Spaces:
Runtime error
Runtime error
import json | |
from pyserini.search.lucene import LuceneSearcher | |
from tqdm import tqdm | |
def convert_unicode_to_normal(data): | |
if isinstance(data, str): | |
return data.encode('utf-8').decode('utf-8') | |
elif isinstance(data, list): | |
assert(isinstance(data[0], str)) | |
return [sample.encode('utf-8').decode('utf-8') for sample in data] | |
else: | |
raise ValueError | |
K=30 | |
index_dir="/root/indexes/index-wikipedia-dpr-20210120" # lucene | |
runfile_path=f"runs/q=NQtest_c=wikidpr_m=bm25_k={K}.run" # bm25 | |
qafile_path="/root/nota-fairseq/examples/information_retrieval/open_domain_data/NQ/qa_pairs/test.jsonl" | |
logging_path="logging_q=NQ_c=wiki_including_ans.jsonl" | |
# define searcher with pre-built indexes | |
searcher = LuceneSearcher(index_dir=index_dir) | |
# v2. read qa first (due to runfile query name sort) | |
print("read qa file") | |
pair_by_qid = {} | |
with open(qafile_path, 'r') as fr_qa: | |
for pair in tqdm(fr_qa): | |
pair_data = json.loads(pair) | |
qid, query, answers = pair_data["qid"], pair_data["query"], pair_data["answers"] # str, str, list | |
pair_by_qid[qid] = {'query': query, 'answers':answers} | |
print("check retrieved passage include answer") | |
qid_with_ans_in_retrieval = [] | |
with open(runfile_path, 'r') as fr_run, open(logging_path, 'w') as fw_log: | |
for result in tqdm(fr_run): | |
fields = result.split(' ') | |
assert(len(fields) == 6) # qid q_type pid k score engine | |
qid_, pid = fields[0], fields[2] | |
assert(qid_ in pair_by_qid.keys()) | |
query, answers = pair_by_qid[qid_]['query'], pair_by_qid[qid_]['answers'] | |
# get passage | |
psg_txt = searcher.doc(pid) | |
psg_txt = psg_txt.raw() | |
psg_txt = json.loads(psg_txt) | |
psg_txt = psg_txt['contents'].strip() | |
psg_txt = convert_unicode_to_normal(psg_txt) | |
# check if passage contains answer | |
#if any([ans in psg_txt for ans in answers]): | |
for ans in answers: | |
if ans in psg_txt: | |
log_w = { | |
"qid": qid_, | |
"pid": pid, | |
"query": query, | |
"answer": ans, | |
"passage": psg_txt | |
} | |
fw_log.write(json.dumps(log_w, ensure_ascii=False) + '\n') | |
if qid_ not in qid_with_ans_in_retrieval: | |
qid_with_ans_in_retrieval.append(qid_) | |
break # don't have to count check multiple answer in passage | |
print(f"#qid in test set: {len(pair_by_qid.keys())}, #qid having answer with retrieval(BM25, K={K}): {len(qid_with_ans_in_retrieval)}, Recall = {len(qid_with_ans_in_retrieval)/len(pair_by_qid.keys())*100}") | |
# v1 | |
""" | |
with open(runfile_path, 'r') as fr_run, open(qafile_path, 'r') as fr_qa: | |
for pair in tqdm(fr_qa): | |
pair_data = json.loads(pair) | |
qid, query, answers = pair_data["qid"], pair_data["query"], pair_data["answers"] # str, str, list | |
for k in range(K): | |
result=fr_run.readline() | |
print(result) | |
fields = result.split(' ') | |
assert(len(fields) == 6) # qid q_type pid k score engine | |
qid_, pid = fields[0], fields[2] | |
assert(qid == qid_), f"qid={qid}, qid_={qid_} should be same" | |
# get passage | |
psg_txt = searcher.doc(pid) | |
psg_txt = psg_txt.raw() | |
psg_txt = json.loads(psg_txt) | |
psg_txt = psg_txt['contents'].strip() | |
psg_txt = convert_unicode_to_normal(psg_txt) | |
# check if passage contains answer | |
if any([ans in psg_txt for ans in answers]): | |
import pdb | |
pdb.set_trace() | |
""" |