Skip to content

Commit 88e57ac

Browse files
committed
more README
1 parent e036ebc commit 88e57ac

File tree

5 files changed

+59
-37
lines changed

5 files changed

+59
-37
lines changed

benchmarking/pyg_serial.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ def create_parser():
3131

3232

3333
def get_dataset(download_path=None):
34-
dataset = PygNodePropPredDataset(name="ogbn-products", root=input_dir, transform=T.NormalizeFeatures())
34+
dataset = PygNodePropPredDataset(
35+
name="ogbn-products", root=input_dir, transform=T.NormalizeFeatures()
36+
)
3537
gcn_norm = T.GCNNorm()
3638
return (gcn_norm.forward(dataset[0]), dataset.num_classes)
3739

benchmarking/spmm.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
def multiply_sharded_matrices_padded(
8-
pt_file
8+
pt_file,
99
shard_row,
1010
shard_col,
1111
shard_x_col,
@@ -36,14 +36,8 @@ def multiply_sharded_matrices_padded(
3636
original_x_cols = x.shape[1]
3737

3838
# Calculate padded dimensions for edge_index (implied adjacency matrix)
39-
padded_rows = (
40-
(original_num_nodes + shard_row - 1) // shard_row * shard_row
41-
)
42-
padded_cols_x = (
43-
(original_num_nodes + shard_col - 1)
44-
// shard_col
45-
* shard_col
46-
)
39+
padded_rows = (original_num_nodes + shard_row - 1) // shard_row * shard_row
40+
padded_cols_x = (original_num_nodes + shard_col - 1) // shard_col * shard_col
4741

4842
# Calculate padded dimensions for x
4943
padded_x_rows = (
@@ -54,9 +48,7 @@ def multiply_sharded_matrices_padded(
5448
padded_x_cols = original_x_cols
5549
else:
5650
padded_x_cols = (
57-
(original_x_cols + shard_x_col - 1)
58-
// shard_x_col
59-
* shard_x_col
51+
(original_x_cols + shard_x_col - 1) // shard_x_col * shard_x_col
6052
)
6153

6254
# Calculate shard sizes for padded dimensions
@@ -99,7 +91,10 @@ def multiply_sharded_matrices_padded(
9991
x_end_col = x_col_shard_size
10092
sharded_x = padded_x[x_start_row:x_end_row, x_start_col:x_end_col]
10193

102-
print("Theoretical # of FLOPs (2 * NNZ * D): " + str(2 * sharded_adj_t._nnz() * sharded_x.shape[1]))
94+
print(
95+
"Theoretical # of FLOPs (2 * NNZ * D): "
96+
+ str(2 * sharded_adj_t._nnz() * sharded_x.shape[1])
97+
)
10398

10499
# Move tensors to CUDA if available
105100
if torch.cuda.is_available():
@@ -151,7 +146,7 @@ def multiply_sharded_matrices_padded(
151146
parser.add_argument(
152147
"pt_file",
153148
type=int,
154-
help="Path to plexus processed .pt file containing the data"
149+
help="Path to plexus processed .pt file containing the data",
155150
)
156151
parser.add_argument(
157152
"shard_row",
@@ -189,5 +184,5 @@ def multiply_sharded_matrices_padded(
189184
args.shard_col,
190185
args.shard_x_col,
191186
args.iterations,
192-
args.warmup
187+
args.warmup,
193188
)

plexus/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
## Files
2+
3+
- **gcn_conv.py**: This file implements a 3D tensor-parallel version of the `GCNConv` layer, a fundamental component in Graph Convolutional Networks.
4+
5+
- **cross_entropy.py**: This file provides a parallel implementation of the cross-entropy loss function, a standard loss function used for node-level classification.
6+
7+
- **utils/**: This subdirectory contains several utility modules that provide essential functionalities for the Plexus framework:
8+
9+
- **general.py**: This module includes generic utility functions used throughout the framework, including the following.
10+
- setting s random seed for reproducible experiments.
11+
- padfinh a number to make it divisible by another number, which is helpful when sharding.
12+
- functions for retrieving process group information
13+
- functions for printing timing information
14+
15+
- **dataset.py**: This module provides utilities for preprocessing graph datasets. Key functions include:
16+
- `preprocess_graph()`: Preprocesses a graph dataset. This includes normalizing features and the adjacency matrix, and applying the double permutation scheme specific to Plexus. It is recommended to use the `set_seed` function from `general.py` before calling this function, as random initialization is used for features in datasets that do not originally contain them.
17+
- `partition_graph_2d()`: Statically 2D partitions a preprocessed graph, creating an individual file for each 2D matrix partition. This allows for distributing the data across multiple devices, preventing each GPU from having to load the entire dataset.
18+
- Other utility functions for data conversion and manipulation.
19+
20+
- **dataloader.py**: This module contains the `DataLoader` class, which is responsible for efficiently loading preprocessed graph data. The `DataLoader` supports two modes:
21+
- Loading unpartitioned (original) preprocessed data.
22+
- Loading partitioned data generated by the `partition_graph_2d()` function. In this case, the `DataLoader` automatically determines which files to load for each GPU and extracts the relevant data shards.
23+

plexus/utils/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
## Files
2+
3+
- **gcn_conv.py**: This file implements a 3D tensor-parallel version of the `GCNConv` layer, a fundamental component in Graph Convolutional Networks.
4+
5+
- **cross_entropy.py**: This file provides a parallel implementation of the cross-entropy loss function, a standard loss function used for node-level classification.
6+
7+
- **utils/**: This subdirectory contains several utility modules that provide essential functionalities for the Plexus framework:
8+
9+
- **general.py**: This module includes generic utility functions used throughout the framework, including the following.
10+
- setting s random seed for reproducible experiments.
11+
- padfinh a number to make it divisible by another number, which is helpful when sharding.
12+
- functions for retrieving process group information
13+
- functions for printing timing information
14+
15+
- **dataset.py**: This module provides utilities for preprocessing graph datasets. Key functions include:
16+
- `preprocess_graph()`: Preprocesses a graph dataset. This includes normalizing features and the adjacency matrix, and applying the double permutation scheme specific to Plexus. It is recommended to use the `set_seed` function from `general.py` before calling this function, as random initialization is used for features in datasets that do not originally contain them.
17+
- `partition_graph_2d()`: Statically 2D partitions a preprocessed graph, creating an individual file for each 2D matrix partition. This allows for distributing the data across multiple devices, preventing each GPU from having to load the entire dataset.
18+
- Other utility functions for data conversion and manipulation.
19+
20+
- **dataloader.py**: This module contains the `DataLoader` class, which is responsible for efficiently loading preprocessed graph data. The `DataLoader` supports two modes:
21+
- Loading unpartitioned (original) preprocessed data.
22+
- Loading partitioned data generated by the `partition_graph_2d()` function. In this case, the `DataLoader` automatically determines which files to load for each GPU and extracts the relevant data shards.
23+

plexus/utils/dataset.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -507,24 +507,3 @@ def process_partition(chunk_idx_dim1, chunk_idx_dim2):
507507
for future in futures:
508508
future.result() # Ensure completion
509509

510-
511-
if __name__ == "__main__":
512-
# don't delete these two lines
513-
set_seed(0)
514-
torch.serialization.add_safe_globals(
515-
[dtype, scalar, GlobalStorage, DataEdgeAttr, DataTensorAttr]
516-
)
517-
518-
"""
519-
preprocess_graph(
520-
"papers",
521-
"/pscratch/sd/a/aranjan/gnn-env/gnn-datasets/original",
522-
"/pscratch/sd/a/aranjan/gnn-env/gnn-datasets/papers",
523-
)
524-
"""
525-
526-
partition_graph_2d(
527-
"/pscratch/sd/a/aranjan/gnn-env/gnn-datasets/papers/processed_papers.pt",
528-
16,
529-
"/pscratch/sd/a/aranjan/gnn-env/gnn-datasets/partitioned_papers",
530-
)

0 commit comments

Comments
 (0)