diff --git a/tests/unit/vertex_rag/conftest.py b/tests/unit/vertex_rag/conftest.py index 46846be388..18d325cb53 100644 --- a/tests/unit/vertex_rag/conftest.py +++ b/tests/unit/vertex_rag/conftest.py @@ -162,6 +162,8 @@ def rag_data_client_preview_mock_exception(): api_client_mock.list_rag_files.side_effect = Exception # delete_rag_file api_client_mock.delete_rag_file.side_effect = Exception + # update_rag_engine_config + api_client_mock.update_rag_engine_config.side_effect = Exception rag_data_client_mock_exception.return_value = api_client_mock yield rag_data_client_mock_exception diff --git a/tests/unit/vertex_rag/test_rag_constants_preview.py b/tests/unit/vertex_rag/test_rag_constants_preview.py index ee52bce881..e25d0108e8 100644 --- a/tests/unit/vertex_rag/test_rag_constants_preview.py +++ b/tests/unit/vertex_rag/test_rag_constants_preview.py @@ -60,6 +60,8 @@ SharePointSources as GapicSharePointSources, SlackSource as GapicSlackSource, RagContexts, + RagManagedDbConfig, + RagEngineConfig, RetrieveContextsResponse, RagVectorDbConfig as GapicRagVectorDbConfig, VertexAiSearchConfig as GapicVertexAiSearchConfig, @@ -75,7 +77,10 @@ TEST_CORPUS_DISCRIPTION = "My first corpus." TEST_RAG_CORPUS_ID = "generate-123" TEST_API_ENDPOINT = "us-central1-" + aiplatform.constants.base.API_BASE_PATH +TEST_RAG_ENGINE_CONFIG_DISPLAY_NAME = "my-rag-engine-config-1" +TEST_RAG_ENGINE_CONFIG_DESCRIPTION = "My first rag engine config." TEST_RAG_CORPUS_RESOURCE_NAME = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragCorpora/{TEST_RAG_CORPUS_ID}" +TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragEngineConfigs/test-rag-engine-config" # RagCorpus TEST_WEAVIATE_HTTP_ENDPOINT = "test.weaviate.com" @@ -386,6 +391,18 @@ parent=TEST_RAG_CORPUS_RESOURCE_NAME, import_rag_files_config=TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING, ) + +# Config Resource +TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME = ( + TEST_RAG_CORPUS_RESOURCE_NAME + "/ragEngineConfigs/test-rag-engine-config" +) +TEST_RAG_ENGINE_CONFIG_DISPLAY_NAME = "test-rag-engine-config" +TEST_RAG_ENGINE_CONFIG_DESCRIPTION = "test-rag-engine-config-description" +TEST_RAG_ENGINE_CONFIG = RagEngineConfig( + name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, + rag_managed_db_config=RagManagedDbConfig(), +) + # Google Drive files TEST_DRIVE_FILE_ID = "456" TEST_DRIVE_FILE = f"https://drive.google.com/file/d/{TEST_DRIVE_FILE_ID}" diff --git a/tests/unit/vertex_rag/test_rag_data_preview.py b/tests/unit/vertex_rag/test_rag_data_preview.py index d23f718560..98e7d27501 100644 --- a/tests/unit/vertex_rag/test_rag_data_preview.py +++ b/tests/unit/vertex_rag/test_rag_data_preview.py @@ -296,6 +296,40 @@ def list_rag_corpora_pager_mock(): yield list_rag_corpora_pager_mock +@pytest.fixture() +def update_config_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_engine_config", + ) as update_rag_engine_config_mock: + update_rag_engine_config_mock.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG + ) + yield update_rag_engine_config_mock + + +@pytest.fixture() +def update_config_mock_exception(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_engine_config", + ) as update_rag_engine_config_mock_exception: + update_rag_engine_config_mock_exception.side_effect = Exception + yield update_rag_engine_config_mock_exception + + +@pytest.fixture() +def update_config_mock_vertex_ai_engine_search_config(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_engine_config", + ) as update_rag_engine_config_mock_vertex_ai_engine_search_config: + update_rag_engine_config_mock_vertex_ai_engine_search_config.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_VERTEX_AI_ENGINE_SEARCH_CONFIG + ) + yield update_rag_engine_config_mock_vertex_ai_engine_search_config + + class MockResponse: def __init__(self, json_data, status_code): self.json_data = json_data @@ -1258,3 +1292,16 @@ def test_set_embedding_model_config_wrong_endpoint_format_error(self): test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS, ) e.match("endpoint must be of the format ") + + def test_update_rag_engine_config_success(self, rag_data_client_preview_mock): + rag.update_rag_engine_config( + rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG, + ) + assert rag_data_client_preview_mock.call_count == 1 + + def test_update_rag_engine_config_failure(self): + with pytest.raises(RuntimeError) as e: + rag.update_rag_engine_config( + rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG, + ) + e.match("Failed in RagEngineConfig update due to") diff --git a/vertexai/preview/rag/__init__.py b/vertexai/preview/rag/__init__.py index 5841fd5aa6..1245a07e02 100644 --- a/vertexai/preview/rag/__init__.py +++ b/vertexai/preview/rag/__init__.py @@ -25,6 +25,7 @@ get_file, list_files, delete_file, + update_rag_engine_config, ) from vertexai.preview.rag.rag_retrieval import ( retrieval_query, diff --git a/vertexai/preview/rag/rag_data.py b/vertexai/preview/rag/rag_data.py index d0fffa219d..2036e445a5 100644 --- a/vertexai/preview/rag/rag_data.py +++ b/vertexai/preview/rag/rag_data.py @@ -34,6 +34,7 @@ ListRagFilesRequest, RagCorpus as GapicRagCorpus, UpdateRagCorpusRequest, + UpdateRagEngineConfigRequest, ) from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service.pagers import ( ListRagCorporaPager, @@ -51,6 +52,7 @@ RagCorpus, RagFile, RagManagedDb, + RagEngineConfig, RagVectorDbConfig, SharePointSources, SlackChannelsSource, @@ -926,3 +928,37 @@ def delete_file(name: str, corpus_name: Optional[str] = None) -> None: except Exception as e: raise RuntimeError("Failed in RagFile deletion due to: ", e) from e return None + + +def update_rag_engine_config( + rag_engine_config: RagEngineConfig, +) -> None: + """Update RagEngineConfig. + + Example usage: + ``` + import vertexai + from vertexai.preview import rag + vertexai.init(project="my-project") + rag_engine_config = rag.RagEngineConfig( + name="projects/my-project/locations/us-central1/ragEngineConfigs/my-rag-engine-config" + ), + ) + rag.update_rag_engine_config(rag_engine_config=rag_engine_config) + ``` + + Args: + rag_engine_config: The RagEngineConfig to update. + + Raises: + RuntimeError: Failed in RagEngineConfig update due to exception. + """ + request = UpdateRagEngineConfigRequest( + rag_engine_config=rag_engine_config, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + client.update_rag_engine_config(request=request) + except Exception as e: + raise RuntimeError("Failed in RagEngineConfig update due to: ", e) from e + return None diff --git a/vertexai/preview/rag/utils/resources.py b/vertexai/preview/rag/utils/resources.py index 27e2e7f208..28819ab9ac 100644 --- a/vertexai/preview/rag/utils/resources.py +++ b/vertexai/preview/rag/utils/resources.py @@ -524,3 +524,16 @@ class LlmParserConfig: model_name: str max_parsing_requests_per_min: Optional[int] = None custom_parsing_prompt: Optional[str] = None + + +@dataclasses.dataclass +class RagEngineConfig: + """RagEngineConfig. + + Attributes: + name: Generated resource name. Format: + ``projects/{project}/locations/{location}/ragEngineConfig/ + {rag_engine_config}`` + """ + + name: str