Advanced Usage¶
This guide covers advanced patterns and use cases for reverse attention tracing.
Working with Raw Token IDs¶
For maximum control, use trace() instead of trace_text():
from reverse_attention import ReverseAttentionTracer
tracer = ReverseAttentionTracer(model, tokenizer)
# Tokenize manually
input_ids = tokenizer("Your text", return_tensors="pt").input_ids
attention_mask = tokenizer("Your text", return_tensors="pt").attention_mask
# Trace with full control
result = tracer.trace(
input_ids,
attention_mask=attention_mask,
target_pos=-1,
layer=-1,
)
Analyzing Multiple Layers¶
Compare attention patterns across layers:
results_by_layer = {}
for layer_idx in [-1, -5, -10, 0]:
result = tracer.trace_text(
"The quick brown fox jumps over the lazy dog.",
layer=layer_idx,
)
results_by_layer[layer_idx] = result
print(f"\nLayer {layer_idx}:")
for path in result.paths_text[:3]:
print(f" {path}")
Analyzing Multiple Positions¶
Trace from different target positions:
text = "The capital of France is Paris."
tokens = tokenizer.tokenize(text)
for pos in range(len(tokens)):
result = tracer.trace_text(text, target_pos=pos)
print(f"\nPosition {pos} ({tokens[pos]}):")
print(f" Top path: {result.paths_text[0]}")
Batch Analysis¶
Analyze multiple texts efficiently:
texts = [
"The quick brown fox jumps over the lazy dog.",
"Pack my box with five dozen liquor jugs.",
"How vexingly quick daft zebras jump!",
]
results = []
for text in texts:
result = tracer.trace_text(text)
results.append({
"text": text,
"top_path": result.paths_text[0],
"num_beams": len(result.beams),
})
Extracting Path Statistics¶
Analyze the discovered paths programmatically:
result = tracer.trace_text("Your text here")
for i, beam in enumerate(result.beams):
print(f"\nBeam {i + 1}:")
print(f" Path length: {len(beam.positions)}")
print(f" Positions: {beam.positions}")
print(f" Tokens: {beam.tokens}")
print(f" Edge attentions: {[f'{a:.3f}' for a in beam.edge_attns]}")
print(f" Raw score: {beam.score_raw:.4f}")
print(f" Normalized score: {beam.score_norm:.4f}")
Custom Visualization¶
Access Sankey data for custom visualizations:
result = tracer.trace_text("Your text here")
# Access raw Sankey data
sankey = result.sankey
# Nodes
for node in sankey.nodes:
print(f"Node: {node.id}, Token: {node.name}, Position: {node.position}")
# Links
for link in sankey.links:
print(f"Link: {link.source} -> {link.target}, Value: {link.value:.4f}")
Comparing Beam Search Parameters¶
Find optimal parameters for your use case:
import itertools
text = "The quick brown fox jumps over the lazy dog."
# Parameter grid
top_beams = [3, 5, 10]
top_ks = [3, 5, 10]
length_norms = ["none", "avg_logprob", "sqrt"]
for tb, tk, ln in itertools.product(top_beams, top_ks, length_norms):
result = tracer.trace_text(
text,
top_beam=tb,
top_k=tk,
length_norm=ln,
)
avg_path_length = sum(len(b.positions) for b in result.beams) / len(result.beams)
print(f"top_beam={tb}, top_k={tk}, length_norm={ln}")
print(f" Avg path length: {avg_path_length:.1f}")
print(f" Top path: {result.paths_text[0]}")
Head Aggregation Comparison¶
Compare mean vs max head aggregation:
text = "The quick brown fox jumps over the lazy dog."
for agg in ["mean", "max"]:
result = tracer.trace_text(text, agg_heads=agg)
print(f"\n{agg.upper()} aggregation:")
for path in result.paths_text[:3]:
print(f" {path}")
Exporting Results¶
To JSON¶
import json
result = tracer.trace_text("Your text")
export_data = {
"metadata": {
"seq_len": result.seq_len,
"target_pos": result.target_pos,
"layer": result.layer,
},
"tokens": result.tokens,
"paths_text": result.paths_text,
"beams": [
{
"positions": beam.positions,
"tokens": beam.tokens,
"edge_attns": beam.edge_attns,
"score_norm": beam.score_norm,
}
for beam in result.beams
],
}
with open("trace_result.json", "w") as f:
json.dump(export_data, f, indent=2)
To DataFrame¶
import pandas as pd
result = tracer.trace_text("Your text")
rows = []
for i, beam in enumerate(result.beams):
for j, (pos, token, attn) in enumerate(zip(
beam.positions[1:], # Skip target position
beam.tokens[1:],
beam.edge_attns,
)):
rows.append({
"beam_id": i,
"step": j,
"position": pos,
"token": token,
"attention": attn,
})
df = pd.DataFrame(rows)
print(df)
Finding Attention Patterns¶
Hub Tokens¶
Find tokens that appear in many paths:
from collections import Counter
result = tracer.trace_text("Your text here", top_beam=10)
position_counts = Counter()
for beam in result.beams:
for pos in beam.positions:
position_counts[pos] += 1
print("Most common positions:")
for pos, count in position_counts.most_common(5):
token = result.tokens[pos]
print(f" Position {pos} ({token}): {count} beams")
Long-Range Attention¶
Find paths with large position jumps:
result = tracer.trace_text("Your long text here...")
for i, beam in enumerate(result.beams):
max_jump = max(
beam.positions[j] - beam.positions[j+1]
for j in range(len(beam.positions) - 1)
)
if max_jump > 5:
print(f"Beam {i}: max jump = {max_jump}")
print(f" Path: {' -> '.join(beam.tokens)}")
Integration with Other Tools¶
With Weights & Biases¶
import wandb
wandb.init(project="attention-analysis")
result = tracer.trace_text("Your text")
wandb.log({
"num_beams": len(result.beams),
"avg_path_length": sum(len(b.positions) for b in result.beams) / len(result.beams),
"top_score": result.beams[0].score_norm,
})
# Log visualization
tracer.render_html(result, "output/")
wandb.log({"visualization": wandb.Html(open("output/index.html").read())})
With Matplotlib¶
import matplotlib.pyplot as plt
result = tracer.trace_text("Your text")
# Plot attention weights along top path
beam = result.beams[0]
plt.figure(figsize=(10, 4))
plt.bar(range(len(beam.edge_attns)), beam.edge_attns)
plt.xticks(range(len(beam.tokens) - 1), beam.tokens[:-1], rotation=45)
plt.ylabel("Attention")
plt.title("Attention weights along top beam")
plt.tight_layout()
plt.savefig("attention_weights.png")
Performance Optimization¶
Reduce Memory Usage¶
import torch
# Use inference mode
with torch.inference_mode():
result = tracer.trace_text("Your text")
# Clear CUDA cache if needed
torch.cuda.empty_cache()
Parallel Analysis¶
from concurrent.futures import ThreadPoolExecutor
texts = ["Text 1", "Text 2", "Text 3", ...]
def trace_text(text):
return tracer.trace_text(text)
# Note: This works for CPU; GPU may require more care
with ThreadPoolExecutor(max_workers=4) as executor:
results = list(executor.map(trace_text, texts))