import torch import streamlit as st import time from app_lib.user_input import ( get_cardinality, get_class_name, get_concepts, get_image, get_model_name, ) from app_lib.test import ( load_dataset, load_model, encode_image, encode_concepts, encode_class_name, ) def _disable(): st.session_state.disabled = True def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")): columns = st.columns([0.40, 0.60]) with columns[0]: model_name = get_model_name() row1 = st.columns(2) row2 = st.columns(2) with row1[0]: image = get_image() st.image(image, use_column_width=True) with row1[1]: class_name, class_ready, class_error = get_class_name() concepts, concepts_ready, concepts_error = get_concepts() cardinality = get_cardinality(concepts, concepts_ready) with row2[0]: change_image_button = st.button( "Change Image", use_container_width=True, disabled=st.session_state.disabled, ) if change_image_button: st.session_state.sidebar_state = "expanded" st.experimental_rerun() with row2[1]: ready = class_ready and concepts_ready error_message = "" if class_error is not None: error_message += f"- {class_error}\n" if concepts_error is not None: error_message += f"- {concepts_error}\n" if error_message: st.error(error_message) test_button = st.button( "Test", use_container_width=True, on_click=_disable, disabled=st.session_state.disabled or not ready, ) with columns[1]: if test_button: with st.spinner("Loading dataset"): embedding = load_dataset("imagenette", model_name) time.sleep(1) with st.spinner("Loading model"): model, preprocess, tokenizer = load_model(model_name, device) time.sleep(1) with st.spinner("Encoding concepts"): cbm = encode_concepts(tokenizer, model, concepts, device) time.sleep(1) with st.spinner("Preparing zero-shot classifier"): classifier = encode_class_name(tokenizer, model, class_name, device) with st.spinner("Encoding image"): h = encode_image(model, preprocess, image, device) z = h @ cbm.T print(h.shape, cbm.shape, z.shape) time.sleep(2) st.session_state.disabled = False st.experimental_rerun()