File size: 5,284 Bytes
9d0d223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import random

import torch

from audiocraft.adversarial import (
    AdversarialLoss,
    get_adv_criterion,
    get_real_criterion,
    get_fake_criterion,
    FeatureMatchingLoss,
    MultiScaleDiscriminator,
)


class TestAdversarialLoss:

    def test_adversarial_single_multidiscriminator(self):
        adv = MultiScaleDiscriminator()
        optimizer = torch.optim.Adam(
            adv.parameters(),
            lr=1e-4,
        )
        loss, loss_real, loss_fake = get_adv_criterion('mse'), get_real_criterion('mse'), get_fake_criterion('mse')
        adv_loss = AdversarialLoss(adv, optimizer, loss, loss_real, loss_fake)

        B, C, T = 4, 1, random.randint(1000, 5000)
        real = torch.randn(B, C, T)
        fake = torch.randn(B, C, T)

        disc_loss = adv_loss.train_adv(fake, real)
        assert isinstance(disc_loss, torch.Tensor) and isinstance(disc_loss.item(), float)

        loss, loss_feat = adv_loss(fake, real)
        assert isinstance(loss, torch.Tensor) and isinstance(loss.item(), float)
        # we did not specify feature loss
        assert loss_feat.item() == 0.

    def test_adversarial_feat_loss(self):
        adv = MultiScaleDiscriminator()
        optimizer = torch.optim.Adam(
            adv.parameters(),
            lr=1e-4,
        )
        loss, loss_real, loss_fake = get_adv_criterion('mse'), get_real_criterion('mse'), get_fake_criterion('mse')
        feat_loss = FeatureMatchingLoss()
        adv_loss = AdversarialLoss(adv, optimizer, loss, loss_real, loss_fake, feat_loss)

        B, C, T = 4, 1, random.randint(1000, 5000)
        real = torch.randn(B, C, T)
        fake = torch.randn(B, C, T)

        loss, loss_feat = adv_loss(fake, real)

        assert isinstance(loss, torch.Tensor) and isinstance(loss.item(), float)
        assert isinstance(loss_feat, torch.Tensor) and isinstance(loss.item(), float)


class TestGeneratorAdversarialLoss:

    def test_hinge_generator_adv_loss(self):
        adv_loss = get_adv_criterion(loss_type='hinge')

        t0 = torch.randn(1, 2, 0)
        t1 = torch.FloatTensor([1.0, 2.0, 3.0])

        assert adv_loss(t0).item() == 0.0
        assert adv_loss(t1).item() == -2.0

    def test_mse_generator_adv_loss(self):
        adv_loss = get_adv_criterion(loss_type='mse')

        t0 = torch.randn(1, 2, 0)
        t1 = torch.FloatTensor([1.0, 1.0, 1.0])
        t2 = torch.FloatTensor([2.0, 5.0, 5.0])

        assert adv_loss(t0).item() == 0.0
        assert adv_loss(t1).item() == 0.0
        assert adv_loss(t2).item() == 11.0


class TestDiscriminatorAdversarialLoss:

    def _disc_loss(self, loss_type: str, fake: torch.Tensor, real: torch.Tensor):
        disc_loss_real = get_real_criterion(loss_type)
        disc_loss_fake = get_fake_criterion(loss_type)

        loss = disc_loss_fake(fake) + disc_loss_real(real)
        return loss

    def test_hinge_discriminator_adv_loss(self):
        loss_type = 'hinge'
        t0 = torch.FloatTensor([0.0, 0.0, 0.0])
        t1 = torch.FloatTensor([1.0, 2.0, 3.0])

        assert self._disc_loss(loss_type, t0, t0).item() == 2.0
        assert self._disc_loss(loss_type, t1, t1).item() == 3.0

    def test_mse_discriminator_adv_loss(self):
        loss_type = 'mse'

        t0 = torch.FloatTensor([0.0, 0.0, 0.0])
        t1 = torch.FloatTensor([1.0, 1.0, 1.0])

        assert self._disc_loss(loss_type, t0, t0).item() == 1.0
        assert self._disc_loss(loss_type, t1, t0).item() == 2.0


class TestFeatureMatchingLoss:

    def test_features_matching_loss_base(self):
        ft_matching_loss = FeatureMatchingLoss()
        length = random.randrange(1, 100_000)
        t1 = torch.randn(1, 2, length)

        loss = ft_matching_loss([t1], [t1])
        assert isinstance(loss, torch.Tensor)
        assert loss.item() == 0.0

    def test_features_matching_loss_raises_exception(self):
        ft_matching_loss = FeatureMatchingLoss()
        length = random.randrange(1, 100_000)
        t1 = torch.randn(1, 2, length)
        t2 = torch.randn(1, 2, length + 1)

        with pytest.raises(AssertionError):
            ft_matching_loss([], [])

        with pytest.raises(AssertionError):
            ft_matching_loss([t1], [t1, t1])

        with pytest.raises(AssertionError):
            ft_matching_loss([t1], [t2])

    def test_features_matching_loss_output(self):
        loss_nonorm = FeatureMatchingLoss(normalize=False)
        loss_layer_normed = FeatureMatchingLoss(normalize=True)

        length = random.randrange(1, 100_000)
        t1 = torch.randn(1, 2, length)
        t2 = torch.randn(1, 2, length)

        assert loss_nonorm([t1, t2], [t1, t2]).item() == 0.0
        assert loss_layer_normed([t1, t2], [t1, t2]).item() == 0.0

        t3 = torch.FloatTensor([1.0, 2.0, 3.0])
        t4 = torch.FloatTensor([2.0, 10.0, 3.0])

        assert loss_nonorm([t3], [t4]).item() == 3.0
        assert loss_nonorm([t3, t3], [t4, t4]).item() == 6.0

        assert loss_layer_normed([t3], [t4]).item() == 3.0
        assert loss_layer_normed([t3, t3], [t4, t4]).item() == 3.0