Spaces:
Build error
Build error
import argparse | |
import json | |
import os | |
import torch | |
from datasets import load_dataset | |
from tqdm.auto import tqdm | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DPRQuestionEncoder | |
from common import articles_to_paragraphs, kilt_wikipedia_columns | |
from common import kilt_wikipedia_paragraph_columns as columns | |
def eval_generate(args): | |
device = ("cuda" if torch.cuda.is_available() else "cpu") | |
question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name) | |
question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device) | |
_ = question_model.eval() | |
eli5_tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_eli5') | |
eli5_model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_eli5').to(device) | |
_ = eli5_model.eval() | |
min_snippet_length = 20 | |
topk = 21 | |
min_chars_per_passage = 200 | |
kilt_wikipedia = load_dataset("kilt_wikipedia", split="full") | |
kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True, | |
remove_columns=kilt_wikipedia_columns, | |
batch_size=256, | |
cache_file_name=f"./data/wiki_kilt_paragraphs_full.arrow", | |
desc="Expanding wiki articles into paragraphs") | |
# use paragraphs that are not simple fragments or very short sentences | |
kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter( | |
lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage) | |
kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0) | |
def embed_questions_for_retrieval(questions): | |
query = question_tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
q_reps = question_model(query["input_ids"].to(device), | |
query["attention_mask"].to(device)).pooler_output | |
return q_reps.cpu().numpy() | |
def query_index(question): | |
question_embedding = embed_questions_for_retrieval([question]) | |
scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk) | |
retrieved_examples = [] | |
r = list(zip(wiki_passages[k] for k in columns)) | |
for i in range(topk): | |
retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])}) | |
return retrieved_examples | |
def create_kilt_datapoint(q_id, query, answer, res_list): | |
# make a KILT data point | |
# see https://github.com/facebookresearch/KILT#kilt-data-format | |
provenance = [{ | |
"wikipedia_id": r["wikipedia_id"], # *mandatory* | |
"title": r["title"], | |
"section": r["section"], | |
"start_paragraph_id": r["start_paragraph_id"], | |
"start_character": r["start_character"], | |
"end_paragraph_id": r["end_paragraph_id"], | |
"end_character": r["end_character"], | |
"text": r["text"], | |
"bleu_score": None, # wrt original evidence | |
"meta": None # dataset/task specific | |
} for r in res_list] | |
output = [{"answer": answer, "provenance": provenance}] | |
return {"id": q_id, | |
"input": query, | |
"output": output, # each element is an answer or provenance (can have multiple of each) | |
"meta": None # dataset/task specific | |
} | |
kilt_output = [] | |
with open(args.kilt_input_file, "r") as f: | |
kilt_items = [json.loads(x) for x in f.read().strip().split("\n")] | |
progress_bar = tqdm(range(len(kilt_items)), desc="Creating KILT response document") | |
for idx, item in enumerate(kilt_items): | |
query = item["input"] | |
res_list = query_index(query) | |
res_list = [res for res in res_list if len(res["text"].split()) > min_snippet_length][:int(topk / 3)] | |
documents = [res["text"] for res in res_list] | |
conditioned_doc = "<P> " + " <P> ".join([d for d in documents]) | |
query_and_docs = "question: {} context: {}".format(query, conditioned_doc) | |
model_input = eli5_tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt") | |
generated_answers_encoded = eli5_model.generate(input_ids=model_input["input_ids"].to(device), | |
attention_mask=model_input["attention_mask"].to(device), | |
min_length=50, | |
max_length=250, | |
do_sample=False, | |
early_stopping=True, | |
num_beams=8, | |
temperature=1.0, | |
top_k=None, | |
top_p=None, | |
no_repeat_ngram_size=3, | |
num_return_sequences=1) | |
answer = eli5_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True, | |
clean_up_tokenization_spaces=True) | |
kilt_example = create_kilt_datapoint(item["id"], query, answer[0], res_list) | |
kilt_output.append(kilt_example) | |
progress_bar.update(1) | |
with open(args.kilt_output_file, "w") as fp: | |
for kilt_example in kilt_output: | |
json.dump(kilt_example, fp) | |
fp.write("\n") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--kilt_input_file', default="./eli5-dev-kilt.jsonl", type=str) | |
parser.add_argument('--kilt_output_file', default="./eli5-predicted_retrieval.jsonl", type=str) | |
parser.add_argument( | |
"--question_encoder_name", | |
default="vblagoje/dpr-question_encoder-single-lfqa-base", | |
help="Question encoder to use", | |
) | |
parser.add_argument( | |
"--index_file_name", | |
default="../data/kilt_dpr_wikipedia_first.faiss", | |
help="Faiss index with passage embeddings", | |
) | |
args = parser.parse_args() | |
assert os.path.isfile(args.kilt_input_file), f"Input file {args.kilt_input_file} couldn't be loaded" | |
eval_generate(args) | |