Skip to content

Commit

Permalink
Correctly handle prefix data for dynamic types
Browse files Browse the repository at this point in the history
  • Loading branch information
genzgd committed Dec 21, 2024
1 parent 0998510 commit 88a2a52
Show file tree
Hide file tree
Showing 13 changed files with 226 additions and 125 deletions.
10 changes: 8 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@ release (0.9.0), unrecognized arguments/keywords for these methods of creating a
instead of being passed as ClickHouse server settings. This is in conjunction with some refactoring in Client construction.
The supported method of passing ClickHouse server settings is to prefix such arguments/query parameters with`ch_`.

## 0.8.11, 2024-12-17
## 0.8.11, TBD
### Improvement
- Support of ISO8601 strings for inserting values to columns with DateTime64 type was added.
- Support of ISO8601 strings for inserting values to columns with DateTime64 type was added. If the driver detects
that the inserted data for a DateTime64 is a string, it will attempt to parse an ISO-8601 datetime from that string.
Other string formats are not currently supported. Thanks to [Nikita Reznikov](https://github.com/rnv812) for the PR!

### Bug Fix
- Correctly handled native format column prefixes for Variant/Dynamic/JSON. Fixes https://github.com/ClickHouse/clickhouse-connect/issues/441
and possibly some other issues with experimental types Variant,Dynamic, and JSON.

## 0.8.10, 2024-12-14
### Bug Fixes
Expand Down
41 changes: 23 additions & 18 deletions clickhouse_connect/datatypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,19 @@ def write_column_prefix(self, dest: bytearray):
if self.low_card:
write_uint64(low_card_version, dest)

def read_column_prefix(self, source: ByteSource, _ctx: QueryContext):
def read_column_prefix(self, source: ByteSource, _ctx: QueryContext) -> Any:
"""
Read the low cardinality version. Like the write method, this has to happen immediately for container classes
:param source: The native protocol binary read buffer
:return: updated read pointer
:param _ctx: The current query context
:return: any state data required by the read_column_data method
"""
if self.low_card:
v = source.read_uint64()
if v != low_card_version:
logger.warning('Unexpected low cardinality version %d reading type %s', v, self.name)
return v
return None

def read_column(self, source: ByteSource, num_rows: int, ctx: QueryContext) -> Sequence:
"""
Expand All @@ -144,30 +147,31 @@ def read_column(self, source: ByteSource, num_rows: int, ctx: QueryContext) -> S
:param source: Native protocol binary read buffer
:param num_rows: Number of rows expected in the column
:param ctx: QueryContext for query specific settings
:return: The decoded column data as a sequence and the updated location pointer
:return: The decoded column data as a sequence
"""
self.read_column_prefix(source, ctx)
return self.read_column_data(source, num_rows, ctx)
read_state = self.read_column_prefix(source, ctx)
return self.read_column_data(source, num_rows, ctx, read_state)

def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext) -> Sequence:
def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: Any) -> Sequence:
"""
Public read method for all ClickHouseType data type columns
:param source: Native protocol binary read buffer
:param num_rows: Number of rows expected in the column
:param ctx: QueryContext for query specific settings
:return: The decoded column plus the updated location pointer
:param read_state: Any information returned by the read_column_prefix method
:return: The decoded column
"""
if self.low_card:
column = self._read_low_card_column(source, num_rows, ctx)
column = self._read_low_card_column(source, num_rows, ctx, read_state)
elif self.nullable:
column = self._read_nullable_column(source, num_rows, ctx)
column = self._read_nullable_column(source, num_rows, ctx, read_state)
else:
column = self._read_column_binary(source, num_rows, ctx)
column = self._read_column_binary(source, num_rows, ctx, read_state)
return self._finalize_column(column, ctx)

def _read_nullable_column(self, source: ByteSource, num_rows: int, ctx: QueryContext) -> Sequence:
def _read_nullable_column(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: Any) -> Sequence:
null_map = source.read_bytes(num_rows)
column = self._read_column_binary(source, num_rows, ctx)
column = self._read_column_binary(source, num_rows, ctx, read_state)
null_obj = self._active_null(ctx)
return data_conv.build_nullable_column(column, null_map, null_obj)

Expand All @@ -177,7 +181,8 @@ def _read_nullable_column(self, source: ByteSource, num_rows: int, ctx: QueryCon
# pylint: disable=no-self-use
def _read_column_binary(self,
_source: ByteSource,
_num_rows: int, _ctx: QueryContext) -> Union[Sequence, MutableSequence]:
_num_rows: int, _ctx: QueryContext,
_read_state: Any) -> Union[Sequence, MutableSequence]:
"""
Lowest level read method for ClickHouseType native data columns
:param _source: Native protocol binary read buffer
Expand Down Expand Up @@ -224,13 +229,13 @@ def write_column_data(self, column: Sequence, dest: bytearray, ctx: InsertContex
self._write_column_binary(column, dest, ctx)

# pylint: disable=no-member
def _read_low_card_column(self, source: ByteSource, num_rows: int, ctx: QueryContext):
def _read_low_card_column(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: Any):
if num_rows == 0:
return []
key_data = source.read_uint64()
key_sz = 2 ** (key_data & 0xff)
index_cnt = source.read_uint64()
index = self._read_column_binary(source, index_cnt, ctx)
index = self._read_column_binary(source, index_cnt, ctx, read_state)
key_cnt = source.read_uint64()
keys = source.read_array(array_type(key_sz, False), key_cnt)
if self.nullable:
Expand Down Expand Up @@ -313,12 +318,12 @@ def __init_subclass__(cls, registered: bool = True):
cls._struct_type = '<' + cls._array_type
cls.byte_size = array.array(cls._array_type).itemsize

def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext):
def _read_column_binary(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any):
if ctx.use_numpy:
return numpy_conv.read_numpy_array(source, self.np_type, num_rows)
return source.read_array(self._array_type, num_rows)

def _read_nullable_column(self, source: ByteSource, num_rows: int, ctx: QueryContext) -> Sequence:
def _read_nullable_column(self, source: ByteSource, num_rows: int, ctx: QueryContext, _read_state: Any) -> Sequence:
return data_conv.read_nullable_array(source, self._array_type, num_rows, self._active_null(ctx))

def _build_lc_column(self, index: Sequence, keys: array.array, ctx: QueryContext):
Expand Down Expand Up @@ -357,7 +362,7 @@ def __init__(self, type_def: TypeDef):
super().__init__(type_def)
self._name_suffix = type_def.arg_str

def _read_column_binary(self, source: Sequence, num_rows: int, ctx: QueryContext):
def _read_column_binary(self, source: Sequence, num_rows: int, ctx: QueryContext, read_state: Any):
raise NotSupportedError(f'{self.name} deserialization not supported')

def _write_column_binary(self, column: Union[Sequence, MutableSequence], dest: bytearray, ctx: InsertContext):
Expand Down
64 changes: 42 additions & 22 deletions clickhouse_connect/datatypes/container.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import array
import logging
from typing import Sequence, Collection
from typing import Sequence, Collection, Any

from clickhouse_connect.driver.insert import InsertContext
from clickhouse_connect.driver.query import QueryContext
Expand All @@ -15,15 +15,20 @@


class Array(ClickHouseType):
__slots__ = ('element_type',)
__slots__ = ('element_type', '_insert_name')
python_type = list

@property
def insert_name(self):
return self._insert_name

def __init__(self, type_def: TypeDef):
super().__init__(type_def)
self.element_type = get_from_name(type_def.values[0])
self._name_suffix = f'({self.element_type.name})'
self._insert_name = f'Array({self.element_type.insert_name})'

def read_column_prefix(self, source: ByteSource, ctx:QueryContext):
def read_column_prefix(self, source: ByteSource, ctx: QueryContext):
return self.element_type.read_column_prefix(source, ctx)

def _data_size(self, sample: Sequence) -> int:
Expand All @@ -35,7 +40,7 @@ def _data_size(self, sample: Sequence) -> int:
return total // len(sample) + 8

# pylint: disable=too-many-locals
def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext):
def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: Any):
final_type = self.element_type
depth = 1
while isinstance(final_type, Array):
Expand All @@ -48,7 +53,7 @@ def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext)
offset_sizes.append(level_offsets)
level_size = level_offsets[-1] if level_offsets else 0
if level_size:
all_values = final_type.read_column_data(source, level_size, ctx)
all_values = final_type.read_column_data(source, level_size, ctx, read_state)
else:
all_values = []
column = all_values if isinstance(all_values, list) else list(all_values)
Expand Down Expand Up @@ -86,18 +91,28 @@ def write_column_data(self, column: Sequence, dest: bytearray, ctx: InsertContex


class Tuple(ClickHouseType):
_slots = 'element_names', 'element_types'
_slots = 'element_names', 'element_types', '_insert_name'
python_type = tuple
valid_formats = 'tuple', 'dict', 'json', 'native' # native is 'tuple' for unnamed tuples, and dict for named tuples

@property
def insert_name(self):
return self._insert_name

def __init__(self, type_def: TypeDef):
super().__init__(type_def)
self.element_names = type_def.keys
self.element_types = [get_from_name(name) for name in type_def.values]
if self.element_names:
self._name_suffix = f"({', '.join(quote_identifier(k) + ' ' + str(v) for k, v in zip(type_def.keys, type_def.values))})"
self._name_suffix = f"({', '.join(quote_identifier(k) + ' ' + str(v) for
k, v in zip(type_def.keys, type_def.values))})"
else:
self._name_suffix = type_def.arg_str
if self.element_names:
self._insert_name = f"Tuple({', '.join(quote_identifier(k) + ' ' + v.insert_name
for k, v in zip(type_def.keys, self.element_types))})"
else:
self._insert_name = f"Tuple({', '.join(v.insert_name for v in self.element_types)})"

def _data_size(self, sample: Collection) -> int:
if len(sample) == 0:
Expand All @@ -114,14 +129,13 @@ def _data_size(self, sample: Collection) -> int:
return elem_size

def read_column_prefix(self, source: ByteSource, ctx: QueryContext):
for e_type in self.element_types:
e_type.read_column_prefix(source, ctx)
return [e_type.read_column_prefix(source, ctx) for e_type in self.element_types]

def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext):
def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: Any):
columns = []
e_names = self.element_names
for e_type in self.element_types:
column = e_type.read_column_data(source, num_rows, ctx)
for ix, e_type in enumerate(self.element_types):
column = e_type.read_column_data(source, num_rows, ctx, read_state[ix])
columns.append(column)
if e_names and self.read_format(ctx) != 'tuple':
dicts = [{} for _ in range(num_rows)]
Expand Down Expand Up @@ -156,14 +170,19 @@ def convert_dict_insert(self, column: Sequence) -> Sequence:


class Map(ClickHouseType):
_slots = 'key_type', 'value_type'
_slots = 'key_type', 'value_type', '_insert_name'
python_type = dict

@property
def insert_name(self):
return self._insert_name

def __init__(self, type_def: TypeDef):
super().__init__(type_def)
self.key_type = get_from_name(type_def.values[0])
self.value_type = get_from_name(type_def.values[1])
self._name_suffix = type_def.arg_str
self._insert_name = f'Map({self.key_type.insert_name}, {self.value_type.insert_name})'

def _data_size(self, sample: Collection) -> int:
total = 0
Expand All @@ -175,15 +194,16 @@ def _data_size(self, sample: Collection) -> int:
return total // len(sample)

def read_column_prefix(self, source: ByteSource, ctx: QueryContext):
self.key_type.read_column_prefix(source, ctx)
self.value_type.read_column_prefix(source, ctx)
key_state = self.key_type.read_column_prefix(source, ctx)
value_state = self.value_type.read_column_prefix(source, ctx)
return key_state, value_state

# pylint: disable=too-many-locals
def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext):
def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: Any):
offsets = source.read_array('Q', num_rows)
total_rows = 0 if len(offsets) == 0 else offsets[-1]
keys = self.key_type.read_column_data(source, total_rows, ctx)
values = self.value_type.read_column_data(source, total_rows, ctx)
keys = self.key_type.read_column_data(source, total_rows, ctx, read_state[0])
values = self.value_type.read_column_data(source, total_rows, ctx, read_state[1])
all_pairs = tuple(zip(keys, values))
column = []
app = column.append
Expand Down Expand Up @@ -231,12 +251,12 @@ def _data_size(self, sample: Collection) -> int:
array_sample = [[tuple(sub_row[key] for key in keys) for sub_row in row] for row in sample]
return self.tuple_array.data_size(array_sample)

def read_column_prefix(self, source: ByteSource, ctx:QueryContext):
self.tuple_array.read_column_prefix(source, ctx)
def read_column_prefix(self, source: ByteSource, ctx: QueryContext):
return self.tuple_array.read_column_prefix(source, ctx)

def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext):
def read_column_data(self, source: ByteSource, num_rows: int, ctx: QueryContext, read_state: Any):
keys = self.element_names
data = self.tuple_array.read_column_data(source, num_rows, ctx)
data = self.tuple_array.read_column_data(source, num_rows, ctx, read_state)
return [[dict(zip(keys, x)) for x in row] for row in data]

def write_column_prefix(self, dest: bytearray):
Expand Down
Loading

0 comments on commit 88a2a52

Please sign in to comment.