IBYDMT / app_lib /test.py
jacopoteneggi's picture
Update
dffe47c verified
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
import ml_collections
import numpy as np
import pandas as pd
import streamlit as st
import torch
from huggingface_hub import hf_hub_download
import app_lib.multimodal as multimodal
from app_lib.ckde import cKDE
from app_lib.config import Config
from app_lib.config import Constants as c
from ibydmt.test import xSKIT
rng = np.random.default_rng()
@torch.no_grad()
@torch.cuda.amp.autocast()
def _encode_concepts(model, concepts):
concept_features = model.encode_text(concepts)
concept_features /= torch.linalg.norm(concept_features, dim=-1, keepdim=True)
return concept_features.cpu().numpy()
@torch.no_grad()
@torch.cuda.amp.autocast()
def _encode_image(model, image):
image_features = model.encode_image(image)
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy()
@torch.no_grad()
@torch.cuda.amp.autocast()
def _encode_class_name(model, class_name):
class_text = [f"A photo of a {class_name}"]
class_features = model.encode_text(class_text)
class_features /= torch.linalg.norm(class_features, dim=-1, keepdim=True)
return class_features.cpu().numpy()
def _load_embedding(config):
dataset_path = hf_hub_download(
repo_id="jacopoteneggi/IBYDMT",
filename=(
f"{config.data.dataset.lower()}_train_{config.backbone_name()}.parquet"
),
repo_type="dataset",
)
dataset = pd.read_parquet(dataset_path)
return np.array(dataset["embedding"].values.tolist())
def _sample_random_subset(concept_idx, concepts, cardinality):
sample_idx = list(set(range(len(concepts))) - {concept_idx})
return rng.permutation(sample_idx)[:cardinality].tolist()
def _test(testing_config, z, concept_idx, concepts, cardinality, sampler, classifier):
def cond_p(z, cond_idx, m):
_, sample_h = sampler.sample(z, cond_idx, m=m)
return sample_h
def f(h):
output = h @ classifier.T
return output.squeeze()
rejected_hist, tau_hist, wealth_hist, subset_hist = [], [], [], []
for _ in range(testing_config.r):
subset_idx = _sample_random_subset(concept_idx, concepts, cardinality)
subset = [concepts[idx] for idx in subset_idx]
tester = xSKIT(testing_config)
rejected, tau = tester.test(
z,
concept_idx,
subset_idx,
cond_p,
f,
interrupt_on="max_wealth",
max_wealth=3 * 1 / testing_config.significance_level,
)
wealth = tester.wealth._wealth
wealth = wealth + [wealth[-1]] * (testing_config.tau_max - len(wealth))
rejected_hist.append(rejected)
tau_hist.append(tau)
wealth_hist.append(wealth)
subset_hist.append(subset)
return {
"concept": concepts[concept_idx],
"rejected": rejected_hist,
"tau": tau_hist,
"wealth": wealth_hist,
"subset": subset_hist,
}
def get_testing_config(**kwargs):
testing_config = st.session_state.testing_config = ml_collections.ConfigDict()
testing_config.significance_level = kwargs.get("significance_level", 0.05)
testing_config.wealth = kwargs.get("wealth", "ons")
testing_config.bet = kwargs.get("bet", "tanh")
testing_config.kernel = kwargs.get("kernel", "rbf")
testing_config.kernel_scale_method = kwargs.get("kernel_scale_method", "quantile")
testing_config.kernel_scale = kwargs.get("kernel_scale", 0.5)
testing_config.tau_max = kwargs.get("tau_max", 200)
testing_config.r = kwargs.get("r", 10)
return testing_config
def load_precomputed_results(image_name):
results = np.load(
os.path.join("assets", "results", f"{image_name.split('.')[0]}.npy"),
allow_pickle=True,
).item()
return results
def test(
testing_config,
image,
class_name,
concepts,
cardinality,
dataset_name,
model_name,
device=c.DEVICE,
with_streamlit=True,
):
config = Config()
config.data.dataset = dataset_name
config.data.backbone = model_name
if with_streamlit:
with st.spinner("Loading model"):
model = multimodal.get_model(config, device=device)
else:
model = multimodal.get_model(config, device=device)
if with_streamlit:
with st.spinner("Encoding concepts"):
cbm = _encode_concepts(model, concepts)
else:
cbm = _encode_concepts(model, concepts)
if with_streamlit:
with st.spinner("Encoding image"):
h = _encode_image(model, image)
else:
h = _encode_image(model, image)
z = h @ cbm.T
z = z.squeeze()
if with_streamlit:
progress_bar = st.progress(
0,
text=(
"Testing concepts (can take up to a minute) [0 /"
f" {len(concepts)} completed]"
),
)
progress_bar.progress(
1 / (len(concepts) + 1),
text=(
"Testing concepts (can take up to a minute) [0 /"
f" {len(concepts)} completed]"
),
)
embedding = _load_embedding(config)
semantics = embedding @ cbm.T
sampler = cKDE(embedding, semantics)
classifier = _encode_class_name(model, class_name)
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(
_test,
testing_config,
z,
concept_idx,
concepts,
cardinality,
sampler,
classifier,
)
for concept_idx in range(len(concepts))
]
results = []
for idx, future in enumerate(as_completed(futures)):
results.append(future.result())
if with_streamlit:
progress_bar.progress(
(idx + 2) / (len(concepts) + 1),
text=(
f"Testing concepts (can take up to a minute) [{idx + 1} /"
f" {len(concepts)} completed]"
),
)
rejected = np.empty((testing_config.r, len(concepts)))
tau = np.empty((testing_config.r, len(concepts)))
wealth = np.empty((testing_config.r, testing_config.tau_max, len(concepts)))
for _results in results:
concept_idx = concepts.index(_results["concept"])
rejected[:, concept_idx] = np.array(_results["rejected"])
tau[:, concept_idx] = np.array(_results["tau"])
wealth[:, :, concept_idx] = np.array(_results["wealth"])
tau /= testing_config.tau_max
results = {
"significance_level": testing_config.significance_level,
"concepts": concepts,
"rejected": rejected,
"tau": tau,
"wealth": wealth,
}
return results