WhisperFusion / llm_service.py
makaveli10
cleanup logs; send whisper and llm latency to client
bd36f24
raw
history blame
No virus
13.4 kB
import time
import json
from pathlib import Path
from typing import Optional
import logging
logging.basicConfig(level = logging.INFO)
import numpy as np
import torch
from transformers import AutoTokenizer
import re
import tensorrt_llm
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner
if PYTHON_BINDINGS:
from tensorrt_llm.runtime import ModelRunnerCpp
def read_model_name(engine_dir: str):
engine_version = tensorrt_llm.runtime.engine.get_engine_version(engine_dir)
with open(Path(engine_dir) / "config.json", 'r') as f:
config = json.load(f)
if engine_version is None:
return config['builder_config']['name']
return config['pretrained_config']['architecture']
def throttle_generator(generator, stream_interval):
for i, out in enumerate(generator):
if not i % stream_interval:
yield out
if i % stream_interval:
yield out
def load_tokenizer(tokenizer_dir: Optional[str] = None,
vocab_file: Optional[str] = None,
model_name: str = 'gpt',
tokenizer_type: Optional[str] = None):
if vocab_file is None:
use_fast = True
if tokenizer_type is not None and tokenizer_type == "llama":
use_fast = False
# Should set both padding_side and truncation_side to be 'left'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir,
legacy=False,
padding_side='left',
truncation_side='left',
trust_remote_code=True,
tokenizer_type=tokenizer_type,
use_fast=use_fast)
else:
# For gpt-next, directly load from tokenizer.model
assert model_name == 'gpt'
tokenizer = T5Tokenizer(vocab_file=vocab_file,
padding_side='left',
truncation_side='left')
if model_name == 'qwen':
with open(Path(tokenizer_dir) / "generation_config.json") as f:
gen_config = json.load(f)
chat_format = gen_config['chat_format']
if chat_format == 'raw':
pad_id = gen_config['pad_token_id']
end_id = gen_config['eos_token_id']
elif chat_format == 'chatml':
pad_id = tokenizer.im_end_id
end_id = tokenizer.im_end_id
else:
raise Exception(f"unknown chat format: {chat_format}")
elif model_name == 'glm_10b':
pad_id = tokenizer.pad_token_id
end_id = tokenizer.eop_token_id
else:
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
pad_id = tokenizer.pad_token_id
end_id = tokenizer.eos_token_id
return tokenizer, pad_id, end_id
class TensorRTLLMEngine:
def __init__(self):
pass
def initialize_model(self, engine_dir, tokenizer_dir):
self.log_level = 'error'
self.runtime_rank = tensorrt_llm.mpi_rank()
logger.set_level(self.log_level)
model_name = read_model_name(engine_dir)
self.tokenizer, self.pad_id, self.end_id = load_tokenizer(
tokenizer_dir=tokenizer_dir,
vocab_file=None,
model_name=model_name,
tokenizer_type=None,
)
self.prompt_template = None
self.runner_cls = ModelRunner
self.runner_kwargs = dict(engine_dir=engine_dir,
lora_dir=None,
rank=self.runtime_rank,
debug_mode=False,
lora_ckpt_source='hf')
self.runner = self.runner_cls.from_dir(**self.runner_kwargs)
self.last_prompt = None
self.last_output = None
def parse_input(
self,
input_text=None,
add_special_tokens=True,
max_input_length=923,
pad_id=None,
):
if self.pad_id is None:
self.pad_id = self.tokenizer.pad_token_id
batch_input_ids = []
for curr_text in input_text:
if self.prompt_template is not None:
curr_text = self.prompt_template.format(input_text=curr_text)
input_ids = self.tokenizer.encode(
curr_text,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=max_input_length
)
batch_input_ids.append(input_ids)
batch_input_ids = [
torch.tensor(x, dtype=torch.int32) for x in batch_input_ids
]
return batch_input_ids
def decode_tokens(
self,
output_ids,
input_lengths,
sequence_lengths,
transcription_queue
):
batch_size, num_beams, _ = output_ids.size()
for batch_idx in range(batch_size):
if transcription_queue.qsize() != 0:
return None
inputs = output_ids[batch_idx][0][:input_lengths[batch_idx]].tolist()
input_text = self.tokenizer.decode(inputs)
output = []
for beam in range(num_beams):
if transcription_queue.qsize() != 0:
return None
output_begin = input_lengths[batch_idx]
output_end = sequence_lengths[batch_idx][beam]
outputs = output_ids[batch_idx][beam][
output_begin:output_end].tolist()
output_text = self.tokenizer.decode(outputs)
output.append(output_text)
return output
def format_prompt_qa(self, prompt, conversation_history):
formatted_prompt = ""
for user_prompt, llm_response in conversation_history:
formatted_prompt += f"Instruct: {user_prompt}\nOutput:{llm_response}\n"
return f"{formatted_prompt}Instruct: {prompt}\nOutput:"
def format_prompt_chat(self, prompt, conversation_history):
formatted_prompt = ""
for user_prompt, llm_response in conversation_history:
formatted_prompt += f"Alice: {user_prompt}\nBob:{llm_response}\n"
return f"{formatted_prompt}Alice: {prompt}\nBob:"
def format_prompt_chatml(self, prompt, conversation_history, system_prompt=""):
formatted_prompt = ("<|im_start|>system\n" + system_prompt + "<|im_end|>\n")
for user_prompt, llm_response in conversation_history:
formatted_prompt += f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
formatted_prompt += f"<|im_start|>assistant\n{llm_response}<|im_end|>\n"
formatted_prompt += f"<|im_start|>user\n{prompt}<|im_end|>\n"
return formatted_prompt
def run(
self,
model_path,
tokenizer_path,
transcription_queue=None,
llm_queue=None,
audio_queue=None,
input_text=None,
max_output_len=50,
max_attention_window_size=4096,
num_beams=1,
streaming=False,
streaming_interval=4,
debug=False,
):
self.initialize_model(
model_path,
tokenizer_path,
)
logging.info("[LLM INFO:] Loaded LLM TensorRT Engine.")
conversation_history = {}
while True:
# Get the last transcription output from the queue
transcription_output = transcription_queue.get()
if transcription_queue.qsize() != 0:
continue
if transcription_output["uid"] not in conversation_history:
conversation_history[transcription_output["uid"]] = []
prompt = transcription_output['prompt'].strip()
# if prompt is same but EOS is True, we need that to send outputs to websockets
if self.last_prompt == prompt:
if self.last_output is not None and transcription_output["eos"]:
self.eos = transcription_output["eos"]
llm_queue.put({
"uid": transcription_output["uid"],
"llm_output": self.last_output,
"eos": self.eos,
"latency": self.infer_time
})
audio_queue.put({"llm_output": self.last_output, "eos": self.eos})
conversation_history[transcription_output["uid"]].append(
(transcription_output['prompt'].strip(), self.last_output[0].strip())
)
continue
# input_text=[self.format_prompt_qa(prompt, conversation_history[transcription_output["uid"]])]
input_text=[self.format_prompt_chatml(prompt, conversation_history[transcription_output["uid"]], system_prompt="You are Dolphin, a helpful AI assistant")]
self.eos = transcription_output["eos"]
batch_input_ids = self.parse_input(
input_text=input_text,
add_special_tokens=True,
max_input_length=923,
pad_id=None,
)
input_lengths = [x.size(0) for x in batch_input_ids]
logging.info(f"[LLM INFO:] Running LLM Inference with WhisperLive prompt: {prompt}, eos: {self.eos}")
start = time.time()
with torch.no_grad():
outputs = self.runner.generate(
batch_input_ids,
max_new_tokens=max_output_len,
max_attention_window_size=max_attention_window_size,
end_id=self.end_id,
pad_id=self.pad_id,
temperature=1.0,
top_k=1,
top_p=0.0,
num_beams=num_beams,
length_penalty=1.0,
repetition_penalty=1.0,
stop_words_list=None,
bad_words_list=None,
lora_uids=None,
prompt_table_path=None,
prompt_tasks=None,
streaming=streaming,
output_sequence_lengths=True,
return_dict=True)
torch.cuda.synchronize()
if streaming:
for curr_outputs in throttle_generator(outputs, streaming_interval):
output_ids = curr_outputs['output_ids']
sequence_lengths = curr_outputs['sequence_lengths']
output = self.decode_tokens(
output_ids,
input_lengths,
sequence_lengths,
transcription_queue
)
if output is None:
break
# Interrupted by transcription queue
if output is None:
continue
else:
output_ids = outputs['output_ids']
sequence_lengths = outputs['sequence_lengths']
context_logits = None
generation_logits = None
if self.runner.gather_context_logits:
context_logits = outputs['context_logits']
if self.runner.gather_generation_logits:
generation_logits = outputs['generation_logits']
output = self.decode_tokens(
output_ids,
input_lengths,
sequence_lengths,
transcription_queue
)
self.infer_time = time.time() - start
# if self.eos:
if output is not None:
output[0] = clean_llm_output(output[0])
self.last_output = output
self.last_prompt = prompt
llm_queue.put({
"uid": transcription_output["uid"],
"llm_output": output,
"eos": self.eos,
"latency": self.infer_time
})
audio_queue.put({"llm_output": output, "eos": self.eos})
logging.info(f"[LLM INFO:] Output: {output[0]}\nLLM inference done in {self.infer_time} ms\n\n")
if self.eos:
conversation_history[transcription_output["uid"]].append(
(transcription_output['prompt'].strip(), output[0].strip())
)
self.last_prompt = None
self.last_output = None
def clean_llm_output(output):
output = output.replace("\n\nDolphin\n\n", "")
output = output.replace("\nDolphin\n\n", "")
output = output.replace("Dolphin: ", "")
output = output.replace("Assistant: ", "")
if not output.endswith('.') and not output.endswith('?') and not output.endswith('!'):
last_punct = output.rfind('.')
last_q = output.rfind('?')
if last_q > last_punct:
last_punct = last_q
last_ex = output.rfind('!')
if last_ex > last_punct:
last_punct = last_ex
if last_punct > 0:
output = output[:last_punct+1]
return output