reedmayhew's picture
Update app.py
e2a857b verified
import gradio as gr
import time
import logging
import torch
from sys import platform
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
from transformers.utils import is_flash_attn_2_available
from languages import get_language_names
from subtitle_manager import Subtitle
import spaces
logging.basicConfig(level=logging.INFO)
last_model = None
pipe = None
def write_file(output_file, subtitle):
with open(output_file, 'w', encoding='utf-8') as f:
f.write(subtitle)
def create_pipe(model, flash):
# Load the model into RAM first
torch_dtype = torch.float32 # Load onto CPU with float32 precision
model_id = model
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa",
)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype, # Keep in CPU until GPU is requested
device="cpu", # Initially stay on CPU
)
return pipe, model # Return both pipe and model for later GPU switch
def move_to_gpu(model):
if torch.cuda.is_available():
device = "cuda:0"
torch_dtype = torch.float16 # Use float16 precision on GPU
model.to(device, dtype=torch_dtype)
elif platform == "darwin":
device = "mps"
model.to(device)
else:
device = "cpu"
return device
@spaces.GPU
def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash,
chunk_length_s, batch_size, progress=gr.Progress()):
global last_model
global pipe
progress(0, desc="Loading Audio..")
logging.info(f"urlData:{urlData}")
logging.info(f"multipleFiles:{multipleFiles}")
logging.info(f"microphoneData:{microphoneData}")
logging.info(f"task: {task}")
logging.info(f"is_flash_attn_2_available: {is_flash_attn_2_available()}")
logging.info(f"chunk_length_s: {chunk_length_s}")
logging.info(f"batch_size: {batch_size}")
if last_model is None:
logging.info("first model")
progress(0.1, desc="Loading Model..")
pipe, model = create_pipe(modelName, flash)
elif modelName != last_model:
logging.info("new model")
torch.cuda.empty_cache()
progress(0.1, desc="Loading Model..")
pipe, model = create_pipe(modelName, flash)
else:
logging.info("Model not changed")
last_model = modelName
# Now move the model to GPU after the pipe is created, within the function's context
with torch.inference_mode():
device = move_to_gpu(pipe.model)
# Update pipe's device
pipe.device = torch.device(device)
pipe.model.to(pipe.device)
srt_sub = Subtitle("srt")
vtt_sub = Subtitle("vtt")
txt_sub = Subtitle("txt")
files = []
if multipleFiles:
files += multipleFiles
if urlData:
files.append(urlData)
if microphoneData:
files.append(microphoneData)
logging.info(files)
generate_kwargs = {}
if languageName != "Automatic Detection" and modelName.endswith(".en") == False:
generate_kwargs["language"] = languageName
if modelName.endswith(".en") == False:
generate_kwargs["task"] = task
files_out = []
for file in progress.tqdm(files, desc="Working..."):
start_time = time.time()
logging.info(file)
outputs = pipe(
file,
chunk_length_s=chunk_length_s, # 30
batch_size=batch_size, # 24
generate_kwargs=generate_kwargs,
return_timestamps=True,
)
logging.debug(outputs)
logging.info(print(f"transcribe: {time.time() - start_time} sec."))
file_out = file.split('/')[-1]
srt = srt_sub.get_subtitle(outputs["chunks"])
vtt = vtt_sub.get_subtitle(outputs["chunks"])
txt = txt_sub.get_subtitle(outputs["chunks"])
write_file(file_out + ".srt", srt)
write_file(file_out + ".vtt", vtt)
write_file(file_out + ".txt", txt)
files_out += [file_out + ".srt", file_out + ".vtt", file_out + ".txt"]
progress(1, desc="Completed!")
return files_out, vtt, txt
with gr.Blocks(title="Insanely Fast Whisper") as demo:
description = "An opinionated CLI to transcribe Audio files w/ Whisper on-device! Powered by 🤗 Transformers, Optimum & flash-attn"
article = "Read the [documentation here](https://github.com/Vaibhavs10/insanely-fast-whisper#cli-options)."
whisper_models = [
"openai/whisper-tiny.en",
"openai/whisper-base.en",
"openai/whisper-small.en", "distil-whisper/distil-small.en",
"openai/whisper-medium.en", "distil-whisper/distil-medium.en",
"openai/whisper-large-v3", "distil-whisper/distil-large-v3",
]
waveform_options = gr.WaveformOptions(
waveform_color="#01C6FF",
waveform_progress_color="#0066B4",
skip_length=2,
show_controls=False,
)
simple_transcribe = gr.Interface(fn=transcribe_webui_simple_progress,
description=description,
article=article,
inputs=[
gr.Dropdown(choices=whisper_models, value="distil-whisper/distil-large-v3",
label="Model", info="Select whisper model", interactive=True),
gr.Dropdown(choices=["English"], value="English", interactive=False, visible=False,
label="Language",
info="Select audio voice language", ),
gr.Text(label="URL", info="(YouTube, etc.)", interactive=False, visible=False),
gr.File(label="Upload Files", file_count="multiple", interactive=False, visible=False),
gr.Audio(sources=["upload", "microphone", ], type="filepath", label="Input",
waveform_options=waveform_options),
gr.Dropdown(choices=["transcribe", "translate"], label="Task",
value="transcribe", interactive=False, visible=False),
gr.Checkbox(label='Flash', info='Use Flash Attention 2', interactive=False, visible=False),
gr.Number(label='chunk_length_s', value=30, interactive=False, visible=False),
gr.Number(label='batch_size', value=24, interactive=False, visible=False)
], outputs=[
gr.File(label="Download"),
gr.Text(label="Transcription"),
gr.Text(label="Segments")
]
)
if __name__ == "__main__":
demo.launch()