File size: 7,041 Bytes
df1ad02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
"""
WaveGRU model: melspectrogram => mu-law encoded waveform
"""

import jax
import jax.numpy as jnp
import pax


class ReLU(pax.Module):
    def __call__(self, x):
        return jax.nn.relu(x)


def dilated_residual_conv_block(dim, kernel, stride, dilation):
    """
    Use dilated convs to enlarge the receptive field
    """
    return pax.Sequential(
        pax.Conv1D(dim, dim, kernel, stride, dilation, "VALID", with_bias=False),
        pax.LayerNorm(dim, -1, True, True),
        ReLU(),
        pax.Conv1D(dim, dim, 1, 1, 1, "VALID", with_bias=False),
        pax.LayerNorm(dim, -1, True, True),
        ReLU(),
    )


def tile_1d(x, factor):
    """
    Tile tensor of shape N, L, D into N, L*factor, D
    """
    N, L, D = x.shape
    x = x[:, :, None, :]
    x = jnp.tile(x, (1, 1, factor, 1))
    x = jnp.reshape(x, (N, L * factor, D))
    return x


def up_block(dim, factor):
    """
    Tile >> Conv >> BatchNorm >> ReLU
    """
    return pax.Sequential(
        lambda x: tile_1d(x, factor),
        pax.Conv1D(dim, dim, 2 * factor, stride=1, padding="VALID", with_bias=False),
        pax.LayerNorm(dim, -1, True, True),
        ReLU(),
    )


class Upsample(pax.Module):
    """
    Upsample melspectrogram to match raw audio sample rate.
    """

    def __init__(self, input_dim, upsample_factors):
        super().__init__()
        self.input_conv = pax.Sequential(
            pax.Conv1D(input_dim, 512, 1, with_bias=False),
            pax.LayerNorm(512, -1, True, True),
        )
        self.upsample_factors = upsample_factors
        self.dilated_convs = [
            dilated_residual_conv_block(512, 3, 1, 2**i) for i in range(5)
        ]
        self.up_factors = upsample_factors[:-1]
        self.up_blocks = [up_block(512, x) for x in self.up_factors]
        self.final_tile = upsample_factors[-1]

    def __call__(self, x):
        x = self.input_conv(x)
        for residual in self.dilated_convs:
            y = residual(x)
            pad = (x.shape[1] - y.shape[1]) // 2
            x = x[:, pad:-pad, :] + y

        for f in self.up_blocks:
            x = f(x)

        x = tile_1d(x, self.final_tile)
        return x


class Pruner(pax.Module):
    """
    Base class for pruners
    """

    def __init__(self, update_freq=500):
        super().__init__()
        self.update_freq = update_freq

    def compute_sparsity(self, step):
        """
        Two-stages pruning
        """
        t = jnp.power(1 - (step * 1.0 - 1_000) / 300_000, 3)
        z = 0.5 * jnp.clip(1.0 - t, a_min=0, a_max=1)
        for i in range(4):
            t = jnp.power(1 - (step * 1.0 - 1_000 - 400_000 - i * 200_000) / 100_000, 3)
            z = z + 0.1 * jnp.clip(1 - t, a_min=0, a_max=1)
        return z

    def prune(self, step, weights):
        """
        Return a mask
        """
        z = self.compute_sparsity(step)
        x = weights
        H, W = x.shape
        x = x.reshape(H // 4, 4, W // 4, 4)
        x = jnp.abs(x)
        x = jnp.sum(x, axis=(1, 3), keepdims=True)
        q = jnp.quantile(jnp.reshape(x, (-1,)), z)
        x = x >= q
        x = jnp.tile(x, (1, 4, 1, 4))
        x = jnp.reshape(x, (H, W))
        return x


class GRUPruner(Pruner):
    def __init__(self, gru, update_freq=500):
        super().__init__(update_freq=update_freq)
        self.xh_zr_fc_mask = jnp.ones_like(gru.xh_zr_fc.weight) == 1
        self.xh_h_fc_mask = jnp.ones_like(gru.xh_h_fc.weight) == 1

    def __call__(self, gru: pax.GRU):
        """
        Apply mask after an optimization step
        """
        zr_masked_weights = jnp.where(self.xh_zr_fc_mask, gru.xh_zr_fc.weight, 0)
        gru = gru.replace_node(gru.xh_zr_fc.weight, zr_masked_weights)
        h_masked_weights = jnp.where(self.xh_h_fc_mask, gru.xh_h_fc.weight, 0)
        gru = gru.replace_node(gru.xh_h_fc.weight, h_masked_weights)
        return gru

    def update_mask(self, step, gru: pax.GRU):
        """
        Update internal masks
        """
        xh_z_weight, xh_r_weight = jnp.split(gru.xh_zr_fc.weight, 2, axis=1)
        xh_z_weight = self.prune(step, xh_z_weight)
        xh_r_weight = self.prune(step, xh_r_weight)
        self.xh_zr_fc_mask *= jnp.concatenate((xh_z_weight, xh_r_weight), axis=1)
        self.xh_h_fc_mask *= self.prune(step, gru.xh_h_fc.weight)


class LinearPruner(Pruner):
    def __init__(self, linear, update_freq=500):
        super().__init__(update_freq=update_freq)
        self.mask = jnp.ones_like(linear.weight) == 1

    def __call__(self, linear: pax.Linear):
        """
        Apply mask after an optimization step
        """
        return linear.replace(weight=jnp.where(self.mask, linear.weight, 0))

    def update_mask(self, step, linear: pax.Linear):
        """
        Update internal masks
        """
        self.mask *= self.prune(step, linear.weight)


class WaveGRU(pax.Module):
    """
    WaveGRU vocoder model
    """

    def __init__(
        self, mel_dim=80, embed_dim=32, rnn_dim=512, upsample_factors=(5, 4, 3, 5)
    ):
        super().__init__()
        self.embed = pax.Embed(256, embed_dim)
        self.upsample = Upsample(input_dim=mel_dim, upsample_factors=upsample_factors)
        self.rnn = pax.GRU(embed_dim + rnn_dim, rnn_dim)
        self.o1 = pax.Linear(rnn_dim, rnn_dim)
        self.o2 = pax.Linear(rnn_dim, 256)
        self.gru_pruner = GRUPruner(self.rnn)
        self.o1_pruner = LinearPruner(self.o1)
        self.o2_pruner = LinearPruner(self.o2)

    def output(self, x):
        x = self.o1(x)
        x = jax.nn.relu(x)
        x = self.o2(x)
        return x

    @jax.jit
    def inference_step(self, rnn_state, mel, rng_key, x):
        """one inference step"""
        x = self.embed(x)
        x = jnp.concatenate((x, mel), axis=-1)
        rnn_state, x = self.rnn(rnn_state, x)
        x = self.output(x)
        rng_key, next_rng_key = jax.random.split(rng_key, 2)
        x = jax.random.categorical(rng_key, x, axis=-1)
        return rnn_state, next_rng_key, x

    def inference(self, mel, no_gru=False, seed=42):
        """
        generate waveform form melspectrogram
        """

        y = self.upsample(mel)
        if no_gru:
            return y
        x = jnp.array([127], dtype=jnp.int32)
        rnn_state = self.rnn.initial_state(1)
        output = []
        rng_key = jax.random.PRNGKey(seed)
        for i in range(y.shape[1]):
            rnn_state, rng_key, x = self.inference_step(rnn_state, y[:, i], rng_key, x)
            output.append(x)
        x = jnp.concatenate(output, axis=0)
        return x

    def __call__(self, mel, x):
        x = self.embed(x)
        y = self.upsample(mel)
        pad_left = (x.shape[1] - y.shape[1]) // 2
        pad_right = x.shape[1] - y.shape[1] - pad_left
        x = x[:, pad_left:-pad_right]
        x = jnp.concatenate((x, y), axis=-1)
        _, x = pax.scan(
            self.rnn,
            self.rnn.initial_state(x.shape[0]),
            x,
            time_major=False,
        )
        x = self.output(x)
        return x