Spaces:
Runtime error
Runtime error
ferdmartin
commited on
Commit
•
3c0afc4
1
Parent(s):
9dc5a58
Update app.py
Browse files
app.py
CHANGED
@@ -67,7 +67,7 @@ def main():
|
|
67 |
logits = model(input_ids=input_ids, attention_mask=attention_mask)
|
68 |
_, prediction = torch.max(logits, 1)
|
69 |
prediction = prediction.item()
|
70 |
-
predict_proba = round(torch.softmax(logits, 1).
|
71 |
return prediction, predict_proba
|
72 |
|
73 |
def pred_str(prediction):
|
|
|
67 |
logits = model(input_ids=input_ids, attention_mask=attention_mask)
|
68 |
_, prediction = torch.max(logits, 1)
|
69 |
prediction = prediction.item()
|
70 |
+
predict_proba = round(torch.softmax(logits, 1).cpu().squeeze().tolist()[prediction],4)
|
71 |
return prediction, predict_proba
|
72 |
|
73 |
def pred_str(prediction):
|