OpenLLM / app.py
xusenlin's picture
Update app.py
d877fef
raw
history blame
17.5 kB
import logging
import os
import re
import shutil
import gradio as gr
import openai
import pandas as pd
from backoff import on_exception, expo
from sqlalchemy import create_engine
from tools.doc_qa import DocQAPromptAdapter
from tools.web.overwrites import postprocess, reload_javascript
from tools.web.presets import (
small_and_beautiful_theme,
title,
description,
description_top,
CONCURRENT_COUNT
)
from tools.web.utils import (
convert_to_markdown,
shared_state,
reset_textbox,
cancel_outputing,
transfer_input,
reset_state,
delete_last_conversation
)
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
)
openai.api_key = "xxx"
doc_adapter = DocQAPromptAdapter()
def add_llm(model_name, api_base, models):
""" 添加模型 """
models = models or {}
if model_name and api_base:
models.update(
{
model_name: api_base
}
)
choices = [m[0] for m in models.items()]
return "", "", models, gr.Dropdown.update(choices=choices, value=choices[0] if choices else None)
def set_openai_env(api_base):
""" 配置接口地址 """
openai.api_base = api_base
doc_adapter.embeddings.openai_api_base = api_base
def get_file_list():
""" 获取文件列表 """
if not os.path.exists("doc_store"):
return []
return os.listdir("doc_store")
file_list = get_file_list()
def upload_file(file):
""" 上传文件 """
if not os.path.exists("doc_store"):
os.mkdir("docs")
if file is not None:
filename = os.path.basename(file.name)
shutil.move(file.name, f"doc_store/{filename}")
file_list = get_file_list()
file_list.remove(filename)
file_list.insert(0, filename)
return gr.Dropdown.update(choices=file_list, value=filename)
def add_vector_store(filename, model_name, models, chunk_size, chunk_overlap):
""" 将文件转为向量数据存储 """
api_base = models[model_name]
set_openai_env(api_base)
doc_adapter.chunk_size = chunk_size
doc_adapter.chunk_overlap = chunk_overlap
if filename is not None:
vs_path = f"vector_store/{filename.split('.')[0]}-{filename.split('.')[-1]}"
if not os.path.exists(vs_path):
doc_adapter.create_vector_store(f"doc_store/{filename}", vs_path=vs_path)
msg = f"Successfully added vector store for {filename}!"
else:
doc_adapter.reset_vector_store(vs_path=vs_path)
msg = f"Successfully loaded vector store for {filename}!"
else:
msg = "Please select a file!"
return msg
def add_db(db_user, db_password, db_host, db_port, db_name, databases):
""" 添加数据库 """
databases = databases or {}
if db_user and db_password and db_host and db_port and db_name:
databases.update(
{
db_name: {
"user": db_user,
"password": db_password,
"host": db_host,
"port": int(db_port)
}
}
)
choices = [m[0] for m in databases.items()]
return "", "", "", "", "", databases, gr.Dropdown.update(choices=choices, value=choices[0] if choices else None)
def get_table_names(select_database, databases):
""" 获取数据库表名 """
if select_database:
db_config = databases[select_database]
con = create_engine(f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{select_database}")
tables = pd.read_sql("show tables;", con=con).values
tables = [t[0] for t in tables]
return gr.Dropdown.update(choices=tables, value=[tables[0]])
def get_sql_result(x, con):
q = r"sql\n(.+?);\n"
sql = re.findall(q, x, re.DOTALL)[0] + ";"
df = pd.read_sql(sql, con=con).iloc[:10, :]
return df.to_markdown(numalign="center", stralign="center")
@on_exception(expo, openai.error.RateLimitError, max_tries=5)
def chat_completions_create(params):
""" chat接口 """
return openai.ChatCompletion.create(**params)
def predict(
model_name,
models,
text,
chatbot,
history,
top_p,
temperature,
max_tokens,
memory_k,
is_kgqa,
single_turn,
is_dbqa,
select_database,
select_table,
databases,
):
api_base = models[model_name]
set_openai_env(api_base)
if text == "":
yield chatbot, history, "Empty context."
return
if history is None:
history = []
messages = []
if is_dbqa:
temperature = 0.0
db_config = databases[select_database]
con = create_engine(f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{select_database}")
table_schema = ""
for t in select_table:
table_schema += pd.read_sql(f"show create table {t};", con=con)["Create Table"][0] + "\n\n"
table_schema = table_schema.replace("DEFAULT NULL", "")
messages.append(
{
"role": "system",
"content": f"你现在是一名SQL助手,能够根据用户的问题生成准确的SQL查询。已知SQL的建表语句为:{table_schema}根据上述数据库信息,回答相关问题。"
},
)
else:
if not single_turn:
for h in history[-memory_k:]:
messages.extend(
[
{
"role": "user",
"content": h[0]
},
{
"role": "assistant",
"content": h[1]
}
]
)
messages.append(
{
"role": "user",
"content": doc_adapter(text) if is_kgqa else text
}
)
params = dict(
stream=True,
messages=messages,
model=model_name,
top_p=top_p,
temperature=temperature,
max_tokens=max_tokens
)
res = chat_completions_create(params)
x = ""
for openai_object in res:
delta = openai_object.choices[0]["delta"]
if "content" in delta:
x += delta["content"]
a, b = [[y[0], convert_to_markdown(y[1])] for y in history] + [
[text, convert_to_markdown(x)]
], history + [[text, x]]
yield a, b, "Generating..."
if shared_state.interrupted:
shared_state.recover()
try:
yield a, b, "Stop: Success"
return
except:
pass
if is_dbqa:
try:
res = get_sql_result(x, con)
a[-1][-1] += "\n\n" + convert_to_markdown(res)
b[-1][-1] += "\n\n" + convert_to_markdown(res)
except:
pass
try:
yield a, b, "Generate: Success"
except:
pass
def retry(
model_name,
models,
text,
chatbot,
history,
top_p,
temperature,
max_tokens,
memory_k,
is_kgqa,
single_turn,
is_dbqa,
select_database,
select_table,
databases,
):
logging.info("Retry...")
if len(history) == 0:
yield chatbot, history, "Empty context."
return
chatbot.pop()
inputs = history.pop()[0]
for x in predict(
model_name,
models,
inputs,
chatbot,
history,
top_p,
temperature,
max_tokens,
memory_k,
is_kgqa,
single_turn,
is_dbqa,
select_database,
select_table,
databases,
):
yield x
gr.Chatbot.postprocess = postprocess
with open("assets/custom.css", "r", encoding="utf-8") as f:
customCSS = f.read()
with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo:
history = gr.State([])
user_question = gr.State("")
with gr.Row():
gr.HTML(title)
status_display = gr.Markdown("Success", elem_id="status_display")
gr.Markdown(description_top)
with gr.Row(scale=1).style(equal_height=True):
with gr.Column(scale=5):
with gr.Row():
chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="100%")
with gr.Row():
with gr.Column(scale=12):
user_input = gr.Textbox(
show_label=False, placeholder="Enter text"
).style(container=False)
with gr.Column(min_width=70, scale=1):
submitBtn = gr.Button("发送")
with gr.Column(min_width=70, scale=1):
cancelBtn = gr.Button("停止")
with gr.Row():
emptyBtn = gr.Button(
"🧹 新的对话",
)
retryBtn = gr.Button("🔄 重新生成")
delLastBtn = gr.Button("🗑️ 删除最旧对话")
with gr.Column():
with gr.Column(min_width=50, scale=1):
with gr.Tab(label="模型"):
model_name = gr.Textbox(
placeholder="chatglm",
label="模型名称",
)
api_base = gr.Textbox(
placeholder="https://0.0.0.0:80/v1",
label="模型接口地址",
)
add_model = gr.Button("添加模型")
with gr.Accordion(open=False, label="所有模型配置"):
models = gr.Json()
single_turn = gr.Checkbox(label="使用单轮对话", value=False)
select_model = gr.Dropdown(
choices=[m[0] for m in models.value.items()] if models.value else [],
value=[m[0] for m in models.value.items()][0] if models.value else None,
label="选择模型",
interactive=True,
)
with gr.Tab(label="知识库"):
is_kgqa = gr.Checkbox(
label="使用知识库问答",
value=False,
interactive=True,
)
gr.Markdown("""**基于本地知识库生成更加准确的回答!**""")
select_file = gr.Dropdown(
choices=file_list,
label="选择文件",
interactive=True,
value=file_list[0] if len(file_list) > 0 else None
)
file = gr.File(
label="上传文件",
visible=True,
file_types=['.txt', '.md', '.docx', '.pdf']
)
add_vs = gr.Button(value="添加到知识库")
with gr.Tab(label="数据库"):
with gr.Accordion(open=False, label="数据库配置"):
db_user = gr.Textbox(
placeholder="root",
label="用户名",
)
db_password = gr.Textbox(
placeholder="password",
label="密码",
type="password"
)
db_host = gr.Textbox(
placeholder="0.0.0.0",
label="主机",
)
db_port = gr.Textbox(
placeholder="3306",
label="端口",
)
db_name = gr.Textbox(
placeholder="test",
label="数据库名称",
)
add_database = gr.Button("添加数据库")
with gr.Accordion(open=False, label="所有数据库配置"):
databases = gr.Json()
select_database = gr.Dropdown(
choices=[d[0] for d in databases.value.items()] if databases.value else [],
value=[d[0] for d in databases.value.items()][0] if databases.value else None,
interactive=True,
label="选择数据库"
)
select_table = gr.Dropdown(label="选择表", interactive=True, multiselect=True)
is_dbqa = gr.Checkbox(
label="使用数据库问答",
value=False,
interactive=True,
)
with gr.Tab(label="参数"):
top_p = gr.Slider(
minimum=-0,
maximum=1.0,
value=0.95,
step=0.05,
interactive=True,
label="Top-p",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1,
step=0.1,
interactive=True,
label="Temperature",
)
max_tokens = gr.Slider(
minimum=0,
maximum=512,
value=512,
step=8,
interactive=True,
label="Max Generation Tokens",
)
memory_k = gr.Slider(
minimum=0,
maximum=10,
value=5,
step=1,
interactive=True,
label="Max Memory Window Size",
)
chunk_size = gr.Slider(
minimum=100,
maximum=1000,
value=200,
step=100,
interactive=True,
label="Chunk Size",
)
chunk_overlap = gr.Slider(
minimum=0,
maximum=100,
value=0,
step=10,
interactive=True,
label="Chunk Overlap",
)
gr.Markdown(description)
add_model.click(
add_llm,
inputs=[model_name, api_base, models],
outputs=[model_name, api_base, models, select_model],
)
add_database.click(
add_db,
inputs=[db_user, db_password, db_host, db_port, db_name, databases],
outputs=[db_user, db_password, db_host, db_port, db_name, databases, select_database],
)
select_database.change(
get_table_names,
inputs=[select_database, databases],
outputs=select_table,
)
file.upload(
upload_file,
inputs=file,
outputs=select_file,
)
add_vs.click(
add_vector_store,
inputs=[select_file, select_model, models, chunk_size, chunk_overlap],
outputs=status_display,
)
predict_args = dict(
fn=predict,
inputs=[
select_model,
models,
user_question,
chatbot,
history,
top_p,
temperature,
max_tokens,
memory_k,
is_kgqa,
single_turn,
is_dbqa,
select_database,
select_table,
databases,
],
outputs=[chatbot, history, status_display],
show_progress=True,
)
retry_args = dict(
fn=retry,
inputs=[
select_model,
models,
user_question,
chatbot,
history,
top_p,
temperature,
max_tokens,
memory_k,
is_kgqa,
single_turn,
is_dbqa,
select_database,
select_table,
databases,
],
outputs=[chatbot, history, status_display],
show_progress=True,
)
reset_args = dict(fn=reset_textbox, inputs=[], outputs=[user_input, status_display])
cancelBtn.click(cancel_outputing, [], [status_display])
transfer_input_args = dict(
fn=transfer_input,
inputs=[user_input],
outputs=[user_question, user_input, submitBtn, cancelBtn],
show_progress=True,
)
user_input.submit(**transfer_input_args).then(**predict_args)
submitBtn.click(**transfer_input_args).then(**predict_args)
emptyBtn.click(
reset_state,
outputs=[chatbot, history, status_display],
show_progress=True,
)
emptyBtn.click(**reset_args)
retryBtn.click(**retry_args)
delLastBtn.click(
delete_last_conversation,
[chatbot, history],
[chatbot, history, status_display],
show_progress=True,
)
demo.title = "OpenLLM Chatbot 🚀 "
if __name__ == "__main__":
reload_javascript()
demo.queue(concurrency_count=CONCURRENT_COUNT).launch()