patrickvonplaten commited on
Commit
81a9d24
1 Parent(s): 4487a27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -20
app.py CHANGED
@@ -13,10 +13,10 @@ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
13
  use_flash_attention_2 = is_flash_attn_2_available()
14
 
15
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
16
- "openai/whisper-large-v2", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=use_flash_attention_2
17
  )
18
  distilled_model = AutoModelForSpeechSeq2Seq.from_pretrained(
19
- "distil-whisper/distil-large-v2", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=use_flash_attention_2
20
  )
21
 
22
  if not use_flash_attention_2:
@@ -24,7 +24,7 @@ if not use_flash_attention_2:
24
  model = model.to_bettertransformer()
25
  distilled_model = distilled_model.to_bettertransformer()
26
 
27
- processor = AutoProcessor.from_pretrained("openai/whisper-large-v2")
28
 
29
  model.to(device)
30
  distilled_model.to(device)
@@ -38,7 +38,7 @@ pipe = pipeline(
38
  chunk_length_s=30,
39
  torch_dtype=torch_dtype,
40
  device=device,
41
- generate_kwargs={"language": "en", "task": "transcribe"},
42
  return_timestamps=True
43
  )
44
  pipe_forward = pipe._forward
@@ -52,7 +52,7 @@ distil_pipe = pipeline(
52
  chunk_length_s=15,
53
  torch_dtype=torch_dtype,
54
  device=device,
55
- generate_kwargs={"language": "en", "task": "transcribe"},
56
  )
57
  distil_pipe_forward = distil_pipe._forward
58
 
@@ -110,7 +110,7 @@ if __name__ == "__main__":
110
  "
111
  >
112
  <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
113
- Whisper vs Distil-Whisper: Speed Comparison
114
  </h1>
115
  </div>
116
  </div>
@@ -133,22 +133,11 @@ if __name__ == "__main__":
133
  audio = gr.components.Audio(type="filepath", label="Audio input")
134
  button = gr.Button("Transcribe")
135
  with gr.Row():
136
- distil_runtime = gr.components.Textbox(label="Distil-Whisper Transcription Time (s)")
137
- runtime = gr.components.Textbox(label="Whisper Transcription Time (s)")
138
- with gr.Row():
139
- distil_transcription = gr.components.Textbox(label="Distil-Whisper Transcription", show_copy_button=True)
140
- transcription = gr.components.Textbox(label="Whisper Transcription", show_copy_button=True)
141
  button.click(
142
  fn=transcribe,
143
  inputs=audio,
144
- outputs=[distil_transcription, distil_runtime, transcription, runtime],
145
- )
146
- gr.Markdown("## Examples")
147
- gr.Examples(
148
- [["./assets/example_1.wav"], ["./assets/example_2.wav"]],
149
- audio,
150
- outputs=[distil_transcription, distil_runtime, transcription, runtime],
151
- fn=transcribe,
152
- cache_examples=False,
153
  )
154
  demo.queue(max_size=10).launch()
 
13
  use_flash_attention_2 = is_flash_attn_2_available()
14
 
15
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
16
+ "openai/whisper-large-v3", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=use_flash_attention_2
17
  )
18
  distilled_model = AutoModelForSpeechSeq2Seq.from_pretrained(
19
+ "primeline/whisper-large-v3-german", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=False, use_flash_attention_2=use_flash_attention_2
20
  )
21
 
22
  if not use_flash_attention_2:
 
24
  model = model.to_bettertransformer()
25
  distilled_model = distilled_model.to_bettertransformer()
26
 
27
+ processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")
28
 
29
  model.to(device)
30
  distilled_model.to(device)
 
38
  chunk_length_s=30,
39
  torch_dtype=torch_dtype,
40
  device=device,
41
+ generate_kwargs={"language": "de", "task": "transcribe"},
42
  return_timestamps=True
43
  )
44
  pipe_forward = pipe._forward
 
52
  chunk_length_s=15,
53
  torch_dtype=torch_dtype,
54
  device=device,
55
+ generate_kwargs={"language": "de", "task": "transcribe"},
56
  )
57
  distil_pipe_forward = distil_pipe._forward
58
 
 
110
  "
111
  >
112
  <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
113
+ Whisper-v3 vs Whisper-German-v3
114
  </h1>
115
  </div>
116
  </div>
 
133
  audio = gr.components.Audio(type="filepath", label="Audio input")
134
  button = gr.Button("Transcribe")
135
  with gr.Row():
136
+ distil_transcription = gr.components.Textbox(label="Whisper-v3-German Transcription", show_copy_button=True)
137
+ transcription = gr.components.Textbox(label="Whisper-v3 Transcription", show_copy_button=True)
 
 
 
138
  button.click(
139
  fn=transcribe,
140
  inputs=audio,
141
+ outputs=[distil_transcription, transcription],
 
 
 
 
 
 
 
 
142
  )
143
  demo.queue(max_size=10).launch()