ML Systems Infrastructure Tutorial
From distributed primitives to production RLHF: A hands-on journey through ML infrastructure
This tutorial takes you from zero to understanding how large-scale ML systems work. If you’re comfortable with PyTorch and understand transformers but wonder “how do people actually train GPT-4?”, this is for you.
Who This Is For
- Strong ML background: You know PyTorch, can train models, understand attention
- New to systems: You haven’t done distributed training, don’t know NCCL from TCP
- Curious about scale: You want to understand how 1000-GPU training actually works
What You’ll Learn
By the end of this tutorial, you’ll understand:
- How GPUs talk to each other - Communication primitives that enable distributed training
- How to parallelize training - Data, tensor, and pipeline parallelism strategies
- How inference servers work - KV cache, batching, and speculative decoding
- How RLHF systems are built - The four-model dance that makes ChatGPT possible
Tutorial Structure
Part I: Foundations of Distributed Computing (Chapters 1-4)
Start here. These concepts are the alphabet of distributed systems.
| Chapter | Topic | Key Concepts |
|---|---|---|
| Chapter 1 | Your First Distributed Program | rank, world_size, process groups |
| Chapter 2 | Point-to-Point Communication | send/recv, deadlock avoidance |
| Chapter 3 | Collective Operations | all_reduce, broadcast, scatter |
| Chapter 4 | NCCL and GPU Topology | Ring/Tree algorithms, NVLink |
Part II: Parallelism Strategies (Chapters 5-7)
Now you know the primitives. Let’s use them to train models that don’t fit on one GPU.
| Chapter | Topic | Key Concepts |
|---|---|---|
| Chapter 5 | Data Parallelism Deep Dive | DDP, FSDP, ZeRO stages |
| Chapter 6 | Tensor Parallelism | Column/row parallel, Megatron-style |
| Chapter 7 | Pipeline & Expert Parallelism | 1F1B scheduling, MoE |
Part III: LLM Inference Systems (Chapters 8-11)
Training is half the story. Serving models efficiently is the other half.
| Chapter | Topic | Key Concepts |
|---|---|---|
| Chapter 8 | Server Anatomy | Request lifecycle, prefill/decode |
| Chapter 9 | KV Cache Management | PagedAttention, RadixCache |
| Chapter 10 | Scheduling & CUDA Graphs | Zero-overhead scheduling |
| Chapter 11 | Speculative & Constraint Decoding | Draft models, structured output |
Part IV: RLHF Systems (Chapters 12-14)
The grand finale: training models with human feedback.
| Chapter | Topic | Key Concepts |
|---|---|---|
| Chapter 12 | RL Fundamentals for LLMs | PPO, GAE, policy gradients |
| Chapter 13 | RLHF Computation Flow | Four models, reward calculation |
| Chapter 14 | RLHF System Architecture | Co-located vs disaggregated |
How to Use This Tutorial
Prerequisites
pip install torch # Core requirement
pip install gymnasium # For RL chapter (optional)
No GPU required! All scripts have CPU fallback with the gloo backend.
Learning Path
Recommended order: Follow chapters sequentially. Each builds on the previous.
Hands-on learning: Each chapter has:
- Conceptual explanation (the chapter page)
- Runnable scripts (linked as sub-pages)
- Exercises to try
Running the Scripts
# Chapter 1: Your first distributed program
python tutorial/part1-distributed/chapter01-first-program/scripts/verify_setup.py
python tutorial/part1-distributed/chapter01-first-program/scripts/hello_distributed.py
# Chapter 3: Collective operations
python tutorial/part1-distributed/chapter03-collectives/scripts/collective_cheatsheet.py
Quick Start: See Something Work!
Want to jump in immediately? Run this:
python tutorial/part1-distributed/chapter01-first-program/scripts/verify_setup.py # Check your environment
python tutorial/part1-distributed/chapter01-first-program/scripts/hello_distributed.py # Your first distributed program!
You should see 4 processes talking to each other!
Core Mental Models
The Parallelism Zoo
Problem: Model too big?
├── Too big for memory → Data Parallelism (replicate model)
│ └── Still too big → ZeRO/FSDP (shard everything)
├── One layer too big → Tensor Parallelism (split layers)
└── All layers too big → Pipeline Parallelism (split model)
Problem: Model is MoE?
└── Add Expert Parallelism (distribute experts)
The Memory Hierarchy
Fast ──────────────────────────────────────────► Slow
GPU L2 GPU HBM CPU RAM NVMe SSD Network
90TB/s 3TB/s 200GB/s 7GB/s 50GB/s
Goal: Keep computation in fast memory
Strategy: Overlap communication with computation
The Inference Pipeline
Request → Tokenizer → Scheduler → Model Runner → Detokenizer → Response
↓
[Prefill: Process prompt]
↓
[Decode: Generate tokens]
↓
[KV Cache: Remember context]
“The best way to understand distributed systems is to build one. The second best way is this tutorial.”
Happy learning!
Chapter 1: Your First Distributed Program
“The journey of a thousand GPUs begins with a single
init_process_group.”
Learning Objectives
By the end of this chapter, you will be able to:
- Explain what
rank,world_size, and process groups mean - Initialize a distributed PyTorch environment
- Run code across multiple processes that communicate with each other
- Understand why we use multiprocessing (not multithreading) for distributed training
Prerequisites
- Python 3.8+
- PyTorch installed (
pip install torch) - Basic understanding of PyTorch tensors
- No GPU required (we’ll use CPU fallback)
Concept Overview
Why Distributed Computing?
Imagine you’re training a large language model. A single GPU has maybe 80GB of memory, but your model needs 500GB just for its parameters. What do you do?
The answer is distributed computing: spreading your computation across multiple GPUs (and multiple machines). But here’s the catch—those GPUs need to talk to each other. A lot.
The Python Problem: GIL
Python has a notorious feature called the Global Interpreter Lock (GIL). It prevents true parallel execution of Python threads. For compute-intensive tasks like deep learning, this is a showstopper.
Thread 1: "I want to multiply matrices!"
Thread 2: "I also want to multiply matrices!"
GIL: "One at a time, please. Thread 1, you go first."
Thread 2: *waits impatiently*
The solution? Multiprocessing. Instead of threads sharing one Python interpreter, we spawn completely separate Python processes. Each process gets its own interpreter, its own memory space, and (crucially) its own GPU.
The Distributed Vocabulary
Before we write code, let’s learn the language:
| Term | Definition | Analogy |
|---|---|---|
| World | All processes participating in training | The entire team |
| World Size | Total number of processes | Team size |
| Rank | Unique ID for each process (0 to world_size-1) | Employee ID |
| Local Rank | Process ID within a single machine | Desk number in an office |
| Process Group | A subset of processes that communicate together | A project sub-team |
| Backend | The communication library (NCCL, Gloo, MPI) | The phone system |
Machine 0 Machine 1
┌──────────────────┐ ┌──────────────────┐
│ GPU 0 (rank=0) │ │ GPU 0 (rank=2) │
│ GPU 1 (rank=1) │◄──────►│ GPU 1 (rank=3) │
└──────────────────┘ └──────────────────┘
local_rank: 0,1 local_rank: 0,1
Communication Backends
PyTorch supports three backends for inter-process communication:
| Backend | Best For | Supports CPU? | Supports GPU? |
|---|---|---|---|
| NCCL | GPU training | No | Yes (NVIDIA only) |
| Gloo | CPU training, fallback | Yes | Limited |
| MPI | HPC clusters | Yes | Yes |
Rule of thumb: Use NCCL for GPU training, Gloo for CPU or when NCCL isn’t available.
Code Walkthrough
Script 1: verify_setup.py
Let’s start by checking if your environment is ready for distributed computing.
This script checks:
- Is PyTorch installed?
- Is CUDA available?
- Which distributed backends are supported?
- How many GPUs do we have?
Run it with:
python tutorial/part1-distributed/chapter01-first-program/scripts/verify_setup.py
Script 2: hello_distributed.py
Now for the main event—your first distributed program!
The key function is torch.distributed.init_process_group():
import torch.distributed as dist
dist.init_process_group(
backend="gloo", # Communication backend
init_method="...", # How processes find each other
world_size=4, # Total number of processes
rank=0 # This process's ID
)
How do processes find each other?
The init_method parameter tells processes how to rendezvous:
"env://"- Use environment variables (MASTER_ADDR, MASTER_PORT)"tcp://hostname:port"- Explicit TCP address"file:///path/to/file"- Shared filesystem (for single-machine testing)
For our tutorial, we’ll use mp.spawn() which handles this automatically.
Understanding mp.spawn()
import torch.multiprocessing as mp
def worker(rank, world_size):
# Each process runs this function
print(f"Hello from rank {rank}!")
if __name__ == "__main__":
world_size = 4
mp.spawn(worker, args=(world_size,), nprocs=world_size)
mp.spawn():
- Creates
world_sizenew processes - Calls
worker(rank, world_size)in each process - Passes
rankas the first argument automatically
Run it with:
python tutorial/part1-distributed/chapter01-first-program/scripts/hello_distributed.py
You should see output from 4 different processes!
Try It Yourself
Exercise 1: Modify World Size
Edit hello_distributed.py to use world_size=8. What changes in the output?
Exercise 2: Process-Specific Work
Modify the worker function so that:
- Even-ranked processes print “I handle even data!”
- Odd-ranked processes print “I handle odd data!”
Hint
if rank % 2 == 0:
print(f"Rank {rank}: I handle even data!")
else:
print(f"Rank {rank}: I handle odd data!")
Exercise 3: Investigate Environment Variables
Add code to print the following environment variables:
RANKWORLD_SIZELOCAL_RANKMASTER_ADDRMASTER_PORT
What values do they have? (Hint: Use os.environ.get("VAR_NAME", "not set"))
Key Takeaways
- Multiprocessing, not multithreading - Python’s GIL forces us to use separate processes
- Every process has a unique rank - This is how you identify “who am I?”
- init_process_group is the handshake - Processes can’t communicate until they’ve all called this
- Choose the right backend - NCCL for GPUs, Gloo for CPU/fallback
- mp.spawn handles the boilerplate - It creates processes and passes ranks automatically
What’s Next?
In Chapter 2, we’ll learn point-to-point communication—how two specific processes can send data directly to each other. This is the foundation for pipeline parallelism.
Further Reading
verify_setup.py
Check if your environment is ready for distributed computing
This script verifies that PyTorch is installed correctly and checks for distributed computing capabilities.
What It Does
- Checks PyTorch installation and version
- Detects CUDA availability and GPU count
- Lists supported distributed backends (NCCL, Gloo, MPI)
- Provides recommendations based on your setup
Run It
python tutorial/part1-distributed/chapter01-first-program/scripts/verify_setup.py
Source Code
#!/usr/bin/env python3
"""
Verify your environment is ready for distributed PyTorch.
Run this script to check:
- PyTorch installation
- CUDA availability
- Distributed backends
- GPU count
Usage:
python verify_setup.py
"""
import sys
def print_header(title: str) -> None:
"""Print a formatted header."""
print("\n" + "=" * 50)
print(f" {title}")
print("=" * 50)
def check_pytorch() -> bool:
"""Check if PyTorch is installed and print version info."""
print_header("PyTorch Installation")
try:
import torch
print(f"[OK] PyTorch version: {torch.__version__}")
print(f"[OK] PyTorch location: {torch.__file__}")
return True
except ImportError:
print("[FAIL] PyTorch is not installed!")
print(" Install with: pip install torch")
return False
def check_cuda() -> bool:
"""Check CUDA availability and GPU information."""
print_header("CUDA / GPU Status")
import torch
if torch.cuda.is_available():
print(f"[OK] CUDA is available")
print(f"[OK] CUDA version: {torch.version.cuda}")
print(f"[OK] cuDNN version: {torch.backends.cudnn.version()}")
gpu_count = torch.cuda.device_count()
print(f"[OK] Number of GPUs: {gpu_count}")
for i in range(gpu_count):
props = torch.cuda.get_device_properties(i)
memory_gb = props.total_memory / (1024**3)
print(f" GPU {i}: {props.name} ({memory_gb:.1f} GB)")
return True
else:
print("[INFO] CUDA is not available")
print(" This is OK! We'll use CPU with 'gloo' backend")
print(" GPU training requires NVIDIA GPU + CUDA toolkit")
return False
def check_distributed_backends() -> dict:
"""Check which distributed backends are available."""
print_header("Distributed Backends")
import torch.distributed as dist
backends = {
"gloo": dist.is_gloo_available(),
"nccl": dist.is_nccl_available(),
"mpi": dist.is_mpi_available(),
}
for name, available in backends.items():
status = "[OK]" if available else "[NO]"
description = {
"gloo": "CPU training, cross-platform",
"nccl": "GPU training, NVIDIA only",
"mpi": "HPC clusters",
}
print(f"{status} {name.upper()}: {description[name]}")
# Recommendation
print("\nRecommendation:")
if backends["nccl"]:
print(" Use 'nccl' backend for GPU training")
if backends["gloo"]:
print(" Use 'gloo' backend for CPU training or testing")
return backends
def check_multiprocessing() -> bool:
"""Check multiprocessing support."""
print_header("Multiprocessing")
import torch.multiprocessing as mp
# Check start methods
methods = mp.get_all_start_methods()
print(f"[OK] Available start methods: {methods}")
print(f"[OK] Default start method: {mp.get_start_method()}")
# Check if spawn is available (needed for CUDA)
if "spawn" in methods:
print("[OK] 'spawn' method available (required for CUDA)")
return True
else:
print("[WARN] 'spawn' method not available")
return False
def run_simple_test() -> bool:
"""Run a simple distributed test on CPU."""
print_header("Simple Distributed Test")
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import os
def test_worker(rank: int, world_size: int) -> None:
"""Simple test worker."""
# Use gloo backend (works on CPU)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(
backend="gloo",
rank=rank,
world_size=world_size
)
# Create a tensor and do all_reduce
tensor = torch.tensor([rank + 1.0])
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
if rank == 0:
expected = sum(range(1, world_size + 1))
if tensor.item() == expected:
print(f"[OK] all_reduce test passed: {tensor.item()} == {expected}")
else:
print(f"[FAIL] all_reduce test failed: {tensor.item()} != {expected}")
dist.destroy_process_group()
try:
world_size = 2
mp.spawn(test_worker, args=(world_size,), nprocs=world_size, join=True)
return True
except Exception as e:
print(f"[FAIL] Distributed test failed: {e}")
return False
def main() -> None:
"""Run all verification checks."""
print("\n" + "=" * 50)
print(" DISTRIBUTED PYTORCH ENVIRONMENT CHECK")
print("=" * 50)
results = {}
# Run checks
results["pytorch"] = check_pytorch()
if not results["pytorch"]:
print("\n[ABORT] PyTorch is required. Please install it first.")
sys.exit(1)
results["cuda"] = check_cuda()
results["backends"] = check_distributed_backends()
results["multiprocessing"] = check_multiprocessing()
results["test"] = run_simple_test()
# Summary
print_header("Summary")
all_ok = all([
results["pytorch"],
results["backends"]["gloo"], # At minimum we need gloo
results["multiprocessing"],
results["test"],
])
if all_ok:
print("[OK] Your environment is ready for distributed PyTorch!")
print("\nNext steps:")
print(" 1. Run: python hello_distributed.py")
print(" 2. Continue to Chapter 2: Point-to-Point Communication")
else:
print("[WARN] Some checks failed. Review the output above.")
# Hardware recommendation
if results["cuda"] and results["backends"]["nccl"]:
print("\n[TIP] You have GPU support! For best performance, use:")
print(" backend='nccl' for GPU collective operations")
else:
print("\n[TIP] No GPU detected. All exercises will work on CPU with:")
print(" backend='gloo'")
if __name__ == "__main__":
main()
hello_distributed.py
Your first distributed program - see multiple processes communicate!
This is the “Hello, World!” of distributed computing. It spawns multiple processes that initialize a process group and communicate.
What It Does
- Spawns 4 worker processes using
mp.spawn() - Each process initializes the distributed environment
- Processes perform a simple
all_gatherto collect data from everyone - Each process prints what it received
Run It
# Default: 4 processes
python tutorial/part1-distributed/chapter01-first-program/scripts/hello_distributed.py
# Custom world size
python tutorial/part1-distributed/chapter01-first-program/scripts/hello_distributed.py --world-size 8
Expected Output
[Rank 0] Hello! I see 4 processes in the world.
[Rank 1] Hello! I see 4 processes in the world.
[Rank 2] Hello! I see 4 processes in the world.
[Rank 3] Hello! I see 4 processes in the world.
[Rank 0] Gathered values from all ranks: [0, 1, 2, 3]
...
Key Concepts Demonstrated
mp.spawn()- Creates multiple processes, automatically passing rankdist.init_process_group()- The handshake that enables communicationdist.all_gather()- Collect data from all processes
Source Code
#!/usr/bin/env python3
"""
Your First Distributed Program!
This script demonstrates the fundamentals of distributed PyTorch:
- Process group initialization
- Rank and world_size concepts
- Simple tensor communication with all_gather
Usage:
python hello_distributed.py
python hello_distributed.py --world-size 8
What this script does:
1. Spawns multiple processes (default: 4)
2. Each process initializes a distributed environment
3. Processes share information about themselves
4. We demonstrate all_gather to collect data from all processes
"""
import argparse
import os
from typing import List
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def get_device_info() -> dict:
"""Get information about the current process's compute device."""
if torch.cuda.is_available():
# Get local rank (which GPU this process should use)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
device = torch.device(f"cuda:{local_rank}")
device_name = torch.cuda.get_device_name(device)
else:
device = torch.device("cpu")
device_name = "CPU"
return {
"device": device,
"device_name": device_name,
"pid": os.getpid(),
}
def distributed_worker(rank: int, world_size: int, backend: str) -> None:
"""
The main function that runs in each distributed process.
Args:
rank: Unique identifier for this process (0 to world_size-1)
world_size: Total number of processes
backend: Communication backend ('gloo' or 'nccl')
"""
# =========================================================================
# Step 1: Initialize the process group
# =========================================================================
# This is the "handshake" - all processes must call this before communicating
# Environment variables are set by mp.spawn automatically:
# - MASTER_ADDR: Address of rank 0 process
# - MASTER_PORT: Port for communication
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(
backend=backend,
rank=rank,
world_size=world_size
)
# =========================================================================
# Step 2: Get device and process information
# =========================================================================
info = get_device_info()
device = info["device"]
print(f"[Rank {rank}/{world_size}] Hello! PID={info['pid']}, Device={info['device_name']}")
# =========================================================================
# Step 3: Demonstrate all_gather - collect data from all processes
# =========================================================================
# Each process creates a tensor with its rank value
# After all_gather, every process has all tensors
# Create a tensor unique to this rank
my_tensor = torch.tensor([rank * 10.0, rank * 10.0 + 1], device=device)
print(f"[Rank {rank}] My tensor: {my_tensor.tolist()}")
# Prepare a list to receive tensors from all ranks
gathered_tensors: List[torch.Tensor] = [
torch.zeros(2, device=device) for _ in range(world_size)
]
# all_gather: collect my_tensor from all ranks into gathered_tensors
dist.all_gather(gathered_tensors, my_tensor)
# Synchronize before printing (ensures all processes complete the operation)
dist.barrier()
if rank == 0:
print("\n" + "=" * 50)
print("all_gather results (collected on all ranks):")
for i, tensor in enumerate(gathered_tensors):
print(f" From rank {i}: {tensor.tolist()}")
print("=" * 50 + "\n")
# =========================================================================
# Step 4: Demonstrate all_reduce - aggregate values across all processes
# =========================================================================
# Each process contributes its rank, and we sum them all
my_value = torch.tensor([float(rank)], device=device)
dist.all_reduce(my_value, op=dist.ReduceOp.SUM)
if rank == 0:
expected_sum = sum(range(world_size))
print(f"all_reduce (SUM) result: {my_value.item()}")
print(f" Expected: 0 + 1 + ... + {world_size-1} = {expected_sum}")
print(f" Correct: {my_value.item() == expected_sum}\n")
# =========================================================================
# Step 5: Show that rank 0 is special (often used as "master")
# =========================================================================
if rank == 0:
print("I am rank 0 - often called the 'master' or 'coordinator'")
print("Common responsibilities of rank 0:")
print(" - Logging and printing results")
print(" - Saving checkpoints")
print(" - Orchestrating distributed operations")
# =========================================================================
# Step 6: Clean up
# =========================================================================
# Always destroy the process group when done
dist.destroy_process_group()
def main():
parser = argparse.ArgumentParser(description="Your First Distributed Program")
parser.add_argument(
"--world-size", "-w",
type=int,
default=4,
help="Number of processes to spawn (default: 4)"
)
parser.add_argument(
"--backend", "-b",
type=str,
default="gloo",
choices=["gloo", "nccl"],
help="Distributed backend (default: gloo for CPU compatibility)"
)
args = parser.parse_args()
print("=" * 50)
print(" YOUR FIRST DISTRIBUTED PROGRAM")
print("=" * 50)
print(f"World size: {args.world_size}")
print(f"Backend: {args.backend}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU count: {torch.cuda.device_count()}")
print("=" * 50 + "\n")
# Spawn worker processes
# mp.spawn will:
# 1. Create args.world_size new processes
# 2. Call distributed_worker(rank, world_size, backend) in each
# 3. Pass rank=0,1,2,... to each process automatically
mp.spawn(
distributed_worker,
args=(args.world_size, args.backend),
nprocs=args.world_size,
join=True # Wait for all processes to complete
)
print("\nAll processes completed successfully!")
if __name__ == "__main__":
main()
Chapter 2: Point-to-Point Communication
“Before there were collective operations, there were two processes passing notes in class.”
Learning Objectives
By the end of this chapter, you will be able to:
- Send tensors directly between specific processes using
send/recv - Understand blocking vs non-blocking communication (
isend/irecv) - Recognize and avoid common deadlock patterns
- Implement a simple pipeline pattern
Prerequisites
- Completed Chapter 1: Your First Distributed Program
- Understanding of
rankandworld_size - Ability to initialize a distributed process group
Concept Overview
What is Point-to-Point Communication?
In Chapter 1, we used all_gather and all_reduce—these are collective operations where everyone participates. But sometimes you need surgical precision: process 2 needs to send data specifically to process 5, and no one else.
This is point-to-point communication: a direct channel between two specific processes.
Collective (all_reduce): Point-to-Point (send/recv):
[0] [1] [2] [3] [0] ──────► [3]
\ | | / (direct)
\ | | /
▼ ▼ ▼ ▼
[combined]
The Four Operations
| Operation | Blocking? | Description |
|---|---|---|
send(tensor, dst) | Yes | Send tensor to process dst, wait until done |
recv(tensor, src) | Yes | Receive tensor from process src, wait until done |
isend(tensor, dst) | No | Start sending, return immediately with a handle |
irecv(tensor, src) | No | Start receiving, return immediately with a handle |
The “i” prefix stands for “immediate” (non-blocking).
The Blocking vs Non-Blocking Dance
Blocking operations are simpler but can lead to deadlocks:
# DEADLOCK! Both processes wait for each other forever
# Process 0 # Process 1
send(tensor, dst=1) send(tensor, dst=0)
recv(tensor, src=1) recv(tensor, src=0)
Both processes are stuck on send(), waiting for someone to receive—but no one is receiving because everyone is sending!
The fix: Carefully order operations or use non-blocking variants.
# CORRECT: Interleaved send/recv
# Process 0 # Process 1
send(tensor, dst=1) recv(tensor, src=0)
recv(tensor, src=1) send(tensor, dst=0)
Non-Blocking Operations
Non-blocking operations return a Work handle immediately:
# isend returns immediately, data transfer happens in background
handle = dist.isend(tensor, dst=1)
# Do other work while transfer is in progress
compute_something_else()
# Wait for the transfer to complete before using the tensor
handle.wait()
This is essential for overlapping computation with communication—a key optimization in real systems.
Pipeline Parallelism: Where Point-to-Point Shines
Point-to-point communication is the backbone of pipeline parallelism. Imagine a model split across 4 GPUs:
Input ──► [Stage 0] ──► [Stage 1] ──► [Stage 2] ──► [Stage 3] ──► Output
GPU 0 GPU 1 GPU 2 GPU 3
│ │ │ │
└──send──────┴─────────────┴─────────────┘
activations flow forward
Each stage processes its part and sends the activations to the next stage. The last stage computes the loss and gradients flow backward via send/recv in the opposite direction.
Code Walkthrough
Script 1: send_recv_basic.py
This script demonstrates the fundamental pattern: passing a tensor through a chain of processes.
Rank 0 ──► Rank 1 ──► Rank 2 ──► Rank 3
(creates) (adds 10) (adds 10) (prints final)
Key points:
- Rank 0 only sends (it’s the source)
- Middle ranks receive then send (they’re relays)
- Last rank only receives (it’s the sink)
Script 2: pipeline_simulation.py
A mini pipeline parallelism demo! We split a simple “model” (just matrix multiplications) across processes and pass activations forward.
Common Pitfalls
Pitfall 1: Mismatched Send/Recv
# Process 0: sends to 1
dist.send(tensor, dst=1)
# Process 1: receives from 2 (WRONG!)
dist.recv(tensor, src=2) # Will hang forever!
Always ensure src/dst pairs match.
Pitfall 2: Buffer Reuse Before Completion
handle = dist.isend(tensor, dst=1)
tensor.fill_(0) # DANGER! Modifying buffer during transfer
handle.wait()
Never modify a tensor while an async operation is in progress.
Pitfall 3: Forgetting to Wait
handle = dist.irecv(tensor, src=0)
# Forgot handle.wait()!
print(tensor) # Garbage data!
Always call .wait() before using received data.
Try It Yourself
Exercise 1: Ring Topology
Modify send_recv_basic.py to create a ring:
- Rank N sends to Rank (N+1) % world_size
- This means Rank 3 sends back to Rank 0
What value should the tensor have after going full circle?
Exercise 2: Bidirectional Communication
Write a script where:
- Even ranks send to odd ranks
- Odd ranks send to even ranks
- All at the same time (use isend/irecv to avoid deadlock)
Exercise 3: Measure Latency
Use time.perf_counter() to measure:
- Time for a blocking
send/recvpair - Time for an
isend/irecvpair withwait()
Is there a difference? Why or why not?
Key Takeaways
- Point-to-point is surgical - You specify exactly which process sends and receives
- Blocking can deadlock - Be very careful with
send/recvordering - Non-blocking enables overlap -
isend/irecvlet you compute while communicating - Pipeline parallelism uses this heavily - Activations flow forward, gradients flow backward
- Always wait() before using data - Non-blocking doesn’t mean the data is ready
Mental Model: The Post Office
Think of distributed communication like a post office:
send= Walking to the post office, handing over your package, and waiting until it’s deliveredisend= Dropping your package in a mailbox and walking awayrecv= Waiting at home until the doorbell ringsirecv= Setting up a notification to ping you when a package arrives
The post office (NCCL/Gloo) handles the actual delivery in the background.
What’s Next?
In Chapter 3, we’ll explore collective operations in depth—broadcast, scatter, all_gather, and the all-important all_reduce that makes gradient synchronization possible.
Further Reading
send_recv_basic.py
Demonstrates basic point-to-point communication between processes
This script shows how to pass a tensor through a chain of processes using blocking send and recv operations.
What It Does
- Rank 0 creates a tensor with its rank value
- Each rank receives from the previous rank and adds 10
- Each rank sends to the next rank
- Final rank prints the accumulated result
The Chain Pattern
Rank 0 ──► Rank 1 ──► Rank 2 ──► Rank 3
[0] [0+10] [10+10] [20+10]
= [10] = [20] = [30]
Run It
python tutorial/part1-distributed/chapter02-point-to-point/scripts/send_recv_basic.py
Expected Output
[Rank 0] Sending tensor([0.])
[Rank 1] Received tensor([0.]), adding 10, sending tensor([10.])
[Rank 2] Received tensor([10.]), adding 10, sending tensor([20.])
[Rank 3] Received tensor([20.]), final value: tensor([30.])
Key Concepts Demonstrated
- Blocking
send/recv- Operations wait until completion - Chain topology - Data flows linearly through ranks
- Conditional logic by rank - First, middle, and last ranks have different roles
Source Code
#!/usr/bin/env python3
"""
Basic Point-to-Point Communication: The Chain Pattern
This script demonstrates send/recv in a chain topology:
Rank 0 → Rank 1 → Rank 2 → Rank 3
Each process receives from the previous rank, adds 10, and sends to the next.
Usage:
python send_recv_basic.py
python send_recv_basic.py --world-size 8
Key concepts:
- Blocking send/recv
- Chain topology (avoiding deadlocks)
- Careful ordering of operations
"""
import argparse
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def chain_worker(rank: int, world_size: int, backend: str) -> None:
"""
Worker function implementing a chain communication pattern.
Data flows: Rank 0 → Rank 1 → Rank 2 → ... → Rank (world_size-1)
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
# Get device (CPU for gloo, GPU for nccl)
device = torch.device("cpu")
if backend == "nccl" and torch.cuda.is_available():
local_rank = rank % torch.cuda.device_count()
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
# =========================================================================
# The Chain Pattern
# =========================================================================
# This pattern naturally avoids deadlocks because:
# - Rank 0 only sends (no one sends to it first)
# - Middle ranks receive then send (in that order)
# - Last rank only receives (no one receives from it)
if rank == 0:
# First process: create initial tensor and send
tensor = torch.tensor([42.0], device=device)
print(f"[Rank 0] Starting chain with value: {tensor.item()}")
dist.send(tensor, dst=1)
print(f"[Rank 0] Sent to rank 1")
elif rank == world_size - 1:
# Last process: receive and display final result
tensor = torch.zeros(1, device=device)
dist.recv(tensor, src=rank - 1)
print(f"[Rank {rank}] Received final value: {tensor.item()}")
print(f"\n{'='*50}")
print(f"Chain complete!")
print(f"Original: 42.0")
print(f"After {world_size - 1} additions of 10: {tensor.item()}")
print(f"Expected: {42.0 + (world_size - 1) * 10}")
print(f"{'='*50}")
else:
# Middle processes: receive, add 10, send
tensor = torch.zeros(1, device=device)
dist.recv(tensor, src=rank - 1)
print(f"[Rank {rank}] Received: {tensor.item()}")
tensor += 10 # Transform the data
print(f"[Rank {rank}] After adding 10: {tensor.item()}")
dist.send(tensor, dst=rank + 1)
print(f"[Rank {rank}] Sent to rank {rank + 1}")
# Synchronize all processes before cleanup
dist.barrier()
dist.destroy_process_group()
def demonstrate_deadlock_pattern():
"""
Educational function showing a deadlock pattern (DO NOT RUN).
"""
print("""
⚠️ DEADLOCK PATTERN (DO NOT USE):
# Process 0 # Process 1
send(tensor, dst=1) send(tensor, dst=0)
recv(tensor, src=1) recv(tensor, src=0)
Both processes block on send(), waiting for the other to receive.
Neither can proceed → DEADLOCK!
✓ CORRECT PATTERN (interleaved):
# Process 0 # Process 1
send(tensor, dst=1) recv(tensor, src=0)
recv(tensor, src=1) send(tensor, dst=0)
Process 0 sends while Process 1 receives → both can proceed.
""")
def main():
parser = argparse.ArgumentParser(
description="Demonstrate chain pattern point-to-point communication"
)
parser.add_argument(
"--world-size", "-w",
type=int,
default=4,
help="Number of processes in the chain (default: 4)"
)
parser.add_argument(
"--backend", "-b",
type=str,
default="gloo",
choices=["gloo", "nccl"],
help="Distributed backend"
)
parser.add_argument(
"--show-deadlock",
action="store_true",
help="Show deadlock pattern explanation (educational)"
)
args = parser.parse_args()
if args.show_deadlock:
demonstrate_deadlock_pattern()
return
print("=" * 50)
print(" POINT-TO-POINT COMMUNICATION: CHAIN PATTERN")
print("=" * 50)
print(f"World size: {args.world_size}")
print(f"Pattern: Rank 0 → Rank 1 → ... → Rank {args.world_size - 1}")
print(f"Operation: Each rank adds 10 before forwarding")
print("=" * 50 + "\n")
mp.spawn(
chain_worker,
args=(args.world_size, args.backend),
nprocs=args.world_size,
join=True
)
if __name__ == "__main__":
main()
pipeline_simulation.py
A mini pipeline parallelism demo with forward pass through distributed stages
This script simulates pipeline parallelism by splitting a simple neural network across multiple processes and passing activations forward.
What It Does
- Creates a simple “model” (matrix multiplications) split across stages
- Input data enters at rank 0
- Activations flow forward through each stage via
send/recv - Final output emerges at the last rank
Pipeline Architecture
[Input]
│
▼
┌─────────┐
│ Stage 0 │ ← Rank 0: First linear layer
│ Linear │
└────┬────┘
│ send activations
▼
┌─────────┐
│ Stage 1 │ ← Rank 1: Second linear layer
│ Linear │
└────┬────┘
│ send activations
▼
┌─────────┐
│ Stage 2 │ ← Rank 2: Third linear layer
│ Linear │
└────┬────┘
│ send activations
▼
┌─────────┐
│ Stage 3 │ ← Rank 3: Final layer + output
│ Linear │
└─────────┘
[Output]
Run It
python tutorial/part1-distributed/chapter02-point-to-point/scripts/pipeline_simulation.py
Key Concepts Demonstrated
- Pipeline parallelism - Model split across devices
- Activation passing - Intermediate results flow between stages
- Sequential dependency - Each stage waits for the previous
Why This Matters
Real pipeline parallelism (like GPipe or PipeDream) uses this same send/recv pattern but with:
- Micro-batching to keep all stages busy
- Backward pass for gradient computation
- Gradient checkpointing to save memory
Source Code
#!/usr/bin/env python3
"""
Pipeline Parallelism Simulation
This script simulates how pipeline parallelism works in practice:
- A "model" is split across multiple processes (stages)
- Data flows forward through the pipeline via send/recv
- Each stage processes its part of the model
Real pipeline parallelism also has backward pass and more complex
scheduling (1F1B, interleaved), but this shows the core concept.
Usage:
python pipeline_simulation.py
python pipeline_simulation.py --batch-size 64 --hidden-size 128
The "model" is just a series of Linear layers:
Input → Linear → ReLU → Linear → ReLU → ... → Output
Stage 0 Stage 1 Stage N-1
"""
import argparse
import os
import time
from typing import Optional
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
class PipelineStage(nn.Module):
"""
One stage of our pipeline (a simple feed-forward block).
In a real model like GPT, each stage might be several transformer layers.
"""
def __init__(self, input_size: int, output_size: int, stage_id: int):
super().__init__()
self.stage_id = stage_id
self.linear = nn.Linear(input_size, output_size)
self.activation = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.activation(self.linear(x))
def pipeline_worker(
rank: int,
world_size: int,
batch_size: int,
hidden_size: int,
num_microbatches: int,
backend: str
) -> None:
"""
Worker function for one pipeline stage.
Args:
rank: This stage's rank
world_size: Total number of stages
batch_size: Per-microbatch batch size
hidden_size: Model hidden dimension
num_microbatches: Number of microbatches to process
backend: Distributed backend
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29502"
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
device = torch.device("cpu")
# Create this stage's model part
stage = PipelineStage(hidden_size, hidden_size, rank).to(device)
# =========================================================================
# Pipeline Forward Pass
# =========================================================================
# We process multiple microbatches to show the pipelining effect.
# In practice, while Stage 1 processes microbatch 0,
# Stage 0 can start processing microbatch 1 (pipeline filling).
timings = []
for mb_idx in range(num_microbatches):
start_time = time.perf_counter()
if rank == 0:
# First stage: generate input (in reality, this comes from data loader)
activations = torch.randn(batch_size, hidden_size, device=device)
print(f"[Stage {rank}] Microbatch {mb_idx}: Created input "
f"(shape: {list(activations.shape)})")
else:
# Other stages: receive activations from previous stage
activations = torch.zeros(batch_size, hidden_size, device=device)
dist.recv(activations, src=rank - 1)
print(f"[Stage {rank}] Microbatch {mb_idx}: Received from stage {rank - 1}")
# Process through this stage's model part
with torch.no_grad():
output = stage(activations)
if rank == world_size - 1:
# Last stage: we're done (in reality, compute loss here)
print(f"[Stage {rank}] Microbatch {mb_idx}: Completed! "
f"Output mean: {output.mean().item():.4f}")
else:
# Send activations to next stage
dist.send(output, dst=rank + 1)
print(f"[Stage {rank}] Microbatch {mb_idx}: Sent to stage {rank + 1}")
elapsed = time.perf_counter() - start_time
timings.append(elapsed)
# Synchronize before printing summary
dist.barrier()
if rank == 0:
print(f"\n{'='*60}")
print("PIPELINE SIMULATION COMPLETE")
print(f"{'='*60}")
print(f"Stages: {world_size}")
print(f"Microbatches: {num_microbatches}")
print(f"Batch size per microbatch: {batch_size}")
print(f"Hidden size: {hidden_size}")
print(f"\nIn a real pipeline:")
print(f" - Stages process different microbatches in parallel")
print(f" - Backward pass sends gradients in reverse")
print(f" - 1F1B schedule optimizes memory usage")
print(f"{'='*60}")
dist.destroy_process_group()
def visualize_pipeline():
"""Print a visualization of pipeline parallelism."""
print("""
═══════════════════════════════════════════════════════════════════════
PIPELINE PARALLELISM VISUALIZATION
═══════════════════════════════════════════════════════════════════════
The model is split across GPUs/processes:
Full Model: [Embed] → [Layer 0-3] → [Layer 4-7] → [Layer 8-11] → [Head]
↓ ↓ ↓ ↓ ↓
Pipeline: Stage 0 Stage 1 Stage 2 Stage 3 Stage 4
Data flows through stages via send/recv:
Time →
┌────────────────────────────────────────────────────────────────────────┐
│ │
│ Stage 0: [MB0 Fwd]─────►[MB1 Fwd]─────►[MB2 Fwd]─────►[MB3 Fwd] │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ Stage 1: [MB0 Fwd]─────►[MB1 Fwd]─────►[MB2 Fwd]─────►[MB3 Fwd] │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ Stage 2: [MB0 Fwd]─────►[MB1 Fwd]─────►[MB2 Fwd]─────►... │
│ │
└────────────────────────────────────────────────────────────────────────┘
MB = Microbatch, Fwd = Forward pass
Key insight: While Stage 2 processes MB0, Stage 1 processes MB1,
and Stage 0 processes MB2. The pipeline is "full" of work!
═══════════════════════════════════════════════════════════════════════
""")
def main():
parser = argparse.ArgumentParser(
description="Simulate pipeline parallelism with send/recv"
)
parser.add_argument(
"--world-size", "-w",
type=int,
default=4,
help="Number of pipeline stages (default: 4)"
)
parser.add_argument(
"--batch-size", "-b",
type=int,
default=32,
help="Batch size per microbatch (default: 32)"
)
parser.add_argument(
"--hidden-size",
type=int,
default=64,
help="Model hidden dimension (default: 64)"
)
parser.add_argument(
"--num-microbatches", "-m",
type=int,
default=4,
help="Number of microbatches to process (default: 4)"
)
parser.add_argument(
"--backend",
type=str,
default="gloo",
choices=["gloo", "nccl"],
help="Distributed backend"
)
parser.add_argument(
"--visualize",
action="store_true",
help="Show pipeline visualization and exit"
)
args = parser.parse_args()
if args.visualize:
visualize_pipeline()
return
print("=" * 60)
print(" PIPELINE PARALLELISM SIMULATION")
print("=" * 60)
print(f"Number of stages: {args.world_size}")
print(f"Batch size: {args.batch_size}")
print(f"Hidden size: {args.hidden_size}")
print(f"Microbatches: {args.num_microbatches}")
print("=" * 60 + "\n")
mp.spawn(
pipeline_worker,
args=(
args.world_size,
args.batch_size,
args.hidden_size,
args.num_microbatches,
args.backend
),
nprocs=args.world_size,
join=True
)
if __name__ == "__main__":
main()
Chapter 3: Collective Communication Operations
“In distributed training, all_reduce is the workhorse. Everything else is warm-up.”
Learning Objectives
By the end of this chapter, you will be able to:
- Explain what each collective operation does: broadcast, scatter, gather, all_gather, reduce, all_reduce
- Implement gradient synchronization using all_reduce
- Choose the right collective for different scenarios
- Compose primitives to build complex operations (distributed softmax)
Prerequisites
- Completed Chapters 1 & 2
- Understanding of rank, world_size, and basic communication
- Basic linear algebra (matrix operations)
Concept Overview
What are Collective Operations?
Unlike point-to-point (send/recv) where two specific processes communicate, collective operations involve all processes in a group simultaneously. They’re the building blocks of distributed deep learning.
The Collective Operation Zoo
| Operation | Description | Data Flow |
|---|---|---|
| broadcast | One process sends to all | [A] → [A] [A] [A] [A] |
| scatter | Split and distribute | [A B C D] → [A] [B] [C] [D] |
| gather | Collect to one process | [A] [B] [C] [D] → [A B C D] |
| all_gather | Collect to all processes | [A] [B] [C] [D] → [ABCD] [ABCD] [ABCD] [ABCD] |
| reduce | Aggregate to one process | [1] [2] [3] [4] → [10] (sum) |
| all_reduce | Aggregate to all processes | [1] [2] [3] [4] → [10] [10] [10] [10] (sum) |
| reduce_scatter | Reduce + scatter | [A] [B] [C] [D] → [sum(A)] [sum(B)] [sum(C)] [sum(D)] |
Visual Guide
BROADCAST (src=0): SCATTER (src=0):
┌───┐ ┌───┬───┬───┬───┐
│ A │ ─┐ │ A │ B │ C │ D │
└───┘ │ └─┬─┴─┬─┴─┬─┴─┬─┘
│ ┌───┐ ┌───┐ ┌───┐ ┌───┐ │ │ │ │
└──► A │ │ A │ │ A │ │ A │ ▼ ▼ ▼ ▼
└───┘ └───┘ └───┘ └───┘ ┌───┐┌───┐┌───┐┌───┐
R0 R1 R2 R3 │ A ││ B ││ C ││ D │
└───┘└───┘└───┘└───┘
R0 R1 R2 R3
ALL_GATHER: ALL_REDUCE (sum):
┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
│ A │ │ B │ │ C │ │ D │ │ 1 │ │ 2 │ │ 3 │ │ 4 │
└─┬─┘ └─┬─┘ └─┬─┘ └─┬─┘ └─┬─┘ └─┬─┘ └─┬─┘ └─┬─┘
│ │ │ │ │ │ │ │
└─────┴─────┴─────┘ └─────┴─────┴─────┘
│ │
▼ ▼
┌───────────────┐ ┌──────┐
│ A │ B │ C │ D │ (on all ranks) │ 10 │ (on all ranks)
└───────────────┘ └──────┘
The Star: all_reduce
all_reduce is the most important collective operation in distributed training. Here’s why:
In data-parallel training:
- Each GPU has a copy of the model
- Each GPU computes gradients on different data
- Gradients must be averaged across all GPUs ←
all_reduce! - Each GPU updates its model with the averaged gradients
# This single line synchronizes gradients across all GPUs
dist.all_reduce(gradient, op=dist.ReduceOp.SUM)
gradient /= world_size # Average
Reduction Operations
For reduce and all_reduce, you specify the aggregation operation:
| Operation | Python | Result |
|---|---|---|
ReduceOp.SUM | sum(values) | Sum all |
ReduceOp.PRODUCT | prod(values) | Multiply all |
ReduceOp.MIN | min(values) | Minimum |
ReduceOp.MAX | max(values) | Maximum |
Memory Semantics: In-Place vs Out-of-Place
Some operations modify tensors in-place, others require output buffers:
# all_reduce: IN-PLACE
tensor = torch.tensor([rank])
dist.all_reduce(tensor) # tensor is modified
# all_gather: OUT-OF-PLACE
tensor = torch.tensor([rank])
gathered = [torch.zeros(1) for _ in range(world_size)]
dist.all_gather(gathered, tensor) # tensor unchanged, gathered filled
Code Walkthrough
Script 1: collective_cheatsheet.py
This script demonstrates all major collective operations with clear before/after output. Run it to see exactly what each operation does.
Script 2: distributed_mean.py
A practical example: computing the mean of distributed data using all_reduce. This is exactly what happens during gradient synchronization.
When to Use What?
| Scenario | Operation | Why |
|---|---|---|
| Share hyperparameters from rank 0 | broadcast | One source, all need it |
| Distribute a dataset | scatter | Split data across workers |
| Collect predictions | gather | Aggregate results |
| Synchronize gradients | all_reduce | Everyone needs the sum |
| Share embeddings for lookup | all_gather | Everyone needs all data |
| Gradient bucketing | reduce_scatter | Efficient for large models |
Try It Yourself
Exercise 1: Distributed Mean (Easy)
Each process has a different number. Use all_reduce to compute the mean across all processes.
Exercise 2: Distributed Argmax (Medium)
Each process has a tensor. Find the global maximum value and which rank has it.
Hint: Use all_reduce with MAX, then all_gather to find who has it.
Exercise 3: Ring All-Reduce (Hard)
Implement all_reduce using only send/recv in a ring pattern:
- Each process sends to (rank + 1) % world_size
- Each process receives from (rank - 1) % world_size
- Iterate until all data is aggregated
This is essentially what NCCL’s ring algorithm does!
Key Takeaways
- all_reduce is king - It’s the foundation of gradient synchronization
- Collective operations are optimized - Don’t reimplement them with send/recv
- Know your memory semantics - Some ops are in-place, some aren’t
- Composability is powerful - Complex operations (softmax) build from primitives
- scatter vs broadcast - scatter distributes different data, broadcast replicates same data
Performance Intuition
Communication volume for N processes, each with data size D:
| Operation | Volume per process |
|---|---|
| broadcast | D (receive) |
| scatter | D/N (receive) |
| all_gather | D * (N-1) (send + receive) |
| all_reduce | 2D * (N-1) / N (ring algorithm) |
This is why all_reduce with the ring algorithm is efficient—it has O(D) volume regardless of N (though latency scales with N).
What’s Next?
In Chapter 4, we’ll dive into the actual algorithms NCCL uses (Ring, Tree, Double Binary Tree) and how to inspect GPU topology to understand communication performance.
Further Reading
collective_cheatsheet.py
A visual demonstration of all collective operations
This script is your reference guide to collective operations. It demonstrates each operation with clear before/after output so you can see exactly what happens.
What It Does
Runs through all major collective operations:
- Broadcast - One rank sends to all
- Scatter - Split and distribute
- Gather - Collect to one rank
- All-Gather - Everyone gets everything
- Reduce - Aggregate to one rank
- All-Reduce - Aggregate to all ranks
Run It
python tutorial/part1-distributed/chapter03-collectives/scripts/collective_cheatsheet.py
Expected Output
=== BROADCAST (src=0) ===
Before: Rank 0=[42], Rank 1=[0], Rank 2=[0], Rank 3=[0]
After: Rank 0=[42], Rank 1=[42], Rank 2=[42], Rank 3=[42]
=== SCATTER (src=0) ===
Before: Rank 0=[10,20,30,40]
After: Rank 0=[10], Rank 1=[20], Rank 2=[30], Rank 3=[40]
=== ALL_REDUCE (sum) ===
Before: Rank 0=[1], Rank 1=[2], Rank 2=[3], Rank 3=[4]
After: All ranks=[10] (1+2+3+4)
Quick Reference
| Operation | Before | After |
|---|---|---|
| broadcast | [A] [_] [_] [_] | [A] [A] [A] [A] |
| scatter | [ABCD] [_] [_] [_] | [A] [B] [C] [D] |
| gather | [A] [B] [C] [D] | [ABCD] [_] [_] [_] |
| all_gather | [A] [B] [C] [D] | [ABCD] [ABCD] [ABCD] [ABCD] |
| reduce | [1] [2] [3] [4] | [10] [_] [_] [_] |
| all_reduce | [1] [2] [3] [4] | [10] [10] [10] [10] |
Source Code
#!/usr/bin/env python3
"""
Collective Operations Cheatsheet
This script demonstrates all major collective operations with clear
before/after output. Run it to understand what each operation does.
Operations covered:
- broadcast: One-to-all (same data)
- scatter: One-to-all (different data)
- gather: All-to-one
- all_gather: All-to-all (collect)
- reduce: All-to-one (aggregate)
- all_reduce: All-to-all (aggregate)
Usage:
python collective_cheatsheet.py
python collective_cheatsheet.py --operation all_reduce
"""
import argparse
import os
from typing import List
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def print_state(rank: int, world_size: int, name: str, tensor: torch.Tensor,
is_before: bool = True) -> None:
"""Pretty print tensor state."""
dist.barrier() # Synchronize printing
if is_before:
if rank == 0:
print(f"\n{'='*50}")
print(f" {name}")
print(f"{'='*50}")
print("BEFORE:")
dist.barrier()
print(f" Rank {rank}: {tensor.tolist()}")
else:
dist.barrier()
if rank == 0:
print("AFTER:")
dist.barrier()
print(f" Rank {rank}: {tensor.tolist()}")
dist.barrier()
def demo_broadcast(rank: int, world_size: int, device: torch.device) -> None:
"""
BROADCAST: One process sends the same data to all others.
Use case: Share hyperparameters, model weights initialization,
random seed from rank 0 to all processes.
"""
# Before: only rank 0 has meaningful data
if rank == 0:
tensor = torch.tensor([42.0, 43.0, 44.0], device=device)
else:
tensor = torch.zeros(3, device=device)
print_state(rank, world_size, "BROADCAST (src=0)", tensor, is_before=True)
# Broadcast from rank 0 to all
dist.broadcast(tensor, src=0)
print_state(rank, world_size, "BROADCAST (src=0)", tensor, is_before=False)
if rank == 0:
print("\n[Explanation] Rank 0's data [42, 43, 44] was copied to all ranks.")
def demo_scatter(rank: int, world_size: int, device: torch.device) -> None:
"""
SCATTER: One process distributes different chunks to each process.
Use case: Distribute different batches of data to workers.
"""
# Before: only rank 0 has all data
if rank == 0:
scatter_list = [
torch.tensor([i * 10.0, i * 10 + 1.0], device=device)
for i in range(world_size)
]
print_state(rank, world_size, "SCATTER (src=0)", torch.stack(scatter_list), is_before=True)
else:
scatter_list = None
print_state(rank, world_size, "SCATTER (src=0)", torch.zeros(2, device=device), is_before=True)
# Receive buffer
recv_tensor = torch.zeros(2, device=device)
# Scatter from rank 0
dist.scatter(recv_tensor, scatter_list=scatter_list if rank == 0 else None, src=0)
print_state(rank, world_size, "SCATTER (src=0)", recv_tensor, is_before=False)
if rank == 0:
print("\n[Explanation] Rank 0 distributed different chunks to each rank:")
print(" Rank 0 got [0,1], Rank 1 got [10,11], etc.")
def demo_gather(rank: int, world_size: int, device: torch.device) -> None:
"""
GATHER: Collect data from all processes to one process.
Use case: Collect results, predictions, or metrics to rank 0.
"""
# Each rank has unique data
tensor = torch.tensor([rank * 100.0, rank * 100 + 1.0], device=device)
print_state(rank, world_size, "GATHER (dst=0)", tensor, is_before=True)
# Gather to rank 0
if rank == 0:
gather_list = [torch.zeros(2, device=device) for _ in range(world_size)]
else:
gather_list = None
dist.gather(tensor, gather_list=gather_list, dst=0)
if rank == 0:
result = torch.stack(gather_list)
print_state(rank, world_size, "GATHER (dst=0)", result, is_before=False)
print("\n[Explanation] Rank 0 collected all data. Other ranks have nothing new.")
else:
print_state(rank, world_size, "GATHER (dst=0)", tensor, is_before=False)
def demo_all_gather(rank: int, world_size: int, device: torch.device) -> None:
"""
ALL_GATHER: Collect data from all processes to ALL processes.
Use case: Share embeddings, gather activations for all-to-all attention.
"""
# Each rank has unique data
tensor = torch.tensor([rank + 1.0], device=device)
print_state(rank, world_size, "ALL_GATHER", tensor, is_before=True)
# All-gather: everyone gets everything
gathered = [torch.zeros(1, device=device) for _ in range(world_size)]
dist.all_gather(gathered, tensor)
gathered_tensor = torch.cat(gathered)
print_state(rank, world_size, "ALL_GATHER", gathered_tensor, is_before=False)
if rank == 0:
print("\n[Explanation] Every rank now has [1, 2, 3, 4] (data from all ranks).")
def demo_reduce(rank: int, world_size: int, device: torch.device) -> None:
"""
REDUCE: Aggregate (sum/max/min/product) data from all to one process.
Use case: Compute total loss, find global max, etc.
"""
# Each rank has data to contribute
tensor = torch.tensor([rank + 1.0], device=device)
print_state(rank, world_size, "REDUCE SUM (dst=0)", tensor, is_before=True)
# Reduce to rank 0 with sum
dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM)
print_state(rank, world_size, "REDUCE SUM (dst=0)", tensor, is_before=False)
if rank == 0:
print(f"\n[Explanation] Rank 0 has sum: 1+2+3+4 = {tensor.item()}")
print(" Other ranks' tensors are unchanged (or undefined).")
def demo_all_reduce(rank: int, world_size: int, device: torch.device) -> None:
"""
ALL_REDUCE: Aggregate and distribute result to ALL processes.
Use case: GRADIENT SYNCHRONIZATION! This is the heart of distributed training.
"""
# Each rank has gradients to contribute
tensor = torch.tensor([rank + 1.0, (rank + 1.0) * 2], device=device)
print_state(rank, world_size, "ALL_REDUCE SUM", tensor, is_before=True)
# All-reduce with sum
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
print_state(rank, world_size, "ALL_REDUCE SUM", tensor, is_before=False)
if rank == 0:
print(f"\n[Explanation] All ranks now have the same sum!")
print(f" Element 0: 1+2+3+4 = 10")
print(f" Element 1: 2+4+6+8 = 20")
print(" This is how gradient synchronization works!")
def demo_reduce_scatter(rank: int, world_size: int, device: torch.device) -> None:
"""
REDUCE_SCATTER: Reduce + Scatter in one operation.
Use case: Efficient gradient synchronization for model parallelism,
ZeRO optimizer.
"""
# Each rank has a tensor that will be element-wise reduced, then scattered
tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device) * (rank + 1)
print_state(rank, world_size, "REDUCE_SCATTER SUM", tensor, is_before=True)
# Reduce-scatter
output = torch.zeros(1, device=device)
dist.reduce_scatter(output, [tensor[i:i+1].clone() for i in range(world_size)])
print_state(rank, world_size, "REDUCE_SCATTER SUM", output, is_before=False)
if rank == 0:
print("\n[Explanation] First sums across ranks, then each rank gets one chunk.")
print(" Rank i gets sum of position i from all ranks.")
def collective_worker(rank: int, world_size: int, operation: str, backend: str) -> None:
"""Main worker function."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29503"
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
device = torch.device("cpu")
operations = {
"broadcast": demo_broadcast,
"scatter": demo_scatter,
"gather": demo_gather,
"all_gather": demo_all_gather,
"reduce": demo_reduce,
"all_reduce": demo_all_reduce,
"reduce_scatter": demo_reduce_scatter,
"all": None, # Special case
}
if operation == "all":
for op_name, op_func in operations.items():
if op_name != "all" and op_func is not None:
op_func(rank, world_size, device)
dist.barrier()
if rank == 0:
print("\n" + "─" * 50)
else:
operations[operation](rank, world_size, device)
dist.barrier()
dist.destroy_process_group()
def main():
parser = argparse.ArgumentParser(description="Collective Operations Cheatsheet")
parser.add_argument(
"--operation", "-o",
type=str,
default="all",
choices=["broadcast", "scatter", "gather", "all_gather",
"reduce", "all_reduce", "reduce_scatter", "all"],
help="Which operation to demonstrate (default: all)"
)
parser.add_argument(
"--world-size", "-w",
type=int,
default=4,
help="Number of processes (default: 4)"
)
parser.add_argument(
"--backend", "-b",
type=str,
default="gloo",
choices=["gloo", "nccl"],
help="Distributed backend"
)
args = parser.parse_args()
print("╔" + "═" * 58 + "╗")
print("║" + " COLLECTIVE OPERATIONS CHEATSHEET".center(58) + "║")
print("╚" + "═" * 58 + "╝")
print(f"World size: {args.world_size}")
print(f"Operation: {args.operation}")
mp.spawn(
collective_worker,
args=(args.world_size, args.operation, args.backend),
nprocs=args.world_size,
join=True
)
if __name__ == "__main__":
main()
distributed_mean.py
Computing global mean with all_reduce - exactly what gradient sync does!
This script demonstrates the fundamental pattern of distributed gradient synchronization: using all_reduce to compute global averages.
What It Does
- Each process has local data (simulating local gradients)
- Uses
all_reduceto sum all values - Divides by world size to get the mean
- Every process now has the same averaged value
The Pattern
Local gradients: [1.0] [2.0] [3.0] [4.0]
│ │ │ │
└──────┴──────┴──────┘
│
all_reduce (SUM)
│
▼
Global sum: [10.0] [10.0] [10.0] [10.0]
│
÷ world_size
│
▼
Global mean: [2.5] [2.5] [2.5] [2.5]
Run It
python tutorial/part1-distributed/chapter03-collectives/scripts/distributed_mean.py
Why This Matters
This exact pattern is used in Distributed Data Parallel (DDP):
# In DDP, after backward pass:
for gradient in model.gradients():
dist.all_reduce(gradient, op=ReduceOp.SUM)
gradient /= world_size
All GPUs end up with identical averaged gradients, so the model stays synchronized.
Source Code
#!/usr/bin/env python3
"""
Distributed Mean Computation
This script shows how to compute the mean of data distributed across
multiple processes. This is EXACTLY what happens in gradient synchronization!
Scenario:
- Each process has local data (gradients in real training)
- We want the global mean across ALL data
Two approaches:
1. all_reduce(SUM) / world_size (simple, always works)
2. Local mean, then weighted average (more efficient for unequal sizes)
Usage:
python distributed_mean.py
python distributed_mean.py --data-size 1000
"""
import argparse
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def compute_mean_simple(rank: int, world_size: int, data_size: int,
device: torch.device) -> torch.Tensor:
"""
Simple approach: all_reduce(SUM) / world_size
This works when all processes have equal-sized data.
It's what PyTorch DDP does for gradient synchronization.
"""
# Simulate local gradients (different on each rank)
local_data = torch.randn(data_size, device=device) + rank
# Step 1: Sum across all processes
total = local_data.clone()
dist.all_reduce(total, op=dist.ReduceOp.SUM)
# Step 2: Divide by number of processes
mean = total / world_size
return mean, local_data
def compute_mean_weighted(rank: int, world_size: int, local_sizes: list,
device: torch.device) -> torch.Tensor:
"""
Weighted approach for unequal local sizes.
When processes have different amounts of data (e.g., last batch smaller),
we need to weight by the local size.
"""
# Each process has different amount of data
local_size = local_sizes[rank]
local_data = torch.randn(local_size, device=device) + rank
# Step 1: Compute local sum
local_sum = local_data.sum()
# Step 2: all_reduce the sum
dist.all_reduce(local_sum, op=dist.ReduceOp.SUM)
# Step 3: all_reduce the count
local_count = torch.tensor([float(local_size)], device=device)
dist.all_reduce(local_count, op=dist.ReduceOp.SUM)
# Step 4: Global mean = global sum / global count
global_mean = local_sum / local_count
return global_mean, local_data
def mean_worker(rank: int, world_size: int, data_size: int, backend: str) -> None:
"""Worker demonstrating distributed mean computation."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29504"
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
device = torch.device("cpu")
# =========================================================================
# Demo 1: Simple Mean (equal data sizes)
# =========================================================================
if rank == 0:
print("=" * 60)
print(" DISTRIBUTED MEAN: Simple Approach (equal sizes)")
print("=" * 60)
torch.manual_seed(42 + rank) # Reproducible but different per rank
dist_mean, local_data = compute_mean_simple(rank, world_size, data_size, device)
local_mean = local_data.mean().item()
dist.barrier()
print(f"[Rank {rank}] Local mean: {local_mean:.4f}, Distributed mean: {dist_mean.mean().item():.4f}")
dist.barrier()
if rank == 0:
print("\n[Verification] Distributed mean should equal average of local means.")
print("This works because all ranks have equal-sized data.\n")
# =========================================================================
# Demo 2: Weighted Mean (unequal data sizes)
# =========================================================================
dist.barrier()
if rank == 0:
print("=" * 60)
print(" DISTRIBUTED MEAN: Weighted Approach (unequal sizes)")
print("=" * 60)
# Simulate unequal batch sizes (e.g., last batch is smaller)
local_sizes = [data_size, data_size, data_size, data_size // 2][:world_size]
torch.manual_seed(42 + rank)
weighted_mean, local_data = compute_mean_weighted(rank, world_size, local_sizes, device)
local_mean = local_data.mean().item()
dist.barrier()
print(f"[Rank {rank}] Size: {local_sizes[rank]}, Local mean: {local_mean:.4f}, "
f"Weighted global mean: {weighted_mean.item():.4f}")
dist.barrier()
if rank == 0:
print("\n[Verification] Weighted mean properly accounts for different sizes.")
print("This is important when batch sizes vary!\n")
# =========================================================================
# Demo 3: Gradient Synchronization (the real use case)
# =========================================================================
dist.barrier()
if rank == 0:
print("=" * 60)
print(" PRACTICAL EXAMPLE: Gradient Synchronization")
print("=" * 60)
print("""
In distributed data-parallel training, each GPU computes gradients
on its local batch. To train correctly, we need the AVERAGE gradient
across all batches.
Pseudo-code for DDP:
# Forward pass (local)
loss = model(batch)
# Backward pass (local)
loss.backward() # Computes gradients locally
# Synchronize gradients
for param in model.parameters():
dist.all_reduce(param.grad, op=ReduceOp.SUM)
param.grad /= world_size
# Optimizer step (local, but now with averaged gradients)
optimizer.step()
""")
# Simulate gradient computation
torch.manual_seed(123 + rank)
fake_gradient = torch.randn(10, device=device)
if rank == 0:
print("Before synchronization:")
dist.barrier()
print(f" [Rank {rank}] gradient[0]: {fake_gradient[0].item():.4f}")
dist.barrier()
# Synchronize gradients
dist.all_reduce(fake_gradient, op=dist.ReduceOp.SUM)
fake_gradient /= world_size
if rank == 0:
print("\nAfter synchronization (all ranks have same gradient):")
dist.barrier()
print(f" [Rank {rank}] gradient[0]: {fake_gradient[0].item():.4f}")
dist.barrier()
dist.destroy_process_group()
def main():
parser = argparse.ArgumentParser(description="Distributed Mean Computation")
parser.add_argument(
"--data-size", "-d",
type=int,
default=100,
help="Size of local data per process (default: 100)"
)
parser.add_argument(
"--world-size", "-w",
type=int,
default=4,
help="Number of processes (default: 4)"
)
parser.add_argument(
"--backend", "-b",
type=str,
default="gloo",
choices=["gloo", "nccl"],
help="Distributed backend"
)
args = parser.parse_args()
print("╔" + "═" * 58 + "╗")
print("║" + " DISTRIBUTED MEAN COMPUTATION".center(58) + "║")
print("╚" + "═" * 58 + "╝")
mp.spawn(
mean_worker,
args=(args.world_size, args.data_size, args.backend),
nprocs=args.world_size,
join=True
)
if __name__ == "__main__":
main()
Chapter 4: NCCL Algorithms and GPU Topology
“Understanding your hardware is half the battle. The other half is making NCCL do what you want.”
Learning Objectives
By the end of this chapter, you will be able to:
- Explain how Ring and Tree algorithms work for all_reduce
- Inspect GPU topology and NVLink connections
- Understand why communication patterns matter for performance
- Choose the right NCCL settings for your hardware
Prerequisites
- Completed Chapters 1-3
- Understanding of all_reduce and collective operations
- Access to a machine with NVIDIA GPU (for hands-on topology inspection)
Concept Overview
Why Does the Algorithm Matter?
When you call dist.all_reduce(tensor), NCCL doesn’t just magically synchronize data. It runs a carefully designed algorithm that determines:
- Who sends to whom - The communication pattern
- What data flows - Partial aggregates vs full tensors
- How much bandwidth is used - Network saturation
- How long it takes - Latency characteristics
Different algorithms excel in different scenarios:
- Ring: Great bandwidth utilization, scales with data size
- Tree: Lower latency for small messages, scales better with node count
- Double Binary Tree: Best of both worlds for large clusters
The Ring Algorithm
Ring is the most intuitive collective algorithm. Picture GPUs arranged in a circle:
┌──────► GPU 1 ──────┐
│ │
│ ▼
GPU 0 GPU 2
▲ │
│ │
└────── GPU 3 ◄──────┘
How Ring All-Reduce Works:
Phase 1: Scatter-Reduce (each GPU accumulates partial sums)
Step 1: GPU0 sends chunk0 to GPU1, GPU1 sends chunk1 to GPU2, ...
Step 2: Recipients add their local chunk to received chunk, send result
... (N-1 steps total)
Phase 2: All-Gather (distribute the fully reduced chunks)
Step 1: GPU0 sends its complete chunk0 to GPU1, ...
... (N-1 steps total)
Ring Complexity:
- Total steps: 2(N-1) where N is number of GPUs
- Data per step: D/N where D is total data size
- Total data moved: 2D(N-1)/N ≈ 2D for large N
Ring’s Superpower: Bandwidth utilization is nearly 100%! Each GPU is always sending and receiving.
The Tree Algorithm
For large clusters, Ring’s latency (O(N) steps) becomes problematic. Tree algorithms use a hierarchical structure:
GPU 0 (root)
/ \
GPU 1 GPU 2
/ \ / \
GPU 3 GPU 4 GPU 5 GPU 6
How Tree Reduce Works:
Step 1: Leaves (3,4,5,6) send to parents (1,2)
Step 2: Parents combine, send to root (0)
Step 3: Root has final result
Tree Complexity:
- Total steps: 2 * log2(N) (reduce up + broadcast down)
- Much better latency for small messages
Tree’s Tradeoff: Lower bandwidth utilization (only half the GPUs active at any time).
Double Binary Tree (for 24,000+ GPUs)
At scale (think training GPT-4), even tree algorithms hit bottlenecks. Double Binary Tree uses two complementary trees to keep all links busy:
Tree A: Tree B:
0 7
/ \ / \
1 2 6 5
/ \ / \ / \ / \
3 4 5 6 0 1 2 3
Different GPUs are roots/leaves in each tree, balancing the load.
NVLink: The Speed Demon
NVLink is NVIDIA’s high-bandwidth interconnect for GPU-to-GPU communication:
| Generation | Bandwidth (per link) | Links per GPU |
|---|---|---|
| NVLink 1.0 | 20 GB/s | 4 |
| NVLink 2.0 | 25 GB/s | 6 |
| NVLink 3.0 | 25 GB/s | 12 |
| NVLink 4.0 | 25 GB/s | 18 |
For comparison, PCIe 4.0 x16 is only ~32 GB/s total!
A fully-connected 8-GPU node with NVLink 4.0 has 900 GB/s aggregate bandwidth between GPUs. This is why DGX systems are so fast for training.
GPU Topology: The Key to Understanding Performance
Not all GPU pairs are connected equally! Use nvidia-smi topo -m to see your topology:
GPU0 GPU1 GPU2 GPU3 CPU Affinity
GPU0 X NV4 NV4 NV4 0-31
GPU1 NV4 X NV4 NV4 0-31
GPU2 NV4 NV4 X NV4 0-31
GPU3 NV4 NV4 NV4 X 0-31
Legend:
- X: Self
- NV#: NVLink with # links
- SYS: Cross NUMA node (slowest)
- NODE: Same NUMA node, no NVLink
- PHB: Same PCIe host bridge
- PXB: Different PCIe bridges
- PIX: Same PCIe switch
Rule of thumb: More NVLinks = faster. SYS = slow, avoid if possible.
Code Walkthrough
Script 1: topology_inspector.py
This script inspects your GPU topology and reports:
- How many GPUs you have
- NVLink connections between GPUs
- PCIe topology
- NUMA affinity
It also suggests optimal process placement.
Script 2: benchmark_algorithms.py
This script benchmarks different NCCL algorithms on your hardware:
- Measures all_reduce throughput
- Compares Ring vs Tree
- Shows how performance scales with message size
NCCL Environment Variables
You can tune NCCL behavior with environment variables:
| Variable | Description | Default |
|---|---|---|
NCCL_ALGO | Algorithm: Ring, Tree, CollNetChain | Auto |
NCCL_PROTO | Protocol: Simple, LL, LL128 | Auto |
NCCL_NTHREADS | Threads per block | Auto |
NCCL_DEBUG | Debugging output (WARN, INFO, TRACE) | WARN |
NCCL_DEBUG_SUBSYS | Subsystems to debug | All |
Example: Force ring algorithm and show debug info:
NCCL_ALGO=Ring NCCL_DEBUG=INFO python train.py
Try It Yourself
Exercise 1: Inspect Your Topology
Run topology_inspector.py on a GPU machine and answer:
- How many NVLinks connect GPU 0 to GPU 1?
- Are any GPU pairs connected only via PCIe?
- What’s the CPU affinity for each GPU?
Exercise 2: Benchmark All-Reduce
Run benchmark_algorithms.py with different message sizes:
- 1 KB
- 1 MB
- 100 MB
When does Ring outperform Tree? When does Tree win?
Exercise 3: Measure the NVLink Advantage
If you have GPUs connected via NVLink AND PCIe:
- Run all_reduce between NVLink-connected GPUs
- Run all_reduce between PCIe-connected GPUs
- Calculate the speedup
Key Takeaways
- Ring excels at large messages - Nearly 100% bandwidth utilization
- Tree excels at low latency - O(log N) steps vs O(N)
- NVLink is crucial - 10x+ faster than PCIe
- Topology determines performance - Know your hardware!
- NCCL auto-selects - But you can override for specific cases
Performance Intuition
For a 1 GB all_reduce on 8 GPUs:
| Connection | Ring Bandwidth | Approximate Time |
|---|---|---|
| NVLink 4.0 (900 GB/s) | ~450 GB/s effective | ~2.2 ms |
| PCIe 4.0 x16 (32 GB/s) | ~16 GB/s effective | ~62 ms |
That’s a 28x difference just from interconnect!
What’s Next?
In Part II, we’ll use these primitives to implement actual parallelism strategies:
- Chapter 5: Data Parallelism (DDP, FSDP, ZeRO)
- Chapter 6: Tensor Parallelism (splitting layers)
- Chapter 7: Pipeline Parallelism (splitting models)
Further Reading
topology_inspector.py
Inspect your GPU topology and understand communication paths
This script examines your GPU setup and reports on NVLink connections, PCIe topology, and NUMA affinity.
What It Does
- Detects available GPUs and their properties
- Identifies NVLink connections between GPU pairs
- Maps PCIe topology (bridges, switches)
- Shows CPU/NUMA affinity for each GPU
- Suggests optimal process placement
Run It
python tutorial/part1-distributed/chapter04-nccl-topology/scripts/topology_inspector.py
Example Output (8-GPU DGX)
=== GPU Topology Inspector ===
Found 8 GPUs:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
...
NVLink Connections:
GPU 0 <--NV12--> GPU 1
GPU 0 <--NV12--> GPU 2
GPU 0 <--NV12--> GPU 3
...
PCIe Topology:
GPU 0-3: Same PCIe switch (fast)
GPU 4-7: Same PCIe switch (fast)
GPU 0-4: Cross-switch (slower)
NUMA Affinity:
GPU 0-3: NUMA node 0 (CPUs 0-31)
GPU 4-7: NUMA node 1 (CPUs 32-63)
Recommendations:
- For 4-GPU jobs, use GPUs 0-3 or 4-7 (same switch)
- For 8-GPU jobs, expect ~10% overhead from cross-switch communication
Why This Matters
Understanding topology helps you:
- Place processes optimally - Keep communicating processes on fast interconnects
- Debug performance issues - Unexpectedly slow? Check if you’re using PCIe instead of NVLink
- Choose parallelism strategy - Tensor parallel works best with NVLink
Source Code
#!/usr/bin/env python3
"""
GPU Topology Inspector
This script inspects your GPU topology and provides insights about:
- Number of GPUs and their properties
- NVLink connections between GPUs
- PCIe topology
- NUMA affinity
- Recommended process placement
Usage:
python topology_inspector.py
Note: This requires NVIDIA GPUs. On systems without GPUs, it will
display a simulated topology for educational purposes.
"""
import os
import subprocess
import sys
from typing import Dict, List, Optional, Tuple
def check_nvidia_smi() -> bool:
"""Check if nvidia-smi is available."""
try:
subprocess.run(
["nvidia-smi", "--version"],
capture_output=True,
check=True
)
return True
except (subprocess.CalledProcessError, FileNotFoundError):
return False
def get_gpu_count() -> int:
"""Get the number of GPUs."""
try:
result = subprocess.run(
["nvidia-smi", "-L"],
capture_output=True,
text=True,
check=True
)
return len(result.stdout.strip().split('\n'))
except:
return 0
def get_gpu_info() -> List[Dict]:
"""Get detailed GPU information."""
try:
result = subprocess.run(
["nvidia-smi", "--query-gpu=index,name,memory.total,pci.bus_id",
"--format=csv,noheader,nounits"],
capture_output=True,
text=True,
check=True
)
gpus = []
for line in result.stdout.strip().split('\n'):
parts = [p.strip() for p in line.split(',')]
if len(parts) >= 4:
gpus.append({
'index': int(parts[0]),
'name': parts[1],
'memory_mb': int(parts[2]),
'pci_bus': parts[3]
})
return gpus
except:
return []
def get_topology_matrix() -> Optional[str]:
"""Get the GPU topology matrix from nvidia-smi."""
try:
result = subprocess.run(
["nvidia-smi", "topo", "-m"],
capture_output=True,
text=True,
check=True
)
return result.stdout
except:
return None
def get_nvlink_status() -> Optional[str]:
"""Get NVLink status for GPU 0."""
try:
result = subprocess.run(
["nvidia-smi", "nvlink", "--status", "-i", "0"],
capture_output=True,
text=True,
check=True
)
return result.stdout
except:
return None
def parse_topology_matrix(matrix_str: str) -> Dict[Tuple[int, int], str]:
"""Parse topology matrix into a dict of GPU pairs to connection types."""
connections = {}
lines = matrix_str.strip().split('\n')
# Find the header line with GPU columns
header_idx = None
for i, line in enumerate(lines):
if 'GPU0' in line or 'GPU 0' in line:
header_idx = i
break
if header_idx is None:
return connections
# Parse the matrix
for line in lines[header_idx + 1:]:
if not line.strip() or 'Legend' in line:
break
parts = line.split()
if not parts or not parts[0].startswith('GPU'):
continue
try:
gpu_from = int(parts[0].replace('GPU', ''))
for col_idx, conn in enumerate(parts[1:]):
if conn in ['X', 'NV1', 'NV2', 'NV3', 'NV4', 'NV5', 'NV6',
'NV7', 'NV8', 'NV12', 'NV18', 'SYS', 'NODE',
'PHB', 'PXB', 'PIX']:
connections[(gpu_from, col_idx)] = conn
except (ValueError, IndexError):
continue
return connections
def print_header(title: str) -> None:
"""Print a formatted header."""
print("\n" + "═" * 60)
print(f" {title}")
print("═" * 60)
def print_simulated_topology() -> None:
"""Print a simulated topology for educational purposes."""
print_header("SIMULATED GPU TOPOLOGY (No GPUs Detected)")
print("""
This is a simulated DGX-A100 topology for educational purposes.
In a real DGX-A100 with 8 A100 GPUs:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 X NV12 NV12 NV12 NV12 NV12 NV12 NV12
GPU1 NV12 X NV12 NV12 NV12 NV12 NV12 NV12
GPU2 NV12 NV12 X NV12 NV12 NV12 NV12 NV12
GPU3 NV12 NV12 NV12 X NV12 NV12 NV12 NV12
GPU4 NV12 NV12 NV12 NV12 X NV12 NV12 NV12
GPU5 NV12 NV12 NV12 NV12 NV12 X NV12 NV12
GPU6 NV12 NV12 NV12 NV12 NV12 NV12 X NV12
GPU7 NV12 NV12 NV12 NV12 NV12 NV12 NV12 X
Legend:
X = Self
NV# = Connected via NVLink (# = number of links)
SYS = Connected via PCIe across NUMA nodes (slowest)
NODE = Same NUMA node, connected via PCIe
PHB = Same PCIe host bridge
PXB = Different PCIe bridges, same PCIe switch
PIX = Same PCIe switch
Performance implications:
NV12 (12 NVLinks): ~300 GB/s bidirectional
SYS: ~12 GB/s (PCIe 4.0 x16 through CPU)
This shows why NVLink matters: 25x higher bandwidth!
""")
def analyze_topology(connections: Dict[Tuple[int, int], str], num_gpus: int) -> None:
"""Analyze and report on the topology."""
print_header("TOPOLOGY ANALYSIS")
# Count NVLink connections
nvlink_pairs = []
pcie_pairs = []
for (g1, g2), conn in connections.items():
if g1 < g2: # Avoid double counting
if conn.startswith('NV'):
nvlink_pairs.append((g1, g2, conn))
elif conn in ['SYS', 'NODE', 'PHB', 'PXB', 'PIX']:
pcie_pairs.append((g1, g2, conn))
print(f"\nNVLink Connections ({len(nvlink_pairs)} pairs):")
if nvlink_pairs:
for g1, g2, conn in sorted(nvlink_pairs):
num_links = conn.replace('NV', '')
print(f" GPU{g1} <-> GPU{g2}: {conn} ({num_links} links)")
else:
print(" None detected")
print(f"\nPCIe-only Connections ({len(pcie_pairs)} pairs):")
if pcie_pairs:
for g1, g2, conn in sorted(pcie_pairs):
print(f" GPU{g1} <-> GPU{g2}: {conn}")
else:
print(" None (all pairs have NVLink)")
# Recommendations
print_header("RECOMMENDATIONS")
if len(nvlink_pairs) == (num_gpus * (num_gpus - 1)) // 2:
print("✓ All GPU pairs connected via NVLink")
print(" → Ideal for all-reduce operations")
print(" → Can use any GPU grouping")
elif nvlink_pairs:
print("⚠ Mixed NVLink/PCIe topology")
print(" → Group NVLink-connected GPUs together when possible")
print(" → Use process groups to exploit fast connections")
else:
print("⚠ No NVLink detected")
print(" → Performance will be limited by PCIe bandwidth")
print(" → Consider using smaller batch sizes to hide communication")
def main():
print("╔" + "═" * 58 + "╗")
print("║" + " GPU TOPOLOGY INSPECTOR".center(58) + "║")
print("╚" + "═" * 58 + "╝")
# Check if we have nvidia-smi
if not check_nvidia_smi():
print("\n[INFO] nvidia-smi not found. Showing simulated topology.")
print_simulated_topology()
return
# Get GPU count
num_gpus = get_gpu_count()
if num_gpus == 0:
print("\n[INFO] No NVIDIA GPUs detected. Showing simulated topology.")
print_simulated_topology()
return
# Get GPU information
print_header("GPU INFORMATION")
gpus = get_gpu_info()
for gpu in gpus:
print(f"\nGPU {gpu['index']}: {gpu['name']}")
print(f" Memory: {gpu['memory_mb']} MB")
print(f" PCI Bus: {gpu['pci_bus']}")
# Get topology matrix
print_header("TOPOLOGY MATRIX")
topo_matrix = get_topology_matrix()
if topo_matrix:
print(topo_matrix)
connections = parse_topology_matrix(topo_matrix)
analyze_topology(connections, num_gpus)
else:
print("Could not retrieve topology matrix")
# Get NVLink status
nvlink_status = get_nvlink_status()
if nvlink_status and "NVLINK" not in nvlink_status:
print_header("NVLINK STATUS (GPU 0)")
print(nvlink_status)
# PyTorch CUDA information
print_header("PYTORCH CUDA INFO")
try:
import torch
if torch.cuda.is_available():
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA version: {torch.version.cuda}")
print(f"cuDNN version: {torch.backends.cudnn.version()}")
print(f"GPU count (torch): {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
print(f"\nGPU {i}: {props.name}")
print(f" Compute capability: {props.major}.{props.minor}")
print(f" Total memory: {props.total_memory / 1e9:.1f} GB")
print(f" Multi-processor count: {props.multi_processor_count}")
else:
print("PyTorch CUDA not available")
except ImportError:
print("PyTorch not installed")
# Summary
print_header("QUICK REFERENCE")
print("""
NCCL Environment Variables for Tuning:
NCCL_DEBUG=INFO Show what NCCL is doing
NCCL_ALGO=Ring Force ring algorithm
NCCL_ALGO=Tree Force tree algorithm
NCCL_NTHREADS=256 Set thread count
NCCL_P2P_DISABLE=1 Disable peer-to-peer (for debugging)
Common Commands:
nvidia-smi topo -m Show topology matrix
nvidia-smi nvlink --status Show NVLink connections
nvidia-smi -q -d MEMORY Show memory usage details
""")
if __name__ == "__main__":
main()
benchmark_algorithms.py
Benchmark NCCL algorithms on your hardware
This script measures all_reduce performance with different algorithms and message sizes, helping you understand your hardware’s communication characteristics.
What It Does
- Runs all_reduce with various message sizes (1KB to 1GB)
- Tests different NCCL algorithms (Ring, Tree)
- Measures throughput (GB/s) and latency (ms)
- Shows scaling behavior as data size increases
Run It
# Default benchmarks
python tutorial/part1-distributed/chapter04-nccl-topology/scripts/benchmark_algorithms.py
# Force specific algorithm
NCCL_ALGO=Ring python tutorial/part1-distributed/chapter04-nccl-topology/scripts/benchmark_algorithms.py
Example Output
=== All-Reduce Benchmark ===
Message Size | Latency (ms) | Throughput (GB/s) | Algorithm
-------------|--------------|-------------------|----------
1 KB | 0.05 | 0.02 | Tree
16 KB | 0.06 | 0.27 | Tree
256 KB | 0.12 | 2.13 | Ring
4 MB | 0.89 | 4.49 | Ring
64 MB | 12.50 | 5.12 | Ring
1 GB | 198.00 | 5.05 | Ring
Observations:
- Tree wins for small messages (< 256 KB): lower latency
- Ring wins for large messages (> 256 KB): better bandwidth
- Peak throughput: 5.12 GB/s (limited by PCIe)
Interpreting Results
Latency-bound (small messages):
- Tree algorithm is better
- Dominated by startup overhead
- Actual data transfer is fast
Bandwidth-bound (large messages):
- Ring algorithm is better
- Near-100% bandwidth utilization
- All GPUs sending/receiving simultaneously
Source Code
#!/usr/bin/env python3
"""
NCCL Algorithm Benchmark
This script benchmarks all_reduce performance with different:
- Message sizes (small vs large)
- Number of processes
- Backend settings
It demonstrates how performance characteristics change based on
these parameters, showing when Ring vs Tree algorithms excel.
Usage:
python benchmark_algorithms.py
python benchmark_algorithms.py --sizes 1000,1000000,100000000
Note: On CPU-only systems, this uses the gloo backend which
doesn't have Ring/Tree algorithm selection, but still demonstrates
how message size affects throughput.
"""
import argparse
import os
import time
from typing import List, Dict
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def format_bytes(size: int) -> str:
"""Format bytes into human-readable string."""
for unit in ['B', 'KB', 'MB', 'GB']:
if size < 1024:
return f"{size:.1f} {unit}"
size /= 1024
return f"{size:.1f} TB"
def format_bandwidth(bytes_per_sec: float) -> str:
"""Format bandwidth into human-readable string."""
return format_bytes(int(bytes_per_sec)) + "/s"
def benchmark_all_reduce(
tensor: torch.Tensor,
num_iterations: int = 100,
warmup_iterations: int = 10
) -> Dict:
"""
Benchmark all_reduce operation.
Returns dict with timing statistics.
"""
# Warmup
for _ in range(warmup_iterations):
dist.all_reduce(tensor.clone())
# Synchronize before timing
if torch.cuda.is_available():
torch.cuda.synchronize()
dist.barrier()
# Benchmark
times = []
for _ in range(num_iterations):
test_tensor = tensor.clone()
start = time.perf_counter()
dist.all_reduce(test_tensor)
dist.barrier()
if torch.cuda.is_available():
torch.cuda.synchronize()
end = time.perf_counter()
times.append(end - start)
return {
'mean_ms': sum(times) / len(times) * 1000,
'min_ms': min(times) * 1000,
'max_ms': max(times) * 1000,
'median_ms': sorted(times)[len(times)//2] * 1000,
}
def benchmark_worker(
rank: int,
world_size: int,
message_sizes: List[int],
backend: str,
num_iterations: int
) -> None:
"""Worker function for benchmarking."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29505"
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
device = torch.device("cpu")
if backend == "nccl" and torch.cuda.is_available():
local_rank = rank % torch.cuda.device_count()
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
# Run benchmarks
results = []
for size in message_sizes:
# Create tensor of specified size (in bytes, using float32 = 4 bytes)
num_elements = size // 4
tensor = torch.randn(num_elements, device=device)
stats = benchmark_all_reduce(tensor, num_iterations=num_iterations)
# Calculate bandwidth
# all_reduce moves approximately 2 * size * (N-1) / N bytes (ring algorithm)
bytes_moved = 2 * size * (world_size - 1) / world_size
bandwidth = bytes_moved / (stats['mean_ms'] / 1000)
results.append({
'size': size,
'num_elements': num_elements,
'stats': stats,
'bandwidth': bandwidth,
})
dist.barrier()
# Only rank 0 prints results
if rank == 0:
print("\n" + "=" * 70)
print(" ALL_REDUCE BENCHMARK RESULTS")
print("=" * 70)
print(f"Backend: {backend}")
print(f"World size: {world_size}")
print(f"Device: {device}")
print(f"Iterations per test: {num_iterations}")
print("=" * 70)
print(f"\n{'Size':<12} {'Elements':<12} {'Mean (ms)':<12} {'Min (ms)':<12} {'Bandwidth':<15}")
print("-" * 70)
for r in results:
print(f"{format_bytes(r['size']):<12} "
f"{r['num_elements']:<12} "
f"{r['stats']['mean_ms']:<12.3f} "
f"{r['stats']['min_ms']:<12.3f} "
f"{format_bandwidth(r['bandwidth']):<15}")
print("\n" + "=" * 70)
print(" ANALYSIS")
print("=" * 70)
if len(results) >= 2:
# Compare small vs large messages
small = results[0]
large = results[-1]
small_latency = small['stats']['mean_ms']
large_latency = large['stats']['mean_ms']
size_ratio = large['size'] / small['size']
latency_ratio = large_latency / small_latency
print(f"\nLatency scaling:")
print(f" Message size increased {size_ratio:.0f}x")
print(f" Latency increased {latency_ratio:.1f}x")
if latency_ratio < size_ratio * 0.5:
print(f" → Latency grows sub-linearly with size (good bandwidth utilization)")
elif latency_ratio < size_ratio:
print(f" → Latency grows roughly linearly with size")
else:
print(f" → Latency grows super-linearly (possible bottleneck)")
print(f"\nBandwidth comparison:")
print(f" Small messages ({format_bytes(small['size'])}): {format_bandwidth(small['bandwidth'])}")
print(f" Large messages ({format_bytes(large['size'])}): {format_bandwidth(large['bandwidth'])}")
if large['bandwidth'] > small['bandwidth'] * 1.5:
print(f" → Large messages achieve much better bandwidth utilization")
print(f" → This is typical: large messages amortize fixed overhead")
print("""
Understanding the results:
1. SMALL MESSAGES (< 1 MB):
- Dominated by latency (startup cost)
- Tree algorithm excels here (O(log N) steps)
- Low bandwidth utilization
2. LARGE MESSAGES (> 10 MB):
- Dominated by bandwidth
- Ring algorithm excels here (~100% utilization)
- Latency becomes less important
3. NCCL AUTO-SELECTION:
- NCCL automatically chooses Ring or Tree based on message size
- Small: Tree (low latency)
- Large: Ring (high bandwidth)
- Crossover point is typically around 1-10 MB
4. THEORETICAL PEAK:
- NVLink 4.0: ~450 GB/s effective for all_reduce
- PCIe 4.0: ~16 GB/s effective for all_reduce
- If your numbers are much lower, check topology!
""")
dist.destroy_process_group()
def main():
parser = argparse.ArgumentParser(description="NCCL Algorithm Benchmark")
parser.add_argument(
"--sizes",
type=str,
default="1000,10000,100000,1000000,10000000,100000000",
help="Comma-separated message sizes in bytes (default: 1KB to 100MB)"
)
parser.add_argument(
"--world-size", "-w",
type=int,
default=4,
help="Number of processes (default: 4)"
)
parser.add_argument(
"--backend", "-b",
type=str,
default="gloo",
choices=["gloo", "nccl"],
help="Distributed backend (default: gloo for CPU compatibility)"
)
parser.add_argument(
"--iterations", "-i",
type=int,
default=50,
help="Number of iterations per test (default: 50)"
)
args = parser.parse_args()
message_sizes = [int(s) for s in args.sizes.split(',')]
print("╔" + "═" * 58 + "╗")
print("║" + " NCCL ALGORITHM BENCHMARK".center(58) + "║")
print("╚" + "═" * 58 + "╝")
print(f"\nMessage sizes: {[format_bytes(s) for s in message_sizes]}")
print(f"World size: {args.world_size}")
print(f"Backend: {args.backend}")
print(f"Iterations: {args.iterations}")
if args.backend == "nccl" and not torch.cuda.is_available():
print("\n[WARN] NCCL backend requires CUDA. Falling back to gloo.")
args.backend = "gloo"
mp.spawn(
benchmark_worker,
args=(args.world_size, message_sizes, args.backend, args.iterations),
nprocs=args.world_size,
join=True
)
if __name__ == "__main__":
main()
Chapter 5: Data Parallelism Deep Dive
“Data parallelism is the gateway drug of distributed training. It’s deceptively simple, yet optimizing it is an art.”
Learning Objectives
By the end of this chapter, you will be able to:
- Implement basic data-parallel training manually
- Explain how PyTorch DDP works under the hood
- Understand ZeRO stages and their memory tradeoffs
- Choose between DDP, FSDP, and DeepSpeed for your use case
Prerequisites
- Completed Part I (Distributed Computing Foundations)
- Basic understanding of neural network training (forward, backward, optimizer step)
- Familiarity with PyTorch’s autograd
Concept Overview
What is Data Parallelism?
Data parallelism is the simplest form of distributed training:
- Replicate the entire model on each GPU
- Split the training batch across GPUs
- Compute forward and backward passes locally
- Synchronize gradients across all GPUs
- Update each model copy identically
Global Batch (size 256)
┌───────────────────────────────┐
│ B0 │ B1 │ B2 │ B3 │ B4 │ B5 │ B6 │ B7 │
└─┬───┴─┬───┴─┬───┴─┬───┴─┬───┴─┬───┴─┬───┴─┬─┘
│ │ │ │ │ │ │ │
▼ ▼ ▼ ▼ ▼ ▼ ▼ ▼
GPU 0 GPU 1 GPU 2 GPU 3 GPU 4 GPU 5 GPU 6 GPU 7
┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐
│ M │ │ M │ │ M │ │ M │ │ M │ │ M │ │ M │ │ M │
│ O │ │ O │ │ O │ │ O │ │ O │ │ O │ │ O │ │ O │
│ D │ │ D │ │ D │ │ D │ │ D │ │ D │ │ D │ │ D │
│ E │ │ E │ │ E │ │ E │ │ E │ │ E │ │ E │ │ E │
│ L │ │ L │ │ L │ │ L │ │ L │ │ L │ │ L │ │ L │
└─┬─┘ └─┬─┘ └─┬─┘ └─┬─┘ └─┬─┘ └─┬─┘ └─┬─┘ └─┬─┘
│ │ │ │ │ │ │ │
└──────┴──────┴──────┴───┬──┴──────┴──────┴──────┘
│
all_reduce
(gradients)
The Core Insight: Gradient Averaging
Why does this work mathematically?
For a batch B split into B₀ and B₁:
∇L(B) = ∇L(B₀ ∪ B₁)
= (1/|B|) Σᵢ ∇L(xᵢ)
= (1/|B|) [Σᵢ∈B₀ ∇L(xᵢ) + Σᵢ∈B₁ ∇L(xᵢ)]
= (|B₀|/|B|) · ∇L(B₀) + (|B₁|/|B|) · ∇L(B₁)
With equal splits: ∇L(B) = (∇L(B₀) + ∇L(B₁)) / 2
This is exactly what all_reduce(gradients, SUM) / world_size computes!
PyTorch DistributedDataParallel (DDP)
DDP is PyTorch’s production-grade data parallelism implementation. Key features:
- Gradient Bucketing: Groups small gradients into buckets for efficient all_reduce
- Overlap with Backward: Starts all_reduce before backward is complete
- Broadcast Parameters: Ensures all replicas start with identical weights
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# Initialize process group
dist.init_process_group("nccl")
# Create model and wrap with DDP
model = MyModel().to(device)
model = DDP(model, device_ids=[local_rank])
# Training loop (exactly like single-GPU!)
for batch in dataloader:
loss = model(batch)
loss.backward() # Gradients synchronized automatically!
optimizer.step()
The Memory Problem
Data parallelism replicates the entire model. For an LLM like LLaMA-70B:
| Component | Size per GPU |
|---|---|
| Parameters (FP16) | 140 GB |
| Gradients (FP16) | 140 GB |
| Optimizer states (Adam, FP32) | 560 GB |
| Total | 840 GB |
No single GPU has 840 GB! This is where ZeRO comes in.
ZeRO: Zero Redundancy Optimizer
ZeRO is DeepSpeed’s innovation for reducing memory redundancy in data parallelism.
ZeRO-1: Shard Optimizer States
Without ZeRO: Each GPU has full optimizer states (O₀, O₁, O₂, O₃)
With ZeRO-1: GPU 0 has O₀, GPU 1 has O₁, GPU 2 has O₂, GPU 3 has O₃
Before optimizer step: all_gather optimizer states
Memory saved: (N-1)/N of optimizer states
ZeRO-2: Shard Optimizer States + Gradients
Without ZeRO: Each GPU has full gradients (G₀, G₁, G₂, G₃)
With ZeRO-2: Use reduce_scatter instead of all_reduce
Each GPU only keeps 1/N of gradients
Memory saved: (N-1)/N of gradients too
ZeRO-3: Shard Everything (Parameters too)
Without ZeRO: Each GPU has full model (P₀, P₁, P₂, P₃)
With ZeRO-3: GPU 0 has P₀, GPU 1 has P₁, etc.
Before forward/backward: all_gather needed parameters
Memory saved: (N-1)/N of parameters
Memory Comparison
For a 70B parameter model with 8 GPUs:
| Strategy | Memory per GPU |
|---|---|
| DDP (replicated) | 840 GB |
| ZeRO-1 | 350 GB |
| ZeRO-2 | 210 GB |
| ZeRO-3 | 105 GB |
ZeRO-3 achieves 8x memory reduction!
FSDP: PyTorch’s ZeRO Implementation
Fully Sharded Data Parallel (FSDP) is PyTorch’s native implementation of ZeRO-3:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3
# sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, # ZeRO-2
# sharding_strategy=ShardingStrategy.NO_SHARD, # DDP-like
)
Communication Volume Comparison
| Strategy | Forward | Backward | Optimizer |
|---|---|---|---|
| DDP | 0 | 2D | 0 |
| ZeRO-1 | 0 | 2D | D |
| ZeRO-2 | 0 | D | D |
| ZeRO-3 | 2D | 2D | D |
Where D = model size, communication is per-GPU.
ZeRO-3 has 3x more communication than DDP, but 8x less memory!
When to Use What?
| Scenario | Recommendation |
|---|---|
| Model fits in GPU memory | DDP (fastest) |
| Model + gradients fit | ZeRO-2 / FSDP SHARD_GRAD_OP |
| Model doesn’t fit | ZeRO-3 / FSDP FULL_SHARD |
| Very large models (100B+) | ZeRO-3 + tensor parallelism |
Code Walkthrough
Script 1: simple_ddp.py
A minimal DDP implementation to understand the basics:
- Manual gradient synchronization with all_reduce
- Comparison with automatic DDP wrapper
- Measuring communication overhead
Script 2: gradient_sync_visualizer.py
Visualize how gradient synchronization works:
- Shows per-parameter gradients before/after sync
- Demonstrates gradient bucketing concept
- Compares sync strategies
Try It Yourself
Exercise 1: Manual DDP
Implement data-parallel training without using DDP wrapper:
- Broadcast initial weights from rank 0
- After backward(), manually all_reduce all gradients
- Verify your implementation matches DDP
Exercise 2: Gradient Bucketing
Modify gradient_sync_visualizer.py to bucket gradients:
- Group gradients into fixed-size buckets
- all_reduce each bucket as a single tensor
- Measure if bucketing improves throughput
Exercise 3: Measure Communication Overhead
Profile a DDP training run:
- Measure time spent in forward pass
- Measure time spent in backward pass (includes communication)
- Calculate communication/computation ratio
Key Takeaways
- DDP is the default choice - Simple, fast, well-optimized
- Gradient averaging is the key insight - Enables mathematically correct distributed training
- Memory is the bottleneck for LLMs - ZeRO/FSDP trades communication for memory
- Choose sharding level based on model size - Start with DDP, escalate as needed
- Communication overhead grows with sharding - ZeRO-3 is 3x more communication than DDP
The Efficiency Equation
Throughput ≈ min(Compute Throughput, Memory Bandwidth, Network Bandwidth)
- Compute bound: Add more GPUs with DDP
- Memory bound: Use ZeRO-3/FSDP
- Network bound: Optimize topology, reduce communication
What’s Next?
In Chapter 6, we’ll explore Tensor Parallelism—splitting individual layers across GPUs. This is how we train layers that are too large for a single GPU even with ZeRO-3.
Further Reading
simple_ddp.py
Understanding Distributed Data Parallel from first principles
This script implements data-parallel training both manually (with explicit all_reduce) and using PyTorch’s DDP wrapper, so you can see exactly what’s happening under the hood.
What It Does
- Creates a simple model on each process
- Manual approach: Runs forward/backward, then all_reduce gradients explicitly
- DDP approach: Wraps model in DDP, gradients sync automatically
- Verifies both approaches produce identical results
Run It
python tutorial/part2-parallelism/chapter05-data-parallel/scripts/simple_ddp.py
Key Learning Points
Manual Gradient Sync:
# After loss.backward()
for param in model.parameters():
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
param.grad /= world_size # Average
DDP Wrapper:
model = DDP(model, device_ids=[rank])
# Gradients are synced automatically during backward()!
Why DDP is Better
DDP optimizes what we do manually:
- Overlaps communication with computation - Starts all_reduce while backward is still running
- Buckets gradients - Groups small gradients for efficient communication
- Handles edge cases - Unused parameters, mixed precision, etc.
Source Code
#!/usr/bin/env python3
"""
Simple DDP Implementation
This script shows two approaches to data-parallel training:
1. Manual gradient synchronization (educational)
2. PyTorch's DDP wrapper (production)
Understanding the manual approach helps you appreciate what DDP does
automatically and why it's optimized the way it is.
Usage:
python simple_ddp.py
python simple_ddp.py --epochs 10 --batch-size 64
"""
import argparse
import os
import time
from typing import Tuple
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
class SimpleModel(nn.Module):
"""A simple MLP for demonstration."""
def __init__(self, input_size: int = 784, hidden_size: int = 256,
num_classes: int = 10):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
def create_dummy_dataset(num_samples: int, input_size: int,
num_classes: int) -> TensorDataset:
"""Create a dummy dataset for testing."""
X = torch.randn(num_samples, input_size)
y = torch.randint(0, num_classes, (num_samples,))
return TensorDataset(X, y)
def manual_gradient_sync(model: nn.Module, world_size: int) -> None:
"""
Manually synchronize gradients across all processes.
This is what DDP does automatically (but more efficiently).
"""
for param in model.parameters():
if param.grad is not None:
# Sum gradients across all processes
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
# Average by world size
param.grad /= world_size
def train_manual(
rank: int,
world_size: int,
model: nn.Module,
dataloader: DataLoader,
optimizer: optim.Optimizer,
criterion: nn.Module,
device: torch.device
) -> Tuple[float, float]:
"""
Train for one epoch with MANUAL gradient synchronization.
This is educational - showing exactly what DDP automates.
"""
model.train()
total_loss = 0.0
sync_time = 0.0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(device), target.to(device)
# Forward pass (local)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
# Backward pass (local)
loss.backward()
# Manual gradient synchronization (the key step!)
sync_start = time.perf_counter()
manual_gradient_sync(model, world_size)
sync_time += time.perf_counter() - sync_start
# Optimizer step (local, but with averaged gradients)
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader), sync_time
def train_ddp(
model: nn.Module,
dataloader: DataLoader,
optimizer: optim.Optimizer,
criterion: nn.Module,
device: torch.device
) -> float:
"""
Train for one epoch with PyTorch DDP.
DDP automatically handles gradient synchronization during backward().
"""
model.train()
total_loss = 0.0
for batch_idx, (data, target) in enumerate(dataloader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
# DDP hooks into backward() to synchronize gradients
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
def compare_gradients(model1: nn.Module, model2: nn.Module) -> float:
"""Compare gradients between two models (should be identical after sync)."""
max_diff = 0.0
for (n1, p1), (n2, p2) in zip(model1.named_parameters(), model2.named_parameters()):
if p1.grad is not None and p2.grad is not None:
diff = (p1.grad - p2.grad).abs().max().item()
max_diff = max(max_diff, diff)
return max_diff
def worker(
rank: int,
world_size: int,
args: argparse.Namespace
) -> None:
"""Worker function for each process."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29506"
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
device = torch.device("cpu")
# Create dataset and distributed sampler
dataset = create_dummy_dataset(
num_samples=1000,
input_size=784,
num_classes=10
)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
dataloader = DataLoader(dataset, batch_size=args.batch_size, sampler=sampler)
# =========================================================================
# Method 1: Manual Gradient Sync (Educational)
# =========================================================================
if rank == 0:
print("\n" + "=" * 60)
print(" METHOD 1: MANUAL GRADIENT SYNCHRONIZATION")
print("=" * 60)
# Create model (same initialization on all ranks via seeding)
torch.manual_seed(42)
model_manual = SimpleModel().to(device)
# Broadcast initial weights from rank 0 to ensure all replicas start identical
for param in model_manual.parameters():
dist.broadcast(param.data, src=0)
optimizer_manual = optim.SGD(model_manual.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
dist.barrier()
# Train one epoch manually
manual_loss, sync_time = train_manual(
rank, world_size, model_manual, dataloader,
optimizer_manual, criterion, device
)
dist.barrier()
if rank == 0:
print(f"\n[Manual] Loss: {manual_loss:.4f}")
print(f"[Manual] Time spent in gradient sync: {sync_time*1000:.2f} ms")
# =========================================================================
# Method 2: PyTorch DDP (Production)
# =========================================================================
dist.barrier()
if rank == 0:
print("\n" + "=" * 60)
print(" METHOD 2: PYTORCH DDP (AUTOMATIC)")
print("=" * 60)
# Create fresh model with same seed
torch.manual_seed(42)
model_ddp = SimpleModel().to(device)
# Wrap with DDP - this enables automatic gradient sync
model_ddp = DDP(model_ddp)
optimizer_ddp = optim.SGD(model_ddp.parameters(), lr=0.01)
# Reset sampler for new epoch
sampler.set_epoch(0)
dist.barrier()
# Train one epoch with DDP
start_time = time.perf_counter()
ddp_loss = train_ddp(model_ddp, dataloader, optimizer_ddp, criterion, device)
ddp_time = time.perf_counter() - start_time
dist.barrier()
if rank == 0:
print(f"\n[DDP] Loss: {ddp_loss:.4f}")
print(f"[DDP] Total training time: {ddp_time*1000:.2f} ms")
# =========================================================================
# Comparison
# =========================================================================
dist.barrier()
if rank == 0:
print("\n" + "=" * 60)
print(" COMPARISON")
print("=" * 60)
print(f"""
What DDP does that our manual approach doesn't:
1. GRADIENT BUCKETING
- Groups small gradients into larger buffers
- Reduces number of all_reduce calls
- Our manual: one all_reduce per parameter
2. OVERLAP WITH BACKWARD
- Starts all_reduce before backward completes
- Hides communication latency
- Our manual: all_reduce only after full backward
3. SMART BUFFER MANAGEMENT
- Reuses communication buffers
- Avoids memory allocation overhead
- Our manual: allocates on each call
4. BROADCAST ON FIRST FORWARD
- Ensures consistent initialization
- We did this manually with broadcast
Why DDP is faster:
- Fewer, larger all_reduce calls (bucketing)
- Communication overlapped with computation
- Highly optimized NCCL integration
""")
dist.barrier()
dist.destroy_process_group()
def main():
parser = argparse.ArgumentParser(description="Simple DDP Implementation")
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--world-size", "-w", type=int, default=4)
args = parser.parse_args()
print("╔" + "═" * 58 + "╗")
print("║" + " SIMPLE DDP: MANUAL vs AUTOMATIC".center(58) + "║")
print("╚" + "═" * 58 + "╝")
print(f"\nWorld size: {args.world_size}")
print(f"Batch size per GPU: {args.batch_size}")
print(f"Effective batch size: {args.batch_size * args.world_size}")
mp.spawn(worker, args=(args.world_size, args), nprocs=args.world_size, join=True)
if __name__ == "__main__":
main()
gradient_sync_visualizer.py
See exactly how gradients flow during distributed training
This script visualizes the gradient synchronization process, showing what each GPU has before and after all_reduce.
What It Does
- Each GPU computes gradients on its local batch
- Displays gradients BEFORE synchronization (different on each GPU)
- Performs all_reduce
- Displays gradients AFTER synchronization (identical everywhere)
Run It
python tutorial/part2-parallelism/chapter05-data-parallel/scripts/gradient_sync_visualizer.py
Example Output
=== BEFORE Gradient Sync ===
Rank 0: layer1.weight.grad = [0.123, -0.456, 0.789, ...]
Rank 1: layer1.weight.grad = [0.234, -0.567, 0.890, ...]
Rank 2: layer1.weight.grad = [0.345, -0.678, 0.901, ...]
Rank 3: layer1.weight.grad = [0.456, -0.789, 0.012, ...]
=== Performing all_reduce... ===
=== AFTER Gradient Sync ===
Rank 0: layer1.weight.grad = [0.290, -0.623, 0.648, ...] ← averaged
Rank 1: layer1.weight.grad = [0.290, -0.623, 0.648, ...] ← same!
Rank 2: layer1.weight.grad = [0.290, -0.623, 0.648, ...] ← same!
Rank 3: layer1.weight.grad = [0.290, -0.623, 0.648, ...] ← same!
The Insight
After all_reduce + averaging, every GPU has the exact same gradient. This is mathematically equivalent to computing the gradient on the combined batch from all GPUs.
Source Code
#!/usr/bin/env python3
"""
Gradient Synchronization Visualizer
This script visualizes how gradient synchronization works in distributed
training. It shows gradients before and after synchronization, demonstrating
the averaging that makes data parallelism work.
Key insights:
- Each rank computes different gradients (different data)
- After all_reduce + averaging, all ranks have identical gradients
- This is mathematically equivalent to training on the full batch
Usage:
python gradient_sync_visualizer.py
python gradient_sync_visualizer.py --verbose
"""
import argparse
import os
from typing import Dict, List
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
class TinyModel(nn.Module):
"""A tiny model for visualization purposes."""
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(4, 3, bias=False)
self.layer2 = nn.Linear(3, 2, bias=False)
def forward(self, x):
return self.layer2(torch.relu(self.layer1(x)))
def print_gradients(model: nn.Module, rank: int, prefix: str = "") -> Dict[str, torch.Tensor]:
"""Print gradients for all parameters."""
gradients = {}
for name, param in model.named_parameters():
if param.grad is not None:
gradients[name] = param.grad.clone()
if rank == 0:
print(f" {prefix}{name}:")
print(f" shape: {list(param.grad.shape)}")
print(f" grad[0,0]: {param.grad[0,0].item():.6f}")
return gradients
def visualize_sync(rank: int, world_size: int, verbose: bool) -> None:
"""Main visualization function."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29507"
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
device = torch.device("cpu")
# =========================================================================
# Setup: Create identical models on all ranks
# =========================================================================
torch.manual_seed(42)
model = TinyModel().to(device)
# Broadcast weights to ensure identical starting point
for param in model.parameters():
dist.broadcast(param.data, src=0)
dist.barrier()
if rank == 0:
print("\n" + "=" * 60)
print(" GRADIENT SYNCHRONIZATION VISUALIZATION")
print("=" * 60)
print(f"\nWorld size: {world_size}")
print(f"Model: {model}")
# =========================================================================
# Step 1: Each rank processes DIFFERENT data
# =========================================================================
dist.barrier()
if rank == 0:
print("\n" + "-" * 60)
print(" STEP 1: Each rank has different input data")
print("-" * 60)
# Create rank-specific data (simulating distributed batch)
torch.manual_seed(rank * 100) # Different seed per rank!
local_input = torch.randn(8, 4, device=device) # Batch of 8 samples
local_target = torch.randn(8, 2, device=device)
dist.barrier()
print(f"[Rank {rank}] Input mean: {local_input.mean().item():.4f}, "
f"std: {local_input.std().item():.4f}")
dist.barrier()
# =========================================================================
# Step 2: Forward and backward (compute LOCAL gradients)
# =========================================================================
dist.barrier()
if rank == 0:
print("\n" + "-" * 60)
print(" STEP 2: Compute gradients LOCALLY (before sync)")
print("-" * 60)
output = model(local_input)
loss = ((output - local_target) ** 2).mean() # MSE loss
model.zero_grad()
loss.backward()
dist.barrier()
# Show gradients before sync
print(f"\n[Rank {rank}] Loss: {loss.item():.6f}")
# Collect pre-sync gradients
pre_sync_grads = {}
for name, param in model.named_parameters():
if param.grad is not None:
pre_sync_grads[name] = param.grad.clone()
if verbose:
print(f"[Rank {rank}] {name} grad[0,0]: {param.grad[0,0].item():.6f}")
dist.barrier()
if rank == 0:
print("\n[Note] Gradients are DIFFERENT on each rank because")
print(" each rank processed different input data!")
# =========================================================================
# Step 3: Synchronize gradients (all_reduce + average)
# =========================================================================
dist.barrier()
if rank == 0:
print("\n" + "-" * 60)
print(" STEP 3: Synchronize gradients (all_reduce + average)")
print("-" * 60)
for param in model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
param.grad /= world_size
dist.barrier()
# Show gradients after sync
post_sync_grads = {}
for name, param in model.named_parameters():
if param.grad is not None:
post_sync_grads[name] = param.grad.clone()
if verbose:
print(f"[Rank {rank}] {name} grad[0,0]: {param.grad[0,0].item():.6f}")
dist.barrier()
if rank == 0:
print("\n[Note] After sync, ALL ranks have IDENTICAL gradients!")
print(" These are the averaged gradients from all local batches.")
# =========================================================================
# Step 4: Verify all ranks have identical gradients
# =========================================================================
dist.barrier()
if rank == 0:
print("\n" + "-" * 60)
print(" STEP 4: Verify gradient synchronization")
print("-" * 60)
# Gather gradients from all ranks to rank 0
for name, param in model.named_parameters():
if param.grad is not None:
grad_flat = param.grad.flatten()
gathered = [torch.zeros_like(grad_flat) for _ in range(world_size)]
dist.all_gather(gathered, grad_flat)
if rank == 0:
# Check all ranks have identical gradients
all_same = all(torch.allclose(gathered[0], g) for g in gathered[1:])
status = "✓" if all_same else "✗"
print(f" {status} {name}: all ranks identical = {all_same}")
# =========================================================================
# Step 5: Mathematical verification
# =========================================================================
dist.barrier()
if rank == 0:
print("\n" + "-" * 60)
print(" MATHEMATICAL INSIGHT")
print("-" * 60)
print("""
The synchronized gradient is mathematically equivalent to computing
the gradient on the ENTIRE distributed batch:
Let B = B₀ ∪ B₁ ∪ B₂ ∪ B₃ (union of all local batches)
∇L(B) = (1/|B|) Σᵢ ∇L(xᵢ)
= (1/4) [∇L(B₀) + ∇L(B₁) + ∇L(B₂) + ∇L(B₃)]
= all_reduce(local_gradients, SUM) / world_size
This is why data parallelism gives the SAME result as training
on a single GPU with a larger batch size!
""")
# =========================================================================
# Bonus: Show gradient change magnitude
# =========================================================================
dist.barrier()
if verbose and rank == 0:
print("-" * 60)
print(" GRADIENT CHANGE ANALYSIS")
print("-" * 60)
print("\nHow much did gradients change after sync?")
print("(This shows how different each rank's gradients were)\n")
dist.barrier()
for name in pre_sync_grads:
pre_grad = pre_sync_grads[name]
post_grad = post_sync_grads[name]
change = (post_grad - pre_grad).abs().mean().item()
change_pct = change / (pre_grad.abs().mean().item() + 1e-8) * 100
if verbose:
print(f"[Rank {rank}] {name}: changed by {change_pct:.1f}%")
dist.barrier()
dist.destroy_process_group()
def main():
parser = argparse.ArgumentParser(description="Gradient Sync Visualizer")
parser.add_argument("--world-size", "-w", type=int, default=4)
parser.add_argument("--verbose", "-v", action="store_true",
help="Show detailed gradient values")
args = parser.parse_args()
print("╔" + "═" * 58 + "╗")
print("║" + " GRADIENT SYNCHRONIZATION VISUALIZER".center(58) + "║")
print("╚" + "═" * 58 + "╝")
mp.spawn(
visualize_sync,
args=(args.world_size, args.verbose),
nprocs=args.world_size,
join=True
)
if __name__ == "__main__":
main()
Chapter 6: Tensor Parallelism from Scratch
“When your layer doesn’t fit on one GPU, you split the layer, not just the data.”
Learning Objectives
By the end of this chapter, you will be able to:
- Explain why tensor parallelism is needed for large models
- Implement column-parallel and row-parallel linear layers
- Understand how Megatron-LM partitions transformer layers
- Calculate communication costs for different TP strategies
Prerequisites
- Completed Chapter 5 (Data Parallelism)
- Linear algebra (matrix multiplication)
- Understanding of transformer architecture (attention, MLP)
Concept Overview
Why Tensor Parallelism?
Data parallelism replicates the entire model. But what if a single layer is too big?
Consider GPT-3’s embedding layer:
- Vocabulary: 50,000 tokens
- Hidden dimension: 12,288
- Size: 50,000 × 12,288 × 2 bytes = 1.2 GB (just for embeddings!)
For very large models, even a single linear layer might exceed GPU memory. Tensor parallelism splits individual layers across GPUs.
The Key Insight: Matrix Multiplication is Parallelizable
Matrix multiplication Y = XW can be computed in parts:
Column-wise splitting (split W by columns):
W = [W₁ | W₂] (split into left and right halves)
Y = X × [W₁ | W₂] = [X×W₁ | X×W₂] = [Y₁ | Y₂]
Each GPU computes part of the output. No communication needed—just concatenate!
Row-wise splitting (split W by rows):
W = [W₁] (split into top and bottom halves)
[W₂]
Y = X × [W₁; W₂] requires splitting X too...
This needs an all_reduce to combine partial results.
Megatron-Style Tensor Parallelism
Megatron-LM (NVIDIA’s framework) cleverly combines column and row splits to minimize communication:
MLP Block (in a transformer layer):
MLP(X) = GeLU(X × W₁) × W₂
GPU 0: Y₁ = GeLU(X × W₁ᶜᵒˡ⁰) × W₂ʳᵒʷ⁰
GPU 1: Y₁ = GeLU(X × W₁ᶜᵒˡ¹) × W₂ʳᵒʷ¹
Y = all_reduce(Y₀ + Y₁)
The trick: Column-parallel first, row-parallel second!
- After column-parallel W₁: each GPU has part of the hidden states (no comm needed)
- After row-parallel W₂: need all_reduce to sum partial products
Only ONE all_reduce per MLP block!
Attention Layer Tensor Parallelism
For multi-head attention with 32 heads on 4 GPUs:
- Each GPU handles 8 attention heads
- Q, K, V projections: column-parallel (each GPU computes 8 heads worth)
- Output projection: row-parallel (combine head outputs)
┌─────────────────────────────────────────────┐
│ Multi-Head Attention │
│ │
│ Heads 0-7 Heads 8-15 Heads 16-23 Heads 24-31
│ ┌─────┐ ┌─────┐ ┌─────┐ ┌─────┐
Input X ───► │GPU 0│ │GPU 1│ │GPU 2│ │GPU 3│
│ └──┬──┘ └──┬──┘ └──┬──┘ └──┬──┘
│ │ │ │ │
│ └───────────┴────────────┴────────────┘
│ │
│ all_reduce
│ │
│ ▼
│ Output
└─────────────────────────────────────────────┘
Communication Analysis
For a transformer layer with tensor parallelism degree T:
| Component | Communication Volume |
|---|---|
| MLP forward | 2 × batch × seq × hidden / T (all_reduce) |
| MLP backward | 2 × batch × seq × hidden / T (all_reduce) |
| Attention forward | 2 × batch × seq × hidden / T (all_reduce) |
| Attention backward | 2 × batch × seq × hidden / T (all_reduce) |
Total per layer: 8 × batch × seq × hidden × (T-1) / T bytes
This is why TP is typically used within a node (NVLink), not across nodes (slow InfiniBand).
The Math: Column-Parallel Linear
class ColumnParallelLinear:
"""
Split the weight matrix W by columns.
W_full shape: [in_features, out_features]
W_local shape: [in_features, out_features // tp_size]
Forward: Y_local = X @ W_local
No communication needed in forward!
"""
def forward(self, X):
# Each GPU computes its portion of the output
return X @ self.weight # shape: [batch, out_features // tp_size]
The Math: Row-Parallel Linear
class RowParallelLinear:
"""
Split the weight matrix W by rows.
W_full shape: [in_features, out_features]
W_local shape: [in_features // tp_size, out_features]
Forward: Y_partial = X_local @ W_local
Y = all_reduce(Y_partial)
"""
def forward(self, X_local):
# Each GPU has part of input, computes partial output
Y_partial = X_local @ self.weight
# Sum across all GPUs
dist.all_reduce(Y_partial, op=dist.ReduceOp.SUM)
return Y_partial
Combining Column + Row: The MLP Recipe
def tp_mlp_forward(X, W1_col, W2_row, tp_group):
"""
Tensor-parallel MLP with minimal communication.
W1 is column-parallel: [hidden, 4*hidden//tp_size]
W2 is row-parallel: [4*hidden//tp_size, hidden]
"""
# Step 1: Column-parallel first linear
hidden = torch.relu(X @ W1_col) # No comm needed!
# Step 2: Row-parallel second linear
output = hidden @ W2_row
# Step 3: Only ONE all_reduce needed
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=tp_group)
return output
TP vs DP: When to Use Which?
| Factor | Data Parallel | Tensor Parallel |
|---|---|---|
| Granularity | Whole model | Single layer |
| Communication | Gradients only | Activations every layer |
| Scalability | 100s of GPUs | Usually ≤8 GPUs |
| Best for | Batch scaling | Large layers |
| Topology | Cross-node OK | Intra-node (NVLink) |
Rule of thumb: TP within node, DP across nodes.
Code Walkthrough
Script 1: tp_linear.py
Implements column-parallel and row-parallel linear layers from scratch:
- Shows weight initialization and sharding
- Demonstrates forward pass with all_reduce
- Verifies correctness against non-parallel version
Script 2: tp_mlp.py
A complete tensor-parallel MLP block:
- Combines column and row parallelism
- Shows how to minimize communication
- Compares performance with naive approach
Common Pitfalls
Pitfall 1: Forgetting to Split Inputs for Row-Parallel
Row-parallel expects the input to already be split. If you feed the full input, you’ll get wrong results!
Pitfall 2: Wrong Reduction Order
All_reduce must happen at the right place:
- After row-parallel layer
- NOT after column-parallel layer
Pitfall 3: Mismatched Dimensions
When transitioning from column to row parallel:
- Column output shape:
[batch, hidden // tp_size] - Row input shape:
[batch, hidden // tp_size]
These must match!
Try It Yourself
Exercise 1: Verify Column-Parallel Correctness
Run tp_linear.py and verify that:
concatenate(column_parallel_outputs) == full_linear_output
Exercise 2: Count All-Reduces
Count the number of all_reduce calls in a full transformer layer with:
- TP degree = 4
- 12 attention heads
- 4096 hidden dimension
Exercise 3: Measure TP Overhead
Modify tp_mlp.py to measure:
- Time for matrix multiplications
- Time for all_reduce calls
- Communication percentage
Key Takeaways
- TP splits layers, not batches - Complementary to data parallelism
- Column-parallel needs no sync in forward - Output is naturally partitioned
- Row-parallel needs all_reduce - To sum partial products
- Megatron trick: column then row - Minimizes communication to 2 all_reduces per MLP
- TP best within a node - Needs high bandwidth (NVLink)
Performance Intuition
For a 4-GPU TP setup with NVLink (900 GB/s total):
- MLP computation: ~1ms
- All-reduce (2MB activations): ~0.01ms
TP overhead is typically <5% within a node. But across nodes with InfiniBand (50 GB/s), it would be 10x slower!
What’s Next?
In Chapter 7, we’ll explore Pipeline Parallelism and Expert Parallelism—splitting models by layers and routing tokens to specialized experts.
Further Reading
tp_linear.py
Implementing Column-Parallel and Row-Parallel Linear layers from scratch
This script demonstrates the fundamental building blocks of tensor parallelism: how to split a linear layer’s weight matrix across multiple GPUs.
What It Does
- Implements
ColumnParallelLinear- splits weights by columns - Implements
RowParallelLinear- splits weights by rows - Verifies that parallel execution equals sequential execution
- Shows where communication is (and isn’t) needed
The Two Splitting Strategies
Column-Parallel (no sync in forward):
W = [W₀ | W₁ | W₂ | W₃] ← split by columns
Y = X @ W = [X@W₀ | X@W₁ | X@W₂ | X@W₃]
Each GPU computes its slice independently!
Row-Parallel (needs all_reduce):
W = [W₀] ← split by rows
[W₁]
[W₂]
[W₃]
Y = X@W = X₀@W₀ + X₁@W₁ + X₂@W₂ + X₃@W₃
Requires all_reduce to sum partial results!
Run It
python tutorial/part2-parallelism/chapter06-tensor-parallel/scripts/tp_linear.py
Key Verification
The script verifies:
# Column parallel: concatenated outputs match full computation
torch.cat([y0, y1, y2, y3], dim=-1) == X @ W_full
# Row parallel: summed outputs match full computation
y0 + y1 + y2 + y3 == X @ W_full
Source Code
#!/usr/bin/env python3
"""
Tensor-Parallel Linear Layers
This script implements column-parallel and row-parallel linear layers
from scratch, showing exactly how tensor parallelism works.
Column-parallel: Split output dimension (no sync needed)
Row-parallel: Split input dimension (all_reduce needed)
Usage:
python tp_linear.py
python tp_linear.py --tp-size 4
"""
import argparse
import os
from typing import Tuple
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
class ColumnParallelLinear(nn.Module):
"""
Linear layer with column-parallel weight matrix.
The weight matrix W is split along the output dimension:
W_full: [in_features, out_features]
W_local: [in_features, out_features // tp_size]
Each GPU computes a portion of the output features.
Forward pass:
Y_local = X @ W_local (no communication!)
To get full Y, concatenate Y_local from all GPUs
"""
def __init__(self, in_features: int, out_features: int,
tp_size: int, tp_rank: int):
super().__init__()
assert out_features % tp_size == 0, "out_features must be divisible by tp_size"
self.in_features = in_features
self.out_features = out_features
self.tp_size = tp_size
self.tp_rank = tp_rank
self.out_features_local = out_features // tp_size
# Local weight: only 1/tp_size of the columns
self.weight = nn.Parameter(
torch.empty(in_features, self.out_features_local)
)
nn.init.xavier_uniform_(self.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass: Y_local = X @ W_local
Input: x of shape [batch, in_features]
Output: y of shape [batch, out_features // tp_size]
"""
return x @ self.weight
def __repr__(self):
return (f"ColumnParallelLinear(in={self.in_features}, "
f"out={self.out_features_local} (local) / {self.out_features} (total), "
f"tp_rank={self.tp_rank})")
class RowParallelLinear(nn.Module):
"""
Linear layer with row-parallel weight matrix.
The weight matrix W is split along the input dimension:
W_full: [in_features, out_features]
W_local: [in_features // tp_size, out_features]
Each GPU computes a partial result that must be summed.
Forward pass:
Y_partial = X_local @ W_local
Y = all_reduce(Y_partial) # Sum across all GPUs
"""
def __init__(self, in_features: int, out_features: int,
tp_size: int, tp_rank: int, tp_group=None):
super().__init__()
assert in_features % tp_size == 0, "in_features must be divisible by tp_size"
self.in_features = in_features
self.out_features = out_features
self.tp_size = tp_size
self.tp_rank = tp_rank
self.tp_group = tp_group
self.in_features_local = in_features // tp_size
# Local weight: only 1/tp_size of the rows
self.weight = nn.Parameter(
torch.empty(self.in_features_local, out_features)
)
nn.init.xavier_uniform_(self.weight)
def forward(self, x_local: torch.Tensor) -> torch.Tensor:
"""
Forward pass: Y = all_reduce(X_local @ W_local)
Input: x_local of shape [batch, in_features // tp_size]
Output: y of shape [batch, out_features]
"""
# Partial output (not yet complete)
y_partial = x_local @ self.weight
# Sum across all GPUs to get complete output
dist.all_reduce(y_partial, op=dist.ReduceOp.SUM, group=self.tp_group)
return y_partial
def __repr__(self):
return (f"RowParallelLinear(in={self.in_features_local} (local) / {self.in_features} (total), "
f"out={self.out_features}, tp_rank={self.tp_rank})")
def verify_column_parallel(rank: int, world_size: int) -> None:
"""Verify column-parallel linear correctness."""
device = torch.device("cpu")
if rank == 0:
print("\n" + "=" * 60)
print(" COLUMN-PARALLEL LINEAR VERIFICATION")
print("=" * 60)
# Parameters
batch_size = 4
in_features = 8
out_features = 8 # Must be divisible by world_size
# Create column-parallel layer
torch.manual_seed(42)
col_linear = ColumnParallelLinear(
in_features, out_features, world_size, rank
).to(device)
# Create full layer for comparison (only on rank 0)
torch.manual_seed(42)
if rank == 0:
full_linear = nn.Linear(in_features, out_features, bias=False).to(device)
# Copy weights to match column-parallel weights
full_weight = torch.empty(in_features, out_features)
nn.init.xavier_uniform_(full_weight)
# Gather all column-parallel weights to rank 0 for verification
local_weight = col_linear.weight.data.clone()
gathered_weights = [torch.zeros_like(local_weight) for _ in range(world_size)]
dist.all_gather(gathered_weights, local_weight)
if rank == 0:
reconstructed_weight = torch.cat(gathered_weights, dim=1)
print(f"\nWeight shapes:")
print(f" Local: {local_weight.shape}")
print(f" Reconstructed: {reconstructed_weight.shape}")
print(f" Full: {full_weight.shape}")
# Create test input (same on all ranks)
torch.manual_seed(123)
x = torch.randn(batch_size, in_features, device=device)
# Forward pass with column-parallel
y_local = col_linear(x)
# Gather outputs
gathered_outputs = [torch.zeros_like(y_local) for _ in range(world_size)]
dist.all_gather(gathered_outputs, y_local)
y_reconstructed = torch.cat(gathered_outputs, dim=1)
# Compare with full layer (only on rank 0)
if rank == 0:
# Use reconstructed weight for full computation
y_full = x @ reconstructed_weight
diff = (y_reconstructed - y_full).abs().max().item()
print(f"\nOutput comparison:")
print(f" Reconstructed shape: {y_reconstructed.shape}")
print(f" Full shape: {y_full.shape}")
print(f" Max difference: {diff:.2e}")
print(f" Correct: {diff < 1e-5}")
def verify_row_parallel(rank: int, world_size: int) -> None:
"""Verify row-parallel linear correctness."""
device = torch.device("cpu")
if rank == 0:
print("\n" + "=" * 60)
print(" ROW-PARALLEL LINEAR VERIFICATION")
print("=" * 60)
# Parameters
batch_size = 4
in_features = 8 # Must be divisible by world_size
out_features = 8
# Create row-parallel layer
torch.manual_seed(42)
row_linear = RowParallelLinear(
in_features, out_features, world_size, rank
).to(device)
# Create test input (full, same on all ranks)
torch.manual_seed(123)
x_full = torch.randn(batch_size, in_features, device=device)
# Split input for row-parallel (each rank gets a slice)
x_chunks = x_full.chunk(world_size, dim=1)
x_local = x_chunks[rank]
if rank == 0:
print(f"\nInput shapes:")
print(f" Full: {x_full.shape}")
print(f" Local: {x_local.shape}")
# Forward pass with row-parallel (includes all_reduce!)
y_row_parallel = row_linear(x_local)
# Gather all local weights to reconstruct full weight
local_weight = row_linear.weight.data.clone()
gathered_weights = [torch.zeros_like(local_weight) for _ in range(world_size)]
dist.all_gather(gathered_weights, local_weight)
if rank == 0:
# Reconstruct full weight
full_weight = torch.cat(gathered_weights, dim=0)
# Compute full output for comparison
y_full = x_full @ full_weight
diff = (y_row_parallel - y_full).abs().max().item()
print(f"\nOutput comparison:")
print(f" Row-parallel shape: {y_row_parallel.shape}")
print(f" Full shape: {y_full.shape}")
print(f" Max difference: {diff:.2e}")
print(f" Correct: {diff < 1e-5}")
def demonstrate_megatron_pattern(rank: int, world_size: int) -> None:
"""Demonstrate the Megatron column→row pattern."""
device = torch.device("cpu")
if rank == 0:
print("\n" + "=" * 60)
print(" MEGATRON PATTERN: Column + Row")
print("=" * 60)
print("""
The Megatron-LM pattern for an MLP block:
1. First linear (column-parallel):
- Input: [batch, hidden]
- Output: [batch, 4*hidden // tp_size]
- No communication needed!
2. Activation (GeLU):
- Applied locally
- Still no communication!
3. Second linear (row-parallel):
- Input: [batch, 4*hidden // tp_size]
- Output: [batch, hidden]
- ONE all_reduce to sum partial products
Result: Only 1 all_reduce per MLP block forward pass!
""")
# Demonstrate the pattern
batch_size = 4
hidden_size = 8
intermediate_size = 32 # 4x hidden
# Column-parallel first layer
torch.manual_seed(42 + rank)
W1_col = ColumnParallelLinear(
hidden_size, intermediate_size, world_size, rank
).to(device)
# Row-parallel second layer
torch.manual_seed(142 + rank)
W2_row = RowParallelLinear(
intermediate_size, hidden_size, world_size, rank
).to(device)
# Input (same on all ranks)
torch.manual_seed(200)
x = torch.randn(batch_size, hidden_size, device=device)
# Forward pass
# Step 1: Column-parallel (no communication)
h = W1_col(x)
if rank == 0:
print(f"\nAfter column-parallel W1:")
print(f" Input shape: {x.shape}")
print(f" Output shape: {h.shape} (partitioned across {world_size} GPUs)")
# Step 2: Activation (local)
h = torch.relu(h)
# Step 3: Row-parallel (one all_reduce)
y = W2_row(h)
if rank == 0:
print(f"\nAfter row-parallel W2:")
print(f" Input shape: {h.shape}")
print(f" Output shape: {y.shape} (after all_reduce)")
# Verify all ranks have the same output
gathered_outputs = [torch.zeros_like(y) for _ in range(world_size)]
dist.all_gather(gathered_outputs, y)
if rank == 0:
all_same = all(torch.allclose(gathered_outputs[0], g) for g in gathered_outputs[1:])
print(f"\nAll ranks have identical output: {all_same}")
print("\nKey insight: We achieved tensor parallelism with only ONE all_reduce!")
def worker(rank: int, world_size: int) -> None:
"""Main worker function."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29508"
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
verify_column_parallel(rank, world_size)
dist.barrier()
verify_row_parallel(rank, world_size)
dist.barrier()
demonstrate_megatron_pattern(rank, world_size)
dist.barrier()
dist.destroy_process_group()
def main():
parser = argparse.ArgumentParser(description="Tensor-Parallel Linear Layers")
parser.add_argument("--tp-size", "-t", type=int, default=4,
help="Tensor parallelism degree (default: 4)")
args = parser.parse_args()
print("╔" + "═" * 58 + "╗")
print("║" + " TENSOR-PARALLEL LINEAR LAYERS".center(58) + "║")
print("╚" + "═" * 58 + "╝")
print(f"\nTP degree: {args.tp_size}")
mp.spawn(worker, args=(args.tp_size,), nprocs=args.tp_size, join=True)
if __name__ == "__main__":
main()
tp_mlp.py
A complete Tensor-Parallel MLP block with minimal communication
This script implements the Megatron-style tensor-parallel MLP, showing how to chain column-parallel and row-parallel layers to minimize communication.
What It Does
- Implements a tensor-parallel MLP block:
- First linear: Column-parallel (expands hidden → 4×hidden)
- Activation: GeLU (local, no communication)
- Second linear: Row-parallel (contracts 4×hidden → hidden)
- Shows that only ONE all_reduce is needed per MLP forward pass
- Compares with naive approach (2 all_reduces)
The Megatron Trick
MLP(X) = GeLU(X @ W1) @ W2
Naive: all_reduce after W1, all_reduce after W2 = 2 communications
Smart: column-parallel W1, row-parallel W2 = 1 communication!
Why it works:
- Column-parallel W1 produces split outputs:
[Y₀ | Y₁ | Y₂ | Y₃] - Each GPU applies GeLU locally
- Row-parallel W2 expects split inputs (which we have!)
- Only need all_reduce at the end
Run It
python tutorial/part2-parallelism/chapter06-tensor-parallel/scripts/tp_mlp.py
Architecture Visualization
X (input)
│
▼
┌─────────────────────┐
│ Column-Parallel │ ← W1 split by columns
│ Linear (W1) │ No communication
└──────────┬──────────┘
│
▼
┌─────────────────────┐
│ GeLU │ ← Local operation
│ (no comm needed) │
└──────────┬──────────┘
│
▼
┌─────────────────────┐
│ Row-Parallel │ ← W2 split by rows
│ Linear (W2) │
└──────────┬──────────┘
│
▼
┌─────────────────────┐
│ all_reduce │ ← Only sync point!
│ │
└──────────┬──────────┘
│
▼
Y (output)
Source Code
#!/usr/bin/env python3
"""
Tensor-Parallel MLP Block
This script implements a complete tensor-parallel MLP block using
the Megatron-style column→row pattern for minimal communication.
Usage:
python tp_mlp.py
python tp_mlp.py --tp-size 4 --hidden-size 256
"""
import argparse
import os
import time
from typing import Tuple
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
class TensorParallelMLP(nn.Module):
"""
Tensor-parallel MLP using Megatron-style column→row parallelism.
Structure:
Input → [Column-Parallel Linear] → GeLU → [Row-Parallel Linear] → Output
Communication: 1 all_reduce per forward pass (after row-parallel)
"""
def __init__(self, hidden_size: int, intermediate_size: int,
tp_size: int, tp_rank: int, tp_group=None):
super().__init__()
assert intermediate_size % tp_size == 0
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.tp_size = tp_size
self.tp_rank = tp_rank
self.tp_group = tp_group
self.intermediate_local = intermediate_size // tp_size
# Column-parallel: W1 shape [hidden, intermediate // tp_size]
self.w1 = nn.Linear(hidden_size, self.intermediate_local, bias=False)
# Row-parallel: W2 shape [intermediate // tp_size, hidden]
self.w2 = nn.Linear(self.intermediate_local, hidden_size, bias=False)
self._init_weights()
def _init_weights(self):
"""Initialize weights with proper scaling for TP."""
nn.init.xavier_uniform_(self.w1.weight)
# Scale row-parallel weights to maintain variance after all_reduce
nn.init.xavier_uniform_(self.w2.weight)
self.w2.weight.data /= self.tp_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass with minimal communication.
Args:
x: Input tensor of shape [batch, seq, hidden]
Returns:
Output tensor of shape [batch, seq, hidden]
"""
# Step 1: Column-parallel first linear (no communication)
h = self.w1(x)
# Step 2: Activation (local)
h = torch.nn.functional.gelu(h)
# Step 3: Row-parallel second linear
y = self.w2(h)
# Step 4: All-reduce to sum partial products
dist.all_reduce(y, op=dist.ReduceOp.SUM, group=self.tp_group)
return y
class NonParallelMLP(nn.Module):
"""Standard MLP for comparison."""
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
nn.init.xavier_uniform_(self.w1.weight)
nn.init.xavier_uniform_(self.w2.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = torch.nn.functional.gelu(self.w1(x))
return self.w2(h)
def benchmark_tp_mlp(rank: int, world_size: int, hidden_size: int,
batch_size: int, seq_len: int, warmup: int = 10,
iterations: int = 100) -> Tuple[float, float]:
"""Benchmark tensor-parallel MLP."""
device = torch.device("cpu")
intermediate_size = hidden_size * 4
# Create TP MLP
tp_mlp = TensorParallelMLP(
hidden_size, intermediate_size, world_size, rank
).to(device)
# Create input
torch.manual_seed(42)
x = torch.randn(batch_size, seq_len, hidden_size, device=device)
# Warmup
for _ in range(warmup):
_ = tp_mlp(x)
dist.barrier()
# Benchmark
dist.barrier()
start = time.perf_counter()
for _ in range(iterations):
y = tp_mlp(x)
dist.barrier()
total_time = time.perf_counter() - start
return total_time / iterations, y
def verify_correctness(rank: int, world_size: int, hidden_size: int) -> None:
"""Verify TP MLP produces correct output."""
device = torch.device("cpu")
intermediate_size = hidden_size * 4
if rank == 0:
print("\n" + "=" * 60)
print(" CORRECTNESS VERIFICATION")
print("=" * 60)
# Create test input (same on all ranks)
torch.manual_seed(42)
x = torch.randn(4, 8, hidden_size, device=device)
# Create TP MLP with deterministic weights
torch.manual_seed(100)
tp_mlp = TensorParallelMLP(
hidden_size, intermediate_size, world_size, rank
).to(device)
# Forward pass
y_tp = tp_mlp(x)
# Gather TP weights to rank 0 for comparison
# W1 (column-parallel)
w1_local = tp_mlp.w1.weight.data.clone()
w1_gathered = [torch.zeros_like(w1_local) for _ in range(world_size)]
dist.all_gather(w1_gathered, w1_local)
# W2 (row-parallel)
w2_local = tp_mlp.w2.weight.data.clone()
w2_gathered = [torch.zeros_like(w2_local) for _ in range(world_size)]
dist.all_gather(w2_gathered, w2_local)
if rank == 0:
# Reconstruct full weights
w1_full = torch.cat(w1_gathered, dim=0).T # [hidden, intermediate]
w2_full = torch.cat(w2_gathered, dim=1) # [intermediate, hidden]
# Correct for scaling
w2_full = w2_full * world_size
# Compute reference output
h = torch.nn.functional.gelu(x @ w1_full.T)
y_ref = h @ w2_full.T
diff = (y_tp - y_ref).abs().max().item()
print(f"\nInput shape: {x.shape}")
print(f"Output shape: {y_tp.shape}")
print(f"Max difference from reference: {diff:.2e}")
print(f"Correct: {diff < 1e-5}")
def analyze_communication(rank: int, world_size: int,
hidden_size: int, batch_size: int, seq_len: int) -> None:
"""Analyze communication costs."""
if rank != 0:
return
print("\n" + "=" * 60)
print(" COMMUNICATION ANALYSIS")
print("=" * 60)
bytes_per_element = 4 # float32
elements_per_allreduce = batch_size * seq_len * hidden_size
bytes_per_allreduce = elements_per_allreduce * bytes_per_element
# Ring all_reduce volume
ring_volume = 2 * bytes_per_allreduce * (world_size - 1) / world_size
print(f"""
Configuration:
Hidden size: {hidden_size}
Batch size: {batch_size}
Sequence length: {seq_len}
TP degree: {world_size}
Per forward pass:
All-reduce calls: 1
Elements per all-reduce: {elements_per_allreduce:,}
Bytes per all-reduce: {bytes_per_allreduce / 1024:.1f} KB
Communication volume (ring algorithm):
Per GPU: {ring_volume / 1024:.1f} KB
Total across all GPUs: {ring_volume * world_size / 1024:.1f} KB
Comparison with non-TP:
Non-TP: 0 bytes (no communication)
TP: {ring_volume / 1024:.1f} KB per forward
This is the price of tensor parallelism!
But we can now handle models {world_size}x larger.
""")
def compare_scaling(rank: int, world_size: int) -> None:
"""Compare TP vs non-parallel scaling."""
if rank != 0:
return
print("\n" + "=" * 60)
print(" SCALING ANALYSIS")
print("=" * 60)
print("""
Memory scaling with Tensor Parallelism:
For an MLP with hidden_size H and intermediate_size 4H:
Non-parallel:
W1: H × 4H = 4H² parameters
W2: 4H × H = 4H² parameters
Total: 8H² parameters per GPU
With TP degree T:
W1: H × (4H/T) = 4H²/T parameters
W2: (4H/T) × H = 4H²/T parameters
Total: 8H²/T parameters per GPU
Example: H=4096, T=8 (8-way TP)
Non-parallel: 134M parameters (537 MB in FP32)
With 8-way TP: 16.7M parameters (67 MB per GPU)
This is how we fit 70B+ parameter models on GPUs!
""")
def worker(rank: int, world_size: int, hidden_size: int,
batch_size: int, seq_len: int) -> None:
"""Main worker function."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29509"
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
# Verify correctness
verify_correctness(rank, world_size, hidden_size)
dist.barrier()
# Analyze communication
analyze_communication(rank, world_size, hidden_size, batch_size, seq_len)
dist.barrier()
# Benchmark
if rank == 0:
print("\n" + "=" * 60)
print(" BENCHMARK")
print("=" * 60)
avg_time, output = benchmark_tp_mlp(
rank, world_size, hidden_size, batch_size, seq_len
)
dist.barrier()
if rank == 0:
print(f"\nTP MLP forward pass: {avg_time * 1000:.3f} ms")
print(f"Output shape: {output.shape}")
# Compare scaling
compare_scaling(rank, world_size)
dist.destroy_process_group()
def main():
parser = argparse.ArgumentParser(description="Tensor-Parallel MLP Block")
parser.add_argument("--tp-size", "-t", type=int, default=4,
help="Tensor parallelism degree")
parser.add_argument("--hidden-size", "-H", type=int, default=64,
help="Hidden dimension")
parser.add_argument("--batch-size", "-b", type=int, default=4,
help="Batch size")
parser.add_argument("--seq-len", "-s", type=int, default=16,
help="Sequence length")
args = parser.parse_args()
print("╔" + "═" * 58 + "╗")
print("║" + " TENSOR-PARALLEL MLP BLOCK".center(58) + "║")
print("╚" + "═" * 58 + "╝")
print(f"\nTP degree: {args.tp_size}")
print(f"Hidden size: {args.hidden_size}")
print(f"Intermediate size: {args.hidden_size * 4}")
mp.spawn(
worker,
args=(args.tp_size, args.hidden_size, args.batch_size, args.seq_len),
nprocs=args.tp_size,
join=True
)
if __name__ == "__main__":
main()
Chapter 7: Pipeline and Expert Parallelism
“When one GPU can’t hold one layer (TP), we split layers. When it can’t hold all layers (PP), we split the model vertically.”
Learning Objectives
By the end of this chapter, you will be able to:
- Explain how pipeline parallelism splits models across GPUs
- Implement 1F1B scheduling for efficient pipeline execution
- Understand Mixture-of-Experts (MoE) and Expert Parallelism
- Calculate the optimal parallelism strategy for a given model
Prerequisites
- Completed Chapters 5-6 (Data and Tensor Parallelism)
- Understanding of transformer architecture
- Basic knowledge of GPU memory hierarchy
Concept Overview
Pipeline Parallelism: Splitting by Layers
While tensor parallelism splits individual layers horizontally, pipeline parallelism splits the model vertically—each GPU holds a contiguous group of layers.
Full Model: [Embed] [Layer 0-5] [Layer 6-11] [Layer 12-17] [Layer 18-23] [Head]
↓ ↓ ↓ ↓ ↓ ↓
Pipeline: ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐
│GPU 0 │ │GPU 1 │ │GPU 2 │ │GPU 3 │ │GPU 4 │ │GPU 5 │
│Stage 0│ │Stage 1│ │Stage 2│ │Stage 3│ │Stage 4│ │Stage 5│
└──────┘ └──────┘ └──────┘ └──────┘ └──────┘ └──────┘
Communication: Activations flow forward, gradients flow backward (point-to-point send/recv).
The Pipeline Bubble Problem
Naive pipeline execution has a fatal flaw: bubbles.
Time →
GPU 0: [F0] [F1] [F2] [F3] [B3] [B2] [B1] [B0]
GPU 1: [F0] [F1] [F2] [F3] [B3] [B2] [B1] [B0]
GPU 2: [F0] [F1] [F2] [F3] [B3] [B2] [B1] [B0]
GPU 3: [F0] [F1] [F2] [F3] [B3] [B2] [B1] [B0]
Bubbles = empty time where GPUs are idle
Bubble fraction = (P-1) / (M + P - 1), where P = pipeline stages, M = microbatches.
For P=4, M=4: Bubble = 3/7 = 43% wasted time!
1F1B Scheduling: The Solution
1F1B (One Forward, One Backward) interleaves forward and backward passes to reduce bubbles:
Time →
GPU 0: [F0] [F1] [F2] [F3] [B0] [F4] [B1] [F5] [B2] [B3]
GPU 1: [F0] [F1] [F2] [B0] [F3] [B1] [F4] [B2] [B3]
GPU 2: [F0] [F1] [B0] [F2] [B1] [F3] [B2] [B3]
GPU 3: [F0] [B0] [F1] [B1] [F2] [B2] [F3] [B3]
Key insight: Once the pipeline is “full,” each GPU does one forward then one backward, keeping memory constant.
Memory in Pipeline Parallelism
Each GPU stores:
- Model parameters for its stages
- Activations from forward pass (needed for backward)
1F1B memory advantage: Only need to store activations for P microbatches, not M.
Mixture of Experts (MoE)
MoE replaces the standard FFN with multiple “expert” FFNs:
Standard FFN:
Input → FFN → Output
MoE FFN:
Input → Router → Expert 0 →
→ Expert 1 → Weighted Sum → Output
→ Expert 2 →
→ Expert 3 →
The router (a small neural network) decides which experts process each token. Typically, only top-K experts (K=1 or 2) are activated per token.
Why MoE?
- More parameters without more FLOPs
- Each token only activates a fraction of parameters
- DeepSeek-V3: 671B parameters but only 37B activated per token!
Expert Parallelism (EP)
When you have 64+ experts, they don’t fit on one GPU. Expert Parallelism distributes experts across GPUs:
Token Routing
│
┌──────────┼──────────┐
▼ ▼ ▼
┌───────┐ ┌───────┐ ┌───────┐
│ GPU 0 │ │ GPU 1 │ │ GPU 2 │
│E0,E1,E2│ │E3,E4,E5│ │E6,E7,E8│
└───────┘ └───────┘ └───────┘
│ │ │
└──────────┴──────────┘
All-to-All
(collect results)
Communication pattern: All-to-All (each GPU sends tokens to the GPUs hosting the selected experts).
EP vs TP: A Critical Comparison
For MoE models, EP is often better than TP:
| Aspect | Tensor Parallelism | Expert Parallelism |
|---|---|---|
| What’s split | Each expert matrix | Whole experts |
| Communication | 2 all-reduce per layer | 2 all-to-all per layer |
| Volume | 2 × batch × seq × hidden | 2 × k × batch × seq × hidden / N |
| Compute efficiency | Low (small GEMMs) | High (full expert GEMMs) |
Key insight: TP slices already small expert matrices, making GEMMs inefficient. EP keeps expert matrices whole.
Communication Volume Deep Dive
For TP with degree T on an MoE layer:
Volume = 2S (all-reduce, activations of size S)
For EP with N experts, k activated:
Volume = 2kS/N (all-to-all, only k/N of tokens go to each GPU)
When k << N (sparse activation), EP wins on communication too!
Combining Parallelisms: The 3D Approach
Real large-model training uses multiple parallelism strategies:
┌─────────────────────────────────────────────────────────────┐
│ DATA PARALLELISM │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ PIPELINE PARALLELISM │ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │
│ │ │TENSOR PARA- │ │TENSOR PARA- │ │TENSOR PARA- │ │ │
│ │ │LLELISM │ │LLELISM │ │LLELISM │ │ │
│ │ │ (Stage 0) │→│ (Stage 1) │→│ (Stage 2) │ │ │
│ │ │ 8 GPUs │ │ 8 GPUs │ │ 8 GPUs │ │ │
│ │ └─────────────┘ └─────────────┘ └─────────────┘ │ │
│ └──────────────────────────────────────────────────────┘ │
│ (Replica 0) │
│ ┌──────────────────────────────────────────────────────┐ │
│ │ ... more replicas ... │ │
│ └──────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
Rule of thumb:
- TP: Within a node (needs NVLink)
- PP: Across nodes (lower bandwidth OK)
- DP: Scales indefinitely (gradient sync)
Code Walkthrough
Script 1: pipeline_schedule_viz.py
Visualizes different pipeline schedules:
- Naive (fill-drain)
- 1F1B
- Interleaved 1F1B
Shows bubble ratios and memory usage.
Script 2: parallel_strategy_calculator.py
Given model specs and hardware, calculates:
- Memory per GPU for each parallelism strategy
- Communication volume
- Recommended configuration
Try It Yourself
Exercise 1: Calculate Bubble Fraction
For a 4-stage pipeline with 16 microbatches:
- What’s the bubble fraction with naive scheduling?
- How does it improve with 1F1B?
Exercise 2: MoE Communication Analysis
For an MoE layer with:
- 64 experts
- top-2 routing (k=2)
- 4096 hidden dimension
- Batch of 4 × 1024 tokens
Calculate communication volume for:
- 8-way TP (splitting each expert)
- 8-way EP (8 experts per GPU)
Exercise 3: Design a Parallelism Strategy
You have:
- 70B parameter dense model
- 64 H100 GPUs (8 per node)
- 80GB memory per GPU
Design a parallelism strategy. Consider:
- Model size: ~140GB in FP16
- Activations and gradients
- Communication patterns
Key Takeaways
- PP splits the model by layers - Point-to-point communication only
- Bubbles are the enemy - 1F1B scheduling minimizes idle time
- MoE = sparse activation - More parameters, same compute
- EP beats TP for MoE - Keeps expert matrices whole
- Combine strategies - Real systems use TP + PP + DP + EP
The Parallelism Decision Tree
Is one layer too big for one GPU?
├─ Yes → Use Tensor Parallelism (within node)
└─ No
└─ Is the full model too big for one GPU?
├─ Yes → Use Pipeline Parallelism (across nodes OK)
│ + Use Tensor Parallelism if layers are large
└─ No
└─ Use Data Parallelism (scales indefinitely)
Is the model MoE?
├─ Yes → Add Expert Parallelism (across nodes OK)
└─ No → Continue with above strategy
What’s Next?
In Part III, we’ll dive into LLM Inference Systems—how to efficiently serve models after training. This includes KV cache management, batching strategies, and speculative decoding.
Further Reading
pipeline_schedule_viz.py
Visualize pipeline scheduling strategies and understand bubbles
This script creates ASCII visualizations of different pipeline scheduling algorithms, showing how they affect GPU utilization and memory usage.
What It Does
- Visualizes Naive (Fill-Drain) scheduling - shows massive bubbles
- Visualizes 1F1B scheduling - shows reduced bubbles
- Calculates bubble fraction for each approach
- Compares memory requirements
Run It
python tutorial/part2-parallelism/chapter07-pipeline-expert/scripts/pipeline_schedule_viz.py
Example Output
=== Naive Fill-Drain Schedule (P=4, M=8) ===
Time →
GPU 0: [F0][F1][F2][F3][F4][F5][F6][F7][ ][ ][ ][B7][B6][B5][B4][B3][B2][B1][B0]
GPU 1: [F0][F1][F2][F3][F4][F5][F6][F7][ ][ ][B7][B6][B5][B4][B3][B2][B1][B0]
GPU 2: [F0][F1][F2][F3][F4][F5][F6][F7][ ][B7][B6][B5][B4][B3][B2][B1][B0]
GPU 3: [F0][F1][F2][F3][F4][F5][F6][F7][B7][B6][B5][B4][B3][B2][B1][B0]
Bubble fraction: 27% (3 slots idle per GPU out of 11)
Peak memory: 8 microbatches of activations
=== 1F1B Schedule (P=4, M=8) ===
Time →
GPU 0: [F0][F1][F2][F3][B0][F4][B1][F5][B2][F6][B3][F7][B4][B5][B6][B7]
GPU 1: [F0][F1][F2][B0][F3][B1][F4][B2][F5][B3][F6][B4][F7][B5][B6][B7]
GPU 2: [F0][F1][B0][F2][B1][F3][B2][F4][B3][F5][B4][F6][B5][F7][B6][B7]
GPU 3: [F0][B0][F1][B1][F2][B2][F3][B3][F4][B4][F5][B5][F6][B6][F7][B7]
Bubble fraction: 19%
Peak memory: 4 microbatches of activations (= P, not M!)
Key Insight
1F1B achieves:
- Lower bubble fraction by interleaving forward/backward
- Constant memory by releasing activations as soon as backward is done
Source Code
#!/usr/bin/env python3
"""
Pipeline Schedule Visualizer
This script visualizes different pipeline parallelism scheduling strategies:
- Naive (fill-drain): Simple but high bubble overhead
- 1F1B: Interleaved forward/backward for lower memory
- Shows bubble ratios and GPU utilization
Usage:
python pipeline_schedule_viz.py
python pipeline_schedule_viz.py --stages 4 --microbatches 8
"""
import argparse
from typing import List, Tuple
def visualize_naive_schedule(stages: int, microbatches: int) -> Tuple[str, float]:
"""
Visualize naive fill-drain pipeline schedule.
In naive scheduling:
1. Forward all microbatches through the pipeline
2. Then backward all microbatches
This leads to large bubbles at start and end.
"""
# Calculate timeline
# Forward: stage i starts at time i
# Backward: starts after all forwards complete
total_forward_time = microbatches + stages - 1
total_backward_time = microbatches + stages - 1
total_time = total_forward_time + total_backward_time
# Build visualization
lines = []
for stage in range(stages):
line = ["."] * total_time
# Forward passes
for mb in range(microbatches):
t = stage + mb
line[t] = f"F{mb}"
# Backward passes (start after all forwards)
for mb in range(microbatches):
t = total_forward_time + (stages - 1 - stage) + mb
line[t] = f"B{mb}"
lines.append(line)
# Calculate bubble ratio
work_per_stage = 2 * microbatches # F + B for each microbatch
bubble_per_stage = total_time - work_per_stage
bubble_ratio = (stages - 1) * microbatches / (stages * microbatches + (stages - 1))
# Format output
output = []
output.append(f"Time → " + "".join([f"{i:>3}" for i in range(total_time)]))
output.append("-" * (8 + total_time * 3))
for stage, line in enumerate(lines):
formatted = "".join([f"{x:>3}" for x in line])
output.append(f"GPU {stage}: {formatted}")
return "\n".join(output), bubble_ratio
def visualize_1f1b_schedule(stages: int, microbatches: int) -> Tuple[str, float, int]:
"""
Visualize 1F1B (One Forward, One Backward) pipeline schedule.
Key insight: After warmup, each stage does 1F then 1B,
keeping activation memory bounded.
Memory = max(warmup_microbatches) = stages
"""
# Timeline representation
total_time = 2 * microbatches + 2 * (stages - 1)
lines = []
for stage in range(stages):
line = ["."] * total_time
warmup_steps = stages - stage # First stages need more warmup
# Warmup phase: only forwards
for mb in range(warmup_steps):
t = stage + mb
if t < total_time:
line[t] = f"F{mb}"
# Steady state: 1F1B
for mb in range(warmup_steps, microbatches):
# Forward at position
f_time = stage + mb
if f_time < total_time:
line[f_time] = f"F{mb}"
# Backward at position (for earlier microbatch)
b_mb = mb - warmup_steps
b_time = f_time + 1 if stage == stages - 1 else f_time + 2
if b_time < total_time:
line[b_time] = f"B{b_mb}"
# Cooldown: remaining backwards
cooldown_start = microbatches - warmup_steps
for i, mb in enumerate(range(cooldown_start, microbatches)):
b_time = stages + microbatches - 1 + stage + i
if b_time < total_time:
line[b_time] = f"B{mb}"
lines.append(line)
# Simplified bubble calculation
work_units = 2 * microbatches
bubble_ratio = (stages - 1) / (microbatches + stages - 1)
peak_memory = stages # Peak number of activations stored
# Format output
output = []
output.append(f"Time → " + "".join([f"{i:>3}" for i in range(min(total_time, 25))]))
output.append("-" * (8 + min(total_time, 25) * 3))
for stage, line in enumerate(lines):
formatted = "".join([f"{x:>3}" for x in line[:25]])
output.append(f"GPU {stage}: {formatted}" + ("..." if total_time > 25 else ""))
return "\n".join(output), bubble_ratio, peak_memory
def analyze_schedules(stages: int, microbatches: int) -> None:
"""Compare different scheduling strategies."""
print("=" * 70)
print(" PIPELINE SCHEDULE COMPARISON")
print("=" * 70)
print(f"\nConfiguration: {stages} stages, {microbatches} microbatches\n")
# Naive schedule
print("-" * 70)
print(" NAIVE (Fill-Drain) SCHEDULE")
print("-" * 70)
print("""
Strategy: Complete all forwards, then all backwards.
Memory: Must store activations for ALL microbatches.
""")
naive_viz, naive_bubble = visualize_naive_schedule(stages, microbatches)
print(naive_viz)
print(f"\nBubble ratio: {naive_bubble:.1%}")
print(f"Peak activation memory: {microbatches} microbatches worth")
print("\n")
# 1F1B schedule
print("-" * 70)
print(" 1F1B (One Forward, One Backward) SCHEDULE")
print("-" * 70)
print("""
Strategy: After warmup, alternate 1 forward then 1 backward.
Memory: Only store activations for 'stages' microbatches.
""")
fb_viz, fb_bubble, fb_memory = visualize_1f1b_schedule(stages, microbatches)
print(fb_viz)
print(f"\nBubble ratio: {fb_bubble:.1%}")
print(f"Peak activation memory: {fb_memory} microbatches worth")
# Comparison
print("\n" + "=" * 70)
print(" SUMMARY")
print("=" * 70)
print(f"""
{'Metric':<25} {'Naive':<20} {'1F1B':<20}
{'-'*65}
{'Bubble ratio':<25} {naive_bubble:.1%:<20} {fb_bubble:.1%:<20}
{'Peak memory':<25} {microbatches:<20} {fb_memory:<20}
Key insights:
1. 1F1B has the SAME bubble ratio but LOWER memory
2. More microbatches → lower bubble ratio (approaches 0 as M→∞)
3. Peak memory in 1F1B is bounded by pipeline depth
""")
def demonstrate_bubble_reduction() -> None:
"""Show how bubble ratio decreases with more microbatches."""
print("\n" + "=" * 70)
print(" BUBBLE RATIO vs MICROBATCHES")
print("=" * 70)
print("""
Bubble ratio = (P-1) / (M + P - 1)
Where P = pipeline stages, M = microbatches
""")
stages = 4
print(f"For P = {stages} stages:\n")
print(f"{'Microbatches':<15} {'Bubble Ratio':<15} {'Efficiency':<15}")
print("-" * 45)
for mb in [1, 2, 4, 8, 16, 32, 64]:
bubble = (stages - 1) / (mb + stages - 1)
efficiency = 1 - bubble
print(f"{mb:<15} {bubble:.1%:<15} {efficiency:.1%:<15}")
print("""
Takeaway: Use at least 4x pipeline stages as microbatches
for > 80% efficiency.
""")
def explain_memory_tradeoff() -> None:
"""Explain the memory-throughput tradeoff."""
print("\n" + "=" * 70)
print(" MEMORY vs THROUGHPUT TRADEOFF")
print("=" * 70)
print("""
The fundamental tradeoff in pipeline parallelism:
MORE MICROBATCHES:
✓ Lower bubble ratio (better throughput)
✗ More activation memory (naive) or same (1F1B)
✗ Smaller per-microbatch batch size (worse GPU utilization)
FEWER MICROBATCHES:
✗ Higher bubble ratio (worse throughput)
✓ Less activation memory
✓ Larger per-microbatch batch size (better GPU utilization)
1F1B ADVANTAGE:
With 1F1B, memory is bounded by pipeline depth, NOT microbatches.
This allows many microbatches for low bubbles without memory explosion.
Example calculation:
Model: 24 layers, 4096 hidden dim, batch 512
Pipeline: 4 stages (6 layers each)
Microbatches: 16 (32 samples each)
Naive memory: 16 × activations ≈ 16 × 32 × 4096 × 6 = 12.6 GB per stage
1F1B memory: 4 × activations ≈ 4 × 32 × 4096 × 6 = 3.1 GB per stage
4x memory reduction!
""")
def main():
parser = argparse.ArgumentParser(description="Pipeline Schedule Visualizer")
parser.add_argument("--stages", "-s", type=int, default=4,
help="Number of pipeline stages")
parser.add_argument("--microbatches", "-m", type=int, default=8,
help="Number of microbatches")
args = parser.parse_args()
print("╔" + "═" * 68 + "╗")
print("║" + " PIPELINE PARALLELISM SCHEDULE VISUALIZER".center(68) + "║")
print("╚" + "═" * 68 + "╝")
analyze_schedules(args.stages, args.microbatches)
demonstrate_bubble_reduction()
explain_memory_tradeoff()
if __name__ == "__main__":
main()
parallel_strategy_calculator.py
Calculate optimal parallelism configuration for your model and hardware
This script helps you design a parallelism strategy by calculating memory requirements, communication volume, and efficiency for different configurations.
What It Does
- Takes model specifications (parameters, layers, hidden size)
- Takes hardware specifications (GPU memory, count, interconnect)
- Calculates memory per GPU for each parallelism strategy
- Recommends optimal configuration
Run It
python tutorial/part2-parallelism/chapter07-pipeline-expert/scripts/parallel_strategy_calculator.py
Example Usage
=== Parallelism Strategy Calculator ===
Model: 70B parameters, 80 layers, hidden=8192
Hardware: 64 GPUs (8 per node), 80GB each, NVLink intra-node
Configuration options:
| Strategy | TP | PP | DP | Memory/GPU | Comm Volume | Feasible? |
|---------------|----|----|----| -----------|-------------|-----------|
| Pure DP | 1 | 1 | 64 | 840 GB | 2D/step | No |
| TP=8, DP=8 | 8 | 1 | 8 | 105 GB | 8D/layer | No |
| TP=8, PP=2 | 8 | 2 | 4 | 52 GB | 8D/layer | Yes |
| TP=8, PP=4 | 8 | 4 | 2 | 26 GB | 8D/layer | Yes (rec) |
Recommended: TP=8 (within node), PP=4 (across nodes), DP=2
Reasoning:
- TP=8 uses NVLink bandwidth efficiently
- PP=4 distributes 80 layers across 4 stages (20 layers each)
- DP=2 provides batch parallelism for throughput
Input Parameters
The calculator considers:
- Model: Parameters, layers, hidden dimension, precision
- Hardware: GPU count, memory per GPU, interconnect bandwidth
- Training: Batch size, sequence length, microbatch count
What It Calculates
For each configuration:
- Memory per GPU: Parameters + gradients + optimizer + activations
- Communication volume: Per-step all_reduce/all_gather/send-recv
- Bubble fraction: For pipeline configurations
- Feasibility: Does it fit in GPU memory?
Source Code
#!/usr/bin/env python3
"""
Parallel Strategy Calculator
Given model specifications and hardware constraints, this script helps
you determine the optimal parallelism strategy.
It calculates:
- Memory requirements for different strategies
- Communication volumes
- Recommended configuration
Usage:
python parallel_strategy_calculator.py
python parallel_strategy_calculator.py --params 70 --gpus 64 --memory 80
"""
import argparse
from dataclasses import dataclass
from typing import Optional, Tuple
import math
@dataclass
class ModelConfig:
"""Model configuration."""
params_billions: float
hidden_size: int = 8192
num_layers: int = 80
num_heads: int = 64
vocab_size: int = 128000
intermediate_ratio: float = 4.0
is_moe: bool = False
num_experts: int = 1
top_k: int = 1 # Experts activated per token
@dataclass
class HardwareConfig:
"""Hardware configuration."""
num_gpus: int
memory_per_gpu_gb: float
intra_node_bandwidth_gbps: float = 900 # NVLink
inter_node_bandwidth_gbps: float = 50 # InfiniBand
gpus_per_node: int = 8
@dataclass
class TrainingConfig:
"""Training configuration."""
batch_size: int = 1024
sequence_length: int = 4096
dtype_bytes: int = 2 # FP16/BF16
def estimate_model_memory(config: ModelConfig, dtype_bytes: int = 2) -> dict:
"""Estimate memory requirements for model parameters."""
# Parameter count estimation
# Embedding: vocab_size * hidden_size
# Per layer: 4 * hidden^2 (QKV + O) + 2 * hidden * intermediate (FFN)
embedding_params = config.vocab_size * config.hidden_size
attention_params = 4 * config.hidden_size ** 2 # Q, K, V, O projections
ffn_params = 2 * config.hidden_size * int(config.hidden_size * config.intermediate_ratio)
if config.is_moe:
# MoE: multiply FFN by number of experts
ffn_params *= config.num_experts
layer_params = attention_params + ffn_params
total_params = embedding_params + (layer_params * config.num_layers)
# Convert to bytes
params_bytes = total_params * dtype_bytes
gradients_bytes = params_bytes # Same size as params
# Optimizer states (Adam: 2x params in FP32)
optimizer_bytes = total_params * 4 * 2
return {
'params': params_bytes / 1e9, # GB
'gradients': gradients_bytes / 1e9,
'optimizer': optimizer_bytes / 1e9,
'total': (params_bytes + gradients_bytes + optimizer_bytes) / 1e9,
'param_count': total_params,
}
def estimate_activation_memory(
config: ModelConfig,
training: TrainingConfig,
tp_degree: int = 1,
pp_degree: int = 1
) -> float:
"""Estimate activation memory per GPU in GB."""
batch_per_gpu = training.batch_size // (training.batch_size // pp_degree)
seq_len = training.sequence_length
hidden = config.hidden_size
# Per-layer activations (simplified)
# Input to attention, attention output, FFN intermediate, etc.
activations_per_layer = 10 * batch_per_gpu * seq_len * hidden // tp_degree
layers_per_stage = config.num_layers // pp_degree
total_activation_bytes = activations_per_layer * layers_per_stage * training.dtype_bytes
return total_activation_bytes / 1e9
def calculate_communication_volume(
config: ModelConfig,
training: TrainingConfig,
tp_degree: int,
dp_degree: int,
pp_degree: int
) -> dict:
"""Calculate communication volume for different parallelism types."""
batch = training.batch_size
seq = training.sequence_length
hidden = config.hidden_size
dtype = training.dtype_bytes
# TP communication: all_reduce per layer
# 2 all_reduce per transformer layer (attention + FFN)
tp_volume_per_layer = 4 * batch * seq * hidden * dtype * (tp_degree - 1) / tp_degree
tp_volume_total = tp_volume_per_layer * config.num_layers / pp_degree
# PP communication: activations between stages
pp_volume = 2 * batch * seq * hidden * dtype # Forward and backward
# DP communication: gradient all_reduce
params_per_stage = config.params_billions * 1e9 / pp_degree
dp_volume = 2 * params_per_stage * dtype * (dp_degree - 1) / dp_degree
return {
'tp_per_step_gb': tp_volume_total / 1e9,
'pp_per_step_gb': pp_volume / 1e9,
'dp_per_step_gb': dp_volume / 1e9,
'total_gb': (tp_volume_total + pp_volume + dp_volume) / 1e9,
}
def find_optimal_strategy(
model: ModelConfig,
hardware: HardwareConfig,
training: TrainingConfig
) -> dict:
"""Find optimal parallelism strategy given constraints."""
mem = estimate_model_memory(model, training.dtype_bytes)
results = []
# Try different configurations
for tp in [1, 2, 4, 8]:
if tp > hardware.gpus_per_node:
continue
for pp in [1, 2, 4, 8, 16]:
if tp * pp > hardware.num_gpus:
continue
dp = hardware.num_gpus // (tp * pp)
if dp < 1:
continue
# Memory per GPU
params_per_gpu = mem['params'] / (tp * pp)
grads_per_gpu = mem['gradients'] / (tp * pp)
optimizer_per_gpu = mem['optimizer'] / (tp * pp) # With ZeRO-3
# ZeRO-3 shards optimizer across DP
optimizer_per_gpu = optimizer_per_gpu / dp
activation_mem = estimate_activation_memory(model, training, tp, pp)
total_mem = params_per_gpu + grads_per_gpu + optimizer_per_gpu + activation_mem
# Communication
comm = calculate_communication_volume(model, training, tp, dp, pp)
# Estimate if TP crosses nodes
tp_is_intra_node = tp <= hardware.gpus_per_node
results.append({
'tp': tp,
'pp': pp,
'dp': dp,
'memory_per_gpu': total_mem,
'fits': total_mem < hardware.memory_per_gpu_gb * 0.9, # 90% threshold
'communication': comm,
'tp_intra_node': tp_is_intra_node,
})
return results
def print_results(results: list, hardware: HardwareConfig) -> None:
"""Print strategy comparison."""
print("\n" + "=" * 80)
print(" STRATEGY COMPARISON")
print("=" * 80)
# Header
print(f"\n{'TP':>4} {'PP':>4} {'DP':>4} {'Mem/GPU':>10} {'Fits?':>8} "
f"{'TP Comm':>10} {'PP Comm':>10} {'DP Comm':>10}")
print("-" * 80)
valid_configs = []
for r in results:
fits = "✓" if r['fits'] else "✗"
tp_note = "" if r['tp_intra_node'] else "*"
print(f"{r['tp']:>4}{tp_note:<1} {r['pp']:>3} {r['dp']:>4} "
f"{r['memory_per_gpu']:>9.1f}GB {fits:>8} "
f"{r['communication']['tp_per_step_gb']:>9.2f}GB "
f"{r['communication']['pp_per_step_gb']:>9.2f}GB "
f"{r['communication']['dp_per_step_gb']:>9.2f}GB")
if r['fits']:
valid_configs.append(r)
print("\n* TP crosses node boundary (slower inter-node communication)")
# Recommendation
print("\n" + "=" * 80)
print(" RECOMMENDATION")
print("=" * 80)
if not valid_configs:
print("\n⚠ No configuration fits in memory!")
print(" Consider: More GPUs, larger GPU memory, or smaller batch size")
return
# Sort by communication volume (prefer lower communication)
valid_configs.sort(key=lambda x: (
0 if x['tp_intra_node'] else 1, # Prefer intra-node TP
x['communication']['total_gb'],
))
best = valid_configs[0]
print(f"""
Recommended configuration:
Tensor Parallelism (TP): {best['tp']}
Pipeline Parallelism (PP): {best['pp']}
Data Parallelism (DP): {best['dp']}
Memory per GPU: {best['memory_per_gpu']:.1f} GB (limit: {hardware.memory_per_gpu_gb} GB)
Total communication: {best['communication']['total_gb']:.2f} GB per step
Reasoning:
- TP={best['tp']} {"stays within a node (NVLink speed)" if best['tp_intra_node'] else "crosses nodes (slower)"}
- PP={best['pp']} splits model into {best['pp']} stages
- DP={best['dp']} {"provides excellent scaling" if best['dp'] > 1 else "single replica"}
""")
def main():
parser = argparse.ArgumentParser(description="Parallel Strategy Calculator")
parser.add_argument("--params", "-p", type=float, default=70,
help="Model parameters in billions (default: 70)")
parser.add_argument("--gpus", "-g", type=int, default=64,
help="Total number of GPUs (default: 64)")
parser.add_argument("--memory", "-m", type=float, default=80,
help="GPU memory in GB (default: 80)")
parser.add_argument("--batch-size", "-b", type=int, default=512,
help="Global batch size (default: 512)")
parser.add_argument("--seq-len", "-s", type=int, default=4096,
help="Sequence length (default: 4096)")
parser.add_argument("--moe", action="store_true",
help="Model is Mixture of Experts")
parser.add_argument("--num-experts", type=int, default=64,
help="Number of experts for MoE (default: 64)")
args = parser.parse_args()
print("╔" + "═" * 78 + "╗")
print("║" + " PARALLEL STRATEGY CALCULATOR".center(78) + "║")
print("╚" + "═" * 78 + "╝")
# Configure model
model = ModelConfig(
params_billions=args.params,
is_moe=args.moe,
num_experts=args.num_experts if args.moe else 1,
)
hardware = HardwareConfig(
num_gpus=args.gpus,
memory_per_gpu_gb=args.memory,
)
training = TrainingConfig(
batch_size=args.batch_size,
sequence_length=args.seq_len,
)
# Print configuration
print(f"\n{'─'*40}")
print(" MODEL CONFIGURATION")
print(f"{'─'*40}")
mem = estimate_model_memory(model)
print(f"Parameters: {args.params}B ({mem['param_count']/1e9:.1f}B actual)")
print(f"Parameters memory: {mem['params']:.1f} GB")
print(f"Gradients memory: {mem['gradients']:.1f} GB")
print(f"Optimizer memory: {mem['optimizer']:.1f} GB")
print(f"Total model memory: {mem['total']:.1f} GB")
if args.moe:
print(f"MoE: {args.num_experts} experts")
print(f"\n{'─'*40}")
print(" HARDWARE CONFIGURATION")
print(f"{'─'*40}")
print(f"GPUs: {args.gpus} ({args.gpus // 8} nodes)")
print(f"Memory per GPU: {args.memory} GB")
print(f"\n{'─'*40}")
print(" TRAINING CONFIGURATION")
print(f"{'─'*40}")
print(f"Batch size: {args.batch_size}")
print(f"Sequence length: {args.seq_len}")
# Calculate strategies
results = find_optimal_strategy(model, hardware, training)
print_results(results, hardware)
# Additional MoE considerations
if args.moe:
print("\n" + "=" * 80)
print(" MOE-SPECIFIC CONSIDERATIONS")
print("=" * 80)
print(f"""
For MoE models, also consider Expert Parallelism (EP):
With {args.num_experts} experts and 8-way EP:
- {args.num_experts // 8} experts per GPU
- Communication: 2 all-to-all per layer
EP vs TP for MoE:
- EP: Keeps full expert matrices → better GEMM efficiency
- TP: Slices experts → smaller GEMMs, worse efficiency
- EP preferred when num_experts >> TP degree
Recommendation: Use EP instead of/in addition to TP for the FFN experts.
""")
if __name__ == "__main__":
main()
Chapter 8: Anatomy of an LLM Inference Server
“Training is a sprint. Inference is a marathon that never ends.”
Learning Objectives
By the end of this chapter, you will be able to:
- Trace the lifecycle of a request through an inference server
- Explain the roles of Tokenizer, Scheduler, and Model Runner
- Understand why inference is fundamentally different from training
- Identify bottlenecks in inference serving
Prerequisites
- Completed Part II (Parallelism Strategies)
- Basic understanding of transformer architecture
- Familiarity with REST APIs
Concept Overview
Training vs Inference: A Tale of Two Challenges
| Aspect | Training | Inference |
|---|---|---|
| Goal | Update model weights | Generate tokens |
| Batch size | Fixed (large) | Dynamic (varies) |
| Latency | Irrelevant | Critical |
| Throughput | Samples/second | Tokens/second |
| Memory | Dominated by gradients | Dominated by KV cache |
| Workload | Homogeneous | Heterogeneous |
Training processes fixed batches for hours. Inference serves arbitrary requests in milliseconds.
The Inference Pipeline
When you send a prompt to an LLM, here’s what happens:
┌──────────────────────────────────────────────────────────────────────────┐
│ LLM INFERENCE SERVER │
│ │
│ HTTP Request ─────────────────────────────────────────► HTTP Response │
│ │ ▲ │
│ ▼ │ │
│ ┌─────────────┐ ┌───────────────┐ ┌─────────────────┐ │ │
│ │ API Adapter │───►│TokenizerMgr │───►│ Scheduler │ │ │
│ │ │ │(tokenize) │ │(batch requests) │ │ │
│ └─────────────┘ └───────────────┘ └───────┬─────────┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌─────────────────┐ │ │
│ │ Model Runner │ │ │
│ │ (GPU compute) │ │ │
│ └───────┬─────────┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌─────────────────┐ │ │
│ │DetokenizerMgr │──┘ │
│ │(tokens→text) │ │
│ └─────────────────┘ │
│ │
└──────────────────────────────────────────────────────────────────────────┘
Component Deep Dive
1. API Adapter
Translates HTTP requests into internal format:
- Parses JSON body
- Validates parameters (temperature, max_tokens, etc.)
- Creates
GenerateRequestobject
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatRequest):
# Validate and convert to internal format
generate_request = convert_to_internal(request)
# Send to tokenizer manager
return await tokenizer_manager.generate(generate_request)
2. Tokenizer Manager
Handles text ↔ token conversion:
- Tokenizes input prompt
- Manages vocabulary and special tokens
- Queues tokenized requests for scheduler
3. Scheduler
The brain of the inference server:
- Manages request queue
- Decides which requests to batch together
- Allocates KV cache memory
- Chooses between prefill and decode
The scheduler is so important it gets its own chapters (9-10)!
4. Model Runner
Executes the actual neural network:
- Loads model weights
- Runs forward pass
- Samples next token
5. Detokenizer Manager
Converts tokens back to text:
- Decodes token IDs to strings
- Handles streaming output
- Manages stop sequences
The Two Phases of Inference
LLM inference has two distinct phases:
Phase 1: Prefill (Prompt Processing)
Input: "What is the capital of France?"
[token_0, token_1, token_2, ..., token_n]
Output: KV cache for all tokens + first generated token
Compute: Parallelizable (all tokens at once)
Memory: Write n entries to KV cache
Phase 2: Decode (Token Generation)
Input: Previously generated token + KV cache
[token_i]
Output: Next token
Compute: Sequential (one token at a time)
Memory: Read from KV cache, write 1 entry
Time →
Prefill: [===================] (process all prompt tokens)
↓
Decode: [=] [=] [=] [=] [=] [=] [=] [=] ...
t₁ t₂ t₃ t₄ t₅ t₆ t₇ t₈
Key insight: Prefill is compute-bound, decode is memory-bound.
Why Batching is Complicated
Training batches are simple: same sequence length, process together.
Inference batches are hard:
- Requests arrive at different times
- Different prompt lengths
- Different desired output lengths
- Some requests finish mid-batch
Continuous batching solves this:
Time →
Request A: [====prefill====][d][d][d][d][d][done]
Request B: [prefill][d][d][d][d][d][d][d][d]...
Request C: [====prefill====][d][d]...
Batched execution:
[A+B prefill] [A+B decode] [A+B+C] [B+C decode] ...
Memory: The Inference Bottleneck
For a 70B parameter model serving requests:
| Component | Memory |
|---|---|
| Model weights (FP16) | 140 GB |
| KV cache (per request) | ~2 GB for 32K context |
| Activations | ~1 GB |
With 140 GB of weights and 80 GB GPU memory… we need tensor parallelism just to load the model!
And each request needs its own KV cache. Serving 100 concurrent requests at 32K context would need 200 GB just for KV cache!
This is why KV cache management (Chapter 9) is critical.
Code Walkthrough
Script: minimal_inference_server.py
A simplified inference server showing the core components:
- Request queue management
- Simple batching
- Token-by-token generation
This isn’t production-ready but demonstrates the architecture.
Key Metrics
When evaluating inference servers:
| Metric | Definition | Target |
|---|---|---|
| TTFT | Time To First Token | < 500ms |
| ITL | Inter-Token Latency | < 50ms |
| Throughput | Tokens/second | Maximize |
| Concurrency | Simultaneous requests | Maximize |
Trade-offs:
- Higher concurrency → higher memory → larger batches → higher throughput
- Higher concurrency → more KV cache → potential OOM
- Larger batches → higher latency per request
Try It Yourself
Exercise 1: Trace a Request
Using the minimal_inference_server.py:
- Add logging to each component
- Trace a single request through the system
- Measure time spent in each stage
Exercise 2: Measure Prefill vs Decode
Profile inference to measure:
- Time for prefill (prompt processing)
- Time per decode step
- How does prompt length affect prefill time?
- How does batch size affect decode time?
Exercise 3: Calculate KV Cache Size
For a model with:
- 32 layers
- 8192 hidden dimension
- 128 heads
- 32K max sequence length
Calculate:
- KV cache size per token
- KV cache size for one 32K request
- Max concurrent requests with 80 GB memory (after model weights)
Key Takeaways
- Inference is a pipeline - Multiple stages, each can be a bottleneck
- Prefill vs Decode - Different compute characteristics, different optimizations
- Memory dominates - KV cache limits concurrency
- Batching is complex - Continuous batching enables high throughput
- Latency matters - Unlike training, users are waiting
The Inference Optimization Hierarchy
Level 0: Model works (correctness)
↓
Level 1: Model fits in memory (quantization, TP)
↓
Level 2: Efficient memory management (KV cache, paging)
↓
Level 3: Efficient batching (continuous batching)
↓
Level 4: Kernel optimizations (FlashAttention, CUDA graphs)
↓
Level 5: Speculative decoding (draft models)
What’s Next?
In Chapter 9, we’ll dive deep into KV Cache Management—how systems like PagedAttention and RadixCache enable serving many concurrent requests efficiently.
Further Reading
minimal_inference_server.py
A simplified LLM inference server demonstrating core architecture
This script implements a minimal inference server showing the key components: request handling, batching, and token generation.
What It Does
- Creates a simple request queue
- Implements basic batching logic
- Simulates the prefill/decode loop
- Demonstrates streaming output
Architecture
┌─────────────────────────────────────────────────┐
│ Minimal Inference Server │
│ │
│ Request Queue ──► Batcher ──► Model ──► Output │
│ │
│ Components: │
│ - RequestQueue: FIFO queue for incoming prompts│
│ - SimpleBatcher: Groups requests for GPU │
│ - MockModel: Simulates forward pass │
│ - Generator: Token-by-token output loop │
└─────────────────────────────────────────────────┘
Run It
python tutorial/part3-inference/chapter08-server-anatomy/scripts/minimal_inference_server.py
Key Learning Points
Request Lifecycle:
# 1. Request arrives
request = Request(prompt="Hello, world!")
# 2. Tokenize
tokens = tokenizer.encode(request.prompt)
# 3. Add to queue
queue.add(request)
# 4. Batch processing
batch = batcher.get_next_batch()
# 5. Prefill (process prompt)
kv_cache = model.prefill(batch)
# 6. Decode (generate tokens)
while not done:
next_token = model.decode(kv_cache)
yield next_token
What This Demonstrates
- Separation of concerns: Each component has a single responsibility
- Queue management: Requests are processed fairly
- Batching strategy: Multiple requests share GPU
- Two-phase inference: Prefill then decode
What’s Missing (Real Systems)
- KV cache management (Chapter 9)
- CUDA graph optimization (Chapter 10)
- Speculative decoding (Chapter 11)
- Tensor parallelism for large models
- Production error handling
Source Code
#!/usr/bin/env python3
"""
Minimal LLM Inference Server
This script demonstrates the core components of an inference server:
- Request management
- Simple batching
- Token generation loop
This is educational, not production-ready. Real servers like vLLM and
SGLang have much more sophisticated implementations.
Usage:
python minimal_inference_server.py
python minimal_inference_server.py --num-requests 10
"""
import argparse
import asyncio
import time
from dataclasses import dataclass, field
from typing import List, Optional, AsyncIterator
from collections import deque
import random
@dataclass
class GenerateRequest:
"""A request to generate text."""
id: int
prompt: str
prompt_tokens: List[int]
max_tokens: int = 50
temperature: float = 1.0
created_at: float = field(default_factory=time.time)
# Tracking
generated_tokens: List[int] = field(default_factory=list)
is_finished: bool = False
prefill_done: bool = False
@dataclass
class Batch:
"""A batch of requests to process together."""
requests: List[GenerateRequest]
is_prefill: bool # True for prefill, False for decode
class SimpleTokenizer:
"""
A simplified tokenizer for demonstration.
Real tokenizers (like SentencePiece or tiktoken) are more complex.
"""
def __init__(self, vocab_size: int = 1000):
self.vocab_size = vocab_size
# Simple word-based tokenization
self.token_to_id = {"<pad>": 0, "<eos>": 1, "<unk>": 2}
self.id_to_token = {0: "<pad>", 1: "<eos>", 2: "<unk>"}
def encode(self, text: str) -> List[int]:
"""Convert text to token IDs."""
# Simplified: assign random IDs to words
words = text.lower().split()
tokens = []
for word in words:
# Hash word to get consistent token ID
token_id = hash(word) % (self.vocab_size - 3) + 3
tokens.append(token_id)
return tokens
def decode(self, token_ids: List[int]) -> str:
"""Convert token IDs back to text."""
# Simplified: just return placeholder
return f"[Generated {len(token_ids)} tokens]"
class SimpleModelRunner:
"""
A simplified model runner for demonstration.
Real model runners load actual neural networks and run GPU inference.
"""
def __init__(self, vocab_size: int = 1000, latency_ms: float = 10):
self.vocab_size = vocab_size
self.latency_ms = latency_ms
async def prefill(self, request: GenerateRequest) -> int:
"""
Process prompt and return first generated token.
Real prefill:
1. Run all prompt tokens through model in parallel
2. Build KV cache for all tokens
3. Sample first output token
"""
# Simulate compute time (proportional to prompt length)
prompt_len = len(request.prompt_tokens)
await asyncio.sleep(self.latency_ms * prompt_len / 100)
# "Generate" first token
first_token = random.randint(3, self.vocab_size - 1)
return first_token
async def decode(self, batch: List[GenerateRequest]) -> List[int]:
"""
Generate next token for each request in batch.
Real decode:
1. Run single token through model for each request
2. Update KV cache with new KV pairs
3. Sample next token for each request
"""
# Simulate compute time (roughly constant per batch)
await asyncio.sleep(self.latency_ms)
# "Generate" next tokens
next_tokens = []
for req in batch:
# 10% chance of generating EOS
if random.random() < 0.1:
next_tokens.append(1) # EOS
else:
next_tokens.append(random.randint(3, self.vocab_size - 1))
return next_tokens
class Scheduler:
"""
Manages request queue and batching decisions.
Key responsibilities:
1. Accept new requests
2. Decide which requests to process together
3. Manage prefill vs decode scheduling
"""
def __init__(self, max_batch_size: int = 4):
self.max_batch_size = max_batch_size
self.waiting_queue: deque = deque() # Requests waiting for prefill
self.running_batch: List[GenerateRequest] = [] # Requests in decode phase
self.completed: List[GenerateRequest] = []
def add_request(self, request: GenerateRequest):
"""Add a new request to the waiting queue."""
self.waiting_queue.append(request)
print(f"[Scheduler] Added request {request.id} to queue "
f"(queue size: {len(self.waiting_queue)})")
def get_next_batch(self) -> Optional[Batch]:
"""
Decide what to process next.
Strategy (simplified):
1. If we have requests waiting AND room in running batch, do prefill
2. If running batch has requests, do decode
"""
# Check for finished requests first
self.running_batch = [r for r in self.running_batch if not r.is_finished]
# Prefill new requests if we have capacity
while (self.waiting_queue and
len(self.running_batch) < self.max_batch_size):
request = self.waiting_queue.popleft()
return Batch(requests=[request], is_prefill=True)
# Decode existing requests
if self.running_batch:
return Batch(requests=self.running_batch, is_prefill=False)
return None
def process_prefill_result(self, request: GenerateRequest, token: int):
"""Handle result from prefill."""
request.prefill_done = True
request.generated_tokens.append(token)
self.running_batch.append(request)
print(f"[Scheduler] Request {request.id} finished prefill, "
f"added to running batch (size: {len(self.running_batch)})")
def process_decode_result(self, request: GenerateRequest, token: int):
"""Handle result from decode."""
request.generated_tokens.append(token)
# Check if finished
if token == 1 or len(request.generated_tokens) >= request.max_tokens:
request.is_finished = True
self.completed.append(request)
print(f"[Scheduler] Request {request.id} finished "
f"({len(request.generated_tokens)} tokens)")
def has_work(self) -> bool:
"""Check if there's more work to do."""
return bool(self.waiting_queue or self.running_batch)
class InferenceServer:
"""
Main inference server orchestrating all components.
"""
def __init__(self, max_batch_size: int = 4):
self.tokenizer = SimpleTokenizer()
self.model_runner = SimpleModelRunner()
self.scheduler = Scheduler(max_batch_size)
self.request_counter = 0
async def generate(self, prompt: str, max_tokens: int = 50) -> GenerateRequest:
"""Submit a generation request."""
# Tokenize
tokens = self.tokenizer.encode(prompt)
# Create request
request = GenerateRequest(
id=self.request_counter,
prompt=prompt,
prompt_tokens=tokens,
max_tokens=max_tokens,
)
self.request_counter += 1
# Submit to scheduler
self.scheduler.add_request(request)
return request
async def run_step(self) -> bool:
"""Run one step of inference."""
batch = self.scheduler.get_next_batch()
if batch is None:
return False
if batch.is_prefill:
# Prefill phase
request = batch.requests[0]
print(f"[Server] Prefill request {request.id} "
f"({len(request.prompt_tokens)} prompt tokens)")
token = await self.model_runner.prefill(request)
self.scheduler.process_prefill_result(request, token)
else:
# Decode phase
print(f"[Server] Decode batch of {len(batch.requests)} requests")
tokens = await self.model_runner.decode(batch.requests)
for request, token in zip(batch.requests, tokens):
self.scheduler.process_decode_result(request, token)
return True
async def run_until_complete(self):
"""Run until all requests are complete."""
while self.scheduler.has_work():
await self.run_step()
async def run_demo(num_requests: int, max_batch_size: int):
"""Run a demonstration of the inference server."""
print("=" * 60)
print(" MINIMAL INFERENCE SERVER DEMO")
print("=" * 60)
server = InferenceServer(max_batch_size=max_batch_size)
# Sample prompts
prompts = [
"What is the capital of France?",
"Explain quantum computing in simple terms.",
"Write a haiku about programming.",
"What is machine learning?",
"Tell me a joke.",
"How does the internet work?",
"What is the meaning of life?",
"Describe a beautiful sunset.",
]
print(f"\nConfiguration:")
print(f" Max batch size: {max_batch_size}")
print(f" Number of requests: {num_requests}")
print(f"\n{'─' * 60}\n")
# Submit requests
requests = []
for i in range(num_requests):
prompt = prompts[i % len(prompts)]
request = await server.generate(prompt, max_tokens=20)
requests.append(request)
print(f"\n{'─' * 60}\n")
print("Processing requests...\n")
# Process all requests
start_time = time.time()
await server.run_until_complete()
total_time = time.time() - start_time
# Print results
print(f"\n{'─' * 60}")
print(" RESULTS")
print(f"{'─' * 60}\n")
total_tokens = 0
for req in server.scheduler.completed:
latency = time.time() - req.created_at
print(f"Request {req.id}: {len(req.generated_tokens)} tokens, "
f"{latency:.3f}s latency")
total_tokens += len(req.generated_tokens)
print(f"\n{'─' * 60}")
print(" SUMMARY")
print(f"{'─' * 60}")
print(f"Total requests: {num_requests}")
print(f"Total tokens generated: {total_tokens}")
print(f"Total time: {total_time:.3f}s")
print(f"Throughput: {total_tokens / total_time:.1f} tokens/second")
# Explain what's happening
print(f"\n{'─' * 60}")
print(" WHAT THIS DEMONSTRATES")
print(f"{'─' * 60}")
print("""
1. REQUEST FLOW:
Prompt → Tokenizer → Scheduler → Model Runner → Response
2. PREFILL vs DECODE:
- Prefill: Process entire prompt (one request at a time here)
- Decode: Generate tokens in batches
3. BATCHING:
- Multiple requests share GPU compute during decode
- Higher batch size → higher throughput but higher latency
4. CONTINUOUS BATCHING (simplified):
- New requests can start prefill while others decode
- Finished requests exit, making room for new ones
5. LIMITATIONS OF THIS DEMO:
- No actual model (just simulated delays)
- No KV cache management
- No memory management
- No streaming output
- Simplified scheduling logic
""")
def main():
parser = argparse.ArgumentParser(description="Minimal Inference Server Demo")
parser.add_argument("--num-requests", "-n", type=int, default=5,
help="Number of requests to process")
parser.add_argument("--batch-size", "-b", type=int, default=4,
help="Maximum batch size")
args = parser.parse_args()
asyncio.run(run_demo(args.num_requests, args.batch_size))
if __name__ == "__main__":
main()
Chapter 9: KV Cache Management
“In LLM inference, memory is the new compute. And KV cache is the memory hog.”
Learning Objectives
By the end of this chapter, you will be able to:
- Explain why KV cache exists and how it accelerates inference
- Calculate KV cache size for different models and contexts
- Understand PagedAttention and its benefits
- Explain how RadixCache enables prefix sharing
Prerequisites
- Completed Chapter 8 (Server Anatomy)
- Understanding of transformer attention mechanism
- Basic knowledge of memory management
Concept Overview
Why KV Cache?
In transformers, each attention layer computes:
Attention(Q, K, V) = softmax(QK^T / √d) V
During generation, we produce one token at a time. Without caching:
- Token 1: Compute K, V for position 0
- Token 2: Compute K, V for positions 0, 1 (recompute position 0!)
- Token 3: Compute K, V for positions 0, 1, 2 (recompute again!)
- …
- Token N: O(N²) total computations!
With KV cache:
- Token 1: Compute K₀, V₀, store in cache
- Token 2: Compute K₁, V₁, concatenate with cached K₀, V₀
- Token 3: Compute K₂, V₂, concatenate with cached
- …
- Token N: O(N) computations
KV cache trades memory for compute.
KV Cache Size Calculation
For each token, we store K and V for every layer:
KV per token = 2 × num_layers × num_heads × head_dim × dtype_size
Example (LLaMA-70B):
Layers: 80
Heads: 64 (8 KV heads with GQA)
Head dim: 128
Dtype: FP16 (2 bytes)
KV per token = 2 × 80 × 8 × 128 × 2 = 327,680 bytes ≈ 320 KB
For 32K context:
KV per request = 320 KB × 32K = 10.24 GB
A single request needs 10 GB of KV cache! This is why memory management is critical.
The Memory Fragmentation Problem
Traditional approach: Pre-allocate maximum context length per request.
Request A (needs 1K): [■■■■□□□□□□□□□□□□□□□□□□□□□□□□□□□□] 32K allocated
Request B (needs 2K): [■■■■■■■■□□□□□□□□□□□□□□□□□□□□□□□□] 32K allocated
Request C (needs 1K): [■■■■□□□□□□□□□□□□□□□□□□□□□□□□□□□□] 32K allocated
Total allocated: 96K tokens worth of memory
Actually used: 4K tokens
Waste: 96%!
This is internal fragmentation—memory reserved but unused.
PagedAttention: The Solution
PagedAttention (from vLLM) applies OS-style virtual memory to KV cache:
Physical Memory (Pages):
[Page 0][Page 1][Page 2][Page 3][Page 4][Page 5][Page 6][Page 7]
Request A (logical view): Request B (logical view):
[Tokens 0-255][Tokens 256-511] [Tokens 0-255][Tokens 256-511][Tokens 512-767]
↓ ↓ ↓ ↓ ↓
Page 2 Page 5 Page 0 Page 3 Page 7
(physical) (physical) (physical) (physical) (physical)
Key insight: Allocate physical pages only when needed. Different requests can share the same physical memory pool.
Benefits:
- Near-zero fragmentation
- Memory utilization > 95%
- More concurrent requests
The Three-Level KV Cache Hierarchy
Modern systems like SGLang use a three-level structure:
┌─────────────────────────────────────────────────────────────────┐
│ Level 1: RadixCache (Logical) │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ Radix Tree of Token Sequences │ │
│ │ │ │
│ │ [root] │ │
│ │ / \ │ │
│ │ "What is" "Tell me" │ │
│ │ / \ │ │
│ │ "the capital" "a joke" │ │
│ │ │ │
│ │ Purpose: Detect prefix sharing opportunities │ │
│ └─────────────────────────────────────────────────────────────┘ │
├─────────────────────────────────────────────────────────────────┤
│ Level 2: ReqToTokenPool (Mapping) │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ [request_id, token_position] → physical_slot_index │ │
│ │ │ │
│ │ (req_0, pos_0) → slot_42 │ │
│ │ (req_0, pos_1) → slot_17 │ │
│ │ (req_1, pos_0) → slot_42 ← Same slot! Prefix sharing! │ │
│ │ │ │
│ │ Purpose: Map logical positions to physical memory │ │
│ └─────────────────────────────────────────────────────────────┘ │
├─────────────────────────────────────────────────────────────────┤
│ Level 3: TokenToKVPool (Physical GPU Memory) │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ slot_index → actual K,V tensors on GPU │ │
│ │ │ │
│ │ Slot 0: [K tensor][V tensor] │ │
│ │ Slot 1: [K tensor][V tensor] │ │
│ │ ... │ │
│ │ Slot N: [K tensor][V tensor] │ │
│ │ │ │
│ │ Purpose: Store actual KV values │ │
│ └─────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
RadixCache: Automatic Prefix Sharing
RadixCache enables automatic prefix caching. If multiple requests share a prompt prefix, they share KV cache:
Request 1: "What is the capital of France?"
Request 2: "What is the capital of Germany?"
Request 3: "What is the largest planet?"
Radix Tree:
[root]
│
"What is the" (shared by all 3!)
/ \
"capital of" "largest planet"
/ \ │
"France" "Germany" [Request 3]
│ │
[Request 1] [Request 2]
Memory savings: "What is the" KV cache stored ONCE, used by 3 requests!
This is huge for:
- System prompts (shared across all requests)
- Few-shot examples
- Chat history prefixes
Cache Eviction: LRU with Reference Counting
When memory is full, we need to evict cached entries. RadixCache uses:
- Reference counting: Don’t evict entries in use
- LRU (Least Recently Used): Evict oldest unused entries first
Code Walkthrough
Script 1: kv_cache_calculator.py
Calculate KV cache sizes for different model configurations:
- Shows memory requirements per token, per request
- Estimates concurrent request capacity
- Compares with and without paging
Script 2: prefix_sharing_demo.py
Demonstrates how prefix sharing works:
- Shows memory savings from shared prefixes
- Visualizes the radix tree structure
- Calculates sharing efficiency
Memory Budget Planning
For a 70B model on 8× H100 (640 GB total):
| Component | Memory |
|---|---|
| Model weights (FP16) | 140 GB |
| CUDA kernels, etc. | ~20 GB |
| Available for KV cache | ~480 GB |
With 320 KB per token per request:
- Max tokens in cache: 480 GB / 320 KB = 1.5M tokens
- At 4K avg context: ~375 concurrent requests
- At 32K context: ~47 concurrent requests
This is why context length dramatically affects capacity!
Try It Yourself
Exercise 1: Calculate Your Model’s KV Cache
For your favorite model (LLaMA, Mistral, etc.):
- Find: num_layers, num_kv_heads, head_dim
- Calculate: KV bytes per token
- Calculate: Max requests at 8K context with 80GB GPU
Exercise 2: Measure Prefix Sharing Savings
Design a benchmark:
- Create 100 requests with shared system prompt
- Calculate memory with individual caching
- Calculate memory with prefix sharing
- What’s the savings percentage?
Exercise 3: Implement Simple LRU Cache
Implement a basic LRU cache for KV entries:
- Fixed capacity
- Reference counting
- Eviction when full
Key Takeaways
- KV cache is massive - Often larger than model weights for long contexts
- Fragmentation wastes memory - Pre-allocation is inefficient
- PagedAttention solves fragmentation - Near-100% memory utilization
- Prefix sharing saves memory - Especially for system prompts
- Memory limits concurrency - More memory = more concurrent requests
Trade-offs
| Approach | Pros | Cons |
|---|---|---|
| Pre-allocation | Simple, no overhead | Massive fragmentation |
| PagedAttention | Low fragmentation | Page table overhead |
| RadixCache | Prefix sharing | Tree management overhead |
| Quantized KV | Less memory | Slight quality loss |
What’s Next?
In Chapter 10, we’ll explore Advanced Scheduling and CUDA Graphs—how to hide scheduling overhead and maximize GPU utilization.
Further Reading
kv_cache_calculator.py
Calculate KV cache memory requirements for any model
This script helps you understand how much memory your KV cache will consume and plan your deployment accordingly.
What It Does
- Takes model parameters (layers, heads, head_dim, dtype)
- Calculates KV cache size per token
- Estimates memory for different context lengths
- Shows concurrent request capacity
Run It
python tutorial/part3-inference/chapter09-kv-cache/scripts/kv_cache_calculator.py
Example Output
=== KV Cache Calculator ===
Model: LLaMA-70B
Layers: 80
KV Heads: 8 (GQA)
Head Dim: 128
Dtype: FP16
KV Cache Size:
Per token: 320 KB
Per request (4K context): 1.28 GB
Per request (32K context): 10.24 GB
With 80 GB GPU Memory:
Model weights (FP16): 140 GB (requires 2+ GPUs)
After weights on 8x H100: ~480 GB available
Max concurrent requests:
At 4K context: 375 requests
At 8K context: 187 requests
At 32K context: 46 requests
Warning: Long context dramatically reduces concurrency!
The Formula
kv_bytes_per_token = 2 × layers × kv_heads × head_dim × dtype_bytes
↑ ↑ ↑
K+V layers 2 for FP16
Source Code
#!/usr/bin/env python3
"""
KV Cache Calculator
Calculate KV cache memory requirements for different LLM configurations.
This helps understand memory constraints and capacity planning.
Usage:
python kv_cache_calculator.py
python kv_cache_calculator.py --model llama-70b
python kv_cache_calculator.py --custom --layers 80 --heads 64 --dim 128
"""
import argparse
from dataclasses import dataclass
from typing import Dict
@dataclass
class ModelConfig:
"""Model configuration for KV cache calculation."""
name: str
num_layers: int
num_kv_heads: int # KV heads (may differ from attention heads with GQA)
head_dim: int
vocab_size: int = 128000
@property
def kv_bytes_per_token(self) -> int:
"""Calculate KV cache bytes per token (for FP16)."""
# K and V for each layer
return 2 * self.num_layers * self.num_kv_heads * self.head_dim * 2 # 2 bytes for FP16
# Common model configurations
MODELS = {
"llama-7b": ModelConfig("LLaMA-7B", num_layers=32, num_kv_heads=32, head_dim=128),
"llama-13b": ModelConfig("LLaMA-13B", num_layers=40, num_kv_heads=40, head_dim=128),
"llama-70b": ModelConfig("LLaMA-70B", num_layers=80, num_kv_heads=8, head_dim=128), # GQA
"mistral-7b": ModelConfig("Mistral-7B", num_layers=32, num_kv_heads=8, head_dim=128), # GQA
"qwen-72b": ModelConfig("Qwen-72B", num_layers=80, num_kv_heads=8, head_dim=128),
"deepseek-67b": ModelConfig("DeepSeek-67B", num_layers=95, num_kv_heads=8, head_dim=128),
}
def format_bytes(bytes_val: float) -> str:
"""Format bytes into human-readable string."""
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if bytes_val < 1024:
return f"{bytes_val:.2f} {unit}"
bytes_val /= 1024
return f"{bytes_val:.2f} PB"
def calculate_kv_cache(model: ModelConfig, context_lengths: list,
dtype_bytes: int = 2) -> Dict:
"""Calculate KV cache requirements."""
kv_per_token = (2 * model.num_layers * model.num_kv_heads *
model.head_dim * dtype_bytes)
results = {
'model': model.name,
'layers': model.num_layers,
'kv_heads': model.num_kv_heads,
'head_dim': model.head_dim,
'kv_bytes_per_token': kv_per_token,
'contexts': {}
}
for ctx_len in context_lengths:
kv_per_request = kv_per_token * ctx_len
results['contexts'][ctx_len] = {
'per_request': kv_per_request,
'per_request_formatted': format_bytes(kv_per_request),
}
return results
def analyze_capacity(model: ModelConfig, gpu_memory_gb: float,
model_size_gb: float, context_length: int,
dtype_bytes: int = 2) -> Dict:
"""Analyze how many concurrent requests can be served."""
# Available memory for KV cache
overhead_gb = 2 # CUDA kernels, activations, etc.
available_gb = gpu_memory_gb - model_size_gb - overhead_gb
if available_gb <= 0:
return {
'error': 'Model does not fit in GPU memory',
'available_gb': available_gb,
}
# KV cache per request
kv_per_token = (2 * model.num_layers * model.num_kv_heads *
model.head_dim * dtype_bytes)
kv_per_request = kv_per_token * context_length
kv_per_request_gb = kv_per_request / (1024 ** 3)
# Max concurrent requests
max_requests = int(available_gb / kv_per_request_gb)
# With PagedAttention (assuming 95% utilization vs 50% without)
requests_without_paging = int(max_requests * 0.5) # 50% utilization
requests_with_paging = int(max_requests * 0.95) # 95% utilization
return {
'gpu_memory_gb': gpu_memory_gb,
'model_size_gb': model_size_gb,
'available_for_kv_gb': available_gb,
'context_length': context_length,
'kv_per_request_gb': kv_per_request_gb,
'max_theoretical_requests': max_requests,
'requests_without_paging': requests_without_paging,
'requests_with_paging': requests_with_paging,
'paging_improvement': f"{(requests_with_paging / requests_without_paging - 1) * 100:.0f}%"
}
def compare_fragmentation(model: ModelConfig, requests: int,
avg_context: int, max_context: int,
dtype_bytes: int = 2) -> Dict:
"""Compare memory usage with and without paging."""
kv_per_token = (2 * model.num_layers * model.num_kv_heads *
model.head_dim * dtype_bytes)
# Without paging: allocate max_context for each request
memory_without_paging = requests * max_context * kv_per_token
# With paging: only allocate what's actually used
memory_with_paging = requests * avg_context * kv_per_token
waste = memory_without_paging - memory_with_paging
waste_pct = (waste / memory_without_paging) * 100
return {
'requests': requests,
'avg_context': avg_context,
'max_context': max_context,
'memory_without_paging': format_bytes(memory_without_paging),
'memory_with_paging': format_bytes(memory_with_paging),
'memory_wasted': format_bytes(waste),
'waste_percentage': f"{waste_pct:.1f}%",
}
def print_model_comparison():
"""Print KV cache comparison for common models."""
print("=" * 70)
print(" KV CACHE SIZE COMPARISON ACROSS MODELS")
print("=" * 70)
context_lengths = [2048, 4096, 8192, 32768, 131072]
print(f"\n{'Model':<15} {'Layers':<8} {'KV Heads':<10} "
f"{'Per Token':<12} {'@ 8K ctx':<12} {'@ 32K ctx':<12}")
print("-" * 70)
for name, model in MODELS.items():
results = calculate_kv_cache(model, context_lengths)
per_token = format_bytes(results['kv_bytes_per_token'])
at_8k = results['contexts'][8192]['per_request_formatted']
at_32k = results['contexts'][32768]['per_request_formatted']
print(f"{model.name:<15} {model.num_layers:<8} {model.num_kv_heads:<10} "
f"{per_token:<12} {at_8k:<12} {at_32k:<12}")
def print_capacity_analysis(model_name: str, gpu_config: str):
"""Print capacity analysis for a specific configuration."""
model = MODELS.get(model_name.lower())
if not model:
print(f"Unknown model: {model_name}")
return
# GPU configurations
gpu_configs = {
"h100": (80, "H100 80GB"),
"a100": (80, "A100 80GB"),
"a100-40": (40, "A100 40GB"),
"4090": (24, "RTX 4090 24GB"),
}
# Model sizes (approximate, FP16)
model_sizes = {
"llama-7b": 14,
"llama-13b": 26,
"llama-70b": 140,
"mistral-7b": 14,
"qwen-72b": 144,
"deepseek-67b": 134,
}
gpu_memory, gpu_name = gpu_configs.get(gpu_config, (80, "Custom"))
model_size = model_sizes.get(model_name.lower(), 14)
print("\n" + "=" * 70)
print(f" CAPACITY ANALYSIS: {model.name} on {gpu_name}")
print("=" * 70)
for context_len in [2048, 4096, 8192, 32768]:
capacity = analyze_capacity(model, gpu_memory, model_size, context_len)
if 'error' in capacity:
print(f"\n@ {context_len} context: {capacity['error']}")
continue
print(f"\n@ {context_len} context length:")
print(f" Available for KV cache: {capacity['available_for_kv_gb']:.1f} GB")
print(f" KV per request: {capacity['kv_per_request_gb']:.2f} GB")
print(f" Without PagedAttention: ~{capacity['requests_without_paging']} concurrent requests")
print(f" With PagedAttention: ~{capacity['requests_with_paging']} concurrent requests")
print(f" Improvement: {capacity['paging_improvement']}")
def print_fragmentation_analysis(model_name: str):
"""Show memory fragmentation with and without paging."""
model = MODELS.get(model_name.lower())
if not model:
print(f"Unknown model: {model_name}")
return
print("\n" + "=" * 70)
print(f" FRAGMENTATION ANALYSIS: {model.name}")
print("=" * 70)
scenarios = [
(100, 512, 8192, "Short prompts, 8K max"),
(50, 2048, 8192, "Medium prompts, 8K max"),
(20, 4096, 32768, "Long prompts, 32K max"),
(10, 8192, 131072, "Very long, 128K max"),
]
for requests, avg_ctx, max_ctx, desc in scenarios:
frag = compare_fragmentation(model, requests, avg_ctx, max_ctx)
print(f"\nScenario: {desc}")
print(f" Requests: {requests}, Avg context: {avg_ctx}, Max context: {max_ctx}")
print(f" Without paging: {frag['memory_without_paging']}")
print(f" With paging: {frag['memory_with_paging']}")
print(f" Memory saved: {frag['memory_wasted']} ({frag['waste_percentage']} reduction)")
def main():
parser = argparse.ArgumentParser(description="KV Cache Calculator")
parser.add_argument("--model", "-m", type=str, default="llama-70b",
choices=list(MODELS.keys()),
help="Model to analyze")
parser.add_argument("--gpu", "-g", type=str, default="h100",
choices=["h100", "a100", "a100-40", "4090"],
help="GPU type")
parser.add_argument("--custom", action="store_true",
help="Use custom model config")
parser.add_argument("--layers", type=int, default=80,
help="Number of layers (with --custom)")
parser.add_argument("--heads", type=int, default=8,
help="Number of KV heads (with --custom)")
parser.add_argument("--dim", type=int, default=128,
help="Head dimension (with --custom)")
args = parser.parse_args()
print("╔" + "═" * 68 + "╗")
print("║" + " KV CACHE CALCULATOR".center(68) + "║")
print("╚" + "═" * 68 + "╝")
if args.custom:
custom_model = ModelConfig(
"Custom",
num_layers=args.layers,
num_kv_heads=args.heads,
head_dim=args.dim
)
MODELS["custom"] = custom_model
args.model = "custom"
# Model comparison
print_model_comparison()
# Capacity analysis
print_capacity_analysis(args.model, args.gpu)
# Fragmentation analysis
print_fragmentation_analysis(args.model)
# Key insights
print("\n" + "=" * 70)
print(" KEY INSIGHTS")
print("=" * 70)
print("""
1. KV CACHE DOMINATES MEMORY
- For long contexts, KV cache >> model weights
- 70B model @ 32K context: 140GB weights vs ~10GB KV per request
2. GQA DRAMATICALLY REDUCES KV CACHE
- LLaMA-70B uses 8 KV heads (vs 64 attention heads)
- 8x smaller KV cache per token!
3. PAGEDATTENTION NEARLY DOUBLES CAPACITY
- Eliminates internal fragmentation
- 95% utilization vs ~50% without paging
4. CONTEXT LENGTH IS THE KILLER
- 32K context: ~47 concurrent requests
- 128K context: ~12 concurrent requests
- Same GPU, same model!
5. QUANTIZED KV CACHE HELPS
- FP8 KV cache: 2x more requests
- INT8 KV cache: similar benefits
- Some quality trade-off
""")
if __name__ == "__main__":
main()
prefix_sharing_demo.py
Demonstrate memory savings from shared prompt prefixes
This script shows how RadixCache saves memory by sharing KV cache for common prompt prefixes.
What It Does
- Creates multiple requests with shared prefixes
- Shows memory usage WITHOUT prefix sharing
- Shows memory usage WITH prefix sharing
- Visualizes the radix tree structure
Run It
python tutorial/part3-inference/chapter09-kv-cache/scripts/prefix_sharing_demo.py
Example Output
=== Prefix Sharing Demo ===
Requests:
1. "You are a helpful assistant. What is 2+2?"
2. "You are a helpful assistant. Explain quantum computing."
3. "You are a helpful assistant. Write a poem."
Shared Prefix: "You are a helpful assistant. " (7 tokens)
Memory Analysis:
Without sharing:
Request 1: 100 tokens × 320 KB = 32 MB
Request 2: 120 tokens × 320 KB = 38.4 MB
Request 3: 90 tokens × 320 KB = 28.8 MB
Total: 99.2 MB
With sharing:
Shared prefix: 7 tokens × 320 KB = 2.24 MB (stored once)
Request 1 unique: 93 tokens × 320 KB = 29.76 MB
Request 2 unique: 113 tokens × 320 KB = 36.16 MB
Request 3 unique: 83 tokens × 320 KB = 26.56 MB
Total: 94.72 MB
Savings: 4.5% (increases with more requests sharing the prefix!)
Radix Tree:
[root]
│
"You are a helpful assistant."
/ | \
"What is" "Explain" "Write"
│ │ │
"2+2?" "quantum" "a poem"
Why This Matters
With 100 requests sharing a system prompt:
- Without sharing: 100× full prompt
- With sharing: 1× shared + 100× unique parts
- Savings: Up to 90%+ for long system prompts!
Source Code
#!/usr/bin/env python3
"""
Prefix Sharing Demonstration
This script demonstrates how prefix sharing (RadixCache) saves memory
by reusing KV cache for common prompt prefixes.
Usage:
python prefix_sharing_demo.py
python prefix_sharing_demo.py --num-requests 100
"""
import argparse
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set
from collections import defaultdict
@dataclass
class RadixNode:
"""A node in the radix tree."""
token: Optional[int] = None
children: Dict[int, 'RadixNode'] = field(default_factory=dict)
kv_index: Optional[int] = None # Index into KV cache
ref_count: int = 0 # Number of requests using this node
class RadixTree:
"""
Simplified RadixCache for demonstration.
Real implementations are more complex with:
- Compression (multiple tokens per node)
- LRU eviction
- Reference counting for safe deletion
"""
def __init__(self):
self.root = RadixNode()
self.next_kv_index = 0
self.total_nodes = 0
self.shared_nodes = 0
def insert(self, tokens: List[int]) -> List[int]:
"""
Insert a sequence and return KV indices.
Returns list of KV cache indices for each token.
Reuses existing indices where prefixes match.
"""
kv_indices = []
node = self.root
for token in tokens:
if token not in node.children:
# Create new node
new_node = RadixNode(token=token)
new_node.kv_index = self.next_kv_index
self.next_kv_index += 1
node.children[token] = new_node
self.total_nodes += 1
else:
# Reuse existing node (prefix sharing!)
self.shared_nodes += 1
node = node.children[token]
node.ref_count += 1
kv_indices.append(node.kv_index)
return kv_indices
def get_stats(self) -> Dict:
"""Get statistics about the tree."""
return {
'total_nodes': self.total_nodes,
'shared_accesses': self.shared_nodes,
'unique_kv_entries': self.next_kv_index,
}
def visualize_tree(node: RadixNode, prefix: str = "", is_last: bool = True,
depth: int = 0, max_depth: int = 5) -> List[str]:
"""Visualize the radix tree structure."""
lines = []
if depth > max_depth:
return lines
connector = "└── " if is_last else "├── "
token_str = f"[{node.token}]" if node.token is not None else "[root]"
ref_str = f" (refs: {node.ref_count})" if node.ref_count > 0 else ""
lines.append(f"{prefix}{connector}{token_str}{ref_str}")
children = list(node.children.values())
for i, child in enumerate(children):
extension = " " if is_last else "│ "
child_is_last = (i == len(children) - 1)
lines.extend(visualize_tree(child, prefix + extension, child_is_last,
depth + 1, max_depth))
return lines
def demo_prefix_sharing():
"""Demonstrate prefix sharing with example requests."""
print("=" * 70)
print(" RADIX CACHE PREFIX SHARING DEMO")
print("=" * 70)
# Simulate a tokenizer (just use word indices)
def tokenize(text: str) -> List[int]:
words = text.lower().split()
return [hash(w) % 1000 for w in words]
# Example requests with shared prefixes
requests = [
"You are a helpful assistant. What is the capital of France?",
"You are a helpful assistant. What is the capital of Germany?",
"You are a helpful assistant. What is the largest planet?",
"You are a helpful assistant. Tell me a joke.",
"You are a coding assistant. Write a Python function.",
"You are a coding assistant. Explain recursion.",
]
tree = RadixTree()
total_tokens = 0
request_indices = []
print("\nProcessing requests:\n")
for i, request in enumerate(requests):
tokens = tokenize(request)
total_tokens += len(tokens)
kv_indices = tree.insert(tokens)
request_indices.append(kv_indices)
print(f"Request {i + 1}: {request[:50]}...")
print(f" Tokens: {len(tokens)}, KV indices assigned: {len(set(kv_indices))} unique")
# Statistics
stats = tree.get_stats()
print("\n" + "-" * 70)
print(" MEMORY ANALYSIS")
print("-" * 70)
print(f"\nWithout prefix sharing:")
print(f" Total tokens across all requests: {total_tokens}")
print(f" KV cache entries needed: {total_tokens}")
print(f"\nWith prefix sharing (RadixCache):")
print(f" Unique KV cache entries: {stats['unique_kv_entries']}")
print(f" Shared prefix accesses: {stats['shared_accesses']}")
savings = (1 - stats['unique_kv_entries'] / total_tokens) * 100
print(f"\nMemory savings: {savings:.1f}%")
# Visualize tree (simplified)
print("\n" + "-" * 70)
print(" RADIX TREE STRUCTURE (first 5 levels)")
print("-" * 70)
print("\n".join(visualize_tree(tree.root)))
def analyze_system_prompt_sharing(num_requests: int, system_prompt_len: int,
user_prompt_len: int, kv_bytes_per_token: int):
"""Analyze memory savings from system prompt sharing."""
print("\n" + "=" * 70)
print(" SYSTEM PROMPT SHARING ANALYSIS")
print("=" * 70)
print(f"\nConfiguration:")
print(f" Number of requests: {num_requests}")
print(f" System prompt length: {system_prompt_len} tokens")
print(f" User prompt length: {user_prompt_len} tokens (average)")
print(f" KV bytes per token: {kv_bytes_per_token}")
total_tokens = num_requests * (system_prompt_len + user_prompt_len)
without_sharing = total_tokens * kv_bytes_per_token
# With sharing: system prompt cached once, user prompts unique
with_sharing = (system_prompt_len + num_requests * user_prompt_len) * kv_bytes_per_token
savings = without_sharing - with_sharing
savings_pct = (savings / without_sharing) * 100
print(f"\nMemory usage:")
print(f" Without sharing: {without_sharing / 1e9:.2f} GB")
print(f" With sharing: {with_sharing / 1e9:.2f} GB")
print(f" Saved: {savings / 1e9:.2f} GB ({savings_pct:.1f}%)")
# Break down by component
system_memory = system_prompt_len * kv_bytes_per_token
user_memory = num_requests * user_prompt_len * kv_bytes_per_token
print(f"\nWith sharing breakdown:")
print(f" System prompt (shared): {system_memory / 1e6:.2f} MB (cached once)")
print(f" User prompts (unique): {user_memory / 1e9:.2f} GB")
def analyze_few_shot_sharing(num_requests: int, num_examples: int,
example_len: int, query_len: int,
kv_bytes_per_token: int):
"""Analyze memory savings from few-shot example sharing."""
print("\n" + "=" * 70)
print(" FEW-SHOT EXAMPLE SHARING ANALYSIS")
print("=" * 70)
few_shot_len = num_examples * example_len
print(f"\nConfiguration:")
print(f" Number of requests: {num_requests}")
print(f" Few-shot examples: {num_examples} × {example_len} = {few_shot_len} tokens")
print(f" Query length: {query_len} tokens (average)")
total_tokens = num_requests * (few_shot_len + query_len)
without_sharing = total_tokens * kv_bytes_per_token
with_sharing = (few_shot_len + num_requests * query_len) * kv_bytes_per_token
savings = without_sharing - with_sharing
savings_pct = (savings / without_sharing) * 100
print(f"\nMemory usage:")
print(f" Without sharing: {without_sharing / 1e9:.2f} GB")
print(f" With sharing: {with_sharing / 1e9:.2f} GB")
print(f" Saved: {savings / 1e9:.2f} GB ({savings_pct:.1f}%)")
def main():
parser = argparse.ArgumentParser(description="Prefix Sharing Demo")
parser.add_argument("--num-requests", "-n", type=int, default=100,
help="Number of requests for analysis")
parser.add_argument("--system-prompt-len", type=int, default=500,
help="System prompt length in tokens")
parser.add_argument("--user-prompt-len", type=int, default=100,
help="Average user prompt length")
args = parser.parse_args()
print("╔" + "═" * 68 + "╗")
print("║" + " PREFIX SHARING (RADIXCACHE) DEMONSTRATION".center(68) + "║")
print("╚" + "═" * 68 + "╝")
# Basic demo
demo_prefix_sharing()
# LLaMA-70B KV bytes per token (with GQA)
kv_bytes = 2 * 80 * 8 * 128 * 2 # 327,680 bytes
# System prompt sharing analysis
analyze_system_prompt_sharing(
num_requests=args.num_requests,
system_prompt_len=args.system_prompt_len,
user_prompt_len=args.user_prompt_len,
kv_bytes_per_token=kv_bytes
)
# Few-shot sharing analysis
analyze_few_shot_sharing(
num_requests=args.num_requests,
num_examples=5,
example_len=200,
query_len=50,
kv_bytes_per_token=kv_bytes
)
# Key insights
print("\n" + "=" * 70)
print(" KEY INSIGHTS")
print("=" * 70)
print("""
1. SYSTEM PROMPTS ARE FREE (almost)
- First request pays the cost
- Subsequent requests share the KV cache
- Especially valuable for long system prompts
2. FEW-SHOT EXAMPLES BENEFIT HUGELY
- 5 examples × 200 tokens = 1000 tokens shared
- With 100 requests: 99% memory reduction for examples!
3. RADIXCACHE IS AUTOMATIC
- No manual prefix specification needed
- Tree structure detects sharing automatically
- Works for any common prefix
4. LIMITATIONS:
- Only exact prefix matches benefit
- Different orderings = different prefixes
- Token-level sharing (not semantic)
5. REAL-WORLD IMPACT:
- APIs with shared system prompts: massive savings
- Batch inference with templates: huge efficiency
- Speculative decoding: shared draft prefixes
""")
if __name__ == "__main__":
main()
Chapter 10: Advanced Scheduling and CUDA Graphs
“The fastest code is code that doesn’t run. The second fastest is code you ran yesterday.”
Learning Objectives
By the end of this chapter, you will be able to:
- Explain why CPU scheduling overhead matters for inference
- Understand CUDA Graphs and when to use them
- Describe zero-overhead scheduling with FutureMap
- Identify scenarios where CUDA Graphs help vs hurt
Prerequisites
- Completed Chapter 8-9 (Server Anatomy, KV Cache)
- Basic understanding of CPU-GPU interaction
- Familiarity with asynchronous programming
Concept Overview
The Scheduling Overhead Problem
LLM inference involves many small operations:
- Schedule which requests to process
- Prepare batch metadata
- Launch GPU kernels
- Wait for results
- Process outputs
The problem? Steps 1, 2, and 5 happen on CPU while GPU waits!
Traditional scheduling:
Time →
CPU: [Schedule][Prepare]............[Process]............[Schedule][Prepare]
GPU: ............[Compute]............[idle].............[Compute]..........
↑ ↑
GPU working GPU waiting for CPU!
For small decode batches, CPU overhead can exceed GPU compute time.
How Bad Is It?
On a high-end setup (H100 + fast CPU):
- GPU decode step: ~5-10ms
- CPU scheduling overhead: ~2-5ms
That’s 20-50% overhead! For latency-sensitive applications, this is unacceptable.
CUDA Graphs: Recording Once, Playing Many
CUDA Graphs capture a sequence of GPU operations and replay them with minimal CPU overhead.
# Traditional approach (CPU launches each kernel)
for i in range(1000):
output = model(input) # CPU launches kernels each time
# CUDA Graph approach
# Step 1: Capture
with torch.cuda.graph(graph):
output = model(input) # Operations recorded, not executed
# Step 2: Replay
for i in range(1000):
graph.replay() # Single CPU call replays entire graph
Why it’s fast:
- One CPU→GPU launch instead of many
- GPU executes pre-optimized kernel sequence
- No kernel launch latency per operation
CUDA Graphs: The Constraints
CUDA Graphs require static computation:
| Allowed | Not Allowed |
|---|---|
| Fixed tensor shapes | Dynamic shapes |
| Deterministic operations | Random dropout |
| Pre-allocated memory | Dynamic allocation |
| Fixed control flow | Data-dependent branching |
This is perfect for inference (fixed model) but problematic for training.
Why Training Rarely Uses CUDA Graphs
- Dynamic optimizer updates: Gradient clipping changes behavior
- Learning rate scheduling: Different computations each step
- Gradient accumulation: Variable number of backwards
- Dropout: Random behavior
- Dynamic memory: Activation checkpointing allocates/frees
Zero-Overhead Scheduling: The FutureMap
SGLang’s innovation: overlap CPU scheduling with GPU compute.
Traditional:
Batch N: [CPU Schedule N] → [GPU Compute N] → [CPU Process N]
↓
Batch N+1: [CPU Schedule N+1] → [GPU Compute N+1]
Overlapped:
Batch N: [CPU Schedule N] → [GPU Compute N] ────────────────→
↓
Batch N+1: [CPU Schedule N+1] → [GPU Compute N+1]
↑
Running in parallel!
The challenge: Batch N+1’s inputs might depend on Batch N’s outputs!
FutureMap: Speculative Scheduling
FutureMap solves this with symbolic references:
1. CPU pre-allocates slots for Batch N's output
2. CPU schedules Batch N+1 using slot references (not actual values)
3. GPU runs Batch N, writes to pre-allocated slots
4. GPU's "resolve" kernel substitutes symbolic refs with real values
5. GPU runs Batch N+1
┌────────────────────────────────────────────────────────────────────┐
│ FutureMap Mechanism │
│ │
│ CPU Thread: │
│ 1. Reserve slots for Batch N output │
│ 2. Build Batch N+1 input with symbolic refs: [slot_42, slot_43] │
│ 3. Continue scheduling (no blocking!) │
│ │
│ GPU Thread: │
│ 1. Compute Batch N │
│ 2. Write results to reserved slots (42, 43) │
│ 3. Resolve kernel: [slot_42, slot_43] → [actual_token_ids] │
│ 4. Compute Batch N+1 │
│ │
└────────────────────────────────────────────────────────────────────┘
The Complete Overlap Scheduler
async def overlap_scheduler_loop(self):
"""SGLang's overlapped scheduling loop."""
last_batch = None
last_result = None
while True:
# Step 1: Schedule NEXT batch (CPU)
# This happens WHILE previous batch is computing!
next_batch = self.get_next_batch_to_run()
# Step 2: Launch next batch (GPU, non-blocking)
next_result = self.run_batch(next_batch)
# Step 3: Process PREVIOUS batch results (CPU)
# By now, previous batch is likely done
if last_batch is not None:
self.process_batch_result(last_batch, last_result)
last_batch = next_batch
last_result = next_result
Combining CUDA Graphs with Overlap Scheduling
The ultimate optimization:
- CUDA Graphs for decode batches (fixed shape, repeated)
- Overlap scheduling for prefill/mixed batches
- FutureMap to bridge the gap
Decode path (CUDA Graph):
- Capture graph for batch sizes [1, 2, 4, 8, 16, ...]
- Replay appropriate graph based on batch size
- Near-zero CPU overhead
Prefill path (Overlap):
- Variable prompt lengths
- Use overlap scheduling with FutureMap
- Reduced but not eliminated CPU overhead
Code Walkthrough
Script 1: cuda_graph_simple.py
Demonstrates CUDA Graphs:
- Captures a simple model forward pass
- Compares replay vs normal execution
- Shows the constraints
Script 2: scheduling_overhead_benchmark.py
Measures scheduling overhead:
- Time breakdown: scheduling vs compute
- Impact of batch size
- Benefits of overlap scheduling
Try It Yourself
Exercise 1: Measure Kernel Launch Overhead
Write a benchmark that:
- Runs 100 small matrix multiplications normally
- Captures them in a CUDA Graph
- Compares total time
Exercise 2: Understand Shape Constraints
Try to capture a CUDA Graph with:
- Fixed input shape → works
- Different input shapes → observe behavior
- How do real systems handle multiple shapes?
Exercise 3: Simulate Overlap Scheduling
Implement a simple overlap scheduler:
- Queue of “batches” (just sleep timers)
- Measure throughput with vs without overlap
- What’s the speedup?
Key Takeaways
- CPU overhead is real - Can be 20-50% of decode time
- CUDA Graphs eliminate kernel launch overhead - But need static shapes
- Overlap scheduling hides CPU work - Schedule N+1 while computing N
- FutureMap enables speculation - Pre-allocate outputs, resolve later
- Real systems combine techniques - CUDA Graphs for decode, overlap for prefill
The Speed Hierarchy
From fastest to slowest:
- CUDA Graph replay: ~0.01ms overhead
- Overlap scheduled: ~0.5ms (hidden)
- Normal scheduling: ~2-5ms
- Naive Python loop: ~10ms+
When Not to Use CUDA Graphs
- Variable sequence lengths (prefill)
- Dynamic batch sizes (requests finishing)
- Debugging (graphs hide errors)
- Memory-constrained (graphs consume memory)
What’s Next?
In Chapter 11, we’ll explore Speculative and Constraint Decoding—using draft models to speed up generation and grammar constraints to ensure structured output.
Further Reading
cuda_graph_simple.py
Understand CUDA Graphs with a simple example
This script demonstrates how CUDA Graphs work by capturing and replaying a simple computation.
What It Does
- Creates a simple model (matrix multiplications)
- Runs it normally (CPU launches each kernel)
- Captures it as a CUDA Graph
- Replays the graph (single launch)
- Compares performance
Run It
python tutorial/part3-inference/chapter10-scheduling-cuda/scripts/cuda_graph_simple.py
Example Output
=== CUDA Graph Demo ===
Model: 3-layer MLP (1024 → 1024 → 1024 → 1024)
Normal execution (100 iterations):
Total time: 15.2 ms
Per iteration: 0.152 ms
CUDA Graph execution (100 iterations):
Capture time: 0.5 ms (one-time)
Total replay time: 3.1 ms
Per iteration: 0.031 ms
Speedup: 4.9x faster with CUDA Graphs!
Reason: Normal execution has ~0.12ms kernel launch overhead per iteration.
CUDA Graphs amortize this to near-zero.
Key Concepts
Capture Phase:
# Record operations into a graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
output = model(input) # NOT executed, just recorded
Replay Phase:
# Execute the recorded graph
for _ in range(100):
g.replay() # Single CPU→GPU call, entire sequence runs
The Constraint
CUDA Graphs need FIXED shapes. This doesn’t work:
# ERROR: Shape changes between iterations
for i in range(10):
input = torch.randn(i + 1, 1024) # Different size each time!
output = model(input)
Real systems capture graphs for common shapes: [1, 2, 4, 8, 16, …] batch sizes.
Source Code
#!/usr/bin/env python3
"""
CUDA Graphs Simple Demonstration
This script demonstrates CUDA Graphs for reducing kernel launch overhead.
CUDA Graphs record GPU operations and replay them with minimal CPU overhead.
Usage:
python cuda_graph_simple.py
Note: Requires CUDA GPU. Falls back to simulation on CPU.
"""
import argparse
import time
from typing import Tuple
def check_cuda() -> bool:
"""Check if CUDA is available."""
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False
def run_with_cuda():
"""Run the demo with actual CUDA Graphs."""
import torch
import torch.nn as nn
print("=" * 60)
print(" CUDA GRAPHS DEMONSTRATION")
print("=" * 60)
print(f"\nUsing GPU: {torch.cuda.get_device_name(0)}")
# Create a simple model
model = nn.Sequential(
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
).cuda()
model.eval()
# Fixed-size input (required for CUDA Graphs)
batch_size = 32
input_tensor = torch.randn(batch_size, 512, device='cuda')
output_tensor = torch.zeros(batch_size, 512, device='cuda')
num_iterations = 1000
# =========================================================================
# Method 1: Normal execution (kernel launch per operation)
# =========================================================================
print("\n" + "-" * 60)
print(" Method 1: Normal Execution")
print("-" * 60)
# Warmup
for _ in range(10):
with torch.no_grad():
_ = model(input_tensor)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_iterations):
with torch.no_grad():
output = model(input_tensor)
torch.cuda.synchronize()
normal_time = time.perf_counter() - start
print(f"Total time: {normal_time * 1000:.2f} ms")
print(f"Per iteration: {normal_time / num_iterations * 1000:.3f} ms")
# =========================================================================
# Method 2: CUDA Graph capture and replay
# =========================================================================
print("\n" + "-" * 60)
print(" Method 2: CUDA Graph Replay")
print("-" * 60)
# Create static tensors for capture
static_input = torch.randn(batch_size, 512, device='cuda')
static_output = torch.zeros(batch_size, 512, device='cuda')
# Warmup for capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
with torch.no_grad():
static_output = model(static_input)
torch.cuda.current_stream().wait_stream(s)
# Capture the graph
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
with torch.no_grad():
static_output = model(static_input)
print("Graph captured successfully!")
print(f"Graph contains operations for: Linear → ReLU → Linear → ReLU → Linear")
# Benchmark graph replay
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(num_iterations):
# Copy new input to static buffer
static_input.copy_(input_tensor)
# Replay the graph
graph.replay()
torch.cuda.synchronize()
graph_time = time.perf_counter() - start
print(f"\nTotal time: {graph_time * 1000:.2f} ms")
print(f"Per iteration: {graph_time / num_iterations * 1000:.3f} ms")
# =========================================================================
# Comparison
# =========================================================================
print("\n" + "=" * 60)
print(" COMPARISON")
print("=" * 60)
speedup = normal_time / graph_time
overhead_saved = (normal_time - graph_time) / num_iterations * 1000
print(f"\nNormal execution: {normal_time / num_iterations * 1000:.3f} ms/iter")
print(f"CUDA Graph replay: {graph_time / num_iterations * 1000:.3f} ms/iter")
print(f"Speedup: {speedup:.2f}x")
print(f"Overhead saved per iteration: {overhead_saved:.3f} ms")
# Verify correctness
static_input.copy_(input_tensor)
graph.replay()
torch.cuda.synchronize()
with torch.no_grad():
expected = model(input_tensor)
diff = (static_output - expected).abs().max().item()
print(f"\nCorrectness check (max diff): {diff:.2e}")
print(f"Results match: {diff < 1e-5}")
# =========================================================================
# Demonstrate constraints
# =========================================================================
print("\n" + "=" * 60)
print(" CUDA GRAPH CONSTRAINTS")
print("=" * 60)
print("""
CUDA Graphs REQUIRE:
✓ Fixed tensor shapes
✓ Pre-allocated output buffers
✓ Deterministic operations
✓ Static control flow
CUDA Graphs FORBID:
✗ Dynamic shapes (different batch sizes)
✗ Random operations (dropout)
✗ CPU-GPU synchronization points
✗ Memory allocation during execution
For LLM inference:
- Decode phase: Fixed batch size → CUDA Graphs work great!
- Prefill phase: Variable prompt lengths → Cannot use CUDA Graphs
- Solution: Capture graphs for common batch sizes, fall back otherwise
""")
def run_simulation():
"""Simulate CUDA Graphs concept without GPU."""
print("=" * 60)
print(" CUDA GRAPHS SIMULATION (No GPU)")
print("=" * 60)
print("\nNote: Running without GPU. Demonstrating concept only.\n")
# Simulate overhead
kernel_launch_overhead_ms = 0.05 # Per kernel
compute_time_ms = 0.5 # Actual compute
num_kernels = 5 # Linear, ReLU, Linear, ReLU, Linear
num_iterations = 1000
# Normal execution: overhead per kernel per iteration
normal_time = num_iterations * (num_kernels * kernel_launch_overhead_ms + compute_time_ms)
# CUDA Graph: overhead only once for entire graph
graph_time = num_iterations * (kernel_launch_overhead_ms + compute_time_ms)
print("-" * 60)
print(" SIMULATED COMPARISON")
print("-" * 60)
print(f"\nAssumptions:")
print(f" Kernel launch overhead: {kernel_launch_overhead_ms} ms")
print(f" Number of kernels: {num_kernels}")
print(f" Compute time: {compute_time_ms} ms")
print(f" Iterations: {num_iterations}")
print(f"\nNormal execution:")
print(f" Per iteration: {num_kernels} × {kernel_launch_overhead_ms} + {compute_time_ms} = "
f"{num_kernels * kernel_launch_overhead_ms + compute_time_ms} ms")
print(f" Total: {normal_time} ms")
print(f"\nCUDA Graph replay:")
print(f" Per iteration: 1 × {kernel_launch_overhead_ms} + {compute_time_ms} = "
f"{kernel_launch_overhead_ms + compute_time_ms} ms")
print(f" Total: {graph_time} ms")
speedup = normal_time / graph_time
print(f"\nSpeedup: {speedup:.2f}x")
print(f"Overhead reduced from {num_kernels * kernel_launch_overhead_ms} ms to "
f"{kernel_launch_overhead_ms} ms per iteration")
def demonstrate_multiple_graphs():
"""Show how real systems handle multiple batch sizes."""
print("\n" + "=" * 60)
print(" HANDLING MULTIPLE BATCH SIZES")
print("=" * 60)
print("""
Real inference systems capture multiple CUDA Graphs:
Graph pool:
- batch_size=1: [captured graph for single request decode]
- batch_size=2: [captured graph for 2 requests]
- batch_size=4: [captured graph for 4 requests]
- batch_size=8: [captured graph for 8 requests]
- batch_size=16: [captured graph for 16 requests]
...
At runtime:
1. Check current batch size
2. If graph exists for this size: replay()
3. If not: fall back to normal execution
Trade-offs:
- More graphs = more GPU memory for graph storage
- Typical: capture for powers of 2 up to max batch size
- Padding: batch_size=5 might use batch_size=8 graph with padding
""")
def main():
parser = argparse.ArgumentParser(description="CUDA Graphs Demo")
parser.add_argument("--force-cpu", action="store_true",
help="Force CPU simulation even if GPU available")
args = parser.parse_args()
has_cuda = check_cuda() and not args.force_cpu
if has_cuda:
run_with_cuda()
else:
run_simulation()
demonstrate_multiple_graphs()
if __name__ == "__main__":
main()
scheduling_overhead_benchmark.py
Measure and visualize CPU scheduling overhead in inference
This script quantifies how much time is spent on CPU scheduling vs GPU computation.
What It Does
- Simulates inference batches of different sizes
- Measures scheduling time (CPU)
- Measures compute time (GPU)
- Shows overhead percentage
- Demonstrates overlap scheduling benefit
Run It
python tutorial/part3-inference/chapter10-scheduling-cuda/scripts/scheduling_overhead_benchmark.py
Example Output
=== Scheduling Overhead Benchmark ===
Batch Size | Schedule (ms) | Compute (ms) | Overhead %
-----------|---------------|--------------|------------
1 | 2.1 | 5.2 | 40%
4 | 2.3 | 6.1 | 38%
16 | 2.8 | 12.5 | 22%
64 | 3.5 | 45.2 | 8%
Observation: Larger batches amortize scheduling overhead.
=== With Overlap Scheduling ===
Batch Size | Effective Overhead
-----------|-------------------
1 | 5% (hidden)
4 | 3% (hidden)
16 | 1% (hidden)
64 | 0% (hidden)
Overlap scheduling hides CPU work behind GPU compute!
Why Small Batches Are Hard
Batch size 1:
Schedule: [====] 2ms
Compute: [========] 5ms
Total: 7ms for 1 token = 143 tokens/sec
Batch size 64:
Schedule: [====] 3.5ms
Compute: [======================================] 45ms
Total: 48.5ms for 64 tokens = 1320 tokens/sec
Small batches spend more time scheduling than computing!
Source Code
#!/usr/bin/env python3
"""
Scheduling Overhead Benchmark
This script demonstrates the CPU scheduling overhead in LLM inference
and shows how overlap scheduling reduces it.
Usage:
python scheduling_overhead_benchmark.py
"""
import argparse
import asyncio
import time
from dataclasses import dataclass
from typing import List, Tuple
import random
@dataclass
class BatchMetrics:
"""Metrics for a single batch."""
scheduling_time_ms: float
compute_time_ms: float
postprocess_time_ms: float
total_time_ms: float
class SchedulingOverheadBenchmark:
"""
Benchmark to measure and compare scheduling strategies.
Simulates:
- CPU scheduling overhead (preparing batches)
- GPU compute time (actual model execution)
- CPU postprocessing (handling results)
"""
def __init__(self, scheduling_ms: float = 2.0, compute_ms: float = 10.0,
postprocess_ms: float = 1.0):
self.scheduling_ms = scheduling_ms
self.compute_ms = compute_ms
self.postprocess_ms = postprocess_ms
def _simulate_scheduling(self) -> float:
"""Simulate CPU scheduling work."""
start = time.perf_counter()
# Simulate work: preparing batch metadata, allocating, etc.
time.sleep(self.scheduling_ms / 1000)
return (time.perf_counter() - start) * 1000
def _simulate_compute(self) -> float:
"""Simulate GPU compute time."""
start = time.perf_counter()
time.sleep(self.compute_ms / 1000)
return (time.perf_counter() - start) * 1000
def _simulate_postprocess(self) -> float:
"""Simulate result postprocessing."""
start = time.perf_counter()
time.sleep(self.postprocess_ms / 1000)
return (time.perf_counter() - start) * 1000
def run_sequential(self, num_batches: int) -> Tuple[float, List[BatchMetrics]]:
"""
Run batches sequentially (traditional approach).
Timeline: [Schedule] -> [Compute] -> [Postprocess] -> [Schedule] -> ...
"""
metrics = []
total_start = time.perf_counter()
for _ in range(num_batches):
batch_start = time.perf_counter()
sched_time = self._simulate_scheduling()
compute_time = self._simulate_compute()
post_time = self._simulate_postprocess()
batch_total = (time.perf_counter() - batch_start) * 1000
metrics.append(BatchMetrics(
scheduling_time_ms=sched_time,
compute_time_ms=compute_time,
postprocess_time_ms=post_time,
total_time_ms=batch_total
))
total_time = (time.perf_counter() - total_start) * 1000
return total_time, metrics
async def run_overlapped(self, num_batches: int) -> Tuple[float, List[BatchMetrics]]:
"""
Run batches with overlap scheduling.
Key insight: Schedule batch N+1 while batch N is computing.
Timeline (overlapped):
[Schedule 0] -> [Compute 0]
[Schedule 1] -> [Compute 1]
[Postprocess 0]
[Schedule 2] -> ...
"""
metrics = []
total_start = time.perf_counter()
# Pipeline: we overlap scheduling with previous compute
for i in range(num_batches):
batch_start = time.perf_counter()
if i == 0:
# First batch: no overlap possible
sched_time = self._simulate_scheduling()
compute_time = self._simulate_compute()
post_time = self._simulate_postprocess()
else:
# Subsequent batches: scheduling was done during previous compute
sched_time = 0 # Already done (overlapped)
# But we still need to do scheduling for NEXT batch
# This runs in "parallel" with compute
compute_start = time.perf_counter()
# Simulate both happening together
# In reality, GPU computes while CPU schedules
# Here we take max of the two times
parallel_time = max(self.compute_ms, self.scheduling_ms) / 1000
time.sleep(parallel_time)
compute_time = (time.perf_counter() - compute_start) * 1000
post_time = self._simulate_postprocess()
batch_total = (time.perf_counter() - batch_start) * 1000
metrics.append(BatchMetrics(
scheduling_time_ms=sched_time,
compute_time_ms=compute_time,
postprocess_time_ms=post_time,
total_time_ms=batch_total
))
total_time = (time.perf_counter() - total_start) * 1000
return total_time, metrics
def print_results(name: str, total_time: float, metrics: List[BatchMetrics],
num_batches: int):
"""Print benchmark results."""
print(f"\n{name}")
print("-" * 50)
avg_sched = sum(m.scheduling_time_ms for m in metrics) / len(metrics)
avg_compute = sum(m.compute_time_ms for m in metrics) / len(metrics)
avg_post = sum(m.postprocess_time_ms for m in metrics) / len(metrics)
avg_total = sum(m.total_time_ms for m in metrics) / len(metrics)
print(f"Total time: {total_time:.2f} ms")
print(f"Throughput: {num_batches / (total_time / 1000):.2f} batches/sec")
print(f"\nPer-batch breakdown:")
print(f" Scheduling: {avg_sched:.2f} ms")
print(f" Compute: {avg_compute:.2f} ms")
print(f" Postprocess: {avg_post:.2f} ms")
print(f" Total: {avg_total:.2f} ms")
overhead_pct = (avg_sched + avg_post) / avg_total * 100
print(f"\nCPU overhead: {overhead_pct:.1f}%")
def visualize_timeline(scheduling_ms: float, compute_ms: float,
postprocess_ms: float, num_batches: int = 4):
"""Visualize the scheduling timeline."""
print("\n" + "=" * 70)
print(" TIMELINE VISUALIZATION")
print("=" * 70)
scale = 2 # Characters per ms
def bar(char: str, ms: float) -> str:
return char * int(ms * scale)
print("\nSEQUENTIAL EXECUTION:")
print(" S = Schedule, C = Compute, P = Postprocess, . = idle\n")
cpu_line = ""
gpu_line = ""
for i in range(num_batches):
# CPU: schedule, then idle, then postprocess
cpu_line += bar('S', scheduling_ms)
cpu_line += bar('.', compute_ms)
cpu_line += bar('P', postprocess_ms)
# GPU: idle, then compute, then idle
gpu_line += bar('.', scheduling_ms)
gpu_line += bar('C', compute_ms)
gpu_line += bar('.', postprocess_ms)
print(f" CPU: {cpu_line}")
print(f" GPU: {gpu_line}")
print("\nOVERLAPPED EXECUTION:")
print(" CPU schedules batch N+1 while GPU computes batch N\n")
cpu_line = ""
gpu_line = ""
for i in range(num_batches):
if i == 0:
# First batch: no overlap
cpu_line += bar('S', scheduling_ms)
gpu_line += bar('.', scheduling_ms)
# Overlap: CPU schedules next while GPU computes
overlap_time = max(compute_ms, scheduling_ms)
if scheduling_ms <= compute_ms:
cpu_line += bar('S', scheduling_ms) + bar('.', compute_ms - scheduling_ms)
else:
cpu_line += bar('S', compute_ms) + bar('S', scheduling_ms - compute_ms)
gpu_line += bar('C', overlap_time)
# Postprocess
cpu_line += bar('P', postprocess_ms)
gpu_line += bar('.', postprocess_ms)
print(f" CPU: {cpu_line}")
print(f" GPU: {gpu_line}")
def main():
parser = argparse.ArgumentParser(description="Scheduling Overhead Benchmark")
parser.add_argument("--batches", "-b", type=int, default=20,
help="Number of batches to process")
parser.add_argument("--scheduling-ms", type=float, default=2.0,
help="Simulated scheduling time in ms")
parser.add_argument("--compute-ms", type=float, default=10.0,
help="Simulated compute time in ms")
parser.add_argument("--postprocess-ms", type=float, default=1.0,
help="Simulated postprocess time in ms")
args = parser.parse_args()
print("╔" + "═" * 68 + "╗")
print("║" + " SCHEDULING OVERHEAD BENCHMARK".center(68) + "║")
print("╚" + "═" * 68 + "╝")
print(f"\nConfiguration:")
print(f" Batches: {args.batches}")
print(f" Scheduling time: {args.scheduling_ms} ms")
print(f" Compute time: {args.compute_ms} ms")
print(f" Postprocess time: {args.postprocess_ms} ms")
benchmark = SchedulingOverheadBenchmark(
scheduling_ms=args.scheduling_ms,
compute_ms=args.compute_ms,
postprocess_ms=args.postprocess_ms
)
# Run sequential
print("\n" + "=" * 70)
print(" BENCHMARK RESULTS")
print("=" * 70)
seq_time, seq_metrics = benchmark.run_sequential(args.batches)
print_results("SEQUENTIAL (Traditional)", seq_time, seq_metrics, args.batches)
# Run overlapped
overlap_time, overlap_metrics = asyncio.run(
benchmark.run_overlapped(args.batches)
)
print_results("OVERLAPPED (Zero-Overhead)", overlap_time, overlap_metrics, args.batches)
# Comparison
print("\n" + "=" * 70)
print(" COMPARISON")
print("=" * 70)
speedup = seq_time / overlap_time
time_saved = seq_time - overlap_time
print(f"\nSequential: {seq_time:.2f} ms")
print(f"Overlapped: {overlap_time:.2f} ms")
print(f"Speedup: {speedup:.2f}x")
print(f"Time saved: {time_saved:.2f} ms ({time_saved/seq_time*100:.1f}%)")
# Visualize
visualize_timeline(args.scheduling_ms, args.compute_ms,
args.postprocess_ms, num_batches=4)
# Analysis
print("\n" + "=" * 70)
print(" ANALYSIS")
print("=" * 70)
print(f"""
Key Observations:
1. OVERHEAD IMPACT
Without overlap: {args.scheduling_ms + args.postprocess_ms} ms overhead per batch
Total per batch: {args.scheduling_ms + args.compute_ms + args.postprocess_ms} ms
Overhead percentage: {(args.scheduling_ms + args.postprocess_ms) / (args.scheduling_ms + args.compute_ms + args.postprocess_ms) * 100:.1f}%
2. OVERLAP BENEFIT
Scheduling is hidden behind compute (when compute > scheduling)
Effective overhead: {args.postprocess_ms} ms (only postprocessing)
Overhead reduction: {(args.scheduling_ms) / (args.scheduling_ms + args.postprocess_ms) * 100:.0f}%
3. WHEN OVERLAP HELPS MOST
- Long compute times (GPU-bound workloads)
- Significant scheduling overhead
- Batch decode in LLM inference
4. WHEN OVERLAP HELPS LESS
- Very short compute times
- Scheduling time > compute time
- Complex dependencies between batches
""")
if __name__ == "__main__":
main()
Chapter 11: Speculative and Constraint Decoding
“Predict multiple tokens, verify in parallel. It’s like spell-checking as you type, but for LLMs.”
Learning Objectives
By the end of this chapter, you will be able to:
- Explain how speculative decoding accelerates generation
- Understand the acceptance/rejection mechanism
- Use constraint decoding for structured output (JSON, code)
- Choose when to apply these techniques
Prerequisites
- Completed Chapters 8-10 (Inference Systems)
- Understanding of autoregressive generation
- Basic probability concepts
Concept Overview
The Problem: Sequential Generation
Autoregressive LLMs generate one token at a time:
Prompt: "The capital of France is"
Step 1: → "Paris"
Step 2: → "."
Step 3: → " It"
Step 4: → " is"
...
Each step requires a full model forward pass. For a 70B model generating 100 tokens:
- 100 sequential forward passes
- Can’t parallelize within a request
- Memory bandwidth limited during decode
Speculative Decoding: The Key Insight
What if we could verify multiple tokens at once?
Speculative decoding uses a small “draft” model to guess multiple tokens, then verifies them in parallel with the large “target” model.
Draft model (small, fast):
Input: "The capital of France is"
Draft: ["Paris", ".", " It", " is", " known"]
(5 tokens in one pass)
Target model (large, accurate):
Verify all 5 tokens in ONE parallel forward pass
Accept: ["Paris", ".", " It"] (3 accepted)
Reject: [" is", " known"] (distribution mismatch)
Result: Generated 3 tokens with 1 target model pass instead of 3!
How Verification Works
The target model doesn’t just check “right or wrong”—it uses a probabilistic acceptance criterion:
For each position i:
p_target = target_model_probability(token_i)
p_draft = draft_model_probability(token_i)
If p_target >= p_draft:
ACCEPT (draft was conservative)
Else:
ACCEPT with probability p_target / p_draft
(randomly accept based on ratio)
If REJECT:
Sample new token from adjusted distribution
Stop accepting further tokens
This ensures the output distribution exactly matches the target model!
Speedup Analysis
Let:
- γ = acceptance rate (typically 0.7-0.9)
- k = draft length (tokens generated by draft)
- c = cost ratio (target_time / draft_time, typically 10-50x)
Expected tokens per target forward pass:
E[tokens] = 1 + γ + γ² + ... + γ^k = (1 - γ^(k+1)) / (1 - γ)
For γ=0.8, k=5:
E[tokens] = (1 - 0.8^6) / (1 - 0.8) = 3.36 tokens per pass
3.4x theoretical speedup!
Draft Model Selection
Good draft models:
- Same tokenizer as target (required!)
- Similar training data
- Much smaller (7B for 70B target)
Common pairings:
- LLaMA-70B target + LLaMA-7B draft
- GPT-4 target + GPT-3.5 draft
- Mixtral target + Mistral draft
Constraint Decoding: Structured Output
Sometimes we need output to follow a specific format:
- JSON schema
- SQL query
- Function calls
- Code in specific language
Constraint decoding restricts token probabilities to only valid continuations.
Grammar-Based Constraints
Define valid output using a grammar:
json_value := object | array | string | number | "true" | "false" | "null"
object := "{" (pair ("," pair)*)? "}"
pair := string ":" json_value
...
At each generation step:
- Get logits from model
- Identify tokens that lead to valid states
- Mask invalid tokens (set probability to 0)
- Sample from valid tokens only
def constrained_sample(logits, grammar_state):
# Get valid next tokens from grammar
valid_tokens = grammar_state.get_valid_tokens()
# Mask invalid tokens
mask = torch.zeros_like(logits)
mask[valid_tokens] = 1
logits = logits * mask + (1 - mask) * float('-inf')
# Sample from masked distribution
return torch.multinomial(torch.softmax(logits, dim=-1), 1)
Regex Constraints
For simpler patterns, regex constraints work well:
# Only generate valid email addresses
pattern = r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
# At each step, check if current output + candidate token
# can still match the pattern
Combining Speculative + Constraint Decoding
Can we get both speedup AND structured output? Yes!
The draft model also generates under constraints:
- Draft generates k constrained tokens
- Target verifies (also checking constraints)
- All accepted tokens are guaranteed valid
Tricky part: Draft must use same constraint state as target.
Code Walkthrough
Script 1: speculative_demo.py
Demonstrates speculative decoding:
- Simulates draft/target model interaction
- Shows acceptance/rejection process
- Calculates speedup
Script 2: json_constraint_demo.py
Demonstrates constraint decoding:
- Simple JSON schema
- Token masking
- Valid output generation
When to Use What
| Technique | Best For | Avoid When |
|---|---|---|
| Speculative | Long generations, high acceptance rate | Very different draft/target, short outputs |
| Constraint | Structured output, API responses | Free-form text |
| Combined | Structured output with length | Complex grammars + low acceptance |
Try It Yourself
Exercise 1: Calculate Speedup
For a system with:
- Acceptance rate: 75%
- Draft length: 4 tokens
- Draft cost: 5% of target cost
Calculate:
- Expected tokens per target pass
- Overall speedup including draft cost
Exercise 2: Design a Grammar
Write a simple grammar for:
- Python function definitions
- Email addresses
- Phone numbers
Exercise 3: Acceptance Rate Experiment
If you had access to models:
- Measure acceptance rate for different draft lengths
- Find the optimal draft length
- How does temperature affect acceptance?
Key Takeaways
- Speculative decoding parallelizes verification - Multiple tokens checked in one forward pass
- Acceptance criterion preserves distribution - Output is identical to non-speculative
- Draft model selection matters - Same tokenizer, similar distribution
- Constraint decoding ensures validity - Grammar-based token masking
- Both can combine - Speedup + structure
The Trade-off Triangle
Latency
/\
/ \
/ \
/ \
/________\
Quality Structure
- Speculative decoding: Latency ↓, Quality =, Structure =
- Constraint decoding: Latency ↑, Quality ≈, Structure ↑
- Combined: Latency ↓, Quality ≈, Structure ↑
What’s Next?
In Part IV, we’ll explore RLHF Systems—how to train LLMs with human feedback, including the complex multi-model orchestration required for PPO training.
Further Reading
speculative_demo.py
Simulate speculative decoding to understand the speedup
This script demonstrates how speculative decoding works by simulating the draft-verify process.
What It Does
- Simulates a draft model generating k tokens
- Simulates a target model verifying them
- Shows acceptance/rejection for each token
- Calculates effective speedup
Run It
python tutorial/part3-inference/chapter11-spec-constraint/scripts/speculative_demo.py
Example Output
=== Speculative Decoding Demo ===
Settings:
Draft length (k): 5
Acceptance rate (γ): 0.80
Simulation (10 generations):
Generation 1:
Draft tokens: ["The", "quick", "brown", "fox", "jumps"]
Target verify: [✓ accept] [✓ accept] [✓ accept] [✗ reject] [- skip]
Tokens accepted: 3 (with 1 target forward pass)
Generation 2:
Draft tokens: ["over", "the", "lazy", "dog", "."]
Target verify: [✓ accept] [✓ accept] [✓ accept] [✓ accept] [✓ accept]
Tokens accepted: 5 (with 1 target forward pass)
...
Summary:
Total tokens generated: 38
Total target forward passes: 10
Tokens per pass: 3.8 (vs 1.0 without speculation)
Theoretical speedup: 3.8x
Cost breakdown:
Target passes: 10 × 100ms = 1000ms
Draft passes: 50 × 5ms = 250ms
Total time: 1250ms
Time without speculation: 3800ms
Actual speedup: 3.04x
The Math
Expected tokens per target pass:
E[tokens] = Σ(i=0 to k) γⁱ = (1 - γ^(k+1)) / (1 - γ)
For γ=0.8, k=5: E[tokens] = 3.36
Source Code
#!/usr/bin/env python3
"""
Speculative Decoding Demonstration
This script demonstrates how speculative decoding works:
- Draft model generates multiple candidate tokens
- Target model verifies them in parallel
- Acceptance/rejection based on probability ratio
Usage:
python speculative_demo.py
python speculative_demo.py --draft-length 5 --acceptance-rate 0.8
"""
import argparse
import random
from dataclasses import dataclass
from typing import List, Tuple, Optional
@dataclass
class Token:
"""Represents a generated token."""
id: int
text: str
draft_prob: float
target_prob: float
def simulate_draft_model(prompt: str, num_tokens: int,
vocab: List[str]) -> List[Token]:
"""
Simulate a draft model generating tokens.
In reality, this would be a small LLM like LLaMA-7B.
"""
tokens = []
for _ in range(num_tokens):
# Random token selection (simulated)
token_id = random.randint(0, len(vocab) - 1)
token_text = vocab[token_id]
# Random probabilities (simulated)
# Draft model is less confident
draft_prob = random.uniform(0.3, 0.9)
tokens.append(Token(
id=token_id,
text=token_text,
draft_prob=draft_prob,
target_prob=0 # Set by target model
))
return tokens
def simulate_target_verification(draft_tokens: List[Token],
base_acceptance_rate: float) -> List[Token]:
"""
Simulate target model verification of draft tokens.
In reality, this would run the large model (e.g., LLaMA-70B)
on all draft tokens in parallel.
"""
for token in draft_tokens:
# Target model's probability (simulated)
# Higher acceptance rate = closer to draft distribution
if random.random() < base_acceptance_rate:
# Target agrees or is more confident
token.target_prob = token.draft_prob * random.uniform(0.9, 1.5)
else:
# Target disagrees
token.target_prob = token.draft_prob * random.uniform(0.1, 0.8)
# Clamp to valid probability
token.target_prob = min(1.0, token.target_prob)
return draft_tokens
def speculative_acceptance(tokens: List[Token]) -> Tuple[List[Token], Optional[Token]]:
"""
Apply speculative decoding acceptance criterion.
For each token:
- If p_target >= p_draft: ACCEPT
- Else: ACCEPT with probability p_target / p_draft
- On first rejection: sample from adjusted distribution, stop
"""
accepted = []
correction_token = None
for i, token in enumerate(tokens):
if token.target_prob >= token.draft_prob:
# Definitely accept
accepted.append(token)
else:
# Probabilistic acceptance
acceptance_prob = token.target_prob / token.draft_prob
if random.random() < acceptance_prob:
accepted.append(token)
else:
# Reject: sample from (target - draft) distribution
# Simulated as a new random token
correction_token = Token(
id=random.randint(0, 99),
text=f"[corrected_{i}]",
draft_prob=0,
target_prob=token.target_prob
)
break # Stop accepting after first rejection
return accepted, correction_token
def run_speculative_decoding(prompt: str, target_length: int,
draft_length: int, acceptance_rate: float,
vocab: List[str]) -> Tuple[List[str], dict]:
"""
Run speculative decoding simulation.
Returns generated tokens and statistics.
"""
generated = []
stats = {
'target_calls': 0,
'draft_calls': 0,
'tokens_accepted': 0,
'tokens_rejected': 0,
'total_tokens': 0,
}
while len(generated) < target_length:
# Step 1: Draft model generates k tokens
remaining = target_length - len(generated)
k = min(draft_length, remaining)
draft_tokens = simulate_draft_model(prompt, k, vocab)
stats['draft_calls'] += 1
# Step 2: Target model verifies all k tokens in parallel (ONE call)
verified_tokens = simulate_target_verification(draft_tokens, acceptance_rate)
stats['target_calls'] += 1
# Step 3: Apply acceptance criterion
accepted, correction = speculative_acceptance(verified_tokens)
# Add accepted tokens
for token in accepted:
generated.append(token.text)
stats['tokens_accepted'] += 1
# Add correction token if any
if correction:
generated.append(correction.text)
stats['tokens_rejected'] += 1
stats['total_tokens'] = len(generated)
# Update prompt for next iteration
prompt = prompt + " " + " ".join(t.text for t in accepted)
if correction:
prompt += " " + correction.text
return generated[:target_length], stats
def calculate_speedup(stats: dict, draft_length: int,
draft_cost_ratio: float = 0.1) -> dict:
"""
Calculate speedup from speculative decoding.
Args:
stats: Statistics from run_speculative_decoding
draft_length: Number of tokens drafted per call
draft_cost_ratio: Cost of draft call relative to target (e.g., 0.1 = 10%)
"""
tokens = stats['total_tokens']
# Without speculative: one target call per token
baseline_cost = tokens
# With speculative: target + draft calls
spec_cost = stats['target_calls'] + stats['draft_calls'] * draft_cost_ratio
speedup = baseline_cost / spec_cost
tokens_per_target_call = tokens / stats['target_calls']
return {
'baseline_cost': baseline_cost,
'speculative_cost': spec_cost,
'speedup': speedup,
'tokens_per_target_call': tokens_per_target_call,
'acceptance_rate': stats['tokens_accepted'] / (stats['tokens_accepted'] + stats['tokens_rejected'])
}
def visualize_speculative_step(draft_tokens: List[Token],
accepted: List[Token],
correction: Optional[Token]):
"""Visualize a single speculative decoding step."""
print("\nDraft tokens:")
for i, token in enumerate(draft_tokens):
status = "✓" if token in accepted else "✗"
print(f" {i}: {token.text:15} p_draft={token.draft_prob:.2f} "
f"p_target={token.target_prob:.2f} {status}")
if correction:
print(f"\nCorrection token: {correction.text}")
print(f"Accepted: {len(accepted)}/{len(draft_tokens)} tokens")
def main():
parser = argparse.ArgumentParser(description="Speculative Decoding Demo")
parser.add_argument("--draft-length", "-k", type=int, default=5,
help="Number of tokens to draft")
parser.add_argument("--target-length", "-n", type=int, default=50,
help="Total tokens to generate")
parser.add_argument("--acceptance-rate", "-a", type=float, default=0.75,
help="Base acceptance rate (0-1)")
parser.add_argument("--draft-cost", type=float, default=0.1,
help="Draft cost as fraction of target cost")
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
args = parser.parse_args()
random.seed(args.seed)
print("╔" + "═" * 68 + "╗")
print("║" + " SPECULATIVE DECODING DEMONSTRATION".center(68) + "║")
print("╚" + "═" * 68 + "╝")
# Simple vocabulary for demonstration
vocab = [
"the", "a", "is", "are", "was", "were", "has", "have", "had",
"will", "would", "could", "should", "may", "might", "must",
"and", "or", "but", "if", "then", "else", "when", "where",
"who", "what", "which", "that", "this", "these", "those",
"I", "you", "he", "she", "it", "we", "they", "me", "him", "her",
"good", "bad", "big", "small", "new", "old", "first", "last",
"time", "way", "year", "day", "thing", "man", "world", "life",
]
print(f"\nConfiguration:")
print(f" Draft length (k): {args.draft_length}")
print(f" Target length: {args.target_length}")
print(f" Base acceptance rate: {args.acceptance_rate}")
print(f" Draft cost ratio: {args.draft_cost}")
# Run speculative decoding
print("\n" + "=" * 70)
print(" RUNNING SPECULATIVE DECODING")
print("=" * 70)
prompt = "Once upon a time"
generated, stats = run_speculative_decoding(
prompt, args.target_length, args.draft_length,
args.acceptance_rate, vocab
)
print(f"\nGenerated text preview: {' '.join(generated[:20])}...")
# Calculate speedup
speedup_stats = calculate_speedup(stats, args.draft_length, args.draft_cost)
# Results
print("\n" + "=" * 70)
print(" RESULTS")
print("=" * 70)
print(f"\nGeneration statistics:")
print(f" Total tokens: {stats['total_tokens']}")
print(f" Target model calls: {stats['target_calls']}")
print(f" Draft model calls: {stats['draft_calls']}")
print(f" Tokens accepted: {stats['tokens_accepted']}")
print(f" Tokens rejected: {stats['tokens_rejected']}")
print(f"\nPerformance:")
print(f" Tokens per target call: {speedup_stats['tokens_per_target_call']:.2f}")
print(f" Effective acceptance rate: {speedup_stats['acceptance_rate']:.2%}")
print(f" Speedup: {speedup_stats['speedup']:.2f}x")
# Analysis
print("\n" + "=" * 70)
print(" ANALYSIS")
print("=" * 70)
print(f"""
How Speculative Decoding Works:
1. DRAFT PHASE (fast, cheap)
Small model generates {args.draft_length} candidate tokens quickly
Cost: ~{args.draft_cost*100:.0f}% of target model
2. VERIFY PHASE (one parallel call)
Large model processes ALL draft tokens in ONE forward pass
Each token gets target model probability
3. ACCEPT/REJECT
Token accepted if: p_target >= p_draft (always)
or: random < p_target/p_draft (probabilistic)
First rejection triggers resampling and stops
4. GUARANTEE
Output distribution is IDENTICAL to running target model alone
No quality degradation!
Why It Works:
- Verification is parallel (1 call for k tokens)
- High acceptance rate ({speedup_stats['acceptance_rate']:.0%}) means few rejections
- Draft model cost is negligible ({args.draft_cost*100:.0f}%)
When It Helps Most:
- High acceptance rate (similar draft/target distributions)
- Long generations (amortize setup cost)
- Memory-bound systems (decode phase)
When It Helps Less:
- Low acceptance rate (very different distributions)
- Short generations (overhead not amortized)
- Compute-bound systems (prefill phase)
""")
if __name__ == "__main__":
main()
json_constraint_demo.py
Generate valid JSON using constrained decoding
This script shows how to ensure LLM output follows a specific format using grammar-based constraints.
What It Does
- Defines a simple JSON grammar
- At each step, identifies valid next tokens
- Masks invalid tokens
- Generates guaranteed-valid JSON
Run It
python tutorial/part3-inference/chapter11-spec-constraint/scripts/json_constraint_demo.py
Example Output
=== JSON Constraint Decoding Demo ===
Target schema:
{
"name": <string>,
"age": <number>,
"active": <boolean>
}
Generation trace:
Step 1: State = START_OBJECT
Valid tokens: ['{']
Sampled: '{'
Step 2: State = EXPECT_KEY
Valid tokens: ['"name"', '"age"', '"active"']
Sampled: '"name"'
Step 3: State = EXPECT_COLON
Valid tokens: [':']
Sampled: ':'
Step 4: State = EXPECT_STRING
Valid tokens: ['"', 'a'-'z', 'A'-'Z', ...]
Sampled: '"Alice"'
...
Final output (guaranteed valid JSON):
{
"name": "Alice",
"age": 30,
"active": true
}
The Technique
def constrained_generate(model, grammar):
state = grammar.initial_state()
output = []
while not state.is_finished():
# Get model's preferences
logits = model.get_logits(output)
# Mask invalid tokens
valid_tokens = state.get_valid_tokens()
for i in range(vocab_size):
if i not in valid_tokens:
logits[i] = float('-inf')
# Sample from valid tokens only
token = sample(logits)
output.append(token)
state = state.advance(token)
return output
Why This Matters
Without constraints:
- Model might output invalid JSON
- Need retry logic
- Unpredictable latency
With constraints:
- Always valid output
- Single generation attempt
- Predictable behavior
Source Code
#!/usr/bin/env python3
"""
JSON Constraint Decoding Demonstration
This script demonstrates how constraint decoding ensures valid JSON output
by masking invalid tokens at each generation step.
Usage:
python json_constraint_demo.py
"""
import argparse
import random
from dataclasses import dataclass
from enum import Enum, auto
from typing import List, Set, Optional
class JsonState(Enum):
"""States in simplified JSON grammar."""
START = auto() # Expecting { or [
OBJECT_START = auto() # Just saw {, expecting key or }
OBJECT_KEY = auto() # Expecting string key
OBJECT_COLON = auto() # Expecting :
OBJECT_VALUE = auto() # Expecting value
OBJECT_COMMA = auto() # Expecting , or }
ARRAY_START = auto() # Just saw [, expecting value or ]
ARRAY_VALUE = auto() # Expecting value
ARRAY_COMMA = auto() # Expecting , or ]
STRING = auto() # Inside a string
NUMBER = auto() # Inside a number
DONE = auto() # Finished
@dataclass
class Token:
"""Represents a vocabulary token."""
id: int
text: str
is_valid: bool = True
class SimplifiedJsonGrammar:
"""
Simplified JSON grammar for demonstration.
Real implementations use proper grammar parsing (e.g., lark, interegular).
"""
def __init__(self):
self.state = JsonState.START
self.stack = [] # Track nested structures
# Simplified vocabulary
self.vocab = {
0: "{",
1: "}",
2: "[",
3: "]",
4: ":",
5: ",",
6: '"name"',
7: '"value"',
8: '"id"',
9: '"type"',
10: '"hello"',
11: '"world"',
12: "123",
13: "456",
14: "true",
15: "false",
16: "null",
}
def get_valid_tokens(self) -> Set[int]:
"""Return set of valid next token IDs given current state."""
valid = set()
if self.state == JsonState.START:
valid = {0, 2} # { or [
elif self.state == JsonState.OBJECT_START:
valid = {1, 6, 7, 8, 9} # } or string keys
elif self.state == JsonState.OBJECT_KEY:
valid = {6, 7, 8, 9} # String keys
elif self.state == JsonState.OBJECT_COLON:
valid = {4} # :
elif self.state == JsonState.OBJECT_VALUE:
valid = {0, 2, 6, 7, 10, 11, 12, 13, 14, 15, 16} # Any value
elif self.state == JsonState.OBJECT_COMMA:
if self.stack and self.stack[-1] == "object":
valid = {1, 5} # } or ,
else:
valid = {1}
elif self.state == JsonState.ARRAY_START:
valid = {0, 2, 3, 6, 7, 10, 11, 12, 13, 14, 15, 16} # ] or values
elif self.state == JsonState.ARRAY_VALUE:
valid = {0, 2, 6, 7, 10, 11, 12, 13, 14, 15, 16} # Any value
elif self.state == JsonState.ARRAY_COMMA:
valid = {3, 5} # ] or ,
return valid
def advance(self, token_id: int):
"""Advance grammar state based on token."""
token = self.vocab[token_id]
if token == "{":
self.stack.append("object")
self.state = JsonState.OBJECT_START
elif token == "}":
if self.stack and self.stack[-1] == "object":
self.stack.pop()
if not self.stack:
self.state = JsonState.DONE
else:
self.state = JsonState.OBJECT_COMMA if self.stack[-1] == "object" else JsonState.ARRAY_COMMA
elif token == "[":
self.stack.append("array")
self.state = JsonState.ARRAY_START
elif token == "]":
if self.stack and self.stack[-1] == "array":
self.stack.pop()
if not self.stack:
self.state = JsonState.DONE
else:
self.state = JsonState.OBJECT_COMMA if self.stack[-1] == "object" else JsonState.ARRAY_COMMA
elif token == ":":
self.state = JsonState.OBJECT_VALUE
elif token == ",":
if self.stack[-1] == "object":
self.state = JsonState.OBJECT_KEY
else:
self.state = JsonState.ARRAY_VALUE
elif token.startswith('"') and self.state in [JsonState.OBJECT_START, JsonState.OBJECT_KEY]:
self.state = JsonState.OBJECT_COLON
elif token.startswith('"') or token in ["true", "false", "null"] or token.isdigit():
if self.state == JsonState.OBJECT_VALUE:
self.state = JsonState.OBJECT_COMMA
elif self.state in [JsonState.ARRAY_START, JsonState.ARRAY_VALUE]:
self.state = JsonState.ARRAY_COMMA
def constrained_generation(grammar: SimplifiedJsonGrammar,
max_tokens: int = 20) -> List[str]:
"""
Generate tokens with grammar constraints.
At each step:
1. Get valid tokens from grammar
2. "Sample" from valid tokens (random for demo)
3. Advance grammar state
"""
generated = []
for _ in range(max_tokens):
if grammar.state == JsonState.DONE:
break
valid_tokens = grammar.get_valid_tokens()
if not valid_tokens:
print("Warning: No valid tokens!")
break
# In real implementation, this would be:
# 1. Get logits from model
# 2. Mask invalid tokens (set to -inf)
# 3. Sample from softmax
# Here we just pick randomly from valid tokens
token_id = random.choice(list(valid_tokens))
token_text = grammar.vocab[token_id]
generated.append(token_text)
grammar.advance(token_id)
return generated
def demonstrate_token_masking():
"""Show how token masking works at each step."""
print("=" * 70)
print(" TOKEN MASKING DEMONSTRATION")
print("=" * 70)
grammar = SimplifiedJsonGrammar()
steps = []
generated = []
for i in range(10):
if grammar.state == JsonState.DONE:
break
valid_tokens = grammar.get_valid_tokens()
all_tokens = set(grammar.vocab.keys())
invalid_tokens = all_tokens - valid_tokens
# Pick a random valid token
token_id = random.choice(list(valid_tokens))
token_text = grammar.vocab[token_id]
steps.append({
'step': i,
'state': grammar.state.name,
'valid': [grammar.vocab[t] for t in valid_tokens],
'invalid_count': len(invalid_tokens),
'chosen': token_text,
})
generated.append(token_text)
grammar.advance(token_id)
print("\nStep-by-step token masking:\n")
for step in steps:
print(f"Step {step['step']}:")
print(f" State: {step['state']}")
print(f" Valid tokens: {step['valid']}")
print(f" Masked (invalid): {step['invalid_count']} tokens")
print(f" Chosen: {step['chosen']}")
print()
result = "".join(generated)
print(f"Generated JSON: {result}")
def compare_constrained_unconstrained():
"""Compare constrained vs unconstrained generation."""
print("\n" + "=" * 70)
print(" CONSTRAINED vs UNCONSTRAINED COMPARISON")
print("=" * 70)
vocab = list(SimplifiedJsonGrammar().vocab.values())
print("\nUNCONSTRAINED (random tokens):")
unconstrained = [random.choice(vocab) for _ in range(10)]
result = "".join(unconstrained)
print(f" Generated: {result}")
print(f" Valid JSON: {is_valid_json_like(result)}")
print("\nCONSTRAINED (grammar-guided):")
random.seed(42) # For reproducibility
grammar = SimplifiedJsonGrammar()
constrained = constrained_generation(grammar, max_tokens=15)
result = "".join(constrained)
print(f" Generated: {result}")
print(f" Valid JSON: {is_valid_json_like(result)}")
def is_valid_json_like(s: str) -> bool:
"""Simple check if string looks like valid JSON structure."""
s = s.strip()
if not s:
return False
# Check balanced brackets
stack = []
for char in s:
if char in "{[":
stack.append(char)
elif char == "}":
if not stack or stack[-1] != "{":
return False
stack.pop()
elif char == "]":
if not stack or stack[-1] != "[":
return False
stack.pop()
return len(stack) == 0 and (s.startswith("{") or s.startswith("["))
def show_real_world_usage():
"""Show real-world constraint decoding scenarios."""
print("\n" + "=" * 70)
print(" REAL-WORLD CONSTRAINT DECODING SCENARIOS")
print("=" * 70)
print("""
1. JSON SCHEMA CONSTRAINTS
Force model output to match a specific schema:
Schema: {"name": string, "age": number, "active": boolean}
At each step, only allow tokens that can lead to valid schema:
- After {"name": only allow ": and string tokens
- After "name": "...", only allow , or }
- etc.
2. SQL QUERY CONSTRAINTS
Ensure valid SQL syntax:
Grammar: SELECT columns FROM table WHERE condition
Mask tokens that would break syntax:
- After SELECT: only column names or *
- After FROM: only table names
- etc.
3. FUNCTION CALL CONSTRAINTS
Match function signature:
def greet(name: str, times: int = 1)
Force output like: greet("Alice", 3)
- First token must be function name
- Then (
- Then valid arguments matching types
- etc.
4. REGEX PATTERN CONSTRAINTS
Match patterns like email, URL, phone number:
Email: [a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}
Each token must keep the partial output matching the pattern.
5. PROGRAMMING LANGUAGE CONSTRAINTS
Generate syntactically valid code:
Python grammar ensures:
- Proper indentation
- Balanced parentheses
- Valid keywords
IMPLEMENTATION NOTE:
Real systems use libraries like:
- outlines (https://github.com/outlines-dev/outlines)
- guidance (https://github.com/guidance-ai/guidance)
- lmql (https://lmql.ai/)
""")
def main():
parser = argparse.ArgumentParser(description="JSON Constraint Decoding Demo")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
args = parser.parse_args()
random.seed(args.seed)
print("╔" + "═" * 68 + "╗")
print("║" + " JSON CONSTRAINT DECODING DEMONSTRATION".center(68) + "║")
print("╚" + "═" * 68 + "╝")
demonstrate_token_masking()
compare_constrained_unconstrained()
show_real_world_usage()
# Summary
print("\n" + "=" * 70)
print(" KEY INSIGHTS")
print("=" * 70)
print("""
1. CONSTRAINT DECODING GUARANTEES VALIDITY
Every generated token is checked against grammar
Invalid tokens are masked (probability = 0)
Output is always syntactically correct
2. MINIMAL QUALITY IMPACT
Model still chooses among valid tokens
Only invalid options are removed
Semantic quality preserved
3. SLIGHT LATENCY INCREASE
Grammar state must be tracked
Valid token computation at each step
Usually <10% overhead
4. COMPOSABLE WITH OTHER TECHNIQUES
Works with speculative decoding
Works with beam search
Works with any sampling strategy
5. LIBRARY SUPPORT
Use production libraries (outlines, guidance, lmql)
They handle complex grammars efficiently
Pre-compiled finite automata for speed
""")
if __name__ == "__main__":
main()
Chapter 12: RL Fundamentals for LLMs
“Before you can teach a model with human feedback, you need to speak the language of reinforcement learning.”
Learning Objectives
By the end of this chapter, you will be able to:
- Explain the core RL concepts: states, actions, rewards, policies
- Understand value functions and the Bellman equation
- Implement policy gradients and the REINFORCE algorithm
- Explain PPO (Proximal Policy Optimization) and why it’s used for LLMs
Prerequisites
- Completed Part III (LLM Inference)
- Basic calculus (derivatives)
- Familiarity with neural network training
Concept Overview
RL in 60 Seconds
Supervised Learning: Given input X, predict label Y (teacher provides answer)
Reinforcement Learning: Given state S, take action A, observe reward R (learn from trial and error)
┌─────────────────────────────────────────────────────────────────────┐
│ RL FRAMEWORK │
│ │
│ ┌─────────┐ action a ┌─────────────┐ │
│ │ Agent │ ─────────────────────────► Environment │ │
│ │ (Policy)│ ◄───────────────────────── (World) │ │
│ └─────────┘ state s, reward r └─────────────┘ │
│ │
│ Goal: Learn policy π(a|s) that maximizes cumulative reward │
└─────────────────────────────────────────────────────────────────────┘
The LLM as an RL Agent
| RL Concept | LLM Interpretation |
|---|---|
| State | Prompt + generated tokens so far |
| Action | Next token to generate |
| Policy | The LLM itself (token probabilities) |
| Reward | Human preference score (or reward model) |
| Episode | One complete generation |
Value Functions: Predicting Future Rewards
State Value V(s): Expected total reward starting from state s
V(s) = E[R₀ + γR₁ + γ²R₂ + ... | S₀ = s]
Action Value Q(s,a): Expected total reward after taking action a in state s
Q(s,a) = E[R₀ + γR₁ + γ²R₂ + ... | S₀ = s, A₀ = a]
γ (gamma): Discount factor (0-1). Lower γ = short-sighted, higher γ = long-term thinking.
For LLMs, we typically use γ ≈ 1 (care equally about all future rewards).
The Bellman Equation
The fundamental equation of RL:
V(s) = E[R + γV(s') | S = s]
= Σₐ π(a|s) [R(s,a) + γ Σₛ' P(s'|s,a) V(s')]
“The value of a state is the immediate reward plus the discounted value of the next state.”
This recursive structure enables dynamic programming solutions.
Policy Gradients: Learning by Gradient Ascent
Instead of computing values, directly optimize the policy!
Objective: Maximize expected reward
J(θ) = E[Σₜ R(sₜ, aₜ)]
Policy Gradient Theorem:
∇J(θ) = E[Σₜ ∇log π_θ(aₜ|sₜ) · Gₜ]
Where Gₜ = total future reward from time t.
Intuition:
- If action led to high reward: increase its probability (positive gradient)
- If action led to low reward: decrease its probability (negative gradient)
REINFORCE Algorithm
The simplest policy gradient algorithm:
for episode in episodes:
# Collect trajectory
states, actions, rewards = collect_episode(policy)
# Compute returns
returns = compute_returns(rewards, gamma)
# Update policy
for t, (s, a, G) in enumerate(zip(states, actions, returns)):
loss = -log_prob(policy(s), a) * G
loss.backward()
optimizer.step()
Problem: High variance! Returns can vary wildly between episodes.
Variance Reduction: Baselines
Subtract a baseline from returns to reduce variance:
∇J(θ) = E[Σₜ ∇log π_θ(aₜ|sₜ) · (Gₜ - b(sₜ))]
Common baseline: Value function V(s) — learn to predict expected return.
This gives us the Advantage:
A(s,a) = Q(s,a) - V(s)
≈ R + γV(s') - V(s) (TD error)
“How much better is this action compared to the average?”
Actor-Critic: Best of Both Worlds
Actor: Policy network π_θ(a|s) Critic: Value network V_φ(s)
# Actor update (policy gradient with advantage)
advantage = reward + gamma * V(next_state) - V(state)
actor_loss = -log_prob(action) * advantage.detach()
# Critic update (value regression)
critic_loss = (V(state) - (reward + gamma * V(next_state).detach()))²
Generalized Advantage Estimation (GAE)
GAE smoothly interpolates between:
- Low bias, high variance (full returns)
- High bias, low variance (TD error)
A^GAE_t = Σₖ (γλ)^k δₜ₊ₖ
Where δₜ = rₜ + γV(sₜ₊₁) - V(sₜ) (TD error)
λ controls the tradeoff:
- λ = 0: Just TD error (high bias, low variance)
- λ = 1: Full returns (low bias, high variance)
Typical: λ = 0.95
PPO: The Industry Standard
PPO (Proximal Policy Optimization) adds trust region constraints:
“Don’t change the policy too much in one update.”
PPO-Clip objective:
L^CLIP(θ) = E[min(rₜ(θ)Aₜ, clip(rₜ(θ), 1-ε, 1+ε)Aₜ)]
Where rₜ(θ) = π_θ(aₜ|sₜ) / π_θold(aₜ|sₜ) (probability ratio)
Intuition:
- If advantage is positive and ratio is high: clip to prevent too much increase
- If advantage is negative and ratio is low: clip to prevent too much decrease
- Keeps policy changes bounded
Why PPO for LLMs?
- Stable training: Trust region prevents catastrophic forgetting
- Sample efficient: Reuses samples within trust region
- Proven at scale: Used by OpenAI, Anthropic, DeepMind
- Simple to implement: No second-order optimization
Code Walkthrough
Script 1: ppo_cartpole.py
A minimal PPO implementation on CartPole:
- Actor-Critic networks
- GAE advantage computation
- PPO-Clip objective
This isn’t for LLMs but shows PPO mechanics clearly.
Script 2: gae_visualizer.py
Visualizes how GAE works:
- Shows TD errors over trajectory
- Compares different λ values
- Demonstrates bias-variance tradeoff
The RLHF Connection
In RLHF:
- State: Prompt + partial response
- Action: Next token
- Reward: Comes from reward model (trained on human preferences)
- Episode: Complete response generation
The PPO objective becomes:
max E[R_reward_model(response) - β * KL(π || π_ref)]
Where:
- R_reward_model: Score from reward model
- KL term: Penalty for diverging from reference policy
- β: KL coefficient (prevents reward hacking)
Try It Yourself
Exercise 1: Implement REINFORCE
Implement REINFORCE for a simple environment:
- Collect episodes
- Compute returns
- Update policy
- Track learning curves
Exercise 2: Add a Baseline
Modify your REINFORCE to use a learned baseline:
- Add a value network
- Compute advantages
- Compare variance with/without baseline
Exercise 3: Understand PPO Clipping
For different advantage signs and probability ratios:
- Compute clipped and unclipped objectives
- Determine which is used
- Explain why clipping helps stability
Key Takeaways
- RL learns from rewards, not labels - Trial and error, not supervision
- Value functions predict future rewards - Enables credit assignment
- Policy gradients directly optimize the policy - No need to estimate values
- Baselines reduce variance - Critical for practical training
- PPO is stable and scalable - The go-to algorithm for RLHF
The RL Hierarchy
Simple ────────────────────────────────────► Complex
REINFORCE → Actor-Critic → A2C → PPO → RLHF with PPO
↓ ↓ ↓ ↓ ↓
High Value as Parallel Trust Multi-model
variance baseline training region orchestration
What’s Next?
In Chapter 13, we’ll dive into RLHF Computation Flow—how the Actor, Critic, Reward, and Reference models work together during training.
Further Reading
ppo_cartpole.py
Learn PPO mechanics on a simple game before applying to LLMs
This script implements PPO (Proximal Policy Optimization) on the classic CartPole environment. It’s simpler than LLM training but demonstrates all the same concepts.
What It Does
- Creates Actor (policy) and Critic (value) networks
- Collects episodes using the current policy
- Computes advantages using GAE
- Updates policy with PPO-Clip objective
- Tracks learning progress
Run It
pip install gymnasium # Install gym environment
python tutorial/part4-rlhf/chapter12-rl-fundamentals/scripts/ppo_cartpole.py
Expected Output
=== PPO on CartPole ===
Episode 10: Average Reward = 21.5
Episode 20: Average Reward = 45.3
Episode 30: Average Reward = 98.7
Episode 40: Average Reward = 187.2
Episode 50: Average Reward = 312.5
Episode 60: Average Reward = 500.0 (solved!)
Training complete! CartPole balanced for 500 steps.
Key Components
Actor Network:
class Actor(nn.Module):
def forward(self, state):
# Returns action probabilities
return F.softmax(self.net(state), dim=-1)
Critic Network:
class Critic(nn.Module):
def forward(self, state):
# Returns state value
return self.net(state)
PPO-Clip Loss:
ratio = new_prob / old_prob
clipped = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
loss = -torch.min(ratio * advantage, clipped * advantage).mean()
Why CartPole?
CartPole is the “Hello World” of RL:
- Simple (2D state, 2 actions)
- Fast feedback (episodes complete quickly)
- Clear success metric (balance for 500 steps)
The same PPO algorithm scales to LLMs with minimal changes!
Source Code
#!/usr/bin/env python3
"""
Minimal PPO Implementation on CartPole
This script demonstrates PPO's core concepts:
- Actor-Critic architecture
- GAE (Generalized Advantage Estimation)
- PPO-Clip objective
This is a simplified implementation for educational purposes.
For production RLHF, see verl, trl, or OpenRLHF.
Usage:
pip install gymnasium # if not installed
python ppo_cartpole.py
"""
import argparse
from dataclasses import dataclass
from typing import List, Tuple
import random
import math
# Try to import gymnasium, fall back to simulation if not available
try:
import gymnasium as gym
HAS_GYM = True
except ImportError:
HAS_GYM = False
print("Note: gymnasium not installed. Using simulated environment.")
@dataclass
class Experience:
"""Single step of experience."""
state: List[float]
action: int
reward: float
next_state: List[float]
done: bool
log_prob: float
value: float
class SimpleNetwork:
"""
Simple neural network simulation for demonstration.
In real implementations, use PyTorch or JAX.
"""
def __init__(self, input_size: int, hidden_size: int, output_size: int):
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
# Random initialization (simplified)
self.w1 = [[random.gauss(0, 0.1) for _ in range(hidden_size)]
for _ in range(input_size)]
self.b1 = [0.0] * hidden_size
self.w2 = [[random.gauss(0, 0.1) for _ in range(output_size)]
for _ in range(hidden_size)]
self.b2 = [0.0] * output_size
def forward(self, x: List[float]) -> List[float]:
"""Forward pass."""
# Hidden layer
hidden = []
for j in range(self.hidden_size):
h = self.b1[j]
for i in range(self.input_size):
h += x[i] * self.w1[i][j]
hidden.append(max(0, h)) # ReLU
# Output layer
output = []
for j in range(self.output_size):
o = self.b2[j]
for i in range(self.hidden_size):
o += hidden[i] * self.w2[i][j]
output.append(o)
return output
def update(self, grads: List[float], lr: float):
"""Simplified gradient update (demonstration only)."""
# In real implementations, use proper backpropagation
pass
class ActorCritic:
"""
Actor-Critic network for PPO.
Actor: Outputs action probabilities
Critic: Outputs state value
"""
def __init__(self, state_size: int, action_size: int, hidden_size: int = 64):
self.actor = SimpleNetwork(state_size, hidden_size, action_size)
self.critic = SimpleNetwork(state_size, hidden_size, 1)
self.action_size = action_size
def get_action(self, state: List[float]) -> Tuple[int, float]:
"""
Sample action from policy.
Returns: (action, log_probability)
"""
logits = self.actor.forward(state)
# Softmax
max_logit = max(logits)
exp_logits = [math.exp(l - max_logit) for l in logits]
sum_exp = sum(exp_logits)
probs = [e / sum_exp for e in exp_logits]
# Sample action
r = random.random()
cumsum = 0
action = 0
for i, p in enumerate(probs):
cumsum += p
if r < cumsum:
action = i
break
log_prob = math.log(probs[action] + 1e-8)
return action, log_prob
def get_value(self, state: List[float]) -> float:
"""Get state value from critic."""
return self.critic.forward(state)[0]
def get_action_prob(self, state: List[float], action: int) -> float:
"""Get probability of specific action."""
logits = self.actor.forward(state)
max_logit = max(logits)
exp_logits = [math.exp(l - max_logit) for l in logits]
sum_exp = sum(exp_logits)
return exp_logits[action] / sum_exp
def compute_gae(experiences: List[Experience], gamma: float = 0.99,
lam: float = 0.95) -> List[float]:
"""
Compute Generalized Advantage Estimation.
GAE balances bias and variance in advantage estimation:
- λ=0: Just TD error (high bias, low variance)
- λ=1: Full returns (low bias, high variance)
"""
advantages = []
gae = 0
# Iterate backwards through experiences
for i in reversed(range(len(experiences))):
exp = experiences[i]
if exp.done:
next_value = 0
else:
next_value = experiences[i + 1].value if i + 1 < len(experiences) else 0
# TD error
delta = exp.reward + gamma * next_value - exp.value
# GAE
gae = delta + gamma * lam * (0 if exp.done else gae)
advantages.insert(0, gae)
return advantages
def ppo_update(actor_critic: ActorCritic, experiences: List[Experience],
advantages: List[float], clip_epsilon: float = 0.2,
lr: float = 3e-4) -> dict:
"""
PPO update step.
Key components:
1. Compute probability ratio (new policy / old policy)
2. Clip the ratio to prevent large updates
3. Take minimum of clipped and unclipped objectives
"""
# Compute returns for value function update
returns = []
for i, exp in enumerate(experiences):
returns.append(exp.value + advantages[i])
# PPO objectives (computed but not applied in this demo)
policy_losses = []
value_losses = []
clip_fractions = []
for i, exp in enumerate(experiences):
# New probability
new_prob = actor_critic.get_action_prob(exp.state, exp.action)
new_log_prob = math.log(new_prob + 1e-8)
# Probability ratio
ratio = math.exp(new_log_prob - exp.log_prob)
# Advantage
adv = advantages[i]
# Clipped objective
unclipped = ratio * adv
clipped = max(min(ratio, 1 + clip_epsilon), 1 - clip_epsilon) * adv
# PPO loss (take minimum)
policy_loss = -min(unclipped, clipped)
policy_losses.append(policy_loss)
# Track clipping
clip_fractions.append(1 if abs(ratio - 1) > clip_epsilon else 0)
# Value loss
new_value = actor_critic.get_value(exp.state)
value_loss = (new_value - returns[i]) ** 2
value_losses.append(value_loss)
return {
'policy_loss': sum(policy_losses) / len(policy_losses),
'value_loss': sum(value_losses) / len(value_losses),
'clip_fraction': sum(clip_fractions) / len(clip_fractions),
}
class SimulatedCartPole:
"""Simple CartPole simulation for when gymnasium isn't available."""
def __init__(self):
self.reset()
def reset(self) -> List[float]:
self.x = random.uniform(-0.05, 0.05)
self.x_dot = random.uniform(-0.05, 0.05)
self.theta = random.uniform(-0.05, 0.05)
self.theta_dot = random.uniform(-0.05, 0.05)
self.steps = 0
return [self.x, self.x_dot, self.theta, self.theta_dot]
def step(self, action: int) -> Tuple[List[float], float, bool]:
# Simplified physics
force = 10.0 if action == 1 else -10.0
self.x_dot += 0.02 * force + random.gauss(0, 0.01)
self.x += 0.02 * self.x_dot
self.theta_dot += 0.05 * force * (1 if self.theta > 0 else -1)
self.theta_dot += random.gauss(0, 0.01)
self.theta += 0.02 * self.theta_dot
self.steps += 1
done = (abs(self.x) > 2.4 or abs(self.theta) > 0.21 or self.steps > 200)
reward = 1.0 if not done else 0.0
return [self.x, self.x_dot, self.theta, self.theta_dot], reward, done
def run_episode(env, actor_critic: ActorCritic) -> List[Experience]:
"""Run one episode and collect experiences."""
if HAS_GYM:
state, _ = env.reset()
state = list(state)
else:
state = env.reset()
experiences = []
done = False
while not done:
action, log_prob = actor_critic.get_action(state)
value = actor_critic.get_value(state)
if HAS_GYM:
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
next_state = list(next_state)
else:
next_state, reward, done = env.step(action)
experiences.append(Experience(
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
log_prob=log_prob,
value=value,
))
state = next_state
return experiences
def main():
parser = argparse.ArgumentParser(description="Minimal PPO on CartPole")
parser.add_argument("--episodes", "-e", type=int, default=100,
help="Number of episodes to train")
parser.add_argument("--gamma", type=float, default=0.99,
help="Discount factor")
parser.add_argument("--lam", type=float, default=0.95,
help="GAE lambda")
parser.add_argument("--clip-epsilon", type=float, default=0.2,
help="PPO clip parameter")
args = parser.parse_args()
print("╔" + "═" * 68 + "╗")
print("║" + " MINIMAL PPO ON CARTPOLE".center(68) + "║")
print("╚" + "═" * 68 + "╝")
# Create environment
if HAS_GYM:
env = gym.make("CartPole-v1")
print("\nUsing gymnasium CartPole-v1")
else:
env = SimulatedCartPole()
print("\nUsing simulated CartPole")
# Create actor-critic
actor_critic = ActorCritic(state_size=4, action_size=2)
print(f"\nConfiguration:")
print(f" Episodes: {args.episodes}")
print(f" Gamma: {args.gamma}")
print(f" GAE Lambda: {args.lam}")
print(f" Clip Epsilon: {args.clip_epsilon}")
# Training loop
print("\n" + "=" * 60)
print(" TRAINING")
print("=" * 60)
episode_rewards = []
for episode in range(args.episodes):
# Collect episode
experiences = run_episode(env, actor_critic)
# Compute advantages using GAE
advantages = compute_gae(experiences, args.gamma, args.lam)
# PPO update
update_stats = ppo_update(actor_critic, experiences, advantages,
args.clip_epsilon)
# Track rewards
total_reward = sum(exp.reward for exp in experiences)
episode_rewards.append(total_reward)
if (episode + 1) % 10 == 0:
avg_reward = sum(episode_rewards[-10:]) / min(10, len(episode_rewards))
print(f"Episode {episode + 1:4d} | Reward: {total_reward:6.1f} | "
f"Avg(10): {avg_reward:6.1f} | "
f"Clip: {update_stats['clip_fraction']:.2f}")
# Summary
print("\n" + "=" * 60)
print(" SUMMARY")
print("=" * 60)
avg_first_10 = sum(episode_rewards[:10]) / 10
avg_last_10 = sum(episode_rewards[-10:]) / 10
print(f"\nAverage reward (first 10 episodes): {avg_first_10:.1f}")
print(f"Average reward (last 10 episodes): {avg_last_10:.1f}")
print(f"Improvement: {avg_last_10 - avg_first_10:.1f}")
# Explain PPO
print("\n" + "=" * 60)
print(" PPO EXPLAINED")
print("=" * 60)
print(f"""
What just happened:
1. EPISODE COLLECTION
Agent interacted with environment
Stored: states, actions, rewards, log probs, values
2. ADVANTAGE COMPUTATION (GAE)
For each step, computed "how much better than expected"
λ={args.lam} balances bias/variance
3. PPO UPDATE
Computed policy gradient with clipped objective
Clip ε={args.clip_epsilon} prevents too large updates
Key PPO Components:
ratio = π_new(a|s) / π_old(a|s)
L^CLIP = min(ratio × A, clip(ratio, 1-ε, 1+ε) × A)
- If A > 0 (good action): ratio clipped at 1+ε (prevent overconfidence)
- If A < 0 (bad action): ratio clipped at 1-ε (prevent overcorrection)
Why PPO for RLHF:
- Stable training (no huge policy shifts)
- Sample efficient (reuse trajectories)
- Simple to implement and tune
- Proven at scale (ChatGPT, Claude, etc.)
""")
if HAS_GYM:
env.close()
if __name__ == "__main__":
main()
gae_visualizer.py
Visualize Generalized Advantage Estimation (GAE)
This script helps you understand how GAE works by visualizing the advantage computation for different λ values.
What It Does
- Creates a sample trajectory with rewards and values
- Computes advantages with different λ values
- Visualizes how λ affects the bias-variance tradeoff
- Shows why λ=0.95 is common
Run It
python tutorial/part4-rlhf/chapter12-rl-fundamentals/scripts/gae_visualizer.py
Example Output
=== GAE Visualizer ===
Sample trajectory (10 steps):
Rewards: [0, 0, 0, 1, 0, 0, 0, 0, 1, 0]
Values: [0.5, 0.6, 0.7, 0.8, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
TD Errors (δ):
Step 0: δ = 0 + 0.99*0.6 - 0.5 = 0.094
Step 1: δ = 0 + 0.99*0.7 - 0.6 = 0.093
...
Advantages by λ:
λ = 0.0 (TD error only, high bias):
A = [0.09, 0.09, 0.11, 0.42, -0.36, 0.10, 0.11, 0.12, 0.31, -0.89]
Variance: 0.15
λ = 0.5 (balanced):
A = [0.21, 0.19, 0.22, 0.38, -0.15, 0.16, 0.18, 0.18, 0.24, -0.89]
Variance: 0.12
λ = 0.95 (common choice):
A = [0.45, 0.38, 0.35, 0.32, -0.02, 0.23, 0.22, 0.20, 0.18, -0.89]
Variance: 0.14
λ = 1.0 (full returns, low bias):
A = [0.52, 0.44, 0.40, 0.35, 0.01, 0.26, 0.24, 0.21, 0.18, -0.89]
Variance: 0.16
The GAE Formula
A^GAE_t = δt + (γλ)δt+1 + (γλ)²δt+2 + ...
= Σ (γλ)^k δt+k
Where δt = rt + γV(st+1) - V(st)
Why λ = 0.95?
- λ = 0: Only considers immediate TD error (high bias, low variance)
- λ = 1: Full Monte Carlo returns (low bias, high variance)
- λ = 0.95: Good balance - mostly looks ahead, slight smoothing
Source Code
#!/usr/bin/env python3
"""
GAE (Generalized Advantage Estimation) Visualizer
This script demonstrates how GAE works and its effect on advantage estimation.
Usage:
python gae_visualizer.py
"""
import argparse
from typing import List, Tuple
def generate_trajectory(length: int = 10) -> Tuple[List[float], List[float]]:
"""
Generate a sample trajectory with rewards and values.
Returns:
rewards: List of rewards at each step
values: List of value estimates at each step
"""
# Simulated trajectory: mostly small rewards, occasional large
rewards = [
0.1, 0.1, 0.2, 0.1, 0.5, # Early exploration
0.1, 0.3, 0.1, 0.1, 1.0, # Some success at end
][:length]
# Value estimates (what the critic predicts)
values = [
0.8, 0.7, 0.7, 0.6, 0.5, # Decreasing as end approaches
0.4, 0.4, 0.3, 0.2, 0.1,
][:length]
return rewards, values
def compute_td_errors(rewards: List[float], values: List[float],
gamma: float = 0.99) -> List[float]:
"""
Compute TD (Temporal Difference) errors.
TD error = r_t + γV(s_{t+1}) - V(s_t)
This is the "surprise" - how much better/worse than expected.
"""
td_errors = []
n = len(rewards)
for t in range(n):
r = rewards[t]
v_t = values[t]
v_next = values[t + 1] if t + 1 < n else 0 # Terminal state has 0 value
delta = r + gamma * v_next - v_t
td_errors.append(delta)
return td_errors
def compute_gae(td_errors: List[float], gamma: float = 0.99,
lam: float = 0.95) -> List[float]:
"""
Compute GAE advantages.
A^GAE_t = Σ_{k=0}^{T-t} (γλ)^k δ_{t+k}
λ controls bias-variance tradeoff:
- λ=0: Just TD error (high bias, low variance)
- λ=1: Full returns minus baseline (low bias, high variance)
"""
advantages = []
gae = 0
n = len(td_errors)
# Compute backwards
for t in reversed(range(n)):
gae = td_errors[t] + gamma * lam * gae
advantages.insert(0, gae)
return advantages
def compute_monte_carlo_returns(rewards: List[float], values: List[float],
gamma: float = 0.99) -> List[float]:
"""
Compute Monte Carlo returns (full returns minus baseline).
This is GAE with λ=1.
"""
n = len(rewards)
returns = [0.0] * n
G = 0
for t in reversed(range(n)):
G = rewards[t] + gamma * G
returns[t] = G - values[t] # Advantage = return - baseline
return returns
def visualize_advantages(rewards: List[float], values: List[float],
td_errors: List[float],
advantages_by_lambda: dict):
"""Visualize how different λ values affect advantage estimation."""
n = len(rewards)
print("\n" + "=" * 80)
print(" TRAJECTORY DATA")
print("=" * 80)
print(f"\n{'Step':<6} {'Reward':<10} {'Value':<10} {'TD Error':<12}")
print("-" * 40)
for t in range(n):
print(f"{t:<6} {rewards[t]:<10.2f} {values[t]:<10.2f} {td_errors[t]:<12.4f}")
print("\n" + "=" * 80)
print(" GAE WITH DIFFERENT λ VALUES")
print("=" * 80)
header = f"{'Step':<6}"
for lam in sorted(advantages_by_lambda.keys()):
header += f"{'λ=' + str(lam):<12}"
print(f"\n{header}")
print("-" * (6 + 12 * len(advantages_by_lambda)))
for t in range(n):
row = f"{t:<6}"
for lam in sorted(advantages_by_lambda.keys()):
row += f"{advantages_by_lambda[lam][t]:<12.4f}"
print(row)
def analyze_bias_variance():
"""Analyze the bias-variance tradeoff in GAE."""
print("\n" + "=" * 80)
print(" BIAS-VARIANCE TRADEOFF ANALYSIS")
print("=" * 80)
print("""
GAE with different λ values:
┌────────────────────────────────────────────────────────────────────┐
│ λ = 0 (TD Error Only) │
│ │
│ A_t = δ_t = r_t + γV(s_{t+1}) - V(s_t) │
│ │
│ Properties: │
│ - HIGH BIAS: Only looks one step ahead │
│ - LOW VARIANCE: Single reward, single value estimate │
│ - Fast to adapt, but might miss long-term patterns │
└────────────────────────────────────────────────────────────────────┘
┌────────────────────────────────────────────────────────────────────┐
│ λ = 1 (Monte Carlo) │
│ │
│ A_t = G_t - V(s_t) = Σ γ^k r_{t+k} - V(s_t) │
│ │
│ Properties: │
│ - LOW BIAS: Uses all future rewards │
│ - HIGH VARIANCE: Accumulates noise from many rewards │
│ - Accurate but slow to learn │
└────────────────────────────────────────────────────────────────────┘
┌────────────────────────────────────────────────────────────────────┐
│ λ = 0.95 (Typical) │
│ │
│ A_t = Σ (γλ)^k δ_{t+k} │
│ │
│ Properties: │
│ - BALANCED: Weights earlier steps more │
│ - PRACTICAL: Good empirical performance │
│ - Exponential decay of TD errors │
└────────────────────────────────────────────────────────────────────┘
The weighting scheme (λ = 0.95, γ = 0.99):
Step t: weight = 1.00
Step t+1: weight = 0.94 (γλ = 0.99 × 0.95)
Step t+2: weight = 0.88 (γλ)²
Step t+3: weight = 0.83 (γλ)³
...
Step t+10: weight = 0.53
Far future TD errors contribute less, reducing variance while
maintaining enough signal for learning.
""")
def demonstrate_numerical_example():
"""Show a concrete numerical example of GAE computation."""
print("\n" + "=" * 80)
print(" NUMERICAL EXAMPLE: GAE COMPUTATION")
print("=" * 80)
# Simple 3-step trajectory
rewards = [0.1, 0.2, 1.0] # Big reward at end
values = [0.5, 0.4, 0.2] # Decreasing values
gamma = 0.99
lam = 0.95
print(f"""
Trajectory:
Step 0: r=0.1, V=0.5
Step 1: r=0.2, V=0.4
Step 2: r=1.0, V=0.2 (terminal)
TD Errors (δ_t = r_t + γV_{t+1} - V_t):
δ_0 = 0.1 + 0.99×0.4 - 0.5 = {0.1 + 0.99*0.4 - 0.5:.4f}
δ_1 = 0.2 + 0.99×0.2 - 0.4 = {0.2 + 0.99*0.2 - 0.4:.4f}
δ_2 = 1.0 + 0.99×0.0 - 0.2 = {1.0 + 0.99*0.0 - 0.2:.4f}
GAE Computation (working backwards, λ={lam}):
A_2 = δ_2 = {1.0 + 0.99*0.0 - 0.2:.4f}
A_1 = δ_1 + γλ×A_2 = {0.2 + 0.99*0.2 - 0.4:.4f} + {gamma*lam}×{1.0 + 0.99*0.0 - 0.2:.4f}
= {(0.2 + 0.99*0.2 - 0.4) + gamma*lam*(1.0 + 0.99*0.0 - 0.2):.4f}
A_0 = δ_0 + γλ×A_1
= {(0.1 + 0.99*0.4 - 0.5) + gamma*lam*((0.2 + 0.99*0.2 - 0.4) + gamma*lam*(1.0 + 0.99*0.0 - 0.2)):.4f}
Notice: Step 0's advantage includes discounted information about the
big reward at step 2, but that information is attenuated by (γλ)².
""")
def main():
parser = argparse.ArgumentParser(description="GAE Visualizer")
parser.add_argument("--gamma", type=float, default=0.99,
help="Discount factor")
parser.add_argument("--trajectory-length", type=int, default=10,
help="Length of trajectory")
args = parser.parse_args()
print("╔" + "═" * 78 + "╗")
print("║" + " GENERALIZED ADVANTAGE ESTIMATION (GAE) VISUALIZER".center(78) + "║")
print("╚" + "═" * 78 + "╝")
# Generate trajectory
rewards, values = generate_trajectory(args.trajectory_length)
# Compute TD errors
td_errors = compute_td_errors(rewards, values, args.gamma)
# Compute GAE for different lambda values
lambda_values = [0.0, 0.5, 0.9, 0.95, 1.0]
advantages_by_lambda = {}
for lam in lambda_values:
advantages_by_lambda[lam] = compute_gae(td_errors, args.gamma, lam)
# Visualize
visualize_advantages(rewards, values, td_errors, advantages_by_lambda)
# Analysis
analyze_bias_variance()
# Numerical example
demonstrate_numerical_example()
# Key insights
print("\n" + "=" * 80)
print(" KEY INSIGHTS FOR RLHF")
print("=" * 80)
print("""
In RLHF training:
1. TRAJECTORY = One response generation
- States: prompt + partial response
- Actions: generated tokens
- Reward: typically only at the end (from reward model)
2. GAE HELPS WITH CREDIT ASSIGNMENT
- Which tokens contributed to the final reward?
- GAE propagates reward signal backwards through the response
- λ controls how far back the signal reaches
3. TYPICAL RLHF SETTINGS
- γ = 0.99 or 1.0 (we care about all tokens)
- λ = 0.95 (good balance)
- Sparse reward (only at end of generation)
4. VALUE FUNCTION IN RLHF
- Critic network predicts expected reward
- Helps reduce variance in policy gradient
- Often shares layers with the policy (actor-critic)
5. PPO USES GAE ADVANTAGES
- Compute GAE for each token in response
- Update policy using PPO-Clip objective
- Bounded updates prevent catastrophic forgetting
""")
if __name__ == "__main__":
main()
Chapter 13: RLHF Computation Flow
“Four models, one update. Orchestrating RLHF is like conducting a symphony of neural networks.”
Learning Objectives
By the end of this chapter, you will be able to:
- Name the four models in RLHF and their roles
- Trace the data flow through one RLHF training step
- Explain why we need a reference model
- Calculate memory requirements for RLHF training
Prerequisites
- Completed Chapter 12 (RL Fundamentals)
- Understanding of PPO and advantage estimation
- Familiarity with model architecture (transformers)
Concept Overview
The Four Models of RLHF
| Model | Role | Updates? | Size |
|---|---|---|---|
| Actor (Policy) | Generates responses | Yes | Full LLM |
| Critic (Value) | Predicts expected reward | Yes | Full LLM or smaller |
| Reward | Scores responses | No | Trained separately |
| Reference | Prevents reward hacking | No | Copy of initial actor |
┌─────────────────────────────────────────────────────────────────────────┐
│ RLHF MODEL ORCHESTRA │
│ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ Actor │ │ Critic │ │ Reward │ │Reference│ │
│ │(Policy) │ │(Value) │ │ Model │ │ Policy │ │
│ └────┬────┘ └────┬────┘ └────┬────┘ └────┬────┘ │
│ │ │ │ │ │
│ │ │ │ │ │
│ Generates Estimates Evaluates Anchors │
│ responses future reward quality updates │
│ │ │ │ │ │
│ └────────────────┴────────────────┴────────────────┘ │
│ │ │
│ PPO Update │
│ │
└─────────────────────────────────────────────────────────────────────────┘
The RLHF Training Loop
One step of RLHF training:
1. SAMPLE PROMPTS
└─► Get batch of prompts from dataset
2. GENERATE RESPONSES (Actor)
└─► Actor generates responses for each prompt
└─► Save token probabilities
3. SCORE RESPONSES (Reward Model)
└─► Reward model scores each response
└─► This is the "human feedback" signal
4. COMPUTE KL PENALTY (Reference)
└─► Compare actor probabilities to reference
└─► Penalize divergence (prevent reward hacking)
5. COMPUTE ADVANTAGES (Critic + GAE)
└─► Critic estimates values
└─► GAE computes advantages
6. PPO UPDATE (Actor + Critic)
└─► Update actor using PPO objective
└─► Update critic to predict rewards better
Detailed Data Flow
Prompt
│
▼
┌──────────────────────────┐
│ ACTOR │
│ Generate response │
│ Output: tokens, logits │
└───────────┬──────────────┘
│
┌─────────────┼─────────────┐
│ │ │
▼ ▼ ▼
┌──────────┐ ┌──────────┐ ┌──────────┐
│ REWARD │ │REFERENCE │ │ CRITIC │
│ MODEL │ │ │ │ │
│Score: 0.8│ │ logits │ │ values │
└────┬─────┘ └────┬─────┘ └────┬─────┘
│ │ │
└─────────────┴─────────────┘
│
▼
┌─────────────────┐
│ COMPUTE REWARD │
│ R = R_rm - β*KL │
└────────┬────────┘
│
▼
┌─────────────────┐
│ COMPUTE GAE │
│ advantages │
└────────┬────────┘
│
▼
┌─────────────────┐
│ PPO UPDATE │
│ actor, critic │
└─────────────────┘
The Reward Calculation
The reward for each response combines:
R_total = R_reward_model - β * KL(π_actor || π_reference)
R_reward_model: Score from reward model (trained on human preferences)
KL penalty: Prevents “reward hacking”
Without KL penalty, the model might find degenerate solutions:
- Repeating phrases that game the reward model
- Producing unnatural but high-scoring outputs
- Catastrophic forgetting of language capabilities
Why Reference Model?
The reference model is a frozen copy of the initial policy. It serves as an anchor:
Without reference:
Actor → "AMAZING! INCREDIBLE! BEST EVER!" (reward hacks)
With reference:
Actor → Natural response similar to reference
If too different → KL penalty reduces total reward
KL divergence measures how different the actor’s distribution is from the reference:
KL(π_actor || π_ref) = Σ π_actor(token) * log(π_actor(token) / π_ref(token))
Per-Token vs Per-Response Rewards
In practice, rewards can be assigned:
Per-response (most common):
- Reward model scores complete response
- Reward assigned to last token
- Other tokens get 0 reward
- GAE propagates signal backwards
Per-token (process reward):
- Each token gets a score
- More fine-grained signal
- Harder to obtain labels
Memory Requirements
For a 7B parameter model with RLHF:
| Component | Memory (FP16) |
|---|---|
| Actor | 14 GB |
| Critic | 14 GB |
| Reward Model | 14 GB |
| Reference | 14 GB |
| Optimizer states | 56 GB |
| Activations | ~20 GB |
| Total | ~130 GB |
For 70B: multiply by 10 → ~1.3 TB!
This is why RLHF needs careful system design.
Code Walkthrough
Script 1: rlhf_loop_pseudo.py
Pseudocode implementation of the RLHF loop:
- Shows exact data flow
- Demonstrates each computation
- Explains intermediate values
Script 2: reward_calculator.py
Implements reward calculation:
- Reward model scoring
- KL divergence computation
- Total reward with penalty
Common Questions
Q: Why not just fine-tune on high-reward responses?
Supervised fine-tuning on selected responses (rejection sampling) works, but:
- Wastes low-reward samples
- No gradient signal about “how bad” something is
- PPO makes more efficient use of data
Q: Can the critic share weights with the actor?
Yes! Common approaches:
- Separate critic: Full model, independent
- Shared backbone: Same transformer, different heads
- Value head: Small MLP on top of actor’s hidden states
Shared approaches save memory but may have optimization conflicts.
Q: How is the reward model trained?
Before RLHF:
- Collect comparison data: “Response A is better than B”
- Train reward model with ranking loss
- Reward model learns human preferences
The reward model is then frozen during RLHF.
Try It Yourself
Exercise 1: Trace Data Flow
For a batch of 4 prompts with max response length 100:
- What are the tensor shapes at each stage?
- How many forward passes per training step?
- What’s the communication pattern?
Exercise 2: KL Penalty Tuning
The KL coefficient β controls the penalty:
- β too low: reward hacking
- β too high: no learning
Experiment (conceptually):
- What happens if β = 0?
- What happens if β = 10?
- How would you find the right β?
Exercise 3: Memory Optimization
You have 8× 80GB GPUs and want to train a 70B model with RLHF.
- What parallelism strategies would you use?
- Can you fit all 4 models?
- What trade-offs would you make?
Key Takeaways
- Four models, one loop - Actor, Critic, Reward, Reference
- KL penalty is crucial - Prevents reward hacking
- GAE for credit assignment - Propagates reward signal
- Memory is the bottleneck - 4× model weights minimum
- Reference stays frozen - Anchors the learning
The RLHF Equation
The complete PPO-RLHF objective:
L = E[
L^PPO(actor_params) # Policy improvement
- c₁ * L^VF(critic_params) # Value function loss
+ c₂ * Entropy(actor) # Exploration bonus
]
Where:
L^PPO = min(ratio * A, clip(ratio) * A)
L^VF = (V_predicted - R_observed)²
A = GAE(rewards, values)
rewards = R_reward_model - β * KL
What’s Next?
In Chapter 14, we’ll explore RLHF System Architecture—how to efficiently orchestrate these models across GPUs with co-location, disaggregation, and hybrid approaches.
Further Reading
rlhf_loop_pseudo.py
The complete RLHF training loop in pseudocode
This script shows the exact computation flow of one RLHF training step, making it easy to understand what happens and when.
What It Does
- Simulates all four models (Actor, Critic, Reward, Reference)
- Walks through each step of the training loop
- Shows tensor shapes and intermediate values
- Demonstrates the complete PPO update
Run It
python tutorial/part4-rlhf/chapter13-rlhf-flow/scripts/rlhf_loop_pseudo.py
Example Output
=== RLHF Training Loop Demo ===
Step 1: Sample prompts
Batch size: 4
Prompt shapes: (4, 64) tokens
Step 2: Generate responses (Actor)
Actor forward pass...
Generated tokens: (4, 128)
Actor logits: (4, 128, 50257)
Old log probs: (4, 128)
Step 3: Score responses (Reward Model)
Reward model forward pass...
Scores: [0.73, 0.45, 0.91, 0.62]
Step 4: Compute KL penalty (Reference)
Reference forward pass...
Reference log probs: (4, 128)
KL divergence per token: (4, 128)
Mean KL: 0.23
Step 5: Compute total rewards
reward = reward_model_score - β * KL
Total rewards: [0.50, 0.28, 0.75, 0.41]
Step 6: Compute advantages (Critic + GAE)
Critic forward pass...
Values: (4, 128)
GAE advantages: (4, 128)
Step 7: PPO update
Ratio = exp(new_log_prob - old_log_prob)
Clipped ratio: clip(ratio, 0.8, 1.2)
Actor loss: -0.042
Critic loss: 0.156
Update complete!
The Core Loop
for batch in dataloader:
# 1. Generate
responses, old_logprobs = actor.generate(batch.prompts)
# 2. Score
rewards = reward_model(batch.prompts, responses)
# 3. KL penalty
ref_logprobs = reference(batch.prompts, responses)
kl = old_logprobs - ref_logprobs
rewards = rewards - beta * kl
# 4. Advantages
values = critic(batch.prompts, responses)
advantages = gae(rewards, values)
# 5. PPO update
new_logprobs = actor(batch.prompts, responses)
ratio = (new_logprobs - old_logprobs).exp()
actor_loss = -torch.min(ratio * advantages,
ratio.clamp(0.8, 1.2) * advantages)
critic_loss = (values - rewards) ** 2
# 6. Backprop
(actor_loss + critic_loss).backward()
optimizer.step()
Source Code
#!/usr/bin/env python3
"""
RLHF Training Loop Pseudocode
This script demonstrates the complete RLHF training loop with
detailed comments explaining each step.
This is PSEUDOCODE - not runnable without actual model implementations.
It's meant to illustrate the data flow and computations involved.
Usage:
python rlhf_loop_pseudo.py
"""
from dataclasses import dataclass
from typing import List, Optional
import random
import math
@dataclass
class Prompt:
"""A training prompt."""
text: str
tokens: List[int]
@dataclass
class Response:
"""A generated response with metadata."""
tokens: List[int]
log_probs: List[float] # From actor
ref_log_probs: List[float] # From reference
values: List[float] # From critic
reward_score: float # From reward model
@dataclass
class Experience:
"""One token of experience for PPO."""
token: int
log_prob: float
ref_log_prob: float
value: float
reward: float
advantage: float
def rlhf_training_step(
prompts: List[Prompt],
actor, # The policy model being trained
critic, # The value function
reward_model, # Frozen reward model
reference, # Frozen reference policy
kl_coef: float = 0.02,
gamma: float = 1.0,
lam: float = 0.95,
clip_epsilon: float = 0.2,
) -> dict:
"""
One step of RLHF training.
This function shows the complete data flow through all four models.
"""
print("=" * 70)
print(" RLHF TRAINING STEP")
print("=" * 70)
# =========================================================================
# STEP 1: Generate Responses (Actor)
# =========================================================================
print("\n[Step 1] GENERATE RESPONSES")
print("-" * 50)
responses = []
for prompt in prompts:
# Generate response from actor
# In reality: autoregressive generation with temperature sampling
response_tokens = generate_response(actor, prompt)
# Get log probabilities from actor
actor_log_probs = get_log_probs(actor, prompt.tokens, response_tokens)
responses.append(Response(
tokens=response_tokens,
log_probs=actor_log_probs,
ref_log_probs=[], # Filled in step 3
values=[], # Filled in step 4
reward_score=0, # Filled in step 2
))
print(f" Generated {len(response_tokens)} tokens for prompt")
# =========================================================================
# STEP 2: Score Responses (Reward Model)
# =========================================================================
print("\n[Step 2] SCORE RESPONSES (Reward Model)")
print("-" * 50)
for i, (prompt, response) in enumerate(zip(prompts, responses)):
# Get reward score for complete response
# In reality: forward pass through reward model
full_sequence = prompt.tokens + response.tokens
response.reward_score = score_response(reward_model, full_sequence)
print(f" Response {i}: reward = {response.reward_score:.3f}")
# =========================================================================
# STEP 3: Compute KL Penalty (Reference Model)
# =========================================================================
print("\n[Step 3] COMPUTE KL PENALTY (Reference)")
print("-" * 50)
total_kl = 0
for prompt, response in zip(prompts, responses):
# Get reference log probabilities
response.ref_log_probs = get_log_probs(
reference, prompt.tokens, response.tokens
)
# Compute per-token KL divergence
kl_per_token = []
for actor_lp, ref_lp in zip(response.log_probs, response.ref_log_probs):
# KL = exp(actor_lp) * (actor_lp - ref_lp)
# Simplified: just the log ratio for penalty
kl = actor_lp - ref_lp
kl_per_token.append(kl)
avg_kl = sum(kl_per_token) / len(kl_per_token)
total_kl += avg_kl
avg_kl = total_kl / len(responses)
print(f" Average KL divergence: {avg_kl:.4f}")
# =========================================================================
# STEP 4: Compute Values (Critic)
# =========================================================================
print("\n[Step 4] COMPUTE VALUES (Critic)")
print("-" * 50)
for prompt, response in zip(prompts, responses):
# Get value estimates for each token position
# In reality: forward pass through critic
response.values = get_values(critic, prompt.tokens, response.tokens)
print(f" Values computed: mean={sum(response.values)/len(response.values):.3f}")
# =========================================================================
# STEP 5: Compute Rewards with KL Penalty
# =========================================================================
print("\n[Step 5] COMPUTE REWARDS WITH KL PENALTY")
print("-" * 50)
all_experiences = []
for prompt, response in zip(prompts, responses):
experiences = []
for t in range(len(response.tokens)):
# Per-token KL penalty
kl_penalty = kl_coef * (response.log_probs[t] - response.ref_log_probs[t])
# Reward: only at last token, minus KL at every token
if t == len(response.tokens) - 1:
token_reward = response.reward_score - kl_penalty
else:
token_reward = -kl_penalty # Just KL penalty for non-final tokens
experiences.append(Experience(
token=response.tokens[t],
log_prob=response.log_probs[t],
ref_log_prob=response.ref_log_probs[t],
value=response.values[t],
reward=token_reward,
advantage=0, # Computed in step 6
))
all_experiences.append(experiences)
final_reward = experiences[-1].reward
print(f" Final token reward: {final_reward:.3f} "
f"(score={response.reward_score:.3f}, kl_penalty included)")
# =========================================================================
# STEP 6: Compute GAE Advantages
# =========================================================================
print("\n[Step 6] COMPUTE GAE ADVANTAGES")
print("-" * 50)
for experiences in all_experiences:
# GAE computation (backwards)
gae = 0
for t in reversed(range(len(experiences))):
exp = experiences[t]
if t == len(experiences) - 1:
next_value = 0 # Terminal state
else:
next_value = experiences[t + 1].value
# TD error
delta = exp.reward + gamma * next_value - exp.value
# GAE
gae = delta + gamma * lam * gae
exp.advantage = gae
# Normalize advantages
advantages = [e.advantage for e in experiences]
mean_adv = sum(advantages) / len(advantages)
std_adv = (sum((a - mean_adv) ** 2 for a in advantages) / len(advantages)) ** 0.5
for exp in experiences:
exp.advantage = (exp.advantage - mean_adv) / (std_adv + 1e-8)
print(f" Advantages computed and normalized")
# =========================================================================
# STEP 7: PPO Update
# =========================================================================
print("\n[Step 7] PPO UPDATE")
print("-" * 50)
# Flatten all experiences
flat_experiences = [exp for exps in all_experiences for exp in exps]
# Compute PPO losses
policy_losses = []
value_losses = []
clip_fractions = []
for exp in flat_experiences:
# New log probability (after potential update)
# In reality: forward pass through updated actor
new_log_prob = exp.log_prob # Placeholder
# Probability ratio
ratio = math.exp(new_log_prob - exp.log_prob)
# Clipped objective
unclipped = ratio * exp.advantage
clipped = max(min(ratio, 1 + clip_epsilon), 1 - clip_epsilon) * exp.advantage
policy_loss = -min(unclipped, clipped)
policy_losses.append(policy_loss)
# Value loss
# In reality: new value prediction
new_value = exp.value # Placeholder
value_loss = (new_value - (exp.reward + gamma * 0)) ** 2 # Simplified
value_losses.append(value_loss)
# Track clipping
if abs(ratio - 1) > clip_epsilon:
clip_fractions.append(1)
else:
clip_fractions.append(0)
avg_policy_loss = sum(policy_losses) / len(policy_losses)
avg_value_loss = sum(value_losses) / len(value_losses)
avg_clip_frac = sum(clip_fractions) / len(clip_fractions)
print(f" Policy loss: {avg_policy_loss:.4f}")
print(f" Value loss: {avg_value_loss:.4f}")
print(f" Clip fraction: {avg_clip_frac:.2%}")
# =========================================================================
# Summary
# =========================================================================
print("\n" + "=" * 70)
print(" STEP SUMMARY")
print("=" * 70)
print(f"""
Models used:
- Actor: Generated {sum(len(r.tokens) for r in responses)} total tokens
- Reward: Scored {len(responses)} responses
- Reference: Computed KL for {len(responses)} responses
- Critic: Estimated values for {sum(len(r.tokens) for r in responses)} tokens
Losses:
- Policy loss: {avg_policy_loss:.4f}
- Value loss: {avg_value_loss:.4f}
KL penalty:
- Average KL: {avg_kl:.4f}
- KL coefficient: {kl_coef}
- Total KL penalty: {avg_kl * kl_coef:.4f}
""")
return {
'policy_loss': avg_policy_loss,
'value_loss': avg_value_loss,
'kl': avg_kl,
'clip_fraction': avg_clip_frac,
}
# =============================================================================
# Placeholder functions (would be real model calls in practice)
# =============================================================================
def generate_response(actor, prompt: Prompt) -> List[int]:
"""Generate response tokens from actor."""
# Simulated: random tokens
length = random.randint(10, 30)
return [random.randint(0, 999) for _ in range(length)]
def get_log_probs(model, prompt_tokens: List[int],
response_tokens: List[int]) -> List[float]:
"""Get log probabilities from model."""
# Simulated: random log probs
return [random.uniform(-3, -0.1) for _ in response_tokens]
def score_response(reward_model, tokens: List[int]) -> float:
"""Get reward score from reward model."""
# Simulated: random score
return random.uniform(-1, 1)
def get_values(critic, prompt_tokens: List[int],
response_tokens: List[int]) -> List[float]:
"""Get value estimates from critic."""
# Simulated: decreasing values
n = len(response_tokens)
return [0.5 * (n - i) / n for i in range(n)]
def main():
print("╔" + "═" * 68 + "╗")
print("║" + " RLHF TRAINING LOOP DEMONSTRATION".center(68) + "║")
print("╚" + "═" * 68 + "╝")
# Create sample prompts
prompts = [
Prompt("What is the capital of France?", [1, 2, 3, 4, 5]),
Prompt("Explain quantum computing.", [6, 7, 8, 9]),
Prompt("Write a haiku about programming.", [10, 11, 12, 13, 14]),
Prompt("What is machine learning?", [15, 16, 17]),
]
print(f"\nBatch size: {len(prompts)} prompts")
# Run one training step
stats = rlhf_training_step(
prompts=prompts,
actor=None, # Placeholder
critic=None,
reward_model=None,
reference=None,
kl_coef=0.02,
)
# Explain the process
print("\n" + "=" * 70)
print(" WHAT JUST HAPPENED")
print("=" * 70)
print("""
This simulated one complete RLHF training step:
1. GENERATION: Actor generated responses for each prompt
2. SCORING: Reward model evaluated response quality
3. KL COMPUTATION: Reference model computed divergence penalty
4. VALUE ESTIMATION: Critic predicted expected rewards
5. ADVANTAGE COMPUTATION: GAE combined rewards and values
6. PPO UPDATE: Actor and critic weights updated
In production, this happens with:
- Real neural network forward/backward passes
- GPU tensor operations
- Distributed training across multiple devices
- Gradient accumulation and synchronization
The key insight: RLHF is just PPO with:
- Reward from a learned reward model
- KL penalty to stay close to reference
- Four models instead of just actor-critic
""")
if __name__ == "__main__":
main()
reward_calculator.py
Understand reward calculation with KL penalty
This script demonstrates how the total reward in RLHF is computed from the reward model score and KL penalty.
What It Does
- Shows raw reward model scores
- Computes KL divergence between actor and reference
- Applies the KL penalty with different β values
- Demonstrates why the penalty prevents reward hacking
Run It
python tutorial/part4-rlhf/chapter13-rlhf-flow/scripts/reward_calculator.py
Example Output
=== RLHF Reward Calculator ===
Response: "This is a great product! I highly recommend it!"
Reward Model Score: 0.85 (high quality response)
KL Divergence Calculation:
Actor log prob for each token:
"This": -2.3, "is": -1.1, "a": -0.8, ...
Reference log prob for each token:
"This": -2.1, "is": -1.0, "a": -0.9, ...
KL per token = actor_logp - ref_logp
"This": -0.2, "is": -0.1, "a": +0.1, ...
Total KL: 0.45 (actor has diverged from reference)
Total Reward with Different β:
β = 0.0: R = 0.85 - 0.0 * 0.45 = 0.85
β = 0.1: R = 0.85 - 0.1 * 0.45 = 0.805
β = 0.5: R = 0.85 - 0.5 * 0.45 = 0.625
β = 1.0: R = 0.85 - 1.0 * 0.45 = 0.40
Observation: Higher β penalizes divergence more heavily.
Why KL Penalty Matters
Without penalty (β=0):
Actor learns to say "AMAZING! INCREDIBLE!" for everything
Reward model gives high scores
But output is unnatural
With penalty (β=0.1):
Actor stays close to reference
Must improve while remaining natural
Better quality outputs
The Formula
def compute_reward(response, actor, reference, reward_model, beta):
# Get reward model score
rm_score = reward_model(response)
# Compute KL divergence
actor_logp = actor.log_prob(response)
ref_logp = reference.log_prob(response)
kl = (actor_logp - ref_logp).sum()
# Total reward with penalty
total_reward = rm_score - beta * kl
return total_reward
Source Code
#!/usr/bin/env python3
"""
RLHF Reward Calculator
This script demonstrates how rewards are computed in RLHF:
- Reward model scoring
- KL divergence penalty
- Combined reward signal
Usage:
python reward_calculator.py
"""
import argparse
import math
from typing import List, Tuple
def compute_kl_divergence(actor_log_probs: List[float],
ref_log_probs: List[float]) -> List[float]:
"""
Compute per-token KL divergence.
KL(actor || ref) = Σ p_actor * log(p_actor / p_ref)
= Σ p_actor * (log p_actor - log p_ref)
Since we have log probs, this simplifies to computing the difference
and then exponentiating to get actual KL.
"""
kl_per_token = []
for actor_lp, ref_lp in zip(actor_log_probs, ref_log_probs):
# Approximate KL using log prob difference
# Full KL would be: exp(actor_lp) * (actor_lp - ref_lp)
# Common approximation: just the difference (works well in practice)
kl = actor_lp - ref_lp
kl_per_token.append(kl)
return kl_per_token
def compute_rewards(
reward_model_score: float,
actor_log_probs: List[float],
ref_log_probs: List[float],
kl_coef: float = 0.02,
reward_at_end_only: bool = True,
) -> Tuple[List[float], dict]:
"""
Compute per-token rewards with KL penalty.
Args:
reward_model_score: Score from reward model (typically for full response)
actor_log_probs: Log probabilities from actor for each token
ref_log_probs: Log probabilities from reference for each token
kl_coef: Coefficient for KL penalty (β in papers)
reward_at_end_only: If True, RM score only at last token
Returns:
List of rewards for each token
Dictionary with stats
"""
num_tokens = len(actor_log_probs)
# Compute KL divergence
kl_per_token = compute_kl_divergence(actor_log_probs, ref_log_probs)
# Compute rewards
rewards = []
for t in range(num_tokens):
kl_penalty = kl_coef * kl_per_token[t]
if reward_at_end_only:
# RM score only at last token
if t == num_tokens - 1:
r = reward_model_score - kl_penalty
else:
r = -kl_penalty # Only penalty
else:
# RM score distributed across tokens
r = reward_model_score / num_tokens - kl_penalty
rewards.append(r)
stats = {
'total_kl': sum(kl_per_token),
'avg_kl': sum(kl_per_token) / len(kl_per_token),
'total_kl_penalty': sum(kl_coef * kl for kl in kl_per_token),
'reward_model_score': reward_model_score,
'total_reward': sum(rewards),
}
return rewards, stats
def visualize_rewards(rewards: List[float], kl_per_token: List[float],
kl_coef: float, rm_score: float):
"""Visualize reward distribution across tokens."""
print("\n" + "=" * 70)
print(" TOKEN-LEVEL REWARD BREAKDOWN")
print("=" * 70)
print(f"\n{'Token':<8} {'KL':<12} {'KL Penalty':<12} {'RM Contrib':<12} {'Reward':<12}")
print("-" * 60)
num_tokens = len(rewards)
for t in range(num_tokens):
kl = kl_per_token[t]
penalty = kl_coef * kl
rm_contrib = rm_score if t == num_tokens - 1 else 0
print(f"{t:<8} {kl:>+.4f} {-penalty:>+.4f} {rm_contrib:>+.4f} {rewards[t]:>+.4f}")
print("-" * 60)
print(f"{'Total':<8} {sum(kl_per_token):>+.4f} {-kl_coef*sum(kl_per_token):>+.4f} "
f"{rm_score:>+.4f} {sum(rewards):>+.4f}")
def demonstrate_kl_penalty_effect():
"""Show how KL coefficient affects learning."""
print("\n" + "=" * 70)
print(" KL COEFFICIENT EFFECT")
print("=" * 70)
# Simulated response with moderate divergence
actor_lps = [-1.0, -1.2, -0.8, -1.5, -0.9] # Actor log probs
ref_lps = [-1.1, -1.0, -1.0, -1.2, -1.0] # Reference log probs
rm_score = 0.5 # Positive reward from RM
print("\nScenario: Response with RM score = 0.5")
print("Actor is somewhat divergent from reference\n")
kl_values = [0.0, 0.01, 0.02, 0.05, 0.1, 0.5]
print(f"{'KL Coef':<10} {'Total KL Penalty':<18} {'Net Reward':<12} {'Effect':<20}")
print("-" * 60)
for kl_coef in kl_values:
rewards, stats = compute_rewards(rm_score, actor_lps, ref_lps, kl_coef)
net_reward = stats['total_reward']
if net_reward > rm_score * 0.8:
effect = "Weak penalty"
elif net_reward > 0:
effect = "Moderate penalty"
elif net_reward > -0.5:
effect = "Strong penalty"
else:
effect = "Overwhelming penalty"
print(f"{kl_coef:<10} {stats['total_kl_penalty']:>+.4f} {net_reward:>+.4f} {effect}")
print("""
Interpretation:
β ≈ 0.00: No penalty, risk of reward hacking
β ≈ 0.02: Typical value, balanced
β ≈ 0.10: Strong regularization, slower learning
β ≈ 0.50: KL dominates, almost no RM signal
""")
def demonstrate_kl_scenarios():
"""Show different KL divergence scenarios."""
print("\n" + "=" * 70)
print(" KL DIVERGENCE SCENARIOS")
print("=" * 70)
kl_coef = 0.02
scenarios = [
("Low divergence (similar to reference)", [-1.0, -1.1, -0.9], [-1.0, -1.0, -1.0]),
("High divergence (very different)", [-0.5, -2.0, -0.3], [-1.5, -0.8, -1.2]),
("More confident than reference", [-0.2, -0.3, -0.2], [-1.0, -1.0, -1.0]),
("Less confident than reference", [-2.0, -2.5, -2.0], [-1.0, -1.0, -1.0]),
]
rm_score = 0.5
for name, actor_lps, ref_lps in scenarios:
kl_per_token = compute_kl_divergence(actor_lps, ref_lps)
rewards, stats = compute_rewards(rm_score, actor_lps, ref_lps, kl_coef)
print(f"\n{name}:")
print(f" Actor log probs: {actor_lps}")
print(f" Ref log probs: {ref_lps}")
print(f" Per-token KL: {[f'{k:.2f}' for k in kl_per_token]}")
print(f" Total KL: {stats['total_kl']:.4f}")
print(f" KL penalty: {stats['total_kl_penalty']:.4f}")
print(f" Net reward: {stats['total_reward']:.4f}")
def main():
parser = argparse.ArgumentParser(description="RLHF Reward Calculator")
parser.add_argument("--kl-coef", "-k", type=float, default=0.02,
help="KL penalty coefficient")
parser.add_argument("--rm-score", "-r", type=float, default=0.5,
help="Reward model score")
args = parser.parse_args()
print("╔" + "═" * 68 + "╗")
print("║" + " RLHF REWARD CALCULATOR".center(68) + "║")
print("╚" + "═" * 68 + "╝")
# Example response
actor_log_probs = [-1.2, -0.8, -1.5, -0.9, -1.1, -1.0, -0.7, -1.3]
ref_log_probs = [-1.0, -1.0, -1.2, -1.0, -1.0, -1.1, -1.0, -1.0]
print(f"\nConfiguration:")
print(f" KL coefficient (β): {args.kl_coef}")
print(f" Reward model score: {args.rm_score}")
print(f" Response length: {len(actor_log_probs)} tokens")
# Compute rewards
rewards, stats = compute_rewards(
args.rm_score, actor_log_probs, ref_log_probs, args.kl_coef
)
# Compute KL for visualization
kl_per_token = compute_kl_divergence(actor_log_probs, ref_log_probs)
# Visualize
visualize_rewards(rewards, kl_per_token, args.kl_coef, args.rm_score)
# Show stats
print("\n" + "=" * 70)
print(" SUMMARY STATISTICS")
print("=" * 70)
print(f"""
Reward Model Score: {stats['reward_model_score']:>+.4f}
Total KL Divergence: {stats['total_kl']:>+.4f}
Total KL Penalty: {stats['total_kl_penalty']:>+.4f}
Net Total Reward: {stats['total_reward']:>+.4f}
Reward Composition:
RM contribution: {stats['reward_model_score']:>+.4f} (at last token)
KL penalty: {-stats['total_kl_penalty']:>+.4f} (distributed)
────────────────────────────
Net reward: {stats['total_reward']:>+.4f}
""")
# Demonstrate KL effects
demonstrate_kl_penalty_effect()
demonstrate_kl_scenarios()
# Key insights
print("\n" + "=" * 70)
print(" KEY INSIGHTS")
print("=" * 70)
print("""
1. REWARD = RM_SCORE - β × KL
The KL penalty prevents the actor from diverging too far from
the reference model, avoiding "reward hacking".
2. KL IS COMPUTED PER-TOKEN
Each token's probability is compared to the reference.
This gives fine-grained control over divergence.
3. RM SCORE IS TYPICALLY END-ONLY
The reward model scores the complete response.
This score appears only at the last token.
GAE propagates it backwards during training.
4. β IS A CRITICAL HYPERPARAMETER
Too low: Reward hacking, degenerate solutions
Too high: Learning is too slow, policy doesn't change
Typical values: 0.01 - 0.05
5. NEGATIVE REWARDS ARE OK
The policy gradient cares about relative advantages,
not absolute reward values.
""")
if __name__ == "__main__":
main()
Chapter 14: RLHF System Architecture
“The difference between a working RLHF system and an efficient one is whether you can fit four models on your GPUs.”
Learning Objectives
By the end of this chapter, you will be able to:
- Compare co-located vs disaggregated RLHF architectures
- Explain weight update mechanisms between training and inference engines
- Understand the hybrid engine approach (verl)
- Design an RLHF system for a given hardware setup
Prerequisites
- Completed Chapters 12-13 (RL Fundamentals, RLHF Flow)
- Understanding of distributed training (Part II)
- Familiarity with inference systems (Part III)
Concept Overview
The RLHF Systems Challenge
RLHF requires:
- Generation (inference): Actor generates responses
- Scoring (inference): Reward model evaluates
- Training (training): PPO updates actor and critic
These have different optimal configurations:
- Generation: Large batch, high throughput
- Training: Gradient synchronization, memory for optimizer
Naively running both on the same GPUs wastes resources.
Architecture Options
| Architecture | Description | Pros | Cons |
|---|---|---|---|
| Co-located | All models on same GPUs | Simple, no transfer | Memory constrained |
| Disaggregated | Separate GPU groups | Optimized per workload | Network transfer |
| Hybrid | Smart resource sharing | Best utilization | Complex implementation |
Architecture 1: Co-located (slime, verl)
All models share the same GPUs, swapping memory between phases.
GPU 0-7 (same GPUs for everything):
Phase 1 - Generation:
┌────────────────────────────────────────────┐
│ Actor weights + KV cache for inference │
│ (Reference and Reward also loaded) │
└────────────────────────────────────────────┘
Phase 2 - Training:
┌────────────────────────────────────────────┐
│ Actor + Critic weights + gradients + │
│ optimizer states + activations │
└────────────────────────────────────────────┘
Memory swapping: After generation, KV cache is freed. Optimizer states loaded.
Advantage: No network transfer for weight updates. Disadvantage: Cannot parallelize generation and training.
Architecture 2: Disaggregated (OpenRLHF)
Separate GPU groups for different tasks.
Training Cluster (GPUs 0-31): Inference Cluster (GPUs 32-63):
┌───────────────────────────┐ ┌───────────────────────────┐
│ Actor training │ │ Actor inference │
│ Critic training │ │ (generation) │
│ Gradients + optimizer │ ◄────── │ │
└───────────────────────────┘ weights └───────────────────────────┘
│ ▲
│ ┌───────────────────────────┐
│ │ Reward Model │
└─────────────►│ (scoring) │
prompts └───────────────────────────┘
Weight transfer: After training, send updated weights to inference cluster.
Advantage: Generation and training can overlap. Disadvantage: Network bandwidth for weight transfer.
Architecture 3: Hybrid Engine (verl)
verl’s innovation: Keep weights in GPU memory, switch between training and inference modes.
Same GPUs, Different Modes:
Training Mode:
┌────────────────────────────────────────────┐
│ FSDP sharded weights │
│ Full gradients and optimizer states │
│ Backpropagation-ready tensors │
└────────────────────────────────────────────┘
│
│ mode switch (no data movement!)
▼
Inference Mode:
┌────────────────────────────────────────────┐
│ Same weights, viewed for inference │
│ KV cache allocated │
│ No gradient tracking │
└────────────────────────────────────────────┘
Key insight: Tensor memory is reused between modes. Only metadata changes.
Weight Update Mechanisms
How to get updated weights from training to inference?
Method 1: Disk-based (simplest)
# After training
torch.save(actor.state_dict(), "checkpoint.pt")
# Inference engine loads
actor.load_state_dict(torch.load("checkpoint.pt"))
- Pros: Works always, supports different cluster sizes
- Cons: I/O bound, slow for large models
Method 2: NCCL-based (disaggregated)
# Training rank 0 gathers full weights
full_weights = gather_weights(training_group)
# Send to inference rank 0
dist.send(full_weights, dst=inference_rank_0)
# Inference rank 0 broadcasts
dist.broadcast(full_weights, src=0, group=inference_group)
- Pros: Fast with good network
- Cons: Requires connectivity between clusters
Method 3: Shared memory (co-located)
# verl approach: Share GPU memory via CUDA IPC
handle = tensor._cuda_ipc_handle() # Get memory handle
serialized = serialize(handle) # Not the data, just the pointer!
# Other process
tensor = deserialize(serialized) # Reconstructs tensor from handle
# tensor points to SAME GPU memory - zero copy!
- Pros: Zero data movement
- Cons: Only works on same GPU
The verl Weight Update Deep Dive
verl’s weight update is elegant:
- Training finishes: Actor weights are FSDP-sharded across GPUs
- Gather to full: FSDP
FULL_STATE_DICTgathers to rank 0 - Serialize handle: Create CUDA IPC handle (just a pointer)
- Share handle: Send handle to inference engine (tiny data!)
- Reconstruct tensor: Inference engine creates tensor from handle
- Same memory: Both engines now reference identical GPU memory
Training Engine Inference Engine
│ │
│ FSDP gathers │
▼ │
[Full tensor on GPU] │
│ │
│ Get IPC handle │
▼ │
[Handle: ptr=0x7f.., size=1GB] │
│ │
│ Send handle (few bytes!) │
└───────────────────────────────────►│
│ Reconstruct from handle
▼
[Same GPU memory, new tensor object]
Memory Timeline in Hybrid Engine
Time →
Phase 1: Generation
┌─────────────────────────────────────────────────────────────────┐
│ GPU Memory: [Actor weights][KV Cache][Reward Model][Reference] │
└─────────────────────────────────────────────────────────────────┘
Phase 2: Prepare for Training
┌─────────────────────────────────────────────────────────────────┐
│ GPU Memory: [Actor weights][Critic weights][Free space...] │
│ (KV cache freed, RM and Ref offloaded) │
└─────────────────────────────────────────────────────────────────┘
Phase 3: Training
┌─────────────────────────────────────────────────────────────────┐
│ GPU Memory: [Actor][Critic][Actor grads][Critic grads] │
│ [Adam states][Activations] │
└─────────────────────────────────────────────────────────────────┘
Phase 4: Back to Generation
┌─────────────────────────────────────────────────────────────────┐
│ GPU Memory: [Updated Actor][KV Cache][RM][Ref] │
│ (optimizer states offloaded) │
└─────────────────────────────────────────────────────────────────┘
Comparison: verl vs OpenRLHF vs slime
| Feature | verl | OpenRLHF | slime |
|---|---|---|---|
| Architecture | Hybrid | Disaggregated | Co-located |
| Weight transfer | IPC handles | NCCL/Disk | Disk or tensor |
| Generation engine | Custom | vLLM | SGLang |
| Training engine | Custom SPMD | Ray + DeepSpeed | Megatron |
| Memory efficiency | High | Medium | High |
| Scaling | Complex | Simpler | Complex |
Code Walkthrough
Script 1: weight_update_demo.py
Demonstrates weight update mechanisms:
- Simulates different transfer methods
- Compares overhead
Script 2: memory_timeline.py
Visualizes memory usage across RLHF phases:
- Shows peak memory per phase
- Identifies bottlenecks
System Design Guidelines
For Small Models (7B)
Single 8-GPU node:
- Co-located approach
- All 4 models fit with TP=1
- Simple implementation
For Medium Models (70B)
Multi-node setup:
- Disaggregated or Hybrid
- Actor/Critic: TP=8, PP=2 (16 GPUs)
- Reward/Reference: TP=8 (8 GPUs each)
- Total: 32+ GPUs
For Large Models (400B+)
Large cluster:
- Definitely disaggregated
- Separate clusters for training and inference
- Async weight updates
- Consider gradient checkpointing
Try It Yourself
Exercise 1: Memory Planning
For a 70B model RLHF setup:
- Calculate memory per GPU for co-located (8 GPUs)
- Calculate memory per GPU for disaggregated (32 GPUs)
- Which fits? What trade-offs?
Exercise 2: Weight Transfer Bandwidth
If weight transfer takes 10 seconds for 140GB:
- What’s the transfer bandwidth?
- How does this compare to training iteration time?
- Can we overlap transfer with anything?
Exercise 3: Design an RLHF System
You have: 64 H100 GPUs across 8 nodes Model: 70B parameters
Design:
- Training parallelism (TP, PP, DP)
- Inference parallelism
- Weight update mechanism
- Memory budget per GPU
Key Takeaways
- Architecture choice depends on scale - Co-located for small, disaggregated for large
- Weight transfer is critical - IPC handles enable zero-copy on same GPU
- Memory phases are distinct - Generation and training have different needs
- Hybrid engines maximize utilization - Same GPUs, different modes
- Real systems combine techniques - No one-size-fits-all
The RLHF Systems Maturity Model
Level 1: Naive Co-location
└─► All models loaded always
└─► Works but memory inefficient
Level 2: Smart Co-location
└─► Memory swapping between phases
└─► Better utilization
Level 3: Disaggregated
└─► Separate clusters
└─► Network weight transfer
Level 4: Hybrid Engine
└─► Shared memory, mode switching
└─► Minimal overhead
Level 5: Async Hybrid
└─► Overlapped generation and training
└─► Maximum throughput
What’s Next?
Congratulations! You’ve completed the ML Systems Tutorial. You now understand:
- Distributed training primitives
- Parallelism strategies
- LLM inference systems
- RLHF architecture
For continued learning:
- Study verl, OpenRLHF, or trl source code
- Implement a simple RLHF system
- Contribute to open-source ML systems projects
Further Reading
weight_update_demo.py
Compare different weight transfer mechanisms for RLHF
This script demonstrates how weights are transferred from training to inference engines in different RLHF architectures.
What It Does
- Simulates three weight transfer methods
- Measures transfer time and memory usage
- Shows trade-offs between approaches
Run It
python tutorial/part4-rlhf/chapter14-rlhf-architecture/scripts/weight_update_demo.py
Example Output
=== Weight Transfer Mechanisms ===
Model size: 70B parameters = 140 GB (FP16)
Method 1: Disk-based Transfer
Write to disk: 28.0 seconds (5 GB/s SSD)
Read from disk: 28.0 seconds
Total: 56.0 seconds
Note: Works across any hardware configuration
Method 2: NCCL Transfer (Network)
Gather weights on training rank 0: 2.1 seconds
Transfer to inference cluster: 5.6 seconds (25 GB/s InfiniBand)
Broadcast to inference ranks: 2.1 seconds
Total: 9.8 seconds
Note: Requires network connectivity between clusters
Method 3: CUDA IPC (Same GPU)
Get IPC handle: 0.001 seconds
Serialize handle: 0.001 seconds
Reconstruct tensor: 0.001 seconds
Total: 0.003 seconds (!)
Note: Zero data movement - same memory, new reference
Comparison:
Disk: 56,000 ms (100% transfer)
NCCL: 9,800 ms (18% of disk)
CUDA IPC: 3 ms (0.005% of disk)
The verl approach (CUDA IPC) achieves near-zero overhead!
The Key Insight
Disk transfer:
[GPU Memory] → [CPU Memory] → [Disk] → [CPU Memory] → [GPU Memory]
Lots of data movement
NCCL transfer:
[GPU Memory] ──network──► [GPU Memory]
Still moves all the data
CUDA IPC:
[GPU Memory] ← same memory! → [GPU Memory view]
No data movement at all!
Source Code
#!/usr/bin/env python3
"""
Weight Update Mechanisms Demonstration
This script demonstrates different weight update mechanisms used in RLHF:
- Disk-based transfer
- NCCL-based transfer
- Shared memory (IPC handles)
Usage:
python weight_update_demo.py
"""
import argparse
import time
from dataclasses import dataclass
from typing import Dict
@dataclass
class TransferMethod:
"""Configuration for a weight transfer method."""
name: str
description: str
bandwidth_gbps: float # GB/s
setup_overhead_ms: float
works_across_nodes: bool
works_same_gpu: bool
# Common transfer methods
TRANSFER_METHODS = {
"disk_ssd": TransferMethod(
name="Disk (NVMe SSD)",
description="Save to disk, load from disk",
bandwidth_gbps=7.0, # PCIe 4.0 NVMe
setup_overhead_ms=100,
works_across_nodes=True,
works_same_gpu=True,
),
"disk_hdd": TransferMethod(
name="Disk (HDD/NFS)",
description="Save to network storage",
bandwidth_gbps=0.2,
setup_overhead_ms=500,
works_across_nodes=True,
works_same_gpu=True,
),
"nccl_nvlink": TransferMethod(
name="NCCL (NVLink)",
description="GPU-to-GPU within node",
bandwidth_gbps=450, # NVLink 4.0
setup_overhead_ms=10,
works_across_nodes=False,
works_same_gpu=True,
),
"nccl_ib": TransferMethod(
name="NCCL (InfiniBand)",
description="GPU-to-GPU across nodes",
bandwidth_gbps=50, # 400Gbps IB
setup_overhead_ms=50,
works_across_nodes=True,
works_same_gpu=True,
),
"nccl_ethernet": TransferMethod(
name="NCCL (Ethernet)",
description="GPU-to-GPU over ethernet",
bandwidth_gbps=12.5, # 100Gbps
setup_overhead_ms=100,
works_across_nodes=True,
works_same_gpu=True,
),
"cuda_ipc": TransferMethod(
name="CUDA IPC Handle",
description="Share GPU memory pointer",
bandwidth_gbps=float('inf'), # Zero copy!
setup_overhead_ms=1,
works_across_nodes=False,
works_same_gpu=True, # Same GPU only!
),
}
def calculate_transfer_time(method: TransferMethod, size_gb: float) -> float:
"""Calculate transfer time in milliseconds."""
if method.bandwidth_gbps == float('inf'):
# Zero copy - only setup overhead
return method.setup_overhead_ms
transfer_ms = (size_gb / method.bandwidth_gbps) * 1000
return transfer_ms + method.setup_overhead_ms
def compare_methods(model_size_gb: float, architecture: str) -> None:
"""Compare transfer methods for a given scenario."""
print(f"\n{'='*70}")
print(f" WEIGHT UPDATE COMPARISON: {model_size_gb}GB Model, {architecture} Architecture")
print(f"{'='*70}")
applicable_methods = []
for name, method in TRANSFER_METHODS.items():
if architecture == "co-located" and not method.works_same_gpu:
continue
if architecture == "disaggregated-cross-node" and not method.works_across_nodes:
continue
applicable_methods.append((name, method))
print(f"\n{'Method':<25} {'Transfer Time':<15} {'Notes':<30}")
print("-" * 70)
for name, method in applicable_methods:
transfer_time = calculate_transfer_time(method, model_size_gb)
if transfer_time < 100:
time_str = f"{transfer_time:.1f} ms"
elif transfer_time < 60000:
time_str = f"{transfer_time/1000:.2f} s"
else:
time_str = f"{transfer_time/60000:.1f} min"
if transfer_time < 1000:
notes = "Excellent"
elif transfer_time < 10000:
notes = "Good"
elif transfer_time < 60000:
notes = "Acceptable"
else:
notes = "Slow"
if method.bandwidth_gbps == float('inf'):
notes = "Zero copy!"
print(f"{method.name:<25} {time_str:<15} {notes:<30}")
def demonstrate_ipc_concept():
"""Explain how CUDA IPC handles work."""
print("\n" + "=" * 70)
print(" CUDA IPC HANDLES: ZERO-COPY WEIGHT SHARING")
print("=" * 70)
print("""
How CUDA IPC (Inter-Process Communication) handles work:
┌─────────────────────────────────────────────────────────────────────┐
│ TRADITIONAL WEIGHT TRANSFER │
│ │
│ Training Process: Inference Process: │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ GPU Memory: │ COPY DATA │ GPU Memory: │ │
│ │ [Weight tensor] │ ───────────────► │ [Weight tensor] │ │
│ │ 140 GB │ 140 GB moved! │ 140 GB │ │
│ └─────────────────┘ └─────────────────┘ │
│ │
│ Time: 140GB / 450 GB/s = 311 ms (NVLink) │
└─────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────┐
│ CUDA IPC HANDLE SHARING │
│ │
│ Training Process: Inference Process: │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ GPU Memory: │ SHARE HANDLE │ GPU Memory: │ │
│ │ [Weight tensor] │ ───────────────► │ (same memory!) │ │
│ │ @ address 0x7f..│ Just a pointer │ [Weight tensor] │ │
│ └────────┬────────┘ (~100 bytes) └────────┬────────┘ │
│ │ │ │
│ └──────────── SAME GPU MEMORY ────────┘ │
│ │
│ Time: ~1 ms (only handle serialization) │
│ Data moved: ~100 bytes (not 140 GB!) │
└─────────────────────────────────────────────────────────────────────┘
The handle contains:
- GPU device ID
- Memory address (virtual)
- Size and stride information
- Reference counter handle
When the inference process "reconstructs" the tensor:
1. It creates a new Python tensor object
2. The tensor points to the SAME GPU memory
3. No data is copied!
Limitation: Both processes must be on the same GPU.
For multi-GPU setups, each GPU's weights need their own handle.
""")
def demonstrate_verl_approach():
"""Explain verl's weight update approach."""
print("\n" + "=" * 70)
print(" verl's WEIGHT UPDATE APPROACH")
print("=" * 70)
print("""
verl's Hybrid Engine uses a sophisticated weight update mechanism:
1. TRAINING PHASE
┌─────────────────────────────────────────────────────────────────┐
│ FSDP Training │
│ │
│ GPU 0: [Shard 0] [Shard 4] [Shard 8] ... │
│ GPU 1: [Shard 1] [Shard 5] [Shard 9] ... │
│ GPU 2: [Shard 2] [Shard 6] [Shard 10] ... │
│ GPU 3: [Shard 3] [Shard 7] [Shard 11] ... │
│ │
│ Weights are sharded across GPUs (FSDP) │
└─────────────────────────────────────────────────────────────────┘
2. GATHER FOR INFERENCE
┌─────────────────────────────────────────────────────────────────┐
│ All-Gather to reconstruct full weights │
│ │
│ GPU 0: [Full Layer 0] [Full Layer 1] ... │
│ GPU 1: [Full Layer 0] [Full Layer 1] ... │
│ GPU 2: [Full Layer 0] [Full Layer 1] ... │
│ GPU 3: [Full Layer 0] [Full Layer 1] ... │
│ │
│ (Temporary memory spike during gather) │
└─────────────────────────────────────────────────────────────────┘
3. CREATE IPC HANDLES
┌─────────────────────────────────────────────────────────────────┐
│ For each GPU's portion of weights: │
│ │
│ handle = tensor._cuda_ipc_handle() │
│ serialized = serialize(handle) # ~100 bytes │
│ │
│ Gather handles to coordinator (not data!) │
└─────────────────────────────────────────────────────────────────┘
4. INFERENCE ENGINE RECEIVES HANDLES
┌─────────────────────────────────────────────────────────────────┐
│ For each handle: │
│ │
│ tensor = reconstruct_from_handle(handle) │
│ # tensor now points to same GPU memory as training tensor │
│ │
│ model.load_weights(tensor) # Just pointer assignment │
└─────────────────────────────────────────────────────────────────┘
Benefits:
- Zero data movement (weights stay in place)
- Microsecond-level "transfer" time
- Memory shared between engines
Complexity:
- Must manage tensor lifetimes carefully
- FSDP gather creates temporary memory spike
- Coordination between training and inference loops
""")
def calculate_rlhf_timeline(model_size_gb: float, method_name: str,
generation_time_s: float, training_time_s: float) -> None:
"""Calculate RLHF iteration timeline with weight updates."""
method = TRANSFER_METHODS[method_name]
transfer_time = calculate_transfer_time(method, model_size_gb)
transfer_time_s = transfer_time / 1000
total_time = generation_time_s + transfer_time_s + training_time_s + transfer_time_s
print(f"\n{'='*70}")
print(f" RLHF ITERATION TIMELINE")
print(f"{'='*70}")
print(f"\nConfiguration:")
print(f" Model size: {model_size_gb} GB")
print(f" Transfer method: {method.name}")
print(f" Generation time: {generation_time_s} s")
print(f" Training time: {training_time_s} s")
print(f"\nTimeline:")
print(f"""
┌─────────────────────────────────────────────────────────────────┐
│ Generation {generation_time_s:>6.1f} s │
├─────────────────────────────────────────────────────────────────┤
│ Weight transfer (train → infer) {transfer_time_s:>6.2f} s │
├─────────────────────────────────────────────────────────────────┤
│ Training (PPO update) {training_time_s:>6.1f} s │
├─────────────────────────────────────────────────────────────────┤
│ Weight transfer (infer ← train) {transfer_time_s:>6.2f} s │
└─────────────────────────────────────────────────────────────────┘
Total iteration time: {total_time:.2f} s
""")
overhead_pct = (2 * transfer_time_s) / total_time * 100
print(f" Weight transfer overhead: {overhead_pct:.1f}% of iteration")
def main():
parser = argparse.ArgumentParser(description="Weight Update Demo")
parser.add_argument("--model-size", "-m", type=float, default=140,
help="Model size in GB (default: 140 for 70B model)")
args = parser.parse_args()
print("╔" + "═" * 68 + "╗")
print("║" + " WEIGHT UPDATE MECHANISMS DEMONSTRATION".center(68) + "║")
print("╚" + "═" * 68 + "╝")
# Compare methods for different architectures
compare_methods(args.model_size, "co-located")
compare_methods(args.model_size, "disaggregated-same-node")
compare_methods(args.model_size, "disaggregated-cross-node")
# Explain IPC
demonstrate_ipc_concept()
# Explain verl approach
demonstrate_verl_approach()
# Show timeline impact
calculate_rlhf_timeline(
model_size_gb=args.model_size,
method_name="cuda_ipc",
generation_time_s=30,
training_time_s=20
)
calculate_rlhf_timeline(
model_size_gb=args.model_size,
method_name="nccl_ib",
generation_time_s=30,
training_time_s=20
)
if __name__ == "__main__":
main()
memory_timeline.py
Visualize GPU memory usage across RLHF phases
This script shows how memory is allocated and freed during different phases of RLHF training.
What It Does
- Simulates RLHF memory allocation
- Shows memory usage for each phase
- Identifies peak memory and bottlenecks
- Demonstrates why phase-based swapping helps
Run It
python tutorial/part4-rlhf/chapter14-rlhf-architecture/scripts/memory_timeline.py
Example Output
=== RLHF Memory Timeline (70B model, 8 GPUs) ===
GPU Memory Available: 80 GB per GPU
Phase 1: Generation
┌─────────────────────────────────────────────────────────────┐
│ Actor weights (TP=8): 17.5 GB │
│ Reference weights (TP=8): 17.5 GB │
│ Reward model (TP=8): 17.5 GB │
│ KV Cache (batch=32): 20.0 GB │
│ ───────────────────────────────────── │
│ Total: 72.5 GB [OK - fits in 80 GB] │
└─────────────────────────────────────────────────────────────┘
Phase 2: Transition (Free KV cache, load critic)
Memory freed: 20.0 GB (KV cache)
Memory allocated: 17.5 GB (Critic weights)
Phase 3: Training
┌─────────────────────────────────────────────────────────────┐
│ Actor weights (TP=8): 17.5 GB │
│ Critic weights (TP=8): 17.5 GB │
│ Actor gradients: 17.5 GB │
│ Critic gradients: 17.5 GB │
│ Adam states (2x): 70.0 GB ← Offloaded! │
│ Activations: 10.0 GB │
│ ───────────────────────────────────── │
│ Without offload: 150.5 GB [FAIL - OOM] │
│ With optimizer offload: 80.0 GB [OK - barely fits] │
└─────────────────────────────────────────────────────────────┘
Memory Timeline:
Time →
┌──────────────────────────────────────────────────────────┐
80GB│████████████████████░░░░░░░░████████████████████████████│
│ Generation │Swap│ Training │
60GB│████████████████████░░░░░░░░████████████████████████████│
│ │ │ │
40GB│████████████████████░░░░░░░░████████████████████████████│
│ │ │ │
20GB│████████████████████░░░░░░░░████████████████████████████│
│ │ │ │
0GB└──────────────────────────────────────────────────────────┘
Legend: █ = allocated, ░ = free
Why Phase-Based Memory Matters
Without swapping:
All 4 models + optimizer + KV cache = 200+ GB per GPU
= Out of Memory!
With smart swapping:
Generation: Models + KV cache (no optimizer) = 72 GB
Training: Models + optimizer + grads (no KV) = 80 GB
= Fits!
Source Code
#!/usr/bin/env python3
"""
RLHF Memory Timeline Visualizer
This script visualizes memory usage across different phases of RLHF training,
helping understand memory requirements and bottlenecks.
Usage:
python memory_timeline.py
python memory_timeline.py --model-size 70
"""
import argparse
from dataclasses import dataclass
from typing import List, Dict
@dataclass
class ModelConfig:
"""Model configuration for memory estimation."""
name: str
params_billions: float
hidden_size: int
num_layers: int
vocab_size: int = 128000
MODELS = {
"7b": ModelConfig("7B", 7, 4096, 32),
"13b": ModelConfig("13B", 13, 5120, 40),
"70b": ModelConfig("70B", 70, 8192, 80),
"405b": ModelConfig("405B", 405, 16384, 126),
}
def estimate_memory(params_b: float, dtype_bytes: int = 2) -> float:
"""Estimate memory in GB for model parameters."""
return params_b * 1e9 * dtype_bytes / 1e9
def estimate_kv_cache(batch_size: int, seq_len: int, hidden_size: int,
num_layers: int, dtype_bytes: int = 2) -> float:
"""Estimate KV cache memory in GB."""
# K and V for each layer
kv_per_token = 2 * num_layers * hidden_size * dtype_bytes
total = batch_size * seq_len * kv_per_token
return total / 1e9
def estimate_optimizer_states(params_b: float) -> float:
"""Estimate Adam optimizer state memory in GB."""
# Adam: 2 states (m, v) in FP32
return params_b * 1e9 * 4 * 2 / 1e9
def estimate_gradients(params_b: float, dtype_bytes: int = 2) -> float:
"""Estimate gradient memory in GB."""
return params_b * 1e9 * dtype_bytes / 1e9
def estimate_activations(batch_size: int, seq_len: int, hidden_size: int,
num_layers: int, dtype_bytes: int = 2) -> float:
"""Rough estimate of activation memory in GB."""
# Simplified: ~10x hidden per layer per token
per_layer = batch_size * seq_len * hidden_size * 10 * dtype_bytes
return per_layer * num_layers / 1e9
def calculate_phase_memory(model: ModelConfig, batch_size: int, seq_len: int,
phase: str) -> Dict[str, float]:
"""Calculate memory breakdown for a specific RLHF phase."""
memory = {}
actor_params = estimate_memory(model.params_billions)
critic_params = estimate_memory(model.params_billions)
reward_params = estimate_memory(model.params_billions)
reference_params = estimate_memory(model.params_billions)
if phase == "generation":
# Generation phase: Actor inference + KV cache
memory["Actor (weights)"] = actor_params
memory["KV Cache"] = estimate_kv_cache(
batch_size, seq_len, model.hidden_size, model.num_layers
)
memory["Reward Model"] = reward_params
memory["Reference Model"] = reference_params
memory["Misc (buffers)"] = 2.0
elif phase == "scoring":
# Scoring phase: Reward model forward
memory["Reward Model (weights)"] = reward_params
memory["Activations"] = estimate_activations(
batch_size, seq_len, model.hidden_size, model.num_layers
) * 0.2 # Inference uses less
elif phase == "training":
# Training phase: Actor + Critic with gradients and optimizer
memory["Actor (weights)"] = actor_params
memory["Critic (weights)"] = critic_params
memory["Actor (gradients)"] = estimate_gradients(model.params_billions)
memory["Critic (gradients)"] = estimate_gradients(model.params_billions)
memory["Optimizer States"] = estimate_optimizer_states(model.params_billions)
memory["Activations"] = estimate_activations(
batch_size, seq_len, model.hidden_size, model.num_layers
)
elif phase == "full_rlhf":
# Full RLHF: worst case (all models loaded)
memory["Actor"] = actor_params
memory["Critic"] = critic_params
memory["Reward Model"] = reward_params
memory["Reference Model"] = reference_params
memory["Optimizer States"] = estimate_optimizer_states(model.params_billions)
memory["Gradients"] = estimate_gradients(model.params_billions) * 2
memory["KV Cache (peak)"] = estimate_kv_cache(
batch_size, seq_len, model.hidden_size, model.num_layers
)
memory["Activations (peak)"] = estimate_activations(
batch_size, seq_len, model.hidden_size, model.num_layers
)
return memory
def visualize_memory_bar(memory_dict: Dict[str, float], max_memory: float,
available: float) -> None:
"""Visualize memory as horizontal bar chart."""
total = sum(memory_dict.values())
scale = 50 / max_memory # Characters per GB
print(f"\n{'Component':<25} {'Memory':<10} {'Visualization':<50}")
print("-" * 85)
for name, mem in sorted(memory_dict.items(), key=lambda x: -x[1]):
bar_len = int(mem * scale)
bar = "█" * bar_len
print(f"{name:<25} {mem:>7.1f} GB {bar}")
print("-" * 85)
print(f"{'TOTAL':<25} {total:>7.1f} GB")
if total > available:
print(f"\n⚠ EXCEEDS available memory ({available} GB)!")
print(f" Need {total/available:.1f}x GPUs or memory optimization")
else:
print(f"\n✓ Fits in {available} GB ({total/available*100:.0f}% utilized)")
def show_memory_timeline(model: ModelConfig, batch_size: int, seq_len: int,
gpu_memory: float) -> None:
"""Show memory across all RLHF phases."""
print("\n" + "=" * 80)
print(f" RLHF MEMORY TIMELINE: {model.name} Model")
print("=" * 80)
phases = ["generation", "scoring", "training"]
max_mem = 0
for phase in phases:
mem = calculate_phase_memory(model, batch_size, seq_len, phase)
phase_total = sum(mem.values())
max_mem = max(max_mem, phase_total)
# Now visualize
for phase in phases:
mem = calculate_phase_memory(model, batch_size, seq_len, phase)
print(f"\n--- Phase: {phase.upper()} ---")
visualize_memory_bar(mem, max_mem * 1.1, gpu_memory)
def show_scaling_analysis(model: ModelConfig, gpu_memory: float) -> None:
"""Show how memory scales with batch size."""
print("\n" + "=" * 80)
print(f" SCALING ANALYSIS: {model.name} Model")
print("=" * 80)
print(f"\nMemory breakdown (batch_size=4, seq_len=2048):\n")
mem = calculate_phase_memory(model, 4, 2048, "full_rlhf")
# Fixed costs (don't scale with batch)
fixed = 0
scaling = 0
for name, m in mem.items():
if "weights" in name.lower() or "optimizer" in name.lower():
fixed += m
else:
scaling += m
print(f"Fixed costs (weights, optimizer): {fixed:.1f} GB")
print(f"Scaling costs (activations, KV): {scaling:.1f} GB")
print(f"Total: {fixed + scaling:.1f} GB")
print(f"\nHow scaling costs change:")
print(f"{'Batch Size':<12} {'KV Cache':<12} {'Activations':<15} {'Total':<12}")
print("-" * 51)
for bs in [1, 2, 4, 8, 16, 32]:
kv = estimate_kv_cache(bs, 2048, model.hidden_size, model.num_layers)
act = estimate_activations(bs, 2048, model.hidden_size, model.num_layers)
total = fixed + kv + act
fit = "✓" if total <= gpu_memory else "✗"
print(f"{bs:<12} {kv:<12.1f} {act:<15.1f} {total:<12.1f} {fit}")
def recommend_setup(model: ModelConfig, gpu_memory: float, num_gpus: int) -> None:
"""Recommend setup for given constraints."""
print("\n" + "=" * 80)
print(f" RECOMMENDED SETUP: {model.name} on {num_gpus}x {gpu_memory}GB GPUs")
print("=" * 80)
total_memory = gpu_memory * num_gpus
# Estimate requirements
mem = calculate_phase_memory(model, 4, 2048, "full_rlhf")
required = sum(mem.values())
print(f"\nMemory analysis:")
print(f" Required (naive): {required:.1f} GB")
print(f" Available: {total_memory:.1f} GB")
if required <= gpu_memory:
print(f"\n✓ Fits on single GPU")
print(f" Recommendation: Simple co-located setup")
elif required <= total_memory:
print(f"\n✓ Fits across {num_gpus} GPUs")
# Determine parallelism
tp_needed = max(1, int(required / gpu_memory * 0.7)) # 70% efficiency
tp_needed = min(tp_needed, num_gpus, 8) # Cap at 8 for TP
print(f" Recommendation:")
print(f" - Tensor Parallelism: {tp_needed}")
remaining_gpus = num_gpus // tp_needed
if remaining_gpus > 1:
print(f" - Data/Pipeline Parallelism: {remaining_gpus}")
else:
print(f"\n✗ Does not fit!")
print(f" Need {required / gpu_memory:.0f} GPUs or memory optimization")
print(f" Consider:")
print(f" - ZeRO-3 / FSDP for optimizer state sharding")
print(f" - Gradient checkpointing for activation memory")
print(f" - Disaggregated architecture")
def main():
parser = argparse.ArgumentParser(description="RLHF Memory Timeline")
parser.add_argument("--model-size", "-m", type=str, default="70b",
choices=list(MODELS.keys()),
help="Model size")
parser.add_argument("--batch-size", "-b", type=int, default=4,
help="Batch size")
parser.add_argument("--seq-len", "-s", type=int, default=2048,
help="Sequence length")
parser.add_argument("--gpu-memory", "-g", type=float, default=80,
help="GPU memory in GB")
parser.add_argument("--num-gpus", "-n", type=int, default=8,
help="Number of GPUs")
args = parser.parse_args()
model = MODELS[args.model_size]
print("╔" + "═" * 78 + "╗")
print("║" + " RLHF MEMORY TIMELINE VISUALIZER".center(78) + "║")
print("╚" + "═" * 78 + "╝")
print(f"\nConfiguration:")
print(f" Model: {model.name} ({model.params_billions}B parameters)")
print(f" Batch size: {args.batch_size}")
print(f" Sequence length: {args.seq_len}")
print(f" GPU memory: {args.gpu_memory} GB")
print(f" Number of GPUs: {args.num_gpus}")
# Show timeline
show_memory_timeline(model, args.batch_size, args.seq_len, args.gpu_memory)
# Show scaling
show_scaling_analysis(model, args.gpu_memory)
# Show recommendation
recommend_setup(model, args.gpu_memory, args.num_gpus)
# Key insights
print("\n" + "=" * 80)
print(" KEY INSIGHTS")
print("=" * 80)
print("""
1. FOUR MODELS = 4X WEIGHT MEMORY
RLHF needs Actor, Critic, Reward, Reference
For 70B: 4 × 140GB = 560GB just for weights!
2. OPTIMIZER STATES DOMINATE TRAINING
Adam needs 2× FP32 states per parameter
For 70B with Actor+Critic: ~1.1TB
3. MEMORY PHASES DIFFER SIGNIFICANTLY
- Generation: weights + KV cache (no gradients)
- Training: weights + gradients + optimizer (no KV)
Smart systems swap between phases
4. BATCH SIZE AFFECTS ACTIVATIONS
Larger batch → more activation memory
May need to reduce batch or checkpoint
5. SEQUENCE LENGTH AFFECTS KV CACHE
Longer sequences → larger KV cache
4K → 32K = 8x KV memory increase
""")
if __name__ == "__main__":
main()