vibert-capu / vocabulary.py
dragonSwing's picture
Initialize commit
217bb4e
raw
history blame
13.2 kB
import codecs
from collections import defaultdict
import logging
import os
import re
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union, TYPE_CHECKING
from filelock import FileLock
logger = logging.getLogger(__name__)
DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels")
DEFAULT_PADDING_TOKEN = "@@PADDING@@"
DEFAULT_OOV_TOKEN = "@@UNKNOWN@@"
NAMESPACE_PADDING_FILE = "non_padded_namespaces.txt"
_NEW_LINE_REGEX = re.compile(r"\n|\r\n")
def namespace_match(pattern: str, namespace: str):
"""
Matches a namespace pattern against a namespace string. For example, `*tags` matches
`passage_tags` and `question_tags` and `tokens` matches `tokens` but not
`stemmed_tokens`.
"""
if pattern[0] == "*" and namespace.endswith(pattern[1:]):
return True
elif pattern == namespace:
return True
return False
class _NamespaceDependentDefaultDict(defaultdict):
"""
This is a [defaultdict]
(https://docs.python.org/2/library/collections.html#collections.defaultdict) where the
default value is dependent on the key that is passed.
We use "namespaces" in the :class:`Vocabulary` object to keep track of several different
mappings from strings to integers, so that we have a consistent API for mapping words, tags,
labels, characters, or whatever else you want, into integers. The issue is that some of those
namespaces (words and characters) should have integers reserved for padding and
out-of-vocabulary tokens, while others (labels and tags) shouldn't. This class allows you to
specify filters on the namespace (the key used in the `defaultdict`), and use different
default values depending on whether the namespace passes the filter.
To do filtering, we take a set of `non_padded_namespaces`. This is a set of strings
that are either matched exactly against the keys, or treated as suffixes, if the
string starts with `*`. In other words, if `*tags` is in `non_padded_namespaces` then
`passage_tags`, `question_tags`, etc. (anything that ends with `tags`) will have the
`non_padded` default value.
# Parameters
non_padded_namespaces : `Iterable[str]`
A set / list / tuple of strings describing which namespaces are not padded. If a namespace
(key) is missing from this dictionary, we will use :func:`namespace_match` to see whether
the namespace should be padded. If the given namespace matches any of the strings in this
list, we will use `non_padded_function` to initialize the value for that namespace, and
we will use `padded_function` otherwise.
padded_function : `Callable[[], Any]`
A zero-argument function to call to initialize a value for a namespace that `should` be
padded.
non_padded_function : `Callable[[], Any]`
A zero-argument function to call to initialize a value for a namespace that should `not` be
padded.
"""
def __init__(
self,
non_padded_namespaces: Iterable[str],
padded_function: Callable[[], Any],
non_padded_function: Callable[[], Any],
) -> None:
self._non_padded_namespaces = set(non_padded_namespaces)
self._padded_function = padded_function
self._non_padded_function = non_padded_function
super().__init__()
def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
# add non_padded_namespaces which weren't already present
self._non_padded_namespaces.update(non_padded_namespaces)
class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict):
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
super().__init__(non_padded_namespaces, lambda: {padding_token: 0, oov_token: 1}, lambda: {})
class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict):
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
super().__init__(non_padded_namespaces, lambda: {0: padding_token, 1: oov_token}, lambda: {})
class Vocabulary:
def __init__(
self,
counter: Dict[str, Dict[str, int]] = None,
min_count: Dict[str, int] = None,
max_vocab_size: Union[int, Dict[str, int]] = None,
non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES,
pretrained_files: Optional[Dict[str, str]] = None,
only_include_pretrained_words: bool = False,
tokens_to_add: Dict[str, List[str]] = None,
min_pretrained_embeddings: Dict[str, int] = None,
padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
) -> None:
self._padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
self._oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN
self._non_padded_namespaces = set(non_padded_namespaces)
self._token_to_index = _TokenToIndexDefaultDict(
self._non_padded_namespaces, self._padding_token, self._oov_token
)
self._index_to_token = _IndexToTokenDefaultDict(
self._non_padded_namespaces, self._padding_token, self._oov_token
)
@classmethod
def from_files(
cls,
directory: Union[str, os.PathLike],
padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
) -> "Vocabulary":
"""
Loads a `Vocabulary` that was serialized either using `save_to_files` or inside
a model archive file.
# Parameters
directory : `str`
The directory or archive file containing the serialized vocabulary.
"""
logger.info("Loading token dictionary from %s.", directory)
padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN
if not os.path.isdir(directory):
raise ValueError(f"{directory} not exist")
# We use a lock file to avoid race conditions where multiple processes
# might be reading/writing from/to the same vocab files at once.
with FileLock(os.path.join(directory, ".lock")):
with codecs.open(os.path.join(directory, NAMESPACE_PADDING_FILE), "r", "utf-8") as namespace_file:
non_padded_namespaces = [namespace_str.strip() for namespace_str in namespace_file]
vocab = cls(
non_padded_namespaces=non_padded_namespaces,
padding_token=padding_token,
oov_token=oov_token,
)
# Check every file in the directory.
for namespace_filename in os.listdir(directory):
if namespace_filename == NAMESPACE_PADDING_FILE:
continue
if namespace_filename.startswith("."):
continue
namespace = namespace_filename.replace(".txt", "")
if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces):
is_padded = False
else:
is_padded = True
filename = os.path.join(directory, namespace_filename)
vocab.set_from_file(filename, is_padded, namespace=namespace, oov_token=oov_token)
return vocab
@classmethod
def empty(cls) -> "Vocabulary":
"""
This method returns a bare vocabulary instantiated with `cls()` (so, `Vocabulary()` if you
haven't made a subclass of this object). The only reason to call `Vocabulary.empty()`
instead of `Vocabulary()` is if you are instantiating this object from a config file. We
register this constructor with the key "empty", so if you know that you don't need to
compute a vocabulary (either because you're loading a pre-trained model from an archive
file, you're using a pre-trained transformer that has its own vocabulary, or something
else), you can use this to avoid having the default vocabulary construction code iterate
through the data.
"""
return cls()
def set_from_file(
self,
filename: str,
is_padded: bool = True,
oov_token: str = DEFAULT_OOV_TOKEN,
namespace: str = "tokens",
):
"""
If you already have a vocabulary file for a trained model somewhere, and you really want to
use that vocabulary file instead of just setting the vocabulary from a dataset, for
whatever reason, you can do that with this method. You must specify the namespace to use,
and we assume that you want to use padding and OOV tokens for this.
# Parameters
filename : `str`
The file containing the vocabulary to load. It should be formatted as one token per
line, with nothing else in the line. The index we assign to the token is the line
number in the file (1-indexed if `is_padded`, 0-indexed otherwise). Note that this
file should contain the OOV token string!
is_padded : `bool`, optional (default=`True`)
Is this vocabulary padded? For token / word / character vocabularies, this should be
`True`; while for tag or label vocabularies, this should typically be `False`. If
`True`, we add a padding token with index 0, and we enforce that the `oov_token` is
present in the file.
oov_token : `str`, optional (default=`DEFAULT_OOV_TOKEN`)
What token does this vocabulary use to represent out-of-vocabulary characters? This
must show up as a line in the vocabulary file. When we find it, we replace
`oov_token` with `self._oov_token`, because we only use one OOV token across
namespaces.
namespace : `str`, optional (default=`"tokens"`)
What namespace should we overwrite with this vocab file?
"""
if is_padded:
self._token_to_index[namespace] = {self._padding_token: 0}
self._index_to_token[namespace] = {0: self._padding_token}
else:
self._token_to_index[namespace] = {}
self._index_to_token[namespace] = {}
with codecs.open(filename, "r", "utf-8") as input_file:
lines = _NEW_LINE_REGEX.split(input_file.read())
# Be flexible about having final newline or not
if lines and lines[-1] == "":
lines = lines[:-1]
for i, line in enumerate(lines):
index = i + 1 if is_padded else i
token = line.replace("@@NEWLINE@@", "\n")
if token == oov_token:
token = self._oov_token
self._token_to_index[namespace][token] = index
self._index_to_token[namespace][index] = token
if is_padded:
assert self._oov_token in self._token_to_index[namespace], "OOV token not found!"
def add_token_to_namespace(self, token: str, namespace: str = "tokens") -> int:
"""
Adds `token` to the index, if it is not already present. Either way, we return the index of
the token.
"""
if not isinstance(token, str):
raise ValueError(
"Vocabulary tokens must be strings, or saving and loading will break."
" Got %s (with type %s)" % (repr(token), type(token))
)
if token not in self._token_to_index[namespace]:
index = len(self._token_to_index[namespace])
self._token_to_index[namespace][token] = index
self._index_to_token[namespace][index] = token
return index
else:
return self._token_to_index[namespace][token]
def add_tokens_to_namespace(self, tokens: List[str], namespace: str = "tokens") -> List[int]:
"""
Adds `tokens` to the index, if they are not already present. Either way, we return the
indices of the tokens in the order that they were given.
"""
return [self.add_token_to_namespace(token, namespace) for token in tokens]
def get_token_index(self, token: str, namespace: str = "tokens") -> int:
try:
return self._token_to_index[namespace][token]
except KeyError:
try:
return self._token_to_index[namespace][self._oov_token]
except KeyError:
logger.error("Namespace: %s", namespace)
logger.error("Token: %s", token)
raise KeyError(
f"'{token}' not found in vocab namespace '{namespace}', and namespace "
f"does not contain the default OOV token ('{self._oov_token}')"
)
def get_token_from_index(self, index: int, namespace: str = "tokens") -> str:
return self._index_to_token[namespace][index]
def get_vocab_size(self, namespace: str = "tokens") -> int:
return len(self._token_to_index[namespace])
def get_namespaces(self) -> Set[str]:
return set(self._index_to_token.keys())