import os os.environ["OPENAI_API_KEY"] = "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" from langchain.embeddings.openai import OpenAIEmbeddings from langchain.text_splitter import CharacterTextSplitter from langchain.document_loaders import PyPDFLoader from langchain.vectorstores import Chroma from pypinyin import lazy_pinyin import gradio as gr import openai import random # import logging # logging.basicConfig( # filename='log/log.log', # level=logging.INFO, # format='%(asctime)s - %(levelname)s - %(message)s', # datefmt='%m/%d/%Y %H:%M:%S' # ) embedding = OpenAIEmbeddings() target_files = set() topics = ["农业", "宗教与文化", "建筑业与制造业", "医疗卫生保健", "国家治理", "法律法规", "财政税收", "教育", "金融", "贸易", "宏观经济", "社会发展", "科学技术", "能源环保", "国际关系", "国防安全","不限主题"] def get_path(target_string): folder_path = "./vector_data" all_vectors = os.listdir(folder_path) matching_files = [file for file in all_vectors if file.startswith(target_string)] for file in matching_files: file_path = os.path.join(folder_path, file) return file_path return "" def extract_partial_message(res_message, response): for chunk in response: if len(chunk["choices"][0]["delta"]) != 0: res_message = res_message + chunk["choices"][0]["delta"]["content"] yield res_message def format_messages(sys_prompt, history, message): history_openai_format = [{"role": "system", "content": sys_prompt}] for human, assistant in history: history_openai_format.append({"role": "user", "content": human}) history_openai_format.append({"role": "assistant", "content": assistant}) history_openai_format.append({"role": "user", "content": message}) return history_openai_format def get_domain(history, message): sys_prompt = """ 帮我根据用户的问题划分到以下几个类别,输出最匹配的一个类别:[宗教与文化, 农业, 建筑业与制造业, 医疗卫生保健, 国家治理, 法律法规, 财政税收, 教育, 金融, 贸易, 宏观经济, 社会发展, 科学技术, 能源环保, 国际关系, 国防安全] """ history_openai_format = format_messages(sys_prompt, history, message) print("history_openai_format:", history_openai_format) # logging.info(f"history_openai_format: {history_openai_format}") response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=False) domain = response['choices'][0]['message']['content'] print("匹配领域:", domain) # logging.info(f"匹配领域: {domain}") return domain def echo(message, history, flag1, flag2): global target_files, topics print("flag1:", flag1) print("flag2:", flag2) print("history:", history) print("message:", message) # logging.info(f"flag1: {flag1}") # logging.info(f"flag2: {flag2}") # logging.info(f"history: {history}") # logging.info(f"message: {message}") if len(flag1) == 0: # 不进行研报问答&研报生成 target_files.clear() history.clear() if flag2 not in [None, 16]: domain = topics[flag2] message = f"{domain}领域相关内容" elif flag2 in [None, 16]: message = f"{topics[random.randint(0, 16)]}领域相关内容" if message == "" else message domain = get_domain(history, message) persist_vector_path = get_path("".join(lazy_pinyin(domain))) db = Chroma(persist_directory=persist_vector_path, embedding_function=embedding) docs = db.similarity_search_with_score(query=message, k=5) contents = [doc[0] for doc in docs] relevance = " ".join(doc.page_content for doc in contents) source = [doc.metadata for doc in contents] for item in source: target_files.add(item['source'] ) print("研报搜索结果:", target_files) # logging.info(f"研报搜索结果: {target_files}") sys_prompt = """ 你是一个研报助手,根据这篇文章:{} 来回复用户的问题生成总结,你需要严格按照这种格式回复:以上文章总结了*,主要观点是*。你只能回复中文。 如果这篇文章无法回答用户的问题,你必须根据你的知识面直接编造内容进行回答,不能回复不知道,不能说这段文字没有提供等话语。 """ sys_prompt = sys_prompt.format(relevance) history_openai_format = format_messages(sys_prompt, history, message) print("history_openai_format:", history_openai_format) # logging.info(f"history_openai_format: {history_openai_format}") response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True) partial_message = "搜索结果来自以下研报:\n" + '\n'.join(i for i in target_files) + '\n\n' for result_message in extract_partial_message(partial_message, response): yield result_message elif flag1 == ['研报问答']: print("target_files:", target_files) # logging.info(f"target_files: {target_files}") QA_pages = [] if not target_files: yield "请取消选中研报问答,先进行研报检索,再进行问答。" else: for item in target_files: loader = PyPDFLoader(item) QA_pages.extend(loader.load_and_split()) text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) documents = text_splitter.split_documents(QA_pages) db = Chroma.from_documents(documents, OpenAIEmbeddings()) docs = db.similarity_search_with_score(query=message, k=3) contents = [doc[0] for doc in docs] relevance = " ".join(doc.page_content for doc in contents) sys_prompt = """ 你是一个研报助手,根据这篇文章:{} 来回复用户的问题,如果这篇文章无法回答用户的问题,你必须根据你的知识面来编造进行专业的回答, 不能回复不知道,不能回复这篇文章不能回答的这种话语,你只能回复中文。 """ sys_prompt = sys_prompt.format(relevance) history_openai_format = format_messages(sys_prompt, history, message) print("history_openai_format:", history_openai_format) # logging.info(f"history_openai_format: {history_openai_format}") response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True) for result_message in extract_partial_message("", response): yield result_message elif flag1 == ['研报生成']: target_files.clear() sys_prompt = """ 你是一个研报助手,请根据用户的要求回复问题。 """ history_openai_format = format_messages(sys_prompt, history, message) print("history_openai_format:", history_openai_format) # logging.info(f"history_openai_format: {history_openai_format}") response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True) for result_message in extract_partial_message("", response): yield result_message elif len(flag1) == 2: yield "请选中一个选项,进行相关问答。" demo = gr.ChatInterface( echo, chatbot=gr.Chatbot(height=430, label="ChatReport"), textbox=gr.Textbox(placeholder="请输入问题", container=False, scale=7), title="研报助手", description="清芬院研报助手", theme="soft", additional_inputs=[ # gr.Radio(["研报问答", "研报生成"], type="index", label = "function"), # gr.Checkbox(label = "研报问答"), # gr.Checkbox(label = "研报生成"), gr.CheckboxGroup(["研报问答", "研报生成"], label="Function"), gr.Dropdown(topics, type="index"), # gr.Button(value="Run").click(echo, inputs=["", "", [], None], outputs=[""]) # btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3]) # gr.Blocks() ], # retry_btn="retry", undo_btn="清空输入框", clear_btn="清空聊天记录" ).queue() if __name__ == "__main__": demo.launch(share=True) ''' flag1: ['研报问答'] flag2: None history: [] message: gg target_files: set() '''