Skip to content

Commit af38d6e

Browse files
ConchylicultorThe gemma Authors
authored andcommitted
Make cache_size optional
PiperOrigin-RevId: 701019861
1 parent 65a8858 commit af38d6e

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

gemma/transformer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class TransformerConfig:
6363
use_post_attn_norm: bool
6464
use_post_ffw_norm: bool
6565
attention_types: Iterable[modules.AttentionType]
66-
max_cache_length: int = 1024
66+
max_cache_length: int | None = 1024
6767
query_pre_attn_norm: QueryPreAttentionNormalisation = (
6868
QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM
6969
)
@@ -83,7 +83,7 @@ def query_pre_attn_scalar(self) -> float:
8383

8484
@classmethod
8585
def from_params(
86-
cls, params: params_lib.Params, cache_size: int = 1024
86+
cls, params: params_lib.Params, cache_size: int | None = 1024
8787
) -> 'TransformerConfig':
8888
"""Creates a TransformerConfig from loaded parameters.
8989
@@ -116,7 +116,7 @@ def from_params(
116116
)
117117

118118
@classmethod
119-
def gemma_2b(cls, cache_size: int):
119+
def gemma_2b(cls, cache_size: int | None):
120120
return cls(
121121
num_layers=_NUM_LAYERS_GEMMA_2B,
122122
num_embed=256128,
@@ -133,7 +133,7 @@ def gemma_2b(cls, cache_size: int):
133133
)
134134

135135
@classmethod
136-
def gemma_7b(cls, cache_size: int):
136+
def gemma_7b(cls, cache_size: int | None):
137137
return cls(
138138
num_layers=_NUM_LAYERS_GEMMA_7B,
139139
num_embed=256128,
@@ -150,7 +150,7 @@ def gemma_7b(cls, cache_size: int):
150150
)
151151

152152
@classmethod
153-
def gemma2_2b(cls, cache_size: int):
153+
def gemma2_2b(cls, cache_size: int | None):
154154
return cls(
155155
num_layers=_NUM_LAYERS_GEMMA2_2B,
156156
num_embed=256128,
@@ -174,7 +174,7 @@ def gemma2_2b(cls, cache_size: int):
174174
)
175175

176176
@classmethod
177-
def gemma2_9b(cls, cache_size: int):
177+
def gemma2_9b(cls, cache_size: int | None):
178178
return cls(
179179
num_layers=_NUM_LAYERS_GEMMA2_9B,
180180
num_embed=256128,
@@ -199,7 +199,7 @@ def gemma2_9b(cls, cache_size: int):
199199
)
200200

201201
@classmethod
202-
def gemma2_27b(cls, cache_size: int):
202+
def gemma2_27b(cls, cache_size: int | None):
203203
return cls(
204204
num_layers=_NUM_LAYERS_GEMMA2_27B,
205205
num_embed=256128,
@@ -229,6 +229,8 @@ def init_cache(
229229
dtype: jnp.dtype = jnp.bfloat16,
230230
) -> Cache:
231231
"""Initializes a new Transformer cache."""
232+
if self.max_cache_length is None:
233+
raise ValueError('max_cache_length must be set to initialize cache.')
232234
cache = {
233235
f'layer_{i}': modules.Attention.init_cache(
234236
self.max_cache_length,

0 commit comments

Comments
 (0)