Skip to content

Commit 18bee08

Browse files
committed
fix formatting
1 parent 31d3767 commit 18bee08

File tree

3 files changed

+34
-30
lines changed

3 files changed

+34
-30
lines changed

sample/classification_tree.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ SELECT
55
ELSE 2
66
END
77
END AS value
8-
FROM my_table
8+
FROM my_table;

sample/regression_tree.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ SELECT
2323
END
2424
ELSE 92
2525
END AS value
26-
FROM my_table
26+
FROM my_table;

tree_parser.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22

33
DEBUG = False
44

5+
def _debug(*args, **kwargs):
6+
if DEBUG:
7+
print(*args, **kwargs)
8+
9+
10+
def _print(*args, end='', **kwargs):
11+
if not DEBUG:
12+
print(*args, end=end, **kwargs)
13+
14+
515
def parse_tree(path, save_to='query.sql', column_name='value', src_table='my_table', tab=4):
616
OUT = sys.stdout
717
if not DEBUG:
@@ -15,8 +25,7 @@ def parse_tree(path, save_to='query.sql', column_name='value', src_table='my_tab
1525
node_type = 'split'
1626
stack = [0]
1727

18-
if not DEBUG:
19-
print('SELECT', end='')
28+
_print('SELECT')
2029

2130
for i, row in enumerate(rule):
2231
row = row.strip()
@@ -32,13 +41,11 @@ def parse_tree(path, save_to='query.sql', column_name='value', src_table='my_tab
3241
if depth < stack[-1]:
3342
stack.pop()
3443
while stack[-1] != depth:
35-
if not DEBUG:
36-
text = '\n' + ' ' * tab * stack[-1] + 'END'
37-
print(text, end='')
44+
text = f"\n{' ' * tab * stack[-1]}END"
45+
_print(text)
3846
stack.pop()
3947

40-
if DEBUG:
41-
print(row, stack)
48+
_debug(row, stack)
4249

4350
# infer node type
4451
node_type = 'leaf'
@@ -66,32 +73,29 @@ def parse_tree(path, save_to='query.sql', column_name='value', src_table='my_tab
6673
after = '' # handle cases to put END
6774
if i < len(rule) - 1:
6875
if rule[i + 1].count('|') <= stack[-2]:
69-
after = '\n' + ' ' * spacing * (depth - 1) + 'END'
76+
after = f"\n{' ' * spacing * (depth - 1)}END"
7077

7178
text = f" {text}{after}"
72-
73-
if not DEBUG:
74-
print(text, end='')
75-
else:
76-
end, text = '', ''
79+
_print(text)
80+
else: # split/internal node
81+
text = ''
7782
if else_flag:
78-
text = '\n' + ' ' * indent + 'ELSE'
79-
if not DEBUG:
80-
print(text, end='')
83+
text = f"\n{' ' * indent}ELSE"
84+
_print(text)
8185
else:
8286
start_idx = (spacing + 1) * depth + 1
8387
text = row[start_idx:]
84-
text = f"\n{' ' * indent}{end}CASE WHEN {text} THEN"
85-
if not DEBUG:
86-
print(text, end='')
88+
text = f"\n{' ' * indent}CASE WHEN {text} THEN"
89+
_print(text)
8790

88-
if DEBUG:
89-
print(stack)
90-
else:
91-
if stack[-1] > 1:
92-
text = '\n' + ' ' * spacing * stack[-1] + 'END'
93-
print(text, end='')
94-
95-
print(f"\n{' ' * spacing}END AS {column_name}")
96-
print(f"FROM {src_table}")
91+
_debug(stack)
92+
93+
while stack[-1] > 0:
94+
text = f"\n{' ' * spacing * stack[-1]}END"
95+
if stack[-1] == 1:
96+
text += f" AS {column_name}\nFROM {src_table};\n"
97+
98+
_print(text)
99+
stack.pop()
100+
97101
sys.stdout = OUT

0 commit comments

Comments
 (0)