KdaiP commited on
Commit
d358e26
1 Parent(s): a8e40a4

Upload 238 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. app.py +138 -0
  3. audios/audio0.flac +0 -0
  4. audios/audio1.flac +0 -0
  5. audios/audio2.flac +0 -0
  6. audios/audio3.flac +0 -0
  7. audios/audio4.flac +0 -0
  8. audios/audio5.flac +0 -0
  9. audios/audio6.flac +0 -0
  10. checkpoints/.keep +0 -0
  11. checkpoints/checkpoint_0.pt +3 -0
  12. checkpoints/vocoder.pt +3 -0
  13. config.py +52 -0
  14. datas/__init__.py +0 -0
  15. datas/__pycache__/__init__.cpython-311.pyc +0 -0
  16. datas/__pycache__/dataset.cpython-311.pyc +0 -0
  17. datas/dataset.py +52 -0
  18. datas/sampler.py +121 -0
  19. models/__init__.py +0 -0
  20. models/__pycache__/__init__.cpython-310.pyc +0 -0
  21. models/__pycache__/__init__.cpython-311.pyc +0 -0
  22. models/__pycache__/dit.cpython-311.pyc +0 -0
  23. models/__pycache__/duration_predictor.cpython-311.pyc +0 -0
  24. models/__pycache__/estimator.cpython-311.pyc +0 -0
  25. models/__pycache__/flow_matching.cpython-311.pyc +0 -0
  26. models/__pycache__/model.cpython-310.pyc +0 -0
  27. models/__pycache__/model.cpython-311.pyc +0 -0
  28. models/__pycache__/reference_encoder.cpython-311.pyc +0 -0
  29. models/__pycache__/text_encoder.cpython-311.pyc +0 -0
  30. models/dit.py +205 -0
  31. models/duration_predictor.py +40 -0
  32. models/estimator.py +161 -0
  33. models/flow_matching.py +108 -0
  34. models/model.py +194 -0
  35. models/reference_encoder.py +93 -0
  36. models/text_encoder.py +49 -0
  37. monotonic_align/__init__.py +16 -0
  38. monotonic_align/__pycache__/__init__.cpython-310.pyc +0 -0
  39. monotonic_align/__pycache__/__init__.cpython-311.pyc +0 -0
  40. monotonic_align/__pycache__/core.cpython-310.pyc +0 -0
  41. monotonic_align/__pycache__/core.cpython-311.pyc +0 -0
  42. monotonic_align/core.py +46 -0
  43. requirements.txt +12 -0
  44. text/LICENSE +19 -0
  45. text/__init__.py +71 -0
  46. text/__pycache__/__init__.cpython-310.pyc +0 -0
  47. text/__pycache__/__init__.cpython-311.pyc +0 -0
  48. text/__pycache__/cleaners.cpython-310.pyc +0 -0
  49. text/__pycache__/cleaners.cpython-311.pyc +0 -0
  50. text/__pycache__/english.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ text/custom_pypinyin_dict/__pycache__/cc_cedict_0.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
37
+ text/custom_pypinyin_dict/__pycache__/cc_cedict_0.cpython-38.pyc filter=lfs diff=lfs merge=lfs -text
38
+ text/custom_pypinyin_dict/__pycache__/cc_cedict_1.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
39
+ text/custom_pypinyin_dict/__pycache__/cc_cedict_1.cpython-38.pyc filter=lfs diff=lfs merge=lfs -text
40
+ text/custom_pypinyin_dict/__pycache__/cc_cedict_2.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
41
+ text/custom_pypinyin_dict/__pycache__/cc_cedict_2.cpython-38.pyc filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from dataclasses import asdict
4
+ from text import symbols
5
+ import torch
6
+ import torchaudio
7
+
8
+ from utils.audio import LogMelSpectrogram
9
+ from config import ModelConfig, VocosConfig, MelConfig
10
+ from models.model import StableTTS
11
+ from vocos_pytorch.models.model import Vocos
12
+ from text.english import english_to_ipa2
13
+ from text import cleaned_text_to_sequence
14
+ from datas.dataset import intersperse
15
+
16
+ import gradio as gr
17
+ import numpy as np
18
+ import matplotlib.pyplot as plt
19
+ from pathlib import Path
20
+
21
+ device = 'cpu'
22
+
23
+ @ torch.inference_mode()
24
+ def inference(text: str, ref_audio: torch.Tensor, checkpoint_path: str, step: int=10) -> torch.Tensor:
25
+ global last_checkpoint_path
26
+ if checkpoint_path != last_checkpoint_path:
27
+ tts_model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
28
+ last_checkpoint_path = checkpoint_path
29
+
30
+ phonemizer = english_to_ipa2
31
+
32
+ # prepare input for tts model
33
+ x = torch.tensor(intersperse(cleaned_text_to_sequence(phonemizer(text)), item=0), dtype=torch.long, device=device).unsqueeze(0)
34
+ x_len = torch.tensor([x.size(-1)], dtype=torch.long, device=device)
35
+ waveform, sr = torchaudio.load(ref_audio)
36
+ if sr != sample_rate:
37
+ waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
38
+ y = mel_extractor(waveform).to(device)
39
+
40
+ # inference
41
+ mel = tts_model.synthesise(x, x_len, step, y=y, temperature=0.667, length_scale=1)['decoder_outputs']
42
+ audio = vocoder(mel)
43
+
44
+ # process output for gradio
45
+ audio_output = (sample_rate, (audio.cpu().squeeze(0).numpy() * 32767).astype(np.int16)) # (samplerate, int16 audio) for gr.Audio
46
+ mel_output = plot_mel_spectrogram(mel.cpu().squeeze(0).numpy()) # get the plot of mel
47
+ return audio_output, mel_output
48
+
49
+ def get_pipeline(n_vocab: int, tts_model_config: ModelConfig, mel_config: MelConfig, vocoder_config: VocosConfig, tts_checkpoint_path, vocoder_checkpoint_path):
50
+ tts_model = StableTTS(n_vocab, mel_config.n_mels, **asdict(tts_model_config))
51
+ mel_extractor = LogMelSpectrogram(mel_config)
52
+ vocoder = Vocos(vocoder_config, mel_config)
53
+ # tts_model.load_state_dict(torch.load(tts_checkpoint_path, map_location='cpu'))
54
+ tts_model.to(device)
55
+ tts_model.eval()
56
+ vocoder.load_state_dict(torch.load(vocoder_checkpoint_path, map_location='cpu'))
57
+ vocoder.to(device)
58
+ vocoder.eval()
59
+ return tts_model, mel_extractor, vocoder
60
+
61
+ def plot_mel_spectrogram(mel_spectrogram):
62
+ fig, ax = plt.subplots(figsize=(20, 8))
63
+ ax.imshow(mel_spectrogram, aspect='auto', origin='lower')
64
+ plt.axis('off')
65
+ fig.subplots_adjust(left=0, right=1, top=1, bottom=0) # remove white edges
66
+ return fig
67
+
68
+
69
+ def main():
70
+ tts_model_config = ModelConfig()
71
+ mel_config = MelConfig()
72
+ vocoder_config = VocosConfig()
73
+
74
+ tts_checkpoint_path = './checkpoints' # the folder that contains StableTTS checkpoints
75
+ vocoder_checkpoint_path = './checkpoints/vocoder.pt'
76
+
77
+ global tts_model, mel_extractor, vocoder, sample_rate, last_checkpoint_path
78
+ sample_rate = mel_config.sample_rate
79
+ last_checkpoint_path = None
80
+ tts_model, mel_extractor, vocoder = get_pipeline(len(symbols), tts_model_config, mel_config, vocoder_config, tts_checkpoint_path, vocoder_checkpoint_path)
81
+
82
+ tts_checkpoint_path = [path for path in Path(tts_checkpoint_path).rglob('*.pt') if 'optimizer' and 'vocoder' not in path.name]
83
+ audios = list(Path('./audios').rglob('*.wav')) + list(Path('./audios').rglob('*.flac'))
84
+
85
+ # gradio wabui
86
+ gui_title = 'StableTTS'
87
+ gui_description = """Next-generation TTS model using flow-matching and DiT, inspired by Stable Diffusion 3."""
88
+ with gr.Blocks(analytics_enabled=False) as demo:
89
+
90
+ with gr.Row():
91
+ with gr.Column():
92
+ gr.Markdown(f"# {gui_title}")
93
+ gr.Markdown(gui_description)
94
+
95
+ with gr.Row():
96
+ with gr.Column():
97
+ input_text_gr = gr.Textbox(
98
+ label="Input Text",
99
+ info="One or two sentences at a time is better. Up to 200 text characters.",
100
+ value="Today I want to tell you three stories from my life. That's it. No big deal. Just three stories.",
101
+ )
102
+
103
+ ref_audio_gr = gr.Dropdown(
104
+ label='reference audio',
105
+ choices=audios,
106
+ value = 0
107
+ )
108
+
109
+
110
+ checkpoint_gr = gr.Dropdown(
111
+ label='checkpoint',
112
+ choices=tts_checkpoint_path,
113
+ value = 0
114
+ )
115
+
116
+ step_gr = gr.Slider(
117
+ label='Step',
118
+ minimum=1,
119
+ maximum=40,
120
+ value=8,
121
+ step=1
122
+ )
123
+
124
+
125
+ tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
126
+
127
+ with gr.Column():
128
+ mel_gr = gr.Plot(label="Mel Visual")
129
+ audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
130
+
131
+ tts_button.click(inference, [input_text_gr, ref_audio_gr, checkpoint_gr, step_gr], outputs=[audio_gr, mel_gr])
132
+
133
+ demo.queue()
134
+ demo.launch(debug=True, show_api=True)
135
+
136
+
137
+ if __name__ == '__main__':
138
+ main()
audios/audio0.flac ADDED
Binary file (84 kB). View file
 
audios/audio1.flac ADDED
Binary file (151 kB). View file
 
audios/audio2.flac ADDED
Binary file (318 kB). View file
 
audios/audio3.flac ADDED
Binary file (162 kB). View file
 
audios/audio4.flac ADDED
Binary file (260 kB). View file
 
audios/audio5.flac ADDED
Binary file (361 kB). View file
 
audios/audio6.flac ADDED
Binary file (99.1 kB). View file
 
checkpoints/.keep ADDED
File without changes
checkpoints/checkpoint_0.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c473c5dc40bf87e8472f8688790cb20a2ec10494fa08cb710d657cf1f892d44
3
+ size 37552627
checkpoints/vocoder.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e180a1df6ca0a9e382e0915b0b1984aecfb63397c4ee21f12857447c2d76d29a
3
+ size 56666508
config.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class MelConfig:
5
+ sample_rate: int = 44100
6
+ n_fft: int = 2048
7
+ win_length: int = 2048
8
+ hop_length: int = 512
9
+ f_min: float = 0.0
10
+ f_max: float = None
11
+ pad: int = 0
12
+ n_mels: int = 128
13
+ power: float = 1.0
14
+ normalized: bool = False
15
+ center: bool = False
16
+ pad_mode: str = "reflect"
17
+ mel_scale: str = "htk"
18
+
19
+ def __post_init__(self):
20
+ if self.pad == 0:
21
+ self.pad = (self.n_fft - self.hop_length) // 2
22
+
23
+ @dataclass
24
+ class ModelConfig:
25
+ hidden_channels: int = 192
26
+ filter_channels: int = 512
27
+ n_heads: int = 2
28
+ n_enc_layers: int = 3
29
+ n_dec_layers: int = 2
30
+ kernel_size: int = 3
31
+ p_dropout: int = 0.1
32
+ gin_channels: int = 192
33
+
34
+ @dataclass
35
+ class TrainConfig:
36
+ train_dataset_path: str = 'filelists/filelist.json'
37
+ test_dataset_path: str = 'filelists/filelist.json'
38
+ batch_size: int = 52
39
+ learning_rate: float = 1e-4
40
+ num_epochs: int = 10000
41
+ model_save_path: str = './checkpoints'
42
+ log_dir: str = './runs'
43
+ log_interval: int = 128
44
+ save_interval: int = 15
45
+ warmup_steps: int = 200
46
+
47
+ @dataclass
48
+ class VocosConfig:
49
+ input_channels: int = 128
50
+ dim: int = 512
51
+ intermediate_dim: int = 1536
52
+ num_layers: int = 8
datas/__init__.py ADDED
File without changes
datas/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (154 Bytes). View file
 
datas/__pycache__/dataset.cpython-311.pyc ADDED
Binary file (4.57 kB). View file
 
datas/dataset.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import json
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+
7
+ from text import cleaned_text_to_sequence
8
+
9
+ def intersperse(lst, item):
10
+ result = [item] * (len(lst) * 2 + 1)
11
+ result[1::2] = lst
12
+ return result
13
+
14
+ class StableDataset(Dataset):
15
+ def __init__(self, filelist_path, hop_length):
16
+ self.filelist_path = filelist_path
17
+ self.hop_length = hop_length
18
+
19
+ self._load_filelist(filelist_path)
20
+
21
+ def _load_filelist(self, filelist_path):
22
+ filelist, lengths = [], []
23
+ with open(filelist_path, 'r', encoding='utf-8') as f:
24
+ for line in f:
25
+ line = json.loads(line.strip())
26
+ filelist.append((line['mel_path'], line['phone']))
27
+ lengths.append(os.path.getsize(line['audio_path']) // (2 * self.hop_length))
28
+
29
+ self.filelist = filelist
30
+ self.lengths = lengths
31
+
32
+ def __len__(self):
33
+ return len(self.filelist)
34
+
35
+ def __getitem__(self, idx):
36
+ mel_path, phone = self.filelist[idx]
37
+ mel = torch.load(mel_path, map_location='cpu')
38
+ phone = torch.tensor(intersperse(cleaned_text_to_sequence(phone), 0), dtype=torch.long)
39
+ return mel, phone
40
+
41
+ def collate_fn(batch):
42
+ texts = [item[1] for item in batch]
43
+ mels = [item[0] for item in batch]
44
+
45
+ text_lengths = torch.tensor([text.size(-1) for text in texts], dtype=torch.long)
46
+ mel_lengths = torch.tensor([mel.size(-1) for mel in mels], dtype=torch.long)
47
+
48
+ # pad to the same length
49
+ texts_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(texts), padding=0)
50
+ mels_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(mels), padding=0)
51
+
52
+ return texts_padded, text_lengths, mels_padded, mel_lengths
datas/sampler.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # reference: https://github.com/jaywalnut310/vits/blob/main/data_utils.py
4
+ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
5
+ """
6
+ Maintain similar input lengths in a batch.
7
+ Length groups are specified by boundaries.
8
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
9
+
10
+ It removes samples which are not included in the boundaries.
11
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ dataset,
17
+ batch_size,
18
+ boundaries,
19
+ num_replicas=None,
20
+ rank=None,
21
+ shuffle=True,
22
+ ):
23
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
24
+ self.lengths = dataset.lengths
25
+ self.batch_size = batch_size
26
+ self.boundaries = boundaries
27
+
28
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
29
+ self.total_size = sum(self.num_samples_per_bucket)
30
+ self.num_samples = self.total_size // self.num_replicas
31
+
32
+ def _create_buckets(self):
33
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
34
+ for i in range(len(self.lengths)):
35
+ length = self.lengths[i]
36
+ idx_bucket = self._bisect(length)
37
+ if idx_bucket != -1:
38
+ buckets[idx_bucket].append(i)
39
+
40
+ for i in range(len(buckets) - 1, 0, -1):
41
+ # for i in range(len(buckets) - 1, -1, -1):
42
+ if len(buckets[i]) == 0:
43
+ buckets.pop(i)
44
+ self.boundaries.pop(i + 1)
45
+
46
+ num_samples_per_bucket = []
47
+ for i in range(len(buckets)):
48
+ len_bucket = len(buckets[i])
49
+ total_batch_size = self.num_replicas * self.batch_size
50
+ rem = (
51
+ total_batch_size - (len_bucket % total_batch_size)
52
+ ) % total_batch_size
53
+ num_samples_per_bucket.append(len_bucket + rem)
54
+ return buckets, num_samples_per_bucket
55
+
56
+ def __iter__(self):
57
+ # deterministically shuffle based on epoch
58
+ g = torch.Generator()
59
+ g.manual_seed(self.epoch)
60
+
61
+ indices = []
62
+ if self.shuffle:
63
+ for bucket in self.buckets:
64
+ indices.append(torch.randperm(len(bucket), generator=g).tolist())
65
+ else:
66
+ for bucket in self.buckets:
67
+ indices.append(list(range(len(bucket))))
68
+
69
+ batches = []
70
+ for i in range(len(self.buckets)):
71
+ bucket = self.buckets[i]
72
+ len_bucket = len(bucket)
73
+ ids_bucket = indices[i]
74
+ num_samples_bucket = self.num_samples_per_bucket[i]
75
+
76
+ # add extra samples to make it evenly divisible
77
+ rem = num_samples_bucket - len_bucket
78
+ ids_bucket = (
79
+ ids_bucket
80
+ + ids_bucket * (rem // len_bucket)
81
+ + ids_bucket[: (rem % len_bucket)]
82
+ )
83
+
84
+ # subsample
85
+ ids_bucket = ids_bucket[self.rank :: self.num_replicas]
86
+
87
+ # batching
88
+ for j in range(len(ids_bucket) // self.batch_size):
89
+ batch = [
90
+ bucket[idx]
91
+ for idx in ids_bucket[
92
+ j * self.batch_size : (j + 1) * self.batch_size
93
+ ]
94
+ ]
95
+ batches.append(batch)
96
+
97
+ if self.shuffle:
98
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
99
+ batches = [batches[i] for i in batch_ids]
100
+ self.batches = batches
101
+
102
+ assert len(self.batches) * self.batch_size == self.num_samples
103
+ return iter(self.batches)
104
+
105
+ def _bisect(self, x, lo=0, hi=None):
106
+ if hi is None:
107
+ hi = len(self.boundaries) - 1
108
+
109
+ if hi > lo:
110
+ mid = (hi + lo) // 2
111
+ if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
112
+ return mid
113
+ elif x <= self.boundaries[mid]:
114
+ return self._bisect(x, lo, mid)
115
+ else:
116
+ return self._bisect(x, mid + 1, hi)
117
+ else:
118
+ return -1
119
+
120
+ def __len__(self):
121
+ return self.num_samples // self.batch_size
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (135 Bytes). View file
 
models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (155 Bytes). View file
 
models/__pycache__/dit.cpython-311.pyc ADDED
Binary file (13.7 kB). View file
 
models/__pycache__/duration_predictor.cpython-311.pyc ADDED
Binary file (3.2 kB). View file
 
models/__pycache__/estimator.cpython-311.pyc ADDED
Binary file (12.7 kB). View file
 
models/__pycache__/flow_matching.cpython-311.pyc ADDED
Binary file (5.92 kB). View file
 
models/__pycache__/model.cpython-310.pyc ADDED
Binary file (6.49 kB). View file
 
models/__pycache__/model.cpython-311.pyc ADDED
Binary file (11.6 kB). View file
 
models/__pycache__/reference_encoder.cpython-311.pyc ADDED
Binary file (5.32 kB). View file
 
models/__pycache__/text_encoder.cpython-311.pyc ADDED
Binary file (4.22 kB). View file
 
models/dit.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References:
2
+ # https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/transformer.py
3
+ # https://github.com/jaywalnut310/vits/blob/main/attentions.py
4
+ # https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ class FFN(nn.Module):
11
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., gin_channels=0):
12
+ super().__init__()
13
+ self.in_channels = in_channels
14
+ self.out_channels = out_channels
15
+ self.filter_channels = filter_channels
16
+ self.kernel_size = kernel_size
17
+ self.p_dropout = p_dropout
18
+ self.gin_channels = gin_channels
19
+
20
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
21
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
22
+ self.drop = nn.Dropout(p_dropout)
23
+ self.act1 = nn.GELU(approximate="tanh")
24
+
25
+ def forward(self, x, x_mask):
26
+ x = self.conv_1(x * x_mask)
27
+ x = self.act1(x)
28
+ x = self.drop(x)
29
+ x = self.conv_2(x * x_mask)
30
+ return x * x_mask
31
+
32
+ class MultiHeadAttention(nn.Module):
33
+ def __init__(self, channels, out_channels, n_heads, p_dropout=0.):
34
+ super().__init__()
35
+ assert channels % n_heads == 0
36
+
37
+ self.channels = channels
38
+ self.out_channels = out_channels
39
+ self.n_heads = n_heads
40
+ self.p_dropout = p_dropout
41
+
42
+ self.k_channels = channels // n_heads
43
+ self.conv_q = torch.nn.Conv1d(channels, channels, 1)
44
+ self.conv_k = torch.nn.Conv1d(channels, channels, 1)
45
+ self.conv_v = torch.nn.Conv1d(channels, channels, 1)
46
+
47
+ # from https://nn.labml.ai/transformers/rope/index.html
48
+ self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
49
+ self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
50
+
51
+ self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
52
+ self.drop = torch.nn.Dropout(p_dropout)
53
+
54
+ torch.nn.init.xavier_uniform_(self.conv_q.weight)
55
+ torch.nn.init.xavier_uniform_(self.conv_k.weight)
56
+ torch.nn.init.xavier_uniform_(self.conv_v.weight)
57
+
58
+ def forward(self, x, attn_mask=None):
59
+ q = self.conv_q(x)
60
+ k = self.conv_k(x)
61
+ v = self.conv_v(x)
62
+
63
+ x = self.attention(q, k, v, mask=attn_mask)
64
+
65
+ x = self.conv_o(x)
66
+ return x
67
+
68
+ def attention(self, query, key, value, mask=None):
69
+ b, d, t_s, t_t = (*key.size(), query.size(2))
70
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
71
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
72
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
73
+
74
+ query = self.query_rotary_pe(query) # [b, n_head, t, c // n_head]
75
+ key = self.key_rotary_pe(key)
76
+
77
+ output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=self.p_dropout if self.training else 0)
78
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
79
+ return output
80
+
81
+ # modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/modules.py#L390
82
+ class DiTConVBlock(nn.Module):
83
+ """
84
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
85
+ """
86
+ def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0):
87
+ super().__init__()
88
+ self.norm1 = nn.LayerNorm(hidden_channels, elementwise_affine=False, eps=1e-6)
89
+ self.attn = MultiHeadAttention(hidden_channels, hidden_channels, num_heads, p_dropout)
90
+ self.norm2 = nn.LayerNorm(hidden_channels, elementwise_affine=False, eps=1e-6)
91
+ self.mlp = FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)
92
+ self.adaLN_modulation = nn.Sequential(
93
+ nn.Linear(gin_channels, hidden_channels) if gin_channels != hidden_channels else nn.Identity(),
94
+ nn.SiLU(),
95
+ nn.Linear(hidden_channels, 6 * hidden_channels, bias=True)
96
+ )
97
+
98
+ def forward(self, x, c, x_mask):
99
+ """
100
+ Args:
101
+ x : [batch_size, channel, time]
102
+ c : [batch_size, channel]
103
+ x_mask : [batch_size, 1, time]
104
+ return the same shape as x
105
+ """
106
+ x = x * x_mask
107
+ attn_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(-1) # shape: [batch_size, 1, time, time]
108
+ # attn_mask = attn_mask.to(torch.bool)
109
+
110
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).unsqueeze(2).chunk(6, dim=1) # shape: [batch_size, channel, 1]
111
+ x = x + gate_msa * self.attn(self.modulate(self.norm1(x.transpose(1,2)).transpose(1,2), shift_msa, scale_msa), attn_mask) * x_mask
112
+ x = x + gate_mlp * self.mlp(self.modulate(self.norm2(x.transpose(1,2)).transpose(1,2), shift_mlp, scale_mlp), x_mask)
113
+
114
+ # no condition version
115
+ # x = x + self.attn(self.norm1(x.transpose(1,2)).transpose(1,2), attn_mask)
116
+ # x = x + self.mlp(self.norm1(x.transpose(1,2)).transpose(1,2), x_mask)
117
+ return x
118
+
119
+ @staticmethod
120
+ def modulate(x, shift, scale):
121
+ return x * (1 + scale) + shift
122
+
123
+ class RotaryPositionalEmbeddings(nn.Module):
124
+ """
125
+ ## RoPE module
126
+
127
+ Rotary encoding transforms pairs of features by rotating in the 2D plane.
128
+ That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
129
+ Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
130
+ by an angle depending on the position of the token.
131
+ """
132
+
133
+ def __init__(self, d: int, base: int = 10_000):
134
+ r"""
135
+ * `d` is the number of features $d$
136
+ * `base` is the constant used for calculating $\Theta$
137
+ """
138
+ super().__init__()
139
+
140
+ self.base = base
141
+ self.d = int(d)
142
+ self.cos_cached = None
143
+ self.sin_cached = None
144
+
145
+ def _build_cache(self, x: torch.Tensor):
146
+ r"""
147
+ Cache $\cos$ and $\sin$ values
148
+ """
149
+ # Return if cache is already built
150
+ if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
151
+ return
152
+
153
+ # Get sequence length
154
+ seq_len = x.shape[0]
155
+
156
+ # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
157
+ theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
158
+
159
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
160
+ seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
161
+
162
+ # Calculate the product of position index and $\theta_i$
163
+ idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
164
+
165
+ # Concatenate so that for row $m$ we have
166
+ # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
167
+ idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
168
+
169
+ # Cache them
170
+ self.cos_cached = idx_theta2.cos()[:, None, None, :]
171
+ self.sin_cached = idx_theta2.sin()[:, None, None, :]
172
+
173
+ def _neg_half(self, x: torch.Tensor):
174
+ # $\frac{d}{2}$
175
+ d_2 = self.d // 2
176
+
177
+ # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
178
+ return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
179
+
180
+ def forward(self, x: torch.Tensor):
181
+ """
182
+ * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
183
+ """
184
+ # Cache $\cos$ and $\sin$ values
185
+ x = x.permute(2, 0, 1, 3) # b h t d -> t b h d
186
+
187
+ self._build_cache(x)
188
+
189
+ # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
190
+ x_rope, x_pass = x[..., : self.d], x[..., self.d :]
191
+
192
+ # Calculate
193
+ # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
194
+ neg_half_x = self._neg_half(x_rope)
195
+
196
+ x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
197
+
198
+ return torch.cat((x_rope, x_pass), dim=-1).permute(1, 2, 0, 3) # t b h d -> b h t d
199
+
200
+ class Transpose(nn.Identity):
201
+ """(N, T, D) -> (N, D, T)"""
202
+
203
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
204
+ return input.transpose(1, 2)
205
+
models/duration_predictor.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ # modified from https://github.com/jaywalnut310/vits/blob/main/models.py#L98
5
+ class DurationPredictor(nn.Module):
6
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
7
+ super().__init__()
8
+
9
+ self.in_channels = in_channels
10
+ self.filter_channels = filter_channels
11
+ self.kernel_size = kernel_size
12
+ self.p_dropout = p_dropout
13
+ self.gin_channels = gin_channels
14
+
15
+ self.drop = nn.Dropout(p_dropout)
16
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
17
+ self.norm_1 = nn.LayerNorm(filter_channels)
18
+ self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
19
+ self.norm_2 = nn.LayerNorm(filter_channels)
20
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
21
+
22
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
23
+
24
+ def forward(self, x, x_mask, g):
25
+ x = x.detach()
26
+ x = x + self.cond(g.unsqueeze(2).detach())
27
+ x = self.conv_1(x * x_mask)
28
+ x = torch.relu(x)
29
+ x = self.norm_1(x.transpose(1,2)).transpose(1,2)
30
+ x = self.drop(x)
31
+ x = self.conv_2(x * x_mask)
32
+ x = torch.relu(x)
33
+ x = self.norm_2(x.transpose(1,2)).transpose(1,2)
34
+ x = self.drop(x)
35
+ x = self.proj(x * x_mask)
36
+ return x * x_mask
37
+
38
+ def duration_loss(logw, logw_, lengths):
39
+ loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
40
+ return loss
models/estimator.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from models.dit import DiTConVBlock
8
+
9
+ class DitWrapper(nn.Module):
10
+ """ add FiLM layer to condition time embedding to DiT """
11
+ def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0, time_channels=0):
12
+ super().__init__()
13
+ self.time_fusion = FiLMLayer(hidden_channels, time_channels)
14
+ self.conv1 = ConvNeXtBlock(hidden_channels, filter_channels, gin_channels)
15
+ self.conv2 = ConvNeXtBlock(hidden_channels, filter_channels, gin_channels)
16
+ self.conv3 = ConvNeXtBlock(hidden_channels, filter_channels, gin_channels)
17
+ self.block = DiTConVBlock(hidden_channels, hidden_channels, num_heads, kernel_size, p_dropout, gin_channels)
18
+
19
+ def forward(self, x, c, t, x_mask):
20
+ x = self.time_fusion(x, t) * x_mask
21
+ x = self.conv1(x, c, x_mask)
22
+ x = self.conv2(x, c, x_mask)
23
+ x = self.conv3(x, c, x_mask)
24
+ x = self.block(x, c, x_mask)
25
+ return x
26
+
27
+ class FiLMLayer(nn.Module):
28
+ """
29
+ Feature-wise Linear Modulation (FiLM) layer
30
+ Reference: https://arxiv.org/abs/1709.07871
31
+ """
32
+ def __init__(self, in_channels, cond_channels):
33
+
34
+ super(FiLMLayer, self).__init__()
35
+ self.in_channels = in_channels
36
+ self.film = nn.Conv1d(cond_channels, in_channels * 2, 1)
37
+
38
+ def forward(self, x, c):
39
+ gamma, beta = torch.chunk(self.film(c.unsqueeze(2)), chunks=2, dim=1)
40
+ return gamma * x + beta
41
+
42
+ class ConvNeXtBlock(nn.Module):
43
+ def __init__(self, in_channels, filter_channels, gin_channels):
44
+ super().__init__()
45
+ self.dwconv = nn.Conv1d(in_channels, in_channels, kernel_size=7, padding=3, groups=in_channels)
46
+ self.norm = StyleAdaptiveLayerNorm(in_channels, gin_channels)
47
+ self.pwconv = nn.Sequential(nn.Linear(in_channels, filter_channels),
48
+ nn.GELU(),
49
+ nn.Linear(filter_channels, in_channels))
50
+
51
+ def forward(self, x, c, x_mask) -> torch.Tensor:
52
+ residual = x
53
+ x = self.dwconv(x) * x_mask
54
+ x = self.norm(x.transpose(1, 2), c)
55
+ x = self.pwconv(x).transpose(1, 2)
56
+ x = residual + x
57
+ return x * x_mask
58
+
59
+ class StyleAdaptiveLayerNorm(nn.Module):
60
+ def __init__(self, in_channels, cond_channels):
61
+ """
62
+ Style Adaptive Layer Normalization (SALN) module.
63
+
64
+ Parameters:
65
+ in_channels: The number of channels in the input feature maps.
66
+ cond_channels: The number of channels in the conditioning input.
67
+ """
68
+ super(StyleAdaptiveLayerNorm, self).__init__()
69
+ self.in_channels = in_channels
70
+
71
+ self.saln = nn.Linear(cond_channels, in_channels * 2, 1)
72
+ self.norm = nn.LayerNorm(in_channels, elementwise_affine=False)
73
+
74
+ self.reset_parameters()
75
+
76
+ def reset_parameters(self):
77
+ nn.init.constant_(self.saln.bias.data[:self.in_channels], 1)
78
+ nn.init.constant_(self.saln.bias.data[self.in_channels:], 0)
79
+
80
+ def forward(self, x, c):
81
+ gamma, beta = torch.chunk(self.saln(c.unsqueeze(1)), chunks=2, dim=-1)
82
+ return gamma * self.norm(x) + beta
83
+
84
+
85
+ class SinusoidalPosEmb(nn.Module):
86
+ def __init__(self, dim):
87
+ super().__init__()
88
+ self.dim = dim
89
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
90
+
91
+ def forward(self, x, scale=1000):
92
+ if x.ndim < 1:
93
+ x = x.unsqueeze(0)
94
+ half_dim = self.dim // 2
95
+ emb = math.log(10000) / (half_dim - 1)
96
+ emb = torch.exp(torch.arange(half_dim, device=x.device).float() * -emb)
97
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
98
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
99
+ return emb
100
+
101
+ class TimestepEmbedding(nn.Module):
102
+ def __init__(self, in_channels, out_channels, filter_channels):
103
+ super().__init__()
104
+
105
+ self.layer = nn.Sequential(
106
+ nn.Linear(in_channels, filter_channels),
107
+ nn.SiLU(inplace=True),
108
+ nn.Linear(filter_channels, out_channels)
109
+ )
110
+
111
+ def forward(self, x):
112
+ return self.layer(x)
113
+
114
+ # reference: https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/decoder.py
115
+ class Decoder(nn.Module):
116
+ def __init__(self, hidden_channels, out_channels, filter_channels, dropout=0.05, n_layers=1, n_heads=4, kernel_size=3, gin_channels=0):
117
+ super().__init__()
118
+ self.hidden_channels = hidden_channels
119
+ self.out_channels = out_channels
120
+ self.filter_channels = filter_channels
121
+
122
+ self.time_embeddings = SinusoidalPosEmb(hidden_channels)
123
+ self.time_mlp = TimestepEmbedding(hidden_channels, hidden_channels, filter_channels)
124
+
125
+
126
+ self.blocks = nn.ModuleList([DitWrapper(hidden_channels, filter_channels, n_heads, kernel_size, dropout, gin_channels, hidden_channels) for _ in range(n_layers)])
127
+ self.final_proj = nn.Conv1d(hidden_channels, out_channels, 1)
128
+
129
+ self.initialize_weights()
130
+
131
+ def initialize_weights(self):
132
+ for block in self.blocks:
133
+ nn.init.constant_(block.block.adaLN_modulation[-1].weight, 0)
134
+ nn.init.constant_(block.block.adaLN_modulation[-1].bias, 0)
135
+
136
+ def forward(self, x, mask, mu, t, c):
137
+ """Forward pass of the UNet1DConditional model.
138
+
139
+ Args:
140
+ x (torch.Tensor): shape (batch_size, in_channels, time)
141
+ mask (_type_): shape (batch_size, 1, time)
142
+ t (_type_): shape (batch_size)
143
+ c (_type_): shape (batch_size, gin_channels)
144
+
145
+ Raises:
146
+ ValueError: _description_
147
+ ValueError: _description_
148
+
149
+ Returns:
150
+ _type_: _description_
151
+ """
152
+
153
+ t = self.time_mlp(self.time_embeddings(t))
154
+ x = torch.cat((x, mu), dim=1)
155
+
156
+ for block in self.blocks:
157
+ x = block(x, c, t, mask)
158
+
159
+ output = self.final_proj(x * mask)
160
+
161
+ return output * mask
models/flow_matching.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from models.estimator import Decoder
6
+
7
+ # copied from https://github.com/jaywalnut310/vits/blob/main/commons.py#L121
8
+ def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
9
+ if max_length is None:
10
+ max_length = length.max()
11
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
12
+ return x.unsqueeze(0) < length.unsqueeze(1)
13
+
14
+ # modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/flow_matching.py
15
+ class CFMDecoder(torch.nn.Module):
16
+ def __init__(self, hidden_channels, out_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
17
+ super().__init__()
18
+ self.hidden_channels = hidden_channels
19
+ self.out_channels = out_channels
20
+ self.filter_channels = filter_channels
21
+ self.gin_channels = gin_channels
22
+ self.sigma_min = 1e-4
23
+
24
+ self.estimator = Decoder(hidden_channels, out_channels, filter_channels, p_dropout, n_layers, n_heads, kernel_size, gin_channels)
25
+
26
+ @torch.inference_mode()
27
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, c=None):
28
+ """Forward diffusion
29
+
30
+ Args:
31
+ mu (torch.Tensor): output of encoder
32
+ shape: (batch_size, n_feats, mel_timesteps)
33
+ mask (torch.Tensor): output_mask
34
+ shape: (batch_size, 1, mel_timesteps)
35
+ n_timesteps (int): number of diffusion steps
36
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
37
+ c (torch.Tensor, optional): shape: (batch_size, gin_channels)
38
+
39
+ Returns:
40
+ sample: generated mel-spectrogram
41
+ shape: (batch_size, n_feats, mel_timesteps)
42
+ """
43
+ z = torch.randn_like(mu) * temperature
44
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
45
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, c=c)
46
+
47
+ def solve_euler(self, x, t_span, mu, mask, c):
48
+ """
49
+ Fixed euler solver for ODEs.
50
+ Args:
51
+ x (torch.Tensor): random noise
52
+ t_span (torch.Tensor): n_timesteps interpolated
53
+ shape: (n_timesteps + 1,)
54
+ mu (torch.Tensor): output of encoder
55
+ shape: (batch_size, n_feats, mel_timesteps)
56
+ mask (torch.Tensor): output_mask
57
+ shape: (batch_size, 1, mel_timesteps)
58
+ c (torch.Tensor, optional): speaker condition.
59
+ shape: (batch_size, gin_channels)
60
+ """
61
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
62
+
63
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
64
+ # Or in future might add like a return_all_steps flag
65
+ sol = []
66
+
67
+ for step in range(1, len(t_span)):
68
+ dphi_dt = self.estimator(x, mask, mu, t, c)
69
+
70
+ x = x + dt * dphi_dt
71
+ t = t + dt
72
+ sol.append(x)
73
+ if step < len(t_span) - 1:
74
+ dt = t_span[step + 1] - t
75
+
76
+ return sol[-1]
77
+
78
+ def compute_loss(self, x1, mask, mu, c):
79
+ """Computes diffusion loss
80
+
81
+ Args:
82
+ x1 (torch.Tensor): Target
83
+ shape: (batch_size, n_feats, mel_timesteps)
84
+ mask (torch.Tensor): target mask
85
+ shape: (batch_size, 1, mel_timesteps)
86
+ mu (torch.Tensor): output of encoder
87
+ shape: (batch_size, n_feats, mel_timesteps)
88
+ c (torch.Tensor, optional): speaker condition.
89
+
90
+ Returns:
91
+ loss: conditional flow matching loss
92
+ y: conditional flow
93
+ shape: (batch_size, n_feats, mel_timesteps)
94
+ """
95
+ b, _, t = mu.shape
96
+
97
+ # random timestep
98
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
99
+ # sample noise p(x_0)
100
+ z = torch.randn_like(x1)
101
+
102
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
103
+ u = x1 - (1 - self.sigma_min) * z
104
+
105
+ loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), c), u, reduction="sum") / (
106
+ torch.sum(mask) * u.shape[1]
107
+ )
108
+ return loss, y
models/model.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ import monotonic_align
6
+ from models.text_encoder import TextEncoder
7
+ from models.flow_matching import CFMDecoder
8
+ from models.reference_encoder import MelStyleEncoder
9
+ from models.duration_predictor import DurationPredictor, duration_loss
10
+
11
+ def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
12
+ if max_length is None:
13
+ max_length = length.max()
14
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
15
+ return x.unsqueeze(0) < length.unsqueeze(1)
16
+
17
+ def convert_pad_shape(pad_shape):
18
+ inverted_shape = pad_shape[::-1]
19
+ pad_shape = [item for sublist in inverted_shape for item in sublist]
20
+ return pad_shape
21
+
22
+ def generate_path(duration, mask):
23
+ device = duration.device
24
+
25
+ b, t_x, t_y = mask.shape
26
+ cum_duration = torch.cumsum(duration, 1)
27
+ path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
28
+
29
+ cum_duration_flat = cum_duration.view(b * t_x)
30
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
31
+ path = path.view(b, t_x, t_y)
32
+ path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
33
+ path = path * mask
34
+ return path
35
+
36
+ # modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/matcha_tts.py
37
+ class StableTTS(nn.Module):
38
+ def __init__(self, n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, n_dec_layers, kernel_size, p_dropout, gin_channels):
39
+ super().__init__()
40
+
41
+ self.n_vocab = n_vocab
42
+ self.mel_channels = mel_channels
43
+
44
+ self.encoder = TextEncoder(n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, kernel_size, p_dropout, gin_channels)
45
+ self.ref_encoder = MelStyleEncoder(mel_channels, style_vector_dim=gin_channels, style_kernel_size=3)
46
+ self.dp = DurationPredictor(hidden_channels, filter_channels, kernel_size, p_dropout, gin_channels)
47
+ self.decoder = CFMDecoder(mel_channels + mel_channels, mel_channels, filter_channels, n_heads, n_dec_layers, kernel_size, p_dropout, gin_channels)
48
+
49
+ @torch.inference_mode()
50
+ def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, y=None, length_scale=1.0):
51
+ """
52
+ Generates mel-spectrogram from text. Returns:
53
+ 1. encoder outputs
54
+ 2. decoder outputs
55
+ 3. generated alignment
56
+
57
+ Args:
58
+ x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
59
+ shape: (batch_size, max_text_length)
60
+ x_lengths (torch.Tensor): lengths of texts in batch.
61
+ shape: (batch_size,)
62
+ n_timesteps (int): number of steps to use for reverse diffusion in decoder.
63
+ temperature (float, optional): controls variance of terminal distribution.
64
+ y (torch.Tensor): mel spectrogram of reference audio
65
+ shape: (batch_size, mel_channels, time)
66
+ length_scale (float, optional): controls speech pace.
67
+ Increase value to slow down generated speech and vice versa.
68
+
69
+ Returns:
70
+ dict: {
71
+ "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
72
+ # Average mel spectrogram generated by the encoder
73
+ "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
74
+ # Refined mel spectrogram improved by the CFM
75
+ "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
76
+ # Alignment map between text and mel spectrogram
77
+ """
78
+
79
+ # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
80
+ c = self.ref_encoder(y, None)
81
+ x, mu_x, x_mask = self.encoder(x, c, x_lengths)
82
+ logw = self.dp(x, x_mask, c)
83
+
84
+ w = torch.exp(logw) * x_mask
85
+ w_ceil = torch.ceil(w) * length_scale
86
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
87
+ y_max_length = y_lengths.max()
88
+
89
+ # Using obtained durations `w` construct alignment map `attn`
90
+ y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask.dtype)
91
+ attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
92
+ attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
93
+
94
+ # Align encoded text and get mu_y
95
+ mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
96
+ mu_y = mu_y.transpose(1, 2)
97
+ encoder_outputs = mu_y[:, :, :y_max_length]
98
+
99
+ # Generate sample tracing the probability flow
100
+ decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c)
101
+ decoder_outputs = decoder_outputs[:, :, :y_max_length]
102
+
103
+
104
+ return {
105
+ "encoder_outputs": encoder_outputs,
106
+ "decoder_outputs": decoder_outputs,
107
+ "attn": attn[:, :, :y_max_length],
108
+ }
109
+
110
+ def forward(self, x, x_lengths, y, y_lengths):
111
+ """
112
+ Computes 3 losses:
113
+ 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
114
+ 2. prior loss: loss between mel-spectrogram and encoder outputs.
115
+ 3. flow matching loss: loss between mel-spectrogram and decoder outputs.
116
+
117
+ Args:
118
+ x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
119
+ shape: (batch_size, max_text_length)
120
+ x_lengths (torch.Tensor): lengths of texts in batch.
121
+ shape: (batch_size,)
122
+ y (torch.Tensor): batch of corresponding mel-spectrograms.
123
+ shape: (batch_size, n_feats, max_mel_length)
124
+ y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
125
+ shape: (batch_size,)
126
+ """
127
+ # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
128
+ y_mask = sequence_mask(y_lengths, y.size(2)).unsqueeze(1).to(y.dtype)
129
+ c = self.ref_encoder(y, y_mask)
130
+
131
+ x, mu_x, x_mask = self.encoder(x, c, x_lengths)
132
+ logw = self.dp(x, x_mask, c)
133
+
134
+ attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
135
+
136
+ # Use MAS to find most likely alignment `attn` between text and mel-spectrogram
137
+
138
+ # I'm not sure why the MAS code in Matcha TTS and Grad TTS could not align in StableTTS
139
+ # so I use the code from https://github.com/p0p4k/pflowtts_pytorch/blob/master/pflow/models/pflow_tts.py and it works
140
+ # Welcome everyone to solve this problem QAQ
141
+
142
+ with torch.no_grad():
143
+ # const = -0.5 * math.log(2 * math.pi) * self.n_feats
144
+ # const = -0.5 * math.log(2 * math.pi) * self.mel_channels
145
+ # factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
146
+ # y_square = torch.matmul(factor.transpose(1, 2), y**2)
147
+ # y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), y)
148
+ # mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
149
+ # log_prior = y_square - y_mu_double + mu_square + const
150
+
151
+ s_p_sq_r = torch.ones_like(mu_x) # [b, d, t]
152
+ # s_p_sq_r = torch.exp(-2 * logx)
153
+ neg_cent1 = torch.sum(
154
+ -0.5 * math.log(2 * math.pi)- torch.zeros_like(mu_x), [1], keepdim=True
155
+ )
156
+ # neg_cent1 = torch.sum(
157
+ # -0.5 * math.log(2 * math.pi) - logx, [1], keepdim=True
158
+ # ) # [b, 1, t_s]
159
+ neg_cent2 = torch.einsum("bdt, bds -> bts", -0.5 * (y**2), s_p_sq_r)
160
+ neg_cent3 = torch.einsum("bdt, bds -> bts", y, (mu_x * s_p_sq_r))
161
+ neg_cent4 = torch.sum(
162
+ -0.5 * (mu_x**2) * s_p_sq_r, [1], keepdim=True
163
+ )
164
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
165
+
166
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
167
+
168
+ attn = (
169
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
170
+ )
171
+
172
+ # attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
173
+ # attn = attn.detach()
174
+
175
+ # Compute loss between predicted log-scaled durations and those obtained from MAS
176
+ # refered to as prior loss in the paper
177
+ logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask
178
+ # logw_ = torch.log(1e-8 + torch.sum(attn.unsqueeze(1), -1)) * x_mask
179
+ dur_loss = duration_loss(logw, logw_, x_lengths)
180
+
181
+
182
+ # Align encoded text with mel-spectrogram and get mu_y segment
183
+ attn = attn.squeeze(1).transpose(1,2)
184
+ mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
185
+ mu_y = mu_y.transpose(1, 2)
186
+
187
+ # Compute loss of the decoder
188
+ diff_loss, _ = self.decoder.compute_loss(y, y_mask, mu_y, c)
189
+ # diff_loss = torch.tensor([0], device=mu_y.device)
190
+
191
+ prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
192
+ prior_loss = prior_loss / (torch.sum(y_mask) * self.mel_channels)
193
+
194
+ return dur_loss, diff_loss, prior_loss, attn
models/reference_encoder.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Conv1dGLU(nn.Module):
5
+ """
6
+ Conv1d + GLU(Gated Linear Unit) with residual connection.
7
+ For GLU refer to https://arxiv.org/abs/1612.08083 paper.
8
+ """
9
+
10
+ def __init__(self, in_channels, out_channels, kernel_size, dropout):
11
+ super(Conv1dGLU, self).__init__()
12
+ self.out_channels = out_channels
13
+ self.conv1 = nn.Conv1d(in_channels, 2 * out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
14
+ self.dropout = nn.Dropout(dropout)
15
+
16
+ def forward(self, x):
17
+ residual = x
18
+ x = self.conv1(x)
19
+ x1, x2 = torch.split(x, self.out_channels, dim=1)
20
+ x = x1 * torch.sigmoid(x2)
21
+ x = residual + self.dropout(x)
22
+ return x
23
+
24
+ # modified from https://github.com/RVC-Boss/GPT-SoVITS/blob/main/GPT_SoVITS/module/modules.py#L766
25
+ class MelStyleEncoder(nn.Module):
26
+ """MelStyleEncoder"""
27
+
28
+ def __init__(
29
+ self,
30
+ n_mel_channels=80,
31
+ style_hidden=128,
32
+ style_vector_dim=256,
33
+ style_kernel_size=5,
34
+ style_head=2,
35
+ dropout=0.1,
36
+ ):
37
+ super(MelStyleEncoder, self).__init__()
38
+ self.in_dim = n_mel_channels
39
+ self.hidden_dim = style_hidden
40
+ self.out_dim = style_vector_dim
41
+ self.kernel_size = style_kernel_size
42
+ self.n_head = style_head
43
+ self.dropout = dropout
44
+
45
+ self.spectral = nn.Sequential(
46
+ nn.Linear(self.in_dim, self.hidden_dim),
47
+ nn.Mish(inplace=True),
48
+ nn.Dropout(self.dropout),
49
+ nn.Linear(self.hidden_dim, self.hidden_dim),
50
+ nn.Mish(inplace=True),
51
+ nn.Dropout(self.dropout),
52
+ )
53
+
54
+ self.temporal = nn.Sequential(
55
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
56
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
57
+ )
58
+
59
+ self.slf_attn = nn.MultiheadAttention(
60
+ self.hidden_dim,
61
+ self.n_head,
62
+ self.dropout,
63
+ batch_first=True
64
+ )
65
+
66
+ self.fc = nn.Linear(self.hidden_dim, self.out_dim)
67
+
68
+ def temporal_avg_pool(self, x, mask=None):
69
+ if mask is None:
70
+ return torch.mean(x, dim=1)
71
+ else:
72
+ len_ = (~mask).sum(dim=1).unsqueeze(1).type_as(x)
73
+ return torch.sum(x * ~mask.unsqueeze(-1), dim=1) / len_
74
+
75
+ def forward(self, x, x_mask=None):
76
+ x = x.transpose(1, 2)
77
+
78
+ # spectral
79
+ x = self.spectral(x)
80
+ # temporal
81
+ x = x.transpose(1, 2)
82
+ x = self.temporal(x)
83
+ x = x.transpose(1, 2)
84
+ # self-attention
85
+ if x_mask is not None:
86
+ x_mask = ~x_mask.squeeze(1).to(torch.bool)
87
+ x, _ = self.slf_attn(x, x, x, key_padding_mask=x_mask)
88
+ # fc
89
+ x = self.fc(x)
90
+ # temoral average pooling
91
+ w = self.temporal_avg_pool(x, mask=x_mask)
92
+
93
+ return w
models/text_encoder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from models.dit import DiTConVBlock
5
+
6
+ def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
7
+ if max_length is None:
8
+ max_length = length.max()
9
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
10
+ return x.unsqueeze(0) < length.unsqueeze(1)
11
+
12
+ # modified from https://github.com/jaywalnut310/vits/blob/main/models.py
13
+ class TextEncoder(nn.Module):
14
+ def __init__(self, n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
15
+ super().__init__()
16
+ self.n_vocab = n_vocab
17
+ self.out_channels = out_channels
18
+ self.hidden_channels = hidden_channels
19
+ self.filter_channels = filter_channels
20
+ self.n_heads = n_heads
21
+ self.n_layers = n_layers
22
+ self.kernel_size = kernel_size
23
+ self.p_dropout = p_dropout
24
+
25
+ self.scale = self.hidden_channels ** 0.5
26
+
27
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
28
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
29
+
30
+ self.encoder = nn.ModuleList([DiTConVBlock(hidden_channels, filter_channels, n_heads, kernel_size, p_dropout, gin_channels) for _ in range(n_layers)])
31
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
32
+
33
+ self.initialize_weights()
34
+
35
+ def initialize_weights(self):
36
+ for block in self.encoder:
37
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
38
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
39
+
40
+ def forward(self, x: torch.Tensor, c: torch.Tensor, x_lengths: torch.Tensor):
41
+ x = self.emb(x) * self.scale # [b, t, h]
42
+ x = x.transpose(1, -1) # [b, h, t]
43
+ x_mask = sequence_mask(x_lengths, x.size(2)).unsqueeze(1).to(x.dtype)
44
+
45
+ for layer in self.encoder:
46
+ x = layer(x, c, x_mask)
47
+ mu_x = self.proj(x) * x_mask
48
+
49
+ return x, mu_x, x_mask
monotonic_align/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy import zeros, int32, float32
2
+ from torch import from_numpy
3
+
4
+ from .core import maximum_path_jit
5
+
6
+
7
+ def maximum_path(neg_cent, mask):
8
+ device = neg_cent.device
9
+ dtype = neg_cent.dtype
10
+ neg_cent = neg_cent.data.cpu().numpy().astype(float32)
11
+ path = zeros(neg_cent.shape, dtype=int32)
12
+
13
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
14
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
15
+ maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
16
+ return from_numpy(path).to(device=device, dtype=dtype)
monotonic_align/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (732 Bytes). View file
 
monotonic_align/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.64 kB). View file
 
monotonic_align/__pycache__/core.cpython-310.pyc ADDED
Binary file (985 Bytes). View file
 
monotonic_align/__pycache__/core.cpython-311.pyc ADDED
Binary file (2 kB). View file
 
monotonic_align/core.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numba
2
+
3
+
4
+ @numba.jit(
5
+ numba.void(
6
+ numba.int32[:, :, ::1],
7
+ numba.float32[:, :, ::1],
8
+ numba.int32[::1],
9
+ numba.int32[::1],
10
+ ),
11
+ nopython=True,
12
+ nogil=True,
13
+ )
14
+ def maximum_path_jit(paths, values, t_ys, t_xs):
15
+ b = paths.shape[0]
16
+ max_neg_val = -1e9
17
+ for i in range(int(b)):
18
+ path = paths[i]
19
+ value = values[i]
20
+ t_y = t_ys[i]
21
+ t_x = t_xs[i]
22
+
23
+ v_prev = v_cur = 0.0
24
+ index = t_x - 1
25
+
26
+ for y in range(t_y):
27
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
28
+ if x == y:
29
+ v_cur = max_neg_val
30
+ else:
31
+ v_cur = value[y - 1, x]
32
+ if x == 0:
33
+ if y == 0:
34
+ v_prev = 0.0
35
+ else:
36
+ v_prev = max_neg_val
37
+ else:
38
+ v_prev = value[y - 1, x - 1]
39
+ value[y, x] += max(v_prev, v_cur)
40
+
41
+ for y in range(t_y - 1, -1, -1):
42
+ path[y, index] = 1
43
+ if index != 0 and (
44
+ index == y or value[y - 1, index] < value[y - 1, index - 1]
45
+ ):
46
+ index = index - 1
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ matplotlib
4
+ numpy
5
+ tensorboard
6
+ pypinyin
7
+ jieba
8
+ eng_to_ipa
9
+ unidecode
10
+ inflect
11
+ pyopenjtalk-prebuilt
12
+ numba
text/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2017 Keith Ito
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
+ THE SOFTWARE.
text/__init__.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+ from text import cleaners
3
+ from text.symbols import symbols
4
+
5
+
6
+ # Mappings from symbol to numeric ID and vice versa:
7
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9
+
10
+
11
+ def text_to_sequence(text, symbols, cleaner_names):
12
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
13
+ Args:
14
+ text: string to convert to a sequence
15
+ cleaner_names: names of the cleaner functions to run the text through
16
+ Returns:
17
+ List of integers corresponding to the symbols in the text
18
+ '''
19
+ sequence = []
20
+ symbol_to_id = {s: i for i, s in enumerate(symbols)}
21
+ clean_text = _clean_text(text, cleaner_names)
22
+ print(clean_text)
23
+ print(f" length:{len(clean_text)}")
24
+ for symbol in clean_text:
25
+ if symbol not in symbol_to_id.keys():
26
+ continue
27
+ symbol_id = symbol_to_id[symbol]
28
+ sequence += [symbol_id]
29
+ print(f" length:{len(sequence)}")
30
+ return sequence
31
+
32
+
33
+ def cleaned_text_to_sequence(cleaned_text):
34
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
35
+ Args:
36
+ text: string to convert to a sequence
37
+ Returns:
38
+ List of integers corresponding to the symbols in the text
39
+ '''
40
+ # symbol_to_id = {s: i for i, s in enumerate(symbols)}
41
+ sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
42
+ return sequence
43
+
44
+ def cleaned_text_to_sequence_chinese(cleaned_text):
45
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
46
+ Args:
47
+ text: string to convert to a sequence
48
+ Returns:
49
+ List of integers corresponding to the symbols in the text
50
+ '''
51
+ # symbol_to_id = {s: i for i, s in enumerate(symbols)}
52
+ sequence = [_symbol_to_id[symbol] for symbol in cleaned_text.split(' ') if symbol in _symbol_to_id.keys()]
53
+ return sequence
54
+
55
+
56
+ def sequence_to_text(sequence):
57
+ '''Converts a sequence of IDs back to a string'''
58
+ result = ''
59
+ for symbol_id in sequence:
60
+ s = _id_to_symbol[symbol_id]
61
+ result += s
62
+ return result
63
+
64
+
65
+ def _clean_text(text, cleaner_names):
66
+ for name in cleaner_names:
67
+ cleaner = getattr(cleaners, name)
68
+ if not cleaner:
69
+ raise Exception('Unknown cleaner: %s' % name)
70
+ text = cleaner(text)
71
+ return text
text/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (2.62 kB). View file
 
text/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3.89 kB). View file
 
text/__pycache__/cleaners.cpython-310.pyc ADDED
Binary file (2.54 kB). View file
 
text/__pycache__/cleaners.cpython-311.pyc ADDED
Binary file (4.2 kB). View file
 
text/__pycache__/english.cpython-310.pyc ADDED
Binary file (4.4 kB). View file