1111from urllib .error import HTTPError
1212
1313# PyTorch Lightning
14- import lightning as L
14+ import pytorch_lightning as pl
1515
1616# PyTorch
1717import torch
2525import torch_geometric .nn as geom_nn
2626
2727# PL callbacks
28- from lightning . pytorch .callbacks import ModelCheckpoint
28+ from pytorch_lightning .callbacks import ModelCheckpoint
2929from torch import Tensor
3030
3131AVAIL_GPUS = min (1 , torch .cuda .device_count ())
3636CHECKPOINT_PATH = os .environ .get ("PATH_CHECKPOINT" , "saved_models/GNNs/" )
3737
3838# Setting the seed
39- L .seed_everything (42 )
39+ pl .seed_everything (42 )
4040
4141# Ensure that all operations are deterministic on GPU (if used) for reproducibility
4242torch .backends .cudnn .deterministic = True
@@ -592,7 +592,7 @@ def forward(self, x, *args, **kwargs):
592592
593593
594594# %%
595- class NodeLevelGNN (L .LightningModule ):
595+ class NodeLevelGNN (pl .LightningModule ):
596596 def __init__ (self , model_name , ** model_kwargs ):
597597 super ().__init__ ()
598598 # Saving hyperparameters
@@ -654,13 +654,13 @@ def test_step(self, batch, batch_idx):
654654
655655# %%
656656def train_node_classifier (model_name , dataset , ** model_kwargs ):
657- L .seed_everything (42 )
657+ pl .seed_everything (42 )
658658 node_data_loader = geom_data .DataLoader (dataset , batch_size = 1 )
659659
660660 # Create a PyTorch Lightning trainer
661661 root_dir = os .path .join (CHECKPOINT_PATH , "NodeLevel" + model_name )
662662 os .makedirs (root_dir , exist_ok = True )
663- trainer = L .Trainer (
663+ trainer = pl .Trainer (
664664 default_root_dir = root_dir ,
665665 callbacks = [ModelCheckpoint (save_weights_only = True , mode = "max" , monitor = "val_acc" )],
666666 accelerator = "auto" ,
@@ -676,7 +676,7 @@ def train_node_classifier(model_name, dataset, **model_kwargs):
676676 print ("Found pretrained model, loading..." )
677677 model = NodeLevelGNN .load_from_checkpoint (pretrained_filename )
678678 else :
679- L .seed_everything ()
679+ pl .seed_everything ()
680680 model = NodeLevelGNN (
681681 model_name = model_name , c_in = dataset .num_node_features , c_out = dataset .num_classes , ** model_kwargs
682682 )
@@ -892,7 +892,7 @@ def forward(self, x, edge_index, batch_idx):
892892
893893
894894# %%
895- class GraphLevelGNN (L .LightningModule ):
895+ class GraphLevelGNN (pl .LightningModule ):
896896 def __init__ (self , ** model_kwargs ):
897897 super ().__init__ ()
898898 # Saving hyperparameters
@@ -941,12 +941,12 @@ def test_step(self, batch, batch_idx):
941941
942942# %%
943943def train_graph_classifier (model_name , ** model_kwargs ):
944- L .seed_everything (42 )
944+ pl .seed_everything (42 )
945945
946946 # Create a PyTorch Lightning trainer with the generation callback
947947 root_dir = os .path .join (CHECKPOINT_PATH , "GraphLevel" + model_name )
948948 os .makedirs (root_dir , exist_ok = True )
949- trainer = L .Trainer (
949+ trainer = pl .Trainer (
950950 default_root_dir = root_dir ,
951951 callbacks = [ModelCheckpoint (save_weights_only = True , mode = "max" , monitor = "val_acc" )],
952952 accelerator = "cuda" ,
@@ -962,7 +962,7 @@ def train_graph_classifier(model_name, **model_kwargs):
962962 print ("Found pretrained model, loading..." )
963963 model = GraphLevelGNN .load_from_checkpoint (pretrained_filename )
964964 else :
965- L .seed_everything (42 )
965+ pl .seed_everything (42 )
966966 model = GraphLevelGNN (
967967 c_in = tu_dataset .num_node_features ,
968968 c_out = 1 if tu_dataset .num_classes == 2 else tu_dataset .num_classes ,
0 commit comments