|
import argparse |
|
import json |
|
import os |
|
|
|
import torch |
|
from datasets import load_dataset |
|
from tqdm.auto import tqdm |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DPRQuestionEncoder |
|
|
|
from common import articles_to_paragraphs, kilt_wikipedia_columns |
|
from common import kilt_wikipedia_paragraph_columns as columns |
|
|
|
|
|
def eval_generate(args): |
|
device = ("cuda" if torch.cuda.is_available() else "cpu") |
|
question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name) |
|
question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device) |
|
_ = question_model.eval() |
|
|
|
eli5_tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_eli5') |
|
eli5_model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_eli5').to(device) |
|
_ = eli5_model.eval() |
|
|
|
min_snippet_length = 20 |
|
topk = 21 |
|
min_chars_per_passage = 200 |
|
kilt_wikipedia = load_dataset("kilt_wikipedia", split="full") |
|
kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True, |
|
remove_columns=kilt_wikipedia_columns, |
|
batch_size=256, |
|
cache_file_name=f"./data/wiki_kilt_paragraphs_full.arrow", |
|
desc="Expanding wiki articles into paragraphs") |
|
|
|
|
|
kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter( |
|
lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage) |
|
kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0) |
|
|
|
def embed_questions_for_retrieval(questions): |
|
query = question_tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt") |
|
with torch.no_grad(): |
|
q_reps = question_model(query["input_ids"].to(device), |
|
query["attention_mask"].to(device)).pooler_output |
|
return q_reps.cpu().numpy() |
|
|
|
def query_index(question): |
|
question_embedding = embed_questions_for_retrieval([question]) |
|
scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk) |
|
|
|
retrieved_examples = [] |
|
r = list(zip(wiki_passages[k] for k in columns)) |
|
for i in range(topk): |
|
retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])}) |
|
return retrieved_examples |
|
|
|
def create_kilt_datapoint(q_id, query, answer, res_list): |
|
|
|
|
|
|
|
provenance = [{ |
|
"wikipedia_id": r["wikipedia_id"], |
|
"title": r["title"], |
|
"section": r["section"], |
|
"start_paragraph_id": r["start_paragraph_id"], |
|
"start_character": r["start_character"], |
|
"end_paragraph_id": r["end_paragraph_id"], |
|
"end_character": r["end_character"], |
|
"text": r["text"], |
|
"bleu_score": None, |
|
"meta": None |
|
} for r in res_list] |
|
|
|
output = [{"answer": answer, "provenance": provenance}] |
|
|
|
return {"id": q_id, |
|
"input": query, |
|
"output": output, |
|
"meta": None |
|
} |
|
|
|
kilt_output = [] |
|
with open(args.kilt_input_file, "r") as f: |
|
kilt_items = [json.loads(x) for x in f.read().strip().split("\n")] |
|
progress_bar = tqdm(range(len(kilt_items)), desc="Creating KILT response document") |
|
for idx, item in enumerate(kilt_items): |
|
query = item["input"] |
|
res_list = query_index(query) |
|
|
|
res_list = [res for res in res_list if len(res["text"].split()) > min_snippet_length][:int(topk / 3)] |
|
documents = [res["text"] for res in res_list] |
|
conditioned_doc = "<P> " + " <P> ".join([d for d in documents]) |
|
|
|
query_and_docs = "question: {} context: {}".format(query, conditioned_doc) |
|
|
|
model_input = eli5_tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt") |
|
generated_answers_encoded = eli5_model.generate(input_ids=model_input["input_ids"].to(device), |
|
attention_mask=model_input["attention_mask"].to(device), |
|
min_length=50, |
|
max_length=250, |
|
do_sample=False, |
|
early_stopping=True, |
|
num_beams=8, |
|
temperature=1.0, |
|
top_k=None, |
|
top_p=None, |
|
no_repeat_ngram_size=3, |
|
num_return_sequences=1) |
|
answer = eli5_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True) |
|
|
|
kilt_example = create_kilt_datapoint(item["id"], query, answer[0], res_list) |
|
kilt_output.append(kilt_example) |
|
progress_bar.update(1) |
|
|
|
with open(args.kilt_output_file, "w") as fp: |
|
for kilt_example in kilt_output: |
|
json.dump(kilt_example, fp) |
|
fp.write("\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--kilt_input_file', default="./eli5-dev-kilt.jsonl", type=str) |
|
parser.add_argument('--kilt_output_file', default="./eli5-predicted_retrieval.jsonl", type=str) |
|
parser.add_argument( |
|
"--question_encoder_name", |
|
default="vblagoje/dpr-question_encoder-single-lfqa-base", |
|
help="Question encoder to use", |
|
) |
|
|
|
parser.add_argument( |
|
"--index_file_name", |
|
default="../data/kilt_dpr_wikipedia_first.faiss", |
|
help="Faiss index with passage embeddings", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
assert os.path.isfile(args.kilt_input_file), f"Input file {args.kilt_input_file} couldn't be loaded" |
|
eval_generate(args) |
|
|