import gradio as gr from response_db import StResponseDb from create_cache import Game_Cache import numpy as np from PIL import Image import pandas as pd import torch import pickle import uuid db = StResponseDb() css = """ .chatbot {display:flex;flex-direction:column} .msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%} .msg.user {background-color:cornflowerblue;color:white;align-self:self-end} .msg.bot {background-color:lightgray} .na_button {background-color:red;color:red} """ from model.run_question_asking_model import return_modules, return_modules_yn question_model, response_model_simul, _, caption_model = return_modules() question_model_yn, response_model_simul_yn, _, caption_model_yn = return_modules_yn() class Game_Session: def __init__(self, taskid, yn, hard_setting): self.yn = yn self.hard_setting = hard_setting global question_model, response_model_simul, caption_model global question_model_yn, response_model_simul_yn, caption_model_yn self.question_model = question_model self.response_model_simul = response_model_simul self.caption_model = caption_model self.question_model_yn = question_model_yn self.response_model_simul_yn = response_model_simul_yn self.caption_model_yn = caption_model_yn global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions self.image_files, self.image_np, self.p_y_x, self.p_r_qy, self.p_y_xqr = None, None, None, None, None self.captions, self.questions, self.target_questions = None, None, None self.history = [] self.game_id = str(uuid.uuid4()) self.set_curr_models() def set_curr_models(self): if self.yn: self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model_yn, self.caption_model_yn, self.response_model_simul_yn else: self.curr_question_model, self.curr_caption_model, self.curr_response_model_simul = self.question_model, self.caption_model, self.response_model_simul def get_next_question(self): return self.curr_question_model.select_best_question(self.p_y_x, self.questions, self.images_np, self.captions, self.curr_response_model_simul) def ask_a_question(input, taskid, gs): gs.history.append(input) gs.p_r_qy = gs.curr_response_model_simul.get_p_r_qy(input, gs.history[-2], gs.images_np, gs.captions) gs.p_y_xqr = gs.p_y_x*gs.p_r_qy gs.p_y_xqr = gs.p_y_xqr/torch.sum(gs.p_y_xqr)if torch.sum(gs.p_y_xqr) != 0 else torch.zeros_like(gs.p_y_xqr) gs.p_y_x = gs.p_y_xqr gs.questions.remove(gs.history[-2]) db.add(gs.game_id, taskid, len(gs.history)//2-1, gs.history[-2], gs.history[-1]) gs.history.append(gs.get_next_question()) top_prob = torch.max(gs.p_y_x).item() top_pred = torch.argmax(gs.p_y_x).item() if top_prob > 0.8: gs.history = gs.history[:-1] db.add(gs.game_id, taskid, len(gs.history)//2, f"Guess: Image {top_pred}", "") # write some HTML html = "
" for m, msg in enumerate(gs.history): if msg=="nothing": msg="n/a" cls = "bot" if m%2 == 0 else "user" html += "
{}
".format(cls, msg) html += "
" ### Game finished: if top_prob > 0.8: html += f"

The model identified Image {top_pred+1} as the image. Please select a new task ID to continue.

" return html, gs, gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False) else: if not gs.yn: return html, gs, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False) else: return html, gs, gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=True), gr.Button.update(visible=True) def set_images(taskid): pilot_study = pd.read_csv("pilot-study.csv") taskid_original = taskid taskid = pilot_study['mscoco-id'].tolist()[int(taskid)] with open(f'cache/{int(taskid)}.p', 'rb') as fp: game_cache = pickle.load(fp) gs = Game_Session(int(taskid), game_cache.yn, game_cache.hard_setting) id1 = f"./mscoco-images/val2014/{game_cache.image_files[0]}" id2 = f"./mscoco-images/val2014/{game_cache.image_files[1]}" id3 = f"./mscoco-images/val2014/{game_cache.image_files[2]}" id4 = f"./mscoco-images/val2014/{game_cache.image_files[3]}" id5 = f"./mscoco-images/val2014/{game_cache.image_files[4]}" id6 = f"./mscoco-images/val2014/{game_cache.image_files[5]}" id7 = f"./mscoco-images/val2014/{game_cache.image_files[6]}" id8 = f"./mscoco-images/val2014/{game_cache.image_files[7]}" id9 = f"./mscoco-images/val2014/{game_cache.image_files[8]}" id10 = f"./mscoco-images/val2014/{game_cache.image_files[9]}" gs.image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10] gs.image_files = [x[15:] for x in gs.image_files] gs.images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in gs.image_files] gs.images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in gs.images_np] gs.p_y_x = (torch.ones(10)/10).to(gs.curr_question_model.device) gs.captions = gs.curr_caption_model.get_captions(gs.image_files) gs.questions, gs.target_questions = gs.curr_question_model.get_questions(gs.image_files, gs.captions, 0) gs.curr_question_model.reset_question_bank() gs.curr_question_model.question_bank = game_cache.question_dict first_question = gs.curr_question_model.select_best_question(gs.p_y_x, gs.questions, gs.images_np, gs.captions, gs.curr_response_model_simul) first_question_html = f"
{first_question}
" gs.history.append(first_question) html = f"

Current Task ID: {int(taskid_original)}

" if not gs.yn: return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.HTML.update(value=html, visible=True), gr.Textbox.update(visible=True, value=''), gr.Button.update(visible=True), gr.Button.update(visible=True), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False) else: return id1, id2, id3, id4, id5, id6, id7, id8, id9, id10, gs, first_question_html, gr.HTML.update(value=html, visible=True), gr.Textbox.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Number.update(visible=False), gr.Button.update(visible=False), gr.Button.update(visible=True), gr.Button.update(visible=True) with gr.Blocks(title="Image Q&A Guessing Game", css=css) as demo: gr.HTML("

Image Q&A Guessing Game

\

\ Imagine you are playing 20-questions with an AI model.
\ The AI model plays the role of the question asker. You play the role of the responder.
\ There are 10 images. Your image is Image 1. The other images are distraction images.\ The model can see all 10 images and all the questions and answers for the current set of images. It will ask a question based on the available information.
\ The goal of the model is to accurately guess the correct image (i.e. Image 1) in as few turns as possible.
\ Your goal is to help the model guess the image by answering as clearly and accurately as possible.


\ Guidelines:
\

    \
  1. It is best to keep your answers short (a single word or a short phrase). No need to answer in full sentences.
  2. \
  3. If you feel that the question cannot be answered or does not apply to Image 1, please select N/A.
  4. \
\
\ (Note: We are testing multiple game settings. In some instances, the game will be open-ended, while in other instances, the answer choices will be limited to yes/no.)

\
\

Please enter a TaskID to start

") with gr.Column(): with gr.Row(): taskid = gr.Number(label="Task ID (Enter a number from 0 to 160)", value=0) start_button = gr.Button("Enter") with gr.Row(): task_text = gr.HTML() with gr.Column() as img_block: with gr.Row(): img1 = gr.Image(label="Image 1", show_label=True) img2 = gr.Image(label="Image 2", show_label=True) img3 = gr.Image(label="Image 3", show_label=True) img4 = gr.Image(label="Image 4", show_label=True) img5 = gr.Image(label="Image 5", show_label=True) with gr.Row(): img6 = gr.Image(label="Image 6", show_label=True) img7 = gr.Image(label="Image 7", show_label=True) img8 = gr.Image(label="Image 8", show_label=True) img9 = gr.Image(label="Image 9", show_label=True) img10 = gr.Image(label="Image 10", show_label=True) conversation = gr.HTML() game_session_state = gr.State() answer = gr.Textbox(placeholder="Insert answer here.", label="Answer the given question.", visible=False) null_answer = gr.Textbox("nothing", visible=False) yes_answer = gr.Textbox("yes", visible=False) no_answer = gr.Textbox("no", visible=False) with gr.Column(): with gr.Row(): yes_box = gr.Button("Yes", visible=False) no_box = gr.Button("No", visible=False) with gr.Column(): with gr.Row(): na_box = gr.Button("N/A", visible=False, elem_classes="na_button") submit = gr.Button("Submit", visible=False) ### Button click events start_button.click(fn=set_images, inputs=taskid, outputs=[img1, img2, img3, img4, img5, img6, img7, img8, img9, img10, game_session_state, conversation, task_text, answer, na_box, submit, taskid, start_button, yes_box, no_box]) submit.click(fn=ask_a_question, inputs=[answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box]) na_box.click(fn=ask_a_question, inputs=[null_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box]) yes_box.click(fn=ask_a_question, inputs=[yes_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box]) no_box.click(fn=ask_a_question, inputs=[no_answer, taskid, game_session_state], outputs=[conversation, game_session_state, answer, na_box, submit, taskid, start_button, yes_box, no_box]) demo.launch()