diff --git a/private_gpt/components/ingest/ingest_component.py b/private_gpt/components/ingest/ingest_component.py index 5ed039590..674d770fc 100644 --- a/private_gpt/components/ingest/ingest_component.py +++ b/private_gpt/components/ingest/ingest_component.py @@ -40,7 +40,12 @@ def __init__( self.transformations = transformations @abc.abstractmethod - def ingest(self, file_name: str, file_data: Path) -> list[Document]: + def ingest( + self, + file_name: str, + file_data: Path, + file_metadata: dict[str, Any] | None = None, + ) -> list[Document]: pass @abc.abstractmethod @@ -117,9 +122,16 @@ def __init__( ) -> None: super().__init__(storage_context, embed_model, transformations, *args, **kwargs) - def ingest(self, file_name: str, file_data: Path) -> list[Document]: + def ingest( + self, + file_name: str, + file_data: Path, + file_metadata: dict[str, Any] | None = None, + ) -> list[Document]: logger.info("Ingesting file_name=%s", file_name) - documents = IngestionHelper.transform_file_into_documents(file_name, file_data) + documents = IngestionHelper.transform_file_into_documents( + file_name, file_data, file_metadata + ) logger.info( "Transformed file=%s into count=%s documents", file_name, len(documents) ) @@ -175,9 +187,16 @@ def __init__( processes=self.count_workers ) - def ingest(self, file_name: str, file_data: Path) -> list[Document]: + def ingest( + self, + file_name: str, + file_data: Path, + file_metadata: dict[str, Any] | None = None, + ) -> list[Document]: logger.info("Ingesting file_name=%s", file_name) - documents = IngestionHelper.transform_file_into_documents(file_name, file_data) + documents = IngestionHelper.transform_file_into_documents( + file_name, file_data, file_metadata + ) logger.info( "Transformed file=%s into count=%s documents", file_name, len(documents) ) @@ -185,6 +204,7 @@ def ingest(self, file_name: str, file_data: Path) -> list[Document]: return self._save_docs(documents) def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]: + documents = list( itertools.chain.from_iterable( self._file_to_documents_work_pool.starmap( @@ -257,12 +277,18 @@ def __init__( processes=self.count_workers ) - def ingest(self, file_name: str, file_data: Path) -> list[Document]: + def ingest( + self, + file_name: str, + file_data: Path, + file_metadata: dict[str, Any] | None = None, + ) -> list[Document]: logger.info("Ingesting file_name=%s", file_name) # Running in a single (1) process to release the current # thread, and take a dedicated CPU core for computation documents = self._file_to_documents_work_pool.apply( - IngestionHelper.transform_file_into_documents, (file_name, file_data) + IngestionHelper.transform_file_into_documents, + (file_name, file_data, file_metadata), ) logger.info( "Transformed file=%s into count=%s documents", file_name, len(documents) @@ -271,9 +297,9 @@ def ingest(self, file_name: str, file_data: Path) -> list[Document]: return self._save_docs(documents) def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]: + # Lightweight threads, used for parallelize the # underlying IO calls made in the ingestion - documents = list( itertools.chain.from_iterable( self._ingest_work_pool.starmap(self.ingest, files) @@ -459,8 +485,15 @@ def _flush(self) -> None: self.node_q.put(("flush", None, None, None)) self.node_q.join() - def ingest(self, file_name: str, file_data: Path) -> list[Document]: - documents = IngestionHelper.transform_file_into_documents(file_name, file_data) + def ingest( + self, + file_name: str, + file_data: Path, + file_metadata: dict[str, Any] | None = None, + ) -> list[Document]: + documents = IngestionHelper.transform_file_into_documents( + file_name, file_data, file_metadata + ) self.doc_q.put(("process", file_name, documents)) self._flush() return documents diff --git a/private_gpt/components/ingest/ingest_helper.py b/private_gpt/components/ingest/ingest_helper.py index a11090702..4ac358a98 100644 --- a/private_gpt/components/ingest/ingest_helper.py +++ b/private_gpt/components/ingest/ingest_helper.py @@ -1,5 +1,6 @@ import logging from pathlib import Path +from typing import Any from llama_index.core.readers import StringIterableReader from llama_index.core.readers.base import BaseReader @@ -69,11 +70,13 @@ class IngestionHelper: @staticmethod def transform_file_into_documents( - file_name: str, file_data: Path + file_name: str, file_data: Path, file_metadata: dict[str, Any] | None = None ) -> list[Document]: documents = IngestionHelper._load_file_to_documents(file_name, file_data) for document in documents: + document.metadata.update(file_metadata or {}) document.metadata["file_name"] = file_name + IngestionHelper._exclude_metadata(documents) return documents diff --git a/private_gpt/server/ingest/ingest_router.py b/private_gpt/server/ingest/ingest_router.py index 56adba468..d9136be7b 100644 --- a/private_gpt/server/ingest/ingest_router.py +++ b/private_gpt/server/ingest/ingest_router.py @@ -1,6 +1,7 @@ -from typing import Literal +import json +from typing import Any, Literal -from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile +from fastapi import APIRouter, Depends, Form, HTTPException, Request, UploadFile from pydantic import BaseModel, Field from private_gpt.server.ingest.ingest_service import IngestService @@ -20,6 +21,16 @@ class IngestTextBody(BaseModel): "Chinese martial arts." ] ) + metadata: dict[str, Any] = Field( + None, + examples=[ + { + "title": "Avatar: The Last Airbender", + "author": "Michael Dante DiMartino, Bryan Konietzko", + "year": "2005", + } + ], + ) class IngestResponse(BaseModel): @@ -38,9 +49,16 @@ def ingest(request: Request, file: UploadFile) -> IngestResponse: @ingest_router.post("/ingest/file", tags=["Ingestion"]) -def ingest_file(request: Request, file: UploadFile) -> IngestResponse: +def ingest_file( + request: Request, file: UploadFile, metadata: str = Form(None) +) -> IngestResponse: """Ingests and processes a file, storing its chunks to be used as context. + metadata: Optional metadata to be associated with the file. + You do not have to specify this field if not needed. + The metadata needs to be in JSON format. + e.g. {"title": "Avatar: The Last Airbender", "year": "2005"} + The context obtained from files is later used in `/chat/completions`, `/completions`, and `/chunks` APIs. @@ -57,7 +75,11 @@ def ingest_file(request: Request, file: UploadFile) -> IngestResponse: service = request.state.injector.get(IngestService) if file.filename is None: raise HTTPException(400, "No file name provided") - ingested_documents = service.ingest_bin_data(file.filename, file.file) + + metadata_dict = None if metadata is None else json.loads(metadata) + ingested_documents = service.ingest_bin_data( + file.filename, file.file, metadata_dict + ) return IngestResponse(object="list", model="private-gpt", data=ingested_documents) @@ -73,11 +95,12 @@ def ingest_text(request: Request, body: IngestTextBody) -> IngestResponse: extracted Metadata (which is later used to improve context retrieval). That ID can be used to filter the context used to create responses in `/chat/completions`, `/completions`, and `/chunks` APIs. + """ service = request.state.injector.get(IngestService) if len(body.file_name) == 0: raise HTTPException(400, "No file name provided") - ingested_documents = service.ingest_text(body.file_name, body.text) + ingested_documents = service.ingest_text(body.file_name, body.text, body.metadata) return IngestResponse(object="list", model="private-gpt", data=ingested_documents) diff --git a/private_gpt/server/ingest/ingest_service.py b/private_gpt/server/ingest/ingest_service.py index f9ae4728f..0cd003be1 100644 --- a/private_gpt/server/ingest/ingest_service.py +++ b/private_gpt/server/ingest/ingest_service.py @@ -48,7 +48,12 @@ def __init__( settings=settings(), ) - def _ingest_data(self, file_name: str, file_data: AnyStr) -> list[IngestedDoc]: + def _ingest_data( + self, + file_name: str, + file_data: AnyStr, + file_metadata: dict[str, str] | None = None, + ) -> list[IngestedDoc]: logger.debug("Got file data of size=%s to ingest", len(file_data)) # llama-index mainly supports reading from files, so # we have to create a tmp file to read for it to work @@ -60,27 +65,37 @@ def _ingest_data(self, file_name: str, file_data: AnyStr) -> list[IngestedDoc]: path_to_tmp.write_bytes(file_data) else: path_to_tmp.write_text(str(file_data)) - return self.ingest_file(file_name, path_to_tmp) + return self.ingest_file(file_name, path_to_tmp, file_metadata) finally: tmp.close() path_to_tmp.unlink() - def ingest_file(self, file_name: str, file_data: Path) -> list[IngestedDoc]: + def ingest_file( + self, + file_name: str, + file_data: Path, + file_metadata: dict[str, str] | None = None, + ) -> list[IngestedDoc]: logger.info("Ingesting file_name=%s", file_name) - documents = self.ingest_component.ingest(file_name, file_data) + documents = self.ingest_component.ingest(file_name, file_data, file_metadata) logger.info("Finished ingestion file_name=%s", file_name) return [IngestedDoc.from_document(document) for document in documents] - def ingest_text(self, file_name: str, text: str) -> list[IngestedDoc]: + def ingest_text( + self, file_name: str, text: str, metadata: dict[str, str] | None = None + ) -> list[IngestedDoc]: logger.debug("Ingesting text data with file_name=%s", file_name) - return self._ingest_data(file_name, text) + return self._ingest_data(file_name, text, metadata) def ingest_bin_data( - self, file_name: str, raw_file_data: BinaryIO + self, + file_name: str, + raw_file_data: BinaryIO, + file_metadata: dict[str, str] | None = None, ) -> list[IngestedDoc]: logger.debug("Ingesting binary data with file_name=%s", file_name) file_data = raw_file_data.read() - return self._ingest_data(file_name, file_data) + return self._ingest_data(file_name, file_data, file_metadata) def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[IngestedDoc]: logger.info("Ingesting file_names=%s", [f[0] for f in files]) diff --git a/tests/fixtures/ingest_helper.py b/tests/fixtures/ingest_helper.py index 25515f4e9..ab7946674 100644 --- a/tests/fixtures/ingest_helper.py +++ b/tests/fixtures/ingest_helper.py @@ -1,4 +1,6 @@ +import json from pathlib import Path +from typing import Any import pytest from fastapi.testclient import TestClient @@ -18,6 +20,20 @@ def ingest_file(self, path: Path) -> IngestResponse: ingest_result = IngestResponse.model_validate(response.json()) return ingest_result + def ingest_file_with_metadata( + self, path: Path, metadata: dict[str, Any] + ) -> IngestResponse: + files = { + "file": (path.name, path.open("rb")), + "metadata": (None, json.dumps(metadata)), + } + + response = self.test_client.post("/v1/ingest/file", files=files) + + assert response.status_code == 200 + ingest_result = IngestResponse.model_validate(response.json()) + return ingest_result + @pytest.fixture() def ingest_helper(test_client: TestClient) -> IngestHelper: diff --git a/tests/server/ingest/test_ingest_routes.py b/tests/server/ingest/test_ingest_routes.py index 896410a17..6fbf5fa4d 100644 --- a/tests/server/ingest/test_ingest_routes.py +++ b/tests/server/ingest/test_ingest_routes.py @@ -44,3 +44,25 @@ def test_ingest_plain_text(test_client: TestClient) -> None: assert response.status_code == 200 ingest_result = IngestResponse.model_validate(response.json()) assert len(ingest_result.data) == 1 + + +def test_ingest_text_with_metadata(test_client: TestClient): + response = test_client.post( + "/v1/ingest/text", + json={"file_name": "file_name", "text": "text", "metadata": {"foo": "bar"}}, + ) + assert response.status_code == 200 + ingest_result = IngestResponse.model_validate(response.json()) + assert len(ingest_result.data) == 1 + + assert ingest_result.data[0].doc_metadata == { + "file_name": "file_name", + "foo": "bar", + } + + +def test_ingest_accepts_txt_files_with_metadata(ingest_helper: IngestHelper) -> None: + path = Path(__file__).parents[0] / "test.txt" + ingest_result = ingest_helper.ingest_file_with_metadata(path, {"foo": "bar"}) + assert len(ingest_result.data) == 1 + assert ingest_result.data[0].doc_metadata == {"file_name": "test.txt", "foo": "bar"}