peizesun commited on
Commit
0ec114b
1 Parent(s): ac1256d

Upload t5.py

Browse files
Files changed (1) hide show
  1. t5.py +201 -0
t5.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/t5.py
3
+ import os
4
+ import re
5
+ import html
6
+ import urllib.parse as ul
7
+
8
+ import ftfy
9
+ import torch
10
+ from bs4 import BeautifulSoup
11
+ from transformers import T5EncoderModel, AutoTokenizer
12
+ from huggingface_hub import hf_hub_download
13
+
14
+
15
+ class T5Embedder:
16
+ available_models = ['t5-v1_1-xxl', 't5-v1_1-xl', 'flan-t5-xl']
17
+ bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa
18
+
19
+ def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, local_cache=False, cache_dir=None, hf_token=None, use_text_preprocessing=True,
20
+ t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None, model_max_length=120):
21
+ self.device = torch.device(device)
22
+ self.torch_dtype = torch_dtype or torch.bfloat16
23
+ if t5_model_kwargs is None:
24
+ t5_model_kwargs = {'low_cpu_mem_usage': True, 'torch_dtype': self.torch_dtype}
25
+ t5_model_kwargs['device_map'] = {'shared': self.device, 'encoder': self.device}
26
+
27
+ self.use_text_preprocessing = use_text_preprocessing
28
+ self.hf_token = hf_token
29
+ self.cache_dir = cache_dir or os.path.expanduser('~/.cache/IF_')
30
+ self.dir_or_name = dir_or_name
31
+ tokenizer_path, path = dir_or_name, dir_or_name
32
+ if local_cache:
33
+ cache_dir = os.path.join(self.cache_dir, dir_or_name)
34
+ tokenizer_path, path = cache_dir, cache_dir
35
+ elif dir_or_name in self.available_models:
36
+ cache_dir = os.path.join(self.cache_dir, dir_or_name)
37
+ for filename in [
38
+ 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
39
+ 'pytorch_model.bin.index.json', 'pytorch_model-00001-of-00002.bin', 'pytorch_model-00002-of-00002.bin'
40
+ ]:
41
+ hf_hub_download(repo_id=f'DeepFloyd/{dir_or_name}', filename=filename, cache_dir=cache_dir,
42
+ force_filename=filename, token=self.hf_token)
43
+ tokenizer_path, path = cache_dir, cache_dir
44
+ else:
45
+ cache_dir = os.path.join(self.cache_dir, 't5-v1_1-xxl')
46
+ for filename in [
47
+ 'config.json', 'special_tokens_map.json', 'spiece.model', 'tokenizer_config.json',
48
+ ]:
49
+ hf_hub_download(repo_id='DeepFloyd/t5-v1_1-xxl', filename=filename, cache_dir=cache_dir,
50
+ force_filename=filename, token=self.hf_token)
51
+ tokenizer_path = cache_dir
52
+
53
+ print(tokenizer_path)
54
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
55
+ self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval()
56
+ self.model_max_length = model_max_length
57
+
58
+ def get_text_embeddings(self, texts):
59
+ texts = [self.text_preprocessing(text) for text in texts]
60
+
61
+ text_tokens_and_mask = self.tokenizer(
62
+ texts,
63
+ max_length=self.model_max_length,
64
+ padding='max_length',
65
+ truncation=True,
66
+ return_attention_mask=True,
67
+ add_special_tokens=True,
68
+ return_tensors='pt'
69
+ )
70
+
71
+ text_tokens_and_mask['input_ids'] = text_tokens_and_mask['input_ids']
72
+ text_tokens_and_mask['attention_mask'] = text_tokens_and_mask['attention_mask']
73
+
74
+ with torch.no_grad():
75
+ text_encoder_embs = self.model(
76
+ input_ids=text_tokens_and_mask['input_ids'].to(self.device),
77
+ attention_mask=text_tokens_and_mask['attention_mask'].to(self.device),
78
+ )['last_hidden_state'].detach()
79
+ return text_encoder_embs, text_tokens_and_mask['attention_mask'].to(self.device)
80
+
81
+ def text_preprocessing(self, text):
82
+ if self.use_text_preprocessing:
83
+ # The exact text cleaning as was in the training stage:
84
+ text = self.clean_caption(text)
85
+ text = self.clean_caption(text)
86
+ return text
87
+ else:
88
+ return text.lower().strip()
89
+
90
+ @staticmethod
91
+ def basic_clean(text):
92
+ text = ftfy.fix_text(text)
93
+ text = html.unescape(html.unescape(text))
94
+ return text.strip()
95
+
96
+ def clean_caption(self, caption):
97
+ caption = str(caption)
98
+ caption = ul.unquote_plus(caption)
99
+ caption = caption.strip().lower()
100
+ caption = re.sub('<person>', 'person', caption)
101
+ # urls:
102
+ caption = re.sub(
103
+ r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
104
+ '', caption) # regex for urls
105
+ caption = re.sub(
106
+ r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa
107
+ '', caption) # regex for urls
108
+ # html:
109
+ caption = BeautifulSoup(caption, features='html.parser').text
110
+
111
+ # @<nickname>
112
+ caption = re.sub(r'@[\w\d]+\b', '', caption)
113
+
114
+ # 31C0—31EF CJK Strokes
115
+ # 31F0—31FF Katakana Phonetic Extensions
116
+ # 3200—32FF Enclosed CJK Letters and Months
117
+ # 3300—33FF CJK Compatibility
118
+ # 3400—4DBF CJK Unified Ideographs Extension A
119
+ # 4DC0—4DFF Yijing Hexagram Symbols
120
+ # 4E00—9FFF CJK Unified Ideographs
121
+ caption = re.sub(r'[\u31c0-\u31ef]+', '', caption)
122
+ caption = re.sub(r'[\u31f0-\u31ff]+', '', caption)
123
+ caption = re.sub(r'[\u3200-\u32ff]+', '', caption)
124
+ caption = re.sub(r'[\u3300-\u33ff]+', '', caption)
125
+ caption = re.sub(r'[\u3400-\u4dbf]+', '', caption)
126
+ caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption)
127
+ caption = re.sub(r'[\u4e00-\u9fff]+', '', caption)
128
+ #######################################################
129
+
130
+ # все виды тире / all types of dash --> "-"
131
+ caption = re.sub(
132
+ r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa
133
+ '-', caption)
134
+
135
+ # кавычки к одному стандарту
136
+ caption = re.sub(r'[`´«»“”¨]', '"', caption)
137
+ caption = re.sub(r'[‘’]', "'", caption)
138
+
139
+ # &quot;
140
+ caption = re.sub(r'&quot;?', '', caption)
141
+ # &amp
142
+ caption = re.sub(r'&amp', '', caption)
143
+
144
+ # ip adresses:
145
+ caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption)
146
+
147
+ # article ids:
148
+ caption = re.sub(r'\d:\d\d\s+$', '', caption)
149
+
150
+ # \n
151
+ caption = re.sub(r'\\n', ' ', caption)
152
+
153
+ # "#123"
154
+ caption = re.sub(r'#\d{1,3}\b', '', caption)
155
+ # "#12345.."
156
+ caption = re.sub(r'#\d{5,}\b', '', caption)
157
+ # "123456.."
158
+ caption = re.sub(r'\b\d{6,}\b', '', caption)
159
+ # filenames:
160
+ caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption)
161
+
162
+ #
163
+ caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT"""
164
+ caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT"""
165
+
166
+ caption = re.sub(self.bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
167
+ caption = re.sub(r'\s+\.\s+', r' ', caption) # " . "
168
+
169
+ # this-is-my-cute-cat / this_is_my_cute_cat
170
+ regex2 = re.compile(r'(?:\-|\_)')
171
+ if len(re.findall(regex2, caption)) > 3:
172
+ caption = re.sub(regex2, ' ', caption)
173
+
174
+ caption = self.basic_clean(caption)
175
+
176
+ caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640
177
+ caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc
178
+ caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231
179
+
180
+ caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption)
181
+ caption = re.sub(r'(free\s)?download(\sfree)?', '', caption)
182
+ caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption)
183
+ caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption)
184
+ caption = re.sub(r'\bpage\s+\d+\b', '', caption)
185
+
186
+ caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a...
187
+
188
+ caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption)
189
+
190
+ caption = re.sub(r'\b\s+\:\s+', r': ', caption)
191
+ caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption)
192
+ caption = re.sub(r'\s+', ' ', caption)
193
+
194
+ caption.strip()
195
+
196
+ caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption)
197
+ caption = re.sub(r'^[\'\_,\-\:;]', r'', caption)
198
+ caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption)
199
+ caption = re.sub(r'^\.\S+$', '', caption)
200
+
201
+ return caption.strip()