Spaces:
Running
Running
File size: 5,053 Bytes
80dc74c 5e91161 80dc74c 7e207f0 f6eb5e3 7e207f0 f6eb5e3 7e207f0 f6eb5e3 7e207f0 f6eb5e3 7e207f0 b30bcef 7e207f0 f6eb5e3 7e207f0 f6eb5e3 7e207f0 f6eb5e3 7e207f0 f6eb5e3 7e207f0 f6eb5e3 7e207f0 b30bcef 7e207f0 b30bcef 7e207f0 b30bcef 80dc74c 5ead791 4f55ca2 80dc74c 4f55ca2 80dc74c 4f55ca2 80dc74c 7e207f0 80dc74c 4f55ca2 80dc74c 7e207f0 b30bcef 7e207f0 b30bcef 7e207f0 b30bcef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import streamlit as st
from PIL import Image
from streamlit_image_select import image_select
from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS
def _validate_class_name(class_name):
if class_name is None:
return (False, "Class name cannot be empty.")
if class_name.strip() == "":
return (False, "Class name cannot be empty.")
return (True, None)
def _validate_concepts(concepts):
if len(concepts) < 3:
return (False, "You must provide at least 3 concepts")
if len(concepts) > 10:
return (False, "Maximum 10 concepts allowed")
return (True, None)
def _get_significance_level():
STEP, DEFAULT = 0.01, 0.05
return st.slider(
"Significance level",
help=" ".join(
[
"The level of significance of the tests.",
f"Defaults to {DEFAULT:.2F}.",
]
),
min_value=STEP,
max_value=1.0,
value=DEFAULT,
step=STEP,
disabled=st.session_state.disabled,
)
def _get_tau_max():
STEP, DEFAULT = 50, 200
return int(
st.slider(
"Length of test",
help=" ".join(
[
"The maximum number of steps for each test.",
f"Defaults to {DEFAULT}.",
]
),
min_value=STEP,
max_value=1000,
step=STEP,
value=DEFAULT,
disabled=st.session_state.disabled,
)
)
def _get_number_of_tests():
STEP, DEFAULT = 5, 10
return int(
st.slider(
"Number of tests per concept",
help=" ".join(
[
"The number of tests to average for each concept.",
f"Defaults to {DEFAULT}.",
]
),
min_value=STEP,
max_value=100,
step=STEP,
value=DEFAULT,
disabled=st.session_state.disabled,
)
)
def _get_cardinality(concepts, concepts_ready):
DEFAULT = lambda concepts: int(len(concepts) / 2)
return st.slider(
"Size of conditioning set",
help=" ".join(
[
"The number of concepts to condition model predictions on.",
"Defaults to half of the number of concepts.",
]
),
min_value=1,
max_value=max(2, len(concepts) - 1),
value=DEFAULT(concepts),
step=1,
disabled=st.session_state.disabled or not concepts_ready,
)
def _get_dataset_name():
DEFAULT = SUPPORTED_DATASETS.index("imagenette")
return st.selectbox(
"Dataset",
options=SUPPORTED_DATASETS,
index=DEFAULT,
help=" ".join(
[
"Name of the dataset to use to train sampler.",
"Defaults to Imagenette.",
]
),
disabled=st.session_state.disabled,
)
def get_model_name():
return st.selectbox(
"Model to test",
options=list(SUPPORTED_MODELS.keys()),
help="Name of the vision-language model to test the predictions of.",
disabled=st.session_state.disabled,
)
def get_image():
with st.sidebar:
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
image = uploaded_file or image_select(
label="or select one",
images=[
"assets/ace.jpg",
"assets/ace.jpg",
"assets/ace.jpg",
"assets/ace.jpg",
],
)
return Image.open(image)
def get_class_name():
class_name = st.text_input(
"Class to test",
help="Name of the class to build the zero-shot CLIP classifier with.",
value="cat",
disabled=st.session_state.disabled,
)
class_ready, class_error = _validate_class_name(class_name)
return class_name, class_ready, class_error
def get_concepts():
concepts = st.text_area(
"Concepts to test",
help="List of concepts to test the predictions of the model with. Write one concept per line. Maximum 10 concepts allowed.",
height=160,
value="piano\ncute\nwhiskers\nmusic\nwild",
disabled=st.session_state.disabled,
)
concepts = concepts.split("\n")
concepts = [concept.strip() for concept in concepts]
concepts = [concept for concept in concepts if concept != ""]
concepts = list(set(concepts))
concepts_ready, concepts_error = _validate_concepts(concepts)
return concepts, concepts_ready, concepts_error
def get_advanced_settings(concepts, concepts_ready):
with st.expander("Advanced settings"):
dataset_name = _get_dataset_name()
significance_level = _get_significance_level()
tau_max = _get_tau_max()
r = _get_number_of_tests()
cardinality = _get_cardinality(concepts, concepts_ready)
st.divider()
return significance_level, tau_max, r, cardinality, dataset_name
|