Skip to content

Commit 7cdaee1

Browse files
authored
Adw secret helper (#997)
1 parent bff71a5 commit 7cdaee1

File tree

1 file changed

+32
-20
lines changed

1 file changed

+32
-20
lines changed

ads/oracledb/oracle_db.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
"""
@@ -17,19 +16,20 @@
1716
Note: We need to account for cx_Oracle though oracledb can operate in thick mode. The end user may be is using one of the old conda packs or an environment where cx_Oracle is the only available driver.
1817
"""
1918

20-
from ads.common.utils import ORACLE_DEFAULT_PORT
21-
2219
import logging
23-
import numpy as np
2420
import os
25-
import pandas as pd
2621
import tempfile
27-
from time import time
28-
from typing import Dict, Optional, List, Union, Iterator
2922
import zipfile
23+
from time import time
24+
from typing import Dict, Iterator, List, Optional, Union
25+
26+
import numpy as np
27+
import pandas as pd
28+
3029
from ads.common.decorator.runtime_dependency import (
3130
OptionalDependency,
3231
)
32+
from ads.common.utils import ORACLE_DEFAULT_PORT
3333

3434
logger = logging.getLogger("ads.oracle_connector")
3535
CX_ORACLE = "cx_Oracle"
@@ -40,17 +40,17 @@
4040
import oracledb as oracle_driver # Both the driver share same signature for the APIs that we are using.
4141

4242
PYTHON_DRIVER_NAME = PYTHON_ORACLEDB
43-
except:
43+
except ModuleNotFoundError:
4444
logger.info("oracledb package not found. Trying to load cx_Oracle")
4545
try:
4646
import cx_Oracle as oracle_driver
4747

4848
PYTHON_DRIVER_NAME = CX_ORACLE
49-
except ModuleNotFoundError:
49+
except ModuleNotFoundError as err2:
5050
raise ModuleNotFoundError(
5151
f"Neither `oracledb` nor `cx_Oracle` module was not found. Please run "
5252
f"`pip install {OptionalDependency.DATA}`."
53-
)
53+
) from err2
5454

5555

5656
class OracleRDBMSConnection(oracle_driver.Connection):
@@ -75,7 +75,7 @@ def __init__(
7575
logger.info(
7676
"Running oracledb driver in thick mode. For mTLS based connection, thick mode is default."
7777
)
78-
except:
78+
except Exception:
7979
logger.info(
8080
"Could not use thick mode. The driver is running in thin mode. System might prompt for passphrase"
8181
)
@@ -154,7 +154,6 @@ def insert(
154154
batch_size=100000,
155155
encoding="utf-8",
156156
):
157-
158157
if if_exists not in ["fail", "replace", "append"]:
159158
raise ValueError(
160159
f"Unknown option `if_exists`={if_exists}. Valid options are 'fail', 'replace', 'append'"
@@ -173,7 +172,6 @@ def insert(
173172
df_orcl.columns = df_orcl.columns.str.replace(r"\W+", "_", regex=True)
174173
table_exist = True
175174
with self.cursor() as cursor:
176-
177175
if if_exists != "replace":
178176
try:
179177
cursor.execute(f"SELECT 1 from {table_name} FETCH NEXT 1 ROWS ONLY")
@@ -275,7 +273,6 @@ def chunks(lst: List, batch_size: int):
275273
yield lst[i : i + batch_size]
276274

277275
for batch in chunks(record_data, batch_size=batch_size):
278-
279276
cursor.executemany(sql, batch, batcherrors=True)
280277

281278
for error in cursor.getbatcherrors():
@@ -304,7 +301,6 @@ def _fetch_by_batch(self, cursor, chunksize):
304301
def query(
305302
self, sql: str, bind_variables: Optional[Dict], chunksize=None
306303
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
307-
308304
start_time = time()
309305

310306
cursor = self.cursor()
@@ -315,10 +311,8 @@ def query(
315311
cursor.execute(sql, **bind_variables)
316312
columns = [row[0] for row in cursor.description]
317313
df = iter(
318-
(
319-
pd.DataFrame(data=rows, columns=columns)
320-
for rows in self._fetch_by_batch(cursor, chunksize)
321-
)
314+
pd.DataFrame(data=rows, columns=columns)
315+
for rows in self._fetch_by_batch(cursor, chunksize)
322316
)
323317

324318
else:
@@ -332,3 +326,21 @@ def query(
332326
)
333327

334328
return df
329+
330+
331+
def get_adw_connection(vault_secret_id: str) -> "oracledb.Connection":
332+
"""Creates ADW connection from the credentials stored in the vault"""
333+
import oracledb
334+
335+
from ads.secrets.adb import ADBSecretKeeper
336+
337+
secret = vault_secret_id
338+
339+
logging.getLogger().debug("A secret id was used to retrieve credentials.")
340+
creds = ADBSecretKeeper.load_secret(secret).to_dict()
341+
user = creds.pop("user_name", None)
342+
password = creds.pop("password", None)
343+
if not user or not password:
344+
raise ValueError(f"The user or password is missing in {secret}")
345+
logging.getLogger().debug("Downloaded secrets successfully.")
346+
return oracledb.connect(user=user, password=password, **creds)

0 commit comments

Comments
 (0)