import os from concurrent.futures import ThreadPoolExecutor, as_completed import ml_collections import numpy as np import pandas as pd import streamlit as st import torch from huggingface_hub import hf_hub_download import app_lib.multimodal as multimodal from app_lib.ckde import cKDE from app_lib.config import Config from app_lib.config import Constants as c from ibydmt.test import xSKIT rng = np.random.default_rng() @torch.no_grad() @torch.cuda.amp.autocast() def _encode_concepts(model, concepts): concept_features = model.encode_text(concepts) concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True) return concept_features.cpu().numpy() @torch.no_grad() @torch.cuda.amp.autocast() def _encode_image(model, image): image_features = model.encode_image(image) image_features /= image_features.norm(dim=-1, keepdim=True) return image_features.cpu().numpy() @torch.no_grad() @torch.cuda.amp.autocast() def _encode_class_name(model, class_name): class_text = [f"A photo of a {class_name}"] class_features = model.encode_text(class_text) class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True) return class_features.cpu().numpy() def _load_embedding(config): dataset_path = hf_hub_download( repo_id="jacopoteneggi/IBYDMT", filename=( f"{config.data.dataset.lower()}_train_{config.backbone_name()}.parquet" ), repo_type="dataset", ) dataset = pd.read_parquet(dataset_path) return np.array(dataset["embedding"].values.tolist()) def _sample_random_subset(concept_idx, concepts, cardinality): sample_idx = list(set(range(len(concepts))) - {concept_idx}) return rng.permutation(sample_idx)[:cardinality].tolist() def _test(testing_config, z, concept_idx, concepts, cardinality, sampler, classifier): def cond_p(z, cond_idx, m): _, sample_h = sampler.sample(z, cond_idx, m=m) return sample_h def f(h): output = h @ classifier.T return output.squeeze() rejected_hist, tau_hist, wealth_hist, subset_hist = [], [], [], [] for _ in range(testing_config.r): subset_idx = _sample_random_subset(concept_idx, concepts, cardinality) subset = [concepts[idx] for idx in subset_idx] tester = xSKIT(testing_config) rejected, tau = tester.test( z, concept_idx, subset_idx, cond_p, f, interrupt_on="max_wealth", max_wealth=3 * 1 / testing_config.significance_level, ) wealth = tester.wealth._wealth wealth = wealth + [wealth[-1]] * (testing_config.tau_max - len(wealth)) rejected_hist.append(rejected) tau_hist.append(tau) wealth_hist.append(wealth) subset_hist.append(subset) return { "concept": concepts[concept_idx], "rejected": rejected_hist, "tau": tau_hist, "wealth": wealth_hist, "subset": subset_hist, } def get_testing_config(**kwargs): testing_config = st.session_state.testing_config = ml_collections.ConfigDict() testing_config.significance_level = kwargs.get("significance_level", 0.05) testing_config.wealth = kwargs.get("wealth", "ons") testing_config.bet = kwargs.get("bet", "tanh") testing_config.kernel = kwargs.get("kernel", "rbf") testing_config.kernel_scale_method = kwargs.get("kernel_scale_method", "quantile") testing_config.kernel_scale = kwargs.get("kernel_scale", 0.5) testing_config.tau_max = kwargs.get("tau_max", 200) testing_config.r = kwargs.get("r", 10) return testing_config def load_precomputed_results(image_name): results = np.load( os.path.join("assets", "results", f"{image_name.split('.')[0]}.npy"), allow_pickle=True, ).item() return results def test( testing_config, image, class_name, concepts, cardinality, dataset_name, model_name, device=c.DEVICE, with_streamlit=True, ): config = Config() config.data.dataset = dataset_name config.data.backbone = model_name if with_streamlit: with st.spinner("Loading model"): model = multimodal.get_model(config, device=device) else: model = multimodal.get_model(config, device=device) if with_streamlit: with st.spinner("Encoding concepts"): cbm = _encode_concepts(model, concepts) else: cbm = _encode_concepts(model, concepts) if with_streamlit: with st.spinner("Encoding image"): h = _encode_image(model, image) else: h = _encode_image(model, image) z = h @ cbm.T z = z.squeeze() if with_streamlit: progress_bar = st.progress( 0, text=( "Testing concepts (can take up to a minute) [0 /" f" {len(concepts)} completed]" ), ) progress_bar.progress( 1 / (len(concepts) + 1), text=( "Testing concepts (can take up to a minute) [0 /" f" {len(concepts)} completed]" ), ) embedding = _load_embedding(config) semantics = embedding @ cbm.T sampler = cKDE(embedding, semantics) classifier = _encode_class_name(model, class_name) with ThreadPoolExecutor() as executor: futures = [ executor.submit( _test, testing_config, z, concept_idx, concepts, cardinality, sampler, classifier, ) for concept_idx in range(len(concepts)) ] results = [] for idx, future in enumerate(as_completed(futures)): results.append(future.result()) if with_streamlit: progress_bar.progress( (idx + 2) / (len(concepts) + 1), text=( f"Testing concepts (can take up to a minute) [{idx + 1} /" f" {len(concepts)} completed]" ), ) rejected = np.empty((testing_config.r, len(concepts))) tau = np.empty((testing_config.r, len(concepts))) wealth = np.empty((testing_config.r, testing_config.tau_max, len(concepts))) for _results in results: concept_idx = concepts.index(_results["concept"]) rejected[:, concept_idx] = np.array(_results["rejected"]) tau[:, concept_idx] = np.array(_results["tau"]) wealth[:, :, concept_idx] = np.array(_results["wealth"]) tau /= testing_config.tau_max results = { "significance_level": testing_config.significance_level, "concepts": concepts, "rejected": rejected, "tau": tau, "wealth": wealth, } return results