diff --git a/TCLIService/ttypes.py b/TCLIService/ttypes.py index 573bd043..47b4cec1 100644 --- a/TCLIService/ttypes.py +++ b/TCLIService/ttypes.py @@ -8,6 +8,7 @@ from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException from thrift.protocol.TProtocol import TProtocolException +import numpy as np import sys from thrift.transport import TTransport @@ -2013,9 +2014,7 @@ def read(self, iprot): if ftype == TType.LIST: self.values = [] (_etype51, _size48) = iprot.readListBegin() - for _i52 in range(_size48): - _elem53 = iprot.readBool() - self.values.append(_elem53) + self.values = np.frombuffer(iprot.trans.readAll(1 * _size48), dtype='>?') iprot.readListEnd() else: iprot.skip(ftype) @@ -2097,9 +2096,7 @@ def read(self, iprot): if ftype == TType.LIST: self.values = [] (_etype58, _size55) = iprot.readListBegin() - for _i59 in range(_size55): - _elem60 = iprot.readByte() - self.values.append(_elem60) + self.values = np.frombuffer(iprot.trans.readAll(1 * _size55), dtype='>i1') iprot.readListEnd() else: iprot.skip(ftype) @@ -2181,9 +2178,7 @@ def read(self, iprot): if ftype == TType.LIST: self.values = [] (_etype65, _size62) = iprot.readListBegin() - for _i66 in range(_size62): - _elem67 = iprot.readI16() - self.values.append(_elem67) + self.values = np.frombuffer(iprot.trans.readAll(2 * _size62), dtype='>i2') iprot.readListEnd() else: iprot.skip(ftype) @@ -2265,9 +2260,7 @@ def read(self, iprot): if ftype == TType.LIST: self.values = [] (_etype72, _size69) = iprot.readListBegin() - for _i73 in range(_size69): - _elem74 = iprot.readI32() - self.values.append(_elem74) + self.values = np.frombuffer(iprot.trans.readAll(4 * _size69), dtype='>i4') iprot.readListEnd() else: iprot.skip(ftype) @@ -2349,9 +2342,7 @@ def read(self, iprot): if ftype == TType.LIST: self.values = [] (_etype79, _size76) = iprot.readListBegin() - for _i80 in range(_size76): - _elem81 = iprot.readI64() - self.values.append(_elem81) + self.values = np.frombuffer(iprot.trans.readAll(8 * _size76), dtype='>i8') iprot.readListEnd() else: iprot.skip(ftype) @@ -2433,9 +2424,7 @@ def read(self, iprot): if ftype == TType.LIST: self.values = [] (_etype86, _size83) = iprot.readListBegin() - for _i87 in range(_size83): - _elem88 = iprot.readDouble() - self.values.append(_elem88) + self.values = np.frombuffer(iprot.trans.readAll(8 * _size83), dtype='>f8') iprot.readListEnd() else: iprot.skip(ftype) diff --git a/pyhive/common.py b/pyhive/common.py index 298633a1..c8b23cef 100644 --- a/pyhive/common.py +++ b/pyhive/common.py @@ -38,7 +38,7 @@ def _reset_state(self): # Internal helper state self._state = self._STATE_NONE - self._data = collections.deque() + self._data = None self._columns = None def _fetch_while(self, fn): diff --git a/pyhive/hive.py b/pyhive/hive.py index 3f71df33..7922ed06 100644 --- a/pyhive/hive.py +++ b/pyhive/hive.py @@ -10,6 +10,10 @@ import base64 import datetime +import io +import numpy as np +import pyarrow as pa +import pyarrow.json import re from decimal import Decimal from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context @@ -41,6 +45,7 @@ _logger = logging.getLogger(__name__) _TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)') +_INTERVAL_DAY_TIME_PATTERN = re.compile(r'(\d+) (\d+):(\d+):(\d+(?:.\d+)?)') ssl_cert_parameter_map = { "none": CERT_NONE, @@ -67,9 +72,36 @@ def _parse_timestamp(value): value = None return value +def _parse_date(value): + if value: + format = '%Y-%m-%d' + value = datetime.datetime.strptime(value, format).date() + else: + value = None + return value -TYPES_CONVERTER = {"DECIMAL_TYPE": Decimal, - "TIMESTAMP_TYPE": _parse_timestamp} +def _parse_interval_day_time(value): + if value: + match = _INTERVAL_DAY_TIME_PATTERN.match(value) + if match: + days = int(match.group(1)) + hours = int(match.group(2)) + minutes = int(match.group(3)) + seconds = float(match.group(4)) + value = datetime.timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds) + else: + raise Exception( + 'Cannot convert "{}" into an interval_day_time'.format(value)) + else: + value = None + return value + +TYPES_CONVERTER = { + "DECIMAL_TYPE": Decimal, + "TIMESTAMP_TYPE": _parse_timestamp, + "DATE_TYPE": _parse_date, + "INTERVAL_DAY_TIME_TYPE": _parse_interval_day_time, +} class HiveParamEscaper(common.ParamEscaper): @@ -462,6 +494,48 @@ def cancel(self): response = self._connection.client.CancelOperation(req) _check_status(response) + def fetchone(self): + return self.fetchmany(1) + + def fetchall(self): + return self.fetchmany(-1) + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + if self._state == self._STATE_NONE: + raise exc.ProgrammingError("No query yet") + + if size == -1: + # Fetch everything + self._fetch_while(lambda: self._state != self._STATE_FINISHED) + else: + self._fetch_while(lambda: + (self._state != self._STATE_FINISHED) and + (self._data is None or self._data.num_rows < size) + ) + + if not self._data: + return None + + if size == -1: + # Fetch everything + size = self._data.num_rows + else: + size = min(size, self._data.num_rows) + + self._rownumber += size + rows = self._data[:size] + + if size == self._data.num_rows: + # Fetch everything + self._data = None + else: + self._data = self._data[size:] + + return rows + def _fetch_more(self): """Send another TFetchResultsReq and update state""" assert(self._state == self._STATE_RUNNING), "Should be running when in _fetch_more" @@ -479,13 +553,19 @@ def _fetch_more(self): assert not response.results.rows, 'expected data in columnar format' columns = [_unwrap_column(col, col_schema[1]) for col, col_schema in zip(response.results.columns, schema)] - new_data = list(zip(*columns)) - self._data += new_data + names = [col[0] for col in schema] + new_data = pa.Table.from_batches([pa.RecordBatch.from_arrays(columns, names=names)]) # response.hasMoreRows seems to always be False, so we instead check the number of rows # https://github.com/apache/hive/blob/release-1.2.1/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java#L678 # if not response.hasMoreRows: - if not new_data: + if new_data.num_rows == 0: self._state = self._STATE_FINISHED + return + + if self._data is None: + self._data = new_data + else: + self._data = pa.concat_tables([self._data, new_data]) def poll(self, get_progress_update=True): """Poll for and return the raw status data provided by the Hive Thrift REST API. @@ -563,17 +643,42 @@ def _unwrap_column(col, type_=None): """Return a list of raw values from a TColumn instance.""" for attr, wrapper in iteritems(col.__dict__): if wrapper is not None: - result = wrapper.values - nulls = wrapper.nulls # bit set describing what's null - assert isinstance(nulls, bytes) - for i, char in enumerate(nulls): - byte = ord(char) if sys.version_info[0] == 2 else char - for b in range(8): - if byte & (1 << b): - result[i * 8 + b] = None - converter = TYPES_CONVERTER.get(type_, None) - if converter and type_: - result = [converter(row) if row else row for row in result] + if attr in ['boolVal', 'byteVal', 'i16Val', 'i32Val', 'i64Val', 'doubleVal']: + values = wrapper.values + # unpack nulls as a byte array + nulls = np.unpackbits(np.frombuffer(wrapper.nulls, dtype='uint8')).view(bool) + # override a full mask as trailing False values are not sent + mask = np.zeros(values.shape, dtype='?') + end = min(len(mask), len(nulls)) + mask[:end] = nulls[:end] + + # float values are transferred as double + if type_ == 'FLOAT_TYPE': + values = values.astype('>f4') + + result = pa.array(values.byteswap().newbyteorder(), mask=mask) + else: + result = wrapper.values + nulls = wrapper.nulls # bit set describing what's null + if len(result) == 0: + return pa.array([]) + assert isinstance(nulls, bytes) + for i, char in enumerate(nulls): + byte = ord(char) if sys.version_info[0] == 2 else char + for b in range(8): + if byte & (1 << b): + result[i * 8 + b] = None + converter = TYPES_CONVERTER.get(type_, None) + if converter and type_: + result = [converter(row) if row else row for row in result] + if type_ in ['ARRAY_TYPE', 'MAP_TYPE', 'STRUCT_TYPE']: + fd = io.BytesIO() + for row in result: + if row is None: + row = 'null' + fd.write(f'{{"c":{row}}}\n'.encode('utf8')) + fd.seek(0) + result = pa.json.read_json(fd)[0].combine_chunks() return result raise DataError("Got empty column value {}".format(col)) # pragma: no cover