Skip to content

Commit

Permalink
Addeds OuptutFieldChooseOne
Browse files Browse the repository at this point in the history
  • Loading branch information
aelaguiz committed Mar 24, 2024
1 parent 4fd543c commit 3acce9b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
2 changes: 1 addition & 1 deletion langdspy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .field_descriptors import InputField, OutputField, InputFieldList, HintField, OutputFieldEnum, OutputFieldEnumList, OutputFieldBool
from .field_descriptors import InputField, OutputField, InputFieldList, HintField, OutputFieldEnum, OutputFieldEnumList, OutputFieldBool, OutputFieldChooseOne
from .prompt_strategies import PromptSignature, PromptStrategy, DefaultPromptStrategy
from .prompt_runners import PromptRunner, RunnableConfig, Prediction, MultiPromptRunner
from .model import Model, TrainedModelState
Expand Down
16 changes: 16 additions & 0 deletions langdspy/field_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,22 @@ def format_prompt_description(self, llm_type: str):
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}One of: <choices>{choices_str}</choices> - {self.desc}</{self.name}>"

class OutputFieldChooseOne(OutputField):
def __init__(self, name: str, desc: str, choices: List[str], **kwargs):
kwargs['choices'] = choices

if not 'validator' in kwargs:
kwargs['validator'] = validators.is_one_of
kwargs['choices'] = choices
super().__init__(name, desc, **kwargs)

def format_prompt_description(self, llm_type: str):
choices_str = ", ".join(self.kwargs.get('choices', []))
if llm_type == "openai":
return f"{self._start_format_openai()}: One of: {choices_str} - {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}One of: <choices>{choices_str}</choices> - {self.desc}</{self.name}>"

class OutputFieldEnum(OutputField):
def __init__(self, name: str, desc: str, enum: Enum, **kwargs):
kwargs['enum'] = enum
Expand Down
19 changes: 17 additions & 2 deletions tests/test_field_descriptors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from enum import Enum
from langdspy.field_descriptors import InputField, InputFieldList, OutputField, OutputFieldEnum, OutputFieldEnumList, OutputFieldBool
from langdspy.field_descriptors import InputField, InputFieldList, OutputField, OutputFieldEnum, OutputFieldEnumList, OutputFieldBool, OutputFieldChooseOne

def test_input_field_initialization():
field = InputField("name", "description")
Expand Down Expand Up @@ -72,4 +72,19 @@ def test_output_field_bool_initialization():
def test_output_field_bool_format_prompt_description():
field = OutputFieldBool("name", "description")
assert "One of: Yes, No" in field.format_prompt_description("openai")
assert "One of: <choices>Yes, No</choices>" in field.format_prompt_description("anthropic")
assert "One of: <choices>Yes, No</choices>" in field.format_prompt_description("anthropic")

def test_output_field_choose_one_initialization():
choices = ["Option 1", "Option 2", "Option 3"]
field = OutputFieldChooseOne("name", "description", choices)
assert field.name == "name"
assert field.desc == "description"
assert field.validator.__name__ == "is_one_of"
assert field.kwargs['choices'] == choices

def test_output_field_choose_one_format_prompt_description():
choices = ["Option 1", "Option 2", "Option 3"]

field = OutputFieldChooseOne("name", "description", choices)
assert "One of: Option 1, Option 2, Option 3" in field.format_prompt_description("openai")
assert "One of: <choices>Option 1, Option 2, Option 3</choices>" in field.format_prompt_description("anthropic")

0 comments on commit 3acce9b

Please sign in to comment.