Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
from streamlit_image_select import image_select | |
from app_lib.utils import SUPPORTED_MODELS | |
def _validate_class_name(class_name): | |
if class_name is None: | |
return (False, "Class name cannot be empty.") | |
if class_name.strip() == "": | |
return (False, "Class name cannot be empty.") | |
return (True, None) | |
def _validate_concepts(concepts): | |
if len(concepts) < 3: | |
return (False, "You must provide at least 3 concepts") | |
if len(concepts) > 10: | |
return (False, "Maximum 10 concepts allowed") | |
return (True, None) | |
def get_model_name(): | |
return st.selectbox( | |
"Choose a model to test", | |
options=list(SUPPORTED_MODELS.keys()), | |
help="Name of the vision-language model to test the predictions of.", | |
disabled=st.session_state.disabled, | |
) | |
def get_image(): | |
with st.sidebar: | |
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) | |
image = uploaded_file or image_select( | |
label="or select one", | |
images=[ | |
"assets/ace.jpg", | |
"assets/ace.jpg", | |
"assets/ace.jpg", | |
"assets/ace.jpg", | |
], | |
) | |
return Image.open(image) | |
def get_class_name(): | |
class_name = st.text_input( | |
"Class to test", | |
help="Name of the class to build the zero-shot CLIP classifier with.", | |
value="cat", | |
disabled=st.session_state.disabled, | |
) | |
class_ready, class_error = _validate_class_name(class_name) | |
return class_name, class_ready, class_error | |
def get_concepts(): | |
concepts = st.text_area( | |
"Concepts to test (max 10)", | |
help="List of concepts to test the predictions of the model with. Write one concept per line.", | |
height=160, | |
value="piano\ncute\nwhiskers\nmusic\nwild", | |
disabled=st.session_state.disabled, | |
) | |
concepts = concepts.split("\n") | |
concepts = [concept.strip() for concept in concepts] | |
concepts = [concept for concept in concepts if concept != ""] | |
concepts = list(set(concepts)) | |
concepts_ready, concepts_error = _validate_concepts(concepts) | |
return concepts, concepts_ready, concepts_error | |
def get_cardinality(concepts, concepts_ready): | |
return st.slider( | |
"Size of conditioning set", | |
help="The number of concepts to condition model predictions on.", | |
min_value=1, | |
max_value=max(2, len(concepts) - 1), | |
value=2, | |
step=1, | |
disabled=st.session_state.disabled or not concepts_ready, | |
) | |