WhisperX-V2 / app.py
StevenChen16's picture
Update app.py
86a1f13 verified
raw
history blame
2.18 kB
import whisperx
import torch
import gradio as gr
import tempfile
import os
import spaces
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 4 # 如果GPU内存不足,可适当减少
compute_type = "float32" # 如果GPU内存不足,可改为 "int8"(可能影响准确度)
@spaces.GPU
def transcribe_whisperx(audio_file, task):
# WhisperX模型加载
model = whisperx.load_model("large-v3", device=device, compute_type=compute_type)
if audio_file is None:
raise gr.Error("请上传或录制音频文件再提交请求!")
# 加载音频文件
audio = whisperx.load_audio(audio_file)
# 执行初步转录
result = model.transcribe(audio, batch_size=batch_size)
# 释放模型资源,防止GPU内存不足
torch.cuda.empty_cache()
# 加载对齐模型并对齐转录结果
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
# 执行说话人分离
hf_token = os.getenv("HF_TOKEN")
diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token,
device=device)
diarize_segments = diarize_model(audio_file)
result = whisperx.assign_word_speakers(diarize_segments, result)
# 格式化输出文本
output_text = ""
for segment in result["segments"]:
speaker = segment.get("speaker", "未知")
text = segment["text"]
output_text += f"{speaker}: {text}\n"
return output_text
# Gradio界面
demo = gr.Blocks(theme=gr.themes.Ocean())
transcribe_interface = gr.Interface(
fn=transcribe_whisperx,
inputs=[
gr.Audio(sources=["microphone", "upload"], type="filepath"),
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
],
outputs="text",
title="WhisperX: Transcribe and Diarize Audio",
description="使用WhisperX对音频文件或麦克风输入进行转录和说话人分离。"
)
with demo:
transcribe_interface
demo.queue().launch(ssr_mode=False)