Skip to content

Commit 7331e13

Browse files
author
maxtext authors
committed
Merge pull request #1067 from AI-Hypercomputer:less_prefill_array
PiperOrigin-RevId: 700785020
2 parents 3f93a89 + 1d7aa1b commit 7331e13

File tree

3 files changed

+159
-16
lines changed

3 files changed

+159
-16
lines changed

MaxText/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,9 @@ inference_microbenchmark_log_file_path: ""
441441
inference_metadata_file: "" # path to a json file
442442
enable_model_warmup: False
443443

444+
# Stack prefill cache across the layer to reduce the
445+
# Python layer latency.
446+
stack_prefill_result_cache: False
444447

445448
# KV Cache layout control
446449
# Logical layout: 0,1,2,3 ; CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV

MaxText/maxengine.py

Lines changed: 94 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,22 @@ def load_params(self, *args, rng: Optional[jax.random.PRNGKey] = None, **kwargs)
131131

132132
self.prefill_kv_cache_annotations = max_utils.get_prefill_kv_cache_annotations(self.model, self.config, rng2, self._mesh)
133133
self.prefill_kv_cache_shardings = jax.tree_util.tree_map(
134-
lambda x: jax.sharding.NamedSharding(self._mesh, x), self.prefill_kv_cache_annotations
134+
lambda x: jax.sharding.NamedSharding(self._mesh, x),
135+
self.prefill_kv_cache_annotations,
135136
)
136137

138+
if self.config.stack_prefill_result_cache:
139+
# Add extra axis for the axis generated by the stack.
140+
self.prefill_kv_cache_shardings = jax.tree_util.tree_map(
141+
lambda x: jax.sharding.NamedSharding(self._mesh, jax.sharding.PartitionSpec(None, *x.spec)),
142+
self.prefill_kv_cache_shardings,
143+
)
144+
self.prefill_kv_cache_shardings = self.prefill_kv_cache_shardings["decoder"]["layers_0"]
145+
137146
self.kv_cache_annotations = max_utils.get_kv_cache_annotations(self.model, self.config, rng2, self._mesh)
138147
self.kv_cache_shardings = jax.tree_util.tree_map(
139-
lambda x: jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations
148+
lambda x: jax.sharding.NamedSharding(self._mesh, x),
149+
self.kv_cache_annotations,
140150
)
141151

142152
if self.model.quant and not self.config.checkpoint_is_quantized:
@@ -172,12 +182,40 @@ def model_apply(_p, _rng):
172182
params["aqt"] = new_vars["aqt"]
173183
params["params"] = quantizations.remove_quantized_params(state.params["params"], new_vars["aqt"])
174184
self.abstract_params = jax.tree_util.tree_map(
175-
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), params
185+
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding),
186+
params,
176187
)
177188
max_utils.save_quantized_checkpoint_if_configured(self.config, params)
178189
self.model.quant.quant_mode = quantizations.get_quant_mode("serve")
179190
return params
180191

192+
def _maybe_stack_prefill_result_cache(self, cache):
193+
"""Stack the caches across the layers."""
194+
if not self.config.stack_prefill_result_cache:
195+
return cache
196+
197+
layer_keys = []
198+
for i in range(self.config.num_decoder_layers):
199+
layer_keys.append(f"layers_{i}")
200+
201+
layer_cache = [cache["decoder"][layer_key] for layer_key in layer_keys]
202+
203+
return jax.tree.map(lambda *c: jnp.stack(c), *layer_cache)
204+
205+
def _maybe_unstack_prefill_result_cache(self, cache):
206+
"""Unstack the caches across the layers."""
207+
if not self.config.stack_prefill_result_cache:
208+
return cache
209+
210+
flat_cache, treedef = jax.tree.flatten(cache)
211+
layer_cache = [jax.tree.unflatten(treedef, flat_cache_vars) for flat_cache_vars in zip(*flat_cache, strict=True)]
212+
res_cache = {"decoder": {}}
213+
214+
for i in range(self.config.num_decoder_layers):
215+
res_cache["decoder"][f"layers_{i}"] = layer_cache[i]
216+
217+
return res_cache
218+
181219
@functools.partial(jax.jit, static_argnums=(0,))
182220
def prefill(
183221
self,
@@ -231,7 +269,9 @@ def prefill(
231269
next_pos = jnp.full((1, 1), true_length, dtype=jnp.int32)
232270
generated_tokens = jnp.zeros((1, 1), dtype=jnp.int32)
233271
selected_logits = jax.lax.dynamic_slice(
234-
flat_logits, (0, true_length - 1, 0), (flat_logits.shape[0], 1, flat_logits.shape[2])
272+
flat_logits,
273+
(0, true_length - 1, 0),
274+
(flat_logits.shape[0], 1, flat_logits.shape[2]),
235275
)
236276
selected_logits = jax.lax.with_sharding_constraint(selected_logits, self.replicated_sharding)
237277

@@ -259,9 +299,12 @@ def prefill(
259299
samples_per_slot=1,
260300
)
261301

302+
cache = new_vars["cache"]
303+
cache = self._maybe_stack_prefill_result_cache(cache)
304+
262305
return {
263306
"logits": selected_logits,
264-
"cache": new_vars["cache"],
307+
"cache": cache,
265308
"next_pos": next_pos,
266309
"generated_tokens": generated_tokens,
267310
"tokens": first_generated_token,
@@ -346,9 +389,17 @@ def insert(
346389
"""Insert into KV cache"""
347390
unboxed_prefix = max_utils.unbox_logicallypartioned(prefix)
348391

392+
unboxed_prefix["cache"] = self._maybe_unstack_prefill_result_cache(unboxed_prefix["cache"])
393+
349394
def copy(path, partial_cache, full_cache, annotations):
350395
path_key = path[-1].key
351-
if path_key in ["cache_ar_index", "cached_ar_key", "cached_ar_value", "cached_ar_key_scale", "cached_ar_value_scale"]:
396+
if path_key in [
397+
"cache_ar_index",
398+
"cached_ar_key",
399+
"cached_ar_value",
400+
"cached_ar_key_scale",
401+
"cached_ar_value_scale",
402+
]:
352403
return full_cache # we don't even zero these out because we can mask them out.
353404

354405
batch_idx = -1
@@ -388,12 +439,18 @@ def copy(path, partial_cache, full_cache, annotations):
388439
raise ValueError(f"We don't have a strategy for inserting {path_key}")
389440

390441
inserted_cache = jax.tree_util.tree_map_with_path(
391-
copy, unboxed_prefix["cache"], decode_state["cache"], self.kv_cache_annotations_named
442+
copy,
443+
unboxed_prefix["cache"],
444+
decode_state["cache"],
445+
self.kv_cache_annotations_named,
392446
)
393447
inserted_logits = jax.lax.dynamic_update_index_in_dim(decode_state["logits"], unboxed_prefix["logits"], slot, 0)
394448
inserted_next_pos = jax.lax.dynamic_update_index_in_dim(decode_state["next_pos"], unboxed_prefix["next_pos"], slot, 0)
395449
inserted_generated_tokens = jax.lax.dynamic_update_index_in_dim(
396-
decode_state["generated_tokens"], unboxed_prefix["generated_tokens"], slot, 0
450+
decode_state["generated_tokens"],
451+
unboxed_prefix["generated_tokens"],
452+
slot,
453+
0,
397454
)
398455
inserted_tokens = jax.lax.dynamic_update_index_in_dim(decode_state["tokens"], unboxed_prefix["tokens"], slot, 0)
399456

@@ -458,11 +515,26 @@ def init(abstract_params):
458515
mutable=["cache"],
459516
)
460517

461-
next_pos = jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1), dtype=jnp.int32)
462-
generated_tokens = jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1), dtype=jnp.int32)
463-
tokens = jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1), dtype=jnp.int32)
518+
next_pos = jnp.zeros(
519+
(int(self.config.per_device_batch_size * jax.device_count()), 1),
520+
dtype=jnp.int32,
521+
)
522+
generated_tokens = jnp.zeros(
523+
(int(self.config.per_device_batch_size * jax.device_count()), 1),
524+
dtype=jnp.int32,
525+
)
526+
tokens = jnp.zeros(
527+
(int(self.config.per_device_batch_size * jax.device_count()), 1),
528+
dtype=jnp.int32,
529+
)
464530
return {
465-
"logits": jnp.zeros((int(self.config.per_device_batch_size * jax.device_count()), 1, self.config.vocab_size)),
531+
"logits": jnp.zeros(
532+
(
533+
int(self.config.per_device_batch_size * jax.device_count()),
534+
1,
535+
self.config.vocab_size,
536+
)
537+
),
466538
"cache": cache["cache"],
467539
"next_pos": next_pos,
468540
"generated_tokens": generated_tokens,
@@ -477,7 +549,8 @@ def init(abstract_params):
477549
mesh_annotations = nn.logical_to_mesh(logical_annotations)
478550

479551
shardings = jax.tree_util.tree_map(
480-
lambda mesh_annotation: jax.sharding.NamedSharding(self._mesh, mesh_annotation), mesh_annotations
552+
lambda mesh_annotation: jax.sharding.NamedSharding(self._mesh, mesh_annotation),
553+
mesh_annotations,
481554
)
482555

483556
@functools.partial(jax.jit, out_shardings=shardings)
@@ -519,16 +592,21 @@ def colocated_cpus(self) -> None:
519592
raise NotImplementedError
520593

521594

522-
def set_engine_vars_from_base_engine(engine: engine_api.Engine, base_engine: engine_api.Engine, rng: jax.random.PRNGKey):
595+
def set_engine_vars_from_base_engine(
596+
engine: engine_api.Engine,
597+
base_engine: engine_api.Engine,
598+
rng: jax.random.PRNGKey,
599+
):
523600
"""Set internal vars from base_engine, which has already loaded the checkpoint and has sharding,
524601
mesh, and kv cache related vars set.
525602
"""
526603
engine.model.quant.quant_mode = base_engine.model.quant.quant_mode
527604
engine.state_mesh_annotations = base_engine.state_mesh_annotations
528605
engine.abstract_params = base_engine.abstract_params
529-
engine.kv_cache_annotations = max_utils.get_kv_cache_annotations(engine.model, engine.config, rng, engine._mesh) # pylint: disable=protected-access
606+
engine.kv_cache_annotations = max_utils.get_kv_cache_annotations(engine.model, engine.config, rng, engine.mesh) # pylint: disable=protected-access
530607
engine.kv_cache_shardings = jax.tree_util.tree_map(
531-
lambda x: jax.sharding.NamedSharding(engine._mesh, x), engine.kv_cache_annotations # pylint: disable=protected-access
608+
lambda x: jax.sharding.NamedSharding(engine.mesh, x),
609+
engine.kv_cache_annotations, # pylint: disable=protected-access
532610
)
533611

534612

MaxText/tests/maxengine_test.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""
2+
Copyright 2024 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
""" Tests for the maxengine """
18+
19+
import jax
20+
from jax import numpy as jnp
21+
import numpy as np
22+
import unittest
23+
import pyconfig
24+
from maxengine import MaxEngine
25+
26+
27+
class MaxEngineTest(unittest.TestCase):
28+
"""Tests for MaxEngine."""
29+
30+
# TODO: add unit test for the MaxEngine.
31+
32+
def test_stack_and_unstack_prefill_cache(self):
33+
pyconfig.initialize(
34+
[None, "configs/base.yml"],
35+
enable_checkpointing=False,
36+
stack_prefill_result_cache=True,
37+
)
38+
config = pyconfig.config
39+
engine = MaxEngine(config, jax.devices())
40+
num_layers = engine.config.num_decoder_layers
41+
input = {
42+
"decoder": {},
43+
}
44+
for i in range(num_layers):
45+
input["decoder"][f"layers_{i}"] = {
46+
"a": jnp.ones((1, 10)),
47+
"b": jnp.ones((1, 9)),
48+
}
49+
50+
expected_stacked = {
51+
"a": jnp.ones((num_layers, 1, 10)),
52+
"b": jnp.ones((num_layers, 1, 9)),
53+
}
54+
got_stacked = engine._maybe_stack_prefill_result_cache(input)
55+
jax.tree.map(np.testing.assert_array_equal, got_stacked, expected_stacked)
56+
57+
got_unstacked = engine._maybe_unstack_prefill_result_cache(got_stacked)
58+
jax.tree.map(np.testing.assert_array_equal, got_unstacked, input)
59+
60+
61+
if __name__ == "__main__":
62+
unittest.main()

0 commit comments

Comments
 (0)