diff --git a/jetstream_pt/third_party/mixtral/model.py b/jetstream_pt/third_party/mixtral/model.py
index 1602e886..422f4990 100644
--- a/jetstream_pt/third_party/mixtral/model.py
+++ b/jetstream_pt/third_party/mixtral/model.py
@@ -22,8 +22,10 @@
 from torch.nn import functional as F
 from .config import ModelArgs, find_multiple
 from jetstream_pt.layers import Attention, get_quantized_linear_layer, get_quantized_enbedding_layer
+from jetstream_pt import quantize, torchjax
 
 import jax
+import jax.numpy as jnp
 
 
 class Transformer(nn.Module):
@@ -233,6 +235,31 @@ def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
     else:
       return self.forward_for_short_seq_len(x, expert_indices)
 
+  def _int_ti_eoi_teo(self, lhs, rhs):
+      # x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler)
+      result = torchjax.call_jax(
+          jax.lax.dot_general,
+          lhs,
+          rhs,
+          (((1,), (2)), ((), ())),
+          None,
+          jnp.bfloat16.dtype,
+      )
+      return result
+
+  def _int_teo_eio_tei(self, lhs, rhs):
+    #torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler
+    result = torchjax.call_jax(
+        jax.lax.dot_general,
+        lhs,
+        rhs,
+        (((2,), (2,)), ((1, ), (0, ))),
+        None,
+        jnp.bfloat16.dtype,
+    ) # output is (eti) for some reason
+    return result.transpose(0, 1)
+
+
   def forward_for_short_seq_len(
       self, x: Tensor, expert_indices: Tensor
   ) -> Tensor:
@@ -260,14 +287,20 @@ def forward_for_long_seq_len(self, x, expert_indices):
     # o = config.imtermediate size
     # i = config.dim
     with jax.named_scope("conditional_ff"):
-      x1 = F.silu(torch.einsum("ti,eoi -> teo", x, self.w1) * self.w1_scaler)
-      x3 = torch.einsum("ti, eoi-> teo", x, self.w3) * self.w3_scaler
+      x_int, x_scaler, _ = quantize.quantize_tensor(x, (1,))
+      x_scaler = x_scaler.reshape(seqlen, 1, 1)
+
+      x1 = F.silu(self._int_ti_eoi_teo(x_int, self.w1) * self.w1_scaler * x_scaler)
+      x3 = self._int_ti_eoi_teo(x_int, self.w3) * self.w3_scaler * x_scaler
+
+      x1x3_int, x1x3_scaler, _ = quantize.quantize_tensor(x1 * x3, (1, 2))
+      x1x3_scaler = x1x3_scaler.reshape(seqlen, 1, 1)
       expert_outs = (
-          torch.einsum("teo, eio -> tei", (x1 * x3), self.w2) * self.w2_scaler
+          self._int_teo_eio_tei(x1x3_int, self.w2) * self.w2_scaler
       )
       # e = 8; need to reduce to 2
       seq_indexes = torch.arange(seqlen).unsqueeze(1)
-      return expert_outs[seq_indexes, expert_indices]
+      return expert_outs[seq_indexes, expert_indices] * x1x3_scaler
 
 
 class ConditionalFeedForward(nn.Module):
diff --git a/mlperf/backend.py b/mlperf/backend.py
index 806eb727..fb168e51 100644
--- a/mlperf/backend.py
+++ b/mlperf/backend.py
@@ -283,6 +283,7 @@ def __init__(
         self.dataset.LoadSamplesToRam,
         self.dataset.UnloadSamplesFromRam,
     )
+    log.info(f'DATA set size: {self.dataset.total_sample_count} / {self.dataset.perf_count}')
     self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries)
 
   def load_tokenizer(
diff --git a/mlperf/benchmark_run.sh b/mlperf/benchmark_run.sh
index 946c301a..ded415de 100755
--- a/mlperf/benchmark_run.sh
+++ b/mlperf/benchmark_run.sh
@@ -2,7 +2,7 @@ BASEDIR=mlperf
 API_URL=0.0.0.0:9000
 USER_CONFIG=$BASEDIR/user.conf
 DATA_DISK_DIR=$BASEDIR/data
-TOTAL_SAMPLE_COUNT=1000
+TOTAL_SAMPLE_COUNT=15000
 DATASET_PATH=$BASEDIR/data/mixtral_15k_data.pkl
 
 # HF model id
@@ -29,4 +29,4 @@ python -m mlperf.main \
 	--tokenizer-path ${TOKENIZER_PATH} \
 	--log-interval 1000 \
 	--output-log-dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/server_accuracy_log.log
-popd
\ No newline at end of file
+popd
diff --git a/mlperf/mlperf.conf b/mlperf/mlperf.conf
index 9400d0af..8db7027d 100644
--- a/mlperf/mlperf.conf
+++ b/mlperf/mlperf.conf
@@ -88,7 +88,7 @@ gptj.Offline.min_query_count = 13368
 rnnt.Offline.min_query_count = 2513
 3d-unet.Offline.min_query_count = 43
 stable-diffusion-xl.Offline.min_query_count = 5000
-llama2-70b.Offline.min_query_count = 1000
+llama2-70b.Offline.min_query_count = 15000
 mixtral-8x7b.Offline.min_query_count = 1000
 
 # These fields should be defined and overridden by user.conf.
diff --git a/mlperf/start_server.sh b/mlperf/start_server.sh
index 74f9d6b3..4a334b1a 100755
--- a/mlperf/start_server.sh
+++ b/mlperf/start_server.sh
@@ -1,14 +1,16 @@
 #!/usr/bin/env bash
 
 CACHE_LENGTH=3072
-INPUT_SIZE=512
-OUTPUT_SIZE=512
-CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized/
+INPUT_SIZE=2048
+OUTPUT_SIZE=1024
+CHECKPOINT_PATH=mlperf/data/mixtral-instruct-quantized
 
 pushd ..
 python run_server.py \
+  --lazy_cache_update=1 \
+  --ring_buffer=0 \
   --model_name=mixtral \
-  --batch_size=128 \
+  --batch_size=256 \
   --max_cache_length=$CACHE_LENGTH \
   --max_decode_length=$OUTPUT_SIZE \
   --context_length=$INPUT_SIZE \
@@ -17,4 +19,4 @@ python run_server.py \
   --quantize_weights=1 \
   --quantize_type=int8_per_channel \
   --quantize_kv_cache=1
-popd
\ No newline at end of file
+popd
diff --git a/mlperf/user.conf b/mlperf/user.conf
index 2b1fa841..4f776a5b 100644
--- a/mlperf/user.conf
+++ b/mlperf/user.conf
@@ -1,3 +1,3 @@
 mixtral-8x7b.Server.target_qps = 1.8
-mixtral-8x7b.Offline.target_qps = 4.0
+mixtral-8x7b.Offline.target_qps = 20.0