import gradio as gr import numpy as np from PIL import Image import torch import pickle class Game_Cache: def __init__(self, question_dict, image_files, yn, hard_setting): self.question_dict = question_dict self.image_files = image_files self.yn = yn self.hard_setting = hard_setting image_list = [] with open('./mscoco/mscoco_images.txt', 'r') as f: for line in f.readlines(): image_list.append(line.strip()) image_list_hard = [] with open('./mscoco/mscoco_images_attribute_n=1.txt', 'r') as f: for line in f.readlines(): image_list_hard.append(line.strip()) yn_indices = list(range(40,80))+list(range(120,160)) hard_setting_indices = list(range(80,160)) from model.run_question_asking_model import return_modules, return_modules_yn global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions global question_model, response_model_simul, caption_model question_model, response_model_simul, _, caption_model = return_modules() global question_model_yn, response_model_simul_yn, caption_model_yn question_model_yn, response_model_simul_yn, _, caption_model_yn = return_modules_yn() def create_cache(taskid): original_taskid = taskid global question_model, response_model_simul, caption_model global question_model_yn, response_model_simul_yn, caption_model_yn if taskid in yn_indices: yn = True curr_question_model, curr_response_model_simul, curr_caption_model = question_model, response_model_simul, caption_model taskid-=40 else: yn = False curr_question_model, curr_response_model_simul, curr_caption_model = question_model_yn, response_model_simul_yn, caption_model_yn if taskid in hard_setting_indices: hard_setting = True image_list_curr = image_list_hard taskid -= 80 else: hard_setting = False image_list_curr = image_list id1 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+0]}" id2 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+1]}" id3 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+2]}" id4 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+3]}" id5 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+4]}" id6 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+5]}" id7 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+6]}" id8 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+7]}" id9 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+8]}" id10 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+9]}" image_names = [] for i in range(10): image_names.append(image_list_curr[int(taskid)*10+i]) image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10] image_files = [x[15:] for x in image_files] images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in image_files] images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images_np] p_y_x = (torch.ones(10)/10).to(curr_question_model.device) captions = curr_caption_model.get_captions(image_files) questions, target_questions = curr_question_model.get_questions(image_files, captions, 0) curr_question_model.reset_question_bank() first_question = curr_question_model.select_best_question(p_y_x, questions, images_np, captions, curr_response_model_simul) gc = Game_Cache(curr_question_model.question_bank, image_names, yn, hard_setting) with open(f'./cache{int(taskid)}.p', 'wb') as fp: pickle.dump(gc, fp, protocol=pickle.HIGHEST_PROTOCOL) if __name__=="__main__": for i in range(160): create_cache(i)