# from base64 import b64encode from io import BytesIO from math import ceil import clip from multilingual_clip import legacy_multilingual_clip, pt_multilingual_clip import numpy as np import pandas as pd from PIL import Image import requests import streamlit as st import torch from torchvision.transforms import ToPILImage from transformers import AutoTokenizer, AutoModel, BertTokenizer from CLIP_Explainability.clip_ import load, tokenize from CLIP_Explainability.rn_cam import ( # interpret_rn, interpret_rn_overlapped, rn_perword_relevance, ) from CLIP_Explainability.vit_cam import ( # interpret_vit, vit_perword_relevance, interpret_vit_overlapped, ) from pytorch_grad_cam.grad_cam import GradCAM RUN_LITE = True # Load models for CAM viz for M-CLIP and J-CLIP only MAX_IMG_WIDTH = 500 MAX_IMG_HEIGHT = 800 st.set_page_config(layout="wide") # The `find_best_matches` function compares the text feature vector to the feature vectors of all images and finds the best matches. The function returns the IDs of the best matching images. def find_best_matches(text_features, image_features, image_ids): # Compute the similarity between the search query and each image using the Cosine similarity similarities = (image_features @ text_features.T).squeeze(1) # Sort the images by their similarity score best_image_idx = (-similarities).argsort() # Return the image IDs of the best matches return [[image_ids[i], similarities[i].item()] for i in best_image_idx] # The `encode_search_query` function takes a text description and encodes it into a feature vector using the CLIP model. def encode_search_query(search_query, model_type): with torch.no_grad(): # Encode and normalize the search query using the multilingual model if model_type == "M-CLIP (multilingual ViT)": text_encoded = st.session_state.ml_model.forward( search_query, st.session_state.ml_tokenizer ) text_encoded /= text_encoded.norm(dim=-1, keepdim=True) elif model_type == "J-CLIP (日本語 ViT)": t_text = st.session_state.ja_tokenizer( search_query, padding=True, return_tensors="pt", device=st.session_state.device, ) text_encoded = st.session_state.ja_model.get_text_features(**t_text) text_encoded /= text_encoded.norm(dim=-1, keepdim=True) else: # model_type == legacy text_encoded = st.session_state.rn_model(search_query) text_encoded /= text_encoded.norm(dim=-1, keepdim=True) # Retrieve the feature vector return text_encoded.to(st.session_state.device) def clip_search(search_query): if st.session_state.search_field_value != search_query: st.session_state.search_field_value = search_query model_type = st.session_state.active_model if len(search_query) >= 1: text_features = encode_search_query(search_query, model_type) # Compute the similarity between the descrption and each photo using the Cosine similarity # similarities = list((text_features @ photo_features.T).squeeze(0)) # Sort the photos by their similarity score if model_type == "M-CLIP (multilingual ViT)": matches = find_best_matches( text_features, st.session_state.ml_image_features, st.session_state.image_ids, ) elif model_type == "J-CLIP (日本語 ViT)": matches = find_best_matches( text_features, st.session_state.ja_image_features, st.session_state.image_ids, ) else: # model_type == legacy matches = find_best_matches( text_features, st.session_state.rn_image_features, st.session_state.image_ids, ) st.session_state.search_image_ids = [match[0] for match in matches] st.session_state.search_image_scores = {match[0]: match[1] for match in matches} def string_search(): st.session_state.disable_uploader = ( RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)" ) if "search_field_value" in st.session_state: clip_search(st.session_state.search_field_value) def load_image_features(): # Load the image feature vectors if st.session_state.vision_mode == "tiled": ml_image_features = np.load("./image_features/tiled_ml_features.npy") ja_image_features = np.load("./image_features/tiled_ja_features.npy") rn_image_features = np.load("./image_features/tiled_rn_features.npy") elif st.session_state.vision_mode == "stretched": ml_image_features = np.load("./image_features/resized_ml_features.npy") ja_image_features = np.load("./image_features/resized_ja_features.npy") rn_image_features = np.load("./image_features/resized_rn_features.npy") else: # st.session_state.vision_mode == "cropped": ml_image_features = np.load("./image_features/cropped_ml_features.npy") ja_image_features = np.load("./image_features/cropped_ja_features.npy") rn_image_features = np.load("./image_features/cropped_rn_features.npy") # Convert features to Tensors: Float32 on CPU and Float16 on GPU device = st.session_state.device if device == "cpu": ml_image_features = torch.from_numpy(ml_image_features).float().to(device) ja_image_features = torch.from_numpy(ja_image_features).float().to(device) rn_image_features = torch.from_numpy(rn_image_features).float().to(device) else: ml_image_features = torch.from_numpy(ml_image_features).to(device) ja_image_features = torch.from_numpy(ja_image_features).to(device) rn_image_features = torch.from_numpy(rn_image_features).to(device) st.session_state.ml_image_features = ml_image_features / ml_image_features.norm( dim=-1, keepdim=True ) st.session_state.ja_image_features = ja_image_features / ja_image_features.norm( dim=-1, keepdim=True ) st.session_state.rn_image_features = rn_image_features / rn_image_features.norm( dim=-1, keepdim=True ) string_search() def init(): st.session_state.current_page = 1 # device = "cuda" if torch.cuda.is_available() else "cpu" device = "cpu" st.session_state.device = device # Load the open CLIP models with st.spinner("Loading models and data, please wait..."): ml_model_name = "M-CLIP/XLM-Roberta-Large-Vit-B-16Plus" ml_model_path = "./models/vit_b_16_plus_240-laion400m_e32-699c4b84.pt" st.session_state.ml_image_model, st.session_state.ml_image_preprocess = load( ml_model_path, device=device, jit=False ) st.session_state.ml_model = ( pt_multilingual_clip.MultilingualCLIP.from_pretrained(ml_model_name) ).to(device) st.session_state.ml_tokenizer = AutoTokenizer.from_pretrained(ml_model_name) ja_model_name = "hakuhodo-tech/japanese-clip-vit-h-14-bert-wider" ja_model_path = "./models/ViT-H-14-laion2B-s32B-b79K.bin" st.session_state.ja_image_model, st.session_state.ja_image_preprocess = load( ja_model_path, device=device, jit=False ) st.session_state.ja_model = AutoModel.from_pretrained( ja_model_name, trust_remote_code=True ).to(device) st.session_state.ja_tokenizer = AutoTokenizer.from_pretrained( ja_model_name, trust_remote_code=True ) if not RUN_LITE: st.session_state.rn_image_model, st.session_state.rn_image_preprocess = ( clip.load("RN50x4", device=device) ) st.session_state.rn_model = legacy_multilingual_clip.load_model( "M-BERT-Base-69" ).to(device) st.session_state.rn_tokenizer = BertTokenizer.from_pretrained( "bert-base-multilingual-cased" ) # Load the image IDs st.session_state.images_info = pd.read_csv("./metadata.csv") st.session_state.images_info.set_index("filename", inplace=True) with open("./images_list.txt", "r", encoding="utf-8") as images_list: st.session_state.image_ids = list(images_list.read().strip().split("\n")) st.session_state.active_model = "M-CLIP (multilingual ViT)" st.session_state.vision_mode = "tiled" st.session_state.search_image_ids = [] st.session_state.search_image_scores = {} st.session_state.text_table_df = None st.session_state.disable_uploader = ( RUN_LITE and st.session_state.active_model == "Legacy (multilingual ResNet)" ) with st.spinner("Loading models and data, please wait..."): load_image_features() if "images_info" not in st.session_state: init() def get_overlay_vis(image, img_dim, image_model): orig_img_dims = image.size ##### If the features are based on tiled image slices tile_behavior = None if st.session_state.vision_mode == "tiled": scaled_dims = [img_dim, img_dim] if orig_img_dims[0] > orig_img_dims[1]: scale_ratio = round(orig_img_dims[0] / orig_img_dims[1]) if scale_ratio > 1: scaled_dims = [scale_ratio * img_dim, img_dim] tile_behavior = "width" elif orig_img_dims[0] < orig_img_dims[1]: scale_ratio = round(orig_img_dims[1] / orig_img_dims[0]) if scale_ratio > 1: scaled_dims = [img_dim, scale_ratio * img_dim] tile_behavior = "height" resized_image = image.resize(scaled_dims, Image.LANCZOS) if tile_behavior == "width": image_tiles = [] for x in range(0, scale_ratio): box = (x * img_dim, 0, (x + 1) * img_dim, img_dim) image_tiles.append(resized_image.crop(box)) elif tile_behavior == "height": image_tiles = [] for y in range(0, scale_ratio): box = (0, y * img_dim, img_dim, (y + 1) * img_dim) image_tiles.append(resized_image.crop(box)) else: image_tiles = [resized_image] elif st.session_state.vision_mode == "stretched": image_tiles = [image.resize((img_dim, img_dim), Image.LANCZOS)] else: # vision_mode == "cropped" if orig_img_dims[0] > orig_img_dims[1]: scale_factor = orig_img_dims[0] / orig_img_dims[1] resized_img_dims = (round(scale_factor * img_dim), img_dim) resized_img = image.resize(resized_img_dims) elif orig_img_dims[0] < orig_img_dims[1]: scale_factor = orig_img_dims[1] / orig_img_dims[0] resized_img_dims = (img_dim, round(scale_factor * img_dim)) else: resized_img_dims = (img_dim, img_dim) resized_img = image.resize(resized_img_dims) left = round((resized_img_dims[0] - img_dim) / 2) top = round((resized_img_dims[1] - img_dim) / 2) x_right = round(resized_img_dims[0] - img_dim) - left x_bottom = round(resized_img_dims[1] - img_dim) - top right = resized_img_dims[0] - x_right bottom = resized_img_dims[1] - x_bottom # Crop the center of the image image_tiles = [resized_img.crop((left, top, right, bottom))] image_visualizations = [] image_features = [] image_similarities = [] if st.session_state.active_model == "M-CLIP (multilingual ViT)": text_features = st.session_state.ml_model.forward( st.session_state.search_field_value, st.session_state.ml_tokenizer ) if st.session_state.device == "cpu": text_features = text_features.float().to(st.session_state.device) else: text_features = text_features.to(st.session_state.device) for altered_image in image_tiles: p_image = ( st.session_state.ml_image_preprocess(altered_image) .unsqueeze(0) .to(st.session_state.device) ) vis_t, img_feats, similarity = interpret_vit_overlapped( p_image.type(image_model.dtype), text_features.type(image_model.dtype), image_model.visual, st.session_state.device, img_dim=img_dim, ) image_visualizations.append(vis_t) image_features.append(img_feats) image_similarities.append(similarity.item()) elif st.session_state.active_model == "J-CLIP (日本語 ViT)": t_text = st.session_state.ja_tokenizer( st.session_state.search_field_value, return_tensors="pt", device=st.session_state.device, ) text_features = st.session_state.ja_model.get_text_features(**t_text) if st.session_state.device == "cpu": text_features = text_features.float().to(st.session_state.device) else: text_features = text_features.to(st.session_state.device) for altered_image in image_tiles: p_image = ( st.session_state.ja_image_preprocess(altered_image) .unsqueeze(0) .to(st.session_state.device) ) vis_t, img_feats, similarity = interpret_vit_overlapped( p_image.type(image_model.dtype), text_features.type(image_model.dtype), image_model.visual, st.session_state.device, img_dim=img_dim, ) image_visualizations.append(vis_t) image_features.append(img_feats) image_similarities.append(similarity.item()) else: # st.session_state.active_model == Legacy text_features = st.session_state.rn_model(st.session_state.search_field_value) if st.session_state.device == "cpu": text_features = text_features.float().to(st.session_state.device) else: text_features = text_features.to(st.session_state.device) for altered_image in image_tiles: p_image = ( st.session_state.rn_image_preprocess(altered_image) .unsqueeze(0) .to(st.session_state.device) ) vis_t = interpret_rn_overlapped( p_image.type(image_model.dtype), text_features.type(image_model.dtype), image_model.visual, GradCAM, st.session_state.device, img_dim=img_dim, ) text_features_norm = text_features.norm(dim=-1, keepdim=True) text_features_new = text_features / text_features_norm image_feats = image_model.encode_image(p_image.type(image_model.dtype)) image_feats_norm = image_feats.norm(dim=-1, keepdim=True) image_feats_new = image_feats / image_feats_norm similarity = image_feats_new[0].dot(text_features_new[0]) image_visualizations.append(vis_t) image_features.append(p_image) image_similarities.append(similarity.item()) transform = ToPILImage() vis_images = [transform(vis_t) for vis_t in image_visualizations] if st.session_state.vision_mode == "cropped": resized_img.paste(vis_images[0], (left, top)) vis_images = [resized_img] if orig_img_dims[0] > orig_img_dims[1]: scale_factor = MAX_IMG_WIDTH / orig_img_dims[0] scaled_dims = [MAX_IMG_WIDTH, int(orig_img_dims[1] * scale_factor)] else: scale_factor = MAX_IMG_HEIGHT / orig_img_dims[1] scaled_dims = [int(orig_img_dims[0] * scale_factor), MAX_IMG_HEIGHT] if tile_behavior == "width": vis_image = Image.new("RGB", (len(vis_images) * img_dim, img_dim)) for x, v_img in enumerate(vis_images): vis_image.paste(v_img, (x * img_dim, 0)) activations_image = vis_image.resize(scaled_dims) elif tile_behavior == "height": vis_image = Image.new("RGB", (img_dim, len(vis_images) * img_dim)) for y, v_img in enumerate(vis_images): vis_image.paste(v_img, (0, y * img_dim)) activations_image = vis_image.resize(scaled_dims) else: activations_image = vis_images[0].resize(scaled_dims) return activations_image, image_features, np.mean(image_similarities) def visualize_gradcam(image): if "search_field_value" not in st.session_state: return header_cols = st.columns([80, 20], vertical_alignment="bottom") with header_cols[0]: st.title("Image + query activation gradients") with header_cols[1]: if st.button("Close"): st.rerun() if st.session_state.active_model == "M-CLIP (multilingual ViT)": img_dim = 240 image_model = st.session_state.ml_image_model # Sometimes used for token importance viz tokenized_text = st.session_state.ml_tokenizer.tokenize( st.session_state.search_field_value ) elif st.session_state.active_model == "Legacy (multilingual ResNet)": img_dim = 288 image_model = st.session_state.rn_image_model # Sometimes used for token importance viz tokenized_text = st.session_state.rn_tokenizer.tokenize( st.session_state.search_field_value ) else: # J-CLIP img_dim = 224 image_model = st.session_state.ja_image_model # Sometimes used for token importance viz tokenized_text = st.session_state.ja_tokenizer.tokenize( st.session_state.search_field_value ) st.image(image) with st.spinner("Calculating..."): # info_text = st.text("Calculating activation regions...") activations_image, image_features, similarity_score = get_overlay_vis( image, img_dim, image_model ) st.markdown( f"**Query text:** {st.session_state.search_field_value} | **Approx. image relevance:** {round(similarity_score.item(), 3)}" ) st.image(activations_image) # image_io = BytesIO() # activations_image.save(image_io, "PNG") # dataurl = "data:image/png;base64," + b64encode(image_io.getvalue()).decode( # "ascii" # ) # st.html( # f"""