fffiloni's picture
Upload 164 files
2ada650 verified
raw
history blame
4.78 kB
import sys
import os
import argparse
import pathlib
from tqdm import tqdm
import json
import torch
import torch.nn as nn
# https://github.com/PyTorchLightning/pytorch-lightning/issues/11663
import sentencepiece; import pytorch_lightning as pl; import clip
from transfer_experiments.train import LinearClassifier
from load_aokvqa import load_aokvqa
from evaluation.remap_predictions import map_to_choices
parser = argparse.ArgumentParser()
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
parser.add_argument('--features', type=pathlib.Path, required=True)
parser.add_argument('--out', type=argparse.FileType('w'), dest='output_file')
#
parser_weights = parser.add_mutually_exclusive_group(required=True)
parser_weights.add_argument('--ckpt', type=pathlib.Path, dest='checkpoint_path')
parser_weights.add_argument('--zero-shot', action='store_true', dest='clip_zero_shot')
parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=('--zero-shot' in sys.argv))
#
parser.add_argument('--vocab', type=argparse.FileType('r'))
parser.add_argument('--vocab-features', type=pathlib.Path, dest='vocab_features')
parser.add_argument('--mc', action='store_true', dest='multiple_choice')
parser.add_argument('--clip-model-type', type=str,
choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'],
dest='clip_model_type', required=('--zero-shot' in sys.argv and '--mc' in sys.argv))
#
args = parser.parse_args()
## Load dataset
aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split)
## Load models
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.checkpoint_path is not None:
classifier = LinearClassifier.load_from_checkpoint(args.checkpoint_path)
classifier.to(device)
hp = classifier.hparams
elif args.clip_zero_shot:
classifier = nn.Identity().to(device)
hp = pl.utilities.AttributeDict(backbone='clip', clip_model_type=args.clip_model_type, objective='zero-shot', inputs=args.inputs)
# Load input features
embeddings = torch.load(args.features)
if hp.backbone == 'clip':
for q in embeddings.keys():
embeddings[q]['question'] = embeddings[q]['question'] / embeddings[q]['question'].norm(dim=-1, keepdim=True)
embeddings[q]['image'] = embeddings[q]['image'] / embeddings[q]['image'].norm(dim=-1, keepdim=True)
# Load vocab, vocab features, clip
if (hp.objective == 'classifier') or \
(hp.objective in ['contrastive', 'zero-shot'] and args.multiple_choice is False):
vocab = args.vocab.read().splitlines()
if hp.objective in ['contrastive', 'zero-shot']:
if args.multiple_choice is False:
vocab_features = torch.load(args.vocab_features).cpu()
vocab_features /= vocab_features.norm(dim=-1, keepdim=True)
else:
clip_model = clip.load(hp.clip_model_type, device=device)[0]
logit_scale = clip_model.logit_scale.exp().cpu()
## Prediction loop
predictions = {}
with torch.no_grad():
for o in tqdm(aokvqa_set):
q = o['question_id']
# Load input embedding (from question / image)
if hp.objective == 'zero-shot' and ('question' in hp.inputs and 'image' in hp.inputs):
e = embeddings[q]['question'] + embeddings[q]['image']
elif 'question' in hp.inputs and 'image' in hp.inputs:
e = torch.cat((embeddings[q]['question'], embeddings[q]['image']))
elif 'question' in hp.inputs:
e = embeddings[q]['question']
elif 'image' in hp.inputs:
e = embeddings[q]['image']
# Pass inputs through model
e = e.unsqueeze(0).to(device)
x = classifier(e)[0].cpu()
# Predict
if hp.objective in ['contrastive', 'zero-shot']:
if args.multiple_choice:
vocab = o['choices']
# Encode choices
vocab_features = clip.tokenize(vocab).to(device)
vocab_features = torch.stack([
clip_model.encode_text(v.unsqueeze(0)) for v in vocab_features
], dim=1)[0]
vocab_features /= vocab_features.norm(dim=-1, keepdim=True)
vocab_features = vocab_features.float().cpu()
x = logit_scale * x @ vocab_features.t()
x = x.softmax(dim=-1)
predictions[q] = vocab[x.argmax().item()]
## Save and evaluate predictions
# Map prediction to nearest neighbor choice (by word embeddings)
if args.multiple_choice and hp.objective == 'classifier':
predictions = map_to_choices(aokvqa_set, predictions)
json.dump(predictions, args.output_file)