Skip to content

Commit d40a0ef

Browse files
authored
Merge pull request #7 from oracle/2.5.x
Release version 2.5.10
2 parents d594ed0 + 2c9b287 commit d40a0ef

File tree

104 files changed

+4493
-3218
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

104 files changed

+4493
-3218
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,9 @@ Multiple extra dependencies can be installed together. For example:
7474
```python
7575
import ads
7676
from ads.common.auth import default_signer
77+
import oci
7778

78-
ads.set_auth(auth="api_key", profile="DEFAULT")
79+
ads.set_auth(auth="api_key", oci_config_location=oci.config.DEFAULT_LOCATION, profile="DEFAULT")
7980
bucket_name = <bucket-name>
8081
file_name = <file-name>
8182
namespace = <namespace>

ads/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import sys
1111

1212
import IPython
13+
import oci
1314
from IPython import get_ipython
1415
from IPython.core.error import UsageError
1516

@@ -31,28 +32,39 @@
3132

3233
debug_mode = os.environ.get("DEBUG_MODE", False)
3334
documentation_mode = os.environ.get("DOCUMENTATION_MODE", "False") == "True"
35+
oci_config_path = oci.config.DEFAULT_LOCATION # "~/.oci/config"
3436
oci_key_profile = "DEFAULT"
3537
test_mode = os.environ.get("TEST_MODE", False)
3638
resource_principal_mode = bool(os.environ.get("RESOURCE_PRINCIPAL_MODE", False))
3739

3840

39-
def set_auth(auth="api_key", profile="DEFAULT"):
41+
def set_auth(auth="api_key", oci_config_location=oci.config.DEFAULT_LOCATION, profile="DEFAULT"):
4042
"""
4143
Enable/disable resource principal identity or keypair identity in a notebook session.
4244
4345
Parameters
4446
----------
4547
auth: {'api_key', 'resource_principal'}, default 'api_key'
4648
Enable/disable resource principal identity or keypair identity in a notebook session
49+
oci_config_location: str, default oci.config.DEFAULT_LOCATION, which is '~/.oci/config'
50+
config file location
4751
profile: str, default 'DEFAULT'
4852
profile name for api keys config file
4953
"""
5054
global resource_principal_mode
55+
global oci_config_path
5156
global oci_key_profile
5257
oci_key_profile = profile
58+
if os.path.exists(os.path.expanduser(oci_config_location)):
59+
oci_config_path = oci_config_location
60+
else:
61+
logging.warning(
62+
f"{oci_config_location} file not exists, default value oci.config.DEFAULT_LOCATION used instead"
63+
)
64+
oci_config_path = oci.config.DEFAULT_LOCATION
5365
if auth == "api_key":
5466
resource_principal_mode = False
55-
if auth == "resource_principal":
67+
elif auth == "resource_principal":
5668
resource_principal_mode = True
5769

5870

ads/ads_version.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
"version": "2.5.9"
2+
"version": "2.5.10"
33
}

ads/automl/provider.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
is_notebook,
2727
)
2828
from ads.dataset.label_encoder import DataFrameLabelEncoder
29+
from ads.dataset.helper import is_text_data
2930

3031
from IPython.core.display import display, HTML
3132

@@ -533,10 +534,11 @@ def train(self, **kwargs):
533534
)
534535

535536
self.train_start_time = time.time()
536-
if "time_budget" in kwargs:
537-
self.time_budget = kwargs.pop("time_budget")
538-
else:
539-
self.time_budget = 0 # unlimited
537+
538+
self.time_budget = kwargs.pop("time_budget", 0) # 0 means unlimited
539+
540+
self.col_types = kwargs.pop("col_types", None)
541+
540542
self.est = self._decide_estimator(**kwargs)
541543
with warnings.catch_warnings():
542544
warnings.simplefilter("ignore")
@@ -546,6 +548,7 @@ def train(self, **kwargs):
546548
X_valid=self.X_valid,
547549
y_valid=self.y_valid,
548550
time_budget=self.time_budget,
551+
col_types=self.col_types,
549552
)
550553
self.train_end_time = time.time()
551554
self.print_summary(max_rows=10)
@@ -613,8 +616,20 @@ def _decide_estimator(self, **kwargs):
613616
or self.ml_task_type == ml_task_types.MULTI_CLASS_TEXT_CLASSIFICATION
614617
):
615618
est = self.automl.Pipeline(
616-
task="classification", text=True, score_metric=score_metric, **kwargs
619+
task="classification", score_metric=score_metric, **kwargs
617620
)
621+
if not self.col_types:
622+
if len(self.X_train.columns) == 1:
623+
self.col_types = ['text']
624+
elif len(self.X_train.columns) == 2:
625+
self.col_types = ['text', 'text']
626+
else:
627+
raise ValueError("We detected a text classification problem. Pass " \
628+
"in `col_types = [<type of column1>, <type of column2>, ...]`." \
629+
" Valid types are: ['categorical', 'numerical', 'text', 'datetime'," \
630+
" 'timedelta']."
631+
)
632+
618633
elif self.ml_task_type == ml_task_types.REGRESSION:
619634
est = self.automl.Pipeline(
620635
task="regression", score_metric=score_metric, **kwargs

ads/bds/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2022 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

ads/bds/auth.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2022 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
8+
import os
9+
import subprocess
10+
from contextlib import contextmanager
11+
12+
13+
DEFAULT_KRB5_CONFIG_PATH = "~/.bds_config/krb5.conf"
14+
KRB5_CONFIG = "KRB5_CONFIG"
15+
16+
17+
class KRB5KinitError(Exception):
18+
"""KRB5KinitError class when kinit -kt command failed to generate cached ticket with the keytab file and the krb5 config file."""
19+
20+
pass
21+
22+
23+
def has_kerberos_ticket():
24+
"""Whether kerberos cache ticket exists."""
25+
return True if subprocess.call(["klist", "-s"]) == 0 else False
26+
27+
28+
def init_ccache_with_keytab(principal: str, keytab_file: str) -> None:
29+
"""Initialize credential cache using keytab file.
30+
31+
Parameters
32+
----------
33+
principal: str
34+
The unique identity to which Kerberos can assign tickets.
35+
keytab_path: str
36+
Path to your keytab file.
37+
38+
Returns
39+
-------
40+
None
41+
Nothing.
42+
"""
43+
cmd = "kinit -kt %(keytab_file)s %(principal)s"
44+
args = {}
45+
46+
args["principal"] = principal
47+
args["keytab_file"] = keytab_file
48+
49+
kinit_proc = subprocess.Popen((cmd % args).split(), stderr=subprocess.PIPE)
50+
stdout_data, stderr_data = kinit_proc.communicate()
51+
52+
if kinit_proc.returncode > 0:
53+
raise KRB5KinitError(stderr_data)
54+
55+
56+
@contextmanager
57+
def krbcontext(
58+
principal: str, keytab_path: str, kerb5_path: str = DEFAULT_KRB5_CONFIG_PATH
59+
) -> None:
60+
"""A context manager for Kerberos-related actions.
61+
It provides a Kerberos context that you can put code inside.
62+
It will initialize credential cache automatically with keytab if no cached ticket exists.
63+
Otherwise, does nothing.
64+
65+
Parameters
66+
----------
67+
principal: str
68+
The unique identity to which Kerberos can assign tickets.
69+
keytab_path: str
70+
Path to your keytab file.
71+
kerb5_path: (str, optional).
72+
Path to your krb5 config file.
73+
74+
Returns
75+
-------
76+
None
77+
Nothing.
78+
79+
Examples
80+
--------
81+
>>> from ads.bds.auth import krbcontext
82+
>>> from pyhive import hive
83+
>>> with krbcontext(principal = "your_principal", keytab_path = "your_keytab_path"):
84+
>>> hive_cursor = hive.connect(host="your_hive_host",
85+
... port="your_hive_port",
86+
... auth='KERBEROS',
87+
... kerberos_service_name="hive").cursor()
88+
"""
89+
refresh_ticket(principal=principal, keytab_path=keytab_path, kerb5_path=kerb5_path)
90+
yield
91+
92+
93+
def refresh_ticket(
94+
principal: str, keytab_path: str, kerb5_path: str = DEFAULT_KRB5_CONFIG_PATH
95+
) -> None:
96+
"""generate new cached ticket based on the principal and keytab file path.
97+
98+
Parameters
99+
----------
100+
principal: str
101+
The unique identity to which Kerberos can assign tickets.
102+
keytab_path: str
103+
Path to your keytab file.
104+
kerb5_path: (str, optional).
105+
Path to your krb5 config file.
106+
107+
Returns
108+
-------
109+
None
110+
Nothing.
111+
112+
Examples
113+
--------
114+
>>> from ads.bds.auth import refresh_ticket
115+
>>> from pyhive import hive
116+
>>> refresh_ticket(principal = "your_principal", keytab_path = "your_keytab_path")
117+
>>> hive_cursor = hive.connect(host="your_hive_host",
118+
... port="your_hive_port",
119+
... auth='KERBEROS',
120+
... kerberos_service_name="hive").cursor()
121+
"""
122+
os.environ[KRB5_CONFIG] = os.path.abspath(os.path.expanduser(kerb5_path))
123+
if not os.path.exists(os.environ[KRB5_CONFIG]):
124+
raise FileNotFoundError(f"krb5 config file not found in {kerb5_path}.")
125+
if not has_kerberos_ticket():
126+
init_ccache_with_keytab(principal, keytab_path)

ads/bds/big_data_service.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2022 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
from abc import ABC, abstractmethod
8+
from time import time
9+
from typing import Dict, Iterator, List, Optional, Union
10+
11+
import impala
12+
import impala.dbapi as impyla # noqa
13+
import pandas as pd
14+
from impala.error import Error as ImpylaError # noqa
15+
from impala.error import HiveServer2Error as HS2Error # noqa
16+
17+
18+
class HiveConnection(ABC):
19+
"""Base class Interface."""
20+
21+
def __init__(self, **params):
22+
"""set up the impala connection."""
23+
self.params = params
24+
self.con = None # setup the connection
25+
26+
@abstractmethod
27+
def get_cursor(self):
28+
"""return the cursor from the connection.
29+
30+
Returns
31+
-------
32+
HiveServer2Cursor:
33+
cursor using a specific client.
34+
"""
35+
return None
36+
37+
38+
class ImpylaHiveConnection(HiveConnection):
39+
"""ImpalaHiveConnection class which uses impyla client."""
40+
41+
def __init__(self, **params):
42+
"""set up the impala connection."""
43+
self.params = params
44+
self.con = None # setup the connection
45+
46+
def get_cursor(self) -> "impala.hiveserver2.HiveServer2Cursor":
47+
"""return the cursor from the connection.
48+
49+
Returns
50+
-------
51+
impala.hiveserver2.HiveServer2Cursor:
52+
cursor using impyla client.
53+
"""
54+
return None
55+
56+
57+
class OracleHiveConnection(ImpylaHiveConnection):
58+
def __init__(
59+
self,
60+
host: str,
61+
port: str,
62+
**kwargs,
63+
):
64+
"""Initiate the connection.
65+
66+
Parameters
67+
----------
68+
host: str
69+
Hive host name.
70+
port: str
71+
Hive port.
72+
kwargs:
73+
Other connection parameters accepted by the client.
74+
"""
75+
pass
76+
77+
def insert(
78+
self,
79+
table_name: str,
80+
df: pd.DataFrame,
81+
if_exists: str,
82+
partition: List[str] = None,
83+
):
84+
"""insert a table from a pandas dataframe.
85+
86+
Parameters
87+
----------
88+
table_name (str):
89+
Table Name.
90+
df (pd.DataFrame):
91+
Data to be injected to the database.
92+
if_exists (str):
93+
Whether to replace, append or fail if the table already exists.
94+
partition (List[str], optional): Defaults to None.
95+
For partitioned tables, indicate the partition that's being
96+
inserted into, either with an ordered list of partition keys or a
97+
dict of partition field name to value. For example for the
98+
partition (year=2007, month=7), this can be either (2007, 7) or
99+
{'year': 2007, 'month': 7}.
100+
"""
101+
if if_exists not in ["fail", "replace", "append"]:
102+
raise ValueError(
103+
"Unknown option `if_exists`={if_exists}. Valid options are 'fail', 'replace', 'append'"
104+
)
105+
pass
106+
107+
def _fetch_by_batch(
108+
self, cursor: "impala.hiveserver2.HiveServer2Cursor", chunksize: int
109+
):
110+
"""fetch the data by batch of chunksize."""
111+
while True:
112+
rows = cursor.fetchmany(chunksize)
113+
if rows:
114+
yield rows
115+
else:
116+
break
117+
118+
def query(
119+
self,
120+
sql: str,
121+
bind_variables: Optional[Dict] = None,
122+
chunksize: Optional[int] = None,
123+
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
124+
"""Query data which support select statement.
125+
126+
Parameters
127+
----------
128+
sql (str):
129+
sql query.
130+
bind_variables (Optional[Dict]):
131+
Parameters to be bound to variables in the SQL query, if any.
132+
Impyla supports all DB API `paramstyle`s, including `qmark`,
133+
`numeric`, `named`, `format`, `pyformat`.
134+
chunksize (Optional[int]): . Defaults to None.
135+
chunksize of each of the dataframe in the iterator.
136+
137+
Returns
138+
-------
139+
Union[pd.DataFrame, Iterator[pd.DataFrame]]:
140+
A pandas DataFrame or a pandas DataFrame iterator.
141+
"""
142+
return None

0 commit comments

Comments
 (0)