import torch import clip import open_clip import h5py import streamlit as st import numpy as np import time from concurrent.futures import ThreadPoolExecutor, as_completed import ml_collections from huggingface_hub import hf_hub_download from ibydmt.test import xSKIT from app_lib.utils import SUPPORTED_MODELS from app_lib.ckde import cKDE rng = np.random.default_rng() testing_config = ml_collections.ConfigDict() testing_config.significance_level = 0.05 testing_config.wealth = "ons" testing_config.bet = "tanh" testing_config.kernel = "rbf" testing_config.kernel_scale_method = "quantile" testing_config.kernel_scale = 0.5 testing_config.tau_max = 200 testing_config.r = 10 def _get_open_clip_model(model_name, device): backbone = model_name.split(":")[-1] model, _, preprocess = open_clip.create_model_and_transforms( SUPPORTED_MODELS[model_name], device=device ) model.eval() tokenizer = open_clip.get_tokenizer(backbone) return model, preprocess, tokenizer def _get_clip_model(model_name, device): backbone = model_name.split(":")[-1] model, preprocess = clip.load(backbone, device=device) tokenizer = clip.tokenize return model, preprocess, tokenizer def load_dataset(dataset_name, model_name): dataset_path = hf_hub_download( repo_id="jacopoteneggi/IBYDMT", filename=f"{dataset_name}_{model_name}_train.h5", repo_type="dataset", ) with h5py.File(dataset_path, "r") as dataset: embedding = dataset["embedding"][:] return embedding def load_model(model_name, device): if "open_clip" in model_name: model, preprocess, tokenizer = _get_open_clip_model(model_name, device) elif "clip" in model_name: model, preprocess, tokenizer = _get_clip_model(model_name, device) return model, preprocess, tokenizer @torch.no_grad() @torch.cuda.amp.autocast() def encode_concepts(tokenizer, model, concepts, device): concepts_text = tokenizer(concepts).to(device) concept_features = model.encode_text(concepts_text) concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True) return concept_features.cpu().numpy() @torch.no_grad() @torch.cuda.amp.autocast() def encode_image(model, preprocess, image, device): image = preprocess(image) image = image.unsqueeze(0) image = image.to(device) image_features = model.encode_image(image) image_features /= image_features.norm(dim=-1, keepdim=True) return image_features.cpu().numpy() @torch.no_grad() @torch.cuda.amp.autocast() def encode_class_name(tokenizer, model, class_name, device): class_text = tokenizer([f"A photo of a {class_name}"]).to(device) class_features = model.encode_text(class_text) class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True) return class_features.cpu().numpy() def _sample_random_subset(concept_idx, concepts, cardinality): sample_idx = list(set(range(len(concepts))) - {concept_idx}) return rng.permutation(sample_idx)[:cardinality].tolist() def _test(z, concept_idx, concepts, cardinality, sampler, classifier): def cond_p(z, cond_idx, m): _, sample_h = sampler.sample(z, cond_idx, m=m) return sample_h def f(h): output = h @ classifier.T return output.squeeze() rejected_hist, tau_hist, wealth_hist, subset_hist = [], [], [], [] for _ in range(testing_config.r): subset_idx = _sample_random_subset(concept_idx, concepts, cardinality) subset = [concepts[idx] for idx in subset_idx] tester = xSKIT(testing_config) rejected, tau = tester.test( z, concept_idx, subset_idx, cond_p, f, interrupt_on_rejection=False ) wealth = tester.wealth._wealth rejected_hist.append(rejected) tau_hist.append(tau) wealth_hist.append(wealth) subset_hist.append(subset) return { "concept": concepts[concept_idx], "rejected": rejected_hist, "tau": tau_hist, "wealth": wealth_hist, "subset": subset_hist, } def test(image, class_name, concepts, cardinality, dataset_name, model_name, device): with st.spinner("Loading model"): model, preprocess, tokenizer = load_model(model_name, device) with st.spinner("Encoding concepts"): cbm = encode_concepts(tokenizer, model, concepts, device) with st.spinner("Encoding image"): h = encode_image(model, preprocess, image, device) z = h @ cbm.T z = z.squeeze() with st.spinner("Testing"): progress_bar = st.progress(0) embedding = load_dataset("imagenette", model_name) semantics = embedding @ cbm.T sampler = cKDE(embedding, semantics) classifier = encode_class_name(tokenizer, model, class_name, device) with ThreadPoolExecutor() as executor: futures = [ executor.submit( _test, z, concept_idx, concepts, cardinality, sampler, classifier ) for concept_idx in range(len(concepts)) ] results = [] for idx, future in enumerate(as_completed(futures)): results.append(future.result()) progress_bar.progress((idx + 1) / len(concepts)) # print(results) # wealth = np.empty((testing_config.tau_max, len(concepts))) # wealth[:] = np.nan # for _results in results: # concept_idx = concepts.index(_results["concept"]) # _wealth = st.session_state.disabled = False st.experimental_rerun()