File size: 3,744 Bytes
016285f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import numpy as np
from PIL import Image
import torch
import pickle

class Game_Cache:
    def __init__(self, question_dict, image_files, yn, hard_setting):
        self.question_dict = question_dict
        self.image_files = image_files
        self.yn = yn
        self.hard_setting = hard_setting

image_list = []
with open('./mscoco/mscoco_images.txt', 'r') as f:
    for line in f.readlines():
        image_list.append(line.strip())
image_list_hard = []
with open('./mscoco/mscoco_images_attribute_n=1.txt', 'r') as f:
    for line in f.readlines():
        image_list_hard.append(line.strip())

yn_indices = list(range(40,80))+list(range(120,160))
hard_setting_indices = list(range(80,160))


from model.run_question_asking_model import return_modules, return_modules_yn
global image_files, images_np, p_y_x, p_r_qy, p_y_xqr, captions, questions, target_questions 
global question_model, response_model_simul, caption_model
question_model, response_model_simul, _, caption_model = return_modules()
global question_model_yn, response_model_simul_yn, caption_model_yn
question_model_yn, response_model_simul_yn, _, caption_model_yn = return_modules_yn()

def create_cache(taskid):
    original_taskid = taskid
    global question_model, response_model_simul, caption_model
    global question_model_yn, response_model_simul_yn, caption_model_yn
    if taskid in yn_indices: 
        yn = True
        curr_question_model, curr_response_model_simul, curr_caption_model = question_model, response_model_simul, caption_model
        taskid-=40
    else: 
        yn = False
        curr_question_model, curr_response_model_simul, curr_caption_model = question_model_yn, response_model_simul_yn, caption_model_yn
    if taskid in hard_setting_indices: 
        hard_setting = True
        image_list_curr = image_list_hard
        taskid -= 80
    else: 
        hard_setting = False    
        image_list_curr = image_list
    
    id1 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+0]}"
    id2 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+1]}"
    id3 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+2]}"
    id4 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+3]}"
    id5 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+4]}"
    id6 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+5]}"
    id7 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+6]}"
    id8 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+7]}"
    id9 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+8]}"
    id10 = f"./mscoco-images/val2014/{image_list_curr[int(taskid)*10+9]}"
    image_names = []
    for i in range(10):
        image_names.append(image_list_curr[int(taskid)*10+i])
    image_files = [id1, id2, id3, id4, id5, id6, id7, id8, id9, id10]
    image_files = [x[15:] for x in image_files]
    images_np = [np.asarray(Image.open(f"./mscoco-images/{i}")) for i in image_files]
    images_np = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images_np]
    p_y_x = (torch.ones(10)/10).to(curr_question_model.device)
    captions = curr_caption_model.get_captions(image_files)
    questions, target_questions = curr_question_model.get_questions(image_files, captions, 0)
    curr_question_model.reset_question_bank()
    first_question = curr_question_model.select_best_question(p_y_x, questions, images_np, captions, curr_response_model_simul)

    gc = Game_Cache(curr_question_model.question_bank, image_names, yn, hard_setting)
    with open(f'./cache{int(taskid)}.p', 'wb') as fp:
        pickle.dump(gc, fp, protocol=pickle.HIGHEST_PROTOCOL)

if __name__=="__main__":
    for i in range(160):
        create_cache(i)