Spaces:
Runtime error
Runtime error
kotstantinovskii
commited on
Commit
·
59a4514
1
Parent(s):
407f9c2
Update model.py
Browse files
model.py
CHANGED
@@ -21,17 +21,14 @@ class ArxivModel:
|
|
21 |
def get_idx_class(self, tweet_text, thr=-1.0):
|
22 |
logits = self.get_logits(tweet_text)
|
23 |
|
24 |
-
print(logits)
|
25 |
-
|
26 |
if thr == -1.0:
|
27 |
-
return [np.argmax(logits)]
|
28 |
else:
|
29 |
sum_probs = 0.0
|
30 |
idxs = []
|
31 |
for p in np.argsort(logits)[::-1]:
|
32 |
sum_probs += logits[p]
|
33 |
-
idxs.append(p)
|
34 |
-
|
35 |
if sum_probs > thr:
|
36 |
return idxs
|
37 |
|
|
|
21 |
def get_idx_class(self, tweet_text, thr=-1.0):
|
22 |
logits = self.get_logits(tweet_text)
|
23 |
|
|
|
|
|
24 |
if thr == -1.0:
|
25 |
+
return [(np.argmax(logits), np.max(logits))]
|
26 |
else:
|
27 |
sum_probs = 0.0
|
28 |
idxs = []
|
29 |
for p in np.argsort(logits)[::-1]:
|
30 |
sum_probs += logits[p]
|
31 |
+
idxs.append((p, logits[p]))
|
|
|
32 |
if sum_probs > thr:
|
33 |
return idxs
|
34 |
|