diff --git a/README.rst b/README.rst
index 89c54532..cbdcce10 100644
--- a/README.rst
+++ b/README.rst
@@ -1,31 +1,31 @@
-================================
-Project is currently unsupported
-================================
+========================================================
+PyHive project has been donated to Apache Kyuubi
+========================================================
+You can follow it's development and report any issues you are experiencing here: https://github.com/apache/kyuubi/tree/master/python/pyhive
-.. image:: https://travis-ci.org/dropbox/PyHive.svg?branch=master
- :target: https://travis-ci.org/dropbox/PyHive
-.. image:: https://img.shields.io/codecov/c/github/dropbox/PyHive.svg
+Legacy notes / instructions
+===========================
-======
PyHive
-======
+**********
+
PyHive is a collection of Python `DB-API `_ and
-`SQLAlchemy `_ interfaces for `Presto `_ and
-`Hive `_.
+`SQLAlchemy `_ interfaces for `Presto `_ ,
+`Hive `_ and `Trino `_.
Usage
-=====
+**********
DB-API
------
.. code-block:: python
from pyhive import presto # or import hive or import trino
- cursor = presto.connect('localhost').cursor()
+ cursor = presto.connect('localhost').cursor() # or use hive.connect or use trino.connect
cursor.execute('SELECT * FROM my_awesome_data LIMIT 10')
print cursor.fetchone()
print cursor.fetchall()
@@ -61,7 +61,7 @@ In Python 3.7 `async` became a keyword; you can use `async_` instead:
SQLAlchemy
----------
-First install this package to register it with SQLAlchemy (see ``setup.py``).
+First install this package to register it with SQLAlchemy, see ``entry_points`` in ``setup.py``.
.. code-block:: python
@@ -71,9 +71,11 @@ First install this package to register it with SQLAlchemy (see ``setup.py``).
# Presto
engine = create_engine('presto://localhost:8080/hive/default')
# Trino
- engine = create_engine('trino://localhost:8080/hive/default')
+ engine = create_engine('trino+pyhive://localhost:8080/hive/default')
# Hive
engine = create_engine('hive://localhost:10000/default')
+
+ # SQLAlchemy < 2.0
logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True)
print select([func.count('*')], from_obj=logs).scalar()
@@ -82,6 +84,20 @@ First install this package to register it with SQLAlchemy (see ``setup.py``).
logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True)
print select([func.count('*')], from_obj=logs).scalar()
+ # SQLAlchemy >= 2.0
+ metadata_obj = MetaData()
+ books = Table("books", metadata_obj, Column("id", Integer), Column("title", String), Column("primary_author", String))
+ metadata_obj.create_all(engine)
+ inspector = inspect(engine)
+ inspector.get_columns('books')
+
+ with engine.connect() as con:
+ data = [{ "id": 1, "title": "The Hobbit", "primary_author": "Tolkien" },
+ { "id": 2, "title": "The Silmarillion", "primary_author": "Tolkien" }]
+ con.execute(books.insert(), data[0])
+ result = con.execute(text("select * from books"))
+ print(result.fetchall())
+
Note: query generation functionality is not exhaustive or fully tested, but there should be no
problem with raw SQL.
@@ -101,7 +117,7 @@ Passing session configuration
'session_props': {'query_max_run_time': '1234m'}}
)
create_engine(
- 'trino://user@host:443/hive',
+ 'trino+pyhive://user@host:443/hive',
connect_args={'protocol': 'https',
'session_props': {'query_max_run_time': '1234m'}}
)
@@ -116,27 +132,30 @@ Passing session configuration
)
Requirements
-============
+************
Install using
-- ``pip install 'pyhive[hive]'`` for the Hive interface and
-- ``pip install 'pyhive[presto]'`` for the Presto interface.
+- ``pip install 'pyhive[hive]'`` or ``pip install 'pyhive[hive_pure_sasl]'`` for the Hive interface
+- ``pip install 'pyhive[presto]'`` for the Presto interface
- ``pip install 'pyhive[trino]'`` for the Trino interface
+Note: ``'pyhive[hive]'`` extras uses `sasl `_ that doesn't support Python 3.11, See `github issue `_.
+Hence PyHive also supports `pure-sasl `_ via additional extras ``'pyhive[hive_pure_sasl]'`` which support Python 3.11.
+
PyHive works with
- Python 2.7 / Python 3
-- For Presto: Presto install
-- For Trino: Trino install
+- For Presto: `Presto installation `_
+- For Trino: `Trino installation `_
- For Hive: `HiveServer2 `_ daemon
Changelog
-=========
+*********
See https://github.com/dropbox/PyHive/releases.
Contributing
-============
+************
- Please fill out the Dropbox Contributor License Agreement at https://opensource.dropbox.com/cla/ and note this in your pull request.
- Changes must come with tests, with the exception of trivial things like fixing comments. See .travis.yml for the test environment setup.
- Notes on project scope:
@@ -146,8 +165,28 @@ Contributing
- We prefer having a small number of generic features over a large number of specialized, inflexible features.
For example, the Presto code takes an arbitrary ``requests_session`` argument for customizing HTTP calls, as opposed to having a separate parameter/branch for each ``requests`` option.
+Tips for test environment setup
+****************************************
+You can setup test environment by following ``.travis.yaml`` in this repository. It uses `Cloudera's CDH 5 `_ which requires username and password for download.
+It may not be feasible for everyone to get those credentials. Hence below are alternative instructions to setup test environment.
+
+You can clone `this repository `_ which has Docker Compose setup for Presto and Hive.
+You can add below lines to its docker-compose.yaml to start Trino in same environment::
+
+ trino:
+ image: trinodb/trino:351
+ ports:
+ - "18080:18080"
+ volumes:
+ - ./trino:/etc/trino
+
+Note: ``./trino`` for docker volume defined above is `trino config from PyHive repository `_
+
+Then run::
+ docker-compose up -d
+
Testing
-=======
+*******
.. image:: https://travis-ci.org/dropbox/PyHive.svg
:target: https://travis-ci.org/dropbox/PyHive
.. image:: http://codecov.io/github/dropbox/PyHive/coverage.svg?branch=master
@@ -166,7 +205,7 @@ WARNING: This drops/creates tables named ``one_row``, ``one_row_complex``, and `
database called ``pyhive_test_database``.
Updating TCLIService
-====================
+********************
The TCLIService module is autogenerated using a ``TCLIService.thrift`` file. To update it, the
``generate.py`` file can be used: ``python generate.py ``. When left blank, the
diff --git a/dev_requirements.txt b/dev_requirements.txt
index 0bf6d8a7..40bb605a 100644
--- a/dev_requirements.txt
+++ b/dev_requirements.txt
@@ -12,6 +12,8 @@ pytest-timeout==1.2.0
requests>=1.0.0
requests_kerberos>=0.12.0
sasl>=0.2.1
+pure-sasl>=0.6.2
+kerberos>=1.3.0
thrift>=0.10.0
#thrift_sasl>=0.1.0
git+https://github.com/cloudera/thrift_sasl # Using master branch in order to get Python 3 SASL patches
diff --git a/pyhive/__init__.py b/pyhive/__init__.py
index 8ede6abb..0a6bb1f6 100644
--- a/pyhive/__init__.py
+++ b/pyhive/__init__.py
@@ -1,3 +1,3 @@
from __future__ import absolute_import
from __future__ import unicode_literals
-__version__ = '0.6.3'
+__version__ = '0.7.0'
diff --git a/pyhive/common.py b/pyhive/common.py
index c8b23cef..51b2207f 100644
--- a/pyhive/common.py
+++ b/pyhive/common.py
@@ -18,6 +18,11 @@
from future.utils import with_metaclass
from itertools import islice
+try:
+ from collections.abc import Iterable
+except ImportError:
+ from collections import Iterable
+
class DBAPICursor(with_metaclass(abc.ABCMeta, object)):
"""Base class for some common DB-API logic"""
@@ -245,7 +250,7 @@ def escape_item(self, item):
return self.escape_number(item)
elif isinstance(item, basestring):
return self.escape_string(item)
- elif isinstance(item, collections.Iterable):
+ elif isinstance(item, Iterable):
return self.escape_sequence(item)
elif isinstance(item, datetime.datetime):
return self.escape_datetime(item, self._DATETIME_FORMAT)
diff --git a/pyhive/hive.py b/pyhive/hive.py
index 7922ed06..7cdd1491 100644
--- a/pyhive/hive.py
+++ b/pyhive/hive.py
@@ -54,6 +54,45 @@
}
+def get_sasl_client(host, sasl_auth, service=None, username=None, password=None):
+ import sasl
+ sasl_client = sasl.Client()
+ sasl_client.setAttr('host', host)
+
+ if sasl_auth == 'GSSAPI':
+ sasl_client.setAttr('service', service)
+ elif sasl_auth == 'PLAIN':
+ sasl_client.setAttr('username', username)
+ sasl_client.setAttr('password', password)
+ else:
+ raise ValueError("sasl_auth only supports GSSAPI and PLAIN")
+
+ sasl_client.init()
+ return sasl_client
+
+
+def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=None):
+ from pyhive.sasl_compat import PureSASLClient
+
+ if sasl_auth == 'GSSAPI':
+ sasl_kwargs = {'service': service}
+ elif sasl_auth == 'PLAIN':
+ sasl_kwargs = {'username': username, 'password': password}
+ else:
+ raise ValueError("sasl_auth only supports GSSAPI and PLAIN")
+
+ return PureSASLClient(host=host, **sasl_kwargs)
+
+
+def get_installed_sasl(host, sasl_auth, service=None, username=None, password=None):
+ try:
+ return get_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password)
+ # The sasl library is available
+ except ImportError:
+ # Fallback to pure-sasl library
+ return get_pure_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password)
+
+
def _parse_timestamp(value):
if value:
match = _TIMESTAMP_PATTERN.match(value)
@@ -232,7 +271,6 @@ def __init__(
self._transport = thrift.transport.TTransport.TBufferedTransport(socket)
elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'):
# Defer import so package dependency is optional
- import sasl
import thrift_sasl
if auth == 'KERBEROS':
@@ -243,20 +281,8 @@ def __init__(
if password is None:
# Password doesn't matter in NONE mode, just needs to be nonempty.
password = 'x'
-
- def sasl_factory():
- sasl_client = sasl.Client()
- sasl_client.setAttr('host', host)
- if sasl_auth == 'GSSAPI':
- sasl_client.setAttr('service', kerberos_service_name)
- elif sasl_auth == 'PLAIN':
- sasl_client.setAttr('username', username)
- sasl_client.setAttr('password', password)
- else:
- raise AssertionError
- sasl_client.init()
- return sasl_client
- self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
+
+ self._transport = thrift_sasl.TSaslClientTransport(lambda: get_installed_sasl(host=host, sasl_auth=sasl_auth, service=kerberos_service_name, username=username, password=password), sasl_auth, socket)
else:
# All HS2 config options:
# https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration
diff --git a/pyhive/presto.py b/pyhive/presto.py
index a38cd891..3217f4c2 100644
--- a/pyhive/presto.py
+++ b/pyhive/presto.py
@@ -9,6 +9,8 @@
from __future__ import unicode_literals
from builtins import object
+from decimal import Decimal
+
from pyhive import common
from pyhive.common import DBAPITypeObject
# Make all exceptions visible in this module per DB-API
@@ -34,6 +36,11 @@
_logger = logging.getLogger(__name__)
+TYPES_CONVERTER = {
+ "decimal": Decimal,
+ # As of Presto 0.69, binary data is returned as the varbinary type in base64 format
+ "varbinary": base64.b64decode
+}
class PrestoParamEscaper(common.ParamEscaper):
def escape_datetime(self, item, format):
@@ -307,14 +314,13 @@ def _fetch_more(self):
"""Fetch the next URI and update state"""
self._process_response(self._requests_session.get(self._nextUri, **self._requests_kwargs))
- def _decode_binary(self, rows):
- # As of Presto 0.69, binary data is returned as the varbinary type in base64 format
- # This function decodes base64 data in place
+ def _process_data(self, rows):
for i, col in enumerate(self.description):
- if col[1] == 'varbinary':
+ col_type = col[1].split("(")[0].lower()
+ if col_type in TYPES_CONVERTER:
for row in rows:
if row[i] is not None:
- row[i] = base64.b64decode(row[i])
+ row[i] = TYPES_CONVERTER[col_type](row[i])
def _process_response(self, response):
"""Given the JSON response from Presto's REST API, update the internal state with the next
@@ -341,7 +347,7 @@ def _process_response(self, response):
if 'data' in response_json:
assert self._columns
new_data = response_json['data']
- self._decode_binary(new_data)
+ self._process_data(new_data)
self._data += map(tuple, new_data)
if 'nextUri' not in response_json:
self._state = self._STATE_FINISHED
diff --git a/pyhive/sasl_compat.py b/pyhive/sasl_compat.py
new file mode 100644
index 00000000..dc65abe9
--- /dev/null
+++ b/pyhive/sasl_compat.py
@@ -0,0 +1,56 @@
+# Original source of this file is https://github.com/cloudera/impyla/blob/master/impala/sasl_compat.py
+# which uses Apache-2.0 license as of 21 May 2023.
+# This code was added to Impyla in 2016 as a compatibility layer to allow use of either python-sasl or pure-sasl
+# via PR https://github.com/cloudera/impyla/pull/179
+# Even though thrift_sasl lists pure-sasl as dependency here https://github.com/cloudera/thrift_sasl/blob/master/setup.py#L34
+# but it still calls functions native to python-sasl in this file https://github.com/cloudera/thrift_sasl/blob/master/thrift_sasl/__init__.py#L82
+# Hence this code is required for the fallback to work.
+
+
+from puresasl.client import SASLClient, SASLError
+from contextlib import contextmanager
+
+@contextmanager
+def error_catcher(self, Exc = Exception):
+ try:
+ self.error = None
+ yield
+ except Exc as e:
+ self.error = str(e)
+
+
+class PureSASLClient(SASLClient):
+ def __init__(self, *args, **kwargs):
+ self.error = None
+ super(PureSASLClient, self).__init__(*args, **kwargs)
+
+ def start(self, mechanism):
+ with error_catcher(self, SASLError):
+ if isinstance(mechanism, list):
+ self.choose_mechanism(mechanism)
+ else:
+ self.choose_mechanism([mechanism])
+ return True, self.mechanism, self.process()
+ # else
+ return False, mechanism, None
+
+ def encode(self, incoming):
+ with error_catcher(self):
+ return True, self.unwrap(incoming)
+ # else
+ return False, None
+
+ def decode(self, outgoing):
+ with error_catcher(self):
+ return True, self.wrap(outgoing)
+ # else
+ return False, None
+
+ def step(self, challenge=None):
+ with error_catcher(self):
+ return True, self.process(challenge)
+ # else
+ return False, None
+
+ def getError(self):
+ return self.error
diff --git a/pyhive/sqlalchemy_hive.py b/pyhive/sqlalchemy_hive.py
index 2ef49652..e2244525 100644
--- a/pyhive/sqlalchemy_hive.py
+++ b/pyhive/sqlalchemy_hive.py
@@ -13,11 +13,22 @@
import re
from sqlalchemy import exc
-from sqlalchemy import processors
+from sqlalchemy.sql import text
+try:
+ from sqlalchemy import processors
+except ImportError:
+ # Required for SQLAlchemy>=2.0
+ from sqlalchemy.engine import processors
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
-from sqlalchemy.databases import mysql
+try:
+ from sqlalchemy.databases import mysql
+ mysql_tinyinteger = mysql.MSTinyInteger
+except ImportError:
+ # Required for SQLAlchemy>2.0
+ from sqlalchemy.dialects import mysql
+ mysql_tinyinteger = mysql.base.MSTinyInteger
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler
@@ -121,7 +132,7 @@ def __init__(self, dialect):
_type_map = {
'boolean': types.Boolean,
- 'tinyint': mysql.MSTinyInteger,
+ 'tinyint': mysql_tinyinteger,
'smallint': types.SmallInteger,
'int': types.Integer,
'bigint': types.BigInteger,
@@ -228,8 +239,8 @@ def _translate_colname(self, colname):
class HiveDialect(default.DefaultDialect):
- name = b'hive'
- driver = b'thrift'
+ name = 'hive'
+ driver = 'thrift'
execution_ctx_cls = HiveExecutionContext
preparer = HiveIdentifierPreparer
statement_compiler = HiveCompiler
@@ -247,10 +258,15 @@ class HiveDialect(default.DefaultDialect):
supports_multivalues_insert = True
type_compiler = HiveTypeCompiler
supports_sane_rowcount = False
+ supports_statement_cache = False
@classmethod
def dbapi(cls):
return hive
+
+ @classmethod
+ def import_dbapi(cls):
+ return hive
def create_connect_args(self, url):
kwargs = {
@@ -265,7 +281,7 @@ def create_connect_args(self, url):
def get_schema_names(self, connection, **kw):
# Equivalent to SHOW DATABASES
- return [row[0] for row in connection.execute('SHOW SCHEMAS')]
+ return [row[0] for row in connection.execute(text('SHOW SCHEMAS'))]
def get_view_names(self, connection, schema=None, **kw):
# Hive does not provide functionality to query tableType
@@ -280,7 +296,7 @@ def _get_table_columns(self, connection, table_name, schema):
# Using DESCRIBE works but is uglier.
try:
# This needs the table name to be unescaped (no backticks).
- rows = connection.execute('DESCRIBE {}'.format(full_table)).fetchall()
+ rows = connection.execute(text('DESCRIBE {}'.format(full_table))).fetchall()
except exc.OperationalError as e:
# Does the table exist?
regex_fmt = r'TExecuteStatementResp.*SemanticException.*Table not found {}'
@@ -296,7 +312,7 @@ def _get_table_columns(self, connection, table_name, schema):
raise exc.NoSuchTableError(full_table)
return rows
- def has_table(self, connection, table_name, schema=None):
+ def has_table(self, connection, table_name, schema=None, **kw):
try:
self._get_table_columns(connection, table_name, schema)
return True
@@ -361,7 +377,7 @@ def get_table_names(self, connection, schema=None, **kw):
query = 'SHOW TABLES'
if schema:
query += ' IN ' + self.identifier_preparer.quote_identifier(schema)
- return [row[0] for row in connection.execute(query)]
+ return [row[0] for row in connection.execute(text(query))]
def do_rollback(self, dbapi_connection):
# No transactions for Hive
diff --git a/pyhive/sqlalchemy_presto.py b/pyhive/sqlalchemy_presto.py
index a199ebe1..bfe1ba04 100644
--- a/pyhive/sqlalchemy_presto.py
+++ b/pyhive/sqlalchemy_presto.py
@@ -9,11 +9,19 @@
from __future__ import unicode_literals
import re
+import sqlalchemy
from sqlalchemy import exc
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
-from sqlalchemy.databases import mysql
+from sqlalchemy.sql import text
+try:
+ from sqlalchemy.databases import mysql
+ mysql_tinyinteger = mysql.MSTinyInteger
+except ImportError:
+ # Required for SQLAlchemy>=2.0
+ from sqlalchemy.dialects import mysql
+ mysql_tinyinteger = mysql.base.MSTinyInteger
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler
@@ -21,6 +29,7 @@
from pyhive import presto
from pyhive.common import UniversalSet
+sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))
class PrestoIdentifierPreparer(compiler.IdentifierPreparer):
# Just quote everything to make things simpler / easier to upgrade
@@ -29,7 +38,7 @@ class PrestoIdentifierPreparer(compiler.IdentifierPreparer):
_type_map = {
'boolean': types.Boolean,
- 'tinyint': mysql.MSTinyInteger,
+ 'tinyint': mysql_tinyinteger,
'smallint': types.SmallInteger,
'integer': types.Integer,
'bigint': types.BigInteger,
@@ -80,6 +89,7 @@ class PrestoDialect(default.DefaultDialect):
supports_multivalues_insert = True
supports_unicode_statements = True
supports_unicode_binds = True
+ supports_statement_cache = False
returns_unicode_strings = True
description_encoding = None
supports_native_boolean = True
@@ -88,6 +98,10 @@ class PrestoDialect(default.DefaultDialect):
@classmethod
def dbapi(cls):
return presto
+
+ @classmethod
+ def import_dbapi(cls):
+ return presto
def create_connect_args(self, url):
db_parts = (url.database or 'hive').split('/')
@@ -108,14 +122,14 @@ def create_connect_args(self, url):
return [], kwargs
def get_schema_names(self, connection, **kw):
- return [row.Schema for row in connection.execute('SHOW SCHEMAS')]
+ return [row.Schema for row in connection.execute(text('SHOW SCHEMAS'))]
def _get_table_columns(self, connection, table_name, schema):
full_table = self.identifier_preparer.quote_identifier(table_name)
if schema:
full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table
try:
- return connection.execute('SHOW COLUMNS FROM {}'.format(full_table))
+ return connection.execute(text('SHOW COLUMNS FROM {}'.format(full_table)))
except (presto.DatabaseError, exc.DatabaseError) as e:
# Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which
# it successfully does in the Hive version. The difference with Presto is that this
@@ -134,7 +148,7 @@ def _get_table_columns(self, connection, table_name, schema):
else:
raise
- def has_table(self, connection, table_name, schema=None):
+ def has_table(self, connection, table_name, schema=None, **kw):
try:
self._get_table_columns(connection, table_name, schema)
return True
@@ -176,6 +190,8 @@ def get_indexes(self, connection, table_name, schema=None, **kw):
# - a boolean column named "Partition Key"
# - a string in the "Comment" column
# - a string in the "Extra" column
+ if sqlalchemy_version >= 1.4:
+ row = row._mapping
is_partition_key = (
(part_key in row and row[part_key])
or row['Comment'].startswith(part_key)
@@ -192,7 +208,7 @@ def get_table_names(self, connection, schema=None, **kw):
query = 'SHOW TABLES'
if schema:
query += ' FROM ' + self.identifier_preparer.quote_identifier(schema)
- return [row.Table for row in connection.execute(query)]
+ return [row.Table for row in connection.execute(text(query))]
def do_rollback(self, dbapi_connection):
# No transactions for Presto
diff --git a/pyhive/sqlalchemy_trino.py b/pyhive/sqlalchemy_trino.py
index 4b2b3698..11be2a6c 100644
--- a/pyhive/sqlalchemy_trino.py
+++ b/pyhive/sqlalchemy_trino.py
@@ -13,7 +13,13 @@
from sqlalchemy import types
from sqlalchemy import util
# TODO shouldn't use mysql type
-from sqlalchemy.databases import mysql
+try:
+ from sqlalchemy.databases import mysql
+ mysql_tinyinteger = mysql.MSTinyInteger
+except ImportError:
+ # Required for SQLAlchemy>=2.0
+ from sqlalchemy.dialects import mysql
+ mysql_tinyinteger = mysql.base.MSTinyInteger
from sqlalchemy.engine import default
from sqlalchemy.sql import compiler
from sqlalchemy.sql.compiler import SQLCompiler
@@ -28,7 +34,7 @@ class TrinoIdentifierPreparer(PrestoIdentifierPreparer):
_type_map = {
'boolean': types.Boolean,
- 'tinyint': mysql.MSTinyInteger,
+ 'tinyint': mysql_tinyinteger,
'smallint': types.SmallInteger,
'integer': types.Integer,
'bigint': types.BigInteger,
@@ -67,7 +73,12 @@ def visit_TEXT(self, type_, **kw):
class TrinoDialect(PrestoDialect):
name = 'trino'
+ supports_statement_cache = False
@classmethod
def dbapi(cls):
return trino
+
+ @classmethod
+ def import_dbapi(cls):
+ return trino
diff --git a/pyhive/tests/sqlalchemy_test_case.py b/pyhive/tests/sqlalchemy_test_case.py
index 652e05f4..db89d57b 100644
--- a/pyhive/tests/sqlalchemy_test_case.py
+++ b/pyhive/tests/sqlalchemy_test_case.py
@@ -3,6 +3,7 @@
from __future__ import unicode_literals
import abc
+import re
import contextlib
import functools
@@ -14,8 +15,10 @@
from sqlalchemy.schema import Index
from sqlalchemy.schema import MetaData
from sqlalchemy.schema import Table
-from sqlalchemy.sql import expression
+from sqlalchemy.sql import expression, text
+from sqlalchemy import String
+sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))
def with_engine_connection(fn):
"""Pass a connection to the given function and handle cleanup.
@@ -32,19 +35,33 @@ def wrapped_fn(self, *args, **kwargs):
engine.dispose()
return wrapped_fn
+def reflect_table(engine, connection, table, include_columns, exclude_columns, resolve_fks):
+ if sqlalchemy_version >= 1.4:
+ insp = sqlalchemy.inspect(engine)
+ insp.reflect_table(
+ table,
+ include_columns=include_columns,
+ exclude_columns=exclude_columns,
+ resolve_fks=resolve_fks,
+ )
+ else:
+ engine.dialect.reflecttable(
+ connection, table, include_columns=include_columns,
+ exclude_columns=exclude_columns, resolve_fks=resolve_fks)
+
class SqlAlchemyTestCase(with_metaclass(abc.ABCMeta, object)):
@with_engine_connection
def test_basic_query(self, engine, connection):
- rows = connection.execute('SELECT * FROM one_row').fetchall()
+ rows = connection.execute(text('SELECT * FROM one_row')).fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0].number_of_rows, 1) # number_of_rows is the column name
self.assertEqual(len(rows[0]), 1)
@with_engine_connection
def test_one_row_complex_null(self, engine, connection):
- one_row_complex_null = Table('one_row_complex_null', MetaData(bind=engine), autoload=True)
- rows = one_row_complex_null.select().execute().fetchall()
+ one_row_complex_null = Table('one_row_complex_null', MetaData(), autoload_with=engine)
+ rows = connection.execute(one_row_complex_null.select()).fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(list(rows[0]), [None] * len(rows[0]))
@@ -53,27 +70,26 @@ def test_reflect_no_such_table(self, engine, connection):
"""reflecttable should throw an exception on an invalid table"""
self.assertRaises(
NoSuchTableError,
- lambda: Table('this_does_not_exist', MetaData(bind=engine), autoload=True))
+ lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine))
self.assertRaises(
NoSuchTableError,
- lambda: Table('this_does_not_exist', MetaData(bind=engine),
- schema='also_does_not_exist', autoload=True))
+ lambda: Table('this_does_not_exist', MetaData(schema='also_does_not_exist'), autoload_with=engine))
@with_engine_connection
def test_reflect_include_columns(self, engine, connection):
"""When passed include_columns, reflecttable should filter out other columns"""
- one_row_complex = Table('one_row_complex', MetaData(bind=engine))
- engine.dialect.reflecttable(
- connection, one_row_complex, include_columns=['int'],
+
+ one_row_complex = Table('one_row_complex', MetaData())
+ reflect_table(engine, connection, one_row_complex, include_columns=['int'],
exclude_columns=[], resolve_fks=True)
+
self.assertEqual(len(one_row_complex.c), 1)
self.assertIsNotNone(one_row_complex.c.int)
self.assertRaises(AttributeError, lambda: one_row_complex.c.tinyint)
@with_engine_connection
def test_reflect_with_schema(self, engine, connection):
- dummy = Table('dummy_table', MetaData(bind=engine), schema='pyhive_test_database',
- autoload=True)
+ dummy = Table('dummy_table', MetaData(schema='pyhive_test_database'), autoload_with=engine)
self.assertEqual(len(dummy.c), 1)
self.assertIsNotNone(dummy.c.a)
@@ -81,22 +97,22 @@ def test_reflect_with_schema(self, engine, connection):
@with_engine_connection
def test_reflect_partitions(self, engine, connection):
"""reflecttable should get the partition column as an index"""
- many_rows = Table('many_rows', MetaData(bind=engine), autoload=True)
+ many_rows = Table('many_rows', MetaData(), autoload_with=engine)
self.assertEqual(len(many_rows.c), 2)
self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)}))
- many_rows = Table('many_rows', MetaData(bind=engine))
- engine.dialect.reflecttable(
- connection, many_rows, include_columns=['a'],
+ many_rows = Table('many_rows', MetaData())
+ reflect_table(engine, connection, many_rows, include_columns=['a'],
exclude_columns=[], resolve_fks=True)
+
self.assertEqual(len(many_rows.c), 1)
self.assertFalse(many_rows.c.a.index)
self.assertFalse(many_rows.indexes)
- many_rows = Table('many_rows', MetaData(bind=engine))
- engine.dialect.reflecttable(
- connection, many_rows, include_columns=['b'],
+ many_rows = Table('many_rows', MetaData())
+ reflect_table(engine, connection, many_rows, include_columns=['b'],
exclude_columns=[], resolve_fks=True)
+
self.assertEqual(len(many_rows.c), 1)
self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)}))
@@ -104,11 +120,15 @@ def test_reflect_partitions(self, engine, connection):
def test_unicode(self, engine, connection):
"""Verify that unicode strings make it through SQLAlchemy and the backend"""
unicode_str = "中文"
- one_row = Table('one_row', MetaData(bind=engine))
- returned_str = sqlalchemy.select(
- [expression.bindparam("好", unicode_str)],
- from_obj=one_row,
- ).scalar()
+ one_row = Table('one_row', MetaData())
+
+ if sqlalchemy_version >= 1.4:
+ returned_str = connection.execute(sqlalchemy.select(
+ expression.bindparam("好", unicode_str, type_=String())).select_from(one_row)).scalar()
+ else:
+ returned_str = connection.execute(sqlalchemy.select([
+ expression.bindparam("好", unicode_str, type_=String())]).select_from(one_row)).scalar()
+
self.assertEqual(returned_str, unicode_str)
@with_engine_connection
@@ -133,13 +153,21 @@ def test_get_table_names(self, engine, connection):
@with_engine_connection
def test_has_table(self, engine, connection):
- self.assertTrue(Table('one_row', MetaData(bind=engine)).exists())
- self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists())
+ if sqlalchemy_version >= 1.4:
+ insp = sqlalchemy.inspect(engine)
+ self.assertTrue(insp.has_table("one_row"))
+ self.assertFalse(insp.has_table("this_table_does_not_exist"))
+ else:
+ self.assertTrue(Table('one_row', MetaData(bind=engine)).exists())
+ self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists())
@with_engine_connection
def test_char_length(self, engine, connection):
- one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True)
- result = sqlalchemy.select([
- sqlalchemy.func.char_length(one_row_complex.c.string)
- ]).execute().scalar()
+ one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine)
+
+ if sqlalchemy_version >= 1.4:
+ result = connection.execute(sqlalchemy.select(sqlalchemy.func.char_length(one_row_complex.c.string))).scalar()
+ else:
+ result = connection.execute(sqlalchemy.select([sqlalchemy.func.char_length(one_row_complex.c.string)])).scalar()
+
self.assertEqual(result, len('a string'))
diff --git a/pyhive/tests/test_hive.py b/pyhive/tests/test_hive.py
index c70ed962..b49fc190 100644
--- a/pyhive/tests/test_hive.py
+++ b/pyhive/tests/test_hive.py
@@ -17,7 +17,6 @@
from decimal import Decimal
import mock
-import sasl
import thrift.transport.TSocket
import thrift.transport.TTransport
import thrift_sasl
@@ -204,15 +203,7 @@ def test_custom_transport(self):
socket = thrift.transport.TSocket.TSocket('localhost', 10000)
sasl_auth = 'PLAIN'
- def sasl_factory():
- sasl_client = sasl.Client()
- sasl_client.setAttr('host', 'localhost')
- sasl_client.setAttr('username', 'test_username')
- sasl_client.setAttr('password', 'x')
- sasl_client.init()
- return sasl_client
-
- transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket)
+ transport = thrift_sasl.TSaslClientTransport(lambda: hive.get_installed_sasl(host='localhost', sasl_auth=sasl_auth, username='test_username', password='x'), sasl_auth, socket)
conn = hive.connect(thrift_transport=transport)
with contextlib.closing(conn):
with contextlib.closing(conn.cursor()) as cursor:
diff --git a/pyhive/tests/test_presto.py b/pyhive/tests/test_presto.py
index 7c74f057..187b1c21 100644
--- a/pyhive/tests/test_presto.py
+++ b/pyhive/tests/test_presto.py
@@ -9,6 +9,8 @@
import contextlib
import os
+from decimal import Decimal
+
import requests
from pyhive import exc
@@ -93,7 +95,7 @@ def test_complex(self, cursor):
{"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON
[1, 2], # struct is returned as a list of elements
# '{0:1}',
- '0.1',
+ Decimal('0.1'),
)]
self.assertEqual(rows, expected)
# catch unicode/str
diff --git a/pyhive/tests/test_sasl_compat.py b/pyhive/tests/test_sasl_compat.py
new file mode 100644
index 00000000..49516249
--- /dev/null
+++ b/pyhive/tests/test_sasl_compat.py
@@ -0,0 +1,333 @@
+'''
+http://www.opensource.org/licenses/mit-license.php
+
+Copyright 2007-2011 David Alan Cridland
+Copyright 2011 Lance Stout
+Copyright 2012 Tyler L Hobbs
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of this
+software and associated documentation files (the "Software"), to deal in the Software
+without restriction, including without limitation the rights to use, copy, modify, merge,
+publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons
+to whom the Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all copies or
+substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
+PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
+FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+DEALINGS IN THE SOFTWARE.
+'''
+# This file was generated by referring test cases from the pure-sasl repo i.e. https://github.com/thobbs/pure-sasl/tree/master/tests/unit
+# and by refactoring them to cover wrapper functions in sasl_compat.py along with added coverage for functions exclusive to sasl_compat.py.
+
+import unittest
+import base64
+import hashlib
+import hmac
+import kerberos
+from mock import patch
+import six
+import struct
+from puresasl import SASLProtocolException, QOP
+from puresasl.client import SASLError
+from pyhive.sasl_compat import PureSASLClient, error_catcher
+
+
+class TestPureSASLClient(unittest.TestCase):
+ """Test cases for initialization of SASL client using PureSASLClient class"""
+
+ def setUp(self):
+ self.sasl_kwargs = {}
+ self.sasl = PureSASLClient('localhost', **self.sasl_kwargs)
+
+ def test_start_no_mechanism(self):
+ """Test starting SASL authentication with no mechanism."""
+ success, mechanism, response = self.sasl.start(mechanism=None)
+ self.assertFalse(success)
+ self.assertIsNone(mechanism)
+ self.assertIsNone(response)
+ self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties')
+
+ def test_start_wrong_mechanism(self):
+ """Test starting SASL authentication with a single unsupported mechanism."""
+ success, mechanism, response = self.sasl.start(mechanism='WRONG')
+ self.assertFalse(success)
+ self.assertEqual(mechanism, 'WRONG')
+ self.assertIsNone(response)
+ self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties')
+
+ def test_start_list_of_invalid_mechanisms(self):
+ """Test starting SASL authentication with a list of unsupported mechanisms."""
+ self.sasl.start(['invalid1', 'invalid2'])
+ self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties')
+
+ def test_start_list_of_valid_mechanisms(self):
+ """Test starting SASL authentication with a list of supported mechanisms."""
+ self.sasl.start(['PLAIN', 'DIGEST-MD5', 'CRAM-MD5'])
+ # Validate right mechanism is chosen based on score.
+ self.assertEqual(self.sasl._chosen_mech.name, 'DIGEST-MD5')
+
+ def test_error_catcher_no_error(self):
+ """Test the error_catcher with no error."""
+ with error_catcher(self.sasl):
+ result, _, _ = self.sasl.start(mechanism='ANONYMOUS')
+
+ self.assertEqual(self.sasl.getError(), None)
+ self.assertEqual(result, True)
+
+ def test_error_catcher_with_error(self):
+ """Test the error_catcher with an error."""
+ with error_catcher(self.sasl):
+ result, _, _ = self.sasl.start(mechanism='WRONG')
+
+ self.assertEqual(result, False)
+ self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties')
+
+"""Assuming Client initilization went well and a mechanism is chosen, Below are the test cases for different mechanims"""
+
+class _BaseMechanismTests(unittest.TestCase):
+ """Base test case for SASL mechanisms."""
+
+ mechanism = 'ANONYMOUS'
+ sasl_kwargs = {}
+
+ def setUp(self):
+ self.sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs)
+ self.mechanism_class = self.sasl._chosen_mech
+
+ def test_init_basic(self, *args):
+ sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs)
+ mech = sasl._chosen_mech
+ self.assertIs(mech.sasl, sasl)
+
+ def test_step_basic(self, *args):
+ success, response = self.sasl.step(six.b('string'))
+ self.assertTrue(success)
+ self.assertIsInstance(response, six.binary_type)
+
+ def test_decode_encode(self, *args):
+ self.assertEqual(self.sasl.encode('msg'), (False, None))
+ self.assertEqual(self.sasl.getError(), '')
+ self.assertEqual(self.sasl.decode('msg'), (False, None))
+ self.assertEqual(self.sasl.getError(), '')
+
+
+class AnonymousMechanismTest(_BaseMechanismTests):
+ """Test case for the Anonymous SASL mechanism."""
+
+ mechanism = 'ANONYMOUS'
+
+
+class PlainTextMechanismTest(_BaseMechanismTests):
+ """Test case for the PlainText SASL mechanism."""
+
+ mechanism = 'PLAIN'
+ username = 'user'
+ password = 'pass'
+ sasl_kwargs = {'username': username, 'password': password}
+
+ def test_step(self):
+ for challenge in (None, '', b'asdf', u"\U0001F44D"):
+ success, response = self.sasl.step(challenge)
+ self.assertTrue(success)
+ self.assertEqual(response, six.b(f'\x00{self.username}\x00{self.password}'))
+ self.assertIsInstance(response, six.binary_type)
+
+ def test_step_with_authorization_id_or_identity(self):
+ challenge = u"\U0001F44D"
+ identity = 'user2'
+
+ # Test that we can pass an identity
+ sasl_kwargs = self.sasl_kwargs.copy()
+ sasl_kwargs.update({'identity': identity})
+ sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs)
+ success, response = sasl.step(challenge)
+ self.assertTrue(success)
+ self.assertEqual(response, six.b(f'{identity}\x00{self.username}\x00{self.password}'))
+ self.assertIsInstance(response, six.binary_type)
+ self.assertTrue(sasl.complete)
+
+ # Test that the sasl authorization_id has priority over identity
+ auth_id = 'user3'
+ sasl_kwargs.update({'authorization_id': auth_id})
+ sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs)
+ success, response = sasl.step(challenge)
+ self.assertTrue(success)
+ self.assertEqual(response, six.b(f'{auth_id}\x00{self.username}\x00{self.password}'))
+ self.assertIsInstance(response, six.binary_type)
+ self.assertTrue(sasl.complete)
+
+ def test_decode_encode(self):
+ msg = 'msg'
+ self.assertEqual(self.sasl.decode(msg), (True, msg))
+ self.assertEqual(self.sasl.encode(msg), (True, msg))
+
+
+class ExternalMechanismTest(_BaseMechanismTests):
+ """Test case for the External SASL mechanisms"""
+
+ mechanism = 'EXTERNAL'
+
+ def test_step(self):
+ self.assertEqual(self.sasl.step(), (True, b''))
+
+ def test_decode_encode(self):
+ msg = 'msg'
+ self.assertEqual(self.sasl.decode(msg), (True, msg))
+ self.assertEqual(self.sasl.encode(msg), (True, msg))
+
+
+@patch('puresasl.mechanisms.kerberos.authGSSClientStep')
+@patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=base64.b64encode(six.b('some\x00 response')))
+class GSSAPIMechanismTest(_BaseMechanismTests):
+ """Test case for the GSSAPI SASL mechanism."""
+
+ mechanism = 'GSSAPI'
+ service = 'GSSAPI'
+ sasl_kwargs = {'service': service}
+
+ @patch('puresasl.mechanisms.kerberos.authGSSClientWrap')
+ @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap')
+ def test_decode_encode(self, _inner1, _inner2, authGSSClientResponse, *args):
+ # bypassing step setup by setting qop directly
+ self.mechanism_class.qop = QOP.AUTH
+ msg = b'msg'
+ self.assertEqual(self.sasl.decode(msg), (True, msg))
+ self.assertEqual(self.sasl.encode(msg), (True, msg))
+
+ # Test for behavior with different QOP like data integrity and confidentiality for Kerberos authentication
+ for qop in (QOP.AUTH_INT, QOP.AUTH_CONF):
+ self.mechanism_class.qop = qop
+ with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=1):
+ self.assertEqual(self.sasl.decode(msg), (True, base64.b64decode(authGSSClientResponse.return_value)))
+ self.assertEqual(self.sasl.encode(msg), (True, base64.b64decode(authGSSClientResponse.return_value)))
+ if qop == QOP.AUTH_CONF:
+ with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=0):
+ self.assertEqual(self.sasl.encode(msg), (False, None))
+ self.assertEqual(self.sasl.getError(), 'Error: confidentiality requested, but not honored by the server.')
+
+ def test_step_no_user(self, authGSSClientResponse, *args):
+ msg = six.b('whatever')
+
+ # no user
+ self.assertEqual(self.sasl.step(msg), (True, base64.b64decode(authGSSClientResponse.return_value)))
+ with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=''):
+ self.assertEqual(self.sasl.step(msg), (True, six.b('')))
+
+ username = 'username'
+ # with user; this has to be last because it sets mechanism.user
+ with patch('puresasl.mechanisms.kerberos.authGSSClientStep', return_value=kerberos.AUTH_GSS_COMPLETE):
+ with patch('puresasl.mechanisms.kerberos.authGSSClientUserName', return_value=six.b(username)):
+ self.assertEqual(self.sasl.step(msg), (True, six.b('')))
+ self.assertEqual(self.mechanism_class.user, six.b(username))
+
+ @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap')
+ def test_step_qop(self, *args):
+ self.mechanism_class._have_negotiated_details = True
+ self.mechanism_class.user = 'user'
+ msg = six.b('msg')
+ self.assertEqual(self.sasl.step(msg), (False, None))
+ self.assertEqual(self.sasl.getError(), 'Bad response from server')
+
+ max_len = 100
+ self.assertLess(max_len, self.sasl.max_buffer)
+ for i, qop in QOP.bit_map.items():
+ qop_size = struct.pack('!i', i << 24 | max_len)
+ response = base64.b64encode(qop_size)
+ with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=response):
+ with patch('puresasl.mechanisms.kerberos.authGSSClientWrap') as authGSSClientWrap:
+ self.mechanism_class.complete = False
+ self.assertEqual(self.sasl.step(msg), (True, qop_size))
+ self.assertTrue(self.mechanism_class.complete)
+ self.assertEqual(self.mechanism_class.qop, qop)
+ self.assertEqual(self.mechanism_class.max_buffer, max_len)
+
+ args = authGSSClientWrap.call_args[0]
+ out_data = args[1]
+ out = base64.b64decode(out_data)
+ self.assertEqual(out[:4], qop_size)
+ self.assertEqual(out[4:], six.b(self.mechanism_class.user))
+
+
+class CramMD5MechanismTest(_BaseMechanismTests):
+ """Test case for the CRAM-MD5 SASL mechanism."""
+
+ mechanism = 'CRAM-MD5'
+ username = 'user'
+ password = 'pass'
+ sasl_kwargs = {'username': username, 'password': password}
+
+ def test_step(self):
+ success, response = self.sasl.step(None)
+ self.assertTrue(success)
+ self.assertIsNone(response)
+ challenge = six.b('msg')
+ hash = hmac.HMAC(key=six.b(self.password), digestmod=hashlib.md5)
+ hash.update(challenge)
+ success, response = self.sasl.step(challenge)
+ self.assertTrue(success)
+ self.assertIn(six.b(self.username), response)
+ self.assertIn(six.b(hash.hexdigest()), response)
+ self.assertIsInstance(response, six.binary_type)
+ self.assertTrue(self.sasl.complete)
+
+ def test_decode_encode(self):
+ msg = 'msg'
+ self.assertEqual(self.sasl.decode(msg), (True, msg))
+ self.assertEqual(self.sasl.encode(msg), (True, msg))
+
+
+class DigestMD5MechanismTest(_BaseMechanismTests):
+ """Test case for the DIGEST-MD5 SASL mechanism."""
+
+ mechanism = 'DIGEST-MD5'
+ username = 'user'
+ password = 'pass'
+ sasl_kwargs = {'username': username, 'password': password}
+
+ def test_decode_encode(self):
+ msg = 'msg'
+ self.assertEqual(self.sasl.decode(msg), (True, msg))
+ self.assertEqual(self.sasl.encode(msg), (True, msg))
+
+ def test_step_basic(self, *args):
+ pass
+
+ def test_step(self):
+ """Test a SASL step with dummy challenge for DIGEST-MD5 mechanism."""
+ testChallenge = (
+ b'nonce="rmD6R8aMYVWH+/ih9HGBr3xNGAR6o2DUxpKlgDz6gUQ=",r'
+ b'ealm="example.org",qop="auth,auth-int,auth-conf",cipher="rc4-40,rc'
+ b'4-56,rc4,des,3des",maxbuf=65536,charset=utf-8,algorithm=md5-sess'
+ )
+ result, response = self.sasl.step(testChallenge)
+ self.assertTrue(result)
+ self.assertIsNotNone(response)
+
+ def test_step_server_answer(self):
+ """Test a SASL step with a proper server answer for DIGEST-MD5 mechanism."""
+ sasl_kwargs = {'username': "chris", 'password': "secret"}
+ sasl = PureSASLClient('elwood.innosoft.com',
+ service="imap",
+ mechanism=self.mechanism,
+ mutual_auth=True,
+ **sasl_kwargs)
+ testChallenge = (
+ b'utf-8,username="chris",realm="elwood.innosoft.com",'
+ b'nonce="OA6MG9tEQGm2hh",nc=00000001,cnonce="OA6MHXh6VqTrRk",'
+ b'digest-uri="imap/elwood.innosoft.com",'
+ b'response=d388dad90d4bbd760a152321f2143af7,qop=auth'
+ )
+ sasl.step(testChallenge)
+ sasl._chosen_mech.cnonce = b"OA6MHXh6VqTrRk"
+
+ serverResponse = (
+ b'rspauth=ea40f60335c427b5527b84dbabcdfffd'
+ )
+ sasl.step(serverResponse)
+ # assert that step choses the only supported QOP for for DIGEST-MD5
+ self.assertEqual(self.sasl.qop, QOP.AUTH)
diff --git a/pyhive/tests/test_sqlalchemy_hive.py b/pyhive/tests/test_sqlalchemy_hive.py
index 1ff0e817..790bec4c 100644
--- a/pyhive/tests/test_sqlalchemy_hive.py
+++ b/pyhive/tests/test_sqlalchemy_hive.py
@@ -4,6 +4,7 @@
from pyhive.sqlalchemy_hive import HiveDate
from pyhive.sqlalchemy_hive import HiveDecimal
from pyhive.sqlalchemy_hive import HiveTimestamp
+from sqlalchemy.exc import NoSuchTableError, OperationalError
from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase
from pyhive.tests.sqlalchemy_test_case import with_engine_connection
from sqlalchemy import types
@@ -11,11 +12,15 @@
from sqlalchemy.schema import Column
from sqlalchemy.schema import MetaData
from sqlalchemy.schema import Table
+from sqlalchemy.sql import text
import contextlib
import datetime
import decimal
import sqlalchemy.types
import unittest
+import re
+
+sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))
_ONE_ROW_COMPLEX_CONTENTS = [
True,
@@ -64,7 +69,11 @@ def test_dotted_column_names(self, engine, connection):
"""When Hive returns a dotted column name, both the non-dotted version should be available
as an attribute, and the dotted version should remain available as a key.
"""
- row = connection.execute('SELECT * FROM one_row').fetchone()
+ row = connection.execute(text('SELECT * FROM one_row')).fetchone()
+
+ if sqlalchemy_version >= 1.4:
+ row = row._mapping
+
assert row.keys() == ['number_of_rows']
assert 'number_of_rows' in row
assert row.number_of_rows == 1
@@ -76,20 +85,33 @@ def test_dotted_column_names(self, engine, connection):
def test_dotted_column_names_raw(self, engine, connection):
"""When Hive returns a dotted column name, and raw mode is on, nothing should be modified.
"""
- row = connection.execution_options(hive_raw_colnames=True) \
- .execute('SELECT * FROM one_row').fetchone()
+ row = connection.execution_options(hive_raw_colnames=True).execute(text('SELECT * FROM one_row')).fetchone()
+
+ if sqlalchemy_version >= 1.4:
+ row = row._mapping
+
assert row.keys() == ['one_row.number_of_rows']
assert 'number_of_rows' not in row
assert getattr(row, 'one_row.number_of_rows') == 1
assert row['one_row.number_of_rows'] == 1
+ @with_engine_connection
+ def test_reflect_no_such_table(self, engine, connection):
+ """reflecttable should throw an exception on an invalid table"""
+ self.assertRaises(
+ NoSuchTableError,
+ lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine))
+ self.assertRaises(
+ OperationalError,
+ lambda: Table('this_does_not_exist', MetaData(schema="also_does_not_exist"), autoload_with=engine))
+
@with_engine_connection
def test_reflect_select(self, engine, connection):
"""reflecttable should be able to fill in a table from the name"""
- one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True)
+ one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine)
self.assertEqual(len(one_row_complex.c), 15)
self.assertIsInstance(one_row_complex.c.string, Column)
- row = one_row_complex.select().execute().fetchone()
+ row = connection.execute(one_row_complex.select()).fetchone()
self.assertEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS)
# TODO some of these types could be filled in better
@@ -112,15 +134,15 @@ def test_reflect_select(self, engine, connection):
@with_engine_connection
def test_type_map(self, engine, connection):
"""sqlalchemy should use the dbapi_type_map to infer types from raw queries"""
- row = connection.execute('SELECT * FROM one_row_complex').fetchone()
+ row = connection.execute(text('SELECT * FROM one_row_complex')).fetchone()
self.assertListEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS)
@with_engine_connection
def test_reserved_words(self, engine, connection):
"""Hive uses backticks"""
# Use keywords for the table/column name
- fake_table = Table('select', MetaData(bind=engine), Column('map', sqlalchemy.types.String))
- query = str(fake_table.select(fake_table.c.map == 'a'))
+ fake_table = Table('select', MetaData(), Column('map', sqlalchemy.types.String))
+ query = str(fake_table.select().where(fake_table.c.map == 'a').compile(engine))
self.assertIn('`select`', query)
self.assertIn('`map`', query)
self.assertNotIn('"select"', query)
@@ -132,12 +154,12 @@ def test_switch_database(self):
with contextlib.closing(engine.connect()) as connection:
self.assertIn(
('dummy_table',),
- connection.execute('SHOW TABLES').fetchall()
+ connection.execute(text('SHOW TABLES')).fetchall()
)
- connection.execute('USE default')
+ connection.execute(text('USE default'))
self.assertIn(
('one_row',),
- connection.execute('SHOW TABLES').fetchall()
+ connection.execute(text('SHOW TABLES')).fetchall()
)
finally:
engine.dispose()
@@ -160,13 +182,13 @@ def test_lots_of_types(self, engine, connection):
cols.append(Column('hive_date', HiveDate))
cols.append(Column('hive_decimal', HiveDecimal))
cols.append(Column('hive_timestamp', HiveTimestamp))
- table = Table('test_table', MetaData(bind=engine), *cols, schema='pyhive_test_database')
- table.drop(checkfirst=True)
- table.create()
- connection.execute('SET mapred.job.tracker=local')
- connection.execute('USE pyhive_test_database')
+ table = Table('test_table', MetaData(schema='pyhive_test_database'), *cols,)
+ table.drop(checkfirst=True, bind=connection)
+ table.create(bind=connection)
+ connection.execute(text('SET mapred.job.tracker=local'))
+ connection.execute(text('USE pyhive_test_database'))
big_number = 10 ** 10 - 1
- connection.execute("""
+ connection.execute(text("""
INSERT OVERWRITE TABLE test_table
SELECT
1, "a", "a", "a", "a", "a", 0.1,
@@ -175,41 +197,39 @@ def test_lots_of_types(self, engine, connection):
"a", 1, 1,
0.1, 0.1, 0, 0, 0, "a",
false, "a", "a",
- 0, %d, 123 + 2000
+ 0, :big_number, 123 + 2000
FROM default.one_row
- """, big_number)
- row = connection.execute(table.select()).fetchone()
- self.assertEqual(row.hive_date, datetime.date(1970, 1, 1))
+ """), {"big_number": big_number})
+ row = connection.execute(text("select * from test_table")).fetchone()
+ self.assertEqual(row.hive_date, datetime.datetime(1970, 1, 1, 0, 0))
self.assertEqual(row.hive_decimal, decimal.Decimal(big_number))
self.assertEqual(row.hive_timestamp, datetime.datetime(1970, 1, 1, 0, 0, 2, 123000))
- table.drop()
+ table.drop(bind=connection)
@with_engine_connection
def test_insert_select(self, engine, connection):
- one_row = Table('one_row', MetaData(bind=engine), autoload=True)
- table = Table('insert_test', MetaData(bind=engine),
- Column('a', sqlalchemy.types.Integer),
- schema='pyhive_test_database')
- table.drop(checkfirst=True)
- table.create()
- connection.execute('SET mapred.job.tracker=local')
+ one_row = Table('one_row', MetaData(), autoload_with=engine)
+ table = Table('insert_test', MetaData(schema='pyhive_test_database'),
+ Column('a', sqlalchemy.types.Integer))
+ table.drop(checkfirst=True, bind=connection)
+ table.create(bind=connection)
+ connection.execute(text('SET mapred.job.tracker=local'))
# NOTE(jing) I'm stuck on a version of Hive without INSERT ... VALUES
connection.execute(table.insert().from_select(['a'], one_row.select()))
-
- result = table.select().execute().fetchall()
+
+ result = connection.execute(table.select()).fetchall()
expected = [(1,)]
self.assertEqual(result, expected)
@with_engine_connection
def test_insert_values(self, engine, connection):
- table = Table('insert_test', MetaData(bind=engine),
- Column('a', sqlalchemy.types.Integer),
- schema='pyhive_test_database')
- table.drop(checkfirst=True)
- table.create()
- connection.execute(table.insert([{'a': 1}, {'a': 2}]))
-
- result = table.select().execute().fetchall()
+ table = Table('insert_test', MetaData(schema='pyhive_test_database'),
+ Column('a', sqlalchemy.types.Integer),)
+ table.drop(checkfirst=True, bind=connection)
+ table.create(bind=connection)
+ connection.execute(table.insert().values([{'a': 1}, {'a': 2}]))
+
+ result = connection.execute(table.select()).fetchall()
expected = [(1,), (2,)]
self.assertEqual(result, expected)
diff --git a/pyhive/tests/test_sqlalchemy_presto.py b/pyhive/tests/test_sqlalchemy_presto.py
index a01e4a35..58a5c034 100644
--- a/pyhive/tests/test_sqlalchemy_presto.py
+++ b/pyhive/tests/test_sqlalchemy_presto.py
@@ -8,7 +8,9 @@
from sqlalchemy.schema import Column
from sqlalchemy.schema import MetaData
from sqlalchemy.schema import Table
+from sqlalchemy.sql import text
from sqlalchemy.types import String
+from decimal import Decimal
import contextlib
import unittest
@@ -27,11 +29,11 @@ def test_bad_format(self):
@with_engine_connection
def test_reflect_select(self, engine, connection):
"""reflecttable should be able to fill in a table from the name"""
- one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True)
+ one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine)
# Presto ignores the union column
self.assertEqual(len(one_row_complex.c), 15 - 1)
self.assertIsInstance(one_row_complex.c.string, Column)
- rows = one_row_complex.select().execute().fetchall()
+ rows = connection.execute(one_row_complex.select()).fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(list(rows[0]), [
True,
@@ -48,7 +50,7 @@ def test_reflect_select(self, engine, connection):
{"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON
[1, 2], # struct is returned as a list of elements
# '{0:1}',
- '0.1',
+ Decimal('0.1'),
])
# TODO some of these types could be filled in better
@@ -71,7 +73,7 @@ def test_url_default(self):
engine = create_engine('presto://localhost:8080/hive')
try:
with contextlib.closing(engine.connect()) as connection:
- self.assertEqual(connection.execute('SELECT 1 AS foobar FROM one_row').scalar(), 1)
+ self.assertEqual(connection.execute(text('SELECT 1 AS foobar FROM one_row')).scalar(), 1)
finally:
engine.dispose()
@@ -79,8 +81,8 @@ def test_url_default(self):
def test_reserved_words(self, engine, connection):
"""Presto uses double quotes, not backticks"""
# Use keywords for the table/column name
- fake_table = Table('select', MetaData(bind=engine), Column('current_timestamp', String))
- query = str(fake_table.select(fake_table.c.current_timestamp == 'a'))
+ fake_table = Table('select', MetaData(), Column('current_timestamp', String))
+ query = str(fake_table.select().where(fake_table.c.current_timestamp == 'a').compile(engine))
self.assertIn('"select"', query)
self.assertIn('"current_timestamp"', query)
self.assertNotIn('`select`', query)
diff --git a/pyhive/tests/test_sqlalchemy_trino.py b/pyhive/tests/test_sqlalchemy_trino.py
new file mode 100644
index 00000000..c929f941
--- /dev/null
+++ b/pyhive/tests/test_sqlalchemy_trino.py
@@ -0,0 +1,93 @@
+from sqlalchemy.engine import create_engine
+from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase
+from pyhive.tests.sqlalchemy_test_case import with_engine_connection
+from sqlalchemy.exc import NoSuchTableError, DatabaseError
+from sqlalchemy.schema import MetaData, Table, Column
+from sqlalchemy.types import String
+from sqlalchemy.sql import text
+from sqlalchemy import types
+from decimal import Decimal
+
+import unittest
+import contextlib
+
+
+class TestSqlAlchemyTrino(unittest.TestCase, SqlAlchemyTestCase):
+ def create_engine(self):
+ return create_engine('trino+pyhive://localhost:18080/hive/default?source={}'.format(self.id()))
+
+ def test_bad_format(self):
+ self.assertRaises(
+ ValueError,
+ lambda: create_engine('trino+pyhive://localhost:18080/hive/default/what'),
+ )
+
+ @with_engine_connection
+ def test_reflect_select(self, engine, connection):
+ """reflecttable should be able to fill in a table from the name"""
+ one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine)
+ # Presto ignores the union column
+ self.assertEqual(len(one_row_complex.c), 15 - 1)
+ self.assertIsInstance(one_row_complex.c.string, Column)
+ rows = connection.execute(one_row_complex.select()).fetchall()
+ self.assertEqual(len(rows), 1)
+ self.assertEqual(list(rows[0]), [
+ True,
+ 127,
+ 32767,
+ 2147483647,
+ 9223372036854775807,
+ 0.5,
+ 0.25,
+ 'a string',
+ '1970-01-01 00:00:00.000',
+ b'123',
+ [1, 2],
+ {"1": 2, "3": 4},
+ [1, 2],
+ Decimal('0.1'),
+ ])
+
+ self.assertIsInstance(one_row_complex.c.boolean.type, types.Boolean)
+ self.assertIsInstance(one_row_complex.c.tinyint.type, types.Integer)
+ self.assertIsInstance(one_row_complex.c.smallint.type, types.Integer)
+ self.assertIsInstance(one_row_complex.c.int.type, types.Integer)
+ self.assertIsInstance(one_row_complex.c.bigint.type, types.BigInteger)
+ self.assertIsInstance(one_row_complex.c.float.type, types.Float)
+ self.assertIsInstance(one_row_complex.c.double.type, types.Float)
+ self.assertIsInstance(one_row_complex.c.string.type, String)
+ self.assertIsInstance(one_row_complex.c.timestamp.type, types.NullType)
+ self.assertIsInstance(one_row_complex.c.binary.type, types.VARBINARY)
+ self.assertIsInstance(one_row_complex.c.array.type, types.NullType)
+ self.assertIsInstance(one_row_complex.c.map.type, types.NullType)
+ self.assertIsInstance(one_row_complex.c.struct.type, types.NullType)
+ self.assertIsInstance(one_row_complex.c.decimal.type, types.NullType)
+
+ @with_engine_connection
+ def test_reflect_no_such_table(self, engine, connection):
+ """reflecttable should throw an exception on an invalid table"""
+ self.assertRaises(
+ NoSuchTableError,
+ lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine))
+ self.assertRaises(
+ DatabaseError,
+ lambda: Table('this_does_not_exist', MetaData(schema="also_does_not_exist"), autoload_with=engine))
+
+ def test_url_default(self):
+ engine = create_engine('trino+pyhive://localhost:18080/hive')
+ try:
+ with contextlib.closing(engine.connect()) as connection:
+ self.assertEqual(connection.execute(text('SELECT 1 AS foobar FROM one_row')).scalar(), 1)
+ finally:
+ engine.dispose()
+
+ @with_engine_connection
+ def test_reserved_words(self, engine, connection):
+ """Trino uses double quotes, not backticks"""
+ # Use keywords for the table/column name
+ fake_table = Table('select', MetaData(), Column('current_timestamp', String))
+ query = str(fake_table.select().where(fake_table.c.current_timestamp == 'a').compile(engine))
+ self.assertIn('"select"', query)
+ self.assertIn('"current_timestamp"', query)
+ self.assertNotIn('`select`', query)
+ self.assertNotIn('`current_timestamp`', query)
diff --git a/pyhive/tests/test_trino.py b/pyhive/tests/test_trino.py
index cdc8bb43..41bb489b 100644
--- a/pyhive/tests/test_trino.py
+++ b/pyhive/tests/test_trino.py
@@ -9,6 +9,8 @@
import contextlib
import os
+from decimal import Decimal
+
import requests
from pyhive import exc
@@ -89,7 +91,7 @@ def test_complex(self, cursor):
{"1": 2, "3": 4}, # Trino converts all keys to strings so that they're valid JSON
[1, 2], # struct is returned as a list of elements
# '{0:1}',
- '0.1',
+ Decimal('0.1'),
)]
self.assertEqual(rows, expected)
# catch unicode/str
diff --git a/pyhive/trino.py b/pyhive/trino.py
index e8a1aabd..658457a3 100644
--- a/pyhive/trino.py
+++ b/pyhive/trino.py
@@ -124,7 +124,7 @@ def _process_response(self, response):
if 'data' in response_json:
assert self._columns
new_data = response_json['data']
- self._decode_binary(new_data)
+ self._process_data(new_data)
self._data += map(tuple, new_data)
if 'nextUri' not in response_json:
self._state = self._STATE_FINISHED
diff --git a/setup.py b/setup.py
index ad34a38b..d141ea1b 100755
--- a/setup.py
+++ b/setup.py
@@ -46,6 +46,7 @@ def run_tests(self):
'presto': ['requests>=1.0.0'],
'trino': ['requests>=1.0.0'],
'hive': ['sasl>=0.2.1', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'],
+ 'hive_pure_sasl': ['pure-sasl>=0.6.2', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'],
'sqlalchemy': ['sqlalchemy>=1.3.0'],
'kerberos': ['requests_kerberos>=0.12.0'],
},
@@ -56,6 +57,8 @@ def run_tests(self):
'requests>=1.0.0',
'requests_kerberos>=0.12.0',
'sasl>=0.2.1',
+ 'pure-sasl>=0.6.2',
+ 'kerberos>=1.3.0',
'sqlalchemy>=1.3.0',
'thrift>=0.10.0',
],
@@ -69,7 +72,7 @@ def run_tests(self):
"hive.http = pyhive.sqlalchemy_hive:HiveHTTPDialect",
"hive.https = pyhive.sqlalchemy_hive:HiveHTTPSDialect",
'presto = pyhive.sqlalchemy_presto:PrestoDialect',
- 'trino = pyhive.sqlalchemy_trino:TrinoDialect',
+ 'trino.pyhive = pyhive.sqlalchemy_trino:TrinoDialect',
],
}
)