import re import os import unicodedata from typing import List import uuid import hashlib import pandas as pd from common.call_llm import chat_stream_generator DOC_QA_ENDPOINT = os.environ.get("DOC_QA_ENDPOINT") prompt_template = """你是由猎户星空开发的AI助手,你的名字叫聚言。你可以根据下面给出的参考资料和聊天历史来回答用户问题。 ### 参考资料 ### {context} ### 聊天历史 ### {chat_history} ### 用户问题 ### {question} ### 回答要求 ### {requirement} """ def document_prompt_template(): return """["Source_id": {doc_id},"Content": "{page_content}"]""" def language_detect(text: str) -> str: text = re.sub(r"([ ■◼•*…— �●⚫]+|[·\.~•、—'}\n\t]{1,})", '', text.strip()) stats = { "zh": 0, "ja": 0, "ko": 0, "en": 0, "th": 0, "other": 0 } char_count = 0 for char in text: try: code_name = unicodedata.name(char) except Exception: continue char_count += 1 # 判断是否为中文 if 'CJK' in code_name: stats["zh"] += 1 # 判断是否为日文 elif 'HIRAGANA' in code_name or 'KATAKANA' in code_name: stats["ja"] += 1 # 判断是否为泰文 elif "THAI" in code_name: stats["th"] += 1 # 判断是否为韩文 elif 'HANGUL' in code_name: stats["ko"] += 1 # 判断是否为英文 elif 'LA' in code_name: stats["en"] += 1 else: stats["other"] += 1 lang = "" ratio = 0.0 for lan in stats: if lan == "other": continue # trick: 英文按字母统计不准确,除以4大致表示word个数 if lan == "en": stats[lan] /= 4.0 lan_r = float(stats[lan]) / char_count if ratio < lan_r: lang = lan ratio = lan_r return lang def language_prompt(lan: str) -> str: _ZH_LANGUAGE_MAP = { "zh": "中文", "en": "英文", "other": "中文", "ja": "中文", "zh_gd": "中文", "ko": "韩文", "th": "泰文" } return _ZH_LANGUAGE_MAP.get(lan.lower(), "中文") def _get_chat_history(chat_history: List[List]) -> str: if not chat_history: return "" chat_history_text = "" for human_msg, ai_msg in chat_history: human = "{'Human': '" + human_msg + "'}" ai = "{'AI': '" + ai_msg + "'}" chat_history_text += "[" + ", ".join([human, ai]) + "]\n" return chat_history_text def get_prompt(context: str, chat_history: str, question: str, trapped_switch: int, fallback: str, citations_switch: int) -> str: answer_prompts = ["1. 你只能根据上面参考资料中给出的事实信息来回答用户问题,不要胡编乱造。", "2. 如果向用户提出澄清问题有助于回答问题,可以尝试提问。"] index = 3 if len(fallback) > 0 and trapped_switch == 1: answer_prompts.append( str(index) + ". " + """如果参考资料中的信息不足以回答用户问题,请直接回答下面三个双引号中的内容:\"\"\"{fallback}\"\"\"。""".format( fallback=fallback)) index += 1 if citations_switch: citation_prompt = "如果你给出的答案里引用了参考资料中的内容,请在答案的结尾处添加你引用的Source_id,引用的Source_id值来自于参考资料中,并用两个方括号括起来。示例:[[d97b811489b73f46c8d2cb1bc888dbbe]]、[[b6be48868de736b90363d001c092c019]]" answer_prompts.append(str(index) + ". " + citation_prompt) index += 1 lan = language_detect(question) style_prompt = """请你以第一人称并且用严谨的风格来回答问题,一定要用{language}来回答,并且基于事实详细阐述。""".format( language=language_prompt(lan), ) answer_prompts.append(str(index) + ". " + style_prompt) answer_prompts = "\n".join(answer_prompts) prompt = prompt_template.format(context=context, chat_history=chat_history, question=question, requirement=answer_prompts) return prompt def generate_doc_qa(input_text: str, history: List[List[str]], doc_df: "pd.DataFrame", trapped_switch: str, fallback: str, citations_switch: str): """Generates chat responses according to the input text, history and page content.""" # handle input params print(f"input_text: {input_text}, history: {history}, page_content: {doc_df}, trapped_switch: {trapped_switch}, fallback: {fallback}, citations_switch: {citations_switch}") citations_switch = 1 if citations_switch == "开启引用" else 0 trapped_switch = 1 if trapped_switch == "自定义话术" else 0 fallback = fallback or "" input_text = input_text or "你好" history = (history or [])[-5:] # Keep the last 5 messages in history doc_df = doc_df[doc_df["文档片段内容"].notna()] # iterate over all documents context = "" source_id_map = dict() for _, row in doc_df.iterrows(): if not row["文档片段内容"] or not row["文档片段名称"]: continue source_id = hashlib.md5(str(uuid.uuid4()).encode("utf-8")).hexdigest() source_id_map[source_id] = row["文档片段名称"] context += document_prompt_template().format(doc_id=source_id, page_content=row["文档片段内容"]) + "\n\n" prompt = get_prompt(context.strip(), _get_chat_history(history), input_text, trapped_switch, fallback, citations_switch) print(f"docQA prompt: {prompt}") messages = [{"role": "user", "content": prompt}] # append latest message stream_response = chat_stream_generator(messages=messages, endpoint=DOC_QA_ENDPOINT) cache = "" for character in stream_response: if "[" in character or cache: cache += character continue history[-1][1] += character yield None, history if cache: source_ids = re.findall(r"\[\[(.*?)\]\]", cache) print(f"Matched source ids {source_ids}") for source_id in source_ids: origin_source_id = source_id_map.get(source_id, source_id) cache = cache.replace(source_id, origin_source_id) history[-1][1] += cache yield None, history