Spaces:
Sleeping
Sleeping
import json | |
import os | |
import streamlit as st | |
from PIL import Image | |
from streamlit_image_select import image_select | |
from app_lib.defaults import Defaults as d | |
from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS | |
IMAGE_DIR = os.path.join("assets", "images") | |
IMAGE_NAMES = list(sorted(filter(lambda x: x.endswith(".jpg"), os.listdir(IMAGE_DIR)))) | |
IMAGE_PATHS = list(map(lambda x: os.path.join(IMAGE_DIR, x), IMAGE_NAMES)) | |
IMAGE_PRESETS = json.load(open("assets/image_presets.json")) | |
def _validate_class_name(class_name): | |
if class_name is None: | |
return (False, "Class name cannot be empty.") | |
if class_name.strip() == "": | |
return (False, "Class name cannot be empty.") | |
return (True, None) | |
def _validate_concepts(concepts): | |
if len(concepts) < 3: | |
return (False, "You must provide at least 3 concepts") | |
if len(concepts) > 10: | |
return (False, "Maximum 10 concepts allowed") | |
return (True, None) | |
def _get_significance_level(): | |
default = d.SIGNIFICANCE_LEVEL_VALUE | |
step = d.SIGNIFICANCE_LEVEL_STEP | |
return st.slider( | |
"Significance level", | |
help=f"The level of significance of the tests. Defaults to {default:.2F}.", | |
min_value=step, | |
max_value=1.0, | |
value=default, | |
step=step, | |
disabled=st.session_state.disabled, | |
) | |
def _get_tau_max(): | |
default = d.TAU_MAX_VALUE | |
step = d.TAU_MAX_STEP | |
return int( | |
st.slider( | |
"Length of test", | |
help=f"The maximum number of steps for each test. Defaults to {default}.", | |
min_value=step, | |
max_value=1000, | |
step=step, | |
value=default, | |
disabled=st.session_state.disabled, | |
) | |
) | |
def _get_number_of_tests(): | |
default = d.R_VALUE | |
step = d.R_STEP | |
return int( | |
st.slider( | |
"Number of tests per concept", | |
help=( | |
"The number of tests to average for each concept. " | |
f"Defaults to {default}." | |
), | |
min_value=step, | |
max_value=100, | |
step=step, | |
value=default, | |
disabled=st.session_state.disabled, | |
) | |
) | |
def _get_cardinality(concepts, concepts_ready): | |
default = d.CARDINALITY_VALUE | |
step = d.CARDINALITY_STEP | |
return st.slider( | |
"Size of conditioning set", | |
help=( | |
"The number of concepts to condition model predictions on. " | |
"Defaults to {default}." | |
), | |
min_value=1, | |
max_value=max(2, len(concepts) - 1), | |
value=default, | |
step=step, | |
disabled=st.session_state.disabled or not concepts_ready, | |
) | |
def _get_dataset_name(): | |
options = SUPPORTED_DATASETS | |
default_idx = options.index(d.DATASET_NAME) | |
return st.selectbox( | |
"Dataset", | |
options=options, | |
index=default_idx, | |
help=( | |
"Name of the dataset to use to train sampler." | |
f"Defaults to {SUPPORTED_DATASETS[default_idx]}." | |
), | |
disabled=st.session_state.disabled, | |
) | |
def get_model_name(): | |
options = list(SUPPORTED_MODELS) | |
default_idx = options.index(d.MODEL_NAME) | |
return st.selectbox( | |
"Model to test", | |
options=options, | |
index=default_idx, | |
help=( | |
"Name of the vision-language model to test the predictions of." | |
f"Defaults to {options[default_idx]}" | |
), | |
disabled=st.session_state.disabled, | |
) | |
def get_image(): | |
with st.sidebar: | |
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) | |
if uploaded_file is not None: | |
return (None, Image.open(uploaded_file)) | |
else: | |
DEFAULT = IMAGE_NAMES.index("bowl_ace.jpg") | |
image_idx = image_select( | |
label="or select one", | |
images=IMAGE_PATHS, | |
index=DEFAULT, | |
return_value="index", | |
) | |
image_name, image_path = IMAGE_NAMES[image_idx], IMAGE_PATHS[image_idx] | |
return (image_name, Image.open(image_path)) | |
def get_class_name(image_name=None): | |
default = ( | |
IMAGE_PRESETS[image_name.split(".")[0]]["class_name"] if image_name else "" | |
) | |
class_name = st.text_input( | |
"Class to predict", | |
help="Name of the class to build the zero-shot CLIP classifier with.", | |
value=default, | |
disabled=st.session_state.disabled, | |
placeholder="Type class name here", | |
) | |
class_ready, class_error = _validate_class_name(class_name) | |
return class_name, class_ready, class_error | |
def get_concepts(image_name=None): | |
default = ( | |
"\n".join(IMAGE_PRESETS[image_name.split(".")[0]]["concepts"]) | |
if image_name | |
else "" | |
) | |
concepts = st.text_area( | |
"Concepts to test", | |
help=( | |
"List of concepts to test the predictions of the model with. " | |
"Write one concept per line. Maximum 10 concepts allowed." | |
), | |
height=180, | |
value=default, | |
disabled=st.session_state.disabled, | |
placeholder="Type one concept\nper line", | |
) | |
concepts = concepts.split("\n") | |
concepts = [concept.strip() for concept in concepts] | |
concepts = [concept for concept in concepts if concept != ""] | |
concepts = list(set(concepts)) | |
concepts_ready, concepts_error = _validate_concepts(concepts) | |
return concepts, concepts_ready, concepts_error | |
def get_advanced_settings(concepts, concepts_ready): | |
with st.expander("Advanced settings"): | |
dataset_name = _get_dataset_name() | |
significance_level = _get_significance_level() | |
tau_max = _get_tau_max() | |
r = _get_number_of_tests() | |
cardinality = _get_cardinality(concepts, concepts_ready) | |
st.divider() | |
return significance_level, tau_max, r, cardinality, dataset_name | |