Spaces:
Running
Running
import streamlit as st | |
import requests | |
import io | |
# Designing the interface | |
st.title("πΌοΈ Image Captioning Demo π") | |
st.write("[Yih-Dar SHIEH](https://huggingface.co/ydshieh)") | |
st.sidebar.markdown( | |
""" | |
An image captioning model by combining ViT model with GPT2 model. | |
The encoder (ViT) and decoder (GPT2) are combined using Hugging Face transformers' [Vision-To-Text Encoder-Decoder | |
framework](https://huggingface.co/transformers/master/model_doc/visionencoderdecoder.html). | |
The pretrained weights of both models are loaded, with a set of randomly initialized cross-attention weights. | |
The model is trained on the COCO 2017 dataset for about 6900 steps (batch_size=256). | |
[Follow-up work of [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/).]\n | |
""" | |
) | |
with st.spinner('Loading and compiling ViT-GPT2 model ...'): | |
from model import * | |
random_image_id = get_random_image_id() | |
st.sidebar.title("Select a sample image") | |
sample_image_id = st.sidebar.selectbox( | |
"Please choose a sample image", | |
sample_image_ids | |
) | |
if st.sidebar.button("Random COCO 2017 (val) images"): | |
random_image_id = get_random_image_id() | |
sample_image_id = "None" | |
bytes_data = None | |
with st.sidebar.form("file-uploader-form", clear_on_submit=True): | |
uploaded_file = st.file_uploader("Choose a file") | |
submitted = st.form_submit_button("Upload") | |
if submitted and uploaded_file is not None: | |
bytes_data = io.BytesIO(uploaded_file.getvalue()) | |
if (bytes_data is None) and submitted: | |
st.write("No file is selected to upload") | |
else: | |
image_id = random_image_id | |
if sample_image_id != "None": | |
assert type(sample_image_id) == int | |
image_id = sample_image_id | |
sample_name = f"COCO_val2017_{str(image_id).zfill(12)}.jpg" | |
sample_path = os.path.join(sample_dir, sample_name) | |
if bytes_data is not None: | |
image = Image.open(bytes_data) | |
elif os.path.isfile(sample_path): | |
image = Image.open(sample_path) | |
else: | |
url = f"http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg" | |
image = Image.open(requests.get(url, stream=True).raw) | |
width, height = image.size | |
resized = image.resize(size=(width, height)) | |
if height > 384: | |
width = int(width / height * 384) | |
height = 384 | |
resized = resized.resize(size=(width, height)) | |
width, height = resized.size | |
if width > 512: | |
width = 512 | |
height = int(height / width * 512) | |
resized = resized.resize(size=(width, height)) | |
if bytes_data is None: | |
st.markdown(f"[{str(image_id).zfill(12)}.jpg](http://images.cocodataset.org/val2017/{str(image_id).zfill(12)}.jpg)") | |
show = st.image(resized) | |
show.image(resized, '\n\nSelected Image') | |
resized.close() | |
# For newline | |
st.sidebar.write('\n') | |
with st.spinner('Generating image caption ...'): | |
caption = predict(image) | |
caption_en = caption | |
st.header(f'Predicted caption:\n\n') | |
st.subheader(caption_en) | |
st.sidebar.header("ViT-GPT2 predicts: ") | |
st.sidebar.write(f"{caption}") | |
image.close() |