ferdmartin commited on
Commit
4fc2f5a
1 Parent(s): e59445e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -12
app.py CHANGED
@@ -49,7 +49,12 @@ def main():
49
  translate(str.maketrans('', '', string.punctuation)).strip().lstrip()
50
 
51
  # Define the function to classify text
52
- def nb_lr(model, text):
 
 
 
 
 
53
  # Clean and format the input text
54
  text = format_text(text)
55
  # Predict using either LR or NB and get prediction probability
@@ -58,6 +63,11 @@ def main():
58
  return prediction, predict_proba
59
 
60
  def torch_pred(tokenizer, model, text):
 
 
 
 
 
61
  # DL models (BERT/DistilBERT based models)
62
  cleaned_text_tokens = tokenizer([text], padding='max_length', max_length=512, truncation=True)
63
  with torch.inference_mode():
@@ -70,7 +80,11 @@ def main():
70
  predict_proba = round(torch.softmax(logits, 1).cpu().squeeze().tolist()[prediction],4)
71
  return prediction, predict_proba
72
 
73
- def pred_str(prediction):
 
 
 
 
74
  # Map the predicted class to string output
75
  if prediction == 0:
76
  return "Human-made 🤷‍♂️🤷‍♀️"
@@ -79,6 +93,9 @@ def main():
79
 
80
  @st.cache(allow_output_mutation=True, suppress_st_warning=True)
81
  def load_tokenizer(option):
 
 
 
82
  if option == "BERT-based model":
83
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", padding='max_length', max_length=512, truncation=True)
84
  else:
@@ -87,6 +104,9 @@ def main():
87
 
88
  @st.cache(allow_output_mutation=True, suppress_st_warning=True)
89
  def load_model(option):
 
 
 
90
  if option == "BERT-based model":
91
  model = HF_BertBasedModelAppDocs.from_pretrained("ferdmartin/HF_BertBasedModelAppDocs").to(device)
92
  else:
@@ -95,7 +115,7 @@ def main():
95
 
96
 
97
  # Streamlit app:
98
-
99
  models_available = {"Logistic Regression":"models/baseline_model_lr2.joblib",
100
  "Naive Bayes": "models/baseline_model_nb2.joblib",
101
  "DistilBERT-based model (BERT light)": "ferdmartin/HF_DistilBertBasedModelAppDocs",
@@ -108,11 +128,12 @@ def main():
108
 
109
  # Check the model to use
110
  def restore_prediction_state():
 
111
  if "prediction" in st.session_state:
112
  del st.session_state.prediction
 
113
  option = st.selectbox("Select a model to use:", models_available, on_change=restore_prediction_state)
114
 
115
-
116
  # Load the selected trained model
117
  if option in ("BERT-based model", "DistilBERT-based model (BERT light)"):
118
  tokenizer = load_tokenizer(option)
@@ -135,20 +156,21 @@ def main():
135
  # Use model
136
  if st.button("Let's check this text!"):
137
  if text.strip() == "":
 
138
  st.error("Please enter some text")
139
  else:
140
  with st.spinner("Wait for the magic 🪄🔮"):
141
- # Use model
142
- if option in ("Naive Bayes", "Logistic Regression"):
143
  prediction, predict_proba = nb_lr(model, text)
144
  st.session_state["sklearn"] = True
145
  else:
146
- prediction, predict_proba = torch_pred(tokenizer, model, text)
147
  st.session_state["torch"] = True
148
 
149
  # Store the result in session state
150
- st.session_state["color_pred"] = "blue" if prediction == 0 else "red"
151
- prediction = pred_str(prediction)
152
  st.session_state["prediction"] = prediction
153
  st.session_state["predict_proba"] = predict_proba
154
  st.session_state["text"] = text
@@ -171,15 +193,14 @@ def main():
171
  html = eli5.format_as_html(explainer.explain_prediction(target_names=["Human", "AI"]))
172
  else:
173
  with st.spinner('Wait for it 💭... BERT-based model explanations take around 4-10 minutes. In case you want to abort, refresh the page.'):
174
- # TORCH EXPLAINER PRED FUNC (USES logits)
175
  def f(x):
 
176
  tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=512, truncation=True) for v in x])#.cuda()
177
  outputs = model(tv).detach().cpu().numpy()
178
  scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
179
  val = scipy.special.logit(scores[:,1]) # use one vs rest logit units
180
  return val
181
- # build an explainer using a token masker
182
- explainer = shap.Explainer(f, tokenizer)
183
  shap_values = explainer([st.session_state["text"]], fixed_context=1)
184
  html = shap.plots.text(shap_values, display=False)
185
  # Render HTML
 
49
  translate(str.maketrans('', '', string.punctuation)).strip().lstrip()
50
 
51
  # Define the function to classify text
52
+ def nb_lr(model, text: str) -> (int, float):
53
+ """
54
+ This function takes a previously trained Sklearn Pipeline
55
+ model (NaiveBayes or Logistic Regression), then returns prediction probability,
56
+ and the final prediction as a tuple.
57
+ """
58
  # Clean and format the input text
59
  text = format_text(text)
60
  # Predict using either LR or NB and get prediction probability
 
63
  return prediction, predict_proba
64
 
65
  def torch_pred(tokenizer, model, text):
66
+ """
67
+ This function takes a pre-trained tokenizer, a previously trained transformer-based model
68
+ model (DistilBert or Bert), then returns prediction probability,
69
+ and the final prediction as a tuple.
70
+ """
71
  # DL models (BERT/DistilBERT based models)
72
  cleaned_text_tokens = tokenizer([text], padding='max_length', max_length=512, truncation=True)
73
  with torch.inference_mode():
 
80
  predict_proba = round(torch.softmax(logits, 1).cpu().squeeze().tolist()[prediction],4)
81
  return prediction, predict_proba
82
 
83
+ def pred_str(prediction:int) -> str:
84
+ """
85
+ This function takes an integer value as input and returns a string representing the type of the input's source.
86
+ The input is expected to be a prediction from a classification model that distinguishes between human-made and AI-generated text.
87
+ """
88
  # Map the predicted class to string output
89
  if prediction == 0:
90
  return "Human-made 🤷‍♂️🤷‍♀️"
 
93
 
94
  @st.cache(allow_output_mutation=True, suppress_st_warning=True)
95
  def load_tokenizer(option):
96
+ """
97
+ Load pre-trained tokenizer and and save in cache memory.
98
+ """
99
  if option == "BERT-based model":
100
  tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", padding='max_length', max_length=512, truncation=True)
101
  else:
 
104
 
105
  @st.cache(allow_output_mutation=True, suppress_st_warning=True)
106
  def load_model(option):
107
+ """
108
+ Load trained Transformer-based models and save in cache memory.
109
+ """
110
  if option == "BERT-based model":
111
  model = HF_BertBasedModelAppDocs.from_pretrained("ferdmartin/HF_BertBasedModelAppDocs").to(device)
112
  else:
 
115
 
116
 
117
  # Streamlit app:
118
+ # List of models available
119
  models_available = {"Logistic Regression":"models/baseline_model_lr2.joblib",
120
  "Naive Bayes": "models/baseline_model_nb2.joblib",
121
  "DistilBERT-based model (BERT light)": "ferdmartin/HF_DistilBertBasedModelAppDocs",
 
128
 
129
  # Check the model to use
130
  def restore_prediction_state():
131
+ """Restore session_state variable to clear prediction after changing model"""
132
  if "prediction" in st.session_state:
133
  del st.session_state.prediction
134
+
135
  option = st.selectbox("Select a model to use:", models_available, on_change=restore_prediction_state)
136
 
 
137
  # Load the selected trained model
138
  if option in ("BERT-based model", "DistilBERT-based model (BERT light)"):
139
  tokenizer = load_tokenizer(option)
 
156
  # Use model
157
  if st.button("Let's check this text!"):
158
  if text.strip() == "":
159
+ # In case there is no input for the model
160
  st.error("Please enter some text")
161
  else:
162
  with st.spinner("Wait for the magic 🪄🔮"):
163
+ # Use models
164
+ if option in ("Naive Bayes", "Logistic Regression"): # Use Sklearn pipeline models
165
  prediction, predict_proba = nb_lr(model, text)
166
  st.session_state["sklearn"] = True
167
  else:
168
+ prediction, predict_proba = torch_pred(tokenizer, model, text) # Use transformers
169
  st.session_state["torch"] = True
170
 
171
  # Store the result in session state
172
+ st.session_state["color_pred"] = "blue" if prediction == 0 else "red" # Set color for prediction output string
173
+ prediction = pred_str(prediction) # Map predictions (int => str)
174
  st.session_state["prediction"] = prediction
175
  st.session_state["predict_proba"] = predict_proba
176
  st.session_state["text"] = text
 
193
  html = eli5.format_as_html(explainer.explain_prediction(target_names=["Human", "AI"]))
194
  else:
195
  with st.spinner('Wait for it 💭... BERT-based model explanations take around 4-10 minutes. In case you want to abort, refresh the page.'):
 
196
  def f(x):
197
+ """TORCH EXPLAINER PRED FUNC (USES logits)"""
198
  tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=512, truncation=True) for v in x])#.cuda()
199
  outputs = model(tv).detach().cpu().numpy()
200
  scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
201
  val = scipy.special.logit(scores[:,1]) # use one vs rest logit units
202
  return val
203
+ explainer = shap.Explainer(f, tokenizer) # build explainer using masking tokens and selected transformer-based model
 
204
  shap_values = explainer([st.session_state["text"]], fixed_context=1)
205
  html = shap.plots.text(shap_values, display=False)
206
  # Render HTML