Spaces:
Runtime error
Runtime error
import os | |
os.environ["OPENAI_API_KEY"] = "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" | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.vectorstores import Chroma | |
from pypinyin import lazy_pinyin | |
import gradio as gr | |
import openai | |
import random | |
# import logging | |
# logging.basicConfig( | |
# filename='log/log.log', | |
# level=logging.INFO, | |
# format='%(asctime)s - %(levelname)s - %(message)s', | |
# datefmt='%m/%d/%Y %H:%M:%S' | |
# ) | |
embedding = OpenAIEmbeddings() | |
target_files = set() | |
topics = ["农业", "宗教与文化", "建筑业与制造业", "医疗卫生保健", "国家治理", "法律法规", "财政税收", "教育", "金融", "贸易", "宏观经济", "社会发展", "科学技术", "能源环保", "国际关系", "国防安全","不限主题"] | |
def get_path(target_string): | |
folder_path = "./vector_data" | |
all_vectors = os.listdir(folder_path) | |
matching_files = [file for file in all_vectors if file.startswith(target_string)] | |
for file in matching_files: | |
file_path = os.path.join(folder_path, file) | |
return file_path | |
return "" | |
def extract_partial_message(res_message, response): | |
for chunk in response: | |
if len(chunk["choices"][0]["delta"]) != 0: | |
res_message = res_message + chunk["choices"][0]["delta"]["content"] | |
yield res_message | |
def format_messages(sys_prompt, history, message): | |
history_openai_format = [{"role": "system", "content": sys_prompt}] | |
for human, assistant in history: | |
history_openai_format.append({"role": "user", "content": human}) | |
history_openai_format.append({"role": "assistant", "content": assistant}) | |
history_openai_format.append({"role": "user", "content": message}) | |
return history_openai_format | |
def get_domain(history, message): | |
sys_prompt = """ | |
帮我根据用户的问题划分到以下几个类别,输出最匹配的一个类别:[宗教与文化, 农业, 建筑业与制造业, 医疗卫生保健, 国家治理, 法律法规, 财政税收, 教育, 金融, 贸易, 宏观经济, 社会发展, 科学技术, 能源环保, 国际关系, 国防安全] | |
""" | |
history_openai_format = format_messages(sys_prompt, history, message) | |
print("history_openai_format:", history_openai_format) | |
# logging.info(f"history_openai_format: {history_openai_format}") | |
response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=False) | |
domain = response['choices'][0]['message']['content'] | |
print("匹配领域:", domain) | |
# logging.info(f"匹配领域: {domain}") | |
return domain | |
def echo(message, history, flag1, flag2): | |
global target_files, topics | |
print("flag1:", flag1) | |
print("flag2:", flag2) | |
print("history:", history) | |
print("message:", message) | |
# logging.info(f"flag1: {flag1}") | |
# logging.info(f"flag2: {flag2}") | |
# logging.info(f"history: {history}") | |
# logging.info(f"message: {message}") | |
if len(flag1) == 0: # 不进行研报问答&研报生成 | |
target_files.clear() | |
history.clear() | |
if flag2 not in [None, 16]: | |
domain = topics[flag2] | |
message = f"{domain}领域相关内容" | |
elif flag2 in [None, 16]: | |
message = f"{topics[random.randint(0, 16)]}领域相关内容" if message == "" else message | |
domain = get_domain(history, message) | |
persist_vector_path = get_path("".join(lazy_pinyin(domain))) | |
db = Chroma(persist_directory=persist_vector_path, embedding_function=embedding) | |
docs = db.similarity_search_with_score(query=message, k=5) | |
contents = [doc[0] for doc in docs] | |
relevance = " ".join(doc.page_content for doc in contents) | |
source = [doc.metadata for doc in contents] | |
for item in source: | |
target_files.add(item['source'] ) | |
print("研报搜索结果:", target_files) | |
# logging.info(f"研报搜索结果: {target_files}") | |
sys_prompt = """ | |
你是一个研报助手,根据这篇文章:{} | |
来回复用户的问题生成总结,你需要严格按照这种格式回复:以上文章总结了*,主要观点是*。你只能回复中文。 | |
如果这篇文章无法回答用户的问题,你必须根据你的知识面直接编造内容进行回答,不能回复不知道,不能说这段文字没有提供等话语。 | |
""" | |
sys_prompt = sys_prompt.format(relevance) | |
history_openai_format = format_messages(sys_prompt, history, message) | |
print("history_openai_format:", history_openai_format) | |
# logging.info(f"history_openai_format: {history_openai_format}") | |
response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True) | |
partial_message = "搜索结果来自以下研报:\n" + '\n'.join(i for i in target_files) + '\n\n' | |
for result_message in extract_partial_message(partial_message, response): | |
yield result_message | |
elif flag1 == ['研报问答']: | |
print("target_files:", target_files) | |
# logging.info(f"target_files: {target_files}") | |
QA_pages = [] | |
if not target_files: | |
yield "请取消选中研报问答,先进行研报检索,再进行问答。" | |
else: | |
for item in target_files: | |
loader = PyPDFLoader(item) | |
QA_pages.extend(loader.load_and_split()) | |
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
documents = text_splitter.split_documents(QA_pages) | |
db = Chroma.from_documents(documents, OpenAIEmbeddings()) | |
docs = db.similarity_search_with_score(query=message, k=3) | |
contents = [doc[0] for doc in docs] | |
relevance = " ".join(doc.page_content for doc in contents) | |
sys_prompt = """ | |
你是一个研报助手,根据这篇文章:{} | |
来回复用户的问题,如果这篇文章无法回答用户的问题,你必须根据你的知识面来编造进行专业的回答, | |
不能回复不知道,不能回复这篇文章不能回答的这种话语,你只能回复中文。 | |
""" | |
sys_prompt = sys_prompt.format(relevance) | |
history_openai_format = format_messages(sys_prompt, history, message) | |
print("history_openai_format:", history_openai_format) | |
# logging.info(f"history_openai_format: {history_openai_format}") | |
response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True) | |
for result_message in extract_partial_message("", response): | |
yield result_message | |
elif flag1 == ['研报生成']: | |
target_files.clear() | |
sys_prompt = """ | |
你是一个研报助手,请根据用户的要求回复问题。 | |
""" | |
history_openai_format = format_messages(sys_prompt, history, message) | |
print("history_openai_format:", history_openai_format) | |
# logging.info(f"history_openai_format: {history_openai_format}") | |
response = openai.ChatCompletion.create(model="gpt-4", messages=history_openai_format, temperature=1.0, stream=True) | |
for result_message in extract_partial_message("", response): | |
yield result_message | |
elif len(flag1) == 2: | |
yield "请选中一个选项,进行相关问答。" | |
demo = gr.ChatInterface( | |
echo, | |
chatbot=gr.Chatbot(height=430, label="ChatReport"), | |
textbox=gr.Textbox(placeholder="请输入问题", container=False, scale=7), | |
title="研报助手", | |
description="清芬院研报助手", | |
theme="soft", | |
additional_inputs=[ | |
# gr.Radio(["研报问答", "研报生成"], type="index", label = "function"), | |
# gr.Checkbox(label = "研报问答"), | |
# gr.Checkbox(label = "研报生成"), | |
gr.CheckboxGroup(["研报问答", "研报生成"], label="Function"), | |
gr.Dropdown(topics, type="index"), | |
# gr.Button(value="Run").click(echo, inputs=["", "", [], None], outputs=[""]) | |
# btn.click(combine, inputs=[txt, txt_2], outputs=[txt_3]) | |
# gr.Blocks() | |
], | |
# retry_btn="retry", | |
undo_btn="清空输入框", | |
clear_btn="清空聊天记录" | |
).queue() | |
if __name__ == "__main__": | |
demo.launch(share=True) | |
''' | |
flag1: ['研报问答'] | |
flag2: None | |
history: [] | |
message: gg | |
target_files: set() | |
''' |