theAIguy commited on
Commit
cd29486
1 Parent(s): 8e294b0

change output name

Browse files
Files changed (1) hide show
  1. triplet_margin_loss.py +8 -5
triplet_margin_loss.py CHANGED
@@ -88,9 +88,9 @@ class TripletMarginLoss(evaluate.EvaluationModule):
88
  inputs_description=_KWARGS_DESCRIPTION,
89
  features=datasets.Features(
90
  {
91
- "anchor": datasets.Sequence(datasets.Value("float")),
92
- "positive": datasets.Sequence(datasets.Value("float")),
93
- "negative": datasets.Sequence(datasets.Value("float")),
94
  "margin": datasets.Value("float")
95
  }
96
  ),
@@ -98,13 +98,16 @@ class TripletMarginLoss(evaluate.EvaluationModule):
98
  )
99
 
100
  def _compute(self, anchor, positive, negative, margin=1.0):
 
 
101
  d_a_p_sum = 0.0
102
  d_a_n_sum = 0.0
103
  for a, p, n in zip(anchor, positive, negative):
104
  d_a_p_sum += (a - p)**2
105
  d_a_n_sum += (a - n)**2
 
106
  return {
107
- "accuracy": float(
108
- max(np.sqrt(d_a_p_sum) - np.sqrt(d_a_n_sum) + margin, 0)
109
  )
110
  }
 
88
  inputs_description=_KWARGS_DESCRIPTION,
89
  features=datasets.Features(
90
  {
91
+ "anchor": datasets.Sequence(datasets.Value("float"), id="reference"),
92
+ "positive": datasets.Sequence(datasets.Value("float"), id="sequence"),
93
+ "negative": datasets.Sequence(datasets.Value("float"), id="sequence"),
94
  "margin": datasets.Value("float")
95
  }
96
  ),
 
98
  )
99
 
100
  def _compute(self, anchor, positive, negative, margin=1.0):
101
+ if not (len(anchor) == len(positive) == len(negative)):
102
+ raise ValueError("Anchor, Positive and Negative examples must be of same length.")
103
  d_a_p_sum = 0.0
104
  d_a_n_sum = 0.0
105
  for a, p, n in zip(anchor, positive, negative):
106
  d_a_p_sum += (a - p)**2
107
  d_a_n_sum += (a - n)**2
108
+ loss = max(np.sqrt(d_a_p_sum) - np.sqrt(d_a_n_sum) + margin, 0)
109
  return {
110
+ "triplet_margin_loss": float(
111
+ loss
112
  )
113
  }