|
import os |
|
import argparse |
|
from lang_list import LANGUAGE_NAME_TO_CODE, WHISPER_LANGUAGES |
|
from tqdm import tqdm |
|
import torch |
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
|
from transformers.utils import is_flash_attn_2_available |
|
from time import time |
|
|
|
TRANSCRIPTOR_WHISPER = "openai/whisper-large-v3-turbo" |
|
TRANSCRIPTOR_DISTIL_WHISPER = "distil-whisper/distil-large-v3" |
|
TRANSCRIPTOR = TRANSCRIPTOR_DISTIL_WHISPER |
|
|
|
|
|
def get_language_dict(): |
|
language_dict = {} |
|
|
|
for language_name, language_code in LANGUAGE_NAME_TO_CODE.items(): |
|
|
|
lang_code = language_code.split('_')[0].lower() |
|
|
|
|
|
if lang_code in WHISPER_LANGUAGES: |
|
|
|
language_dict[language_name] = { |
|
"transcriber": lang_code, |
|
"translator": language_code |
|
} |
|
return language_dict |
|
|
|
def transcription_to_dict(transcription): |
|
""" |
|
Convierte una transcripci贸n en formato string a un diccionario estructurado. |
|
|
|
Args: |
|
transcription (str): String que contiene la transcripci贸n con timestamps |
|
|
|
Returns: |
|
dict: Diccionario con el texto completo y los chunks con sus timestamps |
|
""" |
|
try: |
|
|
|
if isinstance(transcription, str): |
|
|
|
transcription_dict = eval(transcription) |
|
else: |
|
transcription_dict = transcription |
|
|
|
|
|
if not isinstance(transcription_dict, dict): |
|
raise ValueError("La transcripci贸n no tiene el formato esperado") |
|
|
|
if 'text' not in transcription_dict or 'chunks' not in transcription_dict: |
|
raise ValueError("La transcripci贸n no contiene los campos requeridos (text y chunks)") |
|
|
|
|
|
cleaned_chunks = [] |
|
for chunk in transcription_dict['chunks']: |
|
|
|
if (chunk.get('text') and |
|
isinstance(chunk.get('timestamp'), (list, tuple)) and |
|
len(chunk['timestamp']) == 2 and |
|
chunk['timestamp'][0] is not None and |
|
chunk['timestamp'][1] is not None): |
|
|
|
cleaned_chunks.append({ |
|
'start': float(chunk['timestamp'][0]), |
|
'end': float(chunk['timestamp'][1]), |
|
'text': chunk['text'].strip() |
|
}) |
|
|
|
|
|
result = { |
|
'text': transcription_dict['text'], |
|
'chunks': cleaned_chunks |
|
} |
|
|
|
return result |
|
|
|
except Exception as e: |
|
print(f"Error procesando la transcripci贸n: {e}") |
|
return None |
|
|
|
def transcribe(audio_file, language, device, chunk_length_s=30, stride_length_s=5): |
|
""" |
|
Transcribe audio file using Whisper model. |
|
|
|
Args: |
|
audio_file (str): Path to audio file |
|
language (str): Language code for transcription |
|
device (str): Device to use for inference ('cuda' or 'cpu') |
|
chunk_length_s (int): Length of audio chunks in seconds |
|
stride_length_s (int): Stride length between chunks in seconds |
|
""" |
|
output_folder = "transcriptions" |
|
if not os.path.exists(output_folder): |
|
os.makedirs(output_folder) |
|
|
|
|
|
audio_filename = os.path.basename(audio_file) |
|
filename_without_ext = os.path.splitext(audio_filename)[0] |
|
output_file = os.path.join(output_folder, f"{filename_without_ext}.srt") |
|
|
|
device = torch.device(device) |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
model_id = TRANSCRIPTOR |
|
t0 = time() |
|
|
|
|
|
print(f"Using Flash Attention 2: {is_flash_attn_2_available()}") |
|
if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER: |
|
model_kwargs = {"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"} |
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
model_id, |
|
torch_dtype=torch_dtype, |
|
low_cpu_mem_usage=True, |
|
use_safetensors=True, |
|
**model_kwargs |
|
) |
|
else: |
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
model_id, |
|
torch_dtype=torch_dtype, |
|
low_cpu_mem_usage=True, |
|
use_safetensors=True, |
|
) |
|
model.to(device) |
|
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
timestamp = True |
|
if TRANSCRIPTOR == TRANSCRIPTOR_DISTIL_WHISPER: |
|
timestamp = "word" |
|
else: |
|
timestamp = True |
|
|
|
|
|
if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER: |
|
pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model=model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
torch_dtype=torch_dtype, |
|
device=device, |
|
chunk_length_s=chunk_length_s, |
|
stride_length_s=stride_length_s, |
|
return_timestamps=timestamp, |
|
max_new_tokens=128, |
|
batch_size=24, |
|
model_kwargs=model_kwargs |
|
) |
|
else: |
|
pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model=model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
torch_dtype=torch_dtype, |
|
device=device, |
|
chunk_length_s=chunk_length_s, |
|
stride_length_s=stride_length_s, |
|
return_timestamps=timestamp, |
|
max_new_tokens=128, |
|
) |
|
|
|
|
|
if TRANSCRIPTOR == TRANSCRIPTOR_WHISPER: |
|
result = pipe( |
|
audio_file, |
|
return_timestamps=timestamp, |
|
batch_size=24, |
|
generate_kwargs={ |
|
"language": language, |
|
"task": "transcribe", |
|
"use_cache": True, |
|
"num_beams": 1 |
|
} |
|
) |
|
else: |
|
result = pipe( |
|
audio_file, |
|
return_timestamps=timestamp, |
|
generate_kwargs={ |
|
"language": language, |
|
"task": "transcribe", |
|
"use_cache": True, |
|
"num_beams": 1 |
|
} |
|
) |
|
|
|
t = time() |
|
print(f"Time to transcribe: {t - t0:.2f} seconds") |
|
|
|
transcription_str = result |
|
transcription_dict = transcription_to_dict(transcription_str) |
|
|
|
return transcription_str, transcription_dict |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description='Transcribe audio files') |
|
parser.add_argument('input_files', help='Input audio files') |
|
parser.add_argument('language', help='Language of the audio file') |
|
parser.add_argument('num_speakers', help='Number of speakers in the audio file') |
|
parser.add_argument('device', help='Device to use for PyTorch inference') |
|
args = parser.parse_args() |
|
|
|
chunks_folder = "chunks" |
|
|
|
with open(args.input_files, 'r') as f: |
|
inputs = f.read().splitlines() |
|
|
|
progress_bar = tqdm(total=len(inputs), desc="Transcribe audio files progress") |
|
for input in inputs: |
|
input_file, _ = input.split('.') |
|
_, input_name = input_file.split('/') |
|
extension = "mp3" |
|
file = f'{chunks_folder}/{input_name}.{extension}' |
|
language_dict = get_language_dict() |
|
transcribe(file, language_dict[args.language]["transcriber"], args.num_speakers, args.device) |
|
progress_bar.update(1) |