Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
from streamlit_image_select import image_select | |
from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS | |
def _validate_class_name(class_name): | |
if class_name is None: | |
return (False, "Class name cannot be empty.") | |
if class_name.strip() == "": | |
return (False, "Class name cannot be empty.") | |
return (True, None) | |
def _validate_concepts(concepts): | |
if len(concepts) < 3: | |
return (False, "You must provide at least 3 concepts") | |
if len(concepts) > 10: | |
return (False, "Maximum 10 concepts allowed") | |
return (True, None) | |
def _get_significance_level(): | |
STEP, DEFAULT = 0.01, 0.05 | |
return st.slider( | |
"Significance level", | |
help=" ".join( | |
[ | |
"The level of significance of the tests.", | |
f"Defaults to {DEFAULT:.2F}.", | |
] | |
), | |
min_value=STEP, | |
max_value=1.0, | |
value=DEFAULT, | |
step=STEP, | |
disabled=st.session_state.disabled, | |
) | |
def _get_tau_max(): | |
STEP, DEFAULT = 50, 200 | |
return int( | |
st.slider( | |
"Length of test", | |
help=" ".join( | |
[ | |
"The maximum number of steps for each test.", | |
f"Defaults to {DEFAULT}.", | |
] | |
), | |
min_value=STEP, | |
max_value=1000, | |
step=STEP, | |
value=DEFAULT, | |
disabled=st.session_state.disabled, | |
) | |
) | |
def _get_number_of_tests(): | |
STEP, DEFAULT = 5, 10 | |
return int( | |
st.slider( | |
"Number of tests per concept", | |
help=" ".join( | |
[ | |
"The number of tests to average for each concept.", | |
f"Defaults to {DEFAULT}.", | |
] | |
), | |
min_value=STEP, | |
max_value=100, | |
step=STEP, | |
value=DEFAULT, | |
disabled=st.session_state.disabled, | |
) | |
) | |
def _get_cardinality(concepts, concepts_ready): | |
DEFAULT = lambda concepts: int(len(concepts) / 2) | |
return st.slider( | |
"Size of conditioning set", | |
help=" ".join( | |
[ | |
"The number of concepts to condition model predictions on.", | |
"Defaults to half of the number of concepts.", | |
] | |
), | |
min_value=1, | |
max_value=max(2, len(concepts) - 1), | |
value=DEFAULT(concepts), | |
step=1, | |
disabled=st.session_state.disabled or not concepts_ready, | |
) | |
def _get_dataset_name(): | |
DEFAULT = SUPPORTED_DATASETS.index("imagenette") | |
return st.selectbox( | |
"Dataset", | |
options=SUPPORTED_DATASETS, | |
index=DEFAULT, | |
help=" ".join( | |
[ | |
"Name of the dataset to use to train sampler.", | |
"Defaults to Imagenette.", | |
] | |
), | |
disabled=st.session_state.disabled, | |
) | |
def get_model_name(): | |
return st.selectbox( | |
"Model to test", | |
options=list(SUPPORTED_MODELS.keys()), | |
help="Name of the vision-language model to test the predictions of.", | |
disabled=st.session_state.disabled, | |
) | |
def get_image(): | |
with st.sidebar: | |
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) | |
image = uploaded_file or image_select( | |
label="or select one", | |
images=[ | |
"assets/ace.jpg", | |
"assets/ace.jpg", | |
"assets/ace.jpg", | |
"assets/ace.jpg", | |
], | |
) | |
return Image.open(image) | |
def get_class_name(): | |
class_name = st.text_input( | |
"Class to test", | |
help="Name of the class to build the zero-shot CLIP classifier with.", | |
value="cat", | |
disabled=st.session_state.disabled, | |
) | |
class_ready, class_error = _validate_class_name(class_name) | |
return class_name, class_ready, class_error | |
def get_concepts(): | |
concepts = st.text_area( | |
"Concepts to test", | |
help="List of concepts to test the predictions of the model with. Write one concept per line. Maximum 10 concepts allowed.", | |
height=160, | |
value="piano\ncute\nwhiskers\nmusic\nwild", | |
disabled=st.session_state.disabled, | |
) | |
concepts = concepts.split("\n") | |
concepts = [concept.strip() for concept in concepts] | |
concepts = [concept for concept in concepts if concept != ""] | |
concepts = list(set(concepts)) | |
concepts_ready, concepts_error = _validate_concepts(concepts) | |
return concepts, concepts_ready, concepts_error | |
def get_advanced_settings(concepts, concepts_ready): | |
with st.expander("Advanced settings"): | |
dataset_name = _get_dataset_name() | |
significance_level = _get_significance_level() | |
tau_max = _get_tau_max() | |
r = _get_number_of_tests() | |
cardinality = _get_cardinality(concepts, concepts_ready) | |
st.divider() | |
return significance_level, tau_max, r, cardinality, dataset_name | |