Skip to content

Commit

Permalink
lintering
Browse files Browse the repository at this point in the history
  • Loading branch information
zimventures committed Jul 11, 2023
1 parent 1690737 commit 2429f71
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 12 deletions.
3 changes: 2 additions & 1 deletion chirps/embedding/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class Service(models.TextChoices):
OPEN_AI = 'OA', 'OpenAI'
LOCAL = 'LH', 'Locally Hosted: NOT IMPLEMENTED'

def get_provider_from_service_name(name: str) -> BaseEmbeddingProvider:
@classmethod
def get_provider_from_service_name(cls, name: str) -> BaseEmbeddingProvider:
"""Get the provider for the service."""
if name == Embedding.Service.OPEN_AI:
return OpenAIEmbeddingProvider()
Expand Down
1 change: 1 addition & 0 deletions chirps/embedding/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Base implementation for embedding providers."""
from abc import ABC
from typing import Any

Expand Down
5 changes: 3 additions & 2 deletions chirps/embedding/providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""OpenAI implementation for embedding provider."""
import logging
from typing import Any

Expand All @@ -22,8 +23,8 @@ def embed(self, user: User, model: str, text: str) -> Any:
# Generate the embedding
try:
response = openai.Embedding.create(model=model, input=text)
except openai.error.InvalidRequestError as e:
raise EmbeddingError(f'Embedding failure: {str(e)}') from e
except openai.error.InvalidRequestError as err:
raise EmbeddingError(f'Embedding failure: {str(err)}') from err

logger.debug(
'Embedding complete',
Expand Down
12 changes: 7 additions & 5 deletions chirps/embedding/tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Tests for the embedding app."""
# pylint: disable=consider-using-with
import json
from pathlib import Path
from unittest.mock import patch
Expand Down Expand Up @@ -54,7 +56,7 @@ def test_create_invalid(self):
'openai.Embedding.create',
side_effect=openai.error.InvalidRequestError('The model `invalid-model-001` does not exist', ''),
)
def test_create_invalid_openai_model(self, mock_openai_embedding_create):
def test_create_invalid_openai_model(self, _mock_openai_embedding_create):
"""Pass in a junk model name to the OpenAI service."""
response = self.client.get(
reverse('embedding_create'), {'text': 'test text', 'model': 'invalid-model-001', 'service': 'OA'}
Expand All @@ -67,9 +69,9 @@ def test_create_invalid_openai_model(self, mock_openai_embedding_create):

@patch(
'openai.Embedding.create',
return_value=json.loads(open(f'{fixture_path}/openai_embedding_create_mock.json').read()),
return_value=json.loads(open(f'{fixture_path}/openai_embedding_create_mock.json', encoding='utf-8').read()),
)
def test_create_valid(self, mock_openai_embedding_create):
def test_create_valid(self, _mock_openai_embedding_create):
"""Test creating a valid embedding."""
# Verify there are no embeddings in the database
self.assertEqual(0, Embedding.objects.all().count())
Expand All @@ -93,9 +95,9 @@ def test_create_valid(self, mock_openai_embedding_create):

@patch(
'openai.Embedding.create',
return_value=json.loads(open(f'{fixture_path}/openai_embedding_create_mock.json').read()),
return_value=json.loads(open(f'{fixture_path}/openai_embedding_create_mock.json', encoding='utf-8').read()),
)
def test_delete(self, mock_openai_embedding_create):
def test_delete(self, _mock_openai_embedding_create):
"""Test deleting an embedding."""
# Verify there are no embeddings in the database
self.assertEqual(0, Embedding.objects.all().count())
Expand Down
2 changes: 1 addition & 1 deletion chirps/embedding/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from . import views

urlpatterns = [
path('', views.list, name='embedding_list'),
path('', views.show, name='embedding_list'),
path('create/', views.create, name='embedding_create'),
path('delete/<int:embedding_id>/', views.delete, name='embedding_delete'),
]
6 changes: 3 additions & 3 deletions chirps/embedding/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def create(request):
embed_result = service.embed(
user=request.user, model=form.cleaned_data['model'], text=form.cleaned_data['text']
)
except EmbeddingError as e:
return HttpResponseBadRequest(json.dumps({'error': str(e)}), status=400)
except EmbeddingError as err:
return HttpResponseBadRequest(json.dumps({'error': str(err)}), status=400)

# Save the embedding result to the database
embedding = Embedding.objects.create(
Expand All @@ -61,7 +61,7 @@ def delete(request, embedding_id):


@login_required
def list(request):
def show(request):
"""List all embeddings."""
# Paginate the number of items returned to the user, defaulting to 25 per page
user_embeddings = Embedding.objects.filter(user=request.user).order_by('id')
Expand Down

0 comments on commit 2429f71

Please sign in to comment.