-
Notifications
You must be signed in to change notification settings - Fork 4
/
data_loading.py
48 lines (38 loc) · 1.08 KB
/
data_loading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import json
from dataclasses import dataclass
from pathlib import Path
from torch_geometric.datasets import TUDataset
DATASETS_DIR = Path("datasets")
DATA_SPLITS_DIR = Path("data_splits")
DATASET_NAMES = [
"DD",
"NCI1",
"PROTEINS_full",
"ENZYMES",
"IMDB-BINARY",
"IMDB-MULTI",
"REDDIT-BINARY",
"REDDIT-MULTI-5K",
"COLLAB",
]
@dataclass
class DatasetSplit:
train_idxs: list[int]
test_idxs: list[int]
def load_dataset_splits(dataset_name: str) -> list[DatasetSplit]:
if dataset_name not in DATASET_NAMES:
raise ValueError(
f"Dataset {dataset_name} not recognized. It has to be one of: {DATASET_NAMES}"
)
file_path = DATA_SPLITS_DIR / f"{dataset_name}.json"
with open(file_path) as file:
splits = json.load(file)
splits = [DatasetSplit(split["train"], split["test"]) for split in splits]
return splits
def load_dataset(dataset_name: str) -> TUDataset:
return TUDataset(
root=str(DATASETS_DIR),
name=dataset_name,
use_node_attr=True,
use_edge_attr=True,
)