Skip to content

Commit 6e23527

Browse files
Merge pull request #120 from RohanArepally/master
Fix bug with querying rows that are nested containers
2 parents 81a9253 + 838d614 commit 6e23527

File tree

3 files changed

+28
-11
lines changed

3 files changed

+28
-11
lines changed

aiochclient/_types.pyx

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ cdef class StrType:
174174
self.container = container
175175

176176
cdef str _convert(self, str string):
177+
string = decode(string.encode())
177178
if self.container:
178179
return remove_single_quotes(string)
179180
return string
@@ -524,7 +525,7 @@ cdef class TupleType:
524525

525526
cdef tuple _convert(self, str string):
526527
return tuple(
527-
tp(decode(val.encode()))
528+
tp(val)
528529
for tp, val in zip(self.types, seq_parser(string[1:-1]))
529530
)
530531

@@ -554,7 +555,7 @@ cdef class MapType:
554555
cdef dict _convert(self, str string):
555556
key, value = string[1:-1].split(':', 1)
556557
return {
557-
self.key_type.p_type(decode(key.encode())): self.value_type.p_type(decode(value.encode()))
558+
self.key_type.p_type(key): self.value_type.p_type(value)
558559
}
559560

560561
cpdef dict p_type(self, string):
@@ -579,7 +580,7 @@ cdef class ArrayType:
579580
)
580581

581582
cdef list _convert(self, str string):
582-
return [self.type.p_type(decode(val.encode())) for val in seq_parser(string[1:-1])]
583+
return [self.type.p_type(val) for val in seq_parser(string[1:-1])]
583584

584585
cpdef list p_type(self, str string):
585586
return self._convert(string)
@@ -611,7 +612,7 @@ cdef class NestedType:
611612
for val in seq_parser(string[1:-1]):
612613
temp = []
613614
for tp, elem in zip(self.types, seq_parser(val.strip("()"))):
614-
temp.append(tp.p_type(decode(elem.encode())))
615+
temp.append(tp.p_type(elem))
615616
result.append(tuple(temp))
616617
return result
617618

aiochclient/types.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def unconvert(value) -> bytes:
147147

148148
class StrType(BaseType):
149149
def p_type(self, string: str) -> str:
150+
string = self.decode(string.encode())
150151
if self.container:
151152
return remove_single_quotes(string)
152153
return string
@@ -299,7 +300,7 @@ def __init__(self, name: str, **kwargs):
299300

300301
def p_type(self, string: str) -> tuple:
301302
return tuple(
302-
tp.p_type(self.decode(val.encode()))
303+
tp.p_type(val)
303304
for tp, val in zip(self.types, self.seq_parser(string.strip("()")))
304305
)
305306

@@ -324,9 +325,8 @@ def __init__(self, name: str, **kwargs):
324325
def p_type(self, string: str) -> dict:
325326
key, value = string[1:-1].split(':', 1)
326327
return {
327-
self.key_type.p_type(self.decode(key.encode())): self.value_type.p_type(
328-
self.decode(value.encode())
329-
)
328+
self.key_type.p_type(key): self.value_type.p_type(value)
329+
330330
}
331331

332332
def convert(self, value: bytes) -> dict:
@@ -350,7 +350,7 @@ def __init__(self, name: str, **kwargs):
350350

351351
def p_type(self, string: str) -> list:
352352
return [
353-
self.type.p_type(self.decode(val.encode()))
353+
self.type.p_type(val)
354354
for val in self.seq_parser(string[1:-1])
355355
]
356356

@@ -375,7 +375,7 @@ def __init__(self, name: str, **kwargs):
375375
def p_type(self, string: str) -> List[tuple]:
376376
return [
377377
tuple(
378-
tp.p_type(self.decode(elem.encode()))
378+
tp.p_type(elem)
379379
for tp, elem in zip(self.types, self.seq_parser(val.strip("()")))
380380
)
381381
for val in self.seq_parser(string[1:-1])

tests.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def rows(uuid):
4949
["hello", "world"],
5050
["hello", "world"],
5151
["hello", None],
52+
[("hello\'", 3, "hello")],
5253
"'\b\f\r\n\t\\",
5354
uuid,
5455
[uuid, uuid, uuid],
@@ -105,6 +106,7 @@ def rows(uuid):
105106
[],
106107
[],
107108
[],
109+
[],
108110
"'\b\f\r\n\t\\",
109111
None,
110112
[],
@@ -198,6 +200,7 @@ async def all_types_db(chclient, rows):
198200
array_string Array(String),
199201
array_low_cardinality_string Array(LowCardinality(String)),
200202
array_nullable_string Array(Nullable(String)),
203+
array_tuple Array(Tuple(String, UInt8, String)),
201204
escape_string String,
202205
uuid Nullable(UUID),
203206
array_uuid Array(UUID),
@@ -258,7 +261,7 @@ async def all_types_db(chclient, rows):
258261
def class_chclient(chclient, all_types_db, rows, request):
259262
request.cls.ch = chclient
260263
cls_rows = rows
261-
cls_rows[1][44] = dt.datetime(
264+
cls_rows[1][45] = dt.datetime(
262265
2019, 1, 1, 3, 0
263266
) # DateTime64 always returns datetime type
264267
request.cls.rows = [tuple(r) for r in cls_rows]
@@ -676,6 +679,19 @@ async def test_array_string(self):
676679
record = await self.select_record_bytes("array_string")
677680
assert record[0] == result
678681
assert record["array_string"] == result
682+
683+
async def test_array_tuple(self):
684+
result = [("hello'", 3, "hello")]
685+
assert await self.select_field("array_tuple") == result
686+
record = await self.select_record("array_tuple")
687+
assert record[0] == result
688+
assert record["array_tuple"] == result
689+
690+
result = b"[('hello\\'',3,'hello')]"
691+
assert await self.select_field_bytes("array_tuple") == result
692+
record = await self.select_record_bytes("array_tuple")
693+
assert record[0] == result
694+
assert record["array_tuple"] == result
679695

680696
async def test_array_low_cardinality_string(self):
681697
result = ["hello", "world"]

0 commit comments

Comments
 (0)