Skip to content

Commit

Permalink
Merge pull request #1162 from AI-Hypercomputer:sixiang-prefill-packing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715860279
  • Loading branch information
maxtext authors committed Jan 15, 2025
2 parents ed5bb31 + 67aa0a6 commit 5530f34
Show file tree
Hide file tree
Showing 5 changed files with 402 additions and 25 deletions.
11 changes: 9 additions & 2 deletions MaxText/inference_mlperf/llama_offline_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dry_run=false
skip_warmup=false
test_run=false
enable_profiler=false
enable_batch_prefill=false
performance=true
audit=false
accuracy=false
Expand All @@ -22,6 +23,7 @@ for arg in "$@"; do
-t) test_run=true ;;
-s) skip_warmup=true ;;
-p) enable_profiler=true ;;
-c) enable_batch_prefill=true ;;
-d) audit=true ;;
-a) accuracy=true ;;
-f) fast_eval=true ;;
Expand Down Expand Up @@ -51,6 +53,11 @@ if "$enable_profiler"; then
PROFILER_OPTION="--enable_profile"
fi

BATCH_PREFILL_OPTION=""
if "$enable_batch_prefill"; then
BATCH_PREFILL_OPTION="--enable_batch_prefill"
fi

if [ -z "$TOKENIZER_PATH" ]; then
TOKENIZER_PATH=/home/${USER}/maxtext/assets/tokenizer.llama2
fi
Expand Down Expand Up @@ -90,7 +97,7 @@ export API_URL=0.0.0.0:9000
if "$test_run"; then
export DATASET_TYPE=test
export DATASET_PATH=${DATA_DISK_DIR}/processed-data.pkl
export TOTAL_SAMPLE_COUNT=100
export TOTAL_SAMPLE_COUNT=1000
export USER_CONFIG=user${TOTAL_SAMPLE_COUNT}.conf
else
export DATASET_TYPE=full
Expand Down Expand Up @@ -131,7 +138,7 @@ run_loadgen() {
--maxengine_args "${MAXENGINE_ARGS}" \
--output_log_dir ${OUTPUT_LOG_DIR} \
--tok_outlen_multiplier ${TOK_OUTLEN_MULTIPLIER} \
${SKIP_WARMUP_OPTION} ${PROFILER_OPTION} 2>&1 | tee ${OUTPUT_LOG_DIR}/${LOADGEN_RUN_TYPE}_log.log
${SKIP_WARMUP_OPTION} ${PROFILER_OPTION} ${BATCH_PREFILL_OPTION} 2>&1 | tee ${OUTPUT_LOG_DIR}/${LOADGEN_RUN_TYPE}_log.log
}

run_loadgen_performance () {
Expand Down
174 changes: 154 additions & 20 deletions MaxText/inference_mlperf/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from maxengine import set_engine_vars_from_base_engine

log = logging.getLogger(__name__)
log.setLevel(os.getenv("LOGLEVEL", "INFO"))


@dataclasses.dataclass
Expand All @@ -55,7 +54,7 @@ def run(self):

class OfflineInference:

def __init__(self, engine: engine_api.Engine, params, base_engine: engine_api.Engine):
def __init__(self, engine: engine_api.Engine, params, base_engine: engine_api.Engine, enable_batch_prefill: bool):
self.live = False
self.engine = engine
self.decode_state = None
Expand All @@ -66,15 +65,18 @@ def __init__(self, engine: engine_api.Engine, params, base_engine: engine_api.En
set_engine_vars_from_base_engine(engine, base_engine, rng)
self.params = params

self.enable_batch_prefill = enable_batch_prefill
self.batch_size = engine.max_concurrent_decodes
self.max_decode_length = engine.config.max_target_length - engine.config.max_prefill_predict_length
metadata = engine.get_tokenizer()
self.tokenizer = engine.build_tokenizer(metadata)
self.dummy = False

self._cached_pref = {}
self._cached_pref_batch = {}
self._cached_generate = None
self.detokenize_backlog = queue.Queue(10)
self.prefill_buckets = defaultdict(list)

def init_decode_state(self):
if self.decode_state is None:
Expand All @@ -83,7 +85,6 @@ def init_decode_state(self):
def warmup(self, max_length, warmup_samples):
self.init_decode_state()
interesting_buckets = [
32,
64,
128,
256,
Expand All @@ -102,6 +103,24 @@ def warmup(self, max_length, warmup_samples):
.lower(self.params, tokens=input_data, slot=0, true_length=length - 1, decode_state=self.decode_state)
.compile()
)
# input_data_batch = jax.ShapeDtypeStruct((max_length,), jnp.dtype("int32"))
# example_seq_len=16
# num_prompts = max_length//length
# self._cached_pref_batch[length] = (
# jax.jit(self._prefill_insert_batch, donate_argnums=(4,))
# .lower(
# self.params,
# tokens=input_data_batch,
# slots=jnp.arange(0, example_seq_len),
# num_prompts = 16,
# decoder_positions = jnp.arange(0, max_length),
# decoder_segment_ids = jnp.ones(max_length),
# start_pos=jnp.arange(0, max_length, max_length//example_seq_len),
# padded_lengths=jnp.arange(0, max_length, max_length//example_seq_len),
# true_lengths=jnp.arange(0, max_length, max_length//example_seq_len),
# decode_state=self.decode_state)
# .compile()
# )
self.batch_inference(warmup_samples, desc="warmup")
self._cached_generate = (
jax.jit(self.engine.generate, donate_argnums=(1,)).lower(self.params, self.decode_state).compile()
Expand All @@ -110,9 +129,48 @@ def warmup(self, max_length, warmup_samples):
def _prefill_insert(self, params, tokens, slot, true_length, decode_state):
"""return decodestate."""
prefill_result, first_token = self.engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = self.engine.insert(prefill_result, decode_state, slot=slot)
decode_state = self.engine.insert(prefill_result, decode_state, slot)
return first_token, decode_state

def _prefill_insert_batch(
self,
params,
tokens,
slots,
num_prompts,
decoder_positions,
decoder_segment_ids,
start_pos,
padded_lengths,
true_lengths,
decode_state,
):
"""return decodestate."""
prefill_results, first_tokens = self.engine.prefill_concat(
params=params,
padded_tokens=tokens,
decoder_positions=decoder_positions,
decoder_segment_ids=decoder_segment_ids,
start_pos=start_pos,
true_lengths=true_lengths,
num_prompts=num_prompts,
)
# decode_state = jax.lax.fori_loop(
# 0, num_prompts,
# lambda i, state: self.engine.insert(
# prefill_results[i],
# state,
# slot=slots[i],
# start_idx = start_pos[i],
# seq_len = padded_lengths[i]),
# decode_state
# )
for i in range(num_prompts):
decode_state = self.engine.insert_partial(
prefill_results[i], decode_state, slots[i], start_idx=start_pos[i].item(), seq_len=padded_lengths[i].item()
)
return first_tokens, decode_state

def batch_inference_with_callback(
self,
data: List[InputData],
Expand All @@ -125,20 +183,73 @@ def batch_inference_with_callback(
token.
"""

def prefill(slot, tokens, true_length):
def prefill(prefill_bucket, prefill_len):
nonlocal self
if self.dummy:
log.info("dummy prefill")
return 123
if not self.enable_batch_prefill or prefill_len * len(prefill_bucket) != 1024:
prefill_result = []
prefill_fn = self._prefill_insert
if (cached := self._cached_pref.get(prefill_len)) is not None:
prefill_fn = cached
for slot, row in prefill_bucket:
first_token, self.decode_state = prefill_fn(
self.params, tokens=row.tokens, slot=slot, true_length=row.true_length, decode_state=self.decode_state
)
prefill_result.append((first_token, slot, row))
return prefill_result
else:
prefill_fn = self._prefill_insert_batch
if (cached := self._cached_pref_batch.get(prefill_len)) is not None:
prefill_fn = cached
positions = np.concatenate([np.arange(0, row.tokens.shape[0]) for (slot, row) in prefill_bucket])
positions = jnp.array(positions)

sequence_indicators = []
for idx, (slot, row) in enumerate(prefill_bucket):
zero_to_n = np.arange(0, row.tokens.shape[0])
ones_to_keep = zero_to_n < row.true_length
one_d_output = (zero_to_n < row.true_length).astype(int) * (idx * 2 + 1) + (zero_to_n >= row.true_length).astype(
int
) * (idx + 1) * 2
sequence_indicators.append(one_d_output)
sequence_indicator = jnp.array(np.concatenate(sequence_indicators))

tokens = jnp.concat([row.tokens for (slot, row) in prefill_bucket])

slots = [slot for (slot, row) in prefill_bucket]
padded_lengths = [row.tokens.shape[0] for (slot, row) in prefill_bucket]
true_lengths = [row.true_length for (slot, row) in prefill_bucket]
start_pos = np.cumsum([0] + [row.tokens.shape[0] for (slot, row) in prefill_bucket])[:-1]
start_pos = start_pos.tolist()

# pad slots to keep static shape of jitted function input
def pad_num_prompts_len_array(array_to_pad, pad_len):
if len(array_to_pad) < pad_len:
array_to_pad.extend([0] * (pad_len - len(array_to_pad)))
return jnp.array(array_to_pad)

slots = pad_num_prompts_len_array(slots, 16)
padded_lengths = pad_num_prompts_len_array(padded_lengths, 16)
true_lengths = pad_num_prompts_len_array(true_lengths, 16)
start_pos = pad_num_prompts_len_array(start_pos, 16)

first_tokens, self.decode_state = prefill_fn(
self.params,
tokens=tokens,
slots=slots,
num_prompts=len(prefill_bucket),
decoder_positions=positions,
decoder_segment_ids=sequence_indicator,
start_pos=start_pos,
padded_lengths=padded_lengths,
true_lengths=true_lengths,
decode_state=self.decode_state,
)
prefill_result = [(first_tokens[idx], slot, row) for (idx, (slot, row)) in enumerate(prefill_bucket)]

prefill_fn = self._prefill_insert
if (cached := self._cached_pref.get(len(tokens))) is not None:
prefill_fn = cached

first_token, self.decode_state = prefill_fn(
self.params, tokens=tokens, slot=slot, true_length=true_length, decode_state=self.decode_state
)
return first_token
return prefill_result

empty_slots = list(range(self.batch_size))
slot_to_id = {}
Expand Down Expand Up @@ -170,6 +281,7 @@ def decode():
self.decode_state, result_tokens = gen_fn(self.params, self.decode_state)
result_tokens_l.append(result_tokens)
for i in range(5):
# result_tokens.copy_to_host_async()
result_tokens = result_tokens_l[i].convert_to_numpy()
self.detokenize_backlog.put((result_tokens, False, 0, 0), block=True)
# log.info(f"Decode put result {i} to queue")
Expand All @@ -182,6 +294,7 @@ def detokenize():
# log.info("Detokenize start")
newly_empty = []
result_tokens, is_first_token, row_id, _slot = self.detokenize_backlog.get(block=True)
# result_tokens = result_tokens.convert_to_numpy()
# log.info("Detokenize get from queue")
if is_first_token:
first_token = result_tokens.data[0][0].item()
Expand All @@ -199,7 +312,7 @@ def detokenize():
should_finish = emit_token(id_, token.item())
if should_finish or length >= self.max_decode_length:
newly_empty.append(slot)
log.info(f"Detokenize free up {slot}, length {length}")
log.debug(f"Detokenize free up {slot}, length {length}")
# Add slots of those that are empty to empty
for slot in newly_empty:
del slot_to_id[slot]
Expand All @@ -215,6 +328,7 @@ def detokenize():
)
self.live = True
detokenize_thread.start()
total_num_prefills = 0
for row in data:
while not empty_slots:
# If slots are all full, decode until there are free slots
Expand All @@ -224,16 +338,36 @@ def detokenize():
decode()
# do one insert
num_tokens = len(row.tokens)
num_prefills[num_tokens] = 0 if num_tokens not in num_prefills else num_prefills[num_tokens] + 1
num_prefills[num_tokens] = 1 if num_tokens not in num_prefills else num_prefills[num_tokens] + 1
log.info(
f"prefill-{desc}-{num_prefills} num_prefills {sum(num_prefills.values())} num_tokens {num_tokens} true_length {row.true_length} num_empty_slots {len(empty_slots)} num_decodes {num_decodes}"
)
total_num_prefills += 1
log.info(f"Total num prefill: {total_num_prefills}")
slot = empty_slots.pop()
first_token = prefill(slot, row.tokens, row.true_length)
self.detokenize_backlog.put((first_token, True, row.id, slot), block=True)

# directly prefill prompts with 64 or less tokens
if num_tokens == 64:
first_token, slot, row = prefill([(slot, row)], 64)[0]
self.detokenize_backlog.put((first_token, True, row.id, slot), block=True)
continue
self.prefill_buckets[num_tokens].append((slot, row))
prefill_buckets_len = {k: len(self.prefill_buckets[k]) for k in self.prefill_buckets}
log.debug(f"prefill buckets {prefill_buckets_len}")
if len(self.prefill_buckets[num_tokens]) * num_tokens == 1024:
prefill_results = prefill(self.prefill_buckets[num_tokens], num_tokens)
for first_token, slot, row in prefill_results:
log.debug(f"Put row of len {row.tokens.shape[0]} true length {row.true_length} slot {slot} to detokenize backlog")
self.detokenize_backlog.put((first_token, True, row.id, slot), block=True)
self.prefill_buckets[num_tokens] = []
# For leftover requests in buckets at the end of computation, do prefill individually.
for num_tokens in self.prefill_buckets.keys():
prefill_results = prefill(self.prefill_buckets[num_tokens], num_tokens)
for first_token, slot, row in prefill_results:
log.debug(f"Put row of len {row.tokens.shape[0]} true length {row.true_length} slot {slot} to detokenize backlog")
self.detokenize_backlog.put((first_token, True, row.id, slot), block=True)
self.prefill_buckets = defaultdict(list)
while slot_to_id:
log.info(f"decode-{desc}-{num_decodes} num_filled_slots {len(slot_to_id)}")
log.debug(f"decode-{desc}-{num_decodes} num_filled_slots {len(slot_to_id)}")
num_decodes += 1
decode()

Expand All @@ -248,7 +382,7 @@ def batch_inference(self, data: List[InputData], desc=""):
def callback(id_, token):
nonlocal res
if token == self.tokenizer.eos_id:
log.info(f"res[{id_}] eos")
log.debug(f"res[{id_}] eos")
if not res[id_] or res[id_][-1] != self.tokenizer.eos_id:
res[id_].append(token)
return token == self.tokenizer.eos_id
Expand Down
9 changes: 8 additions & 1 deletion MaxText/inference_mlperf/offline_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@
required=False,
)

flags.DEFINE_bool(
"enable_batch_prefill",
False,
"If set, enable batch prefilling.",
required=False,
)

flags.DEFINE_bool(
"skip_warmup",
False,
Expand Down Expand Up @@ -463,7 +470,7 @@ def main(argv):
max_target_length=target_length,
args_str=FLAGS.maxengine_args,
)
offline_inf = offline_inference.OfflineInference(engine, params, base_engine)
offline_inf = offline_inference.OfflineInference(engine, params, base_engine, FLAGS.enable_batch_prefill)
if params is None and offline_inf.params is not None:
base_engine = engine
params = offline_inf.params
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
run_name="trillium_llama2-70b"
dry_run=false
enable_profiler=false
enable_batch_prefill=false
enable_xla_flags=false
single_bucket=false
token_multiplier=3.0
Expand Down Expand Up @@ -36,6 +37,7 @@ for arg in "$@"; do
-t) test_mode=true ;;
-s) single_bucket=true ;;
-x) enable_xla_flags=true ;;
-c) enable_batch_prefill=true ;;
-r=*|--run=*) run_name="${arg#*=}" ;;
-r|--run) shift; run_name="$1" ;;
-m=*|--multiplier=*) token_multiplier="${arg#*=}" ;;
Expand Down Expand Up @@ -68,6 +70,10 @@ if "$test_mode"; then
RUN_OPTIONS="${RUN_OPTIONS} -t "
fi

if "$enable_batch_prefill"; then
RUN_OPTIONS="${RUN_OPTIONS} -c "
fi

if "$single_bucket"; then
export BATCH_AND_PREFILL_LEN="1024,54"
else
Expand Down Expand Up @@ -104,7 +110,7 @@ run_benchmark() {
local type=$1
case "$type" in
"performance")
$cmd bash llama_offline_run.sh -r benchmarks_performance_${RUN_DESC} ${RUN_OPTIONS}
$cmd bash llama_offline_run.sh ${RUN_OPTIONS} -r benchmarks_performance_${RUN_DESC}
;;
"audit")
$cmd bash llama_offline_run.sh -r benchmarks_audit_${RUN_DESC} -d
Expand Down
Loading

0 comments on commit 5530f34

Please sign in to comment.