Parth211 commited on
Commit
9a39944
·
verified ·
1 Parent(s): c4c6e6c
Files changed (1) hide show
  1. app.py +80 -85
app.py CHANGED
@@ -237,88 +237,77 @@ def format_chat_history(message, chat_history):
237
  formatted_chat_history.append(f"User: {user_message}")
238
  formatted_chat_history.append(f"Assistant: {bot_message}")
239
  return formatted_chat_history
240
-
241
-
242
-
243
-
244
-
245
- ###############################################
246
- class RAGEvaluator:
247
- def __init__(self):
248
- self.gpt2_model, self.gpt2_tokenizer = self.load_gpt2_model()
249
- self.bias_pipeline = pipeline("zero-shot-classification", model="Hate-speech-CNERG/dehatebert-mono-english")
250
-
251
- def load_gpt2_model(self):
252
- model = GPT2LMHeadModel.from_pretrained('gpt2')
253
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
254
- return model, tokenizer
255
-
256
- def evaluate_bleu_rouge(self, candidates, references):
257
- bleu_score = corpus_bleu(candidates, [references]).score
258
- scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
259
- rouge_scores = [scorer.score(ref, cand) for ref, cand in zip(references, candidates)]
260
- rouge1 = sum([score['rouge1'].fmeasure for score in rouge_scores]) / len(rouge_scores)
261
- return bleu_score, rouge1
262
-
263
- def evaluate_bert_score(self, candidates, references):
264
- P, R, F1 = score(candidates, references, lang="en", model_type='bert-base-multilingual-cased')
265
- return P.mean().item(), R.mean().item(), F1.mean().item()
266
-
267
- def evaluate_perplexity(self, text):
268
- encodings = self.gpt2_tokenizer(text, return_tensors='pt')
269
- max_length = self.gpt2_model.config.n_positions
270
- stride = 512
271
- lls = []
272
- for i in range(0, encodings.input_ids.size(1), stride):
273
- begin_loc = max(i + stride - max_length, 0)
274
- end_loc = min(i + stride, encodings.input_ids.size(1))
275
- trg_len = end_loc - i
276
- input_ids = encodings.input_ids[:, begin_loc:end_loc]
277
- target_ids = input_ids.clone()
278
- target_ids[:, :-trg_len] = -100
279
- with torch.no_grad():
280
- outputs = self.gpt2_model(input_ids, labels=target_ids)
281
- log_likelihood = outputs[0] * trg_len
282
- lls.append(log_likelihood)
283
- ppl = torch.exp(torch.stack(lls).sum() / end_loc)
284
- return ppl.item()
285
-
286
- def evaluate_diversity(self, texts):
287
- all_tokens = [tok for text in texts for tok in text.split()]
288
- unique_bigrams = set(ngrams(all_tokens, 2))
289
- diversity_score = len(unique_bigrams) / len(all_tokens) if all_tokens else 0
290
- return diversity_score
291
-
292
- def evaluate_racial_bias(self, text):
293
- results = self.bias_pipeline([text], candidate_labels=["hate speech", "not hate speech"])
294
- bias_score = results[0]['scores'][results[0]['labels'].index('hate speech')]
295
- return bias_score
296
-
297
- def evaluate_all(self, question, response, reference):
298
- candidates = [response]
299
- references = [reference]
300
- bleu, rouge1 = self.evaluate_bleu_rouge(candidates, references)
301
- bert_p, bert_r, bert_f1 = self.evaluate_bert_score(candidates, references)
302
- perplexity = self.evaluate_perplexity(response)
303
- diversity = self.evaluate_diversity(candidates)
304
- racial_bias = self.evaluate_racial_bias(response)
305
- return {
306
- "BLEU": bleu,
307
- "ROUGE-1": rouge1,
308
- "BERT P": bert_p,
309
- "BERT R": bert_r,
310
- "BERT F1": bert_f1,
311
- "Perplexity": perplexity,
312
- "Diversity": diversity,
313
- "Racial Bias": racial_bias
314
- }
315
-
316
- ###################################
317
 
318
- evaluator = RAGEvaluator()
319
-
320
-
321
- #################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
  def display_metrics(metrics):
324
  result = ""
@@ -339,8 +328,14 @@ def display_metrics(metrics):
339
  elif k == 'Racial Bias':
340
  result += f"Racial Bias score indicates the presence of biased language in the generated output. Higher scores indicate more bias. Score obtained: {v}\n\n"
341
  return result
 
 
 
 
 
 
342
 
343
- def conversation(qa_chain, message, history, evaluator):
344
  formatted_chat_history = format_chat_history(message, history)
345
  question_by_user = message
346
 
@@ -363,7 +358,7 @@ def conversation(qa_chain, message, history, evaluator):
363
  new_history = history + [(message, response_answer)]
364
 
365
  # Evaluate the metrics
366
- metrics = evaluator.evaluate_all(question_by_user, answer_of_question, context)
367
  evaluation_metrics = display_metrics(metrics)
368
 
369
  return (qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page,
@@ -470,12 +465,12 @@ def demo():
470
 
471
  # Chatbot events
472
  msg.submit(conversation, \
473
- inputs=[qa_chain, msg, chatbot,evaluator], \
474
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page,metrics_output], \
475
  queue=False)
476
 
477
  submit_btn.click(conversation,
478
- inputs=[qa_chain, msg, history,evaluator],
479
  outputs=[qa_chain, chatbot, history, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page, metrics_output])
480
 
481
  clear_btn.click(lambda: [None, "", 0, "", 0, "", 0],
 
237
  formatted_chat_history.append(f"User: {user_message}")
238
  formatted_chat_history.append(f"Assistant: {bot_message}")
239
  return formatted_chat_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
+ #----------------------------------------------------------------------------------
242
+ def load_gpt2_model():
243
+ model = GPT2LMHeadModel.from_pretrained('gpt2')
244
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
245
+ return model, tokenizer
246
+
247
+ gpt2_model, gpt2_tokenizer = load_gpt2_model()
248
+ bias_pipeline = pipeline("zero-shot-classification", model="Hate-speech-CNERG/dehatebert-mono-english")
249
+
250
+ def evaluate_bleu_rouge(candidates, references):
251
+ bleu_score = corpus_bleu(candidates, [references]).score
252
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
253
+ rouge_scores = [scorer.score(ref, cand) for ref, cand in zip(references, candidates)]
254
+ rouge1 = sum([score['rouge1'].fmeasure for score in rouge_scores]) / len(rouge_scores)
255
+ return bleu_score, rouge1
256
+
257
+ def evaluate_bert_score(candidates, references):
258
+ P, R, F1 = score(candidates, references, lang="en", model_type='bert-base-multilingual-cased')
259
+ return P.mean().item(), R.mean().item(), F1.mean().item()
260
+
261
+ def evaluate_perplexity(text, model, tokenizer):
262
+ encodings = tokenizer(text, return_tensors='pt')
263
+ max_length = model.config.n_positions
264
+ stride = 512
265
+ lls = []
266
+ for i in range(0, encodings.input_ids.size(1), stride):
267
+ begin_loc = max(i + stride - max_length, 0)
268
+ end_loc = min(i + stride, encodings.input_ids.size(1))
269
+ trg_len = end_loc - i
270
+ input_ids = encodings.input_ids[:, begin_loc:end_loc]
271
+ target_ids = input_ids.clone()
272
+ target_ids[:, :-trg_len] = -100
273
+ with torch.no_grad():
274
+ outputs = model(input_ids, labels=target_ids)
275
+ log_likelihood = outputs[0] * trg_len
276
+ lls.append(log_likelihood)
277
+ ppl = torch.exp(torch.stack(lls).sum() / end_loc)
278
+ return ppl.item()
279
+
280
+ def evaluate_diversity(texts):
281
+ all_tokens = [tok for text in texts for tok in text.split()]
282
+ unique_bigrams = set(ngrams(all_tokens, 2))
283
+ diversity_score = len(unique_bigrams) / len(all_tokens) if all_tokens else 0
284
+ return diversity_score
285
+
286
+ def evaluate_racial_bias(text, pipeline):
287
+ results = pipeline([text], candidate_labels=["hate speech", "not hate speech"])
288
+ bias_score = results[0]['scores'][results[0]['labels'].index('hate speech')]
289
+ return bias_score
290
+
291
+ def evaluate_all(question, response, reference, gpt2_model, gpt2_tokenizer, bias_pipeline):
292
+ candidates = [response]
293
+ references = [reference]
294
+ bleu, rouge1 = evaluate_bleu_rouge(candidates, references)
295
+ bert_p, bert_r, bert_f1 = evaluate_bert_score(candidates, references)
296
+ perplexity = evaluate_perplexity(response, gpt2_model, gpt2_tokenizer)
297
+ diversity = evaluate_diversity(candidates)
298
+ racial_bias = evaluate_racial_bias(response, bias_pipeline)
299
+ return {
300
+ "BLEU": bleu,
301
+ "ROUGE-1": rouge1,
302
+ "BERT P": bert_p,
303
+ "BERT R": bert_r,
304
+ "BERT F1": bert_f1,
305
+ "Perplexity": perplexity,
306
+ "Diversity": diversity,
307
+ "Racial Bias": racial_bias
308
+ }
309
+
310
+ #---------------------------------------------------------------------------------
311
 
312
  def display_metrics(metrics):
313
  result = ""
 
328
  elif k == 'Racial Bias':
329
  result += f"Racial Bias score indicates the presence of biased language in the generated output. Higher scores indicate more bias. Score obtained: {v}\n\n"
330
  return result
331
+ #---------------------------------------------------------------------------------------------------------------------------------------------------
332
+
333
+
334
+
335
+
336
+
337
 
338
+ def conversation(qa_chain, message, history, gpt2_model, gpt2_tokenizer, bias_pipeline):
339
  formatted_chat_history = format_chat_history(message, history)
340
  question_by_user = message
341
 
 
358
  new_history = history + [(message, response_answer)]
359
 
360
  # Evaluate the metrics
361
+ metrics = evaluate_all(question_by_user, answer_of_question, context)
362
  evaluation_metrics = display_metrics(metrics)
363
 
364
  return (qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page,
 
465
 
466
  # Chatbot events
467
  msg.submit(conversation, \
468
+ inputs=[qa_chain, msg, chatbot], \
469
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page,metrics_output], \
470
  queue=False)
471
 
472
  submit_btn.click(conversation,
473
+ inputs=[qa_chain, msg, history],
474
  outputs=[qa_chain, chatbot, history, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page, metrics_output])
475
 
476
  clear_btn.click(lambda: [None, "", 0, "", 0, "", 0],