Spaces:
Sleeping
Sleeping
File size: 8,009 Bytes
8778cfe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
from abc import ABC, abstractmethod
import dataclasses
from typing import Any, Iterable, List, Optional, Tuple, Union
import nltk
import numpy as np
class BaseInputExample(ABC):
"""Parser input for a single sentence (abstract interface)."""
# Subclasses must define the following attributes or properties.
# `words` is a list of unicode representations for each word in the sentence
# and `space_after` is a list of booleans that indicate whether there is
# whitespace after a word. Together, these should form a reversible
# tokenization of raw text input. `tree` is an optional gold parse tree.
words: List[str]
space_after: List[bool]
tree: Optional[nltk.Tree]
@abstractmethod
def leaves(self) -> Optional[List[str]]:
"""Returns leaves to use in the parse tree.
While `words` must be raw unicode text, these should be whatever is
standard for the treebank. For example, '(' in words might correspond to
'-LRB-' in leaves, and leaves might include other transformations such
as transliteration.
"""
pass
@abstractmethod
def pos(self) -> Optional[List[Tuple[str, str]]]:
"""Returns a list of (leaf, part-of-speech tag) tuples."""
pass
@dataclasses.dataclass
class CompressedParserOutput:
"""Parser output, encoded as a collection of numpy arrays.
By default, a parser will return nltk.Tree objects. These have much nicer
APIs than the CompressedParserOutput class, and the code involved is simpler
and more readable. As a trade-off, code dealing with nltk.Tree objects is
slower: the nltk.Tree type itself has some overhead, and algorithms dealing
with it are implemented in pure Python as opposed to C or even CUDA. The
CompressedParserOutput type is an alternative that has some optimizations
for the sole purpose of speeding up inference.
If trying a new parser type for research purposes, it's safe to ignore this
class and the return_compressed argument to parse(). If the parser works
well and is being released, the return_compressed argument can then be added
with a dedicated fast implementation, or simply by using the from_tree
method defined below.
"""
# A parse tree is represented as a set of constituents. In the case of
# non-binary trees, only the labeled non-terminal nodes are included: there
# are no dummy nodes inserted for binarization purposes. However, single
# words are always included in the set of constituents, and they may have a
# null label if there is no phrasal category above the part-of-speech tag.
# All constituents are sorted according to pre-order traversal, and each has
# an associated start (the index of the first word in the constituent), end
# (1 + the index of the last word in the constituent), and label (index
# associated with an external label_vocab dictionary.) These are then stored
# in three numpy arrays:
starts: Iterable[int] # Must be a numpy array
ends: Iterable[int] # Must be a numpy array
labels: Iterable[int] # Must be a numpy array
# Part of speech tag ids as output by the parser (may be None if the parser
# does not do POS tagging). These indices are associated with an external
# tag_vocab dictionary.
tags: Optional[Iterable[int]] = None # Must be None or a numpy array
def without_predicted_tags(self):
return dataclasses.replace(self, tags=None)
def with_tags(self, tags):
return dataclasses.replace(self, tags=tags)
@classmethod
def from_tree(
cls, tree: nltk.Tree, label_vocab: dict, tag_vocab: Optional[dict] = None
) -> "CompressedParserOutput":
num_words = len(tree.leaves())
starts = np.empty(2 * num_words, dtype=int)
ends = np.empty(2 * num_words, dtype=int)
labels = np.empty(2 * num_words, dtype=int)
def helper(tree, start, write_idx):
nonlocal starts, ends, labels
label = []
while len(tree) == 1 and not isinstance(tree[0], str):
if tree.label() != "TOP":
label.append(tree.label())
tree = tree[0]
if len(tree) == 1 and isinstance(tree[0], str):
starts[write_idx] = start
ends[write_idx] = start + 1
labels[write_idx] = label_vocab["::".join(label)]
return start + 1, write_idx + 1
label.append(tree.label())
starts[write_idx] = start
labels[write_idx] = label_vocab["::".join(label)]
end = start
new_write_idx = write_idx + 1
for child in tree:
end, new_write_idx = helper(child, end, new_write_idx)
ends[write_idx] = end
return end, new_write_idx
_, num_constituents = helper(tree, 0, 0)
starts = starts[:num_constituents]
ends = ends[:num_constituents]
labels = labels[:num_constituents]
if tag_vocab is None:
tags = None
else:
tags = np.array([tag_vocab[tag] for _, tag in tree.pos()], dtype=int)
return cls(starts=starts, ends=ends, labels=labels, tags=tags)
def to_tree(self, leaves, label_from_index: dict, tag_from_index: dict = None):
if self.tags is not None:
if tag_from_index is None:
raise ValueError(
"tags_from_index is required to convert predicted pos tags"
)
predicted_tags = [tag_from_index[i] for i in self.tags]
assert len(leaves) == len(predicted_tags)
leaves = [
nltk.Tree(tag, [leaf[0] if isinstance(leaf, tuple) else leaf])
for tag, leaf in zip(predicted_tags, leaves)
]
else:
leaves = [
nltk.Tree(leaf[1], [leaf[0]])
if isinstance(leaf, tuple)
else (nltk.Tree("UNK", [leaf]) if isinstance(leaf, str) else leaf)
for leaf in leaves
]
idx = -1
def helper():
nonlocal idx
idx += 1
i, j, label = (
self.starts[idx],
self.ends[idx],
label_from_index[self.labels[idx]],
)
if (i + 1) >= j:
children = [leaves[i]]
else:
children = []
while (
(idx + 1) < len(self.starts)
and i <= self.starts[idx + 1]
and self.ends[idx + 1] <= j
):
children.extend(helper())
if label:
for sublabel in reversed(label.split("::")):
children = [nltk.Tree(sublabel, children)]
return children
children = helper()
return nltk.Tree("TOP", children)
class BaseParser(ABC):
"""Parser (abstract interface)"""
@classmethod
@abstractmethod
def from_trained(
cls, model_name: str, config: dict = None, state_dict: dict = None
) -> "BaseParser":
"""Load a trained parser."""
pass
@abstractmethod
def parallelize(self, *args, **kwargs):
"""Spread out pre-trained model layers across GPUs."""
pass
@abstractmethod
def parse(
self,
examples: Iterable[BaseInputExample],
return_compressed: bool = False,
return_scores: bool = False,
subbatch_max_tokens: Optional[int] = None,
) -> Union[Iterable[nltk.Tree], Iterable[Any]]:
"""Parse sentences."""
pass
@abstractmethod
def encode_and_collate_subbatches(
self, examples: List[BaseInputExample], subbatch_max_tokens: int
) -> List[dict]:
"""Split batch into sub-batches and convert to tensor features"""
pass
@abstractmethod
def compute_loss(self, batch: dict):
pass
|