Skip to content

Commit 565ef0f

Browse files
committed
test: add datatype.parse_sqltype testcases
Signed-off-by: Đặng Minh Dũng <[email protected]>
1 parent b8e0dcd commit 565ef0f

File tree

4 files changed

+191
-14
lines changed

4 files changed

+191
-14
lines changed

conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import tests.assertions # noqa

sqlalchemy_trino/datatype.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,17 @@
3131

3232
# === Date and time ===
3333
'date': sqltypes.DATE,
34-
'time': sqltypes.Time,
35-
'time with time zone': sqltypes.Time,
34+
'time': sqltypes.TIME,
3635
'timestamp': sqltypes.TIMESTAMP,
37-
'timestamp with time zone': sqltypes.TIMESTAMP,
38-
39-
# 'interval year to month': IntervalOfYear, # TODO
40-
'interval day to second': sqltypes.Interval,
4136

37+
# 'interval year to month':
38+
# 'interval day to second':
39+
#
4240
# === Structural ===
43-
'array': sqltypes.ARRAY,
44-
# 'map': MAP
45-
# 'row': ROW
46-
41+
# 'array': ARRAY,
42+
# 'map': MAP
43+
# 'row': ROW
44+
#
4745
# === Mixed ===
4846
# 'ipaddress': IPADDRESS
4947
# 'uuid': UUID,
@@ -53,13 +51,39 @@
5351
# 'tdigest': TDIGEST,
5452
}
5553

54+
SQLType = Union[TypeEngine, Type[TypeEngine]]
55+
5656

5757
class MAP(TypeEngine):
58-
pass
58+
__visit_name__ = "MAP"
59+
60+
def __init__(self, key_type: SQLType, value_type: SQLType):
61+
if isinstance(key_type, type):
62+
key_type = key_type()
63+
self.key_type: TypeEngine = key_type
64+
65+
if isinstance(value_type, type):
66+
value_type = value_type()
67+
self.value_type: TypeEngine = value_type
68+
69+
@property
70+
def python_type(self):
71+
return dict
5972

6073

6174
class ROW(TypeEngine):
62-
pass
75+
__visit_name__ = "ROW"
76+
77+
def __init__(self, attr_types: Dict[str, SQLType]):
78+
for name, attr_type in attr_types.items():
79+
if isinstance(attr_type, type):
80+
attr_type = attr_type()
81+
attr_types[name] = attr_type
82+
self.attr_types: Dict[str, TypeEngine] = attr_types
83+
84+
@property
85+
def python_type(self):
86+
return dict
6387

6488

6589
def split(string: str, delimiter: str = ',',
@@ -106,15 +130,22 @@ def parse_sqltype(type_str: str) -> TypeEngine:
106130

107131
if type_name == "array":
108132
item_type = parse_sqltype(type_opts)
133+
if isinstance(item_type, sqltypes.ARRAY):
134+
dimensions = (item_type.dimensions or 1) + 1
135+
return sqltypes.ARRAY(item_type.item_type, dimensions=dimensions)
109136
return sqltypes.ARRAY(item_type)
110137
elif type_name == "map":
111138
key_type_str, value_type_str = split(type_opts)
112139
key_type = parse_sqltype(key_type_str)
113140
value_type = parse_sqltype(value_type_str)
114141
return MAP(key_type, value_type)
115142
elif type_name == "row":
116-
attr_types = split(type_opts)
117-
return ROW() # TODO
143+
attr_types: Dict[str, SQLType] = {}
144+
for attr_str in split(type_opts):
145+
name, attr_type_str = split(attr_str.strip(), delimiter=' ')
146+
attr_type = parse_sqltype(attr_type_str)
147+
attr_types[name] = attr_type
148+
return ROW(attr_types)
118149

119150
if type_name not in _type_map:
120151
util.warn(f"Did not recognize type '{type_name}'")

tests/assertions.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from assertpy import add_extension, assert_that
2+
from sqlalchemy.sql.sqltypes import ARRAY
3+
4+
from sqlalchemy_trino.datatype import SQLType, MAP, ROW
5+
6+
7+
def assert_sqltype(this: SQLType, that: SQLType):
8+
if isinstance(this, type):
9+
this = this()
10+
if isinstance(that, type):
11+
that = that()
12+
assert_that(type(this)).is_same_as(type(that))
13+
if isinstance(this, ARRAY):
14+
assert_sqltype(this.item_type, that.item_type)
15+
if this.dimensions is None or this.dimensions == 1:
16+
assert_that(that.dimensions).is_in(None, 1)
17+
else:
18+
assert_that(this.dimensions).is_equal_to(this.dimensions)
19+
elif isinstance(this, MAP):
20+
assert_sqltype(this.key_type, that.key_type)
21+
assert_sqltype(this.value_type, that.value_type)
22+
elif isinstance(this, ROW):
23+
assert_that(len(this.attr_types)).is_equal_to(len(that.attr_types))
24+
for name, this_attr in this.attr_types.items():
25+
that_attr = this.attr_types[name]
26+
assert_sqltype(this_attr, that_attr)
27+
else:
28+
assert_that(str(this)).is_equal_to(str(that))
29+
30+
31+
@add_extension
32+
def is_sqltype(self, that):
33+
this = self.val
34+
assert_sqltype(this, that)

tests/test_datatype_parse.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import pytest
2+
from assertpy import assert_that
3+
from sqlalchemy.sql.sqltypes import *
4+
from sqlalchemy.sql.type_api import TypeEngine
5+
6+
from sqlalchemy_trino import datatype
7+
from sqlalchemy_trino.datatype import MAP, ROW
8+
9+
10+
@pytest.mark.parametrize(
11+
'type_str, sql_type',
12+
datatype._type_map.items(),
13+
ids=datatype._type_map.keys()
14+
)
15+
def test_parse_simple_type(type_str: str, sql_type: TypeEngine):
16+
actual_type = datatype.parse_sqltype(type_str)
17+
if not isinstance(actual_type, type):
18+
actual_type = type(actual_type)
19+
assert_that(actual_type).is_equal_to(sql_type)
20+
21+
22+
parse_type_options_testcases = {
23+
'VARCHAR(10)': VARCHAR(10),
24+
'DECIMAL(20)': DECIMAL(20),
25+
'DECIMAL(20, 3)': DECIMAL(20, 3),
26+
}
27+
28+
29+
@pytest.mark.parametrize(
30+
'type_str, sql_type',
31+
parse_type_options_testcases.items(),
32+
ids=parse_type_options_testcases.keys()
33+
)
34+
def test_parse_type_options(type_str: str, sql_type: TypeEngine):
35+
actual_type = datatype.parse_sqltype(type_str)
36+
assert_that(actual_type).is_sqltype(sql_type)
37+
38+
39+
parse_array_testcases = {
40+
'array(integer)': ARRAY(INTEGER()),
41+
'array(varchar(10))': ARRAY(VARCHAR(10)),
42+
'array(decimal(20,3))': ARRAY(DECIMAL(20, 3)),
43+
'array(array(varchar(10)))': ARRAY(VARCHAR(10), dimensions=2),
44+
}
45+
46+
47+
@pytest.mark.parametrize(
48+
'type_str, sql_type',
49+
parse_array_testcases.items(),
50+
ids=parse_array_testcases.keys()
51+
)
52+
def test_parse_array(type_str: str, sql_type: ARRAY):
53+
actual_type = datatype.parse_sqltype(type_str)
54+
assert_that(actual_type).is_sqltype(sql_type)
55+
56+
57+
parse_map_testcases = {
58+
'map(char, integer)': MAP(CHAR(), INTEGER()),
59+
'map(varchar(10), varchar(10))': MAP(VARCHAR(10), VARCHAR(10)),
60+
'map(varchar(10), decimal(20,3))': MAP(VARCHAR(10), DECIMAL(20, 3)),
61+
'map(char, array(varchar(10)))': MAP(CHAR(), ARRAY(VARCHAR(10))),
62+
'map(varchar(10), array(varchar(10)))': MAP(VARCHAR(10), ARRAY(VARCHAR(10))),
63+
'map(varchar(10), array(array(varchar(10))))': MAP(VARCHAR(10), ARRAY(VARCHAR(10), dimensions=2)),
64+
}
65+
66+
67+
@pytest.mark.parametrize(
68+
'type_str, sql_type',
69+
parse_map_testcases.items(),
70+
ids=parse_map_testcases.keys()
71+
)
72+
def test_parse_map(type_str: str, sql_type: ARRAY):
73+
actual_type = datatype.parse_sqltype(type_str)
74+
assert_that(actual_type).is_sqltype(sql_type)
75+
76+
77+
parse_row_testcases = {
78+
'row(a integer, b varchar)': ROW(dict(a=INTEGER(), b=VARCHAR())),
79+
'row(a varchar(20), b decimal(20,3))': ROW(dict(a=VARCHAR(20), b=DECIMAL(20, 3))),
80+
'row(x array(varchar(10)), y array(array(varchar(10))), z decimal(20,3))':
81+
ROW(dict(x=ARRAY(VARCHAR(10)), y=ARRAY(VARCHAR(10), dimensions=2), z=DECIMAL(20, 3))),
82+
}
83+
84+
85+
@pytest.mark.parametrize(
86+
'type_str, sql_type',
87+
parse_row_testcases.items(),
88+
ids=parse_row_testcases.keys()
89+
)
90+
def test_parse_row(type_str: str, sql_type: ARRAY):
91+
actual_type = datatype.parse_sqltype(type_str)
92+
assert_that(actual_type).is_sqltype(sql_type)
93+
94+
95+
parse_datetime_testcases = {
96+
'date': DATE(),
97+
'time': TIME(),
98+
'time with time zone': TIME(timezone=True),
99+
'timestamp': TIMESTAMP(),
100+
'timestamp with time zone': TIMESTAMP(timezone=True),
101+
}
102+
103+
104+
@pytest.mark.parametrize(
105+
'type_str, sql_type',
106+
parse_datetime_testcases.items(),
107+
ids=parse_datetime_testcases.keys()
108+
)
109+
def test_parse_datetime(type_str: str, sql_type: ARRAY):
110+
actual_type = datatype.parse_sqltype(type_str)
111+
assert_that(actual_type).is_sqltype(sql_type)

0 commit comments

Comments
 (0)