Skip to content

Commit

Permalink
make lance's change work for mixtral
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Jul 23, 2024
1 parent 4fd8117 commit e1a6068
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 6 deletions.
19 changes: 18 additions & 1 deletion benchmarks/mixtral_offline.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
<<<<<<< HEAD
CACHE_LENGTH=1024
INPUT_SIZE=512
OUTPUT_SIZE=1024
BATCH_SIZE=512
=======
CACHE_LENGTH=$1
BATCH_SIZE=$2
INPUT_SIZE=1024
OUTPUT_SIZE=1024
>>>>>>> b7a2310 (make lance's change work for mixtral)
CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/
pushd ..
python -m benchmarks.run_offline \
<<<<<<< HEAD
=======
--lazy_cache_update=1 \
--ring_buffer=0 \
>>>>>>> b7a2310 (make lance's change work for mixtral)
--model_name=mixtral \
--batch_size=$BATCH_SIZE \
--max_cache_length=$CACHE_LENGTH \
Expand All @@ -17,4 +29,9 @@ python -m benchmarks.run_offline \
--quantize_type=int8_per_channel \
--quantize_kv_cache=1 \
--profiling_output=/mnt/disks/hanq/mixtral-profiles
popd
<<<<<<< HEAD
popd
=======
popd
echo "batch was $2 cache was $1"
>>>>>>> b7a2310 (make lance's change work for mixtral)
117 changes: 117 additions & 0 deletions benchmarks/offline_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import math
import pandas as pd
import dataclasses
from collections import defaultdict
from absl import flags, app

from typing import Dict

FLAGS = flags.FLAGS

flags.DEFINE_string('dataset_path', '', '')

@dataclasses.dataclass
class Stat:
cache_size: int
batch_size: int
prefill_times: Dict[int, float]
decode_time: float

scenario1 = [
Stat(
cache_size = 512,
batch_size = 2048,
prefill_times = {
16: 0.016024088603444397,
32: 0.021154335999926843,
64: 0.02999803279999469,
128: 0.043986773600045125, 256: 0.07524209819985117, 512: 0.13882793779994246},
decode_time = 0.28033976474989686
),
Stat(
cache_size = 1280,
batch_size = 512,
prefill_times = {
16: 0.016024088603444397,
32: 0.020686019999993734, 64: 0.02952769919993443, 128: 0.04383329960000992, 256: 0.07538782240008005, 512: 0.13893127239989553, 1024: 0.2693996697998955},
decode_time=0.11505070800001249,
),
Stat(
cache_size = 3072,
batch_size = 256,
prefill_times = {32: 0.021193669800049976, 64: 0.030565194799964956, 128: 0.04334795760005363, 256: 0.07586566419995507, 512: 0.13899565000010625, 1024: 0.26945373279995694, 2048: 0.35605709000010394},
decode_time = 0.06467210225014242,
)
]

scenario2 = [
Stat(
cache_size = 3072,
batch_size = 256,
prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882},
decode_time = 0.0631,
),
Stat(
cache_size = 3072,
batch_size = 256,
prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882},
decode_time = 0.0631,
),
Stat(
cache_size = 3072,
batch_size = 256,
prefill_times= {16: 0.018725800199899823, 32: 0.02242145979980705, 64: 0.02536285559981479, 128: 0.034608948799723295, 256: 0.0560826786000689, 512: 0.10566568380017997, 1024: 0.20719572800007882},
decode_time = 0.0631,
)
]
def eval_scenario(dataset, scenario):

total_input_tokens = 0
total_output_tokens = 0
total_prefill_times = defaultdict(float)
total_decode_times = defaultdict(float)
output_tokens_by_bucket = defaultdict(int)
for _, data in dataset.iterrows():
stat = scenario[data.bucket]
total_input_tokens += data.tok_input_len
total_output_tokens += data.tok_ref_output_len
input_len_bucket = 2**math.ceil(math.log2(data.tok_input_len))
if input_len_bucket == 2048 and data.bucket == 1:
import pdb; pdb.set_trace()
total_prefill_times[input_len_bucket] += stat.prefill_times[input_len_bucket]
output_tokens_by_bucket[data.bucket] += data.tok_ref_output_len

for k in output_tokens_by_bucket.keys():
stat = scenario[k]
total_decode_times[k] = output_tokens_by_bucket[k] / stat.batch_size * scenario[k].decode_time

prefill_total = sum(total_prefill_times.values())
decode_total = sum(total_decode_times.values())
print('Total input tokens', total_input_tokens)
print('Total output tokens', total_output_tokens)
print('Input / output', total_input_tokens / total_output_tokens)
print('Prefill times', total_prefill_times)
print('pref throughput', total_input_tokens / sum(total_prefill_times.values()))
print('decode times', total_decode_times)
print('decode throughput', total_output_tokens / sum(total_decode_times.values()) )
print('overall throughput',
total_output_tokens /
(sum(total_decode_times.values()) + sum(total_prefill_times.values())))
print('prefill total time', prefill_total)
print('decode total time', decode_total)



def main(argv):
dataset = pd.read_pickle(FLAGS.dataset_path)
total_len = dataset.tok_input_len + dataset.tok_ref_output_len
bucket = 0 + (total_len > 512) + ((total_len > 1280) | (dataset.tok_input_len > 1024))
dataset.insert(2, 'bucket', bucket)
eval_scenario(dataset, scenario1)
print('======== scenario 2 ========')
eval_scenario(dataset, scenario2)

if __name__ == '__main__':
app.run(main)


1 change: 1 addition & 0 deletions benchmarks/run_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
256: 23.59,
512: 35.28,
1024: 60.28,
2048: 60.28,
}


Expand Down
14 changes: 9 additions & 5 deletions jetstream_pt/third_party/mixtral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,15 @@ def forward(
bsz, seqlen = idx.shape
freqs_cis = self.freqs_cis[input_pos]
freqs_cis = freqs_cis.reshape(bsz, seqlen, -1)
assert len(caches) == len(
self.layers
), f"Number of caches ({len(caches)}) and layers ({len(self.layers)}) dont match"
for layer, cache in zip(self.layers, caches):
with jax.named_scope("TransformerBlock"):

for layer_id, layer in enumerate(self.layers):
if caches[0].stacked:
cache = caches[0]
else:
cache = caches[layer_id]
# else: # For stacked case, there is only 1 yer of kv cache

with jax.named_scope("TransformerBlock_Layer_" + str(layer_id)):
x = layer(
x,
freqs_cis,
Expand Down

0 comments on commit e1a6068

Please sign in to comment.