Skip to content

Commit 839367f

Browse files
committed
limit number of points in decision tree regressor visualization (#462)
1 parent ea24a47 commit 839367f

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

supervised/algorithms/decision_tree.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from sklearn.tree import _tree
2222
from dtreeviz.trees import dtreeviz
23+
from supervised.utils.subsample import subsample
2324

2425

2526
def get_rules(tree, feature_names, class_names):
@@ -204,14 +205,25 @@ def interpret(
204205
if explain_level == 0:
205206
return
206207
try:
207-
208-
viz = dtreeviz(
209-
self.model,
210-
X_train,
211-
y_train,
212-
target_name="target",
213-
feature_names=X_train.columns,
214-
)
208+
# 250 is hard limit for number of points used in visualization
209+
# if too many points are used then final SVG plot is very large (can be > 100MB)
210+
if X_train.shape[0] > 250:
211+
x, _, y, _ = subsample(X_train, y_train, REGRESSION, 250)
212+
viz = dtreeviz(
213+
self.model,
214+
x,
215+
y,
216+
target_name="target",
217+
feature_names=x.columns,
218+
)
219+
else:
220+
viz = dtreeviz(
221+
self.model,
222+
X_train,
223+
y_train,
224+
target_name="target",
225+
feature_names=X_train.columns,
226+
)
215227
tree_file_plot = os.path.join(model_file_path, learner_name + "_tree.svg")
216228
viz.save(tree_file_plot)
217229
except Exception as e:

0 commit comments

Comments
 (0)