Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

reward_calculator.py

Understand reward calculation with KL penalty

This script demonstrates how the total reward in RLHF is computed from the reward model score and KL penalty.

What It Does

  1. Shows raw reward model scores
  2. Computes KL divergence between actor and reference
  3. Applies the KL penalty with different β values
  4. Demonstrates why the penalty prevents reward hacking

Run It

python tutorial/part4-rlhf/chapter13-rlhf-flow/scripts/reward_calculator.py

Example Output

=== RLHF Reward Calculator ===

Response: "This is a great product! I highly recommend it!"

Reward Model Score: 0.85 (high quality response)

KL Divergence Calculation:
  Actor log prob for each token:
    "This": -2.3,  "is": -1.1,  "a": -0.8,  ...
  Reference log prob for each token:
    "This": -2.1,  "is": -1.0,  "a": -0.9,  ...

  KL per token = actor_logp - ref_logp
    "This": -0.2,  "is": -0.1,  "a": +0.1,  ...

  Total KL: 0.45 (actor has diverged from reference)

Total Reward with Different β:
  β = 0.0: R = 0.85 - 0.0 * 0.45 = 0.85
  β = 0.1: R = 0.85 - 0.1 * 0.45 = 0.805
  β = 0.5: R = 0.85 - 0.5 * 0.45 = 0.625
  β = 1.0: R = 0.85 - 1.0 * 0.45 = 0.40

Observation: Higher β penalizes divergence more heavily.

Why KL Penalty Matters

Without penalty (β=0):
  Actor learns to say "AMAZING! INCREDIBLE!" for everything
  Reward model gives high scores
  But output is unnatural

With penalty (β=0.1):
  Actor stays close to reference
  Must improve while remaining natural
  Better quality outputs

The Formula

def compute_reward(response, actor, reference, reward_model, beta):
    # Get reward model score
    rm_score = reward_model(response)

    # Compute KL divergence
    actor_logp = actor.log_prob(response)
    ref_logp = reference.log_prob(response)
    kl = (actor_logp - ref_logp).sum()

    # Total reward with penalty
    total_reward = rm_score - beta * kl

    return total_reward

Source Code

#!/usr/bin/env python3
"""
RLHF Reward Calculator

This script demonstrates how rewards are computed in RLHF:
- Reward model scoring
- KL divergence penalty
- Combined reward signal

Usage:
    python reward_calculator.py
"""

import argparse
import math
from typing import List, Tuple


def compute_kl_divergence(actor_log_probs: List[float],
                          ref_log_probs: List[float]) -> List[float]:
    """
    Compute per-token KL divergence.

    KL(actor || ref) = Σ p_actor * log(p_actor / p_ref)
                     = Σ p_actor * (log p_actor - log p_ref)

    Since we have log probs, this simplifies to computing the difference
    and then exponentiating to get actual KL.
    """
    kl_per_token = []
    for actor_lp, ref_lp in zip(actor_log_probs, ref_log_probs):
        # Approximate KL using log prob difference
        # Full KL would be: exp(actor_lp) * (actor_lp - ref_lp)
        # Common approximation: just the difference (works well in practice)
        kl = actor_lp - ref_lp
        kl_per_token.append(kl)
    return kl_per_token


def compute_rewards(
    reward_model_score: float,
    actor_log_probs: List[float],
    ref_log_probs: List[float],
    kl_coef: float = 0.02,
    reward_at_end_only: bool = True,
) -> Tuple[List[float], dict]:
    """
    Compute per-token rewards with KL penalty.

    Args:
        reward_model_score: Score from reward model (typically for full response)
        actor_log_probs: Log probabilities from actor for each token
        ref_log_probs: Log probabilities from reference for each token
        kl_coef: Coefficient for KL penalty (β in papers)
        reward_at_end_only: If True, RM score only at last token

    Returns:
        List of rewards for each token
        Dictionary with stats
    """
    num_tokens = len(actor_log_probs)

    # Compute KL divergence
    kl_per_token = compute_kl_divergence(actor_log_probs, ref_log_probs)

    # Compute rewards
    rewards = []
    for t in range(num_tokens):
        kl_penalty = kl_coef * kl_per_token[t]

        if reward_at_end_only:
            # RM score only at last token
            if t == num_tokens - 1:
                r = reward_model_score - kl_penalty
            else:
                r = -kl_penalty  # Only penalty
        else:
            # RM score distributed across tokens
            r = reward_model_score / num_tokens - kl_penalty

        rewards.append(r)

    stats = {
        'total_kl': sum(kl_per_token),
        'avg_kl': sum(kl_per_token) / len(kl_per_token),
        'total_kl_penalty': sum(kl_coef * kl for kl in kl_per_token),
        'reward_model_score': reward_model_score,
        'total_reward': sum(rewards),
    }

    return rewards, stats


def visualize_rewards(rewards: List[float], kl_per_token: List[float],
                       kl_coef: float, rm_score: float):
    """Visualize reward distribution across tokens."""
    print("\n" + "=" * 70)
    print(" TOKEN-LEVEL REWARD BREAKDOWN")
    print("=" * 70)

    print(f"\n{'Token':<8} {'KL':<12} {'KL Penalty':<12} {'RM Contrib':<12} {'Reward':<12}")
    print("-" * 60)

    num_tokens = len(rewards)
    for t in range(num_tokens):
        kl = kl_per_token[t]
        penalty = kl_coef * kl
        rm_contrib = rm_score if t == num_tokens - 1 else 0

        print(f"{t:<8} {kl:>+.4f}     {-penalty:>+.4f}      {rm_contrib:>+.4f}      {rewards[t]:>+.4f}")

    print("-" * 60)
    print(f"{'Total':<8} {sum(kl_per_token):>+.4f}     {-kl_coef*sum(kl_per_token):>+.4f}      "
          f"{rm_score:>+.4f}      {sum(rewards):>+.4f}")


def demonstrate_kl_penalty_effect():
    """Show how KL coefficient affects learning."""
    print("\n" + "=" * 70)
    print(" KL COEFFICIENT EFFECT")
    print("=" * 70)

    # Simulated response with moderate divergence
    actor_lps = [-1.0, -1.2, -0.8, -1.5, -0.9]  # Actor log probs
    ref_lps = [-1.1, -1.0, -1.0, -1.2, -1.0]    # Reference log probs
    rm_score = 0.5  # Positive reward from RM

    print("\nScenario: Response with RM score = 0.5")
    print("Actor is somewhat divergent from reference\n")

    kl_values = [0.0, 0.01, 0.02, 0.05, 0.1, 0.5]

    print(f"{'KL Coef':<10} {'Total KL Penalty':<18} {'Net Reward':<12} {'Effect':<20}")
    print("-" * 60)

    for kl_coef in kl_values:
        rewards, stats = compute_rewards(rm_score, actor_lps, ref_lps, kl_coef)
        net_reward = stats['total_reward']

        if net_reward > rm_score * 0.8:
            effect = "Weak penalty"
        elif net_reward > 0:
            effect = "Moderate penalty"
        elif net_reward > -0.5:
            effect = "Strong penalty"
        else:
            effect = "Overwhelming penalty"

        print(f"{kl_coef:<10} {stats['total_kl_penalty']:>+.4f}            {net_reward:>+.4f}       {effect}")

    print("""
Interpretation:
  β ≈ 0.00: No penalty, risk of reward hacking
  β ≈ 0.02: Typical value, balanced
  β ≈ 0.10: Strong regularization, slower learning
  β ≈ 0.50: KL dominates, almost no RM signal
""")


def demonstrate_kl_scenarios():
    """Show different KL divergence scenarios."""
    print("\n" + "=" * 70)
    print(" KL DIVERGENCE SCENARIOS")
    print("=" * 70)

    kl_coef = 0.02

    scenarios = [
        ("Low divergence (similar to reference)", [-1.0, -1.1, -0.9], [-1.0, -1.0, -1.0]),
        ("High divergence (very different)", [-0.5, -2.0, -0.3], [-1.5, -0.8, -1.2]),
        ("More confident than reference", [-0.2, -0.3, -0.2], [-1.0, -1.0, -1.0]),
        ("Less confident than reference", [-2.0, -2.5, -2.0], [-1.0, -1.0, -1.0]),
    ]

    rm_score = 0.5

    for name, actor_lps, ref_lps in scenarios:
        kl_per_token = compute_kl_divergence(actor_lps, ref_lps)
        rewards, stats = compute_rewards(rm_score, actor_lps, ref_lps, kl_coef)

        print(f"\n{name}:")
        print(f"  Actor log probs: {actor_lps}")
        print(f"  Ref log probs:   {ref_lps}")
        print(f"  Per-token KL:    {[f'{k:.2f}' for k in kl_per_token]}")
        print(f"  Total KL:        {stats['total_kl']:.4f}")
        print(f"  KL penalty:      {stats['total_kl_penalty']:.4f}")
        print(f"  Net reward:      {stats['total_reward']:.4f}")


def main():
    parser = argparse.ArgumentParser(description="RLHF Reward Calculator")
    parser.add_argument("--kl-coef", "-k", type=float, default=0.02,
                        help="KL penalty coefficient")
    parser.add_argument("--rm-score", "-r", type=float, default=0.5,
                        help="Reward model score")
    args = parser.parse_args()

    print("╔" + "═" * 68 + "╗")
    print("║" + " RLHF REWARD CALCULATOR".center(68) + "║")
    print("╚" + "═" * 68 + "╝")

    # Example response
    actor_log_probs = [-1.2, -0.8, -1.5, -0.9, -1.1, -1.0, -0.7, -1.3]
    ref_log_probs = [-1.0, -1.0, -1.2, -1.0, -1.0, -1.1, -1.0, -1.0]

    print(f"\nConfiguration:")
    print(f"  KL coefficient (β): {args.kl_coef}")
    print(f"  Reward model score: {args.rm_score}")
    print(f"  Response length: {len(actor_log_probs)} tokens")

    # Compute rewards
    rewards, stats = compute_rewards(
        args.rm_score, actor_log_probs, ref_log_probs, args.kl_coef
    )

    # Compute KL for visualization
    kl_per_token = compute_kl_divergence(actor_log_probs, ref_log_probs)

    # Visualize
    visualize_rewards(rewards, kl_per_token, args.kl_coef, args.rm_score)

    # Show stats
    print("\n" + "=" * 70)
    print(" SUMMARY STATISTICS")
    print("=" * 70)
    print(f"""
Reward Model Score:    {stats['reward_model_score']:>+.4f}
Total KL Divergence:   {stats['total_kl']:>+.4f}
Total KL Penalty:      {stats['total_kl_penalty']:>+.4f}
Net Total Reward:      {stats['total_reward']:>+.4f}

Reward Composition:
  RM contribution:     {stats['reward_model_score']:>+.4f} (at last token)
  KL penalty:          {-stats['total_kl_penalty']:>+.4f} (distributed)
  ────────────────────────────
  Net reward:          {stats['total_reward']:>+.4f}
""")

    # Demonstrate KL effects
    demonstrate_kl_penalty_effect()
    demonstrate_kl_scenarios()

    # Key insights
    print("\n" + "=" * 70)
    print(" KEY INSIGHTS")
    print("=" * 70)
    print("""
1. REWARD = RM_SCORE - β × KL
   The KL penalty prevents the actor from diverging too far from
   the reference model, avoiding "reward hacking".

2. KL IS COMPUTED PER-TOKEN
   Each token's probability is compared to the reference.
   This gives fine-grained control over divergence.

3. RM SCORE IS TYPICALLY END-ONLY
   The reward model scores the complete response.
   This score appears only at the last token.
   GAE propagates it backwards during training.

4. β IS A CRITICAL HYPERPARAMETER
   Too low: Reward hacking, degenerate solutions
   Too high: Learning is too slow, policy doesn't change
   Typical values: 0.01 - 0.05

5. NEGATIVE REWARDS ARE OK
   The policy gradient cares about relative advantages,
   not absolute reward values.
""")


if __name__ == "__main__":
    main()