Skip to content

Advanced Usage

This guide covers advanced patterns and use cases for reverse attention tracing.

Working with Raw Token IDs

For maximum control, use trace() instead of trace_text():

from reverse_attention import ReverseAttentionTracer

tracer = ReverseAttentionTracer(model, tokenizer)

# Tokenize manually
input_ids = tokenizer("Your text", return_tensors="pt").input_ids
attention_mask = tokenizer("Your text", return_tensors="pt").attention_mask

# Trace with full control
result = tracer.trace(
    input_ids,
    attention_mask=attention_mask,
    target_pos=-1,
    layer=-1,
)

Analyzing Multiple Layers

Compare attention patterns across layers:

results_by_layer = {}

for layer_idx in [-1, -5, -10, 0]:
    result = tracer.trace_text(
        "The quick brown fox jumps over the lazy dog.",
        layer=layer_idx,
    )
    results_by_layer[layer_idx] = result

    print(f"\nLayer {layer_idx}:")
    for path in result.paths_text[:3]:
        print(f"  {path}")

Analyzing Multiple Positions

Trace from different target positions:

text = "The capital of France is Paris."
tokens = tokenizer.tokenize(text)

for pos in range(len(tokens)):
    result = tracer.trace_text(text, target_pos=pos)
    print(f"\nPosition {pos} ({tokens[pos]}):")
    print(f"  Top path: {result.paths_text[0]}")

Batch Analysis

Analyze multiple texts efficiently:

texts = [
    "The quick brown fox jumps over the lazy dog.",
    "Pack my box with five dozen liquor jugs.",
    "How vexingly quick daft zebras jump!",
]

results = []
for text in texts:
    result = tracer.trace_text(text)
    results.append({
        "text": text,
        "top_path": result.paths_text[0],
        "num_beams": len(result.beams),
    })

Extracting Path Statistics

Analyze the discovered paths programmatically:

result = tracer.trace_text("Your text here")

for i, beam in enumerate(result.beams):
    print(f"\nBeam {i + 1}:")
    print(f"  Path length: {len(beam.positions)}")
    print(f"  Positions: {beam.positions}")
    print(f"  Tokens: {beam.tokens}")
    print(f"  Edge attentions: {[f'{a:.3f}' for a in beam.edge_attns]}")
    print(f"  Raw score: {beam.score_raw:.4f}")
    print(f"  Normalized score: {beam.score_norm:.4f}")

Custom Visualization

Access Sankey data for custom visualizations:

result = tracer.trace_text("Your text here")

# Access raw Sankey data
sankey = result.sankey

# Nodes
for node in sankey.nodes:
    print(f"Node: {node.id}, Token: {node.name}, Position: {node.position}")

# Links
for link in sankey.links:
    print(f"Link: {link.source} -> {link.target}, Value: {link.value:.4f}")

Comparing Beam Search Parameters

Find optimal parameters for your use case:

import itertools

text = "The quick brown fox jumps over the lazy dog."

# Parameter grid
top_beams = [3, 5, 10]
top_ks = [3, 5, 10]
length_norms = ["none", "avg_logprob", "sqrt"]

for tb, tk, ln in itertools.product(top_beams, top_ks, length_norms):
    result = tracer.trace_text(
        text,
        top_beam=tb,
        top_k=tk,
        length_norm=ln,
    )

    avg_path_length = sum(len(b.positions) for b in result.beams) / len(result.beams)

    print(f"top_beam={tb}, top_k={tk}, length_norm={ln}")
    print(f"  Avg path length: {avg_path_length:.1f}")
    print(f"  Top path: {result.paths_text[0]}")

Head Aggregation Comparison

Compare mean vs max head aggregation:

text = "The quick brown fox jumps over the lazy dog."

for agg in ["mean", "max"]:
    result = tracer.trace_text(text, agg_heads=agg)
    print(f"\n{agg.upper()} aggregation:")
    for path in result.paths_text[:3]:
        print(f"  {path}")

Exporting Results

To JSON

import json

result = tracer.trace_text("Your text")

export_data = {
    "metadata": {
        "seq_len": result.seq_len,
        "target_pos": result.target_pos,
        "layer": result.layer,
    },
    "tokens": result.tokens,
    "paths_text": result.paths_text,
    "beams": [
        {
            "positions": beam.positions,
            "tokens": beam.tokens,
            "edge_attns": beam.edge_attns,
            "score_norm": beam.score_norm,
        }
        for beam in result.beams
    ],
}

with open("trace_result.json", "w") as f:
    json.dump(export_data, f, indent=2)

To DataFrame

import pandas as pd

result = tracer.trace_text("Your text")

rows = []
for i, beam in enumerate(result.beams):
    for j, (pos, token, attn) in enumerate(zip(
        beam.positions[1:],  # Skip target position
        beam.tokens[1:],
        beam.edge_attns,
    )):
        rows.append({
            "beam_id": i,
            "step": j,
            "position": pos,
            "token": token,
            "attention": attn,
        })

df = pd.DataFrame(rows)
print(df)

Finding Attention Patterns

Hub Tokens

Find tokens that appear in many paths:

from collections import Counter

result = tracer.trace_text("Your text here", top_beam=10)

position_counts = Counter()
for beam in result.beams:
    for pos in beam.positions:
        position_counts[pos] += 1

print("Most common positions:")
for pos, count in position_counts.most_common(5):
    token = result.tokens[pos]
    print(f"  Position {pos} ({token}): {count} beams")

Long-Range Attention

Find paths with large position jumps:

result = tracer.trace_text("Your long text here...")

for i, beam in enumerate(result.beams):
    max_jump = max(
        beam.positions[j] - beam.positions[j+1]
        for j in range(len(beam.positions) - 1)
    )
    if max_jump > 5:
        print(f"Beam {i}: max jump = {max_jump}")
        print(f"  Path: {' -> '.join(beam.tokens)}")

Integration with Other Tools

With Weights & Biases

import wandb

wandb.init(project="attention-analysis")

result = tracer.trace_text("Your text")

wandb.log({
    "num_beams": len(result.beams),
    "avg_path_length": sum(len(b.positions) for b in result.beams) / len(result.beams),
    "top_score": result.beams[0].score_norm,
})

# Log visualization
tracer.render_html(result, "output/")
wandb.log({"visualization": wandb.Html(open("output/index.html").read())})

With Matplotlib

import matplotlib.pyplot as plt

result = tracer.trace_text("Your text")

# Plot attention weights along top path
beam = result.beams[0]
plt.figure(figsize=(10, 4))
plt.bar(range(len(beam.edge_attns)), beam.edge_attns)
plt.xticks(range(len(beam.tokens) - 1), beam.tokens[:-1], rotation=45)
plt.ylabel("Attention")
plt.title("Attention weights along top beam")
plt.tight_layout()
plt.savefig("attention_weights.png")

Performance Optimization

Reduce Memory Usage

import torch

# Use inference mode
with torch.inference_mode():
    result = tracer.trace_text("Your text")

# Clear CUDA cache if needed
torch.cuda.empty_cache()

Parallel Analysis

from concurrent.futures import ThreadPoolExecutor

texts = ["Text 1", "Text 2", "Text 3", ...]

def trace_text(text):
    return tracer.trace_text(text)

# Note: This works for CPU; GPU may require more care
with ThreadPoolExecutor(max_workers=4) as executor:
    results = list(executor.map(trace_text, texts))