File size: 8,352 Bytes
ee21b96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json
from functools import lru_cache


def convert_sentence_to_json(sentence):
    if "_" in sentence:
        prefix, rest = sentence.split("_", 1)
        query, rest = rest.split("_", 1)
        query_index = len(prefix.rstrip().split(" "))
    else:
        query, query_index = None, None

    prefix, rest = sentence.split("[", 1)
    pronoun, rest = rest.split("]", 1)
    pronoun_index = len(prefix.rstrip().split(" "))

    sentence = sentence.replace("_", "").replace("[", "").replace("]", "")

    return {
        "idx": 0,
        "text": sentence,
        "target": {
            "span1_index": query_index,
            "span1_text": query,
            "span2_index": pronoun_index,
            "span2_text": pronoun,
        },
    }


def extended_noun_chunks(sentence):
    noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks}
    np_start, cur_np = 0, "NONE"
    for i, token in enumerate(sentence):
        np_type = token.pos_ if token.pos_ in {"NOUN", "PROPN"} else "NONE"
        if np_type != cur_np:
            if cur_np != "NONE":
                noun_chunks.add((np_start, i))
            if np_type != "NONE":
                np_start = i
            cur_np = np_type
    if cur_np != "NONE":
        noun_chunks.add((np_start, len(sentence)))
    return [sentence[s:e] for (s, e) in sorted(noun_chunks)]


def find_token(sentence, start_pos):
    found_tok = None
    for tok in sentence:
        if tok.idx == start_pos:
            found_tok = tok
            break
    return found_tok


def find_span(sentence, search_text, start=0):
    search_text = search_text.lower()
    for tok in sentence[start:]:
        remainder = sentence[tok.i :].text.lower()
        if remainder.startswith(search_text):
            len_to_consume = len(search_text)
            start_idx = tok.idx
            for next_tok in sentence[tok.i :]:
                end_idx = next_tok.idx + len(next_tok.text)
                if end_idx - start_idx == len_to_consume:
                    span = sentence[tok.i : next_tok.i + 1]
                    return span
    return None


@lru_cache(maxsize=1)
def get_detokenizer():
    from sacremoses import MosesDetokenizer

    detok = MosesDetokenizer(lang="en")
    return detok


@lru_cache(maxsize=1)
def get_spacy_nlp():
    import en_core_web_lg

    nlp = en_core_web_lg.load()
    return nlp


def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
    detok = get_detokenizer()
    nlp = get_spacy_nlp()

    with open(input_fname) as fin:
        for line in fin:
            sample = json.loads(line.strip())

            if positive_only and "label" in sample and not sample["label"]:
                # only consider examples where the query is correct
                continue

            target = sample["target"]

            # clean up the query
            query = target["span1_text"]
            if query is not None:
                if "\n" in query:
                    continue
                if query.endswith(".") or query.endswith(","):
                    query = query[:-1]

            # split tokens
            tokens = sample["text"].split(" ")

            def strip_pronoun(x):
                return x.rstrip('.,"')

            # find the pronoun
            pronoun_idx = target["span2_index"]
            pronoun = strip_pronoun(target["span2_text"])
            if strip_pronoun(tokens[pronoun_idx]) != pronoun:
                # hack: sometimes the index is misaligned
                if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun:
                    pronoun_idx += 1
                else:
                    raise Exception("Misaligned pronoun!")
            assert strip_pronoun(tokens[pronoun_idx]) == pronoun

            # split tokens before and after the pronoun
            before = tokens[:pronoun_idx]
            after = tokens[pronoun_idx + 1 :]

            # the GPT BPE attaches leading spaces to tokens, so we keep track
            # of whether we need spaces before or after the pronoun
            leading_space = " " if pronoun_idx > 0 else ""
            trailing_space = " " if len(after) > 0 else ""

            # detokenize
            before = detok.detokenize(before, return_str=True)
            pronoun = detok.detokenize([pronoun], return_str=True)
            after = detok.detokenize(after, return_str=True)

            # hack: when the pronoun ends in a period (or comma), move the
            # punctuation to the "after" part
            if pronoun.endswith(".") or pronoun.endswith(","):
                after = pronoun[-1] + trailing_space + after
                pronoun = pronoun[:-1]

            # hack: when the "after" part begins with a comma or period, remove
            # the trailing space
            if after.startswith(".") or after.startswith(","):
                trailing_space = ""

            # parse sentence with spacy
            sentence = nlp(before + leading_space + pronoun + trailing_space + after)

            # find pronoun span
            start = len(before + leading_space)
            first_pronoun_tok = find_token(sentence, start_pos=start)
            pronoun_span = find_span(sentence, pronoun, start=first_pronoun_tok.i)
            assert pronoun_span.text == pronoun

            if eval:
                # convert to format where pronoun is surrounded by "[]" and
                # query is surrounded by "_"
                query_span = find_span(sentence, query)
                query_with_ws = "_{}_{}".format(
                    query_span.text,
                    (" " if query_span.text_with_ws.endswith(" ") else ""),
                )
                pronoun_with_ws = "[{}]{}".format(
                    pronoun_span.text,
                    (" " if pronoun_span.text_with_ws.endswith(" ") else ""),
                )
                if query_span.start < pronoun_span.start:
                    first = (query_span, query_with_ws)
                    second = (pronoun_span, pronoun_with_ws)
                else:
                    first = (pronoun_span, pronoun_with_ws)
                    second = (query_span, query_with_ws)
                sentence = (
                    sentence[: first[0].start].text_with_ws
                    + first[1]
                    + sentence[first[0].end : second[0].start].text_with_ws
                    + second[1]
                    + sentence[second[0].end :].text
                )
                yield sentence, sample.get("label", None)
            else:
                yield sentence, pronoun_span, query, sample.get("label", None)


def winogrande_jsonl_iterator(input_fname, eval=False):
    with open(input_fname) as fin:
        for line in fin:
            sample = json.loads(line.strip())
            sentence, option1, option2 = (
                sample["sentence"],
                sample["option1"],
                sample["option2"],
            )

            pronoun_span = (sentence.index("_"), sentence.index("_") + 1)

            if eval:
                query, cand = option1, option2
            else:
                query = option1 if sample["answer"] == "1" else option2
                cand = option2 if sample["answer"] == "1" else option1
            yield sentence, pronoun_span, query, cand


def filter_noun_chunks(
    chunks, exclude_pronouns=False, exclude_query=None, exact_match=False
):
    if exclude_pronouns:
        chunks = [
            np
            for np in chunks
            if (np.lemma_ != "-PRON-" and not all(tok.pos_ == "PRON" for tok in np))
        ]

    if exclude_query is not None:
        excl_txt = [exclude_query.lower()]
        filtered_chunks = []
        for chunk in chunks:
            lower_chunk = chunk.text.lower()
            found = False
            for excl in excl_txt:
                if (
                    not exact_match and (lower_chunk in excl or excl in lower_chunk)
                ) or lower_chunk == excl:
                    found = True
                    break
            if not found:
                filtered_chunks.append(chunk)
        chunks = filtered_chunks

    return chunks