From 1182c48e7b50b836acca4c120315d3eb44d01a4d Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 30 Sep 2024 01:43:37 +0000 Subject: [PATCH 1/9] Initial change. --- jetstream_pt/engine.py | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index b7298e1..0b6a60f 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -75,7 +75,9 @@ class DecodeState: start: jax.Array # [batch_size, 1], the starting pos for each slot input_pos: jax.Array # [batch_size, 1] input pos for each slot mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid - + topk: int + nucleus_topp: int + temperature: int # NOTE model specific @@ -280,7 +282,7 @@ def _call_model_prefill(self, weights, tokens, input_indexes): caches_res = [c.state() for c in caches] return torchjax.from_torch((res, caches_res)) - def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray: + def _sampling(self, logits: Any, batch_size: int, topk: Any, nucleus_topp: Any, temperature: Any) -> jnp.ndarray: if len(logits.shape) == 2: logits = jnp.expand_dims(logits, 0) return ( @@ -288,9 +290,9 @@ def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray: logits[:, -1], self.rng, self.env.sampling_algorithm, - self.env.topk, - self.env.nucleus_topp, - self.env.temperature, + topk, + nucleus_topp, + temperature, ) .reshape(batch_size, -1) .astype(jnp.int32) @@ -304,6 +306,7 @@ def prefill( padded_tokens: PrefillInputs, # PrefillInputs[jax.Array], true_length: int, sampler: Optional[Callable[[Any], Any]] = None, + sampling_config: Any = None, ) -> Tuple[Prefix, engine_api.ResultTokens]: if isinstance(padded_tokens, jax.Array): batched_token = padded_tokens.reshape(1, -1) @@ -321,6 +324,11 @@ def prefill( ) if len(logits.shape) == 3: # b, seqlen, num words logits = logits[0] # seqlen, num words + + topk = sampling_config.topk if sampling_config else self.env.topk + nucleus_topp = sampling_config.nucleus_topp if sampling_config else self.env.nucleus_topp + temperature = sampling_config.temperature if sampling_config else self.env.temperature + if sampler: token = sampler(logits[true_length - 1]) else: @@ -328,9 +336,9 @@ def prefill( logits[true_length - 1], self.rng, self.env.sampling_algorithm, - self.env.topk, - self.env.nucleus_topp, - self.env.temperature, + topk, + nucleus_topp, + temperature, ) token_out = jnp.reshape(token, (1, 1)) data = jnp.concatenate( @@ -729,7 +737,7 @@ def false_comp(b, i, bk, start, end): return b_next, i_next def generate( - self, params: Any, decode_state: DecodeState, sampler=None + self, params: Any, decode_state: DecodeState, sampler=None, sampling_config=None ) -> tuple[DecodeState, engine_api.ResultTokens]: return (None, None) @@ -753,6 +761,7 @@ def generate_impl( params: Any, decode_state: DecodeState, sampler=None, + sampling_config=None, page_token_indices=None, ) -> tuple[DecodeState, engine_api.ResultTokens]: # seq_len = padded_tokens.shape[0] @@ -799,10 +808,15 @@ def update_mask(): # fill mask later, now use flash attention mask = update_mask() + topk = sampling_config.topk if sampling_config else self.env.topk + nucleus_topp = sampling_config.nucleus_topp if sampling_config else self.env.nucleus_topp + temperature = sampling_config.temperature if sampling_config else self.env.temperature + if sampler: next_token = sampler(logits[:, -1]) else: - next_token = self._sampling(logits, self.env.batch_size) + next_token = self._sampling(logits, self.env.batch_size, topk, nucleus_topp, temperature) + if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 lens = decode_state.lens + 1 @@ -976,6 +990,10 @@ def get_decode_state_sharding(self) -> DecodeState: self.replicated, self.replicated, self.replicated, + self.replicated, + self.replicated, + self.replicated, + self.replicated, ) def get_prefix_sequence_ddim(self) -> Any: From 931e2ce69bfc821f4816c480d1ee80d990045a03 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 30 Sep 2024 05:30:50 +0000 Subject: [PATCH 2/9] Limit the changes to temperatures, sets temperature to each request; --- jetstream_pt/engine.py | 63 ++++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 0b6a60f..649a67e 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -60,6 +60,8 @@ class Prefix: token: jax.Array # [1, seqlen] caches: List[Tuple[jax.Array, jax.Array]] seq_len: int # true seqlen front pad + # temperature parameter for scaling probability + temperature: float @struct.dataclass @@ -75,9 +77,7 @@ class DecodeState: start: jax.Array # [batch_size, 1], the starting pos for each slot input_pos: jax.Array # [batch_size, 1] input pos for each slot mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid - topk: int - nucleus_topp: int - temperature: int + temperatures: List[float] # [batch_size, 1], the temperature for each slot # NOTE model specific @@ -170,6 +170,7 @@ def init_decode_state( scalers = [] if self.env.quant_config.enable_kv_quantization: scalers = [c.scalers() for c in caches_obj] + temperatures = [self.env.temperature] * self.env.batch_size return DecodeState( jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), caches, @@ -183,6 +184,7 @@ def init_decode_state( float("-inf"), dtype=self.default_dtype, ), # mask + temperatures, ) # pylint: disable-next=all @@ -282,7 +284,9 @@ def _call_model_prefill(self, weights, tokens, input_indexes): caches_res = [c.state() for c in caches] return torchjax.from_torch((res, caches_res)) - def _sampling(self, logits: Any, batch_size: int, topk: Any, nucleus_topp: Any, temperature: Any) -> jnp.ndarray: + def _sampling( + self, logits: Any, batch_size: int, temperatures: List[float] + ) -> jnp.ndarray: if len(logits.shape) == 2: logits = jnp.expand_dims(logits, 0) return ( @@ -290,9 +294,9 @@ def _sampling(self, logits: Any, batch_size: int, topk: Any, nucleus_topp: Any, logits[:, -1], self.rng, self.env.sampling_algorithm, - topk, - nucleus_topp, - temperature, + self.env.topk, + self.env.nucleus_topp, + temperatures, ) .reshape(batch_size, -1) .astype(jnp.int32) @@ -305,8 +309,7 @@ def prefill( existing_prefix: Optional[Prefix] = None, padded_tokens: PrefillInputs, # PrefillInputs[jax.Array], true_length: int, - sampler: Optional[Callable[[Any], Any]] = None, - sampling_config: Any = None, + sampler: Any, ) -> Tuple[Prefix, engine_api.ResultTokens]: if isinstance(padded_tokens, jax.Array): batched_token = padded_tokens.reshape(1, -1) @@ -325,21 +328,24 @@ def prefill( if len(logits.shape) == 3: # b, seqlen, num words logits = logits[0] # seqlen, num words - topk = sampling_config.topk if sampling_config else self.env.topk - nucleus_topp = sampling_config.nucleus_topp if sampling_config else self.env.nucleus_topp - temperature = sampling_config.temperature if sampling_config else self.env.temperature + temperature = ( + sampler["temperature"] + if isinstance(sampler, dict) and "temperature" in sampler + else self.env.temperature + ) - if sampler: + if sampler and callable(sampler): token = sampler(logits[true_length - 1]) else: token = sampling_utils.sampling( logits[true_length - 1], self.rng, self.env.sampling_algorithm, - topk, - nucleus_topp, + self.env.topk, + self.env.nucleus_topp, temperature, ) + token_out = jnp.reshape(token, (1, 1)) data = jnp.concatenate( [ @@ -365,7 +371,7 @@ def prefill( # v, seq_len - true_length, true_length, axis=2)) # for k, v in updated_caches # ] - return Prefix(token, updated_caches, true_length), result + return Prefix(token, updated_caches, true_length, temperature), result def shrink_prefix( self, @@ -484,6 +490,7 @@ def insert(cache, scaler, new_entry, update_index): caches.append((kcache, vcache)) scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) + decode_state.temperatures[slot] = prefix.temperature return DecodeState( tokens, caches, @@ -493,6 +500,7 @@ def insert(cache, scaler, new_entry, update_index): start, input_pos, mask, + decode_state.temperatures, ) # pylint: disable-next=all @@ -577,6 +585,7 @@ def insert(cache, scaler, new_entry): scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) + decode_state.temperatures[slot] = prefix.temperature return DecodeState( tokens, caches, @@ -586,6 +595,7 @@ def insert(cache, scaler, new_entry): start, input_pos, mask, + decode_state.temperatures, ) def _insert_page_attention( @@ -621,6 +631,7 @@ def _insert_page_attention( input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len) scales = None lens = decode_state.lens.at[slot].set(1) + decode_state.temperatures[slot] = prefix.temperature return DecodeState( tokens, caches, @@ -630,6 +641,7 @@ def _insert_page_attention( start, input_pos, mask, + decode_state.temperatures, ) def insert( @@ -737,7 +749,9 @@ def false_comp(b, i, bk, start, end): return b_next, i_next def generate( - self, params: Any, decode_state: DecodeState, sampler=None, sampling_config=None + self, + params: Any, + decode_state: DecodeState, ) -> tuple[DecodeState, engine_api.ResultTokens]: return (None, None) @@ -761,7 +775,6 @@ def generate_impl( params: Any, decode_state: DecodeState, sampler=None, - sampling_config=None, page_token_indices=None, ) -> tuple[DecodeState, engine_api.ResultTokens]: # seq_len = padded_tokens.shape[0] @@ -808,14 +821,12 @@ def update_mask(): # fill mask later, now use flash attention mask = update_mask() - topk = sampling_config.topk if sampling_config else self.env.topk - nucleus_topp = sampling_config.nucleus_topp if sampling_config else self.env.nucleus_topp - temperature = sampling_config.temperature if sampling_config else self.env.temperature - - if sampler: + if sampler and callable(sampler): next_token = sampler(logits[:, -1]) else: - next_token = self._sampling(logits, self.env.batch_size, topk, nucleus_topp, temperature) + next_token = self._sampling( + logits, self.env.batch_size, decode_state.temperatures + ) if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 @@ -977,6 +988,7 @@ def get_prefix_destination_sharding(self) -> Prefix: if self.env.page_attention else self.cache_sharding, self.replicated, + self.replicated, ) def get_decode_state_sharding(self) -> DecodeState: @@ -991,9 +1003,6 @@ def get_decode_state_sharding(self) -> DecodeState: self.replicated, self.replicated, self.replicated, - self.replicated, - self.replicated, - self.replicated, ) def get_prefix_sequence_ddim(self) -> Any: From efefe9c6d0562c5e176b24b50a1b75b036267573 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 21 Oct 2024 20:35:24 +0000 Subject: [PATCH 3/9] Using sampler config instead of temperature only to be more comprehensive. Added tests. --- jetstream_pt/engine.py | 285 ++++++++++++++++++++----- tests/helpers.py | 6 +- tests/test_engine.py | 459 +++++++++++++++++++++++++++++++++++++++-- 3 files changed, 687 insertions(+), 63 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 649a67e..7c77aff 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -53,6 +53,8 @@ Params = jax.Array PrefillInputs = jax.Array +NEG_INF = -1.0e7 # Sampling masking + @struct.dataclass # pylint: disable-next=all @@ -60,8 +62,7 @@ class Prefix: token: jax.Array # [1, seqlen] caches: List[Tuple[jax.Array, jax.Array]] seq_len: int # true seqlen front pad - # temperature parameter for scaling probability - temperature: float + sampler_config: jax.Array # Sampler or sampling config, [] @struct.dataclass @@ -75,9 +76,16 @@ class DecodeState: current_position: int lens: jax.Array # [batch_size, 1], the output token length start: jax.Array # [batch_size, 1], the starting pos for each slot - input_pos: jax.Array # [batch_size, 1] input pos for each slot + input_pos: ( + jax.Array + ) # [batch_size, 1] total (prefill + decode) length for each slot mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid - temperatures: List[float] # [batch_size, 1], the temperature for each slot + # The sampling configuration. + # If sampling_algorithm set to empty, the shape is [batch_size, 1] + # Otherwise it's a list of intergers + # The last dimension contains [algorithm, temperature, topk, nucleus] + sampler_config: jax.Array | List[int] + # NOTE model specific @@ -95,7 +103,11 @@ def __init__( self.pt_model = pt_model self.env = env self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 - self.rng = jax.random.PRNGKey(0) + self.rng = jax.random.key(0).reshape( + 1, + ) + # For sampling + self.splited_rngs = jax.random.split(self.rng, num=self.env.batch_size) self.weights = weights self.y_sharding = env.sharding_by_axis(1) @@ -170,7 +182,20 @@ def init_decode_state( scalers = [] if self.env.quant_config.enable_kv_quantization: scalers = [c.scalers() for c in caches_obj] - temperatures = [self.env.temperature] * self.env.batch_size + + if self.env.sampling_algorithm == "": + # [algorithm, temperature, topk, nucleus] + sampler_config = [0, 0.0, 0, 0.0] + sampler_config = jnp.tile(sampler_config, (self.env.batch_size, 1)) + + else: + sampler_config = [ + self.env.sampling_algorithm, + self.env.temperature, + self.env.topk, + self.env.nucleus_topp, + ] + return DecodeState( jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), caches, @@ -184,7 +209,7 @@ def init_decode_state( float("-inf"), dtype=self.default_dtype, ), # mask - temperatures, + sampler_config, ) # pylint: disable-next=all @@ -284,21 +309,119 @@ def _call_model_prefill(self, weights, tokens, input_indexes): caches_res = [c.state() for c in caches] return torchjax.from_torch((res, caches_res)) + def _greedy_sampling(self, logits): + return jnp.argmax(logits, axis=-1) + + def _weighted_sampling(self, logits, rng, temperature): + return jax.random.categorical(rng, logits / temperature) + + def _nucleus_sampling(self, logits, rng, temperature, nucleus_topp): + # return sampling_utils.sample_nucleus_topp_logits(logits, nucleus_topp, temperature, rng) + """Restrict sampling to the top logits with cumulative probability >= + nucleus_topp. + + The nucleus sampling method is proposed in the paper `The Curious Case of + Neural Text Degeneration (https://arxiv.org/pdf/1904.09751.pdf)` + + """ + # if nucleus_topp < 0: + # raise ValueError( + # "Can't apply nucleus with parameter {nucleus_topp=} less zero" + # ) + logits_sorted = jnp.sort(logits, axis=-1)[..., ::-1] # sort descending + sorted_cum_probs = jnp.cumsum( + jax.nn.softmax(logits_sorted, axis=-1), axis=-1 + ) # get cumsum probs + cutoff_index = jnp.sum( + sorted_cum_probs < nucleus_topp, axis=-1, keepdims=True + ) # find cutoff index + cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) + logits = jnp.where( + logits < cutoff_logit, jnp.full_like(logits, NEG_INF), logits + ) + result = jax.random.categorical(rng, logits / temperature) + return result + + def _topk_sampling(self, logits, rng, temperature, topk): + # return sampling_utils.sample_topk_logits(logits, topk, temperature, rng) + """Restricting sampling to the best k logits.""" + # if topk <= 0: + # raise ValueError("Can't apply algorithm topk with parameter {topk=} <= 0") + sorted_indices = jnp.argsort(logits)[::-1] # Sort in descending order + topk_mask = jnp.arange(sorted_indices.shape[-1]) < topk + topk_idxs = jnp.where(topk_mask, sorted_indices, -1) + + topk_logits = jnp.where(topk_idxs == -1, -jnp.inf, logits) + + sampled_idx = jnp.expand_dims( + jax.random.categorical(rng, topk_logits / temperature).astype( + jnp.int32 + ), + axis=-1, + ) + sampled_tokens = jnp.squeeze( + jnp.take_along_axis(topk_idxs, sampled_idx, axis=-1), axis=-1 + ).astype(jnp.int32) + + return sampled_tokens + + # Algorithm type: + # 0: Greedy 1: Weighted 2: Nucleus 3: Top-k + def _apply_sampling( + self, logits, algorithm, rng, temperature, topk, nucleus_topp + ) -> jnp.ndarray: + return jax.lax.cond( + algorithm == 0, + lambda: self._greedy_sampling(logits), # Greedy + lambda: jax.lax.cond( + algorithm == 1, + lambda: self._weighted_sampling( + logits, rng, temperature + ), # Weighted + lambda: jax.lax.cond( + algorithm == 2, + lambda: self._nucleus_sampling( + logits, rng, temperature, nucleus_topp + ), # Nucleus sampling + lambda: self._topk_sampling( + logits, rng, temperature, topk + ), # Top-k sampling + ), + ), + ) + + def _custom_sampling( + self, logits, algorithm, rng, temperature, topk, nucleus_topp + ) -> jnp.ndarray: + if len(logits.shape) == 2: + logits = jnp.expand_dims(logits, 0) + logits = logits[:, -1] + apply_sampling_v = jax.vmap( + lambda _logits, _algorithm, _rngs, _temperature, _topk, _nucleus_topp: self._apply_sampling( + _logits, _algorithm, _rngs, _temperature, _topk, _nucleus_topp + ), + in_axes=(0, 0, 0, 0, 0, 0), + ) + apply_sampling_v = jax.jit(apply_sampling_v) + return apply_sampling_v( + logits, algorithm, rng, temperature, topk, nucleus_topp + ).reshape(self.env.batch_size, -1) + def _sampling( - self, logits: Any, batch_size: int, temperatures: List[float] + self, logits: Any, algorithm, rng, temperature, topk, nucleus_topp ) -> jnp.ndarray: if len(logits.shape) == 2: logits = jnp.expand_dims(logits, 0) return ( sampling_utils.sampling( logits[:, -1], - self.rng, - self.env.sampling_algorithm, - self.env.topk, - self.env.nucleus_topp, - temperatures, + rng, + algorithm, + topk, + nucleus_topp, + temperature, ) - .reshape(batch_size, -1) + .reshape(self.env.batch_size, -1) .astype(jnp.int32) ) @@ -307,9 +430,9 @@ def prefill( *, params: Any, # Weights existing_prefix: Optional[Prefix] = None, - padded_tokens: PrefillInputs, # PrefillInputs[jax.Array], + padded_tokens: PrefillInputs, # PrefillInputs[jax.Array] true_length: int, - sampler: Any, + sampler: Optional[Callable[[Any], Any]] = None, ) -> Tuple[Prefix, engine_api.ResultTokens]: if isinstance(padded_tokens, jax.Array): batched_token = padded_tokens.reshape(1, -1) @@ -328,23 +451,39 @@ def prefill( if len(logits.shape) == 3: # b, seqlen, num words logits = logits[0] # seqlen, num words - temperature = ( - sampler["temperature"] - if isinstance(sampler, dict) and "temperature" in sampler - else self.env.temperature - ) - - if sampler and callable(sampler): - token = sampler(logits[true_length - 1]) + if self.env.sampling_algorithm == "": + assert ( + isinstance(sampler, List) + or isinstance(sampler, list) + or isinstance(sampler, Tuple) + or isinstance(sampler, tuple) + ), f"{type(sampler)} is not valid" + algorithm, temperature, topk, nucleus_topp = sampler + + algorithm = jnp.array([algorithm]) + temperature = jnp.array([temperature]) + topk = jnp.array([topk]) + nucleus_topp = jnp.array([nucleus_topp]) + + # Prefill only handle batch size of 1, therefore no need to use splitted rngs + sampling = self._custom_sampling else: - token = sampling_utils.sampling( - logits[true_length - 1], - self.rng, + algorithm, temperature, topk, nucleus_topp = ( self.env.sampling_algorithm, + self.env.temperature, self.env.topk, self.env.nucleus_topp, - temperature, ) + sampling = self._sampling + + token = sampling( + logits=logits, + algorithm=algorithm, + rng=self.rng, + temperature=temperature, + topk=topk, + nucleus_topp=nucleus_topp, + ) token_out = jnp.reshape(token, (1, 1)) data = jnp.concatenate( @@ -371,7 +510,10 @@ def prefill( # v, seq_len - true_length, true_length, axis=2)) # for k, v in updated_caches # ] - return Prefix(token, updated_caches, true_length, temperature), result + return ( + Prefix(token, updated_caches, true_length, jnp.array(sampler)), + result, + ) def shrink_prefix( self, @@ -490,7 +632,19 @@ def insert(cache, scaler, new_entry, update_index): caches.append((kcache, vcache)) scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) - decode_state.temperatures[slot] = prefix.temperature + + if self.env.sampling_algorithm == "": + sampler_config = decode_state.sampler_config.at[slot].set( + prefix.sampler_config + ) + else: + sampler_config = [ + self.env.sampling_algorithm, + self.env.temperature, + self.env.topk, + self.env.nucleus_topp, + ] + return DecodeState( tokens, caches, @@ -500,7 +654,7 @@ def insert(cache, scaler, new_entry, update_index): start, input_pos, mask, - decode_state.temperatures, + sampler_config, ) # pylint: disable-next=all @@ -585,7 +739,19 @@ def insert(cache, scaler, new_entry): scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) - decode_state.temperatures[slot] = prefix.temperature + + if self.env.sampling_algorithm == "": + sampler_config = decode_state.sampler_config.at[slot].set( + prefix.sampler_config + ) + else: + sampler_config = [ + self.env.sampling_algorithm, + self.env.temperature, + self.env.topk, + self.env.nucleus_topp, + ] + return DecodeState( tokens, caches, @@ -595,7 +761,7 @@ def insert(cache, scaler, new_entry): start, input_pos, mask, - decode_state.temperatures, + sampler_config, ) def _insert_page_attention( @@ -631,7 +797,19 @@ def _insert_page_attention( input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len) scales = None lens = decode_state.lens.at[slot].set(1) - decode_state.temperatures[slot] = prefix.temperature + + if self.env.sampling_algorithm == "": + sampler_config = decode_state.sampler_config.at[slot].set( + prefix.sampler_config + ) + else: + sampler_config = [ + self.env.sampling_algorithm, + self.env.temperature, + self.env.topk, + self.env.nucleus_topp, + ] + return DecodeState( tokens, caches, @@ -641,7 +819,7 @@ def _insert_page_attention( start, input_pos, mask, - decode_state.temperatures, + sampler_config, ) def insert( @@ -774,7 +952,6 @@ def generate_impl( self, params: Any, decode_state: DecodeState, - sampler=None, page_token_indices=None, ) -> tuple[DecodeState, engine_api.ResultTokens]: # seq_len = padded_tokens.shape[0] @@ -786,12 +963,16 @@ def generate_impl( else: input_indexes = decode_state.input_pos - ragged_batch_index, ragged_block_index = ( - self.precompute_ragged_block_indices(decode_state) - ) - ragged_batch_index, ragged_block_index = ragged_batch_index.reshape( - (-1) - ), ragged_block_index.reshape((-1)) + # TODO(lancewang): Remove ragged index precomputation + # ragged_batch_index, ragged_block_index = ( + # self.precompute_ragged_block_indices(decode_state) + # ) + # ragged_batch_index, ragged_block_index = ragged_batch_index.reshape( + # (-1) + # ), ragged_block_index.reshape((-1)) + + ragged_batch_index = 0 + ragged_block_index = 0 def update_mask(): if self.env.ring_buffer: @@ -821,12 +1002,21 @@ def update_mask(): # fill mask later, now use flash attention mask = update_mask() - if sampler and callable(sampler): - next_token = sampler(logits[:, -1]) + if self.env.sampling_algorithm == "": + sampling = self._custom_sampling + rng = self.splited_rngs else: - next_token = self._sampling( - logits, self.env.batch_size, decode_state.temperatures - ) + sampling = self._sampling + rng = self.rng + + next_token = sampling( + logits=logits, + algorithm=decode_state.sampler_config[:, 0:1].reshape(-1), + rng=rng, + temperature=decode_state.sampler_config[:, 1:2].reshape(-1), + topk=decode_state.sampler_config[:, 2:3].reshape(-1), + nucleus_topp=decode_state.sampler_config[:, 3:4].reshape(-1), + ) if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 @@ -869,6 +1059,7 @@ def update_mask(): decode_state.start, input_pos, mask, + decode_state.sampler_config, ) return new_decode_state, result_tokens diff --git a/tests/helpers.py b/tests/helpers.py index ac0ea5f..860389d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,7 +7,9 @@ # pylint: disable-next=all -def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None): +def make_env_tiny( + bf16_enable=True, env_data_update_fn=lambda _: None, batch_size=1 +): torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 torch.set_default_dtype(torch_dtype) jax.config.update("jax_dynamic_shapes", False) @@ -19,7 +21,7 @@ def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None): environment_data.cache_sequence_length = 128 environment_data.bf16_enable = bf16_enable environment_data.model_type = "llama-2-tiny" - environment_data.batch_size = 1 + environment_data.batch_size = batch_size environment_data.num_layers = config.n_layers environment_data.cache_shape = ( 1, diff --git a/tests/test_engine.py b/tests/test_engine.py index 57245c0..2fe94ba 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -20,16 +20,61 @@ from jetstream_pt.third_party.llama import model_exportable from jetstream_pt.engine import PyTorchEngine +from jetstream_pt.engine import DecodeState +from jetstream_pt.engine import Prefix from tests import helpers +from jetstream_pt import cache_manager + + +class MockEngine(PyTorchEngine): + + def _call_model_prefill(self, weights, tokens, input_indexes): + caches = [ + cache_manager.KVCachePrefill( + self.env.quant_config.enable_kv_quantization + ) + for _ in self.pt_model.layers + ] + # logits = jnp.ones((self.env.batch_size, 1), jnp.float32) + assert ( + self.env.batch_size == 1 + ), f"The batch size {self.env.batch_size} != 1" + logits = jnp.array([[0.5, 0.6, 0.7, 0.8]]) + return logits, caches + + def _call_model_generate( + self, + weights, + tokens, + input_indexes, + caches, + cache_scales, + mask, + start, + input_pos, + ragged_batch_index, + ragged_block_index, + page_token_indices, + ): + logits = jnp.array( + [ + [[0.5, 0.6, 0.7, 0.8]], + # [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]], + [[0.4, 0.3, 0.2, 0.1]], + ] + ) + return logits, caches, cache_scales class EngineTest(unittest.TestCase): - def setup(self): - env, model_arg = helpers.make_env_tiny(bf16_enable=True) + def setup(self, batch_size=1): + env, model_arg = helpers.make_env_tiny( + bf16_enable=True, batch_size=batch_size + ) model_ours = model_exportable.Transformer(model_arg, env) - engine = PyTorchEngine(pt_model=model_ours, env=env) - engine.rng = jax.random.PRNGKey(0) + engine = MockEngine(pt_model=model_ours, env=env) + engine.rng = jax.random.key(0) return engine def test_sampling_2D(self): @@ -37,14 +82,23 @@ def test_sampling_2D(self): engine = self.setup() self.assertEqual(engine.env.sampling_algorithm, "greedy") logits = jnp.array([[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]]) - token = engine._sampling(logits, batch_size=1) + token = engine._sampling( + logits, "greedy", engine.rng, temperature=1.0, topk=1, nucleus_topp=0.0 + ) self.assertEqual(token, jnp.array([[0]])) self.assertTrue(jnp.isdtype(token, jnp.int32)) # test weighted engine.env.sampling_algorithm = "weighted" engine.env.temperature = 5.0 - token = engine._sampling(logits, batch_size=1) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + temperature=5.0, + topk=1, + nucleus_topp=0.0, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) @@ -52,21 +106,36 @@ def test_sampling_2D(self): engine.env.sampling_algorithm = "topk" engine.env.temperature = 5.0 engine.env.topk = 4 - token = engine._sampling(logits, batch_size=1) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + temperature=5.0, + topk=4, + nucleus_topp=0.0, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) # test nucleus engine.env.sampling_algorithm = "nucleus" - engine.env.temperature = 0.0 + engine.env.temperature = 1.0 engine.env.nucleus_topp = 0.8 - token = engine._sampling(logits, batch_size=1) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + temperature=0.0, + topk=1, + nucleus_topp=0.8, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) def test_sampling_3D(self): # test greedy - engine = self.setup() + engine = self.setup(batch_size=2) + self.assertEqual(engine.env.sampling_algorithm, "greedy") logits = jnp.array( [ @@ -74,14 +143,28 @@ def test_sampling_3D(self): [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]], ] ) - token = engine._sampling(logits, batch_size=2) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + engine.env.temperature, + engine.env.topk, + engine.env.nucleus_topp, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[3], [0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) # test weighted engine.env.sampling_algorithm = "weighted" engine.env.temperature = 10.0 - token = engine._sampling(logits, batch_size=2) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + engine.env.temperature, + engine.env.topk, + engine.env.nucleus_topp, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) @@ -89,7 +172,14 @@ def test_sampling_3D(self): engine.env.sampling_algorithm = "topk" engine.env.temperature = 1.0 engine.env.topk = 3 - token = engine._sampling(logits, batch_size=2) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + engine.env.temperature, + engine.env.topk, + engine.env.nucleus_topp, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[1], [0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) @@ -97,10 +187,351 @@ def test_sampling_3D(self): engine.env.sampling_algorithm = "nucleus" engine.env.temperature = 1.0 engine.env.nucleus_topp = 0.8 - token = engine._sampling(logits, batch_size=2) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + engine.env.temperature, + engine.env.topk, + engine.env.nucleus_topp, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) + def test_custom_sampling_3D(self): + engine = self.setup(batch_size=2) + engine.rng = jax.random.key(3) + engine.splited_rngs = jax.random.split( + engine.rng, num=engine.env.batch_size + ) + engine.env.sampling_algorithm = "" + + # Need a different engine of batch size of 1 to reshape the output + engine_b1 = self.setup() + engine_b1.rng = jax.random.key(3) + logits = jnp.array( + [ + [[0.4, 0.3, 0.2, 0.1], [0.5, 0.6, 0.7, 0.8]], + [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]], + ] + ) + + # test greedy + token = engine._custom_sampling( + logits, + jnp.array([0, 0]), + engine.splited_rngs, + temperature=jnp.array([[0.0], [0.0]]), + topk=jnp.array([[0], [0]]), + nucleus_topp=jnp.array([[0.0], [0.0]]), + ) + original_tokens = [] + for i in range(2): + original_token = engine_b1._sampling( + logits[i], + "greedy", + engine.splited_rngs[i], + temperature=1.0, + topk=0, + nucleus_topp=0.0, + ) + original_tokens.append(original_token) + original_tokens = jnp.concatenate(original_tokens) + + print(f"custom sampling token {token} vs original tokens {original_tokens}") + self.assertTrue(jnp.array_equal(token, original_tokens)) + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [0]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test weighted + engine.env.sampling_algorithm = "weighted" + token = engine._custom_sampling( + logits, + jnp.array([1, 1]), + engine.splited_rngs, + temperature=jnp.array([1, 1]), + topk=jnp.array([0, 0]), + nucleus_topp=jnp.array([0.0, 0.0]), + ) + original_tokens = [] + for i in range(2): + original_token = engine_b1._sampling( + logits[i], + "weighted", + engine.splited_rngs[i], + temperature=1, + topk=0, + nucleus_topp=0.0, + ) + original_tokens.append(original_token) + original_tokens = jnp.concatenate(original_tokens) + + print(f"custom sampling token {token} vs original tokens {original_tokens}") + self.assertTrue(jnp.array_equal(token, original_tokens)) + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test topk + engine.env.sampling_algorithm = "topk" + token = engine._custom_sampling( + logits, + jnp.array([3, 3]), + engine.splited_rngs, + temperature=jnp.array([[1.0], [1.0]]), + topk=jnp.array([[3], [3]]), + nucleus_topp=jnp.array([[0.0], [0.0]]), + ) + original_tokens = [] + for i in range(2): + original_token = engine_b1._sampling( + logits[i], + "topk", + engine.splited_rngs[i], + temperature=1.0, + topk=3, + nucleus_topp=0.0, + ) + original_tokens.append(original_token) + original_tokens = jnp.concatenate(original_tokens) + + print(f"custom sampling token {token} vs original tokens {original_tokens}") + self.assertTrue(jnp.array_equal(token, original_tokens)) + self.assertTrue(jnp.array_equal(token, jnp.array([[1], [2]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test nucleus + engine.env.sampling_algorithm = "nucleus" + token = engine._custom_sampling( + logits, + jnp.array([2, 2]), + engine.splited_rngs, + temperature=jnp.array([[1.0], [1.0]]), + topk=jnp.array([[0], [0]]), + nucleus_topp=jnp.array([[0.8], [0.8]]), + ) + + original_tokens = [] + for i in range(2): + original_token = engine_b1._sampling( + logits[i], + "nucleus", + engine.splited_rngs[i], + temperature=1.0, + topk=0, + nucleus_topp=0.8, + ) + original_tokens.append(original_token) + original_tokens = jnp.concatenate(original_tokens) + print(f"custom sampling token {token} vs original tokens {original_tokens}") + self.assertTrue(jnp.array_equal(token, original_tokens)) + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test greedy + topk + token = engine._custom_sampling( + logits, + jnp.array([0, 3]), + engine.splited_rngs, + temperature=jnp.array([[0.0], [1.0]]), + topk=jnp.array([[0], [3]]), + nucleus_topp=jnp.array([[0.0], [0.0]]), + ) + original_tokens = [] + + i = 0 + original_token = engine_b1._sampling( + logits[i], + "greedy", + engine.splited_rngs[i], + temperature=0.0, + topk=0, + nucleus_topp=0.8, + ) + original_tokens.append(original_token) + + i = 1 + original_token = engine_b1._sampling( + logits[i], + "topk", + engine.splited_rngs[i], + temperature=1.0, + topk=3, + nucleus_topp=0.0, + ) + original_tokens.append(original_token) + + original_tokens = jnp.concatenate(original_tokens) + + print(f"custom sampling token {token} vs original tokens {original_tokens}") + self.assertTrue(jnp.array_equal(token, original_tokens)) + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + def test_prefill_with_custom_sampling(self): + engine = self.setup() + engine.env.sampling_algorithm = "" + + # Inputs doesn't matter + params = jnp.zeros((1,), jnp.float32) + padded_tokens = jnp.zeros((1,), jnp.float32) + true_length = 1 + + # Greedy + # algorithm, temperature, topk, nucleus_topp + sampler = [0, 1.0, 3, 0.8] + prefix, _ = engine.prefill( + params=params, + padded_tokens=padded_tokens, + true_length=true_length, + sampler=sampler, + ) + token = prefix.token + print(f"Greedy output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[3]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + print( + f"prefix sampler config {prefix.sampler_config} vs sampler {jnp.array(sampler)}" + ) + self.assertAlmostEqual( + prefix.sampler_config.all(), jnp.array(sampler).all() + ) + + # Weighted + sampler = [1, 10.0, 3, 0.8] + prefix, _ = engine.prefill( + params=params, + padded_tokens=padded_tokens, + true_length=true_length, + sampler=sampler, + ) + token = prefix.token + print(f"Weighted output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + self.assertAlmostEqual( + prefix.sampler_config.all(), jnp.array(sampler).all() + ) + + # Nucleus + sampler = [2, 1.0, 3, 0.0] + prefix, _ = engine.prefill( + params=params, + padded_tokens=padded_tokens, + true_length=true_length, + sampler=sampler, + ) + token = prefix.token + print(f"Topk output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[3]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + self.assertAlmostEqual( + prefix.sampler_config.all(), jnp.array(sampler).all() + ) + + # Topk + sampler = [3, 1.0, 3, 0.8] + prefix, _ = engine.prefill( + params=params, + padded_tokens=padded_tokens, + true_length=true_length, + sampler=sampler, + ) + token = prefix.token + print(f"Nucleus output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[3]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + self.assertAlmostEqual( + prefix.sampler_config.all(), jnp.array(sampler).all() + ) + + def test_insert_no_wrap_with_custom_sampling(self): + engine = self.setup() + engine.env.sampling_algorithm = "" + engine.env.batch_size = 2 + cache_shape = engine.env.cache_shape + + sampler_config_raw = [0, 1.0, 3, 0.8] + sampler_config = jnp.array(sampler_config_raw) + + prefill_cache_shape = (1, cache_shape[1], 16, cache_shape[3]) + prefill_cache = [] + for _ in range(engine.env.num_layers): + prefill_cache.append( + ( + jnp.ones(prefill_cache_shape, dtype=jnp.bfloat16), + jnp.ones(prefill_cache_shape, dtype=jnp.bfloat16), + ) + ) + + prefix = Prefix( + token=jnp.ones((1)), + caches=prefill_cache, + seq_len=16, + sampler_config=sampler_config, + ) + + doesnt_matter = jnp.array([0]) + kv_cache = engine.env.make_caches_generate() + kv_cache = [c.state() for c in kv_cache] + + decode_state = DecodeState( + tokens=jnp.zeros((engine.env.batch_size, 1)), + caches=kv_cache, + cache_scales=[doesnt_matter], + current_position=16, + lens=jnp.zeros((engine.env.batch_size, 1)), + start=jnp.zeros((engine.env.batch_size, 1)), + input_pos=jnp.zeros((engine.env.batch_size,)), + mask=jnp.zeros((engine.env.batch_size, 128)), + sampler_config=jnp.zeros((engine.env.batch_size, 4)), + ) + + # Insert to slot 1 + result_decode_state = engine._insert_no_wrap(prefix, decode_state, slot=1) + + self.assertAlmostEqual( + result_decode_state.tokens.all(), decode_state.tokens.all() + ) + self.assertAlmostEqual( + result_decode_state.sampler_config.all(), + jnp.array([[0, 0, 0, 0], sampler_config_raw]).all(), + ) + + def test_decode_with_custom_sampling(self): + engine = self.setup(batch_size=2) + engine.rng = jax.random.key(3) + engine.splited_rngs = jax.random.split( + engine.rng, num=engine.env.batch_size + ) + engine.env.sampling_algorithm = "" + + # Inputs doesn't matter + doesnt_matter = jnp.array([0]) + params = doesnt_matter + + decode_state = DecodeState( + tokens=jnp.zeros((engine.env.batch_size, 1)), + caches=[doesnt_matter], + cache_scales=[doesnt_matter], + current_position=0, + lens=jnp.zeros((engine.env.batch_size, 1)), + start=doesnt_matter, + input_pos=jnp.zeros((engine.env.batch_size,)), + mask=jnp.zeros((engine.env.batch_size, 1)), + sampler_config=jnp.array([[0, 0.0, 0, 0.0], [3, 1.0, 3, 0.0]]), + ) + + # Topk + Weighted + # algorithm, temperature, topk, nucleus_topp + decode_state, _ = engine.generate_impl( + params=params, decode_state=decode_state + ) + token = decode_state.tokens + print(f"Greedy output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + # def test_insert(self): # seqlen = 32 From 240ea4774d23ac822f05726a37c4409fe126336c Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 21 Oct 2024 20:50:11 +0000 Subject: [PATCH 4/9] Fix rng shape. Fix typo. --- jetstream_pt/engine.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 7c77aff..a2103ec 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -103,9 +103,7 @@ def __init__( self.pt_model = pt_model self.env = env self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 - self.rng = jax.random.key(0).reshape( - 1, - ) + self.rng = jax.random.key(0) # For sampling self.splited_rngs = jax.random.split(self.rng, num=self.env.batch_size) self.weights = weights @@ -467,6 +465,9 @@ def prefill( # Prefill only handle batch size of 1, therefore no need to use splitted rngs sampling = self._custom_sampling + rng = self.rng.reshape( + 1, + ) else: algorithm, temperature, topk, nucleus_topp = ( self.env.sampling_algorithm, @@ -475,11 +476,12 @@ def prefill( self.env.nucleus_topp, ) sampling = self._sampling + rng = self.rng token = sampling( logits=logits, algorithm=algorithm, - rng=self.rng, + rng=rng, temperature=temperature, topk=topk, nucleus_topp=nucleus_topp, From b16a66f0ccfba0710f0a8a15549bc7df9054138d Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 23 Oct 2024 04:40:20 +0000 Subject: [PATCH 5/9] Take sampling callable. --- jetstream_pt/engine.py | 382 +++++++++++++++++++++-------------------- tests/test_engine.py | 193 +++++++++------------ 2 files changed, 270 insertions(+), 305 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index a2103ec..86977f4 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -44,6 +44,7 @@ from jetstream_pt.third_party.mixtral import config as mixtral_config, model as mixtral_model from absl import flags +from collections.abc import Callable FLAGS = flags.FLAGS @@ -53,7 +54,29 @@ Params = jax.Array PrefillInputs = jax.Array -NEG_INF = -1.0e7 # Sampling masking + +class DefaultSampler: + def __init__(self, rng, temperature, topk, nucleus_topp, algorithm): + self.rng = rng + self.temperature = temperature + self.topk = topk + self.nucleus_topp = nucleus_topp + self.algorithm = algorithm + + def __call__(self, logits): + return PyTorchEngine._sampling(logits, self.algorithm, self.rng, self.temperature, self.topk, self.nucleus_topp) + + # Define how to flatten the instance into leaves and auxiliary data + def tree_flatten(self): + children = (self.rng, self.temperature, self.topk, self.nucleus_topp, self.algorithm) # Leaves to be flattened + aux_data = () # Auxiliary data that does not need to be traced + return children, aux_data + + # Define how to unflatten from leaves and auxiliary data + @classmethod + def tree_unflatten(cls, aux_data, children): + rng, temperature, topk, nucleus_topp, algorithm = children + return cls(rng, temperature, topk, nucleus_topp, algorithm) @struct.dataclass @@ -62,7 +85,7 @@ class Prefix: token: jax.Array # [1, seqlen] caches: List[Tuple[jax.Array, jax.Array]] seq_len: int # true seqlen front pad - sampler_config: jax.Array # Sampler or sampling config, [] + sampler: List[Any] | int # User defined Sampler @struct.dataclass @@ -84,7 +107,8 @@ class DecodeState: # If sampling_algorithm set to empty, the shape is [batch_size, 1] # Otherwise it's a list of intergers # The last dimension contains [algorithm, temperature, topk, nucleus] - sampler_config: jax.Array | List[int] + # sampler_config: jax.Array | List[int] + samplers: jax.Array # NOTE model specific @@ -104,8 +128,7 @@ def __init__( self.env = env self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 self.rng = jax.random.key(0) - # For sampling - self.splited_rngs = jax.random.split(self.rng, num=self.env.batch_size) + self.weights = weights self.y_sharding = env.sharding_by_axis(1) @@ -117,10 +140,10 @@ def __init__( jax.config.update("jax_enable_x64", False) self.prefill_cache_sharding = self.env.prefill_cache_sharding - self.prefill = jax.jit( - self.prefill, - out_shardings=(self.get_prefix_destination_sharding(), None), - ) + # self.prefill = jax.jit( + # self.prefill, + # out_shardings=(self.get_prefix_destination_sharding(), None), + # ) self.insert = jax.jit( self.insert, donate_argnums=(0, 1), @@ -183,16 +206,17 @@ def init_decode_state( if self.env.sampling_algorithm == "": # [algorithm, temperature, topk, nucleus] - sampler_config = [0, 0.0, 0, 0.0] - sampler_config = jnp.tile(sampler_config, (self.env.batch_size, 1)) - + # sampler_config = [0, 0.0, 0, 0.0] + # sampler_config = jnp.tile(sampler_config, (self.env.batch_size, 1)) + samplers = [] else: - sampler_config = [ - self.env.sampling_algorithm, - self.env.temperature, - self.env.topk, - self.env.nucleus_topp, - ] + # sampler_config = [ + # self.env.sampling_algorithm, + # self.env.temperature, + # self.env.topk, + # self.env.nucleus_topp, + # ] + samplers = DefaultSampler(self.rng, self.env.temperature, self.env.topk, self.env.nucleus_topp, self.env.algorithm) return DecodeState( jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), @@ -307,119 +331,133 @@ def _call_model_prefill(self, weights, tokens, input_indexes): caches_res = [c.state() for c in caches] return torchjax.from_torch((res, caches_res)) - def _greedy_sampling(self, logits): - return jnp.argmax(logits, axis=-1) - - def _weighted_sampling(self, logits, rng, temperature): - return jax.random.categorical(rng, logits / temperature) - - def _nucleus_sampling(self, logits, rng, temperature, nucleus_topp): - # return sampling_utils.sample_nucleus_topp_logits(logits, nucleus_topp, temperature, rng) - """Restrict sampling to the top logits with cumulative probability >= - nucleus_topp. - - The nucleus sampling method is proposed in the paper `The Curious Case of - Neural Text Degeneration (https://arxiv.org/pdf/1904.09751.pdf)` - - """ - # if nucleus_topp < 0: - # raise ValueError( - # "Can't apply nucleus with parameter {nucleus_topp=} less zero" - # ) - logits_sorted = jnp.sort(logits, axis=-1)[..., ::-1] # sort descending - sorted_cum_probs = jnp.cumsum( - jax.nn.softmax(logits_sorted, axis=-1), axis=-1 - ) # get cumsum probs - cutoff_index = jnp.sum( - sorted_cum_probs < nucleus_topp, axis=-1, keepdims=True - ) # find cutoff index - cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) - logits = jnp.where( - logits < cutoff_logit, jnp.full_like(logits, NEG_INF), logits - ) - result = jax.random.categorical(rng, logits / temperature) - return result - - def _topk_sampling(self, logits, rng, temperature, topk): - # return sampling_utils.sample_topk_logits(logits, topk, temperature, rng) - """Restricting sampling to the best k logits.""" - # if topk <= 0: - # raise ValueError("Can't apply algorithm topk with parameter {topk=} <= 0") - sorted_indices = jnp.argsort(logits)[::-1] # Sort in descending order - topk_mask = jnp.arange(sorted_indices.shape[-1]) < topk - topk_idxs = jnp.where(topk_mask, sorted_indices, -1) - - topk_logits = jnp.where(topk_idxs == -1, -jnp.inf, logits) - - sampled_idx = jnp.expand_dims( - jax.random.categorical(rng, topk_logits / temperature).astype( - jnp.int32 - ), - axis=-1, - ) - sampled_tokens = jnp.squeeze( - jnp.take_along_axis(topk_idxs, sampled_idx, axis=-1), axis=-1 - ).astype(jnp.int32) - - return sampled_tokens + # def _greedy_sampling(self, logits): + # return jnp.argmax(logits, axis=-1) + + # def _weighted_sampling(self, logits, rng, temperature): + # return jax.random.categorical(rng, logits / temperature) + + # def _nucleus_sampling(self, logits, rng, temperature, nucleus_topp): + # # return sampling_utils.sample_nucleus_topp_logits(logits, nucleus_topp, temperature, rng) + # """Restrict sampling to the top logits with cumulative probability >= + # nucleus_topp. + + # The nucleus sampling method is proposed in the paper `The Curious Case of + # Neural Text Degeneration (https://arxiv.org/pdf/1904.09751.pdf)` + + # """ + # # if nucleus_topp < 0: + # # raise ValueError( + # # "Can't apply nucleus with parameter {nucleus_topp=} less zero" + # # ) + # logits_sorted = jnp.sort(logits, axis=-1)[..., ::-1] # sort descending + # sorted_cum_probs = jnp.cumsum( + # jax.nn.softmax(logits_sorted, axis=-1), axis=-1 + # ) # get cumsum probs + # cutoff_index = jnp.sum( + # sorted_cum_probs < nucleus_topp, axis=-1, keepdims=True + # ) # find cutoff index + # cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) + # logits = jnp.where( + # logits < cutoff_logit, jnp.full_like(logits, NEG_INF), logits + # ) + # result = jax.random.categorical(rng, logits / temperature) + # return result + + # def _topk_sampling(self, logits, rng, temperature, topk): + # # return sampling_utils.sample_topk_logits(logits, topk, temperature, rng) + # """Restricting sampling to the best k logits.""" + # # if topk <= 0: + # # raise ValueError("Can't apply algorithm topk with parameter {topk=} <= 0") + # sorted_indices = jnp.argsort(logits)[::-1] # Sort in descending order + # topk_mask = jnp.arange(sorted_indices.shape[-1]) < topk + # topk_idxs = jnp.where(topk_mask, sorted_indices, -1) + + # topk_logits = jnp.where(topk_idxs == -1, -jnp.inf, logits) + + # sampled_idx = jnp.expand_dims( + # jax.random.categorical(rng, topk_logits / temperature).astype( + # jnp.int32 + # ), + # axis=-1, + # ) + # sampled_tokens = jnp.squeeze( + # jnp.take_along_axis(topk_idxs, sampled_idx, axis=-1), axis=-1 + # ).astype(jnp.int32) + + # return sampled_tokens # Algorithm type: # 0: Greedy 1: Weighted 2: Nucleus 3: Top-k - def _apply_sampling( - self, logits, algorithm, rng, temperature, topk, nucleus_topp - ) -> jnp.ndarray: - return jax.lax.cond( - algorithm == 0, - lambda: self._greedy_sampling(logits), # Greedy - lambda: jax.lax.cond( - algorithm == 1, - lambda: self._weighted_sampling( - logits, rng, temperature - ), # Weighted - lambda: jax.lax.cond( - algorithm == 2, - lambda: self._nucleus_sampling( - logits, rng, temperature, nucleus_topp - ), # Nucleus sampling - lambda: self._topk_sampling( - logits, rng, temperature, topk - ), # Top-k sampling - ), - ), - ) + # def _apply_sampling( + # self, logits, algorithm, rng, temperature, topk, nucleus_topp + # ) -> jnp.ndarray: + # return jax.lax.cond( + # algorithm == 0, + # lambda: self._greedy_sampling(logits), # Greedy + # lambda: jax.lax.cond( + # algorithm == 1, + # lambda: self._weighted_sampling( + # logits, rng, temperature + # ), # Weighted + # lambda: jax.lax.cond( + # algorithm == 2, + # lambda: self._nucleus_sampling( + # logits, rng, temperature, nucleus_topp + # ), # Nucleus sampling + # lambda: self._topk_sampling( + # logits, rng, temperature, topk + # ), # Top-k sampling + # ), + # ), + # ) def _custom_sampling( - self, logits, algorithm, rng, temperature, topk, nucleus_topp + self, logits, samplers ) -> jnp.ndarray: if len(logits.shape) == 2: logits = jnp.expand_dims(logits, 0) + logits = logits[:, -1] - apply_sampling_v = jax.vmap( - lambda _logits, _algorithm, _rngs, _temperature, _topk, _nucleus_topp: self._apply_sampling( - _logits, _algorithm, _rngs, _temperature, _topk, _nucleus_topp - ), - in_axes=(0, 0, 0, 0, 0, 0), - ) - apply_sampling_v = jax.jit(apply_sampling_v) - return apply_sampling_v( - logits, algorithm, rng, temperature, topk, nucleus_topp - ).reshape(self.env.batch_size, -1) + + # Prefill and Generate have different batch size + current_batch_size = logits.shape[0] + + idx = jnp.arange(current_batch_size) + apply_sampler = lambda i, l: jax.lax.switch(i, samplers, l) + apply_vmap = jax.vmap(apply_sampler, in_axes=(0, 0)) + return apply_vmap(idx, logits).reshape(current_batch_size, -1) + + # apply_sampling_v = jax.vmap( + # lambda _logits, _algorithm, _rngs, _temperature, _topk, _nucleus_topp: self._apply_sampling( + # _logits, _algorithm, _rngs, _temperature, _topk, _nucleus_topp + # ), + # in_axes=(0, 0, 0, 0, 0, 0), + # ) + # apply_sampling_v = jax.jit(apply_sampling_v) + # return apply_sampling_v( + # logits, algorithm, rng, temperature, topk, nucleus_topp + # ).reshape(self.env.batch_size, -1) def _sampling( self, logits: Any, algorithm, rng, temperature, topk, nucleus_topp ) -> jnp.ndarray: if len(logits.shape) == 2: logits = jnp.expand_dims(logits, 0) + + logits = logits[:, -1] + current_batch_size = logits.shape[0] + return ( sampling_utils.sampling( - logits[:, -1], - rng, - algorithm, - topk, - nucleus_topp, - temperature, + logits=logits, + rng=rng, + algorithm=algorithm, + topk=topk, + nucleus_topp=nucleus_topp, + temperature=temperature, ) - .reshape(self.env.batch_size, -1) + .reshape(current_batch_size, -1) .astype(jnp.int32) ) @@ -449,25 +487,30 @@ def prefill( if len(logits.shape) == 3: # b, seqlen, num words logits = logits[0] # seqlen, num words + prefill_batch_size = logits.shape[0] if self.env.sampling_algorithm == "": - assert ( - isinstance(sampler, List) - or isinstance(sampler, list) - or isinstance(sampler, Tuple) - or isinstance(sampler, tuple) - ), f"{type(sampler)} is not valid" - algorithm, temperature, topk, nucleus_topp = sampler - - algorithm = jnp.array([algorithm]) - temperature = jnp.array([temperature]) - topk = jnp.array([topk]) - nucleus_topp = jnp.array([nucleus_topp]) + # assert ( + # isinstance(sampler, List) + # or isinstance(sampler, list) + # or isinstance(sampler, Tuple) + # or isinstance(sampler, tuple) + # ), f"{type(sampler)} is not valid" + # algorithm, temperature, topk, nucleus_topp = sampler + + # algorithm = jnp.array([algorithm]) + # temperature = jnp.array([temperature]) + # topk = jnp.array([topk]) + # nucleus_topp = jnp.array([nucleus_topp]) # Prefill only handle batch size of 1, therefore no need to use splitted rngs - sampling = self._custom_sampling - rng = self.rng.reshape( - 1, - ) + token = self._custom_sampling(logits, [sampler]) + + # if len(logits.shape) == 2: + # logits = jnp.expand_dims(logits, 0) + + # logits = logits[:, -1] + # token = sampler(logits) + else: algorithm, temperature, topk, nucleus_topp = ( self.env.sampling_algorithm, @@ -475,19 +518,15 @@ def prefill( self.env.topk, self.env.nucleus_topp, ) - sampling = self._sampling - rng = self.rng - - token = sampling( - logits=logits, - algorithm=algorithm, - rng=rng, - temperature=temperature, - topk=topk, - nucleus_topp=nucleus_topp, - ) + sampling = lambda logits: self._sampling(logits, algorithm, jax.random.key(0), temperature, topk, nucleus_topp) + # No need to store the sampler if using default + sampler = 0 + + # token = sampling( + # logits=logits + # ).reshape(prefill_batch_size, 1) + token_out = token.reshape(prefill_batch_size, 1) - token_out = jnp.reshape(token, (1, 1)) data = jnp.concatenate( [ token_out, # First token @@ -513,7 +552,7 @@ def prefill( # for k, v in updated_caches # ] return ( - Prefix(token, updated_caches, true_length, jnp.array(sampler)), + Prefix(token_out, updated_caches, true_length, sampler), result, ) @@ -636,16 +675,10 @@ def insert(cache, scaler, new_entry, update_index): lens = decode_state.lens.at[slot].set(1) if self.env.sampling_algorithm == "": - sampler_config = decode_state.sampler_config.at[slot].set( - prefix.sampler_config - ) - else: - sampler_config = [ - self.env.sampling_algorithm, - self.env.temperature, - self.env.topk, - self.env.nucleus_topp, - ] + # sampler_config = decode_state.sampler_config.at[slot].set( + # prefix.sampler_config + # ) + decode_state.samplers[slot] = prefix.sampler return DecodeState( tokens, @@ -656,7 +689,7 @@ def insert(cache, scaler, new_entry, update_index): start, input_pos, mask, - sampler_config, + decode_state.samplers, ) # pylint: disable-next=all @@ -743,16 +776,7 @@ def insert(cache, scaler, new_entry): lens = decode_state.lens.at[slot].set(1) if self.env.sampling_algorithm == "": - sampler_config = decode_state.sampler_config.at[slot].set( - prefix.sampler_config - ) - else: - sampler_config = [ - self.env.sampling_algorithm, - self.env.temperature, - self.env.topk, - self.env.nucleus_topp, - ] + decode_state.samplers[slot] = prefix.sampler return DecodeState( tokens, @@ -763,7 +787,7 @@ def insert(cache, scaler, new_entry): start, input_pos, mask, - sampler_config, + decode_state.samplers, ) def _insert_page_attention( @@ -801,16 +825,7 @@ def _insert_page_attention( lens = decode_state.lens.at[slot].set(1) if self.env.sampling_algorithm == "": - sampler_config = decode_state.sampler_config.at[slot].set( - prefix.sampler_config - ) - else: - sampler_config = [ - self.env.sampling_algorithm, - self.env.temperature, - self.env.topk, - self.env.nucleus_topp, - ] + decode_state.samplers[slot] = prefix.sampler return DecodeState( tokens, @@ -821,7 +836,7 @@ def _insert_page_attention( start, input_pos, mask, - sampler_config, + decode_state.samplers, ) def insert( @@ -1005,20 +1020,9 @@ def update_mask(): mask = update_mask() if self.env.sampling_algorithm == "": - sampling = self._custom_sampling - rng = self.splited_rngs + next_token = self._custom_sampling(logits, decode_state.samplers) else: - sampling = self._sampling - rng = self.rng - - next_token = sampling( - logits=logits, - algorithm=decode_state.sampler_config[:, 0:1].reshape(-1), - rng=rng, - temperature=decode_state.sampler_config[:, 1:2].reshape(-1), - topk=decode_state.sampler_config[:, 2:3].reshape(-1), - nucleus_topp=decode_state.sampler_config[:, 3:4].reshape(-1), - ) + next_token = self._sampling(logits, self.env.algorithm, jax.random.key(0), self.env.temperature, self.env.topk, self.env.nucleus_topp) if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 @@ -1061,7 +1065,7 @@ def update_mask(): decode_state.start, input_pos, mask, - decode_state.sampler_config, + decode_state.samplers, ) return new_decode_state, result_tokens diff --git a/tests/test_engine.py b/tests/test_engine.py index 2fe94ba..944781a 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -24,7 +24,8 @@ from jetstream_pt.engine import Prefix from tests import helpers from jetstream_pt import cache_manager - +# from jetstream_pt.engine import BaseSampler, GreedySampler, WeightedSampler, TopkSampler, NucleusSampler +from jetstream.core.utils.sampling_util import BaseSampler, GreedySampler, WeightedSampler, TopkSampler, NucleusSampler class MockEngine(PyTorchEngine): @@ -198,17 +199,15 @@ def test_sampling_3D(self): self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) + def test_custom_sampling_3D(self): engine = self.setup(batch_size=2) - engine.rng = jax.random.key(3) - engine.splited_rngs = jax.random.split( - engine.rng, num=engine.env.batch_size - ) + rng = jax.random.key(3) + engine.env.sampling_algorithm = "" # Need a different engine of batch size of 1 to reshape the output - engine_b1 = self.setup() - engine_b1.rng = jax.random.key(3) + rng_b1 = jax.random.key(3) logits = jnp.array( [ [[0.4, 0.3, 0.2, 0.1], [0.5, 0.6, 0.7, 0.8]], @@ -217,21 +216,17 @@ def test_custom_sampling_3D(self): ) # test greedy - token = engine._custom_sampling( - logits, - jnp.array([0, 0]), - engine.splited_rngs, - temperature=jnp.array([[0.0], [0.0]]), - topk=jnp.array([[0], [0]]), - nucleus_topp=jnp.array([[0.0], [0.0]]), - ) + sampler = GreedySampler() + samplers = [sampler, sampler] + token = engine._custom_sampling( logits, samplers) + original_tokens = [] for i in range(2): - original_token = engine_b1._sampling( + original_token = engine._sampling( logits[i], "greedy", - engine.splited_rngs[i], - temperature=1.0, + rng=rng, + temperature=0.0, topk=0, nucleus_topp=0.0, ) @@ -244,21 +239,18 @@ def test_custom_sampling_3D(self): self.assertTrue(jnp.isdtype(token, jnp.int32)) # test weighted - engine.env.sampling_algorithm = "weighted" - token = engine._custom_sampling( - logits, - jnp.array([1, 1]), - engine.splited_rngs, - temperature=jnp.array([1, 1]), - topk=jnp.array([0, 0]), - nucleus_topp=jnp.array([0.0, 0.0]), - ) + sampler1 = WeightedSampler(rng=rng, temperature=1.0) + sampler2 = WeightedSampler(rng=rng, temperature=1.0) + samplers = [sampler1, sampler2] + token = engine._custom_sampling(logits, samplers) + original_tokens = [] for i in range(2): - original_token = engine_b1._sampling( + rng_b1, sub_rng = jax.random.split(rng_b1) + original_token = engine._sampling( logits[i], "weighted", - engine.splited_rngs[i], + rng, temperature=1, topk=0, nucleus_topp=0.0, @@ -268,25 +260,23 @@ def test_custom_sampling_3D(self): print(f"custom sampling token {token} vs original tokens {original_tokens}") self.assertTrue(jnp.array_equal(token, original_tokens)) - self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) + self.assertTrue(jnp.array_equal(token, jnp.array([[2], [2]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) - # test topk - engine.env.sampling_algorithm = "topk" - token = engine._custom_sampling( - logits, - jnp.array([3, 3]), - engine.splited_rngs, - temperature=jnp.array([[1.0], [1.0]]), - topk=jnp.array([[3], [3]]), - nucleus_topp=jnp.array([[0.0], [0.0]]), - ) + # # test topk + sampler1 = TopkSampler(rng=rng, temperature=1.0, topk=3) + sampler2 = TopkSampler(rng=rng, temperature=1.0, topk=3) + samplers = [sampler1, sampler2] + token = engine._custom_sampling(logits, samplers) + original_tokens = [] for i in range(2): - original_token = engine_b1._sampling( + # rng_b1, sub_rng = jax.random.split(rng_b1) + sub_rng = rng + original_token = engine._sampling( logits[i], "topk", - engine.splited_rngs[i], + rng=sub_rng, temperature=1.0, topk=3, nucleus_topp=0.0, @@ -300,22 +290,17 @@ def test_custom_sampling_3D(self): self.assertTrue(jnp.isdtype(token, jnp.int32)) # test nucleus - engine.env.sampling_algorithm = "nucleus" - token = engine._custom_sampling( - logits, - jnp.array([2, 2]), - engine.splited_rngs, - temperature=jnp.array([[1.0], [1.0]]), - topk=jnp.array([[0], [0]]), - nucleus_topp=jnp.array([[0.8], [0.8]]), - ) + sampler1 = NucleusSampler(rng=rng, temperature=1.0, nucleus_topp=0.8) + sampler2 = NucleusSampler(rng=rng, temperature=1.0, nucleus_topp=0.8) + samplers = [sampler1, sampler2] + token = engine._custom_sampling(logits, samplers) original_tokens = [] for i in range(2): - original_token = engine_b1._sampling( + original_token = engine._sampling( logits[i], "nucleus", - engine.splited_rngs[i], + rng, temperature=1.0, topk=0, nucleus_topp=0.8, @@ -324,38 +309,34 @@ def test_custom_sampling_3D(self): original_tokens = jnp.concatenate(original_tokens) print(f"custom sampling token {token} vs original tokens {original_tokens}") self.assertTrue(jnp.array_equal(token, original_tokens)) - self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) + self.assertTrue(jnp.array_equal(token, jnp.array([[2], [2]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) - # test greedy + topk - token = engine._custom_sampling( - logits, - jnp.array([0, 3]), - engine.splited_rngs, - temperature=jnp.array([[0.0], [1.0]]), - topk=jnp.array([[0], [3]]), - nucleus_topp=jnp.array([[0.0], [0.0]]), - ) - original_tokens = [] + # # test topk + greedy + sampler1 = TopkSampler(rng=rng, temperature=1.0, topk=3) + sampler2 = GreedySampler() + samplers = [sampler1, sampler2] + token = engine._custom_sampling(logits, samplers) + original_tokens = [] i = 0 - original_token = engine_b1._sampling( + original_token = engine._sampling( logits[i], - "greedy", - engine.splited_rngs[i], - temperature=0.0, - topk=0, + "topk", + rng, + temperature=1.0, + topk=3, nucleus_topp=0.8, ) original_tokens.append(original_token) i = 1 - original_token = engine_b1._sampling( + original_token = engine._sampling( logits[i], - "topk", - engine.splited_rngs[i], - temperature=1.0, - topk=3, + "greedy", + rng, + temperature=0.0, + topk=0, nucleus_topp=0.0, ) original_tokens.append(original_token) @@ -364,11 +345,14 @@ def test_custom_sampling_3D(self): print(f"custom sampling token {token} vs original tokens {original_tokens}") self.assertTrue(jnp.array_equal(token, original_tokens)) - self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) + self.assertTrue(jnp.array_equal(token, jnp.array([[1], [0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) + # test Prefill def test_prefill_with_custom_sampling(self): engine = self.setup() + engine.rng = jax.random.key(3) + engine.env.sampling_algorithm = "" # Inputs doesn't matter @@ -377,8 +361,7 @@ def test_prefill_with_custom_sampling(self): true_length = 1 # Greedy - # algorithm, temperature, topk, nucleus_topp - sampler = [0, 1.0, 3, 0.8] + sampler = GreedySampler() prefix, _ = engine.prefill( params=params, padded_tokens=padded_tokens, @@ -390,15 +373,8 @@ def test_prefill_with_custom_sampling(self): self.assertTrue(jnp.array_equal(token, jnp.array([[3]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) - print( - f"prefix sampler config {prefix.sampler_config} vs sampler {jnp.array(sampler)}" - ) - self.assertAlmostEqual( - prefix.sampler_config.all(), jnp.array(sampler).all() - ) - # Weighted - sampler = [1, 10.0, 3, 0.8] + sampler = WeightedSampler(rng=engine.rng, temperature=1.0) prefix, _ = engine.prefill( params=params, padded_tokens=padded_tokens, @@ -407,14 +383,11 @@ def test_prefill_with_custom_sampling(self): ) token = prefix.token print(f"Weighted output: {token}") - self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) + self.assertTrue(jnp.array_equal(token, jnp.array([[2]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) - self.assertAlmostEqual( - prefix.sampler_config.all(), jnp.array(sampler).all() - ) # Nucleus - sampler = [2, 1.0, 3, 0.0] + sampler = NucleusSampler(rng=engine.rng, temperature=1.0, nucleus_topp=0.8) prefix, _ = engine.prefill( params=params, padded_tokens=padded_tokens, @@ -422,15 +395,13 @@ def test_prefill_with_custom_sampling(self): sampler=sampler, ) token = prefix.token - print(f"Topk output: {token}") - self.assertTrue(jnp.array_equal(token, jnp.array([[3]]))) + print(f"Nucleus output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[2]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) - self.assertAlmostEqual( - prefix.sampler_config.all(), jnp.array(sampler).all() - ) # Topk - sampler = [3, 1.0, 3, 0.8] + sampler = TopkSampler(rng=engine.rng, temperature=1.0, topk=3) + prefix, _ = engine.prefill( params=params, padded_tokens=padded_tokens, @@ -438,12 +409,10 @@ def test_prefill_with_custom_sampling(self): sampler=sampler, ) token = prefix.token - print(f"Nucleus output: {token}") - self.assertTrue(jnp.array_equal(token, jnp.array([[3]]))) + print(f"Topk output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[1]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) - self.assertAlmostEqual( - prefix.sampler_config.all(), jnp.array(sampler).all() - ) + def test_insert_no_wrap_with_custom_sampling(self): engine = self.setup() @@ -451,9 +420,6 @@ def test_insert_no_wrap_with_custom_sampling(self): engine.env.batch_size = 2 cache_shape = engine.env.cache_shape - sampler_config_raw = [0, 1.0, 3, 0.8] - sampler_config = jnp.array(sampler_config_raw) - prefill_cache_shape = (1, cache_shape[1], 16, cache_shape[3]) prefill_cache = [] for _ in range(engine.env.num_layers): @@ -464,11 +430,12 @@ def test_insert_no_wrap_with_custom_sampling(self): ) ) + sampler = GreedySampler() prefix = Prefix( token=jnp.ones((1)), caches=prefill_cache, seq_len=16, - sampler_config=sampler_config, + sampler=sampler, ) doesnt_matter = jnp.array([0]) @@ -484,7 +451,7 @@ def test_insert_no_wrap_with_custom_sampling(self): start=jnp.zeros((engine.env.batch_size, 1)), input_pos=jnp.zeros((engine.env.batch_size,)), mask=jnp.zeros((engine.env.batch_size, 128)), - sampler_config=jnp.zeros((engine.env.batch_size, 4)), + samplers = [BaseSampler()] * engine.env.batch_size ) # Insert to slot 1 @@ -493,17 +460,11 @@ def test_insert_no_wrap_with_custom_sampling(self): self.assertAlmostEqual( result_decode_state.tokens.all(), decode_state.tokens.all() ) - self.assertAlmostEqual( - result_decode_state.sampler_config.all(), - jnp.array([[0, 0, 0, 0], sampler_config_raw]).all(), - ) + self.assertEqual(result_decode_state.samplers[1], prefix.sampler) - def test_decode_with_custom_sampling(self): + def test_generate_with_custom_sampling(self): engine = self.setup(batch_size=2) engine.rng = jax.random.key(3) - engine.splited_rngs = jax.random.split( - engine.rng, num=engine.env.batch_size - ) engine.env.sampling_algorithm = "" # Inputs doesn't matter @@ -519,7 +480,7 @@ def test_decode_with_custom_sampling(self): start=doesnt_matter, input_pos=jnp.zeros((engine.env.batch_size,)), mask=jnp.zeros((engine.env.batch_size, 1)), - sampler_config=jnp.array([[0, 0.0, 0, 0.0], [3, 1.0, 3, 0.0]]), + samplers = [GreedySampler(), WeightedSampler(rng=engine.rng, temperature=1.0)], ) # Topk + Weighted @@ -528,7 +489,7 @@ def test_decode_with_custom_sampling(self): params=params, decode_state=decode_state ) token = decode_state.tokens - print(f"Greedy output: {token}") + print(f"Topk + Weighted output: {token}") self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) From 82a4682bf0671a212b24d39d0ccf072cdac50172 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Sun, 1 Dec 2024 16:32:54 +0000 Subject: [PATCH 6/9] To control sampling from request. --- jetstream_pt/engine.py | 266 +++--------- jetstream_pt/environment.py | 5 +- run_interactive.py | 21 +- scripts/custom_sampling_benchmark.py | 91 ++++ scripts/jax_experiments.py | 38 -- tests/test_engine.py | 622 ++++++++++++++------------- 6 files changed, 511 insertions(+), 532 deletions(-) create mode 100644 scripts/custom_sampling_benchmark.py diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 86977f4..1061ff5 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -28,8 +28,7 @@ import torch import numpy as np -from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils -from jetstream.engine import sampling_utils +from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils, sampling_utils import torch_xla2 from torch.utils import _pytree as pytree @@ -54,29 +53,31 @@ Params = jax.Array PrefillInputs = jax.Array +STRATEGY_MAP = {"greedy": 0, "weighted": 1, "top_p": 2, "top_k": 3} -class DefaultSampler: - def __init__(self, rng, temperature, topk, nucleus_topp, algorithm): - self.rng = rng - self.temperature = temperature - self.topk = topk - self.nucleus_topp = nucleus_topp - self.algorithm = algorithm - def __call__(self, logits): - return PyTorchEngine._sampling(logits, self.algorithm, self.rng, self.temperature, self.topk, self.nucleus_topp) +# class DefaultSampler(sampling_utils.BaseSampler): +# def __init__(self, rng, temperature, topk, nucleus_topp, algorithm): +# self.rng = rng +# self.temperature = temperature +# self.topk = topk +# self.nucleus_topp = nucleus_topp +# self.algorithm = algorithm - # Define how to flatten the instance into leaves and auxiliary data - def tree_flatten(self): - children = (self.rng, self.temperature, self.topk, self.nucleus_topp, self.algorithm) # Leaves to be flattened - aux_data = () # Auxiliary data that does not need to be traced - return children, aux_data +# def __call__(self, logits): +# return PyTorchEngine._sampling(logits, self.algorithm, self.rng, self.temperature, self.topk, self.nucleus_topp) - # Define how to unflatten from leaves and auxiliary data - @classmethod - def tree_unflatten(cls, aux_data, children): - rng, temperature, topk, nucleus_topp, algorithm = children - return cls(rng, temperature, topk, nucleus_topp, algorithm) +# # Define how to flatten the instance into leaves and auxiliary data +# def tree_flatten(self): +# children = (self.rng, self.temperature, self.topk, self.nucleus_topp, self.algorithm) # Leaves to be flattened +# aux_data = () # Auxiliary data that does not need to be traced +# return children, aux_data + +# # Define how to unflatten from leaves and auxiliary data +# @classmethod +# def tree_unflatten(cls, aux_data, children): +# rng, temperature, topk, nucleus_topp, algorithm = children +# return cls(rng, temperature, topk, nucleus_topp, algorithm) @struct.dataclass @@ -103,12 +104,8 @@ class DecodeState: jax.Array ) # [batch_size, 1] total (prefill + decode) length for each slot mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid - # The sampling configuration. - # If sampling_algorithm set to empty, the shape is [batch_size, 1] - # Otherwise it's a list of intergers - # The last dimension contains [algorithm, temperature, topk, nucleus] - # sampler_config: jax.Array | List[int] - samplers: jax.Array + # The sampling function + samplers: Any # NOTE model specific @@ -140,10 +137,10 @@ def __init__( jax.config.update("jax_enable_x64", False) self.prefill_cache_sharding = self.env.prefill_cache_sharding - # self.prefill = jax.jit( - # self.prefill, - # out_shardings=(self.get_prefix_destination_sharding(), None), - # ) + self.prefill = jax.jit( + self.prefill, + out_shardings=(self.get_prefix_destination_sharding(), None), + ) self.insert = jax.jit( self.insert, donate_argnums=(0, 1), @@ -154,6 +151,7 @@ def __init__( donate_argnums=(1,), out_shardings=(self.get_decode_state_sharding(), None), ) + # self.generate = self.generate_impl if self.env.page_attention: max_pages_per_sequence = ( @@ -204,20 +202,6 @@ def init_decode_state( if self.env.quant_config.enable_kv_quantization: scalers = [c.scalers() for c in caches_obj] - if self.env.sampling_algorithm == "": - # [algorithm, temperature, topk, nucleus] - # sampler_config = [0, 0.0, 0, 0.0] - # sampler_config = jnp.tile(sampler_config, (self.env.batch_size, 1)) - samplers = [] - else: - # sampler_config = [ - # self.env.sampling_algorithm, - # self.env.temperature, - # self.env.topk, - # self.env.nucleus_topp, - # ] - samplers = DefaultSampler(self.rng, self.env.temperature, self.env.topk, self.env.nucleus_topp, self.env.algorithm) - return DecodeState( jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), caches, @@ -231,7 +215,7 @@ def init_decode_state( float("-inf"), dtype=self.default_dtype, ), # mask - sampler_config, + None, ) # pylint: disable-next=all @@ -331,113 +315,22 @@ def _call_model_prefill(self, weights, tokens, input_indexes): caches_res = [c.state() for c in caches] return torchjax.from_torch((res, caches_res)) - # def _greedy_sampling(self, logits): - # return jnp.argmax(logits, axis=-1) - - # def _weighted_sampling(self, logits, rng, temperature): - # return jax.random.categorical(rng, logits / temperature) - - # def _nucleus_sampling(self, logits, rng, temperature, nucleus_topp): - # # return sampling_utils.sample_nucleus_topp_logits(logits, nucleus_topp, temperature, rng) - # """Restrict sampling to the top logits with cumulative probability >= - # nucleus_topp. - - # The nucleus sampling method is proposed in the paper `The Curious Case of - # Neural Text Degeneration (https://arxiv.org/pdf/1904.09751.pdf)` - - # """ - # # if nucleus_topp < 0: - # # raise ValueError( - # # "Can't apply nucleus with parameter {nucleus_topp=} less zero" - # # ) - # logits_sorted = jnp.sort(logits, axis=-1)[..., ::-1] # sort descending - # sorted_cum_probs = jnp.cumsum( - # jax.nn.softmax(logits_sorted, axis=-1), axis=-1 - # ) # get cumsum probs - # cutoff_index = jnp.sum( - # sorted_cum_probs < nucleus_topp, axis=-1, keepdims=True - # ) # find cutoff index - # cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) - # logits = jnp.where( - # logits < cutoff_logit, jnp.full_like(logits, NEG_INF), logits - # ) - # result = jax.random.categorical(rng, logits / temperature) - # return result - - # def _topk_sampling(self, logits, rng, temperature, topk): - # # return sampling_utils.sample_topk_logits(logits, topk, temperature, rng) - # """Restricting sampling to the best k logits.""" - # # if topk <= 0: - # # raise ValueError("Can't apply algorithm topk with parameter {topk=} <= 0") - # sorted_indices = jnp.argsort(logits)[::-1] # Sort in descending order - # topk_mask = jnp.arange(sorted_indices.shape[-1]) < topk - # topk_idxs = jnp.where(topk_mask, sorted_indices, -1) - - # topk_logits = jnp.where(topk_idxs == -1, -jnp.inf, logits) - - # sampled_idx = jnp.expand_dims( - # jax.random.categorical(rng, topk_logits / temperature).astype( - # jnp.int32 - # ), - # axis=-1, - # ) - # sampled_tokens = jnp.squeeze( - # jnp.take_along_axis(topk_idxs, sampled_idx, axis=-1), axis=-1 - # ).astype(jnp.int32) - - # return sampled_tokens - - # Algorithm type: - # 0: Greedy 1: Weighted 2: Nucleus 3: Top-k - # def _apply_sampling( - # self, logits, algorithm, rng, temperature, topk, nucleus_topp - # ) -> jnp.ndarray: - # return jax.lax.cond( - # algorithm == 0, - # lambda: self._greedy_sampling(logits), # Greedy - # lambda: jax.lax.cond( - # algorithm == 1, - # lambda: self._weighted_sampling( - # logits, rng, temperature - # ), # Weighted - # lambda: jax.lax.cond( - # algorithm == 2, - # lambda: self._nucleus_sampling( - # logits, rng, temperature, nucleus_topp - # ), # Nucleus sampling - # lambda: self._topk_sampling( - # logits, rng, temperature, topk - # ), # Top-k sampling - # ), - # ), - # ) - - def _custom_sampling( - self, logits, samplers - ) -> jnp.ndarray: - if len(logits.shape) == 2: - logits = jnp.expand_dims(logits, 0) + # Temporarily disabled becuase handling per request sampling is not ready yet. + # @classmethod + # def _custom_sampling(self, logits, samplers) -> jnp.ndarray: + # if len(logits.shape) == 2: + # logits = jnp.expand_dims(logits, 0) - logits = logits[:, -1] + # logits = logits[:, -1] - # Prefill and Generate have different batch size - current_batch_size = logits.shape[0] + # # Prefill and Generate have different batch size + # current_batch_size = logits.shape[0] - idx = jnp.arange(current_batch_size) - apply_sampler = lambda i, l: jax.lax.switch(i, samplers, l) - apply_vmap = jax.vmap(apply_sampler, in_axes=(0, 0)) - return apply_vmap(idx, logits).reshape(current_batch_size, -1) + # idx = jnp.arange(current_batch_size) + # apply_sampler = lambda i, l: jax.lax.switch(i, samplers, l) + # apply_vmap = jax.vmap(apply_sampler, in_axes=(0, 0)) + # return apply_vmap(idx, logits).reshape(current_batch_size, -1) - # apply_sampling_v = jax.vmap( - # lambda _logits, _algorithm, _rngs, _temperature, _topk, _nucleus_topp: self._apply_sampling( - # _logits, _algorithm, _rngs, _temperature, _topk, _nucleus_topp - # ), - # in_axes=(0, 0, 0, 0, 0, 0), - # ) - # apply_sampling_v = jax.jit(apply_sampling_v) - # return apply_sampling_v( - # logits, algorithm, rng, temperature, topk, nucleus_topp - # ).reshape(self.env.batch_size, -1) def _sampling( self, logits: Any, algorithm, rng, temperature, topk, nucleus_topp @@ -487,45 +380,22 @@ def prefill( if len(logits.shape) == 3: # b, seqlen, num words logits = logits[0] # seqlen, num words - prefill_batch_size = logits.shape[0] - if self.env.sampling_algorithm == "": - # assert ( - # isinstance(sampler, List) - # or isinstance(sampler, list) - # or isinstance(sampler, Tuple) - # or isinstance(sampler, tuple) - # ), f"{type(sampler)} is not valid" - # algorithm, temperature, topk, nucleus_topp = sampler - - # algorithm = jnp.array([algorithm]) - # temperature = jnp.array([temperature]) - # topk = jnp.array([topk]) - # nucleus_topp = jnp.array([nucleus_topp]) - - # Prefill only handle batch size of 1, therefore no need to use splitted rngs - token = self._custom_sampling(logits, [sampler]) - - # if len(logits.shape) == 2: - # logits = jnp.expand_dims(logits, 0) - - # logits = logits[:, -1] - # token = sampler(logits) - + if sampler: + token = sampler(logits[true_length - 1]) else: - algorithm, temperature, topk, nucleus_topp = ( + token = sampling_utils.sampling( + logits[true_length - 1], + self.rng, self.env.sampling_algorithm, - self.env.temperature, self.env.topk, self.env.nucleus_topp, + self.env.temperature, ) - sampling = lambda logits: self._sampling(logits, algorithm, jax.random.key(0), temperature, topk, nucleus_topp) - # No need to store the sampler if using default - sampler = 0 - + token = jnp.reshape(token, (1,)) + token_out = jnp.reshape(token, (1, 1)) # token = sampling( # logits=logits # ).reshape(prefill_batch_size, 1) - token_out = token.reshape(prefill_batch_size, 1) data = jnp.concatenate( [ @@ -552,7 +422,7 @@ def prefill( # for k, v in updated_caches # ] return ( - Prefix(token_out, updated_caches, true_length, sampler), + Prefix(token, updated_caches, true_length, sampler), result, ) @@ -674,12 +544,7 @@ def insert(cache, scaler, new_entry, update_index): scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) - if self.env.sampling_algorithm == "": - # sampler_config = decode_state.sampler_config.at[slot].set( - # prefix.sampler_config - # ) - decode_state.samplers[slot] = prefix.sampler - + sampler = prefix.sampler if prefix.sampler else decode_state.samplers return DecodeState( tokens, caches, @@ -689,7 +554,7 @@ def insert(cache, scaler, new_entry, update_index): start, input_pos, mask, - decode_state.samplers, + sampler, ) # pylint: disable-next=all @@ -775,8 +640,7 @@ def insert(cache, scaler, new_entry): lens = decode_state.lens.at[slot].set(1) - if self.env.sampling_algorithm == "": - decode_state.samplers[slot] = prefix.sampler + sampler = prefix.sampler if prefix.sampler else decode_state.samplers return DecodeState( tokens, @@ -787,7 +651,7 @@ def insert(cache, scaler, new_entry): start, input_pos, mask, - decode_state.samplers, + sampler, ) def _insert_page_attention( @@ -824,9 +688,7 @@ def _insert_page_attention( scales = None lens = decode_state.lens.at[slot].set(1) - if self.env.sampling_algorithm == "": - decode_state.samplers[slot] = prefix.sampler - + sampler = prefix.sampler if prefix.sampler else decode_state.samplers return DecodeState( tokens, caches, @@ -836,7 +698,7 @@ def _insert_page_attention( start, input_pos, mask, - decode_state.samplers, + sampler, ) def insert( @@ -1019,10 +881,18 @@ def update_mask(): # fill mask later, now use flash attention mask = update_mask() - if self.env.sampling_algorithm == "": - next_token = self._custom_sampling(logits, decode_state.samplers) + # next_token = self._custom_sampling(logits, decode_state.samplers) + if decode_state.samplers: + next_token = decode_state.samplers(logits) else: - next_token = self._sampling(logits, self.env.algorithm, jax.random.key(0), self.env.temperature, self.env.topk, self.env.nucleus_topp) + next_token = self._sampling( + logits, + self.env.sampling_algorithm, + self.rng, + self.env.temperature, + self.env.topk, + self.env.nucleus_topp, + ) if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 4917705..0202711 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -52,7 +52,10 @@ class JetEngineEnvironmentData: batch_size: int = 32 # batch size is generate step batch size cache_sequence_length: int = 2048 # size of the cache. - quant_config: QuantizationConfig = QuantizationConfig() + # quant_config: QuantizationConfig = QuantizationConfig() + quant_config: QuantizationConfig = dataclasses.field( + default_factory=QuantizationConfig + ) model_type: str = "llama-2-13b" # this implies the model config diff --git a/run_interactive.py b/run_interactive.py index 8463658..2e00159 100644 --- a/run_interactive.py +++ b/run_interactive.py @@ -23,6 +23,7 @@ from absl import app from jetstream.engine import token_utils from jetstream_pt.config import FLAGS, create_engine_from_config_flags +from jetstream.engine import sampling_utils # pylint: disable-next=all @@ -30,6 +31,21 @@ def main(argv): engine = create_engine_from_config_flags() + rng = jax.random.key(1) + temperature = 1 + topk = 1 + topp = 0.2 + + sampler = jax.tree_util.Partial( + sampling_utils.jittable_sample_topk_logits, + rng=rng, + temperature=temperature, + topk=topk, + ) + # sampler = jax.tree_util.Partial(sampling_utils.jittable_sample_topp_logits, rng=rng, temperature=temperature, topp=topp) + # sampler = jax.tree_util.Partial(sampling_utils.jittable_sample_greedy_logits) + # sampler = jax.tree_util.Partial(sampling_utils.jittable_sample_weighted_logits, rng=rng, temperature=temperature) + start = time.perf_counter() params = engine.load_params() print("Load params ", time.perf_counter() - start) @@ -77,7 +93,10 @@ def main(argv): jax.profiler.start_trace(profiling_output) prefill_result, _ = engine.prefill( - params=params, padded_tokens=tokens, true_length=true_length + params=params, + padded_tokens=tokens, + true_length=true_length, + sampler=sampler ) # pylint: disable-next=all decode_state = engine.insert(prefill_result, decode_state, slot=slot) diff --git a/scripts/custom_sampling_benchmark.py b/scripts/custom_sampling_benchmark.py new file mode 100644 index 0000000..34188cc --- /dev/null +++ b/scripts/custom_sampling_benchmark.py @@ -0,0 +1,91 @@ + +import time + +import jax +import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec, Mesh + + +from jetstream.engine import sampling_utils +from jetstream_pt.engine import PyTorchEngine + +def sample_topk_logits(logits, topk, temperature, rng): + sorted_indices = jnp.argsort(logits, descending=True) # Sort in descending order + topk_mask = jnp.arange(sorted_indices.shape[-1]) < topk + topk_idxs = jnp.where(topk_mask, sorted_indices, -1) + topk_logits = jnp.where(topk_idxs == -1, -jnp.inf, logits) + + # self.rng, sub_rng = jax.random.split(self.rng) + sub_rng = rng + sampled_idx = jnp.expand_dims( + jax.random.categorical(sub_rng, topk_logits / temperature).astype( + jnp.int32 + ), + axis=-1, + ) + print(f"topk_idxs {topk_idxs.shape} sampled_idx {sampled_idx.shape}") + sampled_tokens = jnp.squeeze( + jnp.take_along_axis(topk_idxs, sampled_idx, axis=-1), axis=-1 + ).astype(jnp.int32) + + return sampled_tokens + +def sample_weighted_logits(logits, topk, temperature, rng): + return jax.random.categorical(rng, logits / temperature) + +def replicate_array_with_sharding(array): + # Create a sharding with None for all dimensions (meaning replicated) + mesh = jax.sharding.Mesh(jax.devices(), ('data',)) # or your existing mesh + sharding = NamedSharding( + mesh, + PartitionSpec(None,) * len(array.shape) # None for each dimension + ) + return jax.device_put(array, sharding) + +def test_custom_sampling(): + """test custom sampling performance""" + batch_size = 96 + hidden_size = 8192 + rng = jax.random.key(0) + logits = jax.random.normal(rng, (batch_size,1, hidden_size), dtype=float) + logits = replicate_array_with_sharding(logits) + + + topk = jax.random.randint(rng, (batch_size,), dtype=int, minval=1, maxval=3) + temperature = jax.random.uniform(rng, (batch_size,), dtype=float, minval=0.0, maxval=1.0) + + samplers = [] + # sampler = sample_topk_logits + sampler = sample_weighted_logits + sampler = jax.jit(sampler) + print(f"logits {logits[:, -1].shape}, topk {topk[0]}, temperature {temperature[0]}") + sampler(logits[:, -1], topk[0], temperature[0], rng) + + for i in range(batch_size): + rng, sub_rng = jax.random.split(rng) + partial = jax.tree_util.Partial(sampler, topk=topk[i], temperature=temperature[i], rng=sub_rng) + samplers.append(partial) + + custom_sampling = PyTorchEngine._custom_sampling +# custom_sampling = jax.jit(custom_sampling) + custom_sampling(logits, samplers) + + start = time.perf_counter() + loops = 10 + # for i in range(loops): + result = custom_sampling(logits, samplers) + result.block_until_ready() + end = time.perf_counter() + duration = end - start + print(f"Custom sampling: total time {duration} for {loops} loops") + + start = time.perf_counter() + loops = 10 + # for i in range(loops): + result = sampler(logits, topk[0], temperature[0], rng) + result.block_until_ready() + end = time.perf_counter() + duration = end - start + print(f"Uniform sampling: total time {duration} for {loops} loops") + +test_custom_sampling() diff --git a/scripts/jax_experiments.py b/scripts/jax_experiments.py index 3d38094..f4c939b 100644 --- a/scripts/jax_experiments.py +++ b/scripts/jax_experiments.py @@ -24,44 +24,6 @@ import torch_xla2.extra -def test1(): - """test jit cache size""" - - @functools.partial(jax.jit, static_argnums=(2,)) - # pylint: disable-next=all - def f(x, i, issum): - if issum: - return x + i - - return x - i - - x = jnp.ones((10,)) - print(f(x, 0, True)) - print("cache", f._cache_size()) - print(f(x, 1, False)) - print("cache", f._cache_size()) - - # pylint: disable-next=all - class A: - - def __init__(self, a): - self.a = a - - def incr(self): - """increase by 1""" - self.a += 1 - - @jax.jit - def f2(x): - a = A(x) - a.incr() - return a.a - - print(f2(x)) - print(f2(x)) - print(f2(x)) - - # pylint: disable-next=all def test2(): """test insert cache""" diff --git a/tests/test_engine.py b/tests/test_engine.py index 944781a..9704ce1 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -25,7 +25,9 @@ from tests import helpers from jetstream_pt import cache_manager # from jetstream_pt.engine import BaseSampler, GreedySampler, WeightedSampler, TopkSampler, NucleusSampler -from jetstream.core.utils.sampling_util import BaseSampler, GreedySampler, WeightedSampler, TopkSampler, NucleusSampler +# from jetstream.engine.sampling_util import BaseSampler, GreedySampler, WeightedSampler, TopkSampler, NucleusSampler +from jetstream.engine.sampling_utils import jittable_sample_greedy_logits, jittable_sample_topp_logits, jittable_sample_topk_logits, jittable_sample_weighted_logits + class MockEngine(PyTorchEngine): @@ -199,299 +201,331 @@ def test_sampling_3D(self): self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) - - def test_custom_sampling_3D(self): - engine = self.setup(batch_size=2) - rng = jax.random.key(3) - - engine.env.sampling_algorithm = "" - - # Need a different engine of batch size of 1 to reshape the output - rng_b1 = jax.random.key(3) - logits = jnp.array( - [ - [[0.4, 0.3, 0.2, 0.1], [0.5, 0.6, 0.7, 0.8]], - [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]], - ] - ) - - # test greedy - sampler = GreedySampler() - samplers = [sampler, sampler] - token = engine._custom_sampling( logits, samplers) - - original_tokens = [] - for i in range(2): - original_token = engine._sampling( - logits[i], - "greedy", - rng=rng, - temperature=0.0, - topk=0, - nucleus_topp=0.0, - ) - original_tokens.append(original_token) - original_tokens = jnp.concatenate(original_tokens) - - print(f"custom sampling token {token} vs original tokens {original_tokens}") - self.assertTrue(jnp.array_equal(token, original_tokens)) - self.assertTrue(jnp.array_equal(token, jnp.array([[3], [0]]))) - self.assertTrue(jnp.isdtype(token, jnp.int32)) - - # test weighted - sampler1 = WeightedSampler(rng=rng, temperature=1.0) - sampler2 = WeightedSampler(rng=rng, temperature=1.0) - samplers = [sampler1, sampler2] - token = engine._custom_sampling(logits, samplers) - - original_tokens = [] - for i in range(2): - rng_b1, sub_rng = jax.random.split(rng_b1) - original_token = engine._sampling( - logits[i], - "weighted", - rng, - temperature=1, - topk=0, - nucleus_topp=0.0, - ) - original_tokens.append(original_token) - original_tokens = jnp.concatenate(original_tokens) - - print(f"custom sampling token {token} vs original tokens {original_tokens}") - self.assertTrue(jnp.array_equal(token, original_tokens)) - self.assertTrue(jnp.array_equal(token, jnp.array([[2], [2]]))) - self.assertTrue(jnp.isdtype(token, jnp.int32)) - - # # test topk - sampler1 = TopkSampler(rng=rng, temperature=1.0, topk=3) - sampler2 = TopkSampler(rng=rng, temperature=1.0, topk=3) - samplers = [sampler1, sampler2] - token = engine._custom_sampling(logits, samplers) - - original_tokens = [] - for i in range(2): - # rng_b1, sub_rng = jax.random.split(rng_b1) - sub_rng = rng - original_token = engine._sampling( - logits[i], - "topk", - rng=sub_rng, - temperature=1.0, - topk=3, - nucleus_topp=0.0, - ) - original_tokens.append(original_token) - original_tokens = jnp.concatenate(original_tokens) - - print(f"custom sampling token {token} vs original tokens {original_tokens}") - self.assertTrue(jnp.array_equal(token, original_tokens)) - self.assertTrue(jnp.array_equal(token, jnp.array([[1], [2]]))) - self.assertTrue(jnp.isdtype(token, jnp.int32)) - - # test nucleus - sampler1 = NucleusSampler(rng=rng, temperature=1.0, nucleus_topp=0.8) - sampler2 = NucleusSampler(rng=rng, temperature=1.0, nucleus_topp=0.8) - samplers = [sampler1, sampler2] - token = engine._custom_sampling(logits, samplers) - - original_tokens = [] - for i in range(2): - original_token = engine._sampling( - logits[i], - "nucleus", - rng, - temperature=1.0, - topk=0, - nucleus_topp=0.8, - ) - original_tokens.append(original_token) - original_tokens = jnp.concatenate(original_tokens) - print(f"custom sampling token {token} vs original tokens {original_tokens}") - self.assertTrue(jnp.array_equal(token, original_tokens)) - self.assertTrue(jnp.array_equal(token, jnp.array([[2], [2]]))) - self.assertTrue(jnp.isdtype(token, jnp.int32)) - - # # test topk + greedy - sampler1 = TopkSampler(rng=rng, temperature=1.0, topk=3) - sampler2 = GreedySampler() - samplers = [sampler1, sampler2] - token = engine._custom_sampling(logits, samplers) - - original_tokens = [] - i = 0 - original_token = engine._sampling( - logits[i], - "topk", - rng, - temperature=1.0, - topk=3, - nucleus_topp=0.8, - ) - original_tokens.append(original_token) - - i = 1 - original_token = engine._sampling( - logits[i], - "greedy", - rng, - temperature=0.0, - topk=0, - nucleus_topp=0.0, - ) - original_tokens.append(original_token) - - original_tokens = jnp.concatenate(original_tokens) - - print(f"custom sampling token {token} vs original tokens {original_tokens}") - self.assertTrue(jnp.array_equal(token, original_tokens)) - self.assertTrue(jnp.array_equal(token, jnp.array([[1], [0]]))) - self.assertTrue(jnp.isdtype(token, jnp.int32)) - - # test Prefill - def test_prefill_with_custom_sampling(self): - engine = self.setup() - engine.rng = jax.random.key(3) - - engine.env.sampling_algorithm = "" - - # Inputs doesn't matter - params = jnp.zeros((1,), jnp.float32) - padded_tokens = jnp.zeros((1,), jnp.float32) - true_length = 1 - - # Greedy - sampler = GreedySampler() - prefix, _ = engine.prefill( - params=params, - padded_tokens=padded_tokens, - true_length=true_length, - sampler=sampler, - ) - token = prefix.token - print(f"Greedy output: {token}") - self.assertTrue(jnp.array_equal(token, jnp.array([[3]]))) - self.assertTrue(jnp.isdtype(token, jnp.int32)) - - # Weighted - sampler = WeightedSampler(rng=engine.rng, temperature=1.0) - prefix, _ = engine.prefill( - params=params, - padded_tokens=padded_tokens, - true_length=true_length, - sampler=sampler, - ) - token = prefix.token - print(f"Weighted output: {token}") - self.assertTrue(jnp.array_equal(token, jnp.array([[2]]))) - self.assertTrue(jnp.isdtype(token, jnp.int32)) - - # Nucleus - sampler = NucleusSampler(rng=engine.rng, temperature=1.0, nucleus_topp=0.8) - prefix, _ = engine.prefill( - params=params, - padded_tokens=padded_tokens, - true_length=true_length, - sampler=sampler, - ) - token = prefix.token - print(f"Nucleus output: {token}") - self.assertTrue(jnp.array_equal(token, jnp.array([[2]]))) - self.assertTrue(jnp.isdtype(token, jnp.int32)) - - # Topk - sampler = TopkSampler(rng=engine.rng, temperature=1.0, topk=3) - - prefix, _ = engine.prefill( - params=params, - padded_tokens=padded_tokens, - true_length=true_length, - sampler=sampler, - ) - token = prefix.token - print(f"Topk output: {token}") - self.assertTrue(jnp.array_equal(token, jnp.array([[1]]))) - self.assertTrue(jnp.isdtype(token, jnp.int32)) - - - def test_insert_no_wrap_with_custom_sampling(self): - engine = self.setup() - engine.env.sampling_algorithm = "" - engine.env.batch_size = 2 - cache_shape = engine.env.cache_shape - - prefill_cache_shape = (1, cache_shape[1], 16, cache_shape[3]) - prefill_cache = [] - for _ in range(engine.env.num_layers): - prefill_cache.append( - ( - jnp.ones(prefill_cache_shape, dtype=jnp.bfloat16), - jnp.ones(prefill_cache_shape, dtype=jnp.bfloat16), - ) - ) - - sampler = GreedySampler() - prefix = Prefix( - token=jnp.ones((1)), - caches=prefill_cache, - seq_len=16, - sampler=sampler, - ) - - doesnt_matter = jnp.array([0]) - kv_cache = engine.env.make_caches_generate() - kv_cache = [c.state() for c in kv_cache] - - decode_state = DecodeState( - tokens=jnp.zeros((engine.env.batch_size, 1)), - caches=kv_cache, - cache_scales=[doesnt_matter], - current_position=16, - lens=jnp.zeros((engine.env.batch_size, 1)), - start=jnp.zeros((engine.env.batch_size, 1)), - input_pos=jnp.zeros((engine.env.batch_size,)), - mask=jnp.zeros((engine.env.batch_size, 128)), - samplers = [BaseSampler()] * engine.env.batch_size - ) - - # Insert to slot 1 - result_decode_state = engine._insert_no_wrap(prefix, decode_state, slot=1) - - self.assertAlmostEqual( - result_decode_state.tokens.all(), decode_state.tokens.all() - ) - self.assertEqual(result_decode_state.samplers[1], prefix.sampler) - - def test_generate_with_custom_sampling(self): - engine = self.setup(batch_size=2) - engine.rng = jax.random.key(3) - engine.env.sampling_algorithm = "" - - # Inputs doesn't matter - doesnt_matter = jnp.array([0]) - params = doesnt_matter - - decode_state = DecodeState( - tokens=jnp.zeros((engine.env.batch_size, 1)), - caches=[doesnt_matter], - cache_scales=[doesnt_matter], - current_position=0, - lens=jnp.zeros((engine.env.batch_size, 1)), - start=doesnt_matter, - input_pos=jnp.zeros((engine.env.batch_size,)), - mask=jnp.zeros((engine.env.batch_size, 1)), - samplers = [GreedySampler(), WeightedSampler(rng=engine.rng, temperature=1.0)], - ) - - # Topk + Weighted - # algorithm, temperature, topk, nucleus_topp - decode_state, _ = engine.generate_impl( - params=params, decode_state=decode_state - ) - token = decode_state.tokens - print(f"Topk + Weighted output: {token}") - self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) - self.assertTrue(jnp.isdtype(token, jnp.int32)) + # Temporarily disabled becuase handling per request sampling is not ready yet. +# def test_custom_sampling_3D(self): +# engine = self.setup(batch_size=2) +# rng = jax.random.key(3) + +# engine.env.sampling_algorithm = "" + +# # Need a different engine of batch size of 1 to reshape the output +# rng_b1 = jax.random.key(3) +# logits = jnp.array( +# [ +# [[0.4, 0.3, 0.2, 0.1], [0.5, 0.6, 0.7, 0.8]], +# [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]], +# ] +# ) + +# # test greedy +# sampler = jittable_sample_greedy_logits +# samplers = [sampler, sampler] +# token = engine._custom_sampling(logits, samplers) + +# original_tokens = [] +# for i in range(2): +# original_token = engine._sampling( +# logits[i], +# "greedy", +# rng=rng, +# temperature=0.0, +# topk=0, +# nucleus_topp=0.0, +# ) +# original_tokens.append(original_token) +# original_tokens = jnp.concatenate(original_tokens) + +# print(f"custom sampling token {token} vs original tokens {original_tokens}") +# self.assertTrue(jnp.array_equal(token, original_tokens)) +# self.assertTrue(jnp.array_equal(token, jnp.array([[3], [0]]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# # test weighted +# sampler1 = jax.tree_util.Partial( +# jittable_sample_weighted_logits, rng=rng, temperature=1.0 +# ) +# sampler2 = jax.tree_util.Partial( +# jittable_sample_weighted_logits, rng=rng, temperature=1.0 +# ) +# samplers = [sampler1, sampler2] +# token = engine._custom_sampling(logits, samplers) + +# original_tokens = [] +# for i in range(2): +# rng_b1, sub_rng = jax.random.split(rng_b1) +# original_token = engine._sampling( +# logits[i], +# "weighted", +# rng, +# temperature=1, +# topk=0, +# nucleus_topp=0.0, +# ) +# original_tokens.append(original_token) +# original_tokens = jnp.concatenate(original_tokens) + +# print(f"custom sampling token {token} vs original tokens {original_tokens}") +# self.assertTrue(jnp.array_equal(token, original_tokens)) +# self.assertTrue(jnp.array_equal(token, jnp.array([[2], [2]]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# # # test topk +# sampler1 = jax.tree_util.Partial( +# jittable_sample_topk_logits, rng=rng, temperature=1.0, topk=3 +# ) +# sampler2 = jax.tree_util.Partial( +# jittable_sample_topk_logits, rng=rng, temperature=1.0, topk=3 +# ) +# samplers = [sampler1, sampler2] +# token = engine._custom_sampling(logits, samplers) + +# original_tokens = [] +# for i in range(2): +# # rng_b1, sub_rng = jax.random.split(rng_b1) +# sub_rng = rng +# original_token = engine._sampling( +# logits[i], +# "topk", +# rng=sub_rng, +# temperature=1.0, +# topk=3, +# nucleus_topp=0.0, +# ) +# original_tokens.append(original_token) +# original_tokens = jnp.concatenate(original_tokens) + +# print(f"custom sampling token {token} vs original tokens {original_tokens}") +# self.assertTrue(jnp.array_equal(token, original_tokens)) +# self.assertTrue(jnp.array_equal(token, jnp.array([[1], [2]]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# # test nucleus +# sampler1 = jax.tree_util.Partial( +# jittable_sample_topp_logits, rng=rng, temperature=1.0, topp=0.8 +# ) +# sampler2 = jax.tree_util.Partial( +# jittable_sample_topp_logits, rng=rng, temperature=1.0, topp=0.8 +# ) +# samplers = [sampler1, sampler2] +# token = engine._custom_sampling(logits, samplers) + +# original_tokens = [] +# for i in range(2): +# original_token = engine._sampling( +# logits[i], +# "nucleus", +# rng, +# temperature=1.0, +# topk=0, +# nucleus_topp=0.8, +# ) +# original_tokens.append(original_token) +# original_tokens = jnp.concatenate(original_tokens) +# print(f"custom sampling token {token} vs original tokens {original_tokens}") +# self.assertTrue(jnp.array_equal(token, original_tokens)) +# self.assertTrue(jnp.array_equal(token, jnp.array([[2], [2]]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# # # test topk + greedy +# sampler1 = jax.tree_util.Partial( +# jittable_sample_topk_logits, rng=rng, temperature=1.0, topk=3 +# ) +# sampler2 = jax.tree_util.Partial(jittable_sample_greedy_logits) +# samplers = [sampler1, sampler2] +# token = engine._custom_sampling(logits, samplers) + +# original_tokens = [] +# i = 0 +# original_token = engine._sampling( +# logits[i], +# "topk", +# rng, +# temperature=1.0, +# topk=3, +# nucleus_topp=0.8, +# ) +# original_tokens.append(original_token) + +# i = 1 +# original_token = engine._sampling( +# logits[i], +# "greedy", +# rng, +# temperature=0.0, +# topk=0, +# nucleus_topp=0.0, +# ) +# original_tokens.append(original_token) + +# original_tokens = jnp.concatenate(original_tokens) + +# print(f"custom sampling token {token} vs original tokens {original_tokens}") +# self.assertTrue(jnp.array_equal(token, original_tokens)) +# self.assertTrue(jnp.array_equal(token, jnp.array([[1], [0]]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# # test Prefill +# def test_prefill_with_custom_sampling(self): +# engine = self.setup() +# engine.rng = jax.random.key(3) + +# engine.env.sampling_algorithm = "" + +# # Inputs doesn't matter +# params = jnp.zeros((1,), jnp.float32) +# padded_tokens = jnp.zeros((1,), jnp.float32) +# true_length = 1 + +# # Greedy +# sampler = jax.tree_util.Partial(jittable_sample_greedy_logits) +# prefix, _ = engine.prefill( +# params=params, +# padded_tokens=padded_tokens, +# true_length=true_length, +# sampler=sampler, +# ) +# token = prefix.token +# print(f"Greedy output: {token}") +# self.assertTrue(jnp.array_equal(token, jnp.array([3]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# # Weighted +# sampler = jax.tree_util.Partial( +# jittable_sample_weighted_logits, rng=engine.rng, temperature=1.0 +# ) +# prefix, _ = engine.prefill( +# params=params, +# padded_tokens=padded_tokens, +# true_length=true_length, +# sampler=sampler, +# ) +# token = prefix.token +# print(f"Weighted output: {token}") +# self.assertTrue(jnp.array_equal(token, jnp.array([2]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# # Nucleus +# sampler = jax.tree_util.Partial( +# jittable_sample_topp_logits, rng=engine.rng, temperature=1.0, topp=0.8 +# ) +# prefix, _ = engine.prefill( +# params=params, +# padded_tokens=padded_tokens, +# true_length=true_length, +# sampler=sampler, +# ) +# token = prefix.token +# print(f"Nucleus output: {token}") +# self.assertTrue(jnp.array_equal(token, jnp.array([2]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# # Topk +# sampler = jax.tree_util.Partial( +# jittable_sample_topk_logits, rng=engine.rng, temperature=1.0, topk=3 +# ) + +# prefix, _ = engine.prefill( +# params=params, +# padded_tokens=padded_tokens, +# true_length=true_length, +# sampler=sampler, +# ) +# token = prefix.token +# print(f"Topk output: {token}") +# self.assertTrue(jnp.array_equal(token, jnp.array([1]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) + +# def test_insert_no_wrap_with_custom_sampling(self): +# engine = self.setup() +# engine.env.sampling_algorithm = "" +# engine.env.batch_size = 2 +# cache_shape = engine.env.cache_shape + +# prefill_cache_shape = (1, cache_shape[1], 16, cache_shape[3]) +# prefill_cache = [] +# for _ in range(engine.env.num_layers): +# prefill_cache.append( +# ( +# jnp.ones(prefill_cache_shape, dtype=jnp.bfloat16), +# jnp.ones(prefill_cache_shape, dtype=jnp.bfloat16), +# ) +# ) + +# sampler = jittable_sample_greedy_logits +# prefix = Prefix( +# token=jnp.ones((1)), +# caches=prefill_cache, +# seq_len=16, +# sampler=sampler, +# ) + +# doesnt_matter = jnp.array([0]) +# kv_cache = engine.env.make_caches_generate() +# kv_cache = [c.state() for c in kv_cache] + +# base_sampler = jax.tree_util.Partial( +# engine._sampling, +# algorithm=engine.env.sampling_algorithm, +# rng=engine.rng, +# temperature=engine.env.temperature, +# topk=engine.env.topk, +# nucleus_topp=engine.env.nucleus_topp, +# ) +# decode_state = DecodeState( +# tokens=jnp.zeros((engine.env.batch_size, 1)), +# caches=kv_cache, +# cache_scales=[doesnt_matter], +# current_position=16, +# lens=jnp.zeros((engine.env.batch_size, 1)), +# start=jnp.zeros((engine.env.batch_size, 1)), +# input_pos=jnp.zeros((engine.env.batch_size,)), +# mask=jnp.zeros((engine.env.batch_size, 128)), +# # samplers = [base_sampler] * engine.env.batch_size +# samplers=None, +# ) + +# # Insert to slot 1 +# result_decode_state = engine._insert_no_wrap(prefix, decode_state, slot=1) + +# self.assertAlmostEqual( +# result_decode_state.tokens.all(), decode_state.tokens.all() +# ) +# self.assertEqual(result_decode_state.samplers, prefix.sampler) + +# def test_generate_with_custom_sampling(self): +# engine = self.setup(batch_size=2) +# engine.rng = jax.random.key(3) +# engine.env.sampling_algorithm = "" + +# # Inputs doesn't matter +# doesnt_matter = jnp.array([0]) +# params = doesnt_matter + +# greedy_sampler = jax.tree_util.Partial(jittable_sample_greedy_logits) +# weighted_sampler = jax.tree_util.Partial( +# jittable_sample_weighted_logits, rng=engine.rng, temperature=1.0 +# ) +# decode_state = DecodeState( +# tokens=jnp.zeros((engine.env.batch_size, 1)), +# caches=[doesnt_matter], +# cache_scales=[doesnt_matter], +# current_position=0, +# lens=jnp.zeros((engine.env.batch_size, 1)), +# start=doesnt_matter, +# input_pos=jnp.zeros((engine.env.batch_size,)), +# mask=jnp.zeros((engine.env.batch_size, 1)), +# samplers=weighted_sampler, +# ) + +# # Topk + Weighted +# # algorithm, temperature, topk, nucleus_topp +# decode_state, _ = engine.generate_impl( +# params=params, decode_state=decode_state +# ) +# token = decode_state.tokens +# print(f"Topk + Weighted output: {token}") +# self.assertTrue(jnp.array_equal(token, jnp.array([[1], [2]]))) +# self.assertTrue(jnp.isdtype(token, jnp.int32)) # def test_insert(self): From eb9a8ff5d729ccae69533c694032b312874d8f7e Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Sun, 1 Dec 2024 16:38:27 +0000 Subject: [PATCH 7/9] Clean up. --- jetstream_pt/engine.py | 31 +--------- scripts/custom_sampling_benchmark.py | 84 ++++++++++++++++------------ 2 files changed, 50 insertions(+), 65 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 1061ff5..44b9406 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -53,32 +53,6 @@ Params = jax.Array PrefillInputs = jax.Array -STRATEGY_MAP = {"greedy": 0, "weighted": 1, "top_p": 2, "top_k": 3} - - -# class DefaultSampler(sampling_utils.BaseSampler): -# def __init__(self, rng, temperature, topk, nucleus_topp, algorithm): -# self.rng = rng -# self.temperature = temperature -# self.topk = topk -# self.nucleus_topp = nucleus_topp -# self.algorithm = algorithm - -# def __call__(self, logits): -# return PyTorchEngine._sampling(logits, self.algorithm, self.rng, self.temperature, self.topk, self.nucleus_topp) - -# # Define how to flatten the instance into leaves and auxiliary data -# def tree_flatten(self): -# children = (self.rng, self.temperature, self.topk, self.nucleus_topp, self.algorithm) # Leaves to be flattened -# aux_data = () # Auxiliary data that does not need to be traced -# return children, aux_data - -# # Define how to unflatten from leaves and auxiliary data -# @classmethod -# def tree_unflatten(cls, aux_data, children): -# rng, temperature, topk, nucleus_topp, algorithm = children -# return cls(rng, temperature, topk, nucleus_topp, algorithm) - @struct.dataclass # pylint: disable-next=all @@ -393,10 +367,6 @@ def prefill( ) token = jnp.reshape(token, (1,)) token_out = jnp.reshape(token, (1, 1)) - # token = sampling( - # logits=logits - # ).reshape(prefill_batch_size, 1) - data = jnp.concatenate( [ token_out, # First token @@ -881,6 +851,7 @@ def update_mask(): # fill mask later, now use flash attention mask = update_mask() + # Temporarily disabled becuase handling per request sampling is not ready yet. # next_token = self._custom_sampling(logits, decode_state.samplers) if decode_state.samplers: next_token = decode_state.samplers(logits) diff --git a/scripts/custom_sampling_benchmark.py b/scripts/custom_sampling_benchmark.py index 34188cc..04891ab 100644 --- a/scripts/custom_sampling_benchmark.py +++ b/scripts/custom_sampling_benchmark.py @@ -1,4 +1,3 @@ - import time import jax @@ -9,65 +8,79 @@ from jetstream.engine import sampling_utils from jetstream_pt.engine import PyTorchEngine + def sample_topk_logits(logits, topk, temperature, rng): - sorted_indices = jnp.argsort(logits, descending=True) # Sort in descending order - topk_mask = jnp.arange(sorted_indices.shape[-1]) < topk - topk_idxs = jnp.where(topk_mask, sorted_indices, -1) - topk_logits = jnp.where(topk_idxs == -1, -jnp.inf, logits) - - # self.rng, sub_rng = jax.random.split(self.rng) - sub_rng = rng - sampled_idx = jnp.expand_dims( - jax.random.categorical(sub_rng, topk_logits / temperature).astype( - jnp.int32 - ), - axis=-1, - ) - print(f"topk_idxs {topk_idxs.shape} sampled_idx {sampled_idx.shape}") - sampled_tokens = jnp.squeeze( - jnp.take_along_axis(topk_idxs, sampled_idx, axis=-1), axis=-1 - ).astype(jnp.int32) + sorted_indices = jnp.argsort( + logits, descending=True + ) # Sort in descending order + topk_mask = jnp.arange(sorted_indices.shape[-1]) < topk + topk_idxs = jnp.where(topk_mask, sorted_indices, -1) + topk_logits = jnp.where(topk_idxs == -1, -jnp.inf, logits) + + # self.rng, sub_rng = jax.random.split(self.rng) + sub_rng = rng + sampled_idx = jnp.expand_dims( + jax.random.categorical(sub_rng, topk_logits / temperature).astype( + jnp.int32 + ), + axis=-1, + ) + print(f"topk_idxs {topk_idxs.shape} sampled_idx {sampled_idx.shape}") + sampled_tokens = jnp.squeeze( + jnp.take_along_axis(topk_idxs, sampled_idx, axis=-1), axis=-1 + ).astype(jnp.int32) + + return sampled_tokens - return sampled_tokens def sample_weighted_logits(logits, topk, temperature, rng): return jax.random.categorical(rng, logits / temperature) - + + def replicate_array_with_sharding(array): - # Create a sharding with None for all dimensions (meaning replicated) - mesh = jax.sharding.Mesh(jax.devices(), ('data',)) # or your existing mesh - sharding = NamedSharding( - mesh, - PartitionSpec(None,) * len(array.shape) # None for each dimension - ) - return jax.device_put(array, sharding) + # Create a sharding with None for all dimensions (meaning replicated) + mesh = jax.sharding.Mesh(jax.devices(), ("data",)) # or your existing mesh + sharding = NamedSharding( + mesh, + PartitionSpec( + None, + ) + * len(array.shape), # None for each dimension + ) + return jax.device_put(array, sharding) + def test_custom_sampling(): """test custom sampling performance""" batch_size = 96 hidden_size = 8192 rng = jax.random.key(0) - logits = jax.random.normal(rng, (batch_size,1, hidden_size), dtype=float) + logits = jax.random.normal(rng, (batch_size, 1, hidden_size), dtype=float) logits = replicate_array_with_sharding(logits) - topk = jax.random.randint(rng, (batch_size,), dtype=int, minval=1, maxval=3) - temperature = jax.random.uniform(rng, (batch_size,), dtype=float, minval=0.0, maxval=1.0) - + temperature = jax.random.uniform( + rng, (batch_size,), dtype=float, minval=0.0, maxval=1.0 + ) + samplers = [] # sampler = sample_topk_logits sampler = sample_weighted_logits sampler = jax.jit(sampler) - print(f"logits {logits[:, -1].shape}, topk {topk[0]}, temperature {temperature[0]}") + print( + f"logits {logits[:, -1].shape}, topk {topk[0]}, temperature {temperature[0]}" + ) sampler(logits[:, -1], topk[0], temperature[0], rng) for i in range(batch_size): rng, sub_rng = jax.random.split(rng) - partial = jax.tree_util.Partial(sampler, topk=topk[i], temperature=temperature[i], rng=sub_rng) + partial = jax.tree_util.Partial( + sampler, topk=topk[i], temperature=temperature[i], rng=sub_rng + ) samplers.append(partial) custom_sampling = PyTorchEngine._custom_sampling -# custom_sampling = jax.jit(custom_sampling) + # custom_sampling = jax.jit(custom_sampling) custom_sampling(logits, samplers) start = time.perf_counter() @@ -78,7 +91,7 @@ def test_custom_sampling(): end = time.perf_counter() duration = end - start print(f"Custom sampling: total time {duration} for {loops} loops") - + start = time.perf_counter() loops = 10 # for i in range(loops): @@ -88,4 +101,5 @@ def test_custom_sampling(): duration = end - start print(f"Uniform sampling: total time {duration} for {loops} loops") + test_custom_sampling() From cc206723f4dde3ce93710be30e1864b04710a580 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Sun, 1 Dec 2024 16:40:32 +0000 Subject: [PATCH 8/9] Clean up more. --- scripts/jax_experiments.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/scripts/jax_experiments.py b/scripts/jax_experiments.py index f4c939b..3d38094 100644 --- a/scripts/jax_experiments.py +++ b/scripts/jax_experiments.py @@ -24,6 +24,44 @@ import torch_xla2.extra +def test1(): + """test jit cache size""" + + @functools.partial(jax.jit, static_argnums=(2,)) + # pylint: disable-next=all + def f(x, i, issum): + if issum: + return x + i + + return x - i + + x = jnp.ones((10,)) + print(f(x, 0, True)) + print("cache", f._cache_size()) + print(f(x, 1, False)) + print("cache", f._cache_size()) + + # pylint: disable-next=all + class A: + + def __init__(self, a): + self.a = a + + def incr(self): + """increase by 1""" + self.a += 1 + + @jax.jit + def f2(x): + a = A(x) + a.incr() + return a.a + + print(f2(x)) + print(f2(x)) + print(f2(x)) + + # pylint: disable-next=all def test2(): """test insert cache""" From e001d861d063b6e13bf7c1fefc0547d74d43a255 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Sun, 1 Dec 2024 16:42:52 +0000 Subject: [PATCH 9/9] Remove custom sampling benchmark from PR. --- scripts/custom_sampling_benchmark.py | 105 --------------------------- 1 file changed, 105 deletions(-) delete mode 100644 scripts/custom_sampling_benchmark.py diff --git a/scripts/custom_sampling_benchmark.py b/scripts/custom_sampling_benchmark.py deleted file mode 100644 index 04891ab..0000000 --- a/scripts/custom_sampling_benchmark.py +++ /dev/null @@ -1,105 +0,0 @@ -import time - -import jax -import jax.numpy as jnp -from jax.sharding import NamedSharding, PartitionSpec, Mesh - - -from jetstream.engine import sampling_utils -from jetstream_pt.engine import PyTorchEngine - - -def sample_topk_logits(logits, topk, temperature, rng): - sorted_indices = jnp.argsort( - logits, descending=True - ) # Sort in descending order - topk_mask = jnp.arange(sorted_indices.shape[-1]) < topk - topk_idxs = jnp.where(topk_mask, sorted_indices, -1) - topk_logits = jnp.where(topk_idxs == -1, -jnp.inf, logits) - - # self.rng, sub_rng = jax.random.split(self.rng) - sub_rng = rng - sampled_idx = jnp.expand_dims( - jax.random.categorical(sub_rng, topk_logits / temperature).astype( - jnp.int32 - ), - axis=-1, - ) - print(f"topk_idxs {topk_idxs.shape} sampled_idx {sampled_idx.shape}") - sampled_tokens = jnp.squeeze( - jnp.take_along_axis(topk_idxs, sampled_idx, axis=-1), axis=-1 - ).astype(jnp.int32) - - return sampled_tokens - - -def sample_weighted_logits(logits, topk, temperature, rng): - return jax.random.categorical(rng, logits / temperature) - - -def replicate_array_with_sharding(array): - # Create a sharding with None for all dimensions (meaning replicated) - mesh = jax.sharding.Mesh(jax.devices(), ("data",)) # or your existing mesh - sharding = NamedSharding( - mesh, - PartitionSpec( - None, - ) - * len(array.shape), # None for each dimension - ) - return jax.device_put(array, sharding) - - -def test_custom_sampling(): - """test custom sampling performance""" - batch_size = 96 - hidden_size = 8192 - rng = jax.random.key(0) - logits = jax.random.normal(rng, (batch_size, 1, hidden_size), dtype=float) - logits = replicate_array_with_sharding(logits) - - topk = jax.random.randint(rng, (batch_size,), dtype=int, minval=1, maxval=3) - temperature = jax.random.uniform( - rng, (batch_size,), dtype=float, minval=0.0, maxval=1.0 - ) - - samplers = [] - # sampler = sample_topk_logits - sampler = sample_weighted_logits - sampler = jax.jit(sampler) - print( - f"logits {logits[:, -1].shape}, topk {topk[0]}, temperature {temperature[0]}" - ) - sampler(logits[:, -1], topk[0], temperature[0], rng) - - for i in range(batch_size): - rng, sub_rng = jax.random.split(rng) - partial = jax.tree_util.Partial( - sampler, topk=topk[i], temperature=temperature[i], rng=sub_rng - ) - samplers.append(partial) - - custom_sampling = PyTorchEngine._custom_sampling - # custom_sampling = jax.jit(custom_sampling) - custom_sampling(logits, samplers) - - start = time.perf_counter() - loops = 10 - # for i in range(loops): - result = custom_sampling(logits, samplers) - result.block_until_ready() - end = time.perf_counter() - duration = end - start - print(f"Custom sampling: total time {duration} for {loops} loops") - - start = time.perf_counter() - loops = 10 - # for i in range(loops): - result = sampler(logits, topk[0], temperature[0], rng) - result.block_until_ready() - end = time.perf_counter() - duration = end - start - print(f"Uniform sampling: total time {duration} for {loops} loops") - - -test_custom_sampling()