@@ -22,10 +22,10 @@ def make_pts(N):
2222
2323
2424class Graph :
25- def __init__ (self , vis = False ):
25+ def __init__ (self , vis = False , vis_args = {} ):
2626 self .gifs = []
2727 if vis :
28- self .vis = visdom .Visdom ()
28+ self .vis = visdom .Visdom (** vis_args )
2929 else :
3030 self .vis = None
3131 self .first = True
@@ -74,8 +74,8 @@ def graph(self, outfile, model=None):
7474
7575
7676class Simple (Graph ):
77- def __init__ (self , N , vis = False ):
78- super ().__init__ (vis )
77+ def __init__ (self , N , vis = False , vis_args = {} ):
78+ super ().__init__ (vis , vis_args )
7979 self .N = N
8080 self .X = make_pts (N )
8181 self .y = []
@@ -85,8 +85,8 @@ def __init__(self, N, vis=False):
8585
8686
8787class Split (Graph ):
88- def __init__ (self , N , vis = False ):
89- super ().__init__ (vis )
88+ def __init__ (self , N , vis = False , vis_args = {} ):
89+ super ().__init__ (vis , vis_args )
9090 self .N = N
9191 self .X = make_pts (N )
9292 self .y = []
@@ -96,8 +96,8 @@ def __init__(self, N, vis=False):
9696
9797
9898class Xor (Graph ):
99- def __init__ (self , N , vis = False ):
100- super ().__init__ (vis )
99+ def __init__ (self , N , vis = False , vis_args = {} ):
100+ super ().__init__ (vis , vis_args )
101101 self .N = N
102102 self .X = make_pts (N )
103103 self .y = []
0 commit comments