|
from tokenizers import Tokenizer |
|
from tokenizers.models import WordLevel |
|
from tokenizers.trainers import WordLevelTrainer |
|
from tokenizers.pre_tokenizers import Whitespace |
|
from transformers import PreTrainedTokenizerFast |
|
from tokenizers.processors import TemplateProcessing |
|
import os |
|
import json |
|
|
|
def build_tokenizer(files): |
|
assert type(files) == list and len(files) > 0 |
|
|
|
|
|
tokenizer = Tokenizer(WordLevel(unk_token="[UNK]")) |
|
trainer = WordLevelTrainer(special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]) |
|
tokenizer.pre_tokenizer = Whitespace() |
|
tokenizer.train(files, trainer) |
|
|
|
return tokenizer |
|
|
|
|
|
def tokenizer_from_file(tokenizer_file): |
|
tokenizer = Tokenizer.from_file(tokenizer_file) |
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer.post_processor = TemplateProcessing( |
|
single="[CLS] $A [SEP]", |
|
|
|
|
|
pair="[CLS] $A [SEP] $B:1 [SEP]:1", |
|
special_tokens=[ |
|
("[PAD]", tokenizer.token_to_id("[PAD]")), |
|
("[UNK]", tokenizer.token_to_id("[UNK]")), |
|
("[CLS]", tokenizer.token_to_id("[CLS]")), |
|
("[SEP]", tokenizer.token_to_id("[SEP]")), |
|
("[MASK]", tokenizer.token_to_id("[MASK]")), |
|
], |
|
) |
|
|
|
|
|
tokenizer = PreTrainedTokenizerFast( |
|
tokenizer_object=tokenizer, model_max_length=512, |
|
pad_token='[PAD]', unk_token='[UNK]', cls_token='[CLS]', |
|
sep_token='[SEP]', mask_token='[MASK]') |
|
|
|
return tokenizer |
|
|
|
if not os.path.exists("tmp.json"): |
|
tokenizer = build_tokenizer(files = ["gene_rank_merge_2021Aug25.txt", "../t5/t5finetune_data_flat.csv"]) |
|
tokenizer.save("tmp.json") |
|
|
|
d=json.load(open("tmp.json")) |
|
|
|
|
|
|
|
|
|
vmax = 0 |
|
for k, v in d['model']['vocab'].items(): |
|
if v > vmax: |
|
vmax = v |
|
|
|
assert vmax + 1 == len(d['model']['vocab']) |
|
|
|
for i in range(0, 100): |
|
|
|
d['model']['vocab'][f"unused{i}"] = vmax + 1 + i |
|
|
|
with open('bert.json','w') as f: |
|
json.dump(d, f) |
|
|
|
|
|
tk = tokenizer_from_file("bert.json") |
|
tk.save_pretrained("berttokenizer") |
|
|
|
|
|
|