Bluebomber182 commited on
Commit
5ac5977
·
verified ·
1 Parent(s): 58a09c1

Upload app-oct-27.py

Browse files
Files changed (1) hide show
  1. app-oct-27.py +417 -0
app-oct-27.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import accelerate
3
+ import gradio as gr
4
+ import torch
5
+ import safetensors
6
+ from huggingface_hub import hf_hub_download
7
+ import soundfile as sf
8
+ import os
9
+
10
+ import numpy as np
11
+ import librosa
12
+ from models.codec.kmeans.repcodec_model import RepCodec
13
+ from models.tts.maskgct.maskgct_s2a import MaskGCT_S2A
14
+ from models.tts.maskgct.maskgct_t2s import MaskGCT_T2S
15
+ from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder
16
+ from transformers import Wav2Vec2BertModel
17
+ from utils.util import load_config
18
+ from models.tts.maskgct.g2p.g2p_generation import g2p, chn_eng_g2p
19
+
20
+ from transformers import SeamlessM4TFeatureExtractor
21
+ import py3langid as langid
22
+
23
+
24
+ processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
26
+ whisper_model = None
27
+ output_file_name_idx = 0
28
+
29
+ def detect_text_language(text):
30
+ return langid.classify(text)[0]
31
+
32
+ def detect_speech_language(speech_file):
33
+ import whisper
34
+ global whisper_model
35
+ if whisper_model == None:
36
+ whisper_model = whisper.load_model("turbo")
37
+ # load audio and pad/trim it to fit 30 seconds
38
+ audio = whisper.load_audio(speech_file)
39
+ audio = whisper.pad_or_trim(audio)
40
+
41
+ # make log-Mel spectrogram and move to the same device as the model
42
+ mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(whisper_model.device)
43
+
44
+ # detect the spoken language
45
+ _, probs = whisper_model.detect_language(mel)
46
+ return max(probs, key=probs.get)
47
+
48
+
49
+ @torch.no_grad()
50
+ def get_prompt_text(speech_16k, language):
51
+ full_prompt_text = ""
52
+ shot_prompt_text = ""
53
+ short_prompt_end_ts = 0.0
54
+
55
+ import whisper
56
+ global whisper_model
57
+ if whisper_model == None:
58
+ whisper_model = whisper.load_model("turbo")
59
+ asr_result = whisper_model.transcribe(speech_16k, language=language)
60
+ full_prompt_text = asr_result["text"] # whisper asr result
61
+ #text = asr_result["segments"][0]["text"] # whisperx asr result
62
+ shot_prompt_text = ""
63
+ short_prompt_end_ts = 0.0
64
+ for segment in asr_result["segments"]:
65
+ shot_prompt_text = shot_prompt_text + segment['text']
66
+ short_prompt_end_ts = segment['end']
67
+ if short_prompt_end_ts >= 4:
68
+ break
69
+ return full_prompt_text, shot_prompt_text, short_prompt_end_ts
70
+
71
+
72
+ def g2p_(text, language):
73
+ if language in ["zh", "en"]:
74
+ return chn_eng_g2p(text)
75
+ else:
76
+ return g2p(text, sentence=None, language=language)
77
+
78
+
79
+ def build_t2s_model(cfg, device):
80
+ t2s_model = MaskGCT_T2S(cfg=cfg)
81
+ t2s_model.eval()
82
+ t2s_model.to(device)
83
+ return t2s_model
84
+
85
+
86
+ def build_s2a_model(cfg, device):
87
+ soundstorm_model = MaskGCT_S2A(cfg=cfg)
88
+ soundstorm_model.eval()
89
+ soundstorm_model.to(device)
90
+ return soundstorm_model
91
+
92
+
93
+ def build_semantic_model(device):
94
+ semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
95
+ semantic_model.eval()
96
+ semantic_model.to(device)
97
+ stat_mean_var = torch.load("./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt")
98
+ semantic_mean = stat_mean_var["mean"]
99
+ semantic_std = torch.sqrt(stat_mean_var["var"])
100
+ semantic_mean = semantic_mean.to(device)
101
+ semantic_std = semantic_std.to(device)
102
+ return semantic_model, semantic_mean, semantic_std
103
+
104
+
105
+ def build_semantic_codec(cfg, device):
106
+ semantic_codec = RepCodec(cfg=cfg)
107
+ semantic_codec.eval()
108
+ semantic_codec.to(device)
109
+ return semantic_codec
110
+
111
+
112
+ def build_acoustic_codec(cfg, device):
113
+ codec_encoder = CodecEncoder(cfg=cfg.encoder)
114
+ codec_decoder = CodecDecoder(cfg=cfg.decoder)
115
+ codec_encoder.eval()
116
+ codec_decoder.eval()
117
+ codec_encoder.to(device)
118
+ codec_decoder.to(device)
119
+ return codec_encoder, codec_decoder
120
+
121
+
122
+ @torch.no_grad()
123
+ def extract_features(speech, processor):
124
+ inputs = processor(speech, sampling_rate=16000, return_tensors="pt")
125
+ input_features = inputs["input_features"][0]
126
+ attention_mask = inputs["attention_mask"][0]
127
+ return input_features, attention_mask
128
+
129
+
130
+ @torch.no_grad()
131
+ def extract_semantic_code(semantic_mean, semantic_std, input_features, attention_mask):
132
+ vq_emb = semantic_model(
133
+ input_features=input_features,
134
+ attention_mask=attention_mask,
135
+ output_hidden_states=True,
136
+ )
137
+ feat = vq_emb.hidden_states[17] # (B, T, C)
138
+ feat = (feat - semantic_mean.to(feat)) / semantic_std.to(feat)
139
+
140
+ semantic_code, rec_feat = semantic_codec.quantize(feat) # (B, T)
141
+ return semantic_code, rec_feat
142
+
143
+
144
+ @torch.no_grad()
145
+ def extract_acoustic_code(speech):
146
+ vq_emb = codec_encoder(speech.unsqueeze(1))
147
+ _, vq, _, _, _ = codec_decoder.quantizer(vq_emb)
148
+ acoustic_code = vq.permute(1, 2, 0)
149
+ return acoustic_code
150
+
151
+
152
+ @torch.no_grad()
153
+ def text2semantic(
154
+ device,
155
+ prompt_speech,
156
+ prompt_text,
157
+ prompt_language,
158
+ target_text,
159
+ target_language,
160
+ target_len=None,
161
+ n_timesteps=50,
162
+ cfg=2.5,
163
+ rescale_cfg=0.75,
164
+ ):
165
+
166
+ prompt_phone_id = g2p_(prompt_text, prompt_language)[1]
167
+
168
+ target_phone_id = g2p_(target_text, target_language)[1]
169
+
170
+ if target_len < 0:
171
+ target_len = int(
172
+ (len(prompt_speech) * len(target_phone_id) / len(prompt_phone_id))
173
+ / 16000
174
+ * 50
175
+ )
176
+ else:
177
+ target_len = int(target_len * 50)
178
+
179
+ prompt_phone_id = torch.tensor(prompt_phone_id, dtype=torch.long).to(device)
180
+ target_phone_id = torch.tensor(target_phone_id, dtype=torch.long).to(device)
181
+
182
+ phone_id = torch.cat([prompt_phone_id, target_phone_id])
183
+
184
+ input_fetures, attention_mask = extract_features(prompt_speech, processor)
185
+ input_fetures = input_fetures.unsqueeze(0).to(device)
186
+ attention_mask = attention_mask.unsqueeze(0).to(device)
187
+ semantic_code, _ = extract_semantic_code(
188
+ semantic_mean, semantic_std, input_fetures, attention_mask
189
+ )
190
+
191
+ predict_semantic = t2s_model.reverse_diffusion(
192
+ semantic_code[:, :],
193
+ target_len,
194
+ phone_id.unsqueeze(0),
195
+ n_timesteps=n_timesteps,
196
+ cfg=cfg,
197
+ rescale_cfg=rescale_cfg,
198
+ )
199
+
200
+ combine_semantic_code = torch.cat([semantic_code[:, :], predict_semantic], dim=-1)
201
+ prompt_semantic_code = semantic_code
202
+
203
+ return combine_semantic_code, prompt_semantic_code
204
+
205
+
206
+ @torch.no_grad()
207
+ def semantic2acoustic(
208
+ device,
209
+ combine_semantic_code,
210
+ acoustic_code,
211
+ n_timesteps=[25, 10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
212
+ cfg=2.5,
213
+ rescale_cfg=0.75,
214
+ ):
215
+
216
+ semantic_code = combine_semantic_code
217
+
218
+ cond = s2a_model_1layer.cond_emb(semantic_code)
219
+ prompt = acoustic_code[:, :, :]
220
+ predict_1layer = s2a_model_1layer.reverse_diffusion(
221
+ cond=cond,
222
+ prompt=prompt,
223
+ temp=1.5,
224
+ filter_thres=0.98,
225
+ n_timesteps=n_timesteps[:1],
226
+ cfg=cfg,
227
+ rescale_cfg=rescale_cfg,
228
+ )
229
+
230
+ cond = s2a_model_full.cond_emb(semantic_code)
231
+ prompt = acoustic_code[:, :, :]
232
+ predict_full = s2a_model_full.reverse_diffusion(
233
+ cond=cond,
234
+ prompt=prompt,
235
+ temp=1.5,
236
+ filter_thres=0.98,
237
+ n_timesteps=n_timesteps,
238
+ cfg=cfg,
239
+ rescale_cfg=rescale_cfg,
240
+ gt_code=predict_1layer,
241
+ )
242
+
243
+ vq_emb = codec_decoder.vq2emb(predict_full.permute(2, 0, 1), n_quantizers=12)
244
+ recovered_audio = codec_decoder(vq_emb)
245
+ prompt_vq_emb = codec_decoder.vq2emb(prompt.permute(2, 0, 1), n_quantizers=12)
246
+ recovered_prompt_audio = codec_decoder(prompt_vq_emb)
247
+ recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().numpy()
248
+ recovered_audio = recovered_audio[0][0].cpu().numpy()
249
+ combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio])
250
+
251
+ return combine_audio, recovered_audio
252
+
253
+
254
+ # Load the model and checkpoints
255
+ def load_models():
256
+ cfg_path = "./models/tts/maskgct/config/maskgct.json"
257
+
258
+ cfg = load_config(cfg_path)
259
+ semantic_model, semantic_mean, semantic_std = build_semantic_model(device)
260
+ semantic_codec = build_semantic_codec(cfg.model.semantic_codec, device)
261
+ codec_encoder, codec_decoder = build_acoustic_codec(
262
+ cfg.model.acoustic_codec, device
263
+ )
264
+ t2s_model = build_t2s_model(cfg.model.t2s_model, device)
265
+ s2a_model_1layer = build_s2a_model(cfg.model.s2a_model.s2a_1layer, device)
266
+ s2a_model_full = build_s2a_model(cfg.model.s2a_model.s2a_full, device)
267
+
268
+ # Download checkpoints
269
+ semantic_code_ckpt = hf_hub_download(
270
+ "amphion/MaskGCT", filename="semantic_codec/model.safetensors"
271
+ )
272
+ # codec_encoder_ckpt = hf_hub_download(
273
+ # "amphion/MaskGCT", filename="acoustic_codec/model.safetensors"
274
+ # )
275
+ # codec_decoder_ckpt = hf_hub_download(
276
+ # "amphion/MaskGCT", filename="acoustic_codec/model_1.safetensors"
277
+ # )
278
+ t2s_model_ckpt = hf_hub_download(
279
+ "amphion/MaskGCT", filename="t2s_model/model.safetensors"
280
+ )
281
+ s2a_1layer_ckpt = hf_hub_download(
282
+ "amphion/MaskGCT", filename="s2a_model/s2a_model_1layer/model.safetensors"
283
+ )
284
+ s2a_full_ckpt = hf_hub_download(
285
+ "amphion/MaskGCT", filename="s2a_model/s2a_model_full/model.safetensors"
286
+ )
287
+
288
+ safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)
289
+ # safetensors.torch.load_model(codec_encoder, codec_encoder_ckpt)
290
+ # safetensors.torch.load_model(codec_decoder, codec_decoder_ckpt)
291
+ accelerate.load_checkpoint_and_dispatch(codec_encoder, "./acoustic_codec/model.safetensors")
292
+ accelerate.load_checkpoint_and_dispatch(codec_decoder, "./acoustic_codec/model_1.safetensors")
293
+ safetensors.torch.load_model(t2s_model, t2s_model_ckpt)
294
+ safetensors.torch.load_model(s2a_model_1layer, s2a_1layer_ckpt)
295
+ safetensors.torch.load_model(s2a_model_full, s2a_full_ckpt)
296
+
297
+ return (
298
+ semantic_model,
299
+ semantic_mean,
300
+ semantic_std,
301
+ semantic_codec,
302
+ codec_encoder,
303
+ codec_decoder,
304
+ t2s_model,
305
+ s2a_model_1layer,
306
+ s2a_model_full,
307
+ )
308
+
309
+
310
+ @torch.no_grad()
311
+ def maskgct_inference(
312
+ prompt_speech_path,
313
+ target_text,
314
+ target_len=None,
315
+ n_timesteps=25,
316
+ cfg=2.5,
317
+ rescale_cfg=0.75,
318
+ n_timesteps_s2a=[25, 10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
319
+ cfg_s2a=2.5,
320
+ rescale_cfg_s2a=0.75,
321
+ device=torch.device("cuda:0"),
322
+ ):
323
+ speech_16k = librosa.load(prompt_speech_path, sr=16000)[0]
324
+ speech = librosa.load(prompt_speech_path, sr=24000)[0]
325
+
326
+ prompt_language = detect_speech_language(prompt_speech_path)
327
+ full_prompt_text, short_prompt_text, shot_prompt_end_ts = get_prompt_text(prompt_speech_path,
328
+ prompt_language)
329
+ # use the first 4+ seconds wav as the prompt in case the prompt wav is too long
330
+ speech = speech[0: int(shot_prompt_end_ts * 24000)]
331
+ speech_16k = speech_16k[0: int(shot_prompt_end_ts*16000)]
332
+
333
+ target_language = detect_text_language(target_text)
334
+ combine_semantic_code, _ = text2semantic(
335
+ device,
336
+ speech_16k,
337
+ short_prompt_text,
338
+ prompt_language,
339
+ target_text,
340
+ target_language,
341
+ target_len,
342
+ n_timesteps,
343
+ cfg,
344
+ rescale_cfg,
345
+ )
346
+ acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device))
347
+ _, recovered_audio = semantic2acoustic(
348
+ device,
349
+ combine_semantic_code,
350
+ acoustic_code,
351
+ n_timesteps=n_timesteps_s2a,
352
+ cfg=cfg_s2a,
353
+ rescale_cfg=rescale_cfg_s2a,
354
+ )
355
+
356
+ return recovered_audio
357
+
358
+
359
+ @spaces.GPU
360
+ def inference(
361
+ prompt_wav,
362
+ target_text,
363
+ target_len,
364
+ n_timesteps,
365
+ ):
366
+ global output_file_name_idx
367
+ save_path = f"./output/output_{output_file_name_idx}.wav"
368
+ os.makedirs("./output", exist_ok=True)
369
+ recovered_audio = maskgct_inference(
370
+ prompt_wav,
371
+ target_text,
372
+ target_len=target_len,
373
+ n_timesteps=int(n_timesteps),
374
+ device=device,
375
+ )
376
+ sf.write(save_path, recovered_audio, 24000)
377
+ output_file_name_idx = (output_file_name_idx + 1) % 10
378
+ return save_path
379
+
380
+ # Load models once
381
+ (
382
+ semantic_model,
383
+ semantic_mean,
384
+ semantic_std,
385
+ semantic_codec,
386
+ codec_encoder,
387
+ codec_decoder,
388
+ t2s_model,
389
+ s2a_model_1layer,
390
+ s2a_model_full,
391
+ ) = load_models()
392
+
393
+ # Language list
394
+ language_list = ["en", "zh", "ja", "ko", "fr", "de"]
395
+
396
+ # Gradio interface
397
+ iface = gr.Interface(
398
+ fn=inference,
399
+ inputs=[
400
+ gr.Audio(label="Upload Prompt Wav", type="filepath"),
401
+ gr.Textbox(label="Target Text"),
402
+ gr.Number(
403
+ label="Target Duration (in seconds), if the target duration is less than 0, the system will estimate a duration.", value=-1
404
+ ), # Removed 'optional=True'
405
+ gr.Slider(
406
+ label="Number of Timesteps", minimum=15, maximum=100, value=25, step=1
407
+ ),
408
+ ],
409
+ outputs=gr.Audio(label="Generated Audio"),
410
+ title="MaskGCT TTS Demo",
411
+ description="""
412
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-COLOR.svg)](https://arxiv.org/abs/2409.00750) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-model-yellow)](https://huggingface.co/amphion/maskgct) [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-demo-pink)](https://huggingface.co/spaces/amphion/maskgct) [![readme](https://img.shields.io/badge/README-Key%20Features-blue)](https://github.com/open-mmlab/Amphion/tree/main/models/tts/maskgct)
413
+ """
414
+ )
415
+
416
+ # Launch the interface
417
+ iface.launch(allowed_paths=["./output"])