Skip to content

Commit 8cbd762

Browse files
committed
refactor
1 parent 0c349da commit 8cbd762

File tree

1 file changed

+78
-56
lines changed

1 file changed

+78
-56
lines changed

tree_parser.py

Lines changed: 78 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -32,70 +32,92 @@ def parse_tree(path, save_to='query.sql', column_name='value', src_table='my_tab
3232

3333
depth = row.count('|')
3434
indent = depth * tab
35-
else_flag = False
36-
37-
if depth > stack[-1]:
38-
stack.append(depth)
39-
else:
40-
else_flag = True
41-
if depth < stack[-1]:
42-
stack.pop()
43-
while stack[-1] != depth:
44-
text = f"\n{' ' * tab * stack[-1]}END"
45-
_print(text)
46-
stack.pop()
4735

36+
else_flag = _handle_else_flag(tab, stack, depth)
4837
_debug(row, stack)
49-
50-
# infer node type
51-
node_type = 'leaf'
52-
if (row.count('<') > 0) | (row.count('>') > 0):
53-
node_type = 'split'
54-
55-
text = ''
38+
39+
node_type = _get_node_type(row)
5640
if node_type == 'leaf':
57-
stack.pop()
58-
59-
# infer task
60-
if task_type == '':
61-
if row.find('class') != -1:
62-
task_type = 'classification'
63-
else:
64-
task_type = 'regression'
65-
66-
# infer value
67-
start_idx = row.index(':') + 1
68-
if task_type == 'regression':
69-
text = row[start_idx + 2:-1]
70-
else:
71-
text = row[start_idx + 1:]
72-
73-
after = '' # handle cases to put END
74-
if i < len(rule) - 1:
75-
if rule[i + 1].count('|') <= stack[-2]:
76-
after = f"\n{' ' * tab * (depth - 1)}END"
77-
78-
text = f" {text}{after}"
79-
_print(text)
80-
else: # split/internal node
81-
text = ''
82-
if else_flag:
83-
text = f"\n{' ' * indent}ELSE"
84-
_print(text)
85-
else:
86-
start_idx = depth * spacing + depth + 1
87-
text = row[start_idx:]
88-
text = f"\n{' ' * indent}CASE WHEN {text} THEN"
89-
_print(text)
41+
_process_leaf_node(tab, rule, task_type, stack, i, row, indent)
42+
else:
43+
_process_split_node(spacing, row, depth, indent, else_flag)
9044

9145
_debug(stack)
92-
46+
_post_process(column_name, src_table, tab, stack)
47+
sys.stdout = OUT
48+
49+
50+
def _post_process(column_name, src_table, tab, stack):
9351
while stack[-1] > 0:
94-
text = f"\n{' ' * tab * stack[-1]}END"
52+
prev_depth = stack[-1]
53+
text = f"\n{' ' * prev_depth * tab}END"
9554
if stack[-1] == 1:
9655
text += f" AS {column_name}\nFROM {src_table};\n"
9756

9857
_print(text)
9958
stack.pop()
100-
101-
sys.stdout = OUT
59+
60+
61+
def _process_split_node(spacing, row, depth, indent, else_flag):
62+
text = ''
63+
if else_flag:
64+
text = f"\n{' ' * indent}ELSE"
65+
_print(text)
66+
else:
67+
start_idx = depth * spacing + depth + 1
68+
text = row[start_idx:]
69+
text = f"\n{' ' * indent}CASE WHEN {text} THEN"
70+
_print(text)
71+
72+
73+
def _process_leaf_node(tab, rule, task_type, stack, i, row, indent):
74+
stack.pop()
75+
76+
if task_type == '': # infer task
77+
task_type = _get_task_type(row)
78+
79+
# infer value
80+
start_idx = row.index(':') + 1
81+
if task_type == 'regression':
82+
text = row[start_idx + 2:-1]
83+
else:
84+
text = row[start_idx + 1:]
85+
86+
after = '' # handle cases to put an END
87+
if i < len(rule) - 1:
88+
if rule[i + 1].count('|') <= stack[-2]:
89+
after = f"\n{' ' * (indent - tab)}END"
90+
91+
text = f" {text}{after}"
92+
_print(text)
93+
94+
95+
def _get_task_type(row):
96+
if row.find('class') != -1:
97+
task_type = 'classification'
98+
else:
99+
task_type = 'regression'
100+
return task_type
101+
102+
103+
def _get_node_type(row):
104+
node_type = 'leaf'
105+
if (row.count('<') > 0) | (row.count('>') > 0):
106+
node_type = 'split'
107+
return node_type
108+
109+
110+
def _handle_else_flag(tab, stack, depth):
111+
else_flag = False
112+
if depth > stack[-1]:
113+
stack.append(depth)
114+
else:
115+
else_flag = True
116+
if depth < stack[-1]:
117+
stack.pop()
118+
while stack[-1] != depth:
119+
prev_depth = stack[-1]
120+
text = f"\n{' ' * prev_depth * tab}END"
121+
_print(text)
122+
stack.pop()
123+
return else_flag

0 commit comments

Comments
 (0)