Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13,935 changes: 11,867 additions & 2,068 deletions python/Demonstration.ipynb

Large diffs are not rendered by default.

42 changes: 39 additions & 3 deletions python/circuitsvis/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch

from circuitsvis.utils.render import RenderedHTML, render


Expand All @@ -15,6 +16,8 @@ def attention_heads(
negative_color: Optional[str] = None,
positive_color: Optional[str] = None,
mask_upper_tri: Optional[bool] = None,
show_tokens: Optional[bool] = None,
match_color: Optional[bool] = None,
) -> RenderedHTML:
"""Attention Heads

Expand All @@ -24,8 +27,8 @@ def attention_heads(
is then shown in full size.

Args:
attention: Attention head activations of the shape [dest_tokens x
src_tokens]
attention: Attention head activations of the shape [heads x dest_tokens x src_tokens]
or [dest_tokens x src_tokens] (will be expanded to single head)
tokens: List of tokens (e.g. `["A", "person"]`). Must be the same length
as the list of values.
max_value: Maximum value. Used to determine how dark the token color is
Expand All @@ -41,19 +44,52 @@ def attention_heads(
mask_upper_tri: Whether or not to mask the upper triangular portion of
the attention patterns. Should be true for causal attention, false for
bidirectional attention.
show_tokens: Whether to show interactive token visualization where
hovering over tokens shows attention strength to other tokens.
match_color: Whether to match colors between attention patterns, token
visualization, and head headers for visual consistency.

Returns:
Html: Attention pattern visualization
"""

# Convert attention to numpy array if it's not already
attention_np: np.ndarray
if isinstance(attention, torch.Tensor):
attention_np = attention.detach().cpu().numpy()
elif isinstance(attention, np.ndarray):
attention_np = attention
else:
attention_np = np.array(attention)

# Ensure attention is 3D (num_heads, dest_len, src_len)
if attention_np.ndim == 2:
attention_np = attention_np[np.newaxis, :, :]
elif attention_np.ndim != 3:
raise ValueError(
f"Attention tensor must be 2D or 3D, got {attention_np.ndim}D tensor."
)

num_heads, dest_len, src_len = attention_np.shape

# Validate token count matches attention dimensions
if len(tokens) != dest_len or len(tokens) != src_len:
raise ValueError(
f"Token count ({len(tokens)}) doesn't match attention dimensions "
f"(dest: {dest_len}, src: {src_len}). For causal attention, these should all be equal."
)

kwargs = {
"attention": attention,
"attention": attention_np,
"attentionHeadNames": attention_head_names,
"maxValue": max_value,
"minValue": min_value,
"negativeColor": negative_color,
"positiveColor": positive_color,
"tokens": tokens,
"maskUpperTri": mask_upper_tri,
"showTokens": show_tokens,
"matchColor": match_color,
}

return render(
Expand Down
103 changes: 98 additions & 5 deletions react/src/attention/AttentionHeads.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import React from "react";
import React, { useMemo, useState } from "react";
import { Col, Container, Row } from "react-grid-system";
import { AttentionPattern } from "./AttentionPattern";
import { colorAttentionTensors } from "./AttentionPatterns";
import { Tokens, TokensView } from "./components/AttentionTokens";
import { useHoverLock, UseHoverLockState } from "./components/useHoverLock";

/**
Expand Down Expand Up @@ -35,6 +37,7 @@ export function AttentionHeadsSelector({
onMouseLeave,
positiveColor,
maskUpperTri,
matchColor,
tokens
}: AttentionHeadsProps & {
attentionHeadNames: string[];
Expand Down Expand Up @@ -88,8 +91,12 @@ export function AttentionHeadsSelector({
showAxisLabels={false}
maxValue={maxValue}
minValue={minValue}
negativeColor={negativeColor}
positiveColor={positiveColor}
negativeColor={matchColor ? undefined : negativeColor}
positiveColor={
matchColor
? attentionHeadColor(idx, attention.length)
: positiveColor
}
maskUpperTri={maskUpperTri}
/>
</div>
Expand All @@ -115,14 +122,41 @@ export function AttentionHeads({
negativeColor,
positiveColor,
maskUpperTri = true,
showTokens = true,
matchColor = false,
tokens
}: AttentionHeadsProps) {
// Attention head focussed state
const { focused, onClick, onMouseEnter, onMouseLeave } = useHoverLock(0);

// State for the token view type
const [tokensView, setTokensView] = useState<TokensView>(
TokensView.DESTINATION_TO_SOURCE
);

// State for which token is focussed
const {
focused: focussedToken,
onClick: onClickToken,
onMouseEnter: onMouseEnterToken,
onMouseLeave: onMouseLeaveToken
} = useHoverLock();

const headNames =
attentionHeadNames || attention.map((_, idx) => `Head ${idx}`);

// Color the attention values (by head) for interactive tokens
const coloredAttention = useMemo(() => {
if (!showTokens || !attention || attention.length === 0) return null;
const numHeads = attention.length;
const numDestTokens = attention[0]?.length || 0;
const numSrcTokens = attention[0]?.[0]?.length || 0;

if (numDestTokens === 0 || numSrcTokens === 0 || numHeads === 0)
return null;
return colorAttentionTensors(attention);
}, [attention, showTokens]);

return (
<Container>
<h3 style={{ marginBottom: 15 }}>
Expand All @@ -141,6 +175,7 @@ export function AttentionHeads({
onMouseLeave={onMouseLeave}
positiveColor={positiveColor}
maskUpperTri={maskUpperTri}
matchColor={matchColor}
tokens={tokens}
/>

Expand All @@ -166,8 +201,12 @@ export function AttentionHeads({
attention={attention[focused]}
maxValue={maxValue}
minValue={minValue}
negativeColor={negativeColor}
positiveColor={positiveColor}
negativeColor={matchColor ? undefined : negativeColor}
positiveColor={
matchColor
? attentionHeadColor(focused, attention.length)
: positiveColor
}
zoomed={true}
maskUpperTri={maskUpperTri}
tokens={tokens}
Expand All @@ -176,6 +215,42 @@ export function AttentionHeads({
</Col>
</Row>

{showTokens && coloredAttention && (
<Row>
<Col xs={12}>
<div className="tokens" style={{ marginTop: 20 }}>
<h4 style={{ display: "inline-block", marginRight: 15 }}>
Tokens
<span style={{ fontWeight: "normal" }}> (click to focus)</span>
</h4>
<select
value={tokensView}
onChange={(e) => setTokensView(e.target.value as TokensView)}
>
<option value={TokensView.DESTINATION_TO_SOURCE}>
Source ← Destination
</option>
<option value={TokensView.SOURCE_TO_DESTINATION}>
Destination ← Source
</option>
</select>
<div style={{ marginTop: 10 }}>
<Tokens
coloredAttention={coloredAttention}
focusedHead={focused}
focusedToken={focussedToken}
onClickToken={onClickToken}
onMouseEnterToken={onMouseEnterToken}
onMouseLeaveToken={onMouseLeaveToken}
tokens={tokens}
tokensView={tokensView}
/>
</div>
</div>
</Col>
</Row>
)}

<Row></Row>
</Container>
);
Expand Down Expand Up @@ -262,6 +337,24 @@ export interface AttentionHeadsProps {
*/
showAxisLabels?: boolean;

/**
* Show interactive tokens
*
* Whether to show interactive token visualization where hovering over tokens shows attention strength to other tokens.
*
* @default true
*/
showTokens?: boolean;

/**
* Match colors
*
* Whether to match colors between attention patterns, token visualization, and head headers for visual consistency.
*
* @default true
*/
matchColor?: boolean;

/**
* List of tokens
*
Expand Down