Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specialize prompts #11

Merged
merged 6 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions examples/amazon/generate_slugs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
logging.getLogger("openai._base_client").disabled = True
logging.getLogger("paramiko.transport").disabled = True
logging.getLogger("anthropic._base_client").disabled = True
logging.getLogger("langdspy").disabled = True
# logging.getLogger("langdspy").disabled = True

import langdspy
import httpx
Expand All @@ -32,7 +32,7 @@ def get_llm():
FAST_OPENAI_MODEL = os.getenv("FAST_OPENAI_MODEL")
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
OPENAI_TEMPERATURE = os.getenv("OPENAI_TEMPERATURE")
FAST_MODEL_PROVIDER = os.getenv("FAST_MODEL_PROVIDER")
FAST_MODEL_PROVIDER = os.getenv("FAST_MODEL_PROVIDER", "")
FAST_ANTHROPIC_MODEL = os.getenv("FAST_ANTHROPIC_MODEL")
FAST_GROQ_MODEL = os.getenv("FAST_GROQ_MODEL")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
Expand Down Expand Up @@ -82,14 +82,14 @@ def cosine_similarity_tfidf(true_slugs, predicted_slugs):
similarity_scores = cosine_similarity(true_vectors, predicted_vectors)
return similarity_scores.diagonal()

def slug_similarity(true_slugs, predicted_slugs):
def slug_similarity(X, true_slugs, predicted_slugs):
similarity_scores = cosine_similarity_tfidf(true_slugs, predicted_slugs)
average_similarity = sum(similarity_scores) / len(similarity_scores)
return average_similarity

def evaluate_model(model, X, y):
predicted_slugs = model.predict(X, llm)
accuracy = slug_similarity(y, predicted_slugs)
accuracy = slug_similarity(X, y, predicted_slugs)
return accuracy

llm = get_llm()
Expand All @@ -105,24 +105,24 @@ def evaluate_model(model, X, y):
X_test = dataset['test']['X']
y_test = dataset['test']['y']

model = ProductSlugGenerator(n_jobs=4, print_prompt=False)
model = ProductSlugGenerator(n_jobs=1, print_prompt=True)
# model.generate_slug.set_model_kwargs({'print_prompt': True})

before_test_accuracy = None
if os.path.exists(output_path):
model.load(output_path)
else:
input("Hit enter to evaluate the untrained model...")
# input("Hit enter to evaluate the untrained model...")
before_test_accuracy = evaluate_model(model, X_test, y_test)
print(f"Before Training Accuracy: {before_test_accuracy}")

input("Hit enter to train the model...")
model.fit(X_train, y_train, score_func=slug_similarity, llm=llm, n_examples=3, n_iter=500)
# input("Hit enter to train the model...")
# model.fit(X_train, y_train, score_func=slug_similarity, llm=llm, n_examples=3, n_iter=500)

input("Hit enter to evaluate the trained model...")
# Evaluate the model on the test set
test_accuracy = evaluate_model(model, X_test, y_test)
print(f"Before Training Accuracy: {before_test_accuracy}")
print(f"After Training Accuracy: {test_accuracy}")
# input("Hit enter to evaluate the trained model...")
# # Evaluate the model on the test set
# test_accuracy = evaluate_model(model, X_test, y_test)
# print(f"Before Training Accuracy: {before_test_accuracy}")
# print(f"After Training Accuracy: {test_accuracy}")

model.save(output_path)
108 changes: 71 additions & 37 deletions langdspy/field_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@ class FieldDescriptor:
def __init__(self, name:str, desc: str, formatter: Optional[Callable[[Any], Any]] = None, transformer: Optional[Callable[[Any], Any]] = None, validator: Optional[Callable[[Any], Any]] = None, **kwargs):
assert "⏎" not in name, "Field name cannot contain newline character"
assert ":" not in name, "Field name cannot contain colon character"

self.name = name
self.desc = desc
self.formatter = formatter
self.transformer = transformer
self.validator = validator
self.kwargs = kwargs


def format_value(self, value: Any) -> Any:
if self.formatter:
return self.formatter(value, self.kwargs)
Expand All @@ -39,81 +37,117 @@ def validate_value(self, input: Input, value: Any) -> bool:
return True

class HintField(FieldDescriptor):
HINT_TOKEN = "💡"
HINT_TOKEN_OPENAI = "💡"
HINT_TOKEN_ANTHROPIC = None

def __init__(self, desc: str, formatter: Optional[Callable[[Any], Any]] = None, transformer: Optional[Callable[[Any], Any]] = None, validator: Optional[Callable[[Any], Any]] = None, **kwargs):
# Provide a default value for the name parameter, such as an empty string
super().__init__("", desc, formatter, transformer, validator, **kwargs)

def format_prompt_description(self):
return f"{self.HINT_TOKEN} {self.desc}"

def _start_format_openai(self):
return f"{self.HINT_TOKEN_OPENAI}"

def format_prompt_description(self):
return f"{self.HINT_TOKEN} {self.desc}"
def _start_format_anthropic(self):
return f"<hint>"

def format_prompt_description(self, llm_type: str):
if llm_type == "openai":
return f"{self._start_format_openai()} {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{self.desc}</hint>"

class InputField(FieldDescriptor):
START_TOKEN = "✅"
START_TOKEN_OPENAI = "✅"
START_TOKEN_ANTHROPIC = None

def _start_format_openai(self):
return f"{self.START_TOKEN_OPENAI}{self.name}"

def _start_format_anthropic(self):
return f"<{self.name}>"

def _start_format(self):
return f"{self.START_TOKEN}{self.name}"

def format_prompt_description(self):
return f"{self._start_format()}: {self.desc}"
def format_prompt_description(self, llm_type: str):
if llm_type == "openai":
return f"{self._start_format_openai()}: {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}: {self.desc}"

def format_prompt_value(self, value):
def format_prompt_value(self, value, llm_type: str):
value = self.format_value(value)
return f"{self._start_format()}: {value}"
if llm_type == "openai":
return f"{self._start_format_openai()}: {value}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{value}</{self.name}>"

class InputFieldList(InputField):
def format_prompt_description(self):
return f"{self._start_format()}: {self.desc}"
def format_prompt_description(self, llm_type: str):
if llm_type == "openai":
return f"{self._start_format_openai()}: {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}: {self.desc}"

def format_prompt_value(self, value):
def format_prompt_value(self, value, llm_type: str):
res = ""
if len(value) >= 1:
for i, value in enumerate(value):
if i > 0:
res += "\n"
value = self.format_value(value)
res += f"{self.START_TOKEN} [{i}]: {value}"
if llm_type == "openai":
res += f"{self.START_TOKEN_OPENAI} [{i}]: {value}"
elif llm_type == "anthropic":
res += f"<item>{value}</item>"
else:
res += f"{self._start_format()}: NO VALUES SPECIFIED"

if llm_type == "openai":
res += f"{self._start_format_openai()}: NO VALUES SPECIFIED"
elif llm_type == "anthropic":
res += f"{self._start_format_anthropic()}NO VALUES SPECIFIED</{self.name}>"

return res

class OutputField(FieldDescriptor):
START_TOKEN = "🔑"
START_TOKEN_OPENAI = "🔑"
START_TOKEN_ANTHROPIC = None

def _start_format(self):
return f"{self.START_TOKEN}{self.name}"

def format_prompt_description(self):
return f"{self._start_format()}: {self.desc}"
def _start_format_openai(self):
return f"{self.START_TOKEN_OPENAI}{self.name}"

def format_prompt_value(self, value):
def _start_format_anthropic(self):
return f"<{self.name}>"

def format_prompt_description(self, llm_type: str):
if llm_type == "openai":
return f"{self._start_format_openai()}: {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}: {self.desc}"

def format_prompt_value(self, value, llm_type: str):
value = self.format_value(value)
return f"{self._start_format()}: {value}"
if llm_type == "openai":
return f"{self._start_format_openai()}: {value}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{value}</{self.name}>"

def format_prompt(self):
return f"{self._start_format()}:"
def format_prompt(self, llm_type: str):
if llm_type == "openai":
return f"{self._start_format_openai()}:"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}</{self.name}>"

class OutputFieldEnum(OutputField):
def __init__(self, name: str, desc: str, enum: Enum, **kwargs):
kwargs['enum'] = enum

if not 'transformer' in kwargs:
kwargs['transformer'] = transformers.as_enum

if not 'validator' in kwargs:
kwargs['validator'] = validators.is_one_of
kwargs['choices'] = [e.name for e in enum]

super().__init__(name, desc, **kwargs)

def format_prompt_description(self):
def format_prompt_description(self, llm_type: str):
enum = self.kwargs.get('enum')
choices_str = ", ".join([e.name for e in enum])
return f"{self._start_format()}: One of: {choices_str} - {self.desc}"
if llm_type == "openai":
return f"{self._start_format_openai()}: One of: {choices_str} - {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}: One of: {choices_str} - {self.desc}"
81 changes: 0 additions & 81 deletions langdspy/lcel_logger.py

This file was deleted.

7 changes: 6 additions & 1 deletion langdspy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def save(self, filepath):
def load(self, filepath):
with open(filepath, 'rb') as file:
self.trained_state = pickle.load(file)
setattr(self, 'trained_state', self.trained_state)
self.kwargs = {**self.kwargs, 'trained_state': self.trained_state}

for runner_name, runner in self.prompt_runners:
runner.set_model_kwargs(self.kwargs)


def predict(self, X, llm):
Expand Down Expand Up @@ -104,7 +109,7 @@ def evaluate_subset(subset):
})
for item in scoring_X
)
score = score_func(scoring_y, predicted_slugs)
score = score_func(scoring_X, scoring_y, predicted_slugs)
logger.debug(f"Training subset scored {score}")
return score, subset

Expand Down
Loading
Loading