import streamlit as st from PIL import Image from streamlit_image_select import image_select from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS def _validate_class_name(class_name): if class_name is None: return (False, "Class name cannot be empty.") if class_name.strip() == "": return (False, "Class name cannot be empty.") return (True, None) def _validate_concepts(concepts): if len(concepts) < 3: return (False, "You must provide at least 3 concepts") if len(concepts) > 10: return (False, "Maximum 10 concepts allowed") return (True, None) def _get_significance_level(): STEP, DEFAULT = 0.01, 0.05 return st.slider( "Significance level", help=" ".join( [ "The level of significance of the tests.", f"Defaults to {DEFAULT:.2F}.", ] ), min_value=STEP, max_value=1.0, value=DEFAULT, step=STEP, disabled=st.session_state.disabled, ) def _get_tau_max(): STEP, DEFAULT = 50, 200 return int( st.slider( "Length of test", help=" ".join( [ "The maximum number of steps for each test.", f"Defaults to {DEFAULT}.", ] ), min_value=STEP, max_value=1000, step=STEP, value=DEFAULT, disabled=st.session_state.disabled, ) ) def _get_number_of_tests(): STEP, DEFAULT = 5, 10 return int( st.slider( "Number of tests per concept", help=" ".join( [ "The number of tests to average for each concept.", f"Defaults to {DEFAULT}.", ] ), min_value=STEP, max_value=100, step=STEP, value=DEFAULT, disabled=st.session_state.disabled, ) ) def _get_cardinality(concepts, concepts_ready): DEFAULT = lambda concepts: int(len(concepts) / 2) return st.slider( "Size of conditioning set", help=" ".join( [ "The number of concepts to condition model predictions on.", "Defaults to half of the number of concepts.", ] ), min_value=1, max_value=max(2, len(concepts) - 1), value=DEFAULT(concepts), step=1, disabled=st.session_state.disabled or not concepts_ready, ) def _get_dataset_name(): DEFAULT = SUPPORTED_DATASETS.index("imagenette") return st.selectbox( "Dataset", options=SUPPORTED_DATASETS, index=DEFAULT, help=" ".join( [ "Name of the dataset to use to train sampler.", "Defaults to Imagenette.", ] ), disabled=st.session_state.disabled, ) def get_model_name(): return st.selectbox( "Model to test", options=list(SUPPORTED_MODELS.keys()), help="Name of the vision-language model to test the predictions of.", disabled=st.session_state.disabled, ) def get_image(): with st.sidebar: uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) image = uploaded_file or image_select( label="or select one", images=[ "assets/ace.jpg", "assets/ace.jpg", "assets/ace.jpg", "assets/ace.jpg", ], ) return Image.open(image) def get_class_name(): class_name = st.text_input( "Class to test", help="Name of the class to build the zero-shot CLIP classifier with.", value="cat", disabled=st.session_state.disabled, ) class_ready, class_error = _validate_class_name(class_name) return class_name, class_ready, class_error def get_concepts(): concepts = st.text_area( "Concepts to test", help="List of concepts to test the predictions of the model with. Write one concept per line. Maximum 10 concepts allowed.", height=160, value="piano\ncute\nwhiskers\nmusic\nwild", disabled=st.session_state.disabled, ) concepts = concepts.split("\n") concepts = [concept.strip() for concept in concepts] concepts = [concept for concept in concepts if concept != ""] concepts = list(set(concepts)) concepts_ready, concepts_error = _validate_concepts(concepts) return concepts, concepts_ready, concepts_error def get_advanced_settings(concepts, concepts_ready): with st.expander("Advanced settings"): dataset_name = _get_dataset_name() significance_level = _get_significance_level() tau_max = _get_tau_max() r = _get_number_of_tests() cardinality = _get_cardinality(concepts, concepts_ready) st.divider() return significance_level, tau_max, r, cardinality, dataset_name