Skip to content

Commit

Permalink
activation quant
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Jul 23, 2024
1 parent e1a6068 commit 936b922
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 13 deletions.
41 changes: 37 additions & 4 deletions jetstream_pt/third_party/mixtral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
from torch.nn import functional as F
from .config import ModelArgs, find_multiple
from jetstream_pt.layers import Attention, get_quantized_linear_layer, get_quantized_enbedding_layer
from jetstream_pt import quantize, torchjax

import jax
import jax.numpy as jnp


class Transformer(nn.Module):
Expand Down Expand Up @@ -233,6 +235,31 @@ def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
else:
return self.forward_for_short_seq_len(x, expert_indices)

def _int_ti_eoi_teo(self, lhs, rhs):
# x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler)
result = torchjax.call_jax(
jax.lax.dot_general,
lhs,
rhs,
(((1,), (2)), ((), ())),
None,
jnp.bfloat16.dtype,
)
return result

def _int_teo_eio_tei(self, lhs, rhs):
#torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler
result = torchjax.call_jax(
jax.lax.dot_general,
lhs,
rhs,
(((2,), (2,)), ((1, ), (0, ))),
None,
jnp.bfloat16.dtype,
) # output is (eti) for some reason
return result.transpose(0, 1)


def forward_for_short_seq_len(
self, x: Tensor, expert_indices: Tensor
) -> Tensor:
Expand Down Expand Up @@ -260,14 +287,20 @@ def forward_for_long_seq_len(self, x, expert_indices):
# o = config.imtermediate size
# i = config.dim
with jax.named_scope("conditional_ff"):
x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler)
x3 = torch.einsum("ti, eoi-> teo", x, self.w3) * self.w3_scaler
x_int, x_scaler, _ = quantize.quantize_tensor(x, (1,))
x_scaler = x_scaler.reshape(seqlen, 1, 1)

x1 = F.silu(self._int_ti_eoi_teo(x_int, self.w1) * self.w1_scaler * x_scaler)
x3 = self._int_ti_eoi_teo(x_int, self.w3) * self.w3_scaler * x_scaler

x1x3_int, x1x3_scaler, _ = quantize.quantize_tensor(x1 * x3, (1, 2))
x1x3_scaler = x1x3_scaler.reshape(seqlen, 1, 1)
expert_outs = (
torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler
self._int_teo_eio_tei(x1x3_int, self.w2) * self.w2_scaler
)
# e = 8; need to reduce to 2
seq_indexes = torch.arange(seqlen).unsqueeze(1)
return expert_outs[seq_indexes, expert_indices]
return expert_outs[seq_indexes, expert_indices] * x1x3_scaler


class ConditionalFeedForward(nn.Module):
Expand Down
1 change: 1 addition & 0 deletions mlperf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def __init__(
self.dataset.LoadSamplesToRam,
self.dataset.UnloadSamplesFromRam,
)
log.info(f'DATA set size: {self.dataset.total_sample_count} / {self.dataset.perf_count}')
self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries)

def load_tokenizer(
Expand Down
4 changes: 2 additions & 2 deletions mlperf/benchmark_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ BASEDIR=mlperf
API_URL=0.0.0.0:9000
USER_CONFIG=$BASEDIR/user.conf
DATA_DISK_DIR=$BASEDIR/data
TOTAL_SAMPLE_COUNT=1000
TOTAL_SAMPLE_COUNT=15000
DATASET_PATH=$BASEDIR/data/mixtral_15k_data.pkl

# HF model id
Expand All @@ -29,4 +29,4 @@ python -m mlperf.main \
--tokenizer-path ${TOKENIZER_PATH} \
--log-interval 1000 \
--output-log-dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log
popd
popd
2 changes: 1 addition & 1 deletion mlperf/mlperf.conf
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ gptj.Offline.min_query_count = 13368
rnnt.Offline.min_query_count = 2513
3d-unet.Offline.min_query_count = 43
stable-diffusion-xl.Offline.min_query_count = 5000
llama2-70b.Offline.min_query_count = 1000
llama2-70b.Offline.min_query_count = 15000
mixtral-8x7b.Offline.min_query_count = 1000

# These fields should be defined and overridden by user.conf.
Expand Down
12 changes: 7 additions & 5 deletions mlperf/start_server.sh
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
#!/usr/bin/env bash

CACHE_LENGTH=3072
INPUT_SIZE=512
OUTPUT_SIZE=512
CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/
INPUT_SIZE=2048
OUTPUT_SIZE=1024
CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized

pushd ..
python run_server.py \
--lazy_cache_update=1 \
--ring_buffer=0 \
--model_name=mixtral \
--batch_size=128 \
--batch_size=256 \
--max_cache_length=$CACHE_LENGTH \
--max_decode_length=$OUTPUT_SIZE \
--context_length=$INPUT_SIZE \
Expand All @@ -17,4 +19,4 @@ python run_server.py \
--quantize_weights=1 \
--quantize_type=int8_per_channel \
--quantize_kv_cache=1
popd
popd
2 changes: 1 addition & 1 deletion mlperf/user.conf
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
mixtral-8x7b.Server.target_qps = 1.8
mixtral-8x7b.Offline.target_qps = 4.0
mixtral-8x7b.Offline.target_qps = 20.0

0 comments on commit 936b922

Please sign in to comment.