From 870807f2a12298717dede3a6f2e1d4a0d6191f57 Mon Sep 17 00:00:00 2001 From: Uche Ogbuji Date: Fri, 14 Jul 2023 10:48:05 -0600 Subject: [PATCH] [#16] Initial swipe at test/test_embedding_helper.py --- test/test_embedding_helper.py | 35 +++++++++++++---------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/test/test_embedding_helper.py b/test/test_embedding_helper.py index 2dbb6f3..f56c56f 100644 --- a/test/test_embedding_helper.py +++ b/test/test_embedding_helper.py @@ -10,12 +10,12 @@ import pytest from ogbujipt import embedding_helper -from ogbujipt.embedding_helper import qdrant_init_embedding_db, \ - qdrant_add_collection, qdrant_upsert_collection +from ogbujipt.embedding_helper import qdrant_collection from ogbujipt.text_helper import text_splitter embedding_helper.QDRANT_AVAILABLE = True + @pytest.fixture def CORRECT_STRING(): return 'And the secret thing in its heaving\nThreatens with iron mask\nThe last lighted torch of the century…' @@ -41,36 +41,27 @@ def test_embed_poem(mocker, COME_THUNDER_POEM, CORRECT_STRING): embedding_helper.models.VectorParams.side_effect = [mock_vparam] mocker.patch('ogbujipt.embedding_helper.QdrantClient') - client = qdrant_init_embedding_db() + coll = qdrant_collection(collection_name, embedding) - #client.count.side_effect = ['count=0'] - client.count.side_effect = lambda collection_name: 'count=0' - client = qdrant_add_collection( - client, - chunks, - embedding, - collection_name - ) - client.recreate_collection.assert_called_once_with( + # client.count.side_effect = ['count=0'] + coll.db.count.side_effect = lambda collection_name: 'count=0' + coll.add(chunks, collection_name) + coll.db.recreate_collection.assert_called_once_with( collection_name='test_collection', vectors_config=mock_vparam ) - + embedding.encode.assert_called_with(CORRECT_STRING) # Test update/insert into the DB mock_pstruct = object() embedding_helper.models.PointStruct.side_effect = lambda id=None, vector=None, payload=None: mock_pstruct - - client.count.reset_mock() - client = qdrant_upsert_collection( - client, - chunks, - embedding, - collection_name - ) - client.upsert.assert_called_with( + coll.db.count.reset_mock() + coll.upsert(chunks) + + # XXX: Add test with metadata + coll.db.upsert.assert_called_with( collection_name=collection_name, points=[mock_pstruct] )