From f45928af7ae77fc27c4bcc01fd3c047bdbf4063a Mon Sep 17 00:00:00 2001 From: Wei Ouyang Date: Thu, 21 Nov 2024 01:51:07 -0800 Subject: [PATCH] Return pagination --- hypha/VERSION | 2 +- hypha/artifact.py | 66 +++++++++++++++++++++++++++++++++++++---- hypha/core/workspace.py | 23 +++++++++++++- tests/test_artifact.py | 16 ++++++++++ tests/test_server.py | 3 ++ 5 files changed, 103 insertions(+), 7 deletions(-) diff --git a/hypha/VERSION b/hypha/VERSION index 04f91558..906e73fa 100644 --- a/hypha/VERSION +++ b/hypha/VERSION @@ -1,3 +1,3 @@ { - "version": "0.20.39.post11" + "version": "0.20.39.post12" } diff --git a/hypha/artifact.py b/hypha/artifact.py index 193692be..5bd78f00 100644 --- a/hypha/artifact.py +++ b/hypha/artifact.py @@ -196,6 +196,8 @@ async def list_children( mode: str = "AND", limit: int = 100, order_by: str = None, + pagination: bool = False, + silent: bool = False, user_info: self.store.login_optional = Depends(self.store.login_optional), ): """List child artifacts of a specified artifact.""" @@ -212,6 +214,8 @@ async def list_children( keywords=keywords, filters=filters, mode=mode, + pagination=pagination, + silent=silent, context={"user": user_info.model_dump(), "ws": workspace}, ) except KeyError: @@ -1607,6 +1611,7 @@ async def search_by_vector( limit: int = 10, with_payload: bool = True, with_vectors: bool = False, + pagination: bool = False, context: dict = None, ): user_info = UserInfo.model_validate(context["user"]) @@ -1635,6 +1640,16 @@ async def search_by_vector( with_payload=with_payload, with_vectors=with_vectors, ) + if pagination: + count = await self._vectordb_client.count( + collection_name=f"{artifact.workspace}/{artifact.alias}" + ) + return { + "total": count.count, + "items": search_results, + "offset": offset, + "limit": limit, + } return search_results except Exception as e: raise e @@ -1650,6 +1665,7 @@ async def search_by_text( limit: int = 10, with_payload: bool = True, with_vectors: bool = False, + pagination: bool = False, context: dict = None, ): user_info = UserInfo.model_validate(context["user"]) @@ -1676,6 +1692,16 @@ async def search_by_text( with_payload=with_payload, with_vectors=with_vectors, ) + if pagination: + count = await self._vectordb_client.count( + collection_name=f"{artifact.workspace}/{artifact.alias}" + ) + return { + "total": count.count, + "items": search_results, + "offset": offset, + "limit": limit, + } return search_results except Exception as e: raise e @@ -1957,7 +1983,7 @@ async def list_files( self, artifact_id: str, dir_path: str = None, - max_length: int = 1000, + limit: int = 1000, version: str = None, context: dict = None, ): @@ -1999,7 +2025,7 @@ async def list_files( s3_client, s3_config["bucket"], full_path, - max_length=max_length, + max_length=limit, ) return items except Exception as e: @@ -2015,8 +2041,9 @@ async def list_children( mode="AND", offset: int = 0, limit: int = 100, - order_by=None, - silent=False, + order_by: str = None, + silent: bool = False, + pagination: bool = False, context: dict = None, ): """ @@ -2058,16 +2085,26 @@ async def list_children( query = select( *[getattr(ArtifactModel, field) for field in list_fields] ).where(ArtifactModel.parent_id == parent_artifact.id) + count_query = select(func.count()).where( + ArtifactModel.parent_id == parent_artifact.id + ) else: # If list_fields is empty or not specified, select all columns query = select(ArtifactModel).where( ArtifactModel.parent_id == parent_artifact.id ) + count_query = select(func.count()).where( + ArtifactModel.parent_id == parent_artifact.id + ) else: query = select(ArtifactModel).where( ArtifactModel.parent_id == None, ArtifactModel.workspace == context["ws"], ) + count_query = select(func.count()).where( + ArtifactModel.parent_id == None, + ArtifactModel.workspace == context["ws"], + ) conditions = [] # Handle keyword-based search across manifest fields @@ -2189,6 +2226,7 @@ async def list_children( ) query = query.where(stage_condition) + count_query = count_query.where(stage_condition) # Combine conditions based on mode (AND/OR) if conditions: @@ -2197,6 +2235,18 @@ async def list_children( if mode == "OR" else query.where(and_(*conditions)) ) + count_query = ( + count_query.where(or_(*conditions)) + if mode == "OR" + else count_query.where(and_(*conditions)) + ) + + if pagination: + # Execute the count query + result = await session.execute(count_query) + total_count = result.scalar() + else: + total_count = None # Pagination and ordering order_field_map = { @@ -2237,7 +2287,13 @@ async def list_children( session, parent_artifact.id, "view_count" ) await session.commit() - + if pagination: + return { + "items": results, + "total": total_count, + "offset": offset, + "limit": limit, + } return results except Exception as e: diff --git a/hypha/core/workspace.py b/hypha/core/workspace.py index 5944bf88..c32b7aab 100644 --- a/hypha/core/workspace.py +++ b/hypha/core/workspace.py @@ -860,6 +860,9 @@ async def search_services( None, description="Order by field, default is score if embedding or text_query is provided.", ), + pagination: Optional[bool] = Field( + False, description="Enable pagination, return metadata with total count." + ), context: Optional[dict] = None, ): """ @@ -928,12 +931,30 @@ async def search_services( query, query_params=query_params ) + # Handle pagination + if pagination: + count_query = Query(query_string).paging(0, 0).dialect(2) + count_results = await self._redis.ft("service_info_index").search( + count_query, query_params=query_params + ) + total_count = count_results.total + else: + total_count = None + # Convert results to dictionaries and return services = [ ServiceInfo.from_redis_dict(vars(doc), in_bytes=False) for doc in results.docs ] - return [service.model_dump() for service in services] + if pagination: + return { + "items": [service.model_dump() for service in services], + "total": total_count, + "offset": offset, + "limit": limit, + } + else: + return [service.model_dump() for service in services] @schema_method async def list_services( diff --git a/tests/test_artifact.py b/tests/test_artifact.py index ea76d7cb..d89c3459 100644 --- a/tests/test_artifact.py +++ b/tests/test_artifact.py @@ -88,6 +88,14 @@ async def test_artifact_vector_collection( ) assert len(search_results) <= 2 + results = await artifact_manager.search_by_vector( + artifact_id=vector_collection.id, + query_vector=query_vector, + limit=2, + pagination=True, + ) + assert results["total"] == 3 + query_filter = { "should": None, "min_should": None, @@ -248,6 +256,14 @@ async def test_sqlite_create_and_search_artifacts( assert len(search_results) == len(datasets) + results = await artifact_manager.list( + parent_id=collection.id, + filters={"stage": True, "manifest": {"description": "*dataset*"}}, + pagination=True, + ) + assert results["total"] == len(datasets) + assert len(results["items"]) == len(datasets) + # list application only search_results = await artifact_manager.list( parent_id=collection.id, filters={"stage": True, "type": "application"} diff --git a/tests/test_server.py b/tests/test_server.py index 3f8f454c..15e45cdc 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -382,6 +382,9 @@ async def test_service_search(fastapi_server_redis_1, test_user_token): assert "natural language processing" in services[0]["docs"] assert services[0]["score"] < services[1]["score"] + results = await api.search_services(text_query=text_query, limit=3, pagination=True) + assert results["total"] >= 1 + embedding = np.ones(384).astype(np.float32) await api.register_service( {