mrfakename commited on
Commit
57b3db8
·
verified ·
1 Parent(s): 753c05d

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

src/f5_tts/eval/eval_infer_batch.py CHANGED
@@ -189,13 +189,13 @@ def main():
189
  gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
190
  gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
191
  if mel_spec_type == "vocos":
192
- generated_wave = vocoder.decode(gen_mel_spec)
193
  elif mel_spec_type == "bigvgan":
194
- generated_wave = vocoder(gen_mel_spec)
195
 
196
  if ref_rms_list[i] < target_rms:
197
  generated_wave = generated_wave * ref_rms_list[i] / target_rms
198
- torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.cpu(), target_sample_rate)
199
 
200
  accelerator.wait_for_everyone()
201
  if accelerator.is_main_process:
 
189
  gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
190
  gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
191
  if mel_spec_type == "vocos":
192
+ generated_wave = vocoder.decode(gen_mel_spec).cpu()
193
  elif mel_spec_type == "bigvgan":
194
+ generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
195
 
196
  if ref_rms_list[i] < target_rms:
197
  generated_wave = generated_wave * ref_rms_list[i] / target_rms
198
+ torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
199
 
200
  accelerator.wait_for_everyone()
201
  if accelerator.is_main_process:
src/f5_tts/infer/speech_edit.py CHANGED
@@ -181,13 +181,13 @@ with torch.inference_mode():
181
  generated = generated[:, ref_audio_len:, :]
182
  gen_mel_spec = generated.permute(0, 2, 1)
183
  if mel_spec_type == "vocos":
184
- generated_wave = vocoder.decode(gen_mel_spec)
185
  elif mel_spec_type == "bigvgan":
186
- generated_wave = vocoder(gen_mel_spec)
187
 
188
  if rms < target_rms:
189
  generated_wave = generated_wave * rms / target_rms
190
 
191
  save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
192
- torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.cpu(), target_sample_rate)
193
  print(f"Generated wav: {generated_wave.shape}")
 
181
  generated = generated[:, ref_audio_len:, :]
182
  gen_mel_spec = generated.permute(0, 2, 1)
183
  if mel_spec_type == "vocos":
184
+ generated_wave = vocoder.decode(gen_mel_spec).cpu()
185
  elif mel_spec_type == "bigvgan":
186
+ generated_wave = vocoder(gen_mel_spec).squeeze(0).cpu()
187
 
188
  if rms < target_rms:
189
  generated_wave = generated_wave * rms / target_rms
190
 
191
  save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
192
+ torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave, target_sample_rate)
193
  print(f"Generated wav: {generated_wave.shape}")
src/f5_tts/model/trainer.py CHANGED
@@ -324,26 +324,31 @@ class Trainer:
324
  self.save_checkpoint(global_step)
325
 
326
  if self.log_samples and self.accelerator.is_local_main_process:
327
- ref_audio, ref_audio_len = vocoder.decode(batch["mel"][0].unsqueeze(0)), mel_lengths[0]
328
- torchaudio.save(
329
- f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio.cpu(), target_sample_rate
330
- )
331
  with torch.inference_mode():
332
  generated, _ = self.accelerator.unwrap_model(self.model).sample(
333
  cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
334
- text=[text_inputs[0] + [" "] + text_inputs[0]],
335
  duration=ref_audio_len * 2,
336
  steps=nfe_step,
337
  cfg_strength=cfg_strength,
338
  sway_sampling_coef=sway_sampling_coef,
339
  )
340
- generated = generated.to(torch.float32)
341
- gen_audio = vocoder.decode(
342
- generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
343
- )
344
- torchaudio.save(
345
- f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio.cpu(), target_sample_rate
346
- )
 
 
 
 
 
347
 
348
  if global_step % self.last_per_steps == 0:
349
  self.save_checkpoint(global_step, last=True)
 
324
  self.save_checkpoint(global_step)
325
 
326
  if self.log_samples and self.accelerator.is_local_main_process:
327
+ ref_audio_len = mel_lengths[0]
328
+ infer_text = [
329
+ text_inputs[0] + ([" "] if isinstance(text_inputs[0], list) else " ") + text_inputs[0]
330
+ ]
331
  with torch.inference_mode():
332
  generated, _ = self.accelerator.unwrap_model(self.model).sample(
333
  cond=mel_spec[0][:ref_audio_len].unsqueeze(0),
334
+ text=infer_text,
335
  duration=ref_audio_len * 2,
336
  steps=nfe_step,
337
  cfg_strength=cfg_strength,
338
  sway_sampling_coef=sway_sampling_coef,
339
  )
340
+ generated = generated.to(torch.float32)
341
+ gen_mel_spec = generated[:, ref_audio_len:, :].permute(0, 2, 1).to(self.accelerator.device)
342
+ ref_mel_spec = batch["mel"][0].unsqueeze(0)
343
+ if self.vocoder_name == "vocos":
344
+ gen_audio = vocoder.decode(gen_mel_spec).cpu()
345
+ ref_audio = vocoder.decode(ref_mel_spec).cpu()
346
+ elif self.vocoder_name == "bigvgan":
347
+ gen_audio = vocoder(gen_mel_spec).squeeze(0).cpu()
348
+ ref_audio = vocoder(ref_mel_spec).squeeze(0).cpu()
349
+
350
+ torchaudio.save(f"{log_samples_path}/step_{global_step}_gen.wav", gen_audio, target_sample_rate)
351
+ torchaudio.save(f"{log_samples_path}/step_{global_step}_ref.wav", ref_audio, target_sample_rate)
352
 
353
  if global_step % self.last_per_steps == 0:
354
  self.save_checkpoint(global_step, last=True)