Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import json | |
from functools import lru_cache | |
def convert_sentence_to_json(sentence): | |
if "_" in sentence: | |
prefix, rest = sentence.split("_", 1) | |
query, rest = rest.split("_", 1) | |
query_index = len(prefix.rstrip().split(" ")) | |
else: | |
query, query_index = None, None | |
prefix, rest = sentence.split("[", 1) | |
pronoun, rest = rest.split("]", 1) | |
pronoun_index = len(prefix.rstrip().split(" ")) | |
sentence = sentence.replace("_", "").replace("[", "").replace("]", "") | |
return { | |
"idx": 0, | |
"text": sentence, | |
"target": { | |
"span1_index": query_index, | |
"span1_text": query, | |
"span2_index": pronoun_index, | |
"span2_text": pronoun, | |
}, | |
} | |
def extended_noun_chunks(sentence): | |
noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks} | |
np_start, cur_np = 0, "NONE" | |
for i, token in enumerate(sentence): | |
np_type = token.pos_ if token.pos_ in {"NOUN", "PROPN"} else "NONE" | |
if np_type != cur_np: | |
if cur_np != "NONE": | |
noun_chunks.add((np_start, i)) | |
if np_type != "NONE": | |
np_start = i | |
cur_np = np_type | |
if cur_np != "NONE": | |
noun_chunks.add((np_start, len(sentence))) | |
return [sentence[s:e] for (s, e) in sorted(noun_chunks)] | |
def find_token(sentence, start_pos): | |
found_tok = None | |
for tok in sentence: | |
if tok.idx == start_pos: | |
found_tok = tok | |
break | |
return found_tok | |
def find_span(sentence, search_text, start=0): | |
search_text = search_text.lower() | |
for tok in sentence[start:]: | |
remainder = sentence[tok.i :].text.lower() | |
if remainder.startswith(search_text): | |
len_to_consume = len(search_text) | |
start_idx = tok.idx | |
for next_tok in sentence[tok.i :]: | |
end_idx = next_tok.idx + len(next_tok.text) | |
if end_idx - start_idx == len_to_consume: | |
span = sentence[tok.i : next_tok.i + 1] | |
return span | |
return None | |
def get_detokenizer(): | |
from sacremoses import MosesDetokenizer | |
detok = MosesDetokenizer(lang="en") | |
return detok | |
def get_spacy_nlp(): | |
import en_core_web_lg | |
nlp = en_core_web_lg.load() | |
return nlp | |
def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False): | |
detok = get_detokenizer() | |
nlp = get_spacy_nlp() | |
with open(input_fname) as fin: | |
for line in fin: | |
sample = json.loads(line.strip()) | |
if positive_only and "label" in sample and not sample["label"]: | |
# only consider examples where the query is correct | |
continue | |
target = sample["target"] | |
# clean up the query | |
query = target["span1_text"] | |
if query is not None: | |
if "\n" in query: | |
continue | |
if query.endswith(".") or query.endswith(","): | |
query = query[:-1] | |
# split tokens | |
tokens = sample["text"].split(" ") | |
def strip_pronoun(x): | |
return x.rstrip('.,"') | |
# find the pronoun | |
pronoun_idx = target["span2_index"] | |
pronoun = strip_pronoun(target["span2_text"]) | |
if strip_pronoun(tokens[pronoun_idx]) != pronoun: | |
# hack: sometimes the index is misaligned | |
if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun: | |
pronoun_idx += 1 | |
else: | |
raise Exception("Misaligned pronoun!") | |
assert strip_pronoun(tokens[pronoun_idx]) == pronoun | |
# split tokens before and after the pronoun | |
before = tokens[:pronoun_idx] | |
after = tokens[pronoun_idx + 1 :] | |
# the GPT BPE attaches leading spaces to tokens, so we keep track | |
# of whether we need spaces before or after the pronoun | |
leading_space = " " if pronoun_idx > 0 else "" | |
trailing_space = " " if len(after) > 0 else "" | |
# detokenize | |
before = detok.detokenize(before, return_str=True) | |
pronoun = detok.detokenize([pronoun], return_str=True) | |
after = detok.detokenize(after, return_str=True) | |
# hack: when the pronoun ends in a period (or comma), move the | |
# punctuation to the "after" part | |
if pronoun.endswith(".") or pronoun.endswith(","): | |
after = pronoun[-1] + trailing_space + after | |
pronoun = pronoun[:-1] | |
# hack: when the "after" part begins with a comma or period, remove | |
# the trailing space | |
if after.startswith(".") or after.startswith(","): | |
trailing_space = "" | |
# parse sentence with spacy | |
sentence = nlp(before + leading_space + pronoun + trailing_space + after) | |
# find pronoun span | |
start = len(before + leading_space) | |
first_pronoun_tok = find_token(sentence, start_pos=start) | |
pronoun_span = find_span(sentence, pronoun, start=first_pronoun_tok.i) | |
assert pronoun_span.text == pronoun | |
if eval: | |
# convert to format where pronoun is surrounded by "[]" and | |
# query is surrounded by "_" | |
query_span = find_span(sentence, query) | |
query_with_ws = "_{}_{}".format( | |
query_span.text, | |
(" " if query_span.text_with_ws.endswith(" ") else ""), | |
) | |
pronoun_with_ws = "[{}]{}".format( | |
pronoun_span.text, | |
(" " if pronoun_span.text_with_ws.endswith(" ") else ""), | |
) | |
if query_span.start < pronoun_span.start: | |
first = (query_span, query_with_ws) | |
second = (pronoun_span, pronoun_with_ws) | |
else: | |
first = (pronoun_span, pronoun_with_ws) | |
second = (query_span, query_with_ws) | |
sentence = ( | |
sentence[: first[0].start].text_with_ws | |
+ first[1] | |
+ sentence[first[0].end : second[0].start].text_with_ws | |
+ second[1] | |
+ sentence[second[0].end :].text | |
) | |
yield sentence, sample.get("label", None) | |
else: | |
yield sentence, pronoun_span, query, sample.get("label", None) | |
def winogrande_jsonl_iterator(input_fname, eval=False): | |
with open(input_fname) as fin: | |
for line in fin: | |
sample = json.loads(line.strip()) | |
sentence, option1, option2 = ( | |
sample["sentence"], | |
sample["option1"], | |
sample["option2"], | |
) | |
pronoun_span = (sentence.index("_"), sentence.index("_") + 1) | |
if eval: | |
query, cand = option1, option2 | |
else: | |
query = option1 if sample["answer"] == "1" else option2 | |
cand = option2 if sample["answer"] == "1" else option1 | |
yield sentence, pronoun_span, query, cand | |
def filter_noun_chunks( | |
chunks, exclude_pronouns=False, exclude_query=None, exact_match=False | |
): | |
if exclude_pronouns: | |
chunks = [ | |
np | |
for np in chunks | |
if (np.lemma_ != "-PRON-" and not all(tok.pos_ == "PRON" for tok in np)) | |
] | |
if exclude_query is not None: | |
excl_txt = [exclude_query.lower()] | |
filtered_chunks = [] | |
for chunk in chunks: | |
lower_chunk = chunk.text.lower() | |
found = False | |
for excl in excl_txt: | |
if ( | |
not exact_match and (lower_chunk in excl or excl in lower_chunk) | |
) or lower_chunk == excl: | |
found = True | |
break | |
if not found: | |
filtered_chunks.append(chunk) | |
chunks = filtered_chunks | |
return chunks | |