File size: 13,431 Bytes
bd36f24
201054b
 
 
bd36f24
16388cf
 
 
201054b
 
 
28cd485
201054b
 
 
 
 
 
 
 
 
 
855f306
201054b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3f7cd8
201054b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cc82ad
16388cf
 
201054b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855f306
201054b
 
 
 
 
 
 
 
9dcd6a2
201054b
 
 
9dcd6a2
 
 
 
201054b
 
 
9dcd6a2
 
 
201054b
 
 
 
 
 
 
 
e3f7cd8
 
 
 
 
ea46c22
e3f7cd8
 
 
 
 
ea46c22
28cd485
 
 
 
 
 
 
 
7cc82ad
 
f1e930a
 
7cc82ad
 
ea46c22
7cc82ad
28cd485
201054b
 
16388cf
201054b
7cc82ad
201054b
7cc82ad
f1e930a
 
201054b
f1e930a
bd36f24
e3f7cd8
 
 
7cc82ad
9dcd6a2
 
7cc82ad
9dcd6a2
 
e3f7cd8
 
 
9dcd6a2
ea46c22
e3f7cd8
3a0cae9
16388cf
3a0cae9
 
bd36f24
 
 
 
 
 
16388cf
e3f7cd8
5d9af51
e3f7cd8
16388cf
e3f7cd8
28cd485
 
3a0cae9
 
16388cf
7cc82ad
 
 
 
 
 
201054b
855f306
bd36f24
 
 
7cc82ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dcd6a2
 
7cc82ad
9dcd6a2
 
 
e3f7cd8
9dcd6a2
 
 
7cc82ad
 
 
 
 
16388cf
7cc82ad
16388cf
7cc82ad
201054b
 
 
7cc82ad
9dcd6a2
201054b
bd36f24
16388cf
 
 
28cd485
16388cf
 
bd36f24
 
 
 
 
 
16388cf
bd36f24
16388cf
 
e3f7cd8
 
 
16388cf
3a0cae9
201054b
28cd485
 
 
bd36f24
 
28cd485
 
 
 
 
 
 
 
 
 
 
 
 
201054b
28cd485
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import time
import json
from pathlib import Path
from typing import Optional

import logging
logging.basicConfig(level = logging.INFO)

import numpy as np
import torch
from transformers import AutoTokenizer
import re

import tensorrt_llm
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner

if PYTHON_BINDINGS:
    from tensorrt_llm.runtime import ModelRunnerCpp


def read_model_name(engine_dir: str):
    engine_version = tensorrt_llm.runtime.engine.get_engine_version(engine_dir)

    with open(Path(engine_dir) / "config.json", 'r') as f:
        config = json.load(f)

    if engine_version is None:
        return config['builder_config']['name']

    return config['pretrained_config']['architecture']


def throttle_generator(generator, stream_interval):
    for i, out in enumerate(generator):
        if not i % stream_interval:
            yield out

    if i % stream_interval:
        yield out


def load_tokenizer(tokenizer_dir: Optional[str] = None,
                   vocab_file: Optional[str] = None,
                   model_name: str = 'gpt',
                   tokenizer_type: Optional[str] = None):
    if vocab_file is None:
        use_fast = True
        if tokenizer_type is not None and tokenizer_type == "llama":
            use_fast = False
        # Should set both padding_side and truncation_side to be 'left'
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir,
                                                  legacy=False,
                                                  padding_side='left',
                                                  truncation_side='left',
                                                  trust_remote_code=True,
                                                  tokenizer_type=tokenizer_type,
                                                  use_fast=use_fast)
    else:
        # For gpt-next, directly load from tokenizer.model
        assert model_name == 'gpt'
        tokenizer = T5Tokenizer(vocab_file=vocab_file,
                                padding_side='left',
                                truncation_side='left')

    if model_name == 'qwen':
        with open(Path(tokenizer_dir) / "generation_config.json") as f:
            gen_config = json.load(f)
        chat_format = gen_config['chat_format']
        if chat_format == 'raw':
            pad_id = gen_config['pad_token_id']
            end_id = gen_config['eos_token_id']
        elif chat_format == 'chatml':
            pad_id = tokenizer.im_end_id
            end_id = tokenizer.im_end_id
        else:
            raise Exception(f"unknown chat format: {chat_format}")
    elif model_name == 'glm_10b':
        pad_id = tokenizer.pad_token_id
        end_id = tokenizer.eop_token_id
    else:
        if tokenizer.pad_token_id is None:
            tokenizer.pad_token_id = tokenizer.eos_token_id
        pad_id = tokenizer.pad_token_id
        end_id = tokenizer.eos_token_id

    return tokenizer, pad_id, end_id


class TensorRTLLMEngine:
    def __init__(self):
        pass
    
    def initialize_model(self, engine_dir, tokenizer_dir):
        self.log_level = 'error'
        self.runtime_rank = tensorrt_llm.mpi_rank()
        logger.set_level(self.log_level)
        model_name = read_model_name(engine_dir)
        self.tokenizer, self.pad_id, self.end_id = load_tokenizer(
            tokenizer_dir=tokenizer_dir,
            vocab_file=None,
            model_name=model_name,
            tokenizer_type=None,
        )
        self.prompt_template = None
        self.runner_cls = ModelRunner
        self.runner_kwargs = dict(engine_dir=engine_dir,
                         lora_dir=None,
                         rank=self.runtime_rank,
                         debug_mode=False,
                         lora_ckpt_source='hf')
        self.runner = self.runner_cls.from_dir(**self.runner_kwargs)
        self.last_prompt = None
        self.last_output = None

    def parse_input(
        self,
        input_text=None,
        add_special_tokens=True,
        max_input_length=923,
        pad_id=None,
    ):
        if self.pad_id is None:
            self.pad_id = self.tokenizer.pad_token_id

        batch_input_ids = []
        for curr_text in input_text:
            if self.prompt_template is not None:
                curr_text = self.prompt_template.format(input_text=curr_text)
            input_ids = self.tokenizer.encode(
                curr_text,
                add_special_tokens=add_special_tokens,
                truncation=True,
                max_length=max_input_length
            )
            batch_input_ids.append(input_ids)

        batch_input_ids = [
            torch.tensor(x, dtype=torch.int32) for x in batch_input_ids
        ]
        return batch_input_ids
    
    def decode_tokens(
        self,
        output_ids,
        input_lengths,
        sequence_lengths,
        transcription_queue
        ):
        batch_size, num_beams, _ = output_ids.size()
        for batch_idx in range(batch_size):
            if transcription_queue.qsize() != 0:
                return None

            inputs = output_ids[batch_idx][0][:input_lengths[batch_idx]].tolist()
            input_text = self.tokenizer.decode(inputs)
            output = []
            for beam in range(num_beams):
                if transcription_queue.qsize() != 0:
                    return None

                output_begin = input_lengths[batch_idx]
                output_end = sequence_lengths[batch_idx][beam]
                outputs = output_ids[batch_idx][beam][
                    output_begin:output_end].tolist()
                output_text = self.tokenizer.decode(outputs)
                output.append(output_text)
        return output
    
    def format_prompt_qa(self, prompt, conversation_history):
        formatted_prompt = ""
        for user_prompt, llm_response in conversation_history:
            formatted_prompt += f"Instruct: {user_prompt}\nOutput:{llm_response}\n"
        return f"{formatted_prompt}Instruct: {prompt}\nOutput:"
    
    def format_prompt_chat(self, prompt, conversation_history):
        formatted_prompt = ""
        for user_prompt, llm_response in conversation_history:
            formatted_prompt += f"Alice: {user_prompt}\nBob:{llm_response}\n"
        return f"{formatted_prompt}Alice: {prompt}\nBob:"

    def format_prompt_chatml(self, prompt, conversation_history, system_prompt=""):
        formatted_prompt = ("<|im_start|>system\n" + system_prompt + "<|im_end|>\n")
        for user_prompt, llm_response in conversation_history:
            formatted_prompt += f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
            formatted_prompt += f"<|im_start|>assistant\n{llm_response}<|im_end|>\n"
        formatted_prompt += f"<|im_start|>user\n{prompt}<|im_end|>\n"
        return formatted_prompt

    def run(
        self,
        model_path,
        tokenizer_path,
        transcription_queue=None,
        llm_queue=None,
        audio_queue=None,
        input_text=None, 
        max_output_len=50, 
        max_attention_window_size=4096, 
        num_beams=1, 
        streaming=False,
        streaming_interval=4,
        debug=False,
    ):
        self.initialize_model(
            model_path,
            tokenizer_path,
        )
        
        logging.info("[LLM INFO:] Loaded LLM TensorRT Engine.")

        conversation_history = {}

        while True:

            # Get the last transcription output from the queue
            transcription_output = transcription_queue.get()
            if transcription_queue.qsize() != 0:
                continue
            
            if transcription_output["uid"] not in conversation_history:
                conversation_history[transcription_output["uid"]] = []

            prompt = transcription_output['prompt'].strip()
                                
            # if prompt is same but EOS is True, we need that to send outputs to websockets
            if self.last_prompt == prompt:
                if self.last_output is not None and transcription_output["eos"]:
                    self.eos = transcription_output["eos"]
                    llm_queue.put({
                        "uid": transcription_output["uid"],
                        "llm_output": self.last_output,
                        "eos": self.eos,
                        "latency": self.infer_time
                    })
                    audio_queue.put({"llm_output": self.last_output, "eos": self.eos})
                    conversation_history[transcription_output["uid"]].append(
                        (transcription_output['prompt'].strip(), self.last_output[0].strip())
                    )
                    continue

            # input_text=[self.format_prompt_qa(prompt, conversation_history[transcription_output["uid"]])]
            input_text=[self.format_prompt_chatml(prompt, conversation_history[transcription_output["uid"]], system_prompt="You are Dolphin, a helpful AI assistant")]
            
            self.eos = transcription_output["eos"]

            batch_input_ids = self.parse_input(
                input_text=input_text,
                add_special_tokens=True,
                max_input_length=923,
                pad_id=None,
            )

            input_lengths = [x.size(0) for x in batch_input_ids]

            logging.info(f"[LLM INFO:] Running LLM Inference with WhisperLive prompt: {prompt}, eos: {self.eos}")
            start = time.time()
            with torch.no_grad():
                outputs = self.runner.generate(
                    batch_input_ids,
                    max_new_tokens=max_output_len,
                    max_attention_window_size=max_attention_window_size,
                    end_id=self.end_id,
                    pad_id=self.pad_id,
                    temperature=1.0,
                    top_k=1,
                    top_p=0.0,
                    num_beams=num_beams,
                    length_penalty=1.0,
                    repetition_penalty=1.0,
                    stop_words_list=None,
                    bad_words_list=None,
                    lora_uids=None,
                    prompt_table_path=None,
                    prompt_tasks=None,
                    streaming=streaming,
                    output_sequence_lengths=True,
                    return_dict=True)
                torch.cuda.synchronize()
            if streaming:
                for curr_outputs in throttle_generator(outputs, streaming_interval):
                    output_ids = curr_outputs['output_ids']
                    sequence_lengths = curr_outputs['sequence_lengths']
                    output = self.decode_tokens(
                        output_ids,
                        input_lengths,
                        sequence_lengths,
                        transcription_queue
                    )

                    if output is None:
                        break
                
                # Interrupted by transcription queue
                if output is None:
                    continue
            else:
                output_ids = outputs['output_ids']
                sequence_lengths = outputs['sequence_lengths']
                context_logits = None
                generation_logits = None
                if self.runner.gather_context_logits:
                    context_logits = outputs['context_logits']
                if self.runner.gather_generation_logits:
                    generation_logits = outputs['generation_logits']
                output = self.decode_tokens(
                    output_ids,
                    input_lengths,
                    sequence_lengths,
                    transcription_queue
                )
            self.infer_time = time.time() - start
            
            # if self.eos:
            if output is not None:
                output[0] = clean_llm_output(output[0])
                self.last_output = output
                self.last_prompt = prompt
                llm_queue.put({
                    "uid": transcription_output["uid"],
                    "llm_output": output,
                    "eos": self.eos,
                    "latency": self.infer_time
                })
                audio_queue.put({"llm_output": output, "eos": self.eos})
                logging.info(f"[LLM INFO:] Output: {output[0]}\nLLM inference done in {self.infer_time} ms\n\n")
            
            if self.eos:
                conversation_history[transcription_output["uid"]].append(
                    (transcription_output['prompt'].strip(), output[0].strip())
                )
                self.last_prompt = None
                self.last_output = None

def clean_llm_output(output):
    output = output.replace("\n\nDolphin\n\n", "")
    output = output.replace("\nDolphin\n\n", "")
    output = output.replace("Dolphin: ", "")
    output = output.replace("Assistant: ", "")

    if not output.endswith('.') and not output.endswith('?') and not output.endswith('!'):
        last_punct = output.rfind('.')
        last_q = output.rfind('?')
        if last_q > last_punct:
            last_punct = last_q
        
        last_ex = output.rfind('!')
        if last_ex > last_punct:
            last_punct = last_ex
        
        if last_punct > 0:
            output = output[:last_punct+1]

    return output