yonatanbitton's picture
Update app.py
900d6f6
from datasets import load_dataset
import gradio as gr
import os
import random
wmtis = load_dataset("nlphuji/wmtis-identify")['test']
print(f"Loaded WMTIS identify, first example:")
print(wmtis[0])
dataset_size = len(wmtis) - 1
NATURAL_IMAGE = 'natural_image'
NORMAL_IMAGE = 'normal_image'
STRANGE_IMAGE = 'strange_image'
def func(index):
example = wmtis[index]
outputs = []
target_size = example['normal_image'].size
add_outputs_for_key(example, outputs, target_size, 'natural')
add_outputs_for_key(example, outputs, target_size, 'normal')
add_outputs_for_key(example, outputs, target_size, 'strange')
return outputs
def add_outputs_for_key(example, outputs, target_size, item):
for item_key in [f'{item}_image', f'{item}_image_caption', f'rating_{item}', f'comments_{item}', f'{item}_hash']:
if item_key == f'comments_{item}':
outputs.append(get_empty_comment_if_needed(example[item_key]))
elif item_key == f'{item}_image':
outputs.append(example[item_key].resize(target_size))
else:
outputs.append(example[item_key])
demo = gr.Blocks()
def get_empty_comment_if_needed(item):
if item == 'nan':
return '-'
return item
def add_column_by_key(item, target_size):
with gr.Column():
img = wmtis[index][f"{item}_image"]
img_resized = img.resize(target_size)
i1 = gr.Image(value=img_resized, label=f'{item.capitalize()} Image')
p1 = gr.Textbox(value=wmtis[index][f"{item}_image_caption"], label='BLIP2 Predicted Caption')
r1 = gr.Textbox(value=wmtis[index][f"rating_{item}"], label='Rating')
c1 = gr.Textbox(value=get_empty_comment_if_needed(wmtis[index][f"comments_{item}"]), label='Comments')
t1 = gr.Textbox(value=wmtis[index][f"{item}_hash"], label='Image ID')
item_outputs = [i1, p1, r1, c1, t1]
return item_outputs
with demo:
gr.Markdown("# Main Challenge: Weirdness, not Synthesis")
with gr.Column():
slider = gr.Slider(minimum=0, maximum=dataset_size)
with gr.Row():
index = slider.value
if index >= dataset_size:
index = 0
target_size = wmtis[index]['normal_image'].size
natural_outputs = add_column_by_key('natural', target_size)
normal_outputs = add_column_by_key('normal', target_size)
strange_outputs = add_column_by_key('strange', target_size)
slider.change(func, inputs=[slider], outputs=natural_outputs + normal_outputs + strange_outputs)
demo.launch()