import json import os import streamlit as st from PIL import Image from streamlit_image_select import image_select from app_lib.defaults import Defaults as d from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS IMAGE_DIR = os.path.join("assets", "images") IMAGE_NAMES = list(sorted(filter(lambda x: x.endswith(".jpg"), os.listdir(IMAGE_DIR)))) IMAGE_PATHS = list(map(lambda x: os.path.join(IMAGE_DIR, x), IMAGE_NAMES)) IMAGE_PRESETS = json.load(open("assets/image_presets.json")) def _validate_class_name(class_name): if class_name is None: return (False, "Class name cannot be empty.") if class_name.strip() == "": return (False, "Class name cannot be empty.") return (True, None) def _validate_concepts(concepts): if len(concepts) < 3: return (False, "You must provide at least 3 concepts") if len(concepts) > 10: return (False, "Maximum 10 concepts allowed") return (True, None) def _get_significance_level(): default = d.SIGNIFICANCE_LEVEL_VALUE step = d.SIGNIFICANCE_LEVEL_STEP return st.slider( "Significance level", help=f"The level of significance of the tests. Defaults to {default:.2F}.", min_value=step, max_value=1.0, value=default, step=step, disabled=st.session_state.disabled, ) def _get_tau_max(): default = d.TAU_MAX_VALUE step = d.TAU_MAX_STEP return int( st.slider( "Length of test", help=f"The maximum number of steps for each test. Defaults to {default}.", min_value=step, max_value=1000, step=step, value=default, disabled=st.session_state.disabled, ) ) def _get_number_of_tests(): default = d.R_VALUE step = d.R_STEP return int( st.slider( "Number of tests per concept", help=( "The number of tests to average for each concept. " f"Defaults to {default}." ), min_value=step, max_value=100, step=step, value=default, disabled=st.session_state.disabled, ) ) def _get_cardinality(concepts, concepts_ready): default = d.CARDINALITY_VALUE step = d.CARDINALITY_STEP return st.slider( "Size of conditioning set", help=( "The number of concepts to condition model predictions on. " "Defaults to {default}." ), min_value=1, max_value=max(2, len(concepts) - 1), value=default, step=step, disabled=st.session_state.disabled or not concepts_ready, ) def _get_dataset_name(): options = SUPPORTED_DATASETS default_idx = options.index(d.DATASET_NAME) return st.selectbox( "Dataset", options=options, index=default_idx, help=( "Name of the dataset to use to train sampler." f"Defaults to {SUPPORTED_DATASETS[default_idx]}." ), disabled=st.session_state.disabled, ) def get_model_name(): options = list(SUPPORTED_MODELS) default_idx = options.index(d.MODEL_NAME) return st.selectbox( "Model to test", options=options, index=default_idx, help=( "Name of the vision-language model to test the predictions of." f"Defaults to {options[default_idx]}" ), disabled=st.session_state.disabled, ) def get_image(): with st.sidebar: uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) if uploaded_file is not None: return (None, Image.open(uploaded_file)) else: DEFAULT = IMAGE_NAMES.index("bowl_ace.jpg") image_idx = image_select( label="or select one", images=IMAGE_PATHS, index=DEFAULT, return_value="index", ) image_name, image_path = IMAGE_NAMES[image_idx], IMAGE_PATHS[image_idx] return (image_name, Image.open(image_path)) def get_class_name(image_name=None): default = ( IMAGE_PRESETS[image_name.split(".")[0]]["class_name"] if image_name else "" ) class_name = st.text_input( "Class to predict", help="Name of the class to build the zero-shot CLIP classifier with.", value=default, disabled=st.session_state.disabled, placeholder="Type class name here", ) class_ready, class_error = _validate_class_name(class_name) return class_name, class_ready, class_error def get_concepts(image_name=None): default = ( "\n".join(IMAGE_PRESETS[image_name.split(".")[0]]["concepts"]) if image_name else "" ) concepts = st.text_area( "Concepts to test", help=( "List of concepts to test the predictions of the model with. " "Write one concept per line. Maximum 10 concepts allowed." ), height=180, value=default, disabled=st.session_state.disabled, placeholder="Type one concept\nper line", ) concepts = concepts.split("\n") concepts = [concept.strip() for concept in concepts] concepts = [concept for concept in concepts if concept != ""] concepts = list(set(concepts)) concepts_ready, concepts_error = _validate_concepts(concepts) return concepts, concepts_ready, concepts_error def get_advanced_settings(concepts, concepts_ready): with st.expander("Advanced settings"): dataset_name = _get_dataset_name() significance_level = _get_significance_level() tau_max = _get_tau_max() r = _get_number_of_tests() cardinality = _get_cardinality(concepts, concepts_ready) st.divider() return significance_level, tau_max, r, cardinality, dataset_name