Basic Tracing Example¶
A complete, annotated walkthrough of using reverse attention tracing.
The Complete Script¶
#!/usr/bin/env python3
"""
Basic example of reverse attention tracing.
This script demonstrates:
1. Loading a model and tokenizer
2. Creating a tracer
3. Running a trace
4. Interpreting the results
5. Generating a visualization
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from reverse_attention import ReverseAttentionTracer
def main():
# =========================================================
# Step 1: Load Model and Tokenizer
# =========================================================
print("Loading model...")
# We use Qwen2-0.5B as it's small enough to run on most machines
# but still shows interesting attention patterns
model_name = "Qwen/Qwen2-0.5B"
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True, # Required for some models
)
# Determine the best device
if torch.cuda.is_available():
device = "cuda"
dtype = torch.float16 # Use half precision on GPU
elif torch.backends.mps.is_available():
device = "mps"
dtype = torch.float16
else:
device = "cpu"
dtype = torch.float32 # Full precision on CPU
print(f"Using device: {device}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map=device,
trust_remote_code=True,
attn_implementation="eager", # Required for attention output
)
# =========================================================
# Step 2: Create the Tracer
# =========================================================
tracer = ReverseAttentionTracer(model, tokenizer)
# =========================================================
# Step 3: Define Input and Run Trace
# =========================================================
# Our test sentence - a classic pangram
text = "The quick brown fox jumps over the lazy dog."
print(f"\nInput: {text}")
print(f"Tokens: {tokenizer.tokenize(text)}")
# Run the trace
# - target_pos=-1: Start from the last token
# - layer=-1: Use the last transformer layer
# - top_beam=5: Keep the top 5 paths
# - top_k=5: Consider top 5 predecessors at each step
result = tracer.trace_text(
text,
target_pos=-1,
layer=-1,
top_beam=5,
top_k=5,
)
# =========================================================
# Step 4: Explore the Results
# =========================================================
print("\n" + "=" * 60)
print("TRACE RESULTS")
print("=" * 60)
# Basic metadata
print(f"\nSequence length: {result.seq_len}")
print(f"Target position: {result.target_pos} ({result.tokens[result.target_pos]})")
print(f"Layer analyzed: {result.layer}")
print(f"Number of beams found: {len(result.beams)}")
# Human-readable paths
print("\n--- Top Attention Paths ---")
for i, path_text in enumerate(result.paths_text):
print(f"\nBeam {i + 1}:")
print(f" {path_text}")
# Detailed beam analysis
print("\n--- Detailed Beam Analysis ---")
top_beam = result.beams[0]
print(f"\nTop beam details:")
print(f" Positions: {top_beam.positions}")
print(f" Tokens: {top_beam.tokens}")
print(f" Edge attentions: {[f'{a:.3f}' for a in top_beam.edge_attns]}")
print(f" Raw score: {top_beam.score_raw:.4f}")
print(f" Normalized score: {top_beam.score_norm:.4f}")
# =========================================================
# Step 5: Generate Visualization
# =========================================================
output_dir = "output"
print(f"\nGenerating visualization in: {output_dir}/")
html_path = tracer.render_html(
result,
out_dir=output_dir,
open_browser=False, # Set to True to auto-open
)
print(f"Visualization saved to: {html_path}")
print("\nOpen this file in a browser to see the interactive Sankey diagram.")
# =========================================================
# Bonus: Analyze Sankey Data
# =========================================================
print("\n--- Sankey Diagram Data ---")
print(f"Nodes: {len(result.sankey.nodes)}")
print(f"Links: {len(result.sankey.links)}")
# Find the strongest link
strongest_link = max(result.sankey.links, key=lambda l: l.value)
print(f"\nStrongest link:")
print(f" {strongest_link.source} → {strongest_link.target}")
print(f" Weight: {strongest_link.value:.4f}")
print(f" Used by beams: {strongest_link.beam_indices}")
if __name__ == "__main__":
main()
Running the Example¶
Save the script as basic_example.py and run:
Expected Output¶
Loading model...
Using device: cuda
Input: The quick brown fox jumps over the lazy dog.
Tokens: ['The', ' quick', ' brown', ' fox', ' jumps', ' over', ' the', ' lazy', ' dog', '.']
============================================================
TRACE RESULTS
============================================================
Sequence length: 10
Target position: 9 (.)
Layer analyzed: 23
Number of beams found: 5
--- Top Attention Paths ---
Beam 1:
. ← dog ← lazy ← the ← over
Beam 2:
. ← dog ← the ← lazy ← over
Beam 3:
. ← dog ← lazy ← over ← jumps
...
--- Detailed Beam Analysis ---
Top beam details:
Positions: [9, 8, 7, 6, 5]
Tokens: ['.', 'dog', 'lazy', 'the', 'over']
Edge attentions: ['0.412', '0.289', '0.341', '0.256']
Raw score: -3.2451
Normalized score: -0.8113
Generating visualization in: output/
Visualization saved to: output/index.html
Open this file in a browser to see the interactive Sankey diagram.
Understanding the Results¶
The Path Format¶
This reads right-to-left as the attention flow:
- We start at
.(the target token) .most strongly attends todogdogattends tolazylazyattends tothetheattends toover
The Scores¶
- Raw score: Sum of log(attention) along the path. More negative = lower probability.
- Normalized score: Raw score divided by path length. Allows fair comparison across paths.
Multiple Beams¶
Having multiple beams shows alternative attention paths:
- Beam 1: The most probable path
- Beam 2-5: Alternative paths that also have high attention
If beams are very similar, the attention is focused. If they differ significantly, attention is distributed.
Variations to Try¶
Different Target Positions¶
Different Layers¶
# Early layer (more local patterns)
result = tracer.trace_text(text, layer=0)
# Middle layer
result = tracer.trace_text(text, layer=10)
More Exploration¶
Filter Weak Attention¶
Next Steps¶
- Parameter Tuning - Fine-tune for your use case
- Visualization Guide - Interpret the diagrams
- Advanced Usage - More complex patterns