minko186 commited on
Commit
227a8b5
·
verified ·
1 Parent(s): c6bd7c4

Update predictors.py

Browse files
Files changed (1) hide show
  1. predictors.py +297 -1
predictors.py CHANGED
@@ -11,7 +11,303 @@ import numpy as np
11
  import concurrent
12
  from multiprocessing import Pool
13
  from const import url_types
14
- from collections import defaultdict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  WORD = re.compile(r"\w+")
17
  model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
 
11
  import concurrent
12
  from multiprocessing import Pool
13
  from const import url_types
14
+ from collections import defaultdictimport torch
15
+ import numpy as np
16
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
17
+ import nltk
18
+ import torch.nn.functional as F
19
+ import nltk
20
+ from scipy.special import softmax
21
+ import yaml
22
+ from utils import *
23
+ import joblib
24
+ from optimum.bettertransformer import BetterTransformer
25
+ import gc
26
+ from cleantext import clean
27
+ import gradio as gr
28
+ from tqdm.auto import tqdm
29
+ from transformers import pipeline
30
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
31
+ import nltk
32
+ from nltk.tokenize import sent_tokenize
33
+ from optimum.pipelines import pipeline
34
+
35
+ with open("config.yaml", "r") as file:
36
+ params = yaml.safe_load(file)
37
+
38
+ nltk.download("punkt")
39
+ nltk.download("stopwords")
40
+ device_needed = "cuda" if torch.cuda.is_available() else "cpu"
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ print('DEVICE IS :' , device)
43
+
44
+ text_bc_model_path = params["TEXT_BC_MODEL_PATH"]
45
+ text_mc_model_path = params["TEXT_MC_MODEL_PATH"]
46
+ text_quillbot_model_path = params["TEXT_QUILLBOT_MODEL_PATH"]
47
+ quillbot_labels = params["QUILLBOT_LABELS"]
48
+ mc_label_map = params["MC_OUTPUT_LABELS"]
49
+ mc_token_size = int(params["MC_TOKEN_SIZE"])
50
+ bc_token_size = int(params["BC_TOKEN_SIZE"])
51
+ bias_checker_model_name = params['BIAS_CHECKER_MODEL_PATH']
52
+ bias_corrector_model_name = params['BIAS_CORRECTOR_MODEL_PATH']
53
+ # access_token = params['HF_TOKEN']
54
+
55
+ text_bc_tokenizer = AutoTokenizer.from_pretrained(text_bc_model_path)
56
+ text_bc_model = AutoModelForSequenceClassification.from_pretrained(text_bc_model_path).to(device)
57
+ text_mc_tokenizer = AutoTokenizer.from_pretrained(text_mc_model_path)
58
+ text_mc_model = AutoModelForSequenceClassification.from_pretrained(text_mc_model_path).to(device)
59
+ quillbot_tokenizer = AutoTokenizer.from_pretrained(text_quillbot_model_path)
60
+ quillbot_model = AutoModelForSequenceClassification.from_pretrained(text_quillbot_model_path).to(device)
61
+
62
+ # proxy models for explainability
63
+ mini_bc_model_name = "polygraf-ai/bc-model"
64
+ bc_tokenizer_mini = AutoTokenizer.from_pretrained(mini_bc_model_name)
65
+ bc_model_mini = AutoModelForSequenceClassification.from_pretrained(mini_bc_model_name).to(device_needed)
66
+ mini_humanizer_model_name = "polygraf-ai/humanizer-model"
67
+ humanizer_tokenizer_mini = AutoTokenizer.from_pretrained(mini_humanizer_model_name)
68
+ humanizer_model_mini = AutoModelForSequenceClassification.from_pretrained(mini_humanizer_model_name).to(device_needed)
69
+
70
+ bc_model_mini = BetterTransformer.transform(bc_model_mini)
71
+ humanizer_model_mini = BetterTransformer.transform(humanizer_model_mini)
72
+ text_bc_model = BetterTransformer.transform(text_bc_model)
73
+ text_mc_model = BetterTransformer.transform(text_mc_model)
74
+ quillbot_model = BetterTransformer.transform(quillbot_model)
75
+
76
+ bias_model_checker = AutoModelForSequenceClassification.from_pretrained(bias_checker_model_name)
77
+ tokenizer = AutoTokenizer.from_pretrained(bias_checker_model_name)
78
+ bias_model_checker = BetterTransformer.transform(bias_model_checker, keep_original_model=False)
79
+ bias_checker = pipeline(
80
+ "text-classification",
81
+ model=bias_checker_model_name,
82
+ tokenizer=bias_checker_model_name,
83
+ )
84
+ gc.collect()
85
+ bias_corrector = pipeline( "text2text-generation", model=bias_corrector_model_name, accelerator="ort")
86
+
87
+ # model score calibration
88
+ iso_reg = joblib.load("isotonic_regression_model.joblib")
89
+
90
+
91
+ def split_text(text: str) -> list:
92
+ sentences = sent_tokenize(text)
93
+ return [[sentence] for sentence in sentences]
94
+
95
+ def correct_text(text: str, bias_checker, bias_corrector, separator: str = " ") -> tuple:
96
+ sentence_batches = split_text(text)
97
+ corrected_text = []
98
+ corrections = []
99
+ for batch in tqdm(sentence_batches, total=len(sentence_batches), desc="correcting text.."):
100
+ raw_text = " ".join(batch)
101
+ results = bias_checker(raw_text)
102
+ if results[0]["label"] != "LABEL_1" or (results[0]["label"] == "LABEL_1" and results[0]["score"] < 0.9):
103
+ corrected_batch = bias_corrector(raw_text)
104
+ corrected_version = corrected_batch[0]["generated_text"]
105
+ corrected_text.append(corrected_version)
106
+ corrections.append((raw_text, corrected_version))
107
+ else:
108
+ corrected_text.append(raw_text)
109
+ corrected_text = separator.join(corrected_text)
110
+ return corrected_text, corrections
111
+
112
+ def update(text: str):
113
+ text = clean(text, lower=False)
114
+ corrected_text, corrections = correct_text(text, bias_checker, bias_corrector)
115
+ corrections_display = "".join([f"{corr}" for orig, corr in corrections])
116
+ if corrections_display == "":
117
+ corrections_display = text
118
+ return corrections_display
119
+
120
+ def update_main(text: str):
121
+ text = clean(text, lower=False)
122
+ corrected_text, corrections = correct_text(text, bias_checker, bias_corrector)
123
+ corrections_display = "\n\n".join([f"Original: {orig}\nCorrected: {corr}" for orig, corr in corrections])
124
+ return corrected_text, corrections_display
125
+
126
+ def split_text(text: str) -> list:
127
+ sentences = sent_tokenize(text)
128
+ return [[sentence] for sentence in sentences]
129
+
130
+ def get_token_length(tokenizer, sentence):
131
+ return len(tokenizer.tokenize(sentence))
132
+
133
+ def split_text_allow_complete_sentences_nltk(text, type_det="bc"):
134
+ sentences = sent_tokenize(text)
135
+ chunks = []
136
+ current_chunk = []
137
+ current_length = 0
138
+ if type_det == "bc":
139
+ tokenizer = text_bc_tokenizer
140
+ max_tokens = bc_token_size
141
+ elif type_det == "mc":
142
+ tokenizer = text_mc_tokenizer
143
+ max_tokens = mc_token_size
144
+
145
+ elif type_det == "quillbot":
146
+ tokenizer = quillbot_tokenizer
147
+ max_tokens = 256
148
+
149
+ def add_sentence_to_chunk(sentence):
150
+ nonlocal current_chunk, current_length
151
+ sentence_length = get_token_length(tokenizer, sentence)
152
+ if current_length + sentence_length > max_tokens:
153
+ chunks.append((current_chunk, current_length))
154
+ current_chunk = []
155
+ current_length = 0
156
+ current_chunk.append(sentence)
157
+ current_length += sentence_length
158
+
159
+ for sentence in sentences:
160
+ add_sentence_to_chunk(sentence)
161
+ if current_chunk:
162
+ chunks.append((current_chunk, current_length))
163
+ adjusted_chunks = []
164
+ while chunks:
165
+ chunk = chunks.pop(0)
166
+ if len(chunks) > 0 and chunk[1] < max_tokens / 2:
167
+ next_chunk = chunks.pop(0)
168
+ combined_length = chunk[1] + next_chunk[1]
169
+ if combined_length <= max_tokens:
170
+ adjusted_chunks.append((chunk[0] + next_chunk[0], combined_length))
171
+ else:
172
+ adjusted_chunks.append(chunk)
173
+ chunks.insert(0, next_chunk)
174
+ else:
175
+ adjusted_chunks.append(chunk)
176
+ result_chunks = [" ".join(chunk[0]) for chunk in adjusted_chunks]
177
+ return result_chunks
178
+
179
+
180
+ def predict_quillbot(text, bias_buster_selected):
181
+ if bias_buster_selected:
182
+ text = update(text)
183
+ with torch.no_grad():
184
+ quillbot_model.eval()
185
+ tokenized_text = quillbot_tokenizer(
186
+ text,
187
+ padding="max_length",
188
+ truncation=True,
189
+ max_length=256,
190
+ return_tensors="pt",
191
+ ).to(device)
192
+ output = quillbot_model(**tokenized_text)
193
+ output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
194
+ q_score = {
195
+ "Humanized": output_norm[1].item(),
196
+ "Original": output_norm[0].item(),
197
+ }
198
+ return q_score
199
+
200
+
201
+ def predict_for_explainanility(text, model_type=None):
202
+ if model_type == "quillbot":
203
+ cleaning = False
204
+ max_length = 256
205
+ model = humanizer_model_mini
206
+ tokenizer = humanizer_tokenizer_mini
207
+ elif model_type == "bc":
208
+ cleaning = True
209
+ max_length = bc_token_size
210
+ model = bc_model_mini
211
+ tokenizer = bc_tokenizer_mini
212
+ else:
213
+ raise ValueError("Invalid model type")
214
+ with torch.no_grad():
215
+ if cleaning:
216
+ text = [remove_special_characters(t) for t in text]
217
+ tokenized_text = tokenizer(
218
+ text,
219
+ return_tensors="pt",
220
+ padding="max_length",
221
+ truncation=True,
222
+ max_length=max_length,
223
+ ).to(device_needed)
224
+ outputs = model(**tokenized_text)
225
+ tensor_logits = outputs[0]
226
+ probas = F.softmax(tensor_logits).detach().cpu().numpy()
227
+ return probas
228
+
229
+
230
+ def predict_bc(model, tokenizer, text):
231
+ with torch.no_grad():
232
+ model.eval()
233
+ tokens = text_bc_tokenizer(
234
+ text,
235
+ padding="max_length",
236
+ truncation=True,
237
+ max_length=bc_token_size,
238
+ return_tensors="pt",
239
+ ).to(device)
240
+ output = model(**tokens)
241
+ output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
242
+ return output_norm
243
+
244
+
245
+ def predict_mc(model, tokenizer, text):
246
+ with torch.no_grad():
247
+ model.eval()
248
+ tokens = text_mc_tokenizer(
249
+ text,
250
+ padding="max_length",
251
+ truncation=True,
252
+ return_tensors="pt",
253
+ max_length=mc_token_size,
254
+ ).to(device)
255
+ output = model(**tokens)
256
+ output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
257
+ return output_norm
258
+
259
+
260
+ def predict_bc_scores(input):
261
+ bc_scores = []
262
+ samples_len_bc = len(
263
+ split_text_allow_complete_sentences_nltk(input, type_det="bc")
264
+ )
265
+ segments_bc = split_text_allow_complete_sentences_nltk(input, type_det="bc")
266
+ for i in range(samples_len_bc):
267
+ cleaned_text_bc = remove_special_characters(segments_bc[i])
268
+ bc_score = predict_bc(text_bc_model, text_bc_tokenizer, cleaned_text_bc)
269
+ bc_scores.append(bc_score)
270
+ bc_scores_array = np.array(bc_scores)
271
+ average_bc_scores = np.mean(bc_scores_array, axis=0)
272
+ bc_score_list = average_bc_scores.tolist()
273
+ print(
274
+ f"Original BC scores: AI: {bc_score_list[1]}, HUMAN: {bc_score_list[0]}"
275
+ )
276
+ # isotonic regression calibration
277
+ ai_score = iso_reg.predict([bc_score_list[1]])[0]
278
+ human_score = 1 - ai_score
279
+ bc_score = {"AI": ai_score, "HUMAN": human_score}
280
+ print(f"Calibration BC scores: AI: {ai_score}, HUMAN: {human_score}")
281
+ print(f"Input Text: {cleaned_text_bc}")
282
+ return bc_score
283
+
284
+
285
+ def predict_mc_scores(input):
286
+ # BC SCORE
287
+ bc_scores = []
288
+ samples_len_bc = len(
289
+ split_text_allow_complete_sentences_nltk(input, type_det="bc")
290
+ )
291
+ segments_bc = split_text_allow_complete_sentences_nltk(input, type_det="bc")
292
+ for i in range(samples_len_bc):
293
+ cleaned_text_bc = remove_special_characters(segments_bc[i])
294
+ bc_score = predict_bc(text_bc_model, text_bc_tokenizer, cleaned_text_bc)
295
+ bc_scores.append(bc_score)
296
+ bc_scores_array = np.array(bc_scores)
297
+ average_bc_scores = np.mean(bc_scores_array, axis=0)
298
+ bc_score_list = average_bc_scores.tolist()
299
+ print(
300
+ f"Original BC scores: AI: {bc_score_list[1]}, HUMAN: {bc_score_list[0]}"
301
+ )
302
+ # isotonic regression calibration
303
+ ai_score = iso_reg.predict([bc_score_list[1]])[0]
304
+ human_score = 1 - ai_score
305
+ bc_score = {"AI": ai_score, "HUMAN": human_score}
306
+ print(f"Calibration BC scores: AI: {ai_score}, HUMAN: {human_score}")
307
+ mc_scores = []
308
+ segments_mc = split_text_allow_complete_sentences_nltk(
309
+ input, type_det="mc"
310
+
311
 
312
  WORD = re.compile(r"\w+")
313
  model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")