Spaces:
Sleeping
Sleeping
File size: 15,327 Bytes
4c2fab7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 |
import os
import time
import json
import asyncio
import gradio as gr
# set the env
from dotenv import load_dotenv
load_dotenv()
# get the root path of the project
current_file_path = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.abspath(current_file_path)
from textwrap import dedent
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate
class OurLLM:
def __init__(self, model="gpt-4o"):
'''
params:
model: str,
模型名称 ["GLM-4-Flash", "GLM-4V-Flash",
"gpt-4o-mini", "gpt-4o", "o1-mini",
"gemini-1.5-flash-002", "gemini-1.5-pro-002",
"Qwen/Qwen2.5-7B-Instruct", "Qwen/Qwen2.5-Coder-7B-Instruct"]
'''
self.model_name = model
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
OPENAI_API_KEY_DF = os.getenv('OPENAI_API_KEY_DF', OPENAI_API_KEY)
OPENAI_API_KEY_AZ = os.getenv('OPENAI_API_KEY_AZ', OPENAI_API_KEY)
OPENAI_API_KEY_CD = os.getenv('OPENAI_API_KEY_CD')
OPENAI_API_KEY_O1 = os.getenv('OPENAI_API_KEY_O1')
OPENAI_API_KEY_GLM = os.getenv('OPENAI_API_KEY_GLM')
OPENAI_API_KEY_SC = os.getenv('OPENAI_API_KEY_SC')
OPENAI_BASE_URL = os.getenv('OPENAI_BASE_URL')
OPENAI_BASE_URL_GLM = os.getenv('OPENAI_BASE_URL_GLM')
OPENAI_BASE_URL_SC = os.getenv('OPENAI_BASE_URL_SC')
# 创建 API Key 映射
apiKeyMap = {
'gemini': {"base_url": OPENAI_BASE_URL, "api_key": OPENAI_API_KEY_DF},
'gpt': {"base_url": OPENAI_BASE_URL, "api_key": OPENAI_API_KEY_AZ},
'o1': {"base_url": OPENAI_BASE_URL, "api_key": OPENAI_API_KEY_O1},
'claude': {"base_url": OPENAI_BASE_URL, "api_key": OPENAI_API_KEY_CD},
'glm': {"base_url": OPENAI_BASE_URL_GLM, "api_key": OPENAI_API_KEY_GLM},
'qwen': {"base_url": OPENAI_BASE_URL_SC, "api_key": OPENAI_API_KEY_SC},
}
for name, info in apiKeyMap.items():
if name in model.lower():
self.base_url = info["base_url"]
self.api_key = info["api_key"]
break
assert self.base_url is not None, f"Base URL not found for model: {model}"
assert self.api_key is not None, f"API key not found for model: {model}"
chat_prompt = ChatPromptTemplate.from_messages(
[
("system", "{system_prompt}"),
("human", "{input}"),
# ("ai", "{chat_history}"),
]
)
self.chat_prompt = chat_prompt
self.llm = self.get_llm(model)
def clean_json(self, s):
return s.replace("```json", "").replace("```", "").strip()
def get_system_prompt(self, mode="assistant"):
prompt_map = {
"assistant": dedent("""
你是一个智能助手,擅长用简洁的中文回答用户的问题。
请确保你的回答准确、清晰、有条理,并且符合中文的语言习惯。
重要提示:
1. 回答要简洁明了,避免冗长
2. 使用适当的专业术语
3. 保持客观中立的语气
4. 如果不确定,要明确指出
"""),
# search
"keyword_expand": dedent("""
你是一个搜索关键词扩展专家,擅长将用户的搜索意图转化为多个相关的搜索词或短语。
用户会输入一段描述他们搜索需求的文本,请你生成与之相关的关键词列表。
你需要返回一个可以直接被 json 库解析的响应,包含以下内容:
{
"keywords": [关键词列表],
}
重要提示:
1. 关键词应该包含同义词、近义词、上位词、下位词
2. 短语要体现不同的表达方式和组合
3. 描述句子要涵盖不同的应用场景和用途
4. 所有内容必须与原始搜索意图高度相关
5. 扩展搜索意图到相关的应用场景和工具,例如:
- 如果搜索"PDF转MD",应包含PDF内容提取、PDF解析工具、PDF数据处理等
- 如果搜索"图片压缩",应包含批量压缩工具、图片格式转换等
- 如果搜索"代码格式化",应包含代码美化工具、语法检查器、代码风格统一等
- 如果搜索"文本翻译",应包含机器翻译API、多语言翻译工具、离线翻译软件等
- 如果搜索"数据可视化",应包含图表生成工具、数据分析库、交互式图表等
- 如果搜索"网络爬虫",应包含数据采集框架、反爬虫绕过、数据解析工具等
- 如果搜索"API测试",应包含接口测试工具、性能监控、自动化测试框架等
6. 所有内容主要使用英文表达,并对部分关键词添加额外的中文表示
7. 返回内容不要使用任何 markdown 格式 以及任何特殊字符
"""),
"zh2en": dedent("""
你是一个专业的中译英翻译专家,尤其擅长学术论文的翻译工作。
请将用户提供的中文内容翻译成地道、专业的英文。
重要提示:
1. 使用学术论文常用的表达方式和术语
2. 保持专业、正式的语气
3. 确保译文的准确性和流畅性
4. 对专业术语进行准确翻译
5. 遵循英文学术写作的语法规范
6. 保持原文的逻辑结构
7. 适当使用学术论文常见的过渡词和连接词
8. 如遇到模糊的表达,选择最符合学术上下文的翻译
9. 避免使用口语化或非正式的表达
10. 注意时态和语态的准确使用
"""),
"github_score": dedent("""
你是一个语义匹配评分专家,擅长根据用户需求和仓库描述进行语义匹配度评分。
用户会输入两部分内容:
1. 用户的具体需求描述
2. 多个仓库的描述列表(以1,2,3等数字开头)
请你仔细分析用户需求,并对每个仓库进行评分。
确保返回一个可以直接被 json 库解析的响应,包含以下内容:
{
"indices": [仓库编号列表,按分数从高到低],
"scores": [编号对应的匹配度评分列表,0-100的整数,表示匹配程度]
}
重要提示:
1. 评分范围为0-100的整数,高于60分表示具有明显相关性
2. 评分要客观反映仓库与需求的契合度
3. 只返回评分大于 60 的仓库
4. 返回内容不要使用任何 markdown 格式 以及任何特殊字符
""")
}
return prompt_map[mode]
def get_llm(self, model="gpt-4o-mini"):
'''
params:
model: str, 模型名称 ["gpt-4o-mini", "gpt-4o", "o1-mini", "gemini-1.5-flash-002"]
'''
llm = ChatOpenAI(
model=model,
base_url=self.base_url,
api_key=self.api_key,
)
print(f"Init model {model} successfully!")
return llm
def ask_question(self, question, system_prompt=None):
# 1. 获取系统提示
if system_prompt is None:
system_prompt = self.get_system_prompt()
# 2. 生成聊天提示
prompt = self.chat_prompt.format(input=question, system_prompt=system_prompt)
config = {
"configurable": {"response_format": {"type": "json_object"}}
}
# 3. 调用 LLM 进行回答
for _ in range(10):
try:
response = self.llm.invoke(prompt, config=config)
response.content = self.clean_json(response.content)
return response
except Exception as e:
print(e)
time.sleep(10)
continue
print(f"Failed to call llm for prompt: {prompt[0:10]}")
return None
async def ask_questions_parallel(self, questions, system_prompt=None):
# 1. 获取系统提示
if system_prompt is None:
system_prompt = self.get_system_prompt()
# 2. 定义异步函数
async def call_llm(prompt):
for _ in range(10):
try:
response = await self.llm.ainvoke(prompt)
response.content = self.clean_json(response.content)
return response
except Exception as e:
print(e)
await asyncio.sleep(10)
continue
print(f"Failed to call llm for prompt: {prompt[0:10]}")
return None
# 3. 构建 prompt
prompts = [self.chat_prompt.format(input=question, system_prompt=system_prompt) for question in questions]
# 4. 异步调用
tasks = [call_llm(prompt) for prompt in prompts]
results = await asyncio.gather(*tasks)
return results
class RepoSearch:
def __init__(self):
db_path = os.path.join(root_path, "database", "init")
embeddings = OpenAIEmbeddings(api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("OPENAI_BASE_URL"),
model="text-embedding-3-small")
assert os.path.exists(db_path), f"Database not found: {db_path}"
self.vector_db = FAISS.load_local(db_path, embeddings,
index_name="init",
allow_dangerous_deserialization=True)
def search(self, query, k=10):
'''
name + description + html_url + topics
'''
results = self.vector_db.similarity_search(query + " technology", k=k)
simple_str = ""
simple_list = []
for i, doc in enumerate(results):
content = json.loads(doc.page_content)
metadata = doc.metadata
if content["description"] is None:
content["description"] = ""
# desc = content["description"] if len(content["description"]) < 300 else content["description"][:300] + "..."
simple_str += f"\t**{i+1}. {content['name']}** || {content['description']}\n" # 用于大模型匹配
simple_list.append({
"name": content["name"],
"description": content["description"],
**metadata, # 解包所有 metadata 字段
})
return simple_str, simple_list
def main():
search = RepoSearch()
llm = OurLLM(model="gpt-4o")
def respond(
prompt: str,
history,
is_llm_filter: bool = False,
is_keyword_expand: bool = False,
match_num: int = 40
):
# 1. 初始化历史记录
if not history:
history = [{"role": "system", "content": "You are a friendly chatbot"}]
history.append({"role": "user", "content": prompt})
response = {"role": "assistant", "content": ""}
yield history
# 2. 扩展用户问题关键词
if is_keyword_expand:
response["content"] = "开始扩展关键词..."
yield history + [response]
query = llm.ask_question(prompt, system_prompt=llm.get_system_prompt("keyword_expand")).content
prompt = ", ".join(json.loads(query)["keywords"])
# 3. 语义向量匹配
response["content"] = "开始语义向量匹配..."
yield history + [response]
match_str, simple_list = search.search(prompt, match_num)
# 4. 通过 LLM 评分得到最匹配的仓库索引
if not is_llm_filter:
simple_strs = [f"\t**{i+1}. {repo['name']}** [✨ {repo['star_count'] // 1000}k] || **Description:** {repo['description']} || **Url:** {repo['html_url']} \n" for i, repo in enumerate(simple_list)]
response["content"] = "".join(simple_strs)
yield history + [response]
else:
response["content"] = "开始通过 LLM 评分得到最匹配的仓库..."
yield history + [response]
query = ' ## 用户需要的仓库内容:' + prompt + '\n ## 搜索结果列表:' + match_str
out = llm.ask_question(query, system_prompt=llm.get_system_prompt("github_score")).content
matched_index = json.loads(out)["indices"]
# 5. 通过索引得到最匹配的仓库
result = [simple_list[idx-1] for idx in matched_index]
simple_strs = [f"\t**{i+1}. {repo['name']}** [✨ {repo['star_count'] // 1000}k] || **Description:** {repo['description']} || **Url:** {repo['html_url']} \n" for i, repo in enumerate(result)]
response["content"] = "".join(simple_strs)
yield history + [response]
with gr.Blocks() as demo:
gr.Markdown("## Github semantic search (基于语义的 github 仓库搜索) 🌐")
with gr.Row():
with gr.Column(scale=1):
# 添加控制参数
llm_filter = gr.Checkbox(
label="使用LLM过滤结果",
value=False,
info="是否使用 LLM 对搜索结果进行二次过滤"
)
keyword_expand = gr.Checkbox(
label="扩展关键词搜索",
value=False,
info="是否使用 LLM 扩展搜索关键词"
)
match_number = gr.Slider(
minimum=10,
maximum=100,
value=40,
step=10,
label="语义匹配数量",
info="进行语义匹配后返回的仓库数量,若使用 LLM 过滤,建议适当增加数量"
)
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Agent",
type="messages",
avatar_images=(None, "https://img1.baidu.com/it/u=2193901176,1740242983&fm=253&fmt=auto&app=138&f=JPEG?w=500&h=500"),
height="65vh"
)
prompt = gr.Textbox(max_lines=2, label="Chat Message")
# 更新submit调用,包含新的参数
prompt.submit(
respond,
[prompt, chatbot, llm_filter, keyword_expand, match_number],
[chatbot]
)
prompt.submit(lambda: "", None, [prompt])
demo.launch(share=False)
if __name__ == "__main__":
main() |