Skip to content

Commit 77b2083

Browse files
committed
se3cnn
1 parent e0d57df commit 77b2083

File tree

49 files changed

+113
-113
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+113
-113
lines changed

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@ This library aims to create SE(3) equivariant convolutional neural networks.
77

88
## Hierarchy
99

10-
- `se3_cnn` contains the library
11-
- `se3_cnn/convolution.py` defines `SE3Convolution` the main class of the library
12-
- `se3_cnn/blocks` defines ways of introducing non linearity in an equivariant way
13-
- `se3_cnn/batchnorm.py` equivariant batch normalization
14-
- `se3_cnn/groupnorm.py` equivariant group normalization
15-
- `se3_cnn/dropout.py` equivariant dropout
10+
- `se3cnn` contains the library
11+
- `se3cnn/convolution.py` defines `SE3Convolution` the main class of the library
12+
- `se3cnn/blocks` defines ways of introducing non linearity in an equivariant way
13+
- `se3cnn/batchnorm.py` equivariant batch normalization
14+
- `se3cnn/groupnorm.py` equivariant group normalization
15+
- `se3cnn/dropout.py` equivariant dropout
1616
- `experiments` contains experiments made with the library
1717
- `examples` simple scripts
1818

1919
## Dependencies
2020

21-
- [pytorch](https://pytorch.org)
21+
- [pytorch](https://pytorch.org)
2222
- [lie_learn](https://github.com/AMLab-Amsterdam/lie_learn) is required to compute the irreducible representations of SO(3)
2323
- scipy
2424

examples/example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
# The class GatedBlock inherit from the class torch.nn.Module.
1818
# It contains one convolution, some ReLU and multiplications
19-
from se3_cnn.blocks import GatedBlock
19+
from se3cnn.blocks import GatedBlock
2020

2121

2222
class AvgSpacial(nn.Module):

examples/plots/kernels.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import matplotlib.pyplot as plt
55
from matplotlib import cm
66
from mpl_toolkits.mplot3d import Axes3D # pylint: disable=W
7-
from se3_cnn.SO3 import irr_repr, spherical_harmonics
8-
from se3_cnn.basis_kernels import _basis_transformation_Q_J
9-
from se3_cnn.util.cache_file import cached_dirpklgz
10-
from se3_cnn.SO3 import compose
7+
from se3cnn.SO3 import irr_repr, spherical_harmonics
8+
from se3cnn.basis_kernels import _basis_transformation_Q_J
9+
from se3cnn.util.cache_file import cached_dirpklgz
10+
from se3cnn.SO3 import compose
1111

1212

1313
def beta_alpha(n):
@@ -87,9 +87,9 @@ def main():
8787
# f(r^-1 x)
8888

8989
f = np.einsum(
90-
"ij,zjkba,kl->zilba",
91-
irr_repr(args.order_out, args.alpha, args.beta, args.gamma),
92-
f,
90+
"ij,zjkba,kl->zilba",
91+
irr_repr(args.order_out, args.alpha, args.beta, args.gamma),
92+
f,
9393
irr_repr(args.order_in, -args.gamma, -args.beta, -args.alpha)
9494
)
9595
# rho_out(r) f(r^-1 x) rho_in(r^-1)

examples/plots/kernels_cutaway.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# pylint: disable=C,R,E1101
2-
from se3_cnn import basis_kernels
3-
from se3_cnn import SO3
2+
from se3cnn import basis_kernels
3+
from se3cnn import SO3
44
import numpy as np
55
import matplotlib.pyplot as plt
66
from functools import partial

examples/tetris.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numpy as np
88

9-
from se3_cnn.blocks import GatedBlock
9+
from se3cnn.blocks import GatedBlock
1010

1111

1212
class AvgSpacial(nn.Module):
@@ -86,8 +86,8 @@ def __init__(self):
8686
super(SE3Net, self).__init__()
8787
features = [
8888
(1,),
89-
(2, 2, 2, 2),
90-
(4, 4, 4, 4),
89+
(2, 2, 2, 2),
90+
(4, 4, 4, 4),
9191
(16,)
9292
]
9393
common_block_params = {

experiments/datasets/modelnet/modelnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from functools import partial
99

1010
from scipy.ndimage import affine_transform
11-
from se3_cnn.SO3 import rot
11+
from se3cnn.SO3 import rot
1212

1313

1414

@@ -43,7 +43,7 @@ class ModelNet(torch.utils.data.Dataset):
4343
def __init__(self, root_dir, dataset, mode, size, classes, transform=None, target_transform=None):
4444
'''
4545
:param root: directory to store dataset in
46-
:param dataset:
46+
:param dataset:
4747
:param mode: dataset to load: 'train', 'validation', 'test' or 'train_full'
4848
the validation set is split from the train set, the full train set can be accessed via 'train_full'
4949
:param transform: transformation applied to image in __getitem__
@@ -91,7 +91,7 @@ def __len__(self):
9191

9292

9393
class AddZAxis(object):
94-
''' add z-axis as second channel to volume
94+
''' add z-axis as second channel to volume
9595
the scale of the z-axis can be set freely
9696
if the volume tensor does not contain a channel dimension, add it
9797
the z axis is assumed to be the last axis of the volume tensor

experiments/scripts/MRI/networks/MICCAI2012/SE3BasicModel/SE3BasicModel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
from functools import partial
66

7-
from se3_cnn.blocks import GatedBlock
8-
from se3_cnn import basis_kernels
7+
from se3cnn.blocks import GatedBlock
8+
from se3cnn import basis_kernels
99

1010

1111
class network(nn.Module):

experiments/scripts/MRI/networks/MICCAI2012/SE3UNet/SE3UNet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
from functools import partial
66

7-
from se3_cnn.blocks import GatedBlock
8-
from se3_cnn import basis_kernels
7+
from se3cnn.blocks import GatedBlock
8+
from se3cnn import basis_kernels
99
from experiments.util.arch_blocks import Merge
1010

1111

experiments/scripts/MRI/networks/MRBrainS/SE3BasicModel/SE3BasicModel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
from functools import partial
66

7-
from se3_cnn.blocks import GatedBlock
8-
from se3_cnn import basis_kernels
7+
from se3cnn.blocks import GatedBlock
8+
from se3cnn import basis_kernels
99

1010

1111
class network(nn.Module):

experiments/scripts/MRI/networks/MRBrainS/SE3UNet/SE3UNet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from functools import partial
66

7-
from se3_cnn import basis_kernels
7+
from se3cnn import basis_kernels
88
from experiments.util.arch_blocks import NonlinearityBlock
99
from experiments.util.arch_blocks import SkipSumBlock
1010

@@ -31,7 +31,7 @@ def __init__(self, output_size, args):
3131
'batch_norm_momentum': 0.01,
3232
}
3333

34-
features = [(3,), # in
34+
features = [(3,), # in
3535
[(6,6,6,4), (16,16,16,12)], # level 1 (enc and dec)
3636
( 32, 32, 32, 24), # level 2 (enc and dec)
3737
( 64, 64, 64, 48), # level 3 (bridge)

0 commit comments

Comments
 (0)