@@ -473,41 +473,45 @@ def process(self):
473
473
v .append (file_to_net [cited_key ])
474
474
475
475
# Build our DGLGraph from the edge list :)
476
- self ._g = dgl .graph ((torch .tensor (u ), torch .tensor (v )))
477
- self ._g .ndata ["x " ] = self .featurize (file_to_net )
476
+ self .graph = dgl .graph ((torch .tensor (u ), torch .tensor (v )))
477
+ self .graph .ndata ["feat " ] = self .featurize (file_to_net )
478
478
479
479
#
480
480
# Build label set for link prediction: balanced positive / negative, count = number of edges
481
481
#
482
482
483
483
# The real edges are the positive labels
484
- pos_src , pos_dst = self ._g .edges ()
484
+ pos_src , pos_dst = self .graph .edges ()
485
485
486
486
# Negative sample labels sized by the number of actual, positive edges
487
- neg_src , neg_dst = global_uniform_negative_sampling (self ._g , self ._g .number_of_edges ())
487
+ neg_src , neg_dst = global_uniform_negative_sampling (
488
+ self .graph , self .graph .number_of_edges ()
489
+ )
488
490
489
491
# Combine positive and negative samples
490
492
self .src_nodes = torch .cat ([pos_src , neg_src ])
491
493
self .dst_nodes = torch .cat ([pos_dst , neg_dst ])
492
494
493
495
# Generate labels: 1 for positive and 0 for negative samples
494
- self .labels = torch .cat ([torch .ones_like (pos_src ), torch .zeros_like (neg_src )]).float ()
496
+ self .graph .ndata ["label" ] = torch .cat (
497
+ [torch .ones_like (pos_src ), torch .zeros_like (neg_src )]
498
+ ).float ()
495
499
496
500
def __getitem__ (self , idx ):
497
501
assert idx == 0 , "This dataset has only one graph"
498
- return self ._g
502
+ return self .graph
499
503
500
504
def __len__ (self ):
501
505
"""__len__ we are one graph long"""
502
506
return 1
503
507
504
508
def save (self ):
505
509
"""save Save our one graph to directory `self.save_path`"""
506
- save_graphs (self ._save_path , [self ._g ])
510
+ save_graphs (self ._save_path , [self .graph ])
507
511
508
512
def load (self ):
509
513
"""load Load processed data from directory `self.save_path`"""
510
- self ._g = load_graphs (self ._save_path )[0 ]
514
+ self .graph = load_graphs (self ._save_path )[0 ]
511
515
512
516
def has_cache (self ):
513
517
"""has_cache Check whether there are processed data in `self.save_path`"""
0 commit comments