karalif commited on
Commit
3411c2a
·
verified ·
1 Parent(s): 82b76c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -17
app.py CHANGED
@@ -2,17 +2,13 @@ import gradio as gr
2
  import re
3
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
  import torch
5
- import shap
6
- import numpy as np
7
 
8
  # Initialize your model and tokenizer here
9
  model_identifier = "karalif/myTestModel"
10
  new_model = AutoModelForSequenceClassification.from_pretrained(model_identifier)
11
  new_tokenizer = AutoTokenizer.from_pretrained(model_identifier)
12
 
13
- # SHAP Explainer Initialization
14
- explainer = shap.Explainer(new_model, new_tokenizer)
15
-
16
  def get_prediction(text):
17
  # Tokenize the input text
18
  encoding = new_tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=200)
@@ -25,14 +21,9 @@ def get_prediction(text):
25
  sigmoid = torch.nn.Sigmoid()
26
  probs = sigmoid(logits.squeeze().cpu()).numpy()
27
 
28
- # Generate SHAP values
29
- shap_values = explainer([text])
30
-
31
- # Extracting top SHAP values and their corresponding tokens
32
- top_shap_values = np.abs(shap_values.values).mean(0).sum(-1)
33
- top_tokens_indices = np.argsort(-top_shap_values)[:5] # Getting indices of top 5 tokens
34
- top_tokens = [new_tokenizer.convert_ids_to_tokens(encoding['input_ids'][0][idx].item()) for idx in top_tokens_indices]
35
- top_shap_scores = top_shap_values[top_tokens_indices]
36
 
37
  # Prepare the HTML output with labels and their probabilities
38
  response = ""
@@ -43,10 +34,10 @@ def get_prediction(text):
43
  response += f"<span style='background-color:{colors[i]}; color:black;'>{label}</span>: {probs[i]*100:.1f}%<br>"
44
 
45
  influential_keywords = "INFLUENTIAL KEYWORDS:<br>"
46
- for token, score in zip(top_tokens, top_shap_scores):
47
- influential_keywords += f"{token} (Score: {score:.2f})<br>"
48
 
49
- return response, list(zip(top_tokens, top_shap_scores)), influential_keywords
50
 
51
  def predict(text):
52
  greeting_pattern = r"^(Halló|Hæ|Sæl|Góðan dag|Kær kveðja|Daginn|Kvöldið|Ágætis|Elsku)"
@@ -57,7 +48,7 @@ def predict(text):
57
  # Highlight the keywords in the input text
58
  modified_input = text
59
  for keyword, _ in keywords:
60
- modified_input = re.sub(rf"(\b{keyword}\b)", r"<span style='color:green;'>\1</span>", modified_input, flags=re.IGNORECASE)
61
 
62
  if not re.match(greeting_pattern, text, re.IGNORECASE):
63
  greeting_feedback = "OTHER FEEDBACK:<br>Heilsaðu dóninn þinn<br>"
 
2
  import re
3
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
  import torch
5
+ from keybert import KeyBERT
 
6
 
7
  # Initialize your model and tokenizer here
8
  model_identifier = "karalif/myTestModel"
9
  new_model = AutoModelForSequenceClassification.from_pretrained(model_identifier)
10
  new_tokenizer = AutoTokenizer.from_pretrained(model_identifier)
11
 
 
 
 
12
  def get_prediction(text):
13
  # Tokenize the input text
14
  encoding = new_tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=200)
 
21
  sigmoid = torch.nn.Sigmoid()
22
  probs = sigmoid(logits.squeeze().cpu()).numpy()
23
 
24
+ # Initialize KeyBERT
25
+ kw_model = KeyBERT()
26
+ keywords = kw_model.extract_keywords(text, keyphrase_ngram_range=(1, 1), stop_words='english', use_maxsum=True, nr_candidates=20, top_n=5)
 
 
 
 
 
27
 
28
  # Prepare the HTML output with labels and their probabilities
29
  response = ""
 
34
  response += f"<span style='background-color:{colors[i]}; color:black;'>{label}</span>: {probs[i]*100:.1f}%<br>"
35
 
36
  influential_keywords = "INFLUENTIAL KEYWORDS:<br>"
37
+ for keyword, score in keywords:
38
+ influential_keywords += f"{keyword} (Score: {score:.2f})<br>"
39
 
40
+ return response, keywords, influential_keywords
41
 
42
  def predict(text):
43
  greeting_pattern = r"^(Halló|Hæ|Sæl|Góðan dag|Kær kveðja|Daginn|Kvöldið|Ágætis|Elsku)"
 
48
  # Highlight the keywords in the input text
49
  modified_input = text
50
  for keyword, _ in keywords:
51
+ modified_input = modified_input.replace(keyword, f"<span style='color:green;'>{keyword}</span>")
52
 
53
  if not re.match(greeting_pattern, text, re.IGNORECASE):
54
  greeting_feedback = "OTHER FEEDBACK:<br>Heilsaðu dóninn þinn<br>"