diff --git a/langdspy/field_descriptors.py b/langdspy/field_descriptors.py
index 2baec30..e1fdf68 100644
--- a/langdspy/field_descriptors.py
+++ b/langdspy/field_descriptors.py
@@ -64,6 +64,7 @@ def format_prompt_description(self, llm_type: str):
return f"{self._start_format_openai()}: {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{self.desc}{self.name}>"
+
def format_prompt_value(self, value, llm_type: str):
value = self.format_value(value)
if llm_type == "openai":
@@ -71,6 +72,11 @@ def format_prompt_value(self, value, llm_type: str):
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{value}{self.name}>"
+ def format_prompt_value_json(self, value, llm_type: str):
+ value = self.format_value(value)
+ return {self.name: value}
+
+
class InputFieldList(InputField):
def format_prompt_description(self, llm_type: str):
if llm_type == "openai":
@@ -157,18 +163,27 @@ def format_prompt_description(self, llm_type: str):
return f"{self._start_format_openai()}: {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{self.desc}{self.name}>"
+
def format_prompt_value(self, value, llm_type: str):
value = self.format_value(value)
if llm_type == "openai":
return f"{self._start_format_openai()}: {value}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{value}{self.name}>"
+
def format_prompt(self, llm_type: str):
if llm_type == "openai":
return f"{self._start_format_openai()}:"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{self.name}>"
+ def format_prompt_json(self):
+ return {self.name: ""}
+
+ def format_prompt_value_json(self, value, llm_type: str):
+ return {self.name: value}
+
+
class OutputFieldBool(OutputField):
def __init__(self, name: str, desc: str, **kwargs):
if not 'transformer' in kwargs:
diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py
index dacff94..3584058 100644
--- a/langdspy/prompt_runners.py
+++ b/langdspy/prompt_runners.py
@@ -82,7 +82,13 @@ def set_model_kwargs(self, model_kwargs):
self.model_kwargs.update(model_kwargs)
def _determine_llm_type(self, llm):
+ print(f"Determining llm type")
if isinstance(llm, ChatOpenAI): # Assuming OpenAILLM is the class for OpenAI models
+ print(f"Getting llm type")
+ logger.debug(llm.kwargs)
+ if llm.kwargs.get('response_format', {}).get('type') == 'json_object':
+ logger.info("OpenAI model response format is json_object")
+ return 'openai_json'
return 'openai'
elif isinstance(llm, ChatAnthropic): # Assuming AnthropicLLM is the class for Anthropic models
return 'anthropic'
@@ -186,7 +192,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
len_parsed_output = len(parsed_output.keys())
len_output_variables = len(self.template.output_variables.keys())
- # logger.debug(f"Parsed output keys: {parsed_output.keys()} [{len_parsed_output}] Expected output keys: {self.template.output_variables.keys()} [{len_output_variables}]")
+ logger.debug(f"Parsed output keys: {parsed_output.keys()} [{len_parsed_output}] Expected output keys: {self.template.output_variables.keys()} [{len_output_variables}]")
if len(parsed_output.keys()) != len(self.template.output_variables.keys()):
validation_err = f"Output keys do not match expected output keys for prompt runner {self.template.__class__.__name__}"
diff --git a/langdspy/prompt_strategies.py b/langdspy/prompt_strategies.py
index 1302a66..14f6879 100644
--- a/langdspy/prompt_strategies.py
+++ b/langdspy/prompt_strategies.py
@@ -1,4 +1,5 @@
from langchain.prompts import BasePromptTemplate # Assuming this is the correct import path
+import json
import re
from langchain.prompts import FewShotPromptTemplate
from langchain_core.runnables import RunnableSerializable
@@ -77,8 +78,8 @@ class PromptStrategy(BaseModel):
def validate_inputs(self, inputs_dict):
if not set(inputs_dict.keys()) == set(self.input_variables.keys()):
- logger.error(f"Input keys do not match expected input keys {inputs_dict.keys()} {self.input_variables.keys()}")
- raise ValueError(f"Input keys do not match expected input keys {inputs_dict.keys()} {self.input_variables.keys()}")
+ logger.error(f"Input keys do not match expected input keys Expected = {inputs_dict.keys()} Received = {self.input_variables.keys()}")
+ raise ValueError(f"Input keys do not match expected input keys Expected: {inputs_dict.keys()} Received: {self.input_variables.keys()}")
def format(self, **kwargs: Any) -> str:
logger.debug(f"PromptStrategy format with kwargs: {kwargs}")
@@ -103,6 +104,8 @@ def format_prompt(self, **kwargs: Any) -> str:
if llm_type == 'openai':
prompt = self._format_openai_prompt(trained_state, use_training, examples, **kwargs)
+ elif llm_type == 'openai_json':
+ prompt = self._format_openai_json_prompt(trained_state, use_training, examples, **kwargs)
elif llm_type == 'anthropic':
prompt = self._format_anthropic_prompt(trained_state, use_training, examples, **kwargs)
@@ -114,7 +117,9 @@ def format_prompt(self, **kwargs: Any) -> str:
raise e
def parse_output_to_fields(self, output: str, llm_type: str) -> dict:
- if llm_type == 'openai':
+ if llm_type == 'openai_json':
+ return self._parse_openai_json_output_to_fields(output)
+ elif llm_type == 'openai':
return self._parse_openai_output_to_fields(output)
elif llm_type == 'anthropic':
return self._parse_anthropic_output_to_fields(output)
@@ -123,11 +128,14 @@ def parse_output_to_fields(self, output: str, llm_type: str) -> dict:
else:
raise ValueError(f"Unsupported LLM type: {llm_type}")
-
@abstractmethod
def _format_openai_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
pass
+ @abstractmethod
+ def _format_openai_json_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
+ pass
+
@abstractmethod
def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
pass
@@ -145,10 +153,79 @@ def _parse_openai_output_to_fields(self, output: str) -> dict:
def _parse_anthropic_output_to_fields(self, output: str) -> dict:
pass
+ @abstractmethod
+ def _parse_openai_json_output_to_fields(self, output: str) -> dict:
+ pass
+
class DefaultPromptStrategy(PromptStrategy):
OUTPUT_TOKEN = "🔑"
+ def _format_openai_json_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
+ prompt = "Follow the following format. Answer with a JSON object. Attributes that have values should not be changed or repeated."
+
+ if len(self.output_variables) > 1:
+ output_field_names = ', '.join([output_field.name for output_field in self.output_variables.values()])
+ prompt += f" Provide answers for {output_field_names}.\n"
+
+ if self.hint_variables:
+ prompt += "\n"
+
+ for _, hint_field in self.hint_variables.items():
+ prompt += hint_field.format_prompt_description("openai") + "\n"
+
+ prompt += "\nInput Fields:\n"
+ input_fields_dict = {}
+ for input_name, input_field in self.input_variables.items():
+ input_fields_dict[input_field.name] = input_field.desc
+ prompt += json.dumps(input_fields_dict, indent=2) + "\n"
+
+ prompt += "\nOutput Fields:\n"
+ output_fields_dict = {}
+ for output_name, output_field in self.output_variables.items():
+ output_fields_dict[output_field.name] = output_field.desc
+ prompt += json.dumps(output_fields_dict, indent=2) + "\n"
+
+ if examples:
+ prompt += "\nExamples:\n"
+ for example_input, example_output in examples:
+ example_dict = {"input": {}, "output": {}}
+ for input_name, input_field in self.input_variables.items():
+ example_dict["input"].update(input_field.format_prompt_value_json(example_input.get(input_name), 'openai_json'))
+ for output_name, output_field in self.output_variables.items():
+ if isinstance(example_output, dict):
+ example_dict["output"].update(output_field.format_prompt_value_json(example_output.get(output_name), 'openai_json'))
+ else:
+ example_dict["output"].update(output_field.format_prompt_value_json(example_output, 'openai_json'))
+ prompt += json.dumps(example_dict, indent=2) + "\n"
+
+ if trained_state and trained_state.examples and use_training:
+ prompt += "\nTrained Examples:\n"
+ for example_X, example_y in trained_state.examples:
+ example_dict = {"input": {}, "output": {}}
+ for input_name, input_field in self.input_variables.items():
+ example_dict["input"].update(input_field.format_prompt_value_json(example_X.get(input_name), 'openai_json'))
+ for output_name, output_field in self.output_variables.items():
+ if isinstance(example_y, dict):
+ example_dict["output"].update(output_field.format_prompt_value_json(example_y.get(output_name), 'openai_json'))
+ else:
+ example_dict["output"].update(output_field.format_prompt_value_json(example_y, 'openai_json'))
+ prompt += json.dumps(example_dict, indent=2) + "\n"
+
+ prompt += "\nInput:\n"
+ input_dict = {}
+ for input_name, input_field in self.input_variables.items():
+ input_dict.update(input_field.format_prompt_value_json(kwargs.get(input_name), 'openai_json'))
+ prompt += json.dumps(input_dict, indent=2) + "\n"
+
+ prompt += "\nOutput:\n"
+ output_dict = {}
+ for output_name, output_field in self.output_variables.items():
+ output_dict.update(output_field.format_prompt_json())
+ prompt += json.dumps(output_dict, indent=2) + "\n"
+
+ return prompt
+
def _format_openai_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
# print(f"Formatting prompt {kwargs}")
prompt = "Follow the following format. Attributes that have values should not be changed or repeated. "
@@ -291,20 +368,20 @@ def _parse_openai_output_to_fields(self, output: str) -> dict:
pattern = r'^([^:]+): (.*)'
lines = output.split(self.OUTPUT_TOKEN)
parsed_fields = {}
- # logger.debug(f"Parsing output to fields with pattern {pattern} and lines {lines}")
+ logger.debug(f"Parsing output to fields with pattern {pattern} and lines {lines}")
for line in lines:
match = re.match(pattern, line, re.MULTILINE)
if match:
field_name, field_content = match.groups()
- # logger.debug(f"Matched line {line} - field name {field_name} field content {field_content}")
+ logger.debug(f"Matched line {line} - field name {field_name} field content {field_content}")
output_field = self._get_output_field(field_name)
if output_field:
- # logger.debug(f"Matched field {field_name} to output field {output_field}")
+ logger.debug(f"Matched field {field_name} to output field {output_field}")
parsed_fields[output_field] = field_content
else:
logger.error(f"Field {field_name} not found in output variables")
- # else:
- # logger.debug(f"NO MATCH line {line}")
+ else:
+ logger.debug(f"NO MATCH line {line}")
if len(self.output_variables) == 1:
first_value = next(iter(parsed_fields.values()), None)
@@ -343,3 +420,37 @@ def _parse_anthropic_output_to_fields(self, output: str) -> dict:
traceback.print_exc()
raise e
+
+ def _parse_openai_json_output_to_fields(self, output: str) -> dict:
+ print(f"Parsing openai json")
+ try:
+ # Parse the JSON output
+ json_output = json.loads(output)
+
+ # Initialize an empty dictionary to store the parsed fields
+ parsed_fields = {}
+
+ # Iterate over the output variables
+ for output_name, output_field in self.output_variables.items():
+ # Check if the output field exists in the JSON output
+ if output_field.name in json_output:
+ # Get the value of the output field from the JSON output
+ field_value = json_output[output_field.name]
+
+ # Apply any necessary transformations to the field value
+ transformed_value = output_field.transform_value(field_value)
+
+ # Store the transformed value in the parsed fields dictionary
+ parsed_fields[output_name] = transformed_value
+ else:
+ # If the output field is not present in the JSON output, set its value to None
+ parsed_fields[output_name] = None
+
+ logger.debug(f"Parsed fields: {parsed_fields}")
+ return parsed_fields
+ except json.JSONDecodeError as e:
+ logger.error(f"Failed to parse JSON output: {e}")
+ raise e
+ except Exception as e:
+ logger.error(f"An error occurred while parsing JSON output: {e}")
+ raise e
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 27f0808..ea72f16 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,7 +11,6 @@ python = ">=3.9,<4.0"
aiohttp = "3.9.3"
aiosignal = "1.3.1"
annotated-types = "0.6.0"
-anthropic = "0.19.1"
anyio = "4.2.0"
attrs = "23.2.0"
build = "1.1.1"
@@ -45,11 +44,6 @@ joblib = "1.3.2"
jsonpatch = "1.33"
jsonpointer = "2.4"
keyring = "24.3.1"
-langchain = "0.1.7"
-langchain-anthropic = "0.1.3"
-langchain-community = "0.0.20"
-langchain-core = "0.1.23"
-langchain-openai = "0.0.6"
markdown-it-py = "3.0.0"
marshmallow = "3.20.2"
mdurl = "0.1.2"
@@ -58,7 +52,6 @@ msgpack = "1.0.8"
multidict = "6.0.5"
mypy-extensions = "1.0.0"
numpy = "1.26.4"
-openai = "1.12.0"
orjson = "3.9.15"
packaging = "23.2"
pexpect = "4.9.0"
@@ -99,6 +92,10 @@ xattr = "1.1.0"
yarl = "1.9.4"
zipp = "3.17.0"
ratelimit = "^2.2.1"
+langchain = "^0.1.20"
+langchain-openai = "^0.1.6"
+anthropic = "^0.25.8"
+langchain-anthropic = "^0.1.11"
[tool.poetry.dev-dependencies]
diff --git a/tests/test_manual_examples.py b/tests/test_manual_examples.py
index ac643cd..fea721d 100644
--- a/tests/test_manual_examples.py
+++ b/tests/test_manual_examples.py
@@ -164,4 +164,61 @@ class MultipleOutputPromptSignature(PromptSignature):
)
assert "Example input" in formatted_prompt
assert "Example output 1" in formatted_prompt
- assert "Example output 2" in formatted_prompt
\ No newline at end of file
+ assert "Example output 2" in formatted_prompt
+
+# ... (previous tests remain the same)
+
+def test_format_prompt_with_examples_openai_json():
+ prompt_runner = PromptRunner(template_class=TestPromptSignature, prompt_strategy=DefaultPromptStrategy)
+ formatted_prompt = prompt_runner.template._format_openai_json_prompt(
+ trained_state=None,
+ use_training=True,
+ examples=TestPromptSignature.__examples__,
+ input="Test input"
+ )
+ assert 'Hint field' in formatted_prompt
+ assert '"input": "Input field"' in formatted_prompt
+ assert '"output": "Output field"' in formatted_prompt
+ assert '"input": "Example input 1"' in formatted_prompt
+ assert '"output": "Example output 1"' in formatted_prompt
+ assert '"input": "Example input 2"' in formatted_prompt
+ assert '"output": "Example output 2"' in formatted_prompt
+ assert '"input": "Test input"' in formatted_prompt
+ assert '"output": ""' in formatted_prompt
+
+def test_format_prompt_without_examples_openai_json():
+ prompt_runner = PromptRunner(template_class=TestPromptSignature, prompt_strategy=DefaultPromptStrategy)
+ formatted_prompt = prompt_runner.template._format_openai_json_prompt(
+ trained_state=None,
+ use_training=False,
+ examples=[],
+ input="Test input"
+ )
+ assert 'Hint field' in formatted_prompt
+ assert '"input": "Input field"' in formatted_prompt
+ assert '"output": "Output field"' in formatted_prompt
+ assert '"input": "Test input"' in formatted_prompt
+ assert '"output": ""' in formatted_prompt
+ assert "Example input 1" not in formatted_prompt
+ assert "Example output 1" not in formatted_prompt
+ assert "Example input 2" not in formatted_prompt
+ assert "Example output 2" not in formatted_prompt
+
+def test_format_prompt_with_multiple_output_fields_openai_json():
+ class MultipleOutputPromptSignature(PromptSignature):
+ input = InputField(name="input", desc="Input field")
+ output1 = OutputField(name="output1", desc="Output field 1")
+ output2 = OutputField(name="output2", desc="Output field 2")
+ __examples__ = [
+ ({"input": "Example input"}, {"output1": "Example output 1", "output2": "Example output 2"}),
+ ]
+ prompt_runner = PromptRunner(template_class=MultipleOutputPromptSignature, prompt_strategy=DefaultPromptStrategy)
+ formatted_prompt = prompt_runner.template._format_openai_json_prompt(
+ trained_state=None,
+ use_training=True,
+ examples=MultipleOutputPromptSignature.__examples__,
+ input="Test input"
+ )
+ assert '"input": "Example input"' in formatted_prompt
+ assert '"output1": "Example output 1"' in formatted_prompt
+ assert '"output2": "Example output 2"' in formatted_prompt
\ No newline at end of file
diff --git a/tests/test_output_parsing.py b/tests/test_output_parsing.py
index 76d3d40..be5cf02 100644
--- a/tests/test_output_parsing.py
+++ b/tests/test_output_parsing.py
@@ -1,9 +1,19 @@
import pytest
-from langdspy.field_descriptors import InputField, OutputField, OutputFieldBool
+import sys
+
+import logging
+
+logger = logging.getLogger("langdspy")
+logger.setLevel(logging.DEBUG)
+handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(handler)
+
+from langdspy.field_descriptors import InputField, OutputField, OutputFieldBool, InputFieldDictList, HintField, InputFieldList
from langdspy.prompt_strategies import PromptSignature, DefaultPromptStrategy
from langdspy.prompt_runners import PromptRunner
from langdspy.formatters import as_multiline
+
class TestOutputParsingPromptSignature(PromptSignature):
ticket_summary = InputField(name="Ticket Summary", desc="Summary of the ticket we're trying to analyze.")
buyer_issues_summary = OutputField(name="Buyer Issues Summary", desc="Summary of the issues this buyer is facing.")
@@ -136,4 +146,24 @@ class IsTicketSpam(PromptSignature):
prompt_runner = PromptRunner(template_class=IsTicketSpam, prompt_strategy=DefaultPromptStrategy)
result = prompt_runner.template.parse_output_to_fields(output_data, config["llm_type"])
- assert result["is_spam"] == "No"
\ No newline at end of file
+ assert result["is_spam"] == "No"
+
+def test_output_parsing_openai_json():
+ prompt_runner = PromptRunner(template_class=TestOutputParsingPromptSignature, prompt_strategy=DefaultPromptStrategy)
+
+ input_data = {
+ "ticket_summary": "..."
+ }
+
+ output_data = """
+ {
+ "Buyer Issues Summary": "The buyer is trying to personalize their order by selecting variants like color or size, but after making their selections and hitting \\"done\\", the changes are not being reflected. They are also asking how long delivery will take.",
+ "Buyer Issue Enum": "BOX_CONTENTS_CUSTOMIZATION"
+ }
+ """
+
+ config = {"llm_type": "openai_json"}
+ result = prompt_runner.template.parse_output_to_fields(output_data, config["llm_type"])
+
+ assert result["buyer_issues_summary"] == "The buyer is trying to personalize their order by selecting variants like color or size, but after making their selections and hitting \"done\", the changes are not being reflected. They are also asking how long delivery will take."
+ assert result["buyer_issue_category"] == "BOX_CONTENTS_CUSTOMIZATION"
\ No newline at end of file
diff --git a/tests/test_prompt_formatting.py b/tests/test_prompt_formatting.py
index ef0af7a..86e36c3 100644
--- a/tests/test_prompt_formatting.py
+++ b/tests/test_prompt_formatting.py
@@ -50,18 +50,35 @@ def test_parse_output_anthropic():
assert parsed_output["output"] == "test output"
-def test_llm_type_detection_openai():
+def test_llm_type_detection_anthropic():
prompt_runner = PromptRunner(template_class=TestPromptSignature, prompt_strategy=DefaultPromptStrategy)
- llm = ChatOpenAI()
+ llm = ChatAnthropic(model_name="claude-3-sonnet-20240229")
llm_type = prompt_runner._determine_llm_type(llm)
- assert llm_type == "openai"
+ assert llm_type == "anthropic"
-def test_llm_type_detection_anthropic():
+def test_format_prompt_openai_json():
prompt_runner = PromptRunner(template_class=TestPromptSignature, prompt_strategy=DefaultPromptStrategy)
- llm = ChatAnthropic(model_name="claude-3-sonnet-20240229")
- llm_type = prompt_runner._determine_llm_type(llm)
+ formatted_prompt = prompt_runner.template._format_openai_json_prompt(trained_state=None, use_training=True, input="test input", examples=None)
+
+ print(formatted_prompt)
+
+ assert 'Hint field' in formatted_prompt
+ assert "Input Fields:" in formatted_prompt
+ assert '"input": "Input field"' in formatted_prompt
+ assert "Output Fields:" in formatted_prompt
+ assert '"output": "Output field"' in formatted_prompt
+ assert "Input:" in formatted_prompt
+ assert '"input": "test input"' in formatted_prompt
+ assert "Output:" in formatted_prompt
+ assert '"output": ""' in formatted_prompt
+
+def test_parse_output_openai_json():
+ prompt_runner = PromptRunner(template_class=TestPromptSignature, prompt_strategy=DefaultPromptStrategy)
+
+ output = '{"output": "test output"}'
+ parsed_output = prompt_runner.template._parse_openai_json_output_to_fields(output)
- assert llm_type == "anthropic"
\ No newline at end of file
+ assert parsed_output["output"] == "test output"
\ No newline at end of file