Spaces:
Running
Running
Add Meta-Llama-3-8B-Instruct ctranslate2 as the translation model to use.
Browse files- app.py +28 -9
- config.json5 +8 -0
- src/config.py +2 -2
- src/translation/translationModel.py +14 -2
app.py
CHANGED
@@ -921,6 +921,8 @@ class WhisperTranscriber:
|
|
921 |
madlad400LangName: str = dataDict.pop("madlad400LangName")
|
922 |
seamlessModelName: str = dataDict.pop("seamlessModelName")
|
923 |
seamlessLangName: str = dataDict.pop("seamlessLangName")
|
|
|
|
|
924 |
|
925 |
translationBatchSize: int = dataDict.pop("translationBatchSize")
|
926 |
translationNoRepeatNgramSize: int = dataDict.pop("translationNoRepeatNgramSize")
|
@@ -954,6 +956,10 @@ class WhisperTranscriber:
|
|
954 |
selectedModelName = seamlessModelName if seamlessModelName is not None and len(seamlessModelName) > 0 else "seamless-m4t-v2-large/facebook"
|
955 |
selectedModel = next((modelConfig for modelConfig in self.app_config.models["seamless"] if modelConfig.name == selectedModelName), None)
|
956 |
translationLang = get_lang_from_seamlessT_Tx_name(seamlessLangName)
|
|
|
|
|
|
|
|
|
957 |
|
958 |
if translationLang is not None:
|
959 |
translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=inputLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams, torchDtypeFloat16=translationTorchDtypeFloat16, usingBitsandbytes=translationUsingBitsandbytes)
|
@@ -1023,6 +1029,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
1023 |
ALMA_models = app_config.get_model_names("ALMA")
|
1024 |
madlad400_models = app_config.get_model_names("madlad400")
|
1025 |
seamless_models = app_config.get_model_names("seamless")
|
|
|
1026 |
if not torch.cuda.is_available(): # Loading only quantized or models with medium-low parameters in an environment without GPU support.
|
1027 |
nllb_models = list(filter(lambda nllb: any(name in nllb for name in ["-600M", "-1.3B", "-3.3B-ct2"]), nllb_models))
|
1028 |
m2m100_models = list(filter(lambda m2m100: "12B" not in m2m100, m2m100_models))
|
@@ -1057,20 +1064,24 @@ def create_ui(app_config: ApplicationConfig):
|
|
1057 |
gr.Dropdown(label="seamless - Model (for translate)", choices=seamless_models, elem_id="seamlessModelName"),
|
1058 |
gr.Dropdown(label="seamless - Language", choices=sorted(get_lang_seamlessT_Tx_names()), elem_id="seamlessLangName"),
|
1059 |
}
|
|
|
|
|
|
|
|
|
1060 |
|
1061 |
common_translation_inputs = lambda : {
|
1062 |
gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
|
1063 |
-
gr.Number(label="Translation - No Repeat Ngram Size", precision=0, value=app_config.translation_no_repeat_ngram_size, elem_id="translationNoRepeatNgramSize"),
|
1064 |
-
gr.Number(label="Translation - Num Beams", precision=0, value=app_config.translation_num_beams, elem_id="translationNumBeams"),
|
1065 |
gr.Checkbox(label="Translation - Torch Dtype float16", visible=torch.cuda.is_available(), value=app_config.translation_torch_dtype_float16, info="Load the float32 translation model with float16 when the system supports GPU (reducing VRAM usage, not applicable to models that have already been quantized, such as Ctranslate2, GPTQ, GGUF)", elem_id="translationTorchDtypeFloat16"),
|
1066 |
gr.Radio(label="Translation - Using Bitsandbytes", visible=torch.cuda.is_available(), choices=[None, "int8", "int4"], value=app_config.translation_using_bitsandbytes, info="Load the float32 translation model into mixed-8bit or 4bit precision quantized model when the system supports GPU (reducing VRAM usage, not applicable to models that have already been quantized, such as Ctranslate2, GPTQ, GGUF)", elem_id="translationUsingBitsandbytes"),
|
1067 |
}
|
1068 |
|
1069 |
common_vad_inputs = lambda : {
|
1070 |
gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD", elem_id="vad"),
|
1071 |
-
gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window, elem_id="vadMergeWindow"),
|
1072 |
-
gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size, elem_id="vadMaxMergeSize"),
|
1073 |
-
gr.Number(label="VAD - Process Timeout (s)", precision=0, value=app_config.vad_process_timeout, elem_id="vadPocessTimeout"),
|
1074 |
}
|
1075 |
|
1076 |
common_word_timestamps_inputs = lambda : {
|
@@ -1148,12 +1159,16 @@ def create_ui(app_config: ApplicationConfig):
|
|
1148 |
with gr.Tab(label="seamless") as seamlessTab:
|
1149 |
with gr.Row():
|
1150 |
inputDict.update(common_seamless_inputs())
|
|
|
|
|
|
|
1151 |
m2m100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [translateInput] )
|
1152 |
nllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [translateInput] )
|
1153 |
mt5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [translateInput] )
|
1154 |
almaTab.select(fn=lambda: "ALMA", inputs = [], outputs= [translateInput] )
|
1155 |
madlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [translateInput] )
|
1156 |
seamlessTab.select(fn=lambda: "seamless", inputs = [], outputs= [translateInput] )
|
|
|
1157 |
with gr.Column():
|
1158 |
with gr.Tab(label="URL") as UrlTab:
|
1159 |
inputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
@@ -1164,14 +1179,14 @@ def create_ui(app_config: ApplicationConfig):
|
|
1164 |
UrlTab.select(fn=lambda: "urlData", inputs = [], outputs= [sourceInput] )
|
1165 |
UploadTab.select(fn=lambda: "multipleFiles", inputs = [], outputs= [sourceInput] )
|
1166 |
MicTab.select(fn=lambda: "microphoneData", inputs = [], outputs= [sourceInput] )
|
1167 |
-
inputDict.update({gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task, elem_id = "task")})
|
1168 |
with gr.Accordion("VAD options", open=False):
|
1169 |
inputDict.update(common_vad_inputs())
|
1170 |
if isFull:
|
1171 |
inputDict.update({
|
1172 |
-
gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding, elem_id = "vadPadding"),
|
1173 |
-
gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window, elem_id = "vadPromptWindow"),
|
1174 |
-
gr.Dropdown(choices=VAD_INITIAL_PROMPT_MODE_VALUES, label="VAD - Initial Prompt Mode", value=app_config.vad_initial_prompt_mode, elem_id = "vadInitialPromptMode")})
|
1175 |
with gr.Accordion("Word Timestamps options", open=False):
|
1176 |
inputDict.update(common_word_timestamps_inputs())
|
1177 |
if isFull:
|
@@ -1250,12 +1265,16 @@ def create_ui(app_config: ApplicationConfig):
|
|
1250 |
with gr.Tab(label="seamless") as seamlessTab:
|
1251 |
with gr.Row():
|
1252 |
inputDict.update(common_seamless_inputs())
|
|
|
|
|
|
|
1253 |
m2m100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [translateInput] )
|
1254 |
nllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [translateInput] )
|
1255 |
mt5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [translateInput] )
|
1256 |
almaTab.select(fn=lambda: "ALMA", inputs = [], outputs= [translateInput] )
|
1257 |
madlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [translateInput] )
|
1258 |
seamlessTab.select(fn=lambda: "seamless", inputs = [], outputs= [translateInput] )
|
|
|
1259 |
with gr.Column():
|
1260 |
inputDict.update({
|
1261 |
gr.Dropdown(label="Input - Language", choices=sorted(get_lang_whisper_names()), value=app_config.language, elem_id="inputLangName"),
|
|
|
921 |
madlad400LangName: str = dataDict.pop("madlad400LangName")
|
922 |
seamlessModelName: str = dataDict.pop("seamlessModelName")
|
923 |
seamlessLangName: str = dataDict.pop("seamlessLangName")
|
924 |
+
LlamaModelName: str = dataDict.pop("LlamaModelName")
|
925 |
+
LlamaLangName: str = dataDict.pop("LlamaLangName")
|
926 |
|
927 |
translationBatchSize: int = dataDict.pop("translationBatchSize")
|
928 |
translationNoRepeatNgramSize: int = dataDict.pop("translationNoRepeatNgramSize")
|
|
|
956 |
selectedModelName = seamlessModelName if seamlessModelName is not None and len(seamlessModelName) > 0 else "seamless-m4t-v2-large/facebook"
|
957 |
selectedModel = next((modelConfig for modelConfig in self.app_config.models["seamless"] if modelConfig.name == selectedModelName), None)
|
958 |
translationLang = get_lang_from_seamlessT_Tx_name(seamlessLangName)
|
959 |
+
elif translateInput == "Llama" and LlamaLangName is not None and len(LlamaLangName) > 0:
|
960 |
+
selectedModelName = LlamaModelName if LlamaModelName is not None and len(LlamaModelName) > 0 else "Meta-Llama-3-8B-Instruct-ct2-int8_float16/avan"
|
961 |
+
selectedModel = next((modelConfig for modelConfig in self.app_config.models["Llama"] if modelConfig.name == selectedModelName), None)
|
962 |
+
translationLang = get_lang_from_m2m100_name(LlamaLangName)
|
963 |
|
964 |
if translationLang is not None:
|
965 |
translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=inputLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams, torchDtypeFloat16=translationTorchDtypeFloat16, usingBitsandbytes=translationUsingBitsandbytes)
|
|
|
1029 |
ALMA_models = app_config.get_model_names("ALMA")
|
1030 |
madlad400_models = app_config.get_model_names("madlad400")
|
1031 |
seamless_models = app_config.get_model_names("seamless")
|
1032 |
+
Llama_models = app_config.get_model_names("Llama")
|
1033 |
if not torch.cuda.is_available(): # Loading only quantized or models with medium-low parameters in an environment without GPU support.
|
1034 |
nllb_models = list(filter(lambda nllb: any(name in nllb for name in ["-600M", "-1.3B", "-3.3B-ct2"]), nllb_models))
|
1035 |
m2m100_models = list(filter(lambda m2m100: "12B" not in m2m100, m2m100_models))
|
|
|
1064 |
gr.Dropdown(label="seamless - Model (for translate)", choices=seamless_models, elem_id="seamlessModelName"),
|
1065 |
gr.Dropdown(label="seamless - Language", choices=sorted(get_lang_seamlessT_Tx_names()), elem_id="seamlessLangName"),
|
1066 |
}
|
1067 |
+
common_Llama_inputs = lambda : {
|
1068 |
+
gr.Dropdown(label="Llama - Model (for translate)", choices=Llama_models, elem_id="LlamaModelName"),
|
1069 |
+
gr.Dropdown(label="Llama - Language", choices=sorted(get_lang_m2m100_names()), elem_id="LlamaLangName"),
|
1070 |
+
}
|
1071 |
|
1072 |
common_translation_inputs = lambda : {
|
1073 |
gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
|
1074 |
+
gr.Number(label="Translation - No Repeat Ngram Size", precision=0, value=app_config.translation_no_repeat_ngram_size, elem_id="translationNoRepeatNgramSize", info="Prevent repetitions of ngrams with this size (set 0 to disable)."),
|
1075 |
+
gr.Number(label="Translation - Num Beams", precision=0, value=app_config.translation_num_beams, elem_id="translationNumBeams", info="Beam size (1 for greedy search)."),
|
1076 |
gr.Checkbox(label="Translation - Torch Dtype float16", visible=torch.cuda.is_available(), value=app_config.translation_torch_dtype_float16, info="Load the float32 translation model with float16 when the system supports GPU (reducing VRAM usage, not applicable to models that have already been quantized, such as Ctranslate2, GPTQ, GGUF)", elem_id="translationTorchDtypeFloat16"),
|
1077 |
gr.Radio(label="Translation - Using Bitsandbytes", visible=torch.cuda.is_available(), choices=[None, "int8", "int4"], value=app_config.translation_using_bitsandbytes, info="Load the float32 translation model into mixed-8bit or 4bit precision quantized model when the system supports GPU (reducing VRAM usage, not applicable to models that have already been quantized, such as Ctranslate2, GPTQ, GGUF)", elem_id="translationUsingBitsandbytes"),
|
1078 |
}
|
1079 |
|
1080 |
common_vad_inputs = lambda : {
|
1081 |
gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD", elem_id="vad"),
|
1082 |
+
gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window, elem_id="vadMergeWindow", info="If set, any adjacent speech sections that are at most this number of seconds apart will be automatically merged."),
|
1083 |
+
gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size, elem_id="vadMaxMergeSize", info="Disables merging of adjacent speech sections if they are this number of seconds long."),
|
1084 |
+
gr.Number(label="VAD - Process Timeout (s)", precision=0, value=app_config.vad_process_timeout, elem_id="vadPocessTimeout", info="This configures the number of seconds until a process is killed due to inactivity, freeing RAM and video memory. The default value is 30 minutes."),
|
1085 |
}
|
1086 |
|
1087 |
common_word_timestamps_inputs = lambda : {
|
|
|
1159 |
with gr.Tab(label="seamless") as seamlessTab:
|
1160 |
with gr.Row():
|
1161 |
inputDict.update(common_seamless_inputs())
|
1162 |
+
with gr.Tab(label="Llama") as llamaTab:
|
1163 |
+
with gr.Row():
|
1164 |
+
inputDict.update(common_Llama_inputs())
|
1165 |
m2m100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [translateInput] )
|
1166 |
nllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [translateInput] )
|
1167 |
mt5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [translateInput] )
|
1168 |
almaTab.select(fn=lambda: "ALMA", inputs = [], outputs= [translateInput] )
|
1169 |
madlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [translateInput] )
|
1170 |
seamlessTab.select(fn=lambda: "seamless", inputs = [], outputs= [translateInput] )
|
1171 |
+
llamaTab.select(fn=lambda: "Llama", inputs = [], outputs= [translateInput] )
|
1172 |
with gr.Column():
|
1173 |
with gr.Tab(label="URL") as UrlTab:
|
1174 |
inputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
|
|
1179 |
UrlTab.select(fn=lambda: "urlData", inputs = [], outputs= [sourceInput] )
|
1180 |
UploadTab.select(fn=lambda: "multipleFiles", inputs = [], outputs= [sourceInput] )
|
1181 |
MicTab.select(fn=lambda: "microphoneData", inputs = [], outputs= [sourceInput] )
|
1182 |
+
inputDict.update({gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task, elem_id = "task", info="Select the task - either \"transcribe\" to transcribe the audio to text, or \"translate\" to translate it to English.")})
|
1183 |
with gr.Accordion("VAD options", open=False):
|
1184 |
inputDict.update(common_vad_inputs())
|
1185 |
if isFull:
|
1186 |
inputDict.update({
|
1187 |
+
gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding, elem_id = "vadPadding", info="The number of seconds (floating point) to add to the beginning and end of each speech section. Setting this to a number larger than zero ensures that Whisper is more likely to correctly transcribe a sentence in the beginning of a speech section. However, this also increases the probability of Whisper assigning the wrong timestamp to each transcribed line. The default value is 1 second."),
|
1188 |
+
gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window, elem_id = "vadPromptWindow", info="The text of a detected line will be included as a prompt to the next speech section, if the speech section starts at most this number of seconds after the line has finished. For instance, if a line ends at 10:00, and the next speech section starts at 10:04, the line's text will be included if the prompt window is 4 seconds or more (10:04 - 10:00 = 4 seconds)."),
|
1189 |
+
gr.Dropdown(choices=VAD_INITIAL_PROMPT_MODE_VALUES, label="VAD - Initial Prompt Mode", value=app_config.vad_initial_prompt_mode, elem_id = "vadInitialPromptMode", info="prepend_all_segments: prepend the initial prompt to each VAD segment, prepend_first_segment: just the first segment")})
|
1190 |
with gr.Accordion("Word Timestamps options", open=False):
|
1191 |
inputDict.update(common_word_timestamps_inputs())
|
1192 |
if isFull:
|
|
|
1265 |
with gr.Tab(label="seamless") as seamlessTab:
|
1266 |
with gr.Row():
|
1267 |
inputDict.update(common_seamless_inputs())
|
1268 |
+
with gr.Tab(label="Llama") as llamaTab:
|
1269 |
+
with gr.Row():
|
1270 |
+
inputDict.update(common_Llama_inputs())
|
1271 |
m2m100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [translateInput] )
|
1272 |
nllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [translateInput] )
|
1273 |
mt5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [translateInput] )
|
1274 |
almaTab.select(fn=lambda: "ALMA", inputs = [], outputs= [translateInput] )
|
1275 |
madlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [translateInput] )
|
1276 |
seamlessTab.select(fn=lambda: "seamless", inputs = [], outputs= [translateInput] )
|
1277 |
+
llamaTab.select(fn=lambda: "Llama", inputs = [], outputs= [translateInput] )
|
1278 |
with gr.Column():
|
1279 |
inputDict.update({
|
1280 |
gr.Dropdown(label="Input - Language", choices=sorted(get_lang_whisper_names()), value=app_config.language, elem_id="inputLangName"),
|
config.json5
CHANGED
@@ -292,6 +292,14 @@
|
|
292 |
"url": "facebook/seamless-m4t-v2-large",
|
293 |
"type": "huggingface"
|
294 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
]
|
296 |
},
|
297 |
// Configuration options that will be used if they are not specified in the command line arguments.
|
|
|
292 |
"url": "facebook/seamless-m4t-v2-large",
|
293 |
"type": "huggingface"
|
294 |
}
|
295 |
+
],
|
296 |
+
"Llama": [
|
297 |
+
{
|
298 |
+
"name": "Meta-Llama-3-8B-Instruct-ct2-int8_float16/avan",
|
299 |
+
"url": "avans06/Meta-Llama-3-8B-Instruct-ct2-int8_float16",
|
300 |
+
"type": "huggingface",
|
301 |
+
"tokenizer_url": "avans06/Meta-Llama-3-8B-Instruct-ct2-int8_float16"
|
302 |
+
}
|
303 |
]
|
304 |
},
|
305 |
// Configuration options that will be used if they are not specified in the command line arguments.
|
src/config.py
CHANGED
@@ -50,7 +50,7 @@ class VadInitialPromptMode(Enum):
|
|
50 |
return None
|
51 |
|
52 |
class ApplicationConfig:
|
53 |
-
def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA", "madlad400", "seamless"], List[ModelConfig]],
|
54 |
input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
|
55 |
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
56 |
whisper_implementation: str = "whisper", default_model_name: str = "medium",
|
@@ -185,7 +185,7 @@ class ApplicationConfig:
|
|
185 |
# Load using json5
|
186 |
data = json5.load(f)
|
187 |
data_models = data.pop("models", [])
|
188 |
-
models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA", "madlad400", "seamless"], List[ModelConfig]] = {
|
189 |
key: [ModelConfig(**item) for item in value]
|
190 |
for key, value in data_models.items()
|
191 |
}
|
|
|
50 |
return None
|
51 |
|
52 |
class ApplicationConfig:
|
53 |
+
def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA", "madlad400", "seamless", "Llama"], List[ModelConfig]],
|
54 |
input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
|
55 |
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
56 |
whisper_implementation: str = "whisper", default_model_name: str = "medium",
|
|
|
185 |
# Load using json5
|
186 |
data = json5.load(f)
|
187 |
data_models = data.pop("models", [])
|
188 |
+
models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA", "madlad400", "seamless", "Llama"], List[ModelConfig]] = {
|
189 |
key: [ModelConfig(**item) for item in value]
|
190 |
for key, value in data_models.items()
|
191 |
}
|
src/translation/translationModel.py
CHANGED
@@ -27,7 +27,7 @@ class TranslationModel:
|
|
27 |
localFilesOnly: bool = False,
|
28 |
loadModel: bool = False,
|
29 |
):
|
30 |
-
"""Initializes the M2M100 / Nllb-200 / mt5 / ALMA / madlad400 / seamless-m4t translation model.
|
31 |
|
32 |
Args:
|
33 |
modelConfig: Config of the model to use (distilled-600M, distilled-1.3B,
|
@@ -230,6 +230,9 @@ class TranslationModel:
|
|
230 |
if "ALMA" in self.modelPath:
|
231 |
self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ":\n" + self.whisperLang.whisper.names[0] + ": "
|
232 |
self.transModel = ctranslate2.Generator(**kwargsModel)
|
|
|
|
|
|
|
233 |
else:
|
234 |
if "nllb" in self.modelPath:
|
235 |
kwargsTokenizer.update({"src_lang": self.whisperLang.nllb.code})
|
@@ -243,6 +246,8 @@ class TranslationModel:
|
|
243 |
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer)
|
244 |
if "m2m100" in self.modelPath:
|
245 |
self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
|
|
|
|
|
246 |
elif "mt5" in self.modelPath:
|
247 |
self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
|
248 |
kwargsTokenizer.update({"pretrained_model_name_or_path": self.modelPath, "legacy": False})
|
@@ -382,6 +387,12 @@ class TranslationModel:
|
|
382 |
output = self.transModel.generate_batch([source], max_length=max_length, max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams, sampling_temperature=0.7, sampling_topp=0.9, repetition_penalty=1.1, include_prompt_in_result=False) #, sampling_topk=40
|
383 |
target = output[0]
|
384 |
result = self.transTokenizer.decode(target.sequences_ids[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
elif "madlad400" in self.modelPath:
|
386 |
source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(self.madlad400Prefix + text))
|
387 |
output = self.transModel.translate_batch([source], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
|
@@ -424,7 +435,8 @@ _MODELS = ["nllb-200",
|
|
424 |
"mt5",
|
425 |
"ALMA",
|
426 |
"madlad400",
|
427 |
-
"seamless"
|
|
|
428 |
|
429 |
def check_model_name(name):
|
430 |
return any(allowed_name in name for allowed_name in _MODELS)
|
|
|
27 |
localFilesOnly: bool = False,
|
28 |
loadModel: bool = False,
|
29 |
):
|
30 |
+
"""Initializes the M2M100 / Nllb-200 / mt5 / ALMA / madlad400 / seamless-m4t / Llama translation model.
|
31 |
|
32 |
Args:
|
33 |
modelConfig: Config of the model to use (distilled-600M, distilled-1.3B,
|
|
|
230 |
if "ALMA" in self.modelPath:
|
231 |
self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ":\n" + self.whisperLang.whisper.names[0] + ": "
|
232 |
self.transModel = ctranslate2.Generator(**kwargsModel)
|
233 |
+
elif "Llama" in self.modelPath:
|
234 |
+
self.roleSystem = {"role": "system", "content":"You are an excellent and professional translation master who understands languages from all around the world. Please directly translate the following sentence from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ", please simply provide the translation below without further explanation and without using any emojis."}
|
235 |
+
self.transModel = ctranslate2.Generator(**kwargsModel)
|
236 |
else:
|
237 |
if "nllb" in self.modelPath:
|
238 |
kwargsTokenizer.update({"src_lang": self.whisperLang.nllb.code})
|
|
|
246 |
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer)
|
247 |
if "m2m100" in self.modelPath:
|
248 |
self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
|
249 |
+
elif "Llama" in self.modelPath:
|
250 |
+
self.terminators = [self.transTokenizer.eos_token_id, self.transTokenizer.convert_tokens_to_ids("<|eot_id|>")]
|
251 |
elif "mt5" in self.modelPath:
|
252 |
self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
|
253 |
kwargsTokenizer.update({"pretrained_model_name_or_path": self.modelPath, "legacy": False})
|
|
|
387 |
output = self.transModel.generate_batch([source], max_length=max_length, max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams, sampling_temperature=0.7, sampling_topp=0.9, repetition_penalty=1.1, include_prompt_in_result=False) #, sampling_topk=40
|
388 |
target = output[0]
|
389 |
result = self.transTokenizer.decode(target.sequences_ids[0])
|
390 |
+
elif "Llama" in self.modelPath:
|
391 |
+
input_ids = self.transTokenizer.apply_chat_template([self.roleSystem, {"role": "user", "content": "'" + text + "', \n" + self.translationLang.whisper.names[0] + ":"}], tokenize=False, add_generation_prompt=True)
|
392 |
+
source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(input_ids))
|
393 |
+
output = self.transModel.generate_batch([source], max_length=max_length, max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams, sampling_temperature=0.7, sampling_topp=0.9, include_prompt_in_result=False, end_token=self.terminators)
|
394 |
+
target = output[0]
|
395 |
+
result = self.transTokenizer.decode(target.sequences_ids[0])
|
396 |
elif "madlad400" in self.modelPath:
|
397 |
source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(self.madlad400Prefix + text))
|
398 |
output = self.transModel.translate_batch([source], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
|
|
|
435 |
"mt5",
|
436 |
"ALMA",
|
437 |
"madlad400",
|
438 |
+
"seamless"
|
439 |
+
"Llama"]
|
440 |
|
441 |
def check_model_name(name):
|
442 |
return any(allowed_name in name for allowed_name in _MODELS)
|