Skip to content

Commit

Permalink
Merge pull request #14 from aelaguiz/enum_list
Browse files Browse the repository at this point in the history
Added output field enum list
  • Loading branch information
aelaguiz authored Mar 19, 2024
2 parents 1171c94 + 7421872 commit e94463c
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 8 deletions.
2 changes: 1 addition & 1 deletion langdspy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .field_descriptors import InputField, OutputField, InputFieldList, HintField, OutputFieldEnum
from .field_descriptors import InputField, OutputField, InputFieldList, HintField, OutputFieldEnum, OutputFieldEnumList
from .prompt_strategies import PromptSignature, PromptStrategy, DefaultPromptStrategy
from .prompt_runners import PromptRunner, RunnableConfig, Prediction, MultiPromptRunner
from .model import Model, TrainedModelState
Expand Down
20 changes: 19 additions & 1 deletion langdspy/field_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,22 @@ def format_prompt_description(self, llm_type: str):
if llm_type == "openai":
return f"{self._start_format_openai()}: One of: {choices_str} - {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}One of: <choices>{choices_str}</choices> - {self.desc}</{self.name}>"
return f"{self._start_format_anthropic()}One of: <choices>{choices_str}</choices> - {self.desc}</{self.name}>"

class OutputFieldEnumList(OutputField):
def __init__(self, name: str, desc: str, enum: Enum, **kwargs):
kwargs['enum'] = enum
if not 'transformer' in kwargs:
kwargs['transformer'] = transformers.as_enum_list
if not 'validator' in kwargs:
kwargs['validator'] = validators.is_subset_of
kwargs['choices'] = [e.name for e in enum]
super().__init__(name, desc, **kwargs)

def format_prompt_description(self, llm_type: str):
enum = self.kwargs.get('enum')
choices_str = ", ".join([e.name for e in enum])
if llm_type == "openai":
return f"{self._start_format_openai()}: A comma-separated list of one or more of: {choices_str} - {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}A comma-separated list of one or more of: <choices>{choices_str}</choices> - {self.desc}</{self.name}>"
10 changes: 9 additions & 1 deletion langdspy/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,12 @@ def as_enum(val: str, kwargs: Dict[str, Any]) -> Enum:
try:
return enum_class[val.upper()]
except KeyError:
raise ValueError(f"{val} is not a valid member of the {enum_class.__name__} enumeration")
raise ValueError(f"{val} is not a valid member of the {enum_class.__name__} enumeration")

def as_enum_list(val: str, kwargs: Dict[str, Any]) -> List[Enum]:
enum_class = kwargs['enum']
values = [v.strip() for v in val.split(",")]
try:
return [enum_class[v.upper()] for v in values]
except KeyError as e:
raise ValueError(f"{e.args[0]} is not a valid member of the {enum_class.__name__} enumeration")
22 changes: 21 additions & 1 deletion langdspy/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def is_one_of(input, output_val, kwargs) -> bool:
if not kwargs.get('choices'):
raise ValueError("is_one_of validator requires 'choices' keyword argument")


none_ok = False
if kwargs.get('none_ok', False):
none_ok = True
Expand All @@ -49,4 +48,25 @@ def is_one_of(input, output_val, kwargs) -> bool:
logger.error(f"Field must be one of {kwargs.get('choices')}, not {output_val}")
import traceback
traceback.print_exc()
return False

def is_subset_of(input, output_val, kwargs) -> bool:
if not kwargs.get('choices'):
raise ValueError("is_subset_of validator requires 'choices' keyword argument")

none_ok = kwargs.get('none_ok', False)
if none_ok and output_val.lower().strip() == "none":
return True

try:
values = [v.strip() for v in output_val.split(",")]
if not kwargs.get('case_sensitive', False):
choices = [c.lower() for c in kwargs['choices']]
values = [v.lower() for v in values]
for value in values:
if value not in choices:
return False
return True
except Exception as e:
logger.error(f"Field must be a comma-separated list of one or more of {kwargs.get('choices')}, not {output_val}")
return False
27 changes: 25 additions & 2 deletions tests/test_field_descriptors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from langdspy.field_descriptors import InputField, InputFieldList
from enum import Enum
from langdspy.field_descriptors import InputField, InputFieldList, OutputField, OutputFieldEnum, OutputFieldEnumList

def test_input_field_initialization():
field = InputField("name", "description")
Expand Down Expand Up @@ -35,4 +36,26 @@ def test_input_field_list_format_prompt_value():

def test_input_field_list_format_prompt_value_empty():
field = InputFieldList("name", "description")
assert field.format_prompt_value([], "openai") == "✅name: NO VALUES SPECIFIED"
assert field.format_prompt_value([], "openai") == "✅name: NO VALUES SPECIFIED"

class TestEnum(Enum):
VALUE1 = "value1"
VALUE2 = "value2"
VALUE3 = "value3"

def test_output_field_enum_list_initialization():
field = OutputFieldEnumList("name", "description", TestEnum)
assert field.name == "name"
assert field.desc == "description"
print(field.kwargs)
assert field.kwargs['enum'] == TestEnum
assert field.transformer.__name__ == "as_enum_list"
# assert field.kwargs['transformer'].__name__ == "as_enum_list"
assert field.validator.__name__ == "is_subset_of"
# assert field.kwargs['validator'].__name__ == "is_subset_of"
assert field.kwargs['choices'] == ["VALUE1", "VALUE2", "VALUE3"]

def test_output_field_enum_list_format_prompt_description():
field = OutputFieldEnumList("name", "description", TestEnum)
assert "A comma-separated list of one or more of: VALUE1, VALUE2, VALUE3" in field.format_prompt_description("openai")
assert "A comma-separated list of one or more of: <choices>VALUE1, VALUE2, VALUE3</choices>" in field.format_prompt_description("anthropic")
15 changes: 14 additions & 1 deletion tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,17 @@ class Fruit(Enum):
assert transformers.as_enum("BANANA", {"enum": Fruit}) == Fruit.BANANA

with pytest.raises(ValueError):
transformers.as_enum("CHERRY", {"enum": Fruit})
transformers.as_enum("CHERRY", {"enum": Fruit})

def test_as_enum_list():
class Fruit(Enum):
APPLE = 1
BANANA = 2
CHERRY = 3

assert transformers.as_enum_list("APPLE", {"enum": Fruit}) == [Fruit.APPLE]
assert transformers.as_enum_list("BANANA, CHERRY", {"enum": Fruit}) == [Fruit.BANANA, Fruit.CHERRY]
assert transformers.as_enum_list("APPLE,BANANA,CHERRY", {"enum": Fruit}) == [Fruit.APPLE, Fruit.BANANA, Fruit.CHERRY]

with pytest.raises(ValueError):
transformers.as_enum_list("DURIAN", {"enum": Fruit})
20 changes: 19 additions & 1 deletion tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,22 @@ def test_is_one_of():
assert validators.is_one_of({}, 'none', {'choices': ['apple', 'banana'], 'none_ok': True}) == True

with pytest.raises(ValueError):
validators.is_one_of({}, 'apple', {})
validators.is_one_of({}, 'apple', {})

def test_is_subset_of():
choices = ["apple", "banana", "cherry"]

assert validators.is_subset_of({}, "apple", {"choices": choices}) == True
assert validators.is_subset_of({}, "apple,banana", {"choices": choices}) == True
assert validators.is_subset_of({}, "apple, banana, cherry", {"choices": choices}) == True
assert validators.is_subset_of({}, "APPLE", {"choices": choices, "case_sensitive": False}) == True
assert validators.is_subset_of({}, "APPLE,BANANA", {"choices": choices, "case_sensitive": False}) == True

assert validators.is_subset_of({}, "durian", {"choices": choices}) == False
assert validators.is_subset_of({}, "apple,durian", {"choices": choices}) == False

assert validators.is_subset_of({}, "none", {"choices": choices, "none_ok": True}) == True
assert validators.is_subset_of({}, "apple,none", {"choices": choices, "none_ok": True}) == False

with pytest.raises(ValueError):
validators.is_subset_of({}, "apple", {})

0 comments on commit e94463c

Please sign in to comment.