Bluebomber182 commited on
Commit
58a09c1
·
verified ·
1 Parent(s): e2083a1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +368 -0
app.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import safetensors
4
+ from huggingface_hub import hf_hub_download
5
+ import soundfile as sf
6
+ import os
7
+
8
+ import numpy as np
9
+ import librosa
10
+ from models.codec.kmeans.repcodec_model import RepCodec
11
+ from models.tts.maskgct.maskgct_s2a import MaskGCT_S2A
12
+ from models.tts.maskgct.maskgct_t2s import MaskGCT_T2S
13
+ from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder
14
+ from transformers import Wav2Vec2BertModel
15
+ from utils.util import load_config
16
+ from models.tts.maskgct.g2p.g2p_generation import g2p, chn_eng_g2p
17
+
18
+ from transformers import SeamlessM4TFeatureExtractor
19
+
20
+ processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
21
+
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+
24
+
25
+ def g2p_(text, language):
26
+ if language in ["zh", "en"]:
27
+ return chn_eng_g2p(text)
28
+ else:
29
+ return g2p(text, sentence=None, language=language)
30
+
31
+
32
+ def build_t2s_model(cfg, device):
33
+ t2s_model = MaskGCT_T2S(cfg=cfg)
34
+ t2s_model.eval()
35
+ t2s_model.to(device)
36
+ return t2s_model
37
+
38
+
39
+ def build_s2a_model(cfg, device):
40
+ soundstorm_model = MaskGCT_S2A(cfg=cfg)
41
+ soundstorm_model.eval()
42
+ soundstorm_model.to(device)
43
+ return soundstorm_model
44
+
45
+
46
+ def build_semantic_model(device):
47
+ semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
48
+ semantic_model.eval()
49
+ semantic_model.to(device)
50
+ stat_mean_var = torch.load("./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt")
51
+ semantic_mean = stat_mean_var["mean"]
52
+ semantic_std = torch.sqrt(stat_mean_var["var"])
53
+ semantic_mean = semantic_mean.to(device)
54
+ semantic_std = semantic_std.to(device)
55
+ return semantic_model, semantic_mean, semantic_std
56
+
57
+
58
+ def build_semantic_codec(cfg, device):
59
+ semantic_codec = RepCodec(cfg=cfg)
60
+ semantic_codec.eval()
61
+ semantic_codec.to(device)
62
+ return semantic_codec
63
+
64
+
65
+ def build_acoustic_codec(cfg, device):
66
+ codec_encoder = CodecEncoder(cfg=cfg.encoder)
67
+ codec_decoder = CodecDecoder(cfg=cfg.decoder)
68
+ codec_encoder.eval()
69
+ codec_decoder.eval()
70
+ codec_encoder.to(device)
71
+ codec_decoder.to(device)
72
+ return codec_encoder, codec_decoder
73
+
74
+
75
+ @torch.no_grad()
76
+ def extract_features(speech, processor):
77
+ inputs = processor(speech, sampling_rate=16000, return_tensors="pt")
78
+ input_features = inputs["input_features"][0]
79
+ attention_mask = inputs["attention_mask"][0]
80
+ return input_features, attention_mask
81
+
82
+
83
+ @torch.no_grad()
84
+ def extract_semantic_code(semantic_mean, semantic_std, input_features, attention_mask):
85
+ vq_emb = semantic_model(
86
+ input_features=input_features,
87
+ attention_mask=attention_mask,
88
+ output_hidden_states=True,
89
+ )
90
+ feat = vq_emb.hidden_states[17] # (B, T, C)
91
+ feat = (feat - semantic_mean.to(feat)) / semantic_std.to(feat)
92
+
93
+ semantic_code, rec_feat = semantic_codec.quantize(feat) # (B, T)
94
+ return semantic_code, rec_feat
95
+
96
+
97
+ @torch.no_grad()
98
+ def extract_acoustic_code(speech):
99
+ vq_emb = codec_encoder(speech.unsqueeze(1))
100
+ _, vq, _, _, _ = codec_decoder.quantizer(vq_emb)
101
+ acoustic_code = vq.permute(1, 2, 0)
102
+ return acoustic_code
103
+
104
+
105
+ @torch.no_grad()
106
+ def text2semantic(
107
+ device,
108
+ prompt_speech,
109
+ prompt_text,
110
+ prompt_language,
111
+ target_text,
112
+ target_language,
113
+ target_len=None,
114
+ n_timesteps=50,
115
+ cfg=2.5,
116
+ rescale_cfg=0.75,
117
+ ):
118
+
119
+ prompt_phone_id = g2p_(prompt_text, prompt_language)[1]
120
+
121
+ target_phone_id = g2p_(target_text, target_language)[1]
122
+
123
+ if target_len is None:
124
+ target_len = int(
125
+ (len(prompt_speech) * len(target_phone_id) / len(prompt_phone_id))
126
+ / 16000
127
+ * 50
128
+ )
129
+ else:
130
+ target_len = int(target_len * 50)
131
+
132
+ prompt_phone_id = torch.tensor(prompt_phone_id, dtype=torch.long).to(device)
133
+ target_phone_id = torch.tensor(target_phone_id, dtype=torch.long).to(device)
134
+
135
+ phone_id = torch.cat([prompt_phone_id, target_phone_id])
136
+
137
+ input_fetures, attention_mask = extract_features(prompt_speech, processor)
138
+ input_fetures = input_fetures.unsqueeze(0).to(device)
139
+ attention_mask = attention_mask.unsqueeze(0).to(device)
140
+ semantic_code, _ = extract_semantic_code(
141
+ semantic_mean, semantic_std, input_fetures, attention_mask
142
+ )
143
+
144
+ predict_semantic = t2s_model.reverse_diffusion(
145
+ semantic_code[:, :],
146
+ target_len,
147
+ phone_id.unsqueeze(0),
148
+ n_timesteps=n_timesteps,
149
+ cfg=cfg,
150
+ rescale_cfg=rescale_cfg,
151
+ )
152
+
153
+ combine_semantic_code = torch.cat([semantic_code[:, :], predict_semantic], dim=-1)
154
+ prompt_semantic_code = semantic_code
155
+
156
+ return combine_semantic_code, prompt_semantic_code
157
+
158
+
159
+ @torch.no_grad()
160
+ def semantic2acoustic(
161
+ device,
162
+ combine_semantic_code,
163
+ acoustic_code,
164
+ n_timesteps=[25, 10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
165
+ cfg=2.5,
166
+ rescale_cfg=0.75,
167
+ ):
168
+
169
+ semantic_code = combine_semantic_code
170
+
171
+ cond = s2a_model_1layer.cond_emb(semantic_code)
172
+ prompt = acoustic_code[:, :, :]
173
+ predict_1layer = s2a_model_1layer.reverse_diffusion(
174
+ cond=cond,
175
+ prompt=prompt,
176
+ temp=1.5,
177
+ filter_thres=0.98,
178
+ n_timesteps=n_timesteps[:1],
179
+ cfg=cfg,
180
+ rescale_cfg=rescale_cfg,
181
+ )
182
+
183
+ cond = s2a_model_full.cond_emb(semantic_code)
184
+ prompt = acoustic_code[:, :, :]
185
+ predict_full = s2a_model_full.reverse_diffusion(
186
+ cond=cond,
187
+ prompt=prompt,
188
+ temp=1.5,
189
+ filter_thres=0.98,
190
+ n_timesteps=n_timesteps,
191
+ cfg=cfg,
192
+ rescale_cfg=rescale_cfg,
193
+ gt_code=predict_1layer,
194
+ )
195
+
196
+ vq_emb = codec_decoder.vq2emb(predict_full.permute(2, 0, 1), n_quantizers=12)
197
+ recovered_audio = codec_decoder(vq_emb)
198
+ prompt_vq_emb = codec_decoder.vq2emb(prompt.permute(2, 0, 1), n_quantizers=12)
199
+ recovered_prompt_audio = codec_decoder(prompt_vq_emb)
200
+ recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy()
201
+ recovered_audio = recovered_audio[0][0].cpu().numpy()
202
+ combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio])
203
+
204
+ return combine_audio, recovered_audio
205
+
206
+
207
+ # Load the model and checkpoints
208
+ def load_models():
209
+ cfg_path = "./models/tts/maskgct/config/maskgct.json"
210
+
211
+ cfg = load_config(cfg_path)
212
+ semantic_model, semantic_mean, semantic_std = build_semantic_model(device)
213
+ semantic_codec = build_semantic_codec(cfg.model.semantic_codec, device)
214
+ codec_encoder, codec_decoder = build_acoustic_codec(
215
+ cfg.model.acoustic_codec, device
216
+ )
217
+ t2s_model = build_t2s_model(cfg.model.t2s_model, device)
218
+ s2a_model_1layer = build_s2a_model(cfg.model.s2a_model.s2a_1layer, device)
219
+ s2a_model_full = build_s2a_model(cfg.model.s2a_model.s2a_full, device)
220
+
221
+ # Download checkpoints
222
+ semantic_code_ckpt = hf_hub_download(
223
+ "amphion/MaskGCT", filename="semantic_codec/model.safetensors"
224
+ )
225
+ codec_encoder_ckpt = hf_hub_download(
226
+ "amphion/MaskGCT", filename="acoustic_codec/model.safetensors"
227
+ )
228
+ codec_decoder_ckpt = hf_hub_download(
229
+ "amphion/MaskGCT", filename="acoustic_codec/model_1.safetensors"
230
+ )
231
+ t2s_model_ckpt = hf_hub_download(
232
+ "amphion/MaskGCT", filename="t2s_model/model.safetensors"
233
+ )
234
+ s2a_1layer_ckpt = hf_hub_download(
235
+ "amphion/MaskGCT", filename="s2a_model/s2a_model_1layer/model.safetensors"
236
+ )
237
+ s2a_full_ckpt = hf_hub_download(
238
+ "amphion/MaskGCT", filename="s2a_model/s2a_model_full/model.safetensors"
239
+ )
240
+
241
+ safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)
242
+ safetensors.torch.load_model(codec_encoder, codec_encoder_ckpt)
243
+ safetensors.torch.load_model(codec_decoder, codec_decoder_ckpt)
244
+ safetensors.torch.load_model(t2s_model, t2s_model_ckpt)
245
+ safetensors.torch.load_model(s2a_model_1layer, s2a_1layer_ckpt)
246
+ safetensors.torch.load_model(s2a_model_full, s2a_full_ckpt)
247
+
248
+ return (
249
+ semantic_model,
250
+ semantic_mean,
251
+ semantic_std,
252
+ semantic_codec,
253
+ codec_encoder,
254
+ codec_decoder,
255
+ t2s_model,
256
+ s2a_model_1layer,
257
+ s2a_model_full,
258
+ )
259
+
260
+
261
+ @torch.no_grad()
262
+ def maskgct_inference(
263
+ prompt_speech_path,
264
+ prompt_text,
265
+ target_text,
266
+ language="en",
267
+ target_language="en",
268
+ target_len=None,
269
+ n_timesteps=25,
270
+ cfg=2.5,
271
+ rescale_cfg=0.75,
272
+ n_timesteps_s2a=[25, 10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
273
+ cfg_s2a=2.5,
274
+ rescale_cfg_s2a=0.75,
275
+ device=torch.device("cuda:5"),
276
+ ):
277
+ speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
278
+ speech = librosa.load(prompt_speech_path, sr=24000)[0]
279
+
280
+ combine_semantic_code, _ = text2semantic(
281
+ device,
282
+ speech_16k,
283
+ prompt_text,
284
+ language,
285
+ target_text,
286
+ target_language,
287
+ target_len,
288
+ n_timesteps,
289
+ cfg,
290
+ rescale_cfg,
291
+ )
292
+ acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device))
293
+ _, recovered_audio = semantic2acoustic(
294
+ device,
295
+ combine_semantic_code,
296
+ acoustic_code,
297
+ n_timesteps=n_timesteps_s2a,
298
+ cfg=cfg_s2a,
299
+ rescale_cfg=rescale_cfg_s2a,
300
+ )
301
+
302
+ return recovered_audio
303
+
304
+
305
+ @torch.no_grad()
306
+ def inference(
307
+ prompt_wav,
308
+ prompt_text,
309
+ target_text,
310
+ target_len,
311
+ n_timesteps,
312
+ language,
313
+ target_language,
314
+ ):
315
+ save_path = "./output/output.wav"
316
+ os.makedirs("./output", exist_ok=True)
317
+ recovered_audio = maskgct_inference(
318
+ prompt_wav,
319
+ prompt_text,
320
+ target_text,
321
+ language,
322
+ target_language,
323
+ target_len=target_len,
324
+ n_timesteps=int(n_timesteps),
325
+ device=device,
326
+ )
327
+ sf.write(save_path, recovered_audio, 24000)
328
+ return save_path
329
+
330
+ # Load models once
331
+ (
332
+ semantic_model,
333
+ semantic_mean,
334
+ semantic_std,
335
+ semantic_codec,
336
+ codec_encoder,
337
+ codec_decoder,
338
+ t2s_model,
339
+ s2a_model_1layer,
340
+ s2a_model_full,
341
+ ) = load_models()
342
+
343
+ # Language list
344
+ language_list = ["en", "zh", "ja", "ko", "fr", "de"]
345
+
346
+ # Gradio interface
347
+ iface = gr.Interface(
348
+ fn=inference,
349
+ inputs=[
350
+ gr.Audio(label="Upload Prompt Wav", type="filepath"),
351
+ gr.Textbox(label="Prompt Text"),
352
+ gr.Textbox(label="Target Text"),
353
+ gr.Number(
354
+ label="Target Duration (in seconds)", value=None
355
+ ), # Removed 'optional=True'
356
+ gr.Slider(
357
+ label="Number of Timesteps", minimum=15, maximum=100, value=25, step=1
358
+ ),
359
+ gr.Dropdown(label="Language", choices=language_list, value="en"),
360
+ gr.Dropdown(label="Target Language", choices=language_list, value="en"),
361
+ ],
362
+ outputs=gr.Audio(label="Generated Audio"),
363
+ title="MaskGCT TTS Demo",
364
+ description="Generate speech from text using the MaskGCT model.",
365
+ )
366
+
367
+ # Launch the interface
368
+ iface.launch(allowed_paths=["./output"])