SSM Interp Support#1481
Merged
Merged
Conversation
1 task
1 task
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Adds an interpretability surface for SSM/recurrent-mixer models to
TransformerBridge, and extendsverify_modelsto cover the SSM architectures.Until now the bridge could run SSM models, but exposed almost none of their internals to interp tooling: no effective-attention view, no recurrent-state read, no way to intervene on the scan, and no shared vocabulary across the different SSM families. This PR closes that gap for every active SSM/hybrid architecture and wires the same quantities to the same hook names regardless of family, so
ActivationCachecan dispatch to them generically.Motivation: mechanistic-interp work on state-space and linear-attention models (Mamba, the "Hidden Attention of Mamba" line of work, gated-delta-net) needs the same read/patch affordances we already have for attention. This makes SSM layers first-class alongside attention in the bridge.
What's added
Read surface (all gated by independent references, see Verification):
compute_effective_attention– materializes the effective-attention matrixM = L ⊙ (C Bᵀ)for Mamba-1, Mamba-2, and gated-delta-net (GDN). GDN's is documented as an explicit gated-linear-attention heuristic (drops the delta-rule key-removal), not a faithful decomposition.compute_ssm_state– reconstructs the recurrent stateSₜfrom cached hooks for Mamba-1, Mamba-2, and GDN (GDN replays the full delta rule, key-removal included, and L2-normalizes Q/K like the fused kernel, so it is faithful).Intervention surface:
eager_scanpath that swaps HF's fused kernel for a readable Python scan sohook_ssm_write(per-step write) andhook_ssm_state(post-scan state trajectory) fire and can be patched for Mamba-1, Mamba-2 (+ NemotronH / GraniteMoeHybrid hybrids), and GDN. Patchinghook_ssm_writepropagates through the recurrence.hook_ssm_stateaffects only the same-position readout. Default is off, so the standard forward is bit-for-bit unchanged.Family-agnostic plumbing:
SSMMixerProtocol+find_ssm_mixer(locates the mixer regardless of slot:.mixer/.linear_attn), andSSMStateHookMixin– the single definition point for the canonicalhook_ssm_*vocabulary across all three families.ActivationCache.compute_ssm_effective_attention,compute_ssm_state, andssm_layersdispatch across homogeneous and hybrid models (stacked tensor for pure models, per-layer dict for hybrids).Architectures covered:
verify_models/ benchmarks:[1,2,3,4](P1 forward-parity vs raw HF is exact since the mixer delegates to HF, P4 generation) instead of being skipped.[batch, seq, d_model]residual the isolated harness feeds) while still testing the mixer node end-to-end.SKIPPED(not a false failure) for attention-less / passthrough-mixer blocks.q_projforward-timeNotImplementedError(Qwen3.5/Next) as a skip, matching the setter-time applicability gate.verify_models; results recorded in the registry.Docs:
docs/source/content/ssm_interpretability.md(canonical hook vocabulary +eager_scanopt-in example), linked in the docs indexNotable design / correctness notes
logits_equivalencedivergence surfaced in P3 was determined to be caused by bf16 precision of compatibility-modecenter_unembed(toggle: off = 0.000 in both dtypes, fp32 = 4.2e-5, bf16 = 0.375). The clean fp32 run passes P3=100.Dependencies
None new – uses the SSM/recurrent modules already provided by
transformers. Some tests andverify_modelsruns require the corresponding HF checkpoints and are availability-gated (skip cleanly when absent).Type of change
Checklist: