Spaces:
Sleeping
Sleeping
upload test model
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- app.py +138 -0
- audios/audio0.wav +0 -0
- audios/audio1.wav +0 -0
- audios/audio2.wav +0 -0
- audios/audio3.wav +0 -0
- audios/audio4.wav +0 -0
- audios/audio5.wav +0 -0
- audios/audio6.wav +0 -0
- checkpoints/.keep +0 -0
- checkpoints/checkpoint_0.pt +3 -0
- checkpoints/vocoder.pt +3 -0
- config.py +52 -0
- datas/__init__.py +0 -0
- datas/dataset.py +52 -0
- datas/sampler.py +121 -0
- models/__init__.py +0 -0
- models/__pycache__/__init__.cpython-310.pyc +0 -0
- models/__pycache__/model.cpython-310.pyc +0 -0
- models/dit.py +205 -0
- models/duration_predictor.py +40 -0
- models/estimator.py +161 -0
- models/flow_matching.py +108 -0
- models/model.py +194 -0
- models/reference_encoder.py +93 -0
- models/text_encoder.py +49 -0
- monotonic_align/__init__.py +16 -0
- monotonic_align/__pycache__/__init__.cpython-310.pyc +0 -0
- monotonic_align/__pycache__/core.cpython-310.pyc +0 -0
- monotonic_align/core.py +46 -0
- requirements.txt +11 -0
- text/LICENSE +19 -0
- text/__init__.py +71 -0
- text/__pycache__/__init__.cpython-310.pyc +0 -0
- text/__pycache__/cleaners.cpython-310.pyc +0 -0
- text/__pycache__/english.cpython-310.pyc +0 -0
- text/cleaners.py +73 -0
- text/cn2an/__init__.py +16 -0
- text/cn2an/__pycache__/__init__.cpython-311.pyc +0 -0
- text/cn2an/__pycache__/__init__.cpython-38.pyc +0 -0
- text/cn2an/__pycache__/an2cn.cpython-311.pyc +0 -0
- text/cn2an/__pycache__/an2cn.cpython-38.pyc +0 -0
- text/cn2an/__pycache__/cn2an.cpython-311.pyc +0 -0
- text/cn2an/__pycache__/cn2an.cpython-38.pyc +0 -0
- text/cn2an/__pycache__/conf.cpython-311.pyc +0 -0
- text/cn2an/__pycache__/conf.cpython-38.pyc +0 -0
- text/cn2an/__pycache__/transform.cpython-311.pyc +0 -0
- text/cn2an/__pycache__/transform.cpython-38.pyc +0 -0
- text/cn2an/an2cn.py +203 -0
- text/cn2an/cn2an.py +293 -0
.gitattributes
CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
text/custom_pypinyin_dict/__pycache__/cc_cedict_0.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
37 |
+
text/custom_pypinyin_dict/__pycache__/cc_cedict_0.cpython-38.pyc filter=lfs diff=lfs merge=lfs -text
|
38 |
+
text/custom_pypinyin_dict/__pycache__/cc_cedict_1.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
39 |
+
text/custom_pypinyin_dict/__pycache__/cc_cedict_1.cpython-38.pyc filter=lfs diff=lfs merge=lfs -text
|
40 |
+
text/custom_pypinyin_dict/__pycache__/cc_cedict_2.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
41 |
+
text/custom_pypinyin_dict/__pycache__/cc_cedict_2.cpython-38.pyc filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from dataclasses import asdict
|
4 |
+
from text import symbols
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
|
8 |
+
from utils.audio import LogMelSpectrogram
|
9 |
+
from config import ModelConfig, VocosConfig, MelConfig
|
10 |
+
from models.model import StableTTS
|
11 |
+
from vocos_pytorch.models.model import Vocos
|
12 |
+
from text.mandarin import chinese_to_cnm3
|
13 |
+
from text import cleaned_text_to_sequence
|
14 |
+
from datas.dataset import intersperse
|
15 |
+
|
16 |
+
import gradio as gr
|
17 |
+
import numpy as np
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
from pathlib import Path
|
20 |
+
|
21 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
22 |
+
|
23 |
+
@ torch.inference_mode()
|
24 |
+
def inference(text: str, ref_audio: torch.Tensor, checkpoint_path: str, step: int=10) -> torch.Tensor:
|
25 |
+
global last_checkpoint_path
|
26 |
+
if checkpoint_path != last_checkpoint_path:
|
27 |
+
tts_model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
28 |
+
last_checkpoint_path = checkpoint_path
|
29 |
+
|
30 |
+
phonemizer = chinese_to_cnm3
|
31 |
+
|
32 |
+
# prepare input for tts model
|
33 |
+
x = torch.tensor(intersperse(cleaned_text_to_sequence(phonemizer(text)), item=0), dtype=torch.long, device=device).unsqueeze(0)
|
34 |
+
x_len = torch.tensor([x.size(-1)], dtype=torch.long, device=device)
|
35 |
+
waveform, sr = torchaudio.load(ref_audio)
|
36 |
+
if sr != sample_rate:
|
37 |
+
waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
|
38 |
+
y = mel_extractor(waveform).to(device)
|
39 |
+
|
40 |
+
# inference
|
41 |
+
mel = tts_model.synthesise(x, x_len, step, y=y, temperature=1, length_scale=1)['decoder_outputs']
|
42 |
+
audio = vocoder(mel)
|
43 |
+
|
44 |
+
# process output for gradio
|
45 |
+
audio_output = (sample_rate, (audio.cpu().squeeze(0).numpy() * 32767).astype(np.int16)) # (samplerate, int16 audio) for gr.Audio
|
46 |
+
mel_output = plot_mel_spectrogram(mel.cpu().squeeze(0).numpy()) # get the plot of mel
|
47 |
+
return audio_output, mel_output
|
48 |
+
|
49 |
+
def get_pipeline(n_vocab: int, tts_model_config: ModelConfig, mel_config: MelConfig, vocoder_config: VocosConfig, tts_checkpoint_path, vocoder_checkpoint_path):
|
50 |
+
tts_model = StableTTS(n_vocab, mel_config.n_mels, **asdict(tts_model_config))
|
51 |
+
mel_extractor = LogMelSpectrogram(mel_config)
|
52 |
+
vocoder = Vocos(vocoder_config, mel_config)
|
53 |
+
# tts_model.load_state_dict(torch.load(tts_checkpoint_path, map_location='cpu'))
|
54 |
+
tts_model.to(device)
|
55 |
+
tts_model.eval()
|
56 |
+
vocoder.load_state_dict(torch.load(vocoder_checkpoint_path, map_location='cpu'))
|
57 |
+
vocoder.to(device)
|
58 |
+
vocoder.eval()
|
59 |
+
return tts_model, mel_extractor, vocoder
|
60 |
+
|
61 |
+
def plot_mel_spectrogram(mel_spectrogram):
|
62 |
+
fig, ax = plt.subplots(figsize=(20, 8))
|
63 |
+
ax.imshow(mel_spectrogram, aspect='auto', origin='lower')
|
64 |
+
plt.axis('off')
|
65 |
+
fig.subplots_adjust(left=0, right=1, top=1, bottom=0) # remove white edges
|
66 |
+
return fig
|
67 |
+
|
68 |
+
|
69 |
+
def main():
|
70 |
+
tts_model_config = ModelConfig()
|
71 |
+
mel_config = MelConfig()
|
72 |
+
vocoder_config = VocosConfig()
|
73 |
+
|
74 |
+
tts_checkpoint_path = './checkpoints' # the folder that contains StableTTS checkpoints
|
75 |
+
vocoder_checkpoint_path = './checkpoints/vocoder.pt'
|
76 |
+
|
77 |
+
global tts_model, mel_extractor, vocoder, sample_rate, last_checkpoint_path
|
78 |
+
sample_rate = mel_config.sample_rate
|
79 |
+
last_checkpoint_path = None
|
80 |
+
tts_model, mel_extractor, vocoder = get_pipeline(len(symbols), tts_model_config, mel_config, vocoder_config, tts_checkpoint_path, vocoder_checkpoint_path)
|
81 |
+
|
82 |
+
tts_checkpoint_path = [path for path in Path(tts_checkpoint_path).rglob('*.pt') if 'optimizer' and 'vocoder' not in path.name]
|
83 |
+
audios = list(Path('./audios').rglob('*.wav'))
|
84 |
+
|
85 |
+
# gradio wabui
|
86 |
+
gui_title = 'StableTTS'
|
87 |
+
gui_description = """Next-generation TTS model using flow-matching and DiT, inspired by Stable Diffusion 3."""
|
88 |
+
with gr.Blocks(analytics_enabled=False) as demo:
|
89 |
+
|
90 |
+
with gr.Row():
|
91 |
+
with gr.Column():
|
92 |
+
gr.Markdown(f"# {gui_title}")
|
93 |
+
gr.Markdown(gui_description)
|
94 |
+
|
95 |
+
with gr.Row():
|
96 |
+
with gr.Column():
|
97 |
+
input_text_gr = gr.Textbox(
|
98 |
+
label="Input Text",
|
99 |
+
info="One or two sentences at a time is better. Up to 200 text characters.",
|
100 |
+
value="三国杀是一款风靡全球的以三国演义为背景的策略卡牌桌面游戏,经典新三国国战玩法,百万名将任由你搭配,楚雄争霸,等你决战沙场!",
|
101 |
+
)
|
102 |
+
|
103 |
+
ref_audio_gr = gr.Dropdown(
|
104 |
+
label='reference audio',
|
105 |
+
choices=audios,
|
106 |
+
value = 0
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
checkpoint_gr = gr.Dropdown(
|
111 |
+
label='checkpoint',
|
112 |
+
choices=tts_checkpoint_path,
|
113 |
+
value = 0
|
114 |
+
)
|
115 |
+
|
116 |
+
step_gr = gr.Slider(
|
117 |
+
label='Step',
|
118 |
+
minimum=1,
|
119 |
+
maximum=100,
|
120 |
+
value=25,
|
121 |
+
step=1
|
122 |
+
)
|
123 |
+
|
124 |
+
|
125 |
+
tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
|
126 |
+
|
127 |
+
with gr.Column():
|
128 |
+
mel_gr = gr.Plot(label="Mel Visual")
|
129 |
+
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
|
130 |
+
|
131 |
+
tts_button.click(inference, [input_text_gr, ref_audio_gr, checkpoint_gr, step_gr], outputs=[audio_gr, mel_gr])
|
132 |
+
|
133 |
+
demo.queue()
|
134 |
+
demo.launch(debug=True, show_api=True)
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == '__main__':
|
138 |
+
main()
|
audios/audio0.wav
ADDED
Binary file (318 kB). View file
|
|
audios/audio1.wav
ADDED
Binary file (558 kB). View file
|
|
audios/audio2.wav
ADDED
Binary file (300 kB). View file
|
|
audios/audio3.wav
ADDED
Binary file (738 kB). View file
|
|
audios/audio4.wav
ADDED
Binary file (547 kB). View file
|
|
audios/audio5.wav
ADDED
Binary file (667 kB). View file
|
|
audios/audio6.wav
ADDED
Binary file (596 kB). View file
|
|
checkpoints/.keep
ADDED
File without changes
|
checkpoints/checkpoint_0.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f9c5d59001f1c9e308c0fa00291779722e88d5a8162734afac547c95126d4580
|
3 |
+
size 37552792
|
checkpoints/vocoder.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e180a1df6ca0a9e382e0915b0b1984aecfb63397c4ee21f12857447c2d76d29a
|
3 |
+
size 56666508
|
config.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
@dataclass
|
4 |
+
class MelConfig:
|
5 |
+
sample_rate: int = 44100
|
6 |
+
n_fft: int = 2048
|
7 |
+
win_length: int = 2048
|
8 |
+
hop_length: int = 512
|
9 |
+
f_min: float = 0.0
|
10 |
+
f_max: float = None
|
11 |
+
pad: int = 0
|
12 |
+
n_mels: int = 128
|
13 |
+
power: float = 1.0
|
14 |
+
normalized: bool = False
|
15 |
+
center: bool = False
|
16 |
+
pad_mode: str = "reflect"
|
17 |
+
mel_scale: str = "htk"
|
18 |
+
|
19 |
+
def __post_init__(self):
|
20 |
+
if self.pad == 0:
|
21 |
+
self.pad = (self.n_fft - self.hop_length) // 2
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class ModelConfig:
|
25 |
+
hidden_channels: int = 192
|
26 |
+
filter_channels: int = 512
|
27 |
+
n_heads: int = 2
|
28 |
+
n_enc_layers: int = 3
|
29 |
+
n_dec_layers: int = 2
|
30 |
+
kernel_size: int = 3
|
31 |
+
p_dropout: int = 0.1
|
32 |
+
gin_channels: int = 192
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class TrainConfig:
|
36 |
+
train_dataset_path: str = 'filelists/filelist.json'
|
37 |
+
test_dataset_path: str = 'filelists/filelist.json'
|
38 |
+
batch_size: int = 52
|
39 |
+
learning_rate: float = 1e-4
|
40 |
+
num_epochs: int = 10000
|
41 |
+
model_save_path: str = './checkpoints'
|
42 |
+
log_dir: str = './runs'
|
43 |
+
log_interval: int = 128
|
44 |
+
save_interval: int = 15
|
45 |
+
warmup_steps: int = 200
|
46 |
+
|
47 |
+
@dataclass
|
48 |
+
class VocosConfig:
|
49 |
+
input_channels: int = 128
|
50 |
+
dim: int = 512
|
51 |
+
intermediate_dim: int = 1536
|
52 |
+
num_layers: int = 8
|
datas/__init__.py
ADDED
File without changes
|
datas/dataset.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
from text import cleaned_text_to_sequence
|
8 |
+
|
9 |
+
def intersperse(lst, item):
|
10 |
+
result = [item] * (len(lst) * 2 + 1)
|
11 |
+
result[1::2] = lst
|
12 |
+
return result
|
13 |
+
|
14 |
+
class StableDataset(Dataset):
|
15 |
+
def __init__(self, filelist_path, hop_length):
|
16 |
+
self.filelist_path = filelist_path
|
17 |
+
self.hop_length = hop_length
|
18 |
+
|
19 |
+
self._load_filelist(filelist_path)
|
20 |
+
|
21 |
+
def _load_filelist(self, filelist_path):
|
22 |
+
filelist, lengths = [], []
|
23 |
+
with open(filelist_path, 'r', encoding='utf-8') as f:
|
24 |
+
for line in f:
|
25 |
+
line = json.loads(line.strip())
|
26 |
+
filelist.append((line['mel_path'], line['phone']))
|
27 |
+
lengths.append(os.path.getsize(line['audio_path']) // (2 * self.hop_length))
|
28 |
+
|
29 |
+
self.filelist = filelist
|
30 |
+
self.lengths = lengths
|
31 |
+
|
32 |
+
def __len__(self):
|
33 |
+
return len(self.filelist)
|
34 |
+
|
35 |
+
def __getitem__(self, idx):
|
36 |
+
mel_path, phone = self.filelist[idx]
|
37 |
+
mel = torch.load(mel_path, map_location='cpu')
|
38 |
+
phone = torch.tensor(intersperse(cleaned_text_to_sequence(phone), 0), dtype=torch.long)
|
39 |
+
return mel, phone
|
40 |
+
|
41 |
+
def collate_fn(batch):
|
42 |
+
texts = [item[1] for item in batch]
|
43 |
+
mels = [item[0] for item in batch]
|
44 |
+
|
45 |
+
text_lengths = torch.tensor([text.size(-1) for text in texts], dtype=torch.long)
|
46 |
+
mel_lengths = torch.tensor([mel.size(-1) for mel in mels], dtype=torch.long)
|
47 |
+
|
48 |
+
# pad to the same length
|
49 |
+
texts_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(texts), padding=0)
|
50 |
+
mels_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(mels), padding=0)
|
51 |
+
|
52 |
+
return texts_padded, text_lengths, mels_padded, mel_lengths
|
datas/sampler.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# reference: https://github.com/jaywalnut310/vits/blob/main/data_utils.py
|
4 |
+
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
5 |
+
"""
|
6 |
+
Maintain similar input lengths in a batch.
|
7 |
+
Length groups are specified by boundaries.
|
8 |
+
Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
|
9 |
+
|
10 |
+
It removes samples which are not included in the boundaries.
|
11 |
+
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
dataset,
|
17 |
+
batch_size,
|
18 |
+
boundaries,
|
19 |
+
num_replicas=None,
|
20 |
+
rank=None,
|
21 |
+
shuffle=True,
|
22 |
+
):
|
23 |
+
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
24 |
+
self.lengths = dataset.lengths
|
25 |
+
self.batch_size = batch_size
|
26 |
+
self.boundaries = boundaries
|
27 |
+
|
28 |
+
self.buckets, self.num_samples_per_bucket = self._create_buckets()
|
29 |
+
self.total_size = sum(self.num_samples_per_bucket)
|
30 |
+
self.num_samples = self.total_size // self.num_replicas
|
31 |
+
|
32 |
+
def _create_buckets(self):
|
33 |
+
buckets = [[] for _ in range(len(self.boundaries) - 1)]
|
34 |
+
for i in range(len(self.lengths)):
|
35 |
+
length = self.lengths[i]
|
36 |
+
idx_bucket = self._bisect(length)
|
37 |
+
if idx_bucket != -1:
|
38 |
+
buckets[idx_bucket].append(i)
|
39 |
+
|
40 |
+
for i in range(len(buckets) - 1, 0, -1):
|
41 |
+
# for i in range(len(buckets) - 1, -1, -1):
|
42 |
+
if len(buckets[i]) == 0:
|
43 |
+
buckets.pop(i)
|
44 |
+
self.boundaries.pop(i + 1)
|
45 |
+
|
46 |
+
num_samples_per_bucket = []
|
47 |
+
for i in range(len(buckets)):
|
48 |
+
len_bucket = len(buckets[i])
|
49 |
+
total_batch_size = self.num_replicas * self.batch_size
|
50 |
+
rem = (
|
51 |
+
total_batch_size - (len_bucket % total_batch_size)
|
52 |
+
) % total_batch_size
|
53 |
+
num_samples_per_bucket.append(len_bucket + rem)
|
54 |
+
return buckets, num_samples_per_bucket
|
55 |
+
|
56 |
+
def __iter__(self):
|
57 |
+
# deterministically shuffle based on epoch
|
58 |
+
g = torch.Generator()
|
59 |
+
g.manual_seed(self.epoch)
|
60 |
+
|
61 |
+
indices = []
|
62 |
+
if self.shuffle:
|
63 |
+
for bucket in self.buckets:
|
64 |
+
indices.append(torch.randperm(len(bucket), generator=g).tolist())
|
65 |
+
else:
|
66 |
+
for bucket in self.buckets:
|
67 |
+
indices.append(list(range(len(bucket))))
|
68 |
+
|
69 |
+
batches = []
|
70 |
+
for i in range(len(self.buckets)):
|
71 |
+
bucket = self.buckets[i]
|
72 |
+
len_bucket = len(bucket)
|
73 |
+
ids_bucket = indices[i]
|
74 |
+
num_samples_bucket = self.num_samples_per_bucket[i]
|
75 |
+
|
76 |
+
# add extra samples to make it evenly divisible
|
77 |
+
rem = num_samples_bucket - len_bucket
|
78 |
+
ids_bucket = (
|
79 |
+
ids_bucket
|
80 |
+
+ ids_bucket * (rem // len_bucket)
|
81 |
+
+ ids_bucket[: (rem % len_bucket)]
|
82 |
+
)
|
83 |
+
|
84 |
+
# subsample
|
85 |
+
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
|
86 |
+
|
87 |
+
# batching
|
88 |
+
for j in range(len(ids_bucket) // self.batch_size):
|
89 |
+
batch = [
|
90 |
+
bucket[idx]
|
91 |
+
for idx in ids_bucket[
|
92 |
+
j * self.batch_size : (j + 1) * self.batch_size
|
93 |
+
]
|
94 |
+
]
|
95 |
+
batches.append(batch)
|
96 |
+
|
97 |
+
if self.shuffle:
|
98 |
+
batch_ids = torch.randperm(len(batches), generator=g).tolist()
|
99 |
+
batches = [batches[i] for i in batch_ids]
|
100 |
+
self.batches = batches
|
101 |
+
|
102 |
+
assert len(self.batches) * self.batch_size == self.num_samples
|
103 |
+
return iter(self.batches)
|
104 |
+
|
105 |
+
def _bisect(self, x, lo=0, hi=None):
|
106 |
+
if hi is None:
|
107 |
+
hi = len(self.boundaries) - 1
|
108 |
+
|
109 |
+
if hi > lo:
|
110 |
+
mid = (hi + lo) // 2
|
111 |
+
if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
|
112 |
+
return mid
|
113 |
+
elif x <= self.boundaries[mid]:
|
114 |
+
return self._bisect(x, lo, mid)
|
115 |
+
else:
|
116 |
+
return self._bisect(x, mid + 1, hi)
|
117 |
+
else:
|
118 |
+
return -1
|
119 |
+
|
120 |
+
def __len__(self):
|
121 |
+
return self.num_samples // self.batch_size
|
models/__init__.py
ADDED
File without changes
|
models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (135 Bytes). View file
|
|
models/__pycache__/model.cpython-310.pyc
ADDED
Binary file (6.49 kB). View file
|
|
models/dit.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# References:
|
2 |
+
# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/transformer.py
|
3 |
+
# https://github.com/jaywalnut310/vits/blob/main/attentions.py
|
4 |
+
# https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
class FFN(nn.Module):
|
11 |
+
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., gin_channels=0):
|
12 |
+
super().__init__()
|
13 |
+
self.in_channels = in_channels
|
14 |
+
self.out_channels = out_channels
|
15 |
+
self.filter_channels = filter_channels
|
16 |
+
self.kernel_size = kernel_size
|
17 |
+
self.p_dropout = p_dropout
|
18 |
+
self.gin_channels = gin_channels
|
19 |
+
|
20 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
21 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
22 |
+
self.drop = nn.Dropout(p_dropout)
|
23 |
+
self.act1 = nn.GELU(approximate="tanh")
|
24 |
+
|
25 |
+
def forward(self, x, x_mask):
|
26 |
+
x = self.conv_1(x * x_mask)
|
27 |
+
x = self.act1(x)
|
28 |
+
x = self.drop(x)
|
29 |
+
x = self.conv_2(x * x_mask)
|
30 |
+
return x * x_mask
|
31 |
+
|
32 |
+
class MultiHeadAttention(nn.Module):
|
33 |
+
def __init__(self, channels, out_channels, n_heads, p_dropout=0.):
|
34 |
+
super().__init__()
|
35 |
+
assert channels % n_heads == 0
|
36 |
+
|
37 |
+
self.channels = channels
|
38 |
+
self.out_channels = out_channels
|
39 |
+
self.n_heads = n_heads
|
40 |
+
self.p_dropout = p_dropout
|
41 |
+
|
42 |
+
self.k_channels = channels // n_heads
|
43 |
+
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
|
44 |
+
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
|
45 |
+
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
|
46 |
+
|
47 |
+
# from https://nn.labml.ai/transformers/rope/index.html
|
48 |
+
self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
|
49 |
+
self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
|
50 |
+
|
51 |
+
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
52 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
53 |
+
|
54 |
+
torch.nn.init.xavier_uniform_(self.conv_q.weight)
|
55 |
+
torch.nn.init.xavier_uniform_(self.conv_k.weight)
|
56 |
+
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
57 |
+
|
58 |
+
def forward(self, x, attn_mask=None):
|
59 |
+
q = self.conv_q(x)
|
60 |
+
k = self.conv_k(x)
|
61 |
+
v = self.conv_v(x)
|
62 |
+
|
63 |
+
x = self.attention(q, k, v, mask=attn_mask)
|
64 |
+
|
65 |
+
x = self.conv_o(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
def attention(self, query, key, value, mask=None):
|
69 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
70 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
71 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
72 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
73 |
+
|
74 |
+
query = self.query_rotary_pe(query) # [b, n_head, t, c // n_head]
|
75 |
+
key = self.key_rotary_pe(key)
|
76 |
+
|
77 |
+
output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=self.p_dropout if self.training else 0)
|
78 |
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
79 |
+
return output
|
80 |
+
|
81 |
+
# modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/modules.py#L390
|
82 |
+
class DiTConVBlock(nn.Module):
|
83 |
+
"""
|
84 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
85 |
+
"""
|
86 |
+
def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0):
|
87 |
+
super().__init__()
|
88 |
+
self.norm1 = nn.LayerNorm(hidden_channels, elementwise_affine=False, eps=1e-6)
|
89 |
+
self.attn = MultiHeadAttention(hidden_channels, hidden_channels, num_heads, p_dropout)
|
90 |
+
self.norm2 = nn.LayerNorm(hidden_channels, elementwise_affine=False, eps=1e-6)
|
91 |
+
self.mlp = FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)
|
92 |
+
self.adaLN_modulation = nn.Sequential(
|
93 |
+
nn.Linear(gin_channels, hidden_channels) if gin_channels != hidden_channels else nn.Identity(),
|
94 |
+
nn.SiLU(),
|
95 |
+
nn.Linear(hidden_channels, 6 * hidden_channels, bias=True)
|
96 |
+
)
|
97 |
+
|
98 |
+
def forward(self, x, c, x_mask):
|
99 |
+
"""
|
100 |
+
Args:
|
101 |
+
x : [batch_size, channel, time]
|
102 |
+
c : [batch_size, channel]
|
103 |
+
x_mask : [batch_size, 1, time]
|
104 |
+
return the same shape as x
|
105 |
+
"""
|
106 |
+
x = x * x_mask
|
107 |
+
attn_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(-1) # shape: [batch_size, 1, time, time]
|
108 |
+
# attn_mask = attn_mask.to(torch.bool)
|
109 |
+
|
110 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).unsqueeze(2).chunk(6, dim=1) # shape: [batch_size, channel, 1]
|
111 |
+
x = x + gate_msa * self.attn(self.modulate(self.norm1(x.transpose(1,2)).transpose(1,2), shift_msa, scale_msa), attn_mask) * x_mask
|
112 |
+
x = x + gate_mlp * self.mlp(self.modulate(self.norm2(x.transpose(1,2)).transpose(1,2), shift_mlp, scale_mlp), x_mask)
|
113 |
+
|
114 |
+
# no condition version
|
115 |
+
# x = x + self.attn(self.norm1(x.transpose(1,2)).transpose(1,2), attn_mask)
|
116 |
+
# x = x + self.mlp(self.norm1(x.transpose(1,2)).transpose(1,2), x_mask)
|
117 |
+
return x
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def modulate(x, shift, scale):
|
121 |
+
return x * (1 + scale) + shift
|
122 |
+
|
123 |
+
class RotaryPositionalEmbeddings(nn.Module):
|
124 |
+
"""
|
125 |
+
## RoPE module
|
126 |
+
|
127 |
+
Rotary encoding transforms pairs of features by rotating in the 2D plane.
|
128 |
+
That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
|
129 |
+
Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
|
130 |
+
by an angle depending on the position of the token.
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(self, d: int, base: int = 10_000):
|
134 |
+
r"""
|
135 |
+
* `d` is the number of features $d$
|
136 |
+
* `base` is the constant used for calculating $\Theta$
|
137 |
+
"""
|
138 |
+
super().__init__()
|
139 |
+
|
140 |
+
self.base = base
|
141 |
+
self.d = int(d)
|
142 |
+
self.cos_cached = None
|
143 |
+
self.sin_cached = None
|
144 |
+
|
145 |
+
def _build_cache(self, x: torch.Tensor):
|
146 |
+
r"""
|
147 |
+
Cache $\cos$ and $\sin$ values
|
148 |
+
"""
|
149 |
+
# Return if cache is already built
|
150 |
+
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
|
151 |
+
return
|
152 |
+
|
153 |
+
# Get sequence length
|
154 |
+
seq_len = x.shape[0]
|
155 |
+
|
156 |
+
# $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
|
157 |
+
theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
|
158 |
+
|
159 |
+
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
160 |
+
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
|
161 |
+
|
162 |
+
# Calculate the product of position index and $\theta_i$
|
163 |
+
idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
|
164 |
+
|
165 |
+
# Concatenate so that for row $m$ we have
|
166 |
+
# $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
|
167 |
+
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
|
168 |
+
|
169 |
+
# Cache them
|
170 |
+
self.cos_cached = idx_theta2.cos()[:, None, None, :]
|
171 |
+
self.sin_cached = idx_theta2.sin()[:, None, None, :]
|
172 |
+
|
173 |
+
def _neg_half(self, x: torch.Tensor):
|
174 |
+
# $\frac{d}{2}$
|
175 |
+
d_2 = self.d // 2
|
176 |
+
|
177 |
+
# Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
178 |
+
return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
|
179 |
+
|
180 |
+
def forward(self, x: torch.Tensor):
|
181 |
+
"""
|
182 |
+
* `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
|
183 |
+
"""
|
184 |
+
# Cache $\cos$ and $\sin$ values
|
185 |
+
x = x.permute(2, 0, 1, 3) # b h t d -> t b h d
|
186 |
+
|
187 |
+
self._build_cache(x)
|
188 |
+
|
189 |
+
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
|
190 |
+
x_rope, x_pass = x[..., : self.d], x[..., self.d :]
|
191 |
+
|
192 |
+
# Calculate
|
193 |
+
# $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
|
194 |
+
neg_half_x = self._neg_half(x_rope)
|
195 |
+
|
196 |
+
x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
|
197 |
+
|
198 |
+
return torch.cat((x_rope, x_pass), dim=-1).permute(1, 2, 0, 3) # t b h d -> b h t d
|
199 |
+
|
200 |
+
class Transpose(nn.Identity):
|
201 |
+
"""(N, T, D) -> (N, D, T)"""
|
202 |
+
|
203 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
204 |
+
return input.transpose(1, 2)
|
205 |
+
|
models/duration_predictor.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
# modified from https://github.com/jaywalnut310/vits/blob/main/models.py#L98
|
5 |
+
class DurationPredictor(nn.Module):
|
6 |
+
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
self.in_channels = in_channels
|
10 |
+
self.filter_channels = filter_channels
|
11 |
+
self.kernel_size = kernel_size
|
12 |
+
self.p_dropout = p_dropout
|
13 |
+
self.gin_channels = gin_channels
|
14 |
+
|
15 |
+
self.drop = nn.Dropout(p_dropout)
|
16 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
|
17 |
+
self.norm_1 = nn.LayerNorm(filter_channels)
|
18 |
+
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
|
19 |
+
self.norm_2 = nn.LayerNorm(filter_channels)
|
20 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
21 |
+
|
22 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
23 |
+
|
24 |
+
def forward(self, x, x_mask, g):
|
25 |
+
x = x.detach()
|
26 |
+
x = x + self.cond(g.unsqueeze(2).detach())
|
27 |
+
x = self.conv_1(x * x_mask)
|
28 |
+
x = torch.relu(x)
|
29 |
+
x = self.norm_1(x.transpose(1,2)).transpose(1,2)
|
30 |
+
x = self.drop(x)
|
31 |
+
x = self.conv_2(x * x_mask)
|
32 |
+
x = torch.relu(x)
|
33 |
+
x = self.norm_2(x.transpose(1,2)).transpose(1,2)
|
34 |
+
x = self.drop(x)
|
35 |
+
x = self.proj(x * x_mask)
|
36 |
+
return x * x_mask
|
37 |
+
|
38 |
+
def duration_loss(logw, logw_, lengths):
|
39 |
+
loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
|
40 |
+
return loss
|
models/estimator.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from models.dit import DiTConVBlock
|
8 |
+
|
9 |
+
class DitWrapper(nn.Module):
|
10 |
+
""" add FiLM layer to condition time embedding to DiT """
|
11 |
+
def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0, time_channels=0):
|
12 |
+
super().__init__()
|
13 |
+
self.time_fusion = FiLMLayer(hidden_channels, time_channels)
|
14 |
+
self.conv1 = ConvNeXtBlock(hidden_channels, filter_channels, gin_channels)
|
15 |
+
self.conv2 = ConvNeXtBlock(hidden_channels, filter_channels, gin_channels)
|
16 |
+
self.conv3 = ConvNeXtBlock(hidden_channels, filter_channels, gin_channels)
|
17 |
+
self.block = DiTConVBlock(hidden_channels, hidden_channels, num_heads, kernel_size, p_dropout, gin_channels)
|
18 |
+
|
19 |
+
def forward(self, x, c, t, x_mask):
|
20 |
+
x = self.time_fusion(x, t) * x_mask
|
21 |
+
x = self.conv1(x, c, x_mask)
|
22 |
+
x = self.conv2(x, c, x_mask)
|
23 |
+
x = self.conv3(x, c, x_mask)
|
24 |
+
x = self.block(x, c, x_mask)
|
25 |
+
return x
|
26 |
+
|
27 |
+
class FiLMLayer(nn.Module):
|
28 |
+
"""
|
29 |
+
Feature-wise Linear Modulation (FiLM) layer
|
30 |
+
Reference: https://arxiv.org/abs/1709.07871
|
31 |
+
"""
|
32 |
+
def __init__(self, in_channels, cond_channels):
|
33 |
+
|
34 |
+
super(FiLMLayer, self).__init__()
|
35 |
+
self.in_channels = in_channels
|
36 |
+
self.film = nn.Conv1d(cond_channels, in_channels * 2, 1)
|
37 |
+
|
38 |
+
def forward(self, x, c):
|
39 |
+
gamma, beta = torch.chunk(self.film(c.unsqueeze(2)), chunks=2, dim=1)
|
40 |
+
return gamma * x + beta
|
41 |
+
|
42 |
+
class ConvNeXtBlock(nn.Module):
|
43 |
+
def __init__(self, in_channels, filter_channels, gin_channels):
|
44 |
+
super().__init__()
|
45 |
+
self.dwconv = nn.Conv1d(in_channels, in_channels, kernel_size=7, padding=3, groups=in_channels)
|
46 |
+
self.norm = StyleAdaptiveLayerNorm(in_channels, gin_channels)
|
47 |
+
self.pwconv = nn.Sequential(nn.Linear(in_channels, filter_channels),
|
48 |
+
nn.GELU(),
|
49 |
+
nn.Linear(filter_channels, in_channels))
|
50 |
+
|
51 |
+
def forward(self, x, c, x_mask) -> torch.Tensor:
|
52 |
+
residual = x
|
53 |
+
x = self.dwconv(x) * x_mask
|
54 |
+
x = self.norm(x.transpose(1, 2), c)
|
55 |
+
x = self.pwconv(x).transpose(1, 2)
|
56 |
+
x = residual + x
|
57 |
+
return x * x_mask
|
58 |
+
|
59 |
+
class StyleAdaptiveLayerNorm(nn.Module):
|
60 |
+
def __init__(self, in_channels, cond_channels):
|
61 |
+
"""
|
62 |
+
Style Adaptive Layer Normalization (SALN) module.
|
63 |
+
|
64 |
+
Parameters:
|
65 |
+
in_channels: The number of channels in the input feature maps.
|
66 |
+
cond_channels: The number of channels in the conditioning input.
|
67 |
+
"""
|
68 |
+
super(StyleAdaptiveLayerNorm, self).__init__()
|
69 |
+
self.in_channels = in_channels
|
70 |
+
|
71 |
+
self.saln = nn.Linear(cond_channels, in_channels * 2, 1)
|
72 |
+
self.norm = nn.LayerNorm(in_channels, elementwise_affine=False)
|
73 |
+
|
74 |
+
self.reset_parameters()
|
75 |
+
|
76 |
+
def reset_parameters(self):
|
77 |
+
nn.init.constant_(self.saln.bias.data[:self.in_channels], 1)
|
78 |
+
nn.init.constant_(self.saln.bias.data[self.in_channels:], 0)
|
79 |
+
|
80 |
+
def forward(self, x, c):
|
81 |
+
gamma, beta = torch.chunk(self.saln(c.unsqueeze(1)), chunks=2, dim=-1)
|
82 |
+
return gamma * self.norm(x) + beta
|
83 |
+
|
84 |
+
|
85 |
+
class SinusoidalPosEmb(nn.Module):
|
86 |
+
def __init__(self, dim):
|
87 |
+
super().__init__()
|
88 |
+
self.dim = dim
|
89 |
+
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
90 |
+
|
91 |
+
def forward(self, x, scale=1000):
|
92 |
+
if x.ndim < 1:
|
93 |
+
x = x.unsqueeze(0)
|
94 |
+
half_dim = self.dim // 2
|
95 |
+
emb = math.log(10000) / (half_dim - 1)
|
96 |
+
emb = torch.exp(torch.arange(half_dim, device=x.device).float() * -emb)
|
97 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
98 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
99 |
+
return emb
|
100 |
+
|
101 |
+
class TimestepEmbedding(nn.Module):
|
102 |
+
def __init__(self, in_channels, out_channels, filter_channels):
|
103 |
+
super().__init__()
|
104 |
+
|
105 |
+
self.layer = nn.Sequential(
|
106 |
+
nn.Linear(in_channels, filter_channels),
|
107 |
+
nn.SiLU(inplace=True),
|
108 |
+
nn.Linear(filter_channels, out_channels)
|
109 |
+
)
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
return self.layer(x)
|
113 |
+
|
114 |
+
# reference: https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/decoder.py
|
115 |
+
class Decoder(nn.Module):
|
116 |
+
def __init__(self, hidden_channels, out_channels, filter_channels, dropout=0.05, n_layers=1, n_heads=4, kernel_size=3, gin_channels=0):
|
117 |
+
super().__init__()
|
118 |
+
self.hidden_channels = hidden_channels
|
119 |
+
self.out_channels = out_channels
|
120 |
+
self.filter_channels = filter_channels
|
121 |
+
|
122 |
+
self.time_embeddings = SinusoidalPosEmb(hidden_channels)
|
123 |
+
self.time_mlp = TimestepEmbedding(hidden_channels, hidden_channels, filter_channels)
|
124 |
+
|
125 |
+
|
126 |
+
self.blocks = nn.ModuleList([DitWrapper(hidden_channels, filter_channels, n_heads, kernel_size, dropout, gin_channels, hidden_channels) for _ in range(n_layers)])
|
127 |
+
self.final_proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
128 |
+
|
129 |
+
self.initialize_weights()
|
130 |
+
|
131 |
+
def initialize_weights(self):
|
132 |
+
for block in self.blocks:
|
133 |
+
nn.init.constant_(block.block.adaLN_modulation[-1].weight, 0)
|
134 |
+
nn.init.constant_(block.block.adaLN_modulation[-1].bias, 0)
|
135 |
+
|
136 |
+
def forward(self, x, mask, mu, t, c):
|
137 |
+
"""Forward pass of the UNet1DConditional model.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
141 |
+
mask (_type_): shape (batch_size, 1, time)
|
142 |
+
t (_type_): shape (batch_size)
|
143 |
+
c (_type_): shape (batch_size, gin_channels)
|
144 |
+
|
145 |
+
Raises:
|
146 |
+
ValueError: _description_
|
147 |
+
ValueError: _description_
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
_type_: _description_
|
151 |
+
"""
|
152 |
+
|
153 |
+
t = self.time_mlp(self.time_embeddings(t))
|
154 |
+
x = torch.cat((x, mu), dim=1)
|
155 |
+
|
156 |
+
for block in self.blocks:
|
157 |
+
x = block(x, c, t, mask)
|
158 |
+
|
159 |
+
output = self.final_proj(x * mask)
|
160 |
+
|
161 |
+
return output * mask
|
models/flow_matching.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from models.estimator import Decoder
|
6 |
+
|
7 |
+
# copied from https://github.com/jaywalnut310/vits/blob/main/commons.py#L121
|
8 |
+
def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
|
9 |
+
if max_length is None:
|
10 |
+
max_length = length.max()
|
11 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
12 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
13 |
+
|
14 |
+
# modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/flow_matching.py
|
15 |
+
class CFMDecoder(torch.nn.Module):
|
16 |
+
def __init__(self, hidden_channels, out_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
|
17 |
+
super().__init__()
|
18 |
+
self.hidden_channels = hidden_channels
|
19 |
+
self.out_channels = out_channels
|
20 |
+
self.filter_channels = filter_channels
|
21 |
+
self.gin_channels = gin_channels
|
22 |
+
self.sigma_min = 1e-4
|
23 |
+
|
24 |
+
self.estimator = Decoder(hidden_channels, out_channels, filter_channels, p_dropout, n_layers, n_heads, kernel_size, gin_channels)
|
25 |
+
|
26 |
+
@torch.inference_mode()
|
27 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, c=None):
|
28 |
+
"""Forward diffusion
|
29 |
+
|
30 |
+
Args:
|
31 |
+
mu (torch.Tensor): output of encoder
|
32 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
33 |
+
mask (torch.Tensor): output_mask
|
34 |
+
shape: (batch_size, 1, mel_timesteps)
|
35 |
+
n_timesteps (int): number of diffusion steps
|
36 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
37 |
+
c (torch.Tensor, optional): shape: (batch_size, gin_channels)
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
sample: generated mel-spectrogram
|
41 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
42 |
+
"""
|
43 |
+
z = torch.randn_like(mu) * temperature
|
44 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
45 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, c=c)
|
46 |
+
|
47 |
+
def solve_euler(self, x, t_span, mu, mask, c):
|
48 |
+
"""
|
49 |
+
Fixed euler solver for ODEs.
|
50 |
+
Args:
|
51 |
+
x (torch.Tensor): random noise
|
52 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
53 |
+
shape: (n_timesteps + 1,)
|
54 |
+
mu (torch.Tensor): output of encoder
|
55 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
56 |
+
mask (torch.Tensor): output_mask
|
57 |
+
shape: (batch_size, 1, mel_timesteps)
|
58 |
+
c (torch.Tensor, optional): speaker condition.
|
59 |
+
shape: (batch_size, gin_channels)
|
60 |
+
"""
|
61 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
62 |
+
|
63 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
64 |
+
# Or in future might add like a return_all_steps flag
|
65 |
+
sol = []
|
66 |
+
|
67 |
+
for step in range(1, len(t_span)):
|
68 |
+
dphi_dt = self.estimator(x, mask, mu, t, c)
|
69 |
+
|
70 |
+
x = x + dt * dphi_dt
|
71 |
+
t = t + dt
|
72 |
+
sol.append(x)
|
73 |
+
if step < len(t_span) - 1:
|
74 |
+
dt = t_span[step + 1] - t
|
75 |
+
|
76 |
+
return sol[-1]
|
77 |
+
|
78 |
+
def compute_loss(self, x1, mask, mu, c):
|
79 |
+
"""Computes diffusion loss
|
80 |
+
|
81 |
+
Args:
|
82 |
+
x1 (torch.Tensor): Target
|
83 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
84 |
+
mask (torch.Tensor): target mask
|
85 |
+
shape: (batch_size, 1, mel_timesteps)
|
86 |
+
mu (torch.Tensor): output of encoder
|
87 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
88 |
+
c (torch.Tensor, optional): speaker condition.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
loss: conditional flow matching loss
|
92 |
+
y: conditional flow
|
93 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
94 |
+
"""
|
95 |
+
b, _, t = mu.shape
|
96 |
+
|
97 |
+
# random timestep
|
98 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
99 |
+
# sample noise p(x_0)
|
100 |
+
z = torch.randn_like(x1)
|
101 |
+
|
102 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
103 |
+
u = x1 - (1 - self.sigma_min) * z
|
104 |
+
|
105 |
+
loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), c), u, reduction="sum") / (
|
106 |
+
torch.sum(mask) * u.shape[1]
|
107 |
+
)
|
108 |
+
return loss, y
|
models/model.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
import monotonic_align
|
6 |
+
from models.text_encoder import TextEncoder
|
7 |
+
from models.flow_matching import CFMDecoder
|
8 |
+
from models.reference_encoder import MelStyleEncoder
|
9 |
+
from models.duration_predictor import DurationPredictor, duration_loss
|
10 |
+
|
11 |
+
def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
|
12 |
+
if max_length is None:
|
13 |
+
max_length = length.max()
|
14 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
15 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
16 |
+
|
17 |
+
def convert_pad_shape(pad_shape):
|
18 |
+
inverted_shape = pad_shape[::-1]
|
19 |
+
pad_shape = [item for sublist in inverted_shape for item in sublist]
|
20 |
+
return pad_shape
|
21 |
+
|
22 |
+
def generate_path(duration, mask):
|
23 |
+
device = duration.device
|
24 |
+
|
25 |
+
b, t_x, t_y = mask.shape
|
26 |
+
cum_duration = torch.cumsum(duration, 1)
|
27 |
+
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
|
28 |
+
|
29 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
30 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
31 |
+
path = path.view(b, t_x, t_y)
|
32 |
+
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
33 |
+
path = path * mask
|
34 |
+
return path
|
35 |
+
|
36 |
+
# modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/matcha_tts.py
|
37 |
+
class StableTTS(nn.Module):
|
38 |
+
def __init__(self, n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, n_dec_layers, kernel_size, p_dropout, gin_channels):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.n_vocab = n_vocab
|
42 |
+
self.mel_channels = mel_channels
|
43 |
+
|
44 |
+
self.encoder = TextEncoder(n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, kernel_size, p_dropout, gin_channels)
|
45 |
+
self.ref_encoder = MelStyleEncoder(mel_channels, style_vector_dim=gin_channels, style_kernel_size=3)
|
46 |
+
self.dp = DurationPredictor(hidden_channels, filter_channels, kernel_size, p_dropout, gin_channels)
|
47 |
+
self.decoder = CFMDecoder(mel_channels + mel_channels, mel_channels, filter_channels, n_heads, n_dec_layers, kernel_size, p_dropout, gin_channels)
|
48 |
+
|
49 |
+
@torch.inference_mode()
|
50 |
+
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, y=None, length_scale=1.0):
|
51 |
+
"""
|
52 |
+
Generates mel-spectrogram from text. Returns:
|
53 |
+
1. encoder outputs
|
54 |
+
2. decoder outputs
|
55 |
+
3. generated alignment
|
56 |
+
|
57 |
+
Args:
|
58 |
+
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
59 |
+
shape: (batch_size, max_text_length)
|
60 |
+
x_lengths (torch.Tensor): lengths of texts in batch.
|
61 |
+
shape: (batch_size,)
|
62 |
+
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
|
63 |
+
temperature (float, optional): controls variance of terminal distribution.
|
64 |
+
y (torch.Tensor): mel spectrogram of reference audio
|
65 |
+
shape: (batch_size, mel_channels, time)
|
66 |
+
length_scale (float, optional): controls speech pace.
|
67 |
+
Increase value to slow down generated speech and vice versa.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
dict: {
|
71 |
+
"encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
72 |
+
# Average mel spectrogram generated by the encoder
|
73 |
+
"decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
|
74 |
+
# Refined mel spectrogram improved by the CFM
|
75 |
+
"attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
|
76 |
+
# Alignment map between text and mel spectrogram
|
77 |
+
"""
|
78 |
+
|
79 |
+
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
80 |
+
c = self.ref_encoder(y, None)
|
81 |
+
x, mu_x, x_mask = self.encoder(x, c, x_lengths)
|
82 |
+
logw = self.dp(x, x_mask, c)
|
83 |
+
|
84 |
+
w = torch.exp(logw) * x_mask
|
85 |
+
w_ceil = torch.ceil(w) * length_scale
|
86 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
87 |
+
y_max_length = y_lengths.max()
|
88 |
+
|
89 |
+
# Using obtained durations `w` construct alignment map `attn`
|
90 |
+
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask.dtype)
|
91 |
+
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
92 |
+
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
93 |
+
|
94 |
+
# Align encoded text and get mu_y
|
95 |
+
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
96 |
+
mu_y = mu_y.transpose(1, 2)
|
97 |
+
encoder_outputs = mu_y[:, :, :y_max_length]
|
98 |
+
|
99 |
+
# Generate sample tracing the probability flow
|
100 |
+
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c)
|
101 |
+
decoder_outputs = decoder_outputs[:, :, :y_max_length]
|
102 |
+
|
103 |
+
|
104 |
+
return {
|
105 |
+
"encoder_outputs": encoder_outputs,
|
106 |
+
"decoder_outputs": decoder_outputs,
|
107 |
+
"attn": attn[:, :, :y_max_length],
|
108 |
+
}
|
109 |
+
|
110 |
+
def forward(self, x, x_lengths, y, y_lengths):
|
111 |
+
"""
|
112 |
+
Computes 3 losses:
|
113 |
+
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
|
114 |
+
2. prior loss: loss between mel-spectrogram and encoder outputs.
|
115 |
+
3. flow matching loss: loss between mel-spectrogram and decoder outputs.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
|
119 |
+
shape: (batch_size, max_text_length)
|
120 |
+
x_lengths (torch.Tensor): lengths of texts in batch.
|
121 |
+
shape: (batch_size,)
|
122 |
+
y (torch.Tensor): batch of corresponding mel-spectrograms.
|
123 |
+
shape: (batch_size, n_feats, max_mel_length)
|
124 |
+
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
|
125 |
+
shape: (batch_size,)
|
126 |
+
"""
|
127 |
+
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
128 |
+
y_mask = sequence_mask(y_lengths, y.size(2)).unsqueeze(1).to(y.dtype)
|
129 |
+
c = self.ref_encoder(y, y_mask)
|
130 |
+
|
131 |
+
x, mu_x, x_mask = self.encoder(x, c, x_lengths)
|
132 |
+
logw = self.dp(x, x_mask, c)
|
133 |
+
|
134 |
+
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
135 |
+
|
136 |
+
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
|
137 |
+
|
138 |
+
# I'm not sure why the MAS code in Matcha TTS and Grad TTS could not align in StableTTS
|
139 |
+
# so I use the code from https://github.com/p0p4k/pflowtts_pytorch/blob/master/pflow/models/pflow_tts.py and it works
|
140 |
+
# Welcome everyone to solve this problem QAQ
|
141 |
+
|
142 |
+
with torch.no_grad():
|
143 |
+
# const = -0.5 * math.log(2 * math.pi) * self.n_feats
|
144 |
+
# const = -0.5 * math.log(2 * math.pi) * self.mel_channels
|
145 |
+
# factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
|
146 |
+
# y_square = torch.matmul(factor.transpose(1, 2), y**2)
|
147 |
+
# y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
|
148 |
+
# mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
|
149 |
+
# log_prior = y_square - y_mu_double + mu_square + const
|
150 |
+
|
151 |
+
s_p_sq_r = torch.ones_like(mu_x) # [b, d, t]
|
152 |
+
# s_p_sq_r = torch.exp(-2 * logx)
|
153 |
+
neg_cent1 = torch.sum(
|
154 |
+
-0.5 * math.log(2 * math.pi)- torch.zeros_like(mu_x), [1], keepdim=True
|
155 |
+
)
|
156 |
+
# neg_cent1 = torch.sum(
|
157 |
+
# -0.5 * math.log(2 * math.pi) - logx, [1], keepdim=True
|
158 |
+
# ) # [b, 1, t_s]
|
159 |
+
neg_cent2 = torch.einsum("bdt, bds -> bts", -0.5 * (y**2), s_p_sq_r)
|
160 |
+
neg_cent3 = torch.einsum("bdt, bds -> bts", y, (mu_x * s_p_sq_r))
|
161 |
+
neg_cent4 = torch.sum(
|
162 |
+
-0.5 * (mu_x**2) * s_p_sq_r, [1], keepdim=True
|
163 |
+
)
|
164 |
+
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
|
165 |
+
|
166 |
+
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
167 |
+
|
168 |
+
attn = (
|
169 |
+
monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
170 |
+
)
|
171 |
+
|
172 |
+
# attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
|
173 |
+
# attn = attn.detach()
|
174 |
+
|
175 |
+
# Compute loss between predicted log-scaled durations and those obtained from MAS
|
176 |
+
# refered to as prior loss in the paper
|
177 |
+
logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask
|
178 |
+
# logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
|
179 |
+
dur_loss = duration_loss(logw, logw_, x_lengths)
|
180 |
+
|
181 |
+
|
182 |
+
# Align encoded text with mel-spectrogram and get mu_y segment
|
183 |
+
attn = attn.squeeze(1).transpose(1,2)
|
184 |
+
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
185 |
+
mu_y = mu_y.transpose(1, 2)
|
186 |
+
|
187 |
+
# Compute loss of the decoder
|
188 |
+
diff_loss, _ = self.decoder.compute_loss(y, y_mask, mu_y, c)
|
189 |
+
# diff_loss = torch.tensor([0], device=mu_y.device)
|
190 |
+
|
191 |
+
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
|
192 |
+
prior_loss = prior_loss / (torch.sum(y_mask) * self.mel_channels)
|
193 |
+
|
194 |
+
return dur_loss, diff_loss, prior_loss, attn
|
models/reference_encoder.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class Conv1dGLU(nn.Module):
|
5 |
+
"""
|
6 |
+
Conv1d + GLU(Gated Linear Unit) with residual connection.
|
7 |
+
For GLU refer to https://arxiv.org/abs/1612.08083 paper.
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self, in_channels, out_channels, kernel_size, dropout):
|
11 |
+
super(Conv1dGLU, self).__init__()
|
12 |
+
self.out_channels = out_channels
|
13 |
+
self.conv1 = nn.Conv1d(in_channels, 2 * out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
|
14 |
+
self.dropout = nn.Dropout(dropout)
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
residual = x
|
18 |
+
x = self.conv1(x)
|
19 |
+
x1, x2 = torch.split(x, self.out_channels, dim=1)
|
20 |
+
x = x1 * torch.sigmoid(x2)
|
21 |
+
x = residual + self.dropout(x)
|
22 |
+
return x
|
23 |
+
|
24 |
+
# modified from https://github.com/RVC-Boss/GPT-SoVITS/blob/main/GPT_SoVITS/module/modules.py#L766
|
25 |
+
class MelStyleEncoder(nn.Module):
|
26 |
+
"""MelStyleEncoder"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
n_mel_channels=80,
|
31 |
+
style_hidden=128,
|
32 |
+
style_vector_dim=256,
|
33 |
+
style_kernel_size=5,
|
34 |
+
style_head=2,
|
35 |
+
dropout=0.1,
|
36 |
+
):
|
37 |
+
super(MelStyleEncoder, self).__init__()
|
38 |
+
self.in_dim = n_mel_channels
|
39 |
+
self.hidden_dim = style_hidden
|
40 |
+
self.out_dim = style_vector_dim
|
41 |
+
self.kernel_size = style_kernel_size
|
42 |
+
self.n_head = style_head
|
43 |
+
self.dropout = dropout
|
44 |
+
|
45 |
+
self.spectral = nn.Sequential(
|
46 |
+
nn.Linear(self.in_dim, self.hidden_dim),
|
47 |
+
nn.Mish(inplace=True),
|
48 |
+
nn.Dropout(self.dropout),
|
49 |
+
nn.Linear(self.hidden_dim, self.hidden_dim),
|
50 |
+
nn.Mish(inplace=True),
|
51 |
+
nn.Dropout(self.dropout),
|
52 |
+
)
|
53 |
+
|
54 |
+
self.temporal = nn.Sequential(
|
55 |
+
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
|
56 |
+
Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
|
57 |
+
)
|
58 |
+
|
59 |
+
self.slf_attn = nn.MultiheadAttention(
|
60 |
+
self.hidden_dim,
|
61 |
+
self.n_head,
|
62 |
+
self.dropout,
|
63 |
+
batch_first=True
|
64 |
+
)
|
65 |
+
|
66 |
+
self.fc = nn.Linear(self.hidden_dim, self.out_dim)
|
67 |
+
|
68 |
+
def temporal_avg_pool(self, x, mask=None):
|
69 |
+
if mask is None:
|
70 |
+
return torch.mean(x, dim=1)
|
71 |
+
else:
|
72 |
+
len_ = (~mask).sum(dim=1).unsqueeze(1).type_as(x)
|
73 |
+
return torch.sum(x * ~mask.unsqueeze(-1), dim=1) / len_
|
74 |
+
|
75 |
+
def forward(self, x, x_mask=None):
|
76 |
+
x = x.transpose(1, 2)
|
77 |
+
|
78 |
+
# spectral
|
79 |
+
x = self.spectral(x)
|
80 |
+
# temporal
|
81 |
+
x = x.transpose(1, 2)
|
82 |
+
x = self.temporal(x)
|
83 |
+
x = x.transpose(1, 2)
|
84 |
+
# self-attention
|
85 |
+
if x_mask is not None:
|
86 |
+
x_mask = ~x_mask.squeeze(1).to(torch.bool)
|
87 |
+
x, _ = self.slf_attn(x, x, x, key_padding_mask=x_mask)
|
88 |
+
# fc
|
89 |
+
x = self.fc(x)
|
90 |
+
# temoral average pooling
|
91 |
+
w = self.temporal_avg_pool(x, mask=x_mask)
|
92 |
+
|
93 |
+
return w
|
models/text_encoder.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from models.dit import DiTConVBlock
|
5 |
+
|
6 |
+
def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
|
7 |
+
if max_length is None:
|
8 |
+
max_length = length.max()
|
9 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
10 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
11 |
+
|
12 |
+
# modified from https://github.com/jaywalnut310/vits/blob/main/models.py
|
13 |
+
class TextEncoder(nn.Module):
|
14 |
+
def __init__(self, n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
|
15 |
+
super().__init__()
|
16 |
+
self.n_vocab = n_vocab
|
17 |
+
self.out_channels = out_channels
|
18 |
+
self.hidden_channels = hidden_channels
|
19 |
+
self.filter_channels = filter_channels
|
20 |
+
self.n_heads = n_heads
|
21 |
+
self.n_layers = n_layers
|
22 |
+
self.kernel_size = kernel_size
|
23 |
+
self.p_dropout = p_dropout
|
24 |
+
|
25 |
+
self.scale = self.hidden_channels ** 0.5
|
26 |
+
|
27 |
+
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
28 |
+
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
29 |
+
|
30 |
+
self.encoder = nn.ModuleList([DiTConVBlock(hidden_channels, filter_channels, n_heads, kernel_size, p_dropout, gin_channels) for _ in range(n_layers)])
|
31 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
32 |
+
|
33 |
+
self.initialize_weights()
|
34 |
+
|
35 |
+
def initialize_weights(self):
|
36 |
+
for block in self.encoder:
|
37 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
38 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
39 |
+
|
40 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor, x_lengths: torch.Tensor):
|
41 |
+
x = self.emb(x) * self.scale # [b, t, h]
|
42 |
+
x = x.transpose(1, -1) # [b, h, t]
|
43 |
+
x_mask = sequence_mask(x_lengths, x.size(2)).unsqueeze(1).to(x.dtype)
|
44 |
+
|
45 |
+
for layer in self.encoder:
|
46 |
+
x = layer(x, c, x_mask)
|
47 |
+
mu_x = self.proj(x) * x_mask
|
48 |
+
|
49 |
+
return x, mu_x, x_mask
|
monotonic_align/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from numpy import zeros, int32, float32
|
2 |
+
from torch import from_numpy
|
3 |
+
|
4 |
+
from .core import maximum_path_jit
|
5 |
+
|
6 |
+
|
7 |
+
def maximum_path(neg_cent, mask):
|
8 |
+
device = neg_cent.device
|
9 |
+
dtype = neg_cent.dtype
|
10 |
+
neg_cent = neg_cent.data.cpu().numpy().astype(float32)
|
11 |
+
path = zeros(neg_cent.shape, dtype=int32)
|
12 |
+
|
13 |
+
t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
|
14 |
+
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
|
15 |
+
maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
|
16 |
+
return from_numpy(path).to(device=device, dtype=dtype)
|
monotonic_align/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (732 Bytes). View file
|
|
monotonic_align/__pycache__/core.cpython-310.pyc
ADDED
Binary file (985 Bytes). View file
|
|
monotonic_align/core.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numba
|
2 |
+
|
3 |
+
|
4 |
+
@numba.jit(
|
5 |
+
numba.void(
|
6 |
+
numba.int32[:, :, ::1],
|
7 |
+
numba.float32[:, :, ::1],
|
8 |
+
numba.int32[::1],
|
9 |
+
numba.int32[::1],
|
10 |
+
),
|
11 |
+
nopython=True,
|
12 |
+
nogil=True,
|
13 |
+
)
|
14 |
+
def maximum_path_jit(paths, values, t_ys, t_xs):
|
15 |
+
b = paths.shape[0]
|
16 |
+
max_neg_val = -1e9
|
17 |
+
for i in range(int(b)):
|
18 |
+
path = paths[i]
|
19 |
+
value = values[i]
|
20 |
+
t_y = t_ys[i]
|
21 |
+
t_x = t_xs[i]
|
22 |
+
|
23 |
+
v_prev = v_cur = 0.0
|
24 |
+
index = t_x - 1
|
25 |
+
|
26 |
+
for y in range(t_y):
|
27 |
+
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
28 |
+
if x == y:
|
29 |
+
v_cur = max_neg_val
|
30 |
+
else:
|
31 |
+
v_cur = value[y - 1, x]
|
32 |
+
if x == 0:
|
33 |
+
if y == 0:
|
34 |
+
v_prev = 0.0
|
35 |
+
else:
|
36 |
+
v_prev = max_neg_val
|
37 |
+
else:
|
38 |
+
v_prev = value[y - 1, x - 1]
|
39 |
+
value[y, x] += max(v_prev, v_cur)
|
40 |
+
|
41 |
+
for y in range(t_y - 1, -1, -1):
|
42 |
+
path[y, index] = 1
|
43 |
+
if index != 0 and (
|
44 |
+
index == y or value[y - 1, index] < value[y - 1, index - 1]
|
45 |
+
):
|
46 |
+
index = index - 1
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchaudio
|
3 |
+
matplotlib
|
4 |
+
numpy
|
5 |
+
tensorboard
|
6 |
+
pypinyin
|
7 |
+
jieba
|
8 |
+
eng_to_ipa
|
9 |
+
unidecode
|
10 |
+
inflect
|
11 |
+
pyopenjtalk-prebuilt
|
text/LICENSE
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2017 Keith Ito
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
of this software and associated documentation files (the "Software"), to deal
|
5 |
+
in the Software without restriction, including without limitation the rights
|
6 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
copies of the Software, and to permit persons to whom the Software is
|
8 |
+
furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
The above copyright notice and this permission notice shall be included in
|
11 |
+
all copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
19 |
+
THE SOFTWARE.
|
text/__init__.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/keithito/tacotron """
|
2 |
+
from text import cleaners
|
3 |
+
from text.symbols import symbols
|
4 |
+
|
5 |
+
|
6 |
+
# Mappings from symbol to numeric ID and vice versa:
|
7 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
8 |
+
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
9 |
+
|
10 |
+
|
11 |
+
def text_to_sequence(text, symbols, cleaner_names):
|
12 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
13 |
+
Args:
|
14 |
+
text: string to convert to a sequence
|
15 |
+
cleaner_names: names of the cleaner functions to run the text through
|
16 |
+
Returns:
|
17 |
+
List of integers corresponding to the symbols in the text
|
18 |
+
'''
|
19 |
+
sequence = []
|
20 |
+
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
21 |
+
clean_text = _clean_text(text, cleaner_names)
|
22 |
+
print(clean_text)
|
23 |
+
print(f" length:{len(clean_text)}")
|
24 |
+
for symbol in clean_text:
|
25 |
+
if symbol not in symbol_to_id.keys():
|
26 |
+
continue
|
27 |
+
symbol_id = symbol_to_id[symbol]
|
28 |
+
sequence += [symbol_id]
|
29 |
+
print(f" length:{len(sequence)}")
|
30 |
+
return sequence
|
31 |
+
|
32 |
+
|
33 |
+
def cleaned_text_to_sequence(cleaned_text):
|
34 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
35 |
+
Args:
|
36 |
+
text: string to convert to a sequence
|
37 |
+
Returns:
|
38 |
+
List of integers corresponding to the symbols in the text
|
39 |
+
'''
|
40 |
+
# symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
41 |
+
sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
|
42 |
+
return sequence
|
43 |
+
|
44 |
+
def cleaned_text_to_sequence_chinese(cleaned_text):
|
45 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
46 |
+
Args:
|
47 |
+
text: string to convert to a sequence
|
48 |
+
Returns:
|
49 |
+
List of integers corresponding to the symbols in the text
|
50 |
+
'''
|
51 |
+
# symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
52 |
+
sequence = [_symbol_to_id[symbol] for symbol in cleaned_text.split(' ') if symbol in _symbol_to_id.keys()]
|
53 |
+
return sequence
|
54 |
+
|
55 |
+
|
56 |
+
def sequence_to_text(sequence):
|
57 |
+
'''Converts a sequence of IDs back to a string'''
|
58 |
+
result = ''
|
59 |
+
for symbol_id in sequence:
|
60 |
+
s = _id_to_symbol[symbol_id]
|
61 |
+
result += s
|
62 |
+
return result
|
63 |
+
|
64 |
+
|
65 |
+
def _clean_text(text, cleaner_names):
|
66 |
+
for name in cleaner_names:
|
67 |
+
cleaner = getattr(cleaners, name)
|
68 |
+
if not cleaner:
|
69 |
+
raise Exception('Unknown cleaner: %s' % name)
|
70 |
+
text = cleaner(text)
|
71 |
+
return text
|
text/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (2.62 kB). View file
|
|
text/__pycache__/cleaners.cpython-310.pyc
ADDED
Binary file (2.54 kB). View file
|
|
text/__pycache__/english.cpython-310.pyc
ADDED
Binary file (4.4 kB). View file
|
|
text/cleaners.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import string
|
3 |
+
import numpy as np
|
4 |
+
from .langdetect import detect, LangDetectException
|
5 |
+
|
6 |
+
from text.english import english_to_ipa2
|
7 |
+
from text.mandarin import chinese_to_cnm3
|
8 |
+
from text.japanese import japanese_to_ipa2
|
9 |
+
|
10 |
+
language_module_map = {"PAD":0, "ZH": 1, "EN": 2, "JA": 3}
|
11 |
+
|
12 |
+
# 预编译正则表达式
|
13 |
+
ZH_PATTERN = re.compile(r'[\u3400-\u4DBF\u4e00-\u9FFF\uF900-\uFAFF\u3000-\u303F]')
|
14 |
+
EN_PATTERN = re.compile(r'[a-zA-Z.,!?\'"(){}[\]<>:;@#$%^&*-_+=/\\|~`]+')
|
15 |
+
JP_PATTERN = re.compile(r'[\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FAF\u31F0-\u31FF\uFF00-\uFFEF\u3000-\u303F]')
|
16 |
+
CLEANER_PATTERN = re.compile(r'\[(ZH|EN|JA)\]')
|
17 |
+
|
18 |
+
def detect_language(text: str, prev_lang=None):
|
19 |
+
"""
|
20 |
+
根据给定的文本检测语言
|
21 |
+
|
22 |
+
:param text: 输入文本
|
23 |
+
:param prev_lang: 上一个检测到的语言
|
24 |
+
:return: 'ZH' for Chinese, 'EN' for English, 'JA' for Japanese, or prev_lang for spaces
|
25 |
+
"""
|
26 |
+
if ZH_PATTERN.search(text): return 'ZH'
|
27 |
+
if EN_PATTERN.search(text): return 'EN'
|
28 |
+
if JP_PATTERN.search(text): return 'JA'
|
29 |
+
if text.isspace(): return prev_lang # 若是空格,则返回前一个语言
|
30 |
+
return None
|
31 |
+
|
32 |
+
def replace_substring(s, start_index, end_index, replacement):
|
33 |
+
return s[:start_index] + replacement + s[end_index:]
|
34 |
+
|
35 |
+
def replace_sublist(lst, start_index, end_index, replacement_list):
|
36 |
+
lst[start_index:end_index] = replacement_list
|
37 |
+
|
38 |
+
# convert text to ipa and prepare for language embedding
|
39 |
+
def append_tags_and_convert(match, conversion_func, tag_value, tags):
|
40 |
+
converted_text = conversion_func(match.group(1))
|
41 |
+
tags.extend([tag_value] * len(converted_text))
|
42 |
+
return converted_text + ' '
|
43 |
+
|
44 |
+
# auto detect language using re
|
45 |
+
def cjke_cleaners4(text: str):
|
46 |
+
"""
|
47 |
+
根据文本内容自动检测语言并转换为IPA音标
|
48 |
+
|
49 |
+
:param text: 输入文本
|
50 |
+
:return: 转换为IPA音标的文本
|
51 |
+
"""
|
52 |
+
text = CLEANER_PATTERN.sub('', text)
|
53 |
+
pointer = 0
|
54 |
+
output = ''
|
55 |
+
current_language = detect_language(text[pointer])
|
56 |
+
|
57 |
+
while pointer < len(text):
|
58 |
+
temp_text = ''
|
59 |
+
while pointer < len(text) and detect_language(text[pointer], current_language) == current_language:
|
60 |
+
temp_text += text[pointer]
|
61 |
+
pointer += 1
|
62 |
+
if current_language == 'ZH':
|
63 |
+
output += chinese_to_cnm3(temp_text)
|
64 |
+
elif current_language == 'JA':
|
65 |
+
output += japanese_to_ipa2(temp_text)
|
66 |
+
elif current_language == 'EN':
|
67 |
+
output += english_to_ipa2(temp_text)
|
68 |
+
if pointer < len(text):
|
69 |
+
current_language = detect_language(text[pointer])
|
70 |
+
|
71 |
+
output = re.sub(r'\s+$', '', output)
|
72 |
+
output = re.sub(r'([^\.,!\?\-…~])$', r'\1.', output)
|
73 |
+
return output
|
text/cn2an/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = "0.5.20"
|
2 |
+
|
3 |
+
from .cn2an import Cn2An
|
4 |
+
from .an2cn import An2Cn
|
5 |
+
from .transform import Transform
|
6 |
+
|
7 |
+
cn2an = Cn2An().cn2an
|
8 |
+
an2cn = An2Cn().an2cn
|
9 |
+
transform = Transform().transform
|
10 |
+
|
11 |
+
__all__ = [
|
12 |
+
"__version__",
|
13 |
+
"cn2an",
|
14 |
+
"an2cn",
|
15 |
+
"transform"
|
16 |
+
]
|
text/cn2an/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (525 Bytes). View file
|
|
text/cn2an/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (354 Bytes). View file
|
|
text/cn2an/__pycache__/an2cn.cpython-311.pyc
ADDED
Binary file (8.64 kB). View file
|
|
text/cn2an/__pycache__/an2cn.cpython-38.pyc
ADDED
Binary file (4.73 kB). View file
|
|
text/cn2an/__pycache__/cn2an.cpython-311.pyc
ADDED
Binary file (14.8 kB). View file
|
|
text/cn2an/__pycache__/cn2an.cpython-38.pyc
ADDED
Binary file (7.4 kB). View file
|
|
text/cn2an/__pycache__/conf.cpython-311.pyc
ADDED
Binary file (2.02 kB). View file
|
|
text/cn2an/__pycache__/conf.cpython-38.pyc
ADDED
Binary file (1.61 kB). View file
|
|
text/cn2an/__pycache__/transform.cpython-311.pyc
ADDED
Binary file (9.99 kB). View file
|
|
text/cn2an/__pycache__/transform.cpython-38.pyc
ADDED
Binary file (4.82 kB). View file
|
|
text/cn2an/an2cn.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
#from proces import preprocess
|
4 |
+
|
5 |
+
from .conf import NUMBER_LOW_AN2CN, NUMBER_UP_AN2CN, UNIT_LOW_ORDER_AN2CN, UNIT_UP_ORDER_AN2CN
|
6 |
+
|
7 |
+
|
8 |
+
class An2Cn(object):
|
9 |
+
def __init__(self) -> None:
|
10 |
+
self.all_num = "0123456789"
|
11 |
+
self.number_low = NUMBER_LOW_AN2CN
|
12 |
+
self.number_up = NUMBER_UP_AN2CN
|
13 |
+
self.mode_list = ["low", "up", "rmb", "direct"]
|
14 |
+
|
15 |
+
def an2cn(self, inputs: Union[str, int, float] = None, mode: str = "low") -> str:
|
16 |
+
"""阿拉伯数字转中文数字
|
17 |
+
|
18 |
+
:param inputs: 阿拉伯数字
|
19 |
+
:param mode: low 小写数字,up 大写数字,rmb 人民币大写,direct 直接转化
|
20 |
+
:return: 中文数字
|
21 |
+
"""
|
22 |
+
if inputs is not None and inputs != "":
|
23 |
+
if mode not in self.mode_list:
|
24 |
+
raise ValueError(f"mode 仅支持 {str(self.mode_list)} !")
|
25 |
+
|
26 |
+
# 将数字转化为字符串,这里会有Python会自动做转化
|
27 |
+
# 1. -> 1.0 1.00 -> 1.0 -0 -> 0
|
28 |
+
if not isinstance(inputs, str):
|
29 |
+
inputs = self.__number_to_string(inputs)
|
30 |
+
|
31 |
+
# 数据预处理:
|
32 |
+
# 1. 繁体转简体
|
33 |
+
# 2. 全角转半角
|
34 |
+
# inputs = preprocess(inputs, pipelines=[
|
35 |
+
# "traditional_to_simplified",
|
36 |
+
# "full_angle_to_half_angle"
|
37 |
+
# ])
|
38 |
+
|
39 |
+
# 检查数据是否有效
|
40 |
+
self.__check_inputs_is_valid(inputs)
|
41 |
+
|
42 |
+
# 判断正负
|
43 |
+
if inputs[0] == "-":
|
44 |
+
sign = "负"
|
45 |
+
inputs = inputs[1:]
|
46 |
+
else:
|
47 |
+
sign = ""
|
48 |
+
|
49 |
+
if mode == "direct":
|
50 |
+
output = self.__direct_convert(inputs)
|
51 |
+
else:
|
52 |
+
# 切割整数部分和小数部分
|
53 |
+
split_result = inputs.split(".")
|
54 |
+
len_split_result = len(split_result)
|
55 |
+
if len_split_result == 1:
|
56 |
+
# 不包含小数的输入
|
57 |
+
integer_data = split_result[0]
|
58 |
+
if mode == "rmb":
|
59 |
+
output = self.__integer_convert(integer_data, "up") + "元整"
|
60 |
+
else:
|
61 |
+
output = self.__integer_convert(integer_data, mode)
|
62 |
+
elif len_split_result == 2:
|
63 |
+
# 包含小数的输入
|
64 |
+
integer_data, decimal_data = split_result
|
65 |
+
if mode == "rmb":
|
66 |
+
int_data = self.__integer_convert(integer_data, "up")
|
67 |
+
dec_data = self.__decimal_convert(decimal_data, "up")
|
68 |
+
len_dec_data = len(dec_data)
|
69 |
+
|
70 |
+
if len_dec_data == 0:
|
71 |
+
output = int_data + "元整"
|
72 |
+
elif len_dec_data == 1:
|
73 |
+
raise ValueError(f"异常输出:{dec_data}")
|
74 |
+
elif len_dec_data == 2:
|
75 |
+
if dec_data[1] != "零":
|
76 |
+
if int_data == "零":
|
77 |
+
output = dec_data[1] + "角"
|
78 |
+
else:
|
79 |
+
output = int_data + "元" + dec_data[1] + "角"
|
80 |
+
else:
|
81 |
+
output = int_data + "元整"
|
82 |
+
else:
|
83 |
+
if dec_data[1] != "零":
|
84 |
+
if dec_data[2] != "零":
|
85 |
+
if int_data == "零":
|
86 |
+
output = dec_data[1] + "角" + dec_data[2] + "分"
|
87 |
+
else:
|
88 |
+
output = int_data + "元" + dec_data[1] + "角" + dec_data[2] + "分"
|
89 |
+
else:
|
90 |
+
if int_data == "零":
|
91 |
+
output = dec_data[1] + "角"
|
92 |
+
else:
|
93 |
+
output = int_data + "元" + dec_data[1] + "角"
|
94 |
+
else:
|
95 |
+
if dec_data[2] != "零":
|
96 |
+
if int_data == "零":
|
97 |
+
output = dec_data[2] + "分"
|
98 |
+
else:
|
99 |
+
output = int_data + "元" + "零" + dec_data[2] + "分"
|
100 |
+
else:
|
101 |
+
output = int_data + "元整"
|
102 |
+
else:
|
103 |
+
output = self.__integer_convert(integer_data, mode) + self.__decimal_convert(decimal_data, mode)
|
104 |
+
else:
|
105 |
+
raise ValueError(f"输入格式错误:{inputs}!")
|
106 |
+
else:
|
107 |
+
raise ValueError("输入数据为空!")
|
108 |
+
|
109 |
+
return sign + output
|
110 |
+
|
111 |
+
def __direct_convert(self, inputs: str) -> str:
|
112 |
+
_output = ""
|
113 |
+
for d in inputs:
|
114 |
+
if d == ".":
|
115 |
+
_output += "点"
|
116 |
+
else:
|
117 |
+
_output += self.number_low[int(d)]
|
118 |
+
return _output
|
119 |
+
|
120 |
+
@staticmethod
|
121 |
+
def __number_to_string(number_data: Union[int, float]) -> str:
|
122 |
+
# 小数处理:python 会自动把 0.00005 转化成 5e-05,因此 str(0.00005) != "0.00005"
|
123 |
+
string_data = str(number_data)
|
124 |
+
if "e" in string_data:
|
125 |
+
string_data_list = string_data.split("e")
|
126 |
+
string_key = string_data_list[0]
|
127 |
+
string_value = string_data_list[1]
|
128 |
+
if string_value[0] == "-":
|
129 |
+
string_data = "0." + "0" * (int(string_value[1:]) - 1) + string_key
|
130 |
+
else:
|
131 |
+
string_data = string_key + "0" * int(string_value)
|
132 |
+
return string_data
|
133 |
+
|
134 |
+
def __check_inputs_is_valid(self, check_data: str) -> None:
|
135 |
+
# 检查输入数据是否在规定的字典中
|
136 |
+
all_check_keys = self.all_num + ".-"
|
137 |
+
for data in check_data:
|
138 |
+
if data not in all_check_keys:
|
139 |
+
raise ValueError(f"输入的数据不在转化范围内:{data}!")
|
140 |
+
|
141 |
+
def __integer_convert(self, integer_data: str, mode: str) -> str:
|
142 |
+
if mode == "low":
|
143 |
+
numeral_list = NUMBER_LOW_AN2CN
|
144 |
+
unit_list = UNIT_LOW_ORDER_AN2CN
|
145 |
+
elif mode == "up":
|
146 |
+
numeral_list = NUMBER_UP_AN2CN
|
147 |
+
unit_list = UNIT_UP_ORDER_AN2CN
|
148 |
+
else:
|
149 |
+
raise ValueError(f"error mode: {mode}")
|
150 |
+
|
151 |
+
# 去除前面的 0,比如 007 => 7
|
152 |
+
integer_data = str(int(integer_data))
|
153 |
+
|
154 |
+
len_integer_data = len(integer_data)
|
155 |
+
if len_integer_data > len(unit_list):
|
156 |
+
raise ValueError(f"超出数据范围,最长支持 {len(unit_list)} 位")
|
157 |
+
|
158 |
+
output_an = ""
|
159 |
+
for i, d in enumerate(integer_data):
|
160 |
+
if int(d):
|
161 |
+
output_an += numeral_list[int(d)] + unit_list[len_integer_data - i - 1]
|
162 |
+
else:
|
163 |
+
if not (len_integer_data - i - 1) % 4:
|
164 |
+
output_an += numeral_list[int(d)] + unit_list[len_integer_data - i - 1]
|
165 |
+
|
166 |
+
if i > 0 and not output_an[-1] == "零":
|
167 |
+
output_an += numeral_list[int(d)]
|
168 |
+
|
169 |
+
output_an = output_an.replace("零零", "零").replace("零万", "万").replace("零亿", "亿").replace("亿万", "亿") \
|
170 |
+
.strip("零")
|
171 |
+
|
172 |
+
# 解决「一十几」问题
|
173 |
+
if output_an[:2] in ["一十"]:
|
174 |
+
output_an = output_an[1:]
|
175 |
+
|
176 |
+
# 0 - 1 之间的小数
|
177 |
+
if not output_an:
|
178 |
+
output_an = "零"
|
179 |
+
|
180 |
+
return output_an
|
181 |
+
|
182 |
+
def __decimal_convert(self, decimal_data: str, o_mode: str) -> str:
|
183 |
+
len_decimal_data = len(decimal_data)
|
184 |
+
|
185 |
+
if len_decimal_data > 16:
|
186 |
+
print(f"注意:小数部分长度为 {len_decimal_data} ,将自动截取前 16 位有效精度!")
|
187 |
+
decimal_data = decimal_data[:16]
|
188 |
+
|
189 |
+
if len_decimal_data:
|
190 |
+
output_an = "点"
|
191 |
+
else:
|
192 |
+
output_an = ""
|
193 |
+
|
194 |
+
if o_mode == "low":
|
195 |
+
numeral_list = NUMBER_LOW_AN2CN
|
196 |
+
elif o_mode == "up":
|
197 |
+
numeral_list = NUMBER_UP_AN2CN
|
198 |
+
else:
|
199 |
+
raise ValueError(f"error mode: {o_mode}")
|
200 |
+
|
201 |
+
for data in decimal_data:
|
202 |
+
output_an += numeral_list[int(data)]
|
203 |
+
return output_an
|
text/cn2an/cn2an.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
#from proces import preprocess
|
5 |
+
|
6 |
+
from .an2cn import An2Cn
|
7 |
+
from .conf import NUMBER_CN2AN, UNIT_CN2AN, STRICT_CN_NUMBER, NORMAL_CN_NUMBER, NUMBER_LOW_AN2CN, UNIT_LOW_AN2CN
|
8 |
+
|
9 |
+
|
10 |
+
class Cn2An(object):
|
11 |
+
def __init__(self) -> None:
|
12 |
+
self.all_num = "".join(list(NUMBER_CN2AN.keys()))
|
13 |
+
self.all_unit = "".join(list(UNIT_CN2AN.keys()))
|
14 |
+
self.strict_cn_number = STRICT_CN_NUMBER
|
15 |
+
self.normal_cn_number = NORMAL_CN_NUMBER
|
16 |
+
self.check_key_dict = {
|
17 |
+
"strict": "".join(self.strict_cn_number.values()) + "点负",
|
18 |
+
"normal": "".join(self.normal_cn_number.values()) + "点负",
|
19 |
+
"smart": "".join(self.normal_cn_number.values()) + "点负" + "01234567890.-"
|
20 |
+
}
|
21 |
+
self.pattern_dict = self.__get_pattern()
|
22 |
+
self.ac = An2Cn()
|
23 |
+
self.mode_list = ["strict", "normal", "smart"]
|
24 |
+
self.yjf_pattern = re.compile(fr"^.*?[元圆][{self.all_num}]角([{self.all_num}]分)?$")
|
25 |
+
self.pattern1 = re.compile(fr"^-?\d+(\.\d+)?[{self.all_unit}]?$")
|
26 |
+
self.ptn_all_num = re.compile(f"^[{self.all_num}]+$")
|
27 |
+
# "十?" is for special case "十一万三"
|
28 |
+
self.ptn_speaking_mode = re.compile(f"^([{self.all_num}]{{0,2}}[{self.all_unit}])+[{self.all_num}]$")
|
29 |
+
|
30 |
+
def cn2an(self, inputs: Union[str, int, float] = None, mode: str = "strict") -> Union[float, int]:
|
31 |
+
"""中文数字转阿拉伯数字
|
32 |
+
|
33 |
+
:param inputs: 中文数字、阿拉伯数字、中文数字和阿拉伯数字
|
34 |
+
:param mode: strict 严格,normal 正常,smart 智能
|
35 |
+
:return: 阿拉伯数字
|
36 |
+
"""
|
37 |
+
if inputs is not None or inputs == "":
|
38 |
+
if mode not in self.mode_list:
|
39 |
+
raise ValueError(f"mode 仅支持 {str(self.mode_list)} !")
|
40 |
+
|
41 |
+
# 将数字转化为字符串
|
42 |
+
if not isinstance(inputs, str):
|
43 |
+
inputs = str(inputs)
|
44 |
+
|
45 |
+
# 数据预处理:
|
46 |
+
# 1. 繁体转简体
|
47 |
+
# 2. 全角转半角
|
48 |
+
# inputs = preprocess(inputs, pipelines=[
|
49 |
+
# "traditional_to_simplified",
|
50 |
+
# "full_angle_to_half_angle"
|
51 |
+
# ])
|
52 |
+
|
53 |
+
# 特殊转化 廿
|
54 |
+
inputs = inputs.replace("廿", "二十")
|
55 |
+
|
56 |
+
# 检查输入数据是否有效
|
57 |
+
sign, integer_data, decimal_data, is_all_num = self.__check_input_data_is_valid(inputs, mode)
|
58 |
+
|
59 |
+
# smart 下的特殊情况
|
60 |
+
if sign == 0:
|
61 |
+
return integer_data
|
62 |
+
else:
|
63 |
+
if not is_all_num:
|
64 |
+
if decimal_data is None:
|
65 |
+
output = self.__integer_convert(integer_data)
|
66 |
+
else:
|
67 |
+
output = self.__integer_convert(integer_data) + self.__decimal_convert(decimal_data)
|
68 |
+
# fix 1 + 0.57 = 1.5699999999999998
|
69 |
+
output = round(output, len(decimal_data))
|
70 |
+
else:
|
71 |
+
if decimal_data is None:
|
72 |
+
output = self.__direct_convert(integer_data)
|
73 |
+
else:
|
74 |
+
output = self.__direct_convert(integer_data) + self.__decimal_convert(decimal_data)
|
75 |
+
# fix 1 + 0.57 = 1.5699999999999998
|
76 |
+
output = round(output, len(decimal_data))
|
77 |
+
else:
|
78 |
+
raise ValueError("输入数据为空!")
|
79 |
+
|
80 |
+
return sign * output
|
81 |
+
|
82 |
+
def __get_pattern(self) -> dict:
|
83 |
+
# 整数严格检查
|
84 |
+
_0 = "[零]"
|
85 |
+
_1_9 = "[一二三四五六七八九]"
|
86 |
+
_10_99 = f"{_1_9}?[十]{_1_9}?"
|
87 |
+
_1_99 = f"({_10_99}|{_1_9})"
|
88 |
+
_100_999 = f"({_1_9}[百]([零]{_1_9})?|{_1_9}[百]{_10_99})"
|
89 |
+
_1_999 = f"({_100_999}|{_1_99})"
|
90 |
+
_1000_9999 = f"({_1_9}[千]([零]{_1_99})?|{_1_9}[千]{_100_999})"
|
91 |
+
_1_9999 = f"({_1000_9999}|{_1_999})"
|
92 |
+
_10000_99999999 = f"({_1_9999}[万]([零]{_1_999})?|{_1_9999}[万]{_1000_9999})"
|
93 |
+
_1_99999999 = f"({_10000_99999999}|{_1_9999})"
|
94 |
+
_100000000_9999999999999999 = f"({_1_99999999}[亿]([零]{_1_99999999})?|{_1_99999999}[亿]{_10000_99999999})"
|
95 |
+
_1_9999999999999999 = f"({_100000000_9999999999999999}|{_1_99999999})"
|
96 |
+
str_int_pattern = f"^({_0}|{_1_9999999999999999})$"
|
97 |
+
nor_int_pattern = f"^({_0}|{_1_9999999999999999})$"
|
98 |
+
|
99 |
+
str_dec_pattern = "^[零一二三四五六七八九]{0,15}[一二三四五六七八九]$"
|
100 |
+
nor_dec_pattern = "^[零一二三四五六七八九]{0,16}$"
|
101 |
+
|
102 |
+
for str_num in self.strict_cn_number.keys():
|
103 |
+
str_int_pattern = str_int_pattern.replace(str_num, self.strict_cn_number[str_num])
|
104 |
+
str_dec_pattern = str_dec_pattern.replace(str_num, self.strict_cn_number[str_num])
|
105 |
+
for nor_num in self.normal_cn_number.keys():
|
106 |
+
nor_int_pattern = nor_int_pattern.replace(nor_num, self.normal_cn_number[nor_num])
|
107 |
+
nor_dec_pattern = nor_dec_pattern.replace(nor_num, self.normal_cn_number[nor_num])
|
108 |
+
|
109 |
+
pattern_dict = {
|
110 |
+
"strict": {
|
111 |
+
"int": re.compile(str_int_pattern),
|
112 |
+
"dec": re.compile(str_dec_pattern)
|
113 |
+
},
|
114 |
+
"normal": {
|
115 |
+
"int": re.compile(nor_int_pattern),
|
116 |
+
"dec": re.compile(nor_dec_pattern)
|
117 |
+
}
|
118 |
+
}
|
119 |
+
return pattern_dict
|
120 |
+
|
121 |
+
def __copy_num(self, num):
|
122 |
+
cn_num = ""
|
123 |
+
for n in num:
|
124 |
+
cn_num += NUMBER_LOW_AN2CN[int(n)]
|
125 |
+
return cn_num
|
126 |
+
|
127 |
+
def __check_input_data_is_valid(self, check_data: str, mode: str) -> (int, str, str, bool):
|
128 |
+
# 去除 元整、圆整、元正、圆正
|
129 |
+
stop_words = ["元整", "圆整", "元正", "圆正"]
|
130 |
+
for word in stop_words:
|
131 |
+
if check_data[-2:] == word:
|
132 |
+
check_data = check_data[:-2]
|
133 |
+
|
134 |
+
# 去除 元、圆
|
135 |
+
if mode != "strict":
|
136 |
+
normal_stop_words = ["圆", "元"]
|
137 |
+
for word in normal_stop_words:
|
138 |
+
if check_data[-1] == word:
|
139 |
+
check_data = check_data[:-1]
|
140 |
+
|
141 |
+
# 处理元角分
|
142 |
+
result = self.yjf_pattern.search(check_data)
|
143 |
+
if result:
|
144 |
+
check_data = check_data.replace("元", "点").replace("角", "").replace("分", "")
|
145 |
+
|
146 |
+
# 处理特殊问法:一千零十一 一万零百一十一
|
147 |
+
if "零十" in check_data:
|
148 |
+
check_data = check_data.replace("零十", "零一十")
|
149 |
+
if "零百" in check_data:
|
150 |
+
check_data = check_data.replace("零百", "零一百")
|
151 |
+
|
152 |
+
for data in check_data:
|
153 |
+
if data not in self.check_key_dict[mode]:
|
154 |
+
raise ValueError(f"当前为{mode}模式,输入的数据不在转化范围内:{data}!")
|
155 |
+
|
156 |
+
# 确定正负号
|
157 |
+
if check_data[0] == "负":
|
158 |
+
check_data = check_data[1:]
|
159 |
+
sign = -1
|
160 |
+
else:
|
161 |
+
sign = 1
|
162 |
+
|
163 |
+
if "点" in check_data:
|
164 |
+
split_data = check_data.split("点")
|
165 |
+
if len(split_data) == 2:
|
166 |
+
integer_data, decimal_data = split_data
|
167 |
+
# 将 smart 模式中的阿拉伯数字转化成中文数字
|
168 |
+
if mode == "smart":
|
169 |
+
integer_data = re.sub(r"\d+", lambda x: self.ac.an2cn(x.group()), integer_data)
|
170 |
+
decimal_data = re.sub(r"\d+", lambda x: self.__copy_num(x.group()), decimal_data)
|
171 |
+
mode = "normal"
|
172 |
+
else:
|
173 |
+
raise ValueError("数据中包含不止一个点!")
|
174 |
+
else:
|
175 |
+
integer_data = check_data
|
176 |
+
decimal_data = None
|
177 |
+
# 将 smart 模式中的阿拉伯数字转化成中文数字
|
178 |
+
if mode == "smart":
|
179 |
+
# 10.1万 10.1
|
180 |
+
result1 = self.pattern1.search(integer_data)
|
181 |
+
if result1:
|
182 |
+
if result1.group() == integer_data:
|
183 |
+
if integer_data[-1] in UNIT_CN2AN.keys():
|
184 |
+
output = int(float(integer_data[:-1]) * UNIT_CN2AN[integer_data[-1]])
|
185 |
+
else:
|
186 |
+
output = float(integer_data)
|
187 |
+
return 0, output, None, None
|
188 |
+
|
189 |
+
integer_data = re.sub(r"\d+", lambda x: self.ac.an2cn(x.group()), integer_data)
|
190 |
+
mode = "normal"
|
191 |
+
|
192 |
+
result_int = self.pattern_dict[mode]["int"].search(integer_data)
|
193 |
+
if result_int:
|
194 |
+
if result_int.group() == integer_data:
|
195 |
+
if decimal_data is not None:
|
196 |
+
result_dec = self.pattern_dict[mode]["dec"].search(decimal_data)
|
197 |
+
if result_dec:
|
198 |
+
if result_dec.group() == decimal_data:
|
199 |
+
return sign, integer_data, decimal_data, False
|
200 |
+
else:
|
201 |
+
return sign, integer_data, decimal_data, False
|
202 |
+
else:
|
203 |
+
if mode == "strict":
|
204 |
+
raise ValueError(f"不符合格式的数据:{integer_data}")
|
205 |
+
elif mode == "normal":
|
206 |
+
# 纯数模式:一二三
|
207 |
+
result_all_num = self.ptn_all_num.search(integer_data)
|
208 |
+
if result_all_num:
|
209 |
+
if result_all_num.group() == integer_data:
|
210 |
+
if decimal_data is not None:
|
211 |
+
result_dec = self.pattern_dict[mode]["dec"].search(decimal_data)
|
212 |
+
if result_dec:
|
213 |
+
if result_dec.group() == decimal_data:
|
214 |
+
return sign, integer_data, decimal_data, True
|
215 |
+
else:
|
216 |
+
return sign, integer_data, decimal_data, True
|
217 |
+
|
218 |
+
# 口语模式:一万二,两千三,三百四,十三万六,一百二十五万三
|
219 |
+
result_speaking_mode = self.ptn_speaking_mode.search(integer_data)
|
220 |
+
if len(integer_data) >= 3 and result_speaking_mode and result_speaking_mode.group() == integer_data:
|
221 |
+
# len(integer_data)>=3: because the minimum length of integer_data that can be matched is 3
|
222 |
+
# to find the last unit
|
223 |
+
last_unit = result_speaking_mode.groups()[-1][-1]
|
224 |
+
_unit = UNIT_LOW_AN2CN[UNIT_CN2AN[last_unit] // 10]
|
225 |
+
integer_data = integer_data + _unit
|
226 |
+
if decimal_data is not None:
|
227 |
+
result_dec = self.pattern_dict[mode]["dec"].search(decimal_data)
|
228 |
+
if result_dec:
|
229 |
+
if result_dec.group() == decimal_data:
|
230 |
+
return sign, integer_data, decimal_data, False
|
231 |
+
else:
|
232 |
+
return sign, integer_data, decimal_data, False
|
233 |
+
|
234 |
+
raise ValueError(f"不符合格式的数据:{check_data}")
|
235 |
+
|
236 |
+
def __integer_convert(self, integer_data: str) -> int:
|
237 |
+
# 核心
|
238 |
+
output_integer = 0
|
239 |
+
unit = 1
|
240 |
+
ten_thousand_unit = 1
|
241 |
+
for index, cn_num in enumerate(reversed(integer_data)):
|
242 |
+
# 数值
|
243 |
+
if cn_num in NUMBER_CN2AN:
|
244 |
+
num = NUMBER_CN2AN[cn_num]
|
245 |
+
output_integer += num * unit
|
246 |
+
# 单位
|
247 |
+
elif cn_num in UNIT_CN2AN:
|
248 |
+
unit = UNIT_CN2AN[cn_num]
|
249 |
+
# 判断出万、亿、万亿
|
250 |
+
if unit % 10000 == 0:
|
251 |
+
# 万 亿
|
252 |
+
if unit > ten_thousand_unit:
|
253 |
+
ten_thousand_unit = unit
|
254 |
+
# 万亿
|
255 |
+
else:
|
256 |
+
ten_thousand_unit = unit * ten_thousand_unit
|
257 |
+
unit = ten_thousand_unit
|
258 |
+
|
259 |
+
if unit < ten_thousand_unit:
|
260 |
+
unit = unit * ten_thousand_unit
|
261 |
+
|
262 |
+
if index == len(integer_data) - 1:
|
263 |
+
output_integer += unit
|
264 |
+
else:
|
265 |
+
raise ValueError(f"{cn_num} 不在转化范围内")
|
266 |
+
|
267 |
+
return int(output_integer)
|
268 |
+
|
269 |
+
def __decimal_convert(self, decimal_data: str) -> float:
|
270 |
+
len_decimal_data = len(decimal_data)
|
271 |
+
|
272 |
+
if len_decimal_data > 16:
|
273 |
+
print(f"注意:小数部分长度为 {len_decimal_data} ,将自动截取前 16 位有效精度!")
|
274 |
+
decimal_data = decimal_data[:16]
|
275 |
+
len_decimal_data = 16
|
276 |
+
|
277 |
+
output_decimal = 0
|
278 |
+
for index in range(len(decimal_data) - 1, -1, -1):
|
279 |
+
unit_key = NUMBER_CN2AN[decimal_data[index]]
|
280 |
+
output_decimal += unit_key * 10 ** -(index + 1)
|
281 |
+
|
282 |
+
# 处理精度溢出问题
|
283 |
+
output_decimal = round(output_decimal, len_decimal_data)
|
284 |
+
|
285 |
+
return output_decimal
|
286 |
+
|
287 |
+
def __direct_convert(self, data: str) -> int:
|
288 |
+
output_data = 0
|
289 |
+
for index in range(len(data) - 1, -1, -1):
|
290 |
+
unit_key = NUMBER_CN2AN[data[index]]
|
291 |
+
output_data += unit_key * 10 ** (len(data) - index - 1)
|
292 |
+
|
293 |
+
return output_data
|