Skip to content

SSM Interp Support#1481

Merged
jlarson4 merged 15 commits into
devfrom
feature/improved-SSM-interp-surface
Jul 2, 2026
Merged

SSM Interp Support#1481
jlarson4 merged 15 commits into
devfrom
feature/improved-SSM-interp-surface

Conversation

@jlarson4

@jlarson4 jlarson4 commented Jul 2, 2026

Copy link
Copy Markdown
Collaborator

Description

Adds an interpretability surface for SSM/recurrent-mixer models to TransformerBridge, and extends verify_models to cover the SSM architectures.

Until now the bridge could run SSM models, but exposed almost none of their internals to interp tooling: no effective-attention view, no recurrent-state read, no way to intervene on the scan, and no shared vocabulary across the different SSM families. This PR closes that gap for every active SSM/hybrid architecture and wires the same quantities to the same hook names regardless of family, so ActivationCache can dispatch to them generically.

Motivation: mechanistic-interp work on state-space and linear-attention models (Mamba, the "Hidden Attention of Mamba" line of work, gated-delta-net) needs the same read/patch affordances we already have for attention. This makes SSM layers first-class alongside attention in the bridge.

What's added

Read surface (all gated by independent references, see Verification):

  • compute_effective_attention – materializes the effective-attention matrix M = L ⊙ (C Bᵀ) for Mamba-1, Mamba-2, and gated-delta-net (GDN). GDN's is documented as an explicit gated-linear-attention heuristic (drops the delta-rule key-removal), not a faithful decomposition.
  • compute_ssm_state – reconstructs the recurrent state Sₜ from cached hooks for Mamba-1, Mamba-2, and GDN (GDN replays the full delta rule, key-removal included, and L2-normalizes Q/K like the fused kernel, so it is faithful).

Intervention surface:

  • Opt-in eager_scan path that swaps HF's fused kernel for a readable Python scan so hook_ssm_write (per-step write) and hook_ssm_state (post-scan state trajectory) fire and can be patched for Mamba-1, Mamba-2 (+ NemotronH / GraniteMoeHybrid hybrids), and GDN. Patching hook_ssm_write propagates through the recurrence. hook_ssm_state affects only the same-position readout. Default is off, so the standard forward is bit-for-bit unchanged.

Family-agnostic plumbing:

  • SSMMixerProtocol + find_ssm_mixer (locates the mixer regardless of slot: .mixer / .linear_attn), and SSMStateHookMixin – the single definition point for the canonical hook_ssm_* vocabulary across all three families.
  • ActivationCache.compute_ssm_effective_attention, compute_ssm_state, and ssm_layers dispatch across homogeneous and hybrid models (stacked tensor for pure models, per-layer dict for hybrids).

Architectures covered:

  • Mamba-1
  • Mamba-2
  • NemotronH (hybrid)
  • GraniteMoeHybrid (hybrid)
  • gated-delta-net (Qwen3.5 / Qwen3-Next).

verify_models / benchmarks:

  • SSM adapters now run the full phase set [1,2,3,4] (P1 forward-parity vs raw HF is exact since the mixer delegates to HF, P4 generation) instead of being skipped.
  • Component-output benchmark skips SSM mixer internals (they take channel-first / d_inner / dt_rank / state shapes, not the [batch, seq, d_model] residual the isolated harness feeds) while still testing the mixer node end-to-end.
  • Weight-centering benchmarks now return SKIPPED (not a false failure) for attention-less / passthrough-mixer blocks.
  • Hook-registration treats the gated-q_proj forward-time NotImplementedError (Qwen3.5/Next) as a skip, matching the setter-time applicability gate.
  • Real checkpoints across all families were run through verify_models; results recorded in the registry.

Docs:

  • new docs/source/content/ssm_interpretability.md (canonical hook vocabulary + eager_scan opt-in example), linked in the docs index
  • README touch-up.

Notable design / correctness notes

  • Reconstructions are validated against independent fp64 step recurrences and HF-output parity, not against the implementation itself (non-circular gates).
  • The Granite logits_equivalence divergence surfaced in P3 was determined to be caused by bf16 precision of compatibility-mode center_unembed (toggle: off = 0.000 in both dtypes, fp32 = 4.2e-5, bf16 = 0.375). The clean fp32 run passes P3=100.
  • Eager-scan device allocation is input-device-correct (regression covered on cpu/cuda/mps).

Dependencies

None new – uses the SSM/recurrent modules already provided by transformers. Some tests and verify_models runs require the corresponding HF checkpoints and are availability-gated (skip cleanly when absent).

Type of change

  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@jlarson4 jlarson4 merged commit d4560ca into dev Jul 2, 2026
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant