from base64 import b64encode from io import BytesIO from math import ceil from multilingual_clip import pt_multilingual_clip import numpy as np import pandas as pd from PIL import Image import requests import streamlit as st import torch from torchvision.transforms import ToPILImage from transformers import AutoTokenizer, AutoModel from CLIP_Explainability.clip_ import load, tokenize from CLIP_Explainability.vit_cam import ( interpret_vit, vit_perword_relevance, ) # , interpret_vit_overlapped MAX_IMG_WIDTH = 450 # For small dialog MAX_IMG_HEIGHT = 800 st.set_page_config(layout="wide") def init(): st.session_state.current_page = 1 device = "cuda" if torch.cuda.is_available() else "cpu" st.session_state.device = device # Load the open CLIP models ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus" ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt" st.session_state.ml_image_model, st.session_state.ml_image_preprocess = load( ml_model_path, device=device, jit=False ) st.session_state.ml_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained( ml_model_name ) st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name) ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider" ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin" st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load( ja_model_path, device=device, jit=False ) st.session_state.ja_model = AutoModel.from_pretrained( ja_model_name, trust_remote_code=True ).to(device) st.session_state.ja_tokenizer = AutoTokenizer.from_pretrained( ja_model_name, trust_remote_code=True ) st.session_state.active_model = "M-CLIP (multiple languages)" st.session_state.search_image_ids = [] st.session_state.search_image_scores = {} st.session_state.activations_image = None st.session_state.text_table_df = None # Load the image IDs st.session_state.images_info = pd.read_csv("./metadata.csv") st.session_state.images_info.set_index("filename", inplace=True) st.session_state.image_ids = list( open("./images_list.txt", "r", encoding="utf-8").read().strip().split("\n") ) # Load the image feature vectors # ml_image_features = np.load("./multilingual_features.npy") # ja_image_features = np.load("./hakuhodo_features.npy") ml_image_features = np.load("./resized_ml_features.npy") ja_image_features = np.load("./resized_ja_features.npy") # ml_image_features = np.load("./tiled_ml_features.npy") # ja_image_features = np.load("./tiled_ja_features.npy") # Convert features to Tensors: Float32 on CPU and Float16 on GPU if device == "cpu": ml_image_features = torch.from_numpy(ml_image_features).float().to(device) ja_image_features = torch.from_numpy(ja_image_features).float().to(device) else: ml_image_features = torch.from_numpy(ml_image_features).to(device) ja_image_features = torch.from_numpy(ja_image_features).to(device) st.session_state.ml_image_features = ml_image_features / ml_image_features.norm( dim=-1, keepdim=True ) st.session_state.ja_image_features = ja_image_features / ja_image_features.norm( dim=-1, keepdim=True ) if ( "ml_image_features" not in st.session_state or "ja_image_features" not in st.session_state ): with st.spinner("Loading models and data, please wait..."): init() # The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model. def encode_search_query(search_query, model_type): with torch.no_grad(): # Encode and normalize the search query using the multilingual model if model_type == "M-CLIP (multiple languages)": text_encoded = st.session_state.ml_model.forward( search_query, st.session_state.ml_tokenizer ) text_encoded /= text_encoded.norm(dim=-1, keepdim=True) else: # model_type == "J-CLIP (日本語 only)" t_text = st.session_state.ja_tokenizer( search_query, padding=True, return_tensors="pt" ) text_encoded = st.session_state.ja_model.get_text_features(**t_text) text_encoded /= text_encoded.norm(dim=-1, keepdim=True) # Retrieve the feature vector return text_encoded # The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images. def find_best_matches(text_features, image_features, image_ids): # Compute the similarity between the search query and each image using the Cosine similarity similarities = (image_features @ text_features.T).squeeze(1) # Sort the images by their similarity score best_image_idx = (-similarities).argsort() # Return the image IDs of the best matches return [[image_ids[i], similarities[i].item()] for i in best_image_idx] def clip_search(search_query): if st.session_state.search_field_value != search_query: st.session_state.search_field_value = search_query model_type = st.session_state.active_model if len(search_query) >= 1: text_features = encode_search_query(search_query, model_type) # Compute the similarity between the descrption and each photo using the Cosine similarity # similarities = list((text_features @ photo_features.T).squeeze(0)) # Sort the photos by their similarity score if model_type == "M-CLIP (multiple languages)": matches = find_best_matches( text_features, st.session_state.ml_image_features, st.session_state.image_ids, ) else: # model_type == "J-CLIP (日本語 only)" matches = find_best_matches( text_features, st.session_state.ja_image_features, st.session_state.image_ids, ) st.session_state.search_image_ids = [match[0] for match in matches] st.session_state.search_image_scores = {match[0]: match[1] for match in matches} def string_search(): clip_search(st.session_state.search_field_value) def visualize_gradcam(viz_image_id): if not st.session_state.search_field_value: return header_cols = st.columns([80, 20], vertical_alignment="bottom") with header_cols[0]: st.title("Image + query details") with header_cols[1]: if st.button("Close"): st.rerun() st.markdown( f"**Query text:** {st.session_state.search_field_value} | **Image relevance:** {round(st.session_state.search_image_scores[viz_image_id], 3)}" ) # with st.spinner("Calculating..."): info_text = st.text("Calculating activation regions...") image_url = st.session_state.images_info.loc[viz_image_id]["image_url"] image_response = requests.get(image_url) image = Image.open(BytesIO(image_response.content), formats=["JPEG", "GIF"]) img_dim = 224 if st.session_state.active_model == "M-CLIP (multiple languages)": img_dim = 240 orig_img_dims = image.size altered_image = image.resize((img_dim, img_dim), Image.LANCZOS) if st.session_state.active_model == "M-CLIP (multiple languages)": p_image = ( st.session_state.ml_image_preprocess(altered_image) .unsqueeze(0) .to(st.session_state.device) ) # Sometimes used for token importance viz tokenized_text = st.session_state.ml_tokenizer.tokenize( st.session_state.search_field_value ) image_model = st.session_state.ml_image_model # tokenize = st.session_state.ml_tokenizer.tokenize text_features = st.session_state.ml_model.forward( st.session_state.search_field_value, st.session_state.ml_tokenizer ) vis_t = interpret_vit( p_image.type(st.session_state.ml_image_model.dtype), text_features, st.session_state.ml_image_model.visual, st.session_state.device, img_dim=img_dim, ) else: p_image = ( st.session_state.ja_image_preprocess(altered_image) .unsqueeze(0) .to(st.session_state.device) ) # Sometimes used for token importance viz tokenized_text = st.session_state.ja_tokenizer.tokenize( st.session_state.search_field_value ) image_model = st.session_state.ja_image_model t_text = st.session_state.ja_tokenizer( st.session_state.search_field_value, return_tensors="pt" ) text_features = st.session_state.ja_model.get_text_features(**t_text) vis_t = interpret_vit( p_image.type(st.session_state.ja_image_model.dtype), text_features, st.session_state.ja_image_model.visual, st.session_state.device, img_dim=img_dim, ) transform = ToPILImage() vis_img = transform(vis_t) if orig_img_dims[0] > orig_img_dims[1]: scale_factor = MAX_IMG_WIDTH / orig_img_dims[0] scaled_dims = [MAX_IMG_WIDTH, int(orig_img_dims[1] * scale_factor)] else: scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1] scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT] st.session_state.activations_image = vis_img.resize(scaled_dims) image_io = BytesIO() st.session_state.activations_image.save(image_io, "PNG") dataurl = "data:image/png;base64," + b64encode(image_io.getvalue()).decode("ascii") st.html( f"""