Skip to content

Quick Start

Get up and running with reverse attention tracing in under 5 minutes.

Basic Usage

from transformers import AutoModelForCausalLM, AutoTokenizer
from reverse_attention import ReverseAttentionTracer

# 1. Load a model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-0.5B",
    attn_implementation="eager",  # Required for attention output
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")

# 2. Create a tracer
tracer = ReverseAttentionTracer(model, tokenizer)

# 3. Trace attention paths from the last token
result = tracer.trace_text("The quick brown fox jumps over the lazy dog.")

# 4. See what we found
for i, path in enumerate(result.paths_text):
    print(f"Beam {i+1}: {path}")

Example output:

Beam 1: dog. ← lazy ← the ← over ← jumps
Beam 2: dog. ← the ← lazy ← over ← jumps
Beam 3: dog. ← lazy ← over ← jumps ← fox
...

Interactive Visualization

Generate a beautiful Sankey diagram:

# Generate HTML visualization
tracer.render_html(result, "output/", open_browser=True)

This creates an interactive HTML file that you can:

  • Zoom and pan with scroll and drag
  • Click nodes to highlight connected paths
  • Filter by beam using the dropdown
  • Inspect details in the info panel

Using the CLI

Run the demo script for quick experimentation:

python examples/demo_qwen2.py --text "Your text here" --open-browser

CLI Options

Option Default Description
--model Qwen/Qwen2-0.5B Model name or path
--text (sample text) Text to trace
--target-pos -1 Target position (negative indexing supported)
--layer -1 Layer index (negative indexing supported)
--top-beam 5 Number of beams to keep
--top-k 5 Top-k predecessors per step
--output output Output directory
--open-browser false Open visualization in browser
--device auto Device (cuda/mps/cpu/auto)

Understanding the Output

TraceResult

The trace() and trace_text() methods return a TraceResult object:

result = tracer.trace_text("Hello world")

print(f"Sequence length: {result.seq_len}")
print(f"Target position: {result.target_pos}")
print(f"Layer: {result.layer}")
print(f"Tokens: {result.tokens}")
print(f"Number of beams: {len(result.beams)}")

BeamPath

Each beam contains detailed path information:

beam = result.beams[0]

print(f"Positions: {beam.positions}")      # [9, 7, 5, 3]
print(f"Tokens: {beam.tokens}")            # ['dog.', 'lazy', 'over', 'jumps']
print(f"Edge attentions: {beam.edge_attns}")  # [0.45, 0.32, 0.28]
print(f"Normalized score: {beam.score_norm}")

Next Steps