Spaces:
Runtime error
Runtime error
File size: 3,656 Bytes
10b0761 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import sys
import unittest
import torch
from fairseq.distributed import utils as dist_utils
from .utils import objects_are_equal, spawn_and_init
class DistributedTest(unittest.TestCase):
def setUp(self):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA not available, skipping test")
if sys.platform == "win32":
raise unittest.SkipTest("NCCL doesn't support Windows, skipping test")
if torch.cuda.device_count() < 2:
raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")
class TestBroadcastObject(DistributedTest):
def test_str(self):
spawn_and_init(
functools.partial(
TestBroadcastObject._test_broadcast_object, "hello world"
),
world_size=2,
)
def test_tensor(self):
spawn_and_init(
functools.partial(
TestBroadcastObject._test_broadcast_object,
torch.rand(5),
),
world_size=2,
)
def test_complex(self):
spawn_and_init(
functools.partial(
TestBroadcastObject._test_broadcast_object,
{
"a": "1",
"b": [2, torch.rand(2, 3), 3],
"c": (torch.rand(2, 3), 4),
"d": {5, torch.rand(5)},
"e": torch.rand(5),
"f": torch.rand(5).int().cuda(),
},
),
world_size=2,
)
@staticmethod
def _test_broadcast_object(ref_obj, rank, group):
obj = dist_utils.broadcast_object(
ref_obj if rank == 0 else None, src_rank=0, group=group
)
assert objects_are_equal(ref_obj, obj)
class TestAllGatherList(DistributedTest):
def test_str_equality(self):
spawn_and_init(
functools.partial(
TestAllGatherList._test_all_gather_list_equality,
"hello world",
),
world_size=2,
)
def test_tensor_equality(self):
spawn_and_init(
functools.partial(
TestAllGatherList._test_all_gather_list_equality,
torch.rand(5),
),
world_size=2,
)
def test_complex_equality(self):
spawn_and_init(
functools.partial(
TestAllGatherList._test_all_gather_list_equality,
{
"a": "1",
"b": [2, torch.rand(2, 3), 3],
"c": (torch.rand(2, 3), 4),
"d": {5, torch.rand(5)},
"e": torch.rand(5),
"f": torch.rand(5).int(),
},
),
world_size=2,
)
@staticmethod
def _test_all_gather_list_equality(ref_obj, rank, group):
objs = dist_utils.all_gather_list(ref_obj, group)
for obj in objs:
assert objects_are_equal(ref_obj, obj)
def test_rank_tensor(self):
spawn_and_init(
TestAllGatherList._test_all_gather_list_rank_tensor, world_size=2
)
@staticmethod
def _test_all_gather_list_rank_tensor(rank, group):
obj = torch.tensor([rank])
objs = dist_utils.all_gather_list(obj, group)
for i, obj in enumerate(objs):
assert obj.item() == i
if __name__ == "__main__":
unittest.main()
|