RayeRen commited on
Commit
91c5bdb
1 Parent(s): e75aa39
inference/tts/ps_flow.py CHANGED
@@ -11,7 +11,8 @@ class PortaSpeechFlowInfer(BaseTTSInfer):
11
  word_dict_size = len(self.word_encoder)
12
  model = PortaSpeechFlow(ph_dict_size, word_dict_size, self.hparams)
13
  load_ckpt(model, hparams['work_dir'], 'model')
14
- model.post_flow.store_inverse()
 
15
  model.eval()
16
  return model
17
 
 
11
  word_dict_size = len(self.word_encoder)
12
  model = PortaSpeechFlow(ph_dict_size, word_dict_size, self.hparams)
13
  load_ckpt(model, hparams['work_dir'], 'model')
14
+ with torch.no_grad():
15
+ model.store_inverse_all()
16
  model.eval()
17
  return model
18
 
modules/tts/portaspeech/portaspeech.py CHANGED
@@ -212,4 +212,15 @@ class PortaSpeech(FastSpeech):
212
  x_pos = build_word_mask(word2word, x2word).float() # [B, T_word, T_ph]
213
  x_pos = (x_pos.cumsum(-1) / x_pos.sum(-1).clamp(min=1)[..., None] * x_pos).sum(1)
214
  x_pos = self.sin_pos(x_pos.float()) # [B, T_ph, H]
215
- return x_pos
 
 
 
 
 
 
 
 
 
 
 
 
212
  x_pos = build_word_mask(word2word, x2word).float() # [B, T_word, T_ph]
213
  x_pos = (x_pos.cumsum(-1) / x_pos.sum(-1).clamp(min=1)[..., None] * x_pos).sum(1)
214
  x_pos = self.sin_pos(x_pos.float()) # [B, T_ph, H]
215
+ return x_pos
216
+
217
+ def store_inverse_all(self):
218
+ def remove_weight_norm(m):
219
+ try:
220
+ if hasattr(m, 'store_inverse'):
221
+ m.store_inverse()
222
+ nn.utils.remove_weight_norm(m)
223
+ except ValueError: # this module didn't have weight norm
224
+ return
225
+
226
+ self.apply(remove_weight_norm)
tasks/tts/ps.py CHANGED
@@ -156,14 +156,7 @@ class PortaSpeechTask(FastSpeechTask):
156
  super().test_start()
157
  if hparams.get('save_attn', False):
158
  os.makedirs(f'{self.gen_dir}/attn', exist_ok=True)
159
-
160
- def remove_weight_norm(m):
161
- try:
162
- nn.utils.remove_weight_norm(m)
163
- except ValueError:
164
- return
165
-
166
- self.apply(remove_weight_norm)
167
 
168
  def test_step(self, sample, batch_idx):
169
  assert sample['txt_tokens'].shape[0] == 1, 'only support batch_size=1 in inference'
 
156
  super().test_start()
157
  if hparams.get('save_attn', False):
158
  os.makedirs(f'{self.gen_dir}/attn', exist_ok=True)
159
+ self.model.store_inverse_all()
 
 
 
 
 
 
 
160
 
161
  def test_step(self, sample, batch_idx):
162
  assert sample['txt_tokens'].shape[0] == 1, 'only support batch_size=1 in inference'
tasks/tts/ps_flow.py CHANGED
@@ -131,12 +131,4 @@ class PortaSpeechFlowTask(PortaSpeechTask):
131
  return [self.optimizer]
132
 
133
  def build_scheduler(self, optimizer):
134
- return FastSpeechTask.build_scheduler(self, optimizer[0])
135
-
136
- ############
137
- # infer
138
- ############
139
- def test_start(self):
140
- super().test_start()
141
- if hparams['use_post_flow']:
142
- self.model.post_flow.store_inverse()
 
131
  return [self.optimizer]
132
 
133
  def build_scheduler(self, optimizer):
134
+ return FastSpeechTask.build_scheduler(self, optimizer[0])