File size: 3,070 Bytes
f7161fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com)
@description: 
"""
import platform

from loguru import logger

from src.base_model import BaseLLMModel
from src.presets import LOCAL_MODELS


class ChatGLMClient(BaseLLMModel):
    def __init__(self, model_name, user_name=""):
        super().__init__(model_name=model_name, user=user_name)
        import torch
        from transformers import AutoModel, AutoTokenizer
        system_name = platform.system()
        logger.info(f"Loading model from {model_name}")
        if model_name in LOCAL_MODELS:
            model_path = LOCAL_MODELS[model_name]
        else:
            model_path = model_name
        self.CHATGLM_TOKENIZER = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        quantified = False
        if "int4" in model_name:
            quantified = True
        model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map='auto', torch_dtype='auto')
        if torch.cuda.is_available():
            logger.info("CUDA is available, using CUDA")
            model = model.half().cuda()
        # mps加速还存在一些问题,暂时不使用
        elif system_name == "Darwin" and model_path is not None and not quantified:
            logger.info("Running on macOS, using MPS")
            # running on macOS and model already downloaded
            model = model.half().to("mps")
        else:
            logger.info("GPU is not available, using CPU")
            model = model.float()
        model = model.eval()
        logger.info(f"Model loaded from {model_path}")
        self.CHATGLM_MODEL = model

    def _get_glm3_style_input(self):
        history = self.history
        query = history.pop()["content"]
        return history, query

    def _get_glm2_style_input(self):
        history = [x["content"] for x in self.history]
        query = history.pop()
        logger.debug(f"{history}")
        assert len(history) % 2 == 0, f"History should be even length. current history is: {history}"
        history = [[history[i], history[i + 1]]
                   for i in range(0, len(history), 2)]
        return history, query

    def _get_glm_style_input(self):
        if "glm2" in self.model_name:
            return self._get_glm2_style_input()
        else:
            return self._get_glm3_style_input()

    def get_answer_at_once(self):
        history, query = self._get_glm_style_input()
        logger.debug(f"{history}")
        response, _ = self.CHATGLM_MODEL.chat(
            self.CHATGLM_TOKENIZER, query, history=history)
        return response, len(response)

    def get_answer_stream_iter(self):
        history, query = self._get_glm_style_input()
        logger.debug(f"{history}")
        for response, history in self.CHATGLM_MODEL.stream_chat(
                self.CHATGLM_TOKENIZER,
                query,
                history,
                max_length=self.token_upper_limit,
                top_p=self.top_p,
                temperature=self.temperature,
        ):
            yield response