|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from dataclasses import field |
|
from typing import Dict |
|
from typing import Generic |
|
from typing import List |
|
from typing import Optional |
|
from typing import TypeVar |
|
from typing import Union |
|
|
|
Symbol = TypeVar('Symbol') |
|
|
|
|
|
@dataclass(repr=False) |
|
class SymbolTable(Generic[Symbol]): |
|
'''SymbolTable that maps symbol IDs, found on the FSA arcs to |
|
actual objects. These objects can be arbitrary Python objects |
|
that can serve as keys in a dictionary (i.e. they need to be |
|
hashable and immutable). |
|
|
|
The SymbolTable can only be read to/written from disk if the |
|
symbols are strings. |
|
''' |
|
_id2sym: Dict[int, Symbol] = field(default_factory=dict) |
|
'''Map an integer to a symbol. |
|
''' |
|
|
|
_sym2id: Dict[Symbol, int] = field(default_factory=dict) |
|
'''Map a symbol to an integer. |
|
''' |
|
|
|
_next_available_id: int = 1 |
|
'''A helper internal field that helps adding new symbols |
|
to the table efficiently. |
|
''' |
|
|
|
eps: Symbol = '<eps>' |
|
'''Null symbol, always mapped to index 0. |
|
''' |
|
|
|
def __post_init__(self): |
|
assert all(self._sym2id[sym] == idx for idx, sym in self._id2sym.items()) |
|
assert all(self._id2sym[idx] == sym for sym, idx in self._sym2id.items()) |
|
assert 0 not in self._id2sym or self._id2sym[0] == self.eps |
|
|
|
self._next_available_id = max(self._id2sym, default=0) + 1 |
|
self._id2sym.setdefault(0, self.eps) |
|
self._sym2id.setdefault(self.eps, 0) |
|
|
|
|
|
@staticmethod |
|
def from_str(s: str) -> 'SymbolTable': |
|
'''Build a symbol table from a string. |
|
|
|
The string consists of lines. Every line has two fields separated |
|
by space(s), tab(s) or both. The first field is the symbol and the |
|
second the integer id of the symbol. |
|
|
|
Args: |
|
s: |
|
The input string with the format described above. |
|
Returns: |
|
An instance of :class:`SymbolTable`. |
|
''' |
|
id2sym: Dict[int, str] = dict() |
|
sym2id: Dict[str, int] = dict() |
|
|
|
for line in s.split('\n'): |
|
fields = line.split() |
|
if len(fields) == 0: |
|
continue |
|
assert len(fields) == 2, \ |
|
f'Expect a line with 2 fields. Given: {len(fields)}' |
|
sym, idx = fields[0], int(fields[1]) |
|
assert sym not in sym2id, f'Duplicated symbol {sym}' |
|
assert idx not in id2sym, f'Duplicated id {idx}' |
|
id2sym[idx] = sym |
|
sym2id[sym] = idx |
|
|
|
eps = id2sym.get(0, '<eps>') |
|
|
|
return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps) |
|
|
|
@staticmethod |
|
def from_file(filename: str) -> 'SymbolTable': |
|
'''Build a symbol table from file. |
|
|
|
Every line in the symbol table file has two fields separated by |
|
space(s), tab(s) or both. The following is an example file: |
|
|
|
.. code-block:: |
|
|
|
<eps> 0 |
|
a 1 |
|
b 2 |
|
c 3 |
|
|
|
Args: |
|
filename: |
|
Name of the symbol table file. Its format is documented above. |
|
|
|
Returns: |
|
An instance of :class:`SymbolTable`. |
|
|
|
''' |
|
with open(filename, 'r', encoding='utf-8') as f: |
|
return SymbolTable.from_str(f.read().strip()) |
|
|
|
def to_str(self) -> str: |
|
''' |
|
Returns: |
|
Return a string representation of this object. You can pass |
|
it to the method ``from_str`` to recreate an identical object. |
|
''' |
|
s = '' |
|
for idx, symbol in sorted(self._id2sym.items()): |
|
s += f'{symbol} {idx}\n' |
|
return s |
|
|
|
def to_file(self, filename: str): |
|
'''Serialize the SymbolTable to a file. |
|
|
|
Every line in the symbol table file has two fields separated by |
|
space(s), tab(s) or both. The following is an example file: |
|
|
|
.. code-block:: |
|
|
|
<eps> 0 |
|
a 1 |
|
b 2 |
|
c 3 |
|
|
|
Args: |
|
filename: |
|
Name of the symbol table file. Its format is documented above. |
|
''' |
|
with open(filename, 'w') as f: |
|
for idx, symbol in sorted(self._id2sym.items()): |
|
print(symbol, idx, file=f) |
|
|
|
def add(self, symbol: Symbol, index: Optional[int] = None) -> int: |
|
'''Add a new symbol to the SymbolTable. |
|
|
|
Args: |
|
symbol: |
|
The symbol to be added. |
|
index: |
|
Optional int id to which the symbol should be assigned. |
|
If it is not available, a ValueError will be raised. |
|
|
|
Returns: |
|
The int id to which the symbol has been assigned. |
|
''' |
|
|
|
if symbol in self._sym2id: |
|
return self._sym2id[symbol] |
|
|
|
if index is None: |
|
index = self._next_available_id |
|
|
|
if index in self._id2sym: |
|
raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - " |
|
f"already occupied by {self._id2sym[index]}") |
|
self._sym2id[symbol] = index |
|
self._id2sym[index] = symbol |
|
|
|
|
|
if self._next_available_id <= index: |
|
self._next_available_id = index + 1 |
|
|
|
return index |
|
|
|
def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]: |
|
'''Get a symbol for an id or get an id for a symbol |
|
|
|
Args: |
|
k: |
|
If it is an id, it tries to find the symbol corresponding |
|
to the id; if it is a symbol, it tries to find the id |
|
corresponding to the symbol. |
|
|
|
Returns: |
|
An id or a symbol depending on the given `k`. |
|
''' |
|
if isinstance(k, int): |
|
return self._id2sym[k] |
|
else: |
|
return self._sym2id[k] |
|
|
|
def merge(self, other: 'SymbolTable') -> 'SymbolTable': |
|
'''Create a union of two SymbolTables. |
|
Raises an AssertionError if the same IDs are occupied by |
|
different symbols. |
|
|
|
Args: |
|
other: |
|
A symbol table to merge with ``self``. |
|
|
|
Returns: |
|
A new symbol table. |
|
''' |
|
self._check_compatible(other) |
|
return SymbolTable( |
|
_id2sym={**self._id2sym, **other._id2sym}, |
|
_sym2id={**self._sym2id, **other._sym2id}, |
|
eps=self.eps |
|
) |
|
|
|
def _check_compatible(self, other: 'SymbolTable') -> None: |
|
|
|
assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \ |
|
f'{self.eps} != {other.eps}' |
|
|
|
common_ids = set(self._id2sym).intersection(other._id2sym) |
|
for idx in common_ids: |
|
assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \ |
|
f'self[idx] = "{self[idx]}", ' \ |
|
f'other[idx] = "{other[idx]}"' |
|
|
|
common_symbols = set(self._sym2id).intersection(other._sym2id) |
|
for sym in common_symbols: |
|
assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \ |
|
f'self[sym] = "{self[sym]}", ' \ |
|
f'other[sym] = "{other[sym]}"' |
|
|
|
def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]: |
|
return self.get(item) |
|
|
|
def __contains__(self, item: Union[int, Symbol]) -> bool: |
|
if isinstance(item, int): |
|
return item in self._id2sym |
|
else: |
|
return item in self._sym2id |
|
|
|
def __len__(self) -> int: |
|
return len(self._id2sym) |
|
|
|
def __eq__(self, other: 'SymbolTable') -> bool: |
|
if len(self) != len(other): |
|
return False |
|
|
|
for s in self.symbols: |
|
if self[s] != other[s]: |
|
return False |
|
|
|
return True |
|
|
|
@property |
|
def ids(self) -> List[int]: |
|
'''Returns a list of integer IDs corresponding to the symbols. |
|
''' |
|
ans = list(self._id2sym.keys()) |
|
ans.sort() |
|
return ans |
|
|
|
@property |
|
def symbols(self) -> List[Symbol]: |
|
'''Returns a list of symbols (e.g., strings) corresponding to |
|
the integer IDs. |
|
''' |
|
ans = list(self._sym2id.keys()) |
|
ans.sort() |
|
return ans |
|
|
|
|
|
class TextToken: |
|
def __init__( |
|
self, |
|
text_tokens: List[str], |
|
add_eos: bool = True, |
|
add_bos: bool = True, |
|
pad_symbol: str = "<pad>", |
|
bos_symbol: str = "<bos>", |
|
eos_symbol: str = "<eos>", |
|
): |
|
self.pad_symbol = pad_symbol |
|
self.add_eos = add_eos |
|
self.add_bos = add_bos |
|
self.bos_symbol = bos_symbol |
|
self.eos_symbol = eos_symbol |
|
|
|
unique_tokens = [pad_symbol] |
|
if add_bos: |
|
unique_tokens.append(bos_symbol) |
|
if add_eos: |
|
unique_tokens.append(eos_symbol) |
|
unique_tokens.extend(sorted(text_tokens)) |
|
|
|
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} |
|
self.idx2token = unique_tokens |
|
|
|
|
|
def get_token_id_seq(self, text): |
|
tokens_seq = [p for p in text] |
|
seq = ( |
|
([self.bos_symbol] if self.add_bos else []) |
|
+ tokens_seq |
|
+ ([self.eos_symbol] if self.add_eos else []) |
|
) |
|
|
|
token_ids = [self.token2idx[token] for token in seq] |
|
token_lens = len(tokens_seq) + self.add_eos + self.add_bos |
|
|
|
return token_ids, token_lens |
|
|
|
|