Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

hello_distributed.py

Your first distributed program - see multiple processes communicate!

This is the “Hello, World!” of distributed computing. It spawns multiple processes that initialize a process group and communicate.

What It Does

  1. Spawns 4 worker processes using mp.spawn()
  2. Each process initializes the distributed environment
  3. Processes perform a simple all_gather to collect data from everyone
  4. Each process prints what it received

Run It

# Default: 4 processes
python tutorial/part1-distributed/chapter01-first-program/scripts/hello_distributed.py

# Custom world size
python tutorial/part1-distributed/chapter01-first-program/scripts/hello_distributed.py --world-size 8

Expected Output

[Rank 0] Hello! I see 4 processes in the world.
[Rank 1] Hello! I see 4 processes in the world.
[Rank 2] Hello! I see 4 processes in the world.
[Rank 3] Hello! I see 4 processes in the world.
[Rank 0] Gathered values from all ranks: [0, 1, 2, 3]
...

Key Concepts Demonstrated

  • mp.spawn() - Creates multiple processes, automatically passing rank
  • dist.init_process_group() - The handshake that enables communication
  • dist.all_gather() - Collect data from all processes

Source Code

#!/usr/bin/env python3
"""
Your First Distributed Program!

This script demonstrates the fundamentals of distributed PyTorch:
- Process group initialization
- Rank and world_size concepts
- Simple tensor communication with all_gather

Usage:
    python hello_distributed.py
    python hello_distributed.py --world-size 8

What this script does:
1. Spawns multiple processes (default: 4)
2. Each process initializes a distributed environment
3. Processes share information about themselves
4. We demonstrate all_gather to collect data from all processes
"""

import argparse
import os
from typing import List

import torch
import torch.distributed as dist
import torch.multiprocessing as mp


def get_device_info() -> dict:
    """Get information about the current process's compute device."""
    if torch.cuda.is_available():
        # Get local rank (which GPU this process should use)
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        device = torch.device(f"cuda:{local_rank}")
        device_name = torch.cuda.get_device_name(device)
    else:
        device = torch.device("cpu")
        device_name = "CPU"

    return {
        "device": device,
        "device_name": device_name,
        "pid": os.getpid(),
    }


def distributed_worker(rank: int, world_size: int, backend: str) -> None:
    """
    The main function that runs in each distributed process.

    Args:
        rank: Unique identifier for this process (0 to world_size-1)
        world_size: Total number of processes
        backend: Communication backend ('gloo' or 'nccl')
    """
    # =========================================================================
    # Step 1: Initialize the process group
    # =========================================================================
    # This is the "handshake" - all processes must call this before communicating
    # Environment variables are set by mp.spawn automatically:
    #   - MASTER_ADDR: Address of rank 0 process
    #   - MASTER_PORT: Port for communication

    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"

    dist.init_process_group(
        backend=backend,
        rank=rank,
        world_size=world_size
    )

    # =========================================================================
    # Step 2: Get device and process information
    # =========================================================================
    info = get_device_info()
    device = info["device"]

    print(f"[Rank {rank}/{world_size}] Hello! PID={info['pid']}, Device={info['device_name']}")

    # =========================================================================
    # Step 3: Demonstrate all_gather - collect data from all processes
    # =========================================================================
    # Each process creates a tensor with its rank value
    # After all_gather, every process has all tensors

    # Create a tensor unique to this rank
    my_tensor = torch.tensor([rank * 10.0, rank * 10.0 + 1], device=device)
    print(f"[Rank {rank}] My tensor: {my_tensor.tolist()}")

    # Prepare a list to receive tensors from all ranks
    gathered_tensors: List[torch.Tensor] = [
        torch.zeros(2, device=device) for _ in range(world_size)
    ]

    # all_gather: collect my_tensor from all ranks into gathered_tensors
    dist.all_gather(gathered_tensors, my_tensor)

    # Synchronize before printing (ensures all processes complete the operation)
    dist.barrier()

    if rank == 0:
        print("\n" + "=" * 50)
        print("all_gather results (collected on all ranks):")
        for i, tensor in enumerate(gathered_tensors):
            print(f"  From rank {i}: {tensor.tolist()}")
        print("=" * 50 + "\n")

    # =========================================================================
    # Step 4: Demonstrate all_reduce - aggregate values across all processes
    # =========================================================================
    # Each process contributes its rank, and we sum them all

    my_value = torch.tensor([float(rank)], device=device)
    dist.all_reduce(my_value, op=dist.ReduceOp.SUM)

    if rank == 0:
        expected_sum = sum(range(world_size))
        print(f"all_reduce (SUM) result: {my_value.item()}")
        print(f"  Expected: 0 + 1 + ... + {world_size-1} = {expected_sum}")
        print(f"  Correct: {my_value.item() == expected_sum}\n")

    # =========================================================================
    # Step 5: Show that rank 0 is special (often used as "master")
    # =========================================================================
    if rank == 0:
        print("I am rank 0 - often called the 'master' or 'coordinator'")
        print("Common responsibilities of rank 0:")
        print("  - Logging and printing results")
        print("  - Saving checkpoints")
        print("  - Orchestrating distributed operations")

    # =========================================================================
    # Step 6: Clean up
    # =========================================================================
    # Always destroy the process group when done
    dist.destroy_process_group()


def main():
    parser = argparse.ArgumentParser(description="Your First Distributed Program")
    parser.add_argument(
        "--world-size", "-w",
        type=int,
        default=4,
        help="Number of processes to spawn (default: 4)"
    )
    parser.add_argument(
        "--backend", "-b",
        type=str,
        default="gloo",
        choices=["gloo", "nccl"],
        help="Distributed backend (default: gloo for CPU compatibility)"
    )
    args = parser.parse_args()

    print("=" * 50)
    print(" YOUR FIRST DISTRIBUTED PROGRAM")
    print("=" * 50)
    print(f"World size: {args.world_size}")
    print(f"Backend: {args.backend}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"GPU count: {torch.cuda.device_count()}")
    print("=" * 50 + "\n")

    # Spawn worker processes
    # mp.spawn will:
    #   1. Create args.world_size new processes
    #   2. Call distributed_worker(rank, world_size, backend) in each
    #   3. Pass rank=0,1,2,... to each process automatically
    mp.spawn(
        distributed_worker,
        args=(args.world_size, args.backend),
        nprocs=args.world_size,
        join=True  # Wait for all processes to complete
    )

    print("\nAll processes completed successfully!")


if __name__ == "__main__":
    main()