From 936b922f70e1d7d8dda185fb25a823a7a46efa4c Mon Sep 17 00:00:00 2001 From: Han Qi Date: Tue, 23 Jul 2024 21:37:08 +0000 Subject: [PATCH] activation quant --- jetstream_pt/third_party/mixtral/model.py | 41 ++++++++++++++++++++--- mlperf/backend.py | 1 + mlperf/benchmark_run.sh | 4 +-- mlperf/mlperf.conf | 2 +- mlperf/start_server.sh | 12 ++++--- mlperf/user.conf | 2 +- 6 files changed, 49 insertions(+), 13 deletions(-) diff --git a/jetstream_pt/third_party/mixtral/model.py b/jetstream_pt/third_party/mixtral/model.py index 1602e886..422f4990 100644 --- a/jetstream_pt/third_party/mixtral/model.py +++ b/jetstream_pt/third_party/mixtral/model.py @@ -22,8 +22,10 @@ from torch.nn import functional as F from .config import ModelArgs, find_multiple from jetstream_pt.layers import Attention, get_quantized_linear_layer, get_quantized_enbedding_layer +from jetstream_pt import quantize, torchjax import jax +import jax.numpy as jnp class Transformer(nn.Module): @@ -233,6 +235,31 @@ def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: else: return self.forward_for_short_seq_len(x, expert_indices) + def _int_ti_eoi_teo(self, lhs, rhs): + # x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler) + result = torchjax.call_jax( + jax.lax.dot_general, + lhs, + rhs, + (((1,), (2)), ((), ())), + None, + jnp.bfloat16.dtype, + ) + return result + + def _int_teo_eio_tei(self, lhs, rhs): + #torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler + result = torchjax.call_jax( + jax.lax.dot_general, + lhs, + rhs, + (((2,), (2,)), ((1, ), (0, ))), + None, + jnp.bfloat16.dtype, + ) # output is (eti) for some reason + return result.transpose(0, 1) + + def forward_for_short_seq_len( self, x: Tensor, expert_indices: Tensor ) -> Tensor: @@ -260,14 +287,20 @@ def forward_for_long_seq_len(self, x, expert_indices): # o = config.imtermediate size # i = config.dim with jax.named_scope("conditional_ff"): - x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler) - x3 = torch.einsum("ti, eoi-> teo", x, self.w3) * self.w3_scaler + x_int, x_scaler, _ = quantize.quantize_tensor(x, (1,)) + x_scaler = x_scaler.reshape(seqlen, 1, 1) + + x1 = F.silu(self._int_ti_eoi_teo(x_int, self.w1) * self.w1_scaler * x_scaler) + x3 = self._int_ti_eoi_teo(x_int, self.w3) * self.w3_scaler * x_scaler + + x1x3_int, x1x3_scaler, _ = quantize.quantize_tensor(x1 * x3, (1, 2)) + x1x3_scaler = x1x3_scaler.reshape(seqlen, 1, 1) expert_outs = ( - torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler + self._int_teo_eio_tei(x1x3_int, self.w2) * self.w2_scaler ) # e = 8; need to reduce to 2 seq_indexes = torch.arange(seqlen).unsqueeze(1) - return expert_outs[seq_indexes, expert_indices] + return expert_outs[seq_indexes, expert_indices] * x1x3_scaler class ConditionalFeedForward(nn.Module): diff --git a/mlperf/backend.py b/mlperf/backend.py index 806eb727..fb168e51 100644 --- a/mlperf/backend.py +++ b/mlperf/backend.py @@ -283,6 +283,7 @@ def __init__( self.dataset.LoadSamplesToRam, self.dataset.UnloadSamplesFromRam, ) + log.info(f'DATA set size: {self.dataset.total_sample_count} / {self.dataset.perf_count}') self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries) def load_tokenizer( diff --git a/mlperf/benchmark_run.sh b/mlperf/benchmark_run.sh index 946c301a..ded415de 100755 --- a/mlperf/benchmark_run.sh +++ b/mlperf/benchmark_run.sh @@ -2,7 +2,7 @@ BASEDIR=mlperf API_URL=0.0.0.0:9000 USER_CONFIG=$BASEDIR/user.conf DATA_DISK_DIR=$BASEDIR/data -TOTAL_SAMPLE_COUNT=1000 +TOTAL_SAMPLE_COUNT=15000 DATASET_PATH=$BASEDIR/data/mixtral_15k_data.pkl # HF model id @@ -29,4 +29,4 @@ python -m mlperf.main \ --tokenizer-path ${TOKENIZER_PATH} \ --log-interval 1000 \ --output-log-dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log -popd \ No newline at end of file +popd diff --git a/mlperf/mlperf.conf b/mlperf/mlperf.conf index 9400d0af..8db7027d 100644 --- a/mlperf/mlperf.conf +++ b/mlperf/mlperf.conf @@ -88,7 +88,7 @@ gptj.Offline.min_query_count = 13368 rnnt.Offline.min_query_count = 2513 3d-unet.Offline.min_query_count = 43 stable-diffusion-xl.Offline.min_query_count = 5000 -llama2-70b.Offline.min_query_count = 1000 +llama2-70b.Offline.min_query_count = 15000 mixtral-8x7b.Offline.min_query_count = 1000 # These fields should be defined and overridden by user.conf. diff --git a/mlperf/start_server.sh b/mlperf/start_server.sh index 74f9d6b3..4a334b1a 100755 --- a/mlperf/start_server.sh +++ b/mlperf/start_server.sh @@ -1,14 +1,16 @@ #!/usr/bin/env bash CACHE_LENGTH=3072 -INPUT_SIZE=512 -OUTPUT_SIZE=512 -CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/ +INPUT_SIZE=2048 +OUTPUT_SIZE=1024 +CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized pushd .. python run_server.py \ + --lazy_cache_update=1 \ + --ring_buffer=0 \ --model_name=mixtral \ - --batch_size=128 \ + --batch_size=256 \ --max_cache_length=$CACHE_LENGTH \ --max_decode_length=$OUTPUT_SIZE \ --context_length=$INPUT_SIZE \ @@ -17,4 +19,4 @@ python run_server.py \ --quantize_weights=1 \ --quantize_type=int8_per_channel \ --quantize_kv_cache=1 -popd \ No newline at end of file +popd diff --git a/mlperf/user.conf b/mlperf/user.conf index 2b1fa841..4f776a5b 100644 --- a/mlperf/user.conf +++ b/mlperf/user.conf @@ -1,3 +1,3 @@ mixtral-8x7b.Server.target_qps = 1.8 -mixtral-8x7b.Offline.target_qps = 4.0 +mixtral-8x7b.Offline.target_qps = 20.0