diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index b7298e1..a2103ec 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -53,6 +53,8 @@ Params = jax.Array PrefillInputs = jax.Array +NEG_INF = -1.0e7 # Sampling masking + @struct.dataclass # pylint: disable-next=all @@ -60,6 +62,7 @@ class Prefix: token: jax.Array # [1, seqlen] caches: List[Tuple[jax.Array, jax.Array]] seq_len: int # true seqlen front pad + sampler_config: jax.Array # Sampler or sampling config, [] @struct.dataclass @@ -73,8 +76,15 @@ class DecodeState: current_position: int lens: jax.Array # [batch_size, 1], the output token length start: jax.Array # [batch_size, 1], the starting pos for each slot - input_pos: jax.Array # [batch_size, 1] input pos for each slot + input_pos: ( + jax.Array + ) # [batch_size, 1] total (prefill + decode) length for each slot mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid + # The sampling configuration. + # If sampling_algorithm set to empty, the shape is [batch_size, 1] + # Otherwise it's a list of intergers + # The last dimension contains [algorithm, temperature, topk, nucleus] + sampler_config: jax.Array | List[int] # NOTE model specific @@ -93,7 +103,9 @@ def __init__( self.pt_model = pt_model self.env = env self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 - self.rng = jax.random.PRNGKey(0) + self.rng = jax.random.key(0) + # For sampling + self.splited_rngs = jax.random.split(self.rng, num=self.env.batch_size) self.weights = weights self.y_sharding = env.sharding_by_axis(1) @@ -168,6 +180,20 @@ def init_decode_state( scalers = [] if self.env.quant_config.enable_kv_quantization: scalers = [c.scalers() for c in caches_obj] + + if self.env.sampling_algorithm == "": + # [algorithm, temperature, topk, nucleus] + sampler_config = [0, 0.0, 0, 0.0] + sampler_config = jnp.tile(sampler_config, (self.env.batch_size, 1)) + + else: + sampler_config = [ + self.env.sampling_algorithm, + self.env.temperature, + self.env.topk, + self.env.nucleus_topp, + ] + return DecodeState( jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), caches, @@ -181,6 +207,7 @@ def init_decode_state( float("-inf"), dtype=self.default_dtype, ), # mask + sampler_config, ) # pylint: disable-next=all @@ -280,19 +307,119 @@ def _call_model_prefill(self, weights, tokens, input_indexes): caches_res = [c.state() for c in caches] return torchjax.from_torch((res, caches_res)) - def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray: + def _greedy_sampling(self, logits): + return jnp.argmax(logits, axis=-1) + + def _weighted_sampling(self, logits, rng, temperature): + return jax.random.categorical(rng, logits / temperature) + + def _nucleus_sampling(self, logits, rng, temperature, nucleus_topp): + # return sampling_utils.sample_nucleus_topp_logits(logits, nucleus_topp, temperature, rng) + """Restrict sampling to the top logits with cumulative probability >= + nucleus_topp. + + The nucleus sampling method is proposed in the paper `The Curious Case of + Neural Text Degeneration (https://arxiv.org/pdf/1904.09751.pdf)` + + """ + # if nucleus_topp < 0: + # raise ValueError( + # "Can't apply nucleus with parameter {nucleus_topp=} less zero" + # ) + logits_sorted = jnp.sort(logits, axis=-1)[..., ::-1] # sort descending + sorted_cum_probs = jnp.cumsum( + jax.nn.softmax(logits_sorted, axis=-1), axis=-1 + ) # get cumsum probs + cutoff_index = jnp.sum( + sorted_cum_probs < nucleus_topp, axis=-1, keepdims=True + ) # find cutoff index + cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) + logits = jnp.where( + logits < cutoff_logit, jnp.full_like(logits, NEG_INF), logits + ) + result = jax.random.categorical(rng, logits / temperature) + return result + + def _topk_sampling(self, logits, rng, temperature, topk): + # return sampling_utils.sample_topk_logits(logits, topk, temperature, rng) + """Restricting sampling to the best k logits.""" + # if topk <= 0: + # raise ValueError("Can't apply algorithm topk with parameter {topk=} <= 0") + sorted_indices = jnp.argsort(logits)[::-1] # Sort in descending order + topk_mask = jnp.arange(sorted_indices.shape[-1]) < topk + topk_idxs = jnp.where(topk_mask, sorted_indices, -1) + + topk_logits = jnp.where(topk_idxs == -1, -jnp.inf, logits) + + sampled_idx = jnp.expand_dims( + jax.random.categorical(rng, topk_logits / temperature).astype( + jnp.int32 + ), + axis=-1, + ) + sampled_tokens = jnp.squeeze( + jnp.take_along_axis(topk_idxs, sampled_idx, axis=-1), axis=-1 + ).astype(jnp.int32) + + return sampled_tokens + + # Algorithm type: + # 0: Greedy 1: Weighted 2: Nucleus 3: Top-k + def _apply_sampling( + self, logits, algorithm, rng, temperature, topk, nucleus_topp + ) -> jnp.ndarray: + return jax.lax.cond( + algorithm == 0, + lambda: self._greedy_sampling(logits), # Greedy + lambda: jax.lax.cond( + algorithm == 1, + lambda: self._weighted_sampling( + logits, rng, temperature + ), # Weighted + lambda: jax.lax.cond( + algorithm == 2, + lambda: self._nucleus_sampling( + logits, rng, temperature, nucleus_topp + ), # Nucleus sampling + lambda: self._topk_sampling( + logits, rng, temperature, topk + ), # Top-k sampling + ), + ), + ) + + def _custom_sampling( + self, logits, algorithm, rng, temperature, topk, nucleus_topp + ) -> jnp.ndarray: + if len(logits.shape) == 2: + logits = jnp.expand_dims(logits, 0) + logits = logits[:, -1] + apply_sampling_v = jax.vmap( + lambda _logits, _algorithm, _rngs, _temperature, _topk, _nucleus_topp: self._apply_sampling( + _logits, _algorithm, _rngs, _temperature, _topk, _nucleus_topp + ), + in_axes=(0, 0, 0, 0, 0, 0), + ) + apply_sampling_v = jax.jit(apply_sampling_v) + return apply_sampling_v( + logits, algorithm, rng, temperature, topk, nucleus_topp + ).reshape(self.env.batch_size, -1) + + def _sampling( + self, logits: Any, algorithm, rng, temperature, topk, nucleus_topp + ) -> jnp.ndarray: if len(logits.shape) == 2: logits = jnp.expand_dims(logits, 0) return ( sampling_utils.sampling( logits[:, -1], - self.rng, - self.env.sampling_algorithm, - self.env.topk, - self.env.nucleus_topp, - self.env.temperature, + rng, + algorithm, + topk, + nucleus_topp, + temperature, ) - .reshape(batch_size, -1) + .reshape(self.env.batch_size, -1) .astype(jnp.int32) ) @@ -301,7 +428,7 @@ def prefill( *, params: Any, # Weights existing_prefix: Optional[Prefix] = None, - padded_tokens: PrefillInputs, # PrefillInputs[jax.Array], + padded_tokens: PrefillInputs, # PrefillInputs[jax.Array] true_length: int, sampler: Optional[Callable[[Any], Any]] = None, ) -> Tuple[Prefix, engine_api.ResultTokens]: @@ -321,17 +448,45 @@ def prefill( ) if len(logits.shape) == 3: # b, seqlen, num words logits = logits[0] # seqlen, num words - if sampler: - token = sampler(logits[true_length - 1]) + + if self.env.sampling_algorithm == "": + assert ( + isinstance(sampler, List) + or isinstance(sampler, list) + or isinstance(sampler, Tuple) + or isinstance(sampler, tuple) + ), f"{type(sampler)} is not valid" + algorithm, temperature, topk, nucleus_topp = sampler + + algorithm = jnp.array([algorithm]) + temperature = jnp.array([temperature]) + topk = jnp.array([topk]) + nucleus_topp = jnp.array([nucleus_topp]) + + # Prefill only handle batch size of 1, therefore no need to use splitted rngs + sampling = self._custom_sampling + rng = self.rng.reshape( + 1, + ) else: - token = sampling_utils.sampling( - logits[true_length - 1], - self.rng, + algorithm, temperature, topk, nucleus_topp = ( self.env.sampling_algorithm, + self.env.temperature, self.env.topk, self.env.nucleus_topp, - self.env.temperature, ) + sampling = self._sampling + rng = self.rng + + token = sampling( + logits=logits, + algorithm=algorithm, + rng=rng, + temperature=temperature, + topk=topk, + nucleus_topp=nucleus_topp, + ) + token_out = jnp.reshape(token, (1, 1)) data = jnp.concatenate( [ @@ -357,7 +512,10 @@ def prefill( # v, seq_len - true_length, true_length, axis=2)) # for k, v in updated_caches # ] - return Prefix(token, updated_caches, true_length), result + return ( + Prefix(token, updated_caches, true_length, jnp.array(sampler)), + result, + ) def shrink_prefix( self, @@ -476,6 +634,19 @@ def insert(cache, scaler, new_entry, update_index): caches.append((kcache, vcache)) scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) + + if self.env.sampling_algorithm == "": + sampler_config = decode_state.sampler_config.at[slot].set( + prefix.sampler_config + ) + else: + sampler_config = [ + self.env.sampling_algorithm, + self.env.temperature, + self.env.topk, + self.env.nucleus_topp, + ] + return DecodeState( tokens, caches, @@ -485,6 +656,7 @@ def insert(cache, scaler, new_entry, update_index): start, input_pos, mask, + sampler_config, ) # pylint: disable-next=all @@ -569,6 +741,19 @@ def insert(cache, scaler, new_entry): scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) + + if self.env.sampling_algorithm == "": + sampler_config = decode_state.sampler_config.at[slot].set( + prefix.sampler_config + ) + else: + sampler_config = [ + self.env.sampling_algorithm, + self.env.temperature, + self.env.topk, + self.env.nucleus_topp, + ] + return DecodeState( tokens, caches, @@ -578,6 +763,7 @@ def insert(cache, scaler, new_entry): start, input_pos, mask, + sampler_config, ) def _insert_page_attention( @@ -613,6 +799,19 @@ def _insert_page_attention( input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len) scales = None lens = decode_state.lens.at[slot].set(1) + + if self.env.sampling_algorithm == "": + sampler_config = decode_state.sampler_config.at[slot].set( + prefix.sampler_config + ) + else: + sampler_config = [ + self.env.sampling_algorithm, + self.env.temperature, + self.env.topk, + self.env.nucleus_topp, + ] + return DecodeState( tokens, caches, @@ -622,6 +821,7 @@ def _insert_page_attention( start, input_pos, mask, + sampler_config, ) def insert( @@ -729,7 +929,9 @@ def false_comp(b, i, bk, start, end): return b_next, i_next def generate( - self, params: Any, decode_state: DecodeState, sampler=None + self, + params: Any, + decode_state: DecodeState, ) -> tuple[DecodeState, engine_api.ResultTokens]: return (None, None) @@ -752,7 +954,6 @@ def generate_impl( self, params: Any, decode_state: DecodeState, - sampler=None, page_token_indices=None, ) -> tuple[DecodeState, engine_api.ResultTokens]: # seq_len = padded_tokens.shape[0] @@ -764,12 +965,16 @@ def generate_impl( else: input_indexes = decode_state.input_pos - ragged_batch_index, ragged_block_index = ( - self.precompute_ragged_block_indices(decode_state) - ) - ragged_batch_index, ragged_block_index = ragged_batch_index.reshape( - (-1) - ), ragged_block_index.reshape((-1)) + # TODO(lancewang): Remove ragged index precomputation + # ragged_batch_index, ragged_block_index = ( + # self.precompute_ragged_block_indices(decode_state) + # ) + # ragged_batch_index, ragged_block_index = ragged_batch_index.reshape( + # (-1) + # ), ragged_block_index.reshape((-1)) + + ragged_batch_index = 0 + ragged_block_index = 0 def update_mask(): if self.env.ring_buffer: @@ -799,10 +1004,22 @@ def update_mask(): # fill mask later, now use flash attention mask = update_mask() - if sampler: - next_token = sampler(logits[:, -1]) + if self.env.sampling_algorithm == "": + sampling = self._custom_sampling + rng = self.splited_rngs else: - next_token = self._sampling(logits, self.env.batch_size) + sampling = self._sampling + rng = self.rng + + next_token = sampling( + logits=logits, + algorithm=decode_state.sampler_config[:, 0:1].reshape(-1), + rng=rng, + temperature=decode_state.sampler_config[:, 1:2].reshape(-1), + topk=decode_state.sampler_config[:, 2:3].reshape(-1), + nucleus_topp=decode_state.sampler_config[:, 3:4].reshape(-1), + ) + if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 lens = decode_state.lens + 1 @@ -844,6 +1061,7 @@ def update_mask(): decode_state.start, input_pos, mask, + decode_state.sampler_config, ) return new_decode_state, result_tokens @@ -963,6 +1181,7 @@ def get_prefix_destination_sharding(self) -> Prefix: if self.env.page_attention else self.cache_sharding, self.replicated, + self.replicated, ) def get_decode_state_sharding(self) -> DecodeState: @@ -976,6 +1195,7 @@ def get_decode_state_sharding(self) -> DecodeState: self.replicated, self.replicated, self.replicated, + self.replicated, ) def get_prefix_sequence_ddim(self) -> Any: diff --git a/tests/helpers.py b/tests/helpers.py index ac0ea5f..860389d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,7 +7,9 @@ # pylint: disable-next=all -def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None): +def make_env_tiny( + bf16_enable=True, env_data_update_fn=lambda _: None, batch_size=1 +): torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 torch.set_default_dtype(torch_dtype) jax.config.update("jax_dynamic_shapes", False) @@ -19,7 +21,7 @@ def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None): environment_data.cache_sequence_length = 128 environment_data.bf16_enable = bf16_enable environment_data.model_type = "llama-2-tiny" - environment_data.batch_size = 1 + environment_data.batch_size = batch_size environment_data.num_layers = config.n_layers environment_data.cache_shape = ( 1, diff --git a/tests/test_engine.py b/tests/test_engine.py index 57245c0..2fe94ba 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -20,16 +20,61 @@ from jetstream_pt.third_party.llama import model_exportable from jetstream_pt.engine import PyTorchEngine +from jetstream_pt.engine import DecodeState +from jetstream_pt.engine import Prefix from tests import helpers +from jetstream_pt import cache_manager + + +class MockEngine(PyTorchEngine): + + def _call_model_prefill(self, weights, tokens, input_indexes): + caches = [ + cache_manager.KVCachePrefill( + self.env.quant_config.enable_kv_quantization + ) + for _ in self.pt_model.layers + ] + # logits = jnp.ones((self.env.batch_size, 1), jnp.float32) + assert ( + self.env.batch_size == 1 + ), f"The batch size {self.env.batch_size} != 1" + logits = jnp.array([[0.5, 0.6, 0.7, 0.8]]) + return logits, caches + + def _call_model_generate( + self, + weights, + tokens, + input_indexes, + caches, + cache_scales, + mask, + start, + input_pos, + ragged_batch_index, + ragged_block_index, + page_token_indices, + ): + logits = jnp.array( + [ + [[0.5, 0.6, 0.7, 0.8]], + # [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]], + [[0.4, 0.3, 0.2, 0.1]], + ] + ) + return logits, caches, cache_scales class EngineTest(unittest.TestCase): - def setup(self): - env, model_arg = helpers.make_env_tiny(bf16_enable=True) + def setup(self, batch_size=1): + env, model_arg = helpers.make_env_tiny( + bf16_enable=True, batch_size=batch_size + ) model_ours = model_exportable.Transformer(model_arg, env) - engine = PyTorchEngine(pt_model=model_ours, env=env) - engine.rng = jax.random.PRNGKey(0) + engine = MockEngine(pt_model=model_ours, env=env) + engine.rng = jax.random.key(0) return engine def test_sampling_2D(self): @@ -37,14 +82,23 @@ def test_sampling_2D(self): engine = self.setup() self.assertEqual(engine.env.sampling_algorithm, "greedy") logits = jnp.array([[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]]) - token = engine._sampling(logits, batch_size=1) + token = engine._sampling( + logits, "greedy", engine.rng, temperature=1.0, topk=1, nucleus_topp=0.0 + ) self.assertEqual(token, jnp.array([[0]])) self.assertTrue(jnp.isdtype(token, jnp.int32)) # test weighted engine.env.sampling_algorithm = "weighted" engine.env.temperature = 5.0 - token = engine._sampling(logits, batch_size=1) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + temperature=5.0, + topk=1, + nucleus_topp=0.0, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) @@ -52,21 +106,36 @@ def test_sampling_2D(self): engine.env.sampling_algorithm = "topk" engine.env.temperature = 5.0 engine.env.topk = 4 - token = engine._sampling(logits, batch_size=1) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + temperature=5.0, + topk=4, + nucleus_topp=0.0, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) # test nucleus engine.env.sampling_algorithm = "nucleus" - engine.env.temperature = 0.0 + engine.env.temperature = 1.0 engine.env.nucleus_topp = 0.8 - token = engine._sampling(logits, batch_size=1) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + temperature=0.0, + topk=1, + nucleus_topp=0.8, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) def test_sampling_3D(self): # test greedy - engine = self.setup() + engine = self.setup(batch_size=2) + self.assertEqual(engine.env.sampling_algorithm, "greedy") logits = jnp.array( [ @@ -74,14 +143,28 @@ def test_sampling_3D(self): [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]], ] ) - token = engine._sampling(logits, batch_size=2) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + engine.env.temperature, + engine.env.topk, + engine.env.nucleus_topp, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[3], [0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) # test weighted engine.env.sampling_algorithm = "weighted" engine.env.temperature = 10.0 - token = engine._sampling(logits, batch_size=2) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + engine.env.temperature, + engine.env.topk, + engine.env.nucleus_topp, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) @@ -89,7 +172,14 @@ def test_sampling_3D(self): engine.env.sampling_algorithm = "topk" engine.env.temperature = 1.0 engine.env.topk = 3 - token = engine._sampling(logits, batch_size=2) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + engine.env.temperature, + engine.env.topk, + engine.env.nucleus_topp, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[1], [0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) @@ -97,10 +187,351 @@ def test_sampling_3D(self): engine.env.sampling_algorithm = "nucleus" engine.env.temperature = 1.0 engine.env.nucleus_topp = 0.8 - token = engine._sampling(logits, batch_size=2) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + engine.env.temperature, + engine.env.topk, + engine.env.nucleus_topp, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) + def test_custom_sampling_3D(self): + engine = self.setup(batch_size=2) + engine.rng = jax.random.key(3) + engine.splited_rngs = jax.random.split( + engine.rng, num=engine.env.batch_size + ) + engine.env.sampling_algorithm = "" + + # Need a different engine of batch size of 1 to reshape the output + engine_b1 = self.setup() + engine_b1.rng = jax.random.key(3) + logits = jnp.array( + [ + [[0.4, 0.3, 0.2, 0.1], [0.5, 0.6, 0.7, 0.8]], + [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]], + ] + ) + + # test greedy + token = engine._custom_sampling( + logits, + jnp.array([0, 0]), + engine.splited_rngs, + temperature=jnp.array([[0.0], [0.0]]), + topk=jnp.array([[0], [0]]), + nucleus_topp=jnp.array([[0.0], [0.0]]), + ) + original_tokens = [] + for i in range(2): + original_token = engine_b1._sampling( + logits[i], + "greedy", + engine.splited_rngs[i], + temperature=1.0, + topk=0, + nucleus_topp=0.0, + ) + original_tokens.append(original_token) + original_tokens = jnp.concatenate(original_tokens) + + print(f"custom sampling token {token} vs original tokens {original_tokens}") + self.assertTrue(jnp.array_equal(token, original_tokens)) + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [0]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test weighted + engine.env.sampling_algorithm = "weighted" + token = engine._custom_sampling( + logits, + jnp.array([1, 1]), + engine.splited_rngs, + temperature=jnp.array([1, 1]), + topk=jnp.array([0, 0]), + nucleus_topp=jnp.array([0.0, 0.0]), + ) + original_tokens = [] + for i in range(2): + original_token = engine_b1._sampling( + logits[i], + "weighted", + engine.splited_rngs[i], + temperature=1, + topk=0, + nucleus_topp=0.0, + ) + original_tokens.append(original_token) + original_tokens = jnp.concatenate(original_tokens) + + print(f"custom sampling token {token} vs original tokens {original_tokens}") + self.assertTrue(jnp.array_equal(token, original_tokens)) + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test topk + engine.env.sampling_algorithm = "topk" + token = engine._custom_sampling( + logits, + jnp.array([3, 3]), + engine.splited_rngs, + temperature=jnp.array([[1.0], [1.0]]), + topk=jnp.array([[3], [3]]), + nucleus_topp=jnp.array([[0.0], [0.0]]), + ) + original_tokens = [] + for i in range(2): + original_token = engine_b1._sampling( + logits[i], + "topk", + engine.splited_rngs[i], + temperature=1.0, + topk=3, + nucleus_topp=0.0, + ) + original_tokens.append(original_token) + original_tokens = jnp.concatenate(original_tokens) + + print(f"custom sampling token {token} vs original tokens {original_tokens}") + self.assertTrue(jnp.array_equal(token, original_tokens)) + self.assertTrue(jnp.array_equal(token, jnp.array([[1], [2]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test nucleus + engine.env.sampling_algorithm = "nucleus" + token = engine._custom_sampling( + logits, + jnp.array([2, 2]), + engine.splited_rngs, + temperature=jnp.array([[1.0], [1.0]]), + topk=jnp.array([[0], [0]]), + nucleus_topp=jnp.array([[0.8], [0.8]]), + ) + + original_tokens = [] + for i in range(2): + original_token = engine_b1._sampling( + logits[i], + "nucleus", + engine.splited_rngs[i], + temperature=1.0, + topk=0, + nucleus_topp=0.8, + ) + original_tokens.append(original_token) + original_tokens = jnp.concatenate(original_tokens) + print(f"custom sampling token {token} vs original tokens {original_tokens}") + self.assertTrue(jnp.array_equal(token, original_tokens)) + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test greedy + topk + token = engine._custom_sampling( + logits, + jnp.array([0, 3]), + engine.splited_rngs, + temperature=jnp.array([[0.0], [1.0]]), + topk=jnp.array([[0], [3]]), + nucleus_topp=jnp.array([[0.0], [0.0]]), + ) + original_tokens = [] + + i = 0 + original_token = engine_b1._sampling( + logits[i], + "greedy", + engine.splited_rngs[i], + temperature=0.0, + topk=0, + nucleus_topp=0.8, + ) + original_tokens.append(original_token) + + i = 1 + original_token = engine_b1._sampling( + logits[i], + "topk", + engine.splited_rngs[i], + temperature=1.0, + topk=3, + nucleus_topp=0.0, + ) + original_tokens.append(original_token) + + original_tokens = jnp.concatenate(original_tokens) + + print(f"custom sampling token {token} vs original tokens {original_tokens}") + self.assertTrue(jnp.array_equal(token, original_tokens)) + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + def test_prefill_with_custom_sampling(self): + engine = self.setup() + engine.env.sampling_algorithm = "" + + # Inputs doesn't matter + params = jnp.zeros((1,), jnp.float32) + padded_tokens = jnp.zeros((1,), jnp.float32) + true_length = 1 + + # Greedy + # algorithm, temperature, topk, nucleus_topp + sampler = [0, 1.0, 3, 0.8] + prefix, _ = engine.prefill( + params=params, + padded_tokens=padded_tokens, + true_length=true_length, + sampler=sampler, + ) + token = prefix.token + print(f"Greedy output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[3]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + print( + f"prefix sampler config {prefix.sampler_config} vs sampler {jnp.array(sampler)}" + ) + self.assertAlmostEqual( + prefix.sampler_config.all(), jnp.array(sampler).all() + ) + + # Weighted + sampler = [1, 10.0, 3, 0.8] + prefix, _ = engine.prefill( + params=params, + padded_tokens=padded_tokens, + true_length=true_length, + sampler=sampler, + ) + token = prefix.token + print(f"Weighted output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + self.assertAlmostEqual( + prefix.sampler_config.all(), jnp.array(sampler).all() + ) + + # Nucleus + sampler = [2, 1.0, 3, 0.0] + prefix, _ = engine.prefill( + params=params, + padded_tokens=padded_tokens, + true_length=true_length, + sampler=sampler, + ) + token = prefix.token + print(f"Topk output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[3]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + self.assertAlmostEqual( + prefix.sampler_config.all(), jnp.array(sampler).all() + ) + + # Topk + sampler = [3, 1.0, 3, 0.8] + prefix, _ = engine.prefill( + params=params, + padded_tokens=padded_tokens, + true_length=true_length, + sampler=sampler, + ) + token = prefix.token + print(f"Nucleus output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[3]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + self.assertAlmostEqual( + prefix.sampler_config.all(), jnp.array(sampler).all() + ) + + def test_insert_no_wrap_with_custom_sampling(self): + engine = self.setup() + engine.env.sampling_algorithm = "" + engine.env.batch_size = 2 + cache_shape = engine.env.cache_shape + + sampler_config_raw = [0, 1.0, 3, 0.8] + sampler_config = jnp.array(sampler_config_raw) + + prefill_cache_shape = (1, cache_shape[1], 16, cache_shape[3]) + prefill_cache = [] + for _ in range(engine.env.num_layers): + prefill_cache.append( + ( + jnp.ones(prefill_cache_shape, dtype=jnp.bfloat16), + jnp.ones(prefill_cache_shape, dtype=jnp.bfloat16), + ) + ) + + prefix = Prefix( + token=jnp.ones((1)), + caches=prefill_cache, + seq_len=16, + sampler_config=sampler_config, + ) + + doesnt_matter = jnp.array([0]) + kv_cache = engine.env.make_caches_generate() + kv_cache = [c.state() for c in kv_cache] + + decode_state = DecodeState( + tokens=jnp.zeros((engine.env.batch_size, 1)), + caches=kv_cache, + cache_scales=[doesnt_matter], + current_position=16, + lens=jnp.zeros((engine.env.batch_size, 1)), + start=jnp.zeros((engine.env.batch_size, 1)), + input_pos=jnp.zeros((engine.env.batch_size,)), + mask=jnp.zeros((engine.env.batch_size, 128)), + sampler_config=jnp.zeros((engine.env.batch_size, 4)), + ) + + # Insert to slot 1 + result_decode_state = engine._insert_no_wrap(prefix, decode_state, slot=1) + + self.assertAlmostEqual( + result_decode_state.tokens.all(), decode_state.tokens.all() + ) + self.assertAlmostEqual( + result_decode_state.sampler_config.all(), + jnp.array([[0, 0, 0, 0], sampler_config_raw]).all(), + ) + + def test_decode_with_custom_sampling(self): + engine = self.setup(batch_size=2) + engine.rng = jax.random.key(3) + engine.splited_rngs = jax.random.split( + engine.rng, num=engine.env.batch_size + ) + engine.env.sampling_algorithm = "" + + # Inputs doesn't matter + doesnt_matter = jnp.array([0]) + params = doesnt_matter + + decode_state = DecodeState( + tokens=jnp.zeros((engine.env.batch_size, 1)), + caches=[doesnt_matter], + cache_scales=[doesnt_matter], + current_position=0, + lens=jnp.zeros((engine.env.batch_size, 1)), + start=doesnt_matter, + input_pos=jnp.zeros((engine.env.batch_size,)), + mask=jnp.zeros((engine.env.batch_size, 1)), + sampler_config=jnp.array([[0, 0.0, 0, 0.0], [3, 1.0, 3, 0.0]]), + ) + + # Topk + Weighted + # algorithm, temperature, topk, nucleus_topp + decode_state, _ = engine.generate_impl( + params=params, decode_state=decode_state + ) + token = decode_state.tokens + print(f"Greedy output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + # def test_insert(self): # seqlen = 32