import argparse |
import json |
import os |
import torch |
from datasets import load_dataset |
from tqdm.auto import tqdm |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DPRQuestionEncoder |
from common import articles_to_paragraphs, kilt_wikipedia_columns |
from common import kilt_wikipedia_paragraph_columns as columns |
def eval_generate(args): |
device = ("cuda" if torch.cuda.is_available() else "cpu") |
question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name) |
question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device) |
_ = question_model.eval() |
eli5_tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_eli5') |
eli5_model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_eli5').to(device) |
_ = eli5_model.eval() |
min_snippet_length = 20 |
topk = 21 |
min_chars_per_passage = 200 |
kilt_wikipedia = load_dataset("kilt_wikipedia", split="full") |
kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True, |
remove_columns=kilt_wikipedia_columns, |
batch_size=256, |
cache_file_name=f"./data/wiki_kilt_paragraphs_full.arrow", |
desc="Expanding wiki articles into paragraphs") |
kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter( |
lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage) |
kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0) |
def embed_questions_for_retrieval(questions): |
query = question_tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt") |
with torch.no_grad(): |
q_reps = question_model(query["input_ids"].to(device), |
query["attention_mask"].to(device)).pooler_output |
return q_reps.cpu().numpy() |
def query_index(question): |
question_embedding = embed_questions_for_retrieval([question]) |
scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk) |
retrieved_examples = [] |
r = list(zip(wiki_passages[k] for k in columns)) |
for i in range(topk): |
retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])}) |
return retrieved_examples |
def create_kilt_datapoint(q_id, query, answer, res_list): |
provenance = [{ |
"wikipedia_id": r["wikipedia_id"], |
"title": r["title"], |
"section": r["section"], |
"start_paragraph_id": r["start_paragraph_id"], |
"start_character": r["start_character"], |
"end_paragraph_id": r["end_paragraph_id"], |
"end_character": r["end_character"], |
"text": r["text"], |
"bleu_score": None, |
"meta": None |
} for r in res_list] |
output = [{"answer": answer, "provenance": provenance}] |
return {"id": q_id, |
"input": query, |
"output": output, |
"meta": None |
} |
kilt_output = [] |
with open(args.kilt_input_file, "r") as f: |
kilt_items = [json.loads(x) for x in f.read().strip().split("\n")] |
progress_bar = tqdm(range(len(kilt_items)), desc="Creating KILT response document") |
for idx, item in enumerate(kilt_items): |
query = item["input"] |
res_list = query_index(query) |
res_list = [res for res in res_list if len(res["text"].split()) > min_snippet_length][:int(topk / 3)] |
documents = [res["text"] for res in res_list] |
conditioned_doc = "<P> " + " <P> ".join([d for d in documents]) |
query_and_docs = "question: {} context: {}".format(query, conditioned_doc) |
model_input = eli5_tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt") |
generated_answers_encoded = eli5_model.generate(input_ids=model_input["input_ids"].to(device), |
attention_mask=model_input["attention_mask"].to(device), |
min_length=50, |
max_length=250, |
do_sample=False, |
early_stopping=True, |
num_beams=8, |
temperature=1.0, |
top_k=None, |
top_p=None, |
no_repeat_ngram_size=3, |
num_return_sequences=1) |
answer = eli5_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True, |
clean_up_tokenization_spaces=True) |
kilt_example = create_kilt_datapoint(item["id"], query, answer[0], res_list) |
kilt_output.append(kilt_example) |
progress_bar.update(1) |
with open(args.kilt_output_file, "w") as fp: |
for kilt_example in kilt_output: |
json.dump(kilt_example, fp) |
fp.write("\n") |
if __name__ == "__main__": |
parser = argparse.ArgumentParser() |
parser.add_argument('--kilt_input_file', default="./eli5-dev-kilt.jsonl", type=str) |
parser.add_argument('--kilt_output_file', default="./eli5-predicted_retrieval.jsonl", type=str) |
parser.add_argument( |
"--question_encoder_name", |
default="vblagoje/dpr-question_encoder-single-lfqa-base", |
help="Question encoder to use", |
) |
parser.add_argument( |
"--index_file_name", |
default="../data/kilt_dpr_wikipedia_first.faiss", |
help="Faiss index with passage embeddings", |
) |
args = parser.parse_args() |
assert os.path.isfile(args.kilt_input_file), f"Input file {args.kilt_input_file} couldn't be loaded" |
eval_generate(args) |