-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from aelaguiz/nshot_train
Nshot train
- Loading branch information
Showing
22 changed files
with
1,304 additions
and
261 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
name: Run Tests | ||
|
||
on: | ||
pull_request: | ||
branches: [main] | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v2 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: 3.9 | ||
|
||
- name: Install Poetry | ||
run: | | ||
curl -sSL https://install.python-poetry.org | python3 - | ||
- name: Configure Poetry | ||
run: | | ||
echo "$HOME/.local/bin" >> $GITHUB_PATH | ||
poetry config virtualenvs.in-project true | ||
- name: Set up cache | ||
uses: actions/cache@v2 | ||
with: | ||
path: .venv | ||
key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }} | ||
|
||
- name: Install dependencies | ||
run: | | ||
poetry install | ||
- name: Run tests with coverage | ||
run: | | ||
poetry run test | ||
- name: Generate coverage report | ||
run: | | ||
poetry run coverage | ||
- name: Upload coverage report | ||
uses: actions/upload-artifact@v2 | ||
with: | ||
name: coverage-report | ||
path: htmlcov/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,3 +10,9 @@ cache* | |
sources | ||
*.pickle | ||
logs | ||
__pycache__/ | ||
*.pyc | ||
.pytest_cache/ | ||
htmlcov/ | ||
.coverage | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import sys | ||
sys.path.append('.') | ||
sys.path.append('langdspy') | ||
|
||
import os | ||
import dotenv | ||
dotenv.load_dotenv() | ||
|
||
import logging | ||
|
||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | ||
logging.getLogger("httpx").disabled = True | ||
logging.getLogger("openai").disabled = True | ||
logging.getLogger("httpcore.connection").disabled = True | ||
logging.getLogger("httpcore.http11").disabled = True | ||
logging.getLogger("openai._base_client").disabled = True | ||
logging.getLogger("paramiko.transport").disabled = True | ||
logging.getLogger("anthropic._base_client").disabled = True | ||
# logging.getLogger("langdspy").disabled = True | ||
|
||
import langdspy | ||
import httpx | ||
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | ||
from langchain_anthropic import ChatAnthropic | ||
from sklearn.metrics import accuracy_score | ||
import json | ||
from sklearn.feature_extraction.text import TfidfVectorizer | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
|
||
|
||
def get_llm(): | ||
FAST_OPENAI_MODEL = os.getenv("FAST_OPENAI_MODEL") | ||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") | ||
OPENAI_TEMPERATURE = os.getenv("OPENAI_TEMPERATURE") | ||
FAST_MODEL_PROVIDER = os.getenv("FAST_MODEL_PROVIDER") | ||
FAST_ANTHROPIC_MODEL = os.getenv("FAST_ANTHROPIC_MODEL") | ||
|
||
if FAST_MODEL_PROVIDER.lower() == "anthropic": | ||
_fast_llm = ChatAnthropic(model_name=FAST_ANTHROPIC_MODEL, temperature=OPENAI_TEMPERATURE, anthropic_api_key=ANTHROPIC_API_KEY) | ||
else: | ||
_fast_llm = ChatOpenAI(model_name=FAST_OPENAI_MODEL, temperature=OPENAI_TEMPERATURE, timeout=httpx.Timeout(15.0, read=60.0, write=10.0, connect=3.0), max_retries=2) | ||
|
||
return _fast_llm | ||
|
||
|
||
class GenerateSlug(langdspy.PromptSignature): | ||
hint_slug = langdspy.HintField(desc="Generate a URL-friendly slug based on the provided H1, title, and product copy. The slug should be lowercase, use hyphens to separate words, and not exceed 50 characters.") | ||
|
||
h1 = langdspy.InputField(name="H1", desc="The H1 heading of the product page") | ||
title = langdspy.InputField(name="Title", desc="The title of the product page") | ||
product_copy = langdspy.InputField(name="Product Copy", desc="The product description or copy") | ||
|
||
slug = langdspy.OutputField(name="Slug", desc="The generated URL-friendly slug") | ||
|
||
class ProductSlugGenerator(langdspy.Model): | ||
generate_slug = langdspy.PromptRunner(template_class=GenerateSlug, prompt_strategy=langdspy.DefaultPromptStrategy) | ||
|
||
def invoke(self, input, config): | ||
h1 = input['h1'] | ||
title = input['title'] | ||
product_copy = input['product_copy'] | ||
|
||
slug_res = self.generate_slug.invoke({'h1': h1, 'title': title, 'product_copy': product_copy}, config=config) | ||
|
||
return slug_res.slug | ||
|
||
def cosine_similarity_tfidf(true_slugs, predicted_slugs): | ||
# Convert slugs to lowercase | ||
true_slugs = [slug.lower() for slug in true_slugs] | ||
predicted_slugs = [slug.lower() for slug in predicted_slugs] | ||
|
||
# for i in range(len(true_slugs)): | ||
# print(f"Actual Slug: {true_slugs[i]} Predicted: {predicted_slugs[i]}") | ||
|
||
vectorizer = TfidfVectorizer() | ||
true_vectors = vectorizer.fit_transform(true_slugs) | ||
predicted_vectors = vectorizer.transform(predicted_slugs) | ||
similarity_scores = cosine_similarity(true_vectors, predicted_vectors) | ||
return similarity_scores.diagonal() | ||
|
||
def slug_similarity(true_slugs, predicted_slugs): | ||
similarity_scores = cosine_similarity_tfidf(true_slugs, predicted_slugs) | ||
average_similarity = sum(similarity_scores) / len(similarity_scores) | ||
return average_similarity | ||
|
||
def evaluate_model(model, X, y): | ||
predicted_slugs = model.predict(X, llm) | ||
accuracy = slug_similarity(y, predicted_slugs) | ||
return accuracy | ||
|
||
llm = get_llm() | ||
|
||
if __name__ == "__main__": | ||
output_path = sys.argv[1] | ||
dataset_file= "data/amazon_products_split.json" | ||
with open(dataset_file, 'r') as file: | ||
dataset = json.load(file) | ||
|
||
X_train = dataset['train']['X'] | ||
y_train = dataset['train']['y'] | ||
X_test = dataset['test']['X'] | ||
y_test = dataset['test']['y'] | ||
|
||
model = ProductSlugGenerator(n_jobs=4, print_prompt=True) | ||
|
||
before_test_accuracy = None | ||
if os.path.exists(output_path): | ||
model.load(output_path) | ||
else: | ||
input("Hit enter to evaluate the untrained model...") | ||
before_test_accuracy = evaluate_model(model, X_test, y_test) | ||
print(f"Before Training Accuracy: {before_test_accuracy}") | ||
|
||
input("Hit enter to train the model...") | ||
model.fit(X_train, y_train, score_func=slug_similarity, llm=llm, n_examples=2, n_iter=100) | ||
|
||
input("Hit enter to evaluate the trained model...") | ||
# Evaluate the model on the test set | ||
test_accuracy = evaluate_model(model, X_test, y_test) | ||
print(f"Before Training Accuracy: {before_test_accuracy}") | ||
print(f"After Training Accuracy: {test_accuracy}") | ||
|
||
model.save(output_path) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import json | ||
import random | ||
import sys | ||
from sklearn.model_selection import train_test_split | ||
from urllib.parse import urlparse | ||
|
||
input_file = sys.argv[1] | ||
output_file = sys.argv[2] | ||
|
||
# Read the JSONL file | ||
with open(input_file, 'r') as file: | ||
data = [json.loads(line) for line in file] | ||
|
||
# Deduplicate based on the title | ||
unique_data = {item['title']: item for item in data}.values() | ||
|
||
# Process each item | ||
X = [] | ||
y = [] | ||
for item in unique_data: | ||
# Trim strings | ||
item['title'] = item['title'].strip() if item.get('title') else '' | ||
item['h1'] = item['h1'].strip() if item.get('h1') else '' | ||
item['product_copy'] = ' '.join([copy.strip() for copy in item.get('product_copy', [])]) | ||
url = item['url'] | ||
parsed_url = urlparse(url) | ||
path_parts = parsed_url.path.split('/') | ||
print(f"URL: {url} Path: {path_parts}") | ||
try: | ||
if "dp" == path_parts[2]: | ||
item['slug'] = path_parts[1] | ||
item['product_id'] = path_parts[3] | ||
X.append({ | ||
'title': item['title'], | ||
'h1': item['h1'], | ||
'product_copy': item['product_copy'] | ||
}) | ||
y.append(item['slug']) | ||
elif "dp" == path_parts[1]: | ||
item['product_id'] = path_parts[2] | ||
item['slug'] = None | ||
else: | ||
print(f"Unknown URL format: {url}") | ||
except: | ||
print(f"Failed to parse URL: {url}") | ||
continue | ||
|
||
# Split the data into train/test sets | ||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | ||
|
||
# Create a dictionary to store the datasets | ||
datasets = { | ||
'train': {'X': X_train, 'y': y_train}, | ||
'test': {'X': X_test, 'y': y_test} | ||
} | ||
|
||
# Save the datasets to a single JSON file | ||
with open(output_file, 'w') as file: | ||
json.dump(datasets, file, indent=2) | ||
|
||
print(f"Datasets saved to {output_file}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.