Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import datetime | |
import logging | |
import time | |
import torch | |
import torch.distributed as dist | |
from maskrcnn_benchmark.utils.comm import get_world_size | |
from maskrcnn_benchmark.utils.metric_logger import MetricLogger | |
def reduce_loss_dict(all_loss_dict): | |
""" | |
Reduce the loss dictionary from all processes so that process with rank | |
0 has the averaged results. Returns a dict with the same fields as | |
loss_dict, after reduction. | |
""" | |
world_size = get_world_size() | |
with torch.no_grad(): | |
loss_names = [] | |
all_losses = [] | |
for loss_dict in all_loss_dict: | |
for k in sorted(loss_dict.keys()): | |
loss_names.append(k) | |
all_losses.append(loss_dict[k]) | |
all_losses = torch.stack(all_losses, dim=0) | |
if world_size > 1: | |
dist.reduce(all_losses, dst=0) | |
if dist.get_rank() == 0: | |
# only main process gets accumulated, so only divide by | |
# world_size in this case | |
all_losses /= world_size | |
reduced_losses = {} | |
for k, v in zip(loss_names, all_losses): | |
if k not in reduced_losses: | |
reduced_losses[k] = v / len(all_loss_dict) | |
reduced_losses[k] += v / len(all_loss_dict) | |
return reduced_losses | |
def do_train( | |
model, | |
data_loader, | |
optimizer, | |
scheduler, | |
checkpointer, | |
device, | |
checkpoint_period, | |
arguments, | |
): | |
logger = logging.getLogger("maskrcnn_benchmark.trainer") | |
logger.info("Start training") | |
meters = MetricLogger(delimiter=" ") | |
max_iter = min(len(task_loader) for task_loader in data_loader) | |
start_iter = arguments["iteration"] | |
model.train() | |
start_training_time = time.time() | |
end = time.time() | |
for iteration, task_loader in enumerate(zip(*data_loader), start_iter): | |
data_time = time.time() - end | |
iteration = iteration + 1 | |
arguments["iteration"] = iteration | |
all_task_loss_dict = [] | |
for task, (images, targets, _) in enumerate(task_loader, 1): | |
if all(len(target) < 1 for target in targets): | |
logger.warning('Sampled all negative batches, skip') | |
continue | |
images = images.to(device) | |
targets = [target.to(device) for target in targets] | |
loss_dict = model(images, targets, task) | |
all_task_loss_dict.append(loss_dict) | |
losses = sum(loss for loss_dict in all_task_loss_dict for loss in loss_dict.values()) | |
# reduce losses over all GPUs for logging purposes | |
loss_dict_reduced = reduce_loss_dict(all_task_loss_dict) | |
losses_reduced = sum(loss for loss in loss_dict_reduced.values()) | |
meters.update(loss=losses_reduced, **loss_dict_reduced) | |
optimizer.zero_grad() | |
losses.backward() | |
optimizer.step() | |
scheduler.step() | |
batch_time = time.time() - end | |
end = time.time() | |
meters.update(time=batch_time, data=data_time) | |
eta_seconds = meters.time.global_avg * (max_iter - iteration) | |
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) | |
if iteration % 20 == 0 or iteration == max_iter: | |
logger.info( | |
meters.delimiter.join( | |
[ | |
"eta: {eta}", | |
"iter: {iter}", | |
"{meters}", | |
"lr: {lr:.6f}", | |
"max mem: {memory:.0f}", | |
] | |
).format( | |
eta=eta_string, | |
iter=iteration, | |
meters=str(meters), | |
lr=optimizer.param_groups[0]["lr"], | |
memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, | |
) | |
) | |
if iteration % checkpoint_period == 0: | |
checkpointer.save("model_{:07d}".format(iteration), **arguments) | |
if iteration == max_iter: | |
checkpointer.save("model_final", **arguments) | |
total_training_time = time.time() - start_training_time | |
total_time_str = str(datetime.timedelta(seconds=total_training_time)) | |
logger.info( | |
"Total training time: {} ({:.4f} s / it)".format( | |
total_time_str, total_training_time / (max_iter) | |
) | |
) | |