alexanderander30 commited on
Commit
32d353a
·
verified ·
1 Parent(s): 7d2b707

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +13 -0
  2. app.py +260 -0
  3. gitattributes +35 -0
  4. requirements.txt +8 -0
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Whisper JAX
3
+ emoji: 👀
4
+ colorFrom: yellow
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import os
4
+ import tempfile
5
+ import time
6
+
7
+ import gradio as gr
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+ import yt_dlp as youtube_dl
11
+ from jax.experimental.compilation_cache import compilation_cache as cc
12
+ from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
13
+ from transformers.pipelines.audio_utils import ffmpeg_read
14
+
15
+ from whisper_jax import FlaxWhisperPipline
16
+
17
+
18
+ cc.initialize_cache("./jax_cache")
19
+ checkpoint = "openai/whisper-large-v3"
20
+
21
+ BATCH_SIZE = 32
22
+ CHUNK_LENGTH_S = 30
23
+ NUM_PROC = 32
24
+ FILE_LIMIT_MB = 1000
25
+ YT_LENGTH_LIMIT_S = 7200 # limit to 2 hour YouTube files
26
+
27
+ title = "Whisper JAX: The Fastest Whisper API ⚡️"
28
+
29
+ description = """Whisper JAX is an optimised implementation of the [Whisper model](https://huggingface.co/openai/whisper-large-v3) by OpenAI. This demo is running on JAX with a TPU v5e backend. Compared to PyTorch on an A100 GPU, it is over [**70x faster**](https://github.com/sanchit-gandhi/whisper-jax#benchmarks), making it the fastest Whisper API available.
30
+
31
+ Note that at peak times, you may find yourself in the queue for this demo. When you submit a request, your queue position will be shown in the top right-hand side of the demo pane. Once you reach the front of the queue, your audio file will be transcribed, with the progress displayed through a progress bar.
32
+
33
+ To skip the queue, you may wish to create your own inference endpoint by duplicating the demo, details for which can be found in the [Whisper JAX repository](https://github.com/sanchit-gandhi/whisper-jax#creating-an-endpoint).
34
+ """
35
+
36
+ article = "Whisper large-v3 model by OpenAI. Backend running JAX on a TPU v5e directly through Hugging Face Spaces. Whisper JAX [code](https://github.com/sanchit-gandhi/whisper-jax) and Gradio demo by 🤗 Hugging Face."
37
+
38
+ language_names = sorted(TO_LANGUAGE_CODE.keys())
39
+
40
+ logger = logging.getLogger("whisper-jax-app")
41
+ logger.setLevel(logging.INFO)
42
+ ch = logging.StreamHandler()
43
+ ch.setLevel(logging.INFO)
44
+ formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
45
+ ch.setFormatter(formatter)
46
+ logger.addHandler(ch)
47
+
48
+
49
+ def identity(batch):
50
+ return batch
51
+
52
+
53
+ # Copied from https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/utils.py#L50
54
+ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
55
+ if seconds is not None:
56
+ milliseconds = round(seconds * 1000.0)
57
+
58
+ hours = milliseconds // 3_600_000
59
+ milliseconds -= hours * 3_600_000
60
+
61
+ minutes = milliseconds // 60_000
62
+ milliseconds -= minutes * 60_000
63
+
64
+ seconds = milliseconds // 1_000
65
+ milliseconds -= seconds * 1_000
66
+
67
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
68
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
69
+ else:
70
+ # we have a malformed timestamp so just return it as is
71
+ return seconds
72
+
73
+
74
+ if __name__ == "__main__":
75
+ pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
76
+ stride_length_s = CHUNK_LENGTH_S / 6
77
+ chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
78
+ stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate)
79
+ step = chunk_len - stride_left - stride_right
80
+
81
+ # do a pre-compile step so that the first user to use the demo isn't hit with a long transcription time
82
+ logger.info("compiling forward call...")
83
+ start = time.time()
84
+ random_inputs = {
85
+ "input_features": np.ones(
86
+ (BATCH_SIZE, pipeline.model.config.num_mel_bins, 2 * pipeline.model.config.max_source_positions)
87
+ )
88
+ }
89
+ random_timestamps = pipeline.forward(random_inputs, batch_size=BATCH_SIZE, return_timestamps=True)
90
+ compile_time = time.time() - start
91
+ logger.info(f"compiled in {compile_time}s")
92
+
93
+ def tqdm_generate(inputs: dict, task: str, return_timestamps: bool, progress: gr.Progress):
94
+ inputs_len = inputs["array"].shape[0]
95
+ all_chunk_start_idx = np.arange(0, inputs_len, step)
96
+ num_samples = len(all_chunk_start_idx)
97
+ num_batches = math.ceil(num_samples / BATCH_SIZE)
98
+ dummy_batches = list(
99
+ range(num_batches)
100
+ ) # Gradio progress bar not compatible with generator, see https://github.com/gradio-app/gradio/issues/3841
101
+
102
+ dataloader = pipeline.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
103
+ model_outputs = []
104
+ start_time = time.time()
105
+ logger.info("transcribing...")
106
+ # iterate over our chunked audio samples - always predict timestamps to reduce hallucinations
107
+ for batch, _ in zip(dataloader, progress.tqdm(dummy_batches, desc="Transcribing...")):
108
+ model_outputs.append(pipeline.forward(batch, batch_size=BATCH_SIZE, task=task, return_timestamps=True))
109
+ runtime = time.time() - start_time
110
+ logger.info("done transcription")
111
+
112
+ logger.info("post-processing...")
113
+ post_processed = pipeline.postprocess(model_outputs, return_timestamps=True)
114
+ text = post_processed["text"]
115
+ if return_timestamps:
116
+ timestamps = post_processed.get("chunks")
117
+ timestamps = [
118
+ f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
119
+ for chunk in timestamps
120
+ ]
121
+ text = "\n".join(str(feature) for feature in timestamps)
122
+ logger.info("done post-processing")
123
+ return text, runtime
124
+
125
+ def transcribe_chunked_audio(inputs, task, return_timestamps, progress=gr.Progress()):
126
+ progress(0, desc="Loading audio file...")
127
+ logger.info("loading audio file...")
128
+ if inputs is None:
129
+ logger.warning("No audio file")
130
+ raise gr.Error("No audio file submitted! Please upload an audio file before submitting your request.")
131
+ file_size_mb = os.stat(inputs).st_size / (1024 * 1024)
132
+ if file_size_mb > FILE_LIMIT_MB:
133
+ logger.warning("Max file size exceeded")
134
+ raise gr.Error(
135
+ f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB."
136
+ )
137
+
138
+ with open(inputs, "rb") as f:
139
+ inputs = f.read()
140
+
141
+ inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
142
+ inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
143
+ logger.info("done loading")
144
+ text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
145
+ return text, runtime
146
+
147
+ def _return_yt_html_embed(yt_url):
148
+ video_id = yt_url.split("?v=")[-1]
149
+ HTML_str = (
150
+ f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
151
+ " </center>"
152
+ )
153
+ return HTML_str
154
+
155
+ def download_yt_audio(yt_url, filename):
156
+ info_loader = youtube_dl.YoutubeDL()
157
+ try:
158
+ info = info_loader.extract_info(yt_url, download=False)
159
+ except youtube_dl.utils.DownloadError as err:
160
+ raise gr.Error(str(err))
161
+
162
+ file_length = info["duration_string"]
163
+ file_h_m_s = file_length.split(":")
164
+ file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
165
+ if len(file_h_m_s) == 1:
166
+ file_h_m_s.insert(0, 0)
167
+ if len(file_h_m_s) == 2:
168
+ file_h_m_s.insert(0, 0)
169
+
170
+ file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
171
+ if file_length_s > YT_LENGTH_LIMIT_S:
172
+ yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
173
+ file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
174
+ raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
175
+
176
+ ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
177
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
178
+ try:
179
+ ydl.download([yt_url])
180
+ except youtube_dl.utils.ExtractorError as err:
181
+ raise gr.Error(str(err))
182
+
183
+ def transcribe_youtube(yt_url, task, return_timestamps, progress=gr.Progress()):
184
+ progress(0, desc="Loading audio file...")
185
+ logger.info("loading youtube file...")
186
+ html_embed_str = _return_yt_html_embed(yt_url)
187
+ with tempfile.TemporaryDirectory() as tmpdirname:
188
+ filepath = os.path.join(tmpdirname, "video.mp4")
189
+ download_yt_audio(yt_url, filepath)
190
+
191
+ with open(filepath, "rb") as f:
192
+ inputs = f.read()
193
+
194
+ inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
195
+ inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
196
+ logger.info("done loading...")
197
+ text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
198
+ return html_embed_str, text, runtime
199
+
200
+ microphone_chunked = gr.Interface(
201
+ fn=transcribe_chunked_audio,
202
+ inputs=[
203
+ gr.Audio(sources=["microphone"], type="filepath"),
204
+ gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
205
+ gr.Checkbox(value=False, label="Return timestamps"),
206
+ ],
207
+ outputs=[
208
+ gr.Textbox(label="Transcription", show_copy_button=True),
209
+ gr.Textbox(label="Transcription Time (s)"),
210
+ ],
211
+ allow_flagging="never",
212
+ title=title,
213
+ description=description,
214
+ article=article,
215
+ )
216
+
217
+ audio_chunked = gr.Interface(
218
+ fn=transcribe_chunked_audio,
219
+ inputs=[
220
+ gr.Audio(sources=["upload"], label="Audio file", type="filepath"),
221
+ gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
222
+ gr.Checkbox(value=False, label="Return timestamps"),
223
+ ],
224
+ outputs=[
225
+ gr.Textbox(label="Transcription", show_copy_button=True),
226
+ gr.Textbox(label="Transcription Time (s)"),
227
+ ],
228
+ allow_flagging="never",
229
+ title=title,
230
+ description=description,
231
+ article=article,
232
+ )
233
+
234
+ youtube = gr.Interface(
235
+ fn=transcribe_youtube,
236
+ inputs=[
237
+ gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
238
+ gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
239
+ gr.Checkbox(value=False, label="Return timestamps"),
240
+ ],
241
+ outputs=[
242
+ gr.HTML(label="Video"),
243
+ gr.Textbox(label="Transcription", show_copy_button=True),
244
+ gr.Textbox(label="Transcription Time (s)"),
245
+ ],
246
+ allow_flagging="never",
247
+ title=title,
248
+ examples=[["https://www.youtube.com/watch?v=m8u-18Q0s7I", "transcribe", False]],
249
+ cache_examples=False,
250
+ description=description,
251
+ article=article,
252
+ )
253
+
254
+ demo = gr.Blocks()
255
+
256
+ with demo:
257
+ gr.TabbedInterface([microphone_chunked, audio_chunked, youtube], ["Microphone", "Audio File", "YouTube"])
258
+
259
+ demo.queue(max_size=5)
260
+ demo.launch(show_api=False)
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
2
+ jax[tpu]
3
+ torch
4
+ transformers>=4.40.0
5
+ flax
6
+ cached-property
7
+ requests
8
+ yt-dlp>=2023.3.4