Skip to content

Commit

Permalink
Merge pull request #26 from aelaguiz/openai_json
Browse files Browse the repository at this point in the history
Introduces openai json support
  • Loading branch information
aelaguiz authored May 14, 2024
2 parents ad3a5e1 + 83fdabb commit 81b885a
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 27 deletions.
15 changes: 15 additions & 0 deletions langdspy/field_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,19 @@ def format_prompt_description(self, llm_type: str):
return f"{self._start_format_openai()}: {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{self.desc}</{self.name}>"

def format_prompt_value(self, value, llm_type: str):
value = self.format_value(value)
if llm_type == "openai":
return f"{self._start_format_openai()}: {value}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{value}</{self.name}>"

def format_prompt_value_json(self, value, llm_type: str):
value = self.format_value(value)
return {self.name: value}


class InputFieldList(InputField):
def format_prompt_description(self, llm_type: str):
if llm_type == "openai":
Expand Down Expand Up @@ -157,18 +163,27 @@ def format_prompt_description(self, llm_type: str):
return f"{self._start_format_openai()}: {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{self.desc}</{self.name}>"

def format_prompt_value(self, value, llm_type: str):
value = self.format_value(value)
if llm_type == "openai":
return f"{self._start_format_openai()}: {value}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{value}</{self.name}>"

def format_prompt(self, llm_type: str):
if llm_type == "openai":
return f"{self._start_format_openai()}:"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}</{self.name}>"

def format_prompt_json(self):
return {self.name: ""}

def format_prompt_value_json(self, value, llm_type: str):
return {self.name: value}


class OutputFieldBool(OutputField):
def __init__(self, name: str, desc: str, **kwargs):
if not 'transformer' in kwargs:
Expand Down
8 changes: 7 additions & 1 deletion langdspy/prompt_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@ def set_model_kwargs(self, model_kwargs):
self.model_kwargs.update(model_kwargs)

def _determine_llm_type(self, llm):
print(f"Determining llm type")
if isinstance(llm, ChatOpenAI): # Assuming OpenAILLM is the class for OpenAI models
print(f"Getting llm type")
logger.debug(llm.kwargs)
if llm.kwargs.get('response_format', {}).get('type') == 'json_object':
logger.info("OpenAI model response format is json_object")
return 'openai_json'
return 'openai'
elif isinstance(llm, ChatAnthropic): # Assuming AnthropicLLM is the class for Anthropic models
return 'anthropic'
Expand Down Expand Up @@ -186,7 +192,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna

len_parsed_output = len(parsed_output.keys())
len_output_variables = len(self.template.output_variables.keys())
# logger.debug(f"Parsed output keys: {parsed_output.keys()} [{len_parsed_output}] Expected output keys: {self.template.output_variables.keys()} [{len_output_variables}]")
logger.debug(f"Parsed output keys: {parsed_output.keys()} [{len_parsed_output}] Expected output keys: {self.template.output_variables.keys()} [{len_output_variables}]")

if len(parsed_output.keys()) != len(self.template.output_variables.keys()):
validation_err = f"Output keys do not match expected output keys for prompt runner {self.template.__class__.__name__}"
Expand Down
129 changes: 120 additions & 9 deletions langdspy/prompt_strategies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from langchain.prompts import BasePromptTemplate # Assuming this is the correct import path
import json
import re
from langchain.prompts import FewShotPromptTemplate
from langchain_core.runnables import RunnableSerializable
Expand Down Expand Up @@ -77,8 +78,8 @@ class PromptStrategy(BaseModel):

def validate_inputs(self, inputs_dict):
if not set(inputs_dict.keys()) == set(self.input_variables.keys()):
logger.error(f"Input keys do not match expected input keys {inputs_dict.keys()} {self.input_variables.keys()}")
raise ValueError(f"Input keys do not match expected input keys {inputs_dict.keys()} {self.input_variables.keys()}")
logger.error(f"Input keys do not match expected input keys Expected = {inputs_dict.keys()} Received = {self.input_variables.keys()}")
raise ValueError(f"Input keys do not match expected input keys Expected: {inputs_dict.keys()} Received: {self.input_variables.keys()}")

def format(self, **kwargs: Any) -> str:
logger.debug(f"PromptStrategy format with kwargs: {kwargs}")
Expand All @@ -103,6 +104,8 @@ def format_prompt(self, **kwargs: Any) -> str:

if llm_type == 'openai':
prompt = self._format_openai_prompt(trained_state, use_training, examples, **kwargs)
elif llm_type == 'openai_json':
prompt = self._format_openai_json_prompt(trained_state, use_training, examples, **kwargs)
elif llm_type == 'anthropic':
prompt = self._format_anthropic_prompt(trained_state, use_training, examples, **kwargs)

Expand All @@ -114,7 +117,9 @@ def format_prompt(self, **kwargs: Any) -> str:
raise e

def parse_output_to_fields(self, output: str, llm_type: str) -> dict:
if llm_type == 'openai':
if llm_type == 'openai_json':
return self._parse_openai_json_output_to_fields(output)
elif llm_type == 'openai':
return self._parse_openai_output_to_fields(output)
elif llm_type == 'anthropic':
return self._parse_anthropic_output_to_fields(output)
Expand All @@ -123,11 +128,14 @@ def parse_output_to_fields(self, output: str, llm_type: str) -> dict:
else:
raise ValueError(f"Unsupported LLM type: {llm_type}")


@abstractmethod
def _format_openai_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
pass

@abstractmethod
def _format_openai_json_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
pass

@abstractmethod
def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
pass
Expand All @@ -145,10 +153,79 @@ def _parse_openai_output_to_fields(self, output: str) -> dict:
def _parse_anthropic_output_to_fields(self, output: str) -> dict:
pass

@abstractmethod
def _parse_openai_json_output_to_fields(self, output: str) -> dict:
pass


class DefaultPromptStrategy(PromptStrategy):
OUTPUT_TOKEN = "🔑"

def _format_openai_json_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
prompt = "Follow the following format. Answer with a JSON object. Attributes that have values should not be changed or repeated."

if len(self.output_variables) > 1:
output_field_names = ', '.join([output_field.name for output_field in self.output_variables.values()])
prompt += f" Provide answers for {output_field_names}.\n"

if self.hint_variables:
prompt += "\n"

for _, hint_field in self.hint_variables.items():
prompt += hint_field.format_prompt_description("openai") + "\n"

prompt += "\nInput Fields:\n"
input_fields_dict = {}
for input_name, input_field in self.input_variables.items():
input_fields_dict[input_field.name] = input_field.desc
prompt += json.dumps(input_fields_dict, indent=2) + "\n"

prompt += "\nOutput Fields:\n"
output_fields_dict = {}
for output_name, output_field in self.output_variables.items():
output_fields_dict[output_field.name] = output_field.desc
prompt += json.dumps(output_fields_dict, indent=2) + "\n"

if examples:
prompt += "\nExamples:\n"
for example_input, example_output in examples:
example_dict = {"input": {}, "output": {}}
for input_name, input_field in self.input_variables.items():
example_dict["input"].update(input_field.format_prompt_value_json(example_input.get(input_name), 'openai_json'))
for output_name, output_field in self.output_variables.items():
if isinstance(example_output, dict):
example_dict["output"].update(output_field.format_prompt_value_json(example_output.get(output_name), 'openai_json'))
else:
example_dict["output"].update(output_field.format_prompt_value_json(example_output, 'openai_json'))
prompt += json.dumps(example_dict, indent=2) + "\n"

if trained_state and trained_state.examples and use_training:
prompt += "\nTrained Examples:\n"
for example_X, example_y in trained_state.examples:
example_dict = {"input": {}, "output": {}}
for input_name, input_field in self.input_variables.items():
example_dict["input"].update(input_field.format_prompt_value_json(example_X.get(input_name), 'openai_json'))
for output_name, output_field in self.output_variables.items():
if isinstance(example_y, dict):
example_dict["output"].update(output_field.format_prompt_value_json(example_y.get(output_name), 'openai_json'))
else:
example_dict["output"].update(output_field.format_prompt_value_json(example_y, 'openai_json'))
prompt += json.dumps(example_dict, indent=2) + "\n"

prompt += "\nInput:\n"
input_dict = {}
for input_name, input_field in self.input_variables.items():
input_dict.update(input_field.format_prompt_value_json(kwargs.get(input_name), 'openai_json'))
prompt += json.dumps(input_dict, indent=2) + "\n"

prompt += "\nOutput:\n"
output_dict = {}
for output_name, output_field in self.output_variables.items():
output_dict.update(output_field.format_prompt_json())
prompt += json.dumps(output_dict, indent=2) + "\n"

return prompt

def _format_openai_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
# print(f"Formatting prompt {kwargs}")
prompt = "Follow the following format. Attributes that have values should not be changed or repeated. "
Expand Down Expand Up @@ -291,20 +368,20 @@ def _parse_openai_output_to_fields(self, output: str) -> dict:
pattern = r'^([^:]+): (.*)'
lines = output.split(self.OUTPUT_TOKEN)
parsed_fields = {}
# logger.debug(f"Parsing output to fields with pattern {pattern} and lines {lines}")
logger.debug(f"Parsing output to fields with pattern {pattern} and lines {lines}")
for line in lines:
match = re.match(pattern, line, re.MULTILINE)
if match:
field_name, field_content = match.groups()
# logger.debug(f"Matched line {line} - field name {field_name} field content {field_content}")
logger.debug(f"Matched line {line} - field name {field_name} field content {field_content}")
output_field = self._get_output_field(field_name)
if output_field:
# logger.debug(f"Matched field {field_name} to output field {output_field}")
logger.debug(f"Matched field {field_name} to output field {output_field}")
parsed_fields[output_field] = field_content
else:
logger.error(f"Field {field_name} not found in output variables")
# else:
# logger.debug(f"NO MATCH line {line}")
else:
logger.debug(f"NO MATCH line {line}")

if len(self.output_variables) == 1:
first_value = next(iter(parsed_fields.values()), None)
Expand Down Expand Up @@ -343,3 +420,37 @@ def _parse_anthropic_output_to_fields(self, output: str) -> dict:
traceback.print_exc()

raise e

def _parse_openai_json_output_to_fields(self, output: str) -> dict:
print(f"Parsing openai json")
try:
# Parse the JSON output
json_output = json.loads(output)

# Initialize an empty dictionary to store the parsed fields
parsed_fields = {}

# Iterate over the output variables
for output_name, output_field in self.output_variables.items():
# Check if the output field exists in the JSON output
if output_field.name in json_output:
# Get the value of the output field from the JSON output
field_value = json_output[output_field.name]

# Apply any necessary transformations to the field value
transformed_value = output_field.transform_value(field_value)

# Store the transformed value in the parsed fields dictionary
parsed_fields[output_name] = transformed_value
else:
# If the output field is not present in the JSON output, set its value to None
parsed_fields[output_name] = None

logger.debug(f"Parsed fields: {parsed_fields}")
return parsed_fields
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON output: {e}")
raise e
except Exception as e:
logger.error(f"An error occurred while parsing JSON output: {e}")
raise e
11 changes: 4 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ python = ">=3.9,<4.0"
aiohttp = "3.9.3"
aiosignal = "1.3.1"
annotated-types = "0.6.0"
anthropic = "0.19.1"
anyio = "4.2.0"
attrs = "23.2.0"
build = "1.1.1"
Expand Down Expand Up @@ -45,11 +44,6 @@ joblib = "1.3.2"
jsonpatch = "1.33"
jsonpointer = "2.4"
keyring = "24.3.1"
langchain = "0.1.7"
langchain-anthropic = "0.1.3"
langchain-community = "0.0.20"
langchain-core = "0.1.23"
langchain-openai = "0.0.6"
markdown-it-py = "3.0.0"
marshmallow = "3.20.2"
mdurl = "0.1.2"
Expand All @@ -58,7 +52,6 @@ msgpack = "1.0.8"
multidict = "6.0.5"
mypy-extensions = "1.0.0"
numpy = "1.26.4"
openai = "1.12.0"
orjson = "3.9.15"
packaging = "23.2"
pexpect = "4.9.0"
Expand Down Expand Up @@ -99,6 +92,10 @@ xattr = "1.1.0"
yarl = "1.9.4"
zipp = "3.17.0"
ratelimit = "^2.2.1"
langchain = "^0.1.20"
langchain-openai = "^0.1.6"
anthropic = "^0.25.8"
langchain-anthropic = "^0.1.11"

[tool.poetry.dev-dependencies]

Expand Down
59 changes: 58 additions & 1 deletion tests/test_manual_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,61 @@ class MultipleOutputPromptSignature(PromptSignature):
)
assert "<input>Example input</input>" in formatted_prompt
assert "<output1>Example output 1</output1>" in formatted_prompt
assert "<output2>Example output 2</output2>" in formatted_prompt
assert "<output2>Example output 2</output2>" in formatted_prompt

# ... (previous tests remain the same)

def test_format_prompt_with_examples_openai_json():
prompt_runner = PromptRunner(template_class=TestPromptSignature, prompt_strategy=DefaultPromptStrategy)
formatted_prompt = prompt_runner.template._format_openai_json_prompt(
trained_state=None,
use_training=True,
examples=TestPromptSignature.__examples__,
input="Test input"
)
assert 'Hint field' in formatted_prompt
assert '"input": "Input field"' in formatted_prompt
assert '"output": "Output field"' in formatted_prompt
assert '"input": "Example input 1"' in formatted_prompt
assert '"output": "Example output 1"' in formatted_prompt
assert '"input": "Example input 2"' in formatted_prompt
assert '"output": "Example output 2"' in formatted_prompt
assert '"input": "Test input"' in formatted_prompt
assert '"output": ""' in formatted_prompt

def test_format_prompt_without_examples_openai_json():
prompt_runner = PromptRunner(template_class=TestPromptSignature, prompt_strategy=DefaultPromptStrategy)
formatted_prompt = prompt_runner.template._format_openai_json_prompt(
trained_state=None,
use_training=False,
examples=[],
input="Test input"
)
assert 'Hint field' in formatted_prompt
assert '"input": "Input field"' in formatted_prompt
assert '"output": "Output field"' in formatted_prompt
assert '"input": "Test input"' in formatted_prompt
assert '"output": ""' in formatted_prompt
assert "Example input 1" not in formatted_prompt
assert "Example output 1" not in formatted_prompt
assert "Example input 2" not in formatted_prompt
assert "Example output 2" not in formatted_prompt

def test_format_prompt_with_multiple_output_fields_openai_json():
class MultipleOutputPromptSignature(PromptSignature):
input = InputField(name="input", desc="Input field")
output1 = OutputField(name="output1", desc="Output field 1")
output2 = OutputField(name="output2", desc="Output field 2")
__examples__ = [
({"input": "Example input"}, {"output1": "Example output 1", "output2": "Example output 2"}),
]
prompt_runner = PromptRunner(template_class=MultipleOutputPromptSignature, prompt_strategy=DefaultPromptStrategy)
formatted_prompt = prompt_runner.template._format_openai_json_prompt(
trained_state=None,
use_training=True,
examples=MultipleOutputPromptSignature.__examples__,
input="Test input"
)
assert '"input": "Example input"' in formatted_prompt
assert '"output1": "Example output 1"' in formatted_prompt
assert '"output2": "Example output 2"' in formatted_prompt
Loading

0 comments on commit 81b885a

Please sign in to comment.