Skip to content

Commit c362dba

Browse files
committed
Add Rearrange module util
1 parent 43a4822 commit c362dba

File tree

1 file changed

+44
-4
lines changed

1 file changed

+44
-4
lines changed

test/utils.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from itertools import product
44
from typing import Iterable, List, Tuple
55

6+
from einops import rearrange
67
from numpy import eye, ndarray
78
from torch import Tensor, cat, cuda, device, from_numpy, rand, randint
89
from torch.nn import Module, Parameter
@@ -24,13 +25,28 @@ def get_available_devices():
2425
return devices
2526

2627

27-
def classification_targets(size, num_classes):
28-
"""Create random targets for classes 0, ..., `num_classes - 1`."""
28+
def classification_targets(size: Tuple[int], num_classes: int) -> Tensor:
29+
"""Create random targets for classes 0, ..., `num_classes - 1`.
30+
31+
Args:
32+
size: Size of the targets to create.
33+
num_classes: Number of classes.
34+
35+
Returns:
36+
Random targets.
37+
"""
2938
return randint(size=size, low=0, high=num_classes)
3039

3140

32-
def regression_targets(size):
33-
"""Create random targets for regression."""
41+
def regression_targets(size: Tuple[int]) -> Tensor:
42+
"""Create random targets for regression.
43+
44+
Args:
45+
size: Size of the targets to create.
46+
47+
Returns:
48+
Random targets.
49+
"""
3450
return rand(*size)
3551

3652

@@ -91,3 +107,27 @@ def ggn_block_diagonal(
91107

92108
# concatenate all blocks
93109
return cat([cat(row_blocks, dim=1) for row_blocks in ggn_blocks], dim=0).numpy()
110+
111+
112+
class Rearrange(Module):
113+
"""A module that rearranges the input tensor."""
114+
115+
def __init__(self, pattern: str):
116+
"""Initialize the module.
117+
118+
Args:
119+
pattern: The rearrangement pattern.
120+
"""
121+
super().__init__()
122+
self.pattern = pattern
123+
124+
def forward(self, x: Tensor) -> Tensor:
125+
"""Rearrange the input tensor.
126+
127+
Args:
128+
x: The input tensor.
129+
130+
Returns:
131+
The rearranged tensor.
132+
"""
133+
return rearrange(x, self.pattern)

0 commit comments

Comments
 (0)