ferdmartin commited on
Commit
3c0afc4
1 Parent(s): 9dc5a58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -67,7 +67,7 @@ def main():
67
  logits = model(input_ids=input_ids, attention_mask=attention_mask)
68
  _, prediction = torch.max(logits, 1)
69
  prediction = prediction.item()
70
- predict_proba = round(torch.softmax(logits, 1).numpy().squeeze()[prediction].item(),4)
71
  return prediction, predict_proba
72
 
73
  def pred_str(prediction):
 
67
  logits = model(input_ids=input_ids, attention_mask=attention_mask)
68
  _, prediction = torch.max(logits, 1)
69
  prediction = prediction.item()
70
+ predict_proba = round(torch.softmax(logits, 1).cpu().squeeze().tolist()[prediction],4)
71
  return prediction, predict_proba
72
 
73
  def pred_str(prediction):