Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import functools | |
import tempfile | |
import torch | |
def spawn_and_init(fn, world_size, args=None): | |
if args is None: | |
args = () | |
with tempfile.NamedTemporaryFile(delete=False) as tmp_file: | |
torch.multiprocessing.spawn( | |
fn=functools.partial(init_and_run, fn, args), | |
args=(world_size, tmp_file.name,), | |
nprocs=world_size, | |
join=True, | |
) | |
def distributed_init(rank, world_size, tmp_file): | |
torch.distributed.init_process_group( | |
backend="nccl", | |
init_method="file://{}".format(tmp_file), | |
world_size=world_size, | |
rank=rank, | |
) | |
torch.cuda.set_device(rank) | |
def init_and_run(fn, args, rank, world_size, tmp_file): | |
distributed_init(rank, world_size, tmp_file) | |
group = torch.distributed.new_group() | |
fn(rank, group, *args) | |
def objects_are_equal(a, b) -> bool: | |
if type(a) is not type(b): | |
return False | |
if isinstance(a, dict): | |
if set(a.keys()) != set(b.keys()): | |
return False | |
for k in a.keys(): | |
if not objects_are_equal(a[k], b[k]): | |
return False | |
return True | |
elif isinstance(a, (list, tuple, set)): | |
if len(a) != len(b): | |
return False | |
return all(objects_are_equal(x, y) for x, y in zip(a, b)) | |
elif torch.is_tensor(a): | |
return ( | |
a.size() == b.size() | |
and a.dtype == b.dtype | |
and a.device == b.device | |
and torch.all(a == b) | |
) | |
else: | |
return a == b | |