File size: 8,167 Bytes
2df812d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
'''
Author: Chris Xiao yl.xiao@mail.utoronto.ca
Date: 2023-09-11 18:27:02
LastEditors: Chris Xiao yl.xiao@mail.utoronto.ca
LastEditTime: 2023-12-17 18:22:47
FilePath: /EndoSAM/endoSAM/train.py
Description: fine-tune training script
I Love IU
Copyright (c) 2023 by Chris Xiao yl.xiao@mail.utoronto.ca, All Rights Reserved.
'''
'''
@copyright Chris Xiao yl.xiao@mail.utoronto.ca
'''
import argparse
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
import os
from dataset import EndoVisDataset
from utils import make_if_dont_exist, setup_logger, one_hot_embedding_3d, save_checkpoint, plot_progress
import datetime
import torch
from model import EndoSAMAdapter
import numpy as np
from segment_anything.build_sam import sam_model_registry
from loss import ce_loss, mse_loss
from tqdm import tqdm
import wget
COMMON_MODEL_LINKS={
'default': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth'
}
def parse_command():
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', default=None, type=str, help='path to config file')
parser.add_argument('--resume', action='store_true', help='use this if you want to continue a training')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_command()
cfg_path = args.cfg
resume = args.resume
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if cfg_path is not None:
if os.path.exists(cfg_path):
cfg = OmegaConf.load(cfg_path)
else:
raise FileNotFoundError(f'config file {cfg_path} not found')
else:
raise ValueError('config file not specified')
if 'sam_model_dir' not in OmegaConf.to_container(cfg)['model'].keys() or OmegaConf.is_missing(cfg.model, 'sam_model_dir') or not os.path.exists(cfg.model.sam_model_dir):
print("Didn't find SAM Checkpoint. Downloading from Facebook AI...")
parent_dir = '/'.join(os.getcwd().split('/')[:-1])
model_dir = os.path.join(parent_dir, 'sam_ckpts')
make_if_dont_exist(model_dir, overwrite=True)
checkpoint = os.path.join(model_dir, cfg.model.sam_model_type+'.pth')
wget.download(COMMON_MODEL_LINKS[cfg.model.sam_model_type], checkpoint)
OmegaConf.update(cfg, 'model.sam_model_dir', checkpoint)
OmegaConf.save(cfg, cfg_path)
exp = cfg.experiment_name
root_dir = cfg.dataset.dataset_dir
img_format = cfg.dataset.img_format
ann_format = cfg.dataset.ann_format
model_path = cfg.model_folder
log_path = cfg.log_folder
ckpt_path = cfg.ckpt_folder
plot_path = cfg.plot_folder
model_exp_path = os.path.join(model_path, exp)
log_exp_path = os.path.join(log_path, exp)
ckpt_exp_path = os.path.join(ckpt_path, exp)
plot_exp_path = os.path.join(plot_path, exp)
if not resume:
make_if_dont_exist(model_path, overwrite=True)
make_if_dont_exist(log_path, overwrite=True)
make_if_dont_exist(ckpt_path, overwrite=True)
make_if_dont_exist(plot_path, overwrite=True)
make_if_dont_exist(model_exp_path, overwrite=True)
make_if_dont_exist(log_exp_path, overwrite=True)
make_if_dont_exist(ckpt_exp_path, overwrite=True)
make_if_dont_exist(plot_exp_path, overwrite=True)
datetime_object = 'training_log_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + '.log'
logger = setup_logger(f'EndoSAM', os.path.join(log_exp_path, datetime_object))
logger.info(f"Welcome To {exp} Fine-Tuning")
logger.info("Load Dataset-Specific Parameters")
train_dataset = EndoVisDataset(root_dir, ann_format=ann_format, img_format=img_format, mode='train', encoder_size=cfg.model.encoder_size)
valid_dataset = EndoVisDataset(root_dir, ann_format=ann_format, img_format=img_format, mode='val', encoder_size=cfg.model.encoder_size)
train_loader = DataLoader(train_dataset, batch_size=cfg.train_bs, shuffle=True, num_workers=cfg.num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=cfg.val_bs, shuffle=True, num_workers=cfg.num_workers)
logger.info("Load Model-Specific Parameters")
sam_mask_encoder, sam_prompt_encoder, sam_mask_decoder = sam_model_registry[cfg.model.sam_model_type](checkpoint=cfg.model.sam_model_dir,customized=cfg.model.sam_model_customized)
model = EndoSAMAdapter(device, cfg.model.class_num, sam_mask_encoder, sam_prompt_encoder, sam_mask_decoder, num_token=cfg.num_token).to(device)
lr = cfg.opt_params.lr_default
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
train_losses = []
val_losses = []
best_val_loss = np.inf
max_iter = cfg.max_iter
val_iter = cfg.val_iter
start_epoch = 0
if resume:
ckpt = torch.load(os.path.join(ckpt_exp_path, 'ckpt.pth'), map_location=device)
optimizer.load_state_dict(ckpt['optimizer'])
model.load_state_dict(ckpt['weights'])
best_val_loss = ckpt['best_val_loss']
train_losses = ckpt['train_losses']
val_losses = ckpt['val_losses']
lr = optimizer.param_groups[0]['lr']
start_epoch = ckpt['epoch'] + 1
logger.info("Resume Training")
else:
logger.info("Start Training")
for epoch in range(start_epoch, cfg.max_iter):
logger.info(f"Epoch {epoch+1}/{cfg.max_iter}:")
losses = []
model.train()
with tqdm(train_loader, unit='batch', desc='Training') as tdata:
for img, ann, _, _ in tdata:
img = img.to(device)
ann = ann.to(device).unsqueeze(1).long()
ann = one_hot_embedding_3d(ann, class_num=cfg.model.class_num)
optimizer.zero_grad()
pred, pred_quality = model(img)
loss = cfg.losses.ce.weight * ce_loss(ann, pred) + cfg.losses.mse.weight * mse_loss(ann, pred)
tdata.set_postfix(loss=loss.item())
loss.backward()
optimizer.step()
losses.append(loss.item())
avg_loss = np.mean(losses, axis=0)
logger.info(f"\ttraining loss: {avg_loss}")
train_losses.append([epoch+1, avg_loss])
if epoch % cfg.val_iter == 0:
model.eval()
losses = []
with torch.no_grad():
with tqdm(valid_loader, unit='batch', desc='Validation') as tdata:
for img, ann, _, _ in tdata:
img = img.to(device)
ann = ann.to(device).unsqueeze(1).long()
ann = one_hot_embedding_3d(ann, class_num=cfg.model.class_num)
pred, pred_quality = model(img)
loss = cfg.losses.ce.weight * ce_loss(ann, pred) + cfg.losses.mse.weight * mse_loss(ann, pred)
tdata.set_postfix(loss=loss.item())
losses.append(loss.item())
avg_loss = np.mean(losses, axis=0)
logger.info(f"\tvalidation loss: {avg_loss}")
val_losses.append([epoch+1, avg_loss])
if avg_loss < best_val_loss:
best_val_loss = avg_loss
logger.info(f"\tsave best endosam model")
torch.save({
'epoch': epoch,
'best_val_loss': best_val_loss,
'train_losses': train_losses,
'val_losses': val_losses,
'endosam_state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}, os.path.join(model_exp_path, 'model.pth'))
save_dir = os.path.join(ckpt_exp_path, 'ckpt.pth')
save_checkpoint(model, optimizer, epoch, best_val_loss, train_losses, val_losses, save_dir)
plot_progress(logger, plot_exp_path, train_losses, val_losses, 'loss')
|