陈淑一 commited on
Commit
13c0d56
·
1 Parent(s): ddf861d

Update agent-model

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.env ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OPENAI API 访问密钥配置
2
+ OPENAI_API_KEY = ""
3
+
4
+ # 文心 API 访问密钥配置
5
+ # 方式1. 使用应用 AK/SK 鉴权
6
+ # 创建的应用的 API Key
7
+ QIANFAN_AK = ""
8
+ # 创建的应用的 Secret Key
9
+ QIANFAN_SK = ""
10
+ # 方式2. 使用安全认证 AK/SK 鉴权
11
+ # 安全认证方式获取的 Access Key
12
+ QIANFAN_ACCESS_KEY = ""
13
+ # 安全认证方式获取的 Secret Key
14
+ QIANFAN_SECRET_KEY = ""
15
+
16
+ # Ernie SDK 文心 API 访问密钥配置
17
+ EB_ACCESS_TOKEN = ""
18
+
19
+ # 控制台中获取的 APPID 信息
20
+ SPARK_APPID = ""
21
+ # 控制台中获取的 APIKey 信息
22
+ SPARK_API_KEY = ""
23
+ # 控制台中获取的 APISecret 信息
24
+ SPARK_API_SECRET = ""
25
+
26
+ # langchain中星火 API 访问密钥配置
27
+ # 控制台中获取的 APPID 信息
28
+ IFLYTEK_SPARK_APP_ID = ""
29
+ # 控制台中获取的 APISecret 信息
30
+ IFLYTEK_SPARK_API_KEY = ""
31
+ # 控制台中获取的 APIKey 信息
32
+ IFLYTEK_SPARK_API_SECRET = ""
33
+
34
+ # 智谱 API 访问密钥配置
35
+ ZHIPUAI_API_KEY = "c9bc35e8e7c1c076a8aaba862efb19af.DhiaibnU9Mys34de"
36
+ ZHIPUAI_API_KEY2 = "bd2f9388e369f6c46ef442556163b03c.79Jq4Gdqs9Ni9VnP"
Healthcare_agent.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ from langchain import LLMChain, PromptTemplate
3
+ from langchain_openai import ChatOpenAI
4
+ from langchain.chains import RetrievalQA
5
+ from langchain.output_parsers import StrOutputParser
6
+ from langchain.embeddings import ZhipuAIEmbeddings
7
+ from langchain.vectorstores import Chroma
8
+ from diffusers import StableDiffusionPipeline
9
+ import requests
10
+ import gradio as gr
11
+ import os
12
+ from dotenv import load_dotenv, find_dotenv
13
+
14
+ _ = load_dotenv(find_dotenv()) # 读取本地 .env 文件
15
+ zhipuai_api_key = os.environ['ZHIPUAI_API_KEY']
16
+
17
+ class HealthcareAgent:
18
+ def __init__(self):
19
+ self.vectordb = self.get_vectordb()
20
+ self.llm = ChatOpenAI(
21
+ model="glm-3-turbo",
22
+ temperature=0.7,
23
+ openai_api_key=zhipuai_api_key,
24
+ openai_api_base="https://open.bigmodel.cn/api/paas/v4/"
25
+ )
26
+ self.diffusion_model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda")
27
+
28
+ def get_vectordb(self):
29
+ embedding = ZhipuAIEmbeddings()
30
+ persist_directory = '/Users/chenshuyi/Documents/agent/data_base/vector_db'
31
+ vectordb = Chroma(
32
+ persist_directory=persist_directory,
33
+ embedding_function=embedding
34
+ )
35
+ return vectordb
36
+
37
+ def generate_response(self, input_text):
38
+ output = self.llm.invoke(input_text)
39
+ output_parser = StrOutputParser()
40
+ output = output_parser.invoke(output)
41
+ return output
42
+
43
+ def rag_search(self, symptoms):
44
+ template = """使用以下上下文来回答关于症状的问题。如果你不知道答案,就说你不知道,不要试图编造答案。最多使用三句话。尽量使答案简明扼要。总是在回答的最后说"谢谢你的提问!"。
45
+ 上下文: {context}
46
+ 问题: 基于这些症状 "{symptoms}",可能是什么疾病?请列出这些疾病的其他常见症状。
47
+ 回答格式:
48
+ 可能的疾病: [疾病1, 疾病2, ...]
49
+ 其他常见症状: [症状1, 症状2, ...]
50
+ 回答:"""
51
+ QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "symptoms"], template=template)
52
+ retriever = self.vectordb.as_retriever()
53
+ qa_chain = RetrievalQA.from_chain_type(
54
+ self.llm,
55
+ retriever=retriever,
56
+ return_source_documents=True,
57
+ chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
58
+ )
59
+ result = qa_chain({"query": symptoms})
60
+ return result["result"]
61
+
62
+ def assess_severity(self, condition, symptoms):
63
+ template = """使用以下上下文来评估疾病的严重程度。
64
+ 上下文: {context}
65
+ 疾病: {condition}
66
+ 症状: {symptoms}
67
+ 请根据给定的疾病和症状,评估病情的严重程度。将严重程度分为轻度、中度和重度三个等级。
68
+ 同时,请给出这个评估的理由,并提供一些建议。
69
+ 回答格式:
70
+ 严重程度: [轻度/中度/重度]
71
+ 理由: [您的解释]
72
+ 建议: [您的建议]
73
+ 回答:"""
74
+ QA_CHAIN_PROMPT = PromptTemplate(
75
+ input_variables=["context", "condition", "symptoms"],
76
+ template=template
77
+ )
78
+ retriever = self.vectordb.as_retriever()
79
+ qa_chain = RetrievalQA.from_chain_type(
80
+ self.llm,
81
+ retriever=retriever,
82
+ return_source_documents=True,
83
+ chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
84
+ )
85
+ result = qa_chain({"query": f"{condition} {symptoms}", "condition": condition, "symptoms": symptoms})
86
+ return result["result"]
87
+
88
+ def generate_skin_condition_image(self, condition):
89
+ severities = ["轻度", "中度", "重度"]
90
+ images = []
91
+ for severity in severities:
92
+ prompt = f"{severity}{condition}的皮肤症状"
93
+ image = self.diffusion_model(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
94
+ images.append(image)
95
+ return images
96
+
97
+ def recommend_medical_facility(self, user_location, condition, severity):
98
+ # 首先使用LLM推荐医疗设施类型
99
+ template = """
100
+ 基于以下信息推荐合适的医疗设施类型:
101
+
102
+ 疾病: {condition}
103
+ 严重程度: {severity}
104
+
105
+ 请从以下选项中选择最合适的医疗设施类型:
106
+ 1. 药房
107
+ 2. 社区医院
108
+ 3. 二甲医院
109
+ 4. 三甲医院
110
+
111
+ 只需回复数字1-4,不需要其他解释。
112
+
113
+ 推荐:
114
+ """
115
+
116
+ prompt = PromptTemplate(template=template, input_variables=["condition", "severity"])
117
+ llm_chain = LLMChain(prompt=prompt, llm=self.llm)
118
+ facility_type = llm_chain.run(condition=condition, severity=severity).strip()
119
+
120
+ # 将LLM的推荐转换为实际的设施类型
121
+ facility_types = {
122
+ "1": "药房",
123
+ "2": "社区医院",
124
+ "3": "二甲医院",
125
+ "4": "三甲医院"
126
+ }
127
+ recommended_type = facility_types.get(facility_type, "医院") # 默认为"医院"
128
+
129
+ # 调用高德地图API搜索附近的医疗设施
130
+ amap_key = "您的高德地图API密钥" # 请替换为您的实际API密钥
131
+ url = f"https://restapi.amap.com/v3/place/text?key={amap_key}&keywords={recommended_type}&city={user_location}&offset=10&page=1&extensions=all"
132
+
133
+ response = requests.get(url)
134
+ if response.status_code == 200:
135
+ data = response.json()
136
+ if data["status"] == "1" and data["pois"]:
137
+ facilities = data["pois"]
138
+ # 返回前三个结果
139
+ top_facilities = facilities[:3]
140
+ result = f"根据您的情况,我们推荐您去{recommended_type}。以下是附近的几个选择:\n\n"
141
+ for facility in top_facilities:
142
+ result += f"名称: {facility['name']}\n"
143
+ result += f"地址: {facility['address']}\n"
144
+ result += f"电话: {facility.get('tel', '未提供')}\n\n"
145
+ return result
146
+ else:
147
+ return f"抱歉,我们无法在您的位置找到合适的{recommended_type}。请考虑寻求紧急医疗帮助或咨询当地卫生部门。"
148
+ else:
149
+ return "抱歉,我们暂时无法获取医疗设施信息。请稍后再试或直接联系当地医疗机构。"
150
+
151
+ def interact(self, symptoms, user_location):
152
+ condition = self.rag_search(symptoms)
153
+
154
+ if "皮肤" in condition:
155
+ images = self.generate_skin_condition_image(condition)
156
+ return condition, images, True, None # 添加None作为医疗设施推荐的占位符
157
+ else:
158
+ severity_assessment = self.assess_severity(condition, symptoms)
159
+ severity, reason, advice = self.parse_severity_result(severity_assessment)
160
+ facility_recommendation = self.recommend_medical_facility(user_location, condition, severity)
161
+ return condition, (severity, reason, advice), False, facility_recommendation
162
+
163
+ def parse_severity_result(self, result):
164
+ # 这个函数需要根据实际的输出格式来实现
165
+ # 这里只是一个示例
166
+ lines = result.split('\n')
167
+ severity = ""
168
+ reason = ""
169
+ advice = ""
170
+ for line in lines:
171
+ if line.startswith("严重程度:"):
172
+ severity = line.split(':')[1].strip()
173
+ elif line.startswith("理由:"):
174
+ reason = line.split(':')[1].strip()
175
+ elif line.startswith("建议:"):
176
+ advice = line.split(':')[1].strip()
177
+ return severity, reason, advice
178
+
179
+ def gradio_interface():
180
+ agent = HealthcareAgent()
181
+
182
+ def process_input(symptoms, user_location):
183
+ condition, result, is_skin_condition, facility_recommendation = agent.interact(symptoms, user_location)
184
+ if is_skin_condition:
185
+ return gr.update(visible=True, value=condition), gr.update(visible=True, value=result), gr.update(visible=False), gr.update(visible=True, value=facility_recommendation)
186
+ else:
187
+ severity, reason, advice = result
188
+ return gr.update(visible=True, value=f"诊断: {condition}\n严重程度: {severity}\n理由: {reason}\n建议: {advice}"), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True, value=facility_recommendation)
189
+
190
+ def on_select(evt: gr.SelectData):
191
+ severities = ["轻度", "中度", "重度"]
192
+ return f"您选择的严重程度为: {severities[evt.index]}"
193
+
194
+ with gr.Blocks() as iface:
195
+ gr.Markdown("# 医疗保健助手")
196
+ symptoms_input = gr.Textbox(label="请描述您的症状")
197
+ location_input = gr.Textbox(label="请输入您的位置")
198
+ submit_btn = gr.Button("提交")
199
+
200
+ with gr.Group() as output_group:
201
+ text_output = gr.Textbox(label="诊断结果", visible=False)
202
+ image_output = gr.Gallery(label="请选择最接近您症状的图片", visible=False, columns=3, height=300)
203
+ severity_output = gr.Textbox(label="严重程度", visible=False)
204
+ facility_output = gr.Textbox(label="推荐医疗设施", visible=False)
205
+
206
+ submit_btn.click(process_input, inputs=[symptoms_input, location_input], outputs=[text_output, image_output, severity_output, facility_output])
207
+ image_output.select(on_select, None, severity_output)
208
+
209
+ return iface
210
+
211
+ if __name__ == "__main__":
212
+ iface = gradio_interface()
213
+ iface.launch()
data_base/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data_base/knowledge_db/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data_base/vector_db/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data_base/vector_db/chroma/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data_base/vector_db/chroma/4a760640-7f28-4921-b9eb-107dd81a30e2/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a63cbf10a7a1118c3c49522388ddeac7432c320741107e5c223df88edadbd3df
3
+ size 12708000
data_base/vector_db/chroma/4a760640-7f28-4921-b9eb-107dd81a30e2/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6700dc0d3644287e522ceb9b6618f4a25e1491d206fc9eb3cd96d71f12b9be20
3
+ size 100
data_base/vector_db/chroma/4a760640-7f28-4921-b9eb-107dd81a30e2/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bcc379271d45360118c53ee26c3ed4c3d7021526a84b518c090b3dc36639d3d
3
+ size 172072
data_base/vector_db/chroma/4a760640-7f28-4921-b9eb-107dd81a30e2/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28f97f051bcfc0a0a3cd4fda105dd685460522dc3d3b4c621d8ba5c69f489659
3
+ size 12000
data_base/vector_db/chroma/4a760640-7f28-4921-b9eb-107dd81a30e2/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bf6d0d0e5e00119db848bbe340a910c568b83b4462e53d4f072a36a4e5990c4
3
+ size 25736
data_processing.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import PyMuPDFLoader
2
+ from langchain_community.document_loaders import UnstructuredMarkdownLoader
3
+ from langchain.schema import Document
4
+ from langchain_community.embeddings import OpenAIEmbeddings
5
+ from langchain_community.vectorstores import Chroma
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ import re
8
+ import os
9
+ from dotenv import load_dotenv, find_dotenv
10
+ # 使用自己封装的智谱 Embedding,需要将封装代码下载到本地使用
11
+ from zhipuai_embedding import ZhipuAIEmbeddings
12
+
13
+ # 读取本地/项目的环境变量。
14
+ # find_dotenv()寻找并定位.env文件的路径
15
+ # load_dotenv()读取该.env文件,并将其中的环境变量加载到当前的运行环境中
16
+ # 如果你设置的是全局的环境变量,这行代码则没有任何作用。
17
+ _ = load_dotenv(find_dotenv())
18
+
19
+ # 创建一个 PyMuPDFLoader Class 实例,输入为待加载的 pdf 文档路径
20
+ loader = PyMuPDFLoader("/Users/chenshuyi/Documents/agent/data_base/knowledge_db/merck.pdf")
21
+
22
+ # 调用 PyMuPDFLoader Class 的函数 load 对 pdf 文件进行加载
23
+ pdf_pages = loader.load()
24
+ #print(f"载入后的变量类型为:{type(pdf_pages)},", f"该 PDF 一共包含 {len(pdf_pages)} 页")
25
+
26
+ #pdf_page = pdf_pages[1]
27
+ #print(f"每一个元素的类型:{type(pdf_page)}.",
28
+ # f"该文档的描述性数据:{pdf_page.metadata}",
29
+ # f"查看该文档的内容:\n{pdf_page.page_content}",
30
+ # sep="\n------\n")
31
+
32
+ pattern = re.compile(r'[^\u4e00-\u9fff](\n)[^\u4e00-\u9fff]', re.DOTALL)
33
+
34
+ for pdf_page in pdf_pages:
35
+ # 使用正则表达式替换非中文字符之间的换行符
36
+ pdf_page.page_content = re.sub(pattern, lambda match: match.group(0).replace('\n', ''), pdf_page.page_content)
37
+
38
+ # 移除圆点符号
39
+ pdf_page.page_content = pdf_page.page_content.replace('•', '')
40
+
41
+ # 将连续的两个换行符替换为单个换行符
42
+ pdf_page.page_content = pdf_page.page_content.replace('\n\n', '\n')
43
+
44
+ # 切分文档
45
+ text_splitter = RecursiveCharacterTextSplitter(
46
+ chunk_size=500, chunk_overlap=50)
47
+
48
+ split_docs = text_splitter.split_documents(pdf_pages)
49
+ #print(f"切分后的文件数量:{len(split_docs)}")
50
+ #print(f"切分后的字符数(可以用来大致评估 token 数):{sum([len(doc.page_content) for doc in split_docs])}")
51
+
52
+ #构建chroma向量库
53
+ embedding = ZhipuAIEmbeddings()
54
+
55
+ # 定义持久化路径
56
+ persist_directory = '../../data_base/vector_db/chroma'
57
+ #!rm -rf '../../data_base/vector_db/chroma' # 删除旧的数据库文件(如果文件夹中有文件的话
58
+
59
+ vectordb = Chroma.from_documents(
60
+ documents = split_docs,
61
+ embedding=embedding,
62
+ persist_directory=persist_directory #将persist_directory目录保存到磁盘上
63
+
64
+ )
65
+ vectordb.persist()
66
+ print(f"向量库中存储的数量:{vectordb._collection.count()}")
67
+
68
+ print(f"Chroma 数据存储在: {vectordb._persist_directory}")
69
+
70
+ question="headache"
71
+ sim_docs = vectordb.similarity_search(question,k=3)
72
+ print(f"检索到的内容数:{len(sim_docs)}")
73
+
74
+ for i, sim_doc in enumerate(sim_docs):
75
+ print(f"检索到的第{i}个内容: \n{sim_doc.page_content[:200]}", end="\n--------------\n")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio==3.50.2
2
+ langchain==0.0.350
3
+ langchain-openai==0.0.2.post1
4
+ chromadb==0.4.22
5
+ diffusers==0.25.0
6
+ transformers==4.36.2
7
+ torch==2.1.2
8
+ requests==2.31.0
9
+ python-dotenv==1.0.0
zhipuai_embedding.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Dict, List, Any
5
+
6
+
7
+ from langchain.embeddings.base import Embeddings
8
+ from langchain.pydantic_v1 import BaseModel, root_validator
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class ZhipuAIEmbeddings(BaseModel, Embeddings):
13
+ """`Zhipuai Embeddings` embedding models."""
14
+
15
+ client: Any
16
+ """`zhipuai.ZhipuAI"""
17
+
18
+ @root_validator()
19
+ def validate_environment(cls, values: Dict) -> Dict:
20
+ """
21
+ 实例化ZhipuAI为values["client"]
22
+
23
+ Args:
24
+
25
+ values (Dict): 包含配置信息的字典,必须包含 client 的字段.
26
+ Returns:
27
+
28
+ values (Dict): 包含配置信息的字典。如果环境中有zhipuai库,则将返回实例化的ZhipuAI类;否则将报错 'ModuleNotFoundError: No module named 'zhipuai''.
29
+ """
30
+ from zhipuai import ZhipuAI
31
+ values["client"] = ZhipuAI()
32
+ return values
33
+
34
+ def embed_query(self, text: str) -> List[float]:
35
+ """
36
+ 生成输入文本的 embedding.
37
+
38
+ Args:
39
+ texts (str): 要生成 embedding 的文本.
40
+
41
+ Return:
42
+ embeddings (List[float]): 输入文本的 embedding,一个浮点数值列表.
43
+ """
44
+ embeddings = self.client.embeddings.create(
45
+ model="embedding-2",
46
+ input=text
47
+ )
48
+ return embeddings.data[0].embedding
49
+
50
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
51
+ """
52
+ 生成输入文本列表的 embedding.
53
+ Args:
54
+ texts (List[str]): 要生成 embedding 的文本列表.
55
+
56
+ Returns:
57
+ List[List[float]]: 输入列表中每个文档的 embedding 列表。每个 embedding 都表示为一个浮点值列表。
58
+ """
59
+ return [self.embed_query(text) for text in texts]
60
+
61
+
62
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
63
+ """Asynchronous Embed search docs."""
64
+ raise NotImplementedError("Please use `embed_documents`. Official does not support asynchronous requests")
65
+
66
+ async def aembed_query(self, text: str) -> List[float]:
67
+ """Asynchronous Embed query text."""
68
+ raise NotImplementedError("Please use `aembed_query`. Official does not support asynchronous requests")