dragonSwing
commited on
Commit
•
217bb4e
1
Parent(s):
03456cd
Initialize commit
Browse files- README.md +80 -1
- config.json +14 -0
- configuration_seq2labels.py +62 -0
- gec_model.py +443 -0
- modeling_seq2labels.py +123 -0
- pytorch_model.bin +3 -0
- utils.py +233 -0
- vocabulary.py +277 -0
- vocabulary/.lock +0 -0
- vocabulary/d_tags.txt +4 -0
- vocabulary/labels.txt +15 -0
- vocabulary/non_padded_namespaces.txt +2 -0
README.md
CHANGED
@@ -1,3 +1,82 @@
|
|
1 |
---
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
language:
|
3 |
+
- vi
|
4 |
+
tags:
|
5 |
+
- capitalization
|
6 |
+
- punctuation
|
7 |
+
- token-classification
|
8 |
+
- sequence-tagger-model
|
9 |
+
license: mit
|
10 |
+
datasets:
|
11 |
+
- oscar-corpus/OSCAR-2109
|
12 |
+
metrics:
|
13 |
+
- accuracy
|
14 |
+
- precision
|
15 |
+
- recall
|
16 |
+
- f1
|
17 |
---
|
18 |
+
# ✨ vibert-capitalization-punctuation
|
19 |
+
This a [viBERT](https://huggingface.co/FPTAI/vibert-base-cased) model finetuned for punctuation restoration on the [OSCAR-2109](https://huggingface.co/datasets/oscar-corpus/OSCAR-2109) dataset.
|
20 |
+
The model predicts the punctuation and upper-casing of plain, lower-cased text. An example use case can be ASR output. Or other cases when text has lost punctuation.
|
21 |
+
This model is intended for direct use as a punctuation restoration model for the general Vietnamese language. Alternatively, you can use this for further fine-tuning on domain-specific texts for punctuation restoration tasks.
|
22 |
+
Model restores the following punctuations -- **[. , : ? ]**
|
23 |
+
The model also restores the complex upper-casing of words like *YouTube*, *MobiFone*.
|
24 |
+
|
25 |
+
-----------------------------------------------
|
26 |
+
## 🚋 Usage
|
27 |
+
**Below is a quick way to get up and running with the model.**
|
28 |
+
1. Download files from hub
|
29 |
+
```python
|
30 |
+
import os
|
31 |
+
import shutil
|
32 |
+
import sys
|
33 |
+
from huggingface_hub import snapshot_download
|
34 |
+
|
35 |
+
cache_dir = "./capu"
|
36 |
+
def download_files(repo_id, cache_dir=None, ignore_regex=None):
|
37 |
+
download_dir = snapshot_download(repo_id=repo_id, cache_dir=cache_dir, ignore_regex=ignore_regex)
|
38 |
+
if cache_dir is None or download_dir == cache_dir:
|
39 |
+
return download_dir
|
40 |
+
|
41 |
+
file_names = os.listdir(download_dir)
|
42 |
+
for file_name in file_names:
|
43 |
+
shutil.move(os.path.join(download_dir, file_name), cache_dir)
|
44 |
+
os.rmdir(download_dir)
|
45 |
+
return cache_dir
|
46 |
+
|
47 |
+
download_files(repo_id="dragonSwing/vibert-capu", cache_dir=cache_dir, ignore_regex=["*.json", "*.bin"])
|
48 |
+
sys.path.append(cache_dir)
|
49 |
+
```
|
50 |
+
2. Sample python code
|
51 |
+
```python
|
52 |
+
import os
|
53 |
+
from gec_model import GecBERTModel
|
54 |
+
model = GecBERTModel(
|
55 |
+
vocab_path=os.path.join(cache_dir, "vocabulary"),
|
56 |
+
model_paths="dragonSwing/vibert-capu",
|
57 |
+
split_chunk=True
|
58 |
+
)
|
59 |
+
model("theo đó thủ tướng dự kiến tiếp bộ trưởng nông nghiệp mỹ tom wilsack bộ trưởng thương mại mỹ gina raimondo bộ trưởng tài chính janet yellen gặp gỡ thượng nghị sĩ patrick leahy và một số nghị sĩ mỹ khác")
|
60 |
+
# Theo đó, Thủ tướng dự kiến tiếp Bộ trưởng Nông nghiệp Mỹ Tom Wilsack, Bộ trưởng Thương mại Mỹ Gina Raimondo, Bộ trưởng Tài chính Janet Yellen, gặp gỡ thượng nghị sĩ Patrick Leahy và một số nghị sĩ Mỹ khác.
|
61 |
+
```
|
62 |
+
**This model can work on arbitrarily large text in Vietnamese language.**
|
63 |
+
|
64 |
+
-----------------------------------------------
|
65 |
+
## 📡 Training data
|
66 |
+
Here is the number of product reviews we used for fine-tuning the model:
|
67 |
+
| Language | Number of text samples|
|
68 |
+
| -------- | ----------------- |
|
69 |
+
| Vietnamese | 5,600,000 |
|
70 |
+
-----------------------------------------------
|
71 |
+
## 🎯 Accuracy
|
72 |
+
Below is a breakdown of the performance of the model by each label on 120,000 held-out text samples:
|
73 |
+
| label | precision | recall | f1-score | support|
|
74 |
+
| --------- | -------------|-------- | ----------|--------|
|
75 |
+
| **Upper** | 0.88 | 0.89 | 0.89 | 56497
|
76 |
+
| **Complex-Upper** | 0.92 | 0.83 | 0.88 | 480
|
77 |
+
| **.** | 0.81 | 0.82 | 0.82 | 18139
|
78 |
+
| **,** | 0.73 | 0.70 | 0.71 | 22961
|
79 |
+
| **:** | 0.74 | 0.56 | 0.64 | 1432
|
80 |
+
| **?** | 0.80 | 0.76 | 0.78 | 1730
|
81 |
+
| **none** | 0.99 | 0.99 | 0.99 |475611
|
82 |
+
-----------------------------------------------
|
config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"initializer_range": 0.02,
|
3 |
+
"label_smoothing": 0.0,
|
4 |
+
"load_pretrained": false,
|
5 |
+
"model_type": "bert",
|
6 |
+
"num_detect_classes": 4,
|
7 |
+
"pad_token_id": 0,
|
8 |
+
"predictor_dropout": 0.0,
|
9 |
+
"pretrained_name_or_path": "FPTAI/vibert-base-cased",
|
10 |
+
"special_tokens_fix": true,
|
11 |
+
"transformers_version": "4.18.0",
|
12 |
+
"use_cache": true,
|
13 |
+
"vocab_size": 15
|
14 |
+
}
|
configuration_seq2labels.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
|
3 |
+
|
4 |
+
class Seq2LabelsConfig(PretrainedConfig):
|
5 |
+
r"""
|
6 |
+
This is the configuration class to store the configuration of a [`Seq2LabelsModel`]. It is used to
|
7 |
+
instantiate a Seq2Labels model according to the specified arguments, defining the model architecture. Instantiating a
|
8 |
+
configuration with the defaults will yield a similar configuration to that of the Seq2Labels architecture.
|
9 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
10 |
+
documentation from [`PretrainedConfig`] for more information.
|
11 |
+
Args:
|
12 |
+
vocab_size (`int`, *optional*, defaults to 30522):
|
13 |
+
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
|
14 |
+
`inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`].
|
15 |
+
pretrained_name_or_path (`str`, *optional*, defaults to `bert-base-cased`):
|
16 |
+
Pretrained BERT-like model path
|
17 |
+
load_pretrained (`bool`, *optional*, defaults to `False`):
|
18 |
+
Whether to load pretrained model from `pretrained_name_or_path`
|
19 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
20 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
21 |
+
relevant if `config.is_decoder=True`.
|
22 |
+
predictor_dropout (`float`, *optional*):
|
23 |
+
The dropout ratio for the classification head.
|
24 |
+
special_tokens_fix (`bool`, *optional*, defaults to `False`):
|
25 |
+
Whether to add additional tokens to the BERT's embedding layer.
|
26 |
+
Examples:
|
27 |
+
```python
|
28 |
+
>>> from transformers import BertModel, BertConfig
|
29 |
+
>>> # Initializing a Seq2Labels style configuration
|
30 |
+
>>> configuration = Seq2LabelsConfig()
|
31 |
+
>>> # Initializing a model from the bert-base-uncased style configuration
|
32 |
+
>>> model = Seq2LabelsModel(configuration)
|
33 |
+
>>> # Accessing the model configuration
|
34 |
+
>>> configuration = model.config
|
35 |
+
```"""
|
36 |
+
model_type = "bert"
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
pretrained_name_or_path="bert-base-cased",
|
41 |
+
vocab_size=15,
|
42 |
+
num_detect_classes=4,
|
43 |
+
load_pretrained=False,
|
44 |
+
initializer_range=0.02,
|
45 |
+
pad_token_id=0,
|
46 |
+
use_cache=True,
|
47 |
+
predictor_dropout=0.0,
|
48 |
+
special_tokens_fix=False,
|
49 |
+
label_smoothing=0.0,
|
50 |
+
**kwargs
|
51 |
+
):
|
52 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
53 |
+
|
54 |
+
self.vocab_size = vocab_size
|
55 |
+
self.num_detect_classes = num_detect_classes
|
56 |
+
self.pretrained_name_or_path = pretrained_name_or_path
|
57 |
+
self.load_pretrained = load_pretrained
|
58 |
+
self.initializer_range = initializer_range
|
59 |
+
self.use_cache = use_cache
|
60 |
+
self.predictor_dropout = predictor_dropout
|
61 |
+
self.special_tokens_fix = special_tokens_fix
|
62 |
+
self.label_smoothing = label_smoothing
|
gec_model.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Wrapper of Seq2Labels model. Fixes errors based on model predictions"""
|
2 |
+
from collections import defaultdict
|
3 |
+
from difflib import SequenceMatcher
|
4 |
+
import logging
|
5 |
+
import re
|
6 |
+
from time import time
|
7 |
+
from typing import List, Union
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from transformers import AutoTokenizer
|
12 |
+
from .modeling_seq2labels import Seq2LabelsModel
|
13 |
+
from .vocabulary import Vocabulary
|
14 |
+
from .utils import PAD, UNK, START_TOKEN, get_target_sent_by_edits
|
15 |
+
|
16 |
+
logging.getLogger("werkzeug").setLevel(logging.ERROR)
|
17 |
+
logger = logging.getLogger(__file__)
|
18 |
+
|
19 |
+
|
20 |
+
class GecBERTModel(torch.nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
vocab_path=None,
|
24 |
+
model_paths=None,
|
25 |
+
weights=None,
|
26 |
+
device=None,
|
27 |
+
max_len=64,
|
28 |
+
min_len=3,
|
29 |
+
lowercase_tokens=False,
|
30 |
+
log=False,
|
31 |
+
iterations=3,
|
32 |
+
min_error_probability=0.0,
|
33 |
+
confidence=0,
|
34 |
+
resolve_cycles=False,
|
35 |
+
split_chunk=False,
|
36 |
+
chunk_size=48,
|
37 |
+
overlap_size=12,
|
38 |
+
min_words_cut=6,
|
39 |
+
punc_dict={':', ".", ",", "?"},
|
40 |
+
):
|
41 |
+
r"""
|
42 |
+
Args:
|
43 |
+
vocab_path (`str`):
|
44 |
+
Path to vocabulary directory.
|
45 |
+
model_paths (`List[str]`):
|
46 |
+
List of model paths.
|
47 |
+
weights (`int`, *Optional*, defaults to None):
|
48 |
+
Weights of each model. Only relevant if `is_ensemble is True`.
|
49 |
+
device (`int`, *Optional*, defaults to None):
|
50 |
+
Device to load model. If not set, device will be automatically choose.
|
51 |
+
max_len (`int`, defaults to 64):
|
52 |
+
Max sentence length to be processed (all longer will be truncated).
|
53 |
+
min_len (`int`, defaults to 3):
|
54 |
+
Min sentence length to be processed (all shorted will be returned w/o changes).
|
55 |
+
lowercase_tokens (`bool`, defaults to False):
|
56 |
+
Whether to lowercase tokens.
|
57 |
+
log (`bool`, defaults to False):
|
58 |
+
Whether to enable logging.
|
59 |
+
iterations (`int`, defaults to 3):
|
60 |
+
Max iterations to run during inference.
|
61 |
+
special_tokens_fix (`bool`, defaults to True):
|
62 |
+
Whether to fix problem with [CLS], [SEP] tokens tokenization.
|
63 |
+
min_error_probability (`float`, defaults to `0.0`):
|
64 |
+
Minimum probability for each action to apply.
|
65 |
+
confidence (`float`, defaults to `0.0`):
|
66 |
+
How many probability to add to $KEEP token.
|
67 |
+
split_chunk (`bool`, defaults to False):
|
68 |
+
Whether to split long sentences to multiple segments of `chunk_size`.
|
69 |
+
!Warning: if `chunk_size > max_len`, each segment will be truncate to `max_len`.
|
70 |
+
chunk_size (`int`, defaults to 48):
|
71 |
+
Length of each segment (in words). Only relevant if `split_chunk is True`.
|
72 |
+
overlap_size (`int`, defaults to 12):
|
73 |
+
Overlap size (in words) between two consecutive segments. Only relevant if `split_chunk is True`.
|
74 |
+
min_words_cut (`int`, defaults to 6):
|
75 |
+
Minimun number of words to be cut while merging two consecutive segments.
|
76 |
+
Only relevant if `split_chunk is True`.
|
77 |
+
punc_dict (List[str], defaults to `{':', ".", ",", "?"}`):
|
78 |
+
List of punctuations.
|
79 |
+
"""
|
80 |
+
super().__init__()
|
81 |
+
self.model_weights = list(map(float, weights)) if weights else [1] * len(model_paths)
|
82 |
+
self.device = (
|
83 |
+
torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
|
84 |
+
)
|
85 |
+
self.max_len = max_len
|
86 |
+
self.min_len = min_len
|
87 |
+
self.lowercase_tokens = lowercase_tokens
|
88 |
+
self.min_error_probability = min_error_probability
|
89 |
+
self.vocab = Vocabulary.from_files(vocab_path)
|
90 |
+
self.log = log
|
91 |
+
self.iterations = iterations
|
92 |
+
self.confidence = confidence
|
93 |
+
self.resolve_cycles = resolve_cycles
|
94 |
+
|
95 |
+
assert (
|
96 |
+
chunk_size > 0 and chunk_size // 2 >= overlap_size
|
97 |
+
), "Chunk merging required overlap size must be smaller than half of chunk size"
|
98 |
+
self.split_chunk = split_chunk
|
99 |
+
self.chunk_size = chunk_size
|
100 |
+
self.overlap_size = overlap_size
|
101 |
+
self.min_words_cut = min_words_cut
|
102 |
+
self.stride = chunk_size - overlap_size
|
103 |
+
self.punc_dict = punc_dict
|
104 |
+
self.punc_str = '[' + ''.join([f'\{x}' for x in punc_dict]) + ']'
|
105 |
+
# set training parameters and operations
|
106 |
+
|
107 |
+
self.indexers = []
|
108 |
+
self.models = []
|
109 |
+
if isinstance(model_paths, str):
|
110 |
+
model_paths = [model_paths]
|
111 |
+
for model_path in model_paths:
|
112 |
+
model = Seq2LabelsModel.from_pretrained(model_path)
|
113 |
+
config = model.config
|
114 |
+
model_name = config.pretrained_name_or_path
|
115 |
+
special_tokens_fix = config.special_tokens_fix
|
116 |
+
self.indexers.append(self._get_indexer(model_name, special_tokens_fix))
|
117 |
+
model.eval().to(self.device)
|
118 |
+
self.models.append(model)
|
119 |
+
|
120 |
+
def _get_indexer(self, weights_name, special_tokens_fix):
|
121 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
122 |
+
weights_name, do_basic_tokenize=False, do_lower_case=self.lowercase_tokens, model_max_length=1024
|
123 |
+
)
|
124 |
+
# to adjust all tokenizers
|
125 |
+
if hasattr(tokenizer, 'encoder'):
|
126 |
+
tokenizer.vocab = tokenizer.encoder
|
127 |
+
if hasattr(tokenizer, 'sp_model'):
|
128 |
+
tokenizer.vocab = defaultdict(lambda: 1)
|
129 |
+
for i in range(tokenizer.sp_model.get_piece_size()):
|
130 |
+
tokenizer.vocab[tokenizer.sp_model.id_to_piece(i)] = i
|
131 |
+
|
132 |
+
if special_tokens_fix:
|
133 |
+
tokenizer.add_tokens([START_TOKEN])
|
134 |
+
tokenizer.vocab[START_TOKEN] = len(tokenizer) - 1
|
135 |
+
return tokenizer
|
136 |
+
|
137 |
+
def forward(self, text: Union[str, List[str], List[List[str]]], is_split_into_words=False):
|
138 |
+
# Input type checking for clearer error
|
139 |
+
def _is_valid_text_input(t):
|
140 |
+
if isinstance(t, str):
|
141 |
+
# Strings are fine
|
142 |
+
return True
|
143 |
+
elif isinstance(t, (list, tuple)):
|
144 |
+
# List are fine as long as they are...
|
145 |
+
if len(t) == 0:
|
146 |
+
# ... empty
|
147 |
+
return True
|
148 |
+
elif isinstance(t[0], str):
|
149 |
+
# ... list of strings
|
150 |
+
return True
|
151 |
+
elif isinstance(t[0], (list, tuple)):
|
152 |
+
# ... list with an empty list or with a list of strings
|
153 |
+
return len(t[0]) == 0 or isinstance(t[0][0], str)
|
154 |
+
else:
|
155 |
+
return False
|
156 |
+
else:
|
157 |
+
return False
|
158 |
+
|
159 |
+
if not _is_valid_text_input(text):
|
160 |
+
raise ValueError(
|
161 |
+
"text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) "
|
162 |
+
"or `List[List[str]]` (batch of pretokenized examples)."
|
163 |
+
)
|
164 |
+
|
165 |
+
if is_split_into_words:
|
166 |
+
is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
|
167 |
+
else:
|
168 |
+
is_batched = isinstance(text, (list, tuple))
|
169 |
+
if is_batched:
|
170 |
+
text = [x.split() for x in text]
|
171 |
+
else:
|
172 |
+
text = text.split()
|
173 |
+
|
174 |
+
if not is_batched:
|
175 |
+
text = [text]
|
176 |
+
|
177 |
+
return self.handle_batch(text)
|
178 |
+
|
179 |
+
def split_chunks(self, batch):
|
180 |
+
# return batch pairs of indices
|
181 |
+
result = []
|
182 |
+
indices = []
|
183 |
+
for tokens in batch:
|
184 |
+
start = len(result)
|
185 |
+
num_token = len(tokens)
|
186 |
+
if num_token <= self.chunk_size:
|
187 |
+
result.append(tokens)
|
188 |
+
elif num_token > self.chunk_size and num_token < (self.chunk_size * 2 - self.overlap_size):
|
189 |
+
split_idx = (num_token + self.overlap_size + 1) // 2
|
190 |
+
result.append(tokens[:split_idx])
|
191 |
+
result.append(tokens[split_idx - self.overlap_size :])
|
192 |
+
else:
|
193 |
+
for i in range(0, num_token - self.overlap_size, self.stride):
|
194 |
+
result.append(tokens[i : i + self.chunk_size])
|
195 |
+
|
196 |
+
indices.append((start, len(result)))
|
197 |
+
|
198 |
+
return result, indices
|
199 |
+
|
200 |
+
def check_alnum(self, s):
|
201 |
+
if len(s) < 2:
|
202 |
+
return False
|
203 |
+
return not (s.isalpha() or s.isdigit())
|
204 |
+
|
205 |
+
def apply_chunk_merging(self, tokens, next_tokens):
|
206 |
+
# Return next tokens if current tokens list is empty
|
207 |
+
if not tokens:
|
208 |
+
return next_tokens
|
209 |
+
|
210 |
+
source_token_idx = []
|
211 |
+
target_token_idx = []
|
212 |
+
source_tokens = []
|
213 |
+
target_tokens = []
|
214 |
+
num_keep = self.overlap_size - self.min_words_cut
|
215 |
+
i = 0
|
216 |
+
while len(source_token_idx) < self.overlap_size and -i < len(tokens):
|
217 |
+
i -= 1
|
218 |
+
if tokens[i] not in self.punc_dict:
|
219 |
+
source_token_idx.insert(0, i)
|
220 |
+
source_tokens.insert(0, tokens[i].lower())
|
221 |
+
|
222 |
+
i = 0
|
223 |
+
while len(target_token_idx) < self.overlap_size and i < len(next_tokens):
|
224 |
+
if next_tokens[i] not in self.punc_dict:
|
225 |
+
target_token_idx.append(i)
|
226 |
+
target_tokens.append(next_tokens[i].lower())
|
227 |
+
i += 1
|
228 |
+
|
229 |
+
matcher = SequenceMatcher(None, source_tokens, target_tokens)
|
230 |
+
diffs = list(matcher.get_opcodes())
|
231 |
+
|
232 |
+
for diff in diffs:
|
233 |
+
tag, i1, i2, j1, j2 = diff
|
234 |
+
if tag == "equal":
|
235 |
+
if i1 >= num_keep:
|
236 |
+
tail_idx = source_token_idx[i1]
|
237 |
+
head_idx = target_token_idx[j1]
|
238 |
+
break
|
239 |
+
elif i2 > num_keep:
|
240 |
+
tail_idx = source_token_idx[num_keep]
|
241 |
+
head_idx = target_token_idx[j2 - i2 + num_keep]
|
242 |
+
break
|
243 |
+
elif tag == "delete" and i1 == 0:
|
244 |
+
num_keep += i2 // 2
|
245 |
+
|
246 |
+
tokens = tokens[:tail_idx] + next_tokens[head_idx:]
|
247 |
+
return tokens
|
248 |
+
|
249 |
+
def merge_chunks(self, batch):
|
250 |
+
result = []
|
251 |
+
if len(batch) == 1 or self.overlap_size == 0:
|
252 |
+
for sub_tokens in batch:
|
253 |
+
result.extend(sub_tokens)
|
254 |
+
else:
|
255 |
+
for _, sub_tokens in enumerate(batch):
|
256 |
+
try:
|
257 |
+
result = self.apply_chunk_merging(result, sub_tokens)
|
258 |
+
except Exception as e:
|
259 |
+
print(e)
|
260 |
+
|
261 |
+
result = " ".join(result)
|
262 |
+
return result
|
263 |
+
|
264 |
+
def predict(self, batches):
|
265 |
+
t11 = time()
|
266 |
+
predictions = []
|
267 |
+
for batch, model in zip(batches, self.models):
|
268 |
+
batch = batch.to(self.device)
|
269 |
+
with torch.no_grad():
|
270 |
+
prediction = model.forward(**batch)
|
271 |
+
predictions.append(prediction)
|
272 |
+
|
273 |
+
preds, idx, error_probs = self._convert(predictions)
|
274 |
+
t55 = time()
|
275 |
+
if self.log:
|
276 |
+
print(f"Inference time {t55 - t11}")
|
277 |
+
return preds, idx, error_probs
|
278 |
+
|
279 |
+
def get_token_action(self, token, index, prob, sugg_token):
|
280 |
+
"""Get lost of suggested actions for token."""
|
281 |
+
# cases when we don't need to do anything
|
282 |
+
if prob < self.min_error_probability or sugg_token in [UNK, PAD, '$KEEP']:
|
283 |
+
return None
|
284 |
+
|
285 |
+
if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE':
|
286 |
+
start_pos = index
|
287 |
+
end_pos = index + 1
|
288 |
+
elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"):
|
289 |
+
start_pos = index + 1
|
290 |
+
end_pos = index + 1
|
291 |
+
|
292 |
+
if sugg_token == "$DELETE":
|
293 |
+
sugg_token_clear = ""
|
294 |
+
elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"):
|
295 |
+
sugg_token_clear = sugg_token[:]
|
296 |
+
else:
|
297 |
+
sugg_token_clear = sugg_token[sugg_token.index('_') + 1 :]
|
298 |
+
|
299 |
+
return start_pos - 1, end_pos - 1, sugg_token_clear, prob
|
300 |
+
|
301 |
+
def preprocess(self, token_batch):
|
302 |
+
seq_lens = [len(sequence) for sequence in token_batch if sequence]
|
303 |
+
if not seq_lens:
|
304 |
+
return []
|
305 |
+
max_len = min(max(seq_lens), self.max_len)
|
306 |
+
batches = []
|
307 |
+
for indexer in self.indexers:
|
308 |
+
token_batch = [[START_TOKEN] + sequence[:max_len] for sequence in token_batch]
|
309 |
+
batch = indexer(
|
310 |
+
token_batch,
|
311 |
+
return_tensors="pt",
|
312 |
+
padding=True,
|
313 |
+
is_split_into_words=True,
|
314 |
+
truncation=True,
|
315 |
+
add_special_tokens=False,
|
316 |
+
)
|
317 |
+
offset_batch = []
|
318 |
+
for i in range(len(token_batch)):
|
319 |
+
word_ids = batch.word_ids(batch_index=i)
|
320 |
+
offsets = [0]
|
321 |
+
for i in range(1, len(word_ids)):
|
322 |
+
if word_ids[i] != word_ids[i - 1]:
|
323 |
+
offsets.append(i)
|
324 |
+
offset_batch.append(torch.LongTensor(offsets))
|
325 |
+
|
326 |
+
batch["input_offsets"] = torch.nn.utils.rnn.pad_sequence(
|
327 |
+
offset_batch, batch_first=True, padding_value=0
|
328 |
+
).to(torch.long)
|
329 |
+
|
330 |
+
batches.append(batch)
|
331 |
+
|
332 |
+
return batches
|
333 |
+
|
334 |
+
def _convert(self, data):
|
335 |
+
all_class_probs = torch.zeros_like(data[0]['logits'])
|
336 |
+
error_probs = torch.zeros_like(data[0]['max_error_probability'])
|
337 |
+
for output, weight in zip(data, self.model_weights):
|
338 |
+
class_probabilities_labels = torch.softmax(output['logits'], dim=-1)
|
339 |
+
all_class_probs += weight * class_probabilities_labels / sum(self.model_weights)
|
340 |
+
error_probs += weight * output['max_error_probability'] / sum(self.model_weights)
|
341 |
+
|
342 |
+
max_vals = torch.max(all_class_probs, dim=-1)
|
343 |
+
probs = max_vals[0].tolist()
|
344 |
+
idx = max_vals[1].tolist()
|
345 |
+
return probs, idx, error_probs.tolist()
|
346 |
+
|
347 |
+
def update_final_batch(self, final_batch, pred_ids, pred_batch, prev_preds_dict):
|
348 |
+
new_pred_ids = []
|
349 |
+
total_updated = 0
|
350 |
+
for i, orig_id in enumerate(pred_ids):
|
351 |
+
orig = final_batch[orig_id]
|
352 |
+
pred = pred_batch[i]
|
353 |
+
prev_preds = prev_preds_dict[orig_id]
|
354 |
+
if orig != pred and pred not in prev_preds:
|
355 |
+
final_batch[orig_id] = pred
|
356 |
+
new_pred_ids.append(orig_id)
|
357 |
+
prev_preds_dict[orig_id].append(pred)
|
358 |
+
total_updated += 1
|
359 |
+
elif orig != pred and pred in prev_preds:
|
360 |
+
# update final batch, but stop iterations
|
361 |
+
final_batch[orig_id] = pred
|
362 |
+
total_updated += 1
|
363 |
+
else:
|
364 |
+
continue
|
365 |
+
return final_batch, new_pred_ids, total_updated
|
366 |
+
|
367 |
+
def postprocess_batch(self, batch, all_probabilities, all_idxs, error_probs):
|
368 |
+
all_results = []
|
369 |
+
noop_index = self.vocab.get_token_index("$KEEP", "labels")
|
370 |
+
for tokens, probabilities, idxs, error_prob in zip(batch, all_probabilities, all_idxs, error_probs):
|
371 |
+
length = min(len(tokens), self.max_len)
|
372 |
+
edits = []
|
373 |
+
|
374 |
+
# skip whole sentences if there no errors
|
375 |
+
if max(idxs) == 0:
|
376 |
+
all_results.append(tokens)
|
377 |
+
continue
|
378 |
+
|
379 |
+
# skip whole sentence if probability of correctness is not high
|
380 |
+
if error_prob < self.min_error_probability:
|
381 |
+
all_results.append(tokens)
|
382 |
+
continue
|
383 |
+
|
384 |
+
for i in range(length + 1):
|
385 |
+
# because of START token
|
386 |
+
if i == 0:
|
387 |
+
token = START_TOKEN
|
388 |
+
else:
|
389 |
+
token = tokens[i - 1]
|
390 |
+
# skip if there is no error
|
391 |
+
if idxs[i] == noop_index:
|
392 |
+
continue
|
393 |
+
|
394 |
+
sugg_token = self.vocab.get_token_from_index(idxs[i], namespace='labels')
|
395 |
+
action = self.get_token_action(token, i, probabilities[i], sugg_token)
|
396 |
+
if not action:
|
397 |
+
continue
|
398 |
+
|
399 |
+
edits.append(action)
|
400 |
+
all_results.append(get_target_sent_by_edits(tokens, edits))
|
401 |
+
return all_results
|
402 |
+
|
403 |
+
def handle_batch(self, full_batch, merge_punc=True):
|
404 |
+
"""
|
405 |
+
Handle batch of requests.
|
406 |
+
"""
|
407 |
+
if self.split_chunk:
|
408 |
+
full_batch, indices = self.split_chunks(full_batch)
|
409 |
+
else:
|
410 |
+
indices = None
|
411 |
+
final_batch = full_batch[:]
|
412 |
+
batch_size = len(full_batch)
|
413 |
+
prev_preds_dict = {i: [final_batch[i]] for i in range(len(final_batch))}
|
414 |
+
short_ids = [i for i in range(len(full_batch)) if len(full_batch[i]) < self.min_len]
|
415 |
+
pred_ids = [i for i in range(len(full_batch)) if i not in short_ids]
|
416 |
+
total_updates = 0
|
417 |
+
|
418 |
+
for n_iter in range(self.iterations):
|
419 |
+
orig_batch = [final_batch[i] for i in pred_ids]
|
420 |
+
|
421 |
+
sequences = self.preprocess(orig_batch)
|
422 |
+
|
423 |
+
if not sequences:
|
424 |
+
break
|
425 |
+
probabilities, idxs, error_probs = self.predict(sequences)
|
426 |
+
|
427 |
+
pred_batch = self.postprocess_batch(orig_batch, probabilities, idxs, error_probs)
|
428 |
+
if self.log:
|
429 |
+
print(f"Iteration {n_iter + 1}. Predicted {round(100*len(pred_ids)/batch_size, 1)}% of sentences.")
|
430 |
+
|
431 |
+
final_batch, pred_ids, cnt = self.update_final_batch(final_batch, pred_ids, pred_batch, prev_preds_dict)
|
432 |
+
total_updates += cnt
|
433 |
+
|
434 |
+
if not pred_ids:
|
435 |
+
break
|
436 |
+
if self.split_chunk:
|
437 |
+
final_batch = [self.merge_chunks(final_batch[start:end]) for (start, end) in indices]
|
438 |
+
else:
|
439 |
+
final_batch = [" ".join(x) for x in final_batch]
|
440 |
+
if merge_punc:
|
441 |
+
final_batch = [re.sub(r'\s+(%s)' % self.punc_str, r'\1', x) for x in final_batch]
|
442 |
+
|
443 |
+
return final_batch, total_updates
|
modeling_seq2labels.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import CrossEntropyLoss
|
4 |
+
from transformers import AutoConfig, AutoModel, BertPreTrainedModel
|
5 |
+
from transformers.modeling_outputs import ModelOutput
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
def get_range_vector(size: int, device: int) -> torch.Tensor:
|
11 |
+
"""
|
12 |
+
Returns a range vector with the desired size, starting at 0. The CUDA implementation
|
13 |
+
is meant to avoid copy data from CPU to GPU.
|
14 |
+
"""
|
15 |
+
return torch.arange(0, size, dtype=torch.long, device=device)
|
16 |
+
|
17 |
+
|
18 |
+
class Seq2LabelsOutput(ModelOutput):
|
19 |
+
loss: Optional[torch.FloatTensor] = None
|
20 |
+
logits: torch.FloatTensor = None
|
21 |
+
detect_logits: torch.FloatTensor = None
|
22 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
23 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
24 |
+
max_error_probability: Optional[torch.FloatTensor] = None
|
25 |
+
|
26 |
+
|
27 |
+
class Seq2LabelsModel(BertPreTrainedModel):
|
28 |
+
|
29 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
30 |
+
|
31 |
+
def __init__(self, config):
|
32 |
+
super().__init__(config)
|
33 |
+
self.num_labels = config.num_labels
|
34 |
+
self.num_detect_classes = config.num_detect_classes
|
35 |
+
self.label_smoothing = config.label_smoothing
|
36 |
+
|
37 |
+
if config.load_pretrained:
|
38 |
+
self.bert = AutoModel.from_pretrained(config.pretrained_name_or_path)
|
39 |
+
bert_config = self.bert.config
|
40 |
+
else:
|
41 |
+
bert_config = AutoConfig.from_pretrained(config.pretrained_name_or_path)
|
42 |
+
self.bert = AutoModel.from_config(bert_config)
|
43 |
+
|
44 |
+
if config.special_tokens_fix:
|
45 |
+
try:
|
46 |
+
vocab_size = self.bert.embeddings.word_embeddings.num_embeddings
|
47 |
+
except AttributeError:
|
48 |
+
# reserve more space
|
49 |
+
vocab_size = self.bert.word_embedding.num_embeddings + 5
|
50 |
+
self.bert.resize_token_embeddings(vocab_size + 1)
|
51 |
+
|
52 |
+
predictor_dropout = config.predictor_dropout if config.predictor_dropout is not None else 0.0
|
53 |
+
self.dropout = nn.Dropout(predictor_dropout)
|
54 |
+
self.classifier = nn.Linear(bert_config.hidden_size, config.vocab_size)
|
55 |
+
self.detector = nn.Linear(bert_config.hidden_size, config.num_detect_classes)
|
56 |
+
|
57 |
+
# Initialize weights and apply final processing
|
58 |
+
self.post_init()
|
59 |
+
|
60 |
+
def forward(
|
61 |
+
self,
|
62 |
+
input_ids: Optional[torch.Tensor] = None,
|
63 |
+
input_offsets: Optional[torch.Tensor] = None,
|
64 |
+
attention_mask: Optional[torch.Tensor] = None,
|
65 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
66 |
+
position_ids: Optional[torch.Tensor] = None,
|
67 |
+
head_mask: Optional[torch.Tensor] = None,
|
68 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
69 |
+
labels: Optional[torch.Tensor] = None,
|
70 |
+
d_tags: Optional[torch.Tensor] = None,
|
71 |
+
output_attentions: Optional[bool] = None,
|
72 |
+
output_hidden_states: Optional[bool] = None,
|
73 |
+
return_dict: Optional[bool] = None,
|
74 |
+
) -> Union[Tuple[torch.Tensor], Seq2LabelsOutput]:
|
75 |
+
r"""
|
76 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
77 |
+
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
|
78 |
+
"""
|
79 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
80 |
+
|
81 |
+
outputs = self.bert(
|
82 |
+
input_ids,
|
83 |
+
attention_mask=attention_mask,
|
84 |
+
token_type_ids=token_type_ids,
|
85 |
+
position_ids=position_ids,
|
86 |
+
head_mask=head_mask,
|
87 |
+
inputs_embeds=inputs_embeds,
|
88 |
+
output_attentions=output_attentions,
|
89 |
+
output_hidden_states=output_hidden_states,
|
90 |
+
return_dict=return_dict,
|
91 |
+
)
|
92 |
+
|
93 |
+
sequence_output = outputs[0]
|
94 |
+
|
95 |
+
if input_offsets is not None:
|
96 |
+
# offsets is (batch_size, d1, ..., dn, orig_sequence_length)
|
97 |
+
range_vector = get_range_vector(input_offsets.size(0), device=sequence_output.device).unsqueeze(1)
|
98 |
+
# selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
|
99 |
+
sequence_output = sequence_output[range_vector, input_offsets]
|
100 |
+
|
101 |
+
logits = self.classifier(self.dropout(sequence_output))
|
102 |
+
logits_d = self.detector(sequence_output)
|
103 |
+
|
104 |
+
loss = None
|
105 |
+
if labels is not None and d_tags is not None:
|
106 |
+
loss_labels_fct = CrossEntropyLoss(label_smoothing=self.label_smoothing)
|
107 |
+
loss_d_fct = CrossEntropyLoss()
|
108 |
+
loss_labels = loss_labels_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
109 |
+
loss_d = loss_d_fct(logits_d.view(-1, self.num_detect_classes), d_tags.view(-1))
|
110 |
+
loss = loss_labels + loss_d
|
111 |
+
|
112 |
+
if not return_dict:
|
113 |
+
output = (logits, logits_d) + outputs[2:]
|
114 |
+
return ((loss,) + output) if loss is not None else output
|
115 |
+
|
116 |
+
return Seq2LabelsOutput(
|
117 |
+
loss=loss,
|
118 |
+
logits=logits,
|
119 |
+
detect_logits=logits_d,
|
120 |
+
hidden_states=outputs.hidden_states,
|
121 |
+
attentions=outputs.attentions,
|
122 |
+
max_error_probability=torch.ones(logits.size(0)),
|
123 |
+
)
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d22fc1de03bc10237eafbd7487c04bc4ef6e890ecf3a77aa678e5995bc251bfd
|
3 |
+
size 461547689
|
utils.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
import re
|
4 |
+
|
5 |
+
|
6 |
+
VOCAB_DIR = Path(__file__).resolve().parent.parent / "data"
|
7 |
+
PAD = "@@PADDING@@"
|
8 |
+
UNK = "@@UNKNOWN@@"
|
9 |
+
START_TOKEN = "$START"
|
10 |
+
SEQ_DELIMETERS = {"tokens": " ", "labels": "SEPL|||SEPR", "operations": "SEPL__SEPR"}
|
11 |
+
|
12 |
+
|
13 |
+
def get_verb_form_dicts():
|
14 |
+
path_to_dict = os.path.join(VOCAB_DIR, "verb-form-vocab.txt")
|
15 |
+
encode, decode = {}, {}
|
16 |
+
with open(path_to_dict, encoding="utf-8") as f:
|
17 |
+
for line in f:
|
18 |
+
words, tags = line.split(":")
|
19 |
+
word1, word2 = words.split("_")
|
20 |
+
tag1, tag2 = tags.split("_")
|
21 |
+
decode_key = f"{word1}_{tag1}_{tag2.strip()}"
|
22 |
+
if decode_key not in decode:
|
23 |
+
encode[words] = tags
|
24 |
+
decode[decode_key] = word2
|
25 |
+
return encode, decode
|
26 |
+
|
27 |
+
|
28 |
+
ENCODE_VERB_DICT, DECODE_VERB_DICT = get_verb_form_dicts()
|
29 |
+
|
30 |
+
|
31 |
+
def get_target_sent_by_edits(source_tokens, edits):
|
32 |
+
target_tokens = source_tokens[:]
|
33 |
+
shift_idx = 0
|
34 |
+
for edit in edits:
|
35 |
+
start, end, label, _ = edit
|
36 |
+
target_pos = start + shift_idx
|
37 |
+
if start < 0:
|
38 |
+
continue
|
39 |
+
elif len(target_tokens) > target_pos:
|
40 |
+
source_token = target_tokens[target_pos]
|
41 |
+
else:
|
42 |
+
source_token = ""
|
43 |
+
if label == "":
|
44 |
+
del target_tokens[target_pos]
|
45 |
+
shift_idx -= 1
|
46 |
+
elif start == end:
|
47 |
+
word = label.replace("$APPEND_", "")
|
48 |
+
# Avoid appending same token twice
|
49 |
+
if (target_pos < len(target_tokens) and target_tokens[target_pos] == word) or (
|
50 |
+
target_pos > 0 and target_tokens[target_pos - 1] == word
|
51 |
+
):
|
52 |
+
continue
|
53 |
+
target_tokens[target_pos:target_pos] = [word]
|
54 |
+
shift_idx += 1
|
55 |
+
elif label.startswith("$TRANSFORM_"):
|
56 |
+
word = apply_reverse_transformation(source_token, label)
|
57 |
+
if word is None:
|
58 |
+
word = source_token
|
59 |
+
target_tokens[target_pos] = word
|
60 |
+
elif start == end - 1:
|
61 |
+
word = label.replace("$REPLACE_", "")
|
62 |
+
target_tokens[target_pos] = word
|
63 |
+
elif label.startswith("$MERGE_"):
|
64 |
+
target_tokens[target_pos + 1 : target_pos + 1] = [label]
|
65 |
+
shift_idx += 1
|
66 |
+
|
67 |
+
return replace_merge_transforms(target_tokens)
|
68 |
+
|
69 |
+
|
70 |
+
def replace_merge_transforms(tokens):
|
71 |
+
if all(not x.startswith("$MERGE_") for x in tokens):
|
72 |
+
return tokens
|
73 |
+
if tokens[0].startswith("$MERGE_"):
|
74 |
+
tokens = tokens[1:]
|
75 |
+
if tokens[-1].startswith("$MERGE_"):
|
76 |
+
tokens = tokens[:-1]
|
77 |
+
|
78 |
+
target_line = " ".join(tokens)
|
79 |
+
target_line = target_line.replace(" $MERGE_HYPHEN ", "-")
|
80 |
+
target_line = target_line.replace(" $MERGE_SPACE ", "")
|
81 |
+
target_line = re.sub(r'([\.\,\?\:]\s+)+', r'\1', target_line)
|
82 |
+
return target_line.split()
|
83 |
+
|
84 |
+
|
85 |
+
def convert_using_case(token, smart_action):
|
86 |
+
if not smart_action.startswith("$TRANSFORM_CASE_"):
|
87 |
+
return token
|
88 |
+
if smart_action.endswith("LOWER"):
|
89 |
+
return token.lower()
|
90 |
+
elif smart_action.endswith("UPPER"):
|
91 |
+
return token.upper()
|
92 |
+
elif smart_action.endswith("CAPITAL"):
|
93 |
+
return token.capitalize()
|
94 |
+
elif smart_action.endswith("CAPITAL_1"):
|
95 |
+
return token[0] + token[1:].capitalize()
|
96 |
+
elif smart_action.endswith("UPPER_-1"):
|
97 |
+
return token[:-1].upper() + token[-1]
|
98 |
+
else:
|
99 |
+
return token
|
100 |
+
|
101 |
+
|
102 |
+
def convert_using_verb(token, smart_action):
|
103 |
+
key_word = "$TRANSFORM_VERB_"
|
104 |
+
if not smart_action.startswith(key_word):
|
105 |
+
raise Exception(f"Unknown action type {smart_action}")
|
106 |
+
encoding_part = f"{token}_{smart_action[len(key_word):]}"
|
107 |
+
decoded_target_word = decode_verb_form(encoding_part)
|
108 |
+
return decoded_target_word
|
109 |
+
|
110 |
+
|
111 |
+
def convert_using_split(token, smart_action):
|
112 |
+
key_word = "$TRANSFORM_SPLIT"
|
113 |
+
if not smart_action.startswith(key_word):
|
114 |
+
raise Exception(f"Unknown action type {smart_action}")
|
115 |
+
target_words = token.split("-")
|
116 |
+
return " ".join(target_words)
|
117 |
+
|
118 |
+
|
119 |
+
def convert_using_plural(token, smart_action):
|
120 |
+
if smart_action.endswith("PLURAL"):
|
121 |
+
return token + "s"
|
122 |
+
elif smart_action.endswith("SINGULAR"):
|
123 |
+
return token[:-1]
|
124 |
+
else:
|
125 |
+
raise Exception(f"Unknown action type {smart_action}")
|
126 |
+
|
127 |
+
|
128 |
+
def apply_reverse_transformation(source_token, transform):
|
129 |
+
if transform.startswith("$TRANSFORM"):
|
130 |
+
# deal with equal
|
131 |
+
if transform == "$KEEP":
|
132 |
+
return source_token
|
133 |
+
# deal with case
|
134 |
+
if transform.startswith("$TRANSFORM_CASE"):
|
135 |
+
return convert_using_case(source_token, transform)
|
136 |
+
# deal with verb
|
137 |
+
if transform.startswith("$TRANSFORM_VERB"):
|
138 |
+
return convert_using_verb(source_token, transform)
|
139 |
+
# deal with split
|
140 |
+
if transform.startswith("$TRANSFORM_SPLIT"):
|
141 |
+
return convert_using_split(source_token, transform)
|
142 |
+
# deal with single/plural
|
143 |
+
if transform.startswith("$TRANSFORM_AGREEMENT"):
|
144 |
+
return convert_using_plural(source_token, transform)
|
145 |
+
# raise exception if not find correct type
|
146 |
+
raise Exception(f"Unknown action type {transform}")
|
147 |
+
else:
|
148 |
+
return source_token
|
149 |
+
|
150 |
+
|
151 |
+
# def read_parallel_lines(fn1, fn2):
|
152 |
+
# lines1 = read_lines(fn1, skip_strip=True)
|
153 |
+
# lines2 = read_lines(fn2, skip_strip=True)
|
154 |
+
# assert len(lines1) == len(lines2)
|
155 |
+
# out_lines1, out_lines2 = [], []
|
156 |
+
# for line1, line2 in zip(lines1, lines2):
|
157 |
+
# if not line1.strip() or not line2.strip():
|
158 |
+
# continue
|
159 |
+
# else:
|
160 |
+
# out_lines1.append(line1)
|
161 |
+
# out_lines2.append(line2)
|
162 |
+
# return out_lines1, out_lines2
|
163 |
+
|
164 |
+
|
165 |
+
def read_parallel_lines(fn1, fn2):
|
166 |
+
with open(fn1, encoding='utf-8') as f1, open(fn2, encoding='utf-8') as f2:
|
167 |
+
for line1, line2 in zip(f1, f2):
|
168 |
+
line1 = line1.strip()
|
169 |
+
line2 = line2.strip()
|
170 |
+
|
171 |
+
yield line1, line2
|
172 |
+
|
173 |
+
|
174 |
+
def read_lines(fn, skip_strip=False):
|
175 |
+
if not os.path.exists(fn):
|
176 |
+
return []
|
177 |
+
with open(fn, 'r', encoding='utf-8') as f:
|
178 |
+
lines = f.readlines()
|
179 |
+
return [s.strip() for s in lines if s.strip() or skip_strip]
|
180 |
+
|
181 |
+
|
182 |
+
def write_lines(fn, lines, mode='w'):
|
183 |
+
if mode == 'w' and os.path.exists(fn):
|
184 |
+
os.remove(fn)
|
185 |
+
with open(fn, encoding='utf-8', mode=mode) as f:
|
186 |
+
f.writelines(['%s\n' % s for s in lines])
|
187 |
+
|
188 |
+
|
189 |
+
def decode_verb_form(original):
|
190 |
+
return DECODE_VERB_DICT.get(original)
|
191 |
+
|
192 |
+
|
193 |
+
def encode_verb_form(original_word, corrected_word):
|
194 |
+
decoding_request = original_word + "_" + corrected_word
|
195 |
+
decoding_response = ENCODE_VERB_DICT.get(decoding_request, "").strip()
|
196 |
+
if original_word and decoding_response:
|
197 |
+
answer = decoding_response
|
198 |
+
else:
|
199 |
+
answer = None
|
200 |
+
return answer
|
201 |
+
|
202 |
+
|
203 |
+
def get_weights_name(transformer_name, lowercase):
|
204 |
+
if transformer_name == 'bert' and lowercase:
|
205 |
+
return 'bert-base-uncased'
|
206 |
+
if transformer_name == 'bert' and not lowercase:
|
207 |
+
return 'bert-base-cased'
|
208 |
+
if transformer_name == 'bert-large' and not lowercase:
|
209 |
+
return 'bert-large-cased'
|
210 |
+
if transformer_name == 'distilbert':
|
211 |
+
if not lowercase:
|
212 |
+
print('Warning! This model was trained only on uncased sentences.')
|
213 |
+
return 'distilbert-base-uncased'
|
214 |
+
if transformer_name == 'albert':
|
215 |
+
if not lowercase:
|
216 |
+
print('Warning! This model was trained only on uncased sentences.')
|
217 |
+
return 'albert-base-v1'
|
218 |
+
if lowercase:
|
219 |
+
print('Warning! This model was trained only on cased sentences.')
|
220 |
+
if transformer_name == 'roberta':
|
221 |
+
return 'roberta-base'
|
222 |
+
if transformer_name == 'roberta-large':
|
223 |
+
return 'roberta-large'
|
224 |
+
if transformer_name == 'gpt2':
|
225 |
+
return 'gpt2'
|
226 |
+
if transformer_name == 'transformerxl':
|
227 |
+
return 'transfo-xl-wt103'
|
228 |
+
if transformer_name == 'xlnet':
|
229 |
+
return 'xlnet-base-cased'
|
230 |
+
if transformer_name == 'xlnet-large':
|
231 |
+
return 'xlnet-large-cased'
|
232 |
+
|
233 |
+
return transformer_name
|
vocabulary.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import codecs
|
2 |
+
from collections import defaultdict
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Union, TYPE_CHECKING
|
7 |
+
from filelock import FileLock
|
8 |
+
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels")
|
13 |
+
DEFAULT_PADDING_TOKEN = "@@PADDING@@"
|
14 |
+
DEFAULT_OOV_TOKEN = "@@UNKNOWN@@"
|
15 |
+
NAMESPACE_PADDING_FILE = "non_padded_namespaces.txt"
|
16 |
+
_NEW_LINE_REGEX = re.compile(r"\n|\r\n")
|
17 |
+
|
18 |
+
|
19 |
+
def namespace_match(pattern: str, namespace: str):
|
20 |
+
"""
|
21 |
+
Matches a namespace pattern against a namespace string. For example, `*tags` matches
|
22 |
+
`passage_tags` and `question_tags` and `tokens` matches `tokens` but not
|
23 |
+
`stemmed_tokens`.
|
24 |
+
"""
|
25 |
+
if pattern[0] == "*" and namespace.endswith(pattern[1:]):
|
26 |
+
return True
|
27 |
+
elif pattern == namespace:
|
28 |
+
return True
|
29 |
+
return False
|
30 |
+
|
31 |
+
|
32 |
+
class _NamespaceDependentDefaultDict(defaultdict):
|
33 |
+
"""
|
34 |
+
This is a [defaultdict]
|
35 |
+
(https://docs.python.org/2/library/collections.html#collections.defaultdict) where the
|
36 |
+
default value is dependent on the key that is passed.
|
37 |
+
We use "namespaces" in the :class:`Vocabulary` object to keep track of several different
|
38 |
+
mappings from strings to integers, so that we have a consistent API for mapping words, tags,
|
39 |
+
labels, characters, or whatever else you want, into integers. The issue is that some of those
|
40 |
+
namespaces (words and characters) should have integers reserved for padding and
|
41 |
+
out-of-vocabulary tokens, while others (labels and tags) shouldn't. This class allows you to
|
42 |
+
specify filters on the namespace (the key used in the `defaultdict`), and use different
|
43 |
+
default values depending on whether the namespace passes the filter.
|
44 |
+
To do filtering, we take a set of `non_padded_namespaces`. This is a set of strings
|
45 |
+
that are either matched exactly against the keys, or treated as suffixes, if the
|
46 |
+
string starts with `*`. In other words, if `*tags` is in `non_padded_namespaces` then
|
47 |
+
`passage_tags`, `question_tags`, etc. (anything that ends with `tags`) will have the
|
48 |
+
`non_padded` default value.
|
49 |
+
# Parameters
|
50 |
+
non_padded_namespaces : `Iterable[str]`
|
51 |
+
A set / list / tuple of strings describing which namespaces are not padded. If a namespace
|
52 |
+
(key) is missing from this dictionary, we will use :func:`namespace_match` to see whether
|
53 |
+
the namespace should be padded. If the given namespace matches any of the strings in this
|
54 |
+
list, we will use `non_padded_function` to initialize the value for that namespace, and
|
55 |
+
we will use `padded_function` otherwise.
|
56 |
+
padded_function : `Callable[[], Any]`
|
57 |
+
A zero-argument function to call to initialize a value for a namespace that `should` be
|
58 |
+
padded.
|
59 |
+
non_padded_function : `Callable[[], Any]`
|
60 |
+
A zero-argument function to call to initialize a value for a namespace that should `not` be
|
61 |
+
padded.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
non_padded_namespaces: Iterable[str],
|
67 |
+
padded_function: Callable[[], Any],
|
68 |
+
non_padded_function: Callable[[], Any],
|
69 |
+
) -> None:
|
70 |
+
self._non_padded_namespaces = set(non_padded_namespaces)
|
71 |
+
self._padded_function = padded_function
|
72 |
+
self._non_padded_function = non_padded_function
|
73 |
+
super().__init__()
|
74 |
+
|
75 |
+
def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]):
|
76 |
+
# add non_padded_namespaces which weren't already present
|
77 |
+
self._non_padded_namespaces.update(non_padded_namespaces)
|
78 |
+
|
79 |
+
|
80 |
+
class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict):
|
81 |
+
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
|
82 |
+
super().__init__(non_padded_namespaces, lambda: {padding_token: 0, oov_token: 1}, lambda: {})
|
83 |
+
|
84 |
+
|
85 |
+
class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict):
|
86 |
+
def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None:
|
87 |
+
super().__init__(non_padded_namespaces, lambda: {0: padding_token, 1: oov_token}, lambda: {})
|
88 |
+
|
89 |
+
|
90 |
+
class Vocabulary:
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
counter: Dict[str, Dict[str, int]] = None,
|
94 |
+
min_count: Dict[str, int] = None,
|
95 |
+
max_vocab_size: Union[int, Dict[str, int]] = None,
|
96 |
+
non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES,
|
97 |
+
pretrained_files: Optional[Dict[str, str]] = None,
|
98 |
+
only_include_pretrained_words: bool = False,
|
99 |
+
tokens_to_add: Dict[str, List[str]] = None,
|
100 |
+
min_pretrained_embeddings: Dict[str, int] = None,
|
101 |
+
padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
|
102 |
+
oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
|
103 |
+
) -> None:
|
104 |
+
self._padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
|
105 |
+
self._oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN
|
106 |
+
|
107 |
+
self._non_padded_namespaces = set(non_padded_namespaces)
|
108 |
+
|
109 |
+
self._token_to_index = _TokenToIndexDefaultDict(
|
110 |
+
self._non_padded_namespaces, self._padding_token, self._oov_token
|
111 |
+
)
|
112 |
+
self._index_to_token = _IndexToTokenDefaultDict(
|
113 |
+
self._non_padded_namespaces, self._padding_token, self._oov_token
|
114 |
+
)
|
115 |
+
|
116 |
+
@classmethod
|
117 |
+
def from_files(
|
118 |
+
cls,
|
119 |
+
directory: Union[str, os.PathLike],
|
120 |
+
padding_token: Optional[str] = DEFAULT_PADDING_TOKEN,
|
121 |
+
oov_token: Optional[str] = DEFAULT_OOV_TOKEN,
|
122 |
+
) -> "Vocabulary":
|
123 |
+
"""
|
124 |
+
Loads a `Vocabulary` that was serialized either using `save_to_files` or inside
|
125 |
+
a model archive file.
|
126 |
+
# Parameters
|
127 |
+
directory : `str`
|
128 |
+
The directory or archive file containing the serialized vocabulary.
|
129 |
+
"""
|
130 |
+
logger.info("Loading token dictionary from %s.", directory)
|
131 |
+
padding_token = padding_token if padding_token is not None else DEFAULT_PADDING_TOKEN
|
132 |
+
oov_token = oov_token if oov_token is not None else DEFAULT_OOV_TOKEN
|
133 |
+
|
134 |
+
if not os.path.isdir(directory):
|
135 |
+
raise ValueError(f"{directory} not exist")
|
136 |
+
|
137 |
+
# We use a lock file to avoid race conditions where multiple processes
|
138 |
+
# might be reading/writing from/to the same vocab files at once.
|
139 |
+
with FileLock(os.path.join(directory, ".lock")):
|
140 |
+
with codecs.open(os.path.join(directory, NAMESPACE_PADDING_FILE), "r", "utf-8") as namespace_file:
|
141 |
+
non_padded_namespaces = [namespace_str.strip() for namespace_str in namespace_file]
|
142 |
+
|
143 |
+
vocab = cls(
|
144 |
+
non_padded_namespaces=non_padded_namespaces,
|
145 |
+
padding_token=padding_token,
|
146 |
+
oov_token=oov_token,
|
147 |
+
)
|
148 |
+
|
149 |
+
# Check every file in the directory.
|
150 |
+
for namespace_filename in os.listdir(directory):
|
151 |
+
if namespace_filename == NAMESPACE_PADDING_FILE:
|
152 |
+
continue
|
153 |
+
if namespace_filename.startswith("."):
|
154 |
+
continue
|
155 |
+
namespace = namespace_filename.replace(".txt", "")
|
156 |
+
if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces):
|
157 |
+
is_padded = False
|
158 |
+
else:
|
159 |
+
is_padded = True
|
160 |
+
filename = os.path.join(directory, namespace_filename)
|
161 |
+
vocab.set_from_file(filename, is_padded, namespace=namespace, oov_token=oov_token)
|
162 |
+
|
163 |
+
return vocab
|
164 |
+
|
165 |
+
@classmethod
|
166 |
+
def empty(cls) -> "Vocabulary":
|
167 |
+
"""
|
168 |
+
This method returns a bare vocabulary instantiated with `cls()` (so, `Vocabulary()` if you
|
169 |
+
haven't made a subclass of this object). The only reason to call `Vocabulary.empty()`
|
170 |
+
instead of `Vocabulary()` is if you are instantiating this object from a config file. We
|
171 |
+
register this constructor with the key "empty", so if you know that you don't need to
|
172 |
+
compute a vocabulary (either because you're loading a pre-trained model from an archive
|
173 |
+
file, you're using a pre-trained transformer that has its own vocabulary, or something
|
174 |
+
else), you can use this to avoid having the default vocabulary construction code iterate
|
175 |
+
through the data.
|
176 |
+
"""
|
177 |
+
return cls()
|
178 |
+
|
179 |
+
def set_from_file(
|
180 |
+
self,
|
181 |
+
filename: str,
|
182 |
+
is_padded: bool = True,
|
183 |
+
oov_token: str = DEFAULT_OOV_TOKEN,
|
184 |
+
namespace: str = "tokens",
|
185 |
+
):
|
186 |
+
"""
|
187 |
+
If you already have a vocabulary file for a trained model somewhere, and you really want to
|
188 |
+
use that vocabulary file instead of just setting the vocabulary from a dataset, for
|
189 |
+
whatever reason, you can do that with this method. You must specify the namespace to use,
|
190 |
+
and we assume that you want to use padding and OOV tokens for this.
|
191 |
+
# Parameters
|
192 |
+
filename : `str`
|
193 |
+
The file containing the vocabulary to load. It should be formatted as one token per
|
194 |
+
line, with nothing else in the line. The index we assign to the token is the line
|
195 |
+
number in the file (1-indexed if `is_padded`, 0-indexed otherwise). Note that this
|
196 |
+
file should contain the OOV token string!
|
197 |
+
is_padded : `bool`, optional (default=`True`)
|
198 |
+
Is this vocabulary padded? For token / word / character vocabularies, this should be
|
199 |
+
`True`; while for tag or label vocabularies, this should typically be `False`. If
|
200 |
+
`True`, we add a padding token with index 0, and we enforce that the `oov_token` is
|
201 |
+
present in the file.
|
202 |
+
oov_token : `str`, optional (default=`DEFAULT_OOV_TOKEN`)
|
203 |
+
What token does this vocabulary use to represent out-of-vocabulary characters? This
|
204 |
+
must show up as a line in the vocabulary file. When we find it, we replace
|
205 |
+
`oov_token` with `self._oov_token`, because we only use one OOV token across
|
206 |
+
namespaces.
|
207 |
+
namespace : `str`, optional (default=`"tokens"`)
|
208 |
+
What namespace should we overwrite with this vocab file?
|
209 |
+
"""
|
210 |
+
if is_padded:
|
211 |
+
self._token_to_index[namespace] = {self._padding_token: 0}
|
212 |
+
self._index_to_token[namespace] = {0: self._padding_token}
|
213 |
+
else:
|
214 |
+
self._token_to_index[namespace] = {}
|
215 |
+
self._index_to_token[namespace] = {}
|
216 |
+
with codecs.open(filename, "r", "utf-8") as input_file:
|
217 |
+
lines = _NEW_LINE_REGEX.split(input_file.read())
|
218 |
+
# Be flexible about having final newline or not
|
219 |
+
if lines and lines[-1] == "":
|
220 |
+
lines = lines[:-1]
|
221 |
+
for i, line in enumerate(lines):
|
222 |
+
index = i + 1 if is_padded else i
|
223 |
+
token = line.replace("@@NEWLINE@@", "\n")
|
224 |
+
if token == oov_token:
|
225 |
+
token = self._oov_token
|
226 |
+
self._token_to_index[namespace][token] = index
|
227 |
+
self._index_to_token[namespace][index] = token
|
228 |
+
if is_padded:
|
229 |
+
assert self._oov_token in self._token_to_index[namespace], "OOV token not found!"
|
230 |
+
|
231 |
+
def add_token_to_namespace(self, token: str, namespace: str = "tokens") -> int:
|
232 |
+
"""
|
233 |
+
Adds `token` to the index, if it is not already present. Either way, we return the index of
|
234 |
+
the token.
|
235 |
+
"""
|
236 |
+
if not isinstance(token, str):
|
237 |
+
raise ValueError(
|
238 |
+
"Vocabulary tokens must be strings, or saving and loading will break."
|
239 |
+
" Got %s (with type %s)" % (repr(token), type(token))
|
240 |
+
)
|
241 |
+
if token not in self._token_to_index[namespace]:
|
242 |
+
index = len(self._token_to_index[namespace])
|
243 |
+
self._token_to_index[namespace][token] = index
|
244 |
+
self._index_to_token[namespace][index] = token
|
245 |
+
return index
|
246 |
+
else:
|
247 |
+
return self._token_to_index[namespace][token]
|
248 |
+
|
249 |
+
def add_tokens_to_namespace(self, tokens: List[str], namespace: str = "tokens") -> List[int]:
|
250 |
+
"""
|
251 |
+
Adds `tokens` to the index, if they are not already present. Either way, we return the
|
252 |
+
indices of the tokens in the order that they were given.
|
253 |
+
"""
|
254 |
+
return [self.add_token_to_namespace(token, namespace) for token in tokens]
|
255 |
+
|
256 |
+
def get_token_index(self, token: str, namespace: str = "tokens") -> int:
|
257 |
+
try:
|
258 |
+
return self._token_to_index[namespace][token]
|
259 |
+
except KeyError:
|
260 |
+
try:
|
261 |
+
return self._token_to_index[namespace][self._oov_token]
|
262 |
+
except KeyError:
|
263 |
+
logger.error("Namespace: %s", namespace)
|
264 |
+
logger.error("Token: %s", token)
|
265 |
+
raise KeyError(
|
266 |
+
f"'{token}' not found in vocab namespace '{namespace}', and namespace "
|
267 |
+
f"does not contain the default OOV token ('{self._oov_token}')"
|
268 |
+
)
|
269 |
+
|
270 |
+
def get_token_from_index(self, index: int, namespace: str = "tokens") -> str:
|
271 |
+
return self._index_to_token[namespace][index]
|
272 |
+
|
273 |
+
def get_vocab_size(self, namespace: str = "tokens") -> int:
|
274 |
+
return len(self._token_to_index[namespace])
|
275 |
+
|
276 |
+
def get_namespaces(self) -> Set[str]:
|
277 |
+
return set(self._index_to_token.keys())
|
vocabulary/.lock
ADDED
File without changes
|
vocabulary/d_tags.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CORRECT
|
2 |
+
INCORRECT
|
3 |
+
@@UNKNOWN@@
|
4 |
+
@@PADDING@@
|
vocabulary/labels.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$KEEP
|
2 |
+
$TRANSFORM_CASE_CAPITAL
|
3 |
+
$APPEND_,
|
4 |
+
$APPEND_.
|
5 |
+
$TRANSFORM_VERB_VB_VBN
|
6 |
+
$TRANSFORM_CASE_UPPER
|
7 |
+
$APPEND_:
|
8 |
+
$APPEND_?
|
9 |
+
$TRANSFORM_VERB_VB_VBC
|
10 |
+
$TRANSFORM_CASE_LOWER
|
11 |
+
$TRANSFORM_CASE_CAPITAL_1
|
12 |
+
$TRANSFORM_CASE_UPPER_-1
|
13 |
+
$MERGE_SPACE
|
14 |
+
@@UNKNOWN@@
|
15 |
+
@@PADDING@@
|
vocabulary/non_padded_namespaces.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*tags
|
2 |
+
*labels
|