|
import torch |
|
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast |
|
from lang_list import LANGUAGE_NAME_TO_CODE, WHISPER_LANGUAGES |
|
import argparse |
|
import re |
|
|
|
language_dict = {} |
|
|
|
for language_name, language_code in LANGUAGE_NAME_TO_CODE.items(): |
|
|
|
lang_code = language_code.split('_')[0].lower() |
|
|
|
|
|
if lang_code in WHISPER_LANGUAGES: |
|
|
|
language_dict[language_name] = { |
|
"transcriber": lang_code, |
|
"translator": language_code |
|
} |
|
|
|
|
|
|
|
def translate(transcribed_text, source_languaje, target_languaje, translate_model, translate_tokenizer, device="cpu"): |
|
|
|
source_languaje_code = language_dict[source_languaje]["translator"] |
|
target_languaje_code = language_dict[target_languaje]["translator"] |
|
|
|
encoded = translate_tokenizer(transcribed_text, return_tensors="pt").to(device) |
|
generated_tokens = translate_model.generate( |
|
**encoded, |
|
forced_bos_token_id=translate_tokenizer.lang_code_to_id[target_languaje_code] |
|
) |
|
translated = translate_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
|
|
return translated |
|
|
|
def main(transcription_file, source_languaje, target_languaje, translate_model, translate_tokenizer, device): |
|
output_folder = "translated_transcriptions" |
|
_, transcription_file_name = transcription_file.split("/") |
|
transcription_file_name, _ = transcription_file_name.split(".") |
|
|
|
|
|
with open(transcription_file, "r") as f: |
|
transcription = f.read().splitlines() |
|
|
|
|
|
translate_transcription = "" |
|
for line in transcription: |
|
if re.match(r"\d+$", line): |
|
translate_transcription += f"{line}\n" |
|
elif re.match(r"\d\d:\d\d:\d\d,\d\d\d --> \d\d:\d\d:\d\d,\d\d\d", line): |
|
translate_transcription += f"{line}\n" |
|
elif re.match(r"^$", line): |
|
translate_transcription += f"{line}\n" |
|
else: |
|
translated = translate(line, source_languaje, target_languaje, translate_model, translate_tokenizer, device) |
|
|
|
translate_transcription += f"{translated}\n" |
|
|
|
|
|
output_file = f"{output_folder}/{transcription_file_name}_{target_languaje}.srt" |
|
with open(output_file, "w") as f: |
|
f.write(translate_transcription) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("transcription_file", help="Transcribed text") |
|
parser.add_argument("--source_languaje", type=str, required=True) |
|
parser.add_argument("--target_languaje", type=str, required=True) |
|
parser.add_argument("--device", type=str, default="cpu") |
|
args = parser.parse_args() |
|
|
|
transcription_file = args.transcription_file |
|
source_languaje = args.source_languaje |
|
target_languaje = args.target_languaje |
|
device = args.device |
|
|
|
|
|
print("Loading translation model") |
|
translate_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device) |
|
translate_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") |
|
print("Translation model loaded") |
|
|
|
main(transcription_file, source_languaje, target_languaje, translate_model, translate_tokenizer, device) |
|
|