IBYDMT / app_lib /user_input.py
jacopoteneggi's picture
Update
a40e67a verified
raw
history blame
2.6 kB
import streamlit as st
from PIL import Image
from streamlit_image_select import image_select
from app_lib.utils import SUPPORTED_MODELS
def _validate_class_name(class_name):
if class_name is None:
return (False, "Class name cannot be empty.")
if class_name.strip() == "":
return (False, "Class name cannot be empty.")
return (True, None)
def _validate_concepts(concepts):
if len(concepts) < 3:
return (False, "You must provide at least 3 concepts")
if len(concepts) > 10:
return (False, "Maximum 10 concepts allowed")
return (True, None)
def get_model_name():
return st.selectbox(
"Choose a model to test",
options=list(SUPPORTED_MODELS.keys()),
help="Name of the vision-language model to test the predictions of.",
disabled=st.session_state.disabled,
)
def get_image():
with st.sidebar:
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
image = uploaded_file or image_select(
label="or select one",
images=[
"assets/ace.jpg",
"assets/ace.jpg",
"assets/ace.jpg",
"assets/ace.jpg",
],
)
return Image.open(image)
def get_class_name():
class_name = st.text_input(
"Class to test",
help="Name of the class to build the zero-shot CLIP classifier with.",
value="cat",
disabled=st.session_state.disabled,
)
class_ready, class_error = _validate_class_name(class_name)
return class_name, class_ready, class_error
def get_concepts():
concepts = st.text_area(
"Concepts to test (max 10)",
help="List of concepts to test the predictions of the model with. Write one concept per line.",
height=160,
value="piano\ncute\nwhiskers\nmusic\nwild",
disabled=st.session_state.disabled,
)
concepts = concepts.split("\n")
concepts = [concept.strip() for concept in concepts]
concepts = [concept for concept in concepts if concept != ""]
concepts = list(set(concepts))
concepts_ready, concepts_error = _validate_concepts(concepts)
return concepts, concepts_ready, concepts_error
def get_cardinality(concepts, concepts_ready):
return st.slider(
"Size of conditioning set",
help="The number of concepts to condition model predictions on.",
min_value=1,
max_value=max(2, len(concepts) - 1),
value=2,
step=1,
disabled=st.session_state.disabled or not concepts_ready,
)