Skip to content

ActNorm implementation missing division by std on the shift parameter #6

@hemildesai

Description

@hemildesai

Hi,

Thanks for making the video lectures and homework public. I'm really enjoying the course so far. I was going through homework 2 and wanted to compare my stuff with the solutions. For the solution of hw2, I found the following implementation of ActNorm

class ActNorm(nn.Module):
    def __init__(self, n_channels):
        super(ActNorm, self).__init__()
        self.log_scale = nn.Parameter(torch.zeros(1, n_channels, 1, 1), requires_grad=True)
        self.shift = nn.Parameter(torch.zeros(1, n_channels, 1, 1), requires_grad=True)
        self.n_channels = n_channels
        self.initialized = False

    def forward(self, x, reverse=False):
        if reverse:
            return (x - self.shift) * torch.exp(-self.log_scale), self.log_scale
        else:
            if not self.initialized:
                self.shift.data = -torch.mean(x, dim=[0, 2, 3], keepdim=True)
                self.log_scale.data = - torch.log(
                    torch.std(x.permute(1, 0, 2, 3).reshape(self.n_channels, -1), dim=1).reshape(1, self.n_channels, 1,
                                                                                                 1))
                self.initialized = True
                result = x * torch.exp(self.log_scale) + self.shift
            return x * torch.exp(self.log_scale) + self.shift, self.log_scale

I think the shift needs to be divided by the standard deviation as follows for the activations to be normalized.

self.shift.data = -(torch.mean(x, dim=[0, 2, 3], keepdim=True) * torch.exp(self.log_scale)

Let me know if I'm missing something.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions