9wimu9's picture
Update app.py
aa65748
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
import gradio as gr
import faiss
import numpy as np
import pandas as pd
from transformers import pipeline
from sentence_transformers import SentenceTransformer
model_name = "9wimu9/lfqa-mt5-large-sin-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
question_answerer_seq_2seq = model.to(device)
retriever_model = SentenceTransformer('9wimu9/retriever-model-sinhala-v2')
question_answerer = pipeline("question-answering", model='9wimu9/xlm-roberta-large-en-si-only-finetuned-sinquad-v12')
def srq2seq_find_answer(query,context):
conditioned_doc = "<P> " + " <P> ".join([d for d in context])
query_and_docs = "question: {} context: {}".format(query, conditioned_doc)
model_input = tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt")
generated_answers_encoded = model.generate(input_ids=model_input["input_ids"].to(device),
attention_mask=model_input["attention_mask"].to(device),
min_length=2,
max_length=120,
early_stopping=True,
num_beams=9,
temperature=0.9,
do_sample=False,
top_k=None,
top_p=None,
eos_token_id=tokenizer.eos_token_id,
no_repeat_ngram_size=8,
num_return_sequences=1)
return tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,clean_up_tokenization_spaces=True)[0]
def encode_file(file_path):
passages = []
with open(file_path) as file:
for item in file:
passages.append([item])
df = pd.DataFrame(passages, columns = ['text'])
text = df['text']
vectors = retriever_model.encode(text)
return vectors,passages
def upload_file(files):
global index
global passages
file_paths = [file.name for file in files]
vectors,passages = encode_file(file_paths[0])
vector_dimension = vectors.shape[1]
index = faiss.IndexFlatL2(vector_dimension)
faiss.normalize_L2(vectors)
index.add(vectors)
return file_paths
def question_answer(search_text):
search_vector = retriever_model.encode(search_text)
print(search_vector)
_vector = np.array([search_vector])
faiss.normalize_L2(_vector)
k = index.ntotal
distances, ann = index.search(_vector, k=k)
context = passages[ann[0][0]][0]
result = question_answerer(question=search_text, context=context)
print(result)
return result['answer']
def question_answer_generated(search_text):
search_vector = retriever_model.encode(search_text)
print(search_vector)
_vector = np.array([search_vector])
faiss.normalize_L2(_vector)
k = index.ntotal
distances, ann = index.search(_vector, k=k)
context = passages[ann[0][0]][0]
return srq2seq_find_answer(search_text,[context])
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
file_output = gr.File()
upload_button = gr.UploadButton("Click to Upload a File", file_types=["txt"], file_count="1")
upload_button.upload(upload_file, upload_button, file_output)
with gr.Row():
with gr.Column():
name = gr.Textbox(label="question")
output = gr.Textbox(label="answer")
greet_btn = gr.Button("get Answer - extraction QA")
greet_btn.click(fn=question_answer, inputs=name, outputs=output, api_name="greet")
with gr.Row():
with gr.Column():
name = gr.Textbox(label="question")
output = gr.Textbox(label="answer")
greet_btn = gr.Button("get Answer - Generated QA")
greet_btn.click(fn=question_answer_generated, inputs=name, outputs=output, api_name="greet")
demo.launch(debug=True)