from time import perf_counter from transformers import ( AutoModelForSpeechSeq2Seq, AutoProcessor, ) import torch from baseHandler import BaseHandler from rich.console import Console import logging logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) console = Console() class WhisperSTTHandler(BaseHandler): """ Handles the Speech To Text generation using a Whisper model. """ def setup( self, model_name="distil-whisper/distil-large-v3", device="cuda", torch_dtype="float16", compile_mode=None, gen_kwargs={}, ): self.device = device self.torch_dtype = getattr(torch, torch_dtype) self.compile_mode = compile_mode self.gen_kwargs = gen_kwargs self.processor = AutoProcessor.from_pretrained(model_name) self.model = AutoModelForSpeechSeq2Seq.from_pretrained( model_name, torch_dtype=self.torch_dtype, ).to(device) # compile if self.compile_mode: self.model.generation_config.cache_implementation = "static" self.model.forward = torch.compile( self.model.forward, mode=self.compile_mode, fullgraph=True ) self.warmup() def prepare_model_inputs(self, spoken_prompt): input_features = self.processor( spoken_prompt, sampling_rate=16000, return_tensors="pt" ).input_features input_features = input_features.to(self.device, dtype=self.torch_dtype) return input_features def warmup(self): logger.info(f"Warming up {self.__class__.__name__}") # 2 warmup steps for no compile or compile mode with CUDA graphs capture n_steps = 1 if self.compile_mode == "default" else 2 dummy_input = torch.randn( (1, self.model.config.num_mel_bins, 3000), dtype=self.torch_dtype, device=self.device, ) if self.compile_mode not in (None, "default"): # generating more tokens than previously will trigger CUDA graphs capture # one should warmup with a number of generated tokens above max tokens targeted for subsequent generation warmup_gen_kwargs = { "min_new_tokens": self.gen_kwargs["min_new_tokens"], "max_new_tokens": self.gen_kwargs["max_new_tokens"], **self.gen_kwargs, } else: warmup_gen_kwargs = self.gen_kwargs if self.device == "cuda": start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) torch.cuda.synchronize() start_event.record() for _ in range(n_steps): _ = self.model.generate(dummy_input, **warmup_gen_kwargs) if self.device == "cuda": end_event.record() torch.cuda.synchronize() logger.info( f"{self.__class__.__name__}: warmed up! time: {start_event.elapsed_time(end_event) * 1e-3:.3f} s" ) def process(self, spoken_prompt): logger.debug("infering whisper...") global pipeline_start pipeline_start = perf_counter() input_features = self.prepare_model_inputs(spoken_prompt) pred_ids = self.model.generate(input_features, **self.gen_kwargs) pred_text = self.processor.batch_decode( pred_ids, skip_special_tokens=True, decode_with_timestamps=False )[0] logger.debug("finished whisper inference") console.print(f"[yellow]USER: {pred_text}") yield pred_text