cleaner files
Browse files- .gitignore +3 -1
- FastT5/__init__.py +4 -0
- FastT5/huggingface_utils.py +25 -0
- FastT5/mcq.py +311 -0
- FastT5/model_testing_tools.py +103 -0
- FastT5/onnx_exporter.py +294 -0
- FastT5/onnx_models.py +269 -0
- FastT5/onnx_models_structure.py +62 -0
- FastT5/ort_settings.py +96 -0
- app.py +1 -734
.gitignore
CHANGED
@@ -1,2 +1,4 @@
|
|
1 |
venv
|
2 |
-
.vscode
|
|
|
|
|
|
1 |
venv
|
2 |
+
.vscode
|
3 |
+
s2v_reddit_2015_md.tar.gz
|
4 |
+
__pycache__
|
FastT5/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .huggingface_utils import set_auth_token
|
2 |
+
from .onnx_models import OnnxT5, export_and_get_onnx_model, get_onnx_model
|
3 |
+
from .ort_settings import get_onnx_runtime_sessions
|
4 |
+
from .onnx_exporter import generate_onnx_representation, quantize
|
FastT5/huggingface_utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_auth_token = None
|
2 |
+
|
3 |
+
def set_auth_token(token):
|
4 |
+
"""Set the token which allows the user to authenticate to hugginface.co for downloading private models
|
5 |
+
|
6 |
+
Args:
|
7 |
+
token (Union[str, bool]): The token value to store. One of:
|
8 |
+
- an API key (from https://huggingface.co/organizations/ORGNAME/settings/token),
|
9 |
+
- a login token obtained by running `$ transformers-cli login`
|
10 |
+
- `True`, which tells transformers to use the login token stored in ~/.huggingface/token
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
None
|
14 |
+
"""
|
15 |
+
global _auth_token
|
16 |
+
_auth_token = token
|
17 |
+
|
18 |
+
def get_auth_token():
|
19 |
+
"""Get the user-configurable auth token, which defaults to None
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
auth_token (Optional[Union[str, bool]]) for authenticating with huggingface.co
|
23 |
+
"""
|
24 |
+
global _auth_token
|
25 |
+
return _auth_token
|
FastT5/mcq.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flashtext import KeywordProcessor
|
2 |
+
from nltk.tokenize import sent_tokenize
|
3 |
+
from similarity.normalized_levenshtein import NormalizedLevenshtein
|
4 |
+
from nltk.corpus import stopwords
|
5 |
+
import torch
|
6 |
+
from collections import OrderedDict
|
7 |
+
import string
|
8 |
+
import pke
|
9 |
+
import nltk
|
10 |
+
import random
|
11 |
+
nltk.download('brown')
|
12 |
+
nltk.download('stopwords')
|
13 |
+
nltk.download('popular')
|
14 |
+
|
15 |
+
|
16 |
+
def MCQs_available(word, s2v):
|
17 |
+
word = word.replace(" ", "_")
|
18 |
+
sense = s2v.get_best_sense(word)
|
19 |
+
if sense is not None:
|
20 |
+
return True
|
21 |
+
else:
|
22 |
+
return False
|
23 |
+
|
24 |
+
|
25 |
+
def edits(word):
|
26 |
+
"All edits that are one edit away from `word`."
|
27 |
+
letters = 'abcdefghijklmnopqrstuvwxyz '+string.punctuation
|
28 |
+
splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
|
29 |
+
deletes = [L + R[1:] for L, R in splits if R]
|
30 |
+
transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R) > 1]
|
31 |
+
replaces = [L + c + R[1:] for L, R in splits if R for c in letters]
|
32 |
+
inserts = [L + c + R for L, R in splits for c in letters]
|
33 |
+
return set(deletes + transposes + replaces + inserts)
|
34 |
+
|
35 |
+
|
36 |
+
def sense2vec_get_words(word, s2v):
|
37 |
+
output = []
|
38 |
+
|
39 |
+
word_preprocessed = word.translate(
|
40 |
+
word.maketrans("", "", string.punctuation))
|
41 |
+
word_preprocessed = word_preprocessed.lower()
|
42 |
+
|
43 |
+
word_edits = edits(word_preprocessed)
|
44 |
+
|
45 |
+
word = word.replace(" ", "_")
|
46 |
+
|
47 |
+
sense = s2v.get_best_sense(word)
|
48 |
+
most_similar = s2v.most_similar(sense, n=15)
|
49 |
+
|
50 |
+
compare_list = [word_preprocessed]
|
51 |
+
for each_word in most_similar:
|
52 |
+
append_word = each_word[0].split("|")[0].replace("_", " ")
|
53 |
+
append_word = append_word.strip()
|
54 |
+
append_word_processed = append_word.lower()
|
55 |
+
append_word_processed = append_word_processed.translate(
|
56 |
+
append_word_processed.maketrans("", "", string.punctuation))
|
57 |
+
if append_word_processed not in compare_list and word_preprocessed not in append_word_processed and append_word_processed not in word_edits:
|
58 |
+
output.append(append_word.title())
|
59 |
+
compare_list.append(append_word_processed)
|
60 |
+
|
61 |
+
out = list(OrderedDict.fromkeys(output))
|
62 |
+
|
63 |
+
return out
|
64 |
+
|
65 |
+
|
66 |
+
def get_options(answer, s2v):
|
67 |
+
distractors = []
|
68 |
+
|
69 |
+
try:
|
70 |
+
distractors = sense2vec_get_words(answer, s2v)
|
71 |
+
if len(distractors) > 0:
|
72 |
+
print(" Sense2vec_distractors successful for word : ", answer)
|
73 |
+
return distractors, "sense2vec"
|
74 |
+
except:
|
75 |
+
print(" Sense2vec_distractors failed for word : ", answer)
|
76 |
+
|
77 |
+
return distractors, "None"
|
78 |
+
|
79 |
+
|
80 |
+
def tokenize_sentences(text):
|
81 |
+
sentences = [sent_tokenize(text)]
|
82 |
+
sentences = [y for x in sentences for y in x]
|
83 |
+
# Remove any short sentences less than 20 letters.
|
84 |
+
sentences = [sentence.strip()
|
85 |
+
for sentence in sentences if len(sentence) > 20]
|
86 |
+
return sentences
|
87 |
+
|
88 |
+
|
89 |
+
def get_sentences_for_keyword(keywords, sentences):
|
90 |
+
keyword_processor = KeywordProcessor()
|
91 |
+
keyword_sentences = {}
|
92 |
+
for word in keywords:
|
93 |
+
word = word.strip()
|
94 |
+
keyword_sentences[word] = []
|
95 |
+
keyword_processor.add_keyword(word)
|
96 |
+
for sentence in sentences:
|
97 |
+
keywords_found = keyword_processor.extract_keywords(sentence)
|
98 |
+
for key in keywords_found:
|
99 |
+
keyword_sentences[key].append(sentence)
|
100 |
+
|
101 |
+
for key in keyword_sentences.keys():
|
102 |
+
values = keyword_sentences[key]
|
103 |
+
values = sorted(values, key=len, reverse=True)
|
104 |
+
keyword_sentences[key] = values
|
105 |
+
|
106 |
+
delete_keys = []
|
107 |
+
for k in keyword_sentences.keys():
|
108 |
+
if len(keyword_sentences[k]) == 0:
|
109 |
+
delete_keys.append(k)
|
110 |
+
for del_key in delete_keys:
|
111 |
+
del keyword_sentences[del_key]
|
112 |
+
|
113 |
+
return keyword_sentences
|
114 |
+
|
115 |
+
|
116 |
+
def is_far(words_list, currentword, thresh, normalized_levenshtein):
|
117 |
+
threshold = thresh
|
118 |
+
score_list = []
|
119 |
+
for word in words_list:
|
120 |
+
score_list.append(normalized_levenshtein.distance(
|
121 |
+
word.lower(), currentword.lower()))
|
122 |
+
if min(score_list) >= threshold:
|
123 |
+
return True
|
124 |
+
else:
|
125 |
+
return False
|
126 |
+
|
127 |
+
|
128 |
+
def filter_phrases(phrase_keys, max, normalized_levenshtein):
|
129 |
+
filtered_phrases = []
|
130 |
+
if len(phrase_keys) > 0:
|
131 |
+
filtered_phrases.append(phrase_keys[0])
|
132 |
+
for ph in phrase_keys[1:]:
|
133 |
+
if is_far(filtered_phrases, ph, 0.7, normalized_levenshtein):
|
134 |
+
filtered_phrases.append(ph)
|
135 |
+
if len(filtered_phrases) >= max:
|
136 |
+
break
|
137 |
+
return filtered_phrases
|
138 |
+
|
139 |
+
|
140 |
+
def get_nouns_multipartite(text):
|
141 |
+
out = []
|
142 |
+
|
143 |
+
extractor = pke.unsupervised.MultipartiteRank()
|
144 |
+
extractor.load_document(input=text, language='en')
|
145 |
+
pos = {'PROPN', 'NOUN'}
|
146 |
+
stoplist = list(string.punctuation)
|
147 |
+
stoplist += stopwords.words('english')
|
148 |
+
extractor.candidate_selection(pos=pos)
|
149 |
+
# 4. build the Multipartite graph and rank candidates using random walk,
|
150 |
+
# alpha controls the weight adjustment mechanism, see TopicRank for
|
151 |
+
# threshold/method parameters.
|
152 |
+
try:
|
153 |
+
extractor.candidate_weighting(alpha=1.1,
|
154 |
+
threshold=0.75,
|
155 |
+
method='average')
|
156 |
+
except:
|
157 |
+
return out
|
158 |
+
|
159 |
+
keyphrases = extractor.get_n_best(n=10)
|
160 |
+
|
161 |
+
for key in keyphrases:
|
162 |
+
out.append(key[0])
|
163 |
+
|
164 |
+
return out
|
165 |
+
|
166 |
+
|
167 |
+
def get_phrases(doc):
|
168 |
+
phrases = {}
|
169 |
+
for np in doc.noun_chunks:
|
170 |
+
phrase = np.text
|
171 |
+
len_phrase = len(phrase.split())
|
172 |
+
if len_phrase > 1:
|
173 |
+
if phrase not in phrases:
|
174 |
+
phrases[phrase] = 1
|
175 |
+
else:
|
176 |
+
phrases[phrase] = phrases[phrase]+1
|
177 |
+
|
178 |
+
phrase_keys = list(phrases.keys())
|
179 |
+
phrase_keys = sorted(phrase_keys, key=lambda x: len(x), reverse=True)
|
180 |
+
phrase_keys = phrase_keys[:50]
|
181 |
+
return phrase_keys
|
182 |
+
|
183 |
+
|
184 |
+
def get_keywords(nlp, text, max_keywords, s2v, fdist, normalized_levenshtein, no_of_sentences):
|
185 |
+
doc = nlp(text)
|
186 |
+
max_keywords = int(max_keywords)
|
187 |
+
|
188 |
+
keywords = get_nouns_multipartite(text)
|
189 |
+
keywords = sorted(keywords, key=lambda x: fdist[x])
|
190 |
+
keywords = filter_phrases(keywords, max_keywords, normalized_levenshtein)
|
191 |
+
|
192 |
+
phrase_keys = get_phrases(doc)
|
193 |
+
filtered_phrases = filter_phrases(
|
194 |
+
phrase_keys, max_keywords, normalized_levenshtein)
|
195 |
+
|
196 |
+
total_phrases = keywords + filtered_phrases
|
197 |
+
|
198 |
+
total_phrases_filtered = filter_phrases(total_phrases, min(
|
199 |
+
max_keywords, 2*no_of_sentences), normalized_levenshtein)
|
200 |
+
|
201 |
+
answers = []
|
202 |
+
for answer in total_phrases_filtered:
|
203 |
+
if answer not in answers and MCQs_available(answer, s2v):
|
204 |
+
answers.append(answer)
|
205 |
+
|
206 |
+
answers = answers[:max_keywords]
|
207 |
+
return answers
|
208 |
+
|
209 |
+
|
210 |
+
def generate_questions_mcq(keyword_sent_mapping, device, tokenizer, model, sense2vec, normalized_levenshtein):
|
211 |
+
batch_text = []
|
212 |
+
|
213 |
+
answers = keyword_sent_mapping.keys()
|
214 |
+
for answer in answers:
|
215 |
+
txt = keyword_sent_mapping[answer]
|
216 |
+
txt_str = "\n".join(txt)
|
217 |
+
context = "context: " + txt_str
|
218 |
+
text = context + " " + "answer: " + answer + " </s>"
|
219 |
+
batch_text.append(text)
|
220 |
+
print(batch_text)
|
221 |
+
|
222 |
+
encoding = tokenizer.batch_encode_plus(
|
223 |
+
batch_text, pad_to_max_length=True, return_tensors="pt")
|
224 |
+
|
225 |
+
print("Running model for generation")
|
226 |
+
input_ids, attention_masks = encoding["input_ids"].to(
|
227 |
+
device), encoding["attention_mask"].to(device)
|
228 |
+
|
229 |
+
with torch.no_grad():
|
230 |
+
outs = model.generate(input_ids=input_ids,
|
231 |
+
attention_mask=attention_masks,
|
232 |
+
max_length=150)
|
233 |
+
|
234 |
+
output_array = {}
|
235 |
+
output_array["questions"] = []
|
236 |
+
# print(outs)
|
237 |
+
for index, val in enumerate(answers):
|
238 |
+
individual_question = {}
|
239 |
+
out = outs[index, :]
|
240 |
+
dec = tokenizer.decode(out, skip_special_tokens=True,
|
241 |
+
clean_up_tokenization_spaces=True)
|
242 |
+
|
243 |
+
Question = dec.replace("question:", "")
|
244 |
+
Question = Question.strip()
|
245 |
+
individual_question["question_statement"] = Question
|
246 |
+
individual_question["question_type"] = "MCQ"
|
247 |
+
individual_question["answer"] = val
|
248 |
+
individual_question["id"] = index+1
|
249 |
+
individual_question["options"], individual_question["options_algorithm"] = get_options(
|
250 |
+
val, sense2vec)
|
251 |
+
|
252 |
+
individual_question["options"] = filter_phrases(
|
253 |
+
individual_question["options"], 10, normalized_levenshtein)
|
254 |
+
index = 3
|
255 |
+
individual_question["extra_options"] = individual_question["options"][index:]
|
256 |
+
individual_question["options"] = individual_question["options"][:index]
|
257 |
+
individual_question["context"] = keyword_sent_mapping[val]
|
258 |
+
|
259 |
+
if len(individual_question["options"]) > 0:
|
260 |
+
output_array["questions"].append(individual_question)
|
261 |
+
|
262 |
+
return output_array
|
263 |
+
|
264 |
+
|
265 |
+
# for normal one word questions
|
266 |
+
def generate_normal_questions(keyword_sent_mapping, device, tokenizer, model):
|
267 |
+
batch_text = ""
|
268 |
+
answers = keyword_sent_mapping.keys()
|
269 |
+
for answer in answers:
|
270 |
+
txt = keyword_sent_mapping[answer]
|
271 |
+
context = "context: " + txt
|
272 |
+
text = context + " " + "answer: " + answer + " </s>"
|
273 |
+
batch_text.join(text)
|
274 |
+
|
275 |
+
encoding = tokenizer.batch_encode_plus(
|
276 |
+
batch_text, pad_to_max_length=True, return_tensors="pt")
|
277 |
+
|
278 |
+
print("Running model for generation")
|
279 |
+
input_ids, attention_masks = encoding["input_ids"].to(
|
280 |
+
device), encoding["attention_mask"].to(device)
|
281 |
+
|
282 |
+
with torch.no_grad():
|
283 |
+
outs = model.generate(input_ids=input_ids,
|
284 |
+
attention_mask=attention_masks,
|
285 |
+
max_length=150)
|
286 |
+
|
287 |
+
output_array = {}
|
288 |
+
output_array["questions"] = []
|
289 |
+
|
290 |
+
for index, val in enumerate(answers):
|
291 |
+
individual_quest = {}
|
292 |
+
out = outs[index, :]
|
293 |
+
dec = tokenizer.decode(out, skip_special_tokens=True,
|
294 |
+
clean_up_tokenization_spaces=True)
|
295 |
+
|
296 |
+
Question = dec.replace('question:', '')
|
297 |
+
Question = Question.strip()
|
298 |
+
|
299 |
+
individual_quest['Question'] = Question
|
300 |
+
individual_quest['Answer'] = val
|
301 |
+
individual_quest["id"] = index+1
|
302 |
+
individual_quest["context"] = keyword_sent_mapping[val]
|
303 |
+
|
304 |
+
output_array["questions"].append(individual_quest)
|
305 |
+
|
306 |
+
return output_array
|
307 |
+
|
308 |
+
|
309 |
+
def random_choice():
|
310 |
+
a = random.choice([0, 1])
|
311 |
+
return bool(a)
|
FastT5/model_testing_tools.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from time import perf_counter as pc
|
2 |
+
from matplotlib import pyplot as plt
|
3 |
+
from transformers import AutoTokenizer
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def speed_test(
|
9 |
+
onnx_model,
|
10 |
+
torch_model,
|
11 |
+
beam_range: range = range(1, 10, 1),
|
12 |
+
seq_length_range: range = range(10, 500, 50),
|
13 |
+
input_text=None,
|
14 |
+
):
|
15 |
+
"""
|
16 |
+
method prints the time took for onnx and pytorch model to finish a text generation task
|
17 |
+
|
18 |
+
args:
|
19 |
+
input_text (str) : text input for the model.
|
20 |
+
onnx_model : onnx representation of the t5 model,
|
21 |
+
torch_model : torch represention of the t5 model,
|
22 |
+
beam_range (range) : provide a range, which takes starting end and steps (don't start with 0)
|
23 |
+
sequence_length-range (range) : takes the start, end and steps as a range (start with 10)
|
24 |
+
return :
|
25 |
+
onnx_model_latency : numpy array of latency for each beam number and sequence length
|
26 |
+
pytorch_model_latency : numpy array of latency for each beam number and sequence length
|
27 |
+
"""
|
28 |
+
|
29 |
+
if input_text is None:
|
30 |
+
input_text = """translate English to French: A nucleus is a collection of a large number of up and down quarks, confined into triplets (neutrons and protons). According to the strange matter hypothesis, strangelets are more stable than nuclei, so nuclei are expected to decay into strangelets. But this process may be extremely slow because there is a large energy barrier to overcome:
|
31 |
+
as the weak interaction starts making a nucleus into a strangelet, the first few strange quarks form strange baryons, such as the Lambda, which are heavy. Only if many conversions occur almost simultaneously will the number of strange quarks reach the critical proportion required to achieve a lower energy state. This is very unlikely to happen, so even if the strange matter hypothesis were correct, nuclei would never be seen to decay to strangelets because their lifetime would be longer than the age of the universe.
|
32 |
+
The stability of strangelets depends on their size. This is because of (a) surface tension at the interface between quark matter and vacuum (which affects small strangelets more than big ones), and (b) screening of charges, which allows small strangelets to be charged, with a neutralizing cloud of electrons/positrons around them, but requires large strangelets, like any large piece of matter, to be electrically neutral in their interior. The charge screening distance tends to be of the order of a few femtometers, so only the outer few femtometers of a strangelet can carry charge.
|
33 |
+
The surface tension of strange matter is unknown. If it is smaller than a critical value (a few MeV per square femtometer) then large strangelets are unstable and will tend to fission into smaller strangelets (strange stars would still be stabilized by gravity). If it is larger than the critical value, then strangelets become more stable as they get bigger.
|
34 |
+
The known particles with strange quarks are unstable. Because the strange quark is heavier than the up and down quarks, it can spontaneously decay, via the weak interaction into an up quark. Consequently particles containing strange quarks, such as the Lambda particle, always lose their strangeness, by decaying into lighter particles containing only up and down quarks.
|
35 |
+
But condensed states with a larger number of quarks might not suffer from this instability. That possible stability against decay is the "strange matter hypothesis" proposed separately by Arnold Bodmer[3] and Edward Witten.[4] According to this hypothesis, when a large enough number of quarks are concentrated together, the lowest energy state is one which has roughly equal numbers of up, down, and strange quarks, namely a strangelet. This stability would occur because of the Pauli exclusion principle; having three types of quarks, rather than two as in normal nuclear matter, allows more quarks to be placed in lower energy levels
|
36 |
+
"""
|
37 |
+
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained(torch_model.name_or_path)
|
39 |
+
|
40 |
+
xx = []
|
41 |
+
yy = []
|
42 |
+
|
43 |
+
for j in beam_range:
|
44 |
+
x = []
|
45 |
+
y = []
|
46 |
+
prev = [1, 2]
|
47 |
+
for i in seq_length_range:
|
48 |
+
|
49 |
+
token = tokenizer(
|
50 |
+
input_text,
|
51 |
+
padding=True,
|
52 |
+
truncation=True,
|
53 |
+
max_length=i,
|
54 |
+
pad_to_max_length=i,
|
55 |
+
return_tensors="pt",
|
56 |
+
)
|
57 |
+
|
58 |
+
input_ids = token["input_ids"]
|
59 |
+
attention_mask = token["attention_mask"]
|
60 |
+
|
61 |
+
a = pc()
|
62 |
+
out = onnx_model.generate(
|
63 |
+
input_ids=input_ids,
|
64 |
+
attention_mask=attention_mask,
|
65 |
+
max_length=i,
|
66 |
+
num_beams=j,
|
67 |
+
)
|
68 |
+
b = pc()
|
69 |
+
x.append(b - a)
|
70 |
+
|
71 |
+
c = pc()
|
72 |
+
o = torch_model.generate(
|
73 |
+
input_ids=input_ids,
|
74 |
+
attention_mask=attention_mask,
|
75 |
+
max_length=i,
|
76 |
+
num_beams=j,
|
77 |
+
)
|
78 |
+
d = pc()
|
79 |
+
y.append(d - c)
|
80 |
+
|
81 |
+
mean_y = np.mean(y)
|
82 |
+
mean_x = np.mean(x)
|
83 |
+
mean_ratio = mean_y / mean_x
|
84 |
+
|
85 |
+
print(f"seqL : {i}, onnx-{b-a}, pt-{d-c} .. X faster {(d-c)/(b-a)}")
|
86 |
+
|
87 |
+
# ...bleu_score-{bleu.compute(predictions=, references=[[tokenizer.decode(o.squeeze(), skip_special_tokens=True)], ])}')
|
88 |
+
# print(f'o---{tokenizer.decode(out.squeeze(), skip_special_tokens=True)}...p---{tokenizer.decode(o.squeeze(), skip_special_tokens=True)}')
|
89 |
+
|
90 |
+
if (o.shape[1] == prev[-1]) and (o.shape[1] == prev[-2]):
|
91 |
+
break
|
92 |
+
|
93 |
+
prev.append(o.shape[1])
|
94 |
+
|
95 |
+
print(f"beam no.- {j} onnx-{mean_x} pt-{mean_y} X ratio-{mean_ratio}")
|
96 |
+
|
97 |
+
xx.append(x)
|
98 |
+
yy.append(y)
|
99 |
+
plt.plot(x, "g", y, "r")
|
100 |
+
plt.pause(0.05)
|
101 |
+
|
102 |
+
plt.show()
|
103 |
+
return np.array(xx), np.array(yy)
|
FastT5/onnx_exporter.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .huggingface_utils import get_auth_token
|
2 |
+
from .onnx_models_structure import (
|
3 |
+
T5Encoder,
|
4 |
+
DecoderWithLMhead,
|
5 |
+
DecoderWithLMheadInitial,
|
6 |
+
)
|
7 |
+
from transformers import (
|
8 |
+
AutoConfig,
|
9 |
+
T5ForConditionalGeneration,
|
10 |
+
MT5ForConditionalGeneration,
|
11 |
+
)
|
12 |
+
import torch
|
13 |
+
import functools
|
14 |
+
import operator
|
15 |
+
from progress.bar import Bar
|
16 |
+
from pathlib import Path
|
17 |
+
import os
|
18 |
+
|
19 |
+
_folder = Path.cwd()
|
20 |
+
saved_models_path = _folder.joinpath("models")
|
21 |
+
|
22 |
+
Bar.check_tty = False
|
23 |
+
|
24 |
+
|
25 |
+
def create_t5_encoder_decoder(pretrained_version="t5-base"):
|
26 |
+
"""Generates an encoder and a decoder model with a language model head from a pretrained huggingface model
|
27 |
+
|
28 |
+
Args:
|
29 |
+
pretrained_version (str): Name of a pretrained model, or path to a pretrained / finetuned version of T5
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
simplified_encoder: pytorch t5 encoder with a wrapper to output only the hidden states
|
33 |
+
decoder_with_lm_head: pytorch t5 decoder with a language modeling head
|
34 |
+
"""
|
35 |
+
|
36 |
+
if 'mt5' in pretrained_version:
|
37 |
+
model = MT5ForConditionalGeneration.from_pretrained(pretrained_version, use_auth_token=get_auth_token())
|
38 |
+
else:
|
39 |
+
model = T5ForConditionalGeneration.from_pretrained(pretrained_version, use_auth_token=get_auth_token())
|
40 |
+
|
41 |
+
return turn_model_into_encoder_decoder(model)
|
42 |
+
|
43 |
+
|
44 |
+
def turn_model_into_encoder_decoder(model):
|
45 |
+
encoder = model.encoder
|
46 |
+
decoder = model.decoder
|
47 |
+
lm_head = model.lm_head
|
48 |
+
|
49 |
+
decoder_with_lm_head = DecoderWithLMhead(decoder, lm_head, model.config)
|
50 |
+
simplified_encoder = T5Encoder(encoder)
|
51 |
+
decoder_with_lm_head_init = DecoderWithLMheadInitial(decoder, lm_head, model.config)
|
52 |
+
|
53 |
+
return simplified_encoder, decoder_with_lm_head, decoder_with_lm_head_init
|
54 |
+
|
55 |
+
|
56 |
+
def generate_onnx_representation(
|
57 |
+
pretrained_version=None,
|
58 |
+
model=None,
|
59 |
+
output_path=None,
|
60 |
+
input_sequence_length=256,
|
61 |
+
onnx_opset_version=12, # no other opset versions are tested, change at your own risk
|
62 |
+
):
|
63 |
+
"""Exports a given huggingface pretrained model, or a given model and tokenizer, to onnx
|
64 |
+
|
65 |
+
Args:
|
66 |
+
pretrained_version (str): Name of a pretrained model, or path to a pretrained / finetuned version of T5
|
67 |
+
output_path (Optional[str]): if missing then use ./models
|
68 |
+
input_sequence_length (Optional[int]): typical input sequence length, for use by the ORT for possible optimization
|
69 |
+
onnx_opset_version (Optional[int]): ONNX Operator Set Version, default 12 is the only tested version
|
70 |
+
"""
|
71 |
+
if (pretrained_version is None) and model is None:
|
72 |
+
print(
|
73 |
+
"You need to specify pretrained_version (the pretrained model you wish to export). Alternatively you can export a model you have in memory."
|
74 |
+
)
|
75 |
+
return
|
76 |
+
|
77 |
+
if model is not None:
|
78 |
+
(
|
79 |
+
simplified_encoder,
|
80 |
+
decoder_with_lm_head,
|
81 |
+
decoder_with_lm_head_init,
|
82 |
+
) = turn_model_into_encoder_decoder(model)
|
83 |
+
else:
|
84 |
+
(
|
85 |
+
simplified_encoder,
|
86 |
+
decoder_with_lm_head,
|
87 |
+
decoder_with_lm_head_init,
|
88 |
+
) = create_t5_encoder_decoder(pretrained_version)
|
89 |
+
|
90 |
+
# model paths for enc, dec and dec_init
|
91 |
+
output_path = saved_models_path if output_path is None else Path(output_path)
|
92 |
+
encoder_path, decoder_path, init_decoder_path = get_model_paths(
|
93 |
+
pretrained_version, output_path, quantized=False
|
94 |
+
)
|
95 |
+
|
96 |
+
model_config = AutoConfig.from_pretrained(pretrained_version, use_auth_token=get_auth_token())
|
97 |
+
|
98 |
+
# Though these are dummy inputs, ORT optimizations do reference these values,
|
99 |
+
# so it is worth using values as close to production as possible
|
100 |
+
batch_size = 1 # not configurable since only CPU
|
101 |
+
enc_seq_length = input_sequence_length
|
102 |
+
dec_seq_length = 1 # a decoder sequence length is always one because it's just the last generated token
|
103 |
+
input_ids = torch.ones(batch_size, enc_seq_length, dtype=torch.int64)
|
104 |
+
attention_mask = torch.ones(batch_size, enc_seq_length, dtype=torch.int64)
|
105 |
+
|
106 |
+
n_heads = model_config.num_heads
|
107 |
+
d_kv = model_config.d_kv
|
108 |
+
|
109 |
+
input_ids_dec = torch.ones(batch_size, dec_seq_length, dtype=torch.int64)
|
110 |
+
attention_mask_dec = torch.ones(batch_size, dec_seq_length, dtype=torch.int64)
|
111 |
+
enc_out = torch.ones(
|
112 |
+
(batch_size, enc_seq_length, model_config.d_model), dtype=torch.float32
|
113 |
+
)
|
114 |
+
|
115 |
+
# self_attention_past_key_values = torch.ones(
|
116 |
+
# (model_config.num_decoder_layers, 2, batch_size, n_heads, seq_length_a, d_kv), dtype=torch.float32)
|
117 |
+
# cross_attention_past_key_values = torch.ones(
|
118 |
+
# (model_config.num_decoder_layers, 2, batch_size, n_heads, seq_length_b, d_kv), dtype=torch.float32)
|
119 |
+
|
120 |
+
sa = torch.ones(
|
121 |
+
(batch_size, n_heads, dec_seq_length, d_kv), dtype=torch.float32
|
122 |
+
) # 1, 8, 1, 64
|
123 |
+
ca = torch.ones(
|
124 |
+
(batch_size, n_heads, enc_seq_length, d_kv), dtype=torch.float32
|
125 |
+
) # 1, 8, variable, 64
|
126 |
+
t5_block = (sa, sa, ca, ca)
|
127 |
+
past_key_values = (t5_block,) * model_config.num_decoder_layers
|
128 |
+
|
129 |
+
flat_past_key_values = functools.reduce(operator.iconcat, past_key_values, [])
|
130 |
+
|
131 |
+
decoder_all_inputs = tuple(
|
132 |
+
[input_ids_dec, attention_mask_dec, enc_out] + flat_past_key_values
|
133 |
+
)
|
134 |
+
|
135 |
+
# for progress bars
|
136 |
+
bar = Bar("Exporting to onnx...", max=3)
|
137 |
+
|
138 |
+
import warnings
|
139 |
+
|
140 |
+
# ignores all the warnings during conversion
|
141 |
+
warnings.filterwarnings("ignore")
|
142 |
+
|
143 |
+
# Exports to ONNX
|
144 |
+
with torch.no_grad():
|
145 |
+
|
146 |
+
decoder_inputs = [
|
147 |
+
"input_ids",
|
148 |
+
"encoder_attention_mask",
|
149 |
+
"encoder_hidden_states",
|
150 |
+
]
|
151 |
+
|
152 |
+
pkv_input_names = ["pkv_{}".format(i) for i in range(len(flat_past_key_values))]
|
153 |
+
|
154 |
+
decoder_input_names = decoder_inputs + pkv_input_names
|
155 |
+
|
156 |
+
decoder_output_names = ["logits", "output_past_key_values"]
|
157 |
+
|
158 |
+
dyn_axis_general = {0: "batch", 1: "sequence"}
|
159 |
+
dyn_axis_pkv = {0: "batch", 2: "seq_length"}
|
160 |
+
|
161 |
+
dyn_axis = {
|
162 |
+
"input_ids": dyn_axis_general,
|
163 |
+
"encoder_attention_mask": dyn_axis_general,
|
164 |
+
"encoder_hidden_states": dyn_axis_general,
|
165 |
+
"logits": dyn_axis_general,
|
166 |
+
"output_past_key_values": dyn_axis_general,
|
167 |
+
}
|
168 |
+
|
169 |
+
dyn_pkv = {
|
170 |
+
"pkv_{}".format(i): dyn_axis_pkv
|
171 |
+
for i in range(len(flat_past_key_values))
|
172 |
+
}
|
173 |
+
|
174 |
+
dyn_axis_params = {**dyn_axis, **dyn_pkv}
|
175 |
+
|
176 |
+
# decoder to utilize past key values:
|
177 |
+
torch.onnx.export(
|
178 |
+
decoder_with_lm_head,
|
179 |
+
decoder_all_inputs,
|
180 |
+
decoder_path.as_posix(),
|
181 |
+
export_params=True,
|
182 |
+
do_constant_folding=True,
|
183 |
+
opset_version=onnx_opset_version,
|
184 |
+
input_names=decoder_input_names,
|
185 |
+
output_names=decoder_output_names,
|
186 |
+
dynamic_axes=dyn_axis_params,
|
187 |
+
)
|
188 |
+
bar.next()
|
189 |
+
|
190 |
+
torch.onnx.export(
|
191 |
+
simplified_encoder,
|
192 |
+
args=(input_ids, attention_mask),
|
193 |
+
f=encoder_path.as_posix(),
|
194 |
+
export_params=True,
|
195 |
+
opset_version=onnx_opset_version,
|
196 |
+
do_constant_folding=True,
|
197 |
+
input_names=["input_ids", "attention_mask"],
|
198 |
+
output_names=["hidden_states"],
|
199 |
+
dynamic_axes={
|
200 |
+
"input_ids": dyn_axis_general,
|
201 |
+
"attention_mask": dyn_axis_general,
|
202 |
+
"hidden_states": dyn_axis_general,
|
203 |
+
},
|
204 |
+
)
|
205 |
+
bar.next()
|
206 |
+
# initial decoder to produce past key values
|
207 |
+
torch.onnx.export(
|
208 |
+
decoder_with_lm_head_init,
|
209 |
+
(input_ids_dec, attention_mask_dec, enc_out),
|
210 |
+
init_decoder_path.as_posix(),
|
211 |
+
export_params=True,
|
212 |
+
opset_version=onnx_opset_version,
|
213 |
+
input_names=[
|
214 |
+
"input_ids",
|
215 |
+
"encoder_attention_mask",
|
216 |
+
"encoder_hidden_states",
|
217 |
+
],
|
218 |
+
output_names=["logits", "past_key_values"],
|
219 |
+
dynamic_axes={
|
220 |
+
# batch_size, seq_length = input_shape
|
221 |
+
"input_ids": dyn_axis_general,
|
222 |
+
"encoder_attention_mask": dyn_axis_general,
|
223 |
+
"encoder_hidden_states": dyn_axis_general,
|
224 |
+
"logits": dyn_axis_general,
|
225 |
+
"past_key_values": dyn_axis_general,
|
226 |
+
},
|
227 |
+
)
|
228 |
+
bar.next()
|
229 |
+
bar.finish()
|
230 |
+
|
231 |
+
return encoder_path, decoder_path, init_decoder_path
|
232 |
+
|
233 |
+
|
234 |
+
def get_model_paths(pretrained_model, model_path, quantized):
|
235 |
+
|
236 |
+
model_path.mkdir(parents=True, exist_ok=True)
|
237 |
+
|
238 |
+
# gets only the filename
|
239 |
+
pretrained_model_name = Path(pretrained_model).stem
|
240 |
+
|
241 |
+
if not quantized:
|
242 |
+
encoder_path = model_path.joinpath(f"{pretrained_model_name}-encoder.onnx")
|
243 |
+
decoder_path = model_path.joinpath(f"{pretrained_model_name}-decoder.onnx")
|
244 |
+
init_decoder_path = model_path.joinpath(
|
245 |
+
f"{pretrained_model_name}-init-decoder.onnx"
|
246 |
+
)
|
247 |
+
else:
|
248 |
+
encoder_path = model_path.joinpath(
|
249 |
+
f"{pretrained_model_name}-encoder-quantized.onnx"
|
250 |
+
)
|
251 |
+
decoder_path = model_path.joinpath(
|
252 |
+
f"{pretrained_model_name}-decoder-quantized.onnx"
|
253 |
+
)
|
254 |
+
init_decoder_path = model_path.joinpath(
|
255 |
+
f"{pretrained_model_name}-init-decoder-quantized.onnx"
|
256 |
+
)
|
257 |
+
|
258 |
+
return encoder_path, decoder_path, init_decoder_path
|
259 |
+
|
260 |
+
|
261 |
+
def quantize(models_name_or_path):
|
262 |
+
"""
|
263 |
+
Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU
|
264 |
+
|
265 |
+
Uses unsigned ints for activation values, signed ints for weights, per
|
266 |
+
https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection
|
267 |
+
it is faster on most CPU architectures
|
268 |
+
Args:
|
269 |
+
onnx_model_path: Path to location the exported ONNX model is stored
|
270 |
+
Returns: The Path generated for the quantized
|
271 |
+
"""
|
272 |
+
from onnxruntime.quantization import quantize_dynamic, QuantType
|
273 |
+
|
274 |
+
bar = Bar("Quantizing...", max=3)
|
275 |
+
|
276 |
+
quant_model_paths = []
|
277 |
+
for model in models_name_or_path:
|
278 |
+
model_name = model.as_posix()
|
279 |
+
output_model_name = f"{model_name[:-5]}-quantized.onnx"
|
280 |
+
quantize_dynamic(
|
281 |
+
model_input=model_name,
|
282 |
+
model_output=output_model_name,
|
283 |
+
per_channel=True,
|
284 |
+
reduce_range=True, # should be the same as per_channel
|
285 |
+
activation_type=QuantType.QUInt8,
|
286 |
+
weight_type=QuantType.QInt8, # per docs, signed is faster on most CPUs
|
287 |
+
optimize_model=False,
|
288 |
+
) # op_types_to_quantize=['MatMul', 'Relu', 'Add', 'Mul' ],
|
289 |
+
quant_model_paths.append(output_model_name)
|
290 |
+
bar.next()
|
291 |
+
|
292 |
+
bar.finish()
|
293 |
+
|
294 |
+
return tuple(quant_model_paths)
|
FastT5/onnx_models.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .huggingface_utils import get_auth_token
|
2 |
+
from .ort_settings import get_onnx_runtime_sessions
|
3 |
+
from .onnx_exporter import (
|
4 |
+
generate_onnx_representation,
|
5 |
+
quantize,
|
6 |
+
get_model_paths,
|
7 |
+
saved_models_path,
|
8 |
+
)
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
from transformers import (
|
12 |
+
AutoConfig,
|
13 |
+
MT5Config,
|
14 |
+
T5ForConditionalGeneration,
|
15 |
+
)
|
16 |
+
from transformers.modeling_outputs import (
|
17 |
+
Seq2SeqLMOutput,
|
18 |
+
BaseModelOutput,
|
19 |
+
)
|
20 |
+
import torch
|
21 |
+
import functools
|
22 |
+
import operator
|
23 |
+
import numpy
|
24 |
+
|
25 |
+
|
26 |
+
class T5Encoder(torch.nn.Module):
|
27 |
+
def __init__(self, encoder_sess):
|
28 |
+
super().__init__()
|
29 |
+
self.encoder = encoder_sess
|
30 |
+
self.main_input_name = "input_ids"
|
31 |
+
|
32 |
+
def forward(
|
33 |
+
self,
|
34 |
+
input_ids,
|
35 |
+
attention_mask,
|
36 |
+
inputs_embeds=None,
|
37 |
+
head_mask=None,
|
38 |
+
output_attentions=None,
|
39 |
+
output_hidden_states=None,
|
40 |
+
return_dict=None,
|
41 |
+
):
|
42 |
+
|
43 |
+
encoder_hidden_state = torch.from_numpy(
|
44 |
+
self.encoder.run(
|
45 |
+
None,
|
46 |
+
{
|
47 |
+
"input_ids": input_ids.cpu().numpy(),
|
48 |
+
"attention_mask": attention_mask.cpu().numpy(),
|
49 |
+
},
|
50 |
+
)[0]
|
51 |
+
)
|
52 |
+
|
53 |
+
return BaseModelOutput(encoder_hidden_state)
|
54 |
+
|
55 |
+
|
56 |
+
class T5DecoderInit(torch.nn.Module):
|
57 |
+
def __init__(self, decoder_sess):
|
58 |
+
super().__init__()
|
59 |
+
self.decoder = decoder_sess
|
60 |
+
|
61 |
+
def forward(self, input_ids, encoder_attention_mask, encoder_hidden_states):
|
62 |
+
|
63 |
+
decoder_outputs = self.decoder.run(
|
64 |
+
None,
|
65 |
+
{
|
66 |
+
"input_ids": input_ids.cpu().numpy(),
|
67 |
+
"encoder_attention_mask": encoder_attention_mask.cpu().numpy(),
|
68 |
+
"encoder_hidden_states": encoder_hidden_states.cpu().numpy(),
|
69 |
+
},
|
70 |
+
)
|
71 |
+
|
72 |
+
list_pkv = tuple(torch.from_numpy(x) for x in decoder_outputs[1:])
|
73 |
+
|
74 |
+
out_past_key_values = tuple(
|
75 |
+
list_pkv[i: i + 4] for i in range(0, len(list_pkv), 4)
|
76 |
+
)
|
77 |
+
|
78 |
+
return torch.from_numpy(decoder_outputs[0]), out_past_key_values
|
79 |
+
|
80 |
+
|
81 |
+
class T5Decoder(torch.nn.Module):
|
82 |
+
def __init__(self, decoder_sess):
|
83 |
+
super().__init__()
|
84 |
+
self.decoder = decoder_sess
|
85 |
+
|
86 |
+
def forward(self, input_ids, attention_mask, encoder_output, past_key_values):
|
87 |
+
|
88 |
+
decoder_inputs = {
|
89 |
+
"input_ids": input_ids.cpu().numpy(),
|
90 |
+
"encoder_attention_mask": attention_mask.cpu().numpy(),
|
91 |
+
"encoder_hidden_states": encoder_output.cpu().numpy(),
|
92 |
+
}
|
93 |
+
|
94 |
+
flat_past_key_values = functools.reduce(
|
95 |
+
operator.iconcat, past_key_values, [])
|
96 |
+
|
97 |
+
past_key_values = {
|
98 |
+
f"pkv_{i}": pkv.cpu().numpy() for i, pkv in enumerate(flat_past_key_values)
|
99 |
+
}
|
100 |
+
|
101 |
+
decoder_outputs = self.decoder.run(
|
102 |
+
None, {**decoder_inputs, **past_key_values})
|
103 |
+
# converts each value of the list to tensor from numpy
|
104 |
+
list_pkv = tuple(torch.from_numpy(x) for x in decoder_outputs[1:])
|
105 |
+
|
106 |
+
# creates a tuple of tuples of shape 6x4 from the above tuple
|
107 |
+
out_past_key_values = tuple(
|
108 |
+
list_pkv[i: i + 4] for i in range(0, len(list_pkv), 4)
|
109 |
+
)
|
110 |
+
|
111 |
+
return torch.from_numpy(decoder_outputs[0]), out_past_key_values
|
112 |
+
|
113 |
+
|
114 |
+
class OnnxT5(T5ForConditionalGeneration):
|
115 |
+
"""creates a T5 model using onnx sessions (encode, decoder & init_decoder)"""
|
116 |
+
|
117 |
+
def __init__(self, model_or_model_path, onnx_model_sessions):
|
118 |
+
config = AutoConfig.from_pretrained(
|
119 |
+
model_or_model_path, use_auth_token=get_auth_token()
|
120 |
+
)
|
121 |
+
super().__init__(config)
|
122 |
+
|
123 |
+
# monkeypatch to work for MT5
|
124 |
+
if (
|
125 |
+
isinstance(model_or_model_path, str)
|
126 |
+
and "mt5" in model_or_model_path.lower()
|
127 |
+
) or (
|
128 |
+
hasattr(model_or_model_path, "name_or_path")
|
129 |
+
and "mt5" in model_or_model_path.name_or_path
|
130 |
+
):
|
131 |
+
self.model_type = "mt5"
|
132 |
+
self.config_class = MT5Config
|
133 |
+
self._keys_to_ignore_on_load_missing = [
|
134 |
+
r"encoder\.embed_tokens\.weight",
|
135 |
+
]
|
136 |
+
self._keys_to_ignore_on_save = [
|
137 |
+
r"encoder\.embed_tokens\.weight",
|
138 |
+
]
|
139 |
+
|
140 |
+
assert len(onnx_model_sessions) == 3, "all three models should be given"
|
141 |
+
|
142 |
+
encoder_sess, decoder_sess, decoder_sess_init = onnx_model_sessions
|
143 |
+
|
144 |
+
self.encoder = T5Encoder(encoder_sess)
|
145 |
+
self.decoder = T5Decoder(decoder_sess)
|
146 |
+
self.decoder_init = T5DecoderInit(decoder_sess_init)
|
147 |
+
|
148 |
+
def forward(
|
149 |
+
self,
|
150 |
+
input_ids=None,
|
151 |
+
attention_mask=None,
|
152 |
+
decoder_input_ids=None,
|
153 |
+
decoder_attention_mask=None,
|
154 |
+
head_mask=None,
|
155 |
+
decoder_head_mask=None,
|
156 |
+
cross_attn_head_mask=None,
|
157 |
+
encoder_outputs=None,
|
158 |
+
past_key_values=None,
|
159 |
+
inputs_embeds=None,
|
160 |
+
decoder_inputs_embeds=None,
|
161 |
+
labels=None,
|
162 |
+
use_cache=None,
|
163 |
+
output_attentions=None,
|
164 |
+
output_hidden_states=None,
|
165 |
+
return_dict=None,
|
166 |
+
):
|
167 |
+
|
168 |
+
if encoder_outputs is None:
|
169 |
+
# Convert encoder inputs in embeddings if needed
|
170 |
+
encoder_outputs = self.encoder(
|
171 |
+
input_ids=input_ids, attention_mask=attention_mask
|
172 |
+
)
|
173 |
+
|
174 |
+
encoder_hidden_states = encoder_outputs[0]
|
175 |
+
|
176 |
+
if past_key_values is not None:
|
177 |
+
if decoder_input_ids is not None:
|
178 |
+
decoder_input_ids = decoder_input_ids[:, -1:]
|
179 |
+
if decoder_inputs_embeds is not None:
|
180 |
+
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
|
181 |
+
|
182 |
+
if past_key_values is None:
|
183 |
+
|
184 |
+
# runs only for the first time:
|
185 |
+
init_onnx_outputs = self.decoder_init(
|
186 |
+
decoder_input_ids, attention_mask, encoder_hidden_states
|
187 |
+
)
|
188 |
+
|
189 |
+
logits, past_key_values = init_onnx_outputs
|
190 |
+
|
191 |
+
else:
|
192 |
+
|
193 |
+
onnx_outputs = self.decoder(
|
194 |
+
decoder_input_ids,
|
195 |
+
attention_mask,
|
196 |
+
encoder_hidden_states,
|
197 |
+
past_key_values,
|
198 |
+
)
|
199 |
+
|
200 |
+
logits, past_key_values = onnx_outputs
|
201 |
+
|
202 |
+
return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)
|
203 |
+
|
204 |
+
|
205 |
+
def export_and_get_onnx_model(
|
206 |
+
model_or_model_path, custom_output_path=saved_models_path, quantized=True
|
207 |
+
):
|
208 |
+
"""
|
209 |
+
Method for whole pipeline,
|
210 |
+
converts from pytorch to onnx --> quantizes model --> sets onnx runtime
|
211 |
+
--> builds whole onnx model with all sessions
|
212 |
+
|
213 |
+
"""
|
214 |
+
|
215 |
+
# Step 1. convert huggingfaces t5 model to onnx
|
216 |
+
onnx_model_paths = generate_onnx_representation(
|
217 |
+
model_or_model_path, output_path=custom_output_path
|
218 |
+
)
|
219 |
+
|
220 |
+
if quantized:
|
221 |
+
# Step 2. (recommended) quantize the converted model for fast inference and to reduce model size.
|
222 |
+
quant_model_paths = quantize(onnx_model_paths)
|
223 |
+
|
224 |
+
# step 3. setup onnx runtime
|
225 |
+
print("Setting up onnx model...")
|
226 |
+
model_sessions = get_onnx_runtime_sessions(quant_model_paths)
|
227 |
+
else:
|
228 |
+
print("Setting up onnx model...")
|
229 |
+
model_sessions = get_onnx_runtime_sessions(onnx_model_paths)
|
230 |
+
|
231 |
+
# step 4. get the onnx model
|
232 |
+
model = OnnxT5(model_or_model_path, model_sessions)
|
233 |
+
print("Done!")
|
234 |
+
|
235 |
+
return model
|
236 |
+
|
237 |
+
|
238 |
+
def get_onnx_model(model_name, onnx_models_path=saved_models_path, quantized=True):
|
239 |
+
"""
|
240 |
+
method gets the onnx model, if already converted models exists
|
241 |
+
Example:
|
242 |
+
>> get_onnx_model(model_name="t5-finetuned", onnx_models_path="../models/onnx/quantized/")
|
243 |
+
|
244 |
+
"""
|
245 |
+
|
246 |
+
encoder_path, decoder_path, init_decoder_path = get_model_paths(
|
247 |
+
model_name, Path(onnx_models_path), quantized
|
248 |
+
)
|
249 |
+
|
250 |
+
if quantized:
|
251 |
+
assert (
|
252 |
+
encoder_path.exists()
|
253 |
+
and decoder_path.exists()
|
254 |
+
and init_decoder_path.exists()
|
255 |
+
), "quantized model don't exist in the model folder, first quantize the model!"
|
256 |
+
else:
|
257 |
+
assert (
|
258 |
+
encoder_path.exists()
|
259 |
+
and decoder_path.exists()
|
260 |
+
and init_decoder_path.exists()
|
261 |
+
), "all or some models don't exists in the model folder, first convert the model! "
|
262 |
+
|
263 |
+
model_paths = encoder_path, decoder_path, init_decoder_path
|
264 |
+
|
265 |
+
model_sessions = get_onnx_runtime_sessions(model_paths)
|
266 |
+
|
267 |
+
model = OnnxT5(model_name, model_sessions)
|
268 |
+
|
269 |
+
return model
|
FastT5/onnx_models_structure.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class DecoderWithLMhead(torch.nn.Module):
|
5 |
+
""" Creation of a class to combine the decoder and the lm head """
|
6 |
+
|
7 |
+
def __init__(self, decoder, lm_head, config):
|
8 |
+
super().__init__()
|
9 |
+
self.decoder = decoder
|
10 |
+
self.lm_head = lm_head
|
11 |
+
self.config = config
|
12 |
+
|
13 |
+
def forward(self, *inputs):
|
14 |
+
|
15 |
+
input_ids, attention_mask, encoder_hidden_states = inputs[:3]
|
16 |
+
|
17 |
+
list_pkv = inputs[3:]
|
18 |
+
past_key_values = tuple(list_pkv[i : i + 4] for i in range(0, len(list_pkv), 4))
|
19 |
+
|
20 |
+
decoder_output = self.decoder(
|
21 |
+
input_ids=input_ids, # decoder_input_ids
|
22 |
+
encoder_attention_mask=attention_mask,
|
23 |
+
encoder_hidden_states=encoder_hidden_states,
|
24 |
+
past_key_values=past_key_values,
|
25 |
+
)
|
26 |
+
|
27 |
+
lm_head_out = self.lm_head(decoder_output[0] * (self.config.d_model ** -0.5))
|
28 |
+
|
29 |
+
return lm_head_out, decoder_output[1]
|
30 |
+
|
31 |
+
|
32 |
+
class T5Encoder(torch.nn.Module):
|
33 |
+
""" Creation of a class to output only the last hidden state from the encoder """
|
34 |
+
|
35 |
+
def __init__(self, encoder):
|
36 |
+
super().__init__()
|
37 |
+
self.encoder = encoder
|
38 |
+
|
39 |
+
def forward(self, *input, **kwargs):
|
40 |
+
return self.encoder(*input, **kwargs)[0]
|
41 |
+
|
42 |
+
|
43 |
+
class DecoderWithLMheadInitial(torch.nn.Module):
|
44 |
+
""" Creation of a class to combine the decoder and the lm head """
|
45 |
+
|
46 |
+
def __init__(self, decoder, lm_head, config):
|
47 |
+
super().__init__()
|
48 |
+
self.decoder = decoder
|
49 |
+
self.lm_head = lm_head
|
50 |
+
self.config = config
|
51 |
+
|
52 |
+
def forward(self, input_ids, attention_mask, encoder_hidden_states):
|
53 |
+
decoder_output = self.decoder(
|
54 |
+
input_ids=input_ids,
|
55 |
+
encoder_attention_mask=attention_mask,
|
56 |
+
encoder_hidden_states=encoder_hidden_states,
|
57 |
+
)
|
58 |
+
|
59 |
+
return (
|
60 |
+
self.lm_head(decoder_output[0] * (self.config.d_model ** -0.5)),
|
61 |
+
decoder_output[1],
|
62 |
+
)
|
FastT5/ort_settings.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, psutil
|
2 |
+
|
3 |
+
os.environ["OMP_NUM_THREADS"] = str(psutil.cpu_count(logical=True))
|
4 |
+
os.environ["OMP_WAIT_POLICY"] = "ACTIVE"
|
5 |
+
|
6 |
+
|
7 |
+
from onnxruntime import (
|
8 |
+
GraphOptimizationLevel,
|
9 |
+
InferenceSession,
|
10 |
+
SessionOptions,
|
11 |
+
ExecutionMode,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def get_onnx_runtime_sessions(
|
16 |
+
model_paths,
|
17 |
+
default: bool = True,
|
18 |
+
opt_level: int = 99,
|
19 |
+
parallel_exe_mode: bool = True,
|
20 |
+
n_threads: int = 0,
|
21 |
+
provider=[
|
22 |
+
"CPUExecutionProvider",
|
23 |
+
],
|
24 |
+
) -> InferenceSession:
|
25 |
+
"""
|
26 |
+
Optimizes the model
|
27 |
+
|
28 |
+
Args:
|
29 |
+
model_paths (List or Tuple of str) : the path to, in order:
|
30 |
+
path_to_encoder (str) : the path of input onnx encoder model.
|
31 |
+
path_to_decoder (str) : the path of input onnx decoder model.
|
32 |
+
path_to_initial_decoder (str) : the path of input initial onnx decoder model.
|
33 |
+
default : set this to true, ort will choose the best settings for your hardware.
|
34 |
+
(you can test out different settings for better results.)
|
35 |
+
opt_level (int) : sess_options.GraphOptimizationLevel param if set 1 uses 'ORT_ENABLE_BASIC',
|
36 |
+
2 for 'ORT_ENABLE_EXTENDED' and 99 for 'ORT_ENABLE_ALL',
|
37 |
+
default value is set to 99.
|
38 |
+
parallel_exe_mode (bool) : Sets the execution mode. Default is True (parallel).
|
39 |
+
n_threads (int) : Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose
|
40 |
+
provider : execution providers list.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
encoder_session : encoder onnx InferenceSession
|
44 |
+
decoder_session : decoder onnx InferenceSession
|
45 |
+
decoder_sess_init : initial decoder onnx InferenceSession
|
46 |
+
|
47 |
+
"""
|
48 |
+
path_to_encoder, path_to_decoder, path_to_initial_decoder = model_paths
|
49 |
+
|
50 |
+
if default:
|
51 |
+
|
52 |
+
encoder_sess = InferenceSession(str(path_to_encoder))
|
53 |
+
|
54 |
+
decoder_sess = InferenceSession(str(path_to_decoder))
|
55 |
+
|
56 |
+
decoder_sess_init = InferenceSession(str(path_to_initial_decoder))
|
57 |
+
|
58 |
+
else:
|
59 |
+
|
60 |
+
# Few properties that might have an impact on performances
|
61 |
+
options = SessionOptions()
|
62 |
+
|
63 |
+
if opt_level == 1:
|
64 |
+
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
65 |
+
elif opt_level == 2:
|
66 |
+
options.graph_optimization_level = (
|
67 |
+
GraphOptimizationLevel.ORT_ENABLE_EXTENDED
|
68 |
+
)
|
69 |
+
else:
|
70 |
+
assert opt_level == 99
|
71 |
+
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
72 |
+
|
73 |
+
# set this true for better performance
|
74 |
+
if parallel_exe_mode == True:
|
75 |
+
options.execution_mode = ExecutionMode.ORT_PARALLEL
|
76 |
+
else:
|
77 |
+
options.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
78 |
+
|
79 |
+
options.intra_op_num_threads = n_threads
|
80 |
+
# options.inter_op_num_threads = 10
|
81 |
+
|
82 |
+
# options.enable_profiling = True
|
83 |
+
|
84 |
+
encoder_sess = InferenceSession(
|
85 |
+
str(path_to_encoder), options, providers=provider
|
86 |
+
)
|
87 |
+
|
88 |
+
decoder_sess = InferenceSession(
|
89 |
+
str(path_to_decoder), options, providers=provider
|
90 |
+
)
|
91 |
+
|
92 |
+
decoder_sess_init = InferenceSession(
|
93 |
+
str(path_to_initial_decoder), options, providers=provider
|
94 |
+
)
|
95 |
+
|
96 |
+
return encoder_sess, decoder_sess, decoder_sess_init
|
app.py
CHANGED
@@ -1,742 +1,9 @@
|
|
1 |
-
import psutil
|
2 |
-
from transformers import (
|
3 |
-
AutoConfig,
|
4 |
-
T5ForConditionalGeneration,
|
5 |
-
MT5ForConditionalGeneration,
|
6 |
-
)
|
7 |
-
import torch
|
8 |
import time
|
9 |
import gradio as gr
|
10 |
from transformers import AutoTokenizer
|
11 |
-
import onnxruntime as ort
|
12 |
-
from transformers.modeling_outputs import (
|
13 |
-
Seq2SeqLMOutput,
|
14 |
-
BaseModelOutput,
|
15 |
-
)
|
16 |
import os
|
17 |
from pathlib import Path
|
18 |
-
from
|
19 |
-
import operator
|
20 |
-
import functools
|
21 |
-
from onnxruntime import (
|
22 |
-
GraphOptimizationLevel,
|
23 |
-
InferenceSession,
|
24 |
-
SessionOptions,
|
25 |
-
ExecutionMode,
|
26 |
-
)
|
27 |
-
_auth_token = None
|
28 |
-
|
29 |
-
|
30 |
-
def set_auth_token(token):
|
31 |
-
"""Set the token which allows the user to authenticate to hugginface.co for downloading private models
|
32 |
-
|
33 |
-
Args:
|
34 |
-
token (Union[str, bool]): The token value to store. One of:
|
35 |
-
- an API key (from https://huggingface.co/organizations/ORGNAME/settings/token),
|
36 |
-
- a login token obtained by running `$ transformers-cli login`
|
37 |
-
- `True`, which tells transformers to use the login token stored in ~/.huggingface/token
|
38 |
-
|
39 |
-
Returns:
|
40 |
-
None
|
41 |
-
"""
|
42 |
-
global _auth_token
|
43 |
-
_auth_token = token
|
44 |
-
|
45 |
-
|
46 |
-
def get_auth_token():
|
47 |
-
"""Get the user-configurable auth token, which defaults to None
|
48 |
-
|
49 |
-
Returns:
|
50 |
-
auth_token (Optional[Union[str, bool]]) for authenticating with huggingface.co
|
51 |
-
"""
|
52 |
-
global _auth_token
|
53 |
-
return _auth_token
|
54 |
-
|
55 |
-
|
56 |
-
os.environ["OMP_NUM_THREADS"] = str(psutil.cpu_count(logical=True))
|
57 |
-
os.environ["OMP_WAIT_POLICY"] = "ACTIVE"
|
58 |
-
|
59 |
-
|
60 |
-
def get_onnx_runtime_sessions(
|
61 |
-
model_paths,
|
62 |
-
default: bool = True,
|
63 |
-
opt_level: int = 99,
|
64 |
-
parallel_exe_mode: bool = True,
|
65 |
-
n_threads: int = 0,
|
66 |
-
provider=[
|
67 |
-
"CPUExecutionProvider",
|
68 |
-
],
|
69 |
-
) -> InferenceSession:
|
70 |
-
"""
|
71 |
-
Optimizes the model
|
72 |
-
|
73 |
-
Args:
|
74 |
-
model_paths (List or Tuple of str) : the path to, in order:
|
75 |
-
path_to_encoder (str) : the path of input onnx encoder model.
|
76 |
-
path_to_decoder (str) : the path of input onnx decoder model.
|
77 |
-
path_to_initial_decoder (str) : the path of input initial onnx decoder model.
|
78 |
-
default : set this to true, ort will choose the best settings for your hardware.
|
79 |
-
(you can test out different settings for better results.)
|
80 |
-
opt_level (int) : sess_options.GraphOptimizationLevel param if set 1 uses 'ORT_ENABLE_BASIC',
|
81 |
-
2 for 'ORT_ENABLE_EXTENDED' and 99 for 'ORT_ENABLE_ALL',
|
82 |
-
default value is set to 99.
|
83 |
-
parallel_exe_mode (bool) : Sets the execution mode. Default is True (parallel).
|
84 |
-
n_threads (int) : Sets the number of threads used to parallelize the execution within nodes. Default is 0 to let onnxruntime choose
|
85 |
-
provider : execution providers list.
|
86 |
-
|
87 |
-
Returns:
|
88 |
-
encoder_session : encoder onnx InferenceSession
|
89 |
-
decoder_session : decoder onnx InferenceSession
|
90 |
-
decoder_sess_init : initial decoder onnx InferenceSession
|
91 |
-
|
92 |
-
"""
|
93 |
-
path_to_encoder, path_to_decoder, path_to_initial_decoder = model_paths
|
94 |
-
|
95 |
-
if default:
|
96 |
-
|
97 |
-
encoder_sess = InferenceSession(str(path_to_encoder))
|
98 |
-
|
99 |
-
decoder_sess = InferenceSession(str(path_to_decoder))
|
100 |
-
|
101 |
-
decoder_sess_init = InferenceSession(str(path_to_initial_decoder))
|
102 |
-
|
103 |
-
else:
|
104 |
-
|
105 |
-
# Few properties that might have an impact on performances
|
106 |
-
options = SessionOptions()
|
107 |
-
|
108 |
-
if opt_level == 1:
|
109 |
-
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
110 |
-
elif opt_level == 2:
|
111 |
-
options.graph_optimization_level = (
|
112 |
-
GraphOptimizationLevel.ORT_ENABLE_EXTENDED
|
113 |
-
)
|
114 |
-
else:
|
115 |
-
assert opt_level == 99
|
116 |
-
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
117 |
-
|
118 |
-
# set this true for better performance
|
119 |
-
if parallel_exe_mode == True:
|
120 |
-
options.execution_mode = ExecutionMode.ORT_PARALLEL
|
121 |
-
else:
|
122 |
-
options.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
123 |
-
|
124 |
-
options.intra_op_num_threads = n_threads
|
125 |
-
# options.inter_op_num_threads = 10
|
126 |
-
|
127 |
-
# options.enable_profiling = True
|
128 |
-
|
129 |
-
encoder_sess = InferenceSession(
|
130 |
-
str(path_to_encoder), options, providers=provider
|
131 |
-
)
|
132 |
-
|
133 |
-
decoder_sess = InferenceSession(
|
134 |
-
str(path_to_decoder), options, providers=provider
|
135 |
-
)
|
136 |
-
|
137 |
-
decoder_sess_init = InferenceSession(
|
138 |
-
str(path_to_initial_decoder), options, providers=provider
|
139 |
-
)
|
140 |
-
|
141 |
-
return encoder_sess, decoder_sess, decoder_sess_init
|
142 |
-
|
143 |
-
|
144 |
-
class DecoderWithLMhead(torch.nn.Module):
|
145 |
-
""" Creation of a class to combine the decoder and the lm head """
|
146 |
-
|
147 |
-
def __init__(self, decoder, lm_head, config):
|
148 |
-
super().__init__()
|
149 |
-
self.decoder = decoder
|
150 |
-
self.lm_head = lm_head
|
151 |
-
self.config = config
|
152 |
-
|
153 |
-
def forward(self, *inputs):
|
154 |
-
|
155 |
-
input_ids, attention_mask, encoder_hidden_states = inputs[:3]
|
156 |
-
|
157 |
-
list_pkv = inputs[3:]
|
158 |
-
past_key_values = tuple(list_pkv[i: i + 4]
|
159 |
-
for i in range(0, len(list_pkv), 4))
|
160 |
-
|
161 |
-
decoder_output = self.decoder(
|
162 |
-
input_ids=input_ids, # decoder_input_ids
|
163 |
-
encoder_attention_mask=attention_mask,
|
164 |
-
encoder_hidden_states=encoder_hidden_states,
|
165 |
-
past_key_values=past_key_values,
|
166 |
-
)
|
167 |
-
|
168 |
-
lm_head_out = self.lm_head(
|
169 |
-
decoder_output[0] * (self.config.d_model ** -0.5))
|
170 |
-
|
171 |
-
return lm_head_out, decoder_output[1]
|
172 |
-
|
173 |
-
|
174 |
-
class T5Encoder(torch.nn.Module):
|
175 |
-
""" Creation of a class to output only the last hidden state from the encoder """
|
176 |
-
|
177 |
-
def __init__(self, encoder):
|
178 |
-
super().__init__()
|
179 |
-
self.encoder = encoder
|
180 |
-
|
181 |
-
def forward(self, *input, **kwargs):
|
182 |
-
return self.encoder(*input, **kwargs)[0]
|
183 |
-
|
184 |
-
|
185 |
-
class DecoderWithLMheadInitial(torch.nn.Module):
|
186 |
-
""" Creation of a class to combine the decoder and the lm head """
|
187 |
-
|
188 |
-
def __init__(self, decoder, lm_head, config):
|
189 |
-
super().__init__()
|
190 |
-
self.decoder = decoder
|
191 |
-
self.lm_head = lm_head
|
192 |
-
self.config = config
|
193 |
-
|
194 |
-
def forward(self, input_ids, attention_mask, encoder_hidden_states):
|
195 |
-
decoder_output = self.decoder(
|
196 |
-
input_ids=input_ids,
|
197 |
-
encoder_attention_mask=attention_mask,
|
198 |
-
encoder_hidden_states=encoder_hidden_states,
|
199 |
-
)
|
200 |
-
|
201 |
-
return (
|
202 |
-
self.lm_head(decoder_output[0] * (self.config.d_model ** -0.5)),
|
203 |
-
decoder_output[1],
|
204 |
-
)
|
205 |
-
|
206 |
-
|
207 |
-
_folder = Path.cwd()
|
208 |
-
saved_models_path = _folder.joinpath("models")
|
209 |
-
|
210 |
-
Bar.check_tty = False
|
211 |
-
|
212 |
-
|
213 |
-
def create_t5_encoder_decoder(pretrained_version="t5-base"):
|
214 |
-
"""Generates an encoder and a decoder model with a language model head from a pretrained huggingface model
|
215 |
-
|
216 |
-
Args:
|
217 |
-
pretrained_version (str): Name of a pretrained model, or path to a pretrained / finetuned version of T5
|
218 |
-
|
219 |
-
Returns:
|
220 |
-
simplified_encoder: pytorch t5 encoder with a wrapper to output only the hidden states
|
221 |
-
decoder_with_lm_head: pytorch t5 decoder with a language modeling head
|
222 |
-
"""
|
223 |
-
|
224 |
-
if 'mt5' in pretrained_version:
|
225 |
-
model = MT5ForConditionalGeneration.from_pretrained(
|
226 |
-
pretrained_version, use_auth_token=get_auth_token())
|
227 |
-
else:
|
228 |
-
model = T5ForConditionalGeneration.from_pretrained(
|
229 |
-
pretrained_version, use_auth_token=get_auth_token())
|
230 |
-
|
231 |
-
return turn_model_into_encoder_decoder(model)
|
232 |
-
|
233 |
-
|
234 |
-
def turn_model_into_encoder_decoder(model):
|
235 |
-
encoder = model.encoder
|
236 |
-
decoder = model.decoder
|
237 |
-
lm_head = model.lm_head
|
238 |
-
|
239 |
-
decoder_with_lm_head = DecoderWithLMhead(decoder, lm_head, model.config)
|
240 |
-
simplified_encoder = T5Encoder(encoder)
|
241 |
-
decoder_with_lm_head_init = DecoderWithLMheadInitial(
|
242 |
-
decoder, lm_head, model.config)
|
243 |
-
|
244 |
-
return simplified_encoder, decoder_with_lm_head, decoder_with_lm_head_init
|
245 |
-
|
246 |
-
|
247 |
-
def generate_onnx_representation(
|
248 |
-
pretrained_version=None,
|
249 |
-
model=None,
|
250 |
-
output_path=None,
|
251 |
-
input_sequence_length=256,
|
252 |
-
onnx_opset_version=12, # no other opset versions are tested, change at your own risk
|
253 |
-
):
|
254 |
-
"""Exports a given huggingface pretrained model, or a given model and tokenizer, to onnx
|
255 |
-
|
256 |
-
Args:
|
257 |
-
pretrained_version (str): Name of a pretrained model, or path to a pretrained / finetuned version of T5
|
258 |
-
output_path (Optional[str]): if missing then use ./models
|
259 |
-
input_sequence_length (Optional[int]): typical input sequence length, for use by the ORT for possible optimization
|
260 |
-
onnx_opset_version (Optional[int]): ONNX Operator Set Version, default 12 is the only tested version
|
261 |
-
"""
|
262 |
-
if (pretrained_version is None) and model is None:
|
263 |
-
print(
|
264 |
-
"You need to specify pretrained_version (the pretrained model you wish to export). Alternatively you can export a model you have in memory."
|
265 |
-
)
|
266 |
-
return
|
267 |
-
|
268 |
-
if model is not None:
|
269 |
-
(
|
270 |
-
simplified_encoder,
|
271 |
-
decoder_with_lm_head,
|
272 |
-
decoder_with_lm_head_init,
|
273 |
-
) = turn_model_into_encoder_decoder(model)
|
274 |
-
else:
|
275 |
-
(
|
276 |
-
simplified_encoder,
|
277 |
-
decoder_with_lm_head,
|
278 |
-
decoder_with_lm_head_init,
|
279 |
-
) = create_t5_encoder_decoder(pretrained_version)
|
280 |
-
|
281 |
-
# model paths for enc, dec and dec_init
|
282 |
-
output_path = saved_models_path if output_path is None else Path(
|
283 |
-
output_path)
|
284 |
-
encoder_path, decoder_path, init_decoder_path = get_model_paths(
|
285 |
-
pretrained_version, output_path, quantized=False
|
286 |
-
)
|
287 |
-
|
288 |
-
model_config = AutoConfig.from_pretrained(
|
289 |
-
pretrained_version, use_auth_token=get_auth_token())
|
290 |
-
|
291 |
-
# Though these are dummy inputs, ORT optimizations do reference these values,
|
292 |
-
# so it is worth using values as close to production as possible
|
293 |
-
batch_size = 1 # not configurable since only CPU
|
294 |
-
enc_seq_length = input_sequence_length
|
295 |
-
# a decoder sequence length is always one because it's just the last generated token
|
296 |
-
dec_seq_length = 1
|
297 |
-
input_ids = torch.ones(batch_size, enc_seq_length, dtype=torch.int64)
|
298 |
-
attention_mask = torch.ones(batch_size, enc_seq_length, dtype=torch.int64)
|
299 |
-
|
300 |
-
n_heads = model_config.num_heads
|
301 |
-
d_kv = model_config.d_kv
|
302 |
-
|
303 |
-
input_ids_dec = torch.ones(batch_size, dec_seq_length, dtype=torch.int64)
|
304 |
-
attention_mask_dec = torch.ones(
|
305 |
-
batch_size, dec_seq_length, dtype=torch.int64)
|
306 |
-
enc_out = torch.ones(
|
307 |
-
(batch_size, enc_seq_length, model_config.d_model), dtype=torch.float32
|
308 |
-
)
|
309 |
-
|
310 |
-
# self_attention_past_key_values = torch.ones(
|
311 |
-
# (model_config.num_decoder_layers, 2, batch_size, n_heads, seq_length_a, d_kv), dtype=torch.float32)
|
312 |
-
# cross_attention_past_key_values = torch.ones(
|
313 |
-
# (model_config.num_decoder_layers, 2, batch_size, n_heads, seq_length_b, d_kv), dtype=torch.float32)
|
314 |
-
|
315 |
-
sa = torch.ones(
|
316 |
-
(batch_size, n_heads, dec_seq_length, d_kv), dtype=torch.float32
|
317 |
-
) # 1, 8, 1, 64
|
318 |
-
ca = torch.ones(
|
319 |
-
(batch_size, n_heads, enc_seq_length, d_kv), dtype=torch.float32
|
320 |
-
) # 1, 8, variable, 64
|
321 |
-
t5_block = (sa, sa, ca, ca)
|
322 |
-
past_key_values = (t5_block,) * model_config.num_decoder_layers
|
323 |
-
|
324 |
-
flat_past_key_values = functools.reduce(
|
325 |
-
operator.iconcat, past_key_values, [])
|
326 |
-
|
327 |
-
decoder_all_inputs = tuple(
|
328 |
-
[input_ids_dec, attention_mask_dec, enc_out] + flat_past_key_values
|
329 |
-
)
|
330 |
-
|
331 |
-
# for progress bars
|
332 |
-
bar = Bar("Exporting to onnx...", max=3)
|
333 |
-
|
334 |
-
import warnings
|
335 |
-
|
336 |
-
# ignores all the warnings during conversion
|
337 |
-
warnings.filterwarnings("ignore")
|
338 |
-
|
339 |
-
# Exports to ONNX
|
340 |
-
with torch.no_grad():
|
341 |
-
|
342 |
-
decoder_inputs = [
|
343 |
-
"input_ids",
|
344 |
-
"encoder_attention_mask",
|
345 |
-
"encoder_hidden_states",
|
346 |
-
]
|
347 |
-
|
348 |
-
pkv_input_names = ["pkv_{}".format(
|
349 |
-
i) for i in range(len(flat_past_key_values))]
|
350 |
-
|
351 |
-
decoder_input_names = decoder_inputs + pkv_input_names
|
352 |
-
|
353 |
-
decoder_output_names = ["logits", "output_past_key_values"]
|
354 |
-
|
355 |
-
dyn_axis_general = {0: "batch", 1: "sequence"}
|
356 |
-
dyn_axis_pkv = {0: "batch", 2: "seq_length"}
|
357 |
-
|
358 |
-
dyn_axis = {
|
359 |
-
"input_ids": dyn_axis_general,
|
360 |
-
"encoder_attention_mask": dyn_axis_general,
|
361 |
-
"encoder_hidden_states": dyn_axis_general,
|
362 |
-
"logits": dyn_axis_general,
|
363 |
-
"output_past_key_values": dyn_axis_general,
|
364 |
-
}
|
365 |
-
|
366 |
-
dyn_pkv = {
|
367 |
-
"pkv_{}".format(i): dyn_axis_pkv
|
368 |
-
for i in range(len(flat_past_key_values))
|
369 |
-
}
|
370 |
-
|
371 |
-
dyn_axis_params = {**dyn_axis, **dyn_pkv}
|
372 |
-
|
373 |
-
# decoder to utilize past key values:
|
374 |
-
torch.onnx.export(
|
375 |
-
decoder_with_lm_head,
|
376 |
-
decoder_all_inputs,
|
377 |
-
decoder_path.as_posix(),
|
378 |
-
export_params=True,
|
379 |
-
do_constant_folding=True,
|
380 |
-
opset_version=onnx_opset_version,
|
381 |
-
input_names=decoder_input_names,
|
382 |
-
output_names=decoder_output_names,
|
383 |
-
dynamic_axes=dyn_axis_params,
|
384 |
-
)
|
385 |
-
bar.next()
|
386 |
-
|
387 |
-
torch.onnx.export(
|
388 |
-
simplified_encoder,
|
389 |
-
args=(input_ids, attention_mask),
|
390 |
-
f=encoder_path.as_posix(),
|
391 |
-
export_params=True,
|
392 |
-
opset_version=onnx_opset_version,
|
393 |
-
do_constant_folding=True,
|
394 |
-
input_names=["input_ids", "attention_mask"],
|
395 |
-
output_names=["hidden_states"],
|
396 |
-
dynamic_axes={
|
397 |
-
"input_ids": dyn_axis_general,
|
398 |
-
"attention_mask": dyn_axis_general,
|
399 |
-
"hidden_states": dyn_axis_general,
|
400 |
-
},
|
401 |
-
)
|
402 |
-
bar.next()
|
403 |
-
# initial decoder to produce past key values
|
404 |
-
torch.onnx.export(
|
405 |
-
decoder_with_lm_head_init,
|
406 |
-
(input_ids_dec, attention_mask_dec, enc_out),
|
407 |
-
init_decoder_path.as_posix(),
|
408 |
-
export_params=True,
|
409 |
-
opset_version=onnx_opset_version,
|
410 |
-
input_names=[
|
411 |
-
"input_ids",
|
412 |
-
"encoder_attention_mask",
|
413 |
-
"encoder_hidden_states",
|
414 |
-
],
|
415 |
-
output_names=["logits", "past_key_values"],
|
416 |
-
dynamic_axes={
|
417 |
-
# batch_size, seq_length = input_shape
|
418 |
-
"input_ids": dyn_axis_general,
|
419 |
-
"encoder_attention_mask": dyn_axis_general,
|
420 |
-
"encoder_hidden_states": dyn_axis_general,
|
421 |
-
"logits": dyn_axis_general,
|
422 |
-
"past_key_values": dyn_axis_general,
|
423 |
-
},
|
424 |
-
)
|
425 |
-
bar.next()
|
426 |
-
bar.finish()
|
427 |
-
|
428 |
-
return encoder_path, decoder_path, init_decoder_path
|
429 |
-
|
430 |
-
|
431 |
-
def get_model_paths(pretrained_model, model_path, quantized):
|
432 |
-
|
433 |
-
model_path.mkdir(parents=True, exist_ok=True)
|
434 |
-
|
435 |
-
# gets only the filename
|
436 |
-
pretrained_model_name = Path(pretrained_model).stem
|
437 |
-
|
438 |
-
if not quantized:
|
439 |
-
encoder_path = model_path.joinpath(
|
440 |
-
f"{pretrained_model_name}-encoder.onnx")
|
441 |
-
decoder_path = model_path.joinpath(
|
442 |
-
f"{pretrained_model_name}-decoder.onnx")
|
443 |
-
init_decoder_path = model_path.joinpath(
|
444 |
-
f"{pretrained_model_name}-init-decoder.onnx"
|
445 |
-
)
|
446 |
-
else:
|
447 |
-
encoder_path = model_path.joinpath(
|
448 |
-
f"{pretrained_model_name}-encoder-quantized.onnx"
|
449 |
-
)
|
450 |
-
decoder_path = model_path.joinpath(
|
451 |
-
f"{pretrained_model_name}-decoder-quantized.onnx"
|
452 |
-
)
|
453 |
-
init_decoder_path = model_path.joinpath(
|
454 |
-
f"{pretrained_model_name}-init-decoder-quantized.onnx"
|
455 |
-
)
|
456 |
-
|
457 |
-
return encoder_path, decoder_path, init_decoder_path
|
458 |
-
|
459 |
-
|
460 |
-
def quantize(models_name_or_path):
|
461 |
-
"""
|
462 |
-
Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU
|
463 |
-
|
464 |
-
Uses unsigned ints for activation values, signed ints for weights, per
|
465 |
-
https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection
|
466 |
-
it is faster on most CPU architectures
|
467 |
-
Args:
|
468 |
-
onnx_model_path: Path to location the exported ONNX model is stored
|
469 |
-
Returns: The Path generated for the quantized
|
470 |
-
"""
|
471 |
-
from onnxruntime.quantization import quantize_dynamic, QuantType
|
472 |
-
|
473 |
-
bar = Bar("Quantizing...", max=3)
|
474 |
-
|
475 |
-
quant_model_paths = []
|
476 |
-
for model in models_name_or_path:
|
477 |
-
model_name = model.as_posix()
|
478 |
-
output_model_name = f"{model_name[:-5]}-quantized.onnx"
|
479 |
-
quantize_dynamic(
|
480 |
-
model_input=model_name,
|
481 |
-
model_output=output_model_name,
|
482 |
-
per_channel=True,
|
483 |
-
reduce_range=True, # should be the same as per_channel
|
484 |
-
activation_type=QuantType.QUInt8,
|
485 |
-
weight_type=QuantType.QInt8, # per docs, signed is faster on most CPUs
|
486 |
-
optimize_model=False,
|
487 |
-
) # op_types_to_quantize=['MatMul', 'Relu', 'Add', 'Mul' ],
|
488 |
-
quant_model_paths.append(output_model_name)
|
489 |
-
bar.next()
|
490 |
-
|
491 |
-
bar.finish()
|
492 |
-
|
493 |
-
return tuple(quant_model_paths)
|
494 |
-
|
495 |
-
|
496 |
-
class T5Encoder(torch.nn.Module):
|
497 |
-
def __init__(self, encoder_sess):
|
498 |
-
super().__init__()
|
499 |
-
self.encoder = encoder_sess
|
500 |
-
self.main_input_name = "input_ids"
|
501 |
-
|
502 |
-
def forward(
|
503 |
-
self,
|
504 |
-
input_ids,
|
505 |
-
attention_mask,
|
506 |
-
inputs_embeds=None,
|
507 |
-
head_mask=None,
|
508 |
-
output_attentions=None,
|
509 |
-
output_hidden_states=None,
|
510 |
-
return_dict=None,
|
511 |
-
):
|
512 |
-
|
513 |
-
encoder_hidden_state = torch.from_numpy(
|
514 |
-
self.encoder.run(
|
515 |
-
None,
|
516 |
-
{
|
517 |
-
"input_ids": input_ids.cpu().numpy(),
|
518 |
-
"attention_mask": attention_mask.cpu().numpy(),
|
519 |
-
},
|
520 |
-
)[0]
|
521 |
-
)
|
522 |
-
|
523 |
-
return BaseModelOutput(encoder_hidden_state)
|
524 |
-
|
525 |
-
|
526 |
-
class T5DecoderInit(torch.nn.Module):
|
527 |
-
def __init__(self, decoder_sess):
|
528 |
-
super().__init__()
|
529 |
-
self.decoder = decoder_sess
|
530 |
-
|
531 |
-
def forward(self, input_ids, encoder_attention_mask, encoder_hidden_states):
|
532 |
-
|
533 |
-
decoder_outputs = self.decoder.run(
|
534 |
-
None,
|
535 |
-
{
|
536 |
-
"input_ids": input_ids.cpu().numpy(),
|
537 |
-
"encoder_attention_mask": encoder_attention_mask.cpu().numpy(),
|
538 |
-
"encoder_hidden_states": encoder_hidden_states.cpu().numpy(),
|
539 |
-
},
|
540 |
-
)
|
541 |
-
|
542 |
-
list_pkv = tuple(torch.from_numpy(x) for x in decoder_outputs[1:])
|
543 |
-
|
544 |
-
out_past_key_values = tuple(
|
545 |
-
list_pkv[i: i + 4] for i in range(0, len(list_pkv), 4)
|
546 |
-
)
|
547 |
-
|
548 |
-
return torch.from_numpy(decoder_outputs[0]), out_past_key_values
|
549 |
-
|
550 |
-
|
551 |
-
class T5Decoder(torch.nn.Module):
|
552 |
-
def __init__(self, decoder_sess):
|
553 |
-
super().__init__()
|
554 |
-
self.decoder = decoder_sess
|
555 |
-
|
556 |
-
def forward(self, input_ids, attention_mask, encoder_output, past_key_values):
|
557 |
-
|
558 |
-
decoder_inputs = {
|
559 |
-
"input_ids": input_ids.cpu().numpy(),
|
560 |
-
"encoder_attention_mask": attention_mask.cpu().numpy(),
|
561 |
-
"encoder_hidden_states": encoder_output.cpu().numpy(),
|
562 |
-
}
|
563 |
-
|
564 |
-
flat_past_key_values = functools.reduce(
|
565 |
-
operator.iconcat, past_key_values, [])
|
566 |
-
|
567 |
-
past_key_values = {
|
568 |
-
f"pkv_{i}": pkv.cpu().numpy() for i, pkv in enumerate(flat_past_key_values)
|
569 |
-
}
|
570 |
-
|
571 |
-
decoder_outputs = self.decoder.run(
|
572 |
-
None, {**decoder_inputs, **past_key_values})
|
573 |
-
# converts each value of the list to tensor from numpy
|
574 |
-
list_pkv = tuple(torch.from_numpy(x) for x in decoder_outputs[1:])
|
575 |
-
|
576 |
-
# creates a tuple of tuples of shape 6x4 from the above tuple
|
577 |
-
out_past_key_values = tuple(
|
578 |
-
list_pkv[i: i + 4] for i in range(0, len(list_pkv), 4)
|
579 |
-
)
|
580 |
-
|
581 |
-
return torch.from_numpy(decoder_outputs[0]), out_past_key_values
|
582 |
-
|
583 |
-
|
584 |
-
class OnnxT5(T5ForConditionalGeneration):
|
585 |
-
"""creates a T5 model using onnx sessions (encode, decoder & init_decoder)"""
|
586 |
-
|
587 |
-
def __init__(self, model_or_model_path, onnx_model_sessions):
|
588 |
-
config = AutoConfig.from_pretrained(
|
589 |
-
model_or_model_path, use_auth_token=get_auth_token()
|
590 |
-
)
|
591 |
-
super().__init__(config)
|
592 |
-
|
593 |
-
# monkeypatch to work for MT5
|
594 |
-
if (
|
595 |
-
isinstance(model_or_model_path, str)
|
596 |
-
and "mt5" in model_or_model_path.lower()
|
597 |
-
) or (
|
598 |
-
hasattr(model_or_model_path, "name_or_path")
|
599 |
-
and "mt5" in model_or_model_path.name_or_path
|
600 |
-
):
|
601 |
-
self.model_type = "mt5"
|
602 |
-
self.config_class = MT5Config
|
603 |
-
self._keys_to_ignore_on_load_missing = [
|
604 |
-
r"encoder\.embed_tokens\.weight",
|
605 |
-
]
|
606 |
-
self._keys_to_ignore_on_save = [
|
607 |
-
r"encoder\.embed_tokens\.weight",
|
608 |
-
]
|
609 |
-
|
610 |
-
assert len(onnx_model_sessions) == 3, "all three models should be given"
|
611 |
-
|
612 |
-
encoder_sess, decoder_sess, decoder_sess_init = onnx_model_sessions
|
613 |
-
|
614 |
-
self.encoder = T5Encoder(encoder_sess)
|
615 |
-
self.decoder = T5Decoder(decoder_sess)
|
616 |
-
self.decoder_init = T5DecoderInit(decoder_sess_init)
|
617 |
-
|
618 |
-
def forward(
|
619 |
-
self,
|
620 |
-
input_ids=None,
|
621 |
-
attention_mask=None,
|
622 |
-
decoder_input_ids=None,
|
623 |
-
decoder_attention_mask=None,
|
624 |
-
head_mask=None,
|
625 |
-
decoder_head_mask=None,
|
626 |
-
cross_attn_head_mask=None,
|
627 |
-
encoder_outputs=None,
|
628 |
-
past_key_values=None,
|
629 |
-
inputs_embeds=None,
|
630 |
-
decoder_inputs_embeds=None,
|
631 |
-
labels=None,
|
632 |
-
use_cache=None,
|
633 |
-
output_attentions=None,
|
634 |
-
output_hidden_states=None,
|
635 |
-
return_dict=None,
|
636 |
-
):
|
637 |
-
|
638 |
-
if encoder_outputs is None:
|
639 |
-
# Convert encoder inputs in embeddings if needed
|
640 |
-
encoder_outputs = self.encoder(
|
641 |
-
input_ids=input_ids, attention_mask=attention_mask
|
642 |
-
)
|
643 |
-
|
644 |
-
encoder_hidden_states = encoder_outputs[0]
|
645 |
-
|
646 |
-
if past_key_values is not None:
|
647 |
-
if decoder_input_ids is not None:
|
648 |
-
decoder_input_ids = decoder_input_ids[:, -1:]
|
649 |
-
if decoder_inputs_embeds is not None:
|
650 |
-
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
|
651 |
-
|
652 |
-
if past_key_values is None:
|
653 |
-
|
654 |
-
# runs only for the first time:
|
655 |
-
init_onnx_outputs = self.decoder_init(
|
656 |
-
decoder_input_ids, attention_mask, encoder_hidden_states
|
657 |
-
)
|
658 |
-
|
659 |
-
logits, past_key_values = init_onnx_outputs
|
660 |
-
|
661 |
-
else:
|
662 |
-
|
663 |
-
onnx_outputs = self.decoder(
|
664 |
-
decoder_input_ids,
|
665 |
-
attention_mask,
|
666 |
-
encoder_hidden_states,
|
667 |
-
past_key_values,
|
668 |
-
)
|
669 |
-
|
670 |
-
logits, past_key_values = onnx_outputs
|
671 |
-
|
672 |
-
return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)
|
673 |
-
|
674 |
-
|
675 |
-
def export_and_get_onnx_model(
|
676 |
-
model_or_model_path, custom_output_path=saved_models_path, quantized=True
|
677 |
-
):
|
678 |
-
"""
|
679 |
-
Method for whole pipeline,
|
680 |
-
converts from pytorch to onnx --> quantizes model --> sets onnx runtime
|
681 |
-
--> builds whole onnx model with all sessions
|
682 |
-
|
683 |
-
"""
|
684 |
-
|
685 |
-
# Step 1. convert huggingfaces t5 model to onnx
|
686 |
-
onnx_model_paths = generate_onnx_representation(
|
687 |
-
model_or_model_path, output_path=custom_output_path
|
688 |
-
)
|
689 |
-
|
690 |
-
if quantized:
|
691 |
-
# Step 2. (recommended) quantize the converted model for fast inference and to reduce model size.
|
692 |
-
quant_model_paths = quantize(onnx_model_paths)
|
693 |
-
|
694 |
-
# step 3. setup onnx runtime
|
695 |
-
print("Setting up onnx model...")
|
696 |
-
model_sessions = get_onnx_runtime_sessions(quant_model_paths)
|
697 |
-
else:
|
698 |
-
print("Setting up onnx model...")
|
699 |
-
model_sessions = get_onnx_runtime_sessions(onnx_model_paths)
|
700 |
-
|
701 |
-
# step 4. get the onnx model
|
702 |
-
model = OnnxT5(model_or_model_path, model_sessions)
|
703 |
-
print("Done!")
|
704 |
-
|
705 |
-
return model
|
706 |
-
|
707 |
-
|
708 |
-
def get_onnx_model(model_name, onnx_models_path=saved_models_path, quantized=True):
|
709 |
-
"""
|
710 |
-
method gets the onnx model, if already converted models exists
|
711 |
-
Example:
|
712 |
-
>> get_onnx_model(model_name="t5-finetuned", onnx_models_path="../models/onnx/quantized/")
|
713 |
-
|
714 |
-
"""
|
715 |
-
|
716 |
-
encoder_path, decoder_path, init_decoder_path = get_model_paths(
|
717 |
-
model_name, Path(onnx_models_path), quantized
|
718 |
-
)
|
719 |
-
|
720 |
-
if quantized:
|
721 |
-
assert (
|
722 |
-
encoder_path.exists()
|
723 |
-
and decoder_path.exists()
|
724 |
-
and init_decoder_path.exists()
|
725 |
-
), "quantized model don't exist in the model folder, first quantize the model!"
|
726 |
-
else:
|
727 |
-
assert (
|
728 |
-
encoder_path.exists()
|
729 |
-
and decoder_path.exists()
|
730 |
-
and init_decoder_path.exists()
|
731 |
-
), "all or some models don't exists in the model folder, first convert the model! "
|
732 |
-
|
733 |
-
model_paths = encoder_path, decoder_path, init_decoder_path
|
734 |
-
|
735 |
-
model_sessions = get_onnx_runtime_sessions(model_paths)
|
736 |
-
|
737 |
-
model = OnnxT5(model_name, model_sessions)
|
738 |
-
|
739 |
-
return model
|
740 |
|
741 |
|
742 |
trained_model_path = './t5_squad_v1/'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import time
|
2 |
import gradio as gr
|
3 |
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
|
|
|
4 |
import os
|
5 |
from pathlib import Path
|
6 |
+
from FastT5 import get_onnx_runtime_sessions, OnnxT5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
trained_model_path = './t5_squad_v1/'
|