Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
""" | |
@author:XuMing(xuming624@qq.com) | |
@description: | |
""" | |
import platform | |
from loguru import logger | |
from src.base_model import BaseLLMModel | |
from src.presets import LOCAL_MODELS | |
class ChatGLMClient(BaseLLMModel): | |
def __init__(self, model_name, user_name=""): | |
super().__init__(model_name=model_name, user=user_name) | |
import torch | |
from transformers import AutoModel, AutoTokenizer | |
system_name = platform.system() | |
logger.info(f"Loading model from {model_name}") | |
if model_name in LOCAL_MODELS: | |
model_path = LOCAL_MODELS[model_name] | |
else: | |
model_path = model_name | |
self.CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
quantified = False | |
if "int4" in model_name: | |
quantified = True | |
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map='auto', torch_dtype='auto') | |
if torch.cuda.is_available(): | |
logger.info("CUDA is available, using CUDA") | |
model = model.half().cuda() | |
# mps加速还存在一些问题,暂时不使用 | |
elif system_name == "Darwin" and model_path is not None and not quantified: | |
logger.info("Running on macOS, using MPS") | |
# running on macOS and model already downloaded | |
model = model.half().to("mps") | |
else: | |
logger.info("GPU is not available, using CPU") | |
model = model.float() | |
model = model.eval() | |
logger.info(f"Model loaded from {model_path}") | |
self.CHATGLM_MODEL = model | |
def _get_glm3_style_input(self): | |
history = self.history | |
query = history.pop()["content"] | |
return history, query | |
def _get_glm2_style_input(self): | |
history = [x["content"] for x in self.history] | |
query = history.pop() | |
logger.debug(f"{history}") | |
assert len(history) % 2 == 0, f"History should be even length. current history is: {history}" | |
history = [[history[i], history[i + 1]] | |
for i in range(0, len(history), 2)] | |
return history, query | |
def _get_glm_style_input(self): | |
if "glm2" in self.model_name: | |
return self._get_glm2_style_input() | |
else: | |
return self._get_glm3_style_input() | |
def get_answer_at_once(self): | |
history, query = self._get_glm_style_input() | |
logger.debug(f"{history}") | |
response, _ = self.CHATGLM_MODEL.chat( | |
self.CHATGLM_TOKENIZER, query, history=history) | |
return response, len(response) | |
def get_answer_stream_iter(self): | |
history, query = self._get_glm_style_input() | |
logger.debug(f"{history}") | |
for response, history in self.CHATGLM_MODEL.stream_chat( | |
self.CHATGLM_TOKENIZER, | |
query, | |
history, | |
max_length=self.token_upper_limit, | |
top_p=self.top_p, | |
temperature=self.temperature, | |
): | |
yield response | |