File size: 5,044 Bytes
d7c4521
 
 
71f6f34
9fc5a09
d7c4521
0febedf
c228017
d7c4521
 
 
 
 
 
 
 
 
 
 
 
 
 
5e12010
 
d7c4521
 
 
13b0421
5e12010
4184f8f
5e12010
d7c4521
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d9aaab
d7c4521
9a8b4c1
d7c4521
 
 
 
 
9fc5a09
d7c4521
 
 
 
 
 
8a5b915
d7c4521
 
 
 
 
 
 
 
7ff99ec
 
d7c4521
 
 
 
41eee7b
d7c4521
 
 
 
 
 
82c0a05
 
 
 
 
 
 
 
d7c4521
 
 
 
 
 
 
 
 
 
 
 
 
a1f9859
d7c4521
 
 
8a5b915
d7c4521
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import math
from datasets import load_dataset
import gradio as gr
import os
import ast

auth_token = os.environ.get("auth_token")
whoops = load_dataset("nlphuji/whoops", token=auth_token, trust_remote_code=True)['test'].shuffle()
# print(f"Loaded WHOOPS!, first example:")
# print(whoops[0])
dataset_size = len(whoops)

IMAGE = 'image'
IMAGE_DESIGNER = 'image_designer'
DESIGNER_EXPLANATION = 'designer_explanation'
CROWD_CAPTIONS = 'crowd_captions'
CROWD_EXPLANATIONS = 'crowd_explanations'
CROWD_UNDERSPECIFIED_CAPTIONS = 'crowd_underspecified_captions'
QA = 'question_answering_pairs'
IMAGE_ID = 'image_id'
SELECTED_CAPTION = 'selected_caption'
COMMONSENSE_CATEGORY = 'commonsense_category'


left_side_columns = [IMAGE]
right_side_columns = [x for x in whoops.features.keys() if x not in left_side_columns]
enumerate_cols = [CROWD_CAPTIONS, CROWD_EXPLANATIONS, CROWD_UNDERSPECIFIED_CAPTIONS]


right_side_columns.remove('image_url')

emoji_to_label = {IMAGE_DESIGNER: '🎨, πŸ§‘β€πŸŽ¨, πŸ’»', DESIGNER_EXPLANATION: 'πŸ’‘, πŸ€”, πŸ§‘β€πŸŽ¨',
                  CROWD_CAPTIONS: 'πŸ‘₯, πŸ’¬, πŸ“', CROWD_EXPLANATIONS: 'πŸ‘₯, πŸ’‘, πŸ€”', CROWD_UNDERSPECIFIED_CAPTIONS: 'πŸ‘₯, πŸ’¬, πŸ‘Ž',
                  QA: '❓, πŸ€”, πŸ’‘', IMAGE_ID: 'πŸ”, πŸ“„, πŸ’Ύ', COMMONSENSE_CATEGORY: 'πŸ€”, πŸ“š, πŸ’‘', SELECTED_CAPTION: 'πŸ“, πŸ‘Œ, πŸ’¬'}
# batch_size = 16
batch_size = 8
target_size = (1024, 1024)


def func(index):
    start_index = index * batch_size
    end_index = start_index + batch_size
    all_examples = [whoops[index] for index in list(range(start_index, end_index))]
    values_lst = []
    for example_idx, example in enumerate(all_examples):
        values = get_instance_values(example)
        values_lst += values
    return values_lst


def get_instance_values(example):
    values = []
    for k in left_side_columns + right_side_columns:
        if k == IMAGE:
            value = example["image"].resize(target_size)
        elif k in enumerate_cols:
            value = list_to_string(ast.literal_eval(example[k]))
        elif k == QA:
            qa_list = [f"Q: {x[0]} A: {x[1]}" for x in ast.literal_eval(example[k])]
            value = list_to_string(qa_list)
        else:
            value = example[k]
        values.append(value)
    return values
    
def list_to_string(lst):
    return '\n'.join(['{}. {}'.format(i+1, item) for i, item in enumerate(lst)])

demo = gr.Blocks()


def get_col(example):
    instance_values = get_instance_values(example)
    with gr.Column():
        inputs_left = []
        assert len(left_side_columns) == len(
            instance_values[:len(left_side_columns)])  # excluding the image & designer
        for key, value in zip(left_side_columns, instance_values[:len(left_side_columns)]):
            if key == IMAGE:
                img_resized = example["image"].resize(target_size)
                # input_k = gr.Image(value=img_resized, label=example['commonsense_category'])
                input_k = gr.Image(value=img_resized)
            else:
                label = key.capitalize().replace("_", " ")
                input_k = gr.Textbox(value=value, label=f"{label} {emoji_to_label[key]}")
            inputs_left.append(input_k)
            
        with gr.Accordion("Click for details", open=False):
            text_inputs_right = []
            assert len(right_side_columns) == len(
                instance_values[len(left_side_columns):])  # excluding the image & designer
            for key, value in zip(right_side_columns, instance_values[len(left_side_columns):]):
                label = key.capitalize().replace("_", " ")
                
                if type(value) != str:
                    num_lines = 1
                else:
                    num_lines = max(1, len(value) // 50 + (len(value) % 45 > 0))  # Assuming ~50 chars per line
                    
                text_input_k = gr.Textbox(value=value, label=f"{label} {emoji_to_label[key]}", lines=num_lines)
                # text_input_k = gr.Textbox(value=value, label=f"{label} {emoji_to_label[key]}")
                text_inputs_right.append(text_input_k)
    return inputs_left, text_inputs_right


with demo:
    gr.Markdown("# Slide to iterate WHOOPS!")

    with gr.Column():
        num_batches = math.ceil(dataset_size / batch_size)
        slider = gr.Slider(minimum=0, maximum=num_batches, step=1, label=f'Page (out of {num_batches})')
        with gr.Row():
            index = slider.value
            start_index = 0 * batch_size
            end_index = start_index + batch_size
            all_examples = [whoops[index] for index in list(range(start_index, end_index))]
            all_inputs_left_right = []
            for example_idx, example in enumerate(all_examples):
                inputs_left, text_inputs_right = get_col(example)
                inputs_left_right = inputs_left + text_inputs_right
                all_inputs_left_right += inputs_left_right

    slider.change(func, inputs=[slider], outputs=all_inputs_left_right)

demo.launch()