Spaces:
Sleeping
Sleeping
File size: 5,989 Bytes
0aef92c 80dc74c 21d3461 5e91161 80dc74c 0aef92c 80dc74c 7e207f0 21d3461 7e207f0 21d3461 7e207f0 21d3461 7e207f0 21d3461 7e207f0 b30bcef 21d3461 7e207f0 21d3461 7e207f0 21d3461 7e207f0 21d3461 7e207f0 21d3461 7e207f0 21d3461 7e207f0 21d3461 7e207f0 21d3461 8e05eba 7e207f0 8e05eba 21d3461 7e207f0 b30bcef 21d3461 b30bcef 21d3461 b30bcef 80dc74c 21d3461 80dc74c 5ead791 21d3461 4f55ca2 80dc74c 0aef92c 8e05eba 0aef92c 21d3461 0aef92c 80dc74c 21d3461 80dc74c 21d3461 4f55ca2 0aef92c 80dc74c 0aef92c 21d3461 0aef92c 80dc74c 7e207f0 21d3461 8e05eba 21d3461 4f55ca2 0aef92c 80dc74c 7e207f0 b30bcef 7e207f0 b30bcef 7e207f0 b30bcef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import json
import os
import streamlit as st
from PIL import Image
from streamlit_image_select import image_select
import app_lib.defaults as defaults
from app_lib.utils import SUPPORTED_DATASETS, SUPPORTED_MODELS
IMAGE_DIR = os.path.join("assets", "images")
IMAGE_NAMES = list(sorted(filter(lambda x: x.endswith(".jpg"), os.listdir(IMAGE_DIR))))
IMAGE_PATHS = list(map(lambda x: os.path.join(IMAGE_DIR, x), IMAGE_NAMES))
IMAGE_PRESETS = json.load(open("assets/image_presets.json"))
def _validate_class_name(class_name):
if class_name is None:
return (False, "Class name cannot be empty.")
if class_name.strip() == "":
return (False, "Class name cannot be empty.")
return (True, None)
def _validate_concepts(concepts):
if len(concepts) < 3:
return (False, "You must provide at least 3 concepts")
if len(concepts) > 10:
return (False, "Maximum 10 concepts allowed")
return (True, None)
def _get_significance_level():
default = defaults.SIGNIFICANCE_LEVEL_VALUE
step = defaults.SIGNIFICANCE_LEVEL_STEP
return st.slider(
"Significance level",
help=f"The level of significance of the tests. Defaults to {default:.2F}.",
min_value=step,
max_value=1.0,
value=default,
step=step,
disabled=st.session_state.disabled,
)
def _get_tau_max():
default = defaults.TAU_MAX_VALUE
step = defaults.TAU_MAX_STEP
return int(
st.slider(
"Length of test",
help=f"The maximum number of steps for each test. Defaults to {default}.",
min_value=step,
max_value=1000,
step=step,
value=default,
disabled=st.session_state.disabled,
)
)
def _get_number_of_tests():
default = defaults.R_VALUE
step = defaults.R_STEP
return int(
st.slider(
"Number of tests per concept",
help=(
"The number of tests to average for each concept. "
f"Defaults to {default}."
),
min_value=step,
max_value=100,
step=step,
value=default,
disabled=st.session_state.disabled,
)
)
def _get_cardinality(concepts, concepts_ready):
default = defaults.CARDINALITY_VALUE
step = defaults.CARDINALITY_STEP
return st.slider(
"Size of conditioning set",
help=(
"The number of concepts to condition model predictions on. "
"Defaults to {default}."
),
min_value=1,
max_value=max(2, len(concepts) - 1),
value=default,
step=step,
disabled=st.session_state.disabled or not concepts_ready,
)
def _get_dataset_name():
options = SUPPORTED_DATASETS
default_idx = options.index(defaults.DATASET_NAME)
return st.selectbox(
"Dataset",
options=options,
index=default_idx,
help=(
"Name of the dataset to use to train sampler."
f"Defaults to {SUPPORTED_DATASETS[default_idx]}."
),
disabled=st.session_state.disabled,
)
def get_model_name():
options = list(SUPPORTED_MODELS.keys())
default_idx = options.index(defaults.MODEL_NAME)
return st.selectbox(
"Model to test",
options=options,
index=default_idx,
help=(
"Name of the vision-language model to test the predictions of."
f"Defaults to {options[default_idx]}"
),
disabled=st.session_state.disabled,
)
def get_image():
with st.sidebar:
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
if uploaded_file is not None:
return (None, Image.open(uploaded_file))
else:
DEFAULT = IMAGE_NAMES.index("bowl_ace.jpg")
image_idx = image_select(
label="or select one",
images=IMAGE_PATHS,
index=DEFAULT,
return_value="index",
)
image_name, image_path = IMAGE_NAMES[image_idx], IMAGE_PATHS[image_idx]
return (image_name, Image.open(image_path))
def get_class_name(image_name=None):
default = (
IMAGE_PRESETS[image_name.split(".")[0]]["class_name"] if image_name else ""
)
class_name = st.text_input(
"Class to predict",
help="Name of the class to build the zero-shot CLIP classifier with.",
value=default,
disabled=st.session_state.disabled,
placeholder="Type class name here",
)
class_ready, class_error = _validate_class_name(class_name)
return class_name, class_ready, class_error
def get_concepts(image_name=None):
default = (
"\n".join(IMAGE_PRESETS[image_name.split(".")[0]]["concepts"])
if image_name
else ""
)
concepts = st.text_area(
"Concepts to test",
help=(
"List of concepts to test the predictions of the model with. "
"Write one concept per line. Maximum 10 concepts allowed."
),
height=180,
value=default,
disabled=st.session_state.disabled,
placeholder="Type one concept\nper line",
)
concepts = concepts.split("\n")
concepts = [concept.strip() for concept in concepts]
concepts = [concept for concept in concepts if concept != ""]
concepts = list(set(concepts))
concepts_ready, concepts_error = _validate_concepts(concepts)
return concepts, concepts_ready, concepts_error
def get_advanced_settings(concepts, concepts_ready):
with st.expander("Advanced settings"):
dataset_name = _get_dataset_name()
significance_level = _get_significance_level()
tau_max = _get_tau_max()
r = _get_number_of_tests()
cardinality = _get_cardinality(concepts, concepts_ready)
st.divider()
return significance_level, tau_max, r, cardinality, dataset_name
|