import os import time from datetime import datetime from tqdm import tqdm from tensorboardX import SummaryWriter import torch import torchinfo import numpy as np import options from validate import validate, calculate_acc from datasets import * from utils.logger import create_logger from utils.earlystop import EarlyStopping from networks.trainer import Trainer if __name__ == '__main__': train_opt = options.TrainOptions().parse() # val_opt = options.TestOptions().parse() # logger logger = create_logger(output_dir=train_opt.checkpoints_dir, name="FeatureTransformer") logger.info(f"working dir: {train_opt.checkpoints_dir}") model = Trainer(train_opt) # logger.info(opt.gpu_ids[0]) logger.info(model.device) # extract_feature_model = model.extract_feature_model train_loader, val_loader = create_train_val_dataloader(train_opt, clip_model = None, transform = model.clip_model.preprocess, k_split=0.8) logger.info(f"train {len(train_loader)}") logger.info(f"validate {len(val_loader)}") train_writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name, "train")) val_writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name, "val")) early_stopping = EarlyStopping(save_path=train_opt.checkpoints_dir, patience=train_opt.earlystop_epoch, delta=-0.001, verbose=True) start_time = time.time() logger.info(torchinfo.summary(model.model, input_size=(train_opt.batch_size, 16, 768), col_width=20, col_names=['input_size', 'output_size', 'num_params', 'trainable'], row_settings=['var_names'], verbose=0)) logger.info("Length of train loader: %d" %(len(train_loader))) for epoch in range(train_opt.niter): y_true, y_pred = [], [] pbar = tqdm(train_loader) for i, data in enumerate(pbar): pbar.set_description(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) model.total_steps += 1 model.set_input(data) model.optimize_parameters() y_pred.extend(model.output.sigmoid().flatten().tolist()) y_true.extend(data[1].flatten().tolist()) if model.total_steps % train_opt.loss_freq == 0: logger.info("Train loss: {} at step: {}".format(model.loss, model.total_steps)) train_writer.add_scalar('loss', model.loss, model.total_steps) logger.info("Iter time: {}".format((time.time()-start_time)/model.total_steps) ) if model.total_steps in [10,30,50,100,1000,5000,10000] and False: # save models at these iters model.save_networks('model_iters_%s.pth' % model.total_steps) # logger.info("trained one batch") pbar.set_postfix_str(f"loss: {model.loss}, ") r_acc0, f_acc0, acc0 = calculate_acc(np.array(y_true), np.array(y_pred), 0.5) logger.info(f"TrainSet r_acc: {r_acc0}, f_acc: {f_acc0}, acc: {acc0}") if epoch % train_opt.save_epoch_freq == 0: logger.info('saving the model at the end of epoch %d' % (epoch)) model.save_networks( 'model_epoch_%s.pth' % epoch ) # Validation model.eval() ap, r_acc, f_acc, acc = validate(model, val_loader, logger=logger) val_writer.add_scalar('accuracy', acc, model.total_steps) val_writer.add_scalar('ap', ap, model.total_steps) logger.info("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap)) early_stopping(acc, model.model) if early_stopping.early_stop: cont_train = model.adjust_learning_rate() if cont_train: logger.info("Learning rate dropped by 10, continue training...") early_stopping = EarlyStopping(save_path=train_opt.checkpoints_dir, patience=train_opt.earlystop_epoch, delta=-0.002, verbose=True) else: logger.info("Early stopping.") break model.train()