YourMT3 / amt /src /tests /model /ops_test.py
mimbres's picture
.
a03c9b4
raw
history blame
491 Bytes
import unittest
import torch
import numpy as np
from model.ops import minmax_normalize
class TestMinMaxNormalize(unittest.TestCase):
def test_minmax_normalize(self):
x = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]])
x_norm = minmax_normalize(x)
x_norm_expected = torch.tensor([[[0.0, 0.2, 0.4], [0.6, 0.8, 1.0]]])
np.testing.assert_almost_equal(x_norm.numpy(), x_norm_expected.numpy(), decimal=2)
if __name__ == '__main__':
unittest.main()