Skip to content

Commit

Permalink
raise error when steerable basis for R2Conv is empty; rename filters var
Browse files Browse the repository at this point in the history
add test for output shape of R2Conv
rename filter variables in methods
  • Loading branch information
Gabri95 committed Nov 12, 2020
2 parents f89d8be + 734553a commit 274a9c6
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 82 deletions.
8 changes: 4 additions & 4 deletions e2cnn/nn/modules/pooling/pointwise_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,15 @@ def __init__(self,
r = -torch.sum((grid - mean) ** 2., dim=-1, dtype=torch.get_default_dtype())

# Build the gaussian kernel
filter = torch.exp(r / (2 * variance))
_filter = torch.exp(r / (2 * variance))

# Normalize
filter /= torch.sum(filter)
_filter /= torch.sum(_filter)

# The filter needs to be reshaped to be used in 2d depthwise convolution
filter = filter.view(1, 1, filter_size, filter_size).repeat((in_type.size, 1, 1, 1))
_filter = _filter.view(1, 1, filter_size, filter_size).repeat((in_type.size, 1, 1, 1))

self.register_buffer('filter', filter)
self.register_buffer('filter', _filter)

################################################################################################################

Expand Down
10 changes: 5 additions & 5 deletions e2cnn/nn/modules/pooling/pointwise_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(self,
"""

if dilation != 1:
raise NotImplementedError("Diltation larger than 1 is not supported yet")
raise NotImplementedError("Dilation larger than 1 is not supported yet")

super(PointwiseMaxPoolAntialiased, self).__init__(in_type, kernel_size, stride, padding, dilation, ceil_mode)

Expand All @@ -197,15 +197,15 @@ def __init__(self,
r = -torch.sum((grid - mean) ** 2., dim=-1, dtype=torch.get_default_dtype())

# Build the gaussian kernel
filter = torch.exp(r / (2 * variance))
_filter = torch.exp(r / (2 * variance))

# Normalize
filter /= torch.sum(filter)
_filter /= torch.sum(_filter)

# The filter needs to be reshaped to be used in 2d depthwise convolution
filter = filter.view(1, 1, filter_size, filter_size).repeat((in_type.size, 1, 1, 1))
_filter = _filter.view(1, 1, filter_size, filter_size).repeat((in_type.size, 1, 1, 1))

self.register_buffer('filter', filter)
self.register_buffer('filter', _filter)
self._pad = tuple(p + int((filter_size-1)//2) for p in self.padding)

def forward(self, input: GeometricTensor) -> GeometricTensor:
Expand Down
34 changes: 19 additions & 15 deletions e2cnn/nn/modules/r2_conv/basisexpansion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def __init__(self,
except EmptyBasisException:
# print(f"Empty basis at {reprs_names}")
pass

if len(_block_expansion_modules) == 0:
print('WARNING! The basis for the block expansion of the filter is empty!')

self._n_pairs = len(in_type._unique_representations) * len(out_type._unique_representations)

Expand Down Expand Up @@ -295,17 +298,18 @@ def _expand_block(self, weights, io_pair):
coefficients = coefficients.view(-1, block_expansion.dimension())

# expand the current subset of basis vectors and set the result in the appropriate place in the filter
filter = block_expansion(coefficients)
k, o, i, p = filter.shape
_filter = block_expansion(coefficients)
k, o, i, p = _filter.shape

filter = filter.view(self._out_count[io_pair[1]],
self._in_count[io_pair[0]],
o,
i,
self.S,
)
filter = filter.transpose(1, 2)
return filter
_filter = _filter.view(
self._out_count[io_pair[1]],
self._in_count[io_pair[0]],
o,
i,
self.S,
)
_filter = _filter.transpose(1, 2)
return _filter

def forward(self, weights: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -327,12 +331,12 @@ def forward(self, weights: torch.Tensor) -> torch.Tensor:
io_pair = self._representations_pairs[0]
in_indices = getattr(self, f"in_indices_{io_pair}")
out_indices = getattr(self, f"out_indices_{io_pair}")
filter = self._expand_block(weights, io_pair).reshape(out_indices[2], in_indices[2], self.S)
_filter = self._expand_block(weights, io_pair).reshape(out_indices[2], in_indices[2], self.S)

else:

# build the tensor which will contain te filter
filter = torch.zeros(self._output_size, self._input_size, self.S, device=weights.device)
_filter = torch.zeros(self._output_size, self._input_size, self.S, device=weights.device)

# iterate through all input-output field representations pairs
for io_pair in self._representations_pairs:
Expand All @@ -345,20 +349,20 @@ def forward(self, weights: torch.Tensor) -> torch.Tensor:
expanded = self._expand_block(weights, io_pair)

if self._contiguous[io_pair]:
filter[
_filter[
out_indices[0]:out_indices[1],
in_indices[0]:in_indices[1],
:,
] = expanded.reshape(out_indices[2], in_indices[2], self.S)
else:
filter[
_filter[
out_indices,
in_indices,
:,
] = expanded.reshape(-1, self.S)

# return the new filter
return filter
return _filter


def _retrieve_indices(type: FieldType):
Expand Down
60 changes: 30 additions & 30 deletions e2cnn/nn/modules/r2_conv/r2_transposed_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,36 +189,36 @@ def basisexpansion(self) -> BasisExpansion:
return self._basisexpansion

def expand_parameters(self) -> Tuple[torch.Tensor, torch.Tensor]:
filter = self.basisexpansion(self.weights)
filter = filter.reshape(filter.shape[0], filter.shape[1], self.kernel_size, self.kernel_size)
filter = filter.transpose(0, 1)
_filter = self.basisexpansion(self.weights)
_filter = _filter.reshape(_filter.shape[0], _filter.shape[1], self.kernel_size, self.kernel_size)
_filter = _filter.transpose(0, 1)

if self.bias is None:
bias = None
_bias = None
else:
bias = self.bias_expansion @ self.bias
_bias = self.bias_expansion @ self.bias

return filter, bias
return _filter, _bias

def forward(self, input: GeometricTensor):
assert input.type == self.in_type

if not self.training:
filter = self.filter
bias = self.expanded_bias
_filter = self.filter
_bias = self.expanded_bias
else:
# retrieve the filter and the bias
filter, bias = self.expand_parameters()
_filter, _bias = self.expand_parameters()

# use it for convolution and return the result
output = conv_transpose2d(
input.tensor, filter,
input.tensor, _filter,
padding=self.padding,
output_padding=self.output_padding,
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
bias=bias)
bias=_bias)

return GeometricTensor(output, self.out_type)

Expand All @@ -233,11 +233,11 @@ def train(self, mode=True):
elif self.training:
# avoid re-computation of the filter and the bias on multiple consecutive calls of `.eval()`

filter, bias = self.expand_parameters()
_filter, _bias = self.expand_parameters()

self.register_buffer("filter", filter)
if bias is not None:
self.register_buffer("expanded_bias", bias)
self.register_buffer("filter", _filter)
if _bias is not None:
self.register_buffer("expanded_bias", _bias)
else:
self.expanded_bias = None

Expand Down Expand Up @@ -271,9 +271,9 @@ def check_equivariance(self, atol: float = 0.1, rtol: float = 0.1, assertion: bo
import matplotlib.image as mpimg
from skimage.measure import block_reduce
from skimage.transform import resize
x = mpimg.imread('../test/group/testimage.jpeg').transpose((2, 0, 1))[np.newaxis, 0:c, :, :]

x = mpimg.imread('../group/testimage.jpeg').transpose((2, 0, 1))[np.newaxis, 0:c, :, :]

x = resize(
x,
(x.shape[0], x.shape[1], initial_size, initial_size),
Expand Down Expand Up @@ -359,8 +359,8 @@ def export(self):
# set to eval mode so the filter and the bias are updated with the current
# values of the weights
self.eval()
filter = self.filter
bias = self.expanded_bias
_filter = self.filter
_bias = self.expanded_bias

# build the PyTorch Conv2d module
has_bias = self.bias is not None
Expand All @@ -374,9 +374,9 @@ def export(self):
bias=has_bias)

# set the filter and the bias
conv.weight.data[:] = filter.data
conv.weight.data[:] = _filter.data
if has_bias:
conv.bias.data[:] = bias.data
conv.bias.data[:] = _bias.data

return conv

Expand Down Expand Up @@ -434,10 +434,10 @@ def bandlimiting_filter(frequency_cutoff: Union[float, Callable[[float], float]]
if isinstance(frequency_cutoff, float):
frequency_cutoff = lambda r, fco=frequency_cutoff: r * frequency_cutoff

def filter(attributes: dict) -> bool:
def bl_filter(attributes: dict) -> bool:
return math.fabs(attributes["frequency"]) <= frequency_cutoff(attributes["radius"])

return filter
return bl_filter


def get_grid_coords(kernel_size: int, dilation: int = 1):
Expand Down Expand Up @@ -546,11 +546,11 @@ def _manual_fco3(max_radius: float) -> Callable[[float], float]:
"""

def filter(r: float) -> float:
def bl_filter(r: float) -> float:
max_freq = 0 if r == 0. else 1 if r == max_radius else 2
return max_freq

return filter
return bl_filter


def _manual_fco2(max_radius: float) -> Callable[[float], float]:
Expand All @@ -567,11 +567,11 @@ def _manual_fco2(max_radius: float) -> Callable[[float], float]:
"""

def filter(r: float) -> float:
def bl_filter(r: float) -> float:
max_freq = 0 if r == 0. else min(2 * r, 1 if r == max_radius else 2 * r - (r + 1) % 2)
return max_freq

return filter
return bl_filter


def _manual_fco1(max_radius: float) -> Callable[[float], float]:
Expand All @@ -588,11 +588,11 @@ def _manual_fco1(max_radius: float) -> Callable[[float], float]:
"""

def filter(r: float) -> float:
def bl_filter(r: float) -> float:
max_freq = 0 if r == 0. else min(2 * r, 2 if r == max_radius else 2 * r - (r + 1) % 2)
return max_freq

return filter
return bl_filter



Loading

0 comments on commit 274a9c6

Please sign in to comment.