Spaces:
Runtime error
Runtime error
NTT123
commited on
Commit
·
41ba53f
1
Parent(s):
faa16c1
use a customized gru.
Browse files- app.py +3 -6
- extract_model.py +5 -0
- inference.py +1 -2
- wavegru.ckpt +2 -2
- wavegru.py +94 -60
- wavegru.yaml +1 -1
- wavegru_cpp.py +11 -47
- wavegru_mod.cc +29 -42
- wavegru_mod.so +2 -2
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
## build wavegru-cpp
|
|
|
4 |
# os.system("./bazelisk-linux-amd64 clean --expunge")
|
5 |
# os.system("./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native")
|
6 |
|
@@ -18,14 +17,12 @@ alphabet, tacotron_net, tacotron_config = load_tacotron_model(
|
|
18 |
wavegru_config, wavegru_net = load_wavegru_net("./wavegru.yaml", "./wavegru.ckpt")
|
19 |
|
20 |
wave_cpp_weight_mask = extract_weight_mask(wavegru_net)
|
21 |
-
wavecpp = load_wavegru_cpp(wave_cpp_weight_mask)
|
22 |
|
23 |
|
24 |
def speak(text):
|
25 |
mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
|
26 |
-
print(mel.shape)
|
27 |
y = mel_to_wav(wavegru_net, wavecpp, mel, wavegru_config)
|
28 |
-
print(y.shape)
|
29 |
return 24_000, y
|
30 |
|
31 |
|
@@ -38,7 +35,7 @@ gr.Interface(
|
|
38 |
examples=[
|
39 |
"this is a test!",
|
40 |
"October arrived, spreading a damp chill over the grounds and into the castle. Madam Pomfrey, the nurse, was kept busy by a sudden spate of colds among the staff and students.",
|
41 |
-
"Artificial intelligence is intelligence demonstrated by machines, as opposed to natural intelligence displayed by animals including humans",
|
42 |
],
|
43 |
outputs="audio",
|
44 |
title=title,
|
|
|
|
|
|
|
1 |
## build wavegru-cpp
|
2 |
+
# import os
|
3 |
# os.system("./bazelisk-linux-amd64 clean --expunge")
|
4 |
# os.system("./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native")
|
5 |
|
|
|
17 |
wavegru_config, wavegru_net = load_wavegru_net("./wavegru.yaml", "./wavegru.ckpt")
|
18 |
|
19 |
wave_cpp_weight_mask = extract_weight_mask(wavegru_net)
|
20 |
+
wavecpp = load_wavegru_cpp(wave_cpp_weight_mask, wavegru_config["upsample_factors"][-1])
|
21 |
|
22 |
|
23 |
def speak(text):
|
24 |
mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
|
|
|
25 |
y = mel_to_wav(wavegru_net, wavecpp, mel, wavegru_config)
|
|
|
26 |
return 24_000, y
|
27 |
|
28 |
|
|
|
35 |
examples=[
|
36 |
"this is a test!",
|
37 |
"October arrived, spreading a damp chill over the grounds and into the castle. Madam Pomfrey, the nurse, was kept busy by a sudden spate of colds among the staff and students.",
|
38 |
+
"Artificial intelligence is intelligence demonstrated by machines, as opposed to natural intelligence displayed by animals including humans.",
|
39 |
],
|
40 |
outputs="audio",
|
41 |
title=title,
|
extract_model.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
|
3 |
+
dic = pickle.load(open("./wavegru.ckpt", "rb"))
|
4 |
+
del dic["optim_state_dict"]
|
5 |
+
pickle.dump(dic, open("./wavegru.ckpt", "wb"))
|
inference.py
CHANGED
@@ -49,7 +49,6 @@ def load_wavegru_net(config_file, model_file):
|
|
49 |
config = load_wavegru_config(config_file)
|
50 |
net = WaveGRU(
|
51 |
mel_dim=config["mel_dim"],
|
52 |
-
embed_dim=config["embed_dim"],
|
53 |
rnn_dim=config["rnn_dim"],
|
54 |
upsample_factors=config["upsample_factors"],
|
55 |
)
|
@@ -74,7 +73,7 @@ def mel_to_wav(net, netcpp, mel, config):
|
|
74 |
)
|
75 |
ft = wavegru_inference(net, mel)
|
76 |
ft = jax.device_get(ft[0])
|
77 |
-
wav = netcpp.inference(ft,
|
78 |
wav = np.array(wav)
|
79 |
wav = librosa.mu_expand(wav - 127, mu=255)
|
80 |
wav = librosa.effects.deemphasis(wav, coef=0.86)
|
|
|
49 |
config = load_wavegru_config(config_file)
|
50 |
net = WaveGRU(
|
51 |
mel_dim=config["mel_dim"],
|
|
|
52 |
rnn_dim=config["rnn_dim"],
|
53 |
upsample_factors=config["upsample_factors"],
|
54 |
)
|
|
|
73 |
)
|
74 |
ft = wavegru_inference(net, mel)
|
75 |
ft = jax.device_get(ft[0])
|
76 |
+
wav = netcpp.inference(ft, 0.9)
|
77 |
wav = np.array(wav)
|
78 |
wav = librosa.mu_expand(wav - 127, mu=255)
|
79 |
wav = librosa.effects.deemphasis(wav, coef=0.86)
|
wavegru.ckpt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c06310d989fd524359d5f3ecf8ea1dc146980bb594b7b90553d0d42a64c512d8
|
3 |
+
size 58039876
|
wavegru.py
CHANGED
@@ -2,9 +2,13 @@
|
|
2 |
WaveGRU model: melspectrogram => mu-law encoded waveform
|
3 |
"""
|
4 |
|
|
|
|
|
5 |
import jax
|
6 |
import jax.numpy as jnp
|
7 |
import pax
|
|
|
|
|
8 |
|
9 |
|
10 |
class ReLU(pax.Module):
|
@@ -37,16 +41,20 @@ def tile_1d(x, factor):
|
|
37 |
return x
|
38 |
|
39 |
|
40 |
-
def up_block(
|
41 |
"""
|
42 |
Tile >> Conv >> BatchNorm >> ReLU
|
43 |
"""
|
44 |
-
|
45 |
lambda x: tile_1d(x, factor),
|
46 |
-
pax.Conv1D(
|
47 |
-
|
48 |
-
|
|
|
49 |
)
|
|
|
|
|
|
|
50 |
|
51 |
|
52 |
class Upsample(pax.Module):
|
@@ -54,21 +62,24 @@ class Upsample(pax.Module):
|
|
54 |
Upsample melspectrogram to match raw audio sample rate.
|
55 |
"""
|
56 |
|
57 |
-
def __init__(self, input_dim, upsample_factors):
|
58 |
super().__init__()
|
59 |
self.input_conv = pax.Sequential(
|
60 |
-
pax.Conv1D(input_dim,
|
61 |
-
pax.LayerNorm(
|
62 |
)
|
63 |
self.upsample_factors = upsample_factors
|
64 |
self.dilated_convs = [
|
65 |
-
dilated_residual_conv_block(
|
66 |
]
|
67 |
self.up_factors = upsample_factors[:-1]
|
68 |
-
self.up_blocks = [up_block(
|
|
|
|
|
|
|
69 |
self.final_tile = upsample_factors[-1]
|
70 |
|
71 |
-
def __call__(self, x):
|
72 |
x = self.input_conv(x)
|
73 |
for residual in self.dilated_convs:
|
74 |
y = residual(x)
|
@@ -78,28 +89,55 @@ class Upsample(pax.Module):
|
|
78 |
for f in self.up_blocks:
|
79 |
x = f(x)
|
80 |
|
|
|
|
|
81 |
x = tile_1d(x, self.final_tile)
|
82 |
return x
|
83 |
|
84 |
|
85 |
-
class
|
86 |
"""
|
87 |
-
|
88 |
"""
|
89 |
|
90 |
-
|
|
|
|
|
|
|
91 |
super().__init__()
|
92 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
def compute_sparsity(self, step):
|
95 |
-
|
96 |
-
|
97 |
-
"""
|
98 |
-
t = jnp.power(1 - (step * 1.0 - 1_000) / 300_000, 3)
|
99 |
-
z = 0.5 * jnp.clip(1.0 - t, a_min=0, a_max=1)
|
100 |
-
for i in range(4):
|
101 |
-
t = jnp.power(1 - (step * 1.0 - 1_000 - 400_000 - i * 200_000) / 100_000, 3)
|
102 |
-
z = z + 0.1 * jnp.clip(1 - t, a_min=0, a_max=1)
|
103 |
return z
|
104 |
|
105 |
def prune(self, step, weights):
|
@@ -120,35 +158,32 @@ class Pruner(pax.Module):
|
|
120 |
|
121 |
|
122 |
class GRUPruner(Pruner):
|
123 |
-
def __init__(self, gru
|
124 |
-
super().__init__(
|
125 |
-
self.
|
126 |
-
self.xh_h_fc_mask = jnp.ones_like(gru.xh_h_fc.weight) == 1
|
127 |
|
128 |
def __call__(self, gru: pax.GRU):
|
129 |
"""
|
130 |
Apply mask after an optimization step
|
131 |
"""
|
132 |
-
|
133 |
-
gru = gru.replace_node(gru.
|
134 |
-
h_masked_weights = jnp.where(self.xh_h_fc_mask, gru.xh_h_fc.weight, 0)
|
135 |
-
gru = gru.replace_node(gru.xh_h_fc.weight, h_masked_weights)
|
136 |
return gru
|
137 |
|
138 |
def update_mask(self, step, gru: pax.GRU):
|
139 |
"""
|
140 |
Update internal masks
|
141 |
"""
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
self.
|
147 |
|
148 |
|
149 |
class LinearPruner(Pruner):
|
150 |
-
def __init__(self, linear
|
151 |
-
super().__init__(
|
152 |
self.mask = jnp.ones_like(linear.weight) == 1
|
153 |
|
154 |
def __call__(self, linear: pax.Linear):
|
@@ -166,16 +201,16 @@ class LinearPruner(Pruner):
|
|
166 |
|
167 |
class WaveGRU(pax.Module):
|
168 |
"""
|
169 |
-
WaveGRU vocoder model
|
170 |
"""
|
171 |
|
172 |
-
def __init__(
|
173 |
-
self, mel_dim=80, embed_dim=32, rnn_dim=512, upsample_factors=(5, 4, 3, 5)
|
174 |
-
):
|
175 |
super().__init__()
|
176 |
-
self.embed = pax.Embed(256,
|
177 |
-
self.upsample = Upsample(
|
178 |
-
|
|
|
|
|
179 |
self.o1 = pax.Linear(rnn_dim, rnn_dim)
|
180 |
self.o2 = pax.Linear(rnn_dim, 256)
|
181 |
self.gru_pruner = GRUPruner(self.rnn)
|
@@ -188,31 +223,30 @@ class WaveGRU(pax.Module):
|
|
188 |
x = self.o2(x)
|
189 |
return x
|
190 |
|
191 |
-
@jax.jit
|
192 |
-
def inference_step(self, rnn_state, mel, rng_key, x):
|
193 |
-
"""one inference step"""
|
194 |
-
x = self.embed(x)
|
195 |
-
x = jnp.concatenate((x, mel), axis=-1)
|
196 |
-
rnn_state, x = self.rnn(rnn_state, x)
|
197 |
-
x = self.output(x)
|
198 |
-
rng_key, next_rng_key = jax.random.split(rng_key, 2)
|
199 |
-
x = jax.random.categorical(rng_key, x, axis=-1)
|
200 |
-
return rnn_state, next_rng_key, x
|
201 |
-
|
202 |
def inference(self, mel, no_gru=False, seed=42):
|
203 |
"""
|
204 |
generate waveform form melspectrogram
|
205 |
"""
|
206 |
|
207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
if no_gru:
|
209 |
return y
|
210 |
x = jnp.array([127], dtype=jnp.int32)
|
211 |
rnn_state = self.rnn.initial_state(1)
|
212 |
output = []
|
213 |
rng_key = jax.random.PRNGKey(seed)
|
214 |
-
for i in range(y.shape[1]):
|
215 |
-
rnn_state, rng_key, x =
|
216 |
output.append(x)
|
217 |
x = jnp.concatenate(output, axis=0)
|
218 |
return x
|
@@ -223,7 +257,7 @@ class WaveGRU(pax.Module):
|
|
223 |
pad_left = (x.shape[1] - y.shape[1]) // 2
|
224 |
pad_right = x.shape[1] - y.shape[1] - pad_left
|
225 |
x = x[:, pad_left:-pad_right]
|
226 |
-
x =
|
227 |
_, x = pax.scan(
|
228 |
self.rnn,
|
229 |
self.rnn.initial_state(x.shape[0]),
|
|
|
2 |
WaveGRU model: melspectrogram => mu-law encoded waveform
|
3 |
"""
|
4 |
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
import jax
|
8 |
import jax.numpy as jnp
|
9 |
import pax
|
10 |
+
from pax import GRUState
|
11 |
+
from tqdm.cli import tqdm
|
12 |
|
13 |
|
14 |
class ReLU(pax.Module):
|
|
|
41 |
return x
|
42 |
|
43 |
|
44 |
+
def up_block(in_dim, out_dim, factor, relu=True):
|
45 |
"""
|
46 |
Tile >> Conv >> BatchNorm >> ReLU
|
47 |
"""
|
48 |
+
f = pax.Sequential(
|
49 |
lambda x: tile_1d(x, factor),
|
50 |
+
pax.Conv1D(
|
51 |
+
in_dim, out_dim, 2 * factor, stride=1, padding="VALID", with_bias=False
|
52 |
+
),
|
53 |
+
pax.LayerNorm(out_dim, -1, True, True),
|
54 |
)
|
55 |
+
if relu:
|
56 |
+
f >>= ReLU()
|
57 |
+
return f
|
58 |
|
59 |
|
60 |
class Upsample(pax.Module):
|
|
|
62 |
Upsample melspectrogram to match raw audio sample rate.
|
63 |
"""
|
64 |
|
65 |
+
def __init__(self, input_dim, rnn_dim, upsample_factors):
|
66 |
super().__init__()
|
67 |
self.input_conv = pax.Sequential(
|
68 |
+
pax.Conv1D(input_dim, rnn_dim, 1, with_bias=False),
|
69 |
+
pax.LayerNorm(rnn_dim, -1, True, True),
|
70 |
)
|
71 |
self.upsample_factors = upsample_factors
|
72 |
self.dilated_convs = [
|
73 |
+
dilated_residual_conv_block(rnn_dim, 3, 1, 2**i) for i in range(5)
|
74 |
]
|
75 |
self.up_factors = upsample_factors[:-1]
|
76 |
+
self.up_blocks = [up_block(rnn_dim, rnn_dim, x) for x in self.up_factors[:-1]]
|
77 |
+
self.up_blocks.append(
|
78 |
+
up_block(rnn_dim, 3 * rnn_dim, self.up_factors[-1], relu=False)
|
79 |
+
)
|
80 |
self.final_tile = upsample_factors[-1]
|
81 |
|
82 |
+
def __call__(self, x, no_repeat=False):
|
83 |
x = self.input_conv(x)
|
84 |
for residual in self.dilated_convs:
|
85 |
y = residual(x)
|
|
|
89 |
for f in self.up_blocks:
|
90 |
x = f(x)
|
91 |
|
92 |
+
if no_repeat:
|
93 |
+
return x
|
94 |
x = tile_1d(x, self.final_tile)
|
95 |
return x
|
96 |
|
97 |
|
98 |
+
class GRU(pax.Module):
|
99 |
"""
|
100 |
+
A customized GRU module.
|
101 |
"""
|
102 |
|
103 |
+
input_dim: int
|
104 |
+
hidden_dim: int
|
105 |
+
|
106 |
+
def __init__(self, hidden_dim: int):
|
107 |
super().__init__()
|
108 |
+
self.hidden_dim = hidden_dim
|
109 |
+
self.h_zrh_fc = pax.Linear(hidden_dim, hidden_dim * 3)
|
110 |
+
|
111 |
+
def initial_state(self, batch_size: int) -> GRUState:
|
112 |
+
"""Create an all zeros initial state."""
|
113 |
+
return GRUState(jnp.zeros((batch_size, self.hidden_dim), dtype=jnp.float32))
|
114 |
+
|
115 |
+
def __call__(self, state: GRUState, x) -> Tuple[GRUState, jnp.ndarray]:
|
116 |
+
hidden = state.hidden
|
117 |
+
x_zrh = x
|
118 |
+
h_zrh = self.h_zrh_fc(hidden)
|
119 |
+
x_zr, x_h = jnp.split(x_zrh, [2 * self.hidden_dim], axis=-1)
|
120 |
+
h_zr, h_h = jnp.split(h_zrh, [2 * self.hidden_dim], axis=-1)
|
121 |
+
|
122 |
+
zr = x_zr + h_zr
|
123 |
+
zr = jax.nn.sigmoid(zr)
|
124 |
+
z, r = jnp.split(zr, 2, axis=-1)
|
125 |
+
|
126 |
+
h_hat = x_h + r * h_h
|
127 |
+
h_hat = jnp.tanh(h_hat)
|
128 |
+
|
129 |
+
h = (1 - z) * hidden + z * h_hat
|
130 |
+
return GRUState(h), h
|
131 |
+
|
132 |
+
|
133 |
+
class Pruner(pax.Module):
|
134 |
+
"""
|
135 |
+
Base class for pruners
|
136 |
+
"""
|
137 |
|
138 |
def compute_sparsity(self, step):
|
139 |
+
t = jnp.power(1 - (step * 1.0 - 1_000) / 200_000, 3)
|
140 |
+
z = 0.9 * jnp.clip(1.0 - t, a_min=0, a_max=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
return z
|
142 |
|
143 |
def prune(self, step, weights):
|
|
|
158 |
|
159 |
|
160 |
class GRUPruner(Pruner):
|
161 |
+
def __init__(self, gru):
|
162 |
+
super().__init__()
|
163 |
+
self.h_zrh_fc_mask = jnp.ones_like(gru.h_zrh_fc.weight) == 1
|
|
|
164 |
|
165 |
def __call__(self, gru: pax.GRU):
|
166 |
"""
|
167 |
Apply mask after an optimization step
|
168 |
"""
|
169 |
+
zrh_masked_weights = jnp.where(self.h_zrh_fc_mask, gru.h_zrh_fc.weight, 0)
|
170 |
+
gru = gru.replace_node(gru.h_zrh_fc.weight, zrh_masked_weights)
|
|
|
|
|
171 |
return gru
|
172 |
|
173 |
def update_mask(self, step, gru: pax.GRU):
|
174 |
"""
|
175 |
Update internal masks
|
176 |
"""
|
177 |
+
z_weight, r_weight, h_weight = jnp.split(gru.h_zrh_fc.weight, 3, axis=1)
|
178 |
+
z_mask = self.prune(step, z_weight)
|
179 |
+
r_mask = self.prune(step, r_weight)
|
180 |
+
h_mask = self.prune(step, h_weight)
|
181 |
+
self.h_zrh_fc_mask *= jnp.concatenate((z_mask, r_mask, h_mask), axis=1)
|
182 |
|
183 |
|
184 |
class LinearPruner(Pruner):
|
185 |
+
def __init__(self, linear):
|
186 |
+
super().__init__()
|
187 |
self.mask = jnp.ones_like(linear.weight) == 1
|
188 |
|
189 |
def __call__(self, linear: pax.Linear):
|
|
|
201 |
|
202 |
class WaveGRU(pax.Module):
|
203 |
"""
|
204 |
+
WaveGRU vocoder model.
|
205 |
"""
|
206 |
|
207 |
+
def __init__(self, mel_dim=80, rnn_dim=512, upsample_factors=(5, 3, 20)):
|
|
|
|
|
208 |
super().__init__()
|
209 |
+
self.embed = pax.Embed(256, 3 * rnn_dim)
|
210 |
+
self.upsample = Upsample(
|
211 |
+
input_dim=mel_dim, rnn_dim=rnn_dim, upsample_factors=upsample_factors
|
212 |
+
)
|
213 |
+
self.rnn = GRU(rnn_dim)
|
214 |
self.o1 = pax.Linear(rnn_dim, rnn_dim)
|
215 |
self.o2 = pax.Linear(rnn_dim, 256)
|
216 |
self.gru_pruner = GRUPruner(self.rnn)
|
|
|
223 |
x = self.o2(x)
|
224 |
return x
|
225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
def inference(self, mel, no_gru=False, seed=42):
|
227 |
"""
|
228 |
generate waveform form melspectrogram
|
229 |
"""
|
230 |
|
231 |
+
@jax.jit
|
232 |
+
def step(rnn_state, mel, rng_key, x):
|
233 |
+
x = self.embed(x)
|
234 |
+
x = x + mel
|
235 |
+
rnn_state, x = self.rnn(rnn_state, x)
|
236 |
+
x = self.output(x)
|
237 |
+
rng_key, next_rng_key = jax.random.split(rng_key, 2)
|
238 |
+
x = jax.random.categorical(rng_key, x, axis=-1)
|
239 |
+
return rnn_state, next_rng_key, x
|
240 |
+
|
241 |
+
y = self.upsample(mel, no_repeat=no_gru)
|
242 |
if no_gru:
|
243 |
return y
|
244 |
x = jnp.array([127], dtype=jnp.int32)
|
245 |
rnn_state = self.rnn.initial_state(1)
|
246 |
output = []
|
247 |
rng_key = jax.random.PRNGKey(seed)
|
248 |
+
for i in tqdm(range(y.shape[1])):
|
249 |
+
rnn_state, rng_key, x = step(rnn_state, y[:, i], rng_key, x)
|
250 |
output.append(x)
|
251 |
x = jnp.concatenate(output, axis=0)
|
252 |
return x
|
|
|
257 |
pad_left = (x.shape[1] - y.shape[1]) // 2
|
258 |
pad_right = x.shape[1] - y.shape[1] - pad_left
|
259 |
x = x[:, pad_left:-pad_right]
|
260 |
+
x = x + y
|
261 |
_, x = pax.scan(
|
262 |
self.rnn,
|
263 |
self.rnn.initial_state(x.shape[0]),
|
wavegru.yaml
CHANGED
@@ -11,4 +11,4 @@ embed_dim: 32
|
|
11 |
rnn_dim: 512
|
12 |
frames_per_sequence: 67
|
13 |
num_pad_frames: 62
|
14 |
-
upsample_factors: [5,
|
|
|
11 |
rnn_dim: 512
|
12 |
frames_per_sequence: 67
|
13 |
num_pad_frames: 62
|
14 |
+
upsample_factors: [5, 3, 20]
|
wavegru_cpp.py
CHANGED
@@ -1,18 +1,13 @@
|
|
1 |
import numpy as np
|
2 |
-
import sys
|
3 |
from wavegru_mod import WaveGRU
|
4 |
|
5 |
|
6 |
def extract_weight_mask(net):
|
7 |
data = {}
|
8 |
data["embed_weight"] = net.embed.weight
|
9 |
-
data["
|
10 |
-
data["
|
11 |
-
data["
|
12 |
-
|
13 |
-
data["gru_xh_h_weight"] = net.rnn.xh_h_fc.weight
|
14 |
-
data["gru_xh_h_mask"] = net.gru_pruner.xh_h_fc_mask
|
15 |
-
data["gru_xh_h_bias"] = net.rnn.xh_h_fc.bias
|
16 |
|
17 |
data["o1_weight"] = net.o1.weight
|
18 |
data["o1_mask"] = net.o1_pruner.mask
|
@@ -23,31 +18,16 @@ def extract_weight_mask(net):
|
|
23 |
return data
|
24 |
|
25 |
|
26 |
-
def load_wavegru_cpp(data):
|
|
|
27 |
embed = data["embed_weight"]
|
28 |
-
|
29 |
-
|
30 |
-
input_dim = data["gru_xh_zr_weight"].shape[1] - rnn_dim
|
31 |
-
net = WaveGRU(input_dim, embed_dim, rnn_dim)
|
32 |
net.load_embed(embed)
|
33 |
-
dim = embed_dim + input_dim + rnn_dim
|
34 |
-
z, r = np.split(data["gru_xh_zr_weight"].T, 2, axis=0)
|
35 |
-
h = data["gru_xh_h_weight"].T
|
36 |
-
z = np.ascontiguousarray(z)
|
37 |
-
r = np.ascontiguousarray(r)
|
38 |
-
h = np.ascontiguousarray(h)
|
39 |
-
|
40 |
-
b1, b2 = np.split(data["gru_xh_zr_bias"], 2)
|
41 |
-
b3 = data["gru_xh_h_bias"]
|
42 |
-
m1, m2, m3 = z, r, h
|
43 |
-
|
44 |
-
mask_z, mask_r = np.split(data["gru_xh_zr_mask"].T, 2, axis=0)
|
45 |
-
mask_h = data["gru_xh_h_mask"].T
|
46 |
-
mask_z = np.ascontiguousarray(mask_z)
|
47 |
-
mask_r = np.ascontiguousarray(mask_r)
|
48 |
-
mask_h = np.ascontiguousarray(mask_h)
|
49 |
|
50 |
-
|
|
|
|
|
51 |
|
52 |
o1 = np.ascontiguousarray(data["o1_weight"].T)
|
53 |
masko1 = np.ascontiguousarray(data["o1_mask"].T)
|
@@ -57,22 +37,6 @@ def load_wavegru_cpp(data):
|
|
57 |
masko2 = np.ascontiguousarray(data["o2_mask"].T)
|
58 |
o2b = data["o2_bias"]
|
59 |
|
60 |
-
net.load_weights(
|
61 |
-
m1,
|
62 |
-
mask1,
|
63 |
-
b1,
|
64 |
-
m2,
|
65 |
-
mask2,
|
66 |
-
b2,
|
67 |
-
m3,
|
68 |
-
mask3,
|
69 |
-
b3,
|
70 |
-
o1,
|
71 |
-
masko1,
|
72 |
-
o1b,
|
73 |
-
o2,
|
74 |
-
masko2,
|
75 |
-
o2b,
|
76 |
-
)
|
77 |
|
78 |
return net
|
|
|
1 |
import numpy as np
|
|
|
2 |
from wavegru_mod import WaveGRU
|
3 |
|
4 |
|
5 |
def extract_weight_mask(net):
|
6 |
data = {}
|
7 |
data["embed_weight"] = net.embed.weight
|
8 |
+
data["gru_h_zrh_weight"] = net.rnn.h_zrh_fc.weight
|
9 |
+
data["gru_h_zrh_mask"] = net.gru_pruner.h_zrh_fc_mask
|
10 |
+
data["gru_h_zrh_bias"] = net.rnn.h_zrh_fc.bias
|
|
|
|
|
|
|
|
|
11 |
|
12 |
data["o1_weight"] = net.o1.weight
|
13 |
data["o1_mask"] = net.o1_pruner.mask
|
|
|
18 |
return data
|
19 |
|
20 |
|
21 |
+
def load_wavegru_cpp(data, repeat_factor):
|
22 |
+
"""load wavegru weight to cpp object"""
|
23 |
embed = data["embed_weight"]
|
24 |
+
rnn_dim = data["gru_h_zrh_bias"].shape[0] // 3
|
25 |
+
net = WaveGRU(rnn_dim, repeat_factor)
|
|
|
|
|
26 |
net.load_embed(embed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
+
m = np.ascontiguousarray(data["gru_h_zrh_weight"].T)
|
29 |
+
mask = np.ascontiguousarray(data["gru_h_zrh_mask"].T)
|
30 |
+
b = data["gru_h_zrh_bias"]
|
31 |
|
32 |
o1 = np.ascontiguousarray(data["o1_weight"].T)
|
33 |
masko1 = np.ascontiguousarray(data["o1_mask"].T)
|
|
|
37 |
masko2 = np.ascontiguousarray(data["o2_mask"].T)
|
38 |
o2b = data["o2_bias"]
|
39 |
|
40 |
+
net.load_weights(m, mask, b, o1, masko1, o1b, o2, masko2, o2b)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
return net
|
wavegru_mod.cc
CHANGED
@@ -30,12 +30,11 @@ mat create_mat(int h, int w) {
|
|
30 |
}
|
31 |
|
32 |
struct WaveGRU {
|
33 |
-
int input_dim;
|
34 |
-
int embed_dim;
|
35 |
int hidden_dim;
|
36 |
-
|
37 |
-
|
38 |
-
vec
|
|
|
39 |
vec fco1, fco2;
|
40 |
vec o1b, o2b;
|
41 |
vec t;
|
@@ -43,30 +42,26 @@ struct WaveGRU {
|
|
43 |
mat o1, o2;
|
44 |
std::vector<vec> embed;
|
45 |
|
46 |
-
WaveGRU(int
|
47 |
-
:
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
b3(hidden_dim),
|
53 |
z(hidden_dim),
|
54 |
r(hidden_dim),
|
55 |
hh(hidden_dim),
|
56 |
fco1(hidden_dim),
|
57 |
fco2(256),
|
58 |
-
t(hidden_dim + input_dim + embed_dim),
|
59 |
h(hidden_dim),
|
60 |
o1b(hidden_dim),
|
61 |
o2b(256) {
|
62 |
-
|
63 |
-
m2 = create_mat(input_dim + hidden_dim + embed_dim, hidden_dim);
|
64 |
-
m3 = create_mat(input_dim + hidden_dim + embed_dim, hidden_dim);
|
65 |
o1 = create_mat(hidden_dim, hidden_dim);
|
66 |
o2 = create_mat(hidden_dim, 256);
|
67 |
embed = std::vector<vec>();
|
68 |
for (int i = 0; i < 256; i++) {
|
69 |
-
embed.emplace_back(
|
70 |
embed[i].FillRandom();
|
71 |
}
|
72 |
}
|
@@ -74,7 +69,7 @@ struct WaveGRU {
|
|
74 |
void load_embed(fndarray embed_weights) {
|
75 |
auto a_embed = embed_weights.unchecked<2>();
|
76 |
for (int i = 0; i < 256; i++) {
|
77 |
-
for (int j = 0; j <
|
78 |
}
|
79 |
}
|
80 |
|
@@ -90,43 +85,35 @@ struct WaveGRU {
|
|
90 |
return mmm;
|
91 |
}
|
92 |
|
93 |
-
void load_weights(fndarray
|
94 |
-
|
95 |
-
|
96 |
-
indarray o1_mask, fndarray o1b, fndarray o2,
|
97 |
indarray o2_mask, fndarray o2b) {
|
98 |
-
this->
|
99 |
-
this->m2 = load_linear(this->b2, m2, m2_mask, b2);
|
100 |
-
this->m3 = load_linear(this->b3, m3, m3_mask, b3);
|
101 |
this->o1 = load_linear(this->o1b, o1, o1_mask, o1b);
|
102 |
this->o2 = load_linear(this->o2b, o2, o2_mask, o2b);
|
103 |
}
|
104 |
|
105 |
std::vector<int> inference(fndarray ft, float temperature) {
|
106 |
auto rft = ft.unchecked<2>();
|
107 |
-
std::vector<vec> xs;
|
108 |
-
for (int i = 0; i < rft.shape(0); i++) {
|
109 |
-
xs.emplace_back(input_dim);
|
110 |
-
for (int j = 0; j < input_dim; j++) xs[i][j] = rft(i, j);
|
111 |
-
}
|
112 |
-
|
113 |
int value = 127;
|
114 |
-
std::vector<int> signal(
|
115 |
h.FillZero();
|
116 |
-
for (int index = 0; index <
|
117 |
-
|
118 |
-
|
119 |
-
for (int i = 0; i < hidden_dim; i++) t[
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
122 |
z.Sigmoid();
|
123 |
r.Sigmoid();
|
124 |
|
125 |
for (int i = 0; i < hidden_dim; i++) {
|
126 |
-
|
127 |
}
|
128 |
-
|
129 |
-
m3.SpMM_bias(t, b3, &hh, false);
|
130 |
hh.Tanh();
|
131 |
for (int i = 0; i < hidden_dim; i++) {
|
132 |
h[i] = (1. - z[i]) * h[i] + z[i] * hh[i];
|
@@ -142,7 +129,7 @@ struct WaveGRU {
|
|
142 |
|
143 |
PYBIND11_MODULE(wavegru_mod, m) {
|
144 |
py::class_<WaveGRU>(m, "WaveGRU")
|
145 |
-
.def(py::init<int, int
|
146 |
.def("load_embed", &WaveGRU::load_embed)
|
147 |
.def("load_weights", &WaveGRU::load_weights)
|
148 |
.def("inference", &WaveGRU::inference);
|
|
|
30 |
}
|
31 |
|
32 |
struct WaveGRU {
|
|
|
|
|
33 |
int hidden_dim;
|
34 |
+
int repeat_factor;
|
35 |
+
mat m;
|
36 |
+
vec b;
|
37 |
+
vec z, r, hh, zrh;
|
38 |
vec fco1, fco2;
|
39 |
vec o1b, o2b;
|
40 |
vec t;
|
|
|
42 |
mat o1, o2;
|
43 |
std::vector<vec> embed;
|
44 |
|
45 |
+
WaveGRU(int hidden_dim, int repeat_factor)
|
46 |
+
: hidden_dim(hidden_dim),
|
47 |
+
repeat_factor(repeat_factor),
|
48 |
+
b(3*hidden_dim),
|
49 |
+
t(3*hidden_dim),
|
50 |
+
zrh(3*hidden_dim),
|
|
|
51 |
z(hidden_dim),
|
52 |
r(hidden_dim),
|
53 |
hh(hidden_dim),
|
54 |
fco1(hidden_dim),
|
55 |
fco2(256),
|
|
|
56 |
h(hidden_dim),
|
57 |
o1b(hidden_dim),
|
58 |
o2b(256) {
|
59 |
+
m = create_mat(hidden_dim, 3*hidden_dim);
|
|
|
|
|
60 |
o1 = create_mat(hidden_dim, hidden_dim);
|
61 |
o2 = create_mat(hidden_dim, 256);
|
62 |
embed = std::vector<vec>();
|
63 |
for (int i = 0; i < 256; i++) {
|
64 |
+
embed.emplace_back(hidden_dim * 3);
|
65 |
embed[i].FillRandom();
|
66 |
}
|
67 |
}
|
|
|
69 |
void load_embed(fndarray embed_weights) {
|
70 |
auto a_embed = embed_weights.unchecked<2>();
|
71 |
for (int i = 0; i < 256; i++) {
|
72 |
+
for (int j = 0; j < hidden_dim * 3; j++) embed[i][j] = a_embed(i, j);
|
73 |
}
|
74 |
}
|
75 |
|
|
|
85 |
return mmm;
|
86 |
}
|
87 |
|
88 |
+
void load_weights(fndarray m, indarray m_mask, fndarray b,
|
89 |
+
fndarray o1, indarray o1_mask,
|
90 |
+
fndarray o1b, fndarray o2,
|
|
|
91 |
indarray o2_mask, fndarray o2b) {
|
92 |
+
this->m = load_linear(this->b, m, m_mask, b);
|
|
|
|
|
93 |
this->o1 = load_linear(this->o1b, o1, o1_mask, o1b);
|
94 |
this->o2 = load_linear(this->o2b, o2, o2_mask, o2b);
|
95 |
}
|
96 |
|
97 |
std::vector<int> inference(fndarray ft, float temperature) {
|
98 |
auto rft = ft.unchecked<2>();
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
int value = 127;
|
100 |
+
std::vector<int> signal(rft.shape(0) * repeat_factor);
|
101 |
h.FillZero();
|
102 |
+
for (int index = 0; index < signal.size(); index++) {
|
103 |
+
m.SpMM_bias(h, b, &zrh, false);
|
104 |
+
|
105 |
+
for (int i = 0; i < 3 * hidden_dim; i++) t[i] = embed[value][i] + rft(index / repeat_factor, i);
|
106 |
+
for (int i = 0; i < hidden_dim; i++) {
|
107 |
+
z[i] = zrh[i] + t[i];
|
108 |
+
r[i] = zrh[hidden_dim + i] + t[hidden_dim + i];
|
109 |
+
}
|
110 |
+
|
111 |
z.Sigmoid();
|
112 |
r.Sigmoid();
|
113 |
|
114 |
for (int i = 0; i < hidden_dim; i++) {
|
115 |
+
hh[i] = zrh[hidden_dim * 2 + i] * r[i] + t[hidden_dim * 2 + i];
|
116 |
}
|
|
|
|
|
117 |
hh.Tanh();
|
118 |
for (int i = 0; i < hidden_dim; i++) {
|
119 |
h[i] = (1. - z[i]) * h[i] + z[i] * hh[i];
|
|
|
129 |
|
130 |
PYBIND11_MODULE(wavegru_mod, m) {
|
131 |
py::class_<WaveGRU>(m, "WaveGRU")
|
132 |
+
.def(py::init<int, int>())
|
133 |
.def("load_embed", &WaveGRU::load_embed)
|
134 |
.def("load_weights", &WaveGRU::load_weights)
|
135 |
.def("inference", &WaveGRU::inference);
|
wavegru_mod.so
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:12c27f0ea07f8da3a3ab48bc01bb0f68971ce7d57b19ada87669eab138623a9c
|
3 |
+
size 525536
|