Skip to content

Commit

Permalink
feat: support chunking setting on KB
Browse files Browse the repository at this point in the history
  • Loading branch information
Mini256 committed Feb 14, 2025
1 parent 13604e6 commit 1e5cc9c
Show file tree
Hide file tree
Showing 8 changed files with 380 additions and 94 deletions.
40 changes: 40 additions & 0 deletions backend/app/alembic/versions/76f4e47a03d3_chunking_setting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""chunking_setting
Revision ID: 76f4e47a03d3
Revises: 2adc0b597dcd
Create Date: 2025-02-13 18:21:17.830980
"""

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "76f4e47a03d3"
down_revision = "2adc0b597dcd"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"documents",
sa.Column(
"content_format",
sa.Enum("TEXT", "MARKDOWN", name="contentformat"),
nullable=False,
),
)
op.add_column(
"knowledge_bases", sa.Column("chunking_config", sa.JSON(), nullable=True)
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("knowledge_bases", "chunking_config")
op.drop_column("documents", "content_format")
# ### end Alembic commands ###
8 changes: 2 additions & 6 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,8 @@ def _validate_sentry_sample_rate(self) -> Self:
COMPLIED_INTENT_ANALYSIS_PROGRAM_PATH: str | None = None
COMPLIED_PREREQUISITE_ANALYSIS_PROGRAM_PATH: str | None = None

# CAUTION: Do not change EMBEDDING_DIMS after initializing the database.
# Changing the embedding dimensions requires recreating the database and tables.
# The default EMBEDDING_DIMS and EMBEDDING_MAX_TOKENS are set for the OpenAI text-embedding-3-small model.
# If using a different embedding model, adjust these values according to the model's specifications.
# For example:
# maidalun1020/bce-embedding-base_v1: EMBEDDING_DIMS=768 EMBEDDING_MAX_TOKENS=512
# NOTICE: EMBEDDING_DIMS and EMBEDDING_MAX_TOKENS is deprecated and
# will be removed in the future.
EMBEDDING_DIMS: int = 1536
EMBEDDING_MAX_TOKENS: int = 2048

Expand Down
6 changes: 6 additions & 0 deletions backend/app/models/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ class DocIndexTaskStatus(str, enum.Enum):
FAILED = "failed"


class ContentFormat(str, enum.Enum):
TEXT = "text"
MARKDOWN = "markdown"


class Document(UpdatableBaseModel, table=True):
# Avoid "expected `enum` but got `str`" error.
model_config = ConfigDict(use_enum_values=True)
Expand All @@ -35,6 +40,7 @@ class Document(UpdatableBaseModel, table=True):
hash: str = Field(max_length=32)
name: str = Field(max_length=256)
content: str = Field(sa_column=Column(MEDIUMTEXT))
content_format: ContentFormat = Field(max_length=20)
mime_type: MimeTypes = Field(sa_column=Column(String(128), nullable=False))
source_uri: str = Field(max_length=512)
meta: dict | list = Field(default={}, sa_column=Column(JSON))
Expand Down
9 changes: 2 additions & 7 deletions backend/app/models/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from tidb_vector.sqlalchemy import VectorType
from sqlalchemy import Index

from app.core.config import settings
from app.models.knowledge_base import KnowledgeBase
from app.models.knowledge_base_scoped.registry import get_kb_scoped_registry
from app.models.knowledge_base_scoped.table_naming import (
Expand Down Expand Up @@ -42,14 +41,10 @@ class EntityBase(SQLModel):
class Entity(EntityBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
description_vec: Any = Field(
sa_column=Column(
VectorType(settings.EMBEDDING_DIMS), comment="hnsw(distance=cosine)"
)
sa_column=Column(VectorType(1536), comment="hnsw(distance=cosine)")
)
meta_vec: Any = Field(
sa_column=Column(
VectorType(settings.EMBEDDING_DIMS), comment="hnsw(distance=cosine)"
)
sa_column=Column(VectorType(1536), comment="hnsw(distance=cosine)")
)

__tablename__ = "entities"
Expand Down
97 changes: 91 additions & 6 deletions backend/app/models/knowledge_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import enum
from datetime import datetime
from typing import Optional
from typing import Annotated, Dict, Literal, Optional, Union
from uuid import UUID

from pydantic import BaseModel
from sqlalchemy import JSON, func
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlmodel import (
Expand All @@ -12,11 +13,17 @@
Relationship as SQLRelationship,
SQLModel,
)

from llama_index.core.node_parser.text.sentence import (
DEFAULT_PARAGRAPH_SEP,
SENTENCE_CHUNK_OVERLAP,
CHUNKING_REGEX,
)
from app.rag.node_parser.file.markdown import DEFAULT_CHUNK_HEADER_LEVEL
from app.api.admin_routes.models import KnowledgeBaseDescriptor
from app.exceptions import KBDataSourceNotFound
from app.models.auth import User
from app.models.data_source import DataSource
from app.models.document import ContentFormat
from app.models.embed_model import EmbeddingModel
from app.models.llm import LLM

Expand All @@ -36,11 +43,93 @@ class KnowledgeBaseDataSource(SQLModel, table=True):
__tablename__ = "knowledge_base_datasources"


class ChunkSplitter(str, enum.Enum):
SENTENCE_SPLITTER = "sentence-splitter"
MARKDOWN_SPLITTER = "markdown-splitter"


class BaseSplitterConfig(BaseModel):
type: ChunkSplitter
chunk_size: int = Field(
description="The token chunk size for each chunk.",
default=1000,
gt=0,
)


class SentenceSplitterConfig(BaseSplitterConfig):
type: Literal[ChunkSplitter.SENTENCE_SPLITTER] = ChunkSplitter.SENTENCE_SPLITTER
chunk_overlap: int = Field(
description="The overlap size for each chunk.",
default=SENTENCE_CHUNK_OVERLAP,
gt=0,
)
separator: str = Field(
description="The separator for splitting the text.",
default=" ",
)
paragraph_separator: str = Field(
description="The paragraph separator for splitting the text.",
default=DEFAULT_PARAGRAPH_SEP,
)
secondary_chunking_regex: str = Field(
description="The regex for secondary chunking.",
default=CHUNKING_REGEX,
)


class MarkdownSplitterConfig(BaseSplitterConfig):
type: Literal[ChunkSplitter.MARKDOWN_SPLITTER] = ChunkSplitter.MARKDOWN_SPLITTER
chunk_header_level: int = Field(
description="The header level to split on",
default=DEFAULT_CHUNK_HEADER_LEVEL,
ge=1,
le=6,
)


ChunkSplitterConfig = Annotated[
Union[SentenceSplitterConfig, MarkdownSplitterConfig], Field(discriminator="type")
]


default_chunking_rules = {
ContentFormat.TEXT: SentenceSplitterConfig(),
ContentFormat.MARKDOWN: MarkdownSplitterConfig(),
}


class ChunkingMode(str, enum.Enum):
AUTO = "auto"
ADVANCED = "advanced"


class ChunkingConfig(BaseModel):
mode: ChunkingMode = Field(default=ChunkingMode.AUTO)


class AutoChunkingConfig(ChunkingConfig):
chunk_size: int = Field(default=1000, gt=0)
chunk_overlap: int = Field(default=SENTENCE_CHUNK_OVERLAP, gt=0)
paragraph_separator: str = Field(default=DEFAULT_PARAGRAPH_SEP)


class AdvancedChunkingConfig(ChunkingConfig):
rules: Dict[ContentFormat, ChunkSplitterConfig] = Field(
default=default_chunking_rules
)


class KnowledgeBase(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(max_length=255, nullable=False)
description: str = Field(sa_column=Column(MEDIUMTEXT))

# The config for transforming the document into (chunk) nodes.
chunking_config: Dict = Field(
sa_column=Column(JSON), default=AutoChunkingConfig().model_dump()
)

# Data sources config.
data_sources: list["DataSource"] = SQLRelationship(
link_model=KnowledgeBaseDataSource
Expand All @@ -64,10 +153,6 @@ class KnowledgeBase(SQLModel, table=True):
"foreign_keys": "KnowledgeBase.embedding_model_id",
},
)

# TODO: Support knowledge-base level retrieval config.

# TODO: Store the statistics of the knowledge base.
documents_total: int = Field(default=0)
data_sources_total: int = Field(default=0)

Expand Down
101 changes: 73 additions & 28 deletions backend/app/rag/build_index.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
import logging
from typing import Optional, Type
from typing import List, Optional, Type

from llama_index.core import VectorStoreIndex
from llama_index.core.embeddings.utils import EmbedType
from llama_index.core.llms.llm import LLM
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import TransformComponent

from sqlmodel import Session

from app.models.knowledge_base import KnowledgeBase
from app.models.document import ContentFormat
from app.models.knowledge_base import (
AdvancedChunkingConfig,
AutoChunkingConfig,
ChunkSplitter,
ChunkingConfig,
ChunkingMode,
KnowledgeBase,
MarkdownSplitterConfig,
SentenceSplitterConfig,
)
from app.rag.knowledge_base.index_store import (
get_kb_tidb_vector_store,
get_kb_tidb_graph_store,
)
from app.rag.indices.knowledge_graph import KnowledgeGraphIndex
from app.core.config import settings
from app.models import (
Document as DBDocument,
Chunk as DBChunk,
)
from app.models import Document, Chunk
from app.rag.node_parser.file.markdown import MarkdownNodeParser
from app.utils.dspy import get_dspy_lm_by_llama_llm

logger = logging.getLogger(__name__)
Expand All @@ -41,7 +50,7 @@ def __init__(

# TODO: move to ./indices/vector_search
def build_vector_index_for_document(
self, session: Session, db_document: Type[DBDocument]
self, session: Session, db_document: Type[Document]
):
"""
Build vector index and graph index from document.
Expand All @@ -52,37 +61,73 @@ def build_vector_index_for_document(
3. embedding text nodes.
4. Insert nodes into `chunks` table.
"""

if db_document.mime_type.lower() == "text/markdown":
# spliter = MarkdownNodeParser()
# TODO: FIX MarkdownNodeParser
spliter = SentenceSplitter(
chunk_size=settings.EMBEDDING_MAX_TOKENS,
)
else:
spliter = SentenceSplitter(
chunk_size=settings.EMBEDDING_MAX_TOKENS,
)

_transformations = [spliter]

vector_store = get_kb_tidb_vector_store(session, self._knowledge_base)
transformations = self._get_transformations(db_document)
vector_index = VectorStoreIndex.from_vector_store(
vector_store,
embed_model=self._embed_model,
transformations=_transformations,
transformations=transformations,
)

document = db_document.to_llama_document()
logger.info(f"Start building vector index for document #{document.doc_id}.")
vector_index.insert(document, source_uri=db_document.source_uri)
logger.info(f"Finish building vector index for document #{document.doc_id}.")
llama_document = db_document.to_llama_document()
logger.info(f"Start building vector index for document #{db_document.id}.")
vector_index.insert(llama_document, source_uri=db_document.source_uri)
logger.info(f"Finish building vector index for document #{db_document.id}.")
vector_store.close_session()

return

def _get_transformations(
self, db_document: Type[Document]
) -> List[TransformComponent]:
transformations = []

chunking_config_dict = self._knowledge_base.chunking_config
config = ChunkingConfig.model_validate(chunking_config_dict)
if config.mode != ChunkingMode.AUTO:
auto_chunking_config = AutoChunkingConfig.model_validate(
chunking_config_dict
)
chunking_config = AdvancedChunkingConfig(
mode=ChunkingMode.ADVANCED,
rules={
ContentFormat.TEXT: SentenceSplitterConfig(
chunk_size=auto_chunking_config.chunk_size,
chunk_overlap=auto_chunking_config.chunk_overlap,
),
ContentFormat.MARKDOWN: MarkdownSplitterConfig(
chunk_size=auto_chunking_config.chunk_size,
),
},
)
elif config.mode == ChunkingMode.ADVANCED:
chunking_config = AdvancedChunkingConfig.model_validate(
chunking_config_dict
)

# Chunking
content_format = db_document.content_format
if content_format in chunking_config.rules:
splitter_config = chunking_config.rules[content_format]
else:
splitter_config = chunking_config.rules[ContentFormat.TEXT]

match splitter_config.type:
case ChunkSplitter.MARKDOWN_SPLITTER:
transformations.append(
MarkdownNodeParser(**splitter_config.model_dump(exclude={"type"}))
)
case ChunkSplitter.SENTENCE_SPLITTER:
transformations.append(
SentenceSplitter(**splitter_config.model_dump(exclude={"type"}))
)
case _:
raise ValueError(f"Unsupported splitter type: {splitter_config.type}")

return transformations

# TODO: move to ./indices/knowledge_graph
def build_kg_index_for_chunk(self, session: Session, db_chunk: Type[DBChunk]):
def build_kg_index_for_chunk(self, session: Session, db_chunk: Type[Chunk]):
"""Build knowledge graph index from chunk.
Build knowledge graph index will do the following:
Expand Down
Loading

0 comments on commit 1e5cc9c

Please sign in to comment.