File size: 486 Bytes
d7a991a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch

from mmpose.core import WeightNormClipHook


def test_weight_norm_clip():
    torch.manual_seed(0)

    module = torch.nn.Linear(2, 2, bias=False)
    module.weight.data.fill_(2)
    WeightNormClipHook(max_norm=1.0).register(module)

    x = torch.rand(1, 2).requires_grad_()
    _ = module(x)

    weight_norm = module.weight.norm().item()
    np.testing.assert_almost_equal(weight_norm, 1.0, decimal=6)