virtex-redcaps / app.py
zamborg's picture
syntax
5a0da41
raw
history blame
2.26 kB
import streamlit as st
import io
import sys
import time
sys.path.append("./virtex/")
st.title("Image Captioning Demo from Redcaps")
st.sidebar.markdown(
"""
Image Captioning Model from VirTex trained on Redcaps
"""
)
with st.spinner("Loading Model"):
st.write("DEBUG PRINTING ==========")
start = time.time()
from model import *
st.write(f"Import TIME: {time.time()-start}")
sample_images = get_samples()
start = time.time()
download_files()
st.write(f"download TIME: {time.time()-start}")
start = time.time()
virtexModel = VirTexModel()
imageLoader = ImageLoader()
st.write(f"model TIME: {time.time()-start}")
random_image = get_rand_img(sample_images)
st.sidebar.title("Select a sample image")
sample_image = st.sidebar.selectbox(
"",
sample_images
)
if st.sidebar.button("Random Sample Image"):
random_image = get_rand_img(sample_images)
sample_image = None
uploaded_image = None
with st.sidebar.form("file-uploader-form", clear_on_submit=True):
uploaded_file = st.file_uploader("Choose a file")
submitted = st.form_submit_button("Submit")
if uploaded_file is not None and submitted:
uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
if uploaded_image is None and submitted:
st.write("Please select a file to upload")
else:
image_file = sample_image if sample_image is not None else random_image
image = uploaded_image if uploaded_image is not None else Image.open(image_file)
image_dict = imageLoader.transform(image)
image = imageLoader.to_image(image_dict["image"].squeeze(0))
show = st.image(image)
show.image(image, "Your Image")
with st.spinner("Generating Caption"):
subreddit, caption = virtexModel.predict(image_dict)
st.header("Predicted Caption:\n\n")
st.subheader(f"Subreddit: {subreddit}\n")
st.subheader(f"Caption: {caption}\n")
image.close()
# from model import *
# download_files()
# sample_images = get_samples()
# v, il = VirTexModel(), ImageLoader()
# for s in sample_images:
# subreddit, caption = v.predict(il.load(s))
# print("=====================")
# print(subreddit)
# print(caption)