Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

pipeline_simulation.py

A mini pipeline parallelism demo with forward pass through distributed stages

This script simulates pipeline parallelism by splitting a simple neural network across multiple processes and passing activations forward.

What It Does

  1. Creates a simple “model” (matrix multiplications) split across stages
  2. Input data enters at rank 0
  3. Activations flow forward through each stage via send/recv
  4. Final output emerges at the last rank

Pipeline Architecture

     [Input]
        │
        ▼
   ┌─────────┐
   │ Stage 0 │  ← Rank 0: First linear layer
   │  Linear │
   └────┬────┘
        │ send activations
        ▼
   ┌─────────┐
   │ Stage 1 │  ← Rank 1: Second linear layer
   │  Linear │
   └────┬────┘
        │ send activations
        ▼
   ┌─────────┐
   │ Stage 2 │  ← Rank 2: Third linear layer
   │  Linear │
   └────┬────┘
        │ send activations
        ▼
   ┌─────────┐
   │ Stage 3 │  ← Rank 3: Final layer + output
   │  Linear │
   └─────────┘
     [Output]

Run It

python tutorial/part1-distributed/chapter02-point-to-point/scripts/pipeline_simulation.py

Key Concepts Demonstrated

  • Pipeline parallelism - Model split across devices
  • Activation passing - Intermediate results flow between stages
  • Sequential dependency - Each stage waits for the previous

Why This Matters

Real pipeline parallelism (like GPipe or PipeDream) uses this same send/recv pattern but with:

  • Micro-batching to keep all stages busy
  • Backward pass for gradient computation
  • Gradient checkpointing to save memory

Source Code

#!/usr/bin/env python3
"""
Pipeline Parallelism Simulation

This script simulates how pipeline parallelism works in practice:
- A "model" is split across multiple processes (stages)
- Data flows forward through the pipeline via send/recv
- Each stage processes its part of the model

Real pipeline parallelism also has backward pass and more complex
scheduling (1F1B, interleaved), but this shows the core concept.

Usage:
    python pipeline_simulation.py
    python pipeline_simulation.py --batch-size 64 --hidden-size 128

The "model" is just a series of Linear layers:
    Input → Linear → ReLU → Linear → ReLU → ... → Output
            Stage 0         Stage 1         Stage N-1
"""

import argparse
import os
import time
from typing import Optional

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn


class PipelineStage(nn.Module):
    """
    One stage of our pipeline (a simple feed-forward block).

    In a real model like GPT, each stage might be several transformer layers.
    """

    def __init__(self, input_size: int, output_size: int, stage_id: int):
        super().__init__()
        self.stage_id = stage_id
        self.linear = nn.Linear(input_size, output_size)
        self.activation = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.activation(self.linear(x))


def pipeline_worker(
    rank: int,
    world_size: int,
    batch_size: int,
    hidden_size: int,
    num_microbatches: int,
    backend: str
) -> None:
    """
    Worker function for one pipeline stage.

    Args:
        rank: This stage's rank
        world_size: Total number of stages
        batch_size: Per-microbatch batch size
        hidden_size: Model hidden dimension
        num_microbatches: Number of microbatches to process
        backend: Distributed backend
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29502"

    dist.init_process_group(backend=backend, rank=rank, world_size=world_size)

    device = torch.device("cpu")

    # Create this stage's model part
    stage = PipelineStage(hidden_size, hidden_size, rank).to(device)

    # =========================================================================
    # Pipeline Forward Pass
    # =========================================================================
    # We process multiple microbatches to show the pipelining effect.
    # In practice, while Stage 1 processes microbatch 0,
    # Stage 0 can start processing microbatch 1 (pipeline filling).

    timings = []

    for mb_idx in range(num_microbatches):
        start_time = time.perf_counter()

        if rank == 0:
            # First stage: generate input (in reality, this comes from data loader)
            activations = torch.randn(batch_size, hidden_size, device=device)
            print(f"[Stage {rank}] Microbatch {mb_idx}: Created input "
                  f"(shape: {list(activations.shape)})")
        else:
            # Other stages: receive activations from previous stage
            activations = torch.zeros(batch_size, hidden_size, device=device)
            dist.recv(activations, src=rank - 1)
            print(f"[Stage {rank}] Microbatch {mb_idx}: Received from stage {rank - 1}")

        # Process through this stage's model part
        with torch.no_grad():
            output = stage(activations)

        if rank == world_size - 1:
            # Last stage: we're done (in reality, compute loss here)
            print(f"[Stage {rank}] Microbatch {mb_idx}: Completed! "
                  f"Output mean: {output.mean().item():.4f}")
        else:
            # Send activations to next stage
            dist.send(output, dst=rank + 1)
            print(f"[Stage {rank}] Microbatch {mb_idx}: Sent to stage {rank + 1}")

        elapsed = time.perf_counter() - start_time
        timings.append(elapsed)

    # Synchronize before printing summary
    dist.barrier()

    if rank == 0:
        print(f"\n{'='*60}")
        print("PIPELINE SIMULATION COMPLETE")
        print(f"{'='*60}")
        print(f"Stages: {world_size}")
        print(f"Microbatches: {num_microbatches}")
        print(f"Batch size per microbatch: {batch_size}")
        print(f"Hidden size: {hidden_size}")
        print(f"\nIn a real pipeline:")
        print(f"  - Stages process different microbatches in parallel")
        print(f"  - Backward pass sends gradients in reverse")
        print(f"  - 1F1B schedule optimizes memory usage")
        print(f"{'='*60}")

    dist.destroy_process_group()


def visualize_pipeline():
    """Print a visualization of pipeline parallelism."""
    print("""
    ═══════════════════════════════════════════════════════════════════════
    PIPELINE PARALLELISM VISUALIZATION
    ═══════════════════════════════════════════════════════════════════════

    The model is split across GPUs/processes:

    Full Model:     [Embed] → [Layer 0-3] → [Layer 4-7] → [Layer 8-11] → [Head]
                        ↓           ↓            ↓             ↓           ↓
    Pipeline:       Stage 0     Stage 1      Stage 2       Stage 3     Stage 4

    Data flows through stages via send/recv:

    Time →
    ┌────────────────────────────────────────────────────────────────────────┐
    │                                                                        │
    │  Stage 0:  [MB0 Fwd]─────►[MB1 Fwd]─────►[MB2 Fwd]─────►[MB3 Fwd]     │
    │                 │              │              │              │         │
    │                 ▼              ▼              ▼              ▼         │
    │  Stage 1:      [MB0 Fwd]─────►[MB1 Fwd]─────►[MB2 Fwd]─────►[MB3 Fwd] │
    │                     │              │              │              │     │
    │                     ▼              ▼              ▼              ▼     │
    │  Stage 2:          [MB0 Fwd]─────►[MB1 Fwd]─────►[MB2 Fwd]─────►...   │
    │                                                                        │
    └────────────────────────────────────────────────────────────────────────┘

    MB = Microbatch, Fwd = Forward pass

    Key insight: While Stage 2 processes MB0, Stage 1 processes MB1,
    and Stage 0 processes MB2. The pipeline is "full" of work!

    ═══════════════════════════════════════════════════════════════════════
    """)


def main():
    parser = argparse.ArgumentParser(
        description="Simulate pipeline parallelism with send/recv"
    )
    parser.add_argument(
        "--world-size", "-w",
        type=int,
        default=4,
        help="Number of pipeline stages (default: 4)"
    )
    parser.add_argument(
        "--batch-size", "-b",
        type=int,
        default=32,
        help="Batch size per microbatch (default: 32)"
    )
    parser.add_argument(
        "--hidden-size",
        type=int,
        default=64,
        help="Model hidden dimension (default: 64)"
    )
    parser.add_argument(
        "--num-microbatches", "-m",
        type=int,
        default=4,
        help="Number of microbatches to process (default: 4)"
    )
    parser.add_argument(
        "--backend",
        type=str,
        default="gloo",
        choices=["gloo", "nccl"],
        help="Distributed backend"
    )
    parser.add_argument(
        "--visualize",
        action="store_true",
        help="Show pipeline visualization and exit"
    )
    args = parser.parse_args()

    if args.visualize:
        visualize_pipeline()
        return

    print("=" * 60)
    print(" PIPELINE PARALLELISM SIMULATION")
    print("=" * 60)
    print(f"Number of stages: {args.world_size}")
    print(f"Batch size: {args.batch_size}")
    print(f"Hidden size: {args.hidden_size}")
    print(f"Microbatches: {args.num_microbatches}")
    print("=" * 60 + "\n")

    mp.spawn(
        pipeline_worker,
        args=(
            args.world_size,
            args.batch_size,
            args.hidden_size,
            args.num_microbatches,
            args.backend
        ),
        nprocs=args.world_size,
        join=True
    )


if __name__ == "__main__":
    main()