Skip to content

Commit 7919eb1

Browse files
committed
self._g --> self.graph
1 parent b6b21fe commit 7919eb1

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

graphml_class/citation.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -473,41 +473,45 @@ def process(self):
473473
v.append(file_to_net[cited_key])
474474

475475
# Build our DGLGraph from the edge list :)
476-
self._g = dgl.graph((torch.tensor(u), torch.tensor(v)))
477-
self._g.ndata["x"] = self.featurize(file_to_net)
476+
self.graph = dgl.graph((torch.tensor(u), torch.tensor(v)))
477+
self.graph.ndata["feat"] = self.featurize(file_to_net)
478478

479479
#
480480
# Build label set for link prediction: balanced positive / negative, count = number of edges
481481
#
482482

483483
# The real edges are the positive labels
484-
pos_src, pos_dst = self._g.edges()
484+
pos_src, pos_dst = self.graph.edges()
485485

486486
# Negative sample labels sized by the number of actual, positive edges
487-
neg_src, neg_dst = global_uniform_negative_sampling(self._g, self._g.number_of_edges())
487+
neg_src, neg_dst = global_uniform_negative_sampling(
488+
self.graph, self.graph.number_of_edges()
489+
)
488490

489491
# Combine positive and negative samples
490492
self.src_nodes = torch.cat([pos_src, neg_src])
491493
self.dst_nodes = torch.cat([pos_dst, neg_dst])
492494

493495
# Generate labels: 1 for positive and 0 for negative samples
494-
self.labels = torch.cat([torch.ones_like(pos_src), torch.zeros_like(neg_src)]).float()
496+
self.graph.ndata["label"] = torch.cat(
497+
[torch.ones_like(pos_src), torch.zeros_like(neg_src)]
498+
).float()
495499

496500
def __getitem__(self, idx):
497501
assert idx == 0, "This dataset has only one graph"
498-
return self._g
502+
return self.graph
499503

500504
def __len__(self):
501505
"""__len__ we are one graph long"""
502506
return 1
503507

504508
def save(self):
505509
"""save Save our one graph to directory `self.save_path`"""
506-
save_graphs(self._save_path, [self._g])
510+
save_graphs(self._save_path, [self.graph])
507511

508512
def load(self):
509513
"""load Load processed data from directory `self.save_path`"""
510-
self._g = load_graphs(self._save_path)[0]
514+
self.graph = load_graphs(self._save_path)[0]
511515

512516
def has_cache(self):
513517
"""has_cache Check whether there are processed data in `self.save_path`"""

0 commit comments

Comments
 (0)