Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import numpy as np | |
import torch | |
from mmpose.core import WeightNormClipHook | |
def test_weight_norm_clip(): | |
torch.manual_seed(0) | |
module = torch.nn.Linear(2, 2, bias=False) | |
module.weight.data.fill_(2) | |
WeightNormClipHook(max_norm=1.0).register(module) | |
x = torch.rand(1, 2).requires_grad_() | |
_ = module(x) | |
weight_norm = module.weight.norm().item() | |
np.testing.assert_almost_equal(weight_norm, 1.0, decimal=6) | |