akhaliq's picture
akhaliq HF staff
add files
c80917c
raw
history blame
6.58 kB
import argparse
import random
import numpy as np
import torch
import pprint
import yaml
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def is_interactive():
import __main__ as main
return not hasattr(main, '__file__')
def get_optimizer(optim, verbose=False):
# Bind the optimizer
if optim == 'rms':
if verbose:
print("Optimizer: Using RMSProp")
optimizer = torch.optim.RMSprop
elif optim == 'adam':
if verbose:
print("Optimizer: Using Adam")
optimizer = torch.optim.Adam
elif optim == 'adamw':
if verbose:
print("Optimizer: Using AdamW")
# optimizer = torch.optim.AdamW
optimizer = 'adamw'
elif optim == 'adamax':
if verbose:
print("Optimizer: Using Adamax")
optimizer = torch.optim.Adamax
elif optim == 'sgd':
if verbose:
print("Optimizer: SGD")
optimizer = torch.optim.SGD
else:
assert False, "Please add your optimizer %s in the list." % optim
return optimizer
def parse_args(parse=True, **optional_kwargs):
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=9595, help='random seed')
# Data Splits
parser.add_argument("--train", default='karpathy_train')
parser.add_argument("--valid", default='karpathy_val')
parser.add_argument("--test", default='karpathy_test')
# parser.add_argument('--test_only', action='store_true')
# Quick experiments
parser.add_argument('--train_topk', type=int, default=-1)
parser.add_argument('--valid_topk', type=int, default=-1)
# Checkpoint
parser.add_argument('--output', type=str, default='snap/test')
parser.add_argument('--load', type=str, default=None, help='Load the model (usually the fine-tuned model).')
parser.add_argument('--from_scratch', action='store_true')
# CPU/GPU
parser.add_argument("--multiGPU", action='store_const', default=False, const=True)
parser.add_argument('--fp16', action='store_true')
parser.add_argument("--distributed", action='store_true')
parser.add_argument("--num_workers", default=0, type=int)
parser.add_argument('--local_rank', type=int, default=-1)
# parser.add_argument('--rank', type=int, default=-1)
# Model Config
# parser.add_argument('--encoder_backbone', type=str, default='openai/clip-vit-base-patch32')
# parser.add_argument('--decoder_backbone', type=str, default='bert-base-uncased')
parser.add_argument('--tokenizer', type=str, default='openai/clip-vit-base-patch32')
# parser.add_argument('--position_embedding_type', type=str, default='absolute')
# parser.add_argument('--encoder_transform', action='store_true')
parser.add_argument('--max_text_length', type=int, default=40)
# parser.add_argument('--image_size', type=int, default=224)
# parser.add_argument('--patch_size', type=int, default=32)
# parser.add_argument('--decoder_num_layers', type=int, default=12)
# Training
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--valid_batch_size', type=int, default=None)
parser.add_argument('--optim', default='adamw')
parser.add_argument('--warmup_ratio', type=float, default=0.05)
parser.add_argument('--weight_decay', type=float, default=0.01)
parser.add_argument('--clip_grad_norm', type=float, default=-1.0)
parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--adam_eps', type=float, default=1e-6)
parser.add_argument('--adam_beta1', type=float, default=0.9)
parser.add_argument('--adam_beta2', type=float, default=0.999)
parser.add_argument('--epochs', type=int, default=20)
# parser.add_argument('--dropout', type=float, default=0.1)
# Inference
# parser.add_argument('--num_beams', type=int, default=1)
# parser.add_argument('--gen_max_length', type=int, default=20)
parser.add_argument('--start_from', type=str, default=None)
# Data
# parser.add_argument('--do_lower_case', type=str2bool, default=None)
# parser.add_argument('--prefix', type=str, default=None)
# COCO Caption
# parser.add_argument('--no_prefix', action='store_true')
parser.add_argument('--no_cls', action='store_true')
parser.add_argument('--cfg', type=str, default=None)
parser.add_argument('--id', type=str, default=None)
# Etc.
parser.add_argument('--comment', type=str, default='')
parser.add_argument("--dry", action='store_true')
# Parse the arguments.
if parse:
args = parser.parse_args()
# For interative engironmnet (ex. jupyter)
else:
args = parser.parse_known_args()[0]
loaded_kwargs = {}
if args.cfg is not None:
cfg_path = f'configs/{args.cfg}.yaml'
with open(cfg_path, 'r') as f:
loaded_kwargs = yaml.safe_load(f)
# Namespace => Dictionary
parsed_kwargs = vars(args)
parsed_kwargs.update(optional_kwargs)
kwargs = {}
kwargs.update(parsed_kwargs)
kwargs.update(loaded_kwargs)
args = Config(**kwargs)
# Bind optimizer class.
verbose = False
args.optimizer = get_optimizer(args.optim, verbose=verbose)
# Set seeds
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
return args
class Config(object):
def __init__(self, **kwargs):
"""Configuration Class: set kwargs as class attributes with setattr"""
for k, v in kwargs.items():
setattr(self, k, v)
@property
def config_str(self):
return pprint.pformat(self.__dict__)
def __repr__(self):
"""Pretty-print configurations in alphabetical order"""
config_str = 'Configurations\n'
config_str += self.config_str
return config_str
# def update(self, **kwargs):
# for k, v in kwargs.items():
# setattr(self, k, v)
# def save(self, path):
# with open(path, 'w') as f:
# yaml.dump(self.__dict__, f, default_flow_style=False)
# @classmethod
# def load(cls, path):
# with open(path, 'r') as f:
# kwargs = yaml.load(f)
# return Config(**kwargs)
if __name__ == '__main__':
args = parse_args(True)