zaidmehdi commited on
Commit
a025587
·
1 Parent(s): 1964ece

integrating new finetuned model

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ models/best_model_checkpoint.pth filter=lfs diff=lfs merge=lfs -text
models/best_model_checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:393c58dcfc7d9f44298ae0defd11abfd18d49522c2997477c84f8f7d34e3a628
3
+ size 559417326
models/logistic_regression.pkl DELETED
Binary file (132 kB)
 
src/main.py CHANGED
@@ -1,22 +1,31 @@
1
  import os
2
- import pickle
3
 
4
  import gradio as gr
5
- from transformers import AutoModel, AutoTokenizer
 
 
6
 
7
- from .utils import extract_hidden_state
8
 
9
 
10
  # Load model
11
- models_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
12
- model_file = os.path.join(models_dir, 'logistic_regression.pkl')
 
13
 
 
 
14
  if os.path.exists(model_file):
15
  with open(model_file, "rb") as f:
16
- model = pickle.load(f)
 
17
  else:
18
  print(f"Error: {model_file} not found.")
19
 
 
 
 
 
20
  # Load html
21
  html_dir = os.path.join(os.path.dirname(__file__), "templates")
22
  index_html_path = os.path.join(html_dir, "index.html")
@@ -27,20 +36,17 @@ if os.path.exists(index_html_path):
27
  else:
28
  print(f"Error: {index_html_path} not found.")
29
 
30
- # Load pre-trained model
31
- model_name = "moussaKam/AraBART"
32
- tokenizer = AutoTokenizer.from_pretrained(model_name)
33
- language_model = AutoModel.from_pretrained(model_name)
34
-
35
 
36
  def classify_arabic_dialect(text):
37
- text_embeddings = extract_hidden_state(text, tokenizer, language_model)
38
- probabilities = model.predict_proba(text_embeddings)[0]
39
- labels = model.classes_
 
40
  predictions = {labels[i]: probabilities[i] for i in range(len(probabilities))}
41
 
42
  return predictions
43
 
 
44
  def main():
45
  with gr.Blocks() as demo:
46
  gr.HTML(index_html)
 
1
  import os
 
2
 
3
  import gradio as gr
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
 
8
+ from .utils import load_data
9
 
10
 
11
  # Load model
12
+ model_name = "moussaKam/AraBART"
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=21)
15
 
16
+ models_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
17
+ model_file = os.path.join(models_dir, 'best_model_checkpoint.pth')
18
  if os.path.exists(model_file):
19
  with open(model_file, "rb") as f:
20
+ checkpoint = torch.load(model_file)
21
+ model.load_state_dict(checkpoint)
22
  else:
23
  print(f"Error: {model_file} not found.")
24
 
25
+ # Load label encoder
26
+ encoder_file = os.path.join(models_dir, 'label_encoder.pkl')
27
+ label_encoder = load_data(encoder_file)
28
+
29
  # Load html
30
  html_dir = os.path.join(os.path.dirname(__file__), "templates")
31
  index_html_path = os.path.join(html_dir, "index.html")
 
36
  else:
37
  print(f"Error: {index_html_path} not found.")
38
 
 
 
 
 
 
39
 
40
  def classify_arabic_dialect(text):
41
+ tokenized_text = tokenizer(text, return_tensors="pt")
42
+ output = model(**tokenized_text)
43
+ probabilities = F.softmax(output.logits, dim=1)[0]
44
+ labels = label_encoder.inverse_transform(range(len(probabilities)))
45
  predictions = {labels[i]: probabilities[i] for i in range(len(probabilities))}
46
 
47
  return predictions
48
 
49
+
50
  def main():
51
  with gr.Blocks() as demo:
52
  gr.HTML(index_html)