Open
Description
Hi, @amlankar . When i try to visulize the MLE training network structure by add the **make_dot()**
function from the package torchviz, the code i added in train(self, epoch) function of train_ce.py script as bellow:
def train(self, epoch):
print 'Starting training'
self.model.train()
accum = defaultdict(float)
# To accumulate stats for printing
for step, data in enumerate(self.train_loader):
if self.global_step % self.opts['val_freq'] == 0:
self.validate()
self.save_checkpoint(epoch)
# Forward pass
# data['img'] = Variable(data['img'], requires_grad=True)
# data['fwd_poly'] = Variable(data['fwd_poly'], requires_grad=True) # Variable data['fwd_poly'] used for correction interactive
input1 = data['img']
input2 = data['fwd_poly']
output = self.model(input1.to(device), input2.to(device))
## used for generating '.dot' format network structure graph
g = make_dot(output, params=dict(list(polyrnnpp.PolyRNNpp(self.opts).named_parameters())+[('input1', input1), ('input2', input2)]))
g.render('./graph', view=False)
Then i get the 'grad_fn' Error like this:
g = make_dot(output, params=dict(list(polyrnnpp.PolyRNNpp(self.opts).named_parameters())+[('input1', input1), ('input2', input2)]))
File "/home/tzq-lxj/softWare/anaconda3/envs/polygonRNNPP/lib/python2.7/site-packages/torchviz/dot.py", line 37, in make_dot
output_nodes = (var.grad_fn,) if not isinstance(var, tuple) else tuple(v.grad_fn for v in var)
AttributeError: 'dict' object has no attribute 'grad_fn'
Why the grad_fn of Variable output just gone when the make_dot() function called and how can i fix the 'grad_fn' Error ? And i try to use SummaryWriter of tensorboardX to do the same thing also, when testing by generate_annotation.py script and training by mle_ce.py script, I got the same Error. What's more important, if you have some other way to visulize the MLE network structure please let me konw. Thank you. Appreciative for your reply.