StevenChen16 commited on
Commit
86a1f13
1 Parent(s): e32eb72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -88
app.py CHANGED
@@ -1,109 +1,66 @@
1
- import spaces
2
- import gradio as gr
3
- import yt_dlp as youtube_dl
4
  import whisperx
 
 
5
  import tempfile
6
  import os
7
- import torch
8
- import gc
9
-
10
- # WhisperX配置
11
- device = "cuda" #if torch.cuda.is_available() else "cpu"
12
- batch_size = 4
13
- compute_type = "float32"
14
- MODEL_NAME = "large-v3"
15
- YT_LENGTH_LIMIT_S = 3600 # 1 hour YouTube files
16
-
17
- # 加载WhisperX模型
18
- # @spaces.GPU
19
- # def load_whisperx_model():
20
- # # 加载 WhisperX 模型
21
- # return whisperx.load_model(MODEL_NAME, device=device, compute_type=compute_type)
22
 
23
- # model = load_whisperx_model()
 
 
24
 
25
  @spaces.GPU
26
- def transcribe(inputs, task):
27
- model = whisperx.load_model(MODEL_NAME, device=device, compute_type=compute_type)
28
-
29
- if inputs is None:
30
- raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
31
 
32
- # 加载和转录音频
33
- audio = whisperx.load_audio(inputs)
 
 
 
 
 
34
  result = model.transcribe(audio, batch_size=batch_size)
35
- print(result["segments"]) # 未对齐的文本片段
36
 
37
- # 释放资源以节省GPU内存
38
- gc.collect()
39
  torch.cuda.empty_cache()
40
- del model
41
-
42
- # 加载对齐模型
43
  model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
44
  result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
45
 
46
- # 说话人分离
47
- diarize_model = whisperx.DiarizationPipeline(use_auth_token="your_huggingface_token", device=device)
48
- result = whisperx.assign_word_speakers(diarize_model, result)
49
-
50
- # 格式化输出
51
- transcript = ""
52
- for segment in result['segments']:
53
- speaker = segment.get('speaker', 'Unknown')
54
- transcript += f"{speaker}: {segment['text']}\n"
55
-
56
- return transcript
57
-
58
- def _return_yt_html_embed(yt_url):
59
- video_id = yt_url.split("?v=")[-1]
60
- HTML_str = (
61
- f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
62
- " </center>"
63
- )
64
- return HTML_str
65
-
66
- def download_yt_audio(yt_url, filename):
67
- info_loader = youtube_dl.YoutubeDL()
68
-
69
- try:
70
- info = info_loader.extract_info(yt_url, download=False)
71
- except youtube_dl.utils.DownloadError as err:
72
- raise gr.Error(str(err))
73
-
74
- file_length = info["duration"]
75
- if file_length > YT_LENGTH_LIMIT_S:
76
- raise gr.Error("YouTube video length exceeds the 1-hour limit.")
77
-
78
- ydl_opts = {"outtmpl": filename, "format": "bestaudio[ext=m4a]"}
79
-
80
- with youtube_dl.YoutubeDL(ydl_opts) as ydl:
81
- try:
82
- ydl.download([yt_url])
83
- except youtube_dl.utils.ExtractorError as err:
84
- raise gr.Error(str(err))
85
 
86
- def yt_transcribe(yt_url, task):
87
- html_embed_str = _return_yt_html_embed(yt_url)
 
 
 
 
88
 
89
- with tempfile.TemporaryDirectory() as tmpdirname:
90
- filepath = os.path.join(tmpdirname, "video.m4a")
91
- download_yt_audio(yt_url, filepath)
92
- result = transcribe(filepath, task)
93
 
94
- return html_embed_str, result
 
95
 
96
- # Gradio 界面设置
97
- demo = gr.Blocks()
98
- yt_transcribe_interface = gr.Interface(
99
- fn=yt_transcribe,
100
- inputs=[gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL")],
101
- outputs=["html", "text"],
102
- title="WhisperX: Transcribe YouTube with Speaker Diarization",
103
- description="Transcribe and diarize YouTube videos with WhisperX."
 
104
  )
105
 
106
  with demo:
107
- gr.TabbedInterface([yt_transcribe_interface], ["YouTube"])
108
 
109
- demo.launch()
 
 
 
 
1
  import whisperx
2
+ import torch
3
+ import gradio as gr
4
  import tempfile
5
  import os
6
+ import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ batch_size = 4 # 如果GPU内存不足,可适当减少
10
+ compute_type = "float32" # 如果GPU内存不足,可改为 "int8"(可能影响准确度)
11
 
12
  @spaces.GPU
13
+ def transcribe_whisperx(audio_file, task):
14
+ # WhisperX模型加载
15
+ model = whisperx.load_model("large-v3", device=device, compute_type=compute_type)
 
 
16
 
17
+ if audio_file is None:
18
+ raise gr.Error("请上传或录制音频文件再提交请求!")
19
+
20
+ # 加载音频文件
21
+ audio = whisperx.load_audio(audio_file)
22
+
23
+ # 执行初步转录
24
  result = model.transcribe(audio, batch_size=batch_size)
 
25
 
26
+ # 释放模型资源,防止GPU内存不足
 
27
  torch.cuda.empty_cache()
28
+
29
+ # 加载对齐模型并对齐转录结果
 
30
  model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
31
  result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
32
 
33
+ # 执行说话人分离
34
+ hf_token = os.getenv("HF_TOKEN")
35
+ diarize_model = whisperx.DiarizationPipeline(use_auth_token=hf_token,
36
+ device=device)
37
+ diarize_segments = diarize_model(audio_file)
38
+ result = whisperx.assign_word_speakers(diarize_segments, result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ # 格式化输出文本
41
+ output_text = ""
42
+ for segment in result["segments"]:
43
+ speaker = segment.get("speaker", "未知")
44
+ text = segment["text"]
45
+ output_text += f"{speaker}: {text}\n"
46
 
47
+ return output_text
 
 
 
48
 
49
+ # Gradio界面
50
+ demo = gr.Blocks(theme=gr.themes.Ocean())
51
 
52
+ transcribe_interface = gr.Interface(
53
+ fn=transcribe_whisperx,
54
+ inputs=[
55
+ gr.Audio(sources=["microphone", "upload"], type="filepath"),
56
+ gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
57
+ ],
58
+ outputs="text",
59
+ title="WhisperX: Transcribe and Diarize Audio",
60
+ description="使用WhisperX对音频文件或麦克风输入进行转录和说话人分离。"
61
  )
62
 
63
  with demo:
64
+ transcribe_interface
65
 
66
+ demo.queue().launch(ssr_mode=False)