Quick Start¶
Get up and running with reverse attention tracing in under 5 minutes.
Basic Usage¶
from transformers import AutoModelForCausalLM, AutoTokenizer
from reverse_attention import ReverseAttentionTracer
# 1. Load a model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2-0.5B",
attn_implementation="eager", # Required for attention output
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
# 2. Create a tracer
tracer = ReverseAttentionTracer(model, tokenizer)
# 3. Trace attention paths from the last token
result = tracer.trace_text("The quick brown fox jumps over the lazy dog.")
# 4. See what we found
for i, path in enumerate(result.paths_text):
print(f"Beam {i+1}: {path}")
Example output:
Beam 1: dog. ← lazy ← the ← over ← jumps
Beam 2: dog. ← the ← lazy ← over ← jumps
Beam 3: dog. ← lazy ← over ← jumps ← fox
...
Interactive Visualization¶
Generate a beautiful Sankey diagram:
This creates an interactive HTML file that you can:
- Zoom and pan with scroll and drag
- Click nodes to highlight connected paths
- Filter by beam using the dropdown
- Inspect details in the info panel
Using the CLI¶
Run the demo script for quick experimentation:
CLI Options¶
| Option | Default | Description |
|---|---|---|
--model |
Qwen/Qwen2-0.5B |
Model name or path |
--text |
(sample text) | Text to trace |
--target-pos |
-1 |
Target position (negative indexing supported) |
--layer |
-1 |
Layer index (negative indexing supported) |
--top-beam |
5 |
Number of beams to keep |
--top-k |
5 |
Top-k predecessors per step |
--output |
output |
Output directory |
--open-browser |
false |
Open visualization in browser |
--device |
auto |
Device (cuda/mps/cpu/auto) |
Understanding the Output¶
TraceResult¶
The trace() and trace_text() methods return a TraceResult object:
result = tracer.trace_text("Hello world")
print(f"Sequence length: {result.seq_len}")
print(f"Target position: {result.target_pos}")
print(f"Layer: {result.layer}")
print(f"Tokens: {result.tokens}")
print(f"Number of beams: {len(result.beams)}")
BeamPath¶
Each beam contains detailed path information:
beam = result.beams[0]
print(f"Positions: {beam.positions}") # [9, 7, 5, 3]
print(f"Tokens: {beam.tokens}") # ['dog.', 'lazy', 'over', 'jumps']
print(f"Edge attentions: {beam.edge_attns}") # [0.45, 0.32, 0.28]
print(f"Normalized score: {beam.score_norm}")
Next Steps¶
- What is Reverse Attention? - Understand the theory
- Parameter Tuning - Optimize for your use case
- API Reference - Complete API documentation