|
import gradio as gr |
|
import os |
|
import nest_asyncio |
|
import re |
|
from pathlib import Path |
|
import typing as t |
|
import base64 |
|
from mimetypes import guess_type |
|
from llama_parse import LlamaParse |
|
from llama_index.core.schema import TextNode |
|
from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage, Settings |
|
from llama_index.embeddings.openai import OpenAIEmbedding |
|
from llama_index.llms.openai import OpenAI |
|
from llama_index.core.query_engine import CustomQueryEngine |
|
from llama_index.multi_modal_llms.openai import OpenAIMultiModal |
|
from llama_index.core.prompts import PromptTemplate |
|
from llama_index.core.schema import ImageNode |
|
from llama_index.core.base.response.schema import Response |
|
from typing import Any, List, Optional |
|
from llama_index.core.postprocessor.types import BaseNodePostprocessor |
|
|
|
nest_asyncio.apply() |
|
|
|
|
|
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY') |
|
os.environ["LLAMA_CLOUD_API_KEY"] = os.getenv('LLAMA_CLOUD_API_KEY') |
|
|
|
|
|
parser = LlamaParse( |
|
result_type="markdown", |
|
parsing_instruction="You are given a medical textbook on medicine", |
|
use_vendor_multimodal_model=True, |
|
vendor_multimodal_model_name="gpt-4o-mini-2024-07-18", |
|
show_progress=True, |
|
verbose=True, |
|
invalidate_cache=True, |
|
do_not_cache=True, |
|
num_workers=8, |
|
language="en" |
|
) |
|
|
|
|
|
def local_image_to_data_url(image_path): |
|
mime_type, _ = guess_type(image_path) |
|
if mime_type is None: |
|
mime_type = 'image/png' |
|
with open(image_path, "rb") as image_file: |
|
base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8') |
|
return f"data:{mime_type};base64,{base64_encoded_data}" |
|
|
|
|
|
def get_page_number(file_name): |
|
match = re.search(r"-page-(\d+)\.jpg$", str(file_name)) |
|
if match: |
|
return int(match.group(1)) |
|
return 0 |
|
|
|
def _get_sorted_image_files(image_dir): |
|
raw_files = [f for f in list(Path(image_dir).iterdir()) if f.is_file()] |
|
sorted_files = sorted(raw_files, key=get_page_number) |
|
return sorted_files |
|
|
|
def get_text_nodes(md_json_objs, image_dir) -> t.List[TextNode]: |
|
nodes = [] |
|
for result in md_json_objs: |
|
json_dicts = result["pages"] |
|
document_name = result["file_path"].split('/')[-1] |
|
docs = [doc["md"] for doc in json_dicts] |
|
image_files = _get_sorted_image_files(image_dir) |
|
for idx, doc in enumerate(docs): |
|
node = TextNode( |
|
text=doc, |
|
metadata={"image_path": str(image_files[idx]), "page_num": idx + 1, "document_name": document_name}, |
|
) |
|
nodes.append(node) |
|
return nodes |
|
|
|
|
|
def upload_and_process_file(uploaded_file): |
|
if uploaded_file is None: |
|
return "Please upload a medical textbook (pdf)" |
|
|
|
file_path = f"{uploaded_file.name}" |
|
with open(file_path, "wb") as f: |
|
f.write(uploaded_file.read()) |
|
|
|
md_json_objs = parser.get_json_result([file_path]) |
|
image_dicts = parser.get_images(md_json_objs, download_path="data_images") |
|
|
|
return md_json_objs |
|
|
|
def ask_question(md_json_objs, query_text, uploaded_query_image=None): |
|
if not md_json_objs: |
|
return "No knowledge base loaded. Please upload a file first." |
|
|
|
text_nodes = get_text_nodes(md_json_objs, "data_images") |
|
|
|
|
|
embed_model = OpenAIEmbedding(model="text-embedding-3-large") |
|
llm = OpenAI("gpt-4o-mini-2024-07-18") |
|
Settings.llm = llm |
|
Settings.embed_model = embed_model |
|
|
|
if not os.path.exists("storage_manuals"): |
|
index = VectorStoreIndex(text_nodes, embed_model=embed_model) |
|
index.storage_context.persist(persist_dir="./storage_manuals") |
|
else: |
|
ctx = StorageContext.from_defaults(persist_dir="./storage_manuals") |
|
index = load_index_from_storage(ctx) |
|
|
|
retriever = index.as_retriever() |
|
|
|
|
|
encoded_image_url = None |
|
if uploaded_query_image is not None: |
|
query_image_path = f"{uploaded_query_image.name}" |
|
with open(query_image_path, "wb") as img_file: |
|
img_file.write(uploaded_query_image.read()) |
|
encoded_image_url = local_image_to_data_url(query_image_path) |
|
|
|
|
|
QA_PROMPT_TMPL = """ |
|
You are a friendly medical chatbot designed to assist users by providing accurate and detailed responses to medical questions based on information from medical books. |
|
|
|
### Context: |
|
--------------------- |
|
{context_str} |
|
--------------------- |
|
|
|
### Query Text: |
|
{query_str} |
|
|
|
### Query Image: |
|
--------------------- |
|
{encoded_image_url} |
|
--------------------- |
|
|
|
### Answer: |
|
""" |
|
QA_PROMPT = PromptTemplate(QA_PROMPT_TMPL) |
|
gpt_4o_mm = OpenAIMultiModal(model="gpt-4o-mini-2024-07-18") |
|
|
|
class MultimodalQueryEngine(CustomQueryEngine): |
|
qa_prompt: PromptTemplate |
|
retriever: BaseRetriever |
|
multi_modal_llm: OpenAIMultiModal |
|
node_postprocessors: Optional[List[BaseNodePostprocessor]] |
|
|
|
def __init__( |
|
self, |
|
qa_prompt: PromptTemplate, |
|
retriever: BaseRetriever, |
|
multi_modal_llm: OpenAIMultiModal, |
|
node_postprocessors: Optional[List[BaseNodePostprocessor]] = [], |
|
): |
|
super().__init__( |
|
qa_prompt=qa_prompt, |
|
retriever=retriever, |
|
multi_modal_llm=multi_modal_llm, |
|
node_postprocessors=node_postprocessors |
|
) |
|
|
|
def custom_query(self, query_str: str): |
|
|
|
nodes = self.retriever.retrieve(query_str) |
|
|
|
|
|
image_nodes = [ |
|
NodeWithScore(node=ImageNode(image_path=n.node.metadata["image_path"])) |
|
for n in nodes |
|
] |
|
|
|
|
|
ctx_str = "\n\n".join( |
|
[r.node.get_content(metadata_mode=MetadataMode.LLM).strip() for r in nodes] |
|
) |
|
|
|
|
|
fmt_prompt = self.qa_prompt.format( |
|
context_str=ctx_str, query_str=query_str, encoded_image_url=encoded_image_url |
|
) |
|
|
|
|
|
llm_response = self.multi_modal_llm.complete( |
|
prompt=fmt_prompt, |
|
image_documents=[image_node.node for image_node in image_nodes], |
|
) |
|
|
|
return Response( |
|
response=str(llm_response), |
|
source_nodes=nodes, |
|
metadata={"text_nodes": nodes, "image_nodes": image_nodes}, |
|
) |
|
|
|
query_engine = MultimodalQueryEngine(QA_PROMPT, retriever, gpt_4o_mm) |
|
|
|
response = query_engine.custom_query(query_text) |
|
return response.response |
|
|
|
|
|
md_json_objs = [] |
|
|
|
def upload_wrapper(uploaded_file): |
|
global md_json_objs |
|
md_json_objs = upload_and_process_file(uploaded_file) |
|
return "File successfully processed!" |
|
|
|
iface = gr.Interface( |
|
fn=ask_question, |
|
inputs=[ |
|
gr.inputs.State(), |
|
gr.inputs.Textbox(label="Enter your query:"), |
|
gr.inputs.File(label="Upload a query image (if any):", optional=True) |
|
], |
|
outputs="text", |
|
title="Medical Knowledge Base & Query System" |
|
) |
|
|
|
upload_iface = gr.Interface( |
|
fn=upload_wrapper, |
|
inputs=gr.inputs.File(label="Upload a medical textbook (pdf):"), |
|
outputs="text", |
|
title="Upload Knowledge Base" |
|
) |
|
|
|
app = gr.TabbedInterface([upload_iface, iface], ["Upload Knowledge Base", "Ask a Question"]) |
|
app.launch() |
|
|