Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implement SSL-EY #1443

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding SSL-EY to tests
jameschapman19 committed Dec 8, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 6dbb3014eb1c044a2c3b5f37d12d333b1b0aaac2
74 changes: 6 additions & 68 deletions lightly/loss/ssley_loss.py
Original file line number Diff line number Diff line change
@@ -47,9 +47,6 @@ def __init__(
"distributed support."
)

self.lambda_param = lambda_param
self.mu_param = mu_param
self.nu_param = nu_param
self.gather_distributed = gather_distributed
self.eps = eps

@@ -62,15 +59,6 @@ def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor:
z_b:
Tensor with shape (batch_size, ..., dim).
"""
assert (
z_a.shape[0] > 1 and z_b.shape[0] > 1
), f"z_a and z_b must have batch size > 1 but found {z_a.shape[0]} and {z_b.shape[0]}"
assert (
z_a.shape == z_b.shape
), f"z_a and z_b must have same shape but found {z_a.shape} and {z_b.shape}."

# invariance term of the loss
inv_loss = invariance_loss(x=z_a, y=z_b)

# gather all batches
if self.gather_distributed and dist.is_initialized():
@@ -79,62 +67,12 @@ def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor:
z_a = torch.cat(gather(z_a), dim=0)
z_b = torch.cat(gather(z_b), dim=0)

var_loss = 0.5 * (
variance_loss(x=z_a, eps=self.eps) + variance_loss(x=z_b, eps=self.eps)
)
cov_loss = covariance_loss(x=z_a) + covariance_loss(x=z_b)

loss = (
self.lambda_param * inv_loss
+ self.mu_param * var_loss
+ self.nu_param * cov_loss
)
return loss


def invariance_loss(x: Tensor, y: Tensor) -> Tensor:
"""Returns SSL-EY invariance loss.

Args:
x:
Tensor with shape (batch_size, ..., dim).
y:
Tensor with shape (batch_size, ..., dim).
"""
return F.mse_loss(x, y)

z_a = z_a - z_a.mean(dim=0)
z_b = z_b - z_b.mean(dim=0)

def variance_loss(x: Tensor, eps: float = 0.0001) -> Tensor:
"""Returns SSL-EY variance loss.

Args:
x:
Tensor with shape (batch_size, ..., dim).
eps:
Epsilon for numerical stability.
"""
x = x - x.mean(dim=0)
std = torch.sqrt(x.var(dim=0) + eps)
loss = torch.mean(F.relu(1.0 - std))
return loss
C = 2*(z_a.T @ z_b) / (self.args.batch_size - 1)
V = (z_a.T @ z_a) / (self.args.batch_size - 1) + (z_b.T @ z_b) / (self.args.batch_size - 1)

loss = torch.trace(C)-torch.trace(V@V)

def covariance_loss(x: Tensor) -> Tensor:
"""Returns SSL-EY covariance loss.

Generalized version of the covariance loss with support for tensors with more than
two dimensions.

Args:
x:
Tensor with shape (batch_size, ..., dim).
"""
x = x - x.mean(dim=0)
batch_size = x.size(0)
dim = x.size(-1)
# nondiag_mask has shape (dim, dim) with 1s on all non-diagonal entries.
nondiag_mask = ~torch.eye(dim, device=x.device, dtype=torch.bool)
# cov has shape (..., dim, dim)
cov = torch.einsum("b...c,b...d->...cd", x, x) / (batch_size - 1)
loss = cov[..., nondiag_mask].pow(2).sum(-1) / dim
return loss.mean()
return loss