Skip to content

Commit

Permalink
feat(backend): support delete datasource endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
wd0517 committed Aug 8, 2024
1 parent 76d2b73 commit 1004346
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 55 deletions.
34 changes: 34 additions & 0 deletions backend/app/alembic/versions/bd17a4ebccc5_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""empty message
Revision ID: bd17a4ebccc5
Revises: a8c79553c9f6
Create Date: 2024-08-08 01:20:42.069228
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from tidb_vector.sqlalchemy import VectorType


# revision identifiers, used by Alembic.
revision = 'bd17a4ebccc5'
down_revision = 'a8c79553c9f6'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('data_sources', sa.Column('deleted_at', sa.DateTime(), nullable=True))
op.drop_index('source_uri', table_name='documents')
op.add_column('relationships', sa.Column('chunk_id', sqlmodel.sql.sqltypes.GUID(), nullable=True))
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('relationships', 'chunk_id')
op.create_index('source_uri', 'documents', ['source_uri'], unique=True)
op.drop_column('data_sources', 'deleted_at')
# ### end Alembic commands ###
75 changes: 21 additions & 54 deletions backend/app/api/admin_routes/data_source.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate
from sqlmodel import select, func

from app.api.deps import SessionDep, CurrentSuperuserDep
from app.models import (
DataSource,
DataSourceType,
Document,
Chunk,
)
from app.tasks import import_documents_from_datasource
from app.repositories import data_source_repo

router = APIRouter()

Expand All @@ -38,9 +35,7 @@ def create_datasource(
user_id=user.id,
llm_id=request.llm_id,
)
session.add(data_source)
session.commit()
session.refresh(data_source)
data_source = data_source_repo.create(session, data_source)
import_documents_from_datasource.delay(data_source.id)
return data_source

Expand All @@ -51,11 +46,7 @@ def list_datasources(
user: CurrentSuperuserDep,
params: Params = Depends(),
) -> Page[DataSource]:
return paginate(
session,
select(DataSource).order_by(DataSource.created_at.desc()),
params,
)
return data_source_repo.paginate(session, params)


@router.get("/admin/datasources/{data_source_id}")
Expand All @@ -64,7 +55,7 @@ def get_datasource(
user: CurrentSuperuserDep,
data_source_id: int,
) -> DataSource:
data_source = session.get(DataSource, data_source_id)
data_source = data_source_repo.get(session, data_source_id)
if data_source is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand All @@ -73,55 +64,31 @@ def get_datasource(
return data_source


@router.get("/admin/datasources/{data_source_id}/overview")
def get_datasource_overview(
@router.delete("/admin/datasources/{data_source_id}")
def delete_datasource(
session: SessionDep,
user: CurrentSuperuserDep,
data_source_id: int,
) -> dict:
data_source = session.get(DataSource, data_source_id)
):
data_source = data_source_repo.get(session, data_source_id)
if data_source is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Data source not found",
)
documents_count = session.scalar(
select(func.count(Document.id)).where(Document.data_source_id == data_source_id)
)
chunks_count = session.scalar(
select(func.count(Chunk.id)).where(
Chunk.document.has(Document.data_source_id == data_source_id)
)
)
return data_source_repo.delete(session, data_source)

statement = (
select(Document.index_status, func.count(Document.id))
.where(Document.data_source_id == data_source_id)
.group_by(Document.index_status)
.order_by(Document.index_status)
)
results = session.exec(statement).all()
vector_index_status = {s: c for s, c in results}

if data_source.build_kg_index:
statement = (
select(Chunk.index_status, func.count(Chunk.id))
.where(Chunk.document.has(Document.data_source_id == data_source_id))
.group_by(Chunk.index_status)
.order_by(Chunk.index_status)
@router.get("/admin/datasources/{data_source_id}/overview")
def get_datasource_overview(
session: SessionDep,
user: CurrentSuperuserDep,
data_source_id: int,
) -> dict:
data_source = data_source_repo.get(session, data_source_id)
if data_source is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Data source not found",
)
results = session.exec(statement).all()
kg_index_status = {s: c for s, c in results}
else:
kg_index_status = {}

return {
"documents": {
"total": documents_count,
},
"chunks": {
"total": chunks_count,
},
"kg_index": kg_index_status,
"vector_index": vector_index_status,
}
return data_source_repo.overview(session, data_source)
6 changes: 6 additions & 0 deletions backend/app/models/data_source.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from enum import Enum
from uuid import UUID
from typing import Optional
from datetime import datetime

from sqlmodel import (
Column,
Field,
JSON,
DateTime,
Relationship as SQLRelationship,
)

Expand Down Expand Up @@ -38,5 +40,9 @@ class DataSource(UpdatableBaseModel, table=True):
"foreign_keys": "DataSource.llm_id",
},
)
deleted_at: Optional[datetime] = Field(
default=None,
sa_column=Column(DateTime),
)

__tablename__ = "data_sources"
2 changes: 1 addition & 1 deletion backend/app/models/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Document(UpdatableBaseModel, table=True):
name: str = Field(max_length=256)
content: str = Field(sa_column=Column(MEDIUMTEXT))
mime_type: str = Field(max_length=64)
source_uri: str = Field(max_length=512, unique=True)
source_uri: str = Field(max_length=512)
meta: dict | list = Field(default={}, sa_column=Column(JSON))
# the last time the document was modified in the source system
last_modified_at: Optional[datetime] = Field(sa_column=Column(DateTime))
Expand Down
2 changes: 2 additions & 0 deletions backend/app/models/knowledge_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
from uuid import UUID
from typing import Optional, Any, List, Dict
from datetime import datetime

Expand Down Expand Up @@ -81,6 +82,7 @@ class Relationship(RelationshipBase, table=True):
"lazy": "joined",
},
)
chunk_id: UUID = Field(nullable=True)

__tablename__ = "relationships"

Expand Down
1 change: 1 addition & 0 deletions backend/app/repositories/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .chat_engine import chat_engine_repo
from .chat import chat_repo
from .document import document_repo
from .data_source import data_source_repo
89 changes: 89 additions & 0 deletions backend/app/repositories/data_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Optional
from datetime import datetime, UTC

from sqlmodel import select, Session, func
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate

from app.models import DataSource, Document, Chunk
from app.repositories.base_repo import BaseRepo


class DataSourceRepo(BaseRepo):
model_cls = DataSource

def paginate(
self,
session: Session,
params: Params | None = Params(),
) -> Page[DataSource]:
query = (
select(DataSource)
.where(DataSource.deleted_at == None)
.order_by(DataSource.created_at.desc())
)
return paginate(session, query, params)

def get(
self,
session: Session,
data_source_id: int,
) -> Optional[DataSource]:
return session.exec(
select(DataSource).where(
DataSource.id == data_source_id, DataSource.deleted_at == None
)
).first()

def delete(self, session: Session, data_source: DataSource) -> None:
data_source.deleted_at = datetime.now(UTC)
session.add(data_source)
session.commit()

def overview(self, session: Session, data_source: DataSource) -> dict:
data_source_id = data_source.id
documents_count = session.scalar(
select(func.count(Document.id)).where(
Document.data_source_id == data_source_id
)
)
chunks_count = session.scalar(
select(func.count(Chunk.id)).where(
Chunk.document.has(Document.data_source_id == data_source_id)
)
)

statement = (
select(Document.index_status, func.count(Document.id))
.where(Document.data_source_id == data_source_id)
.group_by(Document.index_status)
.order_by(Document.index_status)
)
results = session.exec(statement).all()
vector_index_status = {s: c for s, c in results}

if data_source.build_kg_index:
statement = (
select(Chunk.index_status, func.count(Chunk.id))
.where(Chunk.document.has(Document.data_source_id == data_source_id))
.group_by(Chunk.index_status)
.order_by(Chunk.index_status)
)
results = session.exec(statement).all()
kg_index_status = {s: c for s, c in results}
else:
kg_index_status = {}

return {
"documents": {
"total": documents_count,
},
"chunks": {
"total": chunks_count,
},
"kg_index": kg_index_status,
"vector_index": vector_index_status,
}


data_source_repo = DataSourceRepo()

0 comments on commit 1004346

Please sign in to comment.