NTT123 commited on
Commit
41ba53f
·
1 Parent(s): faa16c1

use a customized gru.

Browse files
Files changed (9) hide show
  1. app.py +3 -6
  2. extract_model.py +5 -0
  3. inference.py +1 -2
  4. wavegru.ckpt +2 -2
  5. wavegru.py +94 -60
  6. wavegru.yaml +1 -1
  7. wavegru_cpp.py +11 -47
  8. wavegru_mod.cc +29 -42
  9. wavegru_mod.so +2 -2
app.py CHANGED
@@ -1,6 +1,5 @@
1
- import os
2
-
3
  ## build wavegru-cpp
 
4
  # os.system("./bazelisk-linux-amd64 clean --expunge")
5
  # os.system("./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native")
6
 
@@ -18,14 +17,12 @@ alphabet, tacotron_net, tacotron_config = load_tacotron_model(
18
  wavegru_config, wavegru_net = load_wavegru_net("./wavegru.yaml", "./wavegru.ckpt")
19
 
20
  wave_cpp_weight_mask = extract_weight_mask(wavegru_net)
21
- wavecpp = load_wavegru_cpp(wave_cpp_weight_mask)
22
 
23
 
24
  def speak(text):
25
  mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
26
- print(mel.shape)
27
  y = mel_to_wav(wavegru_net, wavecpp, mel, wavegru_config)
28
- print(y.shape)
29
  return 24_000, y
30
 
31
 
@@ -38,7 +35,7 @@ gr.Interface(
38
  examples=[
39
  "this is a test!",
40
  "October arrived, spreading a damp chill over the grounds and into the castle. Madam Pomfrey, the nurse, was kept busy by a sudden spate of colds among the staff and students.",
41
- "Artificial intelligence is intelligence demonstrated by machines, as opposed to natural intelligence displayed by animals including humans",
42
  ],
43
  outputs="audio",
44
  title=title,
 
 
 
1
  ## build wavegru-cpp
2
+ # import os
3
  # os.system("./bazelisk-linux-amd64 clean --expunge")
4
  # os.system("./bazelisk-linux-amd64 build wavegru_mod -c opt --copt=-march=native")
5
 
 
17
  wavegru_config, wavegru_net = load_wavegru_net("./wavegru.yaml", "./wavegru.ckpt")
18
 
19
  wave_cpp_weight_mask = extract_weight_mask(wavegru_net)
20
+ wavecpp = load_wavegru_cpp(wave_cpp_weight_mask, wavegru_config["upsample_factors"][-1])
21
 
22
 
23
  def speak(text):
24
  mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
 
25
  y = mel_to_wav(wavegru_net, wavecpp, mel, wavegru_config)
 
26
  return 24_000, y
27
 
28
 
 
35
  examples=[
36
  "this is a test!",
37
  "October arrived, spreading a damp chill over the grounds and into the castle. Madam Pomfrey, the nurse, was kept busy by a sudden spate of colds among the staff and students.",
38
+ "Artificial intelligence is intelligence demonstrated by machines, as opposed to natural intelligence displayed by animals including humans.",
39
  ],
40
  outputs="audio",
41
  title=title,
extract_model.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ dic = pickle.load(open("./wavegru.ckpt", "rb"))
4
+ del dic["optim_state_dict"]
5
+ pickle.dump(dic, open("./wavegru.ckpt", "wb"))
inference.py CHANGED
@@ -49,7 +49,6 @@ def load_wavegru_net(config_file, model_file):
49
  config = load_wavegru_config(config_file)
50
  net = WaveGRU(
51
  mel_dim=config["mel_dim"],
52
- embed_dim=config["embed_dim"],
53
  rnn_dim=config["rnn_dim"],
54
  upsample_factors=config["upsample_factors"],
55
  )
@@ -74,7 +73,7 @@ def mel_to_wav(net, netcpp, mel, config):
74
  )
75
  ft = wavegru_inference(net, mel)
76
  ft = jax.device_get(ft[0])
77
- wav = netcpp.inference(ft, 1.0)
78
  wav = np.array(wav)
79
  wav = librosa.mu_expand(wav - 127, mu=255)
80
  wav = librosa.effects.deemphasis(wav, coef=0.86)
 
49
  config = load_wavegru_config(config_file)
50
  net = WaveGRU(
51
  mel_dim=config["mel_dim"],
 
52
  rnn_dim=config["rnn_dim"],
53
  upsample_factors=config["upsample_factors"],
54
  )
 
73
  )
74
  ft = wavegru_inference(net, mel)
75
  ft = jax.device_get(ft[0])
76
+ wav = netcpp.inference(ft, 0.9)
77
  wav = np.array(wav)
78
  wav = librosa.mu_expand(wav - 127, mu=255)
79
  wav = librosa.effects.deemphasis(wav, coef=0.86)
wavegru.ckpt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1439b9d8bfd62848cd53d1da08bd8f311004eab0f6bcc682144d4dc2b3c1e6fd
3
- size 56479601
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c06310d989fd524359d5f3ecf8ea1dc146980bb594b7b90553d0d42a64c512d8
3
+ size 58039876
wavegru.py CHANGED
@@ -2,9 +2,13 @@
2
  WaveGRU model: melspectrogram => mu-law encoded waveform
3
  """
4
 
 
 
5
  import jax
6
  import jax.numpy as jnp
7
  import pax
 
 
8
 
9
 
10
  class ReLU(pax.Module):
@@ -37,16 +41,20 @@ def tile_1d(x, factor):
37
  return x
38
 
39
 
40
- def up_block(dim, factor):
41
  """
42
  Tile >> Conv >> BatchNorm >> ReLU
43
  """
44
- return pax.Sequential(
45
  lambda x: tile_1d(x, factor),
46
- pax.Conv1D(dim, dim, 2 * factor, stride=1, padding="VALID", with_bias=False),
47
- pax.LayerNorm(dim, -1, True, True),
48
- ReLU(),
 
49
  )
 
 
 
50
 
51
 
52
  class Upsample(pax.Module):
@@ -54,21 +62,24 @@ class Upsample(pax.Module):
54
  Upsample melspectrogram to match raw audio sample rate.
55
  """
56
 
57
- def __init__(self, input_dim, upsample_factors):
58
  super().__init__()
59
  self.input_conv = pax.Sequential(
60
- pax.Conv1D(input_dim, 512, 1, with_bias=False),
61
- pax.LayerNorm(512, -1, True, True),
62
  )
63
  self.upsample_factors = upsample_factors
64
  self.dilated_convs = [
65
- dilated_residual_conv_block(512, 3, 1, 2**i) for i in range(5)
66
  ]
67
  self.up_factors = upsample_factors[:-1]
68
- self.up_blocks = [up_block(512, x) for x in self.up_factors]
 
 
 
69
  self.final_tile = upsample_factors[-1]
70
 
71
- def __call__(self, x):
72
  x = self.input_conv(x)
73
  for residual in self.dilated_convs:
74
  y = residual(x)
@@ -78,28 +89,55 @@ class Upsample(pax.Module):
78
  for f in self.up_blocks:
79
  x = f(x)
80
 
 
 
81
  x = tile_1d(x, self.final_tile)
82
  return x
83
 
84
 
85
- class Pruner(pax.Module):
86
  """
87
- Base class for pruners
88
  """
89
 
90
- def __init__(self, update_freq=500):
 
 
 
91
  super().__init__()
92
- self.update_freq = update_freq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  def compute_sparsity(self, step):
95
- """
96
- Two-stages pruning
97
- """
98
- t = jnp.power(1 - (step * 1.0 - 1_000) / 300_000, 3)
99
- z = 0.5 * jnp.clip(1.0 - t, a_min=0, a_max=1)
100
- for i in range(4):
101
- t = jnp.power(1 - (step * 1.0 - 1_000 - 400_000 - i * 200_000) / 100_000, 3)
102
- z = z + 0.1 * jnp.clip(1 - t, a_min=0, a_max=1)
103
  return z
104
 
105
  def prune(self, step, weights):
@@ -120,35 +158,32 @@ class Pruner(pax.Module):
120
 
121
 
122
  class GRUPruner(Pruner):
123
- def __init__(self, gru, update_freq=500):
124
- super().__init__(update_freq=update_freq)
125
- self.xh_zr_fc_mask = jnp.ones_like(gru.xh_zr_fc.weight) == 1
126
- self.xh_h_fc_mask = jnp.ones_like(gru.xh_h_fc.weight) == 1
127
 
128
  def __call__(self, gru: pax.GRU):
129
  """
130
  Apply mask after an optimization step
131
  """
132
- zr_masked_weights = jnp.where(self.xh_zr_fc_mask, gru.xh_zr_fc.weight, 0)
133
- gru = gru.replace_node(gru.xh_zr_fc.weight, zr_masked_weights)
134
- h_masked_weights = jnp.where(self.xh_h_fc_mask, gru.xh_h_fc.weight, 0)
135
- gru = gru.replace_node(gru.xh_h_fc.weight, h_masked_weights)
136
  return gru
137
 
138
  def update_mask(self, step, gru: pax.GRU):
139
  """
140
  Update internal masks
141
  """
142
- xh_z_weight, xh_r_weight = jnp.split(gru.xh_zr_fc.weight, 2, axis=1)
143
- xh_z_weight = self.prune(step, xh_z_weight)
144
- xh_r_weight = self.prune(step, xh_r_weight)
145
- self.xh_zr_fc_mask *= jnp.concatenate((xh_z_weight, xh_r_weight), axis=1)
146
- self.xh_h_fc_mask *= self.prune(step, gru.xh_h_fc.weight)
147
 
148
 
149
  class LinearPruner(Pruner):
150
- def __init__(self, linear, update_freq=500):
151
- super().__init__(update_freq=update_freq)
152
  self.mask = jnp.ones_like(linear.weight) == 1
153
 
154
  def __call__(self, linear: pax.Linear):
@@ -166,16 +201,16 @@ class LinearPruner(Pruner):
166
 
167
  class WaveGRU(pax.Module):
168
  """
169
- WaveGRU vocoder model
170
  """
171
 
172
- def __init__(
173
- self, mel_dim=80, embed_dim=32, rnn_dim=512, upsample_factors=(5, 4, 3, 5)
174
- ):
175
  super().__init__()
176
- self.embed = pax.Embed(256, embed_dim)
177
- self.upsample = Upsample(input_dim=mel_dim, upsample_factors=upsample_factors)
178
- self.rnn = pax.GRU(embed_dim + rnn_dim, rnn_dim)
 
 
179
  self.o1 = pax.Linear(rnn_dim, rnn_dim)
180
  self.o2 = pax.Linear(rnn_dim, 256)
181
  self.gru_pruner = GRUPruner(self.rnn)
@@ -188,31 +223,30 @@ class WaveGRU(pax.Module):
188
  x = self.o2(x)
189
  return x
190
 
191
- @jax.jit
192
- def inference_step(self, rnn_state, mel, rng_key, x):
193
- """one inference step"""
194
- x = self.embed(x)
195
- x = jnp.concatenate((x, mel), axis=-1)
196
- rnn_state, x = self.rnn(rnn_state, x)
197
- x = self.output(x)
198
- rng_key, next_rng_key = jax.random.split(rng_key, 2)
199
- x = jax.random.categorical(rng_key, x, axis=-1)
200
- return rnn_state, next_rng_key, x
201
-
202
  def inference(self, mel, no_gru=False, seed=42):
203
  """
204
  generate waveform form melspectrogram
205
  """
206
 
207
- y = self.upsample(mel)
 
 
 
 
 
 
 
 
 
 
208
  if no_gru:
209
  return y
210
  x = jnp.array([127], dtype=jnp.int32)
211
  rnn_state = self.rnn.initial_state(1)
212
  output = []
213
  rng_key = jax.random.PRNGKey(seed)
214
- for i in range(y.shape[1]):
215
- rnn_state, rng_key, x = self.inference_step(rnn_state, y[:, i], rng_key, x)
216
  output.append(x)
217
  x = jnp.concatenate(output, axis=0)
218
  return x
@@ -223,7 +257,7 @@ class WaveGRU(pax.Module):
223
  pad_left = (x.shape[1] - y.shape[1]) // 2
224
  pad_right = x.shape[1] - y.shape[1] - pad_left
225
  x = x[:, pad_left:-pad_right]
226
- x = jnp.concatenate((x, y), axis=-1)
227
  _, x = pax.scan(
228
  self.rnn,
229
  self.rnn.initial_state(x.shape[0]),
 
2
  WaveGRU model: melspectrogram => mu-law encoded waveform
3
  """
4
 
5
+ from typing import Tuple
6
+
7
  import jax
8
  import jax.numpy as jnp
9
  import pax
10
+ from pax import GRUState
11
+ from tqdm.cli import tqdm
12
 
13
 
14
  class ReLU(pax.Module):
 
41
  return x
42
 
43
 
44
+ def up_block(in_dim, out_dim, factor, relu=True):
45
  """
46
  Tile >> Conv >> BatchNorm >> ReLU
47
  """
48
+ f = pax.Sequential(
49
  lambda x: tile_1d(x, factor),
50
+ pax.Conv1D(
51
+ in_dim, out_dim, 2 * factor, stride=1, padding="VALID", with_bias=False
52
+ ),
53
+ pax.LayerNorm(out_dim, -1, True, True),
54
  )
55
+ if relu:
56
+ f >>= ReLU()
57
+ return f
58
 
59
 
60
  class Upsample(pax.Module):
 
62
  Upsample melspectrogram to match raw audio sample rate.
63
  """
64
 
65
+ def __init__(self, input_dim, rnn_dim, upsample_factors):
66
  super().__init__()
67
  self.input_conv = pax.Sequential(
68
+ pax.Conv1D(input_dim, rnn_dim, 1, with_bias=False),
69
+ pax.LayerNorm(rnn_dim, -1, True, True),
70
  )
71
  self.upsample_factors = upsample_factors
72
  self.dilated_convs = [
73
+ dilated_residual_conv_block(rnn_dim, 3, 1, 2**i) for i in range(5)
74
  ]
75
  self.up_factors = upsample_factors[:-1]
76
+ self.up_blocks = [up_block(rnn_dim, rnn_dim, x) for x in self.up_factors[:-1]]
77
+ self.up_blocks.append(
78
+ up_block(rnn_dim, 3 * rnn_dim, self.up_factors[-1], relu=False)
79
+ )
80
  self.final_tile = upsample_factors[-1]
81
 
82
+ def __call__(self, x, no_repeat=False):
83
  x = self.input_conv(x)
84
  for residual in self.dilated_convs:
85
  y = residual(x)
 
89
  for f in self.up_blocks:
90
  x = f(x)
91
 
92
+ if no_repeat:
93
+ return x
94
  x = tile_1d(x, self.final_tile)
95
  return x
96
 
97
 
98
+ class GRU(pax.Module):
99
  """
100
+ A customized GRU module.
101
  """
102
 
103
+ input_dim: int
104
+ hidden_dim: int
105
+
106
+ def __init__(self, hidden_dim: int):
107
  super().__init__()
108
+ self.hidden_dim = hidden_dim
109
+ self.h_zrh_fc = pax.Linear(hidden_dim, hidden_dim * 3)
110
+
111
+ def initial_state(self, batch_size: int) -> GRUState:
112
+ """Create an all zeros initial state."""
113
+ return GRUState(jnp.zeros((batch_size, self.hidden_dim), dtype=jnp.float32))
114
+
115
+ def __call__(self, state: GRUState, x) -> Tuple[GRUState, jnp.ndarray]:
116
+ hidden = state.hidden
117
+ x_zrh = x
118
+ h_zrh = self.h_zrh_fc(hidden)
119
+ x_zr, x_h = jnp.split(x_zrh, [2 * self.hidden_dim], axis=-1)
120
+ h_zr, h_h = jnp.split(h_zrh, [2 * self.hidden_dim], axis=-1)
121
+
122
+ zr = x_zr + h_zr
123
+ zr = jax.nn.sigmoid(zr)
124
+ z, r = jnp.split(zr, 2, axis=-1)
125
+
126
+ h_hat = x_h + r * h_h
127
+ h_hat = jnp.tanh(h_hat)
128
+
129
+ h = (1 - z) * hidden + z * h_hat
130
+ return GRUState(h), h
131
+
132
+
133
+ class Pruner(pax.Module):
134
+ """
135
+ Base class for pruners
136
+ """
137
 
138
  def compute_sparsity(self, step):
139
+ t = jnp.power(1 - (step * 1.0 - 1_000) / 200_000, 3)
140
+ z = 0.9 * jnp.clip(1.0 - t, a_min=0, a_max=1)
 
 
 
 
 
 
141
  return z
142
 
143
  def prune(self, step, weights):
 
158
 
159
 
160
  class GRUPruner(Pruner):
161
+ def __init__(self, gru):
162
+ super().__init__()
163
+ self.h_zrh_fc_mask = jnp.ones_like(gru.h_zrh_fc.weight) == 1
 
164
 
165
  def __call__(self, gru: pax.GRU):
166
  """
167
  Apply mask after an optimization step
168
  """
169
+ zrh_masked_weights = jnp.where(self.h_zrh_fc_mask, gru.h_zrh_fc.weight, 0)
170
+ gru = gru.replace_node(gru.h_zrh_fc.weight, zrh_masked_weights)
 
 
171
  return gru
172
 
173
  def update_mask(self, step, gru: pax.GRU):
174
  """
175
  Update internal masks
176
  """
177
+ z_weight, r_weight, h_weight = jnp.split(gru.h_zrh_fc.weight, 3, axis=1)
178
+ z_mask = self.prune(step, z_weight)
179
+ r_mask = self.prune(step, r_weight)
180
+ h_mask = self.prune(step, h_weight)
181
+ self.h_zrh_fc_mask *= jnp.concatenate((z_mask, r_mask, h_mask), axis=1)
182
 
183
 
184
  class LinearPruner(Pruner):
185
+ def __init__(self, linear):
186
+ super().__init__()
187
  self.mask = jnp.ones_like(linear.weight) == 1
188
 
189
  def __call__(self, linear: pax.Linear):
 
201
 
202
  class WaveGRU(pax.Module):
203
  """
204
+ WaveGRU vocoder model.
205
  """
206
 
207
+ def __init__(self, mel_dim=80, rnn_dim=512, upsample_factors=(5, 3, 20)):
 
 
208
  super().__init__()
209
+ self.embed = pax.Embed(256, 3 * rnn_dim)
210
+ self.upsample = Upsample(
211
+ input_dim=mel_dim, rnn_dim=rnn_dim, upsample_factors=upsample_factors
212
+ )
213
+ self.rnn = GRU(rnn_dim)
214
  self.o1 = pax.Linear(rnn_dim, rnn_dim)
215
  self.o2 = pax.Linear(rnn_dim, 256)
216
  self.gru_pruner = GRUPruner(self.rnn)
 
223
  x = self.o2(x)
224
  return x
225
 
 
 
 
 
 
 
 
 
 
 
 
226
  def inference(self, mel, no_gru=False, seed=42):
227
  """
228
  generate waveform form melspectrogram
229
  """
230
 
231
+ @jax.jit
232
+ def step(rnn_state, mel, rng_key, x):
233
+ x = self.embed(x)
234
+ x = x + mel
235
+ rnn_state, x = self.rnn(rnn_state, x)
236
+ x = self.output(x)
237
+ rng_key, next_rng_key = jax.random.split(rng_key, 2)
238
+ x = jax.random.categorical(rng_key, x, axis=-1)
239
+ return rnn_state, next_rng_key, x
240
+
241
+ y = self.upsample(mel, no_repeat=no_gru)
242
  if no_gru:
243
  return y
244
  x = jnp.array([127], dtype=jnp.int32)
245
  rnn_state = self.rnn.initial_state(1)
246
  output = []
247
  rng_key = jax.random.PRNGKey(seed)
248
+ for i in tqdm(range(y.shape[1])):
249
+ rnn_state, rng_key, x = step(rnn_state, y[:, i], rng_key, x)
250
  output.append(x)
251
  x = jnp.concatenate(output, axis=0)
252
  return x
 
257
  pad_left = (x.shape[1] - y.shape[1]) // 2
258
  pad_right = x.shape[1] - y.shape[1] - pad_left
259
  x = x[:, pad_left:-pad_right]
260
+ x = x + y
261
  _, x = pax.scan(
262
  self.rnn,
263
  self.rnn.initial_state(x.shape[0]),
wavegru.yaml CHANGED
@@ -11,4 +11,4 @@ embed_dim: 32
11
  rnn_dim: 512
12
  frames_per_sequence: 67
13
  num_pad_frames: 62
14
- upsample_factors: [5, 4, 3, 5]
 
11
  rnn_dim: 512
12
  frames_per_sequence: 67
13
  num_pad_frames: 62
14
+ upsample_factors: [5, 3, 20]
wavegru_cpp.py CHANGED
@@ -1,18 +1,13 @@
1
  import numpy as np
2
- import sys
3
  from wavegru_mod import WaveGRU
4
 
5
 
6
  def extract_weight_mask(net):
7
  data = {}
8
  data["embed_weight"] = net.embed.weight
9
- data["gru_xh_zr_weight"] = net.rnn.xh_zr_fc.weight
10
- data["gru_xh_zr_mask"] = net.gru_pruner.xh_zr_fc_mask
11
- data["gru_xh_zr_bias"] = net.rnn.xh_zr_fc.bias
12
-
13
- data["gru_xh_h_weight"] = net.rnn.xh_h_fc.weight
14
- data["gru_xh_h_mask"] = net.gru_pruner.xh_h_fc_mask
15
- data["gru_xh_h_bias"] = net.rnn.xh_h_fc.bias
16
 
17
  data["o1_weight"] = net.o1.weight
18
  data["o1_mask"] = net.o1_pruner.mask
@@ -23,31 +18,16 @@ def extract_weight_mask(net):
23
  return data
24
 
25
 
26
- def load_wavegru_cpp(data):
 
27
  embed = data["embed_weight"]
28
- embed_dim = embed.shape[1]
29
- rnn_dim = data["gru_xh_h_bias"].shape[0]
30
- input_dim = data["gru_xh_zr_weight"].shape[1] - rnn_dim
31
- net = WaveGRU(input_dim, embed_dim, rnn_dim)
32
  net.load_embed(embed)
33
- dim = embed_dim + input_dim + rnn_dim
34
- z, r = np.split(data["gru_xh_zr_weight"].T, 2, axis=0)
35
- h = data["gru_xh_h_weight"].T
36
- z = np.ascontiguousarray(z)
37
- r = np.ascontiguousarray(r)
38
- h = np.ascontiguousarray(h)
39
-
40
- b1, b2 = np.split(data["gru_xh_zr_bias"], 2)
41
- b3 = data["gru_xh_h_bias"]
42
- m1, m2, m3 = z, r, h
43
-
44
- mask_z, mask_r = np.split(data["gru_xh_zr_mask"].T, 2, axis=0)
45
- mask_h = data["gru_xh_h_mask"].T
46
- mask_z = np.ascontiguousarray(mask_z)
47
- mask_r = np.ascontiguousarray(mask_r)
48
- mask_h = np.ascontiguousarray(mask_h)
49
 
50
- mask1, mask2, mask3 = mask_z, mask_r, mask_h
 
 
51
 
52
  o1 = np.ascontiguousarray(data["o1_weight"].T)
53
  masko1 = np.ascontiguousarray(data["o1_mask"].T)
@@ -57,22 +37,6 @@ def load_wavegru_cpp(data):
57
  masko2 = np.ascontiguousarray(data["o2_mask"].T)
58
  o2b = data["o2_bias"]
59
 
60
- net.load_weights(
61
- m1,
62
- mask1,
63
- b1,
64
- m2,
65
- mask2,
66
- b2,
67
- m3,
68
- mask3,
69
- b3,
70
- o1,
71
- masko1,
72
- o1b,
73
- o2,
74
- masko2,
75
- o2b,
76
- )
77
 
78
  return net
 
1
  import numpy as np
 
2
  from wavegru_mod import WaveGRU
3
 
4
 
5
  def extract_weight_mask(net):
6
  data = {}
7
  data["embed_weight"] = net.embed.weight
8
+ data["gru_h_zrh_weight"] = net.rnn.h_zrh_fc.weight
9
+ data["gru_h_zrh_mask"] = net.gru_pruner.h_zrh_fc_mask
10
+ data["gru_h_zrh_bias"] = net.rnn.h_zrh_fc.bias
 
 
 
 
11
 
12
  data["o1_weight"] = net.o1.weight
13
  data["o1_mask"] = net.o1_pruner.mask
 
18
  return data
19
 
20
 
21
+ def load_wavegru_cpp(data, repeat_factor):
22
+ """load wavegru weight to cpp object"""
23
  embed = data["embed_weight"]
24
+ rnn_dim = data["gru_h_zrh_bias"].shape[0] // 3
25
+ net = WaveGRU(rnn_dim, repeat_factor)
 
 
26
  net.load_embed(embed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ m = np.ascontiguousarray(data["gru_h_zrh_weight"].T)
29
+ mask = np.ascontiguousarray(data["gru_h_zrh_mask"].T)
30
+ b = data["gru_h_zrh_bias"]
31
 
32
  o1 = np.ascontiguousarray(data["o1_weight"].T)
33
  masko1 = np.ascontiguousarray(data["o1_mask"].T)
 
37
  masko2 = np.ascontiguousarray(data["o2_mask"].T)
38
  o2b = data["o2_bias"]
39
 
40
+ net.load_weights(m, mask, b, o1, masko1, o1b, o2, masko2, o2b)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  return net
wavegru_mod.cc CHANGED
@@ -30,12 +30,11 @@ mat create_mat(int h, int w) {
30
  }
31
 
32
  struct WaveGRU {
33
- int input_dim;
34
- int embed_dim;
35
  int hidden_dim;
36
- mat m1, m2, m3;
37
- vec b1, b2, b3;
38
- vec z, r, hh;
 
39
  vec fco1, fco2;
40
  vec o1b, o2b;
41
  vec t;
@@ -43,30 +42,26 @@ struct WaveGRU {
43
  mat o1, o2;
44
  std::vector<vec> embed;
45
 
46
- WaveGRU(int input_dim, int embed_dim, int hidden_dim)
47
- : input_dim(input_dim),
48
- embed_dim(embed_dim),
49
- hidden_dim(hidden_dim),
50
- b1(hidden_dim),
51
- b2(hidden_dim),
52
- b3(hidden_dim),
53
  z(hidden_dim),
54
  r(hidden_dim),
55
  hh(hidden_dim),
56
  fco1(hidden_dim),
57
  fco2(256),
58
- t(hidden_dim + input_dim + embed_dim),
59
  h(hidden_dim),
60
  o1b(hidden_dim),
61
  o2b(256) {
62
- m1 = create_mat(input_dim + hidden_dim + embed_dim, hidden_dim);
63
- m2 = create_mat(input_dim + hidden_dim + embed_dim, hidden_dim);
64
- m3 = create_mat(input_dim + hidden_dim + embed_dim, hidden_dim);
65
  o1 = create_mat(hidden_dim, hidden_dim);
66
  o2 = create_mat(hidden_dim, 256);
67
  embed = std::vector<vec>();
68
  for (int i = 0; i < 256; i++) {
69
- embed.emplace_back(embed_dim);
70
  embed[i].FillRandom();
71
  }
72
  }
@@ -74,7 +69,7 @@ struct WaveGRU {
74
  void load_embed(fndarray embed_weights) {
75
  auto a_embed = embed_weights.unchecked<2>();
76
  for (int i = 0; i < 256; i++) {
77
- for (int j = 0; j < embed_dim; j++) embed[i][j] = a_embed(i, j);
78
  }
79
  }
80
 
@@ -90,43 +85,35 @@ struct WaveGRU {
90
  return mmm;
91
  }
92
 
93
- void load_weights(fndarray m1, indarray m1_mask, fndarray b1, fndarray m2,
94
- indarray m2_mask, fndarray b2, fndarray m3,
95
- indarray m3_mask, fndarray b3, fndarray o1,
96
- indarray o1_mask, fndarray o1b, fndarray o2,
97
  indarray o2_mask, fndarray o2b) {
98
- this->m1 = load_linear(this->b1, m1, m1_mask, b1);
99
- this->m2 = load_linear(this->b2, m2, m2_mask, b2);
100
- this->m3 = load_linear(this->b3, m3, m3_mask, b3);
101
  this->o1 = load_linear(this->o1b, o1, o1_mask, o1b);
102
  this->o2 = load_linear(this->o2b, o2, o2_mask, o2b);
103
  }
104
 
105
  std::vector<int> inference(fndarray ft, float temperature) {
106
  auto rft = ft.unchecked<2>();
107
- std::vector<vec> xs;
108
- for (int i = 0; i < rft.shape(0); i++) {
109
- xs.emplace_back(input_dim);
110
- for (int j = 0; j < input_dim; j++) xs[i][j] = rft(i, j);
111
- }
112
-
113
  int value = 127;
114
- std::vector<int> signal(xs.size());
115
  h.FillZero();
116
- for (int index = 0; index < xs.size(); index++) {
117
- for (int i = 0; i < embed_dim; i++) t[i] = embed[value][i];
118
- for (int i = 0; i < input_dim; i++) t[embed_dim + i] = xs[index][i];
119
- for (int i = 0; i < hidden_dim; i++) t[embed_dim + input_dim + i] = h[i];
120
- m1.SpMM_bias(t, b1, &z, false);
121
- m2.SpMM_bias(t, b2, &r, false);
 
 
 
122
  z.Sigmoid();
123
  r.Sigmoid();
124
 
125
  for (int i = 0; i < hidden_dim; i++) {
126
- t[embed_dim + input_dim + i] = h[i] * r[i];
127
  }
128
-
129
- m3.SpMM_bias(t, b3, &hh, false);
130
  hh.Tanh();
131
  for (int i = 0; i < hidden_dim; i++) {
132
  h[i] = (1. - z[i]) * h[i] + z[i] * hh[i];
@@ -142,7 +129,7 @@ struct WaveGRU {
142
 
143
  PYBIND11_MODULE(wavegru_mod, m) {
144
  py::class_<WaveGRU>(m, "WaveGRU")
145
- .def(py::init<int, int, int>())
146
  .def("load_embed", &WaveGRU::load_embed)
147
  .def("load_weights", &WaveGRU::load_weights)
148
  .def("inference", &WaveGRU::inference);
 
30
  }
31
 
32
  struct WaveGRU {
 
 
33
  int hidden_dim;
34
+ int repeat_factor;
35
+ mat m;
36
+ vec b;
37
+ vec z, r, hh, zrh;
38
  vec fco1, fco2;
39
  vec o1b, o2b;
40
  vec t;
 
42
  mat o1, o2;
43
  std::vector<vec> embed;
44
 
45
+ WaveGRU(int hidden_dim, int repeat_factor)
46
+ : hidden_dim(hidden_dim),
47
+ repeat_factor(repeat_factor),
48
+ b(3*hidden_dim),
49
+ t(3*hidden_dim),
50
+ zrh(3*hidden_dim),
 
51
  z(hidden_dim),
52
  r(hidden_dim),
53
  hh(hidden_dim),
54
  fco1(hidden_dim),
55
  fco2(256),
 
56
  h(hidden_dim),
57
  o1b(hidden_dim),
58
  o2b(256) {
59
+ m = create_mat(hidden_dim, 3*hidden_dim);
 
 
60
  o1 = create_mat(hidden_dim, hidden_dim);
61
  o2 = create_mat(hidden_dim, 256);
62
  embed = std::vector<vec>();
63
  for (int i = 0; i < 256; i++) {
64
+ embed.emplace_back(hidden_dim * 3);
65
  embed[i].FillRandom();
66
  }
67
  }
 
69
  void load_embed(fndarray embed_weights) {
70
  auto a_embed = embed_weights.unchecked<2>();
71
  for (int i = 0; i < 256; i++) {
72
+ for (int j = 0; j < hidden_dim * 3; j++) embed[i][j] = a_embed(i, j);
73
  }
74
  }
75
 
 
85
  return mmm;
86
  }
87
 
88
+ void load_weights(fndarray m, indarray m_mask, fndarray b,
89
+ fndarray o1, indarray o1_mask,
90
+ fndarray o1b, fndarray o2,
 
91
  indarray o2_mask, fndarray o2b) {
92
+ this->m = load_linear(this->b, m, m_mask, b);
 
 
93
  this->o1 = load_linear(this->o1b, o1, o1_mask, o1b);
94
  this->o2 = load_linear(this->o2b, o2, o2_mask, o2b);
95
  }
96
 
97
  std::vector<int> inference(fndarray ft, float temperature) {
98
  auto rft = ft.unchecked<2>();
 
 
 
 
 
 
99
  int value = 127;
100
+ std::vector<int> signal(rft.shape(0) * repeat_factor);
101
  h.FillZero();
102
+ for (int index = 0; index < signal.size(); index++) {
103
+ m.SpMM_bias(h, b, &zrh, false);
104
+
105
+ for (int i = 0; i < 3 * hidden_dim; i++) t[i] = embed[value][i] + rft(index / repeat_factor, i);
106
+ for (int i = 0; i < hidden_dim; i++) {
107
+ z[i] = zrh[i] + t[i];
108
+ r[i] = zrh[hidden_dim + i] + t[hidden_dim + i];
109
+ }
110
+
111
  z.Sigmoid();
112
  r.Sigmoid();
113
 
114
  for (int i = 0; i < hidden_dim; i++) {
115
+ hh[i] = zrh[hidden_dim * 2 + i] * r[i] + t[hidden_dim * 2 + i];
116
  }
 
 
117
  hh.Tanh();
118
  for (int i = 0; i < hidden_dim; i++) {
119
  h[i] = (1. - z[i]) * h[i] + z[i] * hh[i];
 
129
 
130
  PYBIND11_MODULE(wavegru_mod, m) {
131
  py::class_<WaveGRU>(m, "WaveGRU")
132
+ .def(py::init<int, int>())
133
  .def("load_embed", &WaveGRU::load_embed)
134
  .def("load_weights", &WaveGRU::load_weights)
135
  .def("inference", &WaveGRU::inference);
wavegru_mod.so CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d5fd47ca8d441ba204b3fcff514341918d5795618ad2740704d09f62c5b6c47a
3
- size 527976
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12c27f0ea07f8da3a3ab48bc01bb0f68971ce7d57b19ada87669eab138623a9c
3
+ size 525536