|
import haienv |
|
haienv.set_env('lavt2') |
|
import torch.multiprocessing as mp |
|
import torch.distributed as dist |
|
|
|
import datetime |
|
import os |
|
import time |
|
|
|
import torch |
|
import torch.utils.data |
|
from torch import nn |
|
|
|
from functools import reduce |
|
import operator |
|
from bert.modeling_bert import BertModel |
|
|
|
import torchvision |
|
from lib import segmentation |
|
|
|
import transforms as T |
|
import utils |
|
import numpy as np |
|
|
|
import torch.nn.functional as F |
|
|
|
import gc |
|
from collections import OrderedDict |
|
|
|
import torch.backends.cudnn as cudnn |
|
|
|
from ffrecord.torch import DataLoader,Dataset |
|
def get_dataset(image_set, transform, args): |
|
from data.dataset_refer_bert import ReferDataset |
|
ds = ReferDataset(args, |
|
split=image_set, |
|
image_transforms=transform, |
|
target_transforms=None |
|
) |
|
num_classes = 2 |
|
|
|
return ds, num_classes |
|
|
|
|
|
|
|
def IoU(pred, gt): |
|
pred = pred.argmax(1) |
|
|
|
intersection = torch.sum(torch.mul(pred, gt)) |
|
union = torch.sum(torch.add(pred, gt)) - intersection |
|
|
|
if intersection == 0 or union == 0: |
|
iou = 0 |
|
else: |
|
iou = float(intersection) / float(union) |
|
|
|
return iou, intersection, union |
|
|
|
|
|
def get_transform(args): |
|
transforms = [T.Resize(args.img_size, args.img_size), |
|
T.ToTensor(), |
|
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
] |
|
|
|
return T.Compose(transforms) |
|
|
|
|
|
def criterion(input, target): |
|
weight = torch.FloatTensor([0.9, 1.1]).cuda() |
|
return nn.functional.cross_entropy(input, target, weight=weight) |
|
|
|
|
|
def evaluate(model, data_loader, bert_model): |
|
model.eval() |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
header = 'Test:' |
|
total_its = 0 |
|
acc_ious = 0 |
|
|
|
|
|
cum_I, cum_U = 0, 0 |
|
eval_seg_iou_list = [.5, .6, .7, .8, .9] |
|
seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) |
|
seg_total = 0 |
|
mean_IoU = [] |
|
|
|
with torch.no_grad(): |
|
for data in metric_logger.log_every(data_loader, 100, header): |
|
total_its += 1 |
|
image, target, sentences, attentions = data |
|
image, target, sentences, attentions = image.cuda(non_blocking=True),\ |
|
target.cuda(non_blocking=True),\ |
|
sentences.cuda(non_blocking=True),\ |
|
attentions.cuda(non_blocking=True) |
|
|
|
sentences = sentences.squeeze(1) |
|
attentions = attentions.squeeze(1) |
|
|
|
|
|
|
|
if bert_model is not None: |
|
last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] |
|
|
|
embedding = last_hidden_states.permute(0, 2, 1) |
|
attentions = attentions.unsqueeze(dim=-1) |
|
output = model(image, embedding, l_mask=attentions) |
|
else: |
|
output = model(image, sentences, l_mask=attentions) |
|
|
|
iou, I, U = IoU(output, target) |
|
acc_ious += iou |
|
mean_IoU.append(iou) |
|
cum_I += I |
|
cum_U += U |
|
for n_eval_iou in range(len(eval_seg_iou_list)): |
|
eval_seg_iou = eval_seg_iou_list[n_eval_iou] |
|
seg_correct[n_eval_iou] += (iou >= eval_seg_iou) |
|
seg_total += 1 |
|
iou = acc_ious / total_its |
|
|
|
mean_IoU = np.array(mean_IoU) |
|
mIoU = np.mean(mean_IoU) |
|
print('Final results:') |
|
print('Mean IoU is %.2f\n' % (mIoU * 100.)) |
|
results_str = '' |
|
for n_eval_iou in range(len(eval_seg_iou_list)): |
|
results_str += ' precision@%s = %.2f\n' % \ |
|
(str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) |
|
results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) |
|
print(results_str) |
|
|
|
return 100 * iou, 100 * cum_I / cum_U |
|
|
|
|
|
def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, print_freq, |
|
iterations, bert_model): |
|
model.train() |
|
metric_logger = utils.MetricLogger(delimiter=" ") |
|
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) |
|
header = 'Epoch: [{}]'.format(epoch) |
|
train_loss = 0 |
|
total_its = 0 |
|
|
|
for data in metric_logger.log_every(data_loader, print_freq, header): |
|
total_its += 1 |
|
image, target, sentences, attentions = data |
|
image, target, sentences, attentions = image.cuda(non_blocking=True),\ |
|
target.cuda(non_blocking=True),\ |
|
sentences.cuda(non_blocking=True),\ |
|
attentions.cuda(non_blocking=True) |
|
|
|
sentences = sentences.squeeze(1) |
|
attentions = attentions.squeeze(1) |
|
|
|
|
|
|
|
|
|
|
|
if bert_model is not None: |
|
last_hidden_states = bert_model(sentences, attention_mask=attentions)[0] |
|
|
|
|
|
embedding = last_hidden_states.permute(0, 2, 1) |
|
|
|
attentions = attentions.unsqueeze(dim=-1) |
|
|
|
output = model(image, embedding, l_mask=attentions) |
|
else: |
|
output = model(image, sentences, l_mask=attentions) |
|
|
|
loss = criterion(output, target) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
lr_scheduler.step() |
|
|
|
torch.cuda.synchronize() |
|
train_loss += loss.item() |
|
iterations += 1 |
|
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) |
|
|
|
del image, target, sentences, attentions, loss, output, data |
|
if bert_model is not None: |
|
del last_hidden_states, embedding |
|
|
|
|
|
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
def main(local_rank, args): |
|
ip = os.environ['MASTER_IP'] |
|
port = os.environ['MASTER_PORT'] |
|
hosts = int(os.environ['WORLD_SIZE']) |
|
rank = int(os.environ['RANK']) |
|
gpus = torch.cuda.device_count() |
|
print(local_rank, rank, gpus) |
|
dist.init_process_group(backend='nccl', init_method=f'tcp://{ip}:{port}', world_size=hosts*gpus, rank=rank*gpus+local_rank) |
|
torch.cuda.set_device(local_rank) |
|
dist.barrier() |
|
|
|
|
|
args.distributed=True |
|
args.gpu = local_rank |
|
print(args) |
|
|
|
|
|
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) |
|
print("{}".format(args).replace(', ', ',\n')) |
|
|
|
device = torch.device(args.device) |
|
|
|
|
|
seed = args.seed + utils.get_rank() |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
|
|
|
|
|
|
dataset, num_classes = get_dataset("train", |
|
get_transform(args=args), |
|
args=args) |
|
dataset_test, _ = get_dataset("val", |
|
get_transform(args=args), |
|
args=args) |
|
|
|
|
|
print(f"local rank {args.local_rank} / global rank {utils.get_rank()} successfully built train dataset.") |
|
|
|
|
|
num_tasks = hosts*gpus |
|
global_rank = rank*gpus+local_rank |
|
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, |
|
shuffle=True) |
|
test_sampler = torch.utils.data.SequentialSampler(dataset_test) |
|
|
|
|
|
data_loader = DataLoader( |
|
dataset, batch_size=args.batch_size, |
|
sampler=train_sampler, num_workers=args.workers, pin_memory=True, drop_last=True) |
|
|
|
data_loader_test = DataLoader( |
|
dataset_test, batch_size=1, sampler=test_sampler, pin_memory=True, num_workers=args.workers) |
|
|
|
|
|
print(args.model) |
|
model = segmentation.__dict__[args.model](pretrained=args.pretrained_swin_weights, |
|
args=args) |
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
model.cuda() |
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True) |
|
|
|
single_model = model.module |
|
|
|
if args.model != 'lavt_one': |
|
model_class = BertModel |
|
bert_model = model_class.from_pretrained(args.ck_bert) |
|
bert_model.pooler = None |
|
bert_model.cuda() |
|
bert_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(bert_model) |
|
bert_model = torch.nn.parallel.DistributedDataParallel(bert_model, device_ids=[args.local_rank]) |
|
single_bert_model = bert_model.module |
|
else: |
|
bert_model = None |
|
single_bert_model = None |
|
|
|
input_shape = dict() |
|
input_shape['s1'] = Dict({'channel': 128, 'stride': 4}) |
|
input_shape['s2'] = Dict({'channel': 256, 'stride': 8}) |
|
input_shape['s3'] = Dict({'channel': 512, 'stride': 16}) |
|
input_shape['s4'] = Dict({'channel': 1024, 'stride': 32}) |
|
|
|
|
|
|
|
cfg = Dict() |
|
cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 |
|
cfg.MODEL.MASK_FORMER.DROPOUT = 0.0 |
|
cfg.MODEL.MASK_FORMER.NHEADS = 8 |
|
cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 4 |
|
cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM = 256 |
|
cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256 |
|
cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES = ["s1", "s2", "s3", "s4"] |
|
|
|
cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1 |
|
cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256 |
|
cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 1 |
|
cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048 |
|
cfg.MODEL.MASK_FORMER.DEC_LAYERS = 10 |
|
cfg.MODEL.MASK_FORMER.PRE_NORM = False |
|
|
|
|
|
maskformer_head = MaskFormerHead(cfg, input_shape) |
|
maskformer_head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(maskformer_head) |
|
maskformer_head.cuda() |
|
maskformer_head = torch.nn.parallel.DistributedDataParallel(maskformer_head, device_ids=[args.local_rank], find_unused_parameters=False) |
|
single_head = maskformer_head.module |
|
print(single_head) |
|
|
|
|
|
if args.resume == "auto": |
|
last_ckpt = "" |
|
for e in range(args.epochs): |
|
ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth') |
|
if os.path.exists(ckpt_path): |
|
last_ckpt = ckpt_path |
|
args.resume = last_ckpt |
|
|
|
|
|
if args.resume: |
|
checkpoint = torch.load(args.resume, map_location='cpu') |
|
single_model.load_state_dict(checkpoint['model']) |
|
single_head.load_state_dict(checkpoint['head_model']) |
|
if args.model != 'lavt_one': |
|
single_bert_model.load_state_dict(checkpoint['bert_model']) |
|
|
|
|
|
backbone_no_decay = list() |
|
backbone_decay = list() |
|
for name, m in single_model.backbone.named_parameters(): |
|
if 'norm' in name or 'absolute_pos_embed' in name or 'relative_position_bias_table' in name: |
|
backbone_no_decay.append(m) |
|
else: |
|
backbone_decay.append(m) |
|
|
|
if args.model != 'lavt_one': |
|
params_to_optimize = [ |
|
{'params': backbone_no_decay, 'weight_decay': 0.0}, |
|
{'params': backbone_decay}, |
|
{"params": [p for p in single_model.classifier.parameters() if p.requires_grad]}, |
|
|
|
{"params": reduce(operator.concat, |
|
[[p for p in single_bert_model.encoder.layer[i].parameters() |
|
if p.requires_grad] for i in range(10)])}, |
|
{"params": single_head.parameters()} |
|
] |
|
else: |
|
params_to_optimize = [ |
|
{'params': backbone_no_decay, 'weight_decay': 0.0}, |
|
{'params': backbone_decay}, |
|
{"params": [p for p in single_model.classifier.parameters() if p.requires_grad]}, |
|
|
|
{"params": reduce(operator.concat, |
|
[[p for p in single_model.text_encoder.encoder.layer[i].parameters() |
|
if p.requires_grad] for i in range(10)])}, |
|
] |
|
|
|
|
|
optimizer = torch.optim.AdamW(params_to_optimize, |
|
lr=args.lr, |
|
weight_decay=args.weight_decay, |
|
amsgrad=args.amsgrad |
|
) |
|
|
|
|
|
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, |
|
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) |
|
|
|
|
|
start_time = time.time() |
|
iterations = 0 |
|
best_oIoU = -0.1 |
|
|
|
|
|
if args.resume: |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
resume_epoch = checkpoint['epoch'] |
|
else: |
|
resume_epoch = -999 |
|
|
|
|
|
for epoch in range(max(0, resume_epoch+1), args.epochs): |
|
data_loader.sampler.set_epoch(epoch) |
|
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, epoch, args.print_freq, |
|
iterations, bert_model, single_head) |
|
iou, overallIoU = evaluate(model, data_loader_test, bert_model, single_head) |
|
|
|
print('Average object IoU {}'.format(iou)) |
|
print('Overall IoU {}'.format(overallIoU)) |
|
|
|
|
|
if single_bert_model is not None: |
|
dict_to_save = {'model': single_model.state_dict(), 'bert_model': single_bert_model.state_dict(), |
|
'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, |
|
'lr_scheduler': lr_scheduler.state_dict(), 'head_model': single_head.state_dict()} |
|
else: |
|
dict_to_save = {'model': single_model.state_dict(), |
|
'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, |
|
'lr_scheduler': lr_scheduler.state_dict()} |
|
|
|
checkpoint_path = os.path.join(args.output_dir, 'checkpoint-{}.pth'.format(epoch)) |
|
utils.save_on_master(dict_to_save, str(checkpoint_path) + '_TEMP') |
|
if utils.is_main_process(): |
|
os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path)) |
|
|
|
if utils.is_main_process(): |
|
ckpt_paths = [] |
|
for e in range(args.epochs): |
|
ckpt_path = os.path.join(args.output_dir, f'checkpoint-{e}.pth') |
|
print(ckpt_path) |
|
if os.path.exists(ckpt_path): |
|
ckpt_paths.append(ckpt_path) |
|
print(ckpt_paths) |
|
for ckpt_path in ckpt_paths[:-args.max_ckpt]: |
|
os.remove(ckpt_path) |
|
print("remove {:s}".format(ckpt_path)) |
|
|
|
|
|
save_checkpoint = (best_oIoU < overallIoU) |
|
if save_checkpoint: |
|
print('Better epoch: {}\n'.format(epoch)) |
|
if single_bert_model is not None: |
|
dict_to_save = {'model': single_model.state_dict(), 'bert_model': single_bert_model.state_dict(), |
|
'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, |
|
'lr_scheduler': lr_scheduler.state_dict()} |
|
else: |
|
dict_to_save = {'model': single_model.state_dict(), |
|
'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, |
|
'lr_scheduler': lr_scheduler.state_dict()} |
|
|
|
checkpoint_path = os.path.join(args.output_dir, 'model_best_{}.pth'.format(args.model_id)) |
|
utils.save_on_master(dict_to_save, checkpoint_path + '_TEMP') |
|
if utils.is_main_process(): |
|
os.rename(str(checkpoint_path) + '_TEMP', str(checkpoint_path)) |
|
best_oIoU = overallIoU |
|
|
|
|
|
total_time = time.time() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
print('Training time {}'.format(total_time_str)) |
|
|
|
|
|
if __name__ == "__main__": |
|
from args import get_parser |
|
parser = get_parser() |
|
args = parser.parse_args() |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
print('Image size: {}'.format(str(args.img_size))) |
|
|
|
mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count()) |
|
|