show / mmpose-0.29.0 /tests /test_regularization.py
camenduru's picture
thanks to show ❤
3bbb319
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmpose.core import WeightNormClipHook
def test_weight_norm_clip():
torch.manual_seed(0)
module = torch.nn.Linear(2, 2, bias=False)
module.weight.data.fill_(2)
WeightNormClipHook(max_norm=1.0).register(module)
x = torch.rand(1, 2).requires_grad_()
_ = module(x)
weight_norm = module.weight.norm().item()
np.testing.assert_almost_equal(weight_norm, 1.0, decimal=6)