diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 4919f2090..5d3c52a8a 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union from outlines.generate.generator import sequence_generator -from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler if TYPE_CHECKING: import torch @@ -353,11 +352,13 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]: ] generated_sequences = [ - self.format_sequence( - self.strip_stop_sequences(sequence, stop_sequences) + ( + self.format_sequence( + self.strip_stop_sequences(sequence, stop_sequences) + ) + if stop + else sequence ) - if stop - else sequence for sequence, stop in zip( generated_sequences, is_stop_at_reached ) @@ -428,22 +429,7 @@ def __init__(self, model, logits_processor, sampler): self.model = model self.logits_processor = logits_processor - if isinstance(sampler, MultinomialSampler): - self.sampling_params = SamplingParameters( - "multinomial", - sampler.samples, - sampler.top_p, - sampler.top_k, - sampler.temperature, - ) - elif isinstance(sampler, GreedySampler): - self.sampling_params = SamplingParameters( - "greedy", sampler.samples, None, None, 0.0 - ) - elif isinstance(sampler, BeamSearchSampler): - self.sampling_params = SamplingParameters( - "beam_search", sampler.samples, None, None, 1.0 - ) + self.sampling_params = sampler.sampling_params def prepare_generation_parameters( self, diff --git a/outlines/samplers.py b/outlines/samplers.py index b1421971f..3ab1728fc 100644 --- a/outlines/samplers.py +++ b/outlines/samplers.py @@ -1,4 +1,5 @@ import math +from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Optional, Protocol, Tuple if TYPE_CHECKING: @@ -17,6 +18,17 @@ def __call__( ... +@dataclass(frozen=True) +class SamplingParameters: + """Sampling parameters available in Outlines.""" + + sampler: str + num_samples: int = 1 + top_p: Optional[float] = None + top_k: Optional[int] = None + temperature: Optional[float] = None + + class GreedySampler: """Greedy Sampling algorithm. @@ -76,6 +88,10 @@ def __call__( return next_token_ids, ancestors, weights + @property + def sampling_params(self): + return SamplingParameters("greedy", self.samples, None, None, 0.0) + greedy = GreedySampler @@ -161,6 +177,16 @@ def __call__( return next_token_ids, ancestors, weights + @property + def sampling_params(self): + return SamplingParameters( + "multinomial", + self.samples, + self.top_p, + self.top_k, + self.temperature, + ) + multinomial = MultinomialSampler @@ -320,5 +346,9 @@ def __call__( return next_token_ids, ancestors, weights + @property + def sampling_params(self): + return SamplingParameters("beam_search", self.samples, None, None, 1.0) + beam_search = BeamSearchSampler diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 88cdb0fbc..10a7be26f 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -47,6 +47,13 @@ def test_greedy(): assert ancestors.equal(torch.tensor([0, 1])) assert weights.equal(torch.tensor([logprobs[0, 0], logprobs[1, 2]])) + params = sampler.sampling_params + assert params.sampler == "greedy" + assert params.num_samples == 1 + assert params.top_p is None + assert params.top_k is None + assert params.temperature == 0.0 + def test_multinomial(): rng = torch.Generator() @@ -72,6 +79,14 @@ def test_multinomial(): assert ancestors.equal(torch.tensor([0, 1])) assert weights.equal(torch.tensor([logprobs[0, 0], logprobs[1, 2]])) + sampler = MultinomialSampler(samples=5, top_k=10, top_p=0.9, temperature=0.8) + params = sampler.sampling_params + assert params.sampler == "multinomial" + assert params.num_samples == 5 + assert params.top_p == 0.9 + assert params.top_k == 10 + assert params.temperature == 0.8 + def test_multinomial_init(): sampler = MultinomialSampler() @@ -252,3 +267,11 @@ def test_beam_search(): ] ) ) + + sampler = BeamSearchSampler(beams=3) + params = sampler.sampling_params + assert params.sampler == "beam_search" + assert params.num_samples == 3 + assert params.top_p is None + assert params.top_k is None + assert params.temperature == 1.0