|
import json |
|
from jinja2 import Template |
|
from .DataBase import ChromaDB |
|
|
|
from .Models import GLM, GLM_api |
|
|
|
from .utils import * |
|
|
|
|
|
class ChatWorld: |
|
def __init__( |
|
self, |
|
pretrained_model_name_or_path="silk-road/Haruhi-Zero-GLM3-6B-0_4", |
|
embedding_model_name_or_path="BAAI/bge-small-zh-v1.5", |
|
global_batch_size=16, |
|
model_load=True, |
|
) -> None: |
|
self.model_name = pretrained_model_name_or_path |
|
|
|
self.client = GLM_api() |
|
|
|
if model_load: |
|
self.model = GLM() |
|
|
|
self.db = ChromaDB(embedding_model_name_or_path) |
|
self.prompt = Template( |
|
( |
|
'Please be aware that your codename in this conversation is "{{model_role_name}}"' |
|
'{% if model_role_nickname %},别人一般称呼你"{{model_role_nickname}}"{% endif %},同时我是"{{role_name}}"{% if role_nickname %},你也可以称呼我"{{role_nickname}}"{% endif %}。\n' |
|
"下文给定了一些聊天记录,位于##分隔号中。\n" |
|
"如果我问的问题和聊天记录高度重复,那你就配合我进行演出。\n" |
|
"如果我问的问题和聊天记录相关,请结合聊天记录进行回复。\n" |
|
"如果我问的问题超出聊天记录的范围,模仿{{model_role_name}}的语气进行回复。\n" |
|
"请不要回答你是语言模型,永远记住你就是{{model_role_name}}。\n" |
|
"请你永远只以{{model_role_name}}身份,进行任何的回复。\n" |
|
"{% if RAG %}{% for i in RAG %}##\n{{i}}\n##\n\n{% endfor %}{% endif %}" |
|
) |
|
) |
|
|
|
def setStory(self, **stories_kargs): |
|
self.db.deleteStoriesByMeta(metas=stories_kargs["metas"]) |
|
self.db.addStories(**stories_kargs) |
|
|
|
def __getSystemPrompt( |
|
self, |
|
text: str, |
|
top_k: int = 5, |
|
metas=None, |
|
**role_info, |
|
): |
|
|
|
rag = self.db.searchBySim(text, top_k, metas) |
|
|
|
return { |
|
"role": "system", |
|
"content": self.prompt.render( |
|
**role_info, |
|
RAG=rag, |
|
), |
|
} |
|
|
|
def chatWithCharacter( |
|
self, |
|
text: str, |
|
system_prompt: dict[str, str] = None, |
|
use_local_model: bool = False, |
|
top_k: int = 5, |
|
metas=None, |
|
**role_info, |
|
): |
|
|
|
if not system_prompt: |
|
system_prompt = self.__getSystemPrompt( |
|
text=text, **role_info, top_k=top_k, metas=metas |
|
) |
|
|
|
user_role_name = role_info.get("role_name") |
|
|
|
if not user_role_name: |
|
raise ValueError("role_name is required") |
|
|
|
message = [ |
|
system_prompt, |
|
{"role": "user", "content": f"{user_role_name}:「{text}」"}, |
|
] |
|
|
|
logging_info(f"message: {message}") |
|
|
|
if use_local_model: |
|
response = self.model.get_response(message) |
|
else: |
|
response = self.client.chat(message) |
|
|
|
return response |
|
|
|
def chatWithoutCharacter( |
|
self, |
|
text: str, |
|
system_prompt: dict[str, str] = None, |
|
use_local_model: bool = False, |
|
): |
|
|
|
logging_info(f"text: {text}") |
|
|
|
message = [ |
|
{"role": "user", "content": f"{text}"}, |
|
] |
|
|
|
if use_local_model: |
|
response = self.model.get_response(text) |
|
else: |
|
|
|
response = self.client.chat(message) |
|
|
|
return response |
|
|
|
def getRoleNameFromFile(self, input_file: str): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt = ( |
|
f"{input_file}\n" |
|
+ '请你提取包含“人”(name,nickname)类型的所有信息,如果nickname不存在则设置为空字符串,并输出JSON格式。并且不要提取出重复的同一个人。例如格式如下:\n```json\n [{"name": "小明","nickname": "小明"},{"name": "小红","nickname": ""}]```' |
|
) |
|
|
|
respense = self.chatWithoutCharacter(prompt, use_local_model=False) |
|
|
|
json_start_index = respense.find("```json") |
|
json_end_index = respense.find("```", json_start_index + 1) |
|
|
|
json_str = respense[json_start_index + 7 : json_end_index] |
|
|
|
print(json_str) |
|
|
|
try: |
|
json_str = json.loads(json_str) |
|
role_name_list = [i["name"] for i in json_str] |
|
role_name_dict = {i["name"]: i["nickname"] for i in json_str} |
|
except Exception as e: |
|
print(e) |
|
role_name_list = [] |
|
role_name_dict = {} |
|
|
|
return role_name_list, role_name_dict |
|
|