diff --git a/gemma/gm/text/_sampler.py b/gemma/gm/text/_sampler.py index 7c941a93..5ecd6f33 100644 --- a/gemma/gm/text/_sampler.py +++ b/gemma/gm/text/_sampler.py @@ -292,11 +292,6 @@ def sample( rng = _normalize_rng(rng) has_batch_dim = _get_has_batch_dim(prompt) - if stream and has_batch_dim: - raise ValueError( - 'Streaming is not supported for batched prompts. Let us know if you' - ' need this feature.' - ) # Normalize the text, images. Tokenize, shard,... inputs = self._get_inputs( @@ -365,6 +360,7 @@ def sample( return self._stream_decode_state( # pytype: disable=bad-return-type state, return_state=return_state, + has_batch_dim=has_batch_dim, ) else: return self._decode_state( # pytype: disable=bad-return-type @@ -466,12 +462,13 @@ def _stream_decode_state( state_iter: Iterator[_sampler_loop.SamplingState], *, return_state: bool, + has_batch_dim: bool, ): for i, state in enumerate(state_iter): yield self._decode_state( state, - predicted_tokens=state.predicted_tokens[..., i], - has_batch_dim=False, + predicted_tokens=state.predicted_tokens[..., :i+1], + has_batch_dim=has_batch_dim, return_state=return_state, ) diff --git a/gemma/gm/text/_sampler_loop.py b/gemma/gm/text/_sampler_loop.py index 47784ca8..dcc195fd 100644 --- a/gemma/gm/text/_sampler_loop.py +++ b/gemma/gm/text/_sampler_loop.py @@ -205,7 +205,7 @@ def _stream_sample_loop( for _ in range(max_new_tokens): # Exit if the cache is full. cache = _cache_helper.Cache(state.cache) - if state.done[0].tolist() or cache.is_full: + if jnp.all(state.done) or cache.is_full: break state = self._sample_step(