Skip to content

Commit

Permalink
Move generation of sampling params (#1340)
Browse files Browse the repository at this point in the history
This PR addresses #1335 
I have kept the `SamplingParameters` in `outlines/generate/api.py` cause
its imported at many places and may break examples ig.
Would you want to change that @dgerlanc ??
Otherwise if this looks good, I can make docs changes if needd, lmk.
  • Loading branch information
sky-2002 authored Dec 16, 2024
1 parent 365b566 commit 9b586db
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 21 deletions.
28 changes: 7 additions & 21 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union

from outlines.generate.generator import sequence_generator
from outlines.samplers import BeamSearchSampler, GreedySampler, MultinomialSampler

if TYPE_CHECKING:
import torch
Expand Down Expand Up @@ -353,11 +352,13 @@ def token_generator() -> Iterator[Union[List[str], str, List[List[str]]]]:
]

generated_sequences = [
self.format_sequence(
self.strip_stop_sequences(sequence, stop_sequences)
(
self.format_sequence(
self.strip_stop_sequences(sequence, stop_sequences)
)
if stop
else sequence
)
if stop
else sequence
for sequence, stop in zip(
generated_sequences, is_stop_at_reached
)
Expand Down Expand Up @@ -428,22 +429,7 @@ def __init__(self, model, logits_processor, sampler):
self.model = model
self.logits_processor = logits_processor

if isinstance(sampler, MultinomialSampler):
self.sampling_params = SamplingParameters(
"multinomial",
sampler.samples,
sampler.top_p,
sampler.top_k,
sampler.temperature,
)
elif isinstance(sampler, GreedySampler):
self.sampling_params = SamplingParameters(
"greedy", sampler.samples, None, None, 0.0
)
elif isinstance(sampler, BeamSearchSampler):
self.sampling_params = SamplingParameters(
"beam_search", sampler.samples, None, None, 1.0
)
self.sampling_params = sampler.sampling_params

def prepare_generation_parameters(
self,
Expand Down
30 changes: 30 additions & 0 deletions outlines/samplers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional, Protocol, Tuple

if TYPE_CHECKING:
Expand All @@ -17,6 +18,17 @@ def __call__(
...


@dataclass(frozen=True)
class SamplingParameters:
"""Sampling parameters available in Outlines."""

sampler: str
num_samples: int = 1
top_p: Optional[float] = None
top_k: Optional[int] = None
temperature: Optional[float] = None


class GreedySampler:
"""Greedy Sampling algorithm.
Expand Down Expand Up @@ -76,6 +88,10 @@ def __call__(

return next_token_ids, ancestors, weights

@property
def sampling_params(self):
return SamplingParameters("greedy", self.samples, None, None, 0.0)


greedy = GreedySampler

Expand Down Expand Up @@ -161,6 +177,16 @@ def __call__(

return next_token_ids, ancestors, weights

@property
def sampling_params(self):
return SamplingParameters(
"multinomial",
self.samples,
self.top_p,
self.top_k,
self.temperature,
)


multinomial = MultinomialSampler

Expand Down Expand Up @@ -320,5 +346,9 @@ def __call__(

return next_token_ids, ancestors, weights

@property
def sampling_params(self):
return SamplingParameters("beam_search", self.samples, None, None, 1.0)


beam_search = BeamSearchSampler
23 changes: 23 additions & 0 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def test_greedy():
assert ancestors.equal(torch.tensor([0, 1]))
assert weights.equal(torch.tensor([logprobs[0, 0], logprobs[1, 2]]))

params = sampler.sampling_params
assert params.sampler == "greedy"
assert params.num_samples == 1
assert params.top_p is None
assert params.top_k is None
assert params.temperature == 0.0


def test_multinomial():
rng = torch.Generator()
Expand All @@ -72,6 +79,14 @@ def test_multinomial():
assert ancestors.equal(torch.tensor([0, 1]))
assert weights.equal(torch.tensor([logprobs[0, 0], logprobs[1, 2]]))

sampler = MultinomialSampler(samples=5, top_k=10, top_p=0.9, temperature=0.8)
params = sampler.sampling_params
assert params.sampler == "multinomial"
assert params.num_samples == 5
assert params.top_p == 0.9
assert params.top_k == 10
assert params.temperature == 0.8


def test_multinomial_init():
sampler = MultinomialSampler()
Expand Down Expand Up @@ -252,3 +267,11 @@ def test_beam_search():
]
)
)

sampler = BeamSearchSampler(beams=3)
params = sampler.sampling_params
assert params.sampler == "beam_search"
assert params.num_samples == 3
assert params.top_p is None
assert params.top_k is None
assert params.temperature == 1.0

0 comments on commit 9b586db

Please sign in to comment.