Spaces:
Sleeping
Sleeping
import torch | |
import streamlit as st | |
import time | |
from app_lib.user_input import ( | |
get_cardinality, | |
get_class_name, | |
get_concepts, | |
get_image, | |
get_model_name, | |
) | |
from app_lib.test import ( | |
load_dataset, | |
load_model, | |
encode_image, | |
encode_concepts, | |
encode_class_name, | |
) | |
def _disable(): | |
st.session_state.disabled = True | |
def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")): | |
columns = st.columns([0.40, 0.60]) | |
with columns[0]: | |
model_name = get_model_name() | |
row1 = st.columns(2) | |
row2 = st.columns(2) | |
with row1[0]: | |
image = get_image() | |
st.image(image, use_column_width=True) | |
with row1[1]: | |
class_name, class_ready, class_error = get_class_name() | |
concepts, concepts_ready, concepts_error = get_concepts() | |
cardinality = get_cardinality(concepts, concepts_ready) | |
with row2[0]: | |
change_image_button = st.button( | |
"Change Image", | |
use_container_width=True, | |
disabled=st.session_state.disabled, | |
) | |
if change_image_button: | |
st.session_state.sidebar_state = "expanded" | |
st.experimental_rerun() | |
with row2[1]: | |
ready = class_ready and concepts_ready | |
error_message = "" | |
if class_error is not None: | |
error_message += f"- {class_error}\n" | |
if concepts_error is not None: | |
error_message += f"- {concepts_error}\n" | |
if error_message: | |
st.error(error_message) | |
test_button = st.button( | |
"Test", | |
use_container_width=True, | |
on_click=_disable, | |
disabled=st.session_state.disabled or not ready, | |
) | |
with columns[1]: | |
if test_button: | |
with st.spinner("Loading dataset"): | |
embedding = load_dataset("imagenette", model_name) | |
time.sleep(1) | |
with st.spinner("Loading model"): | |
model, preprocess, tokenizer = load_model(model_name, device) | |
time.sleep(1) | |
with st.spinner("Encoding concepts"): | |
cbm = encode_concepts(tokenizer, model, concepts, device) | |
time.sleep(1) | |
with st.spinner("Preparing zero-shot classifier"): | |
classifier = encode_class_name(tokenizer, model, class_name, device) | |
with st.spinner("Encoding image"): | |
h = encode_image(model, preprocess, image, device) | |
z = h @ cbm.T | |
print(h.shape, cbm.shape, z.shape) | |
time.sleep(2) | |
st.session_state.disabled = False | |
st.experimental_rerun() | |