Skip to content

Commit ef755ab

Browse files
Merge pull request #570 from KevinMusgrave/dev
v1.7.0
2 parents ad3f3c9 + 774df78 commit ef755ab

File tree

4 files changed

+36
-11
lines changed

4 files changed

+36
-11
lines changed

.flake8

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
[flake8]
22

33
extend-ignore =
4-
E266 # too many leading '#' for block comment
5-
E203 # whitespace before ':'
6-
E402 # module level import not at top of file
7-
E501 # line too long
8-
E741 # ambiguous variable names
9-
E265 # block comment should start with #
4+
# too many leading '#' for block comment
5+
E266
6+
# whitespace before ':'
7+
E203
8+
# module level import not at top of file
9+
E402
10+
# line too long
11+
E501
12+
# ambiguous variable names
13+
E741
14+
# block comment should start with #
15+
E265
1016

1117
per-file-ignores =
1218
__init__.py:F401
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.6.3"
1+
__version__ = "1.7.0"

src/pytorch_metric_learning/losses/arcface_loss.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,19 @@ def cast_types(self, dtype, device):
2121

2222
def modify_cosine_of_target_classes(self, cosine_of_target_classes):
2323
angles = self.get_angles(cosine_of_target_classes)
24-
return torch.cos(angles + self.margin)
24+
25+
# Compute cos of (theta + margin) and cos of theta
26+
cos_theta_plus_margin = torch.cos(angles + self.margin)
27+
cos_theta = torch.cos(angles)
28+
29+
# Keep the cost function monotonically decreasing
30+
unscaled_logits = torch.where(
31+
angles <= np.deg2rad(180) - self.margin,
32+
cos_theta_plus_margin,
33+
cos_theta - self.margin * np.sin(self.margin),
34+
)
35+
36+
return unscaled_logits
2537

2638
def scale_logits(self, logits, *_):
2739
return logits * self.scale

tests/losses/test_arcface_loss.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,16 @@ def test_arcface_loss(self):
3636

3737
for i, c in enumerate(labels):
3838
acos = torch.acos(torch.clamp(logits[i, c], -1, 1))
39-
logits[i, c] = torch.cos(
40-
acos + torch.tensor(np.radians(margin), dtype=dtype).to(TEST_DEVICE)
41-
)
39+
if acos <= (np.pi - np.radians(margin)):
40+
logits[i, c] = torch.cos(
41+
acos
42+
+ torch.tensor(np.radians(margin), dtype=dtype).to(TEST_DEVICE)
43+
)
44+
else:
45+
mg = np.radians(margin)
46+
logits[i, c] -= torch.tensor(mg * np.sin(mg), dtype=dtype).to(
47+
TEST_DEVICE
48+
)
4249

4350
correct_loss = F.cross_entropy(logits * scale, labels.to(TEST_DEVICE))
4451

0 commit comments

Comments
 (0)