Spaces:
Running
Running
import threading | |
import time | |
import openai | |
from pytube import YouTube | |
from os import getenv, getcwd | |
from pathlib import Path | |
from enum import Enum, auto | |
import logging | |
import subprocess | |
from src.srt_util.srt import SrtScript | |
from src.srt_util.srt2ass import srt2ass | |
from time import time, strftime, gmtime, sleep | |
from src.translators.translation import get_translation, prompt_selector | |
import torch | |
import stable_whisper | |
import shutil | |
""" | |
Youtube link | |
- link | |
- model | |
- output type | |
Video file | |
- path | |
- model | |
- output type | |
Audio file | |
- path | |
- model | |
- output type | |
""" | |
""" | |
TaskID | |
Progress: Enum | |
Computing resrouce status | |
SRT_Script : SrtScript | |
- input module -> initialize (ASR module) | |
- Pre-process | |
- Translation (%) | |
- Post process (time stamp) | |
- Output module: SRT_Script --> output(.srt) | |
- (Optional) mp4 | |
""" | |
class TaskStatus(str, Enum): | |
CREATED = 'CREATED' | |
INITIALIZING_ASR = 'INITIALIZING_ASR' | |
PRE_PROCESSING = 'PRE_PROCESSING' | |
TRANSLATING = 'TRANSLATING' | |
POST_PROCESSING = 'POST_PROCESSING' | |
OUTPUT_MODULE = 'OUTPUT_MODULE' | |
class Task: | |
def status(self): | |
with self.__status_lock: | |
return self.__status | |
def status(self, new_status): | |
with self.__status_lock: | |
self.__status = new_status | |
def __init__(self, task_id, task_local_dir, task_cfg): | |
self.__status_lock = threading.Lock() | |
self.__status = TaskStatus.CREATED | |
self.gpu_status = 0 | |
openai.api_key = getenv("OPENAI_API_KEY") | |
self.task_id = task_id | |
self.task_local_dir = task_local_dir | |
self.ASR_setting = task_cfg["ASR"] | |
self.translation_setting = task_cfg["translation"] | |
self.translation_model = self.translation_setting["model"] | |
self.output_type = task_cfg["output_type"] | |
self.target_lang = task_cfg["target_lang"] | |
self.source_lang = task_cfg["source_lang"] | |
self.field = task_cfg["field"] | |
self.pre_setting = task_cfg["pre_process"] | |
self.post_setting = task_cfg["post_process"] | |
self.audio_path = None | |
self.SRT_Script = None | |
self.result = None | |
self.s_t = None | |
self.t_e = None | |
print(f"Task ID: {self.task_id}") | |
logging.info(f"Task ID: {self.task_id}") | |
logging.info(f"{self.source_lang} -> {self.target_lang} task in {self.field}") | |
logging.info(f"Translation Model: {self.translation_model}") | |
logging.info(f"subtitle_type: {self.output_type['subtitle']}") | |
logging.info(f"video_ouput: {self.output_type['video']}") | |
logging.info(f"bilingual_ouput: {self.output_type['bilingual']}") | |
logging.info("Pre-process setting:") | |
for key in self.pre_setting: | |
logging.info(f"{key}: {self.pre_setting[key]}") | |
logging.info("Post-process setting:") | |
for key in self.post_setting: | |
logging.info(f"{key}: {self.post_setting[key]}") | |
def fromYoutubeLink(youtube_url, task_id, task_dir, task_cfg): | |
# convert to audio | |
logging.info("Task Creation method: Youtube Link") | |
return YoutubeTask(task_id, task_dir, task_cfg, youtube_url) | |
def fromAudioFile(audio_path, task_id, task_dir, task_cfg): | |
# get audio path | |
logging.info("Task Creation method: Audio File") | |
return AudioTask(task_id, task_dir, task_cfg, audio_path) | |
def fromVideoFile(video_path, task_id, task_dir, task_cfg): | |
# get audio path | |
logging.info("Task Creation method: Video File") | |
return VideoTask(task_id, task_dir, task_cfg, video_path) | |
# Module 1 ASR: audio --> SRT_script | |
def get_srt_class(self): | |
# Instead of using the script_en variable directly, we'll use script_input | |
# TODO: setup ASR module like translator | |
self.status = TaskStatus.INITIALIZING_ASR | |
self.t_s = time() | |
method = self.ASR_setting["whisper_config"]["method"] | |
whisper_model = self.ASR_setting["whisper_config"]["whisper_model"] | |
src_srt_path = self.task_local_dir.joinpath(f"task_{self.task_id}_{self.source_lang}.srt") | |
if not Path.exists(src_srt_path): | |
# extract script from audio | |
logging.info("extract script from audio") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if method == "api": | |
with open(self.audio_path, 'rb') as audio_file: | |
transcript = openai.Audio.transcribe(model="whisper-1", file=audio_file, response_format="srt") | |
elif method == "stable": | |
model = stable_whisper.load_model(whisper_model, device) | |
transcript = model.transcribe(str(self.audio_path), regroup=False, | |
initial_prompt="Hello, welcome to my lecture. Are you good my friend?") | |
( | |
transcript | |
.split_by_punctuation(['.', '。', '?']) | |
.merge_by_gap(.15, max_words=3) | |
.merge_by_punctuation([' ']) | |
.split_by_punctuation(['.', '。', '?']) | |
) | |
transcript = transcript.to_dict() | |
# after get the transcript, release the gpu resource | |
torch.cuda.empty_cache() | |
self.SRT_Script = SrtScript(self.source_lang, self.target_lang, transcript['segments'], self.field) | |
# save the srt script to local | |
self.SRT_Script.write_srt_file_src(src_srt_path) | |
# Module 2: SRT preprocess: perform preprocess steps | |
def preprocess(self): | |
self.status = TaskStatus.PRE_PROCESSING | |
logging.info("--------------------Start Preprocessing SRT class--------------------") | |
if self.pre_setting["sentence_form"]: | |
self.SRT_Script.form_whole_sentence() | |
if self.pre_setting["spell_check"]: | |
self.SRT_Script.spell_check_term() | |
if self.pre_setting["term_correct"]: | |
self.SRT_Script.correct_with_force_term() | |
processed_srt_path_src = str(Path(self.task_local_dir) / f'{self.task_id}_processed.srt') | |
self.SRT_Script.write_srt_file_src(processed_srt_path_src) | |
if self.output_type["subtitle"] == "ass": | |
logging.info("write English .srt file to .ass") | |
assSub_src = srt2ass(processed_srt_path_src, "default", "No", "Modest") | |
logging.info('ASS subtitle saved as: ' + assSub_src) | |
self.script_input = self.SRT_Script.get_source_only() | |
pass | |
def update_translation_progress(self, new_progress): | |
if self.progress == TaskStatus.TRANSLATING: | |
self.progress = TaskStatus.TRANSLATING.value[0], new_progress | |
# Module 3: perform srt translation | |
def translation(self): | |
logging.info("---------------------Start Translation--------------------") | |
prompt = prompt_selector(self.source_lang, self.target_lang, self.field) | |
get_translation(self.SRT_Script, self.translation_model, self.task_id, prompt, self.translation_setting['chunk_size']) | |
# Module 4: perform srt post process steps | |
def postprocess(self): | |
self.status = TaskStatus.POST_PROCESSING | |
logging.info("---------------------Start Post-processing SRT class---------------------") | |
if self.post_setting["check_len_and_split"]: | |
self.SRT_Script.check_len_and_split() | |
if self.post_setting["remove_trans_punctuation"]: | |
self.SRT_Script.remove_trans_punctuation() | |
logging.info("---------------------Post-processing SRT class finished---------------------") | |
# Module 5: output module | |
def output_render(self): | |
self.status = TaskStatus.OUTPUT_MODULE | |
video_out = self.output_type["video"] | |
subtitle_type = self.output_type["subtitle"] | |
is_bilingual = self.output_type["bilingual"] | |
results_dir =f"{self.task_local_dir}/results" | |
subtitle_path = f"{results_dir}/{self.task_id}_{self.target_lang}.srt" | |
self.SRT_Script.write_srt_file_translate(subtitle_path) | |
if is_bilingual: | |
subtitle_path = f"{results_dir}/{self.task_id}_{self.source_lang}_{self.target_lang}.srt" | |
self.SRT_Script.write_srt_file_bilingual(subtitle_path) | |
if subtitle_type == "ass": | |
logging.info("write .srt file to .ass") | |
subtitle_path = srt2ass(subtitle_path, "default", "No", "Modest") | |
logging.info('ASS subtitle saved as: ' + subtitle_path) | |
final_res = subtitle_path | |
# encode to .mp4 video file | |
if video_out and self.video_path is not None: | |
logging.info("encoding video file") | |
logging.info(f'ffmpeg comand: \nffmpeg -i {self.video_path} -vf "subtitles={subtitle_path}" {results_dir}/{self.task_id}.mp4') | |
subprocess.run( | |
["ffmpeg", | |
"-i", self.video_path, | |
"-vf", f"subtitles={subtitle_path}", | |
f"{results_dir}/{self.task_id}.mp4"]) | |
final_res = f"{results_dir}/{self.task_id}.mp4" | |
self.t_e = time() | |
logging.info( | |
"Pipeline finished, time duration:{}".format(strftime("%H:%M:%S", gmtime(self.t_e - self.t_s)))) | |
return final_res | |
def run_pipeline(self): | |
self.get_srt_class() | |
self.preprocess() | |
self.translation() | |
self.postprocess() | |
self.result = self.output_render() | |
# print(self.result) | |
class YoutubeTask(Task): | |
def __init__(self, task_id, task_local_dir, task_cfg, youtube_url): | |
super().__init__(task_id, task_local_dir, task_cfg) | |
self.youtube_url = youtube_url | |
def run(self): | |
yt = YouTube(self.youtube_url) | |
video = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first() | |
if video: | |
video.download(str(self.task_local_dir), filename=f"task_{self.task_id}.mp4") | |
logging.info(f'Video Name: {video.default_filename}') | |
else: | |
raise FileNotFoundError(f" Video stream not found for link {self.youtube_url}") | |
audio = yt.streams.filter(only_audio=True).first() | |
if audio: | |
audio.download(str(self.task_local_dir), filename=f"task_{self.task_id}.mp3") | |
else: | |
logging.info(" download audio failed, using ffmpeg to extract audio") | |
subprocess.run( | |
['ffmpeg', '-i', self.task_local_dir.joinpath(f"task_{self.task_id}.mp4"), '-f', 'mp3', | |
'-ab', '192000', '-vn', self.task_local_dir.joinpath(f"task_{self.task_id}.mp3")]) | |
logging.info("audio extraction finished") | |
self.video_path = self.task_local_dir.joinpath(f"task_{self.task_id}.mp4") | |
self.audio_path = self.task_local_dir.joinpath(f"task_{self.task_id}.mp3") | |
logging.info(f" Video File Dir: {self.video_path}") | |
logging.info(f" Audio File Dir: {self.audio_path}") | |
logging.info(" Data Prep Complete. Start pipeline") | |
super().run_pipeline() | |
class AudioTask(Task): | |
def __init__(self, task_id, task_local_dir, task_cfg, audio_path): | |
super().__init__(task_id, task_local_dir, task_cfg) | |
# TODO: check audio format | |
self.audio_path = audio_path | |
self.video_path = None | |
def run(self): | |
logging.info(f"Video File Dir: {self.video_path}") | |
logging.info(f"Audio File Dir: {self.audio_path}") | |
logging.info("Data Prep Complete. Start pipeline") | |
super().run_pipeline() | |
class VideoTask(Task): | |
def __init__(self, task_id, task_local_dir, task_cfg, video_path): | |
super().__init__(task_id, task_local_dir, task_cfg) | |
# TODO: check video format {.mp4} | |
new_video_path = f"{task_local_dir}/task_{self.task_id}.mp4" | |
print(new_video_path) | |
logging.info(f"Copy video file to: {new_video_path}") | |
shutil.copyfile(video_path, new_video_path) | |
self.video_path = new_video_path | |
def run(self): | |
logging.info("using ffmpeg to extract audio") | |
subprocess.run( | |
['ffmpeg', '-i', self.video_path, '-f', 'mp3', | |
'-ab', '192000', '-vn', self.task_local_dir.joinpath(f"task_{self.task_id}.mp3")]) | |
logging.info("audio extraction finished") | |
self.audio_path = self.task_local_dir.joinpath(f"task_{self.task_id}.mp3") | |
logging.info(f" Video File Dir: {self.video_path}") | |
logging.info(f" Audio File Dir: {self.audio_path}") | |
logging.info("Data Prep Complete. Start pipeline") | |
super().run_pipeline() |