Skip to content

Commit 99e05e5

Browse files
committed
Use quoted identifiers
1 parent e2015b5 commit 99e05e5

File tree

2 files changed

+334
-295
lines changed

2 files changed

+334
-295
lines changed

libs/oracledb/langchain_oracledb/vectorstores/oraclevs.py

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import inspect
1515
import logging
1616
import os
17+
import re
1718
import uuid
1819
from collections.abc import Awaitable
1920
from typing import (
@@ -225,12 +226,20 @@ async def _atable_exists(connection: AsyncConnection, table_name: str) -> bool:
225226
raise
226227

227228

228-
def _normalize_oracle_identifier(name: str) -> str:
229+
def _quote_indentifier(name: str) -> str:
229230
name = name.strip()
230-
if name.startswith('"') and name.endswith('"'):
231-
return name
232-
else:
233-
return name.upper()
231+
reg = r'^(?:"[^"]+"|[^".]+)(?:\.(?:"[^"]+"|[^".]+))*$'
232+
pattern_validate = re.compile(reg)
233+
234+
if not pattern_validate.match(name):
235+
raise ValueError(f"Identifier name {name} is not valid.")
236+
237+
pattern_match = r'"([^"]+)"|([^".]+)'
238+
groups = re.findall(pattern_match, name)
239+
groups = [m[0] or m[1] for m in groups]
240+
groups = [f'"{g}"' for g in groups]
241+
242+
return ".".join(groups)
234243

235244

236245
@_handle_exceptions
@@ -245,16 +254,21 @@ def _index_exists(
245254
{"AND table_name = :table_name" if table_name else ""}
246255
"""
247256

257+
# this is an internal method, index_name and table_name comes with double quotes
258+
index_name = index_name.replace('"', "")
259+
if table_name:
260+
table_name = table_name.replace('"', "")
261+
248262
with connection.cursor() as cursor:
249263
# execute the query
250264
if table_name:
251265
cursor.execute(
252266
query,
253-
idx_name=_normalize_oracle_identifier(index_name),
254-
table_name=_normalize_oracle_identifier(table_name),
267+
idx_name=index_name,
268+
table_name=table_name,
255269
)
256270
else:
257-
cursor.execute(query, idx_name=_normalize_oracle_identifier(index_name))
271+
cursor.execute(query, idx_name=index_name)
258272
result = cursor.fetchone()
259273

260274
# check if the index exists
@@ -272,18 +286,21 @@ async def _aindex_exists(
272286
{"AND table_name = :table_name" if table_name else ""}
273287
"""
274288

289+
# this is an internal method, index_name and table_name comes with double quotes
290+
index_name = index_name.replace('"', "")
291+
if table_name:
292+
table_name = table_name.replace('"', "")
293+
275294
with connection.cursor() as cursor:
276295
# execute the query
277296
if table_name:
278297
await cursor.execute(
279298
query,
280-
idx_name=_normalize_oracle_identifier(index_name),
281-
table_name=_normalize_oracle_identifier(table_name),
299+
idx_name=index_name,
300+
table_name=table_name,
282301
)
283302
else:
284-
await cursor.execute(
285-
query, idx_name=_normalize_oracle_identifier(index_name)
286-
)
303+
await cursor.execute(query, idx_name=index_name)
287304
result = await cursor.fetchone()
288305

289306
# check if the index exists
@@ -309,7 +326,7 @@ def _get_distance_function(distance_strategy: DistanceStrategy) -> str:
309326

310327
def _get_index_name(base_name: str) -> str:
311328
unique_id = str(uuid.uuid4()).replace("-", "")
312-
return f"{base_name}_{unique_id}"
329+
return f'"{base_name}_{unique_id}"'
313330

314331

315332
def _get_table_dict(embedding_dim: int) -> Dict:
@@ -322,7 +339,6 @@ def _get_table_dict(embedding_dim: int) -> Dict:
322339
return cols_dict
323340

324341

325-
@_handle_exceptions
326342
def _create_table(connection: Connection, table_name: str, embedding_dim: int) -> None:
327343
cols_dict = _get_table_dict(embedding_dim)
328344

@@ -365,6 +381,8 @@ def create_index(
365381
if connection is None:
366382
raise ValueError("Failed to acquire a connection.")
367383
if params:
384+
if "idx_name" in params:
385+
params["idx_name"] = _quote_indentifier(params["idx_name"])
368386
if params["idx_type"] == "HNSW":
369387
_create_hnsw_index(
370388
connection,
@@ -425,6 +443,7 @@ def _get_hnsw_index_ddl(
425443
raise ValueError(f"Invalid parameter: {key}")
426444
else:
427445
config = defaults
446+
config["idx_name"] = _get_index_name(str(config["idx_name"]))
428447

429448
# base SQL statement
430449
idx_name = config["idx_name"]
@@ -519,6 +538,7 @@ def _get_ivf_index_ddl(
519538
raise ValueError(f"Invalid parameter: {key}")
520539
else:
521540
config = defaults
541+
config["idx_name"] = _get_index_name(str(config["idx_name"]))
522542

523543
# base SQL statement
524544
idx_name = config["idx_name"]
@@ -577,6 +597,8 @@ async def acreate_index(
577597
) -> None:
578598
async def context(connection: Any) -> None:
579599
if params:
600+
if "idx_name" in params:
601+
params["idx_name"] = _quote_indentifier(params["idx_name"])
580602
if params["idx_type"] == "HNSW":
581603
await _acreate_hnsw_index(
582604
connection,
@@ -657,6 +679,7 @@ def drop_table_purge(client: Any, table_name: str) -> None:
657679
RuntimeError: If an error occurs while dropping the table.
658680
"""
659681
connection = _get_connection(client)
682+
table_name = _quote_indentifier(table_name)
660683
if connection is None:
661684
raise ValueError("Failed to acquire a connection.")
662685
if _table_exists(connection, table_name):
@@ -680,6 +703,7 @@ async def adrop_table_purge(client: Any, table_name: str) -> None:
680703
Raises:
681704
RuntimeError: If an error occurs while dropping the table.
682705
"""
706+
table_name = _quote_indentifier(table_name)
683707

684708
async def context(connection: Any) -> None:
685709
if await _atable_exists(connection, table_name):
@@ -706,6 +730,7 @@ def drop_index_if_exists(client: Any, index_name: str) -> None:
706730
RuntimeError: If an error occurs while dropping the index.
707731
"""
708732
connection = _get_connection(client)
733+
index_name = _quote_indentifier(index_name)
709734
if connection is None:
710735
raise ValueError("Failed to acquire a connection.")
711736
if _index_exists(connection, index_name):
@@ -729,6 +754,7 @@ async def adrop_index_if_exists(client: Any, index_name: str) -> None:
729754
Raises:
730755
RuntimeError: If an error occurs while dropping the index.
731756
"""
757+
index_name = _quote_indentifier(index_name)
732758

733759
async def context(connection: Any) -> None:
734760
if await _aindex_exists(connection, index_name):
@@ -874,8 +900,6 @@ async def _handle_context(
874900
def output_type_string_handler(cursor: Any, metadata: Any) -> Any:
875901
if metadata.type_code is oracledb.DB_TYPE_CLOB:
876902
return cursor.var(oracledb.DB_TYPE_LONG, arraysize=cursor.arraysize)
877-
if metadata.type_code is oracledb.DB_TYPE_BLOB:
878-
return cursor.var(oracledb.DB_TYPE_LONG_RAW, arraysize=cursor.arraysize)
879903
if metadata.type_code is oracledb.DB_TYPE_NCLOB:
880904
return cursor.var(oracledb.DB_TYPE_LONG_NVARCHAR, arraysize=cursor.arraysize)
881905

@@ -911,7 +935,7 @@ def __init__(
911935
Callable[[str], List[float]],
912936
Embeddings,
913937
],
914-
table_name: str,
938+
table_name: str, # case sensitive
915939
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
916940
query: Optional[str] = "What is a Oracle database",
917941
params: Optional[Dict[str, Any]] = None,
@@ -935,7 +959,7 @@ def __init__(
935959
)
936960

937961
embedding_dim = self.get_embedding_dimension()
938-
_create_table(connection, table_name, embedding_dim)
962+
_create_table(connection, self.table_name, embedding_dim)
939963

940964
@classmethod
941965
@_ahandle_exceptions
@@ -969,7 +993,7 @@ async def context(connection: Any) -> None:
969993
)
970994

971995
embedding_dim = await self.aget_embedding_dimension()
972-
await _acreate_table(connection, table_name, embedding_dim)
996+
await _acreate_table(connection, self.table_name, embedding_dim)
973997

974998
await _handle_context(client, context)
975999

@@ -1014,7 +1038,7 @@ def _initialize(
10141038
)
10151039
self.embedding_function = embedding_function
10161040
self.query = query
1017-
self.table_name = table_name
1041+
self.table_name = _quote_indentifier(table_name)
10181042
self.distance_strategy = distance_strategy
10191043
self.params = params
10201044

@@ -1311,9 +1335,13 @@ def similarity_search_by_vector_with_relevance_scores(
13111335
# filter results if filter is provided
13121336
for result in results:
13131337
metadata = result[2] or {}
1338+
page_content_str = result[1] if result[1] is not None else ""
1339+
1340+
if not isinstance(page_content_str, str):
1341+
raise Exception("Unexpected type:", type(page_content_str))
13141342

13151343
doc = Document(
1316-
page_content=(result[1] if result[1] is not None else ""),
1344+
page_content=page_content_str,
13171345
metadata=metadata,
13181346
)
13191347
distance = result[3]
@@ -1357,9 +1385,12 @@ async def context(connection: Any) -> List:
13571385
# filter results if filter is provided
13581386
for result in results:
13591387
metadata = result[2] or {}
1388+
page_content_str = result[1] if result[1] is not None else ""
1389+
if not isinstance(page_content_str, str):
1390+
raise Exception("Unexpected type:", type(page_content_str))
13601391

13611392
doc = Document(
1362-
page_content=(result[1] if result[1] is not None else ""),
1393+
page_content=page_content_str,
13631394
metadata=metadata,
13641395
)
13651396
distance = result[3]
@@ -1405,7 +1436,9 @@ def similarity_search_by_vector_returning_embeddings(
14051436
results = cursor.fetchall()
14061437

14071438
for result in results:
1408-
page_content_str = result[1]
1439+
page_content_str = result[1] if result[1] is not None else ""
1440+
if not isinstance(page_content_str, str):
1441+
raise Exception("Unexpected type:", type(page_content_str))
14091442
metadata = result[2] or {}
14101443

14111444
# apply filter if provided and matches; otherwise, add all
@@ -1459,7 +1492,9 @@ async def context(connection: Any) -> List:
14591492
results = await cursor.fetchall()
14601493

14611494
for result in results:
1462-
page_content_str = result[1]
1495+
page_content_str = result[1] if result[1] is not None else ""
1496+
if not isinstance(page_content_str, str):
1497+
raise Exception("Unexpected type:", type(page_content_str))
14631498
metadata = result[2] or {}
14641499

14651500
# apply filter if provided and matches; otherwise, add all

0 commit comments

Comments
 (0)