Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import torch | |
import pickle | |
class Game_Cache: | |
def __init__(self, question_dict, image_files, yn, hard_setting): | |
self.question_dict = question_dict | |
self.image_files = image_files | |
self.yn = yn | |
self.hard_setting = hard_setting | |
image_list = [] | |
with open('./mscoco/mscoco_images.txt', 'r') as f: | |
for line in f.readlines(): | |
image_list.append(line.strip()) | |
image_list_hard = [] | |
with open('./mscoco/mscoco_images_attribute_n=1.txt', 'r') as f: | |
for line in f.readlines(): | |
image_list_hard.append(line.strip()) | |
yn_indices = list(range(40,80))+list(range(120,160)) | |
hard_setting_indices = list(range(80,160)) | |
from model.run_question_asking_model import return_modules, return_modules_yn | |
global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions | |
global question_model, response_model_simul, caption_model | |
question_model, response_model_simul, _, caption_model = return_modules() | |
global question_model_yn, response_model_simul_yn, caption_model_yn | |
question_model_yn, response_model_simul_yn, _, caption_model_yn = return_modules_yn() | |
def create_cache(taskid): | |
original_taskid = taskid | |
global question_model, response_model_simul, caption_model | |
global question_model_yn, response_model_simul_yn, caption_model_yn | |
if taskid in yn_indices: | |
yn = True | |
curr_question_model, curr_response_model_simul, curr_caption_model = question_model, response_model_simul, caption_model | |
taskid-=40 | |
else: | |
yn = False | |
curr_question_model, curr_response_model_simul, curr_caption_model = question_model_yn, response_model_simul_yn, caption_model_yn | |
if taskid in hard_setting_indices: | |
hard_setting = True | |
image_list_curr = image_list_hard | |
taskid -= 80 | |
else: | |
hard_setting = False | |
image_list_curr = image_list | |
id1 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+0]}" | |
id2 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+1]}" | |
id3 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+2]}" | |
id4 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+3]}" | |
id5 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+4]}" | |
id6 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+5]}" | |
id7 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+6]}" | |
id8 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+7]}" | |
id9 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+8]}" | |
id10 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+9]}" | |
image_names = [] | |
for i in range(10): | |
image_names.append(image_list_curr[int(taskid)*10+i]) | |
image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10] | |
image_files = [x[15:] for x in image_files] | |
images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in image_files] | |
images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images_np] | |
p_y_x = (torch.ones(10)/10).to(curr_question_model.device) | |
captions = curr_caption_model.get_captions(image_files) | |
questions, target_questions = curr_question_model.get_questions(image_files, captions, 0) | |
curr_question_model.reset_question_bank() | |
first_question = curr_question_model.select_best_question(p_y_x, questions, images_np, captions, curr_response_model_simul) | |
gc = Game_Cache(curr_question_model.question_bank, image_names, yn, hard_setting) | |
with open(f'./cache{int(taskid)}.p', 'wb') as fp: | |
pickle.dump(gc, fp, protocol=pickle.HIGHEST_PROTOCOL) | |
if __name__=="__main__": | |
for i in range(160): | |
create_cache(i) | |