File size: 3,979 Bytes
0339b60 6f179e7 0339b60 87818fb 6f179e7 87818fb 0339b60 6f179e7 0339b60 6f179e7 0339b60 6f179e7 0339b60 87818fb 6f179e7 0339b60 6f179e7 87818fb 6f179e7 0339b60 6f179e7 0339b60 87818fb 0339b60 87818fb 0339b60 6f179e7 0339b60 87818fb 6f179e7 87818fb 6f179e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import logging
import gradio as gr
from src import ChatWorld
chatWorld = ChatWorld()
role_name_list_global = None
role_name_dict_global = None
def getContent(input_file):
# 读取文件内容
with open(input_file.name, "r", encoding="utf-8") as f:
logging.info(f"read file {input_file.name}")
input_text = f.read()
logging.info(f"file content: {input_text}")
chatWorld.setStory(stories=input_text, metas=None)
# 保存文件内容
role_name_list, role_name_dict = chatWorld.getRoleNameFromFile(input_text)
global role_name_list_global
role_name_list_global = role_name_list
global role_name_dict_global
role_name_dict_global = role_name_dict
return (
gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[0]),
gr.Radio(choices=role_name_list, interactive=True, value=role_name_list[-1]),
)
def submit_message(
message,
history,
model_role_name,
role_name,
model_role_nickname,
role_nickname,
withCharacter,
):
if withCharacter:
response = chatWorld.chatWithCharacter(
text=message,
role_name=role_name,
role_nickname=role_nickname,
model_role_name=model_role_name,
model_role_nickname=model_role_nickname,
use_local_model=True,
)
else:
response = chatWorld.chatWithoutCharacter(
text=message,
use_local_model=True,
)
return response
def submit_message_api(
message,
history,
model_role_name,
role_name,
model_role_nickname,
role_nickname,
withCharacter,
):
if withCharacter:
response = chatWorld.chatWithCharacter(
text=message,
role_name=role_name,
role_nickname=role_nickname,
model_role_name=model_role_name,
model_role_nickname=model_role_nickname,
use_local_model=False,
)
else:
response = chatWorld.chatWithoutCharacter(
text=message,
use_local_model=False,
)
return response
def get_role_list():
global role_name_list_global
if role_name_list_global:
return role_name_list_global
else:
return []
def change_role_list(name):
global role_name_dict_global
return role_name_dict_global[name]
with gr.Blocks() as demo:
upload_c = gr.File(label="上传文档文件")
with gr.Row():
model_role_name = gr.Radio(get_role_list(), label="模型角色名")
model_role_nickname = gr.Textbox(label="模型角色昵称")
with gr.Row():
role_name = gr.Radio(get_role_list(), label="角色名")
role_nickname = gr.Textbox(label="角色昵称")
model_role_name.change(
fn=change_role_list, inputs=[model_role_name], outputs=[model_role_nickname]
)
role_name.change(fn=change_role_list, inputs=[role_name], outputs=[role_nickname])
upload_c.upload(
fn=getContent, inputs=upload_c, outputs=[model_role_name, role_name]
)
withCharacter = gr.Radio([True, False], value=True, label="是否进行角色扮演")
with gr.Row():
chatBox_local = gr.ChatInterface(
submit_message,
chatbot=gr.Chatbot(height=400, label="本地模型", render=False),
additional_inputs=[
model_role_name,
role_name,
model_role_nickname,
role_nickname,
withCharacter,
],
)
chatBox_api = gr.ChatInterface(
submit_message_api,
chatbot=gr.Chatbot(height=400, label="API模型", render=False),
additional_inputs=[
model_role_name,
role_name,
model_role_nickname,
role_nickname,
withCharacter,
],
)
demo.launch(share=True, server_name="0.0.0.0")
|