chatglm3-6b-chitchat / handler.py
Yingxu He
Update handler.py
7e813b1 verified
raw
history blame
1.63 kB
import torch
import chatglm_cpp
from typing import Dict, List, Any
# get dtype
# dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
TURN_BREAKER = "<||turn_breaker||>"
SYSTEM_SYMBOL = "<||system_symbol||>"
USER_SYMBOL = "<||user_symbol||>"
ASSISTANT_SYMBOL = "<||assistant_symbol||>"
class EndpointHandler:
def __init__(self, path=""):
# load the model
self.pipeline = chatglm_cpp.Pipeline(f"{path}/q5_1.bin")
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
str_messages = inputs.split(TURN_BREAKER)
cpp_messages = [chatglm_cpp.ChatMessage(
role="system",
content=str_messages[0].replace(SYSTEM_SYMBOL, "")
)]
for msg in str_messages[1:]:
if USER_SYMBOL in msg:
cpp_messages.append(
chatglm_cpp.ChatMessage(
role="user",
content=msg.replace(USER_SYMBOL, "")
))
else:
cpp_messages.append(
chatglm_cpp.ChatMessage(
role="assistant",
content=msg.replace(ASSISTANT_SYMBOL, "")
))
# pass inputs with all kwargs in data
if parameters is not None:
prediction = self.pipeline.chat(cpp_messages, **parameters)
else:
prediction = self.pipeline.chat(cpp_messages)
# postprocess the prediction
return prediction.content