Spaces:
Sleeping
Sleeping
import whisperx | |
import torch | |
import gradio as gr | |
import tempfile | |
import os | |
import spaces | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
batch_size = 4 # 如果GPU内存不足,可适当减少 | |
compute_type = "float32" # 如果GPU内存不足,可改为 "int8"(可能影响准确度) | |
def transcribe_whisperx(audio_file, task): | |
# WhisperX模型加载 | |
model = whisperx.load_model("large-v3", device=device, compute_type=compute_type) | |
if audio_file is None: | |
raise gr.Error("请上传或录制音频文件再提交请求!") | |
# 加载音频文件 | |
audio = whisperx.load_audio(audio_file) | |
# 执行初步转录 | |
result = model.transcribe(audio, batch_size=batch_size) | |
# 释放模型资源,防止GPU内存不足 | |
torch.cuda.empty_cache() | |
# 加载对齐模型并对齐转录结果 | |
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) | |
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) | |
# 执行说话人分离 | |
hf_token = os.getenv("HF_TOKEN") | |
diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token, | |
device=device) | |
diarize_segments = diarize_model(audio_file) | |
result = whisperx.assign_word_speakers(diarize_segments, result) | |
# 格式化输出文本 | |
output_text = "" | |
for segment in result["segments"]: | |
speaker = segment.get("speaker", "未知") | |
text = segment["text"] | |
output_text += f"{speaker}: {text}\n" | |
return output_text | |
# Gradio界面 | |
demo = gr.Blocks(theme=gr.themes.Ocean()) | |
transcribe_interface = gr.Interface( | |
fn=transcribe_whisperx, | |
inputs=[ | |
gr.Audio(sources=["microphone", "upload"], type="filepath"), | |
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"), | |
], | |
outputs="text", | |
title="WhisperX: Transcribe and Diarize Audio", | |
description="使用WhisperX对音频文件或麦克风输入进行转录和说话人分离。" | |
) | |
with demo: | |
transcribe_interface | |
demo.queue().launch(ssr_mode=False) |