Skip to content

Beam Search Algorithm

The beam search algorithm is the heart of reverse attention tracing. It efficiently explores multiple paths through the attention matrix without exhaustively checking every possibility.

How It Works

The Basic Idea

  1. Start at the target position (e.g., the last token)
  2. Find the top-k positions that this token attends to most
  3. Keep multiple candidates (beams) alive, not just the best one
  4. Repeat for each candidate until reaching the start of the sequence
  5. Rank all completed paths by their cumulative score

Step-by-Step Example

Let's trace from position 5 with top_k=3 and top_beam=2:

Initial state:

Active beams: [pos 5]
Terminal beams: []

Step 1: Position 5 attends most to positions 4, 2, 1 (with attention 0.4, 0.3, 0.2)

Candidates:
  - [5 → 4] score: log(0.4) = -0.92
  - [5 → 2] score: log(0.3) = -1.20
  - [5 → 1] score: log(0.2) = -1.61

Keep top 2:
  - [5 → 4]
  - [5 → 2]

Step 2: Expand each beam

From position 4: attends to 3, 1, 0 (attention: 0.5, 0.3, 0.1) From position 2: attends to 1, 0 (attention: 0.6, 0.3)

Candidates:
  - [5 → 4 → 3] score: -0.92 + log(0.5) = -1.61
  - [5 → 4 → 1] score: -0.92 + log(0.3) = -2.12
  - [5 → 4 → 0] score: -0.92 + log(0.1) = -3.22  ← TERMINAL
  - [5 → 2 → 1] score: -1.20 + log(0.6) = -1.71
  - [5 → 2 → 0] score: -1.20 + log(0.3) = -2.40  ← TERMINAL

Keep top 2 active:
  - [5 → 4 → 3]
  - [5 → 2 → 1]

Terminal:
  - [5 → 4 → 0]
  - [5 → 2 → 0]

Continue until all beams terminate (reach position 0 or BOS).

Why Log Probabilities?

Multiplying small attention weights leads to numerical underflow:

0.1 × 0.1 × 0.1 × 0.1 = 0.0001  # Works
0.1^50 ≈ 0  # Underflow!

Using log probabilities turns multiplication into addition:

log(0.1) + log(0.1) + log(0.1) + log(0.1) = -4  # Stable

This keeps scores in a manageable range regardless of path length.

Length Normalization

Longer paths accumulate more (negative) log scores. Without normalization, the algorithm would unfairly prefer shorter paths.

Normalization Methods

Method Formula Effect
none score Raw cumulative log prob
avg_logprob score / length Geometric mean (default)
sqrt score / sqrt(length) Moderate normalization
pow:α score / length^α Tunable (e.g., pow:0.7)

When to Use Each

  • avg_logprob (default): Fair comparison across path lengths. Best for most use cases.
  • none: When you want to see absolute influence, favoring shorter paths.
  • sqrt: Slight preference for shorter paths while still considering longer ones.
  • pow:0.7: Fine-tune the balance.

Pruning Strategy

At each step, we keep only top_beam candidates. This is crucial for efficiency:

  • Without pruning: Exponential explosion (k^d candidates after d steps)
  • With pruning: Linear in sequence length (top_beam × top_k per step)

Pruning is based on normalized score, so we compare paths fairly regardless of their current length.

Termination Conditions

A beam terminates when:

  1. Reaches position 0: Can't go further back
  2. Reaches BOS token: Beginning of sequence marker (if stop_at_bos=True)
  3. No valid predecessors: All attention weights below min_attn threshold

The Algorithm in Code

def beam_search_backward(attn, target_pos, top_beam, top_k):
    active_beams = [start at target_pos]
    terminal_beams = []

    while active_beams:
        candidates = []

        for beam in active_beams:
            # Find top-k predecessors
            predecessors = get_top_k(attn[beam.pos, :beam.pos], k=top_k)

            for pred_pos, attn_weight in predecessors:
                new_beam = extend(beam, pred_pos, attn_weight)

                if should_terminate(new_beam):
                    terminal_beams.append(new_beam)
                else:
                    candidates.append(new_beam)

        # Keep only top_beam candidates
        active_beams = prune(candidates, top_beam)

    return sorted(terminal_beams, by=normalized_score)

Parameter Effects

top_k (predecessors per step)

Higher values explore more alternatives at each position:

  • top_k=1: Greedy (always follow max attention)
  • top_k=3: Modest exploration
  • top_k=5: Good balance (default)
  • top_k=10: Thorough exploration, slower

top_beam (beams to keep)

Higher values maintain more diversity:

  • top_beam=1: Single best path
  • top_beam=5: Multiple perspectives (default)
  • top_beam=10: Very diverse paths

Trade-offs

Config Diversity Speed Memory
Low k, low beam Low Fast Low
High k, low beam Medium Medium Medium
Low k, high beam Medium Medium Medium
High k, high beam High Slow High

Next Steps