-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathsinkhorn.py
93 lines (73 loc) · 2.97 KB
/
sinkhorn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
Source: https://gist.github.com/wohlert/8589045ab544082560cc5f8915cc90bd
"""
import torch
import torch.nn as nn
from pdb import set_trace as bp
class SinkhornSolver(nn.Module):
"""
Optimal Transport solver under entropic regularisation.
Based on the code of Gabriel Peyré.
"""
def __init__(self, epsilon, iterations=100, ground_metric=lambda x: torch.pow(x, 2)):
super(SinkhornSolver, self).__init__()
self.epsilon = epsilon
self.iterations = iterations
self.ground_metric = ground_metric
def sinkhorn_loss(self, x, y):
num_x = x.size(-2)
num_y = y.size(-2)
batch_size = 1 if x.dim() == 2 else x.size(0)
# Marginal densities are empirical measures
a = x.new_ones((batch_size, num_x), requires_grad=False) / num_x
b = y.new_ones((batch_size, num_y), requires_grad=False) / num_y
a = a.squeeze()
b = b.squeeze()
# Initialise approximation vectors in log domain
u = torch.zeros_like(a)
v = torch.zeros_like(b)
# Stopping criterion
threshold = 1e-1
# Cost matrix
C = self._compute_cost(x, y)
# Sinkhorn iterations
for i in range(self.iterations):
u0, v0 = u, v
# u^{l+1} = a / (K v^l)
K = self._log_boltzmann_kernel(u, v, C)
u_ = torch.log(a + 1e-8) - torch.logsumexp(K, dim=1)
u = self.epsilon * u_ + u
# v^{l+1} = b / (K^T u^(l+1))
K_t = self._log_boltzmann_kernel(u, v, C).transpose(-2, -1)
v_ = torch.log(b + 1e-8) - torch.logsumexp(K_t, dim=1)
v = self.epsilon * v_ + v
# Size of the change we have performed on u
diff = torch.sum(torch.abs(u - u0), dim=-1) + torch.sum(torch.abs(v - v0), dim=-1)
mean_diff = torch.mean(diff)
if mean_diff.item() < threshold:
break
# print("Finished computing transport plan in {} iterations".format(i))
# Transport plan pi = diag(a)*K*diag(b)
K = self._log_boltzmann_kernel(u, v, C)
pi = torch.exp(K)
# Sinkhorn distance
cost = torch.sum(pi * C, dim=(-2, -1))
return cost
def sinkhorn_normalized(self, x, y):
Wxy = self.sinkhorn_loss(x, y)
Wxx = self.sinkhorn_loss(x, x)
Wyy = self.sinkhorn_loss(y, y)
return 2 * Wxy - Wxx - Wyy
def forward(self, x, y):
# return self.sinkhorn_normalized(x,y)
return self.sinkhorn_loss(x, y)
def _compute_cost(self, x, y):
x_ = x.unsqueeze(-2)
y_ = y.unsqueeze(-3)
C = torch.sum(self.ground_metric(x_ - y_), dim=-1)
return C
def _log_boltzmann_kernel(self, u, v, C=None):
C = self._compute_cost(x, y) if C is None else C
kernel = -C + u.unsqueeze(-1) + v.unsqueeze(-2)
kernel /= self.epsilon
return kernel