enhg-parsing / benepar /parse_base.py
nielklug's picture
add parsing
8778cfe
from abc import ABC, abstractmethod
import dataclasses
from typing import Any, Iterable, List, Optional, Tuple, Union
import nltk
import numpy as np
class BaseInputExample(ABC):
"""Parser input for a single sentence (abstract interface)."""
# Subclasses must define the following attributes or properties.
# `words` is a list of unicode representations for each word in the sentence
# and `space_after` is a list of booleans that indicate whether there is
# whitespace after a word. Together, these should form a reversible
# tokenization of raw text input. `tree` is an optional gold parse tree.
words: List[str]
space_after: List[bool]
tree: Optional[nltk.Tree]
@abstractmethod
def leaves(self) -> Optional[List[str]]:
"""Returns leaves to use in the parse tree.
While `words` must be raw unicode text, these should be whatever is
standard for the treebank. For example, '(' in words might correspond to
'-LRB-' in leaves, and leaves might include other transformations such
as transliteration.
"""
pass
@abstractmethod
def pos(self) -> Optional[List[Tuple[str, str]]]:
"""Returns a list of (leaf, part-of-speech tag) tuples."""
pass
@dataclasses.dataclass
class CompressedParserOutput:
"""Parser output, encoded as a collection of numpy arrays.
By default, a parser will return nltk.Tree objects. These have much nicer
APIs than the CompressedParserOutput class, and the code involved is simpler
and more readable. As a trade-off, code dealing with nltk.Tree objects is
slower: the nltk.Tree type itself has some overhead, and algorithms dealing
with it are implemented in pure Python as opposed to C or even CUDA. The
CompressedParserOutput type is an alternative that has some optimizations
for the sole purpose of speeding up inference.
If trying a new parser type for research purposes, it's safe to ignore this
class and the return_compressed argument to parse(). If the parser works
well and is being released, the return_compressed argument can then be added
with a dedicated fast implementation, or simply by using the from_tree
method defined below.
"""
# A parse tree is represented as a set of constituents. In the case of
# non-binary trees, only the labeled non-terminal nodes are included: there
# are no dummy nodes inserted for binarization purposes. However, single
# words are always included in the set of constituents, and they may have a
# null label if there is no phrasal category above the part-of-speech tag.
# All constituents are sorted according to pre-order traversal, and each has
# an associated start (the index of the first word in the constituent), end
# (1 + the index of the last word in the constituent), and label (index
# associated with an external label_vocab dictionary.) These are then stored
# in three numpy arrays:
starts: Iterable[int] # Must be a numpy array
ends: Iterable[int] # Must be a numpy array
labels: Iterable[int] # Must be a numpy array
# Part of speech tag ids as output by the parser (may be None if the parser
# does not do POS tagging). These indices are associated with an external
# tag_vocab dictionary.
tags: Optional[Iterable[int]] = None # Must be None or a numpy array
def without_predicted_tags(self):
return dataclasses.replace(self, tags=None)
def with_tags(self, tags):
return dataclasses.replace(self, tags=tags)
@classmethod
def from_tree(
cls, tree: nltk.Tree, label_vocab: dict, tag_vocab: Optional[dict] = None
) -> "CompressedParserOutput":
num_words = len(tree.leaves())
starts = np.empty(2 * num_words, dtype=int)
ends = np.empty(2 * num_words, dtype=int)
labels = np.empty(2 * num_words, dtype=int)
def helper(tree, start, write_idx):
nonlocal starts, ends, labels
label = []
while len(tree) == 1 and not isinstance(tree[0], str):
if tree.label() != "TOP":
label.append(tree.label())
tree = tree[0]
if len(tree) == 1 and isinstance(tree[0], str):
starts[write_idx] = start
ends[write_idx] = start + 1
labels[write_idx] = label_vocab["::".join(label)]
return start + 1, write_idx + 1
label.append(tree.label())
starts[write_idx] = start
labels[write_idx] = label_vocab["::".join(label)]
end = start
new_write_idx = write_idx + 1
for child in tree:
end, new_write_idx = helper(child, end, new_write_idx)
ends[write_idx] = end
return end, new_write_idx
_, num_constituents = helper(tree, 0, 0)
starts = starts[:num_constituents]
ends = ends[:num_constituents]
labels = labels[:num_constituents]
if tag_vocab is None:
tags = None
else:
tags = np.array([tag_vocab[tag] for _, tag in tree.pos()], dtype=int)
return cls(starts=starts, ends=ends, labels=labels, tags=tags)
def to_tree(self, leaves, label_from_index: dict, tag_from_index: dict = None):
if self.tags is not None:
if tag_from_index is None:
raise ValueError(
"tags_from_index is required to convert predicted pos tags"
)
predicted_tags = [tag_from_index[i] for i in self.tags]
assert len(leaves) == len(predicted_tags)
leaves = [
nltk.Tree(tag, [leaf[0] if isinstance(leaf, tuple) else leaf])
for tag, leaf in zip(predicted_tags, leaves)
]
else:
leaves = [
nltk.Tree(leaf[1], [leaf[0]])
if isinstance(leaf, tuple)
else (nltk.Tree("UNK", [leaf]) if isinstance(leaf, str) else leaf)
for leaf in leaves
]
idx = -1
def helper():
nonlocal idx
idx += 1
i, j, label = (
self.starts[idx],
self.ends[idx],
label_from_index[self.labels[idx]],
)
if (i + 1) >= j:
children = [leaves[i]]
else:
children = []
while (
(idx + 1) < len(self.starts)
and i <= self.starts[idx + 1]
and self.ends[idx + 1] <= j
):
children.extend(helper())
if label:
for sublabel in reversed(label.split("::")):
children = [nltk.Tree(sublabel, children)]
return children
children = helper()
return nltk.Tree("TOP", children)
class BaseParser(ABC):
"""Parser (abstract interface)"""
@classmethod
@abstractmethod
def from_trained(
cls, model_name: str, config: dict = None, state_dict: dict = None
) -> "BaseParser":
"""Load a trained parser."""
pass
@abstractmethod
def parallelize(self, *args, **kwargs):
"""Spread out pre-trained model layers across GPUs."""
pass
@abstractmethod
def parse(
self,
examples: Iterable[BaseInputExample],
return_compressed: bool = False,
return_scores: bool = False,
subbatch_max_tokens: Optional[int] = None,
) -> Union[Iterable[nltk.Tree], Iterable[Any]]:
"""Parse sentences."""
pass
@abstractmethod
def encode_and_collate_subbatches(
self, examples: List[BaseInputExample], subbatch_max_tokens: int
) -> List[dict]:
"""Split batch into sub-batches and convert to tensor features"""
pass
@abstractmethod
def compute_loss(self, batch: dict):
pass