Update DebertaV3.cs
Browse files- DebertaV3.cs +7 -1
DebertaV3.cs
CHANGED
@@ -34,6 +34,12 @@ public sealed class DebertaV3 : MonoBehaviour
|
|
34 |
Model loadedModel = ModelLoader.Load(model);
|
35 |
engine = WorkerFactory.CreateWorker(backend, loadedModel);
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
string[] hypotheses = classes.Select(x => hypothesisTemplate.Replace("{}", x)).ToArray();
|
38 |
Batch batch = GetTokenizedBatch(text, hypotheses);
|
39 |
float[] scores = GetBatchScores(batch);
|
@@ -106,7 +112,7 @@ public sealed class DebertaV3 : MonoBehaviour
|
|
106 |
// To obtain a single value (score) per example, a softmax function is applied
|
107 |
|
108 |
TensorFloat tensorScores;
|
109 |
-
if (multipleTrueClasses || logits.shape.
|
110 |
{
|
111 |
// Softmax over the entailment vs. contradiction dimension for each label independently
|
112 |
tensorScores = ops.Softmax(logits, -1);
|
|
|
34 |
Model loadedModel = ModelLoader.Load(model);
|
35 |
engine = WorkerFactory.CreateWorker(backend, loadedModel);
|
36 |
|
37 |
+
if (classes.Length == 0)
|
38 |
+
{
|
39 |
+
Debug.LogError("There need to be more than 0 classes");
|
40 |
+
return;
|
41 |
+
}
|
42 |
+
|
43 |
string[] hypotheses = classes.Select(x => hypothesisTemplate.Replace("{}", x)).ToArray();
|
44 |
Batch batch = GetTokenizedBatch(text, hypotheses);
|
45 |
float[] scores = GetBatchScores(batch);
|
|
|
112 |
// To obtain a single value (score) per example, a softmax function is applied
|
113 |
|
114 |
TensorFloat tensorScores;
|
115 |
+
if (multipleTrueClasses || logits.shape.Length(0, 1) == 1)
|
116 |
{
|
117 |
// Softmax over the entailment vs. contradiction dimension for each label independently
|
118 |
tensorScores = ops.Softmax(logits, -1);
|