|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Most of the tokenization code here is copied from Facebook/DPR & DrQA codebase to avoid adding an extra dependency |
|
""" |
|
|
|
import argparse |
|
import copy |
|
import json |
|
import logging |
|
import re |
|
import unicodedata |
|
from tqdm import tqdm |
|
import numpy as np |
|
import os |
|
import regex |
|
import collections |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
DIRNAME = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
if not os.path.exists('data/nq-annotations.jsonl'): |
|
ANNOTATIONS_TO_DOWNLOAD = [ |
|
('https://dl.fbaipublicfiles.com/qaoverlap/data/nq-annotations.jsonl','nq-annotations.jsonl'), |
|
('https://dl.fbaipublicfiles.com/qaoverlap/data/triviaqa-annotations.jsonl', 'triviaqa-annotations.jsonl'), |
|
('https://dl.fbaipublicfiles.com/qaoverlap/data/webquestions-annotations.jsonl','webquestions-annotations.jsonl') |
|
] |
|
|
|
for link, dest in ANNOTATIONS_TO_DOWNLOAD: |
|
os.system(f'wget {link} -P data/') |
|
|
|
ANNOTATION_PATHS = { |
|
'tqa': os.path.join(DIRNAME, '../../data/triviaqa-annotations.jsonl'), |
|
'nq': os.path.join(DIRNAME, '../../data/nq-annotations.jsonl'), |
|
'webquestions': os.path.join(DIRNAME, '../../data/webquestions-annotations.jsonl'), |
|
} |
|
|
|
class Tokens(object): |
|
"""A class to represent a list of tokenized text.""" |
|
TEXT = 0 |
|
TEXT_WS = 1 |
|
SPAN = 2 |
|
POS = 3 |
|
LEMMA = 4 |
|
NER = 5 |
|
|
|
def __init__(self, data, annotators, opts=None): |
|
self.data = data |
|
self.annotators = annotators |
|
self.opts = opts or {} |
|
|
|
def __len__(self): |
|
"""The number of tokens.""" |
|
return len(self.data) |
|
|
|
def slice(self, i=None, j=None): |
|
"""Return a view of the list of tokens from [i, j).""" |
|
new_tokens = copy.copy(self) |
|
new_tokens.data = self.data[i: j] |
|
return new_tokens |
|
|
|
def untokenize(self): |
|
"""Returns the original text (with whitespace reinserted).""" |
|
return ''.join([t[self.TEXT_WS] for t in self.data]).strip() |
|
|
|
def words(self, uncased=False): |
|
"""Returns a list of the text of each token |
|
Args: |
|
uncased: lower cases text |
|
""" |
|
if uncased: |
|
return [t[self.TEXT].lower() for t in self.data] |
|
else: |
|
return [t[self.TEXT] for t in self.data] |
|
|
|
def offsets(self): |
|
"""Returns a list of [start, end) character offsets of each token.""" |
|
return [t[self.SPAN] for t in self.data] |
|
|
|
def pos(self): |
|
"""Returns a list of part-of-speech tags of each token. |
|
Returns None if this annotation was not included. |
|
""" |
|
if 'pos' not in self.annotators: |
|
return None |
|
return [t[self.POS] for t in self.data] |
|
|
|
def lemmas(self): |
|
"""Returns a list of the lemmatized text of each token. |
|
Returns None if this annotation was not included. |
|
""" |
|
if 'lemma' not in self.annotators: |
|
return None |
|
return [t[self.LEMMA] for t in self.data] |
|
|
|
def entities(self): |
|
"""Returns a list of named-entity-recognition tags of each token. |
|
Returns None if this annotation was not included. |
|
""" |
|
if 'ner' not in self.annotators: |
|
return None |
|
return [t[self.NER] for t in self.data] |
|
|
|
def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): |
|
"""Returns a list of all ngrams from length 1 to n. |
|
Args: |
|
n: upper limit of ngram length |
|
uncased: lower cases text |
|
filter_fn: user function that takes in an ngram list and returns |
|
True or False to keep or not keep the ngram |
|
as_string: return the ngram as a string vs list |
|
""" |
|
|
|
def _skip(gram): |
|
if not filter_fn: |
|
return False |
|
return filter_fn(gram) |
|
|
|
words = self.words(uncased) |
|
ngrams = [(s, e + 1) |
|
for s in range(len(words)) |
|
for e in range(s, min(s + n, len(words))) |
|
if not _skip(words[s:e + 1])] |
|
|
|
|
|
if as_strings: |
|
ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] |
|
|
|
return ngrams |
|
|
|
def entity_groups(self): |
|
"""Group consecutive entity tokens with the same NER tag.""" |
|
entities = self.entities() |
|
if not entities: |
|
return None |
|
non_ent = self.opts.get('non_ent', 'O') |
|
groups = [] |
|
idx = 0 |
|
while idx < len(entities): |
|
ner_tag = entities[idx] |
|
|
|
if ner_tag != non_ent: |
|
|
|
start = idx |
|
while (idx < len(entities) and entities[idx] == ner_tag): |
|
idx += 1 |
|
groups.append((self.slice(start, idx).untokenize(), ner_tag)) |
|
else: |
|
idx += 1 |
|
return groups |
|
|
|
|
|
class Tokenizer(object): |
|
"""Base tokenizer class. |
|
Tokenizers implement tokenize, which should return a Tokens class. |
|
""" |
|
|
|
def tokenize(self, text): |
|
raise NotImplementedError |
|
|
|
def shutdown(self): |
|
pass |
|
|
|
def __del__(self): |
|
self.shutdown() |
|
|
|
|
|
class SimpleTokenizer(Tokenizer): |
|
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' |
|
NON_WS = r'[^\p{Z}\p{C}]' |
|
|
|
def __init__(self, **kwargs): |
|
""" |
|
Args: |
|
annotators: None or empty set (only tokenizes). |
|
""" |
|
self._regexp = regex.compile( |
|
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), |
|
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE |
|
) |
|
if len(kwargs.get('annotators', {})) > 0: |
|
logger.warning('%s only tokenizes! Skipping annotators: %s' % |
|
(type(self).__name__, kwargs.get('annotators'))) |
|
self.annotators = set() |
|
|
|
def tokenize(self, text): |
|
data = [] |
|
matches = [m for m in self._regexp.finditer(text)] |
|
for i in range(len(matches)): |
|
|
|
token = matches[i].group() |
|
|
|
|
|
span = matches[i].span() |
|
start_ws = span[0] |
|
if i + 1 < len(matches): |
|
end_ws = matches[i + 1].span()[0] |
|
else: |
|
end_ws = span[1] |
|
|
|
|
|
data.append(( |
|
token, |
|
text[start_ws: end_ws], |
|
span, |
|
)) |
|
return Tokens(data, self.annotators) |
|
|
|
|
|
def regex_match(text, pattern): |
|
"""Test if a regex pattern is contained within a text.""" |
|
try: |
|
pattern = re.compile( |
|
pattern, |
|
flags=re.IGNORECASE + re.UNICODE + re.MULTILINE, |
|
) |
|
except BaseException: |
|
return False |
|
return pattern.search(text) is not None |
|
|
|
|
|
def _normalize(text): |
|
return unicodedata.normalize('NFD', text) |
|
|
|
|
|
def read_jsonl(path): |
|
with open(path) as f: |
|
return [json.loads(l) for l in f] |
|
|
|
|
|
def read_annotations(annotations_data_path): |
|
return read_jsonl(annotations_data_path) |
|
|
|
|
|
def has_answers(text, answers, tokenizer, regex=False): |
|
text = _normalize(text) |
|
if regex: |
|
for ans in answers: |
|
ans = _normalize(ans) |
|
if regex_match(text, ans): |
|
return True |
|
else: |
|
text = tokenizer.tokenize(text).words(uncased=True) |
|
for ans in answers: |
|
ans = _normalize(ans) |
|
ans = tokenizer.tokenize(ans).words(uncased=True) |
|
for i in range(0, len(text) - len(ans) + 1): |
|
if ans == text[i: i + len(ans)]: |
|
return True |
|
return False |
|
|
|
|
|
def evaluate_retrieval(retrieval_file, topk, annotation_file, regex=False): |
|
tokenizer = SimpleTokenizer() |
|
retrieval = json.load(open(retrieval_file)) |
|
annotations = read_annotations(annotation_file) |
|
annotation_ids = {int(a['id']): a['labels'] for a in annotations} |
|
accuracy = { k : collections.defaultdict(list) for k in topk } |
|
max_k = max(topk) |
|
annotation_labels = [ |
|
'total', |
|
'no_overlap', |
|
'question_overlap', |
|
'no_question_overlap', |
|
'answer_overlap', |
|
'no_answer_overlap', |
|
'answer_overlap_only' |
|
] |
|
|
|
|
|
for qid in retrieval.keys(): |
|
answers = retrieval[qid]['answers'] |
|
contexts = retrieval[qid]['contexts'] |
|
has_ans_idx = max_k |
|
|
|
for idx, ctx in enumerate(contexts): |
|
if idx >= max_k: |
|
break |
|
if 'has_answer' in ctx: |
|
if ctx['has_answer']: |
|
has_ans_idx = idx |
|
break |
|
else: |
|
text = ctx['text'].split('\n')[1] |
|
if has_answers(text, answers, tokenizer, regex): |
|
has_ans_idx = idx |
|
break |
|
|
|
for annotation_label in annotation_labels: |
|
if annotation_label in annotation_ids[int(qid)] or annotation_label == 'total' or \ |
|
(annotation_label == 'no_overlap' and ('no_question_overlap' in annotation_ids[int(qid)]) and ('no_answer_overlap' in annotation_ids[int(qid)])): |
|
for k in topk: |
|
accuracy[k][annotation_label].append(0 if has_ans_idx >= k else 1) |
|
|
|
for k in topk: |
|
for annotation_label in annotation_labels: |
|
print(f'Top{k}\taccuracy: {np.mean(accuracy[k][annotation_label])} \t {annotation_label}') |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--retrieval', type=str, metavar='path', |
|
help="Path to retrieval output file.") |
|
parser.add_argument('--topk', type=int, nargs='+', help="topk to evaluate") |
|
parser.add_argument('--regex', action='store_true', default=False, help="regex match") |
|
parser.add_argument('--dataset_name', choices=['nq', 'tqa', 'webquestions'], type=str, |
|
help='name of datset to evaluate on') |
|
args = parser.parse_args() |
|
|
|
evaluate_retrieval(args.retrieval, args.topk, ANNOTATION_PATHS[args.dataset_name], args.regex) |
|
|