Spaces:
Runtime error
Runtime error
from model.model.question_asking_model import get_question_model | |
from model.model.caption_model import get_caption_model | |
from model.model.response_model import get_response_model | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from PIL import Image | |
import argparse | |
import random | |
from tqdm.auto import tqdm | |
import numpy as np | |
import pandas as pd | |
import logging | |
from model.utils import logging_handler, image_saver, assert_checks | |
random.seed(123) | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--device', type=str, default='cuda') | |
parser.add_argument('--include_what', action='store_true') | |
parser.add_argument('--target_idx', type=int, default=0) | |
parser.add_argument('--max_num_questions', type=int, default=25) | |
parser.add_argument('--num_images', type=int, default=10) | |
parser.add_argument('--beam', type=int, default=1) | |
parser.add_argument('--num_samples', type=int, default=100) | |
parser.add_argument('--threshold', type=float, default=0.9) | |
parser.add_argument('--caption_strategy', type=str, default='simple', choices=['simple', 'granular', 'gtruth']) | |
parser.add_argument('--sample_strategy', type=str, default='random', choices=['random', 'attribute', 'clip']) | |
parser.add_argument('--attribute_n', type=int, default=1) # Number of attributes to split | |
parser.add_argument('--response_type_simul', type=str, default='VQA1', choices=['simple', 'QA', 'VQA1', 'VQA2', 'VQA3', 'VQA4']) | |
parser.add_argument('--response_type_gtruth', type=str, default='VQA2', choices=['simple', 'QA', 'VQA1', 'VQA2', 'VQA3', 'VQA4']) | |
parser.add_argument('--question_strategy', type=str, default='gpt3', choices=['rule', 'gpt3']) | |
parser.add_argument('--multiplier_mode', type=str, default='soft', choices=['soft', 'hard', 'none']) | |
parser.add_argument('--gpt3_save_name', type=str, default='questions_gpt3') | |
parser.add_argument('--save_name', type=str, default=None) | |
parser.add_argument('--verbose', action='store_true') | |
args = parser.parse_args() | |
args.question_strategy='gpt3' | |
args.include_what=True | |
args.response_type_simul='VQA1' | |
args.response_type_gtruth='VQA3' | |
args.multiplier_mode='soft' | |
args.sample_strategy='attribute' | |
args.attribute_n=1 | |
args.caption_strategy='gtruth' | |
assert_checks(args) | |
if args.save_name is None: logger = logging_handler(args.verbose, args.save_name) | |
args.load_response_model = True | |
print("1. Loading question model ...") | |
question_model = get_question_model(args) | |
args.question_generator = question_model.question_generator | |
print("2. Loading response model simul ...") | |
response_model_simul = get_response_model(args, args.response_type_simul) | |
response_model_simul.to(args.device) | |
print("3. Loading response model gtruth ...") | |
response_model_gtruth = get_response_model(args, args.response_type_gtruth) | |
response_model_gtruth.to(args.device) | |
print("4. Loading caption model ...") | |
caption_model = get_caption_model(args, question_model) | |
def return_modules(): | |
return question_model, response_model_simul, response_model_gtruth, caption_model | |
args.question_strategy='rule' | |
args.include_what=False | |
args.response_type_simul='VQA1' | |
args.response_type_gtruth='VQA3' | |
args.multiplier_mode='none' | |
args.sample_strategy='attribute' | |
args.attribute_n=1 | |
args.caption_strategy='gtruth' | |
print("1. Loading question model ...") | |
question_model_yn = get_question_model(args) | |
args.question_generator_yn = question_model_yn.question_generator | |
print("2. Loading response model simul ...") | |
response_model_simul_yn = get_response_model(args, args.response_type_simul) | |
response_model_simul_yn.to(args.device) | |
print("3. Loading response model gtruth ...") | |
response_model_gtruth_yn = get_response_model(args, args.response_type_gtruth) | |
response_model_gtruth_yn.to(args.device) | |
print("4. Loading caption model ...") | |
caption_model_yn = get_caption_model(args, question_model_yn) | |
def return_modules_yn(): | |
return question_model_yn, response_model_simul_yn, response_model_gtruth_yn, caption_model_yn | |
# args.question_strategy='gpt3' | |
# args.include_what=True | |
# args.response_type_simul='VQA1' | |
# args.response_type_gtruth='VQA3' | |
# args.multiplier_mode='none' | |
# args.sample_strategy='attribute' | |
# args.attribute_n=1 | |
# args.caption_strategy='gtruth' | |
# assert_checks(args) | |
# if args.save_name is None: logger = logging_handler(args.verbose, args.save_name) | |
# args.load_response_model = True | |
# print("1. Loading question model ...") | |
# question_model = get_question_model(args) | |
# args.question_generator = question_model.question_generator | |
# print("2. Loading response model simul ...") | |
# response_model_simul = get_response_model(args, args.response_type_simul) | |
# response_model_simul.to(args.device) | |
# print("3. Loading response model gtruth ...") | |
# response_model_gtruth = get_response_model(args, args.response_type_gtruth) | |
# response_model_gtruth.to(args.device) | |
# print("4. Loading caption model ...") | |
# caption_model = get_caption_model(args, question_model) | |
# # dataloader = DataLoader(dataset=ReferenceGameData(split='test', | |
# # num_images=args.num_images, | |
# # num_samples=args.num_samples, | |
# # sample_strategy=args.sample_strategy, | |
# # attribute_n=args.attribute_n)) | |
# def return_modules(): | |
# return question_model, response_model_simul, response_model_gtruth, caption_model | |
# # game_lens, game_preds = [], [] | |
# for t, batch in enumerate(tqdm(dataloader)): | |
# image_files = [image[0] for image in batch['images'][:args.num_images]] | |
# image_files = [str(i).split('/')[1] for i in image_files] | |
# with open("mscoco_images_attribute_n=1.txt", 'a') as f: | |
# for i in image_files: | |
# f.write(str(i)+"\n") | |
# images = [np.asarray(Image.open(f"./../../../data/ms-coco/images/{i}")) for i in image_files] | |
# images = [np.dstack([i]*3) if len(i.shape)==2 else i for i in images] | |
# p_y_x = (torch.ones(args.num_images)/args.num_images).to(question_model.device) | |
# if args.save_name is not None: | |
# logger = logging_handler(args.verbose, args.save_name, t) | |
# image_saver(images, args.save_name, t) | |
# captions = caption_model.get_captions(image_files) | |
# questions, target_questions = question_model.get_questions(image_files, captions, args.target_idx) | |
# question_model.reset_question_bank() | |
# logger.info(questions) | |
# for idx, c in enumerate(captions): logger.info(f"Image_{idx}: {c}") | |
# num_questions_original = len(questions) | |
# for j in range(min(args.max_num_questions, num_questions_original)): | |
# # Select best question | |
# question = question_model.select_best_question(p_y_x, questions, images, captions, response_model_simul) | |
# logger.info(f"Question: {question}") | |
# # Ask the question and get the model's response | |
# response = response_model_gtruth.get_response(question, images[args.target_idx], captions[args.target_idx], target_questions, is_a=1-args.include_what) | |
# logger.info(f"Response: {response}") | |
# # Update the probabilities | |
# p_r_qy = response_model_simul.get_p_r_qy(response, question, images, captions) | |
# logger.info(f"P(r|q,y):\n{np.around(p_r_qy.cpu().detach().numpy(), 3)}") | |
# p_y_xqr = p_y_x*p_r_qy | |
# p_y_xqr = p_y_xqr/torch.sum(p_y_xqr)if torch.sum(p_y_xqr) != 0 else torch.zeros_like(p_y_xqr) | |
# p_y_x = p_y_xqr | |
# logger.info(f"Updated distribution:\n{np.around(p_y_x.cpu().detach().numpy(), 3)}\n") | |
# # Don't repeat the same question again in the future | |
# questions.remove(question) | |
# # Terminate if probability exceeds threshold or if out of questions to ask | |
# top_prob = torch.max(p_y_x).item() | |
# if top_prob >= args.threshold or j==min(args.max_num_questions, num_questions_original)-1: | |
# game_preds.append(torch.argmax(p_y_x).item()) | |
# game_lens.append(j+1) | |
# logger.info(f"pred:{game_preds[-1]} game_len:{game_lens[-1]}") | |
# break | |
# logger = logging_handler(args.verbose, args.save_name, "final_results") | |
# logger.info(f"Game lenths:\n{game_lens}") | |
# logger.info(sum(game_lens)/len(game_lens)) | |
# logger.info(f"Predictions:\n{game_preds}") | |
# logger.info(f"Accuracy:\n{sum([i==args.target_idx for i in game_preds])/len(game_preds)}") |