import os import sys import json import argparse import pathlib import random import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader # https://github.com/PyTorchLightning/pytorch-lightning/issues/11663 import sentencepiece; import pytorch_lightning as pl import torchmetrics.functional as MF from load_aokvqa import load_aokvqa def main(): parser = argparse.ArgumentParser() parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir') parser.add_argument('--vocab', type=argparse.FileType('r'), required=True) parser.add_argument('--log-dir', type=pathlib.Path, dest='log_dir', required=True) # parser.add_argument('--backbone', type=str, choices=['clip', 'resnet', 'bert'], required=True) parser.add_argument('--clip-model-type', type=str, choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], dest='clip_model_type', required=('clip' in sys.argv)) parser.add_argument('--train-features', type=pathlib.Path, required=True, dest='train_features') parser.add_argument('--val-features', type=pathlib.Path, required=True, dest='val_features') parser.add_argument('--vocab-features', type=pathlib.Path, required=('contrastive' in sys.argv), dest='vocab_features') # parser.add_argument('--objective', type=str, choices=['classifier', 'contrastive'], required=True) parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=True) # Defaults parser.add_argument('--bs', type=int, default=128, dest='batch_size') parser.add_argument('--lr', type=float, default=0.01) parser.add_argument('--epochs', type=int, default=500) parser.add_argument('--gpus', type=int, default=1) args = parser.parse_args() pl.seed_everything(1) vocab = args.vocab.read().splitlines() ## Data loading dm = AokvqaEmbeddingsDataModule( args.aokvqa_dir, args.train_features, args.val_features, args.objective, args.backbone, args.inputs, vocab, args.vocab_features, batch_size=args.batch_size, num_workers=16 ) ## Model definition model = LinearClassifier( args.objective, args.backbone, args.clip_model_type, args.inputs, len(vocab), args.lr ) ## Training and testing loops logger = pl.loggers.TensorBoardLogger( args.log_dir, name=f'{args.backbone}-{args.objective}', version=f"inputs:{'+'.join(args.inputs)}" ) trainer = pl.Trainer( logger=logger, gpus=args.gpus, max_epochs=args.epochs, callbacks=[ pl.callbacks.ModelCheckpoint( monitor="val_acc", filename="{epoch:02d}-{val_acc:.2f}", mode="max" ) ], ) trainer.fit(model, dm) class AokvqaEmbeddingsDataset(Dataset): def __init__(self, aokvqa_dir, split, input_features, objective, backbone, inputs, vocab, vocab_features): aokvqa_set = load_aokvqa(aokvqa_dir, split) assert ( backbone == 'resnet' and inputs == ['image'] and objective == 'classifier' ) \ or ( backbone == 'bert' and inputs == ['question'] and objective == 'classifier' ) \ or ( backbone == 'clip' ) embeddings = torch.load(input_features) if backbone == 'clip': for q in embeddings.keys(): embeddings[q]['question'] /= embeddings[q]['question'].norm(dim=-1, keepdim=True) embeddings[q]['image'] /= embeddings[q]['image'].norm(dim=-1, keepdim=True) if objective == 'contrastive': vocab_embeddings = torch.load(vocab_features) vocab_embeddings /= vocab_embeddings.norm(dim=-1, keepdim=True) self.objective = objective self.vocab_len = len(vocab) self.embeddings = [] self.answers = [] for o in aokvqa_set: correct_answers = set([o['choices'][o['correct_choice_idx']]] + o['direct_answers']) correct_answers = [vocab.index(a) for a in correct_answers if a in vocab] if self.objective == 'contrastive': correct_answers = [vocab_embeddings[a] for a in correct_answers] if len(correct_answers) == 0: continue self.answers.append(correct_answers) q = o['question_id'] if 'question' in inputs and 'image' in inputs: e = torch.cat((embeddings[q]['question'], embeddings[q]['image'])) elif 'question' in inputs and 'image' not in inputs: e = embeddings[q]['question'] elif 'question' not in inputs and 'image' in inputs: e = embeddings[q]['image'] self.embeddings.append(e) def __getitem__(self, index): e = self.embeddings[index] a = self.answers[index] if self.objective == 'classifier': a = torch.sum(F.one_hot(torch.tensor(a), num_classes=self.vocab_len), dim=0) elif self.objective == 'contrastive': a = random.sample(a, 1)[0] return e, a def __len__(self): return len(self.embeddings) class AokvqaEmbeddingsDataModule(pl.LightningDataModule): def __init__(self, aokvqa_dir, train_features, val_features, objective, backbone, inputs, vocab, vocab_features, batch_size=1, num_workers=0): super().__init__() self.aokvqa_dir = aokvqa_dir self.train_features = train_features self.val_features = val_features self.objective = objective self.backbone = backbone self.inputs = inputs self.vocab = vocab self.vocab_features = vocab_features self.batch_size = batch_size self.num_workers = num_workers def setup(self, stage=None): self.train_dataset = AokvqaEmbeddingsDataset( self.aokvqa_dir, 'train', self.train_features, self.objective, self.backbone, self.inputs, self.vocab, self.vocab_features ) self.val_dataset = AokvqaEmbeddingsDataset( self.aokvqa_dir, 'val', self.val_features, self.objective, self.backbone, self.inputs, self.vocab, self.vocab_features ) def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=int(0.8 * self.num_workers) ) def val_dataloader(self): return DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=int(0.2 * self.num_workers) ) class LinearClassifier(pl.LightningModule): def __init__(self, objective, backbone, clip_model_type, inputs, vocab_len, lr=0.001): super().__init__() self.save_hyperparameters(ignore=['lr']) self.lr = lr if self.hparams.backbone == 'clip': clip_dim = { 'RN50' : 1024, 'RN50x4' : 640, 'RN50x16' : 768, 'RN50x64' : 1024, 'RN101' : 512, 'ViT-B/32' : 512, 'ViT-B/16' : 512, 'ViT-L/14' : 768, 'ViT-L/14@336px' : 768, }[clip_model_type] emb_dim = clip_dim * len(inputs) elif self.hparams.backbone == 'resnet': emb_dim = 2048 elif self.hparams.backbone == 'bert': emb_dim = 768 if self.hparams.objective == 'classifier': out_dim = vocab_len elif self.hparams.objective == 'contrastive': out_dim = clip_dim self.linear = nn.Linear(emb_dim, out_dim) def forward(self, x): x = self.linear(x) if self.hparams.objective == 'classifier': x = torch.sigmoid(x) return x def compute_loss(self, batch): x, y = batch y_pred = self.forward(x) if self.hparams.objective == 'classifier': loss = F.binary_cross_entropy(y_pred, y.float()) elif self.hparams.objective == 'contrastive': indices = torch.arange(0, x.shape[0], dtype=torch.int64, device=self.device) sim = (y_pred @ y.T).softmax(dim=-1) loss = F.cross_entropy(sim, indices) if self.hparams.objective == 'classifier': acc = MF.f1_score(y_pred, y) elif self.hparams.objective == 'contrastive': acc = torch.mean(sim[indices, indices]) return loss, acc def training_step(self, batch, batch_idx): loss, acc = self.compute_loss(batch) self.log("train_loss", loss) self.log("train_acc", acc) return loss def validation_step(self, batch, batch_idx): loss, acc = self.compute_loss(batch) self.log("val_loss", loss) self.log("val_acc", acc) return loss def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) return optimizer if __name__ == '__main__': main()