Spaces:
Running
Running
SoybeanMilk
commited on
Commit
•
4e2f72e
1
Parent(s):
2def7a1
Add madlad400 support.
Browse files- app.py +19 -0
- config.json5 +14 -0
- src/config.py +2 -2
- src/translation/translationModel.py +11 -1
app.py
CHANGED
@@ -233,6 +233,8 @@ class WhisperTranscriber:
|
|
233 |
mt5LangName: str = decodeOptions.pop("mt5LangName")
|
234 |
ALMAModelName: str = decodeOptions.pop("ALMAModelName")
|
235 |
ALMALangName: str = decodeOptions.pop("ALMALangName")
|
|
|
|
|
236 |
|
237 |
translationBatchSize: int = decodeOptions.pop("translationBatchSize")
|
238 |
translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
|
@@ -368,6 +370,10 @@ class WhisperTranscriber:
|
|
368 |
selectedModelName = ALMAModelName if ALMAModelName is not None and len(ALMAModelName) > 0 else "ALMA-13B-GPTQ/TheBloke"
|
369 |
selectedModel = next((modelConfig for modelConfig in self.app_config.models["ALMA"] if modelConfig.name == selectedModelName), None)
|
370 |
translationLang = get_lang_from_m2m100_name(ALMALangName)
|
|
|
|
|
|
|
|
|
371 |
|
372 |
if translationLang is not None:
|
373 |
translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
|
@@ -929,6 +935,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
929 |
m2m100_models = app_config.get_model_names("m2m100")
|
930 |
mt5_models = app_config.get_model_names("mt5")
|
931 |
ALMA_models = app_config.get_model_names("ALMA")
|
|
|
932 |
if not torch.cuda.is_available(): #Due to the poor support of GPTQ for CPUs, the execution time per iteration exceeds a thousand seconds when operating on a CPU. Therefore, when the system does not support a GPU, the GPTQ model is removed from the list.
|
933 |
ALMA_models = list(filter(lambda alma: "GPTQ" not in alma, ALMA_models))
|
934 |
|
@@ -952,6 +959,10 @@ def create_ui(app_config: ApplicationConfig):
|
|
952 |
gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
|
953 |
gr.Dropdown(label="ALMA - Language", choices=sort_lang_by_whisper_codes(["en", "de", "cs", "is", "ru", "zh", "ja"]), elem_id="ALMALangName"),
|
954 |
}
|
|
|
|
|
|
|
|
|
955 |
|
956 |
common_translation_inputs = lambda : {
|
957 |
gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
|
@@ -1036,10 +1047,14 @@ def create_ui(app_config: ApplicationConfig):
|
|
1036 |
with gr.Tab(label="ALMA") as simpleALMATab:
|
1037 |
with gr.Row():
|
1038 |
simpleInputDict.update(common_ALMA_inputs())
|
|
|
|
|
|
|
1039 |
simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
|
1040 |
simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
|
1041 |
simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
|
1042 |
simpleALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [simpleTranslateInput] )
|
|
|
1043 |
with gr.Column():
|
1044 |
with gr.Tab(label="URL") as simpleUrlTab:
|
1045 |
simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
@@ -1103,10 +1118,14 @@ def create_ui(app_config: ApplicationConfig):
|
|
1103 |
with gr.Tab(label="ALMA") as fullALMATab:
|
1104 |
with gr.Row():
|
1105 |
fullInputDict.update(common_ALMA_inputs())
|
|
|
|
|
|
|
1106 |
fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
|
1107 |
fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
|
1108 |
fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
|
1109 |
fullALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [fullTranslateInput] )
|
|
|
1110 |
with gr.Column():
|
1111 |
with gr.Tab(label="URL") as fullUrlTab:
|
1112 |
fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
|
|
233 |
mt5LangName: str = decodeOptions.pop("mt5LangName")
|
234 |
ALMAModelName: str = decodeOptions.pop("ALMAModelName")
|
235 |
ALMALangName: str = decodeOptions.pop("ALMALangName")
|
236 |
+
madlad400ModelName: str = decodeOptions.pop("madlad400ModelName")
|
237 |
+
madlad400LangName: str = decodeOptions.pop("madlad400LangName")
|
238 |
|
239 |
translationBatchSize: int = decodeOptions.pop("translationBatchSize")
|
240 |
translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
|
|
|
370 |
selectedModelName = ALMAModelName if ALMAModelName is not None and len(ALMAModelName) > 0 else "ALMA-13B-GPTQ/TheBloke"
|
371 |
selectedModel = next((modelConfig for modelConfig in self.app_config.models["ALMA"] if modelConfig.name == selectedModelName), None)
|
372 |
translationLang = get_lang_from_m2m100_name(ALMALangName)
|
373 |
+
elif translateInput == "madlad400" and madlad400LangName is not None and len(madlad400LangName) > 0:
|
374 |
+
selectedModelName = madlad400ModelName if madlad400ModelName is not None and len(madlad400ModelName) > 0 else "madlad400-10b-mt-ct2-int8_float16"
|
375 |
+
selectedModel = next((modelConfig for modelConfig in self.app_config.models["madlad400"] if modelConfig.name == selectedModelName), None)
|
376 |
+
translationLang = get_lang_from_m2m100_name(madlad400LangName)
|
377 |
|
378 |
if translationLang is not None:
|
379 |
translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
|
|
|
935 |
m2m100_models = app_config.get_model_names("m2m100")
|
936 |
mt5_models = app_config.get_model_names("mt5")
|
937 |
ALMA_models = app_config.get_model_names("ALMA")
|
938 |
+
madlad400_models = app_config.get_model_names("madlad400")
|
939 |
if not torch.cuda.is_available(): #Due to the poor support of GPTQ for CPUs, the execution time per iteration exceeds a thousand seconds when operating on a CPU. Therefore, when the system does not support a GPU, the GPTQ model is removed from the list.
|
940 |
ALMA_models = list(filter(lambda alma: "GPTQ" not in alma, ALMA_models))
|
941 |
|
|
|
959 |
gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
|
960 |
gr.Dropdown(label="ALMA - Language", choices=sort_lang_by_whisper_codes(["en", "de", "cs", "is", "ru", "zh", "ja"]), elem_id="ALMALangName"),
|
961 |
}
|
962 |
+
common_madlad400_inputs = lambda : {
|
963 |
+
gr.Dropdown(label="madlad400 - Model (for translate)", choices=madlad400_models, elem_id="madlad400ModelName"),
|
964 |
+
gr.Dropdown(label="madlad400 - Language", choices=sorted(get_lang_m2m100_names()), elem_id="madlad400LangName"),
|
965 |
+
}
|
966 |
|
967 |
common_translation_inputs = lambda : {
|
968 |
gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
|
|
|
1047 |
with gr.Tab(label="ALMA") as simpleALMATab:
|
1048 |
with gr.Row():
|
1049 |
simpleInputDict.update(common_ALMA_inputs())
|
1050 |
+
with gr.Tab(label="madlad400") as simplemadlad400Tab:
|
1051 |
+
with gr.Row():
|
1052 |
+
simpleInputDict.update(common_madlad400_inputs())
|
1053 |
simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
|
1054 |
simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
|
1055 |
simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
|
1056 |
simpleALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [simpleTranslateInput] )
|
1057 |
+
simplemadlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [simpleTranslateInput] )
|
1058 |
with gr.Column():
|
1059 |
with gr.Tab(label="URL") as simpleUrlTab:
|
1060 |
simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
|
|
1118 |
with gr.Tab(label="ALMA") as fullALMATab:
|
1119 |
with gr.Row():
|
1120 |
fullInputDict.update(common_ALMA_inputs())
|
1121 |
+
with gr.Tab(label="madlad400") as fullmadlad400Tab:
|
1122 |
+
with gr.Row():
|
1123 |
+
fullInputDict.update(common_madlad400_inputs())
|
1124 |
fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
|
1125 |
fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
|
1126 |
fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
|
1127 |
fullALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [fullTranslateInput] )
|
1128 |
+
fullmadlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [fullTranslateInput] )
|
1129 |
with gr.Column():
|
1130 |
with gr.Tab(label="URL") as fullUrlTab:
|
1131 |
fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
config.json5
CHANGED
@@ -229,6 +229,20 @@
|
|
229 |
"type": "huggingface",
|
230 |
"tokenizer_url": "haoranxu/ALMA-13B"
|
231 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
]
|
233 |
},
|
234 |
// Configuration options that will be used if they are not specified in the command line arguments.
|
|
|
229 |
"type": "huggingface",
|
230 |
"tokenizer_url": "haoranxu/ALMA-13B"
|
231 |
},
|
232 |
+
],
|
233 |
+
"madlad400": [
|
234 |
+
{
|
235 |
+
"name": "madlad400-3b-mt-ct2-int8_float16/SoybeanMilk",
|
236 |
+
"url": "SoybeanMilk/madlad400-3b-mt-ct2-int8_float16",
|
237 |
+
"type": "huggingface",
|
238 |
+
"tokenizer_url": "jbochi/madlad400-3b-mt"
|
239 |
+
},
|
240 |
+
{
|
241 |
+
"name": "madlad400-10b-mt-ct2-int8_float16/SoybeanMilk",
|
242 |
+
"url": "SoybeanMilk/madlad400-10b-mt-ct2-int8_float16",
|
243 |
+
"type": "huggingface",
|
244 |
+
"tokenizer_url": "jbochi/madlad400-10b-mt"
|
245 |
+
},
|
246 |
]
|
247 |
},
|
248 |
// Configuration options that will be used if they are not specified in the command line arguments.
|
src/config.py
CHANGED
@@ -50,7 +50,7 @@ class VadInitialPromptMode(Enum):
|
|
50 |
return None
|
51 |
|
52 |
class ApplicationConfig:
|
53 |
-
def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]],
|
54 |
input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
|
55 |
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
56 |
whisper_implementation: str = "whisper", default_model_name: str = "medium",
|
@@ -181,7 +181,7 @@ class ApplicationConfig:
|
|
181 |
# Load using json5
|
182 |
data = json5.load(f)
|
183 |
data_models = data.pop("models", [])
|
184 |
-
models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]] = {
|
185 |
key: [ModelConfig(**item) for item in value]
|
186 |
for key, value in data_models.items()
|
187 |
}
|
|
|
50 |
return None
|
51 |
|
52 |
class ApplicationConfig:
|
53 |
+
def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA", "madlad400"], List[ModelConfig]],
|
54 |
input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
|
55 |
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
56 |
whisper_implementation: str = "whisper", default_model_name: str = "medium",
|
|
|
181 |
# Load using json5
|
182 |
data = json5.load(f)
|
183 |
data_models = data.pop("models", [])
|
184 |
+
models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA", "madlad400"], List[ModelConfig]] = {
|
185 |
key: [ModelConfig(**item) for item in value]
|
186 |
for key, value in data_models.items()
|
187 |
}
|
src/translation/translationModel.py
CHANGED
@@ -159,6 +159,10 @@ class TranslationModel:
|
|
159 |
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath)
|
160 |
self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ":\n" + self.whisperLang.whisper.names[0] + ": "
|
161 |
self.transModel = ctranslate2.Generator(self.modelPath, compute_type="auto", device=self.device)
|
|
|
|
|
|
|
|
|
162 |
elif "mt5" in self.modelPath:
|
163 |
self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
|
164 |
self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
|
@@ -277,6 +281,11 @@ class TranslationModel:
|
|
277 |
output = self.transModel.generate_batch([source], max_length=max_length, max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams, sampling_temperature=0.7, sampling_topp=0.9, repetition_penalty=1.1, include_prompt_in_result=False) #, sampling_topk=40
|
278 |
target = output[0]
|
279 |
result = self.transTokenizer.decode(target.sequences_ids[0])
|
|
|
|
|
|
|
|
|
|
|
280 |
elif "mt5" in self.modelPath:
|
281 |
output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
|
282 |
result = output[0]['generated_text']
|
@@ -299,7 +308,8 @@ class TranslationModel:
|
|
299 |
_MODELS = ["nllb-200",
|
300 |
"m2m100",
|
301 |
"mt5",
|
302 |
-
"ALMA"
|
|
|
303 |
|
304 |
def check_model_name(name):
|
305 |
return any(allowed_name in name for allowed_name in _MODELS)
|
|
|
159 |
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath)
|
160 |
self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ":\n" + self.whisperLang.whisper.names[0] + ": "
|
161 |
self.transModel = ctranslate2.Generator(self.modelPath, compute_type="auto", device=self.device)
|
162 |
+
elif "madlad400" in self.modelPath:
|
163 |
+
self.madlad400Prefix = "<2" + self.translationLang.whisper.code + "> "
|
164 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
|
165 |
+
self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
|
166 |
elif "mt5" in self.modelPath:
|
167 |
self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
|
168 |
self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
|
|
|
281 |
output = self.transModel.generate_batch([source], max_length=max_length, max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams, sampling_temperature=0.7, sampling_topp=0.9, repetition_penalty=1.1, include_prompt_in_result=False) #, sampling_topk=40
|
282 |
target = output[0]
|
283 |
result = self.transTokenizer.decode(target.sequences_ids[0])
|
284 |
+
elif "madlad400" in self.modelPath:
|
285 |
+
source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(self.madlad400Prefix + text))
|
286 |
+
output = self.transModel.translate_batch([source], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
|
287 |
+
target = output[0].hypotheses[0]
|
288 |
+
result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target))
|
289 |
elif "mt5" in self.modelPath:
|
290 |
output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
|
291 |
result = output[0]['generated_text']
|
|
|
308 |
_MODELS = ["nllb-200",
|
309 |
"m2m100",
|
310 |
"mt5",
|
311 |
+
"ALMA",
|
312 |
+
"madlad400"]
|
313 |
|
314 |
def check_model_name(name):
|
315 |
return any(allowed_name in name for allowed_name in _MODELS)
|