From 8df7254c4016cbcb8a630166fdab9073955b0e48 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Mon, 14 Feb 2022 09:53:54 -0800 Subject: [PATCH 01/12] chore: rename Trino entry point (#428) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ad34a38b..be593fc0 100755 --- a/setup.py +++ b/setup.py @@ -69,7 +69,7 @@ def run_tests(self): "hive.http = pyhive.sqlalchemy_hive:HiveHTTPDialect", "hive.https = pyhive.sqlalchemy_hive:HiveHTTPSDialect", 'presto = pyhive.sqlalchemy_presto:PrestoDialect', - 'trino = pyhive.sqlalchemy_trino:TrinoDialect', + 'trino.pyhive = pyhive.sqlalchemy_trino:TrinoDialect', ], } ) From 3547bd6cccf963a033928b73c5ed498684335c39 Mon Sep 17 00:00:00 2001 From: serenajiang Date: Mon, 7 Mar 2022 13:43:09 -0800 Subject: [PATCH 02/12] Support for Presto decimals (#430) * Support for Presto decimals * lower --- pyhive/presto.py | 18 ++++++++++++------ pyhive/tests/test_presto.py | 4 +++- pyhive/tests/test_trino.py | 4 +++- pyhive/trino.py | 2 +- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/pyhive/presto.py b/pyhive/presto.py index a38cd891..3217f4c2 100644 --- a/pyhive/presto.py +++ b/pyhive/presto.py @@ -9,6 +9,8 @@ from __future__ import unicode_literals from builtins import object +from decimal import Decimal + from pyhive import common from pyhive.common import DBAPITypeObject # Make all exceptions visible in this module per DB-API @@ -34,6 +36,11 @@ _logger = logging.getLogger(__name__) +TYPES_CONVERTER = { + "decimal": Decimal, + # As of Presto 0.69, binary data is returned as the varbinary type in base64 format + "varbinary": base64.b64decode +} class PrestoParamEscaper(common.ParamEscaper): def escape_datetime(self, item, format): @@ -307,14 +314,13 @@ def _fetch_more(self): """Fetch the next URI and update state""" self._process_response(self._requests_session.get(self._nextUri, **self._requests_kwargs)) - def _decode_binary(self, rows): - # As of Presto 0.69, binary data is returned as the varbinary type in base64 format - # This function decodes base64 data in place + def _process_data(self, rows): for i, col in enumerate(self.description): - if col[1] == 'varbinary': + col_type = col[1].split("(")[0].lower() + if col_type in TYPES_CONVERTER: for row in rows: if row[i] is not None: - row[i] = base64.b64decode(row[i]) + row[i] = TYPES_CONVERTER[col_type](row[i]) def _process_response(self, response): """Given the JSON response from Presto's REST API, update the internal state with the next @@ -341,7 +347,7 @@ def _process_response(self, response): if 'data' in response_json: assert self._columns new_data = response_json['data'] - self._decode_binary(new_data) + self._process_data(new_data) self._data += map(tuple, new_data) if 'nextUri' not in response_json: self._state = self._STATE_FINISHED diff --git a/pyhive/tests/test_presto.py b/pyhive/tests/test_presto.py index 7c74f057..187b1c21 100644 --- a/pyhive/tests/test_presto.py +++ b/pyhive/tests/test_presto.py @@ -9,6 +9,8 @@ import contextlib import os +from decimal import Decimal + import requests from pyhive import exc @@ -93,7 +95,7 @@ def test_complex(self, cursor): {"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON [1, 2], # struct is returned as a list of elements # '{0:1}', - '0.1', + Decimal('0.1'), )] self.assertEqual(rows, expected) # catch unicode/str diff --git a/pyhive/tests/test_trino.py b/pyhive/tests/test_trino.py index cdc8bb43..41bb489b 100644 --- a/pyhive/tests/test_trino.py +++ b/pyhive/tests/test_trino.py @@ -9,6 +9,8 @@ import contextlib import os +from decimal import Decimal + import requests from pyhive import exc @@ -89,7 +91,7 @@ def test_complex(self, cursor): {"1": 2, "3": 4}, # Trino converts all keys to strings so that they're valid JSON [1, 2], # struct is returned as a list of elements # '{0:1}', - '0.1', + Decimal('0.1'), )] self.assertEqual(rows, expected) # catch unicode/str diff --git a/pyhive/trino.py b/pyhive/trino.py index e8a1aabd..658457a3 100644 --- a/pyhive/trino.py +++ b/pyhive/trino.py @@ -124,7 +124,7 @@ def _process_response(self, response): if 'data' in response_json: assert self._columns new_data = response_json['data'] - self._decode_binary(new_data) + self._process_data(new_data) self._data += map(tuple, new_data) if 'nextUri' not in response_json: self._state = self._STATE_FINISHED From 1f99552303626cce9eb6867fb7401fc810637fd6 Mon Sep 17 00:00:00 2001 From: Usiel Riedl Date: Tue, 9 May 2023 10:05:04 +0200 Subject: [PATCH 03/12] Use str type for driver and name in HiveDialect (#450) PyHive's HiveDialect usage of bytes for the name and driver fields is not the norm is causing issues upstream: https://github.com/apache/superset/issues/22316 Even other dialects within PyHive use strings. SQLAlchemy does not strictly require a string, but all the stock dialects return a string, so I figure it is heavily implied. I think the risk of breaking something upstream with this change is low (but it is there ofc). I figure in most cases we just make someone's `str(dialect.driver)` expression redundant. Examples for some of the other stock sqlalchemy dialects (name and driver fields using str): https://github.com/zzzeek/sqlalchemy/blob/main/lib/sqlalchemy/dialects/sqlite/pysqlite.py#L501 https://github.com/zzzeek/sqlalchemy/blob/main/lib/sqlalchemy/dialects/sqlite/base.py#L1891 https://github.com/zzzeek/sqlalchemy/blob/main/lib/sqlalchemy/dialects/mysql/base.py#L2383 https://github.com/zzzeek/sqlalchemy/blob/main/lib/sqlalchemy/dialects/mysql/mysqldb.py#L113 https://github.com/zzzeek/sqlalchemy/blob/main/lib/sqlalchemy/dialects/mysql/pymysql.py#L59 --- pyhive/sqlalchemy_hive.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyhive/sqlalchemy_hive.py b/pyhive/sqlalchemy_hive.py index 2ef49652..f39f1793 100644 --- a/pyhive/sqlalchemy_hive.py +++ b/pyhive/sqlalchemy_hive.py @@ -228,8 +228,8 @@ def _translate_colname(self, colname): class HiveDialect(default.DefaultDialect): - name = b'hive' - driver = b'thrift' + name = 'hive' + driver = 'thrift' execution_ctx_cls = HiveExecutionContext preparer = HiveIdentifierPreparer statement_compiler = HiveCompiler From 1c1da8b17bdf0e7e881e15bb731119558bd5440f Mon Sep 17 00:00:00 2001 From: Multazim Deshmukh <57723564+mdeshmu@users.noreply.github.com> Date: Wed, 17 May 2023 20:49:50 +0530 Subject: [PATCH 04/12] Correcting Iterable import for python 3.10 (#451) --- pyhive/common.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pyhive/common.py b/pyhive/common.py index 298633a1..51692b97 100644 --- a/pyhive/common.py +++ b/pyhive/common.py @@ -18,6 +18,11 @@ from future.utils import with_metaclass from itertools import islice +try: + from collections.abc import Iterable +except ImportError: + from collections import Iterable + class DBAPICursor(with_metaclass(abc.ABCMeta, object)): """Base class for some common DB-API logic""" @@ -245,7 +250,7 @@ def escape_item(self, item): return self.escape_number(item) elif isinstance(item, basestring): return self.escape_string(item) - elif isinstance(item, collections.Iterable): + elif isinstance(item, Iterable): return self.escape_sequence(item) elif isinstance(item, datetime.datetime): return self.escape_datetime(item, self._DATETIME_FORMAT) From b0206d3cb8a9f9a95a36feeae311f6b0141c6675 Mon Sep 17 00:00:00 2001 From: nicholas-miles Date: Wed, 17 May 2023 08:21:07 -0700 Subject: [PATCH 05/12] changing drivers to support hive, presto and trino with sqlalchemy>=2.0 (#448) --- pyhive/sqlalchemy_hive.py | 14 +++++++++++--- pyhive/sqlalchemy_presto.py | 8 ++++++-- pyhive/sqlalchemy_trino.py | 8 ++++++-- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/pyhive/sqlalchemy_hive.py b/pyhive/sqlalchemy_hive.py index f39f1793..34fdb648 100644 --- a/pyhive/sqlalchemy_hive.py +++ b/pyhive/sqlalchemy_hive.py @@ -13,11 +13,19 @@ import re from sqlalchemy import exc -from sqlalchemy import processors +try: + from sqlalchemy import processors +except ImportError: + # Newer versions of sqlalchemy require: + from sqlalchemy.engine import processors from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -from sqlalchemy.databases import mysql +try: + from sqlalchemy.databases.mysql import MSTinyInteger +except ImportError: + # Newer versions of sqlalchemy require: + from sqlalchemy.dialects.mysql import MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -121,7 +129,7 @@ def __init__(self, dialect): _type_map = { 'boolean': types.Boolean, - 'tinyint': mysql.MSTinyInteger, + 'tinyint': MSTinyInteger, 'smallint': types.SmallInteger, 'int': types.Integer, 'bigint': types.BigInteger, diff --git a/pyhive/sqlalchemy_presto.py b/pyhive/sqlalchemy_presto.py index a199ebe1..94d06412 100644 --- a/pyhive/sqlalchemy_presto.py +++ b/pyhive/sqlalchemy_presto.py @@ -13,7 +13,11 @@ from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -from sqlalchemy.databases import mysql +try: + from sqlalchemy.databases.mysql import MSTinyInteger +except ImportError: + # Newer versions of sqlalchemy require: + from sqlalchemy.dialects.mysql import MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -29,7 +33,7 @@ class PrestoIdentifierPreparer(compiler.IdentifierPreparer): _type_map = { 'boolean': types.Boolean, - 'tinyint': mysql.MSTinyInteger, + 'tinyint': MSTinyInteger, 'smallint': types.SmallInteger, 'integer': types.Integer, 'bigint': types.BigInteger, diff --git a/pyhive/sqlalchemy_trino.py b/pyhive/sqlalchemy_trino.py index 4b2b3698..686a42c7 100644 --- a/pyhive/sqlalchemy_trino.py +++ b/pyhive/sqlalchemy_trino.py @@ -13,7 +13,11 @@ from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -from sqlalchemy.databases import mysql +try: + from sqlalchemy.databases.mysql import MSTinyInteger +except ImportError: + # Newer versions of sqlalchemy require: + from sqlalchemy.dialects.mysql import MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -28,7 +32,7 @@ class TrinoIdentifierPreparer(PrestoIdentifierPreparer): _type_map = { 'boolean': types.Boolean, - 'tinyint': mysql.MSTinyInteger, + 'tinyint': MSTinyInteger, 'smallint': types.SmallInteger, 'integer': types.Integer, 'bigint': types.BigInteger, From df03bef66500541fa921ec3614ec06a15ca17615 Mon Sep 17 00:00:00 2001 From: Bogdan Date: Wed, 17 May 2023 09:32:32 -0700 Subject: [PATCH 06/12] Revert "changing drivers to support hive, presto and trino with sqlalchemy>=2.0 (#448)" (#452) This reverts commit b0206d3cb8a9f9a95a36feeae311f6b0141c6675. --- pyhive/sqlalchemy_hive.py | 14 +++----------- pyhive/sqlalchemy_presto.py | 8 ++------ pyhive/sqlalchemy_trino.py | 8 ++------ 3 files changed, 7 insertions(+), 23 deletions(-) diff --git a/pyhive/sqlalchemy_hive.py b/pyhive/sqlalchemy_hive.py index 34fdb648..f39f1793 100644 --- a/pyhive/sqlalchemy_hive.py +++ b/pyhive/sqlalchemy_hive.py @@ -13,19 +13,11 @@ import re from sqlalchemy import exc -try: - from sqlalchemy import processors -except ImportError: - # Newer versions of sqlalchemy require: - from sqlalchemy.engine import processors +from sqlalchemy import processors from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -try: - from sqlalchemy.databases.mysql import MSTinyInteger -except ImportError: - # Newer versions of sqlalchemy require: - from sqlalchemy.dialects.mysql import MSTinyInteger +from sqlalchemy.databases import mysql from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -129,7 +121,7 @@ def __init__(self, dialect): _type_map = { 'boolean': types.Boolean, - 'tinyint': MSTinyInteger, + 'tinyint': mysql.MSTinyInteger, 'smallint': types.SmallInteger, 'int': types.Integer, 'bigint': types.BigInteger, diff --git a/pyhive/sqlalchemy_presto.py b/pyhive/sqlalchemy_presto.py index 94d06412..a199ebe1 100644 --- a/pyhive/sqlalchemy_presto.py +++ b/pyhive/sqlalchemy_presto.py @@ -13,11 +13,7 @@ from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -try: - from sqlalchemy.databases.mysql import MSTinyInteger -except ImportError: - # Newer versions of sqlalchemy require: - from sqlalchemy.dialects.mysql import MSTinyInteger +from sqlalchemy.databases import mysql from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -33,7 +29,7 @@ class PrestoIdentifierPreparer(compiler.IdentifierPreparer): _type_map = { 'boolean': types.Boolean, - 'tinyint': MSTinyInteger, + 'tinyint': mysql.MSTinyInteger, 'smallint': types.SmallInteger, 'integer': types.Integer, 'bigint': types.BigInteger, diff --git a/pyhive/sqlalchemy_trino.py b/pyhive/sqlalchemy_trino.py index 686a42c7..4b2b3698 100644 --- a/pyhive/sqlalchemy_trino.py +++ b/pyhive/sqlalchemy_trino.py @@ -13,11 +13,7 @@ from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -try: - from sqlalchemy.databases.mysql import MSTinyInteger -except ImportError: - # Newer versions of sqlalchemy require: - from sqlalchemy.dialects.mysql import MSTinyInteger +from sqlalchemy.databases import mysql from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -32,7 +28,7 @@ class TrinoIdentifierPreparer(PrestoIdentifierPreparer): _type_map = { 'boolean': types.Boolean, - 'tinyint': MSTinyInteger, + 'tinyint': mysql.MSTinyInteger, 'smallint': types.SmallInteger, 'integer': types.Integer, 'bigint': types.BigInteger, From 0bd6f5b5f76f759cd01b83287cec15da9789753e Mon Sep 17 00:00:00 2001 From: Bogdan Date: Wed, 17 May 2023 10:02:09 -0700 Subject: [PATCH 07/12] Update __init__.py (#453) https://github.com/dropbox/PyHive/commit/1c1da8b17bdf0e7e881e15bb731119558bd5440f https://github.com/dropbox/PyHive/commit/1f99552303626cce9eb6867fb7401fc810637fd6 --- pyhive/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyhive/__init__.py b/pyhive/__init__.py index 8ede6abb..0a6bb1f6 100644 --- a/pyhive/__init__.py +++ b/pyhive/__init__.py @@ -1,3 +1,3 @@ from __future__ import absolute_import from __future__ import unicode_literals -__version__ = '0.6.3' +__version__ = '0.7.0' From 4367cc550252f9e6f85782bee7f8694325a742a6 Mon Sep 17 00:00:00 2001 From: Multazim Deshmukh <57723564+mdeshmu@users.noreply.github.com> Date: Tue, 20 Jun 2023 16:37:18 +0530 Subject: [PATCH 08/12] use pure-sasl with python 3.11 (#454) --- dev_requirements.txt | 2 + pyhive/hive.py | 56 ++++-- pyhive/sasl_compat.py | 56 ++++++ pyhive/tests/test_hive.py | 11 +- pyhive/tests/test_sasl_compat.py | 333 +++++++++++++++++++++++++++++++ setup.py | 3 + 6 files changed, 436 insertions(+), 25 deletions(-) create mode 100644 pyhive/sasl_compat.py create mode 100644 pyhive/tests/test_sasl_compat.py diff --git a/dev_requirements.txt b/dev_requirements.txt index 0bf6d8a7..40bb605a 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -12,6 +12,8 @@ pytest-timeout==1.2.0 requests>=1.0.0 requests_kerberos>=0.12.0 sasl>=0.2.1 +pure-sasl>=0.6.2 +kerberos>=1.3.0 thrift>=0.10.0 #thrift_sasl>=0.1.0 git+https://github.com/cloudera/thrift_sasl # Using master branch in order to get Python 3 SASL patches diff --git a/pyhive/hive.py b/pyhive/hive.py index 3f71df33..c1287488 100644 --- a/pyhive/hive.py +++ b/pyhive/hive.py @@ -49,6 +49,45 @@ } +def get_sasl_client(host, sasl_auth, service=None, username=None, password=None): + import sasl + sasl_client = sasl.Client() + sasl_client.setAttr('host', host) + + if sasl_auth == 'GSSAPI': + sasl_client.setAttr('service', service) + elif sasl_auth == 'PLAIN': + sasl_client.setAttr('username', username) + sasl_client.setAttr('password', password) + else: + raise ValueError("sasl_auth only supports GSSAPI and PLAIN") + + sasl_client.init() + return sasl_client + + +def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=None): + from pyhive.sasl_compat import PureSASLClient + + if sasl_auth == 'GSSAPI': + sasl_kwargs = {'service': service} + elif sasl_auth == 'PLAIN': + sasl_kwargs = {'username': username, 'password': password} + else: + raise ValueError("sasl_auth only supports GSSAPI and PLAIN") + + return PureSASLClient(host=host, **sasl_kwargs) + + +def get_installed_sasl(host, sasl_auth, service=None, username=None, password=None): + try: + return get_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password) + # The sasl library is available + except ImportError: + # Fallback to pure-sasl library + return get_pure_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password) + + def _parse_timestamp(value): if value: match = _TIMESTAMP_PATTERN.match(value) @@ -200,7 +239,6 @@ def __init__( self._transport = thrift.transport.TTransport.TBufferedTransport(socket) elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'): # Defer import so package dependency is optional - import sasl import thrift_sasl if auth == 'KERBEROS': @@ -211,20 +249,8 @@ def __init__( if password is None: # Password doesn't matter in NONE mode, just needs to be nonempty. password = 'x' - - def sasl_factory(): - sasl_client = sasl.Client() - sasl_client.setAttr('host', host) - if sasl_auth == 'GSSAPI': - sasl_client.setAttr('service', kerberos_service_name) - elif sasl_auth == 'PLAIN': - sasl_client.setAttr('username', username) - sasl_client.setAttr('password', password) - else: - raise AssertionError - sasl_client.init() - return sasl_client - self._transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket) + + self._transport = thrift_sasl.TSaslClientTransport(lambda: get_installed_sasl(host=host, sasl_auth=sasl_auth, service=kerberos_service_name, username=username, password=password), sasl_auth, socket) else: # All HS2 config options: # https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration diff --git a/pyhive/sasl_compat.py b/pyhive/sasl_compat.py new file mode 100644 index 00000000..dc65abe9 --- /dev/null +++ b/pyhive/sasl_compat.py @@ -0,0 +1,56 @@ +# Original source of this file is https://github.com/cloudera/impyla/blob/master/impala/sasl_compat.py +# which uses Apache-2.0 license as of 21 May 2023. +# This code was added to Impyla in 2016 as a compatibility layer to allow use of either python-sasl or pure-sasl +# via PR https://github.com/cloudera/impyla/pull/179 +# Even though thrift_sasl lists pure-sasl as dependency here https://github.com/cloudera/thrift_sasl/blob/master/setup.py#L34 +# but it still calls functions native to python-sasl in this file https://github.com/cloudera/thrift_sasl/blob/master/thrift_sasl/__init__.py#L82 +# Hence this code is required for the fallback to work. + + +from puresasl.client import SASLClient, SASLError +from contextlib import contextmanager + +@contextmanager +def error_catcher(self, Exc = Exception): + try: + self.error = None + yield + except Exc as e: + self.error = str(e) + + +class PureSASLClient(SASLClient): + def __init__(self, *args, **kwargs): + self.error = None + super(PureSASLClient, self).__init__(*args, **kwargs) + + def start(self, mechanism): + with error_catcher(self, SASLError): + if isinstance(mechanism, list): + self.choose_mechanism(mechanism) + else: + self.choose_mechanism([mechanism]) + return True, self.mechanism, self.process() + # else + return False, mechanism, None + + def encode(self, incoming): + with error_catcher(self): + return True, self.unwrap(incoming) + # else + return False, None + + def decode(self, outgoing): + with error_catcher(self): + return True, self.wrap(outgoing) + # else + return False, None + + def step(self, challenge=None): + with error_catcher(self): + return True, self.process(challenge) + # else + return False, None + + def getError(self): + return self.error diff --git a/pyhive/tests/test_hive.py b/pyhive/tests/test_hive.py index c70ed962..b49fc190 100644 --- a/pyhive/tests/test_hive.py +++ b/pyhive/tests/test_hive.py @@ -17,7 +17,6 @@ from decimal import Decimal import mock -import sasl import thrift.transport.TSocket import thrift.transport.TTransport import thrift_sasl @@ -204,15 +203,7 @@ def test_custom_transport(self): socket = thrift.transport.TSocket.TSocket('localhost', 10000) sasl_auth = 'PLAIN' - def sasl_factory(): - sasl_client = sasl.Client() - sasl_client.setAttr('host', 'localhost') - sasl_client.setAttr('username', 'test_username') - sasl_client.setAttr('password', 'x') - sasl_client.init() - return sasl_client - - transport = thrift_sasl.TSaslClientTransport(sasl_factory, sasl_auth, socket) + transport = thrift_sasl.TSaslClientTransport(lambda: hive.get_installed_sasl(host='localhost', sasl_auth=sasl_auth, username='test_username', password='x'), sasl_auth, socket) conn = hive.connect(thrift_transport=transport) with contextlib.closing(conn): with contextlib.closing(conn.cursor()) as cursor: diff --git a/pyhive/tests/test_sasl_compat.py b/pyhive/tests/test_sasl_compat.py new file mode 100644 index 00000000..49516249 --- /dev/null +++ b/pyhive/tests/test_sasl_compat.py @@ -0,0 +1,333 @@ +''' +http://www.opensource.org/licenses/mit-license.php + +Copyright 2007-2011 David Alan Cridland +Copyright 2011 Lance Stout +Copyright 2012 Tyler L Hobbs + +Permission is hereby granted, free of charge, to any person obtaining a copy of this +software and associated documentation files (the "Software"), to deal in the Software +without restriction, including without limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons +to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or +substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +''' +# This file was generated by referring test cases from the pure-sasl repo i.e. https://github.com/thobbs/pure-sasl/tree/master/tests/unit +# and by refactoring them to cover wrapper functions in sasl_compat.py along with added coverage for functions exclusive to sasl_compat.py. + +import unittest +import base64 +import hashlib +import hmac +import kerberos +from mock import patch +import six +import struct +from puresasl import SASLProtocolException, QOP +from puresasl.client import SASLError +from pyhive.sasl_compat import PureSASLClient, error_catcher + + +class TestPureSASLClient(unittest.TestCase): + """Test cases for initialization of SASL client using PureSASLClient class""" + + def setUp(self): + self.sasl_kwargs = {} + self.sasl = PureSASLClient('localhost', **self.sasl_kwargs) + + def test_start_no_mechanism(self): + """Test starting SASL authentication with no mechanism.""" + success, mechanism, response = self.sasl.start(mechanism=None) + self.assertFalse(success) + self.assertIsNone(mechanism) + self.assertIsNone(response) + self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + + def test_start_wrong_mechanism(self): + """Test starting SASL authentication with a single unsupported mechanism.""" + success, mechanism, response = self.sasl.start(mechanism='WRONG') + self.assertFalse(success) + self.assertEqual(mechanism, 'WRONG') + self.assertIsNone(response) + self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + + def test_start_list_of_invalid_mechanisms(self): + """Test starting SASL authentication with a list of unsupported mechanisms.""" + self.sasl.start(['invalid1', 'invalid2']) + self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + + def test_start_list_of_valid_mechanisms(self): + """Test starting SASL authentication with a list of supported mechanisms.""" + self.sasl.start(['PLAIN', 'DIGEST-MD5', 'CRAM-MD5']) + # Validate right mechanism is chosen based on score. + self.assertEqual(self.sasl._chosen_mech.name, 'DIGEST-MD5') + + def test_error_catcher_no_error(self): + """Test the error_catcher with no error.""" + with error_catcher(self.sasl): + result, _, _ = self.sasl.start(mechanism='ANONYMOUS') + + self.assertEqual(self.sasl.getError(), None) + self.assertEqual(result, True) + + def test_error_catcher_with_error(self): + """Test the error_catcher with an error.""" + with error_catcher(self.sasl): + result, _, _ = self.sasl.start(mechanism='WRONG') + + self.assertEqual(result, False) + self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties') + +"""Assuming Client initilization went well and a mechanism is chosen, Below are the test cases for different mechanims""" + +class _BaseMechanismTests(unittest.TestCase): + """Base test case for SASL mechanisms.""" + + mechanism = 'ANONYMOUS' + sasl_kwargs = {} + + def setUp(self): + self.sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs) + self.mechanism_class = self.sasl._chosen_mech + + def test_init_basic(self, *args): + sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs) + mech = sasl._chosen_mech + self.assertIs(mech.sasl, sasl) + + def test_step_basic(self, *args): + success, response = self.sasl.step(six.b('string')) + self.assertTrue(success) + self.assertIsInstance(response, six.binary_type) + + def test_decode_encode(self, *args): + self.assertEqual(self.sasl.encode('msg'), (False, None)) + self.assertEqual(self.sasl.getError(), '') + self.assertEqual(self.sasl.decode('msg'), (False, None)) + self.assertEqual(self.sasl.getError(), '') + + +class AnonymousMechanismTest(_BaseMechanismTests): + """Test case for the Anonymous SASL mechanism.""" + + mechanism = 'ANONYMOUS' + + +class PlainTextMechanismTest(_BaseMechanismTests): + """Test case for the PlainText SASL mechanism.""" + + mechanism = 'PLAIN' + username = 'user' + password = 'pass' + sasl_kwargs = {'username': username, 'password': password} + + def test_step(self): + for challenge in (None, '', b'asdf', u"\U0001F44D"): + success, response = self.sasl.step(challenge) + self.assertTrue(success) + self.assertEqual(response, six.b(f'\x00{self.username}\x00{self.password}')) + self.assertIsInstance(response, six.binary_type) + + def test_step_with_authorization_id_or_identity(self): + challenge = u"\U0001F44D" + identity = 'user2' + + # Test that we can pass an identity + sasl_kwargs = self.sasl_kwargs.copy() + sasl_kwargs.update({'identity': identity}) + sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs) + success, response = sasl.step(challenge) + self.assertTrue(success) + self.assertEqual(response, six.b(f'{identity}\x00{self.username}\x00{self.password}')) + self.assertIsInstance(response, six.binary_type) + self.assertTrue(sasl.complete) + + # Test that the sasl authorization_id has priority over identity + auth_id = 'user3' + sasl_kwargs.update({'authorization_id': auth_id}) + sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs) + success, response = sasl.step(challenge) + self.assertTrue(success) + self.assertEqual(response, six.b(f'{auth_id}\x00{self.username}\x00{self.password}')) + self.assertIsInstance(response, six.binary_type) + self.assertTrue(sasl.complete) + + def test_decode_encode(self): + msg = 'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + +class ExternalMechanismTest(_BaseMechanismTests): + """Test case for the External SASL mechanisms""" + + mechanism = 'EXTERNAL' + + def test_step(self): + self.assertEqual(self.sasl.step(), (True, b'')) + + def test_decode_encode(self): + msg = 'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + +@patch('puresasl.mechanisms.kerberos.authGSSClientStep') +@patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=base64.b64encode(six.b('some\x00 response'))) +class GSSAPIMechanismTest(_BaseMechanismTests): + """Test case for the GSSAPI SASL mechanism.""" + + mechanism = 'GSSAPI' + service = 'GSSAPI' + sasl_kwargs = {'service': service} + + @patch('puresasl.mechanisms.kerberos.authGSSClientWrap') + @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap') + def test_decode_encode(self, _inner1, _inner2, authGSSClientResponse, *args): + # bypassing step setup by setting qop directly + self.mechanism_class.qop = QOP.AUTH + msg = b'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + # Test for behavior with different QOP like data integrity and confidentiality for Kerberos authentication + for qop in (QOP.AUTH_INT, QOP.AUTH_CONF): + self.mechanism_class.qop = qop + with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=1): + self.assertEqual(self.sasl.decode(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) + self.assertEqual(self.sasl.encode(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) + if qop == QOP.AUTH_CONF: + with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=0): + self.assertEqual(self.sasl.encode(msg), (False, None)) + self.assertEqual(self.sasl.getError(), 'Error: confidentiality requested, but not honored by the server.') + + def test_step_no_user(self, authGSSClientResponse, *args): + msg = six.b('whatever') + + # no user + self.assertEqual(self.sasl.step(msg), (True, base64.b64decode(authGSSClientResponse.return_value))) + with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=''): + self.assertEqual(self.sasl.step(msg), (True, six.b(''))) + + username = 'username' + # with user; this has to be last because it sets mechanism.user + with patch('puresasl.mechanisms.kerberos.authGSSClientStep', return_value=kerberos.AUTH_GSS_COMPLETE): + with patch('puresasl.mechanisms.kerberos.authGSSClientUserName', return_value=six.b(username)): + self.assertEqual(self.sasl.step(msg), (True, six.b(''))) + self.assertEqual(self.mechanism_class.user, six.b(username)) + + @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap') + def test_step_qop(self, *args): + self.mechanism_class._have_negotiated_details = True + self.mechanism_class.user = 'user' + msg = six.b('msg') + self.assertEqual(self.sasl.step(msg), (False, None)) + self.assertEqual(self.sasl.getError(), 'Bad response from server') + + max_len = 100 + self.assertLess(max_len, self.sasl.max_buffer) + for i, qop in QOP.bit_map.items(): + qop_size = struct.pack('!i', i << 24 | max_len) + response = base64.b64encode(qop_size) + with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=response): + with patch('puresasl.mechanisms.kerberos.authGSSClientWrap') as authGSSClientWrap: + self.mechanism_class.complete = False + self.assertEqual(self.sasl.step(msg), (True, qop_size)) + self.assertTrue(self.mechanism_class.complete) + self.assertEqual(self.mechanism_class.qop, qop) + self.assertEqual(self.mechanism_class.max_buffer, max_len) + + args = authGSSClientWrap.call_args[0] + out_data = args[1] + out = base64.b64decode(out_data) + self.assertEqual(out[:4], qop_size) + self.assertEqual(out[4:], six.b(self.mechanism_class.user)) + + +class CramMD5MechanismTest(_BaseMechanismTests): + """Test case for the CRAM-MD5 SASL mechanism.""" + + mechanism = 'CRAM-MD5' + username = 'user' + password = 'pass' + sasl_kwargs = {'username': username, 'password': password} + + def test_step(self): + success, response = self.sasl.step(None) + self.assertTrue(success) + self.assertIsNone(response) + challenge = six.b('msg') + hash = hmac.HMAC(key=six.b(self.password), digestmod=hashlib.md5) + hash.update(challenge) + success, response = self.sasl.step(challenge) + self.assertTrue(success) + self.assertIn(six.b(self.username), response) + self.assertIn(six.b(hash.hexdigest()), response) + self.assertIsInstance(response, six.binary_type) + self.assertTrue(self.sasl.complete) + + def test_decode_encode(self): + msg = 'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + +class DigestMD5MechanismTest(_BaseMechanismTests): + """Test case for the DIGEST-MD5 SASL mechanism.""" + + mechanism = 'DIGEST-MD5' + username = 'user' + password = 'pass' + sasl_kwargs = {'username': username, 'password': password} + + def test_decode_encode(self): + msg = 'msg' + self.assertEqual(self.sasl.decode(msg), (True, msg)) + self.assertEqual(self.sasl.encode(msg), (True, msg)) + + def test_step_basic(self, *args): + pass + + def test_step(self): + """Test a SASL step with dummy challenge for DIGEST-MD5 mechanism.""" + testChallenge = ( + b'nonce="rmD6R8aMYVWH+/ih9HGBr3xNGAR6o2DUxpKlgDz6gUQ=",r' + b'ealm="example.org",qop="auth,auth-int,auth-conf",cipher="rc4-40,rc' + b'4-56,rc4,des,3des",maxbuf=65536,charset=utf-8,algorithm=md5-sess' + ) + result, response = self.sasl.step(testChallenge) + self.assertTrue(result) + self.assertIsNotNone(response) + + def test_step_server_answer(self): + """Test a SASL step with a proper server answer for DIGEST-MD5 mechanism.""" + sasl_kwargs = {'username': "chris", 'password': "secret"} + sasl = PureSASLClient('elwood.innosoft.com', + service="imap", + mechanism=self.mechanism, + mutual_auth=True, + **sasl_kwargs) + testChallenge = ( + b'utf-8,username="chris",realm="elwood.innosoft.com",' + b'nonce="OA6MG9tEQGm2hh",nc=00000001,cnonce="OA6MHXh6VqTrRk",' + b'digest-uri="imap/elwood.innosoft.com",' + b'response=d388dad90d4bbd760a152321f2143af7,qop=auth' + ) + sasl.step(testChallenge) + sasl._chosen_mech.cnonce = b"OA6MHXh6VqTrRk" + + serverResponse = ( + b'rspauth=ea40f60335c427b5527b84dbabcdfffd' + ) + sasl.step(serverResponse) + # assert that step choses the only supported QOP for for DIGEST-MD5 + self.assertEqual(self.sasl.qop, QOP.AUTH) diff --git a/setup.py b/setup.py index be593fc0..d141ea1b 100755 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ def run_tests(self): 'presto': ['requests>=1.0.0'], 'trino': ['requests>=1.0.0'], 'hive': ['sasl>=0.2.1', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'], + 'hive_pure_sasl': ['pure-sasl>=0.6.2', 'thrift>=0.10.0', 'thrift_sasl>=0.1.0'], 'sqlalchemy': ['sqlalchemy>=1.3.0'], 'kerberos': ['requests_kerberos>=0.12.0'], }, @@ -56,6 +57,8 @@ def run_tests(self): 'requests>=1.0.0', 'requests_kerberos>=0.12.0', 'sasl>=0.2.1', + 'pure-sasl>=0.6.2', + 'kerberos>=1.3.0', 'sqlalchemy>=1.3.0', 'thrift>=0.10.0', ], From 5f0ee1f2ad558e120474b31ef065bb42457d1208 Mon Sep 17 00:00:00 2001 From: Multazim Deshmukh <57723564+mdeshmu@users.noreply.github.com> Date: Sat, 8 Jul 2023 15:17:51 +0530 Subject: [PATCH 09/12] minimal changes for sqlalchemy 2.0 support (#457) --- README.rst | 18 ++++- pyhive/sqlalchemy_hive.py | 30 ++++++-- pyhive/sqlalchemy_presto.py | 28 ++++++-- pyhive/sqlalchemy_trino.py | 15 +++- pyhive/tests/sqlalchemy_test_case.py | 88 +++++++++++++++-------- pyhive/tests/test_sqlalchemy_hive.py | 98 ++++++++++++++++---------- pyhive/tests/test_sqlalchemy_presto.py | 14 ++-- pyhive/tests/test_sqlalchemy_trino.py | 93 ++++++++++++++++++++++++ 8 files changed, 293 insertions(+), 91 deletions(-) create mode 100644 pyhive/tests/test_sqlalchemy_trino.py diff --git a/README.rst b/README.rst index 89c54532..5afd746c 100644 --- a/README.rst +++ b/README.rst @@ -71,9 +71,11 @@ First install this package to register it with SQLAlchemy (see ``setup.py``). # Presto engine = create_engine('presto://localhost:8080/hive/default') # Trino - engine = create_engine('trino://localhost:8080/hive/default') + engine = create_engine('trino+pyhive://localhost:8080/hive/default') # Hive engine = create_engine('hive://localhost:10000/default') + + # SQLAlchemy < 2.0 logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True) print select([func.count('*')], from_obj=logs).scalar() @@ -82,6 +84,20 @@ First install this package to register it with SQLAlchemy (see ``setup.py``). logs = Table('my_awesome_data', MetaData(bind=engine), autoload=True) print select([func.count('*')], from_obj=logs).scalar() + # SQLAlchemy >= 2.0 + metadata_obj = MetaData() + books = Table("books", metadata_obj, Column("id", Integer), Column("title", String), Column("primary_author", String)) + metadata_obj.create_all(engine) + inspector = inspect(engine) + inspector.get_columns('books') + + with engine.connect() as con: + data = [{ "id": 1, "title": "The Hobbit", "primary_author": "Tolkien" }, + { "id": 2, "title": "The Silmarillion", "primary_author": "Tolkien" }] + con.execute(books.insert(), data[0]) + result = con.execute(text("select * from books")) + print(result.fetchall()) + Note: query generation functionality is not exhaustive or fully tested, but there should be no problem with raw SQL. diff --git a/pyhive/sqlalchemy_hive.py b/pyhive/sqlalchemy_hive.py index f39f1793..e2244525 100644 --- a/pyhive/sqlalchemy_hive.py +++ b/pyhive/sqlalchemy_hive.py @@ -13,11 +13,22 @@ import re from sqlalchemy import exc -from sqlalchemy import processors +from sqlalchemy.sql import text +try: + from sqlalchemy import processors +except ImportError: + # Required for SQLAlchemy>=2.0 + from sqlalchemy.engine import processors from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -from sqlalchemy.databases import mysql +try: + from sqlalchemy.databases import mysql + mysql_tinyinteger = mysql.MSTinyInteger +except ImportError: + # Required for SQLAlchemy>2.0 + from sqlalchemy.dialects import mysql + mysql_tinyinteger = mysql.base.MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -121,7 +132,7 @@ def __init__(self, dialect): _type_map = { 'boolean': types.Boolean, - 'tinyint': mysql.MSTinyInteger, + 'tinyint': mysql_tinyinteger, 'smallint': types.SmallInteger, 'int': types.Integer, 'bigint': types.BigInteger, @@ -247,10 +258,15 @@ class HiveDialect(default.DefaultDialect): supports_multivalues_insert = True type_compiler = HiveTypeCompiler supports_sane_rowcount = False + supports_statement_cache = False @classmethod def dbapi(cls): return hive + + @classmethod + def import_dbapi(cls): + return hive def create_connect_args(self, url): kwargs = { @@ -265,7 +281,7 @@ def create_connect_args(self, url): def get_schema_names(self, connection, **kw): # Equivalent to SHOW DATABASES - return [row[0] for row in connection.execute('SHOW SCHEMAS')] + return [row[0] for row in connection.execute(text('SHOW SCHEMAS'))] def get_view_names(self, connection, schema=None, **kw): # Hive does not provide functionality to query tableType @@ -280,7 +296,7 @@ def _get_table_columns(self, connection, table_name, schema): # Using DESCRIBE works but is uglier. try: # This needs the table name to be unescaped (no backticks). - rows = connection.execute('DESCRIBE {}'.format(full_table)).fetchall() + rows = connection.execute(text('DESCRIBE {}'.format(full_table))).fetchall() except exc.OperationalError as e: # Does the table exist? regex_fmt = r'TExecuteStatementResp.*SemanticException.*Table not found {}' @@ -296,7 +312,7 @@ def _get_table_columns(self, connection, table_name, schema): raise exc.NoSuchTableError(full_table) return rows - def has_table(self, connection, table_name, schema=None): + def has_table(self, connection, table_name, schema=None, **kw): try: self._get_table_columns(connection, table_name, schema) return True @@ -361,7 +377,7 @@ def get_table_names(self, connection, schema=None, **kw): query = 'SHOW TABLES' if schema: query += ' IN ' + self.identifier_preparer.quote_identifier(schema) - return [row[0] for row in connection.execute(query)] + return [row[0] for row in connection.execute(text(query))] def do_rollback(self, dbapi_connection): # No transactions for Hive diff --git a/pyhive/sqlalchemy_presto.py b/pyhive/sqlalchemy_presto.py index a199ebe1..bfe1ba04 100644 --- a/pyhive/sqlalchemy_presto.py +++ b/pyhive/sqlalchemy_presto.py @@ -9,11 +9,19 @@ from __future__ import unicode_literals import re +import sqlalchemy from sqlalchemy import exc from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -from sqlalchemy.databases import mysql +from sqlalchemy.sql import text +try: + from sqlalchemy.databases import mysql + mysql_tinyinteger = mysql.MSTinyInteger +except ImportError: + # Required for SQLAlchemy>=2.0 + from sqlalchemy.dialects import mysql + mysql_tinyinteger = mysql.base.MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -21,6 +29,7 @@ from pyhive import presto from pyhive.common import UniversalSet +sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) class PrestoIdentifierPreparer(compiler.IdentifierPreparer): # Just quote everything to make things simpler / easier to upgrade @@ -29,7 +38,7 @@ class PrestoIdentifierPreparer(compiler.IdentifierPreparer): _type_map = { 'boolean': types.Boolean, - 'tinyint': mysql.MSTinyInteger, + 'tinyint': mysql_tinyinteger, 'smallint': types.SmallInteger, 'integer': types.Integer, 'bigint': types.BigInteger, @@ -80,6 +89,7 @@ class PrestoDialect(default.DefaultDialect): supports_multivalues_insert = True supports_unicode_statements = True supports_unicode_binds = True + supports_statement_cache = False returns_unicode_strings = True description_encoding = None supports_native_boolean = True @@ -88,6 +98,10 @@ class PrestoDialect(default.DefaultDialect): @classmethod def dbapi(cls): return presto + + @classmethod + def import_dbapi(cls): + return presto def create_connect_args(self, url): db_parts = (url.database or 'hive').split('/') @@ -108,14 +122,14 @@ def create_connect_args(self, url): return [], kwargs def get_schema_names(self, connection, **kw): - return [row.Schema for row in connection.execute('SHOW SCHEMAS')] + return [row.Schema for row in connection.execute(text('SHOW SCHEMAS'))] def _get_table_columns(self, connection, table_name, schema): full_table = self.identifier_preparer.quote_identifier(table_name) if schema: full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table try: - return connection.execute('SHOW COLUMNS FROM {}'.format(full_table)) + return connection.execute(text('SHOW COLUMNS FROM {}'.format(full_table))) except (presto.DatabaseError, exc.DatabaseError) as e: # Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which # it successfully does in the Hive version. The difference with Presto is that this @@ -134,7 +148,7 @@ def _get_table_columns(self, connection, table_name, schema): else: raise - def has_table(self, connection, table_name, schema=None): + def has_table(self, connection, table_name, schema=None, **kw): try: self._get_table_columns(connection, table_name, schema) return True @@ -176,6 +190,8 @@ def get_indexes(self, connection, table_name, schema=None, **kw): # - a boolean column named "Partition Key" # - a string in the "Comment" column # - a string in the "Extra" column + if sqlalchemy_version >= 1.4: + row = row._mapping is_partition_key = ( (part_key in row and row[part_key]) or row['Comment'].startswith(part_key) @@ -192,7 +208,7 @@ def get_table_names(self, connection, schema=None, **kw): query = 'SHOW TABLES' if schema: query += ' FROM ' + self.identifier_preparer.quote_identifier(schema) - return [row.Table for row in connection.execute(query)] + return [row.Table for row in connection.execute(text(query))] def do_rollback(self, dbapi_connection): # No transactions for Presto diff --git a/pyhive/sqlalchemy_trino.py b/pyhive/sqlalchemy_trino.py index 4b2b3698..11be2a6c 100644 --- a/pyhive/sqlalchemy_trino.py +++ b/pyhive/sqlalchemy_trino.py @@ -13,7 +13,13 @@ from sqlalchemy import types from sqlalchemy import util # TODO shouldn't use mysql type -from sqlalchemy.databases import mysql +try: + from sqlalchemy.databases import mysql + mysql_tinyinteger = mysql.MSTinyInteger +except ImportError: + # Required for SQLAlchemy>=2.0 + from sqlalchemy.dialects import mysql + mysql_tinyinteger = mysql.base.MSTinyInteger from sqlalchemy.engine import default from sqlalchemy.sql import compiler from sqlalchemy.sql.compiler import SQLCompiler @@ -28,7 +34,7 @@ class TrinoIdentifierPreparer(PrestoIdentifierPreparer): _type_map = { 'boolean': types.Boolean, - 'tinyint': mysql.MSTinyInteger, + 'tinyint': mysql_tinyinteger, 'smallint': types.SmallInteger, 'integer': types.Integer, 'bigint': types.BigInteger, @@ -67,7 +73,12 @@ def visit_TEXT(self, type_, **kw): class TrinoDialect(PrestoDialect): name = 'trino' + supports_statement_cache = False @classmethod def dbapi(cls): return trino + + @classmethod + def import_dbapi(cls): + return trino diff --git a/pyhive/tests/sqlalchemy_test_case.py b/pyhive/tests/sqlalchemy_test_case.py index 652e05f4..db89d57b 100644 --- a/pyhive/tests/sqlalchemy_test_case.py +++ b/pyhive/tests/sqlalchemy_test_case.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals import abc +import re import contextlib import functools @@ -14,8 +15,10 @@ from sqlalchemy.schema import Index from sqlalchemy.schema import MetaData from sqlalchemy.schema import Table -from sqlalchemy.sql import expression +from sqlalchemy.sql import expression, text +from sqlalchemy import String +sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) def with_engine_connection(fn): """Pass a connection to the given function and handle cleanup. @@ -32,19 +35,33 @@ def wrapped_fn(self, *args, **kwargs): engine.dispose() return wrapped_fn +def reflect_table(engine, connection, table, include_columns, exclude_columns, resolve_fks): + if sqlalchemy_version >= 1.4: + insp = sqlalchemy.inspect(engine) + insp.reflect_table( + table, + include_columns=include_columns, + exclude_columns=exclude_columns, + resolve_fks=resolve_fks, + ) + else: + engine.dialect.reflecttable( + connection, table, include_columns=include_columns, + exclude_columns=exclude_columns, resolve_fks=resolve_fks) + class SqlAlchemyTestCase(with_metaclass(abc.ABCMeta, object)): @with_engine_connection def test_basic_query(self, engine, connection): - rows = connection.execute('SELECT * FROM one_row').fetchall() + rows = connection.execute(text('SELECT * FROM one_row')).fetchall() self.assertEqual(len(rows), 1) self.assertEqual(rows[0].number_of_rows, 1) # number_of_rows is the column name self.assertEqual(len(rows[0]), 1) @with_engine_connection def test_one_row_complex_null(self, engine, connection): - one_row_complex_null = Table('one_row_complex_null', MetaData(bind=engine), autoload=True) - rows = one_row_complex_null.select().execute().fetchall() + one_row_complex_null = Table('one_row_complex_null', MetaData(), autoload_with=engine) + rows = connection.execute(one_row_complex_null.select()).fetchall() self.assertEqual(len(rows), 1) self.assertEqual(list(rows[0]), [None] * len(rows[0])) @@ -53,27 +70,26 @@ def test_reflect_no_such_table(self, engine, connection): """reflecttable should throw an exception on an invalid table""" self.assertRaises( NoSuchTableError, - lambda: Table('this_does_not_exist', MetaData(bind=engine), autoload=True)) + lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) self.assertRaises( NoSuchTableError, - lambda: Table('this_does_not_exist', MetaData(bind=engine), - schema='also_does_not_exist', autoload=True)) + lambda: Table('this_does_not_exist', MetaData(schema='also_does_not_exist'), autoload_with=engine)) @with_engine_connection def test_reflect_include_columns(self, engine, connection): """When passed include_columns, reflecttable should filter out other columns""" - one_row_complex = Table('one_row_complex', MetaData(bind=engine)) - engine.dialect.reflecttable( - connection, one_row_complex, include_columns=['int'], + + one_row_complex = Table('one_row_complex', MetaData()) + reflect_table(engine, connection, one_row_complex, include_columns=['int'], exclude_columns=[], resolve_fks=True) + self.assertEqual(len(one_row_complex.c), 1) self.assertIsNotNone(one_row_complex.c.int) self.assertRaises(AttributeError, lambda: one_row_complex.c.tinyint) @with_engine_connection def test_reflect_with_schema(self, engine, connection): - dummy = Table('dummy_table', MetaData(bind=engine), schema='pyhive_test_database', - autoload=True) + dummy = Table('dummy_table', MetaData(schema='pyhive_test_database'), autoload_with=engine) self.assertEqual(len(dummy.c), 1) self.assertIsNotNone(dummy.c.a) @@ -81,22 +97,22 @@ def test_reflect_with_schema(self, engine, connection): @with_engine_connection def test_reflect_partitions(self, engine, connection): """reflecttable should get the partition column as an index""" - many_rows = Table('many_rows', MetaData(bind=engine), autoload=True) + many_rows = Table('many_rows', MetaData(), autoload_with=engine) self.assertEqual(len(many_rows.c), 2) self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)})) - many_rows = Table('many_rows', MetaData(bind=engine)) - engine.dialect.reflecttable( - connection, many_rows, include_columns=['a'], + many_rows = Table('many_rows', MetaData()) + reflect_table(engine, connection, many_rows, include_columns=['a'], exclude_columns=[], resolve_fks=True) + self.assertEqual(len(many_rows.c), 1) self.assertFalse(many_rows.c.a.index) self.assertFalse(many_rows.indexes) - many_rows = Table('many_rows', MetaData(bind=engine)) - engine.dialect.reflecttable( - connection, many_rows, include_columns=['b'], + many_rows = Table('many_rows', MetaData()) + reflect_table(engine, connection, many_rows, include_columns=['b'], exclude_columns=[], resolve_fks=True) + self.assertEqual(len(many_rows.c), 1) self.assertEqual(repr(many_rows.indexes), repr({Index('partition', many_rows.c.b)})) @@ -104,11 +120,15 @@ def test_reflect_partitions(self, engine, connection): def test_unicode(self, engine, connection): """Verify that unicode strings make it through SQLAlchemy and the backend""" unicode_str = "中文" - one_row = Table('one_row', MetaData(bind=engine)) - returned_str = sqlalchemy.select( - [expression.bindparam("好", unicode_str)], - from_obj=one_row, - ).scalar() + one_row = Table('one_row', MetaData()) + + if sqlalchemy_version >= 1.4: + returned_str = connection.execute(sqlalchemy.select( + expression.bindparam("好", unicode_str, type_=String())).select_from(one_row)).scalar() + else: + returned_str = connection.execute(sqlalchemy.select([ + expression.bindparam("好", unicode_str, type_=String())]).select_from(one_row)).scalar() + self.assertEqual(returned_str, unicode_str) @with_engine_connection @@ -133,13 +153,21 @@ def test_get_table_names(self, engine, connection): @with_engine_connection def test_has_table(self, engine, connection): - self.assertTrue(Table('one_row', MetaData(bind=engine)).exists()) - self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists()) + if sqlalchemy_version >= 1.4: + insp = sqlalchemy.inspect(engine) + self.assertTrue(insp.has_table("one_row")) + self.assertFalse(insp.has_table("this_table_does_not_exist")) + else: + self.assertTrue(Table('one_row', MetaData(bind=engine)).exists()) + self.assertFalse(Table('this_table_does_not_exist', MetaData(bind=engine)).exists()) @with_engine_connection def test_char_length(self, engine, connection): - one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True) - result = sqlalchemy.select([ - sqlalchemy.func.char_length(one_row_complex.c.string) - ]).execute().scalar() + one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) + + if sqlalchemy_version >= 1.4: + result = connection.execute(sqlalchemy.select(sqlalchemy.func.char_length(one_row_complex.c.string))).scalar() + else: + result = connection.execute(sqlalchemy.select([sqlalchemy.func.char_length(one_row_complex.c.string)])).scalar() + self.assertEqual(result, len('a string')) diff --git a/pyhive/tests/test_sqlalchemy_hive.py b/pyhive/tests/test_sqlalchemy_hive.py index 1ff0e817..790bec4c 100644 --- a/pyhive/tests/test_sqlalchemy_hive.py +++ b/pyhive/tests/test_sqlalchemy_hive.py @@ -4,6 +4,7 @@ from pyhive.sqlalchemy_hive import HiveDate from pyhive.sqlalchemy_hive import HiveDecimal from pyhive.sqlalchemy_hive import HiveTimestamp +from sqlalchemy.exc import NoSuchTableError, OperationalError from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase from pyhive.tests.sqlalchemy_test_case import with_engine_connection from sqlalchemy import types @@ -11,11 +12,15 @@ from sqlalchemy.schema import Column from sqlalchemy.schema import MetaData from sqlalchemy.schema import Table +from sqlalchemy.sql import text import contextlib import datetime import decimal import sqlalchemy.types import unittest +import re + +sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) _ONE_ROW_COMPLEX_CONTENTS = [ True, @@ -64,7 +69,11 @@ def test_dotted_column_names(self, engine, connection): """When Hive returns a dotted column name, both the non-dotted version should be available as an attribute, and the dotted version should remain available as a key. """ - row = connection.execute('SELECT * FROM one_row').fetchone() + row = connection.execute(text('SELECT * FROM one_row')).fetchone() + + if sqlalchemy_version >= 1.4: + row = row._mapping + assert row.keys() == ['number_of_rows'] assert 'number_of_rows' in row assert row.number_of_rows == 1 @@ -76,20 +85,33 @@ def test_dotted_column_names(self, engine, connection): def test_dotted_column_names_raw(self, engine, connection): """When Hive returns a dotted column name, and raw mode is on, nothing should be modified. """ - row = connection.execution_options(hive_raw_colnames=True) \ - .execute('SELECT * FROM one_row').fetchone() + row = connection.execution_options(hive_raw_colnames=True).execute(text('SELECT * FROM one_row')).fetchone() + + if sqlalchemy_version >= 1.4: + row = row._mapping + assert row.keys() == ['one_row.number_of_rows'] assert 'number_of_rows' not in row assert getattr(row, 'one_row.number_of_rows') == 1 assert row['one_row.number_of_rows'] == 1 + @with_engine_connection + def test_reflect_no_such_table(self, engine, connection): + """reflecttable should throw an exception on an invalid table""" + self.assertRaises( + NoSuchTableError, + lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) + self.assertRaises( + OperationalError, + lambda: Table('this_does_not_exist', MetaData(schema="also_does_not_exist"), autoload_with=engine)) + @with_engine_connection def test_reflect_select(self, engine, connection): """reflecttable should be able to fill in a table from the name""" - one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True) + one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) self.assertEqual(len(one_row_complex.c), 15) self.assertIsInstance(one_row_complex.c.string, Column) - row = one_row_complex.select().execute().fetchone() + row = connection.execute(one_row_complex.select()).fetchone() self.assertEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS) # TODO some of these types could be filled in better @@ -112,15 +134,15 @@ def test_reflect_select(self, engine, connection): @with_engine_connection def test_type_map(self, engine, connection): """sqlalchemy should use the dbapi_type_map to infer types from raw queries""" - row = connection.execute('SELECT * FROM one_row_complex').fetchone() + row = connection.execute(text('SELECT * FROM one_row_complex')).fetchone() self.assertListEqual(list(row), _ONE_ROW_COMPLEX_CONTENTS) @with_engine_connection def test_reserved_words(self, engine, connection): """Hive uses backticks""" # Use keywords for the table/column name - fake_table = Table('select', MetaData(bind=engine), Column('map', sqlalchemy.types.String)) - query = str(fake_table.select(fake_table.c.map == 'a')) + fake_table = Table('select', MetaData(), Column('map', sqlalchemy.types.String)) + query = str(fake_table.select().where(fake_table.c.map == 'a').compile(engine)) self.assertIn('`select`', query) self.assertIn('`map`', query) self.assertNotIn('"select"', query) @@ -132,12 +154,12 @@ def test_switch_database(self): with contextlib.closing(engine.connect()) as connection: self.assertIn( ('dummy_table',), - connection.execute('SHOW TABLES').fetchall() + connection.execute(text('SHOW TABLES')).fetchall() ) - connection.execute('USE default') + connection.execute(text('USE default')) self.assertIn( ('one_row',), - connection.execute('SHOW TABLES').fetchall() + connection.execute(text('SHOW TABLES')).fetchall() ) finally: engine.dispose() @@ -160,13 +182,13 @@ def test_lots_of_types(self, engine, connection): cols.append(Column('hive_date', HiveDate)) cols.append(Column('hive_decimal', HiveDecimal)) cols.append(Column('hive_timestamp', HiveTimestamp)) - table = Table('test_table', MetaData(bind=engine), *cols, schema='pyhive_test_database') - table.drop(checkfirst=True) - table.create() - connection.execute('SET mapred.job.tracker=local') - connection.execute('USE pyhive_test_database') + table = Table('test_table', MetaData(schema='pyhive_test_database'), *cols,) + table.drop(checkfirst=True, bind=connection) + table.create(bind=connection) + connection.execute(text('SET mapred.job.tracker=local')) + connection.execute(text('USE pyhive_test_database')) big_number = 10 ** 10 - 1 - connection.execute(""" + connection.execute(text(""" INSERT OVERWRITE TABLE test_table SELECT 1, "a", "a", "a", "a", "a", 0.1, @@ -175,41 +197,39 @@ def test_lots_of_types(self, engine, connection): "a", 1, 1, 0.1, 0.1, 0, 0, 0, "a", false, "a", "a", - 0, %d, 123 + 2000 + 0, :big_number, 123 + 2000 FROM default.one_row - """, big_number) - row = connection.execute(table.select()).fetchone() - self.assertEqual(row.hive_date, datetime.date(1970, 1, 1)) + """), {"big_number": big_number}) + row = connection.execute(text("select * from test_table")).fetchone() + self.assertEqual(row.hive_date, datetime.datetime(1970, 1, 1, 0, 0)) self.assertEqual(row.hive_decimal, decimal.Decimal(big_number)) self.assertEqual(row.hive_timestamp, datetime.datetime(1970, 1, 1, 0, 0, 2, 123000)) - table.drop() + table.drop(bind=connection) @with_engine_connection def test_insert_select(self, engine, connection): - one_row = Table('one_row', MetaData(bind=engine), autoload=True) - table = Table('insert_test', MetaData(bind=engine), - Column('a', sqlalchemy.types.Integer), - schema='pyhive_test_database') - table.drop(checkfirst=True) - table.create() - connection.execute('SET mapred.job.tracker=local') + one_row = Table('one_row', MetaData(), autoload_with=engine) + table = Table('insert_test', MetaData(schema='pyhive_test_database'), + Column('a', sqlalchemy.types.Integer)) + table.drop(checkfirst=True, bind=connection) + table.create(bind=connection) + connection.execute(text('SET mapred.job.tracker=local')) # NOTE(jing) I'm stuck on a version of Hive without INSERT ... VALUES connection.execute(table.insert().from_select(['a'], one_row.select())) - - result = table.select().execute().fetchall() + + result = connection.execute(table.select()).fetchall() expected = [(1,)] self.assertEqual(result, expected) @with_engine_connection def test_insert_values(self, engine, connection): - table = Table('insert_test', MetaData(bind=engine), - Column('a', sqlalchemy.types.Integer), - schema='pyhive_test_database') - table.drop(checkfirst=True) - table.create() - connection.execute(table.insert([{'a': 1}, {'a': 2}])) - - result = table.select().execute().fetchall() + table = Table('insert_test', MetaData(schema='pyhive_test_database'), + Column('a', sqlalchemy.types.Integer),) + table.drop(checkfirst=True, bind=connection) + table.create(bind=connection) + connection.execute(table.insert().values([{'a': 1}, {'a': 2}])) + + result = connection.execute(table.select()).fetchall() expected = [(1,), (2,)] self.assertEqual(result, expected) diff --git a/pyhive/tests/test_sqlalchemy_presto.py b/pyhive/tests/test_sqlalchemy_presto.py index a01e4a35..58a5c034 100644 --- a/pyhive/tests/test_sqlalchemy_presto.py +++ b/pyhive/tests/test_sqlalchemy_presto.py @@ -8,7 +8,9 @@ from sqlalchemy.schema import Column from sqlalchemy.schema import MetaData from sqlalchemy.schema import Table +from sqlalchemy.sql import text from sqlalchemy.types import String +from decimal import Decimal import contextlib import unittest @@ -27,11 +29,11 @@ def test_bad_format(self): @with_engine_connection def test_reflect_select(self, engine, connection): """reflecttable should be able to fill in a table from the name""" - one_row_complex = Table('one_row_complex', MetaData(bind=engine), autoload=True) + one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) # Presto ignores the union column self.assertEqual(len(one_row_complex.c), 15 - 1) self.assertIsInstance(one_row_complex.c.string, Column) - rows = one_row_complex.select().execute().fetchall() + rows = connection.execute(one_row_complex.select()).fetchall() self.assertEqual(len(rows), 1) self.assertEqual(list(rows[0]), [ True, @@ -48,7 +50,7 @@ def test_reflect_select(self, engine, connection): {"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON [1, 2], # struct is returned as a list of elements # '{0:1}', - '0.1', + Decimal('0.1'), ]) # TODO some of these types could be filled in better @@ -71,7 +73,7 @@ def test_url_default(self): engine = create_engine('presto://localhost:8080/hive') try: with contextlib.closing(engine.connect()) as connection: - self.assertEqual(connection.execute('SELECT 1 AS foobar FROM one_row').scalar(), 1) + self.assertEqual(connection.execute(text('SELECT 1 AS foobar FROM one_row')).scalar(), 1) finally: engine.dispose() @@ -79,8 +81,8 @@ def test_url_default(self): def test_reserved_words(self, engine, connection): """Presto uses double quotes, not backticks""" # Use keywords for the table/column name - fake_table = Table('select', MetaData(bind=engine), Column('current_timestamp', String)) - query = str(fake_table.select(fake_table.c.current_timestamp == 'a')) + fake_table = Table('select', MetaData(), Column('current_timestamp', String)) + query = str(fake_table.select().where(fake_table.c.current_timestamp == 'a').compile(engine)) self.assertIn('"select"', query) self.assertIn('"current_timestamp"', query) self.assertNotIn('`select`', query) diff --git a/pyhive/tests/test_sqlalchemy_trino.py b/pyhive/tests/test_sqlalchemy_trino.py new file mode 100644 index 00000000..c929f941 --- /dev/null +++ b/pyhive/tests/test_sqlalchemy_trino.py @@ -0,0 +1,93 @@ +from sqlalchemy.engine import create_engine +from pyhive.tests.sqlalchemy_test_case import SqlAlchemyTestCase +from pyhive.tests.sqlalchemy_test_case import with_engine_connection +from sqlalchemy.exc import NoSuchTableError, DatabaseError +from sqlalchemy.schema import MetaData, Table, Column +from sqlalchemy.types import String +from sqlalchemy.sql import text +from sqlalchemy import types +from decimal import Decimal + +import unittest +import contextlib + + +class TestSqlAlchemyTrino(unittest.TestCase, SqlAlchemyTestCase): + def create_engine(self): + return create_engine('trino+pyhive://localhost:18080/hive/default?source={}'.format(self.id())) + + def test_bad_format(self): + self.assertRaises( + ValueError, + lambda: create_engine('trino+pyhive://localhost:18080/hive/default/what'), + ) + + @with_engine_connection + def test_reflect_select(self, engine, connection): + """reflecttable should be able to fill in a table from the name""" + one_row_complex = Table('one_row_complex', MetaData(), autoload_with=engine) + # Presto ignores the union column + self.assertEqual(len(one_row_complex.c), 15 - 1) + self.assertIsInstance(one_row_complex.c.string, Column) + rows = connection.execute(one_row_complex.select()).fetchall() + self.assertEqual(len(rows), 1) + self.assertEqual(list(rows[0]), [ + True, + 127, + 32767, + 2147483647, + 9223372036854775807, + 0.5, + 0.25, + 'a string', + '1970-01-01 00:00:00.000', + b'123', + [1, 2], + {"1": 2, "3": 4}, + [1, 2], + Decimal('0.1'), + ]) + + self.assertIsInstance(one_row_complex.c.boolean.type, types.Boolean) + self.assertIsInstance(one_row_complex.c.tinyint.type, types.Integer) + self.assertIsInstance(one_row_complex.c.smallint.type, types.Integer) + self.assertIsInstance(one_row_complex.c.int.type, types.Integer) + self.assertIsInstance(one_row_complex.c.bigint.type, types.BigInteger) + self.assertIsInstance(one_row_complex.c.float.type, types.Float) + self.assertIsInstance(one_row_complex.c.double.type, types.Float) + self.assertIsInstance(one_row_complex.c.string.type, String) + self.assertIsInstance(one_row_complex.c.timestamp.type, types.NullType) + self.assertIsInstance(one_row_complex.c.binary.type, types.VARBINARY) + self.assertIsInstance(one_row_complex.c.array.type, types.NullType) + self.assertIsInstance(one_row_complex.c.map.type, types.NullType) + self.assertIsInstance(one_row_complex.c.struct.type, types.NullType) + self.assertIsInstance(one_row_complex.c.decimal.type, types.NullType) + + @with_engine_connection + def test_reflect_no_such_table(self, engine, connection): + """reflecttable should throw an exception on an invalid table""" + self.assertRaises( + NoSuchTableError, + lambda: Table('this_does_not_exist', MetaData(), autoload_with=engine)) + self.assertRaises( + DatabaseError, + lambda: Table('this_does_not_exist', MetaData(schema="also_does_not_exist"), autoload_with=engine)) + + def test_url_default(self): + engine = create_engine('trino+pyhive://localhost:18080/hive') + try: + with contextlib.closing(engine.connect()) as connection: + self.assertEqual(connection.execute(text('SELECT 1 AS foobar FROM one_row')).scalar(), 1) + finally: + engine.dispose() + + @with_engine_connection + def test_reserved_words(self, engine, connection): + """Trino uses double quotes, not backticks""" + # Use keywords for the table/column name + fake_table = Table('select', MetaData(), Column('current_timestamp', String)) + query = str(fake_table.select().where(fake_table.c.current_timestamp == 'a').compile(engine)) + self.assertIn('"select"', query) + self.assertIn('"current_timestamp"', query) + self.assertNotIn('`select`', query) + self.assertNotIn('`current_timestamp`', query) From d4ae481675ac5588ba9101596fa26f22ef0e77c4 Mon Sep 17 00:00:00 2001 From: Multazim Deshmukh <57723564+mdeshmu@users.noreply.github.com> Date: Wed, 12 Jul 2023 16:09:00 +0530 Subject: [PATCH 10/12] update readme to reflect recent changes (#459) --- README.rst | 41 ++++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/README.rst b/README.rst index 5afd746c..bb10e607 100644 --- a/README.rst +++ b/README.rst @@ -14,8 +14,8 @@ PyHive ====== PyHive is a collection of Python `DB-API `_ and -`SQLAlchemy `_ interfaces for `Presto `_ and -`Hive `_. +`SQLAlchemy `_ interfaces for `Presto `_ , +`Hive `_ and `Trino `_. Usage ===== @@ -25,7 +25,7 @@ DB-API .. code-block:: python from pyhive import presto # or import hive or import trino - cursor = presto.connect('localhost').cursor() + cursor = presto.connect('localhost').cursor() # or use hive.connect or use trino.connect cursor.execute('SELECT * FROM my_awesome_data LIMIT 10') print cursor.fetchone() print cursor.fetchall() @@ -61,7 +61,7 @@ In Python 3.7 `async` became a keyword; you can use `async_` instead: SQLAlchemy ---------- -First install this package to register it with SQLAlchemy (see ``setup.py``). +First install this package to register it with SQLAlchemy, see ``entry_points`` in ``setup.py``. .. code-block:: python @@ -117,7 +117,7 @@ Passing session configuration 'session_props': {'query_max_run_time': '1234m'}} ) create_engine( - 'trino://user@host:443/hive', + 'trino+pyhive://user@host:443/hive', connect_args={'protocol': 'https', 'session_props': {'query_max_run_time': '1234m'}} ) @@ -136,15 +136,18 @@ Requirements Install using -- ``pip install 'pyhive[hive]'`` for the Hive interface and -- ``pip install 'pyhive[presto]'`` for the Presto interface. +- ``pip install 'pyhive[hive]'`` or ``pip install 'pyhive[hive_pure_sasl]'`` for the Hive interface +- ``pip install 'pyhive[presto]'`` for the Presto interface - ``pip install 'pyhive[trino]'`` for the Trino interface +Note: ``'pyhive[hive]'`` extras uses `sasl `_ that doesn't support Python 3.11, See `github issue `_. +Hence PyHive also supports `pure-sasl `_ via additional extras ``'pyhive[hive_pure_sasl]'`` which support Python 3.11. + PyHive works with - Python 2.7 / Python 3 -- For Presto: Presto install -- For Trino: Trino install +- For Presto: `Presto installation `_ +- For Trino: `Trino installation `_ - For Hive: `HiveServer2 `_ daemon Changelog @@ -162,6 +165,26 @@ Contributing - We prefer having a small number of generic features over a large number of specialized, inflexible features. For example, the Presto code takes an arbitrary ``requests_session`` argument for customizing HTTP calls, as opposed to having a separate parameter/branch for each ``requests`` option. +Tips for test environment setup +================================ +You can setup test environment by following ``.travis.yaml`` in this repository. It uses `Cloudera's CDH 5 `_ which requires username and password for download. +It may not be feasible for everyone to get those credentials. Hence below are alternative instructions to setup test environment. + +You can clone `this repository `_ which has Docker Compose setup for Presto and Hive. +You can add below lines to its docker-compose.yaml to start Trino in same environment:: + + trino: + image: trinodb/trino:351 + ports: + - "18080:18080" + volumes: + - ./trino:/etc/trino + +Note: ``./trino`` for docker volume defined above is `trino config from PyHive repository `_ + +Then run:: + docker-compose up -d + Testing ======= .. image:: https://travis-ci.org/dropbox/PyHive.svg From 486eaefdea1326bd5f63a5dd4734c2646cf0bf84 Mon Sep 17 00:00:00 2001 From: Bogdan Date: Thu, 30 May 2024 15:26:26 -0700 Subject: [PATCH 11/12] Update README.rst (#475) --- README.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index bb10e607..1f4db670 100644 --- a/README.rst +++ b/README.rst @@ -1,5 +1,7 @@ ================================ -Project is currently unsupported +Pyhive project has been donated to Apache Kyuubi. + +You can follow it's development and report any issues you are experiencing here: https://github.com/apache/kyuubi/tree/master/python/pyhive ================================ From ac09074a652fd50e10b57a7f0bbc4f6410961301 Mon Sep 17 00:00:00 2001 From: Bogdan Date: Thu, 30 May 2024 15:35:22 -0700 Subject: [PATCH 12/12] Update README.rst (#476) --- README.rst | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/README.rst b/README.rst index 1f4db670..cbdcce10 100644 --- a/README.rst +++ b/README.rst @@ -1,26 +1,24 @@ -================================ -Pyhive project has been donated to Apache Kyuubi. +======================================================== +PyHive project has been donated to Apache Kyuubi +======================================================== You can follow it's development and report any issues you are experiencing here: https://github.com/apache/kyuubi/tree/master/python/pyhive -================================ +Legacy notes / instructions +=========================== -.. image:: https://travis-ci.org/dropbox/PyHive.svg?branch=master - :target: https://travis-ci.org/dropbox/PyHive -.. image:: https://img.shields.io/codecov/c/github/dropbox/PyHive.svg - -====== PyHive -====== +********** + PyHive is a collection of Python `DB-API `_ and `SQLAlchemy `_ interfaces for `Presto `_ , `Hive `_ and `Trino `_. Usage -===== +********** DB-API ------ @@ -134,7 +132,7 @@ Passing session configuration ) Requirements -============ +************ Install using @@ -153,11 +151,11 @@ PyHive works with - For Hive: `HiveServer2 `_ daemon Changelog -========= +********* See https://github.com/dropbox/PyHive/releases. Contributing -============ +************ - Please fill out the Dropbox Contributor License Agreement at https://opensource.dropbox.com/cla/ and note this in your pull request. - Changes must come with tests, with the exception of trivial things like fixing comments. See .travis.yml for the test environment setup. - Notes on project scope: @@ -168,7 +166,7 @@ Contributing For example, the Presto code takes an arbitrary ``requests_session`` argument for customizing HTTP calls, as opposed to having a separate parameter/branch for each ``requests`` option. Tips for test environment setup -================================ +**************************************** You can setup test environment by following ``.travis.yaml`` in this repository. It uses `Cloudera's CDH 5 `_ which requires username and password for download. It may not be feasible for everyone to get those credentials. Hence below are alternative instructions to setup test environment. @@ -188,7 +186,7 @@ Then run:: docker-compose up -d Testing -======= +******* .. image:: https://travis-ci.org/dropbox/PyHive.svg :target: https://travis-ci.org/dropbox/PyHive .. image:: http://codecov.io/github/dropbox/PyHive/coverage.svg?branch=master @@ -207,7 +205,7 @@ WARNING: This drops/creates tables named ``one_row``, ``one_row_complex``, and ` database called ``pyhive_test_database``. Updating TCLIService -==================== +******************** The TCLIService module is autogenerated using a ``TCLIService.thrift`` file. To update it, the ``generate.py`` file can be used: ``python generate.py ``. When left blank, the