bofenghuang commited on
Commit
bc8cab0
·
1 Parent(s): 009ac63

add multi model option

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. run_demo_multi_models.py +138 -0
app.py CHANGED
@@ -1 +1 @@
1
- run_demo.py
 
1
+ run_demo_multi_models.py
run_demo_multi_models.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import gradio as gr
4
+ import pytube as pt
5
+ import torch
6
+ from huggingface_hub import model_info
7
+ from transformers import pipeline
8
+
9
+ DEFAULT_MODEL_NAME = "bhuang/whisper-medium-cv11-french-case-punctuation"
10
+ MODEL_NAMES = [
11
+ "bhuang/whisper-small-cv11-french",
12
+ "bhuang/whisper-small-cv11-french-case-punctuation",
13
+ "bhuang/whisper-medium-cv11-french",
14
+ "bhuang/whisper-medium-cv11-french-case-punctuation",
15
+ ]
16
+ CHUNK_LENGTH_S = 30
17
+
18
+ logging.basicConfig(
19
+ format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
20
+ datefmt="%Y-%m-%dT%H:%M:%SZ",
21
+ )
22
+ logger = logging.getLogger(__name__)
23
+ logger.setLevel(logging.DEBUG)
24
+
25
+ device = 0 if torch.cuda.is_available() else "cpu"
26
+
27
+ cached_models = {}
28
+
29
+ def maybe_load_cached_pipeline(model_name):
30
+ pipe = cached_models.get(model_name)
31
+ if pipe is None:
32
+ # load pipeline
33
+ pipe = pipeline(
34
+ task="automatic-speech-recognition",
35
+ model=model_name,
36
+ chunk_length_s=CHUNK_LENGTH_S,
37
+ device=device,
38
+ )
39
+ # set forced_decoder_ids
40
+ pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language="fr", task="transcribe")
41
+
42
+ logger.info(f"`{model_name}` pipeline has been initialized")
43
+
44
+ cached_models[model_name] = pipe
45
+ return pipe
46
+
47
+
48
+ def transcribe(microphone, file_upload, model_name):
49
+ warn_output = ""
50
+ if (microphone is not None) and (file_upload is not None):
51
+ warn_output = (
52
+ "WARNING: You've uploaded an audio file and used the microphone. "
53
+ "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
54
+ )
55
+
56
+ elif (microphone is None) and (file_upload is None):
57
+ return "ERROR: You have to either use the microphone or upload an audio file"
58
+
59
+ file = microphone if microphone is not None else file_upload
60
+
61
+ pipe = maybe_load_cached_pipeline(model_name)
62
+ text = pipe(file)["text"]
63
+
64
+ logger.info(f"Transcription: {text}")
65
+
66
+ return warn_output + text
67
+
68
+
69
+ def _return_yt_html_embed(yt_url):
70
+ video_id = yt_url.split("?v=")[-1]
71
+ HTML_str = (
72
+ f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
73
+ " </center>"
74
+ )
75
+ return HTML_str
76
+
77
+
78
+ def yt_transcribe(yt_url, model_name):
79
+ yt = pt.YouTube(yt_url)
80
+ html_embed_str = _return_yt_html_embed(yt_url)
81
+ stream = yt.streams.filter(only_audio=True)[0]
82
+ stream.download(filename="audio.mp3")
83
+
84
+ pipe = maybe_load_cached_pipeline(model_name)
85
+ text = pipe("audio.mp3")["text"]
86
+
87
+ logger.info(f"Transcription: {text}")
88
+
89
+ return html_embed_str, text
90
+
91
+
92
+ # load default model
93
+ maybe_load_cached_pipeline(DEFAULT_MODEL_NAME)
94
+
95
+ demo = gr.Blocks()
96
+
97
+ mf_transcribe = gr.Interface(
98
+ fn=transcribe,
99
+ inputs=[
100
+ gr.inputs.Audio(source="microphone", type="filepath", optional=True),
101
+ gr.inputs.Audio(source="upload", type="filepath", optional=True),
102
+ gr.inputs.Dropdown(choices=MODEL_NAMES, default=DEFAULT_MODEL_NAME, label="Whisper Model"),
103
+ ],
104
+ outputs="text",
105
+ layout="horizontal",
106
+ theme="huggingface",
107
+ title="Whisper Demo: Transcribe Audio",
108
+ description=(
109
+ "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the the fine-tuned"
110
+ f" checkpoint [{DEFAULT_MODEL_NAME}](https://huggingface.co/{DEFAULT_MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
111
+ " of arbitrary length."
112
+ ),
113
+ allow_flagging="never",
114
+ )
115
+
116
+ yt_transcribe = gr.Interface(
117
+ fn=yt_transcribe,
118
+ inputs=[
119
+ gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
120
+ gr.inputs.Dropdown(choices=MODEL_NAMES, default=DEFAULT_MODEL_NAME, label="Whisper Model"),
121
+ ],
122
+ outputs=["html", "text"],
123
+ layout="horizontal",
124
+ theme="huggingface",
125
+ title="Whisper Demo: Transcribe YouTube",
126
+ description=(
127
+ "Transcribe long-form YouTube videos with the click of a button! Demo uses the the fine-tuned checkpoint:"
128
+ f" [{DEFAULT_MODEL_NAME}](https://huggingface.co/{DEFAULT_MODEL_NAME}) and 🤗 Transformers to transcribe audio files of"
129
+ " arbitrary length."
130
+ ),
131
+ allow_flagging="never",
132
+ )
133
+
134
+ with demo:
135
+ gr.TabbedInterface([mf_transcribe, yt_transcribe], ["Transcribe Audio", "Transcribe YouTube"])
136
+
137
+ # demo.launch(server_name="0.0.0.0", debug=True, share=True)
138
+ demo.launch(enable_queue=True)