zaidmehdi commited on
Commit
9365c1c
·
1 Parent(s): ff49aa5

output 3 predictions and their probabilities

Browse files
Files changed (1) hide show
  1. src/main.py +8 -3
src/main.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import pickle
3
 
4
  import gradio as gr
 
5
  from transformers import AutoModel, AutoTokenizer
6
 
7
  from .utils import extract_hidden_state
@@ -35,9 +36,13 @@ language_model = AutoModel.from_pretrained(model_name)
35
 
36
  def classify_arabic_dialect(text):
37
  text_embeddings = extract_hidden_state(text, tokenizer, language_model)
38
- predicted_class = model.predict(text_embeddings)[0]
39
-
40
- return predicted_class
 
 
 
 
41
 
42
 
43
  with gr.Blocks() as demo:
 
2
  import pickle
3
 
4
  import gradio as gr
5
+ import numpy as np
6
  from transformers import AutoModel, AutoTokenizer
7
 
8
  from .utils import extract_hidden_state
 
36
 
37
  def classify_arabic_dialect(text):
38
  text_embeddings = extract_hidden_state(text, tokenizer, language_model)
39
+ probabilities = model.predict_proba(text_embeddings)[0]
40
+ top_three_indices = np.argsort(-probabilities)[:3]
41
+
42
+ top_three_labels = model.classes_[top_three_indices]
43
+ top_three_probabilities = probabilities[top_three_indices]
44
+
45
+ return top_three_labels, top_three_probabilities
46
 
47
 
48
  with gr.Blocks() as demo: