StevenChen16 commited on
Commit
550cf61
·
verified ·
1 Parent(s): d6c72bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -54
app.py CHANGED
@@ -1,66 +1,114 @@
1
- import spaces
2
- import whisperx
3
  import torch
4
  import gradio as gr
 
 
5
  import tempfile
6
- import os
7
 
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
- batch_size = 4 # 如果GPU内存不足,可适当减少
10
- compute_type = "float32" # 如果GPU内存不足,可改为 "int8"(可能影响准确度)
 
 
11
 
12
- @spaces.GPU
13
- def transcribe_whisperx(audio_file, task):
14
- # WhisperX模型加载
15
- model = whisperx.load_model("large-v3", device=device, compute_type=compute_type)
16
 
17
- if audio_file is None:
18
- raise gr.Error("请上传或录制音频文件再提交请求!")
19
-
20
- # 加载音频文件
21
- audio = whisperx.load_audio(audio_file)
22
-
23
- # 执行初步转录
24
- result = model.transcribe(audio, batch_size=batch_size)
25
-
26
- # 释放模型资源,防止GPU内存不足
27
- torch.cuda.empty_cache()
28
-
29
- # 加载对齐模型并对齐转录结果
30
- model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
31
- result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
32
-
33
- # 执行说话人分离
34
- hf_token = os.getenv("HF_TOKEN")
35
- diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token,
36
- device=device)
37
- diarize_segments = diarize_model(audio_file)
38
- result = whisperx.assign_word_speakers(diarize_segments, result)
39
-
40
- # 格式化输出文本
41
- output_text = ""
42
- for segment in result["segments"]:
43
- speaker = segment.get("speaker", "未知")
44
- text = segment["text"]
45
- output_text += f"{speaker}: {text}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- return output_text
 
 
 
48
 
49
- # Gradio界面
50
  demo = gr.Blocks(theme=gr.themes.Ocean())
51
 
52
- transcribe_interface = gr.Interface(
53
- fn=transcribe_whisperx,
54
- inputs=[
55
- gr.Audio(sources=["microphone", "upload"], type="filepath"),
56
- gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
57
- ],
58
- outputs="text",
59
- title="WhisperX: Transcribe and Diarize Audio",
60
- description="使用WhisperX对音频文件或麦克风输入进行转录和说话人分离。"
61
- )
62
-
63
  with demo:
64
- transcribe_interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- demo.queue().launch(ssr_mode=False)
 
1
+ import os
 
2
  import torch
3
  import gradio as gr
4
+ import whisperx
5
+ from transformers.pipelines.audio_utils import ffmpeg_read
6
  import tempfile
7
+ import gc
8
 
9
+ # Constants
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+ BATCH_SIZE = 4
12
+ COMPUTE_TYPE = "float32"
13
+ FILE_LIMIT_MB = 1000
14
 
15
+ def transcribe_audio(inputs, task):
16
+ if inputs is None:
17
+ raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
 
18
 
19
+ try:
20
+ # Load audio
21
+ if isinstance(inputs, str):
22
+ # For file path input
23
+ audio = whisperx.load_audio(inputs)
24
+ else:
25
+ # For microphone input (needs conversion)
26
+ audio = whisperx.load_audio(inputs)
27
+
28
+ # 1. Transcribe with base Whisper model
29
+ model = whisperx.load_model("large-v3", device=DEVICE, compute_type=COMPUTE_TYPE)
30
+ result = model.transcribe(audio, batch_size=BATCH_SIZE)
31
+
32
+ # Clear GPU memory
33
+ del model
34
+ gc.collect()
35
+ torch.cuda.empty_cache()
36
+
37
+ # 2. Align whisper output
38
+ model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=DEVICE)
39
+ result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False)
40
+
41
+ # Clear GPU memory again
42
+ del model_a
43
+ gc.collect()
44
+ torch.cuda.empty_cache()
45
+
46
+ # 3. Diarize audio
47
+ diarize_model = whisperx.DiarizationPipeline(use_auth_token="YOUR_HF_TOKEN", device=DEVICE)
48
+ diarize_segments = diarize_model(audio)
49
+
50
+ # 4. Assign speaker labels
51
+ result = whisperx.assign_word_speakers(diarize_segments, result)
52
+
53
+ # Format output
54
+ output_text = ""
55
+ for segment in result['segments']:
56
+ speaker = segment.get('speaker', 'Unknown Speaker')
57
+ text = segment['text']
58
+ output_text += f"{speaker}: {text}\n"
59
+
60
+ return output_text
61
+
62
+ except Exception as e:
63
+ raise gr.Error(f"Error processing audio: {str(e)}")
64
 
65
+ finally:
66
+ # Final cleanup
67
+ gc.collect()
68
+ torch.cuda.empty_cache()
69
 
70
+ # Create Gradio interface
71
  demo = gr.Blocks(theme=gr.themes.Ocean())
72
 
 
 
 
 
 
 
 
 
 
 
 
73
  with demo:
74
+ gr.Markdown("# WhisperX: Advanced Speech Recognition with Speaker Diarization")
75
+
76
+ with gr.Row():
77
+ with gr.Column():
78
+ audio_input = gr.Audio(
79
+ sources=["microphone", "upload"],
80
+ type="filepath",
81
+ label="Audio Input (Microphone or File Upload)"
82
+ )
83
+ task = gr.Radio(
84
+ ["transcribe", "translate"],
85
+ label="Task",
86
+ value="transcribe"
87
+ )
88
+ submit_button = gr.Button("Process Audio")
89
+
90
+ with gr.Column():
91
+ output_text = gr.Textbox(
92
+ label="Transcription with Speaker Diarization",
93
+ lines=10,
94
+ placeholder="Transcribed text will appear here..."
95
+ )
96
+
97
+ gr.Markdown("""
98
+ ### Features:
99
+ - High-accuracy transcription using WhisperX
100
+ - Automatic speaker diarization
101
+ - Support for both microphone recording and file upload
102
+ - GPU-accelerated processing
103
+
104
+ ### Note:
105
+ Processing may take a few moments depending on the audio length and system resources.
106
+ """)
107
+
108
+ submit_button.click(
109
+ fn=transcribe_audio,
110
+ inputs=[audio_input, task],
111
+ outputs=output_text
112
+ )
113
 
114
+ demo.queue().launch(share=True)