Saif Rehman Nasir commited on
Commit
58c81e4
1 Parent(s): 430df58

Add Graph Retriever and Generator code, Add input data, Update requirements

Browse files
Files changed (4) hide show
  1. app.py +34 -24
  2. diseases.pdf +0 -0
  3. rag.py +310 -0
  4. requirements.txt +9 -1
app.py CHANGED
@@ -1,9 +1,13 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
3
 
4
  """
5
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
  """
 
 
7
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
9
 
@@ -11,33 +15,37 @@ def respond(
11
  message,
12
  history: list[tuple[str, str]],
13
  system_message,
14
- max_tokens,
15
- temperature,
16
  top_p,
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
27
 
28
- response = ""
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
41
 
42
  """
43
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
@@ -45,9 +53,11 @@ For information on how to customize the ChatInterface, peruse the gradio docs: h
45
  demo = gr.ChatInterface(
46
  respond,
47
  additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
 
 
51
  gr.Slider(
52
  minimum=0.1,
53
  maximum=1.0,
@@ -60,4 +70,4 @@ demo = gr.ChatInterface(
60
 
61
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ import os
4
+ from rag import local_retriever, global_retriever
5
 
6
  """
7
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
8
  """
9
+
10
+
11
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
12
 
13
 
 
15
  message,
16
  history: list[tuple[str, str]],
17
  system_message,
18
+ search_strategy,
 
19
  top_p,
20
  ):
21
+ if search_strategy == "Global":
22
+ return global_retriever(message, 2, "multiple paragraphs")
23
+ else:
24
+ messages = [{"role": "system", "content": system_message}]
25
+
26
+ for val in history:
27
+ if val[0]:
28
+ messages.append({"role": "user", "content": val[0]})
29
+ if val[1]:
30
+ messages.append({"role": "assistant", "content": val[1]})
31
 
32
+ messages.append({"role": "user", "content": message})
 
 
 
 
33
 
34
+ response = ""
35
 
36
+ for message in client.chat_completion(
37
+ messages,
38
+ max_tokens=2048,
39
+ stream=True,
40
+ temperature=1.0,
41
+ top_p=top_p,
42
+ ):
43
+ token = message.choices[0].delta.content
44
 
45
+ response += token
46
+
47
+ return response
 
 
 
 
 
48
 
 
 
49
 
50
  """
51
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
53
  demo = gr.ChatInterface(
54
  respond,
55
  additional_inputs=[
56
+ gr.Textbox(
57
+ value="You are a medical assistant Chatbot. For any query that you don't know, you will say 'I don't know'. You will answer with the given information:",
58
+ label="System message",
59
+ ),
60
+ gr.Dropdown(choices=["Local", "Global"], label="Select search strategy"),
61
  gr.Slider(
62
  minimum=0.1,
63
  maximum=1.0,
 
70
 
71
 
72
  if __name__ == "__main__":
73
+ demo.launch()
diseases.pdf ADDED
Binary file (376 kB). View file
 
rag.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from neo4j import GraphDatabase, Result
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
7
+ from langchain_community.graphs import Neo4jGraph
8
+ from langchain_community.vectorstores import Neo4jVector
9
+
10
+ from langchain_core.prompts import ChatPromptTemplate
11
+ from langchain_core.output_parsers import StrOutputParser
12
+
13
+ from langchain_huggingface import HuggingFaceEndpoint
14
+
15
+ from typing import Dict, Any
16
+ from tqdm import tqdm
17
+
18
+ NEO4J_URI = os.getenv("NEO4J_URI")
19
+ NEO4J_USERNAME = os.getenv("NEO4J_USERNAME")
20
+ NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
21
+ vector_index = os.getenv("VECTOR_INDEX")
22
+
23
+ chat_llm = HuggingFaceEndpoint(
24
+ repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
25
+ task="text-generation",
26
+ max_new_tokens=100,
27
+ do_sample=False,
28
+ )
29
+
30
+
31
+ def local_retriever(query: str):
32
+ topChunks = 3
33
+ topCommunities = 3
34
+ topOutsideRels = 10
35
+ topInsideRels = 10
36
+ topEntities = 10
37
+
38
+ driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
39
+ try:
40
+ lc_retrieval_query = """
41
+ WITH collect(node) as nodes
42
+ // Entity - Text Unit Mapping
43
+ WITH
44
+ collect {
45
+ UNWIND nodes as n
46
+ MATCH (n)<-[:HAS_ENTITY]->(c:__Chunk__)
47
+ WITH c, count(distinct n) as freq
48
+ RETURN c.text AS chunkText
49
+ ORDER BY freq DESC
50
+ LIMIT $topChunks
51
+ } AS text_mapping,
52
+ // Entity - Report Mapping
53
+ collect {
54
+ UNWIND nodes as n
55
+ MATCH (n)-[:IN_COMMUNITY]->(c:__Community__)
56
+ WITH c, c.rank as rank, c.weight AS weight
57
+ RETURN c.summary
58
+ ORDER BY rank, weight DESC
59
+ LIMIT $topCommunities
60
+ } AS report_mapping,
61
+ // Outside Relationships
62
+ collect {
63
+ UNWIND nodes as n
64
+ MATCH (n)-[r:RELATED]-(m)
65
+ WHERE NOT m IN nodes
66
+ RETURN r.description AS descriptionText
67
+ ORDER BY r.rank, r.weight DESC
68
+ LIMIT $topOutsideRels
69
+ } as outsideRels,
70
+ // Inside Relationships
71
+ collect {
72
+ UNWIND nodes as n
73
+ MATCH (n)-[r:RELATED]-(m)
74
+ WHERE m IN nodes
75
+ RETURN r.description AS descriptionText
76
+ ORDER BY r.rank, r.weight DESC
77
+ LIMIT $topInsideRels
78
+ } as insideRels,
79
+ // Entities description
80
+ collect {
81
+ UNWIND nodes as n
82
+ RETURN n.description AS descriptionText
83
+ } as entities
84
+ // We don't have covariates or claims here
85
+ RETURN {Chunks: text_mapping, Reports: report_mapping,
86
+ Relationships: outsideRels + insideRels,
87
+ Entities: entities} AS text, 1.0 AS score, {} AS metadata
88
+ """
89
+
90
+ embedding_model_name = "nomic-ai/nomic-embed-text-v1"
91
+ embedding_model_kwargs = {"device": "cpu", "trust_remote_code": True}
92
+ encode_kwargs = {"normalize_embeddings": True}
93
+ embedding_model = HuggingFaceBgeEmbeddings(
94
+ model_name=embedding_model_name,
95
+ model_kwargs=embedding_model_kwargs,
96
+ encode_kwargs=encode_kwargs,
97
+ )
98
+
99
+ lc_vector = Neo4jVector.from_existing_index(
100
+ embedding_model,
101
+ url=NEO4J_URI,
102
+ username=NEO4J_USERNAME,
103
+ password=NEO4J_PASSWORD,
104
+ index_name=vector_index,
105
+ retrieval_query=lc_retrieval_query,
106
+ )
107
+ docs = lc_vector.similarity_search(
108
+ query,
109
+ k=topEntities,
110
+ params={
111
+ "topChunks": topChunks,
112
+ "topCommunities": topCommunities,
113
+ "topOutsideRels": topOutsideRels,
114
+ "topInsideRels": topInsideRels,
115
+ },
116
+ )
117
+
118
+ return docs[0]
119
+ except Exception as err:
120
+ return f"Error: {err}"
121
+ finally:
122
+ try:
123
+ driver.close()
124
+ except Exception as e:
125
+ print(f"Error closing driver: {e}")
126
+
127
+
128
+ def global_retriever(query: str, level: int, response_type: str):
129
+ MAP_SYSTEM_PROMPT = """
130
+ ---Role---
131
+
132
+ You are a helpful assistant responding to questions about data in the tables provided.
133
+
134
+ ---Goal---
135
+
136
+ Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables.
137
+
138
+ You should use the data provided in the data tables below as the primary context for generating the response.
139
+ If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up.
140
+
141
+ Each key point in the response should have the following element:
142
+ - Description: A comprehensive description of the point.
143
+ - Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0.
144
+
145
+ The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will".
146
+
147
+ Points supported by data should list the relevant reports as references as follows:
148
+ "This is an example sentence supported by data references [Data: Reports (report ids)]"
149
+
150
+ **Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
151
+
152
+ For example:
153
+ "Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Reports (1, 3)]"
154
+
155
+ where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables.
156
+
157
+ Do not include information where the supporting evidence for it is not provided. Always start with {{ and end with }}.
158
+
159
+ The response can only be JSON formatted. Do not add any text before or after the JSON-formatted string in the output.
160
+
161
+ The response should adhere to the following format:
162
+ {{
163
+ "points": [
164
+ {{"description": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}},
165
+ {{"description": "Description of point 2 [Data: Reports (report ids)]", "score": score_value}}
166
+ ]
167
+ }}
168
+
169
+ ---Data tables---
170
+
171
+ """
172
+ map_prompt = ChatPromptTemplate.from_messages(
173
+ [
174
+ (
175
+ "system",
176
+ MAP_SYSTEM_PROMPT,
177
+ ),
178
+ ("system", "{context_data}"),
179
+ (
180
+ "human",
181
+ "{question}",
182
+ ),
183
+ ]
184
+ )
185
+
186
+ map_chain = map_prompt | chat_llm | StrOutputParser()
187
+
188
+ REDUCE_SYSTEM_PROMPT = """
189
+ ---Role---
190
+
191
+ You are a helpful assistant responding to questions about a dataset by synthesizing perspectives from multiple analysts.
192
+
193
+
194
+ ---Goal---
195
+
196
+ Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset.
197
+
198
+ Note that the analysts' reports provided below are ranked in the **descending order of importance**.
199
+
200
+ If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up.
201
+
202
+ The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format.
203
+
204
+ Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
205
+
206
+ The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will".
207
+
208
+ The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process.
209
+
210
+ **Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
211
+
212
+ For example:
213
+
214
+ "Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]"
215
+
216
+ where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record.
217
+
218
+ Do not include information where the supporting evidence for it is not provided.
219
+
220
+
221
+ ---Target response length and format---
222
+
223
+ {response_type}
224
+
225
+
226
+ ---Analyst Reports---
227
+
228
+ {report_data}
229
+
230
+
231
+ ---Goal---
232
+
233
+ Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset.
234
+
235
+ Note that the analysts' reports provided below are ranked in the **descending order of importance**.
236
+
237
+ If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up.
238
+
239
+ The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format.
240
+
241
+ The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will".
242
+
243
+ The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process.
244
+
245
+ **Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
246
+
247
+ For example:
248
+
249
+ "Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Reports (2, 7, 34, 46, 64, +more)]. He is also CEO of company X [Data: Reports (1, 3)]"
250
+
251
+ where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record.
252
+
253
+ Do not include information where the supporting evidence for it is not provided.
254
+
255
+
256
+ ---Target response length and format---
257
+
258
+ {response_type}
259
+
260
+ Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
261
+ """
262
+
263
+ reduce_prompt = ChatPromptTemplate.from_messages(
264
+ [
265
+ (
266
+ "system",
267
+ REDUCE_SYSTEM_PROMPT,
268
+ ),
269
+ (
270
+ "human",
271
+ "{question}",
272
+ ),
273
+ ]
274
+ )
275
+
276
+ reduce_chain = reduce_prompt | chat_llm | StrOutputParser()
277
+
278
+ graph = Neo4jGraph(
279
+ url=NEO4J_URI,
280
+ username=NEO4J_USERNAME,
281
+ password=NEO4J_PASSWORD,
282
+ refresh_schema=False,
283
+ )
284
+
285
+ community_data = graph.query(
286
+ """
287
+ MATCH (c:__Community__)
288
+ WHERE c.level = $level
289
+ RETURN c.full_content AS output
290
+ """,
291
+ params={"level": level},
292
+ )
293
+ # print(community_data)
294
+ intermediate_results = []
295
+ i = 0
296
+ for community in tqdm(community_data[:10], desc="Processing communities"):
297
+ intermediate_response = map_chain.invoke(
298
+ {"question": query, "context_data": community["output"]}
299
+ )
300
+ intermediate_results.append(intermediate_response)
301
+ i += 1
302
+
303
+ final_response = reduce_chain.invoke(
304
+ {
305
+ "report_data": intermediate_results,
306
+ "question": query,
307
+ "response_type": response_type,
308
+ }
309
+ )
310
+ return final_response
requirements.txt CHANGED
@@ -1 +1,9 @@
1
- huggingface_hub==0.22.2
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub==0.22.2
2
+ sentence_transformers
3
+ numpy
4
+ pandas
5
+ neo4j
6
+ langchain_community
7
+ langchain_core
8
+ langchain_huggingface
9
+ tqdm