kotstantinovskii commited on
Commit
59a4514
·
1 Parent(s): 407f9c2

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -5
model.py CHANGED
@@ -21,17 +21,14 @@ class ArxivModel:
21
  def get_idx_class(self, tweet_text, thr=-1.0):
22
  logits = self.get_logits(tweet_text)
23
 
24
- print(logits)
25
-
26
  if thr == -1.0:
27
- return [np.argmax(logits)]
28
  else:
29
  sum_probs = 0.0
30
  idxs = []
31
  for p in np.argsort(logits)[::-1]:
32
  sum_probs += logits[p]
33
- idxs.append(p)
34
-
35
  if sum_probs > thr:
36
  return idxs
37
 
 
21
  def get_idx_class(self, tweet_text, thr=-1.0):
22
  logits = self.get_logits(tweet_text)
23
 
 
 
24
  if thr == -1.0:
25
+ return [(np.argmax(logits), np.max(logits))]
26
  else:
27
  sum_probs = 0.0
28
  idxs = []
29
  for p in np.argsort(logits)[::-1]:
30
  sum_probs += logits[p]
31
+ idxs.append((p, logits[p]))
 
32
  if sum_probs > thr:
33
  return idxs
34