Skip to content

AttributeError: 'dict' object has no attribute 'grad_fn' occured when visualizing the network structure #10

Open
@Jacoobr

Description

@Jacoobr

Hi, @amlankar . When i try to visulize the MLE training network structure by add the **make_dot()** function from the package torchviz, the code i added in train(self, epoch) function of train_ce.py script as bellow:

    def train(self, epoch):
        print 'Starting training'
        self.model.train()

        accum = defaultdict(float)
        # To accumulate stats for printing
        for step, data in enumerate(self.train_loader):
            if self.global_step % self.opts['val_freq'] == 0:
                self.validate()
                self.save_checkpoint(epoch)             

            # Forward pass
            # data['img'] = Variable(data['img'], requires_grad=True)
            # data['fwd_poly'] = Variable(data['fwd_poly'], requires_grad=True)   # Variable data['fwd_poly'] used for correction interactive
            input1 = data['img']
            input2 = data['fwd_poly']
            output = self.model(input1.to(device), input2.to(device))
            ## used for generating '.dot' format network structure graph
            g = make_dot(output, params=dict(list(polyrnnpp.PolyRNNpp(self.opts).named_parameters())+[('input1', input1), ('input2', input2)]))
            g.render('./graph', view=False)

Then i get the 'grad_fn' Error like this:

 g = make_dot(output, params=dict(list(polyrnnpp.PolyRNNpp(self.opts).named_parameters())+[('input1', input1), ('input2', input2)]))
  File "/home/tzq-lxj/softWare/anaconda3/envs/polygonRNNPP/lib/python2.7/site-packages/torchviz/dot.py", line 37, in make_dot
    output_nodes = (var.grad_fn,) if not isinstance(var, tuple) else tuple(v.grad_fn for v in var)
AttributeError: 'dict' object has no attribute 'grad_fn'

Why the grad_fn of Variable output just gone when the make_dot() function called and how can i fix the 'grad_fn' Error ? And i try to use SummaryWriter of tensorboardX to do the same thing also, when testing by generate_annotation.py script and training by mle_ce.py script, I got the same Error. What's more important, if you have some other way to visulize the MLE network structure please let me konw. Thank you. Appreciative for your reply.

Metadata

Metadata

Assignees

No one assigned

    Labels

    help wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions