m3hrdadfi commited on
Commit
21d29cb
1 Parent(s): e9ad52d

Hello gpt2-persian

Browse files
README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GPT2 - Persian
2
+
3
+
4
+ ## Scripts
5
+
6
+ ### Normalizer
7
+
8
+ ```python
9
+ from src.normalizer import normalize
10
+
11
+ input_text = "ὑ蕉Ұ제ṅ尘̲改座◦花芝秀黄天자埃澤ಿ ˈazbab اینجا ایران خانه‌شما است؟!۱۲۳۱۲۳۱۳۱۲ اَلْحُرُوفُ ٱلْعَرَبِیَّة"input_text = normalize(input_text)
12
+ print(normalize(input_text))
13
+ ```
14
+
15
+ Output:
16
+ ```text
17
+ azbab اینجا ایران خانه‌شما است ؟ ! 1231231312 الحروف لعربیه
18
+ ```
19
+
20
+ ### Training tokenizer
21
+
22
+ ```bash
23
+ python train_tokenizer.py --dataset_name oscar --dataset_config_name unshuffled_deduplicated_als --vocab_size 42000
24
+ ```
25
+
26
+ ### Configuration
27
+
28
+ ```bash
29
+ python create_config.py --name_or_path gpt2-medium --params '{"vocab_size": 42000}'
30
+ ```
31
+
32
+
33
+
notes/.keep ADDED
File without changes
src/__pycache__/dictionary.cpython-39.pyc ADDED
Binary file (2.01 kB). View file
 
src/create_config.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import logging
3
+ import os
4
+ import sys
5
+ from dataclasses import dataclass, field
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ from transformers import (
9
+ HfArgumentParser,
10
+ AutoConfig
11
+ )
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class ConfigArguments:
18
+ """
19
+ Arguments to which config we are going to set up.
20
+ """
21
+ output_dir: str = field(
22
+ default=".",
23
+ metadata={"help": "The output directory where the config will be written."},
24
+ )
25
+ name_or_path: Optional[str] = field(
26
+ default=None,
27
+ metadata={
28
+ "help": "The model checkpoint for weights initialization."
29
+ "Don't set if you want to train a model from scratch."
30
+ },
31
+ )
32
+ params: Optional[str] = field(
33
+ default=None,
34
+ metadata={"help": "Custom configuration for the specific `name_or_path`"}
35
+ )
36
+
37
+ def __post_init__(self):
38
+ if self.params:
39
+ try:
40
+ self.params = ast.literal_eval(self.params)
41
+ except Exception as e:
42
+ print(f"Your custom parameters do not acceptable due to {e}")
43
+
44
+
45
+ def main():
46
+ parser = HfArgumentParser([ConfigArguments])
47
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
48
+ # If we pass only one argument to the script and it's the path to a json file,
49
+ # let's parse it to get our arguments.
50
+ config_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
51
+ else:
52
+ config_args = parser.parse_args_into_dataclasses()[0]
53
+
54
+ # Setup logging
55
+ logging.basicConfig(
56
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
57
+ datefmt="%m/%d/%Y %H:%M:%S",
58
+ handlers=[logging.StreamHandler(sys.stdout)],
59
+ )
60
+ logger.setLevel(logging.INFO)
61
+
62
+ logger.info(f"Setting up configuration {config_args.name_or_path} with extra params {config_args.params}")
63
+
64
+ if config_args.params and isinstance(config_args.params, dict):
65
+ config = AutoConfig.from_pretrained(config_args.name_or_path, **config_args.params)
66
+ else:
67
+ config = AutoConfig.from_pretrained(config_args.name_or_path)
68
+
69
+ logger.info(f"Your configuration saved here {config_args.output_dir}/config.json")
70
+ config.save_pretrained(config_args.output_dir)
71
+
72
+
73
+ if __name__ == '__main__':
74
+ main()
src/dictionary.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ characters = {
2
+ "ك": "ک",
3
+ "دِ": "د",
4
+ "بِ": "ب",
5
+ "زِ": "ز",
6
+ "ذِ": "ذ",
7
+ "شِ": "ش",
8
+ "سِ": "س",
9
+ "ى": "ی",
10
+ "ي": "ی",
11
+ "ؤ": "و",
12
+ "ے": "ی",
13
+ "ۀ": "ه",
14
+ "ﭘ": "پ",
15
+ "ﮐ": "ک",
16
+ "ﯽ": "ی",
17
+ "ﺎ": "ا",
18
+ "ﺑ": "ب",
19
+ "ﺘ": "ت",
20
+ "ﺧ": "خ",
21
+ "ﺩ": "د",
22
+ "ﺱ": "س",
23
+ "ﻀ": "ض",
24
+ "ﻌ": "ع",
25
+ "ﻟ": "ل",
26
+ "ﻡ": "م",
27
+ "ﻢ": "م",
28
+ "ﻪ": "ه",
29
+ "ﻮ": "و",
30
+ "ﺍ": "ا",
31
+ "ة": "ه",
32
+ "ﯾ": "ی",
33
+ "ﯿ": "ی",
34
+ "ﺒ": "ب",
35
+ "ﺖ": "ت",
36
+ "ﺪ": "د",
37
+ "ﺮ": "ر",
38
+ "ﺴ": "س",
39
+ "ﺷ": "ش",
40
+ "ﺸ": "ش",
41
+ "ﻋ": "ع",
42
+ "ﻤ": "م",
43
+ "ﻥ": "ن",
44
+ "ﻧ": "ن",
45
+ "ﻭ": "و",
46
+ "ﺭ": "ر",
47
+ "ﮔ": "گ",
48
+ "إ": "ا",
49
+ "ٕ": " ",
50
+ "ھ": "ه",
51
+ "...": ".",
52
+ "…": ".",
53
+ "-": " - ",
54
+ "هٔ": "ه",
55
+ "ﻯ": "ی",
56
+ "ﻛ": "ک",
57
+ "ﭼ": "چ",
58
+ "ﺓ": "ه",
59
+ "ﻴ": "ی",
60
+ "ﻊ": "ع",
61
+ "ﮬ": "ه",
62
+ "ﺟ": "ج",
63
+ "ﺳ": "س",
64
+ "ﻦ": "ن",
65
+ "ﺬ": "ذ",
66
+ "ﺋ": "ئ",
67
+ "ﷲ": "لله",
68
+ "ﺞ": "ج",
69
+ "ﺙ": "ث",
70
+ "ﻗ": "ق",
71
+ "ﮪ": "ه",
72
+ "ﺰ": "ز",
73
+ "ﯼ": "ی",
74
+ "ٺ": "ت",
75
+ "ﺻ": "ص",
76
+ "ﻂ": "ط",
77
+ "ﻣ": "م",
78
+ "ﻈ": "ظ",
79
+ "ﺐ": "ب",
80
+ "ﻍ": "غ",
81
+ "ݸ": "و",
82
+ "ﻨ": "ن",
83
+ "ﻝ": "ل",
84
+ "ﻩ": "ه",
85
+ "ﻲ": "ی",
86
+ "ﻐ": "غ",
87
+ "ﺲ": "س",
88
+ "ﺁ": "آ",
89
+ "ڔ": "ر",
90
+ "ﺫ": "ذ",
91
+ "ﭻ": "چ",
92
+ "ﺠ": "ج",
93
+ "ﯙ": "و",
94
+ "ﮏ": "ک",
95
+ "ﺣ": "ح",
96
+ "ﺝ": "ج",
97
+ "ﺼ": "ص",
98
+ "ﻳ": "ی",
99
+ "ﻘ": "ق",
100
+ "ﺨ": "خ",
101
+ "ﻔ": "ف",
102
+ "ﻎ": "غ",
103
+ "ئ": "ی",
104
+ "ﻓ": "ف",
105
+ "ﻕ": "ق",
106
+ "ﮋ": "ژ",
107
+ "ﺗ": "ت",
108
+ "ﻁ": "ط",
109
+ "ﺯ": "ز",
110
+ "ﮕ": "گ",
111
+ "ﺌ": "ئ",
112
+ "ﺵ": "ش",
113
+ "ۮ": "د",
114
+ "ﻫ": "ه",
115
+ "ﻬ": "ه",
116
+ "ﻏ": "غ",
117
+ "ﻰ": "ی",
118
+ "﷼": "ریال",
119
+ "ﺿ": "ض",
120
+ "ﺛ": "ث",
121
+ "ݐ": "پ",
122
+ "ﺏ": "ب",
123
+ "ﭙ": "پ",
124
+ "ﭽ": "چ",
125
+ "ﺜ": "ث",
126
+ "ﻃ": "ط",
127
+ "ۂ": "ه",
128
+ "ﻑ": "ف",
129
+ "ﺕ": "ت",
130
+ "ﻞ": "ل",
131
+ }
132
+
133
+ special_tokens = {}
src/normalizer.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hazm
2
+ import re
3
+
4
+ from regexes.currency import CURRENCY_REGEX
5
+ from regexes.email import EMAIL_REGEX
6
+ from regexes.latin import LATIN_REGEX
7
+ from regexes.latin import LATIN_REGEX, LATIN_WITH_SPECIAL_REGEX
8
+ from regexes.number import NUMBERS_REGEX
9
+ from regexes.phone import PHONE_REGEX
10
+ from regexes.quote import DOUBLE_QUOTE_REGEX, SINGLE_QUOTE_REGEX
11
+ from regexes.url import URL_REGEX
12
+ from regexes.persian import PERSIAN_REGEX
13
+ from regexes.punk import PUNK_REGEX
14
+ import dictionary
15
+
16
+
17
+ def make_trans(list_a, list_b):
18
+ return dict((ord(a), b) for a, b in zip(list_a, list_b))
19
+
20
+
21
+ def multiple_replace(text, chars_to_mapping):
22
+ pattern = "|".join(map(re.escape, chars_to_mapping.keys()))
23
+ return re.sub(pattern, lambda m: chars_to_mapping[m.group()], str(text))
24
+
25
+
26
+ ar2fa_digits = make_trans("٠١٢٣٤٥٦٧٨٩٪", "۰۱۲۳۴۵۶۷۸۹٪")
27
+ fa2en_digits = make_trans("۰۱۲۳۴۵۶۷۸۹٪", "0123456789%")
28
+ normalizer = hazm.Normalizer(persian_numbers=True)
29
+
30
+
31
+ def normalize(text, zwnj="\u200c", tokenized=False):
32
+ text = text.replace("\n", " ").replace("\t", " ")
33
+ text = re.sub(r"\u200c+", "\u200c", text)
34
+
35
+ text = normalizer.normalize(text)
36
+
37
+ if len(dictionary.characters) > 0:
38
+ text = multiple_replace(text, dictionary.characters)
39
+
40
+ text = text.translate(ar2fa_digits)
41
+ text = text.translate(fa2en_digits)
42
+
43
+ text = SINGLE_QUOTE_REGEX.sub("'", text)
44
+ text = DOUBLE_QUOTE_REGEX.sub('"', text)
45
+ text = CURRENCY_REGEX.sub(r" \1 ", text)
46
+ text = URL_REGEX.sub(r" \1 ", text)
47
+ text = EMAIL_REGEX.sub(r" \1 ", text)
48
+ text = PHONE_REGEX.sub(r" \1 ", text)
49
+ text = NUMBERS_REGEX.sub(r" \1 ", text)
50
+ text = LATIN_REGEX.sub(r" \1 ", text)
51
+ text = PUNK_REGEX.sub(r" \1 ", text)
52
+
53
+ # Allow only english and persian characters
54
+ text = re.sub(PERSIAN_REGEX, " ", text)
55
+
56
+ text = text.replace(f" {zwnj} ", f"{zwnj}")
57
+ text = text.replace(f"{zwnj} ", f"{zwnj}")
58
+ text = text.replace(f" {zwnj}", f"{zwnj}")
59
+
60
+ if len(dictionary.special_tokens) > 0:
61
+ text = multiple_replace(text, dictionary.special_tokens)
62
+
63
+ tokens = []
64
+ for token in text.split():
65
+ token = token.strip()
66
+ if token:
67
+ if token.startswith(zwnj) and token.endswith(zwnj):
68
+ token = token[1:-1]
69
+ if token.startswith(zwnj):
70
+ token = token[1:]
71
+ elif token.endswith(zwnj):
72
+ token = token[:-1]
73
+ else:
74
+ token = token
75
+
76
+ tokens.append(token)
77
+
78
+ if tokenized:
79
+ return tokens
80
+
81
+ return " ".join(tokens)
82
+
83
+
84
+ if __name__ == '__main__':
85
+ import textwrap
86
+
87
+ input_text = "دارهٔ تحقیقات فدرال در سال ۱۹۰۸ به نام ادارهٔ تحقیقات (BOI یا BI) بنیان‌گذاری شد. نام این سازمان در سال ۱۹۳۵ به ادارهٔ تحقیقات فدرال تغییر یافت. دفتر مرکزی اف‌بی‌آی در ساختمان جی. ادگار هوور در شهر واشینگتن، دی.سی. واقع شده‌است."
88
+ input_text = "یونان (به یونانی: Ελλάδα, اِلادا)"
89
+ input_text = "نسخهٔ"
90
+ input_text = "ὑ蕉Ұ제ṅ尘̲改座◦花芝秀黄天자埃澤ಿ ˈazbab اینجا ایران خانه‌شما است؟!۱۲۳۱۲۳۱۳۱۲ اَلْحُرُوفُ ٱلْعَرَبِیَّة"
91
+ input_text = normalize(input_text)
92
+ print(textwrap.fill(input_text))
93
+ print(normalize(input_text, tokenized=True))
src/regexes/__init__.py ADDED
File without changes
src/regexes/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (174 Bytes). View file
 
src/regexes/__pycache__/currency.cpython-39.pyc ADDED
Binary file (693 Bytes). View file
 
src/regexes/__pycache__/email.cpython-39.pyc ADDED
Binary file (488 Bytes). View file
 
src/regexes/__pycache__/latin.cpython-39.pyc ADDED
Binary file (388 Bytes). View file
 
src/regexes/__pycache__/number.cpython-39.pyc ADDED
Binary file (354 Bytes). View file
 
src/regexes/__pycache__/persian.cpython-39.pyc ADDED
Binary file (555 Bytes). View file
 
src/regexes/__pycache__/phone.cpython-39.pyc ADDED
Binary file (384 Bytes). View file
 
src/regexes/__pycache__/punk.cpython-39.pyc ADDED
Binary file (315 Bytes). View file
 
src/regexes/__pycache__/quote.cpython-39.pyc ADDED
Binary file (535 Bytes). View file
 
src/regexes/__pycache__/url.cpython-39.pyc ADDED
Binary file (783 Bytes). View file
 
src/regexes/currency.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ CURRENCIES = {
4
+ "$": "USD",
5
+ "zł": "PLN",
6
+ "£": "GBP",
7
+ "¥": "JPY",
8
+ "฿": "THB",
9
+ "₡": "CRC",
10
+ "₦": "NGN",
11
+ "₩": "KRW",
12
+ "₪": "ILS",
13
+ "₫": "VND",
14
+ "€": "EUR",
15
+ "₱": "PHP",
16
+ "₲": "PYG",
17
+ "₴": "UAH",
18
+ "₹": "INR",
19
+ "﷼": "IRR",
20
+ }
21
+ CURRENCY_REGEX = re.compile(
22
+ "({})+".format("|".join(re.escape(c) for c in CURRENCIES.keys()))
23
+ )
src/regexes/email.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ EMAIL_REGEX = re.compile(
4
+ r"(?:^|(?<=[^\w@.)]))([\w+-](\.(?!\.))?)*?[\w+-](@|[(<{\[]at[)>}\]])(?:(?:[a-z\\u00a1-\\uffff0-9]-?)*[a-z\\u00a1-\\uffff0-9]+)(?:\.(?:[a-z\\u00a1-\\uffff0-9]-?)*[a-z\\u00a1-\\uffff0-9]+)*(?:\.(?:[a-z\\u00a1-\\uffff]{2,}))",
5
+ flags=re.IGNORECASE | re.UNICODE,
6
+ )
src/regexes/latin.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ LATIN_WITH_SPECIAL_REGEX = re.compile(
4
+ r"(\b(?!URL|EMAIL|PHONE|NUMBER|CUR|LATIN\b)[0-9a-zA-Z]+)"
5
+ )
6
+
7
+ LATIN_REGEX = re.compile(
8
+ r"([0-9a-zA-Z]+)"
9
+ )
10
+
11
+ LATIN_SPACES_REGEX = re.compile(
12
+ r"([0-9a-zA-Z])"
13
+ )
src/regexes/number.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import re
2
+
3
+ NUMBERS_REGEX = re.compile(
4
+ r"(?:^|(?<=[^\w,.]))[+–-]?(([1-9]\d{0,2}(,\d{3})+(\.\d*)?)|([1-9]\d{0,2}([ .]\d{3})+(,\d*)?)|(\d*?[.,]\d+)|\d+)(?:$|(?=\b))"
5
+ )
src/regexes/persian.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ PERSIAN_ALPHA = "ءآئابتثجحخدذرزسشصضطظعغفقلمنهوپچژکگیە" # noqa: E501
5
+ PERSIAN_DIGIT = "۰۱۲۳۴۵۶۷۸۹"
6
+
7
+
8
+ ZWNJ = "\u200c"
9
+ PUNK = '\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\<\=\>\?\@\[\]\^\_\`\{\|\}\~\«\»\؟\:\×\٬\٫\﷼\٪\،'
10
+
11
+ PERSIAN = (
12
+ "a-zA-Z0-9" +
13
+ PERSIAN_ALPHA +
14
+ PERSIAN_DIGIT +
15
+ ZWNJ +
16
+ PUNK
17
+ )
18
+
19
+ PERSIAN_REGEX = r"[^" + PERSIAN + "+]"
src/regexes/phone.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ PHONE_REGEX = re.compile(
5
+ r"((?:^|(?<=[^\w)]))(((\+?[01])|(\+\d{2}))[ .-]?)?(\(?\d{3,4}\)?/?[ .-]?)?(\d{3}[ .-]?\d{4})(\s?(?:ext\.?|[#x-])\s?\d{2,6})?(?:$|(?=\W)))|\+?\d{4,5}[ .-/]\d{6,9}"
6
+ )
src/regexes/punk.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import re
2
+
3
+ PUNK_REGEX = re.compile(
4
+ r"([\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\=\?\@\[\\\]\^\_\`\{\|\}\~\«\»\⸮\؟\،\٬\٫\؛])"
5
+ )
src/regexes/quote.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ strange_double_quotes = [
5
+ "«",
6
+ "‹",
7
+ "»",
8
+ "›",
9
+ "„",
10
+ "“",
11
+ "‟",
12
+ "”",
13
+ "❝",
14
+ "❞",
15
+ "❮",
16
+ "❯",
17
+ "〝",
18
+ "〞",
19
+ "〟",
20
+ """,
21
+ ]
22
+ strange_single_quotes = ["‘", "‛", "’", "❛", "❜", "`", "´", "‘", "’"]
23
+
24
+ DOUBLE_QUOTE_REGEX = re.compile("|".join(strange_double_quotes))
25
+ SINGLE_QUOTE_REGEX = re.compile("|".join(strange_single_quotes))
src/regexes/url.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ URL_REGEX = re.compile(
4
+ r"(?:^|(?<![\w\/\.]))"
5
+ # protocol identifier
6
+ # r"(?:(?:https?|ftp)://)" <-- alt?
7
+ r"(?:(?:https?:\/\/|ftp:\/\/|www\d{0,3}\.))"
8
+ # user:pass authentication
9
+ r"(?:\S+(?::\S*)?@)?" r"(?:"
10
+ # IP address exclusion
11
+ # private & local networks
12
+ r"(?!(?:10|127)(?:\.\d{1,3}){3})"
13
+ r"(?!(?:169\.254|192\.168)(?:\.\d{1,3}){2})"
14
+ r"(?!172\.(?:1[6-9]|2\d|3[0-1])(?:\.\d{1,3}){2})"
15
+ # IP address dotted notation octets
16
+ # excludes loopback network 0.0.0.0
17
+ # excludes reserved space >= 224.0.0.0
18
+ # excludes network & broadcast addresses
19
+ # (first & last IP address of each class)
20
+ r"(?:[1-9]\d?|1\d\d|2[01]\d|22[0-3])"
21
+ r"(?:\.(?:1?\d{1,2}|2[0-4]\d|25[0-5])){2}"
22
+ r"(?:\.(?:[1-9]\d?|1\d\d|2[0-4]\d|25[0-4]))"
23
+ r"|"
24
+ # host name
25
+ r"(?:(?:[a-z\\u00a1-\\uffff0-9]-?)*[a-z\\u00a1-\\uffff0-9]+)"
26
+ # domain name
27
+ r"(?:\.(?:[a-z\\u00a1-\\uffff0-9]-?)*[a-z\\u00a1-\\uffff0-9]+)*"
28
+ # TLD identifier
29
+ r"(?:\.(?:[a-z\\u00a1-\\uffff]{2,}))" r"|" r"(?:(localhost))" r")"
30
+ # port number
31
+ r"(?::\d{2,5})?"
32
+ # resource path
33
+ r"(?:\/[^\)\]\}\s]*)?",
34
+ # r"(?:$|(?![\w?!+&\/\)]))",
35
+ # @jfilter: I removed the line above from the regex because I don't understand what it is used for, maybe it was useful?
36
+ # But I made sure that it does not include ), ] and } in the URL.
37
+ flags=re.UNICODE | re.IGNORECASE,
38
+ )
src/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ datasets >= 1.1.3
2
+ jax>=0.2.8
3
+ jaxlib>=0.1.59
4
+ flax>=0.3.4
5
+ optax>=0.0.8
6
+ hazm
src/train_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import logging
3
+ import os
4
+ import sys
5
+ from dataclasses import dataclass, field
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ from datasets import load_dataset
9
+ from tokenizers import ByteLevelBPETokenizer
10
+ from transformers import (
11
+ HfArgumentParser,
12
+ )
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class TokenizerArguments:
19
+ """
20
+ Arguments to which tokenizer we are going to set up.
21
+ """
22
+
23
+ output_dir: str = field(
24
+ default=".",
25
+ metadata={"help": "The output directory where the config will be written."},
26
+ )
27
+ dataset_name: Optional[str] = field(
28
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
29
+ )
30
+ dataset_config_name: Optional[str] = field(
31
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
32
+ )
33
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
34
+ cache_dir: Optional[str] = field(
35
+ default=None,
36
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
37
+ )
38
+ special_tokens: Optional[List[str]] = field(
39
+ default=None,
40
+ metadata={"help": "The list of special tokens that you want to add in your training."}
41
+ )
42
+ vocab_size: Optional[int] = field(
43
+ default=50257,
44
+ metadata={"help": "The size of the final vocabulary, including all tokens and alphabet"}
45
+ )
46
+ min_frequency: Optional[int] = field(
47
+ default=2,
48
+ metadata={"help": "The minimum frequency a pair should have in order to be merged"}
49
+ )
50
+ show_progress: Optional[bool] = field(
51
+ default=True,
52
+ metadata={"help": "Whether to show progress bars while training"}
53
+ )
54
+
55
+ def __post_init__(self):
56
+ if self.special_tokens is None:
57
+ self.special_tokens = [
58
+ "<s>", "<pad>", "</s>", "<unk>", "<mask>",
59
+ "<|endoftext|>", "<|startoftext|>",
60
+ "<sep>", "<cls>", "<nl>", "<tab>", "<zwnj>"
61
+ ]
62
+
63
+ self.special_tokens = self.special_tokens + [f"[U{i}]" for i in range(1, 21)]
64
+ if self.dataset_name is None and self.train_file is None:
65
+ raise ValueError("Need either a dataset name or a training file.")
66
+ else:
67
+ if self.train_file is not None:
68
+ extension = self.train_file.split(".")[-1]
69
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
70
+
71
+
72
+ def main():
73
+ parser = HfArgumentParser([TokenizerArguments])
74
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
75
+ # If we pass only one argument to the script and it's the path to a json file,
76
+ # let's parse it to get our arguments.
77
+ tokenizer_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
78
+ else:
79
+ tokenizer_args = parser.parse_args_into_dataclasses()[0]
80
+
81
+ # Setup logging
82
+ logging.basicConfig(
83
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
84
+ datefmt="%m/%d/%Y %H:%M:%S",
85
+ handlers=[logging.StreamHandler(sys.stdout)],
86
+ )
87
+ logger.setLevel(logging.INFO)
88
+
89
+ logger.info(f"Training tokenizer")
90
+
91
+ if tokenizer_args.dataset_name is not None:
92
+ dataset = load_dataset(
93
+ tokenizer_args.dataset_name,
94
+ tokenizer_args.dataset_config_name,
95
+ cache_dir=tokenizer_args.cache_dir,
96
+ split="train"
97
+ )
98
+ else:
99
+ data_files = {"train": tokenizer_args.train_file}
100
+ extension = tokenizer_args.train_file.split(".")[-1]
101
+ if extension == "txt":
102
+ extension = "text"
103
+
104
+ dataset = load_dataset(
105
+ extension,
106
+ data_files=data_files,
107
+ delimiter="\t",
108
+ cache_dir=tokenizer_args.cache_dir,
109
+ )
110
+
111
+ tokenizer = ByteLevelBPETokenizer()
112
+
113
+ def batch_iterative(batch_size=1000):
114
+ for i in range(0, len(dataset), batch_size):
115
+ yield dataset[i: i + batch_size]["text"]
116
+
117
+ tokenizer.train_from_iterator(
118
+ batch_iterative(),
119
+ vocab_size=tokenizer_args.vocab_size,
120
+ special_tokens=tokenizer_args.special_tokens,
121
+ min_frequency=tokenizer_args.min_frequency,
122
+ show_progress=tokenizer_args.show_progress,
123
+ )
124
+
125
+ logger.info(f"Your tokenizer saved here {tokenizer_args.output_dir}/tokenizer")
126
+ os.makedirs(tokenizer_args.output_dir, exist_ok=True)
127
+ tokenizer.save_model(tokenizer_args.output_dir)
128
+ tokenizer.save(f"{tokenizer_args.output_dir}/tokenizer.json", pretty=True)
129
+
130
+
131
+ if __name__ == '__main__':
132
+ main()