import os import uuid from llama_index.vector_stores.qdrant import QdrantVectorStore from llama_index.core import VectorStoreIndex, StorageContext import qdrant_client import torch from langchain.text_splitter import RecursiveCharacterTextSplitter import clip from llama_index.core import Document from langchain_community.llms import LlamaCpp import numpy as np from huggingface_hub import hf_hub_download from langchain_community.llms import LlamaCpp from llama_index.core import ( ServiceContext, SimpleDirectoryReader, ) import threading from dotenv import load_dotenv from llama_index.llms.nvidia import NVIDIA from open_clip import create_model_from_pretrained, get_tokenizer from llama_index.core import Settings from llama_index.core import VectorStoreIndex from llama_index.core.vector_stores import VectorStoreQuery from llama_index.core.query_engine import RetrieverQueryEngine from tqdm import tqdm from transformers import AutoTokenizer, AutoModel from langchain.embeddings.base import Embeddings from llama_index.embeddings.langchain import LangchainEmbedding from langchain.embeddings.huggingface import HuggingFaceEmbeddings from llama_index.core import Settings from transformers import AutoProcessor, AutoModel import hashlib import uuid import os import gradio as gr import torch import clip import open_clip import numpy as np from llama_index.core.schema import ImageDocument import cv2 import matplotlib.pyplot as plt os.environ["TOKENIZERS_PARALLELISM"] = "false" from unstructured.partition.pdf import partition_pdf from pathlib import Path from langchain_community.document_loaders import DirectoryLoader, UnstructuredFileLoader from PIL import Image import logging import concurrent.futures import logging from llama_index.core import set_global_service_context from llama_index.core import Document as LlamaIndexDocument import getpass import os from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer from sentence_transformers import util from transformers import AutoTokenizer, AutoModelForCausalLM import base64 from google.generativeai import GenerativeModel, configure import google.generativeai as genai # Configure logging # logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') class MetadataMode: EMBED = "embed" INLINE = "inline" NONE = "none" # Define the vectors configuration vectors_config = { "vector_size": 768, # or whatever the dimensionality of your vectors is "distance": "Cosine" # can be "Cosine", "Euclidean", etc. } class ClinicalBertEmbeddingWrapper: def __init__(self, model_name: str = "medicalai/ClinicalBERT"): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) self.model.eval() def embed(self, text: str): inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) with torch.no_grad(): outputs = self.model(**inputs) embeddings = self.mean_pooling(outputs, inputs['attention_mask']) return embeddings.squeeze().tolist() def mean_pooling(self, model_output, attention_mask): token_embeddings = model_output[0] # First element of model_output contains all token embeddings input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) def embed_documents(self, texts): return [self.embed(text) for text in texts] def embed_query(self, text): return self.embed(text) # Implement this method if needed def get_text_embedding_batch(self, text_batch, show_progress=False): embeddings = [] num_batches = len(text_batch) # Process in batches of size 8 batch_size = 8 for i in tqdm(range(0, num_batches, batch_size), desc="Processing Batches", disable=not show_progress): batch_texts = text_batch[i:i + batch_size] batch_embeddings = self.embed_documents(batch_texts) embeddings.extend(batch_embeddings) return embeddings def get_agg_embedding_from_queries(self, queries): # Get embeddings for each query using the embed method embeddings = [torch.tensor(self.embed(query)) for query in queries] # Convert list of tensors to a single tensor for aggregation embeddings_tensor = torch.stack(embeddings) # Example: averaging embeddings agg_embedding = embeddings_tensor.mean(dim=0) return agg_embedding.tolist() # Load environment variables load_dotenv() genai.configure(api_key=os.environ["GEMINI_API_KEY"]) nvidia_api_key = os.getenv("NVIDIA_API_KEY") if not nvidia_api_key: raise ValueError("NVIDIA_API_KEY not found in .env file") os.environ["NVIDIA_API_KEY"] = nvidia_api_key model_name = "aaditya/OpenBioLLM-Llama3-8B-GGUF" model_file = "openbiollm-llama3-8b.Q5_K_M.gguf" QDRANT_URL = "https://f1e9a70a-afb9-498d-b66d-cb248e0d5557.us-east4-0.gcp.cloud.qdrant.io:6333" QDRANT_API_KEY = "REXlX_PeDvCoXeS9uKCzC--e3-LQV0lw3_jBTdcLZ7P5_F6EOdwklA" # Download model model_path = hf_hub_download(model_name, filename=model_file, local_dir='./') llm = NVIDIA(model="writer/palmyra-med-70b") llm.model local_llm = "openbiollm-llama3-8b.Q5_K_M.gguf" # Initialize ClinicalBert embeddings model # text_embed_model = ClinicalBertEmbeddings(model_name="medicalai/ClinicalBERT") text_embed_model = ClinicalBertEmbeddingWrapper(model_name="medicalai/ClinicalBERT") # Intially I was using this biollm but for faster text response during inference I am going for external models #but with this also it works fine. llm1 = LlamaCpp( model_path=local_llm, temperature=0.3, n_ctx=2048, top_p=1 ) Settings.llm = llm Settings.embed_model = text_embed_model # Define ServiceContext with ClinicalBertEmbeddings for text service_context = ServiceContext.from_defaults( llm = llm, embed_model=text_embed_model # Use ClinicalBert embeddings model ) set_global_service_context(service_context) # Just for logging and Debugging # Log ServiceContext details # logging.debug(f"LLM: {service_context.llm}") # logging.debug(f"Embed Model: {service_context.embed_model}") # logging.debug(f"Node Parser: {service_context.node_parser}") # logging.debug(f"Prompt Helper: {service_context.prompt_helper}") # Create QdrantClient with the location set to ":memory:", which means the vector db will be stored in memory try: text_client = qdrant_client.QdrantClient( url=QDRANT_URL, api_key=QDRANT_API_KEY, port=443, ) print("Qdrant client initialized successfully.") except Exception as e: print(f"Error initializing Qdrant client: {e}") raise # load Text documents from the data_wiki directory # text_documents = SimpleDirectoryReader("./Data").load_data() # Load documents loader = DirectoryLoader("./Data/", glob="**/*.pdf", show_progress=True, loader_cls=UnstructuredFileLoader) documents = loader.load() # Print document names for doc in documents: print(f"Processing document: {doc.metadata.get('source', 'Unknown')}") # Split documents into chunks text_splitter = RecursiveCharacterTextSplitter(chunk_size=700, chunk_overlap=70) texts = text_splitter.split_documents(documents) print(f"Loaded {len(documents)} documents") print(f"Split into {len(texts)} chunks") # Convert langchain documents to llama_index documents text_documents = [ LlamaIndexDocument(text=t.page_content, metadata=t.metadata) for t in texts ] # Initialize Qdrant vector store try: text_vector_store = QdrantVectorStore( client=text_client, collection_name="pdf_text" ) print("Qdrant vector store initialized successfully.") except Exception as e: print(f"Error initializing Qdrant vector store: {e}") raise try: image_vector_store = QdrantVectorStore( client=text_client, collection_name="pdf_img" ) print("Qdrant vector store initialized successfully.") except Exception as e: print(f"Error initializing Qdrant vector store: {e}") raise storage_context = StorageContext.from_defaults(vector_store=text_vector_store) wiki_text_index = VectorStoreIndex.from_documents(text_documents # , storage_context=storage_context , service_context=service_context ) print(f"VectorStoreIndex created with {len(wiki_text_index.docstore.docs)} documents") # define the streaming query engine streaming_qe = wiki_text_index.as_query_engine(streaming=True) print(len(text_documents)) # Function to query the text vector database # Modify the process_query function model, preprocess = clip.load("ViT-B/32") input_resolution = model.visual.input_resolution context_length = model.context_length vocab_size = model.vocab_size print( "Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}", ) print("Input resolution:", input_resolution) print("Context length:", context_length) print("Vocab size:", vocab_size) pdf_directory = Path("./data") image_path = Path("./images1") image_path.mkdir(exist_ok=True, parents=True) # Dictionary to store image metadata image_metadata_dict = {} # Limit the number of images downloaded per PDF MAX_IMAGES_PER_PDF = 15 # Generate a UUID for each image image_uuid = 0 # Iterate over each PDF file in the data folder for pdf_file in pdf_directory.glob("*.pdf"): images_per_pdf = 0 print(f"Processing: {pdf_file}") # Extract images from the PDF try: raw_pdf_elements = partition_pdf( filename=str(pdf_file), extract_images_in_pdf=True, infer_table_structure=True, chunking_strategy="by_title", max_characters=4000, new_after_n_chars=3800, combine_text_under_n_chars=2000, extract_image_block_output_dir=image_path, ) # Loop through the elements except Exception as e: print(f"Error processing {pdf_file}: {e}") import traceback traceback.print_exc() continue # Function to summarize images def summarize_image(image_path): # Load and encode the image with open(image_path, "rb") as image_file: encoded_image = base64.b64encode(image_file.read()).decode('utf-8') # Create a GenerativeModel object model = GenerativeModel('gemini-1.5-flash') # Prepare the prompt prompt = """ You are an expert in analyzing medical images. Please provide a detailed description of this medical image, including: 1. You are a bot that is good at analyzing images related to Dog's health 2. The body part or area being examined 3. Any visible structures, organs, or tissues 4. Any abnormalities, lesions, or notable features 5. Any other relevant medical diagram description. Please be as specific and detailed as possible in your analysis. """ # Generate the response response = model.generate_content([ prompt, {"mime_type": "image/jpeg", "data": encoded_image} ]) return response.text # # Iterate through each file in the directory for image_file in os.listdir(image_path): if image_file.endswith(('.jpg', '.jpeg', '.png')): # Generate a standard UUID for the image image_uuid = str(uuid.uuid4()) image_file_name = image_file image_file_path = image_path / image_file # Generate image summary # image_summary = generate_image_summary_with(str(image_file_path), model, feature_extractor, tokenizer, device) # image_summary = generate_summary_with_lm(str(image_file_path), preprocess, model, device, tokenizer, lm_model) image_summary = summarize_image(image_file_path) # Construct metadata entry for the image image_metadata_dict[image_uuid] = { "filename": image_file_name, "img_path": str(image_file_path), # Store the absolute path to the image "summary": image_summary # Add the summary to the metadata } # Limit the number of images processed per folder if len(image_metadata_dict) >= MAX_IMAGES_PER_PDF: break print(f"Number of items in image_dict: {len(image_metadata_dict)}") # Print the metadata dictionary for key, value in image_metadata_dict.items(): print(f"UUID: {key}, Metadata: {value}") def plot_images_with_opencv(image_metadata_dict): original_images_urls = [] images_shown = 0 plt.figure(figsize=(16, 16)) # Adjust the figure size as needed for image_id in image_metadata_dict: img_path = image_metadata_dict[image_id]["img_path"] if os.path.isfile(img_path): try: img = cv2.imread(img_path) if img is not None: # Convert BGR (OpenCV) to RGB (matplotlib) img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) plt.subplot(8, 8, len(original_images_urls) + 1) plt.imshow(img_rgb) plt.xticks([]) plt.yticks([]) original_images_urls.append(image_metadata_dict[image_id]["filename"]) images_shown += 1 if images_shown >= 64: break except Exception as e: print(f"Error processing image {img_path}: {e}") plt.tight_layout() plt.show() plot_images_with_opencv(image_metadata_dict) # set the device to use for the CLIP model, either CUDA (GPU) or CPU, depending on availability device = "cuda" if torch.cuda.is_available() else "cpu" print(device) # Function to preprocess image using OpenCV def preprocess_image(img): # Convert BGR to RGB img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert the image to a PIL Image and then preprocess img_pil = Image.fromarray(img_rgb) return preprocess(img_pil) # Use BiomedCLIP processor for preprocessing # return preprocess(images=img_pil, return_tensors="pt") # return preprocess(img_pil).unsqueeze(0) img_emb_dict = {} with torch.no_grad(): for image_id in image_metadata_dict: img_file_path = image_metadata_dict[image_id]["img_path"] if os.path.isfile(img_file_path): try: # Load image using OpenCV img = cv2.imread(img_file_path) if img is not None: # Preprocess image image = preprocess_image(img).unsqueeze(0).to(device) # image = preprocess_image(img).to(device) # Extract image features image_features = model.encode_image(image) # Store image features img_emb_dict[image_id] = image_features else: print(f"Failed to load image {img_file_path}") except Exception as e: print(f"Error processing image {img_file_path}: {e}") len(img_emb_dict) #22 image so 22 img emb # create a list of ImageDocument objects, one for each image in the dataset img_documents = [] for image_filename in image_metadata_dict: # the img_emb_dict dictionary contains the image embeddings if image_filename in img_emb_dict: filename = image_metadata_dict[image_filename]["filename"] filepath = image_metadata_dict[image_filename]["img_path"] summary = image_metadata_dict[image_filename]["summary"] #print(filepath) # create an ImageDocument for each image newImgDoc = ImageDocument( text=filename, metadata={"filepath": filepath, "summary": summary} # Include the summary in the metadata ) # set image embedding on the ImageDocument newImgDoc.embedding = img_emb_dict[image_filename].tolist()[0] img_documents.append(newImgDoc) # define storage context storage_context = StorageContext.from_defaults(vector_store=image_vector_store) # define image index image_index = VectorStoreIndex.from_documents( img_documents, storage_context=storage_context ) # for doc in img_documents: # print(f"ImageDocument: {doc.text}, Embedding: {doc.embedding}, Metadata: {doc.metadata}") def retrieve_results_from_image_index(query): """ take a text query as input and return the most similar image from the vector store """ # first tokenize the text query and convert it to a tensor text = clip.tokenize(query).to(device) # encode the text tensor using the CLIP model to produce a query embedding query_embedding = model.encode_text(text).tolist()[0] # Encode the query using ClinicalBERT for text similarity clinical_query_embedding = text_embed_model.embed_query(query) # create a VectorStoreQuery image_vector_store_query = VectorStoreQuery( query_embedding=query_embedding, similarity_top_k=1, # returns 1 image mode="default", ) # execute the query against the image vector store image_retrieval_results = image_vector_store.query( image_vector_store_query ) if image_retrieval_results.nodes: best_score = -1 best_image = None for node, clip_score in zip(image_retrieval_results.nodes, image_retrieval_results.similarities): image_path = node.metadata["filepath"] image_summary = node.metadata.get("summary", "") # Assuming summaries are stored in metadata # Calculate text similarity between query and image summary summary_embedding = text_embed_model.embed_query(image_summary) # text_score = util.cosine_similarity( # [clinical_query_embedding], [summary_embedding] # )[0][0] # Use util.cos_sim for cosine similarity text_score = util.cos_sim(torch.tensor([clinical_query_embedding]), torch.tensor([summary_embedding]))[0][0].item() # Calculate average similarity score avg_score = (clip_score + text_score) / 2 if avg_score > best_score: best_score = avg_score best_image = image_path return best_image, best_score return None, 0.0 def plot_image_retrieve_results(image_retrieval_results): """ Take a list of image retrieval results and create a new figure """ plt.figure(figsize=(16, 5)) img_cnt = 0 # Iterate over the image retrieval results, and for each result, display the corresponding image and its score in a subplot. # The title of the subplot is the score of the image, formatted to four decimal places. for returned_image, score in zip( image_retrieval_results.nodes, image_retrieval_results.similarities ): img_name = returned_image.text img_path = returned_image.metadata["filepath"] # Read image using OpenCV image = cv2.imread(img_path) # Convert image to RGB format (OpenCV reads in BGR by default) image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) plt.subplot(2, 3, img_cnt + 1) plt.title("{:.4f}".format(score)) plt.imshow(image_rgb) plt.xticks([]) plt.yticks([]) img_cnt += 1 plt.tight_layout() plt.show() def get_all_images(): image_paths = [] for _, metadata in image_metadata_dict.items(): image_paths.append(metadata["img_path"]) return image_paths def load_image(image_path): return Image.open(image_path) # Define the combined query function def combined_query(query, similarity_threshold=0.3): # Text query text_response = streaming_qe.query(query) text_result = "" for text in text_response.response_gen: text_result += text # Image query top_image_path, similarity_score = retrieve_results_from_image_index(query) if similarity_score >= similarity_threshold: return text_result, top_image_path, similarity_score else: return text_result, None, similarity_score def gradio_interface(query): text_result, image_path, similarity_score = combined_query(query) top_image = load_image(image_path) if image_path else None all_images = [load_image(path) for path in get_all_images()] return text_result, top_image, all_images, f"Similarity Score: {similarity_score:.4f}" with gr.Blocks() as iface: gr.Markdown("# Medical Knowledge Base Query System") with gr.Row(): query_input = gr.Textbox(lines=2, placeholder="Enter your medical query here...") submit_button = gr.Button("Submit") with gr.Row(): text_output = gr.Textbox(label="Text Response") image_output = gr.Image(label="Top Related Image (if similarity > threshold)") similarity_score_output = gr.Textbox(label="Similarity Score") gallery_output = gr.Gallery(label="All Extracted Images", show_label=True, elem_id="gallery") submit_button.click( fn=gradio_interface, inputs=query_input, outputs=[text_output, image_output, gallery_output, similarity_score_output] ) # Load all images on startup iface.load(lambda: ["", None, [load_image(path) for path in get_all_images()], ""], outputs=[text_output, image_output, gallery_output, similarity_score_output]) # Launch the Gradio interface iface.launch(share=True) # just to check if it works or not # def image_query(query): # image_retrieval_results = retrieve_results_from_image_index(query) # plot_image_retrieve_results(image_retrieval_results) # query1 = "What is gingivitis?" # # generate image retrieval results # image_query(query1) # # Modify your text query function # # def text_query(query): # # text_retrieval_results = process_query(query, text_embed_model, k=10) # # return text_retrieval_results # # Function to query the text vector database # def text_query(query: str, k: int = 10): # # Create a VectorStoreIndex from the existing vector store # index = VectorStoreIndex.from_vector_store(text_vector_store) # # Create a retriever with top-k configuration # retriever = index.as_retriever(similarity_top_k=k) # # Create a query engine # query_engine = RetrieverQueryEngine.from_args(retriever) # # Execute the query # response = query_engine.query(query) # return response # # text_retrieval_results = text_query(query1) # streaming_response = streaming_qe.query( # query1 # ) # streaming_response.print_response_stream()