File size: 11,206 Bytes
022601f
1de1fd2
016285f
 
 
 
 
 
 
 
2424844
 
 
1de1fd2
016285f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
022601f
 
 
016285f
 
022601f
 
 
 
016285f
 
 
 
 
 
 
 
 
022601f
 
 
016285f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
022601f
016285f
 
 
 
 
 
 
022601f
016285f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
022601f
 
 
016285f
 
 
 
 
022601f
016285f
 
 
 
 
022601f
016285f
022601f
016285f
 
 
 
 
 
 
 
 
022601f
 
016285f
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import gradio as gr
from response_db import ResponseDb
from create_cache import Game_Cache
import numpy as np
from PIL import Image
import pandas as pd
import torch
import pickle
import uuid

import nltk
nltk.download('punkt')

db = ResponseDb()
css = """
.chatbot {display:flex;flex-direction:column}
.msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.msg.user {background-color:cornflowerblue;color:white;align-self:self-end}
.msg.bot {background-color:lightgray}
.na_button {background-color:red;color:red}
"""

from model.run_question_asking_model import return_modules, return_modules_yn
question_model, response_model_simul, _, caption_model = return_modules()
question_model_yn, response_model_simul_yn, _, caption_model_yn = return_modules_yn()

class Game_Session:
    def __init__(self, taskid, yn, hard_setting):
        self.yn = yn
        self.hard_setting = hard_setting

        global question_model, response_model_simul, caption_model
        global question_model_yn, response_model_simul_yn, caption_model_yn
        self.question_model = question_model
        self.response_model_simul = response_model_simul
        self.caption_model = caption_model
        self.question_model_yn = question_model_yn
        self.response_model_simul_yn = response_model_simul_yn
        self.caption_model_yn = caption_model_yn

        global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions
        self.image_files, self.image_np, self.p_y_x, self.p_r_qy, self.p_y_xqr = None, None, None, None, None
        self.captions, self.questions, self.target_questions = None, None, None

        self.history = []
        self.game_id = str(uuid.uuid4())
        self.set_curr_models()

    def set_curr_models(self):
        if self.yn:
            self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model_yn, self.caption_model_yn, self.response_model_simul_yn
        else:
            self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model, self.caption_model, self.response_model_simul

    def get_next_question(self):
        return self.curr_question_model.select_best_question(self.p_y_x, self.questions, self.images_np, self.captions, self.curr_response_model_simul)


def ask_a_question(input, taskid, gs):
    gs.history.append(input)
    gs.p_r_qy = gs.curr_response_model_simul.get_p_r_qy(input, gs.history[-2], gs.images_np, gs.captions)
    gs.p_y_xqr = gs.p_y_x*gs.p_r_qy
    gs.p_y_xqr = gs.p_y_xqr/torch.sum(gs.p_y_xqr)if torch.sum(gs.p_y_xqr) != 0 else torch.zeros_like(gs.p_y_xqr)        
    gs.p_y_x = gs.p_y_xqr
    gs.questions.remove(gs.history[-2])
    db.add(gs.game_id, taskid, len(gs.history)//2-1, gs.history[-2], gs.history[-1])
    gs.history.append(gs.get_next_question())

    top_prob = torch.max(gs.p_y_x).item()
    top_pred = torch.argmax(gs.p_y_x).item()
    if top_prob > 0.8: 
        gs.history = gs.history[:-1]
        db.add(gs.game_id, taskid, len(gs.history)//2, f"Guess: Image {top_pred}", "")

    # write some HTML
    html = "<div class='chatbot'>"
    for m, msg in enumerate(gs.history):
        if msg=="nothing": msg="n/a"
        cls = "bot" if m%2 == 0 else "user"
        html += "<div class='msg {}'> {}</div>".format(cls, msg)
    html += "</div>"

    ### Game finished:
    if top_prob > 0.8:
        html += f"<p>The model identified <b>Image {top_pred+1}</b> as the image. Please select a new task ID to continue.</p>"
        return html, gs, gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False)
    else:
        if not gs.yn:
            return html, gs, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False)
        else:
            return html, gs, gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True)


def set_images(taskid):
    pilot_study = pd.read_csv("pilot-study.csv")
    taskid_original = taskid
    taskid = pilot_study['mscoco-id'].tolist()[int(taskid)]

    with open(f'cache/{int(taskid)}.p', 'rb') as fp:
        game_cache = pickle.load(fp)
    gs = Game_Session(int(taskid), game_cache.yn, game_cache.hard_setting)
    id1 = f"./mscoco-images/val2014/{game_cache.image_files[0]}"
    id2 = f"./mscoco-images/val2014/{game_cache.image_files[1]}"
    id3 = f"./mscoco-images/val2014/{game_cache.image_files[2]}"
    id4 = f"./mscoco-images/val2014/{game_cache.image_files[3]}"
    id5 = f"./mscoco-images/val2014/{game_cache.image_files[4]}"
    id6 = f"./mscoco-images/val2014/{game_cache.image_files[5]}"
    id7 = f"./mscoco-images/val2014/{game_cache.image_files[6]}"
    id8 = f"./mscoco-images/val2014/{game_cache.image_files[7]}"
    id9 = f"./mscoco-images/val2014/{game_cache.image_files[8]}"
    id10 = f"./mscoco-images/val2014/{game_cache.image_files[9]}"    

    gs.image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10]
    gs.image_files = [x[15:] for x in gs.image_files]
    gs.images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in gs.image_files]
    gs.images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in gs.images_np]

    gs.p_y_x = (torch.ones(10)/10).to(gs.curr_question_model.device)
    gs.captions = gs.curr_caption_model.get_captions(gs.image_files)
    gs.questions, gs.target_questions = gs.curr_question_model.get_questions(gs.image_files, gs.captions, 0)
    gs.curr_question_model.reset_question_bank()
    gs.curr_question_model.question_bank = game_cache.question_dict
    first_question = gs.curr_question_model.select_best_question(gs.p_y_x, gs.questions, gs.images_np, gs.captions, gs.curr_response_model_simul)
    first_question_html = f"<div class='chatbot'><div class='msg bot'>{first_question}</div></div>"
    gs.history.append(first_question)
    html = f"<p>Current Task ID: <b>{int(taskid_original)}</b></p>"
    if not gs.yn:
        return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.HTML.update(value=html, visible=True), gr.Textbox.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False)
    else:
        return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.HTML.update(value=html, visible=True), gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=True), gr.Button.update(visible=True)


with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo:
    gr.HTML("<h1>Image Q&A Guessing Game</h1>\
    <p style='font-size:120%;'>\
    Imagine you are playing 20-questions with an AI model.<br>\
    The AI model plays the role of the question asker. You play the role of the responder. <br>\
    There are 10 images. <b>Your image is Image 1</b>. The other images are distraction images.\
    The model can see all 10 images and all the questions and answers for the current set of images. It will ask a question based on the available information.<br>\
    <span style='color: #0000ff'>The goal of the model is to accurately guess the correct image (i.e. <b><span style='color: #0000ff'>Image 1</span></b>) in as few turns as possible.<br>\
    Your goal is to help the model guess the image by answering as clearly and accurately as possible.</span><br><br>\
    <b>Guidelines:</b><br>\
    <ol style='font-size:120%;'>\
        <li>It is best to keep your answers short (a single word or a short phrase). No need to answer in full sentences.</li>\
        <li>If you feel that the question cannot be answered or does not apply to Image 1, please select N/A.</li>\
    </ol> \
    <br>\
    (Note: We are testing multiple game settings. In some instances, the game will be open-ended, while in other instances, the answer choices will be limited to yes/no.)<br></p>\
    <br>\
    <h2>Please enter a TaskID to start</h2>")

    with gr.Column():
        with gr.Row():
            taskid = gr.Number(label="Task ID (Enter a number from 0 to 160)", value=0)
            start_button = gr.Button("Enter")
        with gr.Row():
            task_text = gr.HTML()

    with gr.Column() as img_block:
        with gr.Row():
            img1 = gr.Image(label="Image 1", show_label=True)
            img2 = gr.Image(label="Image 2", show_label=True)
            img3 = gr.Image(label="Image 3", show_label=True)
            img4 = gr.Image(label="Image 4", show_label=True)
            img5 = gr.Image(label="Image 5", show_label=True)
        with gr.Row():
            img6 = gr.Image(label="Image 6", show_label=True)
            img7 = gr.Image(label="Image 7", show_label=True)
            img8 = gr.Image(label="Image 8", show_label=True)
            img9 = gr.Image(label="Image 9", show_label=True)
            img10 = gr.Image(label="Image 10", show_label=True)
    conversation = gr.HTML()
    game_session_state = gr.State()

    answer = gr.Textbox(placeholder="Insert answer here.", label="Answer the given question.", visible=False)
    null_answer = gr.Textbox("nothing", visible=False)
    yes_answer = gr.Textbox("yes", visible=False)
    no_answer = gr.Textbox("no", visible=False)

    with gr.Column():
        with gr.Row():
            yes_box = gr.Button("Yes", visible=False)
            no_box = gr.Button("No", visible=False)
    with gr.Column():
        with gr.Row():
            na_box = gr.Button("N/A", visible=False, elem_classes="na_button")
            submit = gr.Button("Submit", visible=False)
    ### Button click events
    start_button.click(fn=set_images, inputs=taskid, outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, game_session_state, conversation, task_text, answer, na_box, submit, taskid, start_button, yes_box, no_box])
    submit.click(fn=ask_a_question, inputs=[answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
    na_box.click(fn=ask_a_question, inputs=[null_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
    yes_box.click(fn=ask_a_question, inputs=[yes_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])
    no_box.click(fn=ask_a_question, inputs=[no_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box])

demo.launch()