diff --git a/langdspy/__init__.py b/langdspy/__init__.py
index 5467281..e84c4b9 100644
--- a/langdspy/__init__.py
+++ b/langdspy/__init__.py
@@ -1,4 +1,4 @@
-from .field_descriptors import InputField, OutputField, InputFieldList, HintField, OutputFieldEnum
+from .field_descriptors import InputField, OutputField, InputFieldList, HintField, OutputFieldEnum, OutputFieldEnumList
from .prompt_strategies import PromptSignature, PromptStrategy, DefaultPromptStrategy
from .prompt_runners import PromptRunner, RunnableConfig, Prediction, MultiPromptRunner
from .model import Model, TrainedModelState
diff --git a/langdspy/field_descriptors.py b/langdspy/field_descriptors.py
index 5715ec0..122101e 100644
--- a/langdspy/field_descriptors.py
+++ b/langdspy/field_descriptors.py
@@ -138,4 +138,22 @@ def format_prompt_description(self, llm_type: str):
if llm_type == "openai":
return f"{self._start_format_openai()}: One of: {choices_str} - {self.desc}"
elif llm_type == "anthropic":
- return f"{self._start_format_anthropic()}One of: {choices_str} - {self.desc}{self.name}>"
\ No newline at end of file
+ return f"{self._start_format_anthropic()}One of: {choices_str} - {self.desc}{self.name}>"
+
+class OutputFieldEnumList(OutputField):
+ def __init__(self, name: str, desc: str, enum: Enum, **kwargs):
+ kwargs['enum'] = enum
+ if not 'transformer' in kwargs:
+ kwargs['transformer'] = transformers.as_enum_list
+ if not 'validator' in kwargs:
+ kwargs['validator'] = validators.is_subset_of
+ kwargs['choices'] = [e.name for e in enum]
+ super().__init__(name, desc, **kwargs)
+
+ def format_prompt_description(self, llm_type: str):
+ enum = self.kwargs.get('enum')
+ choices_str = ", ".join([e.name for e in enum])
+ if llm_type == "openai":
+ return f"{self._start_format_openai()}: A comma-separated list of one or more of: {choices_str} - {self.desc}"
+ elif llm_type == "anthropic":
+ return f"{self._start_format_anthropic()}A comma-separated list of one or more of: {choices_str} - {self.desc}{self.name}>"
\ No newline at end of file
diff --git a/langdspy/transformers.py b/langdspy/transformers.py
index c5b753a..7eef24c 100644
--- a/langdspy/transformers.py
+++ b/langdspy/transformers.py
@@ -21,4 +21,12 @@ def as_enum(val: str, kwargs: Dict[str, Any]) -> Enum:
try:
return enum_class[val.upper()]
except KeyError:
- raise ValueError(f"{val} is not a valid member of the {enum_class.__name__} enumeration")
\ No newline at end of file
+ raise ValueError(f"{val} is not a valid member of the {enum_class.__name__} enumeration")
+
+def as_enum_list(val: str, kwargs: Dict[str, Any]) -> List[Enum]:
+ enum_class = kwargs['enum']
+ values = [v.strip() for v in val.split(",")]
+ try:
+ return [enum_class[v.upper()] for v in values]
+ except KeyError as e:
+ raise ValueError(f"{e.args[0]} is not a valid member of the {enum_class.__name__} enumeration")
\ No newline at end of file
diff --git a/langdspy/validators.py b/langdspy/validators.py
index a2d0453..a13cb6b 100644
--- a/langdspy/validators.py
+++ b/langdspy/validators.py
@@ -24,7 +24,6 @@ def is_one_of(input, output_val, kwargs) -> bool:
if not kwargs.get('choices'):
raise ValueError("is_one_of validator requires 'choices' keyword argument")
-
none_ok = False
if kwargs.get('none_ok', False):
none_ok = True
@@ -49,4 +48,25 @@ def is_one_of(input, output_val, kwargs) -> bool:
logger.error(f"Field must be one of {kwargs.get('choices')}, not {output_val}")
import traceback
traceback.print_exc()
+ return False
+
+def is_subset_of(input, output_val, kwargs) -> bool:
+ if not kwargs.get('choices'):
+ raise ValueError("is_subset_of validator requires 'choices' keyword argument")
+
+ none_ok = kwargs.get('none_ok', False)
+ if none_ok and output_val.lower().strip() == "none":
+ return True
+
+ try:
+ values = [v.strip() for v in output_val.split(",")]
+ if not kwargs.get('case_sensitive', False):
+ choices = [c.lower() for c in kwargs['choices']]
+ values = [v.lower() for v in values]
+ for value in values:
+ if value not in choices:
+ return False
+ return True
+ except Exception as e:
+ logger.error(f"Field must be a comma-separated list of one or more of {kwargs.get('choices')}, not {output_val}")
return False
\ No newline at end of file
diff --git a/tests/test_field_descriptors.py b/tests/test_field_descriptors.py
index d630791..9734e0e 100644
--- a/tests/test_field_descriptors.py
+++ b/tests/test_field_descriptors.py
@@ -1,5 +1,6 @@
import pytest
-from langdspy.field_descriptors import InputField, InputFieldList
+from enum import Enum
+from langdspy.field_descriptors import InputField, InputFieldList, OutputField, OutputFieldEnum, OutputFieldEnumList
def test_input_field_initialization():
field = InputField("name", "description")
@@ -35,4 +36,26 @@ def test_input_field_list_format_prompt_value():
def test_input_field_list_format_prompt_value_empty():
field = InputFieldList("name", "description")
- assert field.format_prompt_value([], "openai") == "✅name: NO VALUES SPECIFIED"
\ No newline at end of file
+ assert field.format_prompt_value([], "openai") == "✅name: NO VALUES SPECIFIED"
+
+class TestEnum(Enum):
+ VALUE1 = "value1"
+ VALUE2 = "value2"
+ VALUE3 = "value3"
+
+def test_output_field_enum_list_initialization():
+ field = OutputFieldEnumList("name", "description", TestEnum)
+ assert field.name == "name"
+ assert field.desc == "description"
+ print(field.kwargs)
+ assert field.kwargs['enum'] == TestEnum
+ assert field.transformer.__name__ == "as_enum_list"
+ # assert field.kwargs['transformer'].__name__ == "as_enum_list"
+ assert field.validator.__name__ == "is_subset_of"
+ # assert field.kwargs['validator'].__name__ == "is_subset_of"
+ assert field.kwargs['choices'] == ["VALUE1", "VALUE2", "VALUE3"]
+
+def test_output_field_enum_list_format_prompt_description():
+ field = OutputFieldEnumList("name", "description", TestEnum)
+ assert "A comma-separated list of one or more of: VALUE1, VALUE2, VALUE3" in field.format_prompt_description("openai")
+ assert "A comma-separated list of one or more of: VALUE1, VALUE2, VALUE3" in field.format_prompt_description("anthropic")
diff --git a/tests/test_transformers.py b/tests/test_transformers.py
index 9bd2398..a46cb01 100644
--- a/tests/test_transformers.py
+++ b/tests/test_transformers.py
@@ -27,4 +27,17 @@ class Fruit(Enum):
assert transformers.as_enum("BANANA", {"enum": Fruit}) == Fruit.BANANA
with pytest.raises(ValueError):
- transformers.as_enum("CHERRY", {"enum": Fruit})
\ No newline at end of file
+ transformers.as_enum("CHERRY", {"enum": Fruit})
+
+def test_as_enum_list():
+ class Fruit(Enum):
+ APPLE = 1
+ BANANA = 2
+ CHERRY = 3
+
+ assert transformers.as_enum_list("APPLE", {"enum": Fruit}) == [Fruit.APPLE]
+ assert transformers.as_enum_list("BANANA, CHERRY", {"enum": Fruit}) == [Fruit.BANANA, Fruit.CHERRY]
+ assert transformers.as_enum_list("APPLE,BANANA,CHERRY", {"enum": Fruit}) == [Fruit.APPLE, Fruit.BANANA, Fruit.CHERRY]
+
+ with pytest.raises(ValueError):
+ transformers.as_enum_list("DURIAN", {"enum": Fruit})
\ No newline at end of file
diff --git a/tests/test_validators.py b/tests/test_validators.py
index 18e1bb4..1ae4030 100644
--- a/tests/test_validators.py
+++ b/tests/test_validators.py
@@ -14,4 +14,22 @@ def test_is_one_of():
assert validators.is_one_of({}, 'none', {'choices': ['apple', 'banana'], 'none_ok': True}) == True
with pytest.raises(ValueError):
- validators.is_one_of({}, 'apple', {})
\ No newline at end of file
+ validators.is_one_of({}, 'apple', {})
+
+def test_is_subset_of():
+ choices = ["apple", "banana", "cherry"]
+
+ assert validators.is_subset_of({}, "apple", {"choices": choices}) == True
+ assert validators.is_subset_of({}, "apple,banana", {"choices": choices}) == True
+ assert validators.is_subset_of({}, "apple, banana, cherry", {"choices": choices}) == True
+ assert validators.is_subset_of({}, "APPLE", {"choices": choices, "case_sensitive": False}) == True
+ assert validators.is_subset_of({}, "APPLE,BANANA", {"choices": choices, "case_sensitive": False}) == True
+
+ assert validators.is_subset_of({}, "durian", {"choices": choices}) == False
+ assert validators.is_subset_of({}, "apple,durian", {"choices": choices}) == False
+
+ assert validators.is_subset_of({}, "none", {"choices": choices, "none_ok": True}) == True
+ assert validators.is_subset_of({}, "apple,none", {"choices": choices, "none_ok": True}) == False
+
+ with pytest.raises(ValueError):
+ validators.is_subset_of({}, "apple", {})