@@ -131,12 +131,22 @@ def load_params(self, *args, rng: Optional[jax.random.PRNGKey] = None, **kwargs)
131
131
132
132
self .prefill_kv_cache_annotations = max_utils .get_prefill_kv_cache_annotations (self .model , self .config , rng2 , self ._mesh )
133
133
self .prefill_kv_cache_shardings = jax .tree_util .tree_map (
134
- lambda x : jax .sharding .NamedSharding (self ._mesh , x ), self .prefill_kv_cache_annotations
134
+ lambda x : jax .sharding .NamedSharding (self ._mesh , x ),
135
+ self .prefill_kv_cache_annotations ,
135
136
)
136
137
138
+ if self .config .stack_prefill_result_cache :
139
+ # Add extra axis for the axis generated by the stack.
140
+ self .prefill_kv_cache_shardings = jax .tree_util .tree_map (
141
+ lambda x : jax .sharding .NamedSharding (self ._mesh , jax .sharding .PartitionSpec (None , * x .spec )),
142
+ self .prefill_kv_cache_shardings ,
143
+ )
144
+ self .prefill_kv_cache_shardings = self .prefill_kv_cache_shardings ["decoder" ]["layers_0" ]
145
+
137
146
self .kv_cache_annotations = max_utils .get_kv_cache_annotations (self .model , self .config , rng2 , self ._mesh )
138
147
self .kv_cache_shardings = jax .tree_util .tree_map (
139
- lambda x : jax .sharding .NamedSharding (self ._mesh , x ), self .kv_cache_annotations
148
+ lambda x : jax .sharding .NamedSharding (self ._mesh , x ),
149
+ self .kv_cache_annotations ,
140
150
)
141
151
142
152
if self .model .quant and not self .config .checkpoint_is_quantized :
@@ -172,12 +182,40 @@ def model_apply(_p, _rng):
172
182
params ["aqt" ] = new_vars ["aqt" ]
173
183
params ["params" ] = quantizations .remove_quantized_params (state .params ["params" ], new_vars ["aqt" ])
174
184
self .abstract_params = jax .tree_util .tree_map (
175
- lambda x : jax .ShapeDtypeStruct (shape = x .shape , dtype = x .dtype , sharding = x .sharding ), params
185
+ lambda x : jax .ShapeDtypeStruct (shape = x .shape , dtype = x .dtype , sharding = x .sharding ),
186
+ params ,
176
187
)
177
188
max_utils .save_quantized_checkpoint_if_configured (self .config , params )
178
189
self .model .quant .quant_mode = quantizations .get_quant_mode ("serve" )
179
190
return params
180
191
192
+ def _maybe_stack_prefill_result_cache (self , cache ):
193
+ """Stack the caches across the layers."""
194
+ if not self .config .stack_prefill_result_cache :
195
+ return cache
196
+
197
+ layer_keys = []
198
+ for i in range (self .config .num_decoder_layers ):
199
+ layer_keys .append (f"layers_{ i } " )
200
+
201
+ layer_cache = [cache ["decoder" ][layer_key ] for layer_key in layer_keys ]
202
+
203
+ return jax .tree .map (lambda * c : jnp .stack (c ), * layer_cache )
204
+
205
+ def _maybe_unstack_prefill_result_cache (self , cache ):
206
+ """Unstack the caches across the layers."""
207
+ if not self .config .stack_prefill_result_cache :
208
+ return cache
209
+
210
+ flat_cache , treedef = jax .tree .flatten (cache )
211
+ layer_cache = [jax .tree .unflatten (treedef , flat_cache_vars ) for flat_cache_vars in zip (* flat_cache , strict = True )]
212
+ res_cache = {"decoder" : {}}
213
+
214
+ for i in range (self .config .num_decoder_layers ):
215
+ res_cache ["decoder" ][f"layers_{ i } " ] = layer_cache [i ]
216
+
217
+ return res_cache
218
+
181
219
@functools .partial (jax .jit , static_argnums = (0 ,))
182
220
def prefill (
183
221
self ,
@@ -231,7 +269,9 @@ def prefill(
231
269
next_pos = jnp .full ((1 , 1 ), true_length , dtype = jnp .int32 )
232
270
generated_tokens = jnp .zeros ((1 , 1 ), dtype = jnp .int32 )
233
271
selected_logits = jax .lax .dynamic_slice (
234
- flat_logits , (0 , true_length - 1 , 0 ), (flat_logits .shape [0 ], 1 , flat_logits .shape [2 ])
272
+ flat_logits ,
273
+ (0 , true_length - 1 , 0 ),
274
+ (flat_logits .shape [0 ], 1 , flat_logits .shape [2 ]),
235
275
)
236
276
selected_logits = jax .lax .with_sharding_constraint (selected_logits , self .replicated_sharding )
237
277
@@ -259,9 +299,12 @@ def prefill(
259
299
samples_per_slot = 1 ,
260
300
)
261
301
302
+ cache = new_vars ["cache" ]
303
+ cache = self ._maybe_stack_prefill_result_cache (cache )
304
+
262
305
return {
263
306
"logits" : selected_logits ,
264
- "cache" : new_vars [ " cache" ] ,
307
+ "cache" : cache ,
265
308
"next_pos" : next_pos ,
266
309
"generated_tokens" : generated_tokens ,
267
310
"tokens" : first_generated_token ,
@@ -346,9 +389,17 @@ def insert(
346
389
"""Insert into KV cache"""
347
390
unboxed_prefix = max_utils .unbox_logicallypartioned (prefix )
348
391
392
+ unboxed_prefix ["cache" ] = self ._maybe_unstack_prefill_result_cache (unboxed_prefix ["cache" ])
393
+
349
394
def copy (path , partial_cache , full_cache , annotations ):
350
395
path_key = path [- 1 ].key
351
- if path_key in ["cache_ar_index" , "cached_ar_key" , "cached_ar_value" , "cached_ar_key_scale" , "cached_ar_value_scale" ]:
396
+ if path_key in [
397
+ "cache_ar_index" ,
398
+ "cached_ar_key" ,
399
+ "cached_ar_value" ,
400
+ "cached_ar_key_scale" ,
401
+ "cached_ar_value_scale" ,
402
+ ]:
352
403
return full_cache # we don't even zero these out because we can mask them out.
353
404
354
405
batch_idx = - 1
@@ -388,12 +439,18 @@ def copy(path, partial_cache, full_cache, annotations):
388
439
raise ValueError (f"We don't have a strategy for inserting { path_key } " )
389
440
390
441
inserted_cache = jax .tree_util .tree_map_with_path (
391
- copy , unboxed_prefix ["cache" ], decode_state ["cache" ], self .kv_cache_annotations_named
442
+ copy ,
443
+ unboxed_prefix ["cache" ],
444
+ decode_state ["cache" ],
445
+ self .kv_cache_annotations_named ,
392
446
)
393
447
inserted_logits = jax .lax .dynamic_update_index_in_dim (decode_state ["logits" ], unboxed_prefix ["logits" ], slot , 0 )
394
448
inserted_next_pos = jax .lax .dynamic_update_index_in_dim (decode_state ["next_pos" ], unboxed_prefix ["next_pos" ], slot , 0 )
395
449
inserted_generated_tokens = jax .lax .dynamic_update_index_in_dim (
396
- decode_state ["generated_tokens" ], unboxed_prefix ["generated_tokens" ], slot , 0
450
+ decode_state ["generated_tokens" ],
451
+ unboxed_prefix ["generated_tokens" ],
452
+ slot ,
453
+ 0 ,
397
454
)
398
455
inserted_tokens = jax .lax .dynamic_update_index_in_dim (decode_state ["tokens" ], unboxed_prefix ["tokens" ], slot , 0 )
399
456
@@ -458,11 +515,26 @@ def init(abstract_params):
458
515
mutable = ["cache" ],
459
516
)
460
517
461
- next_pos = jnp .zeros ((int (self .config .per_device_batch_size * jax .device_count ()), 1 ), dtype = jnp .int32 )
462
- generated_tokens = jnp .zeros ((int (self .config .per_device_batch_size * jax .device_count ()), 1 ), dtype = jnp .int32 )
463
- tokens = jnp .zeros ((int (self .config .per_device_batch_size * jax .device_count ()), 1 ), dtype = jnp .int32 )
518
+ next_pos = jnp .zeros (
519
+ (int (self .config .per_device_batch_size * jax .device_count ()), 1 ),
520
+ dtype = jnp .int32 ,
521
+ )
522
+ generated_tokens = jnp .zeros (
523
+ (int (self .config .per_device_batch_size * jax .device_count ()), 1 ),
524
+ dtype = jnp .int32 ,
525
+ )
526
+ tokens = jnp .zeros (
527
+ (int (self .config .per_device_batch_size * jax .device_count ()), 1 ),
528
+ dtype = jnp .int32 ,
529
+ )
464
530
return {
465
- "logits" : jnp .zeros ((int (self .config .per_device_batch_size * jax .device_count ()), 1 , self .config .vocab_size )),
531
+ "logits" : jnp .zeros (
532
+ (
533
+ int (self .config .per_device_batch_size * jax .device_count ()),
534
+ 1 ,
535
+ self .config .vocab_size ,
536
+ )
537
+ ),
466
538
"cache" : cache ["cache" ],
467
539
"next_pos" : next_pos ,
468
540
"generated_tokens" : generated_tokens ,
@@ -477,7 +549,8 @@ def init(abstract_params):
477
549
mesh_annotations = nn .logical_to_mesh (logical_annotations )
478
550
479
551
shardings = jax .tree_util .tree_map (
480
- lambda mesh_annotation : jax .sharding .NamedSharding (self ._mesh , mesh_annotation ), mesh_annotations
552
+ lambda mesh_annotation : jax .sharding .NamedSharding (self ._mesh , mesh_annotation ),
553
+ mesh_annotations ,
481
554
)
482
555
483
556
@functools .partial (jax .jit , out_shardings = shardings )
@@ -519,16 +592,21 @@ def colocated_cpus(self) -> None:
519
592
raise NotImplementedError
520
593
521
594
522
- def set_engine_vars_from_base_engine (engine : engine_api .Engine , base_engine : engine_api .Engine , rng : jax .random .PRNGKey ):
595
+ def set_engine_vars_from_base_engine (
596
+ engine : engine_api .Engine ,
597
+ base_engine : engine_api .Engine ,
598
+ rng : jax .random .PRNGKey ,
599
+ ):
523
600
"""Set internal vars from base_engine, which has already loaded the checkpoint and has sharding,
524
601
mesh, and kv cache related vars set.
525
602
"""
526
603
engine .model .quant .quant_mode = base_engine .model .quant .quant_mode
527
604
engine .state_mesh_annotations = base_engine .state_mesh_annotations
528
605
engine .abstract_params = base_engine .abstract_params
529
- engine .kv_cache_annotations = max_utils .get_kv_cache_annotations (engine .model , engine .config , rng , engine ._mesh ) # pylint: disable=protected-access
606
+ engine .kv_cache_annotations = max_utils .get_kv_cache_annotations (engine .model , engine .config , rng , engine .mesh ) # pylint: disable=protected-access
530
607
engine .kv_cache_shardings = jax .tree_util .tree_map (
531
- lambda x : jax .sharding .NamedSharding (engine ._mesh , x ), engine .kv_cache_annotations # pylint: disable=protected-access
608
+ lambda x : jax .sharding .NamedSharding (engine .mesh , x ),
609
+ engine .kv_cache_annotations , # pylint: disable=protected-access
532
610
)
533
611
534
612
0 commit comments