Skip to content

Commit

Permalink
[#16] Initial swipe at test/test_embedding_helper.py
Browse files Browse the repository at this point in the history
  • Loading branch information
uogbuji committed Jul 14, 2023
1 parent 86af70b commit 870807f
Showing 1 changed file with 13 additions and 22 deletions.
35 changes: 13 additions & 22 deletions test/test_embedding_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import pytest

from ogbujipt import embedding_helper
from ogbujipt.embedding_helper import qdrant_init_embedding_db, \
qdrant_add_collection, qdrant_upsert_collection
from ogbujipt.embedding_helper import qdrant_collection
from ogbujipt.text_helper import text_splitter

embedding_helper.QDRANT_AVAILABLE = True


@pytest.fixture
def CORRECT_STRING():
return 'And the secret thing in its heaving\nThreatens with iron mask\nThe last lighted torch of the century…'
Expand All @@ -41,36 +41,27 @@ def test_embed_poem(mocker, COME_THUNDER_POEM, CORRECT_STRING):
embedding_helper.models.VectorParams.side_effect = [mock_vparam]
mocker.patch('ogbujipt.embedding_helper.QdrantClient')

client = qdrant_init_embedding_db()
coll = qdrant_collection(collection_name, embedding)

#client.count.side_effect = ['count=0']
client.count.side_effect = lambda collection_name: 'count=0'
client = qdrant_add_collection(
client,
chunks,
embedding,
collection_name
)
client.recreate_collection.assert_called_once_with(
# client.count.side_effect = ['count=0']
coll.db.count.side_effect = lambda collection_name: 'count=0'
coll.add(chunks, collection_name)
coll.db.recreate_collection.assert_called_once_with(
collection_name='test_collection',
vectors_config=mock_vparam
)

embedding.encode.assert_called_with(CORRECT_STRING)

# Test update/insert into the DB
mock_pstruct = object()
embedding_helper.models.PointStruct.side_effect = lambda id=None, vector=None, payload=None: mock_pstruct

client.count.reset_mock()
client = qdrant_upsert_collection(
client,
chunks,
embedding,
collection_name
)

client.upsert.assert_called_with(
coll.db.count.reset_mock()
coll.upsert(chunks)

# XXX: Add test with metadata
coll.db.upsert.assert_called_with(
collection_name=collection_name,
points=[mock_pstruct]
)

0 comments on commit 870807f

Please sign in to comment.