yangwang825
commited on
Commit
•
5b7e94f
1
Parent(s):
ae0f9e7
Update modeling_wavlm_spkreg.py
Browse files- modeling_wavlm_spkreg.py +18 -4
modeling_wavlm_spkreg.py
CHANGED
@@ -456,7 +456,7 @@ class AAMSoftmaxLoss(nn.Module):
|
|
456 |
def __init__(
|
457 |
self,
|
458 |
scale: float = 30.0,
|
459 |
-
margin: float = 0.
|
460 |
easy_margin: bool = False,
|
461 |
label_smoothing: float = 0.0,
|
462 |
reduction: str = "mean"
|
@@ -489,9 +489,23 @@ class AAMSoftmaxLoss(nn.Module):
|
|
489 |
"""
|
490 |
_, num_labels = inputs.shape
|
491 |
# `inputs` are the outputs from AngularLinear()
|
492 |
-
|
493 |
-
theta = torch.acos(cos_theta)
|
494 |
-
psi = torch.cos(theta + self.margin)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
one_hot = nn.functional.one_hot(targets, num_labels)
|
496 |
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
|
497 |
loss = F.cross_entropy(
|
|
|
456 |
def __init__(
|
457 |
self,
|
458 |
scale: float = 30.0,
|
459 |
+
margin: float = 0.2,
|
460 |
easy_margin: bool = False,
|
461 |
label_smoothing: float = 0.0,
|
462 |
reduction: str = "mean"
|
|
|
489 |
"""
|
490 |
_, num_labels = inputs.shape
|
491 |
# `inputs` are the outputs from AngularLinear()
|
492 |
+
epsilon = 1e-6
|
493 |
+
# theta = torch.acos(cos_theta)
|
494 |
+
# psi = torch.cos(theta + self.margin)
|
495 |
+
cos_theta = torch.clamp(inputs, -1.0 + epsilon, 1.0 - epsilon)
|
496 |
+
sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
|
497 |
+
sin_theta = torch.clamp(sin_theta, 0.0 + epsilon, 1.0 - epsilon)
|
498 |
+
|
499 |
+
cos_m = math.cos(self.margin)
|
500 |
+
sin_m = math.sin(self.margin)
|
501 |
+
psi = cos_theta * cos_m - sin_theta * sin_m # cos(theta + m)
|
502 |
+
|
503 |
+
if self.easy_margin:
|
504 |
+
psi = torch.where(cos_theta > 0, psi, cos_theta)
|
505 |
+
else:
|
506 |
+
# Make the function cos(theta+m) monotonic decreasing while theta in [0°, 180°]
|
507 |
+
psi = torch.where((cos_theta - math.cos(math.pi - self.margin)) > 0, psi, cos_theta - self.margin)
|
508 |
+
|
509 |
one_hot = nn.functional.one_hot(targets, num_labels)
|
510 |
outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
|
511 |
loss = F.cross_entropy(
|