Towhidul commited on
Commit
36f2167
·
verified ·
1 Parent(s): 24695dd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -0
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import nest_asyncio
4
+ import re
5
+ from pathlib import Path
6
+ import typing as t
7
+ import base64
8
+ from mimetypes import guess_type
9
+ from llama_parse import LlamaParse
10
+ from llama_index.core.schema import TextNode
11
+ from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage, Settings
12
+ from llama_index.embeddings.openai import OpenAIEmbedding
13
+ from llama_index.llms.openai import OpenAI
14
+ from llama_index.core.query_engine import CustomQueryEngine
15
+ from llama_index.multi_modal_llms.openai import OpenAIMultiModal
16
+ from llama_index.core.prompts import PromptTemplate
17
+ from llama_index.core.schema import ImageNode
18
+ from llama_index.core.base.response.schema import Response
19
+ from typing import Any, List, Optional
20
+ from llama_index.core.postprocessor.types import BaseNodePostprocessor
21
+
22
+ nest_asyncio.apply()
23
+
24
+ # Setting API keys
25
+ os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
26
+ os.environ["LLAMA_CLOUD_API_KEY"] = os.getenv('LLAMA_CLOUD_API_KEY')
27
+
28
+ # Initialize the parser
29
+ parser = LlamaParse(
30
+ result_type="markdown",
31
+ parsing_instruction="You are given a medical textbook on medicine",
32
+ use_vendor_multimodal_model=True,
33
+ vendor_multimodal_model_name="gpt-4o-mini-2024-07-18",
34
+ show_progress=True,
35
+ verbose=True,
36
+ invalidate_cache=True,
37
+ do_not_cache=True,
38
+ num_workers=8,
39
+ language="en"
40
+ )
41
+
42
+ # Function to encode image to data URL
43
+ def local_image_to_data_url(image_path):
44
+ mime_type, _ = guess_type(image_path)
45
+ if mime_type is None:
46
+ mime_type = 'image/png'
47
+ with open(image_path, "rb") as image_file:
48
+ base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
49
+ return f"data:{mime_type};base64,{base64_encoded_data}"
50
+
51
+ # Function to get sorted image files
52
+ def get_page_number(file_name):
53
+ match = re.search(r"-page-(\d+)\.jpg$", str(file_name))
54
+ if match:
55
+ return int(match.group(1))
56
+ return 0
57
+
58
+ def _get_sorted_image_files(image_dir):
59
+ raw_files = [f for f in list(Path(image_dir).iterdir()) if f.is_file()]
60
+ sorted_files = sorted(raw_files, key=get_page_number)
61
+ return sorted_files
62
+
63
+ def get_text_nodes(md_json_objs, image_dir) -> t.List[TextNode]:
64
+ nodes = []
65
+ for result in md_json_objs:
66
+ json_dicts = result["pages"]
67
+ document_name = result["file_path"].split('/')[-1]
68
+ docs = [doc["md"] for doc in json_dicts]
69
+ image_files = _get_sorted_image_files(image_dir)
70
+ for idx, doc in enumerate(docs):
71
+ node = TextNode(
72
+ text=doc,
73
+ metadata={"image_path": str(image_files[idx]), "page_num": idx + 1, "document_name": document_name},
74
+ )
75
+ nodes.append(node)
76
+ return nodes
77
+
78
+ # Gradio interface functions
79
+ def upload_and_process_file(uploaded_file):
80
+ if uploaded_file is None:
81
+ return "Please upload a medical textbook (pdf)"
82
+
83
+ file_path = f"{uploaded_file.name}"
84
+ with open(file_path, "wb") as f:
85
+ f.write(uploaded_file.read())
86
+
87
+ md_json_objs = parser.get_json_result([file_path])
88
+ image_dicts = parser.get_images(md_json_objs, download_path="data_images")
89
+
90
+ return md_json_objs
91
+
92
+ def ask_question(md_json_objs, query_text, uploaded_query_image=None):
93
+ if not md_json_objs:
94
+ return "No knowledge base loaded. Please upload a file first."
95
+
96
+ text_nodes = get_text_nodes(md_json_objs, "data_images")
97
+
98
+ # Setup index and LLM
99
+ embed_model = OpenAIEmbedding(model="text-embedding-3-large")
100
+ llm = OpenAI("gpt-4o-mini-2024-07-18")
101
+ Settings.llm = llm
102
+ Settings.embed_model = embed_model
103
+
104
+ if not os.path.exists("storage_manuals"):
105
+ index = VectorStoreIndex(text_nodes, embed_model=embed_model)
106
+ index.storage_context.persist(persist_dir="./storage_manuals")
107
+ else:
108
+ ctx = StorageContext.from_defaults(persist_dir="./storage_manuals")
109
+ index = load_index_from_storage(ctx)
110
+
111
+ retriever = index.as_retriever()
112
+
113
+ # Encode query image if provided
114
+ encoded_image_url = None
115
+ if uploaded_query_image is not None:
116
+ query_image_path = f"{uploaded_query_image.name}"
117
+ with open(query_image_path, "wb") as img_file:
118
+ img_file.write(uploaded_query_image.read())
119
+ encoded_image_url = local_image_to_data_url(query_image_path)
120
+
121
+ # Setup query engine
122
+ QA_PROMPT_TMPL = """
123
+ You are a friendly medical chatbot designed to assist users by providing accurate and detailed responses to medical questions based on information from medical books.
124
+
125
+ ### Context:
126
+ ---------------------
127
+ {context_str}
128
+ ---------------------
129
+
130
+ ### Query Text:
131
+ {query_str}
132
+
133
+ ### Query Image:
134
+ ---------------------
135
+ {encoded_image_url}
136
+ ---------------------
137
+
138
+ ### Answer:
139
+ """
140
+ QA_PROMPT = PromptTemplate(QA_PROMPT_TMPL)
141
+ gpt_4o_mm = OpenAIMultiModal(model="gpt-4o-mini-2024-07-18")
142
+
143
+ class MultimodalQueryEngine(CustomQueryEngine):
144
+ qa_prompt: PromptTemplate
145
+ retriever: BaseRetriever
146
+ multi_modal_llm: OpenAIMultiModal
147
+ node_postprocessors: Optional[List[BaseNodePostprocessor]]
148
+
149
+ def __init__(
150
+ self,
151
+ qa_prompt: PromptTemplate,
152
+ retriever: BaseRetriever,
153
+ multi_modal_llm: OpenAIMultiModal,
154
+ node_postprocessors: Optional[List[BaseNodePostprocessor]] = [],
155
+ ):
156
+ super().__init__(
157
+ qa_prompt=qa_prompt,
158
+ retriever=retriever,
159
+ multi_modal_llm=multi_modal_llm,
160
+ node_postprocessors=node_postprocessors
161
+ )
162
+
163
+ def custom_query(self, query_str: str):
164
+ # retrieve most relevant nodes
165
+ nodes = self.retriever.retrieve(query_str)
166
+
167
+ # create image nodes from the image associated with those nodes
168
+ image_nodes = [
169
+ NodeWithScore(node=ImageNode(image_path=n.node.metadata["image_path"]))
170
+ for n in nodes
171
+ ]
172
+
173
+ # create context string from parsed markdown text
174
+ ctx_str = "\n\n".join(
175
+ [r.node.get_content(metadata_mode=MetadataMode.LLM).strip() for r in nodes]
176
+ )
177
+
178
+ # prompt for the LLM
179
+ fmt_prompt = self.qa_prompt.format(
180
+ context_str=ctx_str, query_str=query_str, encoded_image_url=encoded_image_url
181
+ )
182
+
183
+ # use the multimodal LLM to interpret images and generate a response to the prompt
184
+ llm_response = self.multi_modal_llm.complete(
185
+ prompt=fmt_prompt,
186
+ image_documents=[image_node.node for image_node in image_nodes],
187
+ )
188
+
189
+ return Response(
190
+ response=str(llm_response),
191
+ source_nodes=nodes,
192
+ metadata={"text_nodes": nodes, "image_nodes": image_nodes},
193
+ )
194
+
195
+ query_engine = MultimodalQueryEngine(QA_PROMPT, retriever, gpt_4o_mm)
196
+
197
+ response = query_engine.custom_query(query_text)
198
+ return response.response
199
+
200
+ # Define Gradio interface
201
+ md_json_objs = []
202
+
203
+ def upload_wrapper(uploaded_file):
204
+ global md_json_objs
205
+ md_json_objs = upload_and_process_file(uploaded_file)
206
+ return "File successfully processed!"
207
+
208
+ iface = gr.Interface(
209
+ fn=ask_question,
210
+ inputs=[
211
+ gr.inputs.State(),
212
+ gr.inputs.Textbox(label="Enter your query:"),
213
+ gr.inputs.File(label="Upload a query image (if any):", optional=True)
214
+ ],
215
+ outputs="text",
216
+ title="Medical Knowledge Base & Query System"
217
+ )
218
+
219
+ upload_iface = gr.Interface(
220
+ fn=upload_wrapper,
221
+ inputs=gr.inputs.File(label="Upload a medical textbook (pdf):"),
222
+ outputs="text",
223
+ title="Upload Knowledge Base"
224
+ )
225
+
226
+ app = gr.TabbedInterface([upload_iface, iface], ["Upload Knowledge Base", "Ask a Question"])
227
+ app.launch()