From 653be9cc56e46fd046f1e614b2c9088fe4118ef8 Mon Sep 17 00:00:00 2001 From: Xiang Si Date: Fri, 17 Jan 2025 21:58:52 +0000 Subject: [PATCH] enable concatenation of unpadded prompts --- MaxText/inference_mlperf/offline_inference.py | 181 +++++++++++------- MaxText/inference_mlperf/offline_mode.py | 4 +- MaxText/maxengine.py | 2 +- 3 files changed, 119 insertions(+), 68 deletions(-) diff --git a/MaxText/inference_mlperf/offline_inference.py b/MaxText/inference_mlperf/offline_inference.py index 6c993675a..d5573ffa2 100644 --- a/MaxText/inference_mlperf/offline_inference.py +++ b/MaxText/inference_mlperf/offline_inference.py @@ -24,6 +24,7 @@ import threading import traceback import signal +import random from jetstream.engine import engine_api @@ -67,6 +68,7 @@ def __init__(self, engine: engine_api.Engine, params, base_engine: engine_api.En self.enable_batch_prefill = enable_batch_prefill self.batch_size = engine.max_concurrent_decodes + self.max_prefill_length = engine.config.max_prefill_predict_length self.max_decode_length = engine.config.max_target_length - engine.config.max_prefill_predict_length metadata = engine.get_tokenizer() self.tokenizer = engine.build_tokenizer(metadata) @@ -105,32 +107,35 @@ def warmup(self, max_length, warmup_samples): ) if length == 64 or length == 1024: continue - log.info(f"Compiling batched prefill: {length}") input_data_batch = jax.ShapeDtypeStruct((max_length,), jnp.dtype("int32")) - num_prompts = max_length // length - self._cached_pref_batch[length] = ( - jax.jit( - self._prefill_insert_batch, - static_argnames=( - "num_prompts", - "padded_length", - ), - donate_argnames=("decode_state",), - ) - .lower( - self.params, - tokens=input_data_batch, - slots=jnp.arange(0, 8, dtype=int), - num_prompts=num_prompts, - decoder_positions=jnp.arange(0, max_length, dtype=int), - decoder_segment_ids=jnp.ones(max_length, dtype=int), - start_pos=jnp.arange(0, max_length, 128, dtype=int), - padded_length=length, - true_lengths=jnp.full(8, length, dtype=int), - decode_state=self.decode_state, - ) - .compile() - ) + min_num_prompts = max_length // length + max_num_prompts = max_length // (length // 2) + possible_prompts = range(min_num_prompts, max_num_prompts) + for num_prompts in possible_prompts: + log.info(f"Compiling batched prefill: {length} num_prompts: {num_prompts}") + self._cached_pref_batch[(length, num_prompts)] = ( + jax.jit( + self._prefill_insert_batch, + static_argnames=( + "num_prompts", + "padded_length", + ), + donate_argnames=("decode_state",), + ) + .lower( + self.params, + tokens=input_data_batch, + slots=jnp.arange(0, 16, dtype=int), + num_prompts=num_prompts, + decoder_positions=jnp.arange(0, max_length, dtype=int), + decoder_segment_ids=jnp.ones(max_length, dtype=int), + start_pos=jnp.arange(0, max_length, 64, dtype=int), + padded_length=length, + true_lengths=jnp.full(16, length, dtype=int), + decode_state=self.decode_state, + ) + .compile() + ) self._cached_generate = ( jax.jit(self.engine.generate, donate_argnums=(1,)).lower(self.params, self.decode_state).compile() ) @@ -194,7 +199,11 @@ def prefill(prefill_bucket, prefill_len): if self.dummy: log.info("dummy prefill") return 123 - if not self.enable_batch_prefill or prefill_len in (64, 1024) or prefill_len * len(prefill_bucket) != 1024: + if ( + not self.enable_batch_prefill + or prefill_len == self.max_prefill_length + or prefill_len * len(prefill_bucket) < self.max_prefill_length + ): prefill_result = [] prefill_fn = self._prefill_insert if (cached := self._cached_pref.get(prefill_len)) is not None: @@ -206,27 +215,25 @@ def prefill(prefill_bucket, prefill_len): prefill_result.append((first_token, slot, row)) return prefill_result else: - prefill_fn = self._prefill_insert_batch - if (cached := self._cached_pref_batch.get(prefill_len)) is not None: - prefill_fn = cached - positions = np.concatenate([np.arange(0, row.tokens.shape[0]) for (slot, row) in prefill_bucket]) - positions = jnp.array(positions) - + num_prompts = len(prefill_bucket) sequence_indicators = [] + total_len = 0 for idx, (slot, row) in enumerate(prefill_bucket): - zero_to_n = np.arange(0, row.tokens.shape[0]) - ones_to_keep = zero_to_n < row.true_length - one_d_output = (zero_to_n < row.true_length).astype(int) * (idx * 2 + 1) + (zero_to_n >= row.true_length).astype( - int - ) * (idx + 1) * 2 - sequence_indicators.append(one_d_output) + sequence_indicators.append(np.full(row.true_length, idx * 2 + 1, dtype=int)) + total_len += row.true_length + sequence_indicators.append(np.zeros(self.max_prefill_length - total_len, dtype=int)) sequence_indicator = jnp.array(np.concatenate(sequence_indicators)) - tokens = jnp.concat([row.tokens for (slot, row) in prefill_bucket]) + positions = [np.arange(0, row.true_length) for (slot, row) in prefill_bucket] + positions.append(np.arange(0, self.max_prefill_length - total_len)) + positions = jnp.array(np.concatenate(positions)) + tokens = [row.tokens[: row.true_length] for (slot, row) in prefill_bucket] + tokens.append(jnp.zeros(self.max_prefill_length - total_len, dtype=int)) + tokens = jnp.concat(tokens) slots = [slot for (slot, row) in prefill_bucket] true_lengths = [row.true_length for (slot, row) in prefill_bucket] - start_pos = np.cumsum([0] + [row.tokens.shape[0] for (slot, row) in prefill_bucket])[:-1] + start_pos = np.cumsum([0] + [row.true_length for (slot, row) in prefill_bucket])[:-1] start_pos = start_pos.tolist() # pad slots to keep static shape of jitted function input @@ -235,10 +242,15 @@ def pad_num_prompts_len_array(array_to_pad, pad_len): array_to_pad.extend([0] * (pad_len - len(array_to_pad))) return jnp.array(array_to_pad) - slots = pad_num_prompts_len_array(slots, 8) - true_lengths = pad_num_prompts_len_array(true_lengths, 8) - start_pos = pad_num_prompts_len_array(start_pos, 8) - # this lowered function has static input for num_prompts and padded_length + slots = pad_num_prompts_len_array(slots, 16) + true_lengths = pad_num_prompts_len_array(true_lengths, 16) + start_pos = pad_num_prompts_len_array(start_pos, 16) + + prefill_fn = self._prefill_insert_batch + log.info(f"invoking compiled function with length {prefill_len} num_prompts {num_prompts}") + if (cached := self._cached_pref_batch.get((prefill_len, num_prompts))) is not None: + prefill_fn = cached + first_tokens, self.decode_state = prefill_fn( self.params, tokens=tokens, @@ -248,11 +260,18 @@ def pad_num_prompts_len_array(array_to_pad, pad_len): start_pos=start_pos, true_lengths=true_lengths, decode_state=self.decode_state, - ) # pytype: disable=missing-parameter + ) # pytype: disable=missing-parameter prefill_result = [(first_tokens[idx], slot, row) for (idx, (slot, row)) in enumerate(prefill_bucket)] return prefill_result + def prefill_batch(prefill_bucket, padded_len): + nonlocal self + prefill_results = prefill(prefill_bucket, padded_len) + for _first_token, _slot, _row in prefill_results: + log.info(f"Put row of len {_row.tokens.shape[0]} true length {_row.true_length} slot {_slot} to detokenize backlog") + self.detokenize_backlog.put((_first_token, True, _row.id, _slot), block=True) + empty_slots = list(range(self.batch_size)) slot_to_id = {} num_prefills = {} @@ -339,37 +358,57 @@ def detokenize(): log.info(f"decode-{desc}-{num_decodes}") decode() # do one insert - num_tokens = len(row.tokens) - num_prefills[num_tokens] = 1 if num_tokens not in num_prefills else num_prefills[num_tokens] + 1 - log.info( - f"prefill-{desc}-{num_prefills} num_prefills {sum(num_prefills.values())} num_tokens {num_tokens} true_length {row.true_length} num_empty_slots {len(empty_slots)} num_decodes {num_decodes}" + padded_len = len(row.tokens) + num_prefills[padded_len] = 1 if padded_len not in num_prefills else num_prefills[padded_len] + 1 + log.debug( + f"prefill-{desc}-{num_prefills} num_prefills {sum(num_prefills.values())} padded_len {padded_len} true_length {row.true_length} num_empty_slots {len(empty_slots)} num_decodes {num_decodes}" ) total_num_prefills += 1 log.info(f"Total num prefill: {total_num_prefills}") slot = empty_slots.pop() - # directly prefill prompts with 64 or less tokens, and with 1024 tokens - if num_tokens in (64, 1024) or not self.enable_batch_prefill: - first_token, slot, row = prefill([(slot, row)], num_tokens)[0] + # directly prefill prompts + if not self.enable_batch_prefill: + first_token, slot, row = prefill([(slot, row)], padded_len)[0] + self.detokenize_backlog.put((first_token, True, row.id, slot), block=True) + continue + + if self.prefill_buckets[padded_len // 2] != []: + prefill_batch(self.prefill_buckets[padded_len // 2], padded_len // 2) + self.prefill_buckets[padded_len // 2] = [] + if padded_len == self.max_prefill_length: + first_token, slot, row = prefill([(slot, row)], padded_len)[0] self.detokenize_backlog.put((first_token, True, row.id, slot), block=True) continue - self.prefill_buckets[num_tokens].append((slot, row)) + if padded_len == 64: + row.tokens = jnp.concat([row.tokens, jnp.zeros(64, dtype=int)]) + padded_len = 128 + + self.prefill_buckets[padded_len].append((slot, row)) prefill_buckets_len = {k: len(self.prefill_buckets[k]) for k in self.prefill_buckets} log.debug(f"prefill buckets {prefill_buckets_len}") - if len(self.prefill_buckets[num_tokens]) * num_tokens == 1024: - prefill_results = prefill(self.prefill_buckets[num_tokens], num_tokens) - for first_token, slot, row in prefill_results: - log.debug(f"Put row of len {row.tokens.shape[0]} true length {row.true_length} slot {slot} to detokenize backlog") - self.detokenize_backlog.put((first_token, True, row.id, slot), block=True) - self.prefill_buckets[num_tokens] = [] + if len(self.prefill_buckets[padded_len]) * padded_len >= self.max_prefill_length: + total_true_len = sum([row.true_length for (slot, row) in self.prefill_buckets[padded_len]]) + # Can't hold another buffer, prefill right away + if total_true_len > self.max_prefill_length - padded_len // 2 and total_true_len <= self.max_prefill_length: + log.debug( + f"Normal batch {padded_len} total padded len {len(self.prefill_buckets[padded_len]) * padded_len} total true len {total_true_len}" + ) + prefill_batch(self.prefill_buckets[padded_len], padded_len) + self.prefill_buckets[padded_len] = [] + # Already overloading, left over the last and do prefill + elif total_true_len > self.max_prefill_length: + log.debug( + f"Overloading {padded_len} total padded len {len(self.prefill_buckets[padded_len]) * padded_len} total true len {total_true_len}" + ) + current = self.prefill_buckets[padded_len][-1] + prefill_batch(self.prefill_buckets[padded_len][:-1], padded_len) + self.prefill_buckets[padded_len] = [current] # For leftover requests in buckets at the end of computation, do prefill individually. - for num_tokens in self.prefill_buckets.keys(): - prefill_results = prefill(self.prefill_buckets[num_tokens], num_tokens) - for first_token, slot, row in prefill_results: - log.debug(f"Put row of len {row.tokens.shape[0]} true length {row.true_length} slot {slot} to detokenize backlog") - self.detokenize_backlog.put((first_token, True, row.id, slot), block=True) + for padded_len in self.prefill_buckets.keys(): + prefill_batch(self.prefill_buckets[padded_len], padded_len) self.prefill_buckets = defaultdict(list) while slot_to_id: - log.debug(f"decode-{desc}-{num_decodes} num_filled_slots {len(slot_to_id)}") + log.info(f"decode-{desc}-{num_decodes} num_filled_slots {len(slot_to_id)}") num_decodes += 1 decode() @@ -379,6 +418,18 @@ def detokenize(): def batch_inference(self, data: List[InputData], desc=""): """data is list of obj with id, tokens, and true length""" + data_dict = defaultdict(list) + log.info("sorting data") + for row in data: + data_dict[row.tokens.shape[0]].append(row) + data_dict[128] += data_dict[64] + data_dict[64] = [] + data = [] + for padded_len in np.power([7, 8, 9, 10], 2): + log.info(f"padded len: {padded_len}, num: {len(data_dict[padded_len])}") + random.shuffle(data_dict[padded_len]) + data += data_dict[padded_len] + log.info("finished sorting data") res = defaultdict(list) def callback(id_, token): diff --git a/MaxText/inference_mlperf/offline_mode.py b/MaxText/inference_mlperf/offline_mode.py index ef67cd0e5..4b07626f1 100644 --- a/MaxText/inference_mlperf/offline_mode.py +++ b/MaxText/inference_mlperf/offline_mode.py @@ -191,7 +191,7 @@ def pad_tokens(tokens): true_length = len(tokens) - target_length = max(int(2 ** math.ceil(math.log2(true_length))), 32) + target_length = max(int(2 ** math.ceil(math.log2(true_length))), 128) padded = tokens + [0] * (target_length - true_length) return padded, true_length @@ -347,7 +347,6 @@ def flush_queries(self): self.offline_inf_instances[group_idx].init_decode_state() result = self.offline_inf_instances[group_idx].batch_inference(group, desc=f"batch-{group_idx}") self.offline_inf_instances[group_idx].decode_state = None - gc.collect() for key, val in result.items(): if not val: log.info(f"Value empty for key {key}") @@ -359,6 +358,7 @@ def flush_queries(self): log.info("Flush queries end") end = time.perf_counter() + gc.collect() def LoadSamplesToRam(self, sample_list): """Pads the data, move them to jax array on device""" diff --git a/MaxText/maxengine.py b/MaxText/maxengine.py index 12bbab5b2..f7e006571 100644 --- a/MaxText/maxengine.py +++ b/MaxText/maxengine.py @@ -636,7 +636,7 @@ def copy(path, partial_cache, full_cache, annotations): full_cache = jax.lax.dynamic_update_index_in_dim(full_cache, zeros, slot, batch_idx) ## copy prefill cache partial_cache = jax.lax.dynamic_slice(partial_cache, (0, start_idx), (1, seq_len)) - partial_cache = jnp.mod(partial_cache, 2) + partial_cache = (partial_cache == partial_cache[0, 0]).astype(int) full_cache = jax.lax.dynamic_update_index_in_dim(full_cache, partial_cache, slot, batch_idx) return full_cache elif path_key == "cached_ar_lengths":