Skip to content

Commit

Permalink
Option to sow intermediate activations.
Browse files Browse the repository at this point in the history
The CL introduces a new SowModule which is used to configure which intermediate values that are to be sowed.
TransformerIntermediates and BlockIntermediates are passed around to store the intermediate values.

PiperOrigin-RevId: 702635175
  • Loading branch information
The gemma Authors committed Dec 9, 2024
1 parent af38d6e commit 11b2feb
Show file tree
Hide file tree
Showing 6 changed files with 488 additions and 4 deletions.
22 changes: 18 additions & 4 deletions gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flax import linen as nn
from gemma import layers
from gemma import positional_embeddings
from gemma import sow_lib
import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -52,8 +53,8 @@ def _reconstruct_rotated_cache_positions():

cache_positions = cache_positions[None, None, :] # [1, 1, cache_len]
segment_pos = segment_pos[:, :, None] # [B, seq_len, 1]
sliding_mask = (cache_positions > segment_pos - sliding_window_size)
sliding_mask *= (cache_positions < segment_pos + sliding_window_size)
sliding_mask = cache_positions > segment_pos - sliding_window_size
sliding_mask *= cache_positions < segment_pos + sliding_window_size
return sliding_mask


Expand Down Expand Up @@ -95,6 +96,7 @@ class Attention(nn.Module):
query_pre_attn_scalar: float
attn_logits_soft_cap: float | None = None
sliding_window_size: int | None = None
sow: sow_lib.SowModule = sow_lib.SowModule()

@property
def use_qkv_einsum(self):
Expand Down Expand Up @@ -127,6 +129,7 @@ def __call__(
segment_pos: jax.Array,
cache: LayerCache | None,
attn_mask: jax.Array,
intermediates: sow_lib.BlockIntermediates | None = None,
) -> tuple[LayerCache | None, jax.Array]:
seq_len = x.shape[1]

Expand Down Expand Up @@ -190,6 +193,7 @@ def __call__(
attn_mask *= sliding_mask

padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK)
self.sow.maybe_sow_attn_logits_topk(padded_logits, intermediates)
probs = jax.nn.softmax(padded_logits, axis=-1).astype(key_proj.dtype)
if self.use_gqa:
# Reshape matrices to enable einsums over groups.
Expand Down Expand Up @@ -242,9 +246,10 @@ class FeedForward(nn.Module):
features: int
hidden_dim: int
transpose_gating_einsum: bool
sow: sow_lib.SowModule = sow_lib.SowModule()

@nn.compact
def __call__(self, x):
def __call__(self, x, intermediates=None):
# Some versions use an alternate parameter ordering that
# transposes hidden_dim and features.
if self.transpose_gating_einsum:
Expand All @@ -267,6 +272,8 @@ def __call__(self, x):
ff1 = jnp.dot(x, w_gating[1])
activations = gate_value * ff1

self.sow.maybe_sow_ffw_hidden_topk(activations, intermediates)

# Down projection
w_linear = self.param(
'linear',
Expand All @@ -293,6 +300,7 @@ class Block(nn.Module):
transpose_gating_einsum: bool
attn_logits_soft_cap: float | None = None
sliding_window_size: int | None = None
sow: sow_lib.SowModule = sow_lib.SowModule()

def setup(self):
self.pre_attention_norm = layers.RMSNorm()
Expand All @@ -305,6 +313,7 @@ def setup(self):
query_pre_attn_scalar=self.query_pre_attn_scalar,
attn_logits_soft_cap=self.attn_logits_soft_cap,
sliding_window_size=self.sliding_window_size,
sow=self.sow,
)
self.post_attention_norm = None
if self.use_post_attn_norm:
Expand All @@ -315,6 +324,7 @@ def setup(self):
features=self.embed_dim,
hidden_dim=self.hidden_dim,
transpose_gating_einsum=self.transpose_gating_einsum,
sow=self.sow,
)
self.post_ffw_norm = None
if self.use_post_ffw_norm:
Expand All @@ -326,20 +336,24 @@ def __call__(
segment_pos: jax.Array,
cache: LayerCache | None,
attn_mask: jax.Array,
intermediates: sow_lib.BlockIntermediates | None = None,
) -> tuple[LayerCache | None, jax.Array]:
inputs_normalized = self.pre_attention_norm(x)
cache, attn_output = self.attn(
inputs_normalized,
segment_pos,
cache,
attn_mask,
intermediates,
)
if self.post_attention_norm is not None:
attn_output = self.post_attention_norm(attn_output)
attn_output += x
self.sow.maybe_sow_rs_after_attention(attn_output, intermediates)
outputs = self.pre_ffw_norm(attn_output)
outputs = self.mlp(outputs)
outputs = self.mlp(outputs, intermediates)
if self.post_ffw_norm is not None:
outputs = self.post_ffw_norm(outputs)
outputs += attn_output
self.sow.maybe_sow_rs_after_ffw(outputs, intermediates)
return cache, outputs
27 changes: 27 additions & 0 deletions gemma/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import chex
from gemma import modules
from gemma import params as params_lib
from gemma import sow_lib
from gemma import transformer as transformer_lib
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -88,9 +89,13 @@ class _SamplingState:
# List of tokens that are forbidden to be generated.
forbidden_token_ids: Sequence[int] | None = None

# Intermediate activations from the model if requested.
intermediates: sow_lib.TransformerIntermediates | None = None


@dataclasses.dataclass
class SamplerOutput:
"""Output of the sampler."""

# Decoded samples from the model.
text: list[str]
Expand All @@ -101,6 +106,9 @@ class SamplerOutput:
# Tokens corresponding to the generated samples.
tokens: list[list[int]]

# Intermediate activations from the model if requested.
intermediates: sow_lib.TransformerIntermediates | None = None


class Sampler:
"""Sampler for gemma transformer."""
Expand Down Expand Up @@ -142,13 +150,15 @@ def _sample_step(
sampler_state.positions[:, decoding_step], -1
)
last_token = last_token.reshape((batch_size, 1))
step_intermediates = sow_lib.TransformerIntermediates()

logits, cache = self.transformer.apply(
{'params': params},
last_token,
step_positions,
sampler_state.cache,
attention_mask,
step_intermediates,
)
if sampler_state.forbidden_token_ids:
logits = logits.at[:, :, sampler_state.forbidden_token_ids].set(-jnp.inf)
Expand All @@ -174,6 +184,9 @@ def _sample_step(
else:
logits_buffer = sampler_state.logits_buffer

if sampler_state.intermediates is not None:
sampler_state.intermediates.merge(decoding_step, step_intermediates)

done = sampler_state.done | jnp.equal(
token_buffer[:, decoding_step + 1], self.vocab.eos_id()
)
Expand All @@ -188,12 +201,21 @@ def _sample_step(
done=done,
total_sampling_steps=sampler_state.total_sampling_steps,
forbidden_token_ids=sampler_state.forbidden_token_ids,
intermediates=sampler_state.intermediates,
)

def init_cache(self, bsz) -> dict[str, modules.LayerCache]:
"""Initializes the attention cache for each layer."""
return self.transformer.config.init_cache(bsz, dtype=self.dtype)

def init_intermediates(
self, bsz, buffer_size
) -> sow_lib.TransformerIntermediates:
"""Initializes the intermediate activations that will be filled."""
return self.transformer.config.init_intermediates(
bsz, buffer_size, self.transformer.sow
)

def init_sample_state(
self,
all_input_ids: list[jax.Array],
Expand Down Expand Up @@ -244,6 +266,7 @@ def init_sample_state(
done=done,
total_sampling_steps=total_sampling_steps,
forbidden_token_ids=forbidden_token_ids,
intermediates=self.init_intermediates(bsz, buffer_size),
)

def tokenize(self, input_string: str) -> jax.Array:
Expand Down Expand Up @@ -358,9 +381,13 @@ def __call__(

decoded_outputs = [self.vocab.DecodeIds(tokens) for tokens in out_tokens]

if sampling_state.intermediates is not None:
sampling_state.intermediates.trim(total_sampling_steps)

result = SamplerOutput(
text=decoded_outputs,
logits=out_logits,
tokens=out_tokens,
intermediates=sampling_state.intermediates,
)
return result
98 changes: 98 additions & 0 deletions gemma/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from absl.testing import absltest
from gemma import modules
from gemma import sampler as sampler_lib
from gemma import sow_lib
from gemma import transformer as transformer_lib
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -73,6 +74,17 @@ def EncodeAsIds(self, text: str) -> list[int]: # pylint: disable=invalid-name

class SamplerTest(absltest.TestCase):

def assertNotAllZeros(self, array, msg=None):
if not jnp.any(array).item():
msg = msg or f'The array {array} is all zeros.'
raise self.failureException(msg)

def assertReasonableTensor(self, array, expected_shape=None):
self.assertIsNotNone(array)
# self.assertNotAllZeros(array)
if expected_shape is not None:
self.assertEqual(array.shape, expected_shape)

def test_samples(self):
vocab = MockVocab()

Expand Down Expand Up @@ -317,6 +329,92 @@ def test_sampler_mask_tokens_after_eos_ids(self):
self.assertListEqual(list(masked_token_buffer[0]), [1, 5, 6, 2, 0, 0])
self.assertListEqual(list(masked_token_buffer[1]), [1, 3, 4, 2, 0, 0])

def test_sampler_sows_intermediates(self):
vocab = MockVocab()
config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types
num_layers=3,
num_embed=vocab.GetPieceSize(),
embed_dim=64,
hidden_dim=128,
num_heads=2,
num_kv_heads=1,
head_dim=64,
max_cache_length=1024,
final_logit_softcap=None,
attention_types=[modules.AttentionType.GLOBAL],
use_post_attn_norm=None,
attn_logits_soft_cap=None,
use_post_ffw_norm=None,
)
sow = sow_lib.SowModule(
embeddings=True,
rs_after_attention=False, # This should results in a None value.
rs_after_ffw=True,
attn_logits_topk=5,
ffw_hidden_topk=11,
)
attention_mask = jnp.ones((1, 1, config.max_cache_length))
cache = config.init_cache(1, dtype=jnp.float32)
transformer = transformer_lib.Transformer(config, sow=sow)
params = transformer.init(
jax.random.PRNGKey(0),
jnp.array([[1]]),
jnp.array([[1]]),
cache,
attention_mask,
)
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['params'],
)
raw_input = ['input string', 'hello world']

result = sampler(raw_input, total_generation_steps=10)
input_length = max([len(vocab.EncodeAsIds(i)) for i in raw_input])
input_length += 1 # +1 for BOS token
output_length = max(len(tokens) for tokens in result.tokens)
length = input_length + output_length
self.assertIsNotNone(result)
intermediates = result.intermediates
self.assertIsNotNone(intermediates)
self.assertReasonableTensor(
intermediates.embeddings,
expected_shape=(2, length, config.embed_dim),
)
# Verify that the intermediates are different for two different steps.
self.assertNotAlmostEqual(
jnp.sum(intermediates.embeddings[:, 1, ...]),
jnp.sum(intermediates.embeddings[:, 2, ...]),
)
# Verify that the intermediates are filled in for each layer.
self.assertLen(intermediates.layers, config.num_layers)
for layer in intermediates.layers:
# For the requested intermediates we check the shape and that values are
# not all zeros, which was the initial value.
self.assertReasonableTensor(
layer.rs_after_ffw,
expected_shape=(2, length, config.embed_dim),
)
self.assertReasonableTensor(
layer.attn_logits_topk_values,
expected_shape=(2, length, config.num_heads, sow.attn_logits_topk),
)
self.assertReasonableTensor(
layer.attn_logits_topk_indices,
expected_shape=(2, length, config.num_heads, sow.attn_logits_topk),
)
self.assertReasonableTensor(
layer.ffw_hidden_topk_values,
expected_shape=(2, length, sow.ffw_hidden_topk),
)
self.assertReasonableTensor(
layer.ffw_hidden_topk_indices,
expected_shape=(2, length, sow.ffw_hidden_topk),
)
# For the none requested intermediates we want to have None values.
self.assertIsNone(layer.rs_after_attention)

def test_compute_attention_mask(self):
# Check that the input mask is correctly applied when total sampling steps
# is lower than the max cache length.
Expand Down
Loading

0 comments on commit 11b2feb

Please sign in to comment.