Spaces:
Runtime error
Runtime error
import os | |
import time | |
from datetime import datetime | |
from tqdm import tqdm | |
from tensorboardX import SummaryWriter | |
import torch | |
import torchinfo | |
import numpy as np | |
import options | |
from validate import validate, calculate_acc | |
from datasetss import * | |
from utilss.logger import create_logger | |
from utilss.earlystop import EarlyStopping | |
from networks.trainer import Trainer | |
if __name__ == '__main__': | |
train_opt = options.TrainOptions().parse() | |
# val_opt = options.TestOptions().parse() | |
# logger | |
logger = create_logger(output_dir=train_opt.checkpoints_dir, name="FeatureTransformer") | |
logger.info(f"working dir: {train_opt.checkpoints_dir}") | |
model = Trainer(train_opt) | |
# logger.info(opt.gpu_ids[0]) | |
logger.info(model.device) | |
# extract_feature_model = model.extract_feature_model | |
train_loader, val_loader = create_train_val_dataloader(train_opt, clip_model = None, transform = model.clip_model.preprocess, k_split=0.8) | |
logger.info(f"train {len(train_loader)}") | |
logger.info(f"validate {len(val_loader)}") | |
train_writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name, "train")) | |
val_writer = SummaryWriter(os.path.join(train_opt.checkpoints_dir, train_opt.name, "val")) | |
early_stopping = EarlyStopping(save_path=train_opt.checkpoints_dir, patience=train_opt.earlystop_epoch, delta=-0.001, verbose=True) | |
start_time = time.time() | |
logger.info(torchinfo.summary(model.model, input_size=(train_opt.batch_size, 16, 768), col_width=20, | |
col_names=['input_size', 'output_size', 'num_params', 'trainable'], row_settings=['var_names'], verbose=0)) | |
logger.info("Length of train loader: %d" %(len(train_loader))) | |
for epoch in range(train_opt.niter): | |
y_true, y_pred = [], [] | |
pbar = tqdm(train_loader) | |
for i, data in enumerate(pbar): | |
pbar.set_description(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) | |
model.total_steps += 1 | |
model.set_input(data) | |
model.optimize_parameters() | |
y_pred.extend(model.output.sigmoid().flatten().tolist()) | |
y_true.extend(data[1].flatten().tolist()) | |
if model.total_steps % train_opt.loss_freq == 0: | |
logger.info("Train loss: {} at step: {}".format(model.loss, model.total_steps)) | |
train_writer.add_scalar('loss', model.loss, model.total_steps) | |
logger.info("Iter time: {}".format((time.time()-start_time)/model.total_steps) ) | |
if model.total_steps in [10,30,50,100,1000,5000,10000] and False: # save models at these iters | |
model.save_networks('model_iters_%s.pth' % model.total_steps) | |
# logger.info("trained one batch") | |
pbar.set_postfix_str(f"loss: {model.loss}, ") | |
r_acc0, f_acc0, acc0 = calculate_acc(np.array(y_true), np.array(y_pred), 0.5) | |
logger.info(f"TrainSet r_acc: {r_acc0}, f_acc: {f_acc0}, acc: {acc0}") | |
if epoch % train_opt.save_epoch_freq == 0: | |
logger.info('saving the model at the end of epoch %d' % (epoch)) | |
model.save_networks( 'model_epoch_%s.pth' % epoch ) | |
# Validation | |
model.eval() | |
ap, r_acc, f_acc, acc = validate(model, val_loader, logger=logger) | |
val_writer.add_scalar('accuracy', acc, model.total_steps) | |
val_writer.add_scalar('ap', ap, model.total_steps) | |
logger.info("(Val @ epoch {}) acc: {}; ap: {}".format(epoch, acc, ap)) | |
early_stopping(acc, model.model) | |
if early_stopping.early_stop: | |
cont_train = model.adjust_learning_rate() | |
if cont_train: | |
logger.info("Learning rate dropped by 10, continue training...") | |
early_stopping = EarlyStopping(save_path=train_opt.checkpoints_dir, patience=train_opt.earlystop_epoch, delta=-0.002, verbose=True) | |
else: | |
logger.info("Early stopping.") | |
break | |
model.train() | |