File size: 5,459 Bytes
21d29cb
 
 
 
 
31bf2aa
21d29cb
 
 
 
 
 
 
7cfca48
 
 
 
31bf2aa
7cfca48
 
 
21d29cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31bf2aa
21d29cb
 
 
 
31bf2aa
21d29cb
 
 
 
 
 
 
 
 
 
 
 
 
31bf2aa
21d29cb
 
 
 
31bf2aa
 
 
21d29cb
31bf2aa
21d29cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cfca48
21d29cb
 
 
31bf2aa
21d29cb
 
 
 
 
 
 
7cfca48
21d29cb
 
 
 
 
 
7cfca48
 
 
 
31bf2aa
7cfca48
 
 
21d29cb
 
 
 
 
 
 
 
 
 
 
 
 
 
7cfca48
21d29cb
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import ast
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union, Any

from datasets import load_dataset
from tokenizers import ByteLevelBPETokenizer
from transformers import (
    HfArgumentParser,
)

from data_utils import (
    filter_by_lang_regex,
    filter_by_num_tokens,
    filter_by_num_sents,
    filter_by_adv,
    normalizer
)

logger = logging.getLogger(__name__)


@dataclass
class TokenizerArguments:
    """
    Arguments to which tokenizer we are going to set up.
    """

    output_dir: str = field(
        default=".",
        metadata={"help": "The output directory where the config will be written."},
    )
    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    special_tokens: Optional[str] = field(
        default=None,
        metadata={"help": "The list of special tokens that you want to add in your training."}
    )
    vocab_size: Optional[int] = field(
        default=56000,
        metadata={"help": "The size of the final vocabulary, including all tokens and alphabet"}
    )
    min_frequency: Optional[int] = field(
        default=2,
        metadata={"help": "The minimum frequency a pair should have in order to be merged"}
    )
    show_progress: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to show progress bars while training"}
    )

    def __post_init__(self):
        if self.special_tokens is None:
            special_tokens = [
                "<s>", "<pad>", "</s>", "<unk>", "<mask>",
                "<|endoftext|>", "<|startoftext|>",
                "<sep>", "<cls>", "<nl>", "<tab>", "<zwnj>"
            ]
            special_tokens += [f"[U{i}]" for i in range(1, 21)]
        else:
            special_tokens = list(self.special_tokens.split(","))

        self.special_tokens = special_tokens
        if self.dataset_name is None and self.train_file is None:
            raise ValueError("Need either a dataset name or a training file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."


def main():
    parser = HfArgumentParser([TokenizerArguments])
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        tokenizer_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
    else:
        tokenizer_args = parser.parse_args_into_dataclasses()[0]

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    logger.setLevel(logging.INFO)

    logger.info(f"Training tokenizer")

    if tokenizer_args.dataset_name is not None:
        raw_dataset = load_dataset(
            tokenizer_args.dataset_name,
            tokenizer_args.dataset_config_name,
            cache_dir=tokenizer_args.cache_dir,
            split="train"
        )
    else:
        data_files = {"train": tokenizer_args.train_file}
        extension = tokenizer_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"

        raw_dataset = load_dataset(
            extension,
            data_files=data_files,
            delimiter="\t",
            cache_dir=tokenizer_args.cache_dir,
        )

    logger.info("Preprocessing the dataset")
    dataset = raw_dataset.filter(lambda example: filter_by_lang_regex(example["text"], ratio=0.75))
    dataset = dataset.filter(lambda example: filter_by_num_tokens(example["text"], gt=64))
    dataset = dataset.filter(lambda example: filter_by_num_sents(example["text"], gt=2))
    dataset = dataset.filter(lambda example: filter_by_adv(example["text"], ratio=50))
    dataset = dataset.map(normalizer)
    logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")

    tokenizer = ByteLevelBPETokenizer()

    def batch_iterative(batch_size=1000):
        for i in range(0, len(dataset), batch_size):
            yield dataset[i: i + batch_size]["text"]

    tokenizer.train_from_iterator(
        batch_iterative(),
        vocab_size=tokenizer_args.vocab_size,
        special_tokens=tokenizer_args.special_tokens,
        min_frequency=tokenizer_args.min_frequency,
        show_progress=tokenizer_args.show_progress,
    )

    logger.info(f"Your tokenizer saved here {tokenizer_args.output_dir}")
    os.makedirs(tokenizer_args.output_dir, exist_ok=True)
    tokenizer.save_model(tokenizer_args.output_dir)
    tokenizer.save(f"{tokenizer_args.output_dir}/tokenizer.json", pretty=True)


if __name__ == '__main__':
    main()