Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better enums 2 #21

Merged
merged 7 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,48 +1,45 @@
name: Run Tests

on:
pull_request:
branches: [main]

jobs:
test:
runs-on: ubuntu-latest
env:
OPENAI_API_KEY: "FAKE"

steps:
- uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.9

python-version: '3.11'
- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -

- name: Configure Poetry
run: |
echo "$HOME/.local/bin" >> $GITHUB_PATH
poetry config virtualenvs.in-project true

- name: Set up cache
uses: actions/cache@v2
with:
path: .venv
key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}

- name: Install dependencies
run: |
poetry install

- name: Run tests with coverage
run: |
poetry run test

poetry run pytest --cov=langdspy --cov-report=html tests/
- name: Check test results
if: failure()
run: |
echo "Tests failed. Please fix the failing tests."
exit 1
- name: Generate coverage report
run: |
poetry run coverage

poetry run coverage html
- name: Upload coverage report
uses: actions/upload-artifact@v2
with:
Expand Down
2 changes: 2 additions & 0 deletions langdspy/data_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def normalize_enum_value(val: str) -> str:
return val.replace(" ", "_").replace("-", "_").upper()
24 changes: 16 additions & 8 deletions langdspy/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from enum import Enum
from langchain_core.documents import Document
import re
from .data_helper import normalize_enum_value

def as_bool(value: str, kwargs: Dict[str, Any]) -> bool:
value = re.sub(r'[^\w\s]', '', value)
Expand All @@ -18,15 +19,22 @@ def as_json(val: str, kwargs: Dict[str, Any]) -> Any:

def as_enum(val: str, kwargs: Dict[str, Any]) -> Enum:
enum_class = kwargs['enum']
try:
return enum_class[val.upper()]
except KeyError:
raise ValueError(f"{val} is not a valid member of the {enum_class.__name__} enumeration")
normalized_val = normalize_enum_value(val)
for member in enum_class:
if normalize_enum_value(member.name) == normalized_val:
return member
raise ValueError(f"{val} is not a valid member of the {enum_class.__name__} enumeration")

def as_enum_list(val: str, kwargs: Dict[str, Any]) -> List[Enum]:
enum_class = kwargs['enum']
values = [v.strip() for v in val.split(",")]
try:
return [enum_class[v.upper()] for v in values]
except KeyError as e:
raise ValueError(f"{e.args[0]} is not a valid member of the {enum_class.__name__} enumeration")
result = []
for v in values:
normalized_val = normalize_enum_value(v)
for member in enum_class:
if normalize_enum_value(member.name) == normalized_val:
result.append(member)
break
else:
raise ValueError(f"{v} is not a valid member of the {enum_class.__name__} enumeration")
return result
10 changes: 5 additions & 5 deletions langdspy/validators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import logging

from .data_helper import normalize_enum_value

logger = logging.getLogger("langdspy")

Expand Down Expand Up @@ -34,8 +34,8 @@ def is_one_of(input, output_val, kwargs) -> bool:

try:
if not kwargs.get('case_sensitive', False):
choices = [c.lower() for c in kwargs['choices']]
output_val = output_val.lower()
choices = [normalize_enum_value(c) for c in kwargs['choices']]
output_val = normalize_enum_value (output_val)

# logger.debug(f"Checking if {output_val} is one of {choices}")
for choice in choices:
Expand All @@ -61,8 +61,8 @@ def is_subset_of(input, output_val, kwargs) -> bool:
try:
values = [v.strip() for v in output_val.split(",")]
if not kwargs.get('case_sensitive', False):
choices = [c.lower() for c in kwargs['choices']]
values = [v.lower() for v in values]
choices = [normalize_enum_value(c) for c in kwargs['choices']]
values = [normalize_enum_value(v) for v in values]
for value in values:
if value not in choices:
return False
Expand Down
46 changes: 45 additions & 1 deletion tests/test_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,51 @@
dotenv.load_dotenv()
import pytest
from unittest.mock import MagicMock
from examples.amazon.generate_slugs import ProductSlugGenerator, slug_similarity, get_llm
import langdspy
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity


class GenerateSlug(langdspy.PromptSignature):
hint_slug = langdspy.HintField(desc="Generate a URL-friendly slug based on the provided H1, title, and product copy. The slug should be lowercase, use hyphens to separate words, and not exceed 50 characters.")

h1 = langdspy.InputField(name="H1", desc="The H1 heading of the product page")
title = langdspy.InputField(name="Title", desc="The title of the product page")
product_copy = langdspy.InputField(name="Product Copy", desc="The product description or copy")

slug = langdspy.OutputField(name="Slug", desc="The generated URL-friendly slug")

class ProductSlugGenerator(langdspy.Model):
generate_slug = langdspy.PromptRunner(template_class=GenerateSlug, prompt_strategy=langdspy.DefaultPromptStrategy)

def invoke(self, input_dict, config):
h1 = input_dict['h1']
title = input_dict['title']
product_copy = input_dict['product_copy']

slug_res = self.generate_slug.invoke({'h1': h1, 'title': title, 'product_copy': product_copy}, config=config)

return slug_res.slug


def cosine_similarity_tfidf(true_slugs, predicted_slugs):
# Convert slugs to lowercase
true_slugs = [slug.lower() for slug in true_slugs]
predicted_slugs = [slug.lower() for slug in predicted_slugs]

# for i in range(len(true_slugs)):
# print(f"Actual Slug: {true_slugs[i]} Predicted: {predicted_slugs[i]}")

vectorizer = TfidfVectorizer()
true_vectors = vectorizer.fit_transform(true_slugs)
predicted_vectors = vectorizer.transform(predicted_slugs)
similarity_scores = cosine_similarity(true_vectors, predicted_vectors)
return similarity_scores.diagonal()

def slug_similarity(X, true_slugs, predicted_slugs):
similarity_scores = cosine_similarity_tfidf(true_slugs, predicted_slugs)
average_similarity = sum(similarity_scores) / len(similarity_scores)
return average_similarity

@pytest.fixture
def model():
Expand Down
19 changes: 14 additions & 5 deletions tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,31 @@ def test_as_enum():
class Fruit(Enum):
APPLE = 1
BANANA = 2
CHERRY_PIE = 3
DURIAN_FRUIT = 4

assert transformers.as_enum("APPLE", {"enum": Fruit}) == Fruit.APPLE
assert transformers.as_enum("BANANA", {"enum": Fruit}) == Fruit.BANANA
assert transformers.as_enum("cherry pie", {"enum": Fruit}) == Fruit.CHERRY_PIE
assert transformers.as_enum("Durian-Fruit", {"enum": Fruit}) == Fruit.DURIAN_FRUIT
assert transformers.as_enum("Durian_Fruit", {"enum": Fruit}) == Fruit.DURIAN_FRUIT
assert transformers.as_enum("Durian Fruit", {"enum": Fruit}) == Fruit.DURIAN_FRUIT

with pytest.raises(ValueError):
transformers.as_enum("CHERRY", {"enum": Fruit})
transformers.as_enum("MANGO", {"enum": Fruit})

def test_as_enum_list():
class Fruit(Enum):
APPLE = 1
BANANA = 2
CHERRY = 3
CHERRY_PIE = 3
DURIAN_FRUIT = 4

assert transformers.as_enum_list("APPLE", {"enum": Fruit}) == [Fruit.APPLE]
assert transformers.as_enum_list("BANANA, CHERRY", {"enum": Fruit}) == [Fruit.BANANA, Fruit.CHERRY]
assert transformers.as_enum_list("APPLE,BANANA,CHERRY", {"enum": Fruit}) == [Fruit.APPLE, Fruit.BANANA, Fruit.CHERRY]
assert transformers.as_enum_list("BANANA, CHERRY PIE", {"enum": Fruit}) == [Fruit.BANANA, Fruit.CHERRY_PIE]
assert transformers.as_enum_list("APPLE,BANANA,CHERRY PIE", {"enum": Fruit}) == [Fruit.APPLE, Fruit.BANANA, Fruit.CHERRY_PIE]
assert transformers.as_enum_list("Durian-Fruit, cherry pie", {"enum": Fruit}) == [Fruit.DURIAN_FRUIT, Fruit.CHERRY_PIE]
assert transformers.as_enum_list("Durian Fruit, cherry_pie", {"enum": Fruit}) == [Fruit.DURIAN_FRUIT, Fruit.CHERRY_PIE]

with pytest.raises(ValueError):
transformers.as_enum_list("DURIAN", {"enum": Fruit})
transformers.as_enum_list("MANGO", {"enum": Fruit})
13 changes: 8 additions & 5 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,27 @@ def test_is_one_of():
assert validators.is_one_of({}, 'cherry', {'choices': ['apple', 'banana']}) == False
assert validators.is_one_of({}, 'APPLE', {'choices': ['apple', 'banana'], 'case_sensitive': False}) == True
assert validators.is_one_of({}, 'none', {'choices': ['apple', 'banana'], 'none_ok': True}) == True
assert validators.is_one_of({}, 'apple pie', {'choices': ['apple_pie', 'banana split'], 'case_sensitive': False}) == True
assert validators.is_one_of({}, 'Apple-Pie', {'choices': ['apple_pie', 'banana-split'], 'case_sensitive': False}) == True

with pytest.raises(ValueError):
validators.is_one_of({}, 'apple', {})

def test_is_subset_of():
choices = ["apple", "banana", "cherry"]
choices = ["apple", "banana", "cherry_pie", "durian-fruit"]

assert validators.is_subset_of({}, "apple", {"choices": choices}) == True
assert validators.is_subset_of({}, "apple,banana", {"choices": choices}) == True
assert validators.is_subset_of({}, "apple, banana, cherry", {"choices": choices}) == True
assert validators.is_subset_of({}, "apple, banana, cherry_pie", {"choices": choices}) == True
assert validators.is_subset_of({}, "APPLE", {"choices": choices, "case_sensitive": False}) == True
assert validators.is_subset_of({}, "APPLE,BANANA", {"choices": choices, "case_sensitive": False}) == True
assert validators.is_subset_of({}, "Durian-Fruit, Cherry Pie", {"choices": choices, "case_sensitive": False}) == True

assert validators.is_subset_of({}, "durian", {"choices": choices}) == False
assert validators.is_subset_of({}, "apple,durian", {"choices": choices}) == False
assert validators.is_subset_of({}, "mango", {"choices": choices}) == False
assert validators.is_subset_of({}, "apple,mango", {"choices": choices}) == False

assert validators.is_subset_of({}, "none", {"choices": choices, "none_ok": True}) == True
assert validators.is_subset_of({}, "apple,none", {"choices": choices, "none_ok": True}) == False

with pytest.raises(ValueError):
validators.is_subset_of({}, "apple", {})
validators.is_subset_of({}, "apple", {})
Loading