Spaces:
Sleeping
Sleeping
import argparse | |
import faulthandler | |
import gc | |
import os | |
import tempfile | |
import torch | |
import whisperx | |
from whisperx.asr import FasterWhisperPipeline | |
def get_device(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# device = "mps" if torch.backends.mps.is_available() else device | |
return device | |
def generate_subtitles_from_audio( | |
audio_file_path: str, | |
model: FasterWhisperPipeline, | |
batch_size: int = 8 | |
): | |
audio = whisperx.load_audio(audio_file_path) | |
result = model.transcribe(audio, batch_size=batch_size, language="ru", ) | |
return result | |
def generate_subtitles_from_video( | |
video_path: str, | |
model_name: str = "base", | |
batch_size: int = 8, | |
compute_type: str = "int8", | |
): | |
_, audio_file = tempfile.mkstemp() | |
device = get_device() | |
print("Loading model:") | |
model = whisperx.load_model(model_name, device, compute_type=compute_type, language="ru") | |
print("Parsing audio:") | |
parse_audio(video_path, audio_file) | |
print("Generating subtitles:") | |
result = generate_subtitles_from_audio(audio_file, model, batch_size=batch_size) | |
os.remove(audio_file) | |
del model | |
gc.collect() | |
return result | |
def add_whisper_args(arg_parser: argparse.ArgumentParser): | |
arg_parser.add_argument("video", help="video file") | |
arg_parser.add_argument("--compute_type", help="Base type for model", default="int8", | |
choices=["int8", "float16", "float32"]) | |
arg_parser.add_argument("--whisper_model", help="model to use", default="large-v2") | |
arg_parser.add_argument("--batch_size", help="Batch size for inference", default=4, type=int) | |
if __name__ == "__main__": | |
faulthandler.enable() | |
parser = argparse.ArgumentParser(description="Get video subtitles from a video") | |
add_whisper_args(parser) | |
args = parser.parse_args() | |
print(generate_subtitles_from_video(args.video, args.whisper_model, args.batch_size, args.compute_type)) | |