Skip to content

Data Classes

This page documents the data classes used to represent trace results.

TraceResult

Container for the complete trace result.

TraceResult dataclass

Result of a reverse attention trace.

Fields

Field Type Description
seq_len int Sequence length
target_pos int Target position (resolved to positive index)
layer int Layer index (resolved to positive index)
top_beam int Number of beams used
top_k int Top-k value used
tokens List[str] All tokens in sequence
beams List[BeamPath] Completed beam paths
sankey SankeyData Sankey visualization data
paths_text List[str] Human-readable path descriptions

Example Usage

result = tracer.trace_text("Hello world")

# Access metadata
print(f"Sequence length: {result.seq_len}")
print(f"Target position: {result.target_pos}")
print(f"Layer: {result.layer}")

# Access tokens
print(f"Tokens: {result.tokens}")

# Access beams
for i, beam in enumerate(result.beams):
    print(f"Beam {i}: {result.paths_text[i]}")

# Access Sankey data for custom visualization
nodes = result.sankey.nodes
links = result.sankey.links

BeamPath

A single completed attention path with token information.

BeamPath dataclass

A completed beam path with token information.

Fields

Field Type Description
positions List[int] Token positions in sequence (target → start)
tokens List[str] Token strings at each position
token_ids List[int] Token IDs at each position
edge_attns List[float] Attention weights along edges
score_raw float Raw cumulative log score
score_norm float Length-normalized score

Example Usage

beam = result.beams[0]  # Top beam

# Path structure
print(f"Positions: {beam.positions}")  # [9, 7, 5, 3, 0]
print(f"Tokens: {beam.tokens}")        # ['dog', 'lazy', 'over', ...]
print(f"Token IDs: {beam.token_ids}")

# Attention weights
print(f"Edge attentions: {beam.edge_attns}")
# One fewer edge than positions (edges connect positions)

# Scores
print(f"Raw score (log prob): {beam.score_raw}")
print(f"Normalized score: {beam.score_norm}")

Understanding the Structure

Positions:    [9]  →  [7]  →  [5]  →  [3]  →  [0]
               ↓       ↓       ↓       ↓
Tokens:     "dog" → "lazy" → "over" → ... → "The"
               ↓       ↓       ↓
Edge attns:  [0.4]   [0.3]   [0.25]  [0.2]

The path starts at target_pos and traces backward. edge_attns[i] is the attention from positions[i] to positions[i+1].


SankeyData

Data structure for Sankey diagram visualization.

SankeyData dataclass

Complete Sankey diagram data.

Fields

Field Type Description
nodes List[SankeyNode] Diagram nodes
links List[SankeyLink] Diagram links

Example Usage

sankey = result.sankey

# List all nodes
for node in sankey.nodes:
    print(f"Node {node.id}: '{node.name}' at position {node.position}")

# List all links
for link in sankey.links:
    print(f"Link: {link.source}{link.target}, weight={link.value:.4f}")

SankeyNode

A node in the Sankey diagram.

SankeyNode dataclass

A node in the Sankey diagram.

Fields

Field Type Description
id str Unique identifier (e.g., "pos_5")
name str Display name (token text)
position int Position in sequence
layer int Layer index (default: 0)

A link in the Sankey diagram.

A link in the Sankey diagram.

Fields

Field Type Description
source str Source node ID
target str Target node ID
value float Link weight (aggregated attention)
beam_indices List[int] Which beams use this link

Link weights are computed by aggregating attention across beams using softmax weighting:

  1. Compute beam weights: weights = softmax(normalized_scores)
  2. For each link, sum: value = Σ(beam_weight × attention)

This ensures that higher-scoring beams contribute more to link widths.


BeamState

Internal state during beam search (typically not accessed directly).

BeamState dataclass

State of a single beam during search.

current_pos property

Get the current (most recent) position in the beam.

num_edges property

Number of edges traversed.

Fields

Field Type Description
positions List[int] Positions traversed so far
edge_attns List[float] Attention weights along edges
log_score float Cumulative log(attention)
is_terminal bool Whether this beam has terminated

Properties

Property Type Description
current_pos int Most recent position in the beam
num_edges int Number of edges traversed

Importing Data Classes

All data classes are available from the main package:

from reverse_attention import ReverseAttentionTracer
from reverse_attention.beam import (
    TraceResult,
    BeamPath,
    BeamState,
    SankeyData,
    SankeyNode,
    SankeyLink,
)

Or import just what you need:

from reverse_attention.beam import BeamPath, TraceResult