File size: 8,009 Bytes
8778cfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
from abc import ABC, abstractmethod
import dataclasses
from typing import Any, Iterable, List, Optional, Tuple, Union

import nltk
import numpy as np


class BaseInputExample(ABC):
    """Parser input for a single sentence (abstract interface)."""

    # Subclasses must define the following attributes or properties.
    # `words` is a list of unicode representations for each word in the sentence
    # and `space_after` is a list of booleans that indicate whether there is
    # whitespace after a word. Together, these should form a reversible
    # tokenization of raw text input. `tree` is an optional gold parse tree.
    words: List[str]
    space_after: List[bool]
    tree: Optional[nltk.Tree]

    @abstractmethod
    def leaves(self) -> Optional[List[str]]:
        """Returns leaves to use in the parse tree.

        While `words` must be raw unicode text, these should be whatever is
        standard for the treebank. For example, '(' in words might correspond to
        '-LRB-' in leaves, and leaves might include other transformations such
        as transliteration.
        """
        pass

    @abstractmethod
    def pos(self) -> Optional[List[Tuple[str, str]]]:
        """Returns a list of (leaf, part-of-speech tag) tuples."""
        pass


@dataclasses.dataclass
class CompressedParserOutput:
    """Parser output, encoded as a collection of numpy arrays.

    By default, a parser will return nltk.Tree objects. These have much nicer
    APIs than the CompressedParserOutput class, and the code involved is simpler
    and more readable. As a trade-off, code dealing with nltk.Tree objects is
    slower: the nltk.Tree type itself has some overhead, and algorithms dealing
    with it are implemented in pure Python as opposed to C or even CUDA. The
    CompressedParserOutput type is an alternative that has some optimizations
    for the sole purpose of speeding up inference.

    If trying a new parser type for research purposes, it's safe to ignore this
    class and the return_compressed argument to parse(). If the parser works
    well and is being released, the return_compressed argument can then be added
    with a dedicated fast implementation, or simply by using the from_tree
    method defined below.
    """

    # A parse tree is represented as a set of constituents. In the case of
    # non-binary trees, only the labeled non-terminal nodes are included: there
    # are no dummy nodes inserted for binarization purposes. However, single
    # words are always included in the set of constituents, and they may have a
    # null label if there is no phrasal category above the part-of-speech tag.
    # All constituents are sorted according to pre-order traversal, and each has
    # an associated start (the index of the first word in the constituent), end
    # (1 + the index of the last word in the constituent), and label (index
    # associated with an external label_vocab dictionary.) These are then stored
    # in three numpy arrays:
    starts: Iterable[int]  # Must be a numpy array
    ends: Iterable[int]  # Must be a numpy array
    labels: Iterable[int]  # Must be a numpy array

    # Part of speech tag ids as output by the parser (may be None if the parser
    # does not do POS tagging). These indices are associated with an external
    # tag_vocab dictionary.
    tags: Optional[Iterable[int]] = None # Must be None or a numpy array

    def without_predicted_tags(self):
        return dataclasses.replace(self, tags=None)

    def with_tags(self, tags):
        return dataclasses.replace(self, tags=tags)

    @classmethod
    def from_tree(
        cls, tree: nltk.Tree, label_vocab: dict, tag_vocab: Optional[dict] = None
    ) -> "CompressedParserOutput":
        num_words = len(tree.leaves())
        starts = np.empty(2 * num_words, dtype=int)
        ends = np.empty(2 * num_words, dtype=int)
        labels = np.empty(2 * num_words, dtype=int)

        def helper(tree, start, write_idx):
            nonlocal starts, ends, labels
            label = []
            while len(tree) == 1 and not isinstance(tree[0], str):
                if tree.label() != "TOP":
                    label.append(tree.label())
                tree = tree[0]

            if len(tree) == 1 and isinstance(tree[0], str):
                starts[write_idx] = start
                ends[write_idx] = start + 1
                labels[write_idx] = label_vocab["::".join(label)]
                return start + 1, write_idx + 1

            label.append(tree.label())
            starts[write_idx] = start
            labels[write_idx] = label_vocab["::".join(label)]

            end = start
            new_write_idx = write_idx + 1
            for child in tree:
                end, new_write_idx = helper(child, end, new_write_idx)

            ends[write_idx] = end
            return end, new_write_idx

        _, num_constituents = helper(tree, 0, 0)
        starts = starts[:num_constituents]
        ends = ends[:num_constituents]
        labels = labels[:num_constituents]

        if tag_vocab is None:
            tags = None
        else:
            tags = np.array([tag_vocab[tag] for _, tag in tree.pos()], dtype=int)

        return cls(starts=starts, ends=ends, labels=labels, tags=tags)

    def to_tree(self, leaves, label_from_index: dict, tag_from_index: dict = None):
        if self.tags is not None:
            if tag_from_index is None:
                raise ValueError(
                    "tags_from_index is required to convert predicted pos tags"
                )
            predicted_tags = [tag_from_index[i] for i in self.tags]
            assert len(leaves) == len(predicted_tags)
            leaves = [
                nltk.Tree(tag, [leaf[0] if isinstance(leaf, tuple) else leaf])
                for tag, leaf in zip(predicted_tags, leaves)
            ]
        else:
            leaves = [
                nltk.Tree(leaf[1], [leaf[0]])
                if isinstance(leaf, tuple)
                else (nltk.Tree("UNK", [leaf]) if isinstance(leaf, str) else leaf)
                for leaf in leaves
            ]

        idx = -1

        def helper():
            nonlocal idx
            idx += 1
            i, j, label = (
                self.starts[idx],
                self.ends[idx],
                label_from_index[self.labels[idx]],
            )
            if (i + 1) >= j:
                children = [leaves[i]]
            else:
                children = []
                while (
                    (idx + 1) < len(self.starts)
                    and i <= self.starts[idx + 1]
                    and self.ends[idx + 1] <= j
                ):
                    children.extend(helper())

            if label:
                for sublabel in reversed(label.split("::")):
                    children = [nltk.Tree(sublabel, children)]

            return children

        children = helper()
        return nltk.Tree("TOP", children)


class BaseParser(ABC):
    """Parser (abstract interface)"""

    @classmethod
    @abstractmethod
    def from_trained(
        cls, model_name: str, config: dict = None, state_dict: dict = None
    ) -> "BaseParser":
        """Load a trained parser."""
        pass

    @abstractmethod
    def parallelize(self, *args, **kwargs):
        """Spread out pre-trained model layers across GPUs."""
        pass

    @abstractmethod
    def parse(
        self,
        examples: Iterable[BaseInputExample],
        return_compressed: bool = False,
        return_scores: bool = False,
        subbatch_max_tokens: Optional[int] = None,
    ) -> Union[Iterable[nltk.Tree], Iterable[Any]]:
        """Parse sentences."""
        pass

    @abstractmethod
    def encode_and_collate_subbatches(
        self, examples: List[BaseInputExample], subbatch_max_tokens: int
    ) -> List[dict]:
        """Split batch into sub-batches and convert to tensor features"""
        pass

    @abstractmethod
    def compute_loss(self, batch: dict):
        pass