Spaces:
Runtime error
Runtime error
File size: 2,086 Bytes
ee21b96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from torch import nn
from fairseq.distributed import ModuleProxyWrapper
from .utils import objects_are_equal
class MockDDPWrapper(nn.Module):
"""A simple wrapper with an interface similar to DistributedDataParallel."""
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, x):
return self.module(x)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 10)
self.xyz = "hello"
def forward(self, x):
return self.linear(x)
def get_xyz(self):
return self.xyz
class TestModuleProxyWrapper(unittest.TestCase):
def _get_module(self):
module = Model()
wrapped_module = MockDDPWrapper(module)
wrapped_module = ModuleProxyWrapper(wrapped_module)
return wrapped_module, module
def test_getattr_forwarding(self):
wrapped_module, module = self._get_module()
assert module.xyz == "hello"
assert module.get_xyz() == "hello"
assert wrapped_module.xyz == "hello"
wrapped_module.xyz = "world"
assert wrapped_module.xyz == "world"
assert module.get_xyz() == "hello"
def test_state_dict(self):
wrapped_module, module = self._get_module()
assert objects_are_equal(wrapped_module.state_dict(), module.state_dict())
def test_load_state_dict(self):
wrapped_module, module = self._get_module()
wrapped_module.load_state_dict(module.state_dict())
input = torch.rand(4, 5)
torch.testing.assert_allclose(wrapped_module(input), module(input))
def test_forward(self):
wrapped_module, module = self._get_module()
input = torch.rand(4, 5)
torch.testing.assert_allclose(wrapped_module(input), module(input))
if __name__ == "__main__":
unittest.main()
|