diff --git a/langdspy/__init__.py b/langdspy/__init__.py index 5467281..e84c4b9 100644 --- a/langdspy/__init__.py +++ b/langdspy/__init__.py @@ -1,4 +1,4 @@ -from .field_descriptors import InputField, OutputField, InputFieldList, HintField, OutputFieldEnum +from .field_descriptors import InputField, OutputField, InputFieldList, HintField, OutputFieldEnum, OutputFieldEnumList from .prompt_strategies import PromptSignature, PromptStrategy, DefaultPromptStrategy from .prompt_runners import PromptRunner, RunnableConfig, Prediction, MultiPromptRunner from .model import Model, TrainedModelState diff --git a/langdspy/field_descriptors.py b/langdspy/field_descriptors.py index 5715ec0..122101e 100644 --- a/langdspy/field_descriptors.py +++ b/langdspy/field_descriptors.py @@ -138,4 +138,22 @@ def format_prompt_description(self, llm_type: str): if llm_type == "openai": return f"{self._start_format_openai()}: One of: {choices_str} - {self.desc}" elif llm_type == "anthropic": - return f"{self._start_format_anthropic()}One of: {choices_str} - {self.desc}" \ No newline at end of file + return f"{self._start_format_anthropic()}One of: {choices_str} - {self.desc}" + +class OutputFieldEnumList(OutputField): + def __init__(self, name: str, desc: str, enum: Enum, **kwargs): + kwargs['enum'] = enum + if not 'transformer' in kwargs: + kwargs['transformer'] = transformers.as_enum_list + if not 'validator' in kwargs: + kwargs['validator'] = validators.is_subset_of + kwargs['choices'] = [e.name for e in enum] + super().__init__(name, desc, **kwargs) + + def format_prompt_description(self, llm_type: str): + enum = self.kwargs.get('enum') + choices_str = ", ".join([e.name for e in enum]) + if llm_type == "openai": + return f"{self._start_format_openai()}: A comma-separated list of one or more of: {choices_str} - {self.desc}" + elif llm_type == "anthropic": + return f"{self._start_format_anthropic()}A comma-separated list of one or more of: {choices_str} - {self.desc}" \ No newline at end of file diff --git a/langdspy/transformers.py b/langdspy/transformers.py index c5b753a..7eef24c 100644 --- a/langdspy/transformers.py +++ b/langdspy/transformers.py @@ -21,4 +21,12 @@ def as_enum(val: str, kwargs: Dict[str, Any]) -> Enum: try: return enum_class[val.upper()] except KeyError: - raise ValueError(f"{val} is not a valid member of the {enum_class.__name__} enumeration") \ No newline at end of file + raise ValueError(f"{val} is not a valid member of the {enum_class.__name__} enumeration") + +def as_enum_list(val: str, kwargs: Dict[str, Any]) -> List[Enum]: + enum_class = kwargs['enum'] + values = [v.strip() for v in val.split(",")] + try: + return [enum_class[v.upper()] for v in values] + except KeyError as e: + raise ValueError(f"{e.args[0]} is not a valid member of the {enum_class.__name__} enumeration") \ No newline at end of file diff --git a/langdspy/validators.py b/langdspy/validators.py index a2d0453..a13cb6b 100644 --- a/langdspy/validators.py +++ b/langdspy/validators.py @@ -24,7 +24,6 @@ def is_one_of(input, output_val, kwargs) -> bool: if not kwargs.get('choices'): raise ValueError("is_one_of validator requires 'choices' keyword argument") - none_ok = False if kwargs.get('none_ok', False): none_ok = True @@ -49,4 +48,25 @@ def is_one_of(input, output_val, kwargs) -> bool: logger.error(f"Field must be one of {kwargs.get('choices')}, not {output_val}") import traceback traceback.print_exc() + return False + +def is_subset_of(input, output_val, kwargs) -> bool: + if not kwargs.get('choices'): + raise ValueError("is_subset_of validator requires 'choices' keyword argument") + + none_ok = kwargs.get('none_ok', False) + if none_ok and output_val.lower().strip() == "none": + return True + + try: + values = [v.strip() for v in output_val.split(",")] + if not kwargs.get('case_sensitive', False): + choices = [c.lower() for c in kwargs['choices']] + values = [v.lower() for v in values] + for value in values: + if value not in choices: + return False + return True + except Exception as e: + logger.error(f"Field must be a comma-separated list of one or more of {kwargs.get('choices')}, not {output_val}") return False \ No newline at end of file diff --git a/tests/test_field_descriptors.py b/tests/test_field_descriptors.py index d630791..9734e0e 100644 --- a/tests/test_field_descriptors.py +++ b/tests/test_field_descriptors.py @@ -1,5 +1,6 @@ import pytest -from langdspy.field_descriptors import InputField, InputFieldList +from enum import Enum +from langdspy.field_descriptors import InputField, InputFieldList, OutputField, OutputFieldEnum, OutputFieldEnumList def test_input_field_initialization(): field = InputField("name", "description") @@ -35,4 +36,26 @@ def test_input_field_list_format_prompt_value(): def test_input_field_list_format_prompt_value_empty(): field = InputFieldList("name", "description") - assert field.format_prompt_value([], "openai") == "✅name: NO VALUES SPECIFIED" \ No newline at end of file + assert field.format_prompt_value([], "openai") == "✅name: NO VALUES SPECIFIED" + +class TestEnum(Enum): + VALUE1 = "value1" + VALUE2 = "value2" + VALUE3 = "value3" + +def test_output_field_enum_list_initialization(): + field = OutputFieldEnumList("name", "description", TestEnum) + assert field.name == "name" + assert field.desc == "description" + print(field.kwargs) + assert field.kwargs['enum'] == TestEnum + assert field.transformer.__name__ == "as_enum_list" + # assert field.kwargs['transformer'].__name__ == "as_enum_list" + assert field.validator.__name__ == "is_subset_of" + # assert field.kwargs['validator'].__name__ == "is_subset_of" + assert field.kwargs['choices'] == ["VALUE1", "VALUE2", "VALUE3"] + +def test_output_field_enum_list_format_prompt_description(): + field = OutputFieldEnumList("name", "description", TestEnum) + assert "A comma-separated list of one or more of: VALUE1, VALUE2, VALUE3" in field.format_prompt_description("openai") + assert "A comma-separated list of one or more of: VALUE1, VALUE2, VALUE3" in field.format_prompt_description("anthropic") diff --git a/tests/test_transformers.py b/tests/test_transformers.py index 9bd2398..a46cb01 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -27,4 +27,17 @@ class Fruit(Enum): assert transformers.as_enum("BANANA", {"enum": Fruit}) == Fruit.BANANA with pytest.raises(ValueError): - transformers.as_enum("CHERRY", {"enum": Fruit}) \ No newline at end of file + transformers.as_enum("CHERRY", {"enum": Fruit}) + +def test_as_enum_list(): + class Fruit(Enum): + APPLE = 1 + BANANA = 2 + CHERRY = 3 + + assert transformers.as_enum_list("APPLE", {"enum": Fruit}) == [Fruit.APPLE] + assert transformers.as_enum_list("BANANA, CHERRY", {"enum": Fruit}) == [Fruit.BANANA, Fruit.CHERRY] + assert transformers.as_enum_list("APPLE,BANANA,CHERRY", {"enum": Fruit}) == [Fruit.APPLE, Fruit.BANANA, Fruit.CHERRY] + + with pytest.raises(ValueError): + transformers.as_enum_list("DURIAN", {"enum": Fruit}) \ No newline at end of file diff --git a/tests/test_validators.py b/tests/test_validators.py index 18e1bb4..1ae4030 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -14,4 +14,22 @@ def test_is_one_of(): assert validators.is_one_of({}, 'none', {'choices': ['apple', 'banana'], 'none_ok': True}) == True with pytest.raises(ValueError): - validators.is_one_of({}, 'apple', {}) \ No newline at end of file + validators.is_one_of({}, 'apple', {}) + +def test_is_subset_of(): + choices = ["apple", "banana", "cherry"] + + assert validators.is_subset_of({}, "apple", {"choices": choices}) == True + assert validators.is_subset_of({}, "apple,banana", {"choices": choices}) == True + assert validators.is_subset_of({}, "apple, banana, cherry", {"choices": choices}) == True + assert validators.is_subset_of({}, "APPLE", {"choices": choices, "case_sensitive": False}) == True + assert validators.is_subset_of({}, "APPLE,BANANA", {"choices": choices, "case_sensitive": False}) == True + + assert validators.is_subset_of({}, "durian", {"choices": choices}) == False + assert validators.is_subset_of({}, "apple,durian", {"choices": choices}) == False + + assert validators.is_subset_of({}, "none", {"choices": choices, "none_ok": True}) == True + assert validators.is_subset_of({}, "apple,none", {"choices": choices, "none_ok": True}) == False + + with pytest.raises(ValueError): + validators.is_subset_of({}, "apple", {})