Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import os | |
import matplotlib.pyplot as plt | |
from transformers import AutoTokenizer, CLIPProcessor | |
from medclip.modeling_hybrid_clip import FlaxHybridCLIP | |
def load_model(): | |
model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
return model, processor | |
def load_image_embeddings(): | |
embeddings_df = pd.read_pickle('feature_store/image_embeddings_large.pkl') | |
image_embeds = np.stack(embeddings_df['image_embedding']) | |
image_files = np.asarray(embeddings_df['files'].tolist()) | |
return image_files, image_embeds | |
k = 5 | |
img_dir = './images' | |
st.title("MedCLIP 🩺") | |
st.image("./assets/logo.png", width=100) | |
st.markdown("""Search for medical images with natural language powered by a CLIP model [[Model Card]](https://huggingface.co/flax-community/medclip-roco) finetuned on the | |
[Radiology Objects in COntext (ROCO) dataset](https://github.com/razorx89/roco-dataset).""") | |
st.markdown("""Example queries: | |
* `ultrasound scans` | |
* `pathology` | |
* `pancreatic carcinoma`""") | |
image_list, image_embeddings = load_image_embeddings() | |
model, processor = load_model() | |
query = st.text_input("Enter your query here:") | |
if st.button("Search"): | |
with st.spinner(f"Searching ROCO test set for {query}..."): | |
inputs = processor(text=[query], images=None, return_tensors="jax", padding=True) | |
query_embedding = model.get_text_features(**inputs) | |
query_embedding = np.asarray(query_embedding) | |
query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=-1, keepdims=True) | |
dot_prod = np.sum(np.multiply(query_embedding, image_embeddings), axis=1) | |
topk_images = dot_prod.argsort()[-k:] | |
matching_images = image_list[topk_images] | |
top_scores = 1. - dot_prod[topk_images] | |
#show images | |
for img_path, score in zip(matching_images, top_scores): | |
img = plt.imread(os.path.join(img_dir, img_path)) | |
st.image(img) | |
st.write(f"{img_path} ({score:.2f})", help="score") | |