diff --git a/hypha/VERSION b/hypha/VERSION index cabf9054..04f91558 100644 --- a/hypha/VERSION +++ b/hypha/VERSION @@ -1,3 +1,3 @@ { - "version": "0.20.39.post10" + "version": "0.20.39.post11" } diff --git a/hypha/artifact.py b/hypha/artifact.py index bd5f8fec..193692be 100644 --- a/hypha/artifact.py +++ b/hypha/artifact.py @@ -16,20 +16,23 @@ text, and_, or_, + update, ) -from hrid import HRID +from sqlalchemy.sql import func from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm.attributes import flag_modified +from sqlalchemy.ext.asyncio import ( + async_sessionmaker, + AsyncSession, +) + +from hrid import HRID from hypha.utils import remove_objects_async, list_objects_async, safe_join from hypha.utils.zenodo import ZenodoClient from botocore.exceptions import ClientError from hypha.s3 import FSFileResponse from aiobotocore.session import get_session -from sqlalchemy import update -from sqlalchemy.ext.asyncio import ( - async_sessionmaker, - AsyncSession, -) + from fastapi import APIRouter, Depends, HTTPException from hypha.core import ( UserInfo, @@ -1257,6 +1260,23 @@ async def read( artifact, version_index, s3_config ) + if artifact.type == "collection": + # Use with_only_columns to optimize the count query + count_q = select(func.count()).where( + ArtifactModel.parent_id == artifact.id + ) + result = await session.execute(count_q) + child_count = result.scalar() + artifact_data["config"] = artifact_data.get("config", {}) + artifact_data["config"]["child_count"] = child_count + elif artifact.type == "vector-collection" and self._vectordb_client: + artifact_data["config"] = artifact_data.get("config", {}) + artifact_data["config"]["vector_count"] = ( + await self._vectordb_client.count( + collection_name=f"{artifact.workspace}/{artifact.alias}" + ) + ).count + if not silent: await session.commit() @@ -1337,15 +1357,6 @@ async def commit( ), ) - if artifact.type == "vector-collection": - assert ( - self._vectordb_client - ), "The server is not configured to use a VectorDB client." - artifact.manifest["points"] = self._vectordb_client.count( - collection_name=f"{artifact.workspace}/{artifact.alias}" - ) - flag_modified(artifact, "manifest") - parent_artifact_config = ( parent_artifact.config if parent_artifact else {} ) @@ -1417,7 +1428,7 @@ async def delete( assert ( self._vectordb_client ), "The server is not configured to use a VectorDB client." - self._vectordb_client.delete_collection( + await self._vectordb_client.delete_collection( collection_name=f"{artifact.workspace}/{artifact.alias}" ) diff --git a/tests/test_artifact.py b/tests/test_artifact.py index 2c9875fa..ea76d7cb 100644 --- a/tests/test_artifact.py +++ b/tests/test_artifact.py @@ -76,6 +76,9 @@ async def test_artifact_vector_collection( vectors=vectors, ) + vc = await artifact_manager.read(artifact_id=vector_collection.id) + assert vc["config"]["vector_count"] == 3 + # Search for vectors by query vector query_vector = [random.random() for _ in range(384)] search_results = await artifact_manager.search_by_vector( @@ -929,6 +932,9 @@ async def test_edit_existing_artifact(minio_server, fastapi_server, test_user_to version="stage", ) + collection = await artifact_manager.read(artifact_id=collection.id) + assert collection["config"]["child_count"] == 1 + # Commit the artifact dataset = await artifact_manager.commit(artifact_id=dataset.id) versions = dataset["versions"]