PoetryChat / src /models.py
Tsumugii24
initial commit
f7161fa
raw
history blame
4.08 kB
# -*- coding: utf-8 -*-
"""
Get model client from model name
"""
import json
import colorama
import gradio as gr
from loguru import logger
from src import config
from src.base_model import ModelType
from src.utils import (
hide_middle_chars,
i18n,
)
def get_model(
model_name,
lora_model_path=None,
access_key=None,
temperature=None,
top_p=None,
system_prompt=None,
user_name="",
original_model=None,
):
msg = i18n("模型设置为了:") + f" {model_name}"
model_type = ModelType.get_type(model_name)
lora_choices = ["No LoRA"]
if model_type != ModelType.OpenAI:
config.local_embedding = True
model = original_model
chatbot = gr.Chatbot.update(label=model_name)
try:
if model_type == ModelType.OpenAI:
logger.info(f"正在加载OpenAI模型: {model_name}")
from src.openai_client import OpenAIClient
model = OpenAIClient(
model_name=model_name,
api_key=access_key,
system_prompt=system_prompt,
user_name=user_name,
)
elif model_type == ModelType.OpenAIVision:
logger.info(f"正在加载OpenAI Vision模型: {model_name}")
from src.openai_client import OpenAIVisionClient
model = OpenAIVisionClient(model_name, api_key=access_key, user_name=user_name)
elif model_type == ModelType.ChatGLM:
logger.info(f"正在加载ChatGLM模型: {model_name}")
from src.chatglm import ChatGLMClient
model = ChatGLMClient(model_name, user_name=user_name)
elif model_type == ModelType.LLaMA:
logger.info(f"正在加载LLaMA模型: {model_name}")
from src.llama import LLaMAClient
model = LLaMAClient(model_name, user_name=user_name)
elif model_type == ModelType.ZhipuAI: # todo: fix zhipu bug
logger.info(f"正在加载ZhipuAI模型: {model_name}")
from src.zhipu_client import ZhipuAIClient
model = ZhipuAIClient(
model_name=model_name,
api_key=access_key,
system_prompt=system_prompt,
user_name=user_name,
)
elif model_type == ModelType.Unknown:
raise ValueError(f"未知模型: {model_name}")
except Exception as e:
logger.error(e)
logger.info(msg)
presudo_key = hide_middle_chars(access_key)
if original_model is not None and model is not None:
model.history = original_model.history
model.history_file_path = original_model.history_file_path
return model, msg, chatbot, gr.Dropdown.update(choices=lora_choices, visible=False), access_key, presudo_key
if __name__ == "__main__":
with open("../config.json", "r") as f:
openai_api_key = json.load(f)["openai_api_key"]
print('key:', openai_api_key)
client = get_model(model_name="gpt-3.5-turbo", access_key=openai_api_key)[0]
chatbot = []
stream = False
# 测试账单功能
logger.info(colorama.Back.GREEN + "测试账单功能" + colorama.Back.RESET)
logger.info(client.billing_info())
# 测试问答
logger.info(colorama.Back.GREEN + "测试问答" + colorama.Back.RESET)
question = "巴黎是中国的首都吗?"
for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
logger.info(i)
logger.info(f"测试问答后history : {client.history}")
# 测试记忆力
logger.info(colorama.Back.GREEN + "测试记忆力" + colorama.Back.RESET)
question = "我刚刚问了你什么问题?"
for i in client.predict(inputs=question, chatbot=chatbot, stream=stream):
logger.info(i)
logger.info(f"测试记忆力后history : {client.history}")
# 测试重试功能
logger.info(colorama.Back.GREEN + "测试重试功能" + colorama.Back.RESET)
for i in client.retry(chatbot=chatbot, stream=stream):
logger.info(i)
logger.info(f"重试后history : {client.history}")