legacy107 commited on
Commit
c462daf
·
1 Parent(s): 4d99842

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -0
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio.components import Textbox, Checkbox
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
4
+ from peft import PeftModel
5
+ import torch
6
+ import datasets
7
+ from sentence_transformers import CrossEncoder
8
+ import math
9
+ import re
10
+ from nltk import sent_tokenize, word_tokenize
11
+ import nltk
12
+ nltk.download('punkt')
13
+
14
+ # Load bi encoder
15
+ top_k = 10
16
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
17
+
18
+ # Load your fine-tuned model and tokenizer
19
+ model_name = "google/flan-t5-large"
20
+ peft_name = "legacy107/flan-t5-large-ia3-newsqa"
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ pretrained_model = T5ForConditionalGeneration.from_pretrained(model_name)
23
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
24
+ model = PeftModel.from_pretrained(model, peft_name)
25
+ max_length = 512
26
+ max_target_length = 200
27
+
28
+ # Load your dataset
29
+ dataset = datasets.load_dataset("legacy107/newsqa", split="test")
30
+ dataset = dataset.shuffle()
31
+ dataset = dataset.select(range(10))
32
+
33
+ # Context chunking
34
+ def chunk_splitter(context, chunk_size=50, overlap=0.10):
35
+ overlap_size = chunk_size * overlap
36
+ sentences = nltk.sent_tokenize(context)
37
+
38
+ chunks = []
39
+ text = sentences[0]
40
+
41
+ if len(sentences) == 1:
42
+ chunks.append(text)
43
+
44
+ i = 1
45
+ while i < len(sentences):
46
+ text += " " + sentences[i]
47
+ i += 1
48
+ while i < len(sentences) and len(nltk.word_tokenize(f"{text} {sentences[i]}")) <= chunk_size:
49
+ text += " " + sentences[i]
50
+ i += 1
51
+
52
+ text = text.replace('\"','"').replace("\'","'").replace('\n\n\n'," ").replace('\n\n'," ").replace('\n'," ")
53
+ chunks.append(text)
54
+
55
+ if (i >= len(sentences)):
56
+ break
57
+
58
+ j = i - 1
59
+ text = sentences[j]
60
+ while j >= 0 and len(nltk.word_tokenize(f"{sentences[j]} {text}")) <= overlap_size:
61
+ text = sentences[j] + " " + text
62
+ j -= 1
63
+
64
+ return chunks
65
+
66
+
67
+ def retrieve_context(query, contexts):
68
+ hits = [{"corpus_id": i} for i in range(len(contexts))]
69
+ cross_inp = [[query, contexts[hit["corpus_id"]]] for hit in hits]
70
+ cross_scores = cross_encoder.predict(cross_inp, show_progress_bar=False)
71
+
72
+ for idx in range(len(cross_scores)):
73
+ hits[idx]["cross-score"] = cross_scores[idx]
74
+
75
+ hits = sorted(hits, key=lambda x: x["cross-score"], reverse=True)
76
+
77
+ return " ".join(
78
+ [contexts[hit["corpus_id"]] for hit in hits[0:top_k]]
79
+ ).replace("\n", " ")
80
+
81
+
82
+ # Define your function to generate answers
83
+ def generate_answer(question, context, ground, do_pretrained):
84
+ contexts = chunk_splitter(context)
85
+ context = retrieve_context(question, contexts)
86
+
87
+ # Combine question and context
88
+ input_text = f"question: {question} context: {context}"
89
+
90
+ # Tokenize the input text
91
+ input_ids = tokenizer(
92
+ input_text,
93
+ return_tensors="pt",
94
+ padding="max_length",
95
+ truncation=True,
96
+ max_length=max_length,
97
+ ).input_ids
98
+
99
+ # Generate the answer
100
+ with torch.no_grad():
101
+ generated_ids = model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
102
+
103
+ # Decode and return the generated answer
104
+ generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
105
+
106
+ # Get pretrained model's answer
107
+ pretrained_answer = ""
108
+ if do_pretrained:
109
+ with torch.no_grad():
110
+ pretrained_generated_ids = pretrained_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
111
+ pretrained_answer = tokenizer.decode(pretrained_generated_ids[0], skip_special_tokens=True)
112
+
113
+ return generated_answer, context, pretrained_answer
114
+
115
+
116
+ # Define a function to list examples from the dataset
117
+ def list_examples():
118
+ examples = []
119
+ for example in dataset:
120
+ context = example["context"]
121
+ question = example["question"]
122
+ answer = " | ".join(example["answers"])
123
+ examples.append([question, context, answer, True])
124
+ return examples
125
+
126
+
127
+ # Create a Gradio interface
128
+ iface = gr.Interface(
129
+ fn=generate_answer,
130
+ inputs=[
131
+ Textbox(label="Question"),
132
+ Textbox(label="Context"),
133
+ Textbox(label="Ground truth"),
134
+ Checkbox(label="Include pretrained model's result")
135
+ ],
136
+ outputs=[
137
+ Textbox(label="Generated Answer"),
138
+ Textbox(label="Retrieved Context"),
139
+ Textbox(label="Pretrained Model's Answer")
140
+ ],
141
+ examples=list_examples(),
142
+ examples_per_page=1
143
+ )
144
+
145
+ # Launch the Gradio interface
146
+ iface.launch()