|
import os |
|
import time |
|
|
|
import torch, gc |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.autograd import Variable |
|
import torch.nn.functional as F |
|
|
|
import numpy as np |
|
|
|
from pathlib import Path |
|
|
|
from models.ormbg import ORMBG |
|
|
|
from skimage import io |
|
|
|
from basics import f1_mae_torch |
|
|
|
from data_loader_cache import ( |
|
get_im_gt_name_dict, |
|
create_dataloaders, |
|
GOSGridDropout, |
|
GOSRandomHFlip, |
|
) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0): |
|
net.eval() |
|
print("Validating...") |
|
epoch_num = hypar["max_epoch_num"] |
|
|
|
val_loss = 0.0 |
|
tar_loss = 0.0 |
|
val_cnt = 0.0 |
|
|
|
tmp_f1 = [] |
|
tmp_mae = [] |
|
tmp_time = [] |
|
|
|
start_valid = time.time() |
|
|
|
for k in range(len(valid_dataloaders)): |
|
|
|
valid_dataloader = valid_dataloaders[k] |
|
valid_dataset = valid_datasets[k] |
|
|
|
val_num = valid_dataset.__len__() |
|
mybins = np.arange(0, 256) |
|
PRE = np.zeros((val_num, len(mybins) - 1)) |
|
REC = np.zeros((val_num, len(mybins) - 1)) |
|
F1 = np.zeros((val_num, len(mybins) - 1)) |
|
MAE = np.zeros((val_num)) |
|
|
|
for i_val, data_val in enumerate(valid_dataloader): |
|
val_cnt = val_cnt + 1.0 |
|
imidx_val, inputs_val, labels_val, shapes_val = ( |
|
data_val["imidx"], |
|
data_val["image"], |
|
data_val["label"], |
|
data_val["shape"], |
|
) |
|
|
|
if hypar["model_digit"] == "full": |
|
inputs_val = inputs_val.type(torch.FloatTensor) |
|
labels_val = labels_val.type(torch.FloatTensor) |
|
else: |
|
inputs_val = inputs_val.type(torch.HalfTensor) |
|
labels_val = labels_val.type(torch.HalfTensor) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
inputs_val_v, labels_val_v = Variable( |
|
inputs_val.cuda(), requires_grad=False |
|
), Variable(labels_val.cuda(), requires_grad=False) |
|
else: |
|
inputs_val_v, labels_val_v = Variable( |
|
inputs_val, requires_grad=False |
|
), Variable(labels_val, requires_grad=False) |
|
|
|
t_start = time.time() |
|
ds_val = net(inputs_val_v)[0] |
|
t_end = time.time() - t_start |
|
tmp_time.append(t_end) |
|
|
|
|
|
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v) |
|
|
|
|
|
for t in range(hypar["batch_size_valid"]): |
|
i_test = imidx_val[t].data.numpy() |
|
|
|
pred_val = ds_val[0][t, :, :, :] |
|
|
|
|
|
pred_val = torch.squeeze( |
|
F.upsample( |
|
torch.unsqueeze(pred_val, 0), |
|
(shapes_val[t][0], shapes_val[t][1]), |
|
mode="bilinear", |
|
) |
|
) |
|
|
|
|
|
ma = torch.max(pred_val) |
|
mi = torch.min(pred_val) |
|
pred_val = (pred_val - mi) / (ma - mi) |
|
|
|
if len(valid_dataset.dataset["ori_gt_path"]) != 0: |
|
gt = np.squeeze( |
|
io.imread(valid_dataset.dataset["ori_gt_path"][i_test]) |
|
) |
|
if gt.max() == 1: |
|
gt = gt * 255 |
|
else: |
|
gt = np.zeros((shapes_val[t][0], shapes_val[t][1])) |
|
with torch.no_grad(): |
|
gt = torch.tensor(gt).to(device) |
|
|
|
pre, rec, f1, mae = f1_mae_torch( |
|
pred_val * 255, gt, valid_dataset, i_test, mybins, hypar |
|
) |
|
|
|
PRE[i_test, :] = pre |
|
REC[i_test, :] = rec |
|
F1[i_test, :] = f1 |
|
MAE[i_test] = mae |
|
|
|
del ds_val, gt |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
val_loss += loss_val.item() |
|
tar_loss += loss2_val.item() |
|
|
|
print( |
|
"[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f" |
|
% ( |
|
i_val, |
|
val_num, |
|
val_loss / (i_val + 1), |
|
tar_loss / (i_val + 1), |
|
np.amax(F1[i_test, :]), |
|
MAE[i_test], |
|
t_end, |
|
) |
|
) |
|
|
|
del loss2_val, loss_val |
|
|
|
print("============================") |
|
PRE_m = np.mean(PRE, 0) |
|
REC_m = np.mean(REC, 0) |
|
f1_m = (1 + 0.3) * PRE_m * REC_m / (0.3 * PRE_m + REC_m + 1e-8) |
|
|
|
tmp_f1.append(np.amax(f1_m)) |
|
tmp_mae.append(np.mean(MAE)) |
|
|
|
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time |
|
|
|
|
|
def train( |
|
net, |
|
optimizer, |
|
train_dataloaders, |
|
train_datasets, |
|
valid_dataloaders, |
|
valid_datasets, |
|
hypar, |
|
): |
|
|
|
model_path = hypar["model_path"] |
|
model_save_fre = hypar["model_save_fre"] |
|
max_ite = hypar["max_ite"] |
|
batch_size_train = hypar["batch_size_train"] |
|
batch_size_valid = hypar["batch_size_valid"] |
|
|
|
if not os.path.exists(model_path): |
|
os.mkdir(model_path) |
|
|
|
ite_num = hypar["start_ite"] |
|
ite_num4val = 0 |
|
running_loss = 0.0 |
|
running_tar_loss = 0.0 |
|
last_f1 = [0 for x in range(len(valid_dataloaders))] |
|
|
|
train_num = train_datasets[0].__len__() |
|
|
|
net.train() |
|
|
|
start_last = time.time() |
|
gos_dataloader = train_dataloaders[0] |
|
epoch_num = hypar["max_epoch_num"] |
|
notgood_cnt = 0 |
|
|
|
for epoch in range(epoch_num): |
|
|
|
for i, data in enumerate(gos_dataloader): |
|
|
|
if ite_num >= max_ite: |
|
print("Training Reached the Maximal Iteration Number ", max_ite) |
|
exit() |
|
|
|
|
|
ite_num = ite_num + 1 |
|
ite_num4val = ite_num4val + 1 |
|
|
|
|
|
inputs, labels = data["image"], data["label"] |
|
|
|
if hypar["model_digit"] == "full": |
|
inputs = inputs.type(torch.FloatTensor) |
|
labels = labels.type(torch.FloatTensor) |
|
else: |
|
inputs = inputs.type(torch.HalfTensor) |
|
labels = labels.type(torch.HalfTensor) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
inputs_v, labels_v = Variable( |
|
inputs.cuda(), requires_grad=False |
|
), Variable(labels.cuda(), requires_grad=False) |
|
else: |
|
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable( |
|
labels, requires_grad=False |
|
) |
|
|
|
|
|
start_inf_loss_back = time.time() |
|
optimizer.zero_grad() |
|
|
|
ds, _ = net(inputs_v) |
|
loss2, loss = net.compute_loss(ds, labels_v) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
running_loss += loss.item() |
|
running_tar_loss += loss2.item() |
|
|
|
|
|
del ds, loss2, loss |
|
end_inf_loss_back = time.time() - start_inf_loss_back |
|
|
|
print( |
|
">>>" |
|
+ model_path.split("/")[-1] |
|
+ " - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f" |
|
% ( |
|
epoch + 1, |
|
epoch_num, |
|
(i + 1) * batch_size_train, |
|
train_num, |
|
ite_num, |
|
running_loss / ite_num4val, |
|
running_tar_loss / ite_num4val, |
|
time.time() - start_last, |
|
time.time() - start_last - end_inf_loss_back, |
|
) |
|
) |
|
start_last = time.time() |
|
|
|
if ite_num % model_save_fre == 0: |
|
notgood_cnt += 1 |
|
net.eval() |
|
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid( |
|
net, valid_dataloaders, valid_datasets, hypar, epoch |
|
) |
|
net.train() |
|
|
|
tmp_out = 0 |
|
print("last_f1:", last_f1) |
|
print("tmp_f1:", tmp_f1) |
|
for fi in range(len(last_f1)): |
|
if tmp_f1[fi] > last_f1[fi]: |
|
tmp_out = 1 |
|
print("tmp_out:", tmp_out) |
|
if tmp_out: |
|
notgood_cnt = 0 |
|
last_f1 = tmp_f1 |
|
tmp_f1_str = [str(round(f1x, 4)) for f1x in tmp_f1] |
|
tmp_mae_str = [str(round(mx, 4)) for mx in tmp_mae] |
|
maxf1 = "_".join(tmp_f1_str) |
|
meanM = "_".join(tmp_mae_str) |
|
|
|
model_name = ( |
|
"/gpu_itr_" |
|
+ str(ite_num) |
|
+ "_traLoss_" |
|
+ str(np.round(running_loss / ite_num4val, 4)) |
|
+ "_traTarLoss_" |
|
+ str(np.round(running_tar_loss / ite_num4val, 4)) |
|
+ "_valLoss_" |
|
+ str(np.round(val_loss / (i_val + 1), 4)) |
|
+ "_valTarLoss_" |
|
+ str(np.round(tar_loss / (i_val + 1), 4)) |
|
+ "_maxF1_" |
|
+ maxf1 |
|
+ "_mae_" |
|
+ meanM |
|
+ "_time_" |
|
+ str( |
|
np.round(np.mean(np.array(tmp_time)) / batch_size_valid, 6) |
|
) |
|
+ ".pth" |
|
) |
|
torch.save(net.state_dict(), model_path + model_name) |
|
|
|
running_loss = 0.0 |
|
running_tar_loss = 0.0 |
|
ite_num4val = 0 |
|
|
|
if notgood_cnt >= hypar["early_stop"]: |
|
print( |
|
"No improvements in the last " |
|
+ str(notgood_cnt) |
|
+ " validation periods, so training stopped !" |
|
) |
|
exit() |
|
|
|
print("Training Reaches The Maximum Epoch Number") |
|
|
|
|
|
def main(train_datasets, valid_datasets, hypar): |
|
|
|
print("--- create training dataloader ---") |
|
|
|
train_nm_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train") |
|
|
|
train_dataloaders, train_datasets = create_dataloaders( |
|
train_nm_im_gt_list, |
|
cache_size=hypar["cache_size"], |
|
cache_boost=hypar["cache_boost_train"], |
|
my_transforms=[GOSGridDropout(), GOSRandomHFlip()], |
|
batch_size=hypar["batch_size_train"], |
|
shuffle=True, |
|
) |
|
|
|
valid_nm_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid") |
|
|
|
valid_dataloaders, valid_datasets = create_dataloaders( |
|
valid_nm_im_gt_list, |
|
cache_size=hypar["cache_size"], |
|
cache_boost=hypar["cache_boost_valid"], |
|
my_transforms=[], |
|
batch_size=hypar["batch_size_valid"], |
|
shuffle=False, |
|
) |
|
|
|
net = hypar["model"] |
|
|
|
if hypar["model_digit"] == "half": |
|
net.half() |
|
for layer in net.modules(): |
|
if isinstance(layer, nn.BatchNorm2d): |
|
layer.float() |
|
|
|
if torch.cuda.is_available(): |
|
net.cuda() |
|
|
|
if hypar["restore_model"] != "": |
|
print("restore model from:") |
|
print(hypar["model_path"] + "/" + hypar["restore_model"]) |
|
if torch.cuda.is_available(): |
|
net.load_state_dict( |
|
torch.load(hypar["model_path"] + "/" + hypar["restore_model"]) |
|
) |
|
else: |
|
net.load_state_dict( |
|
torch.load( |
|
hypar["model_path"] + "/" + hypar["restore_model"], |
|
map_location="cpu", |
|
) |
|
) |
|
|
|
optimizer = optim.Adam( |
|
net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0 |
|
) |
|
|
|
train( |
|
net, |
|
optimizer, |
|
train_dataloaders, |
|
train_datasets, |
|
valid_dataloaders, |
|
valid_datasets, |
|
hypar, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
output_model_folder = "saved_models" |
|
Path(output_model_folder).mkdir(parents=True, exist_ok=True) |
|
|
|
train_datasets, valid_datasets = [], [] |
|
dataset_1, dataset_1 = {}, {} |
|
|
|
dataset_training = { |
|
"name": "ormbg-training", |
|
"im_dir": str(Path("dataset", "training", "im")), |
|
"gt_dir": str(Path("dataset", "training", "gt")), |
|
"im_ext": ".png", |
|
"gt_ext": ".png", |
|
"cache_dir": str(Path("cache", "teacher", "training")), |
|
} |
|
|
|
dataset_validation = { |
|
"name": "ormbg-training", |
|
"im_dir": str(Path("dataset", "validation", "im")), |
|
"gt_dir": str(Path("dataset", "validation", "gt")), |
|
"im_ext": ".png", |
|
"gt_ext": ".png", |
|
"cache_dir": str(Path("cache", "teacher", "validation")), |
|
} |
|
|
|
train_datasets = [dataset_training] |
|
valid_datasets = [dataset_validation] |
|
|
|
|
|
hypar = {} |
|
|
|
hypar["model"] = ORMBG() |
|
hypar["seed"] = 0 |
|
|
|
|
|
hypar["model_path"] = "saved_models" |
|
|
|
|
|
hypar["restore_model"] = "" |
|
|
|
|
|
hypar["start_ite"] = 0 |
|
|
|
|
|
hypar["model_digit"] = "full" |
|
|
|
|
|
|
|
hypar["cache_size"] = [ |
|
1024, |
|
1024, |
|
] |
|
|
|
|
|
|
|
hypar["cache_boost_train"] = False |
|
|
|
|
|
hypar["cache_boost_valid"] = False |
|
|
|
|
|
hypar["early_stop"] = 20 |
|
|
|
|
|
hypar["model_save_fre"] = 2000 |
|
|
|
|
|
hypar["batch_size_train"] = 8 |
|
|
|
|
|
hypar["batch_size_valid"] = 1 |
|
|
|
|
|
hypar["max_ite"] = 10000000 |
|
|
|
|
|
hypar["max_epoch_num"] = 1000000 |
|
|
|
main(train_datasets, valid_datasets, hypar=hypar) |
|
|