IBYDMT / app_lib /user_input.py
jacopoteneggi's picture
Update
f6eb5e3 verified
raw
history blame
5.05 kB
import streamlit as st
from PIL import Image
from streamlit_image_select import image_select
from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS
def _validate_class_name(class_name):
if class_name is None:
return (False, "Class name cannot be empty.")
if class_name.strip() == "":
return (False, "Class name cannot be empty.")
return (True, None)
def _validate_concepts(concepts):
if len(concepts) < 3:
return (False, "You must provide at least 3 concepts")
if len(concepts) > 10:
return (False, "Maximum 10 concepts allowed")
return (True, None)
def _get_significance_level():
STEP, DEFAULT = 0.01, 0.05
return st.slider(
"Significance level",
help=" ".join(
[
"The level of significance of the tests.",
f"Defaults to {DEFAULT:.2F}.",
]
),
min_value=STEP,
max_value=1.0,
value=DEFAULT,
step=STEP,
disabled=st.session_state.disabled,
)
def _get_tau_max():
STEP, DEFAULT = 50, 200
return int(
st.slider(
"Length of test",
help=" ".join(
[
"The maximum number of steps for each test.",
f"Defaults to {DEFAULT}.",
]
),
min_value=STEP,
max_value=1000,
step=STEP,
value=DEFAULT,
disabled=st.session_state.disabled,
)
)
def _get_number_of_tests():
STEP, DEFAULT = 5, 10
return int(
st.slider(
"Number of tests per concept",
help=" ".join(
[
"The number of tests to average for each concept.",
f"Defaults to {DEFAULT}.",
]
),
min_value=STEP,
max_value=100,
step=STEP,
value=DEFAULT,
disabled=st.session_state.disabled,
)
)
def _get_cardinality(concepts, concepts_ready):
DEFAULT = lambda concepts: int(len(concepts) / 2)
return st.slider(
"Size of conditioning set",
help=" ".join(
[
"The number of concepts to condition model predictions on.",
"Defaults to half of the number of concepts.",
]
),
min_value=1,
max_value=max(2, len(concepts) - 1),
value=DEFAULT(concepts),
step=1,
disabled=st.session_state.disabled or not concepts_ready,
)
def _get_dataset_name():
DEFAULT = SUPPORTED_DATASETS.index("imagenette")
return st.selectbox(
"Dataset",
options=SUPPORTED_DATASETS,
index=DEFAULT,
help=" ".join(
[
"Name of the dataset to use to train sampler.",
"Defaults to Imagenette.",
]
),
disabled=st.session_state.disabled,
)
def get_model_name():
return st.selectbox(
"Model to test",
options=list(SUPPORTED_MODELS.keys()),
help="Name of the vision-language model to test the predictions of.",
disabled=st.session_state.disabled,
)
def get_image():
with st.sidebar:
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
image = uploaded_file or image_select(
label="or select one",
images=[
"assets/ace.jpg",
"assets/ace.jpg",
"assets/ace.jpg",
"assets/ace.jpg",
],
)
return Image.open(image)
def get_class_name():
class_name = st.text_input(
"Class to test",
help="Name of the class to build the zero-shot CLIP classifier with.",
value="cat",
disabled=st.session_state.disabled,
)
class_ready, class_error = _validate_class_name(class_name)
return class_name, class_ready, class_error
def get_concepts():
concepts = st.text_area(
"Concepts to test",
help="List of concepts to test the predictions of the model with. Write one concept per line. Maximum 10 concepts allowed.",
height=160,
value="piano\ncute\nwhiskers\nmusic\nwild",
disabled=st.session_state.disabled,
)
concepts = concepts.split("\n")
concepts = [concept.strip() for concept in concepts]
concepts = [concept for concept in concepts if concept != ""]
concepts = list(set(concepts))
concepts_ready, concepts_error = _validate_concepts(concepts)
return concepts, concepts_ready, concepts_error
def get_advanced_settings(concepts, concepts_ready):
with st.expander("Advanced settings"):
dataset_name = _get_dataset_name()
significance_level = _get_significance_level()
tau_max = _get_tau_max()
r = _get_number_of_tests()
cardinality = _get_cardinality(concepts, concepts_ready)
st.divider()
return significance_level, tau_max, r, cardinality, dataset_name