yangwang825 commited on
Commit
5b7e94f
1 Parent(s): ae0f9e7

Update modeling_wavlm_spkreg.py

Browse files
Files changed (1) hide show
  1. modeling_wavlm_spkreg.py +18 -4
modeling_wavlm_spkreg.py CHANGED
@@ -456,7 +456,7 @@ class AAMSoftmaxLoss(nn.Module):
456
  def __init__(
457
  self,
458
  scale: float = 30.0,
459
- margin: float = 0.35,
460
  easy_margin: bool = False,
461
  label_smoothing: float = 0.0,
462
  reduction: str = "mean"
@@ -489,9 +489,23 @@ class AAMSoftmaxLoss(nn.Module):
489
  """
490
  _, num_labels = inputs.shape
491
  # `inputs` are the outputs from AngularLinear()
492
- cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7)
493
- theta = torch.acos(cos_theta)
494
- psi = torch.cos(theta + self.margin)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  one_hot = nn.functional.one_hot(targets, num_labels)
496
  outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
497
  loss = F.cross_entropy(
 
456
  def __init__(
457
  self,
458
  scale: float = 30.0,
459
+ margin: float = 0.2,
460
  easy_margin: bool = False,
461
  label_smoothing: float = 0.0,
462
  reduction: str = "mean"
 
489
  """
490
  _, num_labels = inputs.shape
491
  # `inputs` are the outputs from AngularLinear()
492
+ epsilon = 1e-6
493
+ # theta = torch.acos(cos_theta)
494
+ # psi = torch.cos(theta + self.margin)
495
+ cos_theta = torch.clamp(inputs, -1.0 + epsilon, 1.0 - epsilon)
496
+ sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
497
+ sin_theta = torch.clamp(sin_theta, 0.0 + epsilon, 1.0 - epsilon)
498
+
499
+ cos_m = math.cos(self.margin)
500
+ sin_m = math.sin(self.margin)
501
+ psi = cos_theta * cos_m - sin_theta * sin_m # cos(theta + m)
502
+
503
+ if self.easy_margin:
504
+ psi = torch.where(cos_theta > 0, psi, cos_theta)
505
+ else:
506
+ # Make the function cos(theta+m) monotonic decreasing while theta in [0°, 180°]
507
+ psi = torch.where((cos_theta - math.cos(math.pi - self.margin)) > 0, psi, cos_theta - self.margin)
508
+
509
  one_hot = nn.functional.one_hot(targets, num_labels)
510
  outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
511
  loss = F.cross_entropy(