@@ -63,7 +63,7 @@ class TransformerConfig:
63
63
use_post_attn_norm : bool
64
64
use_post_ffw_norm : bool
65
65
attention_types : Iterable [modules .AttentionType ]
66
- max_cache_length : int = 1024
66
+ max_cache_length : int | None = 1024
67
67
query_pre_attn_norm : QueryPreAttentionNormalisation = (
68
68
QueryPreAttentionNormalisation .BY_ONE_OVER_SQRT_HEAD_DIM
69
69
)
@@ -83,7 +83,7 @@ def query_pre_attn_scalar(self) -> float:
83
83
84
84
@classmethod
85
85
def from_params (
86
- cls , params : params_lib .Params , cache_size : int = 1024
86
+ cls , params : params_lib .Params , cache_size : int | None = 1024
87
87
) -> 'TransformerConfig' :
88
88
"""Creates a TransformerConfig from loaded parameters.
89
89
@@ -116,7 +116,7 @@ def from_params(
116
116
)
117
117
118
118
@classmethod
119
- def gemma_2b (cls , cache_size : int ):
119
+ def gemma_2b (cls , cache_size : int | None ):
120
120
return cls (
121
121
num_layers = _NUM_LAYERS_GEMMA_2B ,
122
122
num_embed = 256128 ,
@@ -133,7 +133,7 @@ def gemma_2b(cls, cache_size: int):
133
133
)
134
134
135
135
@classmethod
136
- def gemma_7b (cls , cache_size : int ):
136
+ def gemma_7b (cls , cache_size : int | None ):
137
137
return cls (
138
138
num_layers = _NUM_LAYERS_GEMMA_7B ,
139
139
num_embed = 256128 ,
@@ -150,7 +150,7 @@ def gemma_7b(cls, cache_size: int):
150
150
)
151
151
152
152
@classmethod
153
- def gemma2_2b (cls , cache_size : int ):
153
+ def gemma2_2b (cls , cache_size : int | None ):
154
154
return cls (
155
155
num_layers = _NUM_LAYERS_GEMMA2_2B ,
156
156
num_embed = 256128 ,
@@ -174,7 +174,7 @@ def gemma2_2b(cls, cache_size: int):
174
174
)
175
175
176
176
@classmethod
177
- def gemma2_9b (cls , cache_size : int ):
177
+ def gemma2_9b (cls , cache_size : int | None ):
178
178
return cls (
179
179
num_layers = _NUM_LAYERS_GEMMA2_9B ,
180
180
num_embed = 256128 ,
@@ -199,7 +199,7 @@ def gemma2_9b(cls, cache_size: int):
199
199
)
200
200
201
201
@classmethod
202
- def gemma2_27b (cls , cache_size : int ):
202
+ def gemma2_27b (cls , cache_size : int | None ):
203
203
return cls (
204
204
num_layers = _NUM_LAYERS_GEMMA2_27B ,
205
205
num_embed = 256128 ,
@@ -229,6 +229,8 @@ def init_cache(
229
229
dtype : jnp .dtype = jnp .bfloat16 ,
230
230
) -> Cache :
231
231
"""Initializes a new Transformer cache."""
232
+ if self .max_cache_length is None :
233
+ raise ValueError ('max_cache_length must be set to initialize cache.' )
232
234
cache = {
233
235
f'layer_{ i } ' : modules .Attention .init_cache (
234
236
self .max_cache_length ,
0 commit comments