Official implementation for
Graph Neural Networks for Learning Equivariant Representations of Neural Networks Miltiadis Kofinas*, Boris Knyazev, Yan Zhang, Yunlu Chen, Gertjan J. Burghouts, Efstratios Gavves, Cees G. M. Snoek, David W. Zhang* ICLR 2024 https://arxiv.org/abs/2403.12143/ *Joint first and last authors
To run the experiments, first create a clean virtual environment and install the requirements.
conda create -n neural-graphs python=3.9
conda activate neural-graphs
conda install pytorch==2.0.1 torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
conda install pyg==2.3.0 pytorch-scatter -c pyg
pip install hydra-core einops opencv-python
Install the repo:
git clone https://https://github.com/mkofinas/neural-graphs.git
cd neural-graphs
pip install -e .
An introduction notebook for INR classification with Neural Graphs:
To run a specific experiment, please follow the instructions in the README file within each experiment folder. It provides full instructions and details for downloading the data and reproducing the results reported in the paper.
- INR classification:
experiments/inr_classification
- INR style editing:
experiments/style_editing
- CNN generalization:
experiments/cnn_generalization
- Learning to optimize (coming soon):
experiments/learning_to_optimize
For INR classification, we use MNIST and Fashion MNIST. The datasets are available here.
For INR style editing, we use MNIST. The dataset is available here.
For CNN generalization, we use the grayscale CIFAR-10 (CIFAR10-GS) from the Small CNN Zoo dataset. We also introduce CNN Wild Park, a dataset of CNNs with varying numbers of layers, kernel sizes, activation functions, and residual connections between arbitrary layers.
If you find our work or this code to be useful in your own research, please consider citing the following paper:
@inproceedings{kofinas2024graph,
title={{G}raph {N}eural {N}etworks for {L}earning {E}quivariant {R}epresentations of {N}eural {N}etworks},
author={Kofinas, Miltiadis and Knyazev, Boris and Zhang, Yan and Chen, Yunlu and Burghouts,
Gertjan J. and Gavves, Efstratios and Snoek, Cees G. M. and Zhang, David W.},
booktitle = {12th International Conference on Learning Representations ({ICLR})},
year={2024}
}
@inproceedings{zhang2023neural,
title={{N}eural {N}etworks {A}re {G}raphs! {G}raph {N}eural {N}etworks for {E}quivariant {P}rocessing of {N}eural {N}etworks},
author={Zhang, David W. and Kofinas, Miltiadis and Zhang, Yan and Chen, Yunlu and Burghouts, Gertjan J. and Snoek, Cees G. M.},
booktitle = {Workshop on Topology, Algebra, and Geometry in Machine Learning (TAG-ML), ICML},
year={2023}
}
- This codebase started based on github.com/AvivNavon/DWSNets and the DWSNet implementation is copied from there
- The NFN implementation is copied and slightly adapted from github.com/AllanYangZhou/nfn
- We implemented the relational transformer in PyTorch following the JAX implementation at github.com/CameronDiao/relational-transformer. Our implementation has some differences that we describe in the paper.