Spaces:
Runtime error
Runtime error
SoybeanMilk
commited on
Commit
•
4dcbce8
1
Parent(s):
7282235
Add madlad400 support.
Browse files- src/translation/translationModel.py +204 -55
src/translation/translationModel.py
CHANGED
@@ -3,17 +3,14 @@ import warnings
|
|
3 |
import huggingface_hub
|
4 |
import requests
|
5 |
import torch
|
6 |
-
|
7 |
import ctranslate2
|
8 |
import transformers
|
9 |
-
|
10 |
-
import re
|
11 |
|
12 |
from typing import Optional
|
13 |
from src.config import ModelConfig
|
14 |
from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
|
15 |
|
16 |
-
|
17 |
class TranslationModel:
|
18 |
def __init__(
|
19 |
self,
|
@@ -68,7 +65,7 @@ class TranslationModel:
|
|
68 |
if os.path.isdir(modelConfig.url):
|
69 |
self.modelPath = modelConfig.url
|
70 |
else:
|
71 |
-
self.modelPath = download_model(
|
72 |
modelConfig,
|
73 |
localFilesOnly=localFilesOnly,
|
74 |
cacheDir=downloadRoot,
|
@@ -86,85 +83,233 @@ class TranslationModel:
|
|
86 |
self.load_model()
|
87 |
|
88 |
def load_model(self):
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
-
|
100 |
-
self.
|
101 |
-
|
102 |
-
|
103 |
-
self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
|
104 |
-
elif "ALMA" in self.modelPath:
|
105 |
-
self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.code + " to " + self.translationLang.whisper.code + ":" + self.whisperLang.whisper.code + ":"
|
106 |
-
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath, use_fast=True)
|
107 |
-
self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", trust_remote_code=False, revision="main")
|
108 |
-
self.transTranslator = transformers.pipeline("text-generation", model=self.transModel, tokenizer=self.transTokenizer, batch_size=2, do_sample=True, temperature=0.7, top_p=0.95, top_k=40, repetition_penalty=1.1)
|
109 |
-
else:
|
110 |
-
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
|
111 |
-
self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
|
112 |
-
if "m2m100" in self.modelPath:
|
113 |
-
self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.m2m100.code, tgt_lang=self.translationLang.m2m100.code)
|
114 |
-
else: #NLLB
|
115 |
-
self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.nllb.code, tgt_lang=self.translationLang.nllb.code)
|
116 |
|
117 |
def release_vram(self):
|
118 |
try:
|
119 |
if torch.cuda.is_available():
|
120 |
if "ct2" not in self.modelPath:
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
print("release vram end.")
|
126 |
except Exception as e:
|
|
|
127 |
print("Error release vram: " + str(e))
|
128 |
|
129 |
|
130 |
def translation(self, text: str, max_length: int = 400):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
output = None
|
132 |
result = None
|
133 |
try:
|
134 |
if "ct2" in self.modelPath:
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
elif "mt5" in self.modelPath:
|
140 |
output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
|
141 |
result = output[0]['generated_text']
|
142 |
elif "ALMA" in self.modelPath:
|
143 |
-
|
|
|
|
|
|
|
144 |
result = output[0]['generated_text']
|
145 |
-
result = re.sub(rf'^(.*{self.translationLang.whisper.code}: )', '', result) # Remove the prompt from the result
|
146 |
-
result = re.sub(rf'^(Translate this from .* to .*:)', '', result) # Remove the translation instruction
|
147 |
-
return result.strip()
|
148 |
else: #M2M100 & NLLB
|
149 |
output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
|
150 |
result = output[0]['translation_text']
|
151 |
except Exception as e:
|
|
|
152 |
print("Error translation text: " + str(e))
|
153 |
|
154 |
return result
|
155 |
|
156 |
|
157 |
-
_MODELS = ["
|
158 |
-
"
|
159 |
-
"
|
160 |
-
"
|
161 |
-
"
|
162 |
-
"nllb-200-distilled-600M-ct2", "nllb-200-distilled-600M-ct2-int8", "nllb-200-distilled-600M-ct2-float16",
|
163 |
-
"m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
|
164 |
-
"m2m100_1.2B", "m2m100_418M",
|
165 |
-
"mt5-zh-ja-en-trimmed",
|
166 |
-
"mt5-zh-ja-en-trimmed-fine-tuned-v1",
|
167 |
-
"ALMA-13B-GPTQ"]
|
168 |
|
169 |
def check_model_name(name):
|
170 |
return any(allowed_name in name for allowed_name in _MODELS)
|
@@ -224,7 +369,8 @@ def download_model(
|
|
224 |
"vocab.json", #m2m100
|
225 |
"model.safetensors",
|
226 |
"quantize_config.json",
|
227 |
-
"tokenizer.model"
|
|
|
228 |
]
|
229 |
|
230 |
kwargs = {
|
@@ -232,6 +378,9 @@ def download_model(
|
|
232 |
"allow_patterns": allowPatterns,
|
233 |
#"tqdm_class": disabled_tqdm,
|
234 |
}
|
|
|
|
|
|
|
235 |
|
236 |
if outputDir is not None:
|
237 |
kwargs["local_dir"] = outputDir
|
|
|
3 |
import huggingface_hub
|
4 |
import requests
|
5 |
import torch
|
|
|
6 |
import ctranslate2
|
7 |
import transformers
|
8 |
+
import traceback
|
|
|
9 |
|
10 |
from typing import Optional
|
11 |
from src.config import ModelConfig
|
12 |
from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
|
13 |
|
|
|
14 |
class TranslationModel:
|
15 |
def __init__(
|
16 |
self,
|
|
|
65 |
if os.path.isdir(modelConfig.url):
|
66 |
self.modelPath = modelConfig.url
|
67 |
else:
|
68 |
+
self.modelPath = modelConfig.url if getattr(modelConfig, "model_file", None) is not None else download_model(
|
69 |
modelConfig,
|
70 |
localFilesOnly=localFilesOnly,
|
71 |
cacheDir=downloadRoot,
|
|
|
83 |
self.load_model()
|
84 |
|
85 |
def load_model(self):
|
86 |
+
"""
|
87 |
+
[from_pretrained]
|
88 |
+
low_cpu_mem_usage(bool, optional)
|
89 |
+
Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is an experimental feature and a subject to change at any moment.
|
90 |
+
|
91 |
+
[transformers.AutoTokenizer.from_pretrained]
|
92 |
+
use_fast (bool, optional, defaults to True):
|
93 |
+
Use a fast Rust-based tokenizer if it is supported for a given model.
|
94 |
+
If a fast tokenizer is not available for a given model, a normal Python-based tokenizer is returned instead.
|
95 |
+
|
96 |
+
[transformers.AutoModelForCausalLM.from_pretrained]
|
97 |
+
device_map (str or Dict[str, Union[int, str, torch.device], optional):
|
98 |
+
Sent directly as model_kwargs (just a simpler shortcut). When accelerate library is present,
|
99 |
+
set device_map="auto" to compute the most optimized device_map automatically.
|
100 |
+
revision (str, optional, defaults to "main"):
|
101 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id,
|
102 |
+
since we use a git-based system for storing models and other artifacts on huggingface.co,
|
103 |
+
so revision can be any identifier allowed by git.
|
104 |
+
code_revision (str, optional, defaults to "main")
|
105 |
+
The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model.
|
106 |
+
It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co,
|
107 |
+
so revision can be any identifier allowed by git.
|
108 |
+
trust_remote_code (bool, optional, defaults to False):
|
109 |
+
Whether or not to allow for custom models defined on the Hub in their own modeling files.
|
110 |
+
This option should only be set to True for repositories you trust and in which you have read the code,
|
111 |
+
as it will execute code present on the Hub on your local machine.
|
112 |
+
|
113 |
+
[transformers.pipeline "text-generation"]
|
114 |
+
do_sample:
|
115 |
+
if set to True, this parameter enables decoding strategies such as multinomial sampling,
|
116 |
+
beam-search multinomial sampling, Top-K sampling and Top-p sampling.
|
117 |
+
All these strategies select the next token from the probability distribution
|
118 |
+
over the entire vocabulary with various strategy-specific adjustments.
|
119 |
+
temperature (float, optional, defaults to 1.0):
|
120 |
+
The value used to modulate the next token probabilities.
|
121 |
+
top_k (int, optional, defaults to 50):
|
122 |
+
The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
123 |
+
top_p (float, optional, defaults to 1.0):
|
124 |
+
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
|
125 |
+
repetition_penalty (float, optional, defaults to 1.0)
|
126 |
+
The parameter for repetition penalty. 1.0 means no penalty. See this paper for more details.
|
127 |
+
|
128 |
+
[transformers.GPTQConfig]
|
129 |
+
use_exllama (bool, optional):
|
130 |
+
Whether to use exllama backend. Defaults to True if unset. Only works with bits = 4.
|
131 |
+
|
132 |
+
[ExLlama]
|
133 |
+
ExLlama is a Python/C++/CUDA implementation of the Llama model that is designed for faster inference with 4-bit GPTQ weights (check out these benchmarks).
|
134 |
+
The ExLlama kernel is activated by default when you create a [GPTQConfig] object.
|
135 |
+
To boost inference speed even further, use the ExLlamaV2 kernels by configuring the exllama_config parameter.
|
136 |
+
The ExLlama kernels are only supported when the entire model is on the GPU.
|
137 |
+
If you're doing inference on a CPU with AutoGPTQ (version > 0.4.2), then you'll need to disable the ExLlama kernel.
|
138 |
+
This overwrites the attributes related to the ExLlama kernels in the quantization config of the config.json file.
|
139 |
+
https://github.com/huggingface/transformers/blob/main/docs/source/en/quantization.md#exllama
|
140 |
+
|
141 |
+
[ctransformers]
|
142 |
+
gpu_layers
|
143 |
+
means number of layers to run on GPU. Depending on how much GPU memory is available you can increase gpu_layers. Start with a larger value gpu_layers=100 and if it runs out of memory, try smaller values.
|
144 |
+
To run some of the model layers on GPU, set the `gpu_layers` parameter
|
145 |
+
https://github.com/marella/ctransformers/issues/68
|
146 |
+
"""
|
147 |
+
try:
|
148 |
+
print('\n\nLoading model: %s\n\n' % self.modelPath)
|
149 |
+
if "ct2" in self.modelPath:
|
150 |
+
if any(name in self.modelPath for name in ["nllb", "m2m100"]):
|
151 |
+
if "nllb" in self.modelPath:
|
152 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.nllb.code)
|
153 |
+
self.targetPrefix = [self.translationLang.nllb.code]
|
154 |
+
elif "m2m100" in self.modelPath:
|
155 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
|
156 |
+
self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
|
157 |
+
self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
|
158 |
+
elif "ALMA" in self.modelPath:
|
159 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath)
|
160 |
+
self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ":\n" + self.whisperLang.whisper.names[0] + ": "
|
161 |
+
self.transModel = ctranslate2.Generator(self.modelPath, compute_type="auto", device=self.device)
|
162 |
+
elif "madlad400" in self.modelPath:
|
163 |
+
self.madlad400Prefix = "<2" + self.translationLang.whisper.code + "> "
|
164 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
|
165 |
+
self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
|
166 |
+
elif "mt5" in self.modelPath:
|
167 |
+
self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
|
168 |
+
self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
|
169 |
+
self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath, low_cpu_mem_usage=True)
|
170 |
+
self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
|
171 |
+
elif "ALMA" in self.modelPath:
|
172 |
+
self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ":\n" + self.whisperLang.whisper.names[0] + ": "
|
173 |
+
if "GPTQ" in self.modelPath:
|
174 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath, use_fast=True)
|
175 |
+
if self.device == "cpu":
|
176 |
+
# Due to the poor support of GPTQ for CPUs, Therefore, it is strongly discouraged to operate it on CPU.
|
177 |
+
# set torch_dtype=torch.float32 to prevent the occurrence of the exception "addmm_impl_cpu_ not implemented for 'Half'."
|
178 |
+
transModelConfig = transformers.AutoConfig.from_pretrained(self.modelPath)
|
179 |
+
transModelConfig.quantization_config["use_exllama"] = False
|
180 |
+
self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", low_cpu_mem_usage=True, trust_remote_code=False, revision=self.modelConfig.revision, config=transModelConfig, torch_dtype=torch.float32)
|
181 |
+
else:
|
182 |
+
# transModelConfig.quantization_config["exllama_config"] = {"version":2} # After configuring to use ExLlamaV2, VRAM cannot be effectively released, which may be an issue. Temporarily not adopting the V2 version.
|
183 |
+
self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", low_cpu_mem_usage=True, trust_remote_code=False, revision=self.modelConfig.revision)
|
184 |
+
elif "GGUF" in self.modelPath:
|
185 |
+
import ctransformers
|
186 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url)
|
187 |
+
if self.device == "cpu":
|
188 |
+
self.transModel = ctransformers.AutoModelForCausalLM.from_pretrained(self.modelPath, hf=True, model_file=self.modelConfig.model_file)
|
189 |
+
else:
|
190 |
+
self.transModel = ctransformers.AutoModelForCausalLM.from_pretrained(self.modelPath, hf=True, model_file=self.modelConfig.model_file, gpu_layers=50)
|
191 |
+
self.transTranslator = transformers.pipeline("text-generation", model=self.transModel, tokenizer=self.transTokenizer, do_sample=True, temperature=0.7, top_k=40, top_p=0.95, repetition_penalty=1.1)
|
192 |
+
else:
|
193 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
|
194 |
+
self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
|
195 |
+
if "m2m100" in self.modelPath:
|
196 |
+
self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.m2m100.code, tgt_lang=self.translationLang.m2m100.code)
|
197 |
+
else: #NLLB
|
198 |
+
self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.nllb.code, tgt_lang=self.translationLang.nllb.code)
|
199 |
|
200 |
+
except Exception as e:
|
201 |
+
self.release_vram()
|
202 |
+
raise e
|
203 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
def release_vram(self):
|
206 |
try:
|
207 |
if torch.cuda.is_available():
|
208 |
if "ct2" not in self.modelPath:
|
209 |
+
try:
|
210 |
+
if getattr(self, "transModel", None) is not None:
|
211 |
+
device = torch.device("cpu")
|
212 |
+
self.transModel.to(device)
|
213 |
+
except Exception as e:
|
214 |
+
print(traceback.format_exc())
|
215 |
+
print("\tself.transModel.to cpu, error: " + str(e))
|
216 |
+
if getattr(self, "transTranslator", None) is not None:
|
217 |
+
del self.transTranslator
|
218 |
+
if "ct2" in self.modelPath:
|
219 |
+
if getattr(self, "transModel", None) is not None and getattr(self.transModel, "unload_model", None) is not None:
|
220 |
+
self.transModel.unload_model()
|
221 |
+
|
222 |
+
if getattr(self, "transTokenizer", None) is not None:
|
223 |
+
del self.transTokenizer
|
224 |
+
if getattr(self, "transModel", None) is not None:
|
225 |
+
del self.transModel
|
226 |
+
try:
|
227 |
+
torch.cuda.empty_cache()
|
228 |
+
except Exception as e:
|
229 |
+
print(traceback.format_exc())
|
230 |
+
print("\tcuda empty cache, error: " + str(e))
|
231 |
+
import gc
|
232 |
+
gc.collect()
|
233 |
print("release vram end.")
|
234 |
except Exception as e:
|
235 |
+
print(traceback.format_exc())
|
236 |
print("Error release vram: " + str(e))
|
237 |
|
238 |
|
239 |
def translation(self, text: str, max_length: int = 400):
|
240 |
+
"""
|
241 |
+
[ctranslate2]
|
242 |
+
max_batch_size:
|
243 |
+
The maximum batch size. If the number of inputs is greater than max_batch_size,
|
244 |
+
the inputs are sorted by length and split by chunks of max_batch_size examples
|
245 |
+
so that the number of padding positions is minimized.
|
246 |
+
no_repeat_ngram_size:
|
247 |
+
Prevent repetitions of ngrams with this size (set 0 to disable).
|
248 |
+
beam_size:
|
249 |
+
Beam size (1 for greedy search).
|
250 |
+
|
251 |
+
[ctranslate2.Generator.generate_batch]
|
252 |
+
sampling_temperature:
|
253 |
+
Sampling temperature to generate more random samples.
|
254 |
+
sampling_topk:
|
255 |
+
Randomly sample predictions from the top K candidates.
|
256 |
+
sampling_topp:
|
257 |
+
Keep the most probable tokens whose cumulative probability exceeds this value.
|
258 |
+
repetition_penalty:
|
259 |
+
Penalty applied to the score of previously generated tokens (set > 1 to penalize).
|
260 |
+
include_prompt_in_result:
|
261 |
+
Include the start_tokens in the result.
|
262 |
+
If include_prompt_in_result is True (the default), the decoding loop is constrained to generate the start tokens that are then included in the result.
|
263 |
+
If include_prompt_in_result is False, the start tokens are forwarded in the decoder at once to initialize its state (i.e. the KV cache for Transformer models).
|
264 |
+
For variable-length inputs, only the tokens up to the minimum length in the batch are forwarded at once. The remaining tokens are generated in the decoding loop with constrained decoding.
|
265 |
+
|
266 |
+
[transformers.TextGenerationPipeline.__call__]
|
267 |
+
return_full_text (bool, optional, defaults to True):
|
268 |
+
If set to False only added text is returned, otherwise the full text is returned. Only meaningful if return_text is set to True.
|
269 |
+
"""
|
270 |
output = None
|
271 |
result = None
|
272 |
try:
|
273 |
if "ct2" in self.modelPath:
|
274 |
+
if any(name in self.modelPath for name in ["nllb", "m2m100"]):
|
275 |
+
source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(text))
|
276 |
+
output = self.transModel.translate_batch([source], target_prefix=[self.targetPrefix], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
|
277 |
+
target = output[0].hypotheses[0][1:]
|
278 |
+
result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target))
|
279 |
+
elif "ALMA" in self.modelPath:
|
280 |
+
source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(self.ALMAPrefix + text + "\n" + self.translationLang.whisper.names[0] + ": "))
|
281 |
+
output = self.transModel.generate_batch([source], max_length=max_length, max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams, sampling_temperature=0.7, sampling_topp=0.9, repetition_penalty=1.1, include_prompt_in_result=False) #, sampling_topk=40
|
282 |
+
target = output[0]
|
283 |
+
result = self.transTokenizer.decode(target.sequences_ids[0])
|
284 |
+
elif "madlad400" in self.modelPath:
|
285 |
+
source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(self.madlad400Prefix + text))
|
286 |
+
output = self.transModel.translate_batch([source], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
|
287 |
+
target = output[0].hypotheses[0]
|
288 |
+
result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target))
|
289 |
elif "mt5" in self.modelPath:
|
290 |
output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
|
291 |
result = output[0]['generated_text']
|
292 |
elif "ALMA" in self.modelPath:
|
293 |
+
if "GPTQ" in self.modelPath:
|
294 |
+
output = self.transTranslator(self.ALMAPrefix + text + "\n" + self.translationLang.whisper.names[0] + ": ", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams, return_full_text=False)
|
295 |
+
elif "GGUF" in self.modelPath:
|
296 |
+
output = self.transTranslator(self.ALMAPrefix + text + "\n" + self.translationLang.whisper.names[0] + ": ", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams, return_full_text=False)
|
297 |
result = output[0]['generated_text']
|
|
|
|
|
|
|
298 |
else: #M2M100 & NLLB
|
299 |
output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
|
300 |
result = output[0]['translation_text']
|
301 |
except Exception as e:
|
302 |
+
print(traceback.format_exc())
|
303 |
print("Error translation text: " + str(e))
|
304 |
|
305 |
return result
|
306 |
|
307 |
|
308 |
+
_MODELS = ["nllb-200",
|
309 |
+
"m2m100",
|
310 |
+
"mt5",
|
311 |
+
"ALMA",
|
312 |
+
"madlad400"]
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
|
314 |
def check_model_name(name):
|
315 |
return any(allowed_name in name for allowed_name in _MODELS)
|
|
|
369 |
"vocab.json", #m2m100
|
370 |
"model.safetensors",
|
371 |
"quantize_config.json",
|
372 |
+
"tokenizer.model",
|
373 |
+
"vocabulary.json"
|
374 |
]
|
375 |
|
376 |
kwargs = {
|
|
|
378 |
"allow_patterns": allowPatterns,
|
379 |
#"tqdm_class": disabled_tqdm,
|
380 |
}
|
381 |
+
|
382 |
+
if modelConfig.revision is not None:
|
383 |
+
kwargs["revision"] = modelConfig.revision
|
384 |
|
385 |
if outputDir is not None:
|
386 |
kwargs["local_dir"] = outputDir
|