subtify / transcribe.py
Maximofn's picture
Enhance Whisper transcription with multiple model support and performance improvements
149ed58
raw
history blame
8.27 kB
import os
import argparse
from lang_list import LANGUAGE_NAME_TO_CODE, WHISPER_LANGUAGES
from tqdm import tqdm
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers.utils import is_flash_attn_2_available
from time import time
TRANSCRIPTOR_WHISPER = "openai/whisper-large-v3-turbo" # Time to transcribe: 296.53 seconds ==> minutes: 4.94
TRANSCRIPTOR_DISTIL_WHISPER = "distil-whisper/distil-large-v3" # Time to transcribe: 242.82 seconds ==> minutes: 4.05
TRANSCRIPTOR = TRANSCRIPTOR_DISTIL_WHISPER
def get_language_dict():
language_dict = {}
# Iterate over the LANGUAGE_NAME_TO_CODE dictionary
for language_name, language_code in LANGUAGE_NAME_TO_CODE.items():
# Extract the language code (the first two characters before the underscore)
lang_code = language_code.split('_')[0].lower()
# Check if the language code is present in WHISPER_LANGUAGES
if lang_code in WHISPER_LANGUAGES:
# Construct the entry for the resulting dictionary
language_dict[language_name] = {
"transcriber": lang_code,
"translator": language_code
}
return language_dict
def transcription_to_dict(transcription):
"""
Convierte una transcripci贸n en formato string a un diccionario estructurado.
Args:
transcription (str): String que contiene la transcripci贸n con timestamps
Returns:
dict: Diccionario con el texto completo y los chunks con sus timestamps
"""
try:
# Si la entrada es un string, convertirlo a diccionario
if isinstance(transcription, str):
# Evaluar el string como diccionario de Python
transcription_dict = eval(transcription)
else:
transcription_dict = transcription
# Validar la estructura del diccionario
if not isinstance(transcription_dict, dict):
raise ValueError("La transcripci贸n no tiene el formato esperado")
if 'text' not in transcription_dict or 'chunks' not in transcription_dict:
raise ValueError("La transcripci贸n no contiene los campos requeridos (text y chunks)")
# Limpiar los chunks vac铆os y validar timestamps
cleaned_chunks = []
for chunk in transcription_dict['chunks']:
# Verificar que el chunk tiene texto y timestamps v谩lidos
if (chunk.get('text') and
isinstance(chunk.get('timestamp'), (list, tuple)) and
len(chunk['timestamp']) == 2 and
chunk['timestamp'][0] is not None and
chunk['timestamp'][1] is not None):
cleaned_chunks.append({
'start': float(chunk['timestamp'][0]), # Convertir a float
'end': float(chunk['timestamp'][1]), # Convertir a float
'text': chunk['text'].strip()
})
# Crear el diccionario final limpio
result = {
'text': transcription_dict['text'],
'chunks': cleaned_chunks
}
return result
except Exception as e:
print(f"Error procesando la transcripci贸n: {e}")
return None
def transcribe(audio_file, language, device, chunk_length_s=30, stride_length_s=5):
"""
Transcribe audio file using Whisper model.
Args:
audio_file (str): Path to audio file
language (str): Language code for transcription
device (str): Device to use for inference ('cuda' or 'cpu')
chunk_length_s (int): Length of audio chunks in seconds
stride_length_s (int): Stride length between chunks in seconds
"""
output_folder = "transcriptions"
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# Get output filename
audio_filename = os.path.basename(audio_file)
filename_without_ext = os.path.splitext(audio_filename)[0]
output_file = os.path.join(output_folder, f"{filename_without_ext}.srt")
device = torch.device(device)
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load model and processor
model_id = TRANSCRIPTOR
t0 = time()
# Configurar Flash Attention 2 si est谩 disponible
print(f"Using Flash Attention 2: {is_flash_attn_2_available()}")
if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER:
model_kwargs = {"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"}
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
**model_kwargs
)
else:
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
timestamp = True
if TRANSCRIPTOR == TRANSCRIPTOR_DISTIL_WHISPER:
timestamp = "word"
else:
timestamp = True
# Create pipeline with timestamp generation
if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER:
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
chunk_length_s=chunk_length_s,
stride_length_s=stride_length_s,
return_timestamps=timestamp,
max_new_tokens=128,
batch_size=24,
model_kwargs=model_kwargs
)
else:
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
chunk_length_s=chunk_length_s,
stride_length_s=stride_length_s,
return_timestamps=timestamp,
max_new_tokens=128,
)
# Transcribe with timestamps and generate attention mask
if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER:
result = pipe(
audio_file,
return_timestamps=timestamp,
batch_size=24,
generate_kwargs={
"language": language,
"task": "transcribe",
"use_cache": True,
"num_beams": 1
}
)
else:
result = pipe(
audio_file,
return_timestamps=timestamp,
generate_kwargs={
"language": language,
"task": "transcribe",
"use_cache": True,
"num_beams": 1
}
)
t = time()
print(f"Time to transcribe: {t - t0:.2f} seconds")
transcription_str = result
transcription_dict = transcription_to_dict(transcription_str)
return transcription_str, transcription_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Transcribe audio files')
parser.add_argument('input_files', help='Input audio files')
parser.add_argument('language', help='Language of the audio file')
parser.add_argument('num_speakers', help='Number of speakers in the audio file')
parser.add_argument('device', help='Device to use for PyTorch inference')
args = parser.parse_args()
chunks_folder = "chunks"
with open(args.input_files, 'r') as f:
inputs = f.read().splitlines()
progress_bar = tqdm(total=len(inputs), desc="Transcribe audio files progress")
for input in inputs:
input_file, _ = input.split('.')
_, input_name = input_file.split('/')
extension = "mp3"
file = f'{chunks_folder}/{input_name}.{extension}'
language_dict = get_language_dict()
transcribe(file, language_dict[args.language]["transcriber"], args.num_speakers, args.device)
progress_bar.update(1)