TinyOctopus / inference.py
SaraAlthubaiti's picture
Create inference.py
82b272f verified
import torch
from transformers import WhisperFeatureExtractor
from models.tinyoctopus import TINYOCTOPUS
from utils import prepare_one_sample
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TINYOCTOPUS.from_config(cfg.config.model)
model.to(device)
model.eval()
# Load processor
wav_processor = WhisperFeatureExtractor.from_pretrained("distil-whisper/distil-large-v3")
def transcribe(audio_path, task="dialect"):
"""
Perform inference on an audio file.
Args:
audio_path (str): Path to the audio file.
task (str): Task to perform. Options: "dialect", "asr", "translation".
Returns:
str: The generated text.
"""
task_prompts = {
"dialect": "What is the dialect of the speaker?",
"asr": "تعرف على الكلام وأعطني النص.",
"translation": "الرجاء ترجمة هذا المقطع الصوتي إلى اللغة الإنجليزية."
}
if task not in task_prompts:
raise ValueError("Invalid task. Choose from: 'dialect', 'asr', or 'translation'.")
try:
prompt = task_prompts[task]
samples = prepare_one_sample(audio_path, wav_processor)
prompt = [f"<Speech><SpeechHere></Speech> {prompt.strip()}"]
generated_text = model.generate(samples, {"temperature": 0.7}, prompts=prompt)[0]
return generated_text.replace('<s>', '').replace('</s>', '').strip()
except Exception as e:
return f"Error: {e}"