KdaiP commited on
Commit
09b47fc
·
verified ·
1 Parent(s): ece35c0

upload test model

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. app.py +138 -0
  3. audios/audio0.wav +0 -0
  4. audios/audio1.wav +0 -0
  5. audios/audio2.wav +0 -0
  6. audios/audio3.wav +0 -0
  7. audios/audio4.wav +0 -0
  8. audios/audio5.wav +0 -0
  9. audios/audio6.wav +0 -0
  10. checkpoints/.keep +0 -0
  11. checkpoints/checkpoint_0.pt +3 -0
  12. checkpoints/vocoder.pt +3 -0
  13. config.py +52 -0
  14. datas/__init__.py +0 -0
  15. datas/dataset.py +52 -0
  16. datas/sampler.py +121 -0
  17. models/__init__.py +0 -0
  18. models/__pycache__/__init__.cpython-310.pyc +0 -0
  19. models/__pycache__/model.cpython-310.pyc +0 -0
  20. models/dit.py +205 -0
  21. models/duration_predictor.py +40 -0
  22. models/estimator.py +161 -0
  23. models/flow_matching.py +108 -0
  24. models/model.py +194 -0
  25. models/reference_encoder.py +93 -0
  26. models/text_encoder.py +49 -0
  27. monotonic_align/__init__.py +16 -0
  28. monotonic_align/__pycache__/__init__.cpython-310.pyc +0 -0
  29. monotonic_align/__pycache__/core.cpython-310.pyc +0 -0
  30. monotonic_align/core.py +46 -0
  31. requirements.txt +11 -0
  32. text/LICENSE +19 -0
  33. text/__init__.py +71 -0
  34. text/__pycache__/__init__.cpython-310.pyc +0 -0
  35. text/__pycache__/cleaners.cpython-310.pyc +0 -0
  36. text/__pycache__/english.cpython-310.pyc +0 -0
  37. text/cleaners.py +73 -0
  38. text/cn2an/__init__.py +16 -0
  39. text/cn2an/__pycache__/__init__.cpython-311.pyc +0 -0
  40. text/cn2an/__pycache__/__init__.cpython-38.pyc +0 -0
  41. text/cn2an/__pycache__/an2cn.cpython-311.pyc +0 -0
  42. text/cn2an/__pycache__/an2cn.cpython-38.pyc +0 -0
  43. text/cn2an/__pycache__/cn2an.cpython-311.pyc +0 -0
  44. text/cn2an/__pycache__/cn2an.cpython-38.pyc +0 -0
  45. text/cn2an/__pycache__/conf.cpython-311.pyc +0 -0
  46. text/cn2an/__pycache__/conf.cpython-38.pyc +0 -0
  47. text/cn2an/__pycache__/transform.cpython-311.pyc +0 -0
  48. text/cn2an/__pycache__/transform.cpython-38.pyc +0 -0
  49. text/cn2an/an2cn.py +203 -0
  50. text/cn2an/cn2an.py +293 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ text/custom_pypinyin_dict/__pycache__/cc_cedict_0.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
37
+ text/custom_pypinyin_dict/__pycache__/cc_cedict_0.cpython-38.pyc filter=lfs diff=lfs merge=lfs -text
38
+ text/custom_pypinyin_dict/__pycache__/cc_cedict_1.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
39
+ text/custom_pypinyin_dict/__pycache__/cc_cedict_1.cpython-38.pyc filter=lfs diff=lfs merge=lfs -text
40
+ text/custom_pypinyin_dict/__pycache__/cc_cedict_2.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
41
+ text/custom_pypinyin_dict/__pycache__/cc_cedict_2.cpython-38.pyc filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from dataclasses import asdict
4
+ from text import symbols
5
+ import torch
6
+ import torchaudio
7
+
8
+ from utils.audio import LogMelSpectrogram
9
+ from config import ModelConfig, VocosConfig, MelConfig
10
+ from models.model import StableTTS
11
+ from vocos_pytorch.models.model import Vocos
12
+ from text.mandarin import chinese_to_cnm3
13
+ from text import cleaned_text_to_sequence
14
+ from datas.dataset import intersperse
15
+
16
+ import gradio as gr
17
+ import numpy as np
18
+ import matplotlib.pyplot as plt
19
+ from pathlib import Path
20
+
21
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
22
+
23
+ @ torch.inference_mode()
24
+ def inference(text: str, ref_audio: torch.Tensor, checkpoint_path: str, step: int=10) -> torch.Tensor:
25
+ global last_checkpoint_path
26
+ if checkpoint_path != last_checkpoint_path:
27
+ tts_model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
28
+ last_checkpoint_path = checkpoint_path
29
+
30
+ phonemizer = chinese_to_cnm3
31
+
32
+ # prepare input for tts model
33
+ x = torch.tensor(intersperse(cleaned_text_to_sequence(phonemizer(text)), item=0), dtype=torch.long, device=device).unsqueeze(0)
34
+ x_len = torch.tensor([x.size(-1)], dtype=torch.long, device=device)
35
+ waveform, sr = torchaudio.load(ref_audio)
36
+ if sr != sample_rate:
37
+ waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
38
+ y = mel_extractor(waveform).to(device)
39
+
40
+ # inference
41
+ mel = tts_model.synthesise(x, x_len, step, y=y, temperature=1, length_scale=1)['decoder_outputs']
42
+ audio = vocoder(mel)
43
+
44
+ # process output for gradio
45
+ audio_output = (sample_rate, (audio.cpu().squeeze(0).numpy() * 32767).astype(np.int16)) # (samplerate, int16 audio) for gr.Audio
46
+ mel_output = plot_mel_spectrogram(mel.cpu().squeeze(0).numpy()) # get the plot of mel
47
+ return audio_output, mel_output
48
+
49
+ def get_pipeline(n_vocab: int, tts_model_config: ModelConfig, mel_config: MelConfig, vocoder_config: VocosConfig, tts_checkpoint_path, vocoder_checkpoint_path):
50
+ tts_model = StableTTS(n_vocab, mel_config.n_mels, **asdict(tts_model_config))
51
+ mel_extractor = LogMelSpectrogram(mel_config)
52
+ vocoder = Vocos(vocoder_config, mel_config)
53
+ # tts_model.load_state_dict(torch.load(tts_checkpoint_path, map_location='cpu'))
54
+ tts_model.to(device)
55
+ tts_model.eval()
56
+ vocoder.load_state_dict(torch.load(vocoder_checkpoint_path, map_location='cpu'))
57
+ vocoder.to(device)
58
+ vocoder.eval()
59
+ return tts_model, mel_extractor, vocoder
60
+
61
+ def plot_mel_spectrogram(mel_spectrogram):
62
+ fig, ax = plt.subplots(figsize=(20, 8))
63
+ ax.imshow(mel_spectrogram, aspect='auto', origin='lower')
64
+ plt.axis('off')
65
+ fig.subplots_adjust(left=0, right=1, top=1, bottom=0) # remove white edges
66
+ return fig
67
+
68
+
69
+ def main():
70
+ tts_model_config = ModelConfig()
71
+ mel_config = MelConfig()
72
+ vocoder_config = VocosConfig()
73
+
74
+ tts_checkpoint_path = './checkpoints' # the folder that contains StableTTS checkpoints
75
+ vocoder_checkpoint_path = './checkpoints/vocoder.pt'
76
+
77
+ global tts_model, mel_extractor, vocoder, sample_rate, last_checkpoint_path
78
+ sample_rate = mel_config.sample_rate
79
+ last_checkpoint_path = None
80
+ tts_model, mel_extractor, vocoder = get_pipeline(len(symbols), tts_model_config, mel_config, vocoder_config, tts_checkpoint_path, vocoder_checkpoint_path)
81
+
82
+ tts_checkpoint_path = [path for path in Path(tts_checkpoint_path).rglob('*.pt') if 'optimizer' and 'vocoder' not in path.name]
83
+ audios = list(Path('./audios').rglob('*.wav'))
84
+
85
+ # gradio wabui
86
+ gui_title = 'StableTTS'
87
+ gui_description = """Next-generation TTS model using flow-matching and DiT, inspired by Stable Diffusion 3."""
88
+ with gr.Blocks(analytics_enabled=False) as demo:
89
+
90
+ with gr.Row():
91
+ with gr.Column():
92
+ gr.Markdown(f"# {gui_title}")
93
+ gr.Markdown(gui_description)
94
+
95
+ with gr.Row():
96
+ with gr.Column():
97
+ input_text_gr = gr.Textbox(
98
+ label="Input Text",
99
+ info="One or two sentences at a time is better. Up to 200 text characters.",
100
+ value="三国杀是一款风靡全球的以三国演义为背景的策略卡牌桌面游戏,经典新三国国战玩法,百万名将任由你搭配,楚雄争霸,等你决战沙场!",
101
+ )
102
+
103
+ ref_audio_gr = gr.Dropdown(
104
+ label='reference audio',
105
+ choices=audios,
106
+ value = 0
107
+ )
108
+
109
+
110
+ checkpoint_gr = gr.Dropdown(
111
+ label='checkpoint',
112
+ choices=tts_checkpoint_path,
113
+ value = 0
114
+ )
115
+
116
+ step_gr = gr.Slider(
117
+ label='Step',
118
+ minimum=1,
119
+ maximum=100,
120
+ value=25,
121
+ step=1
122
+ )
123
+
124
+
125
+ tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
126
+
127
+ with gr.Column():
128
+ mel_gr = gr.Plot(label="Mel Visual")
129
+ audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
130
+
131
+ tts_button.click(inference, [input_text_gr, ref_audio_gr, checkpoint_gr, step_gr], outputs=[audio_gr, mel_gr])
132
+
133
+ demo.queue()
134
+ demo.launch(debug=True, show_api=True)
135
+
136
+
137
+ if __name__ == '__main__':
138
+ main()
audios/audio0.wav ADDED
Binary file (318 kB). View file
 
audios/audio1.wav ADDED
Binary file (558 kB). View file
 
audios/audio2.wav ADDED
Binary file (300 kB). View file
 
audios/audio3.wav ADDED
Binary file (738 kB). View file
 
audios/audio4.wav ADDED
Binary file (547 kB). View file
 
audios/audio5.wav ADDED
Binary file (667 kB). View file
 
audios/audio6.wav ADDED
Binary file (596 kB). View file
 
checkpoints/.keep ADDED
File without changes
checkpoints/checkpoint_0.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9c5d59001f1c9e308c0fa00291779722e88d5a8162734afac547c95126d4580
3
+ size 37552792
checkpoints/vocoder.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e180a1df6ca0a9e382e0915b0b1984aecfb63397c4ee21f12857447c2d76d29a
3
+ size 56666508
config.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class MelConfig:
5
+ sample_rate: int = 44100
6
+ n_fft: int = 2048
7
+ win_length: int = 2048
8
+ hop_length: int = 512
9
+ f_min: float = 0.0
10
+ f_max: float = None
11
+ pad: int = 0
12
+ n_mels: int = 128
13
+ power: float = 1.0
14
+ normalized: bool = False
15
+ center: bool = False
16
+ pad_mode: str = "reflect"
17
+ mel_scale: str = "htk"
18
+
19
+ def __post_init__(self):
20
+ if self.pad == 0:
21
+ self.pad = (self.n_fft - self.hop_length) // 2
22
+
23
+ @dataclass
24
+ class ModelConfig:
25
+ hidden_channels: int = 192
26
+ filter_channels: int = 512
27
+ n_heads: int = 2
28
+ n_enc_layers: int = 3
29
+ n_dec_layers: int = 2
30
+ kernel_size: int = 3
31
+ p_dropout: int = 0.1
32
+ gin_channels: int = 192
33
+
34
+ @dataclass
35
+ class TrainConfig:
36
+ train_dataset_path: str = 'filelists/filelist.json'
37
+ test_dataset_path: str = 'filelists/filelist.json'
38
+ batch_size: int = 52
39
+ learning_rate: float = 1e-4
40
+ num_epochs: int = 10000
41
+ model_save_path: str = './checkpoints'
42
+ log_dir: str = './runs'
43
+ log_interval: int = 128
44
+ save_interval: int = 15
45
+ warmup_steps: int = 200
46
+
47
+ @dataclass
48
+ class VocosConfig:
49
+ input_channels: int = 128
50
+ dim: int = 512
51
+ intermediate_dim: int = 1536
52
+ num_layers: int = 8
datas/__init__.py ADDED
File without changes
datas/dataset.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import json
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+
7
+ from text import cleaned_text_to_sequence
8
+
9
+ def intersperse(lst, item):
10
+ result = [item] * (len(lst) * 2 + 1)
11
+ result[1::2] = lst
12
+ return result
13
+
14
+ class StableDataset(Dataset):
15
+ def __init__(self, filelist_path, hop_length):
16
+ self.filelist_path = filelist_path
17
+ self.hop_length = hop_length
18
+
19
+ self._load_filelist(filelist_path)
20
+
21
+ def _load_filelist(self, filelist_path):
22
+ filelist, lengths = [], []
23
+ with open(filelist_path, 'r', encoding='utf-8') as f:
24
+ for line in f:
25
+ line = json.loads(line.strip())
26
+ filelist.append((line['mel_path'], line['phone']))
27
+ lengths.append(os.path.getsize(line['audio_path']) // (2 * self.hop_length))
28
+
29
+ self.filelist = filelist
30
+ self.lengths = lengths
31
+
32
+ def __len__(self):
33
+ return len(self.filelist)
34
+
35
+ def __getitem__(self, idx):
36
+ mel_path, phone = self.filelist[idx]
37
+ mel = torch.load(mel_path, map_location='cpu')
38
+ phone = torch.tensor(intersperse(cleaned_text_to_sequence(phone), 0), dtype=torch.long)
39
+ return mel, phone
40
+
41
+ def collate_fn(batch):
42
+ texts = [item[1] for item in batch]
43
+ mels = [item[0] for item in batch]
44
+
45
+ text_lengths = torch.tensor([text.size(-1) for text in texts], dtype=torch.long)
46
+ mel_lengths = torch.tensor([mel.size(-1) for mel in mels], dtype=torch.long)
47
+
48
+ # pad to the same length
49
+ texts_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(texts), padding=0)
50
+ mels_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(mels), padding=0)
51
+
52
+ return texts_padded, text_lengths, mels_padded, mel_lengths
datas/sampler.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # reference: https://github.com/jaywalnut310/vits/blob/main/data_utils.py
4
+ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
5
+ """
6
+ Maintain similar input lengths in a batch.
7
+ Length groups are specified by boundaries.
8
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
9
+
10
+ It removes samples which are not included in the boundaries.
11
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ dataset,
17
+ batch_size,
18
+ boundaries,
19
+ num_replicas=None,
20
+ rank=None,
21
+ shuffle=True,
22
+ ):
23
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
24
+ self.lengths = dataset.lengths
25
+ self.batch_size = batch_size
26
+ self.boundaries = boundaries
27
+
28
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
29
+ self.total_size = sum(self.num_samples_per_bucket)
30
+ self.num_samples = self.total_size // self.num_replicas
31
+
32
+ def _create_buckets(self):
33
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
34
+ for i in range(len(self.lengths)):
35
+ length = self.lengths[i]
36
+ idx_bucket = self._bisect(length)
37
+ if idx_bucket != -1:
38
+ buckets[idx_bucket].append(i)
39
+
40
+ for i in range(len(buckets) - 1, 0, -1):
41
+ # for i in range(len(buckets) - 1, -1, -1):
42
+ if len(buckets[i]) == 0:
43
+ buckets.pop(i)
44
+ self.boundaries.pop(i + 1)
45
+
46
+ num_samples_per_bucket = []
47
+ for i in range(len(buckets)):
48
+ len_bucket = len(buckets[i])
49
+ total_batch_size = self.num_replicas * self.batch_size
50
+ rem = (
51
+ total_batch_size - (len_bucket % total_batch_size)
52
+ ) % total_batch_size
53
+ num_samples_per_bucket.append(len_bucket + rem)
54
+ return buckets, num_samples_per_bucket
55
+
56
+ def __iter__(self):
57
+ # deterministically shuffle based on epoch
58
+ g = torch.Generator()
59
+ g.manual_seed(self.epoch)
60
+
61
+ indices = []
62
+ if self.shuffle:
63
+ for bucket in self.buckets:
64
+ indices.append(torch.randperm(len(bucket), generator=g).tolist())
65
+ else:
66
+ for bucket in self.buckets:
67
+ indices.append(list(range(len(bucket))))
68
+
69
+ batches = []
70
+ for i in range(len(self.buckets)):
71
+ bucket = self.buckets[i]
72
+ len_bucket = len(bucket)
73
+ ids_bucket = indices[i]
74
+ num_samples_bucket = self.num_samples_per_bucket[i]
75
+
76
+ # add extra samples to make it evenly divisible
77
+ rem = num_samples_bucket - len_bucket
78
+ ids_bucket = (
79
+ ids_bucket
80
+ + ids_bucket * (rem // len_bucket)
81
+ + ids_bucket[: (rem % len_bucket)]
82
+ )
83
+
84
+ # subsample
85
+ ids_bucket = ids_bucket[self.rank :: self.num_replicas]
86
+
87
+ # batching
88
+ for j in range(len(ids_bucket) // self.batch_size):
89
+ batch = [
90
+ bucket[idx]
91
+ for idx in ids_bucket[
92
+ j * self.batch_size : (j + 1) * self.batch_size
93
+ ]
94
+ ]
95
+ batches.append(batch)
96
+
97
+ if self.shuffle:
98
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
99
+ batches = [batches[i] for i in batch_ids]
100
+ self.batches = batches
101
+
102
+ assert len(self.batches) * self.batch_size == self.num_samples
103
+ return iter(self.batches)
104
+
105
+ def _bisect(self, x, lo=0, hi=None):
106
+ if hi is None:
107
+ hi = len(self.boundaries) - 1
108
+
109
+ if hi > lo:
110
+ mid = (hi + lo) // 2
111
+ if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
112
+ return mid
113
+ elif x <= self.boundaries[mid]:
114
+ return self._bisect(x, lo, mid)
115
+ else:
116
+ return self._bisect(x, mid + 1, hi)
117
+ else:
118
+ return -1
119
+
120
+ def __len__(self):
121
+ return self.num_samples // self.batch_size
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (135 Bytes). View file
 
models/__pycache__/model.cpython-310.pyc ADDED
Binary file (6.49 kB). View file
 
models/dit.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References:
2
+ # https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/transformer.py
3
+ # https://github.com/jaywalnut310/vits/blob/main/attentions.py
4
+ # https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ class FFN(nn.Module):
11
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., gin_channels=0):
12
+ super().__init__()
13
+ self.in_channels = in_channels
14
+ self.out_channels = out_channels
15
+ self.filter_channels = filter_channels
16
+ self.kernel_size = kernel_size
17
+ self.p_dropout = p_dropout
18
+ self.gin_channels = gin_channels
19
+
20
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
21
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
22
+ self.drop = nn.Dropout(p_dropout)
23
+ self.act1 = nn.GELU(approximate="tanh")
24
+
25
+ def forward(self, x, x_mask):
26
+ x = self.conv_1(x * x_mask)
27
+ x = self.act1(x)
28
+ x = self.drop(x)
29
+ x = self.conv_2(x * x_mask)
30
+ return x * x_mask
31
+
32
+ class MultiHeadAttention(nn.Module):
33
+ def __init__(self, channels, out_channels, n_heads, p_dropout=0.):
34
+ super().__init__()
35
+ assert channels % n_heads == 0
36
+
37
+ self.channels = channels
38
+ self.out_channels = out_channels
39
+ self.n_heads = n_heads
40
+ self.p_dropout = p_dropout
41
+
42
+ self.k_channels = channels // n_heads
43
+ self.conv_q = torch.nn.Conv1d(channels, channels, 1)
44
+ self.conv_k = torch.nn.Conv1d(channels, channels, 1)
45
+ self.conv_v = torch.nn.Conv1d(channels, channels, 1)
46
+
47
+ # from https://nn.labml.ai/transformers/rope/index.html
48
+ self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
49
+ self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
50
+
51
+ self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
52
+ self.drop = torch.nn.Dropout(p_dropout)
53
+
54
+ torch.nn.init.xavier_uniform_(self.conv_q.weight)
55
+ torch.nn.init.xavier_uniform_(self.conv_k.weight)
56
+ torch.nn.init.xavier_uniform_(self.conv_v.weight)
57
+
58
+ def forward(self, x, attn_mask=None):
59
+ q = self.conv_q(x)
60
+ k = self.conv_k(x)
61
+ v = self.conv_v(x)
62
+
63
+ x = self.attention(q, k, v, mask=attn_mask)
64
+
65
+ x = self.conv_o(x)
66
+ return x
67
+
68
+ def attention(self, query, key, value, mask=None):
69
+ b, d, t_s, t_t = (*key.size(), query.size(2))
70
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
71
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
72
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
73
+
74
+ query = self.query_rotary_pe(query) # [b, n_head, t, c // n_head]
75
+ key = self.key_rotary_pe(key)
76
+
77
+ output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=self.p_dropout if self.training else 0)
78
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
79
+ return output
80
+
81
+ # modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/modules.py#L390
82
+ class DiTConVBlock(nn.Module):
83
+ """
84
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
85
+ """
86
+ def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0):
87
+ super().__init__()
88
+ self.norm1 = nn.LayerNorm(hidden_channels, elementwise_affine=False, eps=1e-6)
89
+ self.attn = MultiHeadAttention(hidden_channels, hidden_channels, num_heads, p_dropout)
90
+ self.norm2 = nn.LayerNorm(hidden_channels, elementwise_affine=False, eps=1e-6)
91
+ self.mlp = FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)
92
+ self.adaLN_modulation = nn.Sequential(
93
+ nn.Linear(gin_channels, hidden_channels) if gin_channels != hidden_channels else nn.Identity(),
94
+ nn.SiLU(),
95
+ nn.Linear(hidden_channels, 6 * hidden_channels, bias=True)
96
+ )
97
+
98
+ def forward(self, x, c, x_mask):
99
+ """
100
+ Args:
101
+ x : [batch_size, channel, time]
102
+ c : [batch_size, channel]
103
+ x_mask : [batch_size, 1, time]
104
+ return the same shape as x
105
+ """
106
+ x = x * x_mask
107
+ attn_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(-1) # shape: [batch_size, 1, time, time]
108
+ # attn_mask = attn_mask.to(torch.bool)
109
+
110
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).unsqueeze(2).chunk(6, dim=1) # shape: [batch_size, channel, 1]
111
+ x = x + gate_msa * self.attn(self.modulate(self.norm1(x.transpose(1,2)).transpose(1,2), shift_msa, scale_msa), attn_mask) * x_mask
112
+ x = x + gate_mlp * self.mlp(self.modulate(self.norm2(x.transpose(1,2)).transpose(1,2), shift_mlp, scale_mlp), x_mask)
113
+
114
+ # no condition version
115
+ # x = x + self.attn(self.norm1(x.transpose(1,2)).transpose(1,2), attn_mask)
116
+ # x = x + self.mlp(self.norm1(x.transpose(1,2)).transpose(1,2), x_mask)
117
+ return x
118
+
119
+ @staticmethod
120
+ def modulate(x, shift, scale):
121
+ return x * (1 + scale) + shift
122
+
123
+ class RotaryPositionalEmbeddings(nn.Module):
124
+ """
125
+ ## RoPE module
126
+
127
+ Rotary encoding transforms pairs of features by rotating in the 2D plane.
128
+ That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
129
+ Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
130
+ by an angle depending on the position of the token.
131
+ """
132
+
133
+ def __init__(self, d: int, base: int = 10_000):
134
+ r"""
135
+ * `d` is the number of features $d$
136
+ * `base` is the constant used for calculating $\Theta$
137
+ """
138
+ super().__init__()
139
+
140
+ self.base = base
141
+ self.d = int(d)
142
+ self.cos_cached = None
143
+ self.sin_cached = None
144
+
145
+ def _build_cache(self, x: torch.Tensor):
146
+ r"""
147
+ Cache $\cos$ and $\sin$ values
148
+ """
149
+ # Return if cache is already built
150
+ if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
151
+ return
152
+
153
+ # Get sequence length
154
+ seq_len = x.shape[0]
155
+
156
+ # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
157
+ theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
158
+
159
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
160
+ seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
161
+
162
+ # Calculate the product of position index and $\theta_i$
163
+ idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
164
+
165
+ # Concatenate so that for row $m$ we have
166
+ # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
167
+ idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
168
+
169
+ # Cache them
170
+ self.cos_cached = idx_theta2.cos()[:, None, None, :]
171
+ self.sin_cached = idx_theta2.sin()[:, None, None, :]
172
+
173
+ def _neg_half(self, x: torch.Tensor):
174
+ # $\frac{d}{2}$
175
+ d_2 = self.d // 2
176
+
177
+ # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
178
+ return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
179
+
180
+ def forward(self, x: torch.Tensor):
181
+ """
182
+ * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
183
+ """
184
+ # Cache $\cos$ and $\sin$ values
185
+ x = x.permute(2, 0, 1, 3) # b h t d -> t b h d
186
+
187
+ self._build_cache(x)
188
+
189
+ # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
190
+ x_rope, x_pass = x[..., : self.d], x[..., self.d :]
191
+
192
+ # Calculate
193
+ # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
194
+ neg_half_x = self._neg_half(x_rope)
195
+
196
+ x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
197
+
198
+ return torch.cat((x_rope, x_pass), dim=-1).permute(1, 2, 0, 3) # t b h d -> b h t d
199
+
200
+ class Transpose(nn.Identity):
201
+ """(N, T, D) -> (N, D, T)"""
202
+
203
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
204
+ return input.transpose(1, 2)
205
+
models/duration_predictor.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ # modified from https://github.com/jaywalnut310/vits/blob/main/models.py#L98
5
+ class DurationPredictor(nn.Module):
6
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
7
+ super().__init__()
8
+
9
+ self.in_channels = in_channels
10
+ self.filter_channels = filter_channels
11
+ self.kernel_size = kernel_size
12
+ self.p_dropout = p_dropout
13
+ self.gin_channels = gin_channels
14
+
15
+ self.drop = nn.Dropout(p_dropout)
16
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
17
+ self.norm_1 = nn.LayerNorm(filter_channels)
18
+ self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
19
+ self.norm_2 = nn.LayerNorm(filter_channels)
20
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
21
+
22
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
23
+
24
+ def forward(self, x, x_mask, g):
25
+ x = x.detach()
26
+ x = x + self.cond(g.unsqueeze(2).detach())
27
+ x = self.conv_1(x * x_mask)
28
+ x = torch.relu(x)
29
+ x = self.norm_1(x.transpose(1,2)).transpose(1,2)
30
+ x = self.drop(x)
31
+ x = self.conv_2(x * x_mask)
32
+ x = torch.relu(x)
33
+ x = self.norm_2(x.transpose(1,2)).transpose(1,2)
34
+ x = self.drop(x)
35
+ x = self.proj(x * x_mask)
36
+ return x * x_mask
37
+
38
+ def duration_loss(logw, logw_, lengths):
39
+ loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
40
+ return loss
models/estimator.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from models.dit import DiTConVBlock
8
+
9
+ class DitWrapper(nn.Module):
10
+ """ add FiLM layer to condition time embedding to DiT """
11
+ def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0, time_channels=0):
12
+ super().__init__()
13
+ self.time_fusion = FiLMLayer(hidden_channels, time_channels)
14
+ self.conv1 = ConvNeXtBlock(hidden_channels, filter_channels, gin_channels)
15
+ self.conv2 = ConvNeXtBlock(hidden_channels, filter_channels, gin_channels)
16
+ self.conv3 = ConvNeXtBlock(hidden_channels, filter_channels, gin_channels)
17
+ self.block = DiTConVBlock(hidden_channels, hidden_channels, num_heads, kernel_size, p_dropout, gin_channels)
18
+
19
+ def forward(self, x, c, t, x_mask):
20
+ x = self.time_fusion(x, t) * x_mask
21
+ x = self.conv1(x, c, x_mask)
22
+ x = self.conv2(x, c, x_mask)
23
+ x = self.conv3(x, c, x_mask)
24
+ x = self.block(x, c, x_mask)
25
+ return x
26
+
27
+ class FiLMLayer(nn.Module):
28
+ """
29
+ Feature-wise Linear Modulation (FiLM) layer
30
+ Reference: https://arxiv.org/abs/1709.07871
31
+ """
32
+ def __init__(self, in_channels, cond_channels):
33
+
34
+ super(FiLMLayer, self).__init__()
35
+ self.in_channels = in_channels
36
+ self.film = nn.Conv1d(cond_channels, in_channels * 2, 1)
37
+
38
+ def forward(self, x, c):
39
+ gamma, beta = torch.chunk(self.film(c.unsqueeze(2)), chunks=2, dim=1)
40
+ return gamma * x + beta
41
+
42
+ class ConvNeXtBlock(nn.Module):
43
+ def __init__(self, in_channels, filter_channels, gin_channels):
44
+ super().__init__()
45
+ self.dwconv = nn.Conv1d(in_channels, in_channels, kernel_size=7, padding=3, groups=in_channels)
46
+ self.norm = StyleAdaptiveLayerNorm(in_channels, gin_channels)
47
+ self.pwconv = nn.Sequential(nn.Linear(in_channels, filter_channels),
48
+ nn.GELU(),
49
+ nn.Linear(filter_channels, in_channels))
50
+
51
+ def forward(self, x, c, x_mask) -> torch.Tensor:
52
+ residual = x
53
+ x = self.dwconv(x) * x_mask
54
+ x = self.norm(x.transpose(1, 2), c)
55
+ x = self.pwconv(x).transpose(1, 2)
56
+ x = residual + x
57
+ return x * x_mask
58
+
59
+ class StyleAdaptiveLayerNorm(nn.Module):
60
+ def __init__(self, in_channels, cond_channels):
61
+ """
62
+ Style Adaptive Layer Normalization (SALN) module.
63
+
64
+ Parameters:
65
+ in_channels: The number of channels in the input feature maps.
66
+ cond_channels: The number of channels in the conditioning input.
67
+ """
68
+ super(StyleAdaptiveLayerNorm, self).__init__()
69
+ self.in_channels = in_channels
70
+
71
+ self.saln = nn.Linear(cond_channels, in_channels * 2, 1)
72
+ self.norm = nn.LayerNorm(in_channels, elementwise_affine=False)
73
+
74
+ self.reset_parameters()
75
+
76
+ def reset_parameters(self):
77
+ nn.init.constant_(self.saln.bias.data[:self.in_channels], 1)
78
+ nn.init.constant_(self.saln.bias.data[self.in_channels:], 0)
79
+
80
+ def forward(self, x, c):
81
+ gamma, beta = torch.chunk(self.saln(c.unsqueeze(1)), chunks=2, dim=-1)
82
+ return gamma * self.norm(x) + beta
83
+
84
+
85
+ class SinusoidalPosEmb(nn.Module):
86
+ def __init__(self, dim):
87
+ super().__init__()
88
+ self.dim = dim
89
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
90
+
91
+ def forward(self, x, scale=1000):
92
+ if x.ndim < 1:
93
+ x = x.unsqueeze(0)
94
+ half_dim = self.dim // 2
95
+ emb = math.log(10000) / (half_dim - 1)
96
+ emb = torch.exp(torch.arange(half_dim, device=x.device).float() * -emb)
97
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
98
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
99
+ return emb
100
+
101
+ class TimestepEmbedding(nn.Module):
102
+ def __init__(self, in_channels, out_channels, filter_channels):
103
+ super().__init__()
104
+
105
+ self.layer = nn.Sequential(
106
+ nn.Linear(in_channels, filter_channels),
107
+ nn.SiLU(inplace=True),
108
+ nn.Linear(filter_channels, out_channels)
109
+ )
110
+
111
+ def forward(self, x):
112
+ return self.layer(x)
113
+
114
+ # reference: https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/decoder.py
115
+ class Decoder(nn.Module):
116
+ def __init__(self, hidden_channels, out_channels, filter_channels, dropout=0.05, n_layers=1, n_heads=4, kernel_size=3, gin_channels=0):
117
+ super().__init__()
118
+ self.hidden_channels = hidden_channels
119
+ self.out_channels = out_channels
120
+ self.filter_channels = filter_channels
121
+
122
+ self.time_embeddings = SinusoidalPosEmb(hidden_channels)
123
+ self.time_mlp = TimestepEmbedding(hidden_channels, hidden_channels, filter_channels)
124
+
125
+
126
+ self.blocks = nn.ModuleList([DitWrapper(hidden_channels, filter_channels, n_heads, kernel_size, dropout, gin_channels, hidden_channels) for _ in range(n_layers)])
127
+ self.final_proj = nn.Conv1d(hidden_channels, out_channels, 1)
128
+
129
+ self.initialize_weights()
130
+
131
+ def initialize_weights(self):
132
+ for block in self.blocks:
133
+ nn.init.constant_(block.block.adaLN_modulation[-1].weight, 0)
134
+ nn.init.constant_(block.block.adaLN_modulation[-1].bias, 0)
135
+
136
+ def forward(self, x, mask, mu, t, c):
137
+ """Forward pass of the UNet1DConditional model.
138
+
139
+ Args:
140
+ x (torch.Tensor): shape (batch_size, in_channels, time)
141
+ mask (_type_): shape (batch_size, 1, time)
142
+ t (_type_): shape (batch_size)
143
+ c (_type_): shape (batch_size, gin_channels)
144
+
145
+ Raises:
146
+ ValueError: _description_
147
+ ValueError: _description_
148
+
149
+ Returns:
150
+ _type_: _description_
151
+ """
152
+
153
+ t = self.time_mlp(self.time_embeddings(t))
154
+ x = torch.cat((x, mu), dim=1)
155
+
156
+ for block in self.blocks:
157
+ x = block(x, c, t, mask)
158
+
159
+ output = self.final_proj(x * mask)
160
+
161
+ return output * mask
models/flow_matching.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from models.estimator import Decoder
6
+
7
+ # copied from https://github.com/jaywalnut310/vits/blob/main/commons.py#L121
8
+ def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
9
+ if max_length is None:
10
+ max_length = length.max()
11
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
12
+ return x.unsqueeze(0) < length.unsqueeze(1)
13
+
14
+ # modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/flow_matching.py
15
+ class CFMDecoder(torch.nn.Module):
16
+ def __init__(self, hidden_channels, out_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
17
+ super().__init__()
18
+ self.hidden_channels = hidden_channels
19
+ self.out_channels = out_channels
20
+ self.filter_channels = filter_channels
21
+ self.gin_channels = gin_channels
22
+ self.sigma_min = 1e-4
23
+
24
+ self.estimator = Decoder(hidden_channels, out_channels, filter_channels, p_dropout, n_layers, n_heads, kernel_size, gin_channels)
25
+
26
+ @torch.inference_mode()
27
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, c=None):
28
+ """Forward diffusion
29
+
30
+ Args:
31
+ mu (torch.Tensor): output of encoder
32
+ shape: (batch_size, n_feats, mel_timesteps)
33
+ mask (torch.Tensor): output_mask
34
+ shape: (batch_size, 1, mel_timesteps)
35
+ n_timesteps (int): number of diffusion steps
36
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
37
+ c (torch.Tensor, optional): shape: (batch_size, gin_channels)
38
+
39
+ Returns:
40
+ sample: generated mel-spectrogram
41
+ shape: (batch_size, n_feats, mel_timesteps)
42
+ """
43
+ z = torch.randn_like(mu) * temperature
44
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
45
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, c=c)
46
+
47
+ def solve_euler(self, x, t_span, mu, mask, c):
48
+ """
49
+ Fixed euler solver for ODEs.
50
+ Args:
51
+ x (torch.Tensor): random noise
52
+ t_span (torch.Tensor): n_timesteps interpolated
53
+ shape: (n_timesteps + 1,)
54
+ mu (torch.Tensor): output of encoder
55
+ shape: (batch_size, n_feats, mel_timesteps)
56
+ mask (torch.Tensor): output_mask
57
+ shape: (batch_size, 1, mel_timesteps)
58
+ c (torch.Tensor, optional): speaker condition.
59
+ shape: (batch_size, gin_channels)
60
+ """
61
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
62
+
63
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
64
+ # Or in future might add like a return_all_steps flag
65
+ sol = []
66
+
67
+ for step in range(1, len(t_span)):
68
+ dphi_dt = self.estimator(x, mask, mu, t, c)
69
+
70
+ x = x + dt * dphi_dt
71
+ t = t + dt
72
+ sol.append(x)
73
+ if step < len(t_span) - 1:
74
+ dt = t_span[step + 1] - t
75
+
76
+ return sol[-1]
77
+
78
+ def compute_loss(self, x1, mask, mu, c):
79
+ """Computes diffusion loss
80
+
81
+ Args:
82
+ x1 (torch.Tensor): Target
83
+ shape: (batch_size, n_feats, mel_timesteps)
84
+ mask (torch.Tensor): target mask
85
+ shape: (batch_size, 1, mel_timesteps)
86
+ mu (torch.Tensor): output of encoder
87
+ shape: (batch_size, n_feats, mel_timesteps)
88
+ c (torch.Tensor, optional): speaker condition.
89
+
90
+ Returns:
91
+ loss: conditional flow matching loss
92
+ y: conditional flow
93
+ shape: (batch_size, n_feats, mel_timesteps)
94
+ """
95
+ b, _, t = mu.shape
96
+
97
+ # random timestep
98
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
99
+ # sample noise p(x_0)
100
+ z = torch.randn_like(x1)
101
+
102
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
103
+ u = x1 - (1 - self.sigma_min) * z
104
+
105
+ loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), c), u, reduction="sum") / (
106
+ torch.sum(mask) * u.shape[1]
107
+ )
108
+ return loss, y
models/model.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ import monotonic_align
6
+ from models.text_encoder import TextEncoder
7
+ from models.flow_matching import CFMDecoder
8
+ from models.reference_encoder import MelStyleEncoder
9
+ from models.duration_predictor import DurationPredictor, duration_loss
10
+
11
+ def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
12
+ if max_length is None:
13
+ max_length = length.max()
14
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
15
+ return x.unsqueeze(0) < length.unsqueeze(1)
16
+
17
+ def convert_pad_shape(pad_shape):
18
+ inverted_shape = pad_shape[::-1]
19
+ pad_shape = [item for sublist in inverted_shape for item in sublist]
20
+ return pad_shape
21
+
22
+ def generate_path(duration, mask):
23
+ device = duration.device
24
+
25
+ b, t_x, t_y = mask.shape
26
+ cum_duration = torch.cumsum(duration, 1)
27
+ path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
28
+
29
+ cum_duration_flat = cum_duration.view(b * t_x)
30
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
31
+ path = path.view(b, t_x, t_y)
32
+ path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
33
+ path = path * mask
34
+ return path
35
+
36
+ # modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/matcha_tts.py
37
+ class StableTTS(nn.Module):
38
+ def __init__(self, n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, n_dec_layers, kernel_size, p_dropout, gin_channels):
39
+ super().__init__()
40
+
41
+ self.n_vocab = n_vocab
42
+ self.mel_channels = mel_channels
43
+
44
+ self.encoder = TextEncoder(n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, kernel_size, p_dropout, gin_channels)
45
+ self.ref_encoder = MelStyleEncoder(mel_channels, style_vector_dim=gin_channels, style_kernel_size=3)
46
+ self.dp = DurationPredictor(hidden_channels, filter_channels, kernel_size, p_dropout, gin_channels)
47
+ self.decoder = CFMDecoder(mel_channels + mel_channels, mel_channels, filter_channels, n_heads, n_dec_layers, kernel_size, p_dropout, gin_channels)
48
+
49
+ @torch.inference_mode()
50
+ def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, y=None, length_scale=1.0):
51
+ """
52
+ Generates mel-spectrogram from text. Returns:
53
+ 1. encoder outputs
54
+ 2. decoder outputs
55
+ 3. generated alignment
56
+
57
+ Args:
58
+ x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
59
+ shape: (batch_size, max_text_length)
60
+ x_lengths (torch.Tensor): lengths of texts in batch.
61
+ shape: (batch_size,)
62
+ n_timesteps (int): number of steps to use for reverse diffusion in decoder.
63
+ temperature (float, optional): controls variance of terminal distribution.
64
+ y (torch.Tensor): mel spectrogram of reference audio
65
+ shape: (batch_size, mel_channels, time)
66
+ length_scale (float, optional): controls speech pace.
67
+ Increase value to slow down generated speech and vice versa.
68
+
69
+ Returns:
70
+ dict: {
71
+ "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
72
+ # Average mel spectrogram generated by the encoder
73
+ "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
74
+ # Refined mel spectrogram improved by the CFM
75
+ "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
76
+ # Alignment map between text and mel spectrogram
77
+ """
78
+
79
+ # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
80
+ c = self.ref_encoder(y, None)
81
+ x, mu_x, x_mask = self.encoder(x, c, x_lengths)
82
+ logw = self.dp(x, x_mask, c)
83
+
84
+ w = torch.exp(logw) * x_mask
85
+ w_ceil = torch.ceil(w) * length_scale
86
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
87
+ y_max_length = y_lengths.max()
88
+
89
+ # Using obtained durations `w` construct alignment map `attn`
90
+ y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask.dtype)
91
+ attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
92
+ attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
93
+
94
+ # Align encoded text and get mu_y
95
+ mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
96
+ mu_y = mu_y.transpose(1, 2)
97
+ encoder_outputs = mu_y[:, :, :y_max_length]
98
+
99
+ # Generate sample tracing the probability flow
100
+ decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c)
101
+ decoder_outputs = decoder_outputs[:, :, :y_max_length]
102
+
103
+
104
+ return {
105
+ "encoder_outputs": encoder_outputs,
106
+ "decoder_outputs": decoder_outputs,
107
+ "attn": attn[:, :, :y_max_length],
108
+ }
109
+
110
+ def forward(self, x, x_lengths, y, y_lengths):
111
+ """
112
+ Computes 3 losses:
113
+ 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
114
+ 2. prior loss: loss between mel-spectrogram and encoder outputs.
115
+ 3. flow matching loss: loss between mel-spectrogram and decoder outputs.
116
+
117
+ Args:
118
+ x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
119
+ shape: (batch_size, max_text_length)
120
+ x_lengths (torch.Tensor): lengths of texts in batch.
121
+ shape: (batch_size,)
122
+ y (torch.Tensor): batch of corresponding mel-spectrograms.
123
+ shape: (batch_size, n_feats, max_mel_length)
124
+ y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
125
+ shape: (batch_size,)
126
+ """
127
+ # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
128
+ y_mask = sequence_mask(y_lengths, y.size(2)).unsqueeze(1).to(y.dtype)
129
+ c = self.ref_encoder(y, y_mask)
130
+
131
+ x, mu_x, x_mask = self.encoder(x, c, x_lengths)
132
+ logw = self.dp(x, x_mask, c)
133
+
134
+ attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
135
+
136
+ # Use MAS to find most likely alignment `attn` between text and mel-spectrogram
137
+
138
+ # I'm not sure why the MAS code in Matcha TTS and Grad TTS could not align in StableTTS
139
+ # so I use the code from https://github.com/p0p4k/pflowtts_pytorch/blob/master/pflow/models/pflow_tts.py and it works
140
+ # Welcome everyone to solve this problem QAQ
141
+
142
+ with torch.no_grad():
143
+ # const = -0.5 * math.log(2 * math.pi) * self.n_feats
144
+ # const = -0.5 * math.log(2 * math.pi) * self.mel_channels
145
+ # factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
146
+ # y_square = torch.matmul(factor.transpose(1, 2), y**2)
147
+ # y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
148
+ # mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
149
+ # log_prior = y_square - y_mu_double + mu_square + const
150
+
151
+ s_p_sq_r = torch.ones_like(mu_x) # [b, d, t]
152
+ # s_p_sq_r = torch.exp(-2 * logx)
153
+ neg_cent1 = torch.sum(
154
+ -0.5 * math.log(2 * math.pi)- torch.zeros_like(mu_x), [1], keepdim=True
155
+ )
156
+ # neg_cent1 = torch.sum(
157
+ # -0.5 * math.log(2 * math.pi) - logx, [1], keepdim=True
158
+ # ) # [b, 1, t_s]
159
+ neg_cent2 = torch.einsum("bdt, bds -> bts", -0.5 * (y**2), s_p_sq_r)
160
+ neg_cent3 = torch.einsum("bdt, bds -> bts", y, (mu_x * s_p_sq_r))
161
+ neg_cent4 = torch.sum(
162
+ -0.5 * (mu_x**2) * s_p_sq_r, [1], keepdim=True
163
+ )
164
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
165
+
166
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
167
+
168
+ attn = (
169
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
170
+ )
171
+
172
+ # attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
173
+ # attn = attn.detach()
174
+
175
+ # Compute loss between predicted log-scaled durations and those obtained from MAS
176
+ # refered to as prior loss in the paper
177
+ logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask
178
+ # logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
179
+ dur_loss = duration_loss(logw, logw_, x_lengths)
180
+
181
+
182
+ # Align encoded text with mel-spectrogram and get mu_y segment
183
+ attn = attn.squeeze(1).transpose(1,2)
184
+ mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
185
+ mu_y = mu_y.transpose(1, 2)
186
+
187
+ # Compute loss of the decoder
188
+ diff_loss, _ = self.decoder.compute_loss(y, y_mask, mu_y, c)
189
+ # diff_loss = torch.tensor([0], device=mu_y.device)
190
+
191
+ prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
192
+ prior_loss = prior_loss / (torch.sum(y_mask) * self.mel_channels)
193
+
194
+ return dur_loss, diff_loss, prior_loss, attn
models/reference_encoder.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Conv1dGLU(nn.Module):
5
+ """
6
+ Conv1d + GLU(Gated Linear Unit) with residual connection.
7
+ For GLU refer to https://arxiv.org/abs/1612.08083 paper.
8
+ """
9
+
10
+ def __init__(self, in_channels, out_channels, kernel_size, dropout):
11
+ super(Conv1dGLU, self).__init__()
12
+ self.out_channels = out_channels
13
+ self.conv1 = nn.Conv1d(in_channels, 2 * out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
14
+ self.dropout = nn.Dropout(dropout)
15
+
16
+ def forward(self, x):
17
+ residual = x
18
+ x = self.conv1(x)
19
+ x1, x2 = torch.split(x, self.out_channels, dim=1)
20
+ x = x1 * torch.sigmoid(x2)
21
+ x = residual + self.dropout(x)
22
+ return x
23
+
24
+ # modified from https://github.com/RVC-Boss/GPT-SoVITS/blob/main/GPT_SoVITS/module/modules.py#L766
25
+ class MelStyleEncoder(nn.Module):
26
+ """MelStyleEncoder"""
27
+
28
+ def __init__(
29
+ self,
30
+ n_mel_channels=80,
31
+ style_hidden=128,
32
+ style_vector_dim=256,
33
+ style_kernel_size=5,
34
+ style_head=2,
35
+ dropout=0.1,
36
+ ):
37
+ super(MelStyleEncoder, self).__init__()
38
+ self.in_dim = n_mel_channels
39
+ self.hidden_dim = style_hidden
40
+ self.out_dim = style_vector_dim
41
+ self.kernel_size = style_kernel_size
42
+ self.n_head = style_head
43
+ self.dropout = dropout
44
+
45
+ self.spectral = nn.Sequential(
46
+ nn.Linear(self.in_dim, self.hidden_dim),
47
+ nn.Mish(inplace=True),
48
+ nn.Dropout(self.dropout),
49
+ nn.Linear(self.hidden_dim, self.hidden_dim),
50
+ nn.Mish(inplace=True),
51
+ nn.Dropout(self.dropout),
52
+ )
53
+
54
+ self.temporal = nn.Sequential(
55
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
56
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
57
+ )
58
+
59
+ self.slf_attn = nn.MultiheadAttention(
60
+ self.hidden_dim,
61
+ self.n_head,
62
+ self.dropout,
63
+ batch_first=True
64
+ )
65
+
66
+ self.fc = nn.Linear(self.hidden_dim, self.out_dim)
67
+
68
+ def temporal_avg_pool(self, x, mask=None):
69
+ if mask is None:
70
+ return torch.mean(x, dim=1)
71
+ else:
72
+ len_ = (~mask).sum(dim=1).unsqueeze(1).type_as(x)
73
+ return torch.sum(x * ~mask.unsqueeze(-1), dim=1) / len_
74
+
75
+ def forward(self, x, x_mask=None):
76
+ x = x.transpose(1, 2)
77
+
78
+ # spectral
79
+ x = self.spectral(x)
80
+ # temporal
81
+ x = x.transpose(1, 2)
82
+ x = self.temporal(x)
83
+ x = x.transpose(1, 2)
84
+ # self-attention
85
+ if x_mask is not None:
86
+ x_mask = ~x_mask.squeeze(1).to(torch.bool)
87
+ x, _ = self.slf_attn(x, x, x, key_padding_mask=x_mask)
88
+ # fc
89
+ x = self.fc(x)
90
+ # temoral average pooling
91
+ w = self.temporal_avg_pool(x, mask=x_mask)
92
+
93
+ return w
models/text_encoder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from models.dit import DiTConVBlock
5
+
6
+ def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
7
+ if max_length is None:
8
+ max_length = length.max()
9
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
10
+ return x.unsqueeze(0) < length.unsqueeze(1)
11
+
12
+ # modified from https://github.com/jaywalnut310/vits/blob/main/models.py
13
+ class TextEncoder(nn.Module):
14
+ def __init__(self, n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
15
+ super().__init__()
16
+ self.n_vocab = n_vocab
17
+ self.out_channels = out_channels
18
+ self.hidden_channels = hidden_channels
19
+ self.filter_channels = filter_channels
20
+ self.n_heads = n_heads
21
+ self.n_layers = n_layers
22
+ self.kernel_size = kernel_size
23
+ self.p_dropout = p_dropout
24
+
25
+ self.scale = self.hidden_channels ** 0.5
26
+
27
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
28
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
29
+
30
+ self.encoder = nn.ModuleList([DiTConVBlock(hidden_channels, filter_channels, n_heads, kernel_size, p_dropout, gin_channels) for _ in range(n_layers)])
31
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
32
+
33
+ self.initialize_weights()
34
+
35
+ def initialize_weights(self):
36
+ for block in self.encoder:
37
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
38
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
39
+
40
+ def forward(self, x: torch.Tensor, c: torch.Tensor, x_lengths: torch.Tensor):
41
+ x = self.emb(x) * self.scale # [b, t, h]
42
+ x = x.transpose(1, -1) # [b, h, t]
43
+ x_mask = sequence_mask(x_lengths, x.size(2)).unsqueeze(1).to(x.dtype)
44
+
45
+ for layer in self.encoder:
46
+ x = layer(x, c, x_mask)
47
+ mu_x = self.proj(x) * x_mask
48
+
49
+ return x, mu_x, x_mask
monotonic_align/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy import zeros, int32, float32
2
+ from torch import from_numpy
3
+
4
+ from .core import maximum_path_jit
5
+
6
+
7
+ def maximum_path(neg_cent, mask):
8
+ device = neg_cent.device
9
+ dtype = neg_cent.dtype
10
+ neg_cent = neg_cent.data.cpu().numpy().astype(float32)
11
+ path = zeros(neg_cent.shape, dtype=int32)
12
+
13
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
14
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
15
+ maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
16
+ return from_numpy(path).to(device=device, dtype=dtype)
monotonic_align/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (732 Bytes). View file
 
monotonic_align/__pycache__/core.cpython-310.pyc ADDED
Binary file (985 Bytes). View file
 
monotonic_align/core.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numba
2
+
3
+
4
+ @numba.jit(
5
+ numba.void(
6
+ numba.int32[:, :, ::1],
7
+ numba.float32[:, :, ::1],
8
+ numba.int32[::1],
9
+ numba.int32[::1],
10
+ ),
11
+ nopython=True,
12
+ nogil=True,
13
+ )
14
+ def maximum_path_jit(paths, values, t_ys, t_xs):
15
+ b = paths.shape[0]
16
+ max_neg_val = -1e9
17
+ for i in range(int(b)):
18
+ path = paths[i]
19
+ value = values[i]
20
+ t_y = t_ys[i]
21
+ t_x = t_xs[i]
22
+
23
+ v_prev = v_cur = 0.0
24
+ index = t_x - 1
25
+
26
+ for y in range(t_y):
27
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
28
+ if x == y:
29
+ v_cur = max_neg_val
30
+ else:
31
+ v_cur = value[y - 1, x]
32
+ if x == 0:
33
+ if y == 0:
34
+ v_prev = 0.0
35
+ else:
36
+ v_prev = max_neg_val
37
+ else:
38
+ v_prev = value[y - 1, x - 1]
39
+ value[y, x] += max(v_prev, v_cur)
40
+
41
+ for y in range(t_y - 1, -1, -1):
42
+ path[y, index] = 1
43
+ if index != 0 and (
44
+ index == y or value[y - 1, index] < value[y - 1, index - 1]
45
+ ):
46
+ index = index - 1
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ matplotlib
4
+ numpy
5
+ tensorboard
6
+ pypinyin
7
+ jieba
8
+ eng_to_ipa
9
+ unidecode
10
+ inflect
11
+ pyopenjtalk-prebuilt
text/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2017 Keith Ito
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
+ THE SOFTWARE.
text/__init__.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+ from text import cleaners
3
+ from text.symbols import symbols
4
+
5
+
6
+ # Mappings from symbol to numeric ID and vice versa:
7
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9
+
10
+
11
+ def text_to_sequence(text, symbols, cleaner_names):
12
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
13
+ Args:
14
+ text: string to convert to a sequence
15
+ cleaner_names: names of the cleaner functions to run the text through
16
+ Returns:
17
+ List of integers corresponding to the symbols in the text
18
+ '''
19
+ sequence = []
20
+ symbol_to_id = {s: i for i, s in enumerate(symbols)}
21
+ clean_text = _clean_text(text, cleaner_names)
22
+ print(clean_text)
23
+ print(f" length:{len(clean_text)}")
24
+ for symbol in clean_text:
25
+ if symbol not in symbol_to_id.keys():
26
+ continue
27
+ symbol_id = symbol_to_id[symbol]
28
+ sequence += [symbol_id]
29
+ print(f" length:{len(sequence)}")
30
+ return sequence
31
+
32
+
33
+ def cleaned_text_to_sequence(cleaned_text):
34
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
35
+ Args:
36
+ text: string to convert to a sequence
37
+ Returns:
38
+ List of integers corresponding to the symbols in the text
39
+ '''
40
+ # symbol_to_id = {s: i for i, s in enumerate(symbols)}
41
+ sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
42
+ return sequence
43
+
44
+ def cleaned_text_to_sequence_chinese(cleaned_text):
45
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
46
+ Args:
47
+ text: string to convert to a sequence
48
+ Returns:
49
+ List of integers corresponding to the symbols in the text
50
+ '''
51
+ # symbol_to_id = {s: i for i, s in enumerate(symbols)}
52
+ sequence = [_symbol_to_id[symbol] for symbol in cleaned_text.split(' ') if symbol in _symbol_to_id.keys()]
53
+ return sequence
54
+
55
+
56
+ def sequence_to_text(sequence):
57
+ '''Converts a sequence of IDs back to a string'''
58
+ result = ''
59
+ for symbol_id in sequence:
60
+ s = _id_to_symbol[symbol_id]
61
+ result += s
62
+ return result
63
+
64
+
65
+ def _clean_text(text, cleaner_names):
66
+ for name in cleaner_names:
67
+ cleaner = getattr(cleaners, name)
68
+ if not cleaner:
69
+ raise Exception('Unknown cleaner: %s' % name)
70
+ text = cleaner(text)
71
+ return text
text/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.62 kB). View file
 
text/__pycache__/cleaners.cpython-310.pyc ADDED
Binary file (2.54 kB). View file
 
text/__pycache__/english.cpython-310.pyc ADDED
Binary file (4.4 kB). View file
 
text/cleaners.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import string
3
+ import numpy as np
4
+ from .langdetect import detect, LangDetectException
5
+
6
+ from text.english import english_to_ipa2
7
+ from text.mandarin import chinese_to_cnm3
8
+ from text.japanese import japanese_to_ipa2
9
+
10
+ language_module_map = {"PAD":0, "ZH": 1, "EN": 2, "JA": 3}
11
+
12
+ # 预编译正则表达式
13
+ ZH_PATTERN = re.compile(r'[\u3400-\u4DBF\u4e00-\u9FFF\uF900-\uFAFF\u3000-\u303F]')
14
+ EN_PATTERN = re.compile(r'[a-zA-Z.,!?\'"(){}[\]<>:;@#$%^&*-_+=/\\|~`]+')
15
+ JP_PATTERN = re.compile(r'[\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FAF\u31F0-\u31FF\uFF00-\uFFEF\u3000-\u303F]')
16
+ CLEANER_PATTERN = re.compile(r'\[(ZH|EN|JA)\]')
17
+
18
+ def detect_language(text: str, prev_lang=None):
19
+ """
20
+ 根据给定的文本检测语言
21
+
22
+ :param text: 输入文本
23
+ :param prev_lang: 上一个检测到的语言
24
+ :return: 'ZH' for Chinese, 'EN' for English, 'JA' for Japanese, or prev_lang for spaces
25
+ """
26
+ if ZH_PATTERN.search(text): return 'ZH'
27
+ if EN_PATTERN.search(text): return 'EN'
28
+ if JP_PATTERN.search(text): return 'JA'
29
+ if text.isspace(): return prev_lang # 若是空格,则返回前一个语言
30
+ return None
31
+
32
+ def replace_substring(s, start_index, end_index, replacement):
33
+ return s[:start_index] + replacement + s[end_index:]
34
+
35
+ def replace_sublist(lst, start_index, end_index, replacement_list):
36
+ lst[start_index:end_index] = replacement_list
37
+
38
+ # convert text to ipa and prepare for language embedding
39
+ def append_tags_and_convert(match, conversion_func, tag_value, tags):
40
+ converted_text = conversion_func(match.group(1))
41
+ tags.extend([tag_value] * len(converted_text))
42
+ return converted_text + ' '
43
+
44
+ # auto detect language using re
45
+ def cjke_cleaners4(text: str):
46
+ """
47
+ 根据文本内容自动检测语言并转换为IPA音标
48
+
49
+ :param text: 输入文本
50
+ :return: 转换为IPA音标的文本
51
+ """
52
+ text = CLEANER_PATTERN.sub('', text)
53
+ pointer = 0
54
+ output = ''
55
+ current_language = detect_language(text[pointer])
56
+
57
+ while pointer < len(text):
58
+ temp_text = ''
59
+ while pointer < len(text) and detect_language(text[pointer], current_language) == current_language:
60
+ temp_text += text[pointer]
61
+ pointer += 1
62
+ if current_language == 'ZH':
63
+ output += chinese_to_cnm3(temp_text)
64
+ elif current_language == 'JA':
65
+ output += japanese_to_ipa2(temp_text)
66
+ elif current_language == 'EN':
67
+ output += english_to_ipa2(temp_text)
68
+ if pointer < len(text):
69
+ current_language = detect_language(text[pointer])
70
+
71
+ output = re.sub(r'\s+$', '', output)
72
+ output = re.sub(r'([^\.,!\?\-…~])$', r'\1.', output)
73
+ return output
text/cn2an/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.5.20"
2
+
3
+ from .cn2an import Cn2An
4
+ from .an2cn import An2Cn
5
+ from .transform import Transform
6
+
7
+ cn2an = Cn2An().cn2an
8
+ an2cn = An2Cn().an2cn
9
+ transform = Transform().transform
10
+
11
+ __all__ = [
12
+ "__version__",
13
+ "cn2an",
14
+ "an2cn",
15
+ "transform"
16
+ ]
text/cn2an/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (525 Bytes). View file
 
text/cn2an/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (354 Bytes). View file
 
text/cn2an/__pycache__/an2cn.cpython-311.pyc ADDED
Binary file (8.64 kB). View file
 
text/cn2an/__pycache__/an2cn.cpython-38.pyc ADDED
Binary file (4.73 kB). View file
 
text/cn2an/__pycache__/cn2an.cpython-311.pyc ADDED
Binary file (14.8 kB). View file
 
text/cn2an/__pycache__/cn2an.cpython-38.pyc ADDED
Binary file (7.4 kB). View file
 
text/cn2an/__pycache__/conf.cpython-311.pyc ADDED
Binary file (2.02 kB). View file
 
text/cn2an/__pycache__/conf.cpython-38.pyc ADDED
Binary file (1.61 kB). View file
 
text/cn2an/__pycache__/transform.cpython-311.pyc ADDED
Binary file (9.99 kB). View file
 
text/cn2an/__pycache__/transform.cpython-38.pyc ADDED
Binary file (4.82 kB). View file
 
text/cn2an/an2cn.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ #from proces import preprocess
4
+
5
+ from .conf import NUMBER_LOW_AN2CN, NUMBER_UP_AN2CN, UNIT_LOW_ORDER_AN2CN, UNIT_UP_ORDER_AN2CN
6
+
7
+
8
+ class An2Cn(object):
9
+ def __init__(self) -> None:
10
+ self.all_num = "0123456789"
11
+ self.number_low = NUMBER_LOW_AN2CN
12
+ self.number_up = NUMBER_UP_AN2CN
13
+ self.mode_list = ["low", "up", "rmb", "direct"]
14
+
15
+ def an2cn(self, inputs: Union[str, int, float] = None, mode: str = "low") -> str:
16
+ """阿拉伯数字转中文数字
17
+
18
+ :param inputs: 阿拉伯数字
19
+ :param mode: low 小写数字,up 大写数字,rmb 人民币大写,direct 直接转化
20
+ :return: 中文数字
21
+ """
22
+ if inputs is not None and inputs != "":
23
+ if mode not in self.mode_list:
24
+ raise ValueError(f"mode 仅支持 {str(self.mode_list)} !")
25
+
26
+ # 将数字转化为字符串,这里会有Python会自动做转化
27
+ # 1. -> 1.0 1.00 -> 1.0 -0 -> 0
28
+ if not isinstance(inputs, str):
29
+ inputs = self.__number_to_string(inputs)
30
+
31
+ # 数据预处理:
32
+ # 1. 繁体转简体
33
+ # 2. 全角转半角
34
+ # inputs = preprocess(inputs, pipelines=[
35
+ # "traditional_to_simplified",
36
+ # "full_angle_to_half_angle"
37
+ # ])
38
+
39
+ # 检查数据是否有效
40
+ self.__check_inputs_is_valid(inputs)
41
+
42
+ # 判断正负
43
+ if inputs[0] == "-":
44
+ sign = "负"
45
+ inputs = inputs[1:]
46
+ else:
47
+ sign = ""
48
+
49
+ if mode == "direct":
50
+ output = self.__direct_convert(inputs)
51
+ else:
52
+ # 切割整数部分和小数部分
53
+ split_result = inputs.split(".")
54
+ len_split_result = len(split_result)
55
+ if len_split_result == 1:
56
+ # 不包含小数的输入
57
+ integer_data = split_result[0]
58
+ if mode == "rmb":
59
+ output = self.__integer_convert(integer_data, "up") + "元整"
60
+ else:
61
+ output = self.__integer_convert(integer_data, mode)
62
+ elif len_split_result == 2:
63
+ # 包含小数的输入
64
+ integer_data, decimal_data = split_result
65
+ if mode == "rmb":
66
+ int_data = self.__integer_convert(integer_data, "up")
67
+ dec_data = self.__decimal_convert(decimal_data, "up")
68
+ len_dec_data = len(dec_data)
69
+
70
+ if len_dec_data == 0:
71
+ output = int_data + "元整"
72
+ elif len_dec_data == 1:
73
+ raise ValueError(f"异常输出:{dec_data}")
74
+ elif len_dec_data == 2:
75
+ if dec_data[1] != "零":
76
+ if int_data == "零":
77
+ output = dec_data[1] + "角"
78
+ else:
79
+ output = int_data + "元" + dec_data[1] + "角"
80
+ else:
81
+ output = int_data + "元整"
82
+ else:
83
+ if dec_data[1] != "零":
84
+ if dec_data[2] != "零":
85
+ if int_data == "零":
86
+ output = dec_data[1] + "角" + dec_data[2] + "分"
87
+ else:
88
+ output = int_data + "元" + dec_data[1] + "角" + dec_data[2] + "分"
89
+ else:
90
+ if int_data == "零":
91
+ output = dec_data[1] + "角"
92
+ else:
93
+ output = int_data + "元" + dec_data[1] + "角"
94
+ else:
95
+ if dec_data[2] != "零":
96
+ if int_data == "零":
97
+ output = dec_data[2] + "分"
98
+ else:
99
+ output = int_data + "元" + "零" + dec_data[2] + "分"
100
+ else:
101
+ output = int_data + "元整"
102
+ else:
103
+ output = self.__integer_convert(integer_data, mode) + self.__decimal_convert(decimal_data, mode)
104
+ else:
105
+ raise ValueError(f"输入格式错误:{inputs}!")
106
+ else:
107
+ raise ValueError("输入数据为空!")
108
+
109
+ return sign + output
110
+
111
+ def __direct_convert(self, inputs: str) -> str:
112
+ _output = ""
113
+ for d in inputs:
114
+ if d == ".":
115
+ _output += "点"
116
+ else:
117
+ _output += self.number_low[int(d)]
118
+ return _output
119
+
120
+ @staticmethod
121
+ def __number_to_string(number_data: Union[int, float]) -> str:
122
+ # 小数处理:python 会自动把 0.00005 转化成 5e-05,因此 str(0.00005) != "0.00005"
123
+ string_data = str(number_data)
124
+ if "e" in string_data:
125
+ string_data_list = string_data.split("e")
126
+ string_key = string_data_list[0]
127
+ string_value = string_data_list[1]
128
+ if string_value[0] == "-":
129
+ string_data = "0." + "0" * (int(string_value[1:]) - 1) + string_key
130
+ else:
131
+ string_data = string_key + "0" * int(string_value)
132
+ return string_data
133
+
134
+ def __check_inputs_is_valid(self, check_data: str) -> None:
135
+ # 检查输入数据是否在规定的字典中
136
+ all_check_keys = self.all_num + ".-"
137
+ for data in check_data:
138
+ if data not in all_check_keys:
139
+ raise ValueError(f"输入的数据不在转化范围内:{data}!")
140
+
141
+ def __integer_convert(self, integer_data: str, mode: str) -> str:
142
+ if mode == "low":
143
+ numeral_list = NUMBER_LOW_AN2CN
144
+ unit_list = UNIT_LOW_ORDER_AN2CN
145
+ elif mode == "up":
146
+ numeral_list = NUMBER_UP_AN2CN
147
+ unit_list = UNIT_UP_ORDER_AN2CN
148
+ else:
149
+ raise ValueError(f"error mode: {mode}")
150
+
151
+ # 去除前面的 0,比如 007 => 7
152
+ integer_data = str(int(integer_data))
153
+
154
+ len_integer_data = len(integer_data)
155
+ if len_integer_data > len(unit_list):
156
+ raise ValueError(f"超出数据范围,最长支持 {len(unit_list)} 位")
157
+
158
+ output_an = ""
159
+ for i, d in enumerate(integer_data):
160
+ if int(d):
161
+ output_an += numeral_list[int(d)] + unit_list[len_integer_data - i - 1]
162
+ else:
163
+ if not (len_integer_data - i - 1) % 4:
164
+ output_an += numeral_list[int(d)] + unit_list[len_integer_data - i - 1]
165
+
166
+ if i > 0 and not output_an[-1] == "零":
167
+ output_an += numeral_list[int(d)]
168
+
169
+ output_an = output_an.replace("零零", "零").replace("零万", "万").replace("零亿", "亿").replace("亿万", "亿") \
170
+ .strip("零")
171
+
172
+ # 解决「一十几」问题
173
+ if output_an[:2] in ["一十"]:
174
+ output_an = output_an[1:]
175
+
176
+ # 0 - 1 之间的小数
177
+ if not output_an:
178
+ output_an = "零"
179
+
180
+ return output_an
181
+
182
+ def __decimal_convert(self, decimal_data: str, o_mode: str) -> str:
183
+ len_decimal_data = len(decimal_data)
184
+
185
+ if len_decimal_data > 16:
186
+ print(f"注意:小数部分长度为 {len_decimal_data} ,将自动截取前 16 位有效精度!")
187
+ decimal_data = decimal_data[:16]
188
+
189
+ if len_decimal_data:
190
+ output_an = "点"
191
+ else:
192
+ output_an = ""
193
+
194
+ if o_mode == "low":
195
+ numeral_list = NUMBER_LOW_AN2CN
196
+ elif o_mode == "up":
197
+ numeral_list = NUMBER_UP_AN2CN
198
+ else:
199
+ raise ValueError(f"error mode: {o_mode}")
200
+
201
+ for data in decimal_data:
202
+ output_an += numeral_list[int(data)]
203
+ return output_an
text/cn2an/cn2an.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Union
3
+
4
+ #from proces import preprocess
5
+
6
+ from .an2cn import An2Cn
7
+ from .conf import NUMBER_CN2AN, UNIT_CN2AN, STRICT_CN_NUMBER, NORMAL_CN_NUMBER, NUMBER_LOW_AN2CN, UNIT_LOW_AN2CN
8
+
9
+
10
+ class Cn2An(object):
11
+ def __init__(self) -> None:
12
+ self.all_num = "".join(list(NUMBER_CN2AN.keys()))
13
+ self.all_unit = "".join(list(UNIT_CN2AN.keys()))
14
+ self.strict_cn_number = STRICT_CN_NUMBER
15
+ self.normal_cn_number = NORMAL_CN_NUMBER
16
+ self.check_key_dict = {
17
+ "strict": "".join(self.strict_cn_number.values()) + "点负",
18
+ "normal": "".join(self.normal_cn_number.values()) + "点负",
19
+ "smart": "".join(self.normal_cn_number.values()) + "点负" + "01234567890.-"
20
+ }
21
+ self.pattern_dict = self.__get_pattern()
22
+ self.ac = An2Cn()
23
+ self.mode_list = ["strict", "normal", "smart"]
24
+ self.yjf_pattern = re.compile(fr"^.*?[元圆][{self.all_num}]角([{self.all_num}]分)?$")
25
+ self.pattern1 = re.compile(fr"^-?\d+(\.\d+)?[{self.all_unit}]?$")
26
+ self.ptn_all_num = re.compile(f"^[{self.all_num}]+$")
27
+ # "十?" is for special case "十一万三"
28
+ self.ptn_speaking_mode = re.compile(f"^([{self.all_num}]{{0,2}}[{self.all_unit}])+[{self.all_num}]$")
29
+
30
+ def cn2an(self, inputs: Union[str, int, float] = None, mode: str = "strict") -> Union[float, int]:
31
+ """中文数字转阿拉伯数字
32
+
33
+ :param inputs: 中文数字、阿拉伯数字、中文数字和阿拉伯数字
34
+ :param mode: strict 严格,normal 正常,smart 智能
35
+ :return: 阿拉伯数字
36
+ """
37
+ if inputs is not None or inputs == "":
38
+ if mode not in self.mode_list:
39
+ raise ValueError(f"mode 仅支持 {str(self.mode_list)} !")
40
+
41
+ # 将数字转化为字符串
42
+ if not isinstance(inputs, str):
43
+ inputs = str(inputs)
44
+
45
+ # 数据预处理:
46
+ # 1. 繁体转简体
47
+ # 2. 全角转半角
48
+ # inputs = preprocess(inputs, pipelines=[
49
+ # "traditional_to_simplified",
50
+ # "full_angle_to_half_angle"
51
+ # ])
52
+
53
+ # 特殊转化 廿
54
+ inputs = inputs.replace("廿", "二十")
55
+
56
+ # 检查输入数据是否有效
57
+ sign, integer_data, decimal_data, is_all_num = self.__check_input_data_is_valid(inputs, mode)
58
+
59
+ # smart 下的特殊情况
60
+ if sign == 0:
61
+ return integer_data
62
+ else:
63
+ if not is_all_num:
64
+ if decimal_data is None:
65
+ output = self.__integer_convert(integer_data)
66
+ else:
67
+ output = self.__integer_convert(integer_data) + self.__decimal_convert(decimal_data)
68
+ # fix 1 + 0.57 = 1.5699999999999998
69
+ output = round(output, len(decimal_data))
70
+ else:
71
+ if decimal_data is None:
72
+ output = self.__direct_convert(integer_data)
73
+ else:
74
+ output = self.__direct_convert(integer_data) + self.__decimal_convert(decimal_data)
75
+ # fix 1 + 0.57 = 1.5699999999999998
76
+ output = round(output, len(decimal_data))
77
+ else:
78
+ raise ValueError("输入数据为空!")
79
+
80
+ return sign * output
81
+
82
+ def __get_pattern(self) -> dict:
83
+ # 整数严格检查
84
+ _0 = "[零]"
85
+ _1_9 = "[一二三四五六七八九]"
86
+ _10_99 = f"{_1_9}?[十]{_1_9}?"
87
+ _1_99 = f"({_10_99}|{_1_9})"
88
+ _100_999 = f"({_1_9}[百]([零]{_1_9})?|{_1_9}[百]{_10_99})"
89
+ _1_999 = f"({_100_999}|{_1_99})"
90
+ _1000_9999 = f"({_1_9}[千]([零]{_1_99})?|{_1_9}[千]{_100_999})"
91
+ _1_9999 = f"({_1000_9999}|{_1_999})"
92
+ _10000_99999999 = f"({_1_9999}[万]([零]{_1_999})?|{_1_9999}[万]{_1000_9999})"
93
+ _1_99999999 = f"({_10000_99999999}|{_1_9999})"
94
+ _100000000_9999999999999999 = f"({_1_99999999}[亿]([零]{_1_99999999})?|{_1_99999999}[亿]{_10000_99999999})"
95
+ _1_9999999999999999 = f"({_100000000_9999999999999999}|{_1_99999999})"
96
+ str_int_pattern = f"^({_0}|{_1_9999999999999999})$"
97
+ nor_int_pattern = f"^({_0}|{_1_9999999999999999})$"
98
+
99
+ str_dec_pattern = "^[零一二三四五六七八九]{0,15}[一二三四五六七八九]$"
100
+ nor_dec_pattern = "^[零一二三四五六七八九]{0,16}$"
101
+
102
+ for str_num in self.strict_cn_number.keys():
103
+ str_int_pattern = str_int_pattern.replace(str_num, self.strict_cn_number[str_num])
104
+ str_dec_pattern = str_dec_pattern.replace(str_num, self.strict_cn_number[str_num])
105
+ for nor_num in self.normal_cn_number.keys():
106
+ nor_int_pattern = nor_int_pattern.replace(nor_num, self.normal_cn_number[nor_num])
107
+ nor_dec_pattern = nor_dec_pattern.replace(nor_num, self.normal_cn_number[nor_num])
108
+
109
+ pattern_dict = {
110
+ "strict": {
111
+ "int": re.compile(str_int_pattern),
112
+ "dec": re.compile(str_dec_pattern)
113
+ },
114
+ "normal": {
115
+ "int": re.compile(nor_int_pattern),
116
+ "dec": re.compile(nor_dec_pattern)
117
+ }
118
+ }
119
+ return pattern_dict
120
+
121
+ def __copy_num(self, num):
122
+ cn_num = ""
123
+ for n in num:
124
+ cn_num += NUMBER_LOW_AN2CN[int(n)]
125
+ return cn_num
126
+
127
+ def __check_input_data_is_valid(self, check_data: str, mode: str) -> (int, str, str, bool):
128
+ # 去除 元整、圆整、元正、圆正
129
+ stop_words = ["元整", "圆整", "元正", "圆正"]
130
+ for word in stop_words:
131
+ if check_data[-2:] == word:
132
+ check_data = check_data[:-2]
133
+
134
+ # 去除 元、圆
135
+ if mode != "strict":
136
+ normal_stop_words = ["圆", "元"]
137
+ for word in normal_stop_words:
138
+ if check_data[-1] == word:
139
+ check_data = check_data[:-1]
140
+
141
+ # 处理元角分
142
+ result = self.yjf_pattern.search(check_data)
143
+ if result:
144
+ check_data = check_data.replace("元", "点").replace("角", "").replace("分", "")
145
+
146
+ # 处理特殊问法:一千零十一 一万零百一十一
147
+ if "零十" in check_data:
148
+ check_data = check_data.replace("零十", "零一十")
149
+ if "零百" in check_data:
150
+ check_data = check_data.replace("零百", "零一百")
151
+
152
+ for data in check_data:
153
+ if data not in self.check_key_dict[mode]:
154
+ raise ValueError(f"当前为{mode}模式,输入的数据不在转化范围内:{data}!")
155
+
156
+ # 确定正负号
157
+ if check_data[0] == "负":
158
+ check_data = check_data[1:]
159
+ sign = -1
160
+ else:
161
+ sign = 1
162
+
163
+ if "点" in check_data:
164
+ split_data = check_data.split("点")
165
+ if len(split_data) == 2:
166
+ integer_data, decimal_data = split_data
167
+ # 将 smart 模式中的阿拉伯数字转化成中文数字
168
+ if mode == "smart":
169
+ integer_data = re.sub(r"\d+", lambda x: self.ac.an2cn(x.group()), integer_data)
170
+ decimal_data = re.sub(r"\d+", lambda x: self.__copy_num(x.group()), decimal_data)
171
+ mode = "normal"
172
+ else:
173
+ raise ValueError("数据中包含不止一个点!")
174
+ else:
175
+ integer_data = check_data
176
+ decimal_data = None
177
+ # 将 smart 模式中的阿拉伯数字转化成中文数字
178
+ if mode == "smart":
179
+ # 10.1万 10.1
180
+ result1 = self.pattern1.search(integer_data)
181
+ if result1:
182
+ if result1.group() == integer_data:
183
+ if integer_data[-1] in UNIT_CN2AN.keys():
184
+ output = int(float(integer_data[:-1]) * UNIT_CN2AN[integer_data[-1]])
185
+ else:
186
+ output = float(integer_data)
187
+ return 0, output, None, None
188
+
189
+ integer_data = re.sub(r"\d+", lambda x: self.ac.an2cn(x.group()), integer_data)
190
+ mode = "normal"
191
+
192
+ result_int = self.pattern_dict[mode]["int"].search(integer_data)
193
+ if result_int:
194
+ if result_int.group() == integer_data:
195
+ if decimal_data is not None:
196
+ result_dec = self.pattern_dict[mode]["dec"].search(decimal_data)
197
+ if result_dec:
198
+ if result_dec.group() == decimal_data:
199
+ return sign, integer_data, decimal_data, False
200
+ else:
201
+ return sign, integer_data, decimal_data, False
202
+ else:
203
+ if mode == "strict":
204
+ raise ValueError(f"不符合格式的数据:{integer_data}")
205
+ elif mode == "normal":
206
+ # 纯数模式:一二三
207
+ result_all_num = self.ptn_all_num.search(integer_data)
208
+ if result_all_num:
209
+ if result_all_num.group() == integer_data:
210
+ if decimal_data is not None:
211
+ result_dec = self.pattern_dict[mode]["dec"].search(decimal_data)
212
+ if result_dec:
213
+ if result_dec.group() == decimal_data:
214
+ return sign, integer_data, decimal_data, True
215
+ else:
216
+ return sign, integer_data, decimal_data, True
217
+
218
+ # 口语模式:一万二,两千三,三百四,十三万六,一百二十五万三
219
+ result_speaking_mode = self.ptn_speaking_mode.search(integer_data)
220
+ if len(integer_data) >= 3 and result_speaking_mode and result_speaking_mode.group() == integer_data:
221
+ # len(integer_data)>=3: because the minimum length of integer_data that can be matched is 3
222
+ # to find the last unit
223
+ last_unit = result_speaking_mode.groups()[-1][-1]
224
+ _unit = UNIT_LOW_AN2CN[UNIT_CN2AN[last_unit] // 10]
225
+ integer_data = integer_data + _unit
226
+ if decimal_data is not None:
227
+ result_dec = self.pattern_dict[mode]["dec"].search(decimal_data)
228
+ if result_dec:
229
+ if result_dec.group() == decimal_data:
230
+ return sign, integer_data, decimal_data, False
231
+ else:
232
+ return sign, integer_data, decimal_data, False
233
+
234
+ raise ValueError(f"不符合格式的数据:{check_data}")
235
+
236
+ def __integer_convert(self, integer_data: str) -> int:
237
+ # 核心
238
+ output_integer = 0
239
+ unit = 1
240
+ ten_thousand_unit = 1
241
+ for index, cn_num in enumerate(reversed(integer_data)):
242
+ # 数值
243
+ if cn_num in NUMBER_CN2AN:
244
+ num = NUMBER_CN2AN[cn_num]
245
+ output_integer += num * unit
246
+ # 单位
247
+ elif cn_num in UNIT_CN2AN:
248
+ unit = UNIT_CN2AN[cn_num]
249
+ # 判断出万、亿、万亿
250
+ if unit % 10000 == 0:
251
+ # 万 亿
252
+ if unit > ten_thousand_unit:
253
+ ten_thousand_unit = unit
254
+ # 万亿
255
+ else:
256
+ ten_thousand_unit = unit * ten_thousand_unit
257
+ unit = ten_thousand_unit
258
+
259
+ if unit < ten_thousand_unit:
260
+ unit = unit * ten_thousand_unit
261
+
262
+ if index == len(integer_data) - 1:
263
+ output_integer += unit
264
+ else:
265
+ raise ValueError(f"{cn_num} 不在转化范围内")
266
+
267
+ return int(output_integer)
268
+
269
+ def __decimal_convert(self, decimal_data: str) -> float:
270
+ len_decimal_data = len(decimal_data)
271
+
272
+ if len_decimal_data > 16:
273
+ print(f"注意:小数部分长度为 {len_decimal_data} ,将自动截取前 16 位有效精度!")
274
+ decimal_data = decimal_data[:16]
275
+ len_decimal_data = 16
276
+
277
+ output_decimal = 0
278
+ for index in range(len(decimal_data) - 1, -1, -1):
279
+ unit_key = NUMBER_CN2AN[decimal_data[index]]
280
+ output_decimal += unit_key * 10 ** -(index + 1)
281
+
282
+ # 处理精度溢出问题
283
+ output_decimal = round(output_decimal, len_decimal_data)
284
+
285
+ return output_decimal
286
+
287
+ def __direct_convert(self, data: str) -> int:
288
+ output_data = 0
289
+ for index in range(len(data) - 1, -1, -1):
290
+ unit_key = NUMBER_CN2AN[data[index]]
291
+ output_data += unit_key * 10 ** (len(data) - index - 1)
292
+
293
+ return output_data