Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import torch | |
from transformers import RealmForOpenQA, RealmRetriever | |
model_name = "google/realm-orqa-nq-openqa" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
retriever = RealmRetriever.from_pretrained(model_name) | |
tokenizer = retriever.tokenizer | |
openqa = RealmForOpenQA.from_pretrained(model_name, retriever=retriever) | |
openqa.to(device) | |
default_num_block_records = openqa.config.num_block_records | |
def add_additional_documents(openqa, additional_documents): | |
documents = additional_documents.split("\n") | |
np_documents = np.array([doc.encode() for doc in documents], dtype=object) | |
total_documents = np_documents.shape[0] | |
retriever = openqa.retriever | |
tokenizer = openqa.retriever.tokenizer | |
# docs | |
retriever.block_records = np.concatenate((retriever.block_records[:default_num_block_records], np_documents), axis=0) | |
# embeds | |
inputs = tokenizer(documents, padding=True, truncation=True, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
projected_score = openqa.embedder(**inputs, return_dict=True).projected_score | |
openqa.block_emb = torch.cat((openqa.block_emb[:default_num_block_records], projected_score), dim=0) | |
openqa.config.num_block_records = default_num_block_records + total_documents | |
def question_answer(question, additional_documents): | |
question_ids = tokenizer(question, return_tensors="pt").input_ids | |
if additional_documents != "": | |
add_additional_documents(openqa, additional_documents) | |
with torch.no_grad(): | |
outputs = openqa(input_ids=question_ids.to(device), return_dict=True) | |
return tokenizer.decode(outputs.predicted_answer_ids) | |
additional_documents_input = gr.inputs.Textbox(lines=5, placeholder="Each line represents a document entry. Leave blank to use default wiki documents.") | |
iface = gr.Interface( | |
fn=question_answer, | |
inputs=["text", additional_documents_input], | |
outputs=["textbox"], | |
allow_flagging="never" | |
) | |
iface.launch(enable_queue=True) |