|
from typing import List, Optional, Union |
|
from vllm.engine.llm_engine import LLMEngine |
|
from vllm.engine.arg_utils import EngineArgs |
|
from vllm.usage.usage_lib import UsageContext |
|
from vllm.utils import Counter |
|
from vllm.outputs import RequestOutput |
|
from vllm import SamplingParams |
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast |
|
import gradio as gr |
|
|
|
|
|
class StreamingLLM: |
|
def __init__( |
|
self, |
|
model: str, |
|
dtype: str = "auto", |
|
quantization: Optional[str] = None, |
|
**kwargs, |
|
) -> None: |
|
engine_args = EngineArgs(model=model, quantization=quantization, dtype=dtype, enforce_eager=True) |
|
self.llm_engine = LLMEngine.from_engine_args(engine_args, usage_context=UsageContext.LLM_CLASS) |
|
self.request_counter = Counter() |
|
|
|
def generate( |
|
self, |
|
prompt: Optional[str] = None, |
|
sampling_params: Optional[SamplingParams] = None |
|
) -> List[RequestOutput]: |
|
|
|
request_id = str(next(self.request_counter)) |
|
self.llm_engine.add_request(request_id, prompt, sampling_params) |
|
|
|
while self.llm_engine.has_unfinished_requests(): |
|
step_outputs = self.llm_engine.step() |
|
for output in step_outputs: |
|
yield output |
|
|
|
|
|
class UI: |
|
def __init__( |
|
self, |
|
llm: StreamingLLM, |
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], |
|
sampling_params: Optional[SamplingParams] = None, |
|
) -> None: |
|
self.llm = llm |
|
self.tokenizer = tokenizer |
|
self.sampling_params = sampling_params |
|
|
|
def _generate(self, message, history): |
|
history_chat_format = [] |
|
for human, assistant in history: |
|
history_chat_format.append({"role": "user", "content": human }) |
|
history_chat_format.append({"role": "assistant", "content": assistant}) |
|
history_chat_format.append({"role": "user", "content": message}) |
|
|
|
prompt = self.tokenizer.apply_chat_template(history_chat_format, tokenize=False) |
|
|
|
for chunk in self.llm.generate(prompt, self.sampling_params): |
|
yield chunk.outputs[0].text |
|
|
|
def launch(self): |
|
gr.ChatInterface(self._generate).launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
llm = StreamingLLM(model="casperhansen/llama-3-70b-instruct-awq", quantization="AWQ", dtype="float16") |
|
tokenizer = llm.llm_engine.tokenizer.tokenizer |
|
sampling_params = SamplingParams(temperature=0.6, |
|
top_p=0.9, |
|
max_tokens=4096, |
|
stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")] |
|
) |
|
ui = UI(llm, tokenizer, sampling_params) |
|
ui.launch() |