Beam Search Algorithm¶
The beam search algorithm is the heart of reverse attention tracing. It efficiently explores multiple paths through the attention matrix without exhaustively checking every possibility.
How It Works¶
The Basic Idea¶
- Start at the target position (e.g., the last token)
- Find the top-k positions that this token attends to most
- Keep multiple candidates (beams) alive, not just the best one
- Repeat for each candidate until reaching the start of the sequence
- Rank all completed paths by their cumulative score
Step-by-Step Example¶
Let's trace from position 5 with top_k=3 and top_beam=2:
Initial state:
Step 1: Position 5 attends most to positions 4, 2, 1 (with attention 0.4, 0.3, 0.2)
Candidates:
- [5 → 4] score: log(0.4) = -0.92
- [5 → 2] score: log(0.3) = -1.20
- [5 → 1] score: log(0.2) = -1.61
Keep top 2:
- [5 → 4]
- [5 → 2]
Step 2: Expand each beam
From position 4: attends to 3, 1, 0 (attention: 0.5, 0.3, 0.1) From position 2: attends to 1, 0 (attention: 0.6, 0.3)
Candidates:
- [5 → 4 → 3] score: -0.92 + log(0.5) = -1.61
- [5 → 4 → 1] score: -0.92 + log(0.3) = -2.12
- [5 → 4 → 0] score: -0.92 + log(0.1) = -3.22 ← TERMINAL
- [5 → 2 → 1] score: -1.20 + log(0.6) = -1.71
- [5 → 2 → 0] score: -1.20 + log(0.3) = -2.40 ← TERMINAL
Keep top 2 active:
- [5 → 4 → 3]
- [5 → 2 → 1]
Terminal:
- [5 → 4 → 0]
- [5 → 2 → 0]
Continue until all beams terminate (reach position 0 or BOS).
Why Log Probabilities?¶
Multiplying small attention weights leads to numerical underflow:
Using log probabilities turns multiplication into addition:
This keeps scores in a manageable range regardless of path length.
Length Normalization¶
Longer paths accumulate more (negative) log scores. Without normalization, the algorithm would unfairly prefer shorter paths.
Normalization Methods¶
| Method | Formula | Effect |
|---|---|---|
none |
score |
Raw cumulative log prob |
avg_logprob |
score / length |
Geometric mean (default) |
sqrt |
score / sqrt(length) |
Moderate normalization |
pow:α |
score / length^α |
Tunable (e.g., pow:0.7) |
When to Use Each¶
avg_logprob(default): Fair comparison across path lengths. Best for most use cases.none: When you want to see absolute influence, favoring shorter paths.sqrt: Slight preference for shorter paths while still considering longer ones.pow:0.7: Fine-tune the balance.
Pruning Strategy¶
At each step, we keep only top_beam candidates. This is crucial for efficiency:
- Without pruning: Exponential explosion (k^d candidates after d steps)
- With pruning: Linear in sequence length (top_beam × top_k per step)
Pruning is based on normalized score, so we compare paths fairly regardless of their current length.
Termination Conditions¶
A beam terminates when:
- Reaches position 0: Can't go further back
- Reaches BOS token: Beginning of sequence marker (if
stop_at_bos=True) - No valid predecessors: All attention weights below
min_attnthreshold
The Algorithm in Code¶
def beam_search_backward(attn, target_pos, top_beam, top_k):
active_beams = [start at target_pos]
terminal_beams = []
while active_beams:
candidates = []
for beam in active_beams:
# Find top-k predecessors
predecessors = get_top_k(attn[beam.pos, :beam.pos], k=top_k)
for pred_pos, attn_weight in predecessors:
new_beam = extend(beam, pred_pos, attn_weight)
if should_terminate(new_beam):
terminal_beams.append(new_beam)
else:
candidates.append(new_beam)
# Keep only top_beam candidates
active_beams = prune(candidates, top_beam)
return sorted(terminal_beams, by=normalized_score)
Parameter Effects¶
top_k (predecessors per step)¶
Higher values explore more alternatives at each position:
top_k=1: Greedy (always follow max attention)top_k=3: Modest explorationtop_k=5: Good balance (default)top_k=10: Thorough exploration, slower
top_beam (beams to keep)¶
Higher values maintain more diversity:
top_beam=1: Single best pathtop_beam=5: Multiple perspectives (default)top_beam=10: Very diverse paths
Trade-offs¶
| Config | Diversity | Speed | Memory |
|---|---|---|---|
| Low k, low beam | Low | Fast | Low |
| High k, low beam | Medium | Medium | Medium |
| Low k, high beam | Medium | Medium | Medium |
| High k, high beam | High | Slow | High |
Next Steps¶
- Parameter Tuning - Optimize for your use case
- Visualization Guide - Interpret the results
- API Reference - Full parameter documentation