|
import logging |
|
import os |
|
import re |
|
import shutil |
|
|
|
import gradio as gr |
|
import openai |
|
import pandas as pd |
|
from backoff import on_exception, expo |
|
from sqlalchemy import create_engine |
|
|
|
from tools.doc_qa import DocQAPromptAdapter |
|
from tools.web.overwrites import postprocess, reload_javascript |
|
from tools.web.presets import ( |
|
small_and_beautiful_theme, |
|
title, |
|
description, |
|
description_top, |
|
CONCURRENT_COUNT |
|
) |
|
from tools.web.utils import ( |
|
convert_to_markdown, |
|
shared_state, |
|
reset_textbox, |
|
cancel_outputing, |
|
transfer_input, |
|
reset_state, |
|
delete_last_conversation |
|
) |
|
|
|
logging.basicConfig( |
|
level=logging.DEBUG, |
|
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s", |
|
) |
|
|
|
|
|
openai.api_key = "xxx" |
|
doc_adapter = DocQAPromptAdapter() |
|
|
|
|
|
def add_llm(model_name, api_base, models): |
|
""" 添加模型 """ |
|
models = models or {} |
|
if model_name and api_base: |
|
models.update( |
|
{ |
|
model_name: api_base |
|
} |
|
) |
|
choices = [m[0] for m in models.items()] |
|
return "", "", models, gr.Dropdown.update(choices=choices, value=choices[0] if choices else None) |
|
|
|
|
|
def set_openai_env(api_base): |
|
""" 配置接口地址 """ |
|
openai.api_base = api_base |
|
doc_adapter.embeddings.openai_api_base = api_base |
|
|
|
|
|
def get_file_list(): |
|
""" 获取文件列表 """ |
|
if not os.path.exists("doc_store"): |
|
return [] |
|
return os.listdir("doc_store") |
|
|
|
|
|
file_list = get_file_list() |
|
|
|
|
|
def upload_file(file): |
|
""" 上传文件 """ |
|
if not os.path.exists("doc_store"): |
|
os.mkdir("docs") |
|
|
|
if file is not None: |
|
filename = os.path.basename(file.name) |
|
shutil.move(file.name, f"doc_store/{filename}") |
|
file_list = get_file_list() |
|
file_list.remove(filename) |
|
file_list.insert(0, filename) |
|
return gr.Dropdown.update(choices=file_list, value=filename) |
|
|
|
|
|
def add_vector_store(filename, model_name, models, chunk_size, chunk_overlap): |
|
""" 将文件转为向量数据存储 """ |
|
api_base = models[model_name] |
|
set_openai_env(api_base) |
|
doc_adapter.chunk_size = chunk_size |
|
doc_adapter.chunk_overlap = chunk_overlap |
|
|
|
if filename is not None: |
|
vs_path = f"vector_store/{filename.split('.')[0]}-{filename.split('.')[-1]}" |
|
if not os.path.exists(vs_path): |
|
doc_adapter.create_vector_store(f"doc_store/{filename}", vs_path=vs_path) |
|
msg = f"Successfully added vector store for {filename}!" |
|
else: |
|
doc_adapter.reset_vector_store(vs_path=vs_path) |
|
msg = f"Successfully loaded vector store for {filename}!" |
|
else: |
|
msg = "Please select a file!" |
|
return msg |
|
|
|
|
|
def add_db(db_user, db_password, db_host, db_port, db_name, databases): |
|
""" 添加数据库 """ |
|
databases = databases or {} |
|
if db_user and db_password and db_host and db_port and db_name: |
|
databases.update( |
|
{ |
|
db_name: { |
|
"user": db_user, |
|
"password": db_password, |
|
"host": db_host, |
|
"port": int(db_port) |
|
} |
|
} |
|
) |
|
choices = [m[0] for m in databases.items()] |
|
return "", "", "", "", "", databases, gr.Dropdown.update(choices=choices, value=choices[0] if choices else None) |
|
|
|
|
|
def get_table_names(select_database, databases): |
|
""" 获取数据库表名 """ |
|
if select_database: |
|
db_config = databases[select_database] |
|
con = create_engine(f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{select_database}") |
|
tables = pd.read_sql("show tables;", con=con).values |
|
tables = [t[0] for t in tables] |
|
return gr.Dropdown.update(choices=tables, value=[tables[0]]) |
|
|
|
|
|
def get_sql_result(x, con): |
|
q = r"sql\n(.+?);\n" |
|
sql = re.findall(q, x, re.DOTALL)[0] + ";" |
|
df = pd.read_sql(sql, con=con).iloc[:10, :] |
|
return df.to_markdown(numalign="center", stralign="center") |
|
|
|
|
|
@on_exception(expo, openai.error.RateLimitError, max_tries=5) |
|
def chat_completions_create(params): |
|
""" chat接口 """ |
|
return openai.ChatCompletion.create(**params) |
|
|
|
|
|
def predict( |
|
model_name, |
|
models, |
|
text, |
|
chatbot, |
|
history, |
|
top_p, |
|
temperature, |
|
max_tokens, |
|
memory_k, |
|
is_kgqa, |
|
single_turn, |
|
is_dbqa, |
|
select_database, |
|
select_table, |
|
databases, |
|
): |
|
api_base = models[model_name] |
|
set_openai_env(api_base) |
|
|
|
if text == "": |
|
yield chatbot, history, "Empty context." |
|
return |
|
|
|
if history is None: |
|
history = [] |
|
|
|
messages = [] |
|
if is_dbqa: |
|
temperature = 0.0 |
|
db_config = databases[select_database] |
|
con = create_engine(f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{select_database}") |
|
table_schema = "" |
|
for t in select_table: |
|
table_schema += pd.read_sql(f"show create table {t};", con=con)["Create Table"][0] + "\n\n" |
|
table_schema = table_schema.replace("DEFAULT NULL", "") |
|
messages.append( |
|
{ |
|
"role": "system", |
|
"content": f"你现在是一名SQL助手,能够根据用户的问题生成准确的SQL查询。已知SQL的建表语句为:{table_schema}根据上述数据库信息,回答相关问题。" |
|
}, |
|
) |
|
else: |
|
if not single_turn: |
|
for h in history[-memory_k:]: |
|
messages.extend( |
|
[ |
|
{ |
|
"role": "user", |
|
"content": h[0] |
|
}, |
|
{ |
|
"role": "assistant", |
|
"content": h[1] |
|
} |
|
] |
|
) |
|
|
|
messages.append( |
|
{ |
|
"role": "user", |
|
"content": doc_adapter(text) if is_kgqa else text |
|
} |
|
) |
|
|
|
params = dict( |
|
stream=True, |
|
messages=messages, |
|
model=model_name, |
|
top_p=top_p, |
|
temperature=temperature, |
|
max_tokens=max_tokens |
|
) |
|
|
|
res = chat_completions_create(params) |
|
x = "" |
|
for openai_object in res: |
|
delta = openai_object.choices[0]["delta"] |
|
if "content" in delta: |
|
x += delta["content"] |
|
|
|
a, b = [[y[0], convert_to_markdown(y[1])] for y in history] + [ |
|
[text, convert_to_markdown(x)] |
|
], history + [[text, x]] |
|
|
|
yield a, b, "Generating..." |
|
|
|
if shared_state.interrupted: |
|
shared_state.recover() |
|
try: |
|
yield a, b, "Stop: Success" |
|
return |
|
except: |
|
pass |
|
|
|
if is_dbqa: |
|
try: |
|
res = get_sql_result(x, con) |
|
a[-1][-1] += "\n\n" + convert_to_markdown(res) |
|
b[-1][-1] += "\n\n" + convert_to_markdown(res) |
|
except: |
|
pass |
|
|
|
try: |
|
yield a, b, "Generate: Success" |
|
except: |
|
pass |
|
|
|
|
|
def retry( |
|
model_name, |
|
models, |
|
text, |
|
chatbot, |
|
history, |
|
top_p, |
|
temperature, |
|
max_tokens, |
|
memory_k, |
|
is_kgqa, |
|
single_turn, |
|
is_dbqa, |
|
select_database, |
|
select_table, |
|
databases, |
|
): |
|
logging.info("Retry...") |
|
if len(history) == 0: |
|
yield chatbot, history, "Empty context." |
|
return |
|
chatbot.pop() |
|
inputs = history.pop()[0] |
|
for x in predict( |
|
model_name, |
|
models, |
|
inputs, |
|
chatbot, |
|
history, |
|
top_p, |
|
temperature, |
|
max_tokens, |
|
memory_k, |
|
is_kgqa, |
|
single_turn, |
|
is_dbqa, |
|
select_database, |
|
select_table, |
|
databases, |
|
): |
|
yield x |
|
|
|
|
|
gr.Chatbot.postprocess = postprocess |
|
|
|
with open("assets/custom.css", "r", encoding="utf-8") as f: |
|
customCSS = f.read() |
|
|
|
with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: |
|
history = gr.State([]) |
|
user_question = gr.State("") |
|
|
|
with gr.Row(): |
|
gr.HTML(title) |
|
status_display = gr.Markdown("Success", elem_id="status_display") |
|
|
|
gr.Markdown(description_top) |
|
|
|
with gr.Row(scale=1).style(equal_height=True): |
|
with gr.Column(scale=5): |
|
with gr.Row(): |
|
chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="100%") |
|
with gr.Row(): |
|
with gr.Column(scale=12): |
|
user_input = gr.Textbox( |
|
show_label=False, placeholder="Enter text" |
|
).style(container=False) |
|
with gr.Column(min_width=70, scale=1): |
|
submitBtn = gr.Button("发送") |
|
with gr.Column(min_width=70, scale=1): |
|
cancelBtn = gr.Button("停止") |
|
with gr.Row(): |
|
emptyBtn = gr.Button( |
|
"🧹 新的对话", |
|
) |
|
retryBtn = gr.Button("🔄 重新生成") |
|
delLastBtn = gr.Button("🗑️ 删除最旧对话") |
|
|
|
with gr.Column(): |
|
with gr.Column(min_width=50, scale=1): |
|
with gr.Tab(label="模型"): |
|
model_name = gr.Textbox( |
|
placeholder="chatglm", |
|
label="模型名称", |
|
) |
|
api_base = gr.Textbox( |
|
placeholder="https://0.0.0.0:80/v1", |
|
label="模型接口地址", |
|
) |
|
add_model = gr.Button("添加模型") |
|
with gr.Accordion(open=False, label="所有模型配置"): |
|
models = gr.Json() |
|
single_turn = gr.Checkbox(label="使用单轮对话", value=False) |
|
select_model = gr.Dropdown( |
|
choices=[m[0] for m in models.value.items()] if models.value else [], |
|
value=[m[0] for m in models.value.items()][0] if models.value else None, |
|
label="选择模型", |
|
interactive=True, |
|
) |
|
|
|
with gr.Tab(label="知识库"): |
|
is_kgqa = gr.Checkbox( |
|
label="使用知识库问答", |
|
value=False, |
|
interactive=True, |
|
) |
|
gr.Markdown("""**基于本地知识库生成更加准确的回答!**""") |
|
select_file = gr.Dropdown( |
|
choices=file_list, |
|
label="选择文件", |
|
interactive=True, |
|
value=file_list[0] if len(file_list) > 0 else None |
|
) |
|
file = gr.File( |
|
label="上传文件", |
|
visible=True, |
|
file_types=['.txt', '.md', '.docx', '.pdf'] |
|
) |
|
add_vs = gr.Button(value="添加到知识库") |
|
|
|
with gr.Tab(label="数据库"): |
|
with gr.Accordion(open=False, label="数据库配置"): |
|
db_user = gr.Textbox( |
|
placeholder="root", |
|
label="用户名", |
|
) |
|
db_password = gr.Textbox( |
|
placeholder="password", |
|
label="密码", |
|
type="password" |
|
) |
|
db_host = gr.Textbox( |
|
placeholder="0.0.0.0", |
|
label="主机", |
|
) |
|
db_port = gr.Textbox( |
|
placeholder="3306", |
|
label="端口", |
|
) |
|
db_name = gr.Textbox( |
|
placeholder="test", |
|
label="数据库名称", |
|
) |
|
add_database = gr.Button("添加数据库") |
|
|
|
with gr.Accordion(open=False, label="所有数据库配置"): |
|
databases = gr.Json() |
|
select_database = gr.Dropdown( |
|
choices=[d[0] for d in databases.value.items()] if databases.value else [], |
|
value=[d[0] for d in databases.value.items()][0] if databases.value else None, |
|
interactive=True, |
|
label="选择数据库" |
|
) |
|
select_table = gr.Dropdown(label="选择表", interactive=True, multiselect=True) |
|
is_dbqa = gr.Checkbox( |
|
label="使用数据库问答", |
|
value=False, |
|
interactive=True, |
|
) |
|
|
|
with gr.Tab(label="参数"): |
|
top_p = gr.Slider( |
|
minimum=-0, |
|
maximum=1.0, |
|
value=0.95, |
|
step=0.05, |
|
interactive=True, |
|
label="Top-p", |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=1, |
|
step=0.1, |
|
interactive=True, |
|
label="Temperature", |
|
) |
|
max_tokens = gr.Slider( |
|
minimum=0, |
|
maximum=512, |
|
value=512, |
|
step=8, |
|
interactive=True, |
|
label="Max Generation Tokens", |
|
) |
|
memory_k = gr.Slider( |
|
minimum=0, |
|
maximum=10, |
|
value=5, |
|
step=1, |
|
interactive=True, |
|
label="Max Memory Window Size", |
|
) |
|
chunk_size = gr.Slider( |
|
minimum=100, |
|
maximum=1000, |
|
value=200, |
|
step=100, |
|
interactive=True, |
|
label="Chunk Size", |
|
) |
|
chunk_overlap = gr.Slider( |
|
minimum=0, |
|
maximum=100, |
|
value=0, |
|
step=10, |
|
interactive=True, |
|
label="Chunk Overlap", |
|
) |
|
|
|
gr.Markdown(description) |
|
|
|
add_model.click( |
|
add_llm, |
|
inputs=[model_name, api_base, models], |
|
outputs=[model_name, api_base, models, select_model], |
|
) |
|
|
|
add_database.click( |
|
add_db, |
|
inputs=[db_user, db_password, db_host, db_port, db_name, databases], |
|
outputs=[db_user, db_password, db_host, db_port, db_name, databases, select_database], |
|
) |
|
|
|
select_database.change( |
|
get_table_names, |
|
inputs=[select_database, databases], |
|
outputs=select_table, |
|
) |
|
|
|
file.upload( |
|
upload_file, |
|
inputs=file, |
|
outputs=select_file, |
|
) |
|
|
|
add_vs.click( |
|
add_vector_store, |
|
inputs=[select_file, select_model, models, chunk_size, chunk_overlap], |
|
outputs=status_display, |
|
) |
|
|
|
predict_args = dict( |
|
fn=predict, |
|
inputs=[ |
|
select_model, |
|
models, |
|
user_question, |
|
chatbot, |
|
history, |
|
top_p, |
|
temperature, |
|
max_tokens, |
|
memory_k, |
|
is_kgqa, |
|
single_turn, |
|
is_dbqa, |
|
select_database, |
|
select_table, |
|
databases, |
|
], |
|
outputs=[chatbot, history, status_display], |
|
show_progress=True, |
|
) |
|
retry_args = dict( |
|
fn=retry, |
|
inputs=[ |
|
select_model, |
|
models, |
|
user_question, |
|
chatbot, |
|
history, |
|
top_p, |
|
temperature, |
|
max_tokens, |
|
memory_k, |
|
is_kgqa, |
|
single_turn, |
|
is_dbqa, |
|
select_database, |
|
select_table, |
|
databases, |
|
], |
|
outputs=[chatbot, history, status_display], |
|
show_progress=True, |
|
) |
|
|
|
reset_args = dict(fn=reset_textbox, inputs=[], outputs=[user_input, status_display]) |
|
|
|
cancelBtn.click(cancel_outputing, [], [status_display]) |
|
transfer_input_args = dict( |
|
fn=transfer_input, |
|
inputs=[user_input], |
|
outputs=[user_question, user_input, submitBtn, cancelBtn], |
|
show_progress=True, |
|
) |
|
|
|
user_input.submit(**transfer_input_args).then(**predict_args) |
|
|
|
submitBtn.click(**transfer_input_args).then(**predict_args) |
|
|
|
emptyBtn.click( |
|
reset_state, |
|
outputs=[chatbot, history, status_display], |
|
show_progress=True, |
|
) |
|
emptyBtn.click(**reset_args) |
|
|
|
retryBtn.click(**retry_args) |
|
|
|
delLastBtn.click( |
|
delete_last_conversation, |
|
[chatbot, history], |
|
[chatbot, history, status_display], |
|
show_progress=True, |
|
) |
|
|
|
demo.title = "OpenLLM Chatbot 🚀 " |
|
|
|
if __name__ == "__main__": |
|
reload_javascript() |
|
demo.queue(concurrency_count=CONCURRENT_COUNT).launch() |
|
|