Skip to content

feat: Implement preview update_rag_engine_config in rag_data.py #5167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/unit/vertex_rag/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def rag_data_client_preview_mock_exception():
api_client_mock.list_rag_files.side_effect = Exception
# delete_rag_file
api_client_mock.delete_rag_file.side_effect = Exception
# update_rag_engine_config
api_client_mock.update_rag_engine_config.side_effect = Exception
rag_data_client_mock_exception.return_value = api_client_mock
yield rag_data_client_mock_exception

Expand Down
17 changes: 17 additions & 0 deletions tests/unit/vertex_rag/test_rag_constants_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
SharePointSources as GapicSharePointSources,
SlackSource as GapicSlackSource,
RagContexts,
RagManagedDbConfig,
RagEngineConfig,
RetrieveContextsResponse,
RagVectorDbConfig as GapicRagVectorDbConfig,
VertexAiSearchConfig as GapicVertexAiSearchConfig,
Expand All @@ -75,7 +77,10 @@
TEST_CORPUS_DISCRIPTION = "My first corpus."
TEST_RAG_CORPUS_ID = "generate-123"
TEST_API_ENDPOINT = "us-central1-" + aiplatform.constants.base.API_BASE_PATH
TEST_RAG_ENGINE_CONFIG_DISPLAY_NAME = "my-rag-engine-config-1"
TEST_RAG_ENGINE_CONFIG_DESCRIPTION = "My first rag engine config."
TEST_RAG_CORPUS_RESOURCE_NAME = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragCorpora/{TEST_RAG_CORPUS_ID}"
TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragEngineConfigs/test-rag-engine-config"

# RagCorpus
TEST_WEAVIATE_HTTP_ENDPOINT = "test.weaviate.com"
Expand Down Expand Up @@ -386,6 +391,18 @@
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING,
)

# Config Resource
TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME = (
TEST_RAG_CORPUS_RESOURCE_NAME + "/ragEngineConfigs/test-rag-engine-config"
)
TEST_RAG_ENGINE_CONFIG_DISPLAY_NAME = "test-rag-engine-config"
TEST_RAG_ENGINE_CONFIG_DESCRIPTION = "test-rag-engine-config-description"
TEST_RAG_ENGINE_CONFIG = RagEngineConfig(
name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME,
rag_managed_db_config=RagManagedDbConfig(),
)

# Google Drive files
TEST_DRIVE_FILE_ID = "456"
TEST_DRIVE_FILE = f"https://drive.google.com/file/d/{TEST_DRIVE_FILE_ID}"
Expand Down
47 changes: 47 additions & 0 deletions tests/unit/vertex_rag/test_rag_data_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,40 @@ def list_rag_corpora_pager_mock():
yield list_rag_corpora_pager_mock


@pytest.fixture()
def update_config_mock():
with mock.patch.object(
VertexRagDataServiceClient,
"update_rag_engine_config",
) as update_rag_engine_config_mock:
update_rag_engine_config_mock.return_value = (
test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG
)
yield update_rag_engine_config_mock


@pytest.fixture()
def update_config_mock_exception():
with mock.patch.object(
VertexRagDataServiceClient,
"update_rag_engine_config",
) as update_rag_engine_config_mock_exception:
update_rag_engine_config_mock_exception.side_effect = Exception
yield update_rag_engine_config_mock_exception


@pytest.fixture()
def update_config_mock_vertex_ai_engine_search_config():
with mock.patch.object(
VertexRagDataServiceClient,
"update_rag_engine_config",
) as update_rag_engine_config_mock_vertex_ai_engine_search_config:
update_rag_engine_config_mock_vertex_ai_engine_search_config.return_value = (
test_rag_constants_preview.TEST_GAPIC_RAG_ENGINE_CONFIG_VERTEX_AI_ENGINE_SEARCH_CONFIG
)
yield update_rag_engine_config_mock_vertex_ai_engine_search_config


class MockResponse:
def __init__(self, json_data, status_code):
self.json_data = json_data
Expand Down Expand Up @@ -1258,3 +1292,16 @@ def test_set_embedding_model_config_wrong_endpoint_format_error(self):
test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS,
)
e.match("endpoint must be of the format ")

def test_update_rag_engine_config_success(self, rag_data_client_preview_mock):
rag.update_rag_engine_config(
rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG,
)
assert rag_data_client_preview_mock.call_count == 1

def test_update_rag_engine_config_failure(self):
with pytest.raises(RuntimeError) as e:
rag.update_rag_engine_config(
rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG,
)
e.match("Failed in RagEngineConfig update due to")
1 change: 1 addition & 0 deletions vertexai/preview/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_file,
list_files,
delete_file,
update_rag_engine_config,
)
from vertexai.preview.rag.rag_retrieval import (
retrieval_query,
Expand Down
36 changes: 36 additions & 0 deletions vertexai/preview/rag/rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ListRagFilesRequest,
RagCorpus as GapicRagCorpus,
UpdateRagCorpusRequest,
UpdateRagEngineConfigRequest,
)
from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service.pagers import (
ListRagCorporaPager,
Expand All @@ -51,6 +52,7 @@
RagCorpus,
RagFile,
RagManagedDb,
RagEngineConfig,
RagVectorDbConfig,
SharePointSources,
SlackChannelsSource,
Expand Down Expand Up @@ -926,3 +928,37 @@ def delete_file(name: str, corpus_name: Optional[str] = None) -> None:
except Exception as e:
raise RuntimeError("Failed in RagFile deletion due to: ", e) from e
return None


def update_rag_engine_config(
rag_engine_config: RagEngineConfig,
) -> None:
"""Update RagEngineConfig.

Example usage:
```
import vertexai
from vertexai.preview import rag
vertexai.init(project="my-project")
rag_engine_config = rag.RagEngineConfig(
name="projects/my-project/locations/us-central1/ragEngineConfigs/my-rag-engine-config"
),
)
rag.update_rag_engine_config(rag_engine_config=rag_engine_config)
```

Args:
rag_engine_config: The RagEngineConfig to update.

Raises:
RuntimeError: Failed in RagEngineConfig update due to exception.
"""
request = UpdateRagEngineConfigRequest(
rag_engine_config=rag_engine_config,
)
client = _gapic_utils.create_rag_data_service_client()
try:
client.update_rag_engine_config(request=request)
except Exception as e:
raise RuntimeError("Failed in RagEngineConfig update due to: ", e) from e
return None
13 changes: 13 additions & 0 deletions vertexai/preview/rag/utils/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,16 @@ class LlmParserConfig:
model_name: str
max_parsing_requests_per_min: Optional[int] = None
custom_parsing_prompt: Optional[str] = None


@dataclasses.dataclass
class RagEngineConfig:
"""RagEngineConfig.

Attributes:
name: Generated resource name. Format:
``projects/{project}/locations/{location}/ragEngineConfig/
{rag_engine_config}``
"""

name: str