Skip to content

Commit 25450ff

Browse files
authored
Merge pull request #23 from aelaguiz/anthropic_prompts
Anthropic prompts
2 parents b2e1173 + 3acce9b commit 25450ff

File tree

8 files changed

+168
-18
lines changed

8 files changed

+168
-18
lines changed

langdspy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .field_descriptors import InputField, OutputField, InputFieldList, HintField, OutputFieldEnum, OutputFieldEnumList
1+
from .field_descriptors import InputField, OutputField, InputFieldList, HintField, OutputFieldEnum, OutputFieldEnumList, OutputFieldBool, OutputFieldChooseOne
22
from .prompt_strategies import PromptSignature, PromptStrategy, DefaultPromptStrategy
33
from .prompt_runners import PromptRunner, RunnableConfig, Prediction, MultiPromptRunner
44
from .model import Model, TrainedModelState

langdspy/field_descriptors.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,39 @@ def format_prompt(self, llm_type: str):
123123
elif llm_type == "anthropic":
124124
return f"{self._start_format_anthropic()}</{self.name}>"
125125

126+
class OutputFieldBool(OutputField):
127+
def __init__(self, name: str, desc: str, **kwargs):
128+
if not 'transformer' in kwargs:
129+
kwargs['transformer'] = transformers.as_bool
130+
if not 'validator' in kwargs:
131+
kwargs['validator'] = validators.is_one_of
132+
kwargs['choices'] = ['Yes', 'No']
133+
134+
super().__init__(name, desc, **kwargs)
135+
136+
def format_prompt_description(self, llm_type: str):
137+
choices_str = ", ".join(['Yes', 'No'])
138+
if llm_type == "openai":
139+
return f"{self._start_format_openai()}: One of: {choices_str} - {self.desc}"
140+
elif llm_type == "anthropic":
141+
return f"{self._start_format_anthropic()}One of: <choices>{choices_str}</choices> - {self.desc}</{self.name}>"
142+
143+
class OutputFieldChooseOne(OutputField):
144+
def __init__(self, name: str, desc: str, choices: List[str], **kwargs):
145+
kwargs['choices'] = choices
146+
147+
if not 'validator' in kwargs:
148+
kwargs['validator'] = validators.is_one_of
149+
kwargs['choices'] = choices
150+
super().__init__(name, desc, **kwargs)
151+
152+
def format_prompt_description(self, llm_type: str):
153+
choices_str = ", ".join(self.kwargs.get('choices', []))
154+
if llm_type == "openai":
155+
return f"{self._start_format_openai()}: One of: {choices_str} - {self.desc}"
156+
elif llm_type == "anthropic":
157+
return f"{self._start_format_anthropic()}One of: <choices>{choices_str}</choices> - {self.desc}</{self.name}>"
158+
126159
class OutputFieldEnum(OutputField):
127160
def __init__(self, name: str, desc: str, enum: Enum, **kwargs):
128161
kwargs['enum'] = enum

langdspy/model.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def load(self, filepath):
7575

7676
def predict(self, X, llm):
7777
y = Parallel(n_jobs=self.n_jobs, backend='threading')(
78-
delayed(self.invoke)(item, {**self.kwargs, 'trained_state': self.trained_state, 'llm': llm})
78+
delayed(self.invoke)(item, {**self.kwargs, 'trained_state': self.trained_state, 'llm': llm, 'max_tries': 1})
7979
for item in tqdm(X, desc="Predicting", total=len(X))
8080
)
8181
return y
@@ -125,27 +125,34 @@ def fit(self, X, y, score_func, llm, n_examples=3, example_ratio=0.7, n_iter=Non
125125
logger.debug(f"Total number of examples: {n_examples} Example size: {example_size} n_examples: {n_examples} example_X size: {len(example_X)} Scoring size: {len(scoring_X)}")
126126

127127
def evaluate_subset(subset):
128+
# logger.debug(f"Evaluating subset: {subset}")
128129
subset_X, subset_y = zip(*subset)
129130
self.trained_state.examples = subset
130131

131132
# Predict on the scoring set
132-
predicted_slugs = Parallel(n_jobs=self.n_jobs)(
133+
predicted_y = Parallel(n_jobs=self.n_jobs)(
133134
delayed(self.invoke)(item, config={
134135
**self.kwargs,
135136
'trained_state': self.trained_state,
136-
'llm': llm
137+
'llm': llm,
138+
'max_tries': 1
137139
})
138140
for item in scoring_X
139141
)
140-
score = score_func(scoring_X, scoring_y, predicted_slugs)
141-
logger.debug(f"Training subset scored {score}")
142+
score = score_func(scoring_X, scoring_y, predicted_y)
143+
# logger.debug(f"Training subset scored {score}")
142144
return score, subset
145+
146+
# logger.debug(f"Generating subsets")
143147

144148
# Generate all possible subsets
145-
all_subsets = list(itertools.combinations(zip(example_X, example_y), n_examples))
149+
# all_subsets = list(itertools.combinations(zip(example_X, example_y), n_examples))
150+
all_subsets = [random.sample(list(zip(example_X, example_y)), n_examples) for _ in range(n_iter)]
151+
146152

147153
# Randomize the order of subsets
148154
random.shuffle(all_subsets)
155+
logger.debug(f"Total number of subsets: {len(all_subsets)}")
149156

150157
# Limit the number of iterations if n_iter is specified
151158
if n_iter is not None:

langdspy/prompt_runners.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
150150
formatted_prompt = self.template.format_prompt(**invoke_args)
151151

152152
if print_prompt:
153+
print(f"------------------------PROMPT START--------------------------------")
153154
print(formatted_prompt)
155+
print(f"------------------------PROMPT END----------------------------------\n")
154156

155157
# logger.debug(f"Invoke args: {invoke_args}")
156158
res = chain.invoke(invoke_args, config=config)
@@ -167,7 +169,9 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
167169

168170
# logger.debug(f"Raw output for prompt runner {self.template.__class__.__name__}: {res}")
169171
if print_prompt:
172+
print(f"------------------------RESULT START--------------------------------")
170173
print(res)
174+
print(f"------------------------RESULT END----------------------------------\n")
171175

172176
# Use the parse_output_to_fields method from the PromptStrategy
173177
parsed_output = {}
@@ -223,14 +227,16 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
223227
if validation_err is None:
224228
return res
225229

226-
logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__}, pausing before we retry")
227-
time.sleep(random.uniform(0.1, 1.5))
228230
max_tries -= 1
231+
if max_tries >= 1:
232+
logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__}, pausing before we retry")
233+
time.sleep(random.uniform(0.05, 0.25))
229234

230235
if hard_fail:
231236
raise ValueError(f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries.")
232237
else:
233-
logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries, returning unvalidated output.")
238+
logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries, returning None.")
239+
res = {attr_name: None for attr_name in self.template.output_variables.keys()}
234240

235241
return res
236242

langdspy/prompt_strategies.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,12 @@ def _format_openai_prompt(self, trained_state, use_training, examples, **kwargs)
215215

216216
def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwargs) -> str:
217217
# print(f"Formatting prompt {kwargs}")
218-
prompt = "Follow the following format. Attributes that have values should not be changed or repeated. "
218+
# prompt = "Follow the following format. Attributes that have values should not be changed or repeated. "
219+
prompt = ""
219220

220221
output_field_names = ', '.join([output_field.name for output_field in self.output_variables.values()])
221222
# Format the instruction with the extracted names
222-
prompt += f"Provide answers for {output_field_names}. Follow the XML output format.\n"
223+
prompt += f"Provide answers for output fields {output_field_names}. Follow the XML output format, only show the output fields do not repeat the hints, input fields or examples.\n"
223224

224225
if self.hint_variables:
225226
prompt += "\n<hints>\n"
@@ -325,9 +326,15 @@ def _parse_anthropic_output_to_fields(self, output: str) -> dict:
325326
parsed_fields = {}
326327
for output_name, output_field in self.output_variables.items():
327328
pattern = fr"<{output_field.name}>(.*?)</{output_field.name}>"
328-
match = re.search(pattern, output, re.DOTALL)
329-
if match:
330-
parsed_fields[output_name] = match.group(1).strip()
329+
# match = re.search(pattern, output, re.DOTALL)
330+
# if match:
331+
# parsed_fields[output_name] = match.group(1).strip()
332+
matches = re.findall(pattern, output, re.DOTALL)
333+
if matches:
334+
# Take the last match
335+
last_match = matches[-1]
336+
parsed_fields[output_name] = last_match.strip()
337+
331338

332339
logger.debug(f"Parsed fields: {parsed_fields}")
333340
return parsed_fields

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ virtualenv = "20.25.1"
9898
xattr = "1.1.0"
9999
yarl = "1.9.4"
100100
zipp = "3.17.0"
101+
ratelimit = "^2.2.1"
101102

102103
[tool.poetry.dev-dependencies]
103104

tests/test_field_descriptors.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from enum import Enum
3-
from langdspy.field_descriptors import InputField, InputFieldList, OutputField, OutputFieldEnum, OutputFieldEnumList
3+
from langdspy.field_descriptors import InputField, InputFieldList, OutputField, OutputFieldEnum, OutputFieldEnumList, OutputFieldBool, OutputFieldChooseOne
44

55
def test_input_field_initialization():
66
field = InputField("name", "description")
@@ -59,3 +59,32 @@ def test_output_field_enum_list_format_prompt_description():
5959
field = OutputFieldEnumList("name", "description", TestEnum)
6060
assert "A comma-separated list of one or more of: VALUE1, VALUE2, VALUE3" in field.format_prompt_description("openai")
6161
assert "A comma-separated list of one or more of: <choices>VALUE1, VALUE2, VALUE3</choices>" in field.format_prompt_description("anthropic")
62+
63+
64+
def test_output_field_bool_initialization():
65+
field = OutputFieldBool("name", "description")
66+
assert field.name == "name"
67+
assert field.desc == "description"
68+
assert field.transformer.__name__ == "as_bool"
69+
assert field.validator.__name__ == "is_one_of"
70+
assert field.kwargs['choices'] == ["Yes", "No"]
71+
72+
def test_output_field_bool_format_prompt_description():
73+
field = OutputFieldBool("name", "description")
74+
assert "One of: Yes, No" in field.format_prompt_description("openai")
75+
assert "One of: <choices>Yes, No</choices>" in field.format_prompt_description("anthropic")
76+
77+
def test_output_field_choose_one_initialization():
78+
choices = ["Option 1", "Option 2", "Option 3"]
79+
field = OutputFieldChooseOne("name", "description", choices)
80+
assert field.name == "name"
81+
assert field.desc == "description"
82+
assert field.validator.__name__ == "is_one_of"
83+
assert field.kwargs['choices'] == choices
84+
85+
def test_output_field_choose_one_format_prompt_description():
86+
choices = ["Option 1", "Option 2", "Option 3"]
87+
88+
field = OutputFieldChooseOne("name", "description", choices)
89+
assert "One of: Option 1, Option 2, Option 3" in field.format_prompt_description("openai")
90+
assert "One of: <choices>Option 1, Option 2, Option 3</choices>" in field.format_prompt_description("anthropic")

tests/test_output_parsing.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import pytest
2-
from langdspy.field_descriptors import InputField, OutputField
2+
from langdspy.field_descriptors import InputField, OutputField, OutputFieldBool
33
from langdspy.prompt_strategies import PromptSignature, DefaultPromptStrategy
44
from langdspy.prompt_runners import PromptRunner
5+
from langdspy.formatters import as_multiline
56

67
class TestOutputParsingPromptSignature(PromptSignature):
78
ticket_summary = InputField(name="Ticket Summary", desc="Summary of the ticket we're trying to analyze.")
@@ -69,4 +70,70 @@ def test_output_parsing_with_missing_fields():
6970
result = prompt_runner.template.parse_output_to_fields(output_data, config["llm_type"])
7071

7172
assert result["buyer_issues_summary"] == "The buyer is trying to personalize their order by selecting variants like color or size, but after making their selections and hitting \"done\", the changes are not being reflected. They are also asking how long delivery will take."
72-
assert result.get("buyer_issue_category") is None
73+
assert result.get("buyer_issue_category") is None
74+
75+
def test_repeated_input_output():
76+
output_data = """<input_fields>
77+
<Ticket Summary>A summary of the ticket</Ticket Summary>
78+
</input_fields>
79+
80+
<output_fields>
81+
<Spam>One of: <choices>Yes, No</choices> - Is this ticket a spam or sales ticket</Spam>
82+
</output_fields>
83+
84+
<examples>
85+
86+
<example>
87+
<input>
88+
<Ticket Summary>«Ticket ID: 2044
89+
Freshdesk ID: 335398
90+
Status: PENDING
91+
Processing State: TRIAGED_READY
92+
Subject: Horror collection
93+
Priority: 2
94+
Messages:»</Ticket Summary>
95+
</input>
96+
<output>
97+
<Spam>No</Spam>
98+
</output>
99+
</example>
100+
101+
<example>
102+
<input>
103+
<Ticket Summary>«Ticket ID: 2504
104+
Freshdesk ID: 334191
105+
Status: PENDING
106+
Processing State: TRIAGED_READY
107+
Subject: Ch
108+
Messages:»</Ticket Summary>
109+
</input>
110+
<output>
111+
<Spam>No</Spam>
112+
</output>
113+
</example>
114+
</examples>
115+
116+
<input>
117+
<Ticket Summary>«Ticket ID: 2453
118+
Freshdesk ID: 334312
119+
Status: IN_PROGRESS
120+
Processing State: TRIAGED_READY
121+
Subject: No Response from Seller
122+
Description: [Chatbot]: Hi there, how can we help you today? [user]: I sent a message to the seller on 2/2 and received an auto reply to allow 2-3 days for someone to get back to me. To date, I have not heard anything from the seller. [Chatbot]: (No Intent Predicted) [Chatbot]: I understand your concern about not hearing back from the seller. If it's been more than 2 business days since you contacted them, Cratejoy can assist by reaching out on your behalf. Please contact Cratejoy Support for further help with this issue. * Shipments Lost In Transit * Getting Help With An Unshipped Order * Damaged, Duplicate or Defective Items [Chatbot]: Was I able to help you resolve your question?
123+
(Yes, thank you!)
124+
</input>
125+
126+
<output>
127+
<Spam>No</Spam>
128+
</output>
129+
"""
130+
131+
class IsTicketSpam(PromptSignature):
132+
ticket_summary = InputField(name="Ticket Summary", desc="A summary of the ticket", formatter=as_multiline)
133+
is_spam = OutputFieldBool(name="Spam", desc="Is this ticket a spam or sales ticket")
134+
135+
config = {"llm_type": "anthropic"}
136+
prompt_runner = PromptRunner(template_class=IsTicketSpam, prompt_strategy=DefaultPromptStrategy)
137+
result = prompt_runner.template.parse_output_to_fields(output_data, config["llm_type"])
138+
139+
assert result["is_spam"] == "No"

0 commit comments

Comments
 (0)