import numpy as np |
import torch |
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast |
import argparse |
import re |
from tqdm import tqdm |
MAX_LENGTH = 500 |
MAGIC_STRING = "[$&]" |
DEBUG = False |
language_dict = {} |
for language_name, language_code in LANGUAGE_NAME_TO_CODE.items(): |
lang_code = language_code.split('_')[0].lower() |
if lang_code in WHISPER_LANGUAGES: |
language_dict[language_name] = { |
"transcriber": lang_code, |
"translator": language_code |
} |
def translate(transcribed_text, source_languaje, target_languaje, translate_model, translate_tokenizer, device="cpu"): |
source_languaje_code = language_dict[source_languaje]["translator"] |
target_languaje_code = language_dict[target_languaje]["translator"] |
encoded = translate_tokenizer(transcribed_text, return_tensors="pt").to(device) |
generated_tokens = translate_model.generate( |
**encoded, |
forced_bos_token_id=translate_tokenizer.lang_code_to_id[target_languaje_code] |
) |
translated = translate_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
return translated |
def main(transcription_file, source_languaje, target_languaje, translate_model, translate_tokenizer, device): |
output_folder = "translated_transcriptions" |
_, transcription_file_name = transcription_file.split("/") |
transcription_file_name, _ = transcription_file_name.split(".") |
with open(transcription_file, "r") as f: |
transcription = f.read().splitlines() |
raw_transcription = "" |
progress_bar = tqdm(total=len(transcription), desc='Concatenate transcriptions progress') |
for line in transcription: |
if re.match(r"\d+$", line): |
pass |
elif re.match(r"\d\d:\d\d:\d\d,\d+ --> \d\d:\d\d:\d\d,\d+", line): |
pass |
elif re.match(r"^$", line): |
pass |
else: |
line = re.sub(r"\[SPEAKER_\d\d\]:", MAGIC_STRING, line) |
raw_transcription += f"{line} " |
progress_bar.update(1) |
progress_bar.close() |
if DEBUG: |
output_file = f"{output_folder}/{transcription_file_name}_raw.srt" |
with open(output_file, "w") as f: |
f.write(raw_transcription) |
raw_transcription_list = raw_transcription.split(MAGIC_STRING) |
if raw_transcription_list[0] == "": |
raw_transcription_list = raw_transcription_list[1:] |
translated_transcription = "" |
concatenate_transcription = raw_transcription_list[0] + MAGIC_STRING |
progress_bar = tqdm(total=len(raw_transcription_list), desc='Translate transcriptions progress') |
progress_bar.update(1) |
if len(raw_transcription_list) > 1: |
for transcription in raw_transcription_list[1:]: |
if len(concatenate_transcription) + len(transcription) < MAX_LENGTH: |
concatenate_transcription += transcription + MAGIC_STRING |
else: |
translation = translate(concatenate_transcription, source_languaje, target_languaje, translate_model, translate_tokenizer, device) |
translated_transcription += translation |
concatenate_transcription = transcription + MAGIC_STRING |
progress_bar.update(1) |
translation = translate(concatenate_transcription, source_languaje, target_languaje, translate_model, translate_tokenizer, device) |
translated_transcription += translation |
else: |
translated_transcription = translate(concatenate_transcription, source_languaje, target_languaje, translate_model, translate_tokenizer, device) |
progress_bar.close() |
if DEBUG: |
output_file = f"{output_folder}/{transcription_file_name}_{target_languaje}_raw.srt" |
with open(output_file, "w") as f: |
f.write(translated_transcription) |
with open(transcription_file, "r") as f: |
transcription = f.read().splitlines() |
translated_transcription_time_stamps = "" |
translated_transcription_list = translated_transcription.split(MAGIC_STRING) |
progress_bar = tqdm(total=len(translated_transcription_list), desc='Add time stamps to translated transcriptions progress') |
i = 0 |
for line in transcription: |
if re.match(r"\d+$", line): |
translated_transcription_time_stamps += f"{line}\n" |
elif re.match(r"\d\d:\d\d:\d\d,\d+ --> \d\d:\d\d:\d\d,\d+", line): |
translated_transcription_time_stamps += f"{line}\n" |
elif re.match(r"^$", line): |
translated_transcription_time_stamps += f"{line}\n" |
else: |
if (i < len(translated_transcription_list)): |
if len(translated_transcription_list[i]) > 0: |
if translated_transcription_list[i][0] == " ": |
translated_transcription_list[i] = translated_transcription_list[i][1:] |
speaker = "" |
if re.match(r"\[SPEAKER_\d\d\]:", line): |
speaker = re.match(r"\[SPEAKER_\d\d\]:", line).group(0) |
translated_transcription_time_stamps += f"{speaker} {translated_transcription_list[i]}\n" |
i += 1 |
progress_bar.update(1) |
progress_bar.close() |
output_file = f"{output_folder}/{transcription_file_name}_{target_languaje}.srt" |
with open(output_file, "w") as f: |
f.write(translated_transcription_time_stamps) |
if __name__ == "__main__": |
parser = argparse.ArgumentParser() |
parser.add_argument("transcription_file", help="Transcribed text") |
parser.add_argument("--source_languaje", type=str, required=True) |
parser.add_argument("--target_languaje", type=str, required=True) |
parser.add_argument("--device", type=str, default="cpu") |
args = parser.parse_args() |
transcription_file = args.transcription_file |
source_languaje = args.source_languaje |
target_languaje = args.target_languaje |
device = args.device |
print("Loading translation model") |
translate_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device) |
translate_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") |
print("Translation model loaded") |
main(transcription_file, source_languaje, target_languaje, translate_model, translate_tokenizer, device) |