Data Classes¶
This page documents the data classes used to represent trace results.
TraceResult¶
Container for the complete trace result.
TraceResult
dataclass
¶
Result of a reverse attention trace.
Fields¶
| Field | Type | Description |
|---|---|---|
seq_len |
int |
Sequence length |
target_pos |
int |
Target position (resolved to positive index) |
layer |
int |
Layer index (resolved to positive index) |
top_beam |
int |
Number of beams used |
top_k |
int |
Top-k value used |
tokens |
List[str] |
All tokens in sequence |
beams |
List[BeamPath] |
Completed beam paths |
sankey |
SankeyData |
Sankey visualization data |
paths_text |
List[str] |
Human-readable path descriptions |
Example Usage¶
result = tracer.trace_text("Hello world")
# Access metadata
print(f"Sequence length: {result.seq_len}")
print(f"Target position: {result.target_pos}")
print(f"Layer: {result.layer}")
# Access tokens
print(f"Tokens: {result.tokens}")
# Access beams
for i, beam in enumerate(result.beams):
print(f"Beam {i}: {result.paths_text[i]}")
# Access Sankey data for custom visualization
nodes = result.sankey.nodes
links = result.sankey.links
BeamPath¶
A single completed attention path with token information.
BeamPath
dataclass
¶
A completed beam path with token information.
Fields¶
| Field | Type | Description |
|---|---|---|
positions |
List[int] |
Token positions in sequence (target → start) |
tokens |
List[str] |
Token strings at each position |
token_ids |
List[int] |
Token IDs at each position |
edge_attns |
List[float] |
Attention weights along edges |
score_raw |
float |
Raw cumulative log score |
score_norm |
float |
Length-normalized score |
Example Usage¶
beam = result.beams[0] # Top beam
# Path structure
print(f"Positions: {beam.positions}") # [9, 7, 5, 3, 0]
print(f"Tokens: {beam.tokens}") # ['dog', 'lazy', 'over', ...]
print(f"Token IDs: {beam.token_ids}")
# Attention weights
print(f"Edge attentions: {beam.edge_attns}")
# One fewer edge than positions (edges connect positions)
# Scores
print(f"Raw score (log prob): {beam.score_raw}")
print(f"Normalized score: {beam.score_norm}")
Understanding the Structure¶
Positions: [9] → [7] → [5] → [3] → [0]
↓ ↓ ↓ ↓
Tokens: "dog" → "lazy" → "over" → ... → "The"
↓ ↓ ↓
Edge attns: [0.4] [0.3] [0.25] [0.2]
The path starts at target_pos and traces backward. edge_attns[i] is the attention from positions[i] to positions[i+1].
SankeyData¶
Data structure for Sankey diagram visualization.
SankeyData
dataclass
¶
Complete Sankey diagram data.
Fields¶
| Field | Type | Description |
|---|---|---|
nodes |
List[SankeyNode] |
Diagram nodes |
links |
List[SankeyLink] |
Diagram links |
Example Usage¶
sankey = result.sankey
# List all nodes
for node in sankey.nodes:
print(f"Node {node.id}: '{node.name}' at position {node.position}")
# List all links
for link in sankey.links:
print(f"Link: {link.source} → {link.target}, weight={link.value:.4f}")
SankeyNode¶
A node in the Sankey diagram.
SankeyNode
dataclass
¶
A node in the Sankey diagram.
Fields¶
| Field | Type | Description |
|---|---|---|
id |
str |
Unique identifier (e.g., "pos_5") |
name |
str |
Display name (token text) |
position |
int |
Position in sequence |
layer |
int |
Layer index (default: 0) |
SankeyLink¶
A link in the Sankey diagram.
SankeyLink
dataclass
¶
A link in the Sankey diagram.
Fields¶
| Field | Type | Description |
|---|---|---|
source |
str |
Source node ID |
target |
str |
Target node ID |
value |
float |
Link weight (aggregated attention) |
beam_indices |
List[int] |
Which beams use this link |
Link Weight Calculation¶
Link weights are computed by aggregating attention across beams using softmax weighting:
- Compute beam weights:
weights = softmax(normalized_scores) - For each link, sum:
value = Σ(beam_weight × attention)
This ensures that higher-scoring beams contribute more to link widths.
BeamState¶
Internal state during beam search (typically not accessed directly).
BeamState
dataclass
¶
Fields¶
| Field | Type | Description |
|---|---|---|
positions |
List[int] |
Positions traversed so far |
edge_attns |
List[float] |
Attention weights along edges |
log_score |
float |
Cumulative log(attention) |
is_terminal |
bool |
Whether this beam has terminated |
Properties¶
| Property | Type | Description |
|---|---|---|
current_pos |
int |
Most recent position in the beam |
num_edges |
int |
Number of edges traversed |
Importing Data Classes¶
All data classes are available from the main package:
from reverse_attention import ReverseAttentionTracer
from reverse_attention.beam import (
TraceResult,
BeamPath,
BeamState,
SankeyData,
SankeyNode,
SankeyLink,
)
Or import just what you need: