Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

gae_visualizer.py

Visualize Generalized Advantage Estimation (GAE)

This script helps you understand how GAE works by visualizing the advantage computation for different λ values.

What It Does

  1. Creates a sample trajectory with rewards and values
  2. Computes advantages with different λ values
  3. Visualizes how λ affects the bias-variance tradeoff
  4. Shows why λ=0.95 is common

Run It

python tutorial/part4-rlhf/chapter12-rl-fundamentals/scripts/gae_visualizer.py

Example Output

=== GAE Visualizer ===

Sample trajectory (10 steps):
  Rewards: [0, 0, 0, 1, 0, 0, 0, 0, 1, 0]
  Values:  [0.5, 0.6, 0.7, 0.8, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

TD Errors (δ):
  Step 0: δ = 0 + 0.99*0.6 - 0.5 = 0.094
  Step 1: δ = 0 + 0.99*0.7 - 0.6 = 0.093
  ...

Advantages by λ:

λ = 0.0 (TD error only, high bias):
  A = [0.09, 0.09, 0.11, 0.42, -0.36, 0.10, 0.11, 0.12, 0.31, -0.89]
  Variance: 0.15

λ = 0.5 (balanced):
  A = [0.21, 0.19, 0.22, 0.38, -0.15, 0.16, 0.18, 0.18, 0.24, -0.89]
  Variance: 0.12

λ = 0.95 (common choice):
  A = [0.45, 0.38, 0.35, 0.32, -0.02, 0.23, 0.22, 0.20, 0.18, -0.89]
  Variance: 0.14

λ = 1.0 (full returns, low bias):
  A = [0.52, 0.44, 0.40, 0.35, 0.01, 0.26, 0.24, 0.21, 0.18, -0.89]
  Variance: 0.16

The GAE Formula

A^GAE_t = δt + (γλ)δt+1 + (γλ)²δt+2 + ...
        = Σ (γλ)^k δt+k

Where δt = rt + γV(st+1) - V(st)

Why λ = 0.95?

  • λ = 0: Only considers immediate TD error (high bias, low variance)
  • λ = 1: Full Monte Carlo returns (low bias, high variance)
  • λ = 0.95: Good balance - mostly looks ahead, slight smoothing

Source Code

#!/usr/bin/env python3
"""
GAE (Generalized Advantage Estimation) Visualizer

This script demonstrates how GAE works and its effect on advantage estimation.

Usage:
    python gae_visualizer.py
"""

import argparse
from typing import List, Tuple


def generate_trajectory(length: int = 10) -> Tuple[List[float], List[float]]:
    """
    Generate a sample trajectory with rewards and values.

    Returns:
        rewards: List of rewards at each step
        values: List of value estimates at each step
    """
    # Simulated trajectory: mostly small rewards, occasional large
    rewards = [
        0.1, 0.1, 0.2, 0.1, 0.5,  # Early exploration
        0.1, 0.3, 0.1, 0.1, 1.0,  # Some success at end
    ][:length]

    # Value estimates (what the critic predicts)
    values = [
        0.8, 0.7, 0.7, 0.6, 0.5,  # Decreasing as end approaches
        0.4, 0.4, 0.3, 0.2, 0.1,
    ][:length]

    return rewards, values


def compute_td_errors(rewards: List[float], values: List[float],
                       gamma: float = 0.99) -> List[float]:
    """
    Compute TD (Temporal Difference) errors.

    TD error = r_t + γV(s_{t+1}) - V(s_t)

    This is the "surprise" - how much better/worse than expected.
    """
    td_errors = []
    n = len(rewards)

    for t in range(n):
        r = rewards[t]
        v_t = values[t]
        v_next = values[t + 1] if t + 1 < n else 0  # Terminal state has 0 value

        delta = r + gamma * v_next - v_t
        td_errors.append(delta)

    return td_errors


def compute_gae(td_errors: List[float], gamma: float = 0.99,
                lam: float = 0.95) -> List[float]:
    """
    Compute GAE advantages.

    A^GAE_t = Σ_{k=0}^{T-t} (γλ)^k δ_{t+k}

    λ controls bias-variance tradeoff:
    - λ=0: Just TD error (high bias, low variance)
    - λ=1: Full returns minus baseline (low bias, high variance)
    """
    advantages = []
    gae = 0
    n = len(td_errors)

    # Compute backwards
    for t in reversed(range(n)):
        gae = td_errors[t] + gamma * lam * gae
        advantages.insert(0, gae)

    return advantages


def compute_monte_carlo_returns(rewards: List[float], values: List[float],
                                 gamma: float = 0.99) -> List[float]:
    """
    Compute Monte Carlo returns (full returns minus baseline).

    This is GAE with λ=1.
    """
    n = len(rewards)
    returns = [0.0] * n
    G = 0

    for t in reversed(range(n)):
        G = rewards[t] + gamma * G
        returns[t] = G - values[t]  # Advantage = return - baseline

    return returns


def visualize_advantages(rewards: List[float], values: List[float],
                          td_errors: List[float],
                          advantages_by_lambda: dict):
    """Visualize how different λ values affect advantage estimation."""
    n = len(rewards)

    print("\n" + "=" * 80)
    print(" TRAJECTORY DATA")
    print("=" * 80)

    print(f"\n{'Step':<6} {'Reward':<10} {'Value':<10} {'TD Error':<12}")
    print("-" * 40)
    for t in range(n):
        print(f"{t:<6} {rewards[t]:<10.2f} {values[t]:<10.2f} {td_errors[t]:<12.4f}")

    print("\n" + "=" * 80)
    print(" GAE WITH DIFFERENT λ VALUES")
    print("=" * 80)

    header = f"{'Step':<6}"
    for lam in sorted(advantages_by_lambda.keys()):
        header += f"{'λ=' + str(lam):<12}"
    print(f"\n{header}")
    print("-" * (6 + 12 * len(advantages_by_lambda)))

    for t in range(n):
        row = f"{t:<6}"
        for lam in sorted(advantages_by_lambda.keys()):
            row += f"{advantages_by_lambda[lam][t]:<12.4f}"
        print(row)


def analyze_bias_variance():
    """Analyze the bias-variance tradeoff in GAE."""
    print("\n" + "=" * 80)
    print(" BIAS-VARIANCE TRADEOFF ANALYSIS")
    print("=" * 80)
    print("""
GAE with different λ values:

┌────────────────────────────────────────────────────────────────────┐
│                     λ = 0 (TD Error Only)                          │
│                                                                    │
│  A_t = δ_t = r_t + γV(s_{t+1}) - V(s_t)                           │
│                                                                    │
│  Properties:                                                        │
│    - HIGH BIAS: Only looks one step ahead                          │
│    - LOW VARIANCE: Single reward, single value estimate            │
│    - Fast to adapt, but might miss long-term patterns             │
└────────────────────────────────────────────────────────────────────┘

┌────────────────────────────────────────────────────────────────────┐
│                     λ = 1 (Monte Carlo)                            │
│                                                                    │
│  A_t = G_t - V(s_t) = Σ γ^k r_{t+k} - V(s_t)                      │
│                                                                    │
│  Properties:                                                        │
│    - LOW BIAS: Uses all future rewards                             │
│    - HIGH VARIANCE: Accumulates noise from many rewards            │
│    - Accurate but slow to learn                                    │
└────────────────────────────────────────────────────────────────────┘

┌────────────────────────────────────────────────────────────────────┐
│                     λ = 0.95 (Typical)                             │
│                                                                    │
│  A_t = Σ (γλ)^k δ_{t+k}                                           │
│                                                                    │
│  Properties:                                                        │
│    - BALANCED: Weights earlier steps more                          │
│    - PRACTICAL: Good empirical performance                         │
│    - Exponential decay of TD errors                                │
└────────────────────────────────────────────────────────────────────┘

The weighting scheme (λ = 0.95, γ = 0.99):

  Step t:   weight = 1.00
  Step t+1: weight = 0.94 (γλ = 0.99 × 0.95)
  Step t+2: weight = 0.88 (γλ)²
  Step t+3: weight = 0.83 (γλ)³
  ...
  Step t+10: weight = 0.53

Far future TD errors contribute less, reducing variance while
maintaining enough signal for learning.
""")


def demonstrate_numerical_example():
    """Show a concrete numerical example of GAE computation."""
    print("\n" + "=" * 80)
    print(" NUMERICAL EXAMPLE: GAE COMPUTATION")
    print("=" * 80)

    # Simple 3-step trajectory
    rewards = [0.1, 0.2, 1.0]  # Big reward at end
    values = [0.5, 0.4, 0.2]   # Decreasing values
    gamma = 0.99
    lam = 0.95

    print(f"""
Trajectory:
  Step 0: r=0.1, V=0.5
  Step 1: r=0.2, V=0.4
  Step 2: r=1.0, V=0.2 (terminal)

TD Errors (δ_t = r_t + γV_{t+1} - V_t):
  δ_0 = 0.1 + 0.99×0.4 - 0.5 = {0.1 + 0.99*0.4 - 0.5:.4f}
  δ_1 = 0.2 + 0.99×0.2 - 0.4 = {0.2 + 0.99*0.2 - 0.4:.4f}
  δ_2 = 1.0 + 0.99×0.0 - 0.2 = {1.0 + 0.99*0.0 - 0.2:.4f}

GAE Computation (working backwards, λ={lam}):
  A_2 = δ_2 = {1.0 + 0.99*0.0 - 0.2:.4f}
  A_1 = δ_1 + γλ×A_2 = {0.2 + 0.99*0.2 - 0.4:.4f} + {gamma*lam}×{1.0 + 0.99*0.0 - 0.2:.4f}
      = {(0.2 + 0.99*0.2 - 0.4) + gamma*lam*(1.0 + 0.99*0.0 - 0.2):.4f}
  A_0 = δ_0 + γλ×A_1
      = {(0.1 + 0.99*0.4 - 0.5) + gamma*lam*((0.2 + 0.99*0.2 - 0.4) + gamma*lam*(1.0 + 0.99*0.0 - 0.2)):.4f}

Notice: Step 0's advantage includes discounted information about the
big reward at step 2, but that information is attenuated by (γλ)².
""")


def main():
    parser = argparse.ArgumentParser(description="GAE Visualizer")
    parser.add_argument("--gamma", type=float, default=0.99,
                        help="Discount factor")
    parser.add_argument("--trajectory-length", type=int, default=10,
                        help="Length of trajectory")
    args = parser.parse_args()

    print("╔" + "═" * 78 + "╗")
    print("║" + " GENERALIZED ADVANTAGE ESTIMATION (GAE) VISUALIZER".center(78) + "║")
    print("╚" + "═" * 78 + "╝")

    # Generate trajectory
    rewards, values = generate_trajectory(args.trajectory_length)

    # Compute TD errors
    td_errors = compute_td_errors(rewards, values, args.gamma)

    # Compute GAE for different lambda values
    lambda_values = [0.0, 0.5, 0.9, 0.95, 1.0]
    advantages_by_lambda = {}

    for lam in lambda_values:
        advantages_by_lambda[lam] = compute_gae(td_errors, args.gamma, lam)

    # Visualize
    visualize_advantages(rewards, values, td_errors, advantages_by_lambda)

    # Analysis
    analyze_bias_variance()

    # Numerical example
    demonstrate_numerical_example()

    # Key insights
    print("\n" + "=" * 80)
    print(" KEY INSIGHTS FOR RLHF")
    print("=" * 80)
    print("""
In RLHF training:

1. TRAJECTORY = One response generation
   - States: prompt + partial response
   - Actions: generated tokens
   - Reward: typically only at the end (from reward model)

2. GAE HELPS WITH CREDIT ASSIGNMENT
   - Which tokens contributed to the final reward?
   - GAE propagates reward signal backwards through the response
   - λ controls how far back the signal reaches

3. TYPICAL RLHF SETTINGS
   - γ = 0.99 or 1.0 (we care about all tokens)
   - λ = 0.95 (good balance)
   - Sparse reward (only at end of generation)

4. VALUE FUNCTION IN RLHF
   - Critic network predicts expected reward
   - Helps reduce variance in policy gradient
   - Often shares layers with the policy (actor-critic)

5. PPO USES GAE ADVANTAGES
   - Compute GAE for each token in response
   - Update policy using PPO-Clip objective
   - Bounded updates prevent catastrophic forgetting
""")


if __name__ == "__main__":
    main()