Mihaj's picture
Update app.py
6a5a2f9 verified
raw
history blame
6.67 kB
import gradio as gr
from transformers import pipeline, Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC
import os
import soundfile as sf
from pyannote.audio import Pipeline
import torch
from pydub import AudioSegment
from pydub.playback import play
from datetime import datetime, timedelta
import time
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
HF_TOKEN = os.environ.get("HF_TOKEN")
sr = 16000
channels = 1
model_name = "Mihaj/wav2vec2-large-xls-r-300m-ruOH-alphav"
bond005_model = "bond005/wav2vec2-large-ru-golos-with-lm"
processor = Wav2Vec2ProcessorWithLM.from_pretrained(bond005_model)
model = Wav2Vec2ForCTC.from_pretrained(bond005_model)
pipe = pipeline("automatic-speech-recognition", model=model, tokenizer=processor, feature_extractor=processor.feature_extractor, decoder=processor.decoder)
model = load_silero_vad()
pipeline_dia = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1",
use_auth_token=HF_TOKEN)
temp_path = "temp.wav"
def preprocess(audio_path):
print("PREPROCESSING STARTED")
sound = AudioSegment.from_file(audio_path, format="mp3")
sound = sound.set_frame_rate(sr)
sound = sound.set_channels(channels)
sound.export(temp_path, format="wav")
print("PREPROCESSING ENDED")
return temp_path
def fast_transcribe(diarise, how_diarise, translate, audio):
audio = preprocess(audio)
y, sr = sf.read(audio)
if diarise:
if how_diarise=="Accurate":
print("DIARISING")
dia = pipeline_dia(audio)
print("DIARISING ENDED")
lines = []
for i, line in enumerate(dia.to_lab().split('\n')):
if line.strip() != "":
res = line.split(" ")
start = int(float(res[0]) * sr)
start_time = str(datetime.fromtimestamp(start / sr) - timedelta(hours=1, minutes=0)).split()[1]
start_time_prts = start_time.split(":")
start_time_srt = f"{start_time_prts[0]}:{start_time_prts[1]}:{float(start_time_prts[2]):.3f}".replace('.', ',')
end = int(float(res[1]) * sr)
end_time = str(datetime.fromtimestamp(end / sr) - timedelta(hours=1, minutes=0)).split()[1]
end_time_prts = end_time.split(":")
end_time_srt = f"{end_time_prts[0]}:{end_time_prts[1]}:{float(end_time_prts[2]):.3f}".replace('.', ',')
label = res[2]
print(f"RECOGNISING LINE_{i} T_START {start_time_srt} T_END {end_time_srt} SPEAKER_{label}")
trans = pipe(y[start:end], chunk_length_s=10, stride_length_s=(4, 2))["text"]
if not translate:
lines.append(f"{i+1}\n{start_time_srt} --> {end_time_srt}\n[{label}] {trans}\n")
else:
print("TRANSLATION STARTED")
trans_eng = translator.translate('trans', src='ru', dest="en").text
print(f"TRANSLATION ENDED RESULT {trans_eng}")
lines.append(f"{i+1}\n{start_time_srt} --> {end_time_srt}\n[{label}] {trans}\n[{label}] {trans_eng}\n")
print("RECOGNISING ENDED")
print(f"LINE RESULT {trans}")
else:
print("DIARISING")
wav = read_audio(audio) # backend (sox, soundfile, or ffmpeg) required!
speech_timestamps = get_speech_timestamps(wav, model, speech_pad_ms=80, min_silence_duration_ms=150, window_size_samples=256)
print("DIARISING ENDED")
lines = []
for i, line in enumerate(speech_timestamps):
start = line['start']
print(start)
start_time = str(datetime.fromtimestamp(start / sr) - timedelta(hours=1, minutes=0)).split()[1]
start_time_prts = start_time.split(":")
start_time_srt = f"{start_time_prts[0]}:{start_time_prts[1]}:{float(start_time_prts[2]):.3f}".replace('.', ',')
print(start_time_srt)
end = line['end']
end_time = str(datetime.fromtimestamp(end / sr) - timedelta(hours=1, minutes=0)).split()[1]
end_time_prts = end_time.split(":")
end_time_srt = f"{end_time_prts[0]}:{end_time_prts[1]}:{float(end_time_prts[2]):.3f}".replace('.', ',')
print(f"RECOGNISING LINE_{i} T_START {start_time_srt} T_END {end_time_srt}")
trans = pipe(y[start:end], chunk_length_s=10, stride_length_s=(4, 2))["text"]
print("RECOGNISING ENDED")
if not translate:
lines.append(f"{i+1}\n{start_time_srt} --> {end_time_srt}\n[{trans}\n")
else:
print("TRANSLATION STARTED")
trans_eng = translator.translate(trans, src='ru', dest="en").text
print(f"TRANSLATION ENDED RESULT {trans_eng}")
lines.append(f"{i+1}\n{start_time_srt} --> {end_time_srt}\n{trans}\n{trans_eng}\n")
print(f"LINE RESULT {trans}")
text = "\n".join(lines)
else:
print("RECOGNISING FULL AUDIO")
res = pipe(y, chunk_length_s=10, stride_length_s=(4, 2))
print("RECOGNISING FULL AUDIO ENDED")
text = res["text"]
return text
with gr.Blocks() as demo:
gr.Markdown("""
#Wav2Vec2 RuOH
Realtime demo for Russian Oral History recognition using several diarizations method (Silero VAD, Pyannote) and a Wav2Vec large model from bond005. https://huggingface.co/bond005/wav2vec2-large-ru-golos-with-lm"
""")
with gr.Tab("Fast Translation"):
with gr.Row():
with gr.Column():
fast_diarize_input = gr.Checkbox(label="Subtitles", info="Do you want subtitles?")
fast_diarize_radio_input = gr.Radio(["Fast", "Accurate", "-"], label="separating_on_subtitles_pption", info="You can choose separating audio on smaller pieces by faster yet low quality variant (Silero VAD), or slower yet high quality variant (Pyannote.Diarization, this option will detect different speakers)")
fast_translate_input = gr.Checkbox(label="Translate", info="Do you want translation to English?")
fast_audio_input = gr.Audio(type="filepath")
fast_output = gr.Textbox()
fast_inputs = [fast_diarize_input, fast_diarize_radio_input, fast_translate_input, fast_audio_input]
fast_recognize_button = gr.Button("Run")
fast_recognize_button.click(fast_transcribe, inputs=fast_inputs, outputs=fast_output)
if __name__ == "__main__":
demo.launch()