weight_update_demo.py
Compare different weight transfer mechanisms for RLHF
This script demonstrates how weights are transferred from training to inference engines in different RLHF architectures.
What It Does
- Simulates three weight transfer methods
- Measures transfer time and memory usage
- Shows trade-offs between approaches
Run It
python tutorial/part4-rlhf/chapter14-rlhf-architecture/scripts/weight_update_demo.py
Example Output
=== Weight Transfer Mechanisms ===
Model size: 70B parameters = 140 GB (FP16)
Method 1: Disk-based Transfer
Write to disk: 28.0 seconds (5 GB/s SSD)
Read from disk: 28.0 seconds
Total: 56.0 seconds
Note: Works across any hardware configuration
Method 2: NCCL Transfer (Network)
Gather weights on training rank 0: 2.1 seconds
Transfer to inference cluster: 5.6 seconds (25 GB/s InfiniBand)
Broadcast to inference ranks: 2.1 seconds
Total: 9.8 seconds
Note: Requires network connectivity between clusters
Method 3: CUDA IPC (Same GPU)
Get IPC handle: 0.001 seconds
Serialize handle: 0.001 seconds
Reconstruct tensor: 0.001 seconds
Total: 0.003 seconds (!)
Note: Zero data movement - same memory, new reference
Comparison:
Disk: 56,000 ms (100% transfer)
NCCL: 9,800 ms (18% of disk)
CUDA IPC: 3 ms (0.005% of disk)
The verl approach (CUDA IPC) achieves near-zero overhead!
The Key Insight
Disk transfer:
[GPU Memory] → [CPU Memory] → [Disk] → [CPU Memory] → [GPU Memory]
Lots of data movement
NCCL transfer:
[GPU Memory] ──network──► [GPU Memory]
Still moves all the data
CUDA IPC:
[GPU Memory] ← same memory! → [GPU Memory view]
No data movement at all!
Source Code
#!/usr/bin/env python3
"""
Weight Update Mechanisms Demonstration
This script demonstrates different weight update mechanisms used in RLHF:
- Disk-based transfer
- NCCL-based transfer
- Shared memory (IPC handles)
Usage:
python weight_update_demo.py
"""
import argparse
import time
from dataclasses import dataclass
from typing import Dict
@dataclass
class TransferMethod:
"""Configuration for a weight transfer method."""
name: str
description: str
bandwidth_gbps: float # GB/s
setup_overhead_ms: float
works_across_nodes: bool
works_same_gpu: bool
# Common transfer methods
TRANSFER_METHODS = {
"disk_ssd": TransferMethod(
name="Disk (NVMe SSD)",
description="Save to disk, load from disk",
bandwidth_gbps=7.0, # PCIe 4.0 NVMe
setup_overhead_ms=100,
works_across_nodes=True,
works_same_gpu=True,
),
"disk_hdd": TransferMethod(
name="Disk (HDD/NFS)",
description="Save to network storage",
bandwidth_gbps=0.2,
setup_overhead_ms=500,
works_across_nodes=True,
works_same_gpu=True,
),
"nccl_nvlink": TransferMethod(
name="NCCL (NVLink)",
description="GPU-to-GPU within node",
bandwidth_gbps=450, # NVLink 4.0
setup_overhead_ms=10,
works_across_nodes=False,
works_same_gpu=True,
),
"nccl_ib": TransferMethod(
name="NCCL (InfiniBand)",
description="GPU-to-GPU across nodes",
bandwidth_gbps=50, # 400Gbps IB
setup_overhead_ms=50,
works_across_nodes=True,
works_same_gpu=True,
),
"nccl_ethernet": TransferMethod(
name="NCCL (Ethernet)",
description="GPU-to-GPU over ethernet",
bandwidth_gbps=12.5, # 100Gbps
setup_overhead_ms=100,
works_across_nodes=True,
works_same_gpu=True,
),
"cuda_ipc": TransferMethod(
name="CUDA IPC Handle",
description="Share GPU memory pointer",
bandwidth_gbps=float('inf'), # Zero copy!
setup_overhead_ms=1,
works_across_nodes=False,
works_same_gpu=True, # Same GPU only!
),
}
def calculate_transfer_time(method: TransferMethod, size_gb: float) -> float:
"""Calculate transfer time in milliseconds."""
if method.bandwidth_gbps == float('inf'):
# Zero copy - only setup overhead
return method.setup_overhead_ms
transfer_ms = (size_gb / method.bandwidth_gbps) * 1000
return transfer_ms + method.setup_overhead_ms
def compare_methods(model_size_gb: float, architecture: str) -> None:
"""Compare transfer methods for a given scenario."""
print(f"\n{'='*70}")
print(f" WEIGHT UPDATE COMPARISON: {model_size_gb}GB Model, {architecture} Architecture")
print(f"{'='*70}")
applicable_methods = []
for name, method in TRANSFER_METHODS.items():
if architecture == "co-located" and not method.works_same_gpu:
continue
if architecture == "disaggregated-cross-node" and not method.works_across_nodes:
continue
applicable_methods.append((name, method))
print(f"\n{'Method':<25} {'Transfer Time':<15} {'Notes':<30}")
print("-" * 70)
for name, method in applicable_methods:
transfer_time = calculate_transfer_time(method, model_size_gb)
if transfer_time < 100:
time_str = f"{transfer_time:.1f} ms"
elif transfer_time < 60000:
time_str = f"{transfer_time/1000:.2f} s"
else:
time_str = f"{transfer_time/60000:.1f} min"
if transfer_time < 1000:
notes = "Excellent"
elif transfer_time < 10000:
notes = "Good"
elif transfer_time < 60000:
notes = "Acceptable"
else:
notes = "Slow"
if method.bandwidth_gbps == float('inf'):
notes = "Zero copy!"
print(f"{method.name:<25} {time_str:<15} {notes:<30}")
def demonstrate_ipc_concept():
"""Explain how CUDA IPC handles work."""
print("\n" + "=" * 70)
print(" CUDA IPC HANDLES: ZERO-COPY WEIGHT SHARING")
print("=" * 70)
print("""
How CUDA IPC (Inter-Process Communication) handles work:
┌─────────────────────────────────────────────────────────────────────┐
│ TRADITIONAL WEIGHT TRANSFER │
│ │
│ Training Process: Inference Process: │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ GPU Memory: │ COPY DATA │ GPU Memory: │ │
│ │ [Weight tensor] │ ───────────────► │ [Weight tensor] │ │
│ │ 140 GB │ 140 GB moved! │ 140 GB │ │
│ └─────────────────┘ └─────────────────┘ │
│ │
│ Time: 140GB / 450 GB/s = 311 ms (NVLink) │
└─────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────┐
│ CUDA IPC HANDLE SHARING │
│ │
│ Training Process: Inference Process: │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ GPU Memory: │ SHARE HANDLE │ GPU Memory: │ │
│ │ [Weight tensor] │ ───────────────► │ (same memory!) │ │
│ │ @ address 0x7f..│ Just a pointer │ [Weight tensor] │ │
│ └────────┬────────┘ (~100 bytes) └────────┬────────┘ │
│ │ │ │
│ └──────────── SAME GPU MEMORY ────────┘ │
│ │
│ Time: ~1 ms (only handle serialization) │
│ Data moved: ~100 bytes (not 140 GB!) │
└─────────────────────────────────────────────────────────────────────┘
The handle contains:
- GPU device ID
- Memory address (virtual)
- Size and stride information
- Reference counter handle
When the inference process "reconstructs" the tensor:
1. It creates a new Python tensor object
2. The tensor points to the SAME GPU memory
3. No data is copied!
Limitation: Both processes must be on the same GPU.
For multi-GPU setups, each GPU's weights need their own handle.
""")
def demonstrate_verl_approach():
"""Explain verl's weight update approach."""
print("\n" + "=" * 70)
print(" verl's WEIGHT UPDATE APPROACH")
print("=" * 70)
print("""
verl's Hybrid Engine uses a sophisticated weight update mechanism:
1. TRAINING PHASE
┌─────────────────────────────────────────────────────────────────┐
│ FSDP Training │
│ │
│ GPU 0: [Shard 0] [Shard 4] [Shard 8] ... │
│ GPU 1: [Shard 1] [Shard 5] [Shard 9] ... │
│ GPU 2: [Shard 2] [Shard 6] [Shard 10] ... │
│ GPU 3: [Shard 3] [Shard 7] [Shard 11] ... │
│ │
│ Weights are sharded across GPUs (FSDP) │
└─────────────────────────────────────────────────────────────────┘
2. GATHER FOR INFERENCE
┌─────────────────────────────────────────────────────────────────┐
│ All-Gather to reconstruct full weights │
│ │
│ GPU 0: [Full Layer 0] [Full Layer 1] ... │
│ GPU 1: [Full Layer 0] [Full Layer 1] ... │
│ GPU 2: [Full Layer 0] [Full Layer 1] ... │
│ GPU 3: [Full Layer 0] [Full Layer 1] ... │
│ │
│ (Temporary memory spike during gather) │
└─────────────────────────────────────────────────────────────────┘
3. CREATE IPC HANDLES
┌─────────────────────────────────────────────────────────────────┐
│ For each GPU's portion of weights: │
│ │
│ handle = tensor._cuda_ipc_handle() │
│ serialized = serialize(handle) # ~100 bytes │
│ │
│ Gather handles to coordinator (not data!) │
└─────────────────────────────────────────────────────────────────┘
4. INFERENCE ENGINE RECEIVES HANDLES
┌─────────────────────────────────────────────────────────────────┐
│ For each handle: │
│ │
│ tensor = reconstruct_from_handle(handle) │
│ # tensor now points to same GPU memory as training tensor │
│ │
│ model.load_weights(tensor) # Just pointer assignment │
└─────────────────────────────────────────────────────────────────┘
Benefits:
- Zero data movement (weights stay in place)
- Microsecond-level "transfer" time
- Memory shared between engines
Complexity:
- Must manage tensor lifetimes carefully
- FSDP gather creates temporary memory spike
- Coordination between training and inference loops
""")
def calculate_rlhf_timeline(model_size_gb: float, method_name: str,
generation_time_s: float, training_time_s: float) -> None:
"""Calculate RLHF iteration timeline with weight updates."""
method = TRANSFER_METHODS[method_name]
transfer_time = calculate_transfer_time(method, model_size_gb)
transfer_time_s = transfer_time / 1000
total_time = generation_time_s + transfer_time_s + training_time_s + transfer_time_s
print(f"\n{'='*70}")
print(f" RLHF ITERATION TIMELINE")
print(f"{'='*70}")
print(f"\nConfiguration:")
print(f" Model size: {model_size_gb} GB")
print(f" Transfer method: {method.name}")
print(f" Generation time: {generation_time_s} s")
print(f" Training time: {training_time_s} s")
print(f"\nTimeline:")
print(f"""
┌─────────────────────────────────────────────────────────────────┐
│ Generation {generation_time_s:>6.1f} s │
├─────────────────────────────────────────────────────────────────┤
│ Weight transfer (train → infer) {transfer_time_s:>6.2f} s │
├─────────────────────────────────────────────────────────────────┤
│ Training (PPO update) {training_time_s:>6.1f} s │
├─────────────────────────────────────────────────────────────────┤
│ Weight transfer (infer ← train) {transfer_time_s:>6.2f} s │
└─────────────────────────────────────────────────────────────────┘
Total iteration time: {total_time:.2f} s
""")
overhead_pct = (2 * transfer_time_s) / total_time * 100
print(f" Weight transfer overhead: {overhead_pct:.1f}% of iteration")
def main():
parser = argparse.ArgumentParser(description="Weight Update Demo")
parser.add_argument("--model-size", "-m", type=float, default=140,
help="Model size in GB (default: 140 for 70B model)")
args = parser.parse_args()
print("╔" + "═" * 68 + "╗")
print("║" + " WEIGHT UPDATE MECHANISMS DEMONSTRATION".center(68) + "║")
print("╚" + "═" * 68 + "╝")
# Compare methods for different architectures
compare_methods(args.model_size, "co-located")
compare_methods(args.model_size, "disaggregated-same-node")
compare_methods(args.model_size, "disaggregated-cross-node")
# Explain IPC
demonstrate_ipc_concept()
# Explain verl approach
demonstrate_verl_approach()
# Show timeline impact
calculate_rlhf_timeline(
model_size_gb=args.model_size,
method_name="cuda_ipc",
generation_time_s=30,
training_time_s=20
)
calculate_rlhf_timeline(
model_size_gb=args.model_size,
method_name="nccl_ib",
generation_time_s=30,
training_time_s=20
)
if __name__ == "__main__":
main()