minh21 commited on
Commit
bcb98cf
·
1 Parent(s): 5cd1587

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -0
app.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio.components import Textbox, Checkbox
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
4
+ from peft import PeftModel
5
+ import torch
6
+ import datasets
7
+ from sentence_transformers import CrossEncoder
8
+ import math
9
+ import re
10
+ from nltk import sent_tokenize, word_tokenize
11
+ import nltk
12
+ nltk.download('punkt')
13
+
14
+ # Load cross encoder
15
+ top_k = 10
16
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
17
+
18
+ # Load your fine-tuned model and tokenizer
19
+ model_name = "google/flan-t5-large"
20
+ peft_name = "legacy107/flan-t5-large-ia3-covidqa"
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ pretrained_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
23
+ model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
24
+ model = PeftModel.from_pretrained(model, peft_name)
25
+
26
+ peft_name = "legacy107/flan-t5-large-ia3-bioasq-paraphrase"
27
+ paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
28
+ paraphrase_model = PeftModel.from_pretrained(paraphrase_model, peft_name)
29
+
30
+ max_length = 512
31
+ max_target_length = 200
32
+
33
+ # Load your dataset
34
+ dataset = datasets.load_dataset("minh21/COVID-QA-Chunk-64-testset-biencoder-data-90_10", split="train")
35
+ dataset = dataset.shuffle()
36
+ dataset = dataset.select(range(10))
37
+
38
+ # Context chunking
39
+ min_sentences_per_chunk = 3
40
+ chunk_size = 64
41
+ window_size = math.ceil(min_sentences_per_chunk * 0.25)
42
+ over_lap_chunk_size = chunk_size * 0.25
43
+
44
+ def chunk_splitter(context):
45
+ sentences = sent_tokenize(context)
46
+ chunks = []
47
+ current_chunk = []
48
+
49
+ for sentence in sentences:
50
+ if len(current_chunk) < min_sentences_per_chunk:
51
+ current_chunk.append(sentence)
52
+ continue
53
+ elif len(word_tokenize(' '.join(current_chunk) + " " + sentence)) < chunk_size:
54
+ current_chunk.append(sentence)
55
+ continue
56
+
57
+ chunks.append(' '.join(current_chunk))
58
+ new_chunk = current_chunk[-window_size:]
59
+ new_window = window_size
60
+ buffer_new_chunk = new_chunk
61
+
62
+ while len(word_tokenize(' '.join(new_chunk))) <= over_lap_chunk_size:
63
+ buffer_new_chunk = new_chunk
64
+ new_window += 1
65
+ new_chunk = current_chunk[-new_window:]
66
+ if new_window >= len(current_chunk):
67
+ break
68
+
69
+ current_chunk = buffer_new_chunk
70
+ current_chunk.append(sentence)
71
+
72
+
73
+ if current_chunk:
74
+ chunks.append(' '.join(current_chunk))
75
+
76
+ return chunks
77
+
78
+
79
+ def clean_data(text):
80
+ # Extract abstract content
81
+ index = text.find("\nAbstract: ")
82
+ if index != -1:
83
+ cleaned_text = text[index + len("\nAbstract: "):]
84
+ else:
85
+ cleaned_text = text # If "\nAbstract: " is not found, keep the original text
86
+
87
+ # Remove both http and https links using a regular expression
88
+ cleaned_text = re.sub(r'(http(s|)\/\/:( |)\S+)|(http(s|):\/\/( |)\S+)', '', cleaned_text)
89
+
90
+
91
+ # Remove DOI patterns like "doi:10.1371/journal.pone.0007211.s003"
92
+ cleaned_text = re.sub(r'doi:( |)\w+', '', cleaned_text)
93
+
94
+ # Remove the "(0.11 MB DOC)" pattern
95
+ cleaned_text = re.sub(r'\(0\.\d+ MB DOC\)', '', cleaned_text)
96
+
97
+ cleaned_text = re.sub(r'www\.\w+(.org|)', '', cleaned_text)
98
+
99
+ return cleaned_text
100
+
101
+
102
+ def paraphrase_answer(question, answer, use_pretrained=False):
103
+ # Combine question and context
104
+ input_text = f"question: {question}. Paraphrase the answer to make it more natural answer: {answer}"
105
+
106
+ # Tokenize the input text
107
+ input_ids = tokenizer(
108
+ input_text,
109
+ return_tensors="pt",
110
+ padding="max_length",
111
+ truncation=True,
112
+ max_length=max_length,
113
+ ).input_ids
114
+
115
+ # Generate the answer
116
+ with torch.no_grad():
117
+ if use_pretrained:
118
+ generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
119
+ else:
120
+ generated_ids = paraphrase_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
121
+
122
+ # Decode and return the generated answer
123
+ paraphrased_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
124
+
125
+ return paraphrased_answer
126
+
127
+
128
+ def retrieve_context(question, contexts):
129
+ # cross-encoder
130
+ hits = [{"corpus_id": i} for i in range(len(contexts))]
131
+ cross_inp = [[question, contexts[hit["corpus_id"]]] for hit in hits]
132
+ cross_scores = cross_encoder.predict(cross_inp, show_progress_bar=False)
133
+
134
+ for idx in range(len(cross_scores)):
135
+ hits[idx]["cross-score"] = cross_scores[idx]
136
+
137
+ hits = sorted(hits, key=lambda x: x["cross-score"], reverse=True)
138
+
139
+ return " ".join(
140
+ [contexts[hit["corpus_id"]] for hit in hits[0:top_k]]
141
+ ).replace("\n", " ")
142
+
143
+
144
+ # Define your function to generate answers
145
+ def generate_answer(question, context, ground, do_pretrained, do_natural, do_pretrained_natural):
146
+ contexts = chunk_splitter(clean_data(context))
147
+ context = retrieve_context(question, contexts)
148
+
149
+ # Combine question and context
150
+ input_text = f"question: {question} context: {context}"
151
+
152
+ # Tokenize the input text
153
+ input_ids = tokenizer(
154
+ input_text,
155
+ return_tensors="pt",
156
+ padding="max_length",
157
+ truncation=True,
158
+ max_length=max_length,
159
+ ).input_ids
160
+
161
+ # Generate the answer
162
+ with torch.no_grad():
163
+ generated_ids = model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
164
+
165
+ # Decode and return the generated answer
166
+ generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
167
+
168
+ # Paraphrase answer
169
+ paraphrased_answer = ""
170
+ if do_natural:
171
+ paraphrased_answer = paraphrase_answer(question, generated_answer)
172
+
173
+ # Get pretrained model's answer
174
+ pretrained_answer = ""
175
+ if do_pretrained:
176
+ with torch.no_grad():
177
+ pretrained_generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
178
+ pretrained_answer = tokenizer.decode(pretrained_generated_ids[0], skip_special_tokens=True)
179
+
180
+ # Get pretrained model's natural answer
181
+ pretrained_paraphrased_answer = ""
182
+ if do_pretrained_natural:
183
+ pretrained_paraphrased_answer = paraphrase_answer(question, generated_answer, True)
184
+
185
+ return generated_answer, context, paraphrased_answer, pretrained_answer, pretrained_paraphrased_answer
186
+
187
+
188
+ # Define a function to list examples from the dataset
189
+ def list_examples():
190
+ examples = []
191
+ for example in dataset:
192
+ context = example["context"]
193
+ question = example["question"]
194
+ answer = example["answer"]
195
+ examples.append([question, context, answer, True, True, True])
196
+ return examples
197
+
198
+
199
+ # Create a Gradio interface
200
+ iface = gr.Interface(
201
+ fn=generate_answer,
202
+ inputs=[
203
+ Textbox(label="Question"),
204
+ Textbox(label="Context"),
205
+ Textbox(label="Ground truth"),
206
+ Checkbox(label="Include pretrained model's result"),
207
+ Checkbox(label="Include natural answer"),
208
+ Checkbox(label="Include pretrained model's natural answer")
209
+ ],
210
+ outputs=[
211
+ Textbox(label="Generated Answer"),
212
+ Textbox(label="Retrieved Context"),
213
+ Textbox(label="Natural Answer"),
214
+ Textbox(label="Pretrained Model's Answer"),
215
+ Textbox(label="Pretrained Model's Natural Answer")
216
+ ],
217
+ examples=list_examples(),
218
+ examples_per_page=1,
219
+ )
220
+
221
+ # Launch the Gradio interface
222
+ iface.launch()