Your Name commited on
Commit
c2bb5f2
1 Parent(s): dc13c66

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +15 -0
  2. customs/customsf +0 -0
  3. data/__init__.py +3 -0
  4. data/__pycache__/__init__.cpython-311.pyc +0 -0
  5. data/__pycache__/collation.cpython-311.pyc +0 -0
  6. data/__pycache__/input_strategies.cpython-311.pyc +0 -0
  7. data/__pycache__/tokenizer.cpython-311.pyc +0 -0
  8. data/collation.py +120 -0
  9. data/datamodule.py +419 -0
  10. data/dataset.py +242 -0
  11. data/fbank.py +212 -0
  12. data/input_strategies.py +159 -0
  13. data/tokenizer.py +126 -0
  14. macros.py +44 -0
  15. main.py +355 -0
  16. models/__init__.py +136 -0
  17. models/__pycache__/__init__.cpython-311.pyc +0 -0
  18. models/__pycache__/macros.cpython-311.pyc +0 -0
  19. models/__pycache__/transformer.cpython-311.pyc +0 -0
  20. models/__pycache__/vallex.cpython-311.pyc +0 -0
  21. models/__pycache__/visualizer.cpython-311.pyc +0 -0
  22. models/macros.py +11 -0
  23. models/transformer.py +394 -0
  24. models/vallex.py +853 -0
  25. models/visualizer.py +106 -0
  26. modules/__init__.py +0 -0
  27. modules/__pycache__/__init__.cpython-311.pyc +0 -0
  28. modules/__pycache__/activation.cpython-311.pyc +0 -0
  29. modules/__pycache__/embedding.cpython-311.pyc +0 -0
  30. modules/__pycache__/scaling.cpython-311.pyc +0 -0
  31. modules/__pycache__/transformer.cpython-311.pyc +0 -0
  32. modules/activation.py +612 -0
  33. modules/embedding.py +97 -0
  34. modules/optim.py +1105 -0
  35. modules/scaling.py +1401 -0
  36. modules/scheduler.py +78 -0
  37. modules/transformer.py +683 -0
  38. prompts/promptsf +0 -0
  39. requirements.txt +36 -0
  40. s2smodels.py +19 -0
  41. utils/__init__.py +15 -0
  42. utils/__pycache__/__init__.cpython-311.pyc +0 -0
  43. utils/__pycache__/generation.cpython-311.pyc +0 -0
  44. utils/__pycache__/prompt_making.cpython-311.pyc +0 -0
  45. utils/__pycache__/sentence_cutter.cpython-311.pyc +0 -0
  46. utils/__pycache__/symbol_table.cpython-311.pyc +0 -0
  47. utils/download.py +49 -0
  48. utils/g2p/__init__.py +72 -0
  49. utils/g2p/__pycache__/__init__.cpython-311.pyc +0 -0
  50. utils/g2p/__pycache__/cleaners.cpython-311.pyc +0 -0
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11
2
+ WORKDIR /code
3
+ COPY ./requirements.txt /code/requirements.txt
4
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
5
+ COPY ./s2smodels.py /code/
6
+ COPY ./macros.py /code/
7
+ COPY ./utils/ . /code/
8
+ COPY ./modules/ . /code/
9
+ COPY ./models/ . /code/
10
+ COPY ./data/ . /code/
11
+ COPY ./prompts/ . /code/
12
+ COPY ./customs/ . /code/
13
+ COPY ./main.py /code/
14
+
15
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"]
customs/customsf ADDED
File without changes
data/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # from .datamodule import *
2
+ # from .tokenizer import *
3
+ from .collation import *
data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (208 Bytes). View file
 
data/__pycache__/collation.cpython-311.pyc ADDED
Binary file (7.2 kB). View file
 
data/__pycache__/input_strategies.cpython-311.pyc ADDED
Binary file (1.8 kB). View file
 
data/__pycache__/tokenizer.cpython-311.pyc ADDED
Binary file (6.77 kB). View file
 
data/collation.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import List, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from utils import SymbolTable
8
+
9
+
10
+ class TextTokenCollater:
11
+ """Collate list of text tokens
12
+
13
+ Map sentences to integers. Sentences are padded to equal length.
14
+ Beginning and end-of-sequence symbols can be added.
15
+
16
+ Example:
17
+ >>> token_collater = TextTokenCollater(text_tokens)
18
+ >>> tokens_batch, tokens_lens = token_collater(text)
19
+
20
+ Returns:
21
+ tokens_batch: IntTensor of shape (B, L)
22
+ B: batch dimension, number of input sentences
23
+ L: length of the longest sentence
24
+ tokens_lens: IntTensor of shape (B,)
25
+ Length of each sentence after adding <eos> and <bos>
26
+ but before padding.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ text_tokens: List[str],
32
+ add_eos: bool = True,
33
+ add_bos: bool = True,
34
+ pad_symbol: str = "<pad>",
35
+ bos_symbol: str = "<bos>",
36
+ eos_symbol: str = "<eos>",
37
+ ):
38
+ self.pad_symbol = pad_symbol
39
+
40
+ self.add_eos = add_eos
41
+ self.add_bos = add_bos
42
+
43
+ self.bos_symbol = bos_symbol
44
+ self.eos_symbol = eos_symbol
45
+
46
+ unique_tokens = (
47
+ [pad_symbol]
48
+ + ([bos_symbol] if add_bos else [])
49
+ + ([eos_symbol] if add_eos else [])
50
+ + sorted(text_tokens)
51
+ )
52
+
53
+ self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
54
+ self.idx2token = [token for token in unique_tokens]
55
+
56
+ def index(
57
+ self, tokens_list: List[str]
58
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
59
+ seqs, seq_lens = [], []
60
+ for tokens in tokens_list:
61
+ assert (
62
+ all([True if s in self.token2idx else False for s in tokens])
63
+ is True
64
+ )
65
+ seq = (
66
+ ([self.bos_symbol] if self.add_bos else [])
67
+ + list(tokens)
68
+ + ([self.eos_symbol] if self.add_eos else [])
69
+ )
70
+ seqs.append(seq)
71
+ seq_lens.append(len(seq))
72
+
73
+ max_len = max(seq_lens)
74
+ for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
75
+ seq.extend([self.pad_symbol] * (max_len - seq_len))
76
+
77
+ tokens = torch.from_numpy(
78
+ np.array(
79
+ [[self.token2idx[token] for token in seq] for seq in seqs],
80
+ dtype=np.int64,
81
+ )
82
+ )
83
+ tokens_lens = torch.IntTensor(seq_lens)
84
+
85
+ return tokens, tokens_lens
86
+
87
+ def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
88
+ tokens_seqs = [[p for p in text] for text in texts]
89
+ max_len = len(max(tokens_seqs, key=len))
90
+
91
+ seqs = [
92
+ ([self.bos_symbol] if self.add_bos else [])
93
+ + list(seq)
94
+ + ([self.eos_symbol] if self.add_eos else [])
95
+ + [self.pad_symbol] * (max_len - len(seq))
96
+ for seq in tokens_seqs
97
+ ]
98
+
99
+ tokens_batch = torch.from_numpy(
100
+ np.array(
101
+ [seq for seq in seqs],
102
+ dtype=np.int64,
103
+ )
104
+ )
105
+
106
+ tokens_lens = torch.IntTensor(
107
+ [
108
+ len(seq) + int(self.add_eos) + int(self.add_bos)
109
+ for seq in tokens_seqs
110
+ ]
111
+ )
112
+
113
+ return tokens_batch, tokens_lens
114
+
115
+
116
+ def get_text_token_collater() -> TextTokenCollater:
117
+ collater = TextTokenCollater(
118
+ ['0'], add_bos=False, add_eos=False
119
+ )
120
+ return collater
data/datamodule.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 (authors: Feiteng Li)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import argparse
19
+ import inspect
20
+ import logging
21
+ from functools import lru_cache
22
+ from pathlib import Path
23
+ from typing import Any, Dict, Optional
24
+
25
+ import torch
26
+ # from icefall.utils import str2bool
27
+ # from lhotse import CutSet, load_manifest_lazy
28
+ # from lhotse.dataset import (
29
+ # CutConcatenate,
30
+ # DynamicBucketingSampler,
31
+ # PrecomputedFeatures,
32
+ # SingleCutSampler,
33
+ # SpecAugment,
34
+ # )
35
+ # from lhotse.dataset.input_strategies import OnTheFlyFeatures
36
+ # from lhotse.utils import fix_random_seed
37
+ from torch.utils.data import DataLoader
38
+
39
+ from data.collation import get_text_token_collater
40
+ # from data.dataset import SpeechSynthesisDataset
41
+ from data.fbank import get_fbank_extractor
42
+ from data.input_strategies import PromptedPrecomputedFeatures
43
+
44
+ # PrecomputedFeatures = PrecomputedFeatures
45
+
46
+
47
+ class _SeedWorkers:
48
+ def __init__(self, seed: int):
49
+ self.seed = seed
50
+
51
+ def __call__(self, worker_id: int):
52
+ fix_random_seed(self.seed + worker_id)
53
+
54
+
55
+ def _get_input_strategy(input_strategy, dataset, cuts):
56
+ if input_strategy == "PromptedPrecomputedFeatures":
57
+ return PromptedPrecomputedFeatures(dataset, cuts)
58
+
59
+ return eval(input_strategy)()
60
+
61
+
62
+ class TtsDataModule:
63
+ """
64
+ DataModule for VALL-E TTS experiments.
65
+ It assumes there is always one train and valid dataloader.
66
+
67
+ It contains all the common data pipeline modules used in TTS
68
+ experiments, e.g.:
69
+ - dynamic batch size,
70
+ - bucketing samplers,
71
+ - cut concatenation[not used & tested yet],
72
+ - augmentation[not used & tested yet],
73
+ - on-the-fly feature extraction[not used & tested yet]
74
+
75
+ This class should be derived for specific corpora used in TTS tasks.
76
+ """
77
+
78
+ def __init__(self, args: argparse.Namespace):
79
+ self.args = args
80
+
81
+ @classmethod
82
+ def add_arguments(cls, parser: argparse.ArgumentParser):
83
+ group = parser.add_argument_group(
84
+ title="TTS data related options",
85
+ description="These options are used for the preparation of "
86
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
87
+ "effective batch sizes, sampling strategies, applied data "
88
+ "augmentations, etc.",
89
+ )
90
+ group.add_argument(
91
+ "--manifest-dir",
92
+ type=Path,
93
+ default=Path("data/tokenized"),
94
+ help="Path to directory with train/valid/test cuts.",
95
+ )
96
+ group.add_argument(
97
+ "--max-duration",
98
+ type=int,
99
+ default=40.0,
100
+ help="Maximum pooled recordings duration (seconds) in a "
101
+ "single batch. You can reduce it if it causes CUDA OOM.",
102
+ )
103
+ group.add_argument(
104
+ "--bucketing-sampler",
105
+ type=str2bool,
106
+ default=True,
107
+ help="When enabled, the batches will come from buckets of "
108
+ "similar duration (saves padding frames).",
109
+ )
110
+ group.add_argument(
111
+ "--num-buckets",
112
+ type=int,
113
+ default=10,
114
+ help="The number of buckets for the DynamicBucketingSampler"
115
+ "(you might want to increase it for larger datasets).",
116
+ )
117
+ group.add_argument(
118
+ "--concatenate-cuts",
119
+ type=str2bool,
120
+ default=False,
121
+ help="When enabled, utterances (cuts) will be concatenated "
122
+ "to minimize the amount of padding.",
123
+ )
124
+ group.add_argument(
125
+ "--duration-factor",
126
+ type=float,
127
+ default=1.0,
128
+ help="Determines the maximum duration of a concatenated cut "
129
+ "relative to the duration of the longest cut in a batch.",
130
+ )
131
+ group.add_argument(
132
+ "--gap",
133
+ type=float,
134
+ default=0.1,
135
+ help="The amount of padding (in seconds) inserted between "
136
+ "concatenated cuts. This padding is filled with noise when "
137
+ "noise augmentation is used.",
138
+ )
139
+ group.add_argument(
140
+ "--on-the-fly-feats",
141
+ type=str2bool,
142
+ default=False,
143
+ help="When enabled, use on-the-fly cut mixing and feature "
144
+ "extraction. Will drop existing precomputed feature manifests "
145
+ "if available.",
146
+ )
147
+ group.add_argument(
148
+ "--shuffle",
149
+ type=str2bool,
150
+ default=True,
151
+ help="When enabled (=default), the examples will be "
152
+ "shuffled for each epoch.",
153
+ )
154
+ group.add_argument(
155
+ "--drop-last",
156
+ type=str2bool,
157
+ default=False,
158
+ help="Whether to drop last batch. Used by sampler.",
159
+ )
160
+ group.add_argument(
161
+ "--return-cuts",
162
+ type=str2bool,
163
+ default=True,
164
+ help="When enabled, each batch will have the "
165
+ "field: batch['supervisions']['cut'] with the cuts that "
166
+ "were used to construct it.",
167
+ )
168
+
169
+ group.add_argument(
170
+ "--num-workers",
171
+ type=int,
172
+ default=8,
173
+ help="The number of training dataloader workers that "
174
+ "collect the batches.",
175
+ )
176
+
177
+ group.add_argument(
178
+ "--enable-spec-aug",
179
+ type=str2bool,
180
+ default=False,
181
+ help="When enabled, use SpecAugment for training dataset.",
182
+ )
183
+
184
+ group.add_argument(
185
+ "--spec-aug-time-warp-factor",
186
+ type=int,
187
+ default=80,
188
+ help="Used only when --enable-spec-aug is True. "
189
+ "It specifies the factor for time warping in SpecAugment. "
190
+ "Larger values mean more warping. "
191
+ "A value less than 1 means to disable time warp.",
192
+ )
193
+
194
+ group.add_argument(
195
+ "--input-strategy",
196
+ type=str,
197
+ default="PrecomputedFeatures",
198
+ help="AudioSamples or PrecomputedFeatures or PromptedPrecomputedFeatures",
199
+ )
200
+
201
+ group.add_argument(
202
+ "--dataset",
203
+ type=str,
204
+ default="ljspeech",
205
+ help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.",
206
+ )
207
+
208
+ parser.add_argument(
209
+ "--text-tokens",
210
+ type=str,
211
+ default="data/tokenized/unique_text_tokens.k2symbols",
212
+ help="Path to the unique text tokens file",
213
+ )
214
+
215
+ parser.add_argument(
216
+ "--sampling-rate",
217
+ type=int,
218
+ default=24000,
219
+ help="""Audio sampling rate.""",
220
+ )
221
+
222
+ def train_dataloaders(
223
+ self,
224
+ cuts_train: CutSet,
225
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
226
+ ) -> DataLoader:
227
+ """
228
+ Args:
229
+ cuts_train:
230
+ CutSet for training.
231
+ sampler_state_dict:
232
+ The state dict for the training sampler.
233
+ """
234
+ transforms = []
235
+
236
+ if self.args.concatenate_cuts:
237
+ logging.info(
238
+ f"Using cut concatenation with duration factor "
239
+ f"{self.args.duration_factor} and gap {self.args.gap}."
240
+ )
241
+ # Cut concatenation should be the first transform in the list,
242
+ # so that if we e.g. mix noise in, it will fill the gaps between
243
+ # different utterances.
244
+ transforms = [
245
+ CutConcatenate(
246
+ duration_factor=self.args.duration_factor, gap=self.args.gap
247
+ )
248
+ ] + transforms
249
+
250
+ input_transforms = []
251
+ if self.args.enable_spec_aug:
252
+ logging.info("Enable SpecAugment")
253
+ logging.info(
254
+ f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
255
+ )
256
+ # Set the value of num_frame_masks according to Lhotse's version.
257
+ # In different Lhotse's versions, the default of num_frame_masks is
258
+ # different.
259
+ num_frame_masks = 10
260
+ num_frame_masks_parameter = inspect.signature(
261
+ SpecAugment.__init__
262
+ ).parameters["num_frame_masks"]
263
+ if num_frame_masks_parameter.default == 1:
264
+ num_frame_masks = 2
265
+ logging.info(f"Num frame mask: {num_frame_masks}")
266
+ input_transforms.append(
267
+ SpecAugment(
268
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
269
+ num_frame_masks=num_frame_masks,
270
+ features_mask_size=27,
271
+ num_feature_masks=2,
272
+ frames_mask_size=100,
273
+ )
274
+ )
275
+ else:
276
+ logging.info("Disable SpecAugment")
277
+
278
+ logging.info("About to create train dataset")
279
+ if self.args.on_the_fly_feats:
280
+ # NOTE: the PerturbSpeed transform should be added only if we
281
+ # remove it from data prep stage.
282
+ # Add on-the-fly speed perturbation; since originally it would
283
+ # have increased epoch size by 3, we will apply prob 2/3 and use
284
+ # 3x more epochs.
285
+ # Speed perturbation probably should come first before
286
+ # concatenation, but in principle the transforms order doesn't have
287
+ # to be strict (e.g. could be randomized)
288
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
289
+ # Drop feats to be on the safe side.
290
+ train = SpeechSynthesisDataset(
291
+ get_text_token_collater(self.args.text_tokens),
292
+ cut_transforms=transforms,
293
+ feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
294
+ feature_transforms=input_transforms,
295
+ )
296
+ else:
297
+ train = SpeechSynthesisDataset(
298
+ get_text_token_collater(self.args.text_tokens),
299
+ feature_input_strategy=_get_input_strategy(
300
+ self.args.input_strategy, self.args.dataset, cuts_train
301
+ ),
302
+ cut_transforms=transforms,
303
+ feature_transforms=input_transforms,
304
+ )
305
+
306
+ if self.args.bucketing_sampler:
307
+ logging.info("Using DynamicBucketingSampler")
308
+ train_sampler = DynamicBucketingSampler(
309
+ cuts_train,
310
+ max_duration=self.args.max_duration,
311
+ shuffle=self.args.shuffle,
312
+ num_buckets=self.args.num_buckets,
313
+ drop_last=self.args.drop_last,
314
+ )
315
+ else:
316
+ logging.info(
317
+ "Using SingleCutSampler and sort by duraton(ascending=True)."
318
+ )
319
+ cuts_train = cuts_train.to_eager().sort_by_duration(ascending=True)
320
+ train_sampler = SingleCutSampler(
321
+ cuts_train,
322
+ max_duration=self.args.max_duration,
323
+ shuffle=self.args.shuffle,
324
+ )
325
+ logging.info("About to create train dataloader")
326
+
327
+ if sampler_state_dict is not None:
328
+ logging.info("Loading sampler state dict")
329
+ train_sampler.load_state_dict(sampler_state_dict)
330
+
331
+ # 'seed' is derived from the current random state, which will have
332
+ # previously been set in the main process.
333
+ seed = torch.randint(0, 100000, ()).item()
334
+ worker_init_fn = _SeedWorkers(seed)
335
+
336
+ train_dl = DataLoader(
337
+ train,
338
+ sampler=train_sampler,
339
+ batch_size=None,
340
+ num_workers=self.args.num_workers,
341
+ persistent_workers=False,
342
+ worker_init_fn=worker_init_fn,
343
+ )
344
+
345
+ return train_dl
346
+
347
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
348
+ logging.info("About to create dev dataset")
349
+ if self.args.on_the_fly_feats:
350
+ validate = SpeechSynthesisDataset(
351
+ get_text_token_collater(self.args.text_tokens),
352
+ feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
353
+ cut_transforms=[],
354
+ )
355
+ else:
356
+ validate = SpeechSynthesisDataset(
357
+ get_text_token_collater(self.args.text_tokens),
358
+ feature_input_strategy=_get_input_strategy(
359
+ self.args.input_strategy, self.args.dataset, cuts_valid
360
+ ),
361
+ cut_transforms=[],
362
+ )
363
+ valid_sampler = DynamicBucketingSampler(
364
+ cuts_valid,
365
+ max_duration=self.args.max_duration,
366
+ shuffle=False,
367
+ )
368
+ logging.info("About to create dev dataloader")
369
+ valid_dl = DataLoader(
370
+ validate,
371
+ sampler=valid_sampler,
372
+ batch_size=None,
373
+ num_workers=4,
374
+ persistent_workers=False,
375
+ )
376
+
377
+ return valid_dl
378
+
379
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
380
+ logging.debug("About to create test dataset")
381
+ test = SpeechSynthesisDataset(
382
+ get_text_token_collater(self.args.text_tokens),
383
+ feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor())
384
+ if self.args.on_the_fly_feats
385
+ else _get_input_strategy(
386
+ self.args.input_strategy, self.args.dataset, cuts
387
+ ),
388
+ cut_transforms=[],
389
+ )
390
+ sampler = DynamicBucketingSampler(
391
+ cuts,
392
+ max_duration=self.args.max_duration,
393
+ shuffle=False,
394
+ )
395
+ logging.debug("About to create test dataloader")
396
+ test_dl = DataLoader(
397
+ test,
398
+ batch_size=None,
399
+ sampler=sampler,
400
+ num_workers=self.args.num_workers,
401
+ )
402
+ return test_dl
403
+
404
+ @lru_cache()
405
+ def train_cuts(self) -> CutSet:
406
+ logging.info("About to get train cuts")
407
+ return load_manifest_lazy(
408
+ self.args.manifest_dir / "cuts_train.jsonl.gz"
409
+ )
410
+
411
+ @lru_cache()
412
+ def dev_cuts(self) -> CutSet:
413
+ logging.info("About to get dev cuts")
414
+ return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
415
+
416
+ @lru_cache()
417
+ def test_cuts(self) -> CutSet:
418
+ logging.info("About to get test cuts")
419
+ return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz")
data/dataset.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 (authors: Feiteng Li)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ modified from lhoste.dataset.speech_synthesis.py
19
+ """
20
+
21
+ import torch
22
+ import math
23
+ import h5py
24
+ from tokenizers import Tokenizer
25
+ from typing import Union, List
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+
29
+ _pad = '_'
30
+ _punctuation = ',.!?-~…'
31
+ _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
32
+ symbols = [_pad] + list(_punctuation) + list(_letters)
33
+
34
+ language_dict = {
35
+ 'en': 0,
36
+ 'zh': 1,
37
+ 'ja': 2,
38
+ }
39
+ def seq2phone(tokens: Union[List, np.ndarray]):
40
+ """
41
+ Convert tokenized phoneme ID sequence back to phoneme string
42
+ :param tokens: phoneme tokens
43
+ :return: recovered phoneme sequence
44
+ """
45
+ phones = "".join([symbols[i] for i in tokens])
46
+ return phones
47
+
48
+ class DynamicBatchSampler(torch.utils.data.Sampler):
49
+ def __init__(self, sampler, num_tokens_fn, num_buckets=100, min_size=0, max_size=1000,
50
+ max_tokens=None, max_sentences=None, drop_last=False):
51
+ """
52
+
53
+ :param sampler:
54
+ :param num_tokens_fn: 根据idx返回样本的长度的函数
55
+ :param num_buckets: 利用桶原理将相似长度的样本放在一个batchsize中,桶的数量
56
+ :param min_size: 最小长度的样本, 小于这个值的样本会被过滤掉。 依据这个值来创建样桶
57
+ :param max_size: 最大长度的样本
58
+ :param max_sentences: batch_size, 但是这里可以通过max_sentences 和 max_tokens 共同控制最终的大小
59
+ """
60
+ super(DynamicBatchSampler, self).__init__(sampler)
61
+ self.sampler = sampler
62
+ self.num_tokens_fn = num_tokens_fn
63
+ self.num_buckets = num_buckets
64
+
65
+ self.min_size = min_size
66
+ self.max_size = max_size
67
+
68
+ assert max_size <= max_tokens, "max_size should be smaller than max tokens"
69
+ assert max_tokens is not None or max_sentences is not None, \
70
+ "max_tokens and max_sentences should not be null at the same time, please specify one parameter at least"
71
+ self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
72
+ self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
73
+ self.drop_last = drop_last
74
+
75
+ def set_epoch(self, epoch):
76
+ self.sampler.set_epoch(epoch)
77
+ def is_batch_full(self, num_tokens, batch):
78
+ if len(batch) == 0:
79
+ return False
80
+ if len(batch) == self.max_sentences:
81
+ return True
82
+ if num_tokens > self.max_tokens:
83
+ return True
84
+ return False
85
+
86
+ def __iter__(self):
87
+ buckets = [[] for _ in range(self.num_buckets)]
88
+ sample_len = [0] * self.num_buckets
89
+
90
+ for idx in self.sampler:
91
+ idx_length = self.num_tokens_fn(idx)
92
+ if not (self.min_size <= idx_length <= self.max_size):
93
+ print("sentence at index {} of size {} exceeds max_tokens, the sentence is ignored".format(idx, idx_length))
94
+ continue
95
+
96
+ index_buckets = math.floor((idx_length - self.min_size) / (self.max_size - self.min_size + 1)
97
+ * self.num_buckets)
98
+ sample_len[index_buckets] = max(sample_len[index_buckets], idx_length)
99
+
100
+ num_tokens = (len(buckets[index_buckets]) + 1) * sample_len[index_buckets]
101
+ if self.is_batch_full(num_tokens, buckets[index_buckets]):
102
+ # yield this batch
103
+ yield buckets[index_buckets]
104
+ buckets[index_buckets] = []
105
+ sample_len[index_buckets] = 0
106
+
107
+ buckets[index_buckets].append(idx)
108
+
109
+ # process left-over
110
+ leftover_batch = []
111
+ leftover_sample_len = 0
112
+ leftover = [idx for bucket in buckets for idx in bucket]
113
+ for idx in leftover:
114
+ idx_length = self.num_tokens_fn(idx)
115
+ leftover_sample_len = max(leftover_sample_len, idx_length)
116
+ num_tokens = (len(leftover_batch) + 1) * leftover_sample_len
117
+ if self.is_batch_full(num_tokens, leftover_batch):
118
+ yield leftover_batch
119
+ leftover_batch = []
120
+ leftover_sample_len = 0
121
+ leftover_batch.append(idx)
122
+
123
+ if len(leftover_batch) > 0 and not self.drop_last:
124
+ yield leftover_batch
125
+
126
+ def __len__(self):
127
+ # we do not know the exactly batch size, so do not call len(dataloader)
128
+ pass
129
+
130
+
131
+ class AudioDataset(torch.utils.data.Dataset):
132
+ def __init__(self, h5_path, ann_path, tokenizer_path):
133
+ self.h5_path = h5_path
134
+ with open(ann_path, 'r', encoding='utf-8') as f:
135
+ lines = f.readlines()
136
+ ls = [l.split("|") for l in lines]
137
+ ls_T = list(zip(*ls))
138
+ del ls_T[-1]
139
+ self.h5_paths, self.durations, self.langs, self.texts = \
140
+ list(ls_T[0]), list(ls_T[1]), list(ls_T[2]), list(ls_T[3])
141
+ self.durations = [float(dur) for dur in self.durations]
142
+ self.tokenizer = Tokenizer.from_file(tokenizer_path)
143
+
144
+ self._archive = None
145
+
146
+ def __len__(self):
147
+ return len(self.h5_paths)
148
+
149
+ def get_dur(self, idx):
150
+ return self.durations[idx]
151
+
152
+ @property
153
+ def archive(self):
154
+ if self._archive is None: # lazy loading here!
155
+ self._archive = h5py.File(self.h5_path, "r")
156
+ return self._archive
157
+ def __getitem__(self, idx):
158
+ archive = self.archive
159
+ h5_path = self.h5_paths[idx]
160
+ sub = archive[h5_path]
161
+ audio_tokens = sub['audio'][()]
162
+ phone_tokens = sub['text'][()]
163
+ dur = self.durations[idx]
164
+ lang = self.langs[idx]
165
+ text = self.texts[idx]
166
+ # tokenization should be done within dataloader
167
+ phones = seq2phone(phone_tokens)
168
+ phones = phones.replace(" ", "_")
169
+ if not len(phones):
170
+ cptpho_tokens = self.tokenizer.encode(text).ids
171
+ else:
172
+ cptpho_tokens = self.tokenizer.encode(phones).ids
173
+ assert len(cptpho_tokens)
174
+ return {
175
+ 'utt_id': h5_path,
176
+ 'text': text,
177
+ 'audio': None,
178
+ 'audio_lens': None,
179
+ 'audio_features': audio_tokens,
180
+ 'audio_features_lens': len(audio_tokens.T),
181
+ 'text_tokens': np.array(cptpho_tokens),
182
+ 'text_tokens_lens': len(cptpho_tokens),
183
+ 'language': language_dict[lang],
184
+ }
185
+
186
+ def collate(batch):
187
+ utt_id_s = [b['utt_id'] for b in batch]
188
+ text_s = [b['text'] for b in batch]
189
+
190
+ audio_s = [b['audio'] for b in batch]
191
+ audio_lens_s = [b['audio_lens'] for b in batch]
192
+
193
+ audio_features_lens_s = [b['audio_features_lens'] for b in batch]
194
+ # create an empty tensor with maximum audio feature length
195
+ audio_features_s = torch.zeros([len(batch), max(audio_features_lens_s), 8], dtype=torch.int64) - 1 # audio pad with -1
196
+
197
+ text_tokens_lens_s = [b['text_tokens_lens'] for b in batch]
198
+ # create an empty tensor with maximum text tokens length
199
+ text_tokens_s = torch.zeros([len(batch), max(text_tokens_lens_s)], dtype=torch.int64) + 3 # [PAD] token id 3
200
+
201
+ language_s = [b['language'] for b in batch]
202
+
203
+ for i, b in enumerate(batch):
204
+ audio_features = b['audio_features']
205
+ audio_features_lens = b['audio_features_lens']
206
+ audio_features_s[i, :audio_features_lens, :] = torch.LongTensor(audio_features.T)
207
+
208
+ text_tokens = b['text_tokens']
209
+ text_tokens_lens = b['text_tokens_lens']
210
+ text_tokens_s[i, :text_tokens_lens] = torch.LongTensor(text_tokens)
211
+
212
+ batch = {
213
+ 'utt_id': utt_id_s,
214
+ 'text': text_s,
215
+ 'audio': audio_s,
216
+ 'audio_lens': audio_lens_s,
217
+ 'audio_features': audio_features_s,
218
+ 'audio_features_lens': torch.LongTensor(np.array(audio_features_lens_s)),
219
+ 'text_tokens': text_tokens_s,
220
+ 'text_tokens_lens': torch.LongTensor(np.array(text_tokens_lens_s)),
221
+ 'languages': torch.LongTensor(np.array(language_s)),
222
+ }
223
+ return batch
224
+
225
+ def create_dataloader(data_dir="/root/valle/egs/mix", n_gpus=1, rank=0, num_workers=0, num_buckets=10, max_duration=120):
226
+ train_dataset = AudioDataset(h5_path=f"{data_dir}/audio_sum.hdf5",
227
+ ann_path=f"{data_dir}/audio_ann_sum.txt",
228
+ tokenizer_path=f"{data_dir}/bpe_69.json")
229
+ ran_sampler = torch.utils.data.distributed.DistributedSampler(
230
+ train_dataset,
231
+ num_replicas=n_gpus,
232
+ rank=rank,
233
+ shuffle=True,
234
+ )
235
+ dynamic_sampler = DynamicBatchSampler(ran_sampler, train_dataset.get_dur, num_buckets=num_buckets, max_size=20,
236
+ max_tokens=max_duration)
237
+
238
+
239
+ train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=num_workers, collate_fn=collate,
240
+ batch_sampler=dynamic_sampler)
241
+
242
+ return train_loader
data/fbank.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 (authors: Feiteng Li)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ from dataclasses import asdict, dataclass
19
+ from typing import Any, Dict, Optional, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ # from lhotse.features.base import FeatureExtractor
24
+ # from lhotse.utils import EPSILON, Seconds, compute_num_frames
25
+ from librosa.filters import mel as librosa_mel_fn
26
+
27
+
28
+ @dataclass
29
+ class BigVGANFbankConfig:
30
+ # Spectogram-related part
31
+ # Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them
32
+ frame_length: Seconds = 1024 / 24000.0
33
+ frame_shift: Seconds = 256 / 24000.0
34
+ remove_dc_offset: bool = True
35
+ round_to_power_of_two: bool = True
36
+
37
+ # Fbank-related part
38
+ low_freq: float = 0.0
39
+ high_freq: float = 12000.0
40
+ num_mel_bins: int = 100
41
+ use_energy: bool = False
42
+
43
+ def to_dict(self) -> Dict[str, Any]:
44
+ return asdict(self)
45
+
46
+ @staticmethod
47
+ def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig":
48
+ return BigVGANFbankConfig(**data)
49
+
50
+
51
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
52
+ return torch.log(torch.clamp(x, min=clip_val) * C)
53
+
54
+
55
+ def spectral_normalize_torch(magnitudes):
56
+ output = dynamic_range_compression_torch(magnitudes)
57
+ return output
58
+
59
+
60
+ # https://github.com/NVIDIA/BigVGAN
61
+ # bigvgan_24khz_100band https://drive.google.com/drive/folders/1EpxX6AsxjCbbk0mmAhE0td6eYiABr8Oz
62
+ class BigVGANFbank(FeatureExtractor):
63
+ name = "fbank"
64
+ config_type = BigVGANFbankConfig
65
+
66
+ def __init__(self, config: Optional[Any] = None):
67
+ super(BigVGANFbank, self).__init__(config)
68
+ sampling_rate = 24000
69
+ self.mel_basis = torch.from_numpy(
70
+ librosa_mel_fn(
71
+ sampling_rate,
72
+ 1024,
73
+ self.config.num_mel_bins,
74
+ self.config.low_freq,
75
+ self.config.high_freq,
76
+ ).astype(np.float32)
77
+ )
78
+ self.hann_window = torch.hann_window(1024)
79
+
80
+ def _feature_fn(self, samples, **kwargs):
81
+ win_length, n_fft = 1024, 1024
82
+ hop_size = 256
83
+ if True:
84
+ sampling_rate = 24000
85
+ duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
86
+ expected_num_frames = compute_num_frames(
87
+ duration=duration,
88
+ frame_shift=self.frame_shift,
89
+ sampling_rate=sampling_rate,
90
+ )
91
+ pad_size = (
92
+ (expected_num_frames - 1) * hop_size
93
+ + win_length
94
+ - samples.shape[-1]
95
+ )
96
+ assert pad_size >= 0
97
+
98
+ y = torch.nn.functional.pad(
99
+ samples,
100
+ (0, pad_size),
101
+ mode="constant",
102
+ )
103
+ else:
104
+ y = torch.nn.functional.pad(
105
+ samples,
106
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
107
+ mode="reflect",
108
+ )
109
+
110
+ y = y.squeeze(1)
111
+
112
+ # complex tensor as default, then use view_as_real for future pytorch compatibility
113
+ spec = torch.stft(
114
+ y,
115
+ n_fft,
116
+ hop_length=hop_size,
117
+ win_length=win_length,
118
+ window=self.hann_window,
119
+ center=False,
120
+ pad_mode="reflect",
121
+ normalized=False,
122
+ onesided=True,
123
+ return_complex=True,
124
+ )
125
+ spec = torch.view_as_real(spec)
126
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
127
+
128
+ spec = torch.matmul(self.mel_basis, spec)
129
+ spec = spectral_normalize_torch(spec)
130
+
131
+ return spec.transpose(2, 1).squeeze(0)
132
+
133
+ def extract(
134
+ self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
135
+ ) -> np.ndarray:
136
+ assert sampling_rate == 24000
137
+ params = asdict(self.config)
138
+ params.update({"sample_frequency": sampling_rate, "snip_edges": False})
139
+ params["frame_shift"] *= 1000.0
140
+ params["frame_length"] *= 1000.0
141
+ if not isinstance(samples, torch.Tensor):
142
+ samples = torch.from_numpy(samples)
143
+ # Torchaudio Kaldi feature extractors expect the channel dimension to be first.
144
+ if len(samples.shape) == 1:
145
+ samples = samples.unsqueeze(0)
146
+ features = self._feature_fn(samples, **params).to(torch.float32)
147
+ return features.numpy()
148
+
149
+ @property
150
+ def frame_shift(self) -> Seconds:
151
+ return self.config.frame_shift
152
+
153
+ def feature_dim(self, sampling_rate: int) -> int:
154
+ return self.config.num_mel_bins
155
+
156
+ @staticmethod
157
+ def mix(
158
+ features_a: np.ndarray,
159
+ features_b: np.ndarray,
160
+ energy_scaling_factor_b: float,
161
+ ) -> np.ndarray:
162
+ return np.log(
163
+ np.maximum(
164
+ # protection against log(0); max with EPSILON is adequate since these are energies (always >= 0)
165
+ EPSILON,
166
+ np.exp(features_a)
167
+ + energy_scaling_factor_b * np.exp(features_b),
168
+ )
169
+ )
170
+
171
+ @staticmethod
172
+ def compute_energy(features: np.ndarray) -> float:
173
+ return float(np.sum(np.exp(features)))
174
+
175
+
176
+ def get_fbank_extractor() -> BigVGANFbank:
177
+ return BigVGANFbank(BigVGANFbankConfig())
178
+
179
+
180
+ if __name__ == "__main__":
181
+ extractor = BigVGANFbank(BigVGANFbankConfig())
182
+
183
+ samples = torch.from_numpy(np.random.random([1000]).astype(np.float32))
184
+ samples = torch.clip(samples, -1.0, 1.0)
185
+ fbank = extractor.extract(samples, 24000.0)
186
+ print(f"fbank {fbank.shape}")
187
+
188
+ from scipy.io.wavfile import read
189
+
190
+ MAX_WAV_VALUE = 32768.0
191
+
192
+ sampling_rate, samples = read(
193
+ "egs/libritts/prompts/5639_40744_000000_000002.wav"
194
+ )
195
+ print(f"samples: [{samples.min()}, {samples.max()}]")
196
+ fbank = extractor.extract(samples.astype(np.float32) / MAX_WAV_VALUE, 24000)
197
+ print(f"fbank {fbank.shape}")
198
+
199
+ import matplotlib.pyplot as plt
200
+
201
+ _ = plt.figure(figsize=(18, 10))
202
+ plt.imshow(
203
+ X=fbank.transpose(1, 0),
204
+ cmap=plt.get_cmap("jet"),
205
+ aspect="auto",
206
+ interpolation="nearest",
207
+ )
208
+ plt.gca().invert_yaxis()
209
+ plt.savefig("egs/libritts/prompts/5639_40744_000000_000002.png")
210
+ plt.close()
211
+
212
+ print("fbank test PASS!")
data/input_strategies.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import defaultdict
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from typing import Tuple, Type
5
+
6
+ # from lhotse import CutSet
7
+ # from lhotse.dataset.collation import collate_features
8
+ # from lhotse.dataset.input_strategies import (
9
+ # ExecutorType,
10
+ # PrecomputedFeatures,
11
+ # _get_executor,
12
+ # )
13
+ # from lhotse.utils import fastcopy
14
+
15
+
16
+ class PromptedFeatures:
17
+ def __init__(self, prompts, features):
18
+ self.prompts = prompts
19
+ self.features = features
20
+
21
+ def to(self, device):
22
+ return PromptedFeatures(
23
+ self.prompts.to(device), self.features.to(device)
24
+ )
25
+
26
+ def sum(self):
27
+ return self.features.sum()
28
+
29
+ @property
30
+ def ndim(self):
31
+ return self.features.ndim
32
+
33
+ @property
34
+ def data(self):
35
+ return (self.prompts, self.features)
36
+
37
+
38
+ # class PromptedPrecomputedFeatures(PrecomputedFeatures):
39
+ # """
40
+ # :class:`InputStrategy` that reads pre-computed features, whose manifests
41
+ # are attached to cuts, from disk.
42
+ #
43
+ # It automatically pads the feature matrices with pre or post feature.
44
+ #
45
+ # .. automethod:: __call__
46
+ # """
47
+ #
48
+ # def __init__(
49
+ # self,
50
+ # dataset: str,
51
+ # cuts: CutSet,
52
+ # num_workers: int = 0,
53
+ # executor_type: Type[ExecutorType] = ThreadPoolExecutor,
54
+ # ) -> None:
55
+ # super(PromptedPrecomputedFeatures, self).__init__(
56
+ # num_workers, executor_type
57
+ # )
58
+ #
59
+ # self.utt2neighbors = defaultdict(lambda: [])
60
+ #
61
+ # if dataset.lower() == "libritts":
62
+ # # 909_131041_000013_000002
63
+ # # 909_131041_000013_000003
64
+ # speaker2utts = defaultdict(lambda: [])
65
+ #
66
+ # utt2cut = {}
67
+ # for cut in cuts:
68
+ # speaker = cut.supervisions[0].speaker
69
+ # speaker2utts[speaker].append(cut.id)
70
+ # utt2cut[cut.id] = cut
71
+ #
72
+ # for spk in speaker2utts:
73
+ # uttids = sorted(speaker2utts[spk])
74
+ # # Using the property of sorted keys to find previous utterance
75
+ # # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001
76
+ # if len(uttids) == 1:
77
+ # self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
78
+ # continue
79
+ #
80
+ # utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
81
+ # utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
82
+ #
83
+ # for utt in utt2prevutt:
84
+ # self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]])
85
+ #
86
+ # for utt in utt2postutt:
87
+ # self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]])
88
+ # elif dataset.lower() == "ljspeech":
89
+ # utt2cut = {}
90
+ # uttids = []
91
+ # for cut in cuts:
92
+ # uttids.append(cut.id)
93
+ # utt2cut[cut.id] = cut
94
+ #
95
+ # if len(uttids) == 1:
96
+ # self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
97
+ # else:
98
+ # # Using the property of sorted keys to find previous utterance
99
+ # # The keys has structure: LJ001-0010
100
+ # utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
101
+ # utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
102
+ #
103
+ # for utt in utt2postutt:
104
+ # postutt = utt2postutt[utt]
105
+ # if utt[:5] == postutt[:5]:
106
+ # self.utt2neighbors[utt].append(utt2cut[postutt])
107
+ #
108
+ # for utt in utt2prevutt:
109
+ # prevutt = utt2prevutt[utt]
110
+ # if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]:
111
+ # self.utt2neighbors[utt].append(utt2cut[prevutt])
112
+ # else:
113
+ # raise ValueError
114
+ #
115
+ # def __call__(
116
+ # self, cuts: CutSet
117
+ # ) -> Tuple[PromptedFeatures, PromptedFeatures]:
118
+ # """
119
+ # Reads the pre-computed features from disk/other storage.
120
+ # The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``.
121
+ #
122
+ # :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding.
123
+ # """
124
+ # features, features_lens = collate_features(
125
+ # cuts,
126
+ # executor=_get_executor(
127
+ # self.num_workers, executor_type=self._executor_type
128
+ # ),
129
+ # )
130
+ #
131
+ # prompts_cuts = []
132
+ # for k, cut in enumerate(cuts):
133
+ # prompts_cut = random.choice(self.utt2neighbors[cut.id])
134
+ # prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}"))
135
+ #
136
+ # mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0])
137
+ # # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate(
138
+ # # max_duration=mini_duration,
139
+ # # offset_type="random",
140
+ # # preserve_id=True,
141
+ # # )
142
+ # prompts_cuts = CutSet(
143
+ # cuts={k: cut for k, cut in enumerate(prompts_cuts)}
144
+ # ).truncate(
145
+ # max_duration=mini_duration,
146
+ # offset_type="random",
147
+ # preserve_id=False,
148
+ # )
149
+ #
150
+ # prompts, prompts_lens = collate_features(
151
+ # prompts_cuts,
152
+ # executor=_get_executor(
153
+ # self.num_workers, executor_type=self._executor_type
154
+ # ),
155
+ # )
156
+ #
157
+ # return PromptedFeatures(prompts, features), PromptedFeatures(
158
+ # prompts_lens, features_lens
159
+ # )
data/tokenizer.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2023 (authors: Feiteng Li)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ from dataclasses import asdict, dataclass
18
+ from typing import Any, Dict, List, Optional, Pattern, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torchaudio
23
+ from encodec import EncodecModel
24
+ from encodec.utils import convert_audio
25
+
26
+ try:
27
+ from pypinyin import Style, pinyin
28
+ from pypinyin.style._utils import get_finals, get_initials
29
+ except Exception:
30
+ pass
31
+
32
+
33
+ def remove_encodec_weight_norm(model):
34
+ from encodec.modules import SConv1d
35
+ from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
36
+ from torch.nn.utils import remove_weight_norm
37
+
38
+ encoder = model.encoder.model
39
+ for key in encoder._modules:
40
+ if isinstance(encoder._modules[key], SEANetResnetBlock):
41
+ remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
42
+ block_modules = encoder._modules[key].block._modules
43
+ for skey in block_modules:
44
+ if isinstance(block_modules[skey], SConv1d):
45
+ remove_weight_norm(block_modules[skey].conv.conv)
46
+ elif isinstance(encoder._modules[key], SConv1d):
47
+ remove_weight_norm(encoder._modules[key].conv.conv)
48
+
49
+ decoder = model.decoder.model
50
+ for key in decoder._modules:
51
+ if isinstance(decoder._modules[key], SEANetResnetBlock):
52
+ remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
53
+ block_modules = decoder._modules[key].block._modules
54
+ for skey in block_modules:
55
+ if isinstance(block_modules[skey], SConv1d):
56
+ remove_weight_norm(block_modules[skey].conv.conv)
57
+ elif isinstance(decoder._modules[key], SConvTranspose1d):
58
+ remove_weight_norm(decoder._modules[key].convtr.convtr)
59
+ elif isinstance(decoder._modules[key], SConv1d):
60
+ remove_weight_norm(decoder._modules[key].conv.conv)
61
+
62
+
63
+ class AudioTokenizer:
64
+ """EnCodec audio."""
65
+
66
+ def __init__(
67
+ self,
68
+ device: Any = None,
69
+ ) -> None:
70
+ # Instantiate a pretrained EnCodec model
71
+ model = EncodecModel.encodec_model_24khz()
72
+ model.set_target_bandwidth(6.0)
73
+ remove_encodec_weight_norm(model)
74
+
75
+ if not device:
76
+ device = torch.device("cpu")
77
+ if torch.cuda.is_available():
78
+ device = torch.device("cuda:0")
79
+ if torch.backends.mps.is_available():
80
+ device = torch.device("mps")
81
+
82
+ self._device = device
83
+
84
+ self.codec = model.to(device)
85
+ self.sample_rate = model.sample_rate
86
+ self.channels = model.channels
87
+
88
+ @property
89
+ def device(self):
90
+ return self._device
91
+
92
+ def encode(self, wav: torch.Tensor) -> torch.Tensor:
93
+ return self.codec.encode(wav.to(self.device))
94
+
95
+ def decode(self, frames: torch.Tensor) -> torch.Tensor:
96
+ return self.codec.decode(frames)
97
+
98
+
99
+ def tokenize_audio(tokenizer: AudioTokenizer, audio):
100
+ # Load and pre-process the audio waveform
101
+ if isinstance(audio, str):
102
+ wav, sr = torchaudio.load(audio)
103
+ else:
104
+ wav, sr = audio
105
+ wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
106
+ wav = wav.unsqueeze(0)
107
+
108
+ # Extract discrete codes from EnCodec
109
+ with torch.no_grad():
110
+ encoded_frames = tokenizer.encode(wav)
111
+ return encoded_frames
112
+
113
+
114
+ if __name__ == "__main__":
115
+ model = EncodecModel.encodec_model_24khz()
116
+ model.set_target_bandwidth(6.0)
117
+
118
+ samples = torch.from_numpy(np.random.random([4, 1, 1600])).type(
119
+ torch.float32
120
+ )
121
+ codes_raw = model.encode(samples)
122
+
123
+ remove_encodec_weight_norm(model)
124
+ codes_norm = model.encode(samples)
125
+
126
+ assert torch.allclose(codes_raw[0][0], codes_norm[0][0])
macros.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NUM_LAYERS = 12
2
+ NUM_HEAD = 16
3
+ N_DIM = 1024
4
+ PREFIX_MODE = 1
5
+ NUM_QUANTIZERS = 8
6
+ SAMPLE_RATE = 24000
7
+
8
+ lang2token = {
9
+ 'zh': "[ZH]",
10
+ 'ja': "[JA]",
11
+ "en": "[EN]",
12
+ "AR": "[AR]",
13
+ 'mix': "",
14
+ }
15
+
16
+ lang2code = {
17
+ 'zh': 0,
18
+ 'ja': 1,
19
+ "en": 2,
20
+ "ar": 3,
21
+ }
22
+
23
+ token2lang = {
24
+ '[ZH]': "zh",
25
+ '[JA]': "ja",
26
+ "[EN]": "en",
27
+ "[AR]": "ar",
28
+ "": "mix"
29
+ }
30
+
31
+ code2lang = {
32
+ 0: 'zh',
33
+ 1: 'ja',
34
+ 2: "en",
35
+ 3: "ar",
36
+ }
37
+
38
+ langdropdown2token = {
39
+ 'English': "[EN]",
40
+ '中文': "[ZH]",
41
+ '日本語': "[JA]",
42
+ 'عربي':"[AR]",
43
+ 'Mix': "",
44
+ }
main.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from sqlalchemy import create_engine
3
+ from sqlalchemy.orm import sessionmaker
4
+ from s2smodels import Base, Audio_segment, AudioGeneration
5
+ from pydub import AudioSegment
6
+ import os
7
+ from fastapi import FastAPI, Response
8
+ import torch
9
+ from fastapi.responses import JSONResponse
10
+ from utils.prompt_making import make_prompt
11
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
12
+ from utils.generation import SAMPLE_RATE, generate_audio, preload_models
13
+ from io import BytesIO
14
+ from pyannote.audio import Pipeline
15
+ import soundfile as sf
16
+ from fastapi_cors import CORS
17
+ DATABASE_URL = "sqlite:///./sql_app.db"
18
+ engine = create_engine(DATABASE_URL)
19
+ Session = sessionmaker(bind=engine)
20
+
21
+ app = FastAPI()
22
+ """
23
+ origins = ["*"]
24
+
25
+ app.add_middleware(
26
+ CORS,
27
+ allow_origins=origins,
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
+ """
33
+ Base.metadata.create_all(engine)
34
+
35
+
36
+ @app.get("/")
37
+ def root():
38
+ return {"message": "No result"}
39
+
40
+ #add audio segements in Audio_segment Table
41
+ def create_segment(start_time: float, end_time: float, audio: AudioSegment, type: str):
42
+ session = Session()
43
+ audio_bytes = BytesIO()
44
+ audio.export(audio_bytes, format='wav')
45
+ audio_bytes = audio_bytes.getvalue()
46
+ segment = Audio_segment(start_time=start_time, end_time=end_time, type=type, audio=audio_bytes)
47
+ session.add(segment)
48
+ session.commit()
49
+ session.close()
50
+
51
+ return {"status_code": 200, "message": "success"}
52
+
53
+
54
+ #add target audio to AudioGeneration Table
55
+ def generate_target(audio: AudioSegment):
56
+ session = Session()
57
+ audio_bytes = BytesIO()
58
+ audio.export(audio_bytes, format='wav')
59
+ audio_bytes = audio_bytes.getvalue()
60
+ target_audio = AudioGeneration(audio=audio_bytes)
61
+ session.add(target_audio)
62
+ session.commit()
63
+ session.close()
64
+
65
+ return {"status_code": 200, "message": "success"}
66
+ """
67
+ audio segmentation into speech and non-speech using segmentation model
68
+ """
69
+ def audio_speech_nonspeech_detection(audio_url):
70
+ pipeline = Pipeline.from_pretrained(
71
+ "pyannote/speaker-diarization-3.0"
72
+ )
73
+ diarization = pipeline(audio_url)
74
+ speaker_regions=[]
75
+ for turn, _,speaker in diarization.itertracks(yield_label=True):
76
+ speaker_regions.append({"start":turn.start,"end":turn.end})
77
+ sound = AudioSegment.from_wav(audio_url)
78
+ speaker_regions.sort(key=lambda x: x['start'])
79
+ non_speech_regions = []
80
+ for i in range(1, len(speaker_regions)):
81
+ start = speaker_regions[i-1]['end']
82
+ end = speaker_regions[i]['start']
83
+ if end > start:
84
+ non_speech_regions.append({'start': start, 'end': end})
85
+ first_speech_start = speaker_regions[0]['start']
86
+ if first_speech_start > 0:
87
+ non_speech_regions.insert(0, {'start': 0, 'end': first_speech_start})
88
+ last_speech_end = speaker_regions[-1]['end']
89
+ total_audio_duration = len(sound)
90
+ if last_speech_end < total_audio_duration:
91
+ non_speech_regions.append({'start': last_speech_end, 'end': total_audio_duration})
92
+ return speaker_regions,non_speech_regions
93
+
94
+ """
95
+ save speech and non-speech segments in audio_segment table
96
+ """
97
+ def split_audio_segments(audio_url):
98
+ sound = AudioSegment.from_wav(audio_url)
99
+ speech_segments, non_speech_segment = audio_speech_nonspeech_detection(audio_url)
100
+ # Process speech segments
101
+ for i, speech_segment in enumerate(speech_segments):
102
+ start = int(speech_segment['start'] * 1000)
103
+ end = int(speech_segment['end'] * 1000)
104
+ segment = sound[start:end]
105
+ create_segment(start_time=start/1000,
106
+ end_time=end/1000,
107
+ type="speech",audio=segment)
108
+ # Process non-speech segments
109
+ for i, non_speech_segment in enumerate(non_speech_segment):
110
+ start = int(non_speech_segment['start'] * 1000)
111
+ end = int(non_speech_segment['end'] * 1000)
112
+ segment = sound[start:end]
113
+ create_segment(start_time=start/1000,
114
+ end_time=end/1000,
115
+ type="non-speech",audio=segment)
116
+
117
+ #@app.post("/translate_en_ar/")
118
+ def en_text_to_ar_text_translation(text):
119
+ pipe = pipeline("translation", model="facebook/nllb-200-distilled-600M")
120
+ result=pipe(text,src_lang='English',tgt_lang='Egyptain Arabic')
121
+ return result[0]['translation_text']
122
+
123
+
124
+ def make_prompt_audio(name,audio_path):
125
+ make_prompt(name=name, audio_prompt_path=audio_path)
126
+
127
+ # whisper model for speech to text process (english language)
128
+ #@app.post("/en_speech_ar_text/")
129
+ def en_speech_to_en_text_process(segment):
130
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
131
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
132
+ model_id = "openai/whisper-large-v3"
133
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
134
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
135
+ model.to(device)
136
+ processor = AutoProcessor.from_pretrained(model_id)
137
+ pipe = pipeline(
138
+ "automatic-speech-recognition",
139
+ model=model,
140
+ tokenizer=processor.tokenizer,
141
+ feature_extractor=processor.feature_extractor,
142
+ max_new_tokens=128,
143
+ chunk_length_s=30,
144
+ batch_size=16,
145
+ return_timestamps=True,
146
+ torch_dtype=torch_dtype,
147
+ device=device,
148
+ )
149
+ result = pipe(segment)
150
+ return result["text"]
151
+
152
+ #text to speech using VALL-E-X model
153
+ #@app.post("/text_to_speech/")
154
+ def text_to_speech(segment_id, target_text, audio_prompt):
155
+ preload_models()
156
+ session = Session()
157
+ segment = session.query(Audio_segment).get(segment_id)
158
+ make_prompt_audio(name=f"audio_{segment_id}",audio_path=audio_prompt)
159
+ audio_array = generate_audio(target_text,f"audio_{segment_id}")
160
+ temp_file = BytesIO()
161
+ sf.write(temp_file, audio_array, SAMPLE_RATE, format='wav')
162
+ temp_file.seek(0)
163
+ segment.audio = temp_file.getvalue()
164
+ session.commit()
165
+ session.close()
166
+ temp_file.close()
167
+ #os.remove(temp_file)
168
+
169
+ """
170
+ reconstruct target audio using all updated segment
171
+ in audio_segment table and then remove all audio_Segment records
172
+ """
173
+ def construct_audio():
174
+ session = Session()
175
+ # Should be ordered by start_time
176
+ segments = session.query(Audio_segment).order_by('start_time').all()
177
+ audio_files = []
178
+ for segment in segments:
179
+ audio_files.append(AudioSegment.from_file(BytesIO(segment.audio), format='wav'))
180
+ target_audio = sum(audio_files, AudioSegment.empty())
181
+ generate_target(audio=target_audio)
182
+
183
+ # Delete all records in Audio_segment table
184
+ session.query(Audio_segment).delete()
185
+ session.commit()
186
+ session.close()
187
+
188
+ """
189
+ source => english speech
190
+ target => arabic speeech
191
+ """
192
+
193
+ #@app.post("/en_speech_ar_speech/")
194
+ def speech_to_speech_translation_en_ar(audio_url):
195
+ session=Session()
196
+ target_text=None
197
+ split_audio_segments(audio_url)
198
+ #filtering by type
199
+ speech_segments = session.query(Audio_segment).filter(Audio_segment.type == "speech").all()
200
+ for segment in speech_segments:
201
+ audio_data = segment.audio
202
+ text = en_speech_to_en_text_process(audio_data)
203
+ if text:
204
+ target_text=en_text_to_ar_text_translation(text)
205
+ else:
206
+ print("speech_to_text_process function not return result. ")
207
+ if target_text is None:
208
+ print("Target text is None.")
209
+ else:
210
+ segment_id = segment.id
211
+ segment_duration = segment.end_time - segment.start_time
212
+ if segment_duration <=15:
213
+ text_to_speech(segment_id,target_text,segment.audio)
214
+ else:
215
+ audio_data=extract_15_seconds(segment.audio,segment.start_time,segment.end_time)
216
+ text_to_speech(segment_id,target_text,audio_data)
217
+ os.remove(audio_data)
218
+ construct_audio()
219
+ return JSONResponse(status_code=200, content={"status_code":"succcessfully"})
220
+
221
+
222
+ @app.get("/get_ar_audio/")
223
+ async def get_ar_audio(audio_url):
224
+ #speech_to_speech_translation_en_ar(audio_url)
225
+ session = Session()
226
+ # Get target audio from AudioGeneration
227
+ target_audio = session.query(AudioGeneration).order_by(AudioGeneration.id.desc()).first()
228
+ # Remove target audio from database
229
+ #session.query(AudioGeneration).delete()
230
+ #session.commit()
231
+ #session.close()
232
+ if target_audio is None:
233
+ raise ValueError("No audio found in the database")
234
+
235
+ audio_bytes = target_audio.audio
236
+ return Response(content=audio_bytes, media_type="audio/wav")
237
+
238
+
239
+ # speech to speech from arabic to english processes
240
+
241
+ #@app.post("/ar_speech_to_en_text/")
242
+ def ar_speech_to_ar_text_process(segment):
243
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
244
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
245
+ model_id = "openai/whisper-large-v3"
246
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
247
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
248
+ model.to(device)
249
+ processor = AutoProcessor.from_pretrained(model_id)
250
+ pipe = pipeline(
251
+ "automatic-speech-recognition",
252
+ model=model,
253
+ tokenizer=processor.tokenizer,
254
+ feature_extractor=processor.feature_extractor,
255
+ max_new_tokens=128,
256
+ chunk_length_s=30,
257
+ batch_size=16,
258
+ return_timestamps=True,
259
+ torch_dtype=torch_dtype,
260
+ device=device,
261
+ )
262
+ result = pipe(segment,generate_kwargs={"language": "arabic"})
263
+ return result["text"]
264
+
265
+ #@app.post("/ar_translate/")
266
+ def ar_text_to_en_text_translation(text):
267
+ pipe = pipeline("translation", model="facebook/nllb-200-distilled-600M")
268
+ result=pipe(text,src_lang='Egyptain Arabic',tgt_lang='English')
269
+ return result[0]['translation_text']
270
+
271
+
272
+ """
273
+ source => arabic speech
274
+ target => english speeech
275
+ """
276
+ def speech_to_speech_translation_ar_en(audio_url):
277
+ session=Session()
278
+ target_text=None
279
+ split_audio_segments(audio_url)
280
+ #filtering by type
281
+ speech_segments = session.query(Audio_segment).filter(Audio_segment.type == "speech").all()
282
+ for segment in speech_segments:
283
+ audio_data = segment.audio
284
+ text = ar_speech_to_ar_text_process(audio_data)
285
+ if text:
286
+ target_text=ar_text_to_en_text_translation(text)
287
+ else:
288
+ print("speech_to_text_process function not return result. ")
289
+ if target_text is None:
290
+ print("Target text is None.")
291
+ else:
292
+ segment_id = segment.id
293
+ segment_duration = segment.end_time - segment.start_time
294
+ if segment_duration <=15:
295
+ text_to_speech(segment_id,target_text,segment.audio)
296
+ else:
297
+ audio_data=extract_15_seconds(segment.audio,segment.start_time,segment.end_time)
298
+ text_to_speech(segment_id,target_text,audio_data)
299
+ os.remove(audio_data)
300
+ construct_audio()
301
+ return JSONResponse(status_code=200, content={"status_code":"succcessfully"})
302
+
303
+
304
+ @app.get("/get_en_audio/")
305
+ async def get_en_audio(audio_url):
306
+ speech_to_speech_translation_ar_en(audio_url)
307
+ session = Session()
308
+ # Get target audio from AudioGeneration
309
+ target_audio = session.query(AudioGeneration).order_by(AudioGeneration.id.desc()).first()
310
+ # Remove target audio from database
311
+ #session.query(AudioGeneration).delete()
312
+ #session.commit()
313
+ #session.close()
314
+ if target_audio is None:
315
+ raise ValueError("No audio found in the database")
316
+
317
+ audio_bytes = target_audio.audio
318
+ return Response(content=audio_bytes, media_type="audio/wav")
319
+
320
+
321
+
322
+ @app.get("/audio_segments/")
323
+ def get_all_audio_segments():
324
+ session=Session()
325
+ segments = session.query(Audio_segment).all()
326
+ segment_dicts = []
327
+ for segment in segments:
328
+ if segment.audio is None:
329
+ raise ValueError("No audio found in the database")
330
+
331
+ audio_bytes = segment.audio
332
+ file_path = f"segments//segment{segment.id}_audio.wav"
333
+ with open(file_path, "wb") as file:
334
+ file.write(audio_bytes)
335
+ segment_dicts.append({
336
+ "id": segment.id,
337
+ "start_time": segment.start_time,
338
+ "end_time": segment.end_time,
339
+ "type": segment.type,
340
+ "audio_url":file_path
341
+ })
342
+ session.close()
343
+ return {"segments":segment_dicts}
344
+
345
+
346
+ def extract_15_seconds(audio_data, start_time, end_time):
347
+ audio_segment = AudioSegment.from_file(BytesIO(audio_data), format='wav')
348
+ start_ms = start_time * 1000
349
+ end_ms = min((start_time + 15) * 1000, end_time * 1000)
350
+ extracted_segment = audio_segment[start_ms:end_ms]
351
+ temp_wav_path = "temp.wav"
352
+ extracted_segment.export(temp_wav_path, format="wav")
353
+
354
+ return temp_wav_path
355
+
models/__init__.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch.nn as nn
4
+ # from icefall.utils import AttributeDict, str2bool
5
+
6
+ from .macros import (
7
+ NUM_AUDIO_TOKENS,
8
+ NUM_MEL_BINS,
9
+ NUM_SPEAKER_CLASSES,
10
+ NUM_TEXT_TOKENS,
11
+ SPEAKER_EMBEDDING_DIM,
12
+ )
13
+ from .transformer import Transformer
14
+ from .vallex import VALLE, VALLF
15
+ from .visualizer import visualize
16
+
17
+
18
+ def add_model_arguments(parser: argparse.ArgumentParser):
19
+ parser.add_argument(
20
+ "--model-name",
21
+ type=str,
22
+ default="VALL-E",
23
+ help="VALL-E, VALL-F, Transformer.",
24
+ )
25
+ parser.add_argument(
26
+ "--decoder-dim",
27
+ type=int,
28
+ default=1024,
29
+ help="Embedding dimension in the decoder model.",
30
+ )
31
+ parser.add_argument(
32
+ "--nhead",
33
+ type=int,
34
+ default=16,
35
+ help="Number of attention heads in the Decoder layers.",
36
+ )
37
+ parser.add_argument(
38
+ "--num-decoder-layers",
39
+ type=int,
40
+ default=12,
41
+ help="Number of Decoder layers.",
42
+ )
43
+ parser.add_argument(
44
+ "--scale-factor",
45
+ type=float,
46
+ default=1.0,
47
+ help="Model scale factor which will be assigned different meanings in different models.",
48
+ )
49
+ parser.add_argument(
50
+ "--norm-first",
51
+ type=bool,
52
+ default=True,
53
+ help="Pre or Post Normalization.",
54
+ )
55
+ parser.add_argument(
56
+ "--add-prenet",
57
+ type=bool,
58
+ default=False,
59
+ help="Whether add PreNet after Inputs.",
60
+ )
61
+
62
+ # VALL-E & F
63
+ parser.add_argument(
64
+ "--prefix-mode",
65
+ type=int,
66
+ default=1,
67
+ help="The mode for how to prefix VALL-E NAR Decoder, "
68
+ "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.",
69
+ )
70
+ parser.add_argument(
71
+ "--share-embedding",
72
+ type=bool,
73
+ default=True,
74
+ help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.",
75
+ )
76
+ parser.add_argument(
77
+ "--prepend-bos",
78
+ type=bool,
79
+ default=False,
80
+ help="Whether prepend <BOS> to the acoustic tokens -> AR Decoder inputs.",
81
+ )
82
+ parser.add_argument(
83
+ "--num-quantizers",
84
+ type=int,
85
+ default=8,
86
+ help="Number of Audio/Semantic quantization layers.",
87
+ )
88
+
89
+ # Transformer
90
+ parser.add_argument(
91
+ "--scaling-xformers",
92
+ type=bool,
93
+ default=False,
94
+ help="Apply Reworked Conformer scaling on Transformers.",
95
+ )
96
+
97
+
98
+ def get_model(params) -> nn.Module:
99
+ if params.model_name.lower() in ["vall-f", "vallf"]:
100
+ model = VALLF(
101
+ params.decoder_dim,
102
+ params.nhead,
103
+ params.num_decoder_layers,
104
+ norm_first=params.norm_first,
105
+ add_prenet=params.add_prenet,
106
+ prefix_mode=params.prefix_mode,
107
+ share_embedding=params.share_embedding,
108
+ nar_scale_factor=params.scale_factor,
109
+ prepend_bos=params.prepend_bos,
110
+ num_quantizers=params.num_quantizers,
111
+ )
112
+ elif params.model_name.lower() in ["vall-e", "valle"]:
113
+ model = VALLE(
114
+ params.decoder_dim,
115
+ params.nhead,
116
+ params.num_decoder_layers,
117
+ norm_first=params.norm_first,
118
+ add_prenet=params.add_prenet,
119
+ prefix_mode=params.prefix_mode,
120
+ share_embedding=params.share_embedding,
121
+ nar_scale_factor=params.scale_factor,
122
+ prepend_bos=params.prepend_bos,
123
+ num_quantizers=params.num_quantizers,
124
+ )
125
+ else:
126
+ assert params.model_name in ["Transformer"]
127
+ model = Transformer(
128
+ params.decoder_dim,
129
+ params.nhead,
130
+ params.num_decoder_layers,
131
+ norm_first=params.norm_first,
132
+ add_prenet=params.add_prenet,
133
+ scaling_xformers=params.scaling_xformers,
134
+ )
135
+
136
+ return model
models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (4.4 kB). View file
 
models/__pycache__/macros.cpython-311.pyc ADDED
Binary file (335 Bytes). View file
 
models/__pycache__/transformer.cpython-311.pyc ADDED
Binary file (15.1 kB). View file
 
models/__pycache__/vallex.cpython-311.pyc ADDED
Binary file (37.6 kB). View file
 
models/__pycache__/visualizer.cpython-311.pyc ADDED
Binary file (5.17 kB). View file
 
models/macros.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Text
2
+ NUM_TEXT_TOKENS = 2048
3
+
4
+ # Audio
5
+ NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins
6
+ NUM_MEL_BINS = 100 # BigVGAN bigvgan_24khz_100band
7
+
8
+
9
+ # Speaker
10
+ NUM_SPEAKER_CLASSES = 4096
11
+ SPEAKER_EMBEDDING_DIM = 64
models/transformer.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 (authors: Feiteng Li)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from functools import partial
16
+ from typing import Any, Dict, List, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ # from icefall.utils import make_pad_mask
22
+ # from torchmetrics.classification import BinaryAccuracy
23
+
24
+ from models.vallex import Transpose
25
+ from modules.embedding import SinePositionalEmbedding, TokenEmbedding
26
+ from modules.scaling import BalancedDoubleSwish, ScaledLinear
27
+ from modules.transformer import (
28
+ BalancedBasicNorm,
29
+ IdentityNorm,
30
+ TransformerDecoderLayer,
31
+ TransformerEncoder,
32
+ TransformerEncoderLayer,
33
+ )
34
+
35
+ from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS
36
+ from .visualizer import visualize
37
+
38
+ IdentityNorm = IdentityNorm
39
+
40
+
41
+ class Transformer(nn.Module):
42
+ """It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding)
43
+ Neural Speech Synthesis with Transformer Network
44
+ https://arxiv.org/abs/1809.08895
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ d_model: int,
50
+ nhead: int,
51
+ num_layers: int,
52
+ norm_first: bool = True,
53
+ add_prenet: bool = False,
54
+ scaling_xformers: bool = False,
55
+ ):
56
+ """
57
+ Args:
58
+ d_model:
59
+ The number of expected features in the input (required).
60
+ nhead:
61
+ The number of heads in the multiheadattention models (required).
62
+ num_layers:
63
+ The number of sub-decoder-layers in the decoder (required).
64
+ """
65
+ super().__init__()
66
+ self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
67
+
68
+ if add_prenet:
69
+ self.encoder_prenet = nn.Sequential(
70
+ Transpose(),
71
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
72
+ nn.BatchNorm1d(d_model),
73
+ nn.ReLU(),
74
+ nn.Dropout(0.5),
75
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
76
+ nn.BatchNorm1d(d_model),
77
+ nn.ReLU(),
78
+ nn.Dropout(0.5),
79
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
80
+ nn.BatchNorm1d(d_model),
81
+ nn.ReLU(),
82
+ nn.Dropout(0.5),
83
+ Transpose(),
84
+ nn.Linear(d_model, d_model),
85
+ )
86
+
87
+ self.decoder_prenet = nn.Sequential(
88
+ nn.Linear(NUM_MEL_BINS, 256),
89
+ nn.ReLU(),
90
+ nn.Dropout(0.5),
91
+ nn.Linear(256, 256),
92
+ nn.ReLU(),
93
+ nn.Dropout(0.5),
94
+ nn.Linear(256, d_model),
95
+ )
96
+
97
+ assert scaling_xformers is False # TODO: update this block
98
+ else:
99
+ self.encoder_prenet = nn.Identity()
100
+ if scaling_xformers:
101
+ self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model)
102
+ else:
103
+ self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model)
104
+
105
+ self.encoder_position = SinePositionalEmbedding(
106
+ d_model,
107
+ dropout=0.1,
108
+ scale=False,
109
+ )
110
+ self.decoder_position = SinePositionalEmbedding(
111
+ d_model, dropout=0.1, scale=False
112
+ )
113
+
114
+ if scaling_xformers:
115
+ self.encoder = TransformerEncoder(
116
+ TransformerEncoderLayer(
117
+ d_model,
118
+ nhead,
119
+ dim_feedforward=d_model * 4,
120
+ dropout=0.1,
121
+ batch_first=True,
122
+ norm_first=norm_first,
123
+ linear1_self_attention_cls=ScaledLinear,
124
+ linear2_self_attention_cls=partial(
125
+ ScaledLinear, initial_scale=0.01
126
+ ),
127
+ linear1_feedforward_cls=ScaledLinear,
128
+ linear2_feedforward_cls=partial(
129
+ ScaledLinear, initial_scale=0.01
130
+ ),
131
+ activation=partial(
132
+ BalancedDoubleSwish,
133
+ channel_dim=-1,
134
+ max_abs=10.0,
135
+ min_prob=0.25,
136
+ ),
137
+ layer_norm_cls=IdentityNorm,
138
+ ),
139
+ num_layers=num_layers,
140
+ norm=BalancedBasicNorm(d_model) if norm_first else None,
141
+ )
142
+
143
+ self.decoder = nn.TransformerDecoder(
144
+ TransformerDecoderLayer(
145
+ d_model,
146
+ nhead,
147
+ dim_feedforward=d_model * 4,
148
+ dropout=0.1,
149
+ batch_first=True,
150
+ norm_first=norm_first,
151
+ linear1_self_attention_cls=ScaledLinear,
152
+ linear2_self_attention_cls=partial(
153
+ ScaledLinear, initial_scale=0.01
154
+ ),
155
+ linear1_feedforward_cls=ScaledLinear,
156
+ linear2_feedforward_cls=partial(
157
+ ScaledLinear, initial_scale=0.01
158
+ ),
159
+ activation=partial(
160
+ BalancedDoubleSwish,
161
+ channel_dim=-1,
162
+ max_abs=10.0,
163
+ min_prob=0.25,
164
+ ),
165
+ layer_norm_cls=IdentityNorm,
166
+ ),
167
+ num_layers=num_layers,
168
+ norm=BalancedBasicNorm(d_model) if norm_first else None,
169
+ )
170
+
171
+ self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS)
172
+ self.stop_layer = nn.Linear(d_model, 1)
173
+ else:
174
+ self.encoder = nn.TransformerEncoder(
175
+ nn.TransformerEncoderLayer(
176
+ d_model,
177
+ nhead,
178
+ dim_feedforward=d_model * 4,
179
+ activation=F.relu,
180
+ dropout=0.1,
181
+ batch_first=True,
182
+ norm_first=norm_first,
183
+ ),
184
+ num_layers=num_layers,
185
+ norm=nn.LayerNorm(d_model) if norm_first else None,
186
+ )
187
+
188
+ self.decoder = nn.TransformerDecoder(
189
+ nn.TransformerDecoderLayer(
190
+ d_model,
191
+ nhead,
192
+ dim_feedforward=d_model * 4,
193
+ activation=F.relu,
194
+ dropout=0.1,
195
+ batch_first=True,
196
+ norm_first=norm_first,
197
+ ),
198
+ num_layers=num_layers,
199
+ norm=nn.LayerNorm(d_model) if norm_first else None,
200
+ )
201
+
202
+ self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS)
203
+ self.stop_layer = nn.Linear(d_model, 1)
204
+
205
+ self.stop_accuracy_metric = BinaryAccuracy(
206
+ threshold=0.5, multidim_average="global"
207
+ )
208
+
209
+ # self.apply(self._init_weights)
210
+
211
+ # def _init_weights(self, module):
212
+ # if isinstance(module, (nn.Linear)):
213
+ # module.weight.data.normal_(mean=0.0, std=0.02)
214
+ # if isinstance(module, nn.Linear) and module.bias is not None:
215
+ # module.bias.data.zero_()
216
+ # elif isinstance(module, nn.LayerNorm):
217
+ # module.bias.data.zero_()
218
+ # module.weight.data.fill_(1.0)
219
+ # elif isinstance(module, nn.Embedding):
220
+ # module.weight.data.normal_(mean=0.0, std=0.02)
221
+
222
+ def forward(
223
+ self,
224
+ x: torch.Tensor,
225
+ x_lens: torch.Tensor,
226
+ y: torch.Tensor,
227
+ y_lens: torch.Tensor,
228
+ reduction: str = "sum",
229
+ train_stage: int = 0,
230
+ **kwargs,
231
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
232
+ """
233
+ Args:
234
+ x:
235
+ A 2-D tensor of shape (N, S).
236
+ x_lens:
237
+ A 1-D tensor of shape (N,). It contains the number of tokens in `x`
238
+ before padding.
239
+ y:
240
+ A 3-D tensor of shape (N, T, 8).
241
+ y_lens:
242
+ A 1-D tensor of shape (N,). It contains the number of tokens in `x`
243
+ before padding.
244
+ train_stage:
245
+ Not used in this model.
246
+ Returns:
247
+ Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
248
+ """
249
+ del train_stage
250
+
251
+ assert x.ndim == 2, x.shape
252
+ assert x_lens.ndim == 1, x_lens.shape
253
+ assert y.ndim == 3, y.shape
254
+ assert y_lens.ndim == 1, y_lens.shape
255
+
256
+ assert torch.all(x_lens > 0)
257
+
258
+ # NOTE: x has been padded in TextTokenCollater
259
+ x_mask = make_pad_mask(x_lens).to(x.device)
260
+
261
+ x = self.text_embedding(x)
262
+ x = self.encoder_prenet(x)
263
+ x = self.encoder_position(x)
264
+ x = self.encoder(x, src_key_padding_mask=x_mask)
265
+
266
+ total_loss, metrics = 0.0, {}
267
+
268
+ y_mask = make_pad_mask(y_lens).to(y.device)
269
+ y_mask_float = y_mask.type(torch.float32)
270
+ data_mask = 1.0 - y_mask_float.unsqueeze(-1)
271
+
272
+ # Training
273
+ # AR Decoder
274
+ def pad_y(y):
275
+ y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach()
276
+ # inputs, targets
277
+ return y[:, :-1], y[:, 1:]
278
+
279
+ y, targets = pad_y(y * data_mask) # mask padding as zeros
280
+
281
+ y_emb = self.decoder_prenet(y)
282
+ y_pos = self.decoder_position(y_emb)
283
+
284
+ y_len = y_lens.max()
285
+ tgt_mask = torch.triu(
286
+ torch.ones(y_len, y_len, device=y.device, dtype=torch.bool),
287
+ diagonal=1,
288
+ )
289
+ y_dec = self.decoder(
290
+ y_pos,
291
+ x,
292
+ tgt_mask=tgt_mask,
293
+ memory_key_padding_mask=x_mask,
294
+ )
295
+
296
+ predict = self.predict_layer(y_dec)
297
+ # loss
298
+ total_loss = F.mse_loss(predict, targets, reduction=reduction)
299
+
300
+ logits = self.stop_layer(y_dec).squeeze(-1)
301
+ stop_loss = F.binary_cross_entropy_with_logits(
302
+ logits,
303
+ y_mask_float.detach(),
304
+ weight=1.0 + y_mask_float.detach() * 4.0,
305
+ reduction=reduction,
306
+ )
307
+ metrics["stop_loss"] = stop_loss.detach()
308
+
309
+ stop_accuracy = self.stop_accuracy_metric(
310
+ (torch.sigmoid(logits) >= 0.5).type(torch.int64),
311
+ y_mask.type(torch.int64),
312
+ )
313
+ # icefall MetricsTracker.norm_items()
314
+ metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type(
315
+ torch.float32
316
+ )
317
+
318
+ return ((x, predict), total_loss + 100.0 * stop_loss, metrics)
319
+
320
+ def inference(
321
+ self,
322
+ x: torch.Tensor,
323
+ x_lens: torch.Tensor,
324
+ y: Any = None,
325
+ **kwargs,
326
+ ) -> torch.Tensor:
327
+ """
328
+ Args:
329
+ x:
330
+ A 2-D tensor of shape (1, S).
331
+ x_lens:
332
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
333
+ before padding.
334
+ Returns:
335
+ Return the predicted audio code matrix and cross-entropy loss.
336
+ """
337
+ assert x.ndim == 2, x.shape
338
+ assert x_lens.ndim == 1, x_lens.shape
339
+
340
+ assert torch.all(x_lens > 0)
341
+
342
+ x_mask = make_pad_mask(x_lens).to(x.device)
343
+
344
+ x = self.text_embedding(x)
345
+ x = self.encoder_prenet(x)
346
+ x = self.encoder_position(x)
347
+ x = self.encoder(x, src_key_padding_mask=x_mask)
348
+
349
+ x_mask = make_pad_mask(x_lens).to(x.device)
350
+
351
+ # AR Decoder
352
+ # TODO: Managing decoder steps avoid repetitive computation
353
+ y = torch.zeros(
354
+ [x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device
355
+ )
356
+ while True:
357
+ y_emb = self.decoder_prenet(y)
358
+ y_pos = self.decoder_position(y_emb)
359
+
360
+ tgt_mask = torch.triu(
361
+ torch.ones(
362
+ y.shape[1], y.shape[1], device=y.device, dtype=torch.bool
363
+ ),
364
+ diagonal=1,
365
+ )
366
+
367
+ y_dec = self.decoder(
368
+ y_pos,
369
+ x,
370
+ tgt_mask=tgt_mask,
371
+ memory_mask=None,
372
+ memory_key_padding_mask=x_mask,
373
+ )
374
+ predict = self.predict_layer(y_dec[:, -1:])
375
+
376
+ logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5
377
+ if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()):
378
+ print(
379
+ f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]"
380
+ )
381
+ break
382
+
383
+ y = torch.concat([y, predict], dim=1)
384
+
385
+ return y[:, 1:]
386
+
387
+ def visualize(
388
+ self,
389
+ predicts: Tuple[torch.Tensor],
390
+ batch: Dict[str, Union[List, torch.Tensor]],
391
+ output_dir: str,
392
+ limit: int = 4,
393
+ ) -> None:
394
+ visualize(predicts, batch, output_dir, limit=limit)
models/vallex.py ADDED
@@ -0,0 +1,853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 (authors: Feiteng Li)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+ from typing import Dict, Iterator, List, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ # from icefall.utils import make_pad_mask
23
+ # from torchmetrics.classification import MulticlassAccuracy
24
+
25
+ from data.input_strategies import PromptedFeatures
26
+ from modules.embedding import SinePositionalEmbedding, TokenEmbedding
27
+ from modules.transformer import (
28
+ AdaptiveLayerNorm,
29
+ LayerNorm,
30
+ TransformerDecoderLayer,
31
+ TransformerEncoder,
32
+ TransformerEncoderLayer,
33
+ )
34
+
35
+ from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS
36
+ from .visualizer import visualize
37
+
38
+
39
+ class Transpose(nn.Identity):
40
+ """(N, T, D) -> (N, D, T)"""
41
+
42
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
43
+ return input.transpose(1, 2)
44
+
45
+
46
+ # NOTE: There are two ways to implement the model
47
+ # 1) [VALL-F] standard TransformerDecoder, use x as memory
48
+ # 2) [VALL-E] modified TransformerDecoder like GPT-x(e.g. causal TransformerEncoder),
49
+ # use x as the prefix of decoder inputs
50
+ class VALLF(nn.Module):
51
+ """It implements https://arxiv.org/abs/2301.02111
52
+ "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ d_model: int,
58
+ nhead: int,
59
+ num_layers: int,
60
+ norm_first: bool = True,
61
+ add_prenet: bool = False,
62
+ decoder_cls: Union[
63
+ nn.TransformerDecoder, nn.TransformerEncoder
64
+ ] = nn.TransformerDecoder,
65
+ decoder_layer_cls: Union[
66
+ TransformerDecoderLayer, TransformerEncoderLayer
67
+ ] = TransformerDecoderLayer,
68
+ prefix_mode: int = 0,
69
+ share_embedding: bool = True,
70
+ nar_scale_factor: float = 1.0,
71
+ prepend_bos: bool = True,
72
+ num_quantizers: int = 8,
73
+ ):
74
+ """
75
+ Args:
76
+ d_model:
77
+ The number of expected features in the input (required).
78
+ nhead:
79
+ The number of heads in the multiheadattention models (required).
80
+ num_layers:
81
+ The number of sub-decoder-layers in the decoder (required).
82
+ """
83
+ super().__init__()
84
+ nar_d_model = int(d_model * nar_scale_factor)
85
+
86
+ self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
87
+ self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS)
88
+
89
+ # ID NUM_AUDIO_TOKENS -> PAD
90
+ # ID NUM_AUDIO_TOKENS + 1 -> BOS
91
+ self.ar_audio_prepend_bos = prepend_bos
92
+ self.ar_audio_embedding = TokenEmbedding(
93
+ d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos)
94
+ )
95
+
96
+ # PreNet
97
+ if add_prenet:
98
+ self.ar_text_prenet = nn.Sequential(
99
+ Transpose(),
100
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
101
+ nn.BatchNorm1d(d_model),
102
+ nn.ReLU(),
103
+ nn.Dropout(0.5),
104
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
105
+ nn.BatchNorm1d(d_model),
106
+ nn.ReLU(),
107
+ nn.Dropout(0.5),
108
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
109
+ nn.BatchNorm1d(d_model),
110
+ nn.ReLU(),
111
+ nn.Dropout(0.5),
112
+ Transpose(),
113
+ nn.Linear(d_model, d_model),
114
+ )
115
+
116
+ self.ar_audio_prenet = nn.Sequential(
117
+ nn.Linear(d_model, 256),
118
+ nn.ReLU(),
119
+ nn.Dropout(0.25),
120
+ nn.Linear(256, 256),
121
+ nn.ReLU(),
122
+ nn.Dropout(0.25),
123
+ nn.Linear(256, d_model),
124
+ )
125
+ else:
126
+ self.ar_text_prenet = nn.Identity()
127
+ self.ar_audio_prenet = nn.Identity()
128
+
129
+ self.ar_text_position = SinePositionalEmbedding(
130
+ d_model,
131
+ dropout=0.1,
132
+ scale=False,
133
+ alpha=True,
134
+ )
135
+ self.ar_audio_position = SinePositionalEmbedding(
136
+ d_model,
137
+ dropout=0.1,
138
+ scale=False,
139
+ alpha=True,
140
+ )
141
+
142
+ self.ar_decoder = decoder_cls(
143
+ decoder_layer_cls(
144
+ d_model,
145
+ nhead,
146
+ dim_feedforward=d_model * 4,
147
+ dropout=0.1,
148
+ batch_first=True,
149
+ norm_first=norm_first,
150
+ ),
151
+ num_layers=num_layers,
152
+ norm=LayerNorm(d_model) if norm_first else None,
153
+ )
154
+ self.ar_predict_layer = nn.Linear(
155
+ d_model, NUM_AUDIO_TOKENS + 1, bias=False
156
+ )
157
+
158
+ self.rng = random.Random(0)
159
+ self.num_heads = nhead
160
+ self.prefix_mode = prefix_mode
161
+ self.num_quantizers = num_quantizers
162
+
163
+ assert num_quantizers >= 1
164
+ if num_quantizers > 1:
165
+ self.nar_audio_embeddings = nn.ModuleList(
166
+ [TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)]
167
+ + [
168
+ TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS)
169
+ for i in range(num_quantizers - 1)
170
+ ]
171
+ ) # W_a
172
+
173
+ # PreNet
174
+ if add_prenet:
175
+ self.nar_text_prenet = nn.Sequential(
176
+ Transpose(),
177
+ nn.Conv1d(
178
+ nar_d_model, nar_d_model, kernel_size=5, padding="same"
179
+ ),
180
+ nn.BatchNorm1d(nar_d_model),
181
+ nn.ReLU(),
182
+ nn.Dropout(0.5),
183
+ nn.Conv1d(
184
+ nar_d_model, nar_d_model, kernel_size=5, padding="same"
185
+ ),
186
+ nn.BatchNorm1d(nar_d_model),
187
+ nn.ReLU(),
188
+ nn.Dropout(0.5),
189
+ nn.Conv1d(
190
+ nar_d_model, nar_d_model, kernel_size=5, padding="same"
191
+ ),
192
+ nn.BatchNorm1d(nar_d_model),
193
+ nn.ReLU(),
194
+ nn.Dropout(0.5),
195
+ Transpose(),
196
+ nn.Linear(nar_d_model, nar_d_model),
197
+ )
198
+ self.nar_audio_prenet = nn.Sequential(
199
+ nn.Linear(nar_d_model, 256),
200
+ nn.ReLU(),
201
+ nn.Dropout(0.25),
202
+ nn.Linear(256, 256),
203
+ nn.ReLU(),
204
+ nn.Dropout(0.25),
205
+ nn.Linear(256, nar_d_model),
206
+ )
207
+ else:
208
+ self.nar_text_prenet = nn.Identity()
209
+ self.nar_audio_prenet = nn.Identity()
210
+
211
+ self.nar_text_position = SinePositionalEmbedding(
212
+ nar_d_model,
213
+ dropout=0.0,
214
+ scale=False,
215
+ alpha=False,
216
+ )
217
+ self.nar_audio_position = SinePositionalEmbedding(
218
+ nar_d_model,
219
+ dropout=0.1,
220
+ scale=False,
221
+ alpha=False,
222
+ )
223
+
224
+ self.nar_decoder = decoder_cls(
225
+ decoder_layer_cls(
226
+ nar_d_model,
227
+ int(nhead * nar_scale_factor),
228
+ dim_feedforward=nar_d_model * 4,
229
+ dropout=0.1,
230
+ batch_first=True,
231
+ norm_first=norm_first,
232
+ adaptive_layer_norm=True,
233
+ ),
234
+ num_layers=int(num_layers * nar_scale_factor),
235
+ norm=AdaptiveLayerNorm(
236
+ nar_d_model, norm=nn.LayerNorm(nar_d_model)
237
+ )
238
+ if norm_first
239
+ else None,
240
+ )
241
+ self.nar_predict_layers = nn.ModuleList(
242
+ [
243
+ nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False)
244
+ for i in range(num_quantizers - 1)
245
+ ]
246
+ )
247
+ self.nar_stage_embeddings = nn.ModuleList(
248
+ [
249
+ TokenEmbedding(nar_d_model, 1)
250
+ for i in range(num_quantizers - 1)
251
+ ]
252
+ )
253
+
254
+ if share_embedding:
255
+ # We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa
256
+ # NOTE(Feiteng): In the experiment, this undermines accuracy
257
+ # self.ar_predict_layer.weight = self.ar_audio_embedding.weight
258
+
259
+ # We also share the parameters of the acoustic embedding layer and the output prediction layer,
260
+ # which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer.
261
+ for j in range(0, num_quantizers - 2):
262
+ self.nar_predict_layers[
263
+ j
264
+ ].weight = self.nar_audio_embeddings[j + 2].weight
265
+
266
+ def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
267
+ assert stage > 0
268
+ if stage == 1:
269
+ for name, param in self.named_parameters():
270
+ if name.startswith("ar_"):
271
+ print(f" AR parameter: {name}")
272
+ yield param
273
+
274
+ if stage == 2:
275
+ for name, param in self.named_parameters():
276
+ if name.startswith("nar_"):
277
+ print(f"NAR parameter: {name}")
278
+ yield param
279
+
280
+ def stage_named_parameters(
281
+ self, stage: int = 1
282
+ ) -> Iterator[Tuple[str, nn.Parameter]]:
283
+ assert stage > 0
284
+ if stage == 1:
285
+ for pair in self.named_parameters():
286
+ if pair[0].startswith("ar_"):
287
+ yield pair
288
+
289
+ if stage == 2:
290
+ for pair in self.named_parameters():
291
+ if pair[0].startswith("nar_"):
292
+ yield pair
293
+
294
+ def pad_y_eos(self, y, y_mask_int, eos_id):
295
+ targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
296
+ y_mask_int, (0, 1), value=1
297
+ )
298
+ # inputs, targets
299
+ if self.ar_audio_prepend_bos:
300
+ return (
301
+ F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1),
302
+ targets,
303
+ )
304
+
305
+ return targets[:, :-1], targets[:, 1:]
306
+
307
+ def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes, prefix_mode):
308
+ # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds
309
+ # from the same utterance.
310
+ # We implement this differently.
311
+ if prefix_mode == 0:
312
+ # no prefix
313
+ prefix_len = 0
314
+ y_emb = self.nar_audio_embeddings[0](y)
315
+ for j in range(1, nar_stage):
316
+ # Formula (4) (5)
317
+ y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
318
+ elif prefix_mode == 1:
319
+ # prefix at begining
320
+ int_low = (0.25 * y_lens.min()).type(torch.int64).item()
321
+ prefix_len = torch.randint(0, int_low * 2, size=()).item()
322
+ prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames
323
+
324
+ y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
325
+ y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
326
+ for j in range(1, self.num_quantizers):
327
+ y_prompts += self.nar_audio_embeddings[j](
328
+ codes[:, :prefix_len, j]
329
+ )
330
+ if j < nar_stage:
331
+ y_emb += self.nar_audio_embeddings[j](
332
+ codes[:, prefix_len:, j]
333
+ )
334
+ y_emb = torch.concat([y_prompts, y_emb], axis=1)
335
+ elif prefix_mode in [2, 4]:
336
+ if prefix_mode == 2:
337
+ # random prefix
338
+ prefix_len = min(225, int(0.25 * y_lens.min().item()))
339
+
340
+ y_prompts_codes = []
341
+ for b in range(codes.shape[0]):
342
+ start = self.rng.randint(0, y_lens[b].item() - prefix_len)
343
+ y_prompts_codes.append(
344
+ torch.clone(codes[b, start : start + prefix_len])
345
+ )
346
+ codes[
347
+ b, start : start + prefix_len, nar_stage
348
+ ] = NUM_AUDIO_TOKENS
349
+ y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
350
+ else:
351
+ prefix_len = y_prompts_codes.shape[1]
352
+
353
+ y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
354
+ y_emb = self.nar_audio_embeddings[0](y)
355
+ for j in range(1, self.num_quantizers):
356
+ y_prompts += self.nar_audio_embeddings[j](
357
+ y_prompts_codes[..., j]
358
+ )
359
+ if j < nar_stage:
360
+ y_emb += self.nar_audio_embeddings[j](codes[..., j])
361
+ y_emb = torch.concat([y_prompts, y_emb], axis=1)
362
+ else:
363
+ raise ValueError
364
+
365
+ return y_emb, prefix_len
366
+
367
+ def forward(
368
+ self,
369
+ x: torch.Tensor,
370
+ x_lens: torch.Tensor,
371
+ y: Union[torch.Tensor, PromptedFeatures],
372
+ y_lens: Union[torch.Tensor, PromptedFeatures],
373
+ reduction: str = "sum",
374
+ train_stage: int = 0,
375
+ **kwargs,
376
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
377
+ raise NotImplementedError
378
+
379
+ def inference(
380
+ self,
381
+ x: torch.Tensor,
382
+ x_lens: torch.Tensor,
383
+ y: torch.Tensor,
384
+ enroll_x_lens: Union[torch.Tensor, None] = None,
385
+ top_k: int = -100,
386
+ temperature: float = 1.0,
387
+ ) -> torch.Tensor:
388
+ raise NotImplementedError
389
+
390
+ def visualize(
391
+ self,
392
+ predicts: Tuple[torch.Tensor],
393
+ batch: Dict[str, Union[List, torch.Tensor]],
394
+ output_dir: str,
395
+ limit: int = 4,
396
+ ) -> None:
397
+ raise NotImplementedError
398
+
399
+
400
+ class VALLE(VALLF):
401
+ """It implements https://arxiv.org/abs/2301.02111
402
+ "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
403
+ """
404
+
405
+ def __init__(
406
+ self,
407
+ d_model: int,
408
+ nhead: int,
409
+ num_layers: int,
410
+ norm_first: bool = True,
411
+ add_prenet: bool = False,
412
+ prefix_mode: int = 0,
413
+ share_embedding: bool = True,
414
+ nar_scale_factor: float = 1.0,
415
+ **kwargs,
416
+ ):
417
+ """
418
+ Args:
419
+ d_model:
420
+ The number of expected features in the input (required).
421
+ nhead:
422
+ The number of heads in the multiheadattention models (required).
423
+ num_layers:
424
+ The number of sub-decoder-layers in the decoder (required).
425
+ """
426
+ super(VALLE, self).__init__(
427
+ d_model,
428
+ nhead,
429
+ num_layers,
430
+ norm_first=norm_first,
431
+ add_prenet=add_prenet,
432
+ decoder_cls=TransformerEncoder,
433
+ decoder_layer_cls=TransformerEncoderLayer,
434
+ prefix_mode=prefix_mode,
435
+ share_embedding=share_embedding,
436
+ nar_scale_factor=nar_scale_factor,
437
+ **kwargs,
438
+ )
439
+ self.language_ID = {
440
+ 'en': 0,
441
+ 'zh': 1,
442
+ 'ja': 2,
443
+ }
444
+ self.ar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
445
+ self.nar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
446
+
447
+ def forward(
448
+ self,
449
+ x: torch.Tensor,
450
+ x_lens: torch.Tensor,
451
+ y: Union[torch.Tensor, PromptedFeatures],
452
+ y_lens: Union[torch.Tensor, PromptedFeatures],
453
+ reduction: str = "sum",
454
+ train_stage: int = 0,
455
+ **kwargs,
456
+ ):
457
+ raise NotImplementedError
458
+ def inference(
459
+ self,
460
+ x: torch.Tensor,
461
+ x_lens: torch.Tensor,
462
+ y: torch.Tensor,
463
+ enroll_x_lens: torch.Tensor,
464
+ top_k: int = -100,
465
+ temperature: float = 1.0,
466
+ prompt_language: str = None,
467
+ text_language: str = None,
468
+ best_of: int = 1,
469
+ length_penalty: float = 1.0,
470
+ return_worst: bool = False,
471
+ ) -> torch.Tensor:
472
+ """
473
+ Args:
474
+ x:
475
+ A 2-D tensor of shape (1, S).
476
+ x_lens:
477
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
478
+ before padding.
479
+ y:
480
+ A 3-D tensor of shape (1, T, 8).
481
+ top_k: (`optional`) int
482
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
483
+ temperature: (`optional`) float
484
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
485
+ Returns:
486
+ Return the predicted audio code matrix.
487
+ """
488
+ assert x.ndim == 2, x.shape
489
+ assert x_lens.ndim == 1, x_lens.shape
490
+ assert y.ndim == 3, y.shape
491
+ assert y.shape[0] == 1, y.shape
492
+
493
+ assert torch.all(x_lens > 0)
494
+
495
+ # NOTE: x has been padded in TextTokenCollater
496
+ text = x
497
+ x = self.ar_text_embedding(text)
498
+ # Add language embedding
499
+ prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
500
+ if isinstance(text_language, str):
501
+ text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
502
+ elif isinstance(text_language, List):
503
+ text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
504
+ x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
505
+ x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
506
+ x = self.ar_text_prenet(x)
507
+ x = self.ar_text_position(x)
508
+
509
+ text_len = x_lens.max()
510
+ prompts = y
511
+ prefix_len = y.shape[1]
512
+
513
+ # AR Decoder
514
+ # TODO: Managing decoder steps avoid repetitive computation
515
+ y = prompts[..., 0]
516
+ if self.ar_audio_prepend_bos:
517
+ y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1)
518
+
519
+ x_len = x_lens.max()
520
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
521
+
522
+ kv_cache = None
523
+ use_kv_caching = True
524
+
525
+ sum_logprobs = torch.zeros(best_of, device=y.device) # implement batch decoding here
526
+ x = x.repeat(best_of, 1, 1)
527
+ y = y.repeat(best_of, 1)
528
+ while True:
529
+ y_emb = self.ar_audio_embedding(y)
530
+ y_emb = self.ar_audio_prenet(y_emb)
531
+ y_pos = self.ar_audio_position(y_emb)
532
+ xy_pos = torch.concat([x, y_pos], dim=1)
533
+
534
+ y_len = y.shape[1]
535
+ x_attn_mask_pad = F.pad(
536
+ x_attn_mask,
537
+ (0, y_len),
538
+ value=True,
539
+ )
540
+ y_attn_mask = F.pad(
541
+ torch.triu(
542
+ torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
543
+ ),
544
+ (x_len, 0),
545
+ value=False,
546
+ )
547
+ xy_attn_mask = torch.concat(
548
+ [x_attn_mask_pad, y_attn_mask], dim=0
549
+ ).to(y.device)
550
+
551
+
552
+ if use_kv_caching and kv_cache is not None:
553
+ xy_pos = xy_pos[:, [-1]]
554
+ else:
555
+ pass
556
+
557
+ xy_dec, kv_cache = self.ar_decoder.infer(
558
+ xy_pos,
559
+ mask=xy_attn_mask,
560
+ past_kv=kv_cache,
561
+ use_cache=use_kv_caching,
562
+ )
563
+ # xy_dec, _ = self.ar_decoder(
564
+ # (xy_pos, None),
565
+ # mask=xy_attn_mask,
566
+ # )
567
+
568
+ logits = self.ar_predict_layer(xy_dec[:, -1])
569
+ samples, current_logprobs = topk_sampling(
570
+ logits, top_k=top_k, top_p=1, temperature=temperature
571
+ )
572
+ sum_logprobs += current_logprobs * (y[:, -1] != NUM_AUDIO_TOKENS)
573
+ samples[y[:, -1] == NUM_AUDIO_TOKENS] = NUM_AUDIO_TOKENS
574
+ completed = (samples[:, -1] == NUM_AUDIO_TOKENS).all()
575
+ if (
576
+ completed
577
+ or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
578
+ ):
579
+ if prompts.shape[1] == y.shape[1]:
580
+ raise SyntaxError(
581
+ "well trained model shouldn't reach here."
582
+ )
583
+ lengths = torch.sum(y != NUM_AUDIO_TOKENS, dim=1)
584
+ avg_logprobs = sum_logprobs / lengths ** length_penalty
585
+ # choose the best beam according to sum_logprobs
586
+ best_beam = y[torch.argmax(avg_logprobs), :]
587
+ worst_beam = y[torch.argmin(avg_logprobs), :]
588
+ # strip all eos tokens
589
+ best_beam = best_beam[best_beam != NUM_AUDIO_TOKENS]
590
+ worst_beam = worst_beam[worst_beam != NUM_AUDIO_TOKENS]
591
+ if return_worst:
592
+ y = worst_beam.unsqueeze(0)
593
+ else:
594
+ y = best_beam.unsqueeze(0)
595
+ print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
596
+ break
597
+
598
+ y = torch.concat([y, samples], dim=1)
599
+
600
+ codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
601
+ if self.num_quantizers == 1:
602
+ return torch.stack(codes, dim=-1)
603
+
604
+ # Non-AR Decoders
605
+ y_emb = self.nar_audio_embeddings[0](
606
+ y[:, int(self.ar_audio_prepend_bos) :]
607
+ )
608
+
609
+ if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
610
+ enrolled_len = enroll_x_lens.max().item()
611
+ # SOS + Synthesis Text + EOS
612
+ text = torch.concat(
613
+ [
614
+ text[:, :1],
615
+ text[:, enrolled_len - 1 :],
616
+ ],
617
+ dim=1,
618
+ )
619
+ text_len = text_len - (enrolled_len - 2)
620
+ assert text.shape[0] == 1
621
+
622
+ x = self.nar_text_embedding(text)
623
+ # Add language embedding
624
+ prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
625
+ if isinstance(text_language, str):
626
+ text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
627
+ elif isinstance(text_language, List):
628
+ text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
629
+ x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
630
+ x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
631
+ x = self.nar_text_prenet(x)
632
+ x = self.nar_text_position(x)
633
+
634
+ if self.prefix_mode == 0:
635
+ for i, (predict_layer, embedding_layer) in enumerate(
636
+ zip(
637
+ self.nar_predict_layers,
638
+ self.nar_audio_embeddings[1:],
639
+ )
640
+ ):
641
+ y_pos = self.nar_audio_prenet(y_emb)
642
+ y_pos = self.nar_audio_position(y_pos)
643
+ xy_pos = torch.concat([x, y_pos], dim=1)
644
+
645
+ xy_dec, _ = self.nar_decoder(
646
+ (xy_pos, self.nar_stage_embeddings[i].weight)
647
+ )
648
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
649
+
650
+ samples = torch.argmax(logits, dim=-1)
651
+ codes.append(samples)
652
+
653
+ if i < self.num_quantizers - 2:
654
+ y_emb[:, :prefix_len] += embedding_layer(
655
+ prompts[..., i + 1]
656
+ )
657
+ y_emb[:, prefix_len:] += embedding_layer(samples)
658
+ else:
659
+ for j in range(1, self.num_quantizers):
660
+ y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
661
+ prompts[..., j]
662
+ )
663
+
664
+ for i, (predict_layer, embedding_layer) in enumerate(
665
+ zip(
666
+ self.nar_predict_layers,
667
+ self.nar_audio_embeddings[1:],
668
+ )
669
+ ):
670
+ y_pos = self.nar_audio_prenet(y_emb)
671
+ y_pos = self.nar_audio_position(y_pos)
672
+ xy_pos = torch.concat([x, y_pos], dim=1)
673
+
674
+ xy_dec, _ = self.nar_decoder(
675
+ (xy_pos, self.nar_stage_embeddings[i].weight)
676
+ )
677
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
678
+
679
+ samples = torch.argmax(logits, dim=-1)
680
+ codes.append(samples)
681
+
682
+ if i < self.num_quantizers - 2:
683
+ y_emb[:, prefix_len:] += embedding_layer(samples)
684
+
685
+ assert len(codes) == self.num_quantizers
686
+ return torch.stack(codes, dim=-1)
687
+
688
+ def continual(
689
+ self,
690
+ x: torch.Tensor,
691
+ x_lens: torch.Tensor,
692
+ y: torch.Tensor,
693
+ ) -> torch.Tensor:
694
+ """
695
+ Args:
696
+ x:
697
+ A 2-D tensor of shape (1, S).
698
+ x_lens:
699
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
700
+ before padding.
701
+ y:
702
+ A 3-D tensor of shape (1, T, 8).
703
+ Returns:
704
+ Return the predicted audio code matrix.
705
+ """
706
+ assert x.ndim == 2, x.shape
707
+ assert x_lens.ndim == 1, x_lens.shape
708
+ assert y.ndim == 3, y.shape
709
+ assert y.shape[0] == 1, y.shape
710
+
711
+ assert torch.all(x_lens > 0)
712
+ assert self.num_quantizers == 8
713
+
714
+ # NOTE: x has been padded in TextTokenCollater
715
+ text = x
716
+ x = self.ar_text_embedding(text)
717
+ x = self.ar_text_prenet(x)
718
+ x = self.ar_text_position(x)
719
+
720
+ text_len = x_lens.max()
721
+
722
+ prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
723
+
724
+ # AR Decoder
725
+ prompts = y[:, :prefix_len]
726
+
727
+ codes = [y[:, prefix_len:, 0]]
728
+ # Non-AR Decoders
729
+ x = self.nar_text_embedding(text)
730
+ x = self.nar_text_prenet(x)
731
+ x = self.nar_text_position(x)
732
+
733
+ y_emb = self.nar_audio_embeddings[0](y[..., 0])
734
+
735
+ if self.prefix_mode == 0:
736
+ for i, (predict_layer, embedding_layer) in enumerate(
737
+ zip(
738
+ self.nar_predict_layers,
739
+ self.nar_audio_embeddings[1:],
740
+ )
741
+ ):
742
+ y_pos = self.nar_audio_position(y_emb)
743
+ y_pos = self.nar_audio_prenet(y_pos)
744
+ xy_pos = torch.concat([x, y_pos], dim=1)
745
+
746
+ xy_dec, _ = self.nar_decoder(
747
+ (xy_pos, self.nar_stage_embeddings[i].weight)
748
+ )
749
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
750
+
751
+ samples = torch.argmax(logits, dim=-1)
752
+ codes.append(samples)
753
+
754
+ if i < 6:
755
+ y_emb[:, :prefix_len] += embedding_layer(
756
+ prompts[..., i + 1]
757
+ )
758
+ y_emb[:, prefix_len:] += embedding_layer(samples)
759
+ else:
760
+ for j in range(1, 8):
761
+ y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
762
+ prompts[..., j]
763
+ )
764
+
765
+ for i, (predict_layer, embedding_layer) in enumerate(
766
+ zip(
767
+ self.nar_predict_layers,
768
+ self.nar_audio_embeddings[1:],
769
+ )
770
+ ):
771
+ y_pos = self.nar_audio_prenet(y_emb)
772
+ y_pos = self.nar_audio_position(y_pos)
773
+ xy_pos = torch.concat([x, y_pos], dim=1)
774
+
775
+ xy_dec, _ = self.nar_decoder(
776
+ (xy_pos, self.nar_stage_embeddings[i].weight)
777
+ )
778
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
779
+
780
+ samples = torch.argmax(logits, dim=-1)
781
+ codes.append(samples)
782
+
783
+ if i < 6:
784
+ y_emb[:, prefix_len:] += embedding_layer(samples)
785
+
786
+ assert len(codes) == 8
787
+ return torch.stack(codes, dim=-1)
788
+
789
+
790
+ # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
791
+ def top_k_top_p_filtering(
792
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
793
+ ):
794
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
795
+ Args:
796
+ logits: logits distribution shape (batch size, vocabulary size)
797
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
798
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
799
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
800
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
801
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
802
+ """
803
+ if top_k > 0:
804
+ top_k = min(
805
+ max(top_k, min_tokens_to_keep), logits.size(-1)
806
+ ) # Safety check
807
+ # Remove all tokens with a probability less than the last token of the top-k
808
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
809
+ logits[indices_to_remove] = filter_value
810
+
811
+ if top_p < 1.0:
812
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
813
+ cumulative_probs = torch.cumsum(
814
+ F.softmax(sorted_logits, dim=-1), dim=-1
815
+ )
816
+
817
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
818
+ sorted_indices_to_remove = cumulative_probs > top_p
819
+ if min_tokens_to_keep > 1:
820
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
821
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
822
+ # Shift the indices to the right to keep also the first token above the threshold
823
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
824
+ ..., :-1
825
+ ].clone()
826
+ sorted_indices_to_remove[..., 0] = 0
827
+
828
+ # scatter sorted tensors to original indexing
829
+ indices_to_remove = sorted_indices_to_remove.scatter(
830
+ 1, sorted_indices, sorted_indices_to_remove
831
+ )
832
+ logits[indices_to_remove] = filter_value
833
+ return logits
834
+
835
+
836
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
837
+ # temperature: (`optional`) float
838
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
839
+ # top_k: (`optional`) int
840
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
841
+ # top_p: (`optional`) float
842
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
843
+
844
+ # Temperature (higher temperature => more likely to sample low probability tokens)
845
+ if temperature != 1.0:
846
+ logits = logits / temperature
847
+ # Top-p/top-k filtering
848
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
849
+ # Sample
850
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
851
+ logprobs = F.log_softmax(logits.float(), dim=-1)
852
+ current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)]
853
+ return token, current_logprobs
models/visualizer.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2023 (authors: Feiteng Li)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ from typing import Dict, List, Tuple, Union
20
+
21
+ import matplotlib.pyplot as plt
22
+ import numpy as np
23
+ import torch
24
+
25
+
26
+ def visualize(
27
+ predicts: Tuple[torch.Tensor],
28
+ batch: Dict[str, Union[List, torch.Tensor]],
29
+ output_dir: str,
30
+ limit: int = 4,
31
+ ) -> None:
32
+ text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
33
+ text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
34
+ audio_features = batch["audio_features"].to("cpu").detach().numpy()
35
+ audio_features_lens = (
36
+ batch["audio_features_lens"].to("cpu").detach().numpy()
37
+ )
38
+ assert text_tokens.ndim == 2
39
+
40
+ utt_ids, texts = batch["utt_id"], batch["text"]
41
+
42
+ encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
43
+ decoder_outputs = predicts[1]
44
+ if isinstance(decoder_outputs, list):
45
+ decoder_outputs = decoder_outputs[-1]
46
+ decoder_outputs = (
47
+ decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
48
+ )
49
+
50
+ vmin, vmax = 0, 1024 # Encodec
51
+ if decoder_outputs.dtype == np.float32:
52
+ vmin, vmax = -6, 0 # Fbank
53
+
54
+ num_figures = 3
55
+ for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
56
+ _ = plt.figure(figsize=(14, 8 * num_figures))
57
+
58
+ S = text_tokens_lens[b]
59
+ T = audio_features_lens[b]
60
+
61
+ # encoder
62
+ plt.subplot(num_figures, 1, 1)
63
+ plt.title(f"Text: {text}")
64
+ plt.imshow(
65
+ X=np.transpose(encoder_outputs[b]),
66
+ cmap=plt.get_cmap("jet"),
67
+ aspect="auto",
68
+ interpolation="nearest",
69
+ )
70
+ plt.gca().invert_yaxis()
71
+ plt.axvline(x=S - 0.4, linewidth=2, color="r")
72
+ plt.xlabel("Encoder Output")
73
+ plt.colorbar()
74
+
75
+ # decoder
76
+ plt.subplot(num_figures, 1, 2)
77
+ plt.imshow(
78
+ X=np.transpose(decoder_outputs[b]),
79
+ cmap=plt.get_cmap("jet"),
80
+ aspect="auto",
81
+ interpolation="nearest",
82
+ vmin=vmin,
83
+ vmax=vmax,
84
+ )
85
+ plt.gca().invert_yaxis()
86
+ plt.axvline(x=T - 0.4, linewidth=2, color="r")
87
+ plt.xlabel("Decoder Output")
88
+ plt.colorbar()
89
+
90
+ # target
91
+ plt.subplot(num_figures, 1, 3)
92
+ plt.imshow(
93
+ X=np.transpose(audio_features[b]),
94
+ cmap=plt.get_cmap("jet"),
95
+ aspect="auto",
96
+ interpolation="nearest",
97
+ vmin=vmin,
98
+ vmax=vmax,
99
+ )
100
+ plt.gca().invert_yaxis()
101
+ plt.axvline(x=T - 0.4, linewidth=2, color="r")
102
+ plt.xlabel("Decoder Target")
103
+ plt.colorbar()
104
+
105
+ plt.savefig(f"{output_dir}/{utt_id}.png")
106
+ plt.close()
modules/__init__.py ADDED
File without changes
modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (175 Bytes). View file
 
modules/__pycache__/activation.cpython-311.pyc ADDED
Binary file (27.5 kB). View file
 
modules/__pycache__/embedding.cpython-311.pyc ADDED
Binary file (6.15 kB). View file
 
modules/__pycache__/scaling.cpython-311.pyc ADDED
Binary file (69 kB). View file
 
modules/__pycache__/transformer.cpython-311.pyc ADDED
Binary file (28.2 kB). View file
 
modules/activation.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, List
2
+ import math
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear, Module
7
+ from torch.nn import functional as F
8
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
9
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
10
+ from torch.nn.parameter import Parameter
11
+
12
+ def _in_projection_packed(
13
+ q: Tensor,
14
+ k: Tensor,
15
+ v: Tensor,
16
+ w: Tensor,
17
+ b: Optional[Tensor] = None,
18
+ ) -> List[Tensor]:
19
+ r"""
20
+ Performs the in-projection step of the attention operation, using packed weights.
21
+ Output is a triple containing projection tensors for query, key and value.
22
+
23
+ Args:
24
+ q, k, v: query, key and value tensors to be projected. For self-attention,
25
+ these are typically the same tensor; for encoder-decoder attention,
26
+ k and v are typically the same tensor. (We take advantage of these
27
+ identities for performance if they are present.) Regardless, q, k and v
28
+ must share a common embedding dimension; otherwise their shapes may vary.
29
+ w: projection weights for q, k and v, packed into a single tensor. Weights
30
+ are packed along dimension 0, in q, k, v order.
31
+ b: optional projection biases for q, k and v, packed into a single tensor
32
+ in q, k, v order.
33
+
34
+ Shape:
35
+ Inputs:
36
+ - q: :math:`(..., E)` where E is the embedding dimension
37
+ - k: :math:`(..., E)` where E is the embedding dimension
38
+ - v: :math:`(..., E)` where E is the embedding dimension
39
+ - w: :math:`(E * 3, E)` where E is the embedding dimension
40
+ - b: :math:`E * 3` where E is the embedding dimension
41
+
42
+ Output:
43
+ - in output list :math:`[q', k', v']`, each output tensor will have the
44
+ same shape as the corresponding input tensor.
45
+ """
46
+ E = q.size(-1)
47
+ if k is v:
48
+ if q is k:
49
+ # self-attention
50
+ return F.linear(q, w, b).chunk(3, dim=-1)
51
+ else:
52
+ # encoder-decoder attention
53
+ w_q, w_kv = w.split([E, E * 2])
54
+ if b is None:
55
+ b_q = b_kv = None
56
+ else:
57
+ b_q, b_kv = b.split([E, E * 2])
58
+ return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
59
+ else:
60
+ w_q, w_k, w_v = w.chunk(3)
61
+ if b is None:
62
+ b_q = b_k = b_v = None
63
+ else:
64
+ b_q, b_k, b_v = b.chunk(3)
65
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
66
+
67
+ def _scaled_dot_product_attention(
68
+ q: Tensor,
69
+ k: Tensor,
70
+ v: Tensor,
71
+ attn_mask: Optional[Tensor] = None,
72
+ dropout_p: float = 0.0,
73
+ ) -> Tuple[Tensor, Tensor]:
74
+ r"""
75
+ Computes scaled dot product attention on query, key and value tensors, using
76
+ an optional attention mask if passed, and applying dropout if a probability
77
+ greater than 0.0 is specified.
78
+ Returns a tensor pair containing attended values and attention weights.
79
+
80
+ Args:
81
+ q, k, v: query, key and value tensors. See Shape section for shape details.
82
+ attn_mask: optional tensor containing mask values to be added to calculated
83
+ attention. May be 2D or 3D; see Shape section for details.
84
+ dropout_p: dropout probability. If greater than 0.0, dropout is applied.
85
+
86
+ Shape:
87
+ - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
88
+ and E is embedding dimension.
89
+ - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
90
+ and E is embedding dimension.
91
+ - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
92
+ and E is embedding dimension.
93
+ - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
94
+ shape :math:`(Nt, Ns)`.
95
+
96
+ - Output: attention values have shape :math:`(B, Nt, E)`; attention weights
97
+ have shape :math:`(B, Nt, Ns)`
98
+ """
99
+ B, Nt, E = q.shape
100
+ q = q / math.sqrt(E)
101
+ # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
102
+ if attn_mask is not None:
103
+ attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
104
+ else:
105
+ attn = torch.bmm(q, k.transpose(-2, -1))
106
+
107
+ attn = F.softmax(attn, dim=-1)
108
+ if dropout_p > 0.0:
109
+ attn = F.dropout(attn, p=dropout_p)
110
+ # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
111
+ output = torch.bmm(attn, v)
112
+ return output, attn
113
+
114
+ def multi_head_attention_forward(
115
+ x,
116
+ ipw,
117
+ ipb,
118
+ opw,
119
+ opb,
120
+ n_head,
121
+ attn_mask,
122
+ past_kv=None,
123
+ use_cache=False,
124
+ ):
125
+ # x = x.transpose(1, 0)
126
+ # tgt_len, bsz, embed_dim = x.shape
127
+ # head_dim = embed_dim // n_head
128
+ # q, k, v = _in_projection_packed(x, x, x, ipw, ipb)
129
+ # q = q.contiguous().view(tgt_len, bsz * n_head, head_dim).transpose(0, 1)
130
+ # k = k.contiguous().view(k.shape[0], bsz * n_head, head_dim).transpose(0, 1)
131
+ # v = v.contiguous().view(v.shape[0], bsz * n_head, head_dim).transpose(0, 1)
132
+
133
+ # new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
134
+ # new_attn_mask.masked_fill_(attn_mask, float("-inf"))
135
+ # attn_mask = new_attn_mask
136
+ #
137
+ # attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, 0.0)
138
+ # attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
139
+ # attn_output = torch._C._nn.linear(attn_output, opw, opb)
140
+ # attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
141
+
142
+ B, T, C = x.size()
143
+
144
+ q, k, v = torch._C._nn.linear(x, ipw, ipb).chunk(3, dim=-1)
145
+ k = k.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
146
+ q = q.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
147
+ v = v.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
148
+ if past_kv is not None:
149
+ past_key = past_kv[0]
150
+ past_value = past_kv[1]
151
+ k = torch.cat((past_key, k), dim=-2)
152
+ v = torch.cat((past_value, v), dim=-2)
153
+
154
+ FULL_T = k.shape[-2]
155
+
156
+ if use_cache is True:
157
+ present = (k, v)
158
+ else:
159
+ present = None
160
+
161
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
162
+ att = att.masked_fill(attn_mask[FULL_T - T:FULL_T, :FULL_T], float('-inf'))
163
+ att = F.softmax(att, dim=-1)
164
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
165
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
166
+ y = torch._C._nn.linear(y, opw, opb)
167
+ return (y, present)
168
+
169
+
170
+ class MultiheadAttention(Module):
171
+ r"""Allows the model to jointly attend to information
172
+ from different representation subspaces as described in the paper:
173
+ `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
174
+
175
+ Multi-Head Attention is defined as:
176
+
177
+ .. math::
178
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
179
+
180
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
181
+
182
+ ``forward()`` will use a special optimized implementation if all of the following
183
+ conditions are met:
184
+
185
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
186
+ restriction will be loosened in the future.)
187
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
188
+ - training is disabled (using ``.eval()``)
189
+ - dropout is 0
190
+ - ``add_bias_kv`` is ``False``
191
+ - ``add_zero_attn`` is ``False``
192
+ - ``batch_first`` is ``True`` and the input is batched
193
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
194
+ - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
195
+ - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
196
+ nor ``attn_mask`` is passed
197
+
198
+ If the optimized implementation is in use, a
199
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
200
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
201
+ padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
202
+ will be returned, and an additional speedup proportional to the fraction of the input
203
+ that is padding can be expected.
204
+
205
+ Args:
206
+ embed_dim: Total dimension of the model.
207
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
208
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
209
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
210
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
211
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
212
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
213
+ Default: ``False``.
214
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
215
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
216
+ batch_first: If ``True``, then the input and output tensors are provided
217
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
218
+
219
+ Examples::
220
+
221
+ >>> # xdoctest: +SKIP
222
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
223
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
224
+
225
+ """
226
+ __constants__ = ["batch_first"]
227
+ bias_k: Optional[torch.Tensor]
228
+ bias_v: Optional[torch.Tensor]
229
+
230
+ def __init__(
231
+ self,
232
+ embed_dim,
233
+ num_heads,
234
+ dropout=0.0,
235
+ bias=True,
236
+ add_bias_kv=False,
237
+ add_zero_attn=False,
238
+ kdim=None,
239
+ vdim=None,
240
+ batch_first=False,
241
+ linear1_cls=Linear,
242
+ linear2_cls=Linear,
243
+ device=None,
244
+ dtype=None,
245
+ ) -> None:
246
+ factory_kwargs = {"device": device, "dtype": dtype}
247
+ super(MultiheadAttention, self).__init__()
248
+ self.embed_dim = embed_dim
249
+ self.kdim = kdim if kdim is not None else embed_dim
250
+ self.vdim = vdim if vdim is not None else embed_dim
251
+ self._qkv_same_embed_dim = (
252
+ self.kdim == embed_dim and self.vdim == embed_dim
253
+ )
254
+
255
+ self.num_heads = num_heads
256
+ self.dropout = dropout
257
+ self.batch_first = batch_first
258
+ self.head_dim = embed_dim // num_heads
259
+ assert (
260
+ self.head_dim * num_heads == self.embed_dim
261
+ ), "embed_dim must be divisible by num_heads"
262
+
263
+ if add_bias_kv:
264
+ self.bias_k = Parameter(
265
+ torch.empty((1, 1, embed_dim), **factory_kwargs)
266
+ )
267
+ self.bias_v = Parameter(
268
+ torch.empty((1, 1, embed_dim), **factory_kwargs)
269
+ )
270
+ else:
271
+ self.bias_k = self.bias_v = None
272
+
273
+ if linear1_cls == Linear:
274
+ if not self._qkv_same_embed_dim:
275
+ self.q_proj_weight = Parameter(
276
+ torch.empty((embed_dim, embed_dim), **factory_kwargs)
277
+ )
278
+ self.k_proj_weight = Parameter(
279
+ torch.empty((embed_dim, self.kdim), **factory_kwargs)
280
+ )
281
+ self.v_proj_weight = Parameter(
282
+ torch.empty((embed_dim, self.vdim), **factory_kwargs)
283
+ )
284
+ self.register_parameter("in_proj_weight", None)
285
+ else:
286
+ self.in_proj_weight = Parameter(
287
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
288
+ )
289
+ self.register_parameter("q_proj_weight", None)
290
+ self.register_parameter("k_proj_weight", None)
291
+ self.register_parameter("v_proj_weight", None)
292
+
293
+ if bias:
294
+ self.in_proj_bias = Parameter(
295
+ torch.empty(3 * embed_dim, **factory_kwargs)
296
+ )
297
+ else:
298
+ self.register_parameter("in_proj_bias", None)
299
+ self.out_proj = NonDynamicallyQuantizableLinear(
300
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
301
+ )
302
+
303
+ self._reset_parameters()
304
+ else:
305
+ if not self._qkv_same_embed_dim:
306
+ raise NotImplementedError
307
+ else:
308
+ self.in_proj_linear = linear1_cls(
309
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
310
+ )
311
+ self.in_proj_weight = self.in_proj_linear.weight
312
+
313
+ self.register_parameter("q_proj_weight", None)
314
+ self.register_parameter("k_proj_weight", None)
315
+ self.register_parameter("v_proj_weight", None)
316
+
317
+ if bias:
318
+ self.in_proj_bias = self.in_proj_linear.bias
319
+ else:
320
+ self.register_parameter("in_proj_bias", None)
321
+
322
+ self.out_proj = linear2_cls(
323
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
324
+ )
325
+
326
+ if self.bias_k is not None:
327
+ xavier_normal_(self.bias_k)
328
+ if self.bias_v is not None:
329
+ xavier_normal_(self.bias_v)
330
+
331
+ self.add_zero_attn = add_zero_attn
332
+
333
+ def _reset_parameters(self):
334
+ if self._qkv_same_embed_dim:
335
+ xavier_uniform_(self.in_proj_weight)
336
+ else:
337
+ xavier_uniform_(self.q_proj_weight)
338
+ xavier_uniform_(self.k_proj_weight)
339
+ xavier_uniform_(self.v_proj_weight)
340
+
341
+ if self.in_proj_bias is not None:
342
+ constant_(self.in_proj_bias, 0.0)
343
+ constant_(self.out_proj.bias, 0.0)
344
+
345
+ if self.bias_k is not None:
346
+ xavier_normal_(self.bias_k)
347
+ if self.bias_v is not None:
348
+ xavier_normal_(self.bias_v)
349
+
350
+ def __setstate__(self, state):
351
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
352
+ if "_qkv_same_embed_dim" not in state:
353
+ state["_qkv_same_embed_dim"] = True
354
+
355
+ super(MultiheadAttention, self).__setstate__(state)
356
+
357
+ def forward(
358
+ self,
359
+ query: Tensor,
360
+ key: Tensor,
361
+ value: Tensor,
362
+ key_padding_mask: Optional[Tensor] = None,
363
+ need_weights: bool = True,
364
+ attn_mask: Optional[Tensor] = None,
365
+ average_attn_weights: bool = True,
366
+ ) -> Tuple[Tensor, Optional[Tensor]]:
367
+ r"""
368
+ Args:
369
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
370
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
371
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
372
+ Queries are compared against key-value pairs to produce the output.
373
+ See "Attention Is All You Need" for more details.
374
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
375
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
376
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
377
+ See "Attention Is All You Need" for more details.
378
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
379
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
380
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
381
+ See "Attention Is All You Need" for more details.
382
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
383
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
384
+ Binary and byte masks are supported.
385
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
386
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
387
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
388
+ Default: ``True``.
389
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
390
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
391
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
392
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
393
+ Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
394
+ corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
395
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
396
+ the attention weight.
397
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
398
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
399
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
400
+
401
+ Outputs:
402
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
403
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
404
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
405
+ embedding dimension ``embed_dim``.
406
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
407
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
408
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
409
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
410
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
411
+
412
+ .. note::
413
+ `batch_first` argument is ignored for unbatched inputs.
414
+ """
415
+ is_batched = query.dim() == 3
416
+ if key_padding_mask is not None:
417
+ _kpm_dtype = key_padding_mask.dtype
418
+ if _kpm_dtype != torch.bool and not torch.is_floating_point(
419
+ key_padding_mask
420
+ ):
421
+ raise AssertionError(
422
+ "only bool and floating types of key_padding_mask are supported"
423
+ )
424
+ why_not_fast_path = ""
425
+ if not is_batched:
426
+ why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
427
+ elif query is not key or key is not value:
428
+ # When lifting this restriction, don't forget to either
429
+ # enforce that the dtypes all match or test cases where
430
+ # they don't!
431
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
432
+ elif (
433
+ self.in_proj_bias is not None
434
+ and query.dtype != self.in_proj_bias.dtype
435
+ ):
436
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
437
+ elif (
438
+ self.in_proj_weight is not None
439
+ and query.dtype != self.in_proj_weight.dtype
440
+ ):
441
+ # this case will fail anyway, but at least they'll get a useful error message.
442
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
443
+ elif self.training:
444
+ why_not_fast_path = "training is enabled"
445
+ elif not self.batch_first:
446
+ why_not_fast_path = "batch_first was not True"
447
+ elif self.bias_k is not None:
448
+ why_not_fast_path = "self.bias_k was not None"
449
+ elif self.bias_v is not None:
450
+ why_not_fast_path = "self.bias_v was not None"
451
+ elif self.dropout:
452
+ why_not_fast_path = f"dropout was {self.dropout}, required zero"
453
+ elif self.add_zero_attn:
454
+ why_not_fast_path = "add_zero_attn was enabled"
455
+ elif not self._qkv_same_embed_dim:
456
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
457
+ elif attn_mask is not None:
458
+ why_not_fast_path = "attn_mask was not None"
459
+ elif query.is_nested and key_padding_mask is not None:
460
+ why_not_fast_path = (
461
+ "key_padding_mask is not supported with NestedTensor input"
462
+ )
463
+ elif self.num_heads % 2 == 1:
464
+ why_not_fast_path = "num_heads is odd"
465
+ elif torch.is_autocast_enabled():
466
+ why_not_fast_path = "autocast is enabled"
467
+
468
+ if not why_not_fast_path:
469
+ tensor_args = (
470
+ query,
471
+ key,
472
+ value,
473
+ self.in_proj_weight,
474
+ self.in_proj_bias,
475
+ self.out_proj.weight,
476
+ self.out_proj.bias,
477
+ )
478
+ # We have to use list comprehensions below because TorchScript does not support
479
+ # generator expressions.
480
+ if torch.overrides.has_torch_function(tensor_args):
481
+ why_not_fast_path = "some Tensor argument has_torch_function"
482
+ elif not all(
483
+ [
484
+ (x is None or x.is_cuda or "cpu" in str(x.device))
485
+ for x in tensor_args
486
+ ]
487
+ ):
488
+ why_not_fast_path = (
489
+ "some Tensor argument is neither CUDA nor CPU"
490
+ )
491
+ elif torch.is_grad_enabled() and any(
492
+ [x is not None and x.requires_grad for x in tensor_args]
493
+ ):
494
+ why_not_fast_path = (
495
+ "grad is enabled and at least one of query or the "
496
+ "input/output projection weights or biases requires_grad"
497
+ )
498
+ if not why_not_fast_path:
499
+ return torch._native_multi_head_attention(
500
+ query,
501
+ key,
502
+ value,
503
+ self.embed_dim,
504
+ self.num_heads,
505
+ self.in_proj_weight,
506
+ self.in_proj_bias,
507
+ self.out_proj.weight,
508
+ self.out_proj.bias,
509
+ key_padding_mask
510
+ if key_padding_mask is not None
511
+ else attn_mask,
512
+ need_weights,
513
+ average_attn_weights,
514
+ 1
515
+ if key_padding_mask is not None
516
+ else 0
517
+ if attn_mask is not None
518
+ else None,
519
+ )
520
+
521
+ any_nested = query.is_nested or key.is_nested or value.is_nested
522
+ assert not any_nested, (
523
+ "MultiheadAttention does not support NestedTensor outside of its fast path. "
524
+ + f"The fast path was not hit because {why_not_fast_path}"
525
+ )
526
+
527
+ if self.batch_first and is_batched:
528
+ # make sure that the transpose op does not affect the "is" property
529
+ if key is value:
530
+ if query is key:
531
+ query = key = value = query.transpose(1, 0)
532
+ else:
533
+ query, key = [x.transpose(1, 0) for x in (query, key)]
534
+ value = key
535
+ else:
536
+ query, key, value = [
537
+ x.transpose(1, 0) for x in (query, key, value)
538
+ ]
539
+
540
+ if not self._qkv_same_embed_dim:
541
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
542
+ query,
543
+ key,
544
+ value,
545
+ self.embed_dim,
546
+ self.num_heads,
547
+ self.in_proj_weight,
548
+ self.in_proj_bias,
549
+ self.bias_k,
550
+ self.bias_v,
551
+ self.add_zero_attn,
552
+ self.dropout,
553
+ self.out_proj.weight,
554
+ self.out_proj.bias,
555
+ training=self.training,
556
+ key_padding_mask=key_padding_mask,
557
+ need_weights=need_weights,
558
+ attn_mask=attn_mask,
559
+ use_separate_proj_weight=True,
560
+ q_proj_weight=self.q_proj_weight,
561
+ k_proj_weight=self.k_proj_weight,
562
+ v_proj_weight=self.v_proj_weight,
563
+ average_attn_weights=average_attn_weights,
564
+ )
565
+ else:
566
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
567
+ query,
568
+ key,
569
+ value,
570
+ self.embed_dim,
571
+ self.num_heads,
572
+ self.in_proj_weight,
573
+ self.in_proj_bias,
574
+ self.bias_k,
575
+ self.bias_v,
576
+ self.add_zero_attn,
577
+ self.dropout,
578
+ self.out_proj.weight,
579
+ self.out_proj.bias,
580
+ training=self.training,
581
+ key_padding_mask=key_padding_mask,
582
+ need_weights=need_weights,
583
+ attn_mask=attn_mask,
584
+ average_attn_weights=average_attn_weights,
585
+ )
586
+ if self.batch_first and is_batched:
587
+ return attn_output.transpose(1, 0), attn_output_weights
588
+ else:
589
+ return attn_output, attn_output_weights
590
+
591
+ def infer(self,
592
+ x: Tensor,
593
+ key_padding_mask: Optional[Tensor] = None,
594
+ need_weights: bool = True,
595
+ attn_mask: Optional[Tensor] = None,
596
+ average_attn_weights: bool = True,
597
+ past_kv = None,
598
+ use_cache = False
599
+ ):
600
+ # x = x.transpose(1, 0)
601
+ y, kv = multi_head_attention_forward(
602
+ x=x,
603
+ ipw=self.in_proj_weight,
604
+ ipb=self.in_proj_bias,
605
+ opw=self.out_proj.weight,
606
+ opb=self.out_proj.bias,
607
+ n_head=self.num_heads,
608
+ attn_mask=attn_mask,
609
+ past_kv=past_kv,
610
+ use_cache=use_cache,
611
+ )
612
+ return (y, kv)
modules/embedding.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 (authors: Feiteng Li)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+
21
+ class TokenEmbedding(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dim_model: int,
25
+ vocab_size: int,
26
+ dropout: float = 0.0,
27
+ ):
28
+ super().__init__()
29
+
30
+ self.vocab_size = vocab_size
31
+ self.dim_model = dim_model
32
+
33
+ self.dropout = torch.nn.Dropout(p=dropout)
34
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
35
+
36
+ @property
37
+ def weight(self) -> torch.Tensor:
38
+ return self.word_embeddings.weight
39
+
40
+ def embedding(self, index: int) -> torch.Tensor:
41
+ return self.word_embeddings.weight[index : index + 1]
42
+
43
+ def forward(self, x: torch.Tensor):
44
+ X = self.word_embeddings(x)
45
+ X = self.dropout(X)
46
+
47
+ return X
48
+
49
+
50
+ class SinePositionalEmbedding(nn.Module):
51
+ def __init__(
52
+ self,
53
+ dim_model: int,
54
+ dropout: float = 0.0,
55
+ scale: bool = False,
56
+ alpha: bool = False,
57
+ ):
58
+ super().__init__()
59
+ self.dim_model = dim_model
60
+ self.x_scale = math.sqrt(dim_model) if scale else 1.0
61
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
62
+ self.dropout = torch.nn.Dropout(p=dropout)
63
+
64
+ self.reverse = False
65
+ self.pe = None
66
+ self.extend_pe(torch.tensor(0.0).expand(1, 4000))
67
+
68
+ def extend_pe(self, x):
69
+ """Reset the positional encodings."""
70
+ if self.pe is not None:
71
+ if self.pe.size(1) >= x.size(1):
72
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
73
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
74
+ return
75
+ pe = torch.zeros(x.size(1), self.dim_model)
76
+ if self.reverse:
77
+ position = torch.arange(
78
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
79
+ ).unsqueeze(1)
80
+ else:
81
+ position = torch.arange(
82
+ 0, x.size(1), dtype=torch.float32
83
+ ).unsqueeze(1)
84
+ div_term = torch.exp(
85
+ torch.arange(0, self.dim_model, 2, dtype=torch.float32)
86
+ * -(math.log(10000.0) / self.dim_model)
87
+ )
88
+ pe[:, 0::2] = torch.sin(position * div_term)
89
+ pe[:, 1::2] = torch.cos(position * div_term)
90
+ pe = pe.unsqueeze(0)
91
+ self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
92
+
93
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
94
+ self.extend_pe(x)
95
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
96
+ output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
97
+ return self.dropout(output)
modules/optim.py ADDED
@@ -0,0 +1,1105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import contextlib
18
+ import logging
19
+ import random
20
+ from collections import defaultdict
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ from lhotse.utils import fix_random_seed
25
+ from torch import Tensor
26
+ from torch.optim import Optimizer
27
+
28
+
29
+ class BatchedOptimizer(Optimizer):
30
+ """
31
+ This class adds to class Optimizer the capability to optimize parameters in batches:
32
+ it will stack the parameters and their grads for you so the optimizer can work
33
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
34
+ as it reduces the number of kernels launched in the optimizer.
35
+
36
+ Args:
37
+ params:
38
+ """
39
+
40
+ def __init__(self, params, defaults):
41
+ super(BatchedOptimizer, self).__init__(params, defaults)
42
+
43
+ @contextlib.contextmanager
44
+ def batched_params(self, param_group, group_params_names):
45
+ """
46
+ This function returns (technically, yields) a list of
47
+ of tuples (p, state), where
48
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
49
+ that share the same shape, and its gradient is also stacked;
50
+ `state` is the state corresponding to this batch of parameters
51
+ (it will be physically located in the "state" for one of the real
52
+ parameters, the last one that has any particular shape and dtype).
53
+
54
+ This function is decorated as a context manager so that it can
55
+ write parameters back to their "real" locations.
56
+
57
+ The idea is, instead of doing:
58
+ <code>
59
+ for p in group["params"]:
60
+ state = self.state[p]
61
+ ...
62
+ </code>
63
+ you can do:
64
+ <code>
65
+ with self.batched_params(group["params"]) as batches:
66
+ for p, state, p_names in batches:
67
+ ...
68
+ </code>
69
+
70
+ Args:
71
+ group: a parameter group, which is a list of parameters; should be
72
+ one of self.param_groups.
73
+ group_params_names: name for each parameter in group,
74
+ which is List[str].
75
+ """
76
+ batches = defaultdict(
77
+ list
78
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
79
+ batches_names = defaultdict(
80
+ list
81
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
82
+
83
+ assert len(param_group) == len(group_params_names)
84
+ for p, named_p in zip(param_group, group_params_names):
85
+ key = (str(p.dtype), *p.shape)
86
+ batches[key].append(p)
87
+ batches_names[key].append(named_p)
88
+
89
+ batches_names_keys = list(batches_names.keys())
90
+ sorted_idx = sorted(
91
+ range(len(batches_names)), key=lambda i: batches_names_keys[i]
92
+ )
93
+ batches_names = [
94
+ batches_names[batches_names_keys[idx]] for idx in sorted_idx
95
+ ]
96
+ batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
97
+
98
+ stacked_params_dict = dict()
99
+
100
+ # turn batches into a list, in deterministic order.
101
+ # tuples will contain tuples of (stacked_param, state, stacked_params_names),
102
+ # one for each batch in `batches`.
103
+ tuples = []
104
+
105
+ for batch, batch_names in zip(batches, batches_names):
106
+ p = batch[0]
107
+ # we arbitrarily store the state in the
108
+ # state corresponding to the 1st parameter in the
109
+ # group. class Optimizer will take care of saving/loading state.
110
+ state = self.state[p]
111
+ p_stacked = torch.stack(batch)
112
+ grad = torch.stack(
113
+ [
114
+ torch.zeros_like(p) if p.grad is None else p.grad
115
+ for p in batch
116
+ ]
117
+ )
118
+ p_stacked.grad = grad
119
+ stacked_params_dict[key] = p_stacked
120
+ tuples.append((p_stacked, state, batch_names))
121
+
122
+ yield tuples # <-- calling code will do the actual optimization here!
123
+
124
+ for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
125
+ for i, p in enumerate(batch): # batch is list of Parameter
126
+ p.copy_(stacked_params[i])
127
+
128
+
129
+ class ScaledAdam(BatchedOptimizer):
130
+ """
131
+ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
132
+ proportional to the norm of that parameter; and also learn the scale of the parameter,
133
+ in log space, subject to upper and lower limits (as if we had factored each parameter as
134
+ param = underlying_param * log_scale.exp())
135
+
136
+
137
+ Args:
138
+ params: The parameters or param_groups to optimize (like other Optimizer subclasses)
139
+ lr: The learning rate. We will typically use a learning rate schedule that starts
140
+ at 0.03 and decreases over time, i.e. much higher than other common
141
+ optimizers.
142
+ clipping_scale: (e.g. 2.0)
143
+ A scale for gradient-clipping: if specified, the normalized gradients
144
+ over the whole model will be clipped to have 2-norm equal to
145
+ `clipping_scale` times the median 2-norm over the most recent period
146
+ of `clipping_update_period` minibatches. By "normalized gradients",
147
+ we mean after multiplying by the rms parameter value for this tensor
148
+ [for non-scalars]; this is appropriate because our update is scaled
149
+ by this quantity.
150
+ betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
151
+ Must satisfy 0 < beta <= beta2 < 1.
152
+ scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
153
+ scale of each parameter tensor and scalar parameters of the mode..
154
+ If each parameter were decomposed
155
+ as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
156
+ would be a the scaling factor on the learning rate of p_scale.
157
+ eps: A general-purpose epsilon to prevent division by zero
158
+ param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
159
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
160
+ parameter tensor to be >= this value)
161
+ param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
162
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
163
+ parameter tensor to be <= this value)
164
+ scalar_max: Maximum absolute value for scalar parameters (applicable if your
165
+ model has any parameters with numel() == 1).
166
+ size_update_period: The periodicity, in steps, with which we update the size (scale)
167
+ of the parameter tensor. This is provided to save a little time
168
+ in the update.
169
+ clipping_update_period: if clipping_scale is specified, this is the period
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ params,
175
+ lr=3e-02,
176
+ clipping_scale=None,
177
+ betas=(0.9, 0.98),
178
+ scalar_lr_scale=0.1,
179
+ eps=1.0e-08,
180
+ param_min_rms=1.0e-05,
181
+ param_max_rms=3.0,
182
+ scalar_max=10.0,
183
+ size_update_period=4,
184
+ clipping_update_period=100,
185
+ parameters_names=None,
186
+ show_dominant_parameters=True,
187
+ ):
188
+
189
+ assert parameters_names is not None, (
190
+ "Please prepare parameters_names,"
191
+ "which is a List[List[str]]. Each List[str] is for a group"
192
+ "and each str is for a parameter"
193
+ )
194
+ defaults = dict(
195
+ lr=lr,
196
+ clipping_scale=clipping_scale,
197
+ betas=betas,
198
+ scalar_lr_scale=scalar_lr_scale,
199
+ eps=eps,
200
+ param_min_rms=param_min_rms,
201
+ param_max_rms=param_max_rms,
202
+ scalar_max=scalar_max,
203
+ size_update_period=size_update_period,
204
+ clipping_update_period=clipping_update_period,
205
+ )
206
+
207
+ super(ScaledAdam, self).__init__(params, defaults)
208
+ assert len(self.param_groups) == len(parameters_names)
209
+ self.parameters_names = parameters_names
210
+ self.show_dominant_parameters = show_dominant_parameters
211
+
212
+ def __setstate__(self, state):
213
+ super(ScaledAdam, self).__setstate__(state)
214
+
215
+ @torch.no_grad()
216
+ def step(self, closure=None):
217
+ """Performs a single optimization step.
218
+
219
+ Arguments:
220
+ closure (callable, optional): A closure that reevaluates the model
221
+ and returns the loss.
222
+ """
223
+ loss = None
224
+ if closure is not None:
225
+ with torch.enable_grad():
226
+ loss = closure()
227
+
228
+ batch = True
229
+
230
+ for group, group_params_names in zip(
231
+ self.param_groups, self.parameters_names
232
+ ):
233
+
234
+ with self.batched_params(
235
+ group["params"], group_params_names
236
+ ) as batches:
237
+
238
+ # batches is list of pairs (stacked_param, state). stacked_param is like
239
+ # a regular parameter, and will have a .grad, but the 1st dim corresponds to
240
+ # a stacking dim, it is not a real dim.
241
+
242
+ if (
243
+ len(batches[0][1]) == 0
244
+ ): # if len(first state) == 0: not yet initialized
245
+ clipping_scale = 1
246
+ else:
247
+ clipping_scale = self._get_clipping_scale(group, batches)
248
+
249
+ for p, state, _ in batches:
250
+ # Perform optimization step.
251
+ # grad is not going to be None, we handled that when creating the batches.
252
+ grad = p.grad
253
+ if grad.is_sparse:
254
+ raise RuntimeError(
255
+ "ScaledAdam optimizer does not support sparse gradients"
256
+ )
257
+ # State initialization
258
+ if len(state) == 0:
259
+ self._init_state(group, p, state)
260
+
261
+ self._step_one_batch(group, p, state, clipping_scale)
262
+
263
+ return loss
264
+
265
+ def _init_state(self, group: dict, p: Tensor, state: dict):
266
+ """
267
+ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
268
+ is actually the batch dimension, corresponding to batched-together
269
+ parameters of a given shape.
270
+
271
+
272
+ Args:
273
+ group: Dict to look up configuration values.
274
+ p: The parameter that we are initializing the state for
275
+ state: Dict from string to whatever state we are initializing
276
+ """
277
+ size_update_period = group["size_update_period"]
278
+
279
+ state["step"] = 0
280
+
281
+ kwargs = {"device": p.device, "dtype": p.dtype}
282
+
283
+ # 'delta' implements conventional momentum. There are
284
+ # several different kinds of update going on, so rather than
285
+ # compute "exp_avg" like in Adam, we store and decay a
286
+ # parameter-change "delta", which combines all forms of
287
+ # update. this is equivalent to how it's done in Adam,
288
+ # except for the first few steps.
289
+ state["delta"] = torch.zeros_like(
290
+ p, memory_format=torch.preserve_format
291
+ )
292
+
293
+ batch_size = p.shape[0]
294
+ numel = p.numel() // batch_size
295
+ numel = p.numel()
296
+
297
+ if numel > 1:
298
+ # "param_rms" just periodically records the scalar root-mean-square value of
299
+ # the parameter tensor.
300
+ # it has a shape like (batch_size, 1, 1, 1, 1)
301
+ param_rms = (
302
+ (p ** 2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
303
+ )
304
+ state["param_rms"] = param_rms
305
+
306
+ state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
307
+ state["scale_grads"] = torch.zeros(
308
+ size_update_period, *param_rms.shape, **kwargs
309
+ )
310
+
311
+ # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
312
+ state["exp_avg_sq"] = torch.zeros_like(
313
+ p, memory_format=torch.preserve_format
314
+ )
315
+
316
+ def _get_clipping_scale(
317
+ self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
318
+ ) -> float:
319
+ """
320
+ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
321
+ by this amount before applying the rest of the update.
322
+
323
+ Args:
324
+ group: the parameter group, an item in self.param_groups
325
+ tuples: a list of tuples of (param, state, param_names)
326
+ where param is a batched set of parameters,
327
+ with a .grad (1st dim is batch dim)
328
+ and state is the state-dict where optimization parameters are kept.
329
+ param_names is a List[str] while each str is name for a parameter
330
+ in batched set of parameters "param".
331
+ """
332
+ assert len(tuples) >= 1
333
+ clipping_scale = group["clipping_scale"]
334
+ (first_p, first_state, _) = tuples[0]
335
+ step = first_state["step"]
336
+ if clipping_scale is None or step == 0:
337
+ # no clipping. return early on step == 0 because the other
338
+ # parameters' state won't have been initialized yet.
339
+ return 1.0
340
+ clipping_update_period = group["clipping_update_period"]
341
+
342
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
343
+ for (p, state, param_names) in tuples:
344
+ grad = p.grad
345
+ if grad.is_sparse:
346
+ raise RuntimeError(
347
+ "ScaledAdam optimizer does not support sparse gradients"
348
+ )
349
+ if p.numel() == p.shape[0]: # a batch of scalars
350
+ tot_sumsq += (
351
+ grad ** 2
352
+ ).sum() # sum() to change shape [1] to []
353
+ else:
354
+ tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
355
+
356
+ tot_norm = tot_sumsq.sqrt()
357
+ if "model_norms" not in first_state:
358
+ first_state["model_norms"] = torch.zeros(
359
+ clipping_update_period, device=p.device
360
+ )
361
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
362
+
363
+ if step % clipping_update_period == 0:
364
+ # Print some stats.
365
+ # We don't reach here if step == 0 because we would have returned
366
+ # above.
367
+ sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
368
+ quartiles = []
369
+ for n in range(0, 5):
370
+ index = min(
371
+ clipping_update_period - 1,
372
+ (clipping_update_period // 4) * n,
373
+ )
374
+ quartiles.append(sorted_norms[index].item())
375
+
376
+ median = quartiles[2]
377
+ threshold = clipping_scale * median
378
+ first_state["model_norm_threshold"] = threshold
379
+ percent_clipped = (
380
+ first_state["num_clipped"] * 100.0 / clipping_update_period
381
+ if "num_clipped" in first_state
382
+ else 0.0
383
+ )
384
+ first_state["num_clipped"] = 0
385
+ quartiles = " ".join(["%.3e" % x for x in quartiles])
386
+ logging.info(
387
+ f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
388
+ f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
389
+ )
390
+
391
+ if step < clipping_update_period:
392
+ return 1.0 # We have not yet estimated a norm to clip to.
393
+ else:
394
+ try:
395
+ model_norm_threshold = first_state["model_norm_threshold"]
396
+ except KeyError:
397
+ logging.info(
398
+ "Warning: model_norm_threshold not in state: possibly "
399
+ "you changed config when restarting, adding clipping_scale option?"
400
+ )
401
+ return 1.0
402
+ ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
403
+ if ans < 1.0:
404
+ first_state["num_clipped"] += 1
405
+ if ans < 0.1:
406
+ logging.warn(
407
+ f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
408
+ )
409
+ if self.show_dominant_parameters:
410
+ assert p.shape[0] == len(param_names)
411
+ self._show_gradient_dominating_parameter(tuples, tot_sumsq)
412
+ return ans
413
+
414
+ def _show_gradient_dominating_parameter(
415
+ self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
416
+ ):
417
+ """
418
+ Show information of parameter wihch dominanting tot_sumsq.
419
+
420
+ Args:
421
+ tuples: a list of tuples of (param, state, param_names)
422
+ where param is a batched set of parameters,
423
+ with a .grad (1st dim is batch dim)
424
+ and state is the state-dict where optimization parameters are kept.
425
+ param_names is a List[str] while each str is name for a parameter
426
+ in batched set of parameters "param".
427
+ tot_sumsq: sumsq of all parameters. Though it's could be calculated
428
+ from tuples, we still pass it to save some time.
429
+ """
430
+ all_sumsq_orig = {}
431
+ for (p, state, batch_param_names) in tuples:
432
+ # p is a stacked batch parameters.
433
+ batch_grad = p.grad
434
+ if p.numel() == p.shape[0]: # a batch of scalars
435
+ batch_sumsq_orig = batch_grad ** 2
436
+ # Dummpy values used by following `zip` statement.
437
+ batch_rms_orig = torch.ones(p.shape[0])
438
+ else:
439
+ batch_rms_orig = state["param_rms"]
440
+ batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
441
+ dim=list(range(1, batch_grad.ndim))
442
+ )
443
+
444
+ for name, sumsq_orig, rms, grad in zip(
445
+ batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
446
+ ):
447
+
448
+ proportion_orig = sumsq_orig / tot_sumsq
449
+ all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
450
+
451
+ assert torch.isclose(
452
+ sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
453
+ torch.tensor(1.0),
454
+ )
455
+ sorted_by_proportion = {
456
+ k: v
457
+ for k, v in sorted(
458
+ all_sumsq_orig.items(),
459
+ key=lambda item: item[1][0],
460
+ reverse=True,
461
+ )
462
+ }
463
+ dominant_param_name = next(iter(sorted_by_proportion))
464
+ (
465
+ dominant_proportion,
466
+ dominant_sumsq,
467
+ dominant_rms,
468
+ dominant_grad,
469
+ ) = sorted_by_proportion[dominant_param_name]
470
+ logging.info(
471
+ f"Parameter Dominanting tot_sumsq {dominant_param_name}"
472
+ f" with proportion {dominant_proportion:.2f},"
473
+ f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
474
+ f"={dominant_sumsq:.3e},"
475
+ f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
476
+ f" orig_rms_sq={(dominant_rms**2).item():.3e}"
477
+ )
478
+
479
+ def _step_one_batch(
480
+ self, group: dict, p: Tensor, state: dict, clipping_scale: float
481
+ ):
482
+ """
483
+ Do the step for one parameter, which is actually going to be a batch of
484
+ `real` parameters, with dim 0 as the batch dim.
485
+ Args:
486
+ group: dict to look up configuration values
487
+ p: parameter to update (actually multiple parameters stacked together
488
+ as a batch)
489
+ state: state-dict for p, to look up the optimizer state
490
+ """
491
+ lr = group["lr"]
492
+ size_update_period = group["size_update_period"]
493
+ beta1 = group["betas"][0]
494
+
495
+ grad = p.grad
496
+ if clipping_scale != 1.0:
497
+ grad = grad * clipping_scale
498
+ step = state["step"]
499
+ delta = state["delta"]
500
+
501
+ delta.mul_(beta1)
502
+ batch_size = p.shape[0]
503
+ numel = p.numel() // batch_size
504
+ if numel > 1:
505
+ # Update the size/scale of p, and set param_rms
506
+ scale_grads = state["scale_grads"]
507
+ scale_grads[step % size_update_period] = (p * grad).sum(
508
+ dim=list(range(1, p.ndim)), keepdim=True
509
+ )
510
+ if step % size_update_period == size_update_period - 1:
511
+ param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
512
+ param_rms.copy_(
513
+ (p ** 2)
514
+ .mean(dim=list(range(1, p.ndim)), keepdim=True)
515
+ .sqrt()
516
+ )
517
+ if step > 0:
518
+ # self._size_update() learns the overall scale on the
519
+ # parameter, by shrinking or expanding it.
520
+ self._size_update(group, scale_grads, p, state)
521
+
522
+ if numel == 1:
523
+ # For parameters with 1 element we just use regular Adam.
524
+ # Updates delta.
525
+ self._step_scalar(group, p, state)
526
+ else:
527
+ self._step(group, p, state)
528
+
529
+ state["step"] = step + 1
530
+
531
+ def _size_update(
532
+ self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
533
+ ) -> None:
534
+ """
535
+ Called only where p.numel() > 1, this updates the scale of the parameter.
536
+ If we imagine: p = underlying_param * scale.exp(), and we are doing
537
+ gradient descent on underlying param and on scale, this function does the update
538
+ on `scale`.
539
+
540
+ Args:
541
+ group: dict to look up configuration values
542
+ scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
543
+ grads w.r.t. the scales.
544
+ p: The parameter to update
545
+ state: The state-dict of p
546
+ """
547
+
548
+ param_rms = state["param_rms"]
549
+ beta1, beta2 = group["betas"]
550
+ size_lr = group["lr"] * group["scalar_lr_scale"]
551
+ param_min_rms = group["param_min_rms"]
552
+ param_max_rms = group["param_max_rms"]
553
+ eps = group["eps"]
554
+ step = state["step"]
555
+ batch_size = p.shape[0]
556
+
557
+ size_update_period = scale_grads.shape[0]
558
+ # correct beta2 for the size update period: we will have
559
+ # faster decay at this level.
560
+ beta2_corr = beta2 ** size_update_period
561
+
562
+ scale_exp_avg_sq = state[
563
+ "scale_exp_avg_sq"
564
+ ] # shape: (batch_size, 1, 1, ..)
565
+ scale_exp_avg_sq.mul_(beta2_corr).add_(
566
+ (scale_grads ** 2).mean(
567
+ dim=0
568
+ ), # mean over dim `size_update_period`
569
+ alpha=1 - beta2_corr,
570
+ ) # shape is (batch_size, 1, 1, ...)
571
+
572
+ # The 1st time we reach here is when size_step == 1.
573
+ size_step = (step + 1) // size_update_period
574
+ bias_correction2 = 1 - beta2_corr ** size_step
575
+ # we don't bother with bias_correction1; this will help prevent divergence
576
+ # at the start of training.
577
+
578
+ denom = scale_exp_avg_sq.sqrt() + eps
579
+
580
+ scale_step = (
581
+ -size_lr
582
+ * (bias_correction2 ** 0.5)
583
+ * scale_grads.sum(dim=0)
584
+ / denom
585
+ )
586
+
587
+ is_too_small = param_rms < param_min_rms
588
+ is_too_large = param_rms > param_max_rms
589
+
590
+ # when the param gets too small, just don't shrink it any further.
591
+ scale_step.masked_fill_(is_too_small, 0.0)
592
+ # when it gets too large, stop it from getting any larger.
593
+ scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
594
+ delta = state["delta"]
595
+ # the factor of (1-beta1) relates to momentum.
596
+ delta.add_(p * scale_step, alpha=(1 - beta1))
597
+
598
+ def _step(self, group: dict, p: Tensor, state: dict):
599
+ """
600
+ This function does the core update of self.step(), in the case where the members of
601
+ the batch have more than 1 element.
602
+
603
+ Args:
604
+ group: A dict which will be used to look up configuration values
605
+ p: The parameter to be updated
606
+ grad: The grad of p
607
+ state: The state-dict corresponding to parameter p
608
+
609
+ This function modifies p.
610
+ """
611
+ grad = p.grad
612
+ lr = group["lr"]
613
+ beta1, beta2 = group["betas"]
614
+ eps = group["eps"]
615
+ param_min_rms = group["param_min_rms"]
616
+ step = state["step"]
617
+
618
+ exp_avg_sq = state["exp_avg_sq"]
619
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
620
+
621
+ this_step = state["step"] - (
622
+ state["zero_step"] if "zero_step" in state else 0
623
+ )
624
+ bias_correction2 = 1 - beta2 ** (this_step + 1)
625
+ if bias_correction2 < 0.99:
626
+ # note: not in-place.
627
+ exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
628
+
629
+ denom = exp_avg_sq.sqrt()
630
+ denom += eps
631
+ grad = grad / denom
632
+
633
+ alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
634
+
635
+ delta = state["delta"]
636
+ delta.add_(grad * alpha)
637
+ p.add_(delta)
638
+
639
+ def _step_scalar(self, group: dict, p: Tensor, state: dict):
640
+ """
641
+ A simplified form of the core update for scalar tensors, where we cannot get a good
642
+ estimate of the parameter rms.
643
+ """
644
+ beta1, beta2 = group["betas"]
645
+ scalar_max = group["scalar_max"]
646
+ eps = group["eps"]
647
+ lr = group["lr"] * group["scalar_lr_scale"]
648
+ grad = p.grad
649
+
650
+ exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
651
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
652
+
653
+ # bias_correction2 is like in Adam. Don't bother with bias_correction1;
654
+ # slower update at the start will help stability anyway.
655
+ bias_correction2 = 1 - beta2 ** (state["step"] + 1)
656
+ denom = (exp_avg_sq / bias_correction2).sqrt() + eps
657
+
658
+ delta = state["delta"]
659
+ delta.add_(grad / denom, alpha=-lr * (1 - beta1))
660
+ p.clamp_(min=-scalar_max, max=scalar_max)
661
+ p.add_(delta)
662
+
663
+
664
+ class LRScheduler(object):
665
+ """
666
+ Base-class for learning rate schedulers where the learning-rate depends on both the
667
+ batch and the epoch.
668
+ """
669
+
670
+ def __init__(self, optimizer: Optimizer, verbose: bool = False):
671
+ # Attach optimizer
672
+ if not isinstance(optimizer, Optimizer):
673
+ raise TypeError(
674
+ "{} is not an Optimizer".format(type(optimizer).__name__)
675
+ )
676
+ self.optimizer = optimizer
677
+ self.verbose = verbose
678
+
679
+ for group in optimizer.param_groups:
680
+ group.setdefault("base_lr", group["lr"])
681
+
682
+ self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
683
+
684
+ self.epoch = 0
685
+ self.batch = 0
686
+
687
+ def state_dict(self):
688
+ """Returns the state of the scheduler as a :class:`dict`.
689
+
690
+ It contains an entry for every variable in self.__dict__ which
691
+ is not the optimizer.
692
+ """
693
+ return {
694
+ "base_lrs": self.base_lrs,
695
+ "epoch": self.epoch,
696
+ "batch": self.batch,
697
+ }
698
+
699
+ def load_state_dict(self, state_dict):
700
+ """Loads the schedulers state.
701
+
702
+ Args:
703
+ state_dict (dict): scheduler state. Should be an object returned
704
+ from a call to :meth:`state_dict`.
705
+ """
706
+ self.__dict__.update(state_dict)
707
+
708
+ def get_last_lr(self) -> List[float]:
709
+ """Return last computed learning rate by current scheduler. Will be a list of float."""
710
+ return self._last_lr
711
+
712
+ def get_lr(self):
713
+ # Compute list of learning rates from self.epoch and self.batch and
714
+ # self.base_lrs; this must be overloaded by the user.
715
+ # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
716
+ raise NotImplementedError
717
+
718
+ def step_batch(self, batch: Optional[int] = None) -> None:
719
+ # Step the batch index, or just set it. If `batch` is specified, it
720
+ # must be the batch index from the start of training, i.e. summed over
721
+ # all epochs.
722
+ # You can call this in any order; if you don't provide 'batch', it should
723
+ # of course be called once per batch.
724
+ if batch is not None:
725
+ self.batch = batch
726
+ else:
727
+ self.batch = self.batch + 1
728
+ self._set_lrs()
729
+
730
+ def step_epoch(self, epoch: Optional[int] = None):
731
+ # Step the epoch index, or just set it. If you provide the 'epoch' arg,
732
+ # you should call this at the start of the epoch; if you don't provide the 'epoch'
733
+ # arg, you should call it at the end of the epoch.
734
+ if epoch is not None:
735
+ self.epoch = epoch
736
+ else:
737
+ self.epoch = self.epoch + 1
738
+ self._set_lrs()
739
+
740
+ def _set_lrs(self):
741
+ values = self.get_lr()
742
+ assert len(values) == len(self.optimizer.param_groups)
743
+
744
+ for i, data in enumerate(zip(self.optimizer.param_groups, values)):
745
+ param_group, lr = data
746
+ param_group["lr"] = lr
747
+ self.print_lr(self.verbose, i, lr)
748
+ self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
749
+
750
+ def print_lr(self, is_verbose, group, lr):
751
+ """Display the current learning rate."""
752
+ if is_verbose:
753
+ logging.info(
754
+ f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
755
+ f" of group {group} to {lr:.4e}."
756
+ )
757
+
758
+
759
+ class Eden(LRScheduler):
760
+ """
761
+ Eden scheduler.
762
+ The basic formula (before warmup) is:
763
+ lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
764
+ (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
765
+ where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
766
+ and then stays constant at 1.
767
+
768
+
769
+ E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
770
+
771
+ Args:
772
+ optimizer: the optimizer to change the learning rates on
773
+ lr_batches: the number of batches after which we start significantly
774
+ decreasing the learning rate, suggest 5000.
775
+ lr_epochs: the number of epochs after which we start significantly
776
+ decreasing the learning rate, suggest 6 if you plan to do e.g.
777
+ 20 to 40 epochs, but may need smaller number if dataset is huge
778
+ and you will do few epochs.
779
+ """
780
+
781
+ def __init__(
782
+ self,
783
+ optimizer: Optimizer,
784
+ lr_batches: Union[int, float],
785
+ lr_epochs: Union[int, float],
786
+ warmup_batches: Union[int, float] = 500.0,
787
+ verbose: bool = False,
788
+ ):
789
+ super(Eden, self).__init__(optimizer, verbose)
790
+ self.lr_batches = lr_batches
791
+ self.lr_epochs = lr_epochs
792
+ self.warmup_batches = warmup_batches
793
+
794
+ def get_lr(self):
795
+ factor = (
796
+ (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
797
+ ) ** -0.25 * (
798
+ ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
799
+ ** -0.25
800
+ )
801
+ warmup_factor = (
802
+ 1.0
803
+ if self.batch >= self.warmup_batches
804
+ else 0.5 + 0.5 * (self.batch / self.warmup_batches)
805
+ )
806
+
807
+ return [x * factor * warmup_factor for x in self.base_lrs]
808
+
809
+
810
+ def _test_eden():
811
+ m = torch.nn.Linear(100, 100)
812
+ optim = ScaledAdam(m.parameters(), lr=0.03)
813
+
814
+ scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
815
+
816
+ for epoch in range(10):
817
+ scheduler.step_epoch(epoch) # sets epoch to `epoch`
818
+
819
+ for step in range(20):
820
+ x = torch.randn(200, 100).detach()
821
+ x.requires_grad = True
822
+ y = m(x)
823
+ dy = torch.randn(200, 100).detach()
824
+ f = (y * dy).sum()
825
+ f.backward()
826
+
827
+ optim.step()
828
+ scheduler.step_batch()
829
+ optim.zero_grad()
830
+
831
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
832
+ logging.info(f"state dict = {scheduler.state_dict()}")
833
+
834
+
835
+ # This is included mostly as a baseline for ScaledAdam.
836
+ class Eve(Optimizer):
837
+ """
838
+ Implements Eve algorithm. This is a modified version of AdamW with a special
839
+ way of setting the weight-decay / shrinkage-factor, which is designed to make the
840
+ rms of the parameters approach a particular target_rms (default: 0.1). This is
841
+ for use with networks with 'scaled' versions of modules (see scaling.py), which
842
+ will be close to invariant to the absolute scale on the parameter matrix.
843
+
844
+ The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
845
+ The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
846
+ Eve is unpublished so far.
847
+
848
+ Arguments:
849
+ params (iterable): iterable of parameters to optimize or dicts defining
850
+ parameter groups
851
+ lr (float, optional): learning rate (default: 1e-3)
852
+ betas (Tuple[float, float], optional): coefficients used for computing
853
+ running averages of gradient and its square (default: (0.9, 0.999))
854
+ eps (float, optional): term added to the denominator to improve
855
+ numerical stability (default: 1e-8)
856
+ weight_decay (float, optional): weight decay coefficient (default: 3e-4;
857
+ this value means that the weight would decay significantly after
858
+ about 3k minibatches. Is not multiplied by learning rate, but
859
+ is conditional on RMS-value of parameter being > target_rms.
860
+ target_rms (float, optional): target root-mean-square value of
861
+ parameters, if they fall below this we will stop applying weight decay.
862
+
863
+
864
+ .. _Adam: A Method for Stochastic Optimization:
865
+ https://arxiv.org/abs/1412.6980
866
+ .. _Decoupled Weight Decay Regularization:
867
+ https://arxiv.org/abs/1711.05101
868
+ .. _On the Convergence of Adam and Beyond:
869
+ https://openreview.net/forum?id=ryQu7f-RZ
870
+ """
871
+
872
+ def __init__(
873
+ self,
874
+ params,
875
+ lr=1e-3,
876
+ betas=(0.9, 0.98),
877
+ eps=1e-8,
878
+ weight_decay=1e-3,
879
+ target_rms=0.1,
880
+ ):
881
+ if not 0.0 <= lr:
882
+ raise ValueError("Invalid learning rate: {}".format(lr))
883
+ if not 0.0 <= eps:
884
+ raise ValueError("Invalid epsilon value: {}".format(eps))
885
+ if not 0.0 <= betas[0] < 1.0:
886
+ raise ValueError(
887
+ "Invalid beta parameter at index 0: {}".format(betas[0])
888
+ )
889
+ if not 0.0 <= betas[1] < 1.0:
890
+ raise ValueError(
891
+ "Invalid beta parameter at index 1: {}".format(betas[1])
892
+ )
893
+ if not 0 <= weight_decay <= 0.1:
894
+ raise ValueError(
895
+ "Invalid weight_decay value: {}".format(weight_decay)
896
+ )
897
+ if not 0 < target_rms <= 10.0:
898
+ raise ValueError("Invalid target_rms value: {}".format(target_rms))
899
+ defaults = dict(
900
+ lr=lr,
901
+ betas=betas,
902
+ eps=eps,
903
+ weight_decay=weight_decay,
904
+ target_rms=target_rms,
905
+ )
906
+ super(Eve, self).__init__(params, defaults)
907
+
908
+ def __setstate__(self, state):
909
+ super(Eve, self).__setstate__(state)
910
+
911
+ @torch.no_grad()
912
+ def step(self, closure=None):
913
+ """Performs a single optimization step.
914
+
915
+ Arguments:
916
+ closure (callable, optional): A closure that reevaluates the model
917
+ and returns the loss.
918
+ """
919
+ loss = None
920
+ if closure is not None:
921
+ with torch.enable_grad():
922
+ loss = closure()
923
+
924
+ for group in self.param_groups:
925
+ for p in group["params"]:
926
+ if p.grad is None:
927
+ continue
928
+
929
+ # Perform optimization step
930
+ grad = p.grad
931
+ if grad.is_sparse:
932
+ raise RuntimeError(
933
+ "AdamW does not support sparse gradients"
934
+ )
935
+
936
+ state = self.state[p]
937
+
938
+ # State initialization
939
+ if len(state) == 0:
940
+ state["step"] = 0
941
+ # Exponential moving average of gradient values
942
+ state["exp_avg"] = torch.zeros_like(
943
+ p, memory_format=torch.preserve_format
944
+ )
945
+ # Exponential moving average of squared gradient values
946
+ state["exp_avg_sq"] = torch.zeros_like(
947
+ p, memory_format=torch.preserve_format
948
+ )
949
+
950
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
951
+
952
+ beta1, beta2 = group["betas"]
953
+
954
+ state["step"] += 1
955
+ bias_correction1 = 1 - beta1 ** state["step"]
956
+ bias_correction2 = 1 - beta2 ** state["step"]
957
+
958
+ # Decay the first and second moment running average coefficient
959
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
960
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
961
+ denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
962
+ group["eps"]
963
+ )
964
+
965
+ step_size = group["lr"] / bias_correction1
966
+ target_rms = group["target_rms"]
967
+ weight_decay = group["weight_decay"]
968
+
969
+ if p.numel() > 1:
970
+ # avoid applying this weight-decay on "scaling factors"
971
+ # (which are scalar).
972
+ is_above_target_rms = p.norm() > (
973
+ target_rms * (p.numel() ** 0.5)
974
+ )
975
+ p.mul_(1 - (weight_decay * is_above_target_rms))
976
+
977
+ p.addcdiv_(exp_avg, denom, value=-step_size)
978
+
979
+ # if random.random() < 0.0005:
980
+ # step = (exp_avg / denom) * step_size
981
+ # logging.info(
982
+ # f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
983
+ # )
984
+
985
+ return loss
986
+
987
+
988
+ def _test_scaled_adam(hidden_dim: int):
989
+ import timeit
990
+
991
+ from scaling import ScaledLinear
992
+
993
+ E = 100
994
+ B = 4
995
+ T = 2
996
+ logging.info("in test_eve_cain")
997
+ # device = torch.device('cuda')
998
+ device = torch.device("cpu")
999
+ dtype = torch.float32
1000
+
1001
+ fix_random_seed(42)
1002
+ # these input_magnitudes and output_magnitudes are to test that
1003
+ # Abel is working as we expect and is able to adjust scales of
1004
+ # different dims differently.
1005
+ input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
1006
+ output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
1007
+
1008
+ for iter in [1, 0]:
1009
+ fix_random_seed(42)
1010
+ Linear = torch.nn.Linear if iter == 0 else ScaledLinear
1011
+
1012
+ m = torch.nn.Sequential(
1013
+ Linear(E, hidden_dim),
1014
+ torch.nn.PReLU(),
1015
+ Linear(hidden_dim, hidden_dim),
1016
+ torch.nn.PReLU(),
1017
+ Linear(hidden_dim, E),
1018
+ ).to(device)
1019
+
1020
+ train_pairs = [
1021
+ (
1022
+ 100.0
1023
+ * torch.randn(B, T, E, device=device, dtype=dtype)
1024
+ * input_magnitudes,
1025
+ torch.randn(B, T, E, device=device, dtype=dtype)
1026
+ * output_magnitudes,
1027
+ )
1028
+ for _ in range(20)
1029
+ ]
1030
+
1031
+ if iter == 0:
1032
+ optim = Eve(m.parameters(), lr=0.003)
1033
+ elif iter == 1:
1034
+ optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
1035
+ scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
1036
+
1037
+ start = timeit.default_timer()
1038
+ avg_loss = 0.0
1039
+ for epoch in range(180):
1040
+ scheduler.step_epoch()
1041
+ # if epoch == 100 and iter in [2,3]:
1042
+ # optim.reset_speedup() # check it doesn't crash.
1043
+
1044
+ # if epoch == 130:
1045
+ # opts = diagnostics.TensorDiagnosticOptions(
1046
+ # 2 ** 22
1047
+ # ) # allow 4 megabytes per sub-module
1048
+ # diagnostic = diagnostics.attach_diagnostics(m, opts)
1049
+
1050
+ for n, (x, y) in enumerate(train_pairs):
1051
+ y_out = m(x)
1052
+ loss = ((y_out - y) ** 2).mean() * 100.0
1053
+ if epoch == 0 and n == 0:
1054
+ avg_loss = loss.item()
1055
+ else:
1056
+ avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
1057
+ if n == 0 and epoch % 5 == 0:
1058
+ # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
1059
+ # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
1060
+ # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
1061
+ # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
1062
+ # scale1 = '%.2e' % (m[0].weight_scale.exp().item())
1063
+ # scale1b = '%.2e' % (m[0].bias_scale.exp().item())
1064
+ # scale2 = '%.2e' % (m[2].weight_scale.exp().item())
1065
+ # scale2b = '%.2e' % (m[2].bias_scale.exp().item())
1066
+ lr = scheduler.get_last_lr()[0]
1067
+ logging.info(
1068
+ f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
1069
+ ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
1070
+ loss.log().backward()
1071
+ optim.step()
1072
+ optim.zero_grad()
1073
+ scheduler.step_batch()
1074
+
1075
+ # diagnostic.print_diagnostics()
1076
+
1077
+ stop = timeit.default_timer()
1078
+ logging.info(f"Iter={iter}, Time taken: {stop - start}")
1079
+
1080
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
1081
+ # logging.info("state dict = ", scheduler.state_dict())
1082
+ # logging.info("optim state_dict = ", optim.state_dict())
1083
+ logging.info(f"input_magnitudes = {input_magnitudes}")
1084
+ logging.info(f"output_magnitudes = {output_magnitudes}")
1085
+
1086
+
1087
+ if __name__ == "__main__":
1088
+ torch.set_num_threads(1)
1089
+ torch.set_num_interop_threads(1)
1090
+ logging.getLogger().setLevel(logging.INFO)
1091
+ import subprocess
1092
+
1093
+ s = subprocess.check_output(
1094
+ "git status -uno .; git log -1; git diff HEAD .", shell=True
1095
+ )
1096
+ logging.info(s)
1097
+ import sys
1098
+
1099
+ if len(sys.argv) > 1:
1100
+ hidden_dim = int(sys.argv[1])
1101
+ else:
1102
+ hidden_dim = 200
1103
+
1104
+ _test_scaled_adam(hidden_dim)
1105
+ _test_eden()
modules/scaling.py ADDED
@@ -0,0 +1,1401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import collections
19
+ import logging
20
+ import random
21
+ import math
22
+ from functools import reduce
23
+ from itertools import repeat
24
+ from typing import Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ from torch import Tensor
30
+ from torch.nn import Embedding as ScaledEmbedding
31
+
32
+ from utils import Transpose
33
+
34
+
35
+ class ActivationBalancerFunction(torch.autograd.Function):
36
+ @staticmethod
37
+ def forward(
38
+ ctx,
39
+ x: Tensor,
40
+ scale_factor: Tensor,
41
+ sign_factor: Optional[Tensor],
42
+ channel_dim: int,
43
+ ) -> Tensor:
44
+ if channel_dim < 0:
45
+ channel_dim += x.ndim
46
+ ctx.channel_dim = channel_dim
47
+ xgt0 = x > 0
48
+ if sign_factor is None:
49
+ ctx.save_for_backward(xgt0, scale_factor)
50
+ else:
51
+ ctx.save_for_backward(xgt0, scale_factor, sign_factor)
52
+ return x
53
+
54
+ @staticmethod
55
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
56
+ if len(ctx.saved_tensors) == 3:
57
+ xgt0, scale_factor, sign_factor = ctx.saved_tensors
58
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
59
+ scale_factor = scale_factor.unsqueeze(-1)
60
+ sign_factor = sign_factor.unsqueeze(-1)
61
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
62
+ else:
63
+ xgt0, scale_factor = ctx.saved_tensors
64
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
65
+ scale_factor = scale_factor.unsqueeze(-1)
66
+ factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
67
+ neg_delta_grad = x_grad.abs() * factor
68
+ return (
69
+ x_grad - neg_delta_grad,
70
+ None,
71
+ None,
72
+ None,
73
+ )
74
+
75
+
76
+ def _compute_scale_factor(
77
+ x: Tensor,
78
+ channel_dim: int,
79
+ min_abs: float,
80
+ max_abs: float,
81
+ gain_factor: float,
82
+ max_factor: float,
83
+ ) -> Tensor:
84
+ if channel_dim < 0:
85
+ channel_dim += x.ndim
86
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
87
+ x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
88
+
89
+ if min_abs == 0.0:
90
+ below_threshold = 0.0
91
+ else:
92
+ # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
93
+ # x_abs)_mean , min_abs.
94
+ below_threshold = (
95
+ (min_abs - x_abs_mean) * (gain_factor / min_abs)
96
+ ).clamp(min=0, max=max_factor)
97
+
98
+ above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
99
+ min=0, max=max_factor
100
+ )
101
+
102
+ return below_threshold - above_threshold
103
+
104
+
105
+ def _compute_sign_factor(
106
+ x: Tensor,
107
+ channel_dim: int,
108
+ min_positive: float,
109
+ max_positive: float,
110
+ gain_factor: float,
111
+ max_factor: float,
112
+ ) -> Tensor:
113
+ if channel_dim < 0:
114
+ channel_dim += x.ndim
115
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
116
+ proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
117
+ if min_positive == 0.0:
118
+ factor1 = 0.0
119
+ else:
120
+ # 0 if proportion_positive >= min_positive, else can be
121
+ # as large as max_factor.
122
+ factor1 = (
123
+ (min_positive - proportion_positive) * (gain_factor / min_positive)
124
+ ).clamp_(min=0, max=max_factor)
125
+
126
+ if max_positive == 1.0:
127
+ factor2 = 0.0
128
+ else:
129
+ # 0 if self.proportion_positive <= max_positive, else can be
130
+ # as large as -max_factor.
131
+ factor2 = (
132
+ (proportion_positive - max_positive)
133
+ * (gain_factor / (1.0 - max_positive))
134
+ ).clamp_(min=0, max=max_factor)
135
+ sign_factor = factor1 - factor2
136
+ # require min_positive != 0 or max_positive != 1:
137
+ assert not isinstance(sign_factor, float)
138
+ return sign_factor
139
+
140
+
141
+ class ActivationScaleBalancerFunction(torch.autograd.Function):
142
+ """
143
+ This object is used in class ActivationBalancer when the user specified
144
+ min_positive=0, max_positive=1, so there are no constraints on the signs
145
+ of the activations and only the absolute value has a constraint.
146
+ """
147
+
148
+ @staticmethod
149
+ def forward(
150
+ ctx,
151
+ x: Tensor,
152
+ sign_factor: Tensor,
153
+ scale_factor: Tensor,
154
+ channel_dim: int,
155
+ ) -> Tensor:
156
+ if channel_dim < 0:
157
+ channel_dim += x.ndim
158
+ ctx.channel_dim = channel_dim
159
+ xgt0 = x > 0
160
+ ctx.save_for_backward(xgt0, sign_factor, scale_factor)
161
+ return x
162
+
163
+ @staticmethod
164
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
165
+ xgt0, sign_factor, scale_factor = ctx.saved_tensors
166
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
167
+ sign_factor = sign_factor.unsqueeze(-1)
168
+ scale_factor = scale_factor.unsqueeze(-1)
169
+
170
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
171
+ neg_delta_grad = x_grad.abs() * factor
172
+ return (
173
+ x_grad - neg_delta_grad,
174
+ None,
175
+ None,
176
+ None,
177
+ )
178
+
179
+
180
+ class RandomClampFunction(torch.autograd.Function):
181
+ @staticmethod
182
+ def forward(
183
+ ctx,
184
+ x: Tensor,
185
+ min: Optional[float],
186
+ max: Optional[float],
187
+ prob: float,
188
+ reflect: float,
189
+ ) -> Tensor:
190
+ x_clamped = torch.clamp(x, min=min, max=max)
191
+ mask = torch.rand_like(x) < prob
192
+ ans = torch.where(mask, x_clamped, x)
193
+ if x.requires_grad:
194
+ ctx.save_for_backward(ans == x)
195
+ ctx.reflect = reflect
196
+ if reflect != 0.0:
197
+ ans = ans * (1.0 + reflect) - (x * reflect)
198
+ return ans
199
+
200
+ @staticmethod
201
+ def backward(
202
+ ctx, ans_grad: Tensor
203
+ ) -> Tuple[Tensor, None, None, None, None]:
204
+ (is_same,) = ctx.saved_tensors
205
+ x_grad = ans_grad * is_same.to(ans_grad.dtype)
206
+ reflect = ctx.reflect
207
+ if reflect != 0.0:
208
+ x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
209
+ return x_grad, None, None, None, None
210
+
211
+
212
+ def random_clamp(
213
+ x: Tensor,
214
+ min: Optional[float] = None,
215
+ max: Optional[float] = None,
216
+ prob: float = 0.5,
217
+ reflect: float = 0.0,
218
+ ):
219
+ return RandomClampFunction.apply(x, min, max, prob, reflect)
220
+
221
+
222
+ def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
223
+ """
224
+ A randomized way of casting a floating point value to half precision.
225
+ """
226
+ if x.dtype == torch.float16:
227
+ return x
228
+ x_abs = x.abs()
229
+ is_too_small = x_abs < min_abs
230
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
231
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
232
+ # for those elements].
233
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
234
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
235
+
236
+
237
+ class RandomGradFunction(torch.autograd.Function):
238
+ """
239
+ Does nothing in forward pass; in backward pass, gets rid of very small grads using
240
+ randomized approach that preserves expectations (intended to reduce roundoff).
241
+ """
242
+
243
+ @staticmethod
244
+ def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
245
+ ctx.min_abs = min_abs
246
+ return x
247
+
248
+ @staticmethod
249
+ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
250
+ if ans_grad.dtype == torch.float16:
251
+ return (
252
+ random_cast_to_half(
253
+ ans_grad.to(torch.float32), min_abs=ctx.min_abs
254
+ ),
255
+ None,
256
+ )
257
+ else:
258
+ return ans_grad, None
259
+
260
+
261
+ class RandomGrad(torch.nn.Module):
262
+ """
263
+ Gets rid of very small gradients using an expectation-preserving method, intended to increase
264
+ accuracy of training when using amp (automatic mixed precision)
265
+ """
266
+
267
+ def __init__(self, min_abs: float = 5.0e-06):
268
+ super(RandomGrad, self).__init__()
269
+ self.min_abs = min_abs
270
+
271
+ def forward(self, x: Tensor):
272
+ if (
273
+ torch.jit.is_scripting()
274
+ or not self.training
275
+ or torch.jit.is_tracing()
276
+ ):
277
+ return x
278
+ else:
279
+ return RandomGradFunction.apply(x, self.min_abs)
280
+
281
+
282
+ class SoftmaxFunction(torch.autograd.Function):
283
+ """
284
+ Tries to handle half-precision derivatives in a randomized way that should
285
+ be more accurate for training than the default behavior.
286
+ """
287
+
288
+ @staticmethod
289
+ def forward(ctx, x: Tensor, dim: int):
290
+ ans = x.softmax(dim=dim)
291
+ # if x dtype is float16, x.softmax() returns a float32 because
292
+ # (presumably) that op does not support float16, and autocast
293
+ # is enabled.
294
+ if torch.is_autocast_enabled():
295
+ ans = ans.to(torch.float16)
296
+ ctx.save_for_backward(ans)
297
+ ctx.x_dtype = x.dtype
298
+ ctx.dim = dim
299
+ return ans
300
+
301
+ @staticmethod
302
+ def backward(ctx, ans_grad: Tensor):
303
+ (ans,) = ctx.saved_tensors
304
+ with torch.cuda.amp.autocast(enabled=False):
305
+ ans_grad = ans_grad.to(torch.float32)
306
+ ans = ans.to(torch.float32)
307
+ x_grad = ans_grad * ans
308
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
309
+ return x_grad, None
310
+
311
+
312
+ def softmax(x: Tensor, dim: int):
313
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
314
+ return x.softmax(dim)
315
+
316
+ return SoftmaxFunction.apply(x, dim)
317
+
318
+
319
+ class MaxEigLimiterFunction(torch.autograd.Function):
320
+ @staticmethod
321
+ def forward(
322
+ ctx,
323
+ x: Tensor,
324
+ coeffs: Tensor,
325
+ direction: Tensor,
326
+ channel_dim: int,
327
+ grad_scale: float,
328
+ ) -> Tensor:
329
+ ctx.channel_dim = channel_dim
330
+ ctx.grad_scale = grad_scale
331
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
332
+ return x
333
+
334
+ @staticmethod
335
+ def backward(ctx, x_grad, *args):
336
+ with torch.enable_grad():
337
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
338
+ x_orig.requires_grad = True
339
+ num_channels = x_orig.shape[ctx.channel_dim]
340
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
341
+ new_direction.requires_grad = False
342
+ x = x - x.mean(dim=0)
343
+ x_var = (x ** 2).mean()
344
+ x_residual = x - coeffs * new_direction
345
+ x_residual_var = (x_residual ** 2).mean()
346
+ # `variance_proportion` is the proportion of the variance accounted for
347
+ # by the top eigen-direction. This is to be minimized.
348
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
349
+ variance_proportion.backward()
350
+ x_orig_grad = x_orig.grad
351
+ x_extra_grad = (
352
+ x_orig.grad
353
+ * ctx.grad_scale
354
+ * x_grad.norm()
355
+ / (x_orig_grad.norm() + 1.0e-20)
356
+ )
357
+ return x_grad + x_extra_grad.detach(), None, None, None, None
358
+
359
+
360
+ class BasicNorm(torch.nn.Module):
361
+ """
362
+ This is intended to be a simpler, and hopefully cheaper, replacement for
363
+ LayerNorm. The observation this is based on, is that Transformer-type
364
+ networks, especially with pre-norm, sometimes seem to set one of the
365
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
366
+ the LayerNorm because the output magnitude is then not strongly dependent
367
+ on the other (useful) features. Presumably the weight and bias of the
368
+ LayerNorm are required to allow it to do this.
369
+
370
+ So the idea is to introduce this large constant value as an explicit
371
+ parameter, that takes the role of the "eps" in LayerNorm, so the network
372
+ doesn't have to do this trick. We make the "eps" learnable.
373
+
374
+ Args:
375
+ num_channels: the number of channels, e.g. 512.
376
+ channel_dim: the axis/dimension corresponding to the channel,
377
+ interprted as an offset from the input's ndim if negative.
378
+ shis is NOT the num_channels; it should typically be one of
379
+ {-2, -1, 0, 1, 2, 3}.
380
+ eps: the initial "epsilon" that we add as ballast in:
381
+ scale = ((input_vec**2).mean() + epsilon)**-0.5
382
+ Note: our epsilon is actually large, but we keep the name
383
+ to indicate the connection with conventional LayerNorm.
384
+ learn_eps: if true, we learn epsilon; if false, we keep it
385
+ at the initial value.
386
+ eps_min: float
387
+ eps_max: float
388
+ """
389
+
390
+ def __init__(
391
+ self,
392
+ num_channels: int,
393
+ channel_dim: int = -1, # CAUTION: see documentation.
394
+ eps: float = 0.25,
395
+ learn_eps: bool = True,
396
+ eps_min: float = -3.0,
397
+ eps_max: float = 3.0,
398
+ ) -> None:
399
+ super(BasicNorm, self).__init__()
400
+ self.num_channels = num_channels
401
+ self.channel_dim = channel_dim
402
+ if learn_eps:
403
+ self.eps = nn.Parameter(torch.tensor(eps).log().detach())
404
+ else:
405
+ self.register_buffer("eps", torch.tensor(eps).log().detach())
406
+ self.eps_min = eps_min
407
+ self.eps_max = eps_max
408
+
409
+ def forward(self, x: Tensor) -> Tensor:
410
+ assert x.shape[self.channel_dim] == self.num_channels
411
+ eps = self.eps
412
+ if self.training and random.random() < 0.25:
413
+ # with probability 0.25, in training mode, clamp eps between the min
414
+ # and max; this will encourage it to learn parameters within the
415
+ # allowed range by making parameters that are outside the allowed
416
+ # range noisy.
417
+
418
+ # gradients to allow the parameter to get back into the allowed
419
+ # region if it happens to exit it.
420
+ eps = eps.clamp(min=self.eps_min, max=self.eps_max)
421
+ scales = (
422
+ torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
423
+ ) ** -0.5
424
+ return x * scales
425
+
426
+
427
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
428
+ """
429
+ Behaves like a constructor of a modified version of nn.Linear
430
+ that gives an easy way to set the default initial parameter scale.
431
+
432
+ Args:
433
+ Accepts the standard args and kwargs that nn.Linear accepts
434
+ e.g. in_features, out_features, bias=False.
435
+
436
+ initial_scale: you can override this if you want to increase
437
+ or decrease the initial magnitude of the module's output
438
+ (affects the initialization of weight_scale and bias_scale).
439
+ Another option, if you want to do something like this, is
440
+ to re-initialize the parameters.
441
+ """
442
+ ans = nn.Linear(*args, **kwargs)
443
+ with torch.no_grad():
444
+ ans.weight[:] *= initial_scale
445
+ if ans.bias is not None:
446
+ torch.nn.init.uniform_(
447
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
448
+ )
449
+ return ans
450
+
451
+
452
+ def ScaledConv1d(
453
+ *args,
454
+ initial_scale: float = 1.0,
455
+ kernel_size: int = 3,
456
+ padding: str = "same",
457
+ **kwargs,
458
+ ) -> nn.Conv1d:
459
+ """
460
+ Behaves like a constructor of a modified version of nn.Conv1d
461
+ that gives an easy way to set the default initial parameter scale.
462
+
463
+ Args:
464
+ Accepts the standard args and kwargs that nn.Linear accepts
465
+ e.g. in_features, out_features, bias=False.
466
+
467
+ initial_scale: you can override this if you want to increase
468
+ or decrease the initial magnitude of the module's output
469
+ (affects the initialization of weight_scale and bias_scale).
470
+ Another option, if you want to do something like this, is
471
+ to re-initialize the parameters.
472
+ """
473
+ ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
474
+ with torch.no_grad():
475
+ ans.weight[:] *= initial_scale
476
+ if ans.bias is not None:
477
+ torch.nn.init.uniform_(
478
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
479
+ )
480
+ return ans
481
+
482
+
483
+ def TransposeScaledConv1d(
484
+ *args,
485
+ initial_scale: float = 1.0,
486
+ kernel_size: int = 3,
487
+ padding: str = "same",
488
+ **kwargs,
489
+ ) -> nn.Sequential:
490
+ """
491
+ Transpose -> ScaledConv1d
492
+ """
493
+ return nn.Sequential(
494
+ Transpose(),
495
+ ScaledConv1d(
496
+ *args,
497
+ initial_scale=initial_scale,
498
+ kernel_size=kernel_size,
499
+ padding=padding,
500
+ **kwargs,
501
+ ),
502
+ )
503
+
504
+
505
+ def ScaledConv1dTranspose(
506
+ *args,
507
+ initial_scale: float = 1.0,
508
+ kernel_size: int = 3,
509
+ padding: str = "same",
510
+ **kwargs,
511
+ ) -> nn.Sequential:
512
+ """
513
+ Transpose -> ScaledConv1d
514
+ """
515
+ return nn.Sequential(
516
+ ScaledConv1d(
517
+ *args,
518
+ initial_scale=initial_scale,
519
+ kernel_size=kernel_size,
520
+ padding=padding,
521
+ **kwargs,
522
+ ),
523
+ Transpose(),
524
+ )
525
+
526
+
527
+ def TransposeConv1d(
528
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
529
+ ) -> nn.Sequential:
530
+ """
531
+ Transpose -> Conv1d
532
+ """
533
+ return nn.Sequential(
534
+ Transpose(),
535
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
536
+ )
537
+
538
+
539
+ def Conv1dTranspose(
540
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
541
+ ) -> nn.Sequential:
542
+ """
543
+ ScaledConv1d -> Transpose
544
+ """
545
+ return nn.Sequential(
546
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
547
+ Transpose(),
548
+ )
549
+
550
+
551
+ class SRLinear(nn.Linear):
552
+ """https://arxiv.org/abs/2303.06296
553
+ Stabilizing Transformer Training by Preventing Attention Entropy Collapse
554
+ """
555
+
556
+ def __init__(self, in_features, out_features, bias=True, **kwargs):
557
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
558
+ self.register_buffer(
559
+ "u", nn.functional.normalize(torch.randn(in_features), dim=0)
560
+ )
561
+ with torch.no_grad():
562
+ sigma = self.get_sigma()
563
+ self.register_buffer("spectral_norm", sigma)
564
+ self.sigma = nn.Parameter(torch.ones(1))
565
+
566
+ def get_sigma(self):
567
+ with torch.no_grad():
568
+ u = self.u
569
+ v = self.weight.mv(u)
570
+ v = nn.functional.normalize(v, dim=0)
571
+ u = self.weight.T.mv(v)
572
+ u = nn.functional.normalize(u, dim=0)
573
+ self.u.data.copy_(u)
574
+ return torch.einsum("c,cd,d->", v, self.weight, u)
575
+
576
+ def get_weight(self):
577
+ sigma = self.get_sigma()
578
+ if self.training:
579
+ self.spectral_norm.data.copy_(sigma)
580
+ weight = (self.sigma / sigma) * self.weight
581
+ return weight
582
+
583
+ def forward(self, x):
584
+ return nn.functional.linear(x, self.get_weight(), self.bias)
585
+
586
+
587
+ class SRConv1d(SRLinear):
588
+ def __init__(
589
+ self,
590
+ in_features,
591
+ out_features,
592
+ kernel_size,
593
+ stride: int = 1,
594
+ padding: str = "same",
595
+ bias: bool = True,
596
+ **kwargs,
597
+ ):
598
+ in_features = in_features * kernel_size
599
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
600
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
601
+ self.kernel_size = kernel_size
602
+ self.stride = stride
603
+ self.padding = padding
604
+
605
+ def forward(self, x):
606
+ in_features = self.in_features // self.kernel_size
607
+ weight = self.get_weight().view(
608
+ self.out_features, in_features, self.kernel_size
609
+ )
610
+ return nn.functional.conv1d(
611
+ x, weight, bias=self.bias, stride=self.stride, padding=self.padding
612
+ )
613
+
614
+
615
+ def TransposeSRConv1d(
616
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
617
+ ) -> nn.Sequential:
618
+ """
619
+ Transpose -> SRConv1d
620
+ """
621
+ return nn.Sequential(
622
+ Transpose(),
623
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
624
+ )
625
+
626
+
627
+ def SRConv1dTranspose(
628
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
629
+ ) -> nn.Sequential:
630
+ """
631
+ SRConv1d -> Transpose
632
+ """
633
+ return nn.Sequential(
634
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
635
+ Transpose(),
636
+ )
637
+
638
+
639
+ class ActivationBalancer(torch.nn.Module):
640
+ """
641
+ Modifies the backpropped derivatives of a function to try to encourage, for
642
+ each channel, that it is positive at least a proportion `threshold` of the
643
+ time. It does this by multiplying negative derivative values by up to
644
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
645
+ interpolated from 1 at the threshold to those extremal values when none
646
+ of the inputs are positive.
647
+
648
+ Args:
649
+ num_channels: the number of channels
650
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
651
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
652
+ min_positive: the minimum, per channel, of the proportion of the time
653
+ that (x > 0), below which we start to modify the derivatives.
654
+ max_positive: the maximum, per channel, of the proportion of the time
655
+ that (x > 0), above which we start to modify the derivatives.
656
+ max_factor: the maximum factor by which we modify the derivatives for
657
+ either the sign constraint or the magnitude constraint;
658
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
659
+ values in the range [0.98..1.02].
660
+ sign_gain_factor: determines the 'gain' with which we increase the
661
+ change in gradient once the constraints on min_positive and max_positive
662
+ are violated.
663
+ scale_gain_factor: determines the 'gain' with which we increase the
664
+ change in gradient once the constraints on min_abs and max_abs
665
+ are violated.
666
+ min_abs: the minimum average-absolute-value difference from the mean
667
+ value per channel, which we allow, before we start to modify
668
+ the derivatives to prevent this.
669
+ max_abs: the maximum average-absolute-value difference from the mean
670
+ value per channel, which we allow, before we start to modify
671
+ the derivatives to prevent this.
672
+ min_prob: determines the minimum probability with which we modify the
673
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
674
+ on each forward(). This is done randomly to prevent all layers
675
+ from doing it at the same time. Early in training we may use
676
+ higher probabilities than this; it will decay to this value.
677
+ """
678
+
679
+ def __init__(
680
+ self,
681
+ num_channels: int,
682
+ channel_dim: int,
683
+ min_positive: float = 0.05,
684
+ max_positive: float = 0.95,
685
+ max_factor: float = 0.04,
686
+ sign_gain_factor: float = 0.01,
687
+ scale_gain_factor: float = 0.02,
688
+ min_abs: float = 0.2,
689
+ max_abs: float = 100.0,
690
+ min_prob: float = 0.1,
691
+ ):
692
+ super(ActivationBalancer, self).__init__()
693
+ self.num_channels = num_channels
694
+ self.channel_dim = channel_dim
695
+ self.min_positive = min_positive
696
+ self.max_positive = max_positive
697
+ self.max_factor = max_factor
698
+ self.min_abs = min_abs
699
+ self.max_abs = max_abs
700
+ self.min_prob = min_prob
701
+ self.sign_gain_factor = sign_gain_factor
702
+ self.scale_gain_factor = scale_gain_factor
703
+
704
+ # count measures how many times the forward() function has been called.
705
+ # We occasionally sync this to a tensor called `count`, that exists to
706
+ # make sure it is synced to disk when we load and save the model.
707
+ self.cpu_count = 0
708
+ self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
709
+
710
+ def forward(self, x: Tensor) -> Tensor:
711
+ if (
712
+ torch.jit.is_scripting()
713
+ or not x.requires_grad
714
+ or torch.jit.is_tracing()
715
+ ):
716
+ return _no_op(x)
717
+
718
+ count = self.cpu_count
719
+ self.cpu_count += 1
720
+
721
+ if random.random() < 0.01:
722
+ # Occasionally sync self.cpu_count with self.count.
723
+ # count affects the decay of 'prob'. don't do this on every iter,
724
+ # because syncing with the GPU is slow.
725
+ self.cpu_count = max(self.cpu_count, self.count.item())
726
+ self.count.fill_(self.cpu_count)
727
+
728
+ # the prob of doing some work exponentially decreases from 0.5 till it hits
729
+ # a floor at min_prob (==0.1, by default)
730
+ prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
731
+
732
+ if random.random() < prob:
733
+ sign_gain_factor = 0.5
734
+ if self.min_positive != 0.0 or self.max_positive != 1.0:
735
+ sign_factor = _compute_sign_factor(
736
+ x,
737
+ self.channel_dim,
738
+ self.min_positive,
739
+ self.max_positive,
740
+ gain_factor=self.sign_gain_factor / prob,
741
+ max_factor=self.max_factor,
742
+ )
743
+ else:
744
+ sign_factor = None
745
+
746
+ scale_factor = _compute_scale_factor(
747
+ x.detach(),
748
+ self.channel_dim,
749
+ min_abs=self.min_abs,
750
+ max_abs=self.max_abs,
751
+ gain_factor=self.scale_gain_factor / prob,
752
+ max_factor=self.max_factor,
753
+ )
754
+ return ActivationBalancerFunction.apply(
755
+ x,
756
+ scale_factor,
757
+ sign_factor,
758
+ self.channel_dim,
759
+ )
760
+ else:
761
+ return _no_op(x)
762
+
763
+
764
+ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
765
+ """
766
+ Returns x unmodified, but in backprop will put a penalty for the excess of
767
+ the absolute values of elements of x over the limit "limit". E.g. if
768
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
769
+
770
+ Caution: the value of this penalty will be affected by grad scaling used
771
+ in automatic mixed precision training. For this reasons we use this,
772
+ it shouldn't really matter, or may even be helpful; we just use this
773
+ to disallow really implausible values of scores to be given to softmax.
774
+ """
775
+ x_sign = x.sign()
776
+ over_limit = (x.abs() - limit) > 0
777
+ # The following is a memory efficient way to penalize the absolute values of
778
+ # x that's over the limit. (The memory efficiency comes when you think
779
+ # about which items torch needs to cache for the autograd, and which ones it
780
+ # can throw away). The numerical value of aux_loss as computed here will
781
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
782
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
783
+ # limit).relu().
784
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
785
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
786
+ # sum() due to how with_loss() works.
787
+ x = with_loss(x, aux_loss)
788
+ # you must use x for something, or this will be ineffective.
789
+ return x
790
+
791
+
792
+ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
793
+ if x.ndim == 2:
794
+ return x.diag()
795
+ else:
796
+ (batch, dim, dim) = x.shape
797
+ x = x.reshape(batch, dim * dim)
798
+ x = x[:, :: dim + 1]
799
+ assert x.shape == (batch, dim)
800
+ return x
801
+
802
+
803
+ def _whitening_metric(x: Tensor, num_groups: int):
804
+ """
805
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
806
+ of the centered feature covariance are the same within each group's covariance matrix
807
+ and also between groups.
808
+ Args:
809
+ x: a Tensor of shape (*, num_channels)
810
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
811
+ Returns:
812
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
813
+ greater than 1.0 otherwise.
814
+ """
815
+ assert x.dtype != torch.float16
816
+ x = x.reshape(-1, x.shape[-1])
817
+ (num_frames, num_channels) = x.shape
818
+ assert num_channels % num_groups == 0
819
+ channels_per_group = num_channels // num_groups
820
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
821
+ # x now has shape (num_groups, num_frames, channels_per_group)
822
+ # subtract the mean so we use the centered, not uncentered, covariance.
823
+ # My experience has been that when we "mess with the gradients" like this,
824
+ # it's better not do anything that tries to move the mean around, because
825
+ # that can easily cause instability.
826
+ x = x - x.mean(dim=1, keepdim=True)
827
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
828
+ x_covar = torch.matmul(x.transpose(1, 2), x)
829
+ x_covar_mean_diag = _diag(x_covar).mean()
830
+ # the following expression is what we'd get if we took the matrix product
831
+ # of each covariance and measured the mean of its trace, i.e.
832
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
833
+ x_covarsq_mean_diag = (x_covar ** 2).sum() / (
834
+ num_groups * channels_per_group
835
+ )
836
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
837
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
838
+ return metric
839
+
840
+
841
+ class WhiteningPenaltyFunction(torch.autograd.Function):
842
+ @staticmethod
843
+ def forward(
844
+ ctx,
845
+ x: Tensor,
846
+ num_groups: int,
847
+ whitening_limit: float,
848
+ grad_scale: float,
849
+ ) -> Tensor:
850
+ ctx.save_for_backward(x)
851
+ ctx.num_groups = num_groups
852
+ ctx.whitening_limit = whitening_limit
853
+ ctx.grad_scale = grad_scale
854
+ return x
855
+
856
+ @staticmethod
857
+ def backward(ctx, x_grad: Tensor):
858
+ (x_orig,) = ctx.saved_tensors
859
+ with torch.enable_grad():
860
+ with torch.cuda.amp.autocast(enabled=False):
861
+ x_detached = x_orig.to(torch.float32).detach()
862
+ x_detached.requires_grad = True
863
+
864
+ metric = _whitening_metric(x_detached, ctx.num_groups)
865
+
866
+ if random.random() < 0.005 or __name__ == "__main__":
867
+ logging.info(
868
+ f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
869
+ f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
870
+ )
871
+
872
+ (metric - ctx.whitening_limit).relu().backward()
873
+ penalty_grad = x_detached.grad
874
+ scale = ctx.grad_scale * (
875
+ x_grad.to(torch.float32).norm()
876
+ / (penalty_grad.norm() + 1.0e-20)
877
+ )
878
+ penalty_grad = penalty_grad * scale
879
+ return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
880
+
881
+
882
+ class Whiten(nn.Module):
883
+ def __init__(
884
+ self,
885
+ num_groups: int,
886
+ whitening_limit: float,
887
+ prob: Union[float, Tuple[float, float]],
888
+ grad_scale: float,
889
+ ):
890
+ """
891
+ Args:
892
+ num_groups: the number of groups to divide the channel dim into before
893
+ whitening. We will attempt to make the feature covariance
894
+ within each group, after mean subtraction, as "white" as possible,
895
+ while having the same trace across all groups.
896
+ whitening_limit: a value greater than 1.0, that dictates how much
897
+ freedom we have to violate the constraints. 1.0 would mean perfectly
898
+ white, with exactly the same trace across groups; larger values
899
+ give more freedom. E.g. 2.0.
900
+ prob: the probability with which we apply the gradient modification
901
+ (also affects the grad scale). May be supplied as a float,
902
+ or as a pair (min_prob, max_prob)
903
+
904
+ grad_scale: determines the scale on the gradient term from this object,
905
+ relative to the rest of the gradient on the attention weights.
906
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
907
+ """
908
+ super(Whiten, self).__init__()
909
+ assert num_groups >= 1
910
+ assert whitening_limit >= 1
911
+ assert grad_scale >= 0
912
+ self.num_groups = num_groups
913
+ self.whitening_limit = whitening_limit
914
+ if isinstance(prob, float):
915
+ assert 0 < prob <= 1
916
+ self.prob = prob
917
+ else:
918
+ (self.min_prob, self.max_prob) = prob
919
+ assert 0 < self.min_prob < self.max_prob <= 1
920
+ self.prob = self.max_prob
921
+
922
+ self.grad_scale = grad_scale
923
+
924
+ def forward(self, x: Tensor) -> Tensor:
925
+ """
926
+ In the forward pass, this function just returns the input unmodified.
927
+ In the backward pass, it will modify the gradients to ensure that the
928
+ distribution in each group has close to (lambda times I) as the covariance
929
+ after mean subtraction, with the same lambda across groups.
930
+ For whitening_limit > 1, there will be more freedom to violate this
931
+ constraint.
932
+
933
+ Args:
934
+ x: the input of shape (*, num_channels)
935
+
936
+ Returns:
937
+ x, unmodified. You should make sure
938
+ you use the returned value, or the graph will be freed
939
+ and nothing will happen in backprop.
940
+ """
941
+ if (
942
+ not x.requires_grad
943
+ or random.random() > self.prob
944
+ or self.grad_scale == 0
945
+ ):
946
+ return _no_op(x)
947
+ else:
948
+ if hasattr(self, "min_prob") and random.random() < 0.25:
949
+ # occasionally switch between min_prob and max_prob, based on whether
950
+ # we are above or below the threshold.
951
+ if (
952
+ _whitening_metric(x.to(torch.float32), self.num_groups)
953
+ > self.whitening_limit
954
+ ):
955
+ # there would be a change to the grad.
956
+ self.prob = self.max_prob
957
+ else:
958
+ self.prob = self.min_prob
959
+
960
+ return WhiteningPenaltyFunction.apply(
961
+ x, self.num_groups, self.whitening_limit, self.grad_scale
962
+ )
963
+
964
+
965
+ class WithLoss(torch.autograd.Function):
966
+ @staticmethod
967
+ def forward(ctx, x: Tensor, y: Tensor):
968
+ ctx.y_shape = y.shape
969
+ return x
970
+
971
+ @staticmethod
972
+ def backward(ctx, ans_grad: Tensor):
973
+ return ans_grad, torch.ones(
974
+ ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
975
+ )
976
+
977
+
978
+ def with_loss(x, y):
979
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
980
+ return x
981
+ # returns x but adds y.sum() to the loss function.
982
+ return WithLoss.apply(x, y)
983
+
984
+
985
+ def _no_op(x: Tensor) -> Tensor:
986
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
987
+ return x
988
+ else:
989
+ # a no-op function that will have a node in the autograd graph,
990
+ # to avoid certain bugs relating to backward hooks
991
+ return x.chunk(1, dim=-1)[0]
992
+
993
+
994
+ class Identity(torch.nn.Module):
995
+ def __init__(self):
996
+ super(Identity, self).__init__()
997
+
998
+ def forward(self, x):
999
+ return _no_op(x)
1000
+
1001
+
1002
+ class MaxEig(torch.nn.Module):
1003
+ """
1004
+ Modifies the backpropped derivatives of a function to try to discourage
1005
+ that any given direction in activation space accounts for more than
1006
+ a specified proportion of the covariance (e.g. 0.2).
1007
+
1008
+
1009
+ Args:
1010
+ num_channels: the number of channels
1011
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
1012
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
1013
+ max_var_per_eig: the maximum proportion of the variance of the
1014
+ features/channels, after mean subtraction, that can come from
1015
+ any given eigenvalue.
1016
+ min_prob: the minimum probability with which we apply this during any invocation
1017
+ of forward(), assuming last time we applied the constraint it was
1018
+ not active; supplied for speed.
1019
+ scale: determines the scale with which we modify the gradients, relative
1020
+ to the existing / unmodified gradients
1021
+ """
1022
+
1023
+ def __init__(
1024
+ self,
1025
+ num_channels: int,
1026
+ channel_dim: int,
1027
+ max_var_per_eig: float = 0.2,
1028
+ min_prob: float = 0.01,
1029
+ scale: float = 0.01,
1030
+ ):
1031
+ super(MaxEig, self).__init__()
1032
+ self.num_channels = num_channels
1033
+ self.channel_dim = channel_dim
1034
+ self.scale = scale
1035
+ assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
1036
+ self.max_var_per_eig = max_var_per_eig
1037
+
1038
+ # we figure out the dominant direction using the power method: starting with
1039
+ # a random vector, keep multiplying by the covariance and renormalizing.
1040
+ with torch.no_grad():
1041
+ # arbitrary.. would use randn() but want to leave the rest of the model's
1042
+ # random parameters unchanged for comparison
1043
+ direction = torch.arange(num_channels).to(torch.float)
1044
+ direction = direction / direction.norm()
1045
+ self.register_buffer("max_eig_direction", direction)
1046
+
1047
+ self.min_prob = min_prob
1048
+ # cur_prob is the current probability we'll use to apply the ActivationBalancer.
1049
+ # We'll regress this towards prob, each time we try to apply it and it is not
1050
+ # active.
1051
+ self.cur_prob = 1.0
1052
+
1053
+ def forward(self, x: Tensor) -> Tensor:
1054
+ if (
1055
+ torch.jit.is_scripting()
1056
+ or self.max_var_per_eig <= 0
1057
+ or random.random() > self.cur_prob
1058
+ or torch.jit.is_tracing()
1059
+ ):
1060
+ return _no_op(x)
1061
+
1062
+ with torch.cuda.amp.autocast(enabled=False):
1063
+ eps = 1.0e-20
1064
+ orig_x = x
1065
+ x = x.to(torch.float32)
1066
+ with torch.no_grad():
1067
+ x = x.transpose(self.channel_dim, -1).reshape(
1068
+ -1, self.num_channels
1069
+ )
1070
+ x = x - x.mean(dim=0)
1071
+ new_direction, coeffs = self._find_direction_coeffs(
1072
+ x, self.max_eig_direction
1073
+ )
1074
+ x_var = (x ** 2).mean()
1075
+ x_residual = x - coeffs * new_direction
1076
+ x_residual_var = (x_residual ** 2).mean()
1077
+
1078
+ # `variance_proportion` is the proportion of the variance accounted for
1079
+ # by the top eigen-direction.
1080
+ variance_proportion = (x_var - x_residual_var) / (
1081
+ x_var + 1.0e-20
1082
+ )
1083
+
1084
+ # ensure new direction is nonzero even if x == 0, by including `direction`.
1085
+ self._set_direction(
1086
+ 0.1 * self.max_eig_direction + new_direction
1087
+ )
1088
+
1089
+ if random.random() < 0.01 or __name__ == "__main__":
1090
+ logging.info(
1091
+ f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
1092
+ )
1093
+
1094
+ if variance_proportion >= self.max_var_per_eig:
1095
+ # The constraint is active. Note, we should quite rarely
1096
+ # reach here, only near the beginning of training if we are
1097
+ # starting to diverge, should this constraint be active.
1098
+ cur_prob = self.cur_prob
1099
+ self.cur_prob = (
1100
+ 1.0 # next time, do the update with probability 1.0.
1101
+ )
1102
+ return MaxEigLimiterFunction.apply(
1103
+ orig_x, coeffs, new_direction, self.channel_dim, self.scale
1104
+ )
1105
+ else:
1106
+ # let self.cur_prob exponentially approach self.min_prob, as
1107
+ # long as the constraint is inactive.
1108
+ self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
1109
+ return orig_x
1110
+
1111
+ def _set_direction(self, direction: Tensor):
1112
+ """
1113
+ Sets self.max_eig_direction to a normalized version of `direction`
1114
+ """
1115
+ direction = direction.detach()
1116
+ direction = direction / direction.norm()
1117
+ direction_sum = direction.sum().item()
1118
+ if direction_sum - direction_sum == 0: # no inf/nan
1119
+ self.max_eig_direction[:] = direction
1120
+ else:
1121
+ logging.info(
1122
+ f"Warning: sum of direction in MaxEig is {direction_sum}, "
1123
+ "num_channels={self.num_channels}, channel_dim={self.channel_dim}"
1124
+ )
1125
+
1126
+ def _find_direction_coeffs(
1127
+ self, x: Tensor, prev_direction: Tensor
1128
+ ) -> Tuple[Tensor, Tensor, Tensor]:
1129
+ """
1130
+ Figure out (an approximation to) the proportion of the variance of a set of
1131
+ feature vectors that can be attributed to the top eigen-direction.
1132
+ Args:
1133
+ x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
1134
+ prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
1135
+ of the top eigen-direction, or a random direction if this is the first
1136
+ iteration. Does not have to be normalized, but should be nonzero.
1137
+
1138
+ Returns: (cur_direction, coeffs), where:
1139
+ cur_direction: a Tensor of shape (num_channels,) that is the current
1140
+ estimate of the top eigen-direction.
1141
+ coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
1142
+ approximately minimizes, (x - coeffs * cur_direction).norm()
1143
+ """
1144
+ (num_frames, num_channels) = x.shape
1145
+ assert num_channels > 1 and num_frames > 1
1146
+ assert prev_direction.shape == (num_channels,)
1147
+ # `coeffs` are the coefficients of `prev_direction` in x.
1148
+ # actually represent the coeffs up to a constant positive factor.
1149
+ coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
1150
+ cur_direction = (x * coeffs).sum(dim=0) / (
1151
+ (coeffs ** 2).sum() + 1.0e-20
1152
+ )
1153
+ return cur_direction, coeffs
1154
+
1155
+
1156
+ class DoubleSwishFunction(torch.autograd.Function):
1157
+ """
1158
+ double_swish(x) = x * torch.sigmoid(x-1)
1159
+ This is a definition, originally motivated by its close numerical
1160
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
1161
+
1162
+ Memory-efficient derivative computation:
1163
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
1164
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
1165
+ Now, s'(x) = s(x) * (1-s(x)).
1166
+ double_swish'(x) = x * s'(x) + s(x).
1167
+ = x * s(x) * (1-s(x)) + s(x).
1168
+ = double_swish(x) * (1-s(x)) + s(x)
1169
+ ... so we just need to remember s(x) but not x itself.
1170
+ """
1171
+
1172
+ @staticmethod
1173
+ def forward(ctx, x: Tensor) -> Tensor:
1174
+ requires_grad = x.requires_grad
1175
+ x_dtype = x.dtype
1176
+ if x.dtype == torch.float16:
1177
+ x = x.to(torch.float32)
1178
+
1179
+ s = torch.sigmoid(x - 1.0)
1180
+ y = x * s
1181
+
1182
+ if requires_grad:
1183
+ deriv = y * (1 - s) + s
1184
+ # notes on derivative of x * sigmoid(x - 1):
1185
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
1186
+ # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
1187
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
1188
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
1189
+ # floors), should be expectation-preserving.
1190
+ floor = -0.043637
1191
+ ceil = 1.2
1192
+ d_scaled = (deriv - floor) * (
1193
+ 255.0 / (ceil - floor)
1194
+ ) + torch.rand_like(deriv)
1195
+ if __name__ == "__main__":
1196
+ # for self-testing only.
1197
+ assert d_scaled.min() >= 0.0
1198
+ assert d_scaled.max() < 256.0
1199
+ d_int = d_scaled.to(torch.uint8)
1200
+ ctx.save_for_backward(d_int)
1201
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1202
+ y = y.to(torch.float16)
1203
+ return y
1204
+
1205
+ @staticmethod
1206
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1207
+ (d,) = ctx.saved_tensors
1208
+ # the same constants as used in forward pass.
1209
+ floor = -0.043637
1210
+ ceil = 1.2
1211
+ d = d * ((ceil - floor) / 255.0) + floor
1212
+ return y_grad * d
1213
+
1214
+
1215
+ class DoubleSwish(torch.nn.Module):
1216
+ def forward(self, x: Tensor) -> Tensor:
1217
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
1218
+ that we approximate closely with x * sigmoid(x-1).
1219
+ """
1220
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1221
+ return x * torch.sigmoid(x - 1.0)
1222
+ return DoubleSwishFunction.apply(x)
1223
+
1224
+
1225
+ def BalancedDoubleSwish(
1226
+ d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
1227
+ ) -> nn.Sequential:
1228
+ """
1229
+ ActivationBalancer -> DoubleSwish
1230
+ """
1231
+ balancer = ActivationBalancer(
1232
+ d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
1233
+ )
1234
+ return nn.Sequential(
1235
+ balancer,
1236
+ DoubleSwish(),
1237
+ )
1238
+
1239
+
1240
+ def _test_max_eig():
1241
+ for proportion in [0.1, 0.5, 10.0]:
1242
+ logging.info(f"proportion = {proportion}")
1243
+ x = torch.randn(100, 128)
1244
+ direction = torch.randn(128)
1245
+ coeffs = torch.randn(100, 1)
1246
+ x += proportion * direction * coeffs
1247
+
1248
+ x.requires_grad = True
1249
+
1250
+ num_channels = 128
1251
+ m = MaxEig(
1252
+ num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
1253
+ ) # grad_scale
1254
+
1255
+ for _ in range(4):
1256
+ y = m(x)
1257
+
1258
+ y_grad = torch.randn_like(x)
1259
+ y.backward(gradient=y_grad)
1260
+
1261
+ if proportion < 0.2:
1262
+ assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
1263
+ elif proportion > 1.0:
1264
+ assert not torch.allclose(x.grad, y_grad)
1265
+
1266
+
1267
+ def _test_whiten():
1268
+ for proportion in [0.1, 0.5, 10.0]:
1269
+ logging.info(f"_test_whiten(): proportion = {proportion}")
1270
+ x = torch.randn(100, 128)
1271
+ direction = torch.randn(128)
1272
+ coeffs = torch.randn(100, 1)
1273
+ x += proportion * direction * coeffs
1274
+
1275
+ x.requires_grad = True
1276
+
1277
+ num_channels = 128
1278
+ m = Whiten(
1279
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
1280
+ ) # grad_scale
1281
+
1282
+ for _ in range(4):
1283
+ y = m(x)
1284
+
1285
+ y_grad = torch.randn_like(x)
1286
+ y.backward(gradient=y_grad)
1287
+
1288
+ if proportion < 0.2:
1289
+ assert torch.allclose(x.grad, y_grad)
1290
+ elif proportion > 1.0:
1291
+ assert not torch.allclose(x.grad, y_grad)
1292
+
1293
+
1294
+ def _test_activation_balancer_sign():
1295
+ probs = torch.arange(0, 1, 0.01)
1296
+ N = 1000
1297
+ x = 1.0 * (
1298
+ (2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
1299
+ )
1300
+ x = x.detach()
1301
+ x.requires_grad = True
1302
+ m = ActivationBalancer(
1303
+ probs.numel(),
1304
+ channel_dim=0,
1305
+ min_positive=0.05,
1306
+ max_positive=0.95,
1307
+ max_factor=0.2,
1308
+ min_abs=0.0,
1309
+ )
1310
+
1311
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
1312
+
1313
+ y = m(x)
1314
+ y.backward(gradient=y_grad)
1315
+ print("_test_activation_balancer_sign: x = ", x)
1316
+ print("_test_activation_balancer_sign: y grad = ", y_grad)
1317
+ print("_test_activation_balancer_sign: x grad = ", x.grad)
1318
+
1319
+
1320
+ def _test_activation_balancer_magnitude():
1321
+ magnitudes = torch.arange(0, 1, 0.01)
1322
+ N = 1000
1323
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
1324
+ -1
1325
+ )
1326
+ x = x.detach()
1327
+ x.requires_grad = True
1328
+ m = ActivationBalancer(
1329
+ magnitudes.numel(),
1330
+ channel_dim=0,
1331
+ min_positive=0.0,
1332
+ max_positive=1.0,
1333
+ max_factor=0.2,
1334
+ min_abs=0.2,
1335
+ max_abs=0.8,
1336
+ min_prob=1.0,
1337
+ )
1338
+
1339
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
1340
+
1341
+ y = m(x)
1342
+ y.backward(gradient=y_grad)
1343
+ print("_test_activation_balancer_magnitude: x = ", x)
1344
+ print("_test_activation_balancer_magnitude: y grad = ", y_grad)
1345
+ print("_test_activation_balancer_magnitude: x grad = ", x.grad)
1346
+
1347
+
1348
+ def _test_basic_norm():
1349
+ num_channels = 128
1350
+ m = BasicNorm(num_channels=num_channels, channel_dim=1)
1351
+
1352
+ x = torch.randn(500, num_channels)
1353
+
1354
+ y = m(x)
1355
+
1356
+ assert y.shape == x.shape
1357
+ x_rms = (x ** 2).mean().sqrt()
1358
+ y_rms = (y ** 2).mean().sqrt()
1359
+ print("x rms = ", x_rms)
1360
+ print("y rms = ", y_rms)
1361
+ assert y_rms < x_rms
1362
+ assert y_rms > 0.5 * x_rms
1363
+
1364
+
1365
+ def _test_double_swish_deriv():
1366
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1367
+ x.requires_grad = True
1368
+ m = DoubleSwish()
1369
+
1370
+ tol = (1.2 - (-0.043637)) / 255.0
1371
+ torch.autograd.gradcheck(m, x, atol=tol)
1372
+
1373
+ # for self-test.
1374
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1375
+ x.requires_grad = True
1376
+ y = m(x)
1377
+
1378
+
1379
+ def _test_softmax():
1380
+ a = torch.randn(2, 10, dtype=torch.float64)
1381
+ b = a.clone()
1382
+ a.requires_grad = True
1383
+ b.requires_grad = True
1384
+ a.softmax(dim=1)[:, 0].sum().backward()
1385
+ print("a grad = ", a.grad)
1386
+ softmax(b, dim=1)[:, 0].sum().backward()
1387
+ print("b grad = ", b.grad)
1388
+ assert torch.allclose(a.grad, b.grad)
1389
+
1390
+
1391
+ if __name__ == "__main__":
1392
+ logging.getLogger().setLevel(logging.INFO)
1393
+ torch.set_num_threads(1)
1394
+ torch.set_num_interop_threads(1)
1395
+ _test_softmax()
1396
+ _test_whiten()
1397
+ _test_max_eig()
1398
+ _test_activation_balancer_sign()
1399
+ _test_activation_balancer_magnitude()
1400
+ _test_basic_norm()
1401
+ _test_double_swish_deriv()
modules/scheduler.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2023 (authors: Feiteng Li)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ import torch
20
+
21
+ from modules.optim import Eden
22
+
23
+
24
+ def calc_lr(step, dim_embed, warmup_steps):
25
+ return dim_embed ** (-0.5) * min(
26
+ step ** (-0.5), step * warmup_steps ** (-1.5)
27
+ )
28
+
29
+
30
+ class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
31
+ def __init__(
32
+ self,
33
+ base_lr: float,
34
+ optimizer: torch.optim.Optimizer,
35
+ dim_embed: int,
36
+ warmup_steps: int,
37
+ last_epoch: int = -1,
38
+ verbose: bool = False,
39
+ ) -> None:
40
+
41
+ self.dim_embed = dim_embed
42
+ self.base_lr = base_lr
43
+ self.warmup_steps = warmup_steps
44
+ self.num_param_groups = len(optimizer.param_groups)
45
+
46
+ super().__init__(optimizer, last_epoch, verbose)
47
+
48
+ def get_lr(self) -> float:
49
+ lr = self.base_lr * calc_lr(
50
+ self._step_count, self.dim_embed, self.warmup_steps
51
+ )
52
+ return [lr] * self.num_param_groups
53
+
54
+ def set_step(self, step: int):
55
+ self._step_count = step
56
+
57
+
58
+ def get_scheduler(params, optimizer):
59
+ if params.scheduler_name.lower() == "eden":
60
+ scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps)
61
+ elif params.scheduler_name.lower() == "noam":
62
+ scheduler = NoamScheduler(
63
+ params.base_lr,
64
+ optimizer,
65
+ params.decoder_dim,
66
+ warmup_steps=params.warmup_steps,
67
+ )
68
+ # scheduler.set_step(params.start_batch or params.batch_idx_train)
69
+ elif params.scheduler_name.lower() == "cosine":
70
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
71
+ params.warmup_steps,
72
+ optimizer,
73
+ eta_min=params.base_lr,
74
+ )
75
+ else:
76
+ raise NotImplementedError(f"{params.scheduler_name}")
77
+
78
+ return scheduler
modules/transformer.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import numbers
3
+ from functools import partial
4
+ from typing import Any, Callable, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from torch import Tensor, nn
8
+ from torch.nn import functional as F
9
+
10
+ from .activation import MultiheadAttention
11
+ from .scaling import ActivationBalancer, BalancedDoubleSwish
12
+ from .scaling import BasicNorm as _BasicNorm
13
+
14
+ _shape_t = Union[int, List[int], torch.Size]
15
+
16
+
17
+ class LayerNorm(nn.Module):
18
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
19
+ normalized_shape: Tuple[int, ...]
20
+ eps: float
21
+ elementwise_affine: bool
22
+
23
+ def __init__(
24
+ self,
25
+ normalized_shape: _shape_t,
26
+ eps: float = 1e-5,
27
+ elementwise_affine: bool = True,
28
+ device=None,
29
+ dtype=None,
30
+ ) -> None:
31
+ factory_kwargs = {"device": device, "dtype": dtype}
32
+ super(LayerNorm, self).__init__()
33
+ if isinstance(normalized_shape, numbers.Integral):
34
+ # mypy error: incompatible types in assignment
35
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
36
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
37
+ self.eps = eps
38
+ self.elementwise_affine = elementwise_affine
39
+ if self.elementwise_affine:
40
+ self.weight = nn.Parameter(
41
+ torch.empty(self.normalized_shape, **factory_kwargs)
42
+ )
43
+ self.bias = nn.Parameter(
44
+ torch.empty(self.normalized_shape, **factory_kwargs)
45
+ )
46
+ else:
47
+ self.register_parameter("weight", None)
48
+ self.register_parameter("bias", None)
49
+
50
+ self.reset_parameters()
51
+
52
+ def reset_parameters(self) -> None:
53
+ if self.elementwise_affine:
54
+ nn.init.ones_(self.weight)
55
+ nn.init.zeros_(self.bias)
56
+
57
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
58
+ if isinstance(input, tuple):
59
+ input, embedding = input
60
+ return (
61
+ F.layer_norm(
62
+ input,
63
+ self.normalized_shape,
64
+ self.weight,
65
+ self.bias,
66
+ self.eps,
67
+ ),
68
+ embedding,
69
+ )
70
+
71
+ assert embedding is None
72
+ return F.layer_norm(
73
+ input, self.normalized_shape, self.weight, self.bias, self.eps
74
+ )
75
+
76
+ def extra_repr(self) -> str:
77
+ return (
78
+ "{normalized_shape}, eps={eps}, "
79
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
80
+ )
81
+
82
+
83
+ class AdaptiveLayerNorm(nn.Module):
84
+ r"""Adaptive Layer Normalization"""
85
+
86
+ def __init__(self, d_model, norm) -> None:
87
+ super(AdaptiveLayerNorm, self).__init__()
88
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
89
+ self.norm = norm
90
+ self.d_model = d_model
91
+ self.eps = self.norm.eps
92
+
93
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
94
+ if isinstance(input, tuple):
95
+ input, embedding = input
96
+ weight, bias = torch.split(
97
+ self.project_layer(embedding),
98
+ split_size_or_sections=self.d_model,
99
+ dim=-1,
100
+ )
101
+ return (weight * self.norm(input) + bias, embedding)
102
+
103
+ weight, bias = torch.split(
104
+ self.project_layer(embedding),
105
+ split_size_or_sections=self.d_model,
106
+ dim=-1,
107
+ )
108
+ return weight * self.norm(input) + bias
109
+
110
+
111
+ class BasicNorm(_BasicNorm):
112
+ def __init__(
113
+ self,
114
+ d_model: int,
115
+ eps: float = 1e-5,
116
+ device=None,
117
+ dtype=None,
118
+ ):
119
+ super(BasicNorm, self).__init__(d_model, eps=eps)
120
+
121
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
122
+ if isinstance(input, tuple):
123
+ input, embedding = input
124
+ return (
125
+ super(BasicNorm, self).forward(input),
126
+ embedding,
127
+ )
128
+
129
+ assert embedding is None
130
+ return super(BasicNorm, self).forward(input)
131
+
132
+
133
+ class BalancedBasicNorm(nn.Module):
134
+ def __init__(
135
+ self,
136
+ d_model: int,
137
+ eps: float = 1e-5,
138
+ device=None,
139
+ dtype=None,
140
+ ):
141
+ super(BalancedBasicNorm, self).__init__()
142
+ self.balancer = ActivationBalancer(
143
+ d_model,
144
+ channel_dim=-1,
145
+ min_positive=0.45,
146
+ max_positive=0.55,
147
+ max_abs=6.0,
148
+ )
149
+ self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
150
+
151
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
152
+ if isinstance(input, tuple):
153
+ input, embedding = input
154
+ return self.norm((self.balancer(input), embedding))
155
+
156
+ assert embedding is None
157
+ return self.norm(self.balancer(input))
158
+
159
+
160
+ class IdentityNorm(nn.Module):
161
+ def __init__(
162
+ self,
163
+ d_model: int,
164
+ eps: float = 1e-5,
165
+ device=None,
166
+ dtype=None,
167
+ ) -> None:
168
+ super(IdentityNorm, self).__init__()
169
+
170
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
171
+ if isinstance(input, tuple):
172
+ return input
173
+
174
+ assert embedding is None
175
+ return input
176
+
177
+
178
+ class TransformerEncoderLayer(nn.Module):
179
+ __constants__ = ["batch_first", "norm_first"]
180
+
181
+ def __init__(
182
+ self,
183
+ d_model: int,
184
+ nhead: int,
185
+ dim_feedforward: int = 2048,
186
+ dropout: float = 0.1,
187
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
188
+ batch_first: bool = False,
189
+ norm_first: bool = False,
190
+ device=None,
191
+ dtype=None,
192
+ linear1_self_attention_cls: nn.Module = nn.Linear,
193
+ linear2_self_attention_cls: nn.Module = nn.Linear,
194
+ linear1_feedforward_cls: nn.Module = nn.Linear,
195
+ linear2_feedforward_cls: nn.Module = nn.Linear,
196
+ layer_norm_cls: nn.Module = LayerNorm,
197
+ layer_norm_eps: float = 1e-5,
198
+ adaptive_layer_norm=False,
199
+ ) -> None:
200
+ factory_kwargs = {"device": device, "dtype": dtype}
201
+ super(TransformerEncoderLayer, self).__init__()
202
+ self.self_attn = MultiheadAttention(
203
+ d_model,
204
+ nhead,
205
+ dropout=dropout,
206
+ batch_first=batch_first,
207
+ linear1_cls=linear1_self_attention_cls,
208
+ linear2_cls=linear2_self_attention_cls,
209
+ **factory_kwargs,
210
+ )
211
+
212
+ # Implementation of Feedforward model
213
+ self.linear1 = linear1_feedforward_cls(
214
+ d_model, dim_feedforward, **factory_kwargs
215
+ )
216
+ self.dropout = nn.Dropout(dropout)
217
+ self.linear2 = linear2_feedforward_cls(
218
+ dim_feedforward, d_model, **factory_kwargs
219
+ )
220
+
221
+ self.norm_first = norm_first
222
+ self.dropout1 = nn.Dropout(dropout)
223
+ self.dropout2 = nn.Dropout(dropout)
224
+
225
+ # Legacy string support for activation function.
226
+ if isinstance(activation, str):
227
+ activation = _get_activation_fn(activation)
228
+ elif isinstance(activation, partial):
229
+ activation = activation(d_model)
230
+ elif activation == BalancedDoubleSwish:
231
+ activation = BalancedDoubleSwish(d_model)
232
+
233
+ # # We can't test self.activation in forward() in TorchScript,
234
+ # # so stash some information about it instead.
235
+ # if activation is F.relu or isinstance(activation, torch.nn.ReLU):
236
+ # self.activation_relu_or_gelu = 1
237
+ # elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
238
+ # self.activation_relu_or_gelu = 2
239
+ # else:
240
+ # self.activation_relu_or_gelu = 0
241
+ self.activation = activation
242
+
243
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
244
+ if layer_norm_cls == IdentityNorm:
245
+ norm2 = BalancedBasicNorm(
246
+ d_model, eps=layer_norm_eps, **factory_kwargs
247
+ )
248
+ else:
249
+ norm2 = layer_norm_cls(
250
+ d_model, eps=layer_norm_eps, **factory_kwargs
251
+ )
252
+
253
+ if adaptive_layer_norm:
254
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
255
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
256
+ else:
257
+ self.norm1 = norm1
258
+ self.norm2 = norm2
259
+
260
+ def __setstate__(self, state):
261
+ super(TransformerEncoderLayer, self).__setstate__(state)
262
+ if not hasattr(self, "activation"):
263
+ self.activation = F.relu
264
+
265
+ def forward(
266
+ self,
267
+ src: Tensor,
268
+ src_mask: Optional[Tensor] = None,
269
+ src_key_padding_mask: Optional[Tensor] = None,
270
+ ) -> Tensor:
271
+ r"""Pass the input through the encoder layer.
272
+
273
+ Args:
274
+ src: the sequence to the encoder layer (required).
275
+ src_mask: the mask for the src sequence (optional).
276
+ src_key_padding_mask: the mask for the src keys per batch (optional).
277
+
278
+ Shape:
279
+ see the docs in Transformer class.
280
+ """
281
+ x, stage_embedding = src, None
282
+ is_src_tuple = False
283
+ if isinstance(src, tuple):
284
+ x, stage_embedding = src
285
+ is_src_tuple = True
286
+
287
+ if src_key_padding_mask is not None:
288
+ _skpm_dtype = src_key_padding_mask.dtype
289
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
290
+ src_key_padding_mask
291
+ ):
292
+ raise AssertionError(
293
+ "only bool and floating types of key_padding_mask are supported"
294
+ )
295
+
296
+ if self.norm_first:
297
+ x = x + self._sa_block(
298
+ self.norm1(x, stage_embedding),
299
+ src_mask,
300
+ src_key_padding_mask,
301
+ )
302
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
303
+ else:
304
+ x = self.norm1(
305
+ x + self._sa_block(x, src_mask, src_key_padding_mask),
306
+ stage_embedding,
307
+ )
308
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
309
+
310
+ if is_src_tuple:
311
+ return (x, stage_embedding)
312
+ return x
313
+
314
+ def infer(
315
+ self,
316
+ src: Tensor,
317
+ src_mask: Optional[Tensor] = None,
318
+ src_key_padding_mask: Optional[Tensor] = None,
319
+ past_kv: Optional[Tensor] = None,
320
+ use_cache: bool = False,
321
+ ):
322
+ x, stage_embedding = src, None
323
+ is_src_tuple = False
324
+ if isinstance(src, tuple):
325
+ x, stage_embedding = src
326
+ is_src_tuple = True
327
+
328
+ if src_key_padding_mask is not None:
329
+ _skpm_dtype = src_key_padding_mask.dtype
330
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
331
+ src_key_padding_mask
332
+ ):
333
+ raise AssertionError(
334
+ "only bool and floating types of key_padding_mask are supported"
335
+ )
336
+
337
+ if self.norm_first:
338
+ x_attn_out, kv = self.self_attn.infer(
339
+ self.norm1(x, stage_embedding),
340
+ attn_mask=src_mask,
341
+ key_padding_mask=src_key_padding_mask,
342
+ need_weights=False,
343
+ past_kv=past_kv,
344
+ use_cache=use_cache,
345
+ )
346
+ x = x + x_attn_out
347
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
348
+
349
+ if is_src_tuple:
350
+ return (x, stage_embedding)
351
+ return (x, kv)
352
+
353
+ # self-attention block
354
+ def _sa_block(
355
+ self,
356
+ x: Tensor,
357
+ attn_mask: Optional[Tensor],
358
+ key_padding_mask: Optional[Tensor],
359
+ ) -> Tensor:
360
+ x = self.self_attn(
361
+ x,
362
+ x,
363
+ x,
364
+ attn_mask=attn_mask,
365
+ key_padding_mask=key_padding_mask,
366
+ need_weights=False,
367
+ )[0]
368
+ return self.dropout1(x)
369
+
370
+ # feed forward block
371
+ def _ff_block(self, x: Tensor) -> Tensor:
372
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
373
+ return self.dropout2(x)
374
+
375
+
376
+ class TransformerEncoder(nn.Module):
377
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
378
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
379
+
380
+ Args:
381
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
382
+ num_layers: the number of sub-encoder-layers in the encoder (required).
383
+ norm: the layer normalization component (optional).
384
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
385
+ (and convert back on output). This will improve the overall performance of
386
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
387
+
388
+ Examples::
389
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
390
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
391
+ >>> src = torch.rand(10, 32, 512)
392
+ >>> out = transformer_encoder(src)
393
+ """
394
+ __constants__ = ["norm"]
395
+
396
+ def __init__(self, encoder_layer, num_layers, norm=None):
397
+ super(TransformerEncoder, self).__init__()
398
+ self.layers = _get_clones(encoder_layer, num_layers)
399
+ self.num_layers = num_layers
400
+ self.norm = norm
401
+
402
+ def forward(
403
+ self,
404
+ src: Tensor,
405
+ mask: Optional[Tensor] = None,
406
+ src_key_padding_mask: Optional[Tensor] = None,
407
+ return_layer_states: bool = False,
408
+ ) -> Tensor:
409
+ r"""Pass the input through the encoder layers in turn.
410
+
411
+ Args:
412
+ src: the sequence to the encoder (required).
413
+ mask: the mask for the src sequence (optional).
414
+ src_key_padding_mask: the mask for the src keys per batch (optional).
415
+ return_layer_states: return layers' state (optional).
416
+
417
+ Shape:
418
+ see the docs in Transformer class.
419
+ """
420
+ if return_layer_states:
421
+ layer_states = [] # layers' output
422
+ output = src
423
+ for mod in self.layers:
424
+ output = mod(
425
+ output,
426
+ src_mask=mask,
427
+ src_key_padding_mask=src_key_padding_mask,
428
+ )
429
+ layer_states.append(output[0])
430
+
431
+ if self.norm is not None:
432
+ output = self.norm(output)
433
+
434
+ return layer_states, output
435
+
436
+ output = src
437
+ for mod in self.layers:
438
+ output = mod(
439
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
440
+ )
441
+
442
+ if self.norm is not None:
443
+ output = self.norm(output)
444
+
445
+ return output
446
+
447
+ def infer(
448
+ self,
449
+ src: Tensor,
450
+ mask: Optional[Tensor] = None,
451
+ src_key_padding_mask: Optional[Tensor] = None,
452
+ return_layer_states: bool = False,
453
+ past_kv: Optional[Tensor] = None,
454
+ use_cache: bool = False,
455
+ ):
456
+ if past_kv is None:
457
+ past_length = 0
458
+ past_kv = tuple([None] * self.num_layers)
459
+ else:
460
+ past_length = past_kv[0][0].size(-2)
461
+ new_kv = () if use_cache else None
462
+ output = src
463
+ for mod, past_layer_kv in zip(self.layers, past_kv):
464
+ output, kv = mod.infer(
465
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past_kv=past_layer_kv, use_cache=use_cache
466
+ )
467
+ if use_cache:
468
+ new_kv = new_kv + (kv,)
469
+
470
+ if self.norm is not None:
471
+ output = self.norm(output)
472
+
473
+ return output, new_kv
474
+
475
+
476
+ class TransformerDecoderLayer(nn.Module):
477
+ __constants__ = ["batch_first", "norm_first"]
478
+
479
+ def __init__(
480
+ self,
481
+ d_model: int,
482
+ nhead: int,
483
+ dim_feedforward: int = 2048,
484
+ dropout: float = 0.1,
485
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
486
+ linear1_self_attention_cls: nn.Module = nn.Linear,
487
+ linear2_self_attention_cls: nn.Module = nn.Linear,
488
+ linear1_feedforward_cls: nn.Module = nn.Linear,
489
+ linear2_feedforward_cls: nn.Module = nn.Linear,
490
+ batch_first: bool = False,
491
+ norm_first: bool = False,
492
+ device=None,
493
+ dtype=None,
494
+ layer_norm_cls: nn.Module = LayerNorm,
495
+ layer_norm_eps: float = 1e-5,
496
+ adaptive_layer_norm=False,
497
+ ) -> None:
498
+ factory_kwargs = {"device": device, "dtype": dtype}
499
+ super(TransformerDecoderLayer, self).__init__()
500
+ self.self_attn = MultiheadAttention(
501
+ d_model,
502
+ nhead,
503
+ dropout=dropout,
504
+ batch_first=batch_first,
505
+ linear1_cls=linear1_self_attention_cls,
506
+ linear2_cls=linear2_self_attention_cls,
507
+ **factory_kwargs,
508
+ )
509
+ self.multihead_attn = MultiheadAttention(
510
+ d_model,
511
+ nhead,
512
+ dropout=dropout,
513
+ batch_first=batch_first,
514
+ linear1_cls=linear1_self_attention_cls,
515
+ linear2_cls=linear2_self_attention_cls,
516
+ **factory_kwargs,
517
+ )
518
+ # Implementation of Feedforward model
519
+ self.linear1 = linear1_feedforward_cls(
520
+ d_model, dim_feedforward, **factory_kwargs
521
+ )
522
+ self.dropout = nn.Dropout(dropout)
523
+ self.linear2 = linear2_feedforward_cls(
524
+ dim_feedforward, d_model, **factory_kwargs
525
+ )
526
+
527
+ self.norm_first = norm_first
528
+ self.dropout1 = nn.Dropout(dropout)
529
+ self.dropout2 = nn.Dropout(dropout)
530
+ self.dropout3 = nn.Dropout(dropout)
531
+
532
+ # Legacy string support for activation function.
533
+ if isinstance(activation, str):
534
+ self.activation = _get_activation_fn(activation)
535
+ elif isinstance(activation, partial):
536
+ self.activation = activation(d_model)
537
+ elif activation == BalancedDoubleSwish:
538
+ self.activation = BalancedDoubleSwish(d_model)
539
+ else:
540
+ self.activation = activation
541
+
542
+ if adaptive_layer_norm:
543
+ norm1 = layer_norm_cls(
544
+ d_model, eps=layer_norm_eps, **factory_kwargs
545
+ )
546
+ norm2 = layer_norm_cls(
547
+ d_model, eps=layer_norm_eps, **factory_kwargs
548
+ )
549
+ norm3 = layer_norm_cls(
550
+ d_model, eps=layer_norm_eps, **factory_kwargs
551
+ )
552
+
553
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
554
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
555
+ self.norm3 = AdaptiveLayerNorm(d_model, norm3)
556
+ else:
557
+ self.norm1 = layer_norm_cls(
558
+ d_model, eps=layer_norm_eps, **factory_kwargs
559
+ )
560
+ self.norm2 = layer_norm_cls(
561
+ d_model, eps=layer_norm_eps, **factory_kwargs
562
+ )
563
+ if layer_norm_cls == IdentityNorm:
564
+ self.norm3 = BalancedBasicNorm(
565
+ d_model, eps=layer_norm_eps, **factory_kwargs
566
+ )
567
+ else:
568
+ self.norm3 = layer_norm_cls(
569
+ d_model, eps=layer_norm_eps, **factory_kwargs
570
+ )
571
+
572
+ def forward(
573
+ self,
574
+ tgt: Tensor,
575
+ memory: Tensor,
576
+ tgt_mask: Optional[Tensor] = None,
577
+ memory_mask: Optional[Tensor] = None,
578
+ tgt_key_padding_mask: Optional[Tensor] = None,
579
+ memory_key_padding_mask: Optional[Tensor] = None,
580
+ ) -> Tensor:
581
+ r"""Pass the inputs (and mask) through the decoder layer.
582
+
583
+ Args:
584
+ tgt: the sequence to the decoder layer (required).
585
+ memory: the sequence from the last layer of the encoder (required).
586
+ tgt_mask: the mask for the tgt sequence (optional).
587
+ memory_mask: the mask for the memory sequence (optional).
588
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
589
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
590
+
591
+ Shape:
592
+ see the docs in Transformer class.
593
+ """
594
+ tgt_is_tuple = False
595
+ if isinstance(tgt, tuple):
596
+ x, stage_embedding = tgt
597
+ tgt_is_tuple = True
598
+ else:
599
+ x, stage_embedding = tgt, None
600
+
601
+ if self.norm_first:
602
+ x = x + self._sa_block(
603
+ self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
604
+ )
605
+ x = x + self._mha_block(
606
+ self.norm2(x, stage_embedding),
607
+ memory,
608
+ memory_mask,
609
+ memory_key_padding_mask,
610
+ )
611
+ x = x + self._ff_block(self.norm3(x, stage_embedding))
612
+ else:
613
+ x = self.norm1(
614
+ x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
615
+ stage_embedding,
616
+ )
617
+ x = self.norm2(
618
+ x
619
+ + self._mha_block(
620
+ x, memory, memory_mask, memory_key_padding_mask
621
+ ),
622
+ stage_embedding,
623
+ )
624
+ x = self.norm3(x + self._ff_block(x), stage_embedding)
625
+
626
+ if tgt_is_tuple:
627
+ return (x, stage_embedding)
628
+ return x
629
+
630
+ # self-attention block
631
+ def _sa_block(
632
+ self,
633
+ x: Tensor,
634
+ attn_mask: Optional[Tensor],
635
+ key_padding_mask: Optional[Tensor],
636
+ ) -> Tensor:
637
+ x = self.self_attn(
638
+ x,
639
+ x,
640
+ x,
641
+ attn_mask=attn_mask,
642
+ key_padding_mask=key_padding_mask,
643
+ need_weights=False,
644
+ )[0]
645
+ return self.dropout1(x)
646
+
647
+ # multihead attention block
648
+ def _mha_block(
649
+ self,
650
+ x: Tensor,
651
+ mem: Tensor,
652
+ attn_mask: Optional[Tensor],
653
+ key_padding_mask: Optional[Tensor],
654
+ ) -> Tensor:
655
+ x = self.multihead_attn(
656
+ x,
657
+ mem,
658
+ mem,
659
+ attn_mask=attn_mask,
660
+ key_padding_mask=key_padding_mask,
661
+ need_weights=False,
662
+ )[0]
663
+ return self.dropout2(x)
664
+
665
+ # feed forward block
666
+ def _ff_block(self, x: Tensor) -> Tensor:
667
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
668
+ return self.dropout3(x)
669
+
670
+
671
+ def _get_clones(module, N):
672
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
673
+
674
+
675
+ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
676
+ if activation == "relu":
677
+ return F.relu
678
+ elif activation == "gelu":
679
+ return F.gelu
680
+
681
+ raise RuntimeError(
682
+ "activation should be relu/gelu, not {}".format(activation)
683
+ )
prompts/promptsf ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ pydub
4
+ soundfile
5
+ numpy
6
+ torchvision
7
+ torchaudio
8
+ tokenizers
9
+ encodec
10
+ langid
11
+ wget
12
+ unidecode
13
+ pyopenjtalk-prebuilt
14
+ pypinyin
15
+ inflect
16
+ cn2an
17
+ jieba
18
+ eng_to_ipa
19
+ openai-whisper
20
+ matplotlib
21
+ gradio==3.41.2
22
+ nltk
23
+ sudachipy
24
+ sudachidict_core
25
+ vocos
26
+ transformers
27
+ accelerate
28
+ pyannote.audio
29
+ onnxruntime
30
+ fastapi
31
+ uvicorn[standard]
32
+ pytest
33
+ fastapi-cors
34
+ sqlalchemy
35
+ sqlalchemy.orm
36
+
s2smodels.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import Column,Integer, String,Float,BINARY
2
+ from sqlalchemy.ext.declarative import declarative_base
3
+ Base = declarative_base()
4
+
5
+ #create audio_segments tables
6
+ class Audio_segment(Base):
7
+ __tablename__ = "audioSegments"
8
+ id = Column(Integer, primary_key=True)
9
+ start_time = Column(Float)
10
+ end_time = Column(Float)
11
+ type = Column(String)
12
+ audio=Column(BINARY)
13
+
14
+
15
+ #create audio_generation table
16
+ class AudioGeneration(Base):
17
+ __tablename__ = "audioGeneration"
18
+ id = Column(Integer, primary_key=True)
19
+ audio=Column(BINARY)
utils/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ # from icefall.utils import make_pad_mask
4
+
5
+ from .symbol_table import SymbolTable
6
+
7
+ # make_pad_mask = make_pad_mask
8
+ SymbolTable = SymbolTable
9
+
10
+
11
+ class Transpose(nn.Identity):
12
+ """(N, T, D) -> (N, D, T)"""
13
+
14
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
15
+ return input.transpose(1, 2)
utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (915 Bytes). View file
 
utils/__pycache__/generation.cpython-311.pyc ADDED
Binary file (15.1 kB). View file
 
utils/__pycache__/prompt_making.cpython-311.pyc ADDED
Binary file (7 kB). View file
 
utils/__pycache__/sentence_cutter.cpython-311.pyc ADDED
Binary file (3.5 kB). View file
 
utils/__pycache__/symbol_table.cpython-311.pyc ADDED
Binary file (12.8 kB). View file
 
utils/download.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import requests
3
+
4
+
5
+ def download_file_from_google_drive(id, destination):
6
+ URL = "https://docs.google.com/uc?export=download&confirm=1"
7
+
8
+ session = requests.Session()
9
+
10
+ response = session.get(URL, params={"id": id}, stream=True)
11
+ token = get_confirm_token(response)
12
+
13
+ if token:
14
+ params = {"id": id, "confirm": token}
15
+ response = session.get(URL, params=params, stream=True)
16
+
17
+ save_response_content(response, destination)
18
+
19
+
20
+ def get_confirm_token(response):
21
+ for key, value in response.cookies.items():
22
+ if key.startswith("download_warning"):
23
+ return value
24
+
25
+ return None
26
+
27
+
28
+ def save_response_content(response, destination):
29
+ CHUNK_SIZE = 32768
30
+
31
+ with open(destination, "wb", encoding='utf-8') as f:
32
+ for chunk in response.iter_content(CHUNK_SIZE):
33
+ if chunk: # filter out keep-alive new chunks
34
+ f.write(chunk)
35
+
36
+
37
+ def main():
38
+ if len(sys.argv) >= 3:
39
+ file_id = sys.argv[1]
40
+ destination = sys.argv[2]
41
+ else:
42
+ file_id = "TAKE_ID_FROM_SHAREABLE_LINK"
43
+ destination = "DESTINATION_FILE_ON_YOUR_DISK"
44
+ print(f"dowload {file_id} to {destination}")
45
+ download_file_from_google_drive(file_id, destination)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ main()
utils/g2p/__init__.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+ import utils.g2p.cleaners
3
+ from utils.g2p.symbols import symbols
4
+ from tokenizers import Tokenizer
5
+
6
+ # Mappings from symbol to numeric ID and vice versa:
7
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9
+
10
+
11
+ class PhonemeBpeTokenizer:
12
+ def __init__(self, tokenizer_path = "./utils/g2p/bpe_1024.json"):
13
+ self.tokenizer = Tokenizer.from_file(tokenizer_path)
14
+
15
+ def tokenize(self, text):
16
+ # 1. convert text to phoneme
17
+ phonemes, langs = _clean_text(text, ['cje_cleaners'])
18
+ # 2. replace blank space " " with "_"
19
+ phonemes = phonemes.replace(" ", "_")
20
+ # 3. tokenize phonemes
21
+ phoneme_tokens = self.tokenizer.encode(phonemes).ids
22
+ assert(len(phoneme_tokens) == len(langs))
23
+ if not len(phoneme_tokens):
24
+ raise ValueError("Empty text is given")
25
+ return phoneme_tokens, langs
26
+
27
+ def text_to_sequence(text, cleaner_names):
28
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
29
+ Args:
30
+ text: string to convert to a sequence
31
+ cleaner_names: names of the cleaner functions to run the text through
32
+ Returns:
33
+ List of integers corresponding to the symbols in the text
34
+ '''
35
+ sequence = []
36
+ symbol_to_id = {s: i for i, s in enumerate(symbols)}
37
+ clean_text = _clean_text(text, cleaner_names)
38
+ for symbol in clean_text:
39
+ if symbol not in symbol_to_id.keys():
40
+ continue
41
+ symbol_id = symbol_to_id[symbol]
42
+ sequence += [symbol_id]
43
+ return sequence
44
+
45
+
46
+ def cleaned_text_to_sequence(cleaned_text):
47
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
48
+ Args:
49
+ text: string to convert to a sequence
50
+ Returns:
51
+ List of integers corresponding to the symbols in the text
52
+ '''
53
+ sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
54
+ return sequence
55
+
56
+
57
+ def sequence_to_text(sequence):
58
+ '''Converts a sequence of IDs back to a string'''
59
+ result = ''
60
+ for symbol_id in sequence:
61
+ s = _id_to_symbol[symbol_id]
62
+ result += s
63
+ return result
64
+
65
+
66
+ def _clean_text(text, cleaner_names):
67
+ for name in cleaner_names:
68
+ cleaner = getattr(utils.g2p.cleaners, name)
69
+ if not cleaner:
70
+ raise Exception('Unknown cleaner: %s' % name)
71
+ text, langs = cleaner(text)
72
+ return text, langs
utils/g2p/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (4.49 kB). View file
 
utils/g2p/__pycache__/cleaners.cpython-311.pyc ADDED
Binary file (4.66 kB). View file