Skip to content

Commit

Permalink
Handle enum in choice (#1279)
Browse files Browse the repository at this point in the history
This PR aims at solving #1218 (requires #1277 to be merged)

Quick fix to solve #1275
  • Loading branch information
g-prz authored Dec 8, 2024
1 parent dccaace commit e4f96fb
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ docs/build
*.gguf
.venv
benchmarks/results
.python-version

# Remove doc build folders
.cache/
Expand Down
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,29 @@ generator = outlines.generate.choice(model, ["Positive", "Negative"])
answer = generator(prompt)
```

You can also pass these choices through en enum:

````python
from enum import Enum

import outlines

class Sentiment(str, Enum):
positive = "Positive"
negative = "Negative"

model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")

prompt = """You are a sentiment-labelling assistant.
Is the following review positive or negative?
Review: This restaurant is just awesome!
"""

generator = outlines.generate.choice(model, Sentiment)
answer = generator(prompt)
````

### Type constraint

You can instruct the model to only return integers or floats:
Expand Down
18 changes: 14 additions & 4 deletions outlines/generate/choice.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json as pyjson
import re
from enum import Enum
from functools import singledispatch
from typing import Callable, List
from typing import Callable, List, Union

from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_enum
from outlines.generate.api import SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.samplers import Sampler, multinomial
Expand All @@ -12,12 +15,19 @@

@singledispatch
def choice(
model, choices: List[str], sampler: Sampler = multinomial()
model, choices: Union[List[str], type[Enum]], sampler: Sampler = multinomial()
) -> SequenceGeneratorAdapter:
regex_str = r"(" + r"|".join(choices) + r")"
if isinstance(choices, type(Enum)):
regex_str = build_regex_from_schema(pyjson.dumps(get_schema_from_enum(choices)))
else:
choices = [re.escape(choice) for choice in choices] # type: ignore
regex_str = r"(" + r"|".join(choices) + r")"

generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: x
if isinstance(choices, type(Enum)):
generator.format_sequence = lambda x: pyjson.loads(x)
else:
generator.format_sequence = lambda x: x

return generator

Expand Down
31 changes: 28 additions & 3 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import re
from enum import Enum

import pytest

Expand Down Expand Up @@ -127,6 +128,18 @@ def model_t5(tmp_path_factory):
)


class MyEnum(Enum):
foo = "foo"
bar = "bar"
baz = "baz"


ALL_SAMPLE_CHOICES_FIXTURES = (
["foo", "bar", "baz"],
MyEnum,
)


##########################################
# Stuctured Generation Inputs
##########################################
Expand Down Expand Up @@ -264,21 +277,33 @@ def test_generate_json(request, model_fixture, sample_schema):


@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
@pytest.mark.parametrize("sample_choices", ALL_SAMPLE_CHOICES_FIXTURES)
def test_generate_choice(request, model_fixture, sample_choices):
model = request.getfixturevalue(model_fixture)
generator = generate.choice(model, sample_choices)
res = generator(**get_inputs(model_fixture))
assert res in sample_choices
if isinstance(sample_choices, type(Enum)):
assert res in [elt.value for elt in sample_choices]
else:
assert res in sample_choices


@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
@pytest.mark.parametrize("sample_choices", ALL_SAMPLE_CHOICES_FIXTURES)
def test_generate_choice_twice(request, model_fixture, sample_choices):
model = request.getfixturevalue(model_fixture)
generator = generate.choice(model, sample_choices)
res = generator(**get_inputs(model_fixture))
assert res in sample_choices
if isinstance(sample_choices, type(Enum)):
assert res in [elt.value for elt in sample_choices]
else:
assert res in sample_choices

res = generator(**get_inputs(model_fixture))
assert res in sample_choices
if isinstance(sample_choices, type(Enum)):
assert res in [elt.value for elt in sample_choices]
else:
assert res in sample_choices


@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
Expand Down

0 comments on commit e4f96fb

Please sign in to comment.