Skip to content

Commit

Permalink
Fix & added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aelaguiz committed Mar 8, 2024
1 parent a90fc1e commit ddc626b
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 2 deletions.
2 changes: 1 addition & 1 deletion langdspy/prompt_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna

while max_tries >= 1:
try:
res = chain.invoke({**input, 'trained_state': config['trained_state'], 'print_prompt': config.get('print_prompt', False)}, config=config)
res = chain.invoke({**input, 'trained_state': config.get('trained_state', None), 'print_prompt': config.get('print_prompt', False)}, config=config)
except Exception as e:
import traceback
traceback.print_exc()
Expand Down
2 changes: 1 addition & 1 deletion scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import subprocess

def test():
subprocess.run(["pytest", "tests/"])
subprocess.run(["pytest", "--tb=long", "tests/"])

def coverage():
subprocess.run(["pytest", "--cov=langdspy", "--cov-report=html", "tests/"])
76 changes: 76 additions & 0 deletions tests/test_model_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# tests/test_generate_slugs.py
import sys
sys.path.append('.')
sys.path.append('langdspy')
import os
import dotenv
dotenv.load_dotenv()
import pytest
from unittest.mock import MagicMock
from examples.amazon.generate_slugs import ProductSlugGenerator, slug_similarity, get_llm

@pytest.fixture
def model():
return ProductSlugGenerator(n_jobs=1, print_prompt=False)

@pytest.fixture
def llm():
return get_llm()

@pytest.fixture
def dataset():
return {
'train': {
'X': [
{'h1': 'Product 1', 'title': 'Title 1', 'product_copy': 'Description 1'},
{'h1': 'Product 2', 'title': 'Title 2', 'product_copy': 'Description 2'}
],
'y': ['product-1', 'product-2']
},
'test': {
'X': [
{'h1': 'Product 3', 'title': 'Title 3', 'product_copy': 'Description 3'},
{'h1': 'Product 4', 'title': 'Title 4', 'product_copy': 'Description 4'}
],
'y': ['product-3', 'product-4']
}
}

def test_invoke_untrained(model, llm, dataset):
input_data = dataset['test']['X'][0]
result = model.invoke(input_data, config={'llm': llm})
assert isinstance(result, str)
assert len(result) <= 50

def test_invoke_trained(model, llm, dataset):
model.fit(dataset['train']['X'], dataset['train']['y'], score_func=slug_similarity, llm=llm, n_examples=1, n_iter=1)
input_data = dataset['test']['X'][0]
result = model.invoke(input_data, config={'llm': llm})
assert isinstance(result, str)
assert len(result) <= 50

def test_predict_untrained(model, llm, dataset):
X_test = dataset['test']['X']
y_test = dataset['test']['y']
predicted_slugs = model.predict(X_test, llm)
assert len(predicted_slugs) == len(y_test)
for slug in predicted_slugs:
assert isinstance(slug, str)
assert len(slug) <= 50

def test_predict_trained(model, llm, dataset):
model.fit(dataset['train']['X'], dataset['train']['y'], score_func=slug_similarity, llm=llm, n_examples=1, n_iter=1)
X_test = dataset['test']['X']
y_test = dataset['test']['y']
predicted_slugs = model.predict(X_test, llm)
assert len(predicted_slugs) == len(y_test)
for slug in predicted_slugs:
assert isinstance(slug, str)
assert len(slug) <= 50

def test_fit(model, llm, dataset):
X_train = dataset['train']['X']
y_train = dataset['train']['y']
model.fit(X_train, y_train, score_func=slug_similarity, llm=llm, n_examples=1, n_iter=1)
assert model.trained_state.examples is not None
assert len(model.trained_state.examples) == 1

0 comments on commit ddc626b

Please sign in to comment.