from typing import List, Optional, Union from vllm.engine.llm_engine import LLMEngine from vllm.engine.arg_utils import EngineArgs from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter from vllm.outputs import RequestOutput from vllm import SamplingParams from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast import gradio as gr class StreamingLLM: def __init__( self, model: str, dtype: str = "auto", quantization: Optional[str] = None, **kwargs, ) -> None: engine_args = EngineArgs(model=model, quantization=quantization, dtype=dtype, enforce_eager=True) self.llm_engine = LLMEngine.from_engine_args(engine_args, usage_context=UsageContext.LLM_CLASS) self.request_counter = Counter() def generate( self, prompt: Optional[str] = None, sampling_params: Optional[SamplingParams] = None ) -> List[RequestOutput]: request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id, prompt, sampling_params) while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() for output in step_outputs: yield output class UI: def __init__( self, llm: StreamingLLM, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], sampling_params: Optional[SamplingParams] = None, ) -> None: self.llm = llm self.tokenizer = tokenizer self.sampling_params = sampling_params def _generate(self, message, history): history_chat_format = [] for human, assistant in history: history_chat_format.append({"role": "user", "content": human }) history_chat_format.append({"role": "assistant", "content": assistant}) history_chat_format.append({"role": "user", "content": message}) prompt = self.tokenizer.apply_chat_template(history_chat_format, tokenize=False) for chunk in self.llm.generate(prompt, self.sampling_params): yield chunk.outputs[0].text def launch(self): gr.ChatInterface(self._generate).launch() if __name__ == "__main__": llm = StreamingLLM(model="casperhansen/llama-3-70b-instruct-awq", quantization="AWQ", dtype="float16") tokenizer = llm.llm_engine.tokenizer.tokenizer sampling_params = SamplingParams(temperature=0.6, top_p=0.9, max_tokens=4096, stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")] ) ui = UI(llm, tokenizer, sampling_params) ui.launch()