bhavitvyamalik's picture
UI
a3a0c58
raw
history blame
6.13 kB
from .utils import get_transformed_image
import streamlit as st
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import re
from mtranslate import translate
from .utils import (
read_markdown,
tokenizer,
language_mapping,
code_to_name,
voicerss_tts
)
import requests
from PIL import Image
from .model.flax_clip_vision_mbart.modeling_clip_vision_mbart import (
FlaxCLIPVisionMBartForConditionalGeneration,
)
from streamlit import caching
def app(state):
mic_state = state
with st.beta_expander("Usage"):
st.write(read_markdown("usage.md"))
st.write("\n")
st.write(read_markdown("intro.md"))
# st.sidebar.title("Generation Parameters")
max_length = 64
with st.sidebar.beta_expander('Generation Parameters'):
do_sample = st.checkbox("Sample", value=False, help="Sample from the model instead of using beam search.")
top_k = st.number_input("Top K", min_value=10, max_value=200, value=50, step=1, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.")
num_beams = st.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.")
temperature = st.select_slider(label="Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}")
top_p = st.select_slider(label = "Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}")
if st.button("Clear All Cache"):
caching.clear_cache()
@st.cache
def load_model(ckpt):
return FlaxCLIPVisionMBartForConditionalGeneration.from_pretrained(ckpt)
@st.cache
def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p, do_sample, top_k, max_length):
lang_code = language_mapping[lang_code]
output_ids = mic_state.model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=max_length, num_beams=num_beams, temperature=temperature, top_p = top_p, top_k=top_k, do_sample=do_sample)
print(output_ids)
output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=max_length)
return output_sequence
mic_checkpoints = ["flax-community/clip-vit-base-patch32_mbart-large-50"] # TODO: Maybe add more checkpoints?
dummy_data = pd.read_csv("reference.tsv", sep="\t")
first_index = 25
# Init Session State
if mic_state.image_file is None:
mic_state.image_file = dummy_data.loc[first_index, "image_file"]
mic_state.caption = dummy_data.loc[first_index, "caption"].strip("- ")
mic_state.lang_id = dummy_data.loc[first_index, "lang_id"]
image_path = os.path.join("images", mic_state.image_file)
image = plt.imread(image_path)
mic_state.image = image
if mic_state.model is None:
# Display Top-5 Predictions
with st.spinner("Loading model..."):
mic_state.model = load_model(mic_checkpoints[0])
query1 = st.text_input(
"Enter a URL to an image",
value="http://images.cocodataset.org/val2017/000000397133.jpg",
)
col1, col2, col3 = st.beta_columns([2,1, 2])
if col1.button(
"Get a random example",
help="Get a random example from the 100 `seeded` image-text pairs.",
):
sample = dummy_data.sample(1).reset_index()
mic_state.image_file = sample.loc[0, "image_file"]
mic_state.caption = sample.loc[0, "caption"].strip("- ")
mic_state.lang_id = sample.loc[0, "lang_id"]
image_path = os.path.join("images", mic_state.image_file)
image = plt.imread(image_path)
mic_state.image = image
col2.write("OR")
if col3.button("Use above URL"):
image_data = requests.get(query1, stream=True).raw
image = np.asarray(Image.open(image_data))
mic_state.image = image
transformed_image = get_transformed_image(mic_state.image)
new_col1, new_col2 = st.beta_columns([5,5])
# Display Image
new_col1.image(mic_state.image, use_column_width="always")
# Display Reference Caption
with new_col1.beta_expander("Reference Caption"):
st.write("**Reference Caption**: " + mic_state.caption)
st.markdown(
f"""**English Translation**: {mic_state.caption if mic_state.lang_id == "en" else translate(mic_state.caption, 'en')}"""
)
# Select Language
options = list(code_to_name.keys())
lang_id = new_col2.selectbox(
"Language",
index=options.index(mic_state.lang_id),
options=options,
format_func=lambda x: code_to_name[x],
help="The language in which caption is to be generated."
)
sequence = ['']
if new_col2.button("Generate Caption", help="Generate a caption in the specified language."):
with st.spinner("Generating Sequence... This might take some time, you can read our Article meanwhile!"):
sequence = generate_sequence(transformed_image, lang_id, num_beams, temperature, top_p, do_sample, top_k, max_length)
# print(sequence)
if sequence!=['']:
new_col2.write(
"**Generated Caption**: "+sequence[0]
)
new_col2.write(
"**English Translation**: "+ sequence[0] if lang_id=="en" else translate(sequence[0])
)
with new_col2:
try:
clean_text = re.sub(r'[^A-Za-z0-9 ]+', '', sequence[0])
# st.write("**Cleaned Text**: ",clean_text)
audio_bytes = voicerss_tts(clean_text, lang_id)
st.markdown("**Audio for the generated caption**")
st.audio(audio_bytes)
except:
st.info("Unabled to generate audio. Please try again in some time.")