Skip to content

Commit

Permalink
Merge pull request #13 from aelaguiz/manual_example
Browse files Browse the repository at this point in the history
Manual example
  • Loading branch information
aelaguiz authored Mar 18, 2024
2 parents a5a1e2c + 3428661 commit 1171c94
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 16 deletions.
71 changes: 57 additions & 14 deletions langdspy/prompt_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class PromptSignature(BasePromptTemplate, BaseModel):
output_variables: Dict[str, Any] = []
hint_variables: Dict[str, Any] = [] # New attribute for hint fields
instance_id: str = Field(default_factory=str)
__examples__: List[Tuple[Dict[str, Any], Any]] = []


def __init__(self, **kwargs):
Expand All @@ -52,6 +53,25 @@ def __init__(self, **kwargs):
self.output_variables = outputs
self.hint_variables = hints

self.validate_examples()

def validate_examples(self):
for example_input, example_output in self.__examples__:
# Check input fields
for input_name in example_input:
if input_name not in self.input_variables:
raise ValueError(f"Example input field '{input_name}' not found in input_variables")

# Check output fields
if isinstance(example_output, dict):
for output_name in example_output:
if output_name not in self.output_variables:
raise ValueError(f"Example output field '{output_name}' not found in output_variables")
else:
if len(self.output_variables) != 1:
raise ValueError("Example output must be a dictionary when there are multiple output fields")


class PromptStrategy(BaseModel):
best_subset: List[Any] = []

Expand All @@ -70,6 +90,8 @@ def format_prompt(self, **kwargs: Any) -> str:
trained_state = kwargs.pop('trained_state', None)
print_prompt = kwargs.pop('print_prompt', False)
use_training = kwargs.pop('use_training', True)
examples = kwargs.pop('__examples__', self.__examples__) # Add this line

# print(f"Formatting prompt with trained_state {trained_state} and print_prompt {print_prompt} and kwargs {kwargs}")
# print(f"Formatting prompt with use_training {use_training}")

Expand All @@ -80,9 +102,9 @@ def format_prompt(self, **kwargs: Any) -> str:
# logger.debug(f"PromptStrategy format_prompt with kwargs: {kwargs}")

if llm_type == 'openai':
prompt = self._format_openai_prompt(trained_state, use_training, **kwargs)
prompt = self._format_openai_prompt(trained_state, use_training, examples, **kwargs)
elif llm_type == 'anthropic':
prompt = self._format_anthropic_prompt(trained_state, use_training, **kwargs)
prompt = self._format_anthropic_prompt(trained_state, use_training, examples, **kwargs)

if print_prompt:
print(prompt)
Expand All @@ -106,11 +128,11 @@ def parse_output_to_fields(self, output: str, llm_type: str) -> dict:


@abstractmethod
def _format_openai_prompt(self, **kwargs: Any) -> str:
def _format_openai_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
pass

@abstractmethod
def _format_anthropic_prompt(self, **kwargs: Any) -> str:
def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
pass

def _get_output_field(self, field_name):
Expand All @@ -130,7 +152,7 @@ def _parse_anthropic_output_to_fields(self, output: str) -> dict:
class DefaultPromptStrategy(PromptStrategy):
OUTPUT_TOKEN = "🔑"

def _format_openai_prompt(self, trained_state, use_training, **kwargs) -> str:
def _format_openai_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
# print(f"Formatting prompt {kwargs}")
prompt = "Follow the following format. Attributes that have values should not be changed or repeated. "

Expand Down Expand Up @@ -159,11 +181,17 @@ def _format_openai_prompt(self, trained_state, use_training, **kwargs) -> str:
prompt += output_field.format_prompt_description("openai") + "\n"
# prompt += f"{self.OUTPUT_TOKEN}{output_field.name}: {output_field.desc}\n"

"""
if examples:
for example_input, example_output in examples:
prompt += "\n---\n\n"
for input_name, input_field in self.input_variables.items():
prompt += input_field.format_prompt_value(example_input.get(input_name), "openai") + "\n"
for output_name, output_field in self.output_variables.items():
if isinstance(example_output, dict):
prompt += output_field.format_prompt_value(example_output.get(output_name), "openai") + "\n"
else:
prompt += output_field.format_prompt_value(example_output, "openai") + "\n"

EXAMPLES GO HERE
"""
if trained_state and trained_state.examples and use_training:
for example_X, example_y in trained_state.examples:
prompt += "\n---\n\n"
Expand All @@ -188,7 +216,7 @@ def _format_openai_prompt(self, trained_state, use_training, **kwargs) -> str:

return prompt

def _format_anthropic_prompt(self, trained_state, use_training, **kwargs) -> str:
def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
# print(f"Formatting prompt {kwargs}")
prompt = "Follow the following format. Attributes that have values should not be changed or repeated. "

Expand All @@ -212,10 +240,25 @@ def _format_anthropic_prompt(self, trained_state, use_training, **kwargs) -> str
prompt += output_field.format_prompt_description("anthropic") + "\n"
# prompt += f"{self.OUTPUT_TOKEN}{output_field.name}: {output_field.desc}\n"
prompt += "</output_fields>\n"
"""
EXAMPLES GO HERE
"""

if examples:
prompt += "\n<examples>\n"
for example_input, example_output in examples:
prompt += "\n<example>\n"
prompt += "<input>\n"
for input_name, input_field in self.input_variables.items():
prompt += input_field.format_prompt_value(example_input.get(input_name), "anthropic") + "\n"
prompt += "</input>\n"
prompt += "<output>\n"
for output_name, output_field in self.output_variables.items():
if isinstance(example_output, dict):
prompt += output_field.format_prompt_value(example_output.get(output_name), "anthropic") + "\n"
else:
prompt += output_field.format_prompt_value(example_output, "anthropic") + "\n"
prompt += "</output>\n"
prompt += "</example>\n"
prompt += "</examples>\n"

if trained_state and trained_state.examples and use_training:
prompt += "\n<examples>\n"
for example_X, example_y in trained_state.examples:
Expand Down
3 changes: 3 additions & 0 deletions langdspy/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def as_bool(value: str, kwargs: Dict[str, Any]) -> bool:
def as_json_list(val: str, kwargs: Dict[str, Any]) -> List[Dict[str, Any]]:
return json.loads(val)

def as_json(val: str, kwargs: Dict[str, Any]) -> Any:
return json.loads(val)

def as_enum(val: str, kwargs: Dict[str, Any]) -> Enum:
enum_class = kwargs['enum']
try:
Expand Down
167 changes: 167 additions & 0 deletions tests/test_manual_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import pytest
from langdspy.field_descriptors import InputField, OutputField, HintField
from langdspy.prompt_strategies import PromptSignature, DefaultPromptStrategy
from langdspy.prompt_runners import PromptRunner

class TestPromptSignature(PromptSignature):
input = InputField(name="input", desc="Input field")
output = OutputField(name="output", desc="Output field")
hint = HintField(desc="Hint field")
__examples__ = [
({"input": "Example input 1"}, "Example output 1"),
({"input": "Example input 2"}, "Example output 2"),
]

def test_format_prompt_with_examples_openai():
prompt_runner = PromptRunner(template_class=TestPromptSignature, prompt_strategy=DefaultPromptStrategy)
formatted_prompt = prompt_runner.template._format_openai_prompt(
trained_state=None,
use_training=True,
examples=TestPromptSignature.__examples__,
input="Test input"
)
print(formatted_prompt)
assert "💡 Hint field" in formatted_prompt
assert "✅input: Input field" in formatted_prompt
assert "🔑output: Output field" in formatted_prompt
assert "✅input: Example input 1" in formatted_prompt
assert "🔑output: Example output 1" in formatted_prompt
assert "✅input: Example input 2" in formatted_prompt
assert "🔑output: Example output 2" in formatted_prompt
assert "✅input: Test input" in formatted_prompt
assert "🔑output:" in formatted_prompt

def test_format_prompt_with_examples_anthropic():
prompt_runner = PromptRunner(template_class=TestPromptSignature, prompt_strategy=DefaultPromptStrategy)
formatted_prompt = prompt_runner.template._format_anthropic_prompt(
trained_state=None,
use_training=True,
examples=TestPromptSignature.__examples__,
input="Test input"
)
assert "<hint>Hint field</hint>" in formatted_prompt
assert "<input>Input field</input>" in formatted_prompt
assert "<output>Output field</output>" in formatted_prompt
assert "<input>Example input 1</input>" in formatted_prompt
assert "<output>Example output 1</output>" in formatted_prompt
assert "<input>Example input 2</input>" in formatted_prompt
assert "<output>Example output 2</output>" in formatted_prompt
assert "<input>Test input</input>" in formatted_prompt
assert "<output></output>" in formatted_prompt

def test_format_prompt_without_examples_openai():
prompt_runner = PromptRunner(template_class=TestPromptSignature, prompt_strategy=DefaultPromptStrategy)
formatted_prompt = prompt_runner.template._format_openai_prompt(
trained_state=None,
use_training=False,
examples=[],
input="Test input"
)
assert "💡 Hint field" in formatted_prompt
assert "✅input: Input field" in formatted_prompt
assert "🔑output: Output field" in formatted_prompt
assert "✅input: Test input" in formatted_prompt
assert "🔑output:" in formatted_prompt
assert "Example input 1" not in formatted_prompt
assert "Example output 1" not in formatted_prompt
assert "Example input 2" not in formatted_prompt
assert "Example output 2" not in formatted_prompt

def test_format_prompt_without_examples_anthropic():
prompt_runner = PromptRunner(template_class=TestPromptSignature, prompt_strategy=DefaultPromptStrategy)
formatted_prompt = prompt_runner.template._format_anthropic_prompt(
trained_state=None,
use_training=False,
examples=[],
input="Test input"
)
assert "<hint>Hint field</hint>" in formatted_prompt
assert "<input>Input field</input>" in formatted_prompt
assert "<output>Output field</output>" in formatted_prompt
assert "<input>Test input</input>" in formatted_prompt
assert "<output></output>" in formatted_prompt
assert "Example input 1" not in formatted_prompt
assert "Example output 1" not in formatted_prompt
assert "Example input 2" not in formatted_prompt
assert "Example output 2" not in formatted_prompt

def test_validate_examples_valid():
class ValidPromptSignature(PromptSignature):
input1 = InputField(name="input1", desc="Input field 1")
input2 = InputField(name="input2", desc="Input field 2")
output1 = OutputField(name="output1", desc="Output field 1")
output2 = OutputField(name="output2", desc="Output field 2")
__examples__ = [
({"input1": "Example input 1", "input2": "Example input 2"}, {"output1": "Example output 1", "output2": "Example output 2"}),
]
prompt_runner = PromptRunner(template_class=ValidPromptSignature, prompt_strategy=DefaultPromptStrategy)
prompt_runner.template.validate_examples() # Should not raise any exception

def test_validate_examples_invalid_input_field():
class InvalidInputPromptSignature(PromptSignature):
input = InputField(name="input", desc="Input field")
output = OutputField(name="output", desc="Output field")
__examples__ = [
({"invalid_input": "Example input"}, "Example output"),
]
with pytest.raises(ValueError, match="Example input field 'invalid_input' not found in input_variables"):
PromptRunner(template_class=InvalidInputPromptSignature, prompt_strategy=DefaultPromptStrategy)

def test_validate_examples_invalid_output_field():
class InvalidOutputPromptSignature(PromptSignature):
input = InputField(name="input", desc="Input field")
output = OutputField(name="output", desc="Output field")
__examples__ = [
({"input": "Example input"}, {"invalid_output": "Example output"}),
]
with pytest.raises(ValueError, match="Example output field 'invalid_output' not found in output_variables"):
PromptRunner(template_class=InvalidOutputPromptSignature, prompt_strategy=DefaultPromptStrategy)

def test_validate_examples_invalid_output_format():
class InvalidOutputFormatPromptSignature(PromptSignature):
input = InputField(name="input", desc="Input field")
output1 = OutputField(name="output1", desc="Output field 1")
output2 = OutputField(name="output2", desc="Output field 2")
__examples__ = [
({"input": "Example input"}, "Example output"),
]
with pytest.raises(ValueError, match="Example output must be a dictionary when there are multiple output fields"):
PromptRunner(template_class=InvalidOutputFormatPromptSignature, prompt_strategy=DefaultPromptStrategy)

def test_format_prompt_with_multiple_output_fields_openai():
class MultipleOutputPromptSignature(PromptSignature):
input = InputField(name="input", desc="Input field")
output1 = OutputField(name="output1", desc="Output field 1")
output2 = OutputField(name="output2", desc="Output field 2")
__examples__ = [
({"input": "Example input"}, {"output1": "Example output 1", "output2": "Example output 2"}),
]
prompt_runner = PromptRunner(template_class=MultipleOutputPromptSignature, prompt_strategy=DefaultPromptStrategy)
formatted_prompt = prompt_runner.template._format_openai_prompt(
trained_state=None,
use_training=True,
examples=MultipleOutputPromptSignature.__examples__,
input="Test input"
)
assert "✅input: Example input" in formatted_prompt
assert "🔑output1: Example output 1" in formatted_prompt
assert "🔑output2: Example output 2" in formatted_prompt

def test_format_prompt_with_multiple_output_fields_anthropic():
class MultipleOutputPromptSignature(PromptSignature):
input = InputField(name="input", desc="Input field")
output1 = OutputField(name="output1", desc="Output field 1")
output2 = OutputField(name="output2", desc="Output field 2")
__examples__ = [
({"input": "Example input"}, {"output1": "Example output 1", "output2": "Example output 2"}),
]
prompt_runner = PromptRunner(template_class=MultipleOutputPromptSignature, prompt_strategy=DefaultPromptStrategy)
formatted_prompt = prompt_runner.template._format_anthropic_prompt(
trained_state=None,
use_training=True,
examples=MultipleOutputPromptSignature.__examples__,
input="Test input"
)
assert "<input>Example input</input>" in formatted_prompt
assert "<output1>Example output 1</output1>" in formatted_prompt
assert "<output2>Example output 2</output2>" in formatted_prompt
4 changes: 2 additions & 2 deletions tests/test_prompt_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TestPromptSignature(PromptSignature):
def test_format_prompt_openai():
prompt_runner = PromptRunner(template_class=TestPromptSignature, prompt_strategy=DefaultPromptStrategy)

formatted_prompt = prompt_runner.template._format_openai_prompt(trained_state=None, use_training=True, input="test input")
formatted_prompt = prompt_runner.template._format_openai_prompt(trained_state=None, use_training=True, input="test input", examples=None)
print(formatted_prompt)

assert "💡 Hint field" in formatted_prompt
Expand All @@ -26,7 +26,7 @@ def test_format_prompt_openai():
def test_format_prompt_anthropic():
prompt_runner = PromptRunner(template_class=TestPromptSignature, prompt_strategy=DefaultPromptStrategy)

formatted_prompt = prompt_runner.template._format_anthropic_prompt(trained_state=None, use_training=True, input="test input")
formatted_prompt = prompt_runner.template._format_anthropic_prompt(trained_state=None, use_training=True, input="test input", examples=None)

assert "<hint>Hint field</hint>" in formatted_prompt
assert "<input>Input field</input>" in formatted_prompt
Expand Down

0 comments on commit 1171c94

Please sign in to comment.