Skip to content

Commit

Permalink
Works
Browse files Browse the repository at this point in the history
  • Loading branch information
aelaguiz committed Mar 24, 2024
1 parent 1411f9c commit 4fd543c
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 17 deletions.
2 changes: 1 addition & 1 deletion langdspy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .field_descriptors import InputField, OutputField, InputFieldList, HintField, OutputFieldEnum, OutputFieldEnumList
from .field_descriptors import InputField, OutputField, InputFieldList, HintField, OutputFieldEnum, OutputFieldEnumList, OutputFieldBool
from .prompt_strategies import PromptSignature, PromptStrategy, DefaultPromptStrategy
from .prompt_runners import PromptRunner, RunnableConfig, Prediction, MultiPromptRunner
from .model import Model, TrainedModelState
Expand Down
17 changes: 17 additions & 0 deletions langdspy/field_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,23 @@ def format_prompt(self, llm_type: str):
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}</{self.name}>"

class OutputFieldBool(OutputField):
def __init__(self, name: str, desc: str, **kwargs):
if not 'transformer' in kwargs:
kwargs['transformer'] = transformers.as_bool
if not 'validator' in kwargs:
kwargs['validator'] = validators.is_one_of
kwargs['choices'] = ['Yes', 'No']

super().__init__(name, desc, **kwargs)

def format_prompt_description(self, llm_type: str):
choices_str = ", ".join(['Yes', 'No'])
if llm_type == "openai":
return f"{self._start_format_openai()}: One of: {choices_str} - {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}One of: <choices>{choices_str}</choices> - {self.desc}</{self.name}>"

class OutputFieldEnum(OutputField):
def __init__(self, name: str, desc: str, enum: Enum, **kwargs):
kwargs['enum'] = enum
Expand Down
19 changes: 13 additions & 6 deletions langdspy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def load(self, filepath):

def predict(self, X, llm):
y = Parallel(n_jobs=self.n_jobs, backend='threading')(
delayed(self.invoke)(item, {**self.kwargs, 'trained_state': self.trained_state, 'llm': llm})
delayed(self.invoke)(item, {**self.kwargs, 'trained_state': self.trained_state, 'llm': llm, 'max_tries': 1})
for item in tqdm(X, desc="Predicting", total=len(X))
)
return y
Expand Down Expand Up @@ -125,27 +125,34 @@ def fit(self, X, y, score_func, llm, n_examples=3, example_ratio=0.7, n_iter=Non
logger.debug(f"Total number of examples: {n_examples} Example size: {example_size} n_examples: {n_examples} example_X size: {len(example_X)} Scoring size: {len(scoring_X)}")

def evaluate_subset(subset):
# logger.debug(f"Evaluating subset: {subset}")
subset_X, subset_y = zip(*subset)
self.trained_state.examples = subset

# Predict on the scoring set
predicted_slugs = Parallel(n_jobs=self.n_jobs)(
predicted_y = Parallel(n_jobs=self.n_jobs)(
delayed(self.invoke)(item, config={
**self.kwargs,
'trained_state': self.trained_state,
'llm': llm
'llm': llm,
'max_tries': 1
})
for item in scoring_X
)
score = score_func(scoring_X, scoring_y, predicted_slugs)
logger.debug(f"Training subset scored {score}")
score = score_func(scoring_X, scoring_y, predicted_y)
# logger.debug(f"Training subset scored {score}")
return score, subset

# logger.debug(f"Generating subsets")

# Generate all possible subsets
all_subsets = list(itertools.combinations(zip(example_X, example_y), n_examples))
# all_subsets = list(itertools.combinations(zip(example_X, example_y), n_examples))
all_subsets = [random.sample(list(zip(example_X, example_y)), n_examples) for _ in range(n_iter)]


# Randomize the order of subsets
random.shuffle(all_subsets)
logger.debug(f"Total number of subsets: {len(all_subsets)}")

# Limit the number of iterations if n_iter is specified
if n_iter is not None:
Expand Down
12 changes: 9 additions & 3 deletions langdspy/prompt_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
formatted_prompt = self.template.format_prompt(**invoke_args)

if print_prompt:
print(f"------------------------PROMPT START--------------------------------")
print(formatted_prompt)
print(f"------------------------PROMPT END----------------------------------\n")

# logger.debug(f"Invoke args: {invoke_args}")
res = chain.invoke(invoke_args, config=config)
Expand All @@ -167,7 +169,9 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna

# logger.debug(f"Raw output for prompt runner {self.template.__class__.__name__}: {res}")
if print_prompt:
print(f"------------------------RESULT START--------------------------------")
print(res)
print(f"------------------------RESULT END----------------------------------\n")

# Use the parse_output_to_fields method from the PromptStrategy
parsed_output = {}
Expand Down Expand Up @@ -223,14 +227,16 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
if validation_err is None:
return res

logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__}, pausing before we retry")
time.sleep(random.uniform(0.1, 1.5))
max_tries -= 1
if max_tries >= 1:
logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__}, pausing before we retry")
time.sleep(random.uniform(0.05, 0.25))

if hard_fail:
raise ValueError(f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries.")
else:
logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries, returning unvalidated output.")
logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries, returning None.")
res = {attr_name: None for attr_name in self.template.output_variables.keys()}

return res

Expand Down
14 changes: 10 additions & 4 deletions langdspy/prompt_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwar

output_field_names = ', '.join([output_field.name for output_field in self.output_variables.values()])
# Format the instruction with the extracted names
prompt += f"Provide answers for {output_field_names}. Follow the XML output format.\n"
prompt += f"Provide answers for output fields {output_field_names}. Follow the XML output format, only show the output fields do not repeat the hints, input fields or examples.\n"

if self.hint_variables:
prompt += "\n<hints>\n"
Expand Down Expand Up @@ -326,9 +326,15 @@ def _parse_anthropic_output_to_fields(self, output: str) -> dict:
parsed_fields = {}
for output_name, output_field in self.output_variables.items():
pattern = fr"<{output_field.name}>(.*?)</{output_field.name}>"
match = re.search(pattern, output, re.DOTALL)
if match:
parsed_fields[output_name] = match.group(1).strip()
# match = re.search(pattern, output, re.DOTALL)
# if match:
# parsed_fields[output_name] = match.group(1).strip()
matches = re.findall(pattern, output, re.DOTALL)
if matches:
# Take the last match
last_match = matches[-1]
parsed_fields[output_name] = last_match.strip()


logger.debug(f"Parsed fields: {parsed_fields}")
return parsed_fields
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ virtualenv = "20.25.1"
xattr = "1.1.0"
yarl = "1.9.4"
zipp = "3.17.0"
ratelimit = "^2.2.1"

[tool.poetry.dev-dependencies]

Expand Down
16 changes: 15 additions & 1 deletion tests/test_field_descriptors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from enum import Enum
from langdspy.field_descriptors import InputField, InputFieldList, OutputField, OutputFieldEnum, OutputFieldEnumList
from langdspy.field_descriptors import InputField, InputFieldList, OutputField, OutputFieldEnum, OutputFieldEnumList, OutputFieldBool

def test_input_field_initialization():
field = InputField("name", "description")
Expand Down Expand Up @@ -59,3 +59,17 @@ def test_output_field_enum_list_format_prompt_description():
field = OutputFieldEnumList("name", "description", TestEnum)
assert "A comma-separated list of one or more of: VALUE1, VALUE2, VALUE3" in field.format_prompt_description("openai")
assert "A comma-separated list of one or more of: <choices>VALUE1, VALUE2, VALUE3</choices>" in field.format_prompt_description("anthropic")


def test_output_field_bool_initialization():
field = OutputFieldBool("name", "description")
assert field.name == "name"
assert field.desc == "description"
assert field.transformer.__name__ == "as_bool"
assert field.validator.__name__ == "is_one_of"
assert field.kwargs['choices'] == ["Yes", "No"]

def test_output_field_bool_format_prompt_description():
field = OutputFieldBool("name", "description")
assert "One of: Yes, No" in field.format_prompt_description("openai")
assert "One of: <choices>Yes, No</choices>" in field.format_prompt_description("anthropic")
71 changes: 69 additions & 2 deletions tests/test_output_parsing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
from langdspy.field_descriptors import InputField, OutputField
from langdspy.field_descriptors import InputField, OutputField, OutputFieldBool
from langdspy.prompt_strategies import PromptSignature, DefaultPromptStrategy
from langdspy.prompt_runners import PromptRunner
from langdspy.formatters import as_multiline

class TestOutputParsingPromptSignature(PromptSignature):
ticket_summary = InputField(name="Ticket Summary", desc="Summary of the ticket we're trying to analyze.")
Expand Down Expand Up @@ -69,4 +70,70 @@ def test_output_parsing_with_missing_fields():
result = prompt_runner.template.parse_output_to_fields(output_data, config["llm_type"])

assert result["buyer_issues_summary"] == "The buyer is trying to personalize their order by selecting variants like color or size, but after making their selections and hitting \"done\", the changes are not being reflected. They are also asking how long delivery will take."
assert result.get("buyer_issue_category") is None
assert result.get("buyer_issue_category") is None

def test_repeated_input_output():
output_data = """<input_fields>
<Ticket Summary>A summary of the ticket</Ticket Summary>
</input_fields>
<output_fields>
<Spam>One of: <choices>Yes, No</choices> - Is this ticket a spam or sales ticket</Spam>
</output_fields>
<examples>
<example>
<input>
<Ticket Summary>«Ticket ID: 2044
Freshdesk ID: 335398
Status: PENDING
Processing State: TRIAGED_READY
Subject: Horror collection
Priority: 2
Messages:»</Ticket Summary>
</input>
<output>
<Spam>No</Spam>
</output>
</example>
<example>
<input>
<Ticket Summary>«Ticket ID: 2504
Freshdesk ID: 334191
Status: PENDING
Processing State: TRIAGED_READY
Subject: Ch
Messages:»</Ticket Summary>
</input>
<output>
<Spam>No</Spam>
</output>
</example>
</examples>
<input>
<Ticket Summary>«Ticket ID: 2453
Freshdesk ID: 334312
Status: IN_PROGRESS
Processing State: TRIAGED_READY
Subject: No Response from Seller
Description: [Chatbot]: Hi there, how can we help you today? [user]: I sent a message to the seller on 2/2 and received an auto reply to allow 2-3 days for someone to get back to me. To date, I have not heard anything from the seller. [Chatbot]: (No Intent Predicted) [Chatbot]: I understand your concern about not hearing back from the seller. If it's been more than 2 business days since you contacted them, Cratejoy can assist by reaching out on your behalf. Please contact Cratejoy Support for further help with this issue. * Shipments Lost In Transit * Getting Help With An Unshipped Order * Damaged, Duplicate or Defective Items [Chatbot]: Was I able to help you resolve your question?
(Yes, thank you!)
</input>
<output>
<Spam>No</Spam>
</output>
"""

class IsTicketSpam(PromptSignature):
ticket_summary = InputField(name="Ticket Summary", desc="A summary of the ticket", formatter=as_multiline)
is_spam = OutputFieldBool(name="Spam", desc="Is this ticket a spam or sales ticket")

config = {"llm_type": "anthropic"}
prompt_runner = PromptRunner(template_class=IsTicketSpam, prompt_strategy=DefaultPromptStrategy)
result = prompt_runner.template.parse_output_to_fields(output_data, config["llm_type"])

assert result["is_spam"] == "No"

0 comments on commit 4fd543c

Please sign in to comment.