Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# author: adefossez | |
import logging | |
import os | |
import torch | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.utils.data import DataLoader, Subset | |
from torch.nn.parallel.distributed import DistributedDataParallel | |
logger = logging.getLogger(__name__) | |
rank = 0 | |
world_size = 1 | |
def init(args): | |
"""init. | |
Initialize DDP using the given rendezvous file. | |
""" | |
global rank, world_size | |
if args.ddp: | |
assert args.rank is not None and args.world_size is not None | |
rank = args.rank | |
world_size = args.world_size | |
if world_size == 1: | |
return | |
torch.cuda.set_device(rank) | |
torch.distributed.init_process_group( | |
backend=args.ddp_backend, | |
init_method='file://' + os.path.abspath(args.rendezvous_file), | |
world_size=world_size, | |
rank=rank) | |
logger.debug("Distributed rendezvous went well, rank %d/%d", rank, world_size) | |
def average(metrics, count=1.): | |
"""average. | |
Average all the relevant metrices across processes | |
`metrics`should be a 1D float32 fector. Returns the average of `metrics` | |
over all hosts. You can use `count` to control the weight of each worker. | |
""" | |
if world_size == 1: | |
return metrics | |
tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32) | |
tensor *= count | |
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) | |
return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist() | |
def wrap(model): | |
"""wrap. | |
Wrap a model with DDP if distributed training is enabled. | |
""" | |
if world_size == 1: | |
return model | |
else: | |
return DistributedDataParallel( | |
model, | |
device_ids=[torch.cuda.current_device()], | |
output_device=torch.cuda.current_device()) | |
def barrier(): | |
if world_size > 1: | |
torch.distributed.barrier() | |
def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs): | |
"""loader. | |
Create a dataloader properly in case of distributed training. | |
If a gradient is going to be computed you must set `shuffle=True`. | |
:param dataset: the dataset to be parallelized | |
:param args: relevant args for the loader | |
:param shuffle: shuffle examples | |
:param klass: loader class | |
:param kwargs: relevant args | |
""" | |
if world_size == 1: | |
return klass(dataset, *args, shuffle=shuffle, **kwargs) | |
if shuffle: | |
# train means we will compute backward, we use DistributedSampler | |
sampler = DistributedSampler(dataset) | |
# We ignore shuffle, DistributedSampler already shuffles | |
return klass(dataset, *args, **kwargs, sampler=sampler) | |
else: | |
# We make a manual shard, as DistributedSampler otherwise replicate some examples | |
dataset = Subset(dataset, list(range(rank, len(dataset), world_size))) | |
return klass(dataset, *args, shuffle=shuffle) | |