imagecode-demo / app.py
BennoKrojer's picture
Update app.py
fc882a3
from turtle import color, onclick
import streamlit as st
from PIL import Image, ImageOps
import glob
import json
import requests
import random
import io
random.seed(10)
if 'show' not in st.session_state:
st.session_state.show = False
if 'example_idx' not in st.session_state:
st.session_state.example_idx = 0
st.set_page_config(layout="wide")
st.markdown("**This is a demo of the *ImageCoDe* benchmark. What is the task? You are given a description and you have to pick the image it describes, out of 10 images total.**")
st.markdown("**If you click the Sample button, you will get a new text and images. More details of ImageCoDe can be found in our ACL 2022 paper.**")
col1, col2 = st.columns(2)
prefix = 'https://raw.githubusercontent.com/BennoKrojer/imagecode-val-set/main/image-sets-val/'
set2ids = json.load(open('set2ids.json', 'r'))
descriptions = json.load(open('valid_list.json', 'r'))
#example_idx = int(col1.number_input('Sample an example (description + corresponding images) from the validation set', value=0, min_value=0, max_value=len(descriptions)-1))
if col1.button('Sample a description + 10 images from the validation set'):
st.session_state.example_idx += 1
# st.session_state.example_idx = random.randint(0, len(descriptions)-1)
img_set, true_idx, descr = descriptions[st.session_state.example_idx]
true_idx = int(true_idx)
images = [prefix+'/'+img_set+'/'+i for i in set2ids[img_set]]
img_urls = images.copy()
index = int(col2.number_input('Image Index from 0 to 9', value=0, min_value=0, max_value=9))
if col1.button('Toggle to reveal/hide the correct image, try to guess yourself before giving up!'):
st.session_state.show = not st.session_state.show
col1.markdown(f'**Description for {img_set}**:')
col1.markdown(f'**{descr}**')
big_img = images[index]
img = Image.open(io.BytesIO(requests.get(images[index], stream=True).content))
img_width, img_height = img.size
smaller = min(img_width, img_height)
images[index]= ImageOps.expand(img,border=smaller//18,fill='blue')
caps = list(range(10))
cap = str(index)
if st.session_state.show:
caps[true_idx] = f'{true_idx} (TARGET IMAGE)'
img = Image.open(io.BytesIO(requests.get(img_urls[true_idx], stream=True).content))
img_width, img_height = img.size
smaller = min(img_width, img_height)
images[true_idx] = ImageOps.expand(img,border=smaller//8,fill='green')
if true_idx == index:
cap = f'{true_idx} (TARGET IMAGE)'
else:
caps[true_idx] = f'{true_idx}'
if true_idx == index:
cap = f'{true_idx}'
col1.image(big_img, use_column_width=True, caption=cap)
col2.image(images, width=175, caption=caps)
col1.markdown(f'{st.session_state.example_idx}')