Gregniuki commited on
Commit
3facf82
1 Parent(s): 8ec1a00

Upload 6 files

Browse files
Files changed (6) hide show
  1. cog.py +180 -0
  2. packages.txt +1 -0
  3. requirements.txt +23 -0
  4. test_infer_batch.py +202 -0
  5. test_infer_batch.sh +13 -0
  6. test_infer_single.py +162 -0
cog.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prediction interface for Cog ⚙️
2
+ # https://cog.run/python
3
+
4
+ from cog import BasePredictor, Input, Path
5
+
6
+ import os
7
+ import re
8
+ import torch
9
+ import torchaudio
10
+ import numpy as np
11
+ import tempfile
12
+ from einops import rearrange
13
+ from ema_pytorch import EMA
14
+ from vocos import Vocos
15
+ from pydub import AudioSegment
16
+ from model import CFM, UNetT, DiT, MMDiT
17
+ from cached_path import cached_path
18
+ from model.utils import (
19
+ get_tokenizer,
20
+ convert_char_to_pinyin,
21
+ save_spectrogram,
22
+ )
23
+ from transformers import pipeline
24
+ import librosa
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
27
+
28
+ target_sample_rate = 24000
29
+ n_mel_channels = 100
30
+ hop_length = 256
31
+ target_rms = 0.1
32
+ nfe_step = 32 # 16, 32
33
+ cfg_strength = 2.0
34
+ ode_method = 'euler'
35
+ sway_sampling_coef = -1.0
36
+ speed = 1.0
37
+ # fix_duration = 27 # None or float (duration in seconds)
38
+ fix_duration = None
39
+
40
+
41
+ class Predictor(BasePredictor):
42
+ def load_model(exp_name, model_cls, model_cfg, ckpt_step):
43
+ checkpoint = torch.load(str(cached_path(f"hf://SWivid/F5-TTS/{exp_name}/model_{ckpt_step}.pt")), map_location=device)
44
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
45
+ model = CFM(
46
+ transformer=model_cls(
47
+ **model_cfg,
48
+ text_num_embeds=vocab_size,
49
+ mel_dim=n_mel_channels
50
+ ),
51
+ mel_spec_kwargs=dict(
52
+ target_sample_rate=target_sample_rate,
53
+ n_mel_channels=n_mel_channels,
54
+ hop_length=hop_length,
55
+ ),
56
+ odeint_kwargs=dict(
57
+ method=ode_method,
58
+ ),
59
+ vocab_char_map=vocab_char_map,
60
+ ).to(device)
61
+
62
+ ema_model = EMA(model, include_online_model=False).to(device)
63
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
64
+ ema_model.copy_params_from_ema_to_model()
65
+
66
+ return ema_model, model
67
+ def setup(self) -> None:
68
+ """Load the model into memory to make running multiple predictions efficient"""
69
+ # self.model = torch.load("./weights.pth")
70
+ print("Loading Whisper model...")
71
+ self.pipe = pipeline(
72
+ "automatic-speech-recognition",
73
+ model="openai/whisper-large-v3-turbo",
74
+ torch_dtype=torch.float16,
75
+ device=device,
76
+ )
77
+ print("Loading F5-TTS model...")
78
+
79
+ F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
80
+ self.F5TTS_ema_model, self.F5TTS_base_model = self.load_model("F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
81
+
82
+
83
+ def predict(
84
+ self,
85
+ gen_text: str = Input(description="Text to generate"),
86
+ ref_audio_orig: Path = Input(description="Reference audio"),
87
+ remove_silence: bool = Input(description="Remove silences", default=True),
88
+ ) -> Path:
89
+ """Run a single prediction on the model"""
90
+ model_choice = "F5-TTS"
91
+ print(gen_text)
92
+ if len(gen_text) > 200:
93
+ raise gr.Error("Please keep your text under 200 chars.")
94
+ gr.Info("Converting audio...")
95
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
96
+ aseg = AudioSegment.from_file(ref_audio_orig)
97
+ audio_duration = len(aseg)
98
+ if audio_duration > 15000:
99
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
100
+ aseg = aseg[:15000]
101
+ aseg.export(f.name, format="wav")
102
+ ref_audio = f.name
103
+ ema_model = self.F5TTS_ema_model
104
+ base_model = self.F5TTS_base_model
105
+
106
+ if not ref_text.strip():
107
+ gr.Info("No reference text provided, transcribing reference audio...")
108
+ ref_text = outputs = self.pipe(
109
+ ref_audio,
110
+ chunk_length_s=30,
111
+ batch_size=128,
112
+ generate_kwargs={"task": "transcribe"},
113
+ return_timestamps=False,
114
+ )['text'].strip()
115
+ gr.Info("Finished transcription")
116
+ else:
117
+ gr.Info("Using custom reference text...")
118
+ audio, sr = torchaudio.load(ref_audio)
119
+
120
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
121
+ if rms < target_rms:
122
+ audio = audio * target_rms / rms
123
+ if sr != target_sample_rate:
124
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
125
+ audio = resampler(audio)
126
+ audio = audio.to(device)
127
+
128
+ # Prepare the text
129
+ text_list = [ref_text + gen_text]
130
+ final_text_list = convert_char_to_pinyin(text_list)
131
+
132
+ # Calculate duration
133
+ ref_audio_len = audio.shape[-1] // hop_length
134
+ # if fix_duration is not None:
135
+ # duration = int(fix_duration * target_sample_rate / hop_length)
136
+ # else:
137
+ zh_pause_punc = r"。,、;:?!"
138
+ ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
139
+ gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
140
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
141
+
142
+ # inference
143
+ gr.Info(f"Generating audio using F5-TTS")
144
+ with torch.inference_mode():
145
+ generated, _ = base_model.sample(
146
+ cond=audio,
147
+ text=final_text_list,
148
+ duration=duration,
149
+ steps=nfe_step,
150
+ cfg_strength=cfg_strength,
151
+ sway_sampling_coef=sway_sampling_coef,
152
+ )
153
+
154
+ generated = generated[:, ref_audio_len:, :]
155
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
156
+ gr.Info("Running vocoder")
157
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
158
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
159
+ if rms < target_rms:
160
+ generated_wave = generated_wave * rms / target_rms
161
+
162
+ # wav -> numpy
163
+ generated_wave = generated_wave.squeeze().cpu().numpy()
164
+
165
+ if remove_silence:
166
+ gr.Info("Removing audio silences... This may take a moment")
167
+ non_silent_intervals = librosa.effects.split(generated_wave, top_db=30)
168
+ non_silent_wave = np.array([])
169
+ for interval in non_silent_intervals:
170
+ start, end = interval
171
+ non_silent_wave = np.concatenate([non_silent_wave, generated_wave[start:end]])
172
+ generated_wave = non_silent_wave
173
+
174
+
175
+ # spectogram
176
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
177
+ wav_path = tmp_wav.name
178
+ torchaudio.save(wav_path, torch.tensor(generated_wave), target_sample_rate)
179
+
180
+ return wav_path
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.33.0
2
+ cached_path
3
+ click
4
+ datasets
5
+ einops>=0.8.0
6
+ einx>=0.3.0
7
+ ema_pytorch>=0.5.2
8
+ gradio
9
+ jieba
10
+ librosa
11
+ matplotlib
12
+ numpy<=1.26.4
13
+ pydub
14
+ pypinyin
15
+ safetensors
16
+ soundfile
17
+ tomli
18
+ torchdiffeq
19
+ tqdm>=4.65.0
20
+ transformers
21
+ vocos
22
+ wandb
23
+ x_transformers>=1.31.14
test_infer_batch.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import random
4
+ from tqdm import tqdm
5
+ import argparse
6
+
7
+ import torch
8
+ import torchaudio
9
+ from accelerate import Accelerator
10
+ from einops import rearrange
11
+ from ema_pytorch import EMA
12
+ from vocos import Vocos
13
+
14
+ from model import CFM, UNetT, DiT
15
+ from model.utils import (
16
+ get_tokenizer,
17
+ get_seedtts_testset_metainfo,
18
+ get_librispeech_test_clean_metainfo,
19
+ get_inference_prompt,
20
+ )
21
+
22
+ accelerator = Accelerator()
23
+ device = f"cuda:{accelerator.process_index}"
24
+
25
+
26
+ # --------------------- Dataset Settings -------------------- #
27
+
28
+ target_sample_rate = 24000
29
+ n_mel_channels = 100
30
+ hop_length = 256
31
+ target_rms = 0.1
32
+
33
+ tokenizer = "pinyin"
34
+
35
+
36
+ # ---------------------- infer setting ---------------------- #
37
+
38
+ parser = argparse.ArgumentParser(description="batch inference")
39
+
40
+ parser.add_argument('-s', '--seed', default=None, type=int)
41
+ parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
42
+ parser.add_argument('-n', '--expname', required=True)
43
+ parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
44
+
45
+ parser.add_argument('-nfe', '--nfestep', default=32, type=int)
46
+ parser.add_argument('-o', '--odemethod', default="euler")
47
+ parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
48
+
49
+ parser.add_argument('-t', '--testset', required=True)
50
+
51
+ args = parser.parse_args()
52
+
53
+
54
+ seed = args.seed
55
+ dataset_name = args.dataset
56
+ exp_name = args.expname
57
+ ckpt_step = args.ckptstep
58
+ checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
59
+
60
+ nfe_step = args.nfestep
61
+ ode_method = args.odemethod
62
+ sway_sampling_coef = args.swaysampling
63
+
64
+ testset = args.testset
65
+
66
+
67
+ infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
68
+ cfg_strength = 2.
69
+ speed = 1.
70
+ use_truth_duration = False
71
+ no_ref_audio = False
72
+
73
+
74
+ if exp_name == "F5TTS_Base":
75
+ model_cls = DiT
76
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
77
+
78
+ elif exp_name == "E2TTS_Base":
79
+ model_cls = UNetT
80
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
81
+
82
+
83
+ if testset == "ls_pc_test_clean":
84
+ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
85
+ librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
86
+ metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
87
+
88
+ elif testset == "seedtts_test_zh":
89
+ metalst = "data/seedtts_testset/zh/meta.lst"
90
+ metainfo = get_seedtts_testset_metainfo(metalst)
91
+
92
+ elif testset == "seedtts_test_en":
93
+ metalst = "data/seedtts_testset/en/meta.lst"
94
+ metainfo = get_seedtts_testset_metainfo(metalst)
95
+
96
+
97
+ # path to save genereted wavs
98
+ if seed is None: seed = random.randint(-10000, 10000)
99
+ output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
100
+ f"seed{seed}_{ode_method}_nfe{nfe_step}" \
101
+ f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
102
+ f"_cfg{cfg_strength}_speed{speed}" \
103
+ f"{'_gt-dur' if use_truth_duration else ''}" \
104
+ f"{'_no-ref-audio' if no_ref_audio else ''}"
105
+
106
+
107
+ # -------------------------------------------------#
108
+
109
+ use_ema = True
110
+
111
+ prompts_all = get_inference_prompt(
112
+ metainfo,
113
+ speed = speed,
114
+ tokenizer = tokenizer,
115
+ target_sample_rate = target_sample_rate,
116
+ n_mel_channels = n_mel_channels,
117
+ hop_length = hop_length,
118
+ target_rms = target_rms,
119
+ use_truth_duration = use_truth_duration,
120
+ infer_batch_size = infer_batch_size,
121
+ )
122
+
123
+ # Vocoder model
124
+ local = False
125
+ if local:
126
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
127
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
128
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
129
+ vocos.load_state_dict(state_dict)
130
+ vocos.eval()
131
+ else:
132
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
133
+
134
+ # Tokenizer
135
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
136
+
137
+ # Model
138
+ model = CFM(
139
+ transformer = model_cls(
140
+ **model_cfg,
141
+ text_num_embeds = vocab_size,
142
+ mel_dim = n_mel_channels
143
+ ),
144
+ mel_spec_kwargs = dict(
145
+ target_sample_rate = target_sample_rate,
146
+ n_mel_channels = n_mel_channels,
147
+ hop_length = hop_length,
148
+ ),
149
+ odeint_kwargs = dict(
150
+ method = ode_method,
151
+ ),
152
+ vocab_char_map = vocab_char_map,
153
+ ).to(device)
154
+
155
+ if use_ema == True:
156
+ ema_model = EMA(model, include_online_model = False).to(device)
157
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
158
+ ema_model.copy_params_from_ema_to_model()
159
+ else:
160
+ model.load_state_dict(checkpoint['model_state_dict'])
161
+
162
+ if not os.path.exists(output_dir) and accelerator.is_main_process:
163
+ os.makedirs(output_dir)
164
+
165
+ # start batch inference
166
+ accelerator.wait_for_everyone()
167
+ start = time.time()
168
+
169
+ with accelerator.split_between_processes(prompts_all) as prompts:
170
+
171
+ for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
172
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
173
+ ref_mels = ref_mels.to(device)
174
+ ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
175
+ total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
176
+
177
+ # Inference
178
+ with torch.inference_mode():
179
+ generated, _ = model.sample(
180
+ cond = ref_mels,
181
+ text = final_text_list,
182
+ duration = total_mel_lens,
183
+ lens = ref_mel_lens,
184
+ steps = nfe_step,
185
+ cfg_strength = cfg_strength,
186
+ sway_sampling_coef = sway_sampling_coef,
187
+ no_ref_audio = no_ref_audio,
188
+ seed = seed,
189
+ )
190
+ # Final result
191
+ for i, gen in enumerate(generated):
192
+ gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
193
+ gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
194
+ generated_wave = vocos.decode(gen_mel_spec.cpu())
195
+ if ref_rms_list[i] < target_rms:
196
+ generated_wave = generated_wave * ref_rms_list[i] / target_rms
197
+ torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
198
+
199
+ accelerator.wait_for_everyone()
200
+ if accelerator.is_main_process:
201
+ timediff = time.time() - start
202
+ print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
test_infer_batch.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # e.g. F5-TTS, 16 NFE
4
+ accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
+ accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
+ accelerate launch test_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
+
8
+ # e.g. Vanilla E2 TTS, 32 NFE
9
+ accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
+ accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
+ accelerate launch test_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
+
13
+ # etc.
test_infer_single.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import torch
5
+ import torchaudio
6
+ from einops import rearrange
7
+ from ema_pytorch import EMA
8
+ from vocos import Vocos
9
+
10
+ from model import CFM, UNetT, DiT, MMDiT
11
+ from model.utils import (
12
+ get_tokenizer,
13
+ convert_char_to_pinyin,
14
+ save_spectrogram,
15
+ )
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+
20
+ # --------------------- Dataset Settings -------------------- #
21
+
22
+ target_sample_rate = 24000
23
+ n_mel_channels = 100
24
+ hop_length = 256
25
+ target_rms = 0.1
26
+
27
+ tokenizer = "pinyin"
28
+ dataset_name = "Emilia_ZH_EN"
29
+
30
+
31
+ # ---------------------- infer setting ---------------------- #
32
+
33
+ seed = None # int | None
34
+
35
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
36
+ ckpt_step = 1200000
37
+
38
+ nfe_step = 32 # 16, 32
39
+ cfg_strength = 2.
40
+ ode_method = 'euler' # euler | midpoint
41
+ sway_sampling_coef = -1.
42
+ speed = 1.
43
+ fix_duration = 27 # None (will linear estimate. if code-switched, consider fix) | float (total in seconds, include ref audio)
44
+
45
+ if exp_name == "F5TTS_Base":
46
+ model_cls = DiT
47
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
48
+
49
+ elif exp_name == "E2TTS_Base":
50
+ model_cls = UNetT
51
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
52
+
53
+ checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
54
+ output_dir = "tests"
55
+
56
+ ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
57
+ ref_text = "Some call me nature, others call me mother nature."
58
+ gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
59
+
60
+ # ref_audio = "tests/ref_audio/test_zh_1_ref_short.wav"
61
+ # ref_text = "对,这就是我,万人敬仰的太乙真人。"
62
+ # gen_text = "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:\"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?\""
63
+
64
+
65
+ # -------------------------------------------------#
66
+
67
+ use_ema = True
68
+
69
+ if not os.path.exists(output_dir):
70
+ os.makedirs(output_dir)
71
+
72
+ # Vocoder model
73
+ local = False
74
+ if local:
75
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
76
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
77
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
78
+ vocos.load_state_dict(state_dict)
79
+ vocos.eval()
80
+ else:
81
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
82
+
83
+ # Tokenizer
84
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
85
+
86
+ # Model
87
+ model = CFM(
88
+ transformer = model_cls(
89
+ **model_cfg,
90
+ text_num_embeds = vocab_size,
91
+ mel_dim = n_mel_channels
92
+ ),
93
+ mel_spec_kwargs = dict(
94
+ target_sample_rate = target_sample_rate,
95
+ n_mel_channels = n_mel_channels,
96
+ hop_length = hop_length,
97
+ ),
98
+ odeint_kwargs = dict(
99
+ method = ode_method,
100
+ ),
101
+ vocab_char_map = vocab_char_map,
102
+ ).to(device)
103
+
104
+ if use_ema == True:
105
+ ema_model = EMA(model, include_online_model = False).to(device)
106
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
107
+ ema_model.copy_params_from_ema_to_model()
108
+ else:
109
+ model.load_state_dict(checkpoint['model_state_dict'])
110
+
111
+ # Audio
112
+ audio, sr = torchaudio.load(ref_audio)
113
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
114
+ if rms < target_rms:
115
+ audio = audio * target_rms / rms
116
+ if sr != target_sample_rate:
117
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
118
+ audio = resampler(audio)
119
+ audio = audio.to(device)
120
+
121
+ # Text
122
+ text_list = [ref_text + gen_text]
123
+ if tokenizer == "pinyin":
124
+ final_text_list = convert_char_to_pinyin(text_list)
125
+ else:
126
+ final_text_list = [text_list]
127
+ print(f"text : {text_list}")
128
+ print(f"pinyin: {final_text_list}")
129
+
130
+ # Duration
131
+ ref_audio_len = audio.shape[-1] // hop_length
132
+ if fix_duration is not None:
133
+ duration = int(fix_duration * target_sample_rate / hop_length)
134
+ else: # simple linear scale calcul
135
+ zh_pause_punc = r"。,、;:?!"
136
+ ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text))
137
+ gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text))
138
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
139
+
140
+ # Inference
141
+ with torch.inference_mode():
142
+ generated, trajectory = model.sample(
143
+ cond = audio,
144
+ text = final_text_list,
145
+ duration = duration,
146
+ steps = nfe_step,
147
+ cfg_strength = cfg_strength,
148
+ sway_sampling_coef = sway_sampling_coef,
149
+ seed = seed,
150
+ )
151
+ print(f"Generated mel: {generated.shape}")
152
+
153
+ # Final result
154
+ generated = generated[:, ref_audio_len:, :]
155
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
156
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
157
+ if rms < target_rms:
158
+ generated_wave = generated_wave * rms / target_rms
159
+
160
+ save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single.png")
161
+ torchaudio.save(f"{output_dir}/test_single.wav", generated_wave, target_sample_rate)
162
+ print(f"Generated wav: {generated_wave.shape}")