-
Notifications
You must be signed in to change notification settings - Fork 169
/
efdmix.py
118 lines (90 loc) · 3.08 KB
/
efdmix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import random
from contextlib import contextmanager
import torch
import torch.nn as nn
def deactivate_efdmix(m):
if type(m) == EFDMix:
m.set_activation_status(False)
def activate_efdmix(m):
if type(m) == EFDMix:
m.set_activation_status(True)
def random_efdmix(m):
if type(m) == EFDMix:
m.update_mix_method("random")
def crossdomain_efdmix(m):
if type(m) == EFDMix:
m.update_mix_method("crossdomain")
@contextmanager
def run_without_efdmix(model):
# Assume MixStyle was initially activated
try:
model.apply(deactivate_efdmix)
yield
finally:
model.apply(activate_efdmix)
@contextmanager
def run_with_efdmix(model, mix=None):
# Assume MixStyle was initially deactivated
if mix == "random":
model.apply(random_efdmix)
elif mix == "crossdomain":
model.apply(crossdomain_efdmix)
try:
model.apply(activate_efdmix)
yield
finally:
model.apply(deactivate_efdmix)
class EFDMix(nn.Module):
"""EFDMix.
Reference:
Zhang et al. Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization. CVPR 2022.
"""
def __init__(self, p=0.5, alpha=0.1, eps=1e-6, mix="random"):
"""
Args:
p (float): probability of using MixStyle.
alpha (float): parameter of the Beta distribution.
eps (float): scaling parameter to avoid numerical issues.
mix (str): how to mix.
"""
super().__init__()
self.p = p
self.beta = torch.distributions.Beta(alpha, alpha)
self.eps = eps
self.alpha = alpha
self.mix = mix
self._activated = True
def __repr__(self):
return (
f"MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps}, mix={self.mix})"
)
def set_activation_status(self, status=True):
self._activated = status
def update_mix_method(self, mix="random"):
self.mix = mix
def forward(self, x):
if not self.training or not self._activated:
return x
if random.random() > self.p:
return x
B, C, W, H = x.size(0), x.size(1), x.size(2), x.size(3)
x_view = x.view(B, C, -1)
value_x, index_x = torch.sort(x_view) # sort inputs
lmda = self.beta.sample((B, 1, 1))
lmda = lmda.to(x.device)
if self.mix == "random":
# random shuffle
perm = torch.randperm(B)
elif self.mix == "crossdomain":
# split into two halves and swap the order
perm = torch.arange(B - 1, -1, -1) # inverse index
perm_b, perm_a = perm.chunk(2)
perm_b = perm_b[torch.randperm(perm_b.shape[0])]
perm_a = perm_a[torch.randperm(perm_a.shape[0])]
perm = torch.cat([perm_b, perm_a], 0)
else:
raise NotImplementedError
inverse_index = index_x.argsort(-1)
x_view_copy = value_x[perm].gather(-1, inverse_index)
new_x = x_view + (x_view_copy - x_view.detach()) * (1-lmda)
return new_x.view(B, C, W, H)