jason-salt commited on
Commit
b971d47
·
1 Parent(s): 53b664a
Files changed (46) hide show
  1. .gitattributes +1 -0
  2. __pycache__/inference_tts_scale.cpython-310.pyc +0 -0
  3. data/__init__.py +0 -0
  4. data/__pycache__/__init__.cpython-310.pyc +0 -0
  5. data/__pycache__/tokenizer.cpython-310.pyc +0 -0
  6. data/gigaspeech.py +156 -0
  7. data/phonemize_encodec_encode_hf.py +206 -0
  8. data/tokenizer.py +149 -0
  9. demo/84_121550_000074_000000.wav +0 -0
  10. demo/generated_se/84_121550_000074_000000_new_seed1.wav +0 -0
  11. demo/generated_se/84_121550_000074_000000_orig.wav +0 -0
  12. demo/generated_tts/84_121550_000074_000000_concat_seed1.wav +0 -0
  13. demo/generated_tts/84_121550_000074_000000_gen_seed1.wav +0 -0
  14. demo/temp/84_121550_000074_000000.txt +1 -0
  15. demo/temp/84_121550_000074_000000.wav +0 -0
  16. demo/temp/mfa_alignments/84_121550_000074_000000.csv +109 -0
  17. gradio_app.py +528 -0
  18. inference_speech_editing_scale.py +226 -0
  19. inference_tts_scale.py +190 -0
  20. models/__pycache__/codebooks_patterns.cpython-310.pyc +0 -0
  21. models/__pycache__/voicecraft.cpython-310.pyc +0 -0
  22. models/codebooks_patterns.py +538 -0
  23. models/modules/__init__.py +0 -0
  24. models/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  25. models/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  26. models/modules/__pycache__/activation.cpython-310.pyc +0 -0
  27. models/modules/__pycache__/activation.cpython-39.pyc +0 -0
  28. models/modules/__pycache__/embedding.cpython-310.pyc +0 -0
  29. models/modules/__pycache__/embedding.cpython-39.pyc +0 -0
  30. models/modules/__pycache__/scaling.cpython-310.pyc +0 -0
  31. models/modules/__pycache__/scaling.cpython-39.pyc +0 -0
  32. models/modules/__pycache__/transformer.cpython-310.pyc +0 -0
  33. models/modules/__pycache__/transformer.cpython-39.pyc +0 -0
  34. models/modules/__pycache__/utils.cpython-310.pyc +0 -0
  35. models/modules/__pycache__/utils.cpython-39.pyc +0 -0
  36. models/modules/__pycache__/visualizer.cpython-39.pyc +0 -0
  37. models/modules/activation.py +653 -0
  38. models/modules/embedding.py +98 -0
  39. models/modules/sampling.py +63 -0
  40. models/modules/scaling.py +1406 -0
  41. models/modules/transformer.py +698 -0
  42. models/modules/utils.py +37 -0
  43. models/voicecraft.py +1406 -0
  44. pretrained_models/encodec_4cb2048_giga.th +3 -0
  45. pretrained_models/giga330M.pth +3 -0
  46. requirements.txt +9 -0
.gitattributes CHANGED
@@ -20,6 +20,7 @@
20
  *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
 
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
20
  *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.th filter=lfs diff=lfs merge=lfs -text
24
  *.pth filter=lfs diff=lfs merge=lfs -text
25
  *.rar filter=lfs diff=lfs merge=lfs -text
26
  *.safetensors filter=lfs diff=lfs merge=lfs -text
__pycache__/inference_tts_scale.cpython-310.pyc ADDED
Binary file (6.8 kB). View file
 
data/__init__.py ADDED
File without changes
data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (130 Bytes). View file
 
data/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (4.83 kB). View file
 
data/gigaspeech.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import copy
5
+ import logging
6
+ import shutil
7
+
8
+ class dataset(torch.utils.data.Dataset):
9
+ def __init__(self, args, split):
10
+ super().__init__()
11
+ self.args = args
12
+ self.split = split
13
+ assert self.split in ['train', 'validation', 'test']
14
+ manifest_fn = os.path.join(self.args.dataset_dir, self.args.manifest_name, self.split+".txt")
15
+
16
+ with open(manifest_fn, "r") as rf:
17
+ data = [l.strip().split("\t") for l in rf.readlines()]
18
+ lengths_list = [int(item[-1]) for item in data]
19
+ self.data = []
20
+ self.lengths_list = []
21
+ for d, l in zip(data, lengths_list):
22
+ if l >= self.args.encodec_sr*self.args.audio_min_length:
23
+ if self.args.drop_long and l > self.args.encodec_sr*self.args.audio_max_length:
24
+ continue
25
+ self.data.append(d)
26
+ self.lengths_list.append(l)
27
+ logging.info(f"number of data points for {self.split} split: {len(self.lengths_list)}")
28
+
29
+ # phoneme vocabulary
30
+ vocab_fn = os.path.join(self.args.dataset_dir,"vocab.txt")
31
+ shutil.copy(vocab_fn, os.path.join(self.args.exp_dir, "vocab.txt"))
32
+ with open(vocab_fn, "r") as f:
33
+ temp = [l.strip().split(" ") for l in f.readlines() if len(l) != 0]
34
+ self.phn2num = {item[1]:int(item[0]) for item in temp}
35
+
36
+ self.symbol_set = set(["<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"])
37
+
38
+ def __len__(self):
39
+ return len(self.lengths_list)
40
+
41
+ def _load_phn_enc(self, index):
42
+ item = self.data[index]
43
+ pf = os.path.join(self.args.dataset_dir, self.args.phn_folder_name, item[1]+".txt")
44
+ ef = os.path.join(self.args.dataset_dir, self.args.encodec_folder_name, item[1]+".txt")
45
+ try:
46
+ with open(pf, "r") as p, open(ef, "r") as e:
47
+ phns = [l.strip() for l in p.readlines()]
48
+ assert len(phns) == 1, phns
49
+ x = [self.phn2num[item] for item in phns[0].split(" ") if item not in self.symbol_set] # drop ["<SIL>", "<MUSIC>", "<NOISE>", "<OTHER>"], as they are not in training set annotation
50
+ encos = [l.strip().split() for k, l in enumerate(e.readlines()) if k < self.args.n_codebooks]
51
+
52
+ assert len(encos) == self.args.n_codebooks, ef
53
+ if self.args.special_first:
54
+ y = [[int(n)+self.args.n_special for n in l] for l in encos]
55
+ else:
56
+ y = [[int(n) for n in l] for l in encos]
57
+ except Exception as e:
58
+ logging.info(f"loading failed for {pf} and {ef}, maybe files don't exist or are corrupted")
59
+ logging.info(f"error message: {e}")
60
+ return [], [[]]
61
+
62
+ return x, y
63
+
64
+ def __getitem__(self, index):
65
+ x, y = self._load_phn_enc(index)
66
+ x_len, y_len = len(x), len(y[0])
67
+
68
+ if x_len == 0 or y_len == 0:
69
+ return {
70
+ "x": None,
71
+ "x_len": None,
72
+ "y": None,
73
+ "y_len": None,
74
+ "y_mask_interval": None, # index y_mask_interval[1] is the position of start_of_continue token
75
+ "extra_mask_start": None # this is only used in VE1
76
+ }
77
+ while y_len < self.args.encodec_sr*self.args.audio_min_length:
78
+ assert not self.args.dynamic_batching
79
+ index = random.choice(range(len(self))) # regenerate an index
80
+ x, y = self._load_phn_enc(index)
81
+ x_len, y_len = len(x), len(y[0])
82
+ if self.args.drop_long:
83
+ while x_len > self.args.text_max_length or y_len > self.args.encodec_sr*self.args.audio_max_length:
84
+ index = random.choice(range(len(self))) # regenerate an index
85
+ x, y = self._load_phn_enc(index)
86
+ x_len, y_len = len(x), len(y[0])
87
+
88
+ ### padding and cropping below ###
89
+ ### padding and cropping below ###
90
+ # adjust the length of encodec codes, pad to max_len or randomly crop
91
+ orig_y_len = copy.copy(y_len)
92
+ max_len = int(self.args.audio_max_length * self.args.encodec_sr)
93
+ if y_len > max_len:
94
+ audio_start = random.choice(range(0, y_len-max_len))
95
+ for i in range(len(y)):
96
+ y[i] = y[i][audio_start:(audio_start+max_len)]
97
+ y_len = max_len
98
+ else:
99
+ audio_start = 0
100
+ if not self.args.dynamic_batching:
101
+ pad = [0] * (max_len - y_len) if self.args.sep_special_token else [self.args.audio_pad_token] * (max_len - y_len)
102
+ for i in range(len(y)):
103
+ y[i] = y[i] + pad
104
+
105
+ # adjust text
106
+ # if audio is cropped, and text is longer than max, crop max based on how audio is cropped
107
+ if audio_start > 0 and len(x) > self.args.text_max_length: # if audio is longer than max and text is long than max, start text the way audio started
108
+ x = x[int(len(x)*audio_start/orig_y_len):]
109
+ if len(x) > self.args.text_max_length: # if text is still longer than max, cut the end
110
+ x = x[:self.args.text_max_length]
111
+
112
+ x_len = len(x)
113
+ if x_len > self.args.text_max_length:
114
+ text_start = random.choice(range(0, x_len - self.args.text_max_length))
115
+ x = x[text_start:text_start+self.args.text_max_length]
116
+ x_len = self.args.text_max_length
117
+ elif self.args.pad_x and x_len <= self.args.text_max_length:
118
+ pad = [0] * (self.args.text_max_length - x_len) if self.args.sep_special_token else [self.args.text_pad_token] * (self.args.text_max_length - x_len)
119
+ x = x + pad
120
+ ### padding and cropping above ###
121
+ ### padding and cropping above ###
122
+
123
+ return {
124
+ "x": torch.LongTensor(x),
125
+ "x_len": x_len,
126
+ "y": torch.LongTensor(y),
127
+ "y_len": y_len
128
+ }
129
+
130
+
131
+ def collate(self, batch):
132
+ out = {key:[] for key in batch[0]}
133
+ for item in batch:
134
+ if item['x'] == None: # deal with load failure
135
+ continue
136
+ for key, val in item.items():
137
+ out[key].append(val)
138
+ res = {}
139
+ if self.args.pad_x:
140
+ res["x"] = torch.stack(out["x"], dim=0)
141
+ else:
142
+ res["x"] = torch.nn.utils.rnn.pad_sequence(out["x"], batch_first=True, padding_value=self.args.text_pad_token)
143
+ res["x_lens"] = torch.LongTensor(out["x_len"])
144
+ if self.args.dynamic_batching:
145
+ if out['y'][0].ndim==2:
146
+ res['y'] = torch.nn.utils.rnn.pad_sequence([item.transpose(1,0) for item in out['y']],padding_value=self.args.audio_pad_token)
147
+ res['y'] = res['y'].permute(1,2,0) # T B K -> B K T
148
+ else:
149
+ assert out['y'][0].ndim==1, out['y'][0].shape
150
+ res['y'] = torch.nn.utils.rnn.pad_sequence(out['y'], batch_first=True, padding_value=self.args.audio_pad_token)
151
+ else:
152
+ res['y'] = torch.stack(out['y'], dim=0)
153
+ res["y_lens"] = torch.LongTensor(out["y_len"])
154
+ res["text_padding_mask"] = torch.arange(res['x'][0].shape[-1]).unsqueeze(0) >= res['x_lens'].unsqueeze(1)
155
+ res["audio_padding_mask"] = torch.arange(res['y'][0].shape[-1]).unsqueeze(0) >= res['y_lens'].unsqueeze(1)
156
+ return res
data/phonemize_encodec_encode_hf.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ def parse_args():
3
+ parser = argparse.ArgumentParser(description="encode the librilight dataset using encodec model")
4
+ parser.add_argument("--dataset_size", type=str, default='xs', help='sizes of gigaspeech, xs, s, m, l, xl. we use xl for VoiceCraft training, xs is good for debugging')
5
+ parser.add_argument('--download_to', type=str, default="/data/scratch/pyp/datasets/gigaspeech_debug", help="dir where you want the huggingface gigaspeech dataset to be downloaded to")
6
+ parser.add_argument('--save_dir', type=str, default="/data/scratch/pyp/datasets/gigaspeech_phn_enc_manifest_debug", help="path to the manifest, phonemes, and encodec codes dirs")
7
+ parser.add_argument('--encodec_model_path', type=str, default="/data/scratch/pyp/exp_pyp/audiocraft/encodec/xps/6f79c6a8/checkpoint.th")
8
+ parser.add_argument('--n_workers', type=int, default=4, help="Number of parallel worker processes")
9
+ parser.add_argument('--mega_batch_size', type=int, default=100, help="Number of samples in each mega batch for multiprocess dataloading")
10
+ parser.add_argument('--batch_size', type=int, default=4, help="batch size for encodec encoding, decrease it if OOM. This is the sum of batch size *over each gpu*, so increase it if you are using more gpus")
11
+ parser.add_argument('--model_sr', type=int, default=16000, help='encodec input audio sample rate')
12
+ parser.add_argument('--downsample_rate', type=int, default=320, help='encodec downsample rate')
13
+ parser.add_argument('--model_code_sr', type=int, default=50, help='encodec model code sample rate')
14
+ parser.add_argument('--len_cap', type=float, default=35.0, help='will drop audios that are longer than this number')
15
+ parser.add_argument('--max_len', type=int, default=30000, help='max length of audio in samples, if exceed, will cut a batch into half to process, decrease this number if OOM on your machine')
16
+ return parser.parse_args()
17
+ if __name__ == "__main__":
18
+ import logging
19
+ formatter = (
20
+ "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
21
+ )
22
+ logging.basicConfig(format=formatter, level=logging.INFO)
23
+ args = parse_args()
24
+
25
+ import os
26
+ import numpy as np
27
+ import torch
28
+ import tqdm
29
+ import time
30
+ from datasets import load_dataset, DownloadConfig
31
+
32
+ from tokenizer import TextTokenizer, tokenize_text
33
+
34
+ # get the path
35
+ phn_save_root = os.path.join(args.save_dir, args.dataset_size, "phonemes")
36
+ codes_save_root = os.path.join(args.save_dir, args.dataset_size, "encodec_16khz_4codebooks")
37
+ vocab_fn = os.path.join(args.save_dir, args.dataset_size, "vocab.txt")
38
+ os.makedirs(phn_save_root, exist_ok=True)
39
+ os.makedirs(codes_save_root, exist_ok=True)
40
+
41
+
42
+ def sort_by_audio_len(lens):
43
+ inds = np.argsort(lens).tolist()
44
+ logging.info(f"longest: {lens[inds[-1]]*args.model_code_sr} encodec codes, {lens[inds[-1]]:.2f} sec.")
45
+ logging.info(f"shortest: {lens[inds[0]]*args.model_code_sr} encodec codes, {lens[inds[0]]:.2f} sec.")
46
+ logging.info(f"median: {lens[inds[len(inds)//2]]*args.model_code_sr} encodec codes, {lens[inds[len(inds)//2]]:.2f} sec.")
47
+ logging.info(f"95 percentile longest: {lens[inds[int(len(inds)*0.95)]]*args.model_code_sr} encodec codes, {lens[inds[int(len(inds)*0.95)]]:.2f} sec.")
48
+ return inds[::-1]
49
+
50
+ def write_array_to_txt_file(array, filename):
51
+ with open(filename, 'w') as f:
52
+ for a in array[:-1]:
53
+ f.write(' '.join(map(str, a))+'\n')
54
+ f.write(' '.join(map(str, array[-1])))
55
+
56
+
57
+ ### phonemization
58
+ # load tokenizer
59
+ # load the encodec model
60
+ from audiocraft.solvers import CompressionSolver
61
+ model = CompressionSolver.model_from_checkpoint(args.encodec_model_path)
62
+ model = model.cuda()
63
+ model = model.eval()
64
+ text_tokenizer = TextTokenizer()
65
+
66
+
67
+ # https://github.com/SpeechColab/GigaSpeech
68
+ # there are only four different punctuations
69
+ # need to check whether there are other < started strings
70
+ punc2sym = {" <COMMA>": ",", " <PERIOD>": ".", " <QUESTIONMARK>": "?", " <EXCLAMATIONPOINT>": "!"} # note the space in front of each punc name
71
+ gar2sym = {"<SIL>": "#%#", "<MUSIC>": "##%", "<NOISE>": "%%#", "<OTHER>":"%#%"} # so that they are savely keep as the original sym when using tokenize_text
72
+ punc2sym.update(gar2sym)
73
+
74
+ word2sym = { "h æ ʃ h ɐ ʃ p ɚ s ɛ n t": "<MUSIC>", "h æ ʃ p ɚ s ɛ n t h æ ʃ": "<SIL>", "p ɚ s ɛ n t h ɐ ʃ p ɚ s ɛ n t": "<OTHER>", "p ɚ s ɛ n t p ɚ s ɛ n t h æ ʃ": "<NOISE>"}
75
+ forbidden_words = set(['#%#', '##%', '%%#', '%#%'])
76
+
77
+ dc = DownloadConfig(cache_dir=args.download_to)
78
+ stime = time.time()
79
+ logging.info("loading the dataset...")
80
+ gs = load_dataset("speechcolab/gigaspeech", args.dataset_size, use_auth_token=True, cache_dir = args.download_to, download_config=dc)
81
+ logging.info(f"time spend on loading the dataset: {time.time() - stime:.2f} seconds")
82
+
83
+ splits = ['validation', 'test', 'train']
84
+
85
+ logging.info(f"gigaspeech dataset {args.dataset_size} info: {gs}")
86
+ logging.info(f"phonemizing...")
87
+ phn_vocab = set()
88
+ all_lens = []
89
+
90
+ # you will see a ton of [WARNING] words_mismatch.py:88......, it's not a issue
91
+ for split in tqdm.tqdm(splits):
92
+ skip = 0
93
+ logging.info(f"now processing split {split}...")
94
+ for item in tqdm.tqdm(gs[split]):
95
+ save_fn = os.path.join(phn_save_root, item['segment_id']+".txt")
96
+ text = item['text']
97
+ if sum(word in forbidden_words for word in text.split(" ")):
98
+ logging.info(f"skip {item['segment_id']}, because it contains forbiden words. It's transcript: {text}")
99
+ skip += 1
100
+ continue
101
+ for k, v in punc2sym.items():
102
+ text = text.replace(k, v)
103
+ phn = tokenize_text(text_tokenizer, text)
104
+ phn_seq = " ".join(phn)
105
+ for k, v in word2sym.items():
106
+ phn_seq = phn_seq.replace(k, v)
107
+ phn_vocab.update(phn_seq.split(" "))
108
+ all_lens.append(len(phn_seq.split(" ")))
109
+ with open(save_fn, "w") as f:
110
+ f.write(phn_seq)
111
+ logging.info(f"split {split} has {len(gs[split])} samples in total, skipped {skip} due to forbiden words")
112
+
113
+ print(f"phn vocab size: {len(list(phn_vocab))}")
114
+ print("phn sequence stats: ")
115
+ print(f"longest: {max(all_lens)}")
116
+ print(f"shortest: {min(all_lens)}")
117
+ print(f"median: {np.quantile(all_lens, 0.5)}")
118
+ print(f"95 percentile longest: {np.quantile(all_lens, 0.95)}")
119
+ print("write vocabulary to ", vocab_fn)
120
+ with open(vocab_fn, "w") as f:
121
+ for i, phn in enumerate(list(phn_vocab)):
122
+ if i < len(list(phn_vocab)) - 1:
123
+ f.write(f"{str(i)} {phn}\n")
124
+ else:
125
+ f.write(f"{str(i)} {phn}")
126
+
127
+ class mydataset(torch.utils.data.Dataset):
128
+ def __init__(self, split):
129
+ super().__init__()
130
+ self.data = gs[split]
131
+ def __len__(self):
132
+ return len(self.data)
133
+ def __getitem__(self, ind):
134
+ try:
135
+ segment_id, audio, sr, text, begin_time, end_time = self.data[ind]['segment_id'], torch.from_numpy(self.data[ind]['audio']['array']).float(), self.data[ind]['audio']['sampling_rate'], self.data[ind]['text'], self.data[ind]['begin_time'], self.data[ind]['end_time']
136
+ except:
137
+ return None, None, None, None, None, None
138
+
139
+ return segment_id, audio, sr, text, begin_time, end_time
140
+ def collate(self, batch):
141
+ res = {'segment_id': [], "audio": [], "sr": [], "text": [], "begin_time": [], "end_time": []}
142
+ for item in batch:
143
+ if item[0] != None:
144
+ res['segment_id'].append(item[0])
145
+ res['audio'].append(item[1])
146
+ res['sr'].append(item[2])
147
+ res['text'].append(item[3])
148
+ res['begin_time'].append(item[4])
149
+ res['end_time'].append(item[5])
150
+ return res
151
+
152
+
153
+ ## encodec codes extraction
154
+ logging.info("encodec encoding...")
155
+ train_dataset = mydataset('train')
156
+ train_loader = torch.torch.utils.data.DataLoader(train_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=train_dataset.collate)
157
+ validation_dataset = mydataset('validation')
158
+ validation_loader = torch.torch.utils.data.DataLoader(validation_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=validation_dataset.collate)
159
+ test_dataset = mydataset('test')
160
+ test_loader = torch.torch.utils.data.DataLoader(test_dataset, batch_size=args.mega_batch_size, shuffle=False, drop_last=False, num_workers=args.n_workers, collate_fn=test_dataset.collate)
161
+ splits = ['validation', 'test', 'train']
162
+ loaders = [validation_loader, test_loader, train_loader]
163
+ # splits = ['validation'] # for debug
164
+ # loaders = [validation_loader]
165
+ for split, loader in zip(splits, loaders):
166
+ skip = 0
167
+ logging.info(f"now processing split {split}...")
168
+ mega_n_steps = int(np.ceil(len(gs[split]) / args.mega_batch_size))
169
+ logging.info(f"partition the split {split} into {mega_n_steps} parts, each has {args.mega_batch_size} samples")
170
+ for m, mega_batch in enumerate(loader):
171
+ logging.info(f"====================================")
172
+ logging.info(f"====================================")
173
+ logging.info(f"now processing mega step {m+1}/{mega_n_steps}")
174
+ lengths = np.array(mega_batch['end_time']) - np.array(mega_batch['begin_time'])
175
+ sorted_inds = sort_by_audio_len(lengths)
176
+ for j in range(len(sorted_inds))[::-1]:
177
+ if lengths[sorted_inds[j]] < 0.2 or lengths[sorted_inds[j]] > args.len_cap: # skip samples that are too short (shorter than 0.2s), or too big (bigger than 80s)
178
+ skip += 1
179
+ del sorted_inds[j]
180
+
181
+ n_steps = int(np.ceil(len(sorted_inds) / args.batch_size))
182
+ for n in tqdm.tqdm(range(n_steps), disable=True):
183
+ inds_used = sorted_inds[n*args.batch_size:(n+1)*args.batch_size]
184
+ audio_batch = [mega_batch['audio'][id] for id in inds_used]
185
+ sr_batch = [mega_batch['sr'][id] for id in inds_used]
186
+ segment_id_batch = [mega_batch['segment_id'][id] for id in inds_used]
187
+ text_batch = [mega_batch['text'][id] for id in inds_used]
188
+ padded_wav = torch.nn.utils.rnn.pad_sequence(audio_batch, batch_first=True).unsqueeze(1) # [B, T] -> [B, 1, T]
189
+ all_lens = [lengths[id] for id in inds_used]
190
+ with torch.no_grad():
191
+ if max(all_lens) > args.max_len and len(all_lens) > 1: # NOTE decrease args.max_len if OOM, or chunk it into more than 2 forward passes
192
+ codes = []
193
+ inwav = padded_wav.cuda()
194
+ codes.append(model.encode(inwav[:len(inwav)//2])[0].cpu())
195
+ codes.append(model.encode(inwav[len(inwav)//2:])[0].cpu())
196
+ codes = torch.cat(codes, dim=0)
197
+ else:
198
+ encoded_frames = model.encode(padded_wav.cuda())
199
+ # logging.info(f"encoded_frames: {encoded_frames[0].shape}")
200
+ codes = encoded_frames[0].cpu()
201
+
202
+ for i, length in enumerate(all_lens):
203
+ save_fn = os.path.join(codes_save_root, segment_id_batch[i]+".txt")
204
+ actual_len = round(length * args.model_code_sr) # 320 is downsample rate for this model
205
+ cur_code = codes[i].tolist() if type(codes) == list else codes[i, :, :actual_len].tolist()
206
+ write_array_to_txt_file(cur_code, save_fn)
data/tokenizer.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/data/tokenizer.py
2
+ # Copyright 2023 (authors: Feiteng Li)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ from dataclasses import asdict, dataclass
18
+ from typing import Any, Dict, List, Optional, Pattern, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torchaudio
23
+ # from lhotse.features import FeatureExtractor
24
+ # from lhotse.utils import Seconds, compute_num_frames
25
+ from phonemizer.backend import EspeakBackend
26
+ from phonemizer.backend.espeak.language_switch import LanguageSwitch
27
+ from phonemizer.backend.espeak.words_mismatch import WordMismatch
28
+ from phonemizer.punctuation import Punctuation
29
+ from phonemizer.separator import Separator
30
+
31
+
32
+
33
+ class TextTokenizer:
34
+ """Phonemize Text."""
35
+
36
+ def __init__(
37
+ self,
38
+ language="en-us",
39
+ backend="espeak",
40
+ separator=Separator(word="_", syllable="-", phone="|"),
41
+ preserve_punctuation=True,
42
+ punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
43
+ with_stress: bool = False,
44
+ tie: Union[bool, str] = False,
45
+ language_switch: LanguageSwitch = "keep-flags",
46
+ words_mismatch: WordMismatch = "ignore",
47
+ ) -> None:
48
+ phonemizer = EspeakBackend(
49
+ language,
50
+ punctuation_marks=punctuation_marks,
51
+ preserve_punctuation=preserve_punctuation,
52
+ with_stress=with_stress,
53
+ tie=tie,
54
+ language_switch=language_switch,
55
+ words_mismatch=words_mismatch,
56
+ )
57
+
58
+ self.backend = phonemizer
59
+ self.separator = separator
60
+
61
+ def to_list(self, phonemized: str) -> List[str]:
62
+ fields = []
63
+ for word in phonemized.split(self.separator.word):
64
+ # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
65
+ pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
66
+ fields.extend(
67
+ [p for p in pp if p != self.separator.phone]
68
+ + [self.separator.word]
69
+ )
70
+ assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
71
+ self.separator.phone
72
+ )
73
+ return fields[:-1]
74
+
75
+ def __call__(self, text, strip=True) -> List[List[str]]:
76
+ if isinstance(text, str):
77
+ text = [text]
78
+
79
+ phonemized = self.backend.phonemize(
80
+ text, separator=self.separator, strip=strip, njobs=1
81
+ )
82
+ return [self.to_list(p) for p in phonemized]
83
+
84
+
85
+ def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
86
+ phonemes = tokenizer([text.strip()])
87
+ return phonemes[0] # k2symbols
88
+
89
+ def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int):
90
+ assert wav.shape[0] in [1, 2], "Audio must be mono or stereo."
91
+ if target_channels == 1:
92
+ wav = wav.mean(0, keepdim=True)
93
+ elif target_channels == 2:
94
+ *shape, _, length = wav.shape
95
+ wav = wav.expand(*shape, target_channels, length)
96
+ elif wav.shape[0] == 1:
97
+ wav = wav.expand(target_channels, -1)
98
+ wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
99
+ return wav
100
+
101
+ class AudioTokenizer:
102
+ """EnCodec audio."""
103
+
104
+ def __init__(
105
+ self,
106
+ device: Any = None,
107
+ signature = None
108
+ ) -> None:
109
+ from audiocraft.solvers import CompressionSolver
110
+ model = CompressionSolver.model_from_checkpoint(signature)
111
+ self.sample_rate = model.sample_rate
112
+ self.channels = model.channels
113
+
114
+ if not device:
115
+ device = torch.device("cpu")
116
+ if torch.cuda.is_available():
117
+ device = torch.device("cuda:0")
118
+
119
+ self._device = device
120
+
121
+ self.codec = model.to(device)
122
+
123
+ @property
124
+ def device(self):
125
+ return self._device
126
+
127
+ def encode(self, wav: torch.Tensor) -> torch.Tensor:
128
+ codes = self.codec.encode(wav.to(self.device))
129
+ return [(codes[0], None)]
130
+
131
+ def decode(self, frames: torch.Tensor) -> torch.Tensor:
132
+ frames = frames[0][0] # [1,4,T]
133
+ return self.codec.decode(frames)
134
+
135
+
136
+
137
+ def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str, offset = -1, num_frames=-1):
138
+ # Load and pre-process the audio waveform
139
+ if offset != -1 and num_frames!=-1:
140
+ wav, sr = torchaudio.load(audio_path, frame_offset=offset, num_frames=num_frames)
141
+ else:
142
+ wav, sr = torchaudio.load(audio_path)
143
+ wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
144
+ wav = wav.unsqueeze(0)
145
+
146
+ # Extract discrete codes from EnCodec
147
+ with torch.no_grad():
148
+ encoded_frames = tokenizer.encode(wav)
149
+ return encoded_frames
demo/84_121550_000074_000000.wav ADDED
Binary file (508 kB). View file
 
demo/generated_se/84_121550_000074_000000_new_seed1.wav ADDED
Binary file (426 kB). View file
 
demo/generated_se/84_121550_000074_000000_orig.wav ADDED
Binary file (508 kB). View file
 
demo/generated_tts/84_121550_000074_000000_concat_seed1.wav ADDED
Binary file (522 kB). View file
 
demo/generated_tts/84_121550_000074_000000_gen_seed1.wav ADDED
Binary file (329 kB). View file
 
demo/temp/84_121550_000074_000000.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks,
demo/temp/84_121550_000074_000000.wav ADDED
Binary file (508 kB). View file
 
demo/temp/mfa_alignments/84_121550_000074_000000.csv ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Begin,End,Label,Type,Speaker
2
+ 0.03,0.18,but,words,temp
3
+ 0.18,0.32,when,words,temp
4
+ 0.32,0.48,i,words,temp
5
+ 0.48,0.64,had,words,temp
6
+ 0.64,1.19,approached,words,temp
7
+ 1.22,1.58,so,words,temp
8
+ 1.58,1.91,near,words,temp
9
+ 1.91,2.07,to,words,temp
10
+ 2.07,2.42,them,words,temp
11
+ 2.53,2.61,the,words,temp
12
+ 2.61,3.01,common,words,temp
13
+ 3.05,3.62,object,words,temp
14
+ 3.68,3.93,which,words,temp
15
+ 3.93,4.02,the,words,temp
16
+ 4.02,4.34,sense,words,temp
17
+ 4.34,4.97,deceives,words,temp
18
+ 5.04,5.54,lost,words,temp
19
+ 5.54,6.0,not,words,temp
20
+ 6.0,6.14,by,words,temp
21
+ 6.14,6.67,distance,words,temp
22
+ 6.79,7.05,any,words,temp
23
+ 7.05,7.18,of,words,temp
24
+ 7.18,7.34,its,words,temp
25
+ 7.34,7.87,marks,words,temp
26
+ 0.03,0.06,B,phones,temp
27
+ 0.06,0.09,AH1,phones,temp
28
+ 0.09,0.18,T,phones,temp
29
+ 0.18,0.23,W,phones,temp
30
+ 0.23,0.27,EH1,phones,temp
31
+ 0.27,0.32,N,phones,temp
32
+ 0.32,0.48,AY1,phones,temp
33
+ 0.48,0.49,HH,phones,temp
34
+ 0.49,0.6,AE1,phones,temp
35
+ 0.6,0.64,D,phones,temp
36
+ 0.64,0.7,AH0,phones,temp
37
+ 0.7,0.83,P,phones,temp
38
+ 0.83,0.88,R,phones,temp
39
+ 0.88,0.99,OW1,phones,temp
40
+ 0.99,1.12,CH,phones,temp
41
+ 1.12,1.19,T,phones,temp
42
+ 1.22,1.4,S,phones,temp
43
+ 1.4,1.58,OW1,phones,temp
44
+ 1.58,1.7,N,phones,temp
45
+ 1.7,1.84,IH1,phones,temp
46
+ 1.84,1.91,R,phones,temp
47
+ 1.91,2.01,T,phones,temp
48
+ 2.01,2.07,AH0,phones,temp
49
+ 2.07,2.13,DH,phones,temp
50
+ 2.13,2.3,EH1,phones,temp
51
+ 2.3,2.42,M,phones,temp
52
+ 2.53,2.55,DH,phones,temp
53
+ 2.55,2.61,AH0,phones,temp
54
+ 2.61,2.73,K,phones,temp
55
+ 2.73,2.85,AA1,phones,temp
56
+ 2.85,2.9,M,phones,temp
57
+ 2.9,2.95,AH0,phones,temp
58
+ 2.95,3.01,N,phones,temp
59
+ 3.05,3.22,AA1,phones,temp
60
+ 3.22,3.27,B,phones,temp
61
+ 3.27,3.34,JH,phones,temp
62
+ 3.34,3.48,EH0,phones,temp
63
+ 3.48,3.54,K,phones,temp
64
+ 3.54,3.62,T,phones,temp
65
+ 3.68,3.69,HH,phones,temp
66
+ 3.69,3.76,W,phones,temp
67
+ 3.76,3.8,IH1,phones,temp
68
+ 3.8,3.93,CH,phones,temp
69
+ 3.93,3.95,DH,phones,temp
70
+ 3.95,4.02,AH0,phones,temp
71
+ 4.02,4.12,S,phones,temp
72
+ 4.12,4.21,EH1,phones,temp
73
+ 4.21,4.27,N,phones,temp
74
+ 4.27,4.34,S,phones,temp
75
+ 4.34,4.42,D,phones,temp
76
+ 4.42,4.45,IH0,phones,temp
77
+ 4.45,4.59,S,phones,temp
78
+ 4.59,4.79,IY1,phones,temp
79
+ 4.79,4.87,V,phones,temp
80
+ 4.87,4.97,Z,phones,temp
81
+ 5.04,5.12,L,phones,temp
82
+ 5.12,5.33,AO1,phones,temp
83
+ 5.33,5.42,S,phones,temp
84
+ 5.42,5.54,T,phones,temp
85
+ 5.54,5.7,N,phones,temp
86
+ 5.7,5.89,AA1,phones,temp
87
+ 5.89,6.0,T,phones,temp
88
+ 6.0,6.05,B,phones,temp
89
+ 6.05,6.14,AY1,phones,temp
90
+ 6.14,6.24,D,phones,temp
91
+ 6.24,6.3,IH1,phones,temp
92
+ 6.3,6.38,S,phones,temp
93
+ 6.38,6.45,T,phones,temp
94
+ 6.45,6.51,AH0,phones,temp
95
+ 6.51,6.57,N,phones,temp
96
+ 6.57,6.67,S,phones,temp
97
+ 6.79,6.89,EH1,phones,temp
98
+ 6.89,6.95,N,phones,temp
99
+ 6.95,7.05,IY0,phones,temp
100
+ 7.05,7.13,AH0,phones,temp
101
+ 7.13,7.18,V,phones,temp
102
+ 7.18,7.22,IH0,phones,temp
103
+ 7.22,7.29,T,phones,temp
104
+ 7.29,7.34,S,phones,temp
105
+ 7.34,7.39,M,phones,temp
106
+ 7.39,7.5,AA1,phones,temp
107
+ 7.5,7.58,R,phones,temp
108
+ 7.58,7.7,K,phones,temp
109
+ 7.7,7.87,S,phones,temp
gradio_app.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "5" # these are only used if developping locally
4
+ import gradio as gr
5
+ import torch
6
+ import torchaudio
7
+ from data.tokenizer import (
8
+ AudioTokenizer,
9
+ TextTokenizer,
10
+ )
11
+ from models import voicecraft
12
+ import io
13
+ import numpy as np
14
+ import random
15
+ import spaces
16
+
17
+
18
+ whisper_model, voicecraft_model = None, None
19
+
20
+ @spaces.GPU(duration=20)
21
+ def seed_everything(seed):
22
+ if seed != -1:
23
+ os.environ['PYTHONHASHSEED'] = str(seed)
24
+ random.seed(seed)
25
+ np.random.seed(seed)
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed(seed)
28
+ torch.backends.cudnn.benchmark = False
29
+ torch.backends.cudnn.deterministic = True
30
+
31
+ @spaces.GPU(duration=120)
32
+ def load_models(whisper_model_choice, voicecraft_model_choice):
33
+ global whisper_model, voicecraft_model
34
+
35
+ if whisper_model_choice is not None:
36
+ import whisper
37
+ from whisper.tokenizer import get_tokenizer
38
+ whisper_model = {
39
+ "model": whisper.load_model(whisper_model_choice),
40
+ "tokenizer": get_tokenizer(multilingual=False)
41
+ }
42
+
43
+
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+
46
+ voicecraft_name = f"{voicecraft_model_choice}.pth"
47
+ ckpt_fn = f"./pretrained_models/{voicecraft_name}"
48
+ encodec_fn = "./pretrained_models/encodec_4cb2048_giga.th"
49
+ if not os.path.exists(ckpt_fn):
50
+ os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\?download\=true")
51
+ os.system(f"mv {voicecraft_name}\?download\=true ./pretrained_models/{voicecraft_name}")
52
+ if not os.path.exists(encodec_fn):
53
+ os.system(f"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th")
54
+ os.system(f"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th")
55
+
56
+ ckpt = torch.load(ckpt_fn, map_location="cpu")
57
+ model = voicecraft.VoiceCraft(ckpt["config"])
58
+ model.load_state_dict(ckpt["model"])
59
+ model.to(device)
60
+ model.eval()
61
+ voicecraft_model = {
62
+ "ckpt": ckpt,
63
+ "model": model,
64
+ "text_tokenizer": TextTokenizer(backend="espeak"),
65
+ "audio_tokenizer": AudioTokenizer(signature=encodec_fn)
66
+ }
67
+
68
+ return gr.Accordion()
69
+
70
+ @spaces.GPU(duration=60)
71
+ def transcribe(seed, audio_path):
72
+ if whisper_model is None:
73
+ raise gr.Error("Whisper model not loaded")
74
+ seed_everything(seed)
75
+
76
+ number_tokens = [
77
+ i
78
+ for i in range(whisper_model["tokenizer"].eot)
79
+ if all(c in "0123456789" for c in whisper_model["tokenizer"].decode([i]).removeprefix(" "))
80
+ ]
81
+ result = whisper_model["model"].transcribe(audio_path, suppress_tokens=[-1] + number_tokens, word_timestamps=True)
82
+ words = [word_info for segment in result["segments"] for word_info in segment["words"]]
83
+
84
+ transcript = result["text"]
85
+ transcript_with_start_time = " ".join([f"{word['start']} {word['word']}" for word in words])
86
+ transcript_with_end_time = " ".join([f"{word['word']} {word['end']}" for word in words])
87
+
88
+ choices = [f"{word['start']} {word['word']} {word['end']}" for word in words]
89
+
90
+ return [
91
+ transcript, transcript_with_start_time, transcript_with_end_time,
92
+ gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # prompt_to_word
93
+ gr.Dropdown(value=choices[0], choices=choices, interactive=True), # edit_from_word
94
+ gr.Dropdown(value=choices[-1], choices=choices, interactive=True), # edit_to_word
95
+ words
96
+ ]
97
+
98
+
99
+ def get_output_audio(audio_tensors, codec_audio_sr):
100
+ result = torch.cat(audio_tensors, 1)
101
+ buffer = io.BytesIO()
102
+ torchaudio.save(buffer, result, int(codec_audio_sr), format="wav")
103
+ buffer.seek(0)
104
+ return buffer.read()
105
+
106
+ @spaces.GPU(duration=90)
107
+ def run(seed, left_margin, right_margin, codec_audio_sr, codec_sr, top_k, top_p, temperature,
108
+ stop_repetition, sample_batch_size, kvcache, silence_tokens,
109
+ audio_path, word_info, transcript, smart_transcript,
110
+ mode, prompt_end_time, edit_start_time, edit_end_time,
111
+ split_text, selected_sentence, previous_audio_tensors):
112
+ if voicecraft_model is None:
113
+ raise gr.Error("VoiceCraft model not loaded")
114
+ if smart_transcript and (word_info is None):
115
+ raise gr.Error("Can't use smart transcript: whisper transcript not found")
116
+
117
+ seed_everything(seed)
118
+ if mode == "Long TTS":
119
+ if split_text == "Newline":
120
+ sentences = transcript.split('\n')
121
+ else:
122
+ from nltk.tokenize import sent_tokenize
123
+ sentences = sent_tokenize(transcript.replace("\n", " "))
124
+ elif mode == "Rerun":
125
+ colon_position = selected_sentence.find(':')
126
+ selected_sentence_idx = int(selected_sentence[:colon_position])
127
+ sentences = [selected_sentence[colon_position + 1:]]
128
+ else:
129
+ sentences = [transcript.replace("\n", " ")]
130
+
131
+ device = "cuda" if torch.cuda.is_available() else "cpu"
132
+ info = torchaudio.info(audio_path)
133
+ audio_dur = info.num_frames / info.sample_rate
134
+
135
+ audio_tensors = []
136
+ inference_transcript = ""
137
+ for sentence in sentences:
138
+ decode_config = {"top_k": top_k, "top_p": top_p, "temperature": temperature, "stop_repetition": stop_repetition,
139
+ "kvcache": kvcache, "codec_audio_sr": codec_audio_sr, "codec_sr": codec_sr,
140
+ "silence_tokens": silence_tokens, "sample_batch_size": sample_batch_size}
141
+ if mode != "Edit":
142
+ from inference_tts_scale import inference_one_sample
143
+
144
+ if smart_transcript:
145
+ target_transcript = ""
146
+ for word in word_info:
147
+ if word["end"] < prompt_end_time:
148
+ target_transcript += word["word"]
149
+ elif (word["start"] + word["end"]) / 2 < prompt_end_time:
150
+ # include part of the word it it's big, but adjust prompt_end_time
151
+ target_transcript += word["word"]
152
+ prompt_end_time = word["end"]
153
+ break
154
+ else:
155
+ break
156
+ target_transcript += f" {sentence}"
157
+ else:
158
+ target_transcript = sentence
159
+
160
+ inference_transcript += target_transcript + "\n"
161
+
162
+ prompt_end_frame = int(min(audio_dur, prompt_end_time) * info.sample_rate)
163
+ _, gen_audio = inference_one_sample(voicecraft_model["model"],
164
+ voicecraft_model["ckpt"]["config"],
165
+ voicecraft_model["ckpt"]["phn2num"],
166
+ voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
167
+ audio_path, target_transcript, device, decode_config,
168
+ prompt_end_frame)
169
+ else:
170
+ from inference_speech_editing_scale import inference_one_sample
171
+
172
+ if smart_transcript:
173
+ target_transcript = ""
174
+ for word in word_info:
175
+ if word["start"] < edit_start_time:
176
+ target_transcript += word["word"]
177
+ else:
178
+ break
179
+ target_transcript += f" {sentence}"
180
+ for word in word_info:
181
+ if word["end"] > edit_end_time:
182
+ target_transcript += word["word"]
183
+ else:
184
+ target_transcript = sentence
185
+
186
+ inference_transcript += target_transcript + "\n"
187
+
188
+ morphed_span = (max(edit_start_time - left_margin, 1 / codec_sr), min(edit_end_time + right_margin, audio_dur))
189
+ mask_interval = [[round(morphed_span[0]*codec_sr), round(morphed_span[1]*codec_sr)]]
190
+ mask_interval = torch.LongTensor(mask_interval)
191
+
192
+ _, gen_audio = inference_one_sample(voicecraft_model["model"],
193
+ voicecraft_model["ckpt"]["config"],
194
+ voicecraft_model["ckpt"]["phn2num"],
195
+ voicecraft_model["text_tokenizer"], voicecraft_model["audio_tokenizer"],
196
+ audio_path, target_transcript, mask_interval, device, decode_config)
197
+ gen_audio = gen_audio[0].cpu()
198
+ audio_tensors.append(gen_audio)
199
+
200
+ if mode != "Rerun":
201
+ output_audio = get_output_audio(audio_tensors, codec_audio_sr)
202
+ sentences = [f"{idx}: {text}" for idx, text in enumerate(sentences)]
203
+ component = gr.Dropdown(choices=sentences, value=sentences[0])
204
+ return output_audio, inference_transcript, component, audio_tensors
205
+ else:
206
+ previous_audio_tensors[selected_sentence_idx] = audio_tensors[0]
207
+ output_audio = get_output_audio(previous_audio_tensors, codec_audio_sr)
208
+ sentence_audio = get_output_audio(audio_tensors, codec_audio_sr)
209
+ return output_audio, inference_transcript, sentence_audio, previous_audio_tensors
210
+
211
+
212
+ def update_input_audio(audio_path):
213
+ if audio_path is None:
214
+ return 0, 0, 0
215
+
216
+ info = torchaudio.info(audio_path)
217
+ max_time = round(info.num_frames / info.sample_rate, 2)
218
+ return [
219
+ gr.Slider(maximum=max_time, value=max_time),
220
+ gr.Slider(maximum=max_time, value=0),
221
+ gr.Slider(maximum=max_time, value=max_time),
222
+ ]
223
+
224
+
225
+ def change_mode(mode):
226
+ tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor
227
+ return [
228
+ gr.Group(visible=mode != "Edit"),
229
+ gr.Group(visible=mode == "Edit"),
230
+ gr.Radio(visible=mode == "Edit"),
231
+ gr.Radio(visible=mode == "Long TTS"),
232
+ gr.Group(visible=mode == "Long TTS"),
233
+ ]
234
+
235
+
236
+ def load_sentence(selected_sentence, codec_audio_sr, audio_tensors):
237
+ if selected_sentence is None:
238
+ return None
239
+ colon_position = selected_sentence.find(':')
240
+ selected_sentence_idx = int(selected_sentence[:colon_position])
241
+ return get_output_audio([audio_tensors[selected_sentence_idx]], codec_audio_sr)
242
+
243
+
244
+ def update_bound_word(is_first_word, selected_word, edit_word_mode):
245
+ if selected_word is None:
246
+ return None
247
+
248
+ word_start_time = float(selected_word.split(' ')[0])
249
+ word_end_time = float(selected_word.split(' ')[-1])
250
+ if edit_word_mode == "Replace half":
251
+ bound_time = (word_start_time + word_end_time) / 2
252
+ elif is_first_word:
253
+ bound_time = word_start_time
254
+ else:
255
+ bound_time = word_end_time
256
+
257
+ return bound_time
258
+
259
+
260
+ def update_bound_words(from_selected_word, to_selected_word, edit_word_mode):
261
+ return [
262
+ update_bound_word(True, from_selected_word, edit_word_mode),
263
+ update_bound_word(False, to_selected_word, edit_word_mode),
264
+ ]
265
+
266
+
267
+ smart_transcript_info = """
268
+ If enabled, the target transcript will be constructed for you:</br>
269
+ - In TTS and Long TTS mode just write the text you want to synthesize.</br>
270
+ - In Edit mode just write the text to replace selected editing segment.</br>
271
+ If disabled, you should write the target transcript yourself:</br>
272
+ - In TTS mode write prompt transcript followed by generation transcript.</br>
273
+ - In Long TTS select split by newline (<b>SENTENCE SPLIT WON'T WORK</b>) and start each line with a prompt transcript.</br>
274
+ - In Edit mode write full prompt</br>
275
+ """
276
+
277
+ demo_original_transcript = " But when I had approached so near to them, the common object, which the sense deceives, lost not by distance any of its marks."
278
+
279
+ demo_text = {
280
+ "TTS": {
281
+ "smart": "I cannot believe that the same model can also do text to speech synthesis as well!",
282
+ "regular": "But when I had approached so near to them, the common I cannot believe that the same model can also do text to speech synthesis as well!"
283
+ },
284
+ "Edit": {
285
+ "smart": "saw the mirage of the lake in the distance,",
286
+ "regular": "But when I saw the mirage of the lake in the distance, which the sense deceives, Lost not by distance any of its marks,"
287
+ },
288
+ "Long TTS": {
289
+ "smart": "You can run generation on a big text!\n"
290
+ "Just write it line-by-line. Or sentence-by-sentence.\n"
291
+ "If some sentences sound odd, just rerun generation on them, no need to generate the whole text again!",
292
+ "regular": "But when I had approached so near to them, the common You can run generation on a big text!\n"
293
+ "But when I had approached so near to them, the common Just write it line-by-line. Or sentence-by-sentence.\n"
294
+ "But when I had approached so near to them, the common If some sentences sound odd, just rerun generation on them, no need to generate the whole text again!"
295
+ }
296
+ }
297
+
298
+ all_demo_texts = {vv for k, v in demo_text.items() for kk, vv in v.items()}
299
+
300
+ demo_words = [
301
+ "0.03 but 0.18",
302
+ "0.18 when 0.32",
303
+ "0.32 i 0.48",
304
+ "0.48 had 0.64",
305
+ "0.64 approached 1.19",
306
+ "1.22 so 1.58",
307
+ "1.58 near 1.91",
308
+ "1.91 to 2.07",
309
+ "2.07 them 2.42",
310
+ "2.53 the 2.61",
311
+ "2.61 common 3.01",
312
+ "3.05 object 3.62",
313
+ "3.68 which 3.93",
314
+ "3.93 the 4.02",
315
+ "4.02 sense 4.34",
316
+ "4.34 deceives 4.97",
317
+ "5.04 lost 5.54",
318
+ "5.54 not 6.00",
319
+ "6.00 by 6.14",
320
+ "6.14 distance 6.67",
321
+ "6.79 any 7.05",
322
+ "7.05 of 7.18",
323
+ "7.18 its 7.34",
324
+ "7.34 marks 7.87"
325
+ ]
326
+
327
+ demo_word_info = [
328
+ {"word": "but", "start": 0.03, "end": 0.18},
329
+ {"word": "when", "start": 0.18, "end": 0.32},
330
+ {"word": "i", "start": 0.32, "end": 0.48},
331
+ {"word": "had", "start": 0.48, "end": 0.64},
332
+ {"word": "approached", "start": 0.64, "end": 1.19},
333
+ {"word": "so", "start": 1.22, "end": 1.58},
334
+ {"word": "near", "start": 1.58, "end": 1.91},
335
+ {"word": "to", "start": 1.91, "end": 2.07},
336
+ {"word": "them", "start": 2.07, "end": 2.42},
337
+ {"word": "the", "start": 2.53, "end": 2.61},
338
+ {"word": "common", "start": 2.61, "end": 3.01},
339
+ {"word": "object", "start": 3.05, "end": 3.62},
340
+ {"word": "which", "start": 3.68, "end": 3.93},
341
+ {"word": "the", "start": 3.93, "end": 4.02},
342
+ {"word": "sense", "start": 4.02, "end": 4.34},
343
+ {"word": "deceives", "start": 4.34, "end": 4.97},
344
+ {"word": "lost", "start": 5.04, "end": 5.54},
345
+ {"word": "not", "start": 5.54, "end": 6.0},
346
+ {"word": "by", "start": 6.0, "end": 6.14},
347
+ {"word": "distance", "start": 6.14, "end": 6.67},
348
+ {"word": "any", "start": 6.79, "end": 7.05},
349
+ {"word": "of", "start": 7.05, "end": 7.18},
350
+ {"word": "its", "start": 7.18, "end": 7.34},
351
+ {"word": "marks", "start": 7.34, "end": 7.87}
352
+ ]
353
+
354
+
355
+ def update_demo(mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word):
356
+ if transcript not in all_demo_texts:
357
+ return transcript, edit_from_word, edit_to_word
358
+
359
+ replace_half = edit_word_mode == "Replace half"
360
+ change_edit_from_word = edit_from_word == demo_words[2] or edit_from_word == demo_words[3]
361
+ change_edit_to_word = edit_to_word == demo_words[11] or edit_to_word == demo_words[12]
362
+ demo_edit_from_word_value = demo_words[2] if replace_half else demo_words[3]
363
+ demo_edit_to_word_value = demo_words[12] if replace_half else demo_words[11]
364
+ return [
365
+ demo_text[mode]["smart" if smart_transcript else "regular"],
366
+ demo_edit_from_word_value if change_edit_from_word else edit_from_word,
367
+ demo_edit_to_word_value if change_edit_to_word else edit_to_word,
368
+ ]
369
+
370
+
371
+ with gr.Blocks() as app:
372
+ with gr.Row():
373
+ with gr.Column(scale=2):
374
+ load_models_btn = gr.Button(value="Load models")
375
+ with gr.Column(scale=5):
376
+ with gr.Accordion("Select models", open=False) as models_selector:
377
+ with gr.Row():
378
+ voicecraft_model_choice = gr.Radio(label="VoiceCraft model", value="giga830M", choices=["giga330M", "giga830M"])
379
+ whisper_model_choice = gr.Radio(label="Whisper model", value="base.en",
380
+ choices=[None, "tiny.en", "base.en", "small.en", "medium.en", "large"])
381
+
382
+ with gr.Row():
383
+ with gr.Column(scale=2):
384
+ input_audio = gr.Audio(value="./demo/84_121550_000074_000000.wav", label="Input Audio", type="filepath")
385
+ with gr.Group():
386
+ original_transcript = gr.Textbox(label="Original transcript", lines=5, value=demo_original_transcript, interactive=False,
387
+ info="Use whisper model to get the transcript. Fix it if necessary.")
388
+ with gr.Accordion("Word start time", open=False):
389
+ transcript_with_start_time = gr.Textbox(label="Start time", lines=5, interactive=False, info="Start time before each word")
390
+ with gr.Accordion("Word end time", open=False):
391
+ transcript_with_end_time = gr.Textbox(label="End time", lines=5, interactive=False, info="End time after each word")
392
+
393
+ transcribe_btn = gr.Button(value="Transcribe")
394
+
395
+ with gr.Column(scale=3):
396
+ with gr.Group():
397
+ transcript = gr.Textbox(label="Text", lines=7, value=demo_text["TTS"]["smart"])
398
+ with gr.Row():
399
+ smart_transcript = gr.Checkbox(label="Smart transcript", value=True)
400
+ with gr.Accordion(label="?", open=False):
401
+ info = gr.Markdown(value=smart_transcript_info)
402
+
403
+ with gr.Row():
404
+ mode = gr.Radio(label="Mode", choices=["TTS", "Edit", "Long TTS"], value="TTS")
405
+ split_text = gr.Radio(label="Split text", choices=["Newline", "Sentence"], value="Newline",
406
+ info="Split text into parts and run TTS for each part.", visible=False)
407
+ edit_word_mode = gr.Radio(label="Edit word mode", choices=["Replace half", "Replace all"], value="Replace half",
408
+ info="What to do with first and last word", visible=False)
409
+
410
+ with gr.Group() as tts_mode_controls:
411
+ prompt_to_word = gr.Dropdown(label="Last word in prompt", choices=demo_words, value=demo_words[10], interactive=True)
412
+ prompt_end_time = gr.Slider(label="Prompt end time", minimum=0, maximum=7.93, step=0.01, value=3.01)
413
+
414
+ with gr.Group(visible=False) as edit_mode_controls:
415
+ with gr.Row():
416
+ edit_from_word = gr.Dropdown(label="First word to edit", choices=demo_words, value=demo_words[2], interactive=True)
417
+ edit_to_word = gr.Dropdown(label="Last word to edit", choices=demo_words, value=demo_words[12], interactive=True)
418
+ with gr.Row():
419
+ edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=7.93, step=0.01, value=0.35)
420
+ edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=7.93, step=0.01, value=3.75)
421
+
422
+ run_btn = gr.Button(value="Run")
423
+
424
+ with gr.Column(scale=2):
425
+ output_audio = gr.Audio(label="Output Audio")
426
+ with gr.Accordion("Inference transcript", open=False):
427
+ inference_transcript = gr.Textbox(label="Inference transcript", lines=5, interactive=False,
428
+ info="Inference was performed on this transcript.")
429
+ with gr.Group(visible=False) as long_tts_sentence_editor:
430
+ sentence_selector = gr.Dropdown(label="Sentence", value=None,
431
+ info="Select sentence you want to regenerate")
432
+ sentence_audio = gr.Audio(label="Sentence Audio", scale=2)
433
+ rerun_btn = gr.Button(value="Rerun")
434
+
435
+ with gr.Row():
436
+ with gr.Accordion("VoiceCraft config", open=False):
437
+ seed = gr.Number(label="seed", value=-1, precision=0)
438
+ left_margin = gr.Number(label="left_margin", value=0.08)
439
+ right_margin = gr.Number(label="right_margin", value=0.08)
440
+ codec_audio_sr = gr.Number(label="codec_audio_sr", value=16000)
441
+ codec_sr = gr.Number(label="codec_sr", value=50)
442
+ top_k = gr.Number(label="top_k", value=0)
443
+ top_p = gr.Number(label="top_p", value=0.8)
444
+ temperature = gr.Number(label="temperature", value=1)
445
+ stop_repetition = gr.Radio(label="stop_repetition", choices=[-1, 1, 2, 3], value=3,
446
+ info="if there are long silence in the generated audio, reduce the stop_repetition to 3, 2 or even 1, -1 = disabled")
447
+ sample_batch_size = gr.Number(label="sample_batch_size", value=4, precision=0,
448
+ info="generate this many samples and choose the shortest one")
449
+ kvcache = gr.Radio(label="kvcache", choices=[0, 1], value=1,
450
+ info="set to 0 to use less VRAM, but with slower inference")
451
+ silence_tokens = gr.Textbox(label="silence tokens", value="[1388,1898,131]")
452
+
453
+
454
+ audio_tensors = gr.State()
455
+ word_info = gr.State(value=demo_word_info)
456
+
457
+
458
+ mode.change(fn=update_demo,
459
+ inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
460
+ outputs=[transcript, edit_from_word, edit_to_word])
461
+ edit_word_mode.change(fn=update_demo,
462
+ inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
463
+ outputs=[transcript, edit_from_word, edit_to_word])
464
+ smart_transcript.change(fn=update_demo,
465
+ inputs=[mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word],
466
+ outputs=[transcript, edit_from_word, edit_to_word])
467
+
468
+ load_models_btn.click(fn=load_models,
469
+ inputs=[whisper_model_choice, voicecraft_model_choice],
470
+ outputs=[models_selector])
471
+
472
+ input_audio.upload(fn=update_input_audio,
473
+ inputs=[input_audio],
474
+ outputs=[prompt_end_time, edit_start_time, edit_end_time])
475
+ transcribe_btn.click(fn=transcribe,
476
+ inputs=[seed, input_audio],
477
+ outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time,
478
+ prompt_to_word, edit_from_word, edit_to_word, word_info])
479
+
480
+ mode.change(fn=change_mode,
481
+ inputs=[mode],
482
+ outputs=[tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor])
483
+
484
+ run_btn.click(fn=run,
485
+ inputs=[
486
+ seed, left_margin, right_margin,
487
+ codec_audio_sr, codec_sr,
488
+ top_k, top_p, temperature,
489
+ stop_repetition, sample_batch_size,
490
+ kvcache, silence_tokens,
491
+ input_audio, word_info, transcript, smart_transcript,
492
+ mode, prompt_end_time, edit_start_time, edit_end_time,
493
+ split_text, sentence_selector, audio_tensors
494
+ ],
495
+ outputs=[output_audio, inference_transcript, sentence_selector, audio_tensors])
496
+
497
+ sentence_selector.change(fn=load_sentence,
498
+ inputs=[sentence_selector, codec_audio_sr, audio_tensors],
499
+ outputs=[sentence_audio])
500
+ rerun_btn.click(fn=run,
501
+ inputs=[
502
+ seed, left_margin, right_margin,
503
+ codec_audio_sr, codec_sr,
504
+ top_k, top_p, temperature,
505
+ stop_repetition, sample_batch_size,
506
+ kvcache, silence_tokens,
507
+ input_audio, word_info, transcript, smart_transcript,
508
+ gr.State(value="Rerun"), prompt_end_time, edit_start_time, edit_end_time,
509
+ split_text, sentence_selector, audio_tensors
510
+ ],
511
+ outputs=[output_audio, inference_transcript, sentence_audio, audio_tensors])
512
+
513
+ prompt_to_word.change(fn=update_bound_word,
514
+ inputs=[gr.State(False), prompt_to_word, gr.State("Replace all")],
515
+ outputs=[prompt_end_time])
516
+ edit_from_word.change(fn=update_bound_word,
517
+ inputs=[gr.State(True), edit_from_word, edit_word_mode],
518
+ outputs=[edit_start_time])
519
+ edit_to_word.change(fn=update_bound_word,
520
+ inputs=[gr.State(False), edit_to_word, edit_word_mode],
521
+ outputs=[edit_end_time])
522
+ edit_word_mode.change(fn=update_bound_words,
523
+ inputs=[edit_from_word, edit_to_word, edit_word_mode],
524
+ outputs=[edit_start_time, edit_end_time])
525
+
526
+
527
+ if __name__ == "__main__":
528
+ app.launch()
inference_speech_editing_scale.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, pickle
2
+ import logging
3
+ import os, random
4
+ import numpy as np
5
+ import torch
6
+ import torchaudio
7
+
8
+ from data.tokenizer import (
9
+ AudioTokenizer,
10
+ TextTokenizer,
11
+ tokenize_audio,
12
+ tokenize_text
13
+ )
14
+
15
+ from models import voicecraft
16
+ import argparse, time, tqdm
17
+
18
+ # this script only works for the musicgen architecture
19
+ def get_args():
20
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
21
+ parser.add_argument("--manifest_fn", type=str, default="path/to/eval_metadata_file")
22
+ parser.add_argument("--audio_root", type=str, default="path/to/audio_folder")
23
+ parser.add_argument("--exp_dir", type=str, default="path/to/model_folder")
24
+ parser.add_argument("--left_margin", type=float, default=0.08, help="extra space on the left to the word boundary")
25
+ parser.add_argument("--right_margin", type=float, default=0.08, help="extra space on the right to the word boundary")
26
+ parser.add_argument("--seed", type=int, default=1)
27
+ parser.add_argument("--codec_audio_sr", type=int, default=16000, help='the sample rate of audio that the codec is trained for')
28
+ parser.add_argument("--codec_sr", type=int, default=50, help='the sample rate of the codec codes')
29
+ parser.add_argument("--top_k", type=int, default=-1, help="sampling param")
30
+ parser.add_argument("--top_p", type=float, default=0.8, help="sampling param")
31
+ parser.add_argument("--temperature", type=float, default=1.0, help="sampling param")
32
+ parser.add_argument("--output_dir", type=str, default=None)
33
+ parser.add_argument("--device", type=str, default="cuda")
34
+ parser.add_argument("--signature", type=str, default=None, help="path to the encodec model")
35
+ parser.add_argument("--stop_repetition", type=int, default=2, help="used for inference, when the number of consecutive repetition of a token is bigger than this, stop it")
36
+ parser.add_argument("--kvcache", type=int, default=1, help='if true, use kv cache, which is 4-8x faster than without')
37
+ parser.add_argument("--silence_tokens", type=str, default="[1388,1898,131]", help="note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
38
+ return parser.parse_args()
39
+
40
+ @torch.no_grad()
41
+ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, mask_interval, device, decode_config):
42
+ # phonemize
43
+ text_tokens = [phn2num[phn] for phn in
44
+ tokenize_text(
45
+ text_tokenizer, text=target_text.strip()
46
+ ) if phn in phn2num
47
+ ]
48
+ text_tokens = torch.LongTensor(text_tokens).unsqueeze(0)
49
+ text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]])
50
+
51
+ encoded_frames = tokenize_audio(audio_tokenizer, audio_fn)
52
+ original_audio = encoded_frames[0][0].transpose(2,1) # [1,T,K]
53
+ assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape
54
+ logging.info(f"with direct encodec encoding before input, original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.")
55
+
56
+ # forward
57
+ stime = time.time()
58
+ encoded_frames = model.inference(
59
+ text_tokens.to(device),
60
+ text_tokens_lens.to(device),
61
+ original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
62
+ mask_interval=mask_interval.unsqueeze(0).to(device),
63
+ top_k=decode_config['top_k'],
64
+ top_p=decode_config['top_p'],
65
+ temperature=decode_config['temperature'],
66
+ stop_repetition=decode_config['stop_repetition'],
67
+ kvcache=decode_config['kvcache'],
68
+ silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens']) == str else decode_config['silence_tokens'],
69
+ ) # output is [1,K,T]
70
+ logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
71
+ if type(encoded_frames) == tuple:
72
+ encoded_frames = encoded_frames[0]
73
+ logging.info(f"generated encoded_frames.shape: {encoded_frames.shape}, which is {encoded_frames.shape[-1]/decode_config['codec_sr']} sec.")
74
+
75
+
76
+ # decode (both original and generated)
77
+ original_sample = audio_tokenizer.decode(
78
+ [(original_audio.transpose(2,1), None)] # [1,T,8] -> [1,8,T]
79
+ )
80
+ generated_sample = audio_tokenizer.decode(
81
+ [(encoded_frames, None)]
82
+ )
83
+
84
+ return original_sample, generated_sample
85
+
86
+ def get_model(exp_dir, device=None):
87
+ with open(os.path.join(exp_dir, "args.pkl"), "rb") as f:
88
+ model_args = pickle.load(f)
89
+
90
+ logging.info("load model weights...")
91
+ model = voicecraft.VoiceCraft(model_args)
92
+ ckpt_fn = os.path.join(exp_dir, "best_bundle.pth")
93
+ ckpt = torch.load(ckpt_fn, map_location='cpu')['model']
94
+ phn2num = torch.load(ckpt_fn, map_location='cpu')['phn2num']
95
+ model.load_state_dict(ckpt)
96
+ del ckpt
97
+ logging.info("done loading weights...")
98
+ if device == None:
99
+ device = torch.device("cpu")
100
+ if torch.cuda.is_available():
101
+ device = torch.device("cuda:0")
102
+ model.to(device)
103
+ model.eval()
104
+ return model, model_args, phn2num
105
+
106
+
107
+ def get_mask_interval(ali_fn, word_span_ind, editType):
108
+ with open(ali_fn, "r") as rf:
109
+ data = [l.strip().split(",") for l in rf.readlines()]
110
+ data = data[1:]
111
+ tmp = word_span_ind.split(",")
112
+ s, e = int(tmp[0]), int(tmp[-1])
113
+ start = None
114
+ for j, item in enumerate(data):
115
+ if j == s and item[3] == "words":
116
+ if editType == 'insertion':
117
+ start = float(item[1])
118
+ else:
119
+ start = float(item[0])
120
+ if j == e and item[3] == "words":
121
+ if editType == 'insertion':
122
+ end = float(item[0])
123
+ else:
124
+ end = float(item[1])
125
+ assert start != None
126
+ break
127
+ return (start, end)
128
+
129
+ if __name__ == "__main__":
130
+ def seed_everything(seed):
131
+ os.environ['PYTHONHASHSEED'] = str(seed)
132
+ random.seed(seed)
133
+ np.random.seed(seed)
134
+ torch.manual_seed(seed)
135
+ torch.cuda.manual_seed(seed)
136
+ torch.backends.cudnn.benchmark = False
137
+ torch.backends.cudnn.deterministic = True
138
+ formatter = (
139
+ "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
140
+ )
141
+ logging.basicConfig(format=formatter, level=logging.INFO)
142
+ args = get_args()
143
+ # args.device = 'cpu'
144
+ args.allowed_repeat_tokens = eval(args.allowed_repeat_tokens)
145
+ seed_everything(args.seed)
146
+
147
+ # load model
148
+ stime = time.time()
149
+ logging.info(f"loading model from {args.exp_dir}")
150
+ model, model_args, phn2num = get_model(args.exp_dir)
151
+ if not os.path.isfile(model_args.exp_dir):
152
+ model_args.exp_dir = args.exp_dir
153
+ logging.info(f"loading model done, took {time.time() - stime:.4f} sec")
154
+
155
+ # setup text and audio tokenizer
156
+ text_tokenizer = TextTokenizer(backend="espeak")
157
+ audio_tokenizer = AudioTokenizer(signature=args.signature) # will also put the neural codec model on gpu
158
+
159
+ with open(args.manifest_fn, "r") as rf:
160
+ manifest = [l.strip().split("\t") for l in rf.readlines()]
161
+ manifest = manifest[1:]
162
+
163
+ # wav_fn txt_fn alingment_fn num_words word_span_ind
164
+ audio_fns = []
165
+ target_texts = []
166
+ mask_intervals = []
167
+ edit_types = []
168
+ new_spans = []
169
+ orig_spans = []
170
+ os.makedirs(args.output_dir, exist_ok=True)
171
+ if args.crop_concat:
172
+ mfa_temp = f"{args.output_dir}/mfa_temp"
173
+ os.makedirs(mfa_temp, exist_ok=True)
174
+ for item in manifest:
175
+ audio_fn = os.path.join(args.audio_root, item[0])
176
+ temp = torchaudio.info(audio_fn)
177
+ audio_dur = temp.num_frames/temp.sample_rate
178
+ audio_fns.append(audio_fn)
179
+ target_text = item[2].split("|")[-1]
180
+ edit_types.append(item[5].split("|"))
181
+ new_spans.append(item[4].split("|"))
182
+ orig_spans.append(item[3].split("|"))
183
+ target_texts.append(target_text) # the last transcript is the target
184
+ # mi needs to be created from word_ind_span and alignment_fn, along with args.left_margin and args.right_margin
185
+ mis = []
186
+ all_ind_intervals = item[3].split("|")
187
+ editTypes = item[5].split("|")
188
+ smaller_indx = []
189
+ alignment_fn = os.path.join(args.audio_root, "aligned", item[0].replace(".wav", ".csv"))
190
+ if not os.path.isfile(alignment_fn):
191
+ alignment_fn = alignment_fn.replace("/aligned/", "/aligned_csv/")
192
+ assert os.path.isfile(alignment_fn), alignment_fn
193
+ for ind_inter,editType in zip(all_ind_intervals, editTypes):
194
+ # print(ind_inter)
195
+ mi = get_mask_interval(alignment_fn, ind_inter, editType)
196
+ mi = (max(mi[0] - args.left_margin, 1/args.codec_sr), min(mi[1] + args.right_margin, audio_dur)) # in seconds
197
+ mis.append(mi)
198
+ smaller_indx.append(mi[0])
199
+ ind = np.argsort(smaller_indx)
200
+ mis = [mis[id] for id in ind]
201
+ mask_intervals.append(mis)
202
+
203
+
204
+
205
+ for i, (audio_fn, target_text, mask_interval) in enumerate(tqdm.tqdm(zip(audio_fns, target_texts, mask_intervals))):
206
+ orig_mask_interval = mask_interval
207
+ mask_interval = [[round(cmi[0]*args.codec_sr), round(cmi[1]*args.codec_sr)] for cmi in mask_interval]
208
+ # logging.info(f"i: {i}, mask_interval: {mask_interval}")
209
+ mask_interval = torch.LongTensor(mask_interval) # [M,2]
210
+ orig_audio, new_audio = inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, mask_interval, args.device, vars(args))
211
+
212
+ # save segments for comparison
213
+ orig_audio, new_audio = orig_audio[0].cpu(), new_audio[0].cpu()
214
+ # logging.info(f"length of the resynthesize orig audio: {orig_audio.shape}")
215
+
216
+ save_fn_new = f"{args.output_dir}/{os.path.basename(audio_fn)[:-4]}_new_seed{args.seed}.wav"
217
+
218
+ torchaudio.save(save_fn_new, new_audio, args.codec_audio_sr)
219
+
220
+ save_fn_orig = f"{args.output_dir}/{os.path.basename(audio_fn)[:-4]}_orig.wav"
221
+ if not os.path.isfile(save_fn_orig):
222
+ orig_audio, orig_sr = torchaudio.load(audio_fn)
223
+ if orig_sr != args.codec_audio_sr:
224
+ orig_audio = torchaudio.transforms.Resample(orig_sr, args.codec_audio_sr)(orig_audio)
225
+ torchaudio.save(save_fn_orig, orig_audio, args.codec_audio_sr)
226
+
inference_tts_scale.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, pickle
2
+ import logging
3
+ import os, random
4
+ import numpy as np
5
+ import torch
6
+ import torchaudio
7
+
8
+ from data.tokenizer import (
9
+ AudioTokenizer,
10
+ TextTokenizer,
11
+ tokenize_audio,
12
+ tokenize_text
13
+ )
14
+
15
+ from models import voicecraft
16
+ import argparse, time, tqdm
17
+
18
+
19
+ # this script only works for the musicgen architecture
20
+ def get_args():
21
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
22
+ parser.add_argument("--manifest_fn", type=str, default="path/to/eval_metadata_file")
23
+ parser.add_argument("--audio_root", type=str, default="path/to/audio_folder")
24
+ parser.add_argument("--exp_dir", type=str, default="path/to/model_folder")
25
+ parser.add_argument("--seed", type=int, default=1)
26
+ parser.add_argument("--codec_audio_sr", type=int, default=16000, help='the sample rate of audio that the codec is trained for')
27
+ parser.add_argument("--codec_sr", type=int, default=50, help='the sample rate of the codec codes')
28
+ parser.add_argument("--top_k", type=int, default=0, help="sampling param")
29
+ parser.add_argument("--top_p", type=float, default=0.8, help="sampling param")
30
+ parser.add_argument("--temperature", type=float, default=1.0, help="sampling param")
31
+ parser.add_argument("--output_dir", type=str, default=None)
32
+ parser.add_argument("--device", type=str, default="cuda")
33
+ parser.add_argument("--signature", type=str, default=None, help="path to the encodec model")
34
+ parser.add_argument("--crop_concat", type=int, default=0)
35
+ parser.add_argument("--stop_repetition", type=int, default=-1, help="used for inference, when the number of consecutive repetition of a token is bigger than this, stop it")
36
+ parser.add_argument("--kvcache", type=int, default=1, help='if true, use kv cache, which is 4-8x faster than without')
37
+ parser.add_argument("--sample_batch_size", type=int, default=1, help="batch size for sampling, NOTE that it's not running inference for several samples, but duplicate one input sample batch_size times, and during inference, we only return the shortest generation")
38
+ parser.add_argument("--silence_tokens", type=str, default="[1388,1898,131]", help="note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
39
+ return parser.parse_args()
40
+
41
+
42
+ @torch.no_grad()
43
+ def inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_text, device, decode_config, prompt_end_frame):
44
+ # phonemize
45
+ text_tokens = [phn2num[phn] for phn in
46
+ tokenize_text(
47
+ text_tokenizer, text=target_text.strip()
48
+ ) if phn in phn2num
49
+ ]
50
+ text_tokens = torch.LongTensor(text_tokens).unsqueeze(0)
51
+ text_tokens_lens = torch.LongTensor([text_tokens.shape[-1]])
52
+
53
+ # encode audio
54
+ encoded_frames = tokenize_audio(audio_tokenizer, audio_fn, offset=0, num_frames=prompt_end_frame)
55
+ original_audio = encoded_frames[0][0].transpose(2,1) # [1,T,K]
56
+ assert original_audio.ndim==3 and original_audio.shape[0] == 1 and original_audio.shape[2] == model_args.n_codebooks, original_audio.shape
57
+ logging.info(f"original audio length: {original_audio.shape[1]} codec frames, which is {original_audio.shape[1]/decode_config['codec_sr']:.2f} sec.")
58
+
59
+ # forward
60
+ stime = time.time()
61
+ if decode_config['sample_batch_size'] <= 1:
62
+ logging.info(f"running inference with batch size 1")
63
+ concat_frames, gen_frames = model.inference_tts(
64
+ text_tokens.to(device),
65
+ text_tokens_lens.to(device),
66
+ original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
67
+ top_k=decode_config['top_k'],
68
+ top_p=decode_config['top_p'],
69
+ temperature=decode_config['temperature'],
70
+ stop_repetition=decode_config['stop_repetition'],
71
+ kvcache=decode_config['kvcache'],
72
+ silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens']
73
+ ) # output is [1,K,T]
74
+ else:
75
+ logging.info(f"running inference with batch size {decode_config['sample_batch_size']}, i.e. return the shortest among {decode_config['sample_batch_size']} generations.")
76
+ concat_frames, gen_frames = model.inference_tts_batch(
77
+ text_tokens.to(device),
78
+ text_tokens_lens.to(device),
79
+ original_audio[...,:model_args.n_codebooks].to(device), # [1,T,8]
80
+ top_k=decode_config['top_k'],
81
+ top_p=decode_config['top_p'],
82
+ temperature=decode_config['temperature'],
83
+ stop_repetition=decode_config['stop_repetition'],
84
+ kvcache=decode_config['kvcache'],
85
+ batch_size = decode_config['sample_batch_size'],
86
+ silence_tokens=eval(decode_config['silence_tokens']) if type(decode_config['silence_tokens'])==str else decode_config['silence_tokens']
87
+ ) # output is [1,K,T]
88
+ logging.info(f"inference on one sample take: {time.time() - stime:.4f} sec.")
89
+
90
+ logging.info(f"generated encoded_frames.shape: {gen_frames.shape}, which is {gen_frames.shape[-1]/decode_config['codec_sr']} sec.")
91
+
92
+ # for timestamp, codes in enumerate(gen_frames[0].transpose(1,0)):
93
+ # logging.info(f"{timestamp}: {codes.tolist()}")
94
+ # decode (both original and generated)
95
+ concat_sample = audio_tokenizer.decode(
96
+ [(concat_frames, None)] # [1,T,8] -> [1,8,T]
97
+ )
98
+ gen_sample = audio_tokenizer.decode(
99
+ [(gen_frames, None)]
100
+ )
101
+
102
+ # return
103
+ return concat_sample, gen_sample
104
+
105
+ def get_model(exp_dir, device=None):
106
+ with open(os.path.join(exp_dir, "args.pkl"), "rb") as f:
107
+ model_args = pickle.load(f)
108
+
109
+ logging.info("load model weights...")
110
+ model = voicecraft.VoiceCraft(model_args)
111
+ ckpt_fn = os.path.join(exp_dir, "best_bundle.pth")
112
+ ckpt = torch.load(ckpt_fn, map_location='cpu')['model']
113
+ phn2num = torch.load(ckpt_fn, map_location='cpu')['phn2num']
114
+ model.load_state_dict(ckpt)
115
+ del ckpt
116
+ logging.info("done loading weights...")
117
+ if device == None:
118
+ device = torch.device("cpu")
119
+ if torch.cuda.is_available():
120
+ device = torch.device("cuda:0")
121
+ model.to(device)
122
+ model.eval()
123
+ return model, model_args, phn2num
124
+
125
+ if __name__ == "__main__":
126
+ def seed_everything(seed):
127
+ os.environ['PYTHONHASHSEED'] = str(seed)
128
+ random.seed(seed)
129
+ np.random.seed(seed)
130
+ torch.manual_seed(seed)
131
+ torch.cuda.manual_seed(seed)
132
+ torch.backends.cudnn.benchmark = False
133
+ torch.backends.cudnn.deterministic = True
134
+ formatter = (
135
+ "%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
136
+ )
137
+ logging.basicConfig(format=formatter, level=logging.INFO)
138
+ args = get_args()
139
+ # args.device='cpu'
140
+ seed_everything(args.seed)
141
+
142
+ os.makedirs(args.output_dir, exist_ok=True)
143
+ # load model
144
+
145
+ with open(args.manifest_fn, "r") as rf:
146
+ manifest = [l.strip().split("\t") for l in rf.readlines()]
147
+ manifest = manifest[1:]
148
+ manifest = [[item[0], item[2], item[3], item[1], item[5]] for item in manifest]
149
+
150
+ stime = time.time()
151
+ logging.info(f"loading model from {args.exp_dir}")
152
+ model, model_args, phn2num = get_model(args.exp_dir)
153
+ logging.info(f"loading model done, took {time.time() - stime:.4f} sec")
154
+
155
+ # setup text and audio tokenizer
156
+ text_tokenizer = TextTokenizer(backend="espeak")
157
+ audio_tokenizer = AudioTokenizer(signature=args.signature) # will also put the neural codec model on gpu
158
+
159
+ audio_fns = []
160
+ texts = []
161
+ prompt_end_frames = []
162
+ new_audio_fns = []
163
+ text_to_syn = []
164
+
165
+ for item in manifest:
166
+ audio_fn = os.path.join(args.audio_root, item[0])
167
+ audio_fns.append(audio_fn)
168
+ temp = torchaudio.info(audio_fn)
169
+ prompt_end_frames.append(round(float(item[2])*temp.sample_rate))
170
+ texts.append(item[1])
171
+ new_audio_fns.append(item[-2])
172
+ all_text = item[1].split(" ")
173
+ start_ind = int(item[-1].split(",")[0])
174
+ text_to_syn.append(" ".join(all_text[start_ind:]))
175
+
176
+ for i, (audio_fn, text, prompt_end_frame, new_audio_fn, to_syn) in enumerate(tqdm.tqdm((zip(audio_fns, texts, prompt_end_frames, new_audio_fns, text_to_syn)))):
177
+ output_expected_sr = args.codec_audio_sr
178
+ concated_audio, gen_audio = inference_one_sample(model, model_args, phn2num, text_tokenizer, audio_tokenizer, audio_fn, text, args.device, vars(args), prompt_end_frame)
179
+
180
+ # save segments for comparison
181
+ concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()
182
+ if output_expected_sr != args.codec_audio_sr:
183
+ gen_audio = torchaudio.transforms.Resample(output_expected_sr, args.codec_audio_sr)(gen_audio)
184
+ concated_audio = torchaudio.transforms.Resample(output_expected_sr, args.codec_audio_sr)(concated_audio)
185
+
186
+ seg_save_fn_gen = f"{args.output_dir}/gen_{new_audio_fn[:-4]}_{i}_seed{args.seed}.wav"
187
+ seg_save_fn_concat = f"{args.output_dir}/concat_{new_audio_fn[:-4]}_{i}_seed{args.seed}.wav"
188
+
189
+ torchaudio.save(seg_save_fn_gen, gen_audio, args.codec_audio_sr)
190
+ torchaudio.save(seg_save_fn_concat, concated_audio, args.codec_audio_sr)
models/__pycache__/codebooks_patterns.cpython-310.pyc ADDED
Binary file (25 kB). View file
 
models/__pycache__/voicecraft.cpython-310.pyc ADDED
Binary file (40.1 kB). View file
 
models/codebooks_patterns.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from collections import namedtuple
8
+ from dataclasses import dataclass
9
+ from functools import lru_cache
10
+ import logging
11
+ import typing as tp
12
+
13
+ from abc import ABC, abstractmethod
14
+ import torch
15
+
16
+ LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
17
+ PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
18
+
19
+
20
+ @dataclass
21
+ class Pattern:
22
+ """Base implementation of a pattern over a sequence with multiple codebooks.
23
+
24
+ The codebook pattern consists in a layout, defining for each sequence step
25
+ the list of coordinates of each codebook timestep in the resulting interleaved sequence.
26
+ The first item of the pattern is always an empty list in order to properly insert a special token
27
+ to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
28
+ and ``timesteps`` the number of timesteps corresponding to the original sequence.
29
+
30
+ The pattern provides convenient methods to build and revert interleaved sequences from it:
31
+ ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
32
+ to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
33
+ K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
34
+ for the output sequence. The unfilled positions are replaced with a special token and the built sequence
35
+ is returned along with a mask indicating valid tokens.
36
+ ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
37
+ of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
38
+ to fill and specify invalid positions if needed.
39
+ See the dedicated methods for more details.
40
+ """
41
+ # Pattern layout, for each sequence step, we have a list of coordinates
42
+ # corresponding to the original codebook timestep and position.
43
+ # The first list is always an empty list in order to properly insert
44
+ # a special token to start with.
45
+ layout: PatternLayout
46
+ timesteps: int
47
+ n_q: int
48
+
49
+ def __post_init__(self):
50
+ assert len(self.layout) > 0
51
+ assert self.layout[0] == []
52
+ self._validate_layout()
53
+ self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
54
+ self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
55
+ # logging.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
56
+
57
+ def _validate_layout(self):
58
+ """Runs checks on the layout to ensure a valid pattern is defined.
59
+ A pattern is considered invalid if:
60
+ - Multiple timesteps for a same codebook are defined in the same sequence step
61
+ - The timesteps for a given codebook are not in ascending order as we advance in the sequence
62
+ (this would mean that we have future timesteps before past timesteps).
63
+ """
64
+ q_timesteps = {q: 0 for q in range(self.n_q)}
65
+ for s, seq_coords in enumerate(self.layout):
66
+ if len(seq_coords) > 0:
67
+ qs = set()
68
+ for coord in seq_coords:
69
+ qs.add(coord.q)
70
+ last_q_timestep = q_timesteps[coord.q]
71
+ assert coord.t >= last_q_timestep, \
72
+ f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
73
+ q_timesteps[coord.q] = coord.t
74
+ # each sequence step contains at max 1 coordinate per codebook
75
+ assert len(qs) == len(seq_coords), \
76
+ f"Multiple entries for a same codebook are found at step {s}"
77
+
78
+ @property
79
+ def num_sequence_steps(self):
80
+ return len(self.layout) - 1
81
+
82
+ @property
83
+ def max_delay(self):
84
+ max_t_in_seq_coords = 0
85
+ for seq_coords in self.layout[1:]:
86
+ for coords in seq_coords:
87
+ max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
88
+ return max_t_in_seq_coords - self.timesteps
89
+
90
+ @property
91
+ def valid_layout(self):
92
+ valid_step = len(self.layout) - self.max_delay
93
+ return self.layout[:valid_step]
94
+
95
+ def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
96
+ """Get codebook coordinates in the layout that corresponds to the specified timestep t
97
+ and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
98
+ and the actual codebook coordinates.
99
+ """
100
+ assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
101
+ if q is not None:
102
+ assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
103
+ coords = []
104
+ for s, seq_codes in enumerate(self.layout):
105
+ for code in seq_codes:
106
+ if code.t == t and (q is None or code.q == q):
107
+ coords.append((s, code))
108
+ return coords
109
+
110
+ def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
111
+ return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
112
+
113
+ def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
114
+ steps_with_timesteps = self.get_steps_with_timestep(t, q)
115
+ return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
116
+
117
+ def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
118
+ device: tp.Union[torch.device, str] = 'cpu'):
119
+ """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
120
+
121
+ Args:
122
+ timesteps (int): Maximum number of timesteps steps to consider.
123
+ keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
124
+ device (Union[torch.device, str]): Device for created tensors.
125
+ Returns:
126
+ indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
127
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
128
+ """
129
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
130
+ assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
131
+ # use the proper layout based on whether we limit ourselves to valid steps only or not,
132
+ # note that using the valid_layout will result in a truncated sequence up to the valid steps
133
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
134
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
135
+ indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
136
+ mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
137
+ # fill indexes with last sequence step value that will correspond to our special token
138
+ # the last value is n_q * timesteps as we have flattened z and append special token as the last token
139
+ # which will correspond to the index: n_q * timesteps
140
+ indexes[:] = n_q * timesteps
141
+ # iterate over the pattern and fill scattered indexes and mask
142
+ for s, sequence_coords in enumerate(ref_layout):
143
+ for coords in sequence_coords:
144
+ if coords.t < timesteps:
145
+ indexes[coords.q, s] = coords.t + coords.q * timesteps
146
+ mask[coords.q, s] = 1
147
+ indexes = torch.from_numpy(indexes).to(device)
148
+ mask = torch.from_numpy(mask).to(device)
149
+ return indexes, mask
150
+
151
+ def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
152
+ """Build sequence corresponding to the pattern from the input tensor z.
153
+ The sequence is built using up to sequence_steps if specified, and non-pattern
154
+ coordinates are filled with the special token.
155
+
156
+ Args:
157
+ z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
158
+ special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
159
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
160
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
161
+ Returns:
162
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
163
+ corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
164
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
165
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
166
+ """
167
+ B, K, T = z.shape
168
+ indexes, mask = self._build_pattern_sequence_scatter_indexes(
169
+ T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
170
+ )
171
+ z = z.view(B, -1)
172
+ # we append the special token as the last index of our flattened z tensor
173
+ z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
174
+ values = z[:, indexes.view(-1)]
175
+ values = values.view(B, K, indexes.shape[-1])
176
+ return values, indexes, mask
177
+
178
+ def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
179
+ keep_only_valid_steps: bool = False,
180
+ is_model_output: bool = False,
181
+ device: tp.Union[torch.device, str] = 'cpu'):
182
+ """Builds scatter indexes required to retrieve the original multi-codebook sequence
183
+ from interleaving pattern.
184
+
185
+ Args:
186
+ sequence_steps (int): Sequence steps.
187
+ n_q (int): Number of codebooks.
188
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
189
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
190
+ is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
191
+ device (Union[torch.device, str]): Device for created tensors.
192
+ Returns:
193
+ torch.Tensor: Indexes for reconstructing the output, of shape [K, T].
194
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
195
+ """
196
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
197
+ # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
198
+ timesteps = self.timesteps
199
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
200
+ assert sequence_steps <= len(ref_layout), \
201
+ f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
202
+
203
+ # ensure we take the appropriate indexes to keep the model output from the first special token as well
204
+ if is_model_output:
205
+ ref_layout = ref_layout[1:]
206
+
207
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
208
+ indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
209
+ mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
210
+ # fill indexes with last sequence step value that will correspond to our special token
211
+ indexes[:] = n_q * sequence_steps
212
+ for s, sequence_codes in enumerate(ref_layout):
213
+ if s < sequence_steps:
214
+ for code in sequence_codes:
215
+ if code.t < timesteps:
216
+ indexes[code.q, code.t] = s + code.q * sequence_steps
217
+ mask[code.q, code.t] = 1
218
+ indexes = torch.from_numpy(indexes).to(device)
219
+ mask = torch.from_numpy(mask).to(device)
220
+ return indexes, mask
221
+
222
+ def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
223
+ """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
224
+ The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
225
+ are filled with the special token.
226
+
227
+ Args:
228
+ s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
229
+ special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
230
+ Returns:
231
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
232
+ corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
233
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
234
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
235
+ """
236
+ B, K, S = s.shape
237
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
238
+ S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
239
+ )
240
+ s = s.view(B, -1)
241
+ # we append the special token as the last index of our flattened z tensor
242
+ s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
243
+ values = s[:, indexes.view(-1)]
244
+ values = values.view(B, K, indexes.shape[-1])
245
+ return values, indexes, mask
246
+
247
+ def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
248
+ """Revert model logits obtained on a sequence built from the pattern
249
+ back to a tensor matching the original sequence.
250
+
251
+ This method is similar to ``revert_pattern_sequence`` with the following specificities:
252
+ 1. It is designed to work with the extra cardinality dimension
253
+ 2. We return the logits for the first sequence item that matches the special_token and
254
+ which matching target in the original sequence is the first item of the sequence,
255
+ while we skip the last logits as there is no matching target
256
+ """
257
+ B, card, K, S = logits.shape
258
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
259
+ S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
260
+ )
261
+ logits = logits.reshape(B, card, -1)
262
+ # we append the special token as the last index of our flattened z tensor
263
+ logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
264
+ values = logits[:, :, indexes.view(-1)]
265
+ values = values.view(B, card, K, indexes.shape[-1])
266
+ return values, indexes, mask
267
+
268
+
269
+ class CodebooksPatternProvider(ABC):
270
+ """Abstraction around providing pattern for interleaving codebooks.
271
+
272
+ The CodebooksPatternProvider abstraction allows to implement various strategies to
273
+ define interleaving pattern of sequences composed of multiple codebooks. For a given
274
+ number of codebooks `n_q`, the pattern provider can generate a specified pattern
275
+ corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
276
+ can be used to construct a new sequence from the original codes respecting the specified
277
+ pattern. The pattern is defined as a list of list of code coordinates, code coordinate
278
+ being a tuple with the original timestep and codebook to build the new sequence.
279
+ Note that all patterns must start with an empty list that is then used to insert a first
280
+ sequence step of special tokens in the newly generated sequence.
281
+
282
+ Args:
283
+ n_q (int): number of codebooks.
284
+ cached (bool): if True, patterns for a given length are cached. In general
285
+ that should be true for efficiency reason to avoid synchronization points.
286
+ """
287
+ def __init__(self, n_q: int, cached: bool = True):
288
+ assert n_q > 0
289
+ self.n_q = n_q
290
+ self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
291
+
292
+ @abstractmethod
293
+ def get_pattern(self, timesteps: int) -> Pattern:
294
+ """Builds pattern with specific interleaving between codebooks.
295
+
296
+ Args:
297
+ timesteps (int): Total numer of timesteps.
298
+ """
299
+ raise NotImplementedError()
300
+
301
+
302
+ class DelayedPatternProvider(CodebooksPatternProvider):
303
+ """Provider for delayed pattern across delayed codebooks.
304
+ Codebooks are delayed in the sequence and sequence steps will contain codebooks
305
+ from different timesteps.
306
+
307
+ Example:
308
+ Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
309
+ [[1, 2, 3, 4],
310
+ [1, 2, 3, 4],
311
+ [1, 2, 3, 4]]
312
+ The resulting sequence obtained from the returned pattern is:
313
+ [[S, 1, 2, 3, 4],
314
+ [S, S, 1, 2, 3],
315
+ [S, S, S, 1, 2]]
316
+ (with S being a special token)
317
+
318
+ Args:
319
+ n_q (int): Number of codebooks.
320
+ delays (Optional[List[int]]): Delay for each of the codebooks.
321
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
322
+ flatten_first (int): Flatten the first N timesteps.
323
+ empty_initial (int): Prepend with N empty list of coordinates.
324
+ """
325
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
326
+ flatten_first: int = 0, empty_initial: int = 0):
327
+ super().__init__(n_q)
328
+ if delays is None:
329
+ delays = list(range(n_q))
330
+ self.delays = delays
331
+ self.flatten_first = flatten_first
332
+ self.empty_initial = empty_initial
333
+ assert len(self.delays) == self.n_q
334
+ assert sorted(self.delays) == self.delays
335
+
336
+ def get_pattern(self, timesteps: int) -> Pattern:
337
+ out: PatternLayout = [[]]
338
+ max_delay = max(self.delays)
339
+ if self.empty_initial:
340
+ out += [[] for _ in range(self.empty_initial)]
341
+ if self.flatten_first:
342
+ for t in range(min(timesteps, self.flatten_first)):
343
+ for q in range(self.n_q):
344
+ out.append([LayoutCoord(t, q)])
345
+ for t in range(self.flatten_first, timesteps + max_delay):
346
+ v = []
347
+ for q, delay in enumerate(self.delays):
348
+ t_for_q = t - delay
349
+ if t_for_q >= self.flatten_first:
350
+ v.append(LayoutCoord(t_for_q, q))
351
+ out.append(v)
352
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
353
+
354
+
355
+ class ParallelPatternProvider(DelayedPatternProvider):
356
+ """Provider for parallel pattern across codebooks.
357
+ This pattern provider is a special case of the delayed pattern with actually no delay,
358
+ hence delays=repeat(0, n_q).
359
+
360
+ Args:
361
+ n_q (int): Number of codebooks.
362
+ """
363
+ def __init__(self, n_q: int):
364
+ super().__init__(n_q, [0] * n_q)
365
+
366
+
367
+ class UnrolledPatternProvider(CodebooksPatternProvider):
368
+ """Provider for unrolling codebooks pattern.
369
+ This pattern provider enables to represent the codebook flattened completely or only to some extend
370
+ while also specifying a given delay between the flattened codebooks representation, allowing to
371
+ unroll the codebooks in the sequence.
372
+
373
+ Example:
374
+ 1. Flattening of the codebooks.
375
+ By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
376
+ taking n_q = 3 and timesteps = 4:
377
+ [[1, 2, 3, 4],
378
+ [1, 2, 3, 4],
379
+ [1, 2, 3, 4]]
380
+ will result into:
381
+ [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
382
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
383
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
384
+ 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
385
+ for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
386
+ taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
387
+ [[1, 2, 3, 4],
388
+ [1, 2, 3, 4],
389
+ [1, 2, 3, 4]]
390
+ will result into:
391
+ [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
392
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
393
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
394
+ 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
395
+ allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
396
+ same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
397
+ and delays = [0, 3, 3]:
398
+ [[1, 2, 3, 4],
399
+ [1, 2, 3, 4],
400
+ [1, 2, 3, 4]]
401
+ will result into:
402
+ [[S, S, S, 1, S, 2, S, 3, S, 4],
403
+ [S, S, S, 1, S, 2, S, 3, S, 4],
404
+ [1, 2, 3, S, 4, S, 5, S, 6, S]]
405
+
406
+ Args:
407
+ n_q (int): Number of codebooks.
408
+ flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined,
409
+ the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
410
+ have n_q extra steps for each timestep.
411
+ delays (Optional[List[int]]): Delay for each of the codebooks. If not defined,
412
+ no delay is added and therefore will default to [0] * ``n_q``.
413
+ Note that two codebooks that will be flattened to the same inner step
414
+ should have the same delay, otherwise the pattern is considered as invalid.
415
+ """
416
+ FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
417
+
418
+ def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
419
+ delays: tp.Optional[tp.List[int]] = None):
420
+ super().__init__(n_q)
421
+ if flattening is None:
422
+ flattening = list(range(n_q))
423
+ if delays is None:
424
+ delays = [0] * n_q
425
+ assert len(flattening) == n_q
426
+ assert len(delays) == n_q
427
+ assert sorted(flattening) == flattening
428
+ assert sorted(delays) == delays
429
+ self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
430
+ self.max_delay = max(delays)
431
+
432
+ def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
433
+ """Build a flattened codebooks representation as a dictionary of inner step
434
+ and the actual codebook indices corresponding to the flattened codebook. For convenience, we
435
+ also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
436
+ """
437
+ flattened_codebooks: dict = {}
438
+ for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
439
+ if inner_step not in flattened_codebooks:
440
+ flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
441
+ else:
442
+ flat_codebook = flattened_codebooks[inner_step]
443
+ assert flat_codebook.delay == delay, (
444
+ "Delay and flattening between codebooks is inconsistent: ",
445
+ "two codebooks flattened to the same position should have the same delay."
446
+ )
447
+ flat_codebook.codebooks.append(q)
448
+ flattened_codebooks[inner_step] = flat_codebook
449
+ return flattened_codebooks
450
+
451
+ @property
452
+ def _num_inner_steps(self):
453
+ """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
454
+ """
455
+ return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
456
+
457
+ def num_virtual_steps(self, timesteps: int) -> int:
458
+ return timesteps * self._num_inner_steps + 1
459
+
460
+ def get_pattern(self, timesteps: int) -> Pattern:
461
+ """Builds pattern for delay across codebooks.
462
+
463
+ Args:
464
+ timesteps (int): Total numer of timesteps.
465
+ """
466
+ # the PatternLayout is built as a tuple of sequence position and list of coordinates
467
+ # so that it can be reordered properly given the required delay between codebooks of given timesteps
468
+ indexed_out: list = [(-1, [])]
469
+ max_timesteps = timesteps + self.max_delay
470
+ for t in range(max_timesteps):
471
+ # for each timestep, we unroll the flattened codebooks,
472
+ # emitting the sequence step with the corresponding delay
473
+ for step in range(self._num_inner_steps):
474
+ if step in self._flattened_codebooks:
475
+ # we have codebooks at this virtual step to emit
476
+ step_codebooks = self._flattened_codebooks[step]
477
+ t_for_q = t + step_codebooks.delay
478
+ coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
479
+ if t_for_q < max_timesteps and t < max_timesteps:
480
+ indexed_out.append((t_for_q, coords))
481
+ else:
482
+ # there is no codebook in this virtual step so we emit an empty list
483
+ indexed_out.append((t, []))
484
+ out = [coords for _, coords in sorted(indexed_out)]
485
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
486
+
487
+
488
+ class VALLEPattern(CodebooksPatternProvider):
489
+ """Almost VALL-E style pattern. We futher allow some delays for the
490
+ codebooks other than the first one.
491
+
492
+ Args:
493
+ n_q (int): Number of codebooks.
494
+ delays (Optional[List[int]]): Delay for each of the codebooks.
495
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
496
+ """
497
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
498
+ super().__init__(n_q)
499
+ if delays is None:
500
+ delays = [0] * (n_q - 1)
501
+ self.delays = delays
502
+ assert len(self.delays) == self.n_q - 1
503
+ assert sorted(self.delays) == self.delays
504
+
505
+ def get_pattern(self, timesteps: int) -> Pattern:
506
+ out: PatternLayout = [[]]
507
+ for t in range(timesteps):
508
+ out.append([LayoutCoord(t, 0)])
509
+ max_delay = max(self.delays)
510
+ for t in range(timesteps + max_delay):
511
+ v = []
512
+ for q, delay in enumerate(self.delays):
513
+ t_for_q = t - delay
514
+ if t_for_q >= 0:
515
+ v.append(LayoutCoord(t_for_q, q + 1))
516
+ out.append(v)
517
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
518
+
519
+
520
+ class MusicLMPattern(CodebooksPatternProvider):
521
+ """Almost MusicLM style pattern. This is equivalent to full flattening
522
+ but in a different order.
523
+
524
+ Args:
525
+ n_q (int): Number of codebooks.
526
+ group_by (int): Number of codebooks to group together.
527
+ """
528
+ def __init__(self, n_q: int, group_by: int = 2):
529
+ super().__init__(n_q)
530
+ self.group_by = group_by
531
+
532
+ def get_pattern(self, timesteps: int) -> Pattern:
533
+ out: PatternLayout = [[]]
534
+ for offset in range(0, self.n_q, self.group_by):
535
+ for t in range(timesteps):
536
+ for q in range(offset, offset + self.group_by):
537
+ out.append([LayoutCoord(t, q)])
538
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
models/modules/__init__.py ADDED
File without changes
models/modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (140 Bytes). View file
 
models/modules/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (132 Bytes). View file
 
models/modules/__pycache__/activation.cpython-310.pyc ADDED
Binary file (18.8 kB). View file
 
models/modules/__pycache__/activation.cpython-39.pyc ADDED
Binary file (18.8 kB). View file
 
models/modules/__pycache__/embedding.cpython-310.pyc ADDED
Binary file (3.07 kB). View file
 
models/modules/__pycache__/embedding.cpython-39.pyc ADDED
Binary file (3.04 kB). View file
 
models/modules/__pycache__/scaling.cpython-310.pyc ADDED
Binary file (40.4 kB). View file
 
models/modules/__pycache__/scaling.cpython-39.pyc ADDED
Binary file (40 kB). View file
 
models/modules/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (16.1 kB). View file
 
models/modules/__pycache__/transformer.cpython-39.pyc ADDED
Binary file (15.8 kB). View file
 
models/modules/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.42 kB). View file
 
models/modules/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.41 kB). View file
 
models/modules/__pycache__/visualizer.cpython-39.pyc ADDED
Binary file (2.02 kB). View file
 
models/modules/activation.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py, modified by Puyuan Peng, 2024
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear, Module
7
+ from torch.nn import functional as F
8
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
9
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
10
+ from torch.nn.parameter import Parameter
11
+ import logging
12
+ from typing import Callable, List, Optional, Tuple, Union
13
+ from typing import TYPE_CHECKING
14
+ if TYPE_CHECKING:
15
+ from torch.types import _dtype as DType
16
+ else:
17
+ # The JIT doesn't understand Union, nor torch.dtype here
18
+ DType = int
19
+
20
+ def _canonical_mask(
21
+ mask: Optional[Tensor],
22
+ mask_name: str,
23
+ other_type: Optional[DType],
24
+ other_name: str,
25
+ target_type: DType,
26
+ check_other: bool = True,
27
+ ) -> Optional[Tensor]:
28
+
29
+ if mask is not None:
30
+ _mask_dtype = mask.dtype
31
+ _mask_is_float = torch.is_floating_point(mask)
32
+ if _mask_dtype != torch.bool and not _mask_is_float:
33
+ raise AssertionError(
34
+ f"only bool and floating types of {mask_name} are supported")
35
+ if check_other and other_type is not None:
36
+ if _mask_dtype != other_type:
37
+ warnings.warn(
38
+ f"Support for mismatched {mask_name} and {other_name} "
39
+ "is deprecated. Use same type for both instead."
40
+ )
41
+ if not _mask_is_float:
42
+ mask = (
43
+ torch.zeros_like(mask, dtype=target_type)
44
+ .masked_fill_(mask, float("-inf"))
45
+ )
46
+ return mask
47
+
48
+ def _in_projection_packed(
49
+ q: Tensor,
50
+ k: Tensor,
51
+ v: Tensor,
52
+ w: Tensor,
53
+ b: Optional[Tensor] = None,
54
+ ) -> List[Tensor]:
55
+ r"""
56
+ Performs the in-projection step of the attention operation, using packed weights.
57
+ Output is a triple containing projection tensors for query, key and value.
58
+
59
+ Args:
60
+ q, k, v: query, key and value tensors to be projected. For self-attention,
61
+ these are typically the same tensor; for encoder-decoder attention,
62
+ k and v are typically the same tensor. (We take advantage of these
63
+ identities for performance if they are present.) Regardless, q, k and v
64
+ must share a common embedding dimension; otherwise their shapes may vary.
65
+ w: projection weights for q, k and v, packed into a single tensor. Weights
66
+ are packed along dimension 0, in q, k, v order.
67
+ b: optional projection biases for q, k and v, packed into a single tensor
68
+ in q, k, v order.
69
+
70
+ Shape:
71
+ Inputs:
72
+ - q: :math:`(..., E)` where E is the embedding dimension
73
+ - k: :math:`(..., E)` where E is the embedding dimension
74
+ - v: :math:`(..., E)` where E is the embedding dimension
75
+ - w: :math:`(E * 3, E)` where E is the embedding dimension
76
+ - b: :math:`E * 3` where E is the embedding dimension
77
+
78
+ Output:
79
+ - in output list :math:`[q', k', v']`, each output tensor will have the
80
+ same shape as the corresponding input tensor.
81
+ """
82
+ E = q.size(-1)
83
+ if k is v:
84
+ if q is k:
85
+ # self-attention
86
+ proj = F.linear(q, w, b)
87
+ # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
88
+ proj = proj.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
89
+ return proj[0], proj[1], proj[2]
90
+ else:
91
+ # encoder-decoder attention
92
+ w_q, w_kv = w.split([E, E * 2])
93
+ if b is None:
94
+ b_q = b_kv = None
95
+ else:
96
+ b_q, b_kv = b.split([E, E * 2])
97
+ q_proj = F.linear(q, w_q, b_q)
98
+ kv_proj = F.linear(k, w_kv, b_kv)
99
+ # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as chunk()
100
+ kv_proj = kv_proj.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
101
+ return (q_proj, kv_proj[0], kv_proj[1])
102
+ else:
103
+ w_q, w_k, w_v = w.chunk(3)
104
+ if b is None:
105
+ b_q = b_k = b_v = None
106
+ else:
107
+ b_q, b_k, b_v = b.chunk(3)
108
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
109
+
110
+ def _none_or_dtype(input: Optional[Tensor]) -> Optional[DType]:
111
+ if input is None:
112
+ return None
113
+ elif isinstance(input, torch.Tensor):
114
+ return input.dtype
115
+ raise RuntimeError("input to _none_or_dtype() must be None or torch.Tensor")
116
+ class MultiheadAttention(Module):
117
+ r"""Allows the model to jointly attend to information
118
+ from different representation subspaces as described in the paper:
119
+ `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
120
+
121
+ Multi-Head Attention is defined as:
122
+
123
+ .. math::
124
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
125
+
126
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
127
+
128
+ ``forward()`` will use a special optimized implementation if all of the following
129
+ conditions are met:
130
+
131
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
132
+ restriction will be loosened in the future.)
133
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
134
+ - training is disabled (using ``.eval()``)
135
+ - dropout is 0
136
+ - ``add_bias_kv`` is ``False``
137
+ - ``add_zero_attn`` is ``False``
138
+ - ``batch_first`` is ``True`` and the input is batched
139
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
140
+ - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
141
+ - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
142
+ nor ``attn_mask`` is passed
143
+
144
+ If the optimized implementation is in use, a
145
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
146
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
147
+ padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
148
+ will be returned, and an additional speedup proportional to the fraction of the input
149
+ that is padding can be expected.
150
+
151
+ Args:
152
+ embed_dim: Total dimension of the model.
153
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
154
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
155
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
156
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
157
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
158
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
159
+ Default: ``False``.
160
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
161
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
162
+ batch_first: If ``True``, then the input and output tensors are provided
163
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
164
+
165
+ Examples::
166
+
167
+ >>> # xdoctest: +SKIP
168
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
169
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
170
+
171
+ """
172
+ __constants__ = ["batch_first"]
173
+ bias_k: Optional[torch.Tensor]
174
+ bias_v: Optional[torch.Tensor]
175
+
176
+ def __init__(
177
+ self,
178
+ embed_dim,
179
+ num_heads,
180
+ dropout=0.0,
181
+ bias=True,
182
+ add_bias_kv=False,
183
+ add_zero_attn=False,
184
+ kdim=None,
185
+ vdim=None,
186
+ batch_first=False,
187
+ linear1_cls=Linear,
188
+ linear2_cls=Linear,
189
+ device=None,
190
+ dtype=None,
191
+ ) -> None:
192
+ factory_kwargs = {"device": device, "dtype": dtype}
193
+ super(MultiheadAttention, self).__init__()
194
+ self.embed_dim = embed_dim
195
+ self.kdim = kdim if kdim is not None else embed_dim
196
+ self.vdim = vdim if vdim is not None else embed_dim
197
+ self._qkv_same_embed_dim = (
198
+ self.kdim == embed_dim and self.vdim == embed_dim
199
+ )
200
+
201
+ self.num_heads = num_heads
202
+ self.dropout = dropout
203
+ self.batch_first = batch_first
204
+ self.head_dim = embed_dim // num_heads
205
+ assert (
206
+ self.head_dim * num_heads == self.embed_dim
207
+ ), "embed_dim must be divisible by num_heads"
208
+
209
+ if add_bias_kv:
210
+ self.bias_k = Parameter(
211
+ torch.empty((1, 1, embed_dim), **factory_kwargs)
212
+ )
213
+ self.bias_v = Parameter(
214
+ torch.empty((1, 1, embed_dim), **factory_kwargs)
215
+ )
216
+ else:
217
+ self.bias_k = self.bias_v = None
218
+
219
+ if linear1_cls == Linear:
220
+ if not self._qkv_same_embed_dim:
221
+ self.q_proj_weight = Parameter(
222
+ torch.empty((embed_dim, embed_dim), **factory_kwargs)
223
+ )
224
+ self.k_proj_weight = Parameter(
225
+ torch.empty((embed_dim, self.kdim), **factory_kwargs)
226
+ )
227
+ self.v_proj_weight = Parameter(
228
+ torch.empty((embed_dim, self.vdim), **factory_kwargs)
229
+ )
230
+ self.register_parameter("in_proj_weight", None)
231
+ else:
232
+ # go down this route with voicecraft
233
+ self.in_proj_weight = Parameter(
234
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
235
+ )
236
+ self.register_parameter("q_proj_weight", None)
237
+ self.register_parameter("k_proj_weight", None)
238
+ self.register_parameter("v_proj_weight", None)
239
+
240
+ if bias: # True by default
241
+ self.in_proj_bias = Parameter(
242
+ torch.empty(3 * embed_dim, **factory_kwargs)
243
+ )
244
+ else:
245
+ self.register_parameter("in_proj_bias", None)
246
+ self.out_proj = NonDynamicallyQuantizableLinear(
247
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
248
+ )
249
+
250
+ self._reset_parameters()
251
+ else:
252
+ if not self._qkv_same_embed_dim:
253
+ raise NotImplementedError
254
+ else:
255
+ self.in_proj_linear = linear1_cls(
256
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
257
+ )
258
+ self.in_proj_weight = self.in_proj_linear.weight
259
+
260
+ self.register_parameter("q_proj_weight", None)
261
+ self.register_parameter("k_proj_weight", None)
262
+ self.register_parameter("v_proj_weight", None)
263
+
264
+ if bias:
265
+ self.in_proj_bias = self.in_proj_linear.bias
266
+ else:
267
+ self.register_parameter("in_proj_bias", None)
268
+
269
+ self.out_proj = linear2_cls(
270
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
271
+ )
272
+
273
+ if self.bias_k is not None:
274
+ xavier_normal_(self.bias_k)
275
+ if self.bias_v is not None:
276
+ xavier_normal_(self.bias_v)
277
+
278
+ self.add_zero_attn = add_zero_attn
279
+
280
+ def _reset_parameters(self):
281
+ if self._qkv_same_embed_dim:
282
+ xavier_uniform_(self.in_proj_weight)
283
+ else:
284
+ xavier_uniform_(self.q_proj_weight)
285
+ xavier_uniform_(self.k_proj_weight)
286
+ xavier_uniform_(self.v_proj_weight)
287
+
288
+ if self.in_proj_bias is not None:
289
+ constant_(self.in_proj_bias, 0.0)
290
+ constant_(self.out_proj.bias, 0.0)
291
+
292
+ if self.bias_k is not None:
293
+ xavier_normal_(self.bias_k)
294
+ if self.bias_v is not None:
295
+ xavier_normal_(self.bias_v)
296
+
297
+ def __setstate__(self, state):
298
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
299
+ if "_qkv_same_embed_dim" not in state:
300
+ state["_qkv_same_embed_dim"] = True
301
+
302
+ super(MultiheadAttention, self).__setstate__(state)
303
+
304
+ def forward(
305
+ self,
306
+ query: Tensor,
307
+ key: Tensor,
308
+ value: Tensor,
309
+ key_padding_mask: Optional[Tensor] = None,
310
+ need_weights: bool = True,
311
+ attn_mask: Optional[Tensor] = None,
312
+ average_attn_weights: bool = True,
313
+ past: Optional[Tensor] = None,
314
+ ) -> Tuple[Tensor, Optional[Tensor]]:
315
+ r"""
316
+ Args:
317
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
318
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
319
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
320
+ Queries are compared against key-value pairs to produce the output.
321
+ See "Attention Is All You Need" for more details.
322
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
323
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
324
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
325
+ See "Attention Is All You Need" for more details.
326
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
327
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
328
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
329
+ See "Attention Is All You Need" for more details.
330
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
331
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
332
+ Binary and byte masks are supported.
333
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
334
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
335
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
336
+ Default: ``True``.
337
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
338
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
339
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
340
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
341
+ Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
342
+ corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
343
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
344
+ the attention weight.
345
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
346
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
347
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
348
+
349
+ Outputs:
350
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
351
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
352
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
353
+ embedding dimension ``embed_dim``.
354
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
355
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
356
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
357
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
358
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
359
+
360
+ .. note::
361
+ `batch_first` argument is ignored for unbatched inputs.
362
+ """
363
+ is_batched = query.dim() == 3
364
+ if key_padding_mask is not None:
365
+ _kpm_dtype = key_padding_mask.dtype
366
+ if _kpm_dtype != torch.bool and not torch.is_floating_point(
367
+ key_padding_mask
368
+ ):
369
+ raise AssertionError(
370
+ "only bool and floating types of key_padding_mask are supported"
371
+ )
372
+ why_not_fast_path = ""
373
+ if not is_batched:
374
+ why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
375
+ elif query is not key or key is not value:
376
+ # When lifting this restriction, don't forget to either
377
+ # enforce that the dtypes all match or test cases where
378
+ # they don't!
379
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
380
+ elif (
381
+ self.in_proj_bias is not None
382
+ and query.dtype != self.in_proj_bias.dtype
383
+ ):
384
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
385
+ elif (
386
+ self.in_proj_weight is not None
387
+ and query.dtype != self.in_proj_weight.dtype
388
+ ):
389
+ # this case will fail anyway, but at least they'll get a useful error message.
390
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
391
+ elif self.training:
392
+ why_not_fast_path = "training is enabled"
393
+ elif not self.batch_first:
394
+ why_not_fast_path = "batch_first was not True"
395
+ elif self.bias_k is not None:
396
+ why_not_fast_path = "self.bias_k was not None"
397
+ elif self.bias_v is not None:
398
+ why_not_fast_path = "self.bias_v was not None"
399
+ elif self.dropout:
400
+ why_not_fast_path = f"dropout was {self.dropout}, required zero"
401
+ elif self.add_zero_attn:
402
+ why_not_fast_path = "add_zero_attn was enabled"
403
+ elif not self._qkv_same_embed_dim:
404
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
405
+ elif attn_mask is not None:
406
+ why_not_fast_path = "attn_mask was not None"
407
+ elif query.is_nested and key_padding_mask is not None:
408
+ why_not_fast_path = (
409
+ "key_padding_mask is not supported with NestedTensor input"
410
+ )
411
+ elif self.num_heads % 2 == 1:
412
+ why_not_fast_path = "num_heads is odd"
413
+ elif torch.is_autocast_enabled():
414
+ why_not_fast_path = "autocast is enabled"
415
+
416
+ if not why_not_fast_path:
417
+ tensor_args = (
418
+ query,
419
+ key,
420
+ value,
421
+ self.in_proj_weight,
422
+ self.in_proj_bias,
423
+ self.out_proj.weight,
424
+ self.out_proj.bias,
425
+ )
426
+ # We have to use list comprehensions below because TorchScript does not support
427
+ # generator expressions.
428
+ if torch.overrides.has_torch_function(tensor_args):
429
+ why_not_fast_path = "some Tensor argument has_torch_function"
430
+ elif not all(
431
+ [
432
+ (x is None or x.is_cuda or "cpu" in str(x.device))
433
+ for x in tensor_args
434
+ ]
435
+ ):
436
+ why_not_fast_path = (
437
+ "some Tensor argument is neither CUDA nor CPU"
438
+ )
439
+ elif torch.is_grad_enabled() and any(
440
+ [x is not None and x.requires_grad for x in tensor_args]
441
+ ):
442
+ why_not_fast_path = (
443
+ "grad is enabled and at least one of query or the "
444
+ "input/output projection weights or biases requires_grad"
445
+ )
446
+ if not why_not_fast_path:
447
+ return torch._native_multi_head_attention(
448
+ query,
449
+ key,
450
+ value,
451
+ self.embed_dim,
452
+ self.num_heads,
453
+ self.in_proj_weight,
454
+ self.in_proj_bias,
455
+ self.out_proj.weight,
456
+ self.out_proj.bias,
457
+ key_padding_mask
458
+ if key_padding_mask is not None
459
+ else attn_mask,
460
+ need_weights,
461
+ average_attn_weights,
462
+ 1
463
+ if key_padding_mask is not None
464
+ else 0
465
+ if attn_mask is not None
466
+ else None,
467
+ )
468
+
469
+ any_nested = query.is_nested or key.is_nested or value.is_nested
470
+ assert not any_nested, (
471
+ "MultiheadAttention does not support NestedTensor outside of its fast path. "
472
+ + f"The fast path was not hit because {why_not_fast_path}"
473
+ )
474
+
475
+ if self.batch_first and is_batched:
476
+ # make sure that the transpose op does not affect the "is" property
477
+ if key is value:
478
+ if query is key:
479
+ query = key = value = query.transpose(1, 0)
480
+ else:
481
+ query, key = [x.transpose(1, 0) for x in (query, key)]
482
+ value = key
483
+ else:
484
+ query, key, value = [
485
+ x.transpose(1, 0) for x in (query, key, value)
486
+ ]
487
+
488
+ if not self._qkv_same_embed_dim:
489
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
490
+ query,
491
+ key,
492
+ value,
493
+ self.embed_dim,
494
+ self.num_heads,
495
+ self.in_proj_weight,
496
+ self.in_proj_bias,
497
+ self.bias_k,
498
+ self.bias_v,
499
+ self.add_zero_attn,
500
+ self.dropout,
501
+ self.out_proj.weight,
502
+ self.out_proj.bias,
503
+ training=self.training,
504
+ key_padding_mask=key_padding_mask,
505
+ need_weights=need_weights,
506
+ attn_mask=attn_mask,
507
+ use_separate_proj_weight=True,
508
+ q_proj_weight=self.q_proj_weight,
509
+ k_proj_weight=self.k_proj_weight,
510
+ v_proj_weight=self.v_proj_weight,
511
+ average_attn_weights=average_attn_weights,
512
+ )
513
+ else:
514
+ # re-write the self.attention here, to get k, v cache
515
+ tgt_len, bsz, embed_dim = query.shape
516
+ src_len, _, _ = key.shape
517
+ num_heads = self.num_heads
518
+ key_padding_mask = _canonical_mask(
519
+ mask=key_padding_mask,
520
+ mask_name="key_padding_mask",
521
+ other_type=_none_or_dtype(attn_mask),
522
+ other_name="attn_mask",
523
+ target_type=query.dtype
524
+ )
525
+ attn_mask = _canonical_mask(
526
+ mask=attn_mask,
527
+ mask_name="attn_mask",
528
+ other_type=None,
529
+ other_name="",
530
+ target_type=query.dtype,
531
+ check_other=False,
532
+ )
533
+ head_dim = self.embed_dim // self.num_heads
534
+ assert head_dim * self.num_heads == self.embed_dim, f"embed_dim {self.embed_dim} not divisible by num_heads {self.num_heads}"
535
+ assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
536
+ q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
537
+ # k_present, v_present = k, v
538
+
539
+ #
540
+ # reshape q, k, v for multihead attention and make em batch first
541
+ #
542
+
543
+ q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
544
+ k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
545
+ v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) # (bsz * num_heads, src_len, head_dim)
546
+ src_len = k.size(1)
547
+ if past is not None and past.ndim > 2:
548
+ expected_src_len = src_len + past[0].shape[-2]
549
+ else:
550
+ expected_src_len = src_len
551
+
552
+
553
+ # ensure attn_mask's dim is 3
554
+ if attn_mask.dim() == 2:
555
+ correct_2d_size = (tgt_len, expected_src_len)
556
+ if attn_mask.shape != correct_2d_size:
557
+ raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
558
+ attn_mask = attn_mask.unsqueeze(0)
559
+ elif attn_mask.dim() == 3:
560
+ correct_3d_size = (bsz * num_heads, tgt_len, expected_src_len)
561
+ if attn_mask.shape != correct_3d_size:
562
+ raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
563
+ else:
564
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
565
+
566
+ if key_padding_mask is not None:
567
+ assert key_padding_mask.shape == (bsz, expected_src_len), \
568
+ f"expecting key_padding_mask shape of {(bsz, expected_src_len)}, but got {key_padding_mask.shape}"
569
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, expected_src_len). \
570
+ expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, expected_src_len)
571
+ if attn_mask is None:
572
+ attn_mask = key_padding_mask
573
+ else:
574
+ attn_mask = attn_mask + key_padding_mask
575
+
576
+ if not self.training:
577
+ dropout_p = 0.0
578
+ else:
579
+ dropout_p = self.dropout
580
+
581
+ if need_weights:
582
+ raise NotImplementedError("need_weights not implemented for voicecraft")
583
+ # B, Nt, E = q.shape
584
+ # q_scaled = q / math.sqrt(E)
585
+
586
+ # assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
587
+
588
+ # if attn_mask is not None:
589
+ # attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
590
+ # else:
591
+ # attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
592
+ # attn_output_weights = softmax(attn_output_weights, dim=-1)
593
+ # if dropout_p > 0.0:
594
+ # attn_output_weights = dropout(attn_output_weights, p=dropout_p)
595
+
596
+ # attn_output = torch.bmm(attn_output_weights, v)
597
+
598
+ # attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
599
+ # attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
600
+ # attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
601
+
602
+ # # optionally average attention weights over heads
603
+ # attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
604
+ # if average_attn_weights:
605
+ # attn_output_weights = attn_output_weights.mean(dim=1)
606
+
607
+ # if not is_batched:
608
+ # # squeeze the output if input was unbatched
609
+ # attn_output = attn_output.squeeze(1)
610
+ # attn_output_weights = attn_output_weights.squeeze(0)
611
+ # return attn_output, attn_output_weights
612
+ else:
613
+ # attn_mask can be either (L,S) or (N*num_heads, L, S)
614
+ # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
615
+ # in order to match the input for SDPA of (N, num_heads, L, S)
616
+ if attn_mask is not None:
617
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
618
+ attn_mask = attn_mask.unsqueeze(0)
619
+ else:
620
+ attn_mask = attn_mask.view(bsz, num_heads, -1, expected_src_len)
621
+
622
+ q = q.view(bsz, num_heads, tgt_len, head_dim)
623
+ k = k.view(bsz, num_heads, src_len, head_dim)
624
+ v = v.view(bsz, num_heads, src_len, head_dim)
625
+ # logging.info(f"shape of past: {past.shape}")
626
+ if past is not None:
627
+ present = torch.stack([k, v], dim=0) # (2, bsz, num_heads, src_len, head_dim)
628
+ if past.ndim > 2: # this means we use kvcache, otherwise we just pass in a placeholder, but not actually using kvcache
629
+ pk, pv = past
630
+ k = torch.cat([pk, k], dim=-2)
631
+ v = torch.cat([pv, v], dim=-2)
632
+ else:
633
+ present = None
634
+ attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal=False)
635
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
636
+
637
+ attn_output = F.linear(attn_output, self.out_proj.weight, self.out_proj.bias)
638
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
639
+ if not is_batched:
640
+ # squeeze the output if input was unbatched
641
+ attn_output = attn_output.squeeze(1)
642
+ # if self.training:
643
+ # return attn_output, None
644
+ # else:
645
+ # return (attn_output, present), None
646
+
647
+ # harded coded, the code do not support returning attn weigths yet
648
+ attn_output_weights=None
649
+ if self.batch_first and is_batched:
650
+ return attn_output.transpose(1, 0), present
651
+ else:
652
+ return attn_output, present
653
+
models/modules/embedding.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
2
+ # Copyright 2023 (authors: Feiteng Li)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+
22
+ class TokenEmbedding(nn.Module):
23
+ def __init__(
24
+ self,
25
+ dim_model: int,
26
+ vocab_size: int,
27
+ dropout: float = 0.0,
28
+ ):
29
+ super().__init__()
30
+
31
+ self.vocab_size = vocab_size
32
+ self.dim_model = dim_model
33
+
34
+ self.dropout = torch.nn.Dropout(p=dropout)
35
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
36
+
37
+ @property
38
+ def weight(self) -> torch.Tensor:
39
+ return self.word_embeddings.weight
40
+
41
+ def embedding(self, index: int) -> torch.Tensor:
42
+ return self.word_embeddings.weight[index : index + 1]
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ X = self.word_embeddings(x)
46
+ X = self.dropout(X)
47
+
48
+ return X
49
+
50
+
51
+ class SinePositionalEmbedding(nn.Module):
52
+ def __init__(
53
+ self,
54
+ dim_model: int,
55
+ dropout: float = 0.0,
56
+ scale: bool = False,
57
+ alpha: bool = False,
58
+ ):
59
+ super().__init__()
60
+ self.dim_model = dim_model
61
+ self.x_scale = math.sqrt(dim_model) if scale else 1.0
62
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
63
+ self.dropout = torch.nn.Dropout(p=dropout)
64
+
65
+ self.reverse = False
66
+ self.pe = None
67
+ self.extend_pe(torch.tensor(0.0).expand(1, 4000))
68
+
69
+ def extend_pe(self, x):
70
+ """Reset the positional encodings."""
71
+ if self.pe is not None:
72
+ if self.pe.size(1) >= x.size(1):
73
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
74
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
75
+ return
76
+ pe = torch.zeros(x.size(1), self.dim_model)
77
+ if self.reverse:
78
+ position = torch.arange(
79
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
80
+ ).unsqueeze(1)
81
+ else:
82
+ position = torch.arange(
83
+ 0, x.size(1), dtype=torch.float32
84
+ ).unsqueeze(1)
85
+ div_term = torch.exp(
86
+ torch.arange(0, self.dim_model, 2, dtype=torch.float32)
87
+ * -(math.log(10000.0) / self.dim_model)
88
+ )
89
+ pe[:, 0::2] = torch.sin(position * div_term)
90
+ pe[:, 1::2] = torch.cos(position * div_term)
91
+ pe = pe.unsqueeze(0)
92
+ self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ self.extend_pe(x)
96
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
97
+ output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
98
+ return self.dropout(output)
models/modules/sampling.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ def top_k_top_p_filtering(
5
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
6
+ ):
7
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
8
+ Args:
9
+ logits: logits distribution shape (batch size, vocabulary size)
10
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
11
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
12
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
13
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
14
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
15
+ """
16
+ if top_k > 0:
17
+ top_k = min(
18
+ max(top_k, min_tokens_to_keep), logits.size(-1)
19
+ ) # Safety check
20
+ # Remove all tokens with a probability less than the last token of the top-k
21
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
22
+ logits[indices_to_remove] = filter_value
23
+
24
+ if top_p < 1.0:
25
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
26
+ cumulative_probs = torch.cumsum(
27
+ F.softmax(sorted_logits, dim=-1), dim=-1
28
+ )
29
+
30
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
31
+ sorted_indices_to_remove = cumulative_probs > top_p
32
+ if min_tokens_to_keep > 1:
33
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
34
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
35
+ # Shift the indices to the right to keep also the first token above the threshold
36
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
37
+ ..., :-1
38
+ ].clone()
39
+ sorted_indices_to_remove[..., 0] = 0
40
+
41
+ # scatter sorted tensors to original indexing
42
+ indices_to_remove = sorted_indices_to_remove.scatter(
43
+ 1, sorted_indices, sorted_indices_to_remove
44
+ )
45
+ logits[indices_to_remove] = filter_value
46
+ return logits
47
+
48
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
49
+ # temperature: (`optional`) float
50
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
51
+ # top_k: (`optional`) int
52
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
53
+ # top_p: (`optional`) float
54
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
55
+
56
+ # Temperature (higher temperature => more likely to sample low probability tokens)
57
+ if temperature != 1.0:
58
+ logits = logits / temperature
59
+ # Top-p/top-k filtering
60
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
61
+ # Sample
62
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
63
+ return token
models/modules/scaling.py ADDED
@@ -0,0 +1,1406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/scaling.py
2
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ import collections
20
+ import logging
21
+ import random
22
+ import math
23
+ from functools import reduce
24
+ from itertools import repeat
25
+ from typing import Optional, Tuple, Union
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from torch import Tensor
31
+ from torch.nn import Embedding as ScaledEmbedding
32
+
33
+ # from valle.utils import Transpose
34
+
35
+ class Transpose(nn.Identity):
36
+ """(N, T, D) -> (N, D, T)"""
37
+
38
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
39
+ return input.transpose(1, 2)
40
+
41
+ class ActivationBalancerFunction(torch.autograd.Function):
42
+ @staticmethod
43
+ def forward(
44
+ ctx,
45
+ x: Tensor,
46
+ scale_factor: Tensor,
47
+ sign_factor: Optional[Tensor],
48
+ channel_dim: int,
49
+ ) -> Tensor:
50
+ if channel_dim < 0:
51
+ channel_dim += x.ndim
52
+ ctx.channel_dim = channel_dim
53
+ xgt0 = x > 0
54
+ if sign_factor is None:
55
+ ctx.save_for_backward(xgt0, scale_factor)
56
+ else:
57
+ ctx.save_for_backward(xgt0, scale_factor, sign_factor)
58
+ return x
59
+
60
+ @staticmethod
61
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
62
+ if len(ctx.saved_tensors) == 3:
63
+ xgt0, scale_factor, sign_factor = ctx.saved_tensors
64
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
65
+ scale_factor = scale_factor.unsqueeze(-1)
66
+ sign_factor = sign_factor.unsqueeze(-1)
67
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
68
+ else:
69
+ xgt0, scale_factor = ctx.saved_tensors
70
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
71
+ scale_factor = scale_factor.unsqueeze(-1)
72
+ factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
73
+ neg_delta_grad = x_grad.abs() * factor
74
+ return (
75
+ x_grad - neg_delta_grad,
76
+ None,
77
+ None,
78
+ None,
79
+ )
80
+
81
+
82
+ def _compute_scale_factor(
83
+ x: Tensor,
84
+ channel_dim: int,
85
+ min_abs: float,
86
+ max_abs: float,
87
+ gain_factor: float,
88
+ max_factor: float,
89
+ ) -> Tensor:
90
+ if channel_dim < 0:
91
+ channel_dim += x.ndim
92
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
93
+ x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
94
+
95
+ if min_abs == 0.0:
96
+ below_threshold = 0.0
97
+ else:
98
+ # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
99
+ # x_abs)_mean , min_abs.
100
+ below_threshold = (
101
+ (min_abs - x_abs_mean) * (gain_factor / min_abs)
102
+ ).clamp(min=0, max=max_factor)
103
+
104
+ above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
105
+ min=0, max=max_factor
106
+ )
107
+
108
+ return below_threshold - above_threshold
109
+
110
+
111
+ def _compute_sign_factor(
112
+ x: Tensor,
113
+ channel_dim: int,
114
+ min_positive: float,
115
+ max_positive: float,
116
+ gain_factor: float,
117
+ max_factor: float,
118
+ ) -> Tensor:
119
+ if channel_dim < 0:
120
+ channel_dim += x.ndim
121
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
122
+ proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
123
+ if min_positive == 0.0:
124
+ factor1 = 0.0
125
+ else:
126
+ # 0 if proportion_positive >= min_positive, else can be
127
+ # as large as max_factor.
128
+ factor1 = (
129
+ (min_positive - proportion_positive) * (gain_factor / min_positive)
130
+ ).clamp_(min=0, max=max_factor)
131
+
132
+ if max_positive == 1.0:
133
+ factor2 = 0.0
134
+ else:
135
+ # 0 if self.proportion_positive <= max_positive, else can be
136
+ # as large as -max_factor.
137
+ factor2 = (
138
+ (proportion_positive - max_positive)
139
+ * (gain_factor / (1.0 - max_positive))
140
+ ).clamp_(min=0, max=max_factor)
141
+ sign_factor = factor1 - factor2
142
+ # require min_positive != 0 or max_positive != 1:
143
+ assert not isinstance(sign_factor, float)
144
+ return sign_factor
145
+
146
+
147
+ class ActivationScaleBalancerFunction(torch.autograd.Function):
148
+ """
149
+ This object is used in class ActivationBalancer when the user specified
150
+ min_positive=0, max_positive=1, so there are no constraints on the signs
151
+ of the activations and only the absolute value has a constraint.
152
+ """
153
+
154
+ @staticmethod
155
+ def forward(
156
+ ctx,
157
+ x: Tensor,
158
+ sign_factor: Tensor,
159
+ scale_factor: Tensor,
160
+ channel_dim: int,
161
+ ) -> Tensor:
162
+ if channel_dim < 0:
163
+ channel_dim += x.ndim
164
+ ctx.channel_dim = channel_dim
165
+ xgt0 = x > 0
166
+ ctx.save_for_backward(xgt0, sign_factor, scale_factor)
167
+ return x
168
+
169
+ @staticmethod
170
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
171
+ xgt0, sign_factor, scale_factor = ctx.saved_tensors
172
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
173
+ sign_factor = sign_factor.unsqueeze(-1)
174
+ scale_factor = scale_factor.unsqueeze(-1)
175
+
176
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
177
+ neg_delta_grad = x_grad.abs() * factor
178
+ return (
179
+ x_grad - neg_delta_grad,
180
+ None,
181
+ None,
182
+ None,
183
+ )
184
+
185
+
186
+ class RandomClampFunction(torch.autograd.Function):
187
+ @staticmethod
188
+ def forward(
189
+ ctx,
190
+ x: Tensor,
191
+ min: Optional[float],
192
+ max: Optional[float],
193
+ prob: float,
194
+ reflect: float,
195
+ ) -> Tensor:
196
+ x_clamped = torch.clamp(x, min=min, max=max)
197
+ mask = torch.rand_like(x) < prob
198
+ ans = torch.where(mask, x_clamped, x)
199
+ if x.requires_grad:
200
+ ctx.save_for_backward(ans == x)
201
+ ctx.reflect = reflect
202
+ if reflect != 0.0:
203
+ ans = ans * (1.0 + reflect) - (x * reflect)
204
+ return ans
205
+
206
+ @staticmethod
207
+ def backward(
208
+ ctx, ans_grad: Tensor
209
+ ) -> Tuple[Tensor, None, None, None, None]:
210
+ (is_same,) = ctx.saved_tensors
211
+ x_grad = ans_grad * is_same.to(ans_grad.dtype)
212
+ reflect = ctx.reflect
213
+ if reflect != 0.0:
214
+ x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
215
+ return x_grad, None, None, None, None
216
+
217
+
218
+ def random_clamp(
219
+ x: Tensor,
220
+ min: Optional[float] = None,
221
+ max: Optional[float] = None,
222
+ prob: float = 0.5,
223
+ reflect: float = 0.0,
224
+ ):
225
+ return RandomClampFunction.apply(x, min, max, prob, reflect)
226
+
227
+
228
+ def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
229
+ """
230
+ A randomized way of casting a floating point value to half precision.
231
+ """
232
+ if x.dtype == torch.float16:
233
+ return x
234
+ x_abs = x.abs()
235
+ is_too_small = x_abs < min_abs
236
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
237
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
238
+ # for those elements].
239
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
240
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
241
+
242
+
243
+ class RandomGradFunction(torch.autograd.Function):
244
+ """
245
+ Does nothing in forward pass; in backward pass, gets rid of very small grads using
246
+ randomized approach that preserves expectations (intended to reduce roundoff).
247
+ """
248
+
249
+ @staticmethod
250
+ def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
251
+ ctx.min_abs = min_abs
252
+ return x
253
+
254
+ @staticmethod
255
+ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
256
+ if ans_grad.dtype == torch.float16:
257
+ return (
258
+ random_cast_to_half(
259
+ ans_grad.to(torch.float32), min_abs=ctx.min_abs
260
+ ),
261
+ None,
262
+ )
263
+ else:
264
+ return ans_grad, None
265
+
266
+
267
+ class RandomGrad(torch.nn.Module):
268
+ """
269
+ Gets rid of very small gradients using an expectation-preserving method, intended to increase
270
+ accuracy of training when using amp (automatic mixed precision)
271
+ """
272
+
273
+ def __init__(self, min_abs: float = 5.0e-06):
274
+ super(RandomGrad, self).__init__()
275
+ self.min_abs = min_abs
276
+
277
+ def forward(self, x: Tensor):
278
+ if (
279
+ torch.jit.is_scripting()
280
+ or not self.training
281
+ or torch.jit.is_tracing()
282
+ ):
283
+ return x
284
+ else:
285
+ return RandomGradFunction.apply(x, self.min_abs)
286
+
287
+
288
+ class SoftmaxFunction(torch.autograd.Function):
289
+ """
290
+ Tries to handle half-precision derivatives in a randomized way that should
291
+ be more accurate for training than the default behavior.
292
+ """
293
+
294
+ @staticmethod
295
+ def forward(ctx, x: Tensor, dim: int):
296
+ ans = x.softmax(dim=dim)
297
+ # if x dtype is float16, x.softmax() returns a float32 because
298
+ # (presumably) that op does not support float16, and autocast
299
+ # is enabled.
300
+ if torch.is_autocast_enabled():
301
+ ans = ans.to(torch.float16)
302
+ ctx.save_for_backward(ans)
303
+ ctx.x_dtype = x.dtype
304
+ ctx.dim = dim
305
+ return ans
306
+
307
+ @staticmethod
308
+ def backward(ctx, ans_grad: Tensor):
309
+ (ans,) = ctx.saved_tensors
310
+ with torch.cuda.amp.autocast(enabled=False):
311
+ ans_grad = ans_grad.to(torch.float32)
312
+ ans = ans.to(torch.float32)
313
+ x_grad = ans_grad * ans
314
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
315
+ return x_grad, None
316
+
317
+
318
+ def softmax(x: Tensor, dim: int):
319
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
320
+ return x.softmax(dim)
321
+
322
+ return SoftmaxFunction.apply(x, dim)
323
+
324
+
325
+ class MaxEigLimiterFunction(torch.autograd.Function):
326
+ @staticmethod
327
+ def forward(
328
+ ctx,
329
+ x: Tensor,
330
+ coeffs: Tensor,
331
+ direction: Tensor,
332
+ channel_dim: int,
333
+ grad_scale: float,
334
+ ) -> Tensor:
335
+ ctx.channel_dim = channel_dim
336
+ ctx.grad_scale = grad_scale
337
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
338
+ return x
339
+
340
+ @staticmethod
341
+ def backward(ctx, x_grad, *args):
342
+ with torch.enable_grad():
343
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
344
+ x_orig.requires_grad = True
345
+ num_channels = x_orig.shape[ctx.channel_dim]
346
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
347
+ new_direction.requires_grad = False
348
+ x = x - x.mean(dim=0)
349
+ x_var = (x ** 2).mean()
350
+ x_residual = x - coeffs * new_direction
351
+ x_residual_var = (x_residual ** 2).mean()
352
+ # `variance_proportion` is the proportion of the variance accounted for
353
+ # by the top eigen-direction. This is to be minimized.
354
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
355
+ variance_proportion.backward()
356
+ x_orig_grad = x_orig.grad
357
+ x_extra_grad = (
358
+ x_orig.grad
359
+ * ctx.grad_scale
360
+ * x_grad.norm()
361
+ / (x_orig_grad.norm() + 1.0e-20)
362
+ )
363
+ return x_grad + x_extra_grad.detach(), None, None, None, None
364
+
365
+
366
+ class BasicNorm(torch.nn.Module):
367
+ """
368
+ This is intended to be a simpler, and hopefully cheaper, replacement for
369
+ LayerNorm. The observation this is based on, is that Transformer-type
370
+ networks, especially with pre-norm, sometimes seem to set one of the
371
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
372
+ the LayerNorm because the output magnitude is then not strongly dependent
373
+ on the other (useful) features. Presumably the weight and bias of the
374
+ LayerNorm are required to allow it to do this.
375
+
376
+ So the idea is to introduce this large constant value as an explicit
377
+ parameter, that takes the role of the "eps" in LayerNorm, so the network
378
+ doesn't have to do this trick. We make the "eps" learnable.
379
+
380
+ Args:
381
+ num_channels: the number of channels, e.g. 512.
382
+ channel_dim: the axis/dimension corresponding to the channel,
383
+ interprted as an offset from the input's ndim if negative.
384
+ shis is NOT the num_channels; it should typically be one of
385
+ {-2, -1, 0, 1, 2, 3}.
386
+ eps: the initial "epsilon" that we add as ballast in:
387
+ scale = ((input_vec**2).mean() + epsilon)**-0.5
388
+ Note: our epsilon is actually large, but we keep the name
389
+ to indicate the connection with conventional LayerNorm.
390
+ learn_eps: if true, we learn epsilon; if false, we keep it
391
+ at the initial value.
392
+ eps_min: float
393
+ eps_max: float
394
+ """
395
+
396
+ def __init__(
397
+ self,
398
+ num_channels: int,
399
+ channel_dim: int = -1, # CAUTION: see documentation.
400
+ eps: float = 0.25,
401
+ learn_eps: bool = True,
402
+ eps_min: float = -3.0,
403
+ eps_max: float = 3.0,
404
+ ) -> None:
405
+ super(BasicNorm, self).__init__()
406
+ self.num_channels = num_channels
407
+ self.channel_dim = channel_dim
408
+ if learn_eps:
409
+ self.eps = nn.Parameter(torch.tensor(eps).log().detach())
410
+ else:
411
+ self.register_buffer("eps", torch.tensor(eps).log().detach())
412
+ self.eps_min = eps_min
413
+ self.eps_max = eps_max
414
+
415
+ def forward(self, x: Tensor) -> Tensor:
416
+ assert x.shape[self.channel_dim] == self.num_channels
417
+ eps = self.eps
418
+ if self.training and random.random() < 0.25:
419
+ # with probability 0.25, in training mode, clamp eps between the min
420
+ # and max; this will encourage it to learn parameters within the
421
+ # allowed range by making parameters that are outside the allowed
422
+ # range noisy.
423
+
424
+ # gradients to allow the parameter to get back into the allowed region if it happens to exit it.
425
+ eps = eps.clamp(min=self.eps_min, max=self.eps_max)
426
+ scales = (
427
+ torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
428
+ ) ** -0.5
429
+ return x * scales
430
+
431
+
432
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
433
+ """
434
+ Behaves like a constructor of a modified version of nn.Linear
435
+ that gives an easy way to set the default initial parameter scale.
436
+
437
+ Args:
438
+ Accepts the standard args and kwargs that nn.Linear accepts
439
+ e.g. in_features, out_features, bias=False.
440
+
441
+ initial_scale: you can override this if you want to increase
442
+ or decrease the initial magnitude of the module's output
443
+ (affects the initialization of weight_scale and bias_scale).
444
+ Another option, if you want to do something like this, is
445
+ to re-initialize the parameters.
446
+ """
447
+ ans = nn.Linear(*args, **kwargs)
448
+ with torch.no_grad():
449
+ ans.weight[:] *= initial_scale
450
+ if ans.bias is not None:
451
+ torch.nn.init.uniform_(
452
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
453
+ )
454
+ return ans
455
+
456
+
457
+ def ScaledConv1d(
458
+ *args,
459
+ initial_scale: float = 1.0,
460
+ kernel_size: int = 3,
461
+ padding: str = "same",
462
+ **kwargs,
463
+ ) -> nn.Conv1d:
464
+ """
465
+ Behaves like a constructor of a modified version of nn.Conv1d
466
+ that gives an easy way to set the default initial parameter scale.
467
+
468
+ Args:
469
+ Accepts the standard args and kwargs that nn.Linear accepts
470
+ e.g. in_features, out_features, bias=False.
471
+
472
+ initial_scale: you can override this if you want to increase
473
+ or decrease the initial magnitude of the module's output
474
+ (affects the initialization of weight_scale and bias_scale).
475
+ Another option, if you want to do something like this, is
476
+ to re-initialize the parameters.
477
+ """
478
+ ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
479
+ with torch.no_grad():
480
+ ans.weight[:] *= initial_scale
481
+ if ans.bias is not None:
482
+ torch.nn.init.uniform_(
483
+ ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
484
+ )
485
+ return ans
486
+
487
+
488
+ def TransposeScaledConv1d(
489
+ *args,
490
+ initial_scale: float = 1.0,
491
+ kernel_size: int = 3,
492
+ padding: str = "same",
493
+ **kwargs,
494
+ ) -> nn.Sequential:
495
+ """
496
+ Transpose -> ScaledConv1d
497
+ """
498
+ return nn.Sequential(
499
+ Transpose(),
500
+ ScaledConv1d(
501
+ *args,
502
+ initial_scale=initial_scale,
503
+ kernel_size=kernel_size,
504
+ padding=padding,
505
+ **kwargs,
506
+ ),
507
+ )
508
+
509
+
510
+ def ScaledConv1dTranspose(
511
+ *args,
512
+ initial_scale: float = 1.0,
513
+ kernel_size: int = 3,
514
+ padding: str = "same",
515
+ **kwargs,
516
+ ) -> nn.Sequential:
517
+ """
518
+ Transpose -> ScaledConv1d
519
+ """
520
+ return nn.Sequential(
521
+ ScaledConv1d(
522
+ *args,
523
+ initial_scale=initial_scale,
524
+ kernel_size=kernel_size,
525
+ padding=padding,
526
+ **kwargs,
527
+ ),
528
+ Transpose(),
529
+ )
530
+
531
+
532
+ def TransposeConv1d(
533
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
534
+ ) -> nn.Sequential:
535
+ """
536
+ Transpose -> Conv1d
537
+ """
538
+ return nn.Sequential(
539
+ Transpose(),
540
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
541
+ )
542
+
543
+
544
+ def Conv1dTranspose(
545
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
546
+ ) -> nn.Sequential:
547
+ """
548
+ ScaledConv1d -> Transpose
549
+ """
550
+ return nn.Sequential(
551
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
552
+ Transpose(),
553
+ )
554
+
555
+
556
+ class SRLinear(nn.Linear):
557
+ """https://arxiv.org/abs/2303.06296
558
+ Stabilizing Transformer Training by Preventing Attention Entropy Collapse
559
+ """
560
+
561
+ def __init__(self, in_features, out_features, bias=True, **kwargs):
562
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
563
+ self.register_buffer(
564
+ "u", nn.functional.normalize(torch.randn(in_features), dim=0)
565
+ )
566
+ with torch.no_grad():
567
+ sigma = self.get_sigma()
568
+ self.register_buffer("spectral_norm", sigma)
569
+ self.sigma = nn.Parameter(torch.ones(1))
570
+
571
+ def get_sigma(self):
572
+ with torch.no_grad():
573
+ u = self.u
574
+ v = self.weight.mv(u)
575
+ v = nn.functional.normalize(v, dim=0)
576
+ u = self.weight.T.mv(v)
577
+ u = nn.functional.normalize(u, dim=0)
578
+ self.u.data.copy_(u)
579
+ return torch.einsum("c,cd,d->", v, self.weight, u)
580
+
581
+ def get_weight(self):
582
+ sigma = self.get_sigma()
583
+ if self.training:
584
+ self.spectral_norm.data.copy_(sigma)
585
+ weight = (self.sigma / sigma) * self.weight
586
+ return weight
587
+
588
+ def forward(self, x):
589
+ return nn.functional.linear(x, self.get_weight(), self.bias)
590
+
591
+
592
+ class SRConv1d(SRLinear):
593
+ def __init__(
594
+ self,
595
+ in_features,
596
+ out_features,
597
+ kernel_size,
598
+ stride: int = 1,
599
+ padding: str = "same",
600
+ bias: bool = True,
601
+ **kwargs,
602
+ ):
603
+ in_features = in_features * kernel_size
604
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
605
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
606
+ self.kernel_size = kernel_size
607
+ self.stride = stride
608
+ self.padding = padding
609
+
610
+ def forward(self, x):
611
+ in_features = self.in_features // self.kernel_size
612
+ weight = self.get_weight().view(
613
+ self.out_features, in_features, self.kernel_size
614
+ )
615
+ return nn.functional.conv1d(
616
+ x, weight, bias=self.bias, stride=self.stride, padding=self.padding
617
+ )
618
+
619
+
620
+ def TransposeSRConv1d(
621
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
622
+ ) -> nn.Sequential:
623
+ """
624
+ Transpose -> SRConv1d
625
+ """
626
+ return nn.Sequential(
627
+ Transpose(),
628
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
629
+ )
630
+
631
+
632
+ def SRConv1dTranspose(
633
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
634
+ ) -> nn.Sequential:
635
+ """
636
+ SRConv1d -> Transpose
637
+ """
638
+ return nn.Sequential(
639
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
640
+ Transpose(),
641
+ )
642
+
643
+
644
+ class ActivationBalancer(torch.nn.Module):
645
+ """
646
+ Modifies the backpropped derivatives of a function to try to encourage, for
647
+ each channel, that it is positive at least a proportion `threshold` of the
648
+ time. It does this by multiplying negative derivative values by up to
649
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
650
+ interpolated from 1 at the threshold to those extremal values when none
651
+ of the inputs are positive.
652
+
653
+ Args:
654
+ num_channels: the number of channels
655
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
656
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
657
+ min_positive: the minimum, per channel, of the proportion of the time
658
+ that (x > 0), below which we start to modify the derivatives.
659
+ max_positive: the maximum, per channel, of the proportion of the time
660
+ that (x > 0), above which we start to modify the derivatives.
661
+ max_factor: the maximum factor by which we modify the derivatives for
662
+ either the sign constraint or the magnitude constraint;
663
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
664
+ values in the range [0.98..1.02].
665
+ sign_gain_factor: determines the 'gain' with which we increase the
666
+ change in gradient once the constraints on min_positive and max_positive
667
+ are violated.
668
+ scale_gain_factor: determines the 'gain' with which we increase the
669
+ change in gradient once the constraints on min_abs and max_abs
670
+ are violated.
671
+ min_abs: the minimum average-absolute-value difference from the mean
672
+ value per channel, which we allow, before we start to modify
673
+ the derivatives to prevent this.
674
+ max_abs: the maximum average-absolute-value difference from the mean
675
+ value per channel, which we allow, before we start to modify
676
+ the derivatives to prevent this.
677
+ min_prob: determines the minimum probability with which we modify the
678
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
679
+ on each forward(). This is done randomly to prevent all layers
680
+ from doing it at the same time. Early in training we may use
681
+ higher probabilities than this; it will decay to this value.
682
+ """
683
+
684
+ def __init__(
685
+ self,
686
+ num_channels: int,
687
+ channel_dim: int,
688
+ min_positive: float = 0.05,
689
+ max_positive: float = 0.95,
690
+ max_factor: float = 0.04,
691
+ sign_gain_factor: float = 0.01,
692
+ scale_gain_factor: float = 0.02,
693
+ min_abs: float = 0.2,
694
+ max_abs: float = 100.0,
695
+ min_prob: float = 0.1,
696
+ ):
697
+ super(ActivationBalancer, self).__init__()
698
+ self.num_channels = num_channels
699
+ self.channel_dim = channel_dim
700
+ self.min_positive = min_positive
701
+ self.max_positive = max_positive
702
+ self.max_factor = max_factor
703
+ self.min_abs = min_abs
704
+ self.max_abs = max_abs
705
+ self.min_prob = min_prob
706
+ self.sign_gain_factor = sign_gain_factor
707
+ self.scale_gain_factor = scale_gain_factor
708
+
709
+ # count measures how many times the forward() function has been called.
710
+ # We occasionally sync this to a tensor called `count`, that exists to
711
+ # make sure it is synced to disk when we load and save the model.
712
+ self.cpu_count = 0
713
+ self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
714
+
715
+ def forward(self, x: Tensor) -> Tensor:
716
+ if (
717
+ torch.jit.is_scripting()
718
+ or not x.requires_grad
719
+ or torch.jit.is_tracing()
720
+ ):
721
+ return _no_op(x)
722
+
723
+ count = self.cpu_count
724
+ self.cpu_count += 1
725
+
726
+ if random.random() < 0.01:
727
+ # Occasionally sync self.cpu_count with self.count.
728
+ # count affects the decay of 'prob'. don't do this on every iter,
729
+ # because syncing with the GPU is slow.
730
+ self.cpu_count = max(self.cpu_count, self.count.item())
731
+ self.count.fill_(self.cpu_count)
732
+
733
+ # the prob of doing some work exponentially decreases from 0.5 till it hits
734
+ # a floor at min_prob (==0.1, by default)
735
+ prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
736
+
737
+ if random.random() < prob:
738
+ sign_gain_factor = 0.5
739
+ if self.min_positive != 0.0 or self.max_positive != 1.0:
740
+ sign_factor = _compute_sign_factor(
741
+ x,
742
+ self.channel_dim,
743
+ self.min_positive,
744
+ self.max_positive,
745
+ gain_factor=self.sign_gain_factor / prob,
746
+ max_factor=self.max_factor,
747
+ )
748
+ else:
749
+ sign_factor = None
750
+
751
+ scale_factor = _compute_scale_factor(
752
+ x.detach(),
753
+ self.channel_dim,
754
+ min_abs=self.min_abs,
755
+ max_abs=self.max_abs,
756
+ gain_factor=self.scale_gain_factor / prob,
757
+ max_factor=self.max_factor,
758
+ )
759
+ return ActivationBalancerFunction.apply(
760
+ x,
761
+ scale_factor,
762
+ sign_factor,
763
+ self.channel_dim,
764
+ )
765
+ else:
766
+ return _no_op(x)
767
+
768
+
769
+ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
770
+ """
771
+ Returns x unmodified, but in backprop will put a penalty for the excess of
772
+ the absolute values of elements of x over the limit "limit". E.g. if
773
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
774
+
775
+ Caution: the value of this penalty will be affected by grad scaling used
776
+ in automatic mixed precision training. For this reasons we use this,
777
+ it shouldn't really matter, or may even be helpful; we just use this
778
+ to disallow really implausible values of scores to be given to softmax.
779
+ """
780
+ x_sign = x.sign()
781
+ over_limit = (x.abs() - limit) > 0
782
+ # The following is a memory efficient way to penalize the absolute values of
783
+ # x that's over the limit. (The memory efficiency comes when you think
784
+ # about which items torch needs to cache for the autograd, and which ones it
785
+ # can throw away). The numerical value of aux_loss as computed here will
786
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
787
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
788
+ # limit).relu().
789
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
790
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
791
+ # sum() due to how with_loss() works.
792
+ x = with_loss(x, aux_loss)
793
+ # you must use x for something, or this will be ineffective.
794
+ return x
795
+
796
+
797
+ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
798
+ if x.ndim == 2:
799
+ return x.diag()
800
+ else:
801
+ (batch, dim, dim) = x.shape
802
+ x = x.reshape(batch, dim * dim)
803
+ x = x[:, :: dim + 1]
804
+ assert x.shape == (batch, dim)
805
+ return x
806
+
807
+
808
+ def _whitening_metric(x: Tensor, num_groups: int):
809
+ """
810
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
811
+ of the centered feature covariance are the same within each group's covariance matrix
812
+ and also between groups.
813
+ Args:
814
+ x: a Tensor of shape (*, num_channels)
815
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
816
+ Returns:
817
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
818
+ greater than 1.0 otherwise.
819
+ """
820
+ assert x.dtype != torch.float16
821
+ x = x.reshape(-1, x.shape[-1])
822
+ (num_frames, num_channels) = x.shape
823
+ assert num_channels % num_groups == 0
824
+ channels_per_group = num_channels // num_groups
825
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
826
+ # x now has shape (num_groups, num_frames, channels_per_group)
827
+ # subtract the mean so we use the centered, not uncentered, covariance.
828
+ # My experience has been that when we "mess with the gradients" like this,
829
+ # it's better not do anything that tries to move the mean around, because
830
+ # that can easily cause instability.
831
+ x = x - x.mean(dim=1, keepdim=True)
832
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
833
+ x_covar = torch.matmul(x.transpose(1, 2), x)
834
+ x_covar_mean_diag = _diag(x_covar).mean()
835
+ # the following expression is what we'd get if we took the matrix product
836
+ # of each covariance and measured the mean of its trace, i.e.
837
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
838
+ x_covarsq_mean_diag = (x_covar ** 2).sum() / (
839
+ num_groups * channels_per_group
840
+ )
841
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
842
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
843
+ return metric
844
+
845
+
846
+ class WhiteningPenaltyFunction(torch.autograd.Function):
847
+ @staticmethod
848
+ def forward(
849
+ ctx,
850
+ x: Tensor,
851
+ num_groups: int,
852
+ whitening_limit: float,
853
+ grad_scale: float,
854
+ ) -> Tensor:
855
+ ctx.save_for_backward(x)
856
+ ctx.num_groups = num_groups
857
+ ctx.whitening_limit = whitening_limit
858
+ ctx.grad_scale = grad_scale
859
+ return x
860
+
861
+ @staticmethod
862
+ def backward(ctx, x_grad: Tensor):
863
+ (x_orig,) = ctx.saved_tensors
864
+ with torch.enable_grad():
865
+ with torch.cuda.amp.autocast(enabled=False):
866
+ x_detached = x_orig.to(torch.float32).detach()
867
+ x_detached.requires_grad = True
868
+
869
+ metric = _whitening_metric(x_detached, ctx.num_groups)
870
+
871
+ if random.random() < 0.005 or __name__ == "__main__":
872
+ logging.info(
873
+ f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
874
+ f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
875
+ )
876
+
877
+ (metric - ctx.whitening_limit).relu().backward()
878
+ penalty_grad = x_detached.grad
879
+ scale = ctx.grad_scale * (
880
+ x_grad.to(torch.float32).norm()
881
+ / (penalty_grad.norm() + 1.0e-20)
882
+ )
883
+ penalty_grad = penalty_grad * scale
884
+ return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
885
+
886
+
887
+ class Whiten(nn.Module):
888
+ def __init__(
889
+ self,
890
+ num_groups: int,
891
+ whitening_limit: float,
892
+ prob: Union[float, Tuple[float, float]],
893
+ grad_scale: float,
894
+ ):
895
+ """
896
+ Args:
897
+ num_groups: the number of groups to divide the channel dim into before
898
+ whitening. We will attempt to make the feature covariance
899
+ within each group, after mean subtraction, as "white" as possible,
900
+ while having the same trace across all groups.
901
+ whitening_limit: a value greater than 1.0, that dictates how much
902
+ freedom we have to violate the constraints. 1.0 would mean perfectly
903
+ white, with exactly the same trace across groups; larger values
904
+ give more freedom. E.g. 2.0.
905
+ prob: the probability with which we apply the gradient modification
906
+ (also affects the grad scale). May be supplied as a float,
907
+ or as a pair (min_prob, max_prob)
908
+
909
+ grad_scale: determines the scale on the gradient term from this object,
910
+ relative to the rest of the gradient on the attention weights.
911
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
912
+ """
913
+ super(Whiten, self).__init__()
914
+ assert num_groups >= 1
915
+ assert whitening_limit >= 1
916
+ assert grad_scale >= 0
917
+ self.num_groups = num_groups
918
+ self.whitening_limit = whitening_limit
919
+ if isinstance(prob, float):
920
+ assert 0 < prob <= 1
921
+ self.prob = prob
922
+ else:
923
+ (self.min_prob, self.max_prob) = prob
924
+ assert 0 < self.min_prob < self.max_prob <= 1
925
+ self.prob = self.max_prob
926
+
927
+ self.grad_scale = grad_scale
928
+
929
+ def forward(self, x: Tensor) -> Tensor:
930
+ """
931
+ In the forward pass, this function just returns the input unmodified.
932
+ In the backward pass, it will modify the gradients to ensure that the
933
+ distribution in each group has close to (lambda times I) as the covariance
934
+ after mean subtraction, with the same lambda across groups.
935
+ For whitening_limit > 1, there will be more freedom to violate this
936
+ constraint.
937
+
938
+ Args:
939
+ x: the input of shape (*, num_channels)
940
+
941
+ Returns:
942
+ x, unmodified. You should make sure
943
+ you use the returned value, or the graph will be freed
944
+ and nothing will happen in backprop.
945
+ """
946
+ if (
947
+ not x.requires_grad
948
+ or random.random() > self.prob
949
+ or self.grad_scale == 0
950
+ ):
951
+ return _no_op(x)
952
+ else:
953
+ if hasattr(self, "min_prob") and random.random() < 0.25:
954
+ # occasionally switch between min_prob and max_prob, based on whether
955
+ # we are above or below the threshold.
956
+ if (
957
+ _whitening_metric(x.to(torch.float32), self.num_groups)
958
+ > self.whitening_limit
959
+ ):
960
+ # there would be a change to the grad.
961
+ self.prob = self.max_prob
962
+ else:
963
+ self.prob = self.min_prob
964
+
965
+ return WhiteningPenaltyFunction.apply(
966
+ x, self.num_groups, self.whitening_limit, self.grad_scale
967
+ )
968
+
969
+
970
+ class WithLoss(torch.autograd.Function):
971
+ @staticmethod
972
+ def forward(ctx, x: Tensor, y: Tensor):
973
+ ctx.y_shape = y.shape
974
+ return x
975
+
976
+ @staticmethod
977
+ def backward(ctx, ans_grad: Tensor):
978
+ return ans_grad, torch.ones(
979
+ ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
980
+ )
981
+
982
+
983
+ def with_loss(x, y):
984
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
985
+ return x
986
+ # returns x but adds y.sum() to the loss function.
987
+ return WithLoss.apply(x, y)
988
+
989
+
990
+ def _no_op(x: Tensor) -> Tensor:
991
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
992
+ return x
993
+ else:
994
+ # a no-op function that will have a node in the autograd graph,
995
+ # to avoid certain bugs relating to backward hooks
996
+ return x.chunk(1, dim=-1)[0]
997
+
998
+
999
+ class Identity(torch.nn.Module):
1000
+ def __init__(self):
1001
+ super(Identity, self).__init__()
1002
+
1003
+ def forward(self, x):
1004
+ return _no_op(x)
1005
+
1006
+
1007
+ class MaxEig(torch.nn.Module):
1008
+ """
1009
+ Modifies the backpropped derivatives of a function to try to discourage
1010
+ that any given direction in activation space accounts for more than
1011
+ a specified proportion of the covariance (e.g. 0.2).
1012
+
1013
+
1014
+ Args:
1015
+ num_channels: the number of channels
1016
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
1017
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
1018
+ max_var_per_eig: the maximum proportion of the variance of the
1019
+ features/channels, after mean subtraction, that can come from
1020
+ any given eigenvalue.
1021
+ min_prob: the minimum probability with which we apply this during any invocation
1022
+ of forward(), assuming last time we applied the constraint it was
1023
+ not active; supplied for speed.
1024
+ scale: determines the scale with which we modify the gradients, relative
1025
+ to the existing / unmodified gradients
1026
+ """
1027
+
1028
+ def __init__(
1029
+ self,
1030
+ num_channels: int,
1031
+ channel_dim: int,
1032
+ max_var_per_eig: float = 0.2,
1033
+ min_prob: float = 0.01,
1034
+ scale: float = 0.01,
1035
+ ):
1036
+ super(MaxEig, self).__init__()
1037
+ self.num_channels = num_channels
1038
+ self.channel_dim = channel_dim
1039
+ self.scale = scale
1040
+ assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
1041
+ self.max_var_per_eig = max_var_per_eig
1042
+
1043
+ # we figure out the dominant direction using the power method: starting with
1044
+ # a random vector, keep multiplying by the covariance and renormalizing.
1045
+ with torch.no_grad():
1046
+ # arbitrary.. would use randn() but want to leave the rest of the model's
1047
+ # random parameters unchanged for comparison
1048
+ direction = torch.arange(num_channels).to(torch.float)
1049
+ direction = direction / direction.norm()
1050
+ self.register_buffer("max_eig_direction", direction)
1051
+
1052
+ self.min_prob = min_prob
1053
+ # cur_prob is the current probability we'll use to apply the ActivationBalancer.
1054
+ # We'll regress this towards prob, each tiem we try to apply it and it is not
1055
+ # active.
1056
+ self.cur_prob = 1.0
1057
+
1058
+ def forward(self, x: Tensor) -> Tensor:
1059
+ if (
1060
+ torch.jit.is_scripting()
1061
+ or self.max_var_per_eig <= 0
1062
+ or random.random() > self.cur_prob
1063
+ or torch.jit.is_tracing()
1064
+ ):
1065
+ return _no_op(x)
1066
+
1067
+ with torch.cuda.amp.autocast(enabled=False):
1068
+ eps = 1.0e-20
1069
+ orig_x = x
1070
+ x = x.to(torch.float32)
1071
+ with torch.no_grad():
1072
+ x = x.transpose(self.channel_dim, -1).reshape(
1073
+ -1, self.num_channels
1074
+ )
1075
+ x = x - x.mean(dim=0)
1076
+ new_direction, coeffs = self._find_direction_coeffs(
1077
+ x, self.max_eig_direction
1078
+ )
1079
+ x_var = (x ** 2).mean()
1080
+ x_residual = x - coeffs * new_direction
1081
+ x_residual_var = (x_residual ** 2).mean()
1082
+
1083
+ # `variance_proportion` is the proportion of the variance accounted for
1084
+ # by the top eigen-direction.
1085
+ variance_proportion = (x_var - x_residual_var) / (
1086
+ x_var + 1.0e-20
1087
+ )
1088
+
1089
+ # ensure new direction is nonzero even if x == 0, by including `direction`.
1090
+ self._set_direction(
1091
+ 0.1 * self.max_eig_direction + new_direction
1092
+ )
1093
+
1094
+ if random.random() < 0.01 or __name__ == "__main__":
1095
+ logging.info(
1096
+ f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
1097
+ )
1098
+
1099
+ if variance_proportion >= self.max_var_per_eig:
1100
+ # The constraint is active. Note, we should quite rarely
1101
+ # reach here, only near the beginning of training if we are
1102
+ # starting to diverge, should this constraint be active.
1103
+ cur_prob = self.cur_prob
1104
+ self.cur_prob = (
1105
+ 1.0 # next time, do the update with probability 1.0.
1106
+ )
1107
+ return MaxEigLimiterFunction.apply(
1108
+ orig_x, coeffs, new_direction, self.channel_dim, self.scale
1109
+ )
1110
+ else:
1111
+ # let self.cur_prob exponentially approach self.min_prob, as
1112
+ # long as the constraint is inactive.
1113
+ self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
1114
+ return orig_x
1115
+
1116
+ def _set_direction(self, direction: Tensor):
1117
+ """
1118
+ Sets self.max_eig_direction to a normalized version of `direction`
1119
+ """
1120
+ direction = direction.detach()
1121
+ direction = direction / direction.norm()
1122
+ direction_sum = direction.sum().item()
1123
+ if direction_sum - direction_sum == 0: # no inf/nan
1124
+ self.max_eig_direction[:] = direction
1125
+ else:
1126
+ logging.info(
1127
+ f"Warning: sum of direction in MaxEig is {direction_sum}, "
1128
+ "num_channels={self.num_channels}, channel_dim={self.channel_dim}"
1129
+ )
1130
+
1131
+ def _find_direction_coeffs(
1132
+ self, x: Tensor, prev_direction: Tensor
1133
+ ) -> Tuple[Tensor, Tensor, Tensor]:
1134
+ """
1135
+ Figure out (an approximation to) the proportion of the variance of a set of
1136
+ feature vectors that can be attributed to the top eigen-direction.
1137
+ Args:
1138
+ x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
1139
+ prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
1140
+ of the top eigen-direction, or a random direction if this is the first
1141
+ iteration. Does not have to be normalized, but should be nonzero.
1142
+
1143
+ Returns: (cur_direction, coeffs), where:
1144
+ cur_direction: a Tensor of shape (num_channels,) that is the current
1145
+ estimate of the top eigen-direction.
1146
+ coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
1147
+ approximately minimizes, (x - coeffs * cur_direction).norm()
1148
+ """
1149
+ (num_frames, num_channels) = x.shape
1150
+ assert num_channels > 1 and num_frames > 1
1151
+ assert prev_direction.shape == (num_channels,)
1152
+ # `coeffs` are the coefficients of `prev_direction` in x.
1153
+ # actually represent the coeffs up to a constant positive factor.
1154
+ coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
1155
+ cur_direction = (x * coeffs).sum(dim=0) / (
1156
+ (coeffs ** 2).sum() + 1.0e-20
1157
+ )
1158
+ return cur_direction, coeffs
1159
+
1160
+
1161
+ class DoubleSwishFunction(torch.autograd.Function):
1162
+ """
1163
+ double_swish(x) = x * torch.sigmoid(x-1)
1164
+ This is a definition, originally motivated by its close numerical
1165
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
1166
+
1167
+ Memory-efficient derivative computation:
1168
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
1169
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
1170
+ Now, s'(x) = s(x) * (1-s(x)).
1171
+ double_swish'(x) = x * s'(x) + s(x).
1172
+ = x * s(x) * (1-s(x)) + s(x).
1173
+ = double_swish(x) * (1-s(x)) + s(x)
1174
+ ... so we just need to remember s(x) but not x itself.
1175
+ """
1176
+
1177
+ @staticmethod
1178
+ def forward(ctx, x: Tensor) -> Tensor:
1179
+ requires_grad = x.requires_grad
1180
+ x_dtype = x.dtype
1181
+ if x.dtype == torch.float16:
1182
+ x = x.to(torch.float32)
1183
+
1184
+ s = torch.sigmoid(x - 1.0)
1185
+ y = x * s
1186
+
1187
+ if requires_grad:
1188
+ deriv = y * (1 - s) + s
1189
+ # notes on derivative of x * sigmoid(x - 1):
1190
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
1191
+ # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
1192
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
1193
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
1194
+ # floors), should be expectation-preserving.
1195
+ floor = -0.043637
1196
+ ceil = 1.2
1197
+ d_scaled = (deriv - floor) * (
1198
+ 255.0 / (ceil - floor)
1199
+ ) + torch.rand_like(deriv)
1200
+ if __name__ == "__main__":
1201
+ # for self-testing only.
1202
+ assert d_scaled.min() >= 0.0
1203
+ assert d_scaled.max() < 256.0
1204
+ d_int = d_scaled.to(torch.uint8)
1205
+ ctx.save_for_backward(d_int)
1206
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1207
+ y = y.to(torch.float16)
1208
+ return y
1209
+
1210
+ @staticmethod
1211
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1212
+ (d,) = ctx.saved_tensors
1213
+ # the same constants as used in forward pass.
1214
+ floor = -0.043637
1215
+ ceil = 1.2
1216
+ d = d * ((ceil - floor) / 255.0) + floor
1217
+ return y_grad * d
1218
+
1219
+
1220
+ class DoubleSwish(torch.nn.Module):
1221
+ def forward(self, x: Tensor) -> Tensor:
1222
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
1223
+ that we approximate closely with x * sigmoid(x-1).
1224
+ """
1225
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1226
+ return x * torch.sigmoid(x - 1.0)
1227
+ return DoubleSwishFunction.apply(x)
1228
+
1229
+
1230
+ def BalancedDoubleSwish(
1231
+ d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
1232
+ ) -> nn.Sequential:
1233
+ """
1234
+ ActivationBalancer -> DoubleSwish
1235
+ """
1236
+ balancer = ActivationBalancer(
1237
+ d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
1238
+ )
1239
+ return nn.Sequential(
1240
+ balancer,
1241
+ DoubleSwish(),
1242
+ )
1243
+
1244
+
1245
+ def _test_max_eig():
1246
+ for proportion in [0.1, 0.5, 10.0]:
1247
+ logging.info(f"proportion = {proportion}")
1248
+ x = torch.randn(100, 128)
1249
+ direction = torch.randn(128)
1250
+ coeffs = torch.randn(100, 1)
1251
+ x += proportion * direction * coeffs
1252
+
1253
+ x.requires_grad = True
1254
+
1255
+ num_channels = 128
1256
+ m = MaxEig(
1257
+ num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
1258
+ ) # grad_scale
1259
+
1260
+ for _ in range(4):
1261
+ y = m(x)
1262
+
1263
+ y_grad = torch.randn_like(x)
1264
+ y.backward(gradient=y_grad)
1265
+
1266
+ if proportion < 0.2:
1267
+ assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
1268
+ elif proportion > 1.0:
1269
+ assert not torch.allclose(x.grad, y_grad)
1270
+
1271
+
1272
+ def _test_whiten():
1273
+ for proportion in [0.1, 0.5, 10.0]:
1274
+ logging.info(f"_test_whiten(): proportion = {proportion}")
1275
+ x = torch.randn(100, 128)
1276
+ direction = torch.randn(128)
1277
+ coeffs = torch.randn(100, 1)
1278
+ x += proportion * direction * coeffs
1279
+
1280
+ x.requires_grad = True
1281
+
1282
+ num_channels = 128
1283
+ m = Whiten(
1284
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
1285
+ ) # grad_scale
1286
+
1287
+ for _ in range(4):
1288
+ y = m(x)
1289
+
1290
+ y_grad = torch.randn_like(x)
1291
+ y.backward(gradient=y_grad)
1292
+
1293
+ if proportion < 0.2:
1294
+ assert torch.allclose(x.grad, y_grad)
1295
+ elif proportion > 1.0:
1296
+ assert not torch.allclose(x.grad, y_grad)
1297
+
1298
+
1299
+ def _test_activation_balancer_sign():
1300
+ probs = torch.arange(0, 1, 0.01)
1301
+ N = 1000
1302
+ x = 1.0 * (
1303
+ (2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
1304
+ )
1305
+ x = x.detach()
1306
+ x.requires_grad = True
1307
+ m = ActivationBalancer(
1308
+ probs.numel(),
1309
+ channel_dim=0,
1310
+ min_positive=0.05,
1311
+ max_positive=0.95,
1312
+ max_factor=0.2,
1313
+ min_abs=0.0,
1314
+ )
1315
+
1316
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
1317
+
1318
+ y = m(x)
1319
+ y.backward(gradient=y_grad)
1320
+ print("_test_activation_balancer_sign: x = ", x)
1321
+ print("_test_activation_balancer_sign: y grad = ", y_grad)
1322
+ print("_test_activation_balancer_sign: x grad = ", x.grad)
1323
+
1324
+
1325
+ def _test_activation_balancer_magnitude():
1326
+ magnitudes = torch.arange(0, 1, 0.01)
1327
+ N = 1000
1328
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
1329
+ -1
1330
+ )
1331
+ x = x.detach()
1332
+ x.requires_grad = True
1333
+ m = ActivationBalancer(
1334
+ magnitudes.numel(),
1335
+ channel_dim=0,
1336
+ min_positive=0.0,
1337
+ max_positive=1.0,
1338
+ max_factor=0.2,
1339
+ min_abs=0.2,
1340
+ max_abs=0.8,
1341
+ min_prob=1.0,
1342
+ )
1343
+
1344
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
1345
+
1346
+ y = m(x)
1347
+ y.backward(gradient=y_grad)
1348
+ print("_test_activation_balancer_magnitude: x = ", x)
1349
+ print("_test_activation_balancer_magnitude: y grad = ", y_grad)
1350
+ print("_test_activation_balancer_magnitude: x grad = ", x.grad)
1351
+
1352
+
1353
+ def _test_basic_norm():
1354
+ num_channels = 128
1355
+ m = BasicNorm(num_channels=num_channels, channel_dim=1)
1356
+
1357
+ x = torch.randn(500, num_channels)
1358
+
1359
+ y = m(x)
1360
+
1361
+ assert y.shape == x.shape
1362
+ x_rms = (x ** 2).mean().sqrt()
1363
+ y_rms = (y ** 2).mean().sqrt()
1364
+ print("x rms = ", x_rms)
1365
+ print("y rms = ", y_rms)
1366
+ assert y_rms < x_rms
1367
+ assert y_rms > 0.5 * x_rms
1368
+
1369
+
1370
+ def _test_double_swish_deriv():
1371
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1372
+ x.requires_grad = True
1373
+ m = DoubleSwish()
1374
+
1375
+ tol = (1.2 - (-0.043637)) / 255.0
1376
+ torch.autograd.gradcheck(m, x, atol=tol)
1377
+
1378
+ # for self-test.
1379
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1380
+ x.requires_grad = True
1381
+ y = m(x)
1382
+
1383
+
1384
+ def _test_softmax():
1385
+ a = torch.randn(2, 10, dtype=torch.float64)
1386
+ b = a.clone()
1387
+ a.requires_grad = True
1388
+ b.requires_grad = True
1389
+ a.softmax(dim=1)[:, 0].sum().backward()
1390
+ print("a grad = ", a.grad)
1391
+ softmax(b, dim=1)[:, 0].sum().backward()
1392
+ print("b grad = ", b.grad)
1393
+ assert torch.allclose(a.grad, b.grad)
1394
+
1395
+
1396
+ if __name__ == "__main__":
1397
+ logging.getLogger().setLevel(logging.INFO)
1398
+ torch.set_num_threads(1)
1399
+ torch.set_num_interop_threads(1)
1400
+ _test_softmax()
1401
+ _test_whiten()
1402
+ _test_max_eig()
1403
+ _test_activation_balancer_sign()
1404
+ _test_activation_balancer_magnitude()
1405
+ _test_basic_norm()
1406
+ _test_double_swish_deriv()
models/modules/transformer.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py, modified by Puyuan Peng 2024
2
+ import copy
3
+ import numbers
4
+ from functools import partial
5
+ from typing import Any, Callable, List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+ from torch.nn import functional as F
10
+
11
+ from .activation import MultiheadAttention
12
+ from .scaling import ActivationBalancer, BalancedDoubleSwish
13
+ from .scaling import BasicNorm as _BasicNorm
14
+
15
+ _shape_t = Union[int, List[int], torch.Size]
16
+
17
+
18
+ class LayerNorm(nn.Module):
19
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
20
+ normalized_shape: Tuple[int, ...]
21
+ eps: float
22
+ elementwise_affine: bool
23
+
24
+ def __init__(
25
+ self,
26
+ normalized_shape: _shape_t,
27
+ eps: float = 1e-5,
28
+ elementwise_affine: bool = True,
29
+ device=None,
30
+ dtype=None,
31
+ ) -> None:
32
+ factory_kwargs = {"device": device, "dtype": dtype}
33
+ super(LayerNorm, self).__init__()
34
+ if isinstance(normalized_shape, numbers.Integral):
35
+ # mypy error: incompatible types in assignment
36
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
37
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
38
+ self.eps = eps
39
+ self.elementwise_affine = elementwise_affine
40
+ if self.elementwise_affine:
41
+ self.weight = nn.Parameter(
42
+ torch.empty(self.normalized_shape, **factory_kwargs)
43
+ )
44
+ self.bias = nn.Parameter(
45
+ torch.empty(self.normalized_shape, **factory_kwargs)
46
+ )
47
+ else:
48
+ self.register_parameter("weight", None)
49
+ self.register_parameter("bias", None)
50
+
51
+ self.reset_parameters()
52
+
53
+ def reset_parameters(self) -> None:
54
+ if self.elementwise_affine:
55
+ nn.init.ones_(self.weight)
56
+ nn.init.zeros_(self.bias)
57
+
58
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
59
+ if isinstance(input, tuple):
60
+ input, embedding = input
61
+ return (
62
+ F.layer_norm(
63
+ input,
64
+ self.normalized_shape,
65
+ self.weight,
66
+ self.bias,
67
+ self.eps,
68
+ ),
69
+ embedding,
70
+ )
71
+
72
+ assert embedding is None
73
+ return F.layer_norm(
74
+ input, self.normalized_shape, self.weight, self.bias, self.eps
75
+ )
76
+
77
+ def extra_repr(self) -> str:
78
+ return (
79
+ "{normalized_shape}, eps={eps}, "
80
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
81
+ )
82
+
83
+
84
+ class AdaptiveLayerNorm(nn.Module):
85
+ r"""Adaptive Layer Normalization"""
86
+
87
+ def __init__(self, d_model, norm) -> None:
88
+ super(AdaptiveLayerNorm, self).__init__()
89
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
90
+ self.norm = norm
91
+ self.d_model = d_model
92
+ self.eps = self.norm.eps
93
+
94
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
95
+ if isinstance(input, tuple):
96
+ input, embedding = input
97
+ weight, bias = torch.split(
98
+ self.project_layer(embedding),
99
+ split_size_or_sections=self.d_model,
100
+ dim=-1,
101
+ )
102
+ return (weight * self.norm(input) + bias, embedding)
103
+
104
+ weight, bias = torch.split(
105
+ self.project_layer(embedding),
106
+ split_size_or_sections=self.d_model,
107
+ dim=-1,
108
+ )
109
+ return weight * self.norm(input) + bias
110
+
111
+
112
+ class BasicNorm(_BasicNorm):
113
+ def __init__(
114
+ self,
115
+ d_model: int,
116
+ eps: float = 1e-5,
117
+ device=None,
118
+ dtype=None,
119
+ ):
120
+ super(BasicNorm, self).__init__(d_model, eps=eps)
121
+
122
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
123
+ if isinstance(input, tuple):
124
+ input, embedding = input
125
+ return (
126
+ super(BasicNorm, self).forward(input),
127
+ embedding,
128
+ )
129
+
130
+ assert embedding is None
131
+ return super(BasicNorm, self).forward(input)
132
+
133
+
134
+ class BalancedBasicNorm(nn.Module):
135
+ def __init__(
136
+ self,
137
+ d_model: int,
138
+ eps: float = 1e-5,
139
+ device=None,
140
+ dtype=None,
141
+ ):
142
+ super(BalancedBasicNorm, self).__init__()
143
+ self.balancer = ActivationBalancer(
144
+ d_model,
145
+ channel_dim=-1,
146
+ min_positive=0.45,
147
+ max_positive=0.55,
148
+ max_abs=6.0,
149
+ )
150
+ self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
151
+
152
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
153
+ if isinstance(input, tuple):
154
+ input, embedding = input
155
+ return self.norm((self.balancer(input), embedding))
156
+
157
+ assert embedding is None
158
+ return self.norm(self.balancer(input))
159
+
160
+
161
+ class IdentityNorm(nn.Module):
162
+ def __init__(
163
+ self,
164
+ d_model: int,
165
+ eps: float = 1e-5,
166
+ device=None,
167
+ dtype=None,
168
+ ) -> None:
169
+ super(IdentityNorm, self).__init__()
170
+
171
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
172
+ if isinstance(input, tuple):
173
+ return input
174
+
175
+ assert embedding is None
176
+ return input
177
+
178
+
179
+ class TransformerEncoderLayer(nn.Module):
180
+ __constants__ = ["batch_first", "norm_first"]
181
+
182
+ def __init__(
183
+ self,
184
+ d_model: int,
185
+ nhead: int,
186
+ dim_feedforward: int = 2048,
187
+ dropout: float = 0.1,
188
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
189
+ batch_first: bool = False,
190
+ norm_first: bool = False,
191
+ device=None,
192
+ dtype=None,
193
+ linear1_self_attention_cls: nn.Module = nn.Linear,
194
+ linear2_self_attention_cls: nn.Module = nn.Linear,
195
+ linear1_feedforward_cls: nn.Module = nn.Linear,
196
+ linear2_feedforward_cls: nn.Module = nn.Linear,
197
+ layer_norm_cls: nn.Module = LayerNorm,
198
+ layer_norm_eps: float = 1e-5,
199
+ adaptive_layer_norm=False,
200
+ ) -> None:
201
+ factory_kwargs = {"device": device, "dtype": dtype}
202
+ super(TransformerEncoderLayer, self).__init__()
203
+ self.self_attn = MultiheadAttention(
204
+ d_model,
205
+ nhead,
206
+ dropout=dropout,
207
+ batch_first=batch_first,
208
+ linear1_cls=linear1_self_attention_cls,
209
+ linear2_cls=linear2_self_attention_cls,
210
+ **factory_kwargs,
211
+ )
212
+
213
+ # Implementation of Feedforward model
214
+ self.linear1 = linear1_feedforward_cls(
215
+ d_model, dim_feedforward, **factory_kwargs
216
+ )
217
+ self.dropout = nn.Dropout(dropout)
218
+ self.linear2 = linear2_feedforward_cls(
219
+ dim_feedforward, d_model, **factory_kwargs
220
+ )
221
+
222
+ self.norm_first = norm_first
223
+ self.dropout1 = nn.Dropout(dropout)
224
+ self.dropout2 = nn.Dropout(dropout)
225
+
226
+ # Legacy string support for activation function.
227
+ if isinstance(activation, str):
228
+ activation = _get_activation_fn(activation)
229
+ elif isinstance(activation, partial):
230
+ activation = activation(d_model)
231
+ elif activation == BalancedDoubleSwish:
232
+ activation = BalancedDoubleSwish(d_model)
233
+
234
+ # # We can't test self.activation in forward() in TorchScript,
235
+ # # so stash some information about it instead.
236
+ # if activation is F.relu or isinstance(activation, torch.nn.ReLU):
237
+ # self.activation_relu_or_gelu = 1
238
+ # elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
239
+ # self.activation_relu_or_gelu = 2
240
+ # else:
241
+ # self.activation_relu_or_gelu = 0
242
+ self.activation = activation
243
+
244
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
245
+ if layer_norm_cls == IdentityNorm:
246
+ norm2 = BalancedBasicNorm(
247
+ d_model, eps=layer_norm_eps, **factory_kwargs
248
+ )
249
+ else:
250
+ norm2 = layer_norm_cls(
251
+ d_model, eps=layer_norm_eps, **factory_kwargs
252
+ )
253
+
254
+ if adaptive_layer_norm:
255
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
256
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
257
+ else:
258
+ self.norm1 = norm1
259
+ self.norm2 = norm2
260
+
261
+ def __setstate__(self, state):
262
+ super(TransformerEncoderLayer, self).__setstate__(state)
263
+ if not hasattr(self, "activation"):
264
+ self.activation = F.relu
265
+
266
+ def forward(
267
+ self,
268
+ src: Tensor,
269
+ src_mask: Optional[Tensor] = None,
270
+ src_key_padding_mask: Optional[Tensor] = None,
271
+ need_weights: Optional[bool] = False,
272
+ past: Optional[Tensor] = None,
273
+ ) -> Tensor:
274
+ r"""Pass the input through the encoder layer.
275
+
276
+ Args:
277
+ src: the sequence to the encoder layer (required).
278
+ src_mask: the mask for the src sequence (optional).
279
+ src_key_padding_mask: the mask for the src keys per batch (optional).
280
+
281
+ Shape:
282
+ see the docs in Transformer class.
283
+ """
284
+ x, stage_embedding = src, None
285
+ is_src_tuple = False
286
+ if isinstance(src, tuple):
287
+ x, stage_embedding = src
288
+ is_src_tuple = True
289
+
290
+ if src_key_padding_mask is not None:
291
+ _skpm_dtype = src_key_padding_mask.dtype
292
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
293
+ src_key_padding_mask
294
+ ):
295
+ raise AssertionError(
296
+ "only bool and floating types of key_padding_mask are supported"
297
+ )
298
+ if need_weights:
299
+ if self.norm_first:
300
+ out, attn = self._sa_block_attn(
301
+ self.norm1(x, stage_embedding),
302
+ src_mask,
303
+ src_key_padding_mask,
304
+ past
305
+ )
306
+ out, present = out # present is the kvcache of the present timestep
307
+ x = x + out
308
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
309
+ else:
310
+ out, attn = self._sa_block_attn(x, src_mask, src_key_padding_mask, past)
311
+ out, present = out # present is the kvcache of the present timestep
312
+ x = self.norm1(
313
+ x + out,
314
+ stage_embedding,
315
+ )
316
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
317
+ assert not is_src_tuple
318
+ # return (x, stage_embedding)
319
+ return (x, attn)
320
+ else:
321
+ if self.norm_first:
322
+ out = self._sa_block(
323
+ self.norm1(x, stage_embedding),
324
+ src_mask,
325
+ src_key_padding_mask, past
326
+ )
327
+ out, present = out # present is the kvcache of the present timestep
328
+ x = x + out
329
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
330
+ else:
331
+ out = self._sa_block(x, src_mask, src_key_padding_mask)
332
+ out, present = out # present is the kvcache of the present timestep
333
+ x = self.norm1(
334
+ x + out,
335
+ stage_embedding, past
336
+ )
337
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
338
+
339
+ if is_src_tuple:
340
+ x = (x, stage_embedding)
341
+ if present != None:
342
+ x = [x, present]
343
+ return x
344
+
345
+ # self-attention block
346
+ def _sa_block(
347
+ self,
348
+ x: Tensor,
349
+ attn_mask: Optional[Tensor],
350
+ key_padding_mask: Optional[Tensor],
351
+ past: Optional[Tensor] = None,
352
+ ) -> Tensor:
353
+ x = self.self_attn(
354
+ x,
355
+ x,
356
+ x,
357
+ attn_mask=attn_mask,
358
+ key_padding_mask=key_padding_mask,
359
+ need_weights=False,
360
+ past=past
361
+ )
362
+ x, present = x
363
+ return self.dropout1(x), present
364
+
365
+ # self-attention block, also return attention weights
366
+ def _sa_block_attn(
367
+ self,
368
+ x: Tensor,
369
+ attn_mask: Optional[Tensor],
370
+ key_padding_mask: Optional[Tensor],
371
+ past: Optional[Tensor] = None,
372
+ ) -> Tensor:
373
+ x, attn = self.self_attn(
374
+ x,
375
+ x,
376
+ x,
377
+ attn_mask=attn_mask,
378
+ key_padding_mask=key_padding_mask,
379
+ need_weights=True,
380
+ past=past
381
+ )
382
+ x, present = x
383
+ return (self.dropout1(x), present), attn
384
+
385
+ # feed forward block
386
+ def _ff_block(self, x: Tensor) -> Tensor:
387
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
388
+ return self.dropout2(x)
389
+
390
+
391
+ class TransformerEncoder(nn.Module):
392
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
393
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
394
+
395
+ Args:
396
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
397
+ num_layers: the number of sub-encoder-layers in the encoder (required).
398
+ norm: the layer normalization component (optional).
399
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
400
+ (and convert back on output). This will improve the overall performance of
401
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
402
+
403
+ Examples::
404
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
405
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
406
+ >>> src = torch.rand(10, 32, 512)
407
+ >>> out = transformer_encoder(src)
408
+ """
409
+ __constants__ = ["norm"]
410
+
411
+ def __init__(self, encoder_layer, num_layers, norm=None):
412
+ super(TransformerEncoder, self).__init__()
413
+ self.layers = _get_clones(encoder_layer, num_layers)
414
+ self.num_layers = num_layers
415
+ self.norm = norm
416
+
417
+ def forward(
418
+ self,
419
+ src: Tensor,
420
+ mask: Optional[Tensor] = None,
421
+ src_key_padding_mask: Optional[Tensor] = None,
422
+ return_layer_states: bool = False,
423
+ need_weights:Optional[bool] = False,
424
+ past: Optional[Tensor] = None,
425
+ ) -> Tensor:
426
+ r"""Pass the input through the encoder layers in turn.
427
+
428
+ Args:
429
+ src: the sequence to the encoder (required).
430
+ mask: the mask for the src sequence (optional).
431
+ src_key_padding_mask: the mask for the src keys per batch (optional).
432
+ return_layer_states: return layers' state (optional).
433
+
434
+ Shape:
435
+ see the docs in Transformer class.
436
+ """
437
+ if return_layer_states:
438
+ assert not need_weights
439
+ layer_states = [] # layers' output
440
+ output = src
441
+ for mod in self.layers:
442
+ output = mod(
443
+ output,
444
+ src_mask=mask,
445
+ src_key_padding_mask=src_key_padding_mask,
446
+ past=past
447
+ )
448
+ layer_states.append(output[0])
449
+
450
+ if self.norm is not None:
451
+ output = self.norm(output)
452
+
453
+ return layer_states, output
454
+ if need_weights:
455
+ assert not return_layer_states
456
+ layer_attn = [] # layers' output
457
+ output = src
458
+ for mod in self.layers:
459
+ output = mod(
460
+ output,
461
+ src_mask=mask,
462
+ src_key_padding_mask=src_key_padding_mask,
463
+ need_weights=True,
464
+ past=past
465
+ )
466
+ layer_attn.append(output[1])
467
+
468
+ if self.norm is not None:
469
+ output = self.norm(output)
470
+
471
+ return layer_attn, output
472
+
473
+ output = src
474
+ all_present = []
475
+ for n_layer, mod in enumerate(self.layers):
476
+ output = mod(
477
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past=None if past is None else past[n_layer]
478
+ )
479
+ if isinstance(output, list):
480
+ output, present = output
481
+ all_present.append(present)
482
+
483
+ if self.norm is not None:
484
+ output = self.norm(output)
485
+ if all_present != []:
486
+ all_present = torch.stack(all_present, dim=0) # (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
487
+ output = [output, all_present]
488
+ return output
489
+
490
+
491
+ class TransformerDecoderLayer(nn.Module):
492
+ __constants__ = ["batch_first", "norm_first"]
493
+
494
+ def __init__(
495
+ self,
496
+ d_model: int,
497
+ nhead: int,
498
+ dim_feedforward: int = 2048,
499
+ dropout: float = 0.1,
500
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
501
+ linear1_self_attention_cls: nn.Module = nn.Linear,
502
+ linear2_self_attention_cls: nn.Module = nn.Linear,
503
+ linear1_feedforward_cls: nn.Module = nn.Linear,
504
+ linear2_feedforward_cls: nn.Module = nn.Linear,
505
+ batch_first: bool = False,
506
+ norm_first: bool = False,
507
+ device=None,
508
+ dtype=None,
509
+ layer_norm_cls: nn.Module = LayerNorm,
510
+ layer_norm_eps: float = 1e-5,
511
+ adaptive_layer_norm=False,
512
+ ) -> None:
513
+ factory_kwargs = {"device": device, "dtype": dtype}
514
+ super(TransformerDecoderLayer, self).__init__()
515
+ self.self_attn = MultiheadAttention(
516
+ d_model,
517
+ nhead,
518
+ dropout=dropout,
519
+ batch_first=batch_first,
520
+ linear1_cls=linear1_self_attention_cls,
521
+ linear2_cls=linear2_self_attention_cls,
522
+ **factory_kwargs,
523
+ )
524
+ self.multihead_attn = MultiheadAttention(
525
+ d_model,
526
+ nhead,
527
+ dropout=dropout,
528
+ batch_first=batch_first,
529
+ linear1_cls=linear1_self_attention_cls,
530
+ linear2_cls=linear2_self_attention_cls,
531
+ **factory_kwargs,
532
+ )
533
+ # Implementation of Feedforward model
534
+ self.linear1 = linear1_feedforward_cls(
535
+ d_model, dim_feedforward, **factory_kwargs
536
+ )
537
+ self.dropout = nn.Dropout(dropout)
538
+ self.linear2 = linear2_feedforward_cls(
539
+ dim_feedforward, d_model, **factory_kwargs
540
+ )
541
+
542
+ self.norm_first = norm_first
543
+ self.dropout1 = nn.Dropout(dropout)
544
+ self.dropout2 = nn.Dropout(dropout)
545
+ self.dropout3 = nn.Dropout(dropout)
546
+
547
+ # Legacy string support for activation function.
548
+ if isinstance(activation, str):
549
+ self.activation = _get_activation_fn(activation)
550
+ elif isinstance(activation, partial):
551
+ self.activation = activation(d_model)
552
+ elif activation == BalancedDoubleSwish:
553
+ self.activation = BalancedDoubleSwish(d_model)
554
+ else:
555
+ self.activation = activation
556
+
557
+ if adaptive_layer_norm:
558
+ norm1 = layer_norm_cls(
559
+ d_model, eps=layer_norm_eps, **factory_kwargs
560
+ )
561
+ norm2 = layer_norm_cls(
562
+ d_model, eps=layer_norm_eps, **factory_kwargs
563
+ )
564
+ norm3 = layer_norm_cls(
565
+ d_model, eps=layer_norm_eps, **factory_kwargs
566
+ )
567
+
568
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
569
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
570
+ self.norm3 = AdaptiveLayerNorm(d_model, norm3)
571
+ else:
572
+ self.norm1 = layer_norm_cls(
573
+ d_model, eps=layer_norm_eps, **factory_kwargs
574
+ )
575
+ self.norm2 = layer_norm_cls(
576
+ d_model, eps=layer_norm_eps, **factory_kwargs
577
+ )
578
+ if layer_norm_cls == IdentityNorm:
579
+ self.norm3 = BalancedBasicNorm(
580
+ d_model, eps=layer_norm_eps, **factory_kwargs
581
+ )
582
+ else:
583
+ self.norm3 = layer_norm_cls(
584
+ d_model, eps=layer_norm_eps, **factory_kwargs
585
+ )
586
+
587
+ def forward(
588
+ self,
589
+ tgt: Tensor,
590
+ memory: Tensor,
591
+ tgt_mask: Optional[Tensor] = None,
592
+ memory_mask: Optional[Tensor] = None,
593
+ tgt_key_padding_mask: Optional[Tensor] = None,
594
+ memory_key_padding_mask: Optional[Tensor] = None,
595
+ ) -> Tensor:
596
+ r"""Pass the inputs (and mask) through the decoder layer.
597
+
598
+ Args:
599
+ tgt: the sequence to the decoder layer (required).
600
+ memory: the sequence from the last layer of the encoder (required).
601
+ tgt_mask: the mask for the tgt sequence (optional).
602
+ memory_mask: the mask for the memory sequence (optional).
603
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
604
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
605
+
606
+ Shape:
607
+ see the docs in Transformer class.
608
+ """
609
+ tgt_is_tuple = False
610
+ if isinstance(tgt, tuple):
611
+ x, stage_embedding = tgt
612
+ tgt_is_tuple = True
613
+ else:
614
+ x, stage_embedding = tgt, None
615
+
616
+ if self.norm_first:
617
+ x = x + self._sa_block(
618
+ self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
619
+ )
620
+ x = x + self._mha_block(
621
+ self.norm2(x, stage_embedding),
622
+ memory,
623
+ memory_mask,
624
+ memory_key_padding_mask,
625
+ )
626
+ x = x + self._ff_block(self.norm3(x, stage_embedding))
627
+ else:
628
+ x = self.norm1(
629
+ x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
630
+ stage_embedding,
631
+ )
632
+ x = self.norm2(
633
+ x
634
+ + self._mha_block(
635
+ x, memory, memory_mask, memory_key_padding_mask
636
+ ),
637
+ stage_embedding,
638
+ )
639
+ x = self.norm3(x + self._ff_block(x), stage_embedding)
640
+
641
+ if tgt_is_tuple:
642
+ return (x, stage_embedding)
643
+ return x
644
+
645
+ # self-attention block
646
+ def _sa_block(
647
+ self,
648
+ x: Tensor,
649
+ attn_mask: Optional[Tensor],
650
+ key_padding_mask: Optional[Tensor],
651
+ ) -> Tensor:
652
+ x = self.self_attn(
653
+ x,
654
+ x,
655
+ x,
656
+ attn_mask=attn_mask,
657
+ key_padding_mask=key_padding_mask,
658
+ need_weights=False,
659
+ )[0]
660
+ return self.dropout1(x)
661
+
662
+ # multihead attention block
663
+ def _mha_block(
664
+ self,
665
+ x: Tensor,
666
+ mem: Tensor,
667
+ attn_mask: Optional[Tensor],
668
+ key_padding_mask: Optional[Tensor],
669
+ ) -> Tensor:
670
+ x = self.multihead_attn(
671
+ x,
672
+ mem,
673
+ mem,
674
+ attn_mask=attn_mask,
675
+ key_padding_mask=key_padding_mask,
676
+ need_weights=False,
677
+ )[0]
678
+ return self.dropout2(x)
679
+
680
+ # feed forward block
681
+ def _ff_block(self, x: Tensor) -> Tensor:
682
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
683
+ return self.dropout3(x)
684
+
685
+
686
+ def _get_clones(module, N):
687
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
688
+
689
+
690
+ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
691
+ if activation == "relu":
692
+ return F.relu
693
+ elif activation == "gelu":
694
+ return F.gelu
695
+
696
+ raise RuntimeError(
697
+ "activation should be relu/gelu, not {}".format(activation)
698
+ )
models/modules/utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cp from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py, modified by Puyuan Peng
2
+ import torch
3
+
4
+
5
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
6
+ """
7
+ Args:
8
+ lengths:
9
+ A 1-D tensor containing sentence lengths.
10
+ max_len:
11
+ The length of masks.
12
+ Returns:
13
+ Return a 2-D bool tensor, where masked positions
14
+ are filled with `True` and non-masked positions are
15
+ filled with `False`.
16
+
17
+ >>> lengths = torch.tensor([1, 3, 2, 5])
18
+ >>> make_pad_mask(lengths)
19
+ tensor([[False, True, True, True, True],
20
+ [False, False, False, True, True],
21
+ [False, False, True, True, True],
22
+ [False, False, False, False, False]])
23
+ """
24
+ assert lengths.ndim == 1, lengths.ndim
25
+ max_len = max(max_len, lengths.max())
26
+ n = lengths.size(0)
27
+ seq_range = torch.arange(0, max_len, device=lengths.device)
28
+ expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
29
+
30
+ return expaned_lengths >= lengths.unsqueeze(-1)
31
+
32
+ def generate_partial_autoregressive_mask(sz, start, end):
33
+ mask = torch.zeros(sz, sz).bool()
34
+ mask[start:end, start:end] = torch.triu(torch.ones(end-start, end-start,dtype=torch.bool), diagonal=1)
35
+ mask[:start, start:end] = True
36
+ mask[end:, start:end] = True
37
+ return mask
models/voicecraft.py ADDED
@@ -0,0 +1,1406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import logging
5
+ import argparse, copy
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torchmetrics.classification import MulticlassAccuracy
10
+
11
+ from .modules.utils import make_pad_mask
12
+
13
+ from .modules.embedding import SinePositionalEmbedding, TokenEmbedding
14
+ from .modules.transformer import (
15
+ LayerNorm,
16
+ TransformerEncoder,
17
+ TransformerEncoderLayer,
18
+ )
19
+ from .codebooks_patterns import DelayedPatternProvider
20
+
21
+ def top_k_top_p_filtering(
22
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
23
+ ):
24
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
25
+ Args:
26
+ logits: logits distribution shape (batch size, vocabulary size)
27
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
28
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
29
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
30
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
31
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
32
+ """
33
+ if top_k > 0:
34
+ top_k = min(
35
+ max(top_k, min_tokens_to_keep), logits.size(-1)
36
+ ) # Safety check
37
+ # Remove all tokens with a probability less than the last token of the top-k
38
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
39
+ logits[indices_to_remove] = filter_value
40
+
41
+ if top_p < 1.0:
42
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
43
+ cumulative_probs = torch.cumsum(
44
+ F.softmax(sorted_logits, dim=-1), dim=-1
45
+ )
46
+
47
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
48
+ sorted_indices_to_remove = cumulative_probs > top_p
49
+ if min_tokens_to_keep > 1:
50
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
51
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
52
+ # Shift the indices to the right to keep also the first token above the threshold
53
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
54
+ ..., :-1
55
+ ].clone()
56
+ sorted_indices_to_remove[..., 0] = 0
57
+
58
+ # scatter sorted tensors to original indexing
59
+ indices_to_remove = sorted_indices_to_remove.scatter(
60
+ 1, sorted_indices, sorted_indices_to_remove
61
+ )
62
+ logits[indices_to_remove] = filter_value
63
+ return logits
64
+
65
+
66
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
67
+ # temperature: (`optional`) float
68
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
69
+ # top_k: (`optional`) int
70
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
71
+ # top_p: (`optional`) float
72
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
73
+
74
+ # Temperature (higher temperature => more likely to sample low probability tokens)
75
+ if temperature != 1.0:
76
+ logits = logits / temperature
77
+ # Top-p/top-k filtering
78
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
79
+ # Sample
80
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
81
+ return token
82
+
83
+
84
+
85
+ class VoiceCraft(nn.Module):
86
+ def __init__(self, args):
87
+ super().__init__()
88
+ self.args = copy.copy(args)
89
+ self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks)
90
+ if not getattr(self.args, "special_first", False):
91
+ self.args.special_first = 0
92
+ if not getattr(self.args, "n_special", False):
93
+ self.args.n_special = 3
94
+ self.args.eos = getattr(self.args, "eos", -1)
95
+ self.eog = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eog, dtype=torch.long), requires_grad=False) # [K 1]
96
+ if self.args.eos > 0:
97
+ assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos
98
+ self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1]
99
+ if type(self.args.audio_vocab_size) == str:
100
+ self.args.audio_vocab_size = eval(self.args.audio_vocab_size)
101
+
102
+ self.n_text_tokens = self.args.text_vocab_size + 1
103
+ assert self.args.text_pad_token == self.args.text_vocab_size, f"self.args.text_vocab_size: {self.args.text_vocab_size}, self.args.text_pad_token: {self.args.text_pad_token}"
104
+
105
+ self.n_audio_tokens = [self.args.audio_vocab_size + self.args.n_special] * self.args.n_codebooks # special tokens: empty token, EOG token, audio pad token
106
+ assert self.args.audio_vocab_size == self.args.empty_token, self.args.empty_token
107
+ assert self.args.eog == self.args.audio_vocab_size + 1, self.args.eog
108
+ assert self.args.audio_pad_token == self.args.audio_vocab_size + 2, self.args.audio_pad_token
109
+
110
+ self.text_embedding = TokenEmbedding(
111
+ dim_model=self.args.d_model,
112
+ vocab_size=self.n_text_tokens,
113
+ dropout=self.args.text_embedding_dropout
114
+ )
115
+
116
+ self.audio_embedding = nn.ModuleList(
117
+ [
118
+ TokenEmbedding(
119
+ dim_model=self.args.audio_embedding_dim,
120
+ vocab_size=self.n_audio_tokens[k],
121
+ dropout=self.args.audio_embedding_dropout
122
+ ) for k in range(self.args.n_codebooks)
123
+ ]
124
+ )
125
+ self.mask_embedding = nn.Parameter(torch.randn(self.args.max_n_spans, self.args.d_model), requires_grad=True)
126
+ self.text_positional_embedding = SinePositionalEmbedding(
127
+ self.args.d_model,
128
+ dropout=self.args.text_positional_embedding_dropout,
129
+ scale=False,
130
+ alpha=True, # learnable scaler, scale the volume of positional embedding
131
+ )
132
+ self.audio_positional_embedding = SinePositionalEmbedding(
133
+ self.args.d_model,
134
+ dropout=self.args.audio_positional_embedding_dropout,
135
+ scale=False,
136
+ alpha=True, # learnable scaler, scale the volume of positional embedding
137
+ )
138
+
139
+ dec_layer = TransformerEncoderLayer(
140
+ self.args.d_model,
141
+ self.args.nhead,
142
+ dim_feedforward=self.args.d_model * 4,
143
+ dropout=self.args.trm_dropout,
144
+ batch_first=True,
145
+ norm_first=True,
146
+ layer_norm_cls=LayerNorm
147
+ )
148
+ self.decoder = TransformerEncoder(
149
+ dec_layer,
150
+ num_layers=self.args.num_decoder_layers,
151
+ norm=LayerNorm(self.args.d_model),
152
+ )
153
+
154
+ self.predict_layer = nn.ModuleList(
155
+ [
156
+ nn.Sequential(nn.Linear(self.args.d_model, self.args.audio_vocab_size//2), nn.GELU(), nn.Linear(self.args.audio_vocab_size//2, self.n_audio_tokens[k])) for k in range(self.args.n_codebooks)
157
+ ]
158
+ )
159
+
160
+ self.accuracy_metrics = nn.ModuleList(
161
+ [MulticlassAccuracy(
162
+ self.n_audio_tokens[k],
163
+ top_k=10,
164
+ average="micro",
165
+ multidim_average="global",
166
+ ignore_index=None,
167
+ ) for k in range(self.args.n_codebooks)]
168
+ )
169
+
170
+
171
+ def prepare_mask_intervals(self, y_lens):
172
+ mask_intervals = []
173
+ non_mask_intervals = []
174
+
175
+ for i, y_len in enumerate(y_lens):
176
+ if self.args.mask_sample_dist == "uniform":
177
+ n_spans = random.choice(range(1, self.args.max_n_spans+1))
178
+ elif "poisson" in self.args.mask_sample_dist.lower():
179
+ param = float(self.args.mask_sample_dist[len("poisson"):])
180
+ poisson_sample = torch.poisson(torch.tensor([param]))
181
+ n_spans = int(poisson_sample.clamp(1, self.args.max_n_spans).item())
182
+
183
+ starts = random.sample(range(1, y_len-1-self.args.mask_len_min), n_spans)
184
+ starts = sorted(starts)
185
+
186
+ for j in range(len(starts)-1, 0, -1):
187
+ if starts[j] - starts[j-1] < self.args.min_gap:
188
+ del starts[j] # If elements are too close, delete the later one
189
+ assert len(starts) > 0, f"there is no masked span left, y_len: {y_len}, sampled n_spans: {n_spans}"
190
+
191
+ temp_starts = starts + [y_len]
192
+ gaps = [temp_starts[j+1] - temp_starts[j] for j in range(len(temp_starts)-1)]
193
+
194
+ ends = []
195
+
196
+ for j, (start, gap) in enumerate(zip(starts, gaps)):
197
+ mask_len = random.randint(self.args.mask_len_min, self.args.mask_len_max)
198
+ # if mask_len > gap * self.args.max_mask_portion: # make sure the masks are not overlapping with each other
199
+ if mask_len > gap - 1: # make sure the masks are not overlapping with each other
200
+ # temp_mask_start = int(0.6*gap*self.args.max_mask_portion)
201
+ # temp_mask_end = int(gap*self.args.max_mask_portion)
202
+ temp_mask_start = 1
203
+ temp_mask_end = gap - 1
204
+ mask_len = random.randint(temp_mask_start, temp_mask_end)
205
+ ends.append(start + mask_len)
206
+
207
+ mask_intervals.append([(s,e) for s,e in zip(starts, ends)])
208
+ non_mask_intervals.append([(ns,ne) for ns, ne in zip([0]+ends, starts+[y_len])])
209
+
210
+ return mask_intervals, non_mask_intervals
211
+
212
+ def rearrange(self, y, non_mask_intervals, mask_intervals):
213
+ reduced_eog = getattr(self.args, "reduced_eog", 0)
214
+ rearranged_y = []
215
+ for i in range(len(y)):
216
+ if self.args.eos > 0:
217
+ assert reduced_eog
218
+ cur_y = [y[i, :, item[0]: item[1]] for item in non_mask_intervals[i][:-1]] + [torch.cat([y[i, :, non_mask_intervals[i][-1][0]: non_mask_intervals[i][-1][1]], self.eos], dim=-1)] + [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in mask_intervals[i]] # only insert eog to the last non-mask-interval, which is when the utterance actual ends
219
+ else:
220
+ if reduced_eog:
221
+ cur_y = [y[i, :, item[0]: item[1]] for item in non_mask_intervals[i][:-1]] + [torch.cat([y[i, :, non_mask_intervals[i][-1][0]: non_mask_intervals[i][-1][1]], self.eog], dim=-1)] + [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in mask_intervals[i]] # only insert eog to the last non-mask-interval, which is when the utterance actual ends
222
+ else:
223
+ cur_y = [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in non_mask_intervals[i]] + [torch.cat([y[i, :, item[0]: item[1]], self.eog], dim=-1) for item in mask_intervals[i]] # eog is added to each section TODO this is not correct, I should add eog to non_mask_intervals if that segment is not the ending segment (as there is no way for the model to predict eog for those segments, and this will do harm to tts experiment, where the model randomly output eog for the first segment)
224
+ rearranged_y.append(cur_y)
225
+ return rearranged_y
226
+
227
+ def shift(self, rearranged_y):
228
+ shifted_y = []
229
+ patterns = []
230
+ for i in range(len(rearranged_y)):
231
+ cur_patterns = [self.pattern.get_pattern(cur_y.shape[1]) for cur_y in rearranged_y[i]]
232
+ out = [cur_pattern.build_pattern_sequence(z=cur_y.unsqueeze(0).contiguous(), special_token=self.args.empty_token, keep_only_valid_steps=False) for cur_pattern, cur_y in zip(cur_patterns, rearranged_y[i])]
233
+ shifted_y.append([item[0].squeeze(0) for item in out]) # the first item is values, later two are indexes and mask
234
+ patterns.append(cur_patterns)
235
+ return shifted_y, patterns
236
+
237
+ def insert_mask(self, shifted_y):
238
+ inserted_y = []
239
+ mask_position = []
240
+ mask_value = []
241
+ for i in range(len(shifted_y)):
242
+ num_masks = (len(shifted_y[i]) - 1) // 2
243
+ assert num_masks == (len(shifted_y[i]) - 1) / 2, len(shifted_y[i])
244
+ emb_inds = list(range(self.args.max_n_spans))
245
+ if self.args.shuffle_mask_embedding:
246
+ random.shuffle(emb_inds)
247
+ emb_inds_use = emb_inds[:num_masks]
248
+ emb_inds_use = emb_inds_use + emb_inds_use
249
+ mask_value.append(emb_inds_use)
250
+ cur_inserted_y = []
251
+ cur_mask_position = []
252
+ for j in range(len(shifted_y[i])-1):
253
+ cur_inserted_y.append(shifted_y[i][j])
254
+ cur_mask_position.append(sum([item.shape[1] for item in cur_inserted_y])) # each item is of shape [K S], so take shape[1]
255
+ cur_inserted_y.append(self.eog) # insert mask token of shape [K, 1], BUT we are actually using the eog token as a place holder here, as the real mask will be inserted in embed_y function
256
+
257
+ cur_inserted_y.append(shifted_y[i][-1])
258
+
259
+ inserted_y.append(cur_inserted_y)
260
+ mask_position.append(cur_mask_position)
261
+ return inserted_y, mask_position, mask_value
262
+
263
+ def cat_y(self, inserted_y, mask_position, y_lens):
264
+ reduced_eog = getattr(self.args, "reduced_eog", 0)
265
+ cated_y = []
266
+ new_y_lens = []
267
+ for i in range(len(inserted_y)):
268
+ cur_cated_y = torch.cat(inserted_y[i], dim=1) #[K S]
269
+ cur_cated_y = cur_cated_y.transpose(1,0) # [S K]
270
+ cur_cated_y_len = cur_cated_y.shape[0]
271
+ if reduced_eog:
272
+ assert cur_cated_y_len == y_lens[i] + len(mask_position[i]) + (len(mask_position[i]) + 1) * self.args.n_codebooks + (len(mask_position[i])/2 + 1), f"cur_cated_y_len == {cur_cated_y_len}, but it should be y_lens[i] ({y_lens[i]}) + len(mask_position[i]) ({len(mask_position[i])}) + (len(mask_position[i]) + 1) * self.args.n_codebooks ({(len(mask_position[i]) + 1) * self.args.n_codebooks}) + (len(mask_position[i])/2 + 1) ({len(mask_position[i])/2 + 1})={y_lens[i] + len(mask_position[i]) + (len(mask_position[i]) + 1) * self.args.n_codebooks + (len(mask_position[i])/2 + 1)}"
273
+ else:
274
+ assert cur_cated_y_len == y_lens[i] + len(mask_position[i]) + (len(mask_position[i]) + 1) * self.args.n_codebooks + (len(mask_position[i]) + 1), f"cur_cated_y_len == {cur_cated_y_len}, but it should be y_lens[i] ({y_lens[i]}) + len(mask_position[i]) ({len(mask_position[i])}) + (len(mask_position[i]) + 1) * self.args.n_codebooks ({(len(mask_position[i]) + 1) * self.args.n_codebooks}) + (len(mask_position[i]) + 1) ({len(mask_position[i]) + 1})" # the last term represent the inserted eog token, originally it's inserted at the end of every token, but this is wrong
275
+ new_y_lens.append(cur_cated_y_len)
276
+ cated_y.append(cur_cated_y)
277
+
278
+ cated_y = torch.nn.utils.rnn.pad_sequence(cated_y, batch_first=False, padding_value=self.args.audio_pad_token)
279
+ assert cated_y.shape == torch.Size([max(new_y_lens),len(inserted_y), self.args.n_codebooks]), f"cated_y.shape: {cated_y.shape}, but it should be {torch.Size([max(new_y_lens,len(inserted_y), self.args.n_codebooks)])}"
280
+ cated_y = cated_y.permute(2,0,1) # [T,B,K]->[K,T,B]
281
+ assert cated_y.shape[0] == self.args.n_codebooks, cated_y.shape
282
+ return cated_y, torch.LongTensor(new_y_lens).to(cated_y.device)
283
+
284
+ def embed_y(self, cated_y, mask_position, mask_value):
285
+ embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, T, B, D]
286
+ assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
287
+ assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
288
+ embedded_y = embedded_y.sum(dim=0) # [K,T,B,D]->[T,B,D]
289
+ embedded_y = embedded_y.transpose(1,0) # [T,B,D]->[B,T,D]
290
+ for i in range(len(embedded_y)):
291
+ if len(mask_position[i]) > 0:
292
+ embedded_y[i, mask_position[i]] = self.mask_embedding[mask_value[i]]
293
+ return embedded_y
294
+
295
+ def prepare_input_target(self, y, y_lens):
296
+ # rearrange y
297
+ # assume y shape: [B T K], K is n_codebooks
298
+ assert y.shape[1] == self.args.n_codebooks, y.shape
299
+ # sample mask_intervals
300
+ mask_intervals, non_mask_intervals = self.prepare_mask_intervals(y_lens)
301
+
302
+ # need to have EOG in each section (SOG will be generated by the pattern class)
303
+ # but mask can be inserted later after we have shifted the input
304
+ # y could be rearranged in this way:
305
+ # [
306
+ # [tensor[4, 12], tensor[4, 45], tensor[4, 102], tensor[4, 32]], tensor[4, 22]],
307
+ # [tensor[4, 44], tensor[4, 56], tensor[4, 19]],
308
+ # ...
309
+ # ]
310
+ # for the first list of tensors (4 tensors), first 3 tensors are non_masked part, last 2 are masked part.
311
+ # NOTE #non_masked_part = #masked_part + 1
312
+ # NOTE *these are also the targets*
313
+ # added eog at the end of each segment (masked segment and unmasked segment)
314
+ rearranged_y = self.rearrange(y, non_mask_intervals, mask_intervals)
315
+ targets = rearranged_y # each element in each sample is of shape [K T]
316
+ assert targets[0][0].shape[0] == self.args.n_codebooks, targets[0][0].shape
317
+
318
+ # next we need to apply pattern shifting to each tensor, after which, we'll replace the starting tokens of each section with a token that's different from the special padding token
319
+ # [[5, 1, 2, 3, 4, 5, 5],
320
+ # [5, 5, 1, 2, 3, 4, 5],
321
+ # [5, 5, 5, 1, 2, 3, 4]]
322
+ shifted_y, patterns = self.shift(rearranged_y) # each element [K S]
323
+ assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape[0]
324
+
325
+
326
+ # then, insert mask token at the intersection of each tensor (we want to decide the arrangement of the mask (shuffle or not)), we better have a separate nn.embedding for it
327
+ # we also need to record the position of the inserted mask
328
+ inserted_y, mask_position, mask_value = self.insert_mask(shifted_y)
329
+ assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][0].shape[0]
330
+ assert inserted_y[0][1].shape == torch.Size((self.args.n_codebooks, 1)), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}"
331
+
332
+ # then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token
333
+ cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB
334
+ assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], len(inserted_y)))
335
+
336
+
337
+ # embed remember to separately embed the mask tokens
338
+ embedded_y = self.embed_y(cated_y, mask_position, mask_value) #BTD
339
+ assert embedded_y.shape[1:] == torch.Size((max(new_y_lens), self.args.d_model)), embedded_y.shape
340
+
341
+ # positional embedding
342
+ y_input = self.audio_positional_embedding(embedded_y)
343
+
344
+ # make attention mask and padding mask
345
+ y_padding_mask = make_pad_mask(new_y_lens).to(y.device)
346
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y_padding_mask.device)
347
+ return y_input, new_y_lens, targets, y_padding_mask, y_attention_mask, mask_position, patterns
348
+
349
+ def remove_mask(self, logits, mask_position, new_y_lens):
350
+ # logits: [B K S card]
351
+ logits_use = []
352
+ for i in range(len(logits)):
353
+ non_mask_positions = [-1] + mask_position[i] + [new_y_lens[i]]
354
+ non_mask_intervals = [[non_mask_positions[i]+1, non_mask_positions[i+1]] for i in range(len(non_mask_positions)-1)]
355
+ cur_logits_use = [logits[i, :, l:r] for l,r in non_mask_intervals]
356
+ logits_use.append(cur_logits_use)
357
+
358
+ return logits_use
359
+
360
+ def revert_pattern(self, patterns, logits_use):
361
+ logits_final = []
362
+ logit_masks = []
363
+ for i in range(len(logits_use)):
364
+ cur_logits = [
365
+ item.unsqueeze(0).permute(0, 3, 1, 2).contiguous() for item in logits_use[i]
366
+ ] # each item is of shape [1 K S card] [1 card K S]
367
+ cur_logits_final = [
368
+ cur_pattern.revert_pattern_logits(
369
+ item, 0, keep_only_valid_steps=False
370
+ )
371
+ for cur_pattern, item in zip(patterns[i], cur_logits)
372
+ ] # if input output order doesn't match, this step will give an error
373
+ cur_logits_final_ret = [item[0].permute(0,2,3,1).squeeze(0) for item in cur_logits_final] # each element is of shape [K,T,card]
374
+ logits_final.append(cur_logits_final_ret)
375
+ logit_masks.append([item[2] for item in cur_logits_final])
376
+
377
+ return logits_final, logit_masks
378
+
379
+ def dec_forward(
380
+ self,
381
+ x_input,
382
+ x_lens,
383
+ x_attention_mask,
384
+ x_padding_mask,
385
+ y_input,
386
+ new_y_lens,
387
+ y_attention_mask,
388
+ y_padding_mask,
389
+ past=None,
390
+ last_3_tokens=False
391
+ ):
392
+ x_attn_mask = F.pad(
393
+ x_attention_mask,
394
+ (0, new_y_lens.max()),
395
+ value=True,
396
+ ) # x attn to all x, doesn't attn to any y, this follow figure 3 of the valle paper
397
+ y_attn_mask = F.pad(
398
+ y_attention_mask,
399
+ (x_lens.max(), 0), # y is padded at the front
400
+ value=False,
401
+ ) # y attn to all x, for y itself use lower triangle mask to ensure autoregressive
402
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
403
+
404
+ # merge key padding and attention masks
405
+ bsz, src_len = x_input.shape[0], x_lens.max() + new_y_lens.max()
406
+ xy_padding_mask = torch.concat([x_padding_mask, y_padding_mask], dim=1)
407
+ _xy_padding_mask = (
408
+ xy_padding_mask.view(bsz, 1, 1, src_len)
409
+ .expand(-1, self.args.nhead, -1, -1)
410
+ .reshape(bsz * self.args.nhead, 1, src_len)
411
+ )
412
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
413
+
414
+ new_attn_mask = torch.zeros_like(xy_attn_mask)
415
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
416
+ xy_attn_mask = new_attn_mask
417
+
418
+ xy_input = torch.cat([x_input, y_input], dim=1)
419
+
420
+ if past == None: # do not use kvcache
421
+ out, _ = self.decoder((xy_input, None), mask=xy_attn_mask)
422
+ return out[:, x_lens.max():], None
423
+ else: # use kvcache
424
+ if past.ndim > 3: # uses kvcache, only need to pass the last tokens, this doesn't work with multi-span speech editing yet
425
+ if last_3_tokens:
426
+ xy_input = xy_input[:, -3:]
427
+ xy_attn_mask = xy_attn_mask[:, -3:]
428
+ else:
429
+ xy_input = xy_input[:, -1:]
430
+ xy_attn_mask = xy_attn_mask[:, -1:]
431
+
432
+ out, present = self.decoder((xy_input, None), mask=xy_attn_mask, past=past)
433
+ if isinstance(out, tuple): # get rid of stage_embedding
434
+ out = out[0]
435
+
436
+ if out.shape[1] > x_lens.max(): # the first pass, not kvcache yet
437
+ return out[:, x_lens.max():], present
438
+ else: # used kvcache
439
+ return out, present
440
+
441
+ def forward(self, batch):
442
+ """
443
+ Args:
444
+ x:
445
+ A 2-D tensor of shape (N, S).
446
+ x_lens:
447
+ A 1-D tensor of shape (N,). It contains the number of tokens in `x`
448
+ before padding.
449
+ y:
450
+ A 3-D tensor of shape (N, K, T).
451
+ where K is the number of codebooks
452
+ y_lens:
453
+ A 1-D tensor of shape (N,). It contains the number of tokens in `x`
454
+ before padding.
455
+ """
456
+ x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
457
+ x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
458
+ y = y[:, :y_lens.max()]
459
+ assert x.ndim == 2, x.shape
460
+ assert x_lens.ndim == 1, x_lens.shape
461
+ assert y.ndim == 3 and y.shape[1] == self.args.n_codebooks, y.shape
462
+ assert y_lens.ndim == 1, y_lens.shape
463
+ # makes attention mask and padding mask for x
464
+ x_padding_mask = make_pad_mask(x_lens).to(x.device)
465
+ x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x_padding_mask.device)
466
+ x_input = self.text_embedding(x)
467
+ x_input = self.text_positional_embedding(x_input)
468
+ y_input, new_y_lens, targets, y_padding_mask, y_attention_mask, mask_position, patterns = self.prepare_input_target(y, y_lens)
469
+ y_out = self.dec_forward(
470
+ x_input,
471
+ x_lens,
472
+ x_attention_mask,
473
+ x_padding_mask,
474
+ y_input,
475
+ new_y_lens,
476
+ y_attention_mask,
477
+ y_padding_mask
478
+ )
479
+ y_out = y_out[0] # no kv-caching during training
480
+ assert y_out.shape == y_input.shape, f"y_out.shape: {y_out.shape}, y_input.shape: {y_input.shape}" # [B S D]
481
+
482
+ logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card]
483
+ # take out the mask token (using mask_position and new_y_lens) and revert (using function provided by self.pattern)
484
+ assert logits.shape[1] == self.args.n_codebooks and logits.shape[3] == self.n_audio_tokens[0], logits.shape
485
+
486
+ logits_use = self.remove_mask(logits, mask_position, new_y_lens)
487
+
488
+ # revert the pattern shift for each logits section in each sample
489
+ logits_final, logit_masks = self.revert_pattern(patterns, logits_use)
490
+ assert logits_final[0][0].shape[0] == self.args.n_codebooks and logits_final[0][0].shape[2] == self.n_audio_tokens[0], f"it is: {logits_final[0][0].shape}, but should be [K, T, card]"
491
+ # testing
492
+ sample_to_test = 0
493
+ assert len(logits_final[sample_to_test]) == len(targets[sample_to_test]), f"{len(logits_final[sample_to_test])}, {len(targets[sample_to_test])}"
494
+ temp = sum([logits_final[sample_to_test][i].shape[:-1] != targets[sample_to_test][i].shape for i in range(len(targets[sample_to_test]))])
495
+ assert temp == 0, f"none equal positions: {temp}, total number of elements: {len(targets[sample_to_test])}"
496
+
497
+ logit_masked = sum([(item==False).any() for cur_mask in logit_masks for item in cur_mask])
498
+ assert logit_masked == 0, logit_masks
499
+
500
+ logits = torch.cat([torch.cat(item, dim=1) for item in logits_final], dim=1) # [K, T1+T2+T3+..., card]
501
+ targets = torch.cat([torch.cat(item, dim=1) for item in targets], dim=1) # [K, T1+T2+T3+...]
502
+ assert targets.shape[0] == logits.shape[0], f"{targets.shape}, {logits.shape}"
503
+ loss = []
504
+ ntokens = []
505
+ top10acc = []
506
+ for k, (logit, target) in enumerate(zip(logits, targets)):
507
+ loss.append(F.cross_entropy(logit, target, reduction='mean'))
508
+ top10acc.append(self.accuracy_metrics[k](logit.detach(), target))
509
+ ntokens.append(len(logit))
510
+
511
+ all_ntokens = sum(ntokens)
512
+ if self.args.codebook_weight != None:
513
+ codebook_weight = eval(self.args.codebook_weight)
514
+ else:
515
+ codebook_weight = [1.] * self.args.n_codebooks
516
+ loss = sum([l*nt*cw for l, nt, cw in zip(loss, ntokens, codebook_weight)])
517
+ top10acc_by_codebook = [t10a*nt for t10a, nt in zip(top10acc, ntokens)]
518
+ top10acc = sum(top10acc_by_codebook)
519
+ ntokens = torch.tensor(all_ntokens).to(logits.device)
520
+
521
+ return {
522
+ "loss": loss,
523
+ "top10acc": top10acc,
524
+ "top10acc_by_codebook": top10acc_by_codebook,
525
+ "effective_ntoken": ntokens,
526
+ }
527
+
528
+ def inference(
529
+ self,
530
+ x: torch.Tensor,
531
+ x_lens: torch.Tensor,
532
+ y: torch.Tensor,
533
+ mask_interval: list[torch.Tensor],
534
+ top_k: int=-100,
535
+ top_p: float=1.0,
536
+ temperature: float=1.0,
537
+ stop_repetition: int=-1,
538
+ kvcache: int=1,
539
+ silence_tokens: list[int]=[1388,1898,131],
540
+ ) -> torch.Tensor:
541
+ """
542
+ Args:
543
+ x:
544
+ A 2-D tensor of shape (1, L).
545
+ x_lens:
546
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
547
+ before padding.
548
+ y:
549
+ A 3-D tensor of shape (1, T, K).
550
+ mask_interval:
551
+ a list of tensors of shape (M, 2). contains M mask_start and mask_end. list length is actually 1, because we only support single sample inference for now
552
+ top_k: (`optional`) int
553
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
554
+ top_p: (`optional`) float
555
+ For Neucleus sampling
556
+ temperature: (`optional`) float
557
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
558
+ eog_coef: (`optional`) float
559
+ if 0, no change to eog token logits, otherwise, will adjust eog token logit based on the difference between acoustic token and phn token length
560
+ stop_repetition (`optional`) int
561
+ if not -1, will set the logits of a token that repeated this many times to be -100000, to avoid generating it again. This only apply to tokens from the first codebook
562
+ allowed_repeat_tokens (`optional`) list of ints
563
+ by inspecting the validation set, get a few tokens that indeed repeat a significant amount of time, and exclude those tokens from prevent repetition
564
+ ultimate_stop_repetition (`optional`) int
565
+ no matter that token it is, stop repetition once after this number
566
+ """
567
+ assert x.ndim == 2, x.shape
568
+ assert x_lens.ndim == 1, x_lens.shape
569
+ assert y.ndim == 3, y.shape
570
+ if self.args.special_first:
571
+ y = y + int(self.args.n_special)
572
+ y = y.transpose(2,1) # [1,T,K] -> [1,K,T]
573
+ assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding
574
+ assert mask_interval.shape == torch.Size((1, mask_interval.shape[1], 2)), mask_interval
575
+
576
+ # make x attention mask and x_input
577
+ x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device)
578
+ # x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
579
+ x_input = self.text_embedding(x)
580
+ x_input = self.text_positional_embedding(x_input)
581
+
582
+ # make initial y_input
583
+
584
+ # make mask_interval and non_mask_interval
585
+ y_len = y.shape[2]
586
+ y_lens = torch.LongTensor([y_len]).to(y.device)
587
+ mask_interval = mask_interval[0]
588
+ starts = [item[0].item() for item in mask_interval] + [y_len]
589
+ ends = [0] + [item[1].item() for item in mask_interval]
590
+ mask_intervals = [[
591
+ (item[0].item(), item[1].item()) for item in mask_interval
592
+ ]] # a werid name change, mask_interval is input, now is mask_intervals, with one more dimension
593
+ non_mask_intervals = [[
594
+ (ns, ne) for ns, ne in zip(ends, starts)
595
+ ]]
596
+
597
+ # rearrange y
598
+ # will add have EOG in each section (SOG will be generated by the pattern class)
599
+ # but mask can be inserted later after we have shifted the input
600
+ # y could be rearranged in this way:
601
+ # [
602
+ # [tensor[4, 12], tensor[4, 45], tensor[4, 102], tensor[4, 32]], tensor[4, 22]],
603
+ # [tensor[4, 44], tensor[4, 56], tensor[4, 19]],
604
+ # ...
605
+ # ]
606
+ # for the first list of tensors (4 tensors), first 3 tensors are non_masked part, last 2 are masked part.
607
+ # NOTE #non_masked_part = #masked_part + 1
608
+ rearranged_y = self.rearrange(y, non_mask_intervals, mask_intervals)
609
+ assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][0].shape
610
+
611
+ # shift each element of y
612
+ # next we need to apply pattern shifting to each tensor, after which, we'll replace the starting tokens of each section with a token that's different from the special padding token
613
+ # [
614
+ # [empty, 1, 2, 3, eog, empty, empty, empty],
615
+ # [empty, empty, 1, 2, 3, eog, empty, empty],
616
+ # [empty, empty, empty, 1, 2, 3, eog, empty],
617
+ # [empty, empty, empty, empty, 1, 2, 3, eog]
618
+ # ]
619
+ shifted_y, patterns = self.shift(rearranged_y) # each element [K S], patterns is not used, as we directly use the original input y
620
+ assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape
621
+
622
+ # insert mask token at the intersction of each tensor, but *actually inserted eog as place holder*
623
+ # the position of inserted mask is also recorded
624
+ # and the mask_value, the index of the mask emb is recorded
625
+ inserted_y, mask_position, mask_value = self.insert_mask(shifted_y)
626
+ assert inserted_y[0][0].shape[0] == self.args.n_codebooks, inserted_y[0][0].shape[0]
627
+ assert inserted_y[0][1].shape == torch.Size((self.args.n_codebooks, 1)), f"this should be a mask, so should have shape {(self.args.n_codebooks, 1)}, but it's {inserted_y[0][1].shape}"
628
+
629
+ # then concat tensors that belong to the same sample (in order) then get the length of each sample, and then stack them in batch dimension, pad them with pad_token
630
+ cated_y, new_y_lens = self.cat_y(inserted_y, mask_position, y_lens) # KTB
631
+ assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], len(inserted_y)))
632
+ assert not (cated_y == self.args.audio_pad_token).any(), cated_y
633
+
634
+ ### NOTE this is different from forward, as we will remove the masked tokens
635
+ ### say there are two masked region
636
+ ### the cated_y should be like
637
+ ### [empty a a a a mask0 empty b b b mask1 empty c c mask0 empty]
638
+ ### which means we need to take the part after the last empty out
639
+ num_mask = len(mask_position[0])//2
640
+ assert num_mask == len(mask_position[0])/2, mask_position
641
+ cated_y = cated_y[:, :mask_position[0][num_mask]+2] # of shape [K,T,B]
642
+ # logging.info(f"mask_position[0][num_mask]+2: {mask_position[0][num_mask]+2}")
643
+ more_mask_value = mask_value[0][num_mask+1:] # NOTE this will be used in the generation loop for reference for inserting mask embedding
644
+ new_y_lens[0] = mask_position[0][num_mask]+2
645
+ mask_position[0] = mask_position[0][:num_mask+1]
646
+ assert mask_position[0][num_mask]+2 == cated_y.shape[1], f"num_mask: {num_mask}, mask_position: {mask_position}, cated_y.shape: {cated_y.shape}"
647
+
648
+ # embed: remember to separately embed the mask tokens
649
+ embedded_y = self.embed_y(cated_y, mask_position, [mask_value[0][:num_mask+1]]) #BTD
650
+ # assert embedded_y.shape == torch.Size((y.shape[0], max(new_y_lens), self.args.d_model)), embedded_y.shape
651
+
652
+ # positional embedding
653
+ y_input = self.audio_positional_embedding(embedded_y)
654
+
655
+ # make attention mask and padding mask
656
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
657
+ # y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
658
+
659
+ x_padding_mask = torch.full((1,x_lens[0]), False).to(x.device)
660
+ y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
661
+
662
+
663
+ codebook_eog = [False] * self.args.n_codebooks
664
+ generated = [] # doesn't contain any empty_token, contains eog
665
+ cur_generated = []
666
+ # say 0 is empty, 4 is eog
667
+ # tensor([[ 1, 2, 3, 4, 0, 0],
668
+ # [ 0, 1, 2, 3, 4, 0],
669
+ # [ 0, 0, 1, 2, 3, 4]])
670
+ num_gen = []
671
+ cur_num_gen = 0
672
+ ##################### silence repetition handling #####################
673
+ ##################### silence repetition handling #####################
674
+ logging.info(f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
675
+ consec_silence_count = 0
676
+ prev_token = None
677
+ ##################### silence repetition handling #####################
678
+ ##################### silence repetition handling #####################
679
+ # prepare the cache placeholder
680
+ # n_layers, 2, bsz, num_heads, src_len, head_dim
681
+ past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
682
+ # handle multi-span kv-cache
683
+ new_masked_span = False
684
+
685
+ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen):
686
+ if n_eog == 0:
687
+ logits_adjust = logits
688
+ for jj in range(1,self.args.n_codebooks):
689
+ logits_adjust[jj][self.args.eog] = -10000
690
+ logits_adjust[jj][self.args.empty_token] = -10000
691
+ ##################### silence repetition handling #####################
692
+ if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
693
+ if logits_adjust[0, prev_token] < 0:
694
+ logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * (consec_silence_count - (stop_repetition-1))
695
+ else:
696
+ logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / (consec_silence_count - (stop_repetition-1))
697
+ ##################### silence repetition handling #####################
698
+ if type(logits_adjust) == list:
699
+ samples_list= []
700
+ for logit in logits_adjust:
701
+ # print(logit)
702
+ # print(logit.shape)
703
+ cur_sample = topk_sampling(
704
+ logit.unsqueeze(0), top_k=top_k, top_p=top_p, temperature=temperature
705
+ ) # [1, 1]
706
+ samples_list.append(cur_sample)
707
+ samples = torch.cat(samples_list, dim=0) # [K, 1]
708
+ else:
709
+ samples = topk_sampling(
710
+ logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
711
+ ) # [K, 1]
712
+ assert samples.shape == torch.Size((self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}"
713
+ if cur_num_gen < self.args.n_codebooks-1:
714
+ for jj in range(1, self.args.n_codebooks - cur_num_gen):
715
+ samples[-jj, 0] = self.args.empty_token
716
+
717
+ if (
718
+ samples[0,0] == self.args.eog or torch.argmax(logits[0], dim=-1) == self.args.eog or y_input.shape[1] > x_lens[0] * 10
719
+ ): # last one means y is already too long, shouldn't happen, but put it here
720
+ samples[0,0] = self.args.eog
721
+ codebook_eog[0] = True
722
+ ##################### silence repetition handling #####################
723
+ ##################### silence repetition handling #####################
724
+ if samples[0,0] in silence_tokens and samples[0,0] == prev_token:
725
+ consec_silence_count += 1
726
+ else:
727
+ consec_silence_count = 0
728
+ prev_token = samples[0,0]
729
+ ##################### silence repetition handling #####################
730
+ ##################### silence repetition handling #####################
731
+ return samples, codebook_eog, prev_token, consec_silence_count
732
+ else:
733
+ assert sum(codebook_eog[i] for i in range(n_eog)) == n_eog, f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
734
+ logits_adjust = logits
735
+ for jj in range(n_eog+1,self.args.n_codebooks):
736
+ logits_adjust[jj][self.args.eog] = -10000
737
+ logits_adjust[jj][self.args.empty_token] = -10000
738
+ if type(logits_adjust) == list:
739
+ samples_list= []
740
+ for logit in logits_adjust:
741
+ cur_sample = topk_sampling(
742
+ logit.unsqueeze(0), top_k=top_k, top_p=top_p, temperature=temperature
743
+ ) # [1, 1]
744
+ samples_list.append(cur_sample)
745
+ samples = torch.cat(samples_list, dim=0) # [K, 1]
746
+ else:
747
+ samples = topk_sampling(
748
+ logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
749
+ ) # [K, 1]
750
+ for jj in range(n_eog):
751
+ samples[jj, 0] = self.args.empty_token
752
+ samples[n_eog, 0] = self.args.eog
753
+ codebook_eog[n_eog] = True
754
+ return samples, codebook_eog, prev_token, consec_silence_count
755
+
756
+ while True:
757
+ y_out, present = self.dec_forward(
758
+ x_input,
759
+ x_lens,
760
+ x_attention_mask,
761
+ x_padding_mask,
762
+ y_input,
763
+ new_y_lens,
764
+ y_attention_mask,
765
+ y_padding_mask,
766
+ past=past,
767
+ last_3_tokens = new_masked_span
768
+ )
769
+ if new_masked_span:
770
+ new_masked_span = False
771
+
772
+ if past != None:
773
+ past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype)
774
+
775
+ y_out = y_out[:, -1:] # only take the last one
776
+
777
+ logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card]
778
+ logits = logits.squeeze(0).squeeze(1) # [K card]
779
+ assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}"
780
+
781
+ n_eog = sum(codebook_eog)
782
+ assert n_eog < self.args.n_codebooks
783
+ if self.args.eos > 0: # eos stands for end-of-sentence, which shouldn't be used as we are doing speech editing
784
+ for jj in range(self.args.n_codebooks):
785
+ logits[jj][self.args.eos] = -10000.
786
+ # need to use a helper function to hand different n_eog cases
787
+ samples, codebook_eog, prev_token, consec_silence_count = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen)
788
+ cur_num_gen += 1
789
+ cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K]
790
+ # get samples_emb
791
+ samples_emb = torch.stack([self.audio_embedding[k](samples[k]) for k in range(self.args.n_codebooks)], dim=0) # [K,1,D]
792
+ samples_emb = samples_emb.sum(dim=0,keepdim=True) # [1,1,D]
793
+
794
+ if sum(codebook_eog) == self.args.n_codebooks: # generation for the current span is done
795
+ # re-init
796
+ codebook_eog = [False] * self.args.n_codebooks
797
+ num_gen.append(cur_num_gen)
798
+ cur_num_gen = 0
799
+ generated.append(cur_generated)
800
+ cur_generated = []
801
+
802
+ # if the current mask span is the last span, then all done
803
+ # else
804
+ # append the next mask token and the four empty tokens to start the next generation
805
+ if len(more_mask_value) > 0:
806
+ next_mask_ind = more_mask_value.pop(0)
807
+ mask_emb = self.mask_embedding[next_mask_ind].unsqueeze(0).unsqueeze(0) # [1,1,D]
808
+ assert mask_emb.shape == torch.Size((1,1,self.args.d_model)), mask_emb.shape
809
+ empty_token = torch.LongTensor([self.args.empty_token]).to(y.device)
810
+ empty_emb = torch.stack([
811
+ self.audio_embedding[k](empty_token) for k in range(self.args.n_codebooks)], dim=0
812
+ ).sum(dim=0, keepdim=True) # [1,1,D]
813
+ assert empty_emb.shape == torch.Size((1,1,self.args.d_model)), empty_emb.shape
814
+ extra_emb = torch.cat([mask_emb, empty_emb], dim=1) # [1,2,D]
815
+ samples_emb = torch.cat([samples_emb, extra_emb], dim=1) # [1,3,D] # prev_last_token, mask_token, empty token
816
+ assert samples_emb.shape == torch.Size((1,3,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
817
+ ##################### silence repetition handling #####################
818
+ ##################### silence repetition handling #####################
819
+ consec_silence_count = 0
820
+ prev_token = None
821
+ ##################### silence repetition handling #####################
822
+ ##################### silence repetition handling #####################
823
+
824
+ # handling kv-caching for multi-span editing
825
+ new_masked_span = True
826
+ else:
827
+ break
828
+ else:
829
+ assert samples_emb.shape == torch.Size((1,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
830
+
831
+ embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
832
+ # positional embedding
833
+ y_input = self.audio_positional_embedding(embedded_y) # [B T D]
834
+ # make attention mask and padding mask
835
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
836
+ new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
837
+ y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
838
+
839
+ assert len(generated) == num_mask, f"len(generated): {len(generated)}, num_mask: {num_mask}"
840
+
841
+ # # combine non_masked_span with generated spans
842
+ # first need to shift the generated part back
843
+ flatten_gen = []
844
+ for l, orig_span in enumerate(generated):
845
+ span = torch.stack(orig_span, dim=0) # [T K]
846
+ span = span.transpose(1,0) # [K, T]
847
+ assert span.shape[0] == self.args.n_codebooks, span.shape
848
+ unshifted_span = []
849
+ for j, s in enumerate(span):
850
+ start_from = j
851
+ end_at = - (self.args.n_codebooks - start_from)
852
+ unshifted_span.append(s[start_from:end_at])
853
+ unshifted_span = torch.stack(unshifted_span, dim=0)
854
+
855
+ assert unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks, f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
856
+ flatten_gen.append(unshifted_span)
857
+ # logging.info(f"unshfited_span: {unshifted_span.shape}")
858
+ # raise
859
+ assert len(non_mask_intervals[0]) - 1 == len(flatten_gen), f"len(non_mask_intervals[0]): {len(non_mask_intervals[0])}, len(flatten_gen): {len(flatten_gen)}"
860
+ res = []
861
+ for orig_interval, gen in zip(non_mask_intervals[0], flatten_gen):
862
+ res.append(y[0, :, orig_interval[0]:orig_interval[1]])
863
+ res.append(gen)
864
+ res.append(y[0, :, non_mask_intervals[0][-1][0]:non_mask_intervals[0][-1][1]])
865
+ res = torch.cat(res, dim=1).unsqueeze(0) # [K,new_T] -> [1, K, new_T]
866
+
867
+ expected_y_len = y_len - sum([item[1] - item[0] for item in mask_intervals[0]]) + sum([item - self.args.n_codebooks for item in num_gen])
868
+ assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len - sum([item[1] - item[0] for item in mask_interval]) + sum([item - self.args.n_codebooks for item in num_gen]): {y_len}-{sum([item[1] - item[0] for item in mask_interval])} + {sum([item - self.args.n_codebooks for item in num_gen])}"
869
+
870
+ if self.args.special_first:
871
+ res = res - int(self.args.n_special)
872
+
873
+ return res
874
+
875
+ def inference_tts(
876
+ self,
877
+ x: torch.Tensor,
878
+ x_lens: torch.Tensor,
879
+ y: torch.Tensor,
880
+ top_k: int=-100,
881
+ top_p: float=1.0,
882
+ temperature: float=1.0,
883
+ stop_repetition: int=3,
884
+ kvcache: int=1,
885
+ silence_tokens: list[int]=[1388,1898,131],
886
+ *kargs
887
+ ) -> torch.Tensor:
888
+ """
889
+ different from inference_tts, this implementation uses kvcache, which should have significant speed up
890
+ Args:
891
+ x:
892
+ A 2-D tensor of shape (1, L).
893
+ x_lens:
894
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
895
+ before padding.
896
+ y:
897
+ A 3-D tensor of shape (1, T, K).
898
+ top_k: (`optional`) int
899
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
900
+ top_p: (`optional`) float
901
+ For Neucleus sampling
902
+ temperature: (`optional`) float
903
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
904
+ """
905
+ eog_inference = self.args.eos if self.args.eos>0 else self.args.eog
906
+ assert x.ndim == 2, x.shape
907
+ assert x_lens.ndim == 1, x_lens.shape
908
+ assert y.ndim == 3, y.shape
909
+ if self.args.special_first:
910
+ y = y + int(self.args.n_special)
911
+ y = y.transpose(2,1) # [1,T,K] -> [1,K,T]
912
+ assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding
913
+
914
+ # make x attention mask and x_input
915
+ x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device)
916
+ # x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
917
+ x_input = self.text_embedding(x)
918
+ x_input = self.text_positional_embedding(x_input)
919
+
920
+ y_len = y.shape[2]
921
+ y_lens = torch.LongTensor([y_len]).to(y.device)
922
+
923
+ # rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario
924
+ rearranged_y = [[y[0]]]
925
+ assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][0].shape
926
+
927
+ # shift y to create the delayed pattern
928
+ shifted_y, patterns = self.shift(rearranged_y) # each element [K S], patterns is not used, as we directly use the original input y
929
+ assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape
930
+ assert len(shifted_y[0]) == 1, len(shifted_y[0])
931
+
932
+ # below is different from forward or inference
933
+ # where we cut this shifted part
934
+ shifted_y[0][0] = shifted_y[0][0][:, :-(self.args.n_codebooks-1)]
935
+ assert not (shifted_y[0][0][self.args.n_codebooks:] == self.args.empty_token).any() and not (shifted_y[0][0][self.args.n_codebooks:] == self.args.eog).any(), shifted_y[0][0]
936
+
937
+ # next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that
938
+ # next section is concate tensors of each sample to one tensor, which we also don't need
939
+ cated_y = shifted_y[0][0].unsqueeze(-1) #[K,S]->[K,S,B]
940
+ new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device)
941
+ assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1))
942
+ assert not (cated_y == self.args.audio_pad_token).any(), cated_y
943
+
944
+ # replace tokens in y with the embeddings, add sum codebooks up
945
+ embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, S, B, D]
946
+ assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
947
+ assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
948
+ embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D]
949
+ embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D]
950
+
951
+ # positional embedding
952
+ y_input = self.audio_positional_embedding(embedded_y)
953
+
954
+ # make attention mask and padding mask
955
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
956
+
957
+ x_padding_mask = torch.full((1,x_lens[0]), False).to(x.device)
958
+ y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
959
+
960
+ # entering the generation stage
961
+ # starting from line 708
962
+ codebook_eog = [False] * self.args.n_codebooks
963
+ generated = [] # doesn't contain any empty token, contain eog
964
+ cur_generated = []
965
+ # say 0 is empty, 4 is eog
966
+ # tensor([[ 1, 2, 3, 4, 0, 0],
967
+ # [ 0, 1, 2, 3, 4, 0],
968
+ # [ 0, 0, 1, 2, 3, 4]])
969
+ num_gen = []
970
+ cur_num_gen = 0
971
+ ##################### silence repetition handling #####################
972
+ ##################### silence repetition handling #####################
973
+ logging.info(f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
974
+ consec_silence_count = 0
975
+ prev_token = None
976
+ ##################### silence repetition handling #####################
977
+ ##################### silence repetition handling #####################
978
+
979
+ # prepare the cache placeholder
980
+ # n_layers, 2, bsz, num_heads, src_len, head_dim
981
+ past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
982
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
983
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
984
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
985
+ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen):
986
+ if n_eog == 0:
987
+ logits_adjust = logits
988
+ for jj in range(1,self.args.n_codebooks):
989
+ logits_adjust[jj][eog_inference] = -10000
990
+ logits_adjust[jj][self.args.empty_token] = -10000
991
+ if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
992
+ logits_adjust[0][eog_inference] = -10000
993
+ ##################### silence repetition handling #####################
994
+ if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
995
+ if logits_adjust[0, prev_token] < 0:
996
+ logits_adjust[0, prev_token] = logits_adjust[0, prev_token] * (consec_silence_count - (stop_repetition-1))
997
+ else:
998
+ logits_adjust[0, prev_token] = logits_adjust[0, prev_token] / (consec_silence_count - (stop_repetition-1))
999
+ ##################### silence repetition handling #####################
1000
+ samples = topk_sampling(
1001
+ logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
1002
+ ) # [K, 1]
1003
+ assert samples.shape == torch.Size((self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}"
1004
+ if cur_num_gen < self.args.n_codebooks-1:
1005
+ for jj in range(1, self.args.n_codebooks - cur_num_gen):
1006
+ samples[-jj, 0] = self.args.empty_token
1007
+
1008
+ if (
1009
+ samples[0,0] == eog_inference or torch.argmax(logits[0], dim=-1) == eog_inference or y_input.shape[1] > x_lens[0] * (self.args.encodec_sr//5)
1010
+ ): # last one means y is already too long, shouldn't happen, but put it here
1011
+ samples[0,0] = eog_inference
1012
+ codebook_eog[0] = True
1013
+ ##################### silence repetition handling #####################
1014
+ if samples[0,0] in silence_tokens and samples[0,0] == prev_token:
1015
+ consec_silence_count += 1
1016
+ else:
1017
+ consec_silence_count = 0
1018
+ prev_token = samples[0,0]
1019
+ ##################### silence repetition handling #####################
1020
+ return samples, codebook_eog, prev_token, consec_silence_count
1021
+ else:
1022
+ assert sum(codebook_eog[i] for i in range(n_eog)) == n_eog, f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
1023
+ logits_adjust = logits
1024
+ for jj in range(n_eog+1,self.args.n_codebooks):
1025
+ logits_adjust[jj][eog_inference] = -10000
1026
+ logits_adjust[jj][self.args.empty_token] = -10000
1027
+ samples = topk_sampling(
1028
+ logits_adjust, top_k=top_k, top_p=top_p, temperature=temperature
1029
+ ) # [K, 1]
1030
+ for jj in range(n_eog):
1031
+ samples[jj, 0] = self.args.empty_token
1032
+ samples[n_eog, 0] = eog_inference
1033
+ codebook_eog[n_eog] = True
1034
+ return samples, codebook_eog, prev_token, consec_silence_count
1035
+ while True:
1036
+ y_out, present = self.dec_forward(
1037
+ x_input,
1038
+ x_lens,
1039
+ x_attention_mask,
1040
+ x_padding_mask,
1041
+ y_input,
1042
+ new_y_lens,
1043
+ y_attention_mask,
1044
+ y_padding_mask,
1045
+ past=past
1046
+ )
1047
+ if past != None:
1048
+ past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype)
1049
+
1050
+
1051
+ y_out = y_out[:, -1:] # only take the last token
1052
+ logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], B==S==1, so [1 K 1 card]
1053
+ logits = logits.squeeze(0).squeeze(1) # [K card]
1054
+ assert logits.shape == torch.Size((self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}"
1055
+
1056
+ n_eog = sum(codebook_eog)
1057
+ assert n_eog < self.args.n_codebooks
1058
+ if self.args.eos > 0: # if we are using end-of-sentence token (which is used by default), eog shouldn't be used here, as there is no masked spans
1059
+ for jj in range(self.args.n_codebooks):
1060
+ logits[jj][self.args.eog] = -10000.
1061
+
1062
+ samples, codebook_eog, prev_token, consec_silence_count = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_token, consec_silence_count, stop_repetition, silence_tokens, cur_num_gen)
1063
+
1064
+ cur_num_gen += 1
1065
+ cur_generated.append(samples.squeeze(-1)) # [K,1] -> [K]
1066
+
1067
+ # samples.shape is [K,1]
1068
+ # ge samples_emb
1069
+ samples_emb = torch.stack([self.audio_embedding[k](samples[k]) for k in range(self.args.n_codebooks)], dim=0) # [K,1,D]
1070
+ samples_emb = samples_emb.sum(dim=0,keepdim=True) # [1,1,D]
1071
+
1072
+ if sum(codebook_eog) == self.args.n_codebooks: # generation for the current span is done
1073
+ codebook_eog = [False] * self.args.n_codebooks
1074
+ num_gen.append(cur_num_gen)
1075
+ cur_num_gen = 0
1076
+ generated.append(cur_generated)
1077
+ cur_generated = []
1078
+ break
1079
+ else:
1080
+ assert samples_emb.shape == torch.Size((1,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
1081
+
1082
+ embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
1083
+ y_input = self.audio_positional_embedding(embedded_y) # [B T D]
1084
+ # make attention mask and padding mask
1085
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
1086
+ new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device)
1087
+ y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
1088
+
1089
+ assert len(generated) == 1, f"len(generated): {len(generated)}"
1090
+
1091
+ # revert the pattern
1092
+ flatten_gen = []
1093
+ for l, orig_span in enumerate(generated):
1094
+ span = torch.stack(orig_span, dim=0) # [T, K]
1095
+ span = span.transpose(1,0) # [K, T]
1096
+ assert span.shape[0] == self.args.n_codebooks, span.shape
1097
+ unshifted_span = []
1098
+ for j, s in enumerate(span):
1099
+ start_from = j
1100
+ end_at = - (self.args.n_codebooks - start_from)
1101
+ unshifted_span.append(s[start_from:end_at])
1102
+ unshifted_span = torch.stack(unshifted_span, dim=0)
1103
+
1104
+ assert unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks, f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
1105
+
1106
+ flatten_gen.append(unshifted_span)
1107
+ assert len(flatten_gen) == 1, len(flatten_gen)
1108
+
1109
+ # combine
1110
+ res = [y[0], flatten_gen[0]]
1111
+ res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]
1112
+
1113
+ expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen])
1114
+ assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}"
1115
+
1116
+ if self.args.special_first:
1117
+ res = res - int(self.args.n_special)
1118
+ flatten_gen = flatten_gen - int(self.args.n_special)
1119
+
1120
+ return res, flatten_gen[0].unsqueeze(0)
1121
+
1122
+
1123
+ def inference_tts_batch(
1124
+ self,
1125
+ x: torch.Tensor,
1126
+ x_lens: torch.Tensor,
1127
+ y: torch.Tensor,
1128
+ top_k: int=-100,
1129
+ top_p: float=1.0,
1130
+ temperature: float=1.0,
1131
+ stop_repetition: int=3,
1132
+ kvcache: int=1,
1133
+ batch_size: int=5,
1134
+ silence_tokens: list[int]=[1388,1898,131],
1135
+ *kargs
1136
+ ) -> torch.Tensor:
1137
+ """
1138
+ have a batch size when forward passing, but they are equivalant to same example but different random seed, therefore as long as one example generated eog, we can drop all other samlpes
1139
+ different from inference_tts, this implementation uses kvcache, which should have significant speed up
1140
+ Args:
1141
+ x:
1142
+ A 2-D tensor of shape (1, L).
1143
+ x_lens:
1144
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
1145
+ before padding.
1146
+ y:
1147
+ A 3-D tensor of shape (1, T, K).
1148
+ top_k: (`optional`) int
1149
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
1150
+ top_p: (`optional`) float
1151
+ For Neucleus sampling
1152
+ temperature: (`optional`) float
1153
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
1154
+ """
1155
+ eog_inference = self.args.eos if self.args.eos>0 else self.args.eog
1156
+ assert x.ndim == 2, x.shape
1157
+ assert x_lens.ndim == 1, x_lens.shape
1158
+ assert y.ndim == 3, y.shape
1159
+ if self.args.special_first:
1160
+ y = y + int(self.args.n_special)
1161
+ y = y.transpose(2,1) # [1,T,K] -> [1,K,T]
1162
+ assert y.shape[0] == 1 and y.shape[1] == self.args.n_codebooks, y.shape # there is no padding
1163
+
1164
+ # make x attention mask and x_input
1165
+ x_attention_mask = torch.triu(torch.ones(x.shape[1], x.shape[1]), diagonal=1).bool().to(x.device)
1166
+ # x_attention_mask = torch.zeros(x.shape[1], x.shape[1]).bool().to(x.device)
1167
+ x_input = self.text_embedding(x)
1168
+ x_input = self.text_positional_embedding(x_input)
1169
+
1170
+ y_len = y.shape[2]
1171
+ y_lens = torch.LongTensor([y_len]).to(y.device)
1172
+
1173
+ # rearrange y, we don't add eog to the end, this doesn't actually do anything in the tts scenario
1174
+ rearranged_y = [[y[0]]]
1175
+ assert rearranged_y[0][0].shape[0] == self.args.n_codebooks, rearranged_y[0][0].shape
1176
+
1177
+ # shift y to create the delayed pattern
1178
+ shifted_y, patterns = self.shift(rearranged_y) # each element [K S], patterns is not used, as we directly use the original input y
1179
+ assert shifted_y[0][0].shape[0] == self.args.n_codebooks, shifted_y[0][0].shape
1180
+ assert len(shifted_y[0]) == 1, len(shifted_y[0])
1181
+
1182
+ # below is different from forward or inference
1183
+ # where we cut this shifted part
1184
+ shifted_y[0][0] = shifted_y[0][0][:, :-(self.args.n_codebooks-1)]
1185
+ assert not (shifted_y[0][0][self.args.n_codebooks:] == self.args.empty_token).any() and not (shifted_y[0][0][self.args.n_codebooks:] == self.args.eog).any(), shifted_y[0][0]
1186
+
1187
+ # next section in inference is insert mask at the intersection of each tensor in a sample, but we don't need to do that
1188
+ # next section is concate tensors of each sample to one tensor, which we also don't need
1189
+ cated_y = shifted_y[0][0].unsqueeze(-1) #[K,S]->[K,S,B]
1190
+ new_y_lens = torch.LongTensor([cated_y.shape[1]]).to(cated_y.device)
1191
+ assert cated_y.shape == torch.Size((self.args.n_codebooks, cated_y.shape[1], 1))
1192
+ assert not (cated_y == self.args.audio_pad_token).any(), cated_y
1193
+
1194
+ # replace tokens in y with the embeddings, add sum codebooks up
1195
+ embedded_y = torch.stack([self.audio_embedding[k](cated_y[k]) for k in range(self.args.n_codebooks)], dim=0) # [K, S, B, D]
1196
+ assert embedded_y.shape[0] == self.args.n_codebooks, embedded_y.shape
1197
+ assert embedded_y.shape[-1] == self.args.d_model, embedded_y.shape
1198
+ embedded_y = embedded_y.sum(dim=0) # [K,S,B,D]->[S,B,D]
1199
+ embedded_y = embedded_y.transpose(1,0) # [S,B,D]->[B,S,D]
1200
+
1201
+ # positional embedding
1202
+ y_input = self.audio_positional_embedding(embedded_y)
1203
+
1204
+ # make attention mask and padding mask
1205
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
1206
+
1207
+ x_padding_mask = torch.full((1,x_lens[0]), False).to(x.device)
1208
+ y_padding_mask = torch.full((1,new_y_lens[0]), False).to(y.device)
1209
+
1210
+ # entering the generation stage
1211
+ # starting from line 708
1212
+ codebook_eog = [False] * self.args.n_codebooks
1213
+ generated = [] # doesn't contain any empty token, contain eog
1214
+ cur_generated = [[] for _ in range(batch_size)]
1215
+ # say 0 is empty, 4 is eog
1216
+ # tensor([[ 1, 2, 3, 4, 0, 0],
1217
+ # [ 0, 1, 2, 3, 4, 0],
1218
+ # [ 0, 0, 1, 2, 3, 4]])
1219
+ num_gen = []
1220
+ cur_num_gen = 0
1221
+ ##################### silence repetition handling #####################
1222
+ ##################### silence repetition handling #####################
1223
+ logging.info(f"silence tokens: {silence_tokens}, note that if you are not using the pretrained encodec 6f79c6a8, make sure you specified it yourself, rather than using the default")
1224
+ consec_silence_counts = [0 for _ in range(batch_size)]
1225
+ prev_tokens = [None for _ in range(batch_size)]
1226
+ ##################### silence repetition handling #####################
1227
+ ##################### silence repetition handling #####################
1228
+
1229
+ # prepare the cache placeholder
1230
+ # n_layers, 2, bsz, num_heads, src_len, head_dim
1231
+ past = torch.ones([self.args.num_decoder_layers, 2, x.shape[0]], device=x.device, dtype=torch.float32) if kvcache else None
1232
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
1233
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
1234
+ # logging.info(f"number of decoder layers: {self.args.num_decoder_layers}")
1235
+ keep = None # NOTE: this very important, tells which sample to keep
1236
+ def sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_tokens, consec_silence_counts, stop_repetition, silence_tokens, cur_num_gen, keep):
1237
+ if n_eog == 0:
1238
+ logits_adjust = logits
1239
+ for jj in range(1,self.args.n_codebooks):
1240
+ logits_adjust[:,jj,eog_inference] = -10000
1241
+ logits_adjust[:,jj,self.args.empty_token] = -10000
1242
+ if cur_num_gen <= self.args.encodec_sr // 5: # this shouldn't happen, but just in case the model stopped too early
1243
+ logits_adjust[:,:,eog_inference] = -10000
1244
+ ##################### silence repetition handling #####################
1245
+ for b in range(batch_size):
1246
+ prev_token = prev_tokens[b]
1247
+ consec_silence_count = consec_silence_counts[b]
1248
+ if stop_repetition > 0 and prev_token in silence_tokens and consec_silence_count > stop_repetition:
1249
+ if logits_adjust[b, 0, prev_token] < 0:
1250
+ logits_adjust[b, 0, prev_token] = logits_adjust[b, 0, prev_token] * (consec_silence_count - (stop_repetition-1))
1251
+ else:
1252
+ logits_adjust[b, 0, prev_token] = logits_adjust[b, 0, prev_token] / (consec_silence_count - (stop_repetition-1))
1253
+ ##################### silence repetition handling #####################
1254
+ samples = topk_sampling(
1255
+ logits_adjust.reshape(batch_size * self.args.n_codebooks, logits_adjust.shape[-1]), top_k=top_k, top_p=top_p, temperature=temperature
1256
+ ) # [B*K, 1]
1257
+ samples = samples.reshape(batch_size, self.args.n_codebooks, 1)
1258
+ assert samples.shape == torch.Size((batch_size, self.args.n_codebooks, 1)), f"samples.shape: {samples.shape}"
1259
+ for b in range(batch_size):
1260
+ if cur_num_gen < self.args.n_codebooks-1:
1261
+ for jj in range(1, self.args.n_codebooks - cur_num_gen):
1262
+ samples[b, -jj, 0] = self.args.empty_token
1263
+
1264
+ if (
1265
+ samples[b,0,0] == eog_inference or torch.argmax(logits[b,0], dim=-1) == eog_inference or y_input.shape[1] > x_lens[b] * (self.args.encodec_sr//5)
1266
+ ): # last one means y is already too long, shouldn't happen, but put it here
1267
+ samples[b,0,0] = eog_inference
1268
+ codebook_eog[0] = True
1269
+ keep = b # NOTE keep is a very important variable, we only return this one, note that if eog shows up in two samples, keep will be overwritten by the later one (or the last one)
1270
+ ##################### silence repetition handling #####################
1271
+ if samples[b,0,0] in silence_tokens and samples[b,0,0] == prev_tokens[b]:
1272
+ consec_silence_counts[b] += 1
1273
+ else:
1274
+ consec_silence_counts[b] = 0
1275
+ prev_tokens[b] = samples[b,0,0]
1276
+ ##################### silence repetition handling #####################
1277
+ return samples, codebook_eog, prev_tokens, consec_silence_counts, keep
1278
+ else:
1279
+ assert sum(codebook_eog[i] for i in range(n_eog)) == n_eog, f"codebook_eog: {codebook_eog}, but n_eog: {n_eog}"
1280
+ logits_adjust = logits
1281
+ for jj in range(n_eog+1,self.args.n_codebooks):
1282
+ logits_adjust[:,jj,eog_inference] = -10000
1283
+ logits_adjust[:,jj,self.args.empty_token] = -10000
1284
+ samples = topk_sampling(
1285
+ logits_adjust.reshape(batch_size * self.args.n_codebooks, logits_adjust.shape[-1]), top_k=top_k, top_p=top_p, temperature=temperature
1286
+ ) # [B, K, 1]
1287
+ samples = samples.reshape(batch_size, self.args.n_codebooks, 1)
1288
+ for jj in range(n_eog):
1289
+ samples[keep, jj, 0] = self.args.empty_token
1290
+ samples[keep, n_eog, 0] = eog_inference
1291
+ codebook_eog[n_eog] = True
1292
+ return samples, codebook_eog, prev_tokens, consec_silence_counts, keep
1293
+ while True:
1294
+ # if cur_num_gen > 0, should have everything in kvcache, so only pass in the last token
1295
+ # in the first generation step, we repeat each tensor to make their first dimension of length the batch size
1296
+ if cur_num_gen == 0:
1297
+ assert x_input.ndim == 3 and x_input.shape[0] == 1, x_input.shape
1298
+ assert x_padding_mask.ndim == 2 and x_padding_mask.shape[0] == 1, x_padding_mask.shape
1299
+ assert y_input.ndim == 3 and y_input.shape[0] == 1 and y_input.shape[1] == new_y_lens[0], y_input.shape
1300
+ assert embedded_y.ndim == 3 and embedded_y.shape[0] == 1 and embedded_y.shape[1] == new_y_lens[0], embedded_y.shape
1301
+ x_input = x_input.repeat(batch_size, 1, 1)
1302
+ x_lens = x_lens.repeat(batch_size)
1303
+ # x_attention_mask = x_attention_mask.repeat(batch_size, 1, 1) # no need to work with attention mask, it doesn't contain batch dimension
1304
+ x_padding_mask = x_padding_mask.repeat(batch_size, 1)
1305
+ y_input = y_input.repeat(batch_size, 1, 1)
1306
+ new_y_lens = new_y_lens.repeat(batch_size)
1307
+ # y_attention_mask = y_attention_mask.repeat(batch_size, 1, 1) # no need to work with attention mask, it doesn't contain batch dimension
1308
+ y_padding_mask = y_padding_mask.repeat(batch_size, 1)
1309
+ embedded_y = embedded_y.repeat(batch_size, 1, 1) # will be used to concat with newly generated token embedding
1310
+ past = past.repeat(1, 1, batch_size) if past != None else None
1311
+ else:
1312
+ assert x_input.shape[0] == batch_size and x_padding_mask.shape[0] == batch_size and y_input.shape[0] == batch_size and new_y_lens.shape[0] == batch_size, f"x_input.shape: {x_input.shape}, x_padding_mask.shape: {x_padding_mask.shape}, y_input.shape: {y_input.shape}, new_y_lens.shape: {new_y_lens.shape}"
1313
+ y_out, present = self.dec_forward(
1314
+ x_input,
1315
+ x_lens,
1316
+ x_attention_mask,
1317
+ x_padding_mask,
1318
+ y_input,
1319
+ new_y_lens,
1320
+ y_attention_mask,
1321
+ y_padding_mask,
1322
+ past=past
1323
+ )
1324
+ if past != None:
1325
+ past = torch.cat([past, present.to(past.dtype)], dim=-2) if past.ndim > 3 else present.to(past.dtype)
1326
+
1327
+ # if no eog emerges, y_out should have batch size of batch_size
1328
+ if sum(codebook_eog) == 0:
1329
+ assert y_out.shape[0] == batch_size and y_out.ndim == 3, y_out.shape
1330
+ y_out = y_out[:, -1:] # only take the last token
1331
+ logits = torch.stack([self.predict_layer[i](y_out) for i in range(self.args.n_codebooks)], dim=1) # [B K S card], S==1, so [B K 1 card]
1332
+ logits = logits.squeeze(2) # [B K card]
1333
+ assert logits.shape == torch.Size((batch_size, self.args.n_codebooks, self.n_audio_tokens[0])), f"{logits.shape}"
1334
+
1335
+ n_eog = sum(codebook_eog)
1336
+ if self.args.eos > 0:
1337
+ for jj in range(self.args.n_codebooks):
1338
+ logits[:,jj,self.args.eog] = -10000.
1339
+ samples, codebook_eog, prev_tokens, consec_silence_counts, keep = sample_helper(n_eog, logits, codebook_eog, top_k, top_p, temperature, prev_tokens, consec_silence_counts, stop_repetition, silence_tokens, cur_num_gen, keep)
1340
+
1341
+ cur_num_gen += 1
1342
+ if sum(codebook_eog) == 0: # no eog yet, keep batch_size of samples
1343
+ assert keep == None
1344
+ for b in range(batch_size):
1345
+ cur_generated[b].append(samples[b].squeeze(-1))
1346
+ elif sum(codebook_eog) == 1: # the first eog just showed up in this step
1347
+ assert keep != None
1348
+ cur_generated = cur_generated[keep]
1349
+ cur_generated.append(samples[keep].squeeze(-1))
1350
+ else: # we are generating the rest eogs for the 'keep' sample
1351
+ cur_generated.append(samples[keep].squeeze(-1))
1352
+
1353
+ # samples.shape is [K,1]
1354
+ # ge samples_emb
1355
+ samples_emb = torch.stack([self.audio_embedding[k](samples[:, k]) for k in range(self.args.n_codebooks)], dim=1) # [B, K,1,D]
1356
+ assert samples_emb.shape == torch.Size([batch_size, self.args.n_codebooks, 1, self.args.d_model])
1357
+ samples_emb = samples_emb.sum(dim=1,keepdim=False) # [B,1,D]
1358
+ if sum(codebook_eog) == self.args.n_codebooks: # generation for the current span is done
1359
+ codebook_eog = [False] * self.args.n_codebooks
1360
+ num_gen.append(cur_num_gen)
1361
+ cur_num_gen = 0
1362
+ generated.append(cur_generated)
1363
+ cur_generated = [[] for _ in range(batch_size)]
1364
+ break
1365
+ else:
1366
+ assert samples_emb.shape == torch.Size((batch_size,1,self.args.d_model)), f"samples_emb.shape: {samples_emb.shape}"
1367
+
1368
+ embedded_y = torch.cat([embedded_y, samples_emb], dim=1)
1369
+ y_input = self.audio_positional_embedding(embedded_y) # [B T D]
1370
+ # make attention mask and padding mask
1371
+ y_attention_mask = torch.triu(torch.ones(y_input.shape[1], y_input.shape[1]), diagonal=1).bool().to(y.device)
1372
+ new_y_lens = torch.LongTensor([y_input.shape[1]]).to(y.device).repeat(batch_size)
1373
+ y_padding_mask = torch.full((batch_size,new_y_lens[0]), False).to(y.device)
1374
+
1375
+ assert len(generated) == 1, f"len(generated): {len(generated)}"
1376
+
1377
+ # revert the pattern
1378
+ flatten_gen = []
1379
+ for l, orig_span in enumerate(generated):
1380
+ span = torch.stack(orig_span, dim=0) # [T, K]
1381
+ span = span.transpose(1,0) # [K, T]
1382
+ assert span.shape[0] == self.args.n_codebooks, span.shape
1383
+ unshifted_span = []
1384
+ for j, s in enumerate(span):
1385
+ start_from = j
1386
+ end_at = - (self.args.n_codebooks - start_from)
1387
+ unshifted_span.append(s[start_from:end_at])
1388
+ unshifted_span = torch.stack(unshifted_span, dim=0)
1389
+
1390
+ assert unshifted_span.shape[1] == num_gen[l] - self.args.n_codebooks, f"len(unshifted_spans[0]): {len(unshifted_span[0])}, num_gen[l]: {num_gen[l]}"
1391
+
1392
+ flatten_gen.append(unshifted_span)
1393
+ assert len(flatten_gen) == 1, len(flatten_gen)
1394
+
1395
+ # combine
1396
+ res = [y[0], flatten_gen[0]]
1397
+ res = torch.cat(res, dim=1).unsqueeze(0) # [K, new_t] -> [1, K, new_T]
1398
+
1399
+ expected_y_len = y_len + sum([item - self.args.n_codebooks for item in num_gen])
1400
+ assert res.shape == torch.Size((1, self.args.n_codebooks, expected_y_len)), f"res.shape: {res.shape}, expected_y_len: {expected_y_len}. y_len + sum([item - self.args.n_codebooks for item in num_gen]): {y_len} + {sum([item - self.args.n_codebooks for item in num_gen])}"
1401
+
1402
+ if self.args.special_first:
1403
+ res = res - int(self.args.n_special)
1404
+ flatten_gen = flatten_gen - int(self.args.n_special)
1405
+
1406
+ return res, flatten_gen[0].unsqueeze(0)
pretrained_models/encodec_4cb2048_giga.th ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:caa0c595d4919527a9728d627150aa2a0b15b6d117b21855165851333dc63378
3
+ size 1167842971
pretrained_models/giga330M.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35e028b8c5237cb4a6050ca81d4569b98e3a34ad9175fa252f7b1d13e6a9ad26
3
+ size 1746844161
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ -e git+https://github.com/facebookresearch/audiocraft.git@c5157b5bf14bf83449c17ea1eeb66c19fb4bc7f0#egg=audiocraft
2
+ xformers==0.0.22
3
+ torchaudio==2.0.2
4
+ torch==2.0.1
5
+ phonemizer==3.2.1
6
+ gradio==3.50.2
7
+ nltk>=3.8.1
8
+ openai-whisper>=20231117
9
+ spaces