ReverseAttentionTracer¶
The main class for tracing attention paths backward through transformer models.
Class Definition¶
ReverseAttentionTracer
¶
Trace attention paths backward through a transformer model.
This class provides the main API for extracting attention weights, running beam search backward through the attention matrix, and visualizing the results as an interactive Sankey diagram.
Example
from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B") tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B") tracer = ReverseAttentionTracer(model, tokenizer) input_ids = tokenizer("Hello world", return_tensors="pt").input_ids result = tracer.trace(input_ids) tracer.render_html(result, "output/")
__init__(model, tokenizer, device=None, dtype=None)
¶
Initialize the tracer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
PreTrainedModel
|
HuggingFace transformer model |
required |
tokenizer
|
PreTrainedTokenizer
|
Corresponding tokenizer |
required |
device
|
Optional[Union[str, device]]
|
Device to run on (defaults to model's device) |
None
|
dtype
|
Optional[dtype]
|
Data type for computation (defaults to model's dtype) |
None
|
trace(input_ids, target_pos=None, attention_mask=None, layer=-1, top_beam=5, top_k=5, min_attn=0.0, agg_heads='mean', length_norm='avg_logprob', stop_at_bos=True, bos_token_id=None)
¶
Trace attention paths backward from a target position.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_ids
|
LongTensor
|
Input token IDs [1, seq_len] |
required |
target_pos
|
Optional[int]
|
Position to trace from (default: -1, last token). Supports negative indexing. |
None
|
attention_mask
|
Optional[LongTensor]
|
Optional attention mask [1, seq_len] |
None
|
layer
|
int
|
Layer index to extract attention from (default: -1, last layer). Supports negative indexing. |
-1
|
top_beam
|
int
|
Number of beams to keep at each step |
5
|
top_k
|
int
|
Number of top predecessors to consider at each step |
5
|
min_attn
|
float
|
Minimum attention threshold for considering a predecessor |
0.0
|
agg_heads
|
str
|
Head aggregation mode ("mean", "max", "none") |
'mean'
|
length_norm
|
str
|
Score normalization mode ("none", "avg_logprob", "sqrt", "pow:α") |
'avg_logprob'
|
stop_at_bos
|
bool
|
Whether to stop at BOS positions |
True
|
bos_token_id
|
Optional[int]
|
Override BOS token ID (uses tokenizer's if None) |
None
|
Returns:
| Type | Description |
|---|---|
TraceResult
|
TraceResult with beams, sankey data, and metadata |
Raises:
| Type | Description |
|---|---|
ValueError
|
If parameters are invalid |
trace_text(text, target_pos=None, **kwargs)
¶
Convenience method to trace from text input.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
text
|
str
|
Input text to tokenize and trace |
required |
target_pos
|
Optional[int]
|
Position to trace from (default: -1, last token) |
None
|
**kwargs
|
Additional arguments passed to trace() |
{}
|
Returns:
| Type | Description |
|---|---|
TraceResult
|
TraceResult with beams, sankey data, and metadata |
render_html(trace_result, out_dir, open_browser=False)
¶
Render trace result as interactive HTML visualization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trace_result
|
TraceResult
|
Result from trace() |
required |
out_dir
|
str
|
Output directory for HTML files |
required |
open_browser
|
bool
|
Whether to open the result in a browser |
False
|
Returns:
| Type | Description |
|---|---|
str
|
Path to the generated index.html |
Usage Examples¶
Basic Initialization¶
from transformers import AutoModelForCausalLM, AutoTokenizer
from reverse_attention import ReverseAttentionTracer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
tracer = ReverseAttentionTracer(model, tokenizer)
With Custom Device¶
import torch
tracer = ReverseAttentionTracer(
model,
tokenizer,
device="cuda:0",
dtype=torch.float16,
)
Tracing from Text¶
result = tracer.trace_text(
"The quick brown fox jumps over the lazy dog.",
target_pos=-1, # Last token
layer=-1, # Last layer
top_beam=5,
top_k=5,
)
Tracing from Token IDs¶
input_ids = tokenizer("Hello world", return_tensors="pt").input_ids
result = tracer.trace(
input_ids,
target_pos=-1,
layer=-1,
)
Generating Visualization¶
html_path = tracer.render_html(
result,
out_dir="output/",
open_browser=True,
)
print(f"Visualization at: {html_path}")
Methods¶
__init__¶
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
)
Initialize the tracer with a model and tokenizer.
Parameters:
| Name | Type | Description |
|---|---|---|
model |
PreTrainedModel |
HuggingFace transformer model |
tokenizer |
PreTrainedTokenizer |
Corresponding tokenizer |
device |
str or torch.device |
Device to run on (defaults to model's device) |
dtype |
torch.dtype |
Data type for computation (defaults to model's dtype) |
trace¶
def trace(
self,
input_ids: torch.LongTensor,
target_pos: Optional[int] = None,
attention_mask: Optional[torch.LongTensor] = None,
layer: int = -1,
top_beam: int = 5,
top_k: int = 5,
min_attn: float = 0.0,
agg_heads: str = "mean",
length_norm: str = "avg_logprob",
stop_at_bos: bool = True,
bos_token_id: Optional[int] = None,
) -> TraceResult
Trace attention paths backward from a target position.
Parameters:
| Name | Type | Default | Description |
|---|---|---|---|
input_ids |
torch.LongTensor |
required | Input token IDs [1, seq_len] |
target_pos |
int |
-1 |
Position to trace from (supports negative indexing) |
attention_mask |
torch.LongTensor |
None |
Optional attention mask [1, seq_len] |
layer |
int |
-1 |
Layer index (supports negative indexing) |
top_beam |
int |
5 |
Number of beams to keep |
top_k |
int |
5 |
Top-k predecessors per step |
min_attn |
float |
0.0 |
Minimum attention threshold |
agg_heads |
str |
"mean" |
Head aggregation: "mean" or "max" |
length_norm |
str |
"avg_logprob" |
Score normalization mode |
stop_at_bos |
bool |
True |
Stop at BOS tokens |
bos_token_id |
int |
None |
Override BOS token ID |
Returns: TraceResult
Raises:
ValueError: Iftop_beam < 1ortop_k < 1ValueError: Ifagg_heads="none"(not supported for beam search)ValueError: Iftarget_posis out of range
trace_text¶
Convenience method to trace from text input.
Parameters:
| Name | Type | Default | Description |
|---|---|---|---|
text |
str |
required | Input text to tokenize and trace |
target_pos |
int |
-1 |
Position to trace from |
**kwargs |
Additional arguments passed to trace() |
Returns: TraceResult
render_html¶
def render_html(
self,
trace_result: TraceResult,
out_dir: str,
open_browser: bool = False,
) -> str
Render trace result as interactive HTML visualization.
Parameters:
| Name | Type | Default | Description |
|---|---|---|---|
trace_result |
TraceResult |
required | Result from trace() |
out_dir |
str |
required | Output directory for HTML files |
open_browser |
bool |
False |
Open result in browser |
Returns: Path to the generated index.html
Parameter Reference¶
length_norm Options¶
| Value | Formula | Effect |
|---|---|---|
"none" |
score |
Raw cumulative log probability |
"avg_logprob" |
score / length |
Geometric mean (default) |
"sqrt" |
score / sqrt(length) |
Moderate normalization |
"pow:α" |
score / length^α |
Custom exponent (e.g., "pow:0.7") |
agg_heads Options¶
| Value | Effect |
|---|---|
"mean" |
Average attention across all heads |
"max" |
Maximum attention across heads |
Note: "none" is not supported for beam search.