Skip to content

Commit

Permalink
Merge pull request #1209 from AI-Hypercomputer:msingh-d
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720700899
  • Loading branch information
maxtext authors committed Jan 28, 2025
2 parents ce4cd52 + d4682a0 commit 8623e3a
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 47 deletions.
39 changes: 20 additions & 19 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
import datetime
import jax
import json
import sys

from absl import app
from collections.abc import MutableMapping
from typing import Any, Dict, Optional

from jetstream.engine import token_utils

Expand All @@ -36,14 +35,15 @@
warnings.simplefilter("ignore", category=FutureWarning)

_WARMUP_ITERS = 2

_FLATTEN_MICROBENCHMARK_RESULTS = False
# pylint: disable=too-many-positional-arguments


def prefill_benchmark_loop(engine, params, tokens, true_length, iters):
"""Inner loop for benchmarking prefill step."""
start = datetime.datetime.now()
rng = jax.random.PRNGKey(1234)
prefill_result = None
for _ in range(iters):
rng, rng_prefill = jax.random.split(rng)
prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill)
Expand All @@ -56,6 +56,7 @@ def prefill_benchmark_loop(engine, params, tokens, true_length, iters):
def prefill_benchmark(config, engine, params, tokens, true_length, num_model_params, iters):
"""Handles warmup, running prefill benchmark, and printing results."""
rng = jax.random.PRNGKey(1234)
prefill_result = None
for _ in range(_WARMUP_ITERS):
rng, rng_prefill = jax.random.split(rng)
prefill_result, _ = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length, rng=rng_prefill)
Expand Down Expand Up @@ -163,7 +164,7 @@ def ar_benchmark(config, engine, params, decode_state, global_batch_size, cache_
"step_in_ms_per_seq": ar_average_ms / global_batch_size,
"global_batch_size": global_batch_size,
"total_throughput_tokens_per_second": total_throughput,
"device_bandwidth_GB_per_second": bw_per_device,
"bw_per_device_GB_per_second": bw_per_device,
}
return result_dict, decode_state

Expand Down Expand Up @@ -197,7 +198,7 @@ def write_results(results, filename, flatten_microbenchmark_results):
"""Write the results microbenchmark results to a json file."""
if flatten_microbenchmark_results:
results["flattened_results"] = flatten_dict(results)
if filename != "":
if filename:
with open(filename, "w", encoding="utf-8") as f:
json.dump(results, f, indent=2)
return results
Expand Down Expand Up @@ -246,7 +247,8 @@ def summarize_prefill_result(engine, params, tokens, true_length):
}


def main(config, inference_metadata: Optional[Dict[str, Any]] = None):
def run_benchmarks(config):
"""Run microbenchmarks."""
engine = maxengine.MaxEngine(config)
rng = jax.random.PRNGKey(1234)
rng, rng_load_params = jax.random.split(rng)
Expand Down Expand Up @@ -313,21 +315,20 @@ def main(config, inference_metadata: Optional[Dict[str, Any]] = None):

results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params)
print_results_for_analyze(results)
if inference_metadata:
flatten_microbenchmark_results = pyconfig.string_to_bool(
inference_metadata.get("flatten_microbenchmark_results", "false")
if config.inference_microbenchmark_log_file_path:
write_results(
results,
filename=config.inference_microbenchmark_log_file_path,
flatten_microbenchmark_results=_FLATTEN_MICROBENCHMARK_RESULTS,
)
else:
flatten_microbenchmark_results = "false"
results = write_results(
results,
filename=config.inference_microbenchmark_log_file_path,
flatten_microbenchmark_results=flatten_microbenchmark_results,
)
return results


if __name__ == "__main__":
def main(argv):
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
pyconfig.initialize(sys.argv)
main(pyconfig.config)
pyconfig.initialize(argv)
run_benchmarks(pyconfig.config)


if __name__ == "__main__":
app.run(main)
40 changes: 14 additions & 26 deletions MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
"""Implementation of Engine API for MaxText"""
import copy as cp
import functools
from typing import Any, Optional, Tuple, Callable
from typing import Any, List, Optional, Tuple, Callable
from collections import defaultdict

import flax
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax import struct

from layers import models, quantizations

Expand All @@ -39,26 +38,15 @@
import max_utils
import inference_utils
import pyconfig
import jaxlib

import warnings

warnings.simplefilter("ignore", category=FutureWarning)

DecodeState = Any
Prefix = Any
PackedPrefix = Any
Params = Any


@struct.dataclass
class DecodeState:
"""The inputs into a generation step."""

prefill_cache: jax.Array
generate_cache: jax.Array
generate_cache_index: int
generate_lengths: jax.Array
generated_token: jax.Array
PRNGKeyType = Any


class MaxEngineConfig:
Expand Down Expand Up @@ -110,7 +98,7 @@ def __init__(self, config: Any, devices: config_lib.Devices | None = None):
self.kv_cache_shardings = None
self.state_mesh_annotations = None

def load_params(self, *args, rng: Optional[jax.random.PRNGKey] = None, **kwargs) -> Params:
def load_params(self, *args, rng: Optional[PRNGKeyType] = None, **kwargs) -> Params:
"""Load Parameters, typically from GCS"""
# pylint: disable=unused-argument

Expand All @@ -126,7 +114,7 @@ def load_params(self, *args, rng: Optional[jax.random.PRNGKey] = None, **kwargs)
# pylint: disable=isinstance-second-argument-not-valid-type
self.abstract_params = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)
if isinstance(x, jaxlib.xla_extension.ArrayImpl)
if isinstance(x, jax.Array)
else None,
state.params,
)
Expand Down Expand Up @@ -158,7 +146,7 @@ def load_params(self, *args, rng: Optional[jax.random.PRNGKey] = None, **kwargs)
max_utils.print_mem_stats("After load_params")
return params

def quantize_params(self, state, rng: Optional[jax.random.PRNGKey] = None):
def quantize_params(self, state, rng: Optional[PRNGKeyType] = None):
"""Forward pass to quantize decode params."""
if rng is None:
rng = jax.random.PRNGKey(0)
Expand Down Expand Up @@ -227,7 +215,7 @@ def prefill(
padded_tokens: jax.Array,
true_length: int,
sampler: Optional[Callable[[Any], Any]] = None, # pylint: disable=unused-argument
rng: Optional[jax.random.PRNGKey] = None,
rng: Optional[PRNGKeyType] = None,
) -> Tuple[Prefix, engine_api.ResultTokens]:
"""Computes a kv-cache for a new generate request.
Expand Down Expand Up @@ -325,8 +313,8 @@ def prefill_concat(
true_lengths: jax.Array,
num_prompts: int,
sampler: Optional[Callable[[Any], Any]] = None, # pylint: disable=unused-argument
rng: Optional[jax.random.PRNGKey] = None,
) -> Tuple[Any, PackedPrefix, engine_api.ResultTokens]:
rng: Optional[PRNGKeyType] = None,
) -> Tuple[Any, PackedPrefix, List[engine_api.ResultTokens]]:
"""Computes a kv-cache for a new packed generate request, which is a
concatenation of several shorter prompts. Experimentation shows that
longer prefill sequences gives approximately 15% boost in time per prefilled
Expand Down Expand Up @@ -424,7 +412,7 @@ def generate(
params: Params,
decode_state: DecodeState,
sampler: Optional[Callable[[Any], Any]] = None, # pylint: disable=unused-argument
rng: Optional[jax.random.PRNGKey] = None,
rng: Optional[PRNGKeyType] = None,
) -> Tuple[DecodeState, engine_api.ResultTokens]:
"""Run one generate step"""
if rng is None:
Expand Down Expand Up @@ -718,7 +706,7 @@ def build_tokenizer(self, metadata: tokenizer_pb2.TokenizerParameters) -> tokeni
def init_decode_state(
self,
*args, # pylint: disable=unused-argument
rng: Optional[jax.random.PRNGKey] = None,
rng: Optional[PRNGKeyType] = None,
**kwargs, # pylint: disable=unused-argument
) -> DecodeState:
"""Initialises any state which a generation step transforms."""
Expand Down Expand Up @@ -820,9 +808,9 @@ def colocated_cpus(self) -> None:


def set_engine_vars_from_base_engine(
engine: engine_api.Engine,
base_engine: engine_api.Engine,
rng: jax.random.PRNGKey,
engine: MaxEngine,
base_engine: MaxEngine,
rng: PRNGKeyType,
):
"""Set internal vars from base_engine, which has already loaded the checkpoint and has sharding,
mesh, and kv cache related vars set.
Expand Down
4 changes: 2 additions & 2 deletions MaxText/tests/inference_microbenchmark_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest
import unittest
from absl.testing import absltest
from inference_microbenchmark import main as inference_microbenchmark_main
from inference_microbenchmark import run_benchmarks


class Inference_Microbenchmark(unittest.TestCase):
Expand All @@ -38,7 +38,7 @@ def test(self):
"weight_dtype=bfloat16",
]
)
inference_microbenchmark_main(pyconfig.config)
run_benchmarks(pyconfig.config)


if __name__ == "__main__":
Expand Down

0 comments on commit 8623e3a

Please sign in to comment.