Spaces:
Build error
Build error
File size: 486 Bytes
d7a991a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmpose.core import WeightNormClipHook
def test_weight_norm_clip():
torch.manual_seed(0)
module = torch.nn.Linear(2, 2, bias=False)
module.weight.data.fill_(2)
WeightNormClipHook(max_norm=1.0).register(module)
x = torch.rand(1, 2).requires_grad_()
_ = module(x)
weight_norm = module.weight.norm().item()
np.testing.assert_almost_equal(weight_norm, 1.0, decimal=6)
|