Skip to content

Commit 8c0bf19

Browse files
darshanmehta17copybara-github
authored andcommitted
feat: RAG - Add ANN and KNN retrieval strategies for RagManagedDb in preview
PiperOrigin-RevId: 758994681
1 parent 0a127fd commit 8c0bf19

File tree

7 files changed

+507
-44
lines changed

7 files changed

+507
-44
lines changed

owlbot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,10 @@
149149
# Update samples config to use `ucaip-sample-tests` project
150150
s.replace(
151151
".kokoro/samples/python3.*/common.cfg",
152-
"""env_vars: \{
152+
"""env_vars: {
153153
key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
154154
value: "python-docs-samples-tests-.*?"
155-
\}""",
155+
}""",
156156
"""env_vars: {
157157
key: "BUILD_SPECIFIC_GCLOUD_PROJECT"
158158
value: "ucaip-sample-tests"

tests/unit/vertex_rag/test_rag_constants_preview.py

Lines changed: 172 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,60 +17,62 @@
1717

1818

1919
from google.cloud import aiplatform
20-
20+
from google.cloud.aiplatform_v1beta1 import (
21+
GoogleDriveSource,
22+
ImportRagFilesConfig,
23+
ImportRagFilesRequest,
24+
ImportRagFilesResponse,
25+
JiraSource as GapicJiraSource,
26+
RagContexts,
27+
RagCorpus as GapicRagCorpus,
28+
RagEngineConfig as GapicRagEngineConfig,
29+
RagFileChunkingConfig,
30+
RagFileParsingConfig,
31+
RagFileTransformationConfig,
32+
RagFile as GapicRagFile,
33+
RagManagedDbConfig as GapicRagManagedDbConfig,
34+
RagVectorDbConfig as GapicRagVectorDbConfig,
35+
RetrieveContextsResponse,
36+
SharePointSources as GapicSharePointSources,
37+
SlackSource as GapicSlackSource,
38+
VertexAiSearchConfig as GapicVertexAiSearchConfig,
39+
)
40+
from google.cloud.aiplatform_v1beta1.types import api_auth
2141
from vertexai.preview.rag import (
42+
ANN,
43+
Basic,
2244
EmbeddingModelConfig,
45+
Enterprise,
2346
Filter,
2447
HybridSearch,
48+
JiraQuery,
49+
JiraSource,
50+
KNN,
2551
LayoutParserConfig,
2652
LlmParserConfig,
2753
LlmRanker,
2854
Pinecone,
2955
RagCorpus,
56+
RagEmbeddingModelConfig,
57+
RagEngineConfig,
3058
RagFile,
59+
RagManagedDb,
60+
RagManagedDbConfig,
3161
RagResource,
3262
RagRetrievalConfig,
33-
Ranking,
63+
RagVectorDbConfig,
3464
RankService,
65+
Ranking,
3566
SharePointSource,
3667
SharePointSources,
37-
SlackChannelsSource,
3868
SlackChannel,
39-
JiraSource,
40-
JiraQuery,
41-
Weaviate,
69+
SlackChannelsSource,
4270
VertexAiSearchConfig,
43-
VertexVectorSearch,
4471
VertexFeatureStore,
45-
RagEmbeddingModelConfig,
4672
VertexPredictionEndpoint,
47-
RagVectorDbConfig,
48-
RagManagedDbConfig,
49-
RagEngineConfig,
50-
Basic,
51-
Enterprise,
52-
)
53-
from google.cloud.aiplatform_v1beta1 import (
54-
GoogleDriveSource,
55-
RagFileChunkingConfig,
56-
RagFileTransformationConfig,
57-
RagFileParsingConfig,
58-
ImportRagFilesConfig,
59-
ImportRagFilesRequest,
60-
ImportRagFilesResponse,
61-
JiraSource as GapicJiraSource,
62-
RagCorpus as GapicRagCorpus,
63-
RagFile as GapicRagFile,
64-
SharePointSources as GapicSharePointSources,
65-
SlackSource as GapicSlackSource,
66-
RagContexts,
67-
RagManagedDbConfig as GapicRagManagedDbConfig,
68-
RagEngineConfig as GapicRagEngineConfig,
69-
RetrieveContextsResponse,
70-
RagVectorDbConfig as GapicRagVectorDbConfig,
71-
VertexAiSearchConfig as GapicVertexAiSearchConfig,
73+
VertexVectorSearch,
74+
Weaviate,
7275
)
73-
from google.cloud.aiplatform_v1beta1.types import api_auth
7476
from google.protobuf import timestamp_pb2
7577

7678

@@ -102,6 +104,18 @@
102104
index_name=TEST_PINECONE_INDEX_NAME,
103105
api_key=TEST_PINECONE_API_KEY_SECRET_VERSION,
104106
)
107+
TEST_RAG_MANAGED_DB_ANN_TREE_DEPTH = 3
108+
TEST_RAG_MANAGED_DB_ANN_LEAF_COUNT = 100
109+
TEST_RAG_MANAGED_DB_CONFIG = RagManagedDb()
110+
TEST_RAG_MANAGED_DB_KNN_CONFIG = RagManagedDb(
111+
retrieval_strategy=KNN(),
112+
)
113+
TEST_RAG_MANAGED_DB_ANN_CONFIG = RagManagedDb(
114+
retrieval_strategy=ANN(
115+
tree_depth=TEST_RAG_MANAGED_DB_ANN_TREE_DEPTH,
116+
leaf_count=TEST_RAG_MANAGED_DB_ANN_LEAF_COUNT,
117+
),
118+
)
105119
TEST_VERTEX_VECTOR_SEARCH_INDEX_ENDPOINT = "test-vector-search-index-endpoint"
106120
TEST_VERTEX_VECTOR_SEARCH_INDEX = "test-vector-search-index"
107121
TEST_VERTEX_VECTOR_SEARCH_CONFIG = VertexVectorSearch(
@@ -169,12 +183,45 @@
169183
),
170184
),
171185
)
186+
TEST_GAPIC_RAG_CORPUS_RAG_MANAGED_DB = GapicRagCorpus(
187+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
188+
display_name=TEST_CORPUS_DISPLAY_NAME,
189+
description=TEST_CORPUS_DISCRIPTION,
190+
rag_vector_db_config=GapicRagVectorDbConfig(
191+
rag_managed_db=GapicRagVectorDbConfig.RagManagedDb()
192+
),
193+
)
194+
TEST_GAPIC_RAG_CORPUS_RAG_MANAGED_DB_KNN = GapicRagCorpus(
195+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
196+
display_name=TEST_CORPUS_DISPLAY_NAME,
197+
description=TEST_CORPUS_DISCRIPTION,
198+
rag_vector_db_config=GapicRagVectorDbConfig(
199+
rag_managed_db=GapicRagVectorDbConfig.RagManagedDb(
200+
knn=GapicRagVectorDbConfig.RagManagedDb.KNN()
201+
)
202+
),
203+
)
204+
TEST_GAPIC_RAG_CORPUS_RAG_MANAGED_DB_ANN = GapicRagCorpus(
205+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
206+
display_name=TEST_CORPUS_DISPLAY_NAME,
207+
description=TEST_CORPUS_DISCRIPTION,
208+
rag_vector_db_config=GapicRagVectorDbConfig(
209+
rag_managed_db=GapicRagVectorDbConfig.RagManagedDb(
210+
ann=GapicRagVectorDbConfig.RagManagedDb.ANN(
211+
tree_depth=TEST_RAG_MANAGED_DB_ANN_TREE_DEPTH,
212+
leaf_count=TEST_RAG_MANAGED_DB_ANN_LEAF_COUNT,
213+
)
214+
)
215+
),
216+
)
172217
TEST_EMBEDDING_MODEL_CONFIG = EmbeddingModelConfig(
173218
publisher_model="publishers/google/models/textembedding-gecko",
174219
)
175220
TEST_RAG_EMBEDDING_MODEL_CONFIG = RagEmbeddingModelConfig(
176221
vertex_prediction_endpoint=VertexPredictionEndpoint(
177-
publisher_model="publishers/google/models/textembedding-gecko",
222+
publisher_model="projects/{}/locations/{}/publishers/google/models/textembedding-gecko".format(
223+
TEST_PROJECT, TEST_REGION
224+
),
178225
),
179226
)
180227
TEST_BACKEND_CONFIG_EMBEDDING_MODEL_CONFIG = RagVectorDbConfig(
@@ -207,6 +254,21 @@
207254
description=TEST_CORPUS_DISCRIPTION,
208255
vector_db=TEST_PINECONE_CONFIG,
209256
)
257+
TEST_RAG_CORPUS_RAG_MANAGED_DB = RagCorpus(
258+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
259+
display_name=TEST_CORPUS_DISPLAY_NAME,
260+
vector_db=TEST_RAG_MANAGED_DB_CONFIG,
261+
)
262+
TEST_RAG_CORPUS_RAG_MANAGED_DB_KNN = RagCorpus(
263+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
264+
display_name=TEST_CORPUS_DISPLAY_NAME,
265+
vector_db=TEST_RAG_MANAGED_DB_KNN_CONFIG,
266+
)
267+
TEST_RAG_CORPUS_RAG_MANAGED_DB_ANN = RagCorpus(
268+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
269+
display_name=TEST_CORPUS_DISPLAY_NAME,
270+
vector_db=TEST_RAG_MANAGED_DB_ANN_CONFIG,
271+
)
210272
TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH = RagCorpus(
211273
name=TEST_RAG_CORPUS_RESOURCE_NAME,
212274
display_name=TEST_CORPUS_DISPLAY_NAME,
@@ -247,6 +309,37 @@
247309
),
248310
),
249311
)
312+
TEST_GAPIC_RAG_CORPUS_RAG_MANAGED_DB_BACKEND_CONFIG = GapicRagCorpus(
313+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
314+
display_name=TEST_CORPUS_DISPLAY_NAME,
315+
description=TEST_CORPUS_DISCRIPTION,
316+
vector_db_config=GapicRagVectorDbConfig(
317+
rag_managed_db=GapicRagVectorDbConfig.RagManagedDb()
318+
),
319+
)
320+
TEST_GAPIC_RAG_CORPUS_RAG_MANAGED_DB_KNN_BACKEND_CONFIG = GapicRagCorpus(
321+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
322+
display_name=TEST_CORPUS_DISPLAY_NAME,
323+
description=TEST_CORPUS_DISCRIPTION,
324+
vector_db_config=GapicRagVectorDbConfig(
325+
rag_managed_db=GapicRagVectorDbConfig.RagManagedDb(
326+
knn=GapicRagVectorDbConfig.RagManagedDb.KNN()
327+
)
328+
),
329+
)
330+
TEST_GAPIC_RAG_CORPUS_RAG_MANAGED_DB_ANN_BACKEND_CONFIG = GapicRagCorpus(
331+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
332+
display_name=TEST_CORPUS_DISPLAY_NAME,
333+
description=TEST_CORPUS_DISCRIPTION,
334+
vector_db_config=GapicRagVectorDbConfig(
335+
rag_managed_db=GapicRagVectorDbConfig.RagManagedDb(
336+
ann=GapicRagVectorDbConfig.RagManagedDb.ANN(
337+
tree_depth=TEST_RAG_MANAGED_DB_ANN_TREE_DEPTH,
338+
leaf_count=TEST_RAG_MANAGED_DB_ANN_LEAF_COUNT,
339+
)
340+
)
341+
),
342+
)
250343
TEST_RAG_CORPUS_BACKEND = RagCorpus(
251344
name=TEST_RAG_CORPUS_RESOURCE_NAME,
252345
display_name=TEST_CORPUS_DISPLAY_NAME,
@@ -255,12 +348,36 @@
255348
TEST_BACKEND_CONFIG_PINECONE_CONFIG = RagVectorDbConfig(
256349
vector_db=TEST_PINECONE_CONFIG,
257350
)
351+
TEST_BACKEND_CONFIG_RAG_MANAGED_DB_CONFIG = RagVectorDbConfig(
352+
vector_db=TEST_RAG_MANAGED_DB_CONFIG,
353+
)
354+
TEST_BACKEND_CONFIG_RAG_MANAGED_DB_KNN_CONFIG = RagVectorDbConfig(
355+
vector_db=TEST_RAG_MANAGED_DB_KNN_CONFIG,
356+
)
357+
TEST_BACKEND_CONFIG_RAG_MANAGED_DB_ANN_CONFIG = RagVectorDbConfig(
358+
vector_db=TEST_RAG_MANAGED_DB_ANN_CONFIG,
359+
)
258360
TEST_RAG_CORPUS_PINECONE_BACKEND = RagCorpus(
259361
name=TEST_RAG_CORPUS_RESOURCE_NAME,
260362
display_name=TEST_CORPUS_DISPLAY_NAME,
261363
description=TEST_CORPUS_DISCRIPTION,
262364
backend_config=TEST_BACKEND_CONFIG_PINECONE_CONFIG,
263365
)
366+
TEST_RAG_CORPUS_RAG_MANAGED_DB_BACKEND = RagCorpus(
367+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
368+
display_name=TEST_CORPUS_DISPLAY_NAME,
369+
backend_config=TEST_BACKEND_CONFIG_RAG_MANAGED_DB_CONFIG,
370+
)
371+
TEST_RAG_CORPUS_RAG_MANAGED_DB_KNN_BACKEND = RagCorpus(
372+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
373+
display_name=TEST_CORPUS_DISPLAY_NAME,
374+
backend_config=TEST_BACKEND_CONFIG_RAG_MANAGED_DB_KNN_CONFIG,
375+
)
376+
TEST_RAG_CORPUS_RAG_MANAGED_DB_ANN_BACKEND = RagCorpus(
377+
name=TEST_RAG_CORPUS_RESOURCE_NAME,
378+
display_name=TEST_CORPUS_DISPLAY_NAME,
379+
backend_config=TEST_BACKEND_CONFIG_RAG_MANAGED_DB_ANN_CONFIG,
380+
)
264381
TEST_BACKEND_CONFIG_VERTEX_VECTOR_SEARCH_CONFIG = RagVectorDbConfig(
265382
vector_db=TEST_VERTEX_VECTOR_SEARCH_CONFIG,
266383
)
@@ -343,6 +460,15 @@
343460
# GCS
344461
TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig(
345462
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
463+
rebuild_ann_index=False,
464+
)
465+
TEST_IMPORT_FILES_CONFIG_GCS_REBUILD_ANN_INDEX = ImportRagFilesConfig(
466+
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
467+
rebuild_ann_index=True,
468+
)
469+
TEST_IMPORT_FILES_CONFIG_GCS_REBUILD_ANN_INDEX.gcs_source.uris = [TEST_GCS_PATH]
470+
TEST_IMPORT_FILES_CONFIG_GCS_REBUILD_ANN_INDEX.rag_file_parsing_config.advanced_parser.use_advanced_pdf_parsing = (
471+
False
346472
)
347473
TEST_IMPORT_FILES_CONFIG_GCS.gcs_source.uris = [TEST_GCS_PATH]
348474
TEST_IMPORT_FILES_CONFIG_GCS.rag_file_parsing_config.advanced_parser.use_advanced_pdf_parsing = (
@@ -352,6 +478,10 @@
352478
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
353479
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_GCS,
354480
)
481+
TEST_IMPORT_REQUEST_GCS_REBUILD_ANN_INDEX = ImportRagFilesRequest(
482+
parent=TEST_RAG_CORPUS_RESOURCE_NAME,
483+
import_rag_files_config=TEST_IMPORT_FILES_CONFIG_GCS_REBUILD_ANN_INDEX,
484+
)
355485
# Google Drive folders
356486
TEST_DRIVE_FOLDER_ID = "123"
357487
TEST_DRIVE_FOLDER = (
@@ -362,6 +492,7 @@
362492
)
363493
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER = ImportRagFilesConfig(
364494
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
495+
rebuild_ann_index=False,
365496
)
366497
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.google_drive_source.resource_ids = [
367498
GoogleDriveSource.ResourceId(
@@ -374,6 +505,7 @@
374505
)
375506
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING = ImportRagFilesConfig(
376507
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
508+
rebuild_ann_index=False,
377509
)
378510
TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.google_drive_source.resource_ids = [
379511
GoogleDriveSource.ResourceId(
@@ -432,6 +564,7 @@
432564
use_advanced_pdf_parsing=False
433565
)
434566
),
567+
rebuild_ann_index=False,
435568
)
436569
TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.max_embedding_requests_per_min = 800
437570

@@ -491,6 +624,7 @@
491624
TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE = ImportRagFilesConfig(
492625
rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG,
493626
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
627+
rebuild_ann_index=False,
494628
)
495629
TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE.slack_source.channels = [
496630
GapicSlackSource.SlackChannels(
@@ -544,6 +678,7 @@
544678
TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE = ImportRagFilesConfig(
545679
rag_file_parsing_config=TEST_RAG_FILE_PARSING_CONFIG,
546680
rag_file_transformation_config=TEST_RAG_FILE_TRANSFORMATION_CONFIG,
681+
rebuild_ann_index=False,
547682
)
548683
TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE.jira_source.jira_queries = [
549684
GapicJiraSource.JiraQueries(
@@ -591,6 +726,7 @@
591726
)
592727
]
593728
),
729+
rebuild_ann_index=False,
594730
)
595731

596732
TEST_IMPORT_REQUEST_SHARE_POINT_SOURCE = ImportRagFilesRequest(
@@ -681,6 +817,7 @@
681817
)
682818
]
683819
),
820+
rebuild_ann_index=False,
684821
)
685822

686823
TEST_IMPORT_REQUEST_SHARE_POINT_SOURCE_NO_FOLDERS = ImportRagFilesRequest(

0 commit comments

Comments
 (0)