Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

您好,关于VSUAModule.py中GNN类的一点问题 #8

Open
Davidwdq opened this issue Oct 16, 2020 · 1 comment
Open

您好,关于VSUAModule.py中GNN类的一点问题 #8

Davidwdq opened this issue Oct 16, 2020 · 1 comment

Comments

@Davidwdq
Copy link

Davidwdq commented Oct 16, 2020

def forward(self, obj_vecs, attr_vecs, rela_vecs, edges, rela_masks=None):
    # for easily indexing the subject and object of each relation in the tensors
    obj_vecs, attr_vecs, rela_vecs, edges, ori_shape = self.feat_3d_to_2d(obj_vecs, attr_vecs, rela_vecs, edges)

    # obj
    new_obj_vecs = obj_vecs

    # attr
    new_attr_vecs = self.gnn_attr(torch.cat([obj_vecs, attr_vecs], dim=-1)) + attr_vecs

    # rela
    # get node features for each triplet <subject, relation, object>
    s_idx = edges[:, 0].contiguous() # index of subject
    o_idx = edges[:, 1].contiguous() # index of object
    s_vecs = obj_vecs[s_idx]
    o_vecs = obj_vecs[o_idx]
    if self.opt.rela_gnn_type == 0:
        t_vecs = torch.cat([s_vecs, rela_vecs, o_vecs], dim=1)
    elif self.opt.rela_gnn_type == 1:
        t_vecs = torch.cat([s_vecs + o_vecs, rela_vecs], dim=1)
    else:
        raise NotImplementedError()
    new_rela_vecs = self.gnn_rela(t_vecs)+rela_vecs

    new_obj_vecs, new_attr_vecs, new_rela_vecs = self.feat_2d_to_3d(new_obj_vecs, new_attr_vecs, new_rela_vecs, rela_masks, ori_shape)

    return new_obj_vecs, new_attr_vecs, new_rela_vecs

您好,我想问一下这里new_obj_vecs,new_attr_vecs,new_rela_vecs是GNN优化后的图像的三种类型的特征吗

@ltguo19
Copy link
Owner

ltguo19 commented Oct 19, 2020

Yes, you are right.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants