Spaces:
Running
Running
File size: 13,207 Bytes
4a582ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 |
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from collections import deque, defaultdict
import pickle
import shutil
import numpy as np
import paddle
import paddle.nn.functional as F
from paddleseg.utils import TimeAverager, calculate_eta, resume, logger
from .val import evaluate
def visual_in_traning(log_writer, vis_dict, step):
"""
Visual in vdl
Args:
log_writer (LogWriter): The log writer of vdl.
vis_dict (dict): Dict of tensor. The shape of thesor is (C, H, W)
"""
for key, value in vis_dict.items():
value_shape = value.shape
if value_shape[0] not in [1, 3]:
value = value[0]
value = value.unsqueeze(0)
value = paddle.transpose(value, (1, 2, 0))
min_v = paddle.min(value)
max_v = paddle.max(value)
if (min_v > 0) and (max_v < 1):
value = value * 255
elif (min_v < 0 and min_v >= -1) and (max_v <= 1):
value = (1 + value) / 2 * 255
else:
value = (value - min_v) / (max_v - min_v) * 255
value = value.astype('uint8')
value = value.numpy()
log_writer.add_image(tag=key, img=value, step=step)
def save_best(best_model_dir, metrics_data, iter):
with open(os.path.join(best_model_dir, 'best_metrics.txt'), 'w') as f:
for key, value in metrics_data.items():
line = key + ' ' + str(value) + '\n'
f.write(line)
f.write('iter' + ' ' + str(iter) + '\n')
def get_best(best_file, metrics, resume_model=None):
'''Get best metrics and iter from file'''
best_metrics_data = {}
if os.path.exists(best_file) and (resume_model is not None):
values = []
with open(best_file, 'r') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
key, value = line.split(' ')
best_metrics_data[key] = eval(value)
if key == 'iter':
best_iter = eval(value)
else:
for key in metrics:
best_metrics_data[key] = np.inf
best_iter = -1
return best_metrics_data, best_iter
def train(model,
train_dataset,
val_dataset=None,
optimizer=None,
save_dir='output',
iters=10000,
batch_size=2,
resume_model=None,
save_interval=1000,
log_iters=10,
log_image_iters=1000,
num_workers=0,
use_vdl=False,
losses=None,
keep_checkpoint_max=5,
eval_begin_iters=None,
metrics='sad'):
"""
Launch training.
Args:
model(nn.Layer): A matting model.
train_dataset (paddle.io.Dataset): Used to read and process training datasets.
val_dataset (paddle.io.Dataset, optional): Used to read and process validation datasets.
optimizer (paddle.optimizer.Optimizer): The optimizer.
save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
iters (int, optional): How may iters to train the model. Defualt: 10000.
batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
resume_model (str, optional): The path of resume model.
save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
log_iters (int, optional): Display logging information at every log_iters. Default: 10.
log_image_iters (int, optional): Log image to vdl. Default: 1000.
num_workers (int, optional): Num workers for data loader. Default: 0.
use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
losses (dict, optional): A dict of loss, refer to the loss function of the model for details. Default: None.
keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
eval_begin_iters (int): The iters begin evaluation. It will evaluate at iters/2 if it is None. Defalust: None.
metrics(str|list, optional): The metrics to evaluate, it may be the combination of ("sad", "mse", "grad", "conn").
"""
model.train()
nranks = paddle.distributed.ParallelEnv().nranks
local_rank = paddle.distributed.ParallelEnv().local_rank
start_iter = 0
if resume_model is not None:
start_iter = resume(model, optimizer, resume_model)
if not os.path.isdir(save_dir):
if os.path.exists(save_dir):
os.remove(save_dir)
os.makedirs(save_dir)
if nranks > 1:
# Initialize parallel environment if not done.
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
):
paddle.distributed.init_parallel_env()
ddp_model = paddle.DataParallel(model)
else:
ddp_model = paddle.DataParallel(model)
batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
loader = paddle.io.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
return_list=True, )
if use_vdl:
from visualdl import LogWriter
log_writer = LogWriter(save_dir)
if isinstance(metrics, str):
metrics = [metrics]
elif not isinstance(metrics, list):
metrics = ['sad']
best_metrics_data, best_iter = get_best(
os.path.join(save_dir, 'best_model', 'best_metrics.txt'),
metrics,
resume_model=resume_model)
avg_loss = defaultdict(float)
iters_per_epoch = len(batch_sampler)
reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager()
save_models = deque()
batch_start = time.time()
iter = start_iter
while iter < iters:
for data in loader:
iter += 1
if iter > iters:
break
reader_cost_averager.record(time.time() - batch_start)
logit_dict, loss_dict = ddp_model(data) if nranks > 1 else model(
data)
loss_dict['all'].backward()
optimizer.step()
lr = optimizer.get_lr()
if isinstance(optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler):
optimizer._learning_rate.step()
model.clear_gradients()
for key, value in loss_dict.items():
avg_loss[key] += value.numpy()[0]
batch_cost_averager.record(
time.time() - batch_start, num_samples=batch_size)
if (iter) % log_iters == 0 and local_rank == 0:
for key, value in avg_loss.items():
avg_loss[key] = value / log_iters
remain_iters = iters - iter
avg_train_batch_cost = batch_cost_averager.get_average()
avg_train_reader_cost = reader_cost_averager.get_average()
eta = calculate_eta(remain_iters, avg_train_batch_cost)
# loss info
loss_str = ' ' * 26 + '\t[LOSSES]'
loss_str = loss_str
for key, value in avg_loss.items():
if key != 'all':
loss_str = loss_str + ' ' + key + '={:.4f}'.format(
value)
logger.info(
"[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.5f}, ips={:.4f} samples/sec | ETA {}\n{}\n"
.format((iter - 1) // iters_per_epoch + 1, iter, iters,
avg_loss['all'], lr, avg_train_batch_cost,
avg_train_reader_cost,
batch_cost_averager.get_ips_average(
), eta, loss_str))
if use_vdl:
for key, value in avg_loss.items():
log_tag = 'Train/' + key
log_writer.add_scalar(log_tag, value, iter)
log_writer.add_scalar('Train/lr', lr, iter)
log_writer.add_scalar('Train/batch_cost',
avg_train_batch_cost, iter)
log_writer.add_scalar('Train/reader_cost',
avg_train_reader_cost, iter)
if iter % log_image_iters == 0:
vis_dict = {}
# ground truth
vis_dict['ground truth/img'] = data['img'][0]
for key in data['gt_fields']:
key = key[0]
vis_dict['/'.join(['ground truth', key])] = data[
key][0]
# predict
for key, value in logit_dict.items():
vis_dict['/'.join(['predict', key])] = logit_dict[
key][0]
visual_in_traning(
log_writer=log_writer, vis_dict=vis_dict, step=iter)
for key in avg_loss.keys():
avg_loss[key] = 0.
reader_cost_averager.reset()
batch_cost_averager.reset()
# save model
if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
current_save_dir = os.path.join(save_dir,
"iter_{}".format(iter))
if not os.path.isdir(current_save_dir):
os.makedirs(current_save_dir)
paddle.save(model.state_dict(),
os.path.join(current_save_dir, 'model.pdparams'))
paddle.save(optimizer.state_dict(),
os.path.join(current_save_dir, 'model.pdopt'))
save_models.append(current_save_dir)
if len(save_models) > keep_checkpoint_max > 0:
model_to_remove = save_models.popleft()
shutil.rmtree(model_to_remove)
# eval model
if eval_begin_iters is None:
eval_begin_iters = iters // 2
if (iter % save_interval == 0 or iter == iters) and (
val_dataset is not None
) and local_rank == 0 and iter >= eval_begin_iters:
num_workers = 1 if num_workers > 0 else 0
metrics_data = evaluate(
model,
val_dataset,
num_workers=1,
print_detail=True,
save_results=False,
metrics=metrics)
model.train()
# save best model and add evaluation results to vdl
if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
if val_dataset is not None and iter >= eval_begin_iters:
if metrics_data[metrics[0]] < best_metrics_data[metrics[0]]:
best_iter = iter
best_metrics_data = metrics_data.copy()
best_model_dir = os.path.join(save_dir, "best_model")
paddle.save(
model.state_dict(),
os.path.join(best_model_dir, 'model.pdparams'))
save_best(best_model_dir, best_metrics_data, iter)
show_list = []
for key, value in best_metrics_data.items():
show_list.append((key, value))
log_str = '[EVAL] The model with the best validation {} ({:.4f}) was saved at iter {}.'.format(
show_list[0][0], show_list[0][1], best_iter)
if len(show_list) > 1:
log_str += " While"
for i in range(1, len(show_list)):
log_str = log_str + ' {}: {:.4f},'.format(
show_list[i][0], show_list[i][1])
log_str = log_str[:-1]
logger.info(log_str)
if use_vdl:
for key, value in metrics_data.items():
log_writer.add_scalar('Evaluate/' + key, value,
iter)
batch_start = time.time()
# Sleep for half a second to let dataloader release resources.
time.sleep(0.5)
if use_vdl:
log_writer.close()
|