Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

rlhf_loop_pseudo.py

The complete RLHF training loop in pseudocode

This script shows the exact computation flow of one RLHF training step, making it easy to understand what happens and when.

What It Does

  1. Simulates all four models (Actor, Critic, Reward, Reference)
  2. Walks through each step of the training loop
  3. Shows tensor shapes and intermediate values
  4. Demonstrates the complete PPO update

Run It

python tutorial/part4-rlhf/chapter13-rlhf-flow/scripts/rlhf_loop_pseudo.py

Example Output

=== RLHF Training Loop Demo ===

Step 1: Sample prompts
  Batch size: 4
  Prompt shapes: (4, 64) tokens

Step 2: Generate responses (Actor)
  Actor forward pass...
  Generated tokens: (4, 128)
  Actor logits: (4, 128, 50257)
  Old log probs: (4, 128)

Step 3: Score responses (Reward Model)
  Reward model forward pass...
  Scores: [0.73, 0.45, 0.91, 0.62]

Step 4: Compute KL penalty (Reference)
  Reference forward pass...
  Reference log probs: (4, 128)
  KL divergence per token: (4, 128)
  Mean KL: 0.23

Step 5: Compute total rewards
  reward = reward_model_score - β * KL
  Total rewards: [0.50, 0.28, 0.75, 0.41]

Step 6: Compute advantages (Critic + GAE)
  Critic forward pass...
  Values: (4, 128)
  GAE advantages: (4, 128)

Step 7: PPO update
  Ratio = exp(new_log_prob - old_log_prob)
  Clipped ratio: clip(ratio, 0.8, 1.2)
  Actor loss: -0.042
  Critic loss: 0.156

  Update complete!

The Core Loop

for batch in dataloader:
    # 1. Generate
    responses, old_logprobs = actor.generate(batch.prompts)

    # 2. Score
    rewards = reward_model(batch.prompts, responses)

    # 3. KL penalty
    ref_logprobs = reference(batch.prompts, responses)
    kl = old_logprobs - ref_logprobs
    rewards = rewards - beta * kl

    # 4. Advantages
    values = critic(batch.prompts, responses)
    advantages = gae(rewards, values)

    # 5. PPO update
    new_logprobs = actor(batch.prompts, responses)
    ratio = (new_logprobs - old_logprobs).exp()
    actor_loss = -torch.min(ratio * advantages,
                           ratio.clamp(0.8, 1.2) * advantages)
    critic_loss = (values - rewards) ** 2

    # 6. Backprop
    (actor_loss + critic_loss).backward()
    optimizer.step()

Source Code

#!/usr/bin/env python3
"""
RLHF Training Loop Pseudocode

This script demonstrates the complete RLHF training loop with
detailed comments explaining each step.

This is PSEUDOCODE - not runnable without actual model implementations.
It's meant to illustrate the data flow and computations involved.

Usage:
    python rlhf_loop_pseudo.py
"""

from dataclasses import dataclass
from typing import List, Optional
import random
import math


@dataclass
class Prompt:
    """A training prompt."""
    text: str
    tokens: List[int]


@dataclass
class Response:
    """A generated response with metadata."""
    tokens: List[int]
    log_probs: List[float]  # From actor
    ref_log_probs: List[float]  # From reference
    values: List[float]  # From critic
    reward_score: float  # From reward model


@dataclass
class Experience:
    """One token of experience for PPO."""
    token: int
    log_prob: float
    ref_log_prob: float
    value: float
    reward: float
    advantage: float


def rlhf_training_step(
    prompts: List[Prompt],
    actor,  # The policy model being trained
    critic,  # The value function
    reward_model,  # Frozen reward model
    reference,  # Frozen reference policy
    kl_coef: float = 0.02,
    gamma: float = 1.0,
    lam: float = 0.95,
    clip_epsilon: float = 0.2,
) -> dict:
    """
    One step of RLHF training.

    This function shows the complete data flow through all four models.
    """
    print("=" * 70)
    print(" RLHF TRAINING STEP")
    print("=" * 70)

    # =========================================================================
    # STEP 1: Generate Responses (Actor)
    # =========================================================================
    print("\n[Step 1] GENERATE RESPONSES")
    print("-" * 50)

    responses = []
    for prompt in prompts:
        # Generate response from actor
        # In reality: autoregressive generation with temperature sampling
        response_tokens = generate_response(actor, prompt)

        # Get log probabilities from actor
        actor_log_probs = get_log_probs(actor, prompt.tokens, response_tokens)

        responses.append(Response(
            tokens=response_tokens,
            log_probs=actor_log_probs,
            ref_log_probs=[],  # Filled in step 3
            values=[],  # Filled in step 4
            reward_score=0,  # Filled in step 2
        ))
        print(f"  Generated {len(response_tokens)} tokens for prompt")

    # =========================================================================
    # STEP 2: Score Responses (Reward Model)
    # =========================================================================
    print("\n[Step 2] SCORE RESPONSES (Reward Model)")
    print("-" * 50)

    for i, (prompt, response) in enumerate(zip(prompts, responses)):
        # Get reward score for complete response
        # In reality: forward pass through reward model
        full_sequence = prompt.tokens + response.tokens
        response.reward_score = score_response(reward_model, full_sequence)
        print(f"  Response {i}: reward = {response.reward_score:.3f}")

    # =========================================================================
    # STEP 3: Compute KL Penalty (Reference Model)
    # =========================================================================
    print("\n[Step 3] COMPUTE KL PENALTY (Reference)")
    print("-" * 50)

    total_kl = 0
    for prompt, response in zip(prompts, responses):
        # Get reference log probabilities
        response.ref_log_probs = get_log_probs(
            reference, prompt.tokens, response.tokens
        )

        # Compute per-token KL divergence
        kl_per_token = []
        for actor_lp, ref_lp in zip(response.log_probs, response.ref_log_probs):
            # KL = exp(actor_lp) * (actor_lp - ref_lp)
            # Simplified: just the log ratio for penalty
            kl = actor_lp - ref_lp
            kl_per_token.append(kl)

        avg_kl = sum(kl_per_token) / len(kl_per_token)
        total_kl += avg_kl

    avg_kl = total_kl / len(responses)
    print(f"  Average KL divergence: {avg_kl:.4f}")

    # =========================================================================
    # STEP 4: Compute Values (Critic)
    # =========================================================================
    print("\n[Step 4] COMPUTE VALUES (Critic)")
    print("-" * 50)

    for prompt, response in zip(prompts, responses):
        # Get value estimates for each token position
        # In reality: forward pass through critic
        response.values = get_values(critic, prompt.tokens, response.tokens)
        print(f"  Values computed: mean={sum(response.values)/len(response.values):.3f}")

    # =========================================================================
    # STEP 5: Compute Rewards with KL Penalty
    # =========================================================================
    print("\n[Step 5] COMPUTE REWARDS WITH KL PENALTY")
    print("-" * 50)

    all_experiences = []

    for prompt, response in zip(prompts, responses):
        experiences = []

        for t in range(len(response.tokens)):
            # Per-token KL penalty
            kl_penalty = kl_coef * (response.log_probs[t] - response.ref_log_probs[t])

            # Reward: only at last token, minus KL at every token
            if t == len(response.tokens) - 1:
                token_reward = response.reward_score - kl_penalty
            else:
                token_reward = -kl_penalty  # Just KL penalty for non-final tokens

            experiences.append(Experience(
                token=response.tokens[t],
                log_prob=response.log_probs[t],
                ref_log_prob=response.ref_log_probs[t],
                value=response.values[t],
                reward=token_reward,
                advantage=0,  # Computed in step 6
            ))

        all_experiences.append(experiences)
        final_reward = experiences[-1].reward
        print(f"  Final token reward: {final_reward:.3f} "
              f"(score={response.reward_score:.3f}, kl_penalty included)")

    # =========================================================================
    # STEP 6: Compute GAE Advantages
    # =========================================================================
    print("\n[Step 6] COMPUTE GAE ADVANTAGES")
    print("-" * 50)

    for experiences in all_experiences:
        # GAE computation (backwards)
        gae = 0
        for t in reversed(range(len(experiences))):
            exp = experiences[t]

            if t == len(experiences) - 1:
                next_value = 0  # Terminal state
            else:
                next_value = experiences[t + 1].value

            # TD error
            delta = exp.reward + gamma * next_value - exp.value

            # GAE
            gae = delta + gamma * lam * gae
            exp.advantage = gae

        # Normalize advantages
        advantages = [e.advantage for e in experiences]
        mean_adv = sum(advantages) / len(advantages)
        std_adv = (sum((a - mean_adv) ** 2 for a in advantages) / len(advantages)) ** 0.5
        for exp in experiences:
            exp.advantage = (exp.advantage - mean_adv) / (std_adv + 1e-8)

        print(f"  Advantages computed and normalized")

    # =========================================================================
    # STEP 7: PPO Update
    # =========================================================================
    print("\n[Step 7] PPO UPDATE")
    print("-" * 50)

    # Flatten all experiences
    flat_experiences = [exp for exps in all_experiences for exp in exps]

    # Compute PPO losses
    policy_losses = []
    value_losses = []
    clip_fractions = []

    for exp in flat_experiences:
        # New log probability (after potential update)
        # In reality: forward pass through updated actor
        new_log_prob = exp.log_prob  # Placeholder

        # Probability ratio
        ratio = math.exp(new_log_prob - exp.log_prob)

        # Clipped objective
        unclipped = ratio * exp.advantage
        clipped = max(min(ratio, 1 + clip_epsilon), 1 - clip_epsilon) * exp.advantage

        policy_loss = -min(unclipped, clipped)
        policy_losses.append(policy_loss)

        # Value loss
        # In reality: new value prediction
        new_value = exp.value  # Placeholder
        value_loss = (new_value - (exp.reward + gamma * 0)) ** 2  # Simplified
        value_losses.append(value_loss)

        # Track clipping
        if abs(ratio - 1) > clip_epsilon:
            clip_fractions.append(1)
        else:
            clip_fractions.append(0)

    avg_policy_loss = sum(policy_losses) / len(policy_losses)
    avg_value_loss = sum(value_losses) / len(value_losses)
    avg_clip_frac = sum(clip_fractions) / len(clip_fractions)

    print(f"  Policy loss: {avg_policy_loss:.4f}")
    print(f"  Value loss: {avg_value_loss:.4f}")
    print(f"  Clip fraction: {avg_clip_frac:.2%}")

    # =========================================================================
    # Summary
    # =========================================================================
    print("\n" + "=" * 70)
    print(" STEP SUMMARY")
    print("=" * 70)
    print(f"""
Models used:
  - Actor: Generated {sum(len(r.tokens) for r in responses)} total tokens
  - Reward: Scored {len(responses)} responses
  - Reference: Computed KL for {len(responses)} responses
  - Critic: Estimated values for {sum(len(r.tokens) for r in responses)} tokens

Losses:
  - Policy loss: {avg_policy_loss:.4f}
  - Value loss: {avg_value_loss:.4f}

KL penalty:
  - Average KL: {avg_kl:.4f}
  - KL coefficient: {kl_coef}
  - Total KL penalty: {avg_kl * kl_coef:.4f}
""")

    return {
        'policy_loss': avg_policy_loss,
        'value_loss': avg_value_loss,
        'kl': avg_kl,
        'clip_fraction': avg_clip_frac,
    }


# =============================================================================
# Placeholder functions (would be real model calls in practice)
# =============================================================================

def generate_response(actor, prompt: Prompt) -> List[int]:
    """Generate response tokens from actor."""
    # Simulated: random tokens
    length = random.randint(10, 30)
    return [random.randint(0, 999) for _ in range(length)]


def get_log_probs(model, prompt_tokens: List[int],
                   response_tokens: List[int]) -> List[float]:
    """Get log probabilities from model."""
    # Simulated: random log probs
    return [random.uniform(-3, -0.1) for _ in response_tokens]


def score_response(reward_model, tokens: List[int]) -> float:
    """Get reward score from reward model."""
    # Simulated: random score
    return random.uniform(-1, 1)


def get_values(critic, prompt_tokens: List[int],
                response_tokens: List[int]) -> List[float]:
    """Get value estimates from critic."""
    # Simulated: decreasing values
    n = len(response_tokens)
    return [0.5 * (n - i) / n for i in range(n)]


def main():
    print("╔" + "═" * 68 + "╗")
    print("║" + " RLHF TRAINING LOOP DEMONSTRATION".center(68) + "║")
    print("╚" + "═" * 68 + "╝")

    # Create sample prompts
    prompts = [
        Prompt("What is the capital of France?", [1, 2, 3, 4, 5]),
        Prompt("Explain quantum computing.", [6, 7, 8, 9]),
        Prompt("Write a haiku about programming.", [10, 11, 12, 13, 14]),
        Prompt("What is machine learning?", [15, 16, 17]),
    ]

    print(f"\nBatch size: {len(prompts)} prompts")

    # Run one training step
    stats = rlhf_training_step(
        prompts=prompts,
        actor=None,  # Placeholder
        critic=None,
        reward_model=None,
        reference=None,
        kl_coef=0.02,
    )

    # Explain the process
    print("\n" + "=" * 70)
    print(" WHAT JUST HAPPENED")
    print("=" * 70)
    print("""
This simulated one complete RLHF training step:

1. GENERATION: Actor generated responses for each prompt
2. SCORING: Reward model evaluated response quality
3. KL COMPUTATION: Reference model computed divergence penalty
4. VALUE ESTIMATION: Critic predicted expected rewards
5. ADVANTAGE COMPUTATION: GAE combined rewards and values
6. PPO UPDATE: Actor and critic weights updated

In production, this happens with:
- Real neural network forward/backward passes
- GPU tensor operations
- Distributed training across multiple devices
- Gradient accumulation and synchronization

The key insight: RLHF is just PPO with:
- Reward from a learned reward model
- KL penalty to stay close to reference
- Four models instead of just actor-critic
""")


if __name__ == "__main__":
    main()