Skip to content

Commit

Permalink
Merge pull request #11 from aelaguiz/specialize_prompts
Browse files Browse the repository at this point in the history
Specialize prompts
  • Loading branch information
aelaguiz authored Mar 18, 2024
2 parents 664b291 + c9f0d72 commit 8e0c8bb
Show file tree
Hide file tree
Showing 9 changed files with 422 additions and 217 deletions.
20 changes: 10 additions & 10 deletions examples/amazon/generate_slugs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
logging.getLogger("openai._base_client").disabled = True
logging.getLogger("paramiko.transport").disabled = True
logging.getLogger("anthropic._base_client").disabled = True
logging.getLogger("langdspy").disabled = True
# logging.getLogger("langdspy").disabled = True

import langdspy
import httpx
Expand Down Expand Up @@ -105,24 +105,24 @@ def evaluate_model(model, X, y):
X_test = dataset['test']['X']
y_test = dataset['test']['y']

model = ProductSlugGenerator(n_jobs=4, print_prompt=False)
model = ProductSlugGenerator(n_jobs=1, print_prompt=True)
# model.generate_slug.set_model_kwargs({'print_prompt': True})

before_test_accuracy = None
if os.path.exists(output_path):
model.load(output_path)
else:
input("Hit enter to evaluate the untrained model...")
# input("Hit enter to evaluate the untrained model...")
before_test_accuracy = evaluate_model(model, X_test, y_test)
print(f"Before Training Accuracy: {before_test_accuracy}")

input("Hit enter to train the model...")
model.fit(X_train, y_train, score_func=slug_similarity, llm=llm, n_examples=3, n_iter=500)
# input("Hit enter to train the model...")
# model.fit(X_train, y_train, score_func=slug_similarity, llm=llm, n_examples=3, n_iter=500)

input("Hit enter to evaluate the trained model...")
# Evaluate the model on the test set
test_accuracy = evaluate_model(model, X_test, y_test)
print(f"Before Training Accuracy: {before_test_accuracy}")
print(f"After Training Accuracy: {test_accuracy}")
# input("Hit enter to evaluate the trained model...")
# # Evaluate the model on the test set
# test_accuracy = evaluate_model(model, X_test, y_test)
# print(f"Before Training Accuracy: {before_test_accuracy}")
# print(f"After Training Accuracy: {test_accuracy}")

model.save(output_path)
120 changes: 71 additions & 49 deletions langdspy/field_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@ class FieldDescriptor:
def __init__(self, name:str, desc: str, formatter: Optional[Callable[[Any], Any]] = None, transformer: Optional[Callable[[Any], Any]] = None, validator: Optional[Callable[[Any], Any]] = None, **kwargs):
assert "⏎" not in name, "Field name cannot contain newline character"
assert ":" not in name, "Field name cannot contain colon character"

self.name = name
self.desc = desc
self.formatter = formatter
self.transformer = transformer
self.validator = validator
self.kwargs = kwargs


def format_value(self, value: Any) -> Any:
if self.formatter:
return self.formatter(value, self.kwargs)
Expand All @@ -39,81 +37,105 @@ def validate_value(self, input: Input, value: Any) -> bool:
return True

class HintField(FieldDescriptor):
HINT_TOKEN = "💡"

HINT_TOKEN_OPENAI = "💡"
HINT_TOKEN_ANTHROPIC = None
def __init__(self, desc: str, formatter: Optional[Callable[[Any], Any]] = None, transformer: Optional[Callable[[Any], Any]] = None, validator: Optional[Callable[[Any], Any]] = None, **kwargs):
# Provide a default value for the name parameter, such as an empty string
super().__init__("", desc, formatter, transformer, validator, **kwargs)

def format_prompt_description(self):
return f"{self.HINT_TOKEN} {self.desc}"


def format_prompt_description(self):
return f"{self.HINT_TOKEN} {self.desc}"

def _start_format_openai(self):
return f"{self.HINT_TOKEN_OPENAI}"
def _start_format_anthropic(self):
return f"<hint>"
def format_prompt_description(self, llm_type: str):
if llm_type == "openai":
return f"{self._start_format_openai()} {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{self.desc}</hint>"

class InputField(FieldDescriptor):
START_TOKEN = "✅"

def _start_format(self):
return f"{self.START_TOKEN}{self.name}"

def format_prompt_description(self):
return f"{self._start_format()}: {self.desc}"

def format_prompt_value(self, value):
START_TOKEN_OPENAI = "✅"
START_TOKEN_ANTHROPIC = None
def _start_format_openai(self):
return f"{self.START_TOKEN_OPENAI}{self.name}"
def _start_format_anthropic(self):
return f"<{self.name}>"
def format_prompt_description(self, llm_type: str):
if llm_type == "openai":
return f"{self._start_format_openai()}: {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{self.desc}</{self.name}>"
def format_prompt_value(self, value, llm_type: str):
value = self.format_value(value)
return f"{self._start_format()}: {value}"
if llm_type == "openai":
return f"{self._start_format_openai()}: {value}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{value}</{self.name}>"

class InputFieldList(InputField):
def format_prompt_description(self):
return f"{self._start_format()}: {self.desc}"

def format_prompt_value(self, value):
def format_prompt_description(self, llm_type: str):
if llm_type == "openai":
return f"{self._start_format_openai()}: {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{self.desc}</{self.name}>"
def format_prompt_value(self, value, llm_type: str):
res = ""
if len(value) >= 1:
if llm_type == "anthropic":
res += f"<{self.name}>\n"
for i, value in enumerate(value):
if i > 0:
res += "\n"
value = self.format_value(value)
res += f"{self.START_TOKEN} [{i}]: {value}"
if llm_type == "openai":
res += f"{self.START_TOKEN_OPENAI} [{i}]: {value}"
elif llm_type == "anthropic":
res += f"<item>{value}</item>"
if llm_type == "anthropic":
res += f"\n</{self.name}>"
else:
res += f"{self._start_format()}: NO VALUES SPECIFIED"


if llm_type == "openai":
res += f"{self._start_format_openai()}: NO VALUES SPECIFIED"
elif llm_type == "anthropic":
res += f"{self._start_format_anthropic()}NO VALUES SPECIFIED</{self.name}>"
return res

class OutputField(FieldDescriptor):
START_TOKEN = "🔑"

def _start_format(self):
return f"{self.START_TOKEN}{self.name}"

def format_prompt_description(self):
return f"{self._start_format()}: {self.desc}"

def format_prompt_value(self, value):
START_TOKEN_OPENAI = "🔑"
START_TOKEN_ANTHROPIC = None
def _start_format_openai(self):
return f"{self.START_TOKEN_OPENAI}{self.name}"
def _start_format_anthropic(self):
return f"<{self.name}>"
def format_prompt_description(self, llm_type: str):
if llm_type == "openai":
return f"{self._start_format_openai()}: {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{self.desc}</{self.name}>"
def format_prompt_value(self, value, llm_type: str):
value = self.format_value(value)
return f"{self._start_format()}: {value}"

def format_prompt(self):
return f"{self._start_format()}:"
if llm_type == "openai":
return f"{self._start_format_openai()}: {value}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}{value}</{self.name}>"
def format_prompt(self, llm_type: str):
if llm_type == "openai":
return f"{self._start_format_openai()}:"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}</{self.name}>"

class OutputFieldEnum(OutputField):
def __init__(self, name: str, desc: str, enum: Enum, **kwargs):
kwargs['enum'] = enum

if not 'transformer' in kwargs:
kwargs['transformer'] = transformers.as_enum

if not 'validator' in kwargs:
kwargs['validator'] = validators.is_one_of
kwargs['choices'] = [e.name for e in enum]

super().__init__(name, desc, **kwargs)

def format_prompt_description(self):
def format_prompt_description(self, llm_type: str):
enum = self.kwargs.get('enum')
choices_str = ", ".join([e.name for e in enum])
return f"{self._start_format()}: One of: {choices_str} - {self.desc}"
if llm_type == "openai":
return f"{self._start_format_openai()}: One of: {choices_str} - {self.desc}"
elif llm_type == "anthropic":
return f"{self._start_format_anthropic()}One of: <choices>{choices_str}</choices> - {self.desc}</{self.name}>"
81 changes: 0 additions & 81 deletions langdspy/lcel_logger.py

This file was deleted.

22 changes: 20 additions & 2 deletions langdspy/prompt_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from langchain_core.output_parsers import StrOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator, Extra, PrivateAttr
from langchain_core.pydantic_v1 import validator
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from typing import Any, Dict, List, Type, Optional, Callable
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed
Expand Down Expand Up @@ -58,11 +60,25 @@ def check_template(

def set_model_kwargs(self, model_kwargs):
self.model_kwargs.update(model_kwargs)

def _determine_llm_type(self, llm):
if isinstance(llm, ChatOpenAI): # Assuming OpenAILLM is the class for OpenAI models
return 'openai'
elif isinstance(llm, ChatAnthropic): # Assuming AnthropicLLM is the class for Anthropic models
return 'anthropic'
else:
return 'openai' # Default to OpenAI if model type cannot be determined


def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[RunnableConfig] = {}):
total_max_tries = max_tries

hard_fail = config.get('hard_fail', False)
llm_type = config.get('llm_type') # Get the LLM type from the configuration
if llm_type is None:
llm_type = self._determine_llm_type(config['llm']) # Auto-detect the LLM type if not specified

logger.debug(f"LLM type: {llm_type}")

res = {}

Expand All @@ -88,7 +104,8 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
# logger.debug(f"Print prompt {print_prompt} kwargs print prompt {kwargs.get('print_prompt')} config print prompt {config.get('print_prompt')}")

# logger.debug(f"PromptRunner invoke with trained_state {trained_state}")
invoke_args = {**input, 'print_prompt': print_prompt, **kwargs, 'trained_state': trained_state, 'use_training': config.get('use_training', True)}
invoke_args = {**input, 'print_prompt': print_prompt, **kwargs, 'trained_state': trained_state, 'use_training': config.get('use_training', True), 'llm_type': llm_type}

# logger.debug(f"Invoke args: {invoke_args}")
res = chain.invoke(invoke_args, config=config)
except Exception as e:
Expand All @@ -109,7 +126,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
# Use the parse_output_to_fields method from the PromptStrategy
parsed_output = {}
try:
parsed_output = self.template.parse_output_to_fields(res)
parsed_output = self.template.parse_output_to_fields(res, llm_type)
except Exception as e:
import traceback
traceback.print_exc()
Expand Down Expand Up @@ -172,6 +189,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
def invoke(self, input: Input, config: Optional[RunnableConfig] = {}) -> Output:
# logger.debug(f"Template: {self.template}")
# logger.debug(f"Config: {config}")

chain = (
self.template
| config['llm']
Expand Down
Loading

0 comments on commit 8e0c8bb

Please sign in to comment.