1
1
2
+ import dataclasses
2
3
import json
4
+ import typing
5
+
6
+
7
+ @dataclasses .dataclass (frozen = True )
8
+ class SQLState :
9
+ state : dict
10
+
11
+ def read_table_meta (self , table_name : str ) -> dict :
12
+ return self .state .get (table_name , {}).get ("metadata" , {})
13
+
14
+ def read_table_rows (self , table_name : str ) -> list [dict ]:
15
+ return self .state .get (table_name , {}).get ("rows" , [])
16
+
17
+ def read_information_schema (self ) -> list [dict ]:
18
+ return [data ["metadata" ] for data in self .state .values ()]
19
+
20
+ def write_table_meta (self , table_name : str , data : dict ):
21
+ state = self .state
22
+ table = state .get (table_name , {})
23
+ metadata = table .get ("metadata" , {})
24
+ metadata .update (data )
25
+ table ["metadata" ] = metadata
26
+ state [table_name ] = table
27
+ return self .__class__ (state )
28
+
29
+ def write_table_rows (self , table_name : str , data : dict ):
30
+ state = self .state
31
+ table = state .get (table_name , {})
32
+ rows = table .get ("rows" , [])
33
+ rows .append (data )
34
+ table ["rows" ] = rows
35
+ state [table_name ] = table
36
+ return self .__class__ (state )
37
+
38
+
39
+ class SQLType :
40
+ @staticmethod
41
+ def varchar (data ) -> str :
42
+ data_str = str (data ).strip ()
43
+ if data_str .startswith ("'" ) or data_str .startswith ('"' ):
44
+ data_str = data_str [1 :]
45
+ if data_str .endswith ("'" ) or data_str .endswith ('"' ):
46
+ data_str = data_str [:- 1 ]
47
+ return data_str
48
+
49
+ @staticmethod
50
+ def int (data ) -> int :
51
+ return int (data .strip ())
52
+
53
+
54
+ sql_type_map = {
55
+ "VARCHAR" : SQLType .varchar ,
56
+ "INT" : SQLType .int ,
57
+ }
58
+
59
+
60
+ class SQLFunctions :
61
+ @staticmethod
62
+ def create_table (state : SQLState , * args , table_schema = "public" ) -> typing .Tuple [list , SQLState ]:
63
+ output : list [dict ] = []
64
+ table_name = args [2 ]
3
65
66
+ # get columns
67
+ columns = {}
68
+ columns_str = " " .join (args [3 :]).replace ("(" , "" ).replace (")" , "" ).strip ()
69
+ if columns_str :
70
+ # fmt: off
71
+ columns = {
72
+ column .strip ().split (" " )[0 ]: column .strip ().split (" " )[1 ]
73
+ for column in columns_str .split ("," )
74
+ }
75
+ # fmt: on
4
76
5
- class SQL :
6
- data : dict = {}
7
-
8
- def __init__ (self ) -> None :
9
- self .data = {}
10
-
11
- def information_schema_tables (self ) -> list [dict ]:
12
- return [data ["metadata" ] for data in self .data .values ()]
13
-
14
- def create_table (self , * args , table_schema = "public" ) -> dict :
15
- table_name = args [2 ]
16
- if not self .data .get (table_name ):
17
- self .data [table_name ] = {
18
- "metadata" : {
77
+ if not state .read_table_meta (table_name ):
78
+ state = state .write_table_meta (
79
+ table_name ,
80
+ {
19
81
"table_name" : table_name ,
20
82
"table_schema" : table_schema ,
83
+ "colums" : columns ,
21
84
},
22
- }
23
- return {}
85
+ )
86
+ return (output , state )
87
+
88
+ @staticmethod
89
+ def insert_into (state : SQLState , * args ) -> typing .Tuple [list , SQLState ]:
90
+ output : list [dict ] = []
91
+ table_name = args [2 ]
92
+
93
+ values_index = None
94
+ for i , arg in enumerate (args ):
95
+ if arg == "VALUES" :
96
+ values_index = i
97
+ if values_index is None :
98
+ raise ValueError ("VALUES not found" )
99
+
100
+ keys = " " .join (args [3 :values_index ]).replace ("(" , "" ).replace (")" , "" ).split ("," )
101
+ keys = [key .strip () for key in keys ]
102
+ values = " " .join (args [values_index + 1 :]).replace ("(" , "" ).replace (")" , "" ).split ("," )
103
+ values = [value .strip () for value in values ]
104
+ key_value_map = dict (zip (keys , values ))
24
105
25
- create_table .sql = "CREATE TABLE"
106
+ data = {}
107
+ if metadata := state .read_table_meta (table_name ):
108
+ for key , value in key_value_map .items ():
109
+ data [key ] = sql_type_map [metadata ["colums" ][key ]](value )
110
+ state = state .write_table_rows (table_name , data )
26
111
27
- def select (self , * args ) -> dict :
28
- output = {}
112
+ return (output , state )
113
+
114
+ @staticmethod
115
+ def select (state : SQLState , * args ) -> typing .Tuple [list , SQLState ]:
116
+ output : list [dict ] = []
29
117
30
118
from_index = None
31
119
where_index = None
@@ -34,49 +122,59 @@ def select(self, *args) -> dict:
34
122
from_index = i
35
123
if arg == "WHERE" :
36
124
where_index = i
125
+ if from_index is None :
126
+ raise ValueError ("FROM not found" )
37
127
38
128
# get select keys by getting the slice of args before FROM
39
129
select_keys = " " .join (args [1 :from_index ]).split ("," )
130
+ select_keys = [key .strip () for key in select_keys ]
40
131
41
132
# get where keys by getting the slice of args after WHERE
42
133
from_value = args [from_index + 1 ]
43
134
44
- # consider "information_schema.tables" a special case until
45
- # we figure out why its so different from the others
135
+ # `information_schema.tables` is a special case
46
136
if from_value == "information_schema.tables" :
47
- target = self .information_schema_tables ()
48
-
49
- # fmt: off
50
- output = {
51
- key : [
52
- value for data in target
53
- for key , value in data .items ()
54
- if key in select_keys
55
- ]
56
- for key in select_keys
57
- }
58
- # fmt: on
59
-
60
- return output
61
-
62
- select .sql = "SELECT"
63
-
64
- sql_map = {
65
- create_table .sql : create_table ,
66
- select .sql : select ,
67
- }
68
-
69
- def run (self , input_sql : list [str ]) -> list [str ]:
70
- output = {}
71
-
72
- for line in input_sql :
73
- if not line .startswith ("--" ):
74
- words = line .split (" " )
75
- for i in reversed (range (len (words ))):
76
- key = " " .join (words [:i ])
77
- if func := self .sql_map .get (key ):
78
- output = func (self , * words )
79
- break
80
-
81
- return [json .dumps (output )]
137
+ data = state .read_information_schema ()
138
+ else :
139
+ data = state .read_table_rows (from_value )
140
+
141
+ output = []
142
+ for datum in data :
143
+ # fmt: off
144
+ output .append ({
145
+ key : datum .get (key )
146
+ for key in select_keys
147
+ })
148
+ # fmt: on
149
+
150
+ return (output , state )
151
+
152
+
153
+ sql_function_map : dict [str , typing .Callable ] = {
154
+ "CREATE TABLE" : SQLFunctions .create_table ,
155
+ "SELECT" : SQLFunctions .select ,
156
+ "INSERT INTO" : SQLFunctions .insert_into ,
157
+ }
158
+
159
+
160
+ def run_sql (input_sql : list [str ]) -> list [str ]:
161
+ output = []
162
+ state = SQLState (state = {})
163
+
164
+ # remove comments
165
+ input_sql = [line .strip () for line in input_sql if not line .startswith ("--" )]
166
+
167
+ # re-split on semi-colons
168
+ input_sql = " " .join (input_sql ).split (";" )
169
+
170
+ # iterate over each line of sql
171
+ for line in input_sql :
172
+ words = line .split (" " )
173
+ for i in reversed (range (len (words ) + 1 )):
174
+ key = " " .join (words [:i ]).strip ()
175
+ if func := sql_function_map .get (key ):
176
+ output , state = func (state , * [word for word in words if word ])
177
+ break
178
+
179
+ return [json .dumps (output )]
82
180
0 commit comments