From 4fd543caa67e68ea49ce0720dab07924f527eb3a Mon Sep 17 00:00:00 2001 From: Amir Elaguizy Date: Sun, 24 Mar 2024 17:03:42 -0500 Subject: [PATCH] Works --- langdspy/__init__.py | 2 +- langdspy/field_descriptors.py | 17 ++++++++ langdspy/model.py | 19 ++++++--- langdspy/prompt_runners.py | 12 ++++-- langdspy/prompt_strategies.py | 14 +++++-- pyproject.toml | 1 + tests/test_field_descriptors.py | 16 +++++++- tests/test_output_parsing.py | 71 ++++++++++++++++++++++++++++++++- 8 files changed, 135 insertions(+), 17 deletions(-) diff --git a/langdspy/__init__.py b/langdspy/__init__.py index e84c4b9..ae02984 100644 --- a/langdspy/__init__.py +++ b/langdspy/__init__.py @@ -1,4 +1,4 @@ -from .field_descriptors import InputField, OutputField, InputFieldList, HintField, OutputFieldEnum, OutputFieldEnumList +from .field_descriptors import InputField, OutputField, InputFieldList, HintField, OutputFieldEnum, OutputFieldEnumList, OutputFieldBool from .prompt_strategies import PromptSignature, PromptStrategy, DefaultPromptStrategy from .prompt_runners import PromptRunner, RunnableConfig, Prediction, MultiPromptRunner from .model import Model, TrainedModelState diff --git a/langdspy/field_descriptors.py b/langdspy/field_descriptors.py index 122101e..97e0e61 100644 --- a/langdspy/field_descriptors.py +++ b/langdspy/field_descriptors.py @@ -123,6 +123,23 @@ def format_prompt(self, llm_type: str): elif llm_type == "anthropic": return f"{self._start_format_anthropic()}" +class OutputFieldBool(OutputField): + def __init__(self, name: str, desc: str, **kwargs): + if not 'transformer' in kwargs: + kwargs['transformer'] = transformers.as_bool + if not 'validator' in kwargs: + kwargs['validator'] = validators.is_one_of + kwargs['choices'] = ['Yes', 'No'] + + super().__init__(name, desc, **kwargs) + + def format_prompt_description(self, llm_type: str): + choices_str = ", ".join(['Yes', 'No']) + if llm_type == "openai": + return f"{self._start_format_openai()}: One of: {choices_str} - {self.desc}" + elif llm_type == "anthropic": + return f"{self._start_format_anthropic()}One of: {choices_str} - {self.desc}" + class OutputFieldEnum(OutputField): def __init__(self, name: str, desc: str, enum: Enum, **kwargs): kwargs['enum'] = enum diff --git a/langdspy/model.py b/langdspy/model.py index d2bdb35..855264f 100644 --- a/langdspy/model.py +++ b/langdspy/model.py @@ -75,7 +75,7 @@ def load(self, filepath): def predict(self, X, llm): y = Parallel(n_jobs=self.n_jobs, backend='threading')( - delayed(self.invoke)(item, {**self.kwargs, 'trained_state': self.trained_state, 'llm': llm}) + delayed(self.invoke)(item, {**self.kwargs, 'trained_state': self.trained_state, 'llm': llm, 'max_tries': 1}) for item in tqdm(X, desc="Predicting", total=len(X)) ) return y @@ -125,27 +125,34 @@ def fit(self, X, y, score_func, llm, n_examples=3, example_ratio=0.7, n_iter=Non logger.debug(f"Total number of examples: {n_examples} Example size: {example_size} n_examples: {n_examples} example_X size: {len(example_X)} Scoring size: {len(scoring_X)}") def evaluate_subset(subset): + # logger.debug(f"Evaluating subset: {subset}") subset_X, subset_y = zip(*subset) self.trained_state.examples = subset # Predict on the scoring set - predicted_slugs = Parallel(n_jobs=self.n_jobs)( + predicted_y = Parallel(n_jobs=self.n_jobs)( delayed(self.invoke)(item, config={ **self.kwargs, 'trained_state': self.trained_state, - 'llm': llm + 'llm': llm, + 'max_tries': 1 }) for item in scoring_X ) - score = score_func(scoring_X, scoring_y, predicted_slugs) - logger.debug(f"Training subset scored {score}") + score = score_func(scoring_X, scoring_y, predicted_y) + # logger.debug(f"Training subset scored {score}") return score, subset + + # logger.debug(f"Generating subsets") # Generate all possible subsets - all_subsets = list(itertools.combinations(zip(example_X, example_y), n_examples)) + # all_subsets = list(itertools.combinations(zip(example_X, example_y), n_examples)) + all_subsets = [random.sample(list(zip(example_X, example_y)), n_examples) for _ in range(n_iter)] + # Randomize the order of subsets random.shuffle(all_subsets) + logger.debug(f"Total number of subsets: {len(all_subsets)}") # Limit the number of iterations if n_iter is specified if n_iter is not None: diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 4bde690..dacff94 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -150,7 +150,9 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna formatted_prompt = self.template.format_prompt(**invoke_args) if print_prompt: + print(f"------------------------PROMPT START--------------------------------") print(formatted_prompt) + print(f"------------------------PROMPT END----------------------------------\n") # logger.debug(f"Invoke args: {invoke_args}") res = chain.invoke(invoke_args, config=config) @@ -167,7 +169,9 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna # logger.debug(f"Raw output for prompt runner {self.template.__class__.__name__}: {res}") if print_prompt: + print(f"------------------------RESULT START--------------------------------") print(res) + print(f"------------------------RESULT END----------------------------------\n") # Use the parse_output_to_fields method from the PromptStrategy parsed_output = {} @@ -223,14 +227,16 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna if validation_err is None: return res - logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__}, pausing before we retry") - time.sleep(random.uniform(0.1, 1.5)) max_tries -= 1 + if max_tries >= 1: + logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__}, pausing before we retry") + time.sleep(random.uniform(0.05, 0.25)) if hard_fail: raise ValueError(f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries.") else: - logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries, returning unvalidated output.") + logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries, returning None.") + res = {attr_name: None for attr_name in self.template.output_variables.keys()} return res diff --git a/langdspy/prompt_strategies.py b/langdspy/prompt_strategies.py index 1e8a51d..1302a66 100644 --- a/langdspy/prompt_strategies.py +++ b/langdspy/prompt_strategies.py @@ -220,7 +220,7 @@ def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwar output_field_names = ', '.join([output_field.name for output_field in self.output_variables.values()]) # Format the instruction with the extracted names - prompt += f"Provide answers for {output_field_names}. Follow the XML output format.\n" + prompt += f"Provide answers for output fields {output_field_names}. Follow the XML output format, only show the output fields do not repeat the hints, input fields or examples.\n" if self.hint_variables: prompt += "\n\n" @@ -326,9 +326,15 @@ def _parse_anthropic_output_to_fields(self, output: str) -> dict: parsed_fields = {} for output_name, output_field in self.output_variables.items(): pattern = fr"<{output_field.name}>(.*?)" - match = re.search(pattern, output, re.DOTALL) - if match: - parsed_fields[output_name] = match.group(1).strip() + # match = re.search(pattern, output, re.DOTALL) + # if match: + # parsed_fields[output_name] = match.group(1).strip() + matches = re.findall(pattern, output, re.DOTALL) + if matches: + # Take the last match + last_match = matches[-1] + parsed_fields[output_name] = last_match.strip() + logger.debug(f"Parsed fields: {parsed_fields}") return parsed_fields diff --git a/pyproject.toml b/pyproject.toml index a35a850..27f0808 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,6 +98,7 @@ virtualenv = "20.25.1" xattr = "1.1.0" yarl = "1.9.4" zipp = "3.17.0" +ratelimit = "^2.2.1" [tool.poetry.dev-dependencies] diff --git a/tests/test_field_descriptors.py b/tests/test_field_descriptors.py index 9734e0e..f7b6c87 100644 --- a/tests/test_field_descriptors.py +++ b/tests/test_field_descriptors.py @@ -1,6 +1,6 @@ import pytest from enum import Enum -from langdspy.field_descriptors import InputField, InputFieldList, OutputField, OutputFieldEnum, OutputFieldEnumList +from langdspy.field_descriptors import InputField, InputFieldList, OutputField, OutputFieldEnum, OutputFieldEnumList, OutputFieldBool def test_input_field_initialization(): field = InputField("name", "description") @@ -59,3 +59,17 @@ def test_output_field_enum_list_format_prompt_description(): field = OutputFieldEnumList("name", "description", TestEnum) assert "A comma-separated list of one or more of: VALUE1, VALUE2, VALUE3" in field.format_prompt_description("openai") assert "A comma-separated list of one or more of: VALUE1, VALUE2, VALUE3" in field.format_prompt_description("anthropic") + + +def test_output_field_bool_initialization(): + field = OutputFieldBool("name", "description") + assert field.name == "name" + assert field.desc == "description" + assert field.transformer.__name__ == "as_bool" + assert field.validator.__name__ == "is_one_of" + assert field.kwargs['choices'] == ["Yes", "No"] + +def test_output_field_bool_format_prompt_description(): + field = OutputFieldBool("name", "description") + assert "One of: Yes, No" in field.format_prompt_description("openai") + assert "One of: Yes, No" in field.format_prompt_description("anthropic") \ No newline at end of file diff --git a/tests/test_output_parsing.py b/tests/test_output_parsing.py index 4303330..76d3d40 100644 --- a/tests/test_output_parsing.py +++ b/tests/test_output_parsing.py @@ -1,7 +1,8 @@ import pytest -from langdspy.field_descriptors import InputField, OutputField +from langdspy.field_descriptors import InputField, OutputField, OutputFieldBool from langdspy.prompt_strategies import PromptSignature, DefaultPromptStrategy from langdspy.prompt_runners import PromptRunner +from langdspy.formatters import as_multiline class TestOutputParsingPromptSignature(PromptSignature): ticket_summary = InputField(name="Ticket Summary", desc="Summary of the ticket we're trying to analyze.") @@ -69,4 +70,70 @@ def test_output_parsing_with_missing_fields(): result = prompt_runner.template.parse_output_to_fields(output_data, config["llm_type"]) assert result["buyer_issues_summary"] == "The buyer is trying to personalize their order by selecting variants like color or size, but after making their selections and hitting \"done\", the changes are not being reflected. They are also asking how long delivery will take." - assert result.get("buyer_issue_category") is None \ No newline at end of file + assert result.get("buyer_issue_category") is None + +def test_repeated_input_output(): + output_data = """ +A summary of the ticket + + + +One of: Yes, No - Is this ticket a spam or sales ticket + + + + + + +«Ticket ID: 2044 +Freshdesk ID: 335398 +Status: PENDING +Processing State: TRIAGED_READY +Subject: Horror collection +Priority: 2 +Messages:» + + +No + + + + + +«Ticket ID: 2504 +Freshdesk ID: 334191 +Status: PENDING +Processing State: TRIAGED_READY +Subject: Ch +Messages:» + + +No + + + + + +«Ticket ID: 2453 +Freshdesk ID: 334312 +Status: IN_PROGRESS +Processing State: TRIAGED_READY +Subject: No Response from Seller +Description: [Chatbot]: Hi there, how can we help you today? [user]: I sent a message to the seller on 2/2 and received an auto reply to allow 2-3 days for someone to get back to me. To date, I have not heard anything from the seller. [Chatbot]: (No Intent Predicted) [Chatbot]: I understand your concern about not hearing back from the seller. If it's been more than 2 business days since you contacted them, Cratejoy can assist by reaching out on your behalf. Please contact Cratejoy Support for further help with this issue. * Shipments Lost In Transit * Getting Help With An Unshipped Order * Damaged, Duplicate or Defective Items [Chatbot]: Was I able to help you resolve your question? +(Yes, thank you!) + + + +No + +""" + + class IsTicketSpam(PromptSignature): + ticket_summary = InputField(name="Ticket Summary", desc="A summary of the ticket", formatter=as_multiline) + is_spam = OutputFieldBool(name="Spam", desc="Is this ticket a spam or sales ticket") + + config = {"llm_type": "anthropic"} + prompt_runner = PromptRunner(template_class=IsTicketSpam, prompt_strategy=DefaultPromptStrategy) + result = prompt_runner.template.parse_output_to_fields(output_data, config["llm_type"]) + + assert result["is_spam"] == "No" \ No newline at end of file