Skip to content

Commit

Permalink
Produce rows as slices of pyarrow.Table
Browse files Browse the repository at this point in the history
  • Loading branch information
ptallada committed Feb 8, 2022
1 parent d199a1b commit 4754b49
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 35 deletions.
25 changes: 7 additions & 18 deletions TCLIService/ttypes.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyhive/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _reset_state(self):

# Internal helper state
self._state = self._STATE_NONE
self._data = collections.deque()
self._data = None
self._columns = None

def _fetch_while(self, fn):
Expand Down
137 changes: 121 additions & 16 deletions pyhive/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

import base64
import datetime
import io
import numpy as np
import pyarrow as pa
import pyarrow.json
import re
from decimal import Decimal
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context
Expand Down Expand Up @@ -41,6 +45,7 @@
_logger = logging.getLogger(__name__)

_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)')
_INTERVAL_DAY_TIME_PATTERN = re.compile(r'(\d+) (\d+):(\d+):(\d+(?:.\d+)?)')

ssl_cert_parameter_map = {
"none": CERT_NONE,
Expand All @@ -67,9 +72,36 @@ def _parse_timestamp(value):
value = None
return value

def _parse_date(value):
if value:
format = '%Y-%m-%d'
value = datetime.datetime.strptime(value, format).date()
else:
value = None
return value

TYPES_CONVERTER = {"DECIMAL_TYPE": Decimal,
"TIMESTAMP_TYPE": _parse_timestamp}
def _parse_interval_day_time(value):
if value:
match = _INTERVAL_DAY_TIME_PATTERN.match(value)
if match:
days = int(match.group(1))
hours = int(match.group(2))
minutes = int(match.group(3))
seconds = float(match.group(4))
value = datetime.timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
else:
raise Exception(
'Cannot convert "{}" into an interval_day_time'.format(value))
else:
value = None
return value

TYPES_CONVERTER = {
"DECIMAL_TYPE": Decimal,
"TIMESTAMP_TYPE": _parse_timestamp,
"DATE_TYPE": _parse_date,
"INTERVAL_DAY_TIME_TYPE": _parse_interval_day_time,
}


class HiveParamEscaper(common.ParamEscaper):
Expand Down Expand Up @@ -462,6 +494,48 @@ def cancel(self):
response = self._connection.client.CancelOperation(req)
_check_status(response)

def fetchone(self):
return self.fetchmany(1)

def fetchall(self):
return self.fetchmany(-1)

def fetchmany(self, size=None):
if size is None:
size = self.arraysize

if self._state == self._STATE_NONE:
raise exc.ProgrammingError("No query yet")

if size == -1:
# Fetch everything
self._fetch_while(lambda: self._state != self._STATE_FINISHED)
else:
self._fetch_while(lambda:
(self._state != self._STATE_FINISHED) and
(self._data is None or self._data.num_rows < size)
)

if not self._data:
return None

if size == -1:
# Fetch everything
size = self._data.num_rows
else:
size = min(size, self._data.num_rows)

self._rownumber += size
rows = self._data[:size]

if size == self._data.num_rows:
# Fetch everything
self._data = None
else:
self._data = self._data[size:]

return rows

def _fetch_more(self):
"""Send another TFetchResultsReq and update state"""
assert(self._state == self._STATE_RUNNING), "Should be running when in _fetch_more"
Expand All @@ -479,13 +553,19 @@ def _fetch_more(self):
assert not response.results.rows, 'expected data in columnar format'
columns = [_unwrap_column(col, col_schema[1]) for col, col_schema in
zip(response.results.columns, schema)]
new_data = list(zip(*columns))
self._data += new_data
names = [col[0] for col in schema]
new_data = pa.Table.from_batches([pa.RecordBatch.from_arrays(columns, names=names)])
# response.hasMoreRows seems to always be False, so we instead check the number of rows
# https://github.com/apache/hive/blob/release-1.2.1/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java#L678
# if not response.hasMoreRows:
if not new_data:
if new_data.num_rows == 0:
self._state = self._STATE_FINISHED
return

if self._data is None:
self._data = new_data
else:
self._data = pa.concat_tables([self._data, new_data])

def poll(self, get_progress_update=True):
"""Poll for and return the raw status data provided by the Hive Thrift REST API.
Expand Down Expand Up @@ -563,17 +643,42 @@ def _unwrap_column(col, type_=None):
"""Return a list of raw values from a TColumn instance."""
for attr, wrapper in iteritems(col.__dict__):
if wrapper is not None:
result = wrapper.values
nulls = wrapper.nulls # bit set describing what's null
assert isinstance(nulls, bytes)
for i, char in enumerate(nulls):
byte = ord(char) if sys.version_info[0] == 2 else char
for b in range(8):
if byte & (1 << b):
result[i * 8 + b] = None
converter = TYPES_CONVERTER.get(type_, None)
if converter and type_:
result = [converter(row) if row else row for row in result]
if attr in ['boolVal', 'byteVal', 'i16Val', 'i32Val', 'i64Val', 'doubleVal']:
values = wrapper.values
# unpack nulls as a byte array
nulls = np.unpackbits(np.frombuffer(wrapper.nulls, dtype='uint8')).view(bool)
# override a full mask as trailing False values are not sent
mask = np.zeros(values.shape, dtype='?')
end = min(len(mask), len(nulls))
mask[:end] = nulls[:end]

# float values are transferred as double
if type_ == 'FLOAT_TYPE':
values = values.astype('>f4')

result = pa.array(values.byteswap().newbyteorder(), mask=mask)
else:
result = wrapper.values
nulls = wrapper.nulls # bit set describing what's null
if len(result) == 0:
return pa.array([])
assert isinstance(nulls, bytes)
for i, char in enumerate(nulls):
byte = ord(char) if sys.version_info[0] == 2 else char
for b in range(8):
if byte & (1 << b):
result[i * 8 + b] = None
converter = TYPES_CONVERTER.get(type_, None)
if converter and type_:
result = [converter(row) if row else row for row in result]
if type_ in ['ARRAY_TYPE', 'MAP_TYPE', 'STRUCT_TYPE']:
fd = io.BytesIO()
for row in result:
if row is None:
row = 'null'
fd.write(f'{{"c":{row}}}\n'.encode('utf8'))
fd.seek(0)
result = pa.json.read_json(fd)[0].combine_chunks()
return result
raise DataError("Got empty column value {}".format(col)) # pragma: no cover

Expand Down

0 comments on commit 4754b49

Please sign in to comment.