From 190840e1155d05ba8db6b8f22c3c7e04f5bedc6a Mon Sep 17 00:00:00 2001 From: The gemma Authors Date: Wed, 4 Dec 2024 01:47:54 -0800 Subject: [PATCH] Option to sow intermediate activations. The CL introduces a new SowModule which is used to configure which intermediate values are to be sowed. TransformerIntermediates and BlockIntermediates are passed around to store the intermediate values. PiperOrigin-RevId: 702635175 --- gemma/modules.py | 22 ++++- gemma/sampler.py | 27 +++++++ gemma/sampler_test.py | 98 ++++++++++++++++++++++ gemma/sow_lib.py | 165 ++++++++++++++++++++++++++++++++++++++ gemma/transformer.py | 65 +++++++++++++++ gemma/transformer_test.py | 114 ++++++++++++++++++++++++++ 6 files changed, 487 insertions(+), 4 deletions(-) create mode 100644 gemma/sow_lib.py diff --git a/gemma/modules.py b/gemma/modules.py index 7a1fa67..512ae27 100644 --- a/gemma/modules.py +++ b/gemma/modules.py @@ -18,6 +18,7 @@ from flax import linen as nn from gemma import layers from gemma import positional_embeddings +from gemma import sow_lib import jax import jax.numpy as jnp @@ -52,8 +53,8 @@ def _reconstruct_rotated_cache_positions(): cache_positions = cache_positions[None, None, :] # [1, 1, cache_len] segment_pos = segment_pos[:, :, None] # [B, seq_len, 1] - sliding_mask = (cache_positions > segment_pos - sliding_window_size) - sliding_mask *= (cache_positions < segment_pos + sliding_window_size) + sliding_mask = cache_positions > segment_pos - sliding_window_size + sliding_mask *= cache_positions < segment_pos + sliding_window_size return sliding_mask @@ -95,6 +96,7 @@ class Attention(nn.Module): query_pre_attn_scalar: float attn_logits_soft_cap: float | None = None sliding_window_size: int | None = None + sow: sow_lib.SowModule = sow_lib.SowModule() @property def use_qkv_einsum(self): @@ -127,6 +129,7 @@ def __call__( segment_pos: jax.Array, cache: LayerCache | None, attn_mask: jax.Array, + intermediates: sow_lib.BlockIntermediates | None = None, ) -> tuple[LayerCache | None, jax.Array]: seq_len = x.shape[1] @@ -190,6 +193,7 @@ def __call__( attn_mask *= sliding_mask padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK) + self.sow.maybe_sow_attn_logits_topk(padded_logits, intermediates) probs = jax.nn.softmax(padded_logits, axis=-1).astype(key_proj.dtype) if self.use_gqa: # Reshape matrices to enable einsums over groups. @@ -242,9 +246,10 @@ class FeedForward(nn.Module): features: int hidden_dim: int transpose_gating_einsum: bool + sow: sow_lib.SowModule = sow_lib.SowModule() @nn.compact - def __call__(self, x): + def __call__(self, x, intermediates=None): # Some versions use an alternate parameter ordering that # transposes hidden_dim and features. if self.transpose_gating_einsum: @@ -267,6 +272,8 @@ def __call__(self, x): ff1 = jnp.dot(x, w_gating[1]) activations = gate_value * ff1 + self.sow.maybe_sow_ffw_hidden_topk(activations, intermediates) + # Down projection w_linear = self.param( 'linear', @@ -293,6 +300,7 @@ class Block(nn.Module): transpose_gating_einsum: bool attn_logits_soft_cap: float | None = None sliding_window_size: int | None = None + sow: sow_lib.SowModule = sow_lib.SowModule() def setup(self): self.pre_attention_norm = layers.RMSNorm() @@ -305,6 +313,7 @@ def setup(self): query_pre_attn_scalar=self.query_pre_attn_scalar, attn_logits_soft_cap=self.attn_logits_soft_cap, sliding_window_size=self.sliding_window_size, + sow=self.sow, ) self.post_attention_norm = None if self.use_post_attn_norm: @@ -315,6 +324,7 @@ def setup(self): features=self.embed_dim, hidden_dim=self.hidden_dim, transpose_gating_einsum=self.transpose_gating_einsum, + sow=self.sow, ) self.post_ffw_norm = None if self.use_post_ffw_norm: @@ -326,6 +336,7 @@ def __call__( segment_pos: jax.Array, cache: LayerCache | None, attn_mask: jax.Array, + intermediates: sow_lib.BlockIntermediates | None = None, ) -> tuple[LayerCache | None, jax.Array]: inputs_normalized = self.pre_attention_norm(x) cache, attn_output = self.attn( @@ -333,13 +344,16 @@ def __call__( segment_pos, cache, attn_mask, + intermediates, ) if self.post_attention_norm is not None: attn_output = self.post_attention_norm(attn_output) attn_output += x + self.sow.maybe_sow_rs_after_attention(attn_output, intermediates) outputs = self.pre_ffw_norm(attn_output) - outputs = self.mlp(outputs) + outputs = self.mlp(outputs, intermediates) if self.post_ffw_norm is not None: outputs = self.post_ffw_norm(outputs) outputs += attn_output + self.sow.maybe_sow_rs_after_ffw(outputs, intermediates) return cache, outputs diff --git a/gemma/sampler.py b/gemma/sampler.py index 1ec029c..428d812 100644 --- a/gemma/sampler.py +++ b/gemma/sampler.py @@ -23,6 +23,7 @@ import chex from gemma import modules from gemma import params as params_lib +from gemma import sow_lib from gemma import transformer as transformer_lib import jax import jax.numpy as jnp @@ -88,9 +89,13 @@ class _SamplingState: # List of tokens that are forbidden to be generated. forbidden_token_ids: Sequence[int] | None = None + # Intermediate activations from the model if requested. + intermediates: sow_lib.TransformerIntermediates | None = None + @dataclasses.dataclass class SamplerOutput: + """Output of the sampler.""" # Decoded samples from the model. text: list[str] @@ -101,6 +106,9 @@ class SamplerOutput: # Tokens corresponding to the generated samples. tokens: list[list[int]] + # Intermediate activations from the model if requested. + intermediates: sow_lib.TransformerIntermediates | None = None + class Sampler: """Sampler for gemma transformer.""" @@ -142,6 +150,7 @@ def _sample_step( sampler_state.positions[:, decoding_step], -1 ) last_token = last_token.reshape((batch_size, 1)) + step_intermediates = sow_lib.TransformerIntermediates() logits, cache = self.transformer.apply( {'params': params}, @@ -149,6 +158,7 @@ def _sample_step( step_positions, sampler_state.cache, attention_mask, + step_intermediates, ) if sampler_state.forbidden_token_ids: logits = logits.at[:, :, sampler_state.forbidden_token_ids].set(-jnp.inf) @@ -174,6 +184,9 @@ def _sample_step( else: logits_buffer = sampler_state.logits_buffer + if sampler_state.intermediates is not None: + sampler_state.intermediates.merge(decoding_step, step_intermediates) + done = sampler_state.done | jnp.equal( token_buffer[:, decoding_step + 1], self.vocab.eos_id() ) @@ -188,12 +201,21 @@ def _sample_step( done=done, total_sampling_steps=sampler_state.total_sampling_steps, forbidden_token_ids=sampler_state.forbidden_token_ids, + intermediates=sampler_state.intermediates, ) def init_cache(self, bsz) -> dict[str, modules.LayerCache]: """Initializes the attention cache for each layer.""" return self.transformer.config.init_cache(bsz, dtype=self.dtype) + def init_intermediates( + self, bsz, buffer_size + ) -> sow_lib.TransformerIntermediates: + """Initializes the intermediate activations that will be filled.""" + return self.transformer.config.init_intermediates( + bsz, buffer_size, self.transformer.sow + ) + def init_sample_state( self, all_input_ids: list[jax.Array], @@ -244,6 +266,7 @@ def init_sample_state( done=done, total_sampling_steps=total_sampling_steps, forbidden_token_ids=forbidden_token_ids, + intermediates=self.init_intermediates(bsz, buffer_size), ) def tokenize(self, input_string: str) -> jax.Array: @@ -358,9 +381,13 @@ def __call__( decoded_outputs = [self.vocab.DecodeIds(tokens) for tokens in out_tokens] + if sampling_state.intermediates is not None: + sampling_state.intermediates.trim(total_sampling_steps) + result = SamplerOutput( text=decoded_outputs, logits=out_logits, tokens=out_tokens, + intermediates=sampling_state.intermediates, ) return result diff --git a/gemma/sampler_test.py b/gemma/sampler_test.py index 5533c2c..cf23d0c 100644 --- a/gemma/sampler_test.py +++ b/gemma/sampler_test.py @@ -19,6 +19,7 @@ from absl.testing import absltest from gemma import modules from gemma import sampler as sampler_lib +from gemma import sow_lib from gemma import transformer as transformer_lib import jax import jax.numpy as jnp @@ -73,6 +74,17 @@ def EncodeAsIds(self, text: str) -> list[int]: # pylint: disable=invalid-name class SamplerTest(absltest.TestCase): + def assertNotAllZeros(self, array, msg=None): + if not jnp.any(array).item(): + msg = msg or f'The array {array} is all zeros.' + raise self.failureException(msg) + + def assertReasonableTensor(self, array, expected_shape=None): + self.assertIsNotNone(array) + # self.assertNotAllZeros(array) + if expected_shape is not None: + self.assertEqual(array.shape, expected_shape) + def test_samples(self): vocab = MockVocab() @@ -317,6 +329,92 @@ def test_sampler_mask_tokens_after_eos_ids(self): self.assertListEqual(list(masked_token_buffer[0]), [1, 5, 6, 2, 0, 0]) self.assertListEqual(list(masked_token_buffer[1]), [1, 3, 4, 2, 0, 0]) + def test_sampler_sows_intermediates(self): + vocab = MockVocab() + config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types + num_layers=3, + num_embed=vocab.GetPieceSize(), + embed_dim=64, + hidden_dim=128, + num_heads=2, + num_kv_heads=1, + head_dim=64, + max_cache_length=1024, + final_logit_softcap=None, + attention_types=[modules.AttentionType.GLOBAL], + use_post_attn_norm=None, + attn_logits_soft_cap=None, + use_post_ffw_norm=None, + ) + sow = sow_lib.SowModule( + embeddings=True, + rs_after_attention=False, # This should results in a None value. + rs_after_ffw=True, + attn_logits_topk=5, + ffw_hidden_topk=11, + ) + attention_mask = jnp.ones((1, 1, config.max_cache_length)) + cache = config.init_cache(1, dtype=jnp.float32) + transformer = transformer_lib.Transformer(config, sow=sow) + params = transformer.init( + jax.random.PRNGKey(0), + jnp.array([[1]]), + jnp.array([[1]]), + cache, + attention_mask, + ) + sampler = sampler_lib.Sampler( + transformer=transformer, + vocab=vocab, + params=params['params'], + ) + raw_input = ['input string', 'hello world'] + + result = sampler(raw_input, total_generation_steps=10) + input_length = max([len(vocab.EncodeAsIds(i)) for i in raw_input]) + input_length += 1 # +1 for BOS token + output_length = max(len(tokens) for tokens in result.tokens) + length = input_length + output_length + self.assertIsNotNone(result) + intermediates = result.intermediates + self.assertIsNotNone(intermediates) + self.assertReasonableTensor( + intermediates.embeddings, + expected_shape=(2, length, config.embed_dim), + ) + # Verify that the intermediates are different for two different steps. + self.assertNotAlmostEqual( + jnp.sum(intermediates.embeddings[:, 1, ...]), + jnp.sum(intermediates.embeddings[:, 2, ...]), + ) + # Verify that the intermediates are filled in for each layer. + self.assertLen(intermediates.layers, config.num_layers) + for layer in intermediates.layers: + # For the requested intermediates we check the shape and that values are + # not all zeros, which was the initial value. + self.assertReasonableTensor( + layer.rs_after_ffw, + expected_shape=(2, length, config.embed_dim), + ) + self.assertReasonableTensor( + layer.attn_logits_topk_values, + expected_shape=(2, length, config.num_heads, sow.attn_logits_topk), + ) + self.assertReasonableTensor( + layer.attn_logits_topk_indices, + expected_shape=(2, length, config.num_heads, sow.attn_logits_topk), + ) + self.assertReasonableTensor( + layer.ffw_hidden_topk_values, + expected_shape=(2, length, sow.ffw_hidden_topk), + ) + self.assertReasonableTensor( + layer.ffw_hidden_topk_indices, + expected_shape=(2, length, sow.ffw_hidden_topk), + ) + # For the none requested intermediates we want to have None values. + self.assertIsNone(layer.rs_after_attention) + def test_compute_attention_mask(self): # Check that the input mask is correctly applied when total sampling steps # is lower than the max cache length. diff --git a/gemma/sow_lib.py b/gemma/sow_lib.py new file mode 100644 index 0000000..1a8f414 --- /dev/null +++ b/gemma/sow_lib.py @@ -0,0 +1,165 @@ +# Copyright 2024 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for sowing intermediate activations.""" + +import dataclasses +import jax + + +@jax.tree_util.register_dataclass +@dataclasses.dataclass +class BlockIntermediates: + """Intermediate activations for a single layer (block).""" + + # Dense residual stream activations. + rs_after_attention: jax.Array | None = None + rs_after_ffw: jax.Array | None = None + + # Sparse representations for large activations. + ffw_hidden_topk_values: jax.Array | None = None + ffw_hidden_topk_indices: jax.Array | None = None + attn_logits_topk_values: jax.Array | None = None + attn_logits_topk_indices: jax.Array | None = None + + def merge(self, decoding_step, step_intermediates: 'BlockIntermediates'): + """Merges the intermediate activations from one step.""" + + # This logic is the same for all intermediates. The second dimenions is the + # length dimension, where we want to merge the intermediates from + # multiple steps. + for field in dataclasses.fields(self.__class__): + value = getattr(self, field.name) + if value is not None: + step_value = getattr(step_intermediates, field.name) + if step_value is None: + raise ValueError( + 'Intermediate step value is None for field %s' % field.name + ) + setattr( + self, + field.name, + value.at[:, decoding_step + 1].set(step_value[:, 0, ...]), + ) + + def trim(self, max_length: int): + """Trims the intermediate activations to the given length.""" + for field in dataclasses.fields(self.__class__): + value = getattr(self, field.name) + if value is not None: + setattr(self, field.name, value[:, :max_length, ...]) + + +@jax.tree_util.register_dataclass +@dataclasses.dataclass +class TransformerIntermediates: + """Intermediate activations of one transformer step.""" + + # Embeddings of the input tokens. + embeddings: jax.Array | None = None + + # Intermediate activations of each layer. + layers: list[BlockIntermediates] = dataclasses.field(default_factory=list) + + def merge( + self, decoding_step, step_intermediates: 'TransformerIntermediates' + ): + """Merges the intermediate activations from one step.""" + if self.embeddings is not None: + assert step_intermediates.embeddings is not None + self.embeddings = self.embeddings.at[:, decoding_step + 1, ...].set( + step_intermediates.embeddings[:, 0, ...] + ) + for layer, step_layer in zip(self.layers, step_intermediates.layers): + layer.merge(decoding_step, step_layer) + + def trim(self, max_length: int): + """Trims the intermediate activations to the given length.""" + self.embeddings = self.embeddings[:, :max_length, ...] + for layer in self.layers: + layer.trim(max_length) + + +@dataclasses.dataclass(frozen=True) +class SowModule: + """Module for sowing intermediate activations.""" + + # Whether to sow embeddings. + embeddings: bool = False + + # Whether to sow activations after each attention block (in residual stream). + rs_after_attention: bool = False + + # Whether to sow activations after each FFW block (in residual stream). + # This is the same as the residual stream activations after a whole layer. + rs_after_ffw: bool = False + + # If non-zero, top-k activations in a ffw hidden layer are sowed. + # We use a sparse representation here to save memory. + ffw_hidden_topk: int = 0 + + # If non-zero, top-k attention logits are sowed. + # We use a sparse representation here to save memory. + attn_logits_topk: int = 0 + + def maybe_sow_embeddings( + self, + embeddings: jax.Array, + intermediates: TransformerIntermediates | None, + ): + """Sows embeddings if configured.""" + if intermediates is not None and self.embeddings: + intermediates.embeddings = embeddings + + def maybe_sow_rs_after_attention( + self, + activations: jax.Array, + intermediates: BlockIntermediates | None, + ): + """Sows activations after attention if configured.""" + if intermediates is not None and self.rs_after_attention: + intermediates.rs_after_attention = activations + + def maybe_sow_rs_after_ffw( + self, + activations: jax.Array, + intermediates: BlockIntermediates | None, + ): + """Sows activations after FFW if configured.""" + if intermediates is not None and self.rs_after_ffw: + intermediates.rs_after_ffw = activations + + def maybe_sow_ffw_hidden_topk( + self, + activations: jax.Array, + intermediates: BlockIntermediates | None, + ): + """Sows top-k activations in a ffw hidden layer if configured.""" + if intermediates is not None and self.ffw_hidden_topk: + ( + intermediates.ffw_hidden_topk_values, + intermediates.ffw_hidden_topk_indices, + ) = jax.lax.top_k(activations, self.ffw_hidden_topk) + + def maybe_sow_attn_logits_topk( + self, + logits: jax.Array, + intermediates: BlockIntermediates | None, + ): + """Sows top-k attention logits if configured.""" + if intermediates is not None and self.attn_logits_topk: + ( + intermediates.attn_logits_topk_values, + intermediates.attn_logits_topk_indices, + ) = jax.lax.top_k(logits, self.attn_logits_topk) diff --git a/gemma/transformer.py b/gemma/transformer.py index 630bf82..342a375 100644 --- a/gemma/transformer.py +++ b/gemma/transformer.py @@ -22,6 +22,7 @@ from gemma import layers from gemma import modules from gemma import params as params_lib +from gemma import sow_lib import jax import jax.numpy as jnp @@ -243,11 +244,65 @@ def init_cache( } return cache + def init_intermediates( + self, + batch_size: int, + buffer_size: int, + sow_config: sow_lib.SowModule, + dtype: jnp.dtype = jnp.float32, + ) -> sow_lib.TransformerIntermediates: + """Initializes the intermediate activations that will be filled.""" + intermediates = sow_lib.TransformerIntermediates() + residual_stream_dummy = jnp.zeros( + (batch_size, buffer_size, self.embed_dim), + dtype=dtype, + ) + if sow_config.embeddings: + intermediates.embeddings = residual_stream_dummy + for _ in range(self.num_layers): + block_intermediates = sow_lib.BlockIntermediates() + if sow_config.rs_after_attention: + block_intermediates.rs_after_attention = residual_stream_dummy + if sow_config.rs_after_ffw: + block_intermediates.rs_after_ffw = residual_stream_dummy + if sow_config.attn_logits_topk: + shape = ( + batch_size, + buffer_size, + self.num_heads, + sow_config.attn_logits_topk, + ) + block_intermediates.attn_logits_topk_values = jnp.zeros( + shape, + dtype=dtype, + ) + block_intermediates.attn_logits_topk_indices = jnp.zeros( + shape, + dtype=jnp.int32, + ) + if sow_config.ffw_hidden_topk: + shape = ( + batch_size, + buffer_size, + sow_config.ffw_hidden_topk, + ) + block_intermediates.ffw_hidden_topk_values = jnp.zeros( + shape, + dtype=dtype, + ) + block_intermediates.ffw_hidden_topk_indices = jnp.zeros( + shape, + dtype=jnp.int32, + ) + intermediates.layers.append(block_intermediates) + return intermediates + class Transformer(nn.Module): """Gemma transformer.""" config: TransformerConfig + sow: sow_lib.SowModule = sow_lib.SowModule() def setup(self): self.embedder = modules.Embedder( @@ -270,6 +325,7 @@ def setup(self): attn_type=attn_type, query_pre_attn_scalar=self.config.query_pre_attn_scalar(), transpose_gating_einsum=self.config.transpose_gating_einsum, + sow=self.sow, ) for i, attn_type in zip( range(self.config.num_layers), self.config.attention_types @@ -283,6 +339,7 @@ def __call__( positions: jax.Array, # [B, L] cache: Cache | None, # (sequence length L') attention_mask: jax.Array, # [B, L, L'] + intermediates: sow_lib.TransformerIntermediates | None = None, ) -> tuple[jax.Array, Cache | None]: """Transformer forward pass. @@ -294,6 +351,7 @@ def __call__( positions: input absolute positions. cache: Attention KV cache or None. attention_mask: transformer input mask. + intermediates: If not None, intermediate activations will be stored here. Returns: predicted_logits, new_cache @@ -302,17 +360,24 @@ def __call__( new_cache: updated cache if the input cache is not None, None elsewhere. """ x = self.embedder.encode(last_tokens) + self.sow.maybe_sow_embeddings(x, intermediates) for i, block in enumerate(self.blocks): layer_name = f'layer_{i}' layer_cache = cache[layer_name] if cache else None + layer_intermediates = ( + sow_lib.BlockIntermediates() if intermediates else None + ) layer_cache, x = block( x, positions, layer_cache, attention_mask, + intermediates=layer_intermediates, ) if cache is not None: cache[layer_name] = layer_cache # pytype: disable=container-type-mismatch + if intermediates is not None: + intermediates.layers.append(layer_intermediates) x = self.final_norm(x) logits = self.embedder.decode(x) diff --git a/gemma/transformer_test.py b/gemma/transformer_test.py index 6f34a9e..3ec0f87 100644 --- a/gemma/transformer_test.py +++ b/gemma/transformer_test.py @@ -18,6 +18,7 @@ from absl.testing import absltest from absl.testing import parameterized from gemma import modules +from gemma import sow_lib from gemma import transformer as transformer_lib import jax import jax.numpy as jnp @@ -375,6 +376,119 @@ def test_query_pre_attn_scalar( ) self.assertEqual(config.query_pre_attn_scalar(), expected_scalar) + @parameterized.parameters([ + sow_lib.SowModule(embeddings=True), + sow_lib.SowModule(rs_after_attention=True), + sow_lib.SowModule(rs_after_ffw=True), + sow_lib.SowModule(attn_logits_topk=5), + sow_lib.SowModule(ffw_hidden_topk=11), + ]) + def test_sow_intermediates(self, sow_config): + batch_size = 3 + sequence_length = 7 + num_layers = 2 + config = transformer_lib.TransformerConfig( + num_layers=num_layers, + num_embed=4, + embed_dim=48, + hidden_dim=12, + num_heads=1, + head_dim=4, + num_kv_heads=1, + final_logit_softcap=None, + use_post_attn_norm=False, + use_post_ffw_norm=False, + attention_types=[modules.AttentionType.GLOBAL] * num_layers, + max_cache_length=sequence_length, + ) + empty_cache = config.init_cache(batch_size, dtype=jnp.float32) + attention_mask = jnp.ones( + (batch_size, sequence_length, sequence_length), dtype=jnp.bool + ) + with jax.numpy_rank_promotion('raise'): + transformer = transformer_lib.Transformer(config=config, sow=sow_config) + params = transformer.init( + jax.random.PRNGKey(0), + last_tokens=jnp.tile(jnp.arange(sequence_length), (batch_size, 1)), + positions=jnp.tile(jnp.arange(sequence_length), (batch_size, 1)), + cache=empty_cache, + attention_mask=attention_mask, + ) + intermediates = sow_lib.TransformerIntermediates() + _, _ = transformer.apply( + params, + jnp.tile(jnp.arange(sequence_length), (batch_size, 1)), + jnp.tile(jnp.arange(sequence_length), (batch_size, 1)), + None, + attention_mask, + intermediates=intermediates, + ) + + if sow_config.embeddings: + embeddings = intermediates.embeddings + self.assertIsNotNone(embeddings) + self.assertEqual( + embeddings.shape, + (batch_size, sequence_length, config.embed_dim), + ) + else: + self.assertIsNone(intermediates.embeddings) + + self.assertLen(intermediates.layers, num_layers) + for block_intermediates in intermediates.layers: + if sow_config.rs_after_attention: + rs_after_attention = block_intermediates.rs_after_attention + self.assertIsNotNone(rs_after_attention) + self.assertEqual( + rs_after_attention.shape, + (batch_size, sequence_length, config.embed_dim), + ) + else: + self.assertIsNone(block_intermediates.rs_after_attention) + if sow_config.rs_after_ffw: + rs_after_ffw = block_intermediates.rs_after_ffw + self.assertIsNotNone(rs_after_ffw) + self.assertEqual( + rs_after_ffw.shape, + (batch_size, sequence_length, config.embed_dim), + ) + else: + self.assertIsNone(block_intermediates.rs_after_ffw) + if sow_config.attn_logits_topk: + attn_logits_topk_values = block_intermediates.attn_logits_topk_values + expected_shape = ( + batch_size, + sequence_length, + config.num_heads, + sow_config.attn_logits_topk, + ) + self.assertIsNotNone(attn_logits_topk_values) + self.assertEqual( + attn_logits_topk_values.shape, + expected_shape, + ) + attn_logits_topk_indices = block_intermediates.attn_logits_topk_indices + self.assertIsNotNone(attn_logits_topk_indices) + self.assertEqual(attn_logits_topk_indices.shape, expected_shape) + else: + self.assertIsNone(block_intermediates.attn_logits_topk_values) + self.assertIsNone(block_intermediates.attn_logits_topk_indices) + if sow_config.ffw_hidden_topk: + ffw_hidden_topk_values = block_intermediates.ffw_hidden_topk_values + expected_shape = ( + batch_size, + sequence_length, + sow_config.ffw_hidden_topk, + ) + self.assertIsNotNone(ffw_hidden_topk_values) + self.assertEqual(ffw_hidden_topk_values.shape, expected_shape) + ffw_hidden_topk_indices = block_intermediates.ffw_hidden_topk_indices + self.assertIsNotNone(ffw_hidden_topk_indices) + self.assertEqual(ffw_hidden_topk_indices.shape, expected_shape) + else: + self.assertIsNone(block_intermediates.ffw_hidden_topk_values) + self.assertIsNone(block_intermediates.ffw_hidden_topk_indices) + class TransformerUtilsTest(parameterized.TestCase):