Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Mini256 committed Feb 14, 2025
1 parent 676b385 commit c5b3347
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 91 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""chunking_setting
"""chunking_settings
Revision ID: 76f4e47a03d3
Revision ID: a849e5f21516
Revises: 2adc0b597dcd
Create Date: 2025-02-13 18:21:17.830980
Create Date: 2025-02-14 17:15:25.930884
"""

Expand All @@ -11,22 +11,14 @@


# revision identifiers, used by Alembic.
revision = "76f4e47a03d3"
revision = "a849e5f21516"
down_revision = "2adc0b597dcd"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"documents",
sa.Column(
"content_format",
sa.Enum("TEXT", "MARKDOWN", name="contentformat"),
nullable=False,
),
)
op.add_column(
"knowledge_bases", sa.Column("chunking_config", sa.JSON(), nullable=True)
)
Expand All @@ -36,5 +28,4 @@ def upgrade():
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("knowledge_bases", "chunking_config")
op.drop_column("documents", "content_format")
# ### end Alembic commands ###
3 changes: 2 additions & 1 deletion backend/app/api/admin_routes/knowledge_base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from app.exceptions import KBNoVectorIndexConfigured
from app.models import KgIndexStatus
from app.models.knowledge_base import IndexMethod
from app.models.knowledge_base import IndexMethod, ChunkingConfig, GeneralChunkingConfig


class KnowledgeBaseCreate(BaseModel):
Expand All @@ -25,6 +25,7 @@ class KnowledgeBaseCreate(BaseModel):
)
llm_id: Optional[int] = None
embedding_model_id: Optional[int] = None
chunking_config: ChunkingConfig = Field(default_factory=GeneralChunkingConfig)
data_sources: list[KBDataSourceCreate] = Field(default_factory=list)

@field_validator("name")
Expand Down
1 change: 0 additions & 1 deletion backend/app/models/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class Document(UpdatableBaseModel, table=True):
hash: str = Field(max_length=32)
name: str = Field(max_length=256)
content: str = Field(sa_column=Column(MEDIUMTEXT))
content_format: ContentFormat = Field(max_length=20)
mime_type: MimeTypes = Field(sa_column=Column(String(128), nullable=False))
source_uri: str = Field(max_length=512)
meta: dict | list = Field(default={}, sa_column=Column(JSON))
Expand Down
64 changes: 34 additions & 30 deletions backend/app/models/knowledge_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import enum
from datetime import datetime
from typing import Annotated, Dict, Literal, Optional, Union
from typing import Dict, Optional, Union
from uuid import UUID

from pydantic import BaseModel
Expand All @@ -17,14 +17,17 @@
DEFAULT_PARAGRAPH_SEP,
SENTENCE_CHUNK_OVERLAP,
)
from app.rag.node_parser.file.markdown import DEFAULT_CHUNK_HEADER_LEVEL
from app.rag.node_parser.file.markdown import (
DEFAULT_CHUNK_HEADER_LEVEL,
DEFAULT_CHUNK_SIZE,
)
from app.api.admin_routes.models import KnowledgeBaseDescriptor
from app.exceptions import KBDataSourceNotFound
from app.models.auth import User
from app.models.data_source import DataSource
from app.models.document import ContentFormat
from app.models.embed_model import EmbeddingModel
from app.models.llm import LLM
from app.types import MimeTypes

# For compatibility with old code, define a fake knowledge base id.
PHONY_KNOWLEDGE_BASE_ID = 0
Expand All @@ -42,22 +45,20 @@ class KnowledgeBaseDataSource(SQLModel, table=True):
__tablename__ = "knowledge_base_datasources"


# Chunking Settings.


class ChunkSplitter(str, enum.Enum):
SENTENCE_SPLITTER = "sentence-splitter"
MARKDOWN_SPLITTER = "markdown-splitter"
SENTENCE_SPLITTER = "SentenceSplitter"
MARKDOWN_NODE_PARSER = "MarkdownNodeParser"


class BaseSplitterConfig(BaseModel):
type: ChunkSplitter
class SentenceSplitterOptions(BaseModel):
chunk_size: int = Field(
description="The token chunk size for each chunk.",
default=1000,
gt=0,
)


class SentenceSplitterConfig(BaseSplitterConfig):
type: Literal[ChunkSplitter.SENTENCE_SPLITTER] = ChunkSplitter.SENTENCE_SPLITTER
chunk_overlap: int = Field(
description="The overlap size for each chunk.",
default=SENTENCE_CHUNK_OVERLAP,
Expand All @@ -69,8 +70,12 @@ class SentenceSplitterConfig(BaseSplitterConfig):
)


class MarkdownSplitterConfig(BaseSplitterConfig):
type: Literal[ChunkSplitter.MARKDOWN_SPLITTER] = ChunkSplitter.MARKDOWN_SPLITTER
class MarkdownNodeParserOptions(BaseModel):
chunk_size: int = Field(
description="The token chunk size for each chunk.",
default=1000,
gt=0,
)
chunk_header_level: int = Field(
description="The header level to split on",
default=DEFAULT_CHUNK_HEADER_LEVEL,
Expand All @@ -79,36 +84,35 @@ class MarkdownSplitterConfig(BaseSplitterConfig):
)


ChunkSplitterConfig = Annotated[
Union[SentenceSplitterConfig, MarkdownSplitterConfig], Field(discriminator="type")
]


default_chunking_rules = {
ContentFormat.TEXT: SentenceSplitterConfig(),
ContentFormat.MARKDOWN: MarkdownSplitterConfig(),
}
class ChunkSplitterConfig(BaseModel):
splitter: ChunkSplitter = Field(default=ChunkSplitter.SENTENCE_SPLITTER)
splitter_options: Union[SentenceSplitterOptions, MarkdownNodeParserOptions] = (
Field()
)


class ChunkingMode(str, enum.Enum):
AUTO = "auto"
GENERAL = "general"
ADVANCED = "advanced"


class ChunkingConfig(BaseModel):
mode: ChunkingMode = Field(default=ChunkingMode.AUTO)
mode: ChunkingMode = Field(default=ChunkingMode.GENERAL)


class AutoChunkingConfig(ChunkingConfig):
chunk_size: int = Field(default=1000, gt=0)
class GeneralChunkingConfig(ChunkingConfig):
mode: ChunkingMode = Field(default=ChunkingMode.GENERAL)
chunk_size: int = Field(default=DEFAULT_CHUNK_SIZE, gt=0)
chunk_overlap: int = Field(default=SENTENCE_CHUNK_OVERLAP, gt=0)
paragraph_separator: str = Field(default=DEFAULT_PARAGRAPH_SEP)


class AdvancedChunkingConfig(ChunkingConfig):
rules: Dict[ContentFormat, ChunkSplitterConfig] = Field(
default=default_chunking_rules
)
mode: ChunkingMode = Field(default=ChunkingMode.ADVANCED)
rules: Dict[MimeTypes, ChunkSplitterConfig] = Field(default_factory=list)


# Knowledge Base Model


class KnowledgeBase(SQLModel, table=True):
Expand All @@ -118,7 +122,7 @@ class KnowledgeBase(SQLModel, table=True):

# The config for transforming the document into (chunk) nodes.
chunking_config: Dict = Field(
sa_column=Column(JSON), default=AutoChunkingConfig().model_dump()
sa_column=Column(JSON), default=GeneralChunkingConfig().model_dump()
)

# Data sources config.
Expand Down
82 changes: 45 additions & 37 deletions backend/app/rag/build_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,15 @@
from llama_index.core.schema import TransformComponent

from sqlmodel import Session

from app.models.document import ContentFormat
from app.models.knowledge_base import (
AdvancedChunkingConfig,
AutoChunkingConfig,
ChunkSplitter,
ChunkingConfig,
ChunkingMode,
KnowledgeBase,
MarkdownSplitterConfig,
SentenceSplitterConfig,
SentenceSplitterOptions,
GeneralChunkingConfig,
ChunkSplitterConfig,
MarkdownNodeParserOptions,
AdvancedChunkingConfig,
)
from app.rag.knowledge_base.index_store import (
get_kb_tidb_vector_store,
Expand All @@ -27,6 +25,7 @@
from app.rag.indices.knowledge_graph import KnowledgeGraphIndex
from app.models import Document, Chunk
from app.rag.node_parser.file.markdown import MarkdownNodeParser
from app.types import MimeTypes
from app.utils.dspy import get_dspy_lm_by_llama_llm

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -83,46 +82,55 @@ def _get_transformations(
transformations = []

chunking_config_dict = self._knowledge_base.chunking_config
config = ChunkingConfig.model_validate(chunking_config_dict)
if config.mode != ChunkingMode.AUTO:
auto_chunking_config = AutoChunkingConfig.model_validate(
mode = (
chunking_config_dict["mode"]
if "mode" in chunking_config_dict
else ChunkingMode.GENERAL
)

if mode == ChunkingMode.ADVANCED:
chunking_config = AdvancedChunkingConfig.model_validate(
chunking_config_dict
)
chunking_config = AdvancedChunkingConfig(
mode=ChunkingMode.ADVANCED,
rules={
ContentFormat.TEXT: SentenceSplitterConfig(
chunk_size=auto_chunking_config.chunk_size,
chunk_overlap=auto_chunking_config.chunk_overlap,
rules = chunking_config.rules
else:
chunking_config = GeneralChunkingConfig.model_validate(chunking_config_dict)
rules = {
MimeTypes.PLAIN_TXT: ChunkSplitterConfig(
splitter=ChunkSplitter.SENTENCE_SPLITTER,
splitter_options=SentenceSplitterOptions(
chunk_size=chunking_config.chunk_size,
chunk_overlap=chunking_config.chunk_overlap,
paragraph_separator=chunking_config.paragraph_separator,
),
ContentFormat.MARKDOWN: MarkdownSplitterConfig(
chunk_size=auto_chunking_config.chunk_size,
),
MimeTypes.MARKDOWN: ChunkSplitterConfig(
splitter=ChunkSplitter.MARKDOWN_NODE_PARSER,
splitter_options=MarkdownNodeParserOptions(
chunk_size=chunking_config.chunk_size,
),
},
)
elif config.mode == ChunkingMode.ADVANCED:
chunking_config = AdvancedChunkingConfig.model_validate(
chunking_config_dict
)
),
}

# Chunking
content_format = db_document.content_format
if content_format in chunking_config.rules:
splitter_config = chunking_config.rules[content_format]
else:
splitter_config = chunking_config.rules[ContentFormat.TEXT]
mime_type = db_document.mime_type
if mime_type not in rules:
raise RuntimeError(
f"Can not chunking for the document in {db_document.mime_type} format"
)

match splitter_config.type:
case ChunkSplitter.MARKDOWN_SPLITTER:
transformations.append(
MarkdownNodeParser(**splitter_config.model_dump(exclude={"type"}))
rule = rules[mime_type]
match rule.splitter:
case ChunkSplitter.MARKDOWN_NODE_PARSER:
options = MarkdownNodeParserOptions.model_validate(
rule.splitter_options
)
transformations.append(MarkdownNodeParser(**options.model_dump()))
case ChunkSplitter.SENTENCE_SPLITTER:
transformations.append(
SentenceSplitter(**splitter_config.model_dump(exclude={"type"}))
)
options = SentenceSplitterOptions.model_validate(rule.splitter_options)
transformations.append(SentenceSplitter(**options.model_dump()))
case _:
raise ValueError(f"Unsupported splitter type: {splitter_config.type}")
raise ValueError(f"Unsupported chunking splitter type: {rule.splitter}")

return transformations

Expand Down
16 changes: 7 additions & 9 deletions backend/app/rag/datasource/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from app.models import Document, Upload
from app.file_storage import default_file_storage
from app.models.document import ContentFormat
from app.types import MimeTypes
from .base import BaseDataSource

Expand Down Expand Up @@ -37,29 +36,28 @@ def load_documents(self) -> Generator[Document, None, None]:
with default_file_storage.open(upload.path) as f:
if upload.mime_type == MimeTypes.PDF:
content = extract_text_from_pdf(f)
content_format = ContentFormat.TEXT
mime_type = MimeTypes.PLAIN_TXT
elif upload.mime_type == MimeTypes.DOCX:
content = extract_text_from_docx(f)
content_format = ContentFormat.TEXT
mime_type = MimeTypes.PLAIN_TXT
elif upload.mime_type == MimeTypes.PPTX:
content = extract_text_from_pptx(f)
content_format = ContentFormat.TEXT
mime_type = MimeTypes.PLAIN_TXT
elif upload.mime_type == MimeTypes.XLSX:
content = extract_text_from_xlsx(f)
content_format = ContentFormat.TEXT
mime_type = MimeTypes.PLAIN_TXT
elif upload.mime_type == MimeTypes.MARKDOWN:
content = f.read()
content_format = ContentFormat.MARKDOWN
mime_type = MimeTypes.MARKDOWN
else:
content = f.read()
content_format = ContentFormat.TEXT
mime_type = MimeTypes.PLAIN_TXT

document = Document(
name=upload.name,
hash=hash(content),
content=content,
content_format=content_format,
mime_type=upload.mime_type,
mime_type=mime_type,
knowledge_base_id=self.knowledge_base_id,
data_source_id=self.data_source_id,
user_id=self.user_id,
Expand Down

0 comments on commit c5b3347

Please sign in to comment.