-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_test_split_edges.py
86 lines (56 loc) · 2.63 KB
/
train_test_split_edges.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import math
import torch
from torch_geometric.utils import to_undirected
def train_split_edges(data):
r"""Splits the edges of a :obj:`torch_geometric.data.Data` object
into positive and negative train/val/test edges, and adds attributes of
`train_pos_edge_index`, `train_neg_adj_mask`, `val_pos_edge_index`,
`val_neg_edge_index`, `test_pos_edge_index`, and `test_neg_edge_index`
to :attr:`data`.
Args:
data (Data): The data object.
val_ratio (float, optional): The ratio of positive validation
edges. (default: :obj:`0.05`)
test_ratio (float, optional): The ratio of positive test
edges. (default: :obj:`0.1`)
:rtype: :class:`torch_geometric.data.Data`
"""
#assert 'batch' not in data # No batch-mode.
num_nodes = data.num_nodes
row, col = data.edge_index
if data.edge_index.shape[1] == 0:
data.train_pos_edge_index = data.edge_index
data.train_neg_edge_index = torch.empty((2,0), dtype=torch.long)
data.test_pos_edge_index = torch.empty((2,0), dtype=torch.long)
data.test_neg_edge_index = torch.empty((2,0), dtype=torch.long)
data.val_pos_edge_index = torch.empty((2,0), dtype=torch.long)
data.val_neg_edge_index = torch.empty((2,0), dtype=torch.long)
data.edge_index = None
return data
data.edge_index = None
# Return upper triangular portion.
mask = row > col
row, col = row[mask], col[mask]
# Positive edges.
perm = torch.randperm(row.size(0))
row, col = row[perm], col[perm]
data.test_pos_edge_index = torch.empty((2,0), dtype=torch.long)
data.val_pos_edge_index = torch.empty((2,0), dtype=torch.long)
r, c = row[:], col[:]
data.train_pos_edge_index = torch.stack([r, c], dim=0)
data.train_pos_edge_index = to_undirected(data.train_pos_edge_index)
# Negative edges.
neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8)
neg_adj_mask = neg_adj_mask.tril(diagonal=1).to(torch.bool) #neg_adj_mask = neg_adj_mask.triu(diagonal=1).to(torch.bool)
neg_adj_mask[row, col] = 0
neg_row, neg_col = neg_adj_mask.nonzero(as_tuple=False).t()
perm = torch.randperm(neg_row.size(0))[:]
perm = torch.randperm(neg_row.size(0))[:]
neg_row, neg_col = neg_row[perm], neg_col[perm]
neg_adj_mask[neg_row, neg_col] = 0
data.train_neg_adj_mask = neg_adj_mask
row, col = neg_row[:], neg_col[:]
data.train_neg_edge_index = torch.stack([row, col], dim=0)
data.test_neg_edge_index = torch.empty((2,0), dtype=torch.long)
data.val_neg_edge_index = torch.empty((2,0), dtype=torch.long)
return data