Spaces:
Build error
Build error
# -------------------------------------------------------- | |
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# Modified by Xueyan Zou (xueyan@cs.wisc.edu) | |
# -------------------------------------------------------- | |
from datetime import datetime | |
import time | |
import os | |
import sys | |
import importlib | |
import json | |
import random | |
#import wandb | |
import logging | |
import numpy as np | |
import copy | |
import contextlib | |
import shutil | |
from typing import Any, Callable, Union | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.optim.lr_scheduler as lr_scheduler | |
from torch.utils.data import DataLoader | |
from torch.utils.data.distributed import DistributedSampler | |
from mpi4py import MPI | |
from infinibatch import iterators | |
from .distributed_trainer import DistributedTrainer | |
from .utils_trainer import UtilsTrainer | |
from .utils.misc import * | |
from .utils.serialization import JSONEncoder, filter_jsonable | |
logger = logging.getLogger(__name__) | |
class DefaultTrainer(UtilsTrainer, DistributedTrainer): | |
def __init__(self, opt): | |
""" | |
Set up the task the model is being trained for. | |
""" | |
super().__init__(opt) | |
base_name = 'base_dir' | |
base_path = os.path.join(self.opt['base_path'], '__init__.py') | |
spec = importlib.util.spec_from_file_location(base_name, base_path) | |
module = importlib.util.module_from_spec(spec) | |
sys.modules[base_name] = module | |
spec.loader.exec_module(module) | |
logger.info(f"Imported {base_name} at base_path {self.opt['base_path']}") | |
pipeline_module = importlib.import_module(f"base_dir.pipeline.{self.opt['PIPELINE']}") | |
pipeline_class = getattr(pipeline_module, self.opt['PIPELINE']) | |
logger.info(f"Pipeline for training: {self.opt['PIPELINE']}") | |
self.pipeline = pipeline_class(self.opt) | |
def eval(self, ): | |
logger.info('-----------------------------------------------') | |
logger.info("Evaluating model ... ") | |
self.mode = "eval" | |
# self.model_names, self.raw_models, self.criteria = self.pipeline.set_up_model() | |
self.raw_models = self.pipeline.initialize_model() | |
self.model_names = self.raw_models.keys() | |
# move models to the device | |
for module_name in self.model_names: | |
self.raw_models[module_name].to(self.opt['device']) | |
# load model during evaluation | |
if self.opt['WEIGHT'] and os.path.isfile(self.opt['RESUME_FROM']): | |
model_path = self.opt['RESUME_FROM'] | |
self.load_model(model_path) | |
else: | |
raise ValueError(f"Model not found: {model_path}") | |
results = self._eval_on_set(self.save_folder) | |
return results | |
def _eval_on_set(self, save_folder): | |
logger.info(f"Evaluation start ...") | |
if self.opt['FP16']: | |
from torch.cuda.amp import autocast | |
with autocast(): | |
results = self.pipeline.evaluate_model(self, save_folder) | |
else: | |
results = self.pipeline.evaluate_model(self, save_folder) | |
if self.opt['rank'] == 0: | |
logger.info(results) | |
return results | |
def compute_loss(self, forward_func, batch): | |
def forward(func, trainer, batch): | |
if self.opt['FP16']: | |
from torch.cuda.amp import autocast | |
with autocast(): | |
loss = func(trainer, batch) | |
else: | |
loss = func(trainer, batch) | |
return loss | |
loss = forward(forward_func, self, batch) | |
return loss | |
def backward_loss(self, loss, model_names=['default']): # noqa: E252 | |
def backward(loss_tensor): | |
if self.opt['FP16']: | |
self.grad_scaler.scale(loss_tensor).backward() | |
else: | |
loss_tensor.backward() | |
if self.grad_acc_steps > 1: | |
loss = loss / self.grad_acc_steps | |
backward(loss) | |
return loss | |
def update_model(self, model_name='default'): | |
if self.opt['FP16']: | |
self.grad_scaler.unscale_(self.optimizers[model_name]) | |
self.grad_scaler.step(self.optimizers[model_name]) | |
else: | |
self.optimizers[model_name].step() | |
self.optimizers[model_name].zero_grad() | |
self.train_params['optim_steps'][model_name] += 1 | |
self.lr_schedulers[model_name].step() | |
def train_step(self, batch): | |
self.grad_acc_batches.append(batch) # support batch accumulation | |
if self.is_gradient_accumulation_boundary(): | |
# set all modules and criteria into training mode | |
for model_name in self.model_names: | |
self.models[model_name].train() | |
assert len(self.grad_acc_batches) == self.grad_acc_steps | |
total_batch_sample = 0 | |
for batch_index, batch in enumerate(self.grad_acc_batches): | |
loss_info, sample_size_info, extra_info = \ | |
self.pipeline.forward_step(self, | |
batch, | |
self.grad_acc_batches, | |
batch_index, | |
is_distributed=(self.opt['world_size'] > 1)) | |
self.train_loss.update_iter(loss_info) | |
total_batch_sample += sample_size_info['num_samples'] | |
if self.opt['FP16']: | |
# Update GradScaler after an effective batch | |
self.grad_scaler.update() | |
# update losses and item counts of an effective batch to the AverageMeters | |
if self.opt['world_size'] > 1: | |
total_batch_sample = torch.tensor(total_batch_sample).to(self.opt['device']) | |
torch.distributed.all_reduce(total_batch_sample, torch.distributed.ReduceOp.SUM) | |
total_batch_sample = total_batch_sample.item() | |
self.train_params['total_batch_size'] += total_batch_sample | |
self.grad_acc_batches = [] | |
self.train_params['num_updates'] += 1 | |
def init_train(self): | |
self.mode = "train" | |
logger.info('-------------------------------------------------------') | |
logger.info("Training on rank: {}".format(self.opt['rank'])) | |
self.raw_models = self.pipeline.initialize_model() | |
self.model_names = list(self.raw_models.keys()) | |
# move models to the device | |
for module_name in self.model_names: | |
self.raw_models[module_name].to(self.opt['device']) | |
self.train_dataloaders = self.pipeline.get_dataloaders(self, 'train', is_evaluation=False) | |
self.train_params = { | |
"updates_per_epoch": len(self.train_dataloaders), | |
"total_batch_size": 0, | |
"num_updates": 0, | |
"optim_steps": {module_name: 0 for module_name in self.model_names}, | |
"start_epoch_idx": 0, | |
"start_batch_idx": 0, | |
"current_epoch_idx": 0, | |
"current_batch_idx": 0, | |
"resume_epoch_idx": 0, | |
} | |
self.train_loss = LossMeter() | |
self.grad_acc_batches = [] | |
if self.opt['CUDA']: | |
torch.cuda.empty_cache() | |
self.create_optimizer_and_scheduler() | |
self.models = {model_name: self.raw_models[model_name] for model_name in self.model_names} | |
self._initialize_ddp() | |
if self.opt.get('WEIGHT', False): | |
self.load_weight(self.opt['RESUME_FROM'], must_exist=True) | |
if self.opt.get('RESUME', False): | |
self.load_checkpoint(self.opt['RESUME_FROM'], must_exist=True) | |
###################### | |
# Start the main loop | |
###################### | |
if self.opt['rank'] == 0: | |
# Train! | |
logger.info("***** Running training *****") | |
logger.info(f" Num of GPUs = {self.opt['world_size']}") | |
logger.info(f" Num Epochs = {self.opt['SOLVER']['MAX_NUM_EPOCHS']}") | |
logger.info(f" Num of Mini Batches per Epoch = {self.train_params['updates_per_epoch']}") | |
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {self.opt['SOLVER']['MAX_NUM_EPOCHS'] * self.train_params['updates_per_epoch']}") | |
logger.info(f" Gradient Accumulation steps = {self.grad_acc_steps}") | |
logger.info(f" Total optimization steps = {self.opt['SOLVER']['MAX_NUM_EPOCHS'] * self.train_params['updates_per_epoch'] // self.grad_acc_steps}") | |
def train(self): | |
""" | |
Training | |
""" | |
self.init_train() | |
current_optim_steps = self._get_and_validate_current_optim_steps() | |
num_epochs = self.opt['SOLVER']['MAX_NUM_EPOCHS'] | |
if self.opt.get('EVAL_AT_START', False): | |
results = self._eval_on_set(self.save_folder) | |
# if self.opt['rank'] == 0 and self.opt['WANDB']: | |
# wandb.log(results) | |
train_prev_logged_time = datetime.now() | |
for epoch in range(self.train_params['start_epoch_idx'], num_epochs): | |
self.train_params['current_epoch_idx'] = epoch | |
logger.info(f"Start epoch: {epoch} training.") | |
epoch_start_time = datetime.now() | |
for batch_idx, batch in enumerate(self.train_dataloaders): | |
if self.train_params['current_epoch_idx'] == self.train_params['start_epoch_idx']: | |
if batch_idx < self.train_params['start_batch_idx']: # skip the first few batches for resuming | |
continue | |
self.train_params['current_batch_idx'] = batch_idx | |
prev_optim_steps = current_optim_steps | |
prev_total_batch_size = self.train_params['total_batch_size'] | |
# update | |
self.prev_optim_steps = prev_optim_steps | |
self.train_step(batch) | |
current_optim_steps = self._get_and_validate_current_optim_steps() | |
# logging | |
if prev_optim_steps != current_optim_steps: # an optimizer update was made | |
log_first = self.opt.get("LOG_FIRST", 10) | |
log_every = self.opt.get("LOG_EVERY", 100) | |
if (current_optim_steps % log_every == 0) or (epoch == 0 and current_optim_steps <= log_first): # print logging | |
last_lr = {} | |
for module_name in self.model_names: | |
last_lr[module_name] = self.lr_schedulers[module_name].get_last_lr()[0] | |
train_time_delta = (datetime.now() - train_prev_logged_time).total_seconds() | |
train_prev_logged_time = datetime.now() | |
MB = 1024.0 * 1024.0 | |
memory = torch.cuda.max_memory_allocated() / MB | |
if self.opt['rank'] == 0: | |
# if self.opt['WANDB']: | |
# # log for wandb | |
# wb_loss_info = {key: obj.val for key, obj in self.train_loss.losses.items()} | |
# wandb.log(wb_loss_info, step=self.prev_optim_steps) | |
# log for terminal | |
logger.info(f"epochs[{epoch:6}] optim steps[{current_optim_steps:.0f}] " | |
f"learning rate[{', '.join([f'{key}: {val:.5e}' for key, val in last_lr.items()])}] " | |
f"train loss[{', '.join([f'{key}: {obj.val:.5f}/{obj.avg:.5f}' for key, obj in self.train_loss.losses.items()])}] " | |
# f"total_loss[{total_loss:.5f}/{total_loss_avg:.5f} " | |
f"items per batch[{self.train_params['total_batch_size'] - prev_total_batch_size}] " | |
f"items per second[{(self.train_params['total_batch_size'] - prev_total_batch_size) / train_time_delta:.2f}] " | |
f"total items[{self.train_params['total_batch_size']}] " | |
f"mini batches[{self.train_params['num_updates']:6}] " | |
f"memory[{memory:.0f}] " | |
f"epoch remaining[{str((datetime.now() - epoch_start_time) / (batch_idx + 1) * (self.train_params['updates_per_epoch'] - batch_idx - 1)).split('.')[0]}]") | |
# evaluate and save ckpt every epoch | |
if batch_idx + 1 == self.train_params['updates_per_epoch']: | |
if self.opt.get('SAVE_CHECKPOINT', True): | |
self.save_checkpoint(self.train_params['num_updates']) | |
results = self._eval_on_set(self.save_folder) | |
# if self.opt['rank'] == 0 and self.opt['WANDB']: | |
# wandb.log(results) | |
break | |
logger.info(f"This epoch takes {datetime.now() - epoch_start_time}") | |
logger.info(f"PROGRESS: {100.0 * (epoch + 1) / num_epochs:.2f}%") | |
logger.info(f"Config files are at {self.opt['conf_files']}") | |
# if not self.opt.get('SAVE_CHECKPOINT', True): | |
# self.save_checkpoint(self.train_params['num_updates']) |