Spaces:
Runtime error
Runtime error
import re | |
import torch | |
kilt_wikipedia_columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories', | |
'wikidata_info', 'history'] | |
kilt_wikipedia_paragraph_columns = ['wikipedia_id', 'start_paragraph_id', 'start_character', 'end_paragraph_id', | |
'end_character', 'title', 'section', 'text'] | |
def clean_question(text): | |
result = cleanup_references(text) | |
result = result.replace("\n", " ") | |
result = re.sub(r"\s\s+", " ", result) | |
result = result.replace("[deleted]", "") | |
return result.lower().strip() | |
def cleanup_references(text): | |
# URL reference where we need to remove both the link text and URL | |
# ...and this letter is used by most biographers as the cornerstone of Lee's personal | |
# views on slavery ([1](_URL_2_ & pg=PA173), [2](_URL_1_), [3](_URL_5_)). | |
# ...and this letter is used by most biographers as the cornerstone of Lee's personal views on slavery. | |
result = re.sub(r"[\(\s]*\[\d+\]\([^)]+\)[,)]*", "", text, 0, re.MULTILINE) | |
# URL reference where we need to preserve link text but remove URL | |
# At the outbreak of the Civil War, [Leyburn left his church](_URL_19_) and joined the South. | |
# At the outbreak of the Civil War, Leyburn left his church and joined the South. | |
result = re.sub(r"\[([^]]+)\]\([^)]+\)", "\\1", result, 0, re.MULTILINE) | |
# lastly remove just dangling _URL_[0-9]_ URL references | |
result = re.sub(r"_URL_\d_", "", result, 0, re.MULTILINE) | |
return result | |
def clean_answer(text): | |
result = cleanup_references(text) | |
result = result.replace("\n", " ") | |
result = re.sub(r"\s\s+", " ", result) | |
result = re.sub(r"BULLET::::-", "", result) | |
return trim(result.strip()) | |
def trim(text, word_count: int = 100): | |
return " ".join(text.split(" ")[:word_count]) | |
def articles_to_paragraphs(examples): | |
ids, titles, sections, texts, start_ps, end_ps, start_cs, end_cs = [], [], [], [], [], [], [], [] | |
for bidx, example in enumerate(examples["text"]): | |
last_section = "" | |
for idx, p in enumerate(example["paragraph"]): | |
if "Section::::" in p: | |
last_section = p | |
ids.append(examples["wikipedia_id"][bidx]) | |
titles.append(examples["wikipedia_title"][bidx]) | |
sections.append(last_section) | |
texts.append(p) | |
start_ps.append(idx) | |
end_ps.append(idx) | |
start_cs.append(0) | |
end_cs.append(len(p)) | |
return {"wikipedia_id": ids, "title": titles, | |
"section": sections, "text": texts, | |
"start_paragraph_id": start_ps, "end_paragraph_id": end_ps, | |
"start_character": start_cs, | |
"end_character": end_cs | |
} | |
def create_kilt_datapoint(eli5_example, columns, wiki_passages, min_length=20, topk=7): | |
res_list = [dict([(k, p[k]) for k in columns]) for p in wiki_passages] | |
res_list = [res for res in res_list if len(res["text"].split()) > min_length][:topk] | |
# make a KILT data point | |
# see https://github.com/facebookresearch/KILT#kilt-data-format | |
output = [] | |
for a in eli5_example["answers"]["text"]: | |
output.append({"answer": a}) | |
output.append({"provenance": [ | |
# evidence set for the answer from the KILT ks | |
{ | |
"wikipedia_id": r["wikipedia_id"], # *mandatory* | |
"title": r["title"], | |
"section": r["section"], | |
"start_paragraph_id": r["start_paragraph_id"], | |
"start_character": r["start_character"], | |
"end_paragraph_id": r["end_paragraph_id"], | |
"end_character": r["end_character"], | |
"text": r["text"], | |
"bleu_score": None, # wrt original evidence | |
"meta": None # dataset/task specific | |
} for r in res_list | |
]}) | |
return {"id": eli5_example["q_id"], | |
"input": eli5_example["title"], | |
"output": output, # each element is an answer or provenance (can have multiple of each) | |
"meta": None # dataset/task specific | |
} | |
def embed_questions(question_model, question_tokenizer, questions, max_length=128, device="cuda:0"): | |
query = question_tokenizer(questions, max_length=max_length, padding="max_length", truncation=True, | |
return_tensors="pt") | |
with torch.no_grad(): | |
q_reps = question_model(query["input_ids"].to(device), | |
query["attention_mask"].to(device)).pooler_output | |
return q_reps.cpu().numpy() | |
def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cuda:0"): | |
p = ctx_tokenizer(passages["text"], max_length=max_length, padding="max_length", | |
truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
a_reps = ctx_model(p["input_ids"].to(device), | |
p["attention_mask"].to(device)).pooler_output | |
return {"embeddings": a_reps.cpu().numpy()} | |