Skip to content

Commit

Permalink
Merge pull request #1289 from AI-Hypercomputer:sixiang/debug
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 730570138
  • Loading branch information
maxtext authors committed Feb 24, 2025
2 parents 5faba12 + ae5f4c2 commit b18d7c6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
6 changes: 3 additions & 3 deletions MaxText/inference_mlperf/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,10 @@ def decode():
else:
assert False, "no generate fn"
result_tokens_l = []
for i in range(5):
for i in range(10):
self.decode_state, result_tokens = gen_fn(self.params, self.decode_state, None)
result_tokens_l.append(result_tokens)
for i in range(5):
for i in range(10):
# result_tokens.copy_to_host_async()
result_tokens = result_tokens_l[i].convert_to_numpy()
self.detokenize_backlog.put((result_tokens, False, 0, 0), block=True)
Expand Down Expand Up @@ -414,7 +414,7 @@ def detokenize():
self.detokenize_backlog.put((first_token, True, row.id, slot), block=True)
continue

if len(self.prefill_buckets[padded_len // 2]) == 0:
if len(self.prefill_buckets[padded_len // 2]) != 0:
prefill_batch(self.prefill_buckets[padded_len // 2], padded_len // 2)
self.prefill_buckets[padded_len // 2] = []
if padded_len == self.max_prefill_length:
Expand Down
9 changes: 8 additions & 1 deletion MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,10 @@ def copy(path, partial_cache, full_cache, annotations):
zeros = jnp.zeros((1, self.config.max_prefill_predict_length), dtype=jnp.int32)
## zero out in case prefill cache is too small to cover
full_cache = jax.lax.dynamic_update_index_in_dim(full_cache, zeros, slot, batch_idx)
# In case partial_cache is too small to slice at the given index, pad it with an extra seqlen
if i == num_prompts - 1:
pad = jnp.zeros((1, seq_len), dtype=int)
partial_cache = jnp.concatenate([partial_cache, pad], axis=1)
## copy prefill cache
partial_cache = jax.lax.dynamic_slice(partial_cache, (0, start_idx), (1, seq_len))
partial_cache = (partial_cache == partial_cache[0, 0]).astype(int)
Expand All @@ -748,8 +752,11 @@ def copy(path, partial_cache, full_cache, annotations):
slice_size[seqlen_index] = seq_len

slice_size = tuple(slice_size)
# Same as in prefill_segment_id processing
if i == num_prompts - 1:
pad = jnp.zeros(slice_size, dtype=partial_cache.dtype)
partial_cache = jnp.concatenate([partial_cache, pad], axis=seqlen_index)
partial_cache = jax.lax.dynamic_slice(partial_cache, start_indices, slice_size)
# jax.debug.print("start_indices: {}, slice_size: {}", start_indices, slice_size)

return jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)
else:
Expand Down

0 comments on commit b18d7c6

Please sign in to comment.