kiramayatu commited on
Commit
517f991
·
verified ·
1 Parent(s): 35c90fb

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch import no_grad, LongTensor
5
+ import argparse
6
+ import commons
7
+ from mel_processing import spectrogram_torch
8
+ import utils
9
+ from models import SynthesizerTrn
10
+ import gradio as gr
11
+ import librosa
12
+ import webbrowser
13
+
14
+
15
+ from text import text_to_sequence, _clean_text
16
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
17
+ language_marks = {
18
+ "Japanese": "",
19
+ "日本語": "[JA]",
20
+ "简体中文": "[ZH]",
21
+ "English": "[EN]",
22
+ "Mix": "",
23
+ }
24
+ lang = ['日本語', '简体中文', 'English', 'Mix']
25
+ def get_text(text, hps, is_symbol):
26
+ text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
27
+ if hps.data.add_blank:
28
+ text_norm = commons.intersperse(text_norm, 0)
29
+ text_norm = LongTensor(text_norm)
30
+ return text_norm
31
+
32
+ def create_tts_fn(model, hps, speaker_ids):
33
+ def tts_fn(text, speaker, language, speed):
34
+ if language is not None:
35
+ text = language_marks[language] + text + language_marks[language]
36
+ speaker_id = speaker_ids[speaker]
37
+ stn_tst = get_text(text, hps, False)
38
+ with no_grad():
39
+ x_tst = stn_tst.unsqueeze(0).to(device)
40
+ x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
41
+ sid = LongTensor([speaker_id]).to(device)
42
+ audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8,
43
+ length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
44
+ del stn_tst, x_tst, x_tst_lengths, sid
45
+ return "Success", (hps.data.sampling_rate, audio)
46
+
47
+ return tts_fn
48
+
49
+ def create_vc_fn(model, hps, speaker_ids):
50
+ def vc_fn(original_speaker, target_speaker, record_audio, upload_audio):
51
+ input_audio = record_audio if record_audio is not None else upload_audio
52
+ if input_audio is None:
53
+ return "You need to record or upload an audio", None
54
+ sampling_rate, audio = input_audio
55
+ original_speaker_id = speaker_ids[original_speaker]
56
+ target_speaker_id = speaker_ids[target_speaker]
57
+
58
+ audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32)
59
+ if len(audio.shape) > 1:
60
+ audio = librosa.to_mono(audio.transpose(1, 0))
61
+ if sampling_rate != hps.data.sampling_rate:
62
+ audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=hps.data.sampling_rate)
63
+ with no_grad():
64
+ y = torch.FloatTensor(audio)
65
+ y = y / max(-y.min(), y.max()) / 0.99
66
+ y = y.to(device)
67
+ y = y.unsqueeze(0)
68
+ spec = spectrogram_torch(y, hps.data.filter_length,
69
+ hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
70
+ center=False).to(device)
71
+ spec_lengths = LongTensor([spec.size(-1)]).to(device)
72
+ sid_src = LongTensor([original_speaker_id]).to(device)
73
+ sid_tgt = LongTensor([target_speaker_id]).to(device)
74
+ audio = model.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt)[0][
75
+ 0, 0].data.cpu().float().numpy()
76
+ del y, spec, spec_lengths, sid_src, sid_tgt
77
+ return "Success", (hps.data.sampling_rate, audio)
78
+
79
+ return vc_fn
80
+ if __name__ == "__main__":
81
+ parser = argparse.ArgumentParser()
82
+ parser.add_argument("--model_dir", default="./inference/G_latest.pth", help="directory to your fine-tuned model")
83
+ parser.add_argument("--config_dir", default="./inference/finetune_speaker.json", help="directory to your model config file")
84
+ parser.add_argument("--share", default=False, help="make link public (used in colab)")
85
+
86
+ args = parser.parse_args()
87
+ hps = utils.get_hparams_from_file(args.config_dir)
88
+
89
+
90
+ net_g = SynthesizerTrn(
91
+ len(hps.symbols),
92
+ hps.data.filter_length // 2 + 1,
93
+ hps.train.segment_size // hps.data.hop_length,
94
+ n_speakers=hps.data.n_speakers,
95
+ **hps.model).to(device)
96
+ _ = net_g.eval()
97
+
98
+ _ = utils.load_checkpoint(args.model_dir, net_g, None)
99
+ speaker_ids = hps.speakers
100
+ speakers = list(hps.speakers.keys())
101
+ tts_fn = create_tts_fn(net_g, hps, speaker_ids)
102
+ vc_fn = create_vc_fn(net_g, hps, speaker_ids)
103
+ app = gr.Blocks()
104
+ with app:
105
+ with gr.Tab("Text-to-Speech"):
106
+ with gr.Row():
107
+ with gr.Column():
108
+ textbox = gr.TextArea(label="Text",
109
+ placeholder="Type your sentence here",
110
+ value="こんにちわ。", elem_id=f"tts-input")
111
+ # select character
112
+ char_dropdown = gr.Dropdown(choices=speakers, value=speakers[0], label='character')
113
+ language_dropdown = gr.Dropdown(choices=lang, value=lang[0], label='language')
114
+ duration_slider = gr.Slider(minimum=0.1, maximum=5, value=1, step=0.1,
115
+ label='速度 Speed')
116
+ with gr.Column():
117
+ text_output = gr.Textbox(label="Message")
118
+ audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio")
119
+ btn = gr.Button("Generate!")
120
+ btn.click(tts_fn,
121
+ inputs=[textbox, char_dropdown, language_dropdown, duration_slider,],
122
+ outputs=[text_output, audio_output], api_name="btn")
123
+ with gr.Tab("Voice Conversion"):
124
+ gr.Markdown("""
125
+ 录制或上传声音,并选择要转换的音色。
126
+ """)
127
+ with gr.Column():
128
+ record_audio = gr.Audio(label="record your voice", source="microphone")
129
+ upload_audio = gr.Audio(label="or upload audio here", source="upload")
130
+ source_speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="source speaker")
131
+ target_speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="target speaker")
132
+ with gr.Column():
133
+ message_box = gr.Textbox(label="Message")
134
+ converted_audio = gr.Audio(label='converted audio')
135
+ btn = gr.Button("Convert!")
136
+ btn.click(vc_fn, inputs=[source_speaker, target_speaker, record_audio, upload_audio],
137
+ outputs=[message_box, converted_audio])
138
+ webbrowser.open("http://127.0.0.1:7860")
139
+ app.launch(share=args.share)
140
+
141
+
142
+
143
+