From 1182c48e7b50b836acca4c120315d3eb44d01a4d Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 30 Sep 2024 01:43:37 +0000 Subject: [PATCH 1/4] Initial change. --- jetstream_pt/engine.py | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index b7298e1..0b6a60f 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -75,7 +75,9 @@ class DecodeState: start: jax.Array # [batch_size, 1], the starting pos for each slot input_pos: jax.Array # [batch_size, 1] input pos for each slot mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid - + topk: int + nucleus_topp: int + temperature: int # NOTE model specific @@ -280,7 +282,7 @@ def _call_model_prefill(self, weights, tokens, input_indexes): caches_res = [c.state() for c in caches] return torchjax.from_torch((res, caches_res)) - def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray: + def _sampling(self, logits: Any, batch_size: int, topk: Any, nucleus_topp: Any, temperature: Any) -> jnp.ndarray: if len(logits.shape) == 2: logits = jnp.expand_dims(logits, 0) return ( @@ -288,9 +290,9 @@ def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray: logits[:, -1], self.rng, self.env.sampling_algorithm, - self.env.topk, - self.env.nucleus_topp, - self.env.temperature, + topk, + nucleus_topp, + temperature, ) .reshape(batch_size, -1) .astype(jnp.int32) @@ -304,6 +306,7 @@ def prefill( padded_tokens: PrefillInputs, # PrefillInputs[jax.Array], true_length: int, sampler: Optional[Callable[[Any], Any]] = None, + sampling_config: Any = None, ) -> Tuple[Prefix, engine_api.ResultTokens]: if isinstance(padded_tokens, jax.Array): batched_token = padded_tokens.reshape(1, -1) @@ -321,6 +324,11 @@ def prefill( ) if len(logits.shape) == 3: # b, seqlen, num words logits = logits[0] # seqlen, num words + + topk = sampling_config.topk if sampling_config else self.env.topk + nucleus_topp = sampling_config.nucleus_topp if sampling_config else self.env.nucleus_topp + temperature = sampling_config.temperature if sampling_config else self.env.temperature + if sampler: token = sampler(logits[true_length - 1]) else: @@ -328,9 +336,9 @@ def prefill( logits[true_length - 1], self.rng, self.env.sampling_algorithm, - self.env.topk, - self.env.nucleus_topp, - self.env.temperature, + topk, + nucleus_topp, + temperature, ) token_out = jnp.reshape(token, (1, 1)) data = jnp.concatenate( @@ -729,7 +737,7 @@ def false_comp(b, i, bk, start, end): return b_next, i_next def generate( - self, params: Any, decode_state: DecodeState, sampler=None + self, params: Any, decode_state: DecodeState, sampler=None, sampling_config=None ) -> tuple[DecodeState, engine_api.ResultTokens]: return (None, None) @@ -753,6 +761,7 @@ def generate_impl( params: Any, decode_state: DecodeState, sampler=None, + sampling_config=None, page_token_indices=None, ) -> tuple[DecodeState, engine_api.ResultTokens]: # seq_len = padded_tokens.shape[0] @@ -799,10 +808,15 @@ def update_mask(): # fill mask later, now use flash attention mask = update_mask() + topk = sampling_config.topk if sampling_config else self.env.topk + nucleus_topp = sampling_config.nucleus_topp if sampling_config else self.env.nucleus_topp + temperature = sampling_config.temperature if sampling_config else self.env.temperature + if sampler: next_token = sampler(logits[:, -1]) else: - next_token = self._sampling(logits, self.env.batch_size) + next_token = self._sampling(logits, self.env.batch_size, topk, nucleus_topp, temperature) + if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 lens = decode_state.lens + 1 @@ -976,6 +990,10 @@ def get_decode_state_sharding(self) -> DecodeState: self.replicated, self.replicated, self.replicated, + self.replicated, + self.replicated, + self.replicated, + self.replicated, ) def get_prefix_sequence_ddim(self) -> Any: From 931e2ce69bfc821f4816c480d1ee80d990045a03 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 30 Sep 2024 05:30:50 +0000 Subject: [PATCH 2/4] Limit the changes to temperatures, sets temperature to each request; --- jetstream_pt/engine.py | 63 ++++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 0b6a60f..649a67e 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -60,6 +60,8 @@ class Prefix: token: jax.Array # [1, seqlen] caches: List[Tuple[jax.Array, jax.Array]] seq_len: int # true seqlen front pad + # temperature parameter for scaling probability + temperature: float @struct.dataclass @@ -75,9 +77,7 @@ class DecodeState: start: jax.Array # [batch_size, 1], the starting pos for each slot input_pos: jax.Array # [batch_size, 1] input pos for each slot mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid - topk: int - nucleus_topp: int - temperature: int + temperatures: List[float] # [batch_size, 1], the temperature for each slot # NOTE model specific @@ -170,6 +170,7 @@ def init_decode_state( scalers = [] if self.env.quant_config.enable_kv_quantization: scalers = [c.scalers() for c in caches_obj] + temperatures = [self.env.temperature] * self.env.batch_size return DecodeState( jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), caches, @@ -183,6 +184,7 @@ def init_decode_state( float("-inf"), dtype=self.default_dtype, ), # mask + temperatures, ) # pylint: disable-next=all @@ -282,7 +284,9 @@ def _call_model_prefill(self, weights, tokens, input_indexes): caches_res = [c.state() for c in caches] return torchjax.from_torch((res, caches_res)) - def _sampling(self, logits: Any, batch_size: int, topk: Any, nucleus_topp: Any, temperature: Any) -> jnp.ndarray: + def _sampling( + self, logits: Any, batch_size: int, temperatures: List[float] + ) -> jnp.ndarray: if len(logits.shape) == 2: logits = jnp.expand_dims(logits, 0) return ( @@ -290,9 +294,9 @@ def _sampling(self, logits: Any, batch_size: int, topk: Any, nucleus_topp: Any, logits[:, -1], self.rng, self.env.sampling_algorithm, - topk, - nucleus_topp, - temperature, + self.env.topk, + self.env.nucleus_topp, + temperatures, ) .reshape(batch_size, -1) .astype(jnp.int32) @@ -305,8 +309,7 @@ def prefill( existing_prefix: Optional[Prefix] = None, padded_tokens: PrefillInputs, # PrefillInputs[jax.Array], true_length: int, - sampler: Optional[Callable[[Any], Any]] = None, - sampling_config: Any = None, + sampler: Any, ) -> Tuple[Prefix, engine_api.ResultTokens]: if isinstance(padded_tokens, jax.Array): batched_token = padded_tokens.reshape(1, -1) @@ -325,21 +328,24 @@ def prefill( if len(logits.shape) == 3: # b, seqlen, num words logits = logits[0] # seqlen, num words - topk = sampling_config.topk if sampling_config else self.env.topk - nucleus_topp = sampling_config.nucleus_topp if sampling_config else self.env.nucleus_topp - temperature = sampling_config.temperature if sampling_config else self.env.temperature + temperature = ( + sampler["temperature"] + if isinstance(sampler, dict) and "temperature" in sampler + else self.env.temperature + ) - if sampler: + if sampler and callable(sampler): token = sampler(logits[true_length - 1]) else: token = sampling_utils.sampling( logits[true_length - 1], self.rng, self.env.sampling_algorithm, - topk, - nucleus_topp, + self.env.topk, + self.env.nucleus_topp, temperature, ) + token_out = jnp.reshape(token, (1, 1)) data = jnp.concatenate( [ @@ -365,7 +371,7 @@ def prefill( # v, seq_len - true_length, true_length, axis=2)) # for k, v in updated_caches # ] - return Prefix(token, updated_caches, true_length), result + return Prefix(token, updated_caches, true_length, temperature), result def shrink_prefix( self, @@ -484,6 +490,7 @@ def insert(cache, scaler, new_entry, update_index): caches.append((kcache, vcache)) scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) + decode_state.temperatures[slot] = prefix.temperature return DecodeState( tokens, caches, @@ -493,6 +500,7 @@ def insert(cache, scaler, new_entry, update_index): start, input_pos, mask, + decode_state.temperatures, ) # pylint: disable-next=all @@ -577,6 +585,7 @@ def insert(cache, scaler, new_entry): scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) + decode_state.temperatures[slot] = prefix.temperature return DecodeState( tokens, caches, @@ -586,6 +595,7 @@ def insert(cache, scaler, new_entry): start, input_pos, mask, + decode_state.temperatures, ) def _insert_page_attention( @@ -621,6 +631,7 @@ def _insert_page_attention( input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len) scales = None lens = decode_state.lens.at[slot].set(1) + decode_state.temperatures[slot] = prefix.temperature return DecodeState( tokens, caches, @@ -630,6 +641,7 @@ def _insert_page_attention( start, input_pos, mask, + decode_state.temperatures, ) def insert( @@ -737,7 +749,9 @@ def false_comp(b, i, bk, start, end): return b_next, i_next def generate( - self, params: Any, decode_state: DecodeState, sampler=None, sampling_config=None + self, + params: Any, + decode_state: DecodeState, ) -> tuple[DecodeState, engine_api.ResultTokens]: return (None, None) @@ -761,7 +775,6 @@ def generate_impl( params: Any, decode_state: DecodeState, sampler=None, - sampling_config=None, page_token_indices=None, ) -> tuple[DecodeState, engine_api.ResultTokens]: # seq_len = padded_tokens.shape[0] @@ -808,14 +821,12 @@ def update_mask(): # fill mask later, now use flash attention mask = update_mask() - topk = sampling_config.topk if sampling_config else self.env.topk - nucleus_topp = sampling_config.nucleus_topp if sampling_config else self.env.nucleus_topp - temperature = sampling_config.temperature if sampling_config else self.env.temperature - - if sampler: + if sampler and callable(sampler): next_token = sampler(logits[:, -1]) else: - next_token = self._sampling(logits, self.env.batch_size, topk, nucleus_topp, temperature) + next_token = self._sampling( + logits, self.env.batch_size, decode_state.temperatures + ) if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 @@ -977,6 +988,7 @@ def get_prefix_destination_sharding(self) -> Prefix: if self.env.page_attention else self.cache_sharding, self.replicated, + self.replicated, ) def get_decode_state_sharding(self) -> DecodeState: @@ -991,9 +1003,6 @@ def get_decode_state_sharding(self) -> DecodeState: self.replicated, self.replicated, self.replicated, - self.replicated, - self.replicated, - self.replicated, ) def get_prefix_sequence_ddim(self) -> Any: From efefe9c6d0562c5e176b24b50a1b75b036267573 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 21 Oct 2024 20:35:24 +0000 Subject: [PATCH 3/4] Using sampler config instead of temperature only to be more comprehensive. Added tests. --- jetstream_pt/engine.py | 285 ++++++++++++++++++++----- tests/helpers.py | 6 +- tests/test_engine.py | 459 +++++++++++++++++++++++++++++++++++++++-- 3 files changed, 687 insertions(+), 63 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 649a67e..7c77aff 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -53,6 +53,8 @@ Params = jax.Array PrefillInputs = jax.Array +NEG_INF = -1.0e7 # Sampling masking + @struct.dataclass # pylint: disable-next=all @@ -60,8 +62,7 @@ class Prefix: token: jax.Array # [1, seqlen] caches: List[Tuple[jax.Array, jax.Array]] seq_len: int # true seqlen front pad - # temperature parameter for scaling probability - temperature: float + sampler_config: jax.Array # Sampler or sampling config, [] @struct.dataclass @@ -75,9 +76,16 @@ class DecodeState: current_position: int lens: jax.Array # [batch_size, 1], the output token length start: jax.Array # [batch_size, 1], the starting pos for each slot - input_pos: jax.Array # [batch_size, 1] input pos for each slot + input_pos: ( + jax.Array + ) # [batch_size, 1] total (prefill + decode) length for each slot mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid - temperatures: List[float] # [batch_size, 1], the temperature for each slot + # The sampling configuration. + # If sampling_algorithm set to empty, the shape is [batch_size, 1] + # Otherwise it's a list of intergers + # The last dimension contains [algorithm, temperature, topk, nucleus] + sampler_config: jax.Array | List[int] + # NOTE model specific @@ -95,7 +103,11 @@ def __init__( self.pt_model = pt_model self.env = env self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 - self.rng = jax.random.PRNGKey(0) + self.rng = jax.random.key(0).reshape( + 1, + ) + # For sampling + self.splited_rngs = jax.random.split(self.rng, num=self.env.batch_size) self.weights = weights self.y_sharding = env.sharding_by_axis(1) @@ -170,7 +182,20 @@ def init_decode_state( scalers = [] if self.env.quant_config.enable_kv_quantization: scalers = [c.scalers() for c in caches_obj] - temperatures = [self.env.temperature] * self.env.batch_size + + if self.env.sampling_algorithm == "": + # [algorithm, temperature, topk, nucleus] + sampler_config = [0, 0.0, 0, 0.0] + sampler_config = jnp.tile(sampler_config, (self.env.batch_size, 1)) + + else: + sampler_config = [ + self.env.sampling_algorithm, + self.env.temperature, + self.env.topk, + self.env.nucleus_topp, + ] + return DecodeState( jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32), caches, @@ -184,7 +209,7 @@ def init_decode_state( float("-inf"), dtype=self.default_dtype, ), # mask - temperatures, + sampler_config, ) # pylint: disable-next=all @@ -284,21 +309,119 @@ def _call_model_prefill(self, weights, tokens, input_indexes): caches_res = [c.state() for c in caches] return torchjax.from_torch((res, caches_res)) + def _greedy_sampling(self, logits): + return jnp.argmax(logits, axis=-1) + + def _weighted_sampling(self, logits, rng, temperature): + return jax.random.categorical(rng, logits / temperature) + + def _nucleus_sampling(self, logits, rng, temperature, nucleus_topp): + # return sampling_utils.sample_nucleus_topp_logits(logits, nucleus_topp, temperature, rng) + """Restrict sampling to the top logits with cumulative probability >= + nucleus_topp. + + The nucleus sampling method is proposed in the paper `The Curious Case of + Neural Text Degeneration (https://arxiv.org/pdf/1904.09751.pdf)` + + """ + # if nucleus_topp < 0: + # raise ValueError( + # "Can't apply nucleus with parameter {nucleus_topp=} less zero" + # ) + logits_sorted = jnp.sort(logits, axis=-1)[..., ::-1] # sort descending + sorted_cum_probs = jnp.cumsum( + jax.nn.softmax(logits_sorted, axis=-1), axis=-1 + ) # get cumsum probs + cutoff_index = jnp.sum( + sorted_cum_probs < nucleus_topp, axis=-1, keepdims=True + ) # find cutoff index + cutoff_logit = jnp.take_along_axis(logits_sorted, cutoff_index, axis=-1) + logits = jnp.where( + logits < cutoff_logit, jnp.full_like(logits, NEG_INF), logits + ) + result = jax.random.categorical(rng, logits / temperature) + return result + + def _topk_sampling(self, logits, rng, temperature, topk): + # return sampling_utils.sample_topk_logits(logits, topk, temperature, rng) + """Restricting sampling to the best k logits.""" + # if topk <= 0: + # raise ValueError("Can't apply algorithm topk with parameter {topk=} <= 0") + sorted_indices = jnp.argsort(logits)[::-1] # Sort in descending order + topk_mask = jnp.arange(sorted_indices.shape[-1]) < topk + topk_idxs = jnp.where(topk_mask, sorted_indices, -1) + + topk_logits = jnp.where(topk_idxs == -1, -jnp.inf, logits) + + sampled_idx = jnp.expand_dims( + jax.random.categorical(rng, topk_logits / temperature).astype( + jnp.int32 + ), + axis=-1, + ) + sampled_tokens = jnp.squeeze( + jnp.take_along_axis(topk_idxs, sampled_idx, axis=-1), axis=-1 + ).astype(jnp.int32) + + return sampled_tokens + + # Algorithm type: + # 0: Greedy 1: Weighted 2: Nucleus 3: Top-k + def _apply_sampling( + self, logits, algorithm, rng, temperature, topk, nucleus_topp + ) -> jnp.ndarray: + return jax.lax.cond( + algorithm == 0, + lambda: self._greedy_sampling(logits), # Greedy + lambda: jax.lax.cond( + algorithm == 1, + lambda: self._weighted_sampling( + logits, rng, temperature + ), # Weighted + lambda: jax.lax.cond( + algorithm == 2, + lambda: self._nucleus_sampling( + logits, rng, temperature, nucleus_topp + ), # Nucleus sampling + lambda: self._topk_sampling( + logits, rng, temperature, topk + ), # Top-k sampling + ), + ), + ) + + def _custom_sampling( + self, logits, algorithm, rng, temperature, topk, nucleus_topp + ) -> jnp.ndarray: + if len(logits.shape) == 2: + logits = jnp.expand_dims(logits, 0) + logits = logits[:, -1] + apply_sampling_v = jax.vmap( + lambda _logits, _algorithm, _rngs, _temperature, _topk, _nucleus_topp: self._apply_sampling( + _logits, _algorithm, _rngs, _temperature, _topk, _nucleus_topp + ), + in_axes=(0, 0, 0, 0, 0, 0), + ) + apply_sampling_v = jax.jit(apply_sampling_v) + return apply_sampling_v( + logits, algorithm, rng, temperature, topk, nucleus_topp + ).reshape(self.env.batch_size, -1) + def _sampling( - self, logits: Any, batch_size: int, temperatures: List[float] + self, logits: Any, algorithm, rng, temperature, topk, nucleus_topp ) -> jnp.ndarray: if len(logits.shape) == 2: logits = jnp.expand_dims(logits, 0) return ( sampling_utils.sampling( logits[:, -1], - self.rng, - self.env.sampling_algorithm, - self.env.topk, - self.env.nucleus_topp, - temperatures, + rng, + algorithm, + topk, + nucleus_topp, + temperature, ) - .reshape(batch_size, -1) + .reshape(self.env.batch_size, -1) .astype(jnp.int32) ) @@ -307,9 +430,9 @@ def prefill( *, params: Any, # Weights existing_prefix: Optional[Prefix] = None, - padded_tokens: PrefillInputs, # PrefillInputs[jax.Array], + padded_tokens: PrefillInputs, # PrefillInputs[jax.Array] true_length: int, - sampler: Any, + sampler: Optional[Callable[[Any], Any]] = None, ) -> Tuple[Prefix, engine_api.ResultTokens]: if isinstance(padded_tokens, jax.Array): batched_token = padded_tokens.reshape(1, -1) @@ -328,23 +451,39 @@ def prefill( if len(logits.shape) == 3: # b, seqlen, num words logits = logits[0] # seqlen, num words - temperature = ( - sampler["temperature"] - if isinstance(sampler, dict) and "temperature" in sampler - else self.env.temperature - ) - - if sampler and callable(sampler): - token = sampler(logits[true_length - 1]) + if self.env.sampling_algorithm == "": + assert ( + isinstance(sampler, List) + or isinstance(sampler, list) + or isinstance(sampler, Tuple) + or isinstance(sampler, tuple) + ), f"{type(sampler)} is not valid" + algorithm, temperature, topk, nucleus_topp = sampler + + algorithm = jnp.array([algorithm]) + temperature = jnp.array([temperature]) + topk = jnp.array([topk]) + nucleus_topp = jnp.array([nucleus_topp]) + + # Prefill only handle batch size of 1, therefore no need to use splitted rngs + sampling = self._custom_sampling else: - token = sampling_utils.sampling( - logits[true_length - 1], - self.rng, + algorithm, temperature, topk, nucleus_topp = ( self.env.sampling_algorithm, + self.env.temperature, self.env.topk, self.env.nucleus_topp, - temperature, ) + sampling = self._sampling + + token = sampling( + logits=logits, + algorithm=algorithm, + rng=self.rng, + temperature=temperature, + topk=topk, + nucleus_topp=nucleus_topp, + ) token_out = jnp.reshape(token, (1, 1)) data = jnp.concatenate( @@ -371,7 +510,10 @@ def prefill( # v, seq_len - true_length, true_length, axis=2)) # for k, v in updated_caches # ] - return Prefix(token, updated_caches, true_length, temperature), result + return ( + Prefix(token, updated_caches, true_length, jnp.array(sampler)), + result, + ) def shrink_prefix( self, @@ -490,7 +632,19 @@ def insert(cache, scaler, new_entry, update_index): caches.append((kcache, vcache)) scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) - decode_state.temperatures[slot] = prefix.temperature + + if self.env.sampling_algorithm == "": + sampler_config = decode_state.sampler_config.at[slot].set( + prefix.sampler_config + ) + else: + sampler_config = [ + self.env.sampling_algorithm, + self.env.temperature, + self.env.topk, + self.env.nucleus_topp, + ] + return DecodeState( tokens, caches, @@ -500,7 +654,7 @@ def insert(cache, scaler, new_entry, update_index): start, input_pos, mask, - decode_state.temperatures, + sampler_config, ) # pylint: disable-next=all @@ -585,7 +739,19 @@ def insert(cache, scaler, new_entry): scales.append((kscale, vscale)) lens = decode_state.lens.at[slot].set(1) - decode_state.temperatures[slot] = prefix.temperature + + if self.env.sampling_algorithm == "": + sampler_config = decode_state.sampler_config.at[slot].set( + prefix.sampler_config + ) + else: + sampler_config = [ + self.env.sampling_algorithm, + self.env.temperature, + self.env.topk, + self.env.nucleus_topp, + ] + return DecodeState( tokens, caches, @@ -595,7 +761,7 @@ def insert(cache, scaler, new_entry): start, input_pos, mask, - decode_state.temperatures, + sampler_config, ) def _insert_page_attention( @@ -631,7 +797,19 @@ def _insert_page_attention( input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len) scales = None lens = decode_state.lens.at[slot].set(1) - decode_state.temperatures[slot] = prefix.temperature + + if self.env.sampling_algorithm == "": + sampler_config = decode_state.sampler_config.at[slot].set( + prefix.sampler_config + ) + else: + sampler_config = [ + self.env.sampling_algorithm, + self.env.temperature, + self.env.topk, + self.env.nucleus_topp, + ] + return DecodeState( tokens, caches, @@ -641,7 +819,7 @@ def _insert_page_attention( start, input_pos, mask, - decode_state.temperatures, + sampler_config, ) def insert( @@ -774,7 +952,6 @@ def generate_impl( self, params: Any, decode_state: DecodeState, - sampler=None, page_token_indices=None, ) -> tuple[DecodeState, engine_api.ResultTokens]: # seq_len = padded_tokens.shape[0] @@ -786,12 +963,16 @@ def generate_impl( else: input_indexes = decode_state.input_pos - ragged_batch_index, ragged_block_index = ( - self.precompute_ragged_block_indices(decode_state) - ) - ragged_batch_index, ragged_block_index = ragged_batch_index.reshape( - (-1) - ), ragged_block_index.reshape((-1)) + # TODO(lancewang): Remove ragged index precomputation + # ragged_batch_index, ragged_block_index = ( + # self.precompute_ragged_block_indices(decode_state) + # ) + # ragged_batch_index, ragged_block_index = ragged_batch_index.reshape( + # (-1) + # ), ragged_block_index.reshape((-1)) + + ragged_batch_index = 0 + ragged_block_index = 0 def update_mask(): if self.env.ring_buffer: @@ -821,12 +1002,21 @@ def update_mask(): # fill mask later, now use flash attention mask = update_mask() - if sampler and callable(sampler): - next_token = sampler(logits[:, -1]) + if self.env.sampling_algorithm == "": + sampling = self._custom_sampling + rng = self.splited_rngs else: - next_token = self._sampling( - logits, self.env.batch_size, decode_state.temperatures - ) + sampling = self._sampling + rng = self.rng + + next_token = sampling( + logits=logits, + algorithm=decode_state.sampler_config[:, 0:1].reshape(-1), + rng=rng, + temperature=decode_state.sampler_config[:, 1:2].reshape(-1), + topk=decode_state.sampler_config[:, 2:3].reshape(-1), + nucleus_topp=decode_state.sampler_config[:, 3:4].reshape(-1), + ) if self.env.ring_buffer: input_pos = decode_state.input_pos + 1 @@ -869,6 +1059,7 @@ def update_mask(): decode_state.start, input_pos, mask, + decode_state.sampler_config, ) return new_decode_state, result_tokens diff --git a/tests/helpers.py b/tests/helpers.py index ac0ea5f..860389d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,7 +7,9 @@ # pylint: disable-next=all -def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None): +def make_env_tiny( + bf16_enable=True, env_data_update_fn=lambda _: None, batch_size=1 +): torch_dtype = torch.bfloat16 if bf16_enable else torch.float32 torch.set_default_dtype(torch_dtype) jax.config.update("jax_dynamic_shapes", False) @@ -19,7 +21,7 @@ def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None): environment_data.cache_sequence_length = 128 environment_data.bf16_enable = bf16_enable environment_data.model_type = "llama-2-tiny" - environment_data.batch_size = 1 + environment_data.batch_size = batch_size environment_data.num_layers = config.n_layers environment_data.cache_shape = ( 1, diff --git a/tests/test_engine.py b/tests/test_engine.py index 57245c0..2fe94ba 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -20,16 +20,61 @@ from jetstream_pt.third_party.llama import model_exportable from jetstream_pt.engine import PyTorchEngine +from jetstream_pt.engine import DecodeState +from jetstream_pt.engine import Prefix from tests import helpers +from jetstream_pt import cache_manager + + +class MockEngine(PyTorchEngine): + + def _call_model_prefill(self, weights, tokens, input_indexes): + caches = [ + cache_manager.KVCachePrefill( + self.env.quant_config.enable_kv_quantization + ) + for _ in self.pt_model.layers + ] + # logits = jnp.ones((self.env.batch_size, 1), jnp.float32) + assert ( + self.env.batch_size == 1 + ), f"The batch size {self.env.batch_size} != 1" + logits = jnp.array([[0.5, 0.6, 0.7, 0.8]]) + return logits, caches + + def _call_model_generate( + self, + weights, + tokens, + input_indexes, + caches, + cache_scales, + mask, + start, + input_pos, + ragged_batch_index, + ragged_block_index, + page_token_indices, + ): + logits = jnp.array( + [ + [[0.5, 0.6, 0.7, 0.8]], + # [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]], + [[0.4, 0.3, 0.2, 0.1]], + ] + ) + return logits, caches, cache_scales class EngineTest(unittest.TestCase): - def setup(self): - env, model_arg = helpers.make_env_tiny(bf16_enable=True) + def setup(self, batch_size=1): + env, model_arg = helpers.make_env_tiny( + bf16_enable=True, batch_size=batch_size + ) model_ours = model_exportable.Transformer(model_arg, env) - engine = PyTorchEngine(pt_model=model_ours, env=env) - engine.rng = jax.random.PRNGKey(0) + engine = MockEngine(pt_model=model_ours, env=env) + engine.rng = jax.random.key(0) return engine def test_sampling_2D(self): @@ -37,14 +82,23 @@ def test_sampling_2D(self): engine = self.setup() self.assertEqual(engine.env.sampling_algorithm, "greedy") logits = jnp.array([[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]]) - token = engine._sampling(logits, batch_size=1) + token = engine._sampling( + logits, "greedy", engine.rng, temperature=1.0, topk=1, nucleus_topp=0.0 + ) self.assertEqual(token, jnp.array([[0]])) self.assertTrue(jnp.isdtype(token, jnp.int32)) # test weighted engine.env.sampling_algorithm = "weighted" engine.env.temperature = 5.0 - token = engine._sampling(logits, batch_size=1) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + temperature=5.0, + topk=1, + nucleus_topp=0.0, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) @@ -52,21 +106,36 @@ def test_sampling_2D(self): engine.env.sampling_algorithm = "topk" engine.env.temperature = 5.0 engine.env.topk = 4 - token = engine._sampling(logits, batch_size=1) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + temperature=5.0, + topk=4, + nucleus_topp=0.0, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) # test nucleus engine.env.sampling_algorithm = "nucleus" - engine.env.temperature = 0.0 + engine.env.temperature = 1.0 engine.env.nucleus_topp = 0.8 - token = engine._sampling(logits, batch_size=1) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + temperature=0.0, + topk=1, + nucleus_topp=0.8, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) def test_sampling_3D(self): # test greedy - engine = self.setup() + engine = self.setup(batch_size=2) + self.assertEqual(engine.env.sampling_algorithm, "greedy") logits = jnp.array( [ @@ -74,14 +143,28 @@ def test_sampling_3D(self): [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]], ] ) - token = engine._sampling(logits, batch_size=2) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + engine.env.temperature, + engine.env.topk, + engine.env.nucleus_topp, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[3], [0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) # test weighted engine.env.sampling_algorithm = "weighted" engine.env.temperature = 10.0 - token = engine._sampling(logits, batch_size=2) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + engine.env.temperature, + engine.env.topk, + engine.env.nucleus_topp, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) @@ -89,7 +172,14 @@ def test_sampling_3D(self): engine.env.sampling_algorithm = "topk" engine.env.temperature = 1.0 engine.env.topk = 3 - token = engine._sampling(logits, batch_size=2) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + engine.env.temperature, + engine.env.topk, + engine.env.nucleus_topp, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[1], [0]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) @@ -97,10 +187,351 @@ def test_sampling_3D(self): engine.env.sampling_algorithm = "nucleus" engine.env.temperature = 1.0 engine.env.nucleus_topp = 0.8 - token = engine._sampling(logits, batch_size=2) + token = engine._sampling( + logits, + engine.env.sampling_algorithm, + engine.rng, + engine.env.temperature, + engine.env.topk, + engine.env.nucleus_topp, + ) self.assertTrue(jnp.array_equal(token, jnp.array([[3], [1]]))) self.assertTrue(jnp.isdtype(token, jnp.int32)) + def test_custom_sampling_3D(self): + engine = self.setup(batch_size=2) + engine.rng = jax.random.key(3) + engine.splited_rngs = jax.random.split( + engine.rng, num=engine.env.batch_size + ) + engine.env.sampling_algorithm = "" + + # Need a different engine of batch size of 1 to reshape the output + engine_b1 = self.setup() + engine_b1.rng = jax.random.key(3) + logits = jnp.array( + [ + [[0.4, 0.3, 0.2, 0.1], [0.5, 0.6, 0.7, 0.8]], + [[0.5, 0.6, 0.7, 0.8], [0.4, 0.3, 0.2, 0.1]], + ] + ) + + # test greedy + token = engine._custom_sampling( + logits, + jnp.array([0, 0]), + engine.splited_rngs, + temperature=jnp.array([[0.0], [0.0]]), + topk=jnp.array([[0], [0]]), + nucleus_topp=jnp.array([[0.0], [0.0]]), + ) + original_tokens = [] + for i in range(2): + original_token = engine_b1._sampling( + logits[i], + "greedy", + engine.splited_rngs[i], + temperature=1.0, + topk=0, + nucleus_topp=0.0, + ) + original_tokens.append(original_token) + original_tokens = jnp.concatenate(original_tokens) + + print(f"custom sampling token {token} vs original tokens {original_tokens}") + self.assertTrue(jnp.array_equal(token, original_tokens)) + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [0]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test weighted + engine.env.sampling_algorithm = "weighted" + token = engine._custom_sampling( + logits, + jnp.array([1, 1]), + engine.splited_rngs, + temperature=jnp.array([1, 1]), + topk=jnp.array([0, 0]), + nucleus_topp=jnp.array([0.0, 0.0]), + ) + original_tokens = [] + for i in range(2): + original_token = engine_b1._sampling( + logits[i], + "weighted", + engine.splited_rngs[i], + temperature=1, + topk=0, + nucleus_topp=0.0, + ) + original_tokens.append(original_token) + original_tokens = jnp.concatenate(original_tokens) + + print(f"custom sampling token {token} vs original tokens {original_tokens}") + self.assertTrue(jnp.array_equal(token, original_tokens)) + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test topk + engine.env.sampling_algorithm = "topk" + token = engine._custom_sampling( + logits, + jnp.array([3, 3]), + engine.splited_rngs, + temperature=jnp.array([[1.0], [1.0]]), + topk=jnp.array([[3], [3]]), + nucleus_topp=jnp.array([[0.0], [0.0]]), + ) + original_tokens = [] + for i in range(2): + original_token = engine_b1._sampling( + logits[i], + "topk", + engine.splited_rngs[i], + temperature=1.0, + topk=3, + nucleus_topp=0.0, + ) + original_tokens.append(original_token) + original_tokens = jnp.concatenate(original_tokens) + + print(f"custom sampling token {token} vs original tokens {original_tokens}") + self.assertTrue(jnp.array_equal(token, original_tokens)) + self.assertTrue(jnp.array_equal(token, jnp.array([[1], [2]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test nucleus + engine.env.sampling_algorithm = "nucleus" + token = engine._custom_sampling( + logits, + jnp.array([2, 2]), + engine.splited_rngs, + temperature=jnp.array([[1.0], [1.0]]), + topk=jnp.array([[0], [0]]), + nucleus_topp=jnp.array([[0.8], [0.8]]), + ) + + original_tokens = [] + for i in range(2): + original_token = engine_b1._sampling( + logits[i], + "nucleus", + engine.splited_rngs[i], + temperature=1.0, + topk=0, + nucleus_topp=0.8, + ) + original_tokens.append(original_token) + original_tokens = jnp.concatenate(original_tokens) + print(f"custom sampling token {token} vs original tokens {original_tokens}") + self.assertTrue(jnp.array_equal(token, original_tokens)) + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + # test greedy + topk + token = engine._custom_sampling( + logits, + jnp.array([0, 3]), + engine.splited_rngs, + temperature=jnp.array([[0.0], [1.0]]), + topk=jnp.array([[0], [3]]), + nucleus_topp=jnp.array([[0.0], [0.0]]), + ) + original_tokens = [] + + i = 0 + original_token = engine_b1._sampling( + logits[i], + "greedy", + engine.splited_rngs[i], + temperature=0.0, + topk=0, + nucleus_topp=0.8, + ) + original_tokens.append(original_token) + + i = 1 + original_token = engine_b1._sampling( + logits[i], + "topk", + engine.splited_rngs[i], + temperature=1.0, + topk=3, + nucleus_topp=0.0, + ) + original_tokens.append(original_token) + + original_tokens = jnp.concatenate(original_tokens) + + print(f"custom sampling token {token} vs original tokens {original_tokens}") + self.assertTrue(jnp.array_equal(token, original_tokens)) + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + def test_prefill_with_custom_sampling(self): + engine = self.setup() + engine.env.sampling_algorithm = "" + + # Inputs doesn't matter + params = jnp.zeros((1,), jnp.float32) + padded_tokens = jnp.zeros((1,), jnp.float32) + true_length = 1 + + # Greedy + # algorithm, temperature, topk, nucleus_topp + sampler = [0, 1.0, 3, 0.8] + prefix, _ = engine.prefill( + params=params, + padded_tokens=padded_tokens, + true_length=true_length, + sampler=sampler, + ) + token = prefix.token + print(f"Greedy output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[3]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + + print( + f"prefix sampler config {prefix.sampler_config} vs sampler {jnp.array(sampler)}" + ) + self.assertAlmostEqual( + prefix.sampler_config.all(), jnp.array(sampler).all() + ) + + # Weighted + sampler = [1, 10.0, 3, 0.8] + prefix, _ = engine.prefill( + params=params, + padded_tokens=padded_tokens, + true_length=true_length, + sampler=sampler, + ) + token = prefix.token + print(f"Weighted output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[0]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + self.assertAlmostEqual( + prefix.sampler_config.all(), jnp.array(sampler).all() + ) + + # Nucleus + sampler = [2, 1.0, 3, 0.0] + prefix, _ = engine.prefill( + params=params, + padded_tokens=padded_tokens, + true_length=true_length, + sampler=sampler, + ) + token = prefix.token + print(f"Topk output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[3]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + self.assertAlmostEqual( + prefix.sampler_config.all(), jnp.array(sampler).all() + ) + + # Topk + sampler = [3, 1.0, 3, 0.8] + prefix, _ = engine.prefill( + params=params, + padded_tokens=padded_tokens, + true_length=true_length, + sampler=sampler, + ) + token = prefix.token + print(f"Nucleus output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[3]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + self.assertAlmostEqual( + prefix.sampler_config.all(), jnp.array(sampler).all() + ) + + def test_insert_no_wrap_with_custom_sampling(self): + engine = self.setup() + engine.env.sampling_algorithm = "" + engine.env.batch_size = 2 + cache_shape = engine.env.cache_shape + + sampler_config_raw = [0, 1.0, 3, 0.8] + sampler_config = jnp.array(sampler_config_raw) + + prefill_cache_shape = (1, cache_shape[1], 16, cache_shape[3]) + prefill_cache = [] + for _ in range(engine.env.num_layers): + prefill_cache.append( + ( + jnp.ones(prefill_cache_shape, dtype=jnp.bfloat16), + jnp.ones(prefill_cache_shape, dtype=jnp.bfloat16), + ) + ) + + prefix = Prefix( + token=jnp.ones((1)), + caches=prefill_cache, + seq_len=16, + sampler_config=sampler_config, + ) + + doesnt_matter = jnp.array([0]) + kv_cache = engine.env.make_caches_generate() + kv_cache = [c.state() for c in kv_cache] + + decode_state = DecodeState( + tokens=jnp.zeros((engine.env.batch_size, 1)), + caches=kv_cache, + cache_scales=[doesnt_matter], + current_position=16, + lens=jnp.zeros((engine.env.batch_size, 1)), + start=jnp.zeros((engine.env.batch_size, 1)), + input_pos=jnp.zeros((engine.env.batch_size,)), + mask=jnp.zeros((engine.env.batch_size, 128)), + sampler_config=jnp.zeros((engine.env.batch_size, 4)), + ) + + # Insert to slot 1 + result_decode_state = engine._insert_no_wrap(prefix, decode_state, slot=1) + + self.assertAlmostEqual( + result_decode_state.tokens.all(), decode_state.tokens.all() + ) + self.assertAlmostEqual( + result_decode_state.sampler_config.all(), + jnp.array([[0, 0, 0, 0], sampler_config_raw]).all(), + ) + + def test_decode_with_custom_sampling(self): + engine = self.setup(batch_size=2) + engine.rng = jax.random.key(3) + engine.splited_rngs = jax.random.split( + engine.rng, num=engine.env.batch_size + ) + engine.env.sampling_algorithm = "" + + # Inputs doesn't matter + doesnt_matter = jnp.array([0]) + params = doesnt_matter + + decode_state = DecodeState( + tokens=jnp.zeros((engine.env.batch_size, 1)), + caches=[doesnt_matter], + cache_scales=[doesnt_matter], + current_position=0, + lens=jnp.zeros((engine.env.batch_size, 1)), + start=doesnt_matter, + input_pos=jnp.zeros((engine.env.batch_size,)), + mask=jnp.zeros((engine.env.batch_size, 1)), + sampler_config=jnp.array([[0, 0.0, 0, 0.0], [3, 1.0, 3, 0.0]]), + ) + + # Topk + Weighted + # algorithm, temperature, topk, nucleus_topp + decode_state, _ = engine.generate_impl( + params=params, decode_state=decode_state + ) + token = decode_state.tokens + print(f"Greedy output: {token}") + self.assertTrue(jnp.array_equal(token, jnp.array([[3], [2]]))) + self.assertTrue(jnp.isdtype(token, jnp.int32)) + # def test_insert(self): # seqlen = 32 From 240ea4774d23ac822f05726a37c4409fe126336c Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 21 Oct 2024 20:50:11 +0000 Subject: [PATCH 4/4] Fix rng shape. Fix typo. --- jetstream_pt/engine.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index 7c77aff..a2103ec 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -103,9 +103,7 @@ def __init__( self.pt_model = pt_model self.env = env self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32 - self.rng = jax.random.key(0).reshape( - 1, - ) + self.rng = jax.random.key(0) # For sampling self.splited_rngs = jax.random.split(self.rng, num=self.env.batch_size) self.weights = weights @@ -467,6 +465,9 @@ def prefill( # Prefill only handle batch size of 1, therefore no need to use splitted rngs sampling = self._custom_sampling + rng = self.rng.reshape( + 1, + ) else: algorithm, temperature, topk, nucleus_topp = ( self.env.sampling_algorithm, @@ -475,11 +476,12 @@ def prefill( self.env.nucleus_topp, ) sampling = self._sampling + rng = self.rng token = sampling( logits=logits, algorithm=algorithm, - rng=self.rng, + rng=rng, temperature=temperature, topk=topk, nucleus_topp=nucleus_topp,