3
3
from itertools import product
4
4
from typing import Iterable , List , Tuple
5
5
6
+ from einops import rearrange
6
7
from numpy import eye , ndarray
7
8
from torch import Tensor , cat , cuda , device , from_numpy , rand , randint
8
9
from torch .nn import Module , Parameter
@@ -24,13 +25,28 @@ def get_available_devices():
24
25
return devices
25
26
26
27
27
- def classification_targets (size , num_classes ):
28
- """Create random targets for classes 0, ..., `num_classes - 1`."""
28
+ def classification_targets (size : Tuple [int ], num_classes : int ) -> Tensor :
29
+ """Create random targets for classes 0, ..., `num_classes - 1`.
30
+
31
+ Args:
32
+ size: Size of the targets to create.
33
+ num_classes: Number of classes.
34
+
35
+ Returns:
36
+ Random targets.
37
+ """
29
38
return randint (size = size , low = 0 , high = num_classes )
30
39
31
40
32
- def regression_targets (size ):
33
- """Create random targets for regression."""
41
+ def regression_targets (size : Tuple [int ]) -> Tensor :
42
+ """Create random targets for regression.
43
+
44
+ Args:
45
+ size: Size of the targets to create.
46
+
47
+ Returns:
48
+ Random targets.
49
+ """
34
50
return rand (* size )
35
51
36
52
@@ -91,3 +107,27 @@ def ggn_block_diagonal(
91
107
92
108
# concatenate all blocks
93
109
return cat ([cat (row_blocks , dim = 1 ) for row_blocks in ggn_blocks ], dim = 0 ).numpy ()
110
+
111
+
112
+ class Rearrange (Module ):
113
+ """A module that rearranges the input tensor."""
114
+
115
+ def __init__ (self , pattern : str ):
116
+ """Initialize the module.
117
+
118
+ Args:
119
+ pattern: The rearrangement pattern.
120
+ """
121
+ super ().__init__ ()
122
+ self .pattern = pattern
123
+
124
+ def forward (self , x : Tensor ) -> Tensor :
125
+ """Rearrange the input tensor.
126
+
127
+ Args:
128
+ x: The input tensor.
129
+
130
+ Returns:
131
+ The rearranged tensor.
132
+ """
133
+ return rearrange (x , self .pattern )
0 commit comments