File size: 4,144 Bytes
a182034
 
f327ce6
a182034
 
 
f327ce6
a182034
f327ce6
aa65748
a182034
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f327ce6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
import gradio as gr
import faiss
import numpy as np
import pandas as pd
from transformers import pipeline
from sentence_transformers import SentenceTransformer

model_name = "9wimu9/lfqa-mt5-large-sin-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
question_answerer_seq_2seq = model.to(device)

retriever_model = SentenceTransformer('9wimu9/retriever-model-sinhala-v2')
question_answerer = pipeline("question-answering", model='9wimu9/xlm-roberta-large-en-si-only-finetuned-sinquad-v12')

def srq2seq_find_answer(query,context):
    conditioned_doc = "<P> " + " <P> ".join([d for d in context])
    query_and_docs = "question: {} context: {}".format(query, conditioned_doc)

    model_input = tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt")

    generated_answers_encoded = model.generate(input_ids=model_input["input_ids"].to(device),
                                           attention_mask=model_input["attention_mask"].to(device),
                                           min_length=2,
                                           max_length=120,
                                           early_stopping=True,
                                           num_beams=9,
                                           temperature=0.9,
                                           do_sample=False,
                                           top_k=None,
                                           top_p=None,
                                           eos_token_id=tokenizer.eos_token_id,
                                           no_repeat_ngram_size=8,
                                           num_return_sequences=1)
    return tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,clean_up_tokenization_spaces=True)[0]

def encode_file(file_path):
  passages = []
  with open(file_path) as file:
    for item in file:
      passages.append([item])
  df = pd.DataFrame(passages, columns = ['text'])
  text = df['text']
  vectors = retriever_model.encode(text)
  return vectors,passages

def upload_file(files):
  global index
  global passages
  file_paths = [file.name for file in files]
  vectors,passages = encode_file(file_paths[0])
  vector_dimension = vectors.shape[1]
  index = faiss.IndexFlatL2(vector_dimension)
  faiss.normalize_L2(vectors)
  index.add(vectors)
  return file_paths

def question_answer(search_text):
  search_vector = retriever_model.encode(search_text)
  print(search_vector)
  _vector = np.array([search_vector])
  faiss.normalize_L2(_vector)
  k = index.ntotal
  distances, ann = index.search(_vector, k=k)
  context = passages[ann[0][0]][0]
  result = question_answerer(question=search_text,     context=context)
  print(result)
  return result['answer']


def question_answer_generated(search_text):
  search_vector = retriever_model.encode(search_text)
  print(search_vector)
  _vector = np.array([search_vector])
  faiss.normalize_L2(_vector)
  k = index.ntotal
  distances, ann = index.search(_vector, k=k)
  context = passages[ann[0][0]][0]
  return srq2seq_find_answer(search_text,[context])

with gr.Blocks() as demo:
  with gr.Row():
    with gr.Column():
      file_output = gr.File()
      upload_button = gr.UploadButton("Click to Upload a File", file_types=["txt"], file_count="1")
      upload_button.upload(upload_file, upload_button, file_output)
  with gr.Row():
    with gr.Column():
      name = gr.Textbox(label="question")
      output = gr.Textbox(label="answer")
      greet_btn = gr.Button("get Answer - extraction QA")
      greet_btn.click(fn=question_answer, inputs=name, outputs=output, api_name="greet")
  with gr.Row():
    with gr.Column():
      name = gr.Textbox(label="question")
      output = gr.Textbox(label="answer")
      greet_btn = gr.Button("get Answer - Generated QA")
      greet_btn.click(fn=question_answer_generated, inputs=name, outputs=output, api_name="greet")

demo.launch(debug=True)