sanchit-gandhi HF staff commited on
Commit
afda370
β€’
1 Parent(s): 3d06185

Update README.md (#144)

Browse files

- Update README.md (29aee6b8a7c9c86b46f63bf6fc2331151935026b)

Files changed (1) hide show
  1. README.md +136 -165
README.md CHANGED
@@ -114,68 +114,39 @@ license: apache-2.0
114
 
115
  # Whisper
116
 
117
- Whisper is a pre-trained model for automatic speech recognition (ASR) and speech translation. Trained on 680k hours
118
- of labelled data, Whisper models demonstrate a strong ability to generalise to many datasets and domains **without** the need
119
- for fine-tuning.
 
120
 
121
- Whisper was proposed in the paper [Robust Speech Recognition via Large-Scale Weak Supervision](https://arxiv.org/abs/2212.04356)
122
- by Alec Radford et al. from OpenAI. The original code repository can be found [here](https://github.com/openai/whisper).
123
 
124
- Whisper `large-v3` has the same architecture as the previous large models except the following minor differences:
125
-
126
- 1. The input uses 128 Mel frequency bins instead of 80
127
  2. A new language token for Cantonese
128
 
129
- The Whisper `large-v3` model is trained on 1 million hours of weakly labeled audio and 4 million hours of pseudolabeled audio collected using Whisper `large-v2`.
130
- The model was trained for 2.0 epochs over this mixture dataset.
131
-
132
- The `large-v3` model shows improved performance over a wide variety of languages, showing 10% to 20% reduction of errors compared to Whisper `large-v2`.
133
-
134
-
135
- **Disclaimer**: Content for this model card has partly been written by the Hugging Face team, and parts of it were
136
- copied and pasted from the original model card.
137
-
138
- ## Model details
139
-
140
- Whisper is a Transformer based encoder-decoder model, also referred to as a _sequence-to-sequence_ model.
141
- It was trained on 1 million hours of weakly labeled audio and 4 million hours of pseudolabeled audio collected using Whisper `large-v2`.
142
-
143
- The models were trained on either English-only data or multilingual data. The English-only models were trained
144
- on the task of speech recognition. The multilingual models were trained on both speech recognition and speech
145
- translation. For speech recognition, the model predicts transcriptions in the *same* language as the audio.
146
- For speech translation, the model predicts transcriptions to a *different* language to the audio.
147
 
148
- Whisper checkpoints come in five configurations of varying model sizes.
149
- The smallest four are trained on either English-only or multilingual data.
150
- The largest checkpoints are multilingual only. All ten of the pre-trained checkpoints
151
- are available on the [Hugging Face Hub](https://huggingface.co/models?search=openai/whisper). The
152
- checkpoints are summarised in the following table with links to the models on the Hub:
153
 
154
- | Size | Parameters | English-only | Multilingual |
155
- |----------|------------|------------------------------------------------------|-----------------------------------------------------|
156
- | tiny | 39 M | [βœ“](https://huggingface.co/openai/whisper-tiny.en) | [βœ“](https://huggingface.co/openai/whisper-tiny) |
157
- | base | 74 M | [βœ“](https://huggingface.co/openai/whisper-base.en) | [βœ“](https://huggingface.co/openai/whisper-base) |
158
- | small | 244 M | [βœ“](https://huggingface.co/openai/whisper-small.en) | [βœ“](https://huggingface.co/openai/whisper-small) |
159
- | medium | 769 M | [βœ“](https://huggingface.co/openai/whisper-medium.en) | [βœ“](https://huggingface.co/openai/whisper-medium) |
160
- | large | 1550 M | x | [βœ“](https://huggingface.co/openai/whisper-large) |
161
- | large-v2 | 1550 M | x | [βœ“](https://huggingface.co/openai/whisper-large-v2) |
162
- | large-v3 | 1550 M | x | [βœ“](https://huggingface.co/openai/whisper-large-v3) |
163
 
164
  ## Usage
165
 
166
- Whisper `large-v3` is supported in Hugging Face πŸ€— Transformers. To run the model, first
167
- install the Transformers library through the GitHub repo. For this example, we'll also install πŸ€— Datasets to load toy
168
- audio dataset from the Hugging Face Hub:
169
 
170
  ```bash
171
  pip install --upgrade pip
172
- pip install --upgrade git+https://github.com/huggingface/transformers.git accelerate datasets[audio]
173
  ```
174
 
175
- ### Short-Form Transcription
176
-
177
  The model can be used with the [`pipeline`](https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline)
178
- class to transcribe short-form audio files (< 30-seconds) as follows:
179
 
180
  ```python
181
  import torch
@@ -200,10 +171,6 @@ pipe = pipeline(
200
  model=model,
201
  tokenizer=processor.tokenizer,
202
  feature_extractor=processor.feature_extractor,
203
- max_new_tokens=128,
204
- chunk_length_s=30,
205
- batch_size=16,
206
- return_timestamps=True,
207
  torch_dtype=torch_dtype,
208
  device=device,
209
  )
@@ -216,9 +183,33 @@ print(result["text"])
216
  ```
217
 
218
  To transcribe a local audio file, simply pass the path to your audio file when you call the pipeline:
219
- ```diff
220
- - result = pipe(sample)
221
- + result = pipe("audio.mp3")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  ```
223
 
224
  Whisper predicts the language of the source audio automatically. If the source audio language is known *a-priori*, it
@@ -261,10 +252,6 @@ print(result["chunks"])
261
 
262
  <summary> For more control over the generation parameters, use the model + processor API directly: </summary>
263
 
264
- Ad-hoc generation arguments can be passed to `model.generate`, including `num_beams` for beam-search, `return_timestamps`
265
- for segment-level timestamps, and `prompt_ids` for prompting. See the [docstrings](https://huggingface.co/docs/transformers/en/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate)
266
- for more details.
267
-
268
  ```python
269
  import torch
270
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
@@ -277,7 +264,7 @@ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
277
  model_id = "openai/whisper-large-v3"
278
 
279
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
280
- model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
281
  )
282
  model.to(device)
283
 
@@ -287,37 +274,58 @@ dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", spl
287
  dataset = dataset.cast_column("audio", Audio(processor.feature_extractor.sampling_rate))
288
  sample = dataset[0]["audio"]
289
 
290
- input_features = processor(
291
- sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
292
- ).input_features
293
-
294
- input_features = input_features.to(device, dtype=torch_dtype)
 
 
 
 
295
 
296
  gen_kwargs = {
297
- "max_new_tokens": 128,
298
- "num_beams": 1,
299
- "return_timestamps": False,
 
 
 
 
 
300
  }
301
 
302
- pred_ids = model.generate(input_features, **gen_kwargs)
303
- pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True, decode_with_timestamps=gen_kwargs["return_timestamps"])
304
 
305
  print(pred_text)
306
  ```
307
 
308
  </details>
309
 
310
- ### Sequential Long-Form
311
 
312
- This algorithm uses a sliding window for buffered inference of long audio files (> 30-seconds),
313
- and returns more accurate transcriptions compared to the [chunked long-form algorithm](#chunked-long-form).
 
 
 
 
 
 
 
314
 
315
  The sequential long-form algorithm should be used in either of the following scenarios:
316
- 1. Transcription accuracy is the most important factor, and latency is less of a consideration
317
  2. You are transcribing **batches** of long audio files, in which case the latency of sequential is comparable to chunked, while being up to 0.5% WER more accurate
318
 
319
- The [`pipeline`](https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline)
320
- class can be used to transcribe long audio files with the sequential algorithm as follows:
 
 
 
 
 
321
 
322
  ```python
323
  import torch
@@ -331,7 +339,7 @@ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
331
  model_id = "openai/whisper-large-v3"
332
 
333
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
334
- model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
335
  )
336
  model.to(device)
337
 
@@ -342,7 +350,8 @@ pipe = pipeline(
342
  model=model,
343
  tokenizer=processor.tokenizer,
344
  feature_extractor=processor.feature_extractor,
345
- max_new_tokens=128,
 
346
  torch_dtype=torch_dtype,
347
  device=device,
348
  )
@@ -354,76 +363,21 @@ result = pipe(sample)
354
  print(result["text"])
355
  ```
356
 
357
- <details>
358
-
359
- <summary> For more control over the generation parameters, use the model + processor API directly: </summary>
360
-
361
- ```python
362
- import torch
363
- from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
364
- from datasets import Audio, load_dataset
365
-
366
-
367
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
368
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
369
-
370
- model_id = "openai/whisper-large-v3"
371
-
372
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
373
- model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
374
- )
375
- model.to(device)
376
-
377
- processor = AutoProcessor.from_pretrained(model_id)
378
-
379
- dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
380
- dataset = dataset.cast_column("audio", Audio(processor.feature_extractor.sampling_rate))
381
- sample = dataset[0]["audio"]
382
-
383
- inputs = processor(
384
- sample["array"],
385
- sampling_rate=sample["sampling_rate"],
386
- return_tensors="pt",
387
- truncation=False,
388
- padding="longest",
389
- return_attention_mask=True,
390
- )
391
- inputs = inputs.to(device, dtype=torch_dtype)
392
-
393
- gen_kwargs = {
394
- "max_new_tokens": 448,
395
- "num_beams": 1,
396
- "condition_on_prev_tokens": False,
397
- "compression_ratio_threshold": 1.35, # zlib compression ratio threshold (in token space)
398
- "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
399
- "logprob_threshold": -1.0,
400
- "no_speech_threshold": 0.6,
401
- "return_timestamps": True,
402
- }
403
-
404
- pred_ids = model.generate(**i nputs, **gen_kwargs)
405
- pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True, decode_with_timestamps=False)
406
-
407
- print(pred_text)
408
- ```
409
-
410
- </details>
411
-
412
- ### Chunked Long-Form
413
 
414
- large-v3 remains compatible with the Transformers chunked long-form algorithm. This algorithm should be used when
415
- a single large audio file is being transcribed and the fastest possible inference is required. In such circumstances,
416
- the chunked algorithm is up to 9x faster than OpenAI's sequential long-form implementation (see Table 7 of the
417
- [Distil-Whisper paper](https://arxiv.org/pdf/2311.00430.pdf)).
418
 
419
- To enable chunking, pass the `chunk_length_s` parameter to the `pipeline`. For distil-large-v3, a chunk length of 25-seconds
420
- is optimal. To activate batching over long audio files, pass the argument `batch_size`:
421
 
422
  ```python
423
  import torch
 
424
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
425
  from datasets import load_dataset
 
426
 
 
427
 
428
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
429
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
@@ -431,9 +385,12 @@ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
431
  model_id = "openai/whisper-large-v3"
432
 
433
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
434
- model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
435
- )
436
- model.to(device)
 
 
 
437
 
438
  processor = AutoProcessor.from_pretrained(model_id)
439
 
@@ -442,9 +399,6 @@ pipe = pipeline(
442
  model=model,
443
  tokenizer=processor.tokenizer,
444
  feature_extractor=processor.feature_extractor,
445
- max_new_tokens=128,
446
- chunk_length_s=25,
447
- batch_size=16,
448
  torch_dtype=torch_dtype,
449
  device=device,
450
  )
@@ -452,20 +406,22 @@ pipe = pipeline(
452
  dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
453
  sample = dataset[0]["audio"]
454
 
455
- result = pipe(sample)
456
- print(result["text"])
457
- ```
 
458
 
459
- ### Additional Speed & Memory Improvements
 
 
460
 
461
- You can apply additional speed and memory improvements to Distil-Whisper to further reduce the inference speed and VRAM
462
- requirements. These optimisations primarily target the attention kernel, swapping it from an eager implementation to a
463
- more efficient flash attention version.
464
 
465
  #### Flash Attention 2
466
 
467
- We recommend using [Flash-Attention 2](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#flashattention-2)
468
- if your GPU allows for it. To do so, you first need to install [Flash Attention](https://github.com/Dao-AILab/flash-attention):
469
 
470
  ```
471
  pip install flash-attn --no-build-isolation
@@ -473,9 +429,8 @@ pip install flash-attn --no-build-isolation
473
 
474
  Then pass `attn_implementation="flash_attention_2"` to `from_pretrained`:
475
 
476
- ```diff
477
- - model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
478
- + model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2")
479
  ```
480
 
481
  #### Torch Scale-Product-Attention (SDPA)
@@ -496,20 +451,36 @@ returns `False`, you need to upgrade your PyTorch version according to the [offi
496
  Once a valid PyTorch version is installed, SDPA is activated by default. It can also be set explicitly by specifying
497
  `attn_implementation="sdpa"` as follows:
498
 
499
- ```diff
500
- - model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
501
- + model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="sdpa")
502
  ```
503
 
504
  For more information about how to use the SDPA refer to the [Transformers SDPA documentation](https://huggingface.co/docs/transformers/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention).
505
 
506
- #### Torch compile
507
 
508
- Coming soon...
 
 
 
 
 
 
 
 
 
 
 
509
 
510
- #### 4-bit and 8-bit Inference
 
 
 
 
 
 
 
 
511
 
512
- Coming soon...
513
 
514
  ## Fine-Tuning
515
 
@@ -529,7 +500,7 @@ In particular, we caution against using Whisper models to transcribe recordings
529
 
530
  ## Training Data
531
 
532
- The models are trained on 1 million hours of weakly labeled audio and 4 million hours of pseudolabeled audio collected using Whisper `large-v2`.
533
 
534
  As discussed in [the accompanying paper](https://cdn.openai.com/papers/whisper.pdf), we see that performance on transcription in a given language is directly correlated with the amount of training data we employ in that language.
535
 
 
114
 
115
  # Whisper
116
 
117
+ Whisper is a state-of-the-art model for automatic speech recognition (ASR) and speech translation, proposed in the paper
118
+ [Robust Speech Recognition via Large-Scale Weak Supervision](https://huggingface.co/papers/2212.04356) by Alec Radford
119
+ et al. from OpenAI. Trained on >5M hours of labeled data, Whisper demonstrates a strong ability to generalise to many
120
+ datasets and domains in a zero-shot setting.
121
 
122
+ Whisper large-v3 has the same architecture as the previous [large](https://huggingface.co/openai/whisper-large) and [large-v2](https://huggingface.co/openai/whisper-large-v2)
123
+ models, except for the following minor differences:
124
 
125
+ 1. The spectrogram input uses 128 Mel frequency bins instead of 80
 
 
126
  2. A new language token for Cantonese
127
 
128
+ The Whisper large-v3 model was trained on 1 million hours of weakly labeled audio and 4 million hours of pseudo-labeled
129
+ audio collected using Whisper [large-v2](https://huggingface.co/openai/whisper-large-v2) . The model was trained for 2.0 epochs over this mixture dataset.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ The large-v3 model shows improved performance over a wide variety of languages, showing 10% to 20% reduction of errors
132
+ compared to Whisper [large-v2](https://huggingface.co/openai/whisper-large-v2) . For more details on the different checkpoints available, refer to the section [Model details](#model-details).
 
 
 
133
 
134
+ **Disclaimer**: Content for this model card has partly been written by the πŸ€— Hugging Face team, and partly copied and
135
+ pasted from the original model card.
 
 
 
 
 
 
 
136
 
137
  ## Usage
138
 
139
+ Whisper large-v3 is supported in Hugging Face πŸ€— Transformers. To run the model, first install the Transformers
140
+ library. For this example, we'll also install πŸ€— Datasets to load toy audio dataset from the Hugging Face Hub, and
141
+ πŸ€— Accelerate to reduce the model loading time:
142
 
143
  ```bash
144
  pip install --upgrade pip
145
+ pip install --upgrade transformers datasets[audio] accelerate
146
  ```
147
 
 
 
148
  The model can be used with the [`pipeline`](https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline)
149
+ class to transcribe audios of arbitrary length:
150
 
151
  ```python
152
  import torch
 
171
  model=model,
172
  tokenizer=processor.tokenizer,
173
  feature_extractor=processor.feature_extractor,
 
 
 
 
174
  torch_dtype=torch_dtype,
175
  device=device,
176
  )
 
183
  ```
184
 
185
  To transcribe a local audio file, simply pass the path to your audio file when you call the pipeline:
186
+
187
+ ```python
188
+ result = pipe("audio.mp3")
189
+ ```
190
+
191
+ Multiple audio files can be transcribed in parallel by specifying them as a list and setting the `batch_size` parameter:
192
+
193
+ ```python
194
+ result = pipe(["audio_1.mp3", "audio_2.mp3"], batch_size=2)
195
+ ```
196
+
197
+ Transformers is compatible with all Whisper decoding strategies, such as temperature fallback and condition on previous
198
+ tokens. The following example demonstrates how to enable these heuristics:
199
+
200
+ ```python
201
+ generate_kwargs = {
202
+ "max_new_tokens": 448,
203
+ "num_beams": 1,
204
+ "condition_on_prev_tokens": False,
205
+ "compression_ratio_threshold": 1.35, # zlib compression ratio threshold (in token space)
206
+ "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
207
+ "logprob_threshold": -1.0,
208
+ "no_speech_threshold": 0.6,
209
+ "return_timestamps": True,
210
+ }
211
+
212
+ result = pipe(sample, generate_kwargs=generate_kwargs)
213
  ```
214
 
215
  Whisper predicts the language of the source audio automatically. If the source audio language is known *a-priori*, it
 
252
 
253
  <summary> For more control over the generation parameters, use the model + processor API directly: </summary>
254
 
 
 
 
 
255
  ```python
256
  import torch
257
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
 
264
  model_id = "openai/whisper-large-v3"
265
 
266
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
267
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
268
  )
269
  model.to(device)
270
 
 
274
  dataset = dataset.cast_column("audio", Audio(processor.feature_extractor.sampling_rate))
275
  sample = dataset[0]["audio"]
276
 
277
+ inputs = processor(
278
+ sample["array"],
279
+ sampling_rate=sample["sampling_rate"],
280
+ return_tensors="pt",
281
+ truncation=False,
282
+ padding="longest",
283
+ return_attention_mask=True,
284
+ )
285
+ inputs = inputs.to(device, dtype=torch_dtype)
286
 
287
  gen_kwargs = {
288
+ "max_new_tokens": 448,
289
+ "num_beams": 1,
290
+ "condition_on_prev_tokens": False,
291
+ "compression_ratio_threshold": 1.35, # zlib compression ratio threshold (in token space)
292
+ "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
293
+ "logprob_threshold": -1.0,
294
+ "no_speech_threshold": 0.6,
295
+ "return_timestamps": True,
296
  }
297
 
298
+ pred_ids = model.generate(**inputs, **gen_kwargs)
299
+ pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True, decode_with_timestamps=False)
300
 
301
  print(pred_text)
302
  ```
303
 
304
  </details>
305
 
306
+ ## Additional Speed & Memory Improvements
307
 
308
+ You can apply additional speed and memory improvements to Whisper to further reduce the inference speed and VRAM
309
+ requirements.
310
+
311
+ ### Chunked Long-Form
312
+
313
+ Whisper has a receptive field of 30-seconds. To transcribe audios longer than this, one of two long-form algorithms are
314
+ required:
315
+ 1. **Sequential:** uses a "sliding window" for buffered inference, transcribing 30-second slices one after the other
316
+ 2. **Chunked:** splits long audio files into shorter ones (with a small overlap between segments), transcribes each segment independently, and stitches the resulting transcriptions at the boundaries
317
 
318
  The sequential long-form algorithm should be used in either of the following scenarios:
319
+ 1. Transcription accuracy is the most important factor, and speed is less of a consideration
320
  2. You are transcribing **batches** of long audio files, in which case the latency of sequential is comparable to chunked, while being up to 0.5% WER more accurate
321
 
322
+ Conversely, the chunked algorithm should be used when:
323
+ 1. Transcription speed is the most important factor
324
+ 2. You are transcribing a **single** long audio file
325
+
326
+ By default, Transformers uses the sequential algorithm. To enable the chunked algorithm, pass the `chunk_length_s`
327
+ parameter to the `pipeline`. For large-v3, a chunk length of 30-seconds is optimal. To activate batching over long
328
+ audio files, pass the argument `batch_size`:
329
 
330
  ```python
331
  import torch
 
339
  model_id = "openai/whisper-large-v3"
340
 
341
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
342
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
343
  )
344
  model.to(device)
345
 
 
350
  model=model,
351
  tokenizer=processor.tokenizer,
352
  feature_extractor=processor.feature_extractor,
353
+ chunk_length_s=30,
354
+ batch_size=16, # batch size for inference - set based on your device
355
  torch_dtype=torch_dtype,
356
  device=device,
357
  )
 
363
  print(result["text"])
364
  ```
365
 
366
+ #### Torch compile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
+ The Whisper forward pass is compatible with [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html)
369
+ for 4.5x speed-ups.
 
 
370
 
371
+ **Note:** `torch.compile` is currently not compatible with the Chunked long-form algorithm or Flash Attention 2 ⚠️
 
372
 
373
  ```python
374
  import torch
375
+ from torch.nn.attention import SDPBackend, sdpa_kernel
376
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
377
  from datasets import load_dataset
378
+ from tqdm import tqdm
379
 
380
+ torch.set_float32_matmul_precision("high")
381
 
382
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
383
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
385
  model_id = "openai/whisper-large-v3"
386
 
387
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
388
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
389
+ ).to(device)
390
+
391
+ # Enable static cache and compile the forward pass
392
+ model.generation_config.cache_implementation = "static"
393
+ model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
394
 
395
  processor = AutoProcessor.from_pretrained(model_id)
396
 
 
399
  model=model,
400
  tokenizer=processor.tokenizer,
401
  feature_extractor=processor.feature_extractor,
 
 
 
402
  torch_dtype=torch_dtype,
403
  device=device,
404
  )
 
406
  dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
407
  sample = dataset[0]["audio"]
408
 
409
+ # 2 warmup steps
410
+ for _ in tqdm(range(2), desc="Warm-up step"):
411
+ with sdpa_kernel(SDPBackend.MATH):
412
+ result = pipe(sample.copy())
413
 
414
+ # fast run
415
+ with sdpa_kernel(SDPBackend.MATH):
416
+ result = pipe(sample.copy())
417
 
418
+ print(result["text"])
419
+ ```
 
420
 
421
  #### Flash Attention 2
422
 
423
+ We recommend using [Flash-Attention 2](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#flashattention-2) if your GPU supports it and you are not using [torch.compile](#torch-compile).
424
+ To do so, first install [Flash Attention](https://github.com/Dao-AILab/flash-attention):
425
 
426
  ```
427
  pip install flash-attn --no-build-isolation
 
429
 
430
  Then pass `attn_implementation="flash_attention_2"` to `from_pretrained`:
431
 
432
+ ```python
433
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, attn_implementation="flash_attention_2")
 
434
  ```
435
 
436
  #### Torch Scale-Product-Attention (SDPA)
 
451
  Once a valid PyTorch version is installed, SDPA is activated by default. It can also be set explicitly by specifying
452
  `attn_implementation="sdpa"` as follows:
453
 
454
+ ```python
455
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, attn_implementation="sdpa")
 
456
  ```
457
 
458
  For more information about how to use the SDPA refer to the [Transformers SDPA documentation](https://huggingface.co/docs/transformers/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention).
459
 
 
460
 
461
+ ## Model details
462
+
463
+ Whisper is a Transformer based encoder-decoder model, also referred to as a _sequence-to-sequence_ model. There are two
464
+ flavours of Whisper model: English-only and multilingual. The English-only models were trained on the task of English
465
+ speech recognition. The multilingual models were trained simultaneously on multilingual speech recognition and speech
466
+ translation. For speech recognition, the model predicts transcriptions in the *same* language as the audio. For speech
467
+ translation, the model predicts transcriptions to a *different* language to the audio.
468
+
469
+ Whisper checkpoints come in five configurations of varying model sizes. The smallest four are available as English-only
470
+ and multilingual. The largest checkpoints are multilingual only. All ten of the pre-trained checkpoints
471
+ are available on the [Hugging Face Hub](https://huggingface.co/models?search=openai/whisper). The
472
+ checkpoints are summarised in the following table with links to the models on the Hub:
473
 
474
+ | Size | Parameters | English-only | Multilingual |
475
+ |----------|------------|------------------------------------------------------|-----------------------------------------------------|
476
+ | tiny | 39 M | [βœ“](https://huggingface.co/openai/whisper-tiny.en) | [βœ“](https://huggingface.co/openai/whisper-tiny) |
477
+ | base | 74 M | [βœ“](https://huggingface.co/openai/whisper-base.en) | [βœ“](https://huggingface.co/openai/whisper-base) |
478
+ | small | 244 M | [βœ“](https://huggingface.co/openai/whisper-small.en) | [βœ“](https://huggingface.co/openai/whisper-small) |
479
+ | medium | 769 M | [βœ“](https://huggingface.co/openai/whisper-medium.en) | [βœ“](https://huggingface.co/openai/whisper-medium) |
480
+ | large | 1550 M | x | [βœ“](https://huggingface.co/openai/whisper-large) |
481
+ | large-v2 | 1550 M | x | [βœ“](https://huggingface.co/openai/whisper-large-v2) |
482
+ | large-v3 | 1550 M | x | [βœ“](https://huggingface.co/openai/whisper-large-v3) |
483
 
 
484
 
485
  ## Fine-Tuning
486
 
 
500
 
501
  ## Training Data
502
 
503
+ The large-v3 checkpoint is trained on 1 million hours of weakly labeled audio and 4 million hours of pseudo-labeled audio collected using Whisper large-v2.
504
 
505
  As discussed in [the accompanying paper](https://cdn.openai.com/papers/whisper.pdf), we see that performance on transcription in a given language is directly correlated with the amount of training data we employ in that language.
506