diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index b7298e1..44b9406 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -28,8 +28,7 @@ import torch import numpy as np -from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils -from jetstream.engine import sampling_utils +from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils, sampling_utils import torch_xla2 from torch.utils import _pytree as pytree @@ -44,6 +43,7 @@ from jetstream_pt.third_party.mixtral import config as mixtral_config, model as mixtral_model from absl import flags +from collections.abc import Callable FLAGS = flags.FLAGS @@ -60,6 +60,7 @@ class Prefix: token: jax.Array # [1, seqlen] caches: List[Tuple[jax.Array, jax.Array]] seq_len: int # true seqlen front pad + sampler: List[Any] | int # User defined Sampler @struct.dataclass @@ -73,8 +74,12 @@ class DecodeState: current_position: int lens: jax.Array # [batch_size, 1], the output token length start: jax.Array # [batch_size, 1], the starting pos for each slot - input_pos: jax.Array # [batch_size, 1] input pos for each slot + input_pos: ( + jax.Array + ) # [batch_size, 1] total (prefill + decode) length for each slot mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid + # The sampling function + samplers: Any # NOTE model specific @@ -93,7 +98,8 @@ def __init__( self.pt_model = pt_model self.env = env self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 - self.rng = jax.random.PRNGKey(0) + self.rng = jax.random.key(0) + self.weights = weights self.y_sharding = env.sharding_by_axis(1) @@ -119,6 +125,7 @@ def __init__( donate_argnums=(1,), out_shardings=(self.get_decode_state_sharding(), None), ) + # self.generate = self.generate_impl if self.env.page_attention: max_pages_per_sequence = ( @@ -168,6 +175,7 @@ def init_decode_state( scalers = [] if self.env.quant_config.enable_kv_quantization: scalers = [c.scalers() for c in caches_obj] + return DecodeState( jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), caches, @@ -181,6 +189,7 @@ def init_decode_state( float("-inf"), dtype=self.default_dtype, ), # mask + None, ) # pylint: disable-next=all @@ -280,19 +289,42 @@ def _call_model_prefill(self, weights, tokens, input_indexes): caches_res = [c.state() for c in caches] return torchjax.from_torch((res, caches_res)) - def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray: + # Temporarily disabled becuase handling per request sampling is not ready yet. + # @classmethod + # def _custom_sampling(self, logits, samplers) -> jnp.ndarray: + # if len(logits.shape) == 2: + # logits = jnp.expand_dims(logits, 0) + + # logits = logits[:, -1] + + # # Prefill and Generate have different batch size + # current_batch_size = logits.shape[0] + + # idx = jnp.arange(current_batch_size) + # apply_sampler = lambda i, l: jax.lax.switch(i, samplers, l) + # apply_vmap = jax.vmap(apply_sampler, in_axes=(0, 0)) + # return apply_vmap(idx, logits).reshape(current_batch_size, -1) + + + def _sampling( + self, logits: Any, algorithm, rng, temperature, topk, nucleus_topp + ) -> jnp.ndarray: if len(logits.shape) == 2: logits = jnp.expand_dims(logits, 0) + + logits = logits[:, -1] + current_batch_size = logits.shape[0] + return ( sampling_utils.sampling( - logits[:, -1], - self.rng, - self.env.sampling_algorithm, - self.env.topk, - self.env.nucleus_topp, - self.env.temperature, + logits=logits, + rng=rng, + algorithm=algorithm, + topk=topk, + nucleus_topp=nucleus_topp, + temperature=temperature, ) - .reshape(batch_size, -1) + .reshape(current_batch_size, -1) .astype(jnp.int32) ) @@ -301,7 +333,7 @@ def prefill( *, params: Any, # Weights existing_prefix: Optional[Prefix] = None, - padded_tokens: PrefillInputs, # PrefillInputs[jax.Array], + padded_tokens: PrefillInputs, # PrefillInputs[jax.Array] true_length: int, sampler: Optional[Callable[[Any], Any]] = None, ) -> Tuple[Prefix, engine_api.ResultTokens]: @@ -321,6 +353,7 @@ def prefill( ) if len(logits.shape) == 3: # b, seqlen, num words logits = logits[0] # seqlen, num words + if sampler: token = sampler(logits[true_length - 1]) else: @@ -332,6 +365,7 @@ def prefill( self.env.nucleus_topp, self.env.temperature, ) + token = jnp.reshape(token, (1,)) token_out = jnp.reshape(token, (1, 1)) data = jnp.concatenate( [ @@ -357,7 +391,10 @@ def prefill( # v, seq_len - true_length, true_length, axis=2)) # for k, v in updated_caches # ] - return Prefix(token, updated_caches, true_length), result + return ( + Prefix(token, updated_caches, true_length, sampler), + result, + ) def shrink_prefix( self, @@ -476,6 +513,8 @@ def insert(cache, scaler, new_entry, update_index): caches.append((kcache, vcache)) scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) + + sampler = prefix.sampler if prefix.sampler else decode_state.samplers return DecodeState( tokens, caches, @@ -485,6 +524,7 @@ def insert(cache, scaler, new_entry, update_index): start, input_pos, mask, + sampler, ) # pylint: disable-next=all @@ -569,6 +609,9 @@ def insert(cache, scaler, new_entry): scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) + + sampler = prefix.sampler if prefix.sampler else decode_state.samplers + return DecodeState( tokens, caches, @@ -578,6 +621,7 @@ def insert(cache, scaler, new_entry): start, input_pos, mask, + sampler, ) def _insert_page_attention( @@ -613,6 +657,8 @@ def _insert_page_attention( input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len) scales = None lens = decode_state.lens.at[slot].set(1) + + sampler = prefix.sampler if prefix.sampler else decode_state.samplers return DecodeState( tokens, caches, @@ -622,6 +668,7 @@ def _insert_page_attention( start, input_pos, mask, + sampler, ) def insert( @@ -729,7 +776,9 @@ def false_comp(b, i, bk, start, end): return b_next, i_next def generate( - self, params: Any, decode_state: DecodeState, sampler=None + self, + params: Any, + decode_state: DecodeState, ) -> tuple[DecodeState, engine_api.ResultTokens]: return (None, None) @@ -752,7 +801,6 @@ def generate_impl( self, params: Any, decode_state: DecodeState, - sampler=None, page_token_indices=None, ) -> tuple[DecodeState, engine_api.ResultTokens]: # seq_len = padded_tokens.shape[0] @@ -764,12 +812,16 @@ def generate_impl( else: input_indexes = decode_state.input_pos - ragged_batch_index, ragged_block_index = ( - self.precompute_ragged_block_indices(decode_state) - ) - ragged_batch_index, ragged_block_index = ragged_batch_index.reshape( - (-1) - ), ragged_block_index.reshape((-1)) + # TODO(lancewang): Remove ragged index precomputation + # ragged_batch_index, ragged_block_index = ( + # self.precompute_ragged_block_indices(decode_state) + # ) + # ragged_batch_index, ragged_block_index = ragged_batch_index.reshape( + # (-1) + # ), ragged_block_index.reshape((-1)) + + ragged_batch_index = 0 + ragged_block_index = 0 def update_mask(): if self.env.ring_buffer: @@ -799,10 +851,20 @@ def update_mask(): # fill mask later, now use flash attention mask = update_mask() - if sampler: - next_token = sampler(logits[:, -1]) + # Temporarily disabled becuase handling per request sampling is not ready yet. + # next_token = self._custom_sampling(logits, decode_state.samplers) + if decode_state.samplers: + next_token = decode_state.samplers(logits) else: - next_token = self._sampling(logits, self.env.batch_size) + next_token = self._sampling( + logits, + self.env.sampling_algorithm, + self.rng, + self.env.temperature, + self.env.topk, + self.env.nucleus_topp, + ) + if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 lens = decode_state.lens + 1 @@ -844,6 +906,7 @@ def update_mask(): decode_state.start, input_pos, mask, + decode_state.samplers, ) return new_decode_state, result_tokens @@ -963,6 +1026,7 @@ def get_prefix_destination_sharding(self) -> Prefix: if self.env.page_attention else self.cache_sharding, self.replicated, + self.replicated, ) def get_decode_state_sharding(self) -> DecodeState: @@ -976,6 +1040,7 @@ def get_decode_state_sharding(self) -> DecodeState: self.replicated, self.replicated, self.replicated, + self.replicated, ) def get_prefix_sequence_ddim(self) -> Any: diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 4917705..0202711 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -52,7 +52,10 @@ class JetEngineEnvironmentData: batch_size: int = 32 # batch size is generate step batch size cache_sequence_length: int = 2048 # size of the cache. - quant_config: QuantizationConfig = QuantizationConfig() + # quant_config: QuantizationConfig = QuantizationConfig() + quant_config: QuantizationConfig = dataclasses.field( + default_factory=QuantizationConfig + ) model_type: str = "llama-2-13b" # this implies the model config diff --git a/run_interactive.py b/run_interactive.py index 8463658..2e00159 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -23,6 +23,7 @@ from absl import app from jetstream.engine import token_utils from jetstream_pt.config import FLAGS, create_engine_from_config_flags +from jetstream.engine import sampling_utils # pylint: disable-next=all @@ -30,6 +31,21 @@ def main(argv): engine = create_engine_from_config_flags() + rng = jax.random.key(1) + temperature = 1 + topk = 1 + topp = 0.2 + + sampler = jax.tree_util.Partial( + sampling_utils.jittable_sample_topk_logits, + rng=rng, + temperature=temperature, + topk=topk, + ) + # sampler = jax.tree_util.Partial(sampling_utils.jittable_sample_topp_logits, rng=rng, temperature=temperature, topp=topp) + # sampler = jax.tree_util.Partial(sampling_utils.jittable_sample_greedy_logits) + # sampler = jax.tree_util.Partial(sampling_utils.jittable_sample_weighted_logits, rng=rng, temperature=temperature) + start = time.perf_counter() params = engine.load_params() print("Load params ", time.perf_counter() - start) @@ -77,7 +93,10 @@ def main(argv): jax.profiler.start_trace(profiling_output) prefill_result, _ = engine.prefill( - params=params, padded_tokens=tokens, true_length=true_length + params=params, + padded_tokens=tokens, + true_length=true_length, + sampler=sampler ) # pylint: disable-next=all decode_state = engine.insert(prefill_result, decode_state, slot=slot) diff --git a/tests/helpers.py b/tests/helpers.py index ac0ea5f..860389d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,7 +7,9 @@ # pylint: disable-next=all -def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None): +def make_env_tiny( + bf16_enable=True, env_data_update_fn=lambda _: None, batch_size=1 +): torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 torch.set_default_dtype(torch_dtype) jax.config.update("jax_dynamic_shapes", False) @@ -19,7 +21,7 @@ def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None): environment_data.cache_sequence_length = 128 environment_data.bf16_enable = bf16_enable environment_data.model_type = "llama-2-tiny" - environment_data.batch_size = 1 + environment_data.batch_size = batch_size environment_data.num_layers = config.n_layers environment_data.cache_shape = ( 1, diff --git a/tests/test_engine.py b/tests/test_engine.py index 57245c0..9704ce1 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -20,16 +20,64 @@ from jetstream_pt.third_party.llama import model_exportable from jetstream_pt.engine import PyTorchEngine +from jetstream_pt.engine import DecodeState +from jetstream_pt.engine import Prefix from tests import helpers +from jetstream_pt import cache_manager +# from jetstream_pt.engine import BaseSampler, GreedySampler, WeightedSampler, TopkSampler, NucleusSampler +# from jetstream.engine.sampling_util import BaseSampler, GreedySampler, WeightedSampler, TopkSampler, NucleusSampler +from jetstream.engine.sampling_utils import jittable_sample_greedy_logits, jittable_sample_topp_logits, jittable_sample_topk_logits, jittable_sample_weighted_logits + + +class MockEngine(PyTorchEngine): + + def _call_model_prefill(self, weights, tokens, input_indexes): + caches = [ + cache_manager.KVCachePrefill( + self.env.quant_config.enable_kv_quantization + ) + for _ in self.pt_model.layers + ] + # logits = jnp.ones((self.env.batch_size, 1), jnp.float32) + assert ( + self.env.batch_size == 1 + ), f"The batch size {self.env.batch_size} != 1" + logits = jnp.array([[0.5, 0.6, 0.7, 0.8]]) + return logits, caches + + def _call_model_generate( + self, + weights, + tokens, + input_indexes, + caches, + cache_scales, + mask, + start, + input_pos, + ragged_batch_index, + ragged_block_index, + page_token_indices, + ): + logits = jnp.array( + [ + [[0.5, 0.6, 0.7, 0.8]], + # [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]], + [[0.4, 0.3, 0.2, 0.1]], + ] + ) + return logits, caches, cache_scales class EngineTest(unittest.TestCase): - def setup(self): - env, model_arg = helpers.make_env_tiny(bf16_enable=True) + def setup(self, batch_size=1): + env, model_arg = helpers.make_env_tiny( + bf16_enable=True, batch_size=batch_size + ) model_ours = model_exportable.Transformer(model_arg, env) - engine = PyTorchEngine(pt_model=model_ours, env=env) - engine.rng = jax.random.PRNGKey(0) + engine = MockEngine(pt_model=model_ours, env=env) + engine.rng = jax.random.key(0) return engine def test_sampling_2D(self): @@ -37,14 +85,23 @@ def test_sampling_2D(self): engine = self.setup() self.assertEqual(engine.env.sampling_algorithm, "greedy") logits = jnp.array([[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]]) - token = engine._sampling(logits, batch_size=1) + token = engine._sampling( + logits, "greedy", engine.rng, temperature=1.0, topk=1, nucleus_topp=0.0 + ) self.assertEqual(token, jnp.array([[0]])) self.assertTrue(jnp.isdtype(token, jnp.int32)) # test weighted engine.env.sampling_algorithm = "weighted" engine.env.temperature = 5.0 - token = engine._sampling(logits, batch_size=1) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + temperature=5.0, + topk=1, + nucleus_topp=0.0, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) @@ -52,21 +109,36 @@ def test_sampling_2D(self): engine.env.sampling_algorithm = "topk" engine.env.temperature = 5.0 engine.env.topk = 4 - token = engine._sampling(logits, batch_size=1) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + temperature=5.0, + topk=4, + nucleus_topp=0.0, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) # test nucleus engine.env.sampling_algorithm = "nucleus" - engine.env.temperature = 0.0 + engine.env.temperature = 1.0 engine.env.nucleus_topp = 0.8 - token = engine._sampling(logits, batch_size=1) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + temperature=0.0, + topk=1, + nucleus_topp=0.8, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) def test_sampling_3D(self): # test greedy - engine = self.setup() + engine = self.setup(batch_size=2) + self.assertEqual(engine.env.sampling_algorithm, "greedy") logits = jnp.array( [ @@ -74,14 +146,28 @@ def test_sampling_3D(self): [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]], ] ) - token = engine._sampling(logits, batch_size=2) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + engine.env.temperature, + engine.env.topk, + engine.env.nucleus_topp, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[3], [0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) # test weighted engine.env.sampling_algorithm = "weighted" engine.env.temperature = 10.0 - token = engine._sampling(logits, batch_size=2) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + engine.env.temperature, + engine.env.topk, + engine.env.nucleus_topp, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) @@ -89,7 +175,14 @@ def test_sampling_3D(self): engine.env.sampling_algorithm = "topk" engine.env.temperature = 1.0 engine.env.topk = 3 - token = engine._sampling(logits, batch_size=2) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + engine.env.temperature, + engine.env.topk, + engine.env.nucleus_topp, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[1], [0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) @@ -97,10 +190,343 @@ def test_sampling_3D(self): engine.env.sampling_algorithm = "nucleus" engine.env.temperature = 1.0 engine.env.nucleus_topp = 0.8 - token = engine._sampling(logits, batch_size=2) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + engine.env.temperature, + engine.env.topk, + engine.env.nucleus_topp, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) + # Temporarily disabled becuase handling per request sampling is not ready yet. +# def test_custom_sampling_3D(self): +# engine = self.setup(batch_size=2) +# rng = jax.random.key(3) + +# engine.env.sampling_algorithm = "" + +# # Need a different engine of batch size of 1 to reshape the output +# rng_b1 = jax.random.key(3) +# logits = jnp.array( +# [ +# [[0.4, 0.3, 0.2, 0.1], [0.5, 0.6, 0.7, 0.8]], +# [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]], +# ] +# ) + +# # test greedy +# sampler = jittable_sample_greedy_logits +# samplers = [sampler, sampler] +# token = engine._custom_sampling(logits, samplers) + +# original_tokens = [] +# for i in range(2): +# original_token = engine._sampling( +# logits[i], +# "greedy", +# rng=rng, +# temperature=0.0, +# topk=0, +# nucleus_topp=0.0, +# ) +# original_tokens.append(original_token) +# original_tokens = jnp.concatenate(original_tokens) + +# print(f"custom sampling token {token} vs original tokens {original_tokens}") +# self.assertTrue(jnp.array_equal(token, original_tokens)) +# self.assertTrue(jnp.array_equal(token, jnp.array([[3], [0]]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# # test weighted +# sampler1 = jax.tree_util.Partial( +# jittable_sample_weighted_logits, rng=rng, temperature=1.0 +# ) +# sampler2 = jax.tree_util.Partial( +# jittable_sample_weighted_logits, rng=rng, temperature=1.0 +# ) +# samplers = [sampler1, sampler2] +# token = engine._custom_sampling(logits, samplers) + +# original_tokens = [] +# for i in range(2): +# rng_b1, sub_rng = jax.random.split(rng_b1) +# original_token = engine._sampling( +# logits[i], +# "weighted", +# rng, +# temperature=1, +# topk=0, +# nucleus_topp=0.0, +# ) +# original_tokens.append(original_token) +# original_tokens = jnp.concatenate(original_tokens) + +# print(f"custom sampling token {token} vs original tokens {original_tokens}") +# self.assertTrue(jnp.array_equal(token, original_tokens)) +# self.assertTrue(jnp.array_equal(token, jnp.array([[2], [2]]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# # # test topk +# sampler1 = jax.tree_util.Partial( +# jittable_sample_topk_logits, rng=rng, temperature=1.0, topk=3 +# ) +# sampler2 = jax.tree_util.Partial( +# jittable_sample_topk_logits, rng=rng, temperature=1.0, topk=3 +# ) +# samplers = [sampler1, sampler2] +# token = engine._custom_sampling(logits, samplers) + +# original_tokens = [] +# for i in range(2): +# # rng_b1, sub_rng = jax.random.split(rng_b1) +# sub_rng = rng +# original_token = engine._sampling( +# logits[i], +# "topk", +# rng=sub_rng, +# temperature=1.0, +# topk=3, +# nucleus_topp=0.0, +# ) +# original_tokens.append(original_token) +# original_tokens = jnp.concatenate(original_tokens) + +# print(f"custom sampling token {token} vs original tokens {original_tokens}") +# self.assertTrue(jnp.array_equal(token, original_tokens)) +# self.assertTrue(jnp.array_equal(token, jnp.array([[1], [2]]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# # test nucleus +# sampler1 = jax.tree_util.Partial( +# jittable_sample_topp_logits, rng=rng, temperature=1.0, topp=0.8 +# ) +# sampler2 = jax.tree_util.Partial( +# jittable_sample_topp_logits, rng=rng, temperature=1.0, topp=0.8 +# ) +# samplers = [sampler1, sampler2] +# token = engine._custom_sampling(logits, samplers) + +# original_tokens = [] +# for i in range(2): +# original_token = engine._sampling( +# logits[i], +# "nucleus", +# rng, +# temperature=1.0, +# topk=0, +# nucleus_topp=0.8, +# ) +# original_tokens.append(original_token) +# original_tokens = jnp.concatenate(original_tokens) +# print(f"custom sampling token {token} vs original tokens {original_tokens}") +# self.assertTrue(jnp.array_equal(token, original_tokens)) +# self.assertTrue(jnp.array_equal(token, jnp.array([[2], [2]]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# # # test topk + greedy +# sampler1 = jax.tree_util.Partial( +# jittable_sample_topk_logits, rng=rng, temperature=1.0, topk=3 +# ) +# sampler2 = jax.tree_util.Partial(jittable_sample_greedy_logits) +# samplers = [sampler1, sampler2] +# token = engine._custom_sampling(logits, samplers) + +# original_tokens = [] +# i = 0 +# original_token = engine._sampling( +# logits[i], +# "topk", +# rng, +# temperature=1.0, +# topk=3, +# nucleus_topp=0.8, +# ) +# original_tokens.append(original_token) + +# i = 1 +# original_token = engine._sampling( +# logits[i], +# "greedy", +# rng, +# temperature=0.0, +# topk=0, +# nucleus_topp=0.0, +# ) +# original_tokens.append(original_token) + +# original_tokens = jnp.concatenate(original_tokens) + +# print(f"custom sampling token {token} vs original tokens {original_tokens}") +# self.assertTrue(jnp.array_equal(token, original_tokens)) +# self.assertTrue(jnp.array_equal(token, jnp.array([[1], [0]]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# # test Prefill +# def test_prefill_with_custom_sampling(self): +# engine = self.setup() +# engine.rng = jax.random.key(3) + +# engine.env.sampling_algorithm = "" + +# # Inputs doesn't matter +# params = jnp.zeros((1,), jnp.float32) +# padded_tokens = jnp.zeros((1,), jnp.float32) +# true_length = 1 + +# # Greedy +# sampler = jax.tree_util.Partial(jittable_sample_greedy_logits) +# prefix, _ = engine.prefill( +# params=params, +# padded_tokens=padded_tokens, +# true_length=true_length, +# sampler=sampler, +# ) +# token = prefix.token +# print(f"Greedy output: {token}") +# self.assertTrue(jnp.array_equal(token, jnp.array([3]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# # Weighted +# sampler = jax.tree_util.Partial( +# jittable_sample_weighted_logits, rng=engine.rng, temperature=1.0 +# ) +# prefix, _ = engine.prefill( +# params=params, +# padded_tokens=padded_tokens, +# true_length=true_length, +# sampler=sampler, +# ) +# token = prefix.token +# print(f"Weighted output: {token}") +# self.assertTrue(jnp.array_equal(token, jnp.array([2]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# # Nucleus +# sampler = jax.tree_util.Partial( +# jittable_sample_topp_logits, rng=engine.rng, temperature=1.0, topp=0.8 +# ) +# prefix, _ = engine.prefill( +# params=params, +# padded_tokens=padded_tokens, +# true_length=true_length, +# sampler=sampler, +# ) +# token = prefix.token +# print(f"Nucleus output: {token}") +# self.assertTrue(jnp.array_equal(token, jnp.array([2]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# # Topk +# sampler = jax.tree_util.Partial( +# jittable_sample_topk_logits, rng=engine.rng, temperature=1.0, topk=3 +# ) + +# prefix, _ = engine.prefill( +# params=params, +# padded_tokens=padded_tokens, +# true_length=true_length, +# sampler=sampler, +# ) +# token = prefix.token +# print(f"Topk output: {token}") +# self.assertTrue(jnp.array_equal(token, jnp.array([1]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# def test_insert_no_wrap_with_custom_sampling(self): +# engine = self.setup() +# engine.env.sampling_algorithm = "" +# engine.env.batch_size = 2 +# cache_shape = engine.env.cache_shape + +# prefill_cache_shape = (1, cache_shape[1], 16, cache_shape[3]) +# prefill_cache = [] +# for _ in range(engine.env.num_layers): +# prefill_cache.append( +# ( +# jnp.ones(prefill_cache_shape, dtype=jnp.bfloat16), +# jnp.ones(prefill_cache_shape, dtype=jnp.bfloat16), +# ) +# ) + +# sampler = jittable_sample_greedy_logits +# prefix = Prefix( +# token=jnp.ones((1)), +# caches=prefill_cache, +# seq_len=16, +# sampler=sampler, +# ) + +# doesnt_matter = jnp.array([0]) +# kv_cache = engine.env.make_caches_generate() +# kv_cache = [c.state() for c in kv_cache] + +# base_sampler = jax.tree_util.Partial( +# engine._sampling, +# algorithm=engine.env.sampling_algorithm, +# rng=engine.rng, +# temperature=engine.env.temperature, +# topk=engine.env.topk, +# nucleus_topp=engine.env.nucleus_topp, +# ) +# decode_state = DecodeState( +# tokens=jnp.zeros((engine.env.batch_size, 1)), +# caches=kv_cache, +# cache_scales=[doesnt_matter], +# current_position=16, +# lens=jnp.zeros((engine.env.batch_size, 1)), +# start=jnp.zeros((engine.env.batch_size, 1)), +# input_pos=jnp.zeros((engine.env.batch_size,)), +# mask=jnp.zeros((engine.env.batch_size, 128)), +# # samplers = [base_sampler] * engine.env.batch_size +# samplers=None, +# ) + +# # Insert to slot 1 +# result_decode_state = engine._insert_no_wrap(prefix, decode_state, slot=1) + +# self.assertAlmostEqual( +# result_decode_state.tokens.all(), decode_state.tokens.all() +# ) +# self.assertEqual(result_decode_state.samplers, prefix.sampler) + +# def test_generate_with_custom_sampling(self): +# engine = self.setup(batch_size=2) +# engine.rng = jax.random.key(3) +# engine.env.sampling_algorithm = "" + +# # Inputs doesn't matter +# doesnt_matter = jnp.array([0]) +# params = doesnt_matter + +# greedy_sampler = jax.tree_util.Partial(jittable_sample_greedy_logits) +# weighted_sampler = jax.tree_util.Partial( +# jittable_sample_weighted_logits, rng=engine.rng, temperature=1.0 +# ) +# decode_state = DecodeState( +# tokens=jnp.zeros((engine.env.batch_size, 1)), +# caches=[doesnt_matter], +# cache_scales=[doesnt_matter], +# current_position=0, +# lens=jnp.zeros((engine.env.batch_size, 1)), +# start=doesnt_matter, +# input_pos=jnp.zeros((engine.env.batch_size,)), +# mask=jnp.zeros((engine.env.batch_size, 1)), +# samplers=weighted_sampler, +# ) + +# # Topk + Weighted +# # algorithm, temperature, topk, nucleus_topp +# decode_state, _ = engine.generate_impl( +# params=params, decode_state=decode_state +# ) +# token = decode_state.tokens +# print(f"Topk + Weighted output: {token}") +# self.assertTrue(jnp.array_equal(token, jnp.array([[1], [2]]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + # def test_insert(self): # seqlen = 32