Spaces:
Build error
Build error
File size: 3,286 Bytes
708dec4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
from typing import Union, List
from transformers import AutoTokenizer
import torch
class HFPTTokenizer(object):
def __init__(self, pt_name=None):
self.pt_name = pt_name
self.added_sep_token = 0
self.added_cls_token = 0
self.enable_add_tokens = False
self.gpt_special_case = ((not self.enable_add_tokens) and ('gpt' in self.pt_name))
if (pt_name is None):
self.tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
else:
self.tokenizer = AutoTokenizer.from_pretrained(pt_name)
# Adding tokens to GPT causing NaN training loss.
# Disable for now until further investigation.
if (self.enable_add_tokens):
if (self.tokenizer.sep_token is None):
self.tokenizer.add_special_tokens({'sep_token': '<SEP>'})
self.added_sep_token = 1
if (self.tokenizer.cls_token is None):
self.tokenizer.add_special_tokens({'cls_token': '<CLS>'})
self.added_cls_token = 1
if (self.gpt_special_case):
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.sep_token = self.tokenizer.eos_token
def get_eot_token(self):
return self.tokenizer.encode(self.tokenizer.sep_token, add_special_tokens=False)[0]
def get_sot_token(self):
return self.tokenizer.encode(self.tokenizer.cls_token, add_special_tokens=False)[0]
def get_eot_token_list(self):
return self.tokenizer.encode(self.tokenizer.sep_token, add_special_tokens=False)
def get_sot_token_list(self):
return self.tokenizer.encode(self.tokenizer.cls_token, add_special_tokens=False)
def get_tokenizer_obj(self):
return self.tokenizer
# Language model needs to know if new tokens
# were added to the dictionary.
def check_added_tokens(self):
return self.added_sep_token + self.added_cls_token
def tokenize(self, texts: Union[str, List[str]], context_length: int = 77):
if isinstance(texts, str):
texts = [texts]
padding = 'max_length'
seqstart = []
seqtok = []
seqend = []
max_length = context_length
if (self.added_cls_token > 0):
seqstart = self.get_sot_token_list()
max_length = max_length - 1
if (self.added_sep_token > 0):
seqend = self.get_eot_token_list()
max_length = max_length - 1
tokens = self.tokenizer(
texts, padding=padding,
truncation=True,
max_length=max_length
)['input_ids']
for i in range(len(tokens)):
tokens[i] = seqstart + tokens[i] + seqend
if (self.gpt_special_case):
for i in range(len(tokens)):
tokens[i][-1] = self.get_eot_token()
# print(str(tokens))
result = torch.Tensor(tokens).type(torch.LongTensor)
return result
def get_vocab_size(self):
return self.tokenizer.vocab_size
def __call__(self, texts: Union[str, List[str]], context_length: int = 77):
return self.tokenize(texts, context_length)
|