import argparse import json import re import time from collections import OrderedDict from pathlib import Path from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union import torch import numpy as np from whisper.tokenizer import get_tokenizer from whisper_live.whisper_utils import (mel_filters, store_transcripts, write_error_stats, load_audio_wav_format, pad_or_trim) import tensorrt_llm import tensorrt_llm.logger as logger from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt, trt_dtype_to_torch) from tensorrt_llm.runtime import ModelConfig, SamplingConfig from tensorrt_llm.runtime.session import Session, TensorInfo SAMPLE_RATE = 16000 N_FFT = 400 HOP_LENGTH = 160 CHUNK_LENGTH = 30 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk class WhisperEncoding: def __init__(self, engine_dir): self.session = self.get_session(engine_dir) def get_session(self, engine_dir): config_path = engine_dir / 'encoder_config.json' with open(config_path, 'r') as f: config = json.load(f) use_gpt_attention_plugin = config['plugin_config'][ 'gpt_attention_plugin'] dtype = config['builder_config']['precision'] n_mels = config['builder_config']['n_mels'] num_languages = config['builder_config']['num_languages'] self.dtype = dtype self.n_mels = n_mels self.num_languages = num_languages serialize_path = engine_dir / f'whisper_encoder_{self.dtype}_tp1_rank0.engine' with open(serialize_path, 'rb') as f: session = Session.from_serialized_engine(f.read()) return session def get_audio_features(self, mel): inputs = OrderedDict() output_list = [] inputs.update({'x': mel}) output_list.append( TensorInfo('x', str_dtype_to_trt(self.dtype), mel.shape)) output_info = (self.session).infer_shapes(output_list) logger.debug(f'output info {output_info}') outputs = { t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device='cuda') for t in output_info } stream = torch.cuda.current_stream() ok = self.session.run(inputs=inputs, outputs=outputs, stream=stream.cuda_stream) assert ok, 'Engine execution failed' stream.synchronize() audio_features = outputs['output'] return audio_features class WhisperDecoding: def __init__(self, engine_dir, runtime_mapping, debug_mode=False): self.decoder_config = self.get_config(engine_dir) self.decoder_generation_session = self.get_session( engine_dir, runtime_mapping, debug_mode) def get_config(self, engine_dir): config_path = engine_dir / 'decoder_config.json' with open(config_path, 'r') as f: config = json.load(f) decoder_config = OrderedDict() decoder_config.update(config['plugin_config']) decoder_config.update(config['builder_config']) return decoder_config def get_session(self, engine_dir, runtime_mapping, debug_mode=False): dtype = self.decoder_config['precision'] serialize_path = engine_dir / f'whisper_decoder_{dtype}_tp1_rank0.engine' with open(serialize_path, "rb") as f: decoder_engine_buffer = f.read() decoder_model_config = ModelConfig( num_heads=self.decoder_config['num_heads'], num_kv_heads=self.decoder_config['num_heads'], hidden_size=self.decoder_config['hidden_size'], vocab_size=self.decoder_config['vocab_size'], num_layers=self.decoder_config['num_layers'], gpt_attention_plugin=self.decoder_config['gpt_attention_plugin'], remove_input_padding=self.decoder_config['remove_input_padding'], cross_attention=self.decoder_config['cross_attention'], has_position_embedding=self. decoder_config['has_position_embedding'], has_token_type_embedding=self. decoder_config['has_token_type_embedding'], ) decoder_generation_session = tensorrt_llm.runtime.GenerationSession( decoder_model_config, decoder_engine_buffer, runtime_mapping, debug_mode=debug_mode) return decoder_generation_session def generate(self, decoder_input_ids, encoder_outputs, eot_id, max_new_tokens=40, num_beams=1): encoder_input_lengths = torch.tensor( [encoder_outputs.shape[1] for x in range(encoder_outputs.shape[0])], dtype=torch.int32, device='cuda') decoder_input_lengths = torch.tensor([ decoder_input_ids.shape[-1] for _ in range(decoder_input_ids.shape[0]) ], dtype=torch.int32, device='cuda') decoder_max_input_length = torch.max(decoder_input_lengths).item() # generation config sampling_config = SamplingConfig(end_id=eot_id, pad_id=eot_id, num_beams=num_beams) self.decoder_generation_session.setup( decoder_input_lengths.size(0), decoder_max_input_length, max_new_tokens, beam_width=num_beams, encoder_max_input_length=encoder_outputs.shape[1]) torch.cuda.synchronize() decoder_input_ids = decoder_input_ids.type(torch.int32).cuda() output_ids = self.decoder_generation_session.decode( decoder_input_ids, decoder_input_lengths, sampling_config, encoder_output=encoder_outputs, encoder_input_lengths=encoder_input_lengths, ) torch.cuda.synchronize() # get the list of int from output_ids tensor output_ids = output_ids.cpu().numpy().tolist() return output_ids class WhisperTRTLLM(object): def __init__( self, engine_dir, debug_mode=False, assets_dir=None, device=None ): world_size = 1 runtime_rank = tensorrt_llm.mpi_rank() runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank) torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node) engine_dir = Path(engine_dir) self.encoder = WhisperEncoding(engine_dir) self.decoder = WhisperDecoding(engine_dir, runtime_mapping, debug_mode=False) self.n_mels = self.encoder.n_mels # self.tokenizer = get_tokenizer(num_languages=self.encoder.num_languages, # tokenizer_dir=assets_dir) self.device = device self.tokenizer = get_tokenizer( False, # num_languages=self.encoder.num_languages, language="en", task="transcribe", ) self.filters = mel_filters(self.device, self.encoder.n_mels, assets_dir) def log_mel_spectrogram( self, audio: Union[str, np.ndarray, torch.Tensor], padding: int = 0, return_duration = True ): """ Compute the log-Mel spectrogram of Parameters ---------- audio: Union[str, np.ndarray, torch.Tensor], shape = (*) The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz n_mels: int The number of Mel-frequency filters, only 80 and 128 are supported padding: int Number of zero samples to pad to the right device: Optional[Union[str, torch.device]] If given, the audio tensor is moved to this device before STFT Returns ------- torch.Tensor, shape = (80 or 128, n_frames) A Tensor that contains the Mel spectrogram """ if not torch.is_tensor(audio): if isinstance(audio, str): if audio.endswith('.wav'): audio, _ = load_audio_wav_format(audio) else: audio = load_audio(audio) assert isinstance(audio, np.ndarray), f"Unsupported audio type: {type(audio)}" duration = audio.shape[-1] / SAMPLE_RATE audio = pad_or_trim(audio, N_SAMPLES) audio = audio.astype(np.float32) audio = torch.from_numpy(audio) if self.device is not None: audio = audio.to(self.device) if padding > 0: audio = F.pad(audio, (0, padding)) window = torch.hann_window(N_FFT).to(audio.device) stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) magnitudes = stft[..., :-1].abs()**2 mel_spec = self.filters @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 if return_duration: return log_spec, duration else: return log_spec def process_batch( self, mel, text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", num_beams=1): prompt_id = self.tokenizer.encode( text_prefix, allowed_special=set(self.tokenizer.special_tokens.keys())) prompt_id = torch.tensor(prompt_id) batch_size = mel.shape[0] decoder_input_ids = prompt_id.repeat(batch_size, 1) encoder_output = self.encoder.get_audio_features(mel) output_ids = self.decoder.generate(decoder_input_ids, encoder_output, self.tokenizer.eot, max_new_tokens=96, num_beams=num_beams) texts = [] for i in range(len(output_ids)): text = self.tokenizer.decode(output_ids[i][0]).strip() texts.append(text) return texts def transcribe( self, mel, text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", dtype='float16', batch_size=1, num_beams=1, ): mel = mel.type(str_dtype_to_torch(dtype)) mel = mel.unsqueeze(0) predictions = self.process_batch(mel, text_prefix, num_beams) prediction = predictions[0] # remove all special tokens in the prediction prediction = re.sub(r'<\|.*?\|>', '', prediction) return prediction.strip() def decode_wav_file( model, mel, text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", dtype='float16', batch_size=1, num_beams=1, normalizer=None, mel_filters_dir=None): mel = mel.type(str_dtype_to_torch(dtype)) mel = mel.unsqueeze(0) # repeat the mel spectrogram to match the batch size mel = mel.repeat(batch_size, 1, 1) predictions = model.process_batch(mel, text_prefix, num_beams) prediction = predictions[0] # remove all special tokens in the prediction prediction = re.sub(r'<\|.*?\|>', '', prediction) if normalizer: prediction = normalizer(prediction) return prediction.strip() if __name__=="__main__": tensorrt_llm.logger.set_level("error") model = WhisperTRTLLM("/root/TensorRT-LLM/examples/whisper/whisper_small_en", False, "../assets", device="cuda") mel, total_duration = model.log_mel_spectrogram( "../assets/1221-135766-0002.wav", ) results = model.transcribe(mel) print(results, total_duration)