Skip to content

Commit

Permalink
enable concatenation of unpadded prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
sixiang-google committed Jan 17, 2025
1 parent 6d8a42b commit 0394b0d
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 68 deletions.
181 changes: 116 additions & 65 deletions MaxText/inference_mlperf/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import threading
import traceback
import signal
import random

from jetstream.engine import engine_api

Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(self, engine: engine_api.Engine, params, base_engine: engine_api.En

self.enable_batch_prefill = enable_batch_prefill
self.batch_size = engine.max_concurrent_decodes
self.max_prefill_length = engine.config.max_prefill_predict_length
self.max_decode_length = engine.config.max_target_length - engine.config.max_prefill_predict_length
metadata = engine.get_tokenizer()
self.tokenizer = engine.build_tokenizer(metadata)
Expand Down Expand Up @@ -105,32 +107,35 @@ def warmup(self, max_length, warmup_samples):
)
if length == 64 or length == 1024:
continue
log.info(f"Compiling batched prefill: {length}")
input_data_batch = jax.ShapeDtypeStruct((max_length,), jnp.dtype("int32"))
num_prompts = max_length // length
self._cached_pref_batch[length] = (
jax.jit(
self._prefill_insert_batch,
static_argnames=(
"num_prompts",
"padded_length",
),
donate_argnames=("decode_state",),
)
.lower(
self.params,
tokens=input_data_batch,
slots=jnp.arange(0, 8, dtype=int),
num_prompts=num_prompts,
decoder_positions=jnp.arange(0, max_length, dtype=int),
decoder_segment_ids=jnp.ones(max_length, dtype=int),
start_pos=jnp.arange(0, max_length, 128, dtype=int),
padded_length=length,
true_lengths=jnp.full(8, length, dtype=int),
decode_state=self.decode_state,
)
.compile()
)
min_num_prompts = max_length // length
max_num_prompts = max_length // (length // 2)
possible_prompts = range(min_num_prompts, max_num_prompts)
for num_prompts in possible_prompts:
log.info(f"Compiling batched prefill: {length} num_prompts: {num_prompts}")
self._cached_pref_batch[(length, num_prompts)] = (
jax.jit(
self._prefill_insert_batch,
static_argnames=(
"num_prompts",
"padded_length",
),
donate_argnames=("decode_state",),
)
.lower(
self.params,
tokens=input_data_batch,
slots=jnp.arange(0, 16, dtype=int),
num_prompts=num_prompts,
decoder_positions=jnp.arange(0, max_length, dtype=int),
decoder_segment_ids=jnp.ones(max_length, dtype=int),
start_pos=jnp.arange(0, max_length, 64, dtype=int),
padded_length=length,
true_lengths=jnp.full(16, length, dtype=int),
decode_state=self.decode_state,
)
.compile()
)
self._cached_generate = (
jax.jit(self.engine.generate, donate_argnums=(1,)).lower(self.params, self.decode_state).compile()
)
Expand Down Expand Up @@ -194,7 +199,11 @@ def prefill(prefill_bucket, prefill_len):
if self.dummy:
log.info("dummy prefill")
return 123
if not self.enable_batch_prefill or prefill_len in (64, 1024) or prefill_len * len(prefill_bucket) != 1024:
if (
not self.enable_batch_prefill
or prefill_len == self.max_prefill_length
or prefill_len * len(prefill_bucket) < self.max_prefill_length
):
prefill_result = []
prefill_fn = self._prefill_insert
if (cached := self._cached_pref.get(prefill_len)) is not None:
Expand All @@ -206,27 +215,25 @@ def prefill(prefill_bucket, prefill_len):
prefill_result.append((first_token, slot, row))
return prefill_result
else:
prefill_fn = self._prefill_insert_batch
if (cached := self._cached_pref_batch.get(prefill_len)) is not None:
prefill_fn = cached
positions = np.concatenate([np.arange(0, row.tokens.shape[0]) for (slot, row) in prefill_bucket])
positions = jnp.array(positions)

num_prompts = len(prefill_bucket)
sequence_indicators = []
total_len = 0
for idx, (slot, row) in enumerate(prefill_bucket):
zero_to_n = np.arange(0, row.tokens.shape[0])
ones_to_keep = zero_to_n < row.true_length
one_d_output = (zero_to_n < row.true_length).astype(int) * (idx * 2 + 1) + (zero_to_n >= row.true_length).astype(
int
) * (idx + 1) * 2
sequence_indicators.append(one_d_output)
sequence_indicators.append(np.full(row.true_length, idx * 2 + 1, dtype=int))
total_len += row.true_length
sequence_indicators.append(np.zeros(self.max_prefill_length - total_len, dtype=int))
sequence_indicator = jnp.array(np.concatenate(sequence_indicators))

tokens = jnp.concat([row.tokens for (slot, row) in prefill_bucket])
positions = [np.arange(0, row.true_length) for (slot, row) in prefill_bucket]
positions.append(np.arange(0, self.max_prefill_length - total_len))
positions = jnp.array(np.concatenate(positions))

tokens = [row.tokens[: row.true_length] for (slot, row) in prefill_bucket]
tokens.append(jnp.zeros(self.max_prefill_length - total_len, dtype=int))
tokens = jnp.concat(tokens)
slots = [slot for (slot, row) in prefill_bucket]
true_lengths = [row.true_length for (slot, row) in prefill_bucket]
start_pos = np.cumsum([0] + [row.tokens.shape[0] for (slot, row) in prefill_bucket])[:-1]
start_pos = np.cumsum([0] + [row.true_length for (slot, row) in prefill_bucket])[:-1]
start_pos = start_pos.tolist()

# pad slots to keep static shape of jitted function input
Expand All @@ -235,10 +242,15 @@ def pad_num_prompts_len_array(array_to_pad, pad_len):
array_to_pad.extend([0] * (pad_len - len(array_to_pad)))
return jnp.array(array_to_pad)

slots = pad_num_prompts_len_array(slots, 8)
true_lengths = pad_num_prompts_len_array(true_lengths, 8)
start_pos = pad_num_prompts_len_array(start_pos, 8)
# this lowered function has static input for num_prompts and padded_length
slots = pad_num_prompts_len_array(slots, 16)
true_lengths = pad_num_prompts_len_array(true_lengths, 16)
start_pos = pad_num_prompts_len_array(start_pos, 16)

prefill_fn = self._prefill_insert_batch
log.info(f"invoking compiled function with length {prefill_len} num_prompts {num_prompts}")
if (cached := self._cached_pref_batch.get((prefill_len, num_prompts))) is not None:
prefill_fn = cached

first_tokens, self.decode_state = prefill_fn(
self.params,
tokens=tokens,
Expand All @@ -248,11 +260,18 @@ def pad_num_prompts_len_array(array_to_pad, pad_len):
start_pos=start_pos,
true_lengths=true_lengths,
decode_state=self.decode_state,
) # pytype: disable=missing-parameter
) # pytype: disable=missing-parameter
prefill_result = [(first_tokens[idx], slot, row) for (idx, (slot, row)) in enumerate(prefill_bucket)]

return prefill_result

def prefill_batch(prefill_bucket, padded_len):
nonlocal self
prefill_results = prefill(prefill_bucket, padded_len)
for _first_token, _slot, _row in prefill_results:
log.info(f"Put row of len {_row.tokens.shape[0]} true length {_row.true_length} slot {_slot} to detokenize backlog")
self.detokenize_backlog.put((_first_token, True, _row.id, _slot), block=True)

empty_slots = list(range(self.batch_size))
slot_to_id = {}
num_prefills = {}
Expand Down Expand Up @@ -339,37 +358,57 @@ def detokenize():
log.info(f"decode-{desc}-{num_decodes}")
decode()
# do one insert
num_tokens = len(row.tokens)
num_prefills[num_tokens] = 1 if num_tokens not in num_prefills else num_prefills[num_tokens] + 1
log.info(
f"prefill-{desc}-{num_prefills} num_prefills {sum(num_prefills.values())} num_tokens {num_tokens} true_length {row.true_length} num_empty_slots {len(empty_slots)} num_decodes {num_decodes}"
padded_len = len(row.tokens)
num_prefills[padded_len] = 1 if padded_len not in num_prefills else num_prefills[padded_len] + 1
log.debug(
f"prefill-{desc}-{num_prefills} num_prefills {sum(num_prefills.values())} padded_len {padded_len} true_length {row.true_length} num_empty_slots {len(empty_slots)} num_decodes {num_decodes}"
)
total_num_prefills += 1
log.info(f"Total num prefill: {total_num_prefills}")
slot = empty_slots.pop()
# directly prefill prompts with 64 or less tokens, and with 1024 tokens
if num_tokens in (64, 1024) or not self.enable_batch_prefill:
first_token, slot, row = prefill([(slot, row)], num_tokens)[0]
# directly prefill prompts
if not self.enable_batch_prefill:
first_token, slot, row = prefill([(slot, row)], padded_len)[0]
self.detokenize_backlog.put((first_token, True, row.id, slot), block=True)
continue

if self.prefill_buckets[padded_len // 2] != []:
prefill_batch(self.prefill_buckets[padded_len // 2], padded_len // 2)
self.prefill_buckets[padded_len // 2] = []
if padded_len == self.max_prefill_length:
first_token, slot, row = prefill([(slot, row)], padded_len)[0]
self.detokenize_backlog.put((first_token, True, row.id, slot), block=True)
continue
self.prefill_buckets[num_tokens].append((slot, row))
if padded_len == 64:
row.tokens = jnp.concat([row.tokens, jnp.zeros(64, dtype=int)])
padded_len = 128

self.prefill_buckets[padded_len].append((slot, row))
prefill_buckets_len = {k: len(self.prefill_buckets[k]) for k in self.prefill_buckets}
log.debug(f"prefill buckets {prefill_buckets_len}")
if len(self.prefill_buckets[num_tokens]) * num_tokens == 1024:
prefill_results = prefill(self.prefill_buckets[num_tokens], num_tokens)
for first_token, slot, row in prefill_results:
log.debug(f"Put row of len {row.tokens.shape[0]} true length {row.true_length} slot {slot} to detokenize backlog")
self.detokenize_backlog.put((first_token, True, row.id, slot), block=True)
self.prefill_buckets[num_tokens] = []
if len(self.prefill_buckets[padded_len]) * padded_len >= self.max_prefill_length:
total_true_len = sum([row.true_length for (slot, row) in self.prefill_buckets[padded_len]])
# Can't hold another buffer, prefill right away
if total_true_len > self.max_prefill_length - padded_len // 2 and total_true_len <= self.max_prefill_length:
log.debug(
f"Normal batch {padded_len} total padded len {len(self.prefill_buckets[padded_len]) * padded_len} total true len {total_true_len}"
)
prefill_batch(self.prefill_buckets[padded_len], padded_len)
self.prefill_buckets[padded_len] = []
# Already overloading, left over the last and do prefill
elif total_true_len > self.max_prefill_length:
log.debug(
f"Overloading {padded_len} total padded len {len(self.prefill_buckets[padded_len]) * padded_len} total true len {total_true_len}"
)
current = self.prefill_buckets[padded_len][-1]
prefill_batch(self.prefill_buckets[padded_len][:-1], padded_len)
self.prefill_buckets[padded_len] = [current]
# For leftover requests in buckets at the end of computation, do prefill individually.
for num_tokens in self.prefill_buckets.keys():
prefill_results = prefill(self.prefill_buckets[num_tokens], num_tokens)
for first_token, slot, row in prefill_results:
log.debug(f"Put row of len {row.tokens.shape[0]} true length {row.true_length} slot {slot} to detokenize backlog")
self.detokenize_backlog.put((first_token, True, row.id, slot), block=True)
for padded_len in self.prefill_buckets.keys():
prefill_batch(self.prefill_buckets[padded_len], padded_len)
self.prefill_buckets = defaultdict(list)
while slot_to_id:
log.debug(f"decode-{desc}-{num_decodes} num_filled_slots {len(slot_to_id)}")
log.info(f"decode-{desc}-{num_decodes} num_filled_slots {len(slot_to_id)}")
num_decodes += 1
decode()

Expand All @@ -379,6 +418,18 @@ def detokenize():

def batch_inference(self, data: List[InputData], desc=""):
"""data is list of obj with id, tokens, and true length"""
data_dict = defaultdict(list)
log.info("sorting data")
for row in data:
data_dict[row.tokens.shape[0]].append(row)
data_dict[128] += data_dict[64]
data_dict[64] = []
data = []
for padded_len in np.power([7, 8, 9, 10], 2):
log.info(f"padded len: {padded_len}, num: {len(data_dict[padded_len])}")
random.shuffle(data_dict[padded_len])
data += data_dict[padded_len]
log.info("finished sorting data")
res = defaultdict(list)

def callback(id_, token):
Expand Down
4 changes: 2 additions & 2 deletions MaxText/inference_mlperf/offline_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@

def pad_tokens(tokens):
true_length = len(tokens)
target_length = max(int(2 ** math.ceil(math.log2(true_length))), 32)
target_length = max(int(2 ** math.ceil(math.log2(true_length))), 128)
padded = tokens + [0] * (target_length - true_length)
return padded, true_length

Expand Down Expand Up @@ -347,7 +347,6 @@ def flush_queries(self):
self.offline_inf_instances[group_idx].init_decode_state()
result = self.offline_inf_instances[group_idx].batch_inference(group, desc=f"batch-{group_idx}")
self.offline_inf_instances[group_idx].decode_state = None
gc.collect()
for key, val in result.items():
if not val:
log.info(f"Value empty for key {key}")
Expand All @@ -359,6 +358,7 @@ def flush_queries(self):

log.info("Flush queries end")
end = time.perf_counter()
gc.collect()

def LoadSamplesToRam(self, sample_list):
"""Pads the data, move them to jax array on device"""
Expand Down
2 changes: 1 addition & 1 deletion MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def copy(path, partial_cache, full_cache, annotations):
full_cache = jax.lax.dynamic_update_index_in_dim(full_cache, zeros, slot, batch_idx)
## copy prefill cache
partial_cache = jax.lax.dynamic_slice(partial_cache, (0, start_idx), (1, seq_len))
partial_cache = jnp.mod(partial_cache, 2)
partial_cache = (partial_cache == partial_cache[0, 0]).astype(int)
full_cache = jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx)
return full_cache
elif path_key == "cached_ar_lengths":
Expand Down

0 comments on commit 0394b0d

Please sign in to comment.