Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to sow intermediate activations. #62

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flax import linen as nn
from gemma import layers
from gemma import positional_embeddings
from gemma import sow_lib
import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -52,8 +53,8 @@ def _reconstruct_rotated_cache_positions():

cache_positions = cache_positions[None, None, :] # [1, 1, cache_len]
segment_pos = segment_pos[:, :, None] # [B, seq_len, 1]
sliding_mask = (cache_positions > segment_pos - sliding_window_size)
sliding_mask *= (cache_positions < segment_pos + sliding_window_size)
sliding_mask = cache_positions > segment_pos - sliding_window_size
sliding_mask *= cache_positions < segment_pos + sliding_window_size
return sliding_mask


Expand Down Expand Up @@ -95,6 +96,7 @@ class Attention(nn.Module):
query_pre_attn_scalar: float
attn_logits_soft_cap: float | None = None
sliding_window_size: int | None = None
sow: sow_lib.SowModule = sow_lib.SowModule()

@property
def use_qkv_einsum(self):
Expand Down Expand Up @@ -127,6 +129,7 @@ def __call__(
segment_pos: jax.Array,
cache: LayerCache | None,
attn_mask: jax.Array,
intermediates: sow_lib.BlockIntermediates | None = None,
) -> tuple[LayerCache | None, jax.Array]:
seq_len = x.shape[1]

Expand Down Expand Up @@ -190,6 +193,7 @@ def __call__(
attn_mask *= sliding_mask

padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK)
self.sow.maybe_sow_attn_logits_topk(padded_logits, intermediates)
probs = jax.nn.softmax(padded_logits, axis=-1).astype(key_proj.dtype)
if self.use_gqa:
# Reshape matrices to enable einsums over groups.
Expand Down Expand Up @@ -242,9 +246,10 @@ class FeedForward(nn.Module):
features: int
hidden_dim: int
transpose_gating_einsum: bool
sow: sow_lib.SowModule = sow_lib.SowModule()

@nn.compact
def __call__(self, x):
def __call__(self, x, intermediates=None):
# Some versions use an alternate parameter ordering that
# transposes hidden_dim and features.
if self.transpose_gating_einsum:
Expand All @@ -267,6 +272,8 @@ def __call__(self, x):
ff1 = jnp.dot(x, w_gating[1])
activations = gate_value * ff1

self.sow.maybe_sow_ffw_hidden_topk(activations, intermediates)

# Down projection
w_linear = self.param(
'linear',
Expand All @@ -293,6 +300,7 @@ class Block(nn.Module):
transpose_gating_einsum: bool
attn_logits_soft_cap: float | None = None
sliding_window_size: int | None = None
sow: sow_lib.SowModule = sow_lib.SowModule()

def setup(self):
self.pre_attention_norm = layers.RMSNorm()
Expand All @@ -305,6 +313,7 @@ def setup(self):
query_pre_attn_scalar=self.query_pre_attn_scalar,
attn_logits_soft_cap=self.attn_logits_soft_cap,
sliding_window_size=self.sliding_window_size,
sow=self.sow,
)
self.post_attention_norm = None
if self.use_post_attn_norm:
Expand All @@ -315,6 +324,7 @@ def setup(self):
features=self.embed_dim,
hidden_dim=self.hidden_dim,
transpose_gating_einsum=self.transpose_gating_einsum,
sow=self.sow,
)
self.post_ffw_norm = None
if self.use_post_ffw_norm:
Expand All @@ -326,20 +336,24 @@ def __call__(
segment_pos: jax.Array,
cache: LayerCache | None,
attn_mask: jax.Array,
intermediates: sow_lib.BlockIntermediates | None = None,
) -> tuple[LayerCache | None, jax.Array]:
inputs_normalized = self.pre_attention_norm(x)
cache, attn_output = self.attn(
inputs_normalized,
segment_pos,
cache,
attn_mask,
intermediates,
)
if self.post_attention_norm is not None:
attn_output = self.post_attention_norm(attn_output)
attn_output += x
self.sow.maybe_sow_rs_after_attention(attn_output, intermediates)
outputs = self.pre_ffw_norm(attn_output)
outputs = self.mlp(outputs)
outputs = self.mlp(outputs, intermediates)
if self.post_ffw_norm is not None:
outputs = self.post_ffw_norm(outputs)
outputs += attn_output
self.sow.maybe_sow_rs_after_ffw(outputs, intermediates)
return cache, outputs
27 changes: 27 additions & 0 deletions gemma/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import chex
from gemma import modules
from gemma import params as params_lib
from gemma import sow_lib
from gemma import transformer as transformer_lib
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -88,9 +89,13 @@ class _SamplingState:
# List of tokens that are forbidden to be generated.
forbidden_token_ids: Sequence[int] | None = None

# Intermediate activations from the model if requested.
intermediates: sow_lib.TransformerIntermediates | None = None


@dataclasses.dataclass
class SamplerOutput:
"""Output of the sampler."""

# Decoded samples from the model.
text: list[str]
Expand All @@ -101,6 +106,9 @@ class SamplerOutput:
# Tokens corresponding to the generated samples.
tokens: list[list[int]]

# Intermediate activations from the model if requested.
intermediates: sow_lib.TransformerIntermediates | None = None


class Sampler:
"""Sampler for gemma transformer."""
Expand Down Expand Up @@ -142,13 +150,15 @@ def _sample_step(
sampler_state.positions[:, decoding_step], -1
)
last_token = last_token.reshape((batch_size, 1))
step_intermediates = sow_lib.TransformerIntermediates()

logits, cache = self.transformer.apply(
{'params': params},
last_token,
step_positions,
sampler_state.cache,
attention_mask,
step_intermediates,
)
if sampler_state.forbidden_token_ids:
logits = logits.at[:, :, sampler_state.forbidden_token_ids].set(-jnp.inf)
Expand All @@ -174,6 +184,9 @@ def _sample_step(
else:
logits_buffer = sampler_state.logits_buffer

if sampler_state.intermediates is not None:
sampler_state.intermediates.merge(decoding_step, step_intermediates)

done = sampler_state.done | jnp.equal(
token_buffer[:, decoding_step + 1], self.vocab.eos_id()
)
Expand All @@ -188,12 +201,21 @@ def _sample_step(
done=done,
total_sampling_steps=sampler_state.total_sampling_steps,
forbidden_token_ids=sampler_state.forbidden_token_ids,
intermediates=sampler_state.intermediates,
)

def init_cache(self, bsz) -> dict[str, modules.LayerCache]:
"""Initializes the attention cache for each layer."""
return self.transformer.config.init_cache(bsz, dtype=self.dtype)

def init_intermediates(
self, bsz, buffer_size
) -> sow_lib.TransformerIntermediates:
"""Initializes the intermediate activations that will be filled."""
return self.transformer.config.init_intermediates(
bsz, buffer_size, self.transformer.sow
)

def init_sample_state(
self,
all_input_ids: list[jax.Array],
Expand Down Expand Up @@ -244,6 +266,7 @@ def init_sample_state(
done=done,
total_sampling_steps=total_sampling_steps,
forbidden_token_ids=forbidden_token_ids,
intermediates=self.init_intermediates(bsz, buffer_size),
)

def tokenize(self, input_string: str) -> jax.Array:
Expand Down Expand Up @@ -358,9 +381,13 @@ def __call__(

decoded_outputs = [self.vocab.DecodeIds(tokens) for tokens in out_tokens]

if sampling_state.intermediates is not None:
sampling_state.intermediates.trim(total_sampling_steps)

result = SamplerOutput(
text=decoded_outputs,
logits=out_logits,
tokens=out_tokens,
intermediates=sampling_state.intermediates,
)
return result
98 changes: 98 additions & 0 deletions gemma/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from absl.testing import absltest
from gemma import modules
from gemma import sampler as sampler_lib
from gemma import sow_lib
from gemma import transformer as transformer_lib
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -73,6 +74,17 @@ def EncodeAsIds(self, text: str) -> list[int]: # pylint: disable=invalid-name

class SamplerTest(absltest.TestCase):

def assertNotAllZeros(self, array, msg=None):
if not jnp.any(array).item():
msg = msg or f'The array {array} is all zeros.'
raise self.failureException(msg)

def assertReasonableTensor(self, array, expected_shape=None):
self.assertIsNotNone(array)
# self.assertNotAllZeros(array)
if expected_shape is not None:
self.assertEqual(array.shape, expected_shape)

def test_samples(self):
vocab = MockVocab()

Expand Down Expand Up @@ -317,6 +329,92 @@ def test_sampler_mask_tokens_after_eos_ids(self):
self.assertListEqual(list(masked_token_buffer[0]), [1, 5, 6, 2, 0, 0])
self.assertListEqual(list(masked_token_buffer[1]), [1, 3, 4, 2, 0, 0])

def test_sampler_sows_intermediates(self):
vocab = MockVocab()
config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types
num_layers=3,
num_embed=vocab.GetPieceSize(),
embed_dim=64,
hidden_dim=128,
num_heads=2,
num_kv_heads=1,
head_dim=64,
max_cache_length=1024,
final_logit_softcap=None,
attention_types=[modules.AttentionType.GLOBAL],
use_post_attn_norm=None,
attn_logits_soft_cap=None,
use_post_ffw_norm=None,
)
sow = sow_lib.SowModule(
embeddings=True,
rs_after_attention=False, # This should results in a None value.
rs_after_ffw=True,
attn_logits_topk=5,
ffw_hidden_topk=11,
)
attention_mask = jnp.ones((1, 1, config.max_cache_length))
cache = config.init_cache(1, dtype=jnp.float32)
transformer = transformer_lib.Transformer(config, sow=sow)
params = transformer.init(
jax.random.PRNGKey(0),
jnp.array([[1]]),
jnp.array([[1]]),
cache,
attention_mask,
)
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['params'],
)
raw_input = ['input string', 'hello world']

result = sampler(raw_input, total_generation_steps=10)
input_length = max([len(vocab.EncodeAsIds(i)) for i in raw_input])
input_length += 1 # +1 for BOS token
output_length = max(len(tokens) for tokens in result.tokens)
length = input_length + output_length
self.assertIsNotNone(result)
intermediates = result.intermediates
self.assertIsNotNone(intermediates)
self.assertReasonableTensor(
intermediates.embeddings,
expected_shape=(2, length, config.embed_dim),
)
# Verify that the intermediates are different for two different steps.
self.assertNotAlmostEqual(
jnp.sum(intermediates.embeddings[:, 1, ...]),
jnp.sum(intermediates.embeddings[:, 2, ...]),
)
# Verify that the intermediates are filled in for each layer.
self.assertLen(intermediates.layers, config.num_layers)
for layer in intermediates.layers:
# For the requested intermediates we check the shape and that values are
# not all zeros, which was the initial value.
self.assertReasonableTensor(
layer.rs_after_ffw,
expected_shape=(2, length, config.embed_dim),
)
self.assertReasonableTensor(
layer.attn_logits_topk_values,
expected_shape=(2, length, config.num_heads, sow.attn_logits_topk),
)
self.assertReasonableTensor(
layer.attn_logits_topk_indices,
expected_shape=(2, length, config.num_heads, sow.attn_logits_topk),
)
self.assertReasonableTensor(
layer.ffw_hidden_topk_values,
expected_shape=(2, length, sow.ffw_hidden_topk),
)
self.assertReasonableTensor(
layer.ffw_hidden_topk_indices,
expected_shape=(2, length, sow.ffw_hidden_topk),
)
# For the none requested intermediates we want to have None values.
self.assertIsNone(layer.rs_after_attention)

def test_compute_attention_mask(self):
# Check that the input mask is correctly applied when total sampling steps
# is lower than the max cache length.
Expand Down
Loading