toolkit / whisper
k4d3's picture
whisper
ec7bd39
raw
history blame
2.81 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This script uses the Whisper large-v3-turbo model from OpenAI for automatic speech recognition (ASR).
The model is finetuned for faster performance with a minor quality trade-off. It leverages the Hugging Face
Transformers library to load the model and processor, and performs transcription on an input audio file.
Whisper is a state-of-the-art model for ASR and speech translation, proposed in the paper "Robust Speech
Recognition via Large-Scale Weak Supervision" by Alec Radford et al. from OpenAI. Trained on over 5 million
hours of labeled data, Whisper demonstrates a strong ability to generalize to many datasets and domains in
a zero-shot setting.
The script performs the following steps:
1. Checks if a CUDA-enabled GPU is available and sets the appropriate device and data type.
2. Loads the Whisper large-v3-turbo model and processor from the Hugging Face Hub.
3. Initializes an ASR pipeline using the model and processor.
4. Defines a function `transcribe_audio` that takes an audio file path as input, performs transcription,
and outputs the result to the terminal and a text file.
5. The script expects an audio file path as a command-line argument and calls the `transcribe_audio` function.
Usage:
whisper <audio_file>
Dependencies:
- torch
- transformers
- datasets
- accelerate
Example:
whisper sample_audio.wav
"""
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import sys
import os
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3-turbo"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
def transcribe_audio(audio_path):
# Load audio file
audio = {"path": audio_path}
# Perform transcription
result = pipe(audio)
# Get the base filename and directory
base_filename = os.path.splitext(audio_path)[0]
output_text_path = base_filename + ".txt"
# Output the result to the terminal
print(result["text"])
# Save the result to a text file
with open(output_text_path, "w") as f:
f.write(result["text"])
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python script.py <audio_file>")
sys.exit(1)
audio_file = sys.argv[1]
transcribe_audio(audio_file)