import os import torch from datasets import load_dataset from transformers import AutoTokenizer, AutoModel import chromadb #from chromadb.utils import PersistenceManager import gradio as gr # Load the Hugging Face token from the environment variable # hf_token = os.getenv("HF_API_TOKEN") # Load the private dataset using the token #dataset = load_dataset("thankrandomness/mimic-iii", token=hf_token) dataset = load_dataset("thankrandomness/mimic-iii-sample") # Load PubMedBERT model and tokenizer tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") model = AutoModel.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") # Initialize ChromaDB client with persistence #persistence_manager = PersistenceManager("/mnt/data/chromadb") #client = chromadb.Client(persistence_manager=persistence_manager) #client = chromadb.Client() #collection = client.get_or_create_collection(name="pubmedbert_matryoshka_embeddings") #collection = client.get_or_create_collection(name="pubmedbert_embeddings") # Function to embed text def embed_text(text, max_length=512): inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length) with torch.no_grad(): embeddings = model(**inputs).last_hidden_state.mean(dim=1).squeeze() return embeddings.numpy() # Initialize ChromaDB client client = chromadb.Client() collection = client.create_collection(name="pubmedbert_embeddings") # Process the dataset and upsert into ChromaDB for i, row in enumerate(dataset['train']): for note in row['notes']: text = note.get('text', '') annotations_list = [] for annotation in note.get('annotations', []): try: code = annotation['code'] code_system = annotation['code_system'] description = annotation['description'] #annotations_list.append(f"{code}: {code_system}: {description}") annotations_list.append({"code": code, "code_system": code_system, "description": description}) except KeyError as e: print(f"Skipping annotation due to missing key: {e}") print(f"Processed annotations for note {note['note_id']}: {annotations_list}") if text and annotations_list: embeddings = embed_text([text])[0] # Upsert data, embeddings, and annotations into ChromaDB for j, annotation in enumerate(annotations_list): collection.upsert( ids=[f"note_{note['note_id']}_{j}"], embeddings=[embeddings], metadatas=[annotation] ) else: print(f"Skipping note {note['note_id']} due to missing 'text' or 'annotations'") # Define retrieval function def retrieve_relevant_text(input_text): input_embedding = embed_text([input_text])[0] # Get the embedding for the single input text results = collection.query(query_embeddings=[input_embedding], n_results=5) print(results) # Extract code and similarity scores output = [] for result in results['results']: print(result) for annotation in result["metadata"]["annotations"]: output.append({ "similarity_score": result["distances"], "annotation": annotation }) return output # Gradio interface def gradio_interface(input_text): results = retrieve_relevant_text(input_text) formatted_results = [ f"Similarity Score: {result['similarity_score']:.2f}, Code: {result['code']}, Description: {result['description']}" for result in results ] return formatted_results interface = gr.Interface(fn=gradio_interface, inputs="text", outputs="text") interface.launch()