import logging import os import re import shutil import gradio as gr import openai import pandas as pd from backoff import on_exception, expo from sqlalchemy import create_engine from tools.doc_qa import DocQAPromptAdapter from tools.web.overwrites import postprocess, reload_javascript from tools.web.presets import ( small_and_beautiful_theme, title, description, description_top, CONCURRENT_COUNT ) from tools.web.utils import ( convert_to_markdown, shared_state, reset_textbox, cancel_outputing, transfer_input, reset_state, delete_last_conversation ) logging.basicConfig( level=logging.DEBUG, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s", ) openai.api_key = "xxx" doc_adapter = DocQAPromptAdapter() def add_llm(model_name, api_base, models): """ 添加模型 """ models = models or {} if model_name and api_base: models.update( { model_name: api_base } ) choices = [m[0] for m in models.items()] return "", "", models, gr.Dropdown.update(choices=choices, value=choices[0] if choices else None) def set_openai_env(api_base): """ 配置接口地址 """ openai.api_base = api_base doc_adapter.embeddings.openai_api_base = api_base def get_file_list(): """ 获取文件列表 """ if not os.path.exists("doc_store"): return [] return os.listdir("doc_store") file_list = get_file_list() def upload_file(file): """ 上传文件 """ if not os.path.exists("doc_store"): os.mkdir("docs") if file is not None: filename = os.path.basename(file.name) shutil.move(file.name, f"doc_store/{filename}") file_list = get_file_list() file_list.remove(filename) file_list.insert(0, filename) return gr.Dropdown.update(choices=file_list, value=filename) def add_vector_store(filename, model_name, models, chunk_size, chunk_overlap): """ 将文件转为向量数据存储 """ api_base = models[model_name] set_openai_env(api_base) doc_adapter.chunk_size = chunk_size doc_adapter.chunk_overlap = chunk_overlap if filename is not None: vs_path = f"vector_store/{filename.split('.')[0]}-{filename.split('.')[-1]}" if not os.path.exists(vs_path): doc_adapter.create_vector_store(f"doc_store/{filename}", vs_path=vs_path) msg = f"Successfully added vector store for {filename}!" else: doc_adapter.reset_vector_store(vs_path=vs_path) msg = f"Successfully loaded vector store for {filename}!" else: msg = "Please select a file!" return msg def add_db(db_user, db_password, db_host, db_port, db_name, databases): """ 添加数据库 """ databases = databases or {} if db_user and db_password and db_host and db_port and db_name: databases.update( { db_name: { "user": db_user, "password": db_password, "host": db_host, "port": int(db_port) } } ) choices = [m[0] for m in databases.items()] return "", "", "", "", "", databases, gr.Dropdown.update(choices=choices, value=choices[0] if choices else None) def get_table_names(select_database, databases): """ 获取数据库表名 """ if select_database: db_config = databases[select_database] con = create_engine(f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{select_database}") tables = pd.read_sql("show tables;", con=con).values tables = [t[0] for t in tables] return gr.Dropdown.update(choices=tables, value=[tables[0]]) def get_sql_result(x, con): q = r"sql\n(.+?);\n" sql = re.findall(q, x, re.DOTALL)[0] + ";" df = pd.read_sql(sql, con=con).iloc[:10, :] return df.to_markdown(numalign="center", stralign="center") @on_exception(expo, openai.error.RateLimitError, max_tries=5) def chat_completions_create(params): """ chat接口 """ return openai.ChatCompletion.create(**params) def predict( model_name, models, text, chatbot, history, top_p, temperature, max_tokens, memory_k, is_kgqa, single_turn, is_dbqa, select_database, select_table, databases, ): api_base = models[model_name] set_openai_env(api_base) if text == "": yield chatbot, history, "Empty context." return if history is None: history = [] messages = [] if is_dbqa: temperature = 0.0 db_config = databases[select_database] con = create_engine(f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{select_database}") table_schema = "" for t in select_table: table_schema += pd.read_sql(f"show create table {t};", con=con)["Create Table"][0] + "\n\n" table_schema = table_schema.replace("DEFAULT NULL", "") messages.append( { "role": "system", "content": f"你现在是一名SQL助手,能够根据用户的问题生成准确的SQL查询。已知SQL的建表语句为:{table_schema}根据上述数据库信息,回答相关问题。" }, ) else: if not single_turn: for h in history[-memory_k:]: messages.extend( [ { "role": "user", "content": h[0] }, { "role": "assistant", "content": h[1] } ] ) messages.append( { "role": "user", "content": doc_adapter(text) if is_kgqa else text } ) params = dict( stream=True, messages=messages, model=model_name, top_p=top_p, temperature=temperature, max_tokens=max_tokens ) res = chat_completions_create(params) x = "" for openai_object in res: delta = openai_object.choices[0]["delta"] if "content" in delta: x += delta["content"] a, b = [[y[0], convert_to_markdown(y[1])] for y in history] + [ [text, convert_to_markdown(x)] ], history + [[text, x]] yield a, b, "Generating..." if shared_state.interrupted: shared_state.recover() try: yield a, b, "Stop: Success" return except: pass if is_dbqa: try: res = get_sql_result(x, con) a[-1][-1] += "\n\n" + convert_to_markdown(res) b[-1][-1] += "\n\n" + convert_to_markdown(res) except: pass try: yield a, b, "Generate: Success" except: pass def retry( model_name, models, text, chatbot, history, top_p, temperature, max_tokens, memory_k, is_kgqa, single_turn, is_dbqa, select_database, select_table, databases, ): logging.info("Retry...") if len(history) == 0: yield chatbot, history, "Empty context." return chatbot.pop() inputs = history.pop()[0] for x in predict( model_name, models, inputs, chatbot, history, top_p, temperature, max_tokens, memory_k, is_kgqa, single_turn, is_dbqa, select_database, select_table, databases, ): yield x gr.Chatbot.postprocess = postprocess with open("assets/custom.css", "r", encoding="utf-8") as f: customCSS = f.read() with gr.Blocks(css=customCSS, theme=small_and_beautiful_theme) as demo: history = gr.State([]) user_question = gr.State("") with gr.Row(): gr.HTML(title) status_display = gr.Markdown("Success", elem_id="status_display") gr.Markdown(description_top) with gr.Row(scale=1).style(equal_height=True): with gr.Column(scale=5): with gr.Row(): chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height="100%") with gr.Row(): with gr.Column(scale=12): user_input = gr.Textbox( show_label=False, placeholder="Enter text" ).style(container=False) with gr.Column(min_width=70, scale=1): submitBtn = gr.Button("发送") with gr.Column(min_width=70, scale=1): cancelBtn = gr.Button("停止") with gr.Row(): emptyBtn = gr.Button( "🧹 新的对话", ) retryBtn = gr.Button("🔄 重新生成") delLastBtn = gr.Button("🗑️ 删除最旧对话") with gr.Column(): with gr.Column(min_width=50, scale=1): with gr.Tab(label="模型"): model_name = gr.Textbox( placeholder="chatglm", label="模型名称", ) api_base = gr.Textbox( placeholder="https://0.0.0.0:80/v1", label="模型接口地址", ) add_model = gr.Button("添加模型") with gr.Accordion(open=False, label="所有模型配置"): models = gr.Json() single_turn = gr.Checkbox(label="使用单轮对话", value=False) select_model = gr.Dropdown( choices=[m[0] for m in models.value.items()] if models.value else [], value=[m[0] for m in models.value.items()][0] if models.value else None, label="选择模型", interactive=True, ) with gr.Tab(label="知识库"): is_kgqa = gr.Checkbox( label="使用知识库问答", value=False, interactive=True, ) gr.Markdown("""**基于本地知识库生成更加准确的回答!**""") select_file = gr.Dropdown( choices=file_list, label="选择文件", interactive=True, value=file_list[0] if len(file_list) > 0 else None ) file = gr.File( label="上传文件", visible=True, file_types=['.txt', '.md', '.docx', '.pdf'] ) add_vs = gr.Button(value="添加到知识库") with gr.Tab(label="数据库"): with gr.Accordion(open=False, label="数据库配置"): db_user = gr.Textbox( placeholder="root", label="用户名", ) db_password = gr.Textbox( placeholder="password", label="密码", type="password" ) db_host = gr.Textbox( placeholder="0.0.0.0", label="主机", ) db_port = gr.Textbox( placeholder="3306", label="端口", ) db_name = gr.Textbox( placeholder="test", label="数据库名称", ) add_database = gr.Button("添加数据库") with gr.Accordion(open=False, label="所有数据库配置"): databases = gr.Json() select_database = gr.Dropdown( choices=[d[0] for d in databases.value.items()] if databases.value else [], value=[d[0] for d in databases.value.items()][0] if databases.value else None, interactive=True, label="选择数据库" ) select_table = gr.Dropdown(label="选择表", interactive=True, multiselect=True) is_dbqa = gr.Checkbox( label="使用数据库问答", value=False, interactive=True, ) with gr.Tab(label="参数"): top_p = gr.Slider( minimum=-0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p", ) temperature = gr.Slider( minimum=0.1, maximum=2.0, value=1, step=0.1, interactive=True, label="Temperature", ) max_tokens = gr.Slider( minimum=0, maximum=512, value=512, step=8, interactive=True, label="Max Generation Tokens", ) memory_k = gr.Slider( minimum=0, maximum=10, value=5, step=1, interactive=True, label="Max Memory Window Size", ) chunk_size = gr.Slider( minimum=100, maximum=1000, value=200, step=100, interactive=True, label="Chunk Size", ) chunk_overlap = gr.Slider( minimum=0, maximum=100, value=0, step=10, interactive=True, label="Chunk Overlap", ) gr.Markdown(description) add_model.click( add_llm, inputs=[model_name, api_base, models], outputs=[model_name, api_base, models, select_model], ) add_database.click( add_db, inputs=[db_user, db_password, db_host, db_port, db_name, databases], outputs=[db_user, db_password, db_host, db_port, db_name, databases, select_database], ) select_database.change( get_table_names, inputs=[select_database, databases], outputs=select_table, ) file.upload( upload_file, inputs=file, outputs=select_file, ) add_vs.click( add_vector_store, inputs=[select_file, select_model, models, chunk_size, chunk_overlap], outputs=status_display, ) predict_args = dict( fn=predict, inputs=[ select_model, models, user_question, chatbot, history, top_p, temperature, max_tokens, memory_k, is_kgqa, single_turn, is_dbqa, select_database, select_table, databases, ], outputs=[chatbot, history, status_display], show_progress=True, ) retry_args = dict( fn=retry, inputs=[ select_model, models, user_question, chatbot, history, top_p, temperature, max_tokens, memory_k, is_kgqa, single_turn, is_dbqa, select_database, select_table, databases, ], outputs=[chatbot, history, status_display], show_progress=True, ) reset_args = dict(fn=reset_textbox, inputs=[], outputs=[user_input, status_display]) cancelBtn.click(cancel_outputing, [], [status_display]) transfer_input_args = dict( fn=transfer_input, inputs=[user_input], outputs=[user_question, user_input, submitBtn, cancelBtn], show_progress=True, ) user_input.submit(**transfer_input_args).then(**predict_args) submitBtn.click(**transfer_input_args).then(**predict_args) emptyBtn.click( reset_state, outputs=[chatbot, history, status_display], show_progress=True, ) emptyBtn.click(**reset_args) retryBtn.click(**retry_args) delLastBtn.click( delete_last_conversation, [chatbot, history], [chatbot, history, status_display], show_progress=True, ) demo.title = "OpenLLM Chatbot 🚀 " if __name__ == "__main__": reload_javascript() demo.queue(concurrency_count=CONCURRENT_COUNT).launch()