Spaces:
Sleeping
Sleeping
import torch | |
import clip | |
import open_clip | |
import h5py | |
import streamlit as st | |
import numpy as np | |
import time | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import ml_collections | |
from huggingface_hub import hf_hub_download | |
from ibydmt.test import xSKIT | |
from app_lib.utils import SUPPORTED_MODELS | |
from app_lib.ckde import cKDE | |
rng = np.random.default_rng() | |
testing_config = ml_collections.ConfigDict() | |
testing_config.significance_level = 0.05 | |
testing_config.wealth = "ons" | |
testing_config.bet = "tanh" | |
testing_config.kernel = "rbf" | |
testing_config.kernel_scale_method = "quantile" | |
testing_config.kernel_scale = 0.5 | |
testing_config.tau_max = 200 | |
testing_config.r = 10 | |
def _get_open_clip_model(model_name, device): | |
backbone = model_name.split(":")[-1] | |
model, _, preprocess = open_clip.create_model_and_transforms( | |
SUPPORTED_MODELS[model_name], device=device | |
) | |
model.eval() | |
tokenizer = open_clip.get_tokenizer(backbone) | |
return model, preprocess, tokenizer | |
def _get_clip_model(model_name, device): | |
backbone = model_name.split(":")[-1] | |
model, preprocess = clip.load(backbone, device=device) | |
tokenizer = clip.tokenize | |
return model, preprocess, tokenizer | |
def load_dataset(dataset_name, model_name): | |
dataset_path = hf_hub_download( | |
repo_id="jacopoteneggi/IBYDMT", | |
filename=f"{dataset_name}_{model_name}_train.h5", | |
repo_type="dataset", | |
) | |
with h5py.File(dataset_path, "r") as dataset: | |
embedding = dataset["embedding"][:] | |
return embedding | |
def load_model(model_name, device): | |
if "open_clip" in model_name: | |
model, preprocess, tokenizer = _get_open_clip_model(model_name, device) | |
elif "clip" in model_name: | |
model, preprocess, tokenizer = _get_clip_model(model_name, device) | |
return model, preprocess, tokenizer | |
def encode_concepts(tokenizer, model, concepts, device): | |
concepts_text = tokenizer(concepts).to(device) | |
concept_features = model.encode_text(concepts_text) | |
concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True) | |
return concept_features.cpu().numpy() | |
def encode_image(model, preprocess, image, device): | |
image = preprocess(image) | |
image = image.unsqueeze(0) | |
image = image.to(device) | |
image_features = model.encode_image(image) | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
return image_features.cpu().numpy() | |
def encode_class_name(tokenizer, model, class_name, device): | |
class_text = tokenizer([f"A photo of a {class_name}"]).to(device) | |
class_features = model.encode_text(class_text) | |
class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True) | |
return class_features.cpu().numpy() | |
def _sample_random_subset(concept_idx, concepts, cardinality): | |
sample_idx = list(set(range(len(concepts))) - {concept_idx}) | |
return rng.permutation(sample_idx)[:cardinality].tolist() | |
def _test(z, concept_idx, concepts, cardinality, sampler, classifier): | |
def cond_p(z, cond_idx, m): | |
_, sample_h = sampler.sample(z, cond_idx, m=m) | |
return sample_h | |
def f(h): | |
output = h @ classifier.T | |
return output.squeeze() | |
rejected_hist, tau_hist, wealth_hist, subset_hist = [], [], [], [] | |
for _ in range(testing_config.r): | |
subset_idx = _sample_random_subset(concept_idx, concepts, cardinality) | |
subset = [concepts[idx] for idx in subset_idx] | |
tester = xSKIT(testing_config) | |
rejected, tau = tester.test( | |
z, concept_idx, subset_idx, cond_p, f, interrupt_on_rejection=False | |
) | |
wealth = tester.wealth._wealth | |
rejected_hist.append(rejected) | |
tau_hist.append(tau) | |
wealth_hist.append(wealth) | |
subset_hist.append(subset) | |
return { | |
"concept": concepts[concept_idx], | |
"rejected": rejected_hist, | |
"tau": tau_hist, | |
"wealth": wealth_hist, | |
"subset": subset_hist, | |
} | |
def test(image, class_name, concepts, cardinality, dataset_name, model_name, device): | |
with st.spinner("Loading model"): | |
model, preprocess, tokenizer = load_model(model_name, device) | |
with st.spinner("Encoding concepts"): | |
cbm = encode_concepts(tokenizer, model, concepts, device) | |
with st.spinner("Encoding image"): | |
h = encode_image(model, preprocess, image, device) | |
z = h @ cbm.T | |
z = z.squeeze() | |
with st.spinner("Testing"): | |
progress_bar = st.progress(0) | |
embedding = load_dataset("imagenette", model_name) | |
semantics = embedding @ cbm.T | |
sampler = cKDE(embedding, semantics) | |
classifier = encode_class_name(tokenizer, model, class_name, device) | |
with ThreadPoolExecutor() as executor: | |
futures = [ | |
executor.submit( | |
_test, z, concept_idx, concepts, cardinality, sampler, classifier | |
) | |
for concept_idx in range(len(concepts)) | |
] | |
results = [] | |
for idx, future in enumerate(as_completed(futures)): | |
results.append(future.result()) | |
progress_bar.progress((idx + 1) / len(concepts)) | |
# print(results) | |
# wealth = np.empty((testing_config.tau_max, len(concepts))) | |
# wealth[:] = np.nan | |
# for _results in results: | |
# concept_idx = concepts.index(_results["concept"]) | |
# _wealth = | |
st.session_state.disabled = False | |
st.experimental_rerun() | |