ixxan's picture
Create asr.py
544e017 verified
raw
history blame
1.48 kB
import torchaudio
import torch
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
# Load processor and model
processor = AutoProcessor.from_pretrained("ixxan/whisper-small-ug-cv-15")
model = AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-ug-cv-15")
def transcribe(audio_path: str) -> str:
"""
Transcribes audio to text using the Whisper model for Uyghur.
Args:
- audio_path (str): Path to the audio file to transcribe.
Returns:
- str: The transcription of the audio.
"""
# Load audio file
audio_input, sampling_rate = torchaudio.load(audio_path)
# Resample if needed
if sampling_rate != processor.feature_extractor.sampling_rate:
resampler = torchaudio.transforms.Resample(sampling_rate, processor.feature_extractor.sampling_rate)
audio_input = resampler(audio_input)
# Preprocess the audio input
inputs = processor(audio_input.squeeze(), sampling_rate=16000, return_tensors="pt")
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
inputs = {key: val.to(device) for key, val in inputs.items()}
# Generate transcription
with torch.no_grad():
generated_ids = model.generate(inputs["input_features"], max_length=225)
# Decode the output to get the transcription text
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return transcription