Skip to content

Commit 44ca31b

Browse files
authored
SQL INSERT INTO in python (#88)
1 parent e8bec6e commit 44ca31b

File tree

9 files changed

+355
-113
lines changed

9 files changed

+355
-113
lines changed

data/sql_input_2.sql

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
-- https://cratedb.com/docs/sql-99/en/latest/chapters/01.html
2+
-- https://www.postgresql.org/docs/16/sql-createtable.html
3+
-- https://www.postgresql.org/docs/16/sql-insert.html
4+
-- https://www.postgresql.org/docs/16/sql-select.html
5+
CREATE TABLE city (
6+
name VARCHAR,
7+
population INT,
8+
timezone INT
9+
);
10+
11+
INSERT INTO city (name, population, timezone)
12+
VALUES ('San Francisco', 852469, -8);
13+
14+
INSERT INTO city (name, population, timezone)
15+
VALUES ('New York', 8405837, -5);
16+
17+
SELECT
18+
name,
19+
population,
20+
timezone
21+
FROM city;

data/sql_input_3.sql

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
-- https://cratedb.com/docs/sql-99/en/latest/chapters/01.html
2+
-- https://www.postgresql.org/docs/16/sql-createtable.html
3+
-- https://www.postgresql.org/docs/16/sql-insert.html
4+
-- https://www.postgresql.org/docs/16/sql-select.html
5+
CREATE TABLE city (
6+
name VARCHAR,
7+
population INT,
8+
timezone INT
9+
);
10+
11+
INSERT INTO city (name, timezone)
12+
VALUES ('San Francisco', -8);
13+
14+
INSERT INTO city (name, population)
15+
VALUES ('New York', 8405837);
16+
17+
SELECT
18+
name,
19+
population,
20+
timezone
21+
FROM city;

data/sql_output_0.json

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
1-
{
2-
"table_name": ["city"]
3-
}
1+
[{ "table_name": "city" }]

data/sql_output_1.json

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
1-
{
2-
"table_name": ["city", "town"]
3-
}
1+
[{ "table_name": "city" }, { "table_name": "town" }]

data/sql_output_2.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[
2+
{ "name": "San Francisco", "population": 852469, "timezone": -8 },
3+
{ "name": "New York", "population": 8405837, "timezone": -5 }
4+
]

data/sql_output_3.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[
2+
{ "name": "San Francisco", "population": null, "timezone": -8 },
3+
{ "name": "New York", "population": 8405837, "timezone": null }
4+
]

snippets/python/sql_test.py

Lines changed: 154 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,119 @@
11

2+
import dataclasses
23
import json
4+
import typing
5+
6+
7+
@dataclasses.dataclass(frozen=True)
8+
class SQLState:
9+
state: dict
10+
11+
def read_table_meta(self, table_name: str) -> dict:
12+
return self.state.get(table_name, {}).get("metadata", {})
13+
14+
def read_table_rows(self, table_name: str) -> list[dict]:
15+
return self.state.get(table_name, {}).get("rows", [])
16+
17+
def read_information_schema(self) -> list[dict]:
18+
return [data["metadata"] for data in self.state.values()]
19+
20+
def write_table_meta(self, table_name: str, data: dict):
21+
state = self.state
22+
table = state.get(table_name, {})
23+
metadata = table.get("metadata", {})
24+
metadata.update(data)
25+
table["metadata"] = metadata
26+
state[table_name] = table
27+
return self.__class__(state)
28+
29+
def write_table_rows(self, table_name: str, data: dict):
30+
state = self.state
31+
table = state.get(table_name, {})
32+
rows = table.get("rows", [])
33+
rows.append(data)
34+
table["rows"] = rows
35+
state[table_name] = table
36+
return self.__class__(state)
37+
38+
39+
class SQLType:
40+
@staticmethod
41+
def varchar(data) -> str:
42+
data_str = str(data).strip()
43+
if data_str.startswith("'") or data_str.startswith('"'):
44+
data_str = data_str[1:]
45+
if data_str.endswith("'") or data_str.endswith('"'):
46+
data_str = data_str[:-1]
47+
return data_str
48+
49+
@staticmethod
50+
def int(data) -> int:
51+
return int(data.strip())
52+
53+
54+
sql_type_map = {
55+
"VARCHAR": SQLType.varchar,
56+
"INT": SQLType.int,
57+
}
58+
59+
60+
class SQLFunctions:
61+
@staticmethod
62+
def create_table(state: SQLState, *args, table_schema="public") -> typing.Tuple[list, SQLState]:
63+
output: list[dict] = []
64+
table_name = args[2]
365

66+
# get columns
67+
columns = {}
68+
columns_str = " ".join(args[3:]).replace("(", "").replace(")", "").strip()
69+
if columns_str:
70+
# fmt: off
71+
columns = {
72+
column.strip().split(" ")[0]: column.strip().split(" ")[1]
73+
for column in columns_str.split(",")
74+
}
75+
# fmt: on
476

5-
class SQL:
6-
data: dict = {}
7-
8-
def __init__(self) -> None:
9-
self.data = {}
10-
11-
def information_schema_tables(self) -> list[dict]:
12-
return [data["metadata"] for data in self.data.values()]
13-
14-
def create_table(self, *args, table_schema="public") -> dict:
15-
table_name = args[2]
16-
if not self.data.get(table_name):
17-
self.data[table_name] = {
18-
"metadata": {
77+
if not state.read_table_meta(table_name):
78+
state = state.write_table_meta(
79+
table_name,
80+
{
1981
"table_name": table_name,
2082
"table_schema": table_schema,
83+
"colums": columns,
2184
},
22-
}
23-
return {}
85+
)
86+
return (output, state)
87+
88+
@staticmethod
89+
def insert_into(state: SQLState, *args) -> typing.Tuple[list, SQLState]:
90+
output: list[dict] = []
91+
table_name = args[2]
92+
93+
values_index = None
94+
for i, arg in enumerate(args):
95+
if arg == "VALUES":
96+
values_index = i
97+
if values_index is None:
98+
raise ValueError("VALUES not found")
99+
100+
keys = " ".join(args[3:values_index]).replace("(", "").replace(")", "").split(",")
101+
keys = [key.strip() for key in keys]
102+
values = " ".join(args[values_index + 1 :]).replace("(", "").replace(")", "").split(",")
103+
values = [value.strip() for value in values]
104+
key_value_map = dict(zip(keys, values))
24105

25-
create_table.sql = "CREATE TABLE"
106+
data = {}
107+
if metadata := state.read_table_meta(table_name):
108+
for key, value in key_value_map.items():
109+
data[key] = sql_type_map[metadata["colums"][key]](value)
110+
state = state.write_table_rows(table_name, data)
26111

27-
def select(self, *args) -> dict:
28-
output = {}
112+
return (output, state)
113+
114+
@staticmethod
115+
def select(state: SQLState, *args) -> typing.Tuple[list, SQLState]:
116+
output: list[dict] = []
29117

30118
from_index = None
31119
where_index = None
@@ -34,49 +122,59 @@ def select(self, *args) -> dict:
34122
from_index = i
35123
if arg == "WHERE":
36124
where_index = i
125+
if from_index is None:
126+
raise ValueError("FROM not found")
37127

38128
# get select keys by getting the slice of args before FROM
39129
select_keys = " ".join(args[1:from_index]).split(",")
130+
select_keys = [key.strip() for key in select_keys]
40131

41132
# get where keys by getting the slice of args after WHERE
42133
from_value = args[from_index + 1]
43134

44-
# consider "information_schema.tables" a special case until
45-
# we figure out why its so different from the others
135+
# `information_schema.tables` is a special case
46136
if from_value == "information_schema.tables":
47-
target = self.information_schema_tables()
48-
49-
# fmt: off
50-
output = {
51-
key: [
52-
value for data in target
53-
for key, value in data.items()
54-
if key in select_keys
55-
]
56-
for key in select_keys
57-
}
58-
# fmt: on
59-
60-
return output
61-
62-
select.sql = "SELECT"
63-
64-
sql_map = {
65-
create_table.sql: create_table,
66-
select.sql: select,
67-
}
68-
69-
def run(self, input_sql: list[str]) -> list[str]:
70-
output = {}
71-
72-
for line in input_sql:
73-
if not line.startswith("--"):
74-
words = line.split(" ")
75-
for i in reversed(range(len(words))):
76-
key = " ".join(words[:i])
77-
if func := self.sql_map.get(key):
78-
output = func(self, *words)
79-
break
80-
81-
return [json.dumps(output)]
137+
data = state.read_information_schema()
138+
else:
139+
data = state.read_table_rows(from_value)
140+
141+
output = []
142+
for datum in data:
143+
# fmt: off
144+
output.append({
145+
key: datum.get(key)
146+
for key in select_keys
147+
})
148+
# fmt: on
149+
150+
return (output, state)
151+
152+
153+
sql_function_map: dict[str, typing.Callable] = {
154+
"CREATE TABLE": SQLFunctions.create_table,
155+
"SELECT": SQLFunctions.select,
156+
"INSERT INTO": SQLFunctions.insert_into,
157+
}
158+
159+
160+
def run_sql(input_sql: list[str]) -> list[str]:
161+
output = []
162+
state = SQLState(state={})
163+
164+
# remove comments
165+
input_sql = [line.strip() for line in input_sql if not line.startswith("--")]
166+
167+
# re-split on semi-colons
168+
input_sql = " ".join(input_sql).split(";")
169+
170+
# iterate over each line of sql
171+
for line in input_sql:
172+
words = line.split(" ")
173+
for i in reversed(range(len(words) + 1)):
174+
key = " ".join(words[:i]).strip()
175+
if func := sql_function_map.get(key):
176+
output, state = func(state, *[word for word in words if word])
177+
break
178+
179+
return [json.dumps(output)]
82180

0 commit comments

Comments
 (0)