from datasets import load_dataset import gradio as gr import os import random wmtis = load_dataset("nlphuji/wmtis-identify")['test'] print(f"Loaded WMTIS identify, first example:") print(wmtis[0]) dataset_size = len(wmtis) - 1 NATURAL_IMAGE = 'natural_image' NORMAL_IMAGE = 'normal_image' STRANGE_IMAGE = 'strange_image' def func(index): example = wmtis[index] outputs = [] target_size = example['normal_image'].size add_outputs_for_key(example, outputs, target_size, 'natural') add_outputs_for_key(example, outputs, target_size, 'normal') add_outputs_for_key(example, outputs, target_size, 'strange') return outputs def add_outputs_for_key(example, outputs, target_size, item): for item_key in [f'{item}_image', f'{item}_image_caption', f'rating_{item}', f'comments_{item}', f'{item}_hash']: if item_key == f'comments_{item}': outputs.append(get_empty_comment_if_needed(example[item_key])) elif item_key == f'{item}_image': outputs.append(example[item_key].resize(target_size)) else: outputs.append(example[item_key]) demo = gr.Blocks() def get_empty_comment_if_needed(item): if item == 'nan': return '-' return item def add_column_by_key(item, target_size): with gr.Column(): img = wmtis[index][f"{item}_image"] img_resized = img.resize(target_size) i1 = gr.Image(value=img_resized, label=f'{item.capitalize()} Image') p1 = gr.Textbox(value=wmtis[index][f"{item}_image_caption"], label='BLIP2 Predicted Caption') r1 = gr.Textbox(value=wmtis[index][f"rating_{item}"], label='Rating') c1 = gr.Textbox(value=get_empty_comment_if_needed(wmtis[index][f"comments_{item}"]), label='Comments') t1 = gr.Textbox(value=wmtis[index][f"{item}_hash"], label='Image ID') item_outputs = [i1, p1, r1, c1, t1] return item_outputs with demo: gr.Markdown("# Main Challenge: Weirdness, not Synthesis") with gr.Column(): slider = gr.Slider(minimum=0, maximum=dataset_size) with gr.Row(): index = slider.value if index >= dataset_size: index = 0 target_size = wmtis[index]['normal_image'].size natural_outputs = add_column_by_key('natural', target_size) normal_outputs = add_column_by_key('normal', target_size) strange_outputs = add_column_by_key('strange', target_size) slider.change(func, inputs=[slider], outputs=natural_outputs + normal_outputs + strange_outputs) demo.launch()