Spaces:
Build error
Build error
import os | |
import time | |
import argparse | |
import datetime | |
import json | |
import numpy as np | |
from collections import defaultdict | |
import torch | |
import torch.backends.cudnn as cudnn | |
import torch.distributed as dist | |
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy | |
from timm.utils import accuracy, AverageMeter | |
from config import get_config | |
from models import build_model | |
from data import build_loader | |
from lr_scheduler import build_scheduler | |
from optimizer import build_optimizer | |
from logger import create_logger | |
from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor,load_pretained | |
have_wandb = False | |
try: | |
import wandb | |
have_wandb = True | |
except: | |
pass | |
# TODO use torch.amp | |
try: | |
# noinspection PyUnresolvedReferences | |
from apex import amp | |
except ImportError: | |
amp = None | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
def parse_option(): | |
parser = argparse.ArgumentParser('MetaFG training and evaluation script', add_help=False) | |
parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) | |
parser.add_argument( | |
"--opts", | |
help="Modify config options by adding 'KEY VALUE' pairs. ", | |
default=None, | |
nargs='+', | |
) | |
# easy config modification | |
parser.add_argument('--batch-size', type=int, help="batch size for single GPU") | |
parser.add_argument('--data-path',default='./imagenet', type=str, help='path to dataset') | |
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') | |
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], | |
help='no: no cache, ' | |
'full: cache all data, ' | |
'part: sharding the dataset into nonoverlapping pieces and only cache one piece') | |
parser.add_argument('--resume', help='resume from checkpoint') | |
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") | |
parser.add_argument('--use-checkpoint', action='store_true', | |
help="whether to use gradient checkpointing to save memory") | |
parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], | |
help='mixed precision opt level, if O0, no amp is used') | |
parser.add_argument('--output', default='output', type=str, metavar='PATH', | |
help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)') | |
parser.add_argument('--tag', help='tag of experiment') | |
parser.add_argument('--eval', action='store_true', help='Perform evaluation only') | |
parser.add_argument('--throughput', action='store_true', help='Test throughput only') | |
parser.add_argument('--num-workers', type=int, | |
help="num of workers on dataloader ") | |
parser.add_argument('--lr', type=float, metavar='LR', | |
help='learning rate') | |
parser.add_argument('--weight-decay', type=float, | |
help='weight decay (default: 0.05 for adamw)') | |
parser.add_argument('--min-lr', type=float, | |
help='learning rate') | |
parser.add_argument('--warmup-lr', type=float, | |
help='warmup learning rate') | |
parser.add_argument('--epochs', type=int, | |
help="epochs") | |
parser.add_argument('--warmup-epochs', type=int, | |
help="epochs") | |
parser.add_argument('--dataset', type=str, | |
help='dataset') | |
parser.add_argument('--lr-scheduler-name', type=str, | |
help='lr scheduler name,cosin linear,step') | |
parser.add_argument('--pretrain', type=str, | |
help='pretrain') | |
parser.add_argument('--wandb_job', type=str) | |
args, unparsed = parser.parse_known_args() | |
config = get_config(args) | |
if have_wandb and int(config.LOCAL_RANK) == 0: | |
wandb.init(name = args.wandb_job, config=args) | |
return args, config | |
def main(config): | |
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) | |
logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") | |
model = build_model(config) | |
if have_wandb and int(config.LOCAL_RANK) == 0: | |
wandb.config['model_config'] = config | |
model.cuda() | |
logger.info(str(model)) | |
optimizer = build_optimizer(config, model) | |
if config.AMP_OPT_LEVEL != "O0": | |
model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) | |
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[int(config.LOCAL_RANK)], broadcast_buffers=False) | |
model_without_ddp = model.module | |
if have_wandb and int(config.LOCAL_RANK) == 0: | |
wandb.watch(model, log_freq=100) | |
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
logger.info(f"number of params: {n_parameters}") | |
if hasattr(model_without_ddp, 'flops'): | |
flops = model_without_ddp.flops() | |
logger.info(f"number of GFLOPs: {flops / 1e9}") | |
lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) | |
if config.AUG.MIXUP > 0.: | |
# smoothing is handled with mixup label transform | |
criterion = SoftTargetCrossEntropy() | |
elif config.MODEL.LABEL_SMOOTHING > 0.: | |
criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) | |
else: | |
criterion = torch.nn.CrossEntropyLoss() | |
max_accuracy = 0.0 | |
if config.MODEL.PRETRAINED: | |
load_pretained(config,model_without_ddp,logger) | |
# Run initial validation | |
logger.info("Start validation (on init)") | |
acc1, acc5, loss, stats = validate(config, data_loader_val, model, limit=10) | |
with open(os.path.join(config.OUTPUT, f'val_init.json'), 'w') as fp: | |
json.dump(stats, fp, indent=1) | |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") | |
if config.TRAIN.AUTO_RESUME: | |
resume_file = auto_resume_helper(config.OUTPUT) | |
if resume_file: | |
if config.MODEL.RESUME: | |
logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") | |
config.defrost() | |
config.MODEL.RESUME = resume_file | |
config.freeze() | |
logger.info(f'auto resuming from {resume_file}') | |
else: | |
logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') | |
if config.MODEL.RESUME: | |
logger.info(f"**********normal test***********") | |
max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) | |
acc1, acc5, loss, stats = validate(config, data_loader_val, model) | |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") | |
if config.DATA.ADD_META: | |
logger.info(f"**********mask meta test***********") | |
acc1, acc5, loss, stats = validate(config, data_loader_val, model,mask_meta=True) | |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") | |
if config.EVAL_MODE: | |
return | |
if config.THROUGHPUT_MODE: | |
throughput(data_loader_val, model, logger) | |
return | |
logger.info("Start training") | |
start_time = time.time() | |
for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): | |
data_loader_train.sampler.set_epoch(epoch) | |
train_one_epoch_local_data(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler) | |
if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): | |
save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger) | |
logger.info(f"**********normal test***********") | |
acc1, acc5, loss, stats = validate(config, data_loader_val, model) | |
with open(os.path.join(config.OUTPUT, f'val_{epoch}.json'), 'w') as fp: | |
json.dump(stats, fp, indent=1) | |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") | |
max_accuracy = max(max_accuracy, acc1) | |
logger.info(f'Max accuracy: {max_accuracy:.2f}%') | |
if config.DATA.ADD_META: | |
logger.info(f"**********mask meta test***********") | |
acc1, acc5, loss, stats = validate(config, data_loader_val, model,mask_meta=True) | |
with open(os.path.join(config.OUTPUT, f'val_{epoch}_meta.json'), 'w') as fp: | |
json.dump(stats, fp, indent=1) | |
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") | |
# data_loader_train.terminate() | |
if have_wandb and int(config.LOCAL_RANK) == 0: | |
wandb.run.summary["acc_top_1"] = acc1 | |
wandb.run.summary["acc_top_5"] = acc5 | |
wandb.run.summary["val_loss"] = loss | |
wandb.log({'val/acc1': acc1}) | |
wandb.log({'val/acc5': acc5}) | |
wandb.log({'val/loss': acc5}) | |
total_time = time.time() - start_time | |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
logger.info('Training time {}'.format(total_time_str)) | |
def train_one_epoch_local_data(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler,tb_logger=None): | |
model.train() | |
if hasattr(model.module,'cur_epoch'): | |
model.module.cur_epoch = epoch | |
model.module.total_epoch = config.TRAIN.EPOCHS | |
optimizer.zero_grad() | |
num_steps = len(data_loader) | |
batch_time = AverageMeter() | |
loss_meter = AverageMeter() | |
norm_meter = AverageMeter() | |
start = time.time() | |
end = time.time() | |
for idx, data in enumerate(data_loader): | |
if config.DATA.ADD_META: | |
samples, targets,meta = data | |
meta = [m.float() for m in meta] | |
meta = torch.stack(meta,dim=0) | |
meta = meta.cuda(non_blocking=True) | |
else: | |
samples, targets= data | |
meta = None | |
samples = samples.cuda(non_blocking=True) | |
targets = targets.cuda(non_blocking=True) | |
if mixup_fn is not None: | |
samples, targets = mixup_fn(samples, targets) | |
if config.DATA.ADD_META: | |
outputs = model(samples,meta) | |
else: | |
outputs = model(samples) | |
if config.TRAIN.ACCUMULATION_STEPS > 1: | |
loss = criterion(outputs, targets) | |
loss = loss / config.TRAIN.ACCUMULATION_STEPS | |
if config.AMP_OPT_LEVEL != "O0": | |
with amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
if config.TRAIN.CLIP_GRAD: | |
grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) | |
else: | |
grad_norm = get_grad_norm(amp.master_params(optimizer)) | |
else: | |
loss.backward() | |
if config.TRAIN.CLIP_GRAD: | |
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) | |
else: | |
grad_norm = get_grad_norm(model.parameters()) | |
if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: | |
optimizer.step() | |
optimizer.zero_grad() | |
lr_scheduler.step_update(epoch * num_steps + idx) | |
else: | |
loss = criterion(outputs, targets) | |
optimizer.zero_grad() | |
if config.AMP_OPT_LEVEL != "O0": | |
with amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
if config.TRAIN.CLIP_GRAD: | |
grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) | |
else: | |
grad_norm = get_grad_norm(amp.master_params(optimizer)) | |
else: | |
loss.backward() | |
if config.TRAIN.CLIP_GRAD: | |
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) | |
else: | |
grad_norm = get_grad_norm(model.parameters()) | |
optimizer.step() | |
lr_scheduler.step_update(epoch * num_steps + idx) | |
torch.cuda.synchronize() | |
loss_meter.update(loss.item(), targets.size(0)) | |
norm_meter.update(grad_norm) | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if idx % config.PRINT_FREQ == 0: | |
lr = optimizer.param_groups[0]['lr'] | |
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) | |
etas = batch_time.avg * (num_steps - idx) | |
if have_wandb and int(config.LOCAL_RANK) == 0 and idx % 100 == 0: | |
wandb.log({"train/loss": loss_meter.val}) | |
logger.info( | |
f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' | |
f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' | |
f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' | |
f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' | |
f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' | |
f'mem {memory_used:.0f}MB') | |
epoch_time = time.time() - start | |
logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") | |
def validate(config, data_loader, model, mask_meta=False, limit=None): | |
criterion = torch.nn.CrossEntropyLoss() | |
model.eval() | |
batch_time = AverageMeter() | |
loss_meter = AverageMeter() | |
acc1_meter = AverageMeter() | |
acc5_meter = AverageMeter() | |
stats = defaultdict(list) | |
end = time.time() | |
for idx, data in enumerate(data_loader): | |
if limit: | |
if idx > limit: | |
break | |
if config.DATA.ADD_META: | |
images,target,meta = data | |
meta = [m.float() for m in meta] | |
meta = torch.stack(meta,dim=0) | |
if mask_meta: | |
meta = torch.zeros_like(meta) | |
meta = meta.cuda(non_blocking=True) | |
else: | |
images, target = data | |
meta = None | |
images = images.cuda(non_blocking=True) | |
target = target.cuda(non_blocking=True) | |
# compute output | |
if config.DATA.ADD_META: | |
output = model(images,meta) | |
else: | |
output = model(images) | |
# measure accuracy and record loss | |
loss = criterion(output, target) | |
acc1, acc5 = accuracy(output, target, topk=(1, 5)) | |
acc1 = reduce_tensor(acc1) | |
acc5 = reduce_tensor(acc5) | |
loss = reduce_tensor(loss) | |
loss_meter.update(loss.item(), target.size(0)) | |
acc1_meter.update(acc1.item(), target.size(0)) | |
acc5_meter.update(acc5.item(), target.size(0)) | |
for t in target: | |
stats[int(t.item())].append((acc1.item(), acc5.item(), loss.item())) | |
# measure elapsed time | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if idx % config.PRINT_FREQ == 0: | |
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) | |
logger.info( | |
f'Test: [{idx}/{len(data_loader)}]\t' | |
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' | |
f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' | |
f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' | |
f'Mem {memory_used:.0f}MB') | |
logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') | |
return acc1_meter.avg, acc5_meter.avg, loss_meter.avg, stats | |
def throughput(data_loader, model, logger): | |
model.eval() | |
for idx, (images, _) in enumerate(data_loader): | |
images = images.cuda(non_blocking=True) | |
batch_size = images.shape[0] | |
for i in range(50): | |
model(images) | |
torch.cuda.synchronize() | |
logger.info(f"throughput averaged with 30 times") | |
tic1 = time.time() | |
for i in range(30): | |
model(images) | |
torch.cuda.synchronize() | |
tic2 = time.time() | |
logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") | |
return | |
if __name__ == '__main__': | |
_, config = parse_option() | |
if config.AMP_OPT_LEVEL != "O0": | |
assert amp is not None, "amp not installed!" | |
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: | |
rank = int(os.environ["RANK"]) | |
world_size = int(os.environ['WORLD_SIZE']) | |
print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") | |
else: | |
rank = -1 | |
world_size = -1 | |
torch.cuda.set_device(f'cuda:{config.LOCAL_RANK}') | |
torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) | |
torch.distributed.barrier() | |
seed = config.SEED + dist.get_rank() | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
cudnn.benchmark = True | |
# linear scale the learning rate according to total batch size, may not be optimal | |
linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 | |
linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 | |
linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 | |
# gradient accumulation also need to scale the learning rate | |
if config.TRAIN.ACCUMULATION_STEPS > 1: | |
linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS | |
linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS | |
linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS | |
config.defrost() | |
config.TRAIN.BASE_LR = linear_scaled_lr | |
config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr | |
config.TRAIN.MIN_LR = linear_scaled_min_lr | |
config.freeze() | |
os.makedirs(config.OUTPUT, exist_ok=True) | |
logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}",local_rank=config.LOCAL_RANK) | |
if dist.get_rank() == 0: | |
path = os.path.join(config.OUTPUT, "config.json") | |
with open(path, "w") as f: | |
f.write(config.dump()) | |
logger.info(f"Full config saved to {path}") | |
# print config | |
logger.info(config.dump()) | |
main(config) | |