IBYDMT / app_lib /main.py
jacopoteneggi's picture
Update
4f55ca2 verified
raw
history blame
2.82 kB
import torch
import streamlit as st
import time
from app_lib.user_input import (
get_cardinality,
get_class_name,
get_concepts,
get_image,
get_model_name,
)
from app_lib.test import (
load_dataset,
load_model,
encode_image,
encode_concepts,
encode_class_name,
)
def _disable():
st.session_state.disabled = True
def main(device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
columns = st.columns([0.40, 0.60])
with columns[0]:
model_name = get_model_name()
row1 = st.columns(2)
row2 = st.columns(2)
with row1[0]:
image = get_image()
st.image(image, use_column_width=True)
with row1[1]:
class_name, class_ready, class_error = get_class_name()
concepts, concepts_ready, concepts_error = get_concepts()
cardinality = get_cardinality(concepts, concepts_ready)
with row2[0]:
change_image_button = st.button(
"Change Image",
use_container_width=True,
disabled=st.session_state.disabled,
)
if change_image_button:
st.session_state.sidebar_state = "expanded"
st.experimental_rerun()
with row2[1]:
ready = class_ready and concepts_ready
error_message = ""
if class_error is not None:
error_message += f"- {class_error}\n"
if concepts_error is not None:
error_message += f"- {concepts_error}\n"
if error_message:
st.error(error_message)
test_button = st.button(
"Test",
use_container_width=True,
on_click=_disable,
disabled=st.session_state.disabled or not ready,
)
with columns[1]:
if test_button:
with st.spinner("Loading dataset"):
embedding = load_dataset("imagenette", model_name)
time.sleep(1)
with st.spinner("Loading model"):
model, preprocess, tokenizer = load_model(model_name, device)
time.sleep(1)
with st.spinner("Encoding concepts"):
cbm = encode_concepts(tokenizer, model, concepts, device)
time.sleep(1)
with st.spinner("Preparing zero-shot classifier"):
classifier = encode_class_name(tokenizer, model, class_name, device)
with st.spinner("Encoding image"):
h = encode_image(model, preprocess, image, device)
z = h @ cbm.T
print(h.shape, cbm.shape, z.shape)
time.sleep(2)
st.session_state.disabled = False
st.experimental_rerun()