Spaces:
Sleeping
Sleeping
import logging | |
from time import perf_counter | |
from baseHandler import BaseHandler | |
from lightning_whisper_mlx import LightningWhisperMLX | |
import numpy as np | |
from rich.console import Console | |
import torch | |
logging.basicConfig( | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
) | |
logger = logging.getLogger(__name__) | |
console = Console() | |
class LightningWhisperSTTHandler(BaseHandler): | |
""" | |
Handles the Speech To Text generation using a Whisper model. | |
""" | |
def setup( | |
self, | |
model_name="distil-large-v3", | |
device="cuda", | |
torch_dtype="float16", | |
compile_mode=None, | |
gen_kwargs={}, | |
): | |
if len(model_name.split("/")) > 1: | |
model_name = model_name.split("/")[-1] | |
self.device = device | |
self.model = LightningWhisperMLX(model=model_name, batch_size=6, quant=None) | |
self.warmup() | |
def warmup(self): | |
logger.info(f"Warming up {self.__class__.__name__}") | |
# 2 warmup steps for no compile or compile mode with CUDA graphs capture | |
n_steps = 1 | |
dummy_input = np.array([0] * 512) | |
for _ in range(n_steps): | |
_ = self.model.transcribe(dummy_input)["text"].strip() | |
def process(self, spoken_prompt): | |
logger.debug("infering whisper...") | |
global pipeline_start | |
pipeline_start = perf_counter() | |
pred_text = self.model.transcribe(spoken_prompt)["text"].strip() | |
torch.mps.empty_cache() | |
logger.debug("finished whisper inference") | |
console.print(f"[yellow]USER: {pred_text}") | |
yield pred_text | |