Skip to content

Basic Tracing Example

A complete, annotated walkthrough of using reverse attention tracing.

The Complete Script

#!/usr/bin/env python3
"""
Basic example of reverse attention tracing.

This script demonstrates:
1. Loading a model and tokenizer
2. Creating a tracer
3. Running a trace
4. Interpreting the results
5. Generating a visualization
"""

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from reverse_attention import ReverseAttentionTracer


def main():
    # =========================================================
    # Step 1: Load Model and Tokenizer
    # =========================================================

    print("Loading model...")

    # We use Qwen2-0.5B as it's small enough to run on most machines
    # but still shows interesting attention patterns
    model_name = "Qwen/Qwen2-0.5B"

    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True,  # Required for some models
    )

    # Determine the best device
    if torch.cuda.is_available():
        device = "cuda"
        dtype = torch.float16  # Use half precision on GPU
    elif torch.backends.mps.is_available():
        device = "mps"
        dtype = torch.float16
    else:
        device = "cpu"
        dtype = torch.float32  # Full precision on CPU

    print(f"Using device: {device}")

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=dtype,
        device_map=device,
        trust_remote_code=True,
        attn_implementation="eager",  # Required for attention output
    )

    # =========================================================
    # Step 2: Create the Tracer
    # =========================================================

    tracer = ReverseAttentionTracer(model, tokenizer)

    # =========================================================
    # Step 3: Define Input and Run Trace
    # =========================================================

    # Our test sentence - a classic pangram
    text = "The quick brown fox jumps over the lazy dog."

    print(f"\nInput: {text}")
    print(f"Tokens: {tokenizer.tokenize(text)}")

    # Run the trace
    # - target_pos=-1: Start from the last token
    # - layer=-1: Use the last transformer layer
    # - top_beam=5: Keep the top 5 paths
    # - top_k=5: Consider top 5 predecessors at each step
    result = tracer.trace_text(
        text,
        target_pos=-1,
        layer=-1,
        top_beam=5,
        top_k=5,
    )

    # =========================================================
    # Step 4: Explore the Results
    # =========================================================

    print("\n" + "=" * 60)
    print("TRACE RESULTS")
    print("=" * 60)

    # Basic metadata
    print(f"\nSequence length: {result.seq_len}")
    print(f"Target position: {result.target_pos} ({result.tokens[result.target_pos]})")
    print(f"Layer analyzed: {result.layer}")
    print(f"Number of beams found: {len(result.beams)}")

    # Human-readable paths
    print("\n--- Top Attention Paths ---")
    for i, path_text in enumerate(result.paths_text):
        print(f"\nBeam {i + 1}:")
        print(f"  {path_text}")

    # Detailed beam analysis
    print("\n--- Detailed Beam Analysis ---")
    top_beam = result.beams[0]

    print(f"\nTop beam details:")
    print(f"  Positions: {top_beam.positions}")
    print(f"  Tokens: {top_beam.tokens}")
    print(f"  Edge attentions: {[f'{a:.3f}' for a in top_beam.edge_attns]}")
    print(f"  Raw score: {top_beam.score_raw:.4f}")
    print(f"  Normalized score: {top_beam.score_norm:.4f}")

    # =========================================================
    # Step 5: Generate Visualization
    # =========================================================

    output_dir = "output"
    print(f"\nGenerating visualization in: {output_dir}/")

    html_path = tracer.render_html(
        result,
        out_dir=output_dir,
        open_browser=False,  # Set to True to auto-open
    )

    print(f"Visualization saved to: {html_path}")
    print("\nOpen this file in a browser to see the interactive Sankey diagram.")

    # =========================================================
    # Bonus: Analyze Sankey Data
    # =========================================================

    print("\n--- Sankey Diagram Data ---")
    print(f"Nodes: {len(result.sankey.nodes)}")
    print(f"Links: {len(result.sankey.links)}")

    # Find the strongest link
    strongest_link = max(result.sankey.links, key=lambda l: l.value)
    print(f"\nStrongest link:")
    print(f"  {strongest_link.source}{strongest_link.target}")
    print(f"  Weight: {strongest_link.value:.4f}")
    print(f"  Used by beams: {strongest_link.beam_indices}")


if __name__ == "__main__":
    main()

Running the Example

Save the script as basic_example.py and run:

python basic_example.py

Expected Output

Loading model...
Using device: cuda

Input: The quick brown fox jumps over the lazy dog.
Tokens: ['The', ' quick', ' brown', ' fox', ' jumps', ' over', ' the', ' lazy', ' dog', '.']

============================================================
TRACE RESULTS
============================================================

Sequence length: 10
Target position: 9 (.)
Layer analyzed: 23
Number of beams found: 5

--- Top Attention Paths ---

Beam 1:
  . ← dog ← lazy ← the ← over

Beam 2:
  . ← dog ← the ← lazy ← over

Beam 3:
  . ← dog ← lazy ← over ← jumps

...

--- Detailed Beam Analysis ---

Top beam details:
  Positions: [9, 8, 7, 6, 5]
  Tokens: ['.', 'dog', 'lazy', 'the', 'over']
  Edge attentions: ['0.412', '0.289', '0.341', '0.256']
  Raw score: -3.2451
  Normalized score: -0.8113

Generating visualization in: output/
Visualization saved to: output/index.html

Open this file in a browser to see the interactive Sankey diagram.

Understanding the Results

The Path Format

. ← dog ← lazy ← the ← over

This reads right-to-left as the attention flow:

  1. We start at . (the target token)
  2. . most strongly attends to dog
  3. dog attends to lazy
  4. lazy attends to the
  5. the attends to over

The Scores

  • Raw score: Sum of log(attention) along the path. More negative = lower probability.
  • Normalized score: Raw score divided by path length. Allows fair comparison across paths.

Multiple Beams

Having multiple beams shows alternative attention paths:

  • Beam 1: The most probable path
  • Beam 2-5: Alternative paths that also have high attention

If beams are very similar, the attention is focused. If they differ significantly, attention is distributed.

Variations to Try

Different Target Positions

# Trace from "fox" instead of the last token
result = tracer.trace_text(text, target_pos=3)

Different Layers

# Early layer (more local patterns)
result = tracer.trace_text(text, layer=0)

# Middle layer
result = tracer.trace_text(text, layer=10)

More Exploration

# More beams and candidates
result = tracer.trace_text(
    text,
    top_beam=10,
    top_k=8,
)

Filter Weak Attention

# Only strong attention connections
result = tracer.trace_text(
    text,
    min_attn=0.1,
)

Next Steps