zyingt commited on
Commit
212c4a8
1 Parent(s): ade41ec

support timbre confusion

Browse files
Files changed (1) hide show
  1. models/tts/vits/vits_inference.py +48 -17
models/tts/vits/vits_inference.py CHANGED
@@ -109,6 +109,7 @@ class VitsInference(TTSInference):
109
 
110
  # get phone symbol file
111
  phone_symbol_file = os.path.join(self.exp_dir, self.cfg.preprocess.symbols_dict)
 
112
  assert os.path.exists(phone_symbol_file)
113
  # convert text to phone sequence
114
  phone_extractor = phoneExtractor(self.cfg)
@@ -129,26 +130,56 @@ class VitsInference(TTSInference):
129
  spk2id_file = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
130
  with open(spk2id_file, 'r') as f:
131
  spk2id = json.load(f)
132
- speaker_id = spk2id[self.args.speaker_name]
133
- print("speaker name:",self.args.speaker_name)
134
- print("speaker id:",speaker_id)
135
- speaker_id = torch.from_numpy(
136
- np.array([speaker_id], dtype=np.int32)
137
- ).unsqueeze(0)
 
 
 
 
 
 
138
 
139
  with torch.no_grad():
140
  x_tst = phone_id_seq.to(self.device).unsqueeze(0)
141
  x_tst_lengths = torch.LongTensor([phone_id_seq.size(0)]).to(self.device)
142
- speaker_id = speaker_id.to(self.device)
143
- outputs = self.model.infer(
144
- x_tst,
145
- x_tst_lengths,
146
- sid=speaker_id,
147
- noise_scale=noise_scale,
148
- noise_scale_w=noise_scale_w,
149
- length_scale=length_scale,
150
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- audio = outputs["y_hat"][0, 0].data.cpu().float().numpy()
 
 
153
 
154
- return audio
 
109
 
110
  # get phone symbol file
111
  phone_symbol_file = os.path.join(self.exp_dir, self.cfg.preprocess.symbols_dict)
112
+ print("phone_symbol_file:",phone_symbol_file)
113
  assert os.path.exists(phone_symbol_file)
114
  # convert text to phone sequence
115
  phone_extractor = phoneExtractor(self.cfg)
 
130
  spk2id_file = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
131
  with open(spk2id_file, 'r') as f:
132
  spk2id = json.load(f)
133
+ speaker_name_1= self.args.speaker_name_1
134
+ speaker_id_1= spk2id[speaker_name_1]
135
+ print("speaker 1:",speaker_name_1)
136
+ speaker_id_1 = torch.from_numpy(np.array([speaker_id_1],dtype=np.int32)).unsqueeze(0)
137
+
138
+ if self.args.speaker_name_2 != (None or "None"):
139
+ speaker_name_2= self.args.speaker_name_2
140
+ speaker_id_2 = spk2id[speaker_name_2]
141
+ print("speaker 2:",speaker_name_2)
142
+ speaker_id_2 = torch.from_numpy(np.array([speaker_id_2],dtype=np.int32)).unsqueeze(0)
143
+ else:
144
+ speaker_id_2 = None
145
 
146
  with torch.no_grad():
147
  x_tst = phone_id_seq.to(self.device).unsqueeze(0)
148
  x_tst_lengths = torch.LongTensor([phone_id_seq.size(0)]).to(self.device)
149
+ if speaker_id_1 is not None:
150
+ speaker_id_1 = speaker_id_1.to(self.device)
151
+ if speaker_id_2 is not None:
152
+ speaker_id_2 = speaker_id_2.to(self.device)
153
+ outputs_list = []
154
+ infer_list = [0,1,self.args.alpha]
155
+ for i in range(3):
156
+ outputs_list.append(self.model.infer(
157
+ x_tst,
158
+ x_tst_lengths,
159
+ sid_1=speaker_id_1,
160
+ sid_2=speaker_id_2,
161
+ alpha=infer_list[i],
162
+ noise_scale=noise_scale,
163
+ noise_scale_w=noise_scale_w,
164
+ length_scale=length_scale,
165
+ ))
166
+ audio_s1=outputs_list[0]["y_hat"][0, 0].data.cpu().float().numpy()
167
+ audio_s2=outputs_list[1]["y_hat"][0, 0].data.cpu().float().numpy()
168
+ audio_interpolated=outputs_list[2]["y_hat"][0, 0].data.cpu().float().numpy()
169
+ else:
170
+ outputs = self.model.infer(
171
+ x_tst,
172
+ x_tst_lengths,
173
+ sid_1=speaker_id_1,
174
+ sid_2=speaker_id_2,
175
+ alpha=self.args.alpha,
176
+ noise_scale=noise_scale,
177
+ noise_scale_w=noise_scale_w,
178
+ length_scale=length_scale,
179
+ )
180
 
181
+ audio_s1 = outputs["y_hat"][0, 0].data.cpu().float().numpy()
182
+ audio_s2 = None
183
+ audio_interpolated = None
184
 
185
+ return audio_s1, audio_s2, audio_interpolated