rlhf_loop_pseudo.py
The complete RLHF training loop in pseudocode
This script shows the exact computation flow of one RLHF training step, making it easy to understand what happens and when.
What It Does
- Simulates all four models (Actor, Critic, Reward, Reference)
- Walks through each step of the training loop
- Shows tensor shapes and intermediate values
- Demonstrates the complete PPO update
Run It
python tutorial/part4-rlhf/chapter13-rlhf-flow/scripts/rlhf_loop_pseudo.py
Example Output
=== RLHF Training Loop Demo ===
Step 1: Sample prompts
Batch size: 4
Prompt shapes: (4, 64) tokens
Step 2: Generate responses (Actor)
Actor forward pass...
Generated tokens: (4, 128)
Actor logits: (4, 128, 50257)
Old log probs: (4, 128)
Step 3: Score responses (Reward Model)
Reward model forward pass...
Scores: [0.73, 0.45, 0.91, 0.62]
Step 4: Compute KL penalty (Reference)
Reference forward pass...
Reference log probs: (4, 128)
KL divergence per token: (4, 128)
Mean KL: 0.23
Step 5: Compute total rewards
reward = reward_model_score - β * KL
Total rewards: [0.50, 0.28, 0.75, 0.41]
Step 6: Compute advantages (Critic + GAE)
Critic forward pass...
Values: (4, 128)
GAE advantages: (4, 128)
Step 7: PPO update
Ratio = exp(new_log_prob - old_log_prob)
Clipped ratio: clip(ratio, 0.8, 1.2)
Actor loss: -0.042
Critic loss: 0.156
Update complete!
The Core Loop
for batch in dataloader:
# 1. Generate
responses, old_logprobs = actor.generate(batch.prompts)
# 2. Score
rewards = reward_model(batch.prompts, responses)
# 3. KL penalty
ref_logprobs = reference(batch.prompts, responses)
kl = old_logprobs - ref_logprobs
rewards = rewards - beta * kl
# 4. Advantages
values = critic(batch.prompts, responses)
advantages = gae(rewards, values)
# 5. PPO update
new_logprobs = actor(batch.prompts, responses)
ratio = (new_logprobs - old_logprobs).exp()
actor_loss = -torch.min(ratio * advantages,
ratio.clamp(0.8, 1.2) * advantages)
critic_loss = (values - rewards) ** 2
# 6. Backprop
(actor_loss + critic_loss).backward()
optimizer.step()
Source Code
#!/usr/bin/env python3
"""
RLHF Training Loop Pseudocode
This script demonstrates the complete RLHF training loop with
detailed comments explaining each step.
This is PSEUDOCODE - not runnable without actual model implementations.
It's meant to illustrate the data flow and computations involved.
Usage:
python rlhf_loop_pseudo.py
"""
from dataclasses import dataclass
from typing import List, Optional
import random
import math
@dataclass
class Prompt:
"""A training prompt."""
text: str
tokens: List[int]
@dataclass
class Response:
"""A generated response with metadata."""
tokens: List[int]
log_probs: List[float] # From actor
ref_log_probs: List[float] # From reference
values: List[float] # From critic
reward_score: float # From reward model
@dataclass
class Experience:
"""One token of experience for PPO."""
token: int
log_prob: float
ref_log_prob: float
value: float
reward: float
advantage: float
def rlhf_training_step(
prompts: List[Prompt],
actor, # The policy model being trained
critic, # The value function
reward_model, # Frozen reward model
reference, # Frozen reference policy
kl_coef: float = 0.02,
gamma: float = 1.0,
lam: float = 0.95,
clip_epsilon: float = 0.2,
) -> dict:
"""
One step of RLHF training.
This function shows the complete data flow through all four models.
"""
print("=" * 70)
print(" RLHF TRAINING STEP")
print("=" * 70)
# =========================================================================
# STEP 1: Generate Responses (Actor)
# =========================================================================
print("\n[Step 1] GENERATE RESPONSES")
print("-" * 50)
responses = []
for prompt in prompts:
# Generate response from actor
# In reality: autoregressive generation with temperature sampling
response_tokens = generate_response(actor, prompt)
# Get log probabilities from actor
actor_log_probs = get_log_probs(actor, prompt.tokens, response_tokens)
responses.append(Response(
tokens=response_tokens,
log_probs=actor_log_probs,
ref_log_probs=[], # Filled in step 3
values=[], # Filled in step 4
reward_score=0, # Filled in step 2
))
print(f" Generated {len(response_tokens)} tokens for prompt")
# =========================================================================
# STEP 2: Score Responses (Reward Model)
# =========================================================================
print("\n[Step 2] SCORE RESPONSES (Reward Model)")
print("-" * 50)
for i, (prompt, response) in enumerate(zip(prompts, responses)):
# Get reward score for complete response
# In reality: forward pass through reward model
full_sequence = prompt.tokens + response.tokens
response.reward_score = score_response(reward_model, full_sequence)
print(f" Response {i}: reward = {response.reward_score:.3f}")
# =========================================================================
# STEP 3: Compute KL Penalty (Reference Model)
# =========================================================================
print("\n[Step 3] COMPUTE KL PENALTY (Reference)")
print("-" * 50)
total_kl = 0
for prompt, response in zip(prompts, responses):
# Get reference log probabilities
response.ref_log_probs = get_log_probs(
reference, prompt.tokens, response.tokens
)
# Compute per-token KL divergence
kl_per_token = []
for actor_lp, ref_lp in zip(response.log_probs, response.ref_log_probs):
# KL = exp(actor_lp) * (actor_lp - ref_lp)
# Simplified: just the log ratio for penalty
kl = actor_lp - ref_lp
kl_per_token.append(kl)
avg_kl = sum(kl_per_token) / len(kl_per_token)
total_kl += avg_kl
avg_kl = total_kl / len(responses)
print(f" Average KL divergence: {avg_kl:.4f}")
# =========================================================================
# STEP 4: Compute Values (Critic)
# =========================================================================
print("\n[Step 4] COMPUTE VALUES (Critic)")
print("-" * 50)
for prompt, response in zip(prompts, responses):
# Get value estimates for each token position
# In reality: forward pass through critic
response.values = get_values(critic, prompt.tokens, response.tokens)
print(f" Values computed: mean={sum(response.values)/len(response.values):.3f}")
# =========================================================================
# STEP 5: Compute Rewards with KL Penalty
# =========================================================================
print("\n[Step 5] COMPUTE REWARDS WITH KL PENALTY")
print("-" * 50)
all_experiences = []
for prompt, response in zip(prompts, responses):
experiences = []
for t in range(len(response.tokens)):
# Per-token KL penalty
kl_penalty = kl_coef * (response.log_probs[t] - response.ref_log_probs[t])
# Reward: only at last token, minus KL at every token
if t == len(response.tokens) - 1:
token_reward = response.reward_score - kl_penalty
else:
token_reward = -kl_penalty # Just KL penalty for non-final tokens
experiences.append(Experience(
token=response.tokens[t],
log_prob=response.log_probs[t],
ref_log_prob=response.ref_log_probs[t],
value=response.values[t],
reward=token_reward,
advantage=0, # Computed in step 6
))
all_experiences.append(experiences)
final_reward = experiences[-1].reward
print(f" Final token reward: {final_reward:.3f} "
f"(score={response.reward_score:.3f}, kl_penalty included)")
# =========================================================================
# STEP 6: Compute GAE Advantages
# =========================================================================
print("\n[Step 6] COMPUTE GAE ADVANTAGES")
print("-" * 50)
for experiences in all_experiences:
# GAE computation (backwards)
gae = 0
for t in reversed(range(len(experiences))):
exp = experiences[t]
if t == len(experiences) - 1:
next_value = 0 # Terminal state
else:
next_value = experiences[t + 1].value
# TD error
delta = exp.reward + gamma * next_value - exp.value
# GAE
gae = delta + gamma * lam * gae
exp.advantage = gae
# Normalize advantages
advantages = [e.advantage for e in experiences]
mean_adv = sum(advantages) / len(advantages)
std_adv = (sum((a - mean_adv) ** 2 for a in advantages) / len(advantages)) ** 0.5
for exp in experiences:
exp.advantage = (exp.advantage - mean_adv) / (std_adv + 1e-8)
print(f" Advantages computed and normalized")
# =========================================================================
# STEP 7: PPO Update
# =========================================================================
print("\n[Step 7] PPO UPDATE")
print("-" * 50)
# Flatten all experiences
flat_experiences = [exp for exps in all_experiences for exp in exps]
# Compute PPO losses
policy_losses = []
value_losses = []
clip_fractions = []
for exp in flat_experiences:
# New log probability (after potential update)
# In reality: forward pass through updated actor
new_log_prob = exp.log_prob # Placeholder
# Probability ratio
ratio = math.exp(new_log_prob - exp.log_prob)
# Clipped objective
unclipped = ratio * exp.advantage
clipped = max(min(ratio, 1 + clip_epsilon), 1 - clip_epsilon) * exp.advantage
policy_loss = -min(unclipped, clipped)
policy_losses.append(policy_loss)
# Value loss
# In reality: new value prediction
new_value = exp.value # Placeholder
value_loss = (new_value - (exp.reward + gamma * 0)) ** 2 # Simplified
value_losses.append(value_loss)
# Track clipping
if abs(ratio - 1) > clip_epsilon:
clip_fractions.append(1)
else:
clip_fractions.append(0)
avg_policy_loss = sum(policy_losses) / len(policy_losses)
avg_value_loss = sum(value_losses) / len(value_losses)
avg_clip_frac = sum(clip_fractions) / len(clip_fractions)
print(f" Policy loss: {avg_policy_loss:.4f}")
print(f" Value loss: {avg_value_loss:.4f}")
print(f" Clip fraction: {avg_clip_frac:.2%}")
# =========================================================================
# Summary
# =========================================================================
print("\n" + "=" * 70)
print(" STEP SUMMARY")
print("=" * 70)
print(f"""
Models used:
- Actor: Generated {sum(len(r.tokens) for r in responses)} total tokens
- Reward: Scored {len(responses)} responses
- Reference: Computed KL for {len(responses)} responses
- Critic: Estimated values for {sum(len(r.tokens) for r in responses)} tokens
Losses:
- Policy loss: {avg_policy_loss:.4f}
- Value loss: {avg_value_loss:.4f}
KL penalty:
- Average KL: {avg_kl:.4f}
- KL coefficient: {kl_coef}
- Total KL penalty: {avg_kl * kl_coef:.4f}
""")
return {
'policy_loss': avg_policy_loss,
'value_loss': avg_value_loss,
'kl': avg_kl,
'clip_fraction': avg_clip_frac,
}
# =============================================================================
# Placeholder functions (would be real model calls in practice)
# =============================================================================
def generate_response(actor, prompt: Prompt) -> List[int]:
"""Generate response tokens from actor."""
# Simulated: random tokens
length = random.randint(10, 30)
return [random.randint(0, 999) for _ in range(length)]
def get_log_probs(model, prompt_tokens: List[int],
response_tokens: List[int]) -> List[float]:
"""Get log probabilities from model."""
# Simulated: random log probs
return [random.uniform(-3, -0.1) for _ in response_tokens]
def score_response(reward_model, tokens: List[int]) -> float:
"""Get reward score from reward model."""
# Simulated: random score
return random.uniform(-1, 1)
def get_values(critic, prompt_tokens: List[int],
response_tokens: List[int]) -> List[float]:
"""Get value estimates from critic."""
# Simulated: decreasing values
n = len(response_tokens)
return [0.5 * (n - i) / n for i in range(n)]
def main():
print("╔" + "═" * 68 + "╗")
print("║" + " RLHF TRAINING LOOP DEMONSTRATION".center(68) + "║")
print("╚" + "═" * 68 + "╝")
# Create sample prompts
prompts = [
Prompt("What is the capital of France?", [1, 2, 3, 4, 5]),
Prompt("Explain quantum computing.", [6, 7, 8, 9]),
Prompt("Write a haiku about programming.", [10, 11, 12, 13, 14]),
Prompt("What is machine learning?", [15, 16, 17]),
]
print(f"\nBatch size: {len(prompts)} prompts")
# Run one training step
stats = rlhf_training_step(
prompts=prompts,
actor=None, # Placeholder
critic=None,
reward_model=None,
reference=None,
kl_coef=0.02,
)
# Explain the process
print("\n" + "=" * 70)
print(" WHAT JUST HAPPENED")
print("=" * 70)
print("""
This simulated one complete RLHF training step:
1. GENERATION: Actor generated responses for each prompt
2. SCORING: Reward model evaluated response quality
3. KL COMPUTATION: Reference model computed divergence penalty
4. VALUE ESTIMATION: Critic predicted expected rewards
5. ADVANTAGE COMPUTATION: GAE combined rewards and values
6. PPO UPDATE: Actor and critic weights updated
In production, this happens with:
- Real neural network forward/backward passes
- GPU tensor operations
- Distributed training across multiple devices
- Gradient accumulation and synchronization
The key insight: RLHF is just PPO with:
- Reward from a learned reward model
- KL penalty to stay close to reference
- Four models instead of just actor-critic
""")
if __name__ == "__main__":
main()