import os from typing import Dict import gradio as gr import pandas as pd from chat_task.chat import generate_chat from doc_qa_task.doc_qa import generate_doc_qa from examples import ( load_examples, preprocess_docqa_examples, preprocess_extraction_examples, preprocess_qa_generator_examples, ) from extract_data_task.extract import extract_slots from plugin_task.api import api_plugin_chat from qa_generator_task.generate_qa import generate_qa_pairs from plugin_task.plugins import PLUGIN_JSON_SCHEMA abs_path = os.path.abspath(__file__) current_dir = os.path.dirname(abs_path) statistic_path = os.path.join(current_dir, "images") load_examples() def clear_session(): """Clears the chat session.""" return "", None def clear_plugin_session(session: Dict): """Clears the plugin session.""" session.clear() return session, None, None def show_custom_fallback_textbox(x): if x == "自定义话术": return [gr.Row(visible=True), gr.Textbox()] else: return [gr.Row(visible=False), gr.Textbox()] def validate_field_word_count( input_text: str, description: str, max_word_count: int = 3000 ): """ Validate the input text for word count :param input_text: :return: """ if len(input_text) == 0: raise gr.Error(f"{description}不能为空") if len(input_text) > max_word_count: raise gr.Error(f"{description}字数不能超过{max_word_count}字") def validate_chat(input_text: str): """ Validate the input text :param input_text: :return: """ validate_field_word_count(input_text, "输入", 500) def validate_doc_qa( input_text: str, doc_df: "pd.DataFrame", fallback_ratio: str, fallback_text_input: str, ): """ Validate fields of doc_qa :param input_text: :param doc_df: :param fallback_ratio: :param fallback_text_input: :return: """ # add all the doc ids to the input text if fallback_ratio == "自定义话术": validate_field_word_count(fallback_text_input, "自定义话术", 100) validate_field_word_count(input_text, "输入", 500) page_content_full_text = ( " ".join(doc_df["文档片段名称"].tolist()) + " " + " ".join(doc_df["文档片段内容"].tolist()) ) validate_field_word_count(page_content_full_text, "文档信息", 2500) def validate_qa_pair_generator(input_text: str): """ Validate the input text :param input_text: :return: """ return validate_field_word_count(input_text, "输入") def validate_extraction( input_text: str, extraction_df: "pd.DataFrame", ): """ Validate fields of extraction """ extraction_full_text = ( " ".join(extraction_df["字段名称"].tolist()) + " " + " ".join(extraction_df["字段描述"].tolist()) ) validate_field_word_count(input_text, "输入", 1500) validate_field_word_count(extraction_full_text, "待抽取字段描述", 1500) def validate_plugin(input_text: str): """ Validate the input text :param input_text: :return: """ validate_field_word_count(input_text, "输入", 500) with gr.Blocks( title="Orion-14B", theme="shivi/calm_seafoam@>=0.0.1,<1.0.0", ) as demo: def user(user_message, history): return user_message, (history or []) + [[user_message, ""]] gr.Markdown( """
""" ) with gr.Tab("基础能力"): chatbot = gr.Chatbot( label="Orion-14B-Chat", elem_classes="control-height", show_copy_button=True, min_width=1368, height=416, ) chat_text_input = gr.Textbox(label="输入", min_width=1368) with gr.Row(): with gr.Column(scale=2): gr.Examples( [ "可以给我讲个笑话吗?", "什么是伟大的诗歌?", "你知道李白吗?", "黑洞是如何工作的?", "在表中插入一条数据,id为1,name为张三,age为18,请问SQL语句是什么?", ], chat_text_input, label="试试问", ) with gr.Column(scale=1): with gr.Row(variant="compact"): clear_history = gr.Button( "清除历史", min_width="17", size="sm", scale=1, icon=os.path.join(statistic_path, "clear.png"), ) submit = gr.Button( "发送", variant="primary", min_width="17", size="sm", scale=1, icon=os.path.join(statistic_path, "send.svg"), ) chat_text_input.submit( fn=validate_chat, inputs=[chat_text_input], outputs=[], queue=False ).success( user, [chat_text_input, chatbot], [chat_text_input, chatbot], queue=False ).success( fn=generate_chat, inputs=[chat_text_input, chatbot], outputs=[chat_text_input, chatbot], ) submit.click( fn=validate_chat, inputs=[chat_text_input], outputs=[], queue=False ).success( user, [chat_text_input, chatbot], [chat_text_input, chatbot], queue=False ).success( fn=generate_chat, inputs=[chat_text_input, chatbot], outputs=[chat_text_input, chatbot], api_name="chat", ) clear_history.click( fn=clear_session, inputs=[], outputs=[chat_text_input, chatbot], queue=False ) with gr.Tab("基于文档问答"): with gr.Row(): with gr.Column(scale=3, min_width=357, variant="panel"): gr.Markdown( '配置项' ) citations_radio = gr.Radio( ["开启引用", "关闭引用"], label="引用", value="关闭引用" ) fallback_radio = gr.Radio( ["使用大模型知识", "自定义话术"], label="超纲问题回复", value="自定义话术", ) fallback_text_input = gr.Textbox( label="自定义话术", value="抱歉,我还在学习中,暂时无法回答您的问题。", ) gr.Markdown( '文档信息' ) doc_df = gr.Dataframe( headers=["文档片段内容", "文档片段名称"], datatype=["str", "str"], row_count=6, col_count=(2, "fixed"), label="", interactive=True, wrap=True, elem_classes="control-height", height=300, ) with gr.Column(scale=2, min_width=430): chatbot = gr.Chatbot( label="适用场景:预期LLM通过自由知识回答", elem_classes="control-height", show_copy_button=True, min_width=999, height=419, ) doc_qa_input = gr.Textbox(label="输入", min_width=999, max_lines=10) with gr.Row(): with gr.Column(scale=2): gr.Examples( [ "哪些情况下不能超车?", "参观须知", "青岛啤酒酒精含量是多少?", ], doc_qa_input, label="试试问", cache_examples=True, fn=preprocess_docqa_examples, outputs=[doc_df], ) with gr.Column(scale=1): with gr.Row(variant="compact"): clear_history = gr.Button( "清除历史", min_width="17", size="sm", scale=1, icon=os.path.join(statistic_path, "clear.png"), ) submit = gr.Button( "发送", variant="primary", min_width="17", size="sm", scale=1, icon=os.path.join(statistic_path, "send.svg"), ) doc_qa_input.submit( fn=validate_doc_qa, inputs=[ doc_qa_input, doc_df, fallback_radio, fallback_text_input, ], outputs=[], queue=False, ).success( user, [doc_qa_input, chatbot], [doc_qa_input, chatbot], queue=False ).success( fn=generate_doc_qa, inputs=[ doc_qa_input, chatbot, doc_df, fallback_radio, fallback_text_input, citations_radio, ], outputs=[doc_qa_input, chatbot], scroll_to_output=True, api_name="doc_qa", ) submit.click( fn=validate_doc_qa, inputs=[ doc_qa_input, doc_df, fallback_radio, fallback_text_input, ], outputs=[], queue=False, ).success( user, [doc_qa_input, chatbot], [doc_qa_input, chatbot], queue=False ).success( fn=generate_doc_qa, inputs=[ doc_qa_input, chatbot, doc_df, fallback_radio, fallback_text_input, citations_radio, ], outputs=[doc_qa_input, chatbot], scroll_to_output=True, ) clear_history.click( fn=lambda x: (None, None, None), inputs=[], outputs=[doc_df, doc_qa_input, chatbot], queue=False, ) with gr.Tab("插件能力"): with gr.Row(): with gr.Column(scale=1): gr.Markdown( '配置项' ) radio_plugins = [ gr.Radio( ["开启", "关闭"], label=plugin_json["name_for_human"], value="开启", ) for plugin_json in PLUGIN_JSON_SCHEMA ] with gr.Column(scale=3): session = gr.State(value=dict()) chatbot = gr.Chatbot( label="适用场景:需要LLM调用API解决问题", elem_classes="control-height", show_copy_button=True, ) plugin_text_input = gr.Textbox(label="输入") with gr.Row(): with gr.Column(scale=2): gr.Examples( [ "北京天气怎么样?", "查询物流信息", "每日壁纸", "bing今天的壁纸是什么", "查询手机号码归属地", ], plugin_text_input, label="试试问", ) with gr.Column(scale=1): with gr.Row(variant="compact"): clear_history = gr.Button( "清除历史", min_width="17", size="sm", scale=1, icon=os.path.join(statistic_path, "clear.png"), ) submit = gr.Button( "发送", variant="primary", min_width="17", size="sm", scale=1, icon=os.path.join(statistic_path, "send.svg"), ) plugin_text_input.submit( fn=validate_plugin, inputs=[ plugin_text_input, ], outputs=[], queue=False, ).success( user, [plugin_text_input, chatbot], [plugin_text_input, chatbot], scroll_to_output=True, ).success( fn=api_plugin_chat, inputs=[session, plugin_text_input, chatbot, *radio_plugins], outputs=[session, plugin_text_input, chatbot], scroll_to_output=True, ) submit.click( fn=validate_plugin, inputs=[ plugin_text_input, ], outputs=[], queue=False, ).success( user, [plugin_text_input, chatbot], [plugin_text_input, chatbot], scroll_to_output=True, ).success( fn=api_plugin_chat, inputs=[session, plugin_text_input, chatbot, *radio_plugins], outputs=[session, plugin_text_input, chatbot], api_name="plugin", scroll_to_output=True, ) clear_history.click( fn=clear_plugin_session, inputs=[session], outputs=[session, plugin_text_input, chatbot], queue=False, ) with gr.Tab("生成QA对"): with gr.Row(equal_height=True): qa_generator_output = gr.Code( language="json", show_label=False, min_width=1368, ) with gr.Row(): qa_generator_input = gr.Textbox( label="输入", show_label=True, info="", min_width=1368, lines=5, max_lines=10, ) with gr.Row(): with gr.Column(scale=2): gr.Examples( [ "第一章 总 则 \n第...", "金字塔,在建筑学上是...", "山西老陈醋是以高粱、...", "室内装饰构造虚拟仿真...", "猎户星空(Orion...", ], qa_generator_input, label="试试问", cache_examples=True, fn=preprocess_qa_generator_examples, outputs=[qa_generator_input], ) with gr.Column(scale=1): with gr.Row(variant="compact"): clear = gr.Button( "清除", min_width="17", size="sm", scale=1, icon=os.path.join(statistic_path, "clear.png"), ) submit = gr.Button( "发送", variant="primary", min_width="17", size="sm", scale=1, icon=os.path.join(statistic_path, "send.svg"), ) submit.click( fn=validate_qa_pair_generator, inputs=[qa_generator_input], outputs=[], ).success( fn=generate_qa_pairs, inputs=[qa_generator_input], outputs=[qa_generator_output, qa_generator_input], scroll_to_output=True, api_name="qa_generator", ) clear.click( fn=lambda x: ("", ""), inputs=[], outputs=[qa_generator_input, qa_generator_output], queue=False, ) with gr.Tab("抽取数据"): extract_outpu_df = gr.Dataframe( label="", headers=["字段名称", "字段抽取结果"], datatype=["str", "str"], col_count=(2, "fixed"), wrap=True, elem_classes="control-height", height=234, row_count=5, ) extract_input = gr.Textbox(label="输入", lines=5, min_width=1368, max_lines=10) extraction_df = gr.Dataframe( headers=["字段名称", "字段描述"], datatype=["str", "str"], row_count=3, col_count=(2, "fixed"), label="", interactive=True, wrap=True, elem_classes="control-height", height=180, ) with gr.Row(): with gr.Column(scale=2): gr.Examples( ["第一条合同当...", "发票编号: IN...", "发件人:John..."], extract_input, label="试试问", cache_examples=True, fn=preprocess_extraction_examples, outputs=[extract_input, extraction_df], ) with gr.Column(scale=1): with gr.Row(variant="compact"): clear = gr.Button( "清除历史", min_width="17", size="sm", scale=1, icon=os.path.join(statistic_path, "clear.png"), ) submit = gr.Button( "发送", variant="primary", min_width="17", size="sm", scale=1, icon=os.path.join(statistic_path, "send.svg"), ) submit.click( fn=validate_extraction, inputs=[extract_input, extraction_df], outputs=[], ).success( fn=extract_slots, inputs=[extract_input, extraction_df], outputs=[extract_outpu_df], scroll_to_output=True, api_name="extract", ) clear.click( fn=lambda x: ("", None, None), inputs=[], outputs=[ extract_input, extraction_df, extract_outpu_df, ], queue=False, ) if __name__ == "__main__": demo.queue(api_open=False, max_size=40).launch( height=800, share=False, server_name="0.0.0.0", show_api=False, max_threads=4, )