Spaces:
Paused
Paused
import os, sys, re | |
import shutil | |
import subprocess | |
import soundfile | |
from process_audio import segment_audio | |
from write_srt import write_to_file | |
from clean_text import clean_english, clean_german, clean_spanish | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
from transformers import AutoModelForCTC, AutoProcessor | |
import torch | |
import gradio as gr | |
english_model = "facebook/wav2vec2-large-960h-lv60-self" | |
english_tokenizer = Wav2Vec2Processor.from_pretrained(english_model) | |
english_asr_model = Wav2Vec2ForCTC.from_pretrained(english_model) | |
german_model = "flozi00/wav2vec2-large-xlsr-53-german-with-lm" | |
german_tokenizer = AutoProcessor.from_pretrained(german_model) | |
german_asr_model = AutoModelForCTC.from_pretrained(german_model) | |
spanish_model = "patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm" | |
spanish_tokenizer = AutoProcessor.from_pretrained(spanish_model) | |
spanish_asr_model = AutoModelForCTC.from_pretrained(spanish_model) | |
# Get German corpus and update nltk | |
command = ["python", "-m", "textblob.download_corpora"] | |
subprocess.run(command) | |
# Line count for SRT file | |
line_count = 0 | |
def sort_alphanumeric(data): | |
convert = lambda text: int(text) if text.isdigit() else text.lower() | |
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] | |
return sorted(data, key = alphanum_key) | |
def transcribe_audio(tokenizer, asr_model, audio_file, file_handle): | |
# Run Wav2Vec2.0 inference on each audio file generated after VAD segmentation. | |
global line_count | |
speech, rate = soundfile.read(audio_file) | |
input_values = tokenizer(speech, sampling_rate=16000, return_tensors = "pt", padding='longest').input_values | |
logits = asr_model(input_values).logits | |
prediction = torch.argmax(logits, dim = -1) | |
infered_text = tokenizer.batch_decode(prediction)[0].lower() | |
if len(infered_text) > 1: | |
if lang == 'english': | |
infered_text = clean_english(infered_text) | |
elif lang == 'german': | |
infered_text = clean_german(infered_text) | |
elif lang == 'spanish': | |
infered_text = clean_spanish(infered_text) | |
print(infered_text) | |
limits = audio_file.split(os.sep)[-1][:-4].split("_")[-1].split("-") | |
line_count += 1 | |
write_to_file(file_handle, infered_text, line_count, limits) | |
else: | |
infered_text = '' | |
def get_subs(input_file, language): | |
# Get directory for audio | |
base_directory = os.getcwd() | |
audio_directory = os.path.join(base_directory, "audio") | |
if os.path.isdir(audio_directory): | |
shutil.rmtree(audio_directory) | |
os.mkdir(audio_directory) | |
# Extract audio from video file | |
video_file = input_file | |
audio_file = audio_directory+'/temp.wav' | |
command = ["ffmpeg", "-i", video_file, "-ac", "1", "-ar", "16000","-vn", "-f", "wav", audio_file] | |
subprocess.run(command) | |
video_file = input_file.split('/')[-1][:-4] | |
srt_file_name = os.path.join(video_file + ".srt") | |
# Split audio file based on VAD silent segments | |
segment_audio(audio_file) | |
os.remove(audio_file) | |
# Output SRT file | |
file_handle = open(srt_file_name, "a+") | |
file_handle.seek(0) | |
for file in sort_alphanumeric(os.listdir(audio_directory)): | |
audio_segment_path = os.path.join(audio_directory, file) | |
global lang | |
lang = language.lower() | |
tokenizer = globals()[lang+'_tokenizer'] | |
asr_model = globals()[lang+'_asr_model'] | |
if audio_segment_path.split(os.sep)[-1] != audio_file.split(os.sep)[-1]: | |
transcribe_audio(tokenizer, asr_model, audio_segment_path, file_handle) | |
file_handle.close() | |
shutil.rmtree(audio_directory) | |
return srt_file_name | |
gradio_ui = gr.Interface( | |
enable_queue=True, | |
fn=get_subs, | |
title="Video to Subtitle", | |
description="Get subtitles (SRT file) for your videos. Inference speed is about 10s/per 1min of video BUT the speed of uploading your video depends on your internet connection.", | |
inputs=[gr.inputs.Video(label="Upload Video File"), | |
gr.inputs.Radio(label="Choose Language", choices=['English', 'German', 'Spanish'])], | |
outputs=gr.outputs.File(label="Auto-Transcript") | |
) | |
gradio_ui.launch() | |