|
import asyncio |
|
from typing import List, Union |
|
|
|
from lagent.llms.base_llm import AsyncBaseLLM, BaseLLM |
|
from lagent.utils.util import filter_suffix |
|
|
|
|
|
def asdict_completion(output): |
|
return { |
|
key: getattr(output, key) |
|
for key in [ |
|
'text', 'token_ids', 'cumulative_logprob', 'logprobs', |
|
'finish_reason', 'stop_reason' |
|
] |
|
} |
|
|
|
|
|
class VllmModel(BaseLLM): |
|
""" |
|
A wrapper of vLLM model. |
|
|
|
Args: |
|
path (str): The path to the model. |
|
It could be one of the following options: |
|
- i) A local directory path of a huggingface model. |
|
- ii) The model_id of a model hosted inside a model repo |
|
on huggingface.co, such as "internlm/internlm-chat-7b", |
|
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" |
|
and so on. |
|
tp (int): tensor parallel |
|
vllm_cfg (dict): Other kwargs for vllm model initialization. |
|
""" |
|
|
|
def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs): |
|
|
|
super().__init__(path=path, **kwargs) |
|
from vllm import LLM |
|
self.model = LLM( |
|
model=self.path, |
|
trust_remote_code=True, |
|
tensor_parallel_size=tp, |
|
**vllm_cfg) |
|
|
|
def generate(self, |
|
inputs: Union[str, List[str]], |
|
do_preprocess: bool = None, |
|
skip_special_tokens: bool = False, |
|
return_dict: bool = False, |
|
**kwargs): |
|
"""Return the chat completions in non-stream mode. |
|
|
|
Args: |
|
inputs (Union[str, List[str]]): input texts to be completed. |
|
do_preprocess (bool): whether pre-process the messages. Default to |
|
True, which means chat_template will be applied. |
|
skip_special_tokens (bool): Whether or not to remove special tokens |
|
in the decoding. Default to be False. |
|
Returns: |
|
(a list of/batched) text/chat completion |
|
""" |
|
from vllm import SamplingParams |
|
|
|
batched = True |
|
if isinstance(inputs, str): |
|
inputs = [inputs] |
|
batched = False |
|
prompt = inputs |
|
gen_params = self.update_gen_params(**kwargs) |
|
max_new_tokens = gen_params.pop('max_new_tokens') |
|
stop_words = gen_params.pop('stop_words') |
|
|
|
sampling_config = SamplingParams( |
|
skip_special_tokens=skip_special_tokens, |
|
max_tokens=max_new_tokens, |
|
stop=stop_words, |
|
**gen_params) |
|
response = self.model.generate(prompt, sampling_params=sampling_config) |
|
texts = [resp.outputs[0].text for resp in response] |
|
|
|
texts = filter_suffix(texts, self.gen_params.get('stop_words')) |
|
for resp, text in zip(response, texts): |
|
resp.outputs[0].text = text |
|
if batched: |
|
return [asdict_completion(resp.outputs[0]) |
|
for resp in response] if return_dict else texts |
|
return asdict_completion( |
|
response[0].outputs[0]) if return_dict else texts[0] |
|
|
|
|
|
class AsyncVllmModel(AsyncBaseLLM): |
|
""" |
|
A asynchronous wrapper of vLLM model. |
|
|
|
Args: |
|
path (str): The path to the model. |
|
It could be one of the following options: |
|
- i) A local directory path of a huggingface model. |
|
- ii) The model_id of a model hosted inside a model repo |
|
on huggingface.co, such as "internlm/internlm-chat-7b", |
|
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" |
|
and so on. |
|
tp (int): tensor parallel |
|
vllm_cfg (dict): Other kwargs for vllm model initialization. |
|
""" |
|
|
|
def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs): |
|
super().__init__(path=path, **kwargs) |
|
from vllm import AsyncEngineArgs, AsyncLLMEngine |
|
|
|
engine_args = AsyncEngineArgs( |
|
model=self.path, |
|
trust_remote_code=True, |
|
tensor_parallel_size=tp, |
|
**vllm_cfg) |
|
self.model = AsyncLLMEngine.from_engine_args(engine_args) |
|
|
|
async def generate(self, |
|
inputs: Union[str, List[str]], |
|
session_ids: Union[int, List[int]] = None, |
|
do_preprocess: bool = None, |
|
skip_special_tokens: bool = False, |
|
return_dict: bool = False, |
|
**kwargs): |
|
"""Return the chat completions in non-stream mode. |
|
|
|
Args: |
|
inputs (Union[str, List[str]]): input texts to be completed. |
|
do_preprocess (bool): whether pre-process the messages. Default to |
|
True, which means chat_template will be applied. |
|
skip_special_tokens (bool): Whether or not to remove special tokens |
|
in the decoding. Default to be False. |
|
Returns: |
|
(a list of/batched) text/chat completion |
|
""" |
|
from vllm import SamplingParams |
|
|
|
batched = True |
|
if isinstance(inputs, str): |
|
inputs = [inputs] |
|
batched = False |
|
if session_ids is None: |
|
session_ids = list(range(len(inputs))) |
|
elif isinstance(session_ids, (int, str)): |
|
session_ids = [session_ids] |
|
assert len(inputs) == len(session_ids) |
|
|
|
prompt = inputs |
|
gen_params = self.update_gen_params(**kwargs) |
|
max_new_tokens = gen_params.pop('max_new_tokens') |
|
stop_words = gen_params.pop('stop_words') |
|
|
|
sampling_config = SamplingParams( |
|
skip_special_tokens=skip_special_tokens, |
|
max_tokens=max_new_tokens, |
|
stop=stop_words, |
|
**gen_params) |
|
|
|
async def _inner_generate(uid, text): |
|
resp, generator = '', self.model.generate( |
|
text, sampling_params=sampling_config, request_id=uid) |
|
async for out in generator: |
|
resp = out.outputs[0] |
|
return resp |
|
|
|
response = await asyncio.gather(*[ |
|
_inner_generate(sid, inp) for sid, inp in zip(session_ids, prompt) |
|
]) |
|
texts = [resp.text for resp in response] |
|
|
|
texts = filter_suffix(texts, self.gen_params.get('stop_words')) |
|
for resp, text in zip(response, texts): |
|
resp.text = text |
|
if batched: |
|
return [asdict_completion(resp) |
|
for resp in response] if return_dict else texts |
|
return asdict_completion(response[0]) if return_dict else texts[0] |
|
|