schirrmacher's picture
Upload folder using huggingface_hub
92c1934 verified
raw
history blame
15.7 kB
import os
import time
import torch, gc
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
from pathlib import Path
from models.ormbg import ORMBG
from skimage import io
from basics import f1_mae_torch
from data_loader_cache import (
get_im_gt_name_dict,
create_dataloaders,
GOSGridDropout,
GOSRandomHFlip,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
def valid(net, valid_dataloaders, valid_datasets, hypar, epoch=0):
net.eval()
print("Validating...")
epoch_num = hypar["max_epoch_num"]
val_loss = 0.0
tar_loss = 0.0
val_cnt = 0.0
tmp_f1 = []
tmp_mae = []
tmp_time = []
start_valid = time.time()
for k in range(len(valid_dataloaders)):
valid_dataloader = valid_dataloaders[k]
valid_dataset = valid_datasets[k]
val_num = valid_dataset.__len__()
mybins = np.arange(0, 256)
PRE = np.zeros((val_num, len(mybins) - 1))
REC = np.zeros((val_num, len(mybins) - 1))
F1 = np.zeros((val_num, len(mybins) - 1))
MAE = np.zeros((val_num))
for i_val, data_val in enumerate(valid_dataloader):
val_cnt = val_cnt + 1.0
imidx_val, inputs_val, labels_val, shapes_val = (
data_val["imidx"],
data_val["image"],
data_val["label"],
data_val["shape"],
)
if hypar["model_digit"] == "full":
inputs_val = inputs_val.type(torch.FloatTensor)
labels_val = labels_val.type(torch.FloatTensor)
else:
inputs_val = inputs_val.type(torch.HalfTensor)
labels_val = labels_val.type(torch.HalfTensor)
# wrap them in Variable
if torch.cuda.is_available():
inputs_val_v, labels_val_v = Variable(
inputs_val.cuda(), requires_grad=False
), Variable(labels_val.cuda(), requires_grad=False)
else:
inputs_val_v, labels_val_v = Variable(
inputs_val, requires_grad=False
), Variable(labels_val, requires_grad=False)
t_start = time.time()
ds_val = net(inputs_val_v)[0]
t_end = time.time() - t_start
tmp_time.append(t_end)
# loss2_val, loss_val = muti_loss_fusion(ds_val, labels_val_v)
loss2_val, loss_val = net.compute_loss(ds_val, labels_val_v)
# compute F measure
for t in range(hypar["batch_size_valid"]):
i_test = imidx_val[t].data.numpy()
pred_val = ds_val[0][t, :, :, :] # B x 1 x H x W
## recover the prediction spatial size to the orignal image size
pred_val = torch.squeeze(
F.upsample(
torch.unsqueeze(pred_val, 0),
(shapes_val[t][0], shapes_val[t][1]),
mode="bilinear",
)
)
# pred_val = normPRED(pred_val)
ma = torch.max(pred_val)
mi = torch.min(pred_val)
pred_val = (pred_val - mi) / (ma - mi) # max = 1
if len(valid_dataset.dataset["ori_gt_path"]) != 0:
gt = np.squeeze(
io.imread(valid_dataset.dataset["ori_gt_path"][i_test])
) # max = 255
if gt.max() == 1:
gt = gt * 255
else:
gt = np.zeros((shapes_val[t][0], shapes_val[t][1]))
with torch.no_grad():
gt = torch.tensor(gt).to(device)
pre, rec, f1, mae = f1_mae_torch(
pred_val * 255, gt, valid_dataset, i_test, mybins, hypar
)
PRE[i_test, :] = pre
REC[i_test, :] = rec
F1[i_test, :] = f1
MAE[i_test] = mae
del ds_val, gt
gc.collect()
torch.cuda.empty_cache()
# if(loss_val.data[0]>1):
val_loss += loss_val.item() # data[0]
tar_loss += loss2_val.item() # data[0]
print(
"[validating: %5d/%5d] val_ls:%f, tar_ls: %f, f1: %f, mae: %f, time: %f"
% (
i_val,
val_num,
val_loss / (i_val + 1),
tar_loss / (i_val + 1),
np.amax(F1[i_test, :]),
MAE[i_test],
t_end,
)
)
del loss2_val, loss_val
print("============================")
PRE_m = np.mean(PRE, 0)
REC_m = np.mean(REC, 0)
f1_m = (1 + 0.3) * PRE_m * REC_m / (0.3 * PRE_m + REC_m + 1e-8)
tmp_f1.append(np.amax(f1_m))
tmp_mae.append(np.mean(MAE))
return tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time
def train(
net,
optimizer,
train_dataloaders,
train_datasets,
valid_dataloaders,
valid_datasets,
hypar,
):
model_path = hypar["model_path"]
model_save_fre = hypar["model_save_fre"]
max_ite = hypar["max_ite"]
batch_size_train = hypar["batch_size_train"]
batch_size_valid = hypar["batch_size_valid"]
if not os.path.exists(model_path):
os.mkdir(model_path)
ite_num = hypar["start_ite"] # count the toal iteration number
ite_num4val = 0 #
running_loss = 0.0 # count the toal loss
running_tar_loss = 0.0 # count the target output loss
last_f1 = [0 for x in range(len(valid_dataloaders))]
train_num = train_datasets[0].__len__()
net.train()
start_last = time.time()
gos_dataloader = train_dataloaders[0]
epoch_num = hypar["max_epoch_num"]
notgood_cnt = 0
for epoch in range(epoch_num):
for i, data in enumerate(gos_dataloader):
if ite_num >= max_ite:
print("Training Reached the Maximal Iteration Number ", max_ite)
exit()
# start_read = time.time()
ite_num = ite_num + 1
ite_num4val = ite_num4val + 1
# get the inputs
inputs, labels = data["image"], data["label"]
if hypar["model_digit"] == "full":
inputs = inputs.type(torch.FloatTensor)
labels = labels.type(torch.FloatTensor)
else:
inputs = inputs.type(torch.HalfTensor)
labels = labels.type(torch.HalfTensor)
# wrap them in Variable
if torch.cuda.is_available():
inputs_v, labels_v = Variable(
inputs.cuda(), requires_grad=False
), Variable(labels.cuda(), requires_grad=False)
else:
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(
labels, requires_grad=False
)
# y zero the parameter gradients
start_inf_loss_back = time.time()
optimizer.zero_grad()
ds, _ = net(inputs_v)
loss2, loss = net.compute_loss(ds, labels_v)
loss.backward()
optimizer.step()
# # print statistics
running_loss += loss.item()
running_tar_loss += loss2.item()
# del outputs, loss
del ds, loss2, loss
end_inf_loss_back = time.time() - start_inf_loss_back
print(
">>>"
+ model_path.split("/")[-1]
+ " - [epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f, time-per-iter: %3f s, time_read: %3f"
% (
epoch + 1,
epoch_num,
(i + 1) * batch_size_train,
train_num,
ite_num,
running_loss / ite_num4val,
running_tar_loss / ite_num4val,
time.time() - start_last,
time.time() - start_last - end_inf_loss_back,
)
)
start_last = time.time()
if ite_num % model_save_fre == 0: # validate every 2000 iterations
notgood_cnt += 1
net.eval()
tmp_f1, tmp_mae, val_loss, tar_loss, i_val, tmp_time = valid(
net, valid_dataloaders, valid_datasets, hypar, epoch
)
net.train() # resume train
tmp_out = 0
print("last_f1:", last_f1)
print("tmp_f1:", tmp_f1)
for fi in range(len(last_f1)):
if tmp_f1[fi] > last_f1[fi]:
tmp_out = 1
print("tmp_out:", tmp_out)
if tmp_out:
notgood_cnt = 0
last_f1 = tmp_f1
tmp_f1_str = [str(round(f1x, 4)) for f1x in tmp_f1]
tmp_mae_str = [str(round(mx, 4)) for mx in tmp_mae]
maxf1 = "_".join(tmp_f1_str)
meanM = "_".join(tmp_mae_str)
# .cpu().detach().numpy()
model_name = (
"/gpu_itr_"
+ str(ite_num)
+ "_traLoss_"
+ str(np.round(running_loss / ite_num4val, 4))
+ "_traTarLoss_"
+ str(np.round(running_tar_loss / ite_num4val, 4))
+ "_valLoss_"
+ str(np.round(val_loss / (i_val + 1), 4))
+ "_valTarLoss_"
+ str(np.round(tar_loss / (i_val + 1), 4))
+ "_maxF1_"
+ maxf1
+ "_mae_"
+ meanM
+ "_time_"
+ str(
np.round(np.mean(np.array(tmp_time)) / batch_size_valid, 6)
)
+ ".pth"
)
torch.save(net.state_dict(), model_path + model_name)
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
if notgood_cnt >= hypar["early_stop"]:
print(
"No improvements in the last "
+ str(notgood_cnt)
+ " validation periods, so training stopped !"
)
exit()
print("Training Reaches The Maximum Epoch Number")
def main(train_datasets, valid_datasets, hypar):
print("--- create training dataloader ---")
train_nm_im_gt_list = get_im_gt_name_dict(train_datasets, flag="train")
## build dataloader for training datasets
train_dataloaders, train_datasets = create_dataloaders(
train_nm_im_gt_list,
cache_size=hypar["cache_size"],
cache_boost=hypar["cache_boost_train"],
my_transforms=[GOSGridDropout(), GOSRandomHFlip()],
batch_size=hypar["batch_size_train"],
shuffle=True,
)
valid_nm_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid")
valid_dataloaders, valid_datasets = create_dataloaders(
valid_nm_im_gt_list,
cache_size=hypar["cache_size"],
cache_boost=hypar["cache_boost_valid"],
my_transforms=[],
batch_size=hypar["batch_size_valid"],
shuffle=False,
)
net = hypar["model"]
if hypar["model_digit"] == "half":
net.half()
for layer in net.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.float()
if torch.cuda.is_available():
net.cuda()
if hypar["restore_model"] != "":
print("restore model from:")
print(hypar["model_path"] + "/" + hypar["restore_model"])
if torch.cuda.is_available():
net.load_state_dict(
torch.load(hypar["model_path"] + "/" + hypar["restore_model"])
)
else:
net.load_state_dict(
torch.load(
hypar["model_path"] + "/" + hypar["restore_model"],
map_location="cpu",
)
)
optimizer = optim.Adam(
net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0
)
train(
net,
optimizer,
train_dataloaders,
train_datasets,
valid_dataloaders,
valid_datasets,
hypar,
)
if __name__ == "__main__":
output_model_folder = "saved_models"
Path(output_model_folder).mkdir(parents=True, exist_ok=True)
train_datasets, valid_datasets = [], []
dataset_1, dataset_1 = {}, {}
dataset_training = {
"name": "ormbg-training",
"im_dir": str(Path("dataset", "training", "im")),
"gt_dir": str(Path("dataset", "training", "gt")),
"im_ext": ".png",
"gt_ext": ".png",
"cache_dir": str(Path("cache", "teacher", "training")),
}
dataset_validation = {
"name": "ormbg-training",
"im_dir": str(Path("dataset", "validation", "im")),
"gt_dir": str(Path("dataset", "validation", "gt")),
"im_ext": ".png",
"gt_ext": ".png",
"cache_dir": str(Path("cache", "teacher", "validation")),
}
train_datasets = [dataset_training]
valid_datasets = [dataset_validation]
### --------------- STEP 2: Configuring the hyperparamters for Training, validation and inferencing ---------------
hypar = {}
hypar["model"] = ORMBG()
hypar["seed"] = 0
## model weights path
hypar["model_path"] = "saved_models"
## name of the segmentation model weights .pth for resume training process from last stop or for the inferencing
hypar["restore_model"] = ""
## start iteration for the training, can be changed to match the restored training process
hypar["start_ite"] = 0
## indicates "half" or "full" accuracy of float number
hypar["model_digit"] = "full"
## To handle large size input images, which take a lot of time for loading in training,
# we introduce the cache mechanism for pre-convering and resizing the jpg and png images into .pt file
hypar["cache_size"] = [
1024,
1024,
]
## cached input spatial resolution, can be configured into different size
## "True" or "False", indicates wheather to load all the training datasets into RAM, True will greatly speed the training process while requires more RAM
hypar["cache_boost_train"] = False
## "True" or "False", indicates wheather to load all the validation datasets into RAM, True will greatly speed the training process while requires more RAM
hypar["cache_boost_valid"] = False
## stop the training when no improvement in the past 20 validation periods, smaller numbers can be used here e.g., 5 or 10.
hypar["early_stop"] = 20
## valid and save model weights every 2000 iterations
hypar["model_save_fre"] = 2000
## batch size for training
hypar["batch_size_train"] = 8
## batch size for validation and inferencing
hypar["batch_size_valid"] = 1
## if early stop couldn't stop the training process, stop it by the max_ite_num
hypar["max_ite"] = 10000000
## if early stop and max_ite couldn't stop the training process, stop it by the max_epoch_num
hypar["max_epoch_num"] = 1000000
main(train_datasets, valid_datasets, hypar=hypar)