import streamlit as st from PIL import Image from streamlit_image_select import image_select from app_lib.utils import SUPPORTED_MODELS def _validate_class_name(class_name): if class_name is None: return (False, "Class name cannot be empty.") if class_name.strip() == "": return (False, "Class name cannot be empty.") return (True, None) def _validate_concepts(concepts): if len(concepts) < 3: return (False, "You must provide at least 3 concepts") if len(concepts) > 10: return (False, "Maximum 10 concepts allowed") return (True, None) def get_model_name(): return st.selectbox( "Choose a model to test", options=list(SUPPORTED_MODELS.keys()), help="Name of the vision-language model to test the predictions of.", disabled=st.session_state.disabled, ) def get_image(): with st.sidebar: uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) image = uploaded_file or image_select( label="or select one", images=[ "assets/ace.jpg", "assets/ace.jpg", "assets/ace.jpg", "assets/ace.jpg", ], ) return Image.open(image) def get_class_name(): class_name = st.text_input( "Class to test", help="Name of the class to build the zero-shot CLIP classifier with.", value="cat", disabled=st.session_state.disabled, ) class_ready, class_error = _validate_class_name(class_name) return class_name, class_ready, class_error def get_concepts(): concepts = st.text_area( "Concepts to test (max 10)", help="List of concepts to test the predictions of the model with. Write one concept per line.", height=160, value="piano\ncute\nwhiskers\nmusic\nwild", disabled=st.session_state.disabled, ) concepts = concepts.split("\n") concepts = [concept.strip() for concept in concepts] concepts = [concept for concept in concepts if concept != ""] concepts = list(set(concepts)) concepts_ready, concepts_error = _validate_concepts(concepts) return concepts, concepts_ready, concepts_error def get_cardinality(concepts, concepts_ready): return st.slider( "Size of conditioning set", help="The number of concepts to condition model predictions on.", min_value=1, max_value=max(2, len(concepts) - 1), value=2, step=1, disabled=st.session_state.disabled or not concepts_ready, )