IBYDMT / app_lib /test.py
jacopoteneggi's picture
Update
a40e67a verified
raw
history blame
5.64 kB
import torch
import clip
import open_clip
import h5py
import streamlit as st
import numpy as np
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
import ml_collections
from huggingface_hub import hf_hub_download
from ibydmt.test import xSKIT
from app_lib.utils import SUPPORTED_MODELS
from app_lib.ckde import cKDE
rng = np.random.default_rng()
testing_config = ml_collections.ConfigDict()
testing_config.significance_level = 0.05
testing_config.wealth = "ons"
testing_config.bet = "tanh"
testing_config.kernel = "rbf"
testing_config.kernel_scale_method = "quantile"
testing_config.kernel_scale = 0.5
testing_config.tau_max = 200
testing_config.r = 10
def _get_open_clip_model(model_name, device):
backbone = model_name.split(":")[-1]
model, _, preprocess = open_clip.create_model_and_transforms(
SUPPORTED_MODELS[model_name], device=device
)
model.eval()
tokenizer = open_clip.get_tokenizer(backbone)
return model, preprocess, tokenizer
def _get_clip_model(model_name, device):
backbone = model_name.split(":")[-1]
model, preprocess = clip.load(backbone, device=device)
tokenizer = clip.tokenize
return model, preprocess, tokenizer
def load_dataset(dataset_name, model_name):
dataset_path = hf_hub_download(
repo_id="jacopoteneggi/IBYDMT",
filename=f"{dataset_name}_{model_name}_train.h5",
repo_type="dataset",
)
with h5py.File(dataset_path, "r") as dataset:
embedding = dataset["embedding"][:]
return embedding
def load_model(model_name, device):
if "open_clip" in model_name:
model, preprocess, tokenizer = _get_open_clip_model(model_name, device)
elif "clip" in model_name:
model, preprocess, tokenizer = _get_clip_model(model_name, device)
return model, preprocess, tokenizer
@torch.no_grad()
@torch.cuda.amp.autocast()
def encode_concepts(tokenizer, model, concepts, device):
concepts_text = tokenizer(concepts).to(device)
concept_features = model.encode_text(concepts_text)
concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True)
return concept_features.cpu().numpy()
@torch.no_grad()
@torch.cuda.amp.autocast()
def encode_image(model, preprocess, image, device):
image = preprocess(image)
image = image.unsqueeze(0)
image = image.to(device)
image_features = model.encode_image(image)
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy()
@torch.no_grad()
@torch.cuda.amp.autocast()
def encode_class_name(tokenizer, model, class_name, device):
class_text = tokenizer([f"A photo of a {class_name}"]).to(device)
class_features = model.encode_text(class_text)
class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True)
return class_features.cpu().numpy()
def _sample_random_subset(concept_idx, concepts, cardinality):
sample_idx = list(set(range(len(concepts))) - {concept_idx})
return rng.permutation(sample_idx)[:cardinality].tolist()
def _test(z, concept_idx, concepts, cardinality, sampler, classifier):
def cond_p(z, cond_idx, m):
_, sample_h = sampler.sample(z, cond_idx, m=m)
return sample_h
def f(h):
output = h @ classifier.T
return output.squeeze()
rejected_hist, tau_hist, wealth_hist, subset_hist = [], [], [], []
for _ in range(testing_config.r):
subset_idx = _sample_random_subset(concept_idx, concepts, cardinality)
subset = [concepts[idx] for idx in subset_idx]
tester = xSKIT(testing_config)
rejected, tau = tester.test(
z, concept_idx, subset_idx, cond_p, f, interrupt_on_rejection=False
)
wealth = tester.wealth._wealth
rejected_hist.append(rejected)
tau_hist.append(tau)
wealth_hist.append(wealth)
subset_hist.append(subset)
return {
"concept": concepts[concept_idx],
"rejected": rejected_hist,
"tau": tau_hist,
"wealth": wealth_hist,
"subset": subset_hist,
}
def test(image, class_name, concepts, cardinality, dataset_name, model_name, device):
with st.spinner("Loading model"):
model, preprocess, tokenizer = load_model(model_name, device)
with st.spinner("Encoding concepts"):
cbm = encode_concepts(tokenizer, model, concepts, device)
with st.spinner("Encoding image"):
h = encode_image(model, preprocess, image, device)
z = h @ cbm.T
z = z.squeeze()
with st.spinner("Testing"):
progress_bar = st.progress(0)
embedding = load_dataset("imagenette", model_name)
semantics = embedding @ cbm.T
sampler = cKDE(embedding, semantics)
classifier = encode_class_name(tokenizer, model, class_name, device)
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(
_test, z, concept_idx, concepts, cardinality, sampler, classifier
)
for concept_idx in range(len(concepts))
]
results = []
for idx, future in enumerate(as_completed(futures)):
results.append(future.result())
progress_bar.progress((idx + 1) / len(concepts))
# print(results)
# wealth = np.empty((testing_config.tau_max, len(concepts)))
# wealth[:] = np.nan
# for _results in results:
# concept_idx = concepts.index(_results["concept"])
# _wealth =
st.session_state.disabled = False
st.experimental_rerun()