Spaces:
Runtime error
Runtime error
update
Browse files- inference/tts/ps_flow.py +2 -1
- modules/tts/portaspeech/portaspeech.py +12 -1
- tasks/tts/ps.py +1 -8
- tasks/tts/ps_flow.py +1 -9
inference/tts/ps_flow.py
CHANGED
@@ -11,7 +11,8 @@ class PortaSpeechFlowInfer(BaseTTSInfer):
|
|
11 |
word_dict_size = len(self.word_encoder)
|
12 |
model = PortaSpeechFlow(ph_dict_size, word_dict_size, self.hparams)
|
13 |
load_ckpt(model, hparams['work_dir'], 'model')
|
14 |
-
|
|
|
15 |
model.eval()
|
16 |
return model
|
17 |
|
|
|
11 |
word_dict_size = len(self.word_encoder)
|
12 |
model = PortaSpeechFlow(ph_dict_size, word_dict_size, self.hparams)
|
13 |
load_ckpt(model, hparams['work_dir'], 'model')
|
14 |
+
with torch.no_grad():
|
15 |
+
model.store_inverse_all()
|
16 |
model.eval()
|
17 |
return model
|
18 |
|
modules/tts/portaspeech/portaspeech.py
CHANGED
@@ -212,4 +212,15 @@ class PortaSpeech(FastSpeech):
|
|
212 |
x_pos = build_word_mask(word2word, x2word).float() # [B, T_word, T_ph]
|
213 |
x_pos = (x_pos.cumsum(-1) / x_pos.sum(-1).clamp(min=1)[..., None] * x_pos).sum(1)
|
214 |
x_pos = self.sin_pos(x_pos.float()) # [B, T_ph, H]
|
215 |
-
return x_pos
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
x_pos = build_word_mask(word2word, x2word).float() # [B, T_word, T_ph]
|
213 |
x_pos = (x_pos.cumsum(-1) / x_pos.sum(-1).clamp(min=1)[..., None] * x_pos).sum(1)
|
214 |
x_pos = self.sin_pos(x_pos.float()) # [B, T_ph, H]
|
215 |
+
return x_pos
|
216 |
+
|
217 |
+
def store_inverse_all(self):
|
218 |
+
def remove_weight_norm(m):
|
219 |
+
try:
|
220 |
+
if hasattr(m, 'store_inverse'):
|
221 |
+
m.store_inverse()
|
222 |
+
nn.utils.remove_weight_norm(m)
|
223 |
+
except ValueError: # this module didn't have weight norm
|
224 |
+
return
|
225 |
+
|
226 |
+
self.apply(remove_weight_norm)
|
tasks/tts/ps.py
CHANGED
@@ -156,14 +156,7 @@ class PortaSpeechTask(FastSpeechTask):
|
|
156 |
super().test_start()
|
157 |
if hparams.get('save_attn', False):
|
158 |
os.makedirs(f'{self.gen_dir}/attn', exist_ok=True)
|
159 |
-
|
160 |
-
def remove_weight_norm(m):
|
161 |
-
try:
|
162 |
-
nn.utils.remove_weight_norm(m)
|
163 |
-
except ValueError:
|
164 |
-
return
|
165 |
-
|
166 |
-
self.apply(remove_weight_norm)
|
167 |
|
168 |
def test_step(self, sample, batch_idx):
|
169 |
assert sample['txt_tokens'].shape[0] == 1, 'only support batch_size=1 in inference'
|
|
|
156 |
super().test_start()
|
157 |
if hparams.get('save_attn', False):
|
158 |
os.makedirs(f'{self.gen_dir}/attn', exist_ok=True)
|
159 |
+
self.model.store_inverse_all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
def test_step(self, sample, batch_idx):
|
162 |
assert sample['txt_tokens'].shape[0] == 1, 'only support batch_size=1 in inference'
|
tasks/tts/ps_flow.py
CHANGED
@@ -131,12 +131,4 @@ class PortaSpeechFlowTask(PortaSpeechTask):
|
|
131 |
return [self.optimizer]
|
132 |
|
133 |
def build_scheduler(self, optimizer):
|
134 |
-
return FastSpeechTask.build_scheduler(self, optimizer[0])
|
135 |
-
|
136 |
-
############
|
137 |
-
# infer
|
138 |
-
############
|
139 |
-
def test_start(self):
|
140 |
-
super().test_start()
|
141 |
-
if hparams['use_post_flow']:
|
142 |
-
self.model.post_flow.store_inverse()
|
|
|
131 |
return [self.optimizer]
|
132 |
|
133 |
def build_scheduler(self, optimizer):
|
134 |
+
return FastSpeechTask.build_scheduler(self, optimizer[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|