NTT123 commited on
Commit
012ab0b
·
1 Parent(s): 267755a

new 1024 gru unit model

Browse files
Files changed (5) hide show
  1. app.py +1 -1
  2. inference.py +1 -0
  3. wavegru.ckpt +2 -2
  4. wavegru.py +42 -10
  5. wavegru.yaml +1 -1
app.py CHANGED
@@ -49,4 +49,4 @@ gr.Interface(
49
  theme="default",
50
  allow_screenshot=False,
51
  allow_flagging="never",
52
- ).launch(debug=True, enable_queue=True)
 
49
  theme="default",
50
  allow_screenshot=False,
51
  allow_flagging="never",
52
+ ).launch(server_port=5000, debug=True, show_error=True)
inference.py CHANGED
@@ -51,6 +51,7 @@ def load_wavegru_net(config_file, model_file):
51
  mel_dim=config["mel_dim"],
52
  rnn_dim=config["rnn_dim"],
53
  upsample_factors=config["upsample_factors"],
 
54
  )
55
  _, net, _ = load_wavegru_ckpt(net, None, model_file)
56
  net = net.eval()
 
51
  mel_dim=config["mel_dim"],
52
  rnn_dim=config["rnn_dim"],
53
  upsample_factors=config["upsample_factors"],
54
+ has_linear_output=True,
55
  )
56
  _, net, _ = load_wavegru_ckpt(net, None, model_file)
57
  net = net.eval()
wavegru.ckpt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:edd8a1ccc0a74a0b63fa416699fc0991e798d1444683be4eaf6a65249c56f8de
3
- size 58039876
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1af9c38d0fffcf41942e4bd8d6c88f6b33f52695619d7e42359b267857019081
3
+ size 69717674
wavegru.py CHANGED
@@ -62,21 +62,34 @@ class Upsample(pax.Module):
62
  Upsample melspectrogram to match raw audio sample rate.
63
  """
64
 
65
- def __init__(self, input_dim, rnn_dim, upsample_factors):
 
 
66
  super().__init__()
67
  self.input_conv = pax.Sequential(
68
- pax.Conv1D(input_dim, rnn_dim, 1, with_bias=False),
69
- pax.LayerNorm(rnn_dim, -1, True, True),
70
  )
71
  self.upsample_factors = upsample_factors
72
  self.dilated_convs = [
73
- dilated_residual_conv_block(rnn_dim, 3, 1, 2**i) for i in range(5)
74
  ]
75
  self.up_factors = upsample_factors[:-1]
76
- self.up_blocks = [up_block(rnn_dim, rnn_dim, x) for x in self.up_factors[:-1]]
 
 
77
  self.up_blocks.append(
78
- up_block(rnn_dim, 3 * rnn_dim, self.up_factors[-1], relu=False)
 
 
 
 
 
79
  )
 
 
 
 
80
  self.final_tile = upsample_factors[-1]
81
 
82
  def __call__(self, x, no_repeat=False):
@@ -89,6 +102,9 @@ class Upsample(pax.Module):
89
  for f in self.up_blocks:
90
  x = f(x)
91
 
 
 
 
92
  if no_repeat:
93
  return x
94
  x = tile_1d(x, self.final_tile)
@@ -106,7 +122,13 @@ class GRU(pax.Module):
106
  def __init__(self, hidden_dim: int):
107
  super().__init__()
108
  self.hidden_dim = hidden_dim
109
- self.h_zrh_fc = pax.Linear(hidden_dim, hidden_dim * 3)
 
 
 
 
 
 
110
 
111
  def initial_state(self, batch_size: int) -> GRUState:
112
  """Create an all zeros initial state."""
@@ -137,7 +159,7 @@ class Pruner(pax.Module):
137
 
138
  def compute_sparsity(self, step):
139
  t = jnp.power(1 - (step * 1.0 - 1_000) / 200_000, 3)
140
- z = 0.9 * jnp.clip(1.0 - t, a_min=0, a_max=1)
141
  return z
142
 
143
  def prune(self, step, weights):
@@ -204,11 +226,21 @@ class WaveGRU(pax.Module):
204
  WaveGRU vocoder model.
205
  """
206
 
207
- def __init__(self, mel_dim=80, rnn_dim=512, upsample_factors=(5, 3, 20)):
 
 
 
 
 
 
208
  super().__init__()
209
  self.embed = pax.Embed(256, 3 * rnn_dim)
210
  self.upsample = Upsample(
211
- input_dim=mel_dim, rnn_dim=rnn_dim, upsample_factors=upsample_factors
 
 
 
 
212
  )
213
  self.rnn = GRU(rnn_dim)
214
  self.o1 = pax.Linear(rnn_dim, rnn_dim)
 
62
  Upsample melspectrogram to match raw audio sample rate.
63
  """
64
 
65
+ def __init__(
66
+ self, input_dim, hidden_dim, rnn_dim, upsample_factors, has_linear_output=False
67
+ ):
68
  super().__init__()
69
  self.input_conv = pax.Sequential(
70
+ pax.Conv1D(input_dim, hidden_dim, 1, with_bias=False),
71
+ pax.LayerNorm(hidden_dim, -1, True, True),
72
  )
73
  self.upsample_factors = upsample_factors
74
  self.dilated_convs = [
75
+ dilated_residual_conv_block(hidden_dim, 3, 1, 2**i) for i in range(5)
76
  ]
77
  self.up_factors = upsample_factors[:-1]
78
+ self.up_blocks = [
79
+ up_block(hidden_dim, hidden_dim, x) for x in self.up_factors[:-1]
80
+ ]
81
  self.up_blocks.append(
82
+ up_block(
83
+ hidden_dim,
84
+ hidden_dim if has_linear_output else 3 * rnn_dim,
85
+ self.up_factors[-1],
86
+ relu=False,
87
+ )
88
  )
89
+ if has_linear_output:
90
+ self.x2zrh_fc = pax.Linear(hidden_dim, rnn_dim * 3)
91
+ self.has_linear_output = has_linear_output
92
+
93
  self.final_tile = upsample_factors[-1]
94
 
95
  def __call__(self, x, no_repeat=False):
 
102
  for f in self.up_blocks:
103
  x = f(x)
104
 
105
+ if self.has_linear_output:
106
+ x = self.x2zrh_fc(x)
107
+
108
  if no_repeat:
109
  return x
110
  x = tile_1d(x, self.final_tile)
 
122
  def __init__(self, hidden_dim: int):
123
  super().__init__()
124
  self.hidden_dim = hidden_dim
125
+ self.h_zrh_fc = pax.Linear(
126
+ hidden_dim,
127
+ hidden_dim * 3,
128
+ w_init=jax.nn.initializers.variance_scaling(
129
+ 1, "fan_out", "truncated_normal"
130
+ ),
131
+ )
132
 
133
  def initial_state(self, batch_size: int) -> GRUState:
134
  """Create an all zeros initial state."""
 
159
 
160
  def compute_sparsity(self, step):
161
  t = jnp.power(1 - (step * 1.0 - 1_000) / 200_000, 3)
162
+ z = 0.95 * jnp.clip(1.0 - t, a_min=0, a_max=1)
163
  return z
164
 
165
  def prune(self, step, weights):
 
226
  WaveGRU vocoder model.
227
  """
228
 
229
+ def __init__(
230
+ self,
231
+ mel_dim=80,
232
+ rnn_dim=1024,
233
+ upsample_factors=(5, 3, 20),
234
+ has_linear_output=False,
235
+ ):
236
  super().__init__()
237
  self.embed = pax.Embed(256, 3 * rnn_dim)
238
  self.upsample = Upsample(
239
+ input_dim=mel_dim,
240
+ hidden_dim=512,
241
+ rnn_dim=rnn_dim,
242
+ upsample_factors=upsample_factors,
243
+ has_linear_output=has_linear_output,
244
  )
245
  self.rnn = GRU(rnn_dim)
246
  self.o1 = pax.Linear(rnn_dim, rnn_dim)
wavegru.yaml CHANGED
@@ -8,7 +8,7 @@ n_fft: 2048
8
 
9
  ## wavegru
10
  embed_dim: 32
11
- rnn_dim: 512
12
  frames_per_sequence: 67
13
  num_pad_frames: 62
14
  upsample_factors: [5, 3, 20]
 
8
 
9
  ## wavegru
10
  embed_dim: 32
11
+ rnn_dim: 1024
12
  frames_per_sequence: 67
13
  num_pad_frames: 62
14
  upsample_factors: [5, 3, 20]