HaMeR / vendor /ViTPose /tests /test_regularization.py
geopavlakos's picture
Initial commit
d7a991a
raw
history blame contribute delete
486 Bytes
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmpose.core import WeightNormClipHook
def test_weight_norm_clip():
torch.manual_seed(0)
module = torch.nn.Linear(2, 2, bias=False)
module.weight.data.fill_(2)
WeightNormClipHook(max_norm=1.0).register(module)
x = torch.rand(1, 2).requires_grad_()
_ = module(x)
weight_norm = module.weight.norm().item()
np.testing.assert_almost_equal(weight_norm, 1.0, decimal=6)