Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduces openai json support #26

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions langdspy/field_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,19 @@ def format_prompt_description(self, llm_type: str):
return f"{self._start_format_openai()}: {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{self.desc}</{self.name}>"

def format_prompt_value(self, value, llm_type: str):
value = self.format_value(value)
if llm_type == "openai":
return f"{self._start_format_openai()}: {value}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{value}</{self.name}>"

def format_prompt_value_json(self, value, llm_type: str):
value = self.format_value(value)
return {self.name: value}


class InputFieldList(InputField):
def format_prompt_description(self, llm_type: str):
if llm_type == "openai":
Expand Down Expand Up @@ -157,18 +163,27 @@ def format_prompt_description(self, llm_type: str):
return f"{self._start_format_openai()}: {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{self.desc}</{self.name}>"

def format_prompt_value(self, value, llm_type: str):
value = self.format_value(value)
if llm_type == "openai":
return f"{self._start_format_openai()}: {value}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{value}</{self.name}>"

def format_prompt(self, llm_type: str):
if llm_type == "openai":
return f"{self._start_format_openai()}:"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}</{self.name}>"

def format_prompt_json(self):
return {self.name: ""}

def format_prompt_value_json(self, value, llm_type: str):
return {self.name: value}


class OutputFieldBool(OutputField):
def __init__(self, name: str, desc: str, **kwargs):
if not 'transformer' in kwargs:
Expand Down
8 changes: 7 additions & 1 deletion langdspy/prompt_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,13 @@ def set_model_kwargs(self, model_kwargs):
self.model_kwargs.update(model_kwargs)

def _determine_llm_type(self, llm):
print(f"Determining llm type")
if isinstance(llm, ChatOpenAI): # Assuming OpenAILLM is the class for OpenAI models
print(f"Getting llm type")
logger.debug(llm.kwargs)
if llm.kwargs.get('response_format', {}).get('type') == 'json_object':
logger.info("OpenAI model response format is json_object")
return 'openai_json'
return 'openai'
elif isinstance(llm, ChatAnthropic): # Assuming AnthropicLLM is the class for Anthropic models
return 'anthropic'
Expand Down Expand Up @@ -186,7 +192,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna

len_parsed_output = len(parsed_output.keys())
len_output_variables = len(self.template.output_variables.keys())
# logger.debug(f"Parsed output keys: {parsed_output.keys()} [{len_parsed_output}] Expected output keys: {self.template.output_variables.keys()} [{len_output_variables}]")
logger.debug(f"Parsed output keys: {parsed_output.keys()} [{len_parsed_output}] Expected output keys: {self.template.output_variables.keys()} [{len_output_variables}]")

if len(parsed_output.keys()) != len(self.template.output_variables.keys()):
validation_err = f"Output keys do not match expected output keys for prompt runner {self.template.__class__.__name__}"
Expand Down
129 changes: 120 additions & 9 deletions langdspy/prompt_strategies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from langchain.prompts import BasePromptTemplate # Assuming this is the correct import path
import json
import re
from langchain.prompts import FewShotPromptTemplate
from langchain_core.runnables import RunnableSerializable
Expand Down Expand Up @@ -77,8 +78,8 @@ class PromptStrategy(BaseModel):

def validate_inputs(self, inputs_dict):
if not set(inputs_dict.keys()) == set(self.input_variables.keys()):
logger.error(f"Input keys do not match expected input keys {inputs_dict.keys()} {self.input_variables.keys()}")
raise ValueError(f"Input keys do not match expected input keys {inputs_dict.keys()} {self.input_variables.keys()}")
logger.error(f"Input keys do not match expected input keys Expected = {inputs_dict.keys()} Received = {self.input_variables.keys()}")
raise ValueError(f"Input keys do not match expected input keys Expected: {inputs_dict.keys()} Received: {self.input_variables.keys()}")

def format(self, **kwargs: Any) -> str:
logger.debug(f"PromptStrategy format with kwargs: {kwargs}")
Expand All @@ -103,6 +104,8 @@ def format_prompt(self, **kwargs: Any) -> str:

if llm_type == 'openai':
prompt = self._format_openai_prompt(trained_state, use_training, examples, **kwargs)
elif llm_type == 'openai_json':
prompt = self._format_openai_json_prompt(trained_state, use_training, examples, **kwargs)
elif llm_type == 'anthropic':
prompt = self._format_anthropic_prompt(trained_state, use_training, examples, **kwargs)

Expand All @@ -114,7 +117,9 @@ def format_prompt(self, **kwargs: Any) -> str:
raise e

def parse_output_to_fields(self, output: str, llm_type: str) -> dict:
if llm_type == 'openai':
if llm_type == 'openai_json':
return self._parse_openai_json_output_to_fields(output)
elif llm_type == 'openai':
return self._parse_openai_output_to_fields(output)
elif llm_type == 'anthropic':
return self._parse_anthropic_output_to_fields(output)
Expand All @@ -123,11 +128,14 @@ def parse_output_to_fields(self, output: str, llm_type: str) -> dict:
else:
raise ValueError(f"Unsupported LLM type: {llm_type}")


@abstractmethod
def _format_openai_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
pass

@abstractmethod
def _format_openai_json_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
pass

@abstractmethod
def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
pass
Expand All @@ -145,10 +153,79 @@ def _parse_openai_output_to_fields(self, output: str) -> dict:
def _parse_anthropic_output_to_fields(self, output: str) -> dict:
pass

@abstractmethod
def _parse_openai_json_output_to_fields(self, output: str) -> dict:
pass


class DefaultPromptStrategy(PromptStrategy):
OUTPUT_TOKEN = "🔑"

def _format_openai_json_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
prompt = "Follow the following format. Answer with a JSON object. Attributes that have values should not be changed or repeated."

if len(self.output_variables) > 1:
output_field_names = ', '.join([output_field.name for output_field in self.output_variables.values()])
prompt += f" Provide answers for {output_field_names}.\n"

if self.hint_variables:
prompt += "\n"

for _, hint_field in self.hint_variables.items():
prompt += hint_field.format_prompt_description("openai") + "\n"

prompt += "\nInput Fields:\n"
input_fields_dict = {}
for input_name, input_field in self.input_variables.items():
input_fields_dict[input_field.name] = input_field.desc
prompt += json.dumps(input_fields_dict, indent=2) + "\n"

prompt += "\nOutput Fields:\n"
output_fields_dict = {}
for output_name, output_field in self.output_variables.items():
output_fields_dict[output_field.name] = output_field.desc
prompt += json.dumps(output_fields_dict, indent=2) + "\n"

if examples:
prompt += "\nExamples:\n"
for example_input, example_output in examples:
example_dict = {"input": {}, "output": {}}
for input_name, input_field in self.input_variables.items():
example_dict["input"].update(input_field.format_prompt_value_json(example_input.get(input_name), 'openai_json'))
for output_name, output_field in self.output_variables.items():
if isinstance(example_output, dict):
example_dict["output"].update(output_field.format_prompt_value_json(example_output.get(output_name), 'openai_json'))
else:
example_dict["output"].update(output_field.format_prompt_value_json(example_output, 'openai_json'))
prompt += json.dumps(example_dict, indent=2) + "\n"

if trained_state and trained_state.examples and use_training:
prompt += "\nTrained Examples:\n"
for example_X, example_y in trained_state.examples:
example_dict = {"input": {}, "output": {}}
for input_name, input_field in self.input_variables.items():
example_dict["input"].update(input_field.format_prompt_value_json(example_X.get(input_name), 'openai_json'))
for output_name, output_field in self.output_variables.items():
if isinstance(example_y, dict):
example_dict["output"].update(output_field.format_prompt_value_json(example_y.get(output_name), 'openai_json'))
else:
example_dict["output"].update(output_field.format_prompt_value_json(example_y, 'openai_json'))
prompt += json.dumps(example_dict, indent=2) + "\n"

prompt += "\nInput:\n"
input_dict = {}
for input_name, input_field in self.input_variables.items():
input_dict.update(input_field.format_prompt_value_json(kwargs.get(input_name), 'openai_json'))
prompt += json.dumps(input_dict, indent=2) + "\n"

prompt += "\nOutput:\n"
output_dict = {}
for output_name, output_field in self.output_variables.items():
output_dict.update(output_field.format_prompt_json())
prompt += json.dumps(output_dict, indent=2) + "\n"

return prompt

def _format_openai_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
# print(f"Formatting prompt {kwargs}")
prompt = "Follow the following format. Attributes that have values should not be changed or repeated. "
Expand Down Expand Up @@ -291,20 +368,20 @@ def _parse_openai_output_to_fields(self, output: str) -> dict:
pattern = r'^([^:]+): (.*)'
lines = output.split(self.OUTPUT_TOKEN)
parsed_fields = {}
# logger.debug(f"Parsing output to fields with pattern {pattern} and lines {lines}")
logger.debug(f"Parsing output to fields with pattern {pattern} and lines {lines}")
for line in lines:
match = re.match(pattern, line, re.MULTILINE)
if match:
field_name, field_content = match.groups()
# logger.debug(f"Matched line {line} - field name {field_name} field content {field_content}")
logger.debug(f"Matched line {line} - field name {field_name} field content {field_content}")
output_field = self._get_output_field(field_name)
if output_field:
# logger.debug(f"Matched field {field_name} to output field {output_field}")
logger.debug(f"Matched field {field_name} to output field {output_field}")
parsed_fields[output_field] = field_content
else:
logger.error(f"Field {field_name} not found in output variables")
# else:
# logger.debug(f"NO MATCH line {line}")
else:
logger.debug(f"NO MATCH line {line}")

if len(self.output_variables) == 1:
first_value = next(iter(parsed_fields.values()), None)
Expand Down Expand Up @@ -343,3 +420,37 @@ def _parse_anthropic_output_to_fields(self, output: str) -> dict:
traceback.print_exc()

raise e

def _parse_openai_json_output_to_fields(self, output: str) -> dict:
print(f"Parsing openai json")
try:
# Parse the JSON output
json_output = json.loads(output)

# Initialize an empty dictionary to store the parsed fields
parsed_fields = {}

# Iterate over the output variables
for output_name, output_field in self.output_variables.items():
# Check if the output field exists in the JSON output
if output_field.name in json_output:
# Get the value of the output field from the JSON output
field_value = json_output[output_field.name]

# Apply any necessary transformations to the field value
transformed_value = output_field.transform_value(field_value)

# Store the transformed value in the parsed fields dictionary
parsed_fields[output_name] = transformed_value
else:
# If the output field is not present in the JSON output, set its value to None
parsed_fields[output_name] = None

logger.debug(f"Parsed fields: {parsed_fields}")
return parsed_fields
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON output: {e}")
raise e
except Exception as e:
logger.error(f"An error occurred while parsing JSON output: {e}")
raise e
11 changes: 4 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ python = ">=3.9,<4.0"
aiohttp = "3.9.3"
aiosignal = "1.3.1"
annotated-types = "0.6.0"
anthropic = "0.19.1"
anyio = "4.2.0"
attrs = "23.2.0"
build = "1.1.1"
Expand Down Expand Up @@ -45,11 +44,6 @@ joblib = "1.3.2"
jsonpatch = "1.33"
jsonpointer = "2.4"
keyring = "24.3.1"
langchain = "0.1.7"
langchain-anthropic = "0.1.3"
langchain-community = "0.0.20"
langchain-core = "0.1.23"
langchain-openai = "0.0.6"
markdown-it-py = "3.0.0"
marshmallow = "3.20.2"
mdurl = "0.1.2"
Expand All @@ -58,7 +52,6 @@ msgpack = "1.0.8"
multidict = "6.0.5"
mypy-extensions = "1.0.0"
numpy = "1.26.4"
openai = "1.12.0"
orjson = "3.9.15"
packaging = "23.2"
pexpect = "4.9.0"
Expand Down Expand Up @@ -99,6 +92,10 @@ xattr = "1.1.0"
yarl = "1.9.4"
zipp = "3.17.0"
ratelimit = "^2.2.1"
langchain = "^0.1.20"
langchain-openai = "^0.1.6"
anthropic = "^0.25.8"
langchain-anthropic = "^0.1.11"

[tool.poetry.dev-dependencies]

Expand Down
59 changes: 58 additions & 1 deletion tests/test_manual_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,61 @@ class MultipleOutputPromptSignature(PromptSignature):
)
assert "<input>Example input</input>" in formatted_prompt
assert "<output1>Example output 1</output1>" in formatted_prompt
assert "<output2>Example output 2</output2>" in formatted_prompt
assert "<output2>Example output 2</output2>" in formatted_prompt

# ... (previous tests remain the same)

def test_format_prompt_with_examples_openai_json():
prompt_runner = PromptRunner(template_class=TestPromptSignature, prompt_strategy=DefaultPromptStrategy)
formatted_prompt = prompt_runner.template._format_openai_json_prompt(
trained_state=None,
use_training=True,
examples=TestPromptSignature.__examples__,
input="Test input"
)
assert 'Hint field' in formatted_prompt
assert '"input": "Input field"' in formatted_prompt
assert '"output": "Output field"' in formatted_prompt
assert '"input": "Example input 1"' in formatted_prompt
assert '"output": "Example output 1"' in formatted_prompt
assert '"input": "Example input 2"' in formatted_prompt
assert '"output": "Example output 2"' in formatted_prompt
assert '"input": "Test input"' in formatted_prompt
assert '"output": ""' in formatted_prompt

def test_format_prompt_without_examples_openai_json():
prompt_runner = PromptRunner(template_class=TestPromptSignature, prompt_strategy=DefaultPromptStrategy)
formatted_prompt = prompt_runner.template._format_openai_json_prompt(
trained_state=None,
use_training=False,
examples=[],
input="Test input"
)
assert 'Hint field' in formatted_prompt
assert '"input": "Input field"' in formatted_prompt
assert '"output": "Output field"' in formatted_prompt
assert '"input": "Test input"' in formatted_prompt
assert '"output": ""' in formatted_prompt
assert "Example input 1" not in formatted_prompt
assert "Example output 1" not in formatted_prompt
assert "Example input 2" not in formatted_prompt
assert "Example output 2" not in formatted_prompt

def test_format_prompt_with_multiple_output_fields_openai_json():
class MultipleOutputPromptSignature(PromptSignature):
input = InputField(name="input", desc="Input field")
output1 = OutputField(name="output1", desc="Output field 1")
output2 = OutputField(name="output2", desc="Output field 2")
__examples__ = [
({"input": "Example input"}, {"output1": "Example output 1", "output2": "Example output 2"}),
]
prompt_runner = PromptRunner(template_class=MultipleOutputPromptSignature, prompt_strategy=DefaultPromptStrategy)
formatted_prompt = prompt_runner.template._format_openai_json_prompt(
trained_state=None,
use_training=True,
examples=MultipleOutputPromptSignature.__examples__,
input="Test input"
)
assert '"input": "Example input"' in formatted_prompt
assert '"output1": "Example output 1"' in formatted_prompt
assert '"output2": "Example output 2"' in formatted_prompt
Loading
Loading