Spaces:
Runtime error
Runtime error
zhzluke96
commited on
Commit
•
da8d589
1
Parent(s):
c4c6bff
update
Browse files- modules/Denoiser/AudioDenoiser.py +140 -0
- modules/Denoiser/AudioNosiseModel.py +66 -0
- modules/Denoiser/__init__.py +0 -0
- modules/Enhancer/ResembleEnhance.py +116 -0
- modules/Enhancer/__init__.py +0 -0
- modules/SynthesizeSegments.py +147 -185
- modules/api/impl/google_api.py +0 -1
- modules/api/impl/speaker_api.py +7 -3
- modules/api/impl/ssml_api.py +11 -24
- modules/api/utils.py +0 -2
- modules/denoise.py +46 -2
- modules/generate_audio.py +1 -1
- modules/models.py +1 -9
- modules/speaker.py +30 -17
- modules/ssml_parser/SSMLParser.py +178 -0
- modules/ssml_parser/__init__.py +0 -0
- modules/ssml_parser/test_ssml_parser.py +104 -0
- modules/utils/JsonObject.py +19 -0
- modules/utils/constants.py +1 -1
- modules/webui/app.py +11 -9
- modules/webui/speaker_tab.py +250 -4
- modules/webui/spliter_tab.py +2 -1
- modules/webui/system_tab.py +15 -0
- modules/webui/tts_tab.py +98 -82
- modules/webui/webui_config.py +4 -0
- modules/webui/webui_utils.py +72 -31
- webui.py +3 -1
modules/Denoiser/AudioDenoiser.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
from typing import Union
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
from torch import nn
|
7 |
+
from audio_denoiser.helpers.torch_helper import batched_apply
|
8 |
+
from modules.Denoiser.AudioNosiseModel import load_audio_denosier_model
|
9 |
+
from audio_denoiser.helpers.audio_helper import (
|
10 |
+
create_spectrogram,
|
11 |
+
reconstruct_from_spectrogram,
|
12 |
+
)
|
13 |
+
|
14 |
+
_expected_t_std = 0.23
|
15 |
+
_recommended_backend = "soundfile"
|
16 |
+
|
17 |
+
|
18 |
+
# ref: https://github.com/jose-solorzano/audio-denoiser
|
19 |
+
class AudioDenoiser:
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
local_dir: str,
|
23 |
+
device: Union[str, torch.device] = None,
|
24 |
+
num_iterations: int = 100,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
if device is None:
|
28 |
+
is_cuda = torch.cuda.is_available()
|
29 |
+
if not is_cuda:
|
30 |
+
logging.warning("CUDA not available. Will use CPU.")
|
31 |
+
device = torch.device("cuda:0") if is_cuda else torch.device("cpu")
|
32 |
+
self.device = device
|
33 |
+
self.model = load_audio_denosier_model(dir_path=local_dir, device=device)
|
34 |
+
self.model.eval()
|
35 |
+
self.model_sample_rate = self.model.sample_rate
|
36 |
+
self.scaler = self.model.scaler
|
37 |
+
self.n_fft = self.model.n_fft
|
38 |
+
self.segment_num_frames = self.model.num_frames
|
39 |
+
self.num_iterations = num_iterations
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def _sp_log(spectrogram: torch.Tensor, eps=0.01):
|
43 |
+
return torch.log(spectrogram + eps)
|
44 |
+
|
45 |
+
@staticmethod
|
46 |
+
def _sp_exp(log_spectrogram: torch.Tensor, eps=0.01):
|
47 |
+
return torch.clamp(torch.exp(log_spectrogram) - eps, min=0)
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def _trimmed_dev(waveform: torch.Tensor, q: float = 0.90) -> float:
|
51 |
+
# Expected for training data is ~0.23
|
52 |
+
abs_waveform = torch.abs(waveform)
|
53 |
+
quantile_value = torch.quantile(abs_waveform, q).item()
|
54 |
+
trimmed_values = waveform[abs_waveform >= quantile_value]
|
55 |
+
return torch.std(trimmed_values).item()
|
56 |
+
|
57 |
+
def process_waveform(
|
58 |
+
self,
|
59 |
+
waveform: torch.Tensor,
|
60 |
+
sample_rate: int,
|
61 |
+
return_cpu_tensor: bool = False,
|
62 |
+
auto_scale: bool = False,
|
63 |
+
) -> torch.Tensor:
|
64 |
+
"""
|
65 |
+
Denoises a waveform.
|
66 |
+
@param waveform: A waveform tensor. Use torchaudio structure.
|
67 |
+
@param sample_rate: The sample rate of the waveform in Hz.
|
68 |
+
@param return_cpu_tensor: Whether the returned tensor must be a CPU tensor.
|
69 |
+
@param auto_scale: Normalize the scale of the waveform before processing. Recommended for low-volume audio.
|
70 |
+
@return: A denoised waveform.
|
71 |
+
"""
|
72 |
+
waveform = waveform.cpu()
|
73 |
+
if auto_scale:
|
74 |
+
w_t_std = self._trimmed_dev(waveform)
|
75 |
+
waveform = waveform * _expected_t_std / w_t_std
|
76 |
+
if sample_rate != self.model_sample_rate:
|
77 |
+
transform = torchaudio.transforms.Resample(
|
78 |
+
orig_freq=sample_rate, new_freq=self.model_sample_rate
|
79 |
+
)
|
80 |
+
waveform = transform(waveform)
|
81 |
+
hop_len = self.n_fft // 2
|
82 |
+
spectrogram = create_spectrogram(waveform, n_fft=self.n_fft, hop_length=hop_len)
|
83 |
+
spectrogram = spectrogram.to(self.device)
|
84 |
+
num_a_channels = spectrogram.size(0)
|
85 |
+
with torch.no_grad():
|
86 |
+
results = []
|
87 |
+
for c in range(num_a_channels):
|
88 |
+
c_spectrogram = spectrogram[c]
|
89 |
+
# c_spectrogram: (257, num_frames)
|
90 |
+
fft_size, num_frames = c_spectrogram.shape
|
91 |
+
num_segments = math.ceil(num_frames / self.segment_num_frames)
|
92 |
+
adj_num_frames = num_segments * self.segment_num_frames
|
93 |
+
if adj_num_frames > num_frames:
|
94 |
+
c_spectrogram = nn.functional.pad(
|
95 |
+
c_spectrogram, (0, adj_num_frames - num_frames)
|
96 |
+
)
|
97 |
+
c_spectrogram = c_spectrogram.view(
|
98 |
+
fft_size, num_segments, self.segment_num_frames
|
99 |
+
)
|
100 |
+
# c_spectrogram: (257, num_segments, 32)
|
101 |
+
c_spectrogram = torch.permute(c_spectrogram, (1, 0, 2))
|
102 |
+
# c_spectrogram: (num_segments, 257, 32)
|
103 |
+
log_c_spectrogram = self._sp_log(c_spectrogram)
|
104 |
+
scaled_log_c_sp = self.scaler(log_c_spectrogram)
|
105 |
+
pred_noise_log_sp = batched_apply(
|
106 |
+
self.model, scaled_log_c_sp, detached=True
|
107 |
+
)
|
108 |
+
log_denoised_sp = log_c_spectrogram - pred_noise_log_sp
|
109 |
+
denoised_sp = self._sp_exp(log_denoised_sp)
|
110 |
+
# denoised_sp: (num_segments, 257, 32)
|
111 |
+
denoised_sp = torch.permute(denoised_sp, (1, 0, 2))
|
112 |
+
# denoised_sp: (257, num_segments, 32)
|
113 |
+
denoised_sp = denoised_sp.contiguous().view(1, fft_size, adj_num_frames)
|
114 |
+
# denoised_sp: (1, 257, adj_num_frames)
|
115 |
+
denoised_sp = denoised_sp[:, :, :num_frames]
|
116 |
+
denoised_sp = denoised_sp.cpu()
|
117 |
+
denoised_waveform = reconstruct_from_spectrogram(
|
118 |
+
denoised_sp, num_iterations=self.num_iterations
|
119 |
+
)
|
120 |
+
# denoised_waveform: (1, num_samples)
|
121 |
+
results.append(denoised_waveform)
|
122 |
+
cpu_results = torch.cat(results)
|
123 |
+
return cpu_results if return_cpu_tensor else cpu_results.to(self.device)
|
124 |
+
|
125 |
+
def process_audio_file(
|
126 |
+
self, in_audio_file: str, out_audio_file: str, auto_scale: bool = False
|
127 |
+
):
|
128 |
+
"""
|
129 |
+
Denoises an audio file.
|
130 |
+
@param in_audio_file: An input audio file with a format supported by torchaudio.
|
131 |
+
@param out_audio_file: Am output audio file with a format supported by torchaudio.
|
132 |
+
@param auto_scale: Whether the input waveform scale should be normalized before processing. Recommended for low-volume audio.
|
133 |
+
"""
|
134 |
+
waveform, sample_rate = torchaudio.load(in_audio_file)
|
135 |
+
denoised_waveform = self.process_waveform(
|
136 |
+
waveform, sample_rate, return_cpu_tensor=True, auto_scale=auto_scale
|
137 |
+
)
|
138 |
+
torchaudio.save(
|
139 |
+
out_audio_file, denoised_waveform, sample_rate=self.model_sample_rate
|
140 |
+
)
|
modules/Denoiser/AudioNosiseModel.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from audio_denoiser.modules.Permute import Permute
|
5 |
+
from audio_denoiser.modules.SimpleRoberta import SimpleRoberta
|
6 |
+
from audio_denoiser.modules.SpectrogramScaler import SpectrogramScaler
|
7 |
+
|
8 |
+
import json
|
9 |
+
|
10 |
+
|
11 |
+
class AudioNoiseModel(nn.Module):
|
12 |
+
def __init__(self, config: dict):
|
13 |
+
super(AudioNoiseModel, self).__init__()
|
14 |
+
|
15 |
+
# Encoder layers
|
16 |
+
self.config = config
|
17 |
+
scaler_dict = config["scaler"]
|
18 |
+
self.scaler = SpectrogramScaler.from_dict(scaler_dict)
|
19 |
+
self.in_channels = config.get("in_channels", 257)
|
20 |
+
self.roberta_hidden_size = config.get("roberta_hidden_size", 768)
|
21 |
+
self.model1 = nn.Sequential(
|
22 |
+
nn.Conv1d(self.in_channels, 1024, kernel_size=1),
|
23 |
+
nn.ELU(),
|
24 |
+
nn.Conv1d(1024, 1024, kernel_size=1),
|
25 |
+
nn.ELU(),
|
26 |
+
nn.Conv1d(1024, self.in_channels, kernel_size=1),
|
27 |
+
)
|
28 |
+
self.model2 = nn.Sequential(
|
29 |
+
Permute(0, 2, 1),
|
30 |
+
nn.Linear(self.in_channels, self.roberta_hidden_size),
|
31 |
+
SimpleRoberta(num_hidden_layers=5, hidden_size=self.roberta_hidden_size),
|
32 |
+
nn.Linear(self.roberta_hidden_size, self.in_channels),
|
33 |
+
Permute(0, 2, 1),
|
34 |
+
)
|
35 |
+
|
36 |
+
@property
|
37 |
+
def sample_rate(self) -> int:
|
38 |
+
return self.config.get("sample_rate", 16000)
|
39 |
+
|
40 |
+
@property
|
41 |
+
def n_fft(self) -> int:
|
42 |
+
return self.config.get("n_fft", 512)
|
43 |
+
|
44 |
+
@property
|
45 |
+
def num_frames(self) -> int:
|
46 |
+
return self.config.get("num_frames", 32)
|
47 |
+
|
48 |
+
def forward(self, x, use_scaler: bool = False, out_scale: float = 1.0):
|
49 |
+
if use_scaler:
|
50 |
+
x = self.scaler(x)
|
51 |
+
x1 = self.model1(x)
|
52 |
+
x2 = self.model2(x)
|
53 |
+
x = x1 + x2
|
54 |
+
return x * out_scale
|
55 |
+
|
56 |
+
|
57 |
+
def load_audio_denosier_model(dir_path: str, device) -> AudioNoiseModel:
|
58 |
+
config = json.load(open(f"{dir_path}/config.json", "r"))
|
59 |
+
model = AudioNoiseModel(config)
|
60 |
+
model.load_state_dict(torch.load(f"{dir_path}/pytorch_model.bin"))
|
61 |
+
|
62 |
+
model.to(device)
|
63 |
+
model.model1.to(device)
|
64 |
+
model.model2.to(device)
|
65 |
+
|
66 |
+
return model
|
modules/Denoiser/__init__.py
ADDED
File without changes
|
modules/Enhancer/ResembleEnhance.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
from resemble_enhance.enhancer.enhancer import Enhancer
|
4 |
+
from resemble_enhance.enhancer.hparams import HParams
|
5 |
+
from resemble_enhance.inference import inference
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from modules.utils.constants import MODELS_DIR
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
from threading import Lock
|
13 |
+
|
14 |
+
resemble_enhance = None
|
15 |
+
lock = Lock()
|
16 |
+
|
17 |
+
|
18 |
+
def load_enhancer(device: torch.device):
|
19 |
+
global resemble_enhance
|
20 |
+
with lock:
|
21 |
+
if resemble_enhance is None:
|
22 |
+
resemble_enhance = ResembleEnhance(device)
|
23 |
+
resemble_enhance.load_model()
|
24 |
+
return resemble_enhance
|
25 |
+
|
26 |
+
|
27 |
+
class ResembleEnhance:
|
28 |
+
hparams: HParams
|
29 |
+
enhancer: Enhancer
|
30 |
+
|
31 |
+
def __init__(self, device: torch.device):
|
32 |
+
self.device = device
|
33 |
+
|
34 |
+
self.enhancer = None
|
35 |
+
self.hparams = None
|
36 |
+
|
37 |
+
def load_model(self):
|
38 |
+
hparams = HParams.load(Path(MODELS_DIR) / "resemble-enhance")
|
39 |
+
enhancer = Enhancer(hparams)
|
40 |
+
state_dict = torch.load(
|
41 |
+
Path(MODELS_DIR) / "resemble-enhance" / "mp_rank_00_model_states.pt",
|
42 |
+
map_location="cpu",
|
43 |
+
)["module"]
|
44 |
+
enhancer.load_state_dict(state_dict)
|
45 |
+
enhancer.eval()
|
46 |
+
enhancer.to(self.device)
|
47 |
+
enhancer.denoiser.to(self.device)
|
48 |
+
|
49 |
+
self.hparams = hparams
|
50 |
+
self.enhancer = enhancer
|
51 |
+
|
52 |
+
@torch.inference_mode()
|
53 |
+
def denoise(self, dwav, sr, device) -> tuple[torch.Tensor, int]:
|
54 |
+
assert self.enhancer is not None, "Model not loaded"
|
55 |
+
assert self.enhancer.denoiser is not None, "Denoiser not loaded"
|
56 |
+
enhancer = self.enhancer
|
57 |
+
return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device)
|
58 |
+
|
59 |
+
@torch.inference_mode()
|
60 |
+
def enhance(
|
61 |
+
self,
|
62 |
+
dwav,
|
63 |
+
sr,
|
64 |
+
device,
|
65 |
+
nfe=32,
|
66 |
+
solver="midpoint",
|
67 |
+
lambd=0.5,
|
68 |
+
tau=0.5,
|
69 |
+
) -> tuple[torch.Tensor, int]:
|
70 |
+
assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}"
|
71 |
+
assert solver in (
|
72 |
+
"midpoint",
|
73 |
+
"rk4",
|
74 |
+
"euler",
|
75 |
+
), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}"
|
76 |
+
assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}"
|
77 |
+
assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}"
|
78 |
+
assert self.enhancer is not None, "Model not loaded"
|
79 |
+
enhancer = self.enhancer
|
80 |
+
enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
|
81 |
+
return inference(model=enhancer, dwav=dwav, sr=sr, device=device)
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
import torchaudio
|
86 |
+
from modules.models import load_chat_tts
|
87 |
+
|
88 |
+
load_chat_tts()
|
89 |
+
|
90 |
+
device = torch.device("cuda")
|
91 |
+
ench = ResembleEnhance(device)
|
92 |
+
ench.load_model()
|
93 |
+
|
94 |
+
wav, sr = torchaudio.load("test.wav")
|
95 |
+
|
96 |
+
print(wav.shape, type(wav), sr, type(sr))
|
97 |
+
exit()
|
98 |
+
|
99 |
+
wav = wav.squeeze(0).cuda()
|
100 |
+
|
101 |
+
print(wav.device)
|
102 |
+
|
103 |
+
denoised, d_sr = ench.denoise(wav.cpu(), sr, device)
|
104 |
+
denoised = denoised.unsqueeze(0)
|
105 |
+
print(denoised.shape)
|
106 |
+
torchaudio.save("denoised.wav", denoised, d_sr)
|
107 |
+
|
108 |
+
for solver in ("midpoint", "rk4", "euler"):
|
109 |
+
for lambd in (0.1, 0.5, 0.9):
|
110 |
+
for tau in (0.1, 0.5, 0.9):
|
111 |
+
enhanced, e_sr = ench.enhance(
|
112 |
+
wav.cpu(), sr, device, solver=solver, lambd=lambd, tau=tau, nfe=128
|
113 |
+
)
|
114 |
+
enhanced = enhanced.unsqueeze(0)
|
115 |
+
print(enhanced.shape)
|
116 |
+
torchaudio.save(f"enhanced_{solver}_{lambd}_{tau}.wav", enhanced, e_sr)
|
modules/Enhancer/__init__.py
ADDED
File without changes
|
modules/SynthesizeSegments.py
CHANGED
@@ -1,17 +1,18 @@
|
|
|
|
1 |
from pydub import AudioSegment
|
2 |
-
from typing import
|
3 |
from scipy.io.wavfile import write
|
4 |
import io
|
|
|
|
|
5 |
from modules.utils import rng
|
6 |
from modules.utils.audio import time_stretch, pitch_shift
|
7 |
from modules import generate_audio
|
8 |
from modules.normalization import text_normalize
|
9 |
import logging
|
10 |
import json
|
11 |
-
import copy
|
12 |
-
import numpy as np
|
13 |
|
14 |
-
from modules.speaker import Speaker
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
@@ -24,7 +25,7 @@ def audio_data_to_segment(audio_data, sr):
|
|
24 |
return AudioSegment.from_file(byte_io, format="wav")
|
25 |
|
26 |
|
27 |
-
def combine_audio_segments(audio_segments: list) -> AudioSegment:
|
28 |
combined_audio = AudioSegment.empty()
|
29 |
for segment in audio_segments:
|
30 |
combined_audio += segment
|
@@ -54,230 +55,191 @@ def to_number(value, t, default=0):
|
|
54 |
return default
|
55 |
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
class SynthesizeSegments:
|
58 |
def __init__(self, batch_size: int = 8):
|
59 |
self.batch_size = batch_size
|
60 |
self.batch_default_spk_seed = rng.np_rng()
|
61 |
self.batch_default_infer_seed = rng.np_rng()
|
62 |
|
63 |
-
def segment_to_generate_params(
|
|
|
|
|
|
|
|
|
|
|
64 |
if segment.get("params", None) is not None:
|
65 |
-
return segment
|
66 |
|
67 |
text = segment.get("text", "")
|
68 |
is_end = segment.get("is_end", False)
|
69 |
|
70 |
text = str(text).strip()
|
71 |
|
72 |
-
attrs = segment.
|
73 |
-
spk = attrs.
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
84 |
disable_normalize = attrs.get("normalize", "") == "False"
|
85 |
|
86 |
-
|
87 |
-
"
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
97 |
|
98 |
if not disable_normalize:
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
-
|
|
|
|
|
|
|
108 |
|
109 |
def bucket_segments(
|
110 |
-
self, segments: List[
|
111 |
-
) -> List[List[
|
112 |
-
|
113 |
-
buckets = {}
|
114 |
for segment in segments:
|
|
|
|
|
|
|
|
|
115 |
params = self.segment_to_generate_params(segment)
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
key = json.dumps(
|
121 |
-
{k: v for k, v in
|
122 |
)
|
123 |
if key not in buckets:
|
124 |
buckets[key] = []
|
125 |
buckets[key].append(segment)
|
126 |
|
127 |
-
|
128 |
-
bucket_list = list(buckets.values())
|
129 |
-
return bucket_list
|
130 |
|
131 |
-
def synthesize_segments(
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
buckets = self.bucket_segments(segments)
|
136 |
-
logger.debug(f"segments len: {len(segments)}")
|
137 |
-
logger.debug(f"bucket pool size: {len(buckets)}")
|
138 |
-
for bucket in buckets:
|
139 |
-
for i in range(0, len(bucket), self.batch_size):
|
140 |
-
batch = bucket[i : i + self.batch_size]
|
141 |
-
param_arr = [
|
142 |
-
self.segment_to_generate_params(segment) for segment in batch
|
143 |
-
]
|
144 |
-
texts = [params["text"] for params in param_arr]
|
145 |
-
|
146 |
-
params = param_arr[0] # Use the first segment to get the parameters
|
147 |
-
audio_datas = generate_audio.generate_audio_batch(
|
148 |
-
texts=texts,
|
149 |
-
temperature=params["temperature"],
|
150 |
-
top_P=params["top_P"],
|
151 |
-
top_K=params["top_K"],
|
152 |
-
spk=params["spk"],
|
153 |
-
infer_seed=params["infer_seed"],
|
154 |
-
prompt1=params["prompt1"],
|
155 |
-
prompt2=params["prompt2"],
|
156 |
-
prefix=params["prefix"],
|
157 |
-
)
|
158 |
-
for idx, segment in enumerate(batch):
|
159 |
-
(sr, audio_data) = audio_datas[idx]
|
160 |
-
rate = float(segment.get("rate", "1.0"))
|
161 |
-
volume = float(segment.get("volume", "0"))
|
162 |
-
pitch = float(segment.get("pitch", "0"))
|
163 |
-
|
164 |
-
audio_segment = audio_data_to_segment(audio_data, sr)
|
165 |
-
audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
|
166 |
-
original_index = segments.index(
|
167 |
-
segment
|
168 |
-
) # Get the original index of the segment
|
169 |
-
audio_segments[original_index] = (
|
170 |
-
audio_segment # Place the audio_segment in the correct position
|
171 |
-
)
|
172 |
-
|
173 |
-
return audio_segments
|
174 |
|
|
|
|
|
175 |
|
176 |
-
|
177 |
-
text: str,
|
178 |
-
spk: int = -1,
|
179 |
-
seed: int = -1,
|
180 |
-
top_p: float = 0.5,
|
181 |
-
top_k: int = 20,
|
182 |
-
temp: float = 0.3,
|
183 |
-
prompt1: str = "",
|
184 |
-
prompt2: str = "",
|
185 |
-
prefix: str = "",
|
186 |
-
enable_normalize=True,
|
187 |
-
is_end: bool = False,
|
188 |
-
) -> AudioSegment:
|
189 |
-
if enable_normalize:
|
190 |
-
text = text_normalize(text, is_end=is_end)
|
191 |
-
|
192 |
-
logger.debug(f"generate segment: {text}")
|
193 |
-
|
194 |
-
sample_rate, audio_data = generate_audio.generate_audio(
|
195 |
-
text=text,
|
196 |
-
temperature=temp if temp is not None else 0.3,
|
197 |
-
top_P=top_p if top_p is not None else 0.5,
|
198 |
-
top_K=top_k if top_k is not None else 20,
|
199 |
-
spk=spk if spk else -1,
|
200 |
-
infer_seed=seed if seed else -1,
|
201 |
-
prompt1=prompt1 if prompt1 else "",
|
202 |
-
prompt2=prompt2 if prompt2 else "",
|
203 |
-
prefix=prefix if prefix else "",
|
204 |
-
)
|
205 |
-
|
206 |
-
byte_io = io.BytesIO()
|
207 |
-
write(byte_io, sample_rate, audio_data)
|
208 |
-
byte_io.seek(0)
|
209 |
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
def synthesize_segment(segment: Dict[str, Any]) -> Union[AudioSegment, None]:
|
214 |
-
if "break" in segment:
|
215 |
-
pause_segment = AudioSegment.silent(duration=segment["break"])
|
216 |
-
return pause_segment
|
217 |
-
|
218 |
-
attrs = segment.get("attrs", {})
|
219 |
-
text = segment.get("text", "")
|
220 |
-
is_end = segment.get("is_end", False)
|
221 |
-
|
222 |
-
text = str(text).strip()
|
223 |
-
|
224 |
-
if text == "":
|
225 |
-
return None
|
226 |
-
|
227 |
-
spk = attrs.get("spk", "")
|
228 |
-
if isinstance(spk, str):
|
229 |
-
spk = int(spk)
|
230 |
-
seed = to_number(attrs.get("seed", ""), int, -1)
|
231 |
-
top_k = to_number(attrs.get("top_k", ""), int, None)
|
232 |
-
top_p = to_number(attrs.get("top_p", ""), float, None)
|
233 |
-
temp = to_number(attrs.get("temp", ""), float, None)
|
234 |
-
|
235 |
-
prompt1 = attrs.get("prompt1", "")
|
236 |
-
prompt2 = attrs.get("prompt2", "")
|
237 |
-
prefix = attrs.get("prefix", "")
|
238 |
-
disable_normalize = attrs.get("normalize", "") == "False"
|
239 |
-
|
240 |
-
audio_segment = generate_audio_segment(
|
241 |
-
text,
|
242 |
-
enable_normalize=not disable_normalize,
|
243 |
-
spk=spk,
|
244 |
-
seed=seed,
|
245 |
-
top_k=top_k,
|
246 |
-
top_p=top_p,
|
247 |
-
temp=temp,
|
248 |
-
prompt1=prompt1,
|
249 |
-
prompt2=prompt2,
|
250 |
-
prefix=prefix,
|
251 |
-
is_end=is_end,
|
252 |
-
)
|
253 |
-
|
254 |
-
rate = float(attrs.get("rate", "1.0"))
|
255 |
-
volume = float(attrs.get("volume", "0"))
|
256 |
-
pitch = float(attrs.get("pitch", "0"))
|
257 |
-
|
258 |
-
audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
|
259 |
|
260 |
-
|
261 |
|
262 |
|
263 |
# 示例使用
|
264 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
ssml_segments = [
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
{
|
271 |
-
"text": "大🍉,一个大🍉,嘿,你的感觉真的很奇妙 [lbreak]",
|
272 |
-
"attrs": {"spk": 2, "temp": 0.1, "seed": 42},
|
273 |
-
},
|
274 |
-
{
|
275 |
-
"text": "大🍌,一条大🍌,嘿,你的感觉真的很奇妙 [lbreak]",
|
276 |
-
"attrs": {"spk": 2, "temp": 0.3, "seed": 42},
|
277 |
-
},
|
278 |
]
|
279 |
|
280 |
synthesizer = SynthesizeSegments(batch_size=2)
|
281 |
audio_segments = synthesizer.synthesize_segments(ssml_segments)
|
|
|
282 |
combined_audio = combine_audio_segments(audio_segments)
|
283 |
combined_audio.export("output.wav", format="wav")
|
|
|
1 |
+
from box import Box
|
2 |
from pydub import AudioSegment
|
3 |
+
from typing import List, Union
|
4 |
from scipy.io.wavfile import write
|
5 |
import io
|
6 |
+
from modules.api.utils import calc_spk_style
|
7 |
+
from modules.ssml_parser.SSMLParser import SSMLSegment, SSMLBreak, SSMLContext
|
8 |
from modules.utils import rng
|
9 |
from modules.utils.audio import time_stretch, pitch_shift
|
10 |
from modules import generate_audio
|
11 |
from modules.normalization import text_normalize
|
12 |
import logging
|
13 |
import json
|
|
|
|
|
14 |
|
15 |
+
from modules.speaker import Speaker, speaker_mgr
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
|
|
25 |
return AudioSegment.from_file(byte_io, format="wav")
|
26 |
|
27 |
|
28 |
+
def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment:
|
29 |
combined_audio = AudioSegment.empty()
|
30 |
for segment in audio_segments:
|
31 |
combined_audio += segment
|
|
|
55 |
return default
|
56 |
|
57 |
|
58 |
+
class TTSAudioSegment(Box):
|
59 |
+
text: str
|
60 |
+
temperature: float
|
61 |
+
top_P: float
|
62 |
+
top_K: int
|
63 |
+
spk: int
|
64 |
+
infer_seed: int
|
65 |
+
prompt1: str
|
66 |
+
prompt2: str
|
67 |
+
prefix: str
|
68 |
+
|
69 |
+
_type: str
|
70 |
+
|
71 |
+
def __init__(self, *args, **kwargs):
|
72 |
+
super().__init__(*args, **kwargs)
|
73 |
+
|
74 |
+
|
75 |
class SynthesizeSegments:
|
76 |
def __init__(self, batch_size: int = 8):
|
77 |
self.batch_size = batch_size
|
78 |
self.batch_default_spk_seed = rng.np_rng()
|
79 |
self.batch_default_infer_seed = rng.np_rng()
|
80 |
|
81 |
+
def segment_to_generate_params(
|
82 |
+
self, segment: Union[SSMLSegment, SSMLBreak]
|
83 |
+
) -> TTSAudioSegment:
|
84 |
+
if isinstance(segment, SSMLBreak):
|
85 |
+
return TTSAudioSegment(_type="break")
|
86 |
+
|
87 |
if segment.get("params", None) is not None:
|
88 |
+
return TTSAudioSegment(**segment.get("params"))
|
89 |
|
90 |
text = segment.get("text", "")
|
91 |
is_end = segment.get("is_end", False)
|
92 |
|
93 |
text = str(text).strip()
|
94 |
|
95 |
+
attrs = segment.attrs
|
96 |
+
spk = attrs.spk
|
97 |
+
style = attrs.style
|
98 |
+
|
99 |
+
ss_params = calc_spk_style(spk, style)
|
100 |
+
|
101 |
+
if "spk" in ss_params:
|
102 |
+
spk = ss_params["spk"]
|
103 |
+
|
104 |
+
seed = to_number(attrs.seed, int, ss_params.get("seed") or -1)
|
105 |
+
top_k = to_number(attrs.top_k, int, None)
|
106 |
+
top_p = to_number(attrs.top_p, float, None)
|
107 |
+
temp = to_number(attrs.temp, float, None)
|
108 |
+
|
109 |
+
prompt1 = attrs.prompt1 or ss_params.get("prompt1")
|
110 |
+
prompt2 = attrs.prompt2 or ss_params.get("prompt2")
|
111 |
+
prefix = attrs.prefix or ss_params.get("prefix")
|
112 |
disable_normalize = attrs.get("normalize", "") == "False"
|
113 |
|
114 |
+
seg = TTSAudioSegment(
|
115 |
+
_type="voice",
|
116 |
+
text=text,
|
117 |
+
temperature=temp if temp is not None else 0.3,
|
118 |
+
top_P=top_p if top_p is not None else 0.5,
|
119 |
+
top_K=top_k if top_k is not None else 20,
|
120 |
+
spk=spk if spk else -1,
|
121 |
+
infer_seed=seed if seed else -1,
|
122 |
+
prompt1=prompt1 if prompt1 else "",
|
123 |
+
prompt2=prompt2 if prompt2 else "",
|
124 |
+
prefix=prefix if prefix else "",
|
125 |
+
)
|
126 |
|
127 |
if not disable_normalize:
|
128 |
+
seg.text = text_normalize(text, is_end=is_end)
|
129 |
+
|
130 |
+
# NOTE 每个batch的默认seed保证前后一致即使是没设置spk的情况
|
131 |
+
if seg.spk == -1:
|
132 |
+
seg.spk = self.batch_default_spk_seed
|
133 |
+
if seg.infer_seed == -1:
|
134 |
+
seg.infer_seed = self.batch_default_infer_seed
|
135 |
+
|
136 |
+
return seg
|
137 |
+
|
138 |
+
def process_break_segments(
|
139 |
+
self,
|
140 |
+
src_segments: List[SSMLBreak],
|
141 |
+
bucket_segments: List[SSMLBreak],
|
142 |
+
audio_segments: List[AudioSegment],
|
143 |
+
):
|
144 |
+
for segment in bucket_segments:
|
145 |
+
index = src_segments.index(segment)
|
146 |
+
audio_segments[index] = AudioSegment.silent(
|
147 |
+
duration=int(segment.attrs.duration)
|
148 |
+
)
|
149 |
|
150 |
+
def process_voice_segments(
|
151 |
+
self,
|
152 |
+
src_segments: List[SSMLSegment],
|
153 |
+
bucket: List[SSMLSegment],
|
154 |
+
audio_segments: List[AudioSegment],
|
155 |
+
):
|
156 |
+
for i in range(0, len(bucket), self.batch_size):
|
157 |
+
batch = bucket[i : i + self.batch_size]
|
158 |
+
param_arr = [self.segment_to_generate_params(segment) for segment in batch]
|
159 |
+
texts = [params.text for params in param_arr]
|
160 |
+
|
161 |
+
params = param_arr[0]
|
162 |
+
audio_datas = generate_audio.generate_audio_batch(
|
163 |
+
texts=texts,
|
164 |
+
temperature=params.temperature,
|
165 |
+
top_P=params.top_P,
|
166 |
+
top_K=params.top_K,
|
167 |
+
spk=params.spk,
|
168 |
+
infer_seed=params.infer_seed,
|
169 |
+
prompt1=params.prompt1,
|
170 |
+
prompt2=params.prompt2,
|
171 |
+
prefix=params.prefix,
|
172 |
+
)
|
173 |
+
for idx, segment in enumerate(batch):
|
174 |
+
sr, audio_data = audio_datas[idx]
|
175 |
+
rate = float(segment.get("rate", "1.0"))
|
176 |
+
volume = float(segment.get("volume", "0"))
|
177 |
+
pitch = float(segment.get("pitch", "0"))
|
178 |
|
179 |
+
audio_segment = audio_data_to_segment(audio_data, sr)
|
180 |
+
audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
|
181 |
+
original_index = src_segments.index(segment)
|
182 |
+
audio_segments[original_index] = audio_segment
|
183 |
|
184 |
def bucket_segments(
|
185 |
+
self, segments: List[Union[SSMLSegment, SSMLBreak]]
|
186 |
+
) -> List[List[Union[SSMLSegment, SSMLBreak]]]:
|
187 |
+
buckets = {"<break>": []}
|
|
|
188 |
for segment in segments:
|
189 |
+
if isinstance(segment, SSMLBreak):
|
190 |
+
buckets["<break>"].append(segment)
|
191 |
+
continue
|
192 |
+
|
193 |
params = self.segment_to_generate_params(segment)
|
194 |
|
195 |
+
if isinstance(params.spk, Speaker):
|
196 |
+
params.spk = str(params.spk.id)
|
197 |
+
|
198 |
key = json.dumps(
|
199 |
+
{k: v for k, v in params.items() if k != "text"}, sort_keys=True
|
200 |
)
|
201 |
if key not in buckets:
|
202 |
buckets[key] = []
|
203 |
buckets[key].append(segment)
|
204 |
|
205 |
+
return buckets
|
|
|
|
|
206 |
|
207 |
+
def synthesize_segments(
|
208 |
+
self, segments: List[Union[SSMLSegment, SSMLBreak]]
|
209 |
+
) -> List[AudioSegment]:
|
210 |
+
audio_segments = [None] * len(segments)
|
211 |
buckets = self.bucket_segments(segments)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
+
break_segments = buckets.pop("<break>")
|
214 |
+
self.process_break_segments(segments, break_segments, audio_segments)
|
215 |
|
216 |
+
buckets = list(buckets.values())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
+
for bucket in buckets:
|
219 |
+
self.process_voice_segments(segments, bucket, audio_segments)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
+
return audio_segments
|
222 |
|
223 |
|
224 |
# 示例使用
|
225 |
if __name__ == "__main__":
|
226 |
+
ctx1 = SSMLContext()
|
227 |
+
ctx1.spk = 1
|
228 |
+
ctx1.seed = 42
|
229 |
+
ctx1.temp = 0.1
|
230 |
+
ctx2 = SSMLContext()
|
231 |
+
ctx2.spk = 2
|
232 |
+
ctx2.seed = 42
|
233 |
+
ctx2.temp = 0.1
|
234 |
ssml_segments = [
|
235 |
+
SSMLSegment(text="大🍌,一条大🍌,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()),
|
236 |
+
SSMLBreak(duration_ms=1000),
|
237 |
+
SSMLSegment(text="大🍉,一个大🍉,嘿,你的感觉真的很奇妙", attrs=ctx1.copy()),
|
238 |
+
SSMLSegment(text="大🍊,一个大🍊,嘿,你的感觉真的很奇妙", attrs=ctx2.copy()),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
]
|
240 |
|
241 |
synthesizer = SynthesizeSegments(batch_size=2)
|
242 |
audio_segments = synthesizer.synthesize_segments(ssml_segments)
|
243 |
+
print(audio_segments)
|
244 |
combined_audio = combine_audio_segments(audio_segments)
|
245 |
combined_audio.export("output.wav", format="wav")
|
modules/api/impl/google_api.py
CHANGED
@@ -18,7 +18,6 @@ from modules.ssml import parse_ssml
|
|
18 |
from modules.SynthesizeSegments import (
|
19 |
SynthesizeSegments,
|
20 |
combine_audio_segments,
|
21 |
-
synthesize_segment,
|
22 |
)
|
23 |
|
24 |
from modules.api import utils as api_utils
|
|
|
18 |
from modules.SynthesizeSegments import (
|
19 |
SynthesizeSegments,
|
20 |
combine_audio_segments,
|
|
|
21 |
)
|
22 |
|
23 |
from modules.api import utils as api_utils
|
modules/api/impl/speaker_api.py
CHANGED
@@ -7,11 +7,11 @@ from modules.api.Api import APIManager
|
|
7 |
|
8 |
|
9 |
class CreateSpeaker(BaseModel):
|
10 |
-
seed: int
|
11 |
name: str
|
12 |
gender: str
|
13 |
describe: str
|
14 |
-
tensor: list
|
|
|
15 |
|
16 |
|
17 |
class UpdateSpeaker(BaseModel):
|
@@ -76,7 +76,7 @@ def setup(app: APIManager):
|
|
76 |
gender=request.gender,
|
77 |
describe=request.describe,
|
78 |
)
|
79 |
-
|
80 |
# from seed
|
81 |
speaker = speaker_mgr.create_speaker_from_seed(
|
82 |
seed=request.seed,
|
@@ -84,6 +84,10 @@ def setup(app: APIManager):
|
|
84 |
gender=request.gender,
|
85 |
describe=request.describe,
|
86 |
)
|
|
|
|
|
|
|
|
|
87 |
return {"message": "ok", "data": speaker.to_json()}
|
88 |
|
89 |
@app.post("/v1/speaker/refresh", response_model=api_utils.BaseResponse)
|
|
|
7 |
|
8 |
|
9 |
class CreateSpeaker(BaseModel):
|
|
|
10 |
name: str
|
11 |
gender: str
|
12 |
describe: str
|
13 |
+
tensor: list = None
|
14 |
+
seed: int = None
|
15 |
|
16 |
|
17 |
class UpdateSpeaker(BaseModel):
|
|
|
76 |
gender=request.gender,
|
77 |
describe=request.describe,
|
78 |
)
|
79 |
+
elif request.seed:
|
80 |
# from seed
|
81 |
speaker = speaker_mgr.create_speaker_from_seed(
|
82 |
seed=request.seed,
|
|
|
84 |
gender=request.gender,
|
85 |
describe=request.describe,
|
86 |
)
|
87 |
+
else:
|
88 |
+
raise HTTPException(
|
89 |
+
status_code=400, detail="Missing tensor or seed in request"
|
90 |
+
)
|
91 |
return {"message": "ok", "data": speaker.to_json()}
|
92 |
|
93 |
@app.post("/v1/speaker/refresh", response_model=api_utils.BaseResponse)
|
modules/api/impl/ssml_api.py
CHANGED
@@ -10,7 +10,6 @@ from modules.normalization import text_normalize
|
|
10 |
from modules.ssml import parse_ssml
|
11 |
from modules.SynthesizeSegments import (
|
12 |
SynthesizeSegments,
|
13 |
-
synthesize_segment,
|
14 |
combine_audio_segments,
|
15 |
)
|
16 |
|
@@ -23,6 +22,8 @@ from modules.api.Api import APIManager
|
|
23 |
class SSMLRequest(BaseModel):
|
24 |
ssml: str
|
25 |
format: str = "mp3"
|
|
|
|
|
26 |
batch_size: int = 4
|
27 |
|
28 |
|
@@ -48,29 +49,15 @@ async def synthesize_ssml(
|
|
48 |
for seg in segments:
|
49 |
seg["text"] = text_normalize(seg["text"], is_end=True)
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
return StreamingResponse(buffer, media_type=f"audio/{format}")
|
61 |
-
else:
|
62 |
-
|
63 |
-
def audio_streamer():
|
64 |
-
for segment in segments:
|
65 |
-
audio_segment = synthesize_segment(segment=segment)
|
66 |
-
buffer = io.BytesIO()
|
67 |
-
audio_segment.export(buffer, format="wav")
|
68 |
-
buffer.seek(0)
|
69 |
-
if format == "mp3":
|
70 |
-
buffer = api_utils.wav_to_mp3(buffer)
|
71 |
-
yield buffer.read()
|
72 |
-
|
73 |
-
return StreamingResponse(audio_streamer(), media_type=f"audio/{format}")
|
74 |
|
75 |
except Exception as e:
|
76 |
import logging
|
|
|
10 |
from modules.ssml import parse_ssml
|
11 |
from modules.SynthesizeSegments import (
|
12 |
SynthesizeSegments,
|
|
|
13 |
combine_audio_segments,
|
14 |
)
|
15 |
|
|
|
22 |
class SSMLRequest(BaseModel):
|
23 |
ssml: str
|
24 |
format: str = "mp3"
|
25 |
+
|
26 |
+
# NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪
|
27 |
batch_size: int = 4
|
28 |
|
29 |
|
|
|
49 |
for seg in segments:
|
50 |
seg["text"] = text_normalize(seg["text"], is_end=True)
|
51 |
|
52 |
+
synthesize = SynthesizeSegments(batch_size)
|
53 |
+
audio_segments = synthesize.synthesize_segments(segments)
|
54 |
+
combined_audio = combine_audio_segments(audio_segments)
|
55 |
+
buffer = io.BytesIO()
|
56 |
+
combined_audio.export(buffer, format="wav")
|
57 |
+
buffer.seek(0)
|
58 |
+
if format == "mp3":
|
59 |
+
buffer = api_utils.wav_to_mp3(buffer)
|
60 |
+
return StreamingResponse(buffer, media_type=f"audio/{format}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
except Exception as e:
|
63 |
import logging
|
modules/api/utils.py
CHANGED
@@ -52,7 +52,6 @@ def to_number(value, t, default=0):
|
|
52 |
def calc_spk_style(spk: Union[str, int], style: Union[str, int]):
|
53 |
voice_attrs = {
|
54 |
"spk": None,
|
55 |
-
"seed": None,
|
56 |
"prompt1": None,
|
57 |
"prompt2": None,
|
58 |
"prefix": None,
|
@@ -85,7 +84,6 @@ def calc_spk_style(spk: Union[str, int], style: Union[str, int]):
|
|
85 |
merge_prompt(voice_attrs, params)
|
86 |
|
87 |
voice_attrs["spk"] = params.get("spk", voice_attrs.get("spk", None))
|
88 |
-
voice_attrs["seed"] = params.get("seed", voice_attrs.get("seed", None))
|
89 |
voice_attrs["temperature"] = params.get(
|
90 |
"temp", voice_attrs.get("temperature", None)
|
91 |
)
|
|
|
52 |
def calc_spk_style(spk: Union[str, int], style: Union[str, int]):
|
53 |
voice_attrs = {
|
54 |
"spk": None,
|
|
|
55 |
"prompt1": None,
|
56 |
"prompt2": None,
|
57 |
"prefix": None,
|
|
|
84 |
merge_prompt(voice_attrs, params)
|
85 |
|
86 |
voice_attrs["spk"] = params.get("spk", voice_attrs.get("spk", None))
|
|
|
87 |
voice_attrs["temperature"] = params.get(
|
88 |
"temp", voice_attrs.get("temperature", None)
|
89 |
)
|
modules/denoise.py
CHANGED
@@ -1,7 +1,51 @@
|
|
1 |
-
|
|
|
|
|
2 |
import torch
|
3 |
import torchaudio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
|
6 |
class TTSAudioDenoiser:
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
import torch
|
5 |
import torchaudio
|
6 |
+
from modules.Denoiser.AudioDenoiser import AudioDenoiser
|
7 |
+
|
8 |
+
from modules.utils.constants import MODELS_DIR
|
9 |
+
|
10 |
+
from modules.devices import devices
|
11 |
+
|
12 |
+
import soundfile as sf
|
13 |
+
|
14 |
+
ad: Union[AudioDenoiser, None] = None
|
15 |
|
16 |
|
17 |
class TTSAudioDenoiser:
|
18 |
+
|
19 |
+
def load_ad(self):
|
20 |
+
global ad
|
21 |
+
if ad is None:
|
22 |
+
ad = AudioDenoiser(
|
23 |
+
os.path.join(
|
24 |
+
MODELS_DIR,
|
25 |
+
"Denoise",
|
26 |
+
"audio-denoiser-512-32-v1",
|
27 |
+
),
|
28 |
+
device=devices.device,
|
29 |
+
)
|
30 |
+
ad.model.to(devices.device)
|
31 |
+
return ad
|
32 |
+
|
33 |
+
def denoise(self, audio_data, sample_rate, auto_scale=False):
|
34 |
+
ad = self.load_ad()
|
35 |
+
sr = ad.model_sample_rate
|
36 |
+
return sr, ad.process_waveform(audio_data, sample_rate, auto_scale)
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
tts_deno = TTSAudioDenoiser()
|
41 |
+
data, sr = sf.read("test.wav")
|
42 |
+
audio_tensor = torch.from_numpy(data).unsqueeze(0).float()
|
43 |
+
print(audio_tensor)
|
44 |
+
|
45 |
+
# data, sr = torchaudio.load("test.wav")
|
46 |
+
# print(data)
|
47 |
+
# data = data.to(devices.device)
|
48 |
+
|
49 |
+
sr, denoised = tts_deno.denoise(audio_data=audio_tensor, sample_rate=sr)
|
50 |
+
denoised = denoised.cpu()
|
51 |
+
torchaudio.save("denoised.wav", denoised, sample_rate=sr)
|
modules/generate_audio.py
CHANGED
@@ -79,7 +79,7 @@ def generate_audio_batch(
|
|
79 |
params_infer_code["spk_emb"] = spk.emb
|
80 |
logger.info(("spk", spk.name))
|
81 |
else:
|
82 |
-
raise ValueError("spk must be int or Speaker")
|
83 |
|
84 |
logger.info(
|
85 |
{
|
|
|
79 |
params_infer_code["spk_emb"] = spk.emb
|
80 |
logger.info(("spk", spk.name))
|
81 |
else:
|
82 |
+
raise ValueError(f"spk must be int or Speaker, but: <{type(spk)}> {spk}")
|
83 |
|
84 |
logger.info(
|
85 |
{
|
modules/models.py
CHANGED
@@ -37,17 +37,9 @@ def load_chat_tts_in_thread():
|
|
37 |
logger.info("ChatTTS models loaded")
|
38 |
|
39 |
|
40 |
-
def
|
41 |
with lock:
|
42 |
if chat_tts is None:
|
43 |
-
model_thread = threading.Thread(target=load_chat_tts_in_thread)
|
44 |
-
model_thread.start()
|
45 |
-
model_thread.join()
|
46 |
-
|
47 |
-
|
48 |
-
def load_chat_tts():
|
49 |
-
if chat_tts is None:
|
50 |
-
with lock:
|
51 |
load_chat_tts_in_thread()
|
52 |
if chat_tts is None:
|
53 |
raise Exception("Failed to load ChatTTS models")
|
|
|
37 |
logger.info("ChatTTS models loaded")
|
38 |
|
39 |
|
40 |
+
def load_chat_tts():
|
41 |
with lock:
|
42 |
if chat_tts is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
load_chat_tts_in_thread()
|
44 |
if chat_tts is None:
|
45 |
raise Exception("Failed to load ChatTTS models")
|
modules/speaker.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
from typing import Union
|
|
|
3 |
import torch
|
4 |
|
5 |
from modules import models
|
@@ -16,6 +17,18 @@ def create_speaker_from_seed(seed):
|
|
16 |
|
17 |
|
18 |
class Speaker:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def __init__(self, seed, name="", gender="", describe=""):
|
20 |
self.id = uuid.uuid4()
|
21 |
self.seed = seed
|
@@ -24,15 +37,20 @@ class Speaker:
|
|
24 |
self.describe = describe
|
25 |
self.emb = None
|
26 |
|
|
|
|
|
|
|
27 |
def to_json(self, with_emb=False):
|
28 |
-
return
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
36 |
|
37 |
def fix(self):
|
38 |
is_update = False
|
@@ -78,14 +96,9 @@ class SpeakerManager:
|
|
78 |
self.speakers = {}
|
79 |
for speaker_file in os.listdir(self.speaker_dir):
|
80 |
if speaker_file.endswith(".pt"):
|
81 |
-
|
82 |
-
self.speaker_dir + speaker_file
|
83 |
)
|
84 |
-
self.speakers[speaker_file] = speaker
|
85 |
-
|
86 |
-
is_update = speaker.fix()
|
87 |
-
if is_update:
|
88 |
-
torch.save(speaker, self.speaker_dir + speaker_file)
|
89 |
|
90 |
def list_speakers(self):
|
91 |
return list(self.speakers.values())
|
@@ -103,8 +116,8 @@ class SpeakerManager:
|
|
103 |
def create_speaker_from_tensor(
|
104 |
self, tensor, filename="", name="", gender="", describe=""
|
105 |
):
|
106 |
-
if
|
107 |
-
|
108 |
speaker = Speaker(seed=-2, name=name, gender=gender, describe=describe)
|
109 |
if isinstance(tensor, torch.Tensor):
|
110 |
speaker.emb = tensor
|
|
|
1 |
import os
|
2 |
from typing import Union
|
3 |
+
from box import Box
|
4 |
import torch
|
5 |
|
6 |
from modules import models
|
|
|
17 |
|
18 |
|
19 |
class Speaker:
|
20 |
+
@staticmethod
|
21 |
+
def from_file(file_like):
|
22 |
+
speaker = torch.load(file_like, map_location=torch.device("cpu"))
|
23 |
+
speaker.fix()
|
24 |
+
return speaker
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def from_tensor(tensor):
|
28 |
+
speaker = Speaker(seed=-2)
|
29 |
+
speaker.emb = tensor
|
30 |
+
return speaker
|
31 |
+
|
32 |
def __init__(self, seed, name="", gender="", describe=""):
|
33 |
self.id = uuid.uuid4()
|
34 |
self.seed = seed
|
|
|
37 |
self.describe = describe
|
38 |
self.emb = None
|
39 |
|
40 |
+
# TODO replace emb => tokens
|
41 |
+
self.tokens = []
|
42 |
+
|
43 |
def to_json(self, with_emb=False):
|
44 |
+
return Box(
|
45 |
+
**{
|
46 |
+
"id": str(self.id),
|
47 |
+
"seed": self.seed,
|
48 |
+
"name": self.name,
|
49 |
+
"gender": self.gender,
|
50 |
+
"describe": self.describe,
|
51 |
+
"emb": self.emb.tolist() if with_emb else None,
|
52 |
+
}
|
53 |
+
)
|
54 |
|
55 |
def fix(self):
|
56 |
is_update = False
|
|
|
96 |
self.speakers = {}
|
97 |
for speaker_file in os.listdir(self.speaker_dir):
|
98 |
if speaker_file.endswith(".pt"):
|
99 |
+
self.speakers[speaker_file] = Speaker.from_file(
|
100 |
+
self.speaker_dir + speaker_file
|
101 |
)
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
def list_speakers(self):
|
104 |
return list(self.speakers.values())
|
|
|
116 |
def create_speaker_from_tensor(
|
117 |
self, tensor, filename="", name="", gender="", describe=""
|
118 |
):
|
119 |
+
if filename == "":
|
120 |
+
filename = name
|
121 |
speaker = Speaker(seed=-2, name=name, gender=gender, describe=describe)
|
122 |
if isinstance(tensor, torch.Tensor):
|
123 |
speaker.emb = tensor
|
modules/ssml_parser/SSMLParser.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lxml import etree
|
2 |
+
|
3 |
+
|
4 |
+
from typing import Any, List, Dict, Union
|
5 |
+
import logging
|
6 |
+
|
7 |
+
from modules.data import styles_mgr
|
8 |
+
from modules.speaker import speaker_mgr
|
9 |
+
from box import Box
|
10 |
+
import copy
|
11 |
+
|
12 |
+
|
13 |
+
class SSMLContext(Box):
|
14 |
+
def __init__(self, parent=None):
|
15 |
+
self.parent: Union[SSMLContext, None] = parent
|
16 |
+
|
17 |
+
self.style = None
|
18 |
+
self.spk = None
|
19 |
+
self.volume = None
|
20 |
+
self.rate = None
|
21 |
+
self.pitch = None
|
22 |
+
# tempurature
|
23 |
+
self.temp = None
|
24 |
+
self.top_p = None
|
25 |
+
self.top_k = None
|
26 |
+
self.seed = None
|
27 |
+
self.noramalize = None
|
28 |
+
self.prompt1 = None
|
29 |
+
self.prompt2 = None
|
30 |
+
self.prefix = None
|
31 |
+
|
32 |
+
|
33 |
+
class SSMLSegment(Box):
|
34 |
+
def __init__(self, text: str, attrs=SSMLContext()):
|
35 |
+
self.attrs = attrs
|
36 |
+
self.text = text
|
37 |
+
self.params = None
|
38 |
+
|
39 |
+
|
40 |
+
class SSMLBreak:
|
41 |
+
def __init__(self, duration_ms: Union[str, int, float]):
|
42 |
+
# TODO 支持其他单位
|
43 |
+
duration_ms = int(str(duration_ms).replace("ms", ""))
|
44 |
+
self.attrs = Box(**{"duration": duration_ms})
|
45 |
+
|
46 |
+
|
47 |
+
class SSMLParser:
|
48 |
+
|
49 |
+
def __init__(self):
|
50 |
+
self.logger = logging.getLogger(__name__)
|
51 |
+
self.logger.debug("SSMLParser.__init__()")
|
52 |
+
self.resolvers = []
|
53 |
+
|
54 |
+
def resolver(self, tag: str):
|
55 |
+
def decorator(func):
|
56 |
+
self.resolvers.append((tag, func))
|
57 |
+
return func
|
58 |
+
|
59 |
+
return decorator
|
60 |
+
|
61 |
+
def parse(self, ssml: str) -> List[Union[SSMLSegment, SSMLBreak]]:
|
62 |
+
root = etree.fromstring(ssml)
|
63 |
+
|
64 |
+
root_ctx = SSMLContext()
|
65 |
+
segments = []
|
66 |
+
self.resolve(root, root_ctx, segments)
|
67 |
+
|
68 |
+
return segments
|
69 |
+
|
70 |
+
def resolve(
|
71 |
+
self, element: etree.Element, context: SSMLContext, segments: List[SSMLSegment]
|
72 |
+
):
|
73 |
+
resolver = [resolver for tag, resolver in self.resolvers if tag == element.tag]
|
74 |
+
if len(resolver) == 0:
|
75 |
+
raise NotImplementedError(f"Tag {element.tag} not supported.")
|
76 |
+
else:
|
77 |
+
resolver = resolver[0]
|
78 |
+
|
79 |
+
resolver(element, context, segments, self)
|
80 |
+
|
81 |
+
|
82 |
+
def create_ssml_parser():
|
83 |
+
parser = SSMLParser()
|
84 |
+
|
85 |
+
@parser.resolver("speak")
|
86 |
+
def tag_speak(element, context, segments, parser):
|
87 |
+
ctx = copy.deepcopy(context)
|
88 |
+
|
89 |
+
version = element.get("version")
|
90 |
+
if version != "0.1":
|
91 |
+
raise ValueError(f"Unsupported SSML version {version}")
|
92 |
+
|
93 |
+
for child in element:
|
94 |
+
parser.resolve(child, ctx, segments)
|
95 |
+
|
96 |
+
@parser.resolver("voice")
|
97 |
+
def tag_voice(element, context, segments, parser):
|
98 |
+
ctx = copy.deepcopy(context)
|
99 |
+
|
100 |
+
ctx.spk = element.get("spk", ctx.spk)
|
101 |
+
ctx.style = element.get("style", ctx.style)
|
102 |
+
ctx.spk = element.get("spk", ctx.spk)
|
103 |
+
ctx.volume = element.get("volume", ctx.volume)
|
104 |
+
ctx.rate = element.get("rate", ctx.rate)
|
105 |
+
ctx.pitch = element.get("pitch", ctx.pitch)
|
106 |
+
# tempurature
|
107 |
+
ctx.temp = element.get("temp", ctx.temp)
|
108 |
+
ctx.top_p = element.get("top_p", ctx.top_p)
|
109 |
+
ctx.top_k = element.get("top_k", ctx.top_k)
|
110 |
+
ctx.seed = element.get("seed", ctx.seed)
|
111 |
+
ctx.noramalize = element.get("noramalize", ctx.noramalize)
|
112 |
+
ctx.prompt1 = element.get("prompt1", ctx.prompt1)
|
113 |
+
ctx.prompt2 = element.get("prompt2", ctx.prompt2)
|
114 |
+
ctx.prefix = element.get("prefix", ctx.prefix)
|
115 |
+
|
116 |
+
# 处理 voice 开头的文本
|
117 |
+
if element.text and element.text.strip():
|
118 |
+
segments.append(SSMLSegment(element.text.strip(), ctx))
|
119 |
+
|
120 |
+
for child in element:
|
121 |
+
parser.resolve(child, ctx, segments)
|
122 |
+
|
123 |
+
# 处理 voice 结尾的文本
|
124 |
+
if child.tail and child.tail.strip():
|
125 |
+
segments.append(SSMLSegment(child.tail.strip(), ctx))
|
126 |
+
|
127 |
+
@parser.resolver("break")
|
128 |
+
def tag_break(element, context, segments, parser):
|
129 |
+
time_ms = int(element.get("time", "0").replace("ms", ""))
|
130 |
+
segments.append(SSMLBreak(time_ms))
|
131 |
+
|
132 |
+
@parser.resolver("prosody")
|
133 |
+
def tag_prosody(element, context, segments, parser):
|
134 |
+
ctx = copy.deepcopy(context)
|
135 |
+
|
136 |
+
ctx.spk = element.get("spk", ctx.spk)
|
137 |
+
ctx.style = element.get("style", ctx.style)
|
138 |
+
ctx.spk = element.get("spk", ctx.spk)
|
139 |
+
ctx.volume = element.get("volume", ctx.volume)
|
140 |
+
ctx.rate = element.get("rate", ctx.rate)
|
141 |
+
ctx.pitch = element.get("pitch", ctx.pitch)
|
142 |
+
# tempurature
|
143 |
+
ctx.temp = element.get("temp", ctx.temp)
|
144 |
+
ctx.top_p = element.get("top_p", ctx.top_p)
|
145 |
+
ctx.top_k = element.get("top_k", ctx.top_k)
|
146 |
+
ctx.seed = element.get("seed", ctx.seed)
|
147 |
+
ctx.noramalize = element.get("noramalize", ctx.noramalize)
|
148 |
+
ctx.prompt1 = element.get("prompt1", ctx.prompt1)
|
149 |
+
ctx.prompt2 = element.get("prompt2", ctx.prompt2)
|
150 |
+
ctx.prefix = element.get("prefix", ctx.prefix)
|
151 |
+
|
152 |
+
if element.text and element.text.strip():
|
153 |
+
segments.append(SSMLSegment(element.text.strip(), ctx))
|
154 |
+
|
155 |
+
return parser
|
156 |
+
|
157 |
+
|
158 |
+
if __name__ == "__main__":
|
159 |
+
parser = create_ssml_parser()
|
160 |
+
|
161 |
+
ssml = """
|
162 |
+
<speak version="0.1">
|
163 |
+
<voice spk="xiaoyan" style="news">
|
164 |
+
<prosody rate="fast">你好</prosody>
|
165 |
+
<break time="500ms"/>
|
166 |
+
<prosody rate="slow">你好</prosody>
|
167 |
+
</voice>
|
168 |
+
</speak>
|
169 |
+
"""
|
170 |
+
|
171 |
+
segments = parser.parse(ssml)
|
172 |
+
for segment in segments:
|
173 |
+
if isinstance(segment, SSMLBreak):
|
174 |
+
print("<break>", segment.attrs)
|
175 |
+
elif isinstance(segment, SSMLSegment):
|
176 |
+
print(segment.text, segment.attrs)
|
177 |
+
else:
|
178 |
+
raise ValueError("Unknown segment type")
|
modules/ssml_parser/__init__.py
ADDED
File without changes
|
modules/ssml_parser/test_ssml_parser.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
from lxml import etree
|
3 |
+
from modules.ssml_parser.SSMLParser import (
|
4 |
+
create_ssml_parser,
|
5 |
+
SSMLSegment,
|
6 |
+
SSMLBreak,
|
7 |
+
SSMLContext,
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
@pytest.fixture
|
12 |
+
def parser():
|
13 |
+
return create_ssml_parser()
|
14 |
+
|
15 |
+
|
16 |
+
@pytest.mark.ssml_parser
|
17 |
+
def test_speak_tag(parser):
|
18 |
+
ssml = """
|
19 |
+
<speak version="0.1">
|
20 |
+
<voice spk="xiaoyan" style="news">
|
21 |
+
<prosody rate="fast">你好</prosody>
|
22 |
+
<break time="500ms"/>
|
23 |
+
<prosody rate="slow">你好</prosody>
|
24 |
+
</voice>
|
25 |
+
</speak>
|
26 |
+
"""
|
27 |
+
segments = parser.parse(ssml)
|
28 |
+
assert len(segments) == 3
|
29 |
+
assert isinstance(segments[0], SSMLSegment)
|
30 |
+
assert segments[0].text == "你好"
|
31 |
+
assert segments[0].params.rate == "fast"
|
32 |
+
assert isinstance(segments[1], SSMLBreak)
|
33 |
+
assert segments[1].duration == 500
|
34 |
+
assert isinstance(segments[2], SSMLSegment)
|
35 |
+
assert segments[2].text == "你好"
|
36 |
+
assert segments[2].params.rate == "slow"
|
37 |
+
|
38 |
+
|
39 |
+
@pytest.mark.ssml_parser
|
40 |
+
def test_voice_tag(parser):
|
41 |
+
ssml = """
|
42 |
+
<speak version="0.1">
|
43 |
+
<voice spk="xiaoyan" style="news">你好</voice>
|
44 |
+
</speak>
|
45 |
+
"""
|
46 |
+
segments = parser.parse(ssml)
|
47 |
+
assert len(segments) == 1
|
48 |
+
assert isinstance(segments[0], SSMLSegment)
|
49 |
+
assert segments[0].text == "你好"
|
50 |
+
assert segments[0].params.spk == "xiaoyan"
|
51 |
+
assert segments[0].params.style == "news"
|
52 |
+
|
53 |
+
|
54 |
+
@pytest.mark.ssml_parser
|
55 |
+
def test_break_tag(parser):
|
56 |
+
ssml = """
|
57 |
+
<speak version="0.1">
|
58 |
+
<break time="500ms"/>
|
59 |
+
</speak>
|
60 |
+
"""
|
61 |
+
segments = parser.parse(ssml)
|
62 |
+
assert len(segments) == 1
|
63 |
+
assert isinstance(segments[0], SSMLBreak)
|
64 |
+
assert segments[0].duration == 500
|
65 |
+
|
66 |
+
|
67 |
+
@pytest.mark.ssml_parser
|
68 |
+
def test_prosody_tag(parser):
|
69 |
+
ssml = """
|
70 |
+
<speak version="0.1">
|
71 |
+
<prosody rate="fast">你好</prosody>
|
72 |
+
</speak>
|
73 |
+
"""
|
74 |
+
segments = parser.parse(ssml)
|
75 |
+
assert len(segments) == 1
|
76 |
+
assert isinstance(segments[0], SSMLSegment)
|
77 |
+
assert segments[0].text == "你好"
|
78 |
+
assert segments[0].params.rate == "fast"
|
79 |
+
|
80 |
+
|
81 |
+
@pytest.mark.ssml_parser
|
82 |
+
def test_unsupported_version(parser):
|
83 |
+
ssml = """
|
84 |
+
<speak version="0.2">
|
85 |
+
<voice spk="xiaoyan" style="news">你好</voice>
|
86 |
+
</speak>
|
87 |
+
"""
|
88 |
+
with pytest.raises(ValueError, match=r"Unsupported SSML version 0.2"):
|
89 |
+
parser.parse(ssml)
|
90 |
+
|
91 |
+
|
92 |
+
@pytest.mark.ssml_parser
|
93 |
+
def test_unsupported_tag(parser):
|
94 |
+
ssml = """
|
95 |
+
<speak version="0.1">
|
96 |
+
<unsupported>你好</unsupported>
|
97 |
+
</speak>
|
98 |
+
"""
|
99 |
+
with pytest.raises(NotImplementedError, match=r"Tag unsupported not supported."):
|
100 |
+
parser.parse(ssml)
|
101 |
+
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
pytest.main()
|
modules/utils/JsonObject.py
CHANGED
@@ -8,6 +8,9 @@ class JsonObject:
|
|
8 |
# If no initial dictionary is provided, use an empty dictionary
|
9 |
self._dict_obj = initial_dict if initial_dict is not None else {}
|
10 |
|
|
|
|
|
|
|
11 |
def __getattr__(self, name):
|
12 |
"""
|
13 |
Get an attribute value. If the attribute does not exist,
|
@@ -111,3 +114,19 @@ class JsonObject:
|
|
111 |
:return: A list of values.
|
112 |
"""
|
113 |
return self._dict_obj.values()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
# If no initial dictionary is provided, use an empty dictionary
|
9 |
self._dict_obj = initial_dict if initial_dict is not None else {}
|
10 |
|
11 |
+
if self._dict_obj is self:
|
12 |
+
raise ValueError("JsonObject cannot be initialized with itself")
|
13 |
+
|
14 |
def __getattr__(self, name):
|
15 |
"""
|
16 |
Get an attribute value. If the attribute does not exist,
|
|
|
114 |
:return: A list of values.
|
115 |
"""
|
116 |
return self._dict_obj.values()
|
117 |
+
|
118 |
+
def clone(self):
|
119 |
+
"""
|
120 |
+
Clone the JsonObject.
|
121 |
+
|
122 |
+
:return: A new JsonObject with the same internal dictionary.
|
123 |
+
"""
|
124 |
+
return JsonObject(self._dict_obj.copy())
|
125 |
+
|
126 |
+
def merge(self, other):
|
127 |
+
"""
|
128 |
+
Merge the internal dictionary with another dictionary.
|
129 |
+
|
130 |
+
:param other: The other dictionary to merge.
|
131 |
+
"""
|
132 |
+
self._dict_obj.update(other)
|
modules/utils/constants.py
CHANGED
@@ -10,4 +10,4 @@ DATA_DIR = os.path.join(ROOT_DIR, "data")
|
|
10 |
|
11 |
MODELS_DIR = os.path.join(ROOT_DIR, "models")
|
12 |
|
13 |
-
|
|
|
10 |
|
11 |
MODELS_DIR = os.path.join(ROOT_DIR, "models")
|
12 |
|
13 |
+
SPEAKERS_DIR = os.path.join(DATA_DIR, "speakers")
|
modules/webui/app.py
CHANGED
@@ -5,7 +5,9 @@ import torch
|
|
5 |
import gradio as gr
|
6 |
|
7 |
from modules import config
|
|
|
8 |
|
|
|
9 |
from modules.webui.tts_tab import create_tts_interface
|
10 |
from modules.webui.ssml_tab import create_ssml_interface
|
11 |
from modules.webui.spliter_tab import create_spliter_tab
|
@@ -93,15 +95,15 @@ def create_interface():
|
|
93 |
with gr.TabItem("Spilter"):
|
94 |
create_spliter_tab(ssml_input, tabs=tabs)
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
|
106 |
with gr.TabItem("README"):
|
107 |
create_readme_tab()
|
|
|
5 |
import gradio as gr
|
6 |
|
7 |
from modules import config
|
8 |
+
from modules.webui import webui_config
|
9 |
|
10 |
+
from modules.webui.system_tab import create_system_tab
|
11 |
from modules.webui.tts_tab import create_tts_interface
|
12 |
from modules.webui.ssml_tab import create_ssml_interface
|
13 |
from modules.webui.spliter_tab import create_spliter_tab
|
|
|
95 |
with gr.TabItem("Spilter"):
|
96 |
create_spliter_tab(ssml_input, tabs=tabs)
|
97 |
|
98 |
+
with gr.TabItem("Speaker"):
|
99 |
+
create_speaker_panel()
|
100 |
+
with gr.TabItem("Inpainting", visible=webui_config.experimental):
|
101 |
+
gr.Markdown("🚧 Under construction")
|
102 |
+
with gr.TabItem("ASR", visible=webui_config.experimental):
|
103 |
+
gr.Markdown("🚧 Under construction")
|
104 |
+
|
105 |
+
with gr.TabItem("System"):
|
106 |
+
create_system_tab()
|
107 |
|
108 |
with gr.TabItem("README"):
|
109 |
create_readme_tab()
|
modules/webui/speaker_tab.py
CHANGED
@@ -1,13 +1,259 @@
|
|
|
|
1 |
import gradio as gr
|
|
|
2 |
|
3 |
-
from modules.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
|
6 |
# 显示 a b c d 四个选择框,选择一个或多个,然后可以试音,并导出
|
7 |
def create_speaker_panel():
|
8 |
speakers = get_speakers()
|
9 |
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
import gradio as gr
|
3 |
+
import torch
|
4 |
|
5 |
+
from modules.hf import spaces
|
6 |
+
from modules.webui.webui_utils import get_speakers, tts_generate
|
7 |
+
from modules.speaker import speaker_mgr, Speaker
|
8 |
+
|
9 |
+
import tempfile
|
10 |
+
|
11 |
+
|
12 |
+
def spk_to_tensor(spk):
|
13 |
+
spk = spk.split(" : ")[1].strip() if " : " in spk else spk
|
14 |
+
if spk == "None" or spk == "":
|
15 |
+
return None
|
16 |
+
return speaker_mgr.get_speaker(spk).emb
|
17 |
+
|
18 |
+
|
19 |
+
def get_speaker_show_name(spk):
|
20 |
+
if spk.gender == "*" or spk.gender == "":
|
21 |
+
return spk.name
|
22 |
+
return f"{spk.gender} : {spk.name}"
|
23 |
+
|
24 |
+
|
25 |
+
def merge_spk(
|
26 |
+
spk_a,
|
27 |
+
spk_a_w,
|
28 |
+
spk_b,
|
29 |
+
spk_b_w,
|
30 |
+
spk_c,
|
31 |
+
spk_c_w,
|
32 |
+
spk_d,
|
33 |
+
spk_d_w,
|
34 |
+
):
|
35 |
+
tensor_a = spk_to_tensor(spk_a)
|
36 |
+
tensor_b = spk_to_tensor(spk_b)
|
37 |
+
tensor_c = spk_to_tensor(spk_c)
|
38 |
+
tensor_d = spk_to_tensor(spk_d)
|
39 |
+
|
40 |
+
assert (
|
41 |
+
tensor_a is not None
|
42 |
+
or tensor_b is not None
|
43 |
+
or tensor_c is not None
|
44 |
+
or tensor_d is not None
|
45 |
+
), "At least one speaker should be selected"
|
46 |
+
|
47 |
+
merge_tensor = torch.zeros_like(
|
48 |
+
tensor_a
|
49 |
+
if tensor_a is not None
|
50 |
+
else (
|
51 |
+
tensor_b
|
52 |
+
if tensor_b is not None
|
53 |
+
else tensor_c if tensor_c is not None else tensor_d
|
54 |
+
)
|
55 |
+
)
|
56 |
+
|
57 |
+
total_weight = 0
|
58 |
+
if tensor_a is not None:
|
59 |
+
merge_tensor += spk_a_w * tensor_a
|
60 |
+
total_weight += spk_a_w
|
61 |
+
if tensor_b is not None:
|
62 |
+
merge_tensor += spk_b_w * tensor_b
|
63 |
+
total_weight += spk_b_w
|
64 |
+
if tensor_c is not None:
|
65 |
+
merge_tensor += spk_c_w * tensor_c
|
66 |
+
total_weight += spk_c_w
|
67 |
+
if tensor_d is not None:
|
68 |
+
merge_tensor += spk_d_w * tensor_d
|
69 |
+
total_weight += spk_d_w
|
70 |
+
|
71 |
+
if total_weight > 0:
|
72 |
+
merge_tensor /= total_weight
|
73 |
+
|
74 |
+
merged_spk = Speaker.from_tensor(merge_tensor)
|
75 |
+
merged_spk.name = "<MIX>"
|
76 |
+
|
77 |
+
return merged_spk
|
78 |
+
|
79 |
+
|
80 |
+
@torch.inference_mode()
|
81 |
+
@spaces.GPU
|
82 |
+
def merge_and_test_spk_voice(
|
83 |
+
spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w, test_text
|
84 |
+
):
|
85 |
+
merged_spk = merge_spk(
|
86 |
+
spk_a,
|
87 |
+
spk_a_w,
|
88 |
+
spk_b,
|
89 |
+
spk_b_w,
|
90 |
+
spk_c,
|
91 |
+
spk_c_w,
|
92 |
+
spk_d,
|
93 |
+
spk_d_w,
|
94 |
+
)
|
95 |
+
return tts_generate(
|
96 |
+
spk=merged_spk,
|
97 |
+
text=test_text,
|
98 |
+
)
|
99 |
+
|
100 |
+
|
101 |
+
@torch.inference_mode()
|
102 |
+
@spaces.GPU
|
103 |
+
def merge_spk_to_file(
|
104 |
+
spk_a,
|
105 |
+
spk_a_w,
|
106 |
+
spk_b,
|
107 |
+
spk_b_w,
|
108 |
+
spk_c,
|
109 |
+
spk_c_w,
|
110 |
+
spk_d,
|
111 |
+
spk_d_w,
|
112 |
+
speaker_name,
|
113 |
+
speaker_gender,
|
114 |
+
speaker_desc,
|
115 |
+
):
|
116 |
+
merged_spk = merge_spk(
|
117 |
+
spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w
|
118 |
+
)
|
119 |
+
merged_spk.name = speaker_name
|
120 |
+
merged_spk.gender = speaker_gender
|
121 |
+
merged_spk.desc = speaker_desc
|
122 |
+
|
123 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
|
124 |
+
torch.save(merged_spk, tmp_file)
|
125 |
+
tmp_file_path = tmp_file.name
|
126 |
+
|
127 |
+
return tmp_file_path
|
128 |
+
|
129 |
+
|
130 |
+
merge_desc = """
|
131 |
+
## Speaker Merger
|
132 |
+
|
133 |
+
在本面板中,您可以选择多个说话人并指定他们的权重,合成新的语音并进行测试。以下是各个功能的详细说明:
|
134 |
+
|
135 |
+
### 1. 选择说话人
|
136 |
+
您可以从下拉菜单中选择最多四个说话人(A、B、C、D),每个说话人都有一个对应的权重滑块,范围从0到10。权重决定了每个说话人在合成语音中的影响程度。
|
137 |
+
|
138 |
+
### 2. 合成语音
|
139 |
+
在选择好说话人和设置好权重后,您可以在“测试文本”框中输入要测试的文本,然后点击“测试语音”按钮来生成并播放合成的语音。
|
140 |
+
|
141 |
+
### 3. 保存说话人
|
142 |
+
您还可以在右侧的“说话人信息”部分填写新的说话人的名称、性别和描述,并点击“保存说话人”按钮来保存合成的说话人。保存后的说话人文件将显示在“合成说话人”栏中,供下载使用。
|
143 |
+
"""
|
144 |
|
145 |
|
146 |
# 显示 a b c d 四个选择框,选择一个或多个,然后可以试音,并导出
|
147 |
def create_speaker_panel():
|
148 |
speakers = get_speakers()
|
149 |
|
150 |
+
speaker_names = ["None"] + [get_speaker_show_name(speaker) for speaker in speakers]
|
151 |
+
|
152 |
+
with gr.Tabs():
|
153 |
+
with gr.TabItem("Merger"):
|
154 |
+
gr.Markdown(merge_desc)
|
155 |
+
|
156 |
+
with gr.Row():
|
157 |
+
with gr.Column(scale=5):
|
158 |
+
with gr.Row():
|
159 |
+
with gr.Group():
|
160 |
+
spk_a = gr.Dropdown(
|
161 |
+
choices=speaker_names, value="None", label="Speaker A"
|
162 |
+
)
|
163 |
+
spk_a_w = gr.Slider(
|
164 |
+
value=1, minimum=0, maximum=10, step=1, label="Weight A"
|
165 |
+
)
|
166 |
+
|
167 |
+
with gr.Group():
|
168 |
+
spk_b = gr.Dropdown(
|
169 |
+
choices=speaker_names, value="None", label="Speaker B"
|
170 |
+
)
|
171 |
+
spk_b_w = gr.Slider(
|
172 |
+
value=1, minimum=0, maximum=10, step=1, label="Weight B"
|
173 |
+
)
|
174 |
+
|
175 |
+
with gr.Group():
|
176 |
+
spk_c = gr.Dropdown(
|
177 |
+
choices=speaker_names, value="None", label="Speaker C"
|
178 |
+
)
|
179 |
+
spk_c_w = gr.Slider(
|
180 |
+
value=1, minimum=0, maximum=10, step=1, label="Weight C"
|
181 |
+
)
|
182 |
+
|
183 |
+
with gr.Group():
|
184 |
+
spk_d = gr.Dropdown(
|
185 |
+
choices=speaker_names, value="None", label="Speaker D"
|
186 |
+
)
|
187 |
+
spk_d_w = gr.Slider(
|
188 |
+
value=1, minimum=0, maximum=10, step=1, label="Weight D"
|
189 |
+
)
|
190 |
+
|
191 |
+
with gr.Row():
|
192 |
+
with gr.Column(scale=3):
|
193 |
+
with gr.Group():
|
194 |
+
gr.Markdown("🎤Test voice")
|
195 |
+
with gr.Row():
|
196 |
+
test_voice_btn = gr.Button(
|
197 |
+
"Test Voice", variant="secondary"
|
198 |
+
)
|
199 |
+
|
200 |
+
with gr.Column(scale=4):
|
201 |
+
test_text = gr.Textbox(
|
202 |
+
label="Test Text",
|
203 |
+
placeholder="Please input test text",
|
204 |
+
value="说话人合并测试 123456789 [uv_break] ok, test done [lbreak]",
|
205 |
+
)
|
206 |
+
|
207 |
+
output_audio = gr.Audio(label="Output Audio")
|
208 |
+
|
209 |
+
with gr.Column(scale=1):
|
210 |
+
with gr.Group():
|
211 |
+
gr.Markdown("🗃️Save to file")
|
212 |
+
|
213 |
+
speaker_name = gr.Textbox(
|
214 |
+
label="Name", value="forge_speaker_merged"
|
215 |
+
)
|
216 |
+
speaker_gender = gr.Textbox(label="Gender", value="*")
|
217 |
+
speaker_desc = gr.Textbox(
|
218 |
+
label="Description", value="merged speaker"
|
219 |
+
)
|
220 |
+
|
221 |
+
save_btn = gr.Button("Save Speaker", variant="primary")
|
222 |
+
|
223 |
+
merged_spker = gr.File(
|
224 |
+
label="Merged Speaker", interactive=False, type="binary"
|
225 |
+
)
|
226 |
+
|
227 |
+
test_voice_btn.click(
|
228 |
+
merge_and_test_spk_voice,
|
229 |
+
inputs=[
|
230 |
+
spk_a,
|
231 |
+
spk_a_w,
|
232 |
+
spk_b,
|
233 |
+
spk_b_w,
|
234 |
+
spk_c,
|
235 |
+
spk_c_w,
|
236 |
+
spk_d,
|
237 |
+
spk_d_w,
|
238 |
+
test_text,
|
239 |
+
],
|
240 |
+
outputs=[output_audio],
|
241 |
+
)
|
242 |
|
243 |
+
save_btn.click(
|
244 |
+
merge_spk_to_file,
|
245 |
+
inputs=[
|
246 |
+
spk_a,
|
247 |
+
spk_a_w,
|
248 |
+
spk_b,
|
249 |
+
spk_b_w,
|
250 |
+
spk_c,
|
251 |
+
spk_c_w,
|
252 |
+
spk_d,
|
253 |
+
spk_d_w,
|
254 |
+
speaker_name,
|
255 |
+
speaker_gender,
|
256 |
+
speaker_desc,
|
257 |
+
],
|
258 |
+
outputs=[merged_spker],
|
259 |
+
)
|
modules/webui/spliter_tab.py
CHANGED
@@ -9,6 +9,7 @@ from modules.webui.webui_utils import (
|
|
9 |
from modules.hf import spaces
|
10 |
|
11 |
|
|
|
12 |
@torch.inference_mode()
|
13 |
@spaces.GPU
|
14 |
def merge_dataframe_to_ssml(dataframe, spk, style, seed):
|
@@ -31,7 +32,7 @@ def merge_dataframe_to_ssml(dataframe, spk, style, seed):
|
|
31 |
if seed:
|
32 |
ssml += f' seed="{seed}"'
|
33 |
ssml += ">\n"
|
34 |
-
ssml += f"{indent}{indent}{text_normalize(row[1])}\n"
|
35 |
ssml += f"{indent}</voice>\n"
|
36 |
return f"<speak version='0.1'>\n{ssml}</speak>"
|
37 |
|
|
|
9 |
from modules.hf import spaces
|
10 |
|
11 |
|
12 |
+
# NOTE: 因为 text_normalize 需要使用 tokenizer
|
13 |
@torch.inference_mode()
|
14 |
@spaces.GPU
|
15 |
def merge_dataframe_to_ssml(dataframe, spk, style, seed):
|
|
|
32 |
if seed:
|
33 |
ssml += f' seed="{seed}"'
|
34 |
ssml += ">\n"
|
35 |
+
ssml += f"{indent}{indent}{text_normalize(row.iloc[1])}\n"
|
36 |
ssml += f"{indent}</voice>\n"
|
37 |
return f"<speak version='0.1'>\n{ssml}</speak>"
|
38 |
|
modules/webui/system_tab.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from modules.webui import webui_config
|
3 |
+
|
4 |
+
|
5 |
+
def create_system_tab():
|
6 |
+
with gr.Row():
|
7 |
+
with gr.Column(scale=1):
|
8 |
+
gr.Markdown(f"info")
|
9 |
+
|
10 |
+
with gr.Column(scale=5):
|
11 |
+
toggle_experimental = gr.Checkbox(
|
12 |
+
label="Enable Experimental Features",
|
13 |
+
value=webui_config.experimental,
|
14 |
+
interactive=False,
|
15 |
+
)
|
modules/webui/tts_tab.py
CHANGED
@@ -3,6 +3,7 @@ import torch
|
|
3 |
from modules.webui.webui_utils import (
|
4 |
get_speakers,
|
5 |
get_styles,
|
|
|
6 |
refine_text,
|
7 |
tts_generate,
|
8 |
)
|
@@ -10,6 +11,13 @@ from modules.webui import webui_config
|
|
10 |
from modules.webui.examples import example_texts
|
11 |
from modules import config
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def create_tts_interface():
|
15 |
speakers = get_speakers()
|
@@ -90,15 +98,18 @@ def create_tts_interface():
|
|
90 |
outputs=[spk_input_text],
|
91 |
)
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
102 |
with gr.Group():
|
103 |
gr.Markdown("💃Inference Seed")
|
104 |
infer_seed_input = gr.Number(
|
@@ -122,85 +133,62 @@ def create_tts_interface():
|
|
122 |
prompt2_input = gr.Textbox(label="Prompt 2")
|
123 |
prefix_input = gr.Textbox(label="Prefix")
|
124 |
|
125 |
-
|
126 |
-
|
|
|
127 |
|
128 |
infer_seed_rand_button.click(
|
129 |
lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
|
130 |
inputs=[infer_seed_input],
|
131 |
outputs=[infer_seed_input],
|
132 |
)
|
133 |
-
with gr.Column(scale=
|
134 |
-
with gr.
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
lambda text, tk=tk: text + " " + tk,
|
181 |
-
inputs=[text_input],
|
182 |
-
outputs=[text_input],
|
183 |
-
)
|
184 |
-
with gr.Column(scale=1):
|
185 |
-
with gr.Group():
|
186 |
-
gr.Markdown("🎶Refiner")
|
187 |
-
refine_prompt_input = gr.Textbox(
|
188 |
-
label="Refine Prompt",
|
189 |
-
value="[oral_2][laugh_0][break_6]",
|
190 |
-
)
|
191 |
-
refine_button = gr.Button("✍️Refine Text")
|
192 |
-
# TODO 分割句子,使用当前配置拼接为SSML,然后发送到SSML tab
|
193 |
-
# send_button = gr.Button("📩Split and send to SSML")
|
194 |
-
|
195 |
-
with gr.Group():
|
196 |
-
gr.Markdown("🔊Generate")
|
197 |
-
disable_normalize_input = gr.Checkbox(
|
198 |
-
value=False, label="Disable Normalize"
|
199 |
-
)
|
200 |
-
tts_button = gr.Button(
|
201 |
-
"🔊Generate Audio",
|
202 |
-
variant="primary",
|
203 |
-
elem_classes="big-button",
|
204 |
)
|
205 |
|
206 |
with gr.Group():
|
@@ -220,6 +208,31 @@ def create_tts_interface():
|
|
220 |
with gr.Group():
|
221 |
gr.Markdown("🎨Output")
|
222 |
tts_output = gr.Audio(label="Generated Audio")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
refine_button.click(
|
225 |
refine_text,
|
@@ -243,6 +256,9 @@ def create_tts_interface():
|
|
243 |
style_input_dropdown,
|
244 |
disable_normalize_input,
|
245 |
batch_size_input,
|
|
|
|
|
|
|
246 |
],
|
247 |
outputs=tts_output,
|
248 |
)
|
|
|
3 |
from modules.webui.webui_utils import (
|
4 |
get_speakers,
|
5 |
get_styles,
|
6 |
+
load_spk_info,
|
7 |
refine_text,
|
8 |
tts_generate,
|
9 |
)
|
|
|
11 |
from modules.webui.examples import example_texts
|
12 |
from modules import config
|
13 |
|
14 |
+
default_text_content = """
|
15 |
+
chat T T S 是一款强大的对话式文本转语音模型。它有中英混读和多说话人的能力。
|
16 |
+
chat T T S 不仅能够生成自然流畅的语音,还能控制[laugh]笑声啊[laugh],
|
17 |
+
停顿啊[uv_break]语气词啊等副语言现象[uv_break]。这个韵律超越了许多开源模型[uv_break]。
|
18 |
+
请注意,chat T T S 的使用应遵守法律和伦理准则,避免滥用的安全风险。[uv_break]
|
19 |
+
"""
|
20 |
+
|
21 |
|
22 |
def create_tts_interface():
|
23 |
speakers = get_speakers()
|
|
|
98 |
outputs=[spk_input_text],
|
99 |
)
|
100 |
|
101 |
+
with gr.Tab(label="Upload"):
|
102 |
+
spk_file_upload = gr.File(label="Speaker (Upload)")
|
103 |
+
|
104 |
+
gr.Markdown("📝Speaker info")
|
105 |
+
infos = gr.Markdown("empty")
|
106 |
+
|
107 |
+
spk_file_upload.change(
|
108 |
+
fn=load_spk_info,
|
109 |
+
inputs=[spk_file_upload],
|
110 |
+
outputs=[infos],
|
111 |
+
),
|
112 |
+
|
113 |
with gr.Group():
|
114 |
gr.Markdown("💃Inference Seed")
|
115 |
infer_seed_input = gr.Number(
|
|
|
133 |
prompt2_input = gr.Textbox(label="Prompt 2")
|
134 |
prefix_input = gr.Textbox(label="Prefix")
|
135 |
|
136 |
+
prompt_audio = gr.File(
|
137 |
+
label="prompt_audio", visible=webui_config.experimental
|
138 |
+
)
|
139 |
|
140 |
infer_seed_rand_button.click(
|
141 |
lambda x: int(torch.randint(0, 2**32 - 1, (1,)).item()),
|
142 |
inputs=[infer_seed_input],
|
143 |
outputs=[infer_seed_input],
|
144 |
)
|
145 |
+
with gr.Column(scale=4):
|
146 |
+
with gr.Group():
|
147 |
+
input_title = gr.Markdown(
|
148 |
+
"📝Text Input",
|
149 |
+
elem_id="input-title",
|
150 |
+
)
|
151 |
+
gr.Markdown(f"- 字数限制{webui_config.tts_max:,}字,超过部分截断")
|
152 |
+
gr.Markdown("- 如果尾字吞字不读,可以试试结尾加上 `[lbreak]`")
|
153 |
+
gr.Markdown(
|
154 |
+
"- If the input text is all in English, it is recommended to check disable_normalize"
|
155 |
+
)
|
156 |
+
text_input = gr.Textbox(
|
157 |
+
show_label=False,
|
158 |
+
label="Text to Speech",
|
159 |
+
lines=10,
|
160 |
+
placeholder="输入文本或选择示例",
|
161 |
+
elem_id="text-input",
|
162 |
+
value=default_text_content,
|
163 |
+
)
|
164 |
+
# TODO 字数统计,其实实现很好写,但是就是会触发loading...并且还要和后端交互...
|
165 |
+
# text_input.change(
|
166 |
+
# fn=lambda x: (
|
167 |
+
# f"📝Text Input ({len(x)} char)"
|
168 |
+
# if x
|
169 |
+
# else (
|
170 |
+
# "📝Text Input (0 char)"
|
171 |
+
# if not x
|
172 |
+
# else "📝Text Input (0 char)"
|
173 |
+
# )
|
174 |
+
# ),
|
175 |
+
# inputs=[text_input],
|
176 |
+
# outputs=[input_title],
|
177 |
+
# )
|
178 |
+
with gr.Row():
|
179 |
+
contorl_tokens = [
|
180 |
+
"[laugh]",
|
181 |
+
"[uv_break]",
|
182 |
+
"[v_break]",
|
183 |
+
"[lbreak]",
|
184 |
+
]
|
185 |
+
|
186 |
+
for tk in contorl_tokens:
|
187 |
+
t_btn = gr.Button(tk)
|
188 |
+
t_btn.click(
|
189 |
+
lambda text, tk=tk: text + " " + tk,
|
190 |
+
inputs=[text_input],
|
191 |
+
outputs=[text_input],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
)
|
193 |
|
194 |
with gr.Group():
|
|
|
208 |
with gr.Group():
|
209 |
gr.Markdown("🎨Output")
|
210 |
tts_output = gr.Audio(label="Generated Audio")
|
211 |
+
with gr.Column(scale=1):
|
212 |
+
with gr.Group():
|
213 |
+
gr.Markdown("🎶Refiner")
|
214 |
+
refine_prompt_input = gr.Textbox(
|
215 |
+
label="Refine Prompt",
|
216 |
+
value="[oral_2][laugh_0][break_6]",
|
217 |
+
)
|
218 |
+
refine_button = gr.Button("✍️Refine Text")
|
219 |
+
|
220 |
+
with gr.Group():
|
221 |
+
gr.Markdown("🔊Generate")
|
222 |
+
disable_normalize_input = gr.Checkbox(
|
223 |
+
value=False, label="Disable Normalize"
|
224 |
+
)
|
225 |
+
|
226 |
+
# FIXME: 不知道为啥,就是非常慢,单独调脚本是很快的
|
227 |
+
with gr.Group(visible=webui_config.experimental):
|
228 |
+
gr.Markdown("💪🏼Enhance")
|
229 |
+
enable_enhance = gr.Checkbox(value=False, label="Enable Enhance")
|
230 |
+
enable_de_noise = gr.Checkbox(value=False, label="Enable De-noise")
|
231 |
+
tts_button = gr.Button(
|
232 |
+
"🔊Generate Audio",
|
233 |
+
variant="primary",
|
234 |
+
elem_classes="big-button",
|
235 |
+
)
|
236 |
|
237 |
refine_button.click(
|
238 |
refine_text,
|
|
|
256 |
style_input_dropdown,
|
257 |
disable_normalize_input,
|
258 |
batch_size_input,
|
259 |
+
enable_enhance,
|
260 |
+
enable_de_noise,
|
261 |
+
spk_file_upload,
|
262 |
],
|
263 |
outputs=tts_output,
|
264 |
)
|
modules/webui/webui_config.py
CHANGED
@@ -1,4 +1,8 @@
|
|
|
|
|
|
|
|
1 |
tts_max = 1000
|
2 |
ssml_max = 1000
|
3 |
spliter_threshold = 100
|
4 |
max_batch_size = 8
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
|
3 |
+
|
4 |
tts_max = 1000
|
5 |
ssml_max = 1000
|
6 |
spliter_threshold = 100
|
7 |
max_batch_size = 8
|
8 |
+
experimental = False
|
modules/webui/webui_utils.py
CHANGED
@@ -1,37 +1,26 @@
|
|
1 |
-
import
|
2 |
-
import logging
|
3 |
-
import sys
|
4 |
-
|
5 |
import numpy as np
|
6 |
|
|
|
7 |
from modules.devices import devices
|
8 |
from modules.synthesize_audio import synthesize_audio
|
9 |
from modules.hf import spaces
|
10 |
from modules.webui import webui_config
|
11 |
|
12 |
-
logging.basicConfig(
|
13 |
-
level=os.getenv("LOG_LEVEL", "INFO"),
|
14 |
-
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
15 |
-
)
|
16 |
-
|
17 |
-
|
18 |
-
import gradio as gr
|
19 |
-
|
20 |
import torch
|
21 |
|
22 |
-
from modules.
|
23 |
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
|
24 |
|
25 |
-
from modules.speaker import speaker_mgr
|
26 |
from modules.data import styles_mgr
|
27 |
|
28 |
from modules.api.utils import calc_spk_style
|
29 |
-
import modules.generate_audio as generate
|
30 |
|
31 |
from modules.normalization import text_normalize
|
32 |
-
from modules import refiner
|
33 |
|
34 |
-
from modules.utils import
|
35 |
from modules.SentenceSplitter import SentenceSplitter
|
36 |
|
37 |
|
@@ -43,11 +32,30 @@ def get_styles():
|
|
43 |
return styles_mgr.list_items()
|
44 |
|
45 |
|
46 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
ret_segments = []
|
48 |
total_len = 0
|
49 |
for seg in segments:
|
50 |
-
if
|
|
|
51 |
continue
|
52 |
total_len += len(seg["text"])
|
53 |
if total_len > total_max:
|
@@ -56,6 +64,28 @@ def segments_length_limit(segments, total_max: int):
|
|
56 |
return ret_segments
|
57 |
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
@torch.inference_mode()
|
60 |
@spaces.GPU
|
61 |
def synthesize_ssml(ssml: str, batch_size=4):
|
@@ -69,7 +99,8 @@ def synthesize_ssml(ssml: str, batch_size=4):
|
|
69 |
if ssml == "":
|
70 |
return None
|
71 |
|
72 |
-
|
|
|
73 |
max_len = webui_config.ssml_max
|
74 |
segments = segments_length_limit(segments, max_len)
|
75 |
|
@@ -87,18 +118,21 @@ def synthesize_ssml(ssml: str, batch_size=4):
|
|
87 |
@spaces.GPU
|
88 |
def tts_generate(
|
89 |
text,
|
90 |
-
temperature,
|
91 |
-
top_p,
|
92 |
-
top_k,
|
93 |
-
spk,
|
94 |
-
infer_seed,
|
95 |
-
use_decoder,
|
96 |
-
prompt1,
|
97 |
-
prompt2,
|
98 |
-
prefix,
|
99 |
-
style,
|
100 |
disable_normalize=False,
|
101 |
batch_size=4,
|
|
|
|
|
|
|
102 |
):
|
103 |
try:
|
104 |
batch_size = int(batch_size)
|
@@ -126,12 +160,15 @@ def tts_generate(
|
|
126 |
prompt1 = prompt1 or params.get("prompt1", "")
|
127 |
prompt2 = prompt2 or params.get("prompt2", "")
|
128 |
|
129 |
-
infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np.
|
130 |
infer_seed = int(infer_seed)
|
131 |
|
132 |
if not disable_normalize:
|
133 |
text = text_normalize(text)
|
134 |
|
|
|
|
|
|
|
135 |
sample_rate, audio_data = synthesize_audio(
|
136 |
text=text,
|
137 |
temperature=temperature,
|
@@ -146,6 +183,10 @@ def tts_generate(
|
|
146 |
batch_size=batch_size,
|
147 |
)
|
148 |
|
|
|
|
|
|
|
|
|
149 |
audio_data = audio.audio_to_int16(audio_data)
|
150 |
return sample_rate, audio_data
|
151 |
|
|
|
1 |
+
from typing import Union
|
|
|
|
|
|
|
2 |
import numpy as np
|
3 |
|
4 |
+
from modules.Enhancer.ResembleEnhance import load_enhancer
|
5 |
from modules.devices import devices
|
6 |
from modules.synthesize_audio import synthesize_audio
|
7 |
from modules.hf import spaces
|
8 |
from modules.webui import webui_config
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
import torch
|
11 |
|
12 |
+
from modules.ssml_parser.SSMLParser import create_ssml_parser, SSMLBreak, SSMLSegment
|
13 |
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
|
14 |
|
15 |
+
from modules.speaker import speaker_mgr, Speaker
|
16 |
from modules.data import styles_mgr
|
17 |
|
18 |
from modules.api.utils import calc_spk_style
|
|
|
19 |
|
20 |
from modules.normalization import text_normalize
|
21 |
+
from modules import refiner
|
22 |
|
23 |
+
from modules.utils import audio
|
24 |
from modules.SentenceSplitter import SentenceSplitter
|
25 |
|
26 |
|
|
|
32 |
return styles_mgr.list_items()
|
33 |
|
34 |
|
35 |
+
def load_spk_info(file):
|
36 |
+
if file is None:
|
37 |
+
return "empty"
|
38 |
+
try:
|
39 |
+
|
40 |
+
spk: Speaker = Speaker.from_file(file)
|
41 |
+
infos = spk.to_json()
|
42 |
+
return f"""
|
43 |
+
- name: {infos.name}
|
44 |
+
- gender: {infos.gender}
|
45 |
+
- describe: {infos.describe}
|
46 |
+
""".strip()
|
47 |
+
except:
|
48 |
+
return "load failed"
|
49 |
+
|
50 |
+
|
51 |
+
def segments_length_limit(
|
52 |
+
segments: list[Union[SSMLBreak, SSMLSegment]], total_max: int
|
53 |
+
) -> list[Union[SSMLBreak, SSMLSegment]]:
|
54 |
ret_segments = []
|
55 |
total_len = 0
|
56 |
for seg in segments:
|
57 |
+
if isinstance(seg, SSMLBreak):
|
58 |
+
ret_segments.append(seg)
|
59 |
continue
|
60 |
total_len += len(seg["text"])
|
61 |
if total_len > total_max:
|
|
|
64 |
return ret_segments
|
65 |
|
66 |
|
67 |
+
@torch.inference_mode()
|
68 |
+
@spaces.GPU
|
69 |
+
def apply_audio_enhance(audio_data, sr, enable_denoise, enable_enhance):
|
70 |
+
audio_data = torch.from_numpy(audio_data).float().squeeze().cpu()
|
71 |
+
if enable_denoise or enable_enhance:
|
72 |
+
enhancer = load_enhancer(devices.device)
|
73 |
+
if enable_denoise:
|
74 |
+
audio_data, sr = enhancer.denoise(audio_data, sr, devices.device)
|
75 |
+
if enable_enhance:
|
76 |
+
audio_data, sr = enhancer.enhance(
|
77 |
+
audio_data,
|
78 |
+
sr,
|
79 |
+
devices.device,
|
80 |
+
tau=0.9,
|
81 |
+
nfe=64,
|
82 |
+
solver="euler",
|
83 |
+
lambd=0.5,
|
84 |
+
)
|
85 |
+
audio_data = audio_data.cpu().numpy()
|
86 |
+
return audio_data, int(sr)
|
87 |
+
|
88 |
+
|
89 |
@torch.inference_mode()
|
90 |
@spaces.GPU
|
91 |
def synthesize_ssml(ssml: str, batch_size=4):
|
|
|
99 |
if ssml == "":
|
100 |
return None
|
101 |
|
102 |
+
parser = create_ssml_parser()
|
103 |
+
segments = parser.parse(ssml)
|
104 |
max_len = webui_config.ssml_max
|
105 |
segments = segments_length_limit(segments, max_len)
|
106 |
|
|
|
118 |
@spaces.GPU
|
119 |
def tts_generate(
|
120 |
text,
|
121 |
+
temperature=0.3,
|
122 |
+
top_p=0.7,
|
123 |
+
top_k=20,
|
124 |
+
spk=-1,
|
125 |
+
infer_seed=-1,
|
126 |
+
use_decoder=True,
|
127 |
+
prompt1="",
|
128 |
+
prompt2="",
|
129 |
+
prefix="",
|
130 |
+
style="",
|
131 |
disable_normalize=False,
|
132 |
batch_size=4,
|
133 |
+
enable_enhance=False,
|
134 |
+
enable_denoise=False,
|
135 |
+
spk_file=None,
|
136 |
):
|
137 |
try:
|
138 |
batch_size = int(batch_size)
|
|
|
160 |
prompt1 = prompt1 or params.get("prompt1", "")
|
161 |
prompt2 = prompt2 or params.get("prompt2", "")
|
162 |
|
163 |
+
infer_seed = np.clip(infer_seed, -1, 2**32 - 1, out=None, dtype=np.float64)
|
164 |
infer_seed = int(infer_seed)
|
165 |
|
166 |
if not disable_normalize:
|
167 |
text = text_normalize(text)
|
168 |
|
169 |
+
if spk_file:
|
170 |
+
spk = Speaker.from_file(spk_file)
|
171 |
+
|
172 |
sample_rate, audio_data = synthesize_audio(
|
173 |
text=text,
|
174 |
temperature=temperature,
|
|
|
183 |
batch_size=batch_size,
|
184 |
)
|
185 |
|
186 |
+
audio_data, sample_rate = apply_audio_enhance(
|
187 |
+
audio_data, sample_rate, enable_denoise, enable_enhance
|
188 |
+
)
|
189 |
+
|
190 |
audio_data = audio.audio_to_int16(audio_data)
|
191 |
return sample_rate, audio_data
|
192 |
|
webui.py
CHANGED
@@ -93,8 +93,10 @@ if __name__ == "__main__":
|
|
93 |
device_id = get_and_update_env(args, "device_id", None, str)
|
94 |
use_cpu = get_and_update_env(args, "use_cpu", [], list)
|
95 |
compile = get_and_update_env(args, "compile", False, bool)
|
96 |
-
webui_experimental = get_and_update_env(args, "webui_experimental", False, bool)
|
97 |
|
|
|
|
|
|
|
98 |
webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
|
99 |
webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
|
100 |
webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
|
|
|
93 |
device_id = get_and_update_env(args, "device_id", None, str)
|
94 |
use_cpu = get_and_update_env(args, "use_cpu", [], list)
|
95 |
compile = get_and_update_env(args, "compile", False, bool)
|
|
|
96 |
|
97 |
+
webui_config.experimental = get_and_update_env(
|
98 |
+
args, "webui_experimental", False, bool
|
99 |
+
)
|
100 |
webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int)
|
101 |
webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int)
|
102 |
webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int)
|