Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

weight_update_demo.py

Compare different weight transfer mechanisms for RLHF

This script demonstrates how weights are transferred from training to inference engines in different RLHF architectures.

What It Does

  1. Simulates three weight transfer methods
  2. Measures transfer time and memory usage
  3. Shows trade-offs between approaches

Run It

python tutorial/part4-rlhf/chapter14-rlhf-architecture/scripts/weight_update_demo.py

Example Output

=== Weight Transfer Mechanisms ===

Model size: 70B parameters = 140 GB (FP16)

Method 1: Disk-based Transfer
  Write to disk: 28.0 seconds (5 GB/s SSD)
  Read from disk: 28.0 seconds
  Total: 56.0 seconds
  Note: Works across any hardware configuration

Method 2: NCCL Transfer (Network)
  Gather weights on training rank 0: 2.1 seconds
  Transfer to inference cluster: 5.6 seconds (25 GB/s InfiniBand)
  Broadcast to inference ranks: 2.1 seconds
  Total: 9.8 seconds
  Note: Requires network connectivity between clusters

Method 3: CUDA IPC (Same GPU)
  Get IPC handle: 0.001 seconds
  Serialize handle: 0.001 seconds
  Reconstruct tensor: 0.001 seconds
  Total: 0.003 seconds (!)
  Note: Zero data movement - same memory, new reference

Comparison:
  Disk:     56,000 ms (100% transfer)
  NCCL:      9,800 ms (18% of disk)
  CUDA IPC:      3 ms (0.005% of disk)

The verl approach (CUDA IPC) achieves near-zero overhead!

The Key Insight

Disk transfer:
  [GPU Memory] → [CPU Memory] → [Disk] → [CPU Memory] → [GPU Memory]
  Lots of data movement

NCCL transfer:
  [GPU Memory] ──network──► [GPU Memory]
  Still moves all the data

CUDA IPC:
  [GPU Memory] ← same memory! → [GPU Memory view]
  No data movement at all!

Source Code

#!/usr/bin/env python3
"""
Weight Update Mechanisms Demonstration

This script demonstrates different weight update mechanisms used in RLHF:
- Disk-based transfer
- NCCL-based transfer
- Shared memory (IPC handles)

Usage:
    python weight_update_demo.py
"""

import argparse
import time
from dataclasses import dataclass
from typing import Dict


@dataclass
class TransferMethod:
    """Configuration for a weight transfer method."""
    name: str
    description: str
    bandwidth_gbps: float  # GB/s
    setup_overhead_ms: float
    works_across_nodes: bool
    works_same_gpu: bool


# Common transfer methods
TRANSFER_METHODS = {
    "disk_ssd": TransferMethod(
        name="Disk (NVMe SSD)",
        description="Save to disk, load from disk",
        bandwidth_gbps=7.0,  # PCIe 4.0 NVMe
        setup_overhead_ms=100,
        works_across_nodes=True,
        works_same_gpu=True,
    ),
    "disk_hdd": TransferMethod(
        name="Disk (HDD/NFS)",
        description="Save to network storage",
        bandwidth_gbps=0.2,
        setup_overhead_ms=500,
        works_across_nodes=True,
        works_same_gpu=True,
    ),
    "nccl_nvlink": TransferMethod(
        name="NCCL (NVLink)",
        description="GPU-to-GPU within node",
        bandwidth_gbps=450,  # NVLink 4.0
        setup_overhead_ms=10,
        works_across_nodes=False,
        works_same_gpu=True,
    ),
    "nccl_ib": TransferMethod(
        name="NCCL (InfiniBand)",
        description="GPU-to-GPU across nodes",
        bandwidth_gbps=50,  # 400Gbps IB
        setup_overhead_ms=50,
        works_across_nodes=True,
        works_same_gpu=True,
    ),
    "nccl_ethernet": TransferMethod(
        name="NCCL (Ethernet)",
        description="GPU-to-GPU over ethernet",
        bandwidth_gbps=12.5,  # 100Gbps
        setup_overhead_ms=100,
        works_across_nodes=True,
        works_same_gpu=True,
    ),
    "cuda_ipc": TransferMethod(
        name="CUDA IPC Handle",
        description="Share GPU memory pointer",
        bandwidth_gbps=float('inf'),  # Zero copy!
        setup_overhead_ms=1,
        works_across_nodes=False,
        works_same_gpu=True,  # Same GPU only!
    ),
}


def calculate_transfer_time(method: TransferMethod, size_gb: float) -> float:
    """Calculate transfer time in milliseconds."""
    if method.bandwidth_gbps == float('inf'):
        # Zero copy - only setup overhead
        return method.setup_overhead_ms

    transfer_ms = (size_gb / method.bandwidth_gbps) * 1000
    return transfer_ms + method.setup_overhead_ms


def compare_methods(model_size_gb: float, architecture: str) -> None:
    """Compare transfer methods for a given scenario."""
    print(f"\n{'='*70}")
    print(f" WEIGHT UPDATE COMPARISON: {model_size_gb}GB Model, {architecture} Architecture")
    print(f"{'='*70}")

    applicable_methods = []

    for name, method in TRANSFER_METHODS.items():
        if architecture == "co-located" and not method.works_same_gpu:
            continue
        if architecture == "disaggregated-cross-node" and not method.works_across_nodes:
            continue
        applicable_methods.append((name, method))

    print(f"\n{'Method':<25} {'Transfer Time':<15} {'Notes':<30}")
    print("-" * 70)

    for name, method in applicable_methods:
        transfer_time = calculate_transfer_time(method, model_size_gb)

        if transfer_time < 100:
            time_str = f"{transfer_time:.1f} ms"
        elif transfer_time < 60000:
            time_str = f"{transfer_time/1000:.2f} s"
        else:
            time_str = f"{transfer_time/60000:.1f} min"

        if transfer_time < 1000:
            notes = "Excellent"
        elif transfer_time < 10000:
            notes = "Good"
        elif transfer_time < 60000:
            notes = "Acceptable"
        else:
            notes = "Slow"

        if method.bandwidth_gbps == float('inf'):
            notes = "Zero copy!"

        print(f"{method.name:<25} {time_str:<15} {notes:<30}")


def demonstrate_ipc_concept():
    """Explain how CUDA IPC handles work."""
    print("\n" + "=" * 70)
    print(" CUDA IPC HANDLES: ZERO-COPY WEIGHT SHARING")
    print("=" * 70)
    print("""
How CUDA IPC (Inter-Process Communication) handles work:

┌─────────────────────────────────────────────────────────────────────┐
│ TRADITIONAL WEIGHT TRANSFER                                          │
│                                                                     │
│ Training Process:                     Inference Process:            │
│ ┌─────────────────┐                  ┌─────────────────┐           │
│ │ GPU Memory:     │   COPY DATA      │ GPU Memory:     │           │
│ │ [Weight tensor] │ ───────────────► │ [Weight tensor] │           │
│ │ 140 GB          │   140 GB moved!  │ 140 GB          │           │
│ └─────────────────┘                  └─────────────────┘           │
│                                                                     │
│ Time: 140GB / 450 GB/s = 311 ms (NVLink)                           │
└─────────────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────────────┐
│ CUDA IPC HANDLE SHARING                                              │
│                                                                     │
│ Training Process:                     Inference Process:            │
│ ┌─────────────────┐                  ┌─────────────────┐           │
│ │ GPU Memory:     │   SHARE HANDLE   │ GPU Memory:     │           │
│ │ [Weight tensor] │ ───────────────► │ (same memory!)  │           │
│ │ @ address 0x7f..│   Just a pointer │ [Weight tensor] │           │
│ └────────┬────────┘   (~100 bytes)   └────────┬────────┘           │
│          │                                     │                    │
│          └──────────── SAME GPU MEMORY ────────┘                    │
│                                                                     │
│ Time: ~1 ms (only handle serialization)                            │
│ Data moved: ~100 bytes (not 140 GB!)                               │
└─────────────────────────────────────────────────────────────────────┘

The handle contains:
  - GPU device ID
  - Memory address (virtual)
  - Size and stride information
  - Reference counter handle

When the inference process "reconstructs" the tensor:
  1. It creates a new Python tensor object
  2. The tensor points to the SAME GPU memory
  3. No data is copied!

Limitation: Both processes must be on the same GPU.
For multi-GPU setups, each GPU's weights need their own handle.
""")


def demonstrate_verl_approach():
    """Explain verl's weight update approach."""
    print("\n" + "=" * 70)
    print(" verl's WEIGHT UPDATE APPROACH")
    print("=" * 70)
    print("""
verl's Hybrid Engine uses a sophisticated weight update mechanism:

1. TRAINING PHASE
   ┌─────────────────────────────────────────────────────────────────┐
   │ FSDP Training                                                    │
   │                                                                 │
   │ GPU 0: [Shard 0] [Shard 4] [Shard 8]  ...                      │
   │ GPU 1: [Shard 1] [Shard 5] [Shard 9]  ...                      │
   │ GPU 2: [Shard 2] [Shard 6] [Shard 10] ...                      │
   │ GPU 3: [Shard 3] [Shard 7] [Shard 11] ...                      │
   │                                                                 │
   │ Weights are sharded across GPUs (FSDP)                         │
   └─────────────────────────────────────────────────────────────────┘

2. GATHER FOR INFERENCE
   ┌─────────────────────────────────────────────────────────────────┐
   │ All-Gather to reconstruct full weights                          │
   │                                                                 │
   │ GPU 0: [Full Layer 0] [Full Layer 1] ...                       │
   │ GPU 1: [Full Layer 0] [Full Layer 1] ...                       │
   │ GPU 2: [Full Layer 0] [Full Layer 1] ...                       │
   │ GPU 3: [Full Layer 0] [Full Layer 1] ...                       │
   │                                                                 │
   │ (Temporary memory spike during gather)                         │
   └─────────────────────────────────────────────────────────────────┘

3. CREATE IPC HANDLES
   ┌─────────────────────────────────────────────────────────────────┐
   │ For each GPU's portion of weights:                              │
   │                                                                 │
   │ handle = tensor._cuda_ipc_handle()                             │
   │ serialized = serialize(handle)  # ~100 bytes                   │
   │                                                                 │
   │ Gather handles to coordinator (not data!)                      │
   └─────────────────────────────────────────────────────────────────┘

4. INFERENCE ENGINE RECEIVES HANDLES
   ┌─────────────────────────────────────────────────────────────────┐
   │ For each handle:                                                │
   │                                                                 │
   │ tensor = reconstruct_from_handle(handle)                       │
   │ # tensor now points to same GPU memory as training tensor      │
   │                                                                 │
   │ model.load_weights(tensor)  # Just pointer assignment          │
   └─────────────────────────────────────────────────────────────────┘

Benefits:
  - Zero data movement (weights stay in place)
  - Microsecond-level "transfer" time
  - Memory shared between engines

Complexity:
  - Must manage tensor lifetimes carefully
  - FSDP gather creates temporary memory spike
  - Coordination between training and inference loops
""")


def calculate_rlhf_timeline(model_size_gb: float, method_name: str,
                             generation_time_s: float, training_time_s: float) -> None:
    """Calculate RLHF iteration timeline with weight updates."""
    method = TRANSFER_METHODS[method_name]
    transfer_time = calculate_transfer_time(method, model_size_gb)
    transfer_time_s = transfer_time / 1000

    total_time = generation_time_s + transfer_time_s + training_time_s + transfer_time_s

    print(f"\n{'='*70}")
    print(f" RLHF ITERATION TIMELINE")
    print(f"{'='*70}")
    print(f"\nConfiguration:")
    print(f"  Model size: {model_size_gb} GB")
    print(f"  Transfer method: {method.name}")
    print(f"  Generation time: {generation_time_s} s")
    print(f"  Training time: {training_time_s} s")

    print(f"\nTimeline:")
    print(f"""
  ┌─────────────────────────────────────────────────────────────────┐
  │ Generation                                      {generation_time_s:>6.1f} s      │
  ├─────────────────────────────────────────────────────────────────┤
  │ Weight transfer (train → infer)                 {transfer_time_s:>6.2f} s      │
  ├─────────────────────────────────────────────────────────────────┤
  │ Training (PPO update)                           {training_time_s:>6.1f} s      │
  ├─────────────────────────────────────────────────────────────────┤
  │ Weight transfer (infer ← train)                 {transfer_time_s:>6.2f} s      │
  └─────────────────────────────────────────────────────────────────┘
  Total iteration time: {total_time:.2f} s
""")

    overhead_pct = (2 * transfer_time_s) / total_time * 100
    print(f"  Weight transfer overhead: {overhead_pct:.1f}% of iteration")


def main():
    parser = argparse.ArgumentParser(description="Weight Update Demo")
    parser.add_argument("--model-size", "-m", type=float, default=140,
                        help="Model size in GB (default: 140 for 70B model)")
    args = parser.parse_args()

    print("╔" + "═" * 68 + "╗")
    print("║" + " WEIGHT UPDATE MECHANISMS DEMONSTRATION".center(68) + "║")
    print("╚" + "═" * 68 + "╝")

    # Compare methods for different architectures
    compare_methods(args.model_size, "co-located")
    compare_methods(args.model_size, "disaggregated-same-node")
    compare_methods(args.model_size, "disaggregated-cross-node")

    # Explain IPC
    demonstrate_ipc_concept()

    # Explain verl approach
    demonstrate_verl_approach()

    # Show timeline impact
    calculate_rlhf_timeline(
        model_size_gb=args.model_size,
        method_name="cuda_ipc",
        generation_time_s=30,
        training_time_s=20
    )

    calculate_rlhf_timeline(
        model_size_gb=args.model_size,
        method_name="nccl_ib",
        generation_time_s=30,
        training_time_s=20
    )


if __name__ == "__main__":
    main()