Skip to content

ReverseAttentionTracer

The main class for tracing attention paths backward through transformer models.

Class Definition

ReverseAttentionTracer

Trace attention paths backward through a transformer model.

This class provides the main API for extracting attention weights, running beam search backward through the attention matrix, and visualizing the results as an interactive Sankey diagram.

Example

from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B") tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B") tracer = ReverseAttentionTracer(model, tokenizer) input_ids = tokenizer("Hello world", return_tensors="pt").input_ids result = tracer.trace(input_ids) tracer.render_html(result, "output/")

__init__(model, tokenizer, device=None, dtype=None)

Initialize the tracer.

Parameters:

Name Type Description Default
model PreTrainedModel

HuggingFace transformer model

required
tokenizer PreTrainedTokenizer

Corresponding tokenizer

required
device Optional[Union[str, device]]

Device to run on (defaults to model's device)

None
dtype Optional[dtype]

Data type for computation (defaults to model's dtype)

None

trace(input_ids, target_pos=None, attention_mask=None, layer=-1, top_beam=5, top_k=5, min_attn=0.0, agg_heads='mean', length_norm='avg_logprob', stop_at_bos=True, bos_token_id=None)

Trace attention paths backward from a target position.

Parameters:

Name Type Description Default
input_ids LongTensor

Input token IDs [1, seq_len]

required
target_pos Optional[int]

Position to trace from (default: -1, last token). Supports negative indexing.

None
attention_mask Optional[LongTensor]

Optional attention mask [1, seq_len]

None
layer int

Layer index to extract attention from (default: -1, last layer). Supports negative indexing.

-1
top_beam int

Number of beams to keep at each step

5
top_k int

Number of top predecessors to consider at each step

5
min_attn float

Minimum attention threshold for considering a predecessor

0.0
agg_heads str

Head aggregation mode ("mean", "max", "none")

'mean'
length_norm str

Score normalization mode ("none", "avg_logprob", "sqrt", "pow:α")

'avg_logprob'
stop_at_bos bool

Whether to stop at BOS positions

True
bos_token_id Optional[int]

Override BOS token ID (uses tokenizer's if None)

None

Returns:

Type Description
TraceResult

TraceResult with beams, sankey data, and metadata

Raises:

Type Description
ValueError

If parameters are invalid

trace_text(text, target_pos=None, **kwargs)

Convenience method to trace from text input.

Parameters:

Name Type Description Default
text str

Input text to tokenize and trace

required
target_pos Optional[int]

Position to trace from (default: -1, last token)

None
**kwargs

Additional arguments passed to trace()

{}

Returns:

Type Description
TraceResult

TraceResult with beams, sankey data, and metadata

render_html(trace_result, out_dir, open_browser=False)

Render trace result as interactive HTML visualization.

Parameters:

Name Type Description Default
trace_result TraceResult

Result from trace()

required
out_dir str

Output directory for HTML files

required
open_browser bool

Whether to open the result in a browser

False

Returns:

Type Description
str

Path to the generated index.html

Usage Examples

Basic Initialization

from transformers import AutoModelForCausalLM, AutoTokenizer
from reverse_attention import ReverseAttentionTracer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")

tracer = ReverseAttentionTracer(model, tokenizer)

With Custom Device

import torch

tracer = ReverseAttentionTracer(
    model,
    tokenizer,
    device="cuda:0",
    dtype=torch.float16,
)

Tracing from Text

result = tracer.trace_text(
    "The quick brown fox jumps over the lazy dog.",
    target_pos=-1,  # Last token
    layer=-1,       # Last layer
    top_beam=5,
    top_k=5,
)

Tracing from Token IDs

input_ids = tokenizer("Hello world", return_tensors="pt").input_ids

result = tracer.trace(
    input_ids,
    target_pos=-1,
    layer=-1,
)

Generating Visualization

html_path = tracer.render_html(
    result,
    out_dir="output/",
    open_browser=True,
)
print(f"Visualization at: {html_path}")

Methods

__init__

def __init__(
    self,
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    device: Optional[Union[str, torch.device]] = None,
    dtype: Optional[torch.dtype] = None,
)

Initialize the tracer with a model and tokenizer.

Parameters:

Name Type Description
model PreTrainedModel HuggingFace transformer model
tokenizer PreTrainedTokenizer Corresponding tokenizer
device str or torch.device Device to run on (defaults to model's device)
dtype torch.dtype Data type for computation (defaults to model's dtype)

trace

def trace(
    self,
    input_ids: torch.LongTensor,
    target_pos: Optional[int] = None,
    attention_mask: Optional[torch.LongTensor] = None,
    layer: int = -1,
    top_beam: int = 5,
    top_k: int = 5,
    min_attn: float = 0.0,
    agg_heads: str = "mean",
    length_norm: str = "avg_logprob",
    stop_at_bos: bool = True,
    bos_token_id: Optional[int] = None,
) -> TraceResult

Trace attention paths backward from a target position.

Parameters:

Name Type Default Description
input_ids torch.LongTensor required Input token IDs [1, seq_len]
target_pos int -1 Position to trace from (supports negative indexing)
attention_mask torch.LongTensor None Optional attention mask [1, seq_len]
layer int -1 Layer index (supports negative indexing)
top_beam int 5 Number of beams to keep
top_k int 5 Top-k predecessors per step
min_attn float 0.0 Minimum attention threshold
agg_heads str "mean" Head aggregation: "mean" or "max"
length_norm str "avg_logprob" Score normalization mode
stop_at_bos bool True Stop at BOS tokens
bos_token_id int None Override BOS token ID

Returns: TraceResult

Raises:

  • ValueError: If top_beam < 1 or top_k < 1
  • ValueError: If agg_heads="none" (not supported for beam search)
  • ValueError: If target_pos is out of range

trace_text

def trace_text(
    self,
    text: str,
    target_pos: Optional[int] = None,
    **kwargs,
) -> TraceResult

Convenience method to trace from text input.

Parameters:

Name Type Default Description
text str required Input text to tokenize and trace
target_pos int -1 Position to trace from
**kwargs Additional arguments passed to trace()

Returns: TraceResult

render_html

def render_html(
    self,
    trace_result: TraceResult,
    out_dir: str,
    open_browser: bool = False,
) -> str

Render trace result as interactive HTML visualization.

Parameters:

Name Type Default Description
trace_result TraceResult required Result from trace()
out_dir str required Output directory for HTML files
open_browser bool False Open result in browser

Returns: Path to the generated index.html

Parameter Reference

length_norm Options

Value Formula Effect
"none" score Raw cumulative log probability
"avg_logprob" score / length Geometric mean (default)
"sqrt" score / sqrt(length) Moderate normalization
"pow:α" score / length^α Custom exponent (e.g., "pow:0.7")

agg_heads Options

Value Effect
"mean" Average attention across all heads
"max" Maximum attention across heads

Note: "none" is not supported for beam search.