vqa-guessing-game / create_cache.py
sedrickkeh's picture
Upload 13 files
016285f
raw
history blame
3.74 kB
import gradio as gr
import numpy as np
from PIL import Image
import torch
import pickle
class Game_Cache:
def __init__(self, question_dict, image_files, yn, hard_setting):
self.question_dict = question_dict
self.image_files = image_files
self.yn = yn
self.hard_setting = hard_setting
image_list = []
with open('./mscoco/mscoco_images.txt', 'r') as f:
for line in f.readlines():
image_list.append(line.strip())
image_list_hard = []
with open('./mscoco/mscoco_images_attribute_n=1.txt', 'r') as f:
for line in f.readlines():
image_list_hard.append(line.strip())
yn_indices = list(range(40,80))+list(range(120,160))
hard_setting_indices = list(range(80,160))
from model.run_question_asking_model import return_modules, return_modules_yn
global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions
global question_model, response_model_simul, caption_model
question_model, response_model_simul, _, caption_model = return_modules()
global question_model_yn, response_model_simul_yn, caption_model_yn
question_model_yn, response_model_simul_yn, _, caption_model_yn = return_modules_yn()
def create_cache(taskid):
original_taskid = taskid
global question_model, response_model_simul, caption_model
global question_model_yn, response_model_simul_yn, caption_model_yn
if taskid in yn_indices:
yn = True
curr_question_model, curr_response_model_simul, curr_caption_model = question_model, response_model_simul, caption_model
taskid-=40
else:
yn = False
curr_question_model, curr_response_model_simul, curr_caption_model = question_model_yn, response_model_simul_yn, caption_model_yn
if taskid in hard_setting_indices:
hard_setting = True
image_list_curr = image_list_hard
taskid -= 80
else:
hard_setting = False
image_list_curr = image_list
id1 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+0]}"
id2 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+1]}"
id3 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+2]}"
id4 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+3]}"
id5 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+4]}"
id6 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+5]}"
id7 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+6]}"
id8 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+7]}"
id9 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+8]}"
id10 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+9]}"
image_names = []
for i in range(10):
image_names.append(image_list_curr[int(taskid)*10+i])
image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10]
image_files = [x[15:] for x in image_files]
images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in image_files]
images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images_np]
p_y_x = (torch.ones(10)/10).to(curr_question_model.device)
captions = curr_caption_model.get_captions(image_files)
questions, target_questions = curr_question_model.get_questions(image_files, captions, 0)
curr_question_model.reset_question_bank()
first_question = curr_question_model.select_best_question(p_y_x, questions, images_np, captions, curr_response_model_simul)
gc = Game_Cache(curr_question_model.question_bank, image_names, yn, hard_setting)
with open(f'./cache{int(taskid)}.p', 'wb') as fp:
pickle.dump(gc, fp, protocol=pickle.HIGHEST_PROTOCOL)
if __name__=="__main__":
for i in range(160):
create_cache(i)