Maximofn commited on
Commit
149ed58
1 Parent(s): d40da8b

Enhance Whisper transcription with multiple model support and performance improvements

Browse files

- Add support for Whisper and Distil-Whisper models
- Implement dynamic model selection with performance tracking
- Create transcription_to_dict function for structured transcription parsing
- Add Flash Attention 2 support for optimized inference
- Improve transcription pipeline configuration and timestamp handling

Files changed (1) hide show
  1. transcribe.py +145 -29
transcribe.py CHANGED
@@ -4,6 +4,12 @@ from lang_list import LANGUAGE_NAME_TO_CODE, WHISPER_LANGUAGES
4
  from tqdm import tqdm
5
  import torch
6
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
 
 
 
 
 
 
7
 
8
 
9
  def get_language_dict():
@@ -22,6 +28,59 @@ def get_language_dict():
22
  }
23
  return language_dict
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def transcribe(audio_file, language, device, chunk_length_s=30, stride_length_s=5):
26
  """
27
  Transcribe audio file using Whisper model.
@@ -42,46 +101,103 @@ def transcribe(audio_file, language, device, chunk_length_s=30, stride_length_s=
42
  filename_without_ext = os.path.splitext(audio_filename)[0]
43
  output_file = os.path.join(output_folder, f"{filename_without_ext}.srt")
44
 
 
45
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
46
 
47
  # Load model and processor
48
- model_id = "openai/whisper-large-v3-turbo"
49
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
50
- model_id,
51
- torch_dtype=torch_dtype,
52
- low_cpu_mem_usage=True,
53
- use_safetensors=True
54
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  model.to(device)
56
 
57
  processor = AutoProcessor.from_pretrained(model_id)
58
 
 
 
 
 
 
 
59
  # Create pipeline with timestamp generation
60
- pipe = pipeline(
61
- "automatic-speech-recognition",
62
- model=model,
63
- tokenizer=processor.tokenizer,
64
- feature_extractor=processor.feature_extractor,
65
- torch_dtype=torch_dtype,
66
- device=device,
67
- chunk_length_s=chunk_length_s,
68
- stride_length_s=stride_length_s,
69
- return_timestamps=True
70
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # Transcribe with timestamps and generate attention mask
73
- result = pipe(
74
- audio_file,
75
- return_timestamps=True,
76
- generate_kwargs={
77
- "language": language,
78
- "task": "transcribe",
79
- "use_cache": True,
80
- "num_beams": 1
81
- }
82
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- print(result)
85
 
86
  if __name__ == "__main__":
87
  parser = argparse.ArgumentParser(description='Transcribe audio files')
 
4
  from tqdm import tqdm
5
  import torch
6
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
7
+ from transformers.utils import is_flash_attn_2_available
8
+ from time import time
9
+
10
+ TRANSCRIPTOR_WHISPER = "openai/whisper-large-v3-turbo" # Time to transcribe: 296.53 seconds ==> minutes: 4.94
11
+ TRANSCRIPTOR_DISTIL_WHISPER = "distil-whisper/distil-large-v3" # Time to transcribe: 242.82 seconds ==> minutes: 4.05
12
+ TRANSCRIPTOR = TRANSCRIPTOR_DISTIL_WHISPER
13
 
14
 
15
  def get_language_dict():
 
28
  }
29
  return language_dict
30
 
31
+ def transcription_to_dict(transcription):
32
+ """
33
+ Convierte una transcripci贸n en formato string a un diccionario estructurado.
34
+
35
+ Args:
36
+ transcription (str): String que contiene la transcripci贸n con timestamps
37
+
38
+ Returns:
39
+ dict: Diccionario con el texto completo y los chunks con sus timestamps
40
+ """
41
+ try:
42
+ # Si la entrada es un string, convertirlo a diccionario
43
+ if isinstance(transcription, str):
44
+ # Evaluar el string como diccionario de Python
45
+ transcription_dict = eval(transcription)
46
+ else:
47
+ transcription_dict = transcription
48
+
49
+ # Validar la estructura del diccionario
50
+ if not isinstance(transcription_dict, dict):
51
+ raise ValueError("La transcripci贸n no tiene el formato esperado")
52
+
53
+ if 'text' not in transcription_dict or 'chunks' not in transcription_dict:
54
+ raise ValueError("La transcripci贸n no contiene los campos requeridos (text y chunks)")
55
+
56
+ # Limpiar los chunks vac铆os y validar timestamps
57
+ cleaned_chunks = []
58
+ for chunk in transcription_dict['chunks']:
59
+ # Verificar que el chunk tiene texto y timestamps v谩lidos
60
+ if (chunk.get('text') and
61
+ isinstance(chunk.get('timestamp'), (list, tuple)) and
62
+ len(chunk['timestamp']) == 2 and
63
+ chunk['timestamp'][0] is not None and
64
+ chunk['timestamp'][1] is not None):
65
+
66
+ cleaned_chunks.append({
67
+ 'start': float(chunk['timestamp'][0]), # Convertir a float
68
+ 'end': float(chunk['timestamp'][1]), # Convertir a float
69
+ 'text': chunk['text'].strip()
70
+ })
71
+
72
+ # Crear el diccionario final limpio
73
+ result = {
74
+ 'text': transcription_dict['text'],
75
+ 'chunks': cleaned_chunks
76
+ }
77
+
78
+ return result
79
+
80
+ except Exception as e:
81
+ print(f"Error procesando la transcripci贸n: {e}")
82
+ return None
83
+
84
  def transcribe(audio_file, language, device, chunk_length_s=30, stride_length_s=5):
85
  """
86
  Transcribe audio file using Whisper model.
 
101
  filename_without_ext = os.path.splitext(audio_filename)[0]
102
  output_file = os.path.join(output_folder, f"{filename_without_ext}.srt")
103
 
104
+ device = torch.device(device)
105
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
106
 
107
  # Load model and processor
108
+ model_id = TRANSCRIPTOR
109
+ t0 = time()
110
+
111
+ # Configurar Flash Attention 2 si est谩 disponible
112
+ print(f"Using Flash Attention 2: {is_flash_attn_2_available()}")
113
+ if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER:
114
+ model_kwargs = {"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"}
115
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
116
+ model_id,
117
+ torch_dtype=torch_dtype,
118
+ low_cpu_mem_usage=True,
119
+ use_safetensors=True,
120
+ **model_kwargs
121
+ )
122
+ else:
123
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
124
+ model_id,
125
+ torch_dtype=torch_dtype,
126
+ low_cpu_mem_usage=True,
127
+ use_safetensors=True,
128
+ )
129
  model.to(device)
130
 
131
  processor = AutoProcessor.from_pretrained(model_id)
132
 
133
+ timestamp = True
134
+ if TRANSCRIPTOR == TRANSCRIPTOR_DISTIL_WHISPER:
135
+ timestamp = "word"
136
+ else:
137
+ timestamp = True
138
+
139
  # Create pipeline with timestamp generation
140
+ if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER:
141
+ pipe = pipeline(
142
+ "automatic-speech-recognition",
143
+ model=model,
144
+ tokenizer=processor.tokenizer,
145
+ feature_extractor=processor.feature_extractor,
146
+ torch_dtype=torch_dtype,
147
+ device=device,
148
+ chunk_length_s=chunk_length_s,
149
+ stride_length_s=stride_length_s,
150
+ return_timestamps=timestamp,
151
+ max_new_tokens=128,
152
+ batch_size=24,
153
+ model_kwargs=model_kwargs
154
+ )
155
+ else:
156
+ pipe = pipeline(
157
+ "automatic-speech-recognition",
158
+ model=model,
159
+ tokenizer=processor.tokenizer,
160
+ feature_extractor=processor.feature_extractor,
161
+ torch_dtype=torch_dtype,
162
+ device=device,
163
+ chunk_length_s=chunk_length_s,
164
+ stride_length_s=stride_length_s,
165
+ return_timestamps=timestamp,
166
+ max_new_tokens=128,
167
+ )
168
 
169
  # Transcribe with timestamps and generate attention mask
170
+ if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER:
171
+ result = pipe(
172
+ audio_file,
173
+ return_timestamps=timestamp,
174
+ batch_size=24,
175
+ generate_kwargs={
176
+ "language": language,
177
+ "task": "transcribe",
178
+ "use_cache": True,
179
+ "num_beams": 1
180
+ }
181
+ )
182
+ else:
183
+ result = pipe(
184
+ audio_file,
185
+ return_timestamps=timestamp,
186
+ generate_kwargs={
187
+ "language": language,
188
+ "task": "transcribe",
189
+ "use_cache": True,
190
+ "num_beams": 1
191
+ }
192
+ )
193
+
194
+ t = time()
195
+ print(f"Time to transcribe: {t - t0:.2f} seconds")
196
+
197
+ transcription_str = result
198
+ transcription_dict = transcription_to_dict(transcription_str)
199
 
200
+ return transcription_str, transcription_dict
201
 
202
  if __name__ == "__main__":
203
  parser = argparse.ArgumentParser(description='Transcribe audio files')