bofenghuang commited on
Commit
7c7bb51
·
1 Parent(s): 3621473

add openai version

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. packages.txt +1 -0
  3. run_demo_openai.py +132 -0
app.py CHANGED
@@ -1 +1 @@
1
- run_demo_multi_models.py
 
1
+ run_demo_openai.py
packages.txt CHANGED
@@ -1 +1,2 @@
1
  ffmpeg
 
 
1
  ffmpeg
2
+ git+https://github.com/openai/whisper.git
run_demo_openai.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+
4
+ import gradio as gr
5
+ import pytube as pt
6
+ import torch
7
+ from huggingface_hub import hf_hub_download, model_info
8
+ from transformers.utils.logging import disable_progress_bar
9
+ import whisper
10
+
11
+
12
+ warnings.filterwarnings("ignore")
13
+ disable_progress_bar()
14
+
15
+ MODEL_NAME = "bofenghuang/whisper-large-v2-cv11-french"
16
+ CHECKPOINT_FILENAME = "checkpoint_openai.pt"
17
+
18
+ logging.basicConfig(
19
+ format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
20
+ datefmt="%Y-%m-%dT%H:%M:%SZ",
21
+ )
22
+ logger = logging.getLogger(__name__)
23
+ logger.setLevel(logging.DEBUG)
24
+
25
+ device = 0 if torch.cuda.is_available() else "cpu"
26
+
27
+ downloaded_model_path = hf_hub_download(repo_id=MODEL_NAME, filename=CHECKPOINT_FILENAME)
28
+
29
+ model = whisper.load_model(downloaded_model_path, device=device)
30
+ logger.info(f"Model has been loaded on device `{device}`")
31
+
32
+ gen_kwargs = {
33
+ "task": "transcribe",
34
+ "language": "fr",
35
+ # "without_timestamps": True,
36
+ # decode options
37
+ # "beam_size": 5,
38
+ # "patience": 2,
39
+ # disable fallback
40
+ # "compression_ratio_threshold": None,
41
+ # "logprob_threshold": None,
42
+ # vad threshold
43
+ # "no_speech_threshold": None,
44
+ }
45
+
46
+ def transcribe(microphone, file_upload):
47
+ warn_output = ""
48
+ if (microphone is not None) and (file_upload is not None):
49
+ warn_output = (
50
+ "WARNING: You've uploaded an audio file and used the microphone. "
51
+ "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
52
+ )
53
+
54
+ elif (microphone is None) and (file_upload is None):
55
+ return "ERROR: You have to either use the microphone or upload an audio file"
56
+
57
+ file = microphone if microphone is not None else file_upload
58
+
59
+ text = model.transcribe(file, **gen_kwargs)["text"]
60
+
61
+ logger.info(f"Transcription: {text}")
62
+
63
+ return warn_output + text
64
+
65
+
66
+ def _return_yt_html_embed(yt_url):
67
+ video_id = yt_url.split("?v=")[-1]
68
+ HTML_str = (
69
+ f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
70
+ " </center>"
71
+ )
72
+ return HTML_str
73
+
74
+
75
+ def yt_transcribe(yt_url):
76
+ yt = pt.YouTube(yt_url)
77
+ html_embed_str = _return_yt_html_embed(yt_url)
78
+ stream = yt.streams.filter(only_audio=True)[0]
79
+ stream.download(filename="audio.mp3")
80
+
81
+ text = model.transcribe("audio.mp3", **gen_kwargs)["text"]
82
+
83
+ logger.info(f'Transcription of "{yt_url}": {text}')
84
+
85
+ return html_embed_str, text
86
+
87
+
88
+ demo = gr.Blocks()
89
+
90
+ mf_transcribe = gr.Interface(
91
+ fn=transcribe,
92
+ inputs=[
93
+ gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Record"),
94
+ gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Upload File"),
95
+ ],
96
+ # outputs="text",
97
+ outputs=gr.outputs.Textbox(label="Transcription"),
98
+ layout="horizontal",
99
+ theme="huggingface",
100
+ title="Whisper French Demo 🇫🇷 : Transcribe Audio",
101
+ description=(
102
+ "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the the fine-tuned"
103
+ f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
104
+ " of arbitrary length."
105
+ ),
106
+ allow_flagging="never",
107
+ )
108
+
109
+ yt_transcribe = gr.Interface(
110
+ fn=yt_transcribe,
111
+ inputs=[gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL")],
112
+ # outputs=["html", "text"],
113
+ outputs=[
114
+ gr.outputs.HTML(label="YouTube Page"),
115
+ gr.outputs.Textbox(label="Transcription"),
116
+ ],
117
+ layout="horizontal",
118
+ theme="huggingface",
119
+ title="Whisper French Demo 🇫🇷 : Transcribe YouTube",
120
+ description=(
121
+ "Transcribe long-form YouTube videos with the click of a button! Demo uses the the fine-tuned checkpoint:"
122
+ f" [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe audio files of"
123
+ " arbitrary length."
124
+ ),
125
+ allow_flagging="never",
126
+ )
127
+
128
+ with demo:
129
+ gr.TabbedInterface([mf_transcribe, yt_transcribe], ["Transcribe Audio", "Transcribe YouTube"])
130
+
131
+ # demo.launch(server_name="0.0.0.0", debug=True, share=True)
132
+ demo.launch(enable_queue=True)