1
+ """Trino integration tests.
2
+
3
+ These rely on having a Trino+Hadoop cluster set up.
4
+ They also require a tables created by make_test_tables.sh.
5
+ """
6
+
7
+ from __future__ import absolute_import
8
+ from __future__ import unicode_literals
9
+
10
+ import contextlib
11
+ import os
12
+ import requests
13
+
14
+ from pyhive import exc
15
+ from pyhive import trino
16
+ from pyhive .tests .dbapi_test_case import DBAPITestCase
17
+ from pyhive .tests .dbapi_test_case import with_cursor
18
+ from pyhive .tests .test_presto import TestPresto
19
+ import mock
20
+ import unittest
21
+ import datetime
22
+
23
+ _HOST = 'localhost'
24
+ _PORT = '18080'
25
+
26
+
27
+ class TestTrino (TestPresto ):
28
+ __test__ = True
29
+
30
+ def connect (self ):
31
+ return trino .connect (host = _HOST , port = _PORT , source = self .id ())
32
+
33
+ def test_bad_protocol (self ):
34
+ self .assertRaisesRegexp (ValueError , 'Protocol must be' ,
35
+ lambda : trino .connect ('localhost' , protocol = 'nonsense' ).cursor ())
36
+
37
+ def test_escape_args (self ):
38
+ escaper = trino .TrinoParamEscaper ()
39
+
40
+ self .assertEqual (escaper .escape_args ((datetime .date (2020 , 4 , 17 ),)),
41
+ ("date '2020-04-17'" ,))
42
+ self .assertEqual (escaper .escape_args ((datetime .datetime (2020 , 4 , 17 , 12 , 0 , 0 , 123456 ),)),
43
+ ("timestamp '2020-04-17 12:00:00.123'" ,))
44
+
45
+ @with_cursor
46
+ def test_description (self , cursor ):
47
+ cursor .execute ('SELECT 1 AS foobar FROM one_row' )
48
+ self .assertEqual (cursor .description , [('foobar' , 'integer' , None , None , None , None , True )])
49
+ self .assertIsNotNone (cursor .last_query_id )
50
+
51
+ @with_cursor
52
+ def test_complex (self , cursor ):
53
+ cursor .execute ('SELECT * FROM one_row_complex' )
54
+ # TODO Trino drops the union field
55
+
56
+ tinyint_type = 'tinyint'
57
+ smallint_type = 'smallint'
58
+ float_type = 'real'
59
+ self .assertEqual (cursor .description , [
60
+ ('boolean' , 'boolean' , None , None , None , None , True ),
61
+ ('tinyint' , tinyint_type , None , None , None , None , True ),
62
+ ('smallint' , smallint_type , None , None , None , None , True ),
63
+ ('int' , 'integer' , None , None , None , None , True ),
64
+ ('bigint' , 'bigint' , None , None , None , None , True ),
65
+ ('float' , float_type , None , None , None , None , True ),
66
+ ('double' , 'double' , None , None , None , None , True ),
67
+ ('string' , 'varchar' , None , None , None , None , True ),
68
+ ('timestamp' , 'timestamp' , None , None , None , None , True ),
69
+ ('binary' , 'varbinary' , None , None , None , None , True ),
70
+ ('array' , 'array(integer)' , None , None , None , None , True ),
71
+ ('map' , 'map(integer,integer)' , None , None , None , None , True ),
72
+ ('struct' , 'row(a integer,b integer)' , None , None , None , None , True ),
73
+ # ('union', 'varchar', None, None, None, None, True),
74
+ ('decimal' , 'decimal(10,1)' , None , None , None , None , True ),
75
+ ])
76
+ rows = cursor .fetchall ()
77
+ expected = [(
78
+ True ,
79
+ 127 ,
80
+ 32767 ,
81
+ 2147483647 ,
82
+ 9223372036854775807 ,
83
+ 0.5 ,
84
+ 0.25 ,
85
+ 'a string' ,
86
+ '1970-01-01 00:00:00.000' ,
87
+ b'123' ,
88
+ [1 , 2 ],
89
+ {"1" : 2 , "3" : 4 }, # Trino converts all keys to strings so that they're valid JSON
90
+ [1 , 2 ], # struct is returned as a list of elements
91
+ # '{0:1}',
92
+ '0.1' ,
93
+ )]
94
+ self .assertEqual (rows , expected )
95
+ # catch unicode/str
96
+ self .assertEqual (list (map (type , rows [0 ])), list (map (type , expected [0 ])))
0 commit comments