陈淑一
commited on
Commit
·
13c0d56
1
Parent(s):
ddf861d
Update agent-model
Browse files- .DS_Store +0 -0
- .env +36 -0
- Healthcare_agent.py +213 -0
- data_base/.DS_Store +0 -0
- data_base/knowledge_db/.DS_Store +0 -0
- data_base/vector_db/.DS_Store +0 -0
- data_base/vector_db/chroma/.DS_Store +0 -0
- data_base/vector_db/chroma/4a760640-7f28-4921-b9eb-107dd81a30e2/data_level0.bin +3 -0
- data_base/vector_db/chroma/4a760640-7f28-4921-b9eb-107dd81a30e2/header.bin +3 -0
- data_base/vector_db/chroma/4a760640-7f28-4921-b9eb-107dd81a30e2/index_metadata.pickle +3 -0
- data_base/vector_db/chroma/4a760640-7f28-4921-b9eb-107dd81a30e2/length.bin +3 -0
- data_base/vector_db/chroma/4a760640-7f28-4921-b9eb-107dd81a30e2/link_lists.bin +3 -0
- data_processing.py +75 -0
- requirements.txt +9 -0
- zhipuai_embedding.py +68 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.env
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# OPENAI API 访问密钥配置
|
2 |
+
OPENAI_API_KEY = ""
|
3 |
+
|
4 |
+
# 文心 API 访问密钥配置
|
5 |
+
# 方式1. 使用应用 AK/SK 鉴权
|
6 |
+
# 创建的应用的 API Key
|
7 |
+
QIANFAN_AK = ""
|
8 |
+
# 创建的应用的 Secret Key
|
9 |
+
QIANFAN_SK = ""
|
10 |
+
# 方式2. 使用安全认证 AK/SK 鉴权
|
11 |
+
# 安全认证方式获取的 Access Key
|
12 |
+
QIANFAN_ACCESS_KEY = ""
|
13 |
+
# 安全认证方式获取的 Secret Key
|
14 |
+
QIANFAN_SECRET_KEY = ""
|
15 |
+
|
16 |
+
# Ernie SDK 文心 API 访问密钥配置
|
17 |
+
EB_ACCESS_TOKEN = ""
|
18 |
+
|
19 |
+
# 控制台中获取的 APPID 信息
|
20 |
+
SPARK_APPID = ""
|
21 |
+
# 控制台中获取的 APIKey 信息
|
22 |
+
SPARK_API_KEY = ""
|
23 |
+
# 控制台中获取的 APISecret 信息
|
24 |
+
SPARK_API_SECRET = ""
|
25 |
+
|
26 |
+
# langchain中星火 API 访问密钥配置
|
27 |
+
# 控制台中获取的 APPID 信息
|
28 |
+
IFLYTEK_SPARK_APP_ID = ""
|
29 |
+
# 控制台中获取的 APISecret 信息
|
30 |
+
IFLYTEK_SPARK_API_KEY = ""
|
31 |
+
# 控制台中获取的 APIKey 信息
|
32 |
+
IFLYTEK_SPARK_API_SECRET = ""
|
33 |
+
|
34 |
+
# 智谱 API 访问密钥配置
|
35 |
+
ZHIPUAI_API_KEY = "c9bc35e8e7c1c076a8aaba862efb19af.DhiaibnU9Mys34de"
|
36 |
+
ZHIPUAI_API_KEY2 = "bd2f9388e369f6c46ef442556163b03c.79Jq4Gdqs9Ni9VnP"
|
Healthcare_agent.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chromadb
|
2 |
+
from langchain import LLMChain, PromptTemplate
|
3 |
+
from langchain_openai import ChatOpenAI
|
4 |
+
from langchain.chains import RetrievalQA
|
5 |
+
from langchain.output_parsers import StrOutputParser
|
6 |
+
from langchain.embeddings import ZhipuAIEmbeddings
|
7 |
+
from langchain.vectorstores import Chroma
|
8 |
+
from diffusers import StableDiffusionPipeline
|
9 |
+
import requests
|
10 |
+
import gradio as gr
|
11 |
+
import os
|
12 |
+
from dotenv import load_dotenv, find_dotenv
|
13 |
+
|
14 |
+
_ = load_dotenv(find_dotenv()) # 读取本地 .env 文件
|
15 |
+
zhipuai_api_key = os.environ['ZHIPUAI_API_KEY']
|
16 |
+
|
17 |
+
class HealthcareAgent:
|
18 |
+
def __init__(self):
|
19 |
+
self.vectordb = self.get_vectordb()
|
20 |
+
self.llm = ChatOpenAI(
|
21 |
+
model="glm-3-turbo",
|
22 |
+
temperature=0.7,
|
23 |
+
openai_api_key=zhipuai_api_key,
|
24 |
+
openai_api_base="https://open.bigmodel.cn/api/paas/v4/"
|
25 |
+
)
|
26 |
+
self.diffusion_model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda")
|
27 |
+
|
28 |
+
def get_vectordb(self):
|
29 |
+
embedding = ZhipuAIEmbeddings()
|
30 |
+
persist_directory = '/Users/chenshuyi/Documents/agent/data_base/vector_db'
|
31 |
+
vectordb = Chroma(
|
32 |
+
persist_directory=persist_directory,
|
33 |
+
embedding_function=embedding
|
34 |
+
)
|
35 |
+
return vectordb
|
36 |
+
|
37 |
+
def generate_response(self, input_text):
|
38 |
+
output = self.llm.invoke(input_text)
|
39 |
+
output_parser = StrOutputParser()
|
40 |
+
output = output_parser.invoke(output)
|
41 |
+
return output
|
42 |
+
|
43 |
+
def rag_search(self, symptoms):
|
44 |
+
template = """使用以下上下文来回答关于症状的问题。如果你不知道答案,就说你不知道,不要试图编造答案。最多使用三句话。尽量使答案简明扼要。总是在回答的最后说"谢谢你的提问!"。
|
45 |
+
上下文: {context}
|
46 |
+
问题: 基于这些症状 "{symptoms}",可能是什么疾病?请列出这些疾病的其他常见症状。
|
47 |
+
回答格式:
|
48 |
+
可能的疾病: [疾病1, 疾病2, ...]
|
49 |
+
其他常见症状: [症状1, 症状2, ...]
|
50 |
+
回答:"""
|
51 |
+
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "symptoms"], template=template)
|
52 |
+
retriever = self.vectordb.as_retriever()
|
53 |
+
qa_chain = RetrievalQA.from_chain_type(
|
54 |
+
self.llm,
|
55 |
+
retriever=retriever,
|
56 |
+
return_source_documents=True,
|
57 |
+
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
|
58 |
+
)
|
59 |
+
result = qa_chain({"query": symptoms})
|
60 |
+
return result["result"]
|
61 |
+
|
62 |
+
def assess_severity(self, condition, symptoms):
|
63 |
+
template = """使用以下上下文来评估疾病的严重程度。
|
64 |
+
上下文: {context}
|
65 |
+
疾病: {condition}
|
66 |
+
症状: {symptoms}
|
67 |
+
请根据给定的疾病和症状,评估病情的严重程度。将严重程度分为轻度、中度和重度三个等级。
|
68 |
+
同时,请给出这个评估的理由,并提供一些建议。
|
69 |
+
回答格式:
|
70 |
+
严重程度: [轻度/中度/重度]
|
71 |
+
理由: [您的解释]
|
72 |
+
建议: [您的建议]
|
73 |
+
回答:"""
|
74 |
+
QA_CHAIN_PROMPT = PromptTemplate(
|
75 |
+
input_variables=["context", "condition", "symptoms"],
|
76 |
+
template=template
|
77 |
+
)
|
78 |
+
retriever = self.vectordb.as_retriever()
|
79 |
+
qa_chain = RetrievalQA.from_chain_type(
|
80 |
+
self.llm,
|
81 |
+
retriever=retriever,
|
82 |
+
return_source_documents=True,
|
83 |
+
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
|
84 |
+
)
|
85 |
+
result = qa_chain({"query": f"{condition} {symptoms}", "condition": condition, "symptoms": symptoms})
|
86 |
+
return result["result"]
|
87 |
+
|
88 |
+
def generate_skin_condition_image(self, condition):
|
89 |
+
severities = ["轻度", "中度", "重度"]
|
90 |
+
images = []
|
91 |
+
for severity in severities:
|
92 |
+
prompt = f"{severity}{condition}的皮肤症状"
|
93 |
+
image = self.diffusion_model(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
|
94 |
+
images.append(image)
|
95 |
+
return images
|
96 |
+
|
97 |
+
def recommend_medical_facility(self, user_location, condition, severity):
|
98 |
+
# 首先使用LLM推荐医疗设施类型
|
99 |
+
template = """
|
100 |
+
基于以下信息推荐合适的医疗设施类型:
|
101 |
+
|
102 |
+
疾病: {condition}
|
103 |
+
严重程度: {severity}
|
104 |
+
|
105 |
+
请从以下选项中选择最合适的医疗设施类型:
|
106 |
+
1. 药房
|
107 |
+
2. 社区医院
|
108 |
+
3. 二甲医院
|
109 |
+
4. 三甲医院
|
110 |
+
|
111 |
+
只需回复数字1-4,不需要其他解释。
|
112 |
+
|
113 |
+
推荐:
|
114 |
+
"""
|
115 |
+
|
116 |
+
prompt = PromptTemplate(template=template, input_variables=["condition", "severity"])
|
117 |
+
llm_chain = LLMChain(prompt=prompt, llm=self.llm)
|
118 |
+
facility_type = llm_chain.run(condition=condition, severity=severity).strip()
|
119 |
+
|
120 |
+
# 将LLM的推荐转换为实际的设施类型
|
121 |
+
facility_types = {
|
122 |
+
"1": "药房",
|
123 |
+
"2": "社区医院",
|
124 |
+
"3": "二甲医院",
|
125 |
+
"4": "三甲医院"
|
126 |
+
}
|
127 |
+
recommended_type = facility_types.get(facility_type, "医院") # 默认为"医院"
|
128 |
+
|
129 |
+
# 调用高德地图API搜索附近的医疗设施
|
130 |
+
amap_key = "您的高德地图API密钥" # 请替换为您的实际API密钥
|
131 |
+
url = f"https://restapi.amap.com/v3/place/text?key={amap_key}&keywords={recommended_type}&city={user_location}&offset=10&page=1&extensions=all"
|
132 |
+
|
133 |
+
response = requests.get(url)
|
134 |
+
if response.status_code == 200:
|
135 |
+
data = response.json()
|
136 |
+
if data["status"] == "1" and data["pois"]:
|
137 |
+
facilities = data["pois"]
|
138 |
+
# 返回前三个结果
|
139 |
+
top_facilities = facilities[:3]
|
140 |
+
result = f"根据您的情况,我们推荐您去{recommended_type}。以下是附近的几个选择:\n\n"
|
141 |
+
for facility in top_facilities:
|
142 |
+
result += f"名称: {facility['name']}\n"
|
143 |
+
result += f"地址: {facility['address']}\n"
|
144 |
+
result += f"电话: {facility.get('tel', '未提供')}\n\n"
|
145 |
+
return result
|
146 |
+
else:
|
147 |
+
return f"抱歉,我们无法在您的位置找到合适的{recommended_type}。请考虑寻求紧急医疗帮助或咨询当地卫生部门。"
|
148 |
+
else:
|
149 |
+
return "抱歉,我们暂时无法获取医疗设施信息。请稍后再试或直接联系当地医疗机构。"
|
150 |
+
|
151 |
+
def interact(self, symptoms, user_location):
|
152 |
+
condition = self.rag_search(symptoms)
|
153 |
+
|
154 |
+
if "皮肤" in condition:
|
155 |
+
images = self.generate_skin_condition_image(condition)
|
156 |
+
return condition, images, True, None # 添加None作为医疗设施推荐的占位符
|
157 |
+
else:
|
158 |
+
severity_assessment = self.assess_severity(condition, symptoms)
|
159 |
+
severity, reason, advice = self.parse_severity_result(severity_assessment)
|
160 |
+
facility_recommendation = self.recommend_medical_facility(user_location, condition, severity)
|
161 |
+
return condition, (severity, reason, advice), False, facility_recommendation
|
162 |
+
|
163 |
+
def parse_severity_result(self, result):
|
164 |
+
# 这个函数需要根据实际的输出格式来实现
|
165 |
+
# 这里只是一个示例
|
166 |
+
lines = result.split('\n')
|
167 |
+
severity = ""
|
168 |
+
reason = ""
|
169 |
+
advice = ""
|
170 |
+
for line in lines:
|
171 |
+
if line.startswith("严重程度:"):
|
172 |
+
severity = line.split(':')[1].strip()
|
173 |
+
elif line.startswith("理由:"):
|
174 |
+
reason = line.split(':')[1].strip()
|
175 |
+
elif line.startswith("建议:"):
|
176 |
+
advice = line.split(':')[1].strip()
|
177 |
+
return severity, reason, advice
|
178 |
+
|
179 |
+
def gradio_interface():
|
180 |
+
agent = HealthcareAgent()
|
181 |
+
|
182 |
+
def process_input(symptoms, user_location):
|
183 |
+
condition, result, is_skin_condition, facility_recommendation = agent.interact(symptoms, user_location)
|
184 |
+
if is_skin_condition:
|
185 |
+
return gr.update(visible=True, value=condition), gr.update(visible=True, value=result), gr.update(visible=False), gr.update(visible=True, value=facility_recommendation)
|
186 |
+
else:
|
187 |
+
severity, reason, advice = result
|
188 |
+
return gr.update(visible=True, value=f"诊断: {condition}\n严重程度: {severity}\n理由: {reason}\n建议: {advice}"), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True, value=facility_recommendation)
|
189 |
+
|
190 |
+
def on_select(evt: gr.SelectData):
|
191 |
+
severities = ["轻度", "中度", "重度"]
|
192 |
+
return f"您选择的严重程度为: {severities[evt.index]}"
|
193 |
+
|
194 |
+
with gr.Blocks() as iface:
|
195 |
+
gr.Markdown("# 医疗保健助手")
|
196 |
+
symptoms_input = gr.Textbox(label="请描述您的症状")
|
197 |
+
location_input = gr.Textbox(label="请输入您的位置")
|
198 |
+
submit_btn = gr.Button("提交")
|
199 |
+
|
200 |
+
with gr.Group() as output_group:
|
201 |
+
text_output = gr.Textbox(label="诊断结果", visible=False)
|
202 |
+
image_output = gr.Gallery(label="请选择最接近您症状的图片", visible=False, columns=3, height=300)
|
203 |
+
severity_output = gr.Textbox(label="严重程度", visible=False)
|
204 |
+
facility_output = gr.Textbox(label="推荐医疗设施", visible=False)
|
205 |
+
|
206 |
+
submit_btn.click(process_input, inputs=[symptoms_input, location_input], outputs=[text_output, image_output, severity_output, facility_output])
|
207 |
+
image_output.select(on_select, None, severity_output)
|
208 |
+
|
209 |
+
return iface
|
210 |
+
|
211 |
+
if __name__ == "__main__":
|
212 |
+
iface = gradio_interface()
|
213 |
+
iface.launch()
|
data_base/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
data_base/knowledge_db/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
data_base/vector_db/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
data_base/vector_db/chroma/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
data_base/vector_db/chroma/4a760640-7f28-4921-b9eb-107dd81a30e2/data_level0.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a63cbf10a7a1118c3c49522388ddeac7432c320741107e5c223df88edadbd3df
|
3 |
+
size 12708000
|
data_base/vector_db/chroma/4a760640-7f28-4921-b9eb-107dd81a30e2/header.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6700dc0d3644287e522ceb9b6618f4a25e1491d206fc9eb3cd96d71f12b9be20
|
3 |
+
size 100
|
data_base/vector_db/chroma/4a760640-7f28-4921-b9eb-107dd81a30e2/index_metadata.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3bcc379271d45360118c53ee26c3ed4c3d7021526a84b518c090b3dc36639d3d
|
3 |
+
size 172072
|
data_base/vector_db/chroma/4a760640-7f28-4921-b9eb-107dd81a30e2/length.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:28f97f051bcfc0a0a3cd4fda105dd685460522dc3d3b4c621d8ba5c69f489659
|
3 |
+
size 12000
|
data_base/vector_db/chroma/4a760640-7f28-4921-b9eb-107dd81a30e2/link_lists.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0bf6d0d0e5e00119db848bbe340a910c568b83b4462e53d4f072a36a4e5990c4
|
3 |
+
size 25736
|
data_processing.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
2 |
+
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
3 |
+
from langchain.schema import Document
|
4 |
+
from langchain_community.embeddings import OpenAIEmbeddings
|
5 |
+
from langchain_community.vectorstores import Chroma
|
6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
+
import re
|
8 |
+
import os
|
9 |
+
from dotenv import load_dotenv, find_dotenv
|
10 |
+
# 使用自己封装的智谱 Embedding,需要将封装代码下载到本地使用
|
11 |
+
from zhipuai_embedding import ZhipuAIEmbeddings
|
12 |
+
|
13 |
+
# 读取本地/项目的环境变量。
|
14 |
+
# find_dotenv()寻找并定位.env文件的路径
|
15 |
+
# load_dotenv()读取该.env文件,并将其中的环境变量加载到当前的运行环境中
|
16 |
+
# 如果你设置的是全局的环境变量,这行代码则没有任何作用。
|
17 |
+
_ = load_dotenv(find_dotenv())
|
18 |
+
|
19 |
+
# 创建一个 PyMuPDFLoader Class 实例,输入为待加载的 pdf 文档路径
|
20 |
+
loader = PyMuPDFLoader("/Users/chenshuyi/Documents/agent/data_base/knowledge_db/merck.pdf")
|
21 |
+
|
22 |
+
# 调用 PyMuPDFLoader Class 的函数 load 对 pdf 文件进行加载
|
23 |
+
pdf_pages = loader.load()
|
24 |
+
#print(f"载入后的变量类型为:{type(pdf_pages)},", f"该 PDF 一共包含 {len(pdf_pages)} 页")
|
25 |
+
|
26 |
+
#pdf_page = pdf_pages[1]
|
27 |
+
#print(f"每一个元素的类型:{type(pdf_page)}.",
|
28 |
+
# f"该文档的描述性数据:{pdf_page.metadata}",
|
29 |
+
# f"查看该文档的内容:\n{pdf_page.page_content}",
|
30 |
+
# sep="\n------\n")
|
31 |
+
|
32 |
+
pattern = re.compile(r'[^\u4e00-\u9fff](\n)[^\u4e00-\u9fff]', re.DOTALL)
|
33 |
+
|
34 |
+
for pdf_page in pdf_pages:
|
35 |
+
# 使用正则表达式替换非中文字符之间的换行符
|
36 |
+
pdf_page.page_content = re.sub(pattern, lambda match: match.group(0).replace('\n', ''), pdf_page.page_content)
|
37 |
+
|
38 |
+
# 移除圆点符号
|
39 |
+
pdf_page.page_content = pdf_page.page_content.replace('•', '')
|
40 |
+
|
41 |
+
# 将连续的两个换行符替换为单个换行符
|
42 |
+
pdf_page.page_content = pdf_page.page_content.replace('\n\n', '\n')
|
43 |
+
|
44 |
+
# 切分文档
|
45 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
46 |
+
chunk_size=500, chunk_overlap=50)
|
47 |
+
|
48 |
+
split_docs = text_splitter.split_documents(pdf_pages)
|
49 |
+
#print(f"切分后的文件数量:{len(split_docs)}")
|
50 |
+
#print(f"切分后的字符数(可以用来大致评估 token 数):{sum([len(doc.page_content) for doc in split_docs])}")
|
51 |
+
|
52 |
+
#构建chroma向量库
|
53 |
+
embedding = ZhipuAIEmbeddings()
|
54 |
+
|
55 |
+
# 定义持久化路径
|
56 |
+
persist_directory = '../../data_base/vector_db/chroma'
|
57 |
+
#!rm -rf '../../data_base/vector_db/chroma' # 删除旧的数据库文件(如果文件夹中有文件的话
|
58 |
+
|
59 |
+
vectordb = Chroma.from_documents(
|
60 |
+
documents = split_docs,
|
61 |
+
embedding=embedding,
|
62 |
+
persist_directory=persist_directory #将persist_directory目录保存到磁盘上
|
63 |
+
|
64 |
+
)
|
65 |
+
vectordb.persist()
|
66 |
+
print(f"向量库中存储的数量:{vectordb._collection.count()}")
|
67 |
+
|
68 |
+
print(f"Chroma 数据存储在: {vectordb._persist_directory}")
|
69 |
+
|
70 |
+
question="headache"
|
71 |
+
sim_docs = vectordb.similarity_search(question,k=3)
|
72 |
+
print(f"检索到的内容数:{len(sim_docs)}")
|
73 |
+
|
74 |
+
for i, sim_doc in enumerate(sim_docs):
|
75 |
+
print(f"检索到的第{i}个内容: \n{sim_doc.page_content[:200]}", end="\n--------------\n")
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==3.50.2
|
2 |
+
langchain==0.0.350
|
3 |
+
langchain-openai==0.0.2.post1
|
4 |
+
chromadb==0.4.22
|
5 |
+
diffusers==0.25.0
|
6 |
+
transformers==4.36.2
|
7 |
+
torch==2.1.2
|
8 |
+
requests==2.31.0
|
9 |
+
python-dotenv==1.0.0
|
zhipuai_embedding.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from typing import Dict, List, Any
|
5 |
+
|
6 |
+
|
7 |
+
from langchain.embeddings.base import Embeddings
|
8 |
+
from langchain.pydantic_v1 import BaseModel, root_validator
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
class ZhipuAIEmbeddings(BaseModel, Embeddings):
|
13 |
+
"""`Zhipuai Embeddings` embedding models."""
|
14 |
+
|
15 |
+
client: Any
|
16 |
+
"""`zhipuai.ZhipuAI"""
|
17 |
+
|
18 |
+
@root_validator()
|
19 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
20 |
+
"""
|
21 |
+
实例化ZhipuAI为values["client"]
|
22 |
+
|
23 |
+
Args:
|
24 |
+
|
25 |
+
values (Dict): 包含配置信息的字典,必须包含 client 的字段.
|
26 |
+
Returns:
|
27 |
+
|
28 |
+
values (Dict): 包含配置信息的字典。如果环境中有zhipuai库,则将返回实例化的ZhipuAI类;否则将报错 'ModuleNotFoundError: No module named 'zhipuai''.
|
29 |
+
"""
|
30 |
+
from zhipuai import ZhipuAI
|
31 |
+
values["client"] = ZhipuAI()
|
32 |
+
return values
|
33 |
+
|
34 |
+
def embed_query(self, text: str) -> List[float]:
|
35 |
+
"""
|
36 |
+
生成输入文本的 embedding.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
texts (str): 要生成 embedding 的文本.
|
40 |
+
|
41 |
+
Return:
|
42 |
+
embeddings (List[float]): 输入文本的 embedding,一个浮点数值列表.
|
43 |
+
"""
|
44 |
+
embeddings = self.client.embeddings.create(
|
45 |
+
model="embedding-2",
|
46 |
+
input=text
|
47 |
+
)
|
48 |
+
return embeddings.data[0].embedding
|
49 |
+
|
50 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
51 |
+
"""
|
52 |
+
生成输入文本列表的 embedding.
|
53 |
+
Args:
|
54 |
+
texts (List[str]): 要生成 embedding 的文本列表.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
List[List[float]]: 输入列表中每个文档的 embedding 列表。每个 embedding 都表示为一个浮点值列表。
|
58 |
+
"""
|
59 |
+
return [self.embed_query(text) for text in texts]
|
60 |
+
|
61 |
+
|
62 |
+
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
63 |
+
"""Asynchronous Embed search docs."""
|
64 |
+
raise NotImplementedError("Please use `embed_documents`. Official does not support asynchronous requests")
|
65 |
+
|
66 |
+
async def aembed_query(self, text: str) -> List[float]:
|
67 |
+
"""Asynchronous Embed query text."""
|
68 |
+
raise NotImplementedError("Please use `aembed_query`. Official does not support asynchronous requests")
|