Spaces:
Sleeping
Sleeping
change output name
Browse files- triplet_margin_loss.py +8 -5
triplet_margin_loss.py
CHANGED
@@ -88,9 +88,9 @@ class TripletMarginLoss(evaluate.EvaluationModule):
|
|
88 |
inputs_description=_KWARGS_DESCRIPTION,
|
89 |
features=datasets.Features(
|
90 |
{
|
91 |
-
"anchor": datasets.Sequence(datasets.Value("float")),
|
92 |
-
"positive": datasets.Sequence(datasets.Value("float")),
|
93 |
-
"negative": datasets.Sequence(datasets.Value("float")),
|
94 |
"margin": datasets.Value("float")
|
95 |
}
|
96 |
),
|
@@ -98,13 +98,16 @@ class TripletMarginLoss(evaluate.EvaluationModule):
|
|
98 |
)
|
99 |
|
100 |
def _compute(self, anchor, positive, negative, margin=1.0):
|
|
|
|
|
101 |
d_a_p_sum = 0.0
|
102 |
d_a_n_sum = 0.0
|
103 |
for a, p, n in zip(anchor, positive, negative):
|
104 |
d_a_p_sum += (a - p)**2
|
105 |
d_a_n_sum += (a - n)**2
|
|
|
106 |
return {
|
107 |
-
"
|
108 |
-
|
109 |
)
|
110 |
}
|
|
|
88 |
inputs_description=_KWARGS_DESCRIPTION,
|
89 |
features=datasets.Features(
|
90 |
{
|
91 |
+
"anchor": datasets.Sequence(datasets.Value("float"), id="reference"),
|
92 |
+
"positive": datasets.Sequence(datasets.Value("float"), id="sequence"),
|
93 |
+
"negative": datasets.Sequence(datasets.Value("float"), id="sequence"),
|
94 |
"margin": datasets.Value("float")
|
95 |
}
|
96 |
),
|
|
|
98 |
)
|
99 |
|
100 |
def _compute(self, anchor, positive, negative, margin=1.0):
|
101 |
+
if not (len(anchor) == len(positive) == len(negative)):
|
102 |
+
raise ValueError("Anchor, Positive and Negative examples must be of same length.")
|
103 |
d_a_p_sum = 0.0
|
104 |
d_a_n_sum = 0.0
|
105 |
for a, p, n in zip(anchor, positive, negative):
|
106 |
d_a_p_sum += (a - p)**2
|
107 |
d_a_n_sum += (a - n)**2
|
108 |
+
loss = max(np.sqrt(d_a_p_sum) - np.sqrt(d_a_n_sum) + margin, 0)
|
109 |
return {
|
110 |
+
"triplet_margin_loss": float(
|
111 |
+
loss
|
112 |
)
|
113 |
}
|