diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..8e27cfa1dfdcf888a72f0dea06481968f974ae87 --- /dev/null +++ b/.gitignore @@ -0,0 +1,36 @@ +.idea/ +tensorboard_logs/ +Corpora/ +Models/ +audios/ +Preprocessing/glottolog/ +Preprocessing/multilinguality/datasets/ +apex/ +pretrained_models/ +.tmp/ +.vscode/ +split/ +singing/ +toucan_conda_venv/ +venv/ +vis/ +Utility/storage_config.py +Preprocessing/multilinguality/distance_datasets + + +*_graph +app.py +gradio* +*playground* +run_phonemizer.py + +*.pt +*.out +*.wav +*.flac +*.json +*.pyc +*.png +*.pdf +*.pkl +*.gif \ No newline at end of file diff --git a/Architectures/Aligner/Aligner.py b/Architectures/Aligner/Aligner.py new file mode 100644 index 0000000000000000000000000000000000000000..8396061e9e4d50d95c0932641803a5cff4dd3de3 --- /dev/null +++ b/Architectures/Aligner/Aligner.py @@ -0,0 +1,164 @@ +""" +taken and adapted from https://github.com/as-ideas/DeepForcedAligner +""" +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.multiprocessing +import torch.nn as nn +from torch.nn import CTCLoss +from torch.nn.utils.rnn import pack_padded_sequence +from torch.nn.utils.rnn import pad_packed_sequence + +from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend + + +class BatchNormConv(nn.Module): + + def __init__(self, in_channels: int, out_channels: int, kernel_size: int): + super().__init__() + self.conv = nn.Conv1d( + in_channels, out_channels, kernel_size, + stride=1, padding=kernel_size // 2, bias=False) + self.bnorm = nn.BatchNorm1d(out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = x.transpose(1, 2) + x = self.conv(x) + x = self.relu(x) + x = self.bnorm(x) + x = x.transpose(1, 2) + return x + + +class Aligner(torch.nn.Module): + + def __init__(self, + n_features=128, + num_symbols=145, + lstm_dim=512, + conv_dim=512): + super().__init__() + self.convs = nn.ModuleList([ + BatchNormConv(n_features, conv_dim, 3), + nn.Dropout(p=0.5), + BatchNormConv(conv_dim, conv_dim, 3), + nn.Dropout(p=0.5), + BatchNormConv(conv_dim, conv_dim, 3), + nn.Dropout(p=0.5), + BatchNormConv(conv_dim, conv_dim, 3), + nn.Dropout(p=0.5), + BatchNormConv(conv_dim, conv_dim, 3), + nn.Dropout(p=0.5), + ]) + self.rnn = torch.nn.LSTM(conv_dim, lstm_dim, batch_first=True, bidirectional=True) + self.proj = torch.nn.Linear(2 * lstm_dim, num_symbols) + self.tf = ArticulatoryCombinedTextFrontend(language="eng") + self.ctc_loss = CTCLoss(blank=144, zero_infinity=True) + self.vector_to_id = dict() + + def forward(self, x, lens=None): + for conv in self.convs: + x = conv(x) + + if lens is not None: + x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False) + x, _ = self.rnn(x) + if lens is not None: + x, _ = pad_packed_sequence(x, batch_first=True) + + x = self.proj(x) + + return x + + @torch.inference_mode() + def inference(self, features, tokens, save_img_for_debug=None, train=False, pathfinding="MAS", return_ctc=False): + if not train: + tokens_indexed = self.tf.text_vectors_to_id_sequence(text_vector=tokens) # first we need to convert the articulatory vectors to IDs, so we can apply dijkstra or viterbi + tokens = np.asarray(tokens_indexed) + else: + tokens = tokens.cpu().detach().numpy() + + pred = self(features.unsqueeze(0)) + if return_ctc: + ctc_loss = self.ctc_loss(pred.transpose(0, 1).log_softmax(2), torch.LongTensor(tokens), torch.LongTensor([len(pred[0])]), + torch.LongTensor([len(tokens)])).item() + pred = pred.squeeze().cpu().detach().numpy() + pred_max = pred[:, tokens] + + # run monotonic alignment search + + alignment_matrix = binarize_alignment(pred_max) + + if save_img_for_debug is not None: + phones = list() + for index in tokens: + for phone in self.tf.phone_to_id: + if self.tf.phone_to_id[phone] == index: + phones.append(phone) + fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5)) + + ax.imshow(alignment_matrix, interpolation='nearest', aspect='auto', origin="lower", cmap='cividis') + ax.set_ylabel("Mel-Frames") + ax.set_xticks(range(len(pred_max[0]))) + ax.set_xticklabels(labels=phones) + ax.set_title("MAS Path") + + plt.tight_layout() + fig.savefig(save_img_for_debug) + fig.clf() + plt.close() + + if return_ctc: + return alignment_matrix, ctc_loss + return alignment_matrix + + + +def binarize_alignment(alignment_prob): + """ + # Implementation by: + # https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/alignment.py + # https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/attn_loss_function.py + + Binarizes alignment with MAS. + """ + # assumes features x text + opt = np.zeros_like(alignment_prob) + alignment_prob = alignment_prob + (np.abs(alignment_prob).max() + 1.0) # make all numbers positive and add an offset to avoid log of 0 later + alignment_prob * alignment_prob * (1.0 / alignment_prob.max()) # normalize to (0, 1] + attn_map = np.log(alignment_prob) + attn_map[0, 1:] = -np.inf + log_p = np.zeros_like(attn_map) + log_p[0, :] = attn_map[0, :] + prev_ind = np.zeros_like(attn_map, dtype=np.int64) + for i in range(1, attn_map.shape[0]): + for j in range(attn_map.shape[1]): # for each text dim + prev_log = log_p[i - 1, j] + prev_j = j + if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]: + prev_log = log_p[i - 1, j - 1] + prev_j = j - 1 + log_p[i, j] = attn_map[i, j] + prev_log + prev_ind[i, j] = prev_j + # now backtrack + curr_text_idx = attn_map.shape[1] - 1 + for i in range(attn_map.shape[0] - 1, -1, -1): + opt[i, curr_text_idx] = 1 + curr_text_idx = prev_ind[i, curr_text_idx] + opt[0, curr_text_idx] = 1 + return opt + + +if __name__ == '__main__': + tf = ArticulatoryCombinedTextFrontend(language="eng") + from Preprocessing.HiFiCodecAudioPreprocessor import CodecAudioPreprocessor + + cap = CodecAudioPreprocessor(input_sr=-1) + dummy_codebook_indexes = torch.randint(low=0, high=1023, size=[9, 20]) + codebook_frames = cap.indexes_to_codec_frames(dummy_codebook_indexes) + alignment = Aligner().inference(codebook_frames.transpose(0, 1), tokens=tf.string_to_tensor("Hello world")) + print(alignment.shape) + plt.imshow(alignment, origin="lower", cmap="GnBu") + plt.show() diff --git a/Architectures/Aligner/CodecAlignerDataset.py b/Architectures/Aligner/CodecAlignerDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..861353263ee42d73ecd8cbd6d7491826dc311aac --- /dev/null +++ b/Architectures/Aligner/CodecAlignerDataset.py @@ -0,0 +1,271 @@ +import os +import random + +import librosa +import soundfile as sf +import torch +from speechbrain.pretrained import EncoderClassifier +from torch.multiprocessing import Manager +from torch.multiprocessing import Process +from torch.utils.data import Dataset +from torchaudio.transforms import Resample +from tqdm import tqdm + +from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor +from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend +from Utility.storage_config import MODELS_DIR + + +class CodecAlignerDataset(Dataset): + + def __init__(self, + path_to_transcript_dict, + cache_dir, + lang, + loading_processes, + device, + min_len_in_seconds=1, + max_len_in_seconds=15, + rebuild_cache=False, + verbose=False, + phone_input=False, + allow_unknown_symbols=False, + gpu_count=1, + rank=0): + self.gpu_count = gpu_count + self.rank = rank + if not os.path.exists(os.path.join(cache_dir, "aligner_train_cache.pt")) or rebuild_cache: + self._build_dataset_cache(path_to_transcript_dict=path_to_transcript_dict, + cache_dir=cache_dir, + lang=lang, + loading_processes=loading_processes, + device=device, + min_len_in_seconds=min_len_in_seconds, + max_len_in_seconds=max_len_in_seconds, + verbose=verbose, + phone_input=phone_input, + allow_unknown_symbols=allow_unknown_symbols, + gpu_count=gpu_count, + rank=rank) + self.lang = lang + self.device = device + self.cache_dir = cache_dir + self.tf = ArticulatoryCombinedTextFrontend(language=self.lang) + cache = torch.load(os.path.join(self.cache_dir, "aligner_train_cache.pt"), map_location='cpu') + self.speaker_embeddings = cache[2] + self.datapoints = cache[0] + if self.gpu_count > 1: + # we only keep a chunk of the dataset in memory to avoid redundancy. Which chunk, we figure out using the rank. + while len(self.datapoints) % self.gpu_count != 0: + self.datapoints.pop(-1) # a bit unfortunate, but if you're using multiple GPUs, you probably have a ton of datapoints anyway. + chunksize = int(len(self.datapoints) / self.gpu_count) + self.datapoints = self.datapoints[chunksize * self.rank:chunksize * (self.rank + 1)] + self.speaker_embeddings = self.speaker_embeddings[chunksize * self.rank:chunksize * (self.rank + 1)] + print(f"Loaded an Aligner dataset with {len(self.datapoints)} datapoints from {cache_dir}.") + + def _build_dataset_cache(self, + path_to_transcript_dict, + cache_dir, + lang, + loading_processes, + device, + min_len_in_seconds=1, + max_len_in_seconds=15, + verbose=False, + phone_input=False, + allow_unknown_symbols=False, + gpu_count=1, + rank=0 + ): + if gpu_count != 1: + import sys + print("Please run the feature extraction using only a single GPU. Multi-GPU is only supported for training.") + sys.exit() + os.makedirs(cache_dir, exist_ok=True) + if type(path_to_transcript_dict) != dict: + path_to_transcript_dict = path_to_transcript_dict() # in this case we passed a function instead of the dict, so that the function isn't executed if not necessary. + torch.multiprocessing.set_start_method('spawn', force=True) + resource_manager = Manager() + self.path_to_transcript_dict = resource_manager.dict(path_to_transcript_dict) + key_list = list(self.path_to_transcript_dict.keys()) + with open(os.path.join(cache_dir, "files_used.txt"), encoding='utf8', mode="w") as files_used_note: + files_used_note.write(str(key_list)) + fisher_yates_shuffle(key_list) + # build cache + print("... building dataset cache ...") + self.result_pool = resource_manager.list() + # make processes + key_splits = list() + process_list = list() + for i in range(loading_processes): + key_splits.append( + key_list[i * len(key_list) // loading_processes:(i + 1) * len(key_list) // loading_processes]) + for key_split in key_splits: + process_list.append( + Process(target=self._cache_builder_process, + args=(key_split, + lang, + min_len_in_seconds, + max_len_in_seconds, + verbose, + device, + phone_input, + allow_unknown_symbols), + daemon=True)) + process_list[-1].start() + for process in process_list: + process.join() + print("pooling results...") + pooled_datapoints = list() + for chunk in self.result_pool: + for datapoint in chunk: + pooled_datapoints.append(datapoint) # unpack into a joint list + self.result_pool = pooled_datapoints + del pooled_datapoints + print("converting text to tensors...") + text_tensors = [torch.ShortTensor(x[0]) for x in self.result_pool] # turn everything back to tensors (had to turn it to np arrays to avoid multiprocessing issues) + print("converting speech to tensors...") + speech_tensors = [torch.ShortTensor(x[1]) for x in self.result_pool] + print("converting waves to tensors...") + norm_waves = [torch.Tensor(x[2]) for x in self.result_pool] + print("unpacking file list...") + filepaths = [x[3] for x in self.result_pool] + del self.result_pool + self.datapoints = list(zip(text_tensors, speech_tensors)) + del text_tensors + del speech_tensors + print("done!") + + # add speaker embeddings + self.speaker_embeddings = list() + speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", + run_opts={"device": str(device)}, + savedir=os.path.join(MODELS_DIR, "Embedding", "speechbrain_speaker_embedding_ecapa")) + with torch.inference_mode(): + for wave in tqdm(norm_waves): + self.speaker_embeddings.append(speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(device).unsqueeze(0)).squeeze().cpu()) + + # save to cache + if len(self.datapoints) == 0: + raise RuntimeError # something went wrong and there are no datapoints + torch.save((self.datapoints, None, self.speaker_embeddings, filepaths), + os.path.join(cache_dir, "aligner_train_cache.pt")) + + def _cache_builder_process(self, + path_list, + lang, + min_len, + max_len, + verbose, + device, + phone_input, + allow_unknown_symbols): + process_internal_dataset_chunk = list() + torch.hub._validate_not_a_forked_repo = lambda a, b, c: True # torch 1.9 has a bug in the hub loading, this is a workaround + # careful: assumes 16kHz or 8kHz audio + silero_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', + model='silero_vad', + force_reload=False, + onnx=False, + verbose=False) + (get_speech_timestamps, + save_audio, + read_audio, + VADIterator, + collect_chunks) = utils + torch.set_grad_enabled(True) # finding this issue was very infuriating: silero sets + # this to false globally during model loading rather than using inference mode or no_grad + silero_model = silero_model.to(device) + silence = torch.zeros([16000 // 4], device=device) + tf = ArticulatoryCombinedTextFrontend(language=lang) + _, sr = sf.read(path_list[0]) + assumed_sr = sr + ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device) + resample = Resample(orig_freq=assumed_sr, new_freq=16000).to(device) + + for path in tqdm(path_list): + if self.path_to_transcript_dict[path].strip() == "": + continue + + try: + wave, sr = sf.read(path) + except: + print(f"Problem with an audio file: {path}") + continue + + wave = librosa.to_mono(wave) + + if sr != assumed_sr: + assumed_sr = sr + ap = CodecAudioPreprocessor(input_sr=assumed_sr, device=device) + resample = Resample(orig_freq=assumed_sr, new_freq=16000).to(device) + print(f"{path} has a different sampling rate --> adapting the codec processor") + + try: + norm_wave = resample(torch.tensor(wave).float().to(device)) + except ValueError: + continue + dur_in_seconds = len(norm_wave) / 16000 + if not (min_len <= dur_in_seconds <= max_len): + if verbose: + print(f"Excluding {path} because of its duration of {round(dur_in_seconds, 2)} seconds.") + continue + + # remove silences from front and back, then add constant 1/4th second silences back to front and back + with torch.no_grad(): + speech_timestamps = get_speech_timestamps(norm_wave, silero_model, sampling_rate=16000) + try: + result = norm_wave[speech_timestamps[0]['start']:speech_timestamps[-1]['end']] + except IndexError: + print("Audio might be too short to cut silences from front and back.") + continue + wave = torch.cat([silence, result, silence]) + + # raw audio preprocessing is done + transcript = self.path_to_transcript_dict[path] + + try: + try: + cached_text = tf.string_to_tensor(transcript, handle_missing=False, input_phonemes=phone_input).squeeze(0).cpu().numpy() + except KeyError: + cached_text = tf.string_to_tensor(transcript, handle_missing=True, input_phonemes=phone_input).squeeze(0).cpu().numpy() + if not allow_unknown_symbols: + continue # we skip sentences with unknown symbols + except ValueError: + # this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample. + continue + except KeyError: + # this can happen for Mandarin Chinese, when the syllabification of pinyin doesn't work. In that case, we just skip the sample. + continue + + cached_speech = ap.audio_to_codebook_indexes(audio=wave, current_sampling_rate=16000).transpose(0, 1).cpu().numpy() + process_internal_dataset_chunk.append([cached_text, + cached_speech, + result.cpu().detach().numpy(), + path]) + self.result_pool.append(process_internal_dataset_chunk) + + def __getitem__(self, index): + text_vector = self.datapoints[index][0] + tokens = self.tf.text_vectors_to_id_sequence(text_vector=text_vector) + tokens = torch.LongTensor(tokens) + token_len = torch.LongTensor([len(tokens)]) + + codes = self.datapoints[index][1] + if codes.size()[0] != 24: # no clue why this is sometimes the case + codes = codes.transpose(0, 1) + + return tokens, \ + token_len, \ + codes, \ + None, \ + self.speaker_embeddings[index] + + def __len__(self): + return len(self.datapoints) + + +def fisher_yates_shuffle(lst): + for i in range(len(lst) - 1, 0, -1): + j = random.randint(0, i) + lst[i], lst[j] = lst[j], lst[i] diff --git a/Architectures/Aligner/README.md b/Architectures/Aligner/README.md new file mode 100644 index 0000000000000000000000000000000000000000..88a0c156d8837eea2472a0ccc617fcd325c66f8b --- /dev/null +++ b/Architectures/Aligner/README.md @@ -0,0 +1 @@ +Everything that is concerned with training and using the aligner model is contained in this directory. It is recommended to use the universal aligner model that we supply in the GitHub releases. \ No newline at end of file diff --git a/Architectures/Aligner/Reconstructor.py b/Architectures/Aligner/Reconstructor.py new file mode 100644 index 0000000000000000000000000000000000000000..066f886f4080c1347f731a70a0938d2e571939ca --- /dev/null +++ b/Architectures/Aligner/Reconstructor.py @@ -0,0 +1,40 @@ +import torch +import torch.multiprocessing +from torch.nn.utils.rnn import pack_padded_sequence +from torch.nn.utils.rnn import pad_packed_sequence + +from Utility.utils import make_non_pad_mask + + +class Reconstructor(torch.nn.Module): + + def __init__(self, + n_features=128, + num_symbols=145, + speaker_embedding_dim=192, + lstm_dim=256): + super().__init__() + self.in_proj = torch.nn.Linear(num_symbols + speaker_embedding_dim, lstm_dim) + self.rnn1 = torch.nn.LSTM(lstm_dim, lstm_dim, batch_first=True, bidirectional=True) + self.rnn2 = torch.nn.LSTM(2 * lstm_dim, lstm_dim, batch_first=True, bidirectional=True) + self.out_proj = torch.nn.Linear(2 * lstm_dim, n_features) + self.l1_criterion = torch.nn.L1Loss(reduction="none") + self.l2_criterion = torch.nn.MSELoss(reduction="none") + + def forward(self, x, lens, ys): + x = self.in_proj(x) + x = pack_padded_sequence(x, lens.cpu(), batch_first=True, enforce_sorted=False) + x, _ = self.rnn1(x) + x, _ = self.rnn2(x) + x, _ = pad_packed_sequence(x, batch_first=True) + x = self.out_proj(x) + out_masks = make_non_pad_mask(lens).unsqueeze(-1).to(ys.device) + out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float() + out_weights /= ys.size(0) * ys.size(2) + l1_loss = self.l1_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum() + l2_loss = self.l2_criterion(x, ys).mul(out_weights).masked_select(out_masks).sum() + return l1_loss + l2_loss + + +if __name__ == '__main__': + print(sum(p.numel() for p in Reconstructor().parameters() if p.requires_grad)) \ No newline at end of file diff --git a/Architectures/Aligner/__init__.py b/Architectures/Aligner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Architectures/Aligner/autoaligner_train_loop.py b/Architectures/Aligner/autoaligner_train_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..cf10bb6fc4997d19426c11282738a1abebe3ce00 --- /dev/null +++ b/Architectures/Aligner/autoaligner_train_loop.py @@ -0,0 +1,188 @@ +import os +import time + +import torch +import torch.multiprocessing +from torch.nn.utils.rnn import pad_sequence +from torch.optim import RAdam +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from Architectures.Aligner.Aligner import Aligner +from Architectures.Aligner.Reconstructor import Reconstructor +from Preprocessing.AudioPreprocessor import AudioPreprocessor +from Preprocessing.EnCodecAudioPreprocessor import CodecAudioPreprocessor + + +def collate_and_pad(batch): + # text, text_len, speech, speech_len, embed + return (pad_sequence([datapoint[0] for datapoint in batch], batch_first=True), + torch.stack([datapoint[1] for datapoint in batch]).squeeze(1), + [datapoint[2] for datapoint in batch], + None, + torch.stack([datapoint[4] for datapoint in batch]).squeeze()) + + +def train_loop(train_dataset, + device, + save_directory, + batch_size, + steps, + path_to_checkpoint=None, + fine_tune=False, + resume=False, + debug_img_path=None, + use_reconstruction=True, + gpu_count=1, + rank=0, + steps_per_checkpoint=None): + """ + Args: + resume: whether to resume from the most recent checkpoint + steps: How many steps to train + path_to_checkpoint: reloads a checkpoint to continue training from there + fine_tune: whether to load everything from a checkpoint, or only the model parameters + train_dataset: Pytorch Dataset Object for train data + device: Device to put the loaded tensors on + save_directory: Where to save the checkpoints + batch_size: How many elements should be loaded at once + debug_img_path: where to put images of the training progress if desired + use_reconstruction: whether to use the auxiliary reconstruction procedure/loss, which can make the alignment sharper + """ + os.makedirs(save_directory, exist_ok=True) + torch.multiprocessing.set_sharing_strategy('file_system') + torch.multiprocessing.set_start_method('spawn', force=True) + + if steps_per_checkpoint is None: + steps_per_checkpoint = len(train_dataset) // batch_size + ap = CodecAudioPreprocessor(input_sr=-1, device=device) # only used to transform features into continuous matrices + spectrogram_extractor = AudioPreprocessor(input_sr=16000, output_sr=16000, device=device) + + asr_model = Aligner().to(device) + optim_asr = RAdam(asr_model.parameters(), lr=0.0001) + + tiny_tts = Reconstructor().to(device) + optim_tts = RAdam(tiny_tts.parameters(), lr=0.0001) + + if gpu_count > 1: + asr_model.to(rank) + tiny_tts.to(rank) + asr_model = torch.nn.parallel.DistributedDataParallel( + asr_model, + device_ids=[rank], + output_device=rank, + find_unused_parameters=True, + ).module + tiny_tts = torch.nn.parallel.DistributedDataParallel( + tiny_tts, + device_ids=[rank], + output_device=rank, + find_unused_parameters=True, + ).module + torch.distributed.barrier() + train_sampler = torch.utils.data.RandomSampler(train_dataset) + batch_sampler_train = torch.utils.data.BatchSampler(train_sampler, batch_size, drop_last=True) + + train_loader = DataLoader(dataset=train_dataset, + num_workers=0, # unfortunately necessary for big data due to mmap errors + batch_sampler=batch_sampler_train, + prefetch_factor=None, + collate_fn=collate_and_pad) + + step_counter = 0 + loss_sum = list() + + if resume: + previous_checkpoint = os.path.join(save_directory, "aligner.pt") + path_to_checkpoint = previous_checkpoint + fine_tune = False + + if path_to_checkpoint is not None: + check_dict = torch.load(os.path.join(path_to_checkpoint), map_location=device) + asr_model.load_state_dict(check_dict["asr_model"]) + tiny_tts.load_state_dict(check_dict["tts_model"]) + if not fine_tune: + optim_asr.load_state_dict(check_dict["optimizer"]) + optim_tts.load_state_dict(check_dict["tts_optimizer"]) + step_counter = check_dict["step_counter"] + if step_counter > steps: + print("Desired steps already reached in loaded checkpoint.") + return + start_time = time.time() + + while True: + asr_model.train() + tiny_tts.train() + for batch in tqdm(train_loader): + tokens = batch[0].to(device) + tokens_len = batch[1].to(device) + speaker_embeddings = batch[4].to(device) + + mels = list() + mel_lengths = list() + for datapoint in batch[2]: + with torch.inference_mode(): + # extremely unfortunate that we have to do this over here, but multiprocessing and this don't go together well + speech = ap.indexes_to_audio(datapoint.int().to(device)) + mel = spectrogram_extractor.audio_to_mel_spec_tensor(speech, explicit_sampling_rate=16000).transpose(0, 1).cpu() + speech_len = torch.LongTensor([len(mel)]) + mels.append(mel.clone()) + mel_lengths.append(speech_len) + mel = pad_sequence(mels, batch_first=True).to(device) + mel_len = torch.stack(mel_lengths).squeeze(1).to(device) + + pred = asr_model(mel, mel_len) + + ctc_loss = asr_model.ctc_loss(pred.transpose(0, 1).log_softmax(2), + tokens, + mel_len, + tokens_len) + + if use_reconstruction: + speaker_embeddings_expanded = torch.nn.functional.normalize(speaker_embeddings).unsqueeze(1).expand(-1, pred.size(1), -1) + tts_lambda = min([0.1, step_counter / 10000]) # super simple schedule + reconstruction_loss = tiny_tts(x=torch.cat([pred, speaker_embeddings_expanded], dim=-1), + # combine ASR prediction with speaker embeddings to allow for reconstruction loss on multiple speakers + lens=mel_len, + ys=mel) * tts_lambda # reconstruction loss to make the states more distinct + loss = ctc_loss + reconstruction_loss + else: + loss = ctc_loss + + optim_asr.zero_grad() + if use_reconstruction: + optim_tts.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(asr_model.parameters(), 1.0) + if use_reconstruction: + torch.nn.utils.clip_grad_norm_(tiny_tts.parameters(), 1.0) + optim_asr.step() + if use_reconstruction: + optim_tts.step() + + loss_sum.append(loss.item()) + step_counter += 1 + + if step_counter % steps_per_checkpoint == 0 and rank == 0: + asr_model.eval() + torch.save({ + "asr_model" : asr_model.state_dict(), + "optimizer" : optim_asr.state_dict(), + "tts_model" : tiny_tts.state_dict(), + "tts_optimizer": optim_tts.state_dict(), + "step_counter" : step_counter, + }, + os.path.join(save_directory, "aligner.pt")) + print("Total Loss: {}".format(round(sum(loss_sum) / len(loss_sum), 3))) + print("Time elapsed: {} Minutes".format(round((time.time() - start_time) / 60))) + print("Steps: {}".format(step_counter)) + if debug_img_path is not None: + asr_model.inference(features=mel[0][:mel_len[0]], + tokens=tokens[0][:tokens_len[0]], + save_img_for_debug=debug_img_path + f"/{step_counter}.png", + train=True) # for testing + asr_model.train() + loss_sum = list() + + if step_counter > steps and step_counter % steps_per_checkpoint == 0: + return diff --git a/Architectures/ControllabilityGAN/GAN.py b/Architectures/ControllabilityGAN/GAN.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd8fc8e6048f2096c10db6469324b55814e2d29 --- /dev/null +++ b/Architectures/ControllabilityGAN/GAN.py @@ -0,0 +1,82 @@ +import torch + +from Architectures.ControllabilityGAN.wgan.init_wgan import create_wgan + + +class GanWrapper: + + def __init__(self, path_wgan, device): + self.device = device + self.path_wgan = path_wgan + + self.mean = None + self.std = None + self.wgan = None + self.normalize = True + + self.load_model(path_wgan) + + self.U = self.compute_controllability() + + self.z_list = list() + for _ in range(1100): + self.z_list.append(self.wgan.G.module.sample_latent(1, 32)) + self.z = self.z_list[0] + + def set_latent(self, seed): + self.z = self.z = self.z_list[seed] + + def reset_default_latent(self): + self.z = self.wgan.G.module.sample_latent(1, 32) + + def load_model(self, path): + gan_checkpoint = torch.load(path, map_location="cpu") + + self.wgan = create_wgan(parameters=gan_checkpoint['model_parameters'], device=self.device) + self.wgan.G.load_state_dict(gan_checkpoint['generator_state_dict']) + self.wgan.D.load_state_dict(gan_checkpoint['critic_state_dict']) + + self.mean = gan_checkpoint["dataset_mean"] + self.std = gan_checkpoint["dataset_std"] + + def compute_controllability(self, n_samples=50000): + _, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True) + intermediate = intermediate.cpu() + z = z.cpu() + U = self.controllable_speakers(intermediate, z) + return U + + def controllable_speakers(self, intermediate, z): + pca = torch.pca_lowrank(intermediate) + mu = intermediate.mean() + X = torch.matmul((intermediate - mu), pca[2]) + U = torch.linalg.lstsq(X, z) + return U + + def get_original_embed(self): + self.wgan.G.eval() + embed_original = self.wgan.G.module.forward(self.z.to(self.device)) + + if self.normalize: + embed_original = inverse_normalize( + embed_original.cpu(), + self.mean.cpu().unsqueeze(0), + self.std.cpu().unsqueeze(0) + ) + return embed_original + + def modify_embed(self, x): + self.wgan.G.eval() + z_new = self.z.squeeze() + torch.matmul(self.U.solution.t(), x) + embed_modified = self.wgan.G.module.forward(z_new.unsqueeze(0).to(self.device)) + if self.normalize: + embed_modified = inverse_normalize( + embed_modified.cpu(), + self.mean.cpu().unsqueeze(0), + self.std.cpu().unsqueeze(0) + ) + return embed_modified + + +def inverse_normalize(tensor, mean, std): + return tensor * std + mean diff --git a/Architectures/ControllabilityGAN/__init__.py b/Architectures/ControllabilityGAN/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Architectures/ControllabilityGAN/dataset/__init__.py b/Architectures/ControllabilityGAN/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Architectures/ControllabilityGAN/dataset/speaker_embeddings_dataset.py b/Architectures/ControllabilityGAN/dataset/speaker_embeddings_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d6c3a8f0ed4c9c4ced3641ccffc78793a3b8bf50 --- /dev/null +++ b/Architectures/ControllabilityGAN/dataset/speaker_embeddings_dataset.py @@ -0,0 +1,94 @@ +import os + +import numpy as np +import torch + + +class SpeakerEmbeddingsDataset(torch.utils.data.Dataset): + + def __init__(self, feature_path, device, mode='utterance'): + super(SpeakerEmbeddingsDataset, self).__init__() + + modes = ['utterance', 'speaker'] + assert mode in modes, f'mode: {mode} is not supported' + if mode == 'utterance': + self.mode = 'utt' + elif mode == 'speaker': + self.mode = 'spk' + + self.device = device + + self.x, self.speakers = self._load_features(feature_path) + # unique_speakers = set(self.speakers) + # spk2class = dict(zip(unique_speakers, range(len(unique_speakers)))) + # #self.x = self._reformat_features(self.x) + # self.y = torch.tensor([spk2class[spk] for spk in self.speakers]).to(self.device) + # self.class2spk = dict(zip(spk2class.values(), spk2class.keys())) + + def __len__(self): + return len(self.speakers) + + def __getitem__(self, index): + embedding = self.normalize_embedding(self.x[index]) + # speaker_id = self.y[index] + return embedding, torch.zeros([0]) + + def normalize_embedding(self, vector): + return torch.sub(vector, self.mean) / self.std + + def get_speaker(self, label): + return self.class2spk[label] + + def get_embedding_dim(self): + return self.x.shape[-1] + + def get_num_speaker(self): + return len(torch.unique((self.y))) + + def set_labels(self, labels): + self.y_old = self.y + self.y = torch.full(size=(len(self),), fill_value=labels).to(self.device) + # if isinstance(labels, int) or isinstance(labels, float): + # self.y = torch.full(size=len(self), fill_value=labels) + # elif len(labels) == len(self): + # self.y = torch.tensor(labels) + + def _load_features(self, feature_path): + if os.path.isfile(feature_path): + vectors = torch.load(feature_path, map_location=self.device) + if isinstance(vectors, list): + vectors = torch.stack(vectors) + + self.mean = torch.mean(vectors) + self.std = torch.std(vectors) + return vectors, torch.zeros(vectors.size(0)) + else: + vectors = torch.load(feature_path, map_location=self.device) + + self.mean = torch.mean(vectors) + self.std = torch.std(vectors) + + spk2idx = {} + with open(feature_path / f'{self.mode}2idx', 'r') as f: + for line in f: + split_line = line.strip().split() + if len(split_line) == 2: + spk2idx[split_line[0].strip()] = int(split_line[1]) + + speakers, indices = zip(*spk2idx.items()) + + if (feature_path / 'utt2spk').exists(): # spk2idx contains utt_ids not speaker_ids + utt2spk = {} + with open(feature_path / 'utt2spk', 'r') as f: + for line in f: + split_line = line.strip().split() + if len(split_line) == 2: + utt2spk[split_line[0].strip()] = split_line[1].strip() + + speakers = [utt2spk[utt] for utt in speakers] + + return vectors[np.array(indices)], speakers + + def _reformat_features(self, features): + if len(features.shape) == 2: + return features.reshape(features.shape[0], 1, 1, features.shape[1]) diff --git a/Architectures/ControllabilityGAN/wgan/__init__.py b/Architectures/ControllabilityGAN/wgan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Architectures/ControllabilityGAN/wgan/init_weights.py b/Architectures/ControllabilityGAN/wgan/init_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..3aabee2cc8d837c7b71dcd054e7e44978e192224 --- /dev/null +++ b/Architectures/ControllabilityGAN/wgan/init_weights.py @@ -0,0 +1,21 @@ +import torch.nn as nn + + +def weights_init_D(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + # nn.init.constant_(m.bias, 0) + elif classname.find('BatchNorm') != -1: + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + +def weights_init_G(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu') + # nn.init.constant_(m.bias, 0) + elif classname.find('BatchNorm') != -1: + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) diff --git a/Architectures/ControllabilityGAN/wgan/init_wgan.py b/Architectures/ControllabilityGAN/wgan/init_wgan.py new file mode 100644 index 0000000000000000000000000000000000000000..da345b6af6bf237daba480efad1d57247eb26dbf --- /dev/null +++ b/Architectures/ControllabilityGAN/wgan/init_wgan.py @@ -0,0 +1,34 @@ +import torch + +from Architectures.ControllabilityGAN.wgan.resnet_init import init_resnet +from Architectures.ControllabilityGAN.wgan.wgan_qc import WassersteinGanQuadraticCost + + +def create_wgan(parameters, device, optimizer='adam'): + if parameters['model'] == "resnet": + generator, discriminator = init_resnet(parameters) + else: + raise NotImplementedError + + if optimizer == 'adam': + optimizer_g = torch.optim.Adam(generator.parameters(), lr=parameters['learning_rate'], betas=parameters['betas']) + optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=parameters['learning_rate'], betas=parameters['betas']) + elif optimizer == 'rmsprop': + optimizer_g = torch.optim.RMSprop(generator.parameters(), lr=parameters['learning_rate']) + optimizer_d = torch.optim.RMSprop(generator.parameters(), lr=parameters['learning_rate']) + + criterion = torch.nn.MSELoss() + + gan = WassersteinGanQuadraticCost(generator, + discriminator, + optimizer_g, + optimizer_d, + criterion=criterion, + data_dimensions=parameters['data_dim'], + epochs=parameters['epochs'], + batch_size=parameters['batch_size'], + device=device, + n_max_iterations=parameters['n_max_iterations'], + gamma=parameters['gamma']) + + return gan diff --git a/Architectures/ControllabilityGAN/wgan/resnet_1.py b/Architectures/ControllabilityGAN/wgan/resnet_1.py new file mode 100644 index 0000000000000000000000000000000000000000..6f5aef2f79aca66a4b61675718bde809b1ef9b48 --- /dev/null +++ b/Architectures/ControllabilityGAN/wgan/resnet_1.py @@ -0,0 +1,181 @@ +import numpy as np +import torch +import torch.utils.data +import torch.utils.data.distributed +from torch import nn + + +class ResNet_G(nn.Module): + + def __init__(self, data_dim, z_dim, size, nfilter=64, nfilter_max=512, bn=True, res_ratio=0.1, **kwargs): + super().__init__() + self.input_dim = z_dim + self.output_dim = z_dim + self.dropout_rate = 0 + + s0 = self.s0 = 4 + nf = self.nf = nfilter + nf_max = self.nf_max = nfilter_max + self.bn = bn + self.z_dim = z_dim + + # Submodules + nlayers = int(np.log2(size / s0)) + self.nf0 = min(nf_max, nf * 2 ** (nlayers + 1)) + + self.fc = nn.Linear(z_dim, self.nf0 * s0 * s0) + if self.bn: + self.bn1d = nn.BatchNorm1d(self.nf0 * s0 * s0) + self.relu = nn.LeakyReLU(0.2, inplace=True) + + blocks = [] + for i in range(nlayers, 0, -1): + nf0 = min(nf * 2 ** (i + 1), nf_max) + nf1 = min(nf * 2 ** i, nf_max) + blocks += [ + ResNetBlock(nf0, nf1, bn=self.bn, res_ratio=res_ratio), + nn.Upsample(scale_factor=2) + ] + + nf0 = min(nf * 2, nf_max) + nf1 = min(nf, nf_max) + blocks += [ + ResNetBlock(nf0, nf1, bn=self.bn, res_ratio=res_ratio), + ResNetBlock(nf1, nf1, bn=self.bn, res_ratio=res_ratio) + ] + + self.resnet = nn.Sequential(*blocks) + self.conv_img = nn.Conv2d(nf, 3, 3, padding=1) + + self.fc_out = nn.Linear(3 * size * size, data_dim) + + def forward(self, z, return_intermediate=False): + # print(z.shape) + batch_size = z.size(0) + # z = z.view(batch_size, -1) + out = self.fc(z) + if self.bn: + out = self.bn1d(out) + out = self.relu(out) + if return_intermediate: + l_1 = out.detach().clone() + out = out.view(batch_size, self.nf0, self.s0, self.s0) + # print(out.shape) + + out = self.resnet(out) + + # print(out.shape) + # out = out.view(batch_size, self.nf0*self.s0*self.s0*2) + + out = self.conv_img(out) + out = self.relu(out) + out.flatten(1) + out = self.fc_out(out.flatten(1)) + + if return_intermediate: + return out, l_1 + return out + + def sample_latent(self, n_samples, z_size): + return torch.randn((n_samples, z_size)) + + +class ResNet_D(nn.Module): + + def __init__(self, data_dim, size, nfilter=64, nfilter_max=512, res_ratio=0.1): + super().__init__() + s0 = self.s0 = 4 + nf = self.nf = nfilter + nf_max = self.nf_max = nfilter_max + self.size = size + + # Submodules + nlayers = int(np.log2(size / s0)) + self.nf0 = min(nf_max, nf * 2 ** nlayers) + + nf0 = min(nf, nf_max) + nf1 = min(nf * 2, nf_max) + blocks = [ + ResNetBlock(nf0, nf0, bn=False, res_ratio=res_ratio), + ResNetBlock(nf0, nf1, bn=False, res_ratio=res_ratio) + ] + + self.fc_input = nn.Linear(data_dim, 3 * size * size) + + for i in range(1, nlayers + 1): + nf0 = min(nf * 2 ** i, nf_max) + nf1 = min(nf * 2 ** (i + 1), nf_max) + blocks += [ + nn.AvgPool2d(3, stride=2, padding=1), + ResNetBlock(nf0, nf1, bn=False, res_ratio=res_ratio), + ] + + self.conv_img = nn.Conv2d(3, 1 * nf, 3, padding=1) + self.relu = nn.LeakyReLU(0.2, inplace=True) + self.resnet = nn.Sequential(*blocks) + + self.fc = nn.Linear(self.nf0 * s0 * s0, 1) + + def forward(self, x): + batch_size = x.size(0) + + out = self.fc_input(x) + out = self.relu(out).view(batch_size, 3, self.size, self.size) + + out = self.relu((self.conv_img(out))) + out = self.resnet(out) + out = out.view(batch_size, self.nf0 * self.s0 * self.s0) + out = self.fc(out) + + return out + + +class ResNetBlock(nn.Module): + + def __init__(self, fin, fout, fhidden=None, bn=True, res_ratio=0.1): + super().__init__() + # Attributes + self.bn = bn + self.is_bias = not bn + self.learned_shortcut = (fin != fout) + self.fin = fin + self.fout = fout + if fhidden is None: + self.fhidden = min(fin, fout) + else: + self.fhidden = fhidden + self.res_ratio = res_ratio + + # Submodules + self.conv_0 = nn.Conv2d(self.fin, self.fhidden, 3, stride=1, padding=1, bias=self.is_bias) + if self.bn: + self.bn2d_0 = nn.BatchNorm2d(self.fhidden) + self.conv_1 = nn.Conv2d(self.fhidden, self.fout, 3, stride=1, padding=1, bias=self.is_bias) + if self.bn: + self.bn2d_1 = nn.BatchNorm2d(self.fout) + if self.learned_shortcut: + self.conv_s = nn.Conv2d(self.fin, self.fout, 1, stride=1, padding=0, bias=False) + if self.bn: + self.bn2d_s = nn.BatchNorm2d(self.fout) + self.relu = nn.LeakyReLU(0.2, inplace=True) + + def forward(self, x): + x_s = self._shortcut(x) + dx = self.conv_0(x) + if self.bn: + dx = self.bn2d_0(dx) + dx = self.relu(dx) + dx = self.conv_1(dx) + if self.bn: + dx = self.bn2d_1(dx) + out = self.relu(x_s + self.res_ratio * dx) + return out + + def _shortcut(self, x): + if self.learned_shortcut: + x_s = self.conv_s(x) + if self.bn: + x_s = self.bn2d_s(x_s) + else: + x_s = x + return x_s diff --git a/Architectures/ControllabilityGAN/wgan/resnet_init.py b/Architectures/ControllabilityGAN/wgan/resnet_init.py new file mode 100644 index 0000000000000000000000000000000000000000..0ff7e8f228f15d1de629a8c2355cf73f1908681b --- /dev/null +++ b/Architectures/ControllabilityGAN/wgan/resnet_init.py @@ -0,0 +1,15 @@ +from Architectures.ControllabilityGAN.wgan.init_weights import weights_init_D +from Architectures.ControllabilityGAN.wgan.init_weights import weights_init_G +from Architectures.ControllabilityGAN.wgan.resnet_1 import ResNet_D +from Architectures.ControllabilityGAN.wgan.resnet_1 import ResNet_G + + +def init_resnet(parameters): + critic = ResNet_D(parameters['data_dim'][-1], parameters['size'], nfilter=parameters['nfilter'], nfilter_max=parameters['nfilter_max']) + generator = ResNet_G(parameters['data_dim'][-1], parameters['z_dim'], parameters['size'], nfilter=parameters['nfilter'], + nfilter_max=parameters['nfilter_max']) + + generator.apply(weights_init_G) + critic.apply(weights_init_D) + + return generator, critic diff --git a/Architectures/ControllabilityGAN/wgan/wgan_qc.py b/Architectures/ControllabilityGAN/wgan/wgan_qc.py new file mode 100644 index 0000000000000000000000000000000000000000..94ff3b6c433458de61f858cfb61a07fec3f8c67c --- /dev/null +++ b/Architectures/ControllabilityGAN/wgan/wgan_qc.py @@ -0,0 +1,272 @@ +import os +import time + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from cvxopt import matrix +from cvxopt import solvers +from cvxopt import sparse +from cvxopt import spmatrix +from torch.autograd import grad as torch_grad +from tqdm import tqdm + + +class WassersteinGanQuadraticCost: + + def __init__(self, generator, discriminator, gen_optimizer, dis_optimizer, criterion, epochs, n_max_iterations, + data_dimensions, batch_size, device, gamma=0.1, K=-1, milestones=[150000, 250000], lr_anneal=1.0): + self.G = generator + self.G_opt = gen_optimizer + self.D = discriminator + self.D_opt = dis_optimizer + self.losses = { + 'D' : [], + 'WD': [], + 'G' : [] + } + self.num_steps = 0 + self.gen_steps = 0 + self.epochs = epochs + self.n_max_iterations = n_max_iterations + # put in the shape of a dataset sample + self.data_dim = data_dimensions[0] * data_dimensions[1] * data_dimensions[2] + self.batch_size = batch_size + self.device = device + self.criterion = criterion + self.mone = torch.FloatTensor([-1]).to(device) + self.tensorboard_counter = 0 + + if K <= 0: + self.K = 1 / self.data_dim + else: + self.K = K + self.Kr = np.sqrt(self.K) + self.LAMBDA = 2 * self.Kr * gamma * 2 + + self.G = nn.DataParallel(self.G.to(self.device)) + self.D = nn.DataParallel(self.D.to(self.device)) + + self.schedulerD = self._build_lr_scheduler_(self.D_opt, milestones, lr_anneal) + self.schedulerG = self._build_lr_scheduler_(self.G_opt, milestones, lr_anneal) + + self.c, self.A, self.pStart = self._prepare_linear_programming_solver_(self.batch_size) + + def _build_lr_scheduler_(self, optimizer, milestones, lr_anneal, last_epoch=-1): + scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=lr_anneal, last_epoch=-1) + return scheduler + + def _quadratic_wasserstein_distance_(self, real, generated): + num_r = real.size(0) + num_f = generated.size(0) + real_flat = real.view(num_r, -1) + fake_flat = generated.view(num_f, -1) + + real3D = real_flat.unsqueeze(1).expand(num_r, num_f, self.data_dim) + fake3D = fake_flat.unsqueeze(0).expand(num_r, num_f, self.data_dim) + # compute squared L2 distance + dif = real3D - fake3D + dist = 0.5 * dif.pow(2).sum(2).squeeze() + + return self.K * dist + + def _prepare_linear_programming_solver_(self, batch_size): + A = spmatrix(1.0, range(batch_size), [0] * batch_size, (batch_size, batch_size)) + for i in range(1, batch_size): + Ai = spmatrix(1.0, range(batch_size), [i] * batch_size, (batch_size, batch_size)) + A = sparse([A, Ai]) + + D = spmatrix(-1.0, range(batch_size), range(batch_size), (batch_size, batch_size)) + DM = D + for i in range(1, batch_size): + DM = sparse([DM, D]) + + A = sparse([[A], [DM]]) + + cr = matrix([-1.0 / batch_size] * batch_size) + cf = matrix([1.0 / batch_size] * batch_size) + c = matrix([cr, cf]) + + pStart = {} + pStart['x'] = matrix([matrix([1.0] * batch_size), matrix([-1.0] * batch_size)]) + pStart['s'] = matrix([1.0] * (2 * batch_size)) + + return c, A, pStart + + def _linear_programming_(self, distance, batch_size): + b = matrix(distance.cpu().double().detach().numpy().flatten()) + sol = solvers.lp(self.c, self.A, b, primalstart=self.pStart, solver='glpk', + options={'glpk': {'msg_lev': 'GLP_MSG_OFF'}}) + offset = 0.5 * (sum(sol['x'])) / batch_size + sol['x'] = sol['x'] - offset + self.pStart['x'] = sol['x'] + self.pStart['s'] = sol['s'] + + return sol + + def _approx_OT_(self, sol): + # Compute the OT mapping for each fake dataset + ResMat = np.array(sol['z']).reshape((self.batch_size, self.batch_size)) + mapping = torch.from_numpy(np.argmax(ResMat, axis=0)).long().to(self.device) + + return mapping + + def _optimal_transport_regularization_(self, output_fake, fake, real_fake_diff): + output_fake_grad = torch.ones(output_fake.size()).to(self.device) + gradients = torch_grad(outputs=output_fake, inputs=fake, + grad_outputs=output_fake_grad, + create_graph=True, retain_graph=True, only_inputs=True)[0] + n = gradients.size(0) + RegLoss = 0.5 * ((gradients.view(n, -1).norm(dim=1) / (2 * self.Kr) - self.Kr / 2 * real_fake_diff.view(n, + -1).norm( + dim=1)).pow(2)).mean() + fake.requires_grad = False + + return RegLoss + + def _critic_deep_regression_(self, images, opt_iterations=1): + images = images.to(self.device) + + for p in self.D.parameters(): # reset requires_grad + p.requires_grad = True # they are set to False below in netG update + + self.G.train() + self.D.train() + + # Get generated fake dataset + generated_data = self.sample_generator(self.batch_size) + + # compute wasserstein distance + distance = self._quadratic_wasserstein_distance_(images, generated_data) + # solve linear programming problem + sol = self._linear_programming_(distance, self.batch_size) + # approximate optimal transport + mapping = self._approx_OT_(sol) + real_ordered = images[mapping] # match real and fake + real_fake_diff = real_ordered - generated_data + + # construct target + target = torch.from_numpy(np.array(sol['x'])).float() + target = target.squeeze().to(self.device) + + for i in range(opt_iterations): + self.D.zero_grad() # ??? + self.D_opt.zero_grad() + generated_data.requires_grad_() + if generated_data.grad is not None: + generated_data.grad.data.zero_() + output_real = self.D(images) + output_fake = self.D(generated_data) + output_real, output_fake = output_real.squeeze(), output_fake.squeeze() + output_R_mean = output_real.mean(0).view(1) + output_F_mean = output_fake.mean(0).view(1) + + L2LossD_real = self.criterion(output_R_mean[0], target[:self.batch_size].mean()) + L2LossD_fake = self.criterion(output_fake, target[self.batch_size:]) + L2LossD = 0.5 * L2LossD_real + 0.5 * L2LossD_fake + + reg_loss_D = self._optimal_transport_regularization_(output_fake, generated_data, real_fake_diff) + + total_loss = L2LossD + self.LAMBDA * reg_loss_D + + self.losses['D'].append(float(total_loss.data)) + + total_loss.backward() + self.D_opt.step() + + # this is supposed to be the wasserstein distance + wasserstein_distance = output_R_mean - output_F_mean + self.losses['WD'].append(float(wasserstein_distance.data)) + + def _generator_train_iteration(self, batch_size): + for p in self.D.parameters(): + p.requires_grad = False # freeze critic + + self.G.zero_grad() + self.G_opt.zero_grad() + + if isinstance(self.G, torch.nn.parallel.DataParallel): + z = self.G.module.sample_latent(batch_size, self.G.module.z_dim) + else: + z = self.G.sample_latent(batch_size, self.G.z_dim) + z.requires_grad = True + + fake = self.G(z) + output_fake = self.D(fake) + output_F_mean_after = output_fake.mean(0).view(1) + + self.losses['G'].append(float(output_F_mean_after.data)) + + output_F_mean_after.backward(self.mone) + self.G_opt.step() + + self.schedulerD.step() + self.schedulerG.step() + + def _train_epoch(self, data_loader, writer, experiment): + for i, data in enumerate(tqdm(data_loader)): + images = data[0] + speaker_ids = data[1] + self.num_steps += 1 + # self.tensorboard_counter += 1 + if self.gen_steps >= self.n_max_iterations: + return + self._critic_deep_regression_(images) + self._generator_train_iteration(images.size(0)) + + D_loss_avg = np.average(self.losses['D']) + G_loss_avg = np.average(self.losses['G']) + wd_avg = np.average(self.losses['WD']) + + def train(self, data_loader, writer, experiment=None): + self.G.train() + self.D.train() + + for epoch in range(self.epochs): + if self.gen_steps >= self.n_max_iterations: + return + time_start_epoch = time.time() + self._train_epoch(data_loader, writer, experiment) + + D_loss_avg = np.average(self.losses['D']) + + time_end_epoch = time.time() + + return self + + def sample_generator(self, num_samples, nograd=False, return_intermediate=False): + self.G.eval() + if isinstance(self.G, torch.nn.parallel.DataParallel): + latent_samples = self.G.module.sample_latent(num_samples, self.G.module.z_dim) + else: + latent_samples = self.G.sample_latent(num_samples, self.G.z_dim) + latent_samples = latent_samples.to(self.device) + if nograd: + with torch.no_grad(): + generated_data = self.G(latent_samples, return_intermediate=return_intermediate) + else: + generated_data = self.G(latent_samples) + self.G.train() + if return_intermediate: + return generated_data[0].detach(), generated_data[1], latent_samples + return generated_data.detach() + + def sample(self, num_samples): + generated_data = self.sample_generator(num_samples) + # Remove color channel + return generated_data.data.cpu().numpy()[:, 0, :, :] + + def save_model_checkpoint(self, model_path, model_parameters, timestampStr): + # dateTimeObj = datetime.now() + # timestampStr = dateTimeObj.strftime("%d-%m-%Y-%H-%M-%S") + name = '%s_%s' % (timestampStr, 'wgan') + model_filename = os.path.join(model_path, name) + torch.save({ + 'generator_state_dict' : self.G.state_dict(), + 'critic_state_dict' : self.D.state_dict(), + 'gen_optimizer_state_dict' : self.G_opt.state_dict(), + 'critic_optimizer_state_dict': self.D_opt.state_dict(), + 'model_parameters' : model_parameters, + 'iterations' : self.num_steps + }, model_filename) diff --git a/Architectures/EmbeddingModel/GST.py b/Architectures/EmbeddingModel/GST.py new file mode 100644 index 0000000000000000000000000000000000000000..b0f2435bd91aaf93a3de3ab6c779409972f0b907 --- /dev/null +++ b/Architectures/EmbeddingModel/GST.py @@ -0,0 +1,235 @@ +# Copyright 2020 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import torch + +from Architectures.GeneralLayers.Attention import MultiHeadedAttention as BaseMultiHeadedAttention + + +class GSTStyleEncoder(torch.nn.Module): + """Style encoder. + This module is style encoder introduced in `Style Tokens: Unsupervised Style + Modeling, Control and Transfer in End-to-End Speech Synthesis`. + .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End + Speech Synthesis`: https://arxiv.org/abs/1803.09017 + Args: + idim (int, optional): Dimension of the input features. + gst_tokens (int, optional): The number of GST embeddings. + gst_token_dim (int, optional): Dimension of each GST embedding. + gst_heads (int, optional): The number of heads in GST multihead attention. + conv_layers (int, optional): The number of conv layers in the reference encoder. + conv_chans_list: (Sequence[int], optional): + List of the number of channels of conv layers in the reference encoder. + conv_kernel_size (int, optional): + Kernel size of conv layers in the reference encoder. + conv_stride (int, optional): + Stride size of conv layers in the reference encoder. + gst_layers (int, optional): The number of GRU layers in the reference encoder. + gst_units (int, optional): The number of GRU units in the reference encoder. + """ + + def __init__( + self, + idim: int = 128, + gst_tokens: int = 512, # adaspeech suggests to use many more "basis vectors", but I believe that this is already sufficient + gst_token_dim: int = 64, + gst_heads: int = 8, + conv_layers: int = 8, + conv_chans_list=(32, 32, 64, 64, 128, 128, 256, 256), + conv_kernel_size: int = 3, + conv_stride: int = 2, + gst_layers: int = 2, + gst_units: int = 256, + ): + """Initialize global style encoder module.""" + super(GSTStyleEncoder, self).__init__() + + self.num_tokens = gst_tokens + self.ref_enc = ReferenceEncoder(idim=idim, + conv_layers=conv_layers, + conv_chans_list=conv_chans_list, + conv_kernel_size=conv_kernel_size, + conv_stride=conv_stride, + gst_layers=gst_layers, + gst_units=gst_units, ) + self.stl = StyleTokenLayer(ref_embed_dim=gst_units, + gst_tokens=gst_tokens, + gst_token_dim=gst_token_dim, + gst_heads=gst_heads, ) + + def forward(self, speech): + """Calculate forward propagation. + Args: + speech (Tensor): Batch of padded target features (B, Lmax, odim). + Returns: + Tensor: Style token embeddings (B, token_dim). + """ + ref_embs = self.ref_enc(speech) + style_embs = self.stl(ref_embs) + + return style_embs + + def calculate_ada4_regularization_loss(self): + losses = list() + for emb1_index in range(self.num_tokens): + for emb2_index in range(emb1_index + 1, self.num_tokens): + if emb1_index != emb2_index: + losses.append(torch.nn.functional.cosine_similarity(self.stl.gst_embs[emb1_index], + self.stl.gst_embs[emb2_index], dim=0)) + return sum(losses) + + +class ReferenceEncoder(torch.nn.Module): + """Reference encoder module. + This module is reference encoder introduced in `Style Tokens: Unsupervised Style + Modeling, Control and Transfer in End-to-End Speech Synthesis`. + .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End + Speech Synthesis`: https://arxiv.org/abs/1803.09017 + Args: + idim (int, optional): Dimension of the input features. + conv_layers (int, optional): The number of conv layers in the reference encoder. + conv_chans_list: (Sequence[int], optional): + List of the number of channels of conv layers in the reference encoder. + conv_kernel_size (int, optional): + Kernel size of conv layers in the reference encoder. + conv_stride (int, optional): + Stride size of conv layers in the reference encoder. + gst_layers (int, optional): The number of GRU layers in the reference encoder. + gst_units (int, optional): The number of GRU units in the reference encoder. + """ + + def __init__( + self, + idim=80, + conv_layers: int = 6, + conv_chans_list=(32, 32, 64, 64, 128, 128), + conv_kernel_size: int = 3, + conv_stride: int = 2, + gst_layers: int = 1, + gst_units: int = 128, + ): + """Initialize reference encoder module.""" + super(ReferenceEncoder, self).__init__() + + # check hyperparameters are valid + assert conv_kernel_size % 2 == 1, "kernel size must be odd." + assert ( + len(conv_chans_list) == conv_layers), "the number of conv layers and length of channels list must be the same." + + convs = [] + padding = (conv_kernel_size - 1) // 2 + for i in range(conv_layers): + conv_in_chans = 1 if i == 0 else conv_chans_list[i - 1] + conv_out_chans = conv_chans_list[i] + convs += [torch.nn.Conv2d(conv_in_chans, + conv_out_chans, + kernel_size=conv_kernel_size, + stride=conv_stride, + padding=padding, + # Do not use bias due to the following batch norm + bias=False, ), + torch.nn.BatchNorm2d(conv_out_chans), + torch.nn.ReLU(inplace=True), ] + self.convs = torch.nn.Sequential(*convs) + + self.conv_layers = conv_layers + self.kernel_size = conv_kernel_size + self.stride = conv_stride + self.padding = padding + + # get the number of GRU input units + gst_in_units = idim + for i in range(conv_layers): + gst_in_units = (gst_in_units - conv_kernel_size + 2 * padding) // conv_stride + 1 + gst_in_units *= conv_out_chans + self.gst = torch.nn.GRU(gst_in_units, gst_units, gst_layers, batch_first=True) + + def forward(self, speech): + """Calculate forward propagation. + Args: + speech (Tensor): Batch of padded target features (B, Lmax, idim). + Returns: + Tensor: Reference embedding (B, gst_units) + """ + batch_size = speech.size(0) + xs = speech.unsqueeze(1) # (B, 1, Lmax, idim) + hs = self.convs(xs).transpose(1, 2) # (B, Lmax', conv_out_chans, idim') + time_length = hs.size(1) + hs = hs.contiguous().view(batch_size, time_length, -1) # (B, Lmax', gst_units) + self.gst.flatten_parameters() + # pack_padded_sequence(hs, speech_lens, enforce_sorted=False, batch_first=True) + _, ref_embs = self.gst(hs) # (gst_layers, batch_size, gst_units) + ref_embs = ref_embs[-1] # (batch_size, gst_units) + + return ref_embs + + +class StyleTokenLayer(torch.nn.Module): + """Style token layer module. + This module is style token layer introduced in `Style Tokens: Unsupervised Style + Modeling, Control and Transfer in End-to-End Speech Synthesis`. + .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End + Speech Synthesis`: https://arxiv.org/abs/1803.09017 + Args: + ref_embed_dim (int, optional): Dimension of the input reference embedding. + gst_tokens (int, optional): The number of GST embeddings. + gst_token_dim (int, optional): Dimension of each GST embedding. + gst_heads (int, optional): The number of heads in GST multihead attention. + dropout_rate (float, optional): Dropout rate in multi-head attention. + """ + + def __init__( + self, + ref_embed_dim: int = 128, + gst_tokens: int = 10, + gst_token_dim: int = 128, + gst_heads: int = 4, + dropout_rate: float = 0.0, + ): + """Initialize style token layer module.""" + super(StyleTokenLayer, self).__init__() + + gst_embs = torch.randn(gst_tokens, gst_token_dim // gst_heads) + self.register_parameter("gst_embs", torch.nn.Parameter(gst_embs)) + self.mha = MultiHeadedAttention(q_dim=ref_embed_dim, + k_dim=gst_token_dim // gst_heads, + v_dim=gst_token_dim // gst_heads, + n_head=gst_heads, + n_feat=gst_token_dim, + dropout_rate=dropout_rate, ) + + def forward(self, ref_embs): + """Calculate forward propagation. + Args: + ref_embs (Tensor): Reference embeddings (B, ref_embed_dim). + Returns: + Tensor: Style token embeddings (B, gst_token_dim). + """ + batch_size = ref_embs.size(0) + # (num_tokens, token_dim) -> (batch_size, num_tokens, token_dim) + gst_embs = torch.tanh(self.gst_embs).unsqueeze(0).expand(batch_size, -1, -1) + # NOTE(kan-bayashi): Shoule we apply Tanh? + ref_embs = ref_embs.unsqueeze(1) # (batch_size, 1 ,ref_embed_dim) + style_embs = self.mha(ref_embs, gst_embs, gst_embs, None) + + return style_embs.squeeze(1) + + +class MultiHeadedAttention(BaseMultiHeadedAttention): + """Multi head attention module with different input dimension.""" + + def __init__(self, q_dim, k_dim, v_dim, n_head, n_feat, dropout_rate=0.0): + """Initialize multi head attention module.""" + # NOTE(kan-bayashi): Do not use super().__init__() here since we want to + # overwrite BaseMultiHeadedAttention.__init__() method. + torch.nn.Module.__init__(self) + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = torch.nn.Linear(q_dim, n_feat) + self.linear_k = torch.nn.Linear(k_dim, n_feat) + self.linear_v = torch.nn.Linear(v_dim, n_feat) + self.linear_out = torch.nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = torch.nn.Dropout(p=dropout_rate) diff --git a/Architectures/EmbeddingModel/README.md b/Architectures/EmbeddingModel/README.md new file mode 100644 index 0000000000000000000000000000000000000000..aff7920a3bf3a42646086174d3c9f4d3988b5906 --- /dev/null +++ b/Architectures/EmbeddingModel/README.md @@ -0,0 +1 @@ +Everything that is concerned with the embedding model is contained in this directory. The embedding function does not have its own train loop, because it is always trained jointly with the TTS. Most of the time however, it is used in a frozen state. We recommend using the embedding function that we publish in the GitHub releases. \ No newline at end of file diff --git a/Architectures/EmbeddingModel/StyleEmbedding.py b/Architectures/EmbeddingModel/StyleEmbedding.py new file mode 100644 index 0000000000000000000000000000000000000000..d7154010e3b9c4945dc76a6cefdcb0fc8541ef06 --- /dev/null +++ b/Architectures/EmbeddingModel/StyleEmbedding.py @@ -0,0 +1,73 @@ +import torch + +from Architectures.EmbeddingModel.GST import GSTStyleEncoder +from Architectures.EmbeddingModel.StyleTTSEncoder import StyleEncoder as StyleTTSEncoder + + +class StyleEmbedding(torch.nn.Module): + """ + The style embedding should provide information of the speaker and their speaking style + + The feedback signal for the module will come from the TTS objective, so it doesn't have a dedicated train loop. + The train loop does however supply supervision in the form of a barlow twins objective. + + See the git history for some other approaches for style embedding, like the SWIN transformer + and a simple LSTM baseline. GST turned out to be the best. + """ + + def __init__(self, embedding_dim=16, style_tts_encoder=False): + super().__init__() + self.embedding_dim = embedding_dim + self.use_gst = not style_tts_encoder + if style_tts_encoder: + self.style_encoder = StyleTTSEncoder(style_dim=embedding_dim) + else: + self.style_encoder = GSTStyleEncoder(gst_token_dim=embedding_dim) + + def forward(self, + batch_of_feature_sequences, + batch_of_feature_sequence_lengths): + """ + Args: + batch_of_feature_sequences: b is the batch axis, 128 features per timestep + and l time-steps, which may include padding + for most elements in the batch (b, l, 128) + batch_of_feature_sequence_lengths: indicate for every element in the batch, + what the true length is, since they are + all padded to the length of the longest + element in the batch (b, 1) + Returns: + batch of n dimensional embeddings (b,n) + """ + + minimum_sequence_length = 512 + specs = list() + for index, spec_length in enumerate(batch_of_feature_sequence_lengths): + spec = batch_of_feature_sequences[index][:spec_length] + # double the length at least once, then check + spec = spec.repeat((2, 1)) + current_spec_length = len(spec) + while current_spec_length < minimum_sequence_length: + # make it longer + spec = spec.repeat((2, 1)) + current_spec_length = len(spec) + specs.append(spec[:minimum_sequence_length]) + + spec_batch = torch.stack(specs, dim=0) + return self.style_encoder(speech=spec_batch) + + +if __name__ == '__main__': + style_emb = StyleEmbedding(style_tts_encoder=False) + print(f"GST parameter count: {sum(p.numel() for p in style_emb.style_encoder.parameters() if p.requires_grad)}") + + seq_length = 398 + print(style_emb(torch.randn(5, seq_length, 512), + torch.tensor([seq_length, seq_length, seq_length, seq_length, seq_length])).shape) + + style_emb = StyleEmbedding(style_tts_encoder=True) + print(f"StyleTTS encoder parameter count: {sum(p.numel() for p in style_emb.style_encoder.parameters() if p.requires_grad)}") + + seq_length = 398 + print(style_emb(torch.randn(5, seq_length, 512), + torch.tensor([seq_length, seq_length, seq_length, seq_length, seq_length])).shape) diff --git a/Architectures/EmbeddingModel/StyleTTSEncoder.py b/Architectures/EmbeddingModel/StyleTTSEncoder.py new file mode 100644 index 0000000000000000000000000000000000000000..38165f95b243f0ed1b9e687199721f669d36ec35 --- /dev/null +++ b/Architectures/EmbeddingModel/StyleTTSEncoder.py @@ -0,0 +1,156 @@ +""" +MIT Licensed Code + +Copyright (c) 2022 Aaron (Yinghao) Li + +https://github.com/yl4579/StyleTTS/blob/main/models.py +""" + +import math + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.utils import spectral_norm + + +class StyleEncoder(nn.Module): + def __init__(self, dim_in=128, style_dim=64, max_conv_dim=384): + super().__init__() + blocks = [] + blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))] + + repeat_num = 4 + for _ in range(repeat_num): + dim_out = min(dim_in * 2, max_conv_dim) + blocks += [ResBlk(dim_in, dim_out, downsample='half')] + dim_in = dim_out + + blocks += [nn.LeakyReLU(0.2)] + blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))] + blocks += [nn.AdaptiveAvgPool2d(1)] + blocks += [nn.LeakyReLU(0.2)] + self.shared = nn.Sequential(*blocks) + + self.unshared = nn.Linear(dim_out, style_dim) + + def forward(self, speech): + h = self.shared(speech.unsqueeze(1)) + h = h.view(h.size(0), -1) + s = self.unshared(h) + + return s + + +class ResBlk(nn.Module): + def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), + normalize=False, downsample='none'): + super().__init__() + self.actv = actv + self.normalize = normalize + self.downsample = DownSample(downsample) + self.downsample_res = LearnedDownSample(downsample, dim_in) + self.learned_sc = dim_in != dim_out + self._build_weights(dim_in, dim_out) + + def _build_weights(self, dim_in, dim_out): + self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1)) + self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1)) + if self.normalize: + self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) + self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) + if self.learned_sc: + self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)) + + def _shortcut(self, x): + if self.learned_sc: + x = self.conv1x1(x) + if self.downsample: + x = self.downsample(x) + return x + + def _residual(self, x): + if self.normalize: + x = self.norm1(x) + x = self.actv(x) + x = self.conv1(x) + x = self.downsample_res(x) + if self.normalize: + x = self.norm2(x) + x = self.actv(x) + x = self.conv2(x) + return x + + def forward(self, x): + x = self._shortcut(x) + self._residual(x) + return x / math.sqrt(2) # unit variance + + +class LearnedDownSample(nn.Module): + def __init__(self, layer_type, dim_in): + super().__init__() + self.layer_type = layer_type + + if self.layer_type == 'none': + self.conv = nn.Identity() + elif self.layer_type == 'timepreserve': + self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0))) + elif self.layer_type == 'half': + self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1)) + else: + raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type) + + def forward(self, x): + return self.conv(x) + + +class LearnedUpSample(nn.Module): + def __init__(self, layer_type, dim_in): + super().__init__() + self.layer_type = layer_type + + if self.layer_type == 'none': + self.conv = nn.Identity() + elif self.layer_type == 'timepreserve': + self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0)) + elif self.layer_type == 'half': + self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1) + else: + raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type) + + def forward(self, x): + return self.conv(x) + + +class DownSample(nn.Module): + def __init__(self, layer_type): + super().__init__() + self.layer_type = layer_type + + def forward(self, x): + if self.layer_type == 'none': + return x + elif self.layer_type == 'timepreserve': + return F.avg_pool2d(x, (2, 1)) + elif self.layer_type == 'half': + if x.shape[-1] % 2 != 0: + x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1) + return F.avg_pool2d(x, 2) + else: + raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type) + + +class UpSample(nn.Module): + def __init__(self, layer_type): + super().__init__() + self.layer_type = layer_type + + def forward(self, x): + if self.layer_type == 'none': + return x + elif self.layer_type == 'timepreserve': + return F.interpolate(x, scale_factor=(2, 1), mode='nearest') + elif self.layer_type == 'half': + return F.interpolate(x, scale_factor=2, mode='nearest') + else: + raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type) diff --git a/Architectures/EmbeddingModel/__init__.py b/Architectures/EmbeddingModel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Architectures/GeneralLayers/Attention.py b/Architectures/GeneralLayers/Attention.py new file mode 100644 index 0000000000000000000000000000000000000000..eb241e315de718099901a075feae2ed0e31c7347 --- /dev/null +++ b/Architectures/GeneralLayers/Attention.py @@ -0,0 +1,324 @@ +# Written by Shigeki Karita, 2019 +# Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +# Adapted by Florian Lux, 2021 + +"""Multi-Head Attention layer definition.""" + +import math + +import numpy +import torch +from torch import nn + +from Utility.utils import make_non_pad_mask + + +class MultiHeadedAttention(nn.Module): + """ + Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + + def __init__(self, n_head, n_feat, dropout_rate): + """ + Construct an MultiHeadedAttention object. + """ + super(MultiHeadedAttention, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, query, key, value): + """ + Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention(self, value, scores, mask): + """ + Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query, key, value, mask): + """ + Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + """ + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask) + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """ + Multi-Head Attention layer with relative position encoding. + Details can be found in https://github.com/espnet/espnet/pull/2816. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + """ + + def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + self.zero_triu = zero_triu + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x): + """ + Compute relative positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + Returns: + torch.Tensor: Output tensor. + """ + zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x)[:, :, :, : x.size(-1) // 2 + 1] # only keep the positions from 0 to time2 + + if self.zero_triu: + ones = torch.ones((x.size(2), x.size(3)), device=x.device) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, mask): + """ + Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, 2*time1-1, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask) + + +class GuidedAttentionLoss(torch.nn.Module): + """ + Guided attention loss function module. + + This module calculates the guided attention loss described + in `Efficiently Trainable Text-to-Speech System Based + on Deep Convolutional Networks with Guided Attention`_, + which forces the attention to be diagonal. + + .. _`Efficiently Trainable Text-to-Speech System + Based on Deep Convolutional Networks with Guided Attention`: + https://arxiv.org/abs/1710.08969 + """ + + def __init__(self, sigma=0.4, alpha=1.0): + """ + Initialize guided attention loss module. + + Args: + sigma (float, optional): Standard deviation to control + how close attention to a diagonal. + alpha (float, optional): Scaling coefficient (lambda). + reset_always (bool, optional): Whether to always reset masks. + """ + super(GuidedAttentionLoss, self).__init__() + self.sigma = sigma + self.alpha = alpha + self.guided_attn_masks = None + self.masks = None + + def _reset_masks(self): + self.guided_attn_masks = None + self.masks = None + + def forward(self, att_ws, ilens, olens): + """ + Calculate forward propagation. + + Args: + att_ws (Tensor): Batch of attention weights (B, T_max_out, T_max_in). + ilens (LongTensor): Batch of input lenghts (B,). + olens (LongTensor): Batch of output lenghts (B,). + + Returns: + Tensor: Guided attention loss value. + """ + self._reset_masks() + self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens).to(att_ws.device) + self.masks = self._make_masks(ilens, olens).to(att_ws.device) + losses = self.guided_attn_masks * att_ws + loss = torch.mean(losses.masked_select(self.masks)) + self._reset_masks() + return self.alpha * loss + + def _make_guided_attention_masks(self, ilens, olens): + n_batches = len(ilens) + max_ilen = max(ilens) + max_olen = max(olens) + guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen), device=ilens.device) + for idx, (ilen, olen) in enumerate(zip(ilens, olens)): + guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma) + return guided_attn_masks + + @staticmethod + def _make_guided_attention_mask(ilen, olen, sigma): + """ + Make guided attention mask. + """ + grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device).float(), torch.arange(ilen, device=ilen.device).float()) + return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2))) + + @staticmethod + def _make_masks(ilens, olens): + """ + Make masks indicating non-padded part. + + Args: + ilens (LongTensor or List): Batch of lengths (B,). + olens (LongTensor or List): Batch of lengths (B,). + + Returns: + Tensor: Mask tensor indicating non-padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + """ + in_masks = make_non_pad_mask(ilens, device=ilens.device) # (B, T_in) + out_masks = make_non_pad_mask(olens, device=olens.device) # (B, T_out) + return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in) + + +class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss): + """ + Guided attention loss function module for multi head attention. + + Args: + sigma (float, optional): Standard deviation to control + how close attention to a diagonal. + alpha (float, optional): Scaling coefficient (lambda). + reset_always (bool, optional): Whether to always reset masks. + """ + + def forward(self, att_ws, ilens, olens): + """ + Calculate forward propagation. + + Args: + att_ws (Tensor): + Batch of multi head attention weights (B, H, T_max_out, T_max_in). + ilens (LongTensor): Batch of input lenghts (B,). + olens (LongTensor): Batch of output lenghts (B,). + + Returns: + Tensor: Guided attention loss value. + """ + if self.guided_attn_masks is None: + self.guided_attn_masks = (self._make_guided_attention_masks(ilens, olens).to(att_ws.device).unsqueeze(1)) + if self.masks is None: + self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1) + losses = self.guided_attn_masks * att_ws + loss = torch.mean(losses.masked_select(self.masks)) + if self.reset_always: + self._reset_masks() + + return self.alpha * loss diff --git a/Architectures/GeneralLayers/ConditionalLayerNorm.py b/Architectures/GeneralLayers/ConditionalLayerNorm.py new file mode 100644 index 0000000000000000000000000000000000000000..844cb2a3c1a73865b970ae5c41d0033746dfe4a5 --- /dev/null +++ b/Architectures/GeneralLayers/ConditionalLayerNorm.py @@ -0,0 +1,118 @@ +""" +Code taken from https://github.com/tuanh123789/AdaSpeech/blob/main/model/adaspeech_modules.py +By https://github.com/tuanh123789 +No license specified + +Implemented as outlined in AdaSpeech https://arxiv.org/pdf/2103.00993.pdf +Used in this toolkit similar to how it is done in AdaSpeech 4 https://arxiv.org/pdf/2204.00436.pdf + +""" + +import torch +from torch import nn + + +class ConditionalLayerNorm(nn.Module): + + def __init__(self, + hidden_dim, + speaker_embedding_dim, + dim=-1): + super(ConditionalLayerNorm, self).__init__() + self.dim = dim + if isinstance(hidden_dim, int): + self.normal_shape = hidden_dim + self.speaker_embedding_dim = speaker_embedding_dim + self.W_scale = nn.Sequential(nn.Linear(self.speaker_embedding_dim, self.normal_shape), + nn.Tanh(), + nn.Linear(self.normal_shape, self.normal_shape)) + self.W_bias = nn.Sequential(nn.Linear(self.speaker_embedding_dim, self.normal_shape), + nn.Tanh(), + nn.Linear(self.normal_shape, self.normal_shape)) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.constant_(self.W_scale[0].weight, 0.0) + torch.nn.init.constant_(self.W_scale[2].weight, 0.0) + torch.nn.init.constant_(self.W_scale[0].bias, 1.0) + torch.nn.init.constant_(self.W_scale[2].bias, 1.0) + torch.nn.init.constant_(self.W_bias[0].weight, 0.0) + torch.nn.init.constant_(self.W_bias[2].weight, 0.0) + torch.nn.init.constant_(self.W_bias[0].bias, 0.0) + torch.nn.init.constant_(self.W_bias[2].bias, 0.0) + + def forward(self, x, speaker_embedding): + + if self.dim != -1: + x = x.transpose(-1, self.dim) + + mean = x.mean(dim=-1, keepdim=True) + var = ((x - mean) ** 2).mean(dim=-1, keepdim=True) + scale = self.W_scale(speaker_embedding) + bias = self.W_bias(speaker_embedding) + + y = scale.unsqueeze(1) * ((x - mean) / var) + bias.unsqueeze(1) + + if self.dim != -1: + y = y.transpose(-1, self.dim) + + return y + + +class SequentialWrappableConditionalLayerNorm(nn.Module): + + def __init__(self, + hidden_dim, + speaker_embedding_dim): + super(SequentialWrappableConditionalLayerNorm, self).__init__() + if isinstance(hidden_dim, int): + self.normal_shape = hidden_dim + self.speaker_embedding_dim = speaker_embedding_dim + self.W_scale = nn.Sequential(nn.Linear(self.speaker_embedding_dim, self.normal_shape), + nn.Tanh(), + nn.Linear(self.normal_shape, self.normal_shape)) + self.W_bias = nn.Sequential(nn.Linear(self.speaker_embedding_dim, self.normal_shape), + nn.Tanh(), + nn.Linear(self.normal_shape, self.normal_shape)) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.constant_(self.W_scale[0].weight, 0.0) + torch.nn.init.constant_(self.W_scale[2].weight, 0.0) + torch.nn.init.constant_(self.W_scale[0].bias, 1.0) + torch.nn.init.constant_(self.W_scale[2].bias, 1.0) + torch.nn.init.constant_(self.W_bias[0].weight, 0.0) + torch.nn.init.constant_(self.W_bias[2].weight, 0.0) + torch.nn.init.constant_(self.W_bias[0].bias, 0.0) + torch.nn.init.constant_(self.W_bias[2].bias, 0.0) + + def forward(self, packed_input): + x, speaker_embedding = packed_input + mean = x.mean(dim=-1, keepdim=True) + var = ((x - mean) ** 2).mean(dim=-1, keepdim=True) + scale = self.W_scale(speaker_embedding) + bias = self.W_bias(speaker_embedding) + + y = scale.unsqueeze(1) * ((x - mean) / var) + bias.unsqueeze(1) + + return y + + +class AdaIN1d(nn.Module): + """ + MIT Licensed + + Copyright (c) 2022 Aaron (Yinghao) Li + https://github.com/yl4579/StyleTTS/blob/main/models.py + """ + + def __init__(self, style_dim, num_features): + super().__init__() + self.norm = nn.InstanceNorm1d(num_features, affine=False) + self.fc = nn.Linear(style_dim, num_features * 2) + + def forward(self, x, s): + h = self.fc(s) + h = h.view(h.size(0), h.size(1), 1) + gamma, beta = torch.chunk(h, chunks=2, dim=1) + return (1 + gamma.transpose(1, 2)) * self.norm(x.transpose(1, 2)).transpose(1, 2) + beta.transpose(1, 2) diff --git a/Architectures/GeneralLayers/Conformer.py b/Architectures/GeneralLayers/Conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..21a9e852b783c5243dc17e7ec3b1150f3734a3cc --- /dev/null +++ b/Architectures/GeneralLayers/Conformer.py @@ -0,0 +1,158 @@ +""" +Taken from ESPNet, but heavily modified +""" + +import torch + +from Architectures.GeneralLayers.Attention import RelPositionMultiHeadedAttention +from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d +from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm +from Architectures.GeneralLayers.Convolution import ConvolutionModule +from Architectures.GeneralLayers.EncoderLayer import EncoderLayer +from Architectures.GeneralLayers.LayerNorm import LayerNorm +from Architectures.GeneralLayers.MultiLayeredConv1d import MultiLayeredConv1d +from Architectures.GeneralLayers.MultiSequential import repeat +from Architectures.GeneralLayers.PositionalEncoding import RelPositionalEncoding +from Architectures.GeneralLayers.Swish import Swish +from Utility.utils import integrate_with_utt_embed + + +class Conformer(torch.nn.Module): + """ + Conformer encoder module. + + Args: + idim (int): Input dimension. + attention_dim (int): Dimension of attention. + attention_heads (int): The number of heads of multi head attention. + linear_units (int): The number of units of position-wise feed forward. + num_blocks (int): The number of decoder blocks. + dropout_rate (float): Dropout rate. + positional_dropout_rate (float): Dropout rate after adding positional encoding. + attention_dropout_rate (float): Dropout rate in attention. + input_layer (Union[str, torch.nn.Module]): Input layer type. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. + macaron_style (bool): Whether to use macaron style for positionwise layer. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + + """ + + def __init__(self, conformer_type, attention_dim=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1, + attention_dropout_rate=0.0, input_layer="conv2d", normalize_before=True, concat_after=False, positionwise_conv_kernel_size=1, + macaron_style=False, use_cnn_module=False, cnn_module_kernel=31, zero_triu=False, utt_embed=None, lang_embs=None, lang_emb_size=16, use_output_norm=True, embedding_integration="AdaIN"): + super(Conformer, self).__init__() + + activation = Swish() + self.conv_subsampling_factor = 1 + self.use_output_norm = use_output_norm + + if isinstance(input_layer, torch.nn.Module): + self.embed = input_layer + self.art_embed_norm = LayerNorm(attention_dim) + self.pos_enc = RelPositionalEncoding(attention_dim, positional_dropout_rate) + elif input_layer is None: + self.embed = None + self.pos_enc = torch.nn.Sequential(RelPositionalEncoding(attention_dim, positional_dropout_rate)) + else: + raise ValueError("unknown input_layer: " + input_layer) + + if self.use_output_norm: + self.output_norm = LayerNorm(attention_dim) + self.utt_embed = utt_embed + self.conformer_type = conformer_type + self.use_conditional_layernorm_embedding_integration = embedding_integration in ["AdaIN", "ConditionalLayerNorm"] + if utt_embed is not None: + if conformer_type == "encoder": # the encoder gets an additional conditioning signal added to its output + if embedding_integration == "AdaIN": + self.encoder_embedding_projection = AdaIN1d(style_dim=utt_embed, num_features=attention_dim) + elif embedding_integration == "ConditionalLayerNorm": + self.encoder_embedding_projection = ConditionalLayerNorm(speaker_embedding_dim=utt_embed, hidden_dim=attention_dim) + else: + self.encoder_embedding_projection = torch.nn.Linear(attention_dim + utt_embed, attention_dim) + else: + if embedding_integration == "AdaIN": + self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: AdaIN1d(style_dim=utt_embed, num_features=attention_dim)) + elif embedding_integration == "ConditionalLayerNorm": + self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: ConditionalLayerNorm(speaker_embedding_dim=utt_embed, hidden_dim=attention_dim)) + else: + self.decoder_embedding_projections = repeat(num_blocks, lambda lnum: torch.nn.Linear(attention_dim + utt_embed, attention_dim)) + if lang_embs is not None: + self.language_embedding = torch.nn.Embedding(num_embeddings=lang_embs, embedding_dim=lang_emb_size) + self.language_embedding_projection = torch.nn.Linear(lang_emb_size, attention_dim) + self.language_emb_norm = LayerNorm(attention_dim) + # self-attention module definition + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate, zero_triu) + + # feed-forward module definition + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = (attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate,) + + # convolution module definition + convolution_layer = ConvolutionModule + convolution_layer_args = (attention_dim, cnn_module_kernel, activation) + + self.encoders = repeat(num_blocks, lambda lnum: EncoderLayer(attention_dim, encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) if macaron_style else None, + convolution_layer(*convolution_layer_args) if use_cnn_module else None, dropout_rate, + normalize_before, concat_after)) + + def forward(self, + xs, + masks, + utterance_embedding=None, + lang_ids=None): + """ + Encode input sequence. + Args: + utterance_embedding: embedding containing lots of conditioning signals + lang_ids: ids of the languages per sample in the batch + xs (torch.Tensor): Input tensor (#batch, time, idim). + masks (torch.Tensor): Mask tensor (#batch, time). + Returns: + torch.Tensor: Output tensor (#batch, time, attention_dim). + torch.Tensor: Mask tensor (#batch, time). + """ + + if self.embed is not None: + xs = self.embed(xs) + xs = self.art_embed_norm(xs) + + if lang_ids is not None: + lang_embs = self.language_embedding(lang_ids) + projected_lang_embs = self.language_embedding_projection(lang_embs).unsqueeze(-1).transpose(1, 2) + projected_lang_embs = self.language_emb_norm(projected_lang_embs) + xs = xs + projected_lang_embs # offset phoneme representation by language specific offset + + xs = self.pos_enc(xs) + + for encoder_index, encoder in enumerate(self.encoders): + if self.utt_embed: + if isinstance(xs, tuple): + x, pos_emb = xs[0], xs[1] + if self.conformer_type != "encoder": + x = integrate_with_utt_embed(hs=x, utt_embeddings=utterance_embedding, projection=self.decoder_embedding_projections[encoder_index], embedding_training=self.use_conditional_layernorm_embedding_integration) + xs = (x, pos_emb) + else: + if self.conformer_type != "encoder": + xs = integrate_with_utt_embed(hs=xs, utt_embeddings=utterance_embedding, projection=self.decoder_embedding_projections[encoder_index], embedding_training=self.use_conditional_layernorm_embedding_integration) + xs, masks = encoder(xs, masks) + + if isinstance(xs, tuple): + xs = xs[0] + + if self.use_output_norm and not (self.utt_embed and self.conformer_type == "encoder"): + xs = self.output_norm(xs) + + if self.utt_embed and self.conformer_type == "encoder": + xs = integrate_with_utt_embed(hs=xs, utt_embeddings=utterance_embedding, + projection=self.encoder_embedding_projection, embedding_training=self.use_conditional_layernorm_embedding_integration) + + return xs, masks diff --git a/Architectures/GeneralLayers/Convolution.py b/Architectures/GeneralLayers/Convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..29d6c8c17a8c66fa3e57fdb02b767fc2870f2269 --- /dev/null +++ b/Architectures/GeneralLayers/Convolution.py @@ -0,0 +1,55 @@ +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +# Adapted by Florian Lux 2021 + + +from torch import nn + + +class ConvolutionModule(nn.Module): + """ + ConvolutionModule in Conformer model. + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + + """ + + def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True): + super(ConvolutionModule, self).__init__() + # kernel_size should be an odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, ) + self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, ) + self.norm = nn.BatchNorm1d(channels) + self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, ) + self.activation = activation + + def forward(self, x): + """ + Compute convolution module. + + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) + + return x.transpose(1, 2) diff --git a/Architectures/GeneralLayers/DurationPredictor.py b/Architectures/GeneralLayers/DurationPredictor.py new file mode 100644 index 0000000000000000000000000000000000000000..871f3bb2e1e2f571ab2e24f3233e9a22d009b043 --- /dev/null +++ b/Architectures/GeneralLayers/DurationPredictor.py @@ -0,0 +1,171 @@ +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) +# Adapted by Florian Lux 2021 + + +import torch + +from Architectures.GeneralLayers.ConditionalLayerNorm import AdaIN1d +from Architectures.GeneralLayers.ConditionalLayerNorm import ConditionalLayerNorm +from Architectures.GeneralLayers.LayerNorm import LayerNorm +from Utility.utils import integrate_with_utt_embed + + +class DurationPredictor(torch.nn.Module): + """ + Duration predictor module. + + This is a module of duration predictor described + in `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + The duration predictor predicts a duration of each frame in log domain + from the hidden embeddings of encoder. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + + Note: + The calculation domain of outputs is different + between in `forward` and in `inference`. In `forward`, + the outputs are calculated in log domain but in `inference`, + those are calculated in linear domain. + + """ + + def __init__(self, idim, + n_layers=2, + n_chans=384, + kernel_size=3, + dropout_rate=0.1, + offset=1.0, + utt_embed_dim=None, + embedding_integration="AdaIN"): + """ + Initialize duration predictor module. + + Args: + idim (int): Input dimension. + n_layers (int, optional): Number of convolutional layers. + n_chans (int, optional): Number of channels of convolutional layers. + kernel_size (int, optional): Kernel size of convolutional layers. + dropout_rate (float, optional): Dropout rate. + offset (float, optional): Offset value to avoid nan in log domain. + + """ + super(DurationPredictor, self).__init__() + self.offset = offset + self.conv = torch.nn.ModuleList() + self.dropouts = torch.nn.ModuleList() + self.norms = torch.nn.ModuleList() + self.embedding_projections = torch.nn.ModuleList() + self.utt_embed_dim = utt_embed_dim + self.use_conditional_layernorm_embedding_integration = embedding_integration in ["AdaIN", "ConditionalLayerNorm"] + + for idx in range(n_layers): + if utt_embed_dim is not None: + if embedding_integration == "AdaIN": + self.embedding_projections += [AdaIN1d(style_dim=utt_embed_dim, num_features=idim)] + elif embedding_integration == "ConditionalLayerNorm": + self.embedding_projections += [ConditionalLayerNorm(speaker_embedding_dim=utt_embed_dim, hidden_dim=idim)] + else: + self.embedding_projections += [torch.nn.Linear(utt_embed_dim + idim, idim)] + else: + self.embedding_projections += [lambda x: x] + in_chans = idim if idx == 0 else n_chans + self.conv += [torch.nn.Sequential(torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, ), + torch.nn.ReLU())] + self.norms += [LayerNorm(n_chans, dim=1)] + self.dropouts += [torch.nn.Dropout(dropout_rate)] + + self.linear = torch.nn.Linear(n_chans, 1) + + def _forward(self, xs, x_masks=None, is_inference=False, utt_embed=None): + xs = xs.transpose(1, -1) # (B, idim, Tmax) + + for f, c, d, p in zip(self.conv, self.norms, self.dropouts, self.embedding_projections): + xs = f(xs) # (B, C, Tmax) + if self.utt_embed_dim is not None: + xs = integrate_with_utt_embed(hs=xs.transpose(1, 2), utt_embeddings=utt_embed, projection=p, embedding_training=self.use_conditional_layernorm_embedding_integration).transpose(1, 2) + xs = c(xs) + xs = d(xs) + + # NOTE: targets are transformed to log domain in the loss calculation, so this will learn to predict in the log space, which makes the value range easier to handle. + xs = self.linear(xs.transpose(1, -1)).squeeze(-1) # (B, Tmax) + + if is_inference: + # NOTE: since we learned to predict in the log domain, we have to invert the log during inference. + xs = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value + else: + xs = xs.masked_fill(x_masks, 0.0) + + return xs + + def forward(self, xs, padding_mask=None, utt_embed=None): + """ + Calculate forward propagation. + + Args: + xs (Tensor): Batch of input sequences (B, Tmax, idim). + padding_mask (ByteTensor, optional): + Batch of masks indicating padded part (B, Tmax). + + Returns: + Tensor: Batch of predicted durations in log domain (B, Tmax). + + """ + return self._forward(xs, padding_mask, False, utt_embed=utt_embed) + + def inference(self, xs, padding_mask=None, utt_embed=None): + """ + Inference duration. + + Args: + xs (Tensor): Batch of input sequences (B, Tmax, idim). + padding_mask (ByteTensor, optional): + Batch of masks indicating padded part (B, Tmax). + + Returns: + LongTensor: Batch of predicted durations in linear domain (B, Tmax). + + """ + return self._forward(xs, padding_mask, True, utt_embed=utt_embed) + + +class DurationPredictorLoss(torch.nn.Module): + """ + Loss function module for duration predictor. + + The loss value is Calculated in log domain to make it Gaussian. + + """ + + def __init__(self, offset=1.0, reduction="mean"): + """ + Args: + offset (float, optional): Offset value to avoid nan in log domain. + reduction (str): Reduction type in loss calculation. + + """ + super(DurationPredictorLoss, self).__init__() + self.criterion = torch.nn.MSELoss(reduction=reduction) + self.offset = offset + + def forward(self, outputs, targets): + """ + Calculate forward propagation. + + Args: + outputs (Tensor): Batch of prediction durations in log domain (B, T) + targets (LongTensor): Batch of groundtruth durations in linear domain (B, T) + + Returns: + Tensor: Mean squared error loss value. + + Note: + `outputs` is in log domain but `targets` is in linear domain. + + """ + # NOTE: outputs is in log domain while targets in linear + targets = torch.log(targets.float() + self.offset) + loss = self.criterion(outputs, targets) + + return loss diff --git a/Architectures/GeneralLayers/EncoderLayer.py b/Architectures/GeneralLayers/EncoderLayer.py new file mode 100644 index 0000000000000000000000000000000000000000..e21d35ae120527e049c37fd9516fd0d860ea5e7f --- /dev/null +++ b/Architectures/GeneralLayers/EncoderLayer.py @@ -0,0 +1,144 @@ +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +# Adapted by Florian Lux 2021 + + +import torch +from torch import nn + +from Architectures.GeneralLayers.LayerNorm import LayerNorm + + +class EncoderLayer(nn.Module): + """ + Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance + can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + + """ + + def __init__(self, size, self_attn, feed_forward, feed_forward_macaron, conv_module, dropout_rate, normalize_before=True, concat_after=False, ): + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = LayerNorm(size) # for the FNN module + self.norm_mha = LayerNorm(size) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = LayerNorm(size) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = LayerNorm(size) # for the CNN module + self.norm_final = LayerNorm(size) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + + def forward(self, x_input, mask, cache=None): + """ + Compute encoded features. + + Args: + x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb. + - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)]. + - w/o pos emb: Tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + if isinstance(x_input, tuple): + x, pos_emb = x_input[0], x_input[1] + else: + x, pos_emb = x_input, None + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if pos_emb is not None: + x_att = self.self_attn(x_q, x, x, pos_emb, mask) + else: + x_att = self.self_attn(x_q, x, x, mask) + + if self.concat_after: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x = residual + self.dropout(self.conv_module(x)) + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + if pos_emb is not None: + return (x, pos_emb), mask + + return x, mask diff --git a/Architectures/GeneralLayers/LayerNorm.py b/Architectures/GeneralLayers/LayerNorm.py new file mode 100644 index 0000000000000000000000000000000000000000..715346b2e9fb551f3adad55b609c3a5b2bb3ad64 --- /dev/null +++ b/Architectures/GeneralLayers/LayerNorm.py @@ -0,0 +1,36 @@ +# Written by Shigeki Karita, 2019 +# Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +# Adapted by Florian Lux, 2021 + +import torch + + +class LayerNorm(torch.nn.LayerNorm): + """ + Layer normalization module. + + Args: + nout (int): Output dim size. + dim (int): Dimension to be normalized. + """ + + def __init__(self, nout, dim=-1, eps=1e-12): + """ + Construct an LayerNorm object. + """ + super(LayerNorm, self).__init__(nout, eps=eps) + self.dim = dim + + def forward(self, x): + """ + Apply layer normalization. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Normalized tensor. + """ + if self.dim == -1: + return super(LayerNorm, self).forward(x) + return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) diff --git a/Architectures/GeneralLayers/LengthRegulator.py b/Architectures/GeneralLayers/LengthRegulator.py new file mode 100644 index 0000000000000000000000000000000000000000..8605716aab026f83b492c05ee5285562182e0674 --- /dev/null +++ b/Architectures/GeneralLayers/LengthRegulator.py @@ -0,0 +1,61 @@ +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) +# Adapted by Florian Lux 2021 + +from abc import ABC + +import torch + +from Utility.utils import pad_list + + +class LengthRegulator(torch.nn.Module, ABC): + """ + Length regulator module for feed-forward Transformer. + + This is a module of length regulator described in + `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + The length regulator expands char or + phoneme-level embedding features to frame-level by repeating each + feature based on the corresponding predicted durations. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + + """ + + def __init__(self, pad_value=0.0): + """ + Initialize length regulator module. + + Args: + pad_value (float, optional): Value used for padding. + """ + super(LengthRegulator, self).__init__() + self.pad_value = pad_value + + def forward(self, xs, ds, alpha=1.0): + """ + Calculate forward propagation. + Args: + xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D). + ds (LongTensor): Batch of durations of each frame (B, T). + alpha (float, optional): Alpha value to control speed of speech. + Returns: + Tensor: replicated input tensor based on durations (B, T*, D). + """ + + if alpha != 1.0: + assert alpha > 0 + ds = torch.round(ds.float() * alpha).long() + + if ds.sum() == 0: + ds[ds.sum(dim=1).eq(0)] = 1 + + return pad_list([self._repeat_one_sequence(x, d) for x, d in zip(xs, ds)], self.pad_value) + + def _repeat_one_sequence(self, x, d): + """ + Repeat each frame according to duration + """ + return torch.repeat_interleave(x, d, dim=0) diff --git a/Architectures/GeneralLayers/MultiLayeredConv1d.py b/Architectures/GeneralLayers/MultiLayeredConv1d.py new file mode 100644 index 0000000000000000000000000000000000000000..f2de4a06a06d891fbaca726959b0f0d34d93d7cc --- /dev/null +++ b/Architectures/GeneralLayers/MultiLayeredConv1d.py @@ -0,0 +1,87 @@ +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) +# Adapted by Florian Lux 2021 + +""" +Layer modules for FFT block in FastSpeech (Feed-forward Transformer). +""" + +import torch + + +class MultiLayeredConv1d(torch.nn.Module): + """ + Multi-layered conv1d for Transformer block. + + This is a module of multi-layered conv1d designed + to replace positionwise feed-forward network + in Transformer block, which is introduced in + `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + """ + + def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): + """ + Initialize MultiLayeredConv1d module. + + Args: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + kernel_size (int): Kernel size of conv1d. + dropout_rate (float): Dropout rate. + """ + super(MultiLayeredConv1d, self).__init__() + self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, ) + self.w_2 = torch.nn.Conv1d(hidden_chans, in_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, ) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward(self, x): + """ + Calculate forward propagation. + + Args: + x (torch.Tensor): Batch of input tensors (B, T, in_chans). + + Returns: + torch.Tensor: Batch of output tensors (B, T, hidden_chans). + """ + x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1) + + +class Conv1dLinear(torch.nn.Module): + """ + Conv1D + Linear for Transformer block. + + A variant of MultiLayeredConv1d, which replaces second conv-layer to linear. + """ + + def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): + """ + Initialize Conv1dLinear module. + + Args: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + kernel_size (int): Kernel size of conv1d. + dropout_rate (float): Dropout rate. + """ + super(Conv1dLinear, self).__init__() + self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2, ) + self.w_2 = torch.nn.Linear(hidden_chans, in_chans) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward(self, x): + """ + Calculate forward propagation. + + Args: + x (torch.Tensor): Batch of input tensors (B, T, in_chans). + + Returns: + torch.Tensor: Batch of output tensors (B, T, hidden_chans). + """ + x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + return self.w_2(self.dropout(x)) diff --git a/Architectures/GeneralLayers/MultiSequential.py b/Architectures/GeneralLayers/MultiSequential.py new file mode 100644 index 0000000000000000000000000000000000000000..bccf8cd18bf94a42fcc1ef94f3fb23e86a114394 --- /dev/null +++ b/Architectures/GeneralLayers/MultiSequential.py @@ -0,0 +1,33 @@ +# Written by Shigeki Karita, 2019 +# Published under Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +# Adapted by Florian Lux, 2021 + +import torch + + +class MultiSequential(torch.nn.Sequential): + """ + Multi-input multi-output torch.nn.Sequential. + """ + + def forward(self, *args): + """ + Repeat. + """ + for m in self: + args = m(*args) + return args + + +def repeat(N, fn): + """ + Repeat module N times. + + Args: + N (int): Number of repeat time. + fn (Callable): Function to generate module. + + Returns: + MultiSequential: Repeated model instance. + """ + return MultiSequential(*[fn(n) for n in range(N)]) diff --git a/Architectures/GeneralLayers/PositionalEncoding.py b/Architectures/GeneralLayers/PositionalEncoding.py new file mode 100644 index 0000000000000000000000000000000000000000..8929a7fa6298f00e97fba1630524da014b738ace --- /dev/null +++ b/Architectures/GeneralLayers/PositionalEncoding.py @@ -0,0 +1,166 @@ +""" +Taken from ESPNet +""" + +import math + +import torch + + +class PositionalEncoding(torch.nn.Module): + """ + Positional encoding. + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + reverse (bool): Whether to reverse the input position. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): + """ + Construct an PositionalEncoding object. + """ + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.reverse = reverse + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0, device=d_model.device).expand(1, max_len)) + + def extend_pe(self, x): + """ + Reset the positional encodings. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + if self.reverse: + position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x): + """ + Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class RelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module (new implementation). + Details can be found in https://github.com/espnet/espnet/pull/2816. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """ + Construct an PositionalEncoding object. + """ + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i