Spaces:
Runtime error
Runtime error
MZhao-LEGION
commited on
Commit
•
84fef35
1
Parent(s):
f2c4c94
multilingual model!
Browse files- Data/TalkFlower_CNzh/config.json +0 -96
- app.py +1 -1
- config.yml +13 -13
- emo_gen.py +16 -23
- infer.py +18 -21
- presets.py +114 -55
- utils.py +70 -0
Data/TalkFlower_CNzh/config.json
DELETED
@@ -1,96 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"eval_interval": 1000,
|
5 |
-
"seed": 42,
|
6 |
-
"epochs": 1000,
|
7 |
-
"learning_rate": 0.0002,
|
8 |
-
"betas": [
|
9 |
-
0.8,
|
10 |
-
0.99
|
11 |
-
],
|
12 |
-
"eps": 1e-09,
|
13 |
-
"batch_size": 12,
|
14 |
-
"fp16_run": false,
|
15 |
-
"lr_decay": 0.99995,
|
16 |
-
"segment_size": 16384,
|
17 |
-
"init_lr_ratio": 1,
|
18 |
-
"warmup_epochs": 0,
|
19 |
-
"c_mel": 45,
|
20 |
-
"c_kl": 1.0,
|
21 |
-
"skip_optimizer": true
|
22 |
-
},
|
23 |
-
"data": {
|
24 |
-
"training_files": "filelists/train.list",
|
25 |
-
"validation_files": "filelists/val.list",
|
26 |
-
"max_wav_value": 32768.0,
|
27 |
-
"sampling_rate": 44100,
|
28 |
-
"filter_length": 2048,
|
29 |
-
"hop_length": 512,
|
30 |
-
"win_length": 2048,
|
31 |
-
"n_mel_channels": 128,
|
32 |
-
"mel_fmin": 0.0,
|
33 |
-
"mel_fmax": null,
|
34 |
-
"add_blank": true,
|
35 |
-
"n_speakers": 700,
|
36 |
-
"cleaned_text": true,
|
37 |
-
"spk2id": {
|
38 |
-
"TalkFlower_CNzh": 0
|
39 |
-
}
|
40 |
-
},
|
41 |
-
"model": {
|
42 |
-
"use_spk_conditioned_encoder": true,
|
43 |
-
"use_noise_scaled_mas": true,
|
44 |
-
"use_mel_posterior_encoder": false,
|
45 |
-
"use_duration_discriminator": true,
|
46 |
-
"inter_channels": 192,
|
47 |
-
"hidden_channels": 192,
|
48 |
-
"filter_channels": 768,
|
49 |
-
"n_heads": 2,
|
50 |
-
"n_layers": 6,
|
51 |
-
"kernel_size": 3,
|
52 |
-
"p_dropout": 0.1,
|
53 |
-
"resblock": "1",
|
54 |
-
"resblock_kernel_sizes": [
|
55 |
-
3,
|
56 |
-
7,
|
57 |
-
11
|
58 |
-
],
|
59 |
-
"resblock_dilation_sizes": [
|
60 |
-
[
|
61 |
-
1,
|
62 |
-
3,
|
63 |
-
5
|
64 |
-
],
|
65 |
-
[
|
66 |
-
1,
|
67 |
-
3,
|
68 |
-
5
|
69 |
-
],
|
70 |
-
[
|
71 |
-
1,
|
72 |
-
3,
|
73 |
-
5
|
74 |
-
]
|
75 |
-
],
|
76 |
-
"upsample_rates": [
|
77 |
-
8,
|
78 |
-
8,
|
79 |
-
2,
|
80 |
-
2,
|
81 |
-
2
|
82 |
-
],
|
83 |
-
"upsample_initial_channel": 512,
|
84 |
-
"upsample_kernel_sizes": [
|
85 |
-
16,
|
86 |
-
16,
|
87 |
-
8,
|
88 |
-
2,
|
89 |
-
2
|
90 |
-
],
|
91 |
-
"n_layers_q": 3,
|
92 |
-
"use_spectral_norm": false,
|
93 |
-
"gin_channels": 256
|
94 |
-
},
|
95 |
-
"version": "2.0"
|
96 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -7,7 +7,7 @@ from presets import *
|
|
7 |
with gr.Blocks(css=customCSS) as demo:
|
8 |
exceed_flag = gr.State(value=False)
|
9 |
tmp_string = gr.Textbox(value="", visible=False)
|
10 |
-
character_area = gr.HTML(get_character_html("
|
11 |
with gr.Tab("Speak", elem_id="tab-speak"):
|
12 |
speak_input = gr.Textbox(lines=1, label="Talking Flower will say:", elem_classes="wonder-card input_text", elem_id="speak_input")
|
13 |
speak_button = gr.Button("Speak!", elem_id="speak_button", elem_classes="main-button wonder-card")
|
|
|
7 |
with gr.Blocks(css=customCSS) as demo:
|
8 |
exceed_flag = gr.State(value=False)
|
9 |
tmp_string = gr.Textbox(value="", visible=False)
|
10 |
+
character_area = gr.HTML(get_character_html("你好呀!我现在支持多语言了呢!"), elem_id="character_area")
|
11 |
with gr.Tab("Speak", elem_id="tab-speak"):
|
12 |
speak_input = gr.Textbox(lines=1, label="Talking Flower will say:", elem_classes="wonder-card input_text", elem_id="speak_input")
|
13 |
speak_button = gr.Button("Speak!", elem_id="speak_button", elem_classes="main-button wonder-card")
|
config.yml
CHANGED
@@ -4,7 +4,7 @@
|
|
4 |
# 拟提供通用路径配置,统一存放数据,避免数据放得很乱
|
5 |
# 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径
|
6 |
# 不填或者填空则路径为相对于项目根目录的路径
|
7 |
-
dataset_path: "
|
8 |
|
9 |
# 模型镜像源,默认huggingface,使用openi镜像源需指定openi_token
|
10 |
mirror: ""
|
@@ -34,7 +34,7 @@ preprocess_text:
|
|
34 |
# 验证集路径
|
35 |
val_path: "filelists/val.list"
|
36 |
# 配置文件路径
|
37 |
-
config_path: "Data/
|
38 |
# 每个speaker的验证集条数
|
39 |
val_per_spk: 5
|
40 |
# 验证集最大条数,多于的会被截断并放到训练集中
|
@@ -47,12 +47,12 @@ preprocess_text:
|
|
47 |
# 注意, “:” 后需要加空格
|
48 |
bert_gen:
|
49 |
# 训练数据集配置文件路径
|
50 |
-
config_path: "Data/
|
51 |
# 并行数
|
52 |
num_processes: 8
|
53 |
# 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
|
54 |
# 该选项同时决定了get_bert_feature的默认设备
|
55 |
-
device: "
|
56 |
# 使用多卡推理
|
57 |
use_multi_device: false
|
58 |
|
@@ -60,11 +60,11 @@ bert_gen:
|
|
60 |
# 注意, “:” 后需要加空格
|
61 |
emo_gen:
|
62 |
# 训练数据集配置文件路径
|
63 |
-
config_path: "Data/
|
64 |
# 并行数
|
65 |
num_processes: 2
|
66 |
# 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
|
67 |
-
device: "
|
68 |
|
69 |
# train 训练配置
|
70 |
# 注意, “:” 后需要加空格
|
@@ -85,7 +85,7 @@ train_ms:
|
|
85 |
# 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
|
86 |
model: "models"
|
87 |
# 配置文件路径
|
88 |
-
config_path: "config.json"
|
89 |
# 训练使用的worker,不建议超过CPU核心数
|
90 |
num_workers: 16
|
91 |
# 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
|
@@ -100,9 +100,9 @@ webui:
|
|
100 |
# 推理设备
|
101 |
device: "cpu"
|
102 |
# 模型路径
|
103 |
-
model: "
|
104 |
# 配置文件路径
|
105 |
-
config_path: "config.json"
|
106 |
# 端口号
|
107 |
port: 7860
|
108 |
# 是否公开部署,对外网开放
|
@@ -120,16 +120,16 @@ server:
|
|
120 |
# 端口号
|
121 |
port: 5000
|
122 |
# 模型默认使用设备:但是当前并没有实现这个配置。
|
123 |
-
device: "
|
124 |
# 需要加载的所有模型的配置
|
125 |
# 注意,所有模型都必须正确配置model与config的路径,空路径会导致加载错误。
|
126 |
models:
|
127 |
- # 模型的路径
|
128 |
-
model: "models/
|
129 |
# 模型config.json的路径
|
130 |
-
config: "
|
131 |
# 模型使用设备,若填写则会覆盖默认配置
|
132 |
-
device: "
|
133 |
# 模型默认使用的语言
|
134 |
language: "ZH"
|
135 |
# 模型人物默认参数
|
|
|
4 |
# 拟提供通用路径配置,统一存放数据,避免数据放得很乱
|
5 |
# 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径
|
6 |
# 不填或者填空则路径为相对于项目根目录的路径
|
7 |
+
dataset_path: ""
|
8 |
|
9 |
# 模型镜像源,默认huggingface,使用openi镜像源需指定openi_token
|
10 |
mirror: ""
|
|
|
34 |
# 验证集路径
|
35 |
val_path: "filelists/val.list"
|
36 |
# 配置文件路径
|
37 |
+
config_path: "Data/config.json"
|
38 |
# 每个speaker的验证集条数
|
39 |
val_per_spk: 5
|
40 |
# 验证集最大条数,多于的会被截断并放到训练集中
|
|
|
47 |
# 注意, “:” 后需要加空格
|
48 |
bert_gen:
|
49 |
# 训练数据集配置文件路径
|
50 |
+
config_path: "Data/config.json"
|
51 |
# 并行数
|
52 |
num_processes: 8
|
53 |
# 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
|
54 |
# 该选项同时决定了get_bert_feature的默认设备
|
55 |
+
device: "cpu"
|
56 |
# 使用多卡推理
|
57 |
use_multi_device: false
|
58 |
|
|
|
60 |
# 注意, “:” 后需要加空格
|
61 |
emo_gen:
|
62 |
# 训练数据集配置文件路径
|
63 |
+
config_path: "Data/config.json"
|
64 |
# 并行数
|
65 |
num_processes: 2
|
66 |
# 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
|
67 |
+
device: "cpu"
|
68 |
|
69 |
# train 训练配置
|
70 |
# 注意, “:” 后需要加空格
|
|
|
85 |
# 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
|
86 |
model: "models"
|
87 |
# 配置文件路径
|
88 |
+
config_path: "Data/config.json"
|
89 |
# 训练使用的worker,不建议超过CPU核心数
|
90 |
num_workers: 16
|
91 |
# 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
|
|
|
100 |
# 推理设备
|
101 |
device: "cpu"
|
102 |
# 模型路径
|
103 |
+
model: "models/G_multilingual.pth"
|
104 |
# 配置文件路径
|
105 |
+
config_path: "Data/config.json"
|
106 |
# 端口号
|
107 |
port: 7860
|
108 |
# 是否公开部署,对外网开放
|
|
|
120 |
# 端口号
|
121 |
port: 5000
|
122 |
# 模型默认使用设备:但是当前并没有实现这个配置。
|
123 |
+
device: "cpu"
|
124 |
# 需要加载的所有模型的配置
|
125 |
# 注意,所有模型都必须正确配置model与config的路径,空路径会导致加载错误。
|
126 |
models:
|
127 |
- # 模型的路径
|
128 |
+
model: "models/G_multilingual.pth"
|
129 |
# 模型config.json的路径
|
130 |
+
config: "Data/config.json"
|
131 |
# 模型使用设备,若填写则会覆盖默认配置
|
132 |
+
device: "cpu"
|
133 |
# 模型默认使用的语言
|
134 |
language: "ZH"
|
135 |
# 模型人物默认参数
|
emo_gen.py
CHANGED
@@ -1,19 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
-
from torch.utils.data import Dataset
|
4 |
-
from
|
5 |
from transformers import Wav2Vec2Processor
|
6 |
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
7 |
Wav2Vec2Model,
|
8 |
Wav2Vec2PreTrainedModel,
|
9 |
)
|
10 |
-
|
11 |
-
import numpy as np
|
12 |
-
import argparse
|
13 |
-
from config import config
|
14 |
import utils
|
15 |
-
import
|
16 |
-
from tqdm import tqdm
|
17 |
|
18 |
|
19 |
class RegressionHead(nn.Module):
|
@@ -78,11 +80,6 @@ class AudioDataset(Dataset):
|
|
78 |
return torch.from_numpy(processed_data)
|
79 |
|
80 |
|
81 |
-
model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
82 |
-
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
83 |
-
model = EmotionModel.from_pretrained(model_name)
|
84 |
-
|
85 |
-
|
86 |
def process_func(
|
87 |
x: np.ndarray,
|
88 |
sampling_rate: int,
|
@@ -135,16 +132,12 @@ if __name__ == "__main__":
|
|
135 |
device = config.bert_gen_config.device
|
136 |
|
137 |
model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
)
|
143 |
-
model = (
|
144 |
-
EmotionModel.from_pretrained(model_name).to(device)
|
145 |
-
if model is None
|
146 |
-
else model.to(device)
|
147 |
-
)
|
148 |
|
149 |
lines = []
|
150 |
with open(hps.data.training_files, encoding="utf-8") as f:
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
import torch
|
8 |
import torch.nn as nn
|
9 |
+
from torch.utils.data import DataLoader, Dataset
|
10 |
+
from tqdm import tqdm
|
11 |
from transformers import Wav2Vec2Processor
|
12 |
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
13 |
Wav2Vec2Model,
|
14 |
Wav2Vec2PreTrainedModel,
|
15 |
)
|
16 |
+
|
|
|
|
|
|
|
17 |
import utils
|
18 |
+
from config import config
|
|
|
19 |
|
20 |
|
21 |
class RegressionHead(nn.Module):
|
|
|
80 |
return torch.from_numpy(processed_data)
|
81 |
|
82 |
|
|
|
|
|
|
|
|
|
|
|
83 |
def process_func(
|
84 |
x: np.ndarray,
|
85 |
sampling_rate: int,
|
|
|
132 |
device = config.bert_gen_config.device
|
133 |
|
134 |
model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
135 |
+
REPO_ID = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
136 |
+
if not Path(model_name).joinpath("pytorch_model.bin").exists():
|
137 |
+
utils.download_emo_models(config.mirror, model_name, REPO_ID)
|
138 |
+
|
139 |
+
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
140 |
+
model = EmotionModel.from_pretrained(model_name).to(device)
|
|
|
|
|
|
|
|
|
141 |
|
142 |
lines = []
|
143 |
with open(hps.data.training_files, encoding="utf-8") as f:
|
infer.py
CHANGED
@@ -29,7 +29,7 @@ from oldVersion.V101.text import symbols as V101symbols
|
|
29 |
from oldVersion import V111, V110, V101, V200
|
30 |
|
31 |
# 当前版本信息
|
32 |
-
latest_version = "2.
|
33 |
|
34 |
# 版本兼容
|
35 |
SynthesizerTrnMap = {
|
@@ -82,7 +82,7 @@ def get_net_g(model_path: str, version: str, device: str, hps):
|
|
82 |
return net_g
|
83 |
|
84 |
|
85 |
-
def get_text(text,
|
86 |
# 在此处实现当前版本的get_text
|
87 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
88 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
@@ -113,12 +113,6 @@ def get_text(text, reference_audio, emotion, language_str, hps, device):
|
|
113 |
else:
|
114 |
raise ValueError("language_str should be ZH, JP or EN")
|
115 |
|
116 |
-
emo = (
|
117 |
-
torch.from_numpy(get_emo(reference_audio))
|
118 |
-
if reference_audio
|
119 |
-
else torch.Tensor([emotion])
|
120 |
-
)
|
121 |
-
|
122 |
assert bert.shape[-1] == len(
|
123 |
phone
|
124 |
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
|
@@ -126,7 +120,16 @@ def get_text(text, reference_audio, emotion, language_str, hps, device):
|
|
126 |
phone = torch.LongTensor(phone)
|
127 |
tone = torch.LongTensor(tone)
|
128 |
language = torch.LongTensor(language)
|
129 |
-
return bert, ja_bert, en_bert,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
|
132 |
def infer(
|
@@ -191,9 +194,10 @@ def infer(
|
|
191 |
device,
|
192 |
)
|
193 |
# 在此处实现当前版本的推理
|
194 |
-
bert, ja_bert, en_bert,
|
195 |
-
text,
|
196 |
)
|
|
|
197 |
if skip_start:
|
198 |
phones = phones[1:]
|
199 |
tones = tones[1:]
|
@@ -261,10 +265,8 @@ def infer_multilang(
|
|
261 |
skip_start=False,
|
262 |
skip_end=False,
|
263 |
):
|
264 |
-
bert, ja_bert, en_bert,
|
265 |
-
|
266 |
-
# text, language, hps, device
|
267 |
-
# )
|
268 |
for idx, (txt, lang) in enumerate(zip(text, language)):
|
269 |
skip_start = (idx != 0) or (skip_start and idx == 0)
|
270 |
skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1)
|
@@ -272,16 +274,14 @@ def infer_multilang(
|
|
272 |
temp_bert,
|
273 |
temp_ja_bert,
|
274 |
temp_en_bert,
|
275 |
-
temp_emo,
|
276 |
temp_phones,
|
277 |
temp_tones,
|
278 |
temp_lang_ids,
|
279 |
-
) = get_text(txt,
|
280 |
if skip_start:
|
281 |
temp_bert = temp_bert[:, 1:]
|
282 |
temp_ja_bert = temp_ja_bert[:, 1:]
|
283 |
temp_en_bert = temp_en_bert[:, 1:]
|
284 |
-
temp_emo = temp_emo[:, 1:]
|
285 |
temp_phones = temp_phones[1:]
|
286 |
temp_tones = temp_tones[1:]
|
287 |
temp_lang_ids = temp_lang_ids[1:]
|
@@ -289,21 +289,18 @@ def infer_multilang(
|
|
289 |
temp_bert = temp_bert[:, :-1]
|
290 |
temp_ja_bert = temp_ja_bert[:, :-1]
|
291 |
temp_en_bert = temp_en_bert[:, :-1]
|
292 |
-
temp_emo = temp_emo[:, :-1]
|
293 |
temp_phones = temp_phones[:-1]
|
294 |
temp_tones = temp_tones[:-1]
|
295 |
temp_lang_ids = temp_lang_ids[:-1]
|
296 |
bert.append(temp_bert)
|
297 |
ja_bert.append(temp_ja_bert)
|
298 |
en_bert.append(temp_en_bert)
|
299 |
-
emo.append(temp_emo)
|
300 |
phones.append(temp_phones)
|
301 |
tones.append(temp_tones)
|
302 |
lang_ids.append(temp_lang_ids)
|
303 |
bert = torch.concatenate(bert, dim=1)
|
304 |
ja_bert = torch.concatenate(ja_bert, dim=1)
|
305 |
en_bert = torch.concatenate(en_bert, dim=1)
|
306 |
-
emo = torch.concatenate(emo, dim=1)
|
307 |
phones = torch.concatenate(phones, dim=0)
|
308 |
tones = torch.concatenate(tones, dim=0)
|
309 |
lang_ids = torch.concatenate(lang_ids, dim=0)
|
|
|
29 |
from oldVersion import V111, V110, V101, V200
|
30 |
|
31 |
# 当前版本信息
|
32 |
+
latest_version = "2.1"
|
33 |
|
34 |
# 版本兼容
|
35 |
SynthesizerTrnMap = {
|
|
|
82 |
return net_g
|
83 |
|
84 |
|
85 |
+
def get_text(text, language_str, hps, device):
|
86 |
# 在此处实现当前版本的get_text
|
87 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
88 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
|
|
113 |
else:
|
114 |
raise ValueError("language_str should be ZH, JP or EN")
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
assert bert.shape[-1] == len(
|
117 |
phone
|
118 |
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
|
|
|
120 |
phone = torch.LongTensor(phone)
|
121 |
tone = torch.LongTensor(tone)
|
122 |
language = torch.LongTensor(language)
|
123 |
+
return bert, ja_bert, en_bert, phone, tone, language
|
124 |
+
|
125 |
+
|
126 |
+
def get_emo_(reference_audio, emotion):
|
127 |
+
emo = (
|
128 |
+
torch.from_numpy(get_emo(reference_audio))
|
129 |
+
if reference_audio
|
130 |
+
else torch.Tensor([emotion])
|
131 |
+
)
|
132 |
+
return emo
|
133 |
|
134 |
|
135 |
def infer(
|
|
|
194 |
device,
|
195 |
)
|
196 |
# 在此处实现当前版本的推理
|
197 |
+
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
198 |
+
text, language, hps, device
|
199 |
)
|
200 |
+
emo = get_emo_(reference_audio, emotion)
|
201 |
if skip_start:
|
202 |
phones = phones[1:]
|
203 |
tones = tones[1:]
|
|
|
265 |
skip_start=False,
|
266 |
skip_end=False,
|
267 |
):
|
268 |
+
bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
|
269 |
+
emo = get_emo_(reference_audio, emotion)
|
|
|
|
|
270 |
for idx, (txt, lang) in enumerate(zip(text, language)):
|
271 |
skip_start = (idx != 0) or (skip_start and idx == 0)
|
272 |
skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1)
|
|
|
274 |
temp_bert,
|
275 |
temp_ja_bert,
|
276 |
temp_en_bert,
|
|
|
277 |
temp_phones,
|
278 |
temp_tones,
|
279 |
temp_lang_ids,
|
280 |
+
) = get_text(txt, lang, hps, device)
|
281 |
if skip_start:
|
282 |
temp_bert = temp_bert[:, 1:]
|
283 |
temp_ja_bert = temp_ja_bert[:, 1:]
|
284 |
temp_en_bert = temp_en_bert[:, 1:]
|
|
|
285 |
temp_phones = temp_phones[1:]
|
286 |
temp_tones = temp_tones[1:]
|
287 |
temp_lang_ids = temp_lang_ids[1:]
|
|
|
289 |
temp_bert = temp_bert[:, :-1]
|
290 |
temp_ja_bert = temp_ja_bert[:, :-1]
|
291 |
temp_en_bert = temp_en_bert[:, :-1]
|
|
|
292 |
temp_phones = temp_phones[:-1]
|
293 |
temp_tones = temp_tones[:-1]
|
294 |
temp_lang_ids = temp_lang_ids[:-1]
|
295 |
bert.append(temp_bert)
|
296 |
ja_bert.append(temp_ja_bert)
|
297 |
en_bert.append(temp_en_bert)
|
|
|
298 |
phones.append(temp_phones)
|
299 |
tones.append(temp_tones)
|
300 |
lang_ids.append(temp_lang_ids)
|
301 |
bert = torch.concatenate(bert, dim=1)
|
302 |
ja_bert = torch.concatenate(ja_bert, dim=1)
|
303 |
en_bert = torch.concatenate(en_bert, dim=1)
|
|
|
304 |
phones = torch.concatenate(phones, dim=0)
|
305 |
tones = torch.concatenate(tones, dim=0)
|
306 |
lang_ids = torch.concatenate(lang_ids, dim=0)
|
presets.py
CHANGED
@@ -4,10 +4,11 @@ import numpy as np
|
|
4 |
import torch
|
5 |
import re_matching
|
6 |
import utils
|
7 |
-
from infer import infer, latest_version, get_net_g
|
8 |
import gradio as gr
|
9 |
from config import config
|
10 |
from tools.webui import reload_javascript, get_character_html
|
|
|
11 |
|
12 |
logging.basicConfig(
|
13 |
level=logging.INFO,
|
@@ -42,6 +43,7 @@ def speak_fn(
|
|
42 |
interval_between_para=0.2, # 段间间隔
|
43 |
interval_between_sent=1, # 句间间隔
|
44 |
):
|
|
|
45 |
while text.find("\n\n") != -1:
|
46 |
text = text.replace("\n\n", "\n")
|
47 |
if len(text) > 100:
|
@@ -54,58 +56,113 @@ def speak_fn(
|
|
54 |
audio_value = "./assets/audios/overlength.wav"
|
55 |
exceed_flag = not exceed_flag
|
56 |
else:
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
)
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
audio_concat = np.concatenate(audio_list)
|
111 |
audio_value = (hps.data.sampling_rate, audio_concat)
|
@@ -113,13 +170,15 @@ def speak_fn(
|
|
113 |
return gr.update(value=audio_value, autoplay=True), get_character_html(text), exceed_flag, gr.update(interactive=True)
|
114 |
|
115 |
|
|
|
116 |
def submit_lock_fn():
|
117 |
return gr.update(interactive=False)
|
118 |
|
119 |
|
120 |
def init_fn():
|
121 |
-
gr.Info("2023-11-
|
122 |
-
gr.Info("
|
|
|
123 |
|
124 |
index = random.randint(1,7)
|
125 |
welcome_text = get_sentence("Welcome", index)
|
|
|
4 |
import torch
|
5 |
import re_matching
|
6 |
import utils
|
7 |
+
from infer import infer, latest_version, get_net_g, infer_multilang
|
8 |
import gradio as gr
|
9 |
from config import config
|
10 |
from tools.webui import reload_javascript, get_character_html
|
11 |
+
from tools.sentence import split_by_language
|
12 |
|
13 |
logging.basicConfig(
|
14 |
level=logging.INFO,
|
|
|
43 |
interval_between_para=0.2, # 段间间隔
|
44 |
interval_between_sent=1, # 句间间隔
|
45 |
):
|
46 |
+
audio_list = []
|
47 |
while text.find("\n\n") != -1:
|
48 |
text = text.replace("\n\n", "\n")
|
49 |
if len(text) > 100:
|
|
|
56 |
audio_value = "./assets/audios/overlength.wav"
|
57 |
exceed_flag = not exceed_flag
|
58 |
else:
|
59 |
+
for idx, slice in enumerate(text.split("|")):
|
60 |
+
if slice == "":
|
61 |
+
continue
|
62 |
+
skip_start = idx != 0
|
63 |
+
skip_end = idx != len(text.split("|")) - 1
|
64 |
+
sentences_list = split_by_language(
|
65 |
+
slice, target_languages=["zh", "ja", "en"]
|
66 |
+
)
|
67 |
+
idx = 0
|
68 |
+
while idx < len(sentences_list):
|
69 |
+
text_to_generate = []
|
70 |
+
lang_to_generate = []
|
71 |
+
while True:
|
72 |
+
content, lang = sentences_list[idx]
|
73 |
+
temp_text = [content]
|
74 |
+
lang = lang.upper()
|
75 |
+
if lang == "JA":
|
76 |
+
lang = "JP"
|
77 |
+
if len(text_to_generate) > 0:
|
78 |
+
text_to_generate[-1] += [temp_text.pop(0)]
|
79 |
+
lang_to_generate[-1] += [lang]
|
80 |
+
if len(temp_text) > 0:
|
81 |
+
text_to_generate += [[i] for i in temp_text]
|
82 |
+
lang_to_generate += [[lang]] * len(temp_text)
|
83 |
+
if idx + 1 < len(sentences_list):
|
84 |
+
idx += 1
|
85 |
+
else:
|
86 |
+
break
|
87 |
+
skip_start = (idx != 0) and skip_start
|
88 |
+
skip_end = (idx != len(sentences_list) - 1) and skip_end
|
89 |
+
print(text_to_generate, lang_to_generate)
|
90 |
+
|
91 |
+
with torch.no_grad():
|
92 |
+
for i, piece in enumerate(text_to_generate):
|
93 |
+
skip_start = (i != 0) and skip_start
|
94 |
+
skip_end = (i != len(text_to_generate) - 1) and skip_end
|
95 |
+
audio = infer_multilang(
|
96 |
+
piece,
|
97 |
+
reference_audio=reference_audio,
|
98 |
+
emotion=emotion,
|
99 |
+
sdp_ratio=sdp_ratio,
|
100 |
+
noise_scale=noise_scale,
|
101 |
+
noise_scale_w=noise_scale_w,
|
102 |
+
length_scale=length_scale,
|
103 |
+
sid=speaker,
|
104 |
+
language=lang_to_generate[i],
|
105 |
+
hps=hps,
|
106 |
+
net_g=net_g,
|
107 |
+
device=device,
|
108 |
+
skip_start=skip_start,
|
109 |
+
skip_end=skip_end,
|
110 |
+
)
|
111 |
+
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
112 |
+
audio_list.append(audio16bit)
|
113 |
+
idx += 1
|
114 |
+
# 单一语言推理
|
115 |
+
# if len(text) > 42:
|
116 |
+
# logging.info(f"Long Text: {text}")
|
117 |
+
# para_list = re_matching.cut_para(text)
|
118 |
+
# for p in para_list:
|
119 |
+
# audio_list_sent = []
|
120 |
+
# sent_list = re_matching.cut_sent(p)
|
121 |
+
# for s in sent_list:
|
122 |
+
# audio = infer(
|
123 |
+
# s,
|
124 |
+
# sdp_ratio=sdp_ratio,
|
125 |
+
# noise_scale=noise_scale,
|
126 |
+
# noise_scale_w=noise_scale_w,
|
127 |
+
# length_scale=length_scale,
|
128 |
+
# sid=speaker,
|
129 |
+
# language=language,
|
130 |
+
# hps=hps,
|
131 |
+
# net_g=net_g,
|
132 |
+
# device=device,
|
133 |
+
# reference_audio=reference_audio,
|
134 |
+
# emotion=emotion,
|
135 |
+
# )
|
136 |
+
# audio_list_sent.append(audio)
|
137 |
+
# silence = np.zeros((int)(44100 * interval_between_sent))
|
138 |
+
# audio_list_sent.append(silence)
|
139 |
+
# if (interval_between_para - interval_between_sent) > 0:
|
140 |
+
# silence = np.zeros((int)(44100 * (interval_between_para - interval_between_sent)))
|
141 |
+
# audio_list_sent.append(silence)
|
142 |
+
# audio16bit = gr.processing_utils.convert_to_16_bit_wav(np.concatenate(audio_list_sent)) # 对完整句子做音量归一
|
143 |
+
# audio_list.append(audio16bit)
|
144 |
+
# else:
|
145 |
+
# logging.info(f"Short Text: {text}")
|
146 |
+
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
|
147 |
+
# with torch.no_grad():
|
148 |
+
# for piece in text.split("|"):
|
149 |
+
# audio = infer(
|
150 |
+
# piece,
|
151 |
+
# sdp_ratio=sdp_ratio,
|
152 |
+
# noise_scale=noise_scale,
|
153 |
+
# noise_scale_w=noise_scale_w,
|
154 |
+
# length_scale=length_scale,
|
155 |
+
# sid=speaker,
|
156 |
+
# language=language,
|
157 |
+
# hps=hps,
|
158 |
+
# net_g=net_g,
|
159 |
+
# device=device,
|
160 |
+
# reference_audio=reference_audio,
|
161 |
+
# emotion=emotion,
|
162 |
+
# )
|
163 |
+
# audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
|
164 |
+
# audio_list.append(audio16bit)
|
165 |
+
# audio_list.append(silence) # 将静音添加到列表中
|
166 |
|
167 |
audio_concat = np.concatenate(audio_list)
|
168 |
audio_value = (hps.data.sampling_rate, audio_concat)
|
|
|
170 |
return gr.update(value=audio_value, autoplay=True), get_character_html(text), exceed_flag, gr.update(interactive=True)
|
171 |
|
172 |
|
173 |
+
|
174 |
def submit_lock_fn():
|
175 |
return gr.update(interactive=False)
|
176 |
|
177 |
|
178 |
def init_fn():
|
179 |
+
gr.Info("2023-11-28: 支持多语言啦!闲聊花花现在能说中、英、日语啦!")
|
180 |
+
# gr.Info("2023-11-24: 优化长句生成效果;增加示例;更新了一些小彩蛋;画了一些大饼)")
|
181 |
+
gr.Info("Support languages: ZH|EN|JA. 欢迎在 Community 中提建议~")
|
182 |
|
183 |
index = random.randint(1,7)
|
184 |
welcome_text = get_sentence("Welcome", index)
|
utils.py
CHANGED
@@ -9,12 +9,31 @@ import numpy as np
|
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
from scipy.io.wavfile import read
|
11 |
import torch
|
|
|
12 |
|
13 |
MATPLOTLIB_FLAG = False
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def download_checkpoint(
|
19 |
dir_path, repo_config, token=None, regex="G_*.pth", mirror="openi"
|
20 |
):
|
@@ -385,3 +404,54 @@ class HParams:
|
|
385 |
|
386 |
def __repr__(self):
|
387 |
return self.__dict__.__repr__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
from scipy.io.wavfile import read
|
11 |
import torch
|
12 |
+
import re
|
13 |
|
14 |
MATPLOTLIB_FLAG = False
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
18 |
|
19 |
+
def download_emo_models(mirror, repo_id, model_name):
|
20 |
+
if mirror == "openi":
|
21 |
+
import openi
|
22 |
+
|
23 |
+
openi.model.download_model(
|
24 |
+
"Stardust_minus/Bert-VITS2",
|
25 |
+
repo_id.split("/")[-1],
|
26 |
+
"./emotional",
|
27 |
+
)
|
28 |
+
else:
|
29 |
+
hf_hub_download(
|
30 |
+
repo_id,
|
31 |
+
"pytorch_model.bin",
|
32 |
+
local_dir=model_name,
|
33 |
+
local_dir_use_symlinks=False,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
def download_checkpoint(
|
38 |
dir_path, repo_config, token=None, regex="G_*.pth", mirror="openi"
|
39 |
):
|
|
|
404 |
|
405 |
def __repr__(self):
|
406 |
return self.__dict__.__repr__()
|
407 |
+
|
408 |
+
|
409 |
+
def load_model(model_path, config_path):
|
410 |
+
hps = get_hparams_from_file(config_path)
|
411 |
+
net = SynthesizerTrn(
|
412 |
+
# len(symbols),
|
413 |
+
108,
|
414 |
+
hps.data.filter_length // 2 + 1,
|
415 |
+
hps.train.segment_size // hps.data.hop_length,
|
416 |
+
n_speakers=hps.data.n_speakers,
|
417 |
+
**hps.model,
|
418 |
+
).to("cpu")
|
419 |
+
_ = net.eval()
|
420 |
+
_ = load_checkpoint(model_path, net, None, skip_optimizer=True)
|
421 |
+
return net
|
422 |
+
|
423 |
+
|
424 |
+
def mix_model(
|
425 |
+
network1, network2, output_path, voice_ratio=(0.5, 0.5), tone_ratio=(0.5, 0.5)
|
426 |
+
):
|
427 |
+
if hasattr(network1, "module"):
|
428 |
+
state_dict1 = network1.module.state_dict()
|
429 |
+
state_dict2 = network2.module.state_dict()
|
430 |
+
else:
|
431 |
+
state_dict1 = network1.state_dict()
|
432 |
+
state_dict2 = network2.state_dict()
|
433 |
+
for k in state_dict1.keys():
|
434 |
+
if k not in state_dict2.keys():
|
435 |
+
continue
|
436 |
+
if "enc_p" in k:
|
437 |
+
state_dict1[k] = (
|
438 |
+
state_dict1[k].clone() * tone_ratio[0]
|
439 |
+
+ state_dict2[k].clone() * tone_ratio[1]
|
440 |
+
)
|
441 |
+
else:
|
442 |
+
state_dict1[k] = (
|
443 |
+
state_dict1[k].clone() * voice_ratio[0]
|
444 |
+
+ state_dict2[k].clone() * voice_ratio[1]
|
445 |
+
)
|
446 |
+
for k in state_dict2.keys():
|
447 |
+
if k not in state_dict1.keys():
|
448 |
+
state_dict1[k] = state_dict2[k].clone()
|
449 |
+
torch.save(
|
450 |
+
{"model": state_dict1, "iteration": 0, "optimizer": None, "learning_rate": 0},
|
451 |
+
output_path,
|
452 |
+
)
|
453 |
+
|
454 |
+
|
455 |
+
def get_steps(model_path):
|
456 |
+
matches = re.findall(r"\d+", model_path)
|
457 |
+
return matches[-1] if matches else None
|