diff --git a/TTS/__init__.py b/TTS/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e93c9b5db1fe29c17774f6c888f302fdcbda869 --- /dev/null +++ b/TTS/__init__.py @@ -0,0 +1,33 @@ +import importlib.metadata + +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 + +__version__ = importlib.metadata.version("coqui-tts") + + +if is_pytorch_at_least_2_4(): + import _codecs + from collections import defaultdict + + import numpy as np + import torch + + from TTS.config.shared_configs import BaseDatasetConfig + from TTS.tts.configs.xtts_config import XttsConfig + from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig + from TTS.utils.radam import RAdam + + torch.serialization.add_safe_globals([dict, defaultdict, RAdam]) + + # Bark + torch.serialization.add_safe_globals( + [ + np.core.multiarray.scalar, + np.dtype, + np.dtypes.Float64DType, + _codecs.encode, # TODO: safe by default from Pytorch 2.5 + ] + ) + + # XTTS + torch.serialization.add_safe_globals([BaseDatasetConfig, XttsConfig, XttsAudioConfig, XttsArgs]) diff --git a/TTS/api.py b/TTS/api.py new file mode 100644 index 0000000000000000000000000000000000000000..250ed1a0d928de0a6d3f91ef284811f13e056a03 --- /dev/null +++ b/TTS/api.py @@ -0,0 +1,458 @@ +import logging +import tempfile +import warnings +from pathlib import Path + +from torch import nn + +from TTS.config import load_config +from TTS.utils.audio.numpy_transforms import save_wav +from TTS.utils.manage import ModelManager +from TTS.utils.synthesizer import Synthesizer + +logger = logging.getLogger(__name__) + + +class TTS(nn.Module): + """TODO: Add voice conversion and Capacitron support.""" + + def __init__( + self, + model_name: str = "", + model_path: str = None, + config_path: str = None, + vocoder_path: str = None, + vocoder_config_path: str = None, + progress_bar: bool = True, + gpu=False, + ): + """🐸TTS python interface that allows to load and use the released models. + + Example with a multi-speaker model: + >>> from TTS.api import TTS + >>> tts = TTS(TTS.list_models()[0]) + >>> wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0]) + >>> tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav") + + Example with a single-speaker model: + >>> tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False) + >>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav") + + Example loading a model from a path: + >>> tts = TTS(model_path="/path/to/checkpoint_100000.pth", config_path="/path/to/config.json", progress_bar=False, gpu=False) + >>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav") + + Example voice cloning with YourTTS in English, French and Portuguese: + >>> tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True) + >>> tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="thisisit.wav") + >>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav") + >>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav") + + Example Fairseq TTS models (uses ISO language codes in https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html): + >>> tts = TTS(model_name="tts_models/eng/fairseq/vits", progress_bar=False, gpu=True) + >>> tts.tts_to_file("This is a test.", file_path="output.wav") + + Args: + model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None. + model_path (str, optional): Path to the model checkpoint. Defaults to None. + config_path (str, optional): Path to the model config. Defaults to None. + vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None. + vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None. + progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + """ + super().__init__() + self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar) + self.config = load_config(config_path) if config_path else None + self.synthesizer = None + self.voice_converter = None + self.model_name = "" + if gpu: + warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.") + + if model_name is not None and len(model_name) > 0: + if "tts_models" in model_name: + self.load_tts_model_by_name(model_name, gpu) + elif "voice_conversion_models" in model_name: + self.load_vc_model_by_name(model_name, gpu) + else: + self.load_model_by_name(model_name, gpu) + + if model_path: + self.load_tts_model_by_path( + model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu + ) + + @property + def models(self): + return self.manager.list_tts_models() + + @property + def is_multi_speaker(self): + if hasattr(self.synthesizer.tts_model, "speaker_manager") and self.synthesizer.tts_model.speaker_manager: + return self.synthesizer.tts_model.speaker_manager.num_speakers > 1 + return False + + @property + def is_multi_lingual(self): + # Not sure what sets this to None, but applied a fix to prevent crashing. + if ( + isinstance(self.model_name, str) + and "xtts" in self.model_name + or self.config + and ("xtts" in self.config.model or "languages" in self.config and len(self.config.languages) > 1) + ): + return True + if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager: + return self.synthesizer.tts_model.language_manager.num_languages > 1 + return False + + @property + def speakers(self): + if not self.is_multi_speaker: + return None + return self.synthesizer.tts_model.speaker_manager.speaker_names + + @property + def languages(self): + if not self.is_multi_lingual: + return None + return self.synthesizer.tts_model.language_manager.language_names + + @staticmethod + def get_models_file_path(): + return Path(__file__).parent / ".models.json" + + @staticmethod + def list_models(): + return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models() + + def download_model_by_name(self, model_name: str): + model_path, config_path, model_item = self.manager.download_model(model_name) + if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)): + # return model directory if there are multiple files + # we assume that the model knows how to load itself + return None, None, None, None, model_path + if model_item.get("default_vocoder") is None: + return model_path, config_path, None, None, None + vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"]) + return model_path, config_path, vocoder_path, vocoder_config_path, None + + def load_model_by_name(self, model_name: str, gpu: bool = False): + """Load one of the 🐸TTS models by name. + + Args: + model_name (str): Model name to load. You can list models by ```tts.models```. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + """ + self.load_tts_model_by_name(model_name, gpu) + + def load_vc_model_by_name(self, model_name: str, gpu: bool = False): + """Load one of the voice conversion models by name. + + Args: + model_name (str): Model name to load. You can list models by ```tts.models```. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + """ + self.model_name = model_name + model_path, config_path, _, _, _ = self.download_model_by_name(model_name) + self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu) + + def load_tts_model_by_name(self, model_name: str, gpu: bool = False): + """Load one of 🐸TTS models by name. + + Args: + model_name (str): Model name to load. You can list models by ```tts.models```. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + + TODO: Add tests + """ + self.synthesizer = None + self.model_name = model_name + + model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name) + + # init synthesizer + # None values are fetch from the model + self.synthesizer = Synthesizer( + tts_checkpoint=model_path, + tts_config_path=config_path, + tts_speakers_file=None, + tts_languages_file=None, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config_path, + encoder_checkpoint=None, + encoder_config=None, + model_dir=model_dir, + use_cuda=gpu, + ) + + def load_tts_model_by_path( + self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False + ): + """Load a model from a path. + + Args: + model_path (str): Path to the model checkpoint. + config_path (str): Path to the model config. + vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None. + vocoder_config (str, optional): Path to the vocoder config. Defaults to None. + gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. + """ + + self.synthesizer = Synthesizer( + tts_checkpoint=model_path, + tts_config_path=config_path, + tts_speakers_file=None, + tts_languages_file=None, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config, + encoder_checkpoint=None, + encoder_config=None, + use_cuda=gpu, + ) + + def _check_arguments( + self, + speaker: str = None, + language: str = None, + speaker_wav: str = None, + emotion: str = None, + speed: float = None, + **kwargs, + ) -> None: + """Check if the arguments are valid for the model.""" + # check for the coqui tts models + if self.is_multi_speaker and (speaker is None and speaker_wav is None): + raise ValueError("Model is multi-speaker but no `speaker` is provided.") + if self.is_multi_lingual and language is None: + raise ValueError("Model is multi-lingual but no `language` is provided.") + if not self.is_multi_speaker and speaker is not None and "voice_dir" not in kwargs: + raise ValueError("Model is not multi-speaker but `speaker` is provided.") + if not self.is_multi_lingual and language is not None: + raise ValueError("Model is not multi-lingual but `language` is provided.") + if emotion is not None and speed is not None: + raise ValueError("Emotion and speed can only be used with Coqui Studio models. Which is discontinued.") + + def tts( + self, + text: str, + speaker: str = None, + language: str = None, + speaker_wav: str = None, + emotion: str = None, + speed: float = None, + split_sentences: bool = True, + **kwargs, + ): + """Convert text to speech. + + Args: + text (str): + Input text to synthesize. + speaker (str, optional): + Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by + `tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None. + language (str): Language of the text. If None, the default language of the speaker is used. Language is only + supported by `XTTS` model. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. + emotion (str, optional): + Emotion to use for 🐸Coqui Studio models. If None, Studio models use "Neutral". Defaults to None. + speed (float, optional): + Speed factor to use for 🐸Coqui Studio models, between 0 and 2.0. If None, Studio models use 1.0. + Defaults to None. + split_sentences (bool, optional): + Split text into sentences, synthesize them separately and concatenate the file audio. + Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only + applicable to the 🐸TTS models. Defaults to True. + kwargs (dict, optional): + Additional arguments for the model. + """ + self._check_arguments( + speaker=speaker, language=language, speaker_wav=speaker_wav, emotion=emotion, speed=speed, **kwargs + ) + wav = self.synthesizer.tts( + text=text, + speaker_name=speaker, + language_name=language, + speaker_wav=speaker_wav, + reference_wav=None, + style_wav=None, + style_text=None, + reference_speaker_name=None, + split_sentences=split_sentences, + **kwargs, + ) + return wav + + def tts_to_file( + self, + text: str, + speaker: str = None, + language: str = None, + speaker_wav: str = None, + emotion: str = None, + speed: float = 1.0, + pipe_out=None, + file_path: str = "output.wav", + split_sentences: bool = True, + **kwargs, + ): + """Convert text to speech. + + Args: + text (str): + Input text to synthesize. + speaker (str, optional): + Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by + `tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None. + language (str, optional): + Language code for multi-lingual models. You can check whether loaded model is multi-lingual + `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. + emotion (str, optional): + Emotion to use for 🐸Coqui Studio models. Defaults to "Neutral". + speed (float, optional): + Speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0. Defaults to None. + pipe_out (BytesIO, optional): + Flag to stdout the generated TTS wav file for shell pipe. + file_path (str, optional): + Output file path. Defaults to "output.wav". + split_sentences (bool, optional): + Split text into sentences, synthesize them separately and concatenate the file audio. + Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only + applicable to the 🐸TTS models. Defaults to True. + kwargs (dict, optional): + Additional arguments for the model. + """ + self._check_arguments(speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs) + + wav = self.tts( + text=text, + speaker=speaker, + language=language, + speaker_wav=speaker_wav, + split_sentences=split_sentences, + **kwargs, + ) + self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out) + return file_path + + def voice_conversion( + self, + source_wav: str, + target_wav: str, + ): + """Voice conversion with FreeVC. Convert source wav to target speaker. + + Args:`` + source_wav (str): + Path to the source wav file. + target_wav (str):` + Path to the target wav file. + """ + wav = self.voice_converter.voice_conversion(source_wav=source_wav, target_wav=target_wav) + return wav + + def voice_conversion_to_file( + self, + source_wav: str, + target_wav: str, + file_path: str = "output.wav", + ): + """Voice conversion with FreeVC. Convert source wav to target speaker. + + Args: + source_wav (str): + Path to the source wav file. + target_wav (str): + Path to the target wav file. + file_path (str, optional): + Output file path. Defaults to "output.wav". + """ + wav = self.voice_conversion(source_wav=source_wav, target_wav=target_wav) + save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate) + return file_path + + def tts_with_vc( + self, + text: str, + language: str = None, + speaker_wav: str = None, + speaker: str = None, + split_sentences: bool = True, + ): + """Convert text to speech with voice conversion. + + It combines tts with voice conversion to fake voice cloning. + + - Convert text to speech with tts. + - Convert the output wav to target speaker with voice conversion. + + Args: + text (str): + Input text to synthesize. + language (str, optional): + Language code for multi-lingual models. You can check whether loaded model is multi-lingual + `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. + speaker (str, optional): + Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by + `tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None. + split_sentences (bool, optional): + Split text into sentences, synthesize them separately and concatenate the file audio. + Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only + applicable to the 🐸TTS models. Defaults to True. + """ + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: + # Lazy code... save it to a temp file to resample it while reading it for VC + self.tts_to_file( + text=text, speaker=speaker, language=language, file_path=fp.name, split_sentences=split_sentences + ) + if self.voice_converter is None: + self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24") + wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav) + return wav + + def tts_with_vc_to_file( + self, + text: str, + language: str = None, + speaker_wav: str = None, + file_path: str = "output.wav", + speaker: str = None, + split_sentences: bool = True, + ): + """Convert text to speech with voice conversion and save to file. + + Check `tts_with_vc` for more details. + + Args: + text (str): + Input text to synthesize. + language (str, optional): + Language code for multi-lingual models. You can check whether loaded model is multi-lingual + `tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None. + speaker_wav (str, optional): + Path to a reference wav file to use for voice cloning with supporting models like YourTTS. + Defaults to None. + file_path (str, optional): + Output file path. Defaults to "output.wav". + speaker (str, optional): + Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by + `tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None. + split_sentences (bool, optional): + Split text into sentences, synthesize them separately and concatenate the file audio. + Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only + applicable to the 🐸TTS models. Defaults to True. + """ + wav = self.tts_with_vc( + text=text, language=language, speaker_wav=speaker_wav, speaker=speaker, split_sentences=split_sentences + ) + save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate) diff --git a/TTS/bin/__init__.py b/TTS/bin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/bin/collect_env_info.py b/TTS/bin/collect_env_info.py new file mode 100644 index 0000000000000000000000000000000000000000..32aa303e6e0b2918a2fd6d64033cbcd558a45cb6 --- /dev/null +++ b/TTS/bin/collect_env_info.py @@ -0,0 +1,49 @@ +"""Get detailed info about the working environment.""" + +import json +import os +import platform +import sys + +import numpy +import torch + +import TTS + +sys.path += [os.path.abspath(".."), os.path.abspath(".")] + + +def system_info(): + return { + "OS": platform.system(), + "architecture": platform.architecture(), + "version": platform.version(), + "processor": platform.processor(), + "python": platform.python_version(), + } + + +def cuda_info(): + return { + "GPU": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())], + "available": torch.cuda.is_available(), + "version": torch.version.cuda, + } + + +def package_info(): + return { + "numpy": numpy.__version__, + "PyTorch_version": torch.__version__, + "PyTorch_debug": torch.version.debug, + "TTS": TTS.__version__, + } + + +def main(): + details = {"System": system_info(), "CUDA": cuda_info(), "Packages": package_info()} + print(json.dumps(details, indent=4, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py new file mode 100644 index 0000000000000000000000000000000000000000..127199186bd91c7c27a57a94da42e65f78e24735 --- /dev/null +++ b/TTS/bin/compute_attention_masks.py @@ -0,0 +1,169 @@ +import argparse +import importlib +import logging +import os +from argparse import RawTextHelpFormatter + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from trainer.io import load_checkpoint + +from TTS.config import load_config +from TTS.tts.datasets.TTSDataset import TTSDataset +from TTS.tts.models import setup_model +from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols +from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger + +if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + + # pylint: disable=bad-option-value + parser = argparse.ArgumentParser( + description="""Extract attention masks from trained Tacotron/Tacotron2 models. +These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n""" + """Each attention mask is written to the same path as the input wav file with ".npy" file extension. +(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n""" + """ +Example run: + CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py + --model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth + --config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json + --dataset_metafile metadata.csv + --data_path /root/LJSpeech-1.1/ + --batch_size 32 + --dataset ljspeech + --use_cuda +""", + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("--model_path", type=str, required=True, help="Path to Tacotron/Tacotron2 model file ") + parser.add_argument( + "--config_path", + type=str, + required=True, + help="Path to Tacotron/Tacotron2 config file.", + ) + parser.add_argument( + "--dataset", + type=str, + default="", + required=True, + help="Target dataset processor name from TTS.tts.dataset.preprocess.", + ) + + parser.add_argument( + "--dataset_metafile", + type=str, + default="", + required=True, + help="Dataset metafile inclusing file paths with transcripts.", + ) + parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.") + parser.add_argument("--use_cuda", action=argparse.BooleanOptionalAction, default=False, help="enable/disable cuda.") + + parser.add_argument( + "--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA." + ) + args = parser.parse_args() + + C = load_config(args.config_path) + ap = AudioProcessor(**C.audio) + + # if the vocabulary was passed, replace the default + if "characters" in C.keys(): + symbols, phonemes = make_symbols(**C.characters) # noqa: F811 + + # load the model + num_chars = len(phonemes) if C.use_phonemes else len(symbols) + # TODO: handle multi-speaker + model = setup_model(C) + model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True) + + # data loader + preprocessor = importlib.import_module("TTS.tts.datasets.formatters") + preprocessor = getattr(preprocessor, args.dataset) + meta_data = preprocessor(args.data_path, args.dataset_metafile) + dataset = TTSDataset( + model.decoder.r, + C.text_cleaner, + compute_linear_spec=False, + ap=ap, + meta_data=meta_data, + characters=C.characters if "characters" in C.keys() else None, + add_blank=C["add_blank"] if "add_blank" in C.keys() else False, + use_phonemes=C.use_phonemes, + phoneme_cache_path=C.phoneme_cache_path, + phoneme_language=C.phoneme_language, + enable_eos_bos=C.enable_eos_bos_chars, + ) + + dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False)) + loader = DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=4, + collate_fn=dataset.collate_fn, + shuffle=False, + drop_last=False, + ) + + # compute attentions + file_paths = [] + with torch.no_grad(): + for data in tqdm(loader): + # setup input data + text_input = data[0] + text_lengths = data[1] + linear_input = data[3] + mel_input = data[4] + mel_lengths = data[5] + stop_targets = data[6] + item_idxs = data[7] + + # dispatch data to GPU + if args.use_cuda: + text_input = text_input.cuda() + text_lengths = text_lengths.cuda() + mel_input = mel_input.cuda() + mel_lengths = mel_lengths.cuda() + + model_outputs = model.forward(text_input, text_lengths, mel_input) + + alignments = model_outputs["alignments"].detach() + for idx, alignment in enumerate(alignments): + item_idx = item_idxs[idx] + # interpolate if r > 1 + alignment = ( + torch.nn.functional.interpolate( + alignment.transpose(0, 1).unsqueeze(0), + size=None, + scale_factor=model.decoder.r, + mode="nearest", + align_corners=None, + recompute_scale_factor=None, + ) + .squeeze(0) + .transpose(0, 1) + ) + # remove paddings + alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy() + # set file paths + wav_file_name = os.path.basename(item_idx) + align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy" + file_path = item_idx.replace(wav_file_name, align_file_name) + # save output + wav_file_abs_path = os.path.abspath(item_idx) + file_abs_path = os.path.abspath(file_path) + file_paths.append([wav_file_abs_path, file_abs_path]) + np.save(file_path, alignment) + + # ourput metafile + metafile = os.path.join(args.data_path, "metadata_attn_mask.txt") + + with open(metafile, "w", encoding="utf-8") as f: + for p in file_paths: + f.write(f"{p[0]}|{p[1]}\n") + print(f" >> Metafile created: {metafile}") diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..1bdb8d733c2a18cdf93ce4e8df91b808c1adbb10 --- /dev/null +++ b/TTS/bin/compute_embeddings.py @@ -0,0 +1,201 @@ +import argparse +import logging +import os +from argparse import RawTextHelpFormatter + +import torch +from tqdm import tqdm + +from TTS.config import load_config +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.managers import save_file +from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger + + +def compute_embeddings( + model_path, + config_path, + output_path, + old_speakers_file=None, + old_append=False, + config_dataset_path=None, + formatter_name=None, + dataset_name=None, + dataset_path=None, + meta_file_train=None, + meta_file_val=None, + disable_cuda=False, + no_eval=False, +): + use_cuda = torch.cuda.is_available() and not disable_cuda + + if config_dataset_path is not None: + c_dataset = load_config(config_dataset_path) + meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=not no_eval) + else: + c_dataset = BaseDatasetConfig() + c_dataset.formatter = formatter_name + c_dataset.dataset_name = dataset_name + c_dataset.path = dataset_path + if meta_file_train is not None: + c_dataset.meta_file_train = meta_file_train + if meta_file_val is not None: + c_dataset.meta_file_val = meta_file_val + meta_data_train, meta_data_eval = load_tts_samples(c_dataset, eval_split=not no_eval) + + if meta_data_eval is None: + samples = meta_data_train + else: + samples = meta_data_train + meta_data_eval + + encoder_manager = SpeakerManager( + encoder_model_path=model_path, + encoder_config_path=config_path, + d_vectors_file_path=old_speakers_file, + use_cuda=use_cuda, + ) + + class_name_key = encoder_manager.encoder_config.class_name_key + + # compute speaker embeddings + if old_speakers_file is not None and old_append: + speaker_mapping = encoder_manager.embeddings + else: + speaker_mapping = {} + + for fields in tqdm(samples): + class_name = fields[class_name_key] + audio_file = fields["audio_file"] + embedding_key = fields["audio_unique_name"] + + # Only update the speaker name when the embedding is already in the old file. + if embedding_key in speaker_mapping: + speaker_mapping[embedding_key]["name"] = class_name + continue + + if old_speakers_file is not None and embedding_key in encoder_manager.clip_ids: + # get the embedding from the old file + embedd = encoder_manager.get_embedding_by_clip(embedding_key) + else: + # extract the embedding + embedd = encoder_manager.compute_embedding_from_clip(audio_file) + + # create speaker_mapping if target dataset is defined + speaker_mapping[embedding_key] = {} + speaker_mapping[embedding_key]["name"] = class_name + speaker_mapping[embedding_key]["embedding"] = embedd + + if speaker_mapping: + # save speaker_mapping if target dataset is defined + if os.path.isdir(output_path): + mapping_file_path = os.path.join(output_path, "speakers.pth") + else: + mapping_file_path = output_path + + if os.path.dirname(mapping_file_path) != "": + os.makedirs(os.path.dirname(mapping_file_path), exist_ok=True) + + save_file(speaker_mapping, mapping_file_path) + print("Speaker embeddings saved at:", mapping_file_path) + + +if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + + parser = argparse.ArgumentParser( + description="""Compute embedding vectors for each audio file in a dataset and store them keyed by `{dataset_name}#{file_path}` in a .pth file\n\n""" + """ + Example runs: + python TTS/bin/compute_embeddings.py --model_path speaker_encoder_model.pth --config_path speaker_encoder_config.json --config_dataset_path dataset_config.json + + python TTS/bin/compute_embeddings.py --model_path speaker_encoder_model.pth --config_path speaker_encoder_config.json --formatter_name coqui --dataset_path /path/to/vctk/dataset --dataset_name my_vctk --meta_file_train /path/to/vctk/metafile_train.csv --meta_file_val /path/to/vctk/metafile_eval.csv + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument( + "--model_path", + type=str, + help="Path to model checkpoint file. It defaults to the released speaker encoder.", + default="https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/model_se.pth.tar", + ) + parser.add_argument( + "--config_path", + type=str, + help="Path to model config file. It defaults to the released speaker encoder config.", + default="https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/config_se.json", + ) + parser.add_argument( + "--config_dataset_path", + type=str, + help="Path to dataset config file. You either need to provide this or `formatter_name`, `dataset_name` and `dataset_path` arguments.", + default=None, + ) + parser.add_argument( + "--output_path", + type=str, + help="Path for output `pth` or `json` file.", + default="speakers.pth", + ) + parser.add_argument( + "--old_file", + type=str, + help="The old existing embedding file, from which the embeddings will be directly loaded for already computed audio clips.", + default=None, + ) + parser.add_argument( + "--old_append", + help="Append new audio clip embeddings to the old embedding file, generate a new non-duplicated merged embedding file. Default False", + default=False, + action="store_true", + ) + parser.add_argument("--disable_cuda", action="store_true", help="Flag to disable cuda.", default=False) + parser.add_argument("--no_eval", help="Do not compute eval?. Default False", default=False, action="store_true") + parser.add_argument( + "--formatter_name", + type=str, + help="Name of the formatter to use. You either need to provide this or `config_dataset_path`", + default=None, + ) + parser.add_argument( + "--dataset_name", + type=str, + help="Name of the dataset to use. You either need to provide this or `config_dataset_path`", + default=None, + ) + parser.add_argument( + "--dataset_path", + type=str, + help="Path to the dataset. You either need to provide this or `config_dataset_path`", + default=None, + ) + parser.add_argument( + "--meta_file_train", + type=str, + help="Path to the train meta file. If not set, dataset formatter uses the default metafile if it is defined in the formatter. You either need to provide this or `config_dataset_path`", + default=None, + ) + parser.add_argument( + "--meta_file_val", + type=str, + help="Path to the evaluation meta file. If not set, dataset formatter uses the default metafile if it is defined in the formatter. You either need to provide this or `config_dataset_path`", + default=None, + ) + args = parser.parse_args() + + compute_embeddings( + args.model_path, + args.config_path, + args.output_path, + old_speakers_file=args.old_file, + old_append=args.old_append, + config_dataset_path=args.config_dataset_path, + formatter_name=args.formatter_name, + dataset_name=args.dataset_name, + dataset_path=args.dataset_path, + meta_file_train=args.meta_file_train, + meta_file_val=args.meta_file_val, + disable_cuda=args.disable_cuda, + no_eval=args.no_eval, + ) diff --git a/TTS/bin/compute_statistics.py b/TTS/bin/compute_statistics.py new file mode 100644 index 0000000000000000000000000000000000000000..dc5423a69177d7dc04e9343ac7babb31180bbb6c --- /dev/null +++ b/TTS/bin/compute_statistics.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import argparse +import glob +import logging +import os + +import numpy as np +from tqdm import tqdm + +# from TTS.utils.io import load_config +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples +from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger + + +def main(): + """Run preprocessing process.""" + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + + parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.") + parser.add_argument("config_path", type=str, help="TTS config file path to define audio processin parameters.") + parser.add_argument("out_path", type=str, help="save path (directory and filename).") + parser.add_argument( + "--data_path", + type=str, + required=False, + help="folder including the target set of wavs overriding dataset config.", + ) + args, overrides = parser.parse_known_args() + + CONFIG = load_config(args.config_path) + CONFIG.parse_known_args(overrides, relaxed_parser=True) + + # load config + CONFIG.audio.signal_norm = False # do not apply earlier normalization + CONFIG.audio.stats_path = None # discard pre-defined stats + + # load audio processor + ap = AudioProcessor(**CONFIG.audio.to_dict()) + + # load the meta data of target dataset + if args.data_path: + dataset_items = glob.glob(os.path.join(args.data_path, "**", "*.wav"), recursive=True) + else: + dataset_items = load_tts_samples(CONFIG.datasets)[0] # take only train data + print(f" > There are {len(dataset_items)} files.") + + mel_sum = 0 + mel_square_sum = 0 + linear_sum = 0 + linear_square_sum = 0 + N = 0 + for item in tqdm(dataset_items): + # compute features + wav = ap.load_wav(item if isinstance(item, str) else item["audio_file"]) + linear = ap.spectrogram(wav) + mel = ap.melspectrogram(wav) + + # compute stats + N += mel.shape[1] + mel_sum += mel.sum(1) + linear_sum += linear.sum(1) + mel_square_sum += (mel**2).sum(axis=1) + linear_square_sum += (linear**2).sum(axis=1) + + mel_mean = mel_sum / N + mel_scale = np.sqrt(mel_square_sum / N - mel_mean**2) + linear_mean = linear_sum / N + linear_scale = np.sqrt(linear_square_sum / N - linear_mean**2) + + output_file_path = args.out_path + stats = {} + stats["mel_mean"] = mel_mean + stats["mel_std"] = mel_scale + stats["linear_mean"] = linear_mean + stats["linear_std"] = linear_scale + + print(f" > Avg mel spec mean: {mel_mean.mean()}") + print(f" > Avg mel spec scale: {mel_scale.mean()}") + print(f" > Avg linear spec mean: {linear_mean.mean()}") + print(f" > Avg linear spec scale: {linear_scale.mean()}") + + # set default config values for mean-var scaling + CONFIG.audio.stats_path = output_file_path + CONFIG.audio.signal_norm = True + # remove redundant values + del CONFIG.audio.max_norm + del CONFIG.audio.min_level_db + del CONFIG.audio.symmetric_norm + del CONFIG.audio.clip_norm + stats["audio_config"] = CONFIG.audio.to_dict() + np.save(output_file_path, stats, allow_pickle=True) + print(f" > stats saved to {output_file_path}") + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/eval_encoder.py b/TTS/bin/eval_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..711c8221db2ab43129622bf705c721a551adee41 --- /dev/null +++ b/TTS/bin/eval_encoder.py @@ -0,0 +1,92 @@ +import argparse +import logging +from argparse import RawTextHelpFormatter + +import torch +from tqdm import tqdm + +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger + + +def compute_encoder_accuracy(dataset_items, encoder_manager): + class_name_key = encoder_manager.encoder_config.class_name_key + map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None) + + class_acc_dict = {} + + # compute embeddings for all wav_files + for item in tqdm(dataset_items): + class_name = item[class_name_key] + wav_file = item["audio_file"] + + # extract the embedding + embedd = encoder_manager.compute_embedding_from_clip(wav_file) + if encoder_manager.encoder_criterion is not None and map_classid_to_classname is not None: + embedding = torch.FloatTensor(embedd).unsqueeze(0) + if encoder_manager.use_cuda: + embedding = embedding.cuda() + + class_id = encoder_manager.encoder_criterion.softmax.inference(embedding).item() + predicted_label = map_classid_to_classname[str(class_id)] + else: + predicted_label = None + + if class_name is not None and predicted_label is not None: + is_equal = int(class_name == predicted_label) + if class_name not in class_acc_dict: + class_acc_dict[class_name] = [is_equal] + else: + class_acc_dict[class_name].append(is_equal) + else: + raise RuntimeError("Error: class_name or/and predicted_label are None") + + acc_avg = 0 + for key, values in class_acc_dict.items(): + acc = sum(values) / len(values) + print("Class", key, "Accuracy:", acc) + acc_avg += acc + + print("Average Accuracy:", acc_avg / len(class_acc_dict)) + + +if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + + parser = argparse.ArgumentParser( + description="""Compute the accuracy of the encoder.\n\n""" + """ + Example runs: + python TTS/bin/eval_encoder.py emotion_encoder_model.pth emotion_encoder_config.json dataset_config.json + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") + parser.add_argument( + "config_path", + type=str, + help="Path to model config file.", + ) + + parser.add_argument( + "config_dataset_path", + type=str, + help="Path to dataset config file.", + ) + parser.add_argument("--use_cuda", action=argparse.BooleanOptionalAction, help="flag to set cuda.", default=True) + parser.add_argument("--eval", action=argparse.BooleanOptionalAction, help="compute eval.", default=True) + + args = parser.parse_args() + + c_dataset = load_config(args.config_dataset_path) + + meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_split=args.eval) + items = meta_data_train + meta_data_eval + + enc_manager = SpeakerManager( + encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda + ) + + compute_encoder_accuracy(items, enc_manager) diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py new file mode 100644 index 0000000000000000000000000000000000000000..86a4dce1776bc6a0461a17db2dcbce1f1061a963 --- /dev/null +++ b/TTS/bin/extract_tts_spectrograms.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python3 +"""Extract Mel spectrograms with teacher forcing.""" + +import argparse +import logging +import os + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from trainer.generic_utils import count_parameters + +from TTS.config import load_config +from TTS.tts.datasets import TTSDataset, load_tts_samples +from TTS.tts.models import setup_model +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import quantize +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger + +use_cuda = torch.cuda.is_available() + + +def setup_loader(ap, r): + tokenizer, _ = TTSTokenizer.init_from_config(c) + dataset = TTSDataset( + outputs_per_step=r, + compute_linear_spec=False, + samples=meta_data, + tokenizer=tokenizer, + ap=ap, + batch_group_size=0, + min_text_len=c.min_text_len, + max_text_len=c.max_text_len, + min_audio_len=c.min_audio_len, + max_audio_len=c.max_audio_len, + phoneme_cache_path=c.phoneme_cache_path, + precompute_num_workers=0, + use_noise_augment=False, + speaker_id_mapping=speaker_manager.name_to_id if c.use_speaker_embedding else None, + d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None, + ) + + if c.use_phonemes and c.compute_input_seq_cache: + # precompute phonemes to have a better estimate of sequence lengths. + dataset.compute_input_seq(c.num_loader_workers) + dataset.preprocess_samples() + + loader = DataLoader( + dataset, + batch_size=c.batch_size, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=False, + sampler=None, + num_workers=c.num_loader_workers, + pin_memory=False, + ) + return loader + + +def set_filename(wav_path, out_path): + wav_file = os.path.basename(wav_path) + file_name = wav_file.split(".")[0] + os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) + os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) + os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True) + os.makedirs(os.path.join(out_path, "wav"), exist_ok=True) + wavq_path = os.path.join(out_path, "quant", file_name) + mel_path = os.path.join(out_path, "mel", file_name) + wav_gl_path = os.path.join(out_path, "wav_gl", file_name + ".wav") + wav_path = os.path.join(out_path, "wav", file_name + ".wav") + return file_name, wavq_path, mel_path, wav_gl_path, wav_path + + +def format_data(data): + # setup input data + text_input = data["token_id"] + text_lengths = data["token_id_lengths"] + mel_input = data["mel"] + mel_lengths = data["mel_lengths"] + item_idx = data["item_idxs"] + d_vectors = data["d_vectors"] + speaker_ids = data["speaker_ids"] + attn_mask = data["attns"] + avg_text_length = torch.mean(text_lengths.float()) + avg_spec_length = torch.mean(mel_lengths.float()) + + # dispatch data to GPU + if use_cuda: + text_input = text_input.cuda(non_blocking=True) + text_lengths = text_lengths.cuda(non_blocking=True) + mel_input = mel_input.cuda(non_blocking=True) + mel_lengths = mel_lengths.cuda(non_blocking=True) + if speaker_ids is not None: + speaker_ids = speaker_ids.cuda(non_blocking=True) + if d_vectors is not None: + d_vectors = d_vectors.cuda(non_blocking=True) + if attn_mask is not None: + attn_mask = attn_mask.cuda(non_blocking=True) + return ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids, + d_vectors, + avg_text_length, + avg_spec_length, + attn_mask, + item_idx, + ) + + +@torch.no_grad() +def inference( + model_name, + model, + ap, + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids=None, + d_vectors=None, +): + if model_name == "glow_tts": + speaker_c = None + if speaker_ids is not None: + speaker_c = speaker_ids + elif d_vectors is not None: + speaker_c = d_vectors + outputs = model.inference_with_MAS( + text_input, + text_lengths, + mel_input, + mel_lengths, + aux_input={"d_vectors": speaker_c, "speaker_ids": speaker_ids}, + ) + model_output = outputs["model_outputs"] + model_output = model_output.detach().cpu().numpy() + + elif "tacotron" in model_name: + aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} + outputs = model(text_input, text_lengths, mel_input, mel_lengths, aux_input) + postnet_outputs = outputs["model_outputs"] + # normalize tacotron output + if model_name == "tacotron": + mel_specs = [] + postnet_outputs = postnet_outputs.data.cpu().numpy() + for b in range(postnet_outputs.shape[0]): + postnet_output = postnet_outputs[b] + mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T)) + model_output = torch.stack(mel_specs).cpu().numpy() + + elif model_name == "tacotron2": + model_output = postnet_outputs.detach().cpu().numpy() + return model_output + + +def extract_spectrograms( + data_loader, model, ap, output_path, quantize_bits=0, save_audio=False, debug=False, metada_name="metada.txt" +): + model.eval() + export_metadata = [] + for _, data in tqdm(enumerate(data_loader), total=len(data_loader)): + # format data + ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids, + d_vectors, + _, + _, + _, + item_idx, + ) = format_data(data) + + model_output = inference( + c.model.lower(), + model, + ap, + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids, + d_vectors, + ) + + for idx in range(text_input.shape[0]): + wav_file_path = item_idx[idx] + wav = ap.load_wav(wav_file_path) + _, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path) + + # quantize and save wav + if quantize_bits > 0: + wavq = quantize(wav, quantize_bits) + np.save(wavq_path, wavq) + + # save TTS mel + mel = model_output[idx] + mel_length = mel_lengths[idx] + mel = mel[:mel_length, :].T + np.save(mel_path, mel) + + export_metadata.append([wav_file_path, mel_path]) + if save_audio: + ap.save_wav(wav, wav_path) + + if debug: + print("Audio for debug saved at:", wav_gl_path) + wav = ap.inv_melspectrogram(mel) + ap.save_wav(wav, wav_gl_path) + + with open(os.path.join(output_path, metada_name), "w", encoding="utf-8") as f: + for data in export_metadata: + f.write(f"{data[0]}|{data[1]+'.npy'}\n") + + +def main(args): # pylint: disable=redefined-outer-name + # pylint: disable=global-variable-undefined + global meta_data, speaker_manager + + # Audio processor + ap = AudioProcessor(**c.audio) + + # load data instances + meta_data_train, meta_data_eval = load_tts_samples( + c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) + + # use eval and training partitions + meta_data = meta_data_train + meta_data_eval + + # init speaker manager + if c.use_speaker_embedding: + speaker_manager = SpeakerManager(data_items=meta_data) + elif c.use_d_vector_file: + speaker_manager = SpeakerManager(d_vectors_file_path=c.d_vector_file) + else: + speaker_manager = None + + # setup model + model = setup_model(c) + + # restore model + model.load_checkpoint(c, args.checkpoint_path, eval=True) + + if use_cuda: + model.cuda() + + num_params = count_parameters(model) + print("\n > Model has {} parameters".format(num_params), flush=True) + # set r + r = 1 if c.model.lower() == "glow_tts" else model.decoder.r + own_loader = setup_loader(ap, r) + + extract_spectrograms( + own_loader, + model, + ap, + args.output_path, + quantize_bits=args.quantize_bits, + save_audio=args.save_audio, + debug=args.debug, + metada_name="metada.txt", + ) + + +if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", type=str, help="Path to config file for training.", required=True) + parser.add_argument("--checkpoint_path", type=str, help="Model file to be restored.", required=True) + parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True) + parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug") + parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files") + parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero") + parser.add_argument("--eval", action=argparse.BooleanOptionalAction, help="compute eval.", default=True) + args = parser.parse_args() + + c = load_config(args.config_path) + c.audio.trim_silence = False + main(args) diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py new file mode 100644 index 0000000000000000000000000000000000000000..0519d43769561c1bfadcea506a35f015f94beafd --- /dev/null +++ b/TTS/bin/find_unique_chars.py @@ -0,0 +1,40 @@ +"""Find all the unique characters in a dataset""" + +import argparse +import logging +from argparse import RawTextHelpFormatter + +from TTS.config import load_config +from TTS.tts.datasets import find_unique_chars, load_tts_samples +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger + + +def main(): + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + + # pylint: disable=bad-option-value + parser = argparse.ArgumentParser( + description="""Find all the unique characters or phonemes in a dataset.\n\n""" + """ + Example runs: + + python TTS/bin/find_unique_chars.py --config_path config.json + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True) + args = parser.parse_args() + + c = load_config(args.config_path) + + # load all datasets + train_items, eval_items = load_tts_samples( + c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) + + items = train_items + eval_items + find_unique_chars(items) + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py new file mode 100644 index 0000000000000000000000000000000000000000..d99acb9893db3547aff356283366cf5c634095fb --- /dev/null +++ b/TTS/bin/find_unique_phonemes.py @@ -0,0 +1,79 @@ +"""Find all the unique characters in a dataset""" + +import argparse +import logging +import multiprocessing +from argparse import RawTextHelpFormatter + +from tqdm.contrib.concurrent import process_map + +from TTS.config import load_config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.utils.text.phonemizers import Gruut +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger + + +def compute_phonemes(item): + text = item["text"] + ph = phonemizer.phonemize(text).replace("|", "") + return set(ph) + + +def main(): + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + + # pylint: disable=W0601 + global c, phonemizer + # pylint: disable=bad-option-value + parser = argparse.ArgumentParser( + description="""Find all the unique characters or phonemes in a dataset.\n\n""" + """ + Example runs: + + python TTS/bin/find_unique_phonemes.py --config_path config.json + """, + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("--config_path", type=str, help="Path to dataset config file.", required=True) + args = parser.parse_args() + + c = load_config(args.config_path) + + # load all datasets + train_items, eval_items = load_tts_samples( + c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) + items = train_items + eval_items + print("Num items:", len(items)) + + language_list = [item["language"] for item in items] + is_lang_def = all(language_list) + + if not c.phoneme_language or not is_lang_def: + raise ValueError("Phoneme language must be defined in config.") + + if not language_list.count(language_list[0]) == len(language_list): + raise ValueError( + "Currently, just one phoneme language per config file is supported !! Please split the dataset config into different configs and run it individually for each language !!" + ) + + phonemizer = Gruut(language=language_list[0], keep_puncs=True) + + phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15) + phones = [] + for ph in phonemes: + phones.extend(ph) + + phones = set(phones) + lower_phones = filter(lambda c: c.islower(), phones) + phones_force_lower = [c.lower() for c in phones] + phones_force_lower = set(phones_force_lower) + + print(f" > Number of unique phonemes: {len(phones)}") + print(f" > Unique phonemes: {''.join(sorted(phones))}") + print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}") + print(f" > Unique all forced to lower phonemes: {''.join(sorted(phones_force_lower))}") + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/remove_silence_using_vad.py b/TTS/bin/remove_silence_using_vad.py new file mode 100644 index 0000000000000000000000000000000000000000..edab882db88a66cfc09ab7575fedb5dbc4294b8a --- /dev/null +++ b/TTS/bin/remove_silence_using_vad.py @@ -0,0 +1,128 @@ +import argparse +import glob +import logging +import multiprocessing +import os +import pathlib + +import torch +from tqdm import tqdm + +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger +from TTS.utils.vad import get_vad_model_and_utils, remove_silence + +torch.set_num_threads(1) + + +def adjust_path_and_remove_silence(audio_path): + output_path = audio_path.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, "")) + # ignore if the file exists + if os.path.exists(output_path) and not args.force: + return output_path, False + + # create all directory structure + pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True) + # remove the silence and save the audio + output_path, is_speech = remove_silence( + model_and_utils, + audio_path, + output_path, + trim_just_beginning_and_end=args.trim_just_beginning_and_end, + use_cuda=args.use_cuda, + ) + return output_path, is_speech + + +def preprocess_audios(): + files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True)) + print("> Number of files: ", len(files)) + if not args.force: + print("> Ignoring files that already exist in the output idrectory.") + + if args.trim_just_beginning_and_end: + print("> Trimming just the beginning and the end with nonspeech parts.") + else: + print("> Trimming all nonspeech parts.") + + filtered_files = [] + if files: + # create threads + # num_threads = multiprocessing.cpu_count() + # process_map(adjust_path_and_remove_silence, files, max_workers=num_threads, chunksize=15) + + if args.num_processes > 1: + with multiprocessing.Pool(processes=args.num_processes) as pool: + results = list( + tqdm( + pool.imap_unordered(adjust_path_and_remove_silence, files), + total=len(files), + desc="Processing audio files", + ) + ) + for output_path, is_speech in results: + if not is_speech: + filtered_files.append(output_path) + else: + for f in tqdm(files): + output_path, is_speech = adjust_path_and_remove_silence(f) + if not is_speech: + filtered_files.append(output_path) + + # write files that do not have speech + with open(os.path.join(args.output_dir, "filtered_files.txt"), "w", encoding="utf-8") as f: + for file in filtered_files: + f.write(str(file) + "\n") + else: + print("> No files Found !") + + +if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + + parser = argparse.ArgumentParser( + description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end" + ) + parser.add_argument("-i", "--input_dir", type=str, help="Dataset root dir", required=True) + parser.add_argument("-o", "--output_dir", type=str, help="Output Dataset dir", default="") + parser.add_argument("-f", "--force", default=False, action="store_true", help="Force the replace of exists files") + parser.add_argument( + "-g", + "--glob", + type=str, + default="**/*.wav", + help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav", + ) + parser.add_argument( + "-t", + "--trim_just_beginning_and_end", + action=argparse.BooleanOptionalAction, + default=True, + help="If True this script will trim just the beginning and end nonspeech parts. If False all nonspeech parts will be trimmed.", + ) + parser.add_argument( + "-c", + "--use_cuda", + action=argparse.BooleanOptionalAction, + default=False, + help="If True use cuda", + ) + parser.add_argument( + "--use_onnx", + action=argparse.BooleanOptionalAction, + default=False, + help="If True use onnx", + ) + parser.add_argument( + "--num_processes", + type=int, + default=1, + help="Number of processes to use", + ) + args = parser.parse_args() + + if args.output_dir == "": + args.output_dir = args.input_dir + + # load the model and utils + model_and_utils = get_vad_model_and_utils(use_cuda=args.use_cuda, use_onnx=args.use_onnx) + preprocess_audios() diff --git a/TTS/bin/resample.py b/TTS/bin/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f28485d1fb235ab0d521ee30318c64b48fbd5a --- /dev/null +++ b/TTS/bin/resample.py @@ -0,0 +1,90 @@ +import argparse +import glob +import os +from argparse import RawTextHelpFormatter +from multiprocessing import Pool +from shutil import copytree + +import librosa +import soundfile as sf +from tqdm import tqdm + + +def resample_file(func_args): + filename, output_sr = func_args + y, sr = librosa.load(filename, sr=output_sr) + sf.write(filename, y, sr) + + +def resample_files(input_dir, output_sr, output_dir=None, file_ext="wav", n_jobs=10): + if output_dir: + print("Recursively copying the input folder...") + copytree(input_dir, output_dir) + input_dir = output_dir + + print("Resampling the audio files...") + audio_files = glob.glob(os.path.join(input_dir, f"**/*.{file_ext}"), recursive=True) + print(f"Found {len(audio_files)} files...") + audio_files = list(zip(audio_files, len(audio_files) * [output_sr])) + with Pool(processes=n_jobs) as p: + with tqdm(total=len(audio_files)) as pbar: + for _, _ in enumerate(p.imap_unordered(resample_file, audio_files)): + pbar.update() + + print("Done !") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="""Resample a folder recusively with librosa + Can be used in place or create a copy of the folder as an output.\n\n + Example run: + python TTS/bin/resample.py + --input_dir /root/LJSpeech-1.1/ + --output_sr 22050 + --output_dir /root/resampled_LJSpeech-1.1/ + --file_ext wav + --n_jobs 24 + """, + formatter_class=RawTextHelpFormatter, + ) + + parser.add_argument( + "--input_dir", + type=str, + default=None, + required=True, + help="Path of the folder containing the audio files to resample", + ) + + parser.add_argument( + "--output_sr", + type=int, + default=22050, + required=False, + help="Samlple rate to which the audio files should be resampled", + ) + + parser.add_argument( + "--output_dir", + type=str, + default=None, + required=False, + help="Path of the destination folder. If not defined, the operation is done in place", + ) + + parser.add_argument( + "--file_ext", + type=str, + default="wav", + required=False, + help="Extension of the audio files to resample", + ) + + parser.add_argument( + "--n_jobs", type=int, default=None, help="Number of threads to use, by default it uses all cores" + ) + + args = parser.parse_args() + + resample_files(args.input_dir, args.output_sr, args.output_dir, args.file_ext, args.n_jobs) diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py new file mode 100644 index 0000000000000000000000000000000000000000..bc01ffd5957f9b167eac7175361947d21c329f9f --- /dev/null +++ b/TTS/bin/synthesize.py @@ -0,0 +1,486 @@ +#!/usr/bin/env python3 + +"""Command line interface.""" + +import argparse +import contextlib +import logging +import sys +from argparse import RawTextHelpFormatter + +# pylint: disable=redefined-outer-name, unused-argument +from pathlib import Path + +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger + +logger = logging.getLogger(__name__) + +description = """ +Synthesize speech on command line. + +You can either use your trained model or choose a model from the provided list. + +If you don't specify any models, then it uses LJSpeech based English model. + +#### Single Speaker Models + +- List provided models: + + ``` + $ tts --list_models + ``` + +- Get model info (for both tts_models and vocoder_models): + + - Query by type/name: + The model_info_by_name uses the name as it from the --list_models. + ``` + $ tts --model_info_by_name "///" + ``` + For example: + ``` + $ tts --model_info_by_name tts_models/tr/common-voice/glow-tts + $ tts --model_info_by_name vocoder_models/en/ljspeech/hifigan_v2 + ``` + - Query by type/idx: + The model_query_idx uses the corresponding idx from --list_models. + + ``` + $ tts --model_info_by_idx "/" + ``` + + For example: + + ``` + $ tts --model_info_by_idx tts_models/3 + ``` + + - Query info for model info by full name: + ``` + $ tts --model_info_by_name "///" + ``` + +- Run TTS with default models: + + ``` + $ tts --text "Text for TTS" --out_path output/path/speech.wav + ``` + +- Run TTS and pipe out the generated TTS wav file data: + + ``` + $ tts --text "Text for TTS" --pipe_out --out_path output/path/speech.wav | aplay + ``` + +- Run a TTS model with its default vocoder model: + + ``` + $ tts --text "Text for TTS" --model_name "///" --out_path output/path/speech.wav + ``` + + For example: + + ``` + $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --out_path output/path/speech.wav + ``` + +- Run with specific TTS and vocoder models from the list: + + ``` + $ tts --text "Text for TTS" --model_name "///" --vocoder_name "///" --out_path output/path/speech.wav + ``` + + For example: + + ``` + $ tts --text "Text for TTS" --model_name "tts_models/en/ljspeech/glow-tts" --vocoder_name "vocoder_models/en/ljspeech/univnet" --out_path output/path/speech.wav + ``` + +- Run your own TTS model (Using Griffin-Lim Vocoder): + + ``` + $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav + ``` + +- Run your own TTS and Vocoder models: + + ``` + $ tts --text "Text for TTS" --model_path path/to/model.pth --config_path path/to/config.json --out_path output/path/speech.wav + --vocoder_path path/to/vocoder.pth --vocoder_config_path path/to/vocoder_config.json + ``` + +#### Multi-speaker Models + +- List the available speakers and choose a among them: + + ``` + $ tts --model_name "//" --list_speaker_idxs + ``` + +- Run the multi-speaker TTS model with the target speaker ID: + + ``` + $ tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "//" --speaker_idx + ``` + +- Run your own multi-speaker TTS model: + + ``` + $ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/model.pth --config_path path/to/config.json --speakers_file_path path/to/speaker.json --speaker_idx + ``` + +### Voice Conversion Models + +``` +$ tts --out_path output/path/speech.wav --model_name "//" --source_wav --target_wav +``` +""" + + +def parse_args() -> argparse.Namespace: + """Parse arguments.""" + parser = argparse.ArgumentParser( + description=description.replace(" ```\n", ""), + formatter_class=RawTextHelpFormatter, + ) + + parser.add_argument( + "--list_models", + action="store_true", + help="list available pre-trained TTS and vocoder models.", + ) + + parser.add_argument( + "--model_info_by_idx", + type=str, + default=None, + help="model info using query format: /", + ) + + parser.add_argument( + "--model_info_by_name", + type=str, + default=None, + help="model info using query format: ///", + ) + + parser.add_argument("--text", type=str, default=None, help="Text to generate speech.") + + # Args for running pre-trained TTS models. + parser.add_argument( + "--model_name", + type=str, + default="tts_models/en/ljspeech/tacotron2-DDC", + help="Name of one of the pre-trained TTS models in format //", + ) + parser.add_argument( + "--vocoder_name", + type=str, + default=None, + help="Name of one of the pre-trained vocoder models in format //", + ) + + # Args for running custom models + parser.add_argument("--config_path", default=None, type=str, help="Path to model config file.") + parser.add_argument( + "--model_path", + type=str, + default=None, + help="Path to model file.", + ) + parser.add_argument( + "--out_path", + type=str, + default="tts_output.wav", + help="Output wav file path.", + ) + parser.add_argument("--use_cuda", action="store_true", help="Run model on CUDA.") + parser.add_argument("--device", type=str, help="Device to run model on.", default="cpu") + parser.add_argument( + "--vocoder_path", + type=str, + help="Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).", + default=None, + ) + parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None) + parser.add_argument( + "--encoder_path", + type=str, + help="Path to speaker encoder model file.", + default=None, + ) + parser.add_argument("--encoder_config_path", type=str, help="Path to speaker encoder config file.", default=None) + parser.add_argument( + "--pipe_out", + help="stdout the generated TTS wav file for shell pipe.", + action="store_true", + ) + + # args for multi-speaker synthesis + parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) + parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None) + parser.add_argument( + "--speaker_idx", + type=str, + help="Target speaker ID for a multi-speaker TTS model.", + default=None, + ) + parser.add_argument( + "--language_idx", + type=str, + help="Target language ID for a multi-lingual TTS model.", + default=None, + ) + parser.add_argument( + "--speaker_wav", + nargs="+", + help="wav file(s) to condition a multi-speaker TTS model with a Speaker Encoder. You can give multiple file paths. The d_vectors is computed as their average.", + default=None, + ) + parser.add_argument("--gst_style", help="Wav path file for GST style reference.", default=None) + parser.add_argument( + "--capacitron_style_wav", type=str, help="Wav path file for Capacitron prosody reference.", default=None + ) + parser.add_argument("--capacitron_style_text", type=str, help="Transcription of the reference.", default=None) + parser.add_argument( + "--list_speaker_idxs", + help="List available speaker ids for the defined multi-speaker model.", + action="store_true", + ) + parser.add_argument( + "--list_language_idxs", + help="List available language ids for the defined multi-lingual model.", + action="store_true", + ) + # aux args + parser.add_argument( + "--save_spectogram", + action="store_true", + help="Save raw spectogram for further (vocoder) processing in out_path.", + ) + parser.add_argument( + "--reference_wav", + type=str, + help="Reference wav file to convert in the voice of the speaker_idx or speaker_wav", + default=None, + ) + parser.add_argument( + "--reference_speaker_idx", + type=str, + help="speaker ID of the reference_wav speaker (If not provided the embedding will be computed using the Speaker Encoder).", + default=None, + ) + parser.add_argument( + "--progress_bar", + action=argparse.BooleanOptionalAction, + help="Show a progress bar for the model download.", + default=True, + ) + + # voice conversion args + parser.add_argument( + "--source_wav", + type=str, + default=None, + help="Original audio file to convert in the voice of the target_wav", + ) + parser.add_argument( + "--target_wav", + type=str, + default=None, + help="Target audio file to convert in the voice of the source_wav", + ) + + parser.add_argument( + "--voice_dir", + type=str, + default=None, + help="Voice dir for tortoise model", + ) + + args = parser.parse_args() + + # print the description if either text or list_models is not set + check_args = [ + args.text, + args.list_models, + args.list_speaker_idxs, + args.list_language_idxs, + args.reference_wav, + args.model_info_by_idx, + args.model_info_by_name, + args.source_wav, + args.target_wav, + ] + if not any(check_args): + parser.parse_args(["-h"]) + return args + + +def main(): + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + args = parse_args() + + pipe_out = sys.stdout if args.pipe_out else None + + with contextlib.redirect_stdout(None if args.pipe_out else sys.stdout): + # Late-import to make things load faster + from TTS.utils.manage import ModelManager + from TTS.utils.synthesizer import Synthesizer + + # load model manager + path = Path(__file__).parent / "../.models.json" + manager = ModelManager(path, progress_bar=args.progress_bar) + + tts_path = None + tts_config_path = None + speakers_file_path = None + language_ids_file_path = None + vocoder_path = None + vocoder_config_path = None + encoder_path = None + encoder_config_path = None + vc_path = None + vc_config_path = None + model_dir = None + + # CASE1 #list : list pre-trained TTS models + if args.list_models: + manager.list_models() + sys.exit() + + # CASE2 #info : model info for pre-trained TTS models + if args.model_info_by_idx: + model_query = args.model_info_by_idx + manager.model_info_by_idx(model_query) + sys.exit() + + if args.model_info_by_name: + model_query_full_name = args.model_info_by_name + manager.model_info_by_full_name(model_query_full_name) + sys.exit() + + # CASE3: load pre-trained model paths + if args.model_name is not None and not args.model_path: + model_path, config_path, model_item = manager.download_model(args.model_name) + # tts model + if model_item["model_type"] == "tts_models": + tts_path = model_path + tts_config_path = config_path + if args.vocoder_name is None and "default_vocoder" in model_item: + args.vocoder_name = model_item["default_vocoder"] + + # voice conversion model + if model_item["model_type"] == "voice_conversion_models": + vc_path = model_path + vc_config_path = config_path + + # tts model with multiple files to be loaded from the directory path + if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list): + model_dir = model_path + tts_path = None + tts_config_path = None + args.vocoder_name = None + + # load vocoder + if args.vocoder_name is not None and not args.vocoder_path: + vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) + + # CASE4: set custom model paths + if args.model_path is not None: + tts_path = args.model_path + tts_config_path = args.config_path + speakers_file_path = args.speakers_file_path + language_ids_file_path = args.language_ids_file_path + + if args.vocoder_path is not None: + vocoder_path = args.vocoder_path + vocoder_config_path = args.vocoder_config_path + + if args.encoder_path is not None: + encoder_path = args.encoder_path + encoder_config_path = args.encoder_config_path + + device = args.device + if args.use_cuda: + device = "cuda" + + # load models + synthesizer = Synthesizer( + tts_path, + tts_config_path, + speakers_file_path, + language_ids_file_path, + vocoder_path, + vocoder_config_path, + encoder_path, + encoder_config_path, + vc_path, + vc_config_path, + model_dir, + args.voice_dir, + ).to(device) + + # query speaker ids of a multi-speaker model. + if args.list_speaker_idxs: + if synthesizer.tts_model.speaker_manager is None: + logger.info("Model only has a single speaker.") + return + logger.info( + "Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model." + ) + logger.info(synthesizer.tts_model.speaker_manager.name_to_id) + return + + # query langauge ids of a multi-lingual model. + if args.list_language_idxs: + if synthesizer.tts_model.language_manager is None: + logger.info("Monolingual model.") + return + logger.info( + "Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model." + ) + logger.info(synthesizer.tts_model.language_manager.name_to_id) + return + + # check the arguments against a multi-speaker model. + if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav): + logger.error( + "Looks like you use a multi-speaker model. Define `--speaker_idx` to " + "select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`." + ) + return + + # RUN THE SYNTHESIS + if args.text: + logger.info("Text: %s", args.text) + + # kick it + if tts_path is not None: + wav = synthesizer.tts( + args.text, + speaker_name=args.speaker_idx, + language_name=args.language_idx, + speaker_wav=args.speaker_wav, + reference_wav=args.reference_wav, + style_wav=args.capacitron_style_wav, + style_text=args.capacitron_style_text, + reference_speaker_name=args.reference_speaker_idx, + ) + elif vc_path is not None: + wav = synthesizer.voice_conversion( + source_wav=args.source_wav, + target_wav=args.target_wav, + ) + elif model_dir is not None: + wav = synthesizer.tts( + args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav + ) + + # save the results + synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out) + logger.info("Saved output to %s", args.out_path) + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ba03c42b6d0c54f3afe076db44a25aed2bd9cbba --- /dev/null +++ b/TTS/bin/train_encoder.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import logging +import os +import sys +import time +import traceback +import warnings + +import torch +from torch.utils.data import DataLoader +from trainer.generic_utils import count_parameters, remove_experiment_folder +from trainer.io import copy_model_files, save_best_model, save_checkpoint +from trainer.torch import NoamLR +from trainer.trainer_utils import get_optimizer + +from TTS.encoder.dataset import EncoderDataset +from TTS.encoder.utils.generic_utils import setup_encoder_model +from TTS.encoder.utils.training import init_training +from TTS.encoder.utils.visual import plot_embeddings +from TTS.tts.datasets import load_tts_samples +from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger +from TTS.utils.samplers import PerfectBatchSampler +from TTS.utils.training import check_update + +torch.backends.cudnn.enabled = True +torch.backends.cudnn.benchmark = True +torch.manual_seed(54321) +use_cuda = torch.cuda.is_available() +num_gpus = torch.cuda.device_count() +print(" > Using CUDA: ", use_cuda) +print(" > Number of GPUs: ", num_gpus) + + +def setup_loader(ap: AudioProcessor, is_val: bool = False): + num_utter_per_class = c.num_utter_per_class if not is_val else c.eval_num_utter_per_class + num_classes_in_batch = c.num_classes_in_batch if not is_val else c.eval_num_classes_in_batch + + dataset = EncoderDataset( + c, + ap, + meta_data_eval if is_val else meta_data_train, + voice_len=c.voice_len, + num_utter_per_class=num_utter_per_class, + num_classes_in_batch=num_classes_in_batch, + augmentation_config=c.audio_augmentation if not is_val else None, + use_torch_spec=c.model_params.get("use_torch_spec", False), + ) + # get classes list + classes = dataset.get_class_list() + + sampler = PerfectBatchSampler( + dataset.items, + classes, + batch_size=num_classes_in_batch * num_utter_per_class, # total batch size + num_classes_in_batch=num_classes_in_batch, + num_gpus=1, + shuffle=not is_val, + drop_last=True, + ) + + if len(classes) < num_classes_in_batch: + if is_val: + raise RuntimeError( + f"config.eval_num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Eval dataset) !" + ) + raise RuntimeError( + f"config.num_classes_in_batch ({num_classes_in_batch}) need to be <= {len(classes)} (Number total of Classes in the Train dataset) !" + ) + + # set the classes to avoid get wrong class_id when the number of training and eval classes are not equal + if is_val: + dataset.set_classes(train_classes) + + loader = DataLoader( + dataset, + num_workers=c.num_loader_workers, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + ) + + return loader, classes, dataset.get_map_classid_to_classname() + + +def evaluation(model, criterion, data_loader, global_step): + eval_loss = 0 + for _, data in enumerate(data_loader): + with torch.no_grad(): + # setup input data + inputs, labels = data + + # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] + labels = torch.transpose( + labels.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch), 0, 1 + ).reshape(labels.shape) + inputs = torch.transpose( + inputs.view(c.eval_num_utter_per_class, c.eval_num_classes_in_batch, -1), 0, 1 + ).reshape(inputs.shape) + + # dispatch data to GPU + if use_cuda: + inputs = inputs.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + # forward pass model + outputs = model(inputs) + + # loss computation + loss = criterion( + outputs.view(c.eval_num_classes_in_batch, outputs.shape[0] // c.eval_num_classes_in_batch, -1), labels + ) + + eval_loss += loss.item() + + eval_avg_loss = eval_loss / len(data_loader) + # save stats + dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss}) + try: + # plot the last batch in the evaluation + figures = { + "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch), + } + dashboard_logger.eval_figures(global_step, figures) + except ImportError: + warnings.warn("Install the `umap-learn` package to see embedding plots.") + return eval_avg_loss + + +def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step): + model.train() + best_loss = {"train_loss": None, "eval_loss": float("inf")} + avg_loader_time = 0 + end_time = time.time() + for epoch in range(c.epochs): + tot_loss = 0 + epoch_time = 0 + for _, data in enumerate(data_loader): + start_time = time.time() + + # setup input data + inputs, labels = data + # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] + labels = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape( + labels.shape + ) + inputs = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape( + inputs.shape + ) + # ToDo: move it to a unit test + # labels_converted = torch.transpose(labels.view(c.num_utter_per_class, c.num_classes_in_batch), 0, 1).reshape(labels.shape) + # inputs_converted = torch.transpose(inputs.view(c.num_utter_per_class, c.num_classes_in_batch, -1), 0, 1).reshape(inputs.shape) + # idx = 0 + # for j in range(0, c.num_classes_in_batch, 1): + # for i in range(j, len(labels), c.num_classes_in_batch): + # if not torch.all(labels[i].eq(labels_converted[idx])) or not torch.all(inputs[i].eq(inputs_converted[idx])): + # print("Invalid") + # print(labels) + # exit() + # idx += 1 + # labels = labels_converted + # inputs = inputs_converted + + loader_time = time.time() - end_time + global_step += 1 + + optimizer.zero_grad() + + # dispatch data to GPU + if use_cuda: + inputs = inputs.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + # forward pass model + outputs = model(inputs) + + # loss computation + loss = criterion( + outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels + ) + loss.backward() + grad_norm, _ = check_update(model, c.grad_clip) + optimizer.step() + + # setup lr + if c.lr_decay: + scheduler.step() + + step_time = time.time() - start_time + epoch_time += step_time + + # acumulate the total epoch loss + tot_loss += loss.item() + + # Averaged Loader Time + num_loader_workers = c.num_loader_workers if c.num_loader_workers > 0 else 1 + avg_loader_time = ( + 1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time + if avg_loader_time != 0 + else loader_time + ) + current_lr = optimizer.param_groups[0]["lr"] + + if global_step % c.steps_plot_stats == 0: + # Plot Training Epoch Stats + train_stats = { + "loss": loss.item(), + "lr": current_lr, + "grad_norm": grad_norm, + "step_time": step_time, + "avg_loader_time": avg_loader_time, + } + dashboard_logger.train_epoch_stats(global_step, train_stats) + figures = { + "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch), + } + dashboard_logger.train_figures(global_step, figures) + + if global_step % c.print_step == 0: + print( + " | > Step:{} Loss:{:.5f} GradNorm:{:.5f} " + "StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format( + global_step, loss.item(), grad_norm, step_time, loader_time, avg_loader_time, current_lr + ), + flush=True, + ) + + if global_step % c.save_step == 0: + # save model + save_checkpoint( + c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict() + ) + + end_time = time.time() + + print("") + print( + ">>> Epoch:{} AvgLoss: {:.5f} GradNorm:{:.5f} " + "EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format( + epoch, tot_loss / len(data_loader), grad_norm, epoch_time, avg_loader_time + ), + flush=True, + ) + # evaluation + if c.run_eval: + model.eval() + eval_loss = evaluation(model, criterion, eval_data_loader, global_step) + print("\n\n") + print("--> EVAL PERFORMANCE") + print( + " | > Epoch:{} AvgLoss: {:.5f} ".format(epoch, eval_loss), + flush=True, + ) + # save the best checkpoint + best_loss = save_best_model( + {"train_loss": None, "eval_loss": eval_loss}, + best_loss, + c, + model, + optimizer, + None, + global_step, + epoch, + OUT_PATH, + criterion=criterion.state_dict(), + ) + model.train() + + return best_loss, global_step + + +def main(args): # pylint: disable=redefined-outer-name + # pylint: disable=global-variable-undefined + global meta_data_train + global meta_data_eval + global train_classes + + ap = AudioProcessor(**c.audio) + model = setup_encoder_model(c) + + optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model) + + # pylint: disable=redefined-outer-name + meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=True) + + train_data_loader, train_classes, map_classid_to_classname = setup_loader(ap, is_val=False) + if c.run_eval: + eval_data_loader, _, _ = setup_loader(ap, is_val=True) + else: + eval_data_loader = None + + num_classes = len(train_classes) + criterion = model.get_criterion(c, num_classes) + + if c.loss == "softmaxproto" and c.model != "speaker_encoder": + c.map_classid_to_classname = map_classid_to_classname + copy_model_files(c, OUT_PATH, new_fields={}) + + if args.restore_path: + criterion, args.restore_step = model.load_checkpoint( + c, args.restore_path, eval=False, use_cuda=use_cuda, criterion=criterion + ) + print(" > Model restored from step %d" % args.restore_step, flush=True) + else: + args.restore_step = 0 + + if c.lr_decay: + scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) + else: + scheduler = None + + num_params = count_parameters(model) + print("\n > Model has {} parameters".format(num_params), flush=True) + + if use_cuda: + model = model.cuda() + criterion.cuda() + + global_step = args.restore_step + _, global_step = train(model, optimizer, scheduler, criterion, train_data_loader, eval_data_loader, global_step) + + +if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + + args, c, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = init_training() + + try: + main(args) + except KeyboardInterrupt: + remove_experiment_folder(OUT_PATH) + try: + sys.exit(0) + except SystemExit: + os._exit(0) # pylint: disable=protected-access + except Exception: # pylint: disable=broad-except + remove_experiment_folder(OUT_PATH) + traceback.print_exc() + sys.exit(1) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..6d6342a762a202eb39452a996d2b55fc7f961c2e --- /dev/null +++ b/TTS/bin/train_tts.py @@ -0,0 +1,75 @@ +import logging +import os +from dataclasses import dataclass, field + +from trainer import Trainer, TrainerArgs + +from TTS.config import load_config, register_config +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models import setup_model +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger + + +@dataclass +class TrainTTSArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + +def main(): + """Run `tts` model training directly by a `config.json` file.""" + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + + # init trainer args + train_args = TrainTTSArgs() + parser = train_args.init_argparse(arg_prefix="") + + # override trainer args from comman-line args + args, config_overrides = parser.parse_known_args() + train_args.parse_args(args) + + # load config.json and register + if args.config_path or args.continue_path: + if args.config_path: + # init from a file + config = load_config(args.config_path) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + elif args.continue_path: + # continue from a prev experiment + config = load_config(os.path.join(args.continue_path, "config.json")) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + else: + # init from console args + from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel + + config_base = BaseTrainingConfig() + config_base.parse_known_args(config_overrides) + config = register_config(config_base.model)() + + # load training samples + train_samples, eval_samples = load_tts_samples( + config.datasets, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, + ) + + # init the model from config + model = setup_model(config, train_samples + eval_samples) + + # init the trainer and 🚀 + trainer = Trainer( + train_args, + model.config, + config.output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + parse_command_line_args=False, + ) + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..221ff4cff05541ef128d24ce24f5d148f3626598 --- /dev/null +++ b/TTS/bin/train_vocoder.py @@ -0,0 +1,81 @@ +import logging +import os +from dataclasses import dataclass, field + +from trainer import Trainer, TrainerArgs + +from TTS.config import load_config, register_config +from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger +from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data +from TTS.vocoder.models import setup_model + + +@dataclass +class TrainVocoderArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + +def main(): + """Run `tts` model training directly by a `config.json` file.""" + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + + # init trainer args + train_args = TrainVocoderArgs() + parser = train_args.init_argparse(arg_prefix="") + + # override trainer args from comman-line args + args, config_overrides = parser.parse_known_args() + train_args.parse_args(args) + + # load config.json and register + if args.config_path or args.continue_path: + if args.config_path: + # init from a file + config = load_config(args.config_path) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + elif args.continue_path: + # continue from a prev experiment + config = load_config(os.path.join(args.continue_path, "config.json")) + if len(config_overrides) > 0: + config.parse_known_args(config_overrides, relaxed_parser=True) + else: + # init from console args + from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel + + config_base = BaseTrainingConfig() + config_base.parse_known_args(config_overrides) + config = register_config(config_base.model)() + + # load training samples + if "feature_path" in config and config.feature_path: + # load pre-computed features + print(f" > Loading features from: {config.feature_path}") + eval_samples, train_samples = load_wav_feat_data(config.data_path, config.feature_path, config.eval_split_size) + else: + # load data raw wav files + eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size) + + # setup audio processor + ap = AudioProcessor(**config.audio) + + # init the model from config + model = setup_model(config) + + # init the trainer and 🚀 + trainer = Trainer( + train_args, + config, + config.output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, + parse_command_line_args=False, + ) + trainer.fit() + + +if __name__ == "__main__": + main() diff --git a/TTS/bin/tune_wavegrad.py b/TTS/bin/tune_wavegrad.py new file mode 100644 index 0000000000000000000000000000000000000000..df2923952d92a393a89c4019aa17166078bd072b --- /dev/null +++ b/TTS/bin/tune_wavegrad.py @@ -0,0 +1,107 @@ +"""Search a good noise schedule for WaveGrad for a given number of inference iterations""" + +import argparse +import logging +from itertools import product as cartesian_product + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from TTS.config import load_config +from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger +from TTS.vocoder.datasets.preprocess import load_wav_data +from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset +from TTS.vocoder.models import setup_model + +if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") + parser.add_argument("--config_path", type=str, help="Path to model config file.") + parser.add_argument("--data_path", type=str, help="Path to data directory.") + parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.") + parser.add_argument( + "--num_iter", + type=int, + help="Number of model inference iterations that you like to optimize noise schedule for.", + ) + parser.add_argument("--use_cuda", action="store_true", help="enable CUDA.") + parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.") + parser.add_argument( + "--search_depth", + type=int, + default=3, + help="Search granularity. Increasing this increases the run-time exponentially.", + ) + + # load config + args = parser.parse_args() + config = load_config(args.config_path) + + # setup audio processor + ap = AudioProcessor(**config.audio) + + # load dataset + _, train_data = load_wav_data(args.data_path, 0) + train_data = train_data[: args.num_samples] + dataset = WaveGradDataset( + ap=ap, + items=train_data, + seq_len=-1, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + is_training=True, + return_segments=False, + use_noise_augment=False, + use_cache=False, + ) + loader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + collate_fn=dataset.collate_full_clips, + drop_last=False, + num_workers=config.num_loader_workers, + pin_memory=False, + ) + + # setup the model + model = setup_model(config) + if args.use_cuda: + model.cuda() + + # setup optimization parameters + base_values = sorted(10 * np.random.uniform(size=args.search_depth)) + print(f" > base values: {base_values}") + exponents = 10 ** np.linspace(-6, -1, num=args.num_iter) + best_error = float("inf") + best_schedule = None # pylint: disable=C0103 + total_search_iter = len(base_values) ** args.num_iter + for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter): + beta = exponents * base + model.compute_noise_level(beta) + for data in loader: + mel, audio = data + y_hat = model.inference(mel.cuda() if args.use_cuda else mel) + + if args.use_cuda: + y_hat = y_hat.cpu() + y_hat = y_hat.numpy() + + mel_hat = [] + for i in range(y_hat.shape[0]): + m = ap.melspectrogram(y_hat[i, 0])[:, :-1] + mel_hat.append(torch.from_numpy(m)) + + mel_hat = torch.stack(mel_hat) + mse = torch.sum((mel - mel_hat) ** 2).mean() + if mse.item() < best_error: + best_error = mse.item() + best_schedule = {"beta": beta} + print(f" > Found a better schedule. - MSE: {mse.item()}") + np.save(args.output_path, best_schedule) diff --git a/TTS/config/__init__.py b/TTS/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5103f200b0752e83b8a012d4cb192641811fb649 --- /dev/null +++ b/TTS/config/__init__.py @@ -0,0 +1,138 @@ +import json +import os +import re +from typing import Dict + +import fsspec +import yaml +from coqpit import Coqpit + +from TTS.config.shared_configs import * +from TTS.utils.generic_utils import find_module + + +def read_json_with_comments(json_path): + """for backward compat.""" + # fallback to json + with fsspec.open(json_path, "r", encoding="utf-8") as f: + input_str = f.read() + # handle comments but not urls with // + input_str = re.sub( + r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str + ) + return json.loads(input_str) + + +def register_config(model_name: str) -> Coqpit: + """Find the right config for the given model name. + + Args: + model_name (str): Model name. + + Raises: + ModuleNotFoundError: No matching config for the model name. + + Returns: + Coqpit: config class. + """ + config_class = None + config_name = model_name + "_config" + + # TODO: fix this + if model_name == "xtts": + from TTS.tts.configs.xtts_config import XttsConfig + + config_class = XttsConfig + paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs", "TTS.vc.configs"] + for path in paths: + try: + config_class = find_module(path, config_name) + except ModuleNotFoundError: + pass + if config_class is None: + raise ModuleNotFoundError(f" [!] Config for {model_name} cannot be found.") + return config_class + + +def _process_model_name(config_dict: Dict) -> str: + """Format the model name as expected. It is a band-aid for the old `vocoder` model names. + + Args: + config_dict (Dict): A dictionary including the config fields. + + Returns: + str: Formatted modelname. + """ + model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"] + model_name = model_name.replace("_generator", "").replace("_discriminator", "") + return model_name + + +def load_config(config_path: str) -> Coqpit: + """Import `json` or `yaml` files as TTS configs. First, load the input file as a `dict` and check the model name + to find the corresponding Config class. Then initialize the Config. + + Args: + config_path (str): path to the config file. + + Raises: + TypeError: given config file has an unknown type. + + Returns: + Coqpit: TTS config object. + """ + config_dict = {} + ext = os.path.splitext(config_path)[1] + if ext in (".yml", ".yaml"): + with fsspec.open(config_path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + elif ext == ".json": + try: + with fsspec.open(config_path, "r", encoding="utf-8") as f: + data = json.load(f) + except json.decoder.JSONDecodeError: + # backwards compat. + data = read_json_with_comments(config_path) + else: + raise TypeError(f" [!] Unknown config file type {ext}") + config_dict.update(data) + model_name = _process_model_name(config_dict) + config_class = register_config(model_name.lower()) + config = config_class() + config.from_dict(config_dict) + return config + + +def check_config_and_model_args(config, arg_name, value): + """Check the give argument in `config.model_args` if exist or in `config` for + the given value. + + Return False if the argument does not exist in `config.model_args` or `config`. + This is to patch up the compatibility between models with and without `model_args`. + + TODO: Remove this in the future with a unified approach. + """ + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] == value + if hasattr(config, arg_name): + return config[arg_name] == value + return False + + +def get_from_config_or_model_args(config, arg_name): + """Get the given argument from `config.model_args` if exist or in `config`.""" + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] + return config[arg_name] + + +def get_from_config_or_model_args_with_default(config, arg_name, def_val): + """Get the given argument from `config.model_args` if exist or in `config`.""" + if hasattr(config, "model_args"): + if arg_name in config.model_args: + return config.model_args[arg_name] + if hasattr(config, arg_name): + return config[arg_name] + return def_val diff --git a/TTS/config/shared_configs.py b/TTS/config/shared_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..7fae77d61361eff8c8fa521a0f4a90dc46f63c75 --- /dev/null +++ b/TTS/config/shared_configs.py @@ -0,0 +1,268 @@ +from dataclasses import asdict, dataclass +from typing import List + +from coqpit import Coqpit, check_argument +from trainer import TrainerConfig + + +@dataclass +class BaseAudioConfig(Coqpit): + """Base config to definge audio processing parameters. It is used to initialize + ```TTS.utils.audio.AudioProcessor.``` + + Args: + fft_size (int): + Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024. + + win_length (int): + Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match + ```fft_size```. Defaults to 1024. + + hop_length (int): + Number of audio samples between adjacent STFT columns. Defaults to 1024. + + frame_shift_ms (int): + Set ```hop_length``` based on milliseconds and sampling rate. + + frame_length_ms (int): + Set ```win_length``` based on milliseconds and sampling rate. + + stft_pad_mode (str): + Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'. + + sample_rate (int): + Audio sampling rate. Defaults to 22050. + + resample (bool): + Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```. + + preemphasis (float): + Preemphasis coefficient. Defaults to 0.0. + + ref_level_db (int): 20 + Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air. + Defaults to 20. + + do_sound_norm (bool): + Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False. + + log_func (str): + Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'. + + do_trim_silence (bool): + Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. + + do_amp_to_db_mel (bool, optional): + enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + + pitch_fmax (float, optional): + Maximum frequency of the F0 frames. Defaults to ```640```. + + pitch_fmin (float, optional): + Minimum frequency of the F0 frames. Defaults to ```1```. + + trim_db (int): + Silence threshold used for silence trimming. Defaults to 45. + + do_rms_norm (bool, optional): + enable/disable RMS volume normalization when loading an audio file. Defaults to False. + + db_level (int, optional): + dB level used for rms normalization. The range is -99 to 0. Defaults to None. + + power (float): + Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the + artifacts in the synthesized voice. Defaults to 1.5. + + griffin_lim_iters (int): + Number of Griffing Lim iterations. Defaults to 60. + + num_mels (int): + Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80. + + mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices. + It needs to be adjusted for a dataset. Defaults to 0. + + mel_fmax (float): + Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset. + + spec_gain (int): + Gain applied when converting amplitude to DB. Defaults to 20. + + signal_norm (bool): + enable/disable signal normalization. Defaults to True. + + min_level_db (int): + minimum db threshold for the computed melspectrograms. Defaults to -100. + + symmetric_norm (bool): + enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else + [0, k], Defaults to True. + + max_norm (float): + ```k``` defining the normalization range. Defaults to 4.0. + + clip_norm (bool): + enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + + stats_path (str): + Path to the computed stats file. Defaults to None. + """ + + # stft parameters + fft_size: int = 1024 + win_length: int = 1024 + hop_length: int = 256 + frame_shift_ms: int = None + frame_length_ms: int = None + stft_pad_mode: str = "reflect" + # audio processing parameters + sample_rate: int = 22050 + resample: bool = False + preemphasis: float = 0.0 + ref_level_db: int = 20 + do_sound_norm: bool = False + log_func: str = "np.log10" + # silence trimming + do_trim_silence: bool = True + trim_db: int = 45 + # rms volume normalization + do_rms_norm: bool = False + db_level: float = None + # griffin-lim params + power: float = 1.5 + griffin_lim_iters: int = 60 + # mel-spec params + num_mels: int = 80 + mel_fmin: float = 0.0 + mel_fmax: float = None + spec_gain: int = 20 + do_amp_to_db_linear: bool = True + do_amp_to_db_mel: bool = True + # f0 params + pitch_fmax: float = 640.0 + pitch_fmin: float = 1.0 + # normalization params + signal_norm: bool = True + min_level_db: int = -100 + symmetric_norm: bool = True + max_norm: float = 4.0 + clip_norm: bool = True + stats_path: str = None + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + check_argument("num_mels", c, restricted=True, min_val=10, max_val=2056) + check_argument("fft_size", c, restricted=True, min_val=128, max_val=4058) + check_argument("sample_rate", c, restricted=True, min_val=512, max_val=100000) + check_argument( + "frame_length_ms", + c, + restricted=True, + min_val=10, + max_val=1000, + alternative="win_length", + ) + check_argument("frame_shift_ms", c, restricted=True, min_val=1, max_val=1000, alternative="hop_length") + check_argument("preemphasis", c, restricted=True, min_val=0, max_val=1) + check_argument("min_level_db", c, restricted=True, min_val=-1000, max_val=10) + check_argument("ref_level_db", c, restricted=True, min_val=0, max_val=1000) + check_argument("power", c, restricted=True, min_val=1, max_val=5) + check_argument("griffin_lim_iters", c, restricted=True, min_val=10, max_val=1000) + + # normalization parameters + check_argument("signal_norm", c, restricted=True) + check_argument("symmetric_norm", c, restricted=True) + check_argument("max_norm", c, restricted=True, min_val=0.1, max_val=1000) + check_argument("clip_norm", c, restricted=True) + check_argument("mel_fmin", c, restricted=True, min_val=0.0, max_val=1000) + check_argument("mel_fmax", c, restricted=True, min_val=500.0, allow_none=True) + check_argument("spec_gain", c, restricted=True, min_val=1, max_val=100) + check_argument("do_trim_silence", c, restricted=True) + check_argument("trim_db", c, restricted=True) + + +@dataclass +class BaseDatasetConfig(Coqpit): + """Base config for TTS datasets. + + Args: + formatter (str): + Formatter name that defines used formatter in ```TTS.tts.datasets.formatter```. Defaults to `""`. + + dataset_name (str): + Unique name for the dataset. Defaults to `""`. + + path (str): + Root path to the dataset files. Defaults to `""`. + + meta_file_train (str): + Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets. + Defaults to `""`. + + ignored_speakers (List): + List of speakers IDs that are not used at the training. Default None. + + language (str): + Language code of the dataset. If defined, it overrides `phoneme_language`. Defaults to `""`. + + phonemizer (str): + Phonemizer used for that dataset's language. By default it uses `DEF_LANG_TO_PHONEMIZER`. Defaults to `""`. + + meta_file_val (str): + Name of the dataset meta file that defines the instances used at validation. + + meta_file_attn_mask (str): + Path to the file that lists the attention mask files used with models that require attention masks to + train the duration predictor. + """ + + formatter: str = "" + dataset_name: str = "" + path: str = "" + meta_file_train: str = "" + ignored_speakers: List[str] = None + language: str = "" + phonemizer: str = "" + meta_file_val: str = "" + meta_file_attn_mask: str = "" + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + check_argument("formatter", c, restricted=True) + check_argument("path", c, restricted=True) + check_argument("meta_file_train", c, restricted=True) + check_argument("meta_file_val", c, restricted=False) + check_argument("meta_file_attn_mask", c, restricted=False) + + +@dataclass +class BaseTrainingConfig(TrainerConfig): + """Base config to define the basic 🐸TTS training parameters that are shared + among all the models. It is based on ```Trainer.TrainingConfig```. + + Args: + model (str): + Name of the model that is used in the training. + + num_loader_workers (int): + Number of workers for training time dataloader. + + num_eval_loader_workers (int): + Number of workers for evaluation time dataloader. + """ + + model: str = None + # dataloading + num_loader_workers: int = 0 + num_eval_loader_workers: int = 0 + use_noise_augment: bool = False diff --git a/TTS/demos/xtts_ft_demo/requirements.txt b/TTS/demos/xtts_ft_demo/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b58f41c546036999f2a9c47d91dd217c7d3580f0 --- /dev/null +++ b/TTS/demos/xtts_ft_demo/requirements.txt @@ -0,0 +1,2 @@ +faster_whisper==0.9.0 +gradio==4.7.1 diff --git a/TTS/demos/xtts_ft_demo/utils/formatter.py b/TTS/demos/xtts_ft_demo/utils/formatter.py new file mode 100644 index 0000000000000000000000000000000000000000..40e8b8ed32f2ff30e328b45b64257d64946170c7 --- /dev/null +++ b/TTS/demos/xtts_ft_demo/utils/formatter.py @@ -0,0 +1,161 @@ +import gc +import os + +import pandas +import torch +import torchaudio +from faster_whisper import WhisperModel +from tqdm import tqdm + +# torch.set_num_threads(1) +from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners + +torch.set_num_threads(16) + +audio_types = (".wav", ".mp3", ".flac") + + +def list_audios(basePath, contains=None): + # return the set of files that are valid + return list_files(basePath, validExts=audio_types, contains=contains) + + +def list_files(basePath, validExts=None, contains=None): + # loop over the directory structure + for rootDir, dirNames, filenames in os.walk(basePath): + # loop over the filenames in the current directory + for filename in filenames: + # if the contains string is not none and the filename does not contain + # the supplied string, then ignore the file + if contains is not None and filename.find(contains) == -1: + continue + + # determine the file extension of the current file + ext = filename[filename.rfind(".") :].lower() + + # check to see if the file is an audio and should be processed + if validExts is None or ext.endswith(validExts): + # construct the path to the audio and yield it + audioPath = os.path.join(rootDir, filename) + yield audioPath + + +def format_audio_list( + audio_files, + target_language="en", + out_path=None, + buffer=0.2, + eval_percentage=0.15, + speaker_name="coqui", + gradio_progress=None, +): + audio_total_size = 0 + # make sure that ooutput file exists + os.makedirs(out_path, exist_ok=True) + + # Loading Whisper + device = "cuda" if torch.cuda.is_available() else "cpu" + + print("Loading Whisper Model!") + asr_model = WhisperModel("large-v2", device=device, compute_type="float16") + + metadata = {"audio_file": [], "text": [], "speaker_name": []} + + if gradio_progress is not None: + tqdm_object = gradio_progress.tqdm(audio_files, desc="Formatting...") + else: + tqdm_object = tqdm(audio_files) + + for audio_path in tqdm_object: + wav, sr = torchaudio.load(audio_path) + # stereo to mono if needed + if wav.size(0) != 1: + wav = torch.mean(wav, dim=0, keepdim=True) + + wav = wav.squeeze() + audio_total_size += wav.size(-1) / sr + + segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language) + segments = list(segments) + i = 0 + sentence = "" + sentence_start = None + first_word = True + # added all segments words in a unique list + words_list = [] + for _, segment in enumerate(segments): + words = list(segment.words) + words_list.extend(words) + + # process each word + for word_idx, word in enumerate(words_list): + if first_word: + sentence_start = word.start + # If it is the first sentence, add buffer or get the begining of the file + if word_idx == 0: + sentence_start = max(sentence_start - buffer, 0) # Add buffer to the sentence start + else: + # get previous sentence end + previous_word_end = words_list[word_idx - 1].end + # add buffer or get the silence midle between the previous sentence and the current one + sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2) + + sentence = word.word + first_word = False + else: + sentence += word.word + + if word.word[-1] in ["!", ".", "?"]: + sentence = sentence[1:] + # Expand number and abbreviations plus normalization + sentence = multilingual_cleaners(sentence, target_language) + audio_file_name, _ = os.path.splitext(os.path.basename(audio_path)) + + audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}.wav" + + # Check for the next word's existence + if word_idx + 1 < len(words_list): + next_word_start = words_list[word_idx + 1].start + else: + # If don't have more words it means that it is the last sentence then use the audio len as next word start + next_word_start = (wav.shape[0] - 1) / sr + + # Average the current word end and next word start + word_end = min((word.end + next_word_start) / 2, word.end + buffer) + + absoulte_path = os.path.join(out_path, audio_file) + os.makedirs(os.path.dirname(absoulte_path), exist_ok=True) + i += 1 + first_word = True + + audio = wav[int(sr * sentence_start) : int(sr * word_end)].unsqueeze(0) + # if the audio is too short ignore it (i.e < 0.33 seconds) + if audio.size(-1) >= sr / 3: + torchaudio.save(absoulte_path, audio, sr) + else: + continue + + metadata["audio_file"].append(audio_file) + metadata["text"].append(sentence) + metadata["speaker_name"].append(speaker_name) + + df = pandas.DataFrame(metadata) + df = df.sample(frac=1) + num_val_samples = int(len(df) * eval_percentage) + + df_eval = df[:num_val_samples] + df_train = df[num_val_samples:] + + df_train = df_train.sort_values("audio_file") + train_metadata_path = os.path.join(out_path, "metadata_train.csv") + df_train.to_csv(train_metadata_path, sep="|", index=False) + + eval_metadata_path = os.path.join(out_path, "metadata_eval.csv") + df_eval = df_eval.sort_values("audio_file") + df_eval.to_csv(eval_metadata_path, sep="|", index=False) + + # deallocate VRAM and RAM + del asr_model, df_train, df_eval, df, metadata + gc.collect() + + return train_metadata_path, eval_metadata_path, audio_total_size diff --git a/TTS/demos/xtts_ft_demo/utils/gpt_train.py b/TTS/demos/xtts_ft_demo/utils/gpt_train.py new file mode 100644 index 0000000000000000000000000000000000000000..7b41966b8f6b5c87849f0892476449b5ae7ac5ac --- /dev/null +++ b/TTS/demos/xtts_ft_demo/utils/gpt_train.py @@ -0,0 +1,171 @@ +import gc +import os + +from trainer import Trainer, TrainerArgs + +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig +from TTS.utils.manage import ModelManager + + +def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path, max_audio_length=255995): + # Logging parameters + RUN_NAME = "GPT_XTTS_FT" + PROJECT_NAME = "XTTS_trainer" + DASHBOARD_LOGGER = "tensorboard" + LOGGER_URI = None + + # Set here the path that the checkpoints will be saved. Default: ./run/training/ + OUT_PATH = os.path.join(output_path, "run", "training") + + # Training Parameters + OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False + START_WITH_EVAL = False # if True it will star with evaluation + BATCH_SIZE = batch_size # set here the batch size + GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps + + # Define here the dataset that you want to use for the fine-tuning on. + config_dataset = BaseDatasetConfig( + formatter="coqui", + dataset_name="ft_dataset", + path=os.path.dirname(train_csv), + meta_file_train=train_csv, + meta_file_val=eval_csv, + language=language, + ) + + # Add here the configs of the datasets + DATASETS_CONFIG_LIST = [config_dataset] + + # Define the path where XTTS v2.0.1 files will be downloaded + CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/") + os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True) + + # DVAE files + DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth" + MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth" + + # Set the path to the downloaded files + DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(DVAE_CHECKPOINT_LINK)) + MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(MEL_NORM_LINK)) + + # download DVAE files if needed + if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE): + print(" > Downloading DVAE files!") + ModelManager._download_model_files( + [MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True + ) + + # Download XTTS v2.0 checkpoint if needed + TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json" + XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth" + XTTS_CONFIG_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json" + + # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning. + TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file + XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file + XTTS_CONFIG_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CONFIG_LINK)) # config.json file + + # download XTTS v2.0 files if needed + if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT): + print(" > Downloading XTTS v2.0 files!") + ModelManager._download_model_files( + [TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK, XTTS_CONFIG_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True + ) + + # init args and config + model_args = GPTArgs( + max_conditioning_length=132300, # 6 secs + min_conditioning_length=66150, # 3 secs + debug_loading_failures=False, + max_wav_length=max_audio_length, # ~11.6 seconds + max_text_length=200, + mel_norm_file=MEL_NORM_FILE, + dvae_checkpoint=DVAE_CHECKPOINT, + xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune + tokenizer_file=TOKENIZER_FILE, + gpt_num_audio_tokens=1026, + gpt_start_audio_token=1024, + gpt_stop_audio_token=1025, + gpt_use_masking_gt_prompt_approach=True, + gpt_use_perceiver_resampler=True, + ) + # define audio config + audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) + # training parameters config + config = GPTTrainerConfig( + epochs=num_epochs, + output_path=OUT_PATH, + model_args=model_args, + run_name=RUN_NAME, + project_name=PROJECT_NAME, + run_description=""" + GPT XTTS training + """, + dashboard_logger=DASHBOARD_LOGGER, + logger_uri=LOGGER_URI, + audio=audio_config, + batch_size=BATCH_SIZE, + batch_group_size=48, + eval_batch_size=BATCH_SIZE, + num_loader_workers=8, + eval_split_max_size=256, + print_step=50, + plot_step=100, + log_model_step=100, + save_step=1000, + save_n_checkpoints=1, + save_checkpoints=True, + # target_loss="loss", + print_eval=False, + # Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters. + optimizer="AdamW", + optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS, + optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2}, + lr=5e-06, # learning rate + lr_scheduler="MultiStepLR", + # it was adjusted accordly for the new step scheme + lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1}, + test_sentences=[], + ) + + # init the model from config + model = GPTTrainer.init_from_config(config) + + # load training samples + train_samples, eval_samples = load_tts_samples( + DATASETS_CONFIG_LIST, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, + ) + + # init the trainer and 🚀 + trainer = Trainer( + TrainerArgs( + restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter + skip_train_epoch=False, + start_with_eval=START_WITH_EVAL, + grad_accum_steps=GRAD_ACUMM_STEPS, + ), + config, + output_path=OUT_PATH, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + ) + trainer.fit() + + # get the longest text audio file to use as speaker reference + samples_len = [len(item["text"].split(" ")) for item in train_samples] + longest_text_idx = samples_len.index(max(samples_len)) + speaker_ref = train_samples[longest_text_idx]["audio_file"] + + trainer_out_path = trainer.output_path + + # deallocate VRAM and RAM + del model, trainer, train_samples, eval_samples + gc.collect() + + return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer_out_path, speaker_ref diff --git a/TTS/demos/xtts_ft_demo/xtts_demo.py b/TTS/demos/xtts_ft_demo/xtts_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac38ed6ee870bb548d1cf1960e72e9c1268f6ba --- /dev/null +++ b/TTS/demos/xtts_ft_demo/xtts_demo.py @@ -0,0 +1,433 @@ +import argparse +import logging +import os +import sys +import tempfile +import traceback + +import gradio as gr +import torch +import torchaudio + +from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list +from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt +from TTS.tts.configs.xtts_config import XttsConfig +from TTS.tts.models.xtts import Xtts + + +def clear_gpu_cache(): + # clear the GPU cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +XTTS_MODEL = None + + +def load_model(xtts_checkpoint, xtts_config, xtts_vocab): + global XTTS_MODEL + clear_gpu_cache() + if not xtts_checkpoint or not xtts_config or not xtts_vocab: + return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!" + config = XttsConfig() + config.load_json(xtts_config) + XTTS_MODEL = Xtts.init_from_config(config) + print("Loading XTTS model! ") + XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False) + if torch.cuda.is_available(): + XTTS_MODEL.cuda() + + print("Model Loaded!") + return "Model Loaded!" + + +def run_tts(lang, tts_text, speaker_audio_file): + if XTTS_MODEL is None or not speaker_audio_file: + return "You need to run the previous step to load the model !!", None, None + + gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents( + audio_path=speaker_audio_file, + gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, + max_ref_length=XTTS_MODEL.config.max_ref_len, + sound_norm_refs=XTTS_MODEL.config.sound_norm_refs, + ) + out = XTTS_MODEL.inference( + text=tts_text, + language=lang, + gpt_cond_latent=gpt_cond_latent, + speaker_embedding=speaker_embedding, + temperature=XTTS_MODEL.config.temperature, # Add custom parameters here + length_penalty=XTTS_MODEL.config.length_penalty, + repetition_penalty=XTTS_MODEL.config.repetition_penalty, + top_k=XTTS_MODEL.config.top_k, + top_p=XTTS_MODEL.config.top_p, + ) + + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: + out["wav"] = torch.tensor(out["wav"]).unsqueeze(0) + out_path = fp.name + torchaudio.save(out_path, out["wav"], 24000) + + return "Speech generated !", out_path, speaker_audio_file + + +# define a logger to redirect +class Logger: + def __init__(self, filename="log.out"): + self.log_file = filename + self.terminal = sys.stdout + self.log = open(self.log_file, "w") + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + + def flush(self): + self.terminal.flush() + self.log.flush() + + def isatty(self): + return False + + +# redirect stdout and stderr to a file +sys.stdout = Logger() +sys.stderr = sys.stdout + + +# logging.basicConfig(stream=sys.stdout, level=logging.INFO) + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)] +) + + +def read_logs(): + sys.stdout.flush() + with open(sys.stdout.log_file, "r") as f: + return f.read() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="""XTTS fine-tuning demo\n\n""" + """ + Example runs: + python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port + """, + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--port", + type=int, + help="Port to run the gradio demo. Default: 5003", + default=5003, + ) + parser.add_argument( + "--out_path", + type=str, + help="Output path (where data and checkpoints will be saved) Default: /tmp/xtts_ft/", + default="/tmp/xtts_ft/", + ) + + parser.add_argument( + "--num_epochs", + type=int, + help="Number of epochs to train. Default: 10", + default=10, + ) + parser.add_argument( + "--batch_size", + type=int, + help="Batch size. Default: 4", + default=4, + ) + parser.add_argument( + "--grad_acumm", + type=int, + help="Grad accumulation steps. Default: 1", + default=1, + ) + parser.add_argument( + "--max_audio_length", + type=int, + help="Max permitted audio size in seconds. Default: 11", + default=11, + ) + + args = parser.parse_args() + + with gr.Blocks() as demo: + with gr.Tab("1 - Data processing"): + out_path = gr.Textbox( + label="Output path (where data and checkpoints will be saved):", + value=args.out_path, + ) + # upload_file = gr.Audio( + # sources="upload", + # label="Select here the audio files that you want to use for XTTS trainining !", + # type="filepath", + # ) + upload_file = gr.File( + file_count="multiple", + label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)", + ) + lang = gr.Dropdown( + label="Dataset Language", + value="en", + choices=[ + "en", + "es", + "fr", + "de", + "it", + "pt", + "pl", + "tr", + "ru", + "nl", + "cs", + "ar", + "zh", + "hu", + "ko", + "ja", + "hi", + ], + ) + progress_data = gr.Label(label="Progress:") + logs = gr.Textbox( + label="Logs:", + interactive=False, + ) + demo.load(read_logs, None, logs, every=1) + + prompt_compute_btn = gr.Button(value="Step 1 - Create dataset") + + def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)): + clear_gpu_cache() + out_path = os.path.join(out_path, "dataset") + os.makedirs(out_path, exist_ok=True) + if audio_path is None: + return ( + "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", + "", + "", + ) + else: + try: + train_meta, eval_meta, audio_total_size = format_audio_list( + audio_path, target_language=language, out_path=out_path, gradio_progress=progress + ) + except: + traceback.print_exc() + error = traceback.format_exc() + return ( + f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", + "", + "", + ) + + clear_gpu_cache() + + # if audio total len is less than 2 minutes raise an error + if audio_total_size < 120: + message = "The sum of the duration of the audios that you provided should be at least 2 minutes!" + print(message) + return message, "", "" + + print("Dataset Processed!") + return "Dataset Processed!", train_meta, eval_meta + + with gr.Tab("2 - Fine-tuning XTTS Encoder"): + train_csv = gr.Textbox( + label="Train CSV:", + ) + eval_csv = gr.Textbox( + label="Eval CSV:", + ) + num_epochs = gr.Slider( + label="Number of epochs:", + minimum=1, + maximum=100, + step=1, + value=args.num_epochs, + ) + batch_size = gr.Slider( + label="Batch size:", + minimum=2, + maximum=512, + step=1, + value=args.batch_size, + ) + grad_acumm = gr.Slider( + label="Grad accumulation steps:", + minimum=2, + maximum=128, + step=1, + value=args.grad_acumm, + ) + max_audio_length = gr.Slider( + label="Max permitted audio size in seconds:", + minimum=2, + maximum=20, + step=1, + value=args.max_audio_length, + ) + progress_train = gr.Label(label="Progress:") + logs_tts_train = gr.Textbox( + label="Logs:", + interactive=False, + ) + demo.load(read_logs, None, logs_tts_train, every=1) + train_btn = gr.Button(value="Step 2 - Run the training") + + def train_model( + language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length + ): + clear_gpu_cache() + if not train_csv or not eval_csv: + return ( + "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", + "", + "", + "", + "", + ) + try: + # convert seconds to waveform frames + max_audio_length = int(max_audio_length * 22050) + config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt( + language, + num_epochs, + batch_size, + grad_acumm, + train_csv, + eval_csv, + output_path=output_path, + max_audio_length=max_audio_length, + ) + except: + traceback.print_exc() + error = traceback.format_exc() + return ( + f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", + "", + "", + "", + "", + ) + + # copy original files to avoid parameters changes issues + os.system(f"cp {config_path} {exp_path}") + os.system(f"cp {vocab_file} {exp_path}") + + ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth") + print("Model training done!") + clear_gpu_cache() + return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav + + with gr.Tab("3 - Inference"): + with gr.Row(): + with gr.Column() as col1: + xtts_checkpoint = gr.Textbox( + label="XTTS checkpoint path:", + value="", + ) + xtts_config = gr.Textbox( + label="XTTS config path:", + value="", + ) + + xtts_vocab = gr.Textbox( + label="XTTS vocab path:", + value="", + ) + progress_load = gr.Label(label="Progress:") + load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model") + + with gr.Column() as col2: + speaker_reference_audio = gr.Textbox( + label="Speaker reference audio:", + value="", + ) + tts_language = gr.Dropdown( + label="Language", + value="en", + choices=[ + "en", + "es", + "fr", + "de", + "it", + "pt", + "pl", + "tr", + "ru", + "nl", + "cs", + "ar", + "zh", + "hu", + "ko", + "ja", + "hi", + ], + ) + tts_text = gr.Textbox( + label="Input Text.", + value="This model sounds really good and above all, it's reasonably fast.", + ) + tts_btn = gr.Button(value="Step 4 - Inference") + + with gr.Column() as col3: + progress_gen = gr.Label(label="Progress:") + tts_output_audio = gr.Audio(label="Generated Audio.") + reference_audio = gr.Audio(label="Reference audio used.") + + prompt_compute_btn.click( + fn=preprocess_dataset, + inputs=[ + upload_file, + lang, + out_path, + ], + outputs=[ + progress_data, + train_csv, + eval_csv, + ], + ) + + train_btn.click( + fn=train_model, + inputs=[ + lang, + train_csv, + eval_csv, + num_epochs, + batch_size, + grad_acumm, + out_path, + max_audio_length, + ], + outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio], + ) + + load_btn.click( + fn=load_model, + inputs=[xtts_checkpoint, xtts_config, xtts_vocab], + outputs=[progress_load], + ) + + tts_btn.click( + fn=run_tts, + inputs=[ + tts_language, + tts_text, + speaker_reference_audio, + ], + outputs=[progress_gen, tts_output_audio, reference_audio], + ) + + demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0") diff --git a/TTS/encoder/README.md b/TTS/encoder/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9f829c9e2af0cb7d705273f5198ccc35f5bb669f --- /dev/null +++ b/TTS/encoder/README.md @@ -0,0 +1,18 @@ +### Speaker Encoder + +This is an implementation of https://arxiv.org/abs/1710.10467. This model can be used for voice and speaker embedding. + +With the code here you can generate d-vectors for both multi-speaker and single-speaker TTS datasets, then visualise and explore them along with the associated audio files in an interactive chart. + +Below is an example showing embedding results of various speakers. You can generate the same plot with the provided notebook as demonstrated in [this video](https://youtu.be/KW3oO7JVa7Q). + +![](umap.png) + +Download a pretrained model from [Released Models](https://github.com/mozilla/TTS/wiki/Released-Models) page. + +To run the code, you need to follow the same flow as in TTS. + +- Define 'config.json' for your needs. Note that, audio parameters should match your TTS model. +- Example training call ```python speaker_encoder/train.py --config_path speaker_encoder/config.json --data_path ~/Data/Libri-TTS/train-clean-360``` +- Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda /model/path/best_model.pth model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files. +- Watch training on Tensorboard as in TTS diff --git a/TTS/encoder/__init__.py b/TTS/encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/encoder/configs/base_encoder_config.py b/TTS/encoder/configs/base_encoder_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ebbaa0457bb55aef70d54dd36fd9b2b7f7c702bb --- /dev/null +++ b/TTS/encoder/configs/base_encoder_config.py @@ -0,0 +1,61 @@ +from dataclasses import asdict, dataclass, field +from typing import Dict, List + +from coqpit import MISSING + +from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig + + +@dataclass +class BaseEncoderConfig(BaseTrainingConfig): + """Defines parameters for a Generic Encoder model.""" + + model: str = None + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + # model params + model_params: Dict = field( + default_factory=lambda: { + "model_name": "lstm", + "input_dim": 80, + "proj_dim": 256, + "lstm_dim": 768, + "num_lstm_layers": 3, + "use_lstm_with_projection": True, + } + ) + + audio_augmentation: Dict = field(default_factory=lambda: {}) + + # training params + epochs: int = 10000 + loss: str = "angleproto" + grad_clip: float = 3.0 + lr: float = 0.0001 + optimizer: str = "radam" + optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.9, 0.999], "weight_decay": 0}) + lr_decay: bool = False + warmup_steps: int = 4000 + + # logging params + tb_model_param_stats: bool = False + steps_plot_stats: int = 10 + save_step: int = 1000 + print_step: int = 20 + run_eval: bool = False + + # data loader + num_classes_in_batch: int = MISSING + num_utter_per_class: int = MISSING + eval_num_classes_in_batch: int = None + eval_num_utter_per_class: int = None + + num_loader_workers: int = MISSING + voice_len: float = 1.6 + + def check_values(self): + super().check_values() + c = asdict(self) + assert ( + c["model_params"]["input_dim"] == self.audio.num_mels + ), " [!] model input dimendion must be equal to melspectrogram dimension." diff --git a/TTS/encoder/configs/emotion_encoder_config.py b/TTS/encoder/configs/emotion_encoder_config.py new file mode 100644 index 0000000000000000000000000000000000000000..1d12325cf2001ca79476192339bb919a09031f6f --- /dev/null +++ b/TTS/encoder/configs/emotion_encoder_config.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass + +from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig + + +@dataclass +class EmotionEncoderConfig(BaseEncoderConfig): + """Defines parameters for Emotion Encoder model.""" + + model: str = "emotion_encoder" + map_classid_to_classname: dict = None + class_name_key: str = "emotion_name" diff --git a/TTS/encoder/configs/speaker_encoder_config.py b/TTS/encoder/configs/speaker_encoder_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0588527a68e417285a80c9ee5b21af5ed90096b5 --- /dev/null +++ b/TTS/encoder/configs/speaker_encoder_config.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + +from TTS.encoder.configs.base_encoder_config import BaseEncoderConfig + + +@dataclass +class SpeakerEncoderConfig(BaseEncoderConfig): + """Defines parameters for Speaker Encoder model.""" + + model: str = "speaker_encoder" + class_name_key: str = "speaker_name" diff --git a/TTS/encoder/dataset.py b/TTS/encoder/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..bb780e3c1dbf91225f8f9ef16daccb48f2971b91 --- /dev/null +++ b/TTS/encoder/dataset.py @@ -0,0 +1,146 @@ +import logging +import random + +import torch +from torch.utils.data import Dataset + +from TTS.encoder.utils.generic_utils import AugmentWAV + +logger = logging.getLogger(__name__) + + +class EncoderDataset(Dataset): + def __init__( + self, + config, + ap, + meta_data, + voice_len=1.6, + num_classes_in_batch=64, + num_utter_per_class=10, + augmentation_config=None, + use_torch_spec=None, + ): + """ + Args: + ap (TTS.tts.utils.AudioProcessor): audio processor object. + meta_data (list): list of dataset instances. + seq_len (int): voice segment length in seconds. + """ + super().__init__() + self.config = config + self.items = meta_data + self.sample_rate = ap.sample_rate + self.seq_len = int(voice_len * self.sample_rate) + self.num_utter_per_class = num_utter_per_class + self.ap = ap + self.use_torch_spec = use_torch_spec + self.classes, self.items = self.__parse_items() + + self.classname_to_classid = {key: i for i, key in enumerate(self.classes)} + + # Data Augmentation + self.augmentator = None + self.gaussian_augmentation_config = None + if augmentation_config: + self.data_augmentation_p = augmentation_config["p"] + if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config): + self.augmentator = AugmentWAV(ap, augmentation_config) + + if "gaussian" in augmentation_config.keys(): + self.gaussian_augmentation_config = augmentation_config["gaussian"] + + logger.info("DataLoader initialization") + logger.info(" | Classes per batch: %d", num_classes_in_batch) + logger.info(" | Number of instances: %d", len(self.items)) + logger.info(" | Sequence length: %d", self.seq_len) + logger.info(" | Number of classes: %d", len(self.classes)) + logger.info(" | Classes: %s", self.classes) + + def load_wav(self, filename): + audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) + return audio + + def __parse_items(self): + class_to_utters = {} + for item in self.items: + path_ = item["audio_file"] + class_name = item[self.config.class_name_key] + if class_name in class_to_utters.keys(): + class_to_utters[class_name].append(path_) + else: + class_to_utters[class_name] = [ + path_, + ] + + # skip classes with number of samples >= self.num_utter_per_class + class_to_utters = {k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class} + + classes = list(class_to_utters.keys()) + classes.sort() + + new_items = [] + for item in self.items: + path_ = item["audio_file"] + class_name = item["emotion_name"] if self.config.model == "emotion_encoder" else item["speaker_name"] + # ignore filtered classes + if class_name not in classes: + continue + # ignore small audios + if self.load_wav(path_).shape[0] - self.seq_len <= 0: + continue + + new_items.append({"wav_file_path": path_, "class_name": class_name}) + + return classes, new_items + + def __len__(self): + return len(self.items) + + def get_num_classes(self): + return len(self.classes) + + def get_class_list(self): + return self.classes + + def set_classes(self, classes): + self.classes = classes + self.classname_to_classid = {key: i for i, key in enumerate(self.classes)} + + def get_map_classid_to_classname(self): + return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items()) + + def __getitem__(self, idx): + return self.items[idx] + + def collate_fn(self, batch): + # get the batch class_ids + labels = [] + feats = [] + for item in batch: + utter_path = item["wav_file_path"] + class_name = item["class_name"] + + # get classid + class_id = self.classname_to_classid[class_name] + # load wav file + wav = self.load_wav(utter_path) + offset = random.randint(0, wav.shape[0] - self.seq_len) + wav = wav[offset : offset + self.seq_len] + + if self.augmentator is not None and self.data_augmentation_p: + if random.random() < self.data_augmentation_p: + wav = self.augmentator.apply_one(wav) + + if not self.use_torch_spec: + mel = self.ap.melspectrogram(wav) + feats.append(torch.FloatTensor(mel)) + else: + feats.append(torch.FloatTensor(wav)) + + labels.append(class_id) + + feats = torch.stack(feats) + labels = torch.LongTensor(labels) + + return feats, labels diff --git a/TTS/encoder/losses.py b/TTS/encoder/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..2e27848c31b288fbb6440163db65466cecd9b744 --- /dev/null +++ b/TTS/encoder/losses.py @@ -0,0 +1,230 @@ +import logging + +import torch +import torch.nn.functional as F +from torch import nn + +logger = logging.getLogger(__name__) + + +# adapted from https://github.com/cvqluu/GE2E-Loss +class GE2ELoss(nn.Module): + def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"): + """ + Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1] + Accepts an input of size (N, M, D) + where N is the number of speakers in the batch, + M is the number of utterances per speaker, + and D is the dimensionality of the embedding vector (e.g. d-vector) + Args: + - init_w (float): defines the initial value of w in Equation (5) of [1] + - init_b (float): definies the initial value of b in Equation (5) of [1] + """ + super().__init__() + # pylint: disable=E1102 + self.w = nn.Parameter(torch.tensor(init_w)) + # pylint: disable=E1102 + self.b = nn.Parameter(torch.tensor(init_b)) + self.loss_method = loss_method + + logger.info("Initialized Generalized End-to-End loss") + + assert self.loss_method in ["softmax", "contrast"] + + if self.loss_method == "softmax": + self.embed_loss = self.embed_loss_softmax + if self.loss_method == "contrast": + self.embed_loss = self.embed_loss_contrast + + # pylint: disable=R0201 + def calc_new_centroids(self, dvecs, centroids, spkr, utt): + """ + Calculates the new centroids excluding the reference utterance + """ + excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :])) + excl = torch.mean(excl, 0) + new_centroids = [] + for i, centroid in enumerate(centroids): + if i == spkr: + new_centroids.append(excl) + else: + new_centroids.append(centroid) + return torch.stack(new_centroids) + + def calc_cosine_sim(self, dvecs, centroids): + """ + Make the cosine similarity matrix with dims (N,M,N) + """ + cos_sim_matrix = [] + for spkr_idx, speaker in enumerate(dvecs): + cs_row = [] + for utt_idx, utterance in enumerate(speaker): + new_centroids = self.calc_new_centroids(dvecs, centroids, spkr_idx, utt_idx) + # vector based cosine similarity for speed + cs_row.append( + torch.clamp( + torch.mm( + utterance.unsqueeze(1).transpose(0, 1), + new_centroids.transpose(0, 1), + ) + / (torch.norm(utterance) * torch.norm(new_centroids, dim=1)), + 1e-6, + ) + ) + cs_row = torch.cat(cs_row, dim=0) + cos_sim_matrix.append(cs_row) + return torch.stack(cos_sim_matrix) + + # pylint: disable=R0201 + def embed_loss_softmax(self, dvecs, cos_sim_matrix): + """ + Calculates the loss on each embedding $L(e_{ji})$ by taking softmax + """ + N, M, _ = dvecs.shape + L = [] + for j in range(N): + L_row = [] + for i in range(M): + L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j]) + L_row = torch.stack(L_row) + L.append(L_row) + return torch.stack(L) + + # pylint: disable=R0201 + def embed_loss_contrast(self, dvecs, cos_sim_matrix): + """ + Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid + """ + N, M, _ = dvecs.shape + L = [] + for j in range(N): + L_row = [] + for i in range(M): + centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i]) + excl_centroids_sigmoids = torch.cat((centroids_sigmoids[:j], centroids_sigmoids[j + 1 :])) + L_row.append(1.0 - torch.sigmoid(cos_sim_matrix[j, i, j]) + torch.max(excl_centroids_sigmoids)) + L_row = torch.stack(L_row) + L.append(L_row) + return torch.stack(L) + + def forward(self, x, _label=None): + """ + Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + + assert x.size()[1] >= 2 + + centroids = torch.mean(x, 1) + cos_sim_matrix = self.calc_cosine_sim(x, centroids) + torch.clamp(self.w, 1e-6) + cos_sim_matrix = self.w * cos_sim_matrix + self.b + L = self.embed_loss(x, cos_sim_matrix) + return L.mean() + + +# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py +class AngleProtoLoss(nn.Module): + """ + Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982 + Accepts an input of size (N, M, D) + where N is the number of speakers in the batch, + M is the number of utterances per speaker, + and D is the dimensionality of the embedding vector + Args: + - init_w (float): defines the initial value of w + - init_b (float): definies the initial value of b + """ + + def __init__(self, init_w=10.0, init_b=-5.0): + super().__init__() + # pylint: disable=E1102 + self.w = nn.Parameter(torch.tensor(init_w)) + # pylint: disable=E1102 + self.b = nn.Parameter(torch.tensor(init_b)) + self.criterion = torch.nn.CrossEntropyLoss() + + logger.info("Initialized Angular Prototypical loss") + + def forward(self, x, _label=None): + """ + Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + + assert x.size()[1] >= 2 + + out_anchor = torch.mean(x[:, 1:, :], 1) + out_positive = x[:, 0, :] + num_speakers = out_anchor.size()[0] + + cos_sim_matrix = F.cosine_similarity( + out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), + out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2), + ) + torch.clamp(self.w, 1e-6) + cos_sim_matrix = cos_sim_matrix * self.w + self.b + label = torch.arange(num_speakers).to(cos_sim_matrix.device) + L = self.criterion(cos_sim_matrix, label) + return L + + +class SoftmaxLoss(nn.Module): + """ + Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982 + Args: + - embedding_dim (float): speaker embedding dim + - n_speakers (float): number of speakers + """ + + def __init__(self, embedding_dim, n_speakers): + super().__init__() + + self.criterion = torch.nn.CrossEntropyLoss() + self.fc = nn.Linear(embedding_dim, n_speakers) + + logger.info("Initialised Softmax Loss") + + def forward(self, x, label=None): + # reshape for compatibility + x = x.reshape(-1, x.size()[-1]) + label = label.reshape(-1) + + x = self.fc(x) + L = self.criterion(x, label) + + return L + + def inference(self, embedding): + x = self.fc(embedding) + activations = torch.nn.functional.softmax(x, dim=1).squeeze(0) + class_id = torch.argmax(activations) + return class_id + + +class SoftmaxAngleProtoLoss(nn.Module): + """ + Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153 + Args: + - embedding_dim (float): speaker embedding dim + - n_speakers (float): number of speakers + - init_w (float): defines the initial value of w + - init_b (float): definies the initial value of b + """ + + def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0): + super().__init__() + + self.softmax = SoftmaxLoss(embedding_dim, n_speakers) + self.angleproto = AngleProtoLoss(init_w, init_b) + + logger.info("Initialised SoftmaxAnglePrototypical Loss") + + def forward(self, x, label=None): + """ + Calculates the SoftmaxAnglePrototypical loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) + """ + + Lp = self.angleproto(x) + + Ls = self.softmax(x, label) + + return Ls + Lp diff --git a/TTS/encoder/models/base_encoder.py b/TTS/encoder/models/base_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f7137c21860669ec59a0e0605510cf20f0c48bda --- /dev/null +++ b/TTS/encoder/models/base_encoder.py @@ -0,0 +1,165 @@ +import logging + +import numpy as np +import torch +import torchaudio +from coqpit import Coqpit +from torch import nn +from trainer.io import load_fsspec + +from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss +from TTS.utils.generic_utils import set_init_dict + +logger = logging.getLogger(__name__) + + +class PreEmphasis(nn.Module): + def __init__(self, coefficient=0.97): + super().__init__() + self.coefficient = coefficient + self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) + + def forward(self, x): + assert len(x.size()) == 2 + + x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") + return torch.nn.functional.conv1d(x, self.filter).squeeze(1) + + +class BaseEncoder(nn.Module): + """Base `encoder` class. Every new `encoder` model must inherit this. + + It defines common `encoder` specific functions. + """ + + # pylint: disable=W0102 + def __init__(self): + super(BaseEncoder, self).__init__() + + def get_torch_mel_spectrogram_class(self, audio_config): + return torch.nn.Sequential( + PreEmphasis(audio_config["preemphasis"]), + # TorchSTFT( + # n_fft=audio_config["fft_size"], + # hop_length=audio_config["hop_length"], + # win_length=audio_config["win_length"], + # sample_rate=audio_config["sample_rate"], + # window="hamming_window", + # mel_fmin=0.0, + # mel_fmax=None, + # use_htk=True, + # do_amp_to_db=False, + # n_mels=audio_config["num_mels"], + # power=2.0, + # use_mel=True, + # mel_norm=None, + # ) + torchaudio.transforms.MelSpectrogram( + sample_rate=audio_config["sample_rate"], + n_fft=audio_config["fft_size"], + win_length=audio_config["win_length"], + hop_length=audio_config["hop_length"], + window_fn=torch.hamming_window, + n_mels=audio_config["num_mels"], + ), + ) + + @torch.no_grad() + def inference(self, x, l2_norm=True): + return self.forward(x, l2_norm) + + @torch.no_grad() + def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True, l2_norm=True): + """ + Generate embeddings for a batch of utterances + x: 1xTxD + """ + # map to the waveform size + if self.use_torch_spec: + num_frames = num_frames * self.audio_config["hop_length"] + + max_len = x.shape[1] + + if max_len < num_frames: + num_frames = max_len + + offsets = np.linspace(0, max_len - num_frames, num=num_eval) + + frames_batch = [] + for offset in offsets: + offset = int(offset) + end_offset = int(offset + num_frames) + frames = x[:, offset:end_offset] + frames_batch.append(frames) + + frames_batch = torch.cat(frames_batch, dim=0) + embeddings = self.inference(frames_batch, l2_norm=l2_norm) + + if return_mean: + embeddings = torch.mean(embeddings, dim=0, keepdim=True) + return embeddings + + def get_criterion(self, c: Coqpit, num_classes=None): + if c.loss == "ge2e": + criterion = GE2ELoss(loss_method="softmax") + elif c.loss == "angleproto": + criterion = AngleProtoLoss() + elif c.loss == "softmaxproto": + criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes) + else: + raise Exception("The %s not is a loss supported" % c.loss) + return criterion + + def load_checkpoint( + self, + config: Coqpit, + checkpoint_path: str, + eval: bool = False, + use_cuda: bool = False, + criterion=None, + cache=False, + ): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + try: + self.load_state_dict(state["model"]) + logger.info("Model fully restored. ") + except (KeyError, RuntimeError) as error: + # If eval raise the error + if eval: + raise error + + logger.info("Partial model initialization.") + model_dict = self.state_dict() + model_dict = set_init_dict(model_dict, state["model"], c) + self.load_state_dict(model_dict) + del model_dict + + # load the criterion for restore_path + if criterion is not None and "criterion" in state: + try: + criterion.load_state_dict(state["criterion"]) + except (KeyError, RuntimeError) as error: + logger.exception("Criterion load ignored because of: %s", error) + + # instance and load the criterion for the encoder classifier in inference time + if ( + eval + and criterion is None + and "criterion" in state + and getattr(config, "map_classid_to_classname", None) is not None + ): + criterion = self.get_criterion(config, len(config.map_classid_to_classname)) + criterion.load_state_dict(state["criterion"]) + + if use_cuda: + self.cuda() + if criterion is not None: + criterion = criterion.cuda() + + if eval: + self.eval() + assert not self.training + + if not eval: + return criterion, state["step"] + return criterion diff --git a/TTS/encoder/models/lstm.py b/TTS/encoder/models/lstm.py new file mode 100644 index 0000000000000000000000000000000000000000..51852b5b820d181824b0db1a205cd5d7bd4fb20d --- /dev/null +++ b/TTS/encoder/models/lstm.py @@ -0,0 +1,99 @@ +import torch +from torch import nn + +from TTS.encoder.models.base_encoder import BaseEncoder + + +class LSTMWithProjection(nn.Module): + def __init__(self, input_size, hidden_size, proj_size): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.proj_size = proj_size + self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) + self.linear = nn.Linear(hidden_size, proj_size, bias=False) + + def forward(self, x): + self.lstm.flatten_parameters() + o, (_, _) = self.lstm(x) + return self.linear(o) + + +class LSTMWithoutProjection(nn.Module): + def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers): + super().__init__() + self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True) + self.linear = nn.Linear(lstm_dim, proj_dim, bias=True) + self.relu = nn.ReLU() + + def forward(self, x): + _, (hidden, _) = self.lstm(x) + return self.relu(self.linear(hidden[-1])) + + +class LSTMSpeakerEncoder(BaseEncoder): + def __init__( + self, + input_dim, + proj_dim=256, + lstm_dim=768, + num_lstm_layers=3, + use_lstm_with_projection=True, + use_torch_spec=False, + audio_config=None, + ): + super().__init__() + self.use_lstm_with_projection = use_lstm_with_projection + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + self.proj_dim = proj_dim + + layers = [] + # choise LSTM layer + if use_lstm_with_projection: + layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim)) + for _ in range(num_lstm_layers - 1): + layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim)) + self.layers = nn.Sequential(*layers) + else: + self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers) + + self.instancenorm = nn.InstanceNorm1d(input_dim) + + if self.use_torch_spec: + self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config) + else: + self.torch_spec = None + + self._init_layers() + + def _init_layers(self): + for name, param in self.layers.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0.0) + elif "weight" in name: + nn.init.xavier_normal_(param) + + def forward(self, x, l2_norm=True): + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` + """ + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + if self.use_torch_spec: + x.squeeze_(1) + x = self.torch_spec(x) + x = self.instancenorm(x).transpose(1, 2) + d = self.layers(x) + if self.use_lstm_with_projection: + d = d[:, -1] + if l2_norm: + d = torch.nn.functional.normalize(d, p=2, dim=1) + return d diff --git a/TTS/encoder/models/resnet.py b/TTS/encoder/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..5eafcd6005739fcdc454fb20def3e66791766a53 --- /dev/null +++ b/TTS/encoder/models/resnet.py @@ -0,0 +1,198 @@ +import torch +from torch import nn + +# from TTS.utils.audio.torch_transforms import TorchSTFT +from TTS.encoder.models.base_encoder import BaseEncoder + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=8): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class SEBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): + super(SEBasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se = SELayer(planes, reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + return out + + +class ResNetSpeakerEncoder(BaseEncoder): + """Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153 + Adapted from: https://github.com/clovaai/voxceleb_trainer + """ + + # pylint: disable=W0102 + def __init__( + self, + input_dim=64, + proj_dim=512, + layers=[3, 4, 6, 3], + num_filters=[32, 64, 128, 256], + encoder_type="ASP", + log_input=False, + use_torch_spec=False, + audio_config=None, + ): + super(ResNetSpeakerEncoder, self).__init__() + + self.encoder_type = encoder_type + self.input_dim = input_dim + self.log_input = log_input + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + self.proj_dim = proj_dim + + self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) + self.relu = nn.ReLU(inplace=True) + self.bn1 = nn.BatchNorm2d(num_filters[0]) + + self.inplanes = num_filters[0] + self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0]) + self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2)) + self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2)) + self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2)) + + self.instancenorm = nn.InstanceNorm1d(input_dim) + + if self.use_torch_spec: + self.torch_spec = self.get_torch_mel_spectrogram_class(audio_config) + else: + self.torch_spec = None + + outmap_size = int(self.input_dim / 8) + + self.attention = nn.Sequential( + nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), + nn.Softmax(dim=2), + ) + + if self.encoder_type == "SAP": + out_dim = num_filters[3] * outmap_size + elif self.encoder_type == "ASP": + out_dim = num_filters[3] * outmap_size * 2 + else: + raise ValueError("Undefined encoder") + + self.fc = nn.Linear(out_dim, proj_dim) + + self._init_layers() + + def _init_layers(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def create_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + # pylint: disable=R0201 + def new_parameter(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + def forward(self, x, l2_norm=False): + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` + """ + x.squeeze_(1) + # if you torch spec compute it otherwise use the mel spec computed by the AP + if self.use_torch_spec: + x = self.torch_spec(x) + + if self.log_input: + x = (x + 1e-6).log() + x = self.instancenorm(x).unsqueeze(1) + + x = self.conv1(x) + x = self.relu(x) + x = self.bn1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = x.reshape(x.size()[0], -1, x.size()[-1]) + + w = self.attention(x) + + if self.encoder_type == "SAP": + x = torch.sum(x * w, dim=2) + elif self.encoder_type == "ASP": + mu = torch.sum(x * w, dim=2) + sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) + x = torch.cat((mu, sg), 1) + + x = x.view(x.size()[0], -1) + x = self.fc(x) + + if l2_norm: + x = torch.nn.functional.normalize(x, p=2, dim=1) + return x diff --git a/TTS/encoder/requirements.txt b/TTS/encoder/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a486cc45ddb44591bd03c9c0df294fbe98c13884 --- /dev/null +++ b/TTS/encoder/requirements.txt @@ -0,0 +1,2 @@ +umap-learn +numpy>=1.17.0 diff --git a/TTS/encoder/utils/__init__.py b/TTS/encoder/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/encoder/utils/generic_utils.py b/TTS/encoder/utils/generic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..495b4def5aa86765ebfdcfb3c36b87531849ceca --- /dev/null +++ b/TTS/encoder/utils/generic_utils.py @@ -0,0 +1,141 @@ +import glob +import logging +import os +import random + +import numpy as np +from scipy import signal + +from TTS.encoder.models.lstm import LSTMSpeakerEncoder +from TTS.encoder.models.resnet import ResNetSpeakerEncoder + +logger = logging.getLogger(__name__) + + +class AugmentWAV(object): + def __init__(self, ap, augmentation_config): + self.ap = ap + self.use_additive_noise = False + + if "additive" in augmentation_config.keys(): + self.additive_noise_config = augmentation_config["additive"] + additive_path = self.additive_noise_config["sounds_path"] + if additive_path: + self.use_additive_noise = True + # get noise types + self.additive_noise_types = [] + for key in self.additive_noise_config.keys(): + if isinstance(self.additive_noise_config[key], dict): + self.additive_noise_types.append(key) + + additive_files = glob.glob(os.path.join(additive_path, "**/*.wav"), recursive=True) + + self.noise_list = {} + + for wav_file in additive_files: + noise_dir = wav_file.replace(additive_path, "").split(os.sep)[0] + # ignore not listed directories + if noise_dir not in self.additive_noise_types: + continue + if noise_dir not in self.noise_list: + self.noise_list[noise_dir] = [] + self.noise_list[noise_dir].append(wav_file) + + logger.info( + "Using Additive Noise Augmentation: with %d audios instances from %s", + len(additive_files), + self.additive_noise_types, + ) + + self.use_rir = False + + if "rir" in augmentation_config.keys(): + self.rir_config = augmentation_config["rir"] + if self.rir_config["rir_path"]: + self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True) + self.use_rir = True + + logger.info("Using RIR Noise Augmentation: with %d audios instances", len(self.rir_files)) + + self.create_augmentation_global_list() + + def create_augmentation_global_list(self): + if self.use_additive_noise: + self.global_noise_list = self.additive_noise_types + else: + self.global_noise_list = [] + if self.use_rir: + self.global_noise_list.append("RIR_AUG") + + def additive_noise(self, noise_type, audio): + clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4) + + noise_list = random.sample( + self.noise_list[noise_type], + random.randint( + self.additive_noise_config[noise_type]["min_num_noises"], + self.additive_noise_config[noise_type]["max_num_noises"], + ), + ) + + audio_len = audio.shape[0] + noises_wav = None + for noise in noise_list: + noiseaudio = self.ap.load_wav(noise, sr=self.ap.sample_rate)[:audio_len] + + if noiseaudio.shape[0] < audio_len: + continue + + noise_snr = random.uniform( + self.additive_noise_config[noise_type]["min_snr_in_db"], + self.additive_noise_config[noise_type]["max_num_noises"], + ) + noise_db = 10 * np.log10(np.mean(noiseaudio**2) + 1e-4) + noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio + + if noises_wav is None: + noises_wav = noise_wav + else: + noises_wav += noise_wav + + # if all possible files is less than audio, choose other files + if noises_wav is None: + return self.additive_noise(noise_type, audio) + + return audio + noises_wav + + def reverberate(self, audio): + audio_len = audio.shape[0] + + rir_file = random.choice(self.rir_files) + rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate) + rir = rir / np.sqrt(np.sum(rir**2)) + return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len] + + def apply_one(self, audio): + noise_type = random.choice(self.global_noise_list) + if noise_type == "RIR_AUG": + return self.reverberate(audio) + + return self.additive_noise(noise_type, audio) + + +def setup_encoder_model(config: "Coqpit"): + if config.model_params["model_name"].lower() == "lstm": + model = LSTMSpeakerEncoder( + config.model_params["input_dim"], + config.model_params["proj_dim"], + config.model_params["lstm_dim"], + config.model_params["num_lstm_layers"], + use_torch_spec=config.model_params.get("use_torch_spec", False), + audio_config=config.audio, + ) + elif config.model_params["model_name"].lower() == "resnet": + model = ResNetSpeakerEncoder( + input_dim=config.model_params["input_dim"], + proj_dim=config.model_params["proj_dim"], + log_input=config.model_params.get("log_input", False), + use_torch_spec=config.model_params.get("use_torch_spec", False), + audio_config=config.audio, + ) + return model diff --git a/TTS/encoder/utils/prepare_voxceleb.py b/TTS/encoder/utils/prepare_voxceleb.py new file mode 100644 index 0000000000000000000000000000000000000000..da7522a51243c7a817b76abe7ce2428ebaca2e6b --- /dev/null +++ b/TTS/encoder/utils/prepare_voxceleb.py @@ -0,0 +1,226 @@ +# coding=utf-8 +# Copyright (C) 2020 ATHENA AUTHORS; Yiping Peng; Ne Luo +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Only support eager mode and TF>=2.0.0 +# pylint: disable=no-member, invalid-name, relative-beyond-top-level +# pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes +""" voxceleb 1 & 2 """ + +import csv +import hashlib +import logging +import os +import subprocess +import sys +import zipfile + +import soundfile as sf + +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger + +logger = logging.getLogger(__name__) + +SUBSETS = { + "vox1_dev_wav": [ + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad", + ], + "vox1_test_wav": ["https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"], + "vox2_dev_aac": [ + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag", + "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah", + ], + "vox2_test_aac": ["https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"], +} + +MD5SUM = { + "vox1_dev_wav": "ae63e55b951748cc486645f532ba230b", + "vox2_dev_aac": "bbc063c46078a602ca71605645c2a402", + "vox1_test_wav": "185fdc63c3c739954633d50379a3d102", + "vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312", +} + +USER = {"user": "", "password": ""} + +speaker_id_dict = {} + + +def download_and_extract(directory, subset, urls): + """Download and extract the given split of dataset. + + Args: + directory: the directory where to put the downloaded data. + subset: subset name of the corpus. + urls: the list of urls to download the data file. + """ + os.makedirs(directory, exist_ok=True) + + try: + for url in urls: + zip_filepath = os.path.join(directory, url.split("/")[-1]) + if os.path.exists(zip_filepath): + continue + logger.info("Downloading %s to %s" % (url, zip_filepath)) + subprocess.call( + "wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath), + shell=True, + ) + + statinfo = os.stat(zip_filepath) + logger.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) + + # concatenate all parts into zip files + if ".zip" not in zip_filepath: + zip_filepath = "_".join(zip_filepath.split("_")[:-1]) + subprocess.call("cat %s* > %s.zip" % (zip_filepath, zip_filepath), shell=True) + zip_filepath += ".zip" + extract_path = zip_filepath.strip(".zip") + + # check zip file md5sum + with open(zip_filepath, "rb") as f_zip: + md5 = hashlib.md5(f_zip.read()).hexdigest() + if md5 != MD5SUM[subset]: + raise ValueError("md5sum of %s mismatch" % zip_filepath) + + with zipfile.ZipFile(zip_filepath, "r") as zfile: + zfile.extractall(directory) + extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename) + subprocess.call("mv %s %s" % (extract_path_ori, extract_path), shell=True) + finally: + # os.remove(zip_filepath) + pass + + +def exec_cmd(cmd): + """Run a command in a subprocess. + Args: + cmd: command line to be executed. + Return: + int, the return code. + """ + try: + retcode = subprocess.call(cmd, shell=True) + if retcode < 0: + logger.info(f"Child was terminated by signal {retcode}") + except OSError as e: + logger.info(f"Execution failed: {e}") + retcode = -999 + return retcode + + +def decode_aac_with_ffmpeg(aac_file, wav_file): + """Decode a given AAC file into WAV using ffmpeg. + Args: + aac_file: file path to input AAC file. + wav_file: file path to output WAV file. + Return: + bool, True if success. + """ + cmd = f"ffmpeg -i {aac_file} {wav_file}" + logger.info(f"Decoding aac file using command line: {cmd}") + ret = exec_cmd(cmd) + if ret != 0: + logger.error(f"Failed to decode aac file with retcode {ret}") + logger.error("Please check your ffmpeg installation.") + return False + return True + + +def convert_audio_and_make_label(input_dir, subset, output_dir, output_file): + """Optionally convert AAC to WAV and make speaker labels. + Args: + input_dir: the directory which holds the input dataset. + subset: the name of the specified subset. e.g. vox1_dev_wav + output_dir: the directory to place the newly generated csv files. + output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv + """ + + logger.info("Preprocessing audio and label for subset %s" % subset) + source_dir = os.path.join(input_dir, subset) + + files = [] + # Convert all AAC file into WAV format. At the same time, generate the csv + for root, _, filenames in os.walk(source_dir): + for filename in filenames: + name, ext = os.path.splitext(filename) + if ext.lower() == ".wav": + _, ext2 = os.path.splitext(name) + if ext2: + continue + wav_file = os.path.join(root, filename) + elif ext.lower() == ".m4a": + # Convert AAC to WAV. + aac_file = os.path.join(root, filename) + wav_file = aac_file + ".wav" + if not os.path.exists(wav_file): + if not decode_aac_with_ffmpeg(aac_file, wav_file): + raise RuntimeError("Audio decoding failed.") + else: + continue + speaker_name = root.split(os.path.sep)[-2] + if speaker_name not in speaker_id_dict: + num = len(speaker_id_dict) + speaker_id_dict[speaker_name] = num + # wav_filesize = os.path.getsize(wav_file) + wav_length = len(sf.read(wav_file)[0]) + files.append((os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name)) + + # Write to CSV file which contains four columns: + # "wav_filename", "wav_length_ms", "speaker_id", "speaker_name". + csv_file_path = os.path.join(output_dir, output_file) + with open(csv_file_path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f, delimiter="\t") + writer.writerow(["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) + for wav_file in files: + writer.writerow(wav_file) + logger.info("Successfully generated csv file {}".format(csv_file_path)) + + +def processor(directory, subset, force_process): + """download and process""" + urls = SUBSETS + if subset not in urls: + raise ValueError(subset, "is not in voxceleb") + + subset_csv = os.path.join(directory, subset + ".csv") + if not force_process and os.path.exists(subset_csv): + return subset_csv + + logger.info("Downloading and process the voxceleb in %s", directory) + logger.info("Preparing subset %s", subset) + download_and_extract(directory, subset, urls[subset]) + convert_audio_and_make_label(directory, subset, directory, subset + ".csv") + logger.info("Finished downloading and processing") + return subset_csv + + +if __name__ == "__main__": + setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + if len(sys.argv) != 4: + print("Usage: python prepare_data.py save_directory user password") + sys.exit() + + DIR, USER["user"], USER["password"] = sys.argv[1], sys.argv[2], sys.argv[3] + for SUBSET in SUBSETS: + processor(DIR, SUBSET, False) diff --git a/TTS/encoder/utils/training.py b/TTS/encoder/utils/training.py new file mode 100644 index 0000000000000000000000000000000000000000..cc3a78b084890e42032121d9ed66ad1c96ac06bc --- /dev/null +++ b/TTS/encoder/utils/training.py @@ -0,0 +1,99 @@ +import os +from dataclasses import dataclass, field + +from coqpit import Coqpit +from trainer import TrainerArgs, get_last_checkpoint +from trainer.generic_utils import get_experiment_folder_path, get_git_branch +from trainer.io import copy_model_files +from trainer.logging import logger_factory +from trainer.logging.console_logger import ConsoleLogger + +from TTS.config import load_config, register_config +from TTS.tts.utils.text.characters import parse_symbols + + +@dataclass +class TrainArgs(TrainerArgs): + config_path: str = field(default=None, metadata={"help": "Path to the config file."}) + + +def getarguments(): + train_config = TrainArgs() + parser = train_config.init_argparse(arg_prefix="") + return parser + + +def process_args(args, config=None): + """Process parsed comand line arguments and initialize the config if not provided. + Args: + args (argparse.Namespace or dict like): Parsed input arguments. + config (Coqpit): Model config. If none, it is generated from `args`. Defaults to None. + Returns: + c (Coqpit): Config paramaters. + out_path (str): Path to save models and logging. + audio_path (str): Path to save generated test audios. + c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does + logging to the console. + dashboard_logger (WandbLogger or TensorboardLogger): Class that does the dashboard Logging + TODO: + - Interactive config definition. + """ + if isinstance(args, tuple): + args, coqpit_overrides = args + if args.continue_path: + # continue a previous training from its output folder + experiment_path = args.continue_path + args.config_path = os.path.join(args.continue_path, "config.json") + args.restore_path, best_model = get_last_checkpoint(args.continue_path) + if not args.best_path: + args.best_path = best_model + # init config if not already defined + if config is None: + if args.config_path: + # init from a file + config = load_config(args.config_path) + else: + # init from console args + from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel + + config_base = BaseTrainingConfig() + config_base.parse_known_args(coqpit_overrides) + config = register_config(config_base.model)() + # override values from command-line args + config.parse_known_args(coqpit_overrides, relaxed_parser=True) + experiment_path = args.continue_path + if not experiment_path: + experiment_path = get_experiment_folder_path(config.output_path, config.run_name) + audio_path = os.path.join(experiment_path, "test_audios") + config.output_log_path = experiment_path + # setup rank 0 process in distributed training + dashboard_logger = None + if args.rank == 0: + new_fields = {} + if args.restore_path: + new_fields["restore_path"] = args.restore_path + new_fields["github_branch"] = get_git_branch() + # if model characters are not set in the config file + # save the default set to the config file for future + # compatibility. + if config.has("characters") and config.characters is None: + used_characters = parse_symbols() + new_fields["characters"] = used_characters + copy_model_files(config, experiment_path, new_fields) + dashboard_logger = logger_factory(config, experiment_path) + c_logger = ConsoleLogger() + return config, experiment_path, audio_path, c_logger, dashboard_logger + + +def init_arguments(): + train_config = TrainArgs() + parser = train_config.init_argparse(arg_prefix="") + return parser + + +def init_training(config: Coqpit = None): + """Initialization of a training run.""" + parser = init_arguments() + args = parser.parse_known_args() + config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger = process_args(args, config) + return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, dashboard_logger diff --git a/TTS/encoder/utils/visual.py b/TTS/encoder/utils/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..bfe40605dfd163abb907228e414f46989285e91d --- /dev/null +++ b/TTS/encoder/utils/visual.py @@ -0,0 +1,53 @@ +import matplotlib +import matplotlib.pyplot as plt +import numpy as np + +matplotlib.use("Agg") + + +colormap = ( + np.array( + [ + [76, 255, 0], + [0, 127, 70], + [255, 0, 0], + [255, 217, 38], + [0, 135, 255], + [165, 0, 165], + [255, 167, 255], + [0, 255, 255], + [255, 96, 38], + [142, 76, 0], + [33, 0, 127], + [0, 0, 0], + [183, 183, 183], + ], + dtype=float, + ) + / 255 +) + + +def plot_embeddings(embeddings, num_classes_in_batch): + try: + import umap + except ImportError as e: + raise ImportError("Package not installed: umap-learn") from e + num_utter_per_class = embeddings.shape[0] // num_classes_in_batch + + # if necessary get just the first 10 classes + if num_classes_in_batch > 10: + num_classes_in_batch = 10 + embeddings = embeddings[: num_classes_in_batch * num_utter_per_class] + + model = umap.UMAP() + projection = model.fit_transform(embeddings) + ground_truth = np.repeat(np.arange(num_classes_in_batch), num_utter_per_class) + colors = [colormap[i] for i in ground_truth] + fig, ax = plt.subplots(figsize=(16, 10)) + _ = ax.scatter(projection[:, 0], projection[:, 1], c=colors) + plt.gca().set_aspect("equal", "datalim") + plt.title("UMAP projection") + plt.tight_layout() + plt.savefig("umap") + return fig diff --git a/TTS/model.py b/TTS/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c3707c85ae1e430769d4499aa9c1c54ebc4d092b --- /dev/null +++ b/TTS/model.py @@ -0,0 +1,66 @@ +import os +from abc import abstractmethod +from typing import Any, Union + +import torch +from coqpit import Coqpit +from trainer import TrainerModel + +# pylint: skip-file + + +class BaseTrainerModel(TrainerModel): + """BaseTrainerModel model expanding TrainerModel with required functions by 🐸TTS. + + Every new 🐸TTS model must inherit it. + """ + + @staticmethod + @abstractmethod + def init_from_config(config: Coqpit) -> "BaseTrainerModel": + """Init the model and all its attributes from the given config. + + Override this depending on your model. + """ + ... + + @abstractmethod + def inference(self, input: torch.Tensor, aux_input: dict[str, Any] = {}) -> dict[str, Any]: + """Forward pass for inference. + + It must return a dictionary with the main model output and all the auxiliary outputs. The key ```model_outputs``` + is considered to be the main output and you can add any other auxiliary outputs as you want. + + We don't use `*kwargs` since it is problematic with the TorchScript API. + + Args: + input (torch.Tensor): [description] + aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc. + + Returns: + Dict: [description] + """ + outputs_dict = {"model_outputs": None} + ... + return outputs_dict + + @abstractmethod + def load_checkpoint( + self, + config: Coqpit, + checkpoint_path: Union[str, os.PathLike[Any]], + eval: bool = False, + strict: bool = True, + cache: bool = False, + ) -> None: + """Load a model checkpoint file and get ready for training or inference. + + Args: + config (Coqpit): Model configuration. + checkpoint_path (str | os.PathLike): Path to the model checkpoint file. + eval (bool, optional): If true, init model for inference else for training. Defaults to False. + strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. + cache (bool, optional): If True, cache the file locally for subsequent calls. + It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False. + """ + ... diff --git a/TTS/server/README.md b/TTS/server/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ae8e38a4e39e3b18f66ed784fd964dd633494c6b --- /dev/null +++ b/TTS/server/README.md @@ -0,0 +1,21 @@ +# :frog: TTS demo server +Before you use the server, make sure you +[install](https://github.com/idiap/coqui-ai-TTS/tree/dev#install-tts)) :frog: TTS +properly and install the additional dependencies with `pip install +coqui-tts[server]`. Then, you can follow the steps below. + +**Note:** If you install :frog:TTS using ```pip```, you can also use the ```tts-server``` end point on the terminal. + +Examples runs: + +List officially released models. +```python TTS/server/server.py --list_models ``` + +Run the server with the official models. +```python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/multiband-melgan``` + +Run the server with the official models on a GPU. +```CUDA_VISIBLE_DEVICES="0" python TTS/server/server.py --model_name tts_models/en/ljspeech/tacotron2-DCA --vocoder_name vocoder_models/en/ljspeech/multiband-melgan --use_cuda``` + +Run the server with a custom models. +```python TTS/server/server.py --tts_checkpoint /path/to/tts/model.pth --tts_config /path/to/tts/config.json --vocoder_checkpoint /path/to/vocoder/model.pth --vocoder_config /path/to/vocoder/config.json``` diff --git a/TTS/server/__init__.py b/TTS/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/server/conf.json b/TTS/server/conf.json new file mode 100644 index 0000000000000000000000000000000000000000..49b6c09c3848a224dfb39a1f653aa1b289a4b6e5 --- /dev/null +++ b/TTS/server/conf.json @@ -0,0 +1,12 @@ +{ + "tts_path":"/media/erogol/data_ssd/Models/libri_tts/5049/", // tts model root folder + "tts_file":"best_model.pth", // tts checkpoint file + "tts_config":"config.json", // tts config.json file + "tts_speakers": null, // json file listing speaker ids. null if no speaker embedding. + "vocoder_config":null, + "vocoder_file": null, + "is_wavernn_batched":true, + "port": 5002, + "use_cuda": true, + "debug": true +} diff --git a/TTS/server/server.py b/TTS/server/server.py new file mode 100644 index 0000000000000000000000000000000000000000..f410fb7539fa77e3875035817fa4cefc7c40c45f --- /dev/null +++ b/TTS/server/server.py @@ -0,0 +1,262 @@ +#!flask/bin/python + +"""TTS demo server.""" + +import argparse +import io +import json +import logging +import os +import sys +from pathlib import Path +from threading import Lock +from typing import Union +from urllib.parse import parse_qs + +try: + from flask import Flask, render_template, render_template_string, request, send_file +except ImportError as e: + msg = "Server requires requires flask, use `pip install coqui-tts[server]`" + raise ImportError(msg) from e + +from TTS.config import load_config +from TTS.utils.generic_utils import ConsoleFormatter, setup_logger +from TTS.utils.manage import ModelManager +from TTS.utils.synthesizer import Synthesizer + +logger = logging.getLogger(__name__) +setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter()) + + +def create_argparser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--list_models", + action="store_true", + help="list available pre-trained tts and vocoder models.", + ) + parser.add_argument( + "--model_name", + type=str, + default="tts_models/en/ljspeech/tacotron2-DDC", + help="Name of one of the pre-trained tts models in format //", + ) + parser.add_argument("--vocoder_name", type=str, default=None, help="name of one of the released vocoder models.") + + # Args for running custom models + parser.add_argument("--config_path", default=None, type=str, help="Path to model config file.") + parser.add_argument( + "--model_path", + type=str, + default=None, + help="Path to model file.", + ) + parser.add_argument( + "--vocoder_path", + type=str, + help="Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).", + default=None, + ) + parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None) + parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) + parser.add_argument("--port", type=int, default=5002, help="port to listen on.") + parser.add_argument("--use_cuda", action=argparse.BooleanOptionalAction, default=False, help="true to use CUDA.") + parser.add_argument( + "--debug", action=argparse.BooleanOptionalAction, default=False, help="true to enable Flask debug mode." + ) + parser.add_argument( + "--show_details", action=argparse.BooleanOptionalAction, default=False, help="Generate model detail page." + ) + return parser + + +# parse the args +args = create_argparser().parse_args() + +path = Path(__file__).parent / "../.models.json" +manager = ModelManager(path) + +# update in-use models to the specified released models. +model_path = None +config_path = None +speakers_file_path = None +vocoder_path = None +vocoder_config_path = None + +# CASE1: list pre-trained TTS models +if args.list_models: + manager.list_models() + sys.exit() + +# CASE2: load pre-trained model paths +if args.model_name is not None and not args.model_path: + model_path, config_path, model_item = manager.download_model(args.model_name) + args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name + +if args.vocoder_name is not None and not args.vocoder_path: + vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) + +# CASE3: set custom model paths +if args.model_path is not None: + model_path = args.model_path + config_path = args.config_path + speakers_file_path = args.speakers_file_path + +if args.vocoder_path is not None: + vocoder_path = args.vocoder_path + vocoder_config_path = args.vocoder_config_path + +# load models +synthesizer = Synthesizer( + tts_checkpoint=model_path, + tts_config_path=config_path, + tts_speakers_file=speakers_file_path, + tts_languages_file=None, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config_path, + encoder_checkpoint="", + encoder_config="", + use_cuda=args.use_cuda, +) + +use_multi_speaker = hasattr(synthesizer.tts_model, "num_speakers") and ( + synthesizer.tts_model.num_speakers > 1 or synthesizer.tts_speakers_file is not None +) +speaker_manager = getattr(synthesizer.tts_model, "speaker_manager", None) + +use_multi_language = hasattr(synthesizer.tts_model, "num_languages") and ( + synthesizer.tts_model.num_languages > 1 or synthesizer.tts_languages_file is not None +) +language_manager = getattr(synthesizer.tts_model, "language_manager", None) + +# TODO: set this from SpeakerManager +use_gst = synthesizer.tts_config.get("use_gst", False) +app = Flask(__name__) + + +def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]: + """Transform an uri style_wav, in either a string (path to wav file to be use for style transfer) + or a dict (gst tokens/values to be use for styling) + + Args: + style_wav (str): uri + + Returns: + Union[str, dict]: path to file (str) or gst style (dict) + """ + if style_wav: + if os.path.isfile(style_wav) and style_wav.endswith(".wav"): + return style_wav # style_wav is a .wav file located on the server + + style_wav = json.loads(style_wav) + return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...} + return None + + +@app.route("/") +def index(): + return render_template( + "index.html", + show_details=args.show_details, + use_multi_speaker=use_multi_speaker, + use_multi_language=use_multi_language, + speaker_ids=speaker_manager.name_to_id if speaker_manager is not None else None, + language_ids=language_manager.name_to_id if language_manager is not None else None, + use_gst=use_gst, + ) + + +@app.route("/details") +def details(): + if args.config_path is not None and os.path.isfile(args.config_path): + model_config = load_config(args.config_path) + elif args.model_name is not None: + model_config = load_config(config_path) + + if args.vocoder_config_path is not None and os.path.isfile(args.vocoder_config_path): + vocoder_config = load_config(args.vocoder_config_path) + elif args.vocoder_name is not None: + vocoder_config = load_config(vocoder_config_path) + else: + vocoder_config = None + + return render_template( + "details.html", + show_details=args.show_details, + model_config=model_config, + vocoder_config=vocoder_config, + args=args.__dict__, + ) + + +lock = Lock() + + +@app.route("/api/tts", methods=["GET", "POST"]) +def tts(): + with lock: + text = request.headers.get("text") or request.values.get("text", "") + speaker_idx = request.headers.get("speaker-id") or request.values.get("speaker_id", "") + language_idx = request.headers.get("language-id") or request.values.get("language_id", "") + style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "") + style_wav = style_wav_uri_to_dict(style_wav) + + logger.info("Model input: %s", text) + logger.info("Speaker idx: %s", speaker_idx) + logger.info("Language idx: %s", language_idx) + wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav) + out = io.BytesIO() + synthesizer.save_wav(wavs, out) + return send_file(out, mimetype="audio/wav") + + +# Basic MaryTTS compatibility layer + + +@app.route("/locales", methods=["GET"]) +def mary_tts_api_locales(): + """MaryTTS-compatible /locales endpoint""" + # NOTE: We currently assume there is only one model active at the same time + if args.model_name is not None: + model_details = args.model_name.split("/") + else: + model_details = ["", "en", "", "default"] + return render_template_string("{{ locale }}\n", locale=model_details[1]) + + +@app.route("/voices", methods=["GET"]) +def mary_tts_api_voices(): + """MaryTTS-compatible /voices endpoint""" + # NOTE: We currently assume there is only one model active at the same time + if args.model_name is not None: + model_details = args.model_name.split("/") + else: + model_details = ["", "en", "", "default"] + return render_template_string( + "{{ name }} {{ locale }} {{ gender }}\n", name=model_details[3], locale=model_details[1], gender="u" + ) + + +@app.route("/process", methods=["GET", "POST"]) +def mary_tts_api_process(): + """MaryTTS-compatible /process endpoint""" + with lock: + if request.method == "POST": + data = parse_qs(request.get_data(as_text=True)) + # NOTE: we ignore param. LOCALE and VOICE for now since we have only one active model + text = data.get("INPUT_TEXT", [""])[0] + else: + text = request.args.get("INPUT_TEXT", "") + logger.info("Model input: %s", text) + wavs = synthesizer.tts(text) + out = io.BytesIO() + synthesizer.save_wav(wavs, out) + return send_file(out, mimetype="audio/wav") + + +def main(): + app.run(debug=args.debug, host="::", port=args.port) + + +if __name__ == "__main__": + main() diff --git a/TTS/server/static/coqui-log-green-TTS.png b/TTS/server/static/coqui-log-green-TTS.png new file mode 100644 index 0000000000000000000000000000000000000000..6ad188b8c03a170097c0393c6769996f03cf9054 Binary files /dev/null and b/TTS/server/static/coqui-log-green-TTS.png differ diff --git a/TTS/server/templates/details.html b/TTS/server/templates/details.html new file mode 100644 index 0000000000000000000000000000000000000000..85ff9595911b809471a7c4a22a4a150a6d169d91 --- /dev/null +++ b/TTS/server/templates/details.html @@ -0,0 +1,131 @@ + + + + + + + + + + + TTS engine + + + + + + + + + + Fork me on GitHub + + {% if show_details == true %} + +
+ Model details +
+ +
+
+ CLI arguments: + + + + + + + {% for key, value in args.items() %} + + + + + + + {% endfor %} +
CLI key Value
{{ key }}{{ value }}
+
+

+ +
+ + {% if model_config != None %} + +
+ Model config: + + + + + + + + + {% for key, value in model_config.items() %} + + + + + + + {% endfor %} + +
Key Value
{{ key }}{{ value }}
+
+ + {% endif %} + +

+ + + +
+ {% if vocoder_config != None %} +
+ Vocoder model config: + + + + + + + + + {% for key, value in vocoder_config.items() %} + + + + + + + {% endfor %} + + +
Key Value
{{ key }}{{ value }}
+
+ {% endif %} +

+ + {% else %} +
+ Please start server with --show_details=true to see details. +
+ + {% endif %} + + + + diff --git a/TTS/server/templates/index.html b/TTS/server/templates/index.html new file mode 100644 index 0000000000000000000000000000000000000000..6bfd5ae2cb3e23b47c6a804dac1f2152658c1d43 --- /dev/null +++ b/TTS/server/templates/index.html @@ -0,0 +1,154 @@ + + + + + + + + + + + TTS engine + + + + + + + + + + Fork me on GitHub + + + + + +
+
+
+ + +
    +
+ + {%if use_gst%} + + {%endif%} + + +

+ + {%if use_multi_speaker%} + Choose a speaker: +

+ {%endif%} + + {%if use_multi_language%} + Choose a language: +

+ {%endif%} + + + {%if show_details%} +

+ {%endif%} + +

+
+
+
+ + + + + + + diff --git a/TTS/tts/__init__.py b/TTS/tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/configs/__init__.py b/TTS/tts/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3146ac1c116cb807a81889b7a9ab223b9a051036 --- /dev/null +++ b/TTS/tts/configs/__init__.py @@ -0,0 +1,17 @@ +import importlib +import os +from inspect import isclass + +# import all files under configs/ +# configs_dir = os.path.dirname(__file__) +# for file in os.listdir(configs_dir): +# path = os.path.join(configs_dir, file) +# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): +# config_name = file[: file.find(".py")] if file.endswith(".py") else file +# module = importlib.import_module("TTS.tts.configs." + config_name) +# for attribute_name in dir(module): +# attribute = getattr(module, attribute_name) + +# if isclass(attribute): +# # Add the class to this package's variables +# globals()[attribute_name] = attribute diff --git a/TTS/tts/configs/align_tts_config.py b/TTS/tts/configs/align_tts_config.py new file mode 100644 index 0000000000000000000000000000000000000000..317a01af53ce26914d83610a913eb44b5836dac2 --- /dev/null +++ b/TTS/tts/configs/align_tts_config.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.align_tts import AlignTTSArgs + + +@dataclass +class AlignTTSConfig(BaseTTSConfig): + """Defines parameters for AlignTTS model. + Example: + + >>> from TTS.tts.configs.align_tts_config import AlignTTSConfig + >>> config = AlignTTSConfig() + + Args: + model(str): + Model name used for selecting the right model at initialization. Defaults to `align_tts`. + positional_encoding (bool): + enable / disable positional encoding applied to the encoder output. Defaults to True. + hidden_channels (int): + Base number of hidden channels. Defines all the layers expect ones defined by the specific encoder or decoder + parameters. Defaults to 256. + hidden_channels_dp (int): + Number of hidden channels of the duration predictor's layers. Defaults to 256. + encoder_type (str): + Type of the encoder used by the model. Look at `TTS.tts.layers.feed_forward.encoder` for more details. + Defaults to `fftransformer`. + encoder_params (dict): + Parameters used to define the encoder network. Look at `TTS.tts.layers.feed_forward.encoder` for more details. + Defaults to `{"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}`. + decoder_type (str): + Type of the decoder used by the model. Look at `TTS.tts.layers.feed_forward.decoder` for more details. + Defaults to `fftransformer`. + decoder_params (dict): + Parameters used to define the decoder network. Look at `TTS.tts.layers.feed_forward.decoder` for more details. + Defaults to `{"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}`. + phase_start_steps (List[int]): + A list of number of steps required to start the next training phase. AlignTTS has 4 different training + phases. Thus you need to define 4 different values to enable phase based training. If None, it + trains the whole model together. Defaults to None. + ssim_alpha (float): + Weight for the SSIM loss. If set <= 0, disables the SSIM loss. Defaults to 1.0. + duration_loss_alpha (float): + Weight for the duration predictor's loss. Defaults to 1.0. + mdn_alpha (float): + Weight for the MDN loss. Defaults to 1.0. + spec_loss_alpha (float): + Weight for the MSE spectrogram loss. If set <= 0, disables the L1 loss. Defaults to 1.0. + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + noam_schedule (bool): + enable / disable the use of Noam LR scheduler. Defaults to False. + warmup_steps (int): + Number of warm-up steps for the Noam scheduler. Defaults 4000. + lr (float): + Initial learning rate. Defaults to `1e-3`. + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage.""" + + model: str = "align_tts" + # model specific params + model_args: AlignTTSArgs = field(default_factory=AlignTTSArgs) + phase_start_steps: List[int] = None + + ssim_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + mdn_alpha: float = 1.0 + + # multi-speaker settings + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = None + lr_scheduler_params: dict = None + lr: float = 1e-4 + grad_clip: float = 5.0 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) diff --git a/TTS/tts/configs/bark_config.py b/TTS/tts/configs/bark_config.py new file mode 100644 index 0000000000000000000000000000000000000000..b846febe85cf4af2c750e2bd1e01861bc2ad677f --- /dev/null +++ b/TTS/tts/configs/bark_config.py @@ -0,0 +1,105 @@ +import os +from dataclasses import dataclass, field +from typing import Dict + +from trainer.io import get_user_data_dir + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.layers.bark.model import GPTConfig +from TTS.tts.layers.bark.model_fine import FineGPTConfig +from TTS.tts.models.bark import BarkAudioConfig + + +@dataclass +class BarkConfig(BaseTTSConfig): + """Bark TTS configuration + + Args: + model (str): model name that registers the model. + audio (BarkAudioConfig): audio configuration. Defaults to BarkAudioConfig(). + num_chars (int): number of characters in the alphabet. Defaults to 0. + semantic_config (GPTConfig): semantic configuration. Defaults to GPTConfig(). + fine_config (FineGPTConfig): fine configuration. Defaults to FineGPTConfig(). + coarse_config (GPTConfig): coarse configuration. Defaults to GPTConfig(). + CONTEXT_WINDOW_SIZE (int): GPT context window size. Defaults to 1024. + SEMANTIC_RATE_HZ (float): semantic tokens rate in Hz. Defaults to 49.9. + SEMANTIC_VOCAB_SIZE (int): semantic vocabulary size. Defaults to 10_000. + CODEBOOK_SIZE (int): encodec codebook size. Defaults to 1024. + N_COARSE_CODEBOOKS (int): number of coarse codebooks. Defaults to 2. + N_FINE_CODEBOOKS (int): number of fine codebooks. Defaults to 8. + COARSE_RATE_HZ (int): coarse tokens rate in Hz. Defaults to 75. + SAMPLE_RATE (int): sample rate. Defaults to 24_000. + USE_SMALLER_MODELS (bool): use smaller models. Defaults to False. + TEXT_ENCODING_OFFSET (int): text encoding offset. Defaults to 10_048. + SEMANTIC_PAD_TOKEN (int): semantic pad token. Defaults to 10_000. + TEXT_PAD_TOKEN ([type]): text pad token. Defaults to 10_048. + TEXT_EOS_TOKEN ([type]): text end of sentence token. Defaults to 10_049. + TEXT_SOS_TOKEN ([type]): text start of sentence token. Defaults to 10_050. + SEMANTIC_INFER_TOKEN (int): semantic infer token. Defaults to 10_051. + COARSE_SEMANTIC_PAD_TOKEN (int): coarse semantic pad token. Defaults to 12_048. + COARSE_INFER_TOKEN (int): coarse infer token. Defaults to 12_050. + REMOTE_BASE_URL ([type]): remote base url. Defaults to "https://huggingface.co/erogol/bark/tree". + REMOTE_MODEL_PATHS (Dict): remote model paths. Defaults to None. + LOCAL_MODEL_PATHS (Dict): local model paths. Defaults to None. + SMALL_REMOTE_MODEL_PATHS (Dict): small remote model paths. Defaults to None. + CACHE_DIR (str): local cache directory. Defaults to get_user_data_dir(). + DEF_SPEAKER_DIR (str): default speaker directory to stoke speaker values for voice cloning. Defaults to get_user_data_dir(). + """ + + model: str = "bark" + audio: BarkAudioConfig = field(default_factory=BarkAudioConfig) + num_chars: int = 0 + semantic_config: GPTConfig = field(default_factory=GPTConfig) + fine_config: FineGPTConfig = field(default_factory=FineGPTConfig) + coarse_config: GPTConfig = field(default_factory=GPTConfig) + CONTEXT_WINDOW_SIZE: int = 1024 + SEMANTIC_RATE_HZ: float = 49.9 + SEMANTIC_VOCAB_SIZE: int = 10_000 + CODEBOOK_SIZE: int = 1024 + N_COARSE_CODEBOOKS: int = 2 + N_FINE_CODEBOOKS: int = 8 + COARSE_RATE_HZ: int = 75 + SAMPLE_RATE: int = 24_000 + USE_SMALLER_MODELS: bool = False + + TEXT_ENCODING_OFFSET: int = 10_048 + SEMANTIC_PAD_TOKEN: int = 10_000 + TEXT_PAD_TOKEN: int = 129_595 + SEMANTIC_INFER_TOKEN: int = 129_599 + COARSE_SEMANTIC_PAD_TOKEN: int = 12_048 + COARSE_INFER_TOKEN: int = 12_050 + + REMOTE_BASE_URL = "https://huggingface.co/erogol/bark/tree/main/" + REMOTE_MODEL_PATHS: Dict = None + LOCAL_MODEL_PATHS: Dict = None + SMALL_REMOTE_MODEL_PATHS: Dict = None + CACHE_DIR: str = str(get_user_data_dir("tts/suno/bark_v0")) + DEF_SPEAKER_DIR: str = str(get_user_data_dir("tts/bark_v0/speakers")) + + def __post_init__(self): + self.REMOTE_MODEL_PATHS = { + "text": { + "path": os.path.join(self.REMOTE_BASE_URL, "text_2.pt"), + "checksum": "54afa89d65e318d4f5f80e8e8799026a", + }, + "coarse": { + "path": os.path.join(self.REMOTE_BASE_URL, "coarse_2.pt"), + "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", + }, + "fine": { + "path": os.path.join(self.REMOTE_BASE_URL, "fine_2.pt"), + "checksum": "59d184ed44e3650774a2f0503a48a97b", + }, + } + self.LOCAL_MODEL_PATHS = { + "text": os.path.join(self.CACHE_DIR, "text_2.pt"), + "coarse": os.path.join(self.CACHE_DIR, "coarse_2.pt"), + "fine": os.path.join(self.CACHE_DIR, "fine_2.pt"), + "hubert_tokenizer": os.path.join(self.CACHE_DIR, "tokenizer.pth"), + } + self.SMALL_REMOTE_MODEL_PATHS = { + "text": {"path": os.path.join(self.REMOTE_BASE_URL, "text.pt")}, + "coarse": {"path": os.path.join(self.REMOTE_BASE_URL, "coarse.pt")}, + "fine": {"path": os.path.join(self.REMOTE_BASE_URL, "fine.pt")}, + } + self.sample_rate = self.SAMPLE_RATE # pylint: disable=attribute-defined-outside-init diff --git a/TTS/tts/configs/delightful_tts_config.py b/TTS/tts/configs/delightful_tts_config.py new file mode 100644 index 0000000000000000000000000000000000000000..805d995369e29fce7d6aa87750356b21458cd64a --- /dev/null +++ b/TTS/tts/configs/delightful_tts_config.py @@ -0,0 +1,170 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.delightful_tts import DelightfulTtsArgs, DelightfulTtsAudioConfig, VocoderConfig + + +@dataclass +class DelightfulTTSConfig(BaseTTSConfig): + """ + Configuration class for the DelightfulTTS model. + + Attributes: + model (str): Name of the model ("delightful_tts"). + audio (DelightfulTtsAudioConfig): Configuration for audio settings. + model_args (DelightfulTtsArgs): Configuration for model arguments. + use_attn_priors (bool): Whether to use attention priors. + vocoder (VocoderConfig): Configuration for the vocoder. + init_discriminator (bool): Whether to initialize the discriminator. + steps_to_start_discriminator (int): Number of steps to start the discriminator. + grad_clip (List[float]): Gradient clipping values. + lr_gen (float): Learning rate for the gan generator. + lr_disc (float): Learning rate for the gan discriminator. + lr_scheduler_gen (str): Name of the learning rate scheduler for the generator. + lr_scheduler_gen_params (dict): Parameters for the learning rate scheduler for the generator. + lr_scheduler_disc (str): Name of the learning rate scheduler for the discriminator. + lr_scheduler_disc_params (dict): Parameters for the learning rate scheduler for the discriminator. + scheduler_after_epoch (bool): Whether to schedule after each epoch. + optimizer (str): Name of the optimizer. + optimizer_params (dict): Parameters for the optimizer. + ssim_loss_alpha (float): Alpha value for the SSIM loss. + mel_loss_alpha (float): Alpha value for the mel loss. + aligner_loss_alpha (float): Alpha value for the aligner loss. + pitch_loss_alpha (float): Alpha value for the pitch loss. + energy_loss_alpha (float): Alpha value for the energy loss. + u_prosody_loss_alpha (float): Alpha value for the utterance prosody loss. + p_prosody_loss_alpha (float): Alpha value for the phoneme prosody loss. + dur_loss_alpha (float): Alpha value for the duration loss. + char_dur_loss_alpha (float): Alpha value for the character duration loss. + binary_align_loss_alpha (float): Alpha value for the binary alignment loss. + binary_loss_warmup_epochs (int): Number of warm-up epochs for the binary loss. + disc_loss_alpha (float): Alpha value for the discriminator loss. + gen_loss_alpha (float): Alpha value for the generator loss. + feat_loss_alpha (float): Alpha value for the feature loss. + vocoder_mel_loss_alpha (float): Alpha value for the vocoder mel loss. + multi_scale_stft_loss_alpha (float): Alpha value for the multi-scale STFT loss. + multi_scale_stft_loss_params (dict): Parameters for the multi-scale STFT loss. + return_wav (bool): Whether to return audio waveforms. + use_weighted_sampler (bool): Whether to use a weighted sampler. + weighted_sampler_attrs (dict): Attributes for the weighted sampler. + weighted_sampler_multipliers (dict): Multipliers for the weighted sampler. + r (int): Value for the `r` override. + compute_f0 (bool): Whether to compute F0 values. + f0_cache_path (str): Path to the F0 cache. + attn_prior_cache_path (str): Path to the attention prior cache. + num_speakers (int): Number of speakers. + use_speaker_embedding (bool): Whether to use speaker embedding. + speakers_file (str): Path to the speaker file. + speaker_embedding_channels (int): Number of channels for the speaker embedding. + language_ids_file (str): Path to the language IDs file. + """ + + model: str = "delightful_tts" + + # model specific params + audio: DelightfulTtsAudioConfig = field(default_factory=DelightfulTtsAudioConfig) + model_args: DelightfulTtsArgs = field(default_factory=DelightfulTtsArgs) + use_attn_priors: bool = True + + # vocoder + vocoder: VocoderConfig = field(default_factory=VocoderConfig) + init_discriminator: bool = True + + # optimizer + steps_to_start_discriminator: int = 200000 + grad_clip: List[float] = field(default_factory=lambda: [1000, 1000]) + lr_gen: float = 0.0002 + lr_disc: float = 0.0002 + lr_scheduler_gen: str = "ExponentialLR" + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + lr_scheduler_disc: str = "ExponentialLR" + lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + scheduler_after_epoch: bool = True + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01}) + + # acoustic model loss params + ssim_loss_alpha: float = 1.0 + mel_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 1.0 + energy_loss_alpha: float = 1.0 + u_prosody_loss_alpha: float = 0.5 + p_prosody_loss_alpha: float = 0.5 + dur_loss_alpha: float = 1.0 + char_dur_loss_alpha: float = 0.01 + binary_align_loss_alpha: float = 0.1 + binary_loss_warmup_epochs: int = 10 + + # vocoder loss params + disc_loss_alpha: float = 1.0 + gen_loss_alpha: float = 1.0 + feat_loss_alpha: float = 1.0 + vocoder_mel_loss_alpha: float = 10.0 + multi_scale_stft_loss_alpha: float = 2.5 + multi_scale_stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + # data loader params + return_wav: bool = True + use_weighted_sampler: bool = False + weighted_sampler_attrs: dict = field(default_factory=lambda: {}) + weighted_sampler_multipliers: dict = field(default_factory=lambda: {}) + + # overrides + r: int = 1 + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + attn_prior_cache_path: str = None + + # multi-speaker settings + # use speaker embedding layer + num_speakers: int = 0 + use_speaker_embedding: bool = False + speakers_file: str = None + speaker_embedding_channels: int = 256 + language_ids_file: str = None + use_language_embedding: bool = False + + # use d-vectors + use_d_vector_file: bool = False + d_vector_file: str = None + d_vector_dim: int = None + + # testing + test_sentences: List[List[str]] = field( + default_factory=lambda: [ + ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."], + ["Be a voice, not an echo."], + ["I'm sorry Dave. I'm afraid I can't do that."], + ["This cake is great. It's so delicious and moist."], + ["Prior to November 22, 1963."], + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/configs/fast_pitch_config.py b/TTS/tts/configs/fast_pitch_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d086d26564450c60fa04a7f3a068506f4147d3be --- /dev/null +++ b/TTS/tts/configs/fast_pitch_config.py @@ -0,0 +1,183 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class FastPitchConfig(BaseTTSConfig): + """Configure `ForwardTTS` as FastPitch model. + + Example: + + >>> from TTS.tts.configs.fast_pitch_config import FastPitchConfig + >>> config = FastPitchConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `Adam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + binary_align_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + + # dataset configs + compute_f0(bool): + Compute pitch. defaults to True + + f0_cache_path(str): + pith cache path. defaults to None + """ + + model: str = "fast_pitch" + base_model: str = "forward_tts" + + # model specific params + model_args: ForwardTTSArgs = field(default_factory=ForwardTTSArgs) + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "mse" + duration_loss_type: str = "mse" + use_ssim_loss: bool = True + ssim_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 0.1 + dur_loss_alpha: float = 0.1 + binary_align_loss_alpha: float = 0.1 + binary_loss_warmup_epochs: int = 150 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/configs/fast_speech_config.py b/TTS/tts/configs/fast_speech_config.py new file mode 100644 index 0000000000000000000000000000000000000000..af6c2db6faf55ee2b15047fff86281d42dab1b87 --- /dev/null +++ b/TTS/tts/configs/fast_speech_config.py @@ -0,0 +1,177 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class FastSpeechConfig(BaseTTSConfig): + """Configure `ForwardTTS` as FastSpeech model. + + Example: + + >>> from TTS.tts.configs.fast_speech_config import FastSpeechConfig + >>> config = FastSpeechConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastSpeechArgs` for more details. Defaults to `FastSpeechArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `Adam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + binary_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "fast_speech" + base_model: str = "forward_tts" + + # model specific params + model_args: ForwardTTSArgs = field(default_factory=lambda: ForwardTTSArgs(use_pitch=False)) + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "mse" + duration_loss_type: str = "mse" + use_ssim_loss: bool = True + ssim_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 0.0 + aligner_loss_alpha: float = 1.0 + binary_align_loss_alpha: float = 1.0 + binary_loss_warmup_epochs: int = 150 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = False + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/configs/fastspeech2_config.py b/TTS/tts/configs/fastspeech2_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d179617fb034fff269355ce7e3d78b67db90aacd --- /dev/null +++ b/TTS/tts/configs/fastspeech2_config.py @@ -0,0 +1,198 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class Fastspeech2Config(BaseTTSConfig): + """Configure `ForwardTTS` as FastPitch model. + + Example: + + >>> from TTS.tts.configs.fastspeech2_config import FastSpeech2Config + >>> config = FastSpeech2Config() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `Adam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + pitch_loss_alpha (float): + Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0. + + energy_loss_alpha (float): + Weight for the energy predictor's loss. If set 0, disables the energy predictor. Defaults to 1.0. + + binary_align_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + + # dataset configs + compute_f0(bool): + Compute pitch. defaults to True + + f0_cache_path(str): + pith cache path. defaults to None + + # dataset configs + compute_energy(bool): + Compute energy. defaults to True + + energy_cache_path(str): + energy cache path. defaults to None + """ + + model: str = "fastspeech2" + base_model: str = "forward_tts" + + # model specific params + model_args: ForwardTTSArgs = field(default_factory=lambda: ForwardTTSArgs(use_pitch=True, use_energy=True)) + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "mse" + duration_loss_type: str = "mse" + use_ssim_loss: bool = True + ssim_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + pitch_loss_alpha: float = 0.1 + energy_loss_alpha: float = 0.1 + dur_loss_alpha: float = 0.1 + binary_align_loss_alpha: float = 0.1 + binary_loss_warmup_epochs: int = 150 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = True + f0_cache_path: str = None + + # dataset configs + compute_energy: bool = True + energy_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/configs/glow_tts_config.py b/TTS/tts/configs/glow_tts_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f42f3e5a510bacf1b2312ccea7d46201bbcb774f --- /dev/null +++ b/TTS/tts/configs/glow_tts_config.py @@ -0,0 +1,182 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig + + +@dataclass +class GlowTTSConfig(BaseTTSConfig): + """Defines parameters for GlowTTS model. + + Example: + + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig + >>> config = GlowTTSConfig() + + Args: + model(str): + Model name used for selecting the right model at initialization. Defaults to `glow_tts`. + encoder_type (str): + Type of the encoder used by the model. Look at `TTS.tts.layers.glow_tts.encoder` for more details. + Defaults to `rel_pos_transformers`. + encoder_params (dict): + Parameters used to define the encoder network. Look at `TTS.tts.layers.glow_tts.encoder` for more details. + Defaults to `{"kernel_size": 3, "dropout_p": 0.1, "num_layers": 6, "num_heads": 2, "hidden_channels_ffn": 768}` + use_encoder_prenet (bool): + enable / disable the use of a prenet for the encoder. Defaults to True. + hidden_channels_enc (int): + Number of base hidden channels used by the encoder network. It defines the input and the output channel sizes, + and for some encoder types internal hidden channels sizes too. Defaults to 192. + hidden_channels_dec (int): + Number of base hidden channels used by the decoder WaveNet network. Defaults to 192 as in the original work. + hidden_channels_dp (int): + Number of layer channels of the duration predictor network. Defaults to 256 as in the original work. + mean_only (bool): + If true predict only the mean values by the decoder flow. Defaults to True. + out_channels (int): + Number of channels of the model output tensor. Defaults to 80. + num_flow_blocks_dec (int): + Number of decoder blocks. Defaults to 12. + inference_noise_scale (float): + Noise scale used at inference. Defaults to 0.33. + kernel_size_dec (int): + Decoder kernel size. Defaults to 5 + dilation_rate (int): + Rate to increase dilation by each layer in a decoder block. Defaults to 1. + num_block_layers (int): + Number of decoder layers in each decoder block. Defaults to 4. + dropout_p_dec (float): + Dropout rate for decoder. Defaults to 0.1. + num_speaker (int): + Number of speaker to define the size of speaker embedding layer. Defaults to 0. + c_in_channels (int): + Number of speaker embedding channels. It is set to 512 if embeddings are learned. Defaults to 0. + num_splits (int): + Number of split levels in inversible conv1x1 operation. Defaults to 4. + num_squeeze (int): + Number of squeeze levels. When squeezing channels increases and time steps reduces by the factor + 'num_squeeze'. Defaults to 2. + sigmoid_scale (bool): + enable/disable sigmoid scaling in decoder. Defaults to False. + mean_only (bool): + If True, encoder only computes mean value and uses constant variance for each time step. Defaults to true. + encoder_type (str): + Encoder module type. Possible values are`["rel_pos_transformer", "gated_conv", "residual_conv_bn", "time_depth_separable"]` + Check `TTS.tts.layers.glow_tts.encoder` for more details. Defaults to `rel_pos_transformers` as in the original paper. + encoder_params (dict): + Encoder module parameters. Defaults to None. + d_vector_dim (int): + Channels of external speaker embedding vectors. Defaults to 0. + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + style_wav_for_test (str): + Path to the wav file used for changing the style of the speech. Defaults to None. + inference_noise_scale (float): + Variance used for sampling the random noise added to the decoder's input at inference. Defaults to 0.0. + length_scale (float): + Multiply the predicted durations with this value to change the speech speed. Defaults to 1. + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + noam_schedule (bool): + enable / disable the use of Noam LR scheduler. Defaults to False. + warmup_steps (int): + Number of warm-up steps for the Noam scheduler. Defaults 4000. + lr (float): + Initial learning rate. Defaults to `1e-3`. + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "glow_tts" + + # model params + num_chars: int = None + encoder_type: str = "rel_pos_transformer" + encoder_params: dict = field( + default_factory=lambda: { + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "num_heads": 2, + "hidden_channels_ffn": 768, + } + ) + use_encoder_prenet: bool = True + hidden_channels_enc: int = 192 + hidden_channels_dec: int = 192 + hidden_channels_dp: int = 256 + dropout_p_dp: float = 0.1 + dropout_p_dec: float = 0.05 + mean_only: bool = True + out_channels: int = 80 + num_flow_blocks_dec: int = 12 + inference_noise_scale: float = 0.33 + kernel_size_dec: int = 5 + dilation_rate: int = 1 + num_block_layers: int = 4 + num_speakers: int = 0 + c_in_channels: int = 0 + num_splits: int = 4 + num_squeeze: int = 2 + sigmoid_scale: bool = False + encoder_type: str = "rel_pos_transformer" + encoder_params: dict = field( + default_factory=lambda: { + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "num_heads": 2, + "hidden_channels_ffn": 768, + "input_length": None, + } + ) + d_vector_dim: int = 0 + + # training params + data_dep_init_steps: int = 10 + + # inference params + style_wav_for_test: str = None + inference_noise_scale: float = 0.0 + length_scale: float = 1.0 + + # multi-speaker settings + use_speaker_embedding: bool = False + speakers_file: str = None + use_d_vector_file: bool = False + d_vector_file: str = False + + # optimizer parameters + optimizer: str = "RAdam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + grad_clip: float = 5.0 + lr: float = 1e-3 + + # overrides + min_seq_len: int = 3 + max_seq_len: int = 500 + r: int = 1 # DO NOT CHANGE - TODO: make this immutable once coqpit implements it. + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) diff --git a/TTS/tts/configs/neuralhmm_tts_config.py b/TTS/tts/configs/neuralhmm_tts_config.py new file mode 100644 index 0000000000000000000000000000000000000000..50f72847ed3e1c7089915ef8fd77ae5775c5b260 --- /dev/null +++ b/TTS/tts/configs/neuralhmm_tts_config.py @@ -0,0 +1,170 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig + + +@dataclass +class NeuralhmmTTSConfig(BaseTTSConfig): + """ + Define parameters for Neural HMM TTS model. + + Example: + + >>> from TTS.tts.configs.overflow_config import OverflowConfig + >>> config = OverflowConfig() + + Args: + model (str): + Model name used to select the right model class to initilize. Defaults to `Overflow`. + run_eval_steps (int): + Run evalulation epoch after N steps. If None, waits until training epoch is completed. Defaults to None. + save_step (int): + Save local checkpoint every save_step steps. Defaults to 500. + plot_step (int): + Plot training stats on the logger every plot_step steps. Defaults to 1. + model_param_stats (bool): + Log model parameters stats on the logger dashboard. Defaults to False. + force_generate_statistics (bool): + Force generate mel normalization statistics. Defaults to False. + mel_statistics_parameter_path (str): + Path to the mel normalization statistics.If the model doesn't finds a file there it will generate statistics. + Defaults to None. + num_chars (int): + Number of characters used by the model. It must be defined before initializing the model. Defaults to None. + state_per_phone (int): + Generates N states per phone. Similar, to `add_blank` parameter in GlowTTS but in Overflow it is upsampled by model's encoder. Defaults to 2. + encoder_in_out_features (int): + Channels of encoder input and character embedding tensors. Defaults to 512. + encoder_n_convolutions (int): + Number of convolution layers in the encoder. Defaults to 3. + out_channels (int): + Channels of the final model output. It must match the spectragram size. Defaults to 80. + ar_order (int): + Autoregressive order of the model. Defaults to 1. In ablations of Neural HMM it was found that more autoregression while giving more variation hurts naturalness of the synthesised audio. + sampling_temp (float): + Variation added to the sample from the latent space of neural HMM. Defaults to 0.334. + deterministic_transition (bool): + deterministic duration generation based on duration quantiles as defiend in "S. Ronanki, O. Watts, S. King, and G. E. Henter, “Medianbased generation of synthetic speech durations using a nonparametric approach,” in Proc. SLT, 2016.". Defaults to True. + duration_threshold (float): + Threshold for duration quantiles. Defaults to 0.55. Tune this to change the speaking rate of the synthesis, where lower values defines a slower speaking rate and higher values defines a faster speaking rate. + use_grad_checkpointing (bool): + Use gradient checkpointing to save memory. In a multi-GPU setting currently pytorch does not supports gradient checkpoint inside a loop so we will have to turn it off then.Adjust depending on whatever get more batch size either by using a single GPU or multi-GPU. Defaults to True. + max_sampling_time (int): + Maximum sampling time while synthesising latents from neural HMM. Defaults to 1000. + prenet_type (str): + `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the + Prenet. Defaults to `original`. + prenet_dim (int): + Dimension of the Prenet. Defaults to 256. + prenet_n_layers (int): + Number of layers in the Prenet. Defaults to 2. + prenet_dropout (float): + Dropout rate of the Prenet. Defaults to 0.5. + prenet_dropout_at_inference (bool): + Use dropout at inference time. Defaults to False. + memory_rnn_dim (int): + Dimension of the memory LSTM to process the prenet output. Defaults to 1024. + outputnet_size (list[int]): + Size of the output network inside the neural HMM. Defaults to [1024]. + flat_start_params (dict): + Parameters for the flat start initialization of the neural HMM. Defaults to `{"mean": 0.0, "std": 1.0, "transition_p": 0.14}`. + It will be recomputed when you pass the dataset. + std_floor (float): + Floor value for the standard deviation of the neural HMM. Prevents model cheating by putting point mass and getting infinite likelihood at any datapoint. Defaults to 0.01. + It is called `variance flooring` in standard HMM literature. + optimizer (str): + Optimizer to use for training. Defaults to `adam`. + optimizer_params (dict): + Parameters for the optimizer. Defaults to `{"weight_decay": 1e-6}`. + grad_clip (float): + Gradient clipping threshold. Defaults to 40_000. + lr (float): + Learning rate. Defaults to 1e-3. + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to `None`. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "NeuralHMM_TTS" + + # Training and Checkpoint configs + run_eval_steps: int = 100 + save_step: int = 500 + plot_step: int = 1 + model_param_stats: bool = False + + # data parameters + force_generate_statistics: bool = False + mel_statistics_parameter_path: str = None + + # Encoder parameters + num_chars: int = None + state_per_phone: int = 2 + encoder_in_out_features: int = 512 + encoder_n_convolutions: int = 3 + + # HMM parameters + out_channels: int = 80 + ar_order: int = 1 + sampling_temp: float = 0 + deterministic_transition: bool = True + duration_threshold: float = 0.43 + use_grad_checkpointing: bool = True + max_sampling_time: int = 1000 + + ## Prenet parameters + prenet_type: str = "original" + prenet_dim: int = 256 + prenet_n_layers: int = 2 + prenet_dropout: float = 0.5 + prenet_dropout_at_inference: bool = True + memory_rnn_dim: int = 1024 + + ## Outputnet parameters + outputnet_size: List[int] = field(default_factory=lambda: [1024]) + flat_start_params: dict = field(default_factory=lambda: {"mean": 0.0, "std": 1.0, "transition_p": 0.14}) + std_floor: float = 0.001 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"weight_decay": 1e-6}) + grad_clip: float = 40000.0 + lr: float = 1e-3 + lr_scheduler: str = None + + # overrides + min_text_len: int = 10 + max_text_len: int = 500 + min_audio_len: int = 512 + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "Be a voice, not an echo.", + ] + ) + + # Extra needed config + r: int = 1 + use_d_vector_file: bool = False + use_speaker_embedding: bool = False + + def check_values(self): + """Validate the hyperparameters. + + Raises: + AssertionError: when the parameters network is not defined + AssertionError: transition probability is not between 0 and 1 + """ + assert self.ar_order > 0, "AR order must be greater than 0 it is an autoregressive model." + assert ( + len(self.outputnet_size) >= 1 + ), f"Parameter Network must have atleast one layer check the config file for parameter network. Provided: {self.parameternetwork}" + assert ( + 0 < self.flat_start_params["transition_p"] < 1 + ), f"Transition probability must be between 0 and 1. Provided: {self.flat_start_params['transition_p']}" diff --git a/TTS/tts/configs/overflow_config.py b/TTS/tts/configs/overflow_config.py new file mode 100644 index 0000000000000000000000000000000000000000..dc3e5548b8f62f76c88acca85d19e2cee8687ebd --- /dev/null +++ b/TTS/tts/configs/overflow_config.py @@ -0,0 +1,201 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig + + +@dataclass +class OverflowConfig(BaseTTSConfig): # The classname has to be camel case + """ + Define parameters for OverFlow model. + + Example: + + >>> from TTS.tts.configs.overflow_config import OverflowConfig + >>> config = OverflowConfig() + + Args: + model (str): + Model name used to select the right model class to initilize. Defaults to `Overflow`. + run_eval_steps (int): + Run evalulation epoch after N steps. If None, waits until training epoch is completed. Defaults to None. + save_step (int): + Save local checkpoint every save_step steps. Defaults to 500. + plot_step (int): + Plot training stats on the logger every plot_step steps. Defaults to 1. + model_param_stats (bool): + Log model parameters stats on the logger dashboard. Defaults to False. + force_generate_statistics (bool): + Force generate mel normalization statistics. Defaults to False. + mel_statistics_parameter_path (str): + Path to the mel normalization statistics.If the model doesn't finds a file there it will generate statistics. + Defaults to None. + num_chars (int): + Number of characters used by the model. It must be defined before initializing the model. Defaults to None. + state_per_phone (int): + Generates N states per phone. Similar, to `add_blank` parameter in GlowTTS but in Overflow it is upsampled by model's encoder. Defaults to 2. + encoder_in_out_features (int): + Channels of encoder input and character embedding tensors. Defaults to 512. + encoder_n_convolutions (int): + Number of convolution layers in the encoder. Defaults to 3. + out_channels (int): + Channels of the final model output. It must match the spectragram size. Defaults to 80. + ar_order (int): + Autoregressive order of the model. Defaults to 1. In ablations of Neural HMM it was found that more autoregression while giving more variation hurts naturalness of the synthesised audio. + sampling_temp (float): + Variation added to the sample from the latent space of neural HMM. Defaults to 0.334. + deterministic_transition (bool): + deterministic duration generation based on duration quantiles as defiend in "S. Ronanki, O. Watts, S. King, and G. E. Henter, “Medianbased generation of synthetic speech durations using a nonparametric approach,” in Proc. SLT, 2016.". Defaults to True. + duration_threshold (float): + Threshold for duration quantiles. Defaults to 0.55. Tune this to change the speaking rate of the synthesis, where lower values defines a slower speaking rate and higher values defines a faster speaking rate. + use_grad_checkpointing (bool): + Use gradient checkpointing to save memory. In a multi-GPU setting currently pytorch does not supports gradient checkpoint inside a loop so we will have to turn it off then.Adjust depending on whatever get more batch size either by using a single GPU or multi-GPU. Defaults to True. + max_sampling_time (int): + Maximum sampling time while synthesising latents from neural HMM. Defaults to 1000. + prenet_type (str): + `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the + Prenet. Defaults to `original`. + prenet_dim (int): + Dimension of the Prenet. Defaults to 256. + prenet_n_layers (int): + Number of layers in the Prenet. Defaults to 2. + prenet_dropout (float): + Dropout rate of the Prenet. Defaults to 0.5. + prenet_dropout_at_inference (bool): + Use dropout at inference time. Defaults to False. + memory_rnn_dim (int): + Dimension of the memory LSTM to process the prenet output. Defaults to 1024. + outputnet_size (list[int]): + Size of the output network inside the neural HMM. Defaults to [1024]. + flat_start_params (dict): + Parameters for the flat start initialization of the neural HMM. Defaults to `{"mean": 0.0, "std": 1.0, "transition_p": 0.14}`. + It will be recomputed when you pass the dataset. + std_floor (float): + Floor value for the standard deviation of the neural HMM. Prevents model cheating by putting point mass and getting infinite likelihood at any datapoint. Defaults to 0.01. + It is called `variance flooring` in standard HMM literature. + hidden_channels_dec (int): + Number of base hidden channels used by the decoder WaveNet network. Defaults to 150. + kernel_size_dec (int): + Decoder kernel size. Defaults to 5 + dilation_rate (int): + Rate to increase dilation by each layer in a decoder block. Defaults to 1. + num_flow_blocks_dec (int): + Number of decoder layers in each decoder block. Defaults to 4. + dropout_p_dec (float): + Dropout rate of the decoder. Defaults to 0.05. + num_splits (int): + Number of split levels in inversible conv1x1 operation. Defaults to 4. + num_squeeze (int): + Number of squeeze levels. When squeezing channels increases and time steps reduces by the factor + 'num_squeeze'. Defaults to 2. + sigmoid_scale (bool): + enable/disable sigmoid scaling in decoder. Defaults to False. + c_in_channels (int): + Unused parameter from GlowTTS's decoder. Defaults to 0. + optimizer (str): + Optimizer to use for training. Defaults to `adam`. + optimizer_params (dict): + Parameters for the optimizer. Defaults to `{"weight_decay": 1e-6}`. + grad_clip (float): + Gradient clipping threshold. Defaults to 40_000. + lr (float): + Learning rate. Defaults to 1e-3. + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to `None`. + min_seq_len (int): + Minimum input sequence length to be used at training. + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "Overflow" + + # Training and Checkpoint configs + run_eval_steps: int = 100 + save_step: int = 500 + plot_step: int = 1 + model_param_stats: bool = False + + # data parameters + force_generate_statistics: bool = False + mel_statistics_parameter_path: str = None + + # Encoder parameters + num_chars: int = None + state_per_phone: int = 2 + encoder_in_out_features: int = 512 + encoder_n_convolutions: int = 3 + + # HMM parameters + out_channels: int = 80 + ar_order: int = 1 + sampling_temp: float = 0.334 + deterministic_transition: bool = True + duration_threshold: float = 0.55 + use_grad_checkpointing: bool = True + max_sampling_time: int = 1000 + + ## Prenet parameters + prenet_type: str = "original" + prenet_dim: int = 256 + prenet_n_layers: int = 2 + prenet_dropout: float = 0.5 + prenet_dropout_at_inference: bool = False + memory_rnn_dim: int = 1024 + + ## Outputnet parameters + outputnet_size: List[int] = field(default_factory=lambda: [1024]) + flat_start_params: dict = field(default_factory=lambda: {"mean": 0.0, "std": 1.0, "transition_p": 0.14}) + std_floor: float = 0.01 + + # Decoder parameters + hidden_channels_dec: int = 150 + kernel_size_dec: int = 5 + dilation_rate: int = 1 + num_flow_blocks_dec: int = 12 + num_block_layers: int = 4 + dropout_p_dec: float = 0.05 + num_splits: int = 4 + num_squeeze: int = 2 + sigmoid_scale: bool = False + c_in_channels: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"weight_decay": 1e-6}) + grad_clip: float = 40000.0 + lr: float = 1e-3 + lr_scheduler: str = None + + # overrides + min_text_len: int = 10 + max_text_len: int = 500 + min_audio_len: int = 512 + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "Be a voice, not an echo.", + ] + ) + + # Extra needed config + r: int = 1 + use_d_vector_file: bool = False + use_speaker_embedding: bool = False + + def check_values(self): + """Validate the hyperparameters. + + Raises: + AssertionError: when the parameters network is not defined + AssertionError: transition probability is not between 0 and 1 + """ + assert self.ar_order > 0, "AR order must be greater than 0 it is an autoregressive model." + assert ( + len(self.outputnet_size) >= 1 + ), f"Parameter Network must have atleast one layer check the config file for parameter network. Provided: {self.parameternetwork}" + assert ( + 0 < self.flat_start_params["transition_p"] < 1 + ), f"Transition probability must be between 0 and 1. Provided: {self.flat_start_params['transition_p']}" diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..bf17322c190bb234d4e27c6196e53b276fb5f09d --- /dev/null +++ b/TTS/tts/configs/shared_configs.py @@ -0,0 +1,344 @@ +from dataclasses import asdict, dataclass, field +from typing import Dict, List + +from coqpit import Coqpit, check_argument + +from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig + + +@dataclass +class GSTConfig(Coqpit): + """Defines the Global Style Token Module + + Args: + gst_style_input_wav (str): + Path to the wav file used to define the style of the output speech at inference. Defaults to None. + + gst_style_input_weights (dict): + Defines the weights for each style token used at inference. Defaults to None. + + gst_embedding_dim (int): + Defines the size of the GST embedding vector dimensions. Defaults to 256. + + gst_num_heads (int): + Number of attention heads used by the multi-head attention. Defaults to 4. + + gst_num_style_tokens (int): + Number of style token vectors. Defaults to 10. + """ + + gst_style_input_wav: str = None + gst_style_input_weights: dict = None + gst_embedding_dim: int = 256 + gst_use_speaker_embedding: bool = False + gst_num_heads: int = 4 + gst_num_style_tokens: int = 10 + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + super().check_values() + check_argument("gst_style_input_weights", c, restricted=False) + check_argument("gst_style_input_wav", c, restricted=False) + check_argument("gst_embedding_dim", c, restricted=True, min_val=0, max_val=1000) + check_argument("gst_use_speaker_embedding", c, restricted=False) + check_argument("gst_num_heads", c, restricted=True, min_val=2, max_val=10) + check_argument("gst_num_style_tokens", c, restricted=True, min_val=1, max_val=1000) + + +@dataclass +class CapacitronVAEConfig(Coqpit): + """Defines the capacitron VAE Module + Args: + capacitron_capacity (int): + Defines the variational capacity limit of the prosody embeddings. Defaults to 150. + capacitron_VAE_embedding_dim (int): + Defines the size of the Capacitron embedding vector dimension. Defaults to 128. + capacitron_use_text_summary_embeddings (bool): + If True, use a text summary embedding in Capacitron. Defaults to True. + capacitron_text_summary_embedding_dim (int): + Defines the size of the capacitron text embedding vector dimension. Defaults to 128. + capacitron_use_speaker_embedding (bool): + if True use speaker embeddings in Capacitron. Defaults to False. + capacitron_VAE_loss_alpha (float): + Weight for the VAE loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + capacitron_grad_clip (float): + Gradient clipping value for all gradients except beta. Defaults to 5.0 + """ + + capacitron_loss_alpha: int = 1 + capacitron_capacity: int = 150 + capacitron_VAE_embedding_dim: int = 128 + capacitron_use_text_summary_embeddings: bool = True + capacitron_text_summary_embedding_dim: int = 128 + capacitron_use_speaker_embedding: bool = False + capacitron_VAE_loss_alpha: float = 0.25 + capacitron_grad_clip: float = 5.0 + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + super().check_values() + check_argument("capacitron_capacity", c, restricted=True, min_val=10, max_val=500) + check_argument("capacitron_VAE_embedding_dim", c, restricted=True, min_val=16, max_val=1024) + check_argument("capacitron_use_speaker_embedding", c, restricted=False) + check_argument("capacitron_text_summary_embedding_dim", c, restricted=False, min_val=16, max_val=512) + check_argument("capacitron_VAE_loss_alpha", c, restricted=False) + check_argument("capacitron_grad_clip", c, restricted=False) + + +@dataclass +class CharactersConfig(Coqpit): + """Defines arguments for the `BaseCharacters` or `BaseVocabulary` and their subclasses. + + Args: + characters_class (str): + Defines the class of the characters used. If None, we pick ```Phonemes``` or ```Graphemes``` based on + the configuration. Defaults to None. + + vocab_dict (dict): + Defines the vocabulary dictionary used to encode the characters. Defaults to None. + + pad (str): + characters in place of empty padding. Defaults to None. + + eos (str): + characters showing the end of a sentence. Defaults to None. + + bos (str): + characters showing the beginning of a sentence. Defaults to None. + + blank (str): + Optional character used between characters by some models for better prosody. Defaults to `_blank`. + + characters (str): + character set used by the model. Characters not in this list are ignored when converting input text to + a list of sequence IDs. Defaults to None. + + punctuations (str): + characters considered as punctuation as parsing the input sentence. Defaults to None. + + phonemes (str): + characters considered as parsing phonemes. This is only for backwards compat. Use `characters` for new + models. Defaults to None. + + is_unique (bool): + remove any duplicate characters in the character lists. It is a bandaid for compatibility with the old + models trained with character lists with duplicates. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + characters_class: str = None + + # using BaseVocabulary + vocab_dict: Dict = None + + # using on BaseCharacters + pad: str = None + eos: str = None + bos: str = None + blank: str = None + characters: str = None + punctuations: str = None + phonemes: str = None + is_unique: bool = True # for backwards compatibility of models trained with char sets with duplicates + is_sorted: bool = True + + +@dataclass +class BaseTTSConfig(BaseTrainingConfig): + """Shared parameters among all the tts models. + + Args: + + audio (BaseAudioConfig): + Audio processor config object instance. + + use_phonemes (bool): + enable / disable phoneme use. + + phonemizer (str): + Name of the phonemizer to use. If set None, the phonemizer will be selected by `phoneme_language`. + Defaults to None. + + phoneme_language (str): + Language code for the phonemizer. You can check the list of supported languages by running + `python TTS/tts/utils/text/phonemizers/__init__.py`. Defaults to None. + + compute_input_seq_cache (bool): + enable / disable precomputation of the phoneme sequences. At the expense of some delay at the beginning of + the training, It allows faster data loader time and precise limitation with `max_seq_len` and + `min_seq_len`. + + text_cleaner (str): + Name of the text cleaner used for cleaning and formatting transcripts. + + enable_eos_bos_chars (bool): + enable / disable the use of eos and bos characters. + + test_senteces_file (str): + Path to a txt file that has sentences used at test time. The file must have a sentence per line. + + phoneme_cache_path (str): + Path to the output folder caching the computed phonemes for each sample. + + characters (CharactersConfig): + Instance of a CharactersConfig class. + + batch_group_size (int): + Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence + length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to + prevent using the same batches for each epoch. + + loss_masking (bool): + enable / disable masking loss values against padded segments of samples in a batch. + + min_text_len (int): + Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0. + + max_text_len (int): + Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf"). + + min_audio_len (int): + Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0. + + max_audio_len (int): + Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the + dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an + OOM error in training. Defaults to float("inf"). + + compute_f0 (int): + (Not in use yet). + + compute_energy (int): + (Not in use yet). + + compute_linear_spec (bool): + If True data loader computes and returns linear spectrograms alongside the other data. + + precompute_num_workers (int): + Number of workers to precompute features. Defaults to 0. + + use_noise_augment (bool): + Augment the input audio with random noise. + + start_by_longest (bool): + If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues. + Defaults to False. + + shuffle (bool): + If True, the data loader will shuffle the dataset when there is not sampler defined. Defaults to True. + + drop_last (bool): + If True, the data loader will drop the last batch if it is not complete. It helps to prevent + issues that emerge from the partial batch statistics. Defaults to True. + + add_blank (bool): + Add blank characters between each other two characters. It improves performance for some models at expense + of slower run-time due to the longer input sequence. + + datasets (List[BaseDatasetConfig]): + List of datasets used for training. If multiple datasets are provided, they are merged and used together + for training. + + optimizer (str): + Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`. + Defaults to ``. + + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to ``. + + lr_scheduler_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`. + + test_sentences (List[str]): + List of sentences to be used at testing. Defaults to '[]' + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + + use_speaker_weighted_sampler (bool): + Enable / Disable the batch balancer by speaker. Defaults to ```False```. + + speaker_weighted_sampler_alpha (float): + Number that control the influence of the speaker sampler weights. Defaults to ```1.0```. + + use_language_weighted_sampler (bool): + Enable / Disable the batch balancer by language. Defaults to ```False```. + + language_weighted_sampler_alpha (float): + Number that control the influence of the language sampler weights. Defaults to ```1.0```. + + use_length_weighted_sampler (bool): + Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided + into 10 buckets considering the min and max audio of the dataset. The sampler weights will be + computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```. + + length_weighted_sampler_alpha (float): + Number that control the influence of the length sampler weights. Defaults to ```1.0```. + """ + + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + # phoneme settings + use_phonemes: bool = False + phonemizer: str = None + phoneme_language: str = None + compute_input_seq_cache: bool = False + text_cleaner: str = None + enable_eos_bos_chars: bool = False + test_sentences_file: str = "" + phoneme_cache_path: str = None + # vocabulary parameters + characters: CharactersConfig = None + add_blank: bool = False + # training params + batch_group_size: int = 0 + loss_masking: bool = None + # dataloading + min_audio_len: int = 1 + max_audio_len: int = float("inf") + min_text_len: int = 1 + max_text_len: int = float("inf") + compute_f0: bool = False + compute_energy: bool = False + compute_linear_spec: bool = False + precompute_num_workers: int = 0 + use_noise_augment: bool = False + start_by_longest: bool = False + shuffle: bool = False + drop_last: bool = False + # dataset + datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + # optimizer + optimizer: str = "radam" + optimizer_params: dict = None + # scheduler + lr_scheduler: str = None + lr_scheduler_params: dict = field(default_factory=lambda: {}) + # testing + test_sentences: List[str] = field(default_factory=lambda: []) + # evaluation + eval_split_max_size: int = None + eval_split_size: float = 0.01 + # weighted samplers + use_speaker_weighted_sampler: bool = False + speaker_weighted_sampler_alpha: float = 1.0 + use_language_weighted_sampler: bool = False + language_weighted_sampler_alpha: float = 1.0 + use_length_weighted_sampler: bool = False + length_weighted_sampler_alpha: float = 1.0 diff --git a/TTS/tts/configs/speedy_speech_config.py b/TTS/tts/configs/speedy_speech_config.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8517dfc478a135978df19f3126313a616c14c2 --- /dev/null +++ b/TTS/tts/configs/speedy_speech_config.py @@ -0,0 +1,194 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.forward_tts import ForwardTTSArgs + + +@dataclass +class SpeedySpeechConfig(BaseTTSConfig): + """Configure `ForwardTTS` as SpeedySpeech model. + + Example: + + >>> from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig + >>> config = SpeedySpeechConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `speedy_speech`. + + base_model (str): + Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate + the base model rather than searching for the `model` implementation. Defaults to `forward_tts`. + + model_args (Coqpit): + Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. + + data_dep_init_steps (int): + Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses + Activation Normalization that pre-computes normalization stats at the beginning and use the same values + for the rest. Defaults to 10. + + speakers_file (str): + Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to + speaker names. Defaults to `None`. + + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Dimension of the external speaker embeddings. Defaults to 0. + + optimizer (str): + Name of the model optimizer. Defaults to `RAdam`. + + optimizer_params (dict): + Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`. + + lr_scheduler (str): + Name of the learning rate scheduler. Defaults to `Noam`. + + lr_scheduler_params (dict): + Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`. + + lr (float): + Initial learning rate. Defaults to `1e-3`. + + grad_clip (float): + Gradient norm clipping value. Defaults to `5.0`. + + spec_loss_type (str): + Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `l1`. + + duration_loss_type (str): + Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `huber`. + + use_ssim_loss (bool): + Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True. + + wd (float): + Weight decay coefficient. Defaults to `1e-7`. + + ssim_loss_alpha (float): + Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0. + + dur_loss_alpha (float): + Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0. + + spec_loss_alpha (float): + Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0. + + binary_loss_alpha (float): + Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0. + + binary_loss_warmup_epochs (float): + Number of epochs to gradually increase the binary loss impact. Defaults to 150. + + min_seq_len (int): + Minimum input sequence length to be used at training. + + max_seq_len (int): + Maximum input sequence length to be used at training. Larger values result in more VRAM usage. + """ + + model: str = "speedy_speech" + base_model: str = "forward_tts" + + # set model args as SpeedySpeech + model_args: ForwardTTSArgs = field( + default_factory=lambda: ForwardTTSArgs( + use_pitch=False, + encoder_type="residual_conv_bn", + encoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13, + }, + decoder_type="residual_conv_bn", + decoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17, + }, + out_channels=80, + hidden_channels=128, + positional_encoding=True, + detach_duration_predictor=True, + ) + ) + + # multi-speaker settings + num_speakers: int = 0 + speakers_file: str = None + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = 0 + + # optimizer parameters + optimizer: str = "Adam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + + # loss params + spec_loss_type: str = "l1" + duration_loss_type: str = "huber" + use_ssim_loss: bool = False + ssim_loss_alpha: float = 1.0 + dur_loss_alpha: float = 1.0 + spec_loss_alpha: float = 1.0 + aligner_loss_alpha: float = 1.0 + binary_align_loss_alpha: float = 0.3 + binary_loss_warmup_epochs: int = 150 + + # overrides + min_seq_len: int = 13 + max_seq_len: int = 200 + r: int = 1 # DO NOT CHANGE + + # dataset configs + compute_f0: bool = False + f0_cache_path: str = None + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def __post_init__(self): + # Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there. + if self.num_speakers > 0: + self.model_args.num_speakers = self.num_speakers + + # speaker embedding settings + if self.use_speaker_embedding: + self.model_args.use_speaker_embedding = True + if self.speakers_file: + self.model_args.speakers_file = self.speakers_file + + # d-vector settings + if self.use_d_vector_file: + self.model_args.use_d_vector_file = True + if self.d_vector_dim is not None and self.d_vector_dim > 0: + self.model_args.d_vector_dim = self.d_vector_dim + if self.d_vector_file: + self.model_args.d_vector_file = self.d_vector_file diff --git a/TTS/tts/configs/tacotron2_config.py b/TTS/tts/configs/tacotron2_config.py new file mode 100644 index 0000000000000000000000000000000000000000..95b65202218cf3aa0dd70c8d8cd55a3f913ed308 --- /dev/null +++ b/TTS/tts/configs/tacotron2_config.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +from TTS.tts.configs.tacotron_config import TacotronConfig + + +@dataclass +class Tacotron2Config(TacotronConfig): + """Defines parameters for Tacotron2 based models. + + Example: + + >>> from TTS.tts.configs.tacotron2_config import Tacotron2Config + >>> config = Tacotron2Config() + + Check `TacotronConfig` for argument descriptions. + """ + + model: str = "tacotron2" + out_channels: int = 80 + encoder_in_features: int = 512 + decoder_in_features: int = 512 diff --git a/TTS/tts/configs/tacotron_config.py b/TTS/tts/configs/tacotron_config.py new file mode 100644 index 0000000000000000000000000000000000000000..350b5ea99633569d6977851875d5d8d83175ac36 --- /dev/null +++ b/TTS/tts/configs/tacotron_config.py @@ -0,0 +1,235 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig, CapacitronVAEConfig, GSTConfig + + +@dataclass +class TacotronConfig(BaseTTSConfig): + """Defines parameters for Tacotron based models. + + Example: + + >>> from TTS.tts.configs.tacotron_config import TacotronConfig + >>> config = TacotronConfig() + + Args: + model (str): + Model name used to select the right model class to initilize. Defaults to `Tacotron`. + use_gst (bool): + enable / disable the use of Global Style Token modules. Defaults to False. + gst (GSTConfig): + Instance of `GSTConfig` class. + gst_style_input (str): + Path to the wav file used at inference to set the speech style through GST. If `GST` is enabled and + this is not defined, the model uses a zero vector as an input. Defaults to None. + use_capacitron_vae (bool): + enable / disable the use of Capacitron modules. Defaults to False. + capacitron_vae (CapacitronConfig): + Instance of `CapacitronConfig` class. + num_chars (int): + Number of characters used by the model. It must be defined before initializing the model. Defaults to None. + num_speakers (int): + Number of speakers for multi-speaker models. Defaults to 1. + r (int): + Initial number of output frames that the decoder computed per iteration. Larger values makes training and inference + faster but reduces the quality of the output frames. This must be equal to the largest `r` value used in + `gradual_training` schedule. Defaults to 1. + gradual_training (List[List]): + Parameters for the gradual training schedule. It is in the form `[[a, b, c], [d ,e ,f] ..]` where `a` is + the step number to start using the rest of the values, `b` is the `r` value and `c` is the batch size. + If sets None, no gradual training is used. Defaults to None. + memory_size (int): + Defines the number of previous frames used by the Prenet. If set to < 0, then it uses only the last frame. + Defaults to -1. + prenet_type (str): + `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the + Prenet. Defaults to `original`. + prenet_dropout (bool): + enables / disables the use of dropout in the Prenet. Defaults to True. + prenet_dropout_at_inference (bool): + enable / disable the use of dropout in the Prenet at the inference time. Defaults to False. + stopnet (bool): + enable /disable the Stopnet that predicts the end of the decoder sequence. Defaults to True. + stopnet_pos_weight (float): + Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with + datasets with longer sentences. Defaults to 0.2. + max_decoder_steps (int): + Max number of steps allowed for the decoder. Defaults to 50. + encoder_in_features (int): + Channels of encoder input and character embedding tensors. Defaults to 256. + decoder_in_features (int): + Channels of decoder input and encoder output tensors. Defaults to 256. + out_channels (int): + Channels of the final model output. It must match the spectragram size. Defaults to 80. + separate_stopnet (bool): + Use a distinct Stopnet which is trained separately from the rest of the model. Defaults to True. + attention_type (str): + attention type. Check ```TTS.tts.layers.attentions.init_attn```. Defaults to 'original'. + attention_heads (int): + Number of attention heads for GMM attention. Defaults to 5. + windowing (bool): + It especially useful at inference to keep attention alignment diagonal. Defaults to False. + use_forward_attn (bool): + It is only valid if ```attn_type``` is ```original```. Defaults to False. + forward_attn_mask (bool): + enable/disable extra masking over forward attention. It is useful at inference to prevent + possible attention failures. Defaults to False. + transition_agent (bool): + enable/disable transition agent in forward attention. Defaults to False. + location_attn (bool): + enable/disable location sensitive attention as in the original Tacotron2 paper. + It is only valid if ```attn_type``` is ```original```. Defaults to True. + bidirectional_decoder (bool): + enable/disable bidirectional decoding. Defaults to False. + double_decoder_consistency (bool): + enable/disable double decoder consistency. Defaults to False. + ddc_r (int): + reduction rate used by the coarse decoder when `double_decoder_consistency` is in use. Set this + as a multiple of the `r` value. Defaults to 6. + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. + use_speaker_embedding (bool): + enable / disable using speaker embeddings for multi-speaker models. If set True, the model is + in the multi-speaker mode. Defaults to False. + use_d_vector_file (bool): + enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False. + d_vector_file (str): + Path to the file including pre-computed speaker embeddings. Defaults to None. + optimizer (str): + Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`. + Defaults to `RAdam`. + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to `NoamLR`. + lr_scheduler_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`. + lr (float): + Initial learning rate. Defaults to `1e-4`. + wd (float): + Weight decay coefficient. Defaults to `1e-6`. + grad_clip (float): + Gradient clipping threshold. Defaults to `5`. + seq_len_norm (bool): + enable / disable the sequnce length normalization in the loss functions. If set True, loss of a sample + is divided by the sequence length. Defaults to False. + loss_masking (bool): + enable / disable masking the paddings of the samples in loss computation. Defaults to True. + decoder_loss_alpha (float): + Weight for the decoder loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + postnet_loss_alpha (float): + Weight for the postnet loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + postnet_diff_spec_alpha (float): + Weight for the postnet differential loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + decoder_diff_spec_alpha (float): + + Weight for the decoder differential loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + decoder_ssim_alpha (float): + Weight for the decoder SSIM loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + postnet_ssim_alpha (float): + Weight for the postnet SSIM loss of the Tacotron model. If set less than or equal to zero, it disables the + corresponding loss function. Defaults to 0.25 + ga_alpha (float): + Weight for the guided attention loss. If set less than or equal to zero, it disables the corresponding loss + function. Defaults to 5. + """ + + model: str = "tacotron" + # model_params: TacotronArgs = field(default_factory=lambda: TacotronArgs()) + use_gst: bool = False + gst: GSTConfig = None + gst_style_input: str = None + + use_capacitron_vae: bool = False + capacitron_vae: CapacitronVAEConfig = None + + # model specific params + num_speakers: int = 1 + num_chars: int = 0 + r: int = 2 + gradual_training: List[List[int]] = None + memory_size: int = -1 + prenet_type: str = "original" + prenet_dropout: bool = True + prenet_dropout_at_inference: bool = False + stopnet: bool = True + separate_stopnet: bool = True + stopnet_pos_weight: float = 0.2 + max_decoder_steps: int = 10000 + encoder_in_features: int = 256 + decoder_in_features: int = 256 + decoder_output_dim: int = 80 + out_channels: int = 513 + + # attention layers + attention_type: str = "original" + attention_heads: int = None + attention_norm: str = "sigmoid" + attention_win: bool = False + windowing: bool = False + use_forward_attn: bool = False + forward_attn_mask: bool = False + transition_agent: bool = False + location_attn: bool = True + + # advance methods + bidirectional_decoder: bool = False + double_decoder_consistency: bool = False + ddc_r: int = 6 + + # multi-speaker settings + speakers_file: str = None + use_speaker_embedding: bool = False + speaker_embedding_dim: int = 512 + use_d_vector_file: bool = False + d_vector_file: str = False + d_vector_dim: int = None + + # optimizer parameters + optimizer: str = "RAdam" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6}) + lr_scheduler: str = "NoamLR" + lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000}) + lr: float = 1e-4 + grad_clip: float = 5.0 + seq_len_norm: bool = False + loss_masking: bool = True + + # loss params + decoder_loss_alpha: float = 0.25 + postnet_loss_alpha: float = 0.25 + postnet_diff_spec_alpha: float = 0.25 + decoder_diff_spec_alpha: float = 0.25 + decoder_ssim_alpha: float = 0.25 + postnet_ssim_alpha: float = 0.25 + ga_alpha: float = 5.0 + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) + + def check_values(self): + if self.gradual_training: + assert ( + self.gradual_training[0][1] == self.r + ), f"[!] the first scheduled gradual training `r` must be equal to the model's `r` value. {self.gradual_training[0][1]} vs {self.r}" + if self.model == "tacotron" and self.audio is not None: + assert self.out_channels == ( + self.audio.fft_size // 2 + 1 + ), f"{self.out_channels} vs {self.audio.fft_size // 2 + 1}" + if self.model == "tacotron2" and self.audio is not None: + assert self.out_channels == self.audio.num_mels diff --git a/TTS/tts/configs/tortoise_config.py b/TTS/tts/configs/tortoise_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d60e43d71280bfa085988e31a52acfeef015c5f0 --- /dev/null +++ b/TTS/tts/configs/tortoise_config.py @@ -0,0 +1,87 @@ +from dataclasses import dataclass, field + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.tortoise import TortoiseArgs, TortoiseAudioConfig + + +@dataclass +class TortoiseConfig(BaseTTSConfig): + """Defines parameters for Tortoise TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (TortoiseArgs): + Model architecture arguments. Defaults to `TortoiseArgs()`. + + audio (TortoiseAudioConfig): + Audio processing configuration. Defaults to `TortoiseAudioConfig()`. + + model_dir (str): + Path to the folder that has all the Tortoise models. Defaults to None. + + temperature (float): + Temperature for the autoregressive model inference. Larger values makes predictions more creative sacrificing stability. Defaults to `0.2`. + + length_penalty (float): + Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, + which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), + length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. + + reperation_penalty (float): + The parameter for repetition penalty. 1.0 means no penalty. Defaults to `2.0`. + + top_p (float): + If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + Defaults to `0.8`. + + cond_free_k (float): + Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. + As cond_free_k increases, the output becomes dominated by the conditioning-free signal. + Formula is: output=cond_present_output*(cond_free_k+1)-cond_absenct_output*cond_free_k. Defaults to `2.0`. + + diffusion_temperature (float): + Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 + are the "mean" prediction of the diffusion network and will sound bland and smeared. + Defaults to `1.0`. + + num_autoregressive_samples (int): + Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great". + Defaults to `16`. + + diffusion_iterations (int): + Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively refine + the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, + however. Defaults to `30`. + + sampler (str): + Diffusion sampler to be used. `ddim` or `dpm++2m`. Defaults to `ddim`. + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.tts.configs.tortoise_config import TortoiseConfig + >>> config = TortoiseConfig() + """ + + model: str = "tortoise" + # model specific params + model_args: TortoiseArgs = field(default_factory=TortoiseArgs) + audio: TortoiseAudioConfig = field(default_factory=TortoiseAudioConfig) + model_dir: str = None + + # settings + temperature: float = 0.2 + length_penalty: float = 1.0 + repetition_penalty: float = 2.0 + top_p: float = 0.8 + cond_free_k: float = 2.0 + diffusion_temperature: float = 1.0 + + # inference params + num_autoregressive_samples: int = 16 + diffusion_iterations: int = 30 + sampler: str = "ddim" diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py new file mode 100644 index 0000000000000000000000000000000000000000..2d0242bf131a25d6b2cef7a297a3c32b283f908a --- /dev/null +++ b/TTS/tts/configs/vits_config.py @@ -0,0 +1,176 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.vits import VitsArgs, VitsAudioConfig + + +@dataclass +class VitsConfig(BaseTTSConfig): + """Defines parameters for VITS End2End TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (VitsArgs): + Model architecture arguments. Defaults to `VitsArgs()`. + + audio (VitsAudioConfig): + Audio processing configuration. Defaults to `VitsAudioConfig()`. + + grad_clip (List): + Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`. + + lr_gen (float): + Initial learning rate for the generator. Defaults to 0.0002. + + lr_disc (float): + Initial learning rate for the discriminator. Defaults to 0.0002. + + lr_scheduler_gen (str): + Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_gen_params (dict): + Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + lr_scheduler_disc (str): + Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_disc_params (dict): + Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + scheduler_after_epoch (bool): + If true, step the schedulers after each epoch else after each step. Defaults to `False`. + + optimizer (str): + Name of the optimizer to use with both the generator and the discriminator networks. One of the + `torch.optim.*`. Defaults to `AdamW`. + + kl_loss_alpha (float): + Loss weight for KL loss. Defaults to 1.0. + + disc_loss_alpha (float): + Loss weight for the discriminator loss. Defaults to 1.0. + + gen_loss_alpha (float): + Loss weight for the generator loss. Defaults to 1.0. + + feat_loss_alpha (float): + Loss weight for the feature matching loss. Defaults to 1.0. + + mel_loss_alpha (float): + Loss weight for the mel loss. Defaults to 45.0. + + return_wav (bool): + If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`. + + compute_linear_spec (bool): + If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. + + use_weighted_sampler (bool): + If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`. + + weighted_sampler_attrs (dict): + Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities + by overweighting `root_path` by 2.0. Defaults to `{}`. + + weighted_sampler_multipliers (dict): + Weight each unique value of a key returned by the formatter for weighted sampling. + For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`. + It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`. + + r (int): + Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. + + add_blank (bool): + If true, a blank token is added in between every character. Defaults to `True`. + + test_sentences (List[List]): + List of sentences with speaker and language information to be used for testing. + + language_ids_file (str): + Path to the language ids file. + + use_language_embedding (bool): + If true, language embedding is used. Defaults to `False`. + + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.tts.configs.vits_config import VitsConfig + >>> config = VitsConfig() + """ + + model: str = "vits" + # model specific params + model_args: VitsArgs = field(default_factory=VitsArgs) + audio: VitsAudioConfig = field(default_factory=VitsAudioConfig) + + # optimizer + grad_clip: List[float] = field(default_factory=lambda: [1000, 1000]) + lr_gen: float = 0.0002 + lr_disc: float = 0.0002 + lr_scheduler_gen: str = "ExponentialLR" + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + lr_scheduler_disc: str = "ExponentialLR" + lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1}) + scheduler_after_epoch: bool = True + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01}) + + # loss params + kl_loss_alpha: float = 1.0 + disc_loss_alpha: float = 1.0 + gen_loss_alpha: float = 1.0 + feat_loss_alpha: float = 1.0 + mel_loss_alpha: float = 45.0 + dur_loss_alpha: float = 1.0 + speaker_encoder_loss_alpha: float = 1.0 + + # data loader params + return_wav: bool = True + compute_linear_spec: bool = True + + # sampler params + use_weighted_sampler: bool = False # TODO: move it to the base config + weighted_sampler_attrs: dict = field(default_factory=lambda: {}) + weighted_sampler_multipliers: dict = field(default_factory=lambda: {}) + + # overrides + r: int = 1 # DO NOT CHANGE + add_blank: bool = True + + # testing + test_sentences: List[List] = field( + default_factory=lambda: [ + ["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."], + ["Be a voice, not an echo."], + ["I'm sorry Dave. I'm afraid I can't do that."], + ["This cake is great. It's so delicious and moist."], + ["Prior to November 22, 1963."], + ] + ) + + # multi-speaker settings + # use speaker embedding layer + num_speakers: int = 0 + use_speaker_embedding: bool = False + speakers_file: str = None + speaker_embedding_channels: int = 256 + language_ids_file: str = None + use_language_embedding: bool = False + + # use d-vectors + use_d_vector_file: bool = False + d_vector_file: List[str] = None + d_vector_dim: int = None + + def __post_init__(self): + for key, val in self.model_args.items(): + if hasattr(self, key): + self[key] = val diff --git a/TTS/tts/configs/xtts_config.py b/TTS/tts/configs/xtts_config.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf048e1ab7984e0cc0c7914cf8fd991fd62ef1f --- /dev/null +++ b/TTS/tts/configs/xtts_config.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.tts.configs.shared_configs import BaseTTSConfig +from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig + + +@dataclass +class XttsConfig(BaseTTSConfig): + """Defines parameters for XTTS TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (XttsArgs): + Model architecture arguments. Defaults to `XttsArgs()`. + + audio (XttsAudioConfig): + Audio processing configuration. Defaults to `XttsAudioConfig()`. + + model_dir (str): + Path to the folder that has all the XTTS models. Defaults to None. + + temperature (float): + Temperature for the autoregressive model inference. Larger values makes predictions more creative sacrificing stability. Defaults to `0.2`. + + length_penalty (float): + Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, + which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), + length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. + + repetition_penalty (float): + The parameter for repetition penalty. 1.0 means no penalty. Defaults to `2.0`. + + top_p (float): + If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + Defaults to `0.8`. + + num_gpt_outputs (int): + Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As XTTS is a probabilistic model, more samples means a higher probability of creating something "great". + Defaults to `16`. + + gpt_cond_len (int): + Secs audio to be used as conditioning for the autoregressive model. Defaults to `12`. + + gpt_cond_chunk_len (int): + Audio chunk size in secs. Audio is split into chunks and latents are extracted for each chunk. Then the + latents are averaged. Chunking improves the stability. It must be <= gpt_cond_len. + If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to `4`. + + max_ref_len (int): + Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`. + + sound_norm_refs (bool): + Whether to normalize the conditioning audio. Defaults to `False`. + + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.tts.configs.xtts_config import XttsConfig + >>> config = XttsConfig() + """ + + model: str = "xtts" + # model specific params + model_args: XttsArgs = field(default_factory=XttsArgs) + audio: XttsAudioConfig = field(default_factory=XttsAudioConfig) + model_dir: str = None + languages: List[str] = field( + default_factory=lambda: [ + "en", + "es", + "fr", + "de", + "it", + "pt", + "pl", + "tr", + "ru", + "nl", + "cs", + "ar", + "zh-cn", + "hu", + "ko", + "ja", + "hi", + ] + ) + + # inference params + temperature: float = 0.85 + length_penalty: float = 1.0 + repetition_penalty: float = 2.0 + top_k: int = 50 + top_p: float = 0.85 + num_gpt_outputs: int = 1 + + # cloning + gpt_cond_len: int = 12 + gpt_cond_chunk_len: int = 4 + max_ref_len: int = 10 + sound_norm_refs: bool = False diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f2cb2e370c7ac57eeb1c27a9ffabbd0df32d7f --- /dev/null +++ b/TTS/tts/datasets/__init__.py @@ -0,0 +1,183 @@ +import logging +import os +import sys +from collections import Counter +from pathlib import Path +from typing import Callable, Dict, List, Tuple, Union + +import numpy as np + +from TTS.tts.datasets.dataset import * +from TTS.tts.datasets.formatters import * + +logger = logging.getLogger(__name__) + + +def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): + """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. + + Args: + items (List[List]): + A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + """ + speakers = [item["speaker_name"] for item in items] + is_multi_speaker = len(set(speakers)) > 1 + if eval_split_size > 1: + eval_split_size = int(eval_split_size) + else: + if eval_split_max_size: + eval_split_size = min(eval_split_max_size, int(len(items) * eval_split_size)) + else: + eval_split_size = int(len(items) * eval_split_size) + + assert ( + eval_split_size > 0 + ), " [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {}".format( + 1 / len(items) + ) + np.random.seed(0) + np.random.shuffle(items) + if is_multi_speaker: + items_eval = [] + speakers = [item["speaker_name"] for item in items] + speaker_counter = Counter(speakers) + while len(items_eval) < eval_split_size: + item_idx = np.random.randint(0, len(items)) + speaker_to_be_removed = items[item_idx]["speaker_name"] + if speaker_counter[speaker_to_be_removed] > 1: + items_eval.append(items[item_idx]) + speaker_counter[speaker_to_be_removed] -= 1 + del items[item_idx] + return items_eval, items + return items[:eval_split_size], items[eval_split_size:] + + +def add_extra_keys(metadata, language, dataset_name): + for item in metadata: + # add language name + item["language"] = language + # add unique audio name + relfilepath = os.path.splitext(os.path.relpath(item["audio_file"], item["root_path"]))[0] + audio_unique_name = f"{dataset_name}#{relfilepath}" + item["audio_unique_name"] = audio_unique_name + return metadata + + +def load_tts_samples( + datasets: Union[List[Dict], Dict], + eval_split=True, + formatter: Callable = None, + eval_split_max_size=None, + eval_split_size=0.01, +) -> Tuple[List[List], List[List]]: + """Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided. + If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based + on the dataset name. + + Args: + datasets (List[Dict], Dict): A list of datasets or a single dataset dictionary. If multiple datasets are + in the list, they are all merged. + + eval_split (bool, optional): If true, create a evaluation split. If an eval split provided explicitly, generate + an eval split automatically. Defaults to True. + + formatter (Callable, optional): The preprocessing function to be applied to create the list of samples. It + must take the root_path and the meta_file name and return a list of samples in the format of + `[[text, audio_path, speaker_id], ...]]`. See the available formatters in `TTS.tts.dataset.formatter` as + example. Defaults to None. + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + + Returns: + Tuple[List[List], List[List]: training and evaluation splits of the dataset. + """ + meta_data_train_all = [] + meta_data_eval_all = [] if eval_split else None + if not isinstance(datasets, list): + datasets = [datasets] + for dataset in datasets: + formatter_name = dataset["formatter"] + dataset_name = dataset["dataset_name"] + root_path = dataset["path"] + meta_file_train = dataset["meta_file_train"] + meta_file_val = dataset["meta_file_val"] + ignored_speakers = dataset["ignored_speakers"] + language = dataset["language"] + + # setup the right data processor + if formatter is None: + formatter = _get_formatter_by_name(formatter_name) + # load train set + meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers) + assert len(meta_data_train) > 0, f" [!] No training samples found in {root_path}/{meta_file_train}" + + meta_data_train = add_extra_keys(meta_data_train, language, dataset_name) + + logger.info("Found %d files in %s", len(meta_data_train), Path(root_path).resolve()) + # load evaluation split if set + if eval_split: + if meta_file_val: + meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) + meta_data_eval = add_extra_keys(meta_data_eval, language, dataset_name) + else: + eval_size_per_dataset = eval_split_max_size // len(datasets) if eval_split_max_size else None + meta_data_eval, meta_data_train = split_dataset(meta_data_train, eval_size_per_dataset, eval_split_size) + meta_data_eval_all += meta_data_eval + meta_data_train_all += meta_data_train + # load attention masks for the duration predictor training + if dataset.meta_file_attn_mask: + meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) + for idx, ins in enumerate(meta_data_train_all): + attn_file = meta_data[ins["audio_file"]].strip() + meta_data_train_all[idx].update({"alignment_file": attn_file}) + if meta_data_eval_all: + for idx, ins in enumerate(meta_data_eval_all): + attn_file = meta_data[ins["audio_file"]].strip() + meta_data_eval_all[idx].update({"alignment_file": attn_file}) + # set none for the next iter + formatter = None + return meta_data_train_all, meta_data_eval_all + + +def load_attention_mask_meta_data(metafile_path): + """Load meta data file created by compute_attention_masks.py""" + with open(metafile_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + meta_data = [] + for line in lines: + wav_file, attn_file = line.split("|") + meta_data.append([wav_file, attn_file]) + return meta_data + + +def _get_formatter_by_name(name): + """Returns the respective preprocessing function.""" + thismodule = sys.modules[__name__] + return getattr(thismodule, name.lower()) + + +def find_unique_chars(data_samples): + texts = "".join(item["text"] for item in data_samples) + chars = set(texts) + lower_chars = filter(lambda c: c.islower(), chars) + chars_force_lower = [c.lower() for c in chars] + chars_force_lower = set(chars_force_lower) + + logger.info("Number of unique characters: %d", len(chars)) + logger.info("Unique characters: %s", "".join(sorted(chars))) + logger.info("Unique lower characters: %s", "".join(sorted(lower_chars))) + logger.info("Unique all forced to lower characters: %s", "".join(sorted(chars_force_lower))) + return chars_force_lower diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..37e3a1779d2b3523847da7a1c827d9309dafe44f --- /dev/null +++ b/TTS/tts/datasets/dataset.py @@ -0,0 +1,995 @@ +import base64 +import collections +import logging +import os +import random +from typing import Any, Optional, Union + +import numpy as np +import numpy.typing as npt +import torch +import torchaudio +import tqdm +from torch.utils.data import Dataset + +from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor +from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy + +logger = logging.getLogger(__name__) + +# to prevent too many open files error as suggested here +# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 +torch.multiprocessing.set_sharing_strategy("file_system") + + +def _parse_sample(item): + language_name = None + attn_file = None + if len(item) == 5: + text, wav_file, speaker_name, language_name, attn_file = item + elif len(item) == 4: + text, wav_file, speaker_name, language_name = item + elif len(item) == 3: + text, wav_file, speaker_name = item + else: + msg = "Dataset cannot parse the sample." + raise ValueError(msg) + return text, wav_file, speaker_name, language_name, attn_file + + +def noise_augment_audio(wav: npt.NDArray) -> npt.NDArray: + return wav + (1.0 / 32768.0) * np.random.rand(*wav.shape) + + +def string2filename(string: str) -> str: + # generate a safe and reversible filename based on a string + return base64.urlsafe_b64encode(string.encode("utf-8")).decode("utf-8", "ignore") + + +def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int: + """Return the number of samples in the audio file.""" + if not isinstance(audiopath, str): + audiopath = str(audiopath) + extension = audiopath.rpartition(".")[-1].lower() + if extension not in {"mp3", "wav", "flac"}: + msg = f"The audio format {extension} is not supported, please convert the audio files to mp3, flac, or wav format!" + raise RuntimeError(msg) + + try: + return torchaudio.info(audiopath).num_frames + except RuntimeError as e: + msg = f"Failed to decode {audiopath}" + raise RuntimeError(msg) from e + + +class TTSDataset(Dataset): + def __init__( + self, + outputs_per_step: int = 1, + compute_linear_spec: bool = False, + ap: AudioProcessor = None, + samples: Optional[list[dict]] = None, + tokenizer: "TTSTokenizer" = None, + compute_f0: bool = False, + compute_energy: bool = False, + f0_cache_path: Optional[str] = None, + energy_cache_path: Optional[str] = None, + return_wav: bool = False, + batch_group_size: int = 0, + min_text_len: int = 0, + max_text_len: int = float("inf"), + min_audio_len: int = 0, + max_audio_len: int = float("inf"), + phoneme_cache_path: Optional[str] = None, + precompute_num_workers: int = 0, + speaker_id_mapping: Optional[dict] = None, + d_vector_mapping: Optional[dict] = None, + language_id_mapping: Optional[dict] = None, + use_noise_augment: bool = False, + start_by_longest: bool = False, + ) -> None: + """Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs. + + If you need something different, you can subclass and override. + + Args: + ---- + outputs_per_step (int): Number of time frames predicted per step. + + compute_linear_spec (bool): compute linear spectrogram if True. + + ap (TTS.tts.utils.AudioProcessor): Audio processor object. + + samples (list): List of dataset samples. + + tokenizer (TTSTokenizer): tokenizer to convert text to sequence IDs. If None init internally else + use the given. Defaults to None. + + compute_f0 (bool): compute f0 if True. Defaults to False. + + compute_energy (bool): compute energy if True. Defaults to False. + + f0_cache_path (str): Path to store f0 cache. Defaults to None. + + energy_cache_path (str): Path to store energy cache. Defaults to None. + + return_wav (bool): Return the waveform of the sample. Defaults to False. + + batch_group_size (int): Range of batch randomization after sorting + sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a + batch. Set 0 to disable. Defaults to 0. + + min_text_len (int): Minimum length of input text to be used. All shorter samples will be ignored. + Defaults to 0. + + max_text_len (int): Maximum length of input text to be used. All longer samples will be ignored. + Defaults to float("inf"). + + min_audio_len (int): Minimum length of input audio to be used. All shorter samples will be ignored. + Defaults to 0. + + max_audio_len (int): Maximum length of input audio to be used. All longer samples will be ignored. + The maximum length in the dataset defines the VRAM used in the training. Hence, pay attention to + this value if you encounter an OOM error in training. Defaults to float("inf"). + + phoneme_cache_path (str): Path to cache computed phonemes. It writes phonemes of each sample to a + separate file. Defaults to None. + + precompute_num_workers (int): Number of workers to precompute features. Defaults to 0. + + speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the + embedding layer. Defaults to None. + + d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None. + + use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False. + + start_by_longest (bool): Start by longest sequence. It is especially useful to check OOM. Defaults to False. + + """ + super().__init__() + self.batch_group_size = batch_group_size + self._samples = samples + self.outputs_per_step = outputs_per_step + self.compute_linear_spec = compute_linear_spec + self.return_wav = return_wav + self.compute_f0 = compute_f0 + self.compute_energy = compute_energy + self.f0_cache_path = f0_cache_path + self.energy_cache_path = energy_cache_path + self.min_audio_len = min_audio_len + self.max_audio_len = max_audio_len + self.min_text_len = min_text_len + self.max_text_len = max_text_len + self.ap = ap + self.phoneme_cache_path = phoneme_cache_path + self.speaker_id_mapping = speaker_id_mapping + self.d_vector_mapping = d_vector_mapping + self.language_id_mapping = language_id_mapping + self.use_noise_augment = use_noise_augment + self.start_by_longest = start_by_longest + + self.rescue_item_idx = 1 + self.pitch_computed = False + self.tokenizer = tokenizer + + if self.tokenizer.use_phonemes: + self.phoneme_dataset = PhonemeDataset( + self.samples, + self.tokenizer, + phoneme_cache_path, + precompute_num_workers=precompute_num_workers, + ) + + if compute_f0: + self.f0_dataset = F0Dataset( + self.samples, + self.ap, + cache_path=f0_cache_path, + precompute_num_workers=precompute_num_workers, + ) + if compute_energy: + self.energy_dataset = EnergyDataset( + self.samples, + self.ap, + cache_path=energy_cache_path, + precompute_num_workers=precompute_num_workers, + ) + self.print_logs() + + @property + def lengths(self) -> list[int]: + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + try: + audio_len = get_audio_size(wav_file) + except RuntimeError: + logger.warning(f"Failed to compute length for {item['audio_file']}") + audio_len = 0 + lens.append(audio_len) + return lens + + @property + def samples(self): + return self._samples + + @samples.setter + def samples(self, new_samples) -> None: + self._samples = new_samples + if hasattr(self, "f0_dataset"): + self.f0_dataset.samples = new_samples + if hasattr(self, "energy_dataset"): + self.energy_dataset.samples = new_samples + if hasattr(self, "phoneme_dataset"): + self.phoneme_dataset.samples = new_samples + + def __len__(self) -> int: + return len(self.samples) + + def __getitem__(self, idx): + return self.load_data(idx) + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + logger.info("%sDataLoader initialization", indent) + logger.info("%s| Tokenizer:", indent) + self.tokenizer.print_logs(level + 1) + logger.info("%s| Number of instances : %d", indent, len(self.samples)) + + def load_wav(self, filename): + waveform = self.ap.load_wav(filename) + assert waveform.size > 0 + return waveform + + def get_phonemes(self, idx, text): + out_dict = self.phoneme_dataset[idx] + assert text == out_dict["text"], f"{text} != {out_dict['text']}" + assert len(out_dict["token_ids"]) > 0 + return out_dict + + def get_f0(self, idx): + out_dict = self.f0_dataset[idx] + item = self.samples[idx] + assert item["audio_unique_name"] == out_dict["audio_unique_name"] + return out_dict + + def get_energy(self, idx): + out_dict = self.energy_dataset[idx] + item = self.samples[idx] + assert item["audio_unique_name"] == out_dict["audio_unique_name"] + return out_dict + + @staticmethod + def get_attn_mask(attn_file): + return np.load(attn_file) + + def get_token_ids(self, idx, text): + if self.tokenizer.use_phonemes: + token_ids = self.get_phonemes(idx, text)["token_ids"] + else: + token_ids = self.tokenizer.text_to_ids(text) + return np.array(token_ids, dtype=np.int32) + + def load_data(self, idx) -> dict[str, Any]: + item = self.samples[idx] + + raw_text = item["text"] + + wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32) + + # apply noise for augmentation + if self.use_noise_augment: + wav = noise_augment_audio(wav) + + # get token ids + token_ids = self.get_token_ids(idx, item["text"]) + + # get pre-computed attention maps + attn = None + if "alignment_file" in item: + attn = self.get_attn_mask(item["alignment_file"]) + + # after phonemization the text length may change + # this is a shareful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or len(wav) < self.min_audio_len: + self.rescue_item_idx += 1 + return self.load_data(self.rescue_item_idx) + + # get f0 values + f0 = None + if self.compute_f0: + f0 = self.get_f0(idx)["f0"] + energy = None + if self.compute_energy: + energy = self.get_energy(idx)["energy"] + + return { + "raw_text": raw_text, + "token_ids": token_ids, + "wav": wav, + "pitch": f0, + "energy": energy, + "attn": attn, + "item_idx": item["audio_file"], + "speaker_name": item["speaker_name"], + "language_name": item["language"], + "wav_file_name": os.path.basename(item["audio_file"]), + "audio_unique_name": item["audio_unique_name"], + } + + @staticmethod + def _compute_lengths(samples): + new_samples = [] + for item in samples: + try: + audio_length = get_audio_size(item["audio_file"]) + except RuntimeError: + logger.warning(f"Failed to compute length, skipping {item['audio_file']}") + continue + text_lenght = len(item["text"]) + item["audio_length"] = audio_length + item["text_length"] = text_lenght + new_samples += [item] + return new_samples + + @staticmethod + def filter_by_length(lengths: list[int], min_len: int, max_len: int): + idxs = np.argsort(lengths) # ascending order + ignore_idx = [] + keep_idx = [] + for idx in idxs: + length = lengths[idx] + if length < min_len or length > max_len: + ignore_idx.append(idx) + else: + keep_idx.append(idx) + return ignore_idx, keep_idx + + @staticmethod + def sort_by_length(samples: list[list]): + audio_lengths = [s["audio_length"] for s in samples] + return np.argsort(audio_lengths) # ascending order + + @staticmethod + def create_buckets(samples, batch_group_size: int): + assert batch_group_size > 0 + for i in range(len(samples) // batch_group_size): + offset = i * batch_group_size + end_offset = offset + batch_group_size + temp_items = samples[offset:end_offset] + random.shuffle(temp_items) + samples[offset:end_offset] = temp_items + return samples + + @staticmethod + def _select_samples_by_idx(idxs, samples): + samples_new = [] + for idx in idxs: + samples_new.append(samples[idx]) + return samples_new + + def preprocess_samples(self) -> None: + r"""Sort `items` based on text length or audio length in ascending order. Filter out samples out or the length + range. + """ + samples = self._compute_lengths(self.samples) + + # sort items based on the sequence length in ascending order + text_lengths = [i["text_length"] for i in samples] + audio_lengths = [i["audio_length"] for i in samples] + text_ignore_idx, text_keep_idx = self.filter_by_length(text_lengths, self.min_text_len, self.max_text_len) + audio_ignore_idx, audio_keep_idx = self.filter_by_length(audio_lengths, self.min_audio_len, self.max_audio_len) + keep_idx = list(set(audio_keep_idx) & set(text_keep_idx)) + ignore_idx = list(set(audio_ignore_idx) | set(text_ignore_idx)) + + samples = self._select_samples_by_idx(keep_idx, samples) + + sorted_idxs = self.sort_by_length(samples) + + if self.start_by_longest: + longest_idxs = sorted_idxs[-1] + sorted_idxs[-1] = sorted_idxs[0] + sorted_idxs[0] = longest_idxs + + samples = self._select_samples_by_idx(sorted_idxs, samples) + + if len(samples) == 0: + msg = "No samples left." + raise RuntimeError(msg) + + # shuffle batch groups + # create batches with similar length items + # the larger the `batch_group_size`, the higher the length variety in a batch. + if self.batch_group_size > 0: + samples = self.create_buckets(samples, self.batch_group_size) + + # update items to the new sorted items + audio_lengths = [s["audio_length"] for s in samples] + text_lengths = [s["text_length"] for s in samples] + self.samples = samples + + logger.info("Preprocessing samples") + logger.info(f"Max text length: {np.max(text_lengths)}") + logger.info(f"Min text length: {np.min(text_lengths)}") + logger.info(f"Avg text length: {np.mean(text_lengths)}") + logger.info(f"Max audio length: {np.max(audio_lengths)}") + logger.info(f"Min audio length: {np.min(audio_lengths)}") + logger.info(f"Avg audio length: {np.mean(audio_lengths)}") + logger.info("Num. instances discarded samples: %d", len(ignore_idx)) + logger.info(f"Batch group size: {self.batch_group_size}.") + + @staticmethod + def _sort_batch(batch, text_lengths): + """Sort the batch by the input text length for RNN efficiency. + + Args: + ---- + batch (Dict): Batch returned by `__getitem__`. + text_lengths (List[int]): Lengths of the input character sequences. + + """ + text_lengths, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lengths), dim=0, descending=True) + batch = [batch[idx] for idx in ids_sorted_decreasing] + return batch, text_lengths, ids_sorted_decreasing + + def collate_fn(self, batch): + """Perform preprocessing and create a final data batch. + + 1. Sort batch instances by text-length + 2. Convert Audio signal to features. + 3. PAD sequences wrt r. + 4. Load to Torch. + """ + # Puts each data field into a tensor with outer dimension batch size + if isinstance(batch[0], collections.abc.Mapping): + token_ids_lengths = np.array([len(d["token_ids"]) for d in batch]) + + # sort items with text input length for RNN efficiency + batch, token_ids_lengths, ids_sorted_decreasing = self._sort_batch(batch, token_ids_lengths) + + # convert list of dicts to dict of lists + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + # get language ids from language names + if self.language_id_mapping is not None: + language_ids = [self.language_id_mapping[ln] for ln in batch["language_name"]] + else: + language_ids = None + # get pre-computed d-vectors + if self.d_vector_mapping is not None: + embedding_keys = list(batch["audio_unique_name"]) + d_vectors = [self.d_vector_mapping[w]["embedding"] for w in embedding_keys] + else: + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_id_mapping: + speaker_ids = [self.speaker_id_mapping[sn] for sn in batch["speaker_name"]] + else: + speaker_ids = None + # compute features + mel = [self.ap.melspectrogram(w).astype("float32") for w in batch["wav"]] + + mel_lengths = [m.shape[1] for m in mel] + + # lengths adjusted by the reduction factor + mel_lengths_adjusted = [ + ( + m.shape[1] + (self.outputs_per_step - (m.shape[1] % self.outputs_per_step)) + if m.shape[1] % self.outputs_per_step + else m.shape[1] + ) + for m in mel + ] + + # compute 'stop token' targets + stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths] + + # PAD stop targets + stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) + + # PAD sequences with longest instance in the batch + token_ids = prepare_data(batch["token_ids"]).astype(np.int32) + + # PAD features with longest instance + mel = prepare_tensor(mel, self.outputs_per_step) + + # B x D x T --> B x T x D + mel = mel.transpose(0, 2, 1) + + # convert things to pytorch + token_ids_lengths = torch.LongTensor(token_ids_lengths) + token_ids = torch.LongTensor(token_ids) + mel = torch.FloatTensor(mel).contiguous() + mel_lengths = torch.LongTensor(mel_lengths) + stop_targets = torch.FloatTensor(stop_targets) + + # speaker vectors + if d_vectors is not None: + d_vectors = torch.FloatTensor(d_vectors) + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + + if language_ids is not None: + language_ids = torch.LongTensor(language_ids) + + # compute linear spectrogram + linear = None + if self.compute_linear_spec: + linear = [self.ap.spectrogram(w).astype("float32") for w in batch["wav"]] + linear = prepare_tensor(linear, self.outputs_per_step) + linear = linear.transpose(0, 2, 1) + assert mel.shape[1] == linear.shape[1] + linear = torch.FloatTensor(linear).contiguous() + + # format waveforms + wav_padded = None + if self.return_wav: + wav_lengths = [w.shape[0] for w in batch["wav"]] + max_wav_len = max(mel_lengths_adjusted) * self.ap.hop_length + wav_lengths = torch.LongTensor(wav_lengths) + wav_padded = torch.zeros(len(batch["wav"]), 1, max_wav_len) + for i, w in enumerate(batch["wav"]): + mel_length = mel_lengths_adjusted[i] + w = np.pad(w, (0, self.ap.hop_length * self.outputs_per_step), mode="edge") + w = w[: mel_length * self.ap.hop_length] + wav_padded[i, :, : w.shape[0]] = torch.from_numpy(w) + wav_padded.transpose_(1, 2) + + # format F0 + if self.compute_f0: + pitch = prepare_data(batch["pitch"]) + assert mel.shape[1] == pitch.shape[1], f"[!] {mel.shape} vs {pitch.shape}" + pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT + else: + pitch = None + # format energy + if self.compute_energy: + energy = prepare_data(batch["energy"]) + assert mel.shape[1] == energy.shape[1], f"[!] {mel.shape} vs {energy.shape}" + energy = torch.FloatTensor(energy)[:, None, :].contiguous() # B x 1 xT + else: + energy = None + # format attention masks + attns = None + if batch["attn"][0] is not None: + attns = [batch["attn"][idx].T for idx in ids_sorted_decreasing] + for idx, attn in enumerate(attns): + pad2 = mel.shape[1] - attn.shape[1] + pad1 = token_ids.shape[1] - attn.shape[0] + assert pad1 >= 0 and pad2 >= 0, f"[!] Negative padding - {pad1} and {pad2}" + attn = np.pad(attn, [[0, pad1], [0, pad2]]) + attns[idx] = attn + attns = prepare_tensor(attns, self.outputs_per_step) + attns = torch.FloatTensor(attns).unsqueeze(1) + + return { + "token_id": token_ids, + "token_id_lengths": token_ids_lengths, + "speaker_names": batch["speaker_name"], + "linear": linear, + "mel": mel, + "mel_lengths": mel_lengths, + "stop_targets": stop_targets, + "item_idxs": batch["item_idx"], + "d_vectors": d_vectors, + "speaker_ids": speaker_ids, + "attns": attns, + "waveform": wav_padded, + "raw_text": batch["raw_text"], + "pitch": pitch, + "energy": energy, + "language_ids": language_ids, + "audio_unique_names": batch["audio_unique_name"], + } + + msg = f"batch must contain tensors, numbers, dicts or lists; found {type(batch[0])}" + raise TypeError(msg) + + +class PhonemeDataset(Dataset): + """Phoneme Dataset for converting input text to phonemes and then token IDs. + + At initialization, it pre-computes the phonemes under `cache_path` and loads them in training to reduce data + loading latency. If `cache_path` is already present, it skips the pre-computation. + + Args: + ---- + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + tokenizer (TTSTokenizer): + Tokenizer to convert input text to phonemes. + + cache_path (str): + Path to cache phonemes. If `cache_path` is already present or None, it skips the pre-computation. + + precompute_num_workers (int): + Number of workers used for pre-computing the phonemes. Defaults to 0. + + """ + + def __init__( + self, + samples: Union[list[dict], list[list]], + tokenizer: "TTSTokenizer", + cache_path: str, + precompute_num_workers: int = 0, + ) -> None: + self.samples = samples + self.tokenizer = tokenizer + self.cache_path = cache_path + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + + def __getitem__(self, index) -> dict[str, Any]: + item = self.samples[index] + ids = self.compute_or_load(string2filename(item["audio_unique_name"]), item["text"], item["language"]) + ph_hat = self.tokenizer.ids_to_text(ids) + return {"text": item["text"], "ph_hat": ph_hat, "token_ids": ids, "token_ids_len": len(ids)} + + def __len__(self) -> int: + return len(self.samples) + + def compute_or_load(self, file_name: str, text: str, language: str) -> list[int]: + """Compute phonemes for the given text. + + If the phonemes are already cached, load them from cache. + """ + file_ext = "_phoneme.npy" + cache_path = os.path.join(self.cache_path, file_name + file_ext) + try: + ids = np.load(cache_path) + except FileNotFoundError: + ids = self.tokenizer.text_to_ids(text, language=language) + np.save(cache_path, ids) + return ids + + def get_pad_id(self) -> int: + """Get pad token ID for sequence padding.""" + return self.tokenizer.pad_id + + def precompute(self, num_workers: int = 1) -> None: + """Precompute phonemes for all samples. + + We use pytorch dataloader because we are lazy. + """ + logger.info("Pre-computing phonemes...") + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, + dataset=self, + shuffle=False, + num_workers=num_workers, + collate_fn=self.collate_fn, + ) + for _ in dataloder: + pbar.update(batch_size) + + def collate_fn(self, batch): + ids = [item["token_ids"] for item in batch] + ids_lens = [item["token_ids_len"] for item in batch] + texts = [item["text"] for item in batch] + texts_hat = [item["ph_hat"] for item in batch] + ids_lens_max = max(ids_lens) + ids_torch = torch.LongTensor(len(ids), ids_lens_max).fill_(self.get_pad_id()) + for i, ids_len in enumerate(ids_lens): + ids_torch[i, :ids_len] = torch.LongTensor(ids[i]) + return {"text": texts, "ph_hat": texts_hat, "token_ids": ids_torch} + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + logger.info("%sPhonemeDataset", indent) + logger.info("%s| Tokenizer:", indent) + self.tokenizer.print_logs(level + 1) + logger.info("%s| Number of instances : %d", indent, len(self.samples)) + + +class F0Dataset: + """F0 Dataset for computing F0 from wav files in CPU. + + Pre-compute F0 values for all the samples at initialization if `cache_path` is not None or already present. It + also computes the mean and std of F0 values if `normalize_f0` is True. + + Args: + ---- + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + ap (AudioProcessor): + AudioProcessor to compute F0 from wav files. + + cache_path (str): + Path to cache F0 values. If `cache_path` is already present or None, it skips the pre-computation. + Defaults to None. + + precompute_num_workers (int): + Number of workers used for pre-computing the F0 values. Defaults to 0. + + normalize_f0 (bool): + Whether to normalize F0 values by mean and std. Defaults to True. + + """ + + def __init__( + self, + samples: Union[list[list], list[dict]], + ap: "AudioProcessor", + audio_config=None, # pylint: disable=unused-argument + cache_path: Optional[str] = None, + precompute_num_workers: int = 0, + normalize_f0: bool = True, + ) -> None: + self.samples = samples + self.ap = ap + self.cache_path = cache_path + self.normalize_f0 = normalize_f0 + self.pad_id = 0.0 + self.mean = None + self.std = None + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + if normalize_f0: + self.load_stats(cache_path) + + def __getitem__(self, idx): + item = self.samples[idx] + f0 = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"])) + if self.normalize_f0: + assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" + f0 = self.normalize(f0) + return {"audio_unique_name": item["audio_unique_name"], "f0": f0} + + def __len__(self) -> int: + return len(self.samples) + + def precompute(self, num_workers: int = 0) -> None: + logger.info("Pre-computing F0s...") + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + # we do not normalize at preproessing + normalize_f0 = self.normalize_f0 + self.normalize_f0 = False + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, + dataset=self, + shuffle=False, + num_workers=num_workers, + collate_fn=self.collate_fn, + ) + computed_data = [] + for batch in dataloder: + f0 = batch["f0"] + computed_data.append(f for f in f0) + pbar.update(batch_size) + self.normalize_f0 = normalize_f0 + + if self.normalize_f0: + computed_data = [tensor for batch in computed_data for tensor in batch] # flatten + pitch_mean, pitch_std = self.compute_pitch_stats(computed_data) + pitch_stats = {"mean": pitch_mean, "std": pitch_std} + np.save(os.path.join(self.cache_path, "pitch_stats"), pitch_stats, allow_pickle=True) + + def get_pad_id(self): + return self.pad_id + + @staticmethod + def create_pitch_file_path(file_name: str, cache_path: str) -> str: + return os.path.join(cache_path, file_name + "_pitch.npy") + + @staticmethod + def _compute_and_save_pitch(ap, wav_file, pitch_file=None): + wav = ap.load_wav(wav_file) + pitch = ap.compute_f0(wav) + if pitch_file: + np.save(pitch_file, pitch) + return pitch + + @staticmethod + def compute_pitch_stats(pitch_vecs): + nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs]) + mean, std = np.mean(nonzeros), np.std(nonzeros) + return mean, std + + def load_stats(self, cache_path) -> None: + stats_path = os.path.join(cache_path, "pitch_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"].astype(np.float32) + self.std = stats["std"].astype(np.float32) + + def normalize(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch = pitch - self.mean + pitch = pitch / self.std + pitch[zero_idxs] = 0.0 + return pitch + + def denormalize(self, pitch): + zero_idxs = np.where(pitch == 0.0)[0] + pitch *= self.std + pitch += self.mean + pitch[zero_idxs] = 0.0 + return pitch + + def compute_or_load(self, wav_file, audio_unique_name): + """Compute pitch and return a numpy array of pitch values.""" + pitch_file = self.create_pitch_file_path(audio_unique_name, self.cache_path) + if not os.path.exists(pitch_file): + pitch = self._compute_and_save_pitch(self.ap, wav_file, pitch_file) + else: + pitch = np.load(pitch_file) + return pitch.astype(np.float32) + + def collate_fn(self, batch): + audio_unique_name = [item["audio_unique_name"] for item in batch] + f0s = [item["f0"] for item in batch] + f0_lens = [len(item["f0"]) for item in batch] + f0_lens_max = max(f0_lens) + f0s_torch = torch.LongTensor(len(f0s), f0_lens_max).fill_(self.get_pad_id()) + for i, f0_len in enumerate(f0_lens): + f0s_torch[i, :f0_len] = torch.LongTensor(f0s[i]) + return {"audio_unique_name": audio_unique_name, "f0": f0s_torch, "f0_lens": f0_lens} + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + logger.info("%sF0Dataset", indent) + logger.info("%s| Number of instances : %d", indent, len(self.samples)) + + +class EnergyDataset: + """Energy Dataset for computing Energy from wav files in CPU. + + Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It + also computes the mean and std of Energy values if `normalize_Energy` is True. + + Args: + ---- + samples (Union[List[List], List[Dict]]): + List of samples. Each sample is a list or a dict. + + ap (AudioProcessor): + AudioProcessor to compute Energy from wav files. + + cache_path (str): + Path to cache Energy values. If `cache_path` is already present or None, it skips the pre-computation. + Defaults to None. + + precompute_num_workers (int): + Number of workers used for pre-computing the Energy values. Defaults to 0. + + normalize_Energy (bool): + Whether to normalize Energy values by mean and std. Defaults to True. + + """ + + def __init__( + self, + samples: Union[list[list], list[dict]], + ap: "AudioProcessor", + cache_path: Optional[str] = None, + precompute_num_workers=0, + normalize_energy=True, + ) -> None: + self.samples = samples + self.ap = ap + self.cache_path = cache_path + self.normalize_energy = normalize_energy + self.pad_id = 0.0 + self.mean = None + self.std = None + if cache_path is not None and not os.path.exists(cache_path): + os.makedirs(cache_path) + self.precompute(precompute_num_workers) + if normalize_energy: + self.load_stats(cache_path) + + def __getitem__(self, idx): + item = self.samples[idx] + energy = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"])) + if self.normalize_energy: + assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available" + energy = self.normalize(energy) + return {"audio_unique_name": item["audio_unique_name"], "energy": energy} + + def __len__(self) -> int: + return len(self.samples) + + def precompute(self, num_workers=0) -> None: + logger.info("Pre-computing energys...") + with tqdm.tqdm(total=len(self)) as pbar: + batch_size = num_workers if num_workers > 0 else 1 + # we do not normalize at preproessing + normalize_energy = self.normalize_energy + self.normalize_energy = False + dataloder = torch.utils.data.DataLoader( + batch_size=batch_size, + dataset=self, + shuffle=False, + num_workers=num_workers, + collate_fn=self.collate_fn, + ) + computed_data = [] + for batch in dataloder: + energy = batch["energy"] + computed_data.append(e for e in energy) + pbar.update(batch_size) + self.normalize_energy = normalize_energy + + if self.normalize_energy: + computed_data = [tensor for batch in computed_data for tensor in batch] # flatten + energy_mean, energy_std = self.compute_energy_stats(computed_data) + energy_stats = {"mean": energy_mean, "std": energy_std} + np.save(os.path.join(self.cache_path, "energy_stats"), energy_stats, allow_pickle=True) + + def get_pad_id(self): + return self.pad_id + + @staticmethod + def create_energy_file_path(wav_file, cache_path): + file_name = os.path.splitext(os.path.basename(wav_file))[0] + return os.path.join(cache_path, file_name + "_energy.npy") + + @staticmethod + def _compute_and_save_energy(ap, wav_file, energy_file=None): + wav = ap.load_wav(wav_file) + energy = calculate_energy(wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length) + if energy_file: + np.save(energy_file, energy) + return energy + + @staticmethod + def compute_energy_stats(energy_vecs): + nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in energy_vecs]) + mean, std = np.mean(nonzeros), np.std(nonzeros) + return mean, std + + def load_stats(self, cache_path) -> None: + stats_path = os.path.join(cache_path, "energy_stats.npy") + stats = np.load(stats_path, allow_pickle=True).item() + self.mean = stats["mean"].astype(np.float32) + self.std = stats["std"].astype(np.float32) + + def normalize(self, energy): + zero_idxs = np.where(energy == 0.0)[0] + energy = energy - self.mean + energy = energy / self.std + energy[zero_idxs] = 0.0 + return energy + + def denormalize(self, energy): + zero_idxs = np.where(energy == 0.0)[0] + energy *= self.std + energy += self.mean + energy[zero_idxs] = 0.0 + return energy + + def compute_or_load(self, wav_file, audio_unique_name): + """Compute energy and return a numpy array of energy values.""" + energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path) + if not os.path.exists(energy_file): + energy = self._compute_and_save_energy(self.ap, wav_file, energy_file) + else: + energy = np.load(energy_file) + return energy.astype(np.float32) + + def collate_fn(self, batch): + audio_unique_name = [item["audio_unique_name"] for item in batch] + energys = [item["energy"] for item in batch] + energy_lens = [len(item["energy"]) for item in batch] + energy_lens_max = max(energy_lens) + energys_torch = torch.LongTensor(len(energys), energy_lens_max).fill_(self.get_pad_id()) + for i, energy_len in enumerate(energy_lens): + energys_torch[i, :energy_len] = torch.LongTensor(energys[i]) + return {"audio_unique_name": audio_unique_name, "energy": energys_torch, "energy_lens": energy_lens} + + def print_logs(self, level: int = 0) -> None: + indent = "\t" * level + logger.info("%senergyDataset") + logger.info("%s| Number of instances : %d", indent, len(self.samples)) diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py new file mode 100644 index 0000000000000000000000000000000000000000..ff1a76e2c9b08e2d614fc272bbdc56ff2be9f75f --- /dev/null +++ b/TTS/tts/datasets/formatters.py @@ -0,0 +1,662 @@ +import csv +import logging +import os +import re +import xml.etree.ElementTree as ET +from glob import glob +from pathlib import Path +from typing import List + +from tqdm import tqdm + +logger = logging.getLogger(__name__) + +######################## +# DATASETS +######################## + + +def cml_tts(root_path, meta_file, ignored_speakers=None): + """Normalizes the CML-TTS meta data file to TTS format + https://github.com/freds0/CML-TTS-Dataset/""" + filepath = os.path.join(root_path, meta_file) + # ensure there are 4 columns for every line + with open(filepath, "r", encoding="utf8") as f: + lines = f.readlines() + num_cols = len(lines[0].split("|")) # take the first row as reference + for idx, line in enumerate(lines[1:]): + if len(line.split("|")) != num_cols: + logger.warning("Missing column in line %d -> %s", idx + 1, line.strip()) + # load metadata + with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f: + reader = csv.DictReader(f, delimiter="|") + metadata = list(reader) + assert all(x in metadata[0] for x in ["wav_filename", "transcript"]) + client_id = None if "client_id" in metadata[0] else "default" + emotion_name = None if "emotion_name" in metadata[0] else "neutral" + items = [] + not_found_counter = 0 + for row in metadata: + if client_id is None and ignored_speakers is not None and row["client_id"] in ignored_speakers: + continue + audio_path = os.path.join(root_path, row["wav_filename"]) + if not os.path.exists(audio_path): + not_found_counter += 1 + continue + items.append( + { + "text": row["transcript"], + "audio_file": audio_path, + "speaker_name": client_id if client_id is not None else row["client_id"], + "emotion_name": emotion_name if emotion_name is not None else row["emotion_name"], + "root_path": root_path, + } + ) + if not_found_counter > 0: + logger.warning("%d files not found", not_found_counter) + return items + + +def coqui(root_path, meta_file, ignored_speakers=None): + """Interal dataset formatter.""" + filepath = os.path.join(root_path, meta_file) + # ensure there are 4 columns for every line + with open(filepath, "r", encoding="utf8") as f: + lines = f.readlines() + num_cols = len(lines[0].split("|")) # take the first row as reference + for idx, line in enumerate(lines[1:]): + if len(line.split("|")) != num_cols: + logger.warning("Missing column in line %d -> %s", idx + 1, line.strip()) + # load metadata + with open(Path(root_path) / meta_file, newline="", encoding="utf-8") as f: + reader = csv.DictReader(f, delimiter="|") + metadata = list(reader) + assert all(x in metadata[0] for x in ["audio_file", "text"]) + speaker_name = None if "speaker_name" in metadata[0] else "coqui" + emotion_name = None if "emotion_name" in metadata[0] else "neutral" + items = [] + not_found_counter = 0 + for row in metadata: + if speaker_name is None and ignored_speakers is not None and row["speaker_name"] in ignored_speakers: + continue + audio_path = os.path.join(root_path, row["audio_file"]) + if not os.path.exists(audio_path): + not_found_counter += 1 + continue + items.append( + { + "text": row["text"], + "audio_file": audio_path, + "speaker_name": speaker_name if speaker_name is not None else row["speaker_name"], + "emotion_name": emotion_name if emotion_name is not None else row["emotion_name"], + "root_path": root_path, + } + ) + if not_found_counter > 0: + logger.warning("%d files not found", not_found_counter) + return items + + +def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalize TWEB dataset. + https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset + """ + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "tweb" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("\t") + wav_file = os.path.join(root_path, cols[0] + ".wav") + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes Mozilla meta data files to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "mozilla" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = cols[1].strip() + text = cols[0].strip() + wav_file = os.path.join(root_path, "wavs", wav_file) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes Mozilla meta data files to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "mozilla" + with open(txt_file, "r", encoding="ISO 8859-1") as ttf: + for line in ttf: + cols = line.strip().split("|") + wav_file = cols[0].strip() + text = cols[1].strip() + folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" + wav_file = os.path.join(root_path, folder_name, wav_file) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def mailabs(root_path, meta_files=None, ignored_speakers=None): + """Normalizes M-AI-Labs meta data files to TTS format + + Args: + root_path (str): root folder of the MAILAB language folder. + meta_files (str): list of meta files to be used in the training. If None, finds all the csv files + recursively. Defaults to None + """ + speaker_regex = re.compile(f"by_book{os.sep}(male|female){os.sep}(?P[^{os.sep}]+){os.sep}") + if not meta_files: + csv_files = glob(root_path + f"{os.sep}**{os.sep}metadata.csv", recursive=True) + else: + csv_files = meta_files + + # meta_files = [f.strip() for f in meta_files.split(",")] + items = [] + for csv_file in csv_files: + if os.path.isfile(csv_file): + txt_file = csv_file + else: + txt_file = os.path.join(root_path, csv_file) + + folder = os.path.dirname(txt_file) + # determine speaker based on folder structure... + speaker_name_match = speaker_regex.search(txt_file) + if speaker_name_match is None: + continue + speaker_name = speaker_name_match.group("speaker_name") + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_name in ignored_speakers: + continue + logger.info(csv_file) + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + if not meta_files: + wav_file = os.path.join(folder, "wavs", cols[0] + ".wav") + else: + wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav") + if os.path.isfile(wav_file): + text = cols[1].strip() + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path} + ) + else: + # M-AI-Labs have some missing samples, so just print the warning + logger.warning("File %s does not exist!", wav_file) + return items + + +def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the LJSpeech meta data file to TTS format + https://keithito.com/LJ-Speech-Dataset/""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "ljspeech" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[2] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the LJSpeech meta data file for TTS testing + https://keithito.com/LJ-Speech-Dataset/""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + speaker_id = 0 + for idx, line in enumerate(ttf): + # 2 samples per speaker to avoid eval split issues + if idx % 2 == 0: + speaker_id += 1 + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[2] + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}", "root_path": root_path} + ) + return items + + +def thorsten(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the thorsten meta data file to TTS format + https://github.com/thorstenMueller/deep-learning-german-tts/""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "thorsten" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the sam-accenture meta data file to TTS format + https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files""" + xml_file = os.path.join(root_path, "voice_over_recordings", meta_file) + xml_root = ET.parse(xml_file).getroot() + items = [] + speaker_name = "sam_accenture" + for item in xml_root.findall("./fileid"): + text = item.text + wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav") + if not os.path.exists(wav_file): + logger.warning("%s in metafile does not exist. Skipping...", wav_file) + continue + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the RUSLAN meta data file to TTS format + https://ruslan-corpus.github.io/""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "ruslan" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav") + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the CSS10 dataset file to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "css10" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, cols[0]) + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Normalizes the Nancy meta data file to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "nancy" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + utt_id = line.split()[1] + text = line[line.find('"') + 1 : line.rfind('"') - 1] + wav_file = os.path.join(root_path, "wavn", utt_id + ".wav") + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def common_voice(root_path, meta_file, ignored_speakers=None): + """Normalize the common voice meta data file to TTS format.""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + if line.startswith("client_id"): + continue + cols = line.split("\t") + text = cols[2] + speaker_name = cols[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_name in ignored_speakers: + continue + wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav")) + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name, "root_path": root_path} + ) + return items + + +def libri_tts(root_path, meta_files=None, ignored_speakers=None): + """https://ai.google/tools/datasets/libri-tts/""" + items = [] + if not meta_files: + meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True) + else: + if isinstance(meta_files, str): + meta_files = [os.path.join(root_path, meta_files)] + + for meta_file in meta_files: + _meta_file = os.path.basename(meta_file).split(".")[0] + with open(meta_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("\t") + file_name = cols[0] + speaker_name, chapter_id, *_ = cols[0].split("_") + _root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}") + wav_file = os.path.join(_root_path, file_name + ".wav") + text = cols[2] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_name in ignored_speakers: + continue + items.append( + { + "text": text, + "audio_file": wav_file, + "speaker_name": f"LTTS_{speaker_name}", + "root_path": root_path, + } + ) + for item in items: + assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}" + return items + + +def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "turkish-female" + skipped_files = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0].strip() + ".wav") + if not os.path.exists(wav_file): + skipped_files.append(wav_file) + continue + text = cols[1].strip() + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + logger.warning("%d files skipped. They don't exist...") + return items + + +# ToDo: add the dataset link when the dataset is released publicly +def brspeech(root_path, meta_file, ignored_speakers=None): + """BRSpeech 3.0 beta""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + if line.startswith("wav_filename"): + continue + cols = line.split("|") + wav_file = os.path.join(root_path, cols[0]) + text = cols[2] + speaker_id = cols[3] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id, "root_path": root_path}) + return items + + +def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic1", ignored_speakers=None): + """VCTK dataset v0.92. + + URL: + https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip + + This dataset has 2 recordings per speaker that are annotated with ```mic1``` and ```mic2```. + It is believed that (😄 ) ```mic1``` files are the same as the previous version of the dataset. + + mic1: + Audio recorded using an omni-directional microphone (DPA 4035). + Contains very low frequency noises. + This is the same audio released in previous versions of VCTK: + https://doi.org/10.7488/ds/1994 + + mic2: + Audio recorded using a small diaphragm condenser microphone with + very wide bandwidth (Sennheiser MKH 800). + Two speakers, p280 and p315 had technical issues of the audio + recordings using MKH 800. + """ + file_ext = "flac" + items = [] + meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) + for meta_file in meta_files: + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + with open(meta_file, "r", encoding="utf-8") as file_text: + text = file_text.readlines()[0] + # p280 has no mic2 recordings + if speaker_id == "p280": + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_mic1.{file_ext}") + else: + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}") + if os.path.exists(wav_file): + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path} + ) + else: + logger.warning("Wav file doesn't exist - %s", wav_file) + return items + + +def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): + """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" + items = [] + meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) + for meta_file in meta_files: + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + with open(meta_file, "r", encoding="utf-8") as file_text: + text = file_text.readlines()[0] + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id, "root_path": root_path} + ) + return items + + +def synpaflex(root_path, metafiles=None, **kwargs): # pylint: disable=unused-argument + items = [] + speaker_name = "synpaflex" + root_path = os.path.join(root_path, "") + wav_files = glob(f"{root_path}**/*.wav", recursive=True) + for wav_file in wav_files: + if os.sep + "wav" + os.sep in wav_file: + txt_file = wav_file.replace("wav", "txt") + else: + txt_file = os.path.join( + os.path.dirname(wav_file), "txt", os.path.basename(wav_file).replace(".wav", ".txt") + ) + if os.path.exists(txt_file) and os.path.exists(wav_file): + with open(txt_file, "r", encoding="utf-8") as file_text: + text = file_text.readlines()[0] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def open_bible(root_path, meta_files="train", ignore_digits_sentences=True, ignored_speakers=None): + """ToDo: Refer the paper when available""" + items = [] + split_dir = meta_files + meta_files = glob(f"{os.path.join(root_path, split_dir)}/**/*.txt", recursive=True) + for meta_file in meta_files: + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + with open(meta_file, "r", encoding="utf-8") as file_text: + text = file_text.readline().replace("\n", "") + # ignore sentences that contains digits + if ignore_digits_sentences and any(map(str.isdigit, text)): + continue + wav_file = os.path.join(root_path, split_dir, speaker_id, file_id + ".flac") + items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id, "root_path": root_path}) + return items + + +def mls(root_path, meta_files=None, ignored_speakers=None): + """http://www.openslr.org/94/""" + items = [] + with open(os.path.join(root_path, meta_files), "r", encoding="utf-8") as meta: + for line in meta: + file, text = line.split("\t") + text = text[:-1] + speaker, book, *_ = file.split("_") + wav_file = os.path.join(root_path, os.path.dirname(meta_files), "audio", speaker, book, file + ".wav") + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker in ignored_speakers: + continue + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker, "root_path": root_path} + ) + return items + + +# ======================================== VOX CELEB =========================================== +def voxceleb2(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument + """ + :param meta_file Used only for consistency with load_tts_samples api + """ + return _voxcel_x(root_path, meta_file, voxcel_idx="2") + + +def voxceleb1(root_path, meta_file=None, **kwargs): # pylint: disable=unused-argument + """ + :param meta_file Used only for consistency with load_tts_samples api + """ + return _voxcel_x(root_path, meta_file, voxcel_idx="1") + + +def _voxcel_x(root_path, meta_file, voxcel_idx): + assert voxcel_idx in ["1", "2"] + expected_count = 148_000 if voxcel_idx == "1" else 1_000_000 + voxceleb_path = Path(root_path) + cache_to = voxceleb_path / f"metafile_voxceleb{voxcel_idx}.csv" + cache_to.parent.mkdir(exist_ok=True) + + # if not exists meta file, crawl recursively for 'wav' files + if meta_file is not None: + with open(str(meta_file), "r", encoding="utf-8") as f: + return [x.strip().split("|") for x in f.readlines()] + + elif not cache_to.exists(): + cnt = 0 + meta_data = [] + wav_files = voxceleb_path.rglob("**/*.wav") + for path in tqdm( + wav_files, + desc=f"Building VoxCeleb {voxcel_idx} Meta file ... this needs to be done only once.", + total=expected_count, + ): + speaker_id = str(Path(path).parent.parent.stem) + assert speaker_id.startswith("id") + text = None # VoxCel does not provide transciptions, and they are not needed for training the SE + meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n") + cnt += 1 + with open(str(cache_to), "w", encoding="utf-8") as f: + f.write("".join(meta_data)) + if cnt < expected_count: + raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}") + + with open(str(cache_to), "r", encoding="utf-8") as f: + return [x.strip().split("|") for x in f.readlines()] + + +def emotion(root_path, meta_file, ignored_speakers=None): + """Generic emotion dataset""" + txt_file = os.path.join(root_path, meta_file) + items = [] + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + if line.startswith("file_path"): + continue + cols = line.split(",") + wav_file = os.path.join(root_path, cols[0]) + speaker_id = cols[1] + emotion_id = cols[2].replace("\n", "") + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: + continue + items.append( + {"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id, "root_path": root_path} + ) + return items + + +def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylint: disable=unused-argument + """Normalizes the Baker meta data file to TTS format + + Args: + root_path (str): path to the baker dataset + meta_file (str): name of the meta dataset containing names of wav to select and the transcript of the sentence + Returns: + List[List[str]]: List of (text, wav_path, speaker_name) associated with each sentences + """ + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "baker" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + wav_name, text = line.rstrip("\n").split("|") + wav_path = os.path.join(root_path, "clips_22", wav_name) + items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Japanese single-speaker dataset from https://github.com/kaiidams/Kokoro-Speech-Dataset""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "kokoro" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") + text = cols[2].replace(" ", "") + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def kss(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + """Korean single-speaker dataset from https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "kss" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, cols[0]) + text = cols[2] # cols[1] => 6월, cols[2] => 유월 + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items + + +def bel_tts_formatter(root_path, meta_file, **kwargs): # pylint: disable=unused-argument + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "bel_tts" + with open(txt_file, "r", encoding="utf-8") as ttf: + for line in ttf: + cols = line.split("|") + wav_file = os.path.join(root_path, cols[0]) + text = cols[1] + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) + return items diff --git a/TTS/tts/layers/__init__.py b/TTS/tts/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f93efdb7fc41109ec3497d8e5e37ba05b0a4315e --- /dev/null +++ b/TTS/tts/layers/__init__.py @@ -0,0 +1 @@ +from TTS.tts.layers.losses import * diff --git a/TTS/tts/layers/align_tts/__init__.py b/TTS/tts/layers/align_tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/layers/align_tts/duration_predictor.py b/TTS/tts/layers/align_tts/duration_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b83894cc3f87575a89ea8fd7bf4a584ca22c28 --- /dev/null +++ b/TTS/tts/layers/align_tts/duration_predictor.py @@ -0,0 +1,21 @@ +from torch import nn + +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.layers.generic.transformer import FFTransformerBlock + + +class DurationPredictor(nn.Module): + def __init__(self, num_chars, hidden_channels, hidden_channels_ffn, num_heads): + super().__init__() + self.embed = nn.Embedding(num_chars, hidden_channels) + self.pos_enc = PositionalEncoding(hidden_channels, dropout_p=0.1) + self.FFT = FFTransformerBlock(hidden_channels, num_heads, hidden_channels_ffn, 2, 0.1) + self.out_layer = nn.Conv1d(hidden_channels, 1, 1) + + def forward(self, text, text_lengths): + # B, L -> B, L + emb = self.embed(text) + emb = self.pos_enc(emb.transpose(1, 2)) + x = self.FFT(emb, text_lengths) + x = self.out_layer(x).squeeze(-1) + return x diff --git a/TTS/tts/layers/align_tts/mdn.py b/TTS/tts/layers/align_tts/mdn.py new file mode 100644 index 0000000000000000000000000000000000000000..cdb332524bf7a5fec6a23da9e7977de6325a0324 --- /dev/null +++ b/TTS/tts/layers/align_tts/mdn.py @@ -0,0 +1,30 @@ +from torch import nn + + +class MDNBlock(nn.Module): + """Mixture of Density Network implementation + https://arxiv.org/pdf/2003.01950.pdf + """ + + def __init__(self, in_channels, out_channels): + super().__init__() + self.out_channels = out_channels + self.conv1 = nn.Conv1d(in_channels, in_channels, 1) + self.norm = nn.LayerNorm(in_channels) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.1) + self.conv2 = nn.Conv1d(in_channels, out_channels, 1) + + def forward(self, x): + o = self.conv1(x) + o = o.transpose(1, 2) + o = self.norm(o) + o = o.transpose(1, 2) + o = self.relu(o) + o = self.dropout(o) + mu_sigma = self.conv2(o) + # TODO: check this sigmoid + # mu = torch.sigmoid(mu_sigma[:, :self.out_channels//2, :]) + mu = mu_sigma[:, : self.out_channels // 2, :] + log_sigma = mu_sigma[:, self.out_channels // 2 :, :] + return mu, log_sigma diff --git a/TTS/tts/layers/bark/__init__.py b/TTS/tts/layers/bark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/layers/bark/hubert/__init__.py b/TTS/tts/layers/bark/hubert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/layers/bark/hubert/hubert_manager.py b/TTS/tts/layers/bark/hubert/hubert_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..fd936a9157910537a80472b3c0f98f59a202488f --- /dev/null +++ b/TTS/tts/layers/bark/hubert/hubert_manager.py @@ -0,0 +1,38 @@ +# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer + +import logging +import os.path +import shutil +import urllib.request + +import huggingface_hub + +logger = logging.getLogger(__name__) + + +class HubertManager: + @staticmethod + def make_sure_hubert_installed( + download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", model_path: str = "" + ): + if not os.path.isfile(model_path): + logger.info("Downloading HuBERT base model") + urllib.request.urlretrieve(download_url, model_path) + logger.info("Downloaded HuBERT") + return model_path + return None + + @staticmethod + def make_sure_tokenizer_installed( + model: str = "quantifier_hubert_base_ls960_14.pth", + repo: str = "GitMylo/bark-voice-cloning", + model_path: str = "", + ): + model_dir = os.path.dirname(model_path) + if not os.path.isfile(model_path): + logger.info("Downloading HuBERT custom tokenizer") + huggingface_hub.hf_hub_download(repo, model, local_dir=model_dir, local_dir_use_symlinks=False) + shutil.move(os.path.join(model_dir, model), model_path) + logger.info("Downloaded tokenizer") + return model_path + return None diff --git a/TTS/tts/layers/bark/hubert/kmeans_hubert.py b/TTS/tts/layers/bark/hubert/kmeans_hubert.py new file mode 100644 index 0000000000000000000000000000000000000000..58a614cb874339bed5115083810600171f5e2e5e --- /dev/null +++ b/TTS/tts/layers/bark/hubert/kmeans_hubert.py @@ -0,0 +1,80 @@ +""" +Modified HuBERT model without kmeans. +Original author: https://github.com/lucidrains/ +Modified by: https://www.github.com/gitmylo/ +License: MIT +""" + +# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py + + +import torch +from einops import pack, unpack +from torch import nn +from torchaudio.functional import resample +from transformers import HubertModel + + +def round_down_nearest_multiple(num, divisor): + return num // divisor * divisor + + +def curtail_to_multiple(t, mult, from_left=False): + data_len = t.shape[-1] + rounded_seq_len = round_down_nearest_multiple(data_len, mult) + seq_slice = slice(None, rounded_seq_len) if not from_left else slice(-rounded_seq_len, None) + return t[..., seq_slice] + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +class CustomHubert(nn.Module): + """ + checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert + or you can train your own + """ + + def __init__(self, target_sample_hz=16000, seq_len_multiple_of=None, output_layer=9, device=None): + super().__init__() + self.target_sample_hz = target_sample_hz + self.seq_len_multiple_of = seq_len_multiple_of + self.output_layer = output_layer + if device is not None: + self.to(device) + self.model = HubertModel.from_pretrained("facebook/hubert-base-ls960") + if device is not None: + self.model.to(device) + self.model.eval() + + @property + def groups(self): + return 1 + + @torch.no_grad() + def forward(self, wav_input, flatten=True, input_sample_hz=None): + device = wav_input.device + + if exists(input_sample_hz): + wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz) + + if exists(self.seq_len_multiple_of): + wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of) + + outputs = self.model.forward( + wav_input, + output_hidden_states=True, + ) + embed = outputs["hidden_states"][self.output_layer] + embed, packed_shape = pack([embed], "* d") + codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) + if flatten: + return codebook_indices + + (codebook_indices,) = unpack(codebook_indices, packed_shape, "*") + return codebook_indices diff --git a/TTS/tts/layers/bark/hubert/tokenizer.py b/TTS/tts/layers/bark/hubert/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..cd9579799a4bee34f7ba3479ba8e81e8048ed61d --- /dev/null +++ b/TTS/tts/layers/bark/hubert/tokenizer.py @@ -0,0 +1,198 @@ +""" +Custom tokenizer model. +Author: https://www.github.com/gitmylo/ +License: MIT +""" + +import json +import logging +import os.path +from zipfile import ZipFile + +import numpy +import torch +from torch import nn, optim + +logger = logging.getLogger(__name__) + + +class HubertTokenizer(nn.Module): + def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0): + super().__init__() + next_size = input_size + if version == 0: + self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True) + next_size = hidden_size + if version == 1: + self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True) + self.intermediate = nn.Linear(hidden_size, 4096) + next_size = 4096 + + self.fc = nn.Linear(next_size, output_size) + self.softmax = nn.LogSoftmax(dim=1) + self.optimizer: optim.Optimizer = None + self.lossfunc = nn.CrossEntropyLoss() + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.version = version + + def forward(self, x): + x, _ = self.lstm(x) + if self.version == 1: + x = self.intermediate(x) + x = self.fc(x) + x = self.softmax(x) + return x + + @torch.no_grad() + def get_token(self, x): + """ + Used to get the token for the first + :param x: An array with shape (N, input_size) where N is a whole number greater or equal to 1, and input_size is the input size used when creating the model. + :return: An array with shape (N,) where N is the same as N from the input. Every number in the array is a whole number in range 0...output_size - 1 where output_size is the output size used when creating the model. + """ + return torch.argmax(self(x), dim=1) + + def prepare_training(self): + self.optimizer = optim.Adam(self.parameters(), 0.001) + + def train_step(self, x_train, y_train, log_loss=False): + # y_train = y_train[:-1] + # y_train = y_train[1:] + + optimizer = self.optimizer + lossfunc = self.lossfunc + # Zero the gradients + self.zero_grad() + + # Forward pass + y_pred = self(x_train) + + y_train_len = len(y_train) + y_pred_len = y_pred.shape[0] + + if y_train_len > y_pred_len: + diff = y_train_len - y_pred_len + y_train = y_train[diff:] + elif y_train_len < y_pred_len: + diff = y_pred_len - y_train_len + y_pred = y_pred[:-diff, :] + + y_train_hot = torch.zeros(len(y_train), self.output_size) + y_train_hot[range(len(y_train)), y_train] = 1 + y_train_hot = y_train_hot.to("cuda") + + # Calculate the loss + loss = lossfunc(y_pred, y_train_hot) + + # Print loss + if log_loss: + logger.info("Loss %.3f", loss.item()) + + # Backward pass + loss.backward() + + # Update the weights + optimizer.step() + + def save(self, path): + info_path = ".".join(os.path.basename(path).split(".")[:-1]) + "/.info" + torch.save(self.state_dict(), path) + data_from_model = Data(self.input_size, self.hidden_size, self.output_size, self.version) + with ZipFile(path, "a") as model_zip: + model_zip.writestr(info_path, data_from_model.save()) + model_zip.close() + + @staticmethod + def load_from_checkpoint(path, map_location=None): + old = True + with ZipFile(path) as model_zip: + filesMatch = [file for file in model_zip.namelist() if file.endswith("/.info")] + file = filesMatch[0] if filesMatch else None + if file: + old = False + data_from_model = Data.load(model_zip.read(file).decode("utf-8")) + model_zip.close() + if old: + model = HubertTokenizer() + else: + model = HubertTokenizer( + data_from_model.hidden_size, + data_from_model.input_size, + data_from_model.output_size, + data_from_model.version, + ) + model.load_state_dict(torch.load(path, map_location=map_location)) + if map_location: + model = model.to(map_location) + return model + + +class Data: + input_size: int + hidden_size: int + output_size: int + version: int + + def __init__(self, input_size=768, hidden_size=1024, output_size=10000, version=0): + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.version = version + + @staticmethod + def load(string): + data = json.loads(string) + return Data(data["input_size"], data["hidden_size"], data["output_size"], data["version"]) + + def save(self): + data = { + "input_size": self.input_size, + "hidden_size": self.hidden_size, + "output_size": self.output_size, + "version": self.version, + } + return json.dumps(data) + + +def auto_train(data_path, save_path="model.pth", load_model: str = None, save_epochs=1): + data_x, data_y = [], [] + + if load_model and os.path.isfile(load_model): + logger.info("Loading model from %s", load_model) + model_training = HubertTokenizer.load_from_checkpoint(load_model, "cuda") + else: + logger.info("Creating new model.") + model_training = HubertTokenizer(version=1).to("cuda") # Settings for the model to run without lstm + save_path = os.path.join(data_path, save_path) + base_save_path = ".".join(save_path.split(".")[:-1]) + + sem_string = "_semantic.npy" + feat_string = "_semantic_features.npy" + + ready = os.path.join(data_path, "ready") + for input_file in os.listdir(ready): + full_path = os.path.join(ready, input_file) + if input_file.endswith(sem_string): + data_y.append(numpy.load(full_path)) + elif input_file.endswith(feat_string): + data_x.append(numpy.load(full_path)) + model_training.prepare_training() + + epoch = 1 + + while 1: + for _ in range(save_epochs): + j = 0 + for x, y in zip(data_x, data_y): + model_training.train_step( + torch.tensor(x).to("cuda"), torch.tensor(y).to("cuda"), j % 50 == 0 + ) # Print loss every 50 steps + j += 1 + save_p = save_path + save_p_2 = f"{base_save_path}_epoch_{epoch}.pth" + model_training.save(save_p) + model_training.save(save_p_2) + logger.info("Epoch %d completed", epoch) + epoch += 1 diff --git a/TTS/tts/layers/bark/inference_funcs.py b/TTS/tts/layers/bark/inference_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..65c7800dcf8807b18deeeb01fd577fa122f85ff7 --- /dev/null +++ b/TTS/tts/layers/bark/inference_funcs.py @@ -0,0 +1,609 @@ +import logging +import os +import re +from glob import glob +from typing import Dict, List, Optional, Tuple + +import librosa +import numpy as np +import numpy.typing as npt +import torch +import torchaudio +import tqdm +from encodec.utils import convert_audio +from scipy.special import softmax +from torch.nn import functional as F + +from TTS.tts.layers.bark.hubert.hubert_manager import HubertManager +from TTS.tts.layers.bark.hubert.kmeans_hubert import CustomHubert +from TTS.tts.layers.bark.hubert.tokenizer import HubertTokenizer +from TTS.tts.layers.bark.load_model import clear_cuda_cache, inference_mode + +logger = logging.getLogger(__name__) + + +def _tokenize(tokenizer, text): + return tokenizer.encode(text, add_special_tokens=False) + + +def _detokenize(tokenizer, enc_text): + return tokenizer.decode(enc_text) + + +def _normalize_whitespace(text): + return re.sub(r"\s+", " ", text).strip() + + +def get_voices(extra_voice_dirs: List[str] = []): # pylint: disable=dangerous-default-value + dirs = extra_voice_dirs + voices: Dict[str, List[str]] = {} + for d in dirs: + subs = os.listdir(d) + for sub in subs: + subj = os.path.join(d, sub) + if os.path.isdir(subj): + voices[sub] = list(glob(f"{subj}/*.npz")) + # fetch audio files if no npz files are found + if len(voices[sub]) == 0: + voices[sub] = list(glob(f"{subj}/*.wav")) + list(glob(f"{subj}/*.mp3")) + return voices + + +def load_npz(npz_file: str) -> Tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]: + x_history = np.load(npz_file) + semantic = x_history["semantic_prompt"] + coarse = x_history["coarse_prompt"] + fine = x_history["fine_prompt"] + return semantic, coarse, fine + + +def load_voice( + model, voice: str, extra_voice_dirs: List[str] = [] +) -> Tuple[ + Optional[npt.NDArray[np.int64]], Optional[npt.NDArray[np.int64]], Optional[npt.NDArray[np.int64]] +]: # pylint: disable=dangerous-default-value + if voice == "random": + return None, None, None + + voices = get_voices(extra_voice_dirs) + paths = voices[voice] + + # bark only uses a single sample for cloning + if len(paths) > 1: + raise ValueError(f"Voice {voice} has multiple paths: {paths}") + + try: + path = voices[voice] + except KeyError as e: + raise KeyError(f"Voice {voice} not found in {extra_voice_dirs}") from e + + if len(paths) == 1 and paths[0].endswith(".npz"): + return load_npz(path[0]) + + audio_path = paths[0] + # replace the file extension with .npz + output_path = os.path.splitext(audio_path)[0] + ".npz" + generate_voice(audio=audio_path, model=model, output_path=output_path) + return load_voice(model, voice, extra_voice_dirs) + + +def zero_crossing_rate(audio, frame_length=1024, hop_length=512): + zero_crossings = np.sum(np.abs(np.diff(np.sign(audio))) / 2) + total_frames = 1 + int((len(audio) - frame_length) / hop_length) + return zero_crossings / total_frames + + +def compute_spectral_contrast(audio_data, sample_rate, n_bands=6, fmin=200.0): + spectral_contrast = librosa.feature.spectral_contrast(y=audio_data, sr=sample_rate, n_bands=n_bands, fmin=fmin) + return np.mean(spectral_contrast) + + +def compute_average_bass_energy(audio_data, sample_rate, max_bass_freq=250): + stft = librosa.stft(audio_data) + power_spectrogram = np.abs(stft) ** 2 + frequencies = librosa.fft_frequencies(sr=sample_rate, n_fft=stft.shape[0]) + bass_mask = frequencies <= max_bass_freq + bass_energy = power_spectrogram[np.ix_(bass_mask, np.arange(power_spectrogram.shape[1]))].mean() + return bass_energy + + +def generate_voice( + audio, + model, + output_path, +): + """Generate a new voice from a given audio. + + Args: + audio (np.ndarray): The audio to use as a base for the new voice. + model (BarkModel): The BarkModel to use for generating the new voice. + output_path (str): The path to save the generated voice to. + """ + if isinstance(audio, str): + audio, sr = torchaudio.load(audio) + audio = convert_audio(audio, sr, model.config.sample_rate, model.encodec.channels) + audio = audio.unsqueeze(0).to(model.device) + + with torch.no_grad(): + encoded_frames = model.encodec.encode(audio) + codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T] + + # move codes to cpu + codes = codes.cpu().numpy() + + # generate semantic tokens + # Load the HuBERT model + hubert_manager = HubertManager() + hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"]) + + hubert_model = CustomHubert().to(model.device) + + # Load the CustomTokenizer model + tokenizer = HubertTokenizer.load_from_checkpoint( + model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"], map_location=model.device + ) + # semantic_tokens = model.text_to_semantic( + # text, max_gen_duration_s=seconds, top_k=50, top_p=0.95, temp=0.7 + # ) # not 100% + semantic_vectors = hubert_model.forward(audio[0], input_sample_hz=model.config.sample_rate) + semantic_tokens = tokenizer.get_token(semantic_vectors) + semantic_tokens = semantic_tokens.cpu().numpy() + + np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens) + + +def generate_text_semantic( + text, + model, + history_prompt=None, + temp=0.7, + top_k=None, + top_p=None, + silent=False, + min_eos_p=0.2, + max_gen_duration_s=None, + allow_early_stop=True, + base=None, + use_kv_caching=True, + **kwargs, # pylint: disable=unused-argument +): + """Generate semantic tokens from text. + + Args: + text (str): The text to generate semantic tokens from. + model (BarkModel): The BarkModel to use for generating the semantic tokens. + history_prompt (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a prompt for the generation. + temp (float): The temperature to use for the generation. + top_k (int): The number of top tokens to consider for the generation. + top_p (float): The cumulative probability to consider for the generation. + silent (bool): Whether to silence the tqdm progress bar. + min_eos_p (float): The minimum probability to consider for the end of sentence token. + max_gen_duration_s (float): The maximum duration in seconds to generate for. + allow_early_stop (bool): Whether to allow the generation to stop early. + base (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a base for the generation. + use_kv_caching (bool): Whether to use key-value caching for the generation. + **kwargs: Additional keyword arguments. They are ignored. + + Returns: + np.ndarray: The generated semantic tokens. + """ + assert isinstance(text, str) + text = _normalize_whitespace(text) + assert len(text.strip()) > 0 + if all(v is not None for v in history_prompt) or base is not None: + if history_prompt is not None: + semantic_history = history_prompt[0] + if base is not None: + semantic_history = base[0] + assert ( + isinstance(semantic_history, np.ndarray) + and len(semantic_history.shape) == 1 + and len(semantic_history) > 0 + and semantic_history.min() >= 0 + and semantic_history.max() <= model.config.SEMANTIC_VOCAB_SIZE - 1 + ) + else: + semantic_history = None + encoded_text = np.array(_tokenize(model.tokenizer, text)) + model.config.TEXT_ENCODING_OFFSET + if len(encoded_text) > 256: + p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1) + logger.warning(f"warning, text too long, lopping of last {p}%") + encoded_text = encoded_text[:256] + encoded_text = np.pad( + encoded_text, + (0, 256 - len(encoded_text)), + constant_values=model.config.TEXT_PAD_TOKEN, + mode="constant", + ) + if semantic_history is not None: + semantic_history = semantic_history.astype(np.int64) + # lop off if history is too long, pad if needed + semantic_history = semantic_history[-256:] + semantic_history = np.pad( + semantic_history, + (0, 256 - len(semantic_history)), + constant_values=model.config.SEMANTIC_PAD_TOKEN, + mode="constant", + ) + else: + semantic_history = np.array([model.config.SEMANTIC_PAD_TOKEN] * 256) + x = torch.from_numpy( + np.hstack([encoded_text, semantic_history, np.array([model.config.SEMANTIC_INFER_TOKEN])]).astype(np.int64) + )[None] + assert x.shape[1] == 256 + 256 + 1 + with inference_mode(): + x = x.to(model.device) + n_tot_steps = 768 + # custom tqdm updates since we don't know when eos will occur + pbar = tqdm.tqdm(disable=silent, total=100) + pbar_state = 0 + tot_generated_duration_s = 0 + kv_cache = None + for n in range(n_tot_steps): + if use_kv_caching and kv_cache is not None: + x_input = x[:, [-1]] + else: + x_input = x + logits, kv_cache = model.semantic_model( + x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache + ) + relevant_logits = logits[0, 0, : model.config.SEMANTIC_VOCAB_SIZE] + if allow_early_stop: + relevant_logits = torch.hstack( + (relevant_logits, logits[0, 0, [model.config.SEMANTIC_PAD_TOKEN]]) + ) # eos + if top_p is not None: + # faster to convert to numpy + logits_device = relevant_logits.device + logits_dtype = relevant_logits.type() + relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy() + sorted_indices = np.argsort(relevant_logits)[::-1] + sorted_logits = relevant_logits[sorted_indices] + cumulative_probs = np.cumsum(softmax(sorted_logits)) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy() + sorted_indices_to_remove[0] = False + relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf + relevant_logits = torch.from_numpy(relevant_logits) + relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + if top_k is not None: + v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) + relevant_logits[relevant_logits < v[-1]] = -float("Inf") + probs = torch.softmax(relevant_logits / temp, dim=-1) + item_next = torch.multinomial(probs, num_samples=1) + if allow_early_stop and ( + item_next == model.config.SEMANTIC_VOCAB_SIZE or (min_eos_p is not None and probs[-1] >= min_eos_p) + ): + # eos found, so break + pbar.update(100 - pbar_state) + break + x = torch.cat((x, item_next[None]), dim=1) + tot_generated_duration_s += 1 / model.config.SEMANTIC_RATE_HZ + if max_gen_duration_s is not None and tot_generated_duration_s > max_gen_duration_s: + pbar.update(100 - pbar_state) + break + if n == n_tot_steps - 1: + pbar.update(100 - pbar_state) + break + del logits, relevant_logits, probs, item_next + req_pbar_state = np.min([100, int(round(100 * n / n_tot_steps))]) + if req_pbar_state > pbar_state: + pbar.update(req_pbar_state - pbar_state) + pbar_state = req_pbar_state + pbar.close() + out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :] + assert all(out >= 0) and all(out < model.config.SEMANTIC_VOCAB_SIZE) + clear_cuda_cache() + return out + + +def _flatten_codebooks(arr, offset_size): + assert len(arr.shape) == 2 + arr = arr.copy() + if offset_size is not None: + for n in range(1, arr.shape[0]): + arr[n, :] += offset_size * n + flat_arr = arr.ravel("F") + return flat_arr + + +def generate_coarse( + x_semantic, + model, + history_prompt=None, + temp=0.7, + top_k=None, + top_p=None, + silent=False, + max_coarse_history=630, # min 60 (faster), max 630 (more context) + sliding_window_len=60, + base=None, + use_kv_caching=True, +): + """Generate coarse audio codes from semantic tokens. + + Args: + x_semantic (np.ndarray): The semantic tokens to generate coarse audio codes from. + model (BarkModel): The BarkModel to use for generating the coarse audio codes. + history_prompt (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a prompt for the generation. + temp (float): The temperature to use for the generation. + top_k (int): The number of top tokens to consider for the generation. + top_p (float): The cumulative probability to consider for the generation. + silent (bool): Whether to silence the tqdm progress bar. + max_coarse_history (int): The maximum number of coarse audio codes to use as history. + sliding_window_len (int): The length of the sliding window to use for the generation. + base (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a base for the generation. + use_kv_caching (bool): Whether to use key-value caching for the generation. + + Returns: + np.ndarray: The generated coarse audio codes. + """ + assert ( + isinstance(x_semantic, np.ndarray) + and len(x_semantic.shape) == 1 + and len(x_semantic) > 0 + and x_semantic.min() >= 0 + and x_semantic.max() <= model.config.SEMANTIC_VOCAB_SIZE - 1 + ) + assert 60 <= max_coarse_history <= 630 + assert max_coarse_history + sliding_window_len <= 1024 - 256 + semantic_to_coarse_ratio = ( + model.config.COARSE_RATE_HZ / model.config.SEMANTIC_RATE_HZ * model.config.N_COARSE_CODEBOOKS + ) + max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio)) + if all(v is not None for v in history_prompt) or base is not None: + if history_prompt is not None: + x_history = history_prompt + x_semantic_history = x_history[0] + x_coarse_history = x_history[1] + if base is not None: + x_semantic_history = base[0] + x_coarse_history = base[1] + assert ( + isinstance(x_semantic_history, np.ndarray) + and len(x_semantic_history.shape) == 1 + and len(x_semantic_history) > 0 + and x_semantic_history.min() >= 0 + and x_semantic_history.max() <= model.config.SEMANTIC_VOCAB_SIZE - 1 + and isinstance(x_coarse_history, np.ndarray) + and len(x_coarse_history.shape) == 2 + and x_coarse_history.shape[0] == model.config.N_COARSE_CODEBOOKS + and x_coarse_history.shape[-1] >= 0 + and x_coarse_history.min() >= 0 + and x_coarse_history.max() <= model.config.CODEBOOK_SIZE - 1 + and ( + round(x_coarse_history.shape[-1] / len(x_semantic_history), 1) + == round(semantic_to_coarse_ratio / model.config.N_COARSE_CODEBOOKS, 1) + ) + ) + x_coarse_history = ( + _flatten_codebooks(x_coarse_history, model.config.CODEBOOK_SIZE) + model.config.SEMANTIC_VOCAB_SIZE + ) + # trim histories correctly + n_semantic_hist_provided = np.min( + [ + max_semantic_history, + len(x_semantic_history) - len(x_semantic_history) % 2, + int(np.floor(len(x_coarse_history) / semantic_to_coarse_ratio)), + ] + ) + n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio)) + x_semantic_history = x_semantic_history[-n_semantic_hist_provided:].astype(np.int32) + x_coarse_history = x_coarse_history[-n_coarse_hist_provided:].astype(np.int32) + # TODO: bit of a hack for time alignment (sounds better) + x_coarse_history = x_coarse_history[:-2] + else: + x_semantic_history = np.array([], dtype=np.int32) + x_coarse_history = np.array([], dtype=np.int32) + # start loop + n_steps = int( + round( + np.floor(len(x_semantic) * semantic_to_coarse_ratio / model.config.N_COARSE_CODEBOOKS) + * model.config.N_COARSE_CODEBOOKS + ) + ) + assert n_steps > 0 and n_steps % model.config.N_COARSE_CODEBOOKS == 0 + x_semantic = np.hstack([x_semantic_history, x_semantic]).astype(np.int32) + x_coarse = x_coarse_history.astype(np.int32) + base_semantic_idx = len(x_semantic_history) + with inference_mode(): + x_semantic_in = torch.from_numpy(x_semantic)[None].to(model.device) + x_coarse_in = torch.from_numpy(x_coarse)[None].to(model.device) + n_window_steps = int(np.ceil(n_steps / sliding_window_len)) + n_step = 0 + for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent): + semantic_idx = base_semantic_idx + int(round(n_step / semantic_to_coarse_ratio)) + # pad from right side + x_in = x_semantic_in[:, np.max([0, semantic_idx - max_semantic_history]) :] + x_in = x_in[:, :256] + x_in = F.pad( + x_in, + (0, 256 - x_in.shape[-1]), + "constant", + model.config.COARSE_SEMANTIC_PAD_TOKEN, + ) + x_in = torch.hstack( + [ + x_in, + torch.tensor([model.config.COARSE_INFER_TOKEN])[None].to(model.device), + x_coarse_in[:, -max_coarse_history:], + ] + ) + kv_cache = None + for _ in range(sliding_window_len): + if n_step >= n_steps: + continue + is_major_step = n_step % model.config.N_COARSE_CODEBOOKS == 0 + + if use_kv_caching and kv_cache is not None: + x_input = x_in[:, [-1]] + else: + x_input = x_in + + logits, kv_cache = model.coarse_model(x_input, use_cache=use_kv_caching, past_kv=kv_cache) + logit_start_idx = ( + model.config.SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * model.config.CODEBOOK_SIZE + ) + logit_end_idx = model.config.SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * model.config.CODEBOOK_SIZE + relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx] + if top_p is not None: + # faster to convert to numpy + logits_device = relevant_logits.device + logits_dtype = relevant_logits.type() + relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy() + sorted_indices = np.argsort(relevant_logits)[::-1] + sorted_logits = relevant_logits[sorted_indices] + cumulative_probs = np.cumsum(torch.nn.functional.softmax(sorted_logits)) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy() + sorted_indices_to_remove[0] = False + relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf + relevant_logits = torch.from_numpy(relevant_logits) + relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + if top_k is not None: + v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) + relevant_logits[relevant_logits < v[-1]] = -float("Inf") + probs = torch.nn.functional.softmax(relevant_logits / temp, dim=-1) + item_next = torch.multinomial(probs, num_samples=1) + item_next += logit_start_idx + x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1) + x_in = torch.cat((x_in, item_next[None]), dim=1) + del logits, relevant_logits, probs, item_next + n_step += 1 + del x_in + del x_semantic_in + gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :] + del x_coarse_in + assert len(gen_coarse_arr) == n_steps + gen_coarse_audio_arr = ( + gen_coarse_arr.reshape(-1, model.config.N_COARSE_CODEBOOKS).T - model.config.SEMANTIC_VOCAB_SIZE + ) + for n in range(1, model.config.N_COARSE_CODEBOOKS): + gen_coarse_audio_arr[n, :] -= n * model.config.CODEBOOK_SIZE + clear_cuda_cache() + return gen_coarse_audio_arr + + +def generate_fine( + x_coarse_gen, + model, + history_prompt=None, + temp=0.5, + silent=True, + base=None, +): + """Generate full audio codes from coarse audio codes. + + Args: + x_coarse_gen (np.ndarray): The coarse audio codes to generate full audio codes from. + model (BarkModel): The BarkModel to use for generating the full audio codes. + history_prompt (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a prompt for the generation. + temp (float): The temperature to use for the generation. + silent (bool): Whether to silence the tqdm progress bar. + base (tuple): A tuple of (semantic_history, coarse_history, fine_history) to use as a base for the generation. + + Returns: + np.ndarray: The generated full audio codes. + """ + assert ( + isinstance(x_coarse_gen, np.ndarray) + and len(x_coarse_gen.shape) == 2 + and 1 <= x_coarse_gen.shape[0] <= model.config.N_FINE_CODEBOOKS - 1 + and x_coarse_gen.shape[1] > 0 + and x_coarse_gen.min() >= 0 + and x_coarse_gen.max() <= model.config.CODEBOOK_SIZE - 1 + ) + if all(v is not None for v in history_prompt) or base is not None: + if history_prompt is not None: + x_fine_history = history_prompt[2] + if base is not None: + x_fine_history = base[2] + assert ( + isinstance(x_fine_history, np.ndarray) + and len(x_fine_history.shape) == 2 + and x_fine_history.shape[0] == model.config.N_FINE_CODEBOOKS + and x_fine_history.shape[1] >= 0 + and x_fine_history.min() >= 0 + and x_fine_history.max() <= model.config.CODEBOOK_SIZE - 1 + ) + else: + x_fine_history = None + n_coarse = x_coarse_gen.shape[0] + # make input arr + in_arr = np.vstack( + [ + x_coarse_gen, + np.zeros((model.config.N_FINE_CODEBOOKS - n_coarse, x_coarse_gen.shape[1])) + + model.config.CODEBOOK_SIZE, # padding + ] + ).astype(np.int32) + # prepend history if available (max 512) + if x_fine_history is not None: + x_fine_history = x_fine_history.astype(np.int32) + in_arr = np.hstack( + [ + x_fine_history[:, -512:].astype(np.int32), + in_arr, + ] + ) + n_history = x_fine_history[:, -512:].shape[1] + else: + n_history = 0 + n_remove_from_end = 0 + # need to pad if too short (since non-causal model) + if in_arr.shape[1] < 1024: + n_remove_from_end = 1024 - in_arr.shape[1] + in_arr = np.hstack( + [ + in_arr, + np.zeros((model.config.N_FINE_CODEBOOKS, n_remove_from_end), dtype=np.int32) + + model.config.CODEBOOK_SIZE, + ] + ) + # we can be lazy about fractional loop and just keep overwriting codebooks + n_loops = np.max([0, int(np.ceil((x_coarse_gen.shape[1] - (1024 - n_history)) / 512))]) + 1 + with inference_mode(): + in_arr = torch.tensor(in_arr.T).to(model.device) + for n in tqdm.tqdm(range(n_loops), disable=silent): + start_idx = np.min([n * 512, in_arr.shape[0] - 1024]) + start_fill_idx = np.min([n_history + n * 512, in_arr.shape[0] - 512]) + rel_start_fill_idx = start_fill_idx - start_idx + in_buffer = in_arr[start_idx : start_idx + 1024, :][None] + for nn in range(n_coarse, model.config.N_FINE_CODEBOOKS): + logits = model.fine_model(nn, in_buffer) + if temp is None: + relevant_logits = logits[0, rel_start_fill_idx:, : model.config.CODEBOOK_SIZE] + codebook_preds = torch.argmax(relevant_logits, -1) + else: + relevant_logits = logits[0, :, : model.config.CODEBOOK_SIZE] / temp + probs = F.softmax(relevant_logits, dim=-1) + codebook_preds = torch.hstack( + [torch.multinomial(probs[n], num_samples=1) for n in range(rel_start_fill_idx, 1024)] + ) + in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds + del logits, codebook_preds + # transfer over info into model_in and convert to numpy + for nn in range(n_coarse, model.config.N_FINE_CODEBOOKS): + in_arr[start_fill_idx : start_fill_idx + (1024 - rel_start_fill_idx), nn] = in_buffer[ + 0, rel_start_fill_idx:, nn + ] + del in_buffer + gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T + del in_arr + gen_fine_arr = gen_fine_arr[:, n_history:] + if n_remove_from_end > 0: + gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end] + assert gen_fine_arr.shape[-1] == x_coarse_gen.shape[-1] + clear_cuda_cache() + return gen_fine_arr + + +def codec_decode(fine_tokens, model): + """Turn quantized audio codes into audio array using encodec.""" + arr = torch.from_numpy(fine_tokens)[None] + arr = arr.to(model.device) + arr = arr.transpose(0, 1) + emb = model.encodec.quantizer.decode(arr) + out = model.encodec.decoder(emb) + audio_arr = out.detach().cpu().numpy().squeeze() + return audio_arr diff --git a/TTS/tts/layers/bark/load_model.py b/TTS/tts/layers/bark/load_model.py new file mode 100644 index 0000000000000000000000000000000000000000..72eca30ac68660598669e9f929667978a910ecfc --- /dev/null +++ b/TTS/tts/layers/bark/load_model.py @@ -0,0 +1,161 @@ +import contextlib +import functools +import hashlib +import logging +import os + +import requests +import torch +import tqdm + +from TTS.tts.layers.bark.model import GPT, GPTConfig +from TTS.tts.layers.bark.model_fine import FineGPT, FineGPTConfig +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 + +if ( + torch.cuda.is_available() + and hasattr(torch.cuda, "amp") + and hasattr(torch.cuda.amp, "autocast") + and torch.cuda.is_bf16_supported() +): + autocast = functools.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16) +else: + + @contextlib.contextmanager + def autocast(): + yield + + +# hold models in global scope to lazy load + +logger = logging.getLogger(__name__) + + +if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + logger.warning( + "torch version does not support flash attention. You will get significantly faster" + + " inference speed by upgrade torch to newest version / nightly." + ) + + +def _md5(fname): + hash_md5 = hashlib.md5() + with open(fname, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +def _download(from_s3_path, to_local_path, CACHE_DIR): + os.makedirs(CACHE_DIR, exist_ok=True) + response = requests.get(from_s3_path, stream=True) + total_size_in_bytes = int(response.headers.get("content-length", 0)) + block_size = 1024 # 1 Kibibyte + progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + with open(to_local_path, "wb") as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + if total_size_in_bytes not in [0, progress_bar.n]: + raise ValueError("ERROR, something went wrong") + + +class InferenceContext: + def __init__(self, benchmark=False): + # we can't expect inputs to be the same length, so disable benchmarking by default + self._chosen_cudnn_benchmark = benchmark + self._cudnn_benchmark = None + + def __enter__(self): + self._cudnn_benchmark = torch.backends.cudnn.benchmark + torch.backends.cudnn.benchmark = self._chosen_cudnn_benchmark + + def __exit__(self, exc_type, exc_value, exc_traceback): + torch.backends.cudnn.benchmark = self._cudnn_benchmark + + +if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + +@contextlib.contextmanager +def inference_mode(): + with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast(): + yield + + +def clear_cuda_cache(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +def load_model(ckpt_path, device, config, model_type="text"): + logger.info(f"loading {model_type} model from {ckpt_path}...") + + if device == "cpu": + logger.warning("No GPU being used. Careful, Inference might be extremely slow!") + if model_type == "text": + ConfigClass = GPTConfig + ModelClass = GPT + elif model_type == "coarse": + ConfigClass = GPTConfig + ModelClass = GPT + elif model_type == "fine": + ConfigClass = FineGPTConfig + ModelClass = FineGPT + else: + raise NotImplementedError() + if ( + not config.USE_SMALLER_MODELS + and os.path.exists(ckpt_path) + and _md5(ckpt_path) != config.REMOTE_MODEL_PATHS[model_type]["checksum"] + ): + logger.warning(f"found outdated {model_type} model, removing...") + os.remove(ckpt_path) + if not os.path.exists(ckpt_path): + logger.info(f"{model_type} model not found, downloading...") + _download(config.REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path, config.CACHE_DIR) + + checkpoint = torch.load(ckpt_path, map_location=device, weights_only=is_pytorch_at_least_2_4()) + # this is a hack + model_args = checkpoint["model_args"] + if "input_vocab_size" not in model_args: + model_args["input_vocab_size"] = model_args["vocab_size"] + model_args["output_vocab_size"] = model_args["vocab_size"] + del model_args["vocab_size"] + + gptconf = ConfigClass(**checkpoint["model_args"]) + if model_type == "text": + config.semantic_config = gptconf + elif model_type == "coarse": + config.coarse_config = gptconf + elif model_type == "fine": + config.fine_config = gptconf + + model = ModelClass(gptconf) + state_dict = checkpoint["model"] + # fixup checkpoint + unwanted_prefix = "_orig_mod." + for k, _ in list(state_dict.items()): + if k.startswith(unwanted_prefix): + state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) + extra_keys = set(state_dict.keys()) - set(model.state_dict().keys()) + extra_keys = set(k for k in extra_keys if not k.endswith(".attn.bias")) + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + missing_keys = set(k for k in missing_keys if not k.endswith(".attn.bias")) + if len(extra_keys) != 0: + raise ValueError(f"extra keys found: {extra_keys}") + if len(missing_keys) != 0: + raise ValueError(f"missing keys: {missing_keys}") + model.load_state_dict(state_dict, strict=False) + n_params = model.get_num_params() + val_loss = checkpoint["best_val_loss"].item() + logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss") + model.eval() + model.to(device) + del checkpoint, state_dict + clear_cuda_cache() + return model, config diff --git a/TTS/tts/layers/bark/model.py b/TTS/tts/layers/bark/model.py new file mode 100644 index 0000000000000000000000000000000000000000..68c50dbdbdeb0e8ae3ade906f008614ba1de2d0d --- /dev/null +++ b/TTS/tts/layers/bark/model.py @@ -0,0 +1,234 @@ +""" +Much of this code is adapted from Andrej Karpathy's NanoGPT +(https://github.com/karpathy/nanoGPT) +""" + +import math +from dataclasses import dataclass + +import torch +from coqpit import Coqpit +from torch import nn +from torch.nn import functional as F + + +class LayerNorm(nn.Module): + """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" + + def __init__(self, ndim, bias): + super().__init__() + self.weight = nn.Parameter(torch.ones(ndim)) + self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None + + def forward(self, x): + return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5) + + +class CausalSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary + self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") + if not self.flash: + # print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0") + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "bias", + torch.tril(torch.ones(config.block_size, config.block_size)).view( + 1, 1, config.block_size, config.block_size + ), + ) + + def forward(self, x, past_kv=None, use_cache=False): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + if past_kv is not None: + past_key = past_kv[0] + past_value = past_kv[1] + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + FULL_T = k.shape[-2] + + if use_cache is True: + present = (k, v) + else: + present = None + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + if past_kv is not None: + # When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains + # the query for the last token. scaled_dot_product_attention interprets this as the first token in the + # sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so + # to work around this we set is_causal=False. + is_causal = False + else: + is_causal = True + + # efficient attention using Flash Attention CUDA kernels + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal) + else: + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, FULL_T - T : FULL_T, :FULL_T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return (y, present) + + +class MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.dropout = nn.Dropout(config.dropout) + self.gelu = nn.GELU() + + def forward(self, x): + x = self.c_fc(x) + x = self.gelu(x) + x = self.c_proj(x) + x = self.dropout(x) + return x + + +class Block(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) + self.attn = CausalSelfAttention(config) + self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) + self.mlp = MLP(config) + self.layer_idx = layer_idx + + def forward(self, x, past_kv=None, use_cache=False): + attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache) + x = x + attn_output + x = x + self.mlp(self.ln_2(x)) + return (x, prev_kvs) + + +@dataclass +class GPTConfig(Coqpit): + block_size: int = 1024 + input_vocab_size: int = 10_048 + output_vocab_size: int = 10_048 + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + dropout: float = 0.0 + bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + + +class GPT(nn.Module): + def __init__(self, config): + super().__init__() + assert config.input_vocab_size is not None + assert config.output_vocab_size is not None + assert config.block_size is not None + self.config = config + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.input_vocab_size, config.n_embd), + wpe=nn.Embedding(config.block_size, config.n_embd), + drop=nn.Dropout(config.dropout), + h=nn.ModuleList([Block(config, idx) for idx in range(config.n_layer)]), + ln_f=LayerNorm(config.n_embd, bias=config.bias), + ) + ) + self.lm_head = nn.Linear(config.n_embd, config.output_vocab_size, bias=False) + + def get_num_params(self, non_embedding=True): + """ + Return the number of parameters in the model. + For non-embedding count (default), the position embeddings get subtracted. + The token embeddings would too, except due to the parameter sharing these + params are actually used as weights in the final layer, so we include them. + """ + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.transformer.wte.weight.numel() + n_params -= self.transformer.wpe.weight.numel() + return n_params + + def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False): + device = idx.device + _, t = idx.size() + if past_kv is not None: + assert t == 1 + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + else: + if merge_context: + assert idx.shape[1] >= 256 + 256 + 1 + t = idx.shape[1] - 256 + else: + assert ( + t <= self.config.block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + + # forward the GPT model itself + if merge_context: + tok_emb = torch.cat( + [ + self.transformer.wte(idx[:, :256]) + self.transformer.wte(idx[:, 256 : 256 + 256]), + self.transformer.wte(idx[:, 256 + 256 :]), + ], + dim=1, + ) + else: + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + + if past_kv is None: + past_length = 0 + past_kv = tuple([None] * len(self.transformer.h)) + else: + past_length = past_kv[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) # shape (1, t) + assert position_ids.shape == (1, t) + + pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd) + + x = self.transformer.drop(tok_emb + pos_emb) + + new_kv = () if use_cache else None + + for _, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)): + x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache) + + if use_cache: + new_kv = new_kv + (kv,) + + x = self.transformer.ln_f(x) + + # inference-time mini-optimization: only forward the lm_head on the very last position + logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim + + return (logits, new_kv) diff --git a/TTS/tts/layers/bark/model_fine.py b/TTS/tts/layers/bark/model_fine.py new file mode 100644 index 0000000000000000000000000000000000000000..29126b41ab1b49b3c67e3eb30707cccd553f2a06 --- /dev/null +++ b/TTS/tts/layers/bark/model_fine.py @@ -0,0 +1,143 @@ +""" +Much of this code is adapted from Andrej Karpathy's NanoGPT +(https://github.com/karpathy/nanoGPT) +""" + +import math +from dataclasses import dataclass + +import torch +from torch import nn +from torch.nn import functional as F + +from .model import GPT, MLP, GPTConfig + + +class NonCausalSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary + self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0 + + def forward(self, x): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=False + ) + else: + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class FineBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.ln_1 = nn.LayerNorm(config.n_embd) + self.attn = NonCausalSelfAttention(config) + self.ln_2 = nn.LayerNorm(config.n_embd) + self.mlp = MLP(config) + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class FineGPT(GPT): + def __init__(self, config): + super().__init__(config) + del self.lm_head + self.config = config + self.n_codes_total = config.n_codes_total + self.transformer = nn.ModuleDict( + dict( + wtes=nn.ModuleList( + [nn.Embedding(config.input_vocab_size, config.n_embd) for _ in range(config.n_codes_total)] + ), + wpe=nn.Embedding(config.block_size, config.n_embd), + drop=nn.Dropout(config.dropout), + h=nn.ModuleList([FineBlock(config) for _ in range(config.n_layer)]), + ln_f=nn.LayerNorm(config.n_embd), + ) + ) + self.lm_heads = nn.ModuleList( + [ + nn.Linear(config.n_embd, config.output_vocab_size, bias=False) + for _ in range(config.n_codes_given, self.n_codes_total) + ] + ) + for i in range(self.n_codes_total - config.n_codes_given): + self.transformer.wtes[i + 1].weight = self.lm_heads[i].weight + + def forward(self, pred_idx, idx): + device = idx.device + b, t, codes = idx.size() + assert ( + t <= self.config.block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + assert pred_idx > 0, "cannot predict 0th codebook" + assert codes == self.n_codes_total, (b, t, codes) + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) + + # forward the GPT model itself + tok_embs = [ + wte(idx[:, :, i]).unsqueeze(-1) for i, wte in enumerate(self.transformer.wtes) + ] # token embeddings of shape (b, t, n_embd) + tok_emb = torch.cat(tok_embs, dim=-1) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) + x = tok_emb[:, :, :, : pred_idx + 1].sum(dim=-1) + x = self.transformer.drop(x + pos_emb) + for block in self.transformer.h: + x = block(x) + x = self.transformer.ln_f(x) + logits = self.lm_heads[pred_idx - self.config.n_codes_given](x) + return logits + + def get_num_params(self, non_embedding=True): + """ + Return the number of parameters in the model. + For non-embedding count (default), the position embeddings get subtracted. + The token embeddings would too, except due to the parameter sharing these + params are actually used as weights in the final layer, so we include them. + """ + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + for wte in self.transformer.wtes: + n_params -= wte.weight.numel() + n_params -= self.transformer.wpe.weight.numel() + return n_params + + +@dataclass +class FineGPTConfig(GPTConfig): + n_codes_total: int = 8 + n_codes_given: int = 1 diff --git a/TTS/tts/layers/delightful_tts/__init__.py b/TTS/tts/layers/delightful_tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/layers/delightful_tts/acoustic_model.py b/TTS/tts/layers/delightful_tts/acoustic_model.py new file mode 100644 index 0000000000000000000000000000000000000000..83989f9ba485647249597f074ca75b7e614fa2aa --- /dev/null +++ b/TTS/tts/layers/delightful_tts/acoustic_model.py @@ -0,0 +1,566 @@ +### credit: https://github.com/dunky11/voicesmith +import logging +from typing import Callable, Dict, Tuple + +import torch +import torch.nn.functional as F +from coqpit import Coqpit +from torch import nn + +from TTS.tts.layers.delightful_tts.conformer import Conformer +from TTS.tts.layers.delightful_tts.encoders import ( + PhonemeLevelProsodyEncoder, + UtteranceLevelProsodyEncoder, + get_mask_from_lengths, +) +from TTS.tts.layers.delightful_tts.energy_adaptor import EnergyAdaptor +from TTS.tts.layers.delightful_tts.networks import EmbeddingPadded, positional_encoding +from TTS.tts.layers.delightful_tts.phoneme_prosody_predictor import PhonemeProsodyPredictor +from TTS.tts.layers.delightful_tts.pitch_adaptor import PitchAdaptor +from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor +from TTS.tts.layers.generic.aligner import AlignmentNetwork +from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask + +logger = logging.getLogger(__name__) + + +class AcousticModel(torch.nn.Module): + def __init__( + self, + args: "ModelArgs", + tokenizer: "TTSTokenizer" = None, + speaker_manager: "SpeakerManager" = None, + ): + super().__init__() + self.args = args + self.tokenizer = tokenizer + self.speaker_manager = speaker_manager + + self.init_multispeaker(args) + # self.set_embedding_dims() + + self.length_scale = ( + float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale + ) + + self.emb_dim = args.n_hidden_conformer_encoder + self.encoder = Conformer( + dim=self.args.n_hidden_conformer_encoder, + n_layers=self.args.n_layers_conformer_encoder, + n_heads=self.args.n_heads_conformer_encoder, + speaker_embedding_dim=self.embedded_speaker_dim, + p_dropout=self.args.dropout_conformer_encoder, + kernel_size_conv_mod=self.args.kernel_size_conv_mod_conformer_encoder, + lrelu_slope=self.args.lrelu_slope, + ) + self.pitch_adaptor = PitchAdaptor( + n_input=self.args.n_hidden_conformer_encoder, + n_hidden=self.args.n_hidden_variance_adaptor, + n_out=1, + kernel_size=self.args.kernel_size_variance_adaptor, + emb_kernel_size=self.args.emb_kernel_size_variance_adaptor, + p_dropout=self.args.dropout_variance_adaptor, + lrelu_slope=self.args.lrelu_slope, + ) + self.energy_adaptor = EnergyAdaptor( + channels_in=self.args.n_hidden_conformer_encoder, + channels_hidden=self.args.n_hidden_variance_adaptor, + channels_out=1, + kernel_size=self.args.kernel_size_variance_adaptor, + emb_kernel_size=self.args.emb_kernel_size_variance_adaptor, + dropout=self.args.dropout_variance_adaptor, + lrelu_slope=self.args.lrelu_slope, + ) + + self.aligner = AlignmentNetwork( + in_query_channels=self.args.out_channels, + in_key_channels=self.args.n_hidden_conformer_encoder, + ) + + self.duration_predictor = VariancePredictor( + channels_in=self.args.n_hidden_conformer_encoder, + channels=self.args.n_hidden_variance_adaptor, + channels_out=1, + kernel_size=self.args.kernel_size_variance_adaptor, + p_dropout=self.args.dropout_variance_adaptor, + lrelu_slope=self.args.lrelu_slope, + ) + + self.utterance_prosody_encoder = UtteranceLevelProsodyEncoder( + num_mels=self.args.num_mels, + ref_enc_filters=self.args.ref_enc_filters_reference_encoder, + ref_enc_size=self.args.ref_enc_size_reference_encoder, + ref_enc_gru_size=self.args.ref_enc_gru_size_reference_encoder, + ref_enc_strides=self.args.ref_enc_strides_reference_encoder, + n_hidden=self.args.n_hidden_conformer_encoder, + dropout=self.args.dropout_conformer_encoder, + bottleneck_size_u=self.args.bottleneck_size_u_reference_encoder, + token_num=self.args.token_num_reference_encoder, + ) + + self.utterance_prosody_predictor = PhonemeProsodyPredictor( + hidden_size=self.args.n_hidden_conformer_encoder, + kernel_size=self.args.predictor_kernel_size_reference_encoder, + dropout=self.args.dropout_conformer_encoder, + bottleneck_size=self.args.bottleneck_size_u_reference_encoder, + lrelu_slope=self.args.lrelu_slope, + ) + + self.phoneme_prosody_encoder = PhonemeLevelProsodyEncoder( + num_mels=self.args.num_mels, + ref_enc_filters=self.args.ref_enc_filters_reference_encoder, + ref_enc_size=self.args.ref_enc_size_reference_encoder, + ref_enc_gru_size=self.args.ref_enc_gru_size_reference_encoder, + ref_enc_strides=self.args.ref_enc_strides_reference_encoder, + n_hidden=self.args.n_hidden_conformer_encoder, + dropout=self.args.dropout_conformer_encoder, + bottleneck_size_p=self.args.bottleneck_size_p_reference_encoder, + n_heads=self.args.n_heads_conformer_encoder, + ) + + self.phoneme_prosody_predictor = PhonemeProsodyPredictor( + hidden_size=self.args.n_hidden_conformer_encoder, + kernel_size=self.args.predictor_kernel_size_reference_encoder, + dropout=self.args.dropout_conformer_encoder, + bottleneck_size=self.args.bottleneck_size_p_reference_encoder, + lrelu_slope=self.args.lrelu_slope, + ) + + self.u_bottle_out = nn.Linear( + self.args.bottleneck_size_u_reference_encoder, + self.args.n_hidden_conformer_encoder, + ) + + self.u_norm = nn.InstanceNorm1d(self.args.bottleneck_size_u_reference_encoder) + self.p_bottle_out = nn.Linear( + self.args.bottleneck_size_p_reference_encoder, + self.args.n_hidden_conformer_encoder, + ) + self.p_norm = nn.InstanceNorm1d( + self.args.bottleneck_size_p_reference_encoder, + ) + self.decoder = Conformer( + dim=self.args.n_hidden_conformer_decoder, + n_layers=self.args.n_layers_conformer_decoder, + n_heads=self.args.n_heads_conformer_decoder, + speaker_embedding_dim=self.embedded_speaker_dim, + p_dropout=self.args.dropout_conformer_decoder, + kernel_size_conv_mod=self.args.kernel_size_conv_mod_conformer_decoder, + lrelu_slope=self.args.lrelu_slope, + ) + + padding_idx = self.tokenizer.characters.pad_id + self.src_word_emb = EmbeddingPadded( + self.args.num_chars, self.args.n_hidden_conformer_encoder, padding_idx=padding_idx + ) + self.to_mel = nn.Linear( + self.args.n_hidden_conformer_decoder, + self.args.num_mels, + ) + + self.energy_scaler = torch.nn.BatchNorm1d(1, affine=False, track_running_stats=True, momentum=None) + self.energy_scaler.requires_grad_(False) + + def init_multispeaker(self, args: Coqpit): # pylint: disable=unused-argument + """Init for multi-speaker training.""" + self.embedded_speaker_dim = 0 + self.num_speakers = self.args.num_speakers + self.audio_transform = None + + if self.speaker_manager: + self.num_speakers = self.speaker_manager.num_speakers + + if self.args.use_speaker_embedding: + self._init_speaker_embedding() + + if self.args.use_d_vector_file: + self._init_d_vector() + + @staticmethod + def _set_cond_input(aux_input: Dict): + """Set the speaker conditioning input based on the multi-speaker mode.""" + sid, g, lid, durations = None, None, None, None + if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: + sid = aux_input["speaker_ids"] + if sid.ndim == 0: + sid = sid.unsqueeze_(0) + if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: + g = F.normalize(aux_input["d_vectors"]) # .unsqueeze_(-1) + if g.ndim == 2: + g = g # .unsqueeze_(0) # pylint: disable=self-assigning-variable + + if "durations" in aux_input and aux_input["durations"] is not None: + durations = aux_input["durations"] + + return sid, g, lid, durations + + def get_aux_input(self, aux_input: Dict): + sid, g, lid, _ = self._set_cond_input(aux_input) + return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} + + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + # def set_embedding_dims(self): + # if self.embedded_speaker_dim > 0: + # self.embedding_dims = self.embedded_speaker_dim + # else: + # self.embedding_dims = 0 + + def _init_speaker_embedding(self): + # pylint: disable=attribute-defined-outside-init + if self.num_speakers > 0: + logger.info("Initialization of speaker-embedding layers.") + self.embedded_speaker_dim = self.args.speaker_embedding_channels + self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + + def _init_d_vector(self): + # pylint: disable=attribute-defined-outside-init + if hasattr(self, "emb_g"): + raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") + self.embedded_speaker_dim = self.args.d_vector_dim + + @staticmethod + def generate_attn(dr, x_mask, y_mask=None): + """Generate an attention mask from the linear scale durations. + + Args: + dr (Tensor): Linear scale durations. + x_mask (Tensor): Mask for the input (character) sequence. + y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations + if None. Defaults to None. + + Shapes + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + """ + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) + return attn + + def _expand_encoder_with_durations( + self, + o_en: torch.FloatTensor, + dr: torch.IntTensor, + x_mask: torch.IntTensor, + y_lengths: torch.IntTensor, + ): + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) + attn = self.generate_attn(dr, x_mask, y_mask) + o_en_ex = torch.einsum("kmn, kjm -> kjn", [attn.float(), o_en]) + return y_mask, o_en_ex, attn.transpose(1, 2) + + def _forward_aligner( + self, + x: torch.FloatTensor, + y: torch.FloatTensor, + x_mask: torch.IntTensor, + y_mask: torch.IntTensor, + attn_priors: torch.FloatTensor, + ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Aligner forward pass. + + 1. Compute a mask to apply to the attention map. + 2. Run the alignment network. + 3. Apply MAS to compute the hard alignment map. + 4. Compute the durations from the hard alignment map. + + Args: + x (torch.FloatTensor): Input sequence. + y (torch.FloatTensor): Output sequence. + x_mask (torch.IntTensor): Input sequence mask. + y_mask (torch.IntTensor): Output sequence mask. + attn_priors (torch.FloatTensor): Prior for the aligner network map. + + Returns: + Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials, + hard alignment map. + + Shapes: + - x: :math:`[B, T_en, C_en]` + - y: :math:`[B, T_de, C_de]` + - x_mask: :math:`[B, 1, T_en]` + - y_mask: :math:`[B, 1, T_de]` + - attn_priors: :math:`[B, T_de, T_en]` + + - aligner_durations: :math:`[B, T_en]` + - aligner_soft: :math:`[B, T_de, T_en]` + - aligner_logprob: :math:`[B, 1, T_de, T_en]` + - aligner_mas: :math:`[B, T_de, T_en]` + """ + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # [B, 1, T_en, T_de] + aligner_soft, aligner_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, attn_priors) + aligner_mas = maximum_path( + aligner_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() + ) + aligner_durations = torch.sum(aligner_mas, -1).int() + aligner_soft = aligner_soft.squeeze(1) # [B, T_max2, T_max] + aligner_mas = aligner_mas.transpose(1, 2) # [B, T_max, T_max2] -> [B, T_max2, T_max] + return aligner_durations, aligner_soft, aligner_logprob, aligner_mas + + def average_utterance_prosody( # pylint: disable=no-self-use + self, u_prosody_pred: torch.Tensor, src_mask: torch.Tensor + ) -> torch.Tensor: + lengths = ((~src_mask) * 1.0).sum(1) + u_prosody_pred = u_prosody_pred.sum(1, keepdim=True) / lengths.view(-1, 1, 1) + return u_prosody_pred + + def forward( + self, + tokens: torch.Tensor, + src_lens: torch.Tensor, + mels: torch.Tensor, + mel_lens: torch.Tensor, + pitches: torch.Tensor, + energies: torch.Tensor, + attn_priors: torch.Tensor, + use_ground_truth: bool = True, + d_vectors: torch.Tensor = None, + speaker_idx: torch.Tensor = None, + ) -> Dict[str, torch.Tensor]: + sid, g, lid, _ = self._set_cond_input( # pylint: disable=unused-variable + {"d_vectors": d_vectors, "speaker_ids": speaker_idx} + ) # pylint: disable=unused-variable + + src_mask = get_mask_from_lengths(src_lens) # [B, T_src] + mel_mask = get_mask_from_lengths(mel_lens) # [B, T_mel] + + # Token embeddings + token_embeddings = self.src_word_emb(tokens) # [B, T_src, C_hidden] + token_embeddings = token_embeddings.masked_fill(src_mask.unsqueeze(-1), 0.0) + + # Alignment network and durations + aligner_durations, aligner_soft, aligner_logprob, aligner_mas = self._forward_aligner( + x=token_embeddings, + y=mels.transpose(1, 2), + x_mask=~src_mask[:, None], + y_mask=~mel_mask[:, None], + attn_priors=attn_priors, + ) + dr = aligner_durations # [B, T_en] + + # Embeddings + speaker_embedding = None + if d_vectors is not None: + speaker_embedding = g + elif speaker_idx is not None: + speaker_embedding = F.normalize(self.emb_g(sid)) + + pos_encoding = positional_encoding( + self.emb_dim, + max(token_embeddings.shape[1], *mel_lens), + device=token_embeddings.device, + ) + encoder_outputs = self.encoder( + token_embeddings, + src_mask, + speaker_embedding=speaker_embedding, + encoding=pos_encoding, + ) + + u_prosody_ref = self.u_norm(self.utterance_prosody_encoder(mels=mels, mel_lens=mel_lens)) + u_prosody_pred = self.u_norm( + self.average_utterance_prosody( + u_prosody_pred=self.utterance_prosody_predictor(x=encoder_outputs, mask=src_mask), + src_mask=src_mask, + ) + ) + + if use_ground_truth: + encoder_outputs = encoder_outputs + self.u_bottle_out(u_prosody_ref) + else: + encoder_outputs = encoder_outputs + self.u_bottle_out(u_prosody_pred) + + p_prosody_ref = self.p_norm( + self.phoneme_prosody_encoder( + x=encoder_outputs, src_mask=src_mask, mels=mels, mel_lens=mel_lens, encoding=pos_encoding + ) + ) + p_prosody_pred = self.p_norm(self.phoneme_prosody_predictor(x=encoder_outputs, mask=src_mask)) + + if use_ground_truth: + encoder_outputs = encoder_outputs + self.p_bottle_out(p_prosody_ref) + else: + encoder_outputs = encoder_outputs + self.p_bottle_out(p_prosody_pred) + + encoder_outputs_res = encoder_outputs + + pitch_pred, avg_pitch_target, pitch_emb = self.pitch_adaptor.get_pitch_embedding_train( + x=encoder_outputs, + target=pitches, + dr=dr, + mask=src_mask, + ) + + energy_pred, avg_energy_target, energy_emb = self.energy_adaptor.get_energy_embedding_train( + x=encoder_outputs, + target=energies, + dr=dr, + mask=src_mask, + ) + + encoder_outputs = encoder_outputs.transpose(1, 2) + pitch_emb + energy_emb + log_duration_prediction = self.duration_predictor(x=encoder_outputs_res.detach(), mask=src_mask) + + mel_pred_mask, encoder_outputs_ex, alignments = self._expand_encoder_with_durations( + o_en=encoder_outputs, y_lengths=mel_lens, dr=dr, x_mask=~src_mask[:, None] + ) + + x = self.decoder( + encoder_outputs_ex.transpose(1, 2), + mel_mask, + speaker_embedding=speaker_embedding, + encoding=pos_encoding, + ) + x = self.to_mel(x) + + dr = torch.log(dr + 1) + + dr_pred = torch.exp(log_duration_prediction) - 1 + alignments_dp = self.generate_attn(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) # [B, T_max, T_max2'] + + return { + "model_outputs": x, + "pitch_pred": pitch_pred, + "pitch_target": avg_pitch_target, + "energy_pred": energy_pred, + "energy_target": avg_energy_target, + "u_prosody_pred": u_prosody_pred, + "u_prosody_ref": u_prosody_ref, + "p_prosody_pred": p_prosody_pred, + "p_prosody_ref": p_prosody_ref, + "alignments_dp": alignments_dp, + "alignments": alignments, # [B, T_de, T_en] + "aligner_soft": aligner_soft, + "aligner_mas": aligner_mas, + "aligner_durations": aligner_durations, + "aligner_logprob": aligner_logprob, + "dr_log_pred": log_duration_prediction.squeeze(1), # [B, T] + "dr_log_target": dr.squeeze(1), # [B, T] + "spk_emb": speaker_embedding, + } + + @torch.no_grad() + def inference( + self, + tokens: torch.Tensor, + speaker_idx: torch.Tensor, + p_control: float = None, # TODO # pylint: disable=unused-argument + d_control: float = None, # TODO # pylint: disable=unused-argument + d_vectors: torch.Tensor = None, + pitch_transform: Callable = None, + energy_transform: Callable = None, + ) -> torch.Tensor: + src_mask = get_mask_from_lengths(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device)) + src_lens = torch.tensor(tokens.shape[1:2]).to(tokens.device) # pylint: disable=unused-variable + sid, g, lid, _ = self._set_cond_input( # pylint: disable=unused-variable + {"d_vectors": d_vectors, "speaker_ids": speaker_idx} + ) # pylint: disable=unused-variable + + token_embeddings = self.src_word_emb(tokens) + token_embeddings = token_embeddings.masked_fill(src_mask.unsqueeze(-1), 0.0) + + # Embeddings + speaker_embedding = None + if d_vectors is not None: + speaker_embedding = g + elif speaker_idx is not None: + speaker_embedding = F.normalize(self.emb_g(sid)) + + pos_encoding = positional_encoding( + self.emb_dim, + token_embeddings.shape[1], + device=token_embeddings.device, + ) + encoder_outputs = self.encoder( + token_embeddings, + src_mask, + speaker_embedding=speaker_embedding, + encoding=pos_encoding, + ) + + u_prosody_pred = self.u_norm( + self.average_utterance_prosody( + u_prosody_pred=self.utterance_prosody_predictor(x=encoder_outputs, mask=src_mask), + src_mask=src_mask, + ) + ) + encoder_outputs = encoder_outputs + self.u_bottle_out(u_prosody_pred).expand_as(encoder_outputs) + + p_prosody_pred = self.p_norm( + self.phoneme_prosody_predictor( + x=encoder_outputs, + mask=src_mask, + ) + ) + encoder_outputs = encoder_outputs + self.p_bottle_out(p_prosody_pred).expand_as(encoder_outputs) + + encoder_outputs_res = encoder_outputs + + pitch_emb_pred, pitch_pred = self.pitch_adaptor.get_pitch_embedding( + x=encoder_outputs, + mask=src_mask, + pitch_transform=pitch_transform, + pitch_mean=self.pitch_mean if hasattr(self, "pitch_mean") else None, + pitch_std=self.pitch_std if hasattr(self, "pitch_std") else None, + ) + + energy_emb_pred, energy_pred = self.energy_adaptor.get_energy_embedding( + x=encoder_outputs, mask=src_mask, energy_transform=energy_transform + ) + encoder_outputs = encoder_outputs.transpose(1, 2) + pitch_emb_pred + energy_emb_pred + + log_duration_pred = self.duration_predictor( + x=encoder_outputs_res.detach(), mask=src_mask + ) # [B, C_hidden, T_src] -> [B, T_src] + duration_pred = (torch.exp(log_duration_pred) - 1) * (~src_mask) * self.length_scale # -> [B, T_src] + duration_pred[duration_pred < 1] = 1.0 # -> [B, T_src] + duration_pred = torch.round(duration_pred) # -> [B, T_src] + mel_lens = duration_pred.sum(1) # -> [B,] + + _, encoder_outputs_ex, alignments = self._expand_encoder_with_durations( + o_en=encoder_outputs, y_lengths=mel_lens, dr=duration_pred.squeeze(1), x_mask=~src_mask[:, None] + ) + + mel_mask = get_mask_from_lengths( + torch.tensor([encoder_outputs_ex.shape[2]], dtype=torch.int64, device=encoder_outputs_ex.device) + ) + + if encoder_outputs_ex.shape[1] > pos_encoding.shape[1]: + encoding = positional_encoding(self.emb_dim, encoder_outputs_ex.shape[2], device=tokens.device) + + # [B, C_hidden, T_src], [B, 1, T_src], [B, C_emb], [B, T_src, C_hidden] -> [B, C_hidden, T_src] + x = self.decoder( + encoder_outputs_ex.transpose(1, 2), + mel_mask, + speaker_embedding=speaker_embedding, + encoding=encoding, + ) + x = self.to_mel(x) + outputs = { + "model_outputs": x, + "alignments": alignments, + # "pitch": pitch_emb_pred, + "durations": duration_pred, + "pitch": pitch_pred, + "energy": energy_pred, + "spk_emb": speaker_embedding, + } + return outputs diff --git a/TTS/tts/layers/delightful_tts/conformer.py b/TTS/tts/layers/delightful_tts/conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b2175b3b965c6b100846e87d405a753dc272c9e7 --- /dev/null +++ b/TTS/tts/layers/delightful_tts/conformer.py @@ -0,0 +1,450 @@ +### credit: https://github.com/dunky11/voicesmith +import math +from typing import Tuple + +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import +import torch.nn.functional as F + +from TTS.tts.layers.delightful_tts.conv_layers import Conv1dGLU, DepthWiseConv1d, PointwiseConv1d +from TTS.tts.layers.delightful_tts.networks import GLUActivation + + +def calc_same_padding(kernel_size: int) -> Tuple[int, int]: + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + + +class Conformer(nn.Module): + def __init__( + self, + dim: int, + n_layers: int, + n_heads: int, + speaker_embedding_dim: int, + p_dropout: float, + kernel_size_conv_mod: int, + lrelu_slope: float, + ): + """ + A Transformer variant that integrates both CNNs and Transformers components. + Conformer proposes a novel combination of self-attention and convolution, in which self-attention + learns the global interaction while the convolutions efficiently capture the local correlations. + + Args: + dim (int): Number of the dimensions for the model. + n_layers (int): Number of model layers. + n_heads (int): The number of attention heads. + speaker_embedding_dim (int): Number of speaker embedding dimensions. + p_dropout (float): Probabilty of dropout. + kernel_size_conv_mod (int): Size of kernels for convolution modules. + + Inputs: inputs, mask + - **inputs** (batch, time, dim): Tensor containing input vector + - **encoding** (batch, time, dim): Positional embedding tensor + - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked + Returns: + - **outputs** (batch, time, dim): Tensor produced by Conformer Encoder. + """ + super().__init__() + d_k = d_v = dim // n_heads + self.layer_stack = nn.ModuleList( + [ + ConformerBlock( + dim, + n_heads, + d_k, + d_v, + kernel_size_conv_mod=kernel_size_conv_mod, + dropout=p_dropout, + speaker_embedding_dim=speaker_embedding_dim, + lrelu_slope=lrelu_slope, + ) + for _ in range(n_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + speaker_embedding: torch.Tensor, + encoding: torch.Tensor, + ) -> torch.Tensor: + """ + Shapes: + - x: :math:`[B, T_src, C]` + - mask: :math: `[B]` + - speaker_embedding: :math: `[B, C]` + - encoding: :math: `[B, T_max2, C]` + """ + + attn_mask = mask.view((mask.shape[0], 1, 1, mask.shape[1])) + for enc_layer in self.layer_stack: + x = enc_layer( + x, + mask=mask, + slf_attn_mask=attn_mask, + speaker_embedding=speaker_embedding, + encoding=encoding, + ) + return x + + +class ConformerBlock(torch.nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + d_k: int, # pylint: disable=unused-argument + d_v: int, # pylint: disable=unused-argument + kernel_size_conv_mod: int, + speaker_embedding_dim: int, + dropout: float, + lrelu_slope: float = 0.3, + ): + """ + A Conformer block is composed of four modules stacked together, + A feed-forward module, a self-attention module, a convolution module, + and a second feed-forward module in the end. The block starts with two Feed forward + modules sandwiching the Multi-Headed Self-Attention module and the Conv module. + + Args: + d_model (int): The dimension of model + n_head (int): The number of attention heads. + kernel_size_conv_mod (int): Size of kernels for convolution modules. + speaker_embedding_dim (int): Number of speaker embedding dimensions. + emotion_embedding_dim (int): Number of emotion embedding dimensions. + dropout (float): Probabilty of dropout. + + Inputs: inputs, mask + - **inputs** (batch, time, dim): Tensor containing input vector + - **encoding** (batch, time, dim): Positional embedding tensor + - **slf_attn_mask** (batch, 1, 1, time1): Tensor containing indices to be masked in self attention module + - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked + Returns: + - **outputs** (batch, time, dim): Tensor produced by the Conformer Block. + """ + super().__init__() + if isinstance(speaker_embedding_dim, int): + self.conditioning = Conv1dGLU( + d_model=d_model, + kernel_size=kernel_size_conv_mod, + padding=kernel_size_conv_mod // 2, + embedding_dim=speaker_embedding_dim, + ) + + self.ff = FeedForward(d_model=d_model, dropout=dropout, kernel_size=3, lrelu_slope=lrelu_slope) + self.conformer_conv_1 = ConformerConvModule( + d_model, kernel_size=kernel_size_conv_mod, dropout=dropout, lrelu_slope=lrelu_slope + ) + self.ln = nn.LayerNorm(d_model) + self.slf_attn = ConformerMultiHeadedSelfAttention(d_model=d_model, num_heads=n_head, dropout_p=dropout) + self.conformer_conv_2 = ConformerConvModule( + d_model, kernel_size=kernel_size_conv_mod, dropout=dropout, lrelu_slope=lrelu_slope + ) + + def forward( + self, + x: torch.Tensor, + speaker_embedding: torch.Tensor, + mask: torch.Tensor, + slf_attn_mask: torch.Tensor, + encoding: torch.Tensor, + ) -> torch.Tensor: + """ + Shapes: + - x: :math:`[B, T_src, C]` + - mask: :math: `[B]` + - slf_attn_mask: :math: `[B, 1, 1, T_src]` + - speaker_embedding: :math: `[B, C]` + - emotion_embedding: :math: `[B, C]` + - encoding: :math: `[B, T_max2, C]` + """ + if speaker_embedding is not None: + x = self.conditioning(x, embeddings=speaker_embedding) + x = self.ff(x) + x + x = self.conformer_conv_1(x) + x + res = x + x = self.ln(x) + x, _ = self.slf_attn(query=x, key=x, value=x, mask=slf_attn_mask, encoding=encoding) + x = x + res + x = x.masked_fill(mask.unsqueeze(-1), 0) + + x = self.conformer_conv_2(x) + x + return x + + +class FeedForward(nn.Module): + def __init__( + self, + d_model: int, + kernel_size: int, + dropout: float, + lrelu_slope: float, + expansion_factor: int = 4, + ): + """ + Feed Forward module for conformer block. + + Args: + d_model (int): The dimension of model. + kernel_size (int): Size of the kernels for conv layers. + dropout (float): probability of dropout. + expansion_factor (int): The factor by which to project the number of channels. + lrelu_slope (int): the negative slope factor for the leaky relu activation. + + Inputs: inputs + - **inputs** (batch, time, dim): Tensor containing input vector + Returns: + - **outputs** (batch, time, dim): Tensor produced by the feed forward module. + """ + super().__init__() + self.dropout = nn.Dropout(dropout) + self.ln = nn.LayerNorm(d_model) + self.conv_1 = nn.Conv1d( + d_model, + d_model * expansion_factor, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + self.act = nn.LeakyReLU(lrelu_slope) + self.conv_2 = nn.Conv1d(d_model * expansion_factor, d_model, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + x: :math: `[B, T, C]` + """ + x = self.ln(x) + x = x.permute((0, 2, 1)) + x = self.conv_1(x) + x = x.permute((0, 2, 1)) + x = self.act(x) + x = self.dropout(x) + x = x.permute((0, 2, 1)) + x = self.conv_2(x) + x = x.permute((0, 2, 1)) + x = self.dropout(x) + x = 0.5 * x + return x + + +class ConformerConvModule(nn.Module): + def __init__( + self, + d_model: int, + expansion_factor: int = 2, + kernel_size: int = 7, + dropout: float = 0.1, + lrelu_slope: float = 0.3, + ): + """ + Convolution module for conformer. Starts with a gating machanism. + a pointwise convolution and a gated linear unit (GLU). This is followed + by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution + to help with training. it also contains an expansion factor to project the number of channels. + + Args: + d_model (int): The dimension of model. + expansion_factor (int): The factor by which to project the number of channels. + kernel_size (int): Size of kernels for convolution modules. + dropout (float): Probabilty of dropout. + lrelu_slope (float): The slope coefficient for leaky relu activation. + + Inputs: inputs + - **inputs** (batch, time, dim): Tensor containing input vector + Returns: + - **outputs** (batch, time, dim): Tensor produced by the conv module. + + """ + super().__init__() + inner_dim = d_model * expansion_factor + self.ln_1 = nn.LayerNorm(d_model) + self.conv_1 = PointwiseConv1d(d_model, inner_dim * 2) + self.conv_act = GLUActivation(slope=lrelu_slope) + self.depthwise = DepthWiseConv1d( + inner_dim, + inner_dim, + kernel_size=kernel_size, + padding=calc_same_padding(kernel_size)[0], + ) + self.ln_2 = nn.GroupNorm(1, inner_dim) + self.activation = nn.LeakyReLU(lrelu_slope) + self.conv_2 = PointwiseConv1d(inner_dim, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + x: :math: `[B, T, C]` + """ + x = self.ln_1(x) + x = x.permute(0, 2, 1) + x = self.conv_1(x) + x = self.conv_act(x) + x = self.depthwise(x) + x = self.ln_2(x) + x = self.activation(x) + x = self.conv_2(x) + x = x.permute(0, 2, 1) + x = self.dropout(x) + return x + + +class ConformerMultiHeadedSelfAttention(nn.Module): + """ + Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL, + the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention + module to generalize better on different input length and the resulting encoder is more robust to the variance of + the utterance length. Conformer use prenorm residual units with dropout which helps training + and regularizing deeper models. + Args: + d_model (int): The dimension of model + num_heads (int): The number of attention heads. + dropout_p (float): probability of dropout + Inputs: inputs, mask + - **inputs** (batch, time, dim): Tensor containing input vector + - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked + Returns: + - **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module. + """ + + def __init__(self, d_model: int, num_heads: int, dropout_p: float): + super().__init__() + self.attention = RelativeMultiHeadAttention(d_model=d_model, num_heads=num_heads) + self.dropout = nn.Dropout(p=dropout_p) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor, + encoding: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, seq_length, _ = key.size() # pylint: disable=unused-variable + encoding = encoding[:, : key.shape[1]] + encoding = encoding.repeat(batch_size, 1, 1) + outputs, attn = self.attention(query, key, value, pos_embedding=encoding, mask=mask) + outputs = self.dropout(outputs) + return outputs, attn + + +class RelativeMultiHeadAttention(nn.Module): + """ + Multi-head attention with relative positional encoding. + This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Args: + d_model (int): The dimension of model + num_heads (int): The number of attention heads. + Inputs: query, key, value, pos_embedding, mask + - **query** (batch, time, dim): Tensor containing query vector + - **key** (batch, time, dim): Tensor containing key vector + - **value** (batch, time, dim): Tensor containing value vector + - **pos_embedding** (batch, time, dim): Positional embedding tensor + - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked + Returns: + - **outputs**: Tensor produces by relative multi head attention module. + """ + + def __init__( + self, + d_model: int = 512, + num_heads: int = 16, + ): + super().__init__() + assert d_model % num_heads == 0, "d_model % num_heads should be zero." + self.d_model = d_model + self.d_head = int(d_model / num_heads) + self.num_heads = num_heads + self.sqrt_dim = math.sqrt(d_model) + + self.query_proj = nn.Linear(d_model, d_model) + self.key_proj = nn.Linear(d_model, d_model, bias=False) + self.value_proj = nn.Linear(d_model, d_model, bias=False) + self.pos_proj = nn.Linear(d_model, d_model, bias=False) + + self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) + self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) + torch.nn.init.xavier_uniform_(self.u_bias) + torch.nn.init.xavier_uniform_(self.v_bias) + self.out_proj = nn.Linear(d_model, d_model) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + pos_embedding: torch.Tensor, + mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = query.shape[0] + query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) + key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) + value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) + pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head) + u_bias = self.u_bias.expand_as(query) + v_bias = self.v_bias.expand_as(query) + a = (query + u_bias).transpose(1, 2) + content_score = a @ key.transpose(2, 3) + b = (query + v_bias).transpose(1, 2) + pos_score = b @ pos_embedding.permute(0, 2, 3, 1) + pos_score = self._relative_shift(pos_score) + + score = content_score + pos_score + score = score * (1.0 / self.sqrt_dim) + + score.masked_fill_(mask, -1e9) + + attn = F.softmax(score, -1) + + context = (attn @ value).transpose(1, 2) + context = context.contiguous().view(batch_size, -1, self.d_model) + + return self.out_proj(context), attn + + def _relative_shift(self, pos_score: torch.Tensor) -> torch.Tensor: # pylint: disable=no-self-use + batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() + zeros = torch.zeros((batch_size, num_heads, seq_length1, 1), device=pos_score.device) + padded_pos_score = torch.cat([zeros, pos_score], dim=-1) + padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1) + pos_score = padded_pos_score[:, :, 1:].view_as(pos_score) + return pos_score + + +class MultiHeadAttention(nn.Module): + """ + input: + query --- [N, T_q, query_dim] + key --- [N, T_k, key_dim] + output: + out --- [N, T_q, num_units] + """ + + def __init__(self, query_dim: int, key_dim: int, num_units: int, num_heads: int): + super().__init__() + self.num_units = num_units + self.num_heads = num_heads + self.key_dim = key_dim + + self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False) + self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + + def forward(self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor: + querys = self.W_query(query) # [N, T_q, num_units] + keys = self.W_key(key) # [N, T_k, num_units] + values = self.W_value(key) + split_size = self.num_units // self.num_heads + querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h] + keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + # score = softmax(QK^T / (d_k ** 0.5)) + scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k] + scores = scores / (self.key_dim**0.5) + scores = F.softmax(scores, dim=3) + # out = score * V + out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] + out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] + return out diff --git a/TTS/tts/layers/delightful_tts/conv_layers.py b/TTS/tts/layers/delightful_tts/conv_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..fb9aa4495fd9b25fb88ce0dd1493cfa1f2c47d5c --- /dev/null +++ b/TTS/tts/layers/delightful_tts/conv_layers.py @@ -0,0 +1,671 @@ +from typing import Tuple + +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import +import torch.nn.functional as F +from torch.nn.utils import parametrize + +from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor + + +def calc_same_padding(kernel_size: int) -> Tuple[int, int]: + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + + +class ConvNorm(nn.Module): + """A 1-dimensional convolutional layer with optional weight normalization. + + This layer wraps a 1D convolutional layer from PyTorch and applies + optional weight normalization. The layer can be used in a similar way to + the convolutional layers in PyTorch's `torch.nn` module. + + Args: + in_channels (int): The number of channels in the input signal. + out_channels (int): The number of channels in the output signal. + kernel_size (int, optional): The size of the convolving kernel. + Defaults to 1. + stride (int, optional): The stride of the convolution. Defaults to 1. + padding (int, optional): Zero-padding added to both sides of the input. + If `None`, the padding will be calculated so that the output has + the same length as the input. Defaults to `None`. + dilation (int, optional): Spacing between kernel elements. Defaults to 1. + bias (bool, optional): If `True`, add bias after convolution. Defaults to `True`. + w_init_gain (str, optional): The weight initialization function to use. + Can be either 'linear' or 'relu'. Defaults to 'linear'. + use_weight_norm (bool, optional): If `True`, apply weight normalization + to the convolutional weights. Defaults to `False`. + + Shapes: + - Input: :math:`[N, D, T]` + + - Output: :math:`[N, out_dim, T]` where `out_dim` is the number of output dimensions. + + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=None, + dilation=1, + bias=True, + w_init_gain="linear", + use_weight_norm=False, + ): + super(ConvNorm, self).__init__() # pylint: disable=super-with-arguments + if padding is None: + assert kernel_size % 2 == 1 + padding = int(dilation * (kernel_size - 1) / 2) + self.kernel_size = kernel_size + self.dilation = dilation + self.use_weight_norm = use_weight_norm + conv_fn = nn.Conv1d + self.conv = conv_fn( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + ) + nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain)) + if self.use_weight_norm: + self.conv = nn.utils.parametrizations.weight_norm(self.conv) + + def forward(self, signal, mask=None): + conv_signal = self.conv(signal) + if mask is not None: + # always re-zero output if mask is + # available to match zero-padding + conv_signal = conv_signal * mask + return conv_signal + + +class ConvLSTMLinear(nn.Module): + def __init__( + self, + in_dim, + out_dim, + n_layers=2, + n_channels=256, + kernel_size=3, + p_dropout=0.1, + lstm_type="bilstm", + use_linear=True, + ): + super(ConvLSTMLinear, self).__init__() # pylint: disable=super-with-arguments + self.out_dim = out_dim + self.lstm_type = lstm_type + self.use_linear = use_linear + self.dropout = nn.Dropout(p=p_dropout) + + convolutions = [] + for i in range(n_layers): + conv_layer = ConvNorm( + in_dim if i == 0 else n_channels, + n_channels, + kernel_size=kernel_size, + stride=1, + padding=int((kernel_size - 1) / 2), + dilation=1, + w_init_gain="relu", + ) + conv_layer = nn.utils.parametrizations.weight_norm(conv_layer.conv, name="weight") + convolutions.append(conv_layer) + + self.convolutions = nn.ModuleList(convolutions) + + if not self.use_linear: + n_channels = out_dim + + if self.lstm_type != "": + use_bilstm = False + lstm_channels = n_channels + if self.lstm_type == "bilstm": + use_bilstm = True + lstm_channels = int(n_channels // 2) + + self.bilstm = nn.LSTM(n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm) + lstm_norm_fn_pntr = nn.utils.spectral_norm + self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0") + if self.lstm_type == "bilstm": + self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse") + + if self.use_linear: + self.dense = nn.Linear(n_channels, out_dim) + + def run_padded_sequence(self, context, lens): + context_embedded = [] + for b_ind in range(context.size()[0]): # TODO: speed up + curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone() + for conv in self.convolutions: + curr_context = self.dropout(F.relu(conv(curr_context))) + context_embedded.append(curr_context[0].transpose(0, 1)) + context = nn.utils.rnn.pad_sequence(context_embedded, batch_first=True) + return context + + def run_unsorted_inputs(self, fn, context, lens): # pylint: disable=no-self-use + lens_sorted, ids_sorted = torch.sort(lens, descending=True) + unsort_ids = [0] * lens.size(0) + for i in range(len(ids_sorted)): # pylint: disable=consider-using-enumerate + unsort_ids[ids_sorted[i]] = i + lens_sorted = lens_sorted.long().cpu() + + context = context[ids_sorted] + context = nn.utils.rnn.pack_padded_sequence(context, lens_sorted, batch_first=True) + context = fn(context)[0] + context = nn.utils.rnn.pad_packed_sequence(context, batch_first=True)[0] + + # map back to original indices + context = context[unsort_ids] + return context + + def forward(self, context, lens): + if context.size()[0] > 1: + context = self.run_padded_sequence(context, lens) + # to B, D, T + context = context.transpose(1, 2) + else: + for conv in self.convolutions: + context = self.dropout(F.relu(conv(context))) + + if self.lstm_type != "": + context = context.transpose(1, 2) + self.bilstm.flatten_parameters() + if lens is not None: + context = self.run_unsorted_inputs(self.bilstm, context, lens) + else: + context = self.bilstm(context)[0] + context = context.transpose(1, 2) + + x_hat = context + if self.use_linear: + x_hat = self.dense(context.transpose(1, 2)).transpose(1, 2) + + return x_hat + + +class DepthWiseConv1d(nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int): + super().__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, groups=in_channels) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class PointwiseConv1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int = 1, + padding: int = 0, + bias: bool = True, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=stride, + padding=padding, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class BSConv1d(nn.Module): + """https://arxiv.org/pdf/2003.13549.pdf""" + + def __init__(self, channels_in: int, channels_out: int, kernel_size: int, padding: int): + super().__init__() + self.pointwise = nn.Conv1d(channels_in, channels_out, kernel_size=1) + self.depthwise = nn.Conv1d( + channels_out, + channels_out, + kernel_size=kernel_size, + padding=padding, + groups=channels_out, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1 = self.pointwise(x) + x2 = self.depthwise(x1) + return x2 + + +class BSConv2d(nn.Module): + """https://arxiv.org/pdf/2003.13549.pdf""" + + def __init__(self, channels_in: int, channels_out: int, kernel_size: int, padding: int): + super().__init__() + self.pointwise = nn.Conv2d(channels_in, channels_out, kernel_size=1) + self.depthwise = nn.Conv2d( + channels_out, + channels_out, + kernel_size=kernel_size, + padding=padding, + groups=channels_out, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1 = self.pointwise(x) + x2 = self.depthwise(x1) + return x2 + + +class Conv1dGLU(nn.Module): + """From DeepVoice 3""" + + def __init__(self, d_model: int, kernel_size: int, padding: int, embedding_dim: int): + super().__init__() + self.conv = BSConv1d(d_model, 2 * d_model, kernel_size=kernel_size, padding=padding) + self.embedding_proj = nn.Linear(embedding_dim, d_model) + self.register_buffer("sqrt", torch.sqrt(torch.FloatTensor([0.5])).squeeze(0)) + self.softsign = torch.nn.Softsign() + + def forward(self, x: torch.Tensor, embeddings: torch.Tensor) -> torch.Tensor: + x = x.permute((0, 2, 1)) + residual = x + x = self.conv(x) + splitdim = 1 + a, b = x.split(x.size(splitdim) // 2, dim=splitdim) + embeddings = self.embedding_proj(embeddings).unsqueeze(2) + softsign = self.softsign(embeddings) + softsign = softsign.expand_as(a) + a = a + softsign + x = a * torch.sigmoid(b) + x = x + residual + x = x * self.sqrt + x = x.permute((0, 2, 1)) + return x + + +class ConvTransposed(nn.Module): + """ + A 1D convolutional transposed layer for PyTorch. + This layer applies a 1D convolutional transpose operation to its input tensor, + where the number of channels of the input tensor is the same as the number of channels of the output tensor. + + Attributes: + in_channels (int): The number of channels in the input tensor. + out_channels (int): The number of channels in the output tensor. + kernel_size (int): The size of the convolutional kernel. Default: 1. + padding (int): The number of padding elements to add to the input tensor. Default: 0. + conv (BSConv1d): The 1D convolutional transpose layer. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 1, + padding: int = 0, + ): + super().__init__() + self.conv = BSConv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=padding, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.contiguous().transpose(1, 2) + x = self.conv(x) + x = x.contiguous().transpose(1, 2) + return x + + +class DepthwiseConvModule(nn.Module): + def __init__(self, dim: int, kernel_size: int = 7, expansion: int = 4, lrelu_slope: float = 0.3): + super().__init__() + padding = calc_same_padding(kernel_size) + self.depthwise = nn.Conv1d( + dim, + dim * expansion, + kernel_size=kernel_size, + padding=padding[0], + groups=dim, + ) + self.act = nn.LeakyReLU(lrelu_slope) + self.out = nn.Conv1d(dim * expansion, dim, 1, 1, 0) + self.ln = nn.LayerNorm(dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ln(x) + x = x.permute((0, 2, 1)) + x = self.depthwise(x) + x = self.act(x) + x = self.out(x) + x = x.permute((0, 2, 1)) + return x + + +class AddCoords(nn.Module): + def __init__(self, rank: int, with_r: bool = False): + super().__init__() + self.rank = rank + self.with_r = with_r + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.rank == 1: + batch_size_shape, channel_in_shape, dim_x = x.shape # pylint: disable=unused-variable + xx_range = torch.arange(dim_x, dtype=torch.int32) + xx_channel = xx_range[None, None, :] + + xx_channel = xx_channel.float() / (dim_x - 1) + xx_channel = xx_channel * 2 - 1 + xx_channel = xx_channel.repeat(batch_size_shape, 1, 1) + + xx_channel = xx_channel.to(x.device) + out = torch.cat([x, xx_channel], dim=1) + + if self.with_r: + rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2)) + out = torch.cat([out, rr], dim=1) + + elif self.rank == 2: + batch_size_shape, channel_in_shape, dim_y, dim_x = x.shape + xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32) + yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32) + + xx_range = torch.arange(dim_y, dtype=torch.int32) + yy_range = torch.arange(dim_x, dtype=torch.int32) + xx_range = xx_range[None, None, :, None] + yy_range = yy_range[None, None, :, None] + + xx_channel = torch.matmul(xx_range, xx_ones) + yy_channel = torch.matmul(yy_range, yy_ones) + + # transpose y + yy_channel = yy_channel.permute(0, 1, 3, 2) + + xx_channel = xx_channel.float() / (dim_y - 1) + yy_channel = yy_channel.float() / (dim_x - 1) + + xx_channel = xx_channel * 2 - 1 + yy_channel = yy_channel * 2 - 1 + + xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1) + yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1) + + xx_channel = xx_channel.to(x.device) + yy_channel = yy_channel.to(x.device) + + out = torch.cat([x, xx_channel, yy_channel], dim=1) + + if self.with_r: + rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) + out = torch.cat([out, rr], dim=1) + + elif self.rank == 3: + batch_size_shape, channel_in_shape, dim_z, dim_y, dim_x = x.shape + xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32) + yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32) + zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32) + + xy_range = torch.arange(dim_y, dtype=torch.int32) + xy_range = xy_range[None, None, None, :, None] + + yz_range = torch.arange(dim_z, dtype=torch.int32) + yz_range = yz_range[None, None, None, :, None] + + zx_range = torch.arange(dim_x, dtype=torch.int32) + zx_range = zx_range[None, None, None, :, None] + + xy_channel = torch.matmul(xy_range, xx_ones) + xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2) + + yz_channel = torch.matmul(yz_range, yy_ones) + yz_channel = yz_channel.permute(0, 1, 3, 4, 2) + yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4) + + zx_channel = torch.matmul(zx_range, zz_ones) + zx_channel = zx_channel.permute(0, 1, 4, 2, 3) + zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3) + + xx_channel = xx_channel.to(x.device) + yy_channel = yy_channel.to(x.device) + zz_channel = zz_channel.to(x.device) + out = torch.cat([x, xx_channel, yy_channel, zz_channel], dim=1) + + if self.with_r: + rr = torch.sqrt( + torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2) + torch.pow(zz_channel - 0.5, 2) + ) + out = torch.cat([out, rr], dim=1) + else: + raise NotImplementedError + + return out + + +class CoordConv1d(nn.modules.conv.Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + with_r: bool = False, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + self.rank = 1 + self.addcoords = AddCoords(self.rank, with_r) + self.conv = nn.Conv1d( + in_channels + self.rank + int(with_r), + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.addcoords(x) + x = self.conv(x) + return x + + +class CoordConv2d(nn.modules.conv.Conv2d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + with_r: bool = False, + ): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + self.rank = 2 + self.addcoords = AddCoords(self.rank, with_r) + self.conv = nn.Conv2d( + in_channels + self.rank + int(with_r), + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.addcoords(x) + x = self.conv(x) + return x + + +class LVCBlock(torch.nn.Module): + """the location-variable convolutions""" + + def __init__( # pylint: disable=dangerous-default-value + self, + in_channels, + cond_channels, + stride, + dilations=[1, 3, 9, 27], + lReLU_slope=0.2, + conv_kernel_size=3, + cond_hop_length=256, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + ): + super().__init__() + + self.cond_hop_length = cond_hop_length + self.conv_layers = len(dilations) + self.conv_kernel_size = conv_kernel_size + + self.kernel_predictor = KernelPredictor( + cond_channels=cond_channels, + conv_in_channels=in_channels, + conv_out_channels=2 * in_channels, + conv_layers=len(dilations), + conv_kernel_size=conv_kernel_size, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=kpnet_dropout, + kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}, + ) + + self.convt_pre = nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.parametrizations.weight_norm( + nn.ConvTranspose1d( + in_channels, + in_channels, + 2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + output_padding=stride % 2, + ) + ), + ) + + self.conv_blocks = nn.ModuleList() + for dilation in dilations: + self.conv_blocks.append( + nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.parametrizations.weight_norm( + nn.Conv1d( + in_channels, + in_channels, + conv_kernel_size, + padding=dilation * (conv_kernel_size - 1) // 2, + dilation=dilation, + ) + ), + nn.LeakyReLU(lReLU_slope), + ) + ) + + def forward(self, x, c): + """forward propagation of the location-variable convolutions. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length) + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + + Returns: + Tensor: the output sequence (batch, in_channels, in_length) + """ + _, in_channels, _ = x.shape # (B, c_g, L') + + x = self.convt_pre(x) # (B, c_g, stride * L') + kernels, bias = self.kernel_predictor(c) + + for i, conv in enumerate(self.conv_blocks): + output = conv(x) # (B, c_g, stride * L') + + k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) + b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) + + output = self.location_variable_convolution( + output, k, b, hop_size=self.cond_hop_length + ) # (B, 2 * c_g, stride * L'): LVC + x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( + output[:, in_channels:, :] + ) # (B, c_g, stride * L'): GAU + + return x + + def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): # pylint: disable=no-self-use + """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. + Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length). + kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) + bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) + dilation (int): the dilation of convolution. + hop_size (int): the hop_size of the conditioning sequence. + Returns: + (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). + """ + batch, _, in_length = x.shape + batch, _, out_channels, kernel_size, kernel_length = kernel.shape + assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" + + padding = dilation * int((kernel_size - 1) / 2) + x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) + x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) + + if hop_size < dilation: + x = F.pad(x, (0, dilation), "constant", 0) + x = x.unfold( + 3, dilation, dilation + ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) + x = x[:, :, :, :, :hop_size] + x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) + + o = torch.einsum("bildsk,biokl->bolsd", x, kernel) + o = o.to(memory_format=torch.channels_last_3d) + bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) + o = o + bias + o = o.contiguous().view(batch, out_channels, -1) + + return o + + def remove_weight_norm(self): + self.kernel_predictor.remove_weight_norm() + parametrize.remove_parametrizations(self.convt_pre[1], "weight") + for block in self.conv_blocks: + parametrize.remove_parametrizations(block[1], "weight") diff --git a/TTS/tts/layers/delightful_tts/encoders.py b/TTS/tts/layers/delightful_tts/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..0878f0677a29d092597a46e8a3b11e4a521769b8 --- /dev/null +++ b/TTS/tts/layers/delightful_tts/encoders.py @@ -0,0 +1,261 @@ +from typing import List, Tuple, Union + +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import +import torch.nn.functional as F + +from TTS.tts.layers.delightful_tts.conformer import ConformerMultiHeadedSelfAttention +from TTS.tts.layers.delightful_tts.conv_layers import CoordConv1d +from TTS.tts.layers.delightful_tts.networks import STL + + +def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor: + batch_size = lengths.shape[0] + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1) + mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) + return mask + + +def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: + return torch.ceil(lens / stride).int() + + +class ReferenceEncoder(nn.Module): + """ + Referance encoder for utterance and phoneme prosody encoders. Reference encoder + made up of convolution and RNN layers. + + Args: + num_mels (int): Number of mel frames to produce. + ref_enc_filters (list[int]): List of channel sizes for encoder layers. + ref_enc_size (int): Size of the kernel for the conv layers. + ref_enc_strides (List[int]): List of strides to use for conv layers. + ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit. + + Inputs: inputs, mask + - **inputs** (batch, dim, time): Tensor containing mel vector + - **lengths** (batch): Tensor containing the mel lengths. + Returns: + - **outputs** (batch, time, dim): Tensor produced by Reference Encoder. + """ + + def __init__( + self, + num_mels: int, + ref_enc_filters: List[Union[int, int, int, int, int, int]], + ref_enc_size: int, + ref_enc_strides: List[Union[int, int, int, int, int]], + ref_enc_gru_size: int, + ): + super().__init__() + + n_mel_channels = num_mels + self.n_mel_channels = n_mel_channels + K = len(ref_enc_filters) + filters = [self.n_mel_channels] + ref_enc_filters + strides = [1] + ref_enc_strides + # Use CoordConv at the first layer to better preserve positional information: https://arxiv.org/pdf/1811.02122.pdf + convs = [ + CoordConv1d( + in_channels=filters[0], + out_channels=filters[0 + 1], + kernel_size=ref_enc_size, + stride=strides[0], + padding=ref_enc_size // 2, + with_r=True, + ) + ] + convs2 = [ + nn.Conv1d( + in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=ref_enc_size, + stride=strides[i], + padding=ref_enc_size // 2, + ) + for i in range(1, K) + ] + convs.extend(convs2) + self.convs = nn.ModuleList(convs) + + self.norms = nn.ModuleList([nn.InstanceNorm1d(num_features=ref_enc_filters[i], affine=True) for i in range(K)]) + + self.gru = nn.GRU( + input_size=ref_enc_filters[-1], + hidden_size=ref_enc_gru_size, + batch_first=True, + ) + + def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + inputs --- [N, n_mels, timesteps] + outputs --- [N, E//2] + """ + + mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1) + x = x.masked_fill(mel_masks, 0) + for conv, norm in zip(self.convs, self.norms): + x = conv(x) + x = F.leaky_relu(x, 0.3) # [N, 128, Ty//2^K, n_mels//2^K] + x = norm(x) + + for _ in range(2): + mel_lens = stride_lens(mel_lens) + + mel_masks = get_mask_from_lengths(mel_lens) + + x = x.masked_fill(mel_masks.unsqueeze(1), 0) + x = x.permute((0, 2, 1)) + x = torch.nn.utils.rnn.pack_padded_sequence(x, mel_lens.cpu().int(), batch_first=True, enforce_sorted=False) + + self.gru.flatten_parameters() + x, memory = self.gru(x) # memory --- [N, Ty, E//2], out --- [1, N, E//2] + x, _ = torch.nn.utils.rnn.pad_packed_sequence(x, batch_first=True) + + return x, memory, mel_masks + + def calculate_channels( # pylint: disable=no-self-use + self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int + ) -> int: + for _ in range(n_convs): + L = (L - kernel_size + 2 * pad) // stride + 1 + return L + + +class UtteranceLevelProsodyEncoder(nn.Module): + def __init__( + self, + num_mels: int, + ref_enc_filters: List[Union[int, int, int, int, int, int]], + ref_enc_size: int, + ref_enc_strides: List[Union[int, int, int, int, int]], + ref_enc_gru_size: int, + dropout: float, + n_hidden: int, + bottleneck_size_u: int, + token_num: int, + ): + """ + Encoder to extract prosody from utterance. it is made up of a reference encoder + with a couple of linear layers and style token layer with dropout. + + Args: + num_mels (int): Number of mel frames to produce. + ref_enc_filters (list[int]): List of channel sizes for ref encoder layers. + ref_enc_size (int): Size of the kernel for the ref encoder conv layers. + ref_enc_strides (List[int]): List of strides to use for teh ref encoder conv layers. + ref_enc_gru_size (int): Number of hidden features for the gated recurrent unit. + dropout (float): Probability of dropout. + n_hidden (int): Size of hidden layers. + bottleneck_size_u (int): Size of the bottle neck layer. + + Inputs: inputs, mask + - **inputs** (batch, dim, time): Tensor containing mel vector + - **lengths** (batch): Tensor containing the mel lengths. + Returns: + - **outputs** (batch, 1, dim): Tensor produced by Utterance Level Prosody Encoder. + """ + super().__init__() + + self.E = n_hidden + self.d_q = self.d_k = n_hidden + bottleneck_size = bottleneck_size_u + + self.encoder = ReferenceEncoder( + ref_enc_filters=ref_enc_filters, + ref_enc_gru_size=ref_enc_gru_size, + ref_enc_size=ref_enc_size, + ref_enc_strides=ref_enc_strides, + num_mels=num_mels, + ) + self.encoder_prj = nn.Linear(ref_enc_gru_size, self.E // 2) + self.stl = STL(n_hidden=n_hidden, token_num=token_num) + self.encoder_bottleneck = nn.Linear(self.E, bottleneck_size) + self.dropout = nn.Dropout(dropout) + + def forward(self, mels: torch.Tensor, mel_lens: torch.Tensor) -> torch.Tensor: + """ + Shapes: + mels: :math: `[B, C, T]` + mel_lens: :math: `[B]` + + out --- [N, seq_len, E] + """ + _, embedded_prosody, _ = self.encoder(mels, mel_lens) + + # Bottleneck + embedded_prosody = self.encoder_prj(embedded_prosody) + + # Style Token + out = self.encoder_bottleneck(self.stl(embedded_prosody)) + out = self.dropout(out) + + out = out.view((-1, 1, out.shape[3])) + return out + + +class PhonemeLevelProsodyEncoder(nn.Module): + def __init__( + self, + num_mels: int, + ref_enc_filters: List[Union[int, int, int, int, int, int]], + ref_enc_size: int, + ref_enc_strides: List[Union[int, int, int, int, int]], + ref_enc_gru_size: int, + dropout: float, + n_hidden: int, + n_heads: int, + bottleneck_size_p: int, + ): + super().__init__() + + self.E = n_hidden + self.d_q = self.d_k = n_hidden + bottleneck_size = bottleneck_size_p + + self.encoder = ReferenceEncoder( + ref_enc_filters=ref_enc_filters, + ref_enc_gru_size=ref_enc_gru_size, + ref_enc_size=ref_enc_size, + ref_enc_strides=ref_enc_strides, + num_mels=num_mels, + ) + self.encoder_prj = nn.Linear(ref_enc_gru_size, n_hidden) + self.attention = ConformerMultiHeadedSelfAttention( + d_model=n_hidden, + num_heads=n_heads, + dropout_p=dropout, + ) + self.encoder_bottleneck = nn.Linear(n_hidden, bottleneck_size) + + def forward( + self, + x: torch.Tensor, + src_mask: torch.Tensor, + mels: torch.Tensor, + mel_lens: torch.Tensor, + encoding: torch.Tensor, + ) -> torch.Tensor: + """ + x --- [N, seq_len, encoder_embedding_dim] + mels --- [N, Ty/r, n_mels*r], r=1 + out --- [N, seq_len, bottleneck_size] + attn --- [N, seq_len, ref_len], Ty/r = ref_len + """ + embedded_prosody, _, mel_masks = self.encoder(mels, mel_lens) + + # Bottleneck + embedded_prosody = self.encoder_prj(embedded_prosody) + + attn_mask = mel_masks.view((mel_masks.shape[0], 1, 1, -1)) + x, _ = self.attention( + query=x, + key=embedded_prosody, + value=embedded_prosody, + mask=attn_mask, + encoding=encoding, + ) + x = self.encoder_bottleneck(x) + x = x.masked_fill(src_mask.unsqueeze(-1), 0.0) + return x diff --git a/TTS/tts/layers/delightful_tts/energy_adaptor.py b/TTS/tts/layers/delightful_tts/energy_adaptor.py new file mode 100644 index 0000000000000000000000000000000000000000..ea0d1e47214d81a42b934bbaaa4b3ebb9f63bcc6 --- /dev/null +++ b/TTS/tts/layers/delightful_tts/energy_adaptor.py @@ -0,0 +1,82 @@ +from typing import Callable, Tuple + +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import + +from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor +from TTS.tts.utils.helpers import average_over_durations + + +class EnergyAdaptor(nn.Module): # pylint: disable=abstract-method + """Variance Adaptor with an added 1D conv layer. Used to + get energy embeddings. + + Args: + channels_in (int): Number of in channels for conv layers. + channels_out (int): Number of out channels. + kernel_size (int): Size the kernel for the conv layers. + dropout (float): Probability of dropout. + lrelu_slope (float): Slope for the leaky relu. + emb_kernel_size (int): Size the kernel for the pitch embedding. + + Inputs: inputs, mask + - **inputs** (batch, time1, dim): Tensor containing input vector + - **target** (batch, 1, time2): Tensor containing the energy target + - **dr** (batch, time1): Tensor containing aligner durations vector + - **mask** (batch, time1): Tensor containing indices to be masked + Returns: + - **energy prediction** (batch, 1, time1): Tensor produced by energy predictor + - **energy embedding** (batch, channels, time1): Tensor produced energy adaptor + - **average energy target(train only)** (batch, 1, time1): Tensor produced after averaging over durations + + """ + + def __init__( + self, + channels_in: int, + channels_hidden: int, + channels_out: int, + kernel_size: int, + dropout: float, + lrelu_slope: float, + emb_kernel_size: int, + ): + super().__init__() + self.energy_predictor = VariancePredictor( + channels_in=channels_in, + channels=channels_hidden, + channels_out=channels_out, + kernel_size=kernel_size, + p_dropout=dropout, + lrelu_slope=lrelu_slope, + ) + self.energy_emb = nn.Conv1d( + 1, + channels_hidden, + kernel_size=emb_kernel_size, + padding=int((emb_kernel_size - 1) / 2), + ) + + def get_energy_embedding_train( + self, x: torch.Tensor, target: torch.Tensor, dr: torch.IntTensor, mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Shapes: + x: :math: `[B, T_src, C]` + target: :math: `[B, 1, T_max2]` + dr: :math: `[B, T_src]` + mask: :math: `[B, T_src]` + """ + energy_pred = self.energy_predictor(x, mask) + energy_pred.unsqueeze_(1) + avg_energy_target = average_over_durations(target, dr) + energy_emb = self.energy_emb(avg_energy_target) + return energy_pred, avg_energy_target, energy_emb + + def get_energy_embedding(self, x: torch.Tensor, mask: torch.Tensor, energy_transform: Callable) -> torch.Tensor: + energy_pred = self.energy_predictor(x, mask) + energy_pred.unsqueeze_(1) + if energy_transform is not None: + energy_pred = energy_transform(energy_pred, (~mask).sum(dim=(1, 2)), self.pitch_mean, self.pitch_std) + energy_emb_pred = self.energy_emb(energy_pred) + return energy_emb_pred, energy_pred diff --git a/TTS/tts/layers/delightful_tts/kernel_predictor.py b/TTS/tts/layers/delightful_tts/kernel_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..96c550b6c2609c52762bb3eaca373e31f0599bf8 --- /dev/null +++ b/TTS/tts/layers/delightful_tts/kernel_predictor.py @@ -0,0 +1,128 @@ +import torch.nn as nn # pylint: disable=consider-using-from-import +from torch.nn.utils import parametrize + + +class KernelPredictor(nn.Module): + """Kernel predictor for the location-variable convolutions + + Args: + cond_channels (int): number of channel for the conditioning sequence, + conv_in_channels (int): number of channel for the input sequence, + conv_out_channels (int): number of channel for the output sequence, + conv_layers (int): number of layers + + """ + + def __init__( # pylint: disable=dangerous-default-value + self, + cond_channels, + conv_in_channels, + conv_out_channels, + conv_layers, + conv_kernel_size=3, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + kpnet_nonlinear_activation="LeakyReLU", + kpnet_nonlinear_activation_params={"negative_slope": 0.1}, + ): + super().__init__() + + self.conv_in_channels = conv_in_channels + self.conv_out_channels = conv_out_channels + self.conv_kernel_size = conv_kernel_size + self.conv_layers = conv_layers + + kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w + kpnet_bias_channels = conv_out_channels * conv_layers # l_b + + self.input_conv = nn.Sequential( + nn.utils.parametrizations.weight_norm( + nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.residual_convs = nn.ModuleList() + padding = (kpnet_conv_size - 1) // 2 + for _ in range(3): + self.residual_convs.append( + nn.Sequential( + nn.Dropout(kpnet_dropout), + nn.utils.parametrizations.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_hidden_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + nn.utils.parametrizations.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_hidden_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + ) + self.kernel_conv = nn.utils.parametrizations.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_kernel_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ) + self.bias_conv = nn.utils.parametrizations.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_bias_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ) + + def forward(self, c): + """ + Args: + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + """ + batch, _, cond_length = c.shape + c = self.input_conv(c) + for residual_conv in self.residual_convs: + residual_conv.to(c.device) + c = c + residual_conv(c) + k = self.kernel_conv(c) + b = self.bias_conv(c) + kernels = k.contiguous().view( + batch, + self.conv_layers, + self.conv_in_channels, + self.conv_out_channels, + self.conv_kernel_size, + cond_length, + ) + bias = b.contiguous().view( + batch, + self.conv_layers, + self.conv_out_channels, + cond_length, + ) + + return kernels, bias + + def remove_weight_norm(self): + parametrize.remove_parametrizations(self.input_conv[0], "weight") + parametrize.remove_parametrizations(self.kernel_conv, "weight") + parametrize.remove_parametrizations(self.bias_conv, "weight") + for block in self.residual_convs: + parametrize.remove_parametrizations(block[1], "weight") + parametrize.remove_parametrizations(block[3], "weight") diff --git a/TTS/tts/layers/delightful_tts/networks.py b/TTS/tts/layers/delightful_tts/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..4305022f18cf95565b2da2553740276818fb486c --- /dev/null +++ b/TTS/tts/layers/delightful_tts/networks.py @@ -0,0 +1,219 @@ +import math +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import +import torch.nn.functional as F + +from TTS.tts.layers.delightful_tts.conv_layers import ConvNorm + + +def initialize_embeddings(shape: Tuple[int]) -> torch.Tensor: + assert len(shape) == 2, "Can only initialize 2-D embedding matrices ..." + # Kaiming initialization + return torch.randn(shape) * np.sqrt(2 / shape[1]) + + +def positional_encoding(d_model: int, length: int, device: torch.device) -> torch.Tensor: + pe = torch.zeros(length, d_model, device=device) + position = torch.arange(0, length, dtype=torch.float, device=device).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2, device=device).float() * -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + return pe + + +class BottleneckLayer(nn.Module): + """ + Bottleneck layer for reducing the dimensionality of a tensor. + + Args: + in_dim: The number of input dimensions. + reduction_factor: The factor by which to reduce the number of dimensions. + norm: The normalization method to use. Can be "weightnorm" or "instancenorm". + non_linearity: The non-linearity to use. Can be "relu" or "leakyrelu". + kernel_size: The size of the convolutional kernel. + use_partial_padding: Whether to use partial padding with the convolutional kernel. + + Shape: + - Input: :math:`[N, in_dim]` where `N` is the batch size and `in_dim` is the number of input dimensions. + + - Output: :math:`[N, out_dim]` where `out_dim` is the number of output dimensions. + """ + + def __init__( + self, + in_dim, + reduction_factor, + norm="weightnorm", + non_linearity="relu", + kernel_size=3, + use_partial_padding=False, # pylint: disable=unused-argument + ): + super(BottleneckLayer, self).__init__() # pylint: disable=super-with-arguments + + self.reduction_factor = reduction_factor + reduced_dim = int(in_dim / reduction_factor) + self.out_dim = reduced_dim + if self.reduction_factor > 1: + fn = ConvNorm(in_dim, reduced_dim, kernel_size=kernel_size, use_weight_norm=(norm == "weightnorm")) + if norm == "instancenorm": + fn = nn.Sequential(fn, nn.InstanceNorm1d(reduced_dim, affine=True)) + + self.projection_fn = fn + self.non_linearity = nn.ReLU() + if non_linearity == "leakyrelu": + self.non_linearity = nn.LeakyReLU() + + def forward(self, x): + if self.reduction_factor > 1: + x = self.projection_fn(x) + x = self.non_linearity(x) + return x + + +class GLUActivation(nn.Module): + """Class that implements the Gated Linear Unit (GLU) activation function. + + The GLU activation function is a variant of the Leaky ReLU activation function, + where the output of the activation function is gated by an input tensor. + + """ + + def __init__(self, slope: float): + super().__init__() + self.lrelu = nn.LeakyReLU(slope) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out, gate = x.chunk(2, dim=1) + x = out * self.lrelu(gate) + return x + + +class StyleEmbedAttention(nn.Module): + def __init__(self, query_dim: int, key_dim: int, num_units: int, num_heads: int): + super().__init__() + self.num_units = num_units + self.num_heads = num_heads + self.key_dim = key_dim + + self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False) + self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + + def forward(self, query: torch.Tensor, key_soft: torch.Tensor) -> torch.Tensor: + values = self.W_value(key_soft) + split_size = self.num_units // self.num_heads + values = torch.stack(torch.split(values, split_size, dim=2), dim=0) + + out_soft = scores_soft = None + querys = self.W_query(query) # [N, T_q, num_units] + keys = self.W_key(key_soft) # [N, T_k, num_units] + + # [h, N, T_q, num_units/h] + querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) + # [h, N, T_k, num_units/h] + keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) + # [h, N, T_k, num_units/h] + + # score = softmax(QK^T / (d_k ** 0.5)) + scores_soft = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k] + scores_soft = scores_soft / (self.key_dim**0.5) + scores_soft = F.softmax(scores_soft, dim=3) + + # out = score * V + # [h, N, T_q, num_units/h] + out_soft = torch.matmul(scores_soft, values) + out_soft = torch.cat(torch.split(out_soft, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] + + return out_soft # , scores_soft + + +class EmbeddingPadded(nn.Module): + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): + super().__init__() + padding_mult = torch.ones((num_embeddings, 1), dtype=torch.int64) + padding_mult[padding_idx] = 0 + self.register_buffer("padding_mult", padding_mult) + self.embeddings = nn.parameter.Parameter(initialize_embeddings((num_embeddings, embedding_dim))) + + def forward(self, idx: torch.Tensor) -> torch.Tensor: + embeddings_zeroed = self.embeddings * self.padding_mult + x = F.embedding(idx, embeddings_zeroed) + return x + + +class EmbeddingProjBlock(nn.Module): + def __init__(self, embedding_dim: int): + super().__init__() + self.layers = nn.ModuleList( + [ + nn.Linear(embedding_dim, embedding_dim), + nn.LeakyReLU(0.3), + nn.Linear(embedding_dim, embedding_dim), + nn.LeakyReLU(0.3), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + res = x + for layer in self.layers: + x = layer(x) + x = x + res + return x + + +class LinearNorm(nn.Module): + def __init__(self, in_features: int, out_features: int, bias: bool = False): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias) + + nn.init.xavier_uniform_(self.linear.weight) + if bias: + nn.init.constant_(self.linear.bias, 0.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear(x) + return x + + +class STL(nn.Module): + """ + A PyTorch module for the Style Token Layer (STL) as described in + "A Style-Based Generator Architecture for Generative Adversarial Networks" + (https://arxiv.org/abs/1812.04948) + + The STL applies a multi-headed attention mechanism over the learned style tokens, + using the text input as the query and the style tokens as the keys and values. + The output of the attention mechanism is used as the text's style embedding. + + Args: + token_num (int): The number of style tokens. + n_hidden (int): Number of hidden dimensions. + """ + + def __init__(self, n_hidden: int, token_num: int): + super(STL, self).__init__() # pylint: disable=super-with-arguments + + num_heads = 1 + E = n_hidden + self.token_num = token_num + self.embed = nn.Parameter(torch.FloatTensor(self.token_num, E // num_heads)) + d_q = E // 2 + d_k = E // num_heads + self.attention = StyleEmbedAttention(query_dim=d_q, key_dim=d_k, num_units=E, num_heads=num_heads) + + torch.nn.init.normal_(self.embed, mean=0, std=0.5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + N = x.size(0) + query = x.unsqueeze(1) # [N, 1, E//2] + + keys_soft = torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1) # [N, token_num, E // num_heads] + + # Weighted sum + emotion_embed_soft = self.attention(query, keys_soft) + + return emotion_embed_soft diff --git a/TTS/tts/layers/delightful_tts/phoneme_prosody_predictor.py b/TTS/tts/layers/delightful_tts/phoneme_prosody_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..28418f7163361120914f277446f76ac9f0363254 --- /dev/null +++ b/TTS/tts/layers/delightful_tts/phoneme_prosody_predictor.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import + +from TTS.tts.layers.delightful_tts.conv_layers import ConvTransposed + + +class PhonemeProsodyPredictor(nn.Module): + """Non-parallel Prosody Predictor inspired by: https://arxiv.org/pdf/2102.00851.pdf + It consists of 2 layers of 1D convolutions each followed by a relu activation, layer norm + and dropout, then finally a linear layer. + + Args: + hidden_size (int): Size of hidden channels. + kernel_size (int): Kernel size for the conv layers. + dropout: (float): Probability of dropout. + bottleneck_size (int): bottleneck size for last linear layer. + lrelu_slope (float): Slope of the leaky relu. + """ + + def __init__( + self, + hidden_size: int, + kernel_size: int, + dropout: float, + bottleneck_size: int, + lrelu_slope: float, + ): + super().__init__() + self.d_model = hidden_size + self.layers = nn.ModuleList( + [ + ConvTransposed( + self.d_model, + self.d_model, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ), + nn.LeakyReLU(lrelu_slope), + nn.LayerNorm(self.d_model), + nn.Dropout(dropout), + ConvTransposed( + self.d_model, + self.d_model, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ), + nn.LeakyReLU(lrelu_slope), + nn.LayerNorm(self.d_model), + nn.Dropout(dropout), + ] + ) + self.predictor_bottleneck = nn.Linear(self.d_model, bottleneck_size) + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Shapes: + x: :math: `[B, T, D]` + mask: :math: `[B, T]` + """ + mask = mask.unsqueeze(2) + for layer in self.layers: + x = layer(x) + x = x.masked_fill(mask, 0.0) + x = self.predictor_bottleneck(x) + return x diff --git a/TTS/tts/layers/delightful_tts/pitch_adaptor.py b/TTS/tts/layers/delightful_tts/pitch_adaptor.py new file mode 100644 index 0000000000000000000000000000000000000000..9031369e0f019cf115d0d43b288bb97d9db48467 --- /dev/null +++ b/TTS/tts/layers/delightful_tts/pitch_adaptor.py @@ -0,0 +1,88 @@ +from typing import Callable, Tuple + +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import + +from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor +from TTS.tts.utils.helpers import average_over_durations + + +class PitchAdaptor(nn.Module): # pylint: disable=abstract-method + """Module to get pitch embeddings via pitch predictor + + Args: + n_input (int): Number of pitch predictor input channels. + n_hidden (int): Number of pitch predictor hidden channels. + n_out (int): Number of pitch predictor out channels. + kernel size (int): Size of the kernel for conv layers. + emb_kernel_size (int): Size the kernel for the pitch embedding. + p_dropout (float): Probability of dropout. + lrelu_slope (float): Slope for the leaky relu. + + Inputs: inputs, mask + - **inputs** (batch, time1, dim): Tensor containing input vector + - **target** (batch, 1, time2): Tensor containing the pitch target + - **dr** (batch, time1): Tensor containing aligner durations vector + - **mask** (batch, time1): Tensor containing indices to be masked + Returns: + - **pitch prediction** (batch, 1, time1): Tensor produced by pitch predictor + - **pitch embedding** (batch, channels, time1): Tensor produced pitch pitch adaptor + - **average pitch target(train only)** (batch, 1, time1): Tensor produced after averaging over durations + """ + + def __init__( + self, + n_input: int, + n_hidden: int, + n_out: int, + kernel_size: int, + emb_kernel_size: int, + p_dropout: float, + lrelu_slope: float, + ): + super().__init__() + self.pitch_predictor = VariancePredictor( + channels_in=n_input, + channels=n_hidden, + channels_out=n_out, + kernel_size=kernel_size, + p_dropout=p_dropout, + lrelu_slope=lrelu_slope, + ) + self.pitch_emb = nn.Conv1d( + 1, + n_input, + kernel_size=emb_kernel_size, + padding=int((emb_kernel_size - 1) / 2), + ) + + def get_pitch_embedding_train( + self, x: torch.Tensor, target: torch.Tensor, dr: torch.IntTensor, mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Shapes: + x: :math: `[B, T_src, C]` + target: :math: `[B, 1, T_max2]` + dr: :math: `[B, T_src]` + mask: :math: `[B, T_src]` + """ + pitch_pred = self.pitch_predictor(x, mask) # [B, T_src, C_hidden], [B, T_src] --> [B, T_src] + pitch_pred.unsqueeze_(1) # --> [B, 1, T_src] + avg_pitch_target = average_over_durations(target, dr) # [B, 1, T_mel], [B, T_src] --> [B, 1, T_src] + pitch_emb = self.pitch_emb(avg_pitch_target) # [B, 1, T_src] --> [B, C_hidden, T_src] + return pitch_pred, avg_pitch_target, pitch_emb + + def get_pitch_embedding( + self, + x: torch.Tensor, + mask: torch.Tensor, + pitch_transform: Callable, + pitch_mean: torch.Tensor, + pitch_std: torch.Tensor, + ) -> torch.Tensor: + pitch_pred = self.pitch_predictor(x, mask) + if pitch_transform is not None: + pitch_pred = pitch_transform(pitch_pred, (~mask).sum(), pitch_mean, pitch_std) + pitch_pred.unsqueeze_(1) + pitch_emb_pred = self.pitch_emb(pitch_pred) + return pitch_emb_pred, pitch_pred diff --git a/TTS/tts/layers/delightful_tts/variance_predictor.py b/TTS/tts/layers/delightful_tts/variance_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..68303a1bd1148089eab7ee8be12d4f37ddf420e1 --- /dev/null +++ b/TTS/tts/layers/delightful_tts/variance_predictor.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn # pylint: disable=consider-using-from-import + +from TTS.tts.layers.delightful_tts.conv_layers import ConvTransposed + + +class VariancePredictor(nn.Module): + """ + Network is 2-layer 1D convolutions with leaky relu activation and then + followed by layer normalization then a dropout layer and finally an + extra linear layer to project the hidden states into the output sequence. + + Args: + channels_in (int): Number of in channels for conv layers. + channels_out (int): Number of out channels for the last linear layer. + kernel_size (int): Size the kernel for the conv layers. + p_dropout (float): Probability of dropout. + lrelu_slope (float): Slope for the leaky relu. + + Inputs: inputs, mask + - **inputs** (batch, time, dim): Tensor containing input vector + - **mask** (batch, time): Tensor containing indices to be masked + Returns: + - **outputs** (batch, time): Tensor produced by last linear layer. + """ + + def __init__( + self, channels_in: int, channels: int, channels_out: int, kernel_size: int, p_dropout: float, lrelu_slope: float + ): + super().__init__() + + self.layers = nn.ModuleList( + [ + ConvTransposed( + channels_in, + channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ), + nn.LeakyReLU(lrelu_slope), + nn.LayerNorm(channels), + nn.Dropout(p_dropout), + ConvTransposed( + channels, + channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ), + nn.LeakyReLU(lrelu_slope), + nn.LayerNorm(channels), + nn.Dropout(p_dropout), + ] + ) + + self.linear_layer = nn.Linear(channels, channels_out) + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Shapes: + x: :math: `[B, T_src, C]` + mask: :math: `[B, T_src]` + """ + for layer in self.layers: + x = layer(x) + x = self.linear_layer(x) + x = x.squeeze(-1) + x = x.masked_fill(mask, 0.0) + return x diff --git a/TTS/tts/layers/feed_forward/__init__.py b/TTS/tts/layers/feed_forward/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/layers/feed_forward/decoder.py b/TTS/tts/layers/feed_forward/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0376e2e3926e65254c3a81d085d48c97df033958 --- /dev/null +++ b/TTS/tts/layers/feed_forward/decoder.py @@ -0,0 +1,228 @@ +import torch +from torch import nn + +from TTS.tts.layers.generic.res_conv_bn import Conv1dBN, Conv1dBNBlock, ResidualConv1dBNBlock +from TTS.tts.layers.generic.transformer import FFTransformerBlock +from TTS.tts.layers.generic.wavenet import WNBlocks +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer + + +class WaveNetDecoder(nn.Module): + """WaveNet based decoder with a prenet and a postnet. + + prenet: conv1d_1x1 + postnet: 3 x [conv1d_1x1 -> relu] -> conv1d_1x1 + + TODO: Integrate speaker conditioning vector. + + Note: + default wavenet parameters; + params = { + "num_blocks": 12, + "hidden_channels":192, + "kernel_size": 5, + "dilation_rate": 1, + "num_layers": 4, + "dropout_p": 0.05 + } + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels for prenet and postnet. + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, c_in_channels, params): + super().__init__() + # prenet + self.prenet = torch.nn.Conv1d(in_channels, params["hidden_channels"], 1) + # wavenet layers + self.wn = WNBlocks(params["hidden_channels"], c_in_channels=c_in_channels, **params) + # postnet + self.postnet = [ + torch.nn.Conv1d(params["hidden_channels"], hidden_channels, 1), + torch.nn.ReLU(), + torch.nn.Conv1d(hidden_channels, hidden_channels, 1), + torch.nn.ReLU(), + torch.nn.Conv1d(hidden_channels, hidden_channels, 1), + torch.nn.ReLU(), + torch.nn.Conv1d(hidden_channels, out_channels, 1), + ] + self.postnet = nn.Sequential(*self.postnet) + + def forward(self, x, x_mask=None, g=None): + x = self.prenet(x) * x_mask + x = self.wn(x, x_mask, g) + o = self.postnet(x) * x_mask + return o + + +class RelativePositionTransformerDecoder(nn.Module): + """Decoder with Relative Positional Transformer. + + Note: + Default params + params={ + 'hidden_channels_ffn': 128, + 'num_heads': 2, + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 8, + "rel_attn_window_size": 4, + "input_length": None + } + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels including Transformer layers. + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, params): + super().__init__() + self.prenet = Conv1dBN(in_channels, hidden_channels, 1, 1) + self.rel_pos_transformer = RelativePositionTransformer(in_channels, out_channels, hidden_channels, **params) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + o = self.prenet(x) * x_mask + o = self.rel_pos_transformer(o, x_mask) + return o + + +class FFTransformerDecoder(nn.Module): + """Decoder with FeedForwardTransformer. + + Default params + params={ + 'hidden_channels_ffn': 1024, + 'num_heads': 2, + "dropout_p": 0.1, + "num_layers": 6, + } + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels including Transformer layers. + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, params): + super().__init__() + self.transformer_block = FFTransformerBlock(in_channels, **params) + self.postnet = nn.Conv1d(in_channels, out_channels, 1) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + # TODO: handle multi-speaker + x_mask = 1 if x_mask is None else x_mask + o = self.transformer_block(x) * x_mask + o = self.postnet(o) * x_mask + return o + + +class ResidualConv1dBNDecoder(nn.Module): + """Residual Convolutional Decoder as in the original Speedy Speech paper + + TODO: Integrate speaker conditioning vector. + + Note: + Default params + params = { + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17 + } + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels including ResidualConv1dBNBlock layers. + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, params): + super().__init__() + self.res_conv_block = ResidualConv1dBNBlock(in_channels, hidden_channels, hidden_channels, **params) + self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1) + self.postnet = nn.Sequential( + Conv1dBNBlock( + hidden_channels, hidden_channels, hidden_channels, params["kernel_size"], 1, num_conv_blocks=2 + ), + nn.Conv1d(hidden_channels, out_channels, 1), + ) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + o = self.res_conv_block(x, x_mask) + o = self.post_conv(o) + x + return self.postnet(o) * x_mask + + +class Decoder(nn.Module): + """Decodes the expanded phoneme encoding into spectrograms + Args: + out_channels (int): number of output channels. + in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers. + decoder_type (str): decoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'. + decoder_params (dict): model parameters for specified decoder type. + c_in_channels (int): number of channels for conditional input. + + Shapes: + - input: (B, C, T) + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + out_channels, + in_hidden_channels, + decoder_type="residual_conv_bn", + decoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17, + }, + c_in_channels=0, + ): + super().__init__() + + if decoder_type.lower() == "relative_position_transformer": + self.decoder = RelativePositionTransformerDecoder( + in_channels=in_hidden_channels, + out_channels=out_channels, + hidden_channels=in_hidden_channels, + params=decoder_params, + ) + elif decoder_type.lower() == "residual_conv_bn": + self.decoder = ResidualConv1dBNDecoder( + in_channels=in_hidden_channels, + out_channels=out_channels, + hidden_channels=in_hidden_channels, + params=decoder_params, + ) + elif decoder_type.lower() == "wavenet": + self.decoder = WaveNetDecoder( + in_channels=in_hidden_channels, + out_channels=out_channels, + hidden_channels=in_hidden_channels, + c_in_channels=c_in_channels, + params=decoder_params, + ) + elif decoder_type.lower() == "fftransformer": + self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params) + else: + raise ValueError(f"[!] Unknown decoder type - {decoder_type}") + + def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument + """ + Args: + x: [B, C, T] + x_mask: [B, 1, T] + g: [B, C_g, 1] + """ + # TODO: implement multi-speaker + o = self.decoder(x, x_mask, g) + return o diff --git a/TTS/tts/layers/feed_forward/duration_predictor.py b/TTS/tts/layers/feed_forward/duration_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..4422648f4337e48aab39671836fcfb5e12ff4be7 --- /dev/null +++ b/TTS/tts/layers/feed_forward/duration_predictor.py @@ -0,0 +1,41 @@ +from torch import nn + +from TTS.tts.layers.generic.res_conv_bn import Conv1dBN + + +class DurationPredictor(nn.Module): + """Speedy Speech duration predictor model. + Predicts phoneme durations from encoder outputs. + + Note: + Outputs interpreted as log(durations) + To get actual durations, do exp transformation + + conv_BN_4x1 -> conv_BN_3x1 -> conv_BN_1x1 -> conv_1x1 + + Args: + hidden_channels (int): number of channels in the inner layers. + """ + + def __init__(self, hidden_channels): + super().__init__() + + self.layers = nn.ModuleList( + [ + Conv1dBN(hidden_channels, hidden_channels, 4, 1), + Conv1dBN(hidden_channels, hidden_channels, 3, 1), + Conv1dBN(hidden_channels, hidden_channels, 1, 1), + nn.Conv1d(hidden_channels, 1, 1), + ] + ) + + def forward(self, x, x_mask): + """ + Shapes: + x: [B, C, T] + x_mask: [B, 1, T] + """ + o = x + for layer in self.layers: + o = layer(o) * x_mask + return o diff --git a/TTS/tts/layers/feed_forward/encoder.py b/TTS/tts/layers/feed_forward/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..caf939ffc73fedac299228e090b2df3bb4cc553c --- /dev/null +++ b/TTS/tts/layers/feed_forward/encoder.py @@ -0,0 +1,162 @@ +from torch import nn + +from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock +from TTS.tts.layers.generic.transformer import FFTransformerBlock +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer + + +class RelativePositionTransformerEncoder(nn.Module): + """Speedy speech encoder built on Transformer with Relative Position encoding. + + TODO: Integrate speaker conditioning vector. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, params): + super().__init__() + self.prenet = ResidualConv1dBNBlock( + in_channels, + hidden_channels, + hidden_channels, + kernel_size=5, + num_res_blocks=3, + num_conv_blocks=1, + dilations=[1, 1, 1], + ) + self.rel_pos_transformer = RelativePositionTransformer(hidden_channels, out_channels, hidden_channels, **params) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + if x_mask is None: + x_mask = 1 + o = self.prenet(x) * x_mask + o = self.rel_pos_transformer(o, x_mask) + return o + + +class ResidualConv1dBNEncoder(nn.Module): + """Residual Convolutional Encoder as in the original Speedy Speech paper + + TODO: Integrate speaker conditioning vector. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels + params (dict): dictionary for residual convolutional blocks. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, params): + super().__init__() + self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1), nn.ReLU()) + self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **params) + + self.postnet = nn.Sequential( + *[ + nn.Conv1d(hidden_channels, hidden_channels, 1), + nn.ReLU(), + nn.BatchNorm1d(hidden_channels), + nn.Conv1d(hidden_channels, out_channels, 1), + ] + ) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + if x_mask is None: + x_mask = 1 + o = self.prenet(x) * x_mask + o = self.res_conv_block(o, x_mask) + o = self.postnet(o + x) * x_mask + return o * x_mask + + +class Encoder(nn.Module): + # pylint: disable=dangerous-default-value + """Factory class for Speedy Speech encoder enables different encoder types internally. + + Args: + num_chars (int): number of characters. + out_channels (int): number of output channels. + in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers. + encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'. + encoder_params (dict): model parameters for specified encoder type. + c_in_channels (int): number of channels for conditional input. + + Note: + Default encoder_params to be set in config.json... + + ```python + # for 'relative_position_transformer' + encoder_params={ + 'hidden_channels_ffn': 128, + 'num_heads': 2, + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "rel_attn_window_size": 4, + "input_length": None + }, + + # for 'residual_conv_bn' + encoder_params = { + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13 + } + + # for 'fftransformer' + encoder_params = { + "hidden_channels_ffn": 1024 , + "num_heads": 2, + "num_layers": 6, + "dropout_p": 0.1 + } + ``` + """ + + def __init__( + self, + in_hidden_channels, + out_channels, + encoder_type="residual_conv_bn", + encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13}, + c_in_channels=0, + ): + super().__init__() + self.out_channels = out_channels + self.in_channels = in_hidden_channels + self.hidden_channels = in_hidden_channels + self.encoder_type = encoder_type + self.c_in_channels = c_in_channels + + # init encoder + if encoder_type.lower() == "relative_position_transformer": + # text encoder + # pylint: disable=unexpected-keyword-arg + self.encoder = RelativePositionTransformerEncoder( + in_hidden_channels, out_channels, in_hidden_channels, encoder_params + ) + elif encoder_type.lower() == "residual_conv_bn": + self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, out_channels, in_hidden_channels, encoder_params) + elif encoder_type.lower() == "fftransformer": + assert ( + in_hidden_channels == out_channels + ), "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'" + # pylint: disable=unexpected-keyword-arg + self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params) + else: + raise NotImplementedError(" [!] unknown encoder type.") + + def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument + """ + Shapes: + x: [B, C, T] + x_mask: [B, 1, T] + g: [B, C, 1] + """ + o = self.encoder(x, x_mask) + return o * x_mask diff --git a/TTS/tts/layers/generic/__init__.py b/TTS/tts/layers/generic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/layers/generic/aligner.py b/TTS/tts/layers/generic/aligner.py new file mode 100644 index 0000000000000000000000000000000000000000..baa6f0e9c4879207695b2de1193c9147b5a3fa4b --- /dev/null +++ b/TTS/tts/layers/generic/aligner.py @@ -0,0 +1,92 @@ +from typing import Tuple + +import torch +from torch import nn + + +class AlignmentNetwork(torch.nn.Module): + """Aligner Network for learning alignment between the input text and the model output with Gaussian Attention. + + :: + + query -> conv1d -> relu -> conv1d -> relu -> conv1d -> L2_dist -> softmax -> alignment + key -> conv1d -> relu -> conv1d -----------------------^ + + Args: + in_query_channels (int): Number of channels in the query network. Defaults to 80. + in_key_channels (int): Number of channels in the key network. Defaults to 512. + attn_channels (int): Number of inner channels in the attention layers. Defaults to 80. + temperature (float): Temperature for the softmax. Defaults to 0.0005. + """ + + def __init__( + self, + in_query_channels=80, + in_key_channels=512, + attn_channels=80, + temperature=0.0005, + ): + super().__init__() + self.temperature = temperature + self.softmax = torch.nn.Softmax(dim=3) + self.log_softmax = torch.nn.LogSoftmax(dim=3) + + self.key_layer = nn.Sequential( + nn.Conv1d( + in_key_channels, + in_key_channels * 2, + kernel_size=3, + padding=1, + bias=True, + ), + torch.nn.ReLU(), + nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True), + ) + + self.query_layer = nn.Sequential( + nn.Conv1d( + in_query_channels, + in_query_channels * 2, + kernel_size=3, + padding=1, + bias=True, + ), + torch.nn.ReLU(), + nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True), + torch.nn.ReLU(), + nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True), + ) + + self.init_layers() + + def init_layers(self): + torch.nn.init.xavier_uniform_(self.key_layer[0].weight, gain=torch.nn.init.calculate_gain("relu")) + torch.nn.init.xavier_uniform_(self.key_layer[2].weight, gain=torch.nn.init.calculate_gain("linear")) + torch.nn.init.xavier_uniform_(self.query_layer[0].weight, gain=torch.nn.init.calculate_gain("relu")) + torch.nn.init.xavier_uniform_(self.query_layer[2].weight, gain=torch.nn.init.calculate_gain("linear")) + torch.nn.init.xavier_uniform_(self.query_layer[4].weight, gain=torch.nn.init.calculate_gain("linear")) + + def forward( + self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None + ) -> Tuple[torch.tensor, torch.tensor]: + """Forward pass of the aligner encoder. + Shapes: + - queries: :math:`[B, C, T_de]` + - keys: :math:`[B, C_emb, T_en]` + - mask: :math:`[B, T_de]` + Output: + attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask. + attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities. + """ + key_out = self.key_layer(keys) + query_out = self.query_layer(queries) + attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2 + attn_logp = -self.temperature * attn_factor.sum(1, keepdim=True) + if attn_prior is not None: + attn_logp = self.log_softmax(attn_logp) + torch.log(attn_prior[:, None] + 1e-8) + + if mask is not None: + attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) + + attn = self.softmax(attn_logp) + return attn, attn_logp diff --git a/TTS/tts/layers/generic/gated_conv.py b/TTS/tts/layers/generic/gated_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..9a29c4499f970db538a4b99c3c05cba22576195f --- /dev/null +++ b/TTS/tts/layers/generic/gated_conv.py @@ -0,0 +1,37 @@ +from torch import nn + +from .normalization import LayerNorm + + +class GatedConvBlock(nn.Module): + """Gated convolutional block as in https://arxiv.org/pdf/1612.08083.pdf + Args: + in_out_channels (int): number of input/output channels. + kernel_size (int): convolution kernel size. + dropout_p (float): dropout rate. + """ + + def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers): + super().__init__() + # class arguments + self.dropout_p = dropout_p + self.num_layers = num_layers + # define layers + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.conv_layers += [nn.Conv1d(in_out_channels, 2 * in_out_channels, kernel_size, padding=kernel_size // 2)] + self.norm_layers += [LayerNorm(2 * in_out_channels)] + + def forward(self, x, x_mask): + o = x + res = x + for idx in range(self.num_layers): + o = nn.functional.dropout(o, p=self.dropout_p, training=self.training) + o = self.conv_layers[idx](o * x_mask) + o = self.norm_layers[idx](o) + o = nn.functional.glu(o, dim=1) + o = res + o + res = o + return o diff --git a/TTS/tts/layers/generic/normalization.py b/TTS/tts/layers/generic/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..c0270e405e4246e47b7bc0787e4cd4b069533f92 --- /dev/null +++ b/TTS/tts/layers/generic/normalization.py @@ -0,0 +1,123 @@ +import torch +from torch import nn + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-4): + """Layer norm for the 2nd dimension of the input. + Args: + channels (int): number of channels (2nd dimension) of the input. + eps (float): to prevent 0 division + + Shapes: + - input: (B, C, T) + - output: (B, C, T) + """ + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(1, channels, 1) * 0.1) + self.beta = nn.Parameter(torch.zeros(1, channels, 1)) + + def forward(self, x): + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) + x = (x - mean) * torch.rsqrt(variance + self.eps) + x = x * self.gamma + self.beta + return x + + +class LayerNorm2(nn.Module): + """Layer norm for the 2nd dimension of the input using torch primitive. + Args: + channels (int): number of channels (2nd dimension) of the input. + eps (float): to prevent 0 division + + Shapes: + - input: (B, C, T) + - output: (B, C, T) + """ + + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = torch.nn.functional.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class TemporalBatchNorm1d(nn.BatchNorm1d): + """Normalize each channel separately over time and batch.""" + + def __init__(self, channels, affine=True, track_running_stats=True, momentum=0.1): + super().__init__(channels, affine=affine, track_running_stats=track_running_stats, momentum=momentum) + + def forward(self, x): + return super().forward(x.transpose(2, 1)).transpose(2, 1) + + +class ActNorm(nn.Module): + """Activation Normalization bijector as an alternative to Batch Norm. It computes + mean and std from a sample data in advance and it uses these values + for normalization at training. + + Args: + channels (int): input channels. + ddi (False): data depended initialization flag. + + Shapes: + - inputs: (B, C, T) + - outputs: (B, C, T) + """ + + def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument + super().__init__() + self.channels = channels + self.initialized = not ddi + + self.logs = nn.Parameter(torch.zeros(1, channels, 1)) + self.bias = nn.Parameter(torch.zeros(1, channels, 1)) + + def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument + if x_mask is None: + x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype) + x_len = torch.sum(x_mask, [1, 2]) + if not self.initialized: + self.initialize(x, x_mask) + self.initialized = True + + if reverse: + z = (x - self.bias) * torch.exp(-self.logs) * x_mask + logdet = None + else: + z = (self.bias + torch.exp(self.logs) * x) * x_mask + logdet = torch.sum(self.logs) * x_len # [b] + + return z, logdet + + def store_inverse(self): + pass + + def set_ddi(self, ddi): + self.initialized = not ddi + + def initialize(self, x, x_mask): + with torch.no_grad(): + denom = torch.sum(x_mask, [0, 2]) + m = torch.sum(x * x_mask, [0, 2]) / denom + m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom + v = m_sq - (m**2) + logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) + + bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) + logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype) + + self.bias.data.copy_(bias_init) + self.logs.data.copy_(logs_init) diff --git a/TTS/tts/layers/generic/pos_encoding.py b/TTS/tts/layers/generic/pos_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..913add0d14332bf70c3ecd2a95869d0071310bd4 --- /dev/null +++ b/TTS/tts/layers/generic/pos_encoding.py @@ -0,0 +1,69 @@ +import math + +import torch +from torch import nn + + +class PositionalEncoding(nn.Module): + """Sinusoidal positional encoding for non-recurrent neural networks. + Implementation based on "Attention Is All You Need" + + Args: + channels (int): embedding size + dropout_p (float): dropout rate applied to the output. + max_len (int): maximum sequence length. + use_scale (bool): whether to use a learnable scaling coefficient. + """ + + def __init__(self, channels, dropout_p=0.0, max_len=5000, use_scale=False): + super().__init__() + if channels % 2 != 0: + raise ValueError( + "Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels) + ) + self.use_scale = use_scale + if use_scale: + self.scale = torch.nn.Parameter(torch.ones(1)) + pe = torch.zeros(max_len, channels) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + pe = pe.unsqueeze(0).transpose(1, 2) + self.register_buffer("pe", pe) + if dropout_p > 0: + self.dropout = nn.Dropout(p=dropout_p) + self.channels = channels + + def forward(self, x, mask=None, first_idx=None, last_idx=None): + """ + Shapes: + x: [B, C, T] + mask: [B, 1, T] + first_idx: int + last_idx: int + """ + + x = x * math.sqrt(self.channels) + if first_idx is None: + if self.pe.size(2) < x.size(2): + raise RuntimeError( + f"Sequence is {x.size(2)} but PositionalEncoding is" + f" limited to {self.pe.size(2)}. See max_len argument." + ) + if mask is not None: + pos_enc = self.pe[:, :, : x.size(2)] * mask + else: + pos_enc = self.pe[:, :, : x.size(2)] + if self.use_scale: + x = x + self.scale * pos_enc + else: + x = x + pos_enc + else: + if self.use_scale: + x = x + self.scale * self.pe[:, :, first_idx:last_idx] + else: + x = x + self.pe[:, :, first_idx:last_idx] + if hasattr(self, "dropout"): + x = self.dropout(x) + return x diff --git a/TTS/tts/layers/generic/res_conv_bn.py b/TTS/tts/layers/generic/res_conv_bn.py new file mode 100644 index 0000000000000000000000000000000000000000..4beda291aa15398024b5b16cd6bf12b88898a0a9 --- /dev/null +++ b/TTS/tts/layers/generic/res_conv_bn.py @@ -0,0 +1,127 @@ +from torch import nn + + +class ZeroTemporalPad(nn.Module): + """Pad sequences to equal lentgh in the temporal dimension""" + + def __init__(self, kernel_size, dilation): + super().__init__() + total_pad = dilation * (kernel_size - 1) + begin = total_pad // 2 + end = total_pad - begin + self.pad_layer = nn.ZeroPad2d((0, 0, begin, end)) + + def forward(self, x): + return self.pad_layer(x) + + +class Conv1dBN(nn.Module): + """1d convolutional with batch norm. + conv1d -> relu -> BN blocks. + + Note: + Batch normalization is applied after ReLU regarding the original implementation. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + kernel_size (int): kernel size for convolutional filters. + dilation (int): dilation for convolution layers. + """ + + def __init__(self, in_channels, out_channels, kernel_size, dilation): + super().__init__() + padding = dilation * (kernel_size - 1) + pad_s = padding // 2 + pad_e = padding - pad_s + self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation) + self.pad = nn.ZeroPad2d((pad_s, pad_e, 0, 0)) # uneven left and right padding + self.norm = nn.BatchNorm1d(out_channels) + + def forward(self, x): + o = self.conv1d(x) + o = self.pad(o) + o = nn.functional.relu(o) + o = self.norm(o) + return o + + +class Conv1dBNBlock(nn.Module): + """1d convolutional block with batch norm. It is a set of conv1d -> relu -> BN blocks. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of inner convolution channels. + kernel_size (int): kernel size for convolutional filters. + dilation (int): dilation for convolution layers. + num_conv_blocks (int, optional): number of convolutional blocks. Defaults to 2. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation, num_conv_blocks=2): + super().__init__() + self.conv_bn_blocks = [] + for idx in range(num_conv_blocks): + layer = Conv1dBN( + in_channels if idx == 0 else hidden_channels, + out_channels if idx == (num_conv_blocks - 1) else hidden_channels, + kernel_size, + dilation, + ) + self.conv_bn_blocks.append(layer) + self.conv_bn_blocks = nn.Sequential(*self.conv_bn_blocks) + + def forward(self, x): + """ + Shapes: + x: (B, D, T) + """ + return self.conv_bn_blocks(x) + + +class ResidualConv1dBNBlock(nn.Module): + """Residual Convolutional Blocks with BN + Each block has 'num_conv_block' conv layers and 'num_res_blocks' such blocks are connected + with residual connections. + + conv_block = (conv1d -> relu -> bn) x 'num_conv_blocks' + residuak_conv_block = (x -> conv_block -> + ->) x 'num_res_blocks' + ' - - - - - - - - - ^ + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of inner convolution channels. + kernel_size (int): kernel size for convolutional filters. + dilations (list): dilations for each convolution layer. + num_res_blocks (int, optional): number of residual blocks. Defaults to 13. + num_conv_blocks (int, optional): number of convolutional blocks in each residual block. Defaults to 2. + """ + + def __init__( + self, in_channels, out_channels, hidden_channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2 + ): + super().__init__() + assert len(dilations) == num_res_blocks + self.res_blocks = nn.ModuleList() + for idx, dilation in enumerate(dilations): + block = Conv1dBNBlock( + in_channels if idx == 0 else hidden_channels, + out_channels if (idx + 1) == len(dilations) else hidden_channels, + hidden_channels, + kernel_size, + dilation, + num_conv_blocks, + ) + self.res_blocks.append(block) + + def forward(self, x, x_mask=None): + if x_mask is None: + x_mask = 1.0 + o = x * x_mask + for block in self.res_blocks: + res = o + o = block(o) + o = o + res + if x_mask is not None: + o = o * x_mask + return o diff --git a/TTS/tts/layers/generic/time_depth_sep_conv.py b/TTS/tts/layers/generic/time_depth_sep_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..186cea02e75e156c40923de91086c369a9ea02ee --- /dev/null +++ b/TTS/tts/layers/generic/time_depth_sep_conv.py @@ -0,0 +1,84 @@ +import torch +from torch import nn + + +class TimeDepthSeparableConv(nn.Module): + """Time depth separable convolution as in https://arxiv.org/pdf/1904.02619.pdf + It shows competative results with less computation and memory footprint.""" + + def __init__(self, in_channels, hid_channels, out_channels, kernel_size, bias=True): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.hid_channels = hid_channels + self.kernel_size = kernel_size + + self.time_conv = nn.Conv1d( + in_channels, + 2 * hid_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.norm1 = nn.BatchNorm1d(2 * hid_channels) + self.depth_conv = nn.Conv1d( + hid_channels, + hid_channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=hid_channels, + bias=bias, + ) + self.norm2 = nn.BatchNorm1d(hid_channels) + self.time_conv2 = nn.Conv1d( + hid_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.norm3 = nn.BatchNorm1d(out_channels) + + def forward(self, x): + x_res = x + x = self.time_conv(x) + x = self.norm1(x) + x = nn.functional.glu(x, dim=1) + x = self.depth_conv(x) + x = self.norm2(x) + x = x * torch.sigmoid(x) + x = self.time_conv2(x) + x = self.norm3(x) + x = x_res + x + return x + + +class TimeDepthSeparableConvBlock(nn.Module): + def __init__(self, in_channels, hid_channels, out_channels, num_layers, kernel_size, bias=True): + super().__init__() + assert (kernel_size - 1) % 2 == 0 + assert num_layers > 1 + + self.layers = nn.ModuleList() + layer = TimeDepthSeparableConv( + in_channels, hid_channels, out_channels if num_layers == 1 else hid_channels, kernel_size, bias + ) + self.layers.append(layer) + for idx in range(num_layers - 1): + layer = TimeDepthSeparableConv( + hid_channels, + hid_channels, + out_channels if (idx + 1) == (num_layers - 1) else hid_channels, + kernel_size, + bias, + ) + self.layers.append(layer) + + def forward(self, x, mask): + for layer in self.layers: + x = layer(x * mask) + return x diff --git a/TTS/tts/layers/generic/transformer.py b/TTS/tts/layers/generic/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..9b7ecee2bacb68cd330e18630531c97bc6f2e6a3 --- /dev/null +++ b/TTS/tts/layers/generic/transformer.py @@ -0,0 +1,89 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +class FFTransformer(nn.Module): + def __init__(self, in_out_channels, num_heads, hidden_channels_ffn=1024, kernel_size_fft=3, dropout_p=0.1): + super().__init__() + self.self_attn = nn.MultiheadAttention(in_out_channels, num_heads, dropout=dropout_p) + + padding = (kernel_size_fft - 1) // 2 + self.conv1 = nn.Conv1d(in_out_channels, hidden_channels_ffn, kernel_size=kernel_size_fft, padding=padding) + self.conv2 = nn.Conv1d(hidden_channels_ffn, in_out_channels, kernel_size=kernel_size_fft, padding=padding) + + self.norm1 = nn.LayerNorm(in_out_channels) + self.norm2 = nn.LayerNorm(in_out_channels) + + self.dropout1 = nn.Dropout(dropout_p) + self.dropout2 = nn.Dropout(dropout_p) + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + """😦 ugly looking with all the transposing""" + src = src.permute(2, 0, 1) + src2, enc_align = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask) + src = src + self.dropout1(src2) + src = self.norm1(src + src2) + # T x B x D -> B x D x T + src = src.permute(1, 2, 0) + src2 = self.conv2(F.relu(self.conv1(src))) + src2 = self.dropout2(src2) + src = src + src2 + src = src.transpose(1, 2) + src = self.norm2(src) + src = src.transpose(1, 2) + return src, enc_align + + +class FFTransformerBlock(nn.Module): + def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, num_layers, dropout_p): + super().__init__() + self.fft_layers = nn.ModuleList( + [ + FFTransformer( + in_out_channels=in_out_channels, + num_heads=num_heads, + hidden_channels_ffn=hidden_channels_ffn, + dropout_p=dropout_p, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument + """ + TODO: handle multi-speaker + Shapes: + - x: :math:`[B, C, T]` + - mask: :math:`[B, 1, T] or [B, T]` + """ + if mask is not None and mask.ndim == 3: + mask = mask.squeeze(1) + # mask is negated, torch uses 1s and 0s reversely. + mask = ~mask.bool() + alignments = [] + for layer in self.fft_layers: + x, align = layer(x, src_key_padding_mask=mask) + alignments.append(align.unsqueeze(1)) + alignments = torch.cat(alignments, 1) + return x + + +class FFTDurationPredictor: + def __init__( + self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None + ): # pylint: disable=unused-argument + self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p) + self.proj = nn.Linear(in_channels, 1) + + def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, C, T]` + - mask: :math:`[B, 1, T]` + + TODO: Handle the cond input + """ + x = self.fft(x, mask=mask) + x = self.proj(x) + return x diff --git a/TTS/tts/layers/generic/wavenet.py b/TTS/tts/layers/generic/wavenet.py new file mode 100644 index 0000000000000000000000000000000000000000..f8de63b49fea8a7140ddd0493446f0541abe6a0a --- /dev/null +++ b/TTS/tts/layers/generic/wavenet.py @@ -0,0 +1,176 @@ +import torch +from torch import nn +from torch.nn.utils import parametrize + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +class WN(torch.nn.Module): + """Wavenet layers with weight norm and no input conditioning. + + |-----------------------------------------------------------------------------| + | |-> tanh -| | + res -|- conv1d(dilation) -> dropout -> + -| * -> conv1d1x1 -> split -|- + -> res + g -------------------------------------| |-> sigmoid -| | + o --------------------------------------------------------------------------- + --------- o + + Args: + in_channels (int): number of input channels. + hidden_channes (int): number of hidden channels. + kernel_size (int): filter kernel size for the first conv layer. + dilation_rate (int): dilations rate to increase dilation per layer. + If it is 2, dilations are 1, 2, 4, 8 for the next 4 layers. + num_layers (int): number of wavenet layers. + c_in_channels (int): number of channels of conditioning input. + dropout_p (float): dropout rate. + weight_norm (bool): enable/disable weight norm for convolution layers. + """ + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + c_in_channels=0, + dropout_p=0, + weight_norm=True, + ): + super().__init__() + assert kernel_size % 2 == 1 + assert hidden_channels % 2 == 0 + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.c_in_channels = c_in_channels + self.dropout_p = dropout_p + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.dropout = nn.Dropout(dropout_p) + + # init conditioning layer + if c_in_channels > 0: + cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1) + self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight") + # intermediate layers + for i in range(num_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + if i == 0: + in_layer = torch.nn.Conv1d( + in_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) + else: + in_layer = torch.nn.Conv1d( + hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) + in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + if i < num_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + # setup weight norm + if not weight_norm: + self.remove_weight_norm() + + def forward(self, x, x_mask=None, g=None, **kwargs): # pylint: disable=unused-argument + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + x_mask = 1.0 if x_mask is None else x_mask + if g is not None: + g = self.cond_layer(g) + for i in range(self.num_layers): + x_in = self.in_layers[i](x) + x_in = self.dropout(x_in) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.num_layers - 1: + x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.c_in_channels != 0: + parametrize.remove_parametrizations(self.cond_layer, "weight") + for l in self.in_layers: + parametrize.remove_parametrizations(l, "weight") + for l in self.res_skip_layers: + parametrize.remove_parametrizations(l, "weight") + + +class WNBlocks(nn.Module): + """Wavenet blocks. + + Note: After each block dilation resets to 1 and it increases in each block + along the dilation rate. + + Args: + in_channels (int): number of input channels. + hidden_channes (int): number of hidden channels. + kernel_size (int): filter kernel size for the first conv layer. + dilation_rate (int): dilations rate to increase dilation per layer. + If it is 2, dilations are 1, 2, 4, 8 for the next 4 layers. + num_blocks (int): number of wavenet blocks. + num_layers (int): number of wavenet layers. + c_in_channels (int): number of channels of conditioning input. + dropout_p (float): dropout rate. + weight_norm (bool): enable/disable weight norm for convolution layers. + """ + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_blocks, + num_layers, + c_in_channels=0, + dropout_p=0, + weight_norm=True, + ): + super().__init__() + self.wn_blocks = nn.ModuleList() + for idx in range(num_blocks): + layer = WN( + in_channels=in_channels if idx == 0 else hidden_channels, + hidden_channels=hidden_channels, + kernel_size=kernel_size, + dilation_rate=dilation_rate, + num_layers=num_layers, + c_in_channels=c_in_channels, + dropout_p=dropout_p, + weight_norm=weight_norm, + ) + self.wn_blocks.append(layer) + + def forward(self, x, x_mask=None, g=None): + o = x + for layer in self.wn_blocks: + o = layer(o, x_mask, g) + return o diff --git a/TTS/tts/layers/glow_tts/__init__.py b/TTS/tts/layers/glow_tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/layers/glow_tts/decoder.py b/TTS/tts/layers/glow_tts/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..61c5174ac5e67885288043885290c2906656c99c --- /dev/null +++ b/TTS/tts/layers/glow_tts/decoder.py @@ -0,0 +1,141 @@ +import torch +from torch import nn + +from TTS.tts.layers.generic.normalization import ActNorm +from TTS.tts.layers.glow_tts.glow import CouplingBlock, InvConvNear + + +def squeeze(x, x_mask=None, num_sqz=2): + """GlowTTS squeeze operation + Increase number of channels and reduce number of time steps + by the same factor. + + Note: + each 's' is a n-dimensional vector. + ``[s1,s2,s3,s4,s5,s6] --> [[s1, s3, s5], [s2, s4, s6]]`` + """ + b, c, t = x.size() + + t = (t // num_sqz) * num_sqz + x = x[:, :, :t] + x_sqz = x.view(b, c, t // num_sqz, num_sqz) + x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * num_sqz, t // num_sqz) + + if x_mask is not None: + x_mask = x_mask[:, :, num_sqz - 1 :: num_sqz] + else: + x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device, dtype=x.dtype) + return x_sqz * x_mask, x_mask + + +def unsqueeze(x, x_mask=None, num_sqz=2): + """GlowTTS unsqueeze operation (revert the squeeze) + + Note: + each 's' is a n-dimensional vector. + ``[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5, s2, s4, s6]]`` + """ + b, c, t = x.size() + + x_unsqz = x.view(b, num_sqz, c // num_sqz, t) + x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // num_sqz, t * num_sqz) + + if x_mask is not None: + x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, num_sqz).view(b, 1, t * num_sqz) + else: + x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device, dtype=x.dtype) + return x_unsqz * x_mask, x_mask + + +class Decoder(nn.Module): + """Stack of Glow Decoder Modules. + + :: + + Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze + + Args: + in_channels (int): channels of input tensor. + hidden_channels (int): hidden decoder channels. + kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.) + dilation_rate (int): rate to increase dilation by each layer in a decoder block. + num_flow_blocks (int): number of decoder blocks. + num_coupling_layers (int): number coupling layers. (number of wavenet layers.) + dropout_p (float): wavenet dropout rate. + sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer. + """ + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_flow_blocks, + num_coupling_layers, + dropout_p=0.0, + num_splits=4, + num_squeeze=2, + sigmoid_scale=False, + c_in_channels=0, + ): + super().__init__() + + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_flow_blocks = num_flow_blocks + self.num_coupling_layers = num_coupling_layers + self.dropout_p = dropout_p + self.num_splits = num_splits + self.num_squeeze = num_squeeze + self.sigmoid_scale = sigmoid_scale + self.c_in_channels = c_in_channels + + self.flows = nn.ModuleList() + for _ in range(num_flow_blocks): + self.flows.append(ActNorm(channels=in_channels * num_squeeze)) + self.flows.append(InvConvNear(channels=in_channels * num_squeeze, num_splits=num_splits)) + self.flows.append( + CouplingBlock( + in_channels * num_squeeze, + hidden_channels, + kernel_size=kernel_size, + dilation_rate=dilation_rate, + num_layers=num_coupling_layers, + c_in_channels=c_in_channels, + dropout_p=dropout_p, + sigmoid_scale=sigmoid_scale, + ) + ) + + def forward(self, x, x_mask, g=None, reverse=False): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1 ,T]` + - g: :math:`[B, C]` + """ + if not reverse: + flows = self.flows + logdet_tot = 0 + else: + flows = reversed(self.flows) + logdet_tot = None + + if self.num_squeeze > 1: + x, x_mask = squeeze(x, x_mask, self.num_squeeze) + for f in flows: + if not reverse: + x, logdet = f(x, x_mask, g=g, reverse=reverse) + logdet_tot += logdet + else: + x, logdet = f(x, x_mask, g=g, reverse=reverse) + if self.num_squeeze > 1: + x, x_mask = unsqueeze(x, x_mask, self.num_squeeze) + return x, logdet_tot + + def store_inverse(self): + for f in self.flows: + f.store_inverse() diff --git a/TTS/tts/layers/glow_tts/duration_predictor.py b/TTS/tts/layers/glow_tts/duration_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..e766ed6ab5a0348eaca8d1482be124003d8b8c68 --- /dev/null +++ b/TTS/tts/layers/glow_tts/duration_predictor.py @@ -0,0 +1,69 @@ +import torch +from torch import nn + +from ..generic.normalization import LayerNorm + + +class DurationPredictor(nn.Module): + """Glow-TTS duration prediction model. + + :: + + [2 x (conv1d_kxk -> relu -> layer_norm -> dropout)] -> conv1d_1x1 -> durs + + Args: + in_channels (int): Number of channels of the input tensor. + hidden_channels (int): Number of hidden channels of the network. + kernel_size (int): Kernel size for the conv layers. + dropout_p (float): Dropout rate used after each conv layer. + """ + + def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p, cond_channels=None, language_emb_dim=None): + super().__init__() + + # add language embedding dim in the input + if language_emb_dim: + in_channels += language_emb_dim + + # class arguments + self.in_channels = in_channels + self.filter_channels = hidden_channels + self.kernel_size = kernel_size + self.dropout_p = dropout_p + # layers + self.drop = nn.Dropout(dropout_p) + self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = LayerNorm(hidden_channels) + self.conv_2 = nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = LayerNorm(hidden_channels) + # output layer + self.proj = nn.Conv1d(hidden_channels, 1, 1) + if cond_channels is not None and cond_channels != 0: + self.cond = nn.Conv1d(cond_channels, in_channels, 1) + + if language_emb_dim != 0 and language_emb_dim is not None: + self.cond_lang = nn.Conv1d(language_emb_dim, in_channels, 1) + + def forward(self, x, x_mask, g=None, lang_emb=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + if g is not None: + x = x + self.cond(g) + + if lang_emb is not None: + x = x + self.cond_lang(lang_emb) + + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask diff --git a/TTS/tts/layers/glow_tts/encoder.py b/TTS/tts/layers/glow_tts/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3b43e527f5e9ca2bd0880bf204e04a1526bc8dfb --- /dev/null +++ b/TTS/tts/layers/glow_tts/encoder.py @@ -0,0 +1,179 @@ +import math + +import torch +from torch import nn + +from TTS.tts.layers.generic.gated_conv import GatedConvBlock +from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock +from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlock +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.glow_tts.glow import ResidualConv1dLayerNormBlock +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer +from TTS.tts.utils.helpers import sequence_mask + + +class Encoder(nn.Module): + """Glow-TTS encoder module. + + :: + + embedding -> -> encoder_module -> --> proj_mean + | + |-> proj_var + | + |-> concat -> duration_predictor + ↑ + speaker_embed + + Args: + num_chars (int): number of characters. + out_channels (int): number of output channels. + hidden_channels (int): encoder's embedding size. + hidden_channels_ffn (int): transformer's feed-forward channels. + kernel_size (int): kernel size for conv layers and duration predictor. + dropout_p (float): dropout rate for any dropout layer. + mean_only (bool): if True, output only mean values and use constant std. + use_prenet (bool): if True, use pre-convolutional layers before transformer layers. + c_in_channels (int): number of channels in conditional input. + + Shapes: + - input: (B, T, C) + + :: + + suggested encoder params... + + for encoder_type == 'rel_pos_transformer' + encoder_params={ + 'kernel_size':3, + 'dropout_p': 0.1, + 'num_layers': 6, + 'num_heads': 2, + 'hidden_channels_ffn': 768, # 4 times the hidden_channels + 'input_length': None + } + + for encoder_type == 'gated_conv' + encoder_params={ + 'kernel_size':5, + 'dropout_p': 0.1, + 'num_layers': 9, + } + + for encoder_type == 'residual_conv_bn' + encoder_params={ + "kernel_size": 4, + "dilations": [1, 2, 4, 1, 2, 4, 1, 2, 4, 1, 2, 4, 1], + "num_conv_blocks": 2, + "num_res_blocks": 13 + } + + for encoder_type == 'time_depth_separable' + encoder_params={ + "kernel_size": 5, + 'num_layers': 9, + } + """ + + def __init__( + self, + num_chars, + out_channels, + hidden_channels, + hidden_channels_dp, + encoder_type, + encoder_params, + dropout_p_dp=0.1, + mean_only=False, + use_prenet=True, + c_in_channels=0, + ): + super().__init__() + # class arguments + self.num_chars = num_chars + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.hidden_channels_dp = hidden_channels_dp + self.dropout_p_dp = dropout_p_dp + self.mean_only = mean_only + self.use_prenet = use_prenet + self.c_in_channels = c_in_channels + self.encoder_type = encoder_type + # embedding layer + self.emb = nn.Embedding(num_chars, hidden_channels) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + # init encoder module + if encoder_type.lower() == "rel_pos_transformer": + if use_prenet: + self.prenet = ResidualConv1dLayerNormBlock( + hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5 + ) + self.encoder = RelativePositionTransformer( + hidden_channels, hidden_channels, hidden_channels, **encoder_params + ) + elif encoder_type.lower() == "gated_conv": + self.encoder = GatedConvBlock(hidden_channels, **encoder_params) + elif encoder_type.lower() == "residual_conv_bn": + if use_prenet: + self.prenet = nn.Sequential(nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU()) + self.encoder = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **encoder_params) + self.postnet = nn.Sequential( + nn.Conv1d(self.hidden_channels, self.hidden_channels, 1), nn.BatchNorm1d(self.hidden_channels) + ) + elif encoder_type.lower() == "time_depth_separable": + if use_prenet: + self.prenet = ResidualConv1dLayerNormBlock( + hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5 + ) + self.encoder = TimeDepthSeparableConvBlock( + hidden_channels, hidden_channels, hidden_channels, **encoder_params + ) + else: + raise ValueError(" [!] Unkown encoder type.") + + # final projection layers + self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1) + if not mean_only: + self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1) + # duration predictor + self.duration_predictor = DurationPredictor( + hidden_channels + c_in_channels, hidden_channels_dp, 3, dropout_p_dp + ) + + def forward(self, x, x_lengths, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_lengths: :math:`[B]` + - g (optional): :math:`[B, 1, T]` + """ + # embedding layer + # [B ,T, D] + x = self.emb(x) * math.sqrt(self.hidden_channels) + # [B, D, T] + x = torch.transpose(x, 1, -1) + # compute input sequence mask + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + # prenet + if hasattr(self, "prenet") and self.use_prenet: + x = self.prenet(x, x_mask) + # encoder + x = self.encoder(x, x_mask) + # postnet + if hasattr(self, "postnet"): + x = self.postnet(x) * x_mask + # set duration predictor input + if g is not None: + g_exp = g.expand(-1, -1, x.size(-1)) + x_dp = torch.cat([x.detach(), g_exp], 1) + else: + x_dp = x.detach() + # final projection layer + x_m = self.proj_m(x) * x_mask + if not self.mean_only: + x_logs = self.proj_s(x) * x_mask + else: + x_logs = torch.zeros_like(x_m) + # duration predictor + logw = self.duration_predictor(x_dp, x_mask) + return x_m, x_logs, logw, x_mask diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py new file mode 100644 index 0000000000000000000000000000000000000000..77a796473b6fe783b702ba4980417168e0febeb4 --- /dev/null +++ b/TTS/tts/layers/glow_tts/glow.py @@ -0,0 +1,229 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.generic.wavenet import WN + +from ..generic.normalization import LayerNorm + + +class ResidualConv1dLayerNormBlock(nn.Module): + """Conv1d with Layer Normalization and residual connection as in GlowTTS paper. + https://arxiv.org/pdf/1811.00002.pdf + + :: + + x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o + |---------------> conv1d_1x1 ------------------| + + Args: + in_channels (int): number of input tensor channels. + hidden_channels (int): number of inner layer channels. + out_channels (int): number of output tensor channels. + kernel_size (int): kernel size of conv1d filter. + num_layers (int): number of blocks. + dropout_p (float): dropout rate for each block. + """ + + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, num_layers, dropout_p): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.num_layers = num_layers + self.dropout_p = dropout_p + assert num_layers > 1, " [!] number of layers should be > 0." + assert kernel_size % 2 == 1, " [!] kernel size should be odd number." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + + for idx in range(num_layers): + self.conv_layers.append( + nn.Conv1d( + in_channels if idx == 0 else hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + """ + x_res = x + for i in range(self.num_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x * x_mask) + x = F.dropout(F.relu(x), self.dropout_p, training=self.training) + x = x_res + self.proj(x) + return x * x_mask + + +class InvConvNear(nn.Module): + """Invertible Convolution with input splitting as in GlowTTS paper. + https://arxiv.org/pdf/1811.00002.pdf + + Args: + channels (int): input and output channels. + num_splits (int): number of splits, also H and W of conv layer. + no_jacobian (bool): enable/disable jacobian computations. + + Note: + Split the input into groups of size self.num_splits and + perform 1x1 convolution separately. Cast 1x1 conv operation + to 2d by reshaping the input for efficiency. + """ + + def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument + super().__init__() + assert num_splits % 2 == 0 + self.channels = channels + self.num_splits = num_splits + self.no_jacobian = no_jacobian + self.weight_inv = None + + w_init = torch.linalg.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_(), "complete")[0] + + if torch.det(w_init) < 0: + w_init[:, 0] = -1 * w_init[:, 0] + self.weight = nn.Parameter(w_init) + + def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + """ + b, c, t = x.size() + assert c % self.num_splits == 0 + if x_mask is None: + x_mask = 1 + x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t + else: + x_len = torch.sum(x_mask, [1, 2]) + + x = x.view(b, 2, c // self.num_splits, self.num_splits // 2, t) + x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.num_splits, c // self.num_splits, t) + + if reverse: + if self.weight_inv is not None: + weight = self.weight_inv + else: + weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) + logdet = None + else: + weight = self.weight + if self.no_jacobian: + logdet = 0 + else: + logdet = torch.logdet(self.weight) * (c / self.num_splits) * x_len # [b] + + weight = weight.view(self.num_splits, self.num_splits, 1, 1) + z = F.conv2d(x, weight) + + z = z.view(b, 2, self.num_splits // 2, c // self.num_splits, t) + z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask + return z, logdet + + def store_inverse(self): + weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) + self.weight_inv = nn.Parameter(weight_inv, requires_grad=False) + + +class CouplingBlock(nn.Module): + """Glow Affine Coupling block as in GlowTTS paper. + https://arxiv.org/pdf/1811.00002.pdf + + :: + + x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o + '-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^ + + Args: + in_channels (int): number of input tensor channels. + hidden_channels (int): number of hidden channels. + kernel_size (int): WaveNet filter kernel size. + dilation_rate (int): rate to increase dilation by each layer in a decoder block. + num_layers (int): number of WaveNet layers. + c_in_channels (int): number of conditioning input channels. + dropout_p (int): wavenet dropout rate. + sigmoid_scale (bool): enable/disable sigmoid scaling for output scale. + + Note: + It does not use the conditional inputs differently from WaveGlow. + """ + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + c_in_channels=0, + dropout_p=0, + sigmoid_scale=False, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.c_in_channels = c_in_channels + self.dropout_p = dropout_p + self.sigmoid_scale = sigmoid_scale + # input layer + start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1) + start = torch.nn.utils.parametrizations.weight_norm(start) + self.start = start + # output layer + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + end = torch.nn.Conv1d(hidden_channels, in_channels, 1) + end.weight.data.zero_() + end.bias.data.zero_() + self.end = end + # coupling layers + self.wn = WN(hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p) + + def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + if x_mask is None: + x_mask = 1 + x_0, x_1 = x[:, : self.in_channels // 2], x[:, self.in_channels // 2 :] + + x = self.start(x_0) * x_mask + x = self.wn(x, x_mask, g) + out = self.end(x) + + z_0 = x_0 + t = out[:, : self.in_channels // 2, :] + s = out[:, self.in_channels // 2 :, :] + if self.sigmoid_scale: + s = torch.log(1e-6 + torch.sigmoid(s + 2)) + + if reverse: + z_1 = (x_1 - t) * torch.exp(-s) * x_mask + logdet = None + else: + z_1 = (t + torch.exp(s) * x_1) * x_mask + logdet = torch.sum(s * x_mask, [1, 2]) + + z = torch.cat([z_0, z_1], 1) + return z, logdet + + def store_inverse(self): + self.wn.remove_weight_norm() diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c97d070a959472bff416c7d4401a897b7d4eed57 --- /dev/null +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -0,0 +1,427 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2 +from TTS.tts.utils.helpers import convert_pad_shape + + +class RelativePositionMultiHeadAttention(nn.Module): + """Multi-head attention with Relative Positional embedding. + https://arxiv.org/pdf/1809.04281.pdf + + It learns positional embeddings for a window of neighbours. For keys and values, + it learns different set of embeddings. Key embeddings are agregated with the attention + scores and value embeddings are aggregated with the output. + + Note: + Example with relative attention window size 2 + + - input = [a, b, c, d, e] + - rel_attn_embeddings = [e(t-2), e(t-1), e(t+1), e(t+2)] + + So it learns 4 embedding vectors (in total 8) separately for key and value vectors. + + Considering the input c + + - e(t-2) corresponds to c -> a + - e(t-2) corresponds to c -> b + - e(t-2) corresponds to c -> d + - e(t-2) corresponds to c -> e + + These embeddings are shared among different time steps. So input a, b, d and e also uses + the same embeddings. + + Embeddings are ignored when the relative window is out of limit for the first and the last + n items. + + Args: + channels (int): input and inner layer channels. + out_channels (int): output channels. + num_heads (int): number of attention heads. + rel_attn_window_size (int, optional): relation attention window size. + If 4, for each time step next and previous 4 time steps are attended. + If default, relative encoding is disabled and it is a regular transformer. + Defaults to None. + heads_share (bool, optional): [description]. Defaults to True. + dropout_p (float, optional): dropout rate. Defaults to 0.. + input_length (int, optional): intput length for positional encoding. Defaults to None. + proximal_bias (bool, optional): enable/disable proximal bias as in the paper. Defaults to False. + proximal_init (bool, optional): enable/disable poximal init as in the paper. + Init key and query layer weights the same. Defaults to False. + """ + + def __init__( + self, + channels, + out_channels, + num_heads, + rel_attn_window_size=None, + heads_share=True, + dropout_p=0.0, + input_length=None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % num_heads == 0, " [!] channels should be divisible by num_heads." + # class attributes + self.channels = channels + self.out_channels = out_channels + self.num_heads = num_heads + self.rel_attn_window_size = rel_attn_window_size + self.heads_share = heads_share + self.input_length = input_length + self.proximal_bias = proximal_bias + self.dropout_p = dropout_p + self.attn = None + # query, key, value layers + self.k_channels = channels // num_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + # output layers + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.dropout = nn.Dropout(dropout_p) + # relative positional encoding layers + if rel_attn_window_size is not None: + n_heads_rel = 1 if heads_share else num_heads + rel_stddev = self.k_channels**-0.5 + emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev + ) + emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev + ) + self.register_parameter("emb_rel_k", emb_rel_k) + self.register_parameter("emb_rel_v", emb_rel_v) + + # init layers + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + # proximal bias + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - c: :math:`[B, C, T]` + - attn_mask: :math:`[B, 1, T, T]` + """ + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + x, self.attn = self.attention(q, k, v, mask=attn_mask) + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.num_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3) + # compute raw attention scores + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + # relative positional encoding for scores + if self.rel_attn_window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + # get relative key embeddings + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) + rel_logits = self._relative_position_to_absolute_position(rel_logits) + scores_local = rel_logits / math.sqrt(self.k_channels) + scores = scores + scores_local + # proximan bias + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attn_proximity_bias(t_s).to(device=scores.device, dtype=scores.dtype) + # attention score masking + if mask is not None: + # add small value to prevent oor error. + scores = scores.masked_fill(mask == 0, -1e4) + if self.input_length is not None: + block_mask = torch.ones_like(scores).triu(-1 * self.input_length).tril(self.input_length) + scores = scores * block_mask + -1e4 * (1 - block_mask) + # attention score normalization + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + # apply dropout to attention weights + p_attn = self.dropout(p_attn) + # compute output + output = torch.matmul(p_attn, value) + # relative positional encoding for values + if self.rel_attn_window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + @staticmethod + def _matmul_with_relative_values(p_attn, re): + """ + Args: + p_attn (Tensor): attention weights. + re (Tensor): relative value embedding vector. (a_(i,j)^V) + + Shapes: + -p_attn: :math:`[B, H, T, V]` + -re: :math:`[H or 1, V, D]` + -logits: :math:`[B, H, T, D]` + """ + logits = torch.matmul(p_attn, re.unsqueeze(0)) + return logits + + @staticmethod + def _matmul_with_relative_keys(query, re): + """ + Args: + query (Tensor): batch of query vectors. (x*W^Q) + re (Tensor): relative key embedding vector. (a_(i,j)^K) + + Shapes: + - query: :math:`[B, H, T, D]` + - re: :math:`[H or 1, V, D]` + - logits: :math:`[B, H, T, V]` + """ + # logits = torch.einsum('bhld, kmd -> bhlm', [query, re.to(query.dtype)]) + logits = torch.matmul(query, re.unsqueeze(0).transpose(-2, -1)) + return logits + + def _get_relative_embeddings(self, relative_embeddings, length): + """Convert embedding vestors to a tensor of embeddings""" + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.rel_attn_window_size + 1), 0) + slice_start_position = max((self.rel_attn_window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] + return used_relative_embeddings + + @staticmethod + def _relative_position_to_absolute_position(x): + """Converts tensor from relative to absolute indexing for local attention. + Shapes: + x: :math:`[B, C, T, 2 * T - 1]` + Returns: + A Tensor of shape :math:`[B, C, T, T]` + """ + batch, heads, length, _ = x.size() + # Pad to shift from relative to absolute indexing. + x = F.pad(x, [0, 1, 0, 0, 0, 0, 0, 0]) + # Pad extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, [0, length - 1, 0, 0, 0, 0]) + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :] + return x_final + + @staticmethod + def _absolute_position_to_relative_position(x): + """ + Shapes: + - x: :math:`[B, C, T, T]` + - ret: :math:`[B, C, T, 2*T-1]` + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0]) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0]) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + @staticmethod + def _attn_proximity_bias(length): + """Produce an attention mask that discourages distant + attention values. + Args: + length (int): an integer scalar. + Returns: + a Tensor with shape :math:`[1, 1, T, T]` + """ + # L + r = torch.arange(length, dtype=torch.float32) + # L x L + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + # scale mask values + diff = -torch.log1p(torch.abs(diff)) + # 1 x 1 x L x L + return diff.unsqueeze(0).unsqueeze(0) + + +class FeedForwardNetwork(nn.Module): + """Feed Forward Inner layers for Transformer. + + Args: + in_channels (int): input tensor channels. + out_channels (int): output tensor channels. + hidden_channels (int): inner layers hidden channels. + kernel_size (int): conv1d filter kernel size. + dropout_p (float, optional): dropout rate. Defaults to 0. + """ + + def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0, causal=False): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dropout_p = dropout_p + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size) + self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size) + self.dropout = nn.Dropout(dropout_p) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + x = torch.relu(x) + x = self.dropout(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, convert_pad_shape(padding)) + return x + + +class RelativePositionTransformer(nn.Module): + """Transformer with Relative Potional Encoding. + https://arxiv.org/abs/1803.02155 + + Args: + in_channels (int): number of channels of the input tensor. + out_chanels (int): number of channels of the output tensor. + hidden_channels (int): model hidden channels. + hidden_channels_ffn (int): hidden channels of FeedForwardNetwork. + num_heads (int): number of attention heads. + num_layers (int): number of transformer layers. + kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1. + dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0. + rel_attn_window_size (int, optional): relation attention window size. + If 4, for each time step next and previous 4 time steps are attended. + If default, relative encoding is disabled and it is a regular transformer. + Defaults to None. + input_length (int, optional): input lenght to limit position encoding. Defaults to None. + layer_norm_type (str, optional): type "1" uses torch tensor operations and type "2" uses torch layer_norm + primitive. Use type "2", type "1: is for backward compat. Defaults to "1". + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + hidden_channels_ffn: int, + num_heads: int, + num_layers: int, + kernel_size=1, + dropout_p=0.0, + rel_attn_window_size: int = None, + input_length: int = None, + layer_norm_type: str = "1", + ): + super().__init__() + self.hidden_channels = hidden_channels + self.hidden_channels_ffn = hidden_channels_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.kernel_size = kernel_size + self.dropout_p = dropout_p + self.rel_attn_window_size = rel_attn_window_size + + self.dropout = nn.Dropout(dropout_p) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + + for idx in range(self.num_layers): + self.attn_layers.append( + RelativePositionMultiHeadAttention( + hidden_channels if idx != 0 else in_channels, + hidden_channels, + num_heads, + rel_attn_window_size=rel_attn_window_size, + dropout_p=dropout_p, + input_length=input_length, + ) + ) + if layer_norm_type == "1": + self.norm_layers_1.append(LayerNorm(hidden_channels)) + elif layer_norm_type == "2": + self.norm_layers_1.append(LayerNorm2(hidden_channels)) + else: + raise ValueError(" [!] Unknown layer norm type") + + if hidden_channels != out_channels and (idx + 1) == self.num_layers: + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + + self.ffn_layers.append( + FeedForwardNetwork( + hidden_channels, + hidden_channels if (idx + 1) != self.num_layers else out_channels, + hidden_channels_ffn, + kernel_size, + dropout_p=dropout_p, + ) + ) + + if layer_norm_type == "1": + self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels)) + elif layer_norm_type == "2": + self.norm_layers_2.append(LayerNorm2(hidden_channels if (idx + 1) != self.num_layers else out_channels)) + else: + raise ValueError(" [!] Unknown layer norm type") + + def forward(self, x, x_mask): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + """ + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.num_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.dropout(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.dropout(y) + + if (i + 1) == self.num_layers and hasattr(self, "proj"): + x = self.proj(x) + + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..5ebed81ddabd73e60b3165a1d001d20630fdac55 --- /dev/null +++ b/TTS/tts/layers/losses.py @@ -0,0 +1,892 @@ +import logging +import math + +import numpy as np +import torch +from coqpit import Coqpit +from torch import nn +from torch.nn import functional + +from TTS.tts.utils.helpers import sequence_mask +from TTS.tts.utils.ssim import SSIMLoss as _SSIMLoss +from TTS.utils.audio.torch_transforms import TorchSTFT + +logger = logging.getLogger(__name__) + + +# pylint: disable=abstract-method +# relates https://github.com/pytorch/pytorch/issues/42305 +class L1LossMasked(nn.Module): + def __init__(self, seq_len_norm): + super().__init__() + self.seq_len_norm = seq_len_norm + + def forward(self, x, target, length): + """ + Args: + x: A Variable containing a FloatTensor of size + (batch, max_len, dim) which contains the + unnormalized probability for each class. + target: A Variable containing a LongTensor of size + (batch, max_len, dim) which contains the index of the true + class for each corresponding step. + length: A Variable containing a LongTensor of size (batch,) + which contains the length of each data in a batch. + Shapes: + x: B x T X D + target: B x T x D + length: B + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + # mask: (batch, max_len, 1) + target.requires_grad = False + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() + if self.seq_len_norm: + norm_w = mask / mask.sum(dim=1, keepdim=True) + out_weights = norm_w.div(target.shape[0] * target.shape[2]) + mask = mask.expand_as(x) + loss = functional.l1_loss(x * mask, target * mask, reduction="none") + loss = loss.mul(out_weights.to(loss.device)).sum() + else: + mask = mask.expand_as(x) + loss = functional.l1_loss(x * mask, target * mask, reduction="sum") + loss = loss / mask.sum() + return loss + + +class MSELossMasked(nn.Module): + def __init__(self, seq_len_norm): + super().__init__() + self.seq_len_norm = seq_len_norm + + def forward(self, x, target, length): + """ + Args: + x: A Variable containing a FloatTensor of size + (batch, max_len, dim) which contains the + unnormalized probability for each class. + target: A Variable containing a LongTensor of size + (batch, max_len, dim) which contains the index of the true + class for each corresponding step. + length: A Variable containing a LongTensor of size (batch,) + which contains the length of each data in a batch. + Shapes: + - x: :math:`[B, T, D]` + - target: :math:`[B, T, D]` + - length: :math:`B` + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + # mask: (batch, max_len, 1) + target.requires_grad = False + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() + if self.seq_len_norm: + norm_w = mask / mask.sum(dim=1, keepdim=True) + out_weights = norm_w.div(target.shape[0] * target.shape[2]) + mask = mask.expand_as(x) + loss = functional.mse_loss(x * mask, target * mask, reduction="none") + loss = loss.mul(out_weights.to(loss.device)).sum() + else: + mask = mask.expand_as(x) + loss = functional.mse_loss(x * mask, target * mask, reduction="sum") + loss = loss / mask.sum() + return loss + + +def sample_wise_min_max(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Min-Max normalize tensor through first dimension + Shapes: + - x: :math:`[B, D1, D2]` + - m: :math:`[B, D1, 1]` + """ + maximum = torch.amax(x.masked_fill(~mask, 0), dim=(1, 2), keepdim=True) + minimum = torch.amin(x.masked_fill(~mask, np.inf), dim=(1, 2), keepdim=True) + return (x - minimum) / (maximum - minimum + 1e-8) + + +class SSIMLoss(torch.nn.Module): + """SSIM loss as (1 - SSIM) + SSIM is explained here https://en.wikipedia.org/wiki/Structural_similarity + """ + + def __init__(self): + super().__init__() + self.loss_func = _SSIMLoss() + + def forward(self, y_hat, y, length): + """ + Args: + y_hat (tensor): model prediction values. + y (tensor): target values. + length (tensor): length of each sample in a batch for masking. + + Shapes: + y_hat: B x T X D + y: B x T x D + length: B + + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + mask = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2) + y_norm = sample_wise_min_max(y, mask) + y_hat_norm = sample_wise_min_max(y_hat, mask) + ssim_loss = self.loss_func((y_norm * mask).unsqueeze(1), (y_hat_norm * mask).unsqueeze(1)) + + if ssim_loss.item() > 1.0: + logger.info("SSIM loss is out-of-range (%.2f), setting it to 1.0", ssim_loss.item()) + ssim_loss = torch.tensor(1.0, device=ssim_loss.device) + + if ssim_loss.item() < 0.0: + logger.info("SSIM loss is out-of-range (%.2f), setting it to 0.0", ssim_loss.item()) + ssim_loss = torch.tensor(0.0, device=ssim_loss.device) + + return ssim_loss + + +class AttentionEntropyLoss(nn.Module): + # pylint: disable=R0201 + def forward(self, align): + """ + Forces attention to be more decisive by penalizing + soft attention weights + """ + entropy = torch.distributions.Categorical(probs=align).entropy() + loss = (entropy / np.log(align.shape[1])).mean() + return loss + + +class BCELossMasked(nn.Module): + """BCE loss with masking. + + Used mainly for stopnet in autoregressive models. + + Args: + pos_weight (float): weight for positive samples. If set < 1, penalize early stopping. Defaults to None. + """ + + def __init__(self, pos_weight: float = None): + super().__init__() + self.register_buffer("pos_weight", torch.tensor([pos_weight])) + + def forward(self, x, target, length): + """ + Args: + x: A Variable containing a FloatTensor of size + (batch, max_len) which contains the + unnormalized probability for each class. + target: A Variable containing a LongTensor of size + (batch, max_len) which contains the index of the true + class for each corresponding step. + length: A Variable containing a LongTensor of size (batch,) + which contains the length of each data in a batch. + Shapes: + x: B x T + target: B x T + length: B + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + target.requires_grad = False + if length is not None: + # mask: (batch, max_len, 1) + mask = sequence_mask(sequence_length=length, max_len=target.size(1)) + num_items = mask.sum() + loss = functional.binary_cross_entropy_with_logits( + x.masked_select(mask), + target.masked_select(mask), + pos_weight=self.pos_weight.to(x.device), + reduction="sum", + ) + else: + loss = functional.binary_cross_entropy_with_logits( + x, target, pos_weight=self.pos_weight.to(x.device), reduction="sum" + ) + num_items = torch.numel(x) + loss = loss / num_items + return loss + + +class DifferentialSpectralLoss(nn.Module): + """Differential Spectral Loss + https://arxiv.org/ftp/arxiv/papers/1909/1909.10302.pdf""" + + def __init__(self, loss_func): + super().__init__() + self.loss_func = loss_func + + def forward(self, x, target, length=None): + """ + Shapes: + x: B x T + target: B x T + length: B + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + x_diff = x[:, 1:] - x[:, :-1] + target_diff = target[:, 1:] - target[:, :-1] + if length is None: + return self.loss_func(x_diff, target_diff) + return self.loss_func(x_diff, target_diff, length - 1) + + +class GuidedAttentionLoss(torch.nn.Module): + def __init__(self, sigma=0.4): + super().__init__() + self.sigma = sigma + + def _make_ga_masks(self, ilens, olens): + B = len(ilens) + max_ilen = max(ilens) + max_olen = max(olens) + ga_masks = torch.zeros((B, max_olen, max_ilen)) + for idx, (ilen, olen) in enumerate(zip(ilens, olens)): + ga_masks[idx, :olen, :ilen] = self._make_ga_mask(ilen, olen, self.sigma) + return ga_masks + + def forward(self, att_ws, ilens, olens): + ga_masks = self._make_ga_masks(ilens, olens).to(att_ws.device) + seq_masks = self._make_masks(ilens, olens).to(att_ws.device) + losses = ga_masks * att_ws + loss = torch.mean(losses.masked_select(seq_masks)) + return loss + + @staticmethod + def _make_ga_mask(ilen, olen, sigma): + grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen), indexing="ij") + grid_x, grid_y = grid_x.float(), grid_y.float() + return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2))) + + @staticmethod + def _make_masks(ilens, olens): + in_masks = sequence_mask(ilens) + out_masks = sequence_mask(olens) + return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) + + +class Huber(nn.Module): + # pylint: disable=R0201 + def forward(self, x, y, length=None): + """ + Shapes: + x: B x T + y: B x T + length: B + """ + mask = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float() + return torch.nn.functional.smooth_l1_loss(x * mask, y * mask, reduction="sum") / mask.sum() + + +class ForwardSumLoss(nn.Module): + def __init__(self, blank_logprob=-1): + super().__init__() + self.log_softmax = torch.nn.LogSoftmax(dim=3) + self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True) + self.blank_logprob = blank_logprob + + def forward(self, attn_logprob, in_lens, out_lens): + key_lens = in_lens + query_lens = out_lens + attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob) + + total_loss = 0.0 + for bid in range(attn_logprob.shape[0]): + target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0) + curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1] + + curr_logprob = self.log_softmax(curr_logprob[None])[0] + loss = self.ctc_loss( + curr_logprob, + target_seq, + input_lengths=query_lens[bid : bid + 1], + target_lengths=key_lens[bid : bid + 1], + ) + total_loss = total_loss + loss + + total_loss = total_loss / attn_logprob.shape[0] + return total_loss + + +######################## +# MODEL LOSS LAYERS +######################## + + +class TacotronLoss(torch.nn.Module): + """Collection of Tacotron set-up based on provided config.""" + + def __init__(self, c, ga_sigma=0.4): + super().__init__() + self.stopnet_pos_weight = c.stopnet_pos_weight + self.use_capacitron_vae = c.use_capacitron_vae + if self.use_capacitron_vae: + self.capacitron_capacity = c.capacitron_vae.capacitron_capacity + self.capacitron_vae_loss_alpha = c.capacitron_vae.capacitron_VAE_loss_alpha + self.ga_alpha = c.ga_alpha + self.decoder_diff_spec_alpha = c.decoder_diff_spec_alpha + self.postnet_diff_spec_alpha = c.postnet_diff_spec_alpha + self.decoder_alpha = c.decoder_loss_alpha + self.postnet_alpha = c.postnet_loss_alpha + self.decoder_ssim_alpha = c.decoder_ssim_alpha + self.postnet_ssim_alpha = c.postnet_ssim_alpha + self.config = c + + # postnet and decoder loss + if c.loss_masking: + self.criterion = L1LossMasked(c.seq_len_norm) if c.model in ["Tacotron"] else MSELossMasked(c.seq_len_norm) + else: + self.criterion = nn.L1Loss() if c.model in ["Tacotron"] else nn.MSELoss() + # guided attention loss + if c.ga_alpha > 0: + self.criterion_ga = GuidedAttentionLoss(sigma=ga_sigma) + # differential spectral loss + if c.postnet_diff_spec_alpha > 0 or c.decoder_diff_spec_alpha > 0: + self.criterion_diff_spec = DifferentialSpectralLoss(loss_func=self.criterion) + # ssim loss + if c.postnet_ssim_alpha > 0 or c.decoder_ssim_alpha > 0: + self.criterion_ssim = SSIMLoss() + # stopnet loss + # pylint: disable=not-callable + self.criterion_st = BCELossMasked(pos_weight=torch.tensor(self.stopnet_pos_weight)) if c.stopnet else None + + # For dev pruposes only + self.criterion_capacitron_reconstruction_loss = nn.L1Loss(reduction="sum") + + def forward( + self, + postnet_output, + decoder_output, + mel_input, + linear_input, + stopnet_output, + stopnet_target, + stop_target_length, + capacitron_vae_outputs, + output_lens, + decoder_b_output, + alignments, + alignment_lens, + alignments_backwards, + input_lens, + ): + # decoder outputs linear or mel spectrograms for Tacotron and Tacotron2 + # the target should be set acccordingly + postnet_target = linear_input if self.config.model.lower() in ["tacotron"] else mel_input + + return_dict = {} + # remove lengths if no masking is applied + if not self.config.loss_masking: + output_lens = None + # decoder and postnet losses + if self.config.loss_masking: + if self.decoder_alpha > 0: + decoder_loss = self.criterion(decoder_output, mel_input, output_lens) + if self.postnet_alpha > 0: + postnet_loss = self.criterion(postnet_output, postnet_target, output_lens) + else: + if self.decoder_alpha > 0: + decoder_loss = self.criterion(decoder_output, mel_input) + if self.postnet_alpha > 0: + postnet_loss = self.criterion(postnet_output, postnet_target) + loss = self.decoder_alpha * decoder_loss + self.postnet_alpha * postnet_loss + return_dict["decoder_loss"] = decoder_loss + return_dict["postnet_loss"] = postnet_loss + + if self.use_capacitron_vae: + # extract capacitron vae infos + posterior_distribution, prior_distribution, beta = capacitron_vae_outputs + + # KL divergence term between the posterior and the prior + kl_term = torch.mean(torch.distributions.kl_divergence(posterior_distribution, prior_distribution)) + + # Limit the mutual information between the data and latent space by the variational capacity limit + kl_capacity = kl_term - self.capacitron_capacity + + # pass beta through softplus to keep it positive + beta = torch.nn.functional.softplus(beta)[0] + + # This is the term going to the main ADAM optimiser, we detach beta because + # beta is optimised by a separate, SGD optimiser below + capacitron_vae_loss = beta.detach() * kl_capacity + + # normalize the capacitron_vae_loss as in L1Loss or MSELoss. + # After this, both the standard loss and capacitron_vae_loss will be in the same scale. + # For this reason we don't need use L1Loss and MSELoss in "sum" reduction mode. + # Note: the batch is not considered because the L1Loss was calculated in "sum" mode + # divided by the batch size, So not dividing the capacitron_vae_loss by B is legitimate. + + # get B T D dimension from input + B, T, D = mel_input.size() + # normalize + if self.config.loss_masking: + # if mask loss get T using the mask + T = output_lens.sum() / B + + # Only for dev purposes to be able to compare the reconstruction loss with the values in the + # original Capacitron paper + return_dict["capaciton_reconstruction_loss"] = ( + self.criterion_capacitron_reconstruction_loss(decoder_output, mel_input) / decoder_output.size(0) + ) + kl_capacity + + capacitron_vae_loss = capacitron_vae_loss / (T * D) + capacitron_vae_loss = capacitron_vae_loss * self.capacitron_vae_loss_alpha + + # This is the term to purely optimise beta and to pass into the SGD optimizer + beta_loss = torch.negative(beta) * kl_capacity.detach() + + loss += capacitron_vae_loss + + return_dict["capacitron_vae_loss"] = capacitron_vae_loss + return_dict["capacitron_vae_beta_loss"] = beta_loss + return_dict["capacitron_vae_kl_term"] = kl_term + return_dict["capacitron_beta"] = beta + + stop_loss = ( + self.criterion_st(stopnet_output, stopnet_target, stop_target_length) + if self.config.stopnet + else torch.zeros(1) + ) + loss += stop_loss + return_dict["stopnet_loss"] = stop_loss + + # backward decoder loss (if enabled) + if self.config.bidirectional_decoder: + if self.config.loss_masking: + decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1,)), mel_input, output_lens) + else: + decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1,)), mel_input) + decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_b_output, dims=(1,)), decoder_output) + loss += self.decoder_alpha * (decoder_b_loss + decoder_c_loss) + return_dict["decoder_b_loss"] = decoder_b_loss + return_dict["decoder_c_loss"] = decoder_c_loss + + # double decoder consistency loss (if enabled) + if self.config.double_decoder_consistency: + if self.config.loss_masking: + decoder_b_loss = self.criterion(decoder_b_output, mel_input, output_lens) + else: + decoder_b_loss = self.criterion(decoder_b_output, mel_input) + # decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output) + attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards) + loss += self.decoder_alpha * (decoder_b_loss + attention_c_loss) + return_dict["decoder_coarse_loss"] = decoder_b_loss + return_dict["decoder_ddc_loss"] = attention_c_loss + + # guided attention loss (if enabled) + if self.config.ga_alpha > 0: + ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens) + loss += ga_loss * self.ga_alpha + return_dict["ga_loss"] = ga_loss + + # decoder differential spectral loss + if self.config.decoder_diff_spec_alpha > 0: + decoder_diff_spec_loss = self.criterion_diff_spec(decoder_output, mel_input, output_lens) + loss += decoder_diff_spec_loss * self.decoder_diff_spec_alpha + return_dict["decoder_diff_spec_loss"] = decoder_diff_spec_loss + + # postnet differential spectral loss + if self.config.postnet_diff_spec_alpha > 0: + postnet_diff_spec_loss = self.criterion_diff_spec(postnet_output, postnet_target, output_lens) + loss += postnet_diff_spec_loss * self.postnet_diff_spec_alpha + return_dict["postnet_diff_spec_loss"] = postnet_diff_spec_loss + + # decoder ssim loss + if self.config.decoder_ssim_alpha > 0: + decoder_ssim_loss = self.criterion_ssim(decoder_output, mel_input, output_lens) + loss += decoder_ssim_loss * self.postnet_ssim_alpha + return_dict["decoder_ssim_loss"] = decoder_ssim_loss + + # postnet ssim loss + if self.config.postnet_ssim_alpha > 0: + postnet_ssim_loss = self.criterion_ssim(postnet_output, postnet_target, output_lens) + loss += postnet_ssim_loss * self.postnet_ssim_alpha + return_dict["postnet_ssim_loss"] = postnet_ssim_loss + + return_dict["loss"] = loss + return return_dict + + +class GlowTTSLoss(torch.nn.Module): + def __init__(self): + super().__init__() + self.constant_factor = 0.5 * math.log(2 * math.pi) + + def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, o_attn_dur, x_lengths): + return_dict = {} + # flow loss - neg log likelihood + pz = torch.sum(scales) + 0.5 * torch.sum(torch.exp(-2 * scales) * (z - means) ** 2) + log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths) * z.shape[2]) + # duration loss - MSE + loss_dur = torch.sum((o_dur_log - o_attn_dur) ** 2) / torch.sum(x_lengths) + # duration loss - huber loss + # loss_dur = torch.nn.functional.smooth_l1_loss(o_dur_log, o_attn_dur, reduction="sum") / torch.sum(x_lengths) + return_dict["loss"] = log_mle + loss_dur + return_dict["log_mle"] = log_mle + return_dict["loss_dur"] = loss_dur + + # check if any loss is NaN + for key, loss in return_dict.items(): + if torch.isnan(loss): + raise RuntimeError(f" [!] NaN loss with {key}.") + return return_dict + + +def mse_loss_custom(x, y): + """MSE loss using the torch back-end without reduction. + It uses less VRAM than the raw code""" + expanded_x, expanded_y = torch.broadcast_tensors(x, y) + return torch._C._nn.mse_loss(expanded_x, expanded_y, 0) # pylint: disable=protected-access, c-extension-no-member + + +class MDNLoss(nn.Module): + """Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf.""" + + def forward(self, logp, text_lengths, mel_lengths): # pylint: disable=no-self-use + """ + Shapes: + mu: [B, D, T] + log_sigma: [B, D, T] + mel_spec: [B, D, T] + """ + B, T_seq, T_mel = logp.shape + log_alpha = logp.new_ones(B, T_seq, T_mel) * (-1e4) + log_alpha[:, 0, 0] = logp[:, 0, 0] + for t in range(1, T_mel): + prev_step = torch.cat( + [log_alpha[:, :, t - 1 : t], functional.pad(log_alpha[:, :, t - 1 : t], (0, 0, 1, -1), value=-1e4)], + dim=-1, + ) + log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + logp[:, :, t] + alpha_last = log_alpha[torch.arange(B), text_lengths - 1, mel_lengths - 1] + mdn_loss = -alpha_last.mean() / T_seq + return mdn_loss # , log_prob_matrix + + +class AlignTTSLoss(nn.Module): + """Modified AlignTTS Loss. + Computes + - L1 and SSIM losses from output spectrograms. + - Huber loss for duration predictor. + - MDNLoss for Mixture of Density Network. + + All loss values are aggregated by a weighted sum of the alpha values. + + Args: + c (dict): TTS model configuration. + """ + + def __init__(self, c): + super().__init__() + self.mdn_loss = MDNLoss() + self.spec_loss = MSELossMasked(False) + self.ssim = SSIMLoss() + self.dur_loss = MSELossMasked(False) + + self.ssim_alpha = c.ssim_alpha + self.dur_loss_alpha = c.dur_loss_alpha + self.spec_loss_alpha = c.spec_loss_alpha + self.mdn_alpha = c.mdn_alpha + + def forward( + self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, phase + ): + # ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step) + spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0 + if phase == 0: + mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens) + elif phase == 1: + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + elif phase == 2: + mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens) + spec_loss = self.spec_lossX(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + elif phase == 3: + dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens) + else: + mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens) + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens) + loss = ( + self.spec_loss_alpha * spec_loss + + self.ssim_alpha * ssim_loss + + self.dur_loss_alpha * dur_loss + + self.mdn_alpha * mdn_loss + ) + return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss} + + +class VitsGeneratorLoss(nn.Module): + def __init__(self, c: Coqpit): + super().__init__() + self.kl_loss_alpha = c.kl_loss_alpha + self.gen_loss_alpha = c.gen_loss_alpha + self.feat_loss_alpha = c.feat_loss_alpha + self.dur_loss_alpha = c.dur_loss_alpha + self.mel_loss_alpha = c.mel_loss_alpha + self.spk_encoder_loss_alpha = c.speaker_encoder_loss_alpha + self.stft = TorchSTFT( + c.audio.fft_size, + c.audio.hop_length, + c.audio.win_length, + sample_rate=c.audio.sample_rate, + mel_fmin=c.audio.mel_fmin, + mel_fmax=c.audio.mel_fmax, + n_mels=c.audio.num_mels, + use_mel=True, + do_amp_to_db=True, + ) + + @staticmethod + def feature_loss(feats_real, feats_generated): + loss = 0 + for dr, dg in zip(feats_real, feats_generated): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + return loss * 2 + + @staticmethod + def generator_loss(scores_fake): + loss = 0 + gen_losses = [] + for dg in scores_fake: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + @staticmethod + def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): + """ + z_p, logs_q: [b, h, t_t] + m_p, logs_p: [b, h, t_t] + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l + + @staticmethod + def cosine_similarity_loss(gt_spk_emb, syn_spk_emb): + return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() + + def forward( + self, + mel_slice, + mel_slice_hat, + z_p, + logs_q, + m_p, + logs_p, + z_len, + scores_disc_fake, + feats_disc_fake, + feats_disc_real, + loss_duration, + use_speaker_encoder_as_loss=False, + gt_spk_emb=None, + syn_spk_emb=None, + ): + """ + Shapes: + - mel_slice : :math:`[B, 1, T]` + - mel_slice_hat: :math:`[B, 1, T]` + - z_p: :math:`[B, C, T]` + - logs_q: :math:`[B, C, T]` + - m_p: :math:`[B, C, T]` + - logs_p: :math:`[B, C, T]` + - z_len: :math:`[B]` + - scores_disc_fake[i]: :math:`[B, C]` + - feats_disc_fake[i][j]: :math:`[B, C, T', P]` + - feats_disc_real[i][j]: :math:`[B, C, T', P]` + """ + loss = 0.0 + return_dict = {} + z_mask = sequence_mask(z_len).float() + # compute losses + loss_kl = ( + self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1)) + * self.kl_loss_alpha + ) + loss_feat = ( + self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha + ) + loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha + loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha + loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha + loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration + + if use_speaker_encoder_as_loss: + loss_se = self.cosine_similarity_loss(gt_spk_emb, syn_spk_emb) * self.spk_encoder_loss_alpha + loss = loss + loss_se + return_dict["loss_spk_encoder"] = loss_se + # pass losses to the dict + return_dict["loss_gen"] = loss_gen + return_dict["loss_kl"] = loss_kl + return_dict["loss_feat"] = loss_feat + return_dict["loss_mel"] = loss_mel + return_dict["loss_duration"] = loss_duration + return_dict["loss"] = loss + return return_dict + + +class VitsDiscriminatorLoss(nn.Module): + def __init__(self, c: Coqpit): + super().__init__() + self.disc_loss_alpha = c.disc_loss_alpha + + @staticmethod + def discriminator_loss(scores_real, scores_fake): + loss = 0 + real_losses = [] + fake_losses = [] + for dr, dg in zip(scores_real, scores_fake): + dr = dr.float() + dg = dg.float() + real_loss = torch.mean((1 - dr) ** 2) + fake_loss = torch.mean(dg**2) + loss += real_loss + fake_loss + real_losses.append(real_loss.item()) + fake_losses.append(fake_loss.item()) + return loss, real_losses, fake_losses + + def forward(self, scores_disc_real, scores_disc_fake): + loss = 0.0 + return_dict = {} + loss_disc, loss_disc_real, _ = self.discriminator_loss( + scores_real=scores_disc_real, scores_fake=scores_disc_fake + ) + return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha + loss = loss + return_dict["loss_disc"] + return_dict["loss"] = loss + + for i, ldr in enumerate(loss_disc_real): + return_dict[f"loss_disc_real_{i}"] = ldr + return return_dict + + +class ForwardTTSLoss(nn.Module): + """Generic configurable ForwardTTS loss.""" + + def __init__(self, c): + super().__init__() + if c.spec_loss_type == "mse": + self.spec_loss = MSELossMasked(False) + elif c.spec_loss_type == "l1": + self.spec_loss = L1LossMasked(False) + else: + raise ValueError(" [!] Unknown spec_loss_type {}".format(c.spec_loss_type)) + + if c.duration_loss_type == "mse": + self.dur_loss = MSELossMasked(False) + elif c.duration_loss_type == "l1": + self.dur_loss = L1LossMasked(False) + elif c.duration_loss_type == "huber": + self.dur_loss = Huber() + else: + raise ValueError(" [!] Unknown duration_loss_type {}".format(c.duration_loss_type)) + + if c.model_args.use_aligner: + self.aligner_loss = ForwardSumLoss() + self.aligner_loss_alpha = c.aligner_loss_alpha + + if c.model_args.use_pitch: + self.pitch_loss = MSELossMasked(False) + self.pitch_loss_alpha = c.pitch_loss_alpha + + if c.model_args.use_energy: + self.energy_loss = MSELossMasked(False) + self.energy_loss_alpha = c.energy_loss_alpha + + if c.use_ssim_loss: + self.ssim = SSIMLoss() if c.use_ssim_loss else None + self.ssim_loss_alpha = c.ssim_loss_alpha + + self.spec_loss_alpha = c.spec_loss_alpha + self.dur_loss_alpha = c.dur_loss_alpha + self.binary_alignment_loss_alpha = c.binary_align_loss_alpha + + @staticmethod + def _binary_alignment_loss(alignment_hard, alignment_soft): + """Binary loss that forces soft alignments to match the hard alignments as + explained in `https://arxiv.org/pdf/2108.10447.pdf`. + """ + log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() + return -log_sum / alignment_hard.sum() + + def forward( + self, + decoder_output, + decoder_target, + decoder_output_lens, + dur_output, + dur_target, + pitch_output, + pitch_target, + energy_output, + energy_target, + input_lens, + alignment_logprob=None, + alignment_hard=None, + alignment_soft=None, + binary_loss_weight=None, + ): + loss = 0 + return_dict = {} + if hasattr(self, "ssim_loss") and self.ssim_loss_alpha > 0: + ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) + loss = loss + self.ssim_loss_alpha * ssim_loss + return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss + + if self.spec_loss_alpha > 0: + spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens) + loss = loss + self.spec_loss_alpha * spec_loss + return_dict["loss_spec"] = self.spec_loss_alpha * spec_loss + + if self.dur_loss_alpha > 0: + log_dur_tgt = torch.log(dur_target.float() + 1) + dur_loss = self.dur_loss(dur_output[:, :, None], log_dur_tgt[:, :, None], input_lens) + loss = loss + self.dur_loss_alpha * dur_loss + return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss + + if hasattr(self, "pitch_loss") and self.pitch_loss_alpha > 0: + pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens) + loss = loss + self.pitch_loss_alpha * pitch_loss + return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss + + if hasattr(self, "energy_loss") and self.energy_loss_alpha > 0: + energy_loss = self.energy_loss(energy_output.transpose(1, 2), energy_target.transpose(1, 2), input_lens) + loss = loss + self.energy_loss_alpha * energy_loss + return_dict["loss_energy"] = self.energy_loss_alpha * energy_loss + + if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0: + aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens) + loss = loss + self.aligner_loss_alpha * aligner_loss + return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss + + if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None: + binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) + loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss + if binary_loss_weight: + return_dict["loss_binary_alignment"] = ( + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + ) + else: + return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss + + return_dict["loss"] = loss + return return_dict diff --git a/TTS/tts/layers/overflow/__init__.py b/TTS/tts/layers/overflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/layers/overflow/common_layers.py b/TTS/tts/layers/overflow/common_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..9f77af293ca0da1e8969d21b1c56496f5f836756 --- /dev/null +++ b/TTS/tts/layers/overflow/common_layers.py @@ -0,0 +1,326 @@ +import logging +from typing import List, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from tqdm.auto import tqdm + +from TTS.tts.layers.tacotron.common_layers import Linear +from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock + +logger = logging.getLogger(__name__) + + +class Encoder(nn.Module): + r"""Neural HMM Encoder + + Same as Tacotron 2 encoder but increases the input length by states per phone + + Args: + num_chars (int): Number of characters in the input. + state_per_phone (int): Number of states per phone. + in_out_channels (int): number of input and output channels. + n_convolutions (int): number of convolutional layers. + """ + + def __init__(self, num_chars, state_per_phone, in_out_channels=512, n_convolutions=3): + super().__init__() + + self.state_per_phone = state_per_phone + self.in_out_channels = in_out_channels + + self.emb = nn.Embedding(num_chars, in_out_channels) + self.convolutions = nn.ModuleList() + for _ in range(n_convolutions): + self.convolutions.append(ConvBNBlock(in_out_channels, in_out_channels, 5, "relu")) + self.lstm = nn.LSTM( + in_out_channels, + int(in_out_channels / 2) * state_per_phone, + num_layers=1, + batch_first=True, + bias=True, + bidirectional=True, + ) + self.rnn_state = None + + def forward(self, x: torch.FloatTensor, x_len: torch.LongTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]: + """Forward pass to the encoder. + + Args: + x (torch.FloatTensor): input text indices. + - shape: :math:`(b, T_{in})` + x_len (torch.LongTensor): input text lengths. + - shape: :math:`(b,)` + + Returns: + Tuple[torch.FloatTensor, torch.LongTensor]: encoder outputs and output lengths. + -shape: :math:`((b, T_{in} * states_per_phone, in_out_channels), (b,))` + """ + b, T = x.shape + o = self.emb(x).transpose(1, 2) + for layer in self.convolutions: + o = layer(o) + o = o.transpose(1, 2) + o = nn.utils.rnn.pack_padded_sequence(o, x_len.cpu(), batch_first=True) + self.lstm.flatten_parameters() + o, _ = self.lstm(o) + o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True) + o = o.reshape(b, T * self.state_per_phone, self.in_out_channels) + x_len = x_len * self.state_per_phone + return o, x_len + + def inference(self, x, x_len): + """Inference to the encoder. + + Args: + x (torch.FloatTensor): input text indices. + - shape: :math:`(b, T_{in})` + x_len (torch.LongTensor): input text lengths. + - shape: :math:`(b,)` + + Returns: + Tuple[torch.FloatTensor, torch.LongTensor]: encoder outputs and output lengths. + -shape: :math:`((b, T_{in} * states_per_phone, in_out_channels), (b,))` + """ + b, T = x.shape + o = self.emb(x).transpose(1, 2) + for layer in self.convolutions: + o = layer(o) + o = o.transpose(1, 2) + # self.lstm.flatten_parameters() + o, _ = self.lstm(o) + o = o.reshape(b, T * self.state_per_phone, self.in_out_channels) + x_len = x_len * self.state_per_phone + return o, x_len + + +class ParameterModel(nn.Module): + r"""Main neural network of the outputnet + + Note: Do not put dropout layers here, the model will not converge. + + Args: + outputnet_size (List[int]): the architecture of the parameter model + input_size (int): size of input for the first layer + output_size (int): size of output i.e size of the feature dim + frame_channels (int): feature dim to set the flat start bias + flat_start_params (dict): flat start parameters to set the bias + """ + + def __init__( + self, + outputnet_size: List[int], + input_size: int, + output_size: int, + frame_channels: int, + flat_start_params: dict, + ): + super().__init__() + self.frame_channels = frame_channels + + self.layers = nn.ModuleList( + [Linear(inp, out) for inp, out in zip([input_size] + outputnet_size[:-1], outputnet_size)] + ) + self.last_layer = nn.Linear(outputnet_size[-1], output_size) + self.flat_start_output_layer( + flat_start_params["mean"], flat_start_params["std"], flat_start_params["transition_p"] + ) + + def flat_start_output_layer(self, mean, std, transition_p): + self.last_layer.weight.data.zero_() + self.last_layer.bias.data[0 : self.frame_channels] = mean + self.last_layer.bias.data[self.frame_channels : 2 * self.frame_channels] = OverflowUtils.inverse_softplus(std) + self.last_layer.bias.data[2 * self.frame_channels :] = OverflowUtils.inverse_sigmod(transition_p) + + def forward(self, x): + for layer in self.layers: + x = F.relu(layer(x)) + x = self.last_layer(x) + return x + + +class Outputnet(nn.Module): + r""" + This network takes current state and previous observed values as input + and returns its parameters, mean, standard deviation and probability + of transition to the next state + """ + + def __init__( + self, + encoder_dim: int, + memory_rnn_dim: int, + frame_channels: int, + outputnet_size: List[int], + flat_start_params: dict, + std_floor: float = 1e-2, + ): + super().__init__() + + self.frame_channels = frame_channels + self.flat_start_params = flat_start_params + self.std_floor = std_floor + + input_size = memory_rnn_dim + encoder_dim + output_size = 2 * frame_channels + 1 + + self.parametermodel = ParameterModel( + outputnet_size=outputnet_size, + input_size=input_size, + output_size=output_size, + flat_start_params=flat_start_params, + frame_channels=frame_channels, + ) + + def forward(self, ar_mels, inputs): + r"""Inputs observation and returns the means, stds and transition probability for the current state + + Args: + ar_mel_inputs (torch.FloatTensor): shape (batch, prenet_dim) + states (torch.FloatTensor): (batch, hidden_states, hidden_state_dim) + + Returns: + means: means for the emission observation for each feature + - shape: (B, hidden_states, feature_size) + stds: standard deviations for the emission observation for each feature + - shape: (batch, hidden_states, feature_size) + transition_vectors: transition vector for the current hidden state + - shape: (batch, hidden_states) + """ + batch_size, prenet_dim = ar_mels.shape[0], ar_mels.shape[1] + N = inputs.shape[1] + + ar_mels = ar_mels.unsqueeze(1).expand(batch_size, N, prenet_dim) + ar_mels = torch.cat((ar_mels, inputs), dim=2) + ar_mels = self.parametermodel(ar_mels) + + mean, std, transition_vector = ( + ar_mels[:, :, 0 : self.frame_channels], + ar_mels[:, :, self.frame_channels : 2 * self.frame_channels], + ar_mels[:, :, 2 * self.frame_channels :].squeeze(2), + ) + std = F.softplus(std) + std = self._floor_std(std) + return mean, std, transition_vector + + def _floor_std(self, std): + r""" + It clamps the standard deviation to not to go below some level + This removes the problem when the model tries to cheat for higher likelihoods by converting + one of the gaussians to a point mass. + + Args: + std (float Tensor): tensor containing the standard deviation to be + """ + original_tensor = std.clone().detach() + std = torch.clamp(std, min=self.std_floor) + if torch.any(original_tensor != std): + logger.info( + "Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about" + ) + return std + + +class OverflowUtils: + @staticmethod + def get_data_parameters_for_flat_start( + data_loader: torch.utils.data.DataLoader, out_channels: int, states_per_phone: int + ): + """Generates data parameters for flat starting the HMM. + + Args: + data_loader (torch.utils.data.Dataloader): _description_ + out_channels (int): mel spectrogram channels + states_per_phone (_type_): HMM states per phone + """ + + # State related information for transition_p + total_state_len = 0 + total_mel_len = 0 + + # Useful for data mean an std + total_mel_sum = 0 + total_mel_sq_sum = 0 + + for batch in tqdm(data_loader, leave=False): + text_lengths = batch["token_id_lengths"] + mels = batch["mel"] + mel_lengths = batch["mel_lengths"] + + total_state_len += torch.sum(text_lengths) + total_mel_len += torch.sum(mel_lengths) + total_mel_sum += torch.sum(mels) + total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) + + data_mean = total_mel_sum / (total_mel_len * out_channels) + data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) + average_num_states = total_state_len / len(data_loader.dataset) + average_mel_len = total_mel_len / len(data_loader.dataset) + average_duration_each_state = average_mel_len / average_num_states + init_transition_prob = 1 / average_duration_each_state + + return data_mean, data_std, (init_transition_prob * states_per_phone) + + @staticmethod + @torch.no_grad() + def update_flat_start_transition(model, transition_p): + model.neural_hmm.output_net.parametermodel.flat_start_output_layer(0.0, 1.0, transition_p) + + @staticmethod + def log_clamped(x, eps=1e-04): + """ + Avoids the log(0) problem + + Args: + x (torch.tensor): input tensor + eps (float, optional): lower bound. Defaults to 1e-04. + + Returns: + torch.tensor: :math:`log(x)` + """ + clamped_x = torch.clamp(x, min=eps) + return torch.log(clamped_x) + + @staticmethod + def inverse_sigmod(x): + r""" + Inverse of the sigmoid function + """ + if not torch.is_tensor(x): + x = torch.tensor(x) + return OverflowUtils.log_clamped(x / (1.0 - x)) + + @staticmethod + def inverse_softplus(x): + r""" + Inverse of the softplus function + """ + if not torch.is_tensor(x): + x = torch.tensor(x) + return OverflowUtils.log_clamped(torch.exp(x) - 1.0) + + @staticmethod + def logsumexp(x, dim): + r""" + Differentiable LogSumExp: Does not creates nan gradients + when all the inputs are -inf yeilds 0 gradients. + Args: + x : torch.Tensor - The input tensor + dim: int - The dimension on which the log sum exp has to be applied + """ + + m, _ = x.max(dim=dim) + mask = m == -float("inf") + s = (x - m.masked_fill_(mask, 0).unsqueeze(dim=dim)).exp().sum(dim=dim) + return s.masked_fill_(mask, 1).log() + m.masked_fill_(mask, -float("inf")) + + @staticmethod + def double_pad(list_of_different_shape_tensors): + r""" + Pads the list of tensors in 2 dimensions + """ + second_dim_lens = [len(a) for a in [i[0] for i in list_of_different_shape_tensors]] + second_dim_max = max(second_dim_lens) + padded_x = [F.pad(x, (0, second_dim_max - len(x[0]))) for x in list_of_different_shape_tensors] + return nn.utils.rnn.pad_sequence(padded_x, batch_first=True) diff --git a/TTS/tts/layers/overflow/decoder.py b/TTS/tts/layers/overflow/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4cd7ae88068cfaffe179f2e61354cc7eb760268c --- /dev/null +++ b/TTS/tts/layers/overflow/decoder.py @@ -0,0 +1,81 @@ +import torch +from torch import nn + +from TTS.tts.layers.glow_tts.decoder import Decoder as GlowDecoder +from TTS.tts.utils.helpers import sequence_mask + + +class Decoder(nn.Module): + """Uses glow decoder with some modifications. + :: + + Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze + + Args: + in_channels (int): channels of input tensor. + hidden_channels (int): hidden decoder channels. + kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.) + dilation_rate (int): rate to increase dilation by each layer in a decoder block. + num_flow_blocks (int): number of decoder blocks. + num_coupling_layers (int): number coupling layers. (number of wavenet layers.) + dropout_p (float): wavenet dropout rate. + sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer. + """ + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_flow_blocks, + num_coupling_layers, + dropout_p=0.0, + num_splits=4, + num_squeeze=2, + sigmoid_scale=False, + c_in_channels=0, + ): + super().__init__() + + self.glow_decoder = GlowDecoder( + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_flow_blocks, + num_coupling_layers, + dropout_p, + num_splits, + num_squeeze, + sigmoid_scale, + c_in_channels, + ) + self.n_sqz = num_squeeze + + def forward(self, x, x_len, g=None, reverse=False): + """ + Input shapes: + - x: :math:`[B, C, T]` + - x_len :math:`[B]` + - g: :math:`[B, C]` + + Output shapes: + - x: :math:`[B, C, T]` + - x_len :math:`[B]` + - logget_tot :math:`[B]` + """ + x, x_len, x_max_len = self.preprocess(x, x_len, x_len.max()) + x_mask = torch.unsqueeze(sequence_mask(x_len, x_max_len), 1).to(x.dtype) + x, logdet_tot = self.glow_decoder(x, x_mask, g, reverse) + return x, x_len, logdet_tot + + def preprocess(self, y, y_lengths, y_max_length): + if y_max_length is not None: + y_max_length = torch.div(y_max_length, self.n_sqz, rounding_mode="floor") * self.n_sqz + y = y[:, :, :y_max_length] + y_lengths = torch.div(y_lengths, self.n_sqz, rounding_mode="floor") * self.n_sqz + return y, y_lengths, y_max_length + + def store_inverse(self): + self.glow_decoder.store_inverse() diff --git a/TTS/tts/layers/overflow/neural_hmm.py b/TTS/tts/layers/overflow/neural_hmm.py new file mode 100644 index 0000000000000000000000000000000000000000..a12becef03e6c80d6869055d335c8a39e0d12779 --- /dev/null +++ b/TTS/tts/layers/overflow/neural_hmm.py @@ -0,0 +1,554 @@ +from typing import List + +import torch +import torch.distributions as tdist +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + +from TTS.tts.layers.overflow.common_layers import Outputnet, OverflowUtils +from TTS.tts.layers.tacotron.common_layers import Prenet +from TTS.tts.utils.helpers import sequence_mask + + +class NeuralHMM(nn.Module): + """Autoregressive left to right HMM model primarily used in "Neural HMMs are all you need (for high-quality attention-free TTS)" + + Paper:: + https://arxiv.org/abs/2108.13320 + + Paper abstract:: + Neural sequence-to-sequence TTS has achieved significantly better output quality than statistical speech synthesis using + HMMs. However, neural TTS is generally not probabilistic and uses non-monotonic attention. Attention failures increase + training time and can make synthesis babble incoherently. This paper describes how the old and new paradigms can be + combined to obtain the advantages of both worlds, by replacing attention in neural TTS with an autoregressive left-right + no-skip hidden Markov model defined by a neural network. Based on this proposal, we modify Tacotron 2 to obtain an + HMM-based neural TTS model with monotonic alignment, trained to maximise the full sequence likelihood without + approximation. We also describe how to combine ideas from classical and contemporary TTS for best results. The resulting + example system is smaller and simpler than Tacotron 2, and learns to speak with fewer iterations and less data, whilst + achieving comparable naturalness prior to the post-net. Our approach also allows easy control over speaking rate. + + Args: + frame_channels (int): Output dimension to generate. + ar_order (int): Autoregressive order of the model. In ablations of Neural HMM it was found that more autoregression while giving more variation hurts naturalness of the synthesised audio. + deterministic_transition (bool): deterministic duration generation based on duration quantiles as defiend in "S. Ronanki, O. Watts, S. King, and G. E. Henter, “Medianbased generation of synthetic speech durations using a nonparametric approach,” in Proc. SLT, 2016.". Defaults to True. + encoder_dim (int): Channels of encoder input and character embedding tensors. Defaults to 512. + prenet_type (str): `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the Prenet. + prenet_dim (int): Dimension of the Prenet. + prenet_n_layers (int): Number of layers in the Prenet. + prenet_dropout (float): Dropout probability of the Prenet. + prenet_dropout_at_inference (bool): If True, dropout is applied at inference time. + memory_rnn_dim (int): Size of the memory RNN to process output of prenet. + outputnet_size (List[int]): Size of the output network inside the neural HMM. + flat_start_params (dict): Parameters for the flat start initialization of the neural HMM. + std_floor (float): Floor value for the standard deviation of the neural HMM. Prevents model cheating by putting point mass and getting infinite likelihood at any datapoint. + use_grad_checkpointing (bool, optional): Use gradient checkpointing to save memory. Defaults to True. + """ + + def __init__( + self, + frame_channels: int, + ar_order: int, + deterministic_transition: bool, + encoder_dim: int, + prenet_type: str, + prenet_dim: int, + prenet_n_layers: int, + prenet_dropout: float, + prenet_dropout_at_inference: bool, + memory_rnn_dim: int, + outputnet_size: List[int], + flat_start_params: dict, + std_floor: float, + use_grad_checkpointing: bool = True, + ): + super().__init__() + + self.frame_channels = frame_channels + self.ar_order = ar_order + self.deterministic_transition = deterministic_transition + self.prenet_dim = prenet_dim + self.memory_rnn_dim = memory_rnn_dim + self.use_grad_checkpointing = use_grad_checkpointing + + self.transition_model = TransitionModel() + self.emission_model = EmissionModel() + + assert ar_order > 0, f"AR order must be greater than 0 provided {ar_order}" + + self.ar_order = ar_order + self.prenet = Prenet( + in_features=frame_channels * ar_order, + prenet_type=prenet_type, + prenet_dropout=prenet_dropout, + dropout_at_inference=prenet_dropout_at_inference, + out_features=[self.prenet_dim for _ in range(prenet_n_layers)], + bias=False, + ) + self.memory_rnn = nn.LSTMCell(input_size=prenet_dim, hidden_size=memory_rnn_dim) + self.output_net = Outputnet( + encoder_dim, memory_rnn_dim, frame_channels, outputnet_size, flat_start_params, std_floor + ) + self.register_buffer("go_tokens", torch.zeros(ar_order, 1)) + + def forward(self, inputs, inputs_len, mels, mel_lens): + r"""HMM forward algorithm for training uses logarithmic version of Rabiner (1989) forward algorithm. + + Args: + inputs (torch.FloatTensor): Encoder outputs + inputs_len (torch.LongTensor): Encoder output lengths + mels (torch.FloatTensor): Mel inputs + mel_lens (torch.LongTensor): Length of mel inputs + + Shapes: + - inputs: (B, T, D_out_enc) + - inputs_len: (B) + - mels: (B, D_mel, T_mel) + - mel_lens: (B) + + Returns: + log_prob (torch.FloatTensor): Log probability of the sequence + """ + # Get dimensions of inputs + batch_size, N, _ = inputs.shape + T_max = torch.max(mel_lens) + mels = mels.permute(0, 2, 1) + + # Intialize forward algorithm + log_state_priors = self._initialize_log_state_priors(inputs) + log_c, log_alpha_scaled, transition_matrix, means = self._initialize_forward_algorithm_variables(mels, N) + + # Initialize autoregression elements + ar_inputs = self._add_go_token(mels) + h_memory, c_memory = self._init_lstm_states(batch_size, self.memory_rnn_dim, mels) + + for t in range(T_max): + # Process Autoregression + h_memory, c_memory = self._process_ar_timestep(t, ar_inputs, h_memory, c_memory) + # Get mean, std and transition vector from decoder for this timestep + # Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop + if self.use_grad_checkpointing and self.training: + # TODO: use_reentrant=False is recommended + mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs, use_reentrant=True) + else: + mean, std, transition_vector = self.output_net(h_memory, inputs) + + if t == 0: + log_alpha_temp = log_state_priors + self.emission_model(mels[:, 0], mean, std, inputs_len) + else: + log_alpha_temp = self.emission_model(mels[:, t], mean, std, inputs_len) + self.transition_model( + log_alpha_scaled[:, t - 1, :], transition_vector, inputs_len + ) + log_c[:, t] = torch.logsumexp(log_alpha_temp, dim=1) + log_alpha_scaled[:, t, :] = log_alpha_temp - log_c[:, t].unsqueeze(1) + transition_matrix[:, t] = transition_vector # needed for absorption state calculation + + # Save for plotting + means.append(mean.detach()) + + log_c, log_alpha_scaled = self._mask_lengths(mel_lens, log_c, log_alpha_scaled) + + sum_final_log_c = self.get_absorption_state_scaling_factor( + mel_lens, log_alpha_scaled, inputs_len, transition_matrix + ) + + log_probs = torch.sum(log_c, dim=1) + sum_final_log_c + + return log_probs, log_alpha_scaled, transition_matrix, means + + @staticmethod + def _mask_lengths(mel_lens, log_c, log_alpha_scaled): + """ + Mask the lengths of the forward variables so that the variable lenghts + do not contribute in the loss calculation + Args: + mel_inputs (torch.FloatTensor): (batch, T, frame_channels) + mel_inputs_lengths (torch.IntTensor): (batch) + log_c (torch.FloatTensor): (batch, T) + Returns: + log_c (torch.FloatTensor) : scaled probabilities (batch, T) + log_alpha_scaled (torch.FloatTensor): forward probabilities (batch, T, N) + """ + mask_log_c = sequence_mask(mel_lens) + log_c = log_c * mask_log_c + mask_log_alpha_scaled = mask_log_c.unsqueeze(2) + log_alpha_scaled = log_alpha_scaled * mask_log_alpha_scaled + return log_c, log_alpha_scaled + + def _process_ar_timestep( + self, + t, + ar_inputs, + h_memory, + c_memory, + ): + """ + Process autoregression in timestep + 1. At a specific t timestep + 2. Perform data dropout if applied (we did not use it) + 3. Run the autoregressive frame through the prenet (has dropout) + 4. Run the prenet output through the post prenet rnn + + Args: + t (int): mel-spec timestep + ar_inputs (torch.FloatTensor): go-token appended mel-spectrograms + - shape: (b, D_out, T_out) + h_post_prenet (torch.FloatTensor): previous timestep rnn hidden state + - shape: (b, memory_rnn_dim) + c_post_prenet (torch.FloatTensor): previous timestep rnn cell state + - shape: (b, memory_rnn_dim) + + Returns: + h_post_prenet (torch.FloatTensor): rnn hidden state of the current timestep + c_post_prenet (torch.FloatTensor): rnn cell state of the current timestep + """ + prenet_input = ar_inputs[:, t : t + self.ar_order].flatten(1) + memory_inputs = self.prenet(prenet_input) + h_memory, c_memory = self.memory_rnn(memory_inputs, (h_memory, c_memory)) + return h_memory, c_memory + + def _add_go_token(self, mel_inputs): + """Append the go token to create the autoregressive input + Args: + mel_inputs (torch.FloatTensor): (batch_size, T, n_mel_channel) + Returns: + ar_inputs (torch.FloatTensor): (batch_size, T, n_mel_channel) + """ + batch_size, T, _ = mel_inputs.shape + go_tokens = self.go_tokens.unsqueeze(0).expand(batch_size, self.ar_order, self.frame_channels) + ar_inputs = torch.cat((go_tokens, mel_inputs), dim=1)[:, :T] + return ar_inputs + + @staticmethod + def _initialize_forward_algorithm_variables(mel_inputs, N): + r"""Initialize placeholders for forward algorithm variables, to use a stable + version we will use log_alpha_scaled and the scaling constant + + Args: + mel_inputs (torch.FloatTensor): (b, T_max, frame_channels) + N (int): number of states + Returns: + log_c (torch.FloatTensor): Scaling constant (b, T_max) + """ + b, T_max, _ = mel_inputs.shape + log_alpha_scaled = mel_inputs.new_zeros((b, T_max, N)) + log_c = mel_inputs.new_zeros(b, T_max) + transition_matrix = mel_inputs.new_zeros((b, T_max, N)) + + # Saving for plotting later, will not have gradient tapes + means = [] + return log_c, log_alpha_scaled, transition_matrix, means + + @staticmethod + def _init_lstm_states(batch_size, hidden_state_dim, device_tensor): + r""" + Initialize Hidden and Cell states for LSTM Cell + + Args: + batch_size (Int): batch size + hidden_state_dim (Int): dimensions of the h and c + device_tensor (torch.FloatTensor): useful for the device and type + + Returns: + (torch.FloatTensor): shape (batch_size, hidden_state_dim) + can be hidden state for LSTM + (torch.FloatTensor): shape (batch_size, hidden_state_dim) + can be the cell state for LSTM + """ + return ( + device_tensor.new_zeros(batch_size, hidden_state_dim), + device_tensor.new_zeros(batch_size, hidden_state_dim), + ) + + def get_absorption_state_scaling_factor(self, mels_len, log_alpha_scaled, inputs_len, transition_vector): + """Returns the final scaling factor of absorption state + + Args: + mels_len (torch.IntTensor): Input size of mels to + get the last timestep of log_alpha_scaled + log_alpha_scaled (torch.FloatTEnsor): State probabilities + text_lengths (torch.IntTensor): length of the states to + mask the values of states lengths + ( + Useful when the batch has very different lengths, + when the length of an observation is less than + the number of max states, then the log alpha after + the state value is filled with -infs. So we mask + those values so that it only consider the states + which are needed for that length + ) + transition_vector (torch.FloatTensor): transtiion vector for each state per timestep + + Shapes: + - mels_len: (batch_size) + - log_alpha_scaled: (batch_size, N, T) + - text_lengths: (batch_size) + - transition_vector: (batch_size, N, T) + + Returns: + sum_final_log_c (torch.FloatTensor): (batch_size) + + """ + N = torch.max(inputs_len) + max_inputs_len = log_alpha_scaled.shape[2] + state_lengths_mask = sequence_mask(inputs_len, max_len=max_inputs_len) + + last_log_alpha_scaled_index = ( + (mels_len - 1).unsqueeze(-1).expand(-1, N).unsqueeze(1) + ) # Batch X Hidden State Size + last_log_alpha_scaled = torch.gather(log_alpha_scaled, 1, last_log_alpha_scaled_index).squeeze(1) + last_log_alpha_scaled = last_log_alpha_scaled.masked_fill(~state_lengths_mask, -float("inf")) + + last_transition_vector = torch.gather(transition_vector, 1, last_log_alpha_scaled_index).squeeze(1) + last_transition_probability = torch.sigmoid(last_transition_vector) + log_probability_of_transitioning = OverflowUtils.log_clamped(last_transition_probability) + + last_transition_probability_index = self.get_mask_for_last_item(inputs_len, inputs_len.device) + log_probability_of_transitioning = log_probability_of_transitioning.masked_fill( + ~last_transition_probability_index, -float("inf") + ) + final_log_c = last_log_alpha_scaled + log_probability_of_transitioning + + # If the length of the mel is less than the number of states it will select the -inf values leading to nan gradients + # Ideally, we should clean the dataset otherwise this is a little hack uncomment the line below + final_log_c = final_log_c.clamp(min=torch.finfo(final_log_c.dtype).min) + + sum_final_log_c = torch.logsumexp(final_log_c, dim=1) + return sum_final_log_c + + @staticmethod + def get_mask_for_last_item(lengths, device, out_tensor=None): + """Returns n-1 mask for the last item in the sequence. + + Args: + lengths (torch.IntTensor): lengths in a batch + device (str, optional): Defaults to "cpu". + out_tensor (torch.Tensor, optional): uses the memory of a specific tensor. + Defaults to None. + + Returns: + - Shape: :math:`(b, max_len)` + """ + max_len = torch.max(lengths).item() + ids = ( + torch.arange(0, max_len, device=device) if out_tensor is None else torch.arange(0, max_len, out=out_tensor) + ) + mask = ids == lengths.unsqueeze(1) - 1 + return mask + + @torch.inference_mode() + def inference( + self, + inputs: torch.FloatTensor, + input_lens: torch.LongTensor, + sampling_temp: float, + max_sampling_time: int, + duration_threshold: float, + ): + """Inference from autoregressive neural HMM + + Args: + inputs (torch.FloatTensor): input states + - shape: :math:`(b, T, d)` + input_lens (torch.LongTensor): input state lengths + - shape: :math:`(b)` + sampling_temp (float): sampling temperature + max_sampling_temp (int): max sampling temperature + duration_threshold (float): duration threshold to switch to next state + - Use this to change the spearking rate of the synthesised audio + """ + + b = inputs.shape[0] + outputs = { + "hmm_outputs": [], + "hmm_outputs_len": [], + "alignments": [], + "input_parameters": [], + "output_parameters": [], + } + for i in range(b): + neural_hmm_outputs, states_travelled, input_parameters, output_parameters = self.sample( + inputs[i : i + 1], input_lens[i], sampling_temp, max_sampling_time, duration_threshold + ) + + outputs["hmm_outputs"].append(neural_hmm_outputs) + outputs["hmm_outputs_len"].append(neural_hmm_outputs.shape[0]) + outputs["alignments"].append(states_travelled) + outputs["input_parameters"].append(input_parameters) + outputs["output_parameters"].append(output_parameters) + + outputs["hmm_outputs"] = nn.utils.rnn.pad_sequence(outputs["hmm_outputs"], batch_first=True) + outputs["hmm_outputs_len"] = torch.tensor( + outputs["hmm_outputs_len"], dtype=input_lens.dtype, device=input_lens.device + ) + return outputs + + @torch.inference_mode() + def sample(self, inputs, input_lens, sampling_temp, max_sampling_time, duration_threshold): + """Samples an output from the parameter models + + Args: + inputs (torch.FloatTensor): input states + - shape: :math:`(1, T, d)` + input_lens (torch.LongTensor): input state lengths + - shape: :math:`(1)` + sampling_temp (float): sampling temperature + max_sampling_time (int): max sampling time + duration_threshold (float): duration threshold to switch to next state + + Returns: + outputs (torch.FloatTensor): Output Observations + - Shape: :math:`(T, output_dim)` + states_travelled (list[int]): Hidden states travelled + - Shape: :math:`(T)` + input_parameters (list[torch.FloatTensor]): Input parameters + output_parameters (list[torch.FloatTensor]): Output parameters + """ + states_travelled, outputs, t = [], [], 0 + + # Sample initial state + current_state = 0 + states_travelled.append(current_state) + + # Prepare autoregression + prenet_input = self.go_tokens.unsqueeze(0).expand(1, self.ar_order, self.frame_channels) + h_memory, c_memory = self._init_lstm_states(1, self.memory_rnn_dim, prenet_input) + + input_parameter_values = [] + output_parameter_values = [] + quantile = 1 + while True: + memory_input = self.prenet(prenet_input.flatten(1).unsqueeze(0)) + # will be 1 while sampling + h_memory, c_memory = self.memory_rnn(memory_input.squeeze(0), (h_memory, c_memory)) + + z_t = inputs[:, current_state].unsqueeze(0) # Add fake time dimension + mean, std, transition_vector = self.output_net(h_memory, z_t) + + transition_probability = torch.sigmoid(transition_vector.flatten()) + staying_probability = torch.sigmoid(-transition_vector.flatten()) + + # Save for plotting + input_parameter_values.append([prenet_input, current_state]) + output_parameter_values.append([mean, std, transition_probability]) + + x_t = self.emission_model.sample(mean, std, sampling_temp=sampling_temp) + + # Prepare autoregressive input for next iteration + prenet_input = torch.cat((prenet_input, x_t), dim=1)[:, 1:] + + outputs.append(x_t.flatten()) + + transition_matrix = torch.cat((staying_probability, transition_probability)) + quantile *= staying_probability + if not self.deterministic_transition: + switch = transition_matrix.multinomial(1)[0].item() + else: + switch = quantile < duration_threshold + + if switch: + current_state += 1 + quantile = 1 + + states_travelled.append(current_state) + + if (current_state == input_lens) or (max_sampling_time and t == max_sampling_time - 1): + break + + t += 1 + + return ( + torch.stack(outputs, dim=0), + F.one_hot(input_lens.new_tensor(states_travelled)), + input_parameter_values, + output_parameter_values, + ) + + @staticmethod + def _initialize_log_state_priors(text_embeddings): + """Creates the log pi in forward algorithm. + + Args: + text_embeddings (torch.FloatTensor): used to create the log pi + on current device + + Shapes: + - text_embeddings: (B, T, D_out_enc) + """ + N = text_embeddings.shape[1] + log_state_priors = text_embeddings.new_full([N], -float("inf")) + log_state_priors[0] = 0.0 + return log_state_priors + + +class TransitionModel(nn.Module): + """Transition Model of the HMM, it represents the probability of transitioning + form current state to all other states""" + + def forward(self, log_alpha_scaled, transition_vector, inputs_len): # pylint: disable=no-self-use + r""" + product of the past state with transitional probabilities in log space + + Args: + log_alpha_scaled (torch.Tensor): Multiply previous timestep's alphas by + transition matrix (in log domain) + - shape: (batch size, N) + transition_vector (torch.tensor): transition vector for each state + - shape: (N) + inputs_len (int tensor): Lengths of states in a batch + - shape: (batch) + + Returns: + out (torch.FloatTensor): log probability of transitioning to each state + """ + transition_p = torch.sigmoid(transition_vector) + staying_p = torch.sigmoid(-transition_vector) + + log_staying_probability = OverflowUtils.log_clamped(staying_p) + log_transition_probability = OverflowUtils.log_clamped(transition_p) + + staying = log_alpha_scaled + log_staying_probability + leaving = log_alpha_scaled + log_transition_probability + leaving = leaving.roll(1, dims=1) + leaving[:, 0] = -float("inf") + inputs_len_mask = sequence_mask(inputs_len) + out = OverflowUtils.logsumexp(torch.stack((staying, leaving), dim=2), dim=2) + out = out.masked_fill(~inputs_len_mask, -float("inf")) # There are no states to contribute to the loss + return out + + +class EmissionModel(nn.Module): + """Emission Model of the HMM, it represents the probability of + emitting an observation based on the current state""" + + def __init__(self) -> None: + super().__init__() + self.distribution_function: tdist.Distribution = tdist.normal.Normal + + def sample(self, means, stds, sampling_temp): + return self.distribution_function(means, stds * sampling_temp).sample() if sampling_temp > 0 else means + + def forward(self, x_t, means, stds, state_lengths): + r"""Calculates the log probability of the the given data (x_t) + being observed from states with given means and stds + Args: + x_t (float tensor) : observation at current time step + - shape: (batch, feature_dim) + means (float tensor): means of the distributions of hidden states + - shape: (batch, hidden_state, feature_dim) + stds (float tensor): standard deviations of the distributions of the hidden states + - shape: (batch, hidden_state, feature_dim) + state_lengths (int tensor): Lengths of states in a batch + - shape: (batch) + + Returns: + out (float tensor): observation log likelihoods, + expressing the probability of an observation + being generated from a state i + shape: (batch, hidden_state) + """ + emission_dists = self.distribution_function(means, stds) + out = emission_dists.log_prob(x_t.unsqueeze(1)) + state_lengths_mask = sequence_mask(state_lengths).unsqueeze(2) + out = torch.sum(out * state_lengths_mask, dim=2) + return out diff --git a/TTS/tts/layers/overflow/plotting_utils.py b/TTS/tts/layers/overflow/plotting_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d9d3e3d141dfa7025fb58fd2ebbecd3f17fcf46f --- /dev/null +++ b/TTS/tts/layers/overflow/plotting_utils.py @@ -0,0 +1,79 @@ +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +import torch + + +def validate_numpy_array(value: Any): + r""" + Validates the input and makes sure it returns a numpy array (i.e on CPU) + + Args: + value (Any): the input value + + Raises: + TypeError: if the value is not a numpy array or torch tensor + + Returns: + np.ndarray: numpy array of the value + """ + if isinstance(value, np.ndarray): + pass + elif isinstance(value, list): + value = np.array(value) + elif torch.is_tensor(value): + value = value.cpu().numpy() + else: + raise TypeError("Value must be a numpy array, a torch tensor or a list") + + return value + + +def get_spec_from_most_probable_state(log_alpha_scaled, means, decoder=None): + """Get the most probable state means from the log_alpha_scaled. + + Args: + log_alpha_scaled (torch.Tensor): Log alpha scaled values. + - Shape: :math:`(T, N)` + means (torch.Tensor): Means of the states. + - Shape: :math:`(N, T, D_out)` + decoder (torch.nn.Module): Decoder module to decode the latent to melspectrogram. Defaults to None. + """ + max_state_numbers = torch.max(log_alpha_scaled, dim=1)[1] + max_len = means.shape[0] + n_mel_channels = means.shape[2] + max_state_numbers = max_state_numbers.unsqueeze(1).unsqueeze(1).expand(max_len, 1, n_mel_channels) + means = torch.gather(means, 1, max_state_numbers).squeeze(1).to(log_alpha_scaled.dtype) + if decoder is not None: + mel = ( + decoder(means.T.unsqueeze(0), torch.tensor([means.shape[0]], device=means.device), reverse=True)[0] + .squeeze(0) + .T + ) + else: + mel = means + return mel + + +def plot_transition_probabilities_to_numpy(states, transition_probabilities, output_fig=False): + """Generates trainsition probabilities plot for the states and the probability of transition. + + Args: + states (torch.IntTensor): the states + transition_probabilities (torch.FloatTensor): the transition probabilities + """ + states = validate_numpy_array(states) + transition_probabilities = validate_numpy_array(transition_probabilities) + + fig, ax = plt.subplots(figsize=(30, 3)) + ax.plot(transition_probabilities, "o") + ax.set_title("Transition probability of state") + ax.set_xlabel("hidden state") + ax.set_ylabel("probability") + ax.set_xticks(list(range(len(transition_probabilities)))) + ax.set_xticklabels([int(x) for x in states], rotation=90) + plt.tight_layout() + if not output_fig: + plt.close() + return fig diff --git a/TTS/tts/layers/tacotron/__init__.py b/TTS/tts/layers/tacotron/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/layers/tacotron/attentions.py b/TTS/tts/layers/tacotron/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..25c3798e6b8f5fbc66224af66c9955e245b94097 --- /dev/null +++ b/TTS/tts/layers/tacotron/attentions.py @@ -0,0 +1,486 @@ +import torch +from scipy.stats import betabinom +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.tacotron.common_layers import Linear + + +class LocationLayer(nn.Module): + """Layers for Location Sensitive Attention + + Args: + attention_dim (int): number of channels in the input tensor. + attention_n_filters (int, optional): number of filters in convolution. Defaults to 32. + attention_kernel_size (int, optional): kernel size of convolution filter. Defaults to 31. + """ + + def __init__(self, attention_dim, attention_n_filters=32, attention_kernel_size=31): + super().__init__() + self.location_conv1d = nn.Conv1d( + in_channels=2, + out_channels=attention_n_filters, + kernel_size=attention_kernel_size, + stride=1, + padding=(attention_kernel_size - 1) // 2, + bias=False, + ) + self.location_dense = Linear(attention_n_filters, attention_dim, bias=False, init_gain="tanh") + + def forward(self, attention_cat): + """ + Shapes: + attention_cat: [B, 2, C] + """ + processed_attention = self.location_conv1d(attention_cat) + processed_attention = self.location_dense(processed_attention.transpose(1, 2)) + return processed_attention + + +class GravesAttention(nn.Module): + """Graves Attention as is ref1 with updates from ref2. + ref1: https://arxiv.org/abs/1910.10288 + ref2: https://arxiv.org/pdf/1906.01083.pdf + + Args: + query_dim (int): number of channels in query tensor. + K (int): number of Gaussian heads to be used for computing attention. + """ + + COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi)) + + def __init__(self, query_dim, K): + super().__init__() + self._mask_value = 1e-8 + self.K = K + # self.attention_alignment = 0.05 + self.eps = 1e-5 + self.J = None + self.N_a = nn.Sequential( + nn.Linear(query_dim, query_dim, bias=True), nn.ReLU(), nn.Linear(query_dim, 3 * K, bias=True) + ) + self.attention_weights = None + self.mu_prev = None + self.init_layers() + + def init_layers(self): + torch.nn.init.constant_(self.N_a[2].bias[(2 * self.K) : (3 * self.K)], 1.0) # bias mean + torch.nn.init.constant_(self.N_a[2].bias[self.K : (2 * self.K)], 10) # bias std + + def init_states(self, inputs): + if self.J is None or inputs.shape[1] + 1 > self.J.shape[-1]: + self.J = torch.arange(0, inputs.shape[1] + 2.0).to(inputs.device) + 0.5 + self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) + self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) + + # pylint: disable=R0201 + # pylint: disable=unused-argument + def preprocess_inputs(self, inputs): + return None + + def forward(self, query, inputs, processed_inputs, mask): + """ + Shapes: + query: [B, C_attention_rnn] + inputs: [B, T_in, C_encoder] + processed_inputs: place_holder + mask: [B, T_in] + """ + gbk_t = self.N_a(query) + gbk_t = gbk_t.view(gbk_t.size(0), -1, self.K) + + # attention model parameters + # each B x K + g_t = gbk_t[:, 0, :] + b_t = gbk_t[:, 1, :] + k_t = gbk_t[:, 2, :] + + # dropout to decorrelate attention heads + g_t = torch.nn.functional.dropout(g_t, p=0.5, training=self.training) + + # attention GMM parameters + sig_t = torch.nn.functional.softplus(b_t) + self.eps + + mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) + g_t = torch.softmax(g_t, dim=-1) + self.eps + + j = self.J[: inputs.size(1) + 1] + + # attention weights + phi_t = g_t.unsqueeze(-1) * (1 / (1 + torch.sigmoid((mu_t.unsqueeze(-1) - j) / sig_t.unsqueeze(-1)))) + + # discritize attention weights + alpha_t = torch.sum(phi_t, 1) + alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1] + alpha_t[alpha_t == 0] = 1e-8 + + # apply masking + if mask is not None: + alpha_t.data.masked_fill_(~mask, self._mask_value) + + context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1) + self.attention_weights = alpha_t + self.mu_prev = mu_t + return context + + +class OriginalAttention(nn.Module): + """Bahdanau Attention with various optional modifications. + - Location sensitive attnetion: https://arxiv.org/abs/1712.05884 + - Forward Attention: https://arxiv.org/abs/1807.06736 + state masking at inference + - Using sigmoid instead of softmax normalization + - Attention windowing at inference time + + Note: + Location Sensitive Attention extends the additive attention mechanism + to use cumulative attention weights from previous decoder time steps with the current time step features. + + Forward attention computes most probable monotonic alignment. The modified attention probabilities at each + timestep are computed recursively by the forward algorithm. + + Transition agent in the forward attention explicitly gates the attention mechanism whether to move forward or + stay at each decoder timestep. + + Attention windowing is a inductive prior that prevents the model from attending to previous and future timesteps + beyond a certain window. + + Args: + query_dim (int): number of channels in the query tensor. + embedding_dim (int): number of channels in the vakue tensor. In general, the value tensor is the output of the encoder layer. + attention_dim (int): number of channels of the inner attention layers. + location_attention (bool): enable/disable location sensitive attention. + attention_location_n_filters (int): number of location attention filters. + attention_location_kernel_size (int): filter size of location attention convolution layer. + windowing (int): window size for attention windowing. if it is 5, for computing the attention, it only considers the time steps [(t-5), ..., (t+5)] of the input. + norm (str): normalization method applied to the attention weights. 'softmax' or 'sigmoid' + forward_attn (bool): enable/disable forward attention. + trans_agent (bool): enable/disable transition agent in the forward attention. + forward_attn_mask (int): enable/disable an explicit masking in forward attention. It is useful to set at especially inference time. + """ + + # Pylint gets confused by PyTorch conventions here + # pylint: disable=attribute-defined-outside-init + def __init__( + self, + query_dim, + embedding_dim, + attention_dim, + location_attention, + attention_location_n_filters, + attention_location_kernel_size, + windowing, + norm, + forward_attn, + trans_agent, + forward_attn_mask, + ): + super().__init__() + self.query_layer = Linear(query_dim, attention_dim, bias=False, init_gain="tanh") + self.inputs_layer = Linear(embedding_dim, attention_dim, bias=False, init_gain="tanh") + self.v = Linear(attention_dim, 1, bias=True) + if trans_agent: + self.ta = nn.Linear(query_dim + embedding_dim, 1, bias=True) + if location_attention: + self.location_layer = LocationLayer( + attention_dim, + attention_location_n_filters, + attention_location_kernel_size, + ) + self._mask_value = -float("inf") + self.windowing = windowing + self.win_idx = None + self.norm = norm + self.forward_attn = forward_attn + self.trans_agent = trans_agent + self.forward_attn_mask = forward_attn_mask + self.location_attention = location_attention + + def init_win_idx(self): + self.win_idx = -1 + self.win_back = 2 + self.win_front = 6 + + def init_forward_attn(self, inputs): + B = inputs.shape[0] + T = inputs.shape[1] + self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device) + self.u = (0.5 * torch.ones([B, 1])).to(inputs.device) + + def init_location_attention(self, inputs): + B = inputs.size(0) + T = inputs.size(1) + self.attention_weights_cum = torch.zeros([B, T], device=inputs.device) + + def init_states(self, inputs): + B = inputs.size(0) + T = inputs.size(1) + self.attention_weights = torch.zeros([B, T], device=inputs.device) + if self.location_attention: + self.init_location_attention(inputs) + if self.forward_attn: + self.init_forward_attn(inputs) + if self.windowing: + self.init_win_idx() + + def preprocess_inputs(self, inputs): + return self.inputs_layer(inputs) + + def update_location_attention(self, alignments): + self.attention_weights_cum += alignments + + def get_location_attention(self, query, processed_inputs): + attention_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1) + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_cat) + energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_inputs)) + energies = energies.squeeze(-1) + return energies, processed_query + + def get_attention(self, query, processed_inputs): + processed_query = self.query_layer(query.unsqueeze(1)) + energies = self.v(torch.tanh(processed_query + processed_inputs)) + energies = energies.squeeze(-1) + return energies, processed_query + + def apply_windowing(self, attention, inputs): + back_win = self.win_idx - self.win_back + front_win = self.win_idx + self.win_front + if back_win > 0: + attention[:, :back_win] = -float("inf") + if front_win < inputs.shape[1]: + attention[:, front_win:] = -float("inf") + # this is a trick to solve a special problem. + # but it does not hurt. + if self.win_idx == -1: + attention[:, 0] = attention.max() + # Update the window + self.win_idx = torch.argmax(attention, 1).long()[0].item() + return attention + + def apply_forward_attention(self, alignment): + # forward attention + fwd_shifted_alpha = F.pad(self.alpha[:, :-1].clone().to(alignment.device), (1, 0, 0, 0)) + # compute transition potentials + alpha = ((1 - self.u) * self.alpha + self.u * fwd_shifted_alpha + 1e-8) * alignment + # force incremental alignment + if not self.training and self.forward_attn_mask: + _, n = fwd_shifted_alpha.max(1) + val, _ = alpha.max(1) + for b in range(alignment.shape[0]): + alpha[b, n[b] + 3 :] = 0 + alpha[b, : (n[b] - 1)] = 0 # ignore all previous states to prevent repetition. + alpha[b, (n[b] - 2)] = 0.01 * val[b] # smoothing factor for the prev step + # renormalize attention weights + alpha = alpha / alpha.sum(dim=1, keepdim=True) + return alpha + + def forward(self, query, inputs, processed_inputs, mask): + """ + shapes: + query: [B, C_attn_rnn] + inputs: [B, T_en, D_en] + processed_inputs: [B, T_en, D_attn] + mask: [B, T_en] + """ + if self.location_attention: + attention, _ = self.get_location_attention(query, processed_inputs) + else: + attention, _ = self.get_attention(query, processed_inputs) + # apply masking + if mask is not None: + attention.data.masked_fill_(~mask, self._mask_value) + # apply windowing - only in eval mode + if not self.training and self.windowing: + attention = self.apply_windowing(attention, inputs) + + # normalize attention values + if self.norm == "softmax": + alignment = torch.softmax(attention, dim=-1) + elif self.norm == "sigmoid": + alignment = torch.sigmoid(attention) / torch.sigmoid(attention).sum(dim=1, keepdim=True) + else: + raise ValueError("Unknown value for attention norm type") + + if self.location_attention: + self.update_location_attention(alignment) + + # apply forward attention if enabled + if self.forward_attn: + alignment = self.apply_forward_attention(alignment) + self.alpha = alignment + + context = torch.bmm(alignment.unsqueeze(1), inputs) + context = context.squeeze(1) + self.attention_weights = alignment + + # compute transition agent + if self.forward_attn and self.trans_agent: + ta_input = torch.cat([context, query.squeeze(1)], dim=-1) + self.u = torch.sigmoid(self.ta(ta_input)) + return context + + +class MonotonicDynamicConvolutionAttention(nn.Module): + """Dynamic convolution attention from + https://arxiv.org/pdf/1910.10288.pdf + + + query -> linear -> tanh -> linear ->| + | mask values + v | | + atten_w(t-1) -|-> conv1d_dynamic -> linear -|-> tanh -> + -> softmax -> * -> * -> context + |-> conv1d_static -> linear -| | + |-> conv1d_prior -> log ----------------| + + query: attention rnn output. + + Note: + Dynamic convolution attention is an alternation of the location senstive attention with + dynamically computed convolution filters from the previous attention scores and a set of + constraints to keep the attention alignment diagonal. + DCA is sensitive to mixed precision training and might cause instable training. + + Args: + query_dim (int): number of channels in the query tensor. + embedding_dim (int): number of channels in the value tensor. + static_filter_dim (int): number of channels in the convolution layer computing the static filters. + static_kernel_size (int): kernel size for the convolution layer computing the static filters. + dynamic_filter_dim (int): number of channels in the convolution layer computing the dynamic filters. + dynamic_kernel_size (int): kernel size for the convolution layer computing the dynamic filters. + prior_filter_len (int, optional): [description]. Defaults to 11 from the paper. + alpha (float, optional): [description]. Defaults to 0.1 from the paper. + beta (float, optional): [description]. Defaults to 0.9 from the paper. + """ + + def __init__( + self, + query_dim, + embedding_dim, # pylint: disable=unused-argument + attention_dim, + static_filter_dim, + static_kernel_size, + dynamic_filter_dim, + dynamic_kernel_size, + prior_filter_len=11, + alpha=0.1, + beta=0.9, + ): + super().__init__() + self._mask_value = 1e-8 + self.dynamic_filter_dim = dynamic_filter_dim + self.dynamic_kernel_size = dynamic_kernel_size + self.prior_filter_len = prior_filter_len + self.attention_weights = None + # setup key and query layers + self.query_layer = nn.Linear(query_dim, attention_dim) + self.key_layer = nn.Linear(attention_dim, dynamic_filter_dim * dynamic_kernel_size, bias=False) + self.static_filter_conv = nn.Conv1d( + 1, + static_filter_dim, + static_kernel_size, + padding=(static_kernel_size - 1) // 2, + bias=False, + ) + self.static_filter_layer = nn.Linear(static_filter_dim, attention_dim, bias=False) + self.dynamic_filter_layer = nn.Linear(dynamic_filter_dim, attention_dim) + self.v = nn.Linear(attention_dim, 1, bias=False) + + prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1, alpha, beta) + self.register_buffer("prior", torch.FloatTensor(prior).flip(0)) + + # pylint: disable=unused-argument + def forward(self, query, inputs, processed_inputs, mask): + """ + query: [B, C_attn_rnn] + inputs: [B, T_en, D_en] + processed_inputs: place holder. + mask: [B, T_en] + """ + # compute prior filters + prior_filter = F.conv1d( + F.pad(self.attention_weights.unsqueeze(1), (self.prior_filter_len - 1, 0)), self.prior.view(1, 1, -1) + ) + prior_filter = torch.log(prior_filter.clamp_min_(1e-6)).squeeze(1) + G = self.key_layer(torch.tanh(self.query_layer(query))) + # compute dynamic filters + dynamic_filter = F.conv1d( + self.attention_weights.unsqueeze(0), + G.view(-1, 1, self.dynamic_kernel_size), + padding=(self.dynamic_kernel_size - 1) // 2, + groups=query.size(0), + ) + dynamic_filter = dynamic_filter.view(query.size(0), self.dynamic_filter_dim, -1).transpose(1, 2) + # compute static filters + static_filter = self.static_filter_conv(self.attention_weights.unsqueeze(1)).transpose(1, 2) + alignment = ( + self.v( + torch.tanh(self.static_filter_layer(static_filter) + self.dynamic_filter_layer(dynamic_filter)) + ).squeeze(-1) + + prior_filter + ) + # compute attention weights + attention_weights = F.softmax(alignment, dim=-1) + # apply masking + if mask is not None: + attention_weights.data.masked_fill_(~mask, self._mask_value) + self.attention_weights = attention_weights + # compute context + context = torch.bmm(attention_weights.unsqueeze(1), inputs).squeeze(1) + return context + + def preprocess_inputs(self, inputs): # pylint: disable=no-self-use + return None + + def init_states(self, inputs): + B = inputs.size(0) + T = inputs.size(1) + self.attention_weights = torch.zeros([B, T], device=inputs.device) + self.attention_weights[:, 0] = 1.0 + + +def init_attn( + attn_type, + query_dim, + embedding_dim, + attention_dim, + location_attention, + attention_location_n_filters, + attention_location_kernel_size, + windowing, + norm, + forward_attn, + trans_agent, + forward_attn_mask, + attn_K, +): + if attn_type == "original": + return OriginalAttention( + query_dim, + embedding_dim, + attention_dim, + location_attention, + attention_location_n_filters, + attention_location_kernel_size, + windowing, + norm, + forward_attn, + trans_agent, + forward_attn_mask, + ) + if attn_type == "graves": + return GravesAttention(query_dim, attn_K) + if attn_type == "dynamic_convolution": + return MonotonicDynamicConvolutionAttention( + query_dim, + embedding_dim, + attention_dim, + static_filter_dim=8, + static_kernel_size=21, + dynamic_filter_dim=8, + dynamic_kernel_size=21, + prior_filter_len=11, + alpha=0.1, + beta=0.9, + ) + + raise RuntimeError(f" [!] Given Attention Type '{attn_type}' is not exist.") diff --git a/TTS/tts/layers/tacotron/capacitron_layers.py b/TTS/tts/layers/tacotron/capacitron_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..2181ffa7ec4e1f54d86cc5865a8fa7f6b6e362af --- /dev/null +++ b/TTS/tts/layers/tacotron/capacitron_layers.py @@ -0,0 +1,205 @@ +import torch +from torch import nn +from torch.distributions.multivariate_normal import MultivariateNormal as MVN +from torch.nn import functional as F + + +class CapacitronVAE(nn.Module): + """Effective Use of Variational Embedding Capacity for prosody transfer. + + See https://arxiv.org/abs/1906.03402""" + + def __init__( + self, + num_mel, + capacitron_VAE_embedding_dim, + encoder_output_dim=256, + reference_encoder_out_dim=128, + speaker_embedding_dim=None, + text_summary_embedding_dim=None, + ): + super().__init__() + # Init distributions + self.prior_distribution = MVN( + torch.zeros(capacitron_VAE_embedding_dim), torch.eye(capacitron_VAE_embedding_dim) + ) + self.approximate_posterior_distribution = None + # define output ReferenceEncoder dim to the capacitron_VAE_embedding_dim + self.encoder = ReferenceEncoder(num_mel, out_dim=reference_encoder_out_dim) + + # Init beta, the lagrange-like term for the KL distribution + self.beta = torch.nn.Parameter(torch.log(torch.exp(torch.Tensor([1.0])) - 1), requires_grad=True) + mlp_input_dimension = reference_encoder_out_dim + + if text_summary_embedding_dim is not None: + self.text_summary_net = TextSummary(text_summary_embedding_dim, encoder_output_dim=encoder_output_dim) + mlp_input_dimension += text_summary_embedding_dim + if speaker_embedding_dim is not None: + # TODO: Test a multispeaker model! + mlp_input_dimension += speaker_embedding_dim + self.post_encoder_mlp = PostEncoderMLP(mlp_input_dimension, capacitron_VAE_embedding_dim) + + def forward(self, reference_mel_info=None, text_info=None, speaker_embedding=None): + # Use reference + if reference_mel_info is not None: + reference_mels = reference_mel_info[0] # [batch_size, num_frames, num_mels] + mel_lengths = reference_mel_info[1] # [batch_size] + enc_out = self.encoder(reference_mels, mel_lengths) + + # concat speaker_embedding and/or text summary embedding + if text_info is not None: + text_inputs = text_info[0] # [batch_size, num_characters, num_embedding] + input_lengths = text_info[1] + text_summary_out = self.text_summary_net(text_inputs, input_lengths).to(reference_mels.device) + enc_out = torch.cat([enc_out, text_summary_out], dim=-1) + if speaker_embedding is not None: + speaker_embedding = torch.squeeze(speaker_embedding) + enc_out = torch.cat([enc_out, speaker_embedding], dim=-1) + + # Feed the output of the ref encoder and information about text/speaker into + # an MLP to produce the parameteres for the approximate poterior distributions + mu, sigma = self.post_encoder_mlp(enc_out) + # convert to cpu because prior_distribution was created on cpu + mu = mu.cpu() + sigma = sigma.cpu() + + # Sample from the posterior: z ~ q(z|x) + self.approximate_posterior_distribution = MVN(mu, torch.diag_embed(sigma)) + VAE_embedding = self.approximate_posterior_distribution.rsample() + # Infer from the model, bypasses encoding + else: + # Sample from the prior: z ~ p(z) + VAE_embedding = self.prior_distribution.sample().unsqueeze(0) + + # reshape to [batch_size, 1, capacitron_VAE_embedding_dim] + return VAE_embedding.unsqueeze(1), self.approximate_posterior_distribution, self.prior_distribution, self.beta + + +class ReferenceEncoder(nn.Module): + """NN module creating a fixed size prosody embedding from a spectrogram. + + inputs: mel spectrograms [batch_size, num_spec_frames, num_mel] + outputs: [batch_size, embedding_dim] + """ + + def __init__(self, num_mel, out_dim): + super().__init__() + self.num_mel = num_mel + filters = [1] + [32, 32, 64, 64, 128, 128] + num_layers = len(filters) - 1 + convs = [ + nn.Conv2d( + in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(2, 2) + ) + for i in range(num_layers) + ] + self.convs = nn.ModuleList(convs) + self.training = False + self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]]) + + post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers) + self.recurrence = nn.LSTM( + input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False + ) + + def forward(self, inputs, input_lengths): + batch_size = inputs.size(0) + x = inputs.view(batch_size, 1, -1, self.num_mel) # [batch_size, num_channels==1, num_frames, num_mel] + valid_lengths = input_lengths.float() # [batch_size] + for conv, bn in zip(self.convs, self.bns): + x = conv(x) + x = bn(x) + x = F.relu(x) + + # Create the post conv width mask based on the valid lengths of the output of the convolution. + # The valid lengths for the output of a convolution on varying length inputs is + # ceil(input_length/stride) + 1 for stride=3 and padding=2 + # For example (kernel_size=3, stride=2, padding=2): + # 0 0 x x x x x 0 0 -> Input = 5, 0 is zero padding, x is valid values coming from padding=2 in conv2d + # _____ + # x _____ + # x _____ + # x ____ + # x + # x x x x -> Output valid length = 4 + # Since every example in te batch is zero padded and therefore have separate valid_lengths, + # we need to mask off all the values AFTER the valid length for each example in the batch. + # Otherwise, the convolutions create noise and a lot of not real information + valid_lengths = (valid_lengths / 2).float() + valid_lengths = torch.ceil(valid_lengths).to(dtype=torch.int64) + 1 # 2 is stride -- size: [batch_size] + post_conv_max_width = x.size(2) + + mask = torch.arange(post_conv_max_width).to(inputs.device).expand( + len(valid_lengths), post_conv_max_width + ) < valid_lengths.unsqueeze(1) + mask = mask.expand(1, 1, -1, -1).transpose(2, 0).transpose(-1, 2) # [batch_size, 1, post_conv_max_width, 1] + x = x * mask + + x = x.transpose(1, 2) + # x: 4D tensor [batch_size, post_conv_width, + # num_channels==128, post_conv_height] + + post_conv_width = x.size(1) + x = x.contiguous().view(batch_size, post_conv_width, -1) + # x: 3D tensor [batch_size, post_conv_width, + # num_channels*post_conv_height] + + # Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding + post_conv_input_lengths = valid_lengths + packed_seqs = nn.utils.rnn.pack_padded_sequence( + x, post_conv_input_lengths.tolist(), batch_first=True, enforce_sorted=False + ) # dynamic rnn sequence padding + self.recurrence.flatten_parameters() + _, (ht, _) = self.recurrence(packed_seqs) + last_output = ht[-1] + + return last_output.to(inputs.device) # [B, 128] + + @staticmethod + def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs): + """Height of spec after n convolutions with fixed kernel/stride/pad.""" + for _ in range(n_convs): + height = (height - kernel_size + 2 * pad) // stride + 1 + return height + + +class TextSummary(nn.Module): + def __init__(self, embedding_dim, encoder_output_dim): + super().__init__() + self.lstm = nn.LSTM( + encoder_output_dim, # text embedding dimension from the text encoder + embedding_dim, # fixed length output summary the lstm creates from the input + batch_first=True, + bidirectional=False, + ) + + def forward(self, inputs, input_lengths): + # Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding + packed_seqs = nn.utils.rnn.pack_padded_sequence( + inputs, input_lengths.tolist(), batch_first=True, enforce_sorted=False + ) # dynamic rnn sequence padding + self.lstm.flatten_parameters() + _, (ht, _) = self.lstm(packed_seqs) + last_output = ht[-1] + return last_output + + +class PostEncoderMLP(nn.Module): + def __init__(self, input_size, hidden_size): + super().__init__() + self.hidden_size = hidden_size + modules = [ + nn.Linear(input_size, hidden_size), # Hidden Layer + nn.Tanh(), + nn.Linear(hidden_size, hidden_size * 2), + ] # Output layer twice the size for mean and variance + self.net = nn.Sequential(*modules) + self.softplus = nn.Softplus() + + def forward(self, _input): + mlp_output = self.net(_input) + # The mean parameter is unconstrained + mu = mlp_output[:, : self.hidden_size] + # The standard deviation must be positive. Parameterise with a softplus + sigma = self.softplus(mlp_output[:, self.hidden_size :]) + return mu, sigma diff --git a/TTS/tts/layers/tacotron/common_layers.py b/TTS/tts/layers/tacotron/common_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f78ff1e75f6c23eb1a0fe827247a1127bc8f9958 --- /dev/null +++ b/TTS/tts/layers/tacotron/common_layers.py @@ -0,0 +1,119 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class Linear(nn.Module): + """Linear layer with a specific initialization. + + Args: + in_features (int): number of channels in the input tensor. + out_features (int): number of channels in the output tensor. + bias (bool, optional): enable/disable bias in the layer. Defaults to True. + init_gain (str, optional): method to compute the gain in the weight initializtion based on the nonlinear activation used afterwards. Defaults to 'linear'. + """ + + def __init__(self, in_features, out_features, bias=True, init_gain="linear"): + super().__init__() + self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias) + self._init_w(init_gain) + + def _init_w(self, init_gain): + torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain)) + + def forward(self, x): + return self.linear_layer(x) + + +class LinearBN(nn.Module): + """Linear layer with Batch Normalization. + + x -> linear -> BN -> o + + Args: + in_features (int): number of channels in the input tensor. + out_features (int ): number of channels in the output tensor. + bias (bool, optional): enable/disable bias in the linear layer. Defaults to True. + init_gain (str, optional): method to set the gain for weight initialization. Defaults to 'linear'. + """ + + def __init__(self, in_features, out_features, bias=True, init_gain="linear"): + super().__init__() + self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias) + self.batch_normalization = nn.BatchNorm1d(out_features, momentum=0.1, eps=1e-5) + self._init_w(init_gain) + + def _init_w(self, init_gain): + torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain)) + + def forward(self, x): + """ + Shapes: + x: [T, B, C] or [B, C] + """ + out = self.linear_layer(x) + if len(out.shape) == 3: + out = out.permute(1, 2, 0) + out = self.batch_normalization(out) + if len(out.shape) == 3: + out = out.permute(2, 0, 1) + return out + + +class Prenet(nn.Module): + """Tacotron specific Prenet with an optional Batch Normalization. + + Note: + Prenet with BN improves the model performance significantly especially + if it is enabled after learning a diagonal attention alignment with the original + prenet. However, if the target dataset is high quality then it also works from + the start. It is also suggested to disable dropout if BN is in use. + + prenet_type == "original" + x -> [linear -> ReLU -> Dropout]xN -> o + + prenet_type == "bn" + x -> [linear -> BN -> ReLU -> Dropout]xN -> o + + Args: + in_features (int): number of channels in the input tensor and the inner layers. + prenet_type (str, optional): prenet type "original" or "bn". Defaults to "original". + prenet_dropout (bool, optional): dropout rate. Defaults to True. + dropout_at_inference (bool, optional): use dropout at inference. It leads to a better quality for some models. + out_features (list, optional): List of output channels for each prenet block. + It also defines number of the prenet blocks based on the length of argument list. + Defaults to [256, 256]. + bias (bool, optional): enable/disable bias in prenet linear layers. Defaults to True. + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + in_features, + prenet_type="original", + prenet_dropout=True, + dropout_at_inference=False, + out_features=[256, 256], + bias=True, + ): + super().__init__() + self.prenet_type = prenet_type + self.prenet_dropout = prenet_dropout + self.dropout_at_inference = dropout_at_inference + in_features = [in_features] + out_features[:-1] + if prenet_type == "bn": + self.linear_layers = nn.ModuleList( + [LinearBN(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)] + ) + elif prenet_type == "original": + self.linear_layers = nn.ModuleList( + [Linear(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)] + ) + + def forward(self, x): + for linear in self.linear_layers: + if self.prenet_dropout: + x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training or self.dropout_at_inference) + else: + x = F.relu(linear(x)) + return x diff --git a/TTS/tts/layers/tacotron/gst_layers.py b/TTS/tts/layers/tacotron/gst_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..05dba7084ff5533b68779d46238530f4988db934 --- /dev/null +++ b/TTS/tts/layers/tacotron/gst_layers.py @@ -0,0 +1,149 @@ +import torch +import torch.nn.functional as F +from torch import nn + + +class GST(nn.Module): + """Global Style Token Module for factorizing prosody in speech. + + See https://arxiv.org/pdf/1803.09017""" + + def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim=None): + super().__init__() + self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim) + self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, embedded_speaker_dim) + + def forward(self, inputs, speaker_embedding=None): + enc_out = self.encoder(inputs) + # concat speaker_embedding + if speaker_embedding is not None: + enc_out = torch.cat([enc_out, speaker_embedding], dim=-1) + style_embed = self.style_token_layer(enc_out) + + return style_embed + + +class ReferenceEncoder(nn.Module): + """NN module creating a fixed size prosody embedding from a spectrogram. + + inputs: mel spectrograms [batch_size, num_spec_frames, num_mel] + outputs: [batch_size, embedding_dim] + """ + + def __init__(self, num_mel, embedding_dim): + super().__init__() + self.num_mel = num_mel + filters = [1] + [32, 32, 64, 64, 128, 128] + num_layers = len(filters) - 1 + convs = [ + nn.Conv2d( + in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1) + ) + for i in range(num_layers) + ] + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]]) + + post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 1, num_layers) + self.recurrence = nn.GRU( + input_size=filters[-1] * post_conv_height, hidden_size=embedding_dim // 2, batch_first=True + ) + + def forward(self, inputs): + batch_size = inputs.size(0) + x = inputs.view(batch_size, 1, -1, self.num_mel) + # x: 4D tensor [batch_size, num_channels==1, num_frames, num_mel] + for conv, bn in zip(self.convs, self.bns): + x = conv(x) + x = bn(x) + x = F.relu(x) + + x = x.transpose(1, 2) + # x: 4D tensor [batch_size, post_conv_width, + # num_channels==128, post_conv_height] + post_conv_width = x.size(1) + x = x.contiguous().view(batch_size, post_conv_width, -1) + # x: 3D tensor [batch_size, post_conv_width, + # num_channels*post_conv_height] + self.recurrence.flatten_parameters() + _, out = self.recurrence(x) + # out: 3D tensor [seq_len==1, batch_size, encoding_size=128] + + return out.squeeze(0) + + @staticmethod + def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs): + """Height of spec after n convolutions with fixed kernel/stride/pad.""" + for _ in range(n_convs): + height = (height - kernel_size + 2 * pad) // stride + 1 + return height + + +class StyleTokenLayer(nn.Module): + """NN Module attending to style tokens based on prosody encodings.""" + + def __init__(self, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None): + super().__init__() + + self.query_dim = gst_embedding_dim // 2 + + if d_vector_dim: + self.query_dim += d_vector_dim + + self.key_dim = gst_embedding_dim // num_heads + self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim)) + nn.init.normal_(self.style_tokens, mean=0, std=0.5) + self.attention = MultiHeadAttention( + query_dim=self.query_dim, key_dim=self.key_dim, num_units=gst_embedding_dim, num_heads=num_heads + ) + + def forward(self, inputs): + batch_size = inputs.size(0) + prosody_encoding = inputs.unsqueeze(1) + # prosody_encoding: 3D tensor [batch_size, 1, encoding_size==128] + tokens = torch.tanh(self.style_tokens).unsqueeze(0).expand(batch_size, -1, -1) + # tokens: 3D tensor [batch_size, num tokens, token embedding size] + style_embed = self.attention(prosody_encoding, tokens) + + return style_embed + + +class MultiHeadAttention(nn.Module): + """ + input: + query --- [N, T_q, query_dim] + key --- [N, T_k, key_dim] + output: + out --- [N, T_q, num_units] + """ + + def __init__(self, query_dim, key_dim, num_units, num_heads): + super().__init__() + self.num_units = num_units + self.num_heads = num_heads + self.key_dim = key_dim + + self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False) + self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + + def forward(self, query, key): + queries = self.W_query(query) # [N, T_q, num_units] + keys = self.W_key(key) # [N, T_k, num_units] + values = self.W_value(key) + + split_size = self.num_units // self.num_heads + queries = torch.stack(torch.split(queries, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h] + keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + + # score = softmax(QK^T / (d_k**0.5)) + scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k] + scores = scores / (self.key_dim**0.5) + scores = F.softmax(scores, dim=3) + + # out = score * V + out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] + out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] + + return out diff --git a/TTS/tts/layers/tacotron/tacotron.py b/TTS/tts/layers/tacotron/tacotron.py new file mode 100644 index 0000000000000000000000000000000000000000..32643dfcee462e03fe3b316eadfa29bfb72f3bdb --- /dev/null +++ b/TTS/tts/layers/tacotron/tacotron.py @@ -0,0 +1,507 @@ +# coding: utf-8 +# adapted from https://github.com/r9y9/tacotron_pytorch + +import logging + +import torch +from torch import nn + +from .attentions import init_attn +from .common_layers import Prenet + +logger = logging.getLogger(__name__) + + +class BatchNormConv1d(nn.Module): + r"""A wrapper for Conv1d with BatchNorm. It sets the activation + function between Conv and BatchNorm layers. BatchNorm layer + is initialized with the TF default values for momentum and eps. + + Args: + in_channels: size of each input sample + out_channels: size of each output samples + kernel_size: kernel size of conv filters + stride: stride of conv filters + padding: padding of conv filters + activation: activation function set b/w Conv1d and BatchNorm + + Shapes: + - input: (B, D) + - output: (B, D) + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, activation=None): + super().__init__() + self.padding = padding + self.padder = nn.ConstantPad1d(padding, 0) + self.conv1d = nn.Conv1d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0, bias=False + ) + # Following tensorflow's default parameters + self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3) + self.activation = activation + # self.init_layers() + + def init_layers(self): + if isinstance(self.activation, torch.nn.ReLU): + w_gain = "relu" + elif isinstance(self.activation, torch.nn.Tanh): + w_gain = "tanh" + elif self.activation is None: + w_gain = "linear" + else: + raise RuntimeError("Unknown activation function") + torch.nn.init.xavier_uniform_(self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain)) + + def forward(self, x): + x = self.padder(x) + x = self.conv1d(x) + x = self.bn(x) + if self.activation is not None: + x = self.activation(x) + return x + + +class Highway(nn.Module): + r"""Highway layers as explained in https://arxiv.org/abs/1505.00387 + + Args: + in_features (int): size of each input sample + out_feature (int): size of each output sample + + Shapes: + - input: (B, *, H_in) + - output: (B, *, H_out) + """ + + # TODO: Try GLU layer + def __init__(self, in_features, out_feature): + super().__init__() + self.H = nn.Linear(in_features, out_feature) + self.H.bias.data.zero_() + self.T = nn.Linear(in_features, out_feature) + self.T.bias.data.fill_(-1) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + # self.init_layers() + + def init_layers(self): + torch.nn.init.xavier_uniform_(self.H.weight, gain=torch.nn.init.calculate_gain("relu")) + torch.nn.init.xavier_uniform_(self.T.weight, gain=torch.nn.init.calculate_gain("sigmoid")) + + def forward(self, inputs): + H = self.relu(self.H(inputs)) + T = self.sigmoid(self.T(inputs)) + return H * T + inputs * (1.0 - T) + + +class CBHG(nn.Module): + """CBHG module: a recurrent neural network composed of: + - 1-d convolution banks + - Highway networks + residual connections + - Bidirectional gated recurrent units + + Args: + in_features (int): sample size + K (int): max filter size in conv bank + projections (list): conv channel sizes for conv projections + num_highways (int): number of highways layers + + Shapes: + - input: (B, C, T_in) + - output: (B, T_in, C*2) + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + in_features, + K=16, + conv_bank_features=128, + conv_projections=[128, 128], + highway_features=128, + gru_features=128, + num_highways=4, + ): + super().__init__() + self.in_features = in_features + self.conv_bank_features = conv_bank_features + self.highway_features = highway_features + self.gru_features = gru_features + self.conv_projections = conv_projections + self.relu = nn.ReLU() + # list of conv1d bank with filter size k=1...K + # TODO: try dilational layers instead + self.conv1d_banks = nn.ModuleList( + [ + BatchNormConv1d( + in_features, + conv_bank_features, + kernel_size=k, + stride=1, + padding=[(k - 1) // 2, k // 2], + activation=self.relu, + ) + for k in range(1, K + 1) + ] + ) + # max pooling of conv bank, with padding + # TODO: try average pooling OR larger kernel size + out_features = [K * conv_bank_features] + conv_projections[:-1] + activations = [self.relu] * (len(conv_projections) - 1) + activations += [None] + # setup conv1d projection layers + layer_set = [] + for in_size, out_size, ac in zip(out_features, conv_projections, activations): + layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, padding=[1, 1], activation=ac) + layer_set.append(layer) + self.conv1d_projections = nn.ModuleList(layer_set) + # setup Highway layers + if self.highway_features != conv_projections[-1]: + self.pre_highway = nn.Linear(conv_projections[-1], highway_features, bias=False) + self.highways = nn.ModuleList([Highway(highway_features, highway_features) for _ in range(num_highways)]) + # bi-directional GPU layer + self.gru = nn.GRU(gru_features, gru_features, 1, batch_first=True, bidirectional=True) + + def forward(self, inputs): + # (B, in_features, T_in) + x = inputs + # (B, hid_features*K, T_in) + # Concat conv1d bank outputs + outs = [] + for conv1d in self.conv1d_banks: + out = conv1d(x) + outs.append(out) + x = torch.cat(outs, dim=1) + assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks) + for conv1d in self.conv1d_projections: + x = conv1d(x) + x += inputs + x = x.transpose(1, 2) + if self.highway_features != self.conv_projections[-1]: + x = self.pre_highway(x) + # Residual connection + # TODO: try residual scaling as in Deep Voice 3 + # TODO: try plain residual layers + for highway in self.highways: + x = highway(x) + # (B, T_in, hid_features*2) + # TODO: replace GRU with convolution as in Deep Voice 3 + self.gru.flatten_parameters() + outputs, _ = self.gru(x) + return outputs + + +class EncoderCBHG(nn.Module): + r"""CBHG module with Encoder specific arguments""" + + def __init__(self): + super().__init__() + self.cbhg = CBHG( + 128, + K=16, + conv_bank_features=128, + conv_projections=[128, 128], + highway_features=128, + gru_features=128, + num_highways=4, + ) + + def forward(self, x): + return self.cbhg(x) + + +class Encoder(nn.Module): + r"""Stack Prenet and CBHG module for encoder + Args: + inputs (FloatTensor): embedding features + + Shapes: + - inputs: (B, T, D_in) + - outputs: (B, T, 128 * 2) + """ + + def __init__(self, in_features): + super().__init__() + self.prenet = Prenet(in_features, out_features=[256, 128]) + self.cbhg = EncoderCBHG() + + def forward(self, inputs): + # B x T x prenet_dim + outputs = self.prenet(inputs) + outputs = self.cbhg(outputs.transpose(1, 2)) + return outputs + + +class PostCBHG(nn.Module): + def __init__(self, mel_dim): + super().__init__() + self.cbhg = CBHG( + mel_dim, + K=8, + conv_bank_features=128, + conv_projections=[256, mel_dim], + highway_features=128, + gru_features=128, + num_highways=4, + ) + + def forward(self, x): + return self.cbhg(x) + + +class Decoder(nn.Module): + """Tacotron decoder. + + Args: + in_channels (int): number of input channels. + frame_channels (int): number of feature frame channels. + r (int): number of outputs per time step (reduction rate). + memory_size (int): size of the past window. if <= 0 memory_size = r + attn_type (string): type of attention used in decoder. + attn_windowing (bool): if true, define an attention window centered to maximum + attention response. It provides more robust attention alignment especially + at interence time. + attn_norm (string): attention normalization function. 'sigmoid' or 'softmax'. + prenet_type (string): 'original' or 'bn'. + prenet_dropout (float): prenet dropout rate. + forward_attn (bool): if true, use forward attention method. https://arxiv.org/abs/1807.06736 + trans_agent (bool): if true, use transition agent. https://arxiv.org/abs/1807.06736 + forward_attn_mask (bool): if true, mask attention values smaller than a threshold. + location_attn (bool): if true, use location sensitive attention. + attn_K (int): number of attention heads for GravesAttention. + separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow. + d_vector_dim (int): size of speaker embedding vector, for multi-speaker training. + max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 500. + """ + + # Pylint gets confused by PyTorch conventions here + # pylint: disable=attribute-defined-outside-init + + def __init__( + self, + in_channels, + frame_channels, + r, + memory_size, + attn_type, + attn_windowing, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + max_decoder_steps, + ): + super().__init__() + self.r_init = r + self.r = r + self.in_channels = in_channels + self.max_decoder_steps = max_decoder_steps + self.use_memory_queue = memory_size > 0 + self.memory_size = memory_size if memory_size > 0 else r + self.frame_channels = frame_channels + self.separate_stopnet = separate_stopnet + self.query_dim = 256 + # memory -> |Prenet| -> processed_memory + prenet_dim = frame_channels * self.memory_size if self.use_memory_queue else frame_channels + self.prenet = Prenet(prenet_dim, prenet_type, prenet_dropout, out_features=[256, 128]) + # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State + # attention_rnn generates queries for the attention mechanism + self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim) + self.attention = init_attn( + attn_type=attn_type, + query_dim=self.query_dim, + embedding_dim=in_channels, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_windowing, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask, + attn_K=attn_K, + ) + # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input + self.project_to_decoder_in = nn.Linear(256 + in_channels, 256) + # decoder_RNN_input -> |RNN| -> RNN_state + self.decoder_rnns = nn.ModuleList([nn.GRUCell(256, 256) for _ in range(2)]) + # RNN_state -> |Linear| -> mel_spec + self.proj_to_mel = nn.Linear(256, frame_channels * self.r_init) + # learn init values instead of zero init. + self.stopnet = StopNet(256 + frame_channels * self.r_init) + + def set_r(self, new_r): + self.r = new_r + + def _reshape_memory(self, memory): + """ + Reshape the spectrograms for given 'r' + """ + # Grouping multiple frames if necessary + if memory.size(-1) == self.frame_channels: + memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1) + # Time first (T_decoder, B, frame_channels) + memory = memory.transpose(0, 1) + return memory + + def _init_states(self, inputs): + """ + Initialization of decoder states + """ + B = inputs.size(0) + # go frame as zeros matrix + if self.use_memory_queue: + self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels * self.memory_size) + else: + self.memory_input = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels) + # decoder states + self.attention_rnn_hidden = torch.zeros(1, device=inputs.device).repeat(B, 256) + self.decoder_rnn_hiddens = [ + torch.zeros(1, device=inputs.device).repeat(B, 256) for idx in range(len(self.decoder_rnns)) + ] + self.context_vec = inputs.data.new(B, self.in_channels).zero_() + # cache attention inputs + self.processed_inputs = self.attention.preprocess_inputs(inputs) + + def _parse_outputs(self, outputs, attentions, stop_tokens): + # Back to batch first + attentions = torch.stack(attentions).transpose(0, 1) + stop_tokens = torch.stack(stop_tokens).transpose(0, 1) + outputs = torch.stack(outputs).transpose(0, 1).contiguous() + outputs = outputs.view(outputs.size(0), -1, self.frame_channels) + outputs = outputs.transpose(1, 2) + return outputs, attentions, stop_tokens + + def decode(self, inputs, mask=None): + # Prenet + processed_memory = self.prenet(self.memory_input) + # Attention RNN + self.attention_rnn_hidden = self.attention_rnn( + torch.cat((processed_memory, self.context_vec), -1), self.attention_rnn_hidden + ) + self.context_vec = self.attention(self.attention_rnn_hidden, inputs, self.processed_inputs, mask) + # Concat RNN output and attention context vector + decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1)) + + # Pass through the decoder RNNs + for idx, decoder_rnn in enumerate(self.decoder_rnns): + self.decoder_rnn_hiddens[idx] = decoder_rnn(decoder_input, self.decoder_rnn_hiddens[idx]) + # Residual connection + decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input + decoder_output = decoder_input + + # predict mel vectors from decoder vectors + output = self.proj_to_mel(decoder_output) + # output = torch.sigmoid(output) + # predict stop token + stopnet_input = torch.cat([decoder_output, output], -1) + if self.separate_stopnet: + stop_token = self.stopnet(stopnet_input.detach()) + else: + stop_token = self.stopnet(stopnet_input) + output = output[:, : self.r * self.frame_channels] + return output, stop_token, self.attention.attention_weights + + def _update_memory_input(self, new_memory): + if self.use_memory_queue: + if self.memory_size > self.r: + # memory queue size is larger than number of frames per decoder iter + self.memory_input = torch.cat( + [new_memory, self.memory_input[:, : (self.memory_size - self.r) * self.frame_channels].clone()], + dim=-1, + ) + else: + # memory queue size smaller than number of frames per decoder iter + self.memory_input = new_memory[:, : self.memory_size * self.frame_channels] + else: + # use only the last frame prediction + # assert new_memory.shape[-1] == self.r * self.frame_channels + self.memory_input = new_memory[:, self.frame_channels * (self.r - 1) :] + + def forward(self, inputs, memory, mask): + """ + Args: + inputs: Encoder outputs. + memory: Decoder memory (autoregression. If None (at eval-time), + decoder outputs are used as decoder inputs. If None, it uses the last + output as the input. + mask: Attention mask for sequence padding. + + Shapes: + - inputs: (B, T, D_out_enc) + - memory: (B, T_mel, D_mel) + """ + # Run greedy decoding if memory is None + memory = self._reshape_memory(memory) + outputs = [] + attentions = [] + stop_tokens = [] + t = 0 + self._init_states(inputs) + self.attention.init_states(inputs) + while len(outputs) < memory.size(0): + if t > 0: + new_memory = memory[t - 1] + self._update_memory_input(new_memory) + + output, stop_token, attention = self.decode(inputs, mask) + outputs += [output] + attentions += [attention] + stop_tokens += [stop_token.squeeze(1)] + t += 1 + return self._parse_outputs(outputs, attentions, stop_tokens) + + def inference(self, inputs): + """ + Args: + inputs: encoder outputs. + Shapes: + - inputs: batch x time x encoder_out_dim + """ + outputs = [] + attentions = [] + stop_tokens = [] + t = 0 + self._init_states(inputs) + self.attention.init_states(inputs) + while True: + if t > 0: + new_memory = outputs[-1] + self._update_memory_input(new_memory) + output, stop_token, attention = self.decode(inputs, None) + stop_token = torch.sigmoid(stop_token.data) + outputs += [output] + attentions += [attention] + stop_tokens += [stop_token] + t += 1 + if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6): + break + if t > self.max_decoder_steps: + logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps) + break + return self._parse_outputs(outputs, attentions, stop_tokens) + + +class StopNet(nn.Module): + r"""Stopnet signalling decoder to stop inference. + Args: + in_features (int): feature dimension of input. + """ + + def __init__(self, in_features): + super().__init__() + self.dropout = nn.Dropout(0.1) + self.linear = nn.Linear(in_features, 1) + torch.nn.init.xavier_uniform_(self.linear.weight, gain=torch.nn.init.calculate_gain("linear")) + + def forward(self, inputs): + outputs = self.dropout(inputs) + outputs = self.linear(outputs) + return outputs diff --git a/TTS/tts/layers/tacotron/tacotron2.py b/TTS/tts/layers/tacotron/tacotron2.py new file mode 100644 index 0000000000000000000000000000000000000000..727bf9ecfd34642c6bb26d1a059168b9559be50b --- /dev/null +++ b/TTS/tts/layers/tacotron/tacotron2.py @@ -0,0 +1,418 @@ +import logging + +import torch +from torch import nn +from torch.nn import functional as F + +from .attentions import init_attn +from .common_layers import Linear, Prenet + +logger = logging.getLogger(__name__) + + +# pylint: disable=no-value-for-parameter +# pylint: disable=unexpected-keyword-arg +class ConvBNBlock(nn.Module): + r"""Convolutions with Batch Normalization and non-linear activation. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + kernel_size (int): convolution kernel size. + activation (str): 'relu', 'tanh', None (linear). + + Shapes: + - input: (B, C_in, T) + - output: (B, C_out, T) + """ + + def __init__(self, in_channels, out_channels, kernel_size, activation=None): + super().__init__() + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.convolution1d = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding) + self.batch_normalization = nn.BatchNorm1d(out_channels, momentum=0.1, eps=1e-5) + self.dropout = nn.Dropout(p=0.5) + if activation == "relu": + self.activation = nn.ReLU() + elif activation == "tanh": + self.activation = nn.Tanh() + else: + self.activation = nn.Identity() + + def forward(self, x): + o = self.convolution1d(x) + o = self.batch_normalization(o) + o = self.activation(o) + o = self.dropout(o) + return o + + +class Postnet(nn.Module): + r"""Tacotron2 Postnet + + Args: + in_out_channels (int): number of output channels. + + Shapes: + - input: (B, C_in, T) + - output: (B, C_in, T) + """ + + def __init__(self, in_out_channels, num_convs=5): + super().__init__() + self.convolutions = nn.ModuleList() + self.convolutions.append(ConvBNBlock(in_out_channels, 512, kernel_size=5, activation="tanh")) + for _ in range(1, num_convs - 1): + self.convolutions.append(ConvBNBlock(512, 512, kernel_size=5, activation="tanh")) + self.convolutions.append(ConvBNBlock(512, in_out_channels, kernel_size=5, activation=None)) + + def forward(self, x): + o = x + for layer in self.convolutions: + o = layer(o) + return o + + +class Encoder(nn.Module): + r"""Tacotron2 Encoder + + Args: + in_out_channels (int): number of input and output channels. + + Shapes: + - input: (B, C_in, T) + - output: (B, C_in, T) + """ + + def __init__(self, in_out_channels=512): + super().__init__() + self.convolutions = nn.ModuleList() + for _ in range(3): + self.convolutions.append(ConvBNBlock(in_out_channels, in_out_channels, 5, "relu")) + self.lstm = nn.LSTM( + in_out_channels, int(in_out_channels / 2), num_layers=1, batch_first=True, bias=True, bidirectional=True + ) + self.rnn_state = None + + def forward(self, x, input_lengths): + o = x + for layer in self.convolutions: + o = layer(o) + o = o.transpose(1, 2) + o = nn.utils.rnn.pack_padded_sequence(o, input_lengths.cpu(), batch_first=True) + self.lstm.flatten_parameters() + o, _ = self.lstm(o) + o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True) + return o + + def inference(self, x): + o = x + for layer in self.convolutions: + o = layer(o) + o = o.transpose(1, 2) + # self.lstm.flatten_parameters() + o, _ = self.lstm(o) + return o + + +# adapted from https://github.com/NVIDIA/tacotron2/ +class Decoder(nn.Module): + """Tacotron2 decoder. We don't use Zoneout but Dropout between RNN layers. + + Args: + in_channels (int): number of input channels. + frame_channels (int): number of feature frame channels. + r (int): number of outputs per time step (reduction rate). + memory_size (int): size of the past window. if <= 0 memory_size = r + attn_type (string): type of attention used in decoder. + attn_win (bool): if true, define an attention window centered to maximum + attention response. It provides more robust attention alignment especially + at interence time. + attn_norm (string): attention normalization function. 'sigmoid' or 'softmax'. + prenet_type (string): 'original' or 'bn'. + prenet_dropout (float): prenet dropout rate. + forward_attn (bool): if true, use forward attention method. https://arxiv.org/abs/1807.06736 + trans_agent (bool): if true, use transition agent. https://arxiv.org/abs/1807.06736 + forward_attn_mask (bool): if true, mask attention values smaller than a threshold. + location_attn (bool): if true, use location sensitive attention. + attn_K (int): number of attention heads for GravesAttention. + separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow. + max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000. + """ + + # Pylint gets confused by PyTorch conventions here + # pylint: disable=attribute-defined-outside-init + def __init__( + self, + in_channels, + frame_channels, + r, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + max_decoder_steps, + ): + super().__init__() + self.frame_channels = frame_channels + self.r_init = r + self.r = r + self.encoder_embedding_dim = in_channels + self.separate_stopnet = separate_stopnet + self.max_decoder_steps = max_decoder_steps + self.stop_threshold = 0.5 + + # model dimensions + self.query_dim = 1024 + self.decoder_rnn_dim = 1024 + self.prenet_dim = 256 + self.attn_dim = 128 + self.p_attention_dropout = 0.1 + self.p_decoder_dropout = 0.1 + + # memory -> |Prenet| -> processed_memory + prenet_dim = self.frame_channels + self.prenet = Prenet( + prenet_dim, prenet_type, prenet_dropout, out_features=[self.prenet_dim, self.prenet_dim], bias=False + ) + + self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_channels, self.query_dim, bias=True) + + self.attention = init_attn( + attn_type=attn_type, + query_dim=self.query_dim, + embedding_dim=in_channels, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_win, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask, + attn_K=attn_K, + ) + + self.decoder_rnn = nn.LSTMCell(self.query_dim + in_channels, self.decoder_rnn_dim, bias=True) + + self.linear_projection = Linear(self.decoder_rnn_dim + in_channels, self.frame_channels * self.r_init) + + self.stopnet = nn.Sequential( + nn.Dropout(0.1), + Linear(self.decoder_rnn_dim + self.frame_channels * self.r_init, 1, bias=True, init_gain="sigmoid"), + ) + self.memory_truncated = None + + def set_r(self, new_r): + self.r = new_r + + def get_go_frame(self, inputs): + B = inputs.size(0) + memory = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels * self.r) + return memory + + def _init_states(self, inputs, mask, keep_states=False): + B = inputs.size(0) + # T = inputs.size(1) + if not keep_states: + self.query = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim) + self.attention_rnn_cell_state = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim) + self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(B, self.decoder_rnn_dim) + self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(B, self.decoder_rnn_dim) + self.context = torch.zeros(1, device=inputs.device).repeat(B, self.encoder_embedding_dim) + self.inputs = inputs + self.processed_inputs = self.attention.preprocess_inputs(inputs) + self.mask = mask + + def _reshape_memory(self, memory): + """ + Reshape the spectrograms for given 'r' + """ + # Grouping multiple frames if necessary + if memory.size(-1) == self.frame_channels: + memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1) + # Time first (T_decoder, B, frame_channels) + memory = memory.transpose(0, 1) + return memory + + def _parse_outputs(self, outputs, stop_tokens, alignments): + alignments = torch.stack(alignments).transpose(0, 1) + stop_tokens = torch.stack(stop_tokens).transpose(0, 1) + outputs = torch.stack(outputs).transpose(0, 1).contiguous() + outputs = outputs.view(outputs.size(0), -1, self.frame_channels) + outputs = outputs.transpose(1, 2) + return outputs, stop_tokens, alignments + + def _update_memory(self, memory): + if len(memory.shape) == 2: + return memory[:, self.frame_channels * (self.r - 1) :] + return memory[:, :, self.frame_channels * (self.r - 1) :] + + def decode(self, memory): + """ + shapes: + - memory: B x r * self.frame_channels + """ + # self.context: B x D_en + # query_input: B x D_en + (r * self.frame_channels) + query_input = torch.cat((memory, self.context), -1) + # self.query and self.attention_rnn_cell_state : B x D_attn_rnn + self.query, self.attention_rnn_cell_state = self.attention_rnn( + query_input, (self.query, self.attention_rnn_cell_state) + ) + self.query = F.dropout(self.query, self.p_attention_dropout, self.training) + self.attention_rnn_cell_state = F.dropout( + self.attention_rnn_cell_state, self.p_attention_dropout, self.training + ) + # B x D_en + self.context = self.attention(self.query, self.inputs, self.processed_inputs, self.mask) + # B x (D_en + D_attn_rnn) + decoder_rnn_input = torch.cat((self.query, self.context), -1) + # self.decoder_hidden and self.decoder_cell: B x D_decoder_rnn + self.decoder_hidden, self.decoder_cell = self.decoder_rnn( + decoder_rnn_input, (self.decoder_hidden, self.decoder_cell) + ) + self.decoder_hidden = F.dropout(self.decoder_hidden, self.p_decoder_dropout, self.training) + # B x (D_decoder_rnn + D_en) + decoder_hidden_context = torch.cat((self.decoder_hidden, self.context), dim=1) + # B x (self.r * self.frame_channels) + decoder_output = self.linear_projection(decoder_hidden_context) + # B x (D_decoder_rnn + (self.r * self.frame_channels)) + stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1) + if self.separate_stopnet: + stop_token = self.stopnet(stopnet_input.detach()) + else: + stop_token = self.stopnet(stopnet_input) + # select outputs for the reduction rate self.r + decoder_output = decoder_output[:, : self.r * self.frame_channels] + return decoder_output, self.attention.attention_weights, stop_token + + def forward(self, inputs, memories, mask): + r"""Train Decoder with teacher forcing. + Args: + inputs: Encoder outputs. + memories: Feature frames for teacher-forcing. + mask: Attention mask for sequence padding. + + Shapes: + - inputs: (B, T, D_out_enc) + - memory: (B, T_mel, D_mel) + - outputs: (B, T_mel, D_mel) + - alignments: (B, T_in, T_out) + - stop_tokens: (B, T_out) + """ + memory = self.get_go_frame(inputs).unsqueeze(0) + memories = self._reshape_memory(memories) + memories = torch.cat((memory, memories), dim=0) + memories = self._update_memory(memories) + memories = self.prenet(memories) + + self._init_states(inputs, mask=mask) + self.attention.init_states(inputs) + + outputs, stop_tokens, alignments = [], [], [] + while len(outputs) < memories.size(0) - 1: + memory = memories[len(outputs)] + decoder_output, attention_weights, stop_token = self.decode(memory) + outputs += [decoder_output.squeeze(1)] + stop_tokens += [stop_token.squeeze(1)] + alignments += [attention_weights] + + outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments) + return outputs, alignments, stop_tokens + + def inference(self, inputs): + r"""Decoder inference without teacher forcing and use + Stopnet to stop decoder. + Args: + inputs: Encoder outputs. + + Shapes: + - inputs: (B, T, D_out_enc) + - outputs: (B, T_mel, D_mel) + - alignments: (B, T_in, T_out) + - stop_tokens: (B, T_out) + """ + memory = self.get_go_frame(inputs) + memory = self._update_memory(memory) + + self._init_states(inputs, mask=None) + self.attention.init_states(inputs) + + outputs, stop_tokens, alignments, t = [], [], [], 0 + while True: + memory = self.prenet(memory) + decoder_output, alignment, stop_token = self.decode(memory) + stop_token = torch.sigmoid(stop_token.data) + outputs += [decoder_output.squeeze(1)] + stop_tokens += [stop_token] + alignments += [alignment] + + if stop_token > self.stop_threshold and t > inputs.shape[0] // 2: + break + if len(outputs) == self.max_decoder_steps: + logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps) + break + + memory = self._update_memory(decoder_output) + t += 1 + + outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments) + + return outputs, alignments, stop_tokens + + def inference_truncated(self, inputs): + """ + Preserve decoder states for continuous inference + """ + if self.memory_truncated is None: + self.memory_truncated = self.get_go_frame(inputs) + self._init_states(inputs, mask=None, keep_states=False) + else: + self._init_states(inputs, mask=None, keep_states=True) + + self.attention.init_states(inputs) + outputs, stop_tokens, alignments, t = [], [], [], 0 + while True: + memory = self.prenet(self.memory_truncated) + decoder_output, alignment, stop_token = self.decode(memory) + stop_token = torch.sigmoid(stop_token.data) + outputs += [decoder_output.squeeze(1)] + stop_tokens += [stop_token] + alignments += [alignment] + + if stop_token > 0.7: + break + if len(outputs) == self.max_decoder_steps: + logger.info("Decoder stopped with `max_decoder_steps` %d", self.max_decoder_steps) + break + + self.memory_truncated = decoder_output + t += 1 + + outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments) + + return outputs, alignments, stop_tokens + + def inference_step(self, inputs, t, memory=None): + """ + For debug purposes + """ + if t == 0: + memory = self.get_go_frame(inputs) + self._init_states(inputs, mask=None) + + memory = self.prenet(memory) + decoder_output, stop_token, alignment = self.decode(memory) + stop_token = torch.sigmoid(stop_token.data) + memory = decoder_output + return decoder_output, stop_token, alignment diff --git a/TTS/tts/layers/tortoise/arch_utils.py b/TTS/tts/layers/tortoise/arch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..52c2526695bb07ce60d4c058c204ba119bdcdb1c --- /dev/null +++ b/TTS/tts/layers/tortoise/arch_utils.py @@ -0,0 +1,433 @@ +import functools +import math + +import fsspec +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from transformers import LogitsWarper + +from TTS.tts.layers.tortoise.xtransformers import ContinuousTransformerWrapper, RelativePositionBias +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + groups = 32 + if channels <= 16: + groups = 8 + elif channels <= 64: + groups = 16 + while channels % groups != 0: + groups = int(groups / 2) + assert groups > 2 + return GroupNorm32(groups, channels) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv, mask=None, rel_pos=None): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + if rel_pos is not None: + weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape( + bs * self.n_heads, weight.shape[-2], weight.shape[-1] + ) + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + if mask is not None: + # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. + mask = mask.repeat(self.n_heads, 1).unsqueeze(1) + weight = weight * mask + a = torch.einsum("bts,bcs->bct", weight, v) + + return a.reshape(bs, -1, length) + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + do_checkpoint=True, + relative_pos_embeddings=False, + ): + super().__init__() + self.channels = channels + self.do_checkpoint = do_checkpoint + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.norm = normalization(channels) + self.qkv = nn.Conv1d(channels, channels * 3, 1) + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) + if relative_pos_embeddings: + self.relative_pos_embeddings = RelativePositionBias( + scale=(channels // self.num_heads) ** 0.5, + causal=False, + heads=num_heads, + num_buckets=32, + max_distance=64, + ) + else: + self.relative_pos_embeddings = None + + def forward(self, x, mask=None): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv, mask, self.relative_pos_embeddings) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + """ + + def __init__(self, channels, use_conv, out_channels=None, factor=4): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.factor = factor + if use_conv: + ksize = 5 + pad = 2 + self.conv = nn.Conv1d(self.channels, self.out_channels, ksize, padding=pad) + + def forward(self, x): + assert x.shape[1] == self.channels + x = F.interpolate(x, scale_factor=self.factor, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + """ + + def __init__(self, channels, use_conv, out_channels=None, factor=4, ksize=5, pad=2): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + + stride = factor + if use_conv: + self.op = nn.Conv1d(self.channels, self.out_channels, ksize, stride=stride, padding=pad) + else: + assert self.channels == self.out_channels + self.op = nn.AvgPool1d(kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(nn.Module): + def __init__( + self, + channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + up=False, + down=False, + kernel_size=3, + ): + super().__init__() + self.channels = channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + padding = 1 if kernel_size == 3 else 2 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False) + self.x_upd = Upsample(channels, False) + elif down: + self.h_upd = Downsample(channels, False) + self.x_upd = Downsample(channels, False) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding) + else: + self.skip_connection = nn.Conv1d(channels, self.out_channels, 1) + + def forward(self, x): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AudioMiniEncoder(nn.Module): + def __init__( + self, + spec_dim, + embedding_dim, + base_channels=128, + depth=2, + resnet_blocks=2, + attn_blocks=4, + num_attn_heads=4, + dropout=0, + downsample_factor=2, + kernel_size=3, + ): + super().__init__() + self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1)) + ch = base_channels + res = [] + for l in range(depth): + for r in range(resnet_blocks): + res.append(ResBlock(ch, dropout, kernel_size=kernel_size)) + res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor)) + ch *= 2 + self.res = nn.Sequential(*res) + self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)) + attn = [] + for a in range(attn_blocks): + attn.append( + AttentionBlock( + embedding_dim, + num_attn_heads, + ) + ) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + + def forward(self, x): + h = self.init(x) + h = self.res(h) + h = self.final(h) + h = self.attn(h) + return h[:, :, 0] + + +DEFAULT_MEL_NORM_FILE = "https://coqui.gateway.scarf.sh/v0.14.1_models/mel_norms.pth" + + +class TorchMelSpectrogram(nn.Module): + def __init__( + self, + filter_length=1024, + hop_length=256, + win_length=1024, + n_mel_channels=80, + mel_fmin=0, + mel_fmax=8000, + sampling_rate=22050, + normalize=False, + mel_norm_file=DEFAULT_MEL_NORM_FILE, + ): + super().__init__() + # These are the default tacotron values for the MEL spectrogram. + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.n_mel_channels = n_mel_channels + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.sampling_rate = sampling_rate + self.mel_stft = torchaudio.transforms.MelSpectrogram( + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + power=2, + normalized=normalize, + sample_rate=self.sampling_rate, + f_min=self.mel_fmin, + f_max=self.mel_fmax, + n_mels=self.n_mel_channels, + norm="slaney", + ) + self.mel_norm_file = mel_norm_file + if self.mel_norm_file is not None: + with fsspec.open(self.mel_norm_file) as f: + self.mel_norms = torch.load(f, weights_only=is_pytorch_at_least_2_4()) + else: + self.mel_norms = None + + def forward(self, inp): + if ( + len(inp.shape) == 3 + ): # Automatically squeeze out the channels dimension if it is present (assuming mono-audio) + inp = inp.squeeze(1) + assert len(inp.shape) == 2 + self.mel_stft = self.mel_stft.to(inp.device) + mel = self.mel_stft(inp) + # Perform dynamic range compression + mel = torch.log(torch.clamp(mel, min=1e-5)) + if self.mel_norms is not None: + self.mel_norms = self.mel_norms.to(mel.device) + mel = mel / self.mel_norms.unsqueeze(0).unsqueeze(-1) + return mel + + +class CheckpointedLayer(nn.Module): + """ + Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses + checkpoint for all other args. + """ + + def __init__(self, wrap): + super().__init__() + self.wrap = wrap + + def forward(self, x, *args, **kwargs): + for k, v in kwargs.items(): + assert not (isinstance(v, torch.Tensor) and v.requires_grad) # This would screw up checkpointing. + partial = functools.partial(self.wrap, **kwargs) + return partial(x, *args) + + +class CheckpointedXTransformerEncoder(nn.Module): + """ + Wraps a ContinuousTransformerWrapper and applies CheckpointedLayer to each layer and permutes from channels-mid + to channels-last that XTransformer expects. + """ + + def __init__(self, needs_permute=True, exit_permute=True, checkpoint=True, **xtransformer_kwargs): + super().__init__() + self.transformer = ContinuousTransformerWrapper(**xtransformer_kwargs) + self.needs_permute = needs_permute + self.exit_permute = exit_permute + + if not checkpoint: + return + for i in range(len(self.transformer.attn_layers.layers)): + n, b, r = self.transformer.attn_layers.layers[i] + self.transformer.attn_layers.layers[i] = nn.ModuleList([n, CheckpointedLayer(b), r]) + + def forward(self, x, **kwargs): + if self.needs_permute: + x = x.permute(0, 2, 1) + h = self.transformer(x, **kwargs) + if self.exit_permute: + h = h.permute(0, 2, 1) + return h + + +class TypicalLogitsWarper(LogitsWarper): + def __init__( + self, + mass: float = 0.9, + filter_value: float = -float("Inf"), + min_tokens_to_keep: int = 1, + ): + self.filter_value = filter_value + self.mass = mass + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # calculate entropy + normalized = torch.nn.functional.log_softmax(scores, dim=-1) + p = torch.exp(normalized) + ent = -(normalized * p).nansum(-1, keepdim=True) + + # shift and sort + shifted_scores = torch.abs((-normalized) - ent) + sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) + sorted_logits = scores.gather(-1, sorted_indices) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Remove tokens with cumulative mass above the threshold + last_ind = (cumulative_probs < self.mass).sum(dim=1) + last_ind[last_ind < 0] = 0 + sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) + if self.min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores diff --git a/TTS/tts/layers/tortoise/audio_utils.py b/TTS/tts/layers/tortoise/audio_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4f299a8fd9f03a054ed3863a8c2b5bbe027ac368 --- /dev/null +++ b/TTS/tts/layers/tortoise/audio_utils.py @@ -0,0 +1,181 @@ +import logging +import os +from glob import glob +from typing import Dict, List + +import librosa +import numpy as np +import torch +import torchaudio +from scipy.io.wavfile import read + +from TTS.utils.audio.torch_transforms import TorchSTFT +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 + +logger = logging.getLogger(__name__) + + +def load_wav_to_torch(full_path): + sampling_rate, data = read(full_path) + if data.dtype == np.int32: + norm_fix = 2**31 + elif data.dtype == np.int16: + norm_fix = 2**15 + elif data.dtype == np.float16 or data.dtype == np.float32: + norm_fix = 1.0 + else: + raise NotImplementedError(f"Provided data dtype not supported: {data.dtype}") + return (torch.FloatTensor(data.astype(np.float32)) / norm_fix, sampling_rate) + + +def check_audio(audio, audiopath: str): + # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. + # '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. + if torch.any(audio > 2) or not torch.any(audio < 0): + logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min()) + audio.clip_(-1, 1) + + +def read_audio_file(audiopath: str): + if audiopath[-4:] == ".wav": + audio, lsr = load_wav_to_torch(audiopath) + elif audiopath[-4:] == ".mp3": + audio, lsr = librosa.load(audiopath, sr=None) + audio = torch.FloatTensor(audio) + else: + assert False, f"Unsupported audio format provided: {audiopath[-4:]}" + + # Remove any channel data. + if len(audio.shape) > 1: + if audio.shape[0] < 5: + audio = audio[0] + else: + assert audio.shape[1] < 5 + audio = audio[:, 0] + + return audio, lsr + + +def load_required_audio(audiopath: str): + audio, lsr = read_audio_file(audiopath) + + audios = [torchaudio.functional.resample(audio, lsr, sampling_rate) for sampling_rate in (22050, 24000)] + for audio in audios: + check_audio(audio, audiopath) + + return [audio.unsqueeze(0) for audio in audios] + + +def load_audio(audiopath, sampling_rate): + audio, lsr = read_audio_file(audiopath) + + if lsr != sampling_rate: + audio = torchaudio.functional.resample(audio, lsr, sampling_rate) + check_audio(audio, audiopath) + + return audio.unsqueeze(0) + + +TACOTRON_MEL_MAX = 2.3143386840820312 +TACOTRON_MEL_MIN = -11.512925148010254 + + +def denormalize_tacotron_mel(norm_mel): + return ((norm_mel + 1) / 2) * (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN) + TACOTRON_MEL_MIN + + +def normalize_tacotron_mel(mel): + return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1 + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def get_voices(extra_voice_dirs: List[str] = []): + dirs = extra_voice_dirs + voices: Dict[str, List[str]] = {} + for d in dirs: + subs = os.listdir(d) + for sub in subs: + subj = os.path.join(d, sub) + if os.path.isdir(subj): + voices[sub] = list(glob(f"{subj}/*.wav")) + list(glob(f"{subj}/*.mp3")) + list(glob(f"{subj}/*.pth")) + return voices + + +def load_voice(voice: str, extra_voice_dirs: List[str] = []): + if voice == "random": + return None, None + + voices = get_voices(extra_voice_dirs) + paths = voices[voice] + if len(paths) == 1 and paths[0].endswith(".pth"): + return None, torch.load(paths[0], weights_only=is_pytorch_at_least_2_4()) + else: + conds = [] + for cond_path in paths: + c = load_required_audio(cond_path) + conds.append(c) + return conds, None + + +def load_voices(voices: List[str], extra_voice_dirs: List[str] = []): + latents = [] + clips = [] + for voice in voices: + if voice == "random": + if len(voices) > 1: + logger.warning("Cannot combine a random voice with a non-random voice. Just using a random voice.") + return None, None + clip, latent = load_voice(voice, extra_voice_dirs) + if latent is None: + assert ( + len(latents) == 0 + ), "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this." + clips.extend(clip) + elif clip is None: + assert ( + len(clips) == 0 + ), "Can only combine raw audio voices or latent voices, not both. Do it yourself if you want this." + latents.append(latent) + if len(latents) == 0: + return clips, None + else: + latents_0 = torch.stack([l[0] for l in latents], dim=0).mean(dim=0) + latents_1 = torch.stack([l[1] for l in latents], dim=0).mean(dim=0) + latents = (latents_0, latents_1) + return None, latents + + +def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"): + stft = TorchSTFT( + n_fft=1024, + hop_length=256, + win_length=1024, + use_mel=True, + n_mels=100, + sample_rate=24000, + mel_fmin=0, + mel_fmax=12000, + ) + stft = stft.to(device) + mel = stft(wav) + mel = dynamic_range_compression(mel) + if do_normalization: + mel = normalize_tacotron_mel(mel) + return mel diff --git a/TTS/tts/layers/tortoise/autoregressive.py b/TTS/tts/layers/tortoise/autoregressive.py new file mode 100644 index 0000000000000000000000000000000000000000..aaae6955160001765853402a892ab93906b7c1c9 --- /dev/null +++ b/TTS/tts/layers/tortoise/autoregressive.py @@ -0,0 +1,669 @@ +# AGPL: a notification must be added stating that changes have been made to that file. +import functools +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import transformers +from packaging.version import Version +from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + +from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, TypicalLogitsWarper + +if Version(transformers.__version__) >= Version("4.45"): + isin = transformers.pytorch_utils.isin_mps_friendly +else: + isin = torch.isin + + +def null_position_embeddings(range, dim): + return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) + + +def _p(t): + return t and (len(t), len(t[0]), t[0][0].shape) # kv_cache debug + + +class ResBlock(nn.Module): + """ + Basic residual convolutional block that uses GroupNorm. + """ + + def __init__(self, chan): + super().__init__() + self.net = nn.Sequential( + nn.Conv1d(chan, chan, kernel_size=3, padding=1), + nn.GroupNorm(chan // 8, chan), + nn.ReLU(), + nn.Conv1d(chan, chan, kernel_size=3, padding=1), + nn.GroupNorm(chan // 8, chan), + ) + + def forward(self, x): + return F.relu(self.net(x) + x) + + +class GPT2InferenceModel(GPT2PreTrainedModel): + def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache): + super().__init__(config) + self.transformer = gpt + self.text_pos_embedding = text_pos_emb + self.embeddings = embeddings + self.lm_head = nn.Sequential(norm, linear) + self.kv_cache = kv_cache + + def store_mel_emb(self, mel_emb): + self.cached_mel_emb = mel_emb + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # usually None + if not self.kv_cache: + past_key_values = None + # only last token for inputs_ids if past is defined in kwargs + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + assert self.cached_mel_emb is not None + assert inputs_embeds is None # Not supported by this inference model. + assert labels is None # Training not supported by this inference model. + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Create embedding + mel_len = self.cached_mel_emb.shape[1] + if input_ids.shape[1] != 1: + text_inputs = input_ids[:, mel_len:] + text_emb = self.embeddings(text_inputs) + text_emb = text_emb + self.text_pos_embedding(text_emb) + if self.cached_mel_emb.shape[0] != text_emb.shape[0]: + mel_emb = self.cached_mel_emb.repeat_interleave(text_emb.shape[0] // self.cached_mel_emb.shape[0], 0) + else: # this outcome only occurs once per loop in most cases + mel_emb = self.cached_mel_emb + emb = torch.cat([mel_emb, text_emb], dim=1) + else: + emb = self.embeddings(input_ids) + emb = emb + self.text_pos_embedding.get_fixed_embedding( + attention_mask.shape[1] - mel_len, attention_mask.device + ) + + transformer_outputs = self.transformer( + inputs_embeds=emb, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + transformer_outputs[1:] + + return CausalLMOutputWithCrossAttentions( + loss=None, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past, beam_idx): + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + + +class ConditioningEncoder(nn.Module): + def __init__( + self, + spec_dim, + embedding_dim, + attn_blocks=6, + num_attn_heads=4, + do_checkpointing=False, + mean=False, + ): + super().__init__() + attn = [] + self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + self.do_checkpointing = do_checkpointing + self.mean = mean + + def forward(self, x): + h = self.init(x) + h = self.attn(h) + if self.mean: + return h.mean(dim=2) + else: + return h[:, :, 0] + + +class LearnedPositionEmbeddings(nn.Module): + def __init__(self, seq_len, model_dim, init=0.02): + super().__init__() + self.emb = nn.Embedding(seq_len, model_dim) + # Initializing this way is standard for GPT-2 + self.emb.weight.data.normal_(mean=0.0, std=init) + + def forward(self, x): + sl = x.shape[1] + return self.emb(torch.arange(0, sl, device=x.device)) + + def get_fixed_embedding(self, ind, dev): + return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind] + + +def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): + """ + GPT-2 implemented by the HuggingFace library. + """ + from transformers import GPT2Config, GPT2Model + + gpt_config = GPT2Config( + vocab_size=256, # Unused. + n_positions=max_mel_seq_len + max_text_seq_len, + n_ctx=max_mel_seq_len + max_text_seq_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing, + ) + gpt = GPT2Model(gpt_config) + # Override the built in positional embeddings + del gpt.wpe # TODO: figure out relevance in fixing exported model definition: Embedding(1012, 1024) + gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) + # Built-in token embeddings are unused. + del gpt.wte + return ( + gpt, + LearnedPositionEmbeddings(max_mel_seq_len, model_dim), + LearnedPositionEmbeddings(max_text_seq_len, model_dim), + None, + None, + ) + + +class MelEncoder(nn.Module): + def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2): + super().__init__() + self.channels = channels + self.encoder = nn.Sequential( + nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1), + nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels // 16, channels // 2), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]), + nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(channels // 8, channels), + nn.ReLU(), + nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]), + ) + self.reduction = 4 + + def forward(self, x): + for e in self.encoder: + x = e(x) + return x.permute(0, 2, 1) + + +class UnifiedVoice(nn.Module): + def __init__( + self, + layers=8, + model_dim=512, + heads=8, + max_text_tokens=120, + max_mel_tokens=250, + max_conditioning_inputs=1, + mel_length_compression=1024, + number_text_tokens=256, + start_text_token=None, + number_mel_codes=8194, + start_mel_token=8192, + stop_mel_token=8193, + train_solo_embeddings=False, + use_mel_codes_as_input=True, + checkpointing=True, + types=1, + ): + """ + Args: + layers: Number of layers in transformer stack. + model_dim: Operating dimensions of the transformer + heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64 + max_text_tokens: Maximum number of text tokens that will be encountered by model. + max_mel_tokens: Maximum number of MEL tokens that will be encountered by model. + max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s). + mel_length_compression: The factor between and . Used to compute MEL code padding given wav input length. + number_text_tokens: + start_text_token: + stop_text_token: + number_mel_codes: + start_mel_token: + stop_mel_token: + train_solo_embeddings: + use_mel_codes_as_input: + checkpointing: + """ + super().__init__() + + self.number_text_tokens = number_text_tokens + self.start_text_token = number_text_tokens * types if start_text_token is None else start_text_token + self.stop_text_token = 0 + self.number_mel_codes = number_mel_codes + self.start_mel_token = start_mel_token + self.stop_mel_token = stop_mel_token + self.layers = layers + self.heads = heads + self.max_mel_tokens = max_mel_tokens + self.max_text_tokens = max_text_tokens + self.model_dim = model_dim + self.max_conditioning_inputs = max_conditioning_inputs + self.mel_length_compression = mel_length_compression + self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) + self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim) + if use_mel_codes_as_input: + self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) + else: + self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) + ( + self.gpt, + self.mel_pos_embedding, + self.text_pos_embedding, + self.mel_layer_pos_embedding, + self.text_layer_pos_embedding, + ) = build_hf_gpt_transformer( + layers, + model_dim, + heads, + self.max_mel_tokens + 2 + self.max_conditioning_inputs, + self.max_text_tokens + 2, + checkpointing, + ) + if train_solo_embeddings: + self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) + self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) + else: + self.mel_solo_embedding = 0 + self.text_solo_embedding = 0 + + self.final_norm = nn.LayerNorm(model_dim) + self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1) + self.mel_head = nn.Linear(model_dim, self.number_mel_codes) + + # Initialize the embeddings per the GPT-2 scheme + embeddings = [self.text_embedding] + if use_mel_codes_as_input: + embeddings.append(self.mel_embedding) + for module in embeddings: + module.weight.data.normal_(mean=0.0, std=0.02) + + def post_init_gpt2_config(self, kv_cache=True): + seq_length = self.max_mel_tokens + self.max_text_tokens + 2 + gpt_config = GPT2Config( + vocab_size=self.max_mel_tokens, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=self.model_dim, + n_layer=self.layers, + n_head=self.heads, + gradient_checkpointing=False, + use_cache=True, + ) + self.inference_model = GPT2InferenceModel( + gpt_config, + self.gpt, + self.mel_pos_embedding, + self.mel_embedding, + self.final_norm, + self.mel_head, + kv_cache=kv_cache, + ) + # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) + self.gpt.wte = self.mel_embedding + # self.inference_model.save_pretrained("") + + def build_aligned_inputs_and_targets(self, input, start_token, stop_token): + inp = F.pad(input, (1, 0), value=start_token) + tar = F.pad(input, (0, 1), value=stop_token) + return inp, tar + + def set_mel_padding(self, mel_input_tokens, wav_lengths): + """ + Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in + that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required + preformatting to create a working TTS model. + """ + # Set padding areas within MEL (currently it is coded with the MEL code for ). + mel_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode="trunc") + for b in range(len(mel_lengths)): + actual_end = ( + mel_lengths[b] + 1 + ) # Due to the convolutional nature of how these tokens are generated, it would be best if the model predicts a token past the actual last token. + if actual_end < mel_input_tokens.shape[-1]: + mel_input_tokens[b, actual_end:] = self.stop_mel_token + return mel_input_tokens + + def get_logits( + self, + speech_conditioning_inputs, + first_inputs, + first_head, + second_inputs=None, + second_head=None, + get_attns=False, + return_latent=False, + ): + if second_inputs is not None: + emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1) + else: + emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1) + + gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns) + if get_attns: + return gpt_out.attentions + + enc = gpt_out.last_hidden_state[:, 1:] # The first logit is tied to the speech_conditioning_input + enc = self.final_norm(enc) + + if return_latent: + return ( + enc[ + :, + speech_conditioning_inputs.shape[1] : speech_conditioning_inputs.shape[1] + first_inputs.shape[1], + ], + enc[:, -second_inputs.shape[1] :], + ) + + first_logits = enc[:, : first_inputs.shape[1]] + first_logits = first_head(first_logits) + first_logits = first_logits.permute(0, 2, 1) + if second_inputs is not None: + second_logits = enc[:, -second_inputs.shape[1] :] + second_logits = second_head(second_logits) + second_logits = second_logits.permute(0, 2, 1) + return first_logits, second_logits + else: + return first_logits + + def get_conditioning(self, speech_conditioning_input): + speech_conditioning_input = ( + speech_conditioning_input.unsqueeze(1) + if len(speech_conditioning_input.shape) == 3 + else speech_conditioning_input + ) + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + conds = conds.mean(dim=1) + return conds + + def forward( + self, + speech_conditioning_latent, + text_inputs, + text_lengths, + mel_codes, + wav_lengths, + types=None, + text_first=True, + raw_mels=None, + return_attentions=False, + return_latent=False, + clip_inputs=True, + ): + """ + Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode + (actuated by `text_first`). + + speech_conditioning_input: MEL float tensor, (b,1024) + text_inputs: long tensor, (b,t) + text_lengths: long tensor, (b,) + mel_inputs: long tensor, (b,m) + wav_lengths: long tensor, (b,) + raw_mels: MEL float tensor (b,80,s) + + If return_attentions is specified, only logits are returned. + If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. + If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality. + """ + # Types are expressed by expanding the text embedding space. + if types is not None: + text_inputs = text_inputs * (1 + types).unsqueeze(-1) + + if clip_inputs: + # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by + # chopping the inputs by the maximum actual length. + max_text_len = text_lengths.max() + text_inputs = text_inputs[:, :max_text_len] + max_mel_len = wav_lengths.max() // self.mel_length_compression + mel_codes = mel_codes[:, :max_mel_len] + if raw_mels is not None: + raw_mels = raw_mels[:, :, : max_mel_len * 4] + mel_codes = self.set_mel_padding(mel_codes, wav_lengths) + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token) + + conds = speech_conditioning_latent.unsqueeze(1) + text_inputs, text_targets = self.build_aligned_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token + ) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + mel_codes, mel_targets = self.build_aligned_inputs_and_targets( + mel_codes, self.start_mel_token, self.stop_mel_token + ) + if raw_mels is not None: + mel_inp = F.pad(raw_mels, (0, 8)) + else: + mel_inp = mel_codes + mel_emb = self.mel_embedding(mel_inp) + mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + + if text_first: + text_logits, mel_logits = self.get_logits( + conds, + text_emb, + self.text_head, + mel_emb, + self.mel_head, + get_attns=return_attentions, + return_latent=return_latent, + ) + if return_latent: + return mel_logits[ + :, :-2 + ] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. + else: + mel_logits, text_logits = self.get_logits( + conds, + mel_emb, + self.mel_head, + text_emb, + self.text_head, + get_attns=return_attentions, + return_latent=return_latent, + ) + if return_latent: + return text_logits[ + :, :-2 + ] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass. + + if return_attentions: + return mel_logits + loss_text = F.cross_entropy(text_logits, text_targets.long()) + loss_mel = F.cross_entropy(mel_logits, mel_targets.long()) + return loss_text.mean(), loss_mel.mean(), mel_logits + + def inference_speech( + self, + speech_conditioning_latent, + text_inputs, + input_tokens=None, + num_return_sequences=1, + max_generate_length=None, + typical_sampling=False, + typical_mass=0.9, + **hf_generate_kwargs, + ): + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + text_inputs, text_targets = self.build_aligned_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token + ) + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + + conds = speech_conditioning_latent.unsqueeze(1) + emb = torch.cat([conds, text_emb], dim=1) + self.inference_model.store_mel_emb(emb) + + fake_inputs = torch.full( + ( + emb.shape[0], + conds.shape[1] + emb.shape[1], + ), + fill_value=1, + dtype=torch.long, + device=text_inputs.device, + ) + fake_inputs[:, -1] = self.start_mel_token + trunc_index = fake_inputs.shape[1] + if input_tokens is None: + inputs = fake_inputs + else: + assert ( + num_return_sequences % input_tokens.shape[0] == 0 + ), "The number of return sequences must be divisible by the number of input sequences" + fake_inputs = fake_inputs.repeat(num_return_sequences, 1) + input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1) + inputs = torch.cat([fake_inputs, input_tokens], dim=1) + + logits_processor = ( + LogitsProcessorList([TypicalLogitsWarper(mass=typical_mass)]) if typical_sampling else LogitsProcessorList() + ) # TODO disable this + max_length = ( + trunc_index + self.max_mel_tokens - 1 if max_generate_length is None else trunc_index + max_generate_length + ) + stop_token_tensor = torch.tensor(self.stop_mel_token, device=inputs.device, dtype=torch.long) + attention_mask = _prepare_attention_mask_for_generation(inputs, stop_token_tensor, stop_token_tensor) + gen = self.inference_model.generate( + inputs, + bos_token_id=self.start_mel_token, + pad_token_id=self.stop_mel_token, + eos_token_id=self.stop_mel_token, + max_length=max_length, + logits_processor=logits_processor, + num_return_sequences=num_return_sequences, + attention_mask=attention_mask, + **hf_generate_kwargs, + ) + return gen[:, trunc_index:] + + +def _prepare_attention_mask_for_generation( + inputs: torch.Tensor, + pad_token_id: Optional[torch.Tensor], + eos_token_id: Optional[torch.Tensor], +) -> torch.LongTensor: + # No information for attention mask inference -> return default attention mask + default_attention_mask = torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) + if pad_token_id is None: + return default_attention_mask + + is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] + if not is_input_ids: + return default_attention_mask + + is_pad_token_in_inputs = (pad_token_id is not None) and (isin(elements=inputs, test_elements=pad_token_id).any()) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~( + isin(elements=eos_token_id, test_elements=pad_token_id).any() + ) + can_infer_attention_mask = is_pad_token_in_inputs * is_pad_token_not_equal_to_eos_token_id + attention_mask_from_padding = inputs.ne(pad_token_id).long() + + attention_mask = ( + attention_mask_from_padding * can_infer_attention_mask + default_attention_mask * ~can_infer_attention_mask + ) + return attention_mask + + +if __name__ == "__main__": + gpt = UnifiedVoice( + model_dim=256, + heads=4, + train_solo_embeddings=True, + use_mel_codes_as_input=True, + max_conditioning_inputs=4, + ) + l = gpt( + torch.randn(2, 3, 80, 800), + torch.randint(high=120, size=(2, 120)), + torch.tensor([32, 120]), + torch.randint(high=8192, size=(2, 250)), + torch.tensor([250 * 256, 195 * 256]), + ) + gpt.text_forward( + torch.randn(2, 80, 800), + torch.randint(high=50, size=(2, 80)), + torch.tensor([32, 80]), + ) diff --git a/TTS/tts/layers/tortoise/classifier.py b/TTS/tts/layers/tortoise/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..8764bb070b5ad8267ee2992ccc33f5bb65bad005 --- /dev/null +++ b/TTS/tts/layers/tortoise/classifier.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn + +from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, Downsample, Upsample, normalization, zero_module + + +class ResBlock(nn.Module): + def __init__( + self, + channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + kernel_size=3, + do_checkpoint=True, + ): + super().__init__() + self.channels = channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.do_checkpoint = do_checkpoint + padding = 1 if kernel_size == 3 else 2 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, kernel_size, padding=padding) + else: + self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1) + + def forward(self, x): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AudioMiniEncoder(nn.Module): + def __init__( + self, + spec_dim, + embedding_dim, + base_channels=128, + depth=2, + resnet_blocks=2, + attn_blocks=4, + num_attn_heads=4, + dropout=0, + downsample_factor=2, + kernel_size=3, + ): + super().__init__() + self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1)) + ch = base_channels + res = [] + self.layers = depth + for l in range(depth): + for r in range(resnet_blocks): + res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size)) + res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor)) + ch *= 2 + self.res = nn.Sequential(*res) + self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)) + attn = [] + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + + def forward(self, x): + h = self.init(x) + h = self.res(h) + h = self.final(h) + for blk in self.attn: + h = blk(h) + return h[:, :, 0] + + +class AudioMiniEncoderWithClassifierHead(nn.Module): + def __init__(self, classes, distribute_zero_label=True, **kwargs): + super().__init__() + self.enc = AudioMiniEncoder(**kwargs) + self.head = nn.Linear(self.enc.dim, classes) + self.num_classes = classes + self.distribute_zero_label = distribute_zero_label + + def forward(self, x, labels=None): + h = self.enc(x) + logits = self.head(h) + if labels is None: + return logits + else: + if self.distribute_zero_label: + oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes) + zeros_indices = (labels == 0).unsqueeze(-1) + # Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise. + zero_extra_mass = torch.full_like( + oh_labels, + dtype=torch.float, + fill_value=0.2 / (self.num_classes - 1), + ) + zero_extra_mass[:, 0] = -0.2 + zero_extra_mass = zero_extra_mass * zeros_indices + oh_labels = oh_labels + zero_extra_mass + else: + oh_labels = labels + loss = nn.functional.cross_entropy(logits, oh_labels) + return loss diff --git a/TTS/tts/layers/tortoise/clvp.py b/TTS/tts/layers/tortoise/clvp.py new file mode 100644 index 0000000000000000000000000000000000000000..241dfdd4f4adebc030dd198c5a9ef615a342e8a1 --- /dev/null +++ b/TTS/tts/layers/tortoise/clvp.py @@ -0,0 +1,159 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from TTS.tts.layers.tortoise.arch_utils import CheckpointedXTransformerEncoder +from TTS.tts.layers.tortoise.transformer import Transformer +from TTS.tts.layers.tortoise.xtransformers import Encoder + + +def exists(val): + return val is not None + + +def masked_mean(t, mask, dim=1): + t = t.masked_fill(~mask[:, :, None], 0.0) + return t.sum(dim=1) / mask.sum(dim=1)[..., None] + + +class CLVP(nn.Module): + """ + CLIP model retrofitted for performing contrastive evaluation between tokenized audio data and the corresponding + transcribed text. + + Originally from https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py + """ + + def __init__( + self, + *, + dim_text=512, + dim_speech=512, + dim_latent=512, + num_text_tokens=256, + text_enc_depth=6, + text_seq_len=120, + text_heads=8, + num_speech_tokens=8192, + speech_enc_depth=6, + speech_heads=8, + speech_seq_len=250, + text_mask_percentage=0, + voice_mask_percentage=0, + wav_token_compression=1024, + use_xformers=False, + ): + super().__init__() + self.text_emb = nn.Embedding(num_text_tokens, dim_text) + self.to_text_latent = nn.Linear(dim_text, dim_latent, bias=False) + + self.speech_emb = nn.Embedding(num_speech_tokens, dim_speech) + self.to_speech_latent = nn.Linear(dim_speech, dim_latent, bias=False) + + if use_xformers: + self.text_transformer = CheckpointedXTransformerEncoder( + needs_permute=False, + exit_permute=False, + max_seq_len=-1, + attn_layers=Encoder( + dim=dim_text, + depth=text_enc_depth, + heads=text_heads, + ff_dropout=0.1, + ff_mult=2, + attn_dropout=0.1, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + ), + ) + self.speech_transformer = CheckpointedXTransformerEncoder( + needs_permute=False, + exit_permute=False, + max_seq_len=-1, + attn_layers=Encoder( + dim=dim_speech, + depth=speech_enc_depth, + heads=speech_heads, + ff_dropout=0.1, + ff_mult=2, + attn_dropout=0.1, + use_rmsnorm=True, + ff_glu=True, + rotary_pos_emb=True, + ), + ) + else: + self.text_transformer = Transformer( + causal=False, seq_len=text_seq_len, dim=dim_text, depth=text_enc_depth, heads=text_heads + ) + self.speech_transformer = Transformer( + causal=False, seq_len=speech_seq_len, dim=dim_speech, depth=speech_enc_depth, heads=speech_heads + ) + + self.temperature = nn.Parameter(torch.tensor(1.0)) + self.text_mask_percentage = text_mask_percentage + self.voice_mask_percentage = voice_mask_percentage + self.wav_token_compression = wav_token_compression + self.xformers = use_xformers + if not use_xformers: + self.text_pos_emb = nn.Embedding(text_seq_len, dim_text) + self.speech_pos_emb = nn.Embedding(num_speech_tokens, dim_speech) + + def forward(self, text, speech_tokens, return_loss=False): + b, device = text.shape[0], text.device + if self.training: + text_mask = torch.rand_like(text.float()) > self.text_mask_percentage + voice_mask = torch.rand_like(speech_tokens.float()) > self.voice_mask_percentage + else: + text_mask = torch.ones_like(text.float()).bool() + voice_mask = torch.ones_like(speech_tokens.float()).bool() + + text_emb = self.text_emb(text) + speech_emb = self.speech_emb(speech_tokens) + + if not self.xformers: + text_emb += self.text_pos_emb(torch.arange(text.shape[1], device=device)) + speech_emb += self.speech_pos_emb(torch.arange(speech_emb.shape[1], device=device)) + + enc_text = self.text_transformer(text_emb, mask=text_mask) + enc_speech = self.speech_transformer(speech_emb, mask=voice_mask) + + text_latents = masked_mean(enc_text, text_mask, dim=1) + speech_latents = masked_mean(enc_speech, voice_mask, dim=1) + + text_latents = self.to_text_latent(text_latents) + speech_latents = self.to_speech_latent(speech_latents) + + text_latents, speech_latents = (F.normalize(t, p=2, dim=-1) for t in (text_latents, speech_latents)) + + temp = self.temperature.exp() + + if not return_loss: + sim = einsum("n d, n d -> n", text_latents, speech_latents) * temp + return sim + + sim = einsum("i d, j d -> i j", text_latents, speech_latents) * temp + labels = torch.arange(b, device=device) + loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 + return loss + + +if __name__ == "__main__": + clip = CLVP(text_mask_percentage=0.2, voice_mask_percentage=0.2) + clip( + torch.randint(0, 256, (2, 120)), + torch.tensor([50, 100]), + torch.randint(0, 8192, (2, 250)), + torch.tensor([101, 102]), + return_loss=True, + ) + nonloss = clip( + torch.randint(0, 256, (2, 120)), + torch.tensor([50, 100]), + torch.randint(0, 8192, (2, 250)), + torch.tensor([101, 102]), + return_loss=False, + ) + print(nonloss.shape) diff --git a/TTS/tts/layers/tortoise/diffusion.py b/TTS/tts/layers/tortoise/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..2b29091b447933e0bad65cdcccaa50ab2ff88574 --- /dev/null +++ b/TTS/tts/layers/tortoise/diffusion.py @@ -0,0 +1,1242 @@ +""" +This is an almost carbon copy of gaussian_diffusion.py from OpenAI's ImprovedDiffusion repo, which itself: + +This code started out as a PyTorch port of Ho et al's diffusion models: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py + +Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. +""" + +import enum +import math + +import numpy as np +import torch +import torch as th +from tqdm import tqdm + +from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper + +try: + from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral + + K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m} +except ImportError: + K_DIFFUSION_SAMPLERS = None + + +SAMPLERS = ["dpm++2m", "p", "ddim"] + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)] + + return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2)) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = "previous_x" # the model predicts x_{t-1} + START_X = "start_x" # the model predicts x_0 + EPSILON = "epsilon" # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = "learned" + FIXED_SMALL = "fixed_small" + FIXED_LARGE = "fixed_large" + LEARNED_RANGE = "learned_range" + + +class LossType(enum.Enum): + MSE = "mse" # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = "rescaled_mse" # use raw MSE loss (with RESCALED_KL when learning variances) + KL = "kl" # use the variational lower-bound + RESCALED_KL = "rescaled_kl" # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + conditioning_free=False, + conditioning_free_k=1, + ramp_conditioning_free=True, + sampler="p", + ): + self.sampler = sampler + self.model_mean_type = ModelMeanType(model_mean_type) + self.model_var_type = ModelVarType(model_var_type) + self.loss_type = LossType(loss_type) + self.rescale_timesteps = rescale_timesteps + self.conditioning_free = conditioning_free + self.conditioning_free_k = conditioning_free_k + self.ramp_conditioning_free = ramp_conditioning_free + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) + self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + if self.conditioning_free: + model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs) + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.conditioning_free: + model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + if self.conditioning_free: + if self.ramp_conditioning_free: + assert t.shape[0] == 1 # This should only be used in inference. + cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps) + else: + cfk = self.conditioning_free_k + model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + else: + raise NotImplementedError(self.model_mean_type) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, self._scale_timesteps(t), **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def k_diffusion_sample_loop( + self, + k_sampler, + pbar, + model, + shape, + noise=None, # all given + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + device=None, # ALL UNUSED + model_kwargs=None, # {'precomputed_aligned_embeddings': precomputed_embeddings}, + progress=False, # unused as well + ): + assert isinstance(model_kwargs, dict) + if device is None: + device = next(model.parameters()).device + s_in = noise.new_ones([noise.shape[0]]) + + def model_split(*args, **kwargs): + model_output = model(*args, **kwargs) + model_epsilon, model_var = th.split(model_output, model_output.shape[1] // 2, dim=1) + return model_epsilon, model_var + + # + """ + print(self.betas) + print(th.tensor(self.betas)) + noise_schedule = NoiseScheduleVP(schedule='discrete', betas=th.tensor(self.betas)) + """ + noise_schedule = NoiseScheduleVP(schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4) + + def model_fn_prewrap(x, t, *args, **kwargs): + """ + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + print(t) + print(self.timestep_map) + exit() + """ + """ + model_output = model(x, self._scale_timesteps(t*4000), **model_kwargs) + out = self.p_mean_variance(model, x, t*4000, model_kwargs=model_kwargs) + return out['pred_xstart'] + """ + x, _ = x.chunk(2) + t, _ = (t * 1000).chunk(2) + res = torch.cat( + [ + model_split(x, t, conditioning_free=True, **model_kwargs)[0], + model_split(x, t, **model_kwargs)[0], + ] + ) + pbar.update(1) + return res + + model_fn = model_wrapper( + model_fn_prewrap, + noise_schedule, + model_type="noise", # "noise" or "x_start" or "v" or "score" + model_kwargs=model_kwargs, + guidance_type="classifier-free", + condition=th.Tensor(1), + unconditional_condition=th.Tensor(1), + guidance_scale=self.conditioning_free_k, + ) + dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + x_sample = dpm_solver.sample( + noise, + steps=self.num_timesteps, + order=2, + skip_type="time_uniform", + method="multistep", + ) + #''' + return x_sample + + def sample_loop(self, *args, **kwargs): + s = self.sampler + if s == "p": + return self.p_sample_loop(*args, **kwargs) + elif s == "ddim": + return self.ddim_sample_loop(*args, **kwargs) + elif s == "dpm++2m": + if self.conditioning_free is not True: + raise RuntimeError("cond_free must be true") + with tqdm(total=self.num_timesteps) as pbar: + if K_DIFFUSION_SAMPLERS is None: + raise ModuleNotFoundError("Install k_diffusion for using k_diffusion samplers") + return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs) + else: + raise RuntimeError("sampler not impl") + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + for i in tqdm(indices, disable=not progress): + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) + # Equation 12. + noise = th.randn_like(x) + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices, disable=not progress) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None): + """ + Get a term for the variational lower-bound. + + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) + out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs) + kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + # TODO: support multiple model outputs for this mode. + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs) + if isinstance(model_outputs, tuple): + model_output = model_outputs[0] + terms["extra_outputs"] = model_outputs[1:] + else: + model_output = model_outputs + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0] + x_start_pred = torch.zeros(x_start) # Not supported. + elif self.model_mean_type == ModelMeanType.START_X: + target = x_start + x_start_pred = model_output + elif self.model_mean_type == ModelMeanType.EPSILON: + target = noise + x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) + else: + raise NotImplementedError(self.model_mean_type) + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + terms["x_start_predicted"] = x_start_pred + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def autoregressive_training_losses( + self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None + ): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + terms = {} + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + assert False # not currently supported for this type of diffusion. + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs) + terms.update(dict(zip(model_output_keys, model_outputs))) + model_output = terms[gd_out_key] + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C, 2, *x_t.shape[2:]) + model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1] + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0] + x_start_pred = torch.zeros(x_start) # Not supported. + elif self.model_mean_type == ModelMeanType.START_X: + target = x_start + x_start_pred = model_output + elif self.model_mean_type == ModelMeanType.EPSILON: + target = noise + x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) + else: + raise NotImplementedError(self.model_mean_type) + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + terms["x_start_predicted"] = x_start_pred + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + + This term can't be optimized, as it only depends on the encoder. + + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def autoregressive_training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs + return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model, autoregressive=False): + if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel): + return model + mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel + return mod(model, self.timestep_map, self.rescale_timesteps, self.original_num_steps) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError(f"cannot divide section of {size} steps into {section_count}") + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class _WrappedModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + model_output = self.model(x, new_ts, **kwargs) + return model_output + + +class _WrappedAutoregressiveModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, x0, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, x0, new_ts, **kwargs) + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) diff --git a/TTS/tts/layers/tortoise/diffusion_decoder.py b/TTS/tts/layers/tortoise/diffusion_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0d3cf7698a7334b4cfc8d9bdd0f5f6ee3059189d --- /dev/null +++ b/TTS/tts/layers/tortoise/diffusion_decoder.py @@ -0,0 +1,415 @@ +import math +import random +from abc import abstractmethod + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import autocast + +from TTS.tts.layers.tortoise.arch_utils import AttentionBlock, normalization + + +def is_latent(t): + return t.dtype == torch.float + + +def is_sequence(t): + return t.dtype == torch.long + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class TimestepBlock(nn.Module): + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + def forward(self, x, emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +class ResBlock(TimestepBlock): + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + dims=2, + kernel_size=3, + efficient_config=True, + use_scale_shift_norm=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_scale_shift_norm = use_scale_shift_norm + padding = {1: 0, 3: 1, 5: 2}[kernel_size] + eff_kernel = 1 if efficient_config else 3 + eff_padding = 0 if efficient_config else 1 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding), + ) + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding) + + def forward(self, x, emb): + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class DiffusionLayer(TimestepBlock): + def __init__(self, model_channels, dropout, num_heads): + super().__init__() + self.resblk = ResBlock( + model_channels, + model_channels, + dropout, + model_channels, + dims=1, + use_scale_shift_norm=True, + ) + self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True) + + def forward(self, x, time_emb): + y = self.resblk(x, time_emb) + return self.attn(y) + + +class DiffusionTts(nn.Module): + def __init__( + self, + model_channels=512, + num_layers=8, + in_channels=100, + in_latent_channels=512, + in_tokens=8193, + out_channels=200, # mean and variance + dropout=0, + use_fp16=False, + num_heads=16, + # Parameters for regularization. + layer_drop=0.1, + unconditioned_percentage=0.1, # This implements a mechanism similar to what is used in classifier-free training. + ): + super().__init__() + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.dropout = dropout + self.num_heads = num_heads + self.unconditioned_percentage = unconditioned_percentage + self.enable_fp16 = use_fp16 + self.layer_drop = layer_drop + + self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1) + self.time_embed = nn.Sequential( + nn.Linear(model_channels, model_channels), + nn.SiLU(), + nn.Linear(model_channels, model_channels), + ) + + # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. + # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally + # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive + # transformer network. + self.code_embedding = nn.Embedding(in_tokens, model_channels) + self.code_converter = nn.Sequential( + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + ) + self.code_norm = normalization(model_channels) + self.latent_conditioner = nn.Sequential( + nn.Conv1d(in_latent_channels, model_channels, 3, padding=1), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True), + ) + self.contextual_embedder = nn.Sequential( + nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2), + nn.Conv1d(model_channels, model_channels * 2, 3, padding=1, stride=2), + AttentionBlock( + model_channels * 2, + num_heads, + relative_pos_embeddings=True, + do_checkpoint=False, + ), + AttentionBlock( + model_channels * 2, + num_heads, + relative_pos_embeddings=True, + do_checkpoint=False, + ), + AttentionBlock( + model_channels * 2, + num_heads, + relative_pos_embeddings=True, + do_checkpoint=False, + ), + AttentionBlock( + model_channels * 2, + num_heads, + relative_pos_embeddings=True, + do_checkpoint=False, + ), + AttentionBlock( + model_channels * 2, + num_heads, + relative_pos_embeddings=True, + do_checkpoint=False, + ), + ) + self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1)) + self.conditioning_timestep_integrator = TimestepEmbedSequential( + DiffusionLayer(model_channels, dropout, num_heads), + DiffusionLayer(model_channels, dropout, num_heads), + DiffusionLayer(model_channels, dropout, num_heads), + ) + + self.integrating_conv = nn.Conv1d(model_channels * 2, model_channels, kernel_size=1) + self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1) + + self.layers = nn.ModuleList( + [DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] + + [ + ResBlock( + model_channels, + model_channels, + dropout, + dims=1, + use_scale_shift_norm=True, + ) + for _ in range(3) + ] + ) + + self.out = nn.Sequential( + normalization(model_channels), + nn.SiLU(), + nn.Conv1d(model_channels, out_channels, 3, padding=1), + ) + + def get_grad_norm_parameter_groups(self): + groups = { + "minicoder": list(self.contextual_embedder.parameters()), + "layers": list(self.layers.parameters()), + "code_converters": list(self.code_embedding.parameters()) + + list(self.code_converter.parameters()) + + list(self.latent_conditioner.parameters()) + + list(self.latent_conditioner.parameters()), + "timestep_integrator": list(self.conditioning_timestep_integrator.parameters()) + + list(self.integrating_conv.parameters()), + "time_embed": list(self.time_embed.parameters()), + } + return groups + + def get_conditioning(self, conditioning_input): + speech_conditioning_input = ( + conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input + ) + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.contextual_embedder(speech_conditioning_input[:, j])) + conds = torch.cat(conds, dim=-1) + conds = conds.mean(dim=-1) + return conds + + def timestep_independent( + self, + aligned_conditioning, + conditioning_latent, + expected_seq_len, + return_code_pred, + ): + # Shuffle aligned_latent to BxCxS format + if is_latent(aligned_conditioning): + aligned_conditioning = aligned_conditioning.permute(0, 2, 1) + + cond_scale, cond_shift = torch.chunk(conditioning_latent, 2, dim=1) + if is_latent(aligned_conditioning): + code_emb = self.latent_conditioner(aligned_conditioning) + else: + code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1) + code_emb = self.code_converter(code_emb) + code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1) + + unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) + # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. + if self.training and self.unconditioned_percentage > 0: + unconditioned_batches = ( + torch.rand((code_emb.shape[0], 1, 1), device=code_emb.device) < self.unconditioned_percentage + ) + code_emb = torch.where( + unconditioned_batches, + self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1), + code_emb, + ) + expanded_code_emb = F.interpolate(code_emb, size=expected_seq_len, mode="nearest") + + if not return_code_pred: + return expanded_code_emb + else: + mel_pred = self.mel_head(expanded_code_emb) + # Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss. + mel_pred = mel_pred * unconditioned_batches.logical_not() + return expanded_code_emb, mel_pred + + def forward( + self, + x, + timesteps, + aligned_conditioning=None, + conditioning_latent=None, + precomputed_aligned_embeddings=None, + conditioning_free=False, + return_code_pred=False, + ): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced. + :param conditioning_latent: a pre-computed conditioning latent; see get_conditioning(). + :param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent() + :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered. + :return: an [N x C x ...] Tensor of outputs. + """ + assert precomputed_aligned_embeddings is not None or ( + aligned_conditioning is not None and conditioning_latent is not None + ) + assert not ( + return_code_pred and precomputed_aligned_embeddings is not None + ) # These two are mutually exclusive. + + unused_params = [] + if conditioning_free: + code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) + unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) + unused_params.extend(list(self.latent_conditioner.parameters())) + else: + if precomputed_aligned_embeddings is not None: + code_emb = precomputed_aligned_embeddings + else: + code_emb, mel_pred = self.timestep_independent( + aligned_conditioning, conditioning_latent, x.shape[-1], True + ) + if is_latent(aligned_conditioning): + unused_params.extend( + list(self.code_converter.parameters()) + list(self.code_embedding.parameters()) + ) + else: + unused_params.extend(list(self.latent_conditioner.parameters())) + + unused_params.append(self.unconditioned_embedding) + + time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) + x = self.inp_block(x) + x = torch.cat([x, code_emb], dim=1) + x = self.integrating_conv(x) + for i, lyr in enumerate(self.layers): + # Do layer drop where applicable. Do not drop first and last layers. + if ( + self.training + and self.layer_drop > 0 + and i != 0 + and i != (len(self.layers) - 1) + and random.random() < self.layer_drop + ): + unused_params.extend(list(lyr.parameters())) + else: + # First and last blocks will have autocast disabled for improved precision. + with autocast(x.device.type, enabled=self.enable_fp16 and i != 0): + x = lyr(x, time_emb) + + x = x.float() + out = self.out(x) + + # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. + extraneous_addition = 0 + for p in unused_params: + extraneous_addition = extraneous_addition + p.mean() + out = out + extraneous_addition * 0 + + if return_code_pred: + return out, mel_pred + return out + + +if __name__ == "__main__": + clip = torch.randn(2, 100, 400) + aligned_latent = torch.randn(2, 388, 512) + aligned_sequence = torch.randint(0, 8192, (2, 100)) + cond = torch.randn(2, 100, 400) + ts = torch.LongTensor([600, 600]) + model = DiffusionTts(512, layer_drop=0.3, unconditioned_percentage=0.5) + # Test with latent aligned conditioning + # o = model(clip, ts, aligned_latent, cond) + # Test with sequence aligned conditioning + o = model(clip, ts, aligned_sequence, cond) diff --git a/TTS/tts/layers/tortoise/dpm_solver.py b/TTS/tts/layers/tortoise/dpm_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..6a1d8ff78468b966f89e70be3582eab5168ea9bc --- /dev/null +++ b/TTS/tts/layers/tortoise/dpm_solver.py @@ -0,0 +1,1565 @@ +import logging +import math + +import torch + +logger = logging.getLogger(__name__) + + +class NoiseScheduleVP: + def __init__( + self, + schedule="discrete", + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20.0, + dtype=torch.float32, + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ["discrete", "linear", "cosine"]: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( + schedule + ) + ) + + self.schedule = schedule + if schedule == "discrete": + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1.0 + self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + self.log_alpha_array = log_alphas.reshape( + ( + 1, + -1, + ) + ).to(dtype=dtype) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999.0 + self.cosine_t_max = ( + math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)) + self.schedule = schedule + if schedule == "cosine": + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1.0 + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == "discrete": + return interpolate_fn( + t.reshape((-1, 1)), + self.t_array.to(t.device), + self.log_alpha_array.to(t.device), + ).reshape((-1)) + elif self.schedule == "linear": + return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == "cosine": + + def log_alpha_fn(s): + return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)) + + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == "linear": + tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == "discrete": + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) + t = interpolate_fn( + log_alpha.reshape((-1, 1)), + torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1]), + ) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + + def t_fn(log_alpha_t): + return ( + torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) + + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1.0, + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == "discrete": + return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0 + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - alpha_t * output) / sigma_t + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return alpha_t * output + sigma_t * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -sigma_t * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * sigma_t * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1.0 or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v", "score"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="dpmsolver++", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1.0, + dynamic_thresholding_ratio=0.995, + ): + """Construct a DPM-Solver. + + We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). + + We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you + can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the + dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space + DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space + DPMs (such as stable-diffusion). + + To support advanced algorithms in image-to-image applications, we also support corrector functions for + both x0 and xt. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". + correcting_x0_fn: A `str` or a function with the following format: + ``` + def correcting_x0_fn(x0, t): + x0_new = ... + return x0_new + ``` + This function is to correct the outputs of the data prediction model at each sampling step. e.g., + ``` + x0_pred = data_pred_model(xt, t) + if correcting_x0_fn is not None: + x0_pred = correcting_x0_fn(x0_pred, t) + xt_1 = update(x0_pred, xt, t) + ``` + If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. + correcting_xt_fn: A function with the following format: + ``` + def correcting_xt_fn(xt, t, step): + x_new = ... + return x_new + ``` + This function is to correct the intermediate samples xt at each sampling step. e.g., + ``` + xt = ... + xt = correcting_xt_fn(xt, t, step) + ``` + thresholding_max_val: A `float`. The max value for thresholding. + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, + Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models + with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["dpmsolver", "dpmsolver++"] + self.algorithm_type = algorithm_type + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + def dynamic_thresholding_fn(self, x0, t): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims( + torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), + dims, + ) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0, t) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.algorithm_type == "dpmsolver++": + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == "logSNR": + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == "time_uniform": + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == "time_quadratic": + t_order = 2 + t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type) + ) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [ + 3, + ] * ( + K - 2 + ) + [2, 1] + elif steps % 3 == 1: + orders = [ + 3, + ] * ( + K - 1 + ) + [1] + else: + orders = [ + 3, + ] * ( + K - 1 + ) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [ + 2, + ] * K + else: + K = steps // 2 + 1 + orders = [ + 2, + ] * ( + K - 1 + ) + [1] + elif order == 1: + K = 1 + orders = [ + 1, + ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == "logSNR": + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum( + torch.tensor( + [ + 0, + ] + + orders + ), + 0, + ).to(device) + ] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s + if return_intermediate: + return x_t, {"model_s": model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s + if return_intermediate: + return x_t, {"model_s": model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update( + self, + x, + s, + t, + r1=0.5, + model_s=None, + return_intermediate=False, + solver_type="dpmsolver", + ): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_t = ( + ns.marginal_std(s), + ns.marginal_std(s1), + ns.marginal_std(t), + ) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + if solver_type == "dpmsolver": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == "taylor": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + if solver_type == "dpmsolver": + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == "taylor": + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {"model_s": model_s, "model_s1": model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update( + self, + x, + s, + t, + r1=1.0 / 3.0, + r2=2.0 / 3.0, + model_s=None, + model_s1=None, + return_intermediate=False, + solver_type="dpmsolver", + ): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1.0 / 3.0 + if r2 is None: + r2 = 2.0 / 3.0 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(s2), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_s2, sigma_t = ( + ns.marginal_std(s), + ns.marginal_std(s1), + ns.marginal_std(s2), + ns.marginal_std(t), + ) + alpha_s1, alpha_s2, alpha_t = ( + torch.exp(log_alpha_s1), + torch.exp(log_alpha_s2), + torch.exp(log_alpha_t), + ) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0 + phi_2 = phi_1 / h + 1.0 + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (sigma_s2 / sigma_s) * x + - (alpha_s2 * phi_12) * model_s + + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == "dpmsolver": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0 + phi_2 = phi_1 / h - 1.0 + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (torch.exp(log_alpha_s2 - log_alpha_s)) * x + - (sigma_s2 * phi_12) * model_s + - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == "dpmsolver": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + + if return_intermediate: + return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] + t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] + lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if solver_type == "dpmsolver": + x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0 + elif solver_type == "taylor": + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * (phi_1 / h + 1.0)) * D1_0 + ) + else: + phi_1 = torch.expm1(h) + if solver_type == "dpmsolver": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - 0.5 * (sigma_t * phi_1) * D1_0 + ) + elif solver_type == "taylor": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * (phi_1 / h - 1.0)) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_2), + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) + D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + phi_2 = phi_1 / h + 1.0 + phi_3 = phi_2 / h - 0.5 + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_1 = torch.expm1(h) + phi_2 = phi_1 / h - 1.0 + phi_3 = phi_2 / h - 0.5 + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + return x_t + + def singlestep_dpm_solver_update( + self, + x, + s, + t, + order, + return_intermediate=False, + solver_type="dpmsolver", + r1=None, + r2=None, + ): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update( + x, + s, + t, + return_intermediate=return_intermediate, + solver_type=solver_type, + r1=r1, + ) + elif order == 3: + return self.singlestep_dpm_solver_third_update( + x, + s, + t, + return_intermediate=return_intermediate, + solver_type=solver_type, + r1=r1, + r2=r2, + ) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive( + self, + x, + order, + t_T, + t_0, + h_init=0.05, + atol=0.0078, + rtol=0.05, + theta=0.9, + t_err=1e-5, + solver_type="dpmsolver", + ): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((1,)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + + def lower_update(x, s, t): + return self.dpm_solver_first_update(x, s, t, return_intermediate=True) + + def higher_update(x, s, t, **kwargs): + return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) + + elif order == 3: + r1, r2 = 1.0 / 3.0, 2.0 / 3.0 + + def lower_update(x, s, t): + return self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type + ) + + def higher_update(x, s, t, **kwargs): + return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) + + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max( + torch.ones_like(x).to(x) * atol, + rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)), + ) + + def norm_fn(v): + return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.0): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min( + theta * h * torch.float_power(E, -1.0 / order).float(), + lambda_0 - lambda_s, + ) + nfe += order + logger.debug("adaptive solver nfe %d", nfe) + return x + + def add_noise(self, x, t, noise=None): + """ + Compute the noised input xt = alpha_t * x + sigma_t * noise. + + Args: + x: A `torch.Tensor` with shape `(batch_size, *shape)`. + t: A `torch.Tensor` with shape `(t_size,)`. + Returns: + xt with shape `(t_size, batch_size, *shape)`. + """ + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + if noise is None: + noise = torch.randn((t.shape[0], *x.shape), device=x.device) + x = x.reshape((-1, *x.shape)) + xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise + if t.shape[0] == 1: + return xt.squeeze(0) + else: + return xt + + def inverse( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type="time_uniform", + method="multistep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpmsolver", + atol=0.0078, + rtol=0.05, + return_intermediate=False, + ): + """ + Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. + For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start + t_T = self.noise_schedule.T if t_end is None else t_end + assert ( + t_0 > 0 and t_T > 0 + ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + return self.sample( + x, + steps=steps, + t_start=t_0, + t_end=t_T, + order=order, + skip_type=skip_type, + method=method, + lower_order_final=lower_order_final, + denoise_to_zero=denoise_to_zero, + solver_type=solver_type, + atol=atol, + rtol=rtol, + return_intermediate=return_intermediate, + ) + + def sample( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type="time_uniform", + method="multistep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpmsolver", + atol=0.0078, + rtol=0.05, + return_intermediate=False, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g., DPM-Solver: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + e.g., DPM-Solver++: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + return_intermediate: A `bool`. Whether to save the xt at each step. + When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert ( + t_0 > 0 and t_T > 0 + ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in [ + "multistep", + "singlestep", + "singlestep_fixed", + ], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in [ + "multistep", + "singlestep", + "singlestep_fixed", + ], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == "adaptive": + x = self.dpm_solver_adaptive( + x, + order=order, + t_T=t_T, + t_0=t_0, + atol=atol, + rtol=rtol, + solver_type=solver_type, + ) + elif method == "multistep": + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + # Init the first `order` values by lower order multistep DPM-Solver. + for step in range(1, order): + t = timesteps[step] + x = self.multistep_dpm_solver_update( + x, + model_prev_list, + t_prev_list, + t, + step, + solver_type=solver_type, + ) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(self.model_fn(x, t)) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + t = timesteps[step] + # We only use lower order for steps < 10 + if lower_order_final and steps < 10: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update( + x, + model_prev_list, + t_prev_list, + t, + step_order, + solver_type=solver_type, + ) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, t) + elif method in ["singlestep", "singlestep_fixed"]: + if method == "singlestep": + ( + timesteps_outer, + orders, + ) = self.get_orders_and_timesteps_for_singlestep_solver( + steps=steps, + order=order, + skip_type=skip_type, + t_T=t_T, + t_0=t_0, + device=device, + ) + elif method == "singlestep_fixed": + K = steps // order + orders = [ + order, + ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for step, order in enumerate(orders): + s, t = timesteps_outer[step], timesteps_outer[step + 1] + timesteps_inner = self.get_time_steps( + skip_type=skip_type, + t_T=s.item(), + t_0=t.item(), + N=order, + device=device, + ) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + else: + raise ValueError("Got wrong method {}".format(method)) + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + +############################################################# +# other utility functions +############################################################# + + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] diff --git a/TTS/tts/layers/tortoise/random_latent_generator.py b/TTS/tts/layers/tortoise/random_latent_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..9b39c1e4b22ee5a9ad84a1711a08a8530c4d76b7 --- /dev/null +++ b/TTS/tts/layers/tortoise/random_latent_generator.py @@ -0,0 +1,55 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5): + if bias is not None: + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), + negative_slope=negative_slope, + ) + * scale + ) + else: + return F.leaky_relu(input, negative_slope=0.2) * scale + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1): + super().__init__() + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + return out + + +class RandomLatentConverter(nn.Module): + def __init__(self, channels): + super().__init__() + self.layers = nn.Sequential( + *[EqualLinear(channels, channels, lr_mul=0.1) for _ in range(5)], nn.Linear(channels, channels) + ) + self.channels = channels + + def forward(self, ref): + r = torch.randn(ref.shape[0], self.channels, device=ref.device) + y = self.layers(r) + return y + + +if __name__ == "__main__": + model = RandomLatentConverter(512) + model(torch.randn(5, 512)) diff --git a/TTS/tts/layers/tortoise/tokenizer.py b/TTS/tts/layers/tortoise/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d243d6558d0dfcbfee59769f991ce5ad9a603678 --- /dev/null +++ b/TTS/tts/layers/tortoise/tokenizer.py @@ -0,0 +1,37 @@ +import os + +import torch +from tokenizers import Tokenizer + +from TTS.tts.utils.text.cleaners import english_cleaners + +DEFAULT_VOCAB_FILE = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json" +) + + +class VoiceBpeTokenizer: + def __init__(self, vocab_file=DEFAULT_VOCAB_FILE, vocab_str=None): + self.tokenizer = None + if vocab_file is not None: + self.tokenizer = Tokenizer.from_file(vocab_file) + if vocab_str is not None: + self.tokenizer = Tokenizer.from_str(vocab_str) + + def preprocess_text(self, txt): + txt = english_cleaners(txt) + return txt + + def encode(self, txt): + txt = self.preprocess_text(txt) + txt = txt.replace(" ", "[SPACE]") + return self.tokenizer.encode(txt).ids + + def decode(self, seq): + if isinstance(seq, torch.Tensor): + seq = seq.cpu().numpy() + txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "") + txt = txt.replace("[SPACE]", " ") + txt = txt.replace("[STOP]", "") + txt = txt.replace("[UNK]", "") + return txt diff --git a/TTS/tts/layers/tortoise/transformer.py b/TTS/tts/layers/tortoise/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..6cb1bab96a10c0f5d4d2b7a9042a1f6a1d14cdd3 --- /dev/null +++ b/TTS/tts/layers/tortoise/transformer.py @@ -0,0 +1,229 @@ +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def cast_tuple(val, depth=1): + if isinstance(val, list): + val = tuple(val) + return val if isinstance(val, tuple) else (val,) * depth + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def stable_softmax(t, dim=-1, alpha=32**2): + t = t / alpha + t = t - torch.amax(t, dim=dim, keepdim=True).detach() + return (t * alpha).softmax(dim=dim) + + +def route_args(router, args, depth): + routed_args = [(dict(), dict()) for _ in range(depth)] + matched_keys = [key for key in args.keys() if key in router] + + for key in matched_keys: + val = args[key] + for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])): + new_f_args, new_g_args = (({key: val} if route else {}) for route in routes) + routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) + return routed_args + + +# classes +class SequentialSequence(nn.Module): + def __init__(self, layers, args_route={}, layer_dropout=0.0): + super().__init__() + assert all( + len(route) == len(layers) for route in args_route.values() + ), "each argument route map must have the same depth as the number of sequential layers" + self.layers = layers + self.args_route = args_route + self.layer_dropout = layer_dropout + + def forward(self, x, **kwargs): + args = route_args(self.args_route, kwargs, len(self.layers)) + layers_and_args = list(zip(self.layers, args)) + + for (f, g), (f_args, g_args) in layers_and_args: + x = x + f(x, **f_args) + x = x + g(x, **g_args) + return x + + +class DivideMax(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + maxes = x.amax(dim=self.dim, keepdim=True).detach() + return x / maxes + + +# https://arxiv.org/abs/2103.17239 +class LayerScale(nn.Module): + def __init__(self, dim, depth, fn): + super().__init__() + if depth <= 18: + init_eps = 0.1 + elif depth > 18 and depth <= 24: + init_eps = 1e-5 + else: + init_eps = 1e-6 + + scale = torch.zeros(1, 1, dim).fill_(init_eps) + self.scale = nn.Parameter(scale) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(x, **kwargs) * self.scale + + +# layer norm + + +class PreNorm(nn.Module): + def __init__(self, dim, fn, sandwich=False): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity() + self.fn = fn + + def forward(self, x, **kwargs): + x = self.norm(x) + x = self.fn(x, **kwargs) + return self.norm_out(x) + + +# feed forward + + +class GEGLU(nn.Module): + def forward(self, x): + x, gates = x.chunk(2, dim=-1) + return x * F.gelu(gates) + + +class FeedForward(nn.Module): + def __init__(self, dim, dropout=0.0, mult=4.0): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, dim * mult * 2), + GEGLU(), + nn.Dropout(dropout), + nn.Linear(dim * mult, dim), + ) + + def forward(self, x): + return self.net(x) + + +# Attention + + +class Attention(nn.Module): + def __init__(self, dim, seq_len, causal=True, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.seq_len = seq_len + self.scale = dim_head**-0.5 + + self.causal = causal + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + + def forward(self, x, mask=None): + b, n, _, h, device = *x.shape, self.heads, x.device + softmax = torch.softmax + + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in qkv) + + q = q * self.scale + + dots = torch.einsum("b h i d, b h j d -> b h i j", q, k) + mask_value = max_neg_value(dots) + + if exists(mask): + mask = rearrange(mask, "b j -> b () () j") + dots.masked_fill_(~mask, mask_value) + del mask + + if self.causal: + i, j = dots.shape[-2:] + mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool() + dots.masked_fill_(mask, mask_value) + + attn = softmax(dots, dim=-1) + + out = torch.einsum("b h i j, b h j d -> b h i d", attn, v) + out = rearrange(out, "b h n d -> b n (h d)") + out = self.to_out(out) + return out + + +# main transformer class +class Transformer(nn.Module): + def __init__( + self, + *, + dim, + depth, + seq_len, + causal=True, + heads=8, + dim_head=64, + ff_mult=4, + attn_dropout=0.0, + ff_dropout=0.0, + sparse_attn=False, + sandwich_norm=False, + ): + super().__init__() + layers = nn.ModuleList([]) + sparse_layer = cast_tuple(sparse_attn, depth) + + for ind, sparse_attn in zip(range(depth), sparse_layer): + attn = Attention( + dim, + causal=causal, + seq_len=seq_len, + heads=heads, + dim_head=dim_head, + dropout=attn_dropout, + ) + + ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout) + + layers.append( + nn.ModuleList( + [ + LayerScale(dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)), + LayerScale(dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)), + ] + ) + ) + + execute_type = SequentialSequence + route_attn = ((True, False),) * depth + attn_route_map = {"mask": route_attn} + + self.layers = execute_type(layers, args_route=attn_route_map) + + def forward(self, x, **kwargs): + return self.layers(x, **kwargs) diff --git a/TTS/tts/layers/tortoise/utils.py b/TTS/tts/layers/tortoise/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..898121f793292018ebac033bf4da548208e7171f --- /dev/null +++ b/TTS/tts/layers/tortoise/utils.py @@ -0,0 +1,49 @@ +import logging +import os +from urllib import request + +from tqdm import tqdm + +logger = logging.getLogger(__name__) + +DEFAULT_MODELS_DIR = os.path.join(os.path.expanduser("~"), ".cache", "tortoise", "models") +MODELS_DIR = os.environ.get("TORTOISE_MODELS_DIR", DEFAULT_MODELS_DIR) +MODELS_DIR = "/data/speech_synth/models/" +MODELS = { + "autoregressive.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/autoregressive.pth", + "classifier.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/classifier.pth", + "clvp2.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/clvp2.pth", + "diffusion_decoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/diffusion_decoder.pth", + "vocoder.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth", + "rlg_auto.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth", + "rlg_diffuser.pth": "https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth", +} + + +def download_models(specific_models=None): + """ + Call to download all the models that Tortoise uses. + """ + os.makedirs(MODELS_DIR, exist_ok=True) + for model_name, url in MODELS.items(): + if specific_models is not None and model_name not in specific_models: + continue + model_path = os.path.join(MODELS_DIR, model_name) + if os.path.exists(model_path): + continue + logger.info("Downloading %s from %s...", model_name, url) + with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1) as t: + request.urlretrieve(url, model_path, lambda nb, bs, fs, t=t: t.update(nb * bs - t.n)) + logger.info("Done.") + + +def get_model_path(model_name, models_dir=MODELS_DIR): + """ + Get path to given model, download it if it doesn't exist. + """ + if model_name not in MODELS: + raise ValueError(f"Model {model_name} not found in available models.") + model_path = os.path.join(models_dir, model_name) + if not os.path.exists(model_path) and models_dir == MODELS_DIR: + download_models([model_name]) + return model_path diff --git a/TTS/tts/layers/tortoise/vocoder.py b/TTS/tts/layers/tortoise/vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a5200c26738b55273a74c86e4308a6bd6783f5d7 --- /dev/null +++ b/TTS/tts/layers/tortoise/vocoder.py @@ -0,0 +1,405 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.utils.parametrize as parametrize + +MAX_WAV_VALUE = 32768.0 + + +class KernelPredictor(torch.nn.Module): + """Kernel predictor for the location-variable convolutions""" + + def __init__( + self, + cond_channels, + conv_in_channels, + conv_out_channels, + conv_layers, + conv_kernel_size=3, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + kpnet_nonlinear_activation="LeakyReLU", + kpnet_nonlinear_activation_params={"negative_slope": 0.1}, + ): + """ + Args: + cond_channels (int): number of channel for the conditioning sequence, + conv_in_channels (int): number of channel for the input sequence, + conv_out_channels (int): number of channel for the output sequence, + conv_layers (int): number of layers + """ + super().__init__() + + self.conv_in_channels = conv_in_channels + self.conv_out_channels = conv_out_channels + self.conv_kernel_size = conv_kernel_size + self.conv_layers = conv_layers + + kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w + kpnet_bias_channels = conv_out_channels * conv_layers # l_b + + self.input_conv = nn.Sequential( + nn.utils.parametrizations.weight_norm( + nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.residual_convs = nn.ModuleList() + padding = (kpnet_conv_size - 1) // 2 + for _ in range(3): + self.residual_convs.append( + nn.Sequential( + nn.Dropout(kpnet_dropout), + nn.utils.parametrizations.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_hidden_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + nn.utils.parametrizations.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_hidden_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ), + getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + ) + self.kernel_conv = nn.utils.parametrizations.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_kernel_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ) + self.bias_conv = nn.utils.parametrizations.weight_norm( + nn.Conv1d( + kpnet_hidden_channels, + kpnet_bias_channels, + kpnet_conv_size, + padding=padding, + bias=True, + ) + ) + + def forward(self, c): + """ + Args: + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + """ + batch, _, cond_length = c.shape + c = self.input_conv(c) + for residual_conv in self.residual_convs: + residual_conv.to(c.device) + c = c + residual_conv(c) + k = self.kernel_conv(c) + b = self.bias_conv(c) + kernels = k.contiguous().view( + batch, + self.conv_layers, + self.conv_in_channels, + self.conv_out_channels, + self.conv_kernel_size, + cond_length, + ) + bias = b.contiguous().view( + batch, + self.conv_layers, + self.conv_out_channels, + cond_length, + ) + + return kernels, bias + + def remove_weight_norm(self): + parametrize.remove_parametrizations(self.input_conv[0], "weight") + parametrize.remove_parametrizations(self.kernel_conv, "weight") + parametrize.remove_parametrizations(self.bias_conv) + for block in self.residual_convs: + parametrize.remove_parametrizations(block[1], "weight") + parametrize.remove_parametrizations(block[3], "weight") + + +class LVCBlock(torch.nn.Module): + """the location-variable convolutions""" + + def __init__( + self, + in_channels, + cond_channels, + stride, + dilations=[1, 3, 9, 27], + lReLU_slope=0.2, + conv_kernel_size=3, + cond_hop_length=256, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + ): + super().__init__() + + self.cond_hop_length = cond_hop_length + self.conv_layers = len(dilations) + self.conv_kernel_size = conv_kernel_size + + self.kernel_predictor = KernelPredictor( + cond_channels=cond_channels, + conv_in_channels=in_channels, + conv_out_channels=2 * in_channels, + conv_layers=len(dilations), + conv_kernel_size=conv_kernel_size, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=kpnet_dropout, + kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}, + ) + + self.convt_pre = nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.parametrizations.weight_norm( + nn.ConvTranspose1d( + in_channels, + in_channels, + 2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + output_padding=stride % 2, + ) + ), + ) + + self.conv_blocks = nn.ModuleList() + for dilation in dilations: + self.conv_blocks.append( + nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.parametrizations.weight_norm( + nn.Conv1d( + in_channels, + in_channels, + conv_kernel_size, + padding=dilation * (conv_kernel_size - 1) // 2, + dilation=dilation, + ) + ), + nn.LeakyReLU(lReLU_slope), + ) + ) + + def forward(self, x, c): + """forward propagation of the location-variable convolutions. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length) + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + + Returns: + Tensor: the output sequence (batch, in_channels, in_length) + """ + _, in_channels, _ = x.shape # (B, c_g, L') + + x = self.convt_pre(x) # (B, c_g, stride * L') + kernels, bias = self.kernel_predictor(c) + + for i, conv in enumerate(self.conv_blocks): + output = conv(x) # (B, c_g, stride * L') + + k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) + b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) + + output = self.location_variable_convolution( + output, k, b, hop_size=self.cond_hop_length + ) # (B, 2 * c_g, stride * L'): LVC + x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( + output[:, in_channels:, :] + ) # (B, c_g, stride * L'): GAU + + return x + + def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): + """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. + Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length). + kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) + bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) + dilation (int): the dilation of convolution. + hop_size (int): the hop_size of the conditioning sequence. + Returns: + (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). + """ + batch, _, in_length = x.shape + batch, _, out_channels, kernel_size, kernel_length = kernel.shape + assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" + + padding = dilation * int((kernel_size - 1) / 2) + x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) + x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) + + if hop_size < dilation: + x = F.pad(x, (0, dilation), "constant", 0) + x = x.unfold( + 3, dilation, dilation + ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) + x = x[:, :, :, :, :hop_size] + x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) + + o = torch.einsum("bildsk,biokl->bolsd", x, kernel) + o = o.to(memory_format=torch.channels_last_3d) + bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) + o = o + bias + o = o.contiguous().view(batch, out_channels, -1) + + return o + + def remove_weight_norm(self): + self.kernel_predictor.remove_weight_norm() + parametrize.remove_parametrizations(self.convt_pre[1], "weight") + for block in self.conv_blocks: + parametrize.remove_parametrizations(block[1], "weight") + + +class UnivNetGenerator(nn.Module): + """ + UnivNet Generator + + Originally from https://github.com/mindslab-ai/univnet/blob/master/model/generator.py. + """ + + def __init__( + self, + noise_dim=64, + channel_size=32, + dilations=[1, 3, 9, 27], + strides=[8, 8, 4], + lReLU_slope=0.2, + kpnet_conv_size=3, + # Below are MEL configurations options that this generator requires. + hop_length=256, + n_mel_channels=100, + ): + super(UnivNetGenerator, self).__init__() + self.mel_channel = n_mel_channels + self.noise_dim = noise_dim + self.hop_length = hop_length + channel_size = channel_size + kpnet_conv_size = kpnet_conv_size + + self.res_stack = nn.ModuleList() + hop_length = 1 + for stride in strides: + hop_length = stride * hop_length + self.res_stack.append( + LVCBlock( + channel_size, + n_mel_channels, + stride=stride, + dilations=dilations, + lReLU_slope=lReLU_slope, + cond_hop_length=hop_length, + kpnet_conv_size=kpnet_conv_size, + ) + ) + + self.conv_pre = nn.utils.parametrizations.weight_norm( + nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect") + ) + + self.conv_post = nn.Sequential( + nn.LeakyReLU(lReLU_slope), + nn.utils.parametrizations.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")), + nn.Tanh(), + ) + + def forward(self, c, z): + """ + Args: + c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length) + z (Tensor): the noise sequence (batch, noise_dim, in_length) + + """ + z = self.conv_pre(z) # (B, c_g, L) + + for res_block in self.res_stack: + res_block.to(z.device) + z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i) + + z = self.conv_post(z) # (B, 1, L * 256) + + return z + + def eval(self, inference=False): + super(UnivNetGenerator, self).eval() + # don't remove weight norm while validation in training loop + if inference: + self.remove_weight_norm() + + def remove_weight_norm(self): + parametrize.remove_parametrizations(self.conv_pre, "weight") + + for layer in self.conv_post: + if len(layer.state_dict()) != 0: + parametrize.remove_parametrizations(layer, "weight") + + for res_block in self.res_stack: + res_block.remove_weight_norm() + + def inference(self, c, z=None): + # pad input mel with zeros to cut artifact + # see https://github.com/seungwonpark/melgan/issues/8 + zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device) + mel = torch.cat((c, zero), dim=2) + + if z is None: + z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device) + + audio = self.forward(mel, z) + audio = audio[:, :, : -(self.hop_length * 10)] + audio = audio.clamp(min=-1, max=1) + return audio + + +@dataclass +class VocType: + constructor: Callable[[], nn.Module] + model_path: str + subkey: Optional[str] = None + + def optionally_index(self, model_dict): + if self.subkey is not None: + return model_dict[self.subkey] + return model_dict + + +class VocConf(Enum): + Univnet = VocType(UnivNetGenerator, "vocoder.pth", "model_g") + + +if __name__ == "__main__": + model = UnivNetGenerator() + + c = torch.randn(3, 100, 10) + z = torch.randn(3, 64, 10) + print(c.shape) + + y = model(c, z) + print(y.shape) + assert y.shape == torch.Size([3, 1, 2560]) + + pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(pytorch_total_params) diff --git a/TTS/tts/layers/tortoise/wav2vec_alignment.py b/TTS/tts/layers/tortoise/wav2vec_alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..47456cc5ac41b7ed9522fe543affc8482218730c --- /dev/null +++ b/TTS/tts/layers/tortoise/wav2vec_alignment.py @@ -0,0 +1,150 @@ +import torch +import torchaudio +from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2ForCTC + + +def max_alignment(s1, s2, skip_character="~", record=None): + """ + A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is + used to replace that character. + + Finally got to use my DP skills! + """ + if record is None: + record = {} + assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}" + if len(s1) == 0: + return "" + if len(s2) == 0: + return skip_character * len(s1) + if s1 == s2: + return s1 + if s1[0] == s2[0]: + return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record) + + take_s1_key = (len(s1), len(s2) - 1) + if take_s1_key in record: + take_s1, take_s1_score = record[take_s1_key] + else: + take_s1 = max_alignment(s1, s2[1:], skip_character, record) + take_s1_score = len(take_s1.replace(skip_character, "")) + record[take_s1_key] = (take_s1, take_s1_score) + + take_s2_key = (len(s1) - 1, len(s2)) + if take_s2_key in record: + take_s2, take_s2_score = record[take_s2_key] + else: + take_s2 = max_alignment(s1[1:], s2, skip_character, record) + take_s2_score = len(take_s2.replace(skip_character, "")) + record[take_s2_key] = (take_s2, take_s2_score) + + return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2 + + +class Wav2VecAlignment: + """ + Uses wav2vec2 to perform audio<->text alignment. + """ + + def __init__(self, device="cuda"): + self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu() + self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-large-960h") + self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("jbetker/tacotron-symbols") + self.device = device + + def align(self, audio, expected_text, audio_sample_rate=24000): + orig_len = audio.shape[-1] + + with torch.no_grad(): + self.model = self.model.to(self.device) + audio = audio.to(self.device) + audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000) + clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) + logits = self.model(clip_norm).logits + self.model = self.model.cpu() + + logits = logits[0] + pred_string = self.tokenizer.decode(logits.argmax(-1).tolist()) + + fixed_expectation = max_alignment(expected_text.lower(), pred_string) + w2v_compression = orig_len // logits.shape[0] + expected_tokens = self.tokenizer.encode(fixed_expectation) + expected_chars = list(fixed_expectation) + if len(expected_tokens) == 1: + return [0] # The alignment is simple; there is only one token. + expected_tokens.pop(0) # The first token is a given. + expected_chars.pop(0) + + alignments = [0] + + def pop_till_you_win(): + if len(expected_tokens) == 0: + return None + popped = expected_tokens.pop(0) + popped_char = expected_chars.pop(0) + while popped_char == "~": + alignments.append(-1) + if len(expected_tokens) == 0: + return None + popped = expected_tokens.pop(0) + popped_char = expected_chars.pop(0) + return popped + + next_expected_token = pop_till_you_win() + for i, logit in enumerate(logits): + top = logit.argmax() + if next_expected_token == top: + alignments.append(i * w2v_compression) + if len(expected_tokens) > 0: + next_expected_token = pop_till_you_win() + else: + break + + pop_till_you_win() + if not (len(expected_tokens) == 0 and len(alignments) == len(expected_text)): + torch.save([audio, expected_text], "alignment_debug.pth") + assert False, ( + "Something went wrong with the alignment algorithm. I've dumped a file, 'alignment_debug.pth' to" + "your current working directory. Please report this along with the file so it can get fixed." + ) + + # Now fix up alignments. Anything with -1 should be interpolated. + alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable. + for i in range(len(alignments)): + if alignments[i] == -1: + for j in range(i + 1, len(alignments)): + if alignments[j] != -1: + next_found_token = j + break + for j in range(i, next_found_token): + gap = alignments[next_found_token] - alignments[i - 1] + alignments[j] = (j - i + 1) * gap // (next_found_token - i + 1) + alignments[i - 1] + + return alignments[:-1] + + def redact(self, audio, expected_text, audio_sample_rate=24000): + if "[" not in expected_text: + return audio + splitted = expected_text.split("[") + fully_split = [splitted[0]] + for spl in splitted[1:]: + assert "]" in spl, 'Every "[" character must be paired with a "]" with no nesting.' + fully_split.extend(spl.split("]")) + + # At this point, fully_split is a list of strings, with every other string being something that should be redacted. + non_redacted_intervals = [] + last_point = 0 + for i in range(len(fully_split)): + if i % 2 == 0: + end_interval = max(0, last_point + len(fully_split[i]) - 1) + non_redacted_intervals.append((last_point, end_interval)) + last_point += len(fully_split[i]) + + bare_text = "".join(fully_split) + alignments = self.align(audio, bare_text, audio_sample_rate) + + output_audio = [] + for nri in non_redacted_intervals: + start, stop = nri + output_audio.append(audio[:, alignments[start] : alignments[stop]]) + return torch.cat(output_audio, dim=-1) diff --git a/TTS/tts/layers/tortoise/xtransformers.py b/TTS/tts/layers/tortoise/xtransformers.py new file mode 100644 index 0000000000000000000000000000000000000000..9325b8c720f0cb73fdac39ef1286170644df5de3 --- /dev/null +++ b/TTS/tts/layers/tortoise/xtransformers.py @@ -0,0 +1,1257 @@ +import math +from collections import namedtuple +from functools import partial +from inspect import isfunction + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import einsum, nn + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"]) + +LayerIntermediates = namedtuple( + "Intermediates", + [ + "hiddens", + "attn_intermediates", + "past_key_values", + ], +) + + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def cast_tuple(val, depth): + return val if isinstance(val, tuple) else (val,) * depth + + +class always: + def __init__(self, val): + self.val = val + + def __call__(self, *args, **kwargs): + return self.val + + +class not_equals: + def __init__(self, val): + self.val = val + + def __call__(self, x, *args, **kwargs): + return x != self.val + + +class equals: + def __init__(self, val): + self.val = val + + def __call__(self, x, *args, **kwargs): + return x == self.val + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +def l2norm(t): + return F.normalize(t, p=2, dim=-1) + + +# init helpers + + +def init_zero_(layer): + nn.init.constant_(layer.weight, 0.0) + if exists(layer.bias): + nn.init.constant_(layer.bias, 0.0) + + +# keyword argument helpers + + +def pick_and_pop(keys, d): + values = [d.pop(key) for key in keys] + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = {x[0][len(prefix) :]: x[1] for x in tuple(kwargs_with_prefix.items())} + return kwargs_without_prefix, kwargs + + +# activations + + +class ReluSquared(nn.Module): + def forward(self, x): + return F.relu(x) ** 2 + + +# positional embeddings + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.scale = dim**-0.5 + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + pos_emb = self.emb(n) + pos_emb = rearrange(pos_emb, "n d -> () n d") + return pos_emb * self.scale + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return rearrange(emb, "n d -> () n d") + + +class RelativePositionBias(nn.Module): + def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8): + super().__init__() + self.scale = scale + self.causal = causal + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + if not causal: + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + else: + n = torch.max(n, torch.zeros_like(n)) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = ( + max_exact + + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long() + ) + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, qk_dots): + i, j, device = *qk_dots.shape[-2:], qk_dots.device + q_pos = torch.arange(i, dtype=torch.long, device=device) + k_pos = torch.arange(j, dtype=torch.long, device=device) + rel_pos = k_pos[None, :] - q_pos[:, None] + rp_bucket = self._relative_position_bucket( + rel_pos, causal=self.causal, num_buckets=self.num_buckets, max_distance=self.max_distance + ) + values = self.relative_attention_bias(rp_bucket) + bias = rearrange(values, "i j h -> () h i j") + return qk_dots + (bias * self.scale) + + +class AlibiPositionalBias(nn.Module): + def __init__(self, heads, **kwargs): + super().__init__() + self.heads = heads + slopes = torch.Tensor(self._get_slopes(heads)) + slopes = rearrange(slopes, "h -> () h () ()") + self.register_buffer("slopes", slopes, persistent=False) + self.register_buffer("bias", None, persistent=False) + + @staticmethod + def _get_slopes(heads): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(heads).is_integer(): + return get_slopes_power_of_2(heads) + + closest_power_of_2 = 2 ** math.floor(math.log2(heads)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][: heads - closest_power_of_2] + ) + + def forward(self, qk_dots): + h, i, j, device = *qk_dots.shape[-3:], qk_dots.device + + if exists(self.bias) and self.bias.shape[-1] >= j: + return qk_dots + self.bias[..., :j] + + bias = torch.arange(j, device=device) + bias = rearrange(bias, "j -> () () () j") + bias = bias * self.slopes + + num_heads_unalibied = h - bias.shape[1] + bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied)) + + self.register_buffer("bias", bias, persistent=False) + return qk_dots + self.bias + + +class LearnedAlibiPositionalBias(AlibiPositionalBias): + def __init__(self, heads, bidirectional=False): + super().__init__(heads) + los_slopes = torch.log(self.slopes) + self.learned_logslopes = nn.Parameter(los_slopes) + + self.bidirectional = bidirectional + if self.bidirectional: + self.learned_logslopes_future = nn.Parameter(los_slopes) + + def forward(self, qk_dots): + h, i, j, device = *qk_dots.shape[-3:], qk_dots.device + + def get_slopes(param): + return F.pad(param.exp(), (0, 0, 0, 0, 0, h - param.shape[1])) + + if exists(self.bias) and self.bias.shape[-1] >= j: + bias = self.bias[..., :i, :j] + else: + i_arange = torch.arange(i, device=device) + j_arange = torch.arange(j, device=device) + bias = rearrange(j_arange, "j -> 1 1 1 j") - rearrange(i_arange, "i -> 1 1 i 1") + self.register_buffer("bias", bias, persistent=False) + + if self.bidirectional: + past_slopes = get_slopes(self.learned_logslopes) + future_slopes = get_slopes(self.learned_logslopes_future) + bias = torch.tril(bias * past_slopes) + torch.triu(bias * future_slopes) + else: + slopes = get_slopes(self.learned_logslopes) + bias = bias * slopes + + return qk_dots + bias + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, max_seq_len, device): + t = torch.arange(max_seq_len, device=device).type_as(self.inv_freq) + freqs = torch.einsum("i , j -> i j", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return rearrange(emb, "n d -> () () n d") + + +def rotate_half(x): + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t, freqs): + seq_len = t.shape[-2] + freqs = freqs[:, :, -seq_len:] + return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) + + +# norms + + +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + out = self.fn(x, **kwargs) + scale_fn = lambda t: t * self.value + + if not isinstance(out, tuple): + return scale_fn(out) + + return (scale_fn(out[0]), *out[1:]) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + out = self.fn(x, **kwargs) + rezero_fn = lambda t: t * self.g + + if not isinstance(out, tuple): + return rezero_fn(out) + + return (rezero_fn(out[0]), *out[1:]) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSScaleShiftNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + self.scale_shift_process = nn.Linear(dim * 2, dim * 2) + + def forward(self, x, norm_scale_shift_inp): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + norm = x / norm.clamp(min=self.eps) * self.g + + ss_emb = self.scale_shift_process(norm_scale_shift_inp) + scale, shift = torch.chunk(ss_emb, 2, dim=1) + h = norm * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return h + + +# residual and residual gates + + +class Residual(nn.Module): + def __init__(self, dim, scale_residual=False): + super().__init__() + self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None + + def forward(self, x, residual): + if exists(self.residual_scale): + residual = residual * self.residual_scale + + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim, scale_residual=False): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None + + def forward(self, x, residual): + if exists(self.residual_scale): + residual = residual * self.residual_scale + + gated_output = self.gru(rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d")) + + return gated_output.reshape_as(x) + + +# token shifting + + +def shift(t, amount, mask=None): + if amount == 0: + return t + + if exists(mask): + t = t.masked_fill(~mask[..., None], 0.0) + + return F.pad(t, (0, 0, amount, -amount), value=0.0) + + +class ShiftTokens(nn.Module): + def __init__(self, shifts, fn): + super().__init__() + self.fn = fn + self.shifts = tuple(shifts) + + def forward(self, x, **kwargs): + mask = kwargs.get("mask", None) + shifts = self.shifts + segments = len(shifts) + feats_per_shift = x.shape[-1] // segments + splitted = x.split(feats_per_shift, dim=-1) + segments_to_shift, rest = splitted[:segments], splitted[segments:] + segments_to_shift = [shift(*args, mask=mask) for args in zip(segments_to_shift, shifts)] + x = torch.cat((*segments_to_shift, *rest), dim=-1) + return self.fn(x, **kwargs) + + +# feedforward + + +class GLU(nn.Module): + def __init__(self, dim_in, dim_out, activation): + super().__init__() + self.act = activation + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * self.act(gate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out=None, + mult=4, + glu=False, + relu_squared=False, + post_act_ln=False, + dropout=0.0, + zero_init_output=False, + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + activation = ReluSquared() if relu_squared else nn.GELU() + + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), activation) if not glu else GLU(dim, inner_dim, activation) + ) + + self.net = nn.Sequential( + project_in, + nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out), + ) + + # init last linear layer to 0 + if zero_init_output: + init_zero_(self.net[-1]) + + def forward(self, x): + return self.net(x) + + +# attention. + + +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + talking_heads=False, + head_scale=False, + collab_heads=False, + collab_compression=0.3, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0.0, + on_attn=False, + gate_values=False, + zero_init_output=False, + max_attend_past=None, + qk_norm=False, + scale_init_value=None, + rel_pos_bias=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + ): + super().__init__() + self.scale = dim_head**-0.5 + + self.heads = heads + self.causal = causal + self.max_attend_past = max_attend_past + + qk_dim = v_dim = dim_head * heads + + # collaborative heads + self.collab_heads = collab_heads + if self.collab_heads: + qk_dim = int(collab_compression * qk_dim) + self.collab_mixing = nn.Parameter(torch.randn(heads, qk_dim)) + + self.to_q = nn.Linear(dim, qk_dim, bias=False) + self.to_k = nn.Linear(dim, qk_dim, bias=False) + self.to_v = nn.Linear(dim, v_dim, bias=False) + + self.dropout = nn.Dropout(dropout) + + # add GLU gating for aggregated values, from alphafold2 + self.to_v_gate = None + if gate_values: + self.to_v_gate = nn.Linear(dim, v_dim) + nn.init.constant_(self.to_v_gate.weight, 0) + nn.init.constant_(self.to_v_gate.bias, 1) + + # cosine sim attention + self.qk_norm = qk_norm + if qk_norm: + scale_init_value = default( + scale_init_value, -3 + ) # if not provided, initialize as though it were sequence length of 1024 + self.scale = nn.Parameter(torch.ones(1, heads, 1, 1) * scale_init_value) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # head scaling + self.head_scale = head_scale + if head_scale: + self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(v_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(v_dim, dim) + + self.rel_pos_bias = rel_pos_bias + if rel_pos_bias: + assert ( + rel_pos_num_buckets <= rel_pos_max_distance + ), "number of relative position buckets must be less than the relative position max distance" + self.rel_pos = RelativePositionBias( + scale=dim_head**0.5, + causal=causal, + heads=heads, + num_buckets=rel_pos_num_buckets, + max_distance=rel_pos_max_distance, + ) + + # init output projection 0 + if zero_init_output: + init_zero_(self.to_out) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + attn_mask=None, + sinusoidal_emb=None, + rotary_pos_emb=None, + prev_attn=None, + mem=None, + layer_past=None, + ): + b, n, _, h, talking_heads, collab_heads, head_scale, scale, device, has_context = ( + *x.shape, + self.heads, + self.talking_heads, + self.collab_heads, + self.head_scale, + self.scale, + x.device, + exists(context), + ) + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + if not collab_heads: + q, k, v = (rearrange(t, "b n (h d) -> b h n d", h=h) for t in (q, k, v)) + else: + q = einsum("b i d, h d -> b h i d", q, self.collab_mixing) + k = rearrange(k, "b n d -> b () n d") + v = rearrange(v, "b n (h d) -> b h n d", h=h) + + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat([past_key, k], dim=-2) + v = torch.cat([past_value, v], dim=-2) + k_cache = k + v_cache = v + + if exists(rotary_pos_emb) and not has_context: + l = rotary_pos_emb.shape[-1] + (ql, qr), (kl, kr), (vl, vr) = ((t[..., :l], t[..., l:]) for t in (q, k, v)) + ql, kl, vl = (apply_rotary_pos_emb(t, rotary_pos_emb) for t in (ql, kl, vl)) + q, k, v = (torch.cat(t, dim=-1) for t in ((ql, qr), (kl, kr), (vl, vr))) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, "b i -> b () i ()") + k_mask = rearrange(k_mask, "b j -> b () () j") + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = (repeat(t, "h n d -> b h n d", b=b) for t in (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + if collab_heads: + k = k.expand(-1, h, -1, -1) + + if self.qk_norm: + q, k = map(l2norm, (q, k)) + scale = 1 / (self.scale.exp().clamp(min=1e-2)) + + dots = einsum("b h i d, b h j d -> b h i j", q, k) * scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots.clone() + + if talking_heads: + dots = einsum("b h i j, h k -> b k i j", dots, self.pre_softmax_proj).contiguous() + + if self.rel_pos_bias: + dots = self.rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if exists(attn_mask): + assert ( + 2 <= attn_mask.ndim <= 4 + ), "attention mask must have greater than 2 dimensions but less than or equal to 4" + if attn_mask.ndim == 2: + attn_mask = rearrange(attn_mask, "i j -> () () i j") + elif attn_mask.ndim == 3: + attn_mask = rearrange(attn_mask, "h i j -> () h i j") + dots.masked_fill_(~attn_mask, mask_value) + + if exists(self.max_attend_past): + i, j = dots.shape[-2:] + range_q = torch.arange(j - i, j, device=device) + range_k = torch.arange(j, device=device) + dist = rearrange(range_q, "i -> () () i ()") - rearrange(range_k, "j -> () () () j") + mask = dist > self.max_attend_past + dots.masked_fill_(mask, mask_value) + del mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, "i -> () () i ()") < rearrange(r, "j -> () () () j") + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn.clone() + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum("b h i j, h k -> b k i j", attn, self.post_softmax_proj).contiguous() + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + if head_scale: + out = out * self.head_scale_params + + out = rearrange(out, "b h n d -> b n (h d)") + + if exists(self.to_v_gate): + gates = self.to_v_gate(x) + out = out * gates.sigmoid() + + intermediates = Intermediates(pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn) + + return self.to_out(out), intermediates, k_cache, v_cache + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rms_scaleshift_norm=False, + use_rmsnorm=False, + use_rezero=False, + alibi_pos_bias=False, + alibi_num_heads=None, + alibi_learned=False, + position_infused_attn=False, + rotary_pos_emb=False, + rotary_emb_dim=None, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + scale_residual=False, + shift_tokens=0, + sandwich_norm=False, + use_qk_norm_attn=False, + qk_norm_attn_seq_len=None, + zero_init_branch_output=False, + **kwargs, + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs) + attn_kwargs, _ = groupby_prefix_and_trim("attn_", kwargs) + + dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + self.causal = causal + + rel_pos_bias = "rel_pos_bias" in attn_kwargs + self.has_pos_emb = position_infused_attn or rel_pos_bias or rotary_pos_emb + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + + rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32) + self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim) if rotary_pos_emb else None + + assert not ( + alibi_pos_bias and rel_pos_bias + ), "you can only choose Alibi positional bias or T5 relative positional bias, not both" + + if alibi_pos_bias: + alibi_num_heads = default(alibi_num_heads, heads) + assert alibi_num_heads <= heads, "number of ALiBi heads must be less than the total number of heads" + alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias + self.rel_pos = alibi_pos_klass(heads=alibi_num_heads, bidirectional=not causal) + else: + self.rel_pos = None + + assert not (not pre_norm and sandwich_norm), "sandwich norm cannot be used when not using prenorm" + self.pre_norm = pre_norm + self.sandwich_norm = sandwich_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + self.cross_attend = cross_attend + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_class = RMSScaleShiftNorm if use_rms_scaleshift_norm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ("a", "c", "f") + elif cross_attend and only_cross: + default_block = ("c", "f") + else: + default_block = ("a", "f") + + if macaron: + default_block = ("f",) + default_block + + # qk normalization + + if use_qk_norm_attn: + attn_scale_init_value = ( + -math.log(math.log2(qk_norm_attn_seq_len**2 - qk_norm_attn_seq_len)) + if exists(qk_norm_attn_seq_len) + else None + ) + attn_kwargs = {**attn_kwargs, "qk_norm": True, "scale_init_value": attn_scale_init_value} + + # zero init + + if zero_init_branch_output: + attn_kwargs = {**attn_kwargs, "zero_init_output": True} + ff_kwargs = {**ff_kwargs, "zero_init_output": True} + + # calculate layer block order + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, "par ratio out of range" + default_block = tuple(filter(not_equals("f"), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, "default block is too large for par_ratio" + par_block = default_block + ("f",) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ("f",) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, "sandwich coefficient should be less than the depth" + layer_types = ("a",) * sandwich_coef + default_block * (depth - sandwich_coef) + ("f",) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals("a"), layer_types))) + + # calculate token shifting + + shift_tokens = cast_tuple(shift_tokens, len(layer_types)) + + # iterate and construct layers + + for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)): + is_last_layer = ind == (len(self.layer_types) - 1) + + if layer_type == "a": + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == "c": + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == "f": + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f"invalid layer type {layer_type}") + + if layer_shift_tokens > 0: + shift_range_upper = layer_shift_tokens + 1 + shift_range_lower = -layer_shift_tokens if not causal else 0 + layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer) + + if exists(branch_fn): + layer = branch_fn(layer) + + residual_fn = GRUGating if gate_residual else Residual + residual = residual_fn(dim, scale_residual=scale_residual) + + layer_uses_qk_norm = use_qk_norm_attn and layer_type in ("a", "c") + + pre_branch_norm = norm_fn() if pre_norm and not layer_uses_qk_norm else None + post_branch_norm = norm_fn() if sandwich_norm or layer_uses_qk_norm else None + post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None + + norms = nn.ModuleList([pre_branch_norm, post_branch_norm, post_main_norm]) + + self.layers.append(nn.ModuleList([norms, layer, residual])) + + def forward( + self, + x, + context=None, + full_context=None, # for passing a list of hidden states from an encoder + mask=None, + context_mask=None, + attn_mask=None, + mems=None, + return_hiddens=False, + norm_scale_shift_inp=None, + past_key_values=None, + expected_seq_len=None, + ): + assert not ( + self.cross_attend ^ (exists(context) or exists(full_context)) + ), "context must be passed in if cross_attend is set to True" + assert context is None or full_context is None, "only one of full_context or context can be provided" + + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + norm_args = {} + if exists(norm_scale_shift_inp): + norm_args["norm_scale_shift_inp"] = norm_scale_shift_inp + + rotary_pos_emb = None + if exists(self.rotary_pos_emb): + if not self.training and self.causal: + assert ( + expected_seq_len is not None + ), "To decode a transformer with rotary embeddings, you must specify an `expected_seq_len`" + elif expected_seq_len is None: + expected_seq_len = 0 + seq_len = x.shape[1] + if past_key_values is not None: + seq_len += past_key_values[0][0].shape[-2] + max_rotary_emb_length = max([(m.shape[1] if exists(m) else 0) + seq_len for m in mems] + [expected_seq_len]) + rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) + + present_key_values = [] + cross_attn_count = 0 + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + if layer_type == "a": + layer_mem = mems.pop(0) if mems else None + + residual = x + + pre_branch_norm, post_branch_norm, post_main_norm = norm + + if exists(pre_branch_norm): + x = pre_branch_norm(x, **norm_args) + + if layer_type == "a" or layer_type == "c": + if past_key_values is not None: + layer_kv = past_key_values.pop(0) + layer_past = tuple(s.to(x.device) for s in layer_kv) + else: + layer_past = None + + if layer_type == "a": + out, inter, k, v = block( + x, None, mask, None, attn_mask, self.pia_pos_emb, rotary_pos_emb, prev_attn, layer_mem, layer_past + ) + elif layer_type == "c": + if exists(full_context): + out, inter, k, v = block( + x, + full_context[cross_attn_count], + mask, + context_mask, + None, + None, + None, + prev_attn, + None, + layer_past, + ) + else: + out, inter, k, v = block( + x, context, mask, context_mask, None, None, None, prev_attn, None, layer_past + ) + elif layer_type == "f": + out = block(x) + + if layer_type == "a" or layer_type == "c" and present_key_values is not None: + present_key_values.append((k.detach(), v.detach())) + + if exists(post_branch_norm): + out = post_branch_norm(out, **norm_args) + + x = residual_fn(out, residual) + + if layer_type in ("a", "c"): + intermediates.append(inter) + + if layer_type == "a" and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == "c" and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if exists(post_main_norm): + x = post_main_norm(x, **norm_args) + + if layer_type == "c": + cross_attn_count += 1 + + if layer_type == "f": + hiddens.append(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, attn_intermediates=intermediates, past_key_values=present_key_values + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert "causal" not in kwargs, "cannot set causality on encoder" + super().__init__(causal=False, **kwargs) + + +class Decoder(AttentionLayers): + def __init__(self, **kwargs): + assert "causal" not in kwargs, "cannot set causality on decoder" + super().__init__(causal=True, **kwargs) + + +class CrossAttender(AttentionLayers): + def __init__(self, **kwargs): + super().__init__(cross_attend=True, only_cross=True, **kwargs) + + +class ViTransformerWrapper(nn.Module): + def __init__(self, *, image_size, patch_size, attn_layers, num_classes=None, dropout=0.0, emb_dropout=0.0): + super().__init__() + assert isinstance(attn_layers, Encoder), "attention layers must be an Encoder" + assert image_size % patch_size == 0, "image dimensions must be divisible by the patch size" + dim = attn_layers.dim + num_patches = (image_size // patch_size) ** 2 + patch_dim = 3 * patch_size**2 + + self.patch_size = patch_size + + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) + self.patch_to_embedding = nn.Linear(patch_dim, dim) + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) + self.dropout = nn.Dropout(emb_dropout) + + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + self.mlp_head = FeedForward(dim, dim_out=num_classes, dropout=dropout) if exists(num_classes) else None + + def forward(self, img, return_embeddings=False): + p = self.patch_size + + x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p) + x = self.patch_to_embedding(x) + b, n, _ = x.shape + + cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embedding[:, : (n + 1)] + x = self.dropout(x) + + x = self.attn_layers(x) + x = self.norm(x) + + if not exists(self.mlp_head) or return_embeddings: + return x + + return self.mlp_head(x[:, 0]) + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0.0, + shift_mem_down=0, + emb_dropout=0.0, + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True, + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), "attention layers must be one of Encoder or Decoder" + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.shift_mem_down = shift_mem_down + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = ( + AbsolutePositionalEmbedding(emb_dim, max_seq_len) + if (use_pos_emb and not attn_layers.has_pos_emb) + else always(0) + ) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + def init_(self): + nn.init.kaiming_normal_(self.token_emb.weight) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_hiddens=False, + return_attn=False, + mems=None, + use_cache=False, + **kwargs, + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x = x + self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, "n d -> b n d", b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + if self.shift_mem_down and exists(mems): + mems_l, mems_r = mems[: self.shift_mem_down], mems[self.shift_mem_down :] + mems = [*mems_r, *mems_l] + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_hiddens: + hiddens = intermediates.hiddens + return out, hiddens + + res = [out] + if return_attn: + attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates] + res.append(attn_maps) + if use_cache: + res.append(intermediates.past_key_values) + + if len(res) > 1: + return tuple(res) + return res[0] + + +class ContinuousTransformerWrapper(nn.Module): + def __init__( + self, *, max_seq_len, attn_layers, dim_in=None, dim_out=None, emb_dim=None, emb_dropout=0.0, use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), "attention layers must be one of Encoder or Decoder" + + dim = attn_layers.dim + + self.max_seq_len = max_seq_len + + self.pos_emb = ( + AbsolutePositionalEmbedding(dim, max_seq_len) + if (use_pos_emb and not attn_layers.has_pos_emb) + else always(0) + ) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity() + + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity() + + def forward(self, x, return_embeddings=False, mask=None, return_attn=False, mems=None, use_cache=False, **kwargs): + b, n, _, device = *x.shape, x.device + + x = self.project_in(x) + x = x + self.pos_emb(x) + x = self.emb_dropout(x) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + out = self.project_out(x) if not return_embeddings else x + + res = [out] + if return_attn: + attn_maps = [t.post_softmax_attn for t in intermediates.attn_intermediates] + res.append(attn_maps) + if use_cache: + res.append(intermediates.past_key_values) + + if len(res) > 1: + return tuple(res) + return res[0] diff --git a/TTS/tts/layers/vits/discriminator.py b/TTS/tts/layers/vits/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..3449739fdc46771b6e47a7f1929ebfc76eaaa52d --- /dev/null +++ b/TTS/tts/layers/vits/discriminator.py @@ -0,0 +1,89 @@ +import torch +from torch import nn +from torch.nn.modules.conv import Conv1d + +from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP + + +class DiscriminatorS(torch.nn.Module): + """HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN. + + Args: + use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. + """ + + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + Tensor: discriminator scores. + List[Tensor]: list of features from the convolutiona layers. + """ + feat = [] + for l in self.convs: + x = l(x) + x = torch.nn.functional.leaky_relu(x, 0.1) + feat.append(x) + x = self.conv_post(x) + feat.append(x) + x = torch.flatten(x, 1, -1) + return x, feat + + +class VitsDiscriminator(nn.Module): + """VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator. + + :: + waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats + |--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^ + + Args: + use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. + """ + + def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False): + super().__init__() + self.nets = nn.ModuleList() + self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm)) + self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]) + + def forward(self, x, x_hat=None): + """ + Args: + x (Tensor): ground truth waveform. + x_hat (Tensor): predicted waveform. + + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of list of features from each layers of each discriminator. + """ + x_scores = [] + x_hat_scores = [] if x_hat is not None else None + x_feats = [] + x_hat_feats = [] if x_hat is not None else None + for net in self.nets: + x_score, x_feat = net(x) + x_scores.append(x_score) + x_feats.append(x_feat) + if x_hat is not None: + x_hat_score, x_hat_feat = net(x_hat) + x_hat_scores.append(x_hat_score) + x_hat_feats.append(x_hat_feat) + return x_scores, x_feats, x_hat_scores, x_hat_feats diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..50ed1024defa18297de6c917c64c30aad6183d6e --- /dev/null +++ b/TTS/tts/layers/vits/networks.py @@ -0,0 +1,272 @@ +import math + +import torch +from torch import nn + +from TTS.tts.layers.glow_tts.glow import WN +from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer +from TTS.tts.utils.helpers import sequence_mask + +LRELU_SLOPE = 0.1 + + +class TextEncoder(nn.Module): + def __init__( + self, + n_vocab: int, + out_channels: int, + hidden_channels: int, + hidden_channels_ffn: int, + num_heads: int, + num_layers: int, + kernel_size: int, + dropout_p: float, + language_emb_dim: int = None, + ): + """Text Encoder for VITS model. + + Args: + n_vocab (int): Number of characters for the embedding layer. + out_channels (int): Number of channels for the output. + hidden_channels (int): Number of channels for the hidden layers. + hidden_channels_ffn (int): Number of channels for the convolutional layers. + num_heads (int): Number of attention heads for the Transformer layers. + num_layers (int): Number of Transformer layers. + kernel_size (int): Kernel size for the FFN layers in Transformer network. + dropout_p (float): Dropout rate for the Transformer layers. + """ + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + + self.emb = nn.Embedding(n_vocab, hidden_channels) + + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + if language_emb_dim: + hidden_channels += language_emb_dim + + self.encoder = RelativePositionTransformer( + in_channels=hidden_channels, + out_channels=hidden_channels, + hidden_channels=hidden_channels, + hidden_channels_ffn=hidden_channels_ffn, + num_heads=num_heads, + num_layers=num_layers, + kernel_size=kernel_size, + dropout_p=dropout_p, + layer_norm_type="2", + rel_attn_window_size=4, + ) + + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, lang_emb=None): + """ + Shapes: + - x: :math:`[B, T]` + - x_length: :math:`[B]` + """ + assert x.shape[0] == x_lengths.shape[0] + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + + # concat the lang emb in embedding chars + if lang_emb is not None: + x = torch.cat((x, lang_emb.transpose(2, 1).expand(x.size(0), x.size(1), -1)), dim=-1) + + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # [b, 1, t] + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class ResidualCouplingBlock(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + dropout_p=0, + cond_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.half_channels = channels // 2 + self.mean_only = mean_only + # input layer + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + # coupling layers + self.enc = WN( + hidden_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + dropout_p=dropout_p, + c_in_channels=cond_channels, + ) + # output layer + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + """ + Note: + Set `reverse` to True for inference. + + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, log_scale = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + log_scale = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(log_scale) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(log_scale, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-log_scale) * x_mask + x = torch.cat([x0, x1], 1) + return x + + +class ResidualCouplingBlocks(nn.Module): + def __init__( + self, + channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + num_layers: int, + num_flows=4, + cond_channels=0, + ): + """Redisual Coupling blocks for VITS flow layers. + + Args: + channels (int): Number of input and output tensor channels. + hidden_channels (int): Number of hidden network channels. + kernel_size (int): Kernel size of the WaveNet layers. + dilation_rate (int): Dilation rate of the WaveNet layers. + num_layers (int): Number of the WaveNet layers. + num_flows (int, optional): Number of Residual Coupling blocks. Defaults to 4. + cond_channels (int, optional): Number of channels of the conditioning tensor. Defaults to 0. + """ + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.num_flows = num_flows + self.cond_channels = cond_channels + + self.flows = nn.ModuleList() + for _ in range(num_flows): + self.flows.append( + ResidualCouplingBlock( + channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + cond_channels=cond_channels, + mean_only=True, + ) + ) + + def forward(self, x, x_mask, g=None, reverse=False): + """ + Note: + Set `reverse` to True for inference. + + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - g: :math:`[B, C, 1]` + """ + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + x = torch.flip(x, [1]) + else: + for flow in reversed(self.flows): + x = torch.flip(x, [1]) + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + num_layers: int, + cond_channels=0, + ): + """Posterior Encoder of VITS model. + + :: + x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z + + Args: + in_channels (int): Number of input tensor channels. + out_channels (int): Number of output tensor channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size of the WaveNet convolution layers. + dilation_rate (int): Dilation rate of the WaveNet layers. + num_layers (int): Number of the WaveNet layers. + cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.num_layers = num_layers + self.cond_channels = cond_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels=cond_channels + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_lengths: :math:`[B, 1]` + - g: :math:`[B, C, 1]` + """ + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + mean, log_scale = torch.split(stats, self.out_channels, dim=1) + z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask + return z, mean, log_scale, x_mask diff --git a/TTS/tts/layers/vits/stochastic_duration_predictor.py b/TTS/tts/layers/vits/stochastic_duration_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..98dbf0935ca0f6cd6e92fe6ecf063dde2ee4138f --- /dev/null +++ b/TTS/tts/layers/vits/stochastic_duration_predictor.py @@ -0,0 +1,294 @@ +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.generic.normalization import LayerNorm2 +from TTS.tts.layers.vits.transforms import piecewise_rational_quadratic_transform + + +class DilatedDepthSeparableConv(nn.Module): + def __init__(self, channels, kernel_size, num_layers, dropout_p=0.0) -> torch.tensor: + """Dilated Depth-wise Separable Convolution module. + + :: + x |-> DDSConv(x) -> LayerNorm(x) -> GeLU(x) -> Conv1x1(x) -> LayerNorm(x) -> GeLU(x) -> + -> o + |-------------------------------------------------------------------------------------^ + + Args: + channels ([type]): [description] + kernel_size ([type]): [description] + num_layers ([type]): [description] + dropout_p (float, optional): [description]. Defaults to 0.0. + + Returns: + torch.tensor: Network output masked by the input sequence mask. + """ + super().__init__() + self.num_layers = num_layers + + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(num_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm2(channels)) + self.norms_2.append(LayerNorm2(channels)) + self.dropout = nn.Dropout(dropout_p) + + def forward(self, x, x_mask, g=None): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + """ + if g is not None: + x = x + g + for i in range(self.num_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.dropout(y) + x = x + y + return x * x_mask + + +class ElementwiseAffine(nn.Module): + """Element-wise affine transform like no-population stats BatchNorm alternative. + + Args: + channels (int): Number of input tensor channels. + """ + + def __init__(self, channels): + super().__init__() + self.translation = nn.Parameter(torch.zeros(channels, 1)) + self.log_scale = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): # pylint: disable=unused-argument + if not reverse: + y = (x * torch.exp(self.log_scale) + self.translation) * x_mask + logdet = torch.sum(self.log_scale * x_mask, [1, 2]) + return y, logdet + x = (x - self.translation) * torch.exp(-self.log_scale) * x_mask + return x + + +class ConvFlow(nn.Module): + """Dilated depth separable convolutional based spline flow. + + Args: + in_channels (int): Number of input tensor channels. + hidden_channels (int): Number of in network channels. + kernel_size (int): Convolutional kernel size. + num_layers (int): Number of convolutional layers. + num_bins (int, optional): Number of spline bins. Defaults to 10. + tail_bound (float, optional): Tail bound for PRQT. Defaults to 5.0. + """ + + def __init__( + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + num_layers: int, + num_bins=10, + tail_bound=5.0, + ): + super().__init__() + self.num_bins = num_bins + self.tail_bound = tail_bound + self.hidden_channels = hidden_channels + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers, dropout_p=0.0) + self.proj = nn.Conv1d(hidden_channels, self.half_channels * (num_bins * 3 - 1), 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.hidden_channels) + unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.hidden_channels) + unnormalized_derivatives = h[..., 2 * self.num_bins :] + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + return x + + +class StochasticDurationPredictor(nn.Module): + """Stochastic duration predictor with Spline Flows. + + It applies Variational Dequantization and Variational Data Augmentation. + + Paper: + SDP: https://arxiv.org/pdf/2106.06103.pdf + Spline Flow: https://arxiv.org/abs/1906.04032 + + :: + ## Inference + + x -> TextCondEncoder() -> Flow() -> dr_hat + noise ----------------------^ + + ## Training + |---------------------| + x -> TextCondEncoder() -> + -> PosteriorEncoder() -> split() -> z_u, z_v -> (d - z_u) -> concat() -> Flow() -> noise + d -> DurCondEncoder() -> ^ | + |------------------------------------------------------------------------------| + + Args: + in_channels (int): Number of input tensor channels. + hidden_channels (int): Number of hidden channels. + kernel_size (int): Kernel size of convolutional layers. + dropout_p (float): Dropout rate. + num_flows (int, optional): Number of flow blocks. Defaults to 4. + cond_channels (int, optional): Number of channels of conditioning tensor. Defaults to 0. + """ + + def __init__( + self, + in_channels: int, + hidden_channels: int, + kernel_size: int, + dropout_p: float, + num_flows=4, + cond_channels=0, + language_emb_dim=0, + ): + super().__init__() + + # add language embedding dim in the input + if language_emb_dim: + in_channels += language_emb_dim + + # condition encoder text + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p) + self.proj = nn.Conv1d(hidden_channels, hidden_channels, 1) + + # posterior encoder + self.flows = nn.ModuleList() + self.flows.append(ElementwiseAffine(2)) + self.flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)] + + # condition encoder duration + self.post_pre = nn.Conv1d(1, hidden_channels, 1) + self.post_convs = DilatedDepthSeparableConv(hidden_channels, kernel_size, num_layers=3, dropout_p=dropout_p) + self.post_proj = nn.Conv1d(hidden_channels, hidden_channels, 1) + + # flow layers + self.post_flows = nn.ModuleList() + self.post_flows.append(ElementwiseAffine(2)) + self.post_flows += [ConvFlow(2, hidden_channels, kernel_size, num_layers=3) for _ in range(num_flows)] + + if cond_channels != 0 and cond_channels is not None: + self.cond = nn.Conv1d(cond_channels, hidden_channels, 1) + + if language_emb_dim != 0 and language_emb_dim is not None: + self.cond_lang = nn.Conv1d(language_emb_dim, hidden_channels, 1) + + def forward(self, x, x_mask, dr=None, g=None, lang_emb=None, reverse=False, noise_scale=1.0): + """ + Shapes: + - x: :math:`[B, C, T]` + - x_mask: :math:`[B, 1, T]` + - dr: :math:`[B, 1, T]` + - g: :math:`[B, C]` + """ + # condition encoder text + x = self.pre(x) + if g is not None: + x = x + self.cond(g) + + if lang_emb is not None: + x = x + self.cond_lang(lang_emb) + + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert dr is not None + + # condition encoder duration + h = self.post_pre(dr) + h = self.post_convs(h, x_mask) + h = self.post_proj(h) * x_mask + noise = torch.randn(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = noise + + # posterior encoder + logdet_tot_q = 0.0 + for idx, flow in enumerate(self.post_flows): + z_q, logdet_q = flow(z_q, x_mask, g=(x + h)) + logdet_tot_q = logdet_tot_q + logdet_q + if idx > 0: + z_q = torch.flip(z_q, [1]) + + z_u, z_v = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (dr - u) * x_mask + + # posterior encoder - neg log likelihood + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) + nll_posterior_encoder = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (noise**2)) * x_mask, [1, 2]) - logdet_tot_q + ) + + z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask + logdet_tot = torch.sum(-z0, [1, 2]) + z = torch.cat([z0, z_v], 1) + + # flow layers + for idx, flow in enumerate(flows): + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + if idx > 0: + z = torch.flip(z, [1]) + + # flow layers - neg log likelihood + nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot + return nll_flow_layers + nll_posterior_encoder + + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = torch.flip(z, [1]) + z = flow(z, x_mask, g=x, reverse=reverse) + + z0, _ = torch.split(z, [1, 1], 1) + logw = z0 + return logw diff --git a/TTS/tts/layers/vits/transforms.py b/TTS/tts/layers/vits/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..3cac1b8d6d12fe98123ca554899978782cf3b4c5 --- /dev/null +++ b/TTS/tts/layers/vits/transforms.py @@ -0,0 +1,202 @@ +# adopted from https://github.com/bayesiains/nflows + +import numpy as np +import torch +from torch.nn import functional as F + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs, + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 + + +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/TTS/tts/layers/xtts/__init__.py b/TTS/tts/layers/xtts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/layers/xtts/dvae.py b/TTS/tts/layers/xtts/dvae.py new file mode 100644 index 0000000000000000000000000000000000000000..73970fb0bfeb7cda7c2fd1acafeb5c684132beae --- /dev/null +++ b/TTS/tts/layers/xtts/dvae.py @@ -0,0 +1,398 @@ +import functools +import logging +from math import sqrt + +import torch +import torch.distributed as distributed +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from einops import rearrange + +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 + +logger = logging.getLogger(__name__) + + +def default(val, d): + return val if val is not None else d + + +def eval_decorator(fn): + def inner(model, *args, **kwargs): + was_training = model.training + model.eval() + out = fn(model, *args, **kwargs) + model.train(was_training) + return out + + return inner + + +def dvae_wav_to_mel( + wav, mel_norms_file="../experiments/clips_mel_norms.pth", mel_norms=None, device=torch.device("cpu") +): + mel_stft = torchaudio.transforms.MelSpectrogram( + n_fft=1024, + hop_length=256, + win_length=1024, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, + norm="slaney", + ).to(device) + wav = wav.to(device) + mel = mel_stft(wav) + mel = torch.log(torch.clamp(mel, min=1e-5)) + if mel_norms is None: + mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=is_pytorch_at_least_2_4()) + mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) + return mel + + +class Quantize(nn.Module): + def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, balancing_heuristic=False, new_return_order=False): + super().__init__() + + self.dim = dim + self.n_embed = n_embed + self.decay = decay + self.eps = eps + + self.balancing_heuristic = balancing_heuristic + self.codes = None + self.max_codes = 64000 + self.codes_full = False + self.new_return_order = new_return_order + + embed = torch.randn(dim, n_embed) + self.register_buffer("embed", embed) + self.register_buffer("cluster_size", torch.zeros(n_embed)) + self.register_buffer("embed_avg", embed.clone()) + + def forward(self, input, return_soft_codes=False): + if self.balancing_heuristic and self.codes_full: + h = torch.histc(self.codes, bins=self.n_embed, min=0, max=self.n_embed) / len(self.codes) + mask = torch.logical_or(h > 0.9, h < 0.01).unsqueeze(1) + ep = self.embed.permute(1, 0) + ea = self.embed_avg.permute(1, 0) + rand_embed = torch.randn_like(ep) * mask + self.embed = (ep * ~mask + rand_embed).permute(1, 0) + self.embed_avg = (ea * ~mask + rand_embed).permute(1, 0) + self.cluster_size = self.cluster_size * ~mask.squeeze() + if torch.any(mask): + logger.info("Reset %d embedding codes.", torch.sum(mask)) + self.codes = None + self.codes_full = False + + flatten = input.reshape(-1, self.dim) + dist = flatten.pow(2).sum(1, keepdim=True) - 2 * flatten @ self.embed + self.embed.pow(2).sum(0, keepdim=True) + soft_codes = -dist + _, embed_ind = soft_codes.max(1) + embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype) + embed_ind = embed_ind.view(*input.shape[:-1]) + quantize = self.embed_code(embed_ind) + + if self.balancing_heuristic: + if self.codes is None: + self.codes = embed_ind.flatten() + else: + self.codes = torch.cat([self.codes, embed_ind.flatten()]) + if len(self.codes) > self.max_codes: + self.codes = self.codes[-self.max_codes :] + self.codes_full = True + + if self.training: + embed_onehot_sum = embed_onehot.sum(0) + embed_sum = flatten.transpose(0, 1) @ embed_onehot + + if distributed.is_initialized() and distributed.get_world_size() > 1: + distributed.all_reduce(embed_onehot_sum) + distributed.all_reduce(embed_sum) + + self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay) + self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay) + n = self.cluster_size.sum() + cluster_size = (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n + embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) + self.embed.data.copy_(embed_normalized) + + diff = (quantize.detach() - input).pow(2).mean() + quantize = input + (quantize - input).detach() + + if return_soft_codes: + return quantize, diff, embed_ind, soft_codes.view(input.shape[:-1] + (-1,)) + elif self.new_return_order: + return quantize, embed_ind, diff + else: + return quantize, diff, embed_ind + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.embed.transpose(0, 1)) + + +# Fits a soft-discretized input to a normal-PDF across the specified dimension. +# In other words, attempts to force the discretization function to have a mean equal utilization across all discrete +# values with the specified expected variance. +class DiscretizationLoss(nn.Module): + def __init__(self, discrete_bins, dim, expected_variance, store_past=0): + super().__init__() + self.discrete_bins = discrete_bins + self.dim = dim + self.dist = torch.distributions.Normal(0, scale=expected_variance) + if store_past > 0: + self.record_past = True + self.register_buffer("accumulator_index", torch.zeros(1, dtype=torch.long, device="cpu")) + self.register_buffer("accumulator_filled", torch.zeros(1, dtype=torch.long, device="cpu")) + self.register_buffer("accumulator", torch.zeros(store_past, discrete_bins)) + else: + self.record_past = False + + def forward(self, x): + other_dims = set(range(len(x.shape))) - set([self.dim]) + averaged = x.sum(dim=tuple(other_dims)) / x.sum() + averaged = averaged - averaged.mean() + + if self.record_past: + acc_count = self.accumulator.shape[0] + avg = averaged.detach().clone() + if self.accumulator_filled > 0: + averaged = torch.mean(self.accumulator, dim=0) * (acc_count - 1) / acc_count + averaged / acc_count + + # Also push averaged into the accumulator. + self.accumulator[self.accumulator_index] = avg + self.accumulator_index += 1 + if self.accumulator_index >= acc_count: + self.accumulator_index *= 0 + if self.accumulator_filled <= 0: + self.accumulator_filled += 1 + + return torch.sum(-self.dist.log_prob(averaged)) + + +class ResBlock(nn.Module): + def __init__(self, chan, conv, activation): + super().__init__() + self.net = nn.Sequential( + conv(chan, chan, 3, padding=1), + activation(), + conv(chan, chan, 3, padding=1), + activation(), + conv(chan, chan, 1), + ) + + def forward(self, x): + return self.net(x) + x + + +class UpsampledConv(nn.Module): + def __init__(self, conv, *args, **kwargs): + super().__init__() + assert "stride" in kwargs.keys() + self.stride = kwargs["stride"] + del kwargs["stride"] + self.conv = conv(*args, **kwargs) + + def forward(self, x): + up = nn.functional.interpolate(x, scale_factor=self.stride, mode="nearest") + return self.conv(up) + + +# DiscreteVAE partially derived from lucidrains DALLE implementation +# Credit: https://github.com/lucidrains/DALLE-pytorch +class DiscreteVAE(nn.Module): + def __init__( + self, + positional_dims=2, + num_tokens=512, + codebook_dim=512, + num_layers=3, + num_resnet_blocks=0, + hidden_dim=64, + channels=3, + stride=2, + kernel_size=4, + use_transposed_convs=True, + encoder_norm=False, + activation="relu", + smooth_l1_loss=False, + straight_through=False, + normalization=None, # ((0.5,) * 3, (0.5,) * 3), + record_codes=False, + discretization_loss_averaging_steps=100, + lr_quantizer_args={}, + ): + super().__init__() + has_resblocks = num_resnet_blocks > 0 + + self.num_tokens = num_tokens + self.num_layers = num_layers + self.straight_through = straight_through + self.positional_dims = positional_dims + self.discrete_loss = DiscretizationLoss( + num_tokens, 2, 1 / (num_tokens * 2), discretization_loss_averaging_steps + ) + + assert positional_dims > 0 and positional_dims < 3 # This VAE only supports 1d and 2d inputs for now. + if positional_dims == 2: + conv = nn.Conv2d + conv_transpose = nn.ConvTranspose2d + else: + conv = nn.Conv1d + conv_transpose = nn.ConvTranspose1d + if not use_transposed_convs: + conv_transpose = functools.partial(UpsampledConv, conv) + + if activation == "relu": + act = nn.ReLU + elif activation == "silu": + act = nn.SiLU + else: + assert NotImplementedError() + + enc_layers = [] + dec_layers = [] + + if num_layers > 0: + enc_chans = [hidden_dim * 2**i for i in range(num_layers)] + dec_chans = list(reversed(enc_chans)) + + enc_chans = [channels, *enc_chans] + + dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0] + dec_chans = [dec_init_chan, *dec_chans] + + enc_chans_io, dec_chans_io = (list(zip(t[:-1], t[1:])) for t in (enc_chans, dec_chans)) + + pad = (kernel_size - 1) // 2 + for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): + enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride=stride, padding=pad), act())) + if encoder_norm: + enc_layers.append(nn.GroupNorm(8, enc_out)) + dec_layers.append( + nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride=stride, padding=pad), act()) + ) + dec_out_chans = dec_chans[-1] + innermost_dim = dec_chans[0] + else: + enc_layers.append(nn.Sequential(conv(channels, hidden_dim, 1), act())) + dec_out_chans = hidden_dim + innermost_dim = hidden_dim + + for _ in range(num_resnet_blocks): + dec_layers.insert(0, ResBlock(innermost_dim, conv, act)) + enc_layers.append(ResBlock(innermost_dim, conv, act)) + + if num_resnet_blocks > 0: + dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1)) + + enc_layers.append(conv(innermost_dim, codebook_dim, 1)) + dec_layers.append(conv(dec_out_chans, channels, 1)) + + self.encoder = nn.Sequential(*enc_layers) + self.decoder = nn.Sequential(*dec_layers) + + self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss + self.codebook = Quantize(codebook_dim, num_tokens, new_return_order=True) + + # take care of normalization within class + self.normalization = normalization + self.record_codes = record_codes + if record_codes: + self.codes = torch.zeros((1228800,), dtype=torch.long) + self.code_ind = 0 + self.total_codes = 0 + self.internal_step = 0 + + def norm(self, images): + if not self.normalization is not None: + return images + + means, stds = (torch.as_tensor(t).to(images) for t in self.normalization) + arrange = "c -> () c () ()" if self.positional_dims == 2 else "c -> () c ()" + means, stds = (rearrange(t, arrange) for t in (means, stds)) + images = images.clone() + images.sub_(means).div_(stds) + return images + + def get_debug_values(self, step, __): + if self.record_codes and self.total_codes > 0: + # Report annealing schedule + return {"histogram_codes": self.codes[: self.total_codes]} + else: + return {} + + @torch.no_grad() + @eval_decorator + def get_codebook_indices(self, images): + img = self.norm(images) + logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) + sampled, codes, _ = self.codebook(logits) + self.log_codes(codes) + return codes + + def decode(self, img_seq): + self.log_codes(img_seq) + if hasattr(self.codebook, "embed_code"): + image_embeds = self.codebook.embed_code(img_seq) + else: + image_embeds = F.embedding(img_seq, self.codebook.codebook) + b, n, d = image_embeds.shape + + kwargs = {} + if self.positional_dims == 1: + arrange = "b n d -> b d n" + else: + h = w = int(sqrt(n)) + arrange = "b (h w) d -> b d h w" + kwargs = {"h": h, "w": w} + image_embeds = rearrange(image_embeds, arrange, **kwargs) + images = [image_embeds] + for layer in self.decoder: + images.append(layer(images[-1])) + return images[-1], images[-2] + + def infer(self, img): + img = self.norm(img) + logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) + sampled, codes, commitment_loss = self.codebook(logits) + return self.decode(codes) + + # Note: This module is not meant to be run in forward() except while training. It has special logic which performs + # evaluation using quantized values when it detects that it is being run in eval() mode, which will be substantially + # more lossy (but useful for determining network performance). + def forward(self, img): + img = self.norm(img) + logits = self.encoder(img).permute((0, 2, 3, 1) if len(img.shape) == 4 else (0, 2, 1)) + sampled, codes, commitment_loss = self.codebook(logits) + sampled = sampled.permute((0, 3, 1, 2) if len(img.shape) == 4 else (0, 2, 1)) + + if self.training: + out = sampled + for d in self.decoder: + out = d(out) + self.log_codes(codes) + else: + # This is non-differentiable, but gives a better idea of how the network is actually performing. + out, _ = self.decode(codes) + + # reconstruction loss + recon_loss = self.loss_fn(img, out, reduction="none") + + return recon_loss, commitment_loss, out + + def log_codes(self, codes): + # This is so we can debug the distribution of codes being learned. + if self.record_codes and self.internal_step % 10 == 0: + codes = codes.flatten() + l = codes.shape[0] + i = self.code_ind if (self.codes.shape[0] - self.code_ind) > l else self.codes.shape[0] - l + self.codes[i : i + l] = codes.cpu() + self.code_ind = self.code_ind + l + if self.code_ind >= self.codes.shape[0]: + self.code_ind = 0 + self.total_codes += 1 + self.internal_step += 1 diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..b3c3b31b47625b9f60c1a18cb4756bf32790638e --- /dev/null +++ b/TTS/tts/layers/xtts/gpt.py @@ -0,0 +1,614 @@ +# ported from: https://github.com/neonbjb/tortoise-tts + +import functools +import random + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import GPT2Config + +from TTS.tts.layers.tortoise.autoregressive import _prepare_attention_mask_for_generation +from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel +from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder +from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler + + +def null_position_embeddings(range, dim): + return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) + + +class LearnedPositionEmbeddings(nn.Module): + def __init__(self, seq_len, model_dim, init=0.02, relative=False): + super().__init__() + # nn.Embedding + self.emb = torch.nn.Embedding(seq_len, model_dim) + # Initializing this way is standard for GPT-2 + self.emb.weight.data.normal_(mean=0.0, std=init) + self.relative = relative + self.seq_len = seq_len + + def forward(self, x): + sl = x.shape[1] + if self.relative: + start = random.randint(sl, self.seq_len) - sl + return self.emb(torch.arange(start, start + sl, device=x.device)) + else: + return self.emb(torch.arange(0, sl, device=x.device)) + + def get_fixed_embedding(self, ind, dev): + return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) + + +def build_hf_gpt_transformer( + layers, + model_dim, + heads, + max_mel_seq_len, + max_text_seq_len, + max_prompt_len, + checkpointing, +): + """ + GPT-2 implemented by the HuggingFace library. + """ + from transformers import GPT2Config, GPT2Model + + gpt_config = GPT2Config( + vocab_size=256, # Unused. + n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len, + n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + gradient_checkpointing=checkpointing, + use_cache=not checkpointing, + ) + gpt = GPT2Model(gpt_config) + # Override the built in positional embeddings + del gpt.wpe + gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) + # Built-in token embeddings are unused. + del gpt.wte + + mel_pos_emb = ( + LearnedPositionEmbeddings(max_mel_seq_len, model_dim) + if max_mel_seq_len != -1 + else functools.partial(null_position_embeddings, dim=model_dim) + ) + text_pos_emb = ( + LearnedPositionEmbeddings(max_text_seq_len, model_dim) + if max_mel_seq_len != -1 + else functools.partial(null_position_embeddings, dim=model_dim) + ) + # gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True) + return gpt, mel_pos_emb, text_pos_emb, None, None + + +class GPT(nn.Module): + def __init__( + self, + start_text_token=261, + stop_text_token=0, + layers=8, + model_dim=512, + heads=8, + max_text_tokens=120, + max_mel_tokens=250, + max_prompt_tokens=70, + max_conditioning_inputs=1, + code_stride_len=1024, + number_text_tokens=256, + num_audio_tokens=8194, + start_audio_token=8192, + stop_audio_token=8193, + train_solo_embeddings=False, + checkpointing=False, + average_conditioning_embeddings=False, + label_smoothing=0.0, + use_perceiver_resampler=False, + perceiver_cond_length_compression=256, + ): + """ + Args: + + """ + super().__init__() + + self.label_smoothing = label_smoothing + self.number_text_tokens = number_text_tokens + self.start_text_token = start_text_token + self.stop_text_token = stop_text_token + self.num_audio_tokens = num_audio_tokens + self.start_audio_token = start_audio_token + self.stop_audio_token = stop_audio_token + self.start_prompt_token = start_audio_token + self.stop_prompt_token = stop_audio_token + self.layers = layers + self.heads = heads + self.model_dim = model_dim + self.max_conditioning_inputs = max_conditioning_inputs + self.max_gen_mel_tokens = max_mel_tokens - self.max_conditioning_inputs - 2 + self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs + self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2 + self.max_prompt_tokens = max_prompt_tokens + self.code_stride_len = code_stride_len + self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) + self.conditioning_dropout = nn.Dropout1d(0.1) + self.average_conditioning_embeddings = average_conditioning_embeddings + self.use_perceiver_resampler = use_perceiver_resampler + self.perceiver_cond_length_compression = perceiver_cond_length_compression + + self.text_embedding = nn.Embedding(self.number_text_tokens, model_dim) + self.mel_embedding = nn.Embedding(self.num_audio_tokens, model_dim) + + ( + self.gpt, + self.mel_pos_embedding, + self.text_pos_embedding, + self.mel_layer_pos_embedding, + self.text_layer_pos_embedding, + ) = build_hf_gpt_transformer( + layers, + model_dim, + heads, + self.max_mel_tokens, + self.max_text_tokens, + self.max_prompt_tokens, + checkpointing, + ) + if train_solo_embeddings: + self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) + self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) + else: + self.mel_solo_embedding = 0 + self.text_solo_embedding = 0 + + self.final_norm = nn.LayerNorm(model_dim) + self.text_head = nn.Linear(model_dim, self.number_text_tokens) + self.mel_head = nn.Linear(model_dim, self.num_audio_tokens) + + if self.use_perceiver_resampler: + # XTTS v2 + self.conditioning_perceiver = PerceiverResampler( + dim=model_dim, + depth=2, + dim_context=model_dim, + num_latents=32, + dim_head=64, + heads=8, + ff_mult=4, + use_flash_attn=False, + ) + else: + # XTTS v1 + self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim) + self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim) + + def get_grad_norm_parameter_groups(self): + return { + "conditioning_encoder": list(self.conditioning_encoder.parameters()), + "conditioning_perceiver": ( + list(self.conditioning_perceiver.parameters()) if self.use_perceiver_resampler else None + ), + "gpt": list(self.gpt.parameters()), + "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()), + } + + def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False): + seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1 + gpt_config = GPT2Config( + vocab_size=self.max_mel_tokens, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=self.model_dim, + n_layer=self.layers, + n_head=self.heads, + gradient_checkpointing=False, + use_cache=True, + ) + self.gpt_inference = GPT2InferenceModel( + gpt_config, + self.gpt, + self.mel_pos_embedding, + self.mel_embedding, + self.final_norm, + self.mel_head, + kv_cache=kv_cache, + ) + self.gpt.wte = self.mel_embedding + + if use_deepspeed: + import deepspeed + + self.ds_engine = deepspeed.init_inference( + model=self.gpt_inference.half(), # Transformers models + mp_size=1, # Number of GPU + dtype=torch.float32, # desired data type of output + replace_method="auto", # Lets DS autmatically identify the layer to replace + replace_with_kernel_inject=True, # replace the model with the kernel injector + ) + self.gpt_inference = self.ds_engine.module.eval() + + def set_inputs_and_targets(self, input, start_token, stop_token): + inp = F.pad(input, (1, 0), value=start_token) + tar = F.pad(input, (0, 1), value=stop_token) + return inp, tar + + def set_mel_padding(self, mel_input_tokens, code_lengths): + """ + Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in + that audio clip, reformats the tokens with stop_audio_token in place of the zero padding. This is required + preformatting to create a working TTS model. + """ + # Set padding areas within MEL (currently it is coded with the MEL code for ). + for b in range(len(code_lengths)): + actual_end = code_lengths[b] + if actual_end < mel_input_tokens.shape[-1]: + mel_input_tokens[b, actual_end:] = self.stop_audio_token + return mel_input_tokens + + def get_logits( + self, + first_inputs, + first_head, + second_inputs=None, + second_head=None, + prompt=None, + get_attns=False, + return_latent=False, + attn_mask_cond=None, + attn_mask_text=None, + attn_mask_mel=None, + ): + if prompt is not None: + offset = prompt.shape[1] + if second_inputs is not None: + emb = torch.cat([prompt, first_inputs, second_inputs], dim=1) + else: + emb = torch.cat([prompt, first_inputs], dim=1) + + # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + attn_mask = None + if attn_mask_text is not None: + attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1) + if prompt is not None: + attn_mask_cond = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device) + attn_mask = torch.cat([attn_mask_cond, attn_mask], dim=1) + + gpt_out = self.gpt( + inputs_embeds=emb, + return_dict=True, + output_attentions=get_attns, + attention_mask=attn_mask, + ) + + if get_attns: + return gpt_out.attentions + + enc = gpt_out.last_hidden_state[:, offset:] + enc = self.final_norm(enc) + + if return_latent: + return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :] + + first_logits = enc[:, : first_inputs.shape[1]] + first_logits = first_head(first_logits) + first_logits = first_logits.permute(0, 2, 1) + if second_inputs is not None: + second_logits = enc[:, -second_inputs.shape[1] :] + second_logits = second_head(second_logits) + second_logits = second_logits.permute(0, 2, 1) + return first_logits, second_logits + else: + return first_logits + + def get_conditioning(self, speech_conditioning_input): + speech_conditioning_input = ( + speech_conditioning_input.unsqueeze(1) + if len(speech_conditioning_input.shape) == 3 + else speech_conditioning_input + ) + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds = torch.stack(conds, dim=1) + conds = conds.mean(dim=1) + return conds + + def get_prompts(self, prompt_codes): + """ + Create a prompt from the mel codes. This is used to condition the model on the mel codes. + Pad the prompt with start and stop mel tokens. + """ + prompt = prompt_codes + if self.training: + lengths = [] + # Compute the real prompt length based on the first encounter with the token 83 used for padding + for i in range(prompt_codes.shape[0]): + length = 0 + for j in range(prompt_codes.shape[1]): + if prompt_codes[i, j] == 83: + break + else: + length += 1 + lengths.append(length) + + # prompt_len = random.randint(1, 9) # in secs + prompt_len = 3 + prompt_len = prompt_len * 24 # in frames + if prompt_codes.shape[-1] >= prompt_len: + for i in range(prompt_codes.shape[0]): + if lengths[i] < prompt_len: + start = 0 + else: + start = random.randint(0, lengths[i] - prompt_len) + prompt = prompt_codes[:, start : start + prompt_len] + + # add start and stop tokens + prompt = F.pad(prompt, (1, 0), value=self.start_prompt_token) + prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token) + return prompt + + def get_style_emb(self, cond_input, return_latent=False): + """ + cond_input: (b, 80, s) or (b, 1, 80, s) + conds: (b, 1024, s) + """ + conds = None + if not return_latent: + if cond_input.ndim == 4: + cond_input = cond_input.squeeze(1) + conds = self.conditioning_encoder(cond_input) # (b, d, s) + if self.use_perceiver_resampler: + conds = self.conditioning_perceiver(conds.permute(0, 2, 1)).transpose(1, 2) # (b, d, 32) + else: + # already computed + conds = cond_input.unsqueeze(1) + return conds + + def forward( + self, + text_inputs, + text_lengths, + audio_codes, + wav_lengths, + cond_mels=None, + cond_idxs=None, + cond_lens=None, + cond_latents=None, + return_attentions=False, + return_latent=False, + ): + """ + Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode + (actuated by `text_first`). + + text_inputs: long tensor, (b,t) + text_lengths: long tensor, (b,) + mel_inputs: long tensor, (b,m) + wav_lengths: long tensor, (b,) + cond_mels: MEL float tensor, (b, 1, 80,s) + cond_idxs: cond start and end indexs, (b, 2) + + If return_attentions is specified, only logits are returned. + If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned. + """ + # ❗ FIXIT + if self.max_conditioning_inputs == 0: + assert cond_mels is None, " ❗ cond_mels is not None, but max_conditioning_inputs == 0" + + max_text_len = text_lengths.max() + code_lengths = torch.ceil(wav_lengths / self.code_stride_len).long() + 3 + + if cond_lens is not None: + if self.use_perceiver_resampler: + cond_lens = cond_lens // self.perceiver_cond_length_compression + else: + cond_lens = cond_lens // self.code_stride_len + + if cond_idxs is not None: + # recompute cond idxs for mel lengths + for idx in range(cond_idxs.size(0)): + if self.use_perceiver_resampler: + cond_idxs[idx] = cond_idxs[idx] // self.perceiver_cond_length_compression + else: + cond_idxs[idx] = cond_idxs[idx] // self.code_stride_len + + # ensure that the cond_mel does not have padding + # if cond_lens is not None and cond_idxs is None: + # min_cond_len = torch.min(cond_lens) + # cond_mels = cond_mels[:, :, :, :min_cond_len] + + # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes. + max_mel_len = code_lengths.max() + + if max_mel_len > audio_codes.shape[-1]: + audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1])) + + # 💖 Lovely assertions + assert ( + max_mel_len <= audio_codes.shape[-1] + ), f" ❗ max_mel_len ({max_mel_len}) > audio_codes.shape[-1] ({audio_codes.shape[-1]})" + assert ( + max_text_len <= text_inputs.shape[-1] + ), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})" + + # Append stop token to text inputs + text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token) + + # Append silence token to mel codes + audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token) + + # Pad mel codes with stop_audio_token + audio_codes = self.set_mel_padding( + audio_codes, code_lengths - 3 + ) # -3 to get the real code lengths without consider start and stop tokens that was not added yet + + # Build input and target tensors + # Prepend start token to inputs and append stop token to targets + text_inputs, text_targets = self.set_inputs_and_targets( + text_inputs, self.start_text_token, self.stop_text_token + ) + audio_codes, mel_targets = self.set_inputs_and_targets( + audio_codes, self.start_audio_token, self.stop_audio_token + ) + + # Set attn_mask + attn_mask_cond = None + attn_mask_text = None + attn_mask_mel = None + if not return_latent: + attn_mask_cond = torch.ones( + cond_mels.shape[0], + cond_mels.shape[-1], + dtype=torch.bool, + device=text_inputs.device, + ) + attn_mask_text = torch.ones( + text_inputs.shape[0], + text_inputs.shape[1], + dtype=torch.bool, + device=text_inputs.device, + ) + attn_mask_mel = torch.ones( + audio_codes.shape[0], + audio_codes.shape[1], + dtype=torch.bool, + device=audio_codes.device, + ) + + if cond_idxs is not None: + # use masking approach + for idx, r in enumerate(cond_idxs): + l = r[1] - r[0] + attn_mask_cond[idx, l:] = 0.0 + elif cond_lens is not None: + for idx, l in enumerate(cond_lens): + attn_mask_cond[idx, l:] = 0.0 + + for idx, l in enumerate(text_lengths): + attn_mask_text[idx, l + 1 :] = 0.0 + + for idx, l in enumerate(code_lengths): + attn_mask_mel[idx, l + 1 :] = 0.0 + + # Compute text embeddings + positional embeddings + text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + + # Compute mel embeddings + positional embeddings + mel_emb = self.mel_embedding(audio_codes) + self.mel_pos_embedding(audio_codes) + + # Compute speech conditioning input + if cond_latents is None: + cond_latents = self.get_style_emb(cond_mels).transpose(1, 2) + + # Get logits + sub = -5 # don't ask me why 😄 + if self.training: + sub = -1 + + text_logits, mel_logits = self.get_logits( + text_emb, + self.text_head, + mel_emb, + self.mel_head, + prompt=cond_latents, + get_attns=return_attentions, + return_latent=return_latent, + attn_mask_cond=attn_mask_cond, + attn_mask_text=attn_mask_text, + attn_mask_mel=attn_mask_mel, + ) + if return_latent: + return mel_logits[:, :sub] # sub to prevent bla. + + if return_attentions: + return mel_logits + + # Set paddings to -1 to ignore them in loss + for idx, l in enumerate(text_lengths): + text_targets[idx, l + 1 :] = -1 + + for idx, l in enumerate(code_lengths): + mel_targets[idx, l + 1 :] = -1 + + # check if stoptoken is in every row of mel_targets + assert (mel_targets == self.stop_audio_token).sum() >= mel_targets.shape[ + 0 + ], f" ❗ mel_targets does not contain stop token ({self.stop_audio_token}) in every row." + + # ignore the loss for the segment used for conditioning + # coin flip for the segment to be ignored + if cond_idxs is not None: + cond_start = cond_idxs[idx, 0] + cond_end = cond_idxs[idx, 1] + mel_targets[idx, cond_start:cond_end] = -1 + + # Compute losses + loss_text = F.cross_entropy( + text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing + ) + loss_mel = F.cross_entropy( + mel_logits, mel_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing + ) + return loss_text.mean(), loss_mel.mean(), mel_logits + + def inference(self, cond_latents, text_inputs, **hf_generate_kwargs): + self.compute_embeddings(cond_latents, text_inputs) + return self.generate(cond_latents, text_inputs, **hf_generate_kwargs) + + def compute_embeddings( + self, + cond_latents, + text_inputs, + ): + text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) + text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token) + emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + emb = torch.cat([cond_latents, emb], dim=1) + self.gpt_inference.store_prefix_emb(emb) + gpt_inputs = torch.full( + ( + emb.shape[0], + emb.shape[1] + 1, # +1 for the start_audio_token + ), + fill_value=1, + dtype=torch.long, + device=text_inputs.device, + ) + gpt_inputs[:, -1] = self.start_audio_token + return gpt_inputs + + def generate( + self, + cond_latents, + text_inputs, + **hf_generate_kwargs, + ): + gpt_inputs = self.compute_embeddings(cond_latents, text_inputs) + stop_token_tensor = torch.tensor(self.stop_audio_token, device=gpt_inputs.device, dtype=torch.long) + attention_mask = _prepare_attention_mask_for_generation(gpt_inputs, stop_token_tensor, stop_token_tensor) + gen = self.gpt_inference.generate( + gpt_inputs, + bos_token_id=self.start_audio_token, + pad_token_id=self.stop_audio_token, + eos_token_id=self.stop_audio_token, + max_length=self.max_gen_mel_tokens + gpt_inputs.shape[-1], + attention_mask=attention_mask, + **hf_generate_kwargs, + ) + if "return_dict_in_generate" in hf_generate_kwargs: + return gen.sequences[:, gpt_inputs.shape[1] :], gen + return gen[:, gpt_inputs.shape[1] :] + + def get_generator(self, fake_inputs, **hf_generate_kwargs): + return self.gpt_inference.generate_stream( + fake_inputs, + bos_token_id=self.start_audio_token, + pad_token_id=self.stop_audio_token, + eos_token_id=self.stop_audio_token, + max_length=self.max_gen_mel_tokens + fake_inputs.shape[-1], + do_stream=True, + **hf_generate_kwargs, + ) diff --git a/TTS/tts/layers/xtts/gpt_inference.py b/TTS/tts/layers/xtts/gpt_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..e94683524a88ae691f7a1f372189e254a9f6f58c --- /dev/null +++ b/TTS/tts/layers/xtts/gpt_inference.py @@ -0,0 +1,137 @@ +import torch +from torch import nn +from transformers import GenerationMixin, GPT2PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + +from TTS.tts.layers.xtts.stream_generator import StreamGenerationConfig + + +class GPT2InferenceModel(GPT2PreTrainedModel, GenerationMixin): + """Override GPT2LMHeadModel to allow for prefix conditioning.""" + + def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache): + super().__init__(config) + self.transformer = gpt + self.pos_embedding = pos_emb + self.embeddings = embeddings + self.final_norm = norm + self.lm_head = nn.Sequential(norm, linear) + self.kv_cache = kv_cache + self.generation_config = StreamGenerationConfig.from_model_config(config) if self.can_generate() else None + + def store_prefix_emb(self, prefix_emb): + self.cached_prefix_emb = prefix_emb + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) # usually None + if not self.kv_cache: + past_key_values = None + + # only last token for inputs_ids if past is defined in kwargs + if past_key_values is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values is not None: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + assert self.cached_prefix_emb is not None + assert inputs_embeds is None # Not supported by this inference model. + assert labels is None # Training not supported by this inference model. + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # assert len(past_key_values) + len(input_ids) == attention_mask.shape[1] + + # Create embedding + prefix_len = self.cached_prefix_emb.shape[1] + if input_ids.shape[1] != 1: + gen_inputs = input_ids[:, prefix_len:] + gen_emb = self.embeddings(gen_inputs) + gen_emb = gen_emb + self.pos_embedding(gen_emb) + if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]: + prefix_emb = self.cached_prefix_emb.repeat_interleave( + gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0 + ) + else: + prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype) + emb = torch.cat([prefix_emb, gen_emb], dim=1) + else: + emb = self.embeddings(input_ids) + emb = emb + self.pos_embedding.get_fixed_embedding( + attention_mask.shape[1] - (prefix_len + 1), attention_mask.device + ) + transformer_outputs = self.transformer( + inputs_embeds=emb, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + return (lm_logits,) + transformer_outputs[1:] + + return CausalLMOutputWithCrossAttentions( + loss=None, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past, beam_idx): + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) diff --git a/TTS/tts/layers/xtts/hifigan_decoder.py b/TTS/tts/layers/xtts/hifigan_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5ef0030b8bb651a1a3ae5290548ac5cf8632d0fb --- /dev/null +++ b/TTS/tts/layers/xtts/hifigan_decoder.py @@ -0,0 +1,734 @@ +import logging + +import torch +import torchaudio +from torch import nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations +from trainer.io import load_fsspec + +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 +from TTS.vocoder.models.hifigan_generator import get_padding + +logger = logging.getLogger(__name__) + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + """Residual Block Type 1. It has 3 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o + |--------------------------------------------------------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + + def forward(self, x): + """ + Args: + x (Tensor): input tensor. + Returns: + Tensor: output tensor. + Shapes: + x: [B, C, T] + """ + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_parametrizations(l, "weight") + for l in self.convs2: + remove_parametrizations(l, "weight") + + +class ResBlock2(torch.nn.Module): + """Residual Block Type 2. It has 1 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o + |---------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_parametrizations(l, "weight") + + +class HifiganGenerator(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + resblock_type, + resblock_dilation_sizes, + resblock_kernel_sizes, + upsample_kernel_sizes, + upsample_initial_channel, + upsample_factors, + inference_padding=5, + cond_channels=0, + conv_pre_weight_norm=True, + conv_post_weight_norm=True, + conv_post_bias=True, + cond_in_each_up_layer=False, + ): + r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) + + Network: + x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o + .. -> zI ---| + resblockN_kNx1 -> zN ---' + + Args: + in_channels (int): number of input tensor channels. + out_channels (int): number of output tensor channels. + resblock_type (str): type of the `ResBlock`. '1' or '2'. + resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`. + resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`. + upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution. + upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 + for each consecutive upsampling layer. + upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer. + inference_padding (int): constant padding applied to the input at inference time. Defaults to 5. + """ + super().__init__() + self.inference_padding = inference_padding + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_factors) + self.cond_in_each_up_layer = cond_in_each_up_layer + + # initial upsampling layers + self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if resblock_type == "1" else ResBlock2 + # upsampling layers + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + # MRF blocks + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + # post convolution layer + self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)) + if cond_channels > 0: + self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) + + if not conv_pre_weight_norm: + remove_parametrizations(self.conv_pre, "weight") + + if not conv_post_weight_norm: + remove_parametrizations(self.conv_post, "weight") + + if self.cond_in_each_up_layer: + self.conds = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + self.conds.append(nn.Conv1d(cond_channels, ch, 1)) + + def forward(self, x, g=None): + """ + Args: + x (Tensor): feature input tensor. + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + o = self.conv_pre(x) + if hasattr(self, "cond_layer"): + o = o + self.cond_layer(g) + for i in range(self.num_upsamples): + o = F.leaky_relu(o, LRELU_SLOPE) + o = self.ups[i](o) + + if self.cond_in_each_up_layer: + o = o + self.conds[i](g) + + z_sum = None + for j in range(self.num_kernels): + if z_sum is None: + z_sum = self.resblocks[i * self.num_kernels + j](o) + else: + z_sum += self.resblocks[i * self.num_kernels + j](o) + o = z_sum / self.num_kernels + o = F.leaky_relu(o) + o = self.conv_post(o) + o = torch.tanh(o) + return o + + @torch.no_grad() + def inference(self, c): + """ + Args: + x (Tensor): conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + c = c.to(self.conv_pre.weight.device) + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") + return self.forward(c) + + def remove_weight_norm(self): + logger.info("Removing weight norm...") + for l in self.ups: + remove_parametrizations(l, "weight") + for l in self.resblocks: + l.remove_weight_norm() + remove_parametrizations(self.conv_pre, "weight") + remove_parametrizations(self.conv_post, "weight") + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4()) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + self.remove_weight_norm() + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=8): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class SEBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): + super(SEBasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se = SELayer(planes, reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + return out + + +def set_init_dict(model_dict, checkpoint_state, c): + # Partial initialization: if there is a mismatch with new and old layer, it is skipped. + for k, v in checkpoint_state.items(): + if k not in model_dict: + logger.warning("Layer missing in the model definition: %s", k) + # 1. filter out unnecessary keys + pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} + # 2. filter out different size layers + pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} + # 3. skip reinit layers + if c.has("reinit_layers") and c.reinit_layers is not None: + for reinit_layer_name in c.reinit_layers: + pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} + # 4. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict)) + return model_dict + + +class PreEmphasis(nn.Module): + def __init__(self, coefficient=0.97): + super().__init__() + self.coefficient = coefficient + self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) + + def forward(self, x): + assert len(x.size()) == 2 + + x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") + return torch.nn.functional.conv1d(x, self.filter).squeeze(1) + + +class ResNetSpeakerEncoder(nn.Module): + """This is copied from 🐸TTS to remove it from the dependencies.""" + + # pylint: disable=W0102 + def __init__( + self, + input_dim=64, + proj_dim=512, + layers=[3, 4, 6, 3], + num_filters=[32, 64, 128, 256], + encoder_type="ASP", + log_input=False, + use_torch_spec=False, + audio_config=None, + ): + super(ResNetSpeakerEncoder, self).__init__() + + self.encoder_type = encoder_type + self.input_dim = input_dim + self.log_input = log_input + self.use_torch_spec = use_torch_spec + self.audio_config = audio_config + self.proj_dim = proj_dim + + self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) + self.relu = nn.ReLU(inplace=True) + self.bn1 = nn.BatchNorm2d(num_filters[0]) + + self.inplanes = num_filters[0] + self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0]) + self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2)) + self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2)) + self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2)) + + self.instancenorm = nn.InstanceNorm1d(input_dim) + + if self.use_torch_spec: + self.torch_spec = torch.nn.Sequential( + PreEmphasis(audio_config["preemphasis"]), + torchaudio.transforms.MelSpectrogram( + sample_rate=audio_config["sample_rate"], + n_fft=audio_config["fft_size"], + win_length=audio_config["win_length"], + hop_length=audio_config["hop_length"], + window_fn=torch.hamming_window, + n_mels=audio_config["num_mels"], + ), + ) + + else: + self.torch_spec = None + + outmap_size = int(self.input_dim / 8) + + self.attention = nn.Sequential( + nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), + nn.Softmax(dim=2), + ) + + if self.encoder_type == "SAP": + out_dim = num_filters[3] * outmap_size + elif self.encoder_type == "ASP": + out_dim = num_filters[3] * outmap_size * 2 + else: + raise ValueError("Undefined encoder") + + self.fc = nn.Linear(out_dim, proj_dim) + + self._init_layers() + + def _init_layers(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def create_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + # pylint: disable=R0201 + def new_parameter(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + def forward(self, x, l2_norm=False): + """Forward pass of the model. + + Args: + x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` + to compute the spectrogram on-the-fly. + l2_norm (bool): Whether to L2-normalize the outputs. + + Shapes: + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` + """ + x.squeeze_(1) + # if you torch spec compute it otherwise use the mel spec computed by the AP + if self.use_torch_spec: + x = self.torch_spec(x) + + if self.log_input: + x = (x + 1e-6).log() + x = self.instancenorm(x).unsqueeze(1) + + x = self.conv1(x) + x = self.relu(x) + x = self.bn1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = x.reshape(x.size()[0], -1, x.size()[-1]) + + w = self.attention(x) + + if self.encoder_type == "SAP": + x = torch.sum(x * w, dim=2) + elif self.encoder_type == "ASP": + mu = torch.sum(x * w, dim=2) + sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) + x = torch.cat((mu, sg), 1) + + x = x.view(x.size()[0], -1) + x = self.fc(x) + + if l2_norm: + x = torch.nn.functional.normalize(x, p=2, dim=1) + return x + + def load_checkpoint( + self, + checkpoint_path: str, + eval: bool = False, + use_cuda: bool = False, + criterion=None, + cache=False, + ): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + try: + self.load_state_dict(state["model"]) + logger.info("Model fully restored.") + except (KeyError, RuntimeError) as error: + # If eval raise the error + if eval: + raise error + + logger.info("Partial model initialization.") + model_dict = self.state_dict() + model_dict = set_init_dict(model_dict, state["model"]) + self.load_state_dict(model_dict) + del model_dict + + # load the criterion for restore_path + if criterion is not None and "criterion" in state: + try: + criterion.load_state_dict(state["criterion"]) + except (KeyError, RuntimeError) as error: + logger.exception("Criterion load ignored because of: %s", error) + + if use_cuda: + self.cuda() + if criterion is not None: + criterion = criterion.cuda() + + if eval: + self.eval() + assert not self.training + + if not eval: + return criterion, state["step"] + return criterion + + +class HifiDecoder(torch.nn.Module): + def __init__( + self, + input_sample_rate=22050, + output_sample_rate=24000, + output_hop_length=256, + ar_mel_length_compression=1024, + decoder_input_dim=1024, + resblock_type_decoder="1", + resblock_dilation_sizes_decoder=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + resblock_kernel_sizes_decoder=[3, 7, 11], + upsample_rates_decoder=[8, 8, 2, 2], + upsample_initial_channel_decoder=512, + upsample_kernel_sizes_decoder=[16, 16, 4, 4], + d_vector_dim=512, + cond_d_vector_in_each_upsampling_layer=True, + speaker_encoder_audio_config={ + "fft_size": 512, + "win_length": 400, + "hop_length": 160, + "sample_rate": 16000, + "preemphasis": 0.97, + "num_mels": 64, + }, + ): + super().__init__() + self.input_sample_rate = input_sample_rate + self.output_sample_rate = output_sample_rate + self.output_hop_length = output_hop_length + self.ar_mel_length_compression = ar_mel_length_compression + self.speaker_encoder_audio_config = speaker_encoder_audio_config + self.waveform_decoder = HifiganGenerator( + decoder_input_dim, + 1, + resblock_type_decoder, + resblock_dilation_sizes_decoder, + resblock_kernel_sizes_decoder, + upsample_kernel_sizes_decoder, + upsample_initial_channel_decoder, + upsample_rates_decoder, + inference_padding=0, + cond_channels=d_vector_dim, + conv_pre_weight_norm=False, + conv_post_weight_norm=False, + conv_post_bias=False, + cond_in_each_up_layer=cond_d_vector_in_each_upsampling_layer, + ) + self.speaker_encoder = ResNetSpeakerEncoder( + input_dim=64, + proj_dim=512, + log_input=True, + use_torch_spec=True, + audio_config=speaker_encoder_audio_config, + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, latents, g=None): + """ + Args: + x (Tensor): feature input tensor (GPT latent). + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + + z = torch.nn.functional.interpolate( + latents.transpose(1, 2), + scale_factor=[self.ar_mel_length_compression / self.output_hop_length], + mode="linear", + ).squeeze(1) + # upsample to the right sr + if self.output_sample_rate != self.input_sample_rate: + z = torch.nn.functional.interpolate( + z, + scale_factor=[self.output_sample_rate / self.input_sample_rate], + mode="linear", + ).squeeze(0) + o = self.waveform_decoder(z, g=g) + return o + + @torch.no_grad() + def inference(self, c, g): + """ + Args: + x (Tensor): feature input tensor (GPT latent). + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + return self.forward(c, g=g) + + def load_checkpoint(self, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + # remove unused keys + state = state["model"] + states_keys = list(state.keys()) + for key in states_keys: + if "waveform_decoder." not in key and "speaker_encoder." not in key: + del state[key] + + self.load_state_dict(state) + if eval: + self.eval() + assert not self.training + self.waveform_decoder.remove_weight_norm() diff --git a/TTS/tts/layers/xtts/latent_encoder.py b/TTS/tts/layers/xtts/latent_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f9d62a36f1529ddd1e9e6fdd92afc5c9f224f827 --- /dev/null +++ b/TTS/tts/layers/xtts/latent_encoder.py @@ -0,0 +1,141 @@ +# ported from: Originally ported from: https://github.com/neonbjb/tortoise-tts + +import math + +import torch +from torch import nn +from torch.nn import functional as F + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def normalization(channels): + groups = 32 + if channels <= 16: + groups = 8 + elif channels <= 64: + groups = 16 + while channels % groups != 0: + groups = int(groups / 2) + assert groups > 2 + return GroupNorm32(groups, channels) + + +def zero_module(module): + for p in module.parameters(): + p.detach().zero_() + return module + + +class QKVAttention(nn.Module): + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv, mask=None, qk_bias=0): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = weight + qk_bias + if mask is not None: + mask = mask.repeat(self.n_heads, 1, 1) + weight[mask.logical_not()] = -torch.inf + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v) + + return a.reshape(bs, -1, length) + + +class AttentionBlock(nn.Module): + """An attention block that allows spatial positions to attend to each other.""" + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + out_channels=None, + do_activation=False, + ): + super().__init__() + self.channels = channels + out_channels = channels if out_channels is None else out_channels + self.do_activation = do_activation + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, out_channels * 3, 1) + self.attention = QKVAttention(self.num_heads) + + self.x_proj = nn.Identity() if out_channels == channels else conv_nd(1, channels, out_channels, 1) + self.proj_out = zero_module(conv_nd(1, out_channels, out_channels, 1)) + + def forward(self, x, mask=None, qk_bias=0): + b, c, *spatial = x.shape + if mask is not None: + if len(mask.shape) == 2: + mask = mask.unsqueeze(0).repeat(x.shape[0], 1, 1) + if mask.shape[1] != x.shape[-1]: + mask = mask[:, : x.shape[-1], : x.shape[-1]] + + x = x.reshape(b, c, -1) + x = self.norm(x) + if self.do_activation: + x = F.silu(x, inplace=True) + qkv = self.qkv(x) + h = self.attention(qkv, mask=mask, qk_bias=qk_bias) + h = self.proj_out(h) + xp = self.x_proj(x) + return (xp + h).reshape(b, xp.shape[1], *spatial) + + +class ConditioningEncoder(nn.Module): + def __init__( + self, + spec_dim, + embedding_dim, + attn_blocks=6, + num_attn_heads=4, + ): + super().__init__() + attn = [] + self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) + for a in range(attn_blocks): + attn.append(AttentionBlock(embedding_dim, num_attn_heads)) + self.attn = nn.Sequential(*attn) + self.dim = embedding_dim + + def forward(self, x): + """ + x: (b, 80, s) + """ + h = self.init(x) + h = self.attn(h) + return h diff --git a/TTS/tts/layers/xtts/perceiver_encoder.py b/TTS/tts/layers/xtts/perceiver_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b6e841235592ac20d289e1e4c08dbd20a3daa2 --- /dev/null +++ b/TTS/tts/layers/xtts/perceiver_encoder.py @@ -0,0 +1,311 @@ +# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532 + +from collections import namedtuple +from functools import wraps + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from torch import einsum, nn + + +def exists(val): + return val is not None + + +def once(fn): + called = False + + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + + return inner + + +print_once = once(print) + +# main class + + +class Attend(nn.Module): + def __init__(self, dropout=0.0, causal=False, use_flash=False): + super().__init__() + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.causal = causal + self.register_buffer("mask", None, persistent=False) + + self.use_flash = use_flash + + # determine efficient attention configs for cuda and cpu + self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]) + self.cpu_config = self.config(True, True, True) + self.cuda_config = None + + if not torch.cuda.is_available() or not use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device("cuda")) + + if device_properties.major == 8 and device_properties.minor == 0: + print_once("A100 GPU detected, using flash attention if input tensor is on cuda") + self.cuda_config = self.config(True, False, False) + else: + print_once("Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda") + self.cuda_config = self.config(False, True, True) + + def get_mask(self, n, device): + if exists(self.mask) and self.mask.shape[-1] >= n: + return self.mask[:n, :n] + + mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) + self.register_buffer("mask", mask, persistent=False) + return mask + + def flash_attn(self, q, k, v, mask=None): + _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda + + # Recommended for multi-query single-key-value attention by Tri Dao + # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) + + if k.ndim == 3: + k = rearrange(k, "b ... -> b 1 ...").expand_as(q) + + if v.ndim == 3: + v = rearrange(v, "b ... -> b 1 ...").expand_as(q) + + # Check if mask exists and expand to compatible shape + # The mask is B L, so it would have to be expanded to B H N L + + if exists(mask): + mask = rearrange(mask, "b j -> b 1 1 j") + mask = mask.expand(-1, heads, q_len, -1) + + # Check if there is a compatible device for flash attention + + config = self.cuda_config if is_cuda else self.cpu_config + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale + + with torch.backends.cuda.sdp_kernel(**config._asdict()): + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal + ) + + return out + + def forward(self, q, k, v, mask=None): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + n, device = q.shape[-2], q.device + + scale = q.shape[-1] ** -0.5 + + if self.use_flash: + return self.flash_attn(q, k, v, mask=mask) + + kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d" + + # similarity + + sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale + + # key padding mask + + if exists(mask): + mask = rearrange(mask, "b j -> b 1 1 j") + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # causal mask + + if self.causal: + causal_mask = self.get_mask(n, device) + sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + + # attention + + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + + out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) + + return out + + +def Sequential(*mods): + return nn.Sequential(*filter(exists, mods)) + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +class RMSNorm(nn.Module): + def __init__(self, dim, scale=True, dim_cond=None): + super().__init__() + self.cond = exists(dim_cond) + self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None + + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(dim)) if scale else None + + def forward(self, x, cond=None): + gamma = default(self.gamma, 1) + out = F.normalize(x, dim=-1) * self.scale * gamma + + if not self.cond: + return out + + assert exists(cond) + gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1) + gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta)) + return out * gamma + beta + + +class CausalConv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + (kernel_size,) = self.kernel_size + (dilation,) = self.dilation + (stride,) = self.stride + + assert stride == 1 + self.causal_padding = dilation * (kernel_size - 1) + + def forward(self, x): + causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0) + return super().forward(causal_padded_x) + + +class GEGLU(nn.Module): + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.gelu(gate) * x + + +def FeedForward(dim, mult=4, causal_conv=False): + dim_inner = int(dim * mult * 2 / 3) + + conv = None + if causal_conv: + conv = nn.Sequential( + Rearrange("b n d -> b d n"), + CausalConv1d(dim_inner, dim_inner, 3), + Rearrange("b d n -> b n d"), + ) + + return Sequential(nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim)) + + +class PerceiverResampler(nn.Module): + def __init__( + self, + *, + dim, + depth=2, + dim_context=None, + num_latents=32, + dim_head=64, + heads=8, + ff_mult=4, + use_flash_attn=False, + ): + super().__init__() + dim_context = default(dim_context, dim) + + self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity() + + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + nn.init.normal_(self.latents, std=0.02) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + Attention( + dim=dim, + dim_head=dim_head, + heads=heads, + use_flash=use_flash_attn, + cross_attn_include_queries=True, + ), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + self.norm = RMSNorm(dim) + + def forward(self, x, mask=None): + batch = x.shape[0] + + x = self.proj_context(x) + + latents = repeat(self.latents, "n d -> b n d", b=batch) + + for attn, ff in self.layers: + latents = attn(latents, x, mask=mask) + latents + latents = ff(latents) + latents + + return self.norm(latents) + + +class Attention(nn.Module): + def __init__( + self, + dim, + *, + dim_context=None, + causal=False, + dim_head=64, + heads=8, + dropout=0.0, + use_flash=False, + cross_attn_include_queries=False, + ): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + self.cross_attn_include_queries = cross_attn_include_queries + + dim_inner = dim_head * heads + dim_context = default(dim_context, dim) + + self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash) + self.to_q = nn.Linear(dim, dim_inner, bias=False) + self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False) + self.to_out = nn.Linear(dim_inner, dim, bias=False) + + def forward(self, x, context=None, mask=None): + h, has_context = self.heads, exists(context) + + context = default(context, x) + + if has_context and self.cross_attn_include_queries: + context = torch.cat((x, context), dim=-2) + + q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = self.attend(q, k, v, mask=mask) + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..44cf940c6901c7f404b26bd02dc68c0dbeb0e4af --- /dev/null +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -0,0 +1,966 @@ +# Adapted from: https://github.com/LowinLi/transformers-stream-generator + +import copy +import inspect +import random +import warnings +from typing import Callable, Optional, Union + +import numpy as np +import torch +import torch.distributed as dist +from torch import nn +from transformers import ( + BeamSearchScorer, + ConstrainedBeamSearchScorer, + DisjunctiveConstraint, + GenerationConfig, + GenerationMixin, + LogitsProcessorList, + PhrasalConstraint, + PreTrainedModel, + StoppingCriteriaList, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) +from transformers.generation.utils import GenerateOutput, SampleOutput, logger + + +def setup_seed(seed: int) -> None: + if seed == -1: + return + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +class StreamGenerationConfig(GenerationConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.do_stream = kwargs.pop("do_stream", False) + + +class NewGenerationMixin(GenerationMixin): + @torch.no_grad() + def generate( # noqa: PLR0911 + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[StreamGenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None, + synced_gpus: Optional[bool] = False, + seed: int = 0, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + r""" + + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the + method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` + should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of + `input_ids`, `input_values`, `input_features`, or `pixel_values`. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchDecoderOnlyOutput`], + - [`~generation.SampleDecoderOnlyOutput`], + - [`~generation.BeamSearchDecoderOnlyOutput`], + - [`~generation.BeamSampleDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.GreedySearchEncoderDecoderOutput`], + - [`~generation.SampleEncoderDecoderOutput`], + - [`~generation.BeamSearchEncoderDecoderOutput`], + - [`~generation.BeamSampleEncoderDecoderOutput`] + """ + # setup_seed(seed) + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + self._validate_model_class() + + # priority: `generation_config` argument > `model.generation_config` (the default generation config) + if generation_config is None: + # legacy: users may modify the model configuration to control generation -- update the generation config + # model attribute accordingly, if it was created from the model config + if self.generation_config._from_model_config: + new_generation_config = StreamGenerationConfig.from_model_config(self.config) + if new_generation_config != self.generation_config: + warnings.warn( + "You have modified the pretrained model configuration to control generation. This is a" + " deprecated strategy to control generation and will be removed soon, in a future version." + " Please use a generation configuration file (see" + " https://huggingface.co/docs/transformers/main_classes/text_generation)" + ) + self.generation_config = new_generation_config + generation_config = self.generation_config + + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + # self._validate_model_kwargs(model_kwargs.copy()) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: + if model_kwargs.get("attention_mask", None) is None: + logger.warning( + "The attention mask and the pad token id were not set. As a consequence, you may observe " + "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results." + ) + eos_token_id = generation_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") + generation_config.pad_token_id = eos_token_id + + # 3. Define model inputs + # inputs_tensor has to be defined + # model_input_name is defined if model-specific keyword input is passed + # otherwise model_input_name is None + # all model-specific keyword inputs are removed from `model_kwargs` + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + batch_size = inputs_tensor.shape[0] + + # 4. Define other model kwargs + model_kwargs["output_attentions"] = generation_config.output_attentions + model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + model_kwargs["use_cache"] = generation_config.use_cache + model_kwargs["cache_position"] = torch.Tensor([0]).to(inputs_tensor.device) + + accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) + requires_attention_mask = "encoder_outputs" not in model_kwargs + + if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: + setattr( + generation_config, + "_pad_token_tensor", + torch.full( + (inputs_tensor.shape[0], inputs_tensor.shape[1]), + generation_config.pad_token_id, + device=inputs_tensor.device, + ), + ) + setattr( + generation_config, + "_eos_token_tensor", + torch.full( + (inputs_tensor.shape[0], inputs_tensor.shape[1]), + generation_config.eos_token_id, + device=inputs_tensor.device, + ), + ) + model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( + inputs_tensor, + generation_config._pad_token_tensor, + generation_config._eos_token_tensor, + ) + + # decoder-only models should use left-padding for generation + if not self.config.is_encoder_decoder: + if ( + generation_config.pad_token_id is not None + and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 + ): + logger.warning( + "A decoder-only architecture is being used, but right-padding was detected! For correct " + "generation results, please set `padding_side='left'` when initializing the tokenizer." + ) + + if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: + # if model is encoder decoder encoder_outputs are created + # and added to `model_kwargs` + model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, model_kwargs, model_input_name + ) + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + if self.config.is_encoder_decoder: + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, + decoder_start_token_id=generation_config.decoder_start_token_id, + bos_token_id=generation_config.bos_token_id, + model_kwargs=model_kwargs, + device=inputs_tensor.device, + ) + else: + # if decoder-only then inputs_tensor has to be `input_ids` + input_ids = inputs_tensor + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_seq_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to" + f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" + " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif has_default_max_length and generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + elif not has_default_max_length and generation_config.max_new_tokens is not None: + raise ValueError( + "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a" + " limit to the generated output length. Remove one of those arguments. Please refer to the" + " documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" + ) + + if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: + raise ValueError( + f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than" + f" the maximum length ({generation_config.max_length})" + ) + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 7. determine generation mode + is_constraint_gen_mode = ( + generation_config.constraints is not None or generation_config.force_words_ids is not None + ) + + is_contrastive_search_gen_mode = ( + generation_config.top_k is not None + and generation_config.top_k > 1 + and generation_config.do_sample is False + and generation_config.penalty_alpha is not None + and generation_config.penalty_alpha > 0 + ) + + is_greedy_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_sample_gen_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + and generation_config.do_stream is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_sample_gen_stream_mode = ( + (generation_config.num_beams == 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_stream is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_beam_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is False + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_beam_sample_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups == 1) + and generation_config.do_sample is True + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + is_group_beam_gen_mode = ( + (generation_config.num_beams > 1) + and (generation_config.num_beam_groups > 1) + and not is_constraint_gen_mode + and not is_contrastive_search_gen_mode + ) + + if generation_config.num_beam_groups > generation_config.num_beams: + raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") + if is_group_beam_gen_mode and generation_config.do_sample is True: + raise ValueError( + "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`." + ) + + if self.device.type != input_ids.device.type: + warnings.warn( + "You are calling .generate() with the `input_ids` being on a device type different" + f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" + f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." + " Please make sure that you have put `input_ids` to the" + f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" + " running `.generate()`.", + UserWarning, + ) + # 8. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + # 9. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + # 10. go into different generation modes + if is_greedy_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " greedy search." + ) + + # 11. run greedy search + return self.greedy_search( + input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_contrastive_search_gen_mode: + if generation_config.num_return_sequences > 1: + raise ValueError( + f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing" + " contrastive search." + ) + + return self.contrastive_search( + input_ids, + top_k=generation_config.top_k, + penalty_alpha=generation_config.penalty_alpha, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_sample_gen_mode: + # 11. prepare logits warper + logits_warper = _get_logits_warper(generation_config) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self.sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_sample_gen_stream_mode: + # 11. prepare logits warper + logits_warper = _get_logits_warper(generation_config) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self.sample_stream( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + elif is_beam_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + + if stopping_criteria.max_length is None: + raise ValueError("`max_length` needs to be a stopping_criteria for now.") + + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_beam_sample_gen_mode: + # 11. prepare logits warper + logits_warper = _get_logits_warper(generation_config) + + if stopping_criteria.max_length is None: + raise ValueError("`max_length` needs to be a stopping_criteria for now.") + # 12. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size * generation_config.num_return_sequences, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + ) + + # 13. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams * generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 14. run beam sample + return self.beam_sample( + input_ids, + beam_scorer, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_group_beam_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + + if generation_config.num_beams % generation_config.num_beam_groups != 0: + raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") + + if stopping_criteria.max_length is None: + raise ValueError("`max_length` needs to be a stopping_criteria for now.") + + has_default_typical_p = kwargs.get("typical_p") is None and generation_config.typical_p == 1.0 + if not has_default_typical_p: + raise ValueError("Decoder argument `typical_p` is not supported with beam groups.") + + # 11. prepare beam search scorer + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + max_length=stopping_criteria.max_length, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + num_beam_groups=generation_config.num_beam_groups, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.group_beam_search( + input_ids, + beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + elif is_constraint_gen_mode: + if generation_config.num_return_sequences > generation_config.num_beams: + raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + + if stopping_criteria.max_length is None: + raise ValueError("`max_length` needs to be a stopping_criteria for now.") + + if generation_config.num_beams <= 1: + raise ValueError("`num_beams` needs to be greater than 1 for constrained generation.") + + if generation_config.do_sample: + raise ValueError("`do_sample` needs to be false for constrained generation.") + + if generation_config.num_beam_groups is not None and generation_config.num_beam_groups > 1: + raise ValueError("`num_beam_groups` not supported yet for constrained generation.") + + final_constraints = [] + if generation_config.constraints is not None: + final_constraints = generation_config.constraints + + if generation_config.force_words_ids is not None: + + def typeerror(): + raise ValueError( + "`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]`" + f"of positive integers, but is {generation_config.force_words_ids}." + ) + + if ( + not isinstance(generation_config.force_words_ids, list) + or len(generation_config.force_words_ids) == 0 + ): + typeerror() + + for word_ids in generation_config.force_words_ids: + if isinstance(word_ids[0], list): + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any(not isinstance(token_ids, list) for token_ids in word_ids): + typeerror() + if any( + any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) + for token_ids in word_ids + ): + typeerror() + + constraint = DisjunctiveConstraint(word_ids) + else: + if not isinstance(word_ids, list) or len(word_ids) == 0: + typeerror() + if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): + typeerror() + + constraint = PhrasalConstraint(word_ids) + final_constraints.append(constraint) + + # 11. prepare beam search scorer + constrained_beam_scorer = ConstrainedBeamSearchScorer( + constraints=final_constraints, + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + ) + # 12. interleave input_ids with `num_beams` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_beams, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + # 13. run beam search + return self.constrained_beam_search( + input_ids, + constrained_beam_scorer=constrained_beam_scorer, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + **model_kwargs, + ) + + @torch.no_grad() + def sample_stream( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, list[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: Optional[bool] = False, + **model_kwargs, + ) -> Union[SampleOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + + + In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. + For an overview of generation strategies and code examples, check the [following + guide](./generation_strategies). + + + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + logits_warper (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used + to warp the prediction score distribution of the language modeling head applied before multinomial + sampling at each generation step. + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`int`, *optional*): + The id of the *end-of-sequence* token. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import ( + ... AutoTokenizer, + ... AutoModelForCausalLM, + ... LogitsProcessorList, + ... MinLengthLogitsProcessor, + ... TopKLogitsWarper, + ... TemperatureLogitsWarper, + ... StoppingCriteriaList, + ... MaxLengthCriteria, + ... ) + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + + >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token + >>> model.config.pad_token_id = model.config.eos_token_id + >>> model.generation_config.pad_token_id = model.config.eos_token_id + + >>> input_prompt = "Today is a beautiful day, and" + >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids + + >>> # instantiate logits processors + >>> logits_processor = LogitsProcessorList( + ... [ + ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), + ... ] + ... ) + >>> # instantiate logits processors + >>> logits_warper = LogitsProcessorList( + ... [ + ... TopKLogitsWarper(50), + ... TemperatureLogitsWarper(0.7), + ... ] + ... ) + + >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + + >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT + >>> outputs = model.sample( + ... input_ids, + ... logits_processor=logits_processor, + ... logits_warper=logits_warper, + ... stopping_criteria=stopping_criteria, + ... ) + + >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) + ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the'] + ```""" + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + this_peer_finished = False # used by synced_gpus only + # auto-regressive generation + while True: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) + ) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1]) + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + if not synced_gpus: + break + else: + this_peer_finished = True + + +def init_stream_support(): + """Overload PreTrainedModel for streaming.""" + PreTrainedModel.generate_stream = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream + + +if __name__ == "__main__": + from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel + + PreTrainedModel.generate = NewGenerationMixin.generate + PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream + model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16) + + tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + model = model.to("cuda:0") + model = model.eval() + prompt_text = "hello? \n" + input_ids = tokenizer(prompt_text, return_tensors="pt", add_special_tokens=False).input_ids + input_ids = input_ids.to("cuda:0") + + with torch.no_grad(): + result = model.generate( + input_ids, + max_new_tokens=200, + do_sample=True, + top_k=30, + top_p=0.85, + temperature=0.35, + repetition_penalty=1.2, + early_stopping=True, + seed=0, + ) + print(tokenizer.decode(result, skip_special_tokens=True)) + generator = model.generate( + input_ids, + max_new_tokens=200, + do_sample=True, + top_k=30, + top_p=0.85, + temperature=0.35, + repetition_penalty=1.2, + early_stopping=True, + seed=0, + do_stream=True, + ) + stream_result = "" + for x in generator: + chunk = tokenizer.decode(x, skip_special_tokens=True) + stream_result += chunk + print(stream_result) + + +def _get_logits_warper(generation_config: GenerationConfig) -> LogitsProcessorList: + + warpers = LogitsProcessorList() + + if generation_config.temperature is not None and generation_config.temperature != 1.0: + warpers.append(TemperatureLogitsWarper(generation_config.temperature)) + if generation_config.top_k is not None and generation_config.top_k != 0: + warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1)) + if generation_config.top_p is not None and generation_config.top_p < 1.0: + warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1)) + + return warpers diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d4937f416f95ff9048950b6950c41958a83c0c0a --- /dev/null +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -0,0 +1,926 @@ +import logging +import os +import re +import textwrap +from functools import cached_property + +import torch +from num2words import num2words +from spacy.lang.ar import Arabic +from spacy.lang.en import English +from spacy.lang.es import Spanish +from spacy.lang.hi import Hindi +from spacy.lang.ja import Japanese +from spacy.lang.zh import Chinese +from tokenizers import Tokenizer + +from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words + +logger = logging.getLogger(__name__) + + +def get_spacy_lang(lang): + """Return Spacy language used for sentence splitting.""" + if lang == "zh": + return Chinese() + elif lang == "ja": + return Japanese() + elif lang == "ar": + return Arabic() + elif lang == "es": + return Spanish() + elif lang == "hi": + return Hindi() + else: + # For most languages, English does the job + return English() + + +def split_sentence(text, lang, text_split_length=250): + """Preprocess the input text""" + text_splits = [] + if text_split_length is not None and len(text) >= text_split_length: + text_splits.append("") + nlp = get_spacy_lang(lang) + nlp.add_pipe("sentencizer") + doc = nlp(text) + for sentence in doc.sents: + if len(text_splits[-1]) + len(str(sentence)) <= text_split_length: + # if the last sentence + the current sentence is less than the text_split_length + # then add the current sentence to the last sentence + text_splits[-1] += " " + str(sentence) + text_splits[-1] = text_splits[-1].lstrip() + elif len(str(sentence)) > text_split_length: + # if the current sentence is greater than the text_split_length + for line in textwrap.wrap( + str(sentence), + width=text_split_length, + drop_whitespace=True, + break_on_hyphens=False, + tabsize=1, + ): + text_splits.append(str(line)) + else: + text_splits.append(str(sentence)) + + if len(text_splits) > 1: + if text_splits[0] == "": + del text_splits[0] + else: + text_splits = [text.lstrip()] + + return text_splits + + +_whitespace_re = re.compile(r"\s+") + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = { + "en": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] + ], + "es": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("sra", "señora"), + ("sr", "señor"), + ("dr", "doctor"), + ("dra", "doctora"), + ("st", "santo"), + ("co", "compañía"), + ("jr", "junior"), + ("ltd", "limitada"), + ] + ], + "fr": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mme", "madame"), + ("mr", "monsieur"), + ("dr", "docteur"), + ("st", "saint"), + ("co", "compagnie"), + ("jr", "junior"), + ("ltd", "limitée"), + ] + ], + "de": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("fr", "frau"), + ("dr", "doktor"), + ("st", "sankt"), + ("co", "firma"), + ("jr", "junior"), + ] + ], + "pt": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("sra", "senhora"), + ("sr", "senhor"), + ("dr", "doutor"), + ("dra", "doutora"), + ("st", "santo"), + ("co", "companhia"), + ("jr", "júnior"), + ("ltd", "limitada"), + ] + ], + "it": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + # ("sig.ra", "signora"), + ("sig", "signore"), + ("dr", "dottore"), + ("st", "santo"), + ("co", "compagnia"), + ("jr", "junior"), + ("ltd", "limitata"), + ] + ], + "pl": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("p", "pani"), + ("m", "pan"), + ("dr", "doktor"), + ("sw", "święty"), + ("jr", "junior"), + ] + ], + "ar": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + # There are not many common abbreviations in Arabic as in English. + ] + ], + "zh": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + # Chinese doesn't typically use abbreviations in the same way as Latin-based scripts. + ] + ], + "cs": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("dr", "doktor"), # doctor + ("ing", "inženýr"), # engineer + ("p", "pan"), # Could also map to pani for woman but no easy way to do it + # Other abbreviations would be specialized and not as common. + ] + ], + "ru": [ + (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1]) + for x in [ + ("г-жа", "госпожа"), # Mrs. + ("г-н", "господин"), # Mr. + ("д-р", "доктор"), # doctor + # Other abbreviations are less common or specialized. + ] + ], + "nl": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("dhr", "de heer"), # Mr. + ("mevr", "mevrouw"), # Mrs. + ("dr", "dokter"), # doctor + ("jhr", "jonkheer"), # young lord or nobleman + # Dutch uses more abbreviations, but these are the most common ones. + ] + ], + "tr": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("b", "bay"), # Mr. + ("byk", "büyük"), # büyük + ("dr", "doktor"), # doctor + # Add other Turkish abbreviations here if needed. + ] + ], + "hu": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("dr", "doktor"), # doctor + ("b", "bácsi"), # Mr. + ("nőv", "nővér"), # nurse + # Add other Hungarian abbreviations here if needed. + ] + ], + "ko": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + # Korean doesn't typically use abbreviations in the same way as Latin-based scripts. + ] + ], + "hi": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + # Hindi doesn't typically use abbreviations in the same way as Latin-based scripts. + ] + ], + "vi": [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("ông", "ông"), # Mr. + ("bà", "bà"), # Mrs. + ("dr", "bác sĩ"), # doctor + ("ts", "tiến sĩ"), # PhD + ("st", "số thứ tự"), # ordinal + ] + ], +} + + +def expand_abbreviations_multilingual(text, lang="en"): + for regex, replacement in _abbreviations[lang]: + text = re.sub(regex, replacement, text) + return text + + +_symbols_multilingual = { + "en": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " and "), + ("@", " at "), + ("%", " percent "), + ("#", " hash "), + ("$", " dollar "), + ("£", " pound "), + ("°", " degree "), + ] + ], + "es": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " y "), + ("@", " arroba "), + ("%", " por ciento "), + ("#", " numeral "), + ("$", " dolar "), + ("£", " libra "), + ("°", " grados "), + ] + ], + "fr": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " et "), + ("@", " arobase "), + ("%", " pour cent "), + ("#", " dièse "), + ("$", " dollar "), + ("£", " livre "), + ("°", " degrés "), + ] + ], + "de": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " und "), + ("@", " at "), + ("%", " prozent "), + ("#", " raute "), + ("$", " dollar "), + ("£", " pfund "), + ("°", " grad "), + ] + ], + "pt": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " e "), + ("@", " arroba "), + ("%", " por cento "), + ("#", " cardinal "), + ("$", " dólar "), + ("£", " libra "), + ("°", " graus "), + ] + ], + "it": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " e "), + ("@", " chiocciola "), + ("%", " per cento "), + ("#", " cancelletto "), + ("$", " dollaro "), + ("£", " sterlina "), + ("°", " gradi "), + ] + ], + "pl": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " i "), + ("@", " małpa "), + ("%", " procent "), + ("#", " krzyżyk "), + ("$", " dolar "), + ("£", " funt "), + ("°", " stopnie "), + ] + ], + "ar": [ + # Arabic + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " و "), + ("@", " على "), + ("%", " في المئة "), + ("#", " رقم "), + ("$", " دولار "), + ("£", " جنيه "), + ("°", " درجة "), + ] + ], + "zh": [ + # Chinese + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " 和 "), + ("@", " 在 "), + ("%", " 百分之 "), + ("#", " 号 "), + ("$", " 美元 "), + ("£", " 英镑 "), + ("°", " 度 "), + ] + ], + "cs": [ + # Czech + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " a "), + ("@", " na "), + ("%", " procento "), + ("#", " křížek "), + ("$", " dolar "), + ("£", " libra "), + ("°", " stupně "), + ] + ], + "ru": [ + # Russian + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " и "), + ("@", " собака "), + ("%", " процентов "), + ("#", " номер "), + ("$", " доллар "), + ("£", " фунт "), + ("°", " градус "), + ] + ], + "nl": [ + # Dutch + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " en "), + ("@", " bij "), + ("%", " procent "), + ("#", " hekje "), + ("$", " dollar "), + ("£", " pond "), + ("°", " graden "), + ] + ], + "tr": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " ve "), + ("@", " at "), + ("%", " yüzde "), + ("#", " diyez "), + ("$", " dolar "), + ("£", " sterlin "), + ("°", " derece "), + ] + ], + "hu": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " és "), + ("@", " kukac "), + ("%", " százalék "), + ("#", " kettőskereszt "), + ("$", " dollár "), + ("£", " font "), + ("°", " fok "), + ] + ], + "ko": [ + # Korean + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " 그리고 "), + ("@", " 에 "), + ("%", " 퍼센트 "), + ("#", " 번호 "), + ("$", " 달러 "), + ("£", " 파운드 "), + ("°", " 도 "), + ] + ], + "hi": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " और "), + ("@", " ऐट दी रेट "), + ("%", " प्रतिशत "), + ("#", " हैश "), + ("$", " डॉलर "), + ("£", " पाउंड "), + ("°", " डिग्री "), + ] + ], + "vi": [ + (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) + for x in [ + ("&", " và "), # and + ("@", " a còng "), # at + ("%", " phần trăm "), # percent + ("#", " dấu thăng "), # hash + ("$", " đô la "), # dollar + ("£", " bảng Anh "), # pound + ("°", " độ "), # degree + ] + ], + +} + + +def expand_symbols_multilingual(text, lang="en"): + for regex, replacement in _symbols_multilingual[lang]: + text = re.sub(regex, replacement, text) + text = text.replace(" ", " ") # Ensure there are no double spaces + return text.strip() + + +_ordinal_re = { + "en": re.compile(r"([0-9]+)(st|nd|rd|th)"), + "es": re.compile(r"([0-9]+)(º|ª|er|o|a|os|as)"), + "fr": re.compile(r"([0-9]+)(º|ª|er|re|e|ème)"), + "de": re.compile(r"([0-9]+)(st|nd|rd|th|º|ª|\.(?=\s|$))"), + "pt": re.compile(r"([0-9]+)(º|ª|o|a|os|as)"), + "it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"), + "pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"), + "ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"), + "cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals. + "ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"), + "nl": re.compile(r"([0-9]+)(de|ste|e)"), + "tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"), + "hu": re.compile(r"([0-9]+)(\.|adik|edik|odik|edik|ödik|ödike|ik)"), + "ko": re.compile(r"([0-9]+)(번째|번|차|째)"), + "hi": re.compile(r"([0-9]+)(st|nd|rd|th)"), # To check + "vi": re.compile(r"([0-9]+)(th|thứ)?"), # Matches "1", "thứ 1", "2", "thứ 2" +} + +_number_re = re.compile(r"[0-9]+") +_currency_re = { + "USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"), + "GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"), + "EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"), +} + +_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b") +_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b") +_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)") + + +def _remove_commas(m): + text = m.group(0) + if "," in text: + text = text.replace(",", "") + return text + + +def _remove_dots(m): + text = m.group(0) + if "." in text: + text = text.replace(".", "") + return text + + +def _expand_decimal_point(m, lang="en"): + amount = m.group(1).replace(",", ".") + return num2words(float(amount), lang=lang if lang != "cs" else "cz") + + +def _expand_currency(m, lang="en", currency="USD"): + amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", ".")))) + full_amount = num2words(amount, to="currency", currency=currency, lang=lang if lang != "cs" else "cz") + + and_equivalents = { + "en": ", ", + "es": " con ", + "fr": " et ", + "de": " und ", + "pt": " e ", + "it": " e ", + "pl": ", ", + "cs": ", ", + "ru": ", ", + "nl": ", ", + "ar": ", ", + "tr": ", ", + "hu": ", ", + "ko": ", ", + "hi": ", ", + "vi": " và ", + + } + + if amount.is_integer(): + last_and = full_amount.rfind(and_equivalents[lang]) + if last_and != -1: + full_amount = full_amount[:last_and] + + return full_amount + + +def _expand_ordinal(m, lang="en"): + return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz") + + +def _expand_number(m, lang="en"): + return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz") + + +def expand_numbers_multilingual(text, lang="en"): + if lang == "zh": + text = zh_num2words()(text) + else: + if lang in ["en", "ru"]: + text = re.sub(_comma_number_re, _remove_commas, text) + else: + text = re.sub(_dot_number_re, _remove_dots, text) + try: + text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text) + text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text) + text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text) + except: + pass + if lang != "tr": + text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text) + text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text) + text = re.sub(_number_re, lambda m: _expand_number(m, lang), text) + return text + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text) + + +def multilingual_cleaners(text, lang): + text = text.replace('"', "") + if lang == "tr": + text = text.replace("İ", "i") + text = text.replace("Ö", "ö") + text = text.replace("Ü", "ü") + text = lowercase(text) + text = expand_numbers_multilingual(text, lang) + text = expand_abbreviations_multilingual(text, lang) + text = expand_symbols_multilingual(text, lang=lang) + text = collapse_whitespace(text) + return text + + +def basic_cleaners(text): + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def chinese_transliterate(text): + try: + import pypinyin + except ImportError as e: + raise ImportError("Chinese requires: pypinyin") from e + return "".join( + [p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)] + ) + + +def japanese_cleaners(text, katsu): + text = katsu.romaji(text) + text = lowercase(text) + return text + + +def korean_transliterate(text): + try: + from hangul_romanize import Transliter + from hangul_romanize.rule import academic + except ImportError as e: + raise ImportError("Korean requires: hangul_romanize") from e + r = Transliter(academic) + return r.translit(text) + + +DEFAULT_VOCAB_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../data/tokenizer.json") + + +class VoiceBpeTokenizer: + def __init__(self, vocab_file=None): + self.tokenizer = None + if vocab_file is not None: + self.tokenizer = Tokenizer.from_file(vocab_file) + self.char_limits = { + "en": 250, + "de": 253, + "fr": 273, + "es": 239, + "it": 213, + "pt": 203, + "pl": 224, + "zh": 82, + "ar": 166, + "cs": 186, + "ru": 182, + "nl": 251, + "tr": 226, + "ja": 71, + "hu": 224, + "ko": 95, + "hi": 150, + "vi": 250, + } + + @cached_property + def katsu(self): + import cutlet + + return cutlet.Cutlet() + + def check_input_length(self, txt, lang): + lang = lang.split("-")[0] # remove the region + limit = self.char_limits.get(lang, 250) + if len(txt) > limit: + logger.warning( + "The text length exceeds the character limit of %d for language '%s', this might cause truncated audio.", + limit, + lang, + ) + + def preprocess_text(self, txt, lang): + if lang in {"ar", "cs", "de", "en", "es", "fr", "hi", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "ko", "vi"}: + txt = multilingual_cleaners(txt, lang) + if lang == "zh": + txt = chinese_transliterate(txt) + if lang == "ko": + txt = korean_transliterate(txt) + elif lang == "ja": + txt = japanese_cleaners(txt, self.katsu) + else: + raise NotImplementedError(f"Language '{lang}' is not supported.") + return txt + + + def encode(self, txt, lang): + lang = lang.split("-")[0] # remove the region + self.check_input_length(txt, lang) + txt = self.preprocess_text(txt, lang) + lang = "zh-cn" if lang == "zh" else lang + txt = f"[{lang}]{txt}" + txt = txt.replace(" ", "[SPACE]") + return self.tokenizer.encode(txt).ids + + def decode(self, seq): + if isinstance(seq, torch.Tensor): + seq = seq.cpu().numpy() + txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "") + txt = txt.replace("[SPACE]", " ") + txt = txt.replace("[STOP]", "") + txt = txt.replace("[UNK]", "") + return txt + + def __len__(self): + return self.tokenizer.get_vocab_size() + + def get_number_tokens(self): + return max(self.tokenizer.get_vocab().values()) + 1 + + +def test_expand_numbers_multilingual(): + test_cases = [ + # English + ("In 12.5 seconds.", "In twelve point five seconds.", "en"), + ("There were 50 soldiers.", "There were fifty soldiers.", "en"), + ("This is a 1st test", "This is a first test", "en"), + ("That will be $20 sir.", "That will be twenty dollars sir.", "en"), + ("That will be 20€ sir.", "That will be twenty euro sir.", "en"), + ("That will be 20.15€ sir.", "That will be twenty euro, fifteen cents sir.", "en"), + ("That's 100,000.5.", "That's one hundred thousand point five.", "en"), + # French + ("En 12,5 secondes.", "En douze virgule cinq secondes.", "fr"), + ("Il y avait 50 soldats.", "Il y avait cinquante soldats.", "fr"), + ("Ceci est un 1er test", "Ceci est un premier test", "fr"), + ("Cela vous fera $20 monsieur.", "Cela vous fera vingt dollars monsieur.", "fr"), + ("Cela vous fera 20€ monsieur.", "Cela vous fera vingt euros monsieur.", "fr"), + ("Cela vous fera 20,15€ monsieur.", "Cela vous fera vingt euros et quinze centimes monsieur.", "fr"), + ("Ce sera 100.000,5.", "Ce sera cent mille virgule cinq.", "fr"), + # German + ("In 12,5 Sekunden.", "In zwölf Komma fünf Sekunden.", "de"), + ("Es gab 50 Soldaten.", "Es gab fünfzig Soldaten.", "de"), + ("Dies ist ein 1. Test", "Dies ist ein erste Test", "de"), # Issue with gender + ("Das macht $20 Herr.", "Das macht zwanzig Dollar Herr.", "de"), + ("Das macht 20€ Herr.", "Das macht zwanzig Euro Herr.", "de"), + ("Das macht 20,15€ Herr.", "Das macht zwanzig Euro und fünfzehn Cent Herr.", "de"), + # Spanish + ("En 12,5 segundos.", "En doce punto cinco segundos.", "es"), + ("Había 50 soldados.", "Había cincuenta soldados.", "es"), + ("Este es un 1er test", "Este es un primero test", "es"), + ("Eso le costará $20 señor.", "Eso le costará veinte dólares señor.", "es"), + ("Eso le costará 20€ señor.", "Eso le costará veinte euros señor.", "es"), + ("Eso le costará 20,15€ señor.", "Eso le costará veinte euros con quince céntimos señor.", "es"), + # Italian + ("In 12,5 secondi.", "In dodici virgola cinque secondi.", "it"), + ("C'erano 50 soldati.", "C'erano cinquanta soldati.", "it"), + ("Questo è un 1° test", "Questo è un primo test", "it"), + ("Ti costerà $20 signore.", "Ti costerà venti dollari signore.", "it"), + ("Ti costerà 20€ signore.", "Ti costerà venti euro signore.", "it"), + ("Ti costerà 20,15€ signore.", "Ti costerà venti euro e quindici centesimi signore.", "it"), + # Portuguese + ("Em 12,5 segundos.", "Em doze vírgula cinco segundos.", "pt"), + ("Havia 50 soldados.", "Havia cinquenta soldados.", "pt"), + ("Este é um 1º teste", "Este é um primeiro teste", "pt"), + ("Isso custará $20 senhor.", "Isso custará vinte dólares senhor.", "pt"), + ("Isso custará 20€ senhor.", "Isso custará vinte euros senhor.", "pt"), + ( + "Isso custará 20,15€ senhor.", + "Isso custará vinte euros e quinze cêntimos senhor.", + "pt", + ), # "cêntimos" should be "centavos" num2words issue + # Polish + ("W 12,5 sekundy.", "W dwanaście przecinek pięć sekundy.", "pl"), + ("Było 50 żołnierzy.", "Było pięćdziesiąt żołnierzy.", "pl"), + ("To będzie kosztować 20€ panie.", "To będzie kosztować dwadzieścia euro panie.", "pl"), + ("To będzie kosztować 20,15€ panie.", "To będzie kosztować dwadzieścia euro, piętnaście centów panie.", "pl"), + # Arabic + ("في الـ 12,5 ثانية.", "في الـ اثنا عشر , خمسون ثانية.", "ar"), + ("كان هناك 50 جنديًا.", "كان هناك خمسون جنديًا.", "ar"), + # ("ستكون النتيجة $20 يا سيد.", 'ستكون النتيجة عشرون دولار يا سيد.', 'ar'), # $ and € are mising from num2words + # ("ستكون النتيجة 20€ يا سيد.", 'ستكون النتيجة عشرون يورو يا سيد.', 'ar'), + # Czech + ("Za 12,5 vteřiny.", "Za dvanáct celá pět vteřiny.", "cs"), + ("Bylo tam 50 vojáků.", "Bylo tam padesát vojáků.", "cs"), + ("To bude stát 20€ pane.", "To bude stát dvacet euro pane.", "cs"), + ("To bude 20.15€ pane.", "To bude dvacet euro, patnáct centů pane.", "cs"), + # Russian + ("Через 12.5 секунды.", "Через двенадцать запятая пять секунды.", "ru"), + ("Там было 50 солдат.", "Там было пятьдесят солдат.", "ru"), + ("Это будет 20.15€ сэр.", "Это будет двадцать евро, пятнадцать центов сэр.", "ru"), + ("Это будет стоить 20€ господин.", "Это будет стоить двадцать евро господин.", "ru"), + # Dutch + ("In 12,5 seconden.", "In twaalf komma vijf seconden.", "nl"), + ("Er waren 50 soldaten.", "Er waren vijftig soldaten.", "nl"), + ("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"), + ("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"), + # Chinese (Simplified) + ("在12.5秒内", "在十二点五秒内", "zh"), + ("有50名士兵", "有五十名士兵", "zh"), + # ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work + # ("那将是20€先生", '那将是二十欧元先生', 'zh'), + # Turkish + # ("12,5 saniye içinde.", 'On iki virgül beş saniye içinde.', 'tr'), # decimal doesn't work for TR + ("50 asker vardı.", "elli asker vardı.", "tr"), + ("Bu 1. test", "Bu birinci test", "tr"), + # ("Bu 100.000,5.", 'Bu yüz bin virgül beş.', 'tr'), + # Hungarian + ("12,5 másodperc alatt.", "tizenkettő egész öt tized másodperc alatt.", "hu"), + ("50 katona volt.", "ötven katona volt.", "hu"), + ("Ez az 1. teszt", "Ez az első teszt", "hu"), + # Korean + ("12.5 초 안에.", "십이 점 다섯 초 안에.", "ko"), + ("50 명의 병사가 있었다.", "오십 명의 병사가 있었다.", "ko"), + ("이것은 1 번째 테스트입니다", "이것은 첫 번째 테스트입니다", "ko"), + # Hindi + ("12.5 सेकंड में।", "साढ़े बारह सेकंड में।", "hi"), + ("50 सैनिक थे।", "पचास सैनिक थे।", "hi"), + # Vietnamese + ("Trong 12,5 giây.", "Trong mười hai phẩy năm giây.", "vi"), + ("Có 50 binh sĩ.", "Có năm mươi binh sĩ.", "vi"), + ("Đây là bài kiểm tra thứ 1", "Đây là bài kiểm tra thứ nhất", "vi"), + ("Đó sẽ là $20 thưa ông.", "Đó sẽ là hai mươi đô la thưa ông.", "vi"), + ("Đó sẽ là 20€ thưa ông.", "Đó sẽ là hai mươi euro thưa ông.", "vi"), + ("Đó sẽ là 20,15€ thưa ông.", "Đó sẽ là hai mươi euro, mười lăm xu thưa ông.", "vi"), + + ] + for a, b, lang in test_cases: + out = expand_numbers_multilingual(a, lang=lang) + assert out == b, f"'{out}' vs '{b}'" + + +def test_abbreviations_multilingual(): + test_cases = [ + # English + ("Hello Mr. Smith.", "Hello mister Smith.", "en"), + ("Dr. Jones is here.", "doctor Jones is here.", "en"), + # Spanish + ("Hola Sr. Garcia.", "Hola señor Garcia.", "es"), + ("La Dra. Martinez es muy buena.", "La doctora Martinez es muy buena.", "es"), + # French + ("Bonjour Mr. Dupond.", "Bonjour monsieur Dupond.", "fr"), + ("Mme. Moreau est absente aujourd'hui.", "madame Moreau est absente aujourd'hui.", "fr"), + # German + ("Frau Dr. Müller ist sehr klug.", "Frau doktor Müller ist sehr klug.", "de"), + # Portuguese + ("Olá Sr. Silva.", "Olá senhor Silva.", "pt"), + ("Dra. Costa, você está disponível?", "doutora Costa, você está disponível?", "pt"), + # Italian + ("Buongiorno, Sig. Rossi.", "Buongiorno, signore Rossi.", "it"), + # ("Sig.ra Bianchi, posso aiutarti?", 'signora Bianchi, posso aiutarti?', 'it'), # Issue with matching that pattern + # Polish + ("Dzień dobry, P. Kowalski.", "Dzień dobry, pani Kowalski.", "pl"), + ("M. Nowak, czy mogę zadać pytanie?", "pan Nowak, czy mogę zadać pytanie?", "pl"), + # Czech + ("P. Novák", "pan Novák", "cs"), + ("Dr. Vojtěch", "doktor Vojtěch", "cs"), + # Dutch + ("Dhr. Jansen", "de heer Jansen", "nl"), + ("Mevr. de Vries", "mevrouw de Vries", "nl"), + # Russian + ("Здравствуйте Г-н Иванов.", "Здравствуйте господин Иванов.", "ru"), + ("Д-р Смирнов здесь, чтобы увидеть вас.", "доктор Смирнов здесь, чтобы увидеть вас.", "ru"), + # Turkish + ("Merhaba B. Yılmaz.", "Merhaba bay Yılmaz.", "tr"), + ("Dr. Ayşe burada.", "doktor Ayşe burada.", "tr"), + # Hungarian + ("Dr. Szabó itt van.", "doktor Szabó itt van.", "hu"), + # Vietnamese + ("Chào Ông. Nguyễn.", "Chào ông Nguyễn.", "vi"), + ("Bà. Trần đã đến.", "Bà Trần đã đến.", "vi"), + ("Tiến sĩ. Lê đang đợi bạn.", "Tiến sĩ Lê đang đợi bạn.", "vi"), + + ] + + for a, b, lang in test_cases: + out = expand_abbreviations_multilingual(a, lang=lang) + assert out == b, f"'{out}' vs '{b}'" + + +def test_symbols_multilingual(): + test_cases = [ + ("I have 14% battery", "I have 14 percent battery", "en"), + ("Te veo @ la fiesta", "Te veo arroba la fiesta", "es"), + ("J'ai 14° de fièvre", "J'ai 14 degrés de fièvre", "fr"), + ("Die Rechnung beträgt £ 20", "Die Rechnung beträgt pfund 20", "de"), + ("O meu email é ana&joao@gmail.com", "O meu email é ana e joao arroba gmail.com", "pt"), + ("linguaggio di programmazione C#", "linguaggio di programmazione C cancelletto", "it"), + ("Moja temperatura to 36.6°", "Moja temperatura to 36.6 stopnie", "pl"), + ("Mám 14% baterie", "Mám 14 procento baterie", "cs"), + ("Těším se na tebe @ party", "Těším se na tebe na party", "cs"), + ("У меня 14% заряда", "У меня 14 процентов заряда", "ru"), + ("Я буду @ дома", "Я буду собака дома", "ru"), + ("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"), + ("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"), + ("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"), + ("我的电量为 14%", "我的电量为 14 百分之", "zh"), + ("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"), + ("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"), + ("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"), + ("मेरे पास 14% बैटरी है।", "मेरे पास चौदह प्रतिशत बैटरी है।", "hi"), + ("Tôi còn 14% pin", "Tôi còn 14 phần trăm pin", "vi"), + ("Hẹn gặp @ bữa tiệc", "Hẹn gặp tại bữa tiệc", "vi"), + ("Nhiệt độ là 36.5°", "Nhiệt độ là 36.5 độ", "vi"), + + ] + + for a, b, lang in test_cases: + out = expand_symbols_multilingual(a, lang=lang) + assert out == b, f"'{out}' vs '{b}'" + + +if __name__ == "__main__": + test_expand_numbers_multilingual() + test_abbreviations_multilingual() + test_symbols_multilingual() diff --git a/TTS/tts/layers/xtts/trainer/dataset.py b/TTS/tts/layers/xtts/trainer/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e59823266526f28252960c9ad5e839030b89058c --- /dev/null +++ b/TTS/tts/layers/xtts/trainer/dataset.py @@ -0,0 +1,244 @@ +import logging +import random +import sys + +import torch +import torch.nn.functional as F +import torch.utils.data + +from TTS.tts.models.xtts import load_audio + +logger = logging.getLogger(__name__) + +torch.set_num_threads(1) + + +def key_samples_by_col(samples, col): + """Returns a dictionary of samples keyed by language.""" + samples_by_col = {} + for sample in samples: + col_val = sample[col] + assert isinstance(col_val, str) + if col_val not in samples_by_col: + samples_by_col[col_val] = [] + samples_by_col[col_val].append(sample) + return samples_by_col + + +def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, is_eval=False): + rel_clip = load_audio(gt_path, sample_rate) + # if eval uses a middle size sample when it is possible to be more reproducible + if is_eval: + sample_length = int((min_sample_length + max_sample_length) / 2) + else: + sample_length = random.randint(min_sample_length, max_sample_length) + gap = rel_clip.shape[-1] - sample_length + if gap < 0: + sample_length = rel_clip.shape[-1] // 2 + gap = rel_clip.shape[-1] - sample_length + + # if eval start always from the position 0 to be more reproducible + if is_eval: + rand_start = 0 + else: + rand_start = random.randint(0, gap) + + rand_end = rand_start + sample_length + rel_clip = rel_clip[:, rand_start:rand_end] + rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1])) + cond_idxs = [rand_start, rand_end] + return rel_clip, rel_clip.shape[-1], cond_idxs + + +class XTTSDataset(torch.utils.data.Dataset): + def __init__(self, config, samples, tokenizer, sample_rate, is_eval=False): + self.config = config + model_args = config.model_args + self.failed_samples = set() + self.debug_failures = model_args.debug_loading_failures + self.max_conditioning_length = model_args.max_conditioning_length + self.min_conditioning_length = model_args.min_conditioning_length + self.is_eval = is_eval + self.tokenizer = tokenizer + self.sample_rate = sample_rate + self.max_wav_len = model_args.max_wav_length + self.max_text_len = model_args.max_text_length + self.use_masking_gt_prompt_approach = model_args.gpt_use_masking_gt_prompt_approach + assert self.max_wav_len is not None and self.max_text_len is not None + + self.samples = samples + if not is_eval: + random.seed(config.training_seed) + # random.shuffle(self.samples) + random.shuffle(self.samples) + # order by language + self.samples = key_samples_by_col(self.samples, "language") + logger.info("Sampling by language: %s", self.samples.keys()) + else: + # for evaluation load and check samples that are corrupted to ensures the reproducibility + self.check_eval_samples() + + def check_eval_samples(self): + logger.info("Filtering invalid eval samples!!") + new_samples = [] + for sample in self.samples: + try: + tseq, _, wav, _, _, _ = self.load_item(sample) + except: + continue + # Basically, this audio file is nonexistent or too long to be supported by the dataset. + if ( + wav is None + or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) + or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len) + ): + continue + new_samples.append(sample) + self.samples = new_samples + logger.info("Total eval samples after filtering: %d", len(self.samples)) + + def get_text(self, text, lang): + tokens = self.tokenizer.encode(text, lang) + tokens = torch.IntTensor(tokens) + assert not torch.any(tokens == 1), f"UNK token found in {text} -> {self.tokenizer.decode(tokens)}" + # The stop token should always be sacred. + assert not torch.any(tokens == 0), f"Stop token found in {text}" + return tokens + + def load_item(self, sample): + text = str(sample["text"]) + tseq = self.get_text(text, sample["language"]) + audiopath = sample["audio_file"] + wav = load_audio(audiopath, self.sample_rate) + if text is None or len(text.strip()) == 0: + raise ValueError + if wav is None or wav.shape[-1] < (0.5 * self.sample_rate): + # Ultra short clips are also useless (and can cause problems within some models). + raise ValueError + + if self.use_masking_gt_prompt_approach: + # get a slice from GT to condition the model + cond, _, cond_idxs = get_prompt_slice( + audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval + ) + # if use masking do not use cond_len + cond_len = torch.nan + else: + ref_sample = ( + sample["reference_path"] + if "reference_path" in sample and sample["reference_path"] is not None + else audiopath + ) + cond, cond_len, _ = get_prompt_slice( + ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval + ) + # if do not use masking use cond_len + cond_idxs = torch.nan + + return tseq, audiopath, wav, cond, cond_len, cond_idxs + + def __getitem__(self, index): + if self.is_eval: + sample = self.samples[index] + sample_id = str(index) + else: + # select a random language + lang = random.choice(list(self.samples.keys())) + # select random sample + index = random.randint(0, len(self.samples[lang]) - 1) + sample = self.samples[lang][index] + # a unique id for each sampel to deal with fails + sample_id = lang + "_" + str(index) + + # ignore samples that we already know that is not valid ones + if sample_id in self.failed_samples: + if self.debug_failures: + logger.info("Ignoring sample %s because it was already ignored before !!", sample["audio_file"]) + # call get item again to get other sample + return self[1] + + # try to load the sample, if fails added it to the failed samples list + try: + tseq, audiopath, wav, cond, cond_len, cond_idxs = self.load_item(sample) + except: + if self.debug_failures: + logger.warning("Error loading %s %s", sample["audio_file"], sys.exc_info()) + self.failed_samples.add(sample_id) + return self[1] + + # check if the audio and text size limits and if it out of the limits, added it failed_samples + if ( + wav is None + or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) + or (self.max_text_len is not None and tseq.shape[0] > self.max_text_len) + ): + # Basically, this audio file is nonexistent or too long to be supported by the dataset. + # It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result. + if self.debug_failures and wav is not None and tseq is not None: + logger.warning( + "Error loading %s: ranges are out of bounds: %d, %d", + sample["audio_file"], + wav.shape[-1], + tseq.shape[0], + ) + self.failed_samples.add(sample_id) + return self[1] + + res = { + # 'real_text': text, + "text": tseq, + "text_lengths": torch.tensor(tseq.shape[0], dtype=torch.long), + "wav": wav, + "wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long), + "filenames": audiopath, + "conditioning": cond.unsqueeze(1), + "cond_lens": ( + torch.tensor(cond_len, dtype=torch.long) if cond_len is not torch.nan else torch.tensor([cond_len]) + ), + "cond_idxs": torch.tensor(cond_idxs) if cond_idxs is not torch.nan else torch.tensor([cond_idxs]), + } + return res + + def __len__(self): + if self.is_eval: + return len(self.samples) + return sum([len(v) for v in self.samples.values()]) + + def collate_fn(self, batch): + # convert list of dicts to dict of lists + B = len(batch) + + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + # stack for features that already have the same shape + batch["wav_lengths"] = torch.stack(batch["wav_lengths"]) + batch["text_lengths"] = torch.stack(batch["text_lengths"]) + batch["conditioning"] = torch.stack(batch["conditioning"]) + batch["cond_lens"] = torch.stack(batch["cond_lens"]) + batch["cond_idxs"] = torch.stack(batch["cond_idxs"]) + + if torch.any(batch["cond_idxs"].isnan()): + batch["cond_idxs"] = None + + if torch.any(batch["cond_lens"].isnan()): + batch["cond_lens"] = None + + max_text_len = batch["text_lengths"].max() + max_wav_len = batch["wav_lengths"].max() + + # create padding tensors + text_padded = torch.IntTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, max_wav_len) + + # initialize tensors for zero padding + text_padded = text_padded.zero_() + wav_padded = wav_padded.zero_() + for i in range(B): + text = batch["text"][i] + text_padded[i, : batch["text_lengths"][i]] = torch.IntTensor(text) + wav = batch["wav"][i] + wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav) + + batch["wav"] = wav_padded + batch["padded_text"] = text_padded + return batch diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..9d9edd5758270dbc7806a3752bc9ac78b3d250cf --- /dev/null +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -0,0 +1,511 @@ +import logging +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Union + +import torch +import torch.nn as nn +import torchaudio +from coqpit import Coqpit +from torch.utils.data import DataLoader +from trainer.io import load_fsspec +from trainer.torch import DistributedSampler +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.configs.xtts_config import XttsConfig +from TTS.tts.datasets.dataset import TTSDataset +from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram +from TTS.tts.layers.xtts.dvae import DiscreteVAE +from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer +from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.models.xtts import Xtts, XttsArgs, XttsAudioConfig +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 + +logger = logging.getLogger(__name__) + + +@dataclass +class GPTTrainerConfig(XttsConfig): + lr: float = 5e-06 + training_seed: int = 1 + optimizer_wd_only_on_weights: bool = False + weighted_loss_attrs: dict = field(default_factory=lambda: {}) + weighted_loss_multipliers: dict = field(default_factory=lambda: {}) + test_sentences: List[dict] = field(default_factory=lambda: []) + + +@dataclass +class XttsAudioConfig(XttsAudioConfig): + dvae_sample_rate: int = 22050 + + +@dataclass +class GPTArgs(XttsArgs): + min_conditioning_length: int = 66150 + max_conditioning_length: int = 132300 + gpt_loss_text_ce_weight: float = 0.01 + gpt_loss_mel_ce_weight: float = 1.0 + gpt_num_audio_tokens: int = 8194 + debug_loading_failures: bool = False + max_wav_length: int = 255995 # ~11.6 seconds + max_text_length: int = 200 + tokenizer_file: str = "" + mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth" + dvae_checkpoint: str = "" + xtts_checkpoint: str = "" + gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model + vocoder: str = "" # overide vocoder key on the config to avoid json write issues + + +def callback_clearml_load_save(operation_type, model_info): + # return None means skip the file upload/log, returning model_info will continue with the log/upload + # you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size + assert operation_type in ("load", "save") + logger.debug("%s %s", operation_type, model_info.__dict__) + + if "similarities.pth" in model_info.__dict__["local_model_path"]: + return None + + return model_info + + +class GPTTrainer(BaseTTS): + def __init__(self, config: Coqpit): + """ + Tortoise GPT training class + """ + super().__init__(config, ap=None, tokenizer=None) + self.config = config + # init XTTS model + self.xtts = Xtts(self.config) + # create the tokenizer with the target vocabulary + self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file) + # init gpt encoder and hifigan decoder + self.xtts.init_models() + + if self.args.xtts_checkpoint: + self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False) + + # set mel stats + if self.args.mel_norm_file: + self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file) + + # load GPT if available + if self.args.gpt_checkpoint: + gpt_checkpoint = torch.load( + self.args.gpt_checkpoint, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4() + ) + # deal with coqui Trainer exported model + if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys(): + logger.info("Coqui Trainer checkpoint detected! Converting it!") + gpt_checkpoint = gpt_checkpoint["model"] + states_keys = list(gpt_checkpoint.keys()) + for key in states_keys: + if "gpt." in key: + new_key = key.replace("gpt.", "") + gpt_checkpoint[new_key] = gpt_checkpoint[key] + del gpt_checkpoint[key] + else: + del gpt_checkpoint[key] + + # edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible + if ( + "text_embedding.weight" in gpt_checkpoint + and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape + ): + num_new_tokens = ( + self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0] + ) + logger.info("Loading checkpoint with %d additional tokens.", num_new_tokens) + + # add new tokens to a linear layer (text_head) + emb_g = gpt_checkpoint["text_embedding.weight"] + new_row = torch.randn(num_new_tokens, emb_g.shape[1]) + start_token_row = emb_g[-1, :] + emb_g = torch.cat([emb_g, new_row], axis=0) + emb_g[-1, :] = start_token_row + gpt_checkpoint["text_embedding.weight"] = emb_g + + # add new weights to the linear layer (text_head) + text_head_weight = gpt_checkpoint["text_head.weight"] + start_token_row = text_head_weight[-1, :] + new_entry = torch.randn(num_new_tokens, self.xtts.gpt.text_head.weight.shape[1]) + text_head_weight = torch.cat([text_head_weight, new_entry], axis=0) + text_head_weight[-1, :] = start_token_row + gpt_checkpoint["text_head.weight"] = text_head_weight + + # add new biases to the linear layer (text_head) + text_head_bias = gpt_checkpoint["text_head.bias"] + start_token_row = text_head_bias[-1] + new_bias_entry = torch.zeros(num_new_tokens) + text_head_bias = torch.cat([text_head_bias, new_bias_entry], axis=0) + text_head_bias[-1] = start_token_row + gpt_checkpoint["text_head.bias"] = text_head_bias + + self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True) + logger.info("GPT weights restored from: %s", self.args.gpt_checkpoint) + + # Mel spectrogram extractor for conditioning + if self.args.gpt_use_perceiver_resampler: + self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram( + filter_length=2048, + hop_length=256, + win_length=1024, + normalize=False, + sampling_rate=config.audio.sample_rate, + mel_fmin=0, + mel_fmax=8000, + n_mel_channels=80, + mel_norm_file=self.args.mel_norm_file, + ) + else: + self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram( + filter_length=4096, + hop_length=1024, + win_length=4096, + normalize=False, + sampling_rate=config.audio.sample_rate, + mel_fmin=0, + mel_fmax=8000, + n_mel_channels=80, + mel_norm_file=self.args.mel_norm_file, + ) + + # Load DVAE + self.dvae = DiscreteVAE( + channels=80, + normalization=None, + positional_dims=1, + num_tokens=self.args.gpt_num_audio_tokens - 2, + codebook_dim=512, + hidden_dim=512, + num_resnet_blocks=3, + kernel_size=3, + num_layers=2, + use_transposed_convs=False, + ) + + self.dvae.eval() + if self.args.dvae_checkpoint: + dvae_checkpoint = torch.load( + self.args.dvae_checkpoint, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4() + ) + self.dvae.load_state_dict(dvae_checkpoint, strict=False) + logger.info("DVAE weights restored from: %s", self.args.dvae_checkpoint) + else: + raise RuntimeError( + "You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!" + ) + + # Mel spectrogram extractor for DVAE + self.torch_mel_spectrogram_dvae = TorchMelSpectrogram( + mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens): + """ + Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode + (actuated by `text_first`). + + text_inputs: long tensor, (b,t) + text_lengths: long tensor, (b,) + mel_inputs: long tensor, (b,m) + wav_lengths: long tensor, (b,) + cond_mels: MEL float tensor, (b, num_samples, 80,t_m) + cond_idxs: cond start and end indexs, (b, 2) + cond_lens: long tensor, (b,) + """ + losses = self.xtts.gpt( + text_inputs, + text_lengths, + audio_codes, + wav_lengths, + cond_mels=cond_mels, + cond_idxs=cond_idxs, + cond_lens=cond_lens, + ) + return losses + + @torch.no_grad() + def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613 + test_audios = {} + if self.config.test_sentences: + # init gpt for inference mode + self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) + self.xtts.gpt.eval() + logger.info("Synthesizing test sentences.") + for idx, s_info in enumerate(self.config.test_sentences): + wav = self.xtts.synthesize( + s_info["text"], + self.config, + s_info["speaker_wav"], + s_info["language"], + gpt_cond_len=3, + )["wav"] + test_audios["{}-audio".format(idx)] = wav + + # delete inference layers + del self.xtts.gpt.gpt_inference + del self.xtts.gpt.gpt.wte + return {"audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.args.output_sample_rate) + + def format_batch(self, batch: Dict) -> Dict: + return batch + + @torch.no_grad() # torch no grad to avoid gradients from the pre-processing and DVAE codes extraction + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + batch["text_lengths"] = batch["text_lengths"] + batch["wav_lengths"] = batch["wav_lengths"] + batch["text_inputs"] = batch["padded_text"] + batch["cond_idxs"] = batch["cond_idxs"] + # compute conditioning mel specs + # transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor + B, num_cond_samples, C, T = batch["conditioning"].size() + conditioning_reshaped = batch["conditioning"].view(B * num_cond_samples, C, T) + paired_conditioning_mel = self.torch_mel_spectrogram_style_encoder(conditioning_reshaped) + # transform torch.Size([B * num_cond_samples, n_mel, T_mel]) in torch.Size([B, num_cond_samples, n_mel, T_mel]) + n_mel = self.torch_mel_spectrogram_style_encoder.n_mel_channels # paired_conditioning_mel.size(1) + T_mel = paired_conditioning_mel.size(2) + paired_conditioning_mel = paired_conditioning_mel.view(B, num_cond_samples, n_mel, T_mel) + # get the conditioning embeddings + batch["cond_mels"] = paired_conditioning_mel + # compute codes using DVAE + if self.config.audio.sample_rate != self.config.audio.dvae_sample_rate: + dvae_wav = torchaudio.functional.resample( + batch["wav"], + orig_freq=self.config.audio.sample_rate, + new_freq=self.config.audio.dvae_sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method="kaiser_window", + beta=14.769656459379492, + ) + else: + dvae_wav = batch["wav"] + dvae_mel_spec = self.torch_mel_spectrogram_dvae(dvae_wav) + codes = self.dvae.get_codebook_indices(dvae_mel_spec) + + batch["audio_codes"] = codes + # delete useless batch tensors + del batch["padded_text"] + del batch["wav"] + del batch["conditioning"] + return batch + + def train_step(self, batch, criterion): + loss_dict = {} + cond_mels = batch["cond_mels"] + text_inputs = batch["text_inputs"] + text_lengths = batch["text_lengths"] + audio_codes = batch["audio_codes"] + wav_lengths = batch["wav_lengths"] + cond_idxs = batch["cond_idxs"] + cond_lens = batch["cond_lens"] + + loss_text, loss_mel, _ = self.forward( + text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_idxs, cond_lens + ) + loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight + loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight + loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"] + return {"model_outputs": None}, loss_dict + + def eval_step(self, batch, criterion): + # ignore masking for more consistent evaluation + batch["cond_idxs"] = None + return self.train_step(batch, criterion) + + def on_train_epoch_start(self, trainer): + trainer.model.eval() # the whole model to eval + # put gpt model in training mode + if hasattr(trainer.model, "module") and hasattr(trainer.model.module, "xtts"): + trainer.model.module.xtts.gpt.train() + else: + trainer.model.xtts.gpt.train() + + def on_init_end(self, trainer): # pylint: disable=W0613 + # ignore similarities.pth on clearml save/upload + if self.config.dashboard_logger.lower() == "clearml": + from clearml.binding.frameworks import WeightsFileHandler + + WeightsFileHandler.add_pre_callback(callback_clearml_load_save) + + @torch.no_grad() + def inference( + self, + x, + aux_input=None, + ): # pylint: disable=dangerous-default-value + return None + + @staticmethod + def get_criterion(): + return None + + def get_sampler(self, dataset: TTSDataset, num_gpus=1): + # sampler for DDP + batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None + return batch_sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": # pylint: disable=W0613 + if is_eval and not config.run_eval: + loader = None + else: + # init dataloader + dataset = XTTSDataset(self.config, samples, self.xtts.tokenizer, config.audio.sample_rate, is_eval) + + # wait all the DDP process to be ready + if num_gpus > 1: + torch.distributed.barrier() + + # sort input sequences from short to long + # dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(dataset, num_gpus) + + # ignore sampler when is eval because if we changed the sampler parameter we will not be able to compare previous runs + if sampler is None or is_eval: + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, + drop_last=False, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + else: + loader = DataLoader( + dataset, + sampler=sampler, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def get_optimizer(self) -> List: + """Initiate and return the optimizer based on the config parameters.""" + # ToDo: deal with multi GPU training + if self.config.optimizer_wd_only_on_weights: + # parameters to only GPT model + net = self.xtts.gpt + + # normalizations + norm_modules = ( + nn.BatchNorm2d, + nn.InstanceNorm2d, + nn.BatchNorm1d, + nn.InstanceNorm1d, + nn.BatchNorm3d, + nn.InstanceNorm3d, + nn.GroupNorm, + nn.LayerNorm, + ) + # nn.Embedding + emb_modules = (nn.Embedding, nn.EmbeddingBag) + + param_names_notweights = set() + all_param_names = set() + param_map = {} + for mn, m in net.named_modules(): + for k, v in m.named_parameters(): + v.is_bias = k.endswith(".bias") + v.is_weight = k.endswith(".weight") + v.is_norm = isinstance(m, norm_modules) + v.is_emb = isinstance(m, emb_modules) + + fpn = "%s.%s" % (mn, k) if mn else k # full param name + all_param_names.add(fpn) + param_map[fpn] = v + if v.is_bias or v.is_norm or v.is_emb: + param_names_notweights.add(fpn) + + params_names_notweights = sorted(list(param_names_notweights)) + params_notweights = [param_map[k] for k in params_names_notweights] + params_names_weights = sorted(list(all_param_names ^ param_names_notweights)) + params_weights = [param_map[k] for k in params_names_weights] + + groups = [ + {"params": params_weights, "weight_decay": self.config.optimizer_params["weight_decay"]}, + {"params": params_notweights, "weight_decay": 0}, + ] + # torch.optim.AdamW + opt = get_optimizer( + self.config.optimizer, + self.config.optimizer_params, + self.config.lr, + parameters=groups, + ) + opt._group_names = [params_names_weights, params_names_notweights] + return opt + + return get_optimizer( + self.config.optimizer, + self.config.optimizer_params, + self.config.lr, + # optimize only for the GPT model + parameters=self.xtts.gpt.parameters(), + ) + + def get_scheduler(self, optimizer) -> List: + """Set the scheduler for the optimizer. + + Args: + optimizer: `torch.optim.Optimizer`. + """ + return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer) + + def load_checkpoint( + self, + config, + checkpoint_path, + eval=False, + strict=True, + cache_storage="/tmp/tts_cache", + target_protocol="s3", + target_options={"anon": True}, + ): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin + """Load the model checkpoint and setup for training or inference""" + + state = self.xtts.get_compatible_checkpoint_state_dict(checkpoint_path) + + # load the model weights + self.xtts.load_state_dict(state, strict=strict) + + if eval: + self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False) + self.eval() + assert not self.training + + @staticmethod + def init_from_config(config: "GPTTrainerConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (GPTTrainerConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + return GPTTrainer(config) diff --git a/TTS/tts/layers/xtts/xtts_manager.py b/TTS/tts/layers/xtts/xtts_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..8156b35f0d9487d24cb090c2efcce9bd6eeadc11 --- /dev/null +++ b/TTS/tts/layers/xtts/xtts_manager.py @@ -0,0 +1,37 @@ +import torch + +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 + + +class SpeakerManager: + def __init__(self, speaker_file_path=None): + self.speakers = torch.load(speaker_file_path, weights_only=is_pytorch_at_least_2_4()) + + @property + def name_to_id(self): + return self.speakers + + @property + def num_speakers(self): + return len(self.name_to_id) + + @property + def speaker_names(self): + return list(self.name_to_id.keys()) + + +class LanguageManager: + def __init__(self, config): + self.langs = config["languages"] + + @property + def name_to_id(self): + return self.langs + + @property + def num_languages(self): + return len(self.name_to_id) + + @property + def language_names(self): + return list(self.name_to_id) diff --git a/TTS/tts/layers/xtts/zh_num2words.py b/TTS/tts/layers/xtts/zh_num2words.py new file mode 100644 index 0000000000000000000000000000000000000000..69b8dae9528e3c9b1af331870b5025104a600f74 --- /dev/null +++ b/TTS/tts/layers/xtts/zh_num2words.py @@ -0,0 +1,1209 @@ +# Authors: +# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git) +# 2019.9 - 2022 Jiayu DU + +import argparse +import csv +import logging +import re +import string +import sys + +logger = logging.getLogger(__name__) + +# fmt: off +# ================================================================================ # +# basic constant +# ================================================================================ # +CHINESE_DIGIS = "零一二三四五六七八九" +BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖" +BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖" +SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万" +SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬" +LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载" +LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載" +SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万" +SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬" + +ZERO_ALT = "〇" +ONE_ALT = "幺" +TWO_ALTS = ["两", "兩"] + +POSITIVE = ["正", "正"] +NEGATIVE = ["负", "負"] +POINT = ["点", "點"] +# PLUS = [u'加', u'加'] +# SIL = [u'杠', u'槓'] + +FILLER_CHARS = ["呃", "啊"] + +ER_WHITELIST = ( + "(儿女|儿子|儿孙|女儿|儿媳|妻儿|" + "胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|" + "儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|" + "佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)" +) +ER_WHITELIST_PATTERN = re.compile(ER_WHITELIST) + +# 中文数字系统类型 +NUMBERING_TYPES = ["low", "mid", "high"] + +CURRENCY_NAMES = "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|" "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)" +CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)" +COM_QUANTIFIERS = ( + "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|" + "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|" + "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|" + "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|" + "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|" + "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)" +) + + +# Punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git) +CN_PUNCS_STOP = "!?。。" +CN_PUNCS_NONSTOP = ""#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏·〈〉-" +CN_PUNCS = CN_PUNCS_STOP + CN_PUNCS_NONSTOP + +PUNCS = CN_PUNCS + string.punctuation +PUNCS_TRANSFORM = str.maketrans(PUNCS, "," * len(PUNCS), "") # replace puncs with English comma + + +# https://zh.wikipedia.org/wiki/全行和半行 +QJ2BJ = { + " ": " ", + "!": "!", + """: '"', + "#": "#", + "$": "$", + "%": "%", + "&": "&", + "'": "'", + "(": "(", + ")": ")", + "*": "*", + "+": "+", + ",": ",", + "-": "-", + ".": ".", + "/": "/", + "0": "0", + "1": "1", + "2": "2", + "3": "3", + "4": "4", + "5": "5", + "6": "6", + "7": "7", + "8": "8", + "9": "9", + ":": ":", + ";": ";", + "<": "<", + "=": "=", + ">": ">", + "?": "?", + "@": "@", + "A": "A", + "B": "B", + "C": "C", + "D": "D", + "E": "E", + "F": "F", + "G": "G", + "H": "H", + "I": "I", + "J": "J", + "K": "K", + "L": "L", + "M": "M", + "N": "N", + "O": "O", + "P": "P", + "Q": "Q", + "R": "R", + "S": "S", + "T": "T", + "U": "U", + "V": "V", + "W": "W", + "X": "X", + "Y": "Y", + "Z": "Z", + "[": "[", + "\": "\\", + "]": "]", + "^": "^", + "_": "_", + "`": "`", + "a": "a", + "b": "b", + "c": "c", + "d": "d", + "e": "e", + "f": "f", + "g": "g", + "h": "h", + "i": "i", + "j": "j", + "k": "k", + "l": "l", + "m": "m", + "n": "n", + "o": "o", + "p": "p", + "q": "q", + "r": "r", + "s": "s", + "t": "t", + "u": "u", + "v": "v", + "w": "w", + "x": "x", + "y": "y", + "z": "z", + "{": "{", + "|": "|", + "}": "}", + "~": "~", +} +QJ2BJ_TRANSFORM = str.maketrans("".join(QJ2BJ.keys()), "".join(QJ2BJ.values()), "") + + +# 2013 China National Standard: https://zh.wikipedia.org/wiki/通用规范汉字表, raw resources: +# https://github.com/mozillazg/pinyin-data/blob/master/kMandarin_8105.txt with 8105 chinese chars in total +CN_CHARS_COMMON = ( + "一丁七万丈三上下不与丏丐丑专且丕世丘丙业丛东丝丞丢两严丧个丫中丰串临丸丹为主丽举" + "乂乃久么义之乌乍乎乏乐乒乓乔乖乘乙乜九乞也习乡书乩买乱乳乸乾了予争事二亍于亏云互" + "亓五井亘亚些亟亡亢交亥亦产亨亩享京亭亮亲亳亵亶亸亹人亿什仁仂仃仄仅仆仇仉今介仍从" + "仑仓仔仕他仗付仙仝仞仟仡代令以仨仪仫们仰仲仳仵件价任份仿企伈伉伊伋伍伎伏伐休众优" + "伙会伛伞伟传伢伣伤伥伦伧伪伫伭伯估伲伴伶伸伺似伽伾佁佃但位低住佐佑体何佖佗佘余佚" + "佛作佝佞佟你佣佤佥佩佬佯佰佳佴佶佸佺佻佼佽佾使侁侂侃侄侈侉例侍侏侑侔侗侘供依侠侣" + "侥侦侧侨侩侪侬侮侯侴侵侹便促俄俅俊俍俎俏俐俑俗俘俙俚俜保俞俟信俣俦俨俩俪俫俭修俯" + "俱俳俵俶俸俺俾倌倍倏倒倓倔倕倘候倚倜倞借倡倥倦倧倨倩倪倬倭倮倴债倻值倾偁偃假偈偌" + "偎偏偓偕做停偡健偬偭偰偲偶偷偻偾偿傀傃傅傈傉傍傒傕傣傥傧储傩催傲傺傻僇僎像僔僖僚" + "僦僧僬僭僮僰僳僵僻儆儇儋儒儡儦儳儴儿兀允元兄充兆先光克免兑兔兕兖党兜兢入全八公六" + "兮兰共关兴兵其具典兹养兼兽冀冁内冈冉册再冏冒冔冕冗写军农冠冢冤冥冬冮冯冰冱冲决况" + "冶冷冻冼冽净凄准凇凉凋凌减凑凓凘凛凝几凡凤凫凭凯凰凳凶凸凹出击凼函凿刀刁刃分切刈" + "刊刍刎刑划刖列刘则刚创初删判刨利别刬刭刮到刳制刷券刹刺刻刽刿剀剁剂剃剅削剋剌前剐" + "剑剔剕剖剜剞剟剡剥剧剩剪副割剽剿劁劂劄劈劐劓力劝办功加务劢劣动助努劫劬劭励劲劳劼" + "劾势勃勇勉勋勍勐勒勔勖勘勚募勠勤勰勺勾勿匀包匆匈匍匏匐匕化北匙匜匝匠匡匣匦匪匮匹" + "区医匼匾匿十千卅升午卉半华协卑卒卓单卖南博卜卞卟占卡卢卣卤卦卧卫卬卮卯印危即却卵" + "卷卸卺卿厂厄厅历厉压厌厍厕厖厘厚厝原厢厣厥厦厨厩厮去厾县叁参叆叇又叉及友双反发叔" + "叕取受变叙叚叛叟叠口古句另叨叩只叫召叭叮可台叱史右叵叶号司叹叻叼叽吁吃各吆合吉吊" + "同名后吏吐向吒吓吕吖吗君吝吞吟吠吡吣否吧吨吩含听吭吮启吱吲吴吵吸吹吻吼吽吾呀呃呆" + "呇呈告呋呐呒呓呔呕呖呗员呙呛呜呢呣呤呦周呱呲味呵呶呷呸呻呼命咀咂咄咆咇咉咋和咍咎" + "咏咐咒咔咕咖咙咚咛咝咡咣咤咥咦咧咨咩咪咫咬咯咱咳咴咸咺咻咽咿哀品哂哃哄哆哇哈哉哌" + "响哎哏哐哑哒哓哔哕哗哙哚哝哞哟哢哥哦哧哨哩哪哭哮哱哲哳哺哼哽哿唁唆唇唉唏唐唑唔唛" + "唝唠唢唣唤唧唪唬售唯唰唱唳唵唷唼唾唿啁啃啄商啉啊啐啕啖啜啡啤啥啦啧啪啫啬啭啮啰啴" + "啵啶啷啸啻啼啾喀喁喂喃善喆喇喈喉喊喋喏喑喔喘喙喜喝喟喤喧喱喳喵喷喹喻喽喾嗄嗅嗉嗌" + "嗍嗐嗑嗒嗓嗔嗖嗜嗝嗞嗟嗡嗣嗤嗥嗦嗨嗪嗫嗬嗯嗲嗳嗵嗷嗽嗾嘀嘁嘈嘉嘌嘎嘏嘘嘚嘛嘞嘟嘡" + "嘣嘤嘧嘬嘭嘱嘲嘴嘶嘹嘻嘿噀噂噇噌噍噎噔噗噘噙噜噢噤器噩噪噫噬噱噶噻噼嚄嚅嚆嚎嚏嚓" + "嚚嚣嚭嚯嚷嚼囊囔囚四回囟因囡团囤囫园困囱围囵囷囹固国图囿圃圄圆圈圉圊圌圐圙圜土圢" + "圣在圩圪圫圬圭圮圯地圲圳圹场圻圾址坂均坉坊坋坌坍坎坏坐坑坒块坚坛坜坝坞坟坠坡坤坥" + "坦坨坩坪坫坬坭坯坰坳坷坻坼坽垂垃垄垆垈型垌垍垎垏垒垓垕垙垚垛垞垟垠垡垢垣垤垦垧垩" + "垫垭垮垯垱垲垴垵垸垺垾垿埂埃埆埇埋埌城埏埒埔埕埗埘埙埚埝域埠埤埪埫埭埯埴埵埸培基" + "埼埽堂堃堆堇堉堋堌堍堎堐堑堕堙堞堠堡堤堧堨堪堰堲堵堼堽堾塄塅塆塌塍塑塔塘塝塞塥填" + "塬塱塾墀墁境墅墈墉墐墒墓墕墘墙墚增墟墡墣墦墨墩墼壁壅壑壕壤士壬壮声壳壶壸壹处备复" + "夏夐夔夕外夙多夜够夤夥大天太夫夬夭央夯失头夷夸夹夺夼奁奂奄奇奈奉奋奎奏契奓奔奕奖" + "套奘奚奠奡奢奥奭女奴奶奸她好妁如妃妄妆妇妈妊妍妒妓妖妗妘妙妞妣妤妥妧妨妩妪妫妭妮" + "妯妲妹妻妾姆姈姊始姐姑姒姓委姗姘姚姜姝姞姣姤姥姨姬姮姱姶姹姻姽姿娀威娃娄娅娆娇娈" + "娉娌娑娓娘娜娟娠娣娥娩娱娲娴娵娶娼婀婆婉婊婌婍婕婘婚婞婠婢婤婧婪婫婳婴婵婶婷婺婻" + "婼婿媂媄媆媒媓媖媚媛媞媪媭媱媲媳媵媸媾嫁嫂嫄嫉嫌嫒嫔嫕嫖嫘嫚嫜嫠嫡嫣嫦嫩嫪嫫嫭嫱" + "嫽嬉嬖嬗嬛嬥嬬嬴嬷嬿孀孅子孑孓孔孕孖字存孙孚孛孜孝孟孢季孤孥学孩孪孬孰孱孳孵孺孽" + "宁它宄宅宇守安宋完宏宓宕宗官宙定宛宜宝实宠审客宣室宥宦宧宪宫宬宰害宴宵家宸容宽宾" + "宿寁寂寄寅密寇富寐寒寓寝寞察寡寤寥寨寮寰寸对寺寻导寿封射将尉尊小少尔尕尖尘尚尜尝" + "尢尤尥尧尨尪尬就尴尸尹尺尻尼尽尾尿局屁层屃居屈屉届屋屎屏屐屑展屙属屠屡屣履屦屯山" + "屹屺屼屾屿岁岂岈岊岌岍岐岑岔岖岗岘岙岚岛岜岞岠岢岣岨岩岫岬岭岱岳岵岷岸岽岿峁峂峃" + "峄峋峒峗峘峙峛峡峣峤峥峦峧峨峪峭峰峱峻峿崀崁崂崃崄崆崇崌崎崒崔崖崚崛崞崟崡崤崦崧" + "崩崭崮崴崶崽崾崿嵁嵅嵇嵊嵋嵌嵎嵖嵘嵚嵛嵝嵩嵫嵬嵯嵲嵴嶂嶅嶍嶒嶓嶙嶝嶟嶦嶲嶷巅巇巉" + "巍川州巡巢工左巧巨巩巫差巯己已巳巴巷巽巾币市布帅帆师希帏帐帑帔帕帖帘帙帚帛帜帝帡" + "带帧帨席帮帱帷常帻帼帽幂幄幅幌幔幕幖幛幞幡幢幪干平年并幸幺幻幼幽广庄庆庇床庋序庐" + "庑库应底庖店庙庚府庞废庠庤庥度座庭庱庳庵庶康庸庹庼庾廆廉廊廋廑廒廓廖廙廛廨廪延廷" + "建廿开弁异弃弄弆弇弈弊弋式弑弓引弗弘弛弟张弢弥弦弧弨弩弭弯弱弶弸弹强弼彀归当录彖" + "彗彘彝彟形彤彦彧彩彪彬彭彰影彳彷役彻彼往征徂径待徇很徉徊律徐徒徕得徘徙徛徜御徨循" + "徭微徵德徼徽心必忆忉忌忍忏忐忑忒忖志忘忙忝忞忠忡忤忧忪快忭忮忱忳念忸忺忻忽忾忿怀" + "态怂怃怄怅怆怊怍怎怏怒怔怕怖怙怛怜思怠怡急怦性怨怩怪怫怯怵总怼怿恁恂恃恋恍恐恒恓" + "恔恕恙恚恝恢恣恤恧恨恩恪恫恬恭息恰恳恶恸恹恺恻恼恽恿悃悄悆悈悉悌悍悒悔悖悚悛悝悟" + "悠悢患悦您悫悬悭悯悰悱悲悴悸悻悼情惆惇惊惋惎惑惔惕惘惙惚惛惜惝惟惠惦惧惨惩惫惬惭" + "惮惯惰想惴惶惹惺愀愁愃愆愈愉愍愎意愐愔愕愚感愠愣愤愦愧愫愭愿慆慈慊慌慎慑慕慝慢慥" + "慧慨慬慭慰慵慷憋憎憔憕憙憧憨憩憬憭憷憺憾懂懈懊懋懑懒懔懦懵懿戆戈戊戋戌戍戎戏成我" + "戒戕或戗战戚戛戟戡戢戣戤戥截戬戭戮戳戴户戽戾房所扁扂扃扅扆扇扈扉扊手才扎扑扒打扔" + "托扛扞扣扦执扩扪扫扬扭扮扯扰扳扶批扺扼扽找承技抃抄抉把抑抒抓抔投抖抗折抚抛抟抠抡" + "抢护报抨披抬抱抵抹抻押抽抿拂拃拄担拆拇拈拉拊拌拍拎拐拒拓拔拖拗拘拙招拜拟拢拣拤拥" + "拦拧拨择括拭拮拯拱拳拴拶拷拼拽拾拿持挂指挈按挎挑挓挖挚挛挝挞挟挠挡挣挤挥挦挨挪挫" + "振挲挹挺挽捂捃捅捆捉捋捌捍捎捏捐捕捞损捡换捣捧捩捭据捯捶捷捺捻捽掀掂掇授掉掊掌掎" + "掏掐排掖掘掞掠探掣接控推掩措掬掭掮掰掳掴掷掸掺掼掾揄揆揉揍描提插揕揖揠握揣揩揪揭" + "揳援揶揸揽揿搀搁搂搅搋搌搏搐搒搓搔搛搜搞搠搡搦搪搬搭搴携搽摁摄摅摆摇摈摊摏摒摔摘" + "摛摞摧摩摭摴摸摹摽撂撄撅撇撑撒撕撖撙撞撤撩撬播撮撰撵撷撸撺撼擀擂擅操擎擐擒擘擞擢" + "擤擦擿攀攉攒攘攥攫攮支收攸改攻攽放政故效敉敌敏救敔敕敖教敛敝敞敢散敦敩敫敬数敲整" + "敷文斋斌斐斑斓斗料斛斜斝斟斠斡斤斥斧斩斫断斯新斶方於施旁旃旄旅旆旋旌旎族旐旒旖旗" + "旞无既日旦旧旨早旬旭旮旯旰旱旴旵时旷旸旺旻旿昀昂昃昄昆昇昈昉昊昌明昏昒易昔昕昙昝" + "星映昡昣昤春昧昨昪昫昭是昱昳昴昵昶昺昼昽显晁晃晅晊晋晌晏晐晒晓晔晕晖晗晙晚晞晟晡" + "晢晤晦晨晪晫普景晰晱晴晶晷智晾暂暄暅暇暌暑暕暖暗暝暧暨暮暲暴暵暶暹暾暿曈曌曙曛曜" + "曝曦曩曰曲曳更曷曹曼曾替最月有朋服朏朐朓朔朕朗望朝期朦木未末本札术朱朳朴朵朸机朽" + "杀杂权杄杆杈杉杌李杏材村杓杕杖杙杜杞束杠条来杧杨杩杪杭杯杰杲杳杵杷杻杼松板极构枅" + "枇枉枋枍析枕林枘枚果枝枞枢枣枥枧枨枪枫枭枯枰枲枳枵架枷枸枹柁柃柄柈柊柏某柑柒染柔" + "柖柘柙柚柜柝柞柠柢查柩柬柯柰柱柳柴柷柽柿栀栅标栈栉栊栋栌栎栏栐树栒栓栖栗栝栟校栩" + "株栲栳栴样核根栻格栽栾桀桁桂桃桄桅框案桉桊桌桎桐桑桓桔桕桠桡桢档桤桥桦桧桨桩桫桯" + "桲桴桶桷桹梁梃梅梆梌梏梓梗梠梢梣梦梧梨梭梯械梳梴梵梼梽梾梿检棁棂棉棋棍棐棒棓棕棘" + "棚棠棣棤棨棪棫棬森棰棱棵棹棺棻棼棽椀椁椅椆椋植椎椐椑椒椓椟椠椤椪椭椰椴椸椹椽椿楂" + "楒楔楗楙楚楝楞楠楣楦楩楪楫楮楯楷楸楹楼概榃榄榅榆榇榈榉榍榑榔榕榖榛榜榧榨榫榭榰榱" + "榴榷榻槁槃槊槌槎槐槔槚槛槜槟槠槭槱槲槽槿樊樗樘樟模樨横樯樱樵樽樾橄橇橐橑橘橙橛橞" + "橡橥橦橱橹橼檀檄檎檐檑檗檞檠檩檫檬櫆欂欠次欢欣欤欧欲欸欹欺欻款歃歅歆歇歉歌歙止正" + "此步武歧歪歹死歼殁殂殃殄殆殇殉殊残殍殒殓殖殚殛殡殣殪殳殴段殷殿毁毂毅毋毌母每毐毒" + "毓比毕毖毗毙毛毡毪毫毯毳毵毹毽氅氆氇氍氏氐民氓气氕氖氘氙氚氛氟氡氢氤氦氧氨氩氪氮" + "氯氰氲水永氾氿汀汁求汆汇汈汉汊汋汐汔汕汗汛汜汝汞江池污汤汧汨汩汪汫汭汰汲汴汶汹汽" + "汾沁沂沃沄沅沆沇沈沉沌沏沐沓沔沘沙沚沛沟没沣沤沥沦沧沨沩沪沫沭沮沱河沸油沺治沼沽" + "沾沿泂泃泄泅泇泉泊泌泐泓泔法泖泗泙泚泛泜泞泠泡波泣泥注泪泫泮泯泰泱泳泵泷泸泺泻泼" + "泽泾洁洄洇洈洋洌洎洑洒洓洗洘洙洚洛洞洢洣津洧洨洪洫洭洮洱洲洳洴洵洸洹洺活洼洽派洿" + "流浃浅浆浇浈浉浊测浍济浏浐浑浒浓浔浕浙浚浛浜浞浟浠浡浣浥浦浩浪浬浭浮浯浰浲浴海浸" + "浼涂涄涅消涉涌涍涎涐涑涓涔涕涘涛涝涞涟涠涡涢涣涤润涧涨涩涪涫涮涯液涴涵涸涿淀淄淅" + "淆淇淋淌淏淑淖淘淙淜淝淞淟淠淡淤淦淫淬淮淯深淳淴混淹添淼清渊渌渍渎渐渑渔渗渚渝渟" + "渠渡渣渤渥温渫渭港渰渲渴游渺渼湃湄湉湍湎湑湓湔湖湘湛湜湝湟湣湫湮湲湴湾湿溁溃溅溆" + "溇溉溍溏源溘溚溜溞溟溠溢溥溦溧溪溯溱溲溴溵溶溷溹溺溻溽滁滂滃滆滇滉滋滍滏滑滓滔滕" + "滗滘滚滞滟滠满滢滤滥滦滧滨滩滪滫滴滹漂漆漈漉漋漏漓演漕漖漠漤漦漩漪漫漭漯漱漳漴漶" + "漷漹漻漼漾潆潇潋潍潏潖潘潜潞潟潢潦潩潭潮潲潴潵潸潺潼潽潾澂澄澈澉澌澍澎澛澜澡澥澧" + "澪澭澳澴澶澹澼澽激濂濉濋濑濒濞濠濡濩濮濯瀌瀍瀑瀔瀚瀛瀣瀱瀵瀹瀼灈灌灏灞火灭灯灰灵" + "灶灸灼灾灿炀炅炆炉炊炌炎炒炔炕炖炘炙炜炝炟炣炫炬炭炮炯炱炳炷炸点炻炼炽烀烁烂烃烈" + "烊烔烘烙烛烜烝烟烠烤烦烧烨烩烫烬热烯烶烷烹烺烻烽焆焉焊焌焐焓焕焖焗焘焙焚焜焞焦焯" + "焰焱然煁煃煅煊煋煌煎煓煜煞煟煤煦照煨煮煲煳煴煸煺煽熄熇熊熏熔熘熙熛熜熟熠熥熨熬熵" + "熹熻燃燊燋燎燏燔燕燚燠燥燧燮燹爆爇爔爚爝爟爨爪爬爰爱爵父爷爸爹爻爽爿牁牂片版牌牍" + "牒牖牙牚牛牝牟牡牢牤牥牦牧物牮牯牲牵特牺牻牾牿犀犁犄犇犊犋犍犏犒犟犨犬犯犰犴状犷" + "犸犹狁狂狃狄狈狉狍狎狐狒狗狙狝狞狠狡狨狩独狭狮狯狰狱狲狳狴狷狸狺狻狼猁猃猄猇猊猎" + "猕猖猗猛猜猝猞猡猢猥猩猪猫猬献猯猰猱猴猷猹猺猾猿獍獐獒獗獠獬獭獯獴獾玃玄率玉王玎" + "玑玒玓玕玖玘玙玚玛玞玟玠玡玢玤玥玦玩玫玭玮环现玱玲玳玶玷玹玺玻玼玿珀珂珅珇珈珉珊" + "珋珌珍珏珐珑珒珕珖珙珛珝珞珠珢珣珥珦珧珩珪珫班珰珲珵珷珸珹珺珽琀球琄琅理琇琈琉琊" + "琎琏琐琔琚琛琟琡琢琤琥琦琨琪琫琬琭琮琯琰琲琳琴琵琶琼瑀瑁瑂瑃瑄瑅瑆瑑瑓瑔瑕瑖瑗瑙" + "瑚瑛瑜瑝瑞瑟瑢瑧瑨瑬瑭瑰瑱瑳瑶瑷瑾璀璁璃璆璇璈璋璎璐璒璘璜璞璟璠璥璧璨璩璪璬璮璱" + "璲璺瓀瓒瓖瓘瓜瓞瓠瓢瓣瓤瓦瓮瓯瓴瓶瓷瓻瓿甄甍甏甑甓甗甘甚甜生甡甥甦用甩甪甫甬甭甯" + "田由甲申电男甸町画甾畀畅畈畋界畎畏畔畖留畚畛畜畤略畦番畬畯畲畴畸畹畿疁疃疆疍疏疐" + "疑疔疖疗疙疚疝疟疠疡疢疣疤疥疫疬疭疮疯疰疱疲疳疴疵疸疹疼疽疾痂痃痄病症痈痉痊痍痒" + "痓痔痕痘痛痞痢痣痤痦痧痨痪痫痰痱痴痹痼痿瘀瘁瘃瘅瘆瘊瘌瘐瘕瘗瘘瘙瘛瘟瘠瘢瘤瘥瘦瘩" + "瘪瘫瘭瘰瘳瘴瘵瘸瘼瘾瘿癀癃癌癍癔癖癗癜癞癣癫癯癸登白百癿皂的皆皇皈皋皎皑皓皕皖皙" + "皛皞皤皦皭皮皱皲皴皿盂盅盆盈盉益盍盎盏盐监盒盔盖盗盘盛盟盥盦目盯盱盲直盷相盹盼盾" + "省眄眇眈眉眊看眍眙眚真眠眢眦眨眩眬眭眯眵眶眷眸眺眼着睁睃睄睇睎睐睑睚睛睡睢督睥睦" + "睨睫睬睹睽睾睿瞀瞄瞅瞋瞌瞍瞎瞑瞒瞟瞠瞢瞥瞧瞩瞪瞫瞬瞭瞰瞳瞵瞻瞽瞿矍矗矛矜矞矢矣知" + "矧矩矫矬短矮矰石矶矸矻矼矾矿砀码砂砄砆砉砌砍砑砒研砖砗砘砚砜砝砟砠砣砥砧砫砬砭砮" + "砰破砵砷砸砹砺砻砼砾础硁硅硇硊硌硍硎硐硒硔硕硖硗硙硚硝硪硫硬硭确硼硿碃碇碈碉碌碍" + "碎碏碑碓碗碘碚碛碜碟碡碣碥碧碨碰碱碲碳碴碶碹碾磁磅磉磊磋磏磐磔磕磙磜磡磨磬磲磴磷" + "磹磻礁礅礌礓礞礴礵示礼社祀祁祃祆祇祈祉祊祋祎祏祐祓祕祖祗祚祛祜祝神祟祠祢祥祧票祭" + "祯祲祷祸祺祼祾禀禁禄禅禊禋福禒禔禘禚禛禤禧禳禹禺离禽禾秀私秃秆秉秋种科秒秕秘租秣" + "秤秦秧秩秫秬秭积称秸移秽秾稀稂稃稆程稌稍税稑稔稗稙稚稞稠稣稳稷稹稻稼稽稿穄穆穑穗" + "穙穜穟穰穴究穷穸穹空穿窀突窃窄窅窈窊窍窎窑窒窕窖窗窘窜窝窟窠窣窥窦窨窬窭窳窸窿立" + "竑竖竘站竞竟章竣童竦竫竭端竹竺竽竿笃笄笆笈笊笋笏笑笔笕笙笛笞笠笤笥符笨笪笫第笮笯" + "笱笳笸笺笼笾筀筅筇等筋筌筏筐筑筒答策筘筚筛筜筝筠筢筤筥筦筮筱筲筵筶筷筹筻筼签简箅" + "箍箐箓箔箕箖算箜管箢箦箧箨箩箪箫箬箭箱箴箸篁篆篇篌篑篓篙篚篝篡篥篦篪篮篯篱篷篼篾" + "簃簇簉簋簌簏簕簖簝簟簠簧簪簰簸簿籀籁籍籥米籴类籼籽粉粑粒粕粗粘粜粝粞粟粢粤粥粪粮" + "粱粲粳粹粼粽精粿糁糅糇糈糊糌糍糒糕糖糗糙糜糟糠糨糯糵系紊素索紧紫累絜絮絷綦綮縠縢" + "縻繁繄繇纂纛纠纡红纣纤纥约级纨纩纪纫纬纭纮纯纰纱纲纳纴纵纶纷纸纹纺纻纼纽纾线绀绁" + "绂练组绅细织终绉绊绋绌绍绎经绐绑绒结绔绕绖绗绘给绚绛络绝绞统绠绡绢绣绤绥绦继绨绩" + "绪绫续绮绯绰绱绲绳维绵绶绷绸绹绺绻综绽绾绿缀缁缂缃缄缅缆缇缈缉缊缌缎缐缑缒缓缔缕" + "编缗缘缙缚缛缜缝缞缟缠缡缢缣缤缥缦缧缨缩缪缫缬缭缮缯缰缱缲缳缴缵缶缸缺罂罄罅罍罐" + "网罔罕罗罘罚罟罡罢罨罩罪置罱署罴罶罹罽罾羁羊羌美羑羓羔羕羖羚羝羞羟羡群羧羯羰羱羲" + "羸羹羼羽羿翀翁翂翃翅翈翊翌翎翔翕翘翙翚翛翟翠翡翥翦翩翮翯翰翱翳翷翻翼翾耀老考耄者" + "耆耇耋而耍耏耐耑耒耔耕耖耗耘耙耜耠耢耤耥耦耧耨耩耪耰耱耳耵耶耷耸耻耽耿聂聃聆聊聋" + "职聍聒联聘聚聩聪聱聿肃肄肆肇肉肋肌肓肖肘肚肛肝肟肠股肢肤肥肩肪肫肭肮肯肱育肴肷肸" + "肺肼肽肾肿胀胁胂胃胄胆胈背胍胎胖胗胙胚胛胜胝胞胠胡胣胤胥胧胨胩胪胫胬胭胯胰胱胲胳" + "胴胶胸胺胼能脂脆脉脊脍脎脏脐脑脒脓脔脖脘脚脞脟脩脬脯脱脲脶脸脾脿腆腈腊腋腌腐腑腒" + "腓腔腕腘腙腚腠腥腧腨腩腭腮腯腰腱腴腹腺腻腼腽腾腿膀膂膈膊膏膑膘膙膛膜膝膦膨膳膺膻" + "臀臂臃臆臊臌臑臜臣臧自臬臭至致臻臼臾舀舁舂舄舅舆舌舍舐舒舔舛舜舞舟舠舢舣舥航舫般" + "舭舯舰舱舲舳舴舵舶舷舸船舻舾艄艅艇艉艋艎艏艘艚艟艨艮良艰色艳艴艺艽艾艿节芃芄芈芊" + "芋芍芎芏芑芒芗芘芙芜芝芟芠芡芣芤芥芦芨芩芪芫芬芭芮芯芰花芳芴芷芸芹芼芽芾苁苄苇苈" + "苉苊苋苌苍苎苏苑苒苓苔苕苗苘苛苜苞苟苠苡苣苤若苦苧苫苯英苴苷苹苻苾茀茁茂范茄茅茆" + "茈茉茋茌茎茏茑茓茔茕茗茚茛茜茝茧茨茫茬茭茯茱茳茴茵茶茸茹茺茼茽荀荁荃荄荆荇草荏荐" + "荑荒荓荔荖荙荚荛荜荞荟荠荡荣荤荥荦荧荨荩荪荫荬荭荮药荷荸荻荼荽莅莆莉莎莒莓莘莙莛" + "莜莝莞莠莨莩莪莫莰莱莲莳莴莶获莸莹莺莼莽莿菀菁菂菅菇菉菊菌菍菏菔菖菘菜菝菟菠菡菥" + "菩菪菰菱菲菹菼菽萁萃萄萆萋萌萍萎萏萑萘萚萜萝萣萤营萦萧萨萩萱萳萸萹萼落葆葎葑葖著" + "葙葚葛葜葡董葩葫葬葭葰葱葳葴葵葶葸葺蒂蒄蒇蒈蒉蒋蒌蒎蒐蒗蒙蒜蒟蒡蒨蒯蒱蒲蒴蒸蒹蒺" + "蒻蒽蒿蓁蓂蓄蓇蓉蓊蓍蓏蓐蓑蓓蓖蓝蓟蓠蓢蓣蓥蓦蓬蓰蓼蓿蔀蔃蔈蔊蔌蔑蔓蔗蔚蔟蔡蔫蔬蔷" + "蔸蔹蔺蔻蔼蔽蕃蕈蕉蕊蕖蕗蕙蕞蕤蕨蕰蕲蕴蕹蕺蕻蕾薁薄薅薇薏薛薜薢薤薨薪薮薯薰薳薷薸" + "薹薿藁藉藏藐藓藕藜藟藠藤藦藨藩藻藿蘅蘑蘖蘘蘧蘩蘸蘼虎虏虐虑虒虓虔虚虞虢虤虫虬虮虱" + "虷虸虹虺虻虼虽虾虿蚀蚁蚂蚄蚆蚊蚋蚌蚍蚓蚕蚜蚝蚣蚤蚧蚨蚩蚪蚬蚯蚰蚱蚲蚴蚶蚺蛀蛃蛄蛆" + "蛇蛉蛊蛋蛎蛏蛐蛑蛔蛘蛙蛛蛞蛟蛤蛩蛭蛮蛰蛱蛲蛳蛴蛸蛹蛾蜀蜂蜃蜇蜈蜉蜊蜍蜎蜐蜒蜓蜕蜗" + "蜘蜚蜜蜞蜡蜢蜣蜥蜩蜮蜱蜴蜷蜻蜾蜿蝇蝈蝉蝌蝎蝓蝗蝘蝙蝠蝣蝤蝥蝮蝰蝲蝴蝶蝻蝼蝽蝾螂螃" + "螅螈螋融螗螟螠螣螨螫螬螭螯螱螳螵螺螽蟀蟆蟊蟋蟏蟑蟒蟛蟠蟥蟪蟫蟮蟹蟾蠃蠊蠋蠓蠕蠖蠡" + "蠢蠲蠹蠼血衃衄衅行衍衎衒衔街衙衠衡衢衣补表衩衫衬衮衰衲衷衽衾衿袁袂袄袅袆袈袋袍袒" + "袖袗袜袢袤袪被袭袯袱袷袼裁裂装裆裈裉裎裒裔裕裘裙裛裟裢裣裤裥裨裰裱裳裴裸裹裼裾褂" + "褊褐褒褓褕褙褚褛褟褡褥褪褫褯褰褴褶襁襄襕襚襜襞襟襦襫襻西要覃覆见观觃规觅视觇览觉" + "觊觋觌觎觏觐觑角觖觚觜觞觟解觥触觫觭觯觱觳觿言訄訇訚訾詈詟詹誉誊誓謇警譬计订讣认" + "讥讦讧讨让讪讫训议讯记讱讲讳讴讵讶讷许讹论讻讼讽设访诀证诂诃评诅识诇诈诉诊诋诌词" + "诎诏诐译诒诓诔试诖诗诘诙诚诛诜话诞诟诠诡询诣诤该详诧诨诩诫诬语诮误诰诱诲诳说诵请" + "诸诹诺读诼诽课诿谀谁谂调谄谅谆谇谈谊谋谌谍谎谏谐谑谒谓谔谕谖谗谙谚谛谜谝谞谟谠谡" + "谢谣谤谥谦谧谨谩谪谫谬谭谮谯谰谱谲谳谴谵谶谷谼谿豁豆豇豉豌豕豚象豢豨豪豫豮豳豸豹" + "豺貂貅貆貉貊貌貔貘贝贞负贡财责贤败账货质贩贪贫贬购贮贯贰贱贲贳贴贵贶贷贸费贺贻贼" + "贽贾贿赀赁赂赃资赅赆赇赈赉赊赋赌赍赎赏赐赑赒赓赔赕赖赗赘赙赚赛赜赝赞赟赠赡赢赣赤" + "赦赧赪赫赭走赳赴赵赶起趁趄超越趋趑趔趟趣趯趱足趴趵趸趺趼趾趿跂跃跄跆跋跌跎跏跐跑" + "跖跗跚跛距跞跟跣跤跨跪跬路跱跳践跶跷跸跹跺跻跽踅踉踊踌踏踒踔踝踞踟踢踣踦踩踪踬踮" + "踯踱踵踶踹踺踽蹀蹁蹂蹄蹅蹇蹈蹉蹊蹋蹐蹑蹒蹙蹚蹜蹢蹦蹩蹬蹭蹯蹰蹲蹴蹶蹼蹽蹾蹿躁躅躇" + "躏躐躔躜躞身躬躯躲躺车轧轨轩轪轫转轭轮软轰轱轲轳轴轵轶轷轸轹轺轻轼载轾轿辀辁辂较" + "辄辅辆辇辈辉辊辋辌辍辎辏辐辑辒输辔辕辖辗辘辙辚辛辜辞辟辣辨辩辫辰辱边辽达辿迁迂迄" + "迅过迈迎运近迓返迕还这进远违连迟迢迤迥迦迨迩迪迫迭迮述迳迷迸迹迺追退送适逃逄逅逆" + "选逊逋逍透逐逑递途逖逗通逛逝逞速造逡逢逦逭逮逯逴逵逶逸逻逼逾遁遂遄遆遇遍遏遐遑遒" + "道遗遘遛遢遣遥遨遭遮遴遵遹遽避邀邂邃邈邋邑邓邕邗邘邙邛邝邠邡邢那邦邨邪邬邮邯邰邱" + "邲邳邴邵邶邸邹邺邻邽邾邿郁郃郄郅郇郈郊郎郏郐郑郓郗郚郛郜郝郡郢郤郦郧部郪郫郭郯郴" + "郸都郾郿鄀鄂鄃鄄鄅鄌鄑鄗鄘鄙鄚鄜鄞鄠鄢鄣鄫鄯鄱鄹酂酃酅酆酉酊酋酌配酎酏酐酒酗酚酝" + "酞酡酢酣酤酥酦酩酪酬酮酯酰酱酲酴酵酶酷酸酹酺酽酾酿醅醇醉醋醌醍醐醑醒醚醛醢醨醪醭" + "醮醯醴醵醺醾采釉释里重野量釐金釜鉴銎銮鋆鋈錾鍪鎏鏊鏖鐾鑫钆钇针钉钊钋钌钍钎钏钐钒" + "钓钔钕钖钗钘钙钚钛钜钝钞钟钠钡钢钣钤钥钦钧钨钩钪钫钬钭钮钯钰钱钲钳钴钵钷钹钺钻钼" + "钽钾钿铀铁铂铃铄铅铆铈铉铊铋铌铍铎铏铐铑铒铕铖铗铘铙铚铛铜铝铞铟铠铡铢铣铤铥铧铨" + "铩铪铫铬铭铮铯铰铱铲铳铴铵银铷铸铹铺铻铼铽链铿销锁锂锃锄锅锆锇锈锉锊锋锌锍锎锏锐" + "锑锒锓锔锕锖锗锘错锚锛锜锝锞锟锡锢锣锤锥锦锧锨锩锪锫锬锭键锯锰锱锲锳锴锵锶锷锸锹" + "锺锻锼锽锾锿镀镁镂镃镄镅镆镇镈镉镊镋镌镍镎镏镐镑镒镓镔镕镖镗镘镚镛镜镝镞镠镡镢镣" + "镤镥镦镧镨镩镪镫镬镭镮镯镰镱镲镳镴镵镶长门闩闪闫闭问闯闰闱闲闳间闵闶闷闸闹闺闻闼" + "闽闾闿阀阁阂阃阄阅阆阇阈阉阊阋阌阍阎阏阐阑阒阔阕阖阗阘阙阚阜队阡阪阮阱防阳阴阵阶" + "阻阼阽阿陀陂附际陆陇陈陉陋陌降陎限陑陔陕陛陞陟陡院除陧陨险陪陬陲陴陵陶陷隃隅隆隈" + "隋隍随隐隔隗隘隙障隧隩隰隳隶隹隺隼隽难雀雁雄雅集雇雉雊雌雍雎雏雒雕雠雨雩雪雯雱雳" + "零雷雹雾需霁霄霅霆震霈霉霍霎霏霓霖霜霞霨霪霭霰露霸霹霾青靓靖静靛非靠靡面靥革靬靰" + "靳靴靶靸靺靼靽靿鞁鞅鞋鞍鞑鞒鞔鞘鞠鞡鞣鞧鞨鞫鞬鞭鞮鞯鞲鞳鞴韂韦韧韨韩韪韫韬韭音韵" + "韶页顶顷顸项顺须顼顽顾顿颀颁颂颃预颅领颇颈颉颊颋颌颍颎颏颐频颓颔颖颗题颙颚颛颜额" + "颞颟颠颡颢颤颥颦颧风飏飐飑飒飓飔飕飗飘飙飞食飧飨餍餐餮饔饕饥饧饨饩饪饫饬饭饮饯饰" + "饱饲饳饴饵饶饷饸饹饺饻饼饽饿馁馃馄馅馆馇馈馉馊馋馌馍馏馐馑馒馓馔馕首馗馘香馝馞馥" + "馧馨马驭驮驯驰驱驲驳驴驵驶驷驸驹驺驻驼驽驾驿骀骁骂骃骄骅骆骇骈骉骊骋验骍骎骏骐骑" + "骒骓骕骖骗骘骙骚骛骜骝骞骟骠骡骢骣骤骥骦骧骨骰骱骶骷骸骺骼髀髁髂髃髅髋髌髎髑髓高" + "髡髢髦髫髭髯髹髻髽鬃鬈鬏鬒鬓鬘鬟鬣鬯鬲鬶鬷鬻鬼魁魂魃魄魅魆魇魈魉魋魍魏魑魔鱼鱽鱾" + "鱿鲀鲁鲂鲃鲅鲆鲇鲈鲉鲊鲋鲌鲍鲎鲏鲐鲑鲒鲔鲕鲖鲗鲘鲙鲚鲛鲜鲝鲞鲟鲠鲡鲢鲣鲤鲥鲦鲧鲨" + "鲩鲪鲫鲬鲭鲮鲯鲰鲱鲲鲳鲴鲵鲷鲸鲹鲺鲻鲼鲽鲾鲿鳀鳁鳂鳃鳄鳅鳇鳈鳉鳊鳌鳍鳎鳏鳐鳑鳒鳓" + "鳔鳕鳖鳗鳘鳙鳚鳛鳜鳝鳞鳟鳠鳡鳢鳣鳤鸟鸠鸡鸢鸣鸤鸥鸦鸧鸨鸩鸪鸫鸬鸭鸮鸯鸰鸱鸲鸳鸵鸶" + "鸷鸸鸹鸺鸻鸼鸽鸾鸿鹀鹁鹂鹃鹄鹅鹆鹇鹈鹉鹊鹋鹌鹍鹎鹏鹐鹑鹒鹔鹕鹖鹗鹘鹙鹚鹛鹜鹝鹞鹟" + "鹠鹡鹢鹣鹤鹦鹧鹨鹩鹪鹫鹬鹭鹮鹯鹰鹱鹲鹳鹴鹾鹿麀麂麇麈麋麑麒麓麖麝麟麦麸麹麻麽麾黄" + "黇黉黍黎黏黑黔默黛黜黝黟黠黡黢黥黧黩黪黯黹黻黼黾鼋鼍鼎鼐鼒鼓鼗鼙鼠鼢鼩鼫鼬鼯鼱鼷" + "鼹鼻鼽鼾齁齇齉齐齑齿龀龁龂龃龄龅龆龇龈龉龊龋龌龙龚龛龟龠龢鿍鿎鿏㑇㑊㕮㘎㙍㙘㙦㛃" + "㛚㛹㟃㠇㠓㤘㥄㧐㧑㧟㫰㬊㬎㬚㭎㭕㮾㰀㳇㳘㳚㴔㵐㶲㸆㸌㺄㻬㽏㿠䁖䂮䃅䃎䅟䌹䎃䎖䏝䏡" + "䏲䐃䓖䓛䓨䓫䓬䗖䗛䗪䗴䜣䝙䢺䢼䣘䥽䦃䲟䲠䲢䴓䴔䴕䴖䴗䴘䴙䶮𠅤𠙶𠳐𡎚𡐓𣗋𣲗𣲘𣸣𤧛𤩽" + "𤫉𥔲𥕢𥖨𥻗𦈡𦒍𦙶𦝼𦭜𦰡𧿹𨐈𨙸𨚕𨟠𨭉𨱇𨱏𨱑𨱔𨺙𩽾𩾃𩾌𪟝𪣻𪤗𪨰𪨶𪩘𪾢𫄧𫄨𫄷𫄸𫇭𫌀𫍣𫍯" + "𫍲𫍽𫐄𫐐𫐓𫑡𫓧𫓯𫓶𫓹𫔍𫔎𫔶𫖮𫖯𫖳𫗧𫗴𫘜𫘝𫘦𫘧𫘨𫘪𫘬𫚕𫚖𫚭𫛭𫞩𫟅𫟦𫟹𫟼𫠆𫠊𫠜𫢸𫫇𫭟" + "𫭢𫭼𫮃𫰛𫵷𫶇𫷷𫸩𬀩𬀪𬂩𬃊𬇕𬇙𬇹𬉼𬊈𬊤𬌗𬍛𬍡𬍤𬒈𬒔𬒗𬕂𬘓𬘘𬘡𬘩𬘫𬘬𬘭𬘯𬙂𬙊𬙋𬜬𬜯𬞟" + "𬟁𬟽𬣙𬣞𬣡𬣳𬤇𬤊𬤝𬨂𬨎𬩽𬪩𬬩𬬭𬬮𬬱𬬸𬬹𬬻𬬿𬭁𬭊𬭎𬭚𬭛𬭤𬭩𬭬𬭯𬭳𬭶𬭸𬭼𬮱𬮿𬯀𬯎𬱖𬱟" + "𬳵𬳶𬳽𬳿𬴂𬴃𬴊𬶋𬶍𬶏𬶐𬶟𬶠𬶨𬶭𬶮𬷕𬸘𬸚𬸣𬸦𬸪𬹼𬺈𬺓" +) +CN_CHARS_EXT = "吶诶屌囧飚屄" + +CN_CHARS = CN_CHARS_COMMON + CN_CHARS_EXT +IN_CH_CHARS = {c: True for c in CN_CHARS} + +EN_CHARS = string.ascii_letters + string.digits +IN_EN_CHARS = {c: True for c in EN_CHARS} + +VALID_CHARS = CN_CHARS + EN_CHARS + " " +IN_VALID_CHARS = {c: True for c in VALID_CHARS} + + +# ================================================================================ # +# basic class +# ================================================================================ # +class ChineseChar(object): + """ + 中文字符 + 每个字符对应简体和繁体, + e.g. 简体 = '负', 繁体 = '負' + 转换时可转换为简体或繁体 + """ + + def __init__(self, simplified, traditional): + self.simplified = simplified + self.traditional = traditional + # self.__repr__ = self.__str__ + + def __str__(self): + return self.simplified or self.traditional or None + + def __repr__(self): + return self.__str__() + + +class ChineseNumberUnit(ChineseChar): + """ + 中文数字/数位字符 + 每个字符除繁简体外还有一个额外的大写字符 + e.g. '陆' 和 '陸' + """ + + def __init__(self, power, simplified, traditional, big_s, big_t): + super(ChineseNumberUnit, self).__init__(simplified, traditional) + self.power = power + self.big_s = big_s + self.big_t = big_t + + def __str__(self): + return "10^{}".format(self.power) + + @classmethod + def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): + if small_unit: + return ChineseNumberUnit( + power=index + 1, simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1] + ) + elif numbering_type == NUMBERING_TYPES[0]: + return ChineseNumberUnit( + power=index + 8, simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1] + ) + elif numbering_type == NUMBERING_TYPES[1]: + return ChineseNumberUnit( + power=(index + 2) * 4, simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1] + ) + elif numbering_type == NUMBERING_TYPES[2]: + return ChineseNumberUnit( + power=pow(2, index + 3), simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1] + ) + else: + raise ValueError("Counting type should be in {0} ({1} provided).".format(NUMBERING_TYPES, numbering_type)) + + +class ChineseNumberDigit(ChineseChar): + """ + 中文数字字符 + """ + + def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None): + super(ChineseNumberDigit, self).__init__(simplified, traditional) + self.value = value + self.big_s = big_s + self.big_t = big_t + self.alt_s = alt_s + self.alt_t = alt_t + + def __str__(self): + return str(self.value) + + @classmethod + def create(cls, i, v): + return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) + + +class ChineseMath(ChineseChar): + """ + 中文数位字符 + """ + + def __init__(self, simplified, traditional, symbol, expression=None): + super(ChineseMath, self).__init__(simplified, traditional) + self.symbol = symbol + self.expression = expression + self.big_s = simplified + self.big_t = traditional + + +CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath + + +class NumberSystem(object): + """ + 中文数字系统 + """ + + +class MathSymbol(object): + """ + 用于中文数字系统的数学符号 (繁/简体), e.g. + positive = ['正', '正'] + negative = ['负', '負'] + point = ['点', '點'] + """ + + def __init__(self, positive, negative, point): + self.positive = positive + self.negative = negative + self.point = point + + def __iter__(self): + for v in self.__dict__.values(): + yield v + + +# class OtherSymbol(object): +# """ +# 其他符号 +# """ +# +# def __init__(self, sil): +# self.sil = sil +# +# def __iter__(self): +# for v in self.__dict__.values(): +# yield v + + +# ================================================================================ # +# basic utils +# ================================================================================ # +def create_system(numbering_type=NUMBERING_TYPES[1]): + """ + 根据数字系统类型返回创建相应的数字系统,默认为 mid + NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 + low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. + mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. + high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. + 返回对应的数字系统 + """ + + # chinese number units of '亿' and larger + all_larger_units = zip(LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL) + larger_units = [CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)] + # chinese number units of '十, 百, 千, 万' + all_smaller_units = zip(SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL) + smaller_units = [CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)] + # digis + chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS, BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL) + digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] + digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT + digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT + digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] + + # symbols + positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x) + negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x) + point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y))) + # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) + system = NumberSystem() + system.units = smaller_units + larger_units + system.digits = digits + system.math = MathSymbol(positive_cn, negative_cn, point_cn) + # system.symbols = OtherSymbol(sil_cn) + return system + + +def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): + def get_symbol(char, system): + for u in system.units: + if char in [u.traditional, u.simplified, u.big_s, u.big_t]: + return u + for d in system.digits: + if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]: + return d + for m in system.math: + if char in [m.traditional, m.simplified]: + return m + + def string2symbols(chinese_string, system): + int_string, dec_string = chinese_string, "" + for p in [system.math.point.simplified, system.math.point.traditional]: + if p in chinese_string: + int_string, dec_string = chinese_string.split(p) + break + return [get_symbol(c, system) for c in int_string], [get_symbol(c, system) for c in dec_string] + + def correct_symbols(integer_symbols, system): + """ + 一百八 to 一百八十 + 一亿一千三百万 to 一亿 一千万 三百万 + """ + + if integer_symbols and isinstance(integer_symbols[0], CNU): + if integer_symbols[0].power == 1: + integer_symbols = [system.digits[1]] + integer_symbols + + if len(integer_symbols) > 1: + if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU): + integer_symbols.append(CNU(integer_symbols[-2].power - 1, None, None, None, None)) + + result = [] + unit_count = 0 + for s in integer_symbols: + if isinstance(s, CND): + result.append(s) + unit_count = 0 + elif isinstance(s, CNU): + current_unit = CNU(s.power, None, None, None, None) + unit_count += 1 + + if unit_count == 1: + result.append(current_unit) + elif unit_count > 1: + for i in range(len(result)): + if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power: + result[-i - 1] = CNU(result[-i - 1].power + current_unit.power, None, None, None, None) + return result + + def compute_value(integer_symbols): + """ + Compute the value. + When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. + e.g. '两千万' = 2000 * 10000 not 2000 + 10000 + """ + value = [0] + last_power = 0 + for s in integer_symbols: + if isinstance(s, CND): + value[-1] = s.value + elif isinstance(s, CNU): + value[-1] *= pow(10, s.power) + if s.power > last_power: + value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1])) + last_power = s.power + value.append(0) + return sum(value) + + system = create_system(numbering_type) + int_part, dec_part = string2symbols(chinese_string, system) + int_part = correct_symbols(int_part, system) + int_str = str(compute_value(int_part)) + dec_str = "".join([str(d.value) for d in dec_part]) + if dec_part: + return "{0}.{1}".format(int_str, dec_str) + else: + return int_str + + +def num2chn( + number_string, + numbering_type=NUMBERING_TYPES[1], + big=False, + traditional=False, + alt_zero=False, + alt_one=False, + alt_two=True, + use_zeros=True, + use_units=True, +): + def get_value(value_string, use_zeros=True): + striped_string = value_string.lstrip("0") + + # record nothing if all zeros + if not striped_string: + return [] + + # record one digits + elif len(striped_string) == 1: + if use_zeros and len(value_string) != len(striped_string): + return [system.digits[0], system.digits[int(striped_string)]] + else: + return [system.digits[int(striped_string)]] + + # recursively record multiple digits + else: + result_unit = next(u for u in reversed(system.units) if u.power < len(striped_string)) + result_string = value_string[: -result_unit.power] + return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power :]) + + system = create_system(numbering_type) + + int_dec = number_string.split(".") + if len(int_dec) == 1: + int_string = int_dec[0] + dec_string = "" + elif len(int_dec) == 2: + int_string = int_dec[0] + dec_string = int_dec[1] + else: + raise ValueError("invalid input num string with more than one dot: {}".format(number_string)) + + if use_units and len(int_string) > 1: + result_symbols = get_value(int_string) + else: + result_symbols = [system.digits[int(c)] for c in int_string] + dec_symbols = [system.digits[int(c)] for c in dec_string] + if dec_string: + result_symbols += [system.math.point] + dec_symbols + + if alt_two: + liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t, system.digits[2].big_s, system.digits[2].big_t) + for i, v in enumerate(result_symbols): + if isinstance(v, CND) and v.value == 2: + next_symbol = result_symbols[i + 1] if i < len(result_symbols) - 1 else None + previous_symbol = result_symbols[i - 1] if i > 0 else None + if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))): + if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)): + result_symbols[i] = liang + + # if big is True, '两' will not be used and `alt_two` has no impact on output + if big: + attr_name = "big_" + if traditional: + attr_name += "t" + else: + attr_name += "s" + else: + if traditional: + attr_name = "traditional" + else: + attr_name = "simplified" + + result = "".join([getattr(s, attr_name) for s in result_symbols]) + + # if not use_zeros: + # result = result.strip(getattr(system.digits[0], attr_name)) + + if alt_zero: + result = result.replace(getattr(system.digits[0], attr_name), system.digits[0].alt_s) + + if alt_one: + result = result.replace(getattr(system.digits[1], attr_name), system.digits[1].alt_s) + + for i, p in enumerate(POINT): + if result.startswith(p): + return CHINESE_DIGIS[0] + result + + # ^10, 11, .., 19 + if ( + len(result) >= 2 + and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] + and result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]] + ): + result = result[1:] + + return result + + +# ================================================================================ # +# different types of rewriters +# ================================================================================ # +class Cardinal: + """ + CARDINAL类 + """ + + def __init__(self, cardinal=None, chntext=None): + self.cardinal = cardinal + self.chntext = chntext + + def chntext2cardinal(self): + return chn2num(self.chntext) + + def cardinal2chntext(self): + return num2chn(self.cardinal) + + +class Digit: + """ + DIGIT类 + """ + + def __init__(self, digit=None, chntext=None): + self.digit = digit + self.chntext = chntext + + # def chntext2digit(self): + # return chn2num(self.chntext) + + def digit2chntext(self): + return num2chn(self.digit, alt_two=False, use_units=False) + + +class TelePhone: + """ + TELEPHONE类 + """ + + def __init__(self, telephone=None, raw_chntext=None, chntext=None): + self.telephone = telephone + self.raw_chntext = raw_chntext + self.chntext = chntext + + # def chntext2telephone(self): + # sil_parts = self.raw_chntext.split('') + # self.telephone = '-'.join([ + # str(chn2num(p)) for p in sil_parts + # ]) + # return self.telephone + + def telephone2chntext(self, fixed=False): + if fixed: + sil_parts = self.telephone.split("-") + self.raw_chntext = "".join([num2chn(part, alt_two=False, use_units=False) for part in sil_parts]) + self.chntext = self.raw_chntext.replace("", "") + else: + sp_parts = self.telephone.strip("+").split() + self.raw_chntext = "".join([num2chn(part, alt_two=False, use_units=False) for part in sp_parts]) + self.chntext = self.raw_chntext.replace("", "") + return self.chntext + + +class Fraction: + """ + FRACTION类 + """ + + def __init__(self, fraction=None, chntext=None): + self.fraction = fraction + self.chntext = chntext + + def chntext2fraction(self): + denominator, numerator = self.chntext.split("分之") + return chn2num(numerator) + "/" + chn2num(denominator) + + def fraction2chntext(self): + numerator, denominator = self.fraction.split("/") + return num2chn(denominator) + "分之" + num2chn(numerator) + + +class Date: + """ + DATE类 + """ + + def __init__(self, date=None, chntext=None): + self.date = date + self.chntext = chntext + + # def chntext2date(self): + # chntext = self.chntext + # try: + # year, other = chntext.strip().split('年', maxsplit=1) + # year = Digit(chntext=year).digit2chntext() + '年' + # except ValueError: + # other = chntext + # year = '' + # if other: + # try: + # month, day = other.strip().split('月', maxsplit=1) + # month = Cardinal(chntext=month).chntext2cardinal() + '月' + # except ValueError: + # day = chntext + # month = '' + # if day: + # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] + # else: + # month = '' + # day = '' + # date = year + month + day + # self.date = date + # return self.date + + def date2chntext(self): + date = self.date + try: + year, other = date.strip().split("年", 1) + year = Digit(digit=year).digit2chntext() + "年" + except ValueError: + other = date + year = "" + if other: + try: + month, day = other.strip().split("月", 1) + month = Cardinal(cardinal=month).cardinal2chntext() + "月" + except ValueError: + day = date + month = "" + if day: + day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] + else: + month = "" + day = "" + chntext = year + month + day + self.chntext = chntext + return self.chntext + + +class Money: + """ + MONEY类 + """ + + def __init__(self, money=None, chntext=None): + self.money = money + self.chntext = chntext + + # def chntext2money(self): + # return self.money + + def money2chntext(self): + money = self.money + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(money) + if matchers: + for matcher in matchers: + money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()) + self.chntext = money + return self.chntext + + +class Percentage: + """ + PERCENTAGE类 + """ + + def __init__(self, percentage=None, chntext=None): + self.percentage = percentage + self.chntext = chntext + + def chntext2percentage(self): + return chn2num(self.chntext.strip().strip("百分之")) + "%" + + def percentage2chntext(self): + return "百分之" + num2chn(self.percentage.strip().strip("%")) + + +def normalize_nsw(raw_text): + text = "^" + raw_text + "$" + logger.debug(text) + + # 规范化日期 + pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)") + matchers = pattern.findall(text) + if matchers: + logger.debug("date") + for matcher in matchers: + text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) + + # 规范化金钱 + pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)") + matchers = pattern.findall(text) + if matchers: + logger.debug("money") + for matcher in matchers: + text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1) + + # 规范化固话/手机号码 + # 手机 + # http://www.jihaoba.com/news/show/13680 + # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 + # 联通:130、131、132、156、155、186、185、176 + # 电信:133、153、189、180、181、177 + pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") + matchers = pattern.findall(text) + if matchers: + logger.debug("telephone") + for matcher in matchers: + text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1) + # 固话 + pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") + matchers = pattern.findall(text) + if matchers: + logger.debug("fixed telephone") + for matcher in matchers: + text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1) + + # 规范化分数 + pattern = re.compile(r"(\d+/\d+)") + matchers = pattern.findall(text) + if matchers: + logger.debug("fraction") + for matcher in matchers: + text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1) + + # 规范化百分数 + text = text.replace("%", "%") + pattern = re.compile(r"(\d+(\.\d+)?%)") + matchers = pattern.findall(text) + if matchers: + logger.debug("percentage") + for matcher in matchers: + text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1) + + # 规范化纯数+量词 + pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) + matchers = pattern.findall(text) + if matchers: + logger.debug("cardinal+quantifier") + for matcher in matchers: + text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) + + # 规范化数字编号 + pattern = re.compile(r"(\d{4,32})") + matchers = pattern.findall(text) + if matchers: + logger.debug("digit") + for matcher in matchers: + text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) + + # 规范化纯数 + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(text) + if matchers: + logger.debug("cardinal") + for matcher in matchers: + text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) + + # restore P2P, O2O, B2C, B2B etc + pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") + matchers = pattern.findall(text) + if matchers: + logger.debug("particular") + for matcher in matchers: + text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1) + + return text.lstrip("^").rstrip("$") + + +def remove_erhua(text): + """ + 去除儿化音词中的儿: + 他女儿在那边儿 -> 他女儿在那边 + """ + + new_str = "" + while re.search("儿", text): + a = re.search("儿", text).span() + remove_er_flag = 0 + + if ER_WHITELIST_PATTERN.search(text): + b = ER_WHITELIST_PATTERN.search(text).span() + if b[0] <= a[0]: + remove_er_flag = 1 + + if remove_er_flag == 0: + new_str = new_str + text[0 : a[0]] + text = text[a[1] :] + else: + new_str = new_str + text[0 : b[1]] + text = text[b[1] :] + + text = new_str + text + return text + + +def remove_space(text): + tokens = text.split() + new = [] + for k, t in enumerate(tokens): + if k != 0: + if IN_EN_CHARS.get(tokens[k - 1][-1]) and IN_EN_CHARS.get(t[0]): + new.append(" ") + new.append(t) + return "".join(new) + + +class TextNorm: + def __init__( + self, + to_banjiao: bool = False, + to_upper: bool = False, + to_lower: bool = False, + remove_fillers: bool = False, + remove_erhua: bool = False, + check_chars: bool = False, + remove_space: bool = False, + cc_mode: str = "", + ): + self.to_banjiao = to_banjiao + self.to_upper = to_upper + self.to_lower = to_lower + self.remove_fillers = remove_fillers + self.remove_erhua = remove_erhua + self.check_chars = check_chars + self.remove_space = remove_space + + self.cc = None + if cc_mode: + from opencc import OpenCC # Open Chinese Convert: pip install opencc + + self.cc = OpenCC(cc_mode) + + def __call__(self, text): + if self.cc: + text = self.cc.convert(text) + + if self.to_banjiao: + text = text.translate(QJ2BJ_TRANSFORM) + + if self.to_upper: + text = text.upper() + + if self.to_lower: + text = text.lower() + + if self.remove_fillers: + for c in FILLER_CHARS: + text = text.replace(c, "") + + if self.remove_erhua: + text = remove_erhua(text) + + text = normalize_nsw(text) + + text = text.translate(PUNCS_TRANSFORM) + + if self.check_chars: + for c in text: + if not IN_VALID_CHARS.get(c): + logger.warning("Illegal char %s in: %s", c, text) + return "" + + if self.remove_space: + text = remove_space(text) + + return text + + +if __name__ == "__main__": + p = argparse.ArgumentParser() + + # normalizer options + p.add_argument("--to_banjiao", action="store_true", help="convert quanjiao chars to banjiao") + p.add_argument("--to_upper", action="store_true", help="convert to upper case") + p.add_argument("--to_lower", action="store_true", help="convert to lower case") + p.add_argument("--remove_fillers", action="store_true", help='remove filler chars such as "呃, 啊"') + p.add_argument("--remove_erhua", action="store_true", help='remove erhua chars such as "他女儿在那边儿 -> 他女儿在那边"') + p.add_argument("--check_chars", action="store_true", help="skip sentences containing illegal chars") + p.add_argument("--remove_space", action="store_true", help="remove whitespace") + p.add_argument( + "--cc_mode", choices=["", "t2s", "s2t"], default="", help="convert between traditional to simplified" + ) + + # I/O options + p.add_argument("--log_interval", type=int, default=10000, help="log interval in number of processed lines") + p.add_argument("--has_key", action="store_true", help="will be deprecated, set --format ark instead") + p.add_argument("--format", type=str, choices=["txt", "ark", "tsv"], default="txt", help="input format") + p.add_argument("ifile", help="input filename, assume utf-8 encoding") + p.add_argument("ofile", help="output filename") + + args = p.parse_args() + + if args.has_key: + args.format = "ark" + + normalizer = TextNorm( + to_banjiao=args.to_banjiao, + to_upper=args.to_upper, + to_lower=args.to_lower, + remove_fillers=args.remove_fillers, + remove_erhua=args.remove_erhua, + check_chars=args.check_chars, + remove_space=args.remove_space, + cc_mode=args.cc_mode, + ) + + normalizer = TextNorm( + to_banjiao=args.to_banjiao, + to_upper=args.to_upper, + to_lower=args.to_lower, + remove_fillers=args.remove_fillers, + remove_erhua=args.remove_erhua, + check_chars=args.check_chars, + remove_space=args.remove_space, + cc_mode=args.cc_mode, + ) + + ndone = 0 + with open(args.ifile, "r", encoding="utf8") as istream, open(args.ofile, "w+", encoding="utf8") as ostream: + if args.format == "tsv": + reader = csv.DictReader(istream, delimiter="\t") + assert "TEXT" in reader.fieldnames + print("\t".join(reader.fieldnames), file=ostream) + + for item in reader: + text = item["TEXT"] + + if text: + text = normalizer(text) + + if text: + item["TEXT"] = text + print("\t".join([item[f] for f in reader.fieldnames]), file=ostream) + + ndone += 1 + if ndone % args.log_interval == 0: + print(f"text norm: {ndone} lines done.", file=sys.stderr, flush=True) + else: + for l in istream: + key, text = "", "" + if args.format == "ark": # KALDI archive, line format: "key text" + cols = l.strip().split(maxsplit=1) + key, text = cols[0], cols[1] if len(cols) == 2 else "" + else: + text = l.strip() + + if text: + text = normalizer(text) + + if text: + if args.format == "ark": + print(key + "\t" + text, file=ostream) + else: + print(text, file=ostream) + + ndone += 1 + if ndone % args.log_interval == 0: + print(f"text norm: {ndone} lines done.", file=sys.stderr, flush=True) + print(f"text norm: {ndone} lines done in total.", file=sys.stderr, flush=True) diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ebfa171c807f0046e7190310df9bef8b4d79bc2d --- /dev/null +++ b/TTS/tts/models/__init__.py @@ -0,0 +1,17 @@ +import logging +from typing import Dict, List, Union + +from TTS.utils.generic_utils import find_module + +logger = logging.getLogger(__name__) + + +def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseTTS": + logger.info("Using model: %s", config.model) + # fetch the right model implementation. + if "base_model" in config and config["base_model"] is not None: + MyModel = find_module("TTS.tts.models", config.base_model.lower()) + else: + MyModel = find_module("TTS.tts.models", config.model.lower()) + model = MyModel.init_from_config(config=config, samples=samples) + return model diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..2d27a5785054021e18e51f7022322d590d7f541c --- /dev/null +++ b/TTS/tts/models/align_tts.py @@ -0,0 +1,448 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Union + +import torch +from coqpit import Coqpit +from torch import nn +from trainer.io import load_fsspec + +from TTS.tts.layers.align_tts.mdn import MDNBlock +from TTS.tts.layers.feed_forward.decoder import Decoder +from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram + + +@dataclass +class AlignTTSArgs(Coqpit): + """ + Args: + num_chars (int): + number of unique input to characters + out_channels (int): + number of output tensor channels. It is equal to the expected spectrogram size. + hidden_channels (int): + number of channels in all the model layers. + hidden_channels_ffn (int): + number of channels in transformer's conv layers. + hidden_channels_dp (int): + number of channels in duration predictor network. + num_heads (int): + number of attention heads in transformer networks. + num_transformer_layers (int): + number of layers in encoder and decoder transformer blocks. + dropout_p (int): + dropout rate in transformer layers. + length_scale (int, optional): + coefficient to set the speech speed. <1 slower, >1 faster. Defaults to 1. + num_speakers (int, optional): + number of speakers for multi-speaker training. Defaults to 0. + external_c (bool, optional): + enable external speaker embeddings. Defaults to False. + c_in_channels (int, optional): + number of channels in speaker embedding vectors. Defaults to 0. + """ + + num_chars: int = None + out_channels: int = 80 + hidden_channels: int = 256 + hidden_channels_dp: int = 256 + encoder_type: str = "fftransformer" + encoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1} + ) + decoder_type: str = "fftransformer" + decoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1} + ) + length_scale: float = 1.0 + num_speakers: int = 0 + use_speaker_embedding: bool = False + use_d_vector_file: bool = False + d_vector_dim: int = 0 + + +class AlignTTS(BaseTTS): + """AlignTTS with modified duration predictor. + https://arxiv.org/pdf/2003.01950.pdf + + Encoder -> DurationPredictor -> Decoder + + Check :class:`AlignTTSArgs` for the class arguments. + + Paper Abstract: + Targeting at both high efficiency and performance, we propose AlignTTS to predict the + mel-spectrum in parallel. AlignTTS is based on a Feed-Forward Transformer which generates mel-spectrum from a + sequence of characters, and the duration of each character is determined by a duration predictor.Instead of + adopting the attention mechanism in Transformer TTS to align text to mel-spectrum, the alignment loss is presented + to consider all possible alignments in training by use of dynamic programming. Experiments on the LJSpeech dataset s + how that our model achieves not only state-of-the-art performance which outperforms Transformer TTS by 0.03 in mean + option score (MOS), but also a high efficiency which is more than 50 times faster than real-time. + + Note: + Original model uses a separate character embedding layer for duration predictor. However, it causes the + duration predictor to overfit and prevents learning higher level interactions among characters. Therefore, + we predict durations based on encoder outputs which has higher level information about input characters. This + enables training without phases as in the original paper. + + Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture + differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters. + + Examples: + >>> from TTS.tts.configs.align_tts_config import AlignTTSConfig + >>> config = AlignTTSConfig() + >>> model = AlignTTS(config) + + """ + + # pylint: disable=dangerous-default-value + + def __init__( + self, + config: "AlignTTSConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + self.speaker_manager = speaker_manager + self.phase = -1 + self.length_scale = ( + float(config.model_args.length_scale) + if isinstance(config.model_args.length_scale, int) + else config.model_args.length_scale + ) + + self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels) + + self.embedded_speaker_dim = 0 + self.init_multispeaker(config) + + self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels) + self.encoder = Encoder( + config.model_args.hidden_channels, + config.model_args.hidden_channels, + config.model_args.encoder_type, + config.model_args.encoder_params, + self.embedded_speaker_dim, + ) + self.decoder = Decoder( + config.model_args.out_channels, + config.model_args.hidden_channels, + config.model_args.decoder_type, + config.model_args.decoder_params, + ) + self.duration_predictor = DurationPredictor(config.model_args.hidden_channels_dp) + + self.mod_layer = nn.Conv1d(config.model_args.hidden_channels, config.model_args.hidden_channels, 1) + + self.mdn_block = MDNBlock(config.model_args.hidden_channels, 2 * config.model_args.out_channels) + + if self.embedded_speaker_dim > 0 and self.embedded_speaker_dim != config.model_args.hidden_channels: + self.proj_g = nn.Conv1d(self.embedded_speaker_dim, config.model_args.hidden_channels, 1) + + @staticmethod + def compute_log_probs(mu, log_sigma, y): + # pylint: disable=protected-access, c-extension-no-member + y = y.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D] + mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] + log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] + expanded_y, expanded_mu = torch.broadcast_tensors(y, mu) + exponential = -0.5 * torch.mean( + torch._C._nn.mse_loss(expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), dim=-1 + ) # B, L, T + logp = exponential - 0.5 * log_sigma.mean(dim=-1) + return logp + + def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask): + # find the max alignment path + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + log_p = self.compute_log_probs(mu, log_sigma, y) + # [B, T_en, T_dec] + attn = maximum_path(log_p, attn_mask.squeeze(1)).unsqueeze(1) + dr_mas = torch.sum(attn, -1) + return dr_mas.squeeze(1), log_p + + @staticmethod + def generate_attn(dr, x_mask, y_mask=None): + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) + return attn + + def expand_encoder_outputs(self, en, dr, x_mask, y_mask): + """Generate attention alignment map from durations and + expand encoder outputs + + Examples:: + - encoder output: [a,b,c,d] + - durations: [1, 3, 2, 1] + + - expanded: [a, b, b, b, c, c, d] + - attention map: [[0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 1, 0], + [0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0]] + """ + attn = self.generate_attn(dr, x_mask, y_mask) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) + return o_en_ex, attn + + def format_durations(self, o_dr_log, x_mask): + o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale + o_dr[o_dr < 1] = 1.0 + o_dr = torch.round(o_dr) + return o_dr + + @staticmethod + def _concat_speaker_embedding(o_en, g): + g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en] + o_en = torch.cat([o_en, g_exp], 1) + return o_en + + def _sum_speaker_embedding(self, x, g): + # project g to decoder dim. + if hasattr(self, "proj_g"): + g = self.proj_g(g) + + return x + g + + def _forward_encoder(self, x, x_lengths, g=None): + if hasattr(self, "emb_g"): + g = nn.functional.normalize(self.speaker_embedding(g)) # [B, C, 1] + + if g is not None: + g = g.unsqueeze(-1) + + # [B, T, C] + x_emb = self.emb(x) + # [B, C, T] + x_emb = torch.transpose(x_emb, 1, -1) + + # compute sequence masks + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) + + # encoder pass + o_en = self.encoder(x_emb, x_mask) + + # speaker conditioning for duration predictor + if g is not None: + o_en_dp = self._concat_speaker_embedding(o_en, g) + else: + o_en_dp = o_en + return o_en, o_en_dp, x_mask, g + + def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + # positional encoding + if hasattr(self, "pos_encoder"): + o_en_ex = self.pos_encoder(o_en_ex, y_mask) + # speaker embedding + if g is not None: + o_en_ex = self._sum_speaker_embedding(o_en_ex, g) + # decoder pass + o_de = self.decoder(o_en_ex, y_mask, g=g) + return o_de, attn.transpose(1, 2) + + def _forward_mdn(self, o_en, y, y_lengths, x_mask): + # MAS potentials and alignment + mu, log_sigma = self.mdn_block(o_en) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) + dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask) + return dr_mas, mu, log_sigma, logp + + def forward( + self, x, x_lengths, y, y_lengths, aux_input={"d_vectors": None}, phase=None + ): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - y_lengths: :math:`[B]` + - dr: :math:`[B, T_max]` + - g: :math:`[B, C]` + """ + y = y.transpose(1, 2) + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None + if phase == 0: + # train encoder and MDN + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) + attn = self.generate_attn(dr_mas, x_mask, y_mask) + elif phase == 1: + # train decoder + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g) + elif phase == 2: + # train the whole except duration predictor + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) + elif phase == 3: + # train duration predictor + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + o_dr_log = self.duration_predictor(x, x_mask) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) + o_dr_log = o_dr_log.squeeze(1) + else: + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) + o_dr_log = o_dr_log.squeeze(1) + dr_mas_log = torch.log(dr_mas + 1).squeeze(1) + outputs = { + "model_outputs": o_de.transpose(1, 2), + "alignments": attn, + "durations_log": o_dr_log, + "durations_mas_log": dr_mas_log, + "mu": mu, + "log_sigma": log_sigma, + "logp": logp, + } + return outputs + + @torch.no_grad() + def inference(self, x, aux_input={"d_vectors": None}): # pylint: disable=unused-argument + """ + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - g: :math:`[B, C]` + """ + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + # pad input to prevent dropping the last word + # x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0) + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + # o_dr_log = self.duration_predictor(x, x_mask) + o_dr_log = self.duration_predictor(o_en_dp, x_mask) + # duration predictor pass + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + y_lengths = o_dr.sum(1) + o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) + outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn} + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + + aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input, self.phase) + loss_dict = criterion( + outputs["logp"], + outputs["model_outputs"], + mel_input, + mel_lengths, + outputs["durations_log"], + outputs["durations_mas_log"], + text_lengths, + phase=self.phase, + ) + + return outputs, loss_dict + + def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use + model_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_criterion(self): + from TTS.tts.layers.losses import AlignTTSLoss # pylint: disable=import-outside-toplevel + + return AlignTTSLoss(self.config) + + @staticmethod + def _set_phase(config, global_step): + """Decide AlignTTS training phase""" + if isinstance(config.phase_start_steps, list): + vals = [i < global_step for i in config.phase_start_steps] + if True not in vals: + phase = 0 + else: + phase = ( + len(config.phase_start_steps) + - [i < global_step for i in config.phase_start_steps][::-1].index(True) + - 1 + ) + else: + phase = None + return phase + + def on_epoch_start(self, trainer): + """Set AlignTTS training phase on epoch start.""" + self.phase = self._set_phase(trainer.config, trainer.total_steps_done) + + @staticmethod + def init_from_config(config: "AlignTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (AlignTTSConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return AlignTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/bark.py b/TTS/tts/models/bark.py new file mode 100644 index 0000000000000000000000000000000000000000..ced8f60ed86cdfa947bc175aec8d9ecbf17ceab9 --- /dev/null +++ b/TTS/tts/models/bark.py @@ -0,0 +1,278 @@ +import os +from dataclasses import dataclass +from typing import Optional + +import numpy as np +from coqpit import Coqpit +from encodec import EncodecModel +from transformers import BertTokenizer + +from TTS.tts.layers.bark.inference_funcs import ( + codec_decode, + generate_coarse, + generate_fine, + generate_text_semantic, + generate_voice, + load_voice, +) +from TTS.tts.layers.bark.load_model import load_model +from TTS.tts.layers.bark.model import GPT +from TTS.tts.layers.bark.model_fine import FineGPT +from TTS.tts.models.base_tts import BaseTTS + + +@dataclass +class BarkAudioConfig(Coqpit): + sample_rate: int = 24000 + output_sample_rate: int = 24000 + + +class Bark(BaseTTS): + def __init__( + self, + config: Coqpit, + tokenizer: BertTokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased"), + ) -> None: + super().__init__(config=config, ap=None, tokenizer=None, speaker_manager=None, language_manager=None) + self.config.num_chars = len(tokenizer) + self.tokenizer = tokenizer + self.semantic_model = GPT(config.semantic_config) + self.coarse_model = GPT(config.coarse_config) + self.fine_model = FineGPT(config.fine_config) + self.encodec = EncodecModel.encodec_model_24khz() + self.encodec.set_target_bandwidth(6.0) + + @property + def device(self): + return next(self.parameters()).device + + def load_bark_models(self): + self.semantic_model, self.config = load_model( + ckpt_path=self.config.LOCAL_MODEL_PATHS["text"], device=self.device, config=self.config, model_type="text" + ) + self.coarse_model, self.config = load_model( + ckpt_path=self.config.LOCAL_MODEL_PATHS["coarse"], + device=self.device, + config=self.config, + model_type="coarse", + ) + self.fine_model, self.config = load_model( + ckpt_path=self.config.LOCAL_MODEL_PATHS["fine"], device=self.device, config=self.config, model_type="fine" + ) + + def train_step( + self, + ): + pass + + def text_to_semantic( + self, + text: str, + history_prompt: Optional[str] = None, + temp: float = 0.7, + base=None, + allow_early_stop=True, + **kwargs, + ): + """Generate semantic array from text. + + Args: + text: text to be turned into audio + history_prompt: history choice for audio cloning + temp: generation temperature (1.0 more diverse, 0.0 more conservative) + + Returns: + numpy semantic array to be fed into `semantic_to_waveform` + """ + x_semantic = generate_text_semantic( + text, + self, + history_prompt=history_prompt, + temp=temp, + base=base, + allow_early_stop=allow_early_stop, + **kwargs, + ) + return x_semantic + + def semantic_to_waveform( + self, + semantic_tokens: np.ndarray, + history_prompt: Optional[str] = None, + temp: float = 0.7, + base=None, + ): + """Generate audio array from semantic input. + + Args: + semantic_tokens: semantic token output from `text_to_semantic` + history_prompt: history choice for audio cloning + temp: generation temperature (1.0 more diverse, 0.0 more conservative) + + Returns: + numpy audio array at sample frequency 24khz + """ + x_coarse_gen = generate_coarse( + semantic_tokens, + self, + history_prompt=history_prompt, + temp=temp, + base=base, + ) + x_fine_gen = generate_fine( + x_coarse_gen, + self, + history_prompt=history_prompt, + temp=0.5, + base=base, + ) + audio_arr = codec_decode(x_fine_gen, self) + return audio_arr, x_coarse_gen, x_fine_gen + + def generate_audio( + self, + text: str, + history_prompt: Optional[str] = None, + text_temp: float = 0.7, + waveform_temp: float = 0.7, + base=None, + allow_early_stop=True, + **kwargs, + ): + """Generate audio array from input text. + + Args: + text: text to be turned into audio + history_prompt: history choice for audio cloning + text_temp: generation temperature (1.0 more diverse, 0.0 more conservative) + waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative) + + Returns: + numpy audio array at sample frequency 24khz + """ + x_semantic = self.text_to_semantic( + text, + history_prompt=history_prompt, + temp=text_temp, + base=base, + allow_early_stop=allow_early_stop, + **kwargs, + ) + audio_arr, c, f = self.semantic_to_waveform( + x_semantic, history_prompt=history_prompt, temp=waveform_temp, base=base + ) + return audio_arr, [x_semantic, c, f] + + def generate_voice(self, audio, speaker_id, voice_dir): + """Generate a voice from the given audio. + + Args: + audio (str): Path to the audio file. + speaker_id (str): Speaker name. + voice_dir (str): Path to the directory to save the generate voice. + """ + if voice_dir is not None: + voice_dirs = [voice_dir] + try: + _ = load_voice(self, speaker_id, voice_dirs) + except (KeyError, FileNotFoundError): + output_path = os.path.join(voice_dir, speaker_id + ".npz") + os.makedirs(voice_dir, exist_ok=True) + generate_voice(audio, self, output_path) + + def _set_voice_dirs(self, voice_dirs): + def_voice_dir = None + if isinstance(self.config.DEF_SPEAKER_DIR, str): + os.makedirs(self.config.DEF_SPEAKER_DIR, exist_ok=True) + if os.path.isdir(self.config.DEF_SPEAKER_DIR): + def_voice_dir = self.config.DEF_SPEAKER_DIR + _voice_dirs = [def_voice_dir] if def_voice_dir is not None else [] + if voice_dirs is not None: + if isinstance(voice_dirs, str): + voice_dirs = [voice_dirs] + _voice_dirs = voice_dirs + _voice_dirs + return _voice_dirs + + # TODO: remove config from synthesize + def synthesize( + self, text, config, speaker_id="random", voice_dirs=None, **kwargs + ): # pylint: disable=unused-argument + """Synthesize speech with the given input text. + + Args: + text (str): Input text. + config (BarkConfig): Config with inference parameters. + speaker_id (str): One of the available speaker names. If `random`, it generates a random speaker. + speaker_wav (str): Path to the speaker audio file for cloning a new voice. It is cloned and saved in + `voice_dirs` with the name `speaker_id`. Defaults to None. + voice_dirs (List[str]): List of paths that host reference audio files for speakers. Defaults to None. + **kwargs: Model specific inference settings used by `generate_audio()` and `TTS.tts.layers.bark.inference_funcs.generate_text_semantic(). + + Returns: + A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference, + `text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents` + as latents used at inference. + + """ + speaker_id = "random" if speaker_id is None else speaker_id + voice_dirs = self._set_voice_dirs(voice_dirs) + history_prompt = load_voice(self, speaker_id, voice_dirs) + outputs = self.generate_audio(text, history_prompt=history_prompt, **kwargs) + return_dict = { + "wav": outputs[0], + "text_inputs": text, + } + + return return_dict + + def eval_step(self): ... + + def forward(self): ... + + def inference(self): ... + + @staticmethod + def init_from_config(config: "BarkConfig", **kwargs): # pylint: disable=unused-argument + return Bark(config) + + # pylint: disable=unused-argument, redefined-builtin + def load_checkpoint( + self, + config, + checkpoint_dir, + text_model_path=None, + coarse_model_path=None, + fine_model_path=None, + hubert_tokenizer_path=None, + eval=False, + strict=True, + **kwargs, + ): + """Load a model checkpoints from a directory. This model is with multiple checkpoint files and it + expects to have all the files to be under the given `checkpoint_dir` with the rigth names. + If eval is True, set the model to eval mode. + + Args: + config (TortoiseConfig): The model config. + checkpoint_dir (str): The directory where the checkpoints are stored. + ar_checkpoint_path (str, optional): The path to the autoregressive checkpoint. Defaults to None. + diff_checkpoint_path (str, optional): The path to the diffusion checkpoint. Defaults to None. + clvp_checkpoint_path (str, optional): The path to the CLVP checkpoint. Defaults to None. + vocoder_checkpoint_path (str, optional): The path to the vocoder checkpoint. Defaults to None. + eval (bool, optional): Whether to set the model to eval mode. Defaults to False. + strict (bool, optional): Whether to load the model strictly. Defaults to True. + """ + text_model_path = text_model_path or os.path.join(checkpoint_dir, "text_2.pt") + coarse_model_path = coarse_model_path or os.path.join(checkpoint_dir, "coarse_2.pt") + fine_model_path = fine_model_path or os.path.join(checkpoint_dir, "fine_2.pt") + hubert_tokenizer_path = hubert_tokenizer_path or os.path.join(checkpoint_dir, "tokenizer.pth") + + self.config.LOCAL_MODEL_PATHS["text"] = text_model_path + self.config.LOCAL_MODEL_PATHS["coarse"] = coarse_model_path + self.config.LOCAL_MODEL_PATHS["fine"] = fine_model_path + self.config.LOCAL_MODEL_PATHS["hubert_tokenizer"] = hubert_tokenizer_path + + self.load_bark_models() + + if eval: + self.eval() diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py new file mode 100644 index 0000000000000000000000000000000000000000..79cdf1a7d4aaa339835ee9e5c87800a0c355443c --- /dev/null +++ b/TTS/tts/models/base_tacotron.py @@ -0,0 +1,309 @@ +import copy +import logging +from abc import abstractmethod +from typing import Dict, Tuple + +import torch +from coqpit import Coqpit +from torch import nn +from trainer.io import load_fsspec + +from TTS.tts.layers.losses import TacotronLoss +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.generic_utils import format_aux_input +from TTS.utils.training import gradual_training_scheduler + +logger = logging.getLogger(__name__) + + +class BaseTacotron(BaseTTS): + """Base class shared by Tacotron and Tacotron2""" + + def __init__( + self, + config: "TacotronConfig", + ap: "AudioProcessor", + tokenizer: "TTSTokenizer", + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields as class attributes + for key in config: + setattr(self, key, config[key]) + + # layers + self.embedding = None + self.encoder = None + self.decoder = None + self.postnet = None + + # init tensors + self.embedded_speakers = None + self.embedded_speakers_projected = None + + # global style token + if self.gst and self.use_gst: + self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim + self.gst_layer = None + + # Capacitron + if self.capacitron_vae and self.use_capacitron_vae: + self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim # add capacitron embedding dim + self.capacitron_vae_layer = None + + # additional layers + self.decoder_backward = None + self.coarse_decoder = None + + @staticmethod + def _format_aux_input(aux_input: Dict) -> Dict: + """Set missing fields to their default values""" + if aux_input: + return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input) + return None + + ############################# + # INIT FUNCTIONS + ############################# + + def _init_backward_decoder(self): + """Init the backward decoder for Forward-Backward decoding.""" + self.decoder_backward = copy.deepcopy(self.decoder) + + def _init_coarse_decoder(self): + """Init the coarse decoder for Double-Decoder Consistency.""" + self.coarse_decoder = copy.deepcopy(self.decoder) + self.coarse_decoder.r_init = self.ddc_r + self.coarse_decoder.set_r(self.ddc_r) + + ############################# + # CORE FUNCTIONS + ############################# + + @abstractmethod + def forward(self): + pass + + @abstractmethod + def inference(self): + pass + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + """Load model checkpoint and set up internals. + + Args: + config (Coqpi): model configuration. + checkpoint_path (str): path to checkpoint file. + eval (bool, optional): whether to load model for evaluation. + cache (bool, optional): If True, cache the file locally for subsequent calls. + It is cached under `trainer.io.get_user_data_dir()/tts_cache`. Defaults to False. + """ + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + # TODO: set r in run-time by taking it from the new config + if "r" in state: + # set r from the state (for compatibility with older checkpoints) + self.decoder.set_r(state["r"]) + elif "config" in state: + # set r from config used at training time (for inference) + self.decoder.set_r(state["config"]["r"]) + else: + # set r from the new config (for new-models) + self.decoder.set_r(config.r) + if eval: + self.eval() + logger.info("Model's reduction rate `r` is set to: %d", self.decoder.r) + assert not self.training + + def get_criterion(self) -> nn.Module: + """Get the model criterion used in training.""" + return TacotronLoss(self.config) + + @staticmethod + def init_from_config(config: Coqpit): + """Initialize model from config.""" + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config) + return BaseTacotron(config, ap, tokenizer, speaker_manager) + + ########################## + # TEST AND LOG FUNCTIONS # + ########################## + + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Args: + assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + logger.info("Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_test_aux_input() + for idx, sen in enumerate(test_sentences): + outputs_dict = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + use_griffin_lim=True, + do_trim_silence=False, + ) + test_audios["{}-audio".format(idx)] = outputs_dict["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment( + outputs_dict["outputs"]["alignments"], output_fig=False + ) + return {"figures": test_figures, "audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) + logger.test_figures(steps, outputs["figures"]) + + ############################# + # COMMON COMPUTE FUNCTIONS + ############################# + + def compute_masks(self, text_lengths, mel_lengths): + """Compute masks against sequence paddings.""" + # B x T_in_max (boolean) + input_mask = sequence_mask(text_lengths) + output_mask = None + if mel_lengths is not None: + max_len = mel_lengths.max() + r = self.decoder.r + max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len + output_mask = sequence_mask(mel_lengths, max_len=max_len) + return input_mask, output_mask + + def _backward_pass(self, mel_specs, encoder_outputs, mask): + """Run backwards decoder""" + decoder_outputs_b, alignments_b, _ = self.decoder_backward( + encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask + ) + decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous() + return decoder_outputs_b, alignments_b + + def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask): + """Double Decoder Consistency""" + T = mel_specs.shape[1] + if T % self.coarse_decoder.r > 0: + padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r) + mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0)) + decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder( + encoder_outputs.detach(), mel_specs, input_mask + ) + # scale_factor = self.decoder.r_init / self.decoder.r + alignments_backward = torch.nn.functional.interpolate( + alignments_backward.transpose(1, 2), + size=alignments.shape[1], + mode="nearest", + ).transpose(1, 2) + decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2) + decoder_outputs_backward = decoder_outputs_backward[:, :T, :] + return decoder_outputs_backward, alignments_backward + + ############################# + # EMBEDDING FUNCTIONS + ############################# + + def compute_gst(self, inputs, style_input, speaker_embedding=None): + """Compute global style token""" + if isinstance(style_input, dict): + # multiply each style token with a weight + query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs) + if speaker_embedding is not None: + query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1) + + _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) + gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) + for k_token, v_amplifier in style_input.items(): + key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) + gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) + gst_outputs = gst_outputs + gst_outputs_att * v_amplifier + elif style_input is None: + # ignore style token and return zero tensor + gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) + else: + # compute style tokens + gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable + inputs = self._concat_speaker_embedding(inputs, gst_outputs) + return inputs + + def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None): + """Capacitron Variational Autoencoder""" + ( + VAE_outputs, + posterior_distribution, + prior_distribution, + capacitron_beta, + ) = self.capacitron_vae_layer( + reference_mel_info, + text_info, + speaker_embedding, # pylint: disable=not-callable + ) + + VAE_outputs = VAE_outputs.to(inputs.device) + encoder_output = self._concat_speaker_embedding( + inputs, VAE_outputs + ) # concatenate to the output of the basic tacotron encoder + return ( + encoder_output, + posterior_distribution, + prior_distribution, + capacitron_beta, + ) + + @staticmethod + def _add_speaker_embedding(outputs, embedded_speakers): + embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1) + outputs = outputs + embedded_speakers_ + return outputs + + @staticmethod + def _concat_speaker_embedding(outputs, embedded_speakers): + embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1) + outputs = torch.cat([outputs, embedded_speakers_], dim=-1) + return outputs + + ############################# + # CALLBACKS + ############################# + + def on_epoch_start(self, trainer): + """Callback for setting values wrt gradual training schedule. + + Args: + trainer (TrainerTTS): TTS trainer object that is used to train this model. + """ + if self.gradual_training: + r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config) + trainer.config.r = r + self.decoder.set_r(r) + if trainer.config.bidirectional_decoder: + trainer.model.decoder_backward.set_r(r) + logger.info("Number of output frames: %d", self.decoder.r) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb023ce8424d9d436df8eef200812fe38a3ddd7 --- /dev/null +++ b/TTS/tts/models/base_tts.py @@ -0,0 +1,463 @@ +import logging +import os +import random +from typing import Dict, List, Tuple, Union + +import torch +import torch.distributed as dist +from coqpit import Coqpit +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.torch import DistributedSampler, DistributedSamplerWrapper + +from TTS.model import BaseTrainerModel +from TTS.tts.datasets.dataset import TTSDataset +from TTS.tts.utils.data import get_length_balancer_weights +from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram + +logger = logging.getLogger(__name__) + +# pylint: skip-file + + +class BaseTTS(BaseTrainerModel): + """Base `tts` class. Every new `tts` model must inherit this. + + It defines common `tts` specific functions on top of `Model` implementation. + """ + + MODEL_TYPE = "tts" + + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor", + tokenizer: "TTSTokenizer", + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, + ): + super().__init__() + self.config = config + self.ap = ap + self.tokenizer = tokenizer + self.speaker_manager = speaker_manager + self.language_manager = language_manager + self._set_model_args(config) + + def _set_model_args(self, config: Coqpit): + """Setup model args based on the config type (`ModelConfig` or `ModelArgs`). + + `ModelArgs` has all the fields reuqired to initialize the model architecture. + + `ModelConfig` has all the fields required for training, inference and containes `ModelArgs`. + + If the config is for training with a name like "*Config", then the model args are embeded in the + config.model_args + + If the config is for the model with a name like "*Args", then we assign the directly. + """ + # don't use isintance not to import recursively + if "Config" in config.__class__.__name__: + config_num_chars = ( + self.config.model_args.num_chars if hasattr(self.config, "model_args") else self.config.num_chars + ) + num_chars = config_num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars + if "characters" in config: + self.config.num_chars = num_chars + if hasattr(self.config, "model_args"): + config.model_args.num_chars = num_chars + self.args = self.config.model_args + else: + self.config = config + self.args = config.model_args + elif "Args" in config.__class__.__name__: + self.args = config + else: + raise ValueError("config must be either a *Config or *Args") + + def init_multispeaker(self, config: Coqpit, data: List = None): + """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining + `in_channels` size of the connected layers. + + This implementation yields 3 possible outcomes: + + 1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing. + 2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512. + 3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of + `config.d_vector_dim` or 512. + + You can override this function for new models. + + Args: + config (Coqpit): Model configuration. + """ + # set number of speakers + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + elif hasattr(config, "num_speakers"): + self.num_speakers = config.num_speakers + + # set ultimate speaker embedding size + if config.use_speaker_embedding or config.use_d_vector_file: + self.embedded_speaker_dim = ( + config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 + ) + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + logger.info("Init speaker_embedding layer.") + self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + self.speaker_embedding.weight.data.normal_(0, 0.3) + + def get_aux_input(self, **kwargs) -> Dict: + """Prepare and return `aux_input` used by `forward()`""" + return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if self.speaker_manager is not None: + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_embedding() + else: + d_vector = self.speaker_manager.get_mean_embedding(speaker_name) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_id() + else: + speaker_id = self.speaker_manager.name_to_id[speaker_name] + + # get language id + if self.language_manager is not None and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.name_to_id[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + } + + def format_batch(self, batch: Dict) -> Dict: + """Generic batch formatting for `TTSDataset`. + + You must override this if you use a custom dataset. + + Args: + batch (Dict): [description] + + Returns: + Dict: [description] + """ + # setup input batch + text_input = batch["token_id"] + text_lengths = batch["token_id_lengths"] + speaker_names = batch["speaker_names"] + linear_input = batch["linear"] + mel_input = batch["mel"] + mel_lengths = batch["mel_lengths"] + stop_targets = batch["stop_targets"] + item_idx = batch["item_idxs"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + attn_mask = batch["attns"] + waveform = batch["waveform"] + pitch = batch["pitch"] + energy = batch["energy"] + language_ids = batch["language_ids"] + max_text_length = torch.max(text_lengths.float()) + max_spec_length = torch.max(mel_lengths.float()) + + # compute durations from attention masks + durations = None + if attn_mask is not None: + durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) + for idx, am in enumerate(attn_mask): + # compute raw durations + c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] + # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) + c_idxs, counts = torch.unique(c_idxs, return_counts=True) + dur = torch.ones([text_lengths[idx]]).to(counts.dtype) + dur[c_idxs] = counts + # smooth the durations and set any 0 duration to 1 + # by cutting off from the largest duration indeces. + extra_frames = dur.sum() - mel_lengths[idx] + largest_idxs = torch.argsort(-dur)[:extra_frames] + dur[largest_idxs] -= 1 + assert ( + dur.sum() == mel_lengths[idx] + ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" + durations[idx, : text_lengths[idx]] = dur + + # set stop targets wrt reduction factor + stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) + stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_() + + return { + "text_input": text_input, + "text_lengths": text_lengths, + "speaker_names": speaker_names, + "mel_input": mel_input, + "mel_lengths": mel_lengths, + "linear_input": linear_input, + "stop_targets": stop_targets, + "stop_target_lengths": stop_target_lengths, + "attn_mask": attn_mask, + "durations": durations, + "speaker_ids": speaker_ids, + "d_vectors": d_vectors, + "max_text_length": float(max_text_length), + "max_spec_length": float(max_spec_length), + "item_idx": item_idx, + "waveform": waveform, + "pitch": pitch, + "energy": energy, + "language_ids": language_ids, + "audio_unique_names": batch["audio_unique_names"], + } + + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): + weights = None + data_items = dataset.samples + + if getattr(config, "use_language_weighted_sampler", False): + alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) + logger.info("Using Language weighted sampler with alpha: %.2f", alpha) + weights = get_language_balancer_weights(data_items) * alpha + + if getattr(config, "use_speaker_weighted_sampler", False): + alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) + logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha) + if weights is not None: + weights += get_speaker_balancer_weights(data_items) * alpha + else: + weights = get_speaker_balancer_weights(data_items) * alpha + + if getattr(config, "use_length_weighted_sampler", False): + alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) + logger.info("Using Length weighted sampler with alpha: %.2f", alpha) + if weights is not None: + weights += get_length_balancer_weights(data_items) * alpha + else: + weights = get_length_balancer_weights(data_items) * alpha + + if weights is not None: + sampler = WeightedRandomSampler(weights, len(weights)) + else: + sampler = None + + # sampler for DDP + if sampler is None: + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler + + return sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # setup multi-speaker attributes + if self.speaker_manager is not None: + if hasattr(config, "model_args"): + speaker_id_mapping = ( + self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None + ) + d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None + config.use_d_vector_file = config.model_args.use_d_vector_file + else: + speaker_id_mapping = self.speaker_manager.name_to_id if config.use_speaker_embedding else None + d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None + else: + speaker_id_mapping = None + d_vector_mapping = None + + # setup multi-lingual attributes + if self.language_manager is not None: + language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None + else: + language_id_mapping = None + + # init dataloader + dataset = TTSDataset( + outputs_per_step=config.r if "r" in config else 1, + compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, + compute_f0=config.get("compute_f0", False), + f0_cache_path=config.get("f0_cache_path", None), + compute_energy=config.get("compute_energy", False), + energy_cache_path=config.get("energy_cache_path", None), + samples=samples, + ap=self.ap, + return_wav=config.return_wav if "return_wav" in config else False, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + use_noise_augment=False if is_eval else config.use_noise_augment, + speaker_id_mapping=speaker_id_mapping, + d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + language_id_mapping=language_id_mapping, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=config.shuffle if sampler is None else False, # if there is no other sampler + collate_fn=dataset.collate_fn, + drop_last=config.drop_last, # setting this False might cause issues in AMP training. + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def _get_test_aux_input( + self, + ) -> Dict: + d_vector = None + if self.config.use_d_vector_file: + d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings] + d_vector = (random.sample(sorted(d_vector), 1),) + + aux_inputs = { + "speaker_id": ( + None + if not self.config.use_speaker_embedding + else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1) + ), + "d_vector": d_vector, + "style_wav": None, # TODO: handle GST style input + } + return aux_inputs + + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Args: + assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + logger.info("Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_test_aux_input() + for idx, sen in enumerate(test_sentences): + if isinstance(sen, list): + aux_inputs = self.get_aux_input_from_test_sentences(sen) + sen = aux_inputs["text"] + outputs_dict = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + use_griffin_lim=True, + do_trim_silence=False, + ) + test_audios["{}-audio".format(idx)] = outputs_dict["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment( + outputs_dict["outputs"]["alignments"], output_fig=False + ) + return test_figures, test_audios + + def on_init_start(self, trainer): + """Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths.""" + if self.speaker_manager is not None: + output_path = os.path.join(trainer.output_path, "speakers.pth") + self.speaker_manager.save_ids_to_file(output_path) + trainer.config.speakers_file = output_path + # some models don't have `model_args` set + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.speakers_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + logger.info("`speakers.pth` is saved to: %s", output_path) + logger.info("`speakers_file` is updated in the config.json.") + + if self.language_manager is not None: + output_path = os.path.join(trainer.output_path, "language_ids.json") + self.language_manager.save_ids_to_file(output_path) + trainer.config.language_ids_file = output_path + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.language_ids_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + logger.info("`language_ids.json` is saved to: %s", output_path) + logger.info("`language_ids_file` is updated in the config.json.") + + +class BaseTTSE2E(BaseTTS): + def _set_model_args(self, config: Coqpit): + self.config = config + if "Config" in config.__class__.__name__: + num_chars = ( + self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars + ) + self.config.model_args.num_chars = num_chars + self.config.num_chars = num_chars + self.args = config.model_args + self.args.num_chars = num_chars + elif "Args" in config.__class__.__name__: + self.args = config + self.args.num_chars = self.args.num_chars + else: + raise ValueError("config must be either a *Config or *Args") diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..a938a3a4ab6a68d991de599279281144a21353fc --- /dev/null +++ b/TTS/tts/models/delightful_tts.py @@ -0,0 +1,1768 @@ +import logging +import os +from dataclasses import dataclass, field +from itertools import chain +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torchaudio +from coqpit import Coqpit +from librosa.filters import mel as librosa_mel_fn +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.io import load_fsspec +from trainer.torch import DistributedSampler, DistributedSamplerWrapper +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample +from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel +from TTS.tts.layers.losses import ForwardSumLoss, VitsDiscriminatorLoss +from TTS.tts.layers.vits.discriminator import VitsDiscriminator +from TTS.tts.models.base_tts import BaseTTSE2E +from TTS.tts.utils.helpers import average_over_durations, compute_attn_prior, rand_segments, segment, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_pitch, plot_spectrogram +from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0 +from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy +from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy +from TTS.utils.audio.processor import AudioProcessor +from TTS.vocoder.layers.losses import MultiScaleSTFTLoss +from TTS.vocoder.models.hifigan_generator import HifiganGenerator +from TTS.vocoder.utils.generic_utils import plot_results + +logger = logging.getLogger(__name__) + + +def id_to_torch(aux_id, cuda=False): + if aux_id is not None: + aux_id = np.asarray(aux_id) + aux_id = torch.from_numpy(aux_id) + if cuda: + return aux_id.cuda() + return aux_id + + +def embedding_to_torch(d_vector, cuda=False): + if d_vector is not None: + d_vector = np.asarray(d_vector) + d_vector = torch.from_numpy(d_vector).float() + d_vector = d_vector.squeeze().unsqueeze(0) + if cuda: + return d_vector.cuda() + return d_vector + + +def numpy_to_torch(np_array, dtype, cuda=False): + if np_array is None: + return None + tensor = torch.as_tensor(np_array, dtype=dtype) + if cuda: + return tensor.cuda() + return tensor + + +def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor: + batch_size = lengths.shape[0] + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1) + mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) + return mask + + +def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor: + out_list = torch.jit.annotate(List[torch.Tensor], []) + for batch in input_ele: + if len(batch.shape) == 1: + one_batch_padded = F.pad(batch, (0, max_len - batch.size(0)), "constant", 0.0) + else: + one_batch_padded = F.pad(batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0) + out_list.append(one_batch_padded) + out_padded = torch.stack(out_list) + return out_padded + + +def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: + return torch.ceil(lens / stride).int() + + +def initialize_embeddings(shape: Tuple[int]) -> torch.Tensor: + assert len(shape) == 2, "Can only initialize 2-D embedding matrices ..." + return torch.randn(shape) * np.sqrt(2 / shape[1]) + + +# pylint: disable=redefined-outer-name +def calc_same_padding(kernel_size: int) -> Tuple[int, int]: + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + + +hann_window = {} +mel_basis = {} + + +@torch.no_grad() +def weights_reset(m: nn.Module): + # check if the current module has reset_parameters and if it is reset the weight + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + + +def get_module_weights_sum(mdl: nn.Module): + dict_sums = {} + for name, w in mdl.named_parameters(): + if "weight" in name: + value = w.data.sum().item() + dict_sums[name] = value + return dict_sums + + +def load_audio(file_path: str): + """Load the audio file normalized in [-1, 1] + + Return Shapes: + - x: :math:`[1, T]` + """ + x, sr = torchaudio.load( + file_path, + ) + assert (x > 1).sum() + (x < -1).sum() == 0 + return x, sr + + +def _amp_to_db(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def _db_to_amp(x, C=1): + return torch.exp(x) / C + + +def amp_to_db(magnitudes): + output = _amp_to_db(magnitudes) + return output + + +def db_to_amp(magnitudes): + output = _db_to_amp(magnitudes) + return output + + +def _wav_to_spec(y, n_fft, hop_length, win_length, center=False): + y = y.squeeze(1) + + if torch.min(y) < -1.0: + logger.info("min value is %.3f", torch.min(y)) + if torch.max(y) > 1.0: + logger.info("max value is %.3f", torch.max(y)) + + global hann_window # pylint: disable=global-statement + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + return spec + + +def wav_to_spec(y, n_fft, hop_length, win_length, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def wav_to_energy(y, n_fft, hop_length, win_length, center=False): + spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return torch.norm(spec, dim=1, keepdim=True) + + +def name_mel_basis(spec, n_fft, fmax): + n_fft_len = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}" + return n_fft_len + + +def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): + """ + Args Shapes: + - spec : :math:`[B,C,T]` + + Return Shapes: + - mel : :math:`[B,C,T]` + """ + global mel_basis # pylint: disable=global-statement + mel_basis_key = name_mel_basis(spec, n_fft, fmax) + # pylint: disable=too-many-function-args + if mel_basis_key not in mel_basis: + # pylint: disable=missing-kwoa + mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) + mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel = torch.matmul(mel_basis[mel_basis_key], spec) + mel = amp_to_db(mel) + return mel + + +def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T_y]` + + Return Shapes: + - spec : :math:`[B,C,T_spec]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + logger.info("min value is %.3f", torch.min(y)) + if torch.max(y) > 1.0: + logger.info("max value is %.3f", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + mel_basis_key = name_mel_basis(y, n_fft, fmax) + wnsize_dtype_device = str(win_length) + "_" + str(y.dtype) + "_" + str(y.device) + if mel_basis_key not in mel_basis: + # pylint: disable=missing-kwoa + mel = librosa_mel_fn( + sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) # pylint: disable=too-many-function-args + mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = torch.matmul(mel_basis[mel_basis_key], spec) + spec = amp_to_db(spec) + return spec + + +############################## +# DATASET +############################## + + +def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None): + """Create balancer weight for torch WeightedSampler""" + attr_names_samples = np.array([item[attr_name] for item in items]) + unique_attr_names = np.unique(attr_names_samples).tolist() + attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] + attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) + weight_attr = 1.0 / attr_count + dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + if multi_dict is not None: + multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) + dataset_samples_weight *= multiplier_samples + return ( + torch.from_numpy(dataset_samples_weight).float(), + unique_attr_names, + np.unique(dataset_samples_weight).tolist(), + ) + + +class ForwardTTSE2eF0Dataset(F0Dataset): + """Override F0Dataset to avoid slow computing of pitches""" + + def __init__( + self, + ap, + samples: Union[List[List], List[Dict]], + cache_path: str = None, + precompute_num_workers=0, + normalize_f0=True, + ): + super().__init__( + samples=samples, + ap=ap, + cache_path=cache_path, + precompute_num_workers=precompute_num_workers, + normalize_f0=normalize_f0, + ) + + def _compute_and_save_pitch(self, wav_file, pitch_file=None): + wav, _ = load_audio(wav_file) + f0 = compute_f0( + x=wav.numpy()[0], + sample_rate=self.ap.sample_rate, + hop_length=self.ap.hop_length, + pitch_fmax=self.ap.pitch_fmax, + pitch_fmin=self.ap.pitch_fmin, + win_length=self.ap.win_length, + ) + # skip the last F0 value to align with the spectrogram + if wav.shape[1] % self.ap.hop_length != 0: + f0 = f0[:-1] + if pitch_file: + np.save(pitch_file, f0) + return f0 + + def compute_or_load(self, wav_file, audio_name): + """ + compute pitch and return a numpy array of pitch values + """ + pitch_file = self.create_pitch_file_path(audio_name, self.cache_path) + if not os.path.exists(pitch_file): + pitch = self._compute_and_save_pitch(wav_file=wav_file, pitch_file=pitch_file) + else: + pitch = np.load(pitch_file) + return pitch.astype(np.float32) + + +class ForwardTTSE2eDataset(TTSDataset): + def __init__(self, *args, **kwargs): + # don't init the default F0Dataset in TTSDataset + compute_f0 = kwargs.pop("compute_f0", False) + kwargs["compute_f0"] = False + self.attn_prior_cache_path = kwargs.pop("attn_prior_cache_path") + + super().__init__(*args, **kwargs) + + self.compute_f0 = compute_f0 + self.pad_id = self.tokenizer.characters.pad_id + self.ap = kwargs["ap"] + + if self.compute_f0: + self.f0_dataset = ForwardTTSE2eF0Dataset( + ap=self.ap, + samples=self.samples, + cache_path=kwargs["f0_cache_path"], + precompute_num_workers=kwargs["precompute_num_workers"], + ) + + if self.attn_prior_cache_path is not None: + os.makedirs(self.attn_prior_cache_path, exist_ok=True) + + def __getitem__(self, idx): + item = self.samples[idx] + + rel_wav_path = Path(item["audio_file"]).relative_to(item["root_path"]).with_suffix("") + rel_wav_path = str(rel_wav_path).replace("/", "_") + + raw_text = item["text"] + wav, _ = load_audio(item["audio_file"]) + wav_filename = os.path.basename(item["audio_file"]) + + try: + token_ids = self.get_token_ids(idx, item["text"]) + except: + logger.exception("%s %s", idx, item) + # pylint: disable=raise-missing-from + raise OSError + f0 = None + if self.compute_f0: + f0 = self.get_f0(idx)["f0"] + + # after phonemization the text length may change + # this is a shameful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: + self.rescue_item_idx += 1 + return self.__getitem__(self.rescue_item_idx) + + attn_prior = None + if self.attn_prior_cache_path is not None: + attn_prior = self.load_or_compute_attn_prior(token_ids, wav, rel_wav_path) + + return { + "raw_text": raw_text, + "token_ids": token_ids, + "token_len": len(token_ids), + "wav": wav, + "pitch": f0, + "wav_file": wav_filename, + "speaker_name": item["speaker_name"], + "language_name": item["language"], + "attn_prior": attn_prior, + "audio_unique_name": item["audio_unique_name"], + } + + def load_or_compute_attn_prior(self, token_ids, wav, rel_wav_path): + """Load or compute and save the attention prior.""" + attn_prior_file = os.path.join(self.attn_prior_cache_path, f"{rel_wav_path}.npy") + # pylint: disable=no-else-return + if os.path.exists(attn_prior_file): + return np.load(attn_prior_file) + else: + token_len = len(token_ids) + mel_len = wav.shape[1] // self.ap.hop_length + attn_prior = compute_attn_prior(token_len, mel_len) + np.save(attn_prior_file, attn_prior) + return attn_prior + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + def collate_fn(self, batch): + """ + Return Shapes: + - tokens: :math:`[B, T]` + - token_lens :math:`[B]` + - token_rel_lens :math:`[B]` + - pitch :math:`[B, T]` + - waveform: :math:`[B, 1, T]` + - waveform_lens: :math:`[B]` + - waveform_rel_lens: :math:`[B]` + - speaker_names: :math:`[B]` + - language_names: :math:`[B]` + - audiofile_paths: :math:`[B]` + - raw_texts: :math:`[B]` + - attn_prior: :math:`[[T_token, T_mel]]` + """ + B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + max_text_len = max([len(x) for x in batch["token_ids"]]) + token_lens = torch.LongTensor(batch["token_len"]) + token_rel_lens = token_lens / token_lens.max() + + wav_lens = [w.shape[1] for w in batch["wav"]] + wav_lens = torch.LongTensor(wav_lens) + wav_lens_max = torch.max(wav_lens) + wav_rel_lens = wav_lens / wav_lens_max + + pitch_padded = None + if self.compute_f0: + pitch_lens = [p.shape[0] for p in batch["pitch"]] + pitch_lens = torch.LongTensor(pitch_lens) + pitch_lens_max = torch.max(pitch_lens) + pitch_padded = torch.FloatTensor(B, 1, pitch_lens_max) + pitch_padded = pitch_padded.zero_() + self.pad_id + + token_padded = torch.LongTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, wav_lens_max) + + token_padded = token_padded.zero_() + self.pad_id + wav_padded = wav_padded.zero_() + self.pad_id + + for i in range(B): + token_ids = batch["token_ids"][i] + token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) + + wav = batch["wav"][i] + wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) + + if self.compute_f0: + pitch = batch["pitch"][i] + pitch_padded[i, 0, : len(pitch)] = torch.FloatTensor(pitch) + + return { + "text_input": token_padded, + "text_lengths": token_lens, + "text_rel_lens": token_rel_lens, + "pitch": pitch_padded, + "waveform": wav_padded, # (B x T) + "waveform_lens": wav_lens, # (B) + "waveform_rel_lens": wav_rel_lens, + "speaker_names": batch["speaker_name"], + "language_names": batch["language_name"], + "audio_unique_names": batch["audio_unique_name"], + "audio_files": batch["wav_file"], + "raw_text": batch["raw_text"], + "attn_priors": batch["attn_prior"] if batch["attn_prior"][0] is not None else None, + } + + +############################## +# CONFIG DEFINITIONS +############################## + + +@dataclass +class VocoderConfig(Coqpit): + resblock_type_decoder: str = "1" + resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11]) + resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) + upsample_initial_channel_decoder: int = 512 + upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + use_spectral_norm_discriminator: bool = False + upsampling_rates_discriminator: List[int] = field(default_factory=lambda: [4, 4, 4, 4]) + periods_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11]) + pretrained_model_path: Optional[str] = None + + +@dataclass +class DelightfulTtsAudioConfig(Coqpit): + sample_rate: int = 22050 + hop_length: int = 256 + win_length: int = 1024 + fft_size: int = 1024 + mel_fmin: float = 0.0 + mel_fmax: float = 8000 + num_mels: int = 100 + pitch_fmax: float = 640.0 + pitch_fmin: float = 1.0 + resample: bool = False + preemphasis: float = 0.0 + ref_level_db: int = 20 + do_sound_norm: bool = False + log_func: str = "np.log10" + do_trim_silence: bool = True + trim_db: int = 45 + do_rms_norm: bool = False + db_level: float = None + power: float = 1.5 + griffin_lim_iters: int = 60 + spec_gain: int = 20 + do_amp_to_db_linear: bool = True + do_amp_to_db_mel: bool = True + min_level_db: int = -100 + max_norm: float = 4.0 + + +@dataclass +class DelightfulTtsArgs(Coqpit): + num_chars: int = 100 + spec_segment_size: int = 32 + n_hidden_conformer_encoder: int = 512 + n_layers_conformer_encoder: int = 6 + n_heads_conformer_encoder: int = 8 + dropout_conformer_encoder: float = 0.1 + kernel_size_conv_mod_conformer_encoder: int = 7 + kernel_size_depthwise_conformer_encoder: int = 7 + lrelu_slope: float = 0.3 + n_hidden_conformer_decoder: int = 512 + n_layers_conformer_decoder: int = 6 + n_heads_conformer_decoder: int = 8 + dropout_conformer_decoder: float = 0.1 + kernel_size_conv_mod_conformer_decoder: int = 11 + kernel_size_depthwise_conformer_decoder: int = 11 + bottleneck_size_p_reference_encoder: int = 4 + bottleneck_size_u_reference_encoder: int = 512 + ref_enc_filters_reference_encoder = [32, 32, 64, 64, 128, 128] + ref_enc_size_reference_encoder: int = 3 + ref_enc_strides_reference_encoder = [1, 2, 1, 2, 1] + ref_enc_pad_reference_encoder = [1, 1] + ref_enc_gru_size_reference_encoder: int = 32 + ref_attention_dropout_reference_encoder: float = 0.2 + token_num_reference_encoder: int = 32 + predictor_kernel_size_reference_encoder: int = 5 + n_hidden_variance_adaptor: int = 512 + kernel_size_variance_adaptor: int = 5 + dropout_variance_adaptor: float = 0.5 + n_bins_variance_adaptor: int = 256 + emb_kernel_size_variance_adaptor: int = 3 + use_speaker_embedding: bool = False + num_speakers: int = 0 + speakers_file: str = None + d_vector_file: str = None + speaker_embedding_channels: int = 384 + use_d_vector_file: bool = False + d_vector_dim: int = 0 + freeze_vocoder: bool = False + freeze_text_encoder: bool = False + freeze_duration_predictor: bool = False + freeze_pitch_predictor: bool = False + freeze_energy_predictor: bool = False + freeze_basis_vectors_predictor: bool = False + freeze_decoder: bool = False + length_scale: float = 1.0 + + +############################## +# MODEL DEFINITION +############################## +class DelightfulTTS(BaseTTSE2E): + """ + Paper:: + https://arxiv.org/pdf/2110.12612.pdf + + Paper Abstract:: + This paper describes the Microsoft end-to-end neural text to speech (TTS) system: DelightfulTTS for Blizzard Challenge 2021. + The goal of this challenge is to synthesize natural and high-quality speech from text, and we approach this goal in two perspectives: + The first is to directly model and generate waveform in 48 kHz sampling rate, which brings higher perception quality than previous systems + with 16 kHz or 24 kHz sampling rate; The second is to model the variation information in speech through a systematic design, which improves + the prosody and naturalness. Specifically, for 48 kHz modeling, we predict 16 kHz mel-spectrogram in acoustic model, and + propose a vocoder called HiFiNet to directly generate 48 kHz waveform from predicted 16 kHz mel-spectrogram, which can better trade off training + efficiency, modelling stability and voice quality. We model variation information systematically from both explicit (speaker ID, language ID, pitch and duration) and + implicit (utterance-level and phoneme-level prosody) perspectives: 1) For speaker and language ID, we use lookup embedding in training and + inference; 2) For pitch and duration, we extract the values from paired text-speech data in training and use two predictors to predict the values in inference; 3) + For utterance-level and phoneme-level prosody, we use two reference encoders to extract the values in training, and use two separate predictors to predict the values in inference. + Additionally, we introduce an improved Conformer block to better model the local and global dependency in acoustic model. For task SH1, DelightfulTTS achieves 4.17 mean score in MOS test + and 4.35 in SMOS test, which indicates the effectiveness of our proposed system + + + Model training:: + text --> ForwardTTS() --> spec_hat --> rand_seg_select()--> GANVocoder() --> waveform_seg + spec --------^ + + Examples: + >>> from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eConfig + >>> config = ForwardTTSE2eConfig() + >>> model = ForwardTTSE2e(config) + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + config: Coqpit, + ap, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config=config, ap=ap, tokenizer=tokenizer, speaker_manager=speaker_manager) + self.ap = ap + + self._set_model_args(config) + self.init_multispeaker(config) + self.binary_loss_weight = None + + self.args.out_channels = self.config.audio.num_mels + self.args.num_mels = self.config.audio.num_mels + self.acoustic_model = AcousticModel(args=self.args, tokenizer=tokenizer, speaker_manager=speaker_manager) + + self.waveform_decoder = HifiganGenerator( + self.config.audio.num_mels, + 1, + self.config.vocoder.resblock_type_decoder, + self.config.vocoder.resblock_dilation_sizes_decoder, + self.config.vocoder.resblock_kernel_sizes_decoder, + self.config.vocoder.upsample_kernel_sizes_decoder, + self.config.vocoder.upsample_initial_channel_decoder, + self.config.vocoder.upsample_rates_decoder, + inference_padding=0, + # cond_channels=self.embedded_speaker_dim, + conv_pre_weight_norm=False, + conv_post_weight_norm=False, + conv_post_bias=False, + ) + + if self.config.init_discriminator: + self.disc = VitsDiscriminator( + use_spectral_norm=self.config.vocoder.use_spectral_norm_discriminator, + periods=self.config.vocoder.periods_discriminator, + ) + + @property + def device(self): + return next(self.parameters()).device + + @property + def energy_scaler(self): + return self.acoustic_model.energy_scaler + + @property + def length_scale(self): + return self.acoustic_model.length_scale + + @length_scale.setter + def length_scale(self, value): + self.acoustic_model.length_scale = value + + @property + def pitch_mean(self): + return self.acoustic_model.pitch_mean + + @pitch_mean.setter + def pitch_mean(self, value): + self.acoustic_model.pitch_mean = value + + @property + def pitch_std(self): + return self.acoustic_model.pitch_std + + @pitch_std.setter + def pitch_std(self, value): + self.acoustic_model.pitch_std = value + + @property + def mel_basis(self): + return build_mel_basis( + sample_rate=self.ap.sample_rate, + fft_size=self.ap.fft_size, + num_mels=self.ap.num_mels, + mel_fmax=self.ap.mel_fmax, + mel_fmin=self.ap.mel_fmin, + ) # pylint: disable=function-redefined + + def init_for_training(self) -> None: + self.train_disc = ( # pylint: disable=attribute-defined-outside-init + self.config.steps_to_start_discriminator <= 0 + ) # pylint: disable=attribute-defined-outside-init + self.update_energy_scaler = True # pylint: disable=attribute-defined-outside-init + + def init_multispeaker(self, config: Coqpit): + """Init for multi-speaker training. + + Args: + config (Coqpit): Model configuration. + """ + self.embedded_speaker_dim = 0 + self.num_speakers = self.args.num_speakers + self.audio_transform = None + + if self.speaker_manager: + self.num_speakers = self.speaker_manager.num_speakers + self.args.num_speakers = self.speaker_manager.num_speakers + + if self.args.use_speaker_embedding: + self._init_speaker_embedding() + + if self.args.use_d_vector_file: + self._init_d_vector() + + def _init_speaker_embedding(self): + # pylint: disable=attribute-defined-outside-init + if self.num_speakers > 0: + logger.info("Initialization of speaker-embedding layers.") + self.embedded_speaker_dim = self.args.speaker_embedding_channels + self.args.embedded_speaker_dim = self.args.speaker_embedding_channels + + def _init_d_vector(self): + # pylint: disable=attribute-defined-outside-init + if hasattr(self, "emb_g"): + raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") + self.embedded_speaker_dim = self.args.d_vector_dim + self.args.embedded_speaker_dim = self.args.d_vector_dim + + def _freeze_layers(self): + if self.args.freeze_vocoder: + for param in self.vocoder.paramseters(): + param.requires_grad = False + + if self.args.freeze_text_encoder: + for param in self.text_encoder.parameters(): + param.requires_grad = False + + if self.args.freeze_duration_predictor: + for param in self.durarion_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_pitch_predictor: + for param in self.pitch_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_energy_predictor: + for param in self.energy_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_decoder: + for param in self.decoder.parameters(): + param.requires_grad = False + + def forward( + self, + x: torch.LongTensor, + x_lengths: torch.LongTensor, + spec_lengths: torch.LongTensor, + spec: torch.FloatTensor, + waveform: torch.FloatTensor, + pitch: torch.FloatTensor = None, + energy: torch.FloatTensor = None, + attn_priors: torch.FloatTensor = None, + d_vectors: torch.FloatTensor = None, + speaker_idx: torch.LongTensor = None, + ) -> Dict: + """Model's forward pass. + + Args: + x (torch.LongTensor): Input character sequences. + x_lengths (torch.LongTensor): Input sequence lengths. + spec_lengths (torch.LongTensor): Spectrogram sequnce lengths. Defaults to None. + spec (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None. + waveform (torch.FloatTensor): Waveform. Defaults to None. + pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None. + energy (torch.FloatTensor): Spectral energy values for each spectrogram frame. Only used when the energy predictor is on. Defaults to None. + attn_priors (torch.FloatTentrasor): Attention priors for the aligner network. Defaults to None. + aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. + + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - spec_lengths: :math:`[B]` + - spec: :math:`[B, T_max2, C_spec]` + - waveform: :math:`[B, 1, T_max2 * hop_length]` + - g: :math:`[B, C]` + - pitch: :math:`[B, 1, T_max2]` + - energy: :math:`[B, 1, T_max2]` + """ + encoder_outputs = self.acoustic_model( + tokens=x, + src_lens=x_lengths, + mel_lens=spec_lengths, + mels=spec, + pitches=pitch, + energies=energy, + attn_priors=attn_priors, + d_vectors=d_vectors, + speaker_idx=speaker_idx, + ) + + # use mel-spec from the decoder + vocoder_input = encoder_outputs["model_outputs"] # [B, T_max2, C_mel] + + vocoder_input_slices, slice_ids = rand_segments( + x=vocoder_input.transpose(1, 2), + x_lengths=spec_lengths, + segment_size=self.args.spec_segment_size, + let_short_samples=True, + pad_short=True, + ) + if encoder_outputs["spk_emb"] is not None: + g = encoder_outputs["spk_emb"].unsqueeze(-1) + else: + g = None + + vocoder_output = self.waveform_decoder(x=vocoder_input_slices.detach(), g=g) + wav_seg = segment( + waveform, + slice_ids * self.ap.hop_length, + self.args.spec_segment_size * self.ap.hop_length, + pad_short=True, + ) + model_outputs = {**encoder_outputs} + model_outputs["acoustic_model_outputs"] = encoder_outputs["model_outputs"] + model_outputs["model_outputs"] = vocoder_output + model_outputs["waveform_seg"] = wav_seg + model_outputs["slice_ids"] = slice_ids + return model_outputs + + @torch.no_grad() + def inference( + self, x, aux_input={"d_vectors": None, "speaker_ids": None}, pitch_transform=None, energy_transform=None + ): + encoder_outputs = self.acoustic_model.inference( + tokens=x, + d_vectors=aux_input["d_vectors"], + speaker_idx=aux_input["speaker_ids"], + pitch_transform=pitch_transform, + energy_transform=energy_transform, + p_control=None, + d_control=None, + ) + vocoder_input = encoder_outputs["model_outputs"].transpose(1, 2) # [B, T_max2, C_mel] -> [B, C_mel, T_max2] + if encoder_outputs["spk_emb"] is not None: + g = encoder_outputs["spk_emb"].unsqueeze(-1) + else: + g = None + + vocoder_output = self.waveform_decoder(x=vocoder_input, g=g) + model_outputs = {**encoder_outputs} + model_outputs["model_outputs"] = vocoder_output + return model_outputs + + @torch.no_grad() + def inference_spec_decoder(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): + encoder_outputs = self.acoustic_model.inference( + tokens=x, + d_vectors=aux_input["d_vectors"], + speaker_idx=aux_input["speaker_ids"], + ) + model_outputs = {**encoder_outputs} + return model_outputs + + def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): + if optimizer_idx == 0: + tokens = batch["text_input"] + token_lenghts = batch["text_lengths"] + mel = batch["mel_input"] + mel_lens = batch["mel_lengths"] + waveform = batch["waveform"] # [B, T, C] -> [B, C, T] + pitch = batch["pitch"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + attn_priors = batch["attn_priors"] + energy = batch["energy"] + + # generator pass + outputs = self.forward( + x=tokens, + x_lengths=token_lenghts, + spec_lengths=mel_lens, + spec=mel, + waveform=waveform, + pitch=pitch, + energy=energy, + attn_priors=attn_priors, + d_vectors=d_vectors, + speaker_idx=speaker_ids, + ) + + # cache tensors for the generator pass + self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init + + if self.train_disc: + # compute scores and features + scores_d_fake, _, scores_d_real, _ = self.disc( + outputs["model_outputs"].detach(), outputs["waveform_seg"] + ) + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + scores_disc_fake=scores_d_fake, + scores_disc_real=scores_d_real, + ) + return outputs, loss_dict + return None, None + + if optimizer_idx == 1: + mel = batch["mel_input"] + # compute melspec segment + with autocast(enabled=False): + mel_slice = segment( + mel.float(), self.model_outputs_cache["slice_ids"], self.args.spec_segment_size, pad_short=True + ) + + mel_slice_hat = wav_to_mel( + y=self.model_outputs_cache["model_outputs"].float(), + n_fft=self.ap.fft_size, + sample_rate=self.ap.sample_rate, + num_mels=self.ap.num_mels, + hop_length=self.ap.hop_length, + win_length=self.ap.win_length, + fmin=self.ap.mel_fmin, + fmax=self.ap.mel_fmax, + center=False, + ) + + scores_d_fake = None + feats_d_fake = None + feats_d_real = None + + if self.train_disc: + # compute discriminator scores and features + scores_d_fake, feats_d_fake, _, feats_d_real = self.disc( + self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] + ) + + # compute losses + with autocast(enabled=True): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + mel_output=self.model_outputs_cache["acoustic_model_outputs"].transpose(1, 2), + mel_target=batch["mel_input"], + mel_lens=batch["mel_lengths"], + dur_output=self.model_outputs_cache["dr_log_pred"], + dur_target=self.model_outputs_cache["dr_log_target"].detach(), + pitch_output=self.model_outputs_cache["pitch_pred"], + pitch_target=self.model_outputs_cache["pitch_target"], + energy_output=self.model_outputs_cache["energy_pred"], + energy_target=self.model_outputs_cache["energy_target"], + src_lens=batch["text_lengths"], + waveform=self.model_outputs_cache["waveform_seg"], + waveform_hat=self.model_outputs_cache["model_outputs"], + p_prosody_ref=self.model_outputs_cache["p_prosody_ref"], + p_prosody_pred=self.model_outputs_cache["p_prosody_pred"], + u_prosody_ref=self.model_outputs_cache["u_prosody_ref"], + u_prosody_pred=self.model_outputs_cache["u_prosody_pred"], + aligner_logprob=self.model_outputs_cache["aligner_logprob"], + aligner_hard=self.model_outputs_cache["aligner_mas"], + aligner_soft=self.model_outputs_cache["aligner_soft"], + binary_loss_weight=self.binary_loss_weight, + feats_fake=feats_d_fake, + feats_real=feats_d_real, + scores_fake=scores_d_fake, + spec_slice=mel_slice, + spec_slice_hat=mel_slice_hat, + skip_disc=not self.train_disc, + ) + + loss_dict["avg_text_length"] = batch["text_lengths"].float().mean() + loss_dict["avg_mel_length"] = batch["mel_lengths"].float().mean() + loss_dict["avg_text_batch_occupancy"] = ( + batch["text_lengths"].float() / batch["text_lengths"].float().max() + ).mean() + loss_dict["avg_mel_batch_occupancy"] = ( + batch["mel_lengths"].float() / batch["mel_lengths"].float().max() + ).mean() + + return self.model_outputs_cache, loss_dict + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): + return self.train_step(batch, criterion, optimizer_idx) + + def _log(self, batch, outputs, name_prefix="train"): + figures, audios = {}, {} + + # encoder outputs + model_outputs = outputs[1]["acoustic_model_outputs"] + alignments = outputs[1]["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, None, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec.T, None, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # plot pitch figures + pitch_avg = abs(outputs[1]["pitch_target"][0, 0].data.cpu().numpy()) + pitch_avg_hat = abs(outputs[1]["pitch_pred"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + pitch_figures = { + "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False), + "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False), + } + figures.update(pitch_figures) + + # plot energy figures + energy_avg = abs(outputs[1]["energy_target"][0, 0].data.cpu().numpy()) + energy_avg_hat = abs(outputs[1]["energy_pred"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + energy_figures = { + "energy_ground_truth": plot_avg_pitch(energy_avg, chars, output_fig=False), + "energy_avg_predicted": plot_avg_pitch(energy_avg_hat, chars, output_fig=False), + } + figures.update(energy_figures) + + # plot the attention mask computed from the predicted durations + alignments_hat = outputs[1]["alignments_dp"][0].data.cpu().numpy() + figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) + + # Sample audio + encoder_audio = mel_to_wav_numpy( + mel=db_to_amp_numpy(x=pred_spec.T, gain=1, base=None), mel_basis=self.mel_basis, **self.config.audio + ) + audios[f"{name_prefix}/encoder_audio"] = encoder_audio + + # vocoder outputs + y_hat = outputs[1]["model_outputs"] + y = outputs[1]["waveform_seg"] + + vocoder_figures = plot_results(y_hat=y_hat, y=y, ap=self.ap, name_prefix=name_prefix) + figures.update(vocoder_figures) + + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + audios[f"{name_prefix}/vocoder_audio"] = sample_voice + return figures, audios + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=no-self-use, unused-argument + """Create visualizations and waveform examples. + + For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to + be projected onto Tensorboard. + + Args: + batch (Dict): Model inputs used at the previous training step. + outputs (Dict): Model outputs generated at the previous training step. + + Returns: + Tuple[Dict, np.ndarray]: training plots and output waveform. + """ + figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/") + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._log(batch=batch, outputs=outputs, name_prefix="vocoder/") + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav = None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector = None, None + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_embedding() + else: + d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_id() + else: + speaker_id = self.speaker_manager.name_to_id[speaker_name] + + return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector} + + def plot_outputs(self, text, wav, alignment, outputs): + figures = {} + pitch_avg_pred = outputs["pitch"].cpu() + energy_avg_pred = outputs["energy"].cpu() + spec = wav_to_mel( + y=torch.from_numpy(wav[None, :]), + n_fft=self.ap.fft_size, + sample_rate=self.ap.sample_rate, + num_mels=self.ap.num_mels, + hop_length=self.ap.hop_length, + win_length=self.ap.win_length, + fmin=self.ap.mel_fmin, + fmax=self.ap.mel_fmax, + center=False, + )[0].transpose(0, 1) + pitch = compute_f0( + x=wav[0], + sample_rate=self.ap.sample_rate, + hop_length=self.ap.hop_length, + pitch_fmax=self.ap.pitch_fmax, + ) + input_text = self.tokenizer.ids_to_text(self.tokenizer.text_to_ids(text, language="en")) + input_text = input_text.replace("", "_") + durations = outputs["durations"] + pitch_avg = average_over_durations(torch.from_numpy(pitch)[None, None, :], durations.cpu()) # [1, 1, n_frames] + pitch_avg_pred_denorm = (pitch_avg_pred * self.pitch_std) + self.pitch_mean + figures["alignment"] = plot_alignment(alignment.transpose(1, 2), output_fig=False) + figures["spectrogram"] = plot_spectrogram(spec) + figures["pitch_from_wav"] = plot_pitch(pitch, spec) + figures["pitch_avg_from_wav"] = plot_avg_pitch(pitch_avg.squeeze(), input_text) + figures["pitch_avg_pred"] = plot_avg_pitch(pitch_avg_pred_denorm.squeeze(), input_text) + figures["energy_avg_pred"] = plot_avg_pitch(energy_avg_pred.squeeze(), input_text) + return figures + + def synthesize( + self, + text: str, + speaker_id: str = None, + d_vector: torch.tensor = None, + pitch_transform=None, + **kwargs, + ): # pylint: disable=unused-argument + # TODO: add cloning support with ref_waveform + is_cuda = next(self.parameters()).is_cuda + + # convert text to sequence of token IDs + text_inputs = np.asarray( + self.tokenizer.text_to_ids(text, language=None), + dtype=np.int32, + ) + + # set speaker inputs + _speaker_id = None + if speaker_id is not None and self.args.use_speaker_embedding: + if isinstance(speaker_id, str) and self.args.use_speaker_embedding: + # get the speaker id for the speaker embedding layer + _speaker_id = self.speaker_manager.name_to_id[speaker_id] + _speaker_id = id_to_torch(_speaker_id, cuda=is_cuda) + + if speaker_id is not None and self.args.use_d_vector_file: + # get the average d_vector for the speaker + d_vector = self.speaker_manager.get_mean_embedding(speaker_id, num_samples=None, randomize=False) + d_vector = embedding_to_torch(d_vector, cuda=is_cuda) + + text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) + text_inputs = text_inputs.unsqueeze(0) + + # synthesize voice + outputs = self.inference( + text_inputs, + aux_input={"d_vectors": d_vector, "speaker_ids": _speaker_id}, + pitch_transform=pitch_transform, + # energy_transform=energy_transform + ) + + # collect outputs + wav = outputs["model_outputs"][0].data.cpu().numpy() + alignments = outputs["alignments"] + return_dict = { + "wav": wav, + "alignments": alignments, + "text_inputs": text_inputs, + "outputs": outputs, + } + return return_dict + + def synthesize_with_gl(self, text: str, speaker_id, d_vector): + is_cuda = next(self.parameters()).is_cuda + + # convert text to sequence of token IDs + text_inputs = np.asarray( + self.tokenizer.text_to_ids(text, language=None), + dtype=np.int32, + ) + # pass tensors to backend + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, cuda=is_cuda) + + if d_vector is not None: + d_vector = embedding_to_torch(d_vector, cuda=is_cuda) + + text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) + text_inputs = text_inputs.unsqueeze(0) + + # synthesize voice + outputs = self.inference_spec_decoder( + x=text_inputs, + aux_input={"d_vectors": d_vector, "speaker_ids": speaker_id}, + ) + + # collect outputs + S = outputs["model_outputs"].cpu().numpy()[0].T + S = db_to_amp_numpy(x=S, gain=1, base=None) + wav = mel_to_wav_numpy(mel=S, mel_basis=self.mel_basis, **self.config.audio) + alignments = outputs["alignments"] + return_dict = { + "wav": wav[None, :], + "alignments": alignments, + "text_inputs": text_inputs, + "outputs": outputs, + } + return return_dict + + @torch.no_grad() + def test_run(self, assets) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + logger.info("Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + for idx, s_info in enumerate(test_sentences): + aux_inputs = self.get_aux_input_from_test_sentences(s_info) + outputs = self.synthesize( + aux_inputs["text"], + config=self.config, + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + ) + outputs_gl = self.synthesize_with_gl( + aux_inputs["text"], + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + ) + # speaker_name = self.speaker_manager.speaker_names[aux_inputs["speaker_id"]] + test_audios["{}-audio".format(idx)] = outputs["wav"].T + test_audios["{}-audio_encoder".format(idx)] = outputs_gl["wav"].T + test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) + return {"figures": test_figures, "audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.config.audio.sample_rate) + logger.test_figures(steps, outputs["figures"]) + + def format_batch(self, batch: Dict) -> Dict: + """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" + speaker_ids = None + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_manager is not None and self.speaker_manager.speaker_names and self.args.use_speaker_embedding: + speaker_ids = [self.speaker_manager.name_to_id[sn] for sn in batch["speaker_names"]] + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + batch["speaker_ids"] = speaker_ids + + # get d_vectors from audio file names + if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file: + d_vector_mapping = self.speaker_manager.embeddings + d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_unique_names"]] + d_vectors = torch.FloatTensor(d_vectors) + + batch["d_vectors"] = d_vectors + batch["speaker_ids"] = speaker_ids + return batch + + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + + ac = self.ap + + # compute spectrograms + batch["mel_input"] = wav_to_mel( + batch["waveform"], + hop_length=ac.hop_length, + win_length=ac.win_length, + n_fft=ac.fft_size, + num_mels=ac.num_mels, + sample_rate=ac.sample_rate, + fmin=ac.mel_fmin, + fmax=ac.mel_fmax, + center=False, + ) + + # TODO: Align pitch properly + # assert ( + # batch["pitch"].shape[2] == batch["mel_input"].shape[2] + # ), f"{batch['pitch'].shape[2]}, {batch['mel_input'].shape[2]}" + batch["pitch"] = batch["pitch"][:, :, : batch["mel_input"].shape[2]] if batch["pitch"] is not None else None + batch["mel_lengths"] = (batch["mel_input"].shape[2] * batch["waveform_rel_lens"]).int() + + # zero the padding frames + batch["mel_input"] = batch["mel_input"] * sequence_mask(batch["mel_lengths"]).unsqueeze(1) + + # format attn priors as we now the max mel length + # TODO: fix 1 diff b/w mel_lengths and attn_priors + + if self.config.use_attn_priors: + attn_priors_np = batch["attn_priors"] + + batch["attn_priors"] = torch.zeros( + batch["mel_input"].shape[0], + batch["mel_lengths"].max(), + batch["text_lengths"].max(), + device=batch["mel_input"].device, + ) + + for i in range(batch["mel_input"].shape[0]): + batch["attn_priors"][i, : attn_priors_np[i].shape[0], : attn_priors_np[i].shape[1]] = torch.from_numpy( + attn_priors_np[i] + ) + + batch["energy"] = None + batch["energy"] = wav_to_energy( # [B, 1, T_max2] + batch["waveform"], + hop_length=ac.hop_length, + win_length=ac.win_length, + n_fft=ac.fft_size, + center=False, + ) + batch["energy"] = self.energy_scaler(batch["energy"]) + return batch + + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1): + weights = None + data_items = dataset.samples + if getattr(config, "use_weighted_sampler", False): + for attr_name, alpha in config.weighted_sampler_attrs.items(): + logger.info("Using weighted sampler for attribute '%s' with alpha %.2f", attr_name, alpha) + multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) + logger.info(multi_dict) + weights, attr_names, attr_weights = get_attribute_balancer_weights( + attr_name=attr_name, items=data_items, multi_dict=multi_dict + ) + weights = weights * alpha + logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights) + + if weights is not None: + sampler = WeightedRandomSampler(weights, len(weights)) + else: + sampler = None + # sampler for DDP + if sampler is None: + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler + return sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # init dataloader + dataset = ForwardTTSE2eDataset( + samples=samples, + ap=self.ap, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + compute_f0=config.compute_f0, + f0_cache_path=config.f0_cache_path, + attn_prior_cache_path=config.attn_prior_cache_path if config.use_attn_priors else None, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences ascendingly by length + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + drop_last=False, # setting this False might cause issues in AMP training. + sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=True, + ) + + # get pitch mean and std + self.pitch_mean = dataset.f0_dataset.mean + self.pitch_std = dataset.f0_dataset.std + return loader + + def get_criterion(self): + return [VitsDiscriminatorLoss(self.config), DelightfulTTSLoss(self.config)] + + def get_optimizer(self) -> List: + """Initiate and return the GAN optimizers based on the config parameters. + It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. + Returns: + List: optimizers. + """ + optimizer_disc = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc + ) + gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) + optimizer_gen = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters + ) + return [optimizer_disc, optimizer_gen] + + def get_lr(self) -> List: + """Set the initial learning rates for each optimizer. + + Returns: + List: learning rates for each optimizer. + """ + return [self.config.lr_disc, self.config.lr_gen] + + def get_scheduler(self, optimizer) -> List: + """Set the schedulers for each optimizer. + + Args: + optimizer (List[`torch.optim.Optimizer`]): List of optimizers. + + Returns: + List: Schedulers, one for each optimizer. + """ + scheduler_D = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) + scheduler_G = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) + return [scheduler_D, scheduler_G] + + def on_epoch_end(self, trainer): # pylint: disable=unused-argument + # stop updating mean and var + # TODO: do the same for F0 + self.energy_scaler.eval() + + @staticmethod + def init_from_config( + config: "DelightfulTTSConfig", samples: Union[List[List], List[Dict]] = None + ): # pylint: disable=unused-argument + """Initiate model from config + + Args: + config (ForwardTTSE2eConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config.model_args, samples) + ap = AudioProcessor.init_from_config(config=config) + return DelightfulTTS(config=new_config, tokenizer=tokenizer, speaker_manager=speaker_manager, ap=ap) + + def load_checkpoint(self, config, checkpoint_path, eval=False): + """Load model from a checkpoint created by the 👟""" + # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_state_dict(self): + """Custom state dict of the model with all the necessary components for inference.""" + save_state = {"config": self.config.to_dict(), "args": self.args.to_dict(), "model": self.state_dict} + + if hasattr(self, "emb_g"): + save_state["speaker_ids"] = self.speaker_manager.speaker_names + + if self.args.use_d_vector_file: + # TODO: implement saving of d_vectors + ... + return save_state + + def save(self, config, checkpoint_path): + """Save model to a file.""" + save_state = self.get_state_dict(config, checkpoint_path) # pylint: disable=too-many-function-args + save_state["pitch_mean"] = self.pitch_mean + save_state["pitch_std"] = self.pitch_std + torch.save(save_state, checkpoint_path) + + def on_train_step_start(self, trainer) -> None: + """Enable the discriminator training based on `steps_to_start_discriminator` + + Args: + trainer (Trainer): Trainer object. + """ + self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 + self.train_disc = ( # pylint: disable=attribute-defined-outside-init + trainer.total_steps_done >= self.config.steps_to_start_discriminator + ) + + +class DelightfulTTSLoss(nn.Module): + def __init__(self, config): + super().__init__() + + self.mse_loss = nn.MSELoss() + self.mae_loss = nn.L1Loss() + self.forward_sum_loss = ForwardSumLoss() + self.multi_scale_stft_loss = MultiScaleSTFTLoss(**config.multi_scale_stft_loss_params) + + self.mel_loss_alpha = config.mel_loss_alpha + self.aligner_loss_alpha = config.aligner_loss_alpha + self.pitch_loss_alpha = config.pitch_loss_alpha + self.energy_loss_alpha = config.energy_loss_alpha + self.u_prosody_loss_alpha = config.u_prosody_loss_alpha + self.p_prosody_loss_alpha = config.p_prosody_loss_alpha + self.dur_loss_alpha = config.dur_loss_alpha + self.char_dur_loss_alpha = config.char_dur_loss_alpha + self.binary_alignment_loss_alpha = config.binary_align_loss_alpha + + self.vocoder_mel_loss_alpha = config.vocoder_mel_loss_alpha + self.feat_loss_alpha = config.feat_loss_alpha + self.gen_loss_alpha = config.gen_loss_alpha + self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha + + @staticmethod + def _binary_alignment_loss(alignment_hard, alignment_soft): + """Binary loss that forces soft alignments to match the hard alignments as + explained in `https://arxiv.org/pdf/2108.10447.pdf`. + """ + log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() + return -log_sum / alignment_hard.sum() + + @staticmethod + def feature_loss(feats_real, feats_generated): + loss = 0 + for dr, dg in zip(feats_real, feats_generated): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + return loss * 2 + + @staticmethod + def generator_loss(scores_fake): + loss = 0 + gen_losses = [] + for dg in scores_fake: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + def forward( + self, + mel_output, + mel_target, + mel_lens, + dur_output, + dur_target, + pitch_output, + pitch_target, + energy_output, + energy_target, + src_lens, + waveform, + waveform_hat, + p_prosody_ref, + p_prosody_pred, + u_prosody_ref, + u_prosody_pred, + aligner_logprob, + aligner_hard, + aligner_soft, + binary_loss_weight=None, + feats_fake=None, + feats_real=None, + scores_fake=None, + spec_slice=None, + spec_slice_hat=None, + skip_disc=False, + ): + """ + Shapes: + - mel_output: :math:`(B, C_mel, T_mel)` + - mel_target: :math:`(B, C_mel, T_mel)` + - mel_lens: :math:`(B)` + - dur_output: :math:`(B, T_src)` + - dur_target: :math:`(B, T_src)` + - pitch_output: :math:`(B, 1, T_src)` + - pitch_target: :math:`(B, 1, T_src)` + - energy_output: :math:`(B, 1, T_src)` + - energy_target: :math:`(B, 1, T_src)` + - src_lens: :math:`(B)` + - waveform: :math:`(B, 1, T_wav)` + - waveform_hat: :math:`(B, 1, T_wav)` + - p_prosody_ref: :math:`(B, T_src, 4)` + - p_prosody_pred: :math:`(B, T_src, 4)` + - u_prosody_ref: :math:`(B, 1, 256) + - u_prosody_pred: :math:`(B, 1, 256) + - aligner_logprob: :math:`(B, 1, T_mel, T_src)` + - aligner_hard: :math:`(B, T_mel, T_src)` + - aligner_soft: :math:`(B, T_mel, T_src)` + - spec_slice: :math:`(B, C_mel, T_mel)` + - spec_slice_hat: :math:`(B, C_mel, T_mel)` + """ + loss_dict = {} + src_mask = sequence_mask(src_lens).to(mel_output.device) # (B, T_src) + mel_mask = sequence_mask(mel_lens).to(mel_output.device) # (B, T_mel) + + dur_target.requires_grad = False + mel_target.requires_grad = False + pitch_target.requires_grad = False + + masked_mel_predictions = mel_output.masked_select(mel_mask[:, None]) + mel_targets = mel_target.masked_select(mel_mask[:, None]) + mel_loss = self.mae_loss(masked_mel_predictions, mel_targets) + + p_prosody_ref = p_prosody_ref.detach() + p_prosody_loss = 0.5 * self.mae_loss( + p_prosody_ref.masked_select(src_mask.unsqueeze(-1)), + p_prosody_pred.masked_select(src_mask.unsqueeze(-1)), + ) + + u_prosody_ref = u_prosody_ref.detach() + u_prosody_loss = 0.5 * self.mae_loss(u_prosody_ref, u_prosody_pred) + + duration_loss = self.mse_loss(dur_output, dur_target) + + pitch_output = pitch_output.masked_select(src_mask[:, None]) + pitch_target = pitch_target.masked_select(src_mask[:, None]) + pitch_loss = self.mse_loss(pitch_output, pitch_target) + + energy_output = energy_output.masked_select(src_mask[:, None]) + energy_target = energy_target.masked_select(src_mask[:, None]) + energy_loss = self.mse_loss(energy_output, energy_target) + + forward_sum_loss = self.forward_sum_loss(aligner_logprob, src_lens, mel_lens) + + total_loss = ( + (mel_loss * self.mel_loss_alpha) + + (duration_loss * self.dur_loss_alpha) + + (u_prosody_loss * self.u_prosody_loss_alpha) + + (p_prosody_loss * self.p_prosody_loss_alpha) + + (pitch_loss * self.pitch_loss_alpha) + + (energy_loss * self.energy_loss_alpha) + + (forward_sum_loss * self.aligner_loss_alpha) + ) + + if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None: + binary_alignment_loss = self._binary_alignment_loss(aligner_hard, aligner_soft) + total_loss = total_loss + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + if binary_loss_weight: + loss_dict["loss_binary_alignment"] = ( + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight + ) + else: + loss_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss + + loss_dict["loss_aligner"] = self.aligner_loss_alpha * forward_sum_loss + loss_dict["loss_mel"] = self.mel_loss_alpha * mel_loss + loss_dict["loss_duration"] = self.dur_loss_alpha * duration_loss + loss_dict["loss_u_prosody"] = self.u_prosody_loss_alpha * u_prosody_loss + loss_dict["loss_p_prosody"] = self.p_prosody_loss_alpha * p_prosody_loss + loss_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss + loss_dict["loss_energy"] = self.energy_loss_alpha * energy_loss + loss_dict["loss"] = total_loss + + # vocoder losses + if not skip_disc: + loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha + loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha + loss_dict["vocoder_loss_feat"] = loss_feat + loss_dict["vocoder_loss_gen"] = loss_gen + loss_dict["loss"] = loss_dict["loss"] + loss_feat + loss_gen + + loss_mel = torch.nn.functional.l1_loss(spec_slice, spec_slice_hat) * self.vocoder_mel_loss_alpha + loss_stft_mg, loss_stft_sc = self.multi_scale_stft_loss(y_hat=waveform_hat, y=waveform) + loss_stft_mg = loss_stft_mg * self.multi_scale_stft_loss_alpha + loss_stft_sc = loss_stft_sc * self.multi_scale_stft_loss_alpha + + loss_dict["vocoder_loss_mel"] = loss_mel + loss_dict["vocoder_loss_stft_mg"] = loss_stft_mg + loss_dict["vocoder_loss_stft_sc"] = loss_stft_sc + + loss_dict["loss"] = loss_dict["loss"] + loss_mel + loss_stft_sc + loss_stft_mg + return loss_dict diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..4b74462dd57b5cf2fbd01a29cb34b9f816756693 --- /dev/null +++ b/TTS/tts/models/forward_tts.py @@ -0,0 +1,865 @@ +import logging +from dataclasses import dataclass, field +from typing import Dict, List, Tuple, Union + +import torch +from coqpit import Coqpit +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from trainer.io import load_fsspec + +from TTS.tts.layers.feed_forward.decoder import Decoder +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.aligner import AlignmentNetwork +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram + +logger = logging.getLogger(__name__) + + +@dataclass +class ForwardTTSArgs(Coqpit): + """ForwardTTS Model arguments. + + Args: + + num_chars (int): + Number of characters in the vocabulary. Defaults to 100. + + out_channels (int): + Number of output channels. Defaults to 80. + + hidden_channels (int): + Number of base hidden channels of the model. Defaults to 512. + + use_aligner (bool): + Whether to use aligner network to learn the text to speech alignment or use pre-computed durations. + If set False, durations should be computed by `TTS/bin/compute_attention_masks.py` and path to the + pre-computed durations must be provided to `config.datasets[0].meta_file_attn_mask`. Defaults to True. + + use_pitch (bool): + Use pitch predictor to learn the pitch. Defaults to True. + + use_energy (bool): + Use energy predictor to learn the energy. Defaults to True. + + duration_predictor_hidden_channels (int): + Number of hidden channels in the duration predictor. Defaults to 256. + + duration_predictor_dropout_p (float): + Dropout rate for the duration predictor. Defaults to 0.1. + + duration_predictor_kernel_size (int): + Kernel size of conv layers in the duration predictor. Defaults to 3. + + pitch_predictor_hidden_channels (int): + Number of hidden channels in the pitch predictor. Defaults to 256. + + pitch_predictor_dropout_p (float): + Dropout rate for the pitch predictor. Defaults to 0.1. + + pitch_predictor_kernel_size (int): + Kernel size of conv layers in the pitch predictor. Defaults to 3. + + pitch_embedding_kernel_size (int): + Kernel size of the projection layer in the pitch predictor. Defaults to 3. + + energy_predictor_hidden_channels (int): + Number of hidden channels in the energy predictor. Defaults to 256. + + energy_predictor_dropout_p (float): + Dropout rate for the energy predictor. Defaults to 0.1. + + energy_predictor_kernel_size (int): + Kernel size of conv layers in the energy predictor. Defaults to 3. + + energy_embedding_kernel_size (int): + Kernel size of the projection layer in the energy predictor. Defaults to 3. + + positional_encoding (bool): + Whether to use positional encoding. Defaults to True. + + positional_encoding_use_scale (bool): + Whether to use a learnable scale coeff in the positional encoding. Defaults to True. + + length_scale (int): + Length scale that multiplies the predicted durations. Larger values result slower speech. Defaults to 1.0. + + encoder_type (str): + Type of the encoder module. One of the encoders available in :class:`TTS.tts.layers.feed_forward.encoder`. + Defaults to `fftransformer` as in the paper. + + encoder_params (dict): + Parameters of the encoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + decoder_type (str): + Type of the decoder module. One of the decoders available in :class:`TTS.tts.layers.feed_forward.decoder`. + Defaults to `fftransformer` as in the paper. + + decoder_params (str): + Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + detach_duration_predictor (bool): + Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss + does not pass to the earlier layers. Defaults to True. + + max_duration (int): + Maximum duration accepted by the model. Defaults to 75. + + num_speakers (int): + Number of speakers for the speaker embedding layer. Defaults to 0. + + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. + + speaker_embedding_channels (int): + Number of speaker embedding channels. Defaults to 256. + + use_d_vector_file (bool): + Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False. + + d_vector_dim (int): + Number of d-vector channels. Defaults to 0. + + """ + + num_chars: int = None + out_channels: int = 80 + hidden_channels: int = 384 + use_aligner: bool = True + # pitch params + use_pitch: bool = True + pitch_predictor_hidden_channels: int = 256 + pitch_predictor_kernel_size: int = 3 + pitch_predictor_dropout_p: float = 0.1 + pitch_embedding_kernel_size: int = 3 + + # energy params + use_energy: bool = False + energy_predictor_hidden_channels: int = 256 + energy_predictor_kernel_size: int = 3 + energy_predictor_dropout_p: float = 0.1 + energy_embedding_kernel_size: int = 3 + + # duration params + duration_predictor_hidden_channels: int = 256 + duration_predictor_kernel_size: int = 3 + duration_predictor_dropout_p: float = 0.1 + + positional_encoding: bool = True + poisitonal_encoding_use_scale: bool = True + length_scale: int = 1 + encoder_type: str = "fftransformer" + encoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + decoder_type: str = "fftransformer" + decoder_params: dict = field( + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} + ) + detach_duration_predictor: bool = False + max_duration: int = 75 + num_speakers: int = 1 + use_speaker_embedding: bool = False + speakers_file: str = None + use_d_vector_file: bool = False + d_vector_dim: int = None + d_vector_file: str = None + + +class ForwardTTS(BaseTTS): + """General forward TTS model implementation that uses an encoder-decoder architecture with an optional alignment + network and a pitch predictor. + + If the alignment network is used, the model learns the text-to-speech alignment + from the data instead of using pre-computed durations. + + If the pitch predictor is used, the model trains a pitch predictor that predicts average pitch value for each + input character as in the FastPitch model. + + `ForwardTTS` can be configured to one of these architectures, + + - FastPitch + - SpeedySpeech + - FastSpeech + - FastSpeech2 (requires average speech energy predictor) + + Args: + config (Coqpit): Model coqpit class. + speaker_manager (SpeakerManager): Speaker manager for multi-speaker training. Only used for multi-speaker models. + Defaults to None. + + Examples: + >>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs + >>> config = ForwardTTSArgs() + >>> model = ForwardTTS(config) + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + self._set_model_args(config) + + self.init_multispeaker(config) + + self.max_duration = self.args.max_duration + self.use_aligner = self.args.use_aligner + self.use_pitch = self.args.use_pitch + self.use_energy = self.args.use_energy + self.binary_loss_weight = 0.0 + + self.length_scale = ( + float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale + ) + + self.emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels) + + self.encoder = Encoder( + self.args.hidden_channels, + self.args.hidden_channels, + self.args.encoder_type, + self.args.encoder_params, + self.embedded_speaker_dim, + ) + + if self.args.positional_encoding: + self.pos_encoder = PositionalEncoding(self.args.hidden_channels) + + self.decoder = Decoder( + self.args.out_channels, + self.args.hidden_channels, + self.args.decoder_type, + self.args.decoder_params, + ) + + self.duration_predictor = DurationPredictor( + self.args.hidden_channels, + self.args.duration_predictor_hidden_channels, + self.args.duration_predictor_kernel_size, + self.args.duration_predictor_dropout_p, + ) + + if self.args.use_pitch: + self.pitch_predictor = DurationPredictor( + self.args.hidden_channels, + self.args.pitch_predictor_hidden_channels, + self.args.pitch_predictor_kernel_size, + self.args.pitch_predictor_dropout_p, + ) + self.pitch_emb = nn.Conv1d( + 1, + self.args.hidden_channels, + kernel_size=self.args.pitch_embedding_kernel_size, + padding=int((self.args.pitch_embedding_kernel_size - 1) / 2), + ) + + if self.args.use_energy: + self.energy_predictor = DurationPredictor( + self.args.hidden_channels, + self.args.energy_predictor_hidden_channels, + self.args.energy_predictor_kernel_size, + self.args.energy_predictor_dropout_p, + ) + self.energy_emb = nn.Conv1d( + 1, + self.args.hidden_channels, + kernel_size=self.args.energy_embedding_kernel_size, + padding=int((self.args.energy_embedding_kernel_size - 1) / 2), + ) + + if self.args.use_aligner: + self.aligner = AlignmentNetwork( + in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels + ) + + def init_multispeaker(self, config: Coqpit): + """Init for multi-speaker training. + + Args: + config (Coqpit): Model configuration. + """ + self.embedded_speaker_dim = 0 + # init speaker manager + if self.speaker_manager is None and (config.use_d_vector_file or config.use_speaker_embedding): + raise ValueError( + " > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model." + ) + # set number of speakers + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + # init d-vector embedding + if config.use_d_vector_file: + self.embedded_speaker_dim = config.d_vector_dim + if self.args.d_vector_dim != self.args.hidden_channels: + # self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) + self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels) + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + logger.info("Init speaker_embedding layer.") + self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + + @staticmethod + def generate_attn(dr, x_mask, y_mask=None): + """Generate an attention mask from the durations. + + Shapes + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + """ + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) + return attn + + def expand_encoder_outputs(self, en, dr, x_mask, y_mask): + """Generate attention alignment map from durations and + expand encoder outputs + + Shapes: + - en: :math:`(B, D_{en}, T_{en})` + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + + Examples:: + + encoder output: [a,b,c,d] + durations: [1, 3, 2, 1] + + expanded: [a, b, b, b, c, c, d] + attention map: [[0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 1, 0], + [0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0]] + """ + attn = self.generate_attn(dr, x_mask, y_mask) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2) + return o_en_ex, attn + + def format_durations(self, o_dr_log, x_mask): + """Format predicted durations. + 1. Convert to linear scale from log scale + 2. Apply the length scale for speed adjustment + 3. Apply masking. + 4. Cast 0 durations to 1. + 5. Round the duration values. + + Args: + o_dr_log: Log scale durations. + x_mask: Input text mask. + + Shapes: + - o_dr_log: :math:`(B, T_{de})` + - x_mask: :math:`(B, T_{en})` + """ + o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale + o_dr[o_dr < 1] = 1.0 + o_dr = torch.round(o_dr) + return o_dr + + def _forward_encoder( + self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Encoding forward pass. + + 1. Embed speaker IDs if multi-speaker mode. + 2. Embed character sequences. + 3. Run the encoder network. + 4. Sum encoder outputs and speaker embeddings + + Args: + x (torch.LongTensor): Input sequence IDs. + x_mask (torch.FloatTensor): Input squence mask. + g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None. + + Returns: + Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor]: + encoder output, encoder output for the duration predictor, input sequence mask, speaker embeddings, + character embeddings + + Shapes: + - x: :math:`(B, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - g: :math:`(B, C)` + """ + if hasattr(self, "emb_g"): + g = g.type(torch.LongTensor) + g = self.emb_g(g) # [B, C, 1] + if g is not None: + g = g.unsqueeze(-1) + # [B, T, C] + x_emb = self.emb(x) + # encoder pass + # o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) + o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g) + # speaker conditioning + # TODO: try different ways of conditioning + if g is not None: + if hasattr(self, "proj_g"): + g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1) + o_en = o_en + g + return o_en, x_mask, g, x_emb + + def _forward_decoder( + self, + o_en: torch.FloatTensor, + dr: torch.IntTensor, + x_mask: torch.FloatTensor, + y_lengths: torch.IntTensor, + g: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Decoding forward pass. + + 1. Compute the decoder output mask + 2. Expand encoder output with the durations. + 3. Apply position encoding. + 4. Add speaker embeddings if multi-speaker mode. + 5. Run the decoder. + + Args: + o_en (torch.FloatTensor): Encoder output. + dr (torch.IntTensor): Ground truth durations or alignment network durations. + x_mask (torch.IntTensor): Input sequence mask. + y_lengths (torch.IntTensor): Output sequence lengths. + g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. + """ + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + # positional encoding + if hasattr(self, "pos_encoder"): + o_en_ex = self.pos_encoder(o_en_ex, y_mask) + # decoder pass + o_de = self.decoder(o_en_ex, y_mask, g=g) + return o_de.transpose(1, 2), attn.transpose(1, 2) + + def _forward_pitch_predictor( + self, + o_en: torch.FloatTensor, + x_mask: torch.IntTensor, + pitch: torch.FloatTensor = None, + dr: torch.IntTensor = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Pitch predictor forward pass. + + 1. Predict pitch from encoder outputs. + 2. In training - Compute average pitch values for each input character from the ground truth pitch values. + 3. Embed average pitch values. + + Args: + o_en (torch.FloatTensor): Encoder output. + x_mask (torch.IntTensor): Input sequence mask. + pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None. + dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction. + + Shapes: + - o_en: :math:`(B, C, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - pitch: :math:`(B, 1, T_{de})` + - dr: :math:`(B, T_{en})` + """ + o_pitch = self.pitch_predictor(o_en, x_mask) + if pitch is not None: + avg_pitch = average_over_durations(pitch, dr) + o_pitch_emb = self.pitch_emb(avg_pitch) + return o_pitch_emb, o_pitch, avg_pitch + o_pitch_emb = self.pitch_emb(o_pitch) + return o_pitch_emb, o_pitch + + def _forward_energy_predictor( + self, + o_en: torch.FloatTensor, + x_mask: torch.IntTensor, + energy: torch.FloatTensor = None, + dr: torch.IntTensor = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Energy predictor forward pass. + + 1. Predict energy from encoder outputs. + 2. In training - Compute average pitch values for each input character from the ground truth pitch values. + 3. Embed average energy values. + + Args: + o_en (torch.FloatTensor): Encoder output. + x_mask (torch.IntTensor): Input sequence mask. + energy (torch.FloatTensor, optional): Ground truth energy values. Defaults to None. + dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Energy embedding, energy prediction. + + Shapes: + - o_en: :math:`(B, C, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - pitch: :math:`(B, 1, T_{de})` + - dr: :math:`(B, T_{en})` + """ + o_energy = self.energy_predictor(o_en, x_mask) + if energy is not None: + avg_energy = average_over_durations(energy, dr) + o_energy_emb = self.energy_emb(avg_energy) + return o_energy_emb, o_energy, avg_energy + o_energy_emb = self.energy_emb(o_energy) + return o_energy_emb, o_energy + + def _forward_aligner( + self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor + ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Aligner forward pass. + + 1. Compute a mask to apply to the attention map. + 2. Run the alignment network. + 3. Apply MAS to compute the hard alignment map. + 4. Compute the durations from the hard alignment map. + + Args: + x (torch.FloatTensor): Input sequence. + y (torch.FloatTensor): Output sequence. + x_mask (torch.IntTensor): Input sequence mask. + y_mask (torch.IntTensor): Output sequence mask. + + Returns: + Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials, + hard alignment map. + + Shapes: + - x: :math:`[B, T_en, C_en]` + - y: :math:`[B, T_de, C_de]` + - x_mask: :math:`[B, 1, T_en]` + - y_mask: :math:`[B, 1, T_de]` + + - o_alignment_dur: :math:`[B, T_en]` + - alignment_soft: :math:`[B, T_en, T_de]` + - alignment_logprob: :math:`[B, 1, T_de, T_en]` + - alignment_mas: :math:`[B, T_en, T_de]` + """ + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None) + alignment_mas = maximum_path( + alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() + ) + o_alignment_dur = torch.sum(alignment_mas, -1).int() + alignment_soft = alignment_soft.squeeze(1).transpose(1, 2) + return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas + + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def forward( + self, + x: torch.LongTensor, + x_lengths: torch.LongTensor, + y_lengths: torch.LongTensor, + y: torch.FloatTensor = None, + dr: torch.IntTensor = None, + pitch: torch.FloatTensor = None, + energy: torch.FloatTensor = None, + aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument + ) -> Dict: + """Model's forward pass. + + Args: + x (torch.LongTensor): Input character sequences. + x_lengths (torch.LongTensor): Input sequence lengths. + y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None. + y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None. + dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None. + pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None. + energy (torch.FloatTensor): energy values for each spectrogram frame. Only used when the energy predictor is on. Defaults to None. + aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. + + Shapes: + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - y_lengths: :math:`[B]` + - y: :math:`[B, T_max2]` + - dr: :math:`[B, T_max]` + - g: :math:`[B, C]` + - pitch: :math:`[B, 1, T]` + """ + g = self._set_speaker_input(aux_input) + # compute sequence masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float() + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float() + # encoder pass + o_en, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g) + # duration predictor pass + if self.args.detach_duration_predictor: + o_dr_log = self.duration_predictor(o_en.detach(), x_mask) + else: + o_dr_log = self.duration_predictor(o_en, x_mask) + o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) + # generate attn mask from predicted durations + o_attn = self.generate_attn(o_dr.squeeze(1), x_mask) + # aligner + o_alignment_dur = None + alignment_soft = None + alignment_logprob = None + alignment_mas = None + if self.use_aligner: + o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner( + x_emb, y, x_mask, y_mask + ) + alignment_soft = alignment_soft.transpose(1, 2) + alignment_mas = alignment_mas.transpose(1, 2) + dr = o_alignment_dur + # pitch predictor pass + o_pitch = None + avg_pitch = None + if self.args.use_pitch: + o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr) + o_en = o_en + o_pitch_emb + # energy predictor pass + o_energy = None + avg_energy = None + if self.args.use_energy: + o_energy_emb, o_energy, avg_energy = self._forward_energy_predictor(o_en, x_mask, energy, dr) + o_en = o_en + o_energy_emb + # decoder pass + o_de, attn = self._forward_decoder( + o_en, dr, x_mask, y_lengths, g=None + ) # TODO: maybe pass speaker embedding (g) too + outputs = { + "model_outputs": o_de, # [B, T, C] + "durations_log": o_dr_log.squeeze(1), # [B, T] + "durations": o_dr.squeeze(1), # [B, T] + "attn_durations": o_attn, # for visualization [B, T_en, T_de'] + "pitch_avg": o_pitch, + "pitch_avg_gt": avg_pitch, + "energy_avg": o_energy, + "energy_avg_gt": avg_energy, + "alignments": attn, # [B, T_de, T_en] + "alignment_soft": alignment_soft, + "alignment_mas": alignment_mas, + "o_alignment_dur": o_alignment_dur, + "alignment_logprob": alignment_logprob, + "x_mask": x_mask, + "y_mask": y_mask, + } + return outputs + + @torch.no_grad() + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument + """Model's inference pass. + + Args: + x (torch.LongTensor): Input character sequence. + aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`. + + Shapes: + - x: [B, T_max] + - x_lengths: [B] + - g: [B, C] + """ + g = self._set_speaker_input(aux_input) + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float() + # encoder pass + o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g) + # duration predictor pass + o_dr_log = self.duration_predictor(o_en.squeeze(), x_mask) + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + y_lengths = o_dr.sum(1) + + # pitch predictor pass + o_pitch = None + if self.args.use_pitch: + o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask) + o_en = o_en + o_pitch_emb + # energy predictor pass + o_energy = None + if self.args.use_energy: + o_energy_emb, o_energy = self._forward_energy_predictor(o_en, x_mask) + o_en = o_en + o_energy_emb + # decoder pass + o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None) + outputs = { + "model_outputs": o_de, + "alignments": attn, + "pitch": o_pitch, + "energy": o_energy, + "durations_log": o_dr_log, + } + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + pitch = batch["pitch"] if self.args.use_pitch else None + energy = batch["energy"] if self.args.use_energy else None + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + durations = batch["durations"] + aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + + # forward pass + outputs = self.forward( + text_input, + text_lengths, + mel_lengths, + y=mel_input, + dr=durations, + pitch=pitch, + energy=energy, + aux_input=aux_input, + ) + # use aligner's output as the duration target + if self.use_aligner: + durations = outputs["o_alignment_dur"] + # use float32 in AMP + with autocast(enabled=False): + # compute loss + loss_dict = criterion( + decoder_output=outputs["model_outputs"], + decoder_target=mel_input, + decoder_output_lens=mel_lengths, + dur_output=outputs["durations_log"], + dur_target=durations, + pitch_output=outputs["pitch_avg"] if self.use_pitch else None, + pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None, + energy_output=outputs["energy_avg"] if self.use_energy else None, + energy_target=outputs["energy_avg_gt"] if self.use_energy else None, + input_lens=text_lengths, + alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, + alignment_soft=outputs["alignment_soft"], + alignment_hard=outputs["alignment_mas"], + binary_loss_weight=self.binary_loss_weight, + ) + # compute duration error + durations_pred = outputs["durations"] + duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() + loss_dict["duration_error"] = duration_error + + return outputs, loss_dict + + def _create_logs(self, batch, outputs, ap): + """Create common logger outputs.""" + model_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + mel_input = batch["mel_input"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # plot pitch figures + if self.args.use_pitch: + pitch_avg = abs(outputs["pitch_avg_gt"][0, 0].data.cpu().numpy()) + pitch_avg_hat = abs(outputs["pitch_avg"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + pitch_figures = { + "pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False), + "pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False), + } + figures.update(pitch_figures) + + # plot energy figures + if self.args.use_energy: + energy_avg = abs(outputs["energy_avg_gt"][0, 0].data.cpu().numpy()) + energy_avg_hat = abs(outputs["energy_avg"][0, 0].data.cpu().numpy()) + chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy()) + energy_figures = { + "energy_ground_truth": plot_avg_energy(energy_avg, chars, output_fig=False), + "energy_avg_predicted": plot_avg_energy(energy_avg_hat, chars, output_fig=False), + } + figures.update(energy_figures) + + # plot the attention mask computed from the predicted durations + if "attn_durations" in outputs: + alignments_hat = outputs["attn_durations"][0].data.cpu().numpy() + figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def get_criterion(self): + from TTS.tts.layers.losses import ForwardTTSLoss # pylint: disable=import-outside-toplevel + + return ForwardTTSLoss(self.config) + + def on_train_step_start(self, trainer): + """Schedule binary loss weight.""" + self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0 + + @staticmethod + def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (ForwardTTSConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return ForwardTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..64954d283ccc3763b6e37c257e440a280d269665 --- /dev/null +++ b/TTS/tts/models/glow_tts.py @@ -0,0 +1,559 @@ +import logging +import math +from typing import Dict, List, Tuple, Union + +import torch +from coqpit import Coqpit +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from torch.nn import functional as F +from trainer.io import load_fsspec + +from TTS.tts.configs.glow_tts_config import GlowTTSConfig +from TTS.tts.layers.glow_tts.decoder import Decoder +from TTS.tts.layers.glow_tts.encoder import Encoder +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram + +logger = logging.getLogger(__name__) + + +class GlowTTS(BaseTTS): + """GlowTTS model. + + Paper:: + https://arxiv.org/abs/2005.11129 + + Paper abstract:: + Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate + mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained + without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS, + a flow-based generative model for parallel TTS that does not require any external aligner. By combining the + properties of flows and dynamic programming, the proposed model searches for the most probable monotonic + alignment between text and the latent representation of speech on its own. We demonstrate that enforcing hard + monotonic alignments enables robust TTS, which generalizes to long utterances, and employing generative flows + enables fast, diverse, and controllable speech synthesis. Glow-TTS obtains an order-of-magnitude speed-up over + the autoregressive model, Tacotron 2, at synthesis with comparable speech quality. We further show that our + model can be easily extended to a multi-speaker setting. + + Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments. + + Examples: + Init only model layers. + + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig + >>> from TTS.tts.models.glow_tts import GlowTTS + >>> config = GlowTTSConfig(num_chars=2) + >>> model = GlowTTS(config) + + Fully init a model ready for action. All the class attributes and class members + (e.g Tokenizer, AudioProcessor, etc.). are initialized internally based on config values. + + >>> from TTS.tts.configs.glow_tts_config import GlowTTSConfig + >>> from TTS.tts.models.glow_tts import GlowTTS + >>> config = GlowTTSConfig() + >>> model = GlowTTS.init_from_config(config) + """ + + def __init__( + self, + config: GlowTTSConfig, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields to `self` + # for fewer code change + self.config = config + for key in config: + setattr(self, key, config[key]) + + self.decoder_output_dim = config.out_channels + + # init multi-speaker layers if necessary + self.init_multispeaker(config) + + self.run_data_dep_init = config.data_dep_init_steps > 0 + self.encoder = Encoder( + self.num_chars, + out_channels=self.out_channels, + hidden_channels=self.hidden_channels_enc, + hidden_channels_dp=self.hidden_channels_dp, + encoder_type=self.encoder_type, + encoder_params=self.encoder_params, + mean_only=self.mean_only, + use_prenet=self.use_encoder_prenet, + dropout_p_dp=self.dropout_p_dp, + c_in_channels=self.c_in_channels, + ) + + self.decoder = Decoder( + self.out_channels, + self.hidden_channels_dec, + self.kernel_size_dec, + self.dilation_rate, + self.num_flow_blocks_dec, + self.num_block_layers, + dropout_p=self.dropout_p_dec, + num_splits=self.num_splits, + num_squeeze=self.num_squeeze, + sigmoid_scale=self.sigmoid_scale, + c_in_channels=self.c_in_channels, + ) + + def init_multispeaker(self, config: Coqpit): + """Init speaker embedding layer if `use_speaker_embedding` is True and set the expected speaker embedding + vector dimension to the encoder layer channel size. If model uses d-vectors, then it only sets + speaker embedding vector dimension to the d-vector dimension from the config. + + Args: + config (Coqpit): Model configuration. + """ + self.embedded_speaker_dim = 0 + # set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + # set ultimate speaker embedding size + if config.use_d_vector_file: + self.embedded_speaker_dim = ( + config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 + ) + if self.speaker_manager is not None: + assert ( + config.d_vector_dim == self.speaker_manager.embedding_dim + ), " [!] d-vector dimension mismatch b/w config and speaker manager." + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + logger.info("Init speaker_embedding layer.") + self.embedded_speaker_dim = self.hidden_channels_enc + self.emb_g = nn.Embedding(self.num_speakers, self.hidden_channels_enc) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + # set conditioning dimensions + self.c_in_channels = self.embedded_speaker_dim + + @staticmethod + def compute_outputs(attn, o_mean, o_log_scale, x_mask): + """Compute and format the mode outputs with the given alignment map""" + y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + y_log_scale = torch.matmul(attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + # compute total duration with adjustment + o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask + return y_mean, y_log_scale, o_attn_dur + + def unlock_act_norm_layers(self): + """Unlock activation normalization layers for data depended initalization.""" + for f in self.decoder.flows: + if getattr(f, "set_ddi", False): + f.set_ddi(True) + + def lock_act_norm_layers(self): + """Lock activation normalization layers.""" + for f in self.decoder.flows: + if getattr(f, "set_ddi", False): + f.set_ddi(False) + + def _set_speaker_input(self, aux_input: Dict): + if aux_input is None: + d_vectors = None + speaker_ids = None + else: + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def _speaker_embedding(self, aux_input: Dict) -> Union[torch.tensor, None]: + g = self._set_speaker_input(aux_input) + # speaker embedding + if g is not None: + if hasattr(self, "emb_g"): + # use speaker embedding layer + if not g.size(): # if is a scalar + g = g.unsqueeze(0) # unsqueeze + g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] + else: + # use d-vector + g = F.normalize(g).unsqueeze(-1) # [b, h, 1] + return g + + def forward( + self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value + """ + Args: + x (torch.Tensor): + Input text sequence ids. :math:`[B, T_en]` + + x_lengths (torch.Tensor): + Lengths of input text sequences. :math:`[B]` + + y (torch.Tensor): + Target mel-spectrogram frames. :math:`[B, T_de, C_mel]` + + y_lengths (torch.Tensor): + Lengths of target mel-spectrogram frames. :math:`[B]` + + aux_input (Dict): + Auxiliary inputs. `d_vectors` is speaker embedding vectors for a multi-speaker model. + :math:`[B, D_vec]`. `speaker_ids` is speaker ids for a multi-speaker model usind speaker-embedding + layer. :math:`B` + + Returns: + Dict: + - z: :math: `[B, T_de, C]` + - logdet: :math:`B` + - y_mean: :math:`[B, T_de, C]` + - y_log_scale: :math:`[B, T_de, C]` + - alignments: :math:`[B, T_en, T_de]` + - durations_log: :math:`[B, T_en, 1]` + - total_durations_log: :math:`[B, T_en, 1]` + """ + # [B, T, C] -> [B, C, T] + y = y.transpose(1, 2) + y_max_length = y.size(2) + # norm speaker embeddings + g = self._speaker_embedding(aux_input) + # embedding pass + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + # drop redisual frames wrt num_squeeze and set y_lengths. + y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) + # create masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + # [B, 1, T_en, T_de] + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + # find the alignment path + with torch.no_grad(): + o_scale = torch.exp(-2 * o_log_scale) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + attn = attn.squeeze(1).permute(0, 2, 1) + outputs = { + "z": z.transpose(1, 2), + "logdet": logdet, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), + "alignments": attn, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), + } + return outputs + + @torch.no_grad() + def inference_with_MAS( + self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value + """ + It's similar to the teacher forcing in Tacotron. + It was proposed in: https://arxiv.org/abs/2104.05557 + + Shapes: + - x: :math:`[B, T]` + - x_lenghts: :math:`B` + - y: :math:`[B, T, C]` + - y_lengths: :math:`B` + - g: :math:`[B, C] or B` + """ + y = y.transpose(1, 2) + y_max_length = y.size(2) + # norm speaker embeddings + g = self._speaker_embedding(aux_input) + # embedding pass + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + # drop redisual frames wrt num_squeeze and set y_lengths. + y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) + # create masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + # find the alignment path between z and encoder output + o_scale = torch.exp(-2 * o_log_scale) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + attn = attn.squeeze(1).permute(0, 2, 1) + + # get predited aligned distribution + z = y_mean * y_mask + + # reverse the decoder and predict using the aligned distribution + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + outputs = { + "model_outputs": z.transpose(1, 2), + "logdet": logdet, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), + "alignments": attn, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), + } + return outputs + + @torch.no_grad() + def decoder_inference( + self, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value + """ + Shapes: + - y: :math:`[B, T, C]` + - y_lengths: :math:`B` + - g: :math:`[B, C] or B` + """ + y = y.transpose(1, 2) + y_max_length = y.size(2) + g = self._speaker_embedding(aux_input) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype) + # decoder pass + z, logdet = self.decoder(y, y_mask, g=g, reverse=False) + # reverse decoder and predict + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + outputs = {} + outputs["model_outputs"] = y.transpose(1, 2) + outputs["logdet"] = logdet + return outputs + + @torch.no_grad() + def inference( + self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None} + ): # pylint: disable=dangerous-default-value + x_lengths = aux_input["x_lengths"] + g = self._speaker_embedding(aux_input) + # embedding pass + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + # compute output durations + w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale + w_ceil = torch.clamp_min(torch.ceil(w), 1) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_max_length = None + # compute masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + # compute attention mask + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + + z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.inference_noise_scale) * y_mask + # decoder pass + y, logdet = self.decoder(z, y_mask, g=g, reverse=True) + attn = attn.squeeze(1).permute(0, 2, 1) + outputs = { + "model_outputs": y.transpose(1, 2), + "logdet": logdet, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), + "alignments": attn, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), + } + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + """A single training step. Forward pass and loss computation. Run data depended initialization for the + first `config.data_dep_init_steps` steps. + + Args: + batch (dict): [description] + criterion (nn.Module): [description] + """ + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + + if self.run_data_dep_init and self.training: + # compute data-dependent initialization of activation norm layers + self.unlock_act_norm_layers() + with torch.no_grad(): + _ = self.forward( + text_input, + text_lengths, + mel_input, + mel_lengths, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) + outputs = None + loss_dict = None + self.lock_act_norm_layers() + else: + # normal training step + outputs = self.forward( + text_input, + text_lengths, + mel_input, + mel_lengths, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) + + with autocast(enabled=False): # avoid mixed_precision in criterion + loss_dict = criterion( + outputs["z"].float(), + outputs["y_mean"].float(), + outputs["y_log_scale"].float(), + outputs["logdet"].float(), + mel_lengths, + outputs["durations_log"].float(), + outputs["total_durations_log"].float(), + text_lengths, + ) + return outputs, loss_dict + + def _create_logs(self, batch, outputs, ap): + alignments = outputs["alignments"] + text_input = batch["text_input"][:1] if batch["text_input"] is not None else None + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + d_vectors = batch["d_vectors"][:1] if batch["d_vectors"] is not None else None + speaker_ids = batch["speaker_ids"][:1] if batch["speaker_ids"] is not None else None + + # model runs reverse flow to predict spectrograms + pred_outputs = self.inference( + text_input, + aux_input={"x_lengths": text_lengths[:1], "d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) + model_outputs = pred_outputs["model_outputs"] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": train_audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + @torch.no_grad() + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @torch.no_grad() + def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + logger.info("Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_test_aux_input() + if len(test_sentences) == 0: + logger.warning("No test sentences provided.") + else: + for idx, sen in enumerate(test_sentences): + outputs = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + use_griffin_lim=True, + do_trim_silence=False, + ) + + test_audios["{}-audio".format(idx)] = outputs["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs["outputs"]["model_outputs"], self.ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) + return test_figures, test_audios + + def preprocess(self, y, y_lengths, y_max_length, attn=None): + if y_max_length is not None: + y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze + y = y[:, :, :y_max_length] + if attn is not None: + attn = attn[:, :, :, :y_max_length] + y_lengths = torch.div(y_lengths, self.num_squeeze, rounding_mode="floor") * self.num_squeeze + return y, y_lengths, y_max_length, attn + + def store_inverse(self): + self.decoder.store_inverse() + + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + self.store_inverse() + assert not self.training + + @staticmethod + def get_criterion(): + from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel + + return GlowTTSLoss() + + def on_train_step_start(self, trainer): + """Decide on every training step wheter enable/disable data depended initialization.""" + self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps + + @staticmethod + def init_from_config(config: "GlowTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return GlowTTS(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/neuralhmm_tts.py b/TTS/tts/models/neuralhmm_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..de5401aac705c67ca6c303fbb6dc6789a10653f9 --- /dev/null +++ b/TTS/tts/models/neuralhmm_tts.py @@ -0,0 +1,393 @@ +import logging +import os +from typing import Dict, List, Union + +import torch +from coqpit import Coqpit +from torch import nn +from trainer.io import load_fsspec +from trainer.logging.tensorboard_logger import TensorboardLogger + +from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils +from TTS.tts.layers.overflow.neural_hmm import NeuralHMM +from TTS.tts.layers.overflow.plotting_utils import ( + get_spec_from_most_probable_state, + plot_transition_probabilities_to_numpy, +) +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.generic_utils import format_aux_input, is_pytorch_at_least_2_4 + +logger = logging.getLogger(__name__) + + +class NeuralhmmTTS(BaseTTS): + """Neural HMM TTS model. + + Paper:: + https://arxiv.org/abs/2108.13320 + + Paper abstract:: + Neural sequence-to-sequence TTS has achieved significantly better output quality + than statistical speech synthesis using HMMs.However, neural TTS is generally not probabilistic + and uses non-monotonic attention. Attention failures increase training time and can make + synthesis babble incoherently. This paper describes how the old and new paradigms can be + combined to obtain the advantages of both worlds, by replacing attention in neural TTS with + an autoregressive left-right no-skip hidden Markov model defined by a neural network. + Based on this proposal, we modify Tacotron 2 to obtain an HMM-based neural TTS model with + monotonic alignment, trained to maximise the full sequence likelihood without approximation. + We also describe how to combine ideas from classical and contemporary TTS for best results. + The resulting example system is smaller and simpler than Tacotron 2, and learns to speak with + fewer iterations and less data, whilst achieving comparable naturalness prior to the post-net. + Our approach also allows easy control over speaking rate. Audio examples and code + are available at https://shivammehta25.github.io/Neural-HMM/ . + + Note: + - This is a parameter efficient version of OverFlow (15.3M vs 28.6M). Since it has half the + number of parameters as OverFlow the synthesis output quality is suboptimal (but comparable to Tacotron2 + without Postnet), but it learns to speak with even lesser amount of data and is still significantly faster + than other attention-based methods. + + - Neural HMMs uses flat start initialization i.e it computes the means and std and transition probabilities + of the dataset and uses them to initialize the model. This benefits the model and helps with faster learning + If you change the dataset or want to regenerate the parameters change the `force_generate_statistics` and + `mel_statistics_parameter_path` accordingly. + + - To enable multi-GPU training, set the `use_grad_checkpointing=False` in config. + This will significantly increase the memory usage. This is because to compute + the actual data likelihood (not an approximation using MAS/Viterbi) we must use + all the states at the previous time step during the forward pass to decide the + probability distribution at the current step i.e the difference between the forward + algorithm and viterbi approximation. + + Check :class:`TTS.tts.configs.neuralhmm_tts_config.NeuralhmmTTSConfig` for class arguments. + """ + + def __init__( + self, + config: "NeuralhmmTTSConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields to `self` + # for fewer code change + self.config = config + for key in config: + setattr(self, key, config[key]) + + self.encoder = Encoder(config.num_chars, config.state_per_phone, config.encoder_in_out_features) + self.neural_hmm = NeuralHMM( + frame_channels=self.out_channels, + ar_order=self.ar_order, + deterministic_transition=self.deterministic_transition, + encoder_dim=self.encoder_in_out_features, + prenet_type=self.prenet_type, + prenet_dim=self.prenet_dim, + prenet_n_layers=self.prenet_n_layers, + prenet_dropout=self.prenet_dropout, + prenet_dropout_at_inference=self.prenet_dropout_at_inference, + memory_rnn_dim=self.memory_rnn_dim, + outputnet_size=self.outputnet_size, + flat_start_params=self.flat_start_params, + std_floor=self.std_floor, + use_grad_checkpointing=self.use_grad_checkpointing, + ) + + self.register_buffer("mean", torch.tensor(0)) + self.register_buffer("std", torch.tensor(1)) + + def update_mean_std(self, statistics_dict: Dict): + self.mean.data = torch.tensor(statistics_dict["mean"]) + self.std.data = torch.tensor(statistics_dict["std"]) + + def preprocess_batch(self, text, text_len, mels, mel_len): + if self.mean.item() == 0 or self.std.item() == 1: + statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4()) + self.update_mean_std(statistics_dict) + + mels = self.normalize(mels) + return text, text_len, mels, mel_len + + def normalize(self, x): + return x.sub(self.mean).div(self.std) + + def inverse_normalize(self, x): + return x.mul(self.std).add(self.mean) + + def forward(self, text, text_len, mels, mel_len): + """ + Forward pass for training and computing the log likelihood of a given batch. + + Shapes: + Shapes: + text: :math:`[B, T_in]` + text_len: :math:`[B]` + mels: :math:`[B, T_out, C]` + mel_len: :math:`[B]` + """ + text, text_len, mels, mel_len = self.preprocess_batch(text, text_len, mels, mel_len) + encoder_outputs, encoder_output_len = self.encoder(text, text_len) + + log_probs, fwd_alignments, transition_vectors, means = self.neural_hmm( + encoder_outputs, encoder_output_len, mels.transpose(1, 2), mel_len + ) + + outputs = { + "log_probs": log_probs, + "alignments": fwd_alignments, + "transition_vectors": transition_vectors, + "means": means, + } + + return outputs + + @staticmethod + def _training_stats(batch): + stats = {} + stats["avg_text_length"] = batch["text_lengths"].float().mean() + stats["avg_spec_length"] = batch["mel_lengths"].float().mean() + stats["avg_text_batch_occupancy"] = (batch["text_lengths"].float() / batch["text_lengths"].float().max()).mean() + stats["avg_spec_batch_occupancy"] = (batch["mel_lengths"].float() / batch["mel_lengths"].float().max()).mean() + return stats + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + + outputs = self.forward( + text=text_input, + text_len=text_lengths, + mels=mel_input, + mel_len=mel_lengths, + ) + loss_dict = criterion(outputs["log_probs"] / (mel_lengths.sum() + text_lengths.sum())) + + # for printing useful statistics on terminal + loss_dict.update(self._training_stats(batch)) + return outputs, loss_dict + + def eval_step(self, batch: Dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def _format_aux_input(self, aux_input: Dict, default_input_dict): + """Set missing fields to their default value. + + Args: + aux_inputs (Dict): Dictionary containing the auxiliary inputs. + """ + default_input_dict = default_input_dict.copy() + default_input_dict.update( + { + "sampling_temp": self.sampling_temp, + "max_sampling_time": self.max_sampling_time, + "duration_threshold": self.duration_threshold, + } + ) + if aux_input: + return format_aux_input(default_input_dict, aux_input) + return default_input_dict + + @torch.no_grad() + def inference( + self, + text: torch.Tensor, + aux_input={"x_lengths": None, "sampling_temp": None, "max_sampling_time": None, "duration_threshold": None}, + ): # pylint: disable=dangerous-default-value + """Sampling from the model + + Args: + text (torch.Tensor): :math:`[B, T_in]` + aux_inputs (_type_, optional): _description_. Defaults to None. + + Returns: + outputs: Dictionary containing the following + - mel (torch.Tensor): :math:`[B, T_out, C]` + - hmm_outputs_len (torch.Tensor): :math:`[B]` + - state_travelled (List[List[int]]): List of lists containing the state travelled for each sample in the batch. + - input_parameters (list[torch.FloatTensor]): Input parameters to the neural HMM. + - output_parameters (list[torch.FloatTensor]): Output parameters to the neural HMM. + """ + default_input_dict = { + "x_lengths": torch.sum(text != 0, dim=1), + } + aux_input = self._format_aux_input(aux_input, default_input_dict) + encoder_outputs, encoder_output_len = self.encoder.inference(text, aux_input["x_lengths"]) + outputs = self.neural_hmm.inference( + encoder_outputs, + encoder_output_len, + sampling_temp=aux_input["sampling_temp"], + max_sampling_time=aux_input["max_sampling_time"], + duration_threshold=aux_input["duration_threshold"], + ) + mels, mel_outputs_len = outputs["hmm_outputs"], outputs["hmm_outputs_len"] + + mels = self.inverse_normalize(mels) + outputs.update({"model_outputs": mels, "model_outputs_len": mel_outputs_len}) + outputs["alignments"] = OverflowUtils.double_pad(outputs["alignments"]) + return outputs + + @staticmethod + def get_criterion(): + return NLLLoss() + + @staticmethod + def init_from_config(config: "NeuralhmmTTSConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return NeuralhmmTTS(new_config, ap, tokenizer, speaker_manager) + + def load_checkpoint( + self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def on_init_start(self, trainer): + """If the current dataset does not have normalisation statistics and initialisation transition_probability it computes them otherwise loads.""" + if not os.path.isfile(trainer.config.mel_statistics_parameter_path) or trainer.config.force_generate_statistics: + dataloader = trainer.get_train_dataloader( + training_assets=None, samples=trainer.train_samples, verbose=False + ) + logger.info( + "Data parameters not found for: %s. Computing mel normalization parameters...", + trainer.config.mel_statistics_parameter_path, + ) + data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( + dataloader, trainer.config.out_channels, trainer.config.state_per_phone + ) + logger.info( + "Saving data parameters to: %s: value: %s", + trainer.config.mel_statistics_parameter_path, + (data_mean, data_std, init_transition_prob), + ) + statistics = { + "mean": data_mean.item(), + "std": data_std.item(), + "init_transition_prob": init_transition_prob.item(), + } + torch.save(statistics, trainer.config.mel_statistics_parameter_path) + + else: + logger.info( + "Data parameters found for: %s. Loading mel normalization parameters...", + trainer.config.mel_statistics_parameter_path, + ) + statistics = torch.load( + trainer.config.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4() + ) + data_mean, data_std, init_transition_prob = ( + statistics["mean"], + statistics["std"], + statistics["init_transition_prob"], + ) + logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob)) + + trainer.config.flat_start_params["transition_p"] = ( + init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob + ) + OverflowUtils.update_flat_start_transition(trainer.model, init_transition_prob) + trainer.model.update_mean_std(statistics) + + @torch.inference_mode() + def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use, unused-argument + alignments, transition_vectors = outputs["alignments"], outputs["transition_vectors"] + means = torch.stack(outputs["means"], dim=1) + + figures = { + "alignment": plot_alignment(alignments[0].exp(), title="Forward alignment", fig_size=(20, 20)), + "log_alignment": plot_alignment( + alignments[0].exp(), title="Forward log alignment", plot_log=True, fig_size=(20, 20) + ), + "transition_vectors": plot_alignment(transition_vectors[0], title="Transition vectors", fig_size=(20, 20)), + "mel_from_most_probable_state": plot_spectrogram( + get_spec_from_most_probable_state(alignments[0], means[0]), fig_size=(12, 3) + ), + "mel_target": plot_spectrogram(batch["mel_input"][0], fig_size=(12, 3)), + } + + # sample one item from the batch -1 will give the smalles item + logger.info("Synthesising audio from the model...") + inference_output = self.inference( + batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} + ) + figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3)) + + states = [p[1] for p in inference_output["input_parameters"][0]] + transition_probability_synthesising = [p[2].cpu().numpy() for p in inference_output["output_parameters"][0]] + + for i in range((len(transition_probability_synthesising) // 200) + 1): + start = i * 200 + end = (i + 1) * 200 + figures[f"synthesised_transition_probabilities/{i}"] = plot_transition_probabilities_to_numpy( + states[start:end], transition_probability_synthesising[start:end] + ) + + audio = ap.inv_melspectrogram(inference_output["model_outputs"][0].T.cpu().numpy()) + return figures, {"audios": audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=unused-argument + """Log training progress.""" + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_log( + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int + ): # pylint: disable=unused-argument + """Compute and log evaluation metrics.""" + # Plot model parameters histograms + if isinstance(logger, TensorboardLogger): + # I don't know if any other loggers supports this + for tag, value in self.named_parameters(): + tag = tag.replace(".", "/") + logger.writer.add_histogram(tag, value.data.cpu().numpy(), steps) + + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs[1], self.ap.sample_rate) + logger.test_figures(steps, outputs[0]) + + +class NLLLoss(nn.Module): + """Negative log likelihood loss.""" + + def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use + """Compute the loss. + + Args: + logits (Tensor): [B, T, D] + + Returns: + Tensor: [1] + + """ + return_dict = {} + return_dict["loss"] = -log_prob.mean() + return return_dict diff --git a/TTS/tts/models/overflow.py b/TTS/tts/models/overflow.py new file mode 100644 index 0000000000000000000000000000000000000000..b72f4877cf8363879350f2dac0fecfd1dfe72698 --- /dev/null +++ b/TTS/tts/models/overflow.py @@ -0,0 +1,409 @@ +import logging +import os +from typing import Dict, List, Union + +import torch +from coqpit import Coqpit +from torch import nn +from trainer.io import load_fsspec +from trainer.logging.tensorboard_logger import TensorboardLogger + +from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils +from TTS.tts.layers.overflow.decoder import Decoder +from TTS.tts.layers.overflow.neural_hmm import NeuralHMM +from TTS.tts.layers.overflow.plotting_utils import ( + get_spec_from_most_probable_state, + plot_transition_probabilities_to_numpy, +) +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.generic_utils import format_aux_input, is_pytorch_at_least_2_4 + +logger = logging.getLogger(__name__) + + +class Overflow(BaseTTS): + """OverFlow TTS model. + + Paper:: + https://arxiv.org/abs/2211.06892 + + Paper abstract:: + Neural HMMs are a type of neural transducer recently proposed for + sequence-to-sequence modelling in text-to-speech. They combine the best features + of classic statistical speech synthesis and modern neural TTS, requiring less + data and fewer training updates, and are less prone to gibberish output caused + by neural attention failures. In this paper, we combine neural HMM TTS with + normalising flows for describing the highly non-Gaussian distribution of speech + acoustics. The result is a powerful, fully probabilistic model of durations and + acoustics that can be trained using exact maximum likelihood. Compared to + dominant flow-based acoustic models, our approach integrates autoregression for + improved modelling of long-range dependences such as utterance-level prosody. + Experiments show that a system based on our proposal gives more accurate + pronunciations and better subjective speech quality than comparable methods, + whilst retaining the original advantages of neural HMMs. Audio examples and code + are available at https://shivammehta25.github.io/OverFlow/. + + Note: + - Neural HMMs uses flat start initialization i.e it computes the means and std and transition probabilities + of the dataset and uses them to initialize the model. This benefits the model and helps with faster learning + If you change the dataset or want to regenerate the parameters change the `force_generate_statistics` and + `mel_statistics_parameter_path` accordingly. + + - To enable multi-GPU training, set the `use_grad_checkpointing=False` in config. + This will significantly increase the memory usage. This is because to compute + the actual data likelihood (not an approximation using MAS/Viterbi) we must use + all the states at the previous time step during the forward pass to decide the + probability distribution at the current step i.e the difference between the forward + algorithm and viterbi approximation. + + Check :class:`TTS.tts.configs.overflow.OverFlowConfig` for class arguments. + """ + + def __init__( + self, + config: "OverFlowConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields to `self` + # for fewer code change + self.config = config + for key in config: + setattr(self, key, config[key]) + + self.decoder_output_dim = config.out_channels + + self.encoder = Encoder(config.num_chars, config.state_per_phone, config.encoder_in_out_features) + self.neural_hmm = NeuralHMM( + frame_channels=self.out_channels, + ar_order=self.ar_order, + deterministic_transition=self.deterministic_transition, + encoder_dim=self.encoder_in_out_features, + prenet_type=self.prenet_type, + prenet_dim=self.prenet_dim, + prenet_n_layers=self.prenet_n_layers, + prenet_dropout=self.prenet_dropout, + prenet_dropout_at_inference=self.prenet_dropout_at_inference, + memory_rnn_dim=self.memory_rnn_dim, + outputnet_size=self.outputnet_size, + flat_start_params=self.flat_start_params, + std_floor=self.std_floor, + use_grad_checkpointing=self.use_grad_checkpointing, + ) + + self.decoder = Decoder( + self.out_channels, + self.hidden_channels_dec, + self.kernel_size_dec, + self.dilation_rate, + self.num_flow_blocks_dec, + self.num_block_layers, + dropout_p=self.dropout_p_dec, + num_splits=self.num_splits, + num_squeeze=self.num_squeeze, + sigmoid_scale=self.sigmoid_scale, + c_in_channels=self.c_in_channels, + ) + + self.register_buffer("mean", torch.tensor(0)) + self.register_buffer("std", torch.tensor(1)) + + def update_mean_std(self, statistics_dict: Dict): + self.mean.data = torch.tensor(statistics_dict["mean"]) + self.std.data = torch.tensor(statistics_dict["std"]) + + def preprocess_batch(self, text, text_len, mels, mel_len): + if self.mean.item() == 0 or self.std.item() == 1: + statistics_dict = torch.load(self.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4()) + self.update_mean_std(statistics_dict) + + mels = self.normalize(mels) + return text, text_len, mels, mel_len + + def normalize(self, x): + return x.sub(self.mean).div(self.std) + + def inverse_normalize(self, x): + return x.mul(self.std).add(self.mean) + + def forward(self, text, text_len, mels, mel_len): + """ + Forward pass for training and computing the log likelihood of a given batch. + + Shapes: + Shapes: + text: :math:`[B, T_in]` + text_len: :math:`[B]` + mels: :math:`[B, T_out, C]` + mel_len: :math:`[B]` + """ + text, text_len, mels, mel_len = self.preprocess_batch(text, text_len, mels, mel_len) + encoder_outputs, encoder_output_len = self.encoder(text, text_len) + z, z_lengths, logdet = self.decoder(mels.transpose(1, 2), mel_len) + log_probs, fwd_alignments, transition_vectors, means = self.neural_hmm( + encoder_outputs, encoder_output_len, z, z_lengths + ) + + outputs = { + "log_probs": log_probs + logdet, + "alignments": fwd_alignments, + "transition_vectors": transition_vectors, + "means": means, + } + + return outputs + + @staticmethod + def _training_stats(batch): + stats = {} + stats["avg_text_length"] = batch["text_lengths"].float().mean() + stats["avg_spec_length"] = batch["mel_lengths"].float().mean() + stats["avg_text_batch_occupancy"] = (batch["text_lengths"].float() / batch["text_lengths"].float().max()).mean() + stats["avg_spec_batch_occupancy"] = (batch["mel_lengths"].float() / batch["mel_lengths"].float().max()).mean() + return stats + + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + + outputs = self.forward( + text=text_input, + text_len=text_lengths, + mels=mel_input, + mel_len=mel_lengths, + ) + loss_dict = criterion(outputs["log_probs"] / (mel_lengths.sum() + text_lengths.sum())) + + # for printing useful statistics on terminal + loss_dict.update(self._training_stats(batch)) + return outputs, loss_dict + + def eval_step(self, batch: Dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def _format_aux_input(self, aux_input: Dict, default_input_dict): + """Set missing fields to their default value. + + Args: + aux_inputs (Dict): Dictionary containing the auxiliary inputs. + """ + default_input_dict = default_input_dict.copy() + default_input_dict.update( + { + "sampling_temp": self.sampling_temp, + "max_sampling_time": self.max_sampling_time, + "duration_threshold": self.duration_threshold, + } + ) + if aux_input: + return format_aux_input(default_input_dict, aux_input) + return default_input_dict + + @torch.no_grad() + def inference( + self, + text: torch.Tensor, + aux_input={"x_lengths": None, "sampling_temp": None, "max_sampling_time": None, "duration_threshold": None}, + ): # pylint: disable=dangerous-default-value + """Sampling from the model + + Args: + text (torch.Tensor): :math:`[B, T_in]` + aux_inputs (_type_, optional): _description_. Defaults to None. + + Returns: + outputs: Dictionary containing the following + - mel (torch.Tensor): :math:`[B, T_out, C]` + - hmm_outputs_len (torch.Tensor): :math:`[B]` + - state_travelled (List[List[int]]): List of lists containing the state travelled for each sample in the batch. + - input_parameters (list[torch.FloatTensor]): Input parameters to the neural HMM. + - output_parameters (list[torch.FloatTensor]): Output parameters to the neural HMM. + """ + default_input_dict = { + "x_lengths": torch.sum(text != 0, dim=1), + } + aux_input = self._format_aux_input(aux_input, default_input_dict) + encoder_outputs, encoder_output_len = self.encoder.inference(text, aux_input["x_lengths"]) + outputs = self.neural_hmm.inference( + encoder_outputs, + encoder_output_len, + sampling_temp=aux_input["sampling_temp"], + max_sampling_time=aux_input["max_sampling_time"], + duration_threshold=aux_input["duration_threshold"], + ) + + mels, mel_outputs_len, _ = self.decoder( + outputs["hmm_outputs"].transpose(1, 2), outputs["hmm_outputs_len"], reverse=True + ) + mels = self.inverse_normalize(mels.transpose(1, 2)) + outputs.update({"model_outputs": mels, "model_outputs_len": mel_outputs_len}) + outputs["alignments"] = OverflowUtils.double_pad(outputs["alignments"]) + return outputs + + @staticmethod + def get_criterion(): + return NLLLoss() + + @staticmethod + def init_from_config(config: "OverFlowConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return Overflow(new_config, ap, tokenizer, speaker_manager) + + def load_checkpoint( + self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if eval: + self.eval() + self.decoder.store_inverse() + assert not self.training + + def on_init_start(self, trainer): + """If the current dataset does not have normalisation statistics and initialisation transition_probability it computes them otherwise loads.""" + if not os.path.isfile(trainer.config.mel_statistics_parameter_path) or trainer.config.force_generate_statistics: + dataloader = trainer.get_train_dataloader( + training_assets=None, samples=trainer.train_samples, verbose=False + ) + logger.info( + "Data parameters not found for: %s. Computing mel normalization parameters...", + trainer.config.mel_statistics_parameter_path, + ) + data_mean, data_std, init_transition_prob = OverflowUtils.get_data_parameters_for_flat_start( + dataloader, trainer.config.out_channels, trainer.config.state_per_phone + ) + logger.info( + "Saving data parameters to: %s: value: %s", + trainer.config.mel_statistics_parameter_path, + (data_mean, data_std, init_transition_prob), + ) + statistics = { + "mean": data_mean.item(), + "std": data_std.item(), + "init_transition_prob": init_transition_prob.item(), + } + torch.save(statistics, trainer.config.mel_statistics_parameter_path) + + else: + logger.info( + "Data parameters found for: %s. Loading mel normalization parameters...", + trainer.config.mel_statistics_parameter_path, + ) + statistics = torch.load( + trainer.config.mel_statistics_parameter_path, weights_only=is_pytorch_at_least_2_4() + ) + data_mean, data_std, init_transition_prob = ( + statistics["mean"], + statistics["std"], + statistics["init_transition_prob"], + ) + logger.info("Data parameters loaded with value: %s", (data_mean, data_std, init_transition_prob)) + + trainer.config.flat_start_params["transition_p"] = ( + init_transition_prob.item() if torch.is_tensor(init_transition_prob) else init_transition_prob + ) + OverflowUtils.update_flat_start_transition(trainer.model, init_transition_prob) + trainer.model.update_mean_std(statistics) + + @torch.inference_mode() + def _create_logs(self, batch, outputs, ap): # pylint: disable=no-self-use, unused-argument + alignments, transition_vectors = outputs["alignments"], outputs["transition_vectors"] + means = torch.stack(outputs["means"], dim=1) + + figures = { + "alignment": plot_alignment(alignments[0].exp(), title="Forward alignment", fig_size=(20, 20)), + "log_alignment": plot_alignment( + alignments[0].exp(), title="Forward log alignment", plot_log=True, fig_size=(20, 20) + ), + "transition_vectors": plot_alignment(transition_vectors[0], title="Transition vectors", fig_size=(20, 20)), + "mel_from_most_probable_state": plot_spectrogram( + get_spec_from_most_probable_state(alignments[0], means[0], self.decoder), fig_size=(12, 3) + ), + "mel_target": plot_spectrogram(batch["mel_input"][0], fig_size=(12, 3)), + } + + # sample one item from the batch -1 will give the smalles item + logger.info("Synthesising audio from the model...") + inference_output = self.inference( + batch["text_input"][-1].unsqueeze(0), aux_input={"x_lengths": batch["text_lengths"][-1].unsqueeze(0)} + ) + figures["synthesised"] = plot_spectrogram(inference_output["model_outputs"][0], fig_size=(12, 3)) + + states = [p[1] for p in inference_output["input_parameters"][0]] + transition_probability_synthesising = [p[2].cpu().numpy() for p in inference_output["output_parameters"][0]] + + for i in range((len(transition_probability_synthesising) // 200) + 1): + start = i * 200 + end = (i + 1) * 200 + figures[f"synthesised_transition_probabilities/{i}"] = plot_transition_probabilities_to_numpy( + states[start:end], transition_probability_synthesising[start:end] + ) + + audio = ap.inv_melspectrogram(inference_output["model_outputs"][0].T.cpu().numpy()) + return figures, {"audios": audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=unused-argument + """Log training progress.""" + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_log( + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int + ): # pylint: disable=unused-argument + """Compute and log evaluation metrics.""" + # Plot model parameters histograms + if isinstance(logger, TensorboardLogger): + # I don't know if any other loggers supports this + for tag, value in self.named_parameters(): + tag = tag.replace(".", "/") + logger.writer.add_histogram(tag, value.data.cpu().numpy(), steps) + + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs[1], self.ap.sample_rate) + logger.test_figures(steps, outputs[0]) + + +class NLLLoss(nn.Module): + """Negative log likelihood loss.""" + + def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use + """Compute the loss. + + Args: + logits (Tensor): [B, T, D] + + Returns: + Tensor: [1] + + """ + return_dict = {} + return_dict["loss"] = -log_prob.mean() + return return_dict diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py new file mode 100644 index 0000000000000000000000000000000000000000..400a86d042aabd00efc44f0fb44c4f6551c1fc74 --- /dev/null +++ b/TTS/tts/models/tacotron.py @@ -0,0 +1,413 @@ +# coding: utf-8 + +from typing import Dict, List, Tuple, Union + +import torch +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE +from TTS.tts.layers.tacotron.gst_layers import GST +from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG +from TTS.tts.models.base_tacotron import BaseTacotron +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.capacitron_optimizer import CapacitronOptimizer + + +class Tacotron(BaseTacotron): + """Tacotron as in https://arxiv.org/abs/1703.10135 + It's an autoregressive encoder-attention-decoder-postnet architecture. + Check `TacotronConfig` for the arguments. + + Args: + config (TacotronConfig): Configuration for the Tacotron model. + speaker_manager (SpeakerManager): Speaker manager to handle multi-speaker settings. Only use if the model is + a multi-speaker model. Defaults to None. + """ + + def __init__( + self, + config: "TacotronConfig", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + # pass all config fields to `self` + # for fewer code change + for key in config: + setattr(self, key, config[key]) + + # set speaker embedding channel size for determining `in_channels` for the connected layers. + # `init_multispeaker` needs to be called once more in training to initialize the speaker embedding layer based + # on the number of speakers infered from the dataset. + if self.use_speaker_embedding or self.use_d_vector_file: + self.init_multispeaker(config) + self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim + + if self.use_gst: + self.decoder_in_features += self.gst.gst_embedding_dim + + if self.use_capacitron_vae: + self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim + + # embedding layer + self.embedding = nn.Embedding(self.num_chars, 256, padding_idx=0) + self.embedding.weight.data.normal_(0, 0.3) + + # base model layers + self.encoder = Encoder(self.encoder_in_features) + self.decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + self.r, + self.memory_size, + self.attention_type, + self.windowing, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, + ) + self.postnet = PostCBHG(self.decoder_output_dim) + self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, self.out_channels) + + # setup prenet dropout + self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference + + # global style token layers + if self.gst and self.use_gst: + self.gst_layer = GST( + num_mel=self.decoder_output_dim, + num_heads=self.gst.gst_num_heads, + num_style_tokens=self.gst.gst_num_style_tokens, + gst_embedding_dim=self.gst.gst_embedding_dim, + ) + + # Capacitron layers + if self.capacitron_vae and self.use_capacitron_vae: + self.capacitron_vae_layer = CapacitronVAE( + num_mel=self.decoder_output_dim, + encoder_output_dim=self.encoder_in_features, + capacitron_VAE_embedding_dim=self.capacitron_vae.capacitron_VAE_embedding_dim, + speaker_embedding_dim=( + self.embedded_speaker_dim + if self.use_speaker_embedding and self.capacitron_vae.capacitron_use_speaker_embedding + else None + ), + text_summary_embedding_dim=( + self.capacitron_vae.capacitron_text_summary_embedding_dim + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None + ), + ) + + # backward pass decoder + if self.bidirectional_decoder: + self._init_backward_decoder() + # setup DDC + if self.double_decoder_consistency: + self.coarse_decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + self.ddc_r, + self.memory_size, + self.attention_type, + self.windowing, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, + ) + + def forward( # pylint: disable=dangerous-default-value + self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None} + ): + """ + Shapes: + text: [B, T_in] + text_lengths: [B] + mel_specs: [B, T_out, C] + mel_lengths: [B] + aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C] + """ + aux_input = self._format_aux_input(aux_input) + outputs = {"alignments_backward": None, "decoder_outputs_backward": None} + inputs = self.embedding(text) + input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) + # B x T_in x encoder_in_features + encoder_outputs = self.encoder(inputs) + # sequence masking + encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) + # global style token + if self.gst and self.use_gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) + # speaker embedding + if self.use_speaker_embedding or self.use_d_vector_file: + if not self.use_d_vector_file: + # B x 1 x speaker_embed_dim + embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None] + else: + # B x 1 x speaker_embed_dim + embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + # Capacitron + if self.capacitron_vae and self.use_capacitron_vae: + # B x capacitron_VAE_embedding_dim + encoder_outputs, *capacitron_vae_outputs = self.compute_capacitron_VAE_embedding( + encoder_outputs, + reference_mel_info=[mel_specs, mel_lengths], + text_info=( + [inputs, text_lengths] if self.capacitron_vae.capacitron_use_text_summary_embeddings else None + ), + speaker_embedding=embedded_speakers if self.capacitron_vae.capacitron_use_speaker_embedding else None, + ) + else: + capacitron_vae_outputs = None + # decoder_outputs: B x decoder_in_features x T_out + # alignments: B x T_in x encoder_in_features + # stop_tokens: B x T_in + decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask) + # sequence masking + if output_mask is not None: + decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) + # B x T_out x decoder_in_features + postnet_outputs = self.postnet(decoder_outputs) + # sequence masking + if output_mask is not None: + postnet_outputs = postnet_outputs * output_mask.unsqueeze(2).expand_as(postnet_outputs) + # B x T_out x posnet_dim + postnet_outputs = self.last_linear(postnet_outputs) + # B x T_out x decoder_in_features + decoder_outputs = decoder_outputs.transpose(1, 2).contiguous() + if self.bidirectional_decoder: + decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask) + outputs["alignments_backward"] = alignments_backward + outputs["decoder_outputs_backward"] = decoder_outputs_backward + if self.double_decoder_consistency: + decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass( + mel_specs, encoder_outputs, alignments, input_mask + ) + outputs["alignments_backward"] = alignments_backward + outputs["decoder_outputs_backward"] = decoder_outputs_backward + outputs.update( + { + "model_outputs": postnet_outputs, + "decoder_outputs": decoder_outputs, + "alignments": alignments, + "stop_tokens": stop_tokens, + "capacitron_vae_outputs": capacitron_vae_outputs, + } + ) + return outputs + + @torch.no_grad() + def inference(self, text_input, aux_input=None): + aux_input = self._format_aux_input(aux_input) + inputs = self.embedding(text_input) + encoder_outputs = self.encoder(inputs) + if self.gst and self.use_gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"]) + if self.capacitron_vae and self.use_capacitron_vae: + if aux_input["style_text"] is not None: + style_text_embedding = self.embedding(aux_input["style_text"]) + style_text_length = torch.tensor([style_text_embedding.size(1)], dtype=torch.int64).to( + encoder_outputs.device + ) # pylint: disable=not-callable + reference_mel_length = ( + torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device) + if aux_input["style_mel"] is not None + else None + ) # pylint: disable=not-callable + # B x capacitron_VAE_embedding_dim + encoder_outputs, *_ = self.compute_capacitron_VAE_embedding( + encoder_outputs, + reference_mel_info=( + [aux_input["style_mel"], reference_mel_length] if aux_input["style_mel"] is not None else None + ), + text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None, + speaker_embedding=( + aux_input["d_vectors"] if self.capacitron_vae.capacitron_use_speaker_embedding else None + ), + ) + if self.num_speakers > 1: + if not self.use_d_vector_file: + # B x 1 x speaker_embed_dim + embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"]) + # reshape embedded_speakers + if embedded_speakers.ndim == 1: + embedded_speakers = embedded_speakers[None, None, :] + elif embedded_speakers.ndim == 2: + embedded_speakers = embedded_speakers[None, :] + else: + # B x 1 x speaker_embed_dim + embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs) + postnet_outputs = self.postnet(decoder_outputs) + postnet_outputs = self.last_linear(postnet_outputs) + decoder_outputs = decoder_outputs.transpose(1, 2) + outputs = { + "model_outputs": postnet_outputs, + "decoder_outputs": decoder_outputs, + "alignments": alignments, + "stop_tokens": stop_tokens, + } + return outputs + + def before_backward_pass(self, loss_dict, optimizer) -> None: + # Extracting custom training specific operations for capacitron + # from the trainer + if self.use_capacitron_vae: + loss_dict["capacitron_vae_beta_loss"].backward() + optimizer.first_step() + + def train_step(self, batch: Dict, criterion: torch.nn.Module) -> Tuple[Dict, Dict]: + """Perform a single training step by fetching the right set of samples from the batch. + + Args: + batch ([Dict]): A dictionary of input tensors. + criterion ([torch.nn.Module]): Callable criterion to compute model loss. + """ + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + linear_input = batch["linear_input"] + stop_targets = batch["stop_targets"] + stop_target_lengths = batch["stop_target_lengths"] + speaker_ids = batch["speaker_ids"] + d_vectors = batch["d_vectors"] + + aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} + outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input) + + # set the [alignment] lengths wrt reduction factor for guided attention + if mel_lengths.max() % self.decoder.r != 0: + alignment_lengths = ( + mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r)) + ) // self.decoder.r + else: + alignment_lengths = mel_lengths // self.decoder.r + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion( + outputs["model_outputs"].float(), + outputs["decoder_outputs"].float(), + mel_input.float(), + linear_input.float(), + outputs["stop_tokens"].float(), + stop_targets.float(), + stop_target_lengths, + outputs["capacitron_vae_outputs"] if self.capacitron_vae else None, + mel_lengths, + None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(), + outputs["alignments"].float(), + alignment_lengths, + None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(), + text_lengths, + ) + + # compute alignment error (the lower the better ) + align_error = 1 - alignment_diagonal_score(outputs["alignments"]) + loss_dict["align_error"] = align_error + return outputs, loss_dict + + def get_optimizer(self) -> List: + if self.use_capacitron_vae: + return CapacitronOptimizer(self.config, self.named_parameters()) + return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self) + + def get_scheduler(self, optimizer: object): + opt = optimizer.primary_optimizer if self.use_capacitron_vae else optimizer + return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, opt) + + def before_gradient_clipping(self): + if self.use_capacitron_vae: + # Capacitron model specific gradient clipping + model_params_to_clip = [] + for name, param in self.named_parameters(): + if param.requires_grad: + if name != "capacitron_vae_layer.beta": + model_params_to_clip.append(param) + torch.nn.utils.clip_grad_norm_(model_params_to_clip, self.capacitron_vae.capacitron_grad_clip) + + def _create_logs(self, batch, outputs, ap): + postnet_outputs = outputs["model_outputs"] + decoder_outputs = outputs["decoder_outputs"] + alignments = outputs["alignments"] + alignments_backward = outputs["alignments_backward"] + mel_input = batch["mel_input"] + linear_input = batch["linear_input"] + + pred_linear_spec = postnet_outputs[0].data.cpu().numpy() + pred_mel_spec = decoder_outputs[0].data.cpu().numpy() + gt_linear_spec = linear_input[0].data.cpu().numpy() + gt_mel_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "pred_linear_spec": plot_spectrogram(pred_linear_spec, ap, output_fig=False), + "real_linear_spec": plot_spectrogram(gt_linear_spec, ap, output_fig=False), + "pred_mel_spec": plot_spectrogram(pred_mel_spec, ap, output_fig=False), + "real_mel_spec": plot_spectrogram(gt_mel_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + if self.bidirectional_decoder or self.double_decoder_consistency: + figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False) + + # Sample audio + audio = ap.inv_spectrogram(pred_linear_spec.T) + return figures, {"audio": audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @staticmethod + def init_from_config(config: "TacotronConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (TacotronConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + return Tacotron(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py new file mode 100644 index 0000000000000000000000000000000000000000..4b1317f4405e4b759696d10dbbcb860024f350d4 --- /dev/null +++ b/TTS/tts/models/tacotron2.py @@ -0,0 +1,437 @@ +# coding: utf-8 + +from typing import Dict, List, Union + +import torch +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE +from TTS.tts.layers.tacotron.gst_layers import GST +from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet +from TTS.tts.models.base_tacotron import BaseTacotron +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.capacitron_optimizer import CapacitronOptimizer + + +class Tacotron2(BaseTacotron): + """Tacotron2 model implementation inherited from :class:`TTS.tts.models.base_tacotron.BaseTacotron`. + + Paper:: + https://arxiv.org/abs/1712.05884 + + Paper abstract:: + This paper describes Tacotron 2, a neural network architecture for speech synthesis directly from text. + The system is composed of a recurrent sequence-to-sequence feature prediction network that maps character + embeddings to mel-scale spectrograms, followed by a modified WaveNet model acting as a vocoder to synthesize + timedomain waveforms from those spectrograms. Our model achieves a mean opinion score (MOS) of 4.53 comparable + to a MOS of 4.58 for professionally recorded speech. To validate our design choices, we present ablation + studies of key components of our system and evaluate the impact of using mel spectrograms as the input to + WaveNet instead of linguistic, duration, and F0 features. We further demonstrate that using a compact acoustic + intermediate representation enables significant simplification of the WaveNet architecture. + + Check :class:`TTS.tts.configs.tacotron2_config.Tacotron2Config` for model arguments. + + Args: + config (TacotronConfig): + Configuration for the Tacotron2 model. + speaker_manager (SpeakerManager): + Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None. + """ + + def __init__( + self, + config: "Tacotron2Config", + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager) + + self.decoder_output_dim = config.out_channels + + # pass all config fields to `self` + # for fewer code change + for key in config: + setattr(self, key, config[key]) + + # init multi-speaker layers + if self.use_speaker_embedding or self.use_d_vector_file: + self.init_multispeaker(config) + self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim + + if self.use_gst: + self.decoder_in_features += self.gst.gst_embedding_dim + + if self.use_capacitron_vae: + self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim + + # embedding layer + self.embedding = nn.Embedding(self.num_chars, 512, padding_idx=0) + + # base model layers + self.encoder = Encoder(self.encoder_in_features) + + self.decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + self.r, + self.attention_type, + self.attention_win, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, + ) + self.postnet = Postnet(self.out_channels) + + # setup prenet dropout + self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference + + # global style token layers + if self.gst and self.use_gst: + self.gst_layer = GST( + num_mel=self.decoder_output_dim, + num_heads=self.gst.gst_num_heads, + num_style_tokens=self.gst.gst_num_style_tokens, + gst_embedding_dim=self.gst.gst_embedding_dim, + ) + + # Capacitron VAE Layers + if self.capacitron_vae and self.use_capacitron_vae: + self.capacitron_vae_layer = CapacitronVAE( + num_mel=self.decoder_output_dim, + encoder_output_dim=self.encoder_in_features, + capacitron_VAE_embedding_dim=self.capacitron_vae.capacitron_VAE_embedding_dim, + speaker_embedding_dim=( + self.embedded_speaker_dim if self.capacitron_vae.capacitron_use_speaker_embedding else None + ), + text_summary_embedding_dim=( + self.capacitron_vae.capacitron_text_summary_embedding_dim + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None + ), + ) + + # backward pass decoder + if self.bidirectional_decoder: + self._init_backward_decoder() + # setup DDC + if self.double_decoder_consistency: + self.coarse_decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + self.ddc_r, + self.attention_type, + self.attention_win, + self.attention_norm, + self.prenet_type, + self.prenet_dropout, + self.use_forward_attn, + self.transition_agent, + self.forward_attn_mask, + self.location_attn, + self.attention_heads, + self.separate_stopnet, + self.max_decoder_steps, + ) + + @staticmethod + def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): + """Final reshape of the model output tensors.""" + mel_outputs = mel_outputs.transpose(1, 2) + mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) + return mel_outputs, mel_outputs_postnet, alignments + + def forward( # pylint: disable=dangerous-default-value + self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input={"speaker_ids": None, "d_vectors": None} + ): + """Forward pass for training with Teacher Forcing. + + Shapes: + text: :math:`[B, T_in]` + text_lengths: :math:`[B]` + mel_specs: :math:`[B, T_out, C]` + mel_lengths: :math:`[B]` + aux_input: 'speaker_ids': :math:`[B, 1]` and 'd_vectors': :math:`[B, C]` + """ + aux_input = self._format_aux_input(aux_input) + outputs = {"alignments_backward": None, "decoder_outputs_backward": None} + # compute mask for padding + # B x T_in_max (boolean) + input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) + # B x D_embed x T_in_max + embedded_inputs = self.embedding(text).transpose(1, 2) + # B x T_in_max x D_en + encoder_outputs = self.encoder(embedded_inputs, text_lengths) + if self.gst and self.use_gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) + + if self.use_speaker_embedding or self.use_d_vector_file: + if not self.use_d_vector_file: + # B x 1 x speaker_embed_dim + embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None] + else: + # B x 1 x speaker_embed_dim + embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1) + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + + # capacitron + if self.capacitron_vae and self.use_capacitron_vae: + # B x capacitron_VAE_embedding_dim + encoder_outputs, *capacitron_vae_outputs = self.compute_capacitron_VAE_embedding( + encoder_outputs, + reference_mel_info=[mel_specs, mel_lengths], + text_info=( + [embedded_inputs.transpose(1, 2), text_lengths] + if self.capacitron_vae.capacitron_use_text_summary_embeddings + else None + ), + speaker_embedding=embedded_speakers if self.capacitron_vae.capacitron_use_speaker_embedding else None, + ) + else: + capacitron_vae_outputs = None + + encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) + + # B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r + decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask) + # sequence masking + if mel_lengths is not None: + decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) + # B x mel_dim x T_out + postnet_outputs = self.postnet(decoder_outputs) + postnet_outputs = decoder_outputs + postnet_outputs + # sequence masking + if output_mask is not None: + postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs) + # B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in + decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments) + if self.bidirectional_decoder: + decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask) + outputs["alignments_backward"] = alignments_backward + outputs["decoder_outputs_backward"] = decoder_outputs_backward + if self.double_decoder_consistency: + decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass( + mel_specs, encoder_outputs, alignments, input_mask + ) + outputs["alignments_backward"] = alignments_backward + outputs["decoder_outputs_backward"] = decoder_outputs_backward + outputs.update( + { + "model_outputs": postnet_outputs, + "decoder_outputs": decoder_outputs, + "alignments": alignments, + "stop_tokens": stop_tokens, + "capacitron_vae_outputs": capacitron_vae_outputs, + } + ) + return outputs + + @torch.no_grad() + def inference(self, text, aux_input=None): + """Forward pass for inference with no Teacher-Forcing. + + Shapes: + text: :math:`[B, T_in]` + text_lengths: :math:`[B]` + """ + aux_input = self._format_aux_input(aux_input) + embedded_inputs = self.embedding(text).transpose(1, 2) + encoder_outputs = self.encoder.inference(embedded_inputs) + + if self.gst and self.use_gst: + # B x gst_dim + encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"]) + + if self.capacitron_vae and self.use_capacitron_vae: + if aux_input["style_text"] is not None: + style_text_embedding = self.embedding(aux_input["style_text"]) + style_text_length = torch.tensor([style_text_embedding.size(1)], dtype=torch.int64).to( + encoder_outputs.device + ) # pylint: disable=not-callable + reference_mel_length = ( + torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device) + if aux_input["style_mel"] is not None + else None + ) # pylint: disable=not-callable + # B x capacitron_VAE_embedding_dim + encoder_outputs, *_ = self.compute_capacitron_VAE_embedding( + encoder_outputs, + reference_mel_info=( + [aux_input["style_mel"], reference_mel_length] if aux_input["style_mel"] is not None else None + ), + text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None, + speaker_embedding=( + aux_input["d_vectors"] if self.capacitron_vae.capacitron_use_speaker_embedding else None + ), + ) + + if self.num_speakers > 1: + if not self.use_d_vector_file: + embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None] + # reshape embedded_speakers + if embedded_speakers.ndim == 1: + embedded_speakers = embedded_speakers[None, None, :] + elif embedded_speakers.ndim == 2: + embedded_speakers = embedded_speakers[None, :] + else: + embedded_speakers = aux_input["d_vectors"] + + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers) + + decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs) + postnet_outputs = self.postnet(decoder_outputs) + postnet_outputs = decoder_outputs + postnet_outputs + decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments) + outputs = { + "model_outputs": postnet_outputs, + "decoder_outputs": decoder_outputs, + "alignments": alignments, + "stop_tokens": stop_tokens, + } + return outputs + + def before_backward_pass(self, loss_dict, optimizer) -> None: + # Extracting custom training specific operations for capacitron + # from the trainer + if self.use_capacitron_vae: + loss_dict["capacitron_vae_beta_loss"].backward() + optimizer.first_step() + + def train_step(self, batch: Dict, criterion: torch.nn.Module): + """A single training step. Forward pass and loss computation. + + Args: + batch ([Dict]): A dictionary of input tensors. + criterion ([type]): Callable criterion to compute model loss. + """ + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] + mel_input = batch["mel_input"] + mel_lengths = batch["mel_lengths"] + stop_targets = batch["stop_targets"] + stop_target_lengths = batch["stop_target_lengths"] + speaker_ids = batch["speaker_ids"] + d_vectors = batch["d_vectors"] + + aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors} + outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input) + + # set the [alignment] lengths wrt reduction factor for guided attention + if mel_lengths.max() % self.decoder.r != 0: + alignment_lengths = ( + mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r)) + ) // self.decoder.r + else: + alignment_lengths = mel_lengths // self.decoder.r + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion( + outputs["model_outputs"].float(), + outputs["decoder_outputs"].float(), + mel_input.float(), + None, + outputs["stop_tokens"].float(), + stop_targets.float(), + stop_target_lengths, + outputs["capacitron_vae_outputs"] if self.capacitron_vae else None, + mel_lengths, + None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(), + outputs["alignments"].float(), + alignment_lengths, + None if outputs["alignments_backward"] is None else outputs["alignments_backward"].float(), + text_lengths, + ) + + # compute alignment error (the lower the better ) + align_error = 1 - alignment_diagonal_score(outputs["alignments"]) + loss_dict["align_error"] = align_error + return outputs, loss_dict + + def get_optimizer(self) -> List: + if self.use_capacitron_vae: + return CapacitronOptimizer(self.config, self.named_parameters()) + return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self) + + def get_scheduler(self, optimizer: object): + opt = optimizer.primary_optimizer if self.use_capacitron_vae else optimizer + return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, opt) + + def before_gradient_clipping(self): + if self.use_capacitron_vae: + # Capacitron model specific gradient clipping + model_params_to_clip = [] + for name, param in self.named_parameters(): + if param.requires_grad: + if name != "capacitron_vae_layer.beta": + model_params_to_clip.append(param) + torch.nn.utils.clip_grad_norm_(model_params_to_clip, self.capacitron_vae.capacitron_grad_clip) + + def _create_logs(self, batch, outputs, ap): + """Create dashboard log information.""" + postnet_outputs = outputs["model_outputs"] + alignments = outputs["alignments"] + alignments_backward = outputs["alignments_backward"] + mel_input = batch["mel_input"] + + pred_spec = postnet_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + if self.bidirectional_decoder or self.double_decoder_consistency: + figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False) + + # Sample audio + audio = ap.inv_melspectrogram(pred_spec.T) + return figures, {"audio": audio} + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + """Log training progress.""" + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._create_logs(batch, outputs, self.ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @staticmethod + def init_from_config(config: "Tacotron2Config", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (Tacotron2Config): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(new_config, samples) + return Tacotron2(new_config, ap, tokenizer, speaker_manager) diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py new file mode 100644 index 0000000000000000000000000000000000000000..01629b5d2a29f810cdba009ff7f1bef26b1cd3f5 --- /dev/null +++ b/TTS/tts/models/tortoise.py @@ -0,0 +1,924 @@ +import logging +import os +import random +from contextlib import contextmanager +from dataclasses import dataclass +from time import time + +import torch +import torch.nn.functional as F +import torchaudio +from coqpit import Coqpit +from tqdm import tqdm + +from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram +from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, load_voice, wav_to_univnet_mel +from TTS.tts.layers.tortoise.autoregressive import UnifiedVoice +from TTS.tts.layers.tortoise.classifier import AudioMiniEncoderWithClassifierHead +from TTS.tts.layers.tortoise.clvp import CLVP +from TTS.tts.layers.tortoise.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps +from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts +from TTS.tts.layers.tortoise.random_latent_generator import RandomLatentConverter +from TTS.tts.layers.tortoise.tokenizer import VoiceBpeTokenizer +from TTS.tts.layers.tortoise.vocoder import VocConf, VocType +from TTS.tts.layers.tortoise.wav2vec_alignment import Wav2VecAlignment +from TTS.tts.models.base_tts import BaseTTS +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 + +logger = logging.getLogger(__name__) + + +def pad_or_truncate(t, length): + """ + Utility function for forcing to have the specified sequence length, whether by clipping it or padding it with 0s. + """ + tp = t[..., :length] + if t.shape[-1] == length: + tp = t + elif t.shape[-1] < length: + tp = F.pad(t, (0, length - t.shape[-1])) + return tp + + +def deterministic_state(seed=None): + """ + Sets the random seeds that tortoise uses to the current time() and returns that seed so results can be + reproduced. + """ + seed = int(time()) if seed is None else seed + torch.manual_seed(seed) + random.seed(seed) + # Can't currently set this because of CUBLAS. TODO: potentially enable it if necessary. + # torch.use_deterministic_algorithms(True) + + return seed + + +def load_discrete_vocoder_diffuser( + trained_diffusion_steps=4000, + desired_diffusion_steps=200, + cond_free=True, + cond_free_k=1, + sampler="ddim", +): + """ + Helper function to load a GaussianDiffusion instance configured for use as a vocoder. + """ + return SpacedDiffusion( + use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), + model_mean_type="epsilon", + model_var_type="learned_range", + loss_type="mse", + betas=get_named_beta_schedule("linear", trained_diffusion_steps), + conditioning_free=cond_free, + conditioning_free_k=cond_free_k, + sampler=sampler, + ) + + +def format_conditioning(clip, cond_length=132300, device="cuda", **kwargs): + """ + Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models. + """ + gap = clip.shape[-1] - cond_length + if gap < 0: + clip = F.pad(clip, pad=(0, abs(gap))) + elif gap > 0: + rand_start = random.randint(0, gap) + clip = clip[:, rand_start : rand_start + cond_length] + mel_clip = TorchMelSpectrogram(**kwargs)(clip.unsqueeze(0)).squeeze(0) + return mel_clip.unsqueeze(0).to(device) + + +def fix_autoregressive_output(codes, stop_token, complain=True): + """ + This function performs some padding on coded audio that fixes a mismatch issue between what the diffusion model was + trained on and what the autoregressive code generator creates (which has no padding or end). + This is highly specific to the DVAE being used, so this particular coding will not necessarily work if used with + a different DVAE. This can be inferred by feeding a audio clip padded with lots of zeros on the end through the DVAE + and copying out the last few codes. + + Failing to do this padding will produce speech with a harsh end that sounds like "BLAH" or similar. + """ + # Strip off the autoregressive stop token and add padding. + stop_token_indices = (codes == stop_token).nonzero() + if len(stop_token_indices) == 0: + if complain: + logger.warning( + "No stop tokens found in one of the generated voice clips. This typically means the spoken audio is " + "too long. In some cases, the output will still be good, though. Listen to it and if it is missing words, " + "try breaking up your input text." + ) + return codes + codes[stop_token_indices] = 83 + stm = stop_token_indices.min().item() + codes[stm:] = 83 + if stm - 3 < codes.shape[0]: + codes[-3] = 45 + codes[-2] = 45 + codes[-1] = 248 + return codes + + +def do_spectrogram_diffusion( + diffusion_model, + diffuser, + latents, + conditioning_latents, + temperature=1, + verbose=True, +): + """ + Uses the specified diffusion model to convert discrete codes into a spectrogram. + """ + with torch.no_grad(): + output_seq_len = ( + latents.shape[1] * 4 * 24000 // 22050 + ) # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. + output_shape = (latents.shape[0], 100, output_seq_len) + precomputed_embeddings = diffusion_model.timestep_independent( + latents, conditioning_latents, output_seq_len, False + ) + + noise = torch.randn(output_shape, device=latents.device) * temperature + mel = diffuser.sample_loop( + diffusion_model, + output_shape, + noise=noise, + model_kwargs={"precomputed_aligned_embeddings": precomputed_embeddings}, + progress=verbose, + ) + return denormalize_tacotron_mel(mel)[:, :, :output_seq_len] + + +def classify_audio_clip(clip, model_dir): + """ + Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise. + :param clip: torch tensor containing audio waveform data (get it from load_audio) + :return: True if the clip was classified as coming from Tortoise and false if it was classified as real. + """ + classifier = AudioMiniEncoderWithClassifierHead( + 2, + spec_dim=1, + embedding_dim=512, + depth=5, + downsample_factor=4, + resnet_blocks=2, + attn_blocks=4, + num_attn_heads=4, + base_channels=32, + dropout=0, + kernel_size=5, + distribute_zero_label=False, + ) + classifier.load_state_dict( + torch.load( + os.path.join(model_dir, "classifier.pth"), + map_location=torch.device("cpu"), + weights_only=is_pytorch_at_least_2_4(), + ) + ) + clip = clip.cpu().unsqueeze(0) + results = F.softmax(classifier(clip), dim=-1) + return results[0][0] + + +def pick_best_batch_size_for_gpu(): + """ + Tries to pick a batch size that will fit in your GPU. These sizes aren't guaranteed to work, but they should give + you a good shot. + """ + if torch.cuda.is_available(): + _, available = torch.cuda.mem_get_info() + availableGb = available / (1024**3) + batch_size = 1 + if availableGb > 14: + batch_size = 16 + elif availableGb > 10: + batch_size = 8 + elif availableGb > 7: + batch_size = 4 + return batch_size + + +@dataclass +class TortoiseAudioConfig(Coqpit): + sample_rate: int = 22050 + diffusion_sample_rate: int = 24000 + output_sample_rate: int = 24000 + + +@dataclass +class TortoiseArgs(Coqpit): + """A dataclass to represent Tortoise model arguments that define the model structure. + + Args: + autoregressive_batch_size (int): The size of the auto-regressive batch. + enable_redaction (bool, optional): Whether to enable redaction. Defaults to True. + high_vram (bool, optional): Whether to use high VRAM. Defaults to False. + kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True. + ar_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None. + clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. + diff_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. + num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. + vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet. + + For UnifiedVoice model: + ar_max_mel_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. + ar_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402. + ar_max_conditioning_inputs (int, optional): The maximum conditioning inputs for the autoregressive model. Defaults to 2. + ar_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30. + ar_model_dim (int, optional): The model dimension for the autoregressive model. Defaults to 1024. + ar_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16. + ar_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255. + ar_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255. + ar_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False. + ar_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False. + + For DiffTTS model: + diff_model_channels (int, optional): The number of channels for the DiffTTS model. Defaults to 1024. + diff_num_layers (int, optional): The number of layers for the DiffTTS model. Defaults to 10. + diff_in_channels (int, optional): The input channels for the DiffTTS model. Defaults to 100. + diff_out_channels (int, optional): The output channels for the DiffTTS model. Defaults to 200. + diff_in_latent_channels (int, optional): The input latent channels for the DiffTTS model. Defaults to 1024. + diff_in_tokens (int, optional): The input tokens for the DiffTTS model. Defaults to 8193. + diff_dropout (int, optional): The dropout percentage for the DiffTTS model. Defaults to 0. + diff_use_fp16 (bool, optional): Whether to use fp16 for the DiffTTS model. Defaults to False. + diff_num_heads (int, optional): The number of heads for the DiffTTS model. Defaults to 16. + diff_layer_drop (int, optional): The layer dropout percentage for the DiffTTS model. Defaults to 0. + diff_unconditioned_percentage (int, optional): The percentage of unconditioned inputs for the DiffTTS model. Defaults to 0. + + For ConditionalLatentVariablePerseq model: + clvp_dim_text (int): The dimension of the text input for the CLVP module. Defaults to 768. + clvp_dim_speech (int): The dimension of the speech input for the CLVP module. Defaults to 768. + clvp_dim_latent (int): The dimension of the latent representation for the CLVP module. Defaults to 768. + clvp_num_text_tokens (int): The number of text tokens used by the CLVP module. Defaults to 256. + clvp_text_enc_depth (int): The depth of the text encoder in the CLVP module. Defaults to 20. + clvp_text_seq_len (int): The maximum sequence length of the text input for the CLVP module. Defaults to 350. + clvp_text_heads (int): The number of attention heads used by the text encoder in the CLVP module. Defaults to 12. + clvp_num_speech_tokens (int): The number of speech tokens used by the CLVP module. Defaults to 8192. + clvp_speech_enc_depth (int): The depth of the speech encoder in the CLVP module. Defaults to 20. + clvp_speech_heads (int): The number of attention heads used by the speech encoder in the CLVP module. Defaults to 12. + clvp_speech_seq_len (int): The maximum sequence length of the speech input for the CLVP module. Defaults to 430. + clvp_use_xformers (bool): A flag indicating whether the model uses transformers in the CLVP module. Defaults to True. + duration_const (int): A constant value used in the model. Defaults to 102400. + """ + + autoregressive_batch_size: int = 1 + enable_redaction: bool = False + high_vram: bool = False + kv_cache: bool = True + ar_checkpoint: str = None + clvp_checkpoint: str = None + diff_checkpoint: str = None + num_chars: int = 255 + vocoder: VocType = VocConf.Univnet + + # UnifiedVoice params + ar_max_mel_tokens: int = 604 + ar_max_text_tokens: int = 402 + ar_max_conditioning_inputs: int = 2 + ar_layers: int = 30 + ar_model_dim: int = 1024 + ar_heads: int = 16 + ar_number_text_tokens: int = 255 + ar_start_text_token: int = 255 + ar_checkpointing: bool = False + ar_train_solo_embeddings: bool = False + + # DiffTTS params + diff_model_channels: int = 1024 + diff_num_layers: int = 10 + diff_in_channels: int = 100 + diff_out_channels: int = 200 + diff_in_latent_channels: int = 1024 + diff_in_tokens: int = 8193 + diff_dropout: int = 0 + diff_use_fp16: bool = False + diff_num_heads: int = 16 + diff_layer_drop: int = 0 + diff_unconditioned_percentage: int = 0 + + # clvp params + clvp_dim_text: int = 768 + clvp_dim_speech: int = 768 + clvp_dim_latent: int = 768 + clvp_num_text_tokens: int = 256 + clvp_text_enc_depth: int = 20 + clvp_text_seq_len: int = 350 + clvp_text_heads: int = 12 + clvp_num_speech_tokens: int = 8192 + clvp_speech_enc_depth: int = 20 + clvp_speech_heads: int = 12 + clvp_speech_seq_len: int = 430 + clvp_use_xformers: bool = True + # constants + duration_const: int = 102400 + + +class Tortoise(BaseTTS): + """Tortoise model class. + + Currently only supports inference. + + Examples: + >>> from TTS.tts.configs.tortoise_config import TortoiseConfig + >>> from TTS.tts.models.tortoise import Tortoise + >>> config = TortoiseConfig() + >>> model = Tortoise.inif_from_config(config) + >>> model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True) + """ + + def __init__(self, config: Coqpit): + super().__init__(config, ap=None, tokenizer=None) + self.mel_norm_path = None + self.config = config + self.ar_checkpoint = self.args.ar_checkpoint + self.diff_checkpoint = self.args.diff_checkpoint # TODO: check if this is even needed + self.models_dir = config.model_dir + self.autoregressive_batch_size = ( + pick_best_batch_size_for_gpu() + if self.args.autoregressive_batch_size is None + else self.args.autoregressive_batch_size + ) + self.enable_redaction = self.args.enable_redaction + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if self.enable_redaction: + self.aligner = Wav2VecAlignment() + + self.tokenizer = VoiceBpeTokenizer() + + self.autoregressive = UnifiedVoice( + max_mel_tokens=self.args.ar_max_mel_tokens, + max_text_tokens=self.args.ar_max_text_tokens, + max_conditioning_inputs=self.args.ar_max_conditioning_inputs, + layers=self.args.ar_layers, + model_dim=self.args.ar_model_dim, + heads=self.args.ar_heads, + number_text_tokens=self.args.ar_number_text_tokens, + start_text_token=self.args.ar_start_text_token, + checkpointing=self.args.ar_checkpointing, + train_solo_embeddings=self.args.ar_train_solo_embeddings, + ).cpu() + + self.diffusion = DiffusionTts( + model_channels=self.args.diff_model_channels, + num_layers=self.args.diff_num_layers, + in_channels=self.args.diff_in_channels, + out_channels=self.args.diff_out_channels, + in_latent_channels=self.args.diff_in_latent_channels, + in_tokens=self.args.diff_in_tokens, + dropout=self.args.diff_dropout, + use_fp16=self.args.diff_use_fp16, + num_heads=self.args.diff_num_heads, + layer_drop=self.args.diff_layer_drop, + unconditioned_percentage=self.args.diff_unconditioned_percentage, + ).cpu() + + self.clvp = CLVP( + dim_text=self.args.clvp_dim_text, + dim_speech=self.args.clvp_dim_speech, + dim_latent=self.args.clvp_dim_latent, + num_text_tokens=self.args.clvp_num_text_tokens, + text_enc_depth=self.args.clvp_text_enc_depth, + text_seq_len=self.args.clvp_text_seq_len, + text_heads=self.args.clvp_text_heads, + num_speech_tokens=self.args.clvp_num_speech_tokens, + speech_enc_depth=self.args.clvp_speech_enc_depth, + speech_heads=self.args.clvp_speech_heads, + speech_seq_len=self.args.clvp_speech_seq_len, + use_xformers=self.args.clvp_use_xformers, + ).cpu() + + self.vocoder = self.args.vocoder.value.constructor().cpu() + + # Random latent generators (RLGs) are loaded lazily. + self.rlg_auto = None + self.rlg_diffusion = None + + if self.args.high_vram: + self.autoregressive = self.autoregressive.to(self.device) + self.diffusion = self.diffusion.to(self.device) + self.clvp = self.clvp.to(self.device) + self.vocoder = self.vocoder.to(self.device) + self.high_vram = self.args.high_vram + + @contextmanager + def temporary_cuda(self, model): + if self.high_vram: + yield model + else: + m = model.to(self.device) + yield m + m = model.cpu() + + def get_conditioning_latents( + self, + voice_samples, + return_mels=False, + latent_averaging_mode=0, + original_tortoise=False, + ): + """ + Transforms one or more voice_samples into a tuple (autoregressive_conditioning_latent, diffusion_conditioning_latent). + These are expressive learned latents that encode aspects of the provided clips like voice, intonation, and acoustic + properties. + :param voice_samples: List of arbitrary reference clips, which should be *pairs* of torch tensors containing arbitrary kHz waveform data. + :param latent_averaging_mode: 0/1/2 for following modes: + 0 - latents will be generated as in original tortoise, using ~4.27s from each voice sample, averaging latent across all samples + 1 - latents will be generated using (almost) entire voice samples, averaged across all the ~4.27s chunks + 2 - latents will be generated using (almost) entire voice samples, averaged per voice sample + """ + assert latent_averaging_mode in [ + 0, + 1, + 2, + ], "latent_averaging mode has to be one of (0, 1, 2)" + + with torch.no_grad(): + voice_samples = [[v.to(self.device) for v in ls] for ls in voice_samples] + + auto_conds = [] + for ls in voice_samples: + auto_conds.append(format_conditioning(ls[0], device=self.device, mel_norm_file=self.mel_norm_path)) + auto_conds = torch.stack(auto_conds, dim=1) + with self.temporary_cuda(self.autoregressive) as ar: + auto_latent = ar.get_conditioning(auto_conds) + + diffusion_conds = [] + + DURS_CONST = self.args.duration_const + for ls in voice_samples: + # The diffuser operates at a sample rate of 24000 (except for the latent inputs) + sample = torchaudio.functional.resample(ls[0], 22050, 24000) if original_tortoise else ls[1] + if latent_averaging_mode == 0: + sample = pad_or_truncate(sample, DURS_CONST) + cond_mel = wav_to_univnet_mel( + sample.to(self.device), + do_normalization=False, + device=self.device, + ) + diffusion_conds.append(cond_mel) + else: + from math import ceil + + if latent_averaging_mode == 2: + temp_diffusion_conds = [] + for chunk in range(ceil(sample.shape[1] / DURS_CONST)): + current_sample = sample[:, chunk * DURS_CONST : (chunk + 1) * DURS_CONST] + current_sample = pad_or_truncate(current_sample, DURS_CONST) + cond_mel = wav_to_univnet_mel( + current_sample.to(self.device), + do_normalization=False, + device=self.device, + ) + if latent_averaging_mode == 1: + diffusion_conds.append(cond_mel) + elif latent_averaging_mode == 2: + temp_diffusion_conds.append(cond_mel) + if latent_averaging_mode == 2: + diffusion_conds.append(torch.stack(temp_diffusion_conds).mean(0)) + diffusion_conds = torch.stack(diffusion_conds, dim=1) + + with self.temporary_cuda(self.diffusion) as diffusion: + diffusion_latent = diffusion.get_conditioning(diffusion_conds) + + if return_mels: + return auto_latent, diffusion_latent, auto_conds, diffusion_conds + return auto_latent, diffusion_latent + + def get_random_conditioning_latents(self): + # Lazy-load the RLG models. + if self.rlg_auto is None: + self.rlg_auto = RandomLatentConverter(1024).eval() + self.rlg_auto.load_state_dict( + torch.load( + os.path.join(self.models_dir, "rlg_auto.pth"), + map_location=torch.device("cpu"), + weights_only=is_pytorch_at_least_2_4(), + ) + ) + self.rlg_diffusion = RandomLatentConverter(2048).eval() + self.rlg_diffusion.load_state_dict( + torch.load( + os.path.join(self.models_dir, "rlg_diffuser.pth"), + map_location=torch.device("cpu"), + weights_only=is_pytorch_at_least_2_4(), + ) + ) + with torch.no_grad(): + return self.rlg_auto(torch.tensor([0.0])), self.rlg_diffusion(torch.tensor([0.0])) + + def synthesize(self, text, config, speaker_id="random", voice_dirs=None, **kwargs): + """Synthesize speech with the given input text. + + Args: + text (str): Input text. + config (TortoiseConfig): Config with inference parameters. + speaker_id (str): One of the available speaker names. If `random`, it generates a random speaker. + voice_dirs (List[str]): List of paths that host reference audio files for speakers. Defaults to None. + **kwargs: Inference settings. See `inference()`. + + Returns: + A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference, + `text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents` + as latents used at inference. + + """ + + speaker_id = "random" if speaker_id is None else speaker_id + + if voice_dirs is not None: + voice_dirs = [voice_dirs] + voice_samples, conditioning_latents = load_voice(speaker_id, voice_dirs) + + else: + voice_samples, conditioning_latents = load_voice(speaker_id) + + outputs = self.inference_with_config( + text, config, voice_samples=voice_samples, conditioning_latents=conditioning_latents, **kwargs + ) + + return_dict = { + "wav": outputs["wav"], + "deterministic_seed": outputs["deterministic_seed"], + "text_inputs": outputs["text"], + "voice_samples": outputs["voice_samples"], + "conditioning_latents": outputs["conditioning_latents"], + } + + return return_dict + + def inference_with_config(self, text, config, **kwargs): + """ + inference with config + #TODO describe in detail + """ + # Use generally found best tuning knobs for generation. + settings = { + "temperature": config.temperature, + "length_penalty": config.length_penalty, + "repetition_penalty": config.repetition_penalty, + "top_p": config.top_p, + "cond_free_k": config.cond_free_k, + "diffusion_temperature": config.diffusion_temperature, + "sampler": config.sampler, + } + # Presets are defined here. + presets = { + "single_sample": { + "num_autoregressive_samples": 8, + "diffusion_iterations": 10, + "sampler": "ddim", + }, + "ultra_fast": { + "num_autoregressive_samples": 16, + "diffusion_iterations": 10, + "sampler": "ddim", + }, + "ultra_fast_old": { + "num_autoregressive_samples": 16, + "diffusion_iterations": 30, + "cond_free": False, + }, + "very_fast": { + "num_autoregressive_samples": 32, + "diffusion_iterations": 30, + "sampler": "dpm++2m", + }, + "fast": { + "num_autoregressive_samples": 5, + "diffusion_iterations": 50, + "sampler": "ddim", + }, + "fast_old": {"num_autoregressive_samples": 96, "diffusion_iterations": 80}, + "standard": { + "num_autoregressive_samples": 5, + "diffusion_iterations": 200, + }, + "high_quality": { + "num_autoregressive_samples": 256, + "diffusion_iterations": 400, + }, + } + if "preset" in kwargs: + settings.update(presets[kwargs["preset"]]) + kwargs.pop("preset") + settings.update(kwargs) # allow overriding of preset settings with kwargs + return self.inference(text, **settings) + + def inference( + self, + text, + voice_samples=None, + conditioning_latents=None, + k=1, + verbose=True, + use_deterministic_seed=None, + return_deterministic_state=False, + latent_averaging_mode=0, + # autoregressive generation parameters follow + num_autoregressive_samples=16, + temperature=0.8, + length_penalty=1, + repetition_penalty=2.0, + top_p=0.8, + max_mel_tokens=500, + # diffusion generation parameters follow + diffusion_iterations=100, + cond_free=True, + cond_free_k=2, + diffusion_temperature=1.0, + sampler="ddim", + half=True, + original_tortoise=False, + **hf_generate_kwargs, + ): + """ + This function produces an audio clip of the given text being spoken with the given reference voice. + + Args: + text: (str) Text to be spoken. + voice_samples: (List[Tuple[torch.Tensor]]) List of an arbitrary number of reference clips, which should be tuple-pairs + of torch tensors containing arbitrary kHz waveform data. + conditioning_latents: (Tuple[autoregressive_conditioning_latent, diffusion_conditioning_latent]) A tuple of + (autoregressive_conditioning_latent, diffusion_conditioning_latent), which can be provided in lieu + of voice_samples. This is ignored unless `voice_samples=None`. Conditioning latents can be retrieved + via `get_conditioning_latents()`. + k: (int) The number of returned clips. The most likely (as determined by Tortoises' CLVP model) clips are returned. + latent_averaging_mode: (int) 0/1/2 for following modes: + 0 - latents will be generated as in original tortoise, using ~4.27s from each voice sample, averaging latent across all samples + 1 - latents will be generated using (almost) entire voice samples, averaged across all the ~4.27s chunks + 2 - latents will be generated using (almost) entire voice samples, averaged per voice sample + verbose: (bool) Whether or not to print log messages indicating the progress of creating a clip. Default=true. + num_autoregressive_samples: (int) Number of samples taken from the autoregressive model, all of which are filtered using CLVP. + As Tortoise is a probabilistic model, more samples means a higher probability of creating something "great". + temperature: (float) The softmax temperature of the autoregressive model. + length_penalty: (float) A length penalty applied to the autoregressive decoder. Higher settings causes the model to produce more terse outputs. + repetition_penalty: (float) A penalty that prevents the autoregressive decoder from repeating itself during decoding. Can be used to reduce + the incidence of long silences or "uhhhhhhs", etc. + top_p: (float) P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" (aka boring) outputs. + max_mel_tokens: (int) Restricts the output length. (0,600] integer. Each unit is 1/20 of a second. + typical_sampling: (bool) Turns typical sampling on or off. This sampling mode is discussed in this paper: https://arxiv.org/abs/2202.00666 + I was interested in the premise, but the results were not as good as I was hoping. This is off by default, but could use some tuning. + typical_mass: (float) The typical_mass parameter from the typical_sampling algorithm. + diffusion_iterations: (int) Number of diffusion steps to perform. [0,4000]. More steps means the network has more chances to iteratively + refine the output, which should theoretically mean a higher quality output. Generally a value above 250 is not noticeably better, however. + cond_free: (bool) Whether or not to perform conditioning-free diffusion. Conditioning-free diffusion performs two forward passes for + each diffusion step: one with the outputs of the autoregressive model and one with no conditioning priors. The output of the two + is blended according to the cond_free_k value below. Conditioning-free diffusion is the real deal, and dramatically improves realism. + cond_free_k: (float) Knob that determines how to balance the conditioning free signal with the conditioning-present signal. [0,inf]. + As cond_free_k increases, the output becomes dominated by the conditioning-free signal. + diffusion_temperature: (float) Controls the variance of the noise fed into the diffusion model. [0,1]. Values at 0 + are the "mean" prediction of the diffusion network and will sound bland and smeared. + hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive transformer. + Extra keyword args fed to this function get forwarded directly to that API. Documentation + here: https://huggingface.co/docs/transformers/internal/generation_utils + + Returns: + Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. + Sample rate is 24kHz. + """ + deterministic_seed = deterministic_state(seed=use_deterministic_seed) + + text_tokens = torch.IntTensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.device) + text_tokens = F.pad(text_tokens, (0, 1)) # This may not be necessary. + assert ( + text_tokens.shape[-1] < 400 + ), "Too much text provided. Break the text up into separate segments and re-try inference." + + if voice_samples is not None: + ( + auto_conditioning, + diffusion_conditioning, + _, + _, + ) = self.get_conditioning_latents( + voice_samples, + return_mels=True, + latent_averaging_mode=latent_averaging_mode, + original_tortoise=original_tortoise, + ) + elif conditioning_latents is not None: + auto_conditioning, diffusion_conditioning = conditioning_latents + else: + ( + auto_conditioning, + diffusion_conditioning, + ) = self.get_random_conditioning_latents() + auto_conditioning = auto_conditioning.to(self.device) + diffusion_conditioning = diffusion_conditioning.to(self.device) + + diffuser = load_discrete_vocoder_diffuser( + desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k, sampler=sampler + ) + + # in the case of single_sample, + orig_batch_size = self.autoregressive_batch_size + while num_autoregressive_samples % self.autoregressive_batch_size: + self.autoregressive_batch_size //= 2 + with torch.no_grad(): + samples = [] + num_batches = num_autoregressive_samples // self.autoregressive_batch_size + stop_mel_token = self.autoregressive.stop_mel_token + calm_token = ( + 83 # This is the token for coding silence, which is fixed in place with "fix_autoregressive_output" + ) + self.autoregressive = self.autoregressive.to(self.device) + logger.info("Generating autoregressive samples..") + with ( + self.temporary_cuda(self.autoregressive) as autoregressive, + torch.autocast(device_type="cuda", dtype=torch.float16, enabled=half), + ): + for b in tqdm(range(num_batches), disable=not verbose): + codes = autoregressive.inference_speech( + auto_conditioning, + text_tokens, + do_sample=True, + top_p=top_p, + temperature=temperature, + num_return_sequences=self.autoregressive_batch_size, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + max_generate_length=max_mel_tokens, + **hf_generate_kwargs, + ) + padding_needed = max_mel_tokens - codes.shape[1] + codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) + samples.append(codes) + self.autoregressive_batch_size = orig_batch_size # in the case of single_sample + + clip_results = [] + with ( + self.temporary_cuda(self.clvp) as clvp, + torch.autocast(device_type="cuda", dtype=torch.float16, enabled=half), + ): + for batch in tqdm(samples, disable=not verbose): + for i in range(batch.shape[0]): + batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) + clvp_res = clvp( + text_tokens.repeat(batch.shape[0], 1), + batch, + return_loss=False, + ) + clip_results.append(clvp_res) + + clip_results = torch.cat(clip_results, dim=0) + samples = torch.cat(samples, dim=0) + best_results = samples[torch.topk(clip_results, k=k).indices] + del samples + + # The diffusion model actually wants the last hidden layer from the autoregressive model as conditioning + # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these + # results, but will increase memory usage. + with self.temporary_cuda(self.autoregressive) as autoregressive: + best_latents = autoregressive( + auto_conditioning.repeat(k, 1), + text_tokens.repeat(k, 1), + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), + best_results, + torch.tensor( + [best_results.shape[-1] * self.autoregressive.mel_length_compression], + device=text_tokens.device, + ), + return_latent=True, + clip_inputs=False, + ) + del auto_conditioning + + logger.info("Transforming autoregressive outputs into audio..") + wav_candidates = [] + for b in range(best_results.shape[0]): + codes = best_results[b].unsqueeze(0) + latents = best_latents[b].unsqueeze(0) + + # Find the first occurrence of the "calm" token and trim the codes to that. + ctokens = 0 + for code in range(codes.shape[-1]): + if codes[0, code] == calm_token: + ctokens += 1 + else: + ctokens = 0 + if ctokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. + latents = latents[:, :code] + break + with self.temporary_cuda(self.diffusion) as diffusion: + mel = do_spectrogram_diffusion( + diffusion, + diffuser, + latents, + diffusion_conditioning, + temperature=diffusion_temperature, + verbose=verbose, + ) + with self.temporary_cuda(self.vocoder) as vocoder: + wav = vocoder.inference(mel) + wav_candidates.append(wav.cpu()) + + def potentially_redact(clip, text): + if self.enable_redaction: + return self.aligner.redact(clip.squeeze(1), text).unsqueeze(1) + return clip + + wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates] + + if len(wav_candidates) > 1: + res = wav_candidates + else: + res = wav_candidates[0] + + return_dict = { + "wav": res, + "deterministic_seed": None, + "text": None, + "voice_samples": None, + "conditioning_latents": None, + } + if return_deterministic_state: + return_dict = { + "wav": res, + "deterministic_seed": deterministic_seed, + "text": text, + "voice_samples": voice_samples, + "conditioning_latents": conditioning_latents, + } + return return_dict + + def forward(self): + raise NotImplementedError("Tortoise Training is not implemented") + + def eval_step(self): + raise NotImplementedError("Tortoise Training is not implemented") + + @staticmethod + def init_from_config(config: "TortoiseConfig", **kwargs): # pylint: disable=unused-argument + return Tortoise(config) + + def load_checkpoint( + self, + config, + checkpoint_dir, + ar_checkpoint_path=None, + diff_checkpoint_path=None, + clvp_checkpoint_path=None, + vocoder_checkpoint_path=None, + eval=False, + strict=True, + **kwargs, + ): # pylint: disable=unused-argument, redefined-builtin + """Load a model checkpoints from a directory. This model is with multiple checkpoint files and it + expects to have all the files to be under the given `checkpoint_dir` with the rigth names. + If eval is True, set the model to eval mode. + + Args: + config (TortoiseConfig): The model config. + checkpoint_dir (str): The directory where the checkpoints are stored. + ar_checkpoint_path (str, optional): The path to the autoregressive checkpoint. Defaults to None. + diff_checkpoint_path (str, optional): The path to the diffusion checkpoint. Defaults to None. + clvp_checkpoint_path (str, optional): The path to the CLVP checkpoint. Defaults to None. + vocoder_checkpoint_path (str, optional): The path to the vocoder checkpoint. Defaults to None. + eval (bool, optional): Whether to set the model to eval mode. Defaults to False. + strict (bool, optional): Whether to load the model strictly. Defaults to True. + """ + if self.models_dir is None: + self.models_dir = checkpoint_dir + ar_path = ar_checkpoint_path or os.path.join(checkpoint_dir, "autoregressive.pth") + diff_path = diff_checkpoint_path or os.path.join(checkpoint_dir, "diffusion_decoder.pth") + clvp_path = clvp_checkpoint_path or os.path.join(checkpoint_dir, "clvp2.pth") + vocoder_checkpoint_path = vocoder_checkpoint_path or os.path.join(checkpoint_dir, "vocoder.pth") + self.mel_norm_path = os.path.join(checkpoint_dir, "mel_norms.pth") + + if os.path.exists(ar_path): + # remove keys from the checkpoint that are not in the model + checkpoint = torch.load(ar_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4()) + + # strict set False + # due to removed `bias` and `masked_bias` changes in Transformers + self.autoregressive.load_state_dict(checkpoint, strict=False) + + if os.path.exists(diff_path): + self.diffusion.load_state_dict(torch.load(diff_path, weights_only=is_pytorch_at_least_2_4()), strict=strict) + + if os.path.exists(clvp_path): + self.clvp.load_state_dict(torch.load(clvp_path, weights_only=is_pytorch_at_least_2_4()), strict=strict) + + if os.path.exists(vocoder_checkpoint_path): + self.vocoder.load_state_dict( + config.model_args.vocoder.value.optionally_index( + torch.load( + vocoder_checkpoint_path, + map_location=torch.device("cpu"), + weights_only=is_pytorch_at_least_2_4(), + ) + ) + ) + + if eval: + self.autoregressive.post_init_gpt2_config(self.args.kv_cache) + self.autoregressive.eval() + self.diffusion.eval() + self.clvp.eval() + self.vocoder.eval() + + def train_step(self): + raise NotImplementedError("Tortoise Training is not implemented") diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py new file mode 100644 index 0000000000000000000000000000000000000000..b014e4fdde2637b33f76dc03385687e5b2d579cd --- /dev/null +++ b/TTS/tts/models/vits.py @@ -0,0 +1,2009 @@ +import logging +import math +import os +from dataclasses import dataclass, field, replace +from itertools import chain +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torchaudio +from coqpit import Coqpit +from librosa.filters import mel as librosa_mel_fn +from torch import nn +from torch.cuda.amp.autocast_mode import autocast +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.io import load_fsspec +from trainer.torch import DistributedSampler, DistributedSamplerWrapper +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.tts.configs.shared_configs import CharactersConfig +from TTS.tts.datasets.dataset import TTSDataset, _parse_sample +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor +from TTS.tts.layers.vits.discriminator import VitsDiscriminator +from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder +from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.fairseq import rehash_fairseq_vits_checkpoint +from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask +from TTS.tts.utils.languages import LanguageManager +from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations +from TTS.tts.utils.text.tokenizer import TTSTokenizer +from TTS.tts.utils.visual import plot_alignment +from TTS.utils.samplers import BucketBatchSampler +from TTS.vocoder.models.hifigan_generator import HifiganGenerator +from TTS.vocoder.utils.generic_utils import plot_results + +logger = logging.getLogger(__name__) + +############################## +# IO / Feature extraction +############################## + +# pylint: disable=global-statement +hann_window = {} +mel_basis = {} + + +@torch.no_grad() +def weights_reset(m: nn.Module): + # check if the current module has reset_parameters and if it is reset the weight + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + + +def get_module_weights_sum(mdl: nn.Module): + dict_sums = {} + for name, w in mdl.named_parameters(): + if "weight" in name: + value = w.data.sum().item() + dict_sums[name] = value + return dict_sums + + +def load_audio(file_path): + """Load the audio file normalized in [-1, 1] + + Return Shapes: + - x: :math:`[1, T]` + """ + x, sr = torchaudio.load(file_path) + assert (x > 1).sum() + (x < -1).sum() == 0 + return x, sr + + +def _amp_to_db(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def _db_to_amp(x, C=1): + return torch.exp(x) / C + + +def amp_to_db(magnitudes): + output = _amp_to_db(magnitudes) + return output + + +def db_to_amp(magnitudes): + output = _db_to_amp(magnitudes) + return output + + +def wav_to_spec(y, n_fft, hop_length, win_length, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + logger.info("min value is %.3f", torch.min(y)) + if torch.max(y) > 1.0: + logger.info("max value is %.3f", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): + """ + Args Shapes: + - spec : :math:`[B,C,T]` + + Return Shapes: + - mel : :math:`[B,C,T]` + """ + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel = torch.matmul(mel_basis[fmax_dtype_device], spec) + mel = amp_to_db(mel) + return mel + + +def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + logger.info("min value is %.3f", torch.min(y)) + if torch.max(y) > 1.0: + logger.info("max value is %.3f", torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_length) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = amp_to_db(spec) + return spec + + +############################# +# CONFIGS +############################# + + +@dataclass +class VitsAudioConfig(Coqpit): + fft_size: int = 1024 + sample_rate: int = 22050 + win_length: int = 1024 + hop_length: int = 256 + num_mels: int = 80 + mel_fmin: int = 0 + mel_fmax: int = None + + +############################## +# DATASET +############################## + + +def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None): + """Create inverse frequency weights for balancing the dataset. + Use `multi_dict` to scale relative weights.""" + attr_names_samples = np.array([item[attr_name] for item in items]) + unique_attr_names = np.unique(attr_names_samples).tolist() + attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] + attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) + weight_attr = 1.0 / attr_count + dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + if multi_dict is not None: + # check if all keys are in the multi_dict + for k in multi_dict: + assert k in unique_attr_names, f"{k} not in {unique_attr_names}" + # scale weights + multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) + dataset_samples_weight *= multiplier_samples + return ( + torch.from_numpy(dataset_samples_weight).float(), + unique_attr_names, + np.unique(dataset_samples_weight).tolist(), + ) + + +class VitsDataset(TTSDataset): + def __init__(self, model_args, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pad_id = self.tokenizer.characters.pad_id + self.model_args = model_args + + def __getitem__(self, idx): + item = self.samples[idx] + raw_text = item["text"] + + wav, _ = load_audio(item["audio_file"]) + if self.model_args.encoder_sample_rate is not None: + if wav.size(1) % self.model_args.encoder_sample_rate != 0: + wav = wav[:, : -int(wav.size(1) % self.model_args.encoder_sample_rate)] + + wav_filename = os.path.basename(item["audio_file"]) + + token_ids = self.get_token_ids(idx, item["text"]) + + # after phonemization the text length may change + # this is a shameful 🤭 hack to prevent longer phonemes + # TODO: find a better fix + if len(token_ids) > self.max_text_len or wav.shape[1] < self.min_audio_len: + self.rescue_item_idx += 1 + return self.__getitem__(self.rescue_item_idx) + + return { + "raw_text": raw_text, + "token_ids": token_ids, + "token_len": len(token_ids), + "wav": wav, + "wav_file": wav_filename, + "speaker_name": item["speaker_name"], + "language_name": item["language"], + "audio_unique_name": item["audio_unique_name"], + } + + @property + def lengths(self): + lens = [] + for item in self.samples: + _, wav_file, *_ = _parse_sample(item) + audio_len = os.path.getsize(wav_file) / 16 * 8 # assuming 16bit audio + lens.append(audio_len) + return lens + + def collate_fn(self, batch): + """ + Return Shapes: + - tokens: :math:`[B, T]` + - token_lens :math:`[B]` + - token_rel_lens :math:`[B]` + - waveform: :math:`[B, 1, T]` + - waveform_lens: :math:`[B]` + - waveform_rel_lens: :math:`[B]` + - speaker_names: :math:`[B]` + - language_names: :math:`[B]` + - audiofile_paths: :math:`[B]` + - raw_texts: :math:`[B]` + - audio_unique_names: :math:`[B]` + """ + # convert list of dicts to dict of lists + B = len(batch) + batch = {k: [dic[k] for dic in batch] for k in batch[0]} + + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x.size(1) for x in batch["wav"]]), dim=0, descending=True + ) + + max_text_len = max([len(x) for x in batch["token_ids"]]) + token_lens = torch.LongTensor(batch["token_len"]) + token_rel_lens = token_lens / token_lens.max() + + wav_lens = [w.shape[1] for w in batch["wav"]] + wav_lens = torch.LongTensor(wav_lens) + wav_lens_max = torch.max(wav_lens) + wav_rel_lens = wav_lens / wav_lens_max + + token_padded = torch.LongTensor(B, max_text_len) + wav_padded = torch.FloatTensor(B, 1, wav_lens_max) + token_padded = token_padded.zero_() + self.pad_id + wav_padded = wav_padded.zero_() + self.pad_id + for i in range(len(ids_sorted_decreasing)): + token_ids = batch["token_ids"][i] + token_padded[i, : batch["token_len"][i]] = torch.LongTensor(token_ids) + + wav = batch["wav"][i] + wav_padded[i, :, : wav.size(1)] = torch.FloatTensor(wav) + + return { + "tokens": token_padded, + "token_lens": token_lens, + "token_rel_lens": token_rel_lens, + "waveform": wav_padded, # (B x T) + "waveform_lens": wav_lens, # (B) + "waveform_rel_lens": wav_rel_lens, + "speaker_names": batch["speaker_name"], + "language_names": batch["language_name"], + "audio_files": batch["wav_file"], + "raw_text": batch["raw_text"], + "audio_unique_names": batch["audio_unique_name"], + } + + +############################## +# MODEL DEFINITION +############################## + + +@dataclass +class VitsArgs(Coqpit): + """VITS model arguments. + + Args: + + num_chars (int): + Number of characters in the vocabulary. Defaults to 100. + + out_channels (int): + Number of output channels of the decoder. Defaults to 513. + + spec_segment_size (int): + Decoder input segment size. Defaults to 32 `(32 * hoplength = waveform length)`. + + hidden_channels (int): + Number of hidden channels of the model. Defaults to 192. + + hidden_channels_ffn_text_encoder (int): + Number of hidden channels of the feed-forward layers of the text encoder transformer. Defaults to 256. + + num_heads_text_encoder (int): + Number of attention heads of the text encoder transformer. Defaults to 2. + + num_layers_text_encoder (int): + Number of transformer layers in the text encoder. Defaults to 6. + + kernel_size_text_encoder (int): + Kernel size of the text encoder transformer FFN layers. Defaults to 3. + + dropout_p_text_encoder (float): + Dropout rate of the text encoder. Defaults to 0.1. + + dropout_p_duration_predictor (float): + Dropout rate of the duration predictor. Defaults to 0.1. + + kernel_size_posterior_encoder (int): + Kernel size of the posterior encoder's WaveNet layers. Defaults to 5. + + dilatation_posterior_encoder (int): + Dilation rate of the posterior encoder's WaveNet layers. Defaults to 1. + + num_layers_posterior_encoder (int): + Number of posterior encoder's WaveNet layers. Defaults to 16. + + kernel_size_flow (int): + Kernel size of the Residual Coupling layers of the flow network. Defaults to 5. + + dilatation_flow (int): + Dilation rate of the Residual Coupling WaveNet layers of the flow network. Defaults to 1. + + num_layers_flow (int): + Number of Residual Coupling WaveNet layers of the flow network. Defaults to 6. + + resblock_type_decoder (str): + Type of the residual block in the decoder network. Defaults to "1". + + resblock_kernel_sizes_decoder (List[int]): + Kernel sizes of the residual blocks in the decoder network. Defaults to `[3, 7, 11]`. + + resblock_dilation_sizes_decoder (List[List[int]]): + Dilation sizes of the residual blocks in the decoder network. Defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`. + + upsample_rates_decoder (List[int]): + Upsampling rates for each concecutive upsampling layer in the decoder network. The multiply of these + values must be equal to the kop length used for computing spectrograms. Defaults to `[8, 8, 2, 2]`. + + upsample_initial_channel_decoder (int): + Number of hidden channels of the first upsampling convolution layer of the decoder network. Defaults to 512. + + upsample_kernel_sizes_decoder (List[int]): + Kernel sizes for each upsampling layer of the decoder network. Defaults to `[16, 16, 4, 4]`. + + periods_multi_period_discriminator (List[int]): + Periods values for Vits Multi-Period Discriminator. Defaults to `[2, 3, 5, 7, 11]`. + + use_sdp (bool): + Use Stochastic Duration Predictor. Defaults to True. + + noise_scale (float): + Noise scale used for the sample noise tensor in training. Defaults to 1.0. + + inference_noise_scale (float): + Noise scale used for the sample noise tensor in inference. Defaults to 0.667. + + length_scale (float): + Scale factor for the predicted duration values. Smaller values result faster speech. Defaults to 1. + + noise_scale_dp (float): + Noise scale used by the Stochastic Duration Predictor sample noise in training. Defaults to 1.0. + + inference_noise_scale_dp (float): + Noise scale for the Stochastic Duration Predictor in inference. Defaults to 0.8. + + max_inference_len (int): + Maximum inference length to limit the memory use. Defaults to None. + + init_discriminator (bool): + Initialize the disciminator network if set True. Set False for inference. Defaults to True. + + use_spectral_norm_disriminator (bool): + Use spectral normalization over weight norm in the discriminator. Defaults to False. + + use_speaker_embedding (bool): + Enable/Disable speaker embedding for multi-speaker models. Defaults to False. + + num_speakers (int): + Number of speakers for the speaker embedding layer. Defaults to 0. + + speakers_file (str): + Path to the speaker mapping file for the Speaker Manager. Defaults to None. + + speaker_embedding_channels (int): + Number of speaker embedding channels. Defaults to 256. + + use_d_vector_file (bool): + Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False. + + d_vector_file (List[str]): + List of paths to the files including pre-computed speaker embeddings. Defaults to None. + + d_vector_dim (int): + Number of d-vector channels. Defaults to 0. + + detach_dp_input (bool): + Detach duration predictor's input from the network for stopping the gradients. Defaults to True. + + use_language_embedding (bool): + Enable/Disable language embedding for multilingual models. Defaults to False. + + embedded_language_dim (int): + Number of language embedding channels. Defaults to 4. + + num_languages (int): + Number of languages for the language embedding layer. Defaults to 0. + + language_ids_file (str): + Path to the language mapping file for the Language Manager. Defaults to None. + + use_speaker_encoder_as_loss (bool): + Enable/Disable Speaker Consistency Loss (SCL). Defaults to False. + + speaker_encoder_config_path (str): + Path to the file speaker encoder config file, to use for SCL. Defaults to "". + + speaker_encoder_model_path (str): + Path to the file speaker encoder checkpoint file, to use for SCL. Defaults to "". + + condition_dp_on_speaker (bool): + Condition the duration predictor on the speaker embedding. Defaults to True. + + freeze_encoder (bool): + Freeze the encoder weigths during training. Defaults to False. + + freeze_DP (bool): + Freeze the duration predictor weigths during training. Defaults to False. + + freeze_PE (bool): + Freeze the posterior encoder weigths during training. Defaults to False. + + freeze_flow_encoder (bool): + Freeze the flow encoder weigths during training. Defaults to False. + + freeze_waveform_decoder (bool): + Freeze the waveform decoder weigths during training. Defaults to False. + + encoder_sample_rate (int): + If not None this sample rate will be used for training the Posterior Encoder, + flow, text_encoder and duration predictor. The decoder part (vocoder) will be + trained with the `config.audio.sample_rate`. Defaults to None. + + interpolate_z (bool): + If `encoder_sample_rate` not None and this parameter True the nearest interpolation + will be used to upsampling the latent variable z with the sampling rate `encoder_sample_rate` + to the `config.audio.sample_rate`. If it is False you will need to add extra + `upsample_rates_decoder` to match the shape. Defaults to True. + + """ + + num_chars: int = 100 + out_channels: int = 513 + spec_segment_size: int = 32 + hidden_channels: int = 192 + hidden_channels_ffn_text_encoder: int = 768 + num_heads_text_encoder: int = 2 + num_layers_text_encoder: int = 6 + kernel_size_text_encoder: int = 3 + dropout_p_text_encoder: float = 0.1 + dropout_p_duration_predictor: float = 0.5 + kernel_size_posterior_encoder: int = 5 + dilation_rate_posterior_encoder: int = 1 + num_layers_posterior_encoder: int = 16 + kernel_size_flow: int = 5 + dilation_rate_flow: int = 1 + num_layers_flow: int = 4 + resblock_type_decoder: str = "1" + resblock_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [3, 7, 11]) + resblock_dilation_sizes_decoder: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_rates_decoder: List[int] = field(default_factory=lambda: [8, 8, 2, 2]) + upsample_initial_channel_decoder: int = 512 + upsample_kernel_sizes_decoder: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + periods_multi_period_discriminator: List[int] = field(default_factory=lambda: [2, 3, 5, 7, 11]) + use_sdp: bool = True + noise_scale: float = 1.0 + inference_noise_scale: float = 0.667 + length_scale: float = 1 + noise_scale_dp: float = 1.0 + inference_noise_scale_dp: float = 1.0 + max_inference_len: int = None + init_discriminator: bool = True + use_spectral_norm_disriminator: bool = False + use_speaker_embedding: bool = False + num_speakers: int = 0 + speakers_file: str = None + d_vector_file: List[str] = None + speaker_embedding_channels: int = 256 + use_d_vector_file: bool = False + d_vector_dim: int = 0 + detach_dp_input: bool = True + use_language_embedding: bool = False + embedded_language_dim: int = 4 + num_languages: int = 0 + language_ids_file: str = None + use_speaker_encoder_as_loss: bool = False + speaker_encoder_config_path: str = "" + speaker_encoder_model_path: str = "" + condition_dp_on_speaker: bool = True + freeze_encoder: bool = False + freeze_DP: bool = False + freeze_PE: bool = False + freeze_flow_decoder: bool = False + freeze_waveform_decoder: bool = False + encoder_sample_rate: int = None + interpolate_z: bool = True + reinit_DP: bool = False + reinit_text_encoder: bool = False + + +class Vits(BaseTTS): + """VITS TTS model + + Paper:: + https://arxiv.org/pdf/2106.06103.pdf + + Paper Abstract:: + Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel + sampling have been proposed, but their sample quality does not match that of two-stage TTS systems. + In this work, we present a parallel endto-end TTS method that generates more natural sounding audio than + current two-stage models. Our method adopts variational inference augmented with normalizing flows and + an adversarial training process, which improves the expressive power of generative modeling. We also propose a + stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the + uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the + natural one-to-many relationship in which a text input can be spoken in multiple ways + with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS) + on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly + available TTS systems and achieves a MOS comparable to ground truth. + + Check :class:`TTS.tts.configs.vits_config.VitsConfig` for class arguments. + + Examples: + >>> from TTS.tts.configs.vits_config import VitsConfig + >>> from TTS.tts.models.vits import Vits + >>> config = VitsConfig() + >>> model = Vits(config) + """ + + def __init__( + self, + config: Coqpit, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None, + ): + super().__init__(config, ap, tokenizer, speaker_manager, language_manager) + + self.init_multispeaker(config) + self.init_multilingual(config) + self.init_upsampling() + + self.length_scale = self.args.length_scale + self.noise_scale = self.args.noise_scale + self.inference_noise_scale = self.args.inference_noise_scale + self.inference_noise_scale_dp = self.args.inference_noise_scale_dp + self.noise_scale_dp = self.args.noise_scale_dp + self.max_inference_len = self.args.max_inference_len + self.spec_segment_size = self.args.spec_segment_size + + self.text_encoder = TextEncoder( + self.args.num_chars, + self.args.hidden_channels, + self.args.hidden_channels, + self.args.hidden_channels_ffn_text_encoder, + self.args.num_heads_text_encoder, + self.args.num_layers_text_encoder, + self.args.kernel_size_text_encoder, + self.args.dropout_p_text_encoder, + language_emb_dim=self.embedded_language_dim, + ) + + self.posterior_encoder = PosteriorEncoder( + self.args.out_channels, + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_posterior_encoder, + dilation_rate=self.args.dilation_rate_posterior_encoder, + num_layers=self.args.num_layers_posterior_encoder, + cond_channels=self.embedded_speaker_dim, + ) + + self.flow = ResidualCouplingBlocks( + self.args.hidden_channels, + self.args.hidden_channels, + kernel_size=self.args.kernel_size_flow, + dilation_rate=self.args.dilation_rate_flow, + num_layers=self.args.num_layers_flow, + cond_channels=self.embedded_speaker_dim, + ) + + if self.args.use_sdp: + self.duration_predictor = StochasticDurationPredictor( + self.args.hidden_channels, + 192, + 3, + self.args.dropout_p_duration_predictor, + 4, + cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, + language_emb_dim=self.embedded_language_dim, + ) + else: + self.duration_predictor = DurationPredictor( + self.args.hidden_channels, + 256, + 3, + self.args.dropout_p_duration_predictor, + cond_channels=self.embedded_speaker_dim, + language_emb_dim=self.embedded_language_dim, + ) + + self.waveform_decoder = HifiganGenerator( + self.args.hidden_channels, + 1, + self.args.resblock_type_decoder, + self.args.resblock_dilation_sizes_decoder, + self.args.resblock_kernel_sizes_decoder, + self.args.upsample_kernel_sizes_decoder, + self.args.upsample_initial_channel_decoder, + self.args.upsample_rates_decoder, + inference_padding=0, + cond_channels=self.embedded_speaker_dim, + conv_pre_weight_norm=False, + conv_post_weight_norm=False, + conv_post_bias=False, + ) + + if self.args.init_discriminator: + self.disc = VitsDiscriminator( + periods=self.args.periods_multi_period_discriminator, + use_spectral_norm=self.args.use_spectral_norm_disriminator, + ) + + @property + def device(self): + return next(self.parameters()).device + + def init_multispeaker(self, config: Coqpit): + """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer + or with external `d_vectors` computed from a speaker encoder model. + + You must provide a `speaker_manager` at initialization to set up the multi-speaker modules. + + Args: + config (Coqpit): Model configuration. + data (List, optional): Dataset items to infer number of speakers. Defaults to None. + """ + self.embedded_speaker_dim = 0 + self.num_speakers = self.args.num_speakers + self.audio_transform = None + + if self.speaker_manager: + self.num_speakers = self.speaker_manager.num_speakers + + if self.args.use_speaker_embedding: + self._init_speaker_embedding() + + if self.args.use_d_vector_file: + self._init_d_vector() + + # TODO: make this a function + if self.args.use_speaker_encoder_as_loss: + if self.speaker_manager.encoder is None and ( + not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path + ): + raise RuntimeError( + " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!" + ) + + self.speaker_manager.encoder.eval() + logger.info("External Speaker Encoder Loaded !!") + + if ( + hasattr(self.speaker_manager.encoder, "audio_config") + and self.config.audio.sample_rate != self.speaker_manager.encoder.audio_config["sample_rate"] + ): + self.audio_transform = torchaudio.transforms.Resample( + orig_freq=self.config.audio.sample_rate, + new_freq=self.speaker_manager.encoder.audio_config["sample_rate"], + ) + + def _init_speaker_embedding(self): + # pylint: disable=attribute-defined-outside-init + if self.num_speakers > 0: + logger.info("Initialization of speaker-embedding layers.") + self.embedded_speaker_dim = self.args.speaker_embedding_channels + self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + + def _init_d_vector(self): + # pylint: disable=attribute-defined-outside-init + if hasattr(self, "emb_g"): + raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") + self.embedded_speaker_dim = self.args.d_vector_dim + + def init_multilingual(self, config: Coqpit): + """Initialize multilingual modules of a model. + + Args: + config (Coqpit): Model configuration. + """ + if self.args.language_ids_file is not None: + self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) + + if self.args.use_language_embedding and self.language_manager: + logger.info("Initialization of language-embedding layers.") + self.num_languages = self.language_manager.num_languages + self.embedded_language_dim = self.args.embedded_language_dim + self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) + torch.nn.init.xavier_uniform_(self.emb_l.weight) + else: + self.embedded_language_dim = 0 + + def init_upsampling(self): + """ + Initialize upsampling modules of a model. + """ + if self.args.encoder_sample_rate: + self.interpolate_factor = self.config.audio["sample_rate"] / self.args.encoder_sample_rate + self.audio_resampler = torchaudio.transforms.Resample( + orig_freq=self.config.audio["sample_rate"], new_freq=self.args.encoder_sample_rate + ) # pylint: disable=W0201 + + def on_epoch_start(self, trainer): # pylint: disable=W0613 + """Freeze layers at the beginning of an epoch""" + self._freeze_layers() + # set the device of speaker encoder + if self.args.use_speaker_encoder_as_loss: + self.speaker_manager.encoder = self.speaker_manager.encoder.to(self.device) + + def on_init_end(self, trainer): # pylint: disable=W0613 + """Reinit layes if needed""" + if self.args.reinit_DP: + before_dict = get_module_weights_sum(self.duration_predictor) + # Applies weights_reset recursively to every submodule of the duration predictor + self.duration_predictor.apply(fn=weights_reset) + after_dict = get_module_weights_sum(self.duration_predictor) + for key, value in after_dict.items(): + if value == before_dict[key]: + raise RuntimeError(" [!] The weights of Duration Predictor was not reinit check it !") + logger.info("Duration Predictor was reinit.") + + if self.args.reinit_text_encoder: + before_dict = get_module_weights_sum(self.text_encoder) + # Applies weights_reset recursively to every submodule of the duration predictor + self.text_encoder.apply(fn=weights_reset) + after_dict = get_module_weights_sum(self.text_encoder) + for key, value in after_dict.items(): + if value == before_dict[key]: + raise RuntimeError(" [!] The weights of Text Encoder was not reinit check it !") + logger.info("Text Encoder was reinit.") + + def get_aux_input(self, aux_input: Dict): + sid, g, lid, _ = self._set_cond_input(aux_input) + return {"speaker_ids": sid, "style_wav": None, "d_vectors": g, "language_ids": lid} + + def _freeze_layers(self): + if self.args.freeze_encoder: + for param in self.text_encoder.parameters(): + param.requires_grad = False + + if hasattr(self, "emb_l"): + for param in self.emb_l.parameters(): + param.requires_grad = False + + if self.args.freeze_PE: + for param in self.posterior_encoder.parameters(): + param.requires_grad = False + + if self.args.freeze_DP: + for param in self.duration_predictor.parameters(): + param.requires_grad = False + + if self.args.freeze_flow_decoder: + for param in self.flow.parameters(): + param.requires_grad = False + + if self.args.freeze_waveform_decoder: + for param in self.waveform_decoder.parameters(): + param.requires_grad = False + + @staticmethod + def _set_cond_input(aux_input: Dict): + """Set the speaker conditioning input based on the multi-speaker mode.""" + sid, g, lid, durations = None, None, None, None + if "speaker_ids" in aux_input and aux_input["speaker_ids"] is not None: + sid = aux_input["speaker_ids"] + if sid.ndim == 0: + sid = sid.unsqueeze_(0) + if "d_vectors" in aux_input and aux_input["d_vectors"] is not None: + g = F.normalize(aux_input["d_vectors"]).unsqueeze(-1) + if g.ndim == 2: + g = g.unsqueeze_(0) + + if "language_ids" in aux_input and aux_input["language_ids"] is not None: + lid = aux_input["language_ids"] + if lid.ndim == 0: + lid = lid.unsqueeze_(0) + + if "durations" in aux_input and aux_input["durations"] is not None: + durations = aux_input["durations"] + + return sid, g, lid, durations + + def _set_speaker_input(self, aux_input: Dict): + d_vectors = aux_input.get("d_vectors", None) + speaker_ids = aux_input.get("speaker_ids", None) + + if d_vectors is not None and speaker_ids is not None: + raise ValueError("[!] Cannot use d-vectors and speaker-ids together.") + + if speaker_ids is not None and not hasattr(self, "emb_g"): + raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.") + + g = speaker_ids if speaker_ids is not None else d_vectors + return g + + def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb): + # find the alignment path + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + with torch.no_grad(): + o_scale = torch.exp(-2 * logs_p) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)]) + logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) + logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp = logp2 + logp3 + logp1 + logp4 + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() # [b, 1, t, t'] + + # duration predictor + attn_durations = attn.sum(3) + if self.args.use_sdp: + loss_duration = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + attn_durations, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, + ) + loss_duration = loss_duration / torch.sum(x_mask) + else: + attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask + log_durations = self.duration_predictor( + x.detach() if self.args.detach_dp_input else x, + x_mask, + g=g.detach() if self.args.detach_dp_input and g is not None else g, + lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb, + ) + loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) + outputs["loss_duration"] = loss_duration + return outputs, attn + + def upsampling_z(self, z, slice_ids=None, y_lengths=None, y_mask=None): + spec_segment_size = self.spec_segment_size + if self.args.encoder_sample_rate: + # recompute the slices and spec_segment_size if needed + slice_ids = slice_ids * int(self.interpolate_factor) if slice_ids is not None else slice_ids + spec_segment_size = spec_segment_size * int(self.interpolate_factor) + # interpolate z if needed + if self.args.interpolate_z: + z = torch.nn.functional.interpolate(z, scale_factor=[self.interpolate_factor], mode="linear").squeeze(0) + # recompute the mask if needed + if y_lengths is not None and y_mask is not None: + y_mask = ( + sequence_mask(y_lengths * self.interpolate_factor, None).to(y_mask.dtype).unsqueeze(1) + ) # [B, 1, T_dec_resampled] + + return z, spec_segment_size, slice_ids, y_mask + + def forward( # pylint: disable=dangerous-default-value + self, + x: torch.tensor, + x_lengths: torch.tensor, + y: torch.tensor, + y_lengths: torch.tensor, + waveform: torch.tensor, + aux_input={"d_vectors": None, "speaker_ids": None, "language_ids": None}, + ) -> Dict: + """Forward pass of the model. + + Args: + x (torch.tensor): Batch of input character sequence IDs. + x_lengths (torch.tensor): Batch of input character sequence lengths. + y (torch.tensor): Batch of input spectrograms. + y_lengths (torch.tensor): Batch of input spectrogram lengths. + waveform (torch.tensor): Batch of ground truth waveforms per sample. + aux_input (dict, optional): Auxiliary inputs for multi-speaker and multi-lingual training. + Defaults to {"d_vectors": None, "speaker_ids": None, "language_ids": None}. + + Returns: + Dict: model outputs keyed by the output name. + + Shapes: + - x: :math:`[B, T_seq]` + - x_lengths: :math:`[B]` + - y: :math:`[B, C, T_spec]` + - y_lengths: :math:`[B]` + - waveform: :math:`[B, 1, T_wav]` + - d_vectors: :math:`[B, C, 1]` + - speaker_ids: :math:`[B]` + - language_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` + - m_q: :math:`[B, C, T_dec]` + - logs_q: :math:`[B, C, T_dec]` + - waveform_seg: :math:`[B, 1, spec_seg_size * hop_length]` + - gt_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` + - syn_spk_emb: :math:`[B, 1, speaker_encoder.proj_dim]` + """ + outputs = {} + sid, g, lid, _ = self._set_cond_input(aux_input) + # speaker embedding + if self.args.use_speaker_embedding and sid is not None: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + + # language embedding + lang_emb = None + if self.args.use_language_embedding and lid is not None: + lang_emb = self.emb_l(lid).unsqueeze(-1) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + + # posterior encoder + z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) + + # flow layers + z_p = self.flow(z, y_mask, g=g) + + # duration predictor + outputs, attn = self.forward_mas(outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g=g, lang_emb=lang_emb) + + # expand prior + m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) + logs_p = torch.einsum("klmn, kjm -> kjn", [attn, logs_p]) + + # select a random feature segment for the waveform decoder + z_slice, slice_ids = rand_segments(z, y_lengths, self.spec_segment_size, let_short_samples=True, pad_short=True) + + # interpolate z if needed + z_slice, spec_segment_size, slice_ids, _ = self.upsampling_z(z_slice, slice_ids=slice_ids) + + o = self.waveform_decoder(z_slice, g=g) + + wav_seg = segment( + waveform, + slice_ids * self.config.audio.hop_length, + spec_segment_size * self.config.audio.hop_length, + pad_short=True, + ) + + if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None: + # concate generated and GT waveforms + wavs_batch = torch.cat((wav_seg, o), dim=0) + + # resample audio to speaker encoder sample_rate + # pylint: disable=W0105 + if self.audio_transform is not None: + wavs_batch = self.audio_transform(wavs_batch) + + pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True) + + # split generated and GT speaker embeddings + gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0) + else: + gt_spk_emb, syn_spk_emb = None, None + + outputs.update( + { + "model_outputs": o, + "alignments": attn.squeeze(1), + "m_p": m_p, + "logs_p": logs_p, + "z": z, + "z_p": z_p, + "m_q": m_q, + "logs_q": logs_q, + "waveform_seg": wav_seg, + "gt_spk_emb": gt_spk_emb, + "syn_spk_emb": syn_spk_emb, + "slice_ids": slice_ids, + } + ) + return outputs + + @staticmethod + def _set_x_lengths(x, aux_input): + if "x_lengths" in aux_input and aux_input["x_lengths"] is not None: + return aux_input["x_lengths"] + return torch.tensor(x.shape[1:2]).to(x.device) + + @torch.no_grad() + def inference( + self, + x, + aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None, "durations": None}, + ): # pylint: disable=dangerous-default-value + """ + Note: + To run in batch mode, provide `x_lengths` else model assumes that the batch size is 1. + + Shapes: + - x: :math:`[B, T_seq]` + - x_lengths: :math:`[B]` + - d_vectors: :math:`[B, C]` + - speaker_ids: :math:`[B]` + + Return Shapes: + - model_outputs: :math:`[B, 1, T_wav]` + - alignments: :math:`[B, T_seq, T_dec]` + - z: :math:`[B, C, T_dec]` + - z_p: :math:`[B, C, T_dec]` + - m_p: :math:`[B, C, T_dec]` + - logs_p: :math:`[B, C, T_dec]` + """ + sid, g, lid, durations = self._set_cond_input(aux_input) + x_lengths = self._set_x_lengths(x, aux_input) + + # speaker embedding + if self.args.use_speaker_embedding and sid is not None: + g = self.emb_g(sid).unsqueeze(-1) + + # language embedding + lang_emb = None + if self.args.use_language_embedding and lid is not None: + lang_emb = self.emb_l(lid).unsqueeze(-1) + + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + + if durations is None: + if self.args.use_sdp: + logw = self.duration_predictor( + x, + x_mask, + g=g if self.args.condition_dp_on_speaker else None, + reverse=True, + noise_scale=self.inference_noise_scale_dp, + lang_emb=lang_emb, + ) + else: + logw = self.duration_predictor( + x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb + ) + w = torch.exp(logw) * x_mask * self.length_scale + else: + assert durations.shape[-1] == x.shape[-1] + w = durations.unsqueeze(0) + + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = sequence_mask(y_lengths, None).to(x_mask.dtype).unsqueeze(1) # [B, 1, T_dec] + + attn_mask = x_mask * y_mask.transpose(1, 2) # [B, 1, T_enc] * [B, T_dec, 1] + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1).transpose(1, 2)) + + m_p = torch.matmul(attn.transpose(1, 2), m_p.transpose(1, 2)).transpose(1, 2) + logs_p = torch.matmul(attn.transpose(1, 2), logs_p.transpose(1, 2)).transpose(1, 2) + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * self.inference_noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + + # upsampling if needed + z, _, _, y_mask = self.upsampling_z(z, y_lengths=y_lengths, y_mask=y_mask) + + o = self.waveform_decoder((z * y_mask)[:, :, : self.max_inference_len], g=g) + + outputs = { + "model_outputs": o, + "alignments": attn.squeeze(1), + "durations": w_ceil, + "z": z, + "z_p": z_p, + "m_p": m_p, + "logs_p": logs_p, + "y_mask": y_mask, + } + return outputs + + @torch.no_grad() + def inference_voice_conversion( + self, reference_wav, speaker_id=None, d_vector=None, reference_speaker_id=None, reference_d_vector=None + ): + """Inference for voice conversion + + Args: + reference_wav (Tensor): Reference wavform. Tensor of shape [B, T] + speaker_id (Tensor): speaker_id of the target speaker. Tensor of shape [B] + d_vector (Tensor): d_vector embedding of target speaker. Tensor of shape `[B, C]` + reference_speaker_id (Tensor): speaker_id of the reference_wav speaker. Tensor of shape [B] + reference_d_vector (Tensor): d_vector embedding of the reference_wav speaker. Tensor of shape `[B, C]` + """ + # compute spectrograms + y = wav_to_spec( + reference_wav, + self.config.audio.fft_size, + self.config.audio.hop_length, + self.config.audio.win_length, + center=False, + ) + y_lengths = torch.tensor([y.size(-1)]).to(y.device) + speaker_cond_src = reference_speaker_id if reference_speaker_id is not None else reference_d_vector + speaker_cond_tgt = speaker_id if speaker_id is not None else d_vector + wav, _, _ = self.voice_conversion(y, y_lengths, speaker_cond_src, speaker_cond_tgt) + return wav + + def voice_conversion(self, y, y_lengths, speaker_cond_src, speaker_cond_tgt): + """Forward pass for voice conversion + + TODO: create an end-point for voice conversion + + Args: + y (Tensor): Reference spectrograms. Tensor of shape [B, T, C] + y_lengths (Tensor): Length of each reference spectrogram. Tensor of shape [B] + speaker_cond_src (Tensor): Reference speaker ID. Tensor of shape [B,] + speaker_cond_tgt (Tensor): Target speaker ID. Tensor of shape [B,] + """ + assert self.num_speakers > 0, "num_speakers have to be larger than 0." + # speaker embedding + if self.args.use_speaker_embedding and not self.args.use_d_vector_file: + g_src = self.emb_g(torch.from_numpy((np.array(speaker_cond_src))).unsqueeze(0)).unsqueeze(-1) + g_tgt = self.emb_g(torch.from_numpy((np.array(speaker_cond_tgt))).unsqueeze(0)).unsqueeze(-1) + elif not self.args.use_speaker_embedding and self.args.use_d_vector_file: + g_src = F.normalize(speaker_cond_src).unsqueeze(-1) + g_tgt = F.normalize(speaker_cond_tgt).unsqueeze(-1) + else: + raise RuntimeError(" [!] Voice conversion is only supported on multi-speaker models.") + + z, _, _, y_mask = self.posterior_encoder(y, y_lengths, g=g_src) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.waveform_decoder(z_hat * y_mask, g=g_tgt) + return o_hat, y_mask, (z, z_p, z_hat) + + def train_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: + """Perform a single training step. Run the model forward pass and compute losses. + + Args: + batch (Dict): Input tensors. + criterion (nn.Module): Loss layer designed for the model. + optimizer_idx (int): Index of optimizer to use. 0 for the discriminator and 1 for the generator networks. + + Returns: + Tuple[Dict, Dict]: Model ouputs and computed losses. + """ + + spec_lens = batch["spec_lens"] + + if optimizer_idx == 0: + tokens = batch["tokens"] + token_lenghts = batch["token_lens"] + spec = batch["spec"] + + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + language_ids = batch["language_ids"] + waveform = batch["waveform"] + + # generator pass + outputs = self.forward( + tokens, + token_lenghts, + spec, + spec_lens, + waveform, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids}, + ) + + # cache tensors for the generator pass + self.model_outputs_cache = outputs # pylint: disable=attribute-defined-outside-init + + # compute scores and features + scores_disc_fake, _, scores_disc_real, _ = self.disc( + outputs["model_outputs"].detach(), outputs["waveform_seg"] + ) + + # compute loss + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + scores_disc_real, + scores_disc_fake, + ) + return outputs, loss_dict + + if optimizer_idx == 1: + mel = batch["mel"] + + # compute melspec segment + with autocast(enabled=False): + if self.args.encoder_sample_rate: + spec_segment_size = self.spec_segment_size * int(self.interpolate_factor) + else: + spec_segment_size = self.spec_segment_size + + mel_slice = segment( + mel.float(), self.model_outputs_cache["slice_ids"], spec_segment_size, pad_short=True + ) + mel_slice_hat = wav_to_mel( + y=self.model_outputs_cache["model_outputs"].float(), + n_fft=self.config.audio.fft_size, + sample_rate=self.config.audio.sample_rate, + num_mels=self.config.audio.num_mels, + hop_length=self.config.audio.hop_length, + win_length=self.config.audio.win_length, + fmin=self.config.audio.mel_fmin, + fmax=self.config.audio.mel_fmax, + center=False, + ) + + # compute discriminator scores and features + scores_disc_fake, feats_disc_fake, _, feats_disc_real = self.disc( + self.model_outputs_cache["model_outputs"], self.model_outputs_cache["waveform_seg"] + ) + + # compute losses + with autocast(enabled=False): # use float32 for the criterion + loss_dict = criterion[optimizer_idx]( + mel_slice_hat=mel_slice.float(), + mel_slice=mel_slice_hat.float(), + z_p=self.model_outputs_cache["z_p"].float(), + logs_q=self.model_outputs_cache["logs_q"].float(), + m_p=self.model_outputs_cache["m_p"].float(), + logs_p=self.model_outputs_cache["logs_p"].float(), + z_len=spec_lens, + scores_disc_fake=scores_disc_fake, + feats_disc_fake=feats_disc_fake, + feats_disc_real=feats_disc_real, + loss_duration=self.model_outputs_cache["loss_duration"], + use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, + gt_spk_emb=self.model_outputs_cache["gt_spk_emb"], + syn_spk_emb=self.model_outputs_cache["syn_spk_emb"], + ) + + return self.model_outputs_cache, loss_dict + + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + def _log(self, ap, batch, outputs, name_prefix="train"): # pylint: disable=unused-argument,no-self-use + y_hat = outputs[1]["model_outputs"] + y = outputs[1]["waveform_seg"] + figures = plot_results(y_hat, y, ap, name_prefix) + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + audios = {f"{name_prefix}/audio": sample_voice} + + alignments = outputs[1]["alignments"] + align_img = alignments[0].data.cpu().numpy().T + + figures.update( + { + "alignment": plot_alignment(align_img, output_fig=False), + } + ) + return figures, audios + + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ): # pylint: disable=no-self-use + """Create visualizations and waveform examples. + + For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to + be projected onto Tensorboard. + + Args: + ap (AudioProcessor): audio processor used at training. + batch (Dict): Model inputs used at the previous training step. + outputs (Dict): Model outputs generated at the previoud training step. + + Returns: + Tuple[Dict, np.ndarray]: training plots and output waveform. + """ + figures, audios = self._log(self.ap, batch, outputs, "train") + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, self.ap.sample_rate) + + @torch.no_grad() + def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int): + return self.train_step(batch, criterion, optimizer_idx) + + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + figures, audios = self._log(self.ap, batch, outputs, "eval") + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def get_aux_input_from_test_sentences(self, sentence_info): + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if hasattr(self, "speaker_manager"): + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_embedding() + else: + d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_id() + else: + speaker_id = self.speaker_manager.name_to_id[speaker_name] + + # get language id + if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.name_to_id[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + "language_name": language_name, + } + + @torch.no_grad() + def test_run(self, assets) -> Tuple[Dict, Dict]: + """Generic test run for `tts` models used by `Trainer`. + + You can override this for a different behaviour. + + Returns: + Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard. + """ + logger.info("Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + for idx, s_info in enumerate(test_sentences): + aux_inputs = self.get_aux_input_from_test_sentences(s_info) + wav, alignment, _, _ = synthesis( + self, + aux_inputs["text"], + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + language_id=aux_inputs["language_id"], + use_griffin_lim=True, + do_trim_silence=False, + ).values() + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment.T, output_fig=False) + return {"figures": test_figures, "audios": test_audios} + + def test_log( + self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument + ) -> None: + logger.test_audios(steps, outputs["audios"], self.ap.sample_rate) + logger.test_figures(steps, outputs["figures"]) + + def format_batch(self, batch: Dict) -> Dict: + """Compute speaker, langugage IDs and d_vector for the batch if necessary.""" + speaker_ids = None + language_ids = None + d_vectors = None + + # get numerical speaker ids from speaker names + if self.speaker_manager is not None and self.speaker_manager.name_to_id and self.args.use_speaker_embedding: + speaker_ids = [self.speaker_manager.name_to_id[sn] for sn in batch["speaker_names"]] + + if speaker_ids is not None: + speaker_ids = torch.LongTensor(speaker_ids) + + # get d_vectors from audio file names + if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file: + d_vector_mapping = self.speaker_manager.embeddings + d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_unique_names"]] + d_vectors = torch.FloatTensor(d_vectors) + + # get language ids from language names + if self.language_manager is not None and self.language_manager.name_to_id and self.args.use_language_embedding: + language_ids = [self.language_manager.name_to_id[ln] for ln in batch["language_names"]] + + if language_ids is not None: + language_ids = torch.LongTensor(language_ids) + + batch["language_ids"] = language_ids + batch["d_vectors"] = d_vectors + batch["speaker_ids"] = speaker_ids + return batch + + def format_batch_on_device(self, batch): + """Compute spectrograms on the device.""" + ac = self.config.audio + + if self.args.encoder_sample_rate: + wav = self.audio_resampler(batch["waveform"]) + else: + wav = batch["waveform"] + + # compute spectrograms + batch["spec"] = wav_to_spec(wav, ac.fft_size, ac.hop_length, ac.win_length, center=False) + + if self.args.encoder_sample_rate: + # recompute spec with high sampling rate to the loss + spec_mel = wav_to_spec(batch["waveform"], ac.fft_size, ac.hop_length, ac.win_length, center=False) + # remove extra stft frames if needed + if spec_mel.size(2) > int(batch["spec"].size(2) * self.interpolate_factor): + spec_mel = spec_mel[:, :, : int(batch["spec"].size(2) * self.interpolate_factor)] + else: + batch["spec"] = batch["spec"][:, :, : int(spec_mel.size(2) / self.interpolate_factor)] + else: + spec_mel = batch["spec"] + + batch["mel"] = spec_to_mel( + spec=spec_mel, + n_fft=ac.fft_size, + num_mels=ac.num_mels, + sample_rate=ac.sample_rate, + fmin=ac.mel_fmin, + fmax=ac.mel_fmax, + ) + + if self.args.encoder_sample_rate: + assert batch["spec"].shape[2] == int( + batch["mel"].shape[2] / self.interpolate_factor + ), f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" + else: + assert batch["spec"].shape[2] == batch["mel"].shape[2], f"{batch['spec'].shape[2]}, {batch['mel'].shape[2]}" + + # compute spectrogram frame lengths + batch["spec_lens"] = (batch["spec"].shape[2] * batch["waveform_rel_lens"]).int() + batch["mel_lens"] = (batch["mel"].shape[2] * batch["waveform_rel_lens"]).int() + + if self.args.encoder_sample_rate: + assert (batch["spec_lens"] - (batch["mel_lens"] / self.interpolate_factor).int()).sum() == 0 + else: + assert (batch["spec_lens"] - batch["mel_lens"]).sum() == 0 + + # zero the padding frames + batch["spec"] = batch["spec"] * sequence_mask(batch["spec_lens"]).unsqueeze(1) + batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1) + return batch + + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False): + weights = None + data_items = dataset.samples + if getattr(config, "use_weighted_sampler", False): + for attr_name, alpha in config.weighted_sampler_attrs.items(): + logger.info("Using weighted sampler for attribute '%s' with alpha %.3f", attr_name, alpha) + multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) + logger.info(multi_dict) + weights, attr_names, attr_weights = get_attribute_balancer_weights( + attr_name=attr_name, items=data_items, multi_dict=multi_dict + ) + weights = weights * alpha + logger.info("Attribute weights for '%s' \n | > %s", attr_names, attr_weights) + + # input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items] + + if weights is not None: + w_sampler = WeightedRandomSampler(weights, len(weights)) + batch_sampler = BucketBatchSampler( + w_sampler, + data=data_items, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + sort_key=lambda x: os.path.getsize(x["audio_file"]), + drop_last=True, + ) + else: + batch_sampler = None + # sampler for DDP + if batch_sampler is None: + batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + batch_sampler = ( + DistributedSamplerWrapper(batch_sampler) if num_gpus > 1 else batch_sampler + ) # TODO: check batch_sampler with multi-gpu + return batch_sampler + + def get_data_loader( + self, + config: Coqpit, + assets: Dict, + is_eval: bool, + samples: Union[List[Dict], List[List]], + verbose: bool, + num_gpus: int, + rank: int = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # init dataloader + dataset = VitsDataset( + model_args=self.args, + samples=samples, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + tokenizer=self.tokenizer, + start_by_longest=config.start_by_longest, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + if sampler is None: + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=False, # shuffle is done in the dataset. + collate_fn=dataset.collate_fn, + drop_last=False, # setting this False might cause issues in AMP training. + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + else: + if num_gpus > 1: + loader = DataLoader( + dataset, + sampler=sampler, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + else: + loader = DataLoader( + dataset, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def get_optimizer(self) -> List: + """Initiate and return the GAN optimizers based on the config parameters. + + It returns 2 optimizers in a list. First one is for the discriminator + and the second one is for the generator. + + Returns: + List: optimizers. + """ + optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc) + + # select generator parameters + gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc.")) + optimizer1 = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters + ) + return [optimizer0, optimizer1] + + def get_lr(self) -> List: + """Set the initial learning rates for each optimizer. + + Returns: + List: learning rates for each optimizer. + """ + return [self.config.lr_disc, self.config.lr_gen] + + def get_scheduler(self, optimizer) -> List: + """Set the schedulers for each optimizer. + + Args: + optimizer (List[`torch.optim.Optimizer`]): List of optimizers. + + Returns: + List: Schedulers, one for each optimizer. + """ + scheduler_D = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[0]) + scheduler_G = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[1]) + return [scheduler_D, scheduler_G] + + def get_criterion(self): + """Get criterions for each optimizer. The index in the output list matches the optimizer idx used in + `train_step()`""" + from TTS.tts.layers.losses import ( # pylint: disable=import-outside-toplevel + VitsDiscriminatorLoss, + VitsGeneratorLoss, + ) + + return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)] + + def load_checkpoint( + self, config, checkpoint_path, eval=False, strict=True, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + """Load the model checkpoint and setup for training or inference""" + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + # compat band-aid for the pre-trained models to not use the encoder baked into the model + # TODO: consider baking the speaker encoder into the model and call it from there. + # as it is probably easier for model distribution. + state["model"] = {k: v for k, v in state["model"].items() if "speaker_encoder" not in k} + + if self.args.encoder_sample_rate is not None and eval: + # audio resampler is not used in inference time + self.audio_resampler = None + + # handle fine-tuning from a checkpoint with additional speakers + if hasattr(self, "emb_g") and state["model"]["emb_g.weight"].shape != self.emb_g.weight.shape: + num_new_speakers = self.emb_g.weight.shape[0] - state["model"]["emb_g.weight"].shape[0] + logger.info("Loading checkpoint with %d additional speakers.", num_new_speakers) + emb_g = state["model"]["emb_g.weight"] + new_row = torch.randn(num_new_speakers, emb_g.shape[1]) + emb_g = torch.cat([emb_g, new_row], axis=0) + state["model"]["emb_g.weight"] = emb_g + # load the model weights + self.load_state_dict(state["model"], strict=strict) + + if eval: + self.eval() + assert not self.training + + def load_fairseq_checkpoint( + self, config, checkpoint_dir, eval=False, strict=True + ): # pylint: disable=unused-argument, redefined-builtin + """Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms + Performs some changes for compatibility. + + Args: + config (Coqpit): 🐸TTS model config. + checkpoint_dir (str): Path to the checkpoint directory. + eval (bool, optional): Set to True for evaluation. Defaults to False. + """ + import json + + from TTS.tts.utils.text.cleaners import basic_cleaners + + self.disc = None + # set paths + config_file = os.path.join(checkpoint_dir, "config.json") + checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth") + vocab_file = os.path.join(checkpoint_dir, "vocab.txt") + # set config params + with open(config_file, "r", encoding="utf-8") as file: + # Load the JSON data as a dictionary + config_org = json.load(file) + self.config.audio.sample_rate = config_org["data"]["sampling_rate"] + # self.config.add_blank = config['add_blank'] + # set tokenizer + vocab = FairseqVocab(vocab_file) + self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels) + self.tokenizer = TTSTokenizer( + use_phonemes=False, + text_cleaner=basic_cleaners, + characters=vocab, + phonemizer=None, + add_blank=config_org["data"]["add_blank"], + use_eos_bos=False, + ) + # load fairseq checkpoint + new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file) + self.load_state_dict(new_chk, strict=strict) + if eval: + self.eval() + assert not self.training + + @staticmethod + def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None): + """Initiate model from config + + Args: + config (VitsConfig): Model config. + samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training. + Defaults to None. + """ + from TTS.utils.audio import AudioProcessor + + upsample_rate = torch.prod(torch.as_tensor(config.model_args.upsample_rates_decoder)).item() + + if not config.model_args.encoder_sample_rate: + assert ( + upsample_rate == config.audio.hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {config.audio.hop_length}" + else: + encoder_to_vocoder_upsampling_factor = config.audio.sample_rate / config.model_args.encoder_sample_rate + effective_hop_length = config.audio.hop_length * encoder_to_vocoder_upsampling_factor + assert ( + upsample_rate == effective_hop_length + ), f" [!] Product of upsample rates must be equal to the hop length - {upsample_rate} vs {effective_hop_length}" + + ap = AudioProcessor.init_from_config(config) + tokenizer, new_config = TTSTokenizer.init_from_config(config) + speaker_manager = SpeakerManager.init_from_config(config, samples) + language_manager = LanguageManager.init_from_config(config) + + if config.model_args.speaker_encoder_model_path: + speaker_manager.init_encoder( + config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path + ) + return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) + + def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True): + """Export model to ONNX format for inference + + Args: + output_path (str): Path to save the exported model. + verbose (bool): Print verbose information. Defaults to True. + """ + + # rollback values + _forward = self.forward + disc = None + if hasattr(self, "disc"): + disc = self.disc + training = self.training + + # set export mode + self.disc = None + self.eval() + + def onnx_inference(text, text_lengths, scales, sid=None, langid=None): + noise_scale = scales[0] + length_scale = scales[1] + noise_scale_dp = scales[2] + self.noise_scale = noise_scale + self.length_scale = length_scale + self.noise_scale_dp = noise_scale_dp + return self.inference( + text, + aux_input={ + "x_lengths": text_lengths, + "d_vectors": None, + "speaker_ids": sid, + "language_ids": langid, + "durations": None, + }, + )["model_outputs"] + + self.forward = onnx_inference + + # set dummy inputs + dummy_input_length = 100 + sequences = torch.randint(low=0, high=2, size=(1, dummy_input_length), dtype=torch.long) + sequence_lengths = torch.LongTensor([sequences.size(1)]) + scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp]) + dummy_input = (sequences, sequence_lengths, scales) + input_names = ["input", "input_lengths", "scales"] + + if self.num_speakers > 0: + speaker_id = torch.LongTensor([0]) + dummy_input += (speaker_id,) + input_names.append("sid") + + if hasattr(self, "num_languages") and self.num_languages > 0 and self.embedded_language_dim > 0: + language_id = torch.LongTensor([0]) + dummy_input += (language_id,) + input_names.append("langid") + + # export to ONNX + torch.onnx.export( + model=self, + args=dummy_input, + opset_version=15, + f=output_path, + verbose=verbose, + input_names=input_names, + output_names=["output"], + dynamic_axes={ + "input": {0: "batch_size", 1: "phonemes"}, + "input_lengths": {0: "batch_size"}, + "output": {0: "batch_size", 1: "time1", 2: "time2"}, + }, + ) + + # rollback + self.forward = _forward + if training: + self.train() + if disc is not None: + self.disc = disc + + def load_onnx(self, model_path: str, cuda=False): + import onnxruntime as ort + + providers = [ + ( + "CPUExecutionProvider" + if cuda is False + else ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}) + ) + ] + sess_options = ort.SessionOptions() + self.onnx_sess = ort.InferenceSession( + model_path, + sess_options=sess_options, + providers=providers, + ) + + def inference_onnx(self, x, x_lengths=None, speaker_id=None, language_id=None): + """ONNX inference""" + + if isinstance(x, torch.Tensor): + x = x.cpu().numpy() + + if x_lengths is None: + x_lengths = np.array([x.shape[1]], dtype=np.int64) + + if isinstance(x_lengths, torch.Tensor): + x_lengths = x_lengths.cpu().numpy() + scales = np.array( + [self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp], + dtype=np.float32, + ) + input_params = {"input": x, "input_lengths": x_lengths, "scales": scales} + if speaker_id is not None: + input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy() + if language_id is not None: + input_params["langid"] = torch.tensor([language_id]).cpu().numpy() + + audio = self.onnx_sess.run( + ["output"], + input_params, + ) + return audio[0][0] + + +################################## +# VITS CHARACTERS +################################## + + +class VitsCharacters(BaseCharacters): + """Characters class for VITs model for compatibility with pre-trained models""" + + def __init__( + self, + graphemes: str = _characters, + punctuations: str = _punctuations, + pad: str = _pad, + ipa_characters: str = _phonemes, + ) -> None: + if ipa_characters is not None: + graphemes += ipa_characters + super().__init__(graphemes, punctuations, pad, None, None, "", is_unique=False, is_sorted=True) + + def _create_vocab(self): + self._vocab = [self._pad] + list(self._punctuations) + list(self._characters) + [self._blank] + self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} + self._id_to_char = dict(enumerate(self.vocab)) + + @staticmethod + def init_from_config(config: Coqpit): + if config.characters is not None: + _pad = config.characters["pad"] + _punctuations = config.characters["punctuations"] + _letters = config.characters["characters"] + _letters_ipa = config.characters["phonemes"] + return ( + VitsCharacters(graphemes=_letters, ipa_characters=_letters_ipa, punctuations=_punctuations, pad=_pad), + config, + ) + characters = VitsCharacters() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + characters=self._characters, + punctuations=self._punctuations, + pad=self._pad, + eos=None, + bos=None, + blank=self._blank, + is_unique=False, + is_sorted=True, + ) + + +class FairseqVocab(BaseVocabulary): + def __init__(self, vocab: str): + super(FairseqVocab).__init__() + self.vocab = vocab + + @property + def vocab(self): + """Return the vocabulary dictionary.""" + return self._vocab + + @vocab.setter + def vocab(self, vocab_file): + with open(vocab_file, encoding="utf-8") as f: + self._vocab = [x.replace("\n", "") for x in f.readlines()] + self.blank = self._vocab[0] + self.pad = " " + self._char_to_id = {s: i for i, s in enumerate(self._vocab)} # pylint: disable=unnecessary-comprehension + self._id_to_char = dict(enumerate(self._vocab)) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py new file mode 100644 index 0000000000000000000000000000000000000000..7c4a76ad7d2d0816394eb00d908933ad490930c6 --- /dev/null +++ b/TTS/tts/models/xtts.py @@ -0,0 +1,803 @@ +import logging +import os +from dataclasses import dataclass +from pathlib import Path + +import librosa +import torch +import torch.nn.functional as F +import torchaudio +from coqpit import Coqpit +from trainer.io import load_fsspec + +from TTS.tts.layers.xtts.gpt import GPT +from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder +from TTS.tts.layers.xtts.stream_generator import init_stream_support +from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence +from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager +from TTS.tts.models.base_tts import BaseTTS +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 + +logger = logging.getLogger(__name__) + +init_stream_support() + + +def wav_to_mel_cloning( + wav, + mel_norms_file="../experiments/clips_mel_norms.pth", + mel_norms=None, + device=torch.device("cpu"), + n_fft=4096, + hop_length=1024, + win_length=4096, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, +): + """ + Convert waveform to mel-spectrogram with hard-coded parameters for cloning. + + Args: + wav (torch.Tensor): Input waveform tensor. + mel_norms_file (str): Path to mel-spectrogram normalization file. + mel_norms (torch.Tensor): Mel-spectrogram normalization tensor. + device (torch.device): Device to use for computation. + + Returns: + torch.Tensor: Mel-spectrogram tensor. + """ + mel_stft = torchaudio.transforms.MelSpectrogram( + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + power=power, + normalized=normalized, + sample_rate=sample_rate, + f_min=f_min, + f_max=f_max, + n_mels=n_mels, + norm="slaney", + ).to(device) + wav = wav.to(device) + mel = mel_stft(wav) + mel = torch.log(torch.clamp(mel, min=1e-5)) + if mel_norms is None: + mel_norms = torch.load(mel_norms_file, map_location=device, weights_only=is_pytorch_at_least_2_4()) + mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1) + return mel + + +def load_audio(audiopath, sampling_rate): + # better load setting following: https://github.com/faroit/python_audio_loading_benchmark + + # torchaudio should chose proper backend to load audio depending on platform + audio, lsr = torchaudio.load(audiopath) + + # stereo to mono if needed + if audio.size(0) != 1: + audio = torch.mean(audio, dim=0, keepdim=True) + + if lsr != sampling_rate: + audio = torchaudio.functional.resample(audio, lsr, sampling_rate) + + # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk. + # '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds. + if torch.any(audio > 10) or not torch.any(audio < 0): + logger.error("Error with %s. Max=%.2f min=%.2f", audiopath, audio.max(), audio.min()) + # clip audio invalid values + audio.clip_(-1, 1) + return audio + + +def pad_or_truncate(t, length): + """ + Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it. + + Args: + t (torch.Tensor): The input tensor to be padded or truncated. + length (int): The desired length of the tensor. + + Returns: + torch.Tensor: The padded or truncated tensor. + """ + tp = t[..., :length] + if t.shape[-1] == length: + tp = t + elif t.shape[-1] < length: + tp = F.pad(t, (0, length - t.shape[-1])) + return tp + + +@dataclass +class XttsAudioConfig(Coqpit): + """ + Configuration class for audio-related parameters in the XTTS model. + + Args: + sample_rate (int): The sample rate in which the GPT operates. + output_sample_rate (int): The sample rate of the output audio waveform. + """ + + sample_rate: int = 22050 + output_sample_rate: int = 24000 + + +@dataclass +class XttsArgs(Coqpit): + """A dataclass to represent XTTS model arguments that define the model structure. + + Args: + gpt_batch_size (int): The size of the auto-regressive batch. + enable_redaction (bool, optional): Whether to enable redaction. Defaults to True. + kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True. + gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None. + clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. + decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. + num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. + + For GPT model: + gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. + gpt_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402. + gpt_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70. + gpt_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30. + gpt_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024. + gpt_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16. + gpt_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255. + gpt_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255. + gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False. + gpt_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False. + gpt_code_stride_len (int, optional): The hop_size of dvae and consequently of the gpt output. Defaults to 1024. + gpt_use_masking_gt_prompt_approach (bool, optional): If True, it will use ground truth as prompt and it will mask the loss to avoid repetition. Defaults to True. + gpt_use_perceiver_resampler (bool, optional): If True, it will use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198. Defaults to False. + """ + + gpt_batch_size: int = 1 + enable_redaction: bool = False + kv_cache: bool = True + gpt_checkpoint: str = None + clvp_checkpoint: str = None + decoder_checkpoint: str = None + num_chars: int = 255 + + # XTTS GPT Encoder params + tokenizer_file: str = "" + gpt_max_audio_tokens: int = 605 + gpt_max_text_tokens: int = 402 + gpt_max_prompt_tokens: int = 70 + gpt_layers: int = 30 + gpt_n_model_channels: int = 1024 + gpt_n_heads: int = 16 + gpt_number_text_tokens: int = None + gpt_start_text_token: int = None + gpt_stop_text_token: int = None + gpt_num_audio_tokens: int = 8194 + gpt_start_audio_token: int = 8192 + gpt_stop_audio_token: int = 8193 + gpt_code_stride_len: int = 1024 + gpt_use_masking_gt_prompt_approach: bool = True + gpt_use_perceiver_resampler: bool = False + + # HifiGAN Decoder params + input_sample_rate: int = 22050 + output_sample_rate: int = 24000 + output_hop_length: int = 256 + decoder_input_dim: int = 1024 + d_vector_dim: int = 512 + cond_d_vector_in_each_upsampling_layer: bool = True + + # constants + duration_const: int = 102400 + + +class Xtts(BaseTTS): + """ⓍTTS model implementation. + + ❗ Currently it only supports inference. + + Examples: + >>> from TTS.tts.configs.xtts_config import XttsConfig + >>> from TTS.tts.models.xtts import Xtts + >>> config = XttsConfig() + >>> model = Xtts.init_from_config(config) + >>> model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True) + """ + + def __init__(self, config: Coqpit): + super().__init__(config, ap=None, tokenizer=None) + self.mel_stats_path = None + self.config = config + self.gpt_checkpoint = self.args.gpt_checkpoint + self.decoder_checkpoint = self.args.decoder_checkpoint # TODO: check if this is even needed + self.models_dir = config.model_dir + self.gpt_batch_size = self.args.gpt_batch_size + + self.tokenizer = VoiceBpeTokenizer() + self.gpt = None + self.init_models() + self.register_buffer("mel_stats", torch.ones(80)) + + def init_models(self): + """Initialize the models. We do it here since we need to load the tokenizer first.""" + if self.tokenizer.tokenizer is not None: + self.args.gpt_number_text_tokens = self.tokenizer.get_number_tokens() + self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]") + self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]") + + if self.args.gpt_number_text_tokens: + self.gpt = GPT( + layers=self.args.gpt_layers, + model_dim=self.args.gpt_n_model_channels, + start_text_token=self.args.gpt_start_text_token, + stop_text_token=self.args.gpt_stop_text_token, + heads=self.args.gpt_n_heads, + max_text_tokens=self.args.gpt_max_text_tokens, + max_mel_tokens=self.args.gpt_max_audio_tokens, + max_prompt_tokens=self.args.gpt_max_prompt_tokens, + number_text_tokens=self.args.gpt_number_text_tokens, + num_audio_tokens=self.args.gpt_num_audio_tokens, + start_audio_token=self.args.gpt_start_audio_token, + stop_audio_token=self.args.gpt_stop_audio_token, + use_perceiver_resampler=self.args.gpt_use_perceiver_resampler, + code_stride_len=self.args.gpt_code_stride_len, + ) + + self.hifigan_decoder = HifiDecoder( + input_sample_rate=self.args.input_sample_rate, + output_sample_rate=self.args.output_sample_rate, + output_hop_length=self.args.output_hop_length, + ar_mel_length_compression=self.args.gpt_code_stride_len, + decoder_input_dim=self.args.decoder_input_dim, + d_vector_dim=self.args.d_vector_dim, + cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, + ) + + @property + def device(self): + return next(self.parameters()).device + + @torch.inference_mode() + def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6): + """Compute the conditioning latents for the GPT model from the given audio. + + Args: + audio (tensor): audio tensor. + sr (int): Sample rate of the audio. + length (int): Length of the audio in seconds. If < 0, use the whole audio. Defaults to 30. + chunk_length (int): Length of the audio chunks in seconds. When `length == chunk_length`, the whole audio + is being used without chunking. It must be < `length`. Defaults to 6. + """ + if sr != 22050: + audio = torchaudio.functional.resample(audio, sr, 22050) + if length > 0: + audio = audio[:, : 22050 * length] + if self.args.gpt_use_perceiver_resampler: + style_embs = [] + for i in range(0, audio.shape[1], 22050 * chunk_length): + audio_chunk = audio[:, i : i + 22050 * chunk_length] + + # if the chunk is too short ignore it + if audio_chunk.size(-1) < 22050 * 0.33: + continue + + mel_chunk = wav_to_mel_cloning( + audio_chunk, + mel_norms=self.mel_stats.cpu(), + n_fft=2048, + hop_length=256, + win_length=1024, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, + ) + style_emb = self.gpt.get_style_emb(mel_chunk.to(self.device), None) + style_embs.append(style_emb) + + # mean style embedding + cond_latent = torch.stack(style_embs).mean(dim=0) + else: + mel = wav_to_mel_cloning( + audio, + mel_norms=self.mel_stats.cpu(), + n_fft=4096, + hop_length=1024, + win_length=4096, + power=2, + normalized=False, + sample_rate=22050, + f_min=0, + f_max=8000, + n_mels=80, + ) + cond_latent = self.gpt.get_style_emb(mel.to(self.device)) + return cond_latent.transpose(1, 2) + + @torch.inference_mode() + def get_speaker_embedding(self, audio, sr): + audio_16k = torchaudio.functional.resample(audio, sr, 16000) + return ( + self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.device), l2_norm=True) + .unsqueeze(-1) + .to(self.device) + ) + + @torch.inference_mode() + def get_conditioning_latents( + self, + audio_path, + max_ref_length=30, + gpt_cond_len=6, + gpt_cond_chunk_len=6, + librosa_trim_db=None, + sound_norm_refs=False, + load_sr=22050, + ): + """Get the conditioning latents for the GPT model from the given audio. + + Args: + audio_path (str or List[str]): Path to reference audio file(s). + max_ref_length (int): Maximum length of each reference audio in seconds. Defaults to 30. + gpt_cond_len (int): Length of the audio used for gpt latents. Defaults to 6. + gpt_cond_chunk_len (int): Chunk length used for gpt latents. It must be <= gpt_conf_len. Defaults to 6. + librosa_trim_db (int, optional): Trim the audio using this value. If None, not trimming. Defaults to None. + sound_norm_refs (bool, optional): Whether to normalize the audio. Defaults to False. + load_sr (int, optional): Sample rate to load the audio. Defaults to 24000. + """ + # deal with multiples references + if not isinstance(audio_path, list): + audio_paths = [audio_path] + else: + audio_paths = audio_path + + speaker_embeddings = [] + audios = [] + speaker_embedding = None + for file_path in audio_paths: + audio = load_audio(file_path, load_sr) + audio = audio[:, : load_sr * max_ref_length].to(self.device) + if sound_norm_refs: + audio = (audio / torch.abs(audio).max()) * 0.75 + if librosa_trim_db is not None: + audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] + + # compute latents for the decoder + speaker_embedding = self.get_speaker_embedding(audio, load_sr) + speaker_embeddings.append(speaker_embedding) + + audios.append(audio) + + # merge all the audios and compute the latents for the gpt + full_audio = torch.cat(audios, dim=-1) + gpt_cond_latents = self.get_gpt_cond_latents( + full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len + ) # [1, 1024, T] + + if speaker_embeddings: + speaker_embedding = torch.stack(speaker_embeddings) + speaker_embedding = speaker_embedding.mean(dim=0) + + return gpt_cond_latents, speaker_embedding + + def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwargs): + """Synthesize speech with the given input text. + + Args: + text (str): Input text. + config (XttsConfig): Config with inference parameters. + speaker_wav (list): List of paths to the speaker audio files to be used for cloning. + language (str): Language ID of the speaker. + **kwargs: Inference settings. See `inference()`. + + Returns: + A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference, + `text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents` + as latents used at inference. + + """ + assert ( + "zh-cn" if language == "zh" else language in self.config.languages + ), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}" + # Use generally found best tuning knobs for generation. + settings = { + "temperature": config.temperature, + "length_penalty": config.length_penalty, + "repetition_penalty": config.repetition_penalty, + "top_k": config.top_k, + "top_p": config.top_p, + } + settings.update(kwargs) # allow overriding of preset settings with kwargs + if speaker_id is not None: + gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values() + return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings) + settings.update( + { + "gpt_cond_len": config.gpt_cond_len, + "gpt_cond_chunk_len": config.gpt_cond_chunk_len, + "max_ref_len": config.max_ref_len, + "sound_norm_refs": config.sound_norm_refs, + } + ) + return self.full_inference(text, speaker_wav, language, **settings) + + @torch.inference_mode() + def full_inference( + self, + text, + ref_audio_path, + language, + # GPT inference + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, + top_k=50, + top_p=0.85, + do_sample=True, + # Cloning + gpt_cond_len=30, + gpt_cond_chunk_len=6, + max_ref_len=10, + sound_norm_refs=False, + **hf_generate_kwargs, + ): + """ + This function produces an audio clip of the given text being spoken with the given reference voice. + + Args: + text: (str) Text to be spoken. + + ref_audio_path: (str) Path to a reference audio file to be used for cloning. This audio file should be >3 + seconds long. + + language: (str) Language of the voice to be generated. + + temperature: (float) The softmax temperature of the autoregressive model. Defaults to 0.65. + + length_penalty: (float) A length penalty applied to the autoregressive decoder. Higher settings causes the + model to produce more terse outputs. Defaults to 1.0. + + repetition_penalty: (float) A penalty that prevents the autoregressive decoder from repeating itself during + decoding. Can be used to reduce the incidence of long silences or "uhhhhhhs", etc. Defaults to 2.0. + + top_k: (int) K value used in top-k sampling. [0,inf]. Lower values mean the decoder produces more "likely" + (aka boring) outputs. Defaults to 50. + + top_p: (float) P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely" + (aka boring) outputs. Defaults to 0.8. + + gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used + else the first `gpt_cond_len` secs is used. Defaults to 30 seconds. + + gpt_cond_chunk_len: (int) Chunk length used for cloning. It must be <= `gpt_cond_len`. + If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to 6 seconds. + + hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive + transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation + here: https://huggingface.co/docs/transformers/internal/generation_utils + + Returns: + Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. + Sample rate is 24kHz. + """ + (gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents( + audio_path=ref_audio_path, + gpt_cond_len=gpt_cond_len, + gpt_cond_chunk_len=gpt_cond_chunk_len, + max_ref_length=max_ref_len, + sound_norm_refs=sound_norm_refs, + ) + + return self.inference( + text, + language, + gpt_cond_latent, + speaker_embedding, + temperature=temperature, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + top_k=top_k, + top_p=top_p, + do_sample=do_sample, + **hf_generate_kwargs, + ) + + @torch.inference_mode() + def inference( + self, + text, + language, + gpt_cond_latent, + speaker_embedding, + # GPT inference + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, + top_k=50, + top_p=0.85, + do_sample=True, + num_beams=1, + speed=1.0, + enable_text_splitting=False, + **hf_generate_kwargs, + ): + language = language.split("-")[0] # remove the country code + length_scale = 1.0 / max(speed, 0.05) + gpt_cond_latent = gpt_cond_latent.to(self.device) + speaker_embedding = speaker_embedding.to(self.device) + if enable_text_splitting: + text = split_sentence(text, language, self.tokenizer.char_limits[language]) + else: + text = [text] + + wavs = [] + gpt_latents_list = [] + for sent in text: + sent = sent.strip().lower() + text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device) + + assert ( + text_tokens.shape[-1] < self.args.gpt_max_text_tokens + ), " ❗ XTTS can only generate text with a maximum of 400 tokens." + + with torch.no_grad(): + gpt_codes = self.gpt.generate( + cond_latents=gpt_cond_latent, + text_inputs=text_tokens, + input_tokens=None, + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=self.gpt_batch_size, + num_beams=num_beams, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + output_attentions=False, + **hf_generate_kwargs, + ) + expected_output_len = torch.tensor( + [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device + ) + + text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) + gpt_latents = self.gpt( + text_tokens, + text_len, + gpt_codes, + expected_output_len, + cond_latents=gpt_cond_latent, + return_attentions=False, + return_latent=True, + ) + + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear" + ).transpose(1, 2) + + gpt_latents_list.append(gpt_latents.cpu()) + wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze()) + + return { + "wav": torch.cat(wavs, dim=0).numpy(), + "gpt_latents": torch.cat(gpt_latents_list, dim=1).numpy(), + "speaker_embedding": speaker_embedding, + } + + def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): + """Handle chunk formatting in streaming mode""" + wav_chunk = wav_gen[:-overlap_len] + if wav_gen_prev is not None: + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len] + if wav_overlap is not None: + # cross fade the overlap section + if overlap_len > len(wav_chunk): + # wav_chunk is smaller than overlap_len, pass on last wav_gen + if wav_gen_prev is not None: + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) :] + else: + # not expecting will hit here as problem happens on last chunk + wav_chunk = wav_gen[-overlap_len:] + return wav_chunk, wav_gen, None + else: + crossfade_wav = wav_chunk[:overlap_len] + crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device) + wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device) + wav_chunk[:overlap_len] += crossfade_wav + + wav_overlap = wav_gen[-overlap_len:] + wav_gen_prev = wav_gen + return wav_chunk, wav_gen_prev, wav_overlap + + @torch.inference_mode() + def inference_stream( + self, + text, + language, + gpt_cond_latent, + speaker_embedding, + # Streaming + stream_chunk_size=20, + overlap_wav_len=1024, + # GPT inference + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, + top_k=50, + top_p=0.85, + do_sample=True, + speed=1.0, + enable_text_splitting=False, + **hf_generate_kwargs, + ): + language = language.split("-")[0] # remove the country code + length_scale = 1.0 / max(speed, 0.05) + gpt_cond_latent = gpt_cond_latent.to(self.device) + speaker_embedding = speaker_embedding.to(self.device) + if enable_text_splitting: + text = split_sentence(text, language, self.tokenizer.char_limits[language]) + else: + text = [text] + + for sent in text: + sent = sent.strip().lower() + text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device) + + assert ( + text_tokens.shape[-1] < self.args.gpt_max_text_tokens + ), " ❗ XTTS can only generate text with a maximum of 400 tokens." + + fake_inputs = self.gpt.compute_embeddings( + gpt_cond_latent.to(self.device), + text_tokens, + ) + gpt_generator = self.gpt.get_generator( + fake_inputs=fake_inputs, + top_k=top_k, + top_p=top_p, + temperature=temperature, + do_sample=do_sample, + num_beams=1, + num_return_sequences=1, + length_penalty=float(length_penalty), + repetition_penalty=float(repetition_penalty), + output_attentions=False, + output_hidden_states=True, + return_dict_in_generate=True, + **hf_generate_kwargs, + ) + + last_tokens = [] + all_latents = [] + wav_gen_prev = None + wav_overlap = None + is_end = False + + while not is_end: + try: + x, latent = next(gpt_generator) + last_tokens += [x] + all_latents += [latent] + except StopIteration: + is_end = True + + if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): + gpt_latents = torch.cat(all_latents, dim=0)[None, :] + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear" + ).transpose(1, 2) + wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) + wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( + wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len + ) + last_tokens = [] + yield wav_chunk + + def forward(self): + raise NotImplementedError( + "XTTS has a dedicated trainer, please check the XTTS docs: https://coqui-tts.readthedocs.io/en/latest/models/xtts.html#training" + ) + + def eval_step(self): + raise NotImplementedError( + "XTTS has a dedicated trainer, please check the XTTS docs: https://coqui-tts.readthedocs.io/en/latest/models/xtts.html#training" + ) + + @staticmethod + def init_from_config(config: "XttsConfig", **kwargs): # pylint: disable=unused-argument + return Xtts(config) + + def eval(self): # pylint: disable=redefined-builtin + """Sets the model to evaluation mode. Overrides the default eval() method to also set the GPT model to eval mode.""" + self.gpt.init_gpt_for_inference() + super().eval() + + def get_compatible_checkpoint_state_dict(self, model_path): + checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"] + # remove xtts gpt trainer extra keys + ignore_keys = ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"] + for key in list(checkpoint.keys()): + # check if it is from the coqui Trainer if so convert it + if key.startswith("xtts."): + new_key = key.replace("xtts.", "") + checkpoint[new_key] = checkpoint[key] + del checkpoint[key] + key = new_key + + # remove unused keys + if key.split(".")[0] in ignore_keys: + del checkpoint[key] + + return checkpoint + + def load_checkpoint( + self, + config, + checkpoint_dir=None, + checkpoint_path=None, + vocab_path=None, + eval=True, + strict=True, + use_deepspeed=False, + speaker_file_path=None, + ): + """ + Loads a checkpoint from disk and initializes the model's state and tokenizer. + + Args: + config (dict): The configuration dictionary for the model. + checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None. + checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None. + vocab_path (str, optional): The path to the vocabulary file. Defaults to None. + eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True. + strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True. + + Returns: + None + """ + + model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") + if vocab_path is None: + if checkpoint_dir is not None and (Path(checkpoint_dir) / "vocab.json").is_file(): + vocab_path = str(Path(checkpoint_dir) / "vocab.json") + else: + vocab_path = config.model_args.tokenizer_file + + if speaker_file_path is None and checkpoint_dir is not None: + speaker_file_path = os.path.join(checkpoint_dir, "speakers_xtts.pth") + + self.language_manager = LanguageManager(config) + self.speaker_manager = None + if speaker_file_path is not None and os.path.exists(speaker_file_path): + self.speaker_manager = SpeakerManager(speaker_file_path) + + if os.path.exists(vocab_path): + self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path) + + self.init_models() + + checkpoint = self.get_compatible_checkpoint_state_dict(model_path) + + # deal with v1 and v1.1. V1 has the init_gpt_for_inference keys, v1.1 do not + try: + self.load_state_dict(checkpoint, strict=strict) + except: + if eval: + self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) + self.load_state_dict(checkpoint, strict=strict) + + if eval: + self.hifigan_decoder.eval() + self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed) + self.gpt.eval() + + def train_step(self): + raise NotImplementedError( + "XTTS has a dedicated trainer, please check the XTTS docs: https://coqui-tts.readthedocs.io/en/latest/models/xtts.html#training" + ) diff --git a/TTS/tts/utils/__init__.py b/TTS/tts/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/utils/assets/tortoise/tokenizer.json b/TTS/tts/utils/assets/tortoise/tokenizer.json new file mode 100644 index 0000000000000000000000000000000000000000..c2fb44a729234c642deac6ed3e83e702c54ce29d --- /dev/null +++ b/TTS/tts/utils/assets/tortoise/tokenizer.json @@ -0,0 +1 @@ +{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}} diff --git a/TTS/tts/utils/data.py b/TTS/tts/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..22e46b683adfc7f6c7c8a57fb5b697e422cd915c --- /dev/null +++ b/TTS/tts/utils/data.py @@ -0,0 +1,79 @@ +import bisect + +import numpy as np +import torch + + +def _pad_data(x, length): + _pad = 0 + assert x.ndim == 1 + return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=_pad) + + +def prepare_data(inputs): + max_len = max((len(x) for x in inputs)) + return np.stack([_pad_data(x, max_len) for x in inputs]) + + +def _pad_tensor(x, length): + _pad = 0.0 + assert x.ndim == 2 + x = np.pad(x, [[0, 0], [0, length - x.shape[1]]], mode="constant", constant_values=_pad) + return x + + +def prepare_tensor(inputs, out_steps): + max_len = max((x.shape[1] for x in inputs)) + remainder = max_len % out_steps + pad_len = max_len + (out_steps - remainder) if remainder > 0 else max_len + return np.stack([_pad_tensor(x, pad_len) for x in inputs]) + + +def _pad_stop_target(x: np.ndarray, length: int, pad_val=1) -> np.ndarray: + """Pad stop target array. + + Args: + x (np.ndarray): Stop target array. + length (int): Length after padding. + pad_val (int, optional): Padding value. Defaults to 1. + + Returns: + np.ndarray: Padded stop target array. + """ + assert x.ndim == 1 + return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=pad_val) + + +def prepare_stop_target(inputs, out_steps): + """Pad row vectors with 1.""" + max_len = max((x.shape[0] for x in inputs)) + remainder = max_len % out_steps + pad_len = max_len + (out_steps - remainder) if remainder > 0 else max_len + return np.stack([_pad_stop_target(x, pad_len) for x in inputs]) + + +def pad_per_step(inputs, pad_len): + return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0) + + +def get_length_balancer_weights(items: list, num_buckets=10): + # get all durations + audio_lengths = np.array([item["audio_length"] for item in items]) + # create the $num_buckets buckets classes based in the dataset max and min length + max_length = int(max(audio_lengths)) + min_length = int(min(audio_lengths)) + step = int((max_length - min_length) / num_buckets) + 1 + buckets_classes = [i + step for i in range(min_length, (max_length - step) + num_buckets + 1, step)] + # add each sample in their respective length bucket + buckets_names = np.array( + [buckets_classes[bisect.bisect_left(buckets_classes, item["audio_length"])] for item in items] + ) + # count and compute the weights_bucket for each sample + unique_buckets_names = np.unique(buckets_names).tolist() + bucket_ids = [unique_buckets_names.index(l) for l in buckets_names] + bucket_count = np.array([len(np.where(buckets_names == l)[0]) for l in unique_buckets_names]) + weight_bucket = 1.0 / bucket_count + dataset_samples_weight = np.array([weight_bucket[l] for l in bucket_ids]) + # normalize + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + return torch.from_numpy(dataset_samples_weight).float() diff --git a/TTS/tts/utils/fairseq.py b/TTS/tts/utils/fairseq.py new file mode 100644 index 0000000000000000000000000000000000000000..20907a05323d763051aa068dd55419d9dd15e3da --- /dev/null +++ b/TTS/tts/utils/fairseq.py @@ -0,0 +1,50 @@ +import torch + +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 + + +def rehash_fairseq_vits_checkpoint(checkpoint_file): + chk = torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4())["model"] + new_chk = {} + for k, v in chk.items(): + if "enc_p." in k: + new_chk[k.replace("enc_p.", "text_encoder.")] = v + elif "dec." in k: + new_chk[k.replace("dec.", "waveform_decoder.")] = v + elif "enc_q." in k: + new_chk[k.replace("enc_q.", "posterior_encoder.")] = v + elif "flow.flows.2." in k: + new_chk[k.replace("flow.flows.2.", "flow.flows.1.")] = v + elif "flow.flows.4." in k: + new_chk[k.replace("flow.flows.4.", "flow.flows.2.")] = v + elif "flow.flows.6." in k: + new_chk[k.replace("flow.flows.6.", "flow.flows.3.")] = v + elif "dp.flows.0.m" in k: + new_chk[k.replace("dp.flows.0.m", "duration_predictor.flows.0.translation")] = v + elif "dp.flows.0.logs" in k: + new_chk[k.replace("dp.flows.0.logs", "duration_predictor.flows.0.log_scale")] = v + elif "dp.flows.1" in k: + new_chk[k.replace("dp.flows.1", "duration_predictor.flows.1")] = v + elif "dp.flows.3" in k: + new_chk[k.replace("dp.flows.3", "duration_predictor.flows.2")] = v + elif "dp.flows.5" in k: + new_chk[k.replace("dp.flows.5", "duration_predictor.flows.3")] = v + elif "dp.flows.7" in k: + new_chk[k.replace("dp.flows.7", "duration_predictor.flows.4")] = v + elif "dp.post_flows.0.m" in k: + new_chk[k.replace("dp.post_flows.0.m", "duration_predictor.post_flows.0.translation")] = v + elif "dp.post_flows.0.logs" in k: + new_chk[k.replace("dp.post_flows.0.logs", "duration_predictor.post_flows.0.log_scale")] = v + elif "dp.post_flows.1" in k: + new_chk[k.replace("dp.post_flows.1", "duration_predictor.post_flows.1")] = v + elif "dp.post_flows.3" in k: + new_chk[k.replace("dp.post_flows.3", "duration_predictor.post_flows.2")] = v + elif "dp.post_flows.5" in k: + new_chk[k.replace("dp.post_flows.5", "duration_predictor.post_flows.3")] = v + elif "dp.post_flows.7" in k: + new_chk[k.replace("dp.post_flows.7", "duration_predictor.post_flows.4")] = v + elif "dp." in k: + new_chk[k.replace("dp.", "duration_predictor.")] = v + else: + new_chk[k] = v + return new_chk diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..7429d0fcc88e2c0250101038654835c991ef740e --- /dev/null +++ b/TTS/tts/utils/helpers.py @@ -0,0 +1,257 @@ +import numpy as np +import torch +from scipy.stats import betabinom +from torch.nn import functional as F + +try: + from TTS.tts.utils.monotonic_align.core import maximum_path_c + + CYTHON = True +except ModuleNotFoundError: + CYTHON = False + + +class StandardScaler: + """StandardScaler for mean-scale normalization with the given mean and scale values.""" + + def __init__(self, mean: np.ndarray = None, scale: np.ndarray = None) -> None: + self.mean_ = mean + self.scale_ = scale + + def set_stats(self, mean, scale): + self.mean_ = mean + self.scale_ = scale + + def reset_stats(self): + delattr(self, "mean_") + delattr(self, "scale_") + + def transform(self, X): + X = np.asarray(X) + X -= self.mean_ + X /= self.scale_ + return X + + def inverse_transform(self, X): + X = np.asarray(X) + X *= self.scale_ + X += self.mean_ + return X + + +# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 +def sequence_mask(sequence_length, max_len=None): + """Create a sequence mask for filtering padding in a sequence tensor. + + Args: + sequence_length (torch.tensor): Sequence lengths. + max_len (int, Optional): Maximum sequence length. Defaults to None. + + Shapes: + - mask: :math:`[B, T_max]` + """ + if max_len is None: + max_len = sequence_length.max() + seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device) + # B x T_max + return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) + + +def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False): + """Segment each sample in a batch based on the provided segment indices + + Args: + x (torch.tensor): Input tensor. + segment_indices (torch.tensor): Segment indices. + segment_size (int): Expected output segment size. + pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. + """ + # pad the input tensor if it is shorter than the segment size + if pad_short and x.shape[-1] < segment_size: + x = torch.nn.functional.pad(x, (0, segment_size - x.size(2))) + + segments = torch.zeros_like(x[:, :, :segment_size]) + + for i in range(x.size(0)): + index_start = segment_indices[i] + index_end = index_start + segment_size + x_i = x[i] + if pad_short and index_end >= x.size(2): + # pad the sample if it is shorter than the segment size + x_i = torch.nn.functional.pad(x_i, (0, (index_end + 1) - x.size(2))) + segments[i] = x_i[:, index_start:index_end] + return segments + + +def rand_segments( + x: torch.tensor, x_lengths: torch.tensor = None, segment_size=4, let_short_samples=False, pad_short=False +): + """Create random segments based on the input lengths. + + Args: + x (torch.tensor): Input tensor. + x_lengths (torch.tensor): Input lengths. + segment_size (int): Expected output segment size. + let_short_samples (bool): Allow shorter samples than the segment size. + pad_short (bool): Pad the end of input tensor with zeros if shorter than the segment size. + + Shapes: + - x: :math:`[B, C, T]` + - x_lengths: :math:`[B]` + """ + _x_lenghts = x_lengths.clone() + B, _, T = x.size() + if pad_short: + if T < segment_size: + x = torch.nn.functional.pad(x, (0, segment_size - T)) + T = segment_size + if _x_lenghts is None: + _x_lenghts = T + len_diff = _x_lenghts - segment_size + if let_short_samples: + _x_lenghts[len_diff < 0] = segment_size + len_diff = _x_lenghts - segment_size + else: + assert all( + len_diff > 0 + ), f" [!] At least one sample is shorter than the segment size ({segment_size}). \n {_x_lenghts}" + segment_indices = (torch.rand([B]).type_as(x) * (len_diff + 1)).long() + ret = segment(x, segment_indices, segment_size, pad_short=pad_short) + return ret, segment_indices + + +def average_over_durations(values, durs): + """Average values over durations. + + Shapes: + - values: :math:`[B, 1, T_de]` + - durs: :math:`[B, T_en]` + - avg: :math:`[B, 1, T_en]` + """ + durs_cums_ends = torch.cumsum(durs, dim=1).long() + durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0)) + values_nonzero_cums = torch.nn.functional.pad(torch.cumsum(values != 0.0, dim=2), (1, 0)) + values_cums = torch.nn.functional.pad(torch.cumsum(values, dim=2), (1, 0)) + + bs, l = durs_cums_ends.size() + n_formants = values.size(1) + dcs = durs_cums_starts[:, None, :].expand(bs, n_formants, l) + dce = durs_cums_ends[:, None, :].expand(bs, n_formants, l) + + values_sums = (torch.gather(values_cums, 2, dce) - torch.gather(values_cums, 2, dcs)).float() + values_nelems = (torch.gather(values_nonzero_cums, 2, dce) - torch.gather(values_nonzero_cums, 2, dcs)).float() + + avg = torch.where(values_nelems == 0.0, values_nelems, values_sums / values_nelems) + return avg + + +def convert_pad_shape(pad_shape: list[list]) -> list: + l = pad_shape[::-1] + return [item for sublist in l for item in sublist] + + +def generate_path(duration, mask): + """ + Shapes: + - duration: :math:`[B, T_en]` + - mask: :math:'[B, T_en, T_de]` + - path: :math:`[B, T_en, T_de]` + """ + b, t_x, t_y = mask.shape + cum_duration = torch.cumsum(duration, 1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path * mask + return path + + +def maximum_path(value, mask): + if CYTHON: + return maximum_path_cython(value, mask) + return maximum_path_numpy(value, mask) + + +def maximum_path_cython(value, mask): + """Cython optimised version. + Shapes: + - value: :math:`[B, T_en, T_de]` + - mask: :math:`[B, T_en, T_de]` + """ + value = value * mask + device = value.device + dtype = value.dtype + value = value.data.cpu().numpy().astype(np.float32) + path = np.zeros_like(value).astype(np.int32) + mask = mask.data.cpu().numpy() + + t_x_max = mask.sum(1)[:, 0].astype(np.int32) + t_y_max = mask.sum(2)[:, 0].astype(np.int32) + maximum_path_c(path, value, t_x_max, t_y_max) + return torch.from_numpy(path).to(device=device, dtype=dtype) + + +def maximum_path_numpy(value, mask, max_neg_val=None): + """ + Monotonic alignment search algorithm + Numpy-friendly version. It's about 4 times faster than torch version. + value: [b, t_x, t_y] + mask: [b, t_x, t_y] + """ + if max_neg_val is None: + max_neg_val = -np.inf # Patch for Sphinx complaint + value = value * mask + + device = value.device + dtype = value.dtype + value = value.cpu().detach().numpy() + mask = mask.cpu().detach().numpy().astype(bool) + + b, t_x, t_y = value.shape + direction = np.zeros(value.shape, dtype=np.int64) + v = np.zeros((b, t_x), dtype=np.float32) + x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1) + for j in range(t_y): + v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1] + v1 = v + max_mask = v1 >= v0 + v_max = np.where(max_mask, v1, v0) + direction[:, :, j] = max_mask + + index_mask = x_range <= j + v = np.where(index_mask, v_max + value[:, :, j], max_neg_val) + direction = np.where(mask, direction, 1) + + path = np.zeros(value.shape, dtype=np.float32) + index = mask[:, :, 0].sum(1).astype(np.int64) - 1 + index_range = np.arange(b) + for j in reversed(range(t_y)): + path[index_range, index, j] = 1 + index = index + direction[index_range, index, j] - 1 + path = path * mask.astype(np.float32) + path = torch.from_numpy(path).to(device=device, dtype=dtype) + return path + + +def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0): + P, M = phoneme_count, mel_count + x = np.arange(0, P) + mel_text_probs = [] + for i in range(1, M + 1): + a, b = scaling_factor * i, scaling_factor * (M + 1 - i) + rv = betabinom(P, a, b) + mel_i_prob = rv.pmf(x) + mel_text_probs.append(mel_i_prob) + return np.array(mel_text_probs) + + +def compute_attn_prior(x_len, y_len, scaling_factor=1.0): + """Compute attention priors for the alignment network.""" + attn_prior = beta_binomial_prior_distribution( + x_len, + y_len, + scaling_factor, + ) + return attn_prior # [y_len, x_len] diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py new file mode 100644 index 0000000000000000000000000000000000000000..f134daf58e68db70e5c56c669f6530f7a5a24722 --- /dev/null +++ b/TTS/tts/utils/languages.py @@ -0,0 +1,125 @@ +import os +from typing import Any, Dict, List, Optional + +import fsspec +import numpy as np +import torch +from coqpit import Coqpit + +from TTS.config import check_config_and_model_args +from TTS.tts.utils.managers import BaseIDManager + + +class LanguageManager(BaseIDManager): + """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information + in a way that can be queried by language. + + Args: + language_ids_file_path (str, optional): Path to the metafile that maps language names to ids used by + TTS models. Defaults to "". + config (Coqpit, optional): Coqpit config that contains the language information in the datasets filed. + Defaults to None. + + Examples: + >>> manager = LanguageManager(language_ids_file_path=language_ids_file_path) + >>> language_id_mapper = manager.language_ids + """ + + def __init__( + self, + language_ids_file_path: str = "", + config: Coqpit = None, + ): + super().__init__(id_file_path=language_ids_file_path) + + if config: + self.set_language_ids_from_config(config) + + @property + def num_languages(self) -> int: + return len(list(self.name_to_id.keys())) + + @property + def language_names(self) -> List: + return list(self.name_to_id.keys()) + + @staticmethod + def parse_language_ids_from_config(c: Coqpit) -> Dict: + """Set language id from config. + + Args: + c (Coqpit): Config + + Returns: + Tuple[Dict, int]: Language ID mapping and the number of languages. + """ + languages = set({}) + for dataset in c.datasets: + if "language" in dataset: + languages.add(dataset["language"]) + else: + raise ValueError(f"Dataset {dataset['name']} has no language specified.") + return {name: i for i, name in enumerate(sorted(languages))} + + def set_language_ids_from_config(self, c: Coqpit) -> None: + """Set language IDs from config samples. + + Args: + c (Coqpit): Config. + """ + self.name_to_id = self.parse_language_ids_from_config(c) + + @staticmethod + def parse_ids_from_data(items: List, parse_key: str) -> Any: + raise NotImplementedError + + def set_ids_from_data(self, items: List, parse_key: str) -> Any: + raise NotImplementedError + + def save_ids_to_file(self, file_path: str) -> None: + """Save language IDs to a json file. + + Args: + file_path (str): Path to the output file. + """ + self._save_json(file_path, self.name_to_id) + + @staticmethod + def init_from_config(config: Coqpit) -> Optional["LanguageManager"]: + """Initialize the language manager from a Coqpit config. + + Args: + config (Coqpit): Coqpit config. + """ + if check_config_and_model_args(config, "use_language_embedding", True): + if config.get("language_ids_file", None): + return LanguageManager(language_ids_file_path=config.language_ids_file) + # Fall back to parse language IDs from the config + return LanguageManager(config=config) + return None + + +def _set_file_path(path): + """Find the language_ids.json under the given path or the above it. + Intended to band aid the different paths returned in restored and continued training.""" + path_restore = os.path.join(os.path.dirname(path), "language_ids.json") + path_continue = os.path.join(path, "language_ids.json") + fs = fsspec.get_mapper(path).fs + if fs.exists(path_restore): + return path_restore + if fs.exists(path_continue): + return path_continue + return None + + +def get_language_balancer_weights(items: list): + language_names = np.array([item["language"] for item in items]) + unique_language_names = np.unique(language_names).tolist() + language_ids = [unique_language_names.index(l) for l in language_names] + language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names]) + weight_language = 1.0 / language_count + # get weight for each sample + dataset_samples_weight = np.array([weight_language[l] for l in language_ids]) + # normalize + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + return torch.from_numpy(dataset_samples_weight).float() diff --git a/TTS/tts/utils/managers.py b/TTS/tts/utils/managers.py new file mode 100644 index 0000000000000000000000000000000000000000..6a2f7df67b31b14f66c9ce698c5010c325477956 --- /dev/null +++ b/TTS/tts/utils/managers.py @@ -0,0 +1,384 @@ +import json +import random +from typing import Any, Dict, List, Tuple, Union + +import fsspec +import numpy as np +import torch + +from TTS.config import load_config +from TTS.encoder.utils.generic_utils import setup_encoder_model +from TTS.utils.audio import AudioProcessor +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 + + +def load_file(path: str): + if path.endswith(".json"): + with fsspec.open(path, "r") as f: + return json.load(f) + elif path.endswith(".pth"): + with fsspec.open(path, "rb") as f: + return torch.load(f, map_location="cpu", weights_only=is_pytorch_at_least_2_4()) + else: + raise ValueError("Unsupported file type") + + +def save_file(obj: Any, path: str): + if path.endswith(".json"): + with fsspec.open(path, "w") as f: + json.dump(obj, f, indent=4) + elif path.endswith(".pth"): + with fsspec.open(path, "wb") as f: + torch.save(obj, f) + else: + raise ValueError("Unsupported file type") + + +class BaseIDManager: + """Base `ID` Manager class. Every new `ID` manager must inherit this. + It defines common `ID` manager specific functions. + """ + + def __init__(self, id_file_path: str = ""): + self.name_to_id = {} + + if id_file_path: + self.load_ids_from_file(id_file_path) + + @staticmethod + def _load_json(json_file_path: str) -> Dict: + with fsspec.open(json_file_path, "r") as f: + return json.load(f) + + @staticmethod + def _save_json(json_file_path: str, data: dict) -> None: + with fsspec.open(json_file_path, "w") as f: + json.dump(data, f, indent=4) + + def set_ids_from_data(self, items: List, parse_key: str) -> None: + """Set IDs from data samples. + + Args: + items (List): Data sampled returned by `load_tts_samples()`. + """ + self.name_to_id = self.parse_ids_from_data(items, parse_key=parse_key) + + def load_ids_from_file(self, file_path: str) -> None: + """Set IDs from a file. + + Args: + file_path (str): Path to the file. + """ + self.name_to_id = load_file(file_path) + + def save_ids_to_file(self, file_path: str) -> None: + """Save IDs to a json file. + + Args: + file_path (str): Path to the output file. + """ + save_file(self.name_to_id, file_path) + + def get_random_id(self) -> Any: + """Get a random embedding. + + Args: + + Returns: + np.ndarray: embedding. + """ + if self.name_to_id: + return self.name_to_id[random.choices(list(self.name_to_id.keys()))[0]] + + return None + + @staticmethod + def parse_ids_from_data(items: List, parse_key: str) -> Tuple[Dict]: + """Parse IDs from data samples retured by `load_tts_samples()`. + + Args: + items (list): Data sampled returned by `load_tts_samples()`. + parse_key (str): The key to being used to parse the data. + Returns: + Tuple[Dict]: speaker IDs. + """ + classes = sorted({item[parse_key] for item in items}) + ids = {name: i for i, name in enumerate(classes)} + return ids + + +class EmbeddingManager(BaseIDManager): + """Base `Embedding` Manager class. Every new `Embedding` manager must inherit this. + It defines common `Embedding` manager specific functions. + + It expects embeddings files in the following format: + + :: + + { + 'audio_file_key':{ + 'name': 'category_name', + 'embedding'[] + }, + ... + } + + `audio_file_key` is a unique key to the audio file in the dataset. It can be the path to the file or any other unique key. + `embedding` is the embedding vector of the audio file. + `name` can be name of the speaker of the audio file. + """ + + def __init__( + self, + embedding_file_path: Union[str, List[str]] = "", + id_file_path: str = "", + encoder_model_path: str = "", + encoder_config_path: str = "", + use_cuda: bool = False, + ): + super().__init__(id_file_path=id_file_path) + + self.embeddings = {} + self.embeddings_by_names = {} + self.clip_ids = [] + self.encoder = None + self.encoder_ap = None + self.use_cuda = use_cuda + + if embedding_file_path: + if isinstance(embedding_file_path, list): + self.load_embeddings_from_list_of_files(embedding_file_path) + else: + self.load_embeddings_from_file(embedding_file_path) + + if encoder_model_path and encoder_config_path: + self.init_encoder(encoder_model_path, encoder_config_path, use_cuda) + + @property + def num_embeddings(self): + """Get number of embeddings.""" + return len(self.embeddings) + + @property + def num_names(self): + """Get number of embeddings.""" + return len(self.embeddings_by_names) + + @property + def embedding_dim(self): + """Dimensionality of embeddings. If embeddings are not loaded, returns zero.""" + if self.embeddings: + return len(self.embeddings[list(self.embeddings.keys())[0]]["embedding"]) + return 0 + + @property + def embedding_names(self): + """Get embedding names.""" + return list(self.embeddings_by_names.keys()) + + def save_embeddings_to_file(self, file_path: str) -> None: + """Save embeddings to a json file. + + Args: + file_path (str): Path to the output file. + """ + save_file(self.embeddings, file_path) + + @staticmethod + def read_embeddings_from_file(file_path: str): + """Load embeddings from a json file. + + Args: + file_path (str): Path to the file. + """ + embeddings = load_file(file_path) + speakers = sorted({x["name"] for x in embeddings.values()}) + name_to_id = {name: i for i, name in enumerate(speakers)} + clip_ids = list(set(clip_name for clip_name in embeddings.keys())) + # cache embeddings_by_names for fast inference using a bigger speakers.json + embeddings_by_names = {} + for x in embeddings.values(): + if x["name"] not in embeddings_by_names.keys(): + embeddings_by_names[x["name"]] = [x["embedding"]] + else: + embeddings_by_names[x["name"]].append(x["embedding"]) + return name_to_id, clip_ids, embeddings, embeddings_by_names + + def load_embeddings_from_file(self, file_path: str) -> None: + """Load embeddings from a json file. + + Args: + file_path (str): Path to the target json file. + """ + self.name_to_id, self.clip_ids, self.embeddings, self.embeddings_by_names = self.read_embeddings_from_file( + file_path + ) + + def load_embeddings_from_list_of_files(self, file_paths: List[str]) -> None: + """Load embeddings from a list of json files and don't allow duplicate keys. + + Args: + file_paths (List[str]): List of paths to the target json files. + """ + self.name_to_id = {} + self.clip_ids = [] + self.embeddings_by_names = {} + self.embeddings = {} + for file_path in file_paths: + ids, clip_ids, embeddings, embeddings_by_names = self.read_embeddings_from_file(file_path) + # check colliding keys + duplicates = set(self.embeddings.keys()) & set(embeddings.keys()) + if duplicates: + raise ValueError(f" [!] Duplicate embedding names <{duplicates}> in {file_path}") + # store values + self.name_to_id.update(ids) + self.clip_ids.extend(clip_ids) + self.embeddings_by_names.update(embeddings_by_names) + self.embeddings.update(embeddings) + + # reset name_to_id to get the right speaker ids + self.name_to_id = {name: i for i, name in enumerate(self.name_to_id)} + + def get_embedding_by_clip(self, clip_idx: str) -> List: + """Get embedding by clip ID. + + Args: + clip_idx (str): Target clip ID. + + Returns: + List: embedding as a list. + """ + return self.embeddings[clip_idx]["embedding"] + + def get_embeddings_by_name(self, idx: str) -> List[List]: + """Get all embeddings of a speaker. + + Args: + idx (str): Target name. + + Returns: + List[List]: all the embeddings of the given speaker. + """ + return self.embeddings_by_names[idx] + + def get_embeddings_by_names(self) -> Dict: + """Get all embeddings by names. + + Returns: + Dict: all the embeddings of each speaker. + """ + embeddings_by_names = {} + for x in self.embeddings.values(): + if x["name"] not in embeddings_by_names.keys(): + embeddings_by_names[x["name"]] = [x["embedding"]] + else: + embeddings_by_names[x["name"]].append(x["embedding"]) + return embeddings_by_names + + def get_mean_embedding(self, idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray: + """Get mean embedding of a idx. + + Args: + idx (str): Target name. + num_samples (int, optional): Number of samples to be averaged. Defaults to None. + randomize (bool, optional): Pick random `num_samples` of embeddings. Defaults to False. + + Returns: + np.ndarray: Mean embedding. + """ + embeddings = self.get_embeddings_by_name(idx) + if num_samples is None: + embeddings = np.stack(embeddings).mean(0) + else: + assert len(embeddings) >= num_samples, f" [!] {idx} has number of samples < {num_samples}" + if randomize: + embeddings = np.stack(random.choices(embeddings, k=num_samples)).mean(0) + else: + embeddings = np.stack(embeddings[:num_samples]).mean(0) + return embeddings + + def get_random_embedding(self) -> Any: + """Get a random embedding. + + Args: + + Returns: + np.ndarray: embedding. + """ + if self.embeddings: + return self.embeddings[random.choices(list(self.embeddings.keys()))[0]]["embedding"] + + return None + + def get_clips(self) -> List: + return sorted(self.embeddings.keys()) + + def init_encoder(self, model_path: str, config_path: str, use_cuda=False) -> None: + """Initialize a speaker encoder model. + + Args: + model_path (str): Model file path. + config_path (str): Model config file path. + use_cuda (bool, optional): Use CUDA. Defaults to False. + """ + self.use_cuda = use_cuda + self.encoder_config = load_config(config_path) + self.encoder = setup_encoder_model(self.encoder_config) + self.encoder_criterion = self.encoder.load_checkpoint( + self.encoder_config, model_path, eval=True, use_cuda=use_cuda, cache=True + ) + self.encoder_ap = AudioProcessor(**self.encoder_config.audio) + + def compute_embedding_from_clip(self, wav_file: Union[str, List[str]]) -> list: + """Compute a embedding from a given audio file. + + Args: + wav_file (Union[str, List[str]]): Target file path. + + Returns: + list: Computed embedding. + """ + + def _compute(wav_file: str): + waveform = self.encoder_ap.load_wav(wav_file, sr=self.encoder_ap.sample_rate) + if not self.encoder_config.model_params.get("use_torch_spec", False): + m_input = self.encoder_ap.melspectrogram(waveform) + m_input = torch.from_numpy(m_input) + else: + m_input = torch.from_numpy(waveform) + + if self.use_cuda: + m_input = m_input.cuda() + m_input = m_input.unsqueeze(0) + embedding = self.encoder.compute_embedding(m_input) + return embedding + + if isinstance(wav_file, list): + # compute the mean embedding + embeddings = None + for wf in wav_file: + embedding = _compute(wf) + if embeddings is None: + embeddings = embedding + else: + embeddings += embedding + return (embeddings / len(wav_file))[0].tolist() + embedding = _compute(wav_file) + return embedding[0].tolist() + + def compute_embeddings(self, feats: Union[torch.Tensor, np.ndarray]) -> List: + """Compute embedding from features. + + Args: + feats (Union[torch.Tensor, np.ndarray]): Input features. + + Returns: + List: computed embedding. + """ + if isinstance(feats, np.ndarray): + feats = torch.from_numpy(feats) + if feats.ndim == 2: + feats = feats.unsqueeze(0) + if self.use_cuda: + feats = feats.cuda() + return self.encoder.compute_embedding(feats) diff --git a/TTS/tts/utils/measures.py b/TTS/tts/utils/measures.py new file mode 100644 index 0000000000000000000000000000000000000000..90e862e1190bdb8443933580b3ff47321f70cecd --- /dev/null +++ b/TTS/tts/utils/measures.py @@ -0,0 +1,15 @@ +def alignment_diagonal_score(alignments, binary=False): + """ + Compute how diagonal alignment predictions are. It is useful + to measure the alignment consistency of a model + Args: + alignments (torch.Tensor): batch of alignments. + binary (bool): if True, ignore scores and consider attention + as a binary mask. + Shape: + - alignments : :math:`[B, T_de, T_en]` + """ + maxs = alignments.max(dim=1)[0] + if binary: + maxs[maxs > 0] = 1 + return maxs.mean(dim=1).mean(dim=0).item() diff --git a/TTS/tts/utils/monotonic_align/__init__.py b/TTS/tts/utils/monotonic_align/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/utils/monotonic_align/core.c b/TTS/tts/utils/monotonic_align/core.c new file mode 100644 index 0000000000000000000000000000000000000000..ba3916635d06df2df018202d913b40c2cc55ebdd --- /dev/null +++ b/TTS/tts/utils/monotonic_align/core.c @@ -0,0 +1,30453 @@ +/* Generated by Cython 3.0.11 */ + +/* BEGIN: Cython Metadata +{ + "distutils": { + "depends": [], + "name": "TTS.tts.utils.monotonic_align.core", + "sources": [ + "TTS/tts/utils/monotonic_align/core.pyx" + ] + }, + "module_name": "TTS.tts.utils.monotonic_align.core" +} +END: Cython Metadata */ + +#ifndef PY_SSIZE_T_CLEAN +#define PY_SSIZE_T_CLEAN +#endif /* PY_SSIZE_T_CLEAN */ +#if defined(CYTHON_LIMITED_API) && 0 + #ifndef Py_LIMITED_API + #if CYTHON_LIMITED_API+0 > 0x03030000 + #define Py_LIMITED_API CYTHON_LIMITED_API + #else + #define Py_LIMITED_API 0x03030000 + #endif + #endif +#endif + +#include "Python.h" +#ifndef Py_PYTHON_H + #error Python headers needed to compile C extensions, please install development version of Python. +#elif PY_VERSION_HEX < 0x02070000 || (0x03000000 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x03030000) + #error Cython requires Python 2.7+ or Python 3.3+. +#else +#if defined(CYTHON_LIMITED_API) && CYTHON_LIMITED_API +#define __PYX_EXTRA_ABI_MODULE_NAME "limited" +#else +#define __PYX_EXTRA_ABI_MODULE_NAME "" +#endif +#define CYTHON_ABI "3_0_11" __PYX_EXTRA_ABI_MODULE_NAME +#define __PYX_ABI_MODULE_NAME "_cython_" CYTHON_ABI +#define __PYX_TYPE_MODULE_PREFIX __PYX_ABI_MODULE_NAME "." +#define CYTHON_HEX_VERSION 0x03000BF0 +#define CYTHON_FUTURE_DIVISION 1 +#include +#ifndef offsetof + #define offsetof(type, member) ( (size_t) & ((type*)0) -> member ) +#endif +#if !defined(_WIN32) && !defined(WIN32) && !defined(MS_WINDOWS) + #ifndef __stdcall + #define __stdcall + #endif + #ifndef __cdecl + #define __cdecl + #endif + #ifndef __fastcall + #define __fastcall + #endif +#endif +#ifndef DL_IMPORT + #define DL_IMPORT(t) t +#endif +#ifndef DL_EXPORT + #define DL_EXPORT(t) t +#endif +#define __PYX_COMMA , +#ifndef HAVE_LONG_LONG + #define HAVE_LONG_LONG +#endif +#ifndef PY_LONG_LONG + #define PY_LONG_LONG LONG_LONG +#endif +#ifndef Py_HUGE_VAL + #define Py_HUGE_VAL HUGE_VAL +#endif +#define __PYX_LIMITED_VERSION_HEX PY_VERSION_HEX +#if defined(GRAALVM_PYTHON) + /* For very preliminary testing purposes. Most variables are set the same as PyPy. + The existence of this section does not imply that anything works or is even tested */ + #define CYTHON_COMPILING_IN_PYPY 0 + #define CYTHON_COMPILING_IN_CPYTHON 0 + #define CYTHON_COMPILING_IN_LIMITED_API 0 + #define CYTHON_COMPILING_IN_GRAAL 1 + #define CYTHON_COMPILING_IN_NOGIL 0 + #undef CYTHON_USE_TYPE_SLOTS + #define CYTHON_USE_TYPE_SLOTS 0 + #undef CYTHON_USE_TYPE_SPECS + #define CYTHON_USE_TYPE_SPECS 0 + #undef CYTHON_USE_PYTYPE_LOOKUP + #define CYTHON_USE_PYTYPE_LOOKUP 0 + #if PY_VERSION_HEX < 0x03050000 + #undef CYTHON_USE_ASYNC_SLOTS + #define CYTHON_USE_ASYNC_SLOTS 0 + #elif !defined(CYTHON_USE_ASYNC_SLOTS) + #define CYTHON_USE_ASYNC_SLOTS 1 + #endif + #undef CYTHON_USE_PYLIST_INTERNALS + #define CYTHON_USE_PYLIST_INTERNALS 0 + #undef CYTHON_USE_UNICODE_INTERNALS + #define CYTHON_USE_UNICODE_INTERNALS 0 + #undef CYTHON_USE_UNICODE_WRITER + #define CYTHON_USE_UNICODE_WRITER 0 + #undef CYTHON_USE_PYLONG_INTERNALS + #define CYTHON_USE_PYLONG_INTERNALS 0 + #undef CYTHON_AVOID_BORROWED_REFS + #define CYTHON_AVOID_BORROWED_REFS 1 + #undef CYTHON_ASSUME_SAFE_MACROS + #define CYTHON_ASSUME_SAFE_MACROS 0 + #undef CYTHON_UNPACK_METHODS + #define CYTHON_UNPACK_METHODS 0 + #undef CYTHON_FAST_THREAD_STATE + #define CYTHON_FAST_THREAD_STATE 0 + #undef CYTHON_FAST_GIL + #define CYTHON_FAST_GIL 0 + #undef CYTHON_METH_FASTCALL + #define CYTHON_METH_FASTCALL 0 + #undef CYTHON_FAST_PYCALL + #define CYTHON_FAST_PYCALL 0 + #ifndef CYTHON_PEP487_INIT_SUBCLASS + #define CYTHON_PEP487_INIT_SUBCLASS (PY_MAJOR_VERSION >= 3) + #endif + #undef CYTHON_PEP489_MULTI_PHASE_INIT + #define CYTHON_PEP489_MULTI_PHASE_INIT 1 + #undef CYTHON_USE_MODULE_STATE + #define CYTHON_USE_MODULE_STATE 0 + #undef CYTHON_USE_TP_FINALIZE + #define CYTHON_USE_TP_FINALIZE 0 + #undef CYTHON_USE_DICT_VERSIONS + #define CYTHON_USE_DICT_VERSIONS 0 + #undef CYTHON_USE_EXC_INFO_STACK + #define CYTHON_USE_EXC_INFO_STACK 0 + #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC + #define CYTHON_UPDATE_DESCRIPTOR_DOC 0 + #endif + #undef CYTHON_USE_FREELISTS + #define CYTHON_USE_FREELISTS 0 +#elif defined(PYPY_VERSION) + #define CYTHON_COMPILING_IN_PYPY 1 + #define CYTHON_COMPILING_IN_CPYTHON 0 + #define CYTHON_COMPILING_IN_LIMITED_API 0 + #define CYTHON_COMPILING_IN_GRAAL 0 + #define CYTHON_COMPILING_IN_NOGIL 0 + #undef CYTHON_USE_TYPE_SLOTS + #define CYTHON_USE_TYPE_SLOTS 0 + #ifndef CYTHON_USE_TYPE_SPECS + #define CYTHON_USE_TYPE_SPECS 0 + #endif + #undef CYTHON_USE_PYTYPE_LOOKUP + #define CYTHON_USE_PYTYPE_LOOKUP 0 + #if PY_VERSION_HEX < 0x03050000 + #undef CYTHON_USE_ASYNC_SLOTS + #define CYTHON_USE_ASYNC_SLOTS 0 + #elif !defined(CYTHON_USE_ASYNC_SLOTS) + #define CYTHON_USE_ASYNC_SLOTS 1 + #endif + #undef CYTHON_USE_PYLIST_INTERNALS + #define CYTHON_USE_PYLIST_INTERNALS 0 + #undef CYTHON_USE_UNICODE_INTERNALS + #define CYTHON_USE_UNICODE_INTERNALS 0 + #undef CYTHON_USE_UNICODE_WRITER + #define CYTHON_USE_UNICODE_WRITER 0 + #undef CYTHON_USE_PYLONG_INTERNALS + #define CYTHON_USE_PYLONG_INTERNALS 0 + #undef CYTHON_AVOID_BORROWED_REFS + #define CYTHON_AVOID_BORROWED_REFS 1 + #undef CYTHON_ASSUME_SAFE_MACROS + #define CYTHON_ASSUME_SAFE_MACROS 0 + #undef CYTHON_UNPACK_METHODS + #define CYTHON_UNPACK_METHODS 0 + #undef CYTHON_FAST_THREAD_STATE + #define CYTHON_FAST_THREAD_STATE 0 + #undef CYTHON_FAST_GIL + #define CYTHON_FAST_GIL 0 + #undef CYTHON_METH_FASTCALL + #define CYTHON_METH_FASTCALL 0 + #undef CYTHON_FAST_PYCALL + #define CYTHON_FAST_PYCALL 0 + #ifndef CYTHON_PEP487_INIT_SUBCLASS + #define CYTHON_PEP487_INIT_SUBCLASS (PY_MAJOR_VERSION >= 3) + #endif + #if PY_VERSION_HEX < 0x03090000 + #undef CYTHON_PEP489_MULTI_PHASE_INIT + #define CYTHON_PEP489_MULTI_PHASE_INIT 0 + #elif !defined(CYTHON_PEP489_MULTI_PHASE_INIT) + #define CYTHON_PEP489_MULTI_PHASE_INIT 1 + #endif + #undef CYTHON_USE_MODULE_STATE + #define CYTHON_USE_MODULE_STATE 0 + #undef CYTHON_USE_TP_FINALIZE + #define CYTHON_USE_TP_FINALIZE (PY_VERSION_HEX >= 0x030400a1 && PYPY_VERSION_NUM >= 0x07030C00) + #undef CYTHON_USE_DICT_VERSIONS + #define CYTHON_USE_DICT_VERSIONS 0 + #undef CYTHON_USE_EXC_INFO_STACK + #define CYTHON_USE_EXC_INFO_STACK 0 + #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC + #define CYTHON_UPDATE_DESCRIPTOR_DOC 0 + #endif + #undef CYTHON_USE_FREELISTS + #define CYTHON_USE_FREELISTS 0 +#elif defined(CYTHON_LIMITED_API) + #ifdef Py_LIMITED_API + #undef __PYX_LIMITED_VERSION_HEX + #define __PYX_LIMITED_VERSION_HEX Py_LIMITED_API + #endif + #define CYTHON_COMPILING_IN_PYPY 0 + #define CYTHON_COMPILING_IN_CPYTHON 0 + #define CYTHON_COMPILING_IN_LIMITED_API 1 + #define CYTHON_COMPILING_IN_GRAAL 0 + #define CYTHON_COMPILING_IN_NOGIL 0 + #undef CYTHON_CLINE_IN_TRACEBACK + #define CYTHON_CLINE_IN_TRACEBACK 0 + #undef CYTHON_USE_TYPE_SLOTS + #define CYTHON_USE_TYPE_SLOTS 0 + #undef CYTHON_USE_TYPE_SPECS + #define CYTHON_USE_TYPE_SPECS 1 + #undef CYTHON_USE_PYTYPE_LOOKUP + #define CYTHON_USE_PYTYPE_LOOKUP 0 + #undef CYTHON_USE_ASYNC_SLOTS + #define CYTHON_USE_ASYNC_SLOTS 0 + #undef CYTHON_USE_PYLIST_INTERNALS + #define CYTHON_USE_PYLIST_INTERNALS 0 + #undef CYTHON_USE_UNICODE_INTERNALS + #define CYTHON_USE_UNICODE_INTERNALS 0 + #ifndef CYTHON_USE_UNICODE_WRITER + #define CYTHON_USE_UNICODE_WRITER 0 + #endif + #undef CYTHON_USE_PYLONG_INTERNALS + #define CYTHON_USE_PYLONG_INTERNALS 0 + #ifndef CYTHON_AVOID_BORROWED_REFS + #define CYTHON_AVOID_BORROWED_REFS 0 + #endif + #undef CYTHON_ASSUME_SAFE_MACROS + #define CYTHON_ASSUME_SAFE_MACROS 0 + #undef CYTHON_UNPACK_METHODS + #define CYTHON_UNPACK_METHODS 0 + #undef CYTHON_FAST_THREAD_STATE + #define CYTHON_FAST_THREAD_STATE 0 + #undef CYTHON_FAST_GIL + #define CYTHON_FAST_GIL 0 + #undef CYTHON_METH_FASTCALL + #define CYTHON_METH_FASTCALL 0 + #undef CYTHON_FAST_PYCALL + #define CYTHON_FAST_PYCALL 0 + #ifndef CYTHON_PEP487_INIT_SUBCLASS + #define CYTHON_PEP487_INIT_SUBCLASS 1 + #endif + #undef CYTHON_PEP489_MULTI_PHASE_INIT + #define CYTHON_PEP489_MULTI_PHASE_INIT 0 + #undef CYTHON_USE_MODULE_STATE + #define CYTHON_USE_MODULE_STATE 1 + #ifndef CYTHON_USE_TP_FINALIZE + #define CYTHON_USE_TP_FINALIZE 0 + #endif + #undef CYTHON_USE_DICT_VERSIONS + #define CYTHON_USE_DICT_VERSIONS 0 + #undef CYTHON_USE_EXC_INFO_STACK + #define CYTHON_USE_EXC_INFO_STACK 0 + #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC + #define CYTHON_UPDATE_DESCRIPTOR_DOC 0 + #endif + #undef CYTHON_USE_FREELISTS + #define CYTHON_USE_FREELISTS 0 +#elif defined(Py_GIL_DISABLED) || defined(Py_NOGIL) + #define CYTHON_COMPILING_IN_PYPY 0 + #define CYTHON_COMPILING_IN_CPYTHON 0 + #define CYTHON_COMPILING_IN_LIMITED_API 0 + #define CYTHON_COMPILING_IN_GRAAL 0 + #define CYTHON_COMPILING_IN_NOGIL 1 + #ifndef CYTHON_USE_TYPE_SLOTS + #define CYTHON_USE_TYPE_SLOTS 1 + #endif + #ifndef CYTHON_USE_TYPE_SPECS + #define CYTHON_USE_TYPE_SPECS 0 + #endif + #undef CYTHON_USE_PYTYPE_LOOKUP + #define CYTHON_USE_PYTYPE_LOOKUP 0 + #ifndef CYTHON_USE_ASYNC_SLOTS + #define CYTHON_USE_ASYNC_SLOTS 1 + #endif + #ifndef CYTHON_USE_PYLONG_INTERNALS + #define CYTHON_USE_PYLONG_INTERNALS 0 + #endif + #undef CYTHON_USE_PYLIST_INTERNALS + #define CYTHON_USE_PYLIST_INTERNALS 0 + #ifndef CYTHON_USE_UNICODE_INTERNALS + #define CYTHON_USE_UNICODE_INTERNALS 1 + #endif + #undef CYTHON_USE_UNICODE_WRITER + #define CYTHON_USE_UNICODE_WRITER 0 + #ifndef CYTHON_AVOID_BORROWED_REFS + #define CYTHON_AVOID_BORROWED_REFS 0 + #endif + #ifndef CYTHON_ASSUME_SAFE_MACROS + #define CYTHON_ASSUME_SAFE_MACROS 1 + #endif + #ifndef CYTHON_UNPACK_METHODS + #define CYTHON_UNPACK_METHODS 1 + #endif + #undef CYTHON_FAST_THREAD_STATE + #define CYTHON_FAST_THREAD_STATE 0 + #undef CYTHON_FAST_GIL + #define CYTHON_FAST_GIL 0 + #ifndef CYTHON_METH_FASTCALL + #define CYTHON_METH_FASTCALL 1 + #endif + #undef CYTHON_FAST_PYCALL + #define CYTHON_FAST_PYCALL 0 + #ifndef CYTHON_PEP487_INIT_SUBCLASS + #define CYTHON_PEP487_INIT_SUBCLASS 1 + #endif + #ifndef CYTHON_PEP489_MULTI_PHASE_INIT + #define CYTHON_PEP489_MULTI_PHASE_INIT 1 + #endif + #ifndef CYTHON_USE_MODULE_STATE + #define CYTHON_USE_MODULE_STATE 0 + #endif + #ifndef CYTHON_USE_TP_FINALIZE + #define CYTHON_USE_TP_FINALIZE 1 + #endif + #undef CYTHON_USE_DICT_VERSIONS + #define CYTHON_USE_DICT_VERSIONS 0 + #undef CYTHON_USE_EXC_INFO_STACK + #define CYTHON_USE_EXC_INFO_STACK 0 + #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC + #define CYTHON_UPDATE_DESCRIPTOR_DOC 1 + #endif + #ifndef CYTHON_USE_FREELISTS + #define CYTHON_USE_FREELISTS 0 + #endif +#else + #define CYTHON_COMPILING_IN_PYPY 0 + #define CYTHON_COMPILING_IN_CPYTHON 1 + #define CYTHON_COMPILING_IN_LIMITED_API 0 + #define CYTHON_COMPILING_IN_GRAAL 0 + #define CYTHON_COMPILING_IN_NOGIL 0 + #ifndef CYTHON_USE_TYPE_SLOTS + #define CYTHON_USE_TYPE_SLOTS 1 + #endif + #ifndef CYTHON_USE_TYPE_SPECS + #define CYTHON_USE_TYPE_SPECS 0 + #endif + #ifndef CYTHON_USE_PYTYPE_LOOKUP + #define CYTHON_USE_PYTYPE_LOOKUP 1 + #endif + #if PY_MAJOR_VERSION < 3 + #undef CYTHON_USE_ASYNC_SLOTS + #define CYTHON_USE_ASYNC_SLOTS 0 + #elif !defined(CYTHON_USE_ASYNC_SLOTS) + #define CYTHON_USE_ASYNC_SLOTS 1 + #endif + #ifndef CYTHON_USE_PYLONG_INTERNALS + #define CYTHON_USE_PYLONG_INTERNALS 1 + #endif + #ifndef CYTHON_USE_PYLIST_INTERNALS + #define CYTHON_USE_PYLIST_INTERNALS 1 + #endif + #ifndef CYTHON_USE_UNICODE_INTERNALS + #define CYTHON_USE_UNICODE_INTERNALS 1 + #endif + #if PY_VERSION_HEX < 0x030300F0 || PY_VERSION_HEX >= 0x030B00A2 + #undef CYTHON_USE_UNICODE_WRITER + #define CYTHON_USE_UNICODE_WRITER 0 + #elif !defined(CYTHON_USE_UNICODE_WRITER) + #define CYTHON_USE_UNICODE_WRITER 1 + #endif + #ifndef CYTHON_AVOID_BORROWED_REFS + #define CYTHON_AVOID_BORROWED_REFS 0 + #endif + #ifndef CYTHON_ASSUME_SAFE_MACROS + #define CYTHON_ASSUME_SAFE_MACROS 1 + #endif + #ifndef CYTHON_UNPACK_METHODS + #define CYTHON_UNPACK_METHODS 1 + #endif + #ifndef CYTHON_FAST_THREAD_STATE + #define CYTHON_FAST_THREAD_STATE 1 + #endif + #ifndef CYTHON_FAST_GIL + #define CYTHON_FAST_GIL (PY_MAJOR_VERSION < 3 || PY_VERSION_HEX >= 0x03060000 && PY_VERSION_HEX < 0x030C00A6) + #endif + #ifndef CYTHON_METH_FASTCALL + #define CYTHON_METH_FASTCALL (PY_VERSION_HEX >= 0x030700A1) + #endif + #ifndef CYTHON_FAST_PYCALL + #define CYTHON_FAST_PYCALL 1 + #endif + #ifndef CYTHON_PEP487_INIT_SUBCLASS + #define CYTHON_PEP487_INIT_SUBCLASS 1 + #endif + #if PY_VERSION_HEX < 0x03050000 + #undef CYTHON_PEP489_MULTI_PHASE_INIT + #define CYTHON_PEP489_MULTI_PHASE_INIT 0 + #elif !defined(CYTHON_PEP489_MULTI_PHASE_INIT) + #define CYTHON_PEP489_MULTI_PHASE_INIT 1 + #endif + #ifndef CYTHON_USE_MODULE_STATE + #define CYTHON_USE_MODULE_STATE 0 + #endif + #if PY_VERSION_HEX < 0x030400a1 + #undef CYTHON_USE_TP_FINALIZE + #define CYTHON_USE_TP_FINALIZE 0 + #elif !defined(CYTHON_USE_TP_FINALIZE) + #define CYTHON_USE_TP_FINALIZE 1 + #endif + #if PY_VERSION_HEX < 0x030600B1 + #undef CYTHON_USE_DICT_VERSIONS + #define CYTHON_USE_DICT_VERSIONS 0 + #elif !defined(CYTHON_USE_DICT_VERSIONS) + #define CYTHON_USE_DICT_VERSIONS (PY_VERSION_HEX < 0x030C00A5) + #endif + #if PY_VERSION_HEX < 0x030700A3 + #undef CYTHON_USE_EXC_INFO_STACK + #define CYTHON_USE_EXC_INFO_STACK 0 + #elif !defined(CYTHON_USE_EXC_INFO_STACK) + #define CYTHON_USE_EXC_INFO_STACK 1 + #endif + #ifndef CYTHON_UPDATE_DESCRIPTOR_DOC + #define CYTHON_UPDATE_DESCRIPTOR_DOC 1 + #endif + #ifndef CYTHON_USE_FREELISTS + #define CYTHON_USE_FREELISTS 1 + #endif +#endif +#if !defined(CYTHON_FAST_PYCCALL) +#define CYTHON_FAST_PYCCALL (CYTHON_FAST_PYCALL && PY_VERSION_HEX >= 0x030600B1) +#endif +#if !defined(CYTHON_VECTORCALL) +#define CYTHON_VECTORCALL (CYTHON_FAST_PYCCALL && PY_VERSION_HEX >= 0x030800B1) +#endif +#define CYTHON_BACKPORT_VECTORCALL (CYTHON_METH_FASTCALL && PY_VERSION_HEX < 0x030800B1) +#if CYTHON_USE_PYLONG_INTERNALS + #if PY_MAJOR_VERSION < 3 + #include "longintrepr.h" + #endif + #undef SHIFT + #undef BASE + #undef MASK + #ifdef SIZEOF_VOID_P + enum { __pyx_check_sizeof_voidp = 1 / (int)(SIZEOF_VOID_P == sizeof(void*)) }; + #endif +#endif +#ifndef __has_attribute + #define __has_attribute(x) 0 +#endif +#ifndef __has_cpp_attribute + #define __has_cpp_attribute(x) 0 +#endif +#ifndef CYTHON_RESTRICT + #if defined(__GNUC__) + #define CYTHON_RESTRICT __restrict__ + #elif defined(_MSC_VER) && _MSC_VER >= 1400 + #define CYTHON_RESTRICT __restrict + #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L + #define CYTHON_RESTRICT restrict + #else + #define CYTHON_RESTRICT + #endif +#endif +#ifndef CYTHON_UNUSED + #if defined(__cplusplus) + /* for clang __has_cpp_attribute(maybe_unused) is true even before C++17 + * but leads to warnings with -pedantic, since it is a C++17 feature */ + #if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L) + #if __has_cpp_attribute(maybe_unused) + #define CYTHON_UNUSED [[maybe_unused]] + #endif + #endif + #endif +#endif +#ifndef CYTHON_UNUSED +# if defined(__GNUC__) +# if !(defined(__cplusplus)) || (__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ >= 4)) +# define CYTHON_UNUSED __attribute__ ((__unused__)) +# else +# define CYTHON_UNUSED +# endif +# elif defined(__ICC) || (defined(__INTEL_COMPILER) && !defined(_MSC_VER)) +# define CYTHON_UNUSED __attribute__ ((__unused__)) +# else +# define CYTHON_UNUSED +# endif +#endif +#ifndef CYTHON_UNUSED_VAR +# if defined(__cplusplus) + template void CYTHON_UNUSED_VAR( const T& ) { } +# else +# define CYTHON_UNUSED_VAR(x) (void)(x) +# endif +#endif +#ifndef CYTHON_MAYBE_UNUSED_VAR + #define CYTHON_MAYBE_UNUSED_VAR(x) CYTHON_UNUSED_VAR(x) +#endif +#ifndef CYTHON_NCP_UNUSED +# if CYTHON_COMPILING_IN_CPYTHON +# define CYTHON_NCP_UNUSED +# else +# define CYTHON_NCP_UNUSED CYTHON_UNUSED +# endif +#endif +#ifndef CYTHON_USE_CPP_STD_MOVE + #if defined(__cplusplus) && (\ + __cplusplus >= 201103L || (defined(_MSC_VER) && _MSC_VER >= 1600)) + #define CYTHON_USE_CPP_STD_MOVE 1 + #else + #define CYTHON_USE_CPP_STD_MOVE 0 + #endif +#endif +#define __Pyx_void_to_None(void_result) ((void)(void_result), Py_INCREF(Py_None), Py_None) +#ifdef _MSC_VER + #ifndef _MSC_STDINT_H_ + #if _MSC_VER < 1300 + typedef unsigned char uint8_t; + typedef unsigned short uint16_t; + typedef unsigned int uint32_t; + #else + typedef unsigned __int8 uint8_t; + typedef unsigned __int16 uint16_t; + typedef unsigned __int32 uint32_t; + #endif + #endif + #if _MSC_VER < 1300 + #ifdef _WIN64 + typedef unsigned long long __pyx_uintptr_t; + #else + typedef unsigned int __pyx_uintptr_t; + #endif + #else + #ifdef _WIN64 + typedef unsigned __int64 __pyx_uintptr_t; + #else + typedef unsigned __int32 __pyx_uintptr_t; + #endif + #endif +#else + #include + typedef uintptr_t __pyx_uintptr_t; +#endif +#ifndef CYTHON_FALLTHROUGH + #if defined(__cplusplus) + /* for clang __has_cpp_attribute(fallthrough) is true even before C++17 + * but leads to warnings with -pedantic, since it is a C++17 feature */ + #if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L) + #if __has_cpp_attribute(fallthrough) + #define CYTHON_FALLTHROUGH [[fallthrough]] + #endif + #endif + #ifndef CYTHON_FALLTHROUGH + #if __has_cpp_attribute(clang::fallthrough) + #define CYTHON_FALLTHROUGH [[clang::fallthrough]] + #elif __has_cpp_attribute(gnu::fallthrough) + #define CYTHON_FALLTHROUGH [[gnu::fallthrough]] + #endif + #endif + #endif + #ifndef CYTHON_FALLTHROUGH + #if __has_attribute(fallthrough) + #define CYTHON_FALLTHROUGH __attribute__((fallthrough)) + #else + #define CYTHON_FALLTHROUGH + #endif + #endif + #if defined(__clang__) && defined(__apple_build_version__) + #if __apple_build_version__ < 7000000 + #undef CYTHON_FALLTHROUGH + #define CYTHON_FALLTHROUGH + #endif + #endif +#endif +#ifdef __cplusplus + template + struct __PYX_IS_UNSIGNED_IMPL {static const bool value = T(0) < T(-1);}; + #define __PYX_IS_UNSIGNED(type) (__PYX_IS_UNSIGNED_IMPL::value) +#else + #define __PYX_IS_UNSIGNED(type) (((type)-1) > 0) +#endif +#if CYTHON_COMPILING_IN_PYPY == 1 + #define __PYX_NEED_TP_PRINT_SLOT (PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x030A0000) +#else + #define __PYX_NEED_TP_PRINT_SLOT (PY_VERSION_HEX >= 0x030800b4 && PY_VERSION_HEX < 0x03090000) +#endif +#define __PYX_REINTERPRET_FUNCION(func_pointer, other_pointer) ((func_pointer)(void(*)(void))(other_pointer)) + +#ifndef CYTHON_INLINE + #if defined(__clang__) + #define CYTHON_INLINE __inline__ __attribute__ ((__unused__)) + #elif defined(__GNUC__) + #define CYTHON_INLINE __inline__ + #elif defined(_MSC_VER) + #define CYTHON_INLINE __inline + #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L + #define CYTHON_INLINE inline + #else + #define CYTHON_INLINE + #endif +#endif + +#define __PYX_BUILD_PY_SSIZE_T "n" +#define CYTHON_FORMAT_SSIZE_T "z" +#if PY_MAJOR_VERSION < 3 + #define __Pyx_BUILTIN_MODULE_NAME "__builtin__" + #define __Pyx_DefaultClassType PyClass_Type + #define __Pyx_PyCode_New(a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\ + PyCode_New(a+k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos) +#else + #define __Pyx_BUILTIN_MODULE_NAME "builtins" + #define __Pyx_DefaultClassType PyType_Type +#if CYTHON_COMPILING_IN_LIMITED_API + static CYTHON_INLINE PyObject* __Pyx_PyCode_New(int a, int p, int k, int l, int s, int f, + PyObject *code, PyObject *c, PyObject* n, PyObject *v, + PyObject *fv, PyObject *cell, PyObject* fn, + PyObject *name, int fline, PyObject *lnos) { + PyObject *exception_table = NULL; + PyObject *types_module=NULL, *code_type=NULL, *result=NULL; + #if __PYX_LIMITED_VERSION_HEX < 0x030B0000 + PyObject *version_info; + PyObject *py_minor_version = NULL; + #endif + long minor_version = 0; + PyObject *type, *value, *traceback; + PyErr_Fetch(&type, &value, &traceback); + #if __PYX_LIMITED_VERSION_HEX >= 0x030B0000 + minor_version = 11; + #else + if (!(version_info = PySys_GetObject("version_info"))) goto end; + if (!(py_minor_version = PySequence_GetItem(version_info, 1))) goto end; + minor_version = PyLong_AsLong(py_minor_version); + Py_DECREF(py_minor_version); + if (minor_version == -1 && PyErr_Occurred()) goto end; + #endif + if (!(types_module = PyImport_ImportModule("types"))) goto end; + if (!(code_type = PyObject_GetAttrString(types_module, "CodeType"))) goto end; + if (minor_version <= 7) { + (void)p; + result = PyObject_CallFunction(code_type, "iiiiiOOOOOOiOO", a, k, l, s, f, code, + c, n, v, fn, name, fline, lnos, fv, cell); + } else if (minor_version <= 10) { + result = PyObject_CallFunction(code_type, "iiiiiiOOOOOOiOO", a,p, k, l, s, f, code, + c, n, v, fn, name, fline, lnos, fv, cell); + } else { + if (!(exception_table = PyBytes_FromStringAndSize(NULL, 0))) goto end; + result = PyObject_CallFunction(code_type, "iiiiiiOOOOOOOiOO", a,p, k, l, s, f, code, + c, n, v, fn, name, name, fline, lnos, exception_table, fv, cell); + } + end: + Py_XDECREF(code_type); + Py_XDECREF(exception_table); + Py_XDECREF(types_module); + if (type) { + PyErr_Restore(type, value, traceback); + } + return result; + } + #ifndef CO_OPTIMIZED + #define CO_OPTIMIZED 0x0001 + #endif + #ifndef CO_NEWLOCALS + #define CO_NEWLOCALS 0x0002 + #endif + #ifndef CO_VARARGS + #define CO_VARARGS 0x0004 + #endif + #ifndef CO_VARKEYWORDS + #define CO_VARKEYWORDS 0x0008 + #endif + #ifndef CO_ASYNC_GENERATOR + #define CO_ASYNC_GENERATOR 0x0200 + #endif + #ifndef CO_GENERATOR + #define CO_GENERATOR 0x0020 + #endif + #ifndef CO_COROUTINE + #define CO_COROUTINE 0x0080 + #endif +#elif PY_VERSION_HEX >= 0x030B0000 + static CYTHON_INLINE PyCodeObject* __Pyx_PyCode_New(int a, int p, int k, int l, int s, int f, + PyObject *code, PyObject *c, PyObject* n, PyObject *v, + PyObject *fv, PyObject *cell, PyObject* fn, + PyObject *name, int fline, PyObject *lnos) { + PyCodeObject *result; + PyObject *empty_bytes = PyBytes_FromStringAndSize("", 0); + if (!empty_bytes) return NULL; + result = + #if PY_VERSION_HEX >= 0x030C0000 + PyUnstable_Code_NewWithPosOnlyArgs + #else + PyCode_NewWithPosOnlyArgs + #endif + (a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, name, fline, lnos, empty_bytes); + Py_DECREF(empty_bytes); + return result; + } +#elif PY_VERSION_HEX >= 0x030800B2 && !CYTHON_COMPILING_IN_PYPY + #define __Pyx_PyCode_New(a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\ + PyCode_NewWithPosOnlyArgs(a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos) +#else + #define __Pyx_PyCode_New(a, p, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos)\ + PyCode_New(a, k, l, s, f, code, c, n, v, fv, cell, fn, name, fline, lnos) +#endif +#endif +#if PY_VERSION_HEX >= 0x030900A4 || defined(Py_IS_TYPE) + #define __Pyx_IS_TYPE(ob, type) Py_IS_TYPE(ob, type) +#else + #define __Pyx_IS_TYPE(ob, type) (((const PyObject*)ob)->ob_type == (type)) +#endif +#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_Is) + #define __Pyx_Py_Is(x, y) Py_Is(x, y) +#else + #define __Pyx_Py_Is(x, y) ((x) == (y)) +#endif +#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsNone) + #define __Pyx_Py_IsNone(ob) Py_IsNone(ob) +#else + #define __Pyx_Py_IsNone(ob) __Pyx_Py_Is((ob), Py_None) +#endif +#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsTrue) + #define __Pyx_Py_IsTrue(ob) Py_IsTrue(ob) +#else + #define __Pyx_Py_IsTrue(ob) __Pyx_Py_Is((ob), Py_True) +#endif +#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsFalse) + #define __Pyx_Py_IsFalse(ob) Py_IsFalse(ob) +#else + #define __Pyx_Py_IsFalse(ob) __Pyx_Py_Is((ob), Py_False) +#endif +#define __Pyx_NoneAsNull(obj) (__Pyx_Py_IsNone(obj) ? NULL : (obj)) +#if PY_VERSION_HEX >= 0x030900F0 && !CYTHON_COMPILING_IN_PYPY + #define __Pyx_PyObject_GC_IsFinalized(o) PyObject_GC_IsFinalized(o) +#else + #define __Pyx_PyObject_GC_IsFinalized(o) _PyGC_FINALIZED(o) +#endif +#ifndef CO_COROUTINE + #define CO_COROUTINE 0x80 +#endif +#ifndef CO_ASYNC_GENERATOR + #define CO_ASYNC_GENERATOR 0x200 +#endif +#ifndef Py_TPFLAGS_CHECKTYPES + #define Py_TPFLAGS_CHECKTYPES 0 +#endif +#ifndef Py_TPFLAGS_HAVE_INDEX + #define Py_TPFLAGS_HAVE_INDEX 0 +#endif +#ifndef Py_TPFLAGS_HAVE_NEWBUFFER + #define Py_TPFLAGS_HAVE_NEWBUFFER 0 +#endif +#ifndef Py_TPFLAGS_HAVE_FINALIZE + #define Py_TPFLAGS_HAVE_FINALIZE 0 +#endif +#ifndef Py_TPFLAGS_SEQUENCE + #define Py_TPFLAGS_SEQUENCE 0 +#endif +#ifndef Py_TPFLAGS_MAPPING + #define Py_TPFLAGS_MAPPING 0 +#endif +#ifndef METH_STACKLESS + #define METH_STACKLESS 0 +#endif +#if PY_VERSION_HEX <= 0x030700A3 || !defined(METH_FASTCALL) + #ifndef METH_FASTCALL + #define METH_FASTCALL 0x80 + #endif + typedef PyObject *(*__Pyx_PyCFunctionFast) (PyObject *self, PyObject *const *args, Py_ssize_t nargs); + typedef PyObject *(*__Pyx_PyCFunctionFastWithKeywords) (PyObject *self, PyObject *const *args, + Py_ssize_t nargs, PyObject *kwnames); +#else + #if PY_VERSION_HEX >= 0x030d00A4 + # define __Pyx_PyCFunctionFast PyCFunctionFast + # define __Pyx_PyCFunctionFastWithKeywords PyCFunctionFastWithKeywords + #else + # define __Pyx_PyCFunctionFast _PyCFunctionFast + # define __Pyx_PyCFunctionFastWithKeywords _PyCFunctionFastWithKeywords + #endif +#endif +#if CYTHON_METH_FASTCALL + #define __Pyx_METH_FASTCALL METH_FASTCALL + #define __Pyx_PyCFunction_FastCall __Pyx_PyCFunctionFast + #define __Pyx_PyCFunction_FastCallWithKeywords __Pyx_PyCFunctionFastWithKeywords +#else + #define __Pyx_METH_FASTCALL METH_VARARGS + #define __Pyx_PyCFunction_FastCall PyCFunction + #define __Pyx_PyCFunction_FastCallWithKeywords PyCFunctionWithKeywords +#endif +#if CYTHON_VECTORCALL + #define __pyx_vectorcallfunc vectorcallfunc + #define __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET PY_VECTORCALL_ARGUMENTS_OFFSET + #define __Pyx_PyVectorcall_NARGS(n) PyVectorcall_NARGS((size_t)(n)) +#elif CYTHON_BACKPORT_VECTORCALL + typedef PyObject *(*__pyx_vectorcallfunc)(PyObject *callable, PyObject *const *args, + size_t nargsf, PyObject *kwnames); + #define __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET ((size_t)1 << (8 * sizeof(size_t) - 1)) + #define __Pyx_PyVectorcall_NARGS(n) ((Py_ssize_t)(((size_t)(n)) & ~__Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET)) +#else + #define __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET 0 + #define __Pyx_PyVectorcall_NARGS(n) ((Py_ssize_t)(n)) +#endif +#if PY_MAJOR_VERSION >= 0x030900B1 +#define __Pyx_PyCFunction_CheckExact(func) PyCFunction_CheckExact(func) +#else +#define __Pyx_PyCFunction_CheckExact(func) PyCFunction_Check(func) +#endif +#define __Pyx_CyOrPyCFunction_Check(func) PyCFunction_Check(func) +#if CYTHON_COMPILING_IN_CPYTHON +#define __Pyx_CyOrPyCFunction_GET_FUNCTION(func) (((PyCFunctionObject*)(func))->m_ml->ml_meth) +#elif !CYTHON_COMPILING_IN_LIMITED_API +#define __Pyx_CyOrPyCFunction_GET_FUNCTION(func) PyCFunction_GET_FUNCTION(func) +#endif +#if CYTHON_COMPILING_IN_CPYTHON +#define __Pyx_CyOrPyCFunction_GET_FLAGS(func) (((PyCFunctionObject*)(func))->m_ml->ml_flags) +static CYTHON_INLINE PyObject* __Pyx_CyOrPyCFunction_GET_SELF(PyObject *func) { + return (__Pyx_CyOrPyCFunction_GET_FLAGS(func) & METH_STATIC) ? NULL : ((PyCFunctionObject*)func)->m_self; +} +#endif +static CYTHON_INLINE int __Pyx__IsSameCFunction(PyObject *func, void *cfunc) { +#if CYTHON_COMPILING_IN_LIMITED_API + return PyCFunction_Check(func) && PyCFunction_GetFunction(func) == (PyCFunction) cfunc; +#else + return PyCFunction_Check(func) && PyCFunction_GET_FUNCTION(func) == (PyCFunction) cfunc; +#endif +} +#define __Pyx_IsSameCFunction(func, cfunc) __Pyx__IsSameCFunction(func, cfunc) +#if __PYX_LIMITED_VERSION_HEX < 0x030900B1 + #define __Pyx_PyType_FromModuleAndSpec(m, s, b) ((void)m, PyType_FromSpecWithBases(s, b)) + typedef PyObject *(*__Pyx_PyCMethod)(PyObject *, PyTypeObject *, PyObject *const *, size_t, PyObject *); +#else + #define __Pyx_PyType_FromModuleAndSpec(m, s, b) PyType_FromModuleAndSpec(m, s, b) + #define __Pyx_PyCMethod PyCMethod +#endif +#ifndef METH_METHOD + #define METH_METHOD 0x200 +#endif +#if CYTHON_COMPILING_IN_PYPY && !defined(PyObject_Malloc) + #define PyObject_Malloc(s) PyMem_Malloc(s) + #define PyObject_Free(p) PyMem_Free(p) + #define PyObject_Realloc(p) PyMem_Realloc(p) +#endif +#if CYTHON_COMPILING_IN_LIMITED_API + #define __Pyx_PyCode_HasFreeVars(co) (PyCode_GetNumFree(co) > 0) + #define __Pyx_PyFrame_SetLineNumber(frame, lineno) +#else + #define __Pyx_PyCode_HasFreeVars(co) (PyCode_GetNumFree(co) > 0) + #define __Pyx_PyFrame_SetLineNumber(frame, lineno) (frame)->f_lineno = (lineno) +#endif +#if CYTHON_COMPILING_IN_LIMITED_API + #define __Pyx_PyThreadState_Current PyThreadState_Get() +#elif !CYTHON_FAST_THREAD_STATE + #define __Pyx_PyThreadState_Current PyThreadState_GET() +#elif PY_VERSION_HEX >= 0x030d00A1 + #define __Pyx_PyThreadState_Current PyThreadState_GetUnchecked() +#elif PY_VERSION_HEX >= 0x03060000 + #define __Pyx_PyThreadState_Current _PyThreadState_UncheckedGet() +#elif PY_VERSION_HEX >= 0x03000000 + #define __Pyx_PyThreadState_Current PyThreadState_GET() +#else + #define __Pyx_PyThreadState_Current _PyThreadState_Current +#endif +#if CYTHON_COMPILING_IN_LIMITED_API +static CYTHON_INLINE void *__Pyx_PyModule_GetState(PyObject *op) +{ + void *result; + result = PyModule_GetState(op); + if (!result) + Py_FatalError("Couldn't find the module state"); + return result; +} +#endif +#define __Pyx_PyObject_GetSlot(obj, name, func_ctype) __Pyx_PyType_GetSlot(Py_TYPE(obj), name, func_ctype) +#if CYTHON_COMPILING_IN_LIMITED_API + #define __Pyx_PyType_GetSlot(type, name, func_ctype) ((func_ctype) PyType_GetSlot((type), Py_##name)) +#else + #define __Pyx_PyType_GetSlot(type, name, func_ctype) ((type)->name) +#endif +#if PY_VERSION_HEX < 0x030700A2 && !defined(PyThread_tss_create) && !defined(Py_tss_NEEDS_INIT) +#include "pythread.h" +#define Py_tss_NEEDS_INIT 0 +typedef int Py_tss_t; +static CYTHON_INLINE int PyThread_tss_create(Py_tss_t *key) { + *key = PyThread_create_key(); + return 0; +} +static CYTHON_INLINE Py_tss_t * PyThread_tss_alloc(void) { + Py_tss_t *key = (Py_tss_t *)PyObject_Malloc(sizeof(Py_tss_t)); + *key = Py_tss_NEEDS_INIT; + return key; +} +static CYTHON_INLINE void PyThread_tss_free(Py_tss_t *key) { + PyObject_Free(key); +} +static CYTHON_INLINE int PyThread_tss_is_created(Py_tss_t *key) { + return *key != Py_tss_NEEDS_INIT; +} +static CYTHON_INLINE void PyThread_tss_delete(Py_tss_t *key) { + PyThread_delete_key(*key); + *key = Py_tss_NEEDS_INIT; +} +static CYTHON_INLINE int PyThread_tss_set(Py_tss_t *key, void *value) { + return PyThread_set_key_value(*key, value); +} +static CYTHON_INLINE void * PyThread_tss_get(Py_tss_t *key) { + return PyThread_get_key_value(*key); +} +#endif +#if PY_MAJOR_VERSION < 3 + #if CYTHON_COMPILING_IN_PYPY + #if PYPY_VERSION_NUM < 0x07030600 + #if defined(__cplusplus) && __cplusplus >= 201402L + [[deprecated("`with nogil:` inside a nogil function will not release the GIL in PyPy2 < 7.3.6")]] + #elif defined(__GNUC__) || defined(__clang__) + __attribute__ ((__deprecated__("`with nogil:` inside a nogil function will not release the GIL in PyPy2 < 7.3.6"))) + #elif defined(_MSC_VER) + __declspec(deprecated("`with nogil:` inside a nogil function will not release the GIL in PyPy2 < 7.3.6")) + #endif + static CYTHON_INLINE int PyGILState_Check(void) { + return 0; + } + #else // PYPY_VERSION_NUM < 0x07030600 + #endif // PYPY_VERSION_NUM < 0x07030600 + #else + static CYTHON_INLINE int PyGILState_Check(void) { + PyThreadState * tstate = _PyThreadState_Current; + return tstate && (tstate == PyGILState_GetThisThreadState()); + } + #endif +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030d0000 || defined(_PyDict_NewPresized) +#define __Pyx_PyDict_NewPresized(n) ((n <= 8) ? PyDict_New() : _PyDict_NewPresized(n)) +#else +#define __Pyx_PyDict_NewPresized(n) PyDict_New() +#endif +#if PY_MAJOR_VERSION >= 3 || CYTHON_FUTURE_DIVISION + #define __Pyx_PyNumber_Divide(x,y) PyNumber_TrueDivide(x,y) + #define __Pyx_PyNumber_InPlaceDivide(x,y) PyNumber_InPlaceTrueDivide(x,y) +#else + #define __Pyx_PyNumber_Divide(x,y) PyNumber_Divide(x,y) + #define __Pyx_PyNumber_InPlaceDivide(x,y) PyNumber_InPlaceDivide(x,y) +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX > 0x030600B4 && PY_VERSION_HEX < 0x030d0000 && CYTHON_USE_UNICODE_INTERNALS +#define __Pyx_PyDict_GetItemStrWithError(dict, name) _PyDict_GetItem_KnownHash(dict, name, ((PyASCIIObject *) name)->hash) +static CYTHON_INLINE PyObject * __Pyx_PyDict_GetItemStr(PyObject *dict, PyObject *name) { + PyObject *res = __Pyx_PyDict_GetItemStrWithError(dict, name); + if (res == NULL) PyErr_Clear(); + return res; +} +#elif PY_MAJOR_VERSION >= 3 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07020000) +#define __Pyx_PyDict_GetItemStrWithError PyDict_GetItemWithError +#define __Pyx_PyDict_GetItemStr PyDict_GetItem +#else +static CYTHON_INLINE PyObject * __Pyx_PyDict_GetItemStrWithError(PyObject *dict, PyObject *name) { +#if CYTHON_COMPILING_IN_PYPY + return PyDict_GetItem(dict, name); +#else + PyDictEntry *ep; + PyDictObject *mp = (PyDictObject*) dict; + long hash = ((PyStringObject *) name)->ob_shash; + assert(hash != -1); + ep = (mp->ma_lookup)(mp, name, hash); + if (ep == NULL) { + return NULL; + } + return ep->me_value; +#endif +} +#define __Pyx_PyDict_GetItemStr PyDict_GetItem +#endif +#if CYTHON_USE_TYPE_SLOTS + #define __Pyx_PyType_GetFlags(tp) (((PyTypeObject *)tp)->tp_flags) + #define __Pyx_PyType_HasFeature(type, feature) ((__Pyx_PyType_GetFlags(type) & (feature)) != 0) + #define __Pyx_PyObject_GetIterNextFunc(obj) (Py_TYPE(obj)->tp_iternext) +#else + #define __Pyx_PyType_GetFlags(tp) (PyType_GetFlags((PyTypeObject *)tp)) + #define __Pyx_PyType_HasFeature(type, feature) PyType_HasFeature(type, feature) + #define __Pyx_PyObject_GetIterNextFunc(obj) PyIter_Next +#endif +#if CYTHON_COMPILING_IN_LIMITED_API + #define __Pyx_SetItemOnTypeDict(tp, k, v) PyObject_GenericSetAttr((PyObject*)tp, k, v) +#else + #define __Pyx_SetItemOnTypeDict(tp, k, v) PyDict_SetItem(tp->tp_dict, k, v) +#endif +#if CYTHON_USE_TYPE_SPECS && PY_VERSION_HEX >= 0x03080000 +#define __Pyx_PyHeapTypeObject_GC_Del(obj) {\ + PyTypeObject *type = Py_TYPE((PyObject*)obj);\ + assert(__Pyx_PyType_HasFeature(type, Py_TPFLAGS_HEAPTYPE));\ + PyObject_GC_Del(obj);\ + Py_DECREF(type);\ +} +#else +#define __Pyx_PyHeapTypeObject_GC_Del(obj) PyObject_GC_Del(obj) +#endif +#if CYTHON_COMPILING_IN_LIMITED_API + #define CYTHON_PEP393_ENABLED 1 + #define __Pyx_PyUnicode_READY(op) (0) + #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GetLength(u) + #define __Pyx_PyUnicode_READ_CHAR(u, i) PyUnicode_ReadChar(u, i) + #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) ((void)u, 1114111U) + #define __Pyx_PyUnicode_KIND(u) ((void)u, (0)) + #define __Pyx_PyUnicode_DATA(u) ((void*)u) + #define __Pyx_PyUnicode_READ(k, d, i) ((void)k, PyUnicode_ReadChar((PyObject*)(d), i)) + #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GetLength(u)) +#elif PY_VERSION_HEX > 0x03030000 && defined(PyUnicode_KIND) + #define CYTHON_PEP393_ENABLED 1 + #if PY_VERSION_HEX >= 0x030C0000 + #define __Pyx_PyUnicode_READY(op) (0) + #else + #define __Pyx_PyUnicode_READY(op) (likely(PyUnicode_IS_READY(op)) ?\ + 0 : _PyUnicode_Ready((PyObject *)(op))) + #endif + #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GET_LENGTH(u) + #define __Pyx_PyUnicode_READ_CHAR(u, i) PyUnicode_READ_CHAR(u, i) + #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) PyUnicode_MAX_CHAR_VALUE(u) + #define __Pyx_PyUnicode_KIND(u) ((int)PyUnicode_KIND(u)) + #define __Pyx_PyUnicode_DATA(u) PyUnicode_DATA(u) + #define __Pyx_PyUnicode_READ(k, d, i) PyUnicode_READ(k, d, i) + #define __Pyx_PyUnicode_WRITE(k, d, i, ch) PyUnicode_WRITE(k, d, i, (Py_UCS4) ch) + #if PY_VERSION_HEX >= 0x030C0000 + #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GET_LENGTH(u)) + #else + #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x03090000 + #define __Pyx_PyUnicode_IS_TRUE(u) (0 != (likely(PyUnicode_IS_READY(u)) ? PyUnicode_GET_LENGTH(u) : ((PyCompactUnicodeObject *)(u))->wstr_length)) + #else + #define __Pyx_PyUnicode_IS_TRUE(u) (0 != (likely(PyUnicode_IS_READY(u)) ? PyUnicode_GET_LENGTH(u) : PyUnicode_GET_SIZE(u))) + #endif + #endif +#else + #define CYTHON_PEP393_ENABLED 0 + #define PyUnicode_1BYTE_KIND 1 + #define PyUnicode_2BYTE_KIND 2 + #define PyUnicode_4BYTE_KIND 4 + #define __Pyx_PyUnicode_READY(op) (0) + #define __Pyx_PyUnicode_GET_LENGTH(u) PyUnicode_GET_SIZE(u) + #define __Pyx_PyUnicode_READ_CHAR(u, i) ((Py_UCS4)(PyUnicode_AS_UNICODE(u)[i])) + #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u) ((sizeof(Py_UNICODE) == 2) ? 65535U : 1114111U) + #define __Pyx_PyUnicode_KIND(u) ((int)sizeof(Py_UNICODE)) + #define __Pyx_PyUnicode_DATA(u) ((void*)PyUnicode_AS_UNICODE(u)) + #define __Pyx_PyUnicode_READ(k, d, i) ((void)(k), (Py_UCS4)(((Py_UNICODE*)d)[i])) + #define __Pyx_PyUnicode_WRITE(k, d, i, ch) (((void)(k)), ((Py_UNICODE*)d)[i] = (Py_UNICODE) ch) + #define __Pyx_PyUnicode_IS_TRUE(u) (0 != PyUnicode_GET_SIZE(u)) +#endif +#if CYTHON_COMPILING_IN_PYPY + #define __Pyx_PyUnicode_Concat(a, b) PyNumber_Add(a, b) + #define __Pyx_PyUnicode_ConcatSafe(a, b) PyNumber_Add(a, b) +#else + #define __Pyx_PyUnicode_Concat(a, b) PyUnicode_Concat(a, b) + #define __Pyx_PyUnicode_ConcatSafe(a, b) ((unlikely((a) == Py_None) || unlikely((b) == Py_None)) ?\ + PyNumber_Add(a, b) : __Pyx_PyUnicode_Concat(a, b)) +#endif +#if CYTHON_COMPILING_IN_PYPY + #if !defined(PyUnicode_DecodeUnicodeEscape) + #define PyUnicode_DecodeUnicodeEscape(s, size, errors) PyUnicode_Decode(s, size, "unicode_escape", errors) + #endif + #if !defined(PyUnicode_Contains) || (PY_MAJOR_VERSION == 2 && PYPY_VERSION_NUM < 0x07030500) + #undef PyUnicode_Contains + #define PyUnicode_Contains(u, s) PySequence_Contains(u, s) + #endif + #if !defined(PyByteArray_Check) + #define PyByteArray_Check(obj) PyObject_TypeCheck(obj, &PyByteArray_Type) + #endif + #if !defined(PyObject_Format) + #define PyObject_Format(obj, fmt) PyObject_CallMethod(obj, "__format__", "O", fmt) + #endif +#endif +#define __Pyx_PyString_FormatSafe(a, b) ((unlikely((a) == Py_None || (PyString_Check(b) && !PyString_CheckExact(b)))) ? PyNumber_Remainder(a, b) : __Pyx_PyString_Format(a, b)) +#define __Pyx_PyUnicode_FormatSafe(a, b) ((unlikely((a) == Py_None || (PyUnicode_Check(b) && !PyUnicode_CheckExact(b)))) ? PyNumber_Remainder(a, b) : PyUnicode_Format(a, b)) +#if PY_MAJOR_VERSION >= 3 + #define __Pyx_PyString_Format(a, b) PyUnicode_Format(a, b) +#else + #define __Pyx_PyString_Format(a, b) PyString_Format(a, b) +#endif +#if PY_MAJOR_VERSION < 3 && !defined(PyObject_ASCII) + #define PyObject_ASCII(o) PyObject_Repr(o) +#endif +#if PY_MAJOR_VERSION >= 3 + #define PyBaseString_Type PyUnicode_Type + #define PyStringObject PyUnicodeObject + #define PyString_Type PyUnicode_Type + #define PyString_Check PyUnicode_Check + #define PyString_CheckExact PyUnicode_CheckExact +#ifndef PyObject_Unicode + #define PyObject_Unicode PyObject_Str +#endif +#endif +#if PY_MAJOR_VERSION >= 3 + #define __Pyx_PyBaseString_Check(obj) PyUnicode_Check(obj) + #define __Pyx_PyBaseString_CheckExact(obj) PyUnicode_CheckExact(obj) +#else + #define __Pyx_PyBaseString_Check(obj) (PyString_Check(obj) || PyUnicode_Check(obj)) + #define __Pyx_PyBaseString_CheckExact(obj) (PyString_CheckExact(obj) || PyUnicode_CheckExact(obj)) +#endif +#if CYTHON_COMPILING_IN_CPYTHON + #define __Pyx_PySequence_ListKeepNew(obj)\ + (likely(PyList_CheckExact(obj) && Py_REFCNT(obj) == 1) ? __Pyx_NewRef(obj) : PySequence_List(obj)) +#else + #define __Pyx_PySequence_ListKeepNew(obj) PySequence_List(obj) +#endif +#ifndef PySet_CheckExact + #define PySet_CheckExact(obj) __Pyx_IS_TYPE(obj, &PySet_Type) +#endif +#if PY_VERSION_HEX >= 0x030900A4 + #define __Pyx_SET_REFCNT(obj, refcnt) Py_SET_REFCNT(obj, refcnt) + #define __Pyx_SET_SIZE(obj, size) Py_SET_SIZE(obj, size) +#else + #define __Pyx_SET_REFCNT(obj, refcnt) Py_REFCNT(obj) = (refcnt) + #define __Pyx_SET_SIZE(obj, size) Py_SIZE(obj) = (size) +#endif +#if CYTHON_ASSUME_SAFE_MACROS + #define __Pyx_PySequence_ITEM(o, i) PySequence_ITEM(o, i) + #define __Pyx_PySequence_SIZE(seq) Py_SIZE(seq) + #define __Pyx_PyTuple_SET_ITEM(o, i, v) (PyTuple_SET_ITEM(o, i, v), (0)) + #define __Pyx_PyList_SET_ITEM(o, i, v) (PyList_SET_ITEM(o, i, v), (0)) + #define __Pyx_PyTuple_GET_SIZE(o) PyTuple_GET_SIZE(o) + #define __Pyx_PyList_GET_SIZE(o) PyList_GET_SIZE(o) + #define __Pyx_PySet_GET_SIZE(o) PySet_GET_SIZE(o) + #define __Pyx_PyBytes_GET_SIZE(o) PyBytes_GET_SIZE(o) + #define __Pyx_PyByteArray_GET_SIZE(o) PyByteArray_GET_SIZE(o) +#else + #define __Pyx_PySequence_ITEM(o, i) PySequence_GetItem(o, i) + #define __Pyx_PySequence_SIZE(seq) PySequence_Size(seq) + #define __Pyx_PyTuple_SET_ITEM(o, i, v) PyTuple_SetItem(o, i, v) + #define __Pyx_PyList_SET_ITEM(o, i, v) PyList_SetItem(o, i, v) + #define __Pyx_PyTuple_GET_SIZE(o) PyTuple_Size(o) + #define __Pyx_PyList_GET_SIZE(o) PyList_Size(o) + #define __Pyx_PySet_GET_SIZE(o) PySet_Size(o) + #define __Pyx_PyBytes_GET_SIZE(o) PyBytes_Size(o) + #define __Pyx_PyByteArray_GET_SIZE(o) PyByteArray_Size(o) +#endif +#if __PYX_LIMITED_VERSION_HEX >= 0x030d00A1 + #define __Pyx_PyImport_AddModuleRef(name) PyImport_AddModuleRef(name) +#else + static CYTHON_INLINE PyObject *__Pyx_PyImport_AddModuleRef(const char *name) { + PyObject *module = PyImport_AddModule(name); + Py_XINCREF(module); + return module; + } +#endif +#if PY_MAJOR_VERSION >= 3 + #define PyIntObject PyLongObject + #define PyInt_Type PyLong_Type + #define PyInt_Check(op) PyLong_Check(op) + #define PyInt_CheckExact(op) PyLong_CheckExact(op) + #define __Pyx_Py3Int_Check(op) PyLong_Check(op) + #define __Pyx_Py3Int_CheckExact(op) PyLong_CheckExact(op) + #define PyInt_FromString PyLong_FromString + #define PyInt_FromUnicode PyLong_FromUnicode + #define PyInt_FromLong PyLong_FromLong + #define PyInt_FromSize_t PyLong_FromSize_t + #define PyInt_FromSsize_t PyLong_FromSsize_t + #define PyInt_AsLong PyLong_AsLong + #define PyInt_AS_LONG PyLong_AS_LONG + #define PyInt_AsSsize_t PyLong_AsSsize_t + #define PyInt_AsUnsignedLongMask PyLong_AsUnsignedLongMask + #define PyInt_AsUnsignedLongLongMask PyLong_AsUnsignedLongLongMask + #define PyNumber_Int PyNumber_Long +#else + #define __Pyx_Py3Int_Check(op) (PyLong_Check(op) || PyInt_Check(op)) + #define __Pyx_Py3Int_CheckExact(op) (PyLong_CheckExact(op) || PyInt_CheckExact(op)) +#endif +#if PY_MAJOR_VERSION >= 3 + #define PyBoolObject PyLongObject +#endif +#if PY_MAJOR_VERSION >= 3 && CYTHON_COMPILING_IN_PYPY + #ifndef PyUnicode_InternFromString + #define PyUnicode_InternFromString(s) PyUnicode_FromString(s) + #endif +#endif +#if PY_VERSION_HEX < 0x030200A4 + typedef long Py_hash_t; + #define __Pyx_PyInt_FromHash_t PyInt_FromLong + #define __Pyx_PyInt_AsHash_t __Pyx_PyIndex_AsHash_t +#else + #define __Pyx_PyInt_FromHash_t PyInt_FromSsize_t + #define __Pyx_PyInt_AsHash_t __Pyx_PyIndex_AsSsize_t +#endif +#if CYTHON_USE_ASYNC_SLOTS + #if PY_VERSION_HEX >= 0x030500B1 + #define __Pyx_PyAsyncMethodsStruct PyAsyncMethods + #define __Pyx_PyType_AsAsync(obj) (Py_TYPE(obj)->tp_as_async) + #else + #define __Pyx_PyType_AsAsync(obj) ((__Pyx_PyAsyncMethodsStruct*) (Py_TYPE(obj)->tp_reserved)) + #endif +#else + #define __Pyx_PyType_AsAsync(obj) NULL +#endif +#ifndef __Pyx_PyAsyncMethodsStruct + typedef struct { + unaryfunc am_await; + unaryfunc am_aiter; + unaryfunc am_anext; + } __Pyx_PyAsyncMethodsStruct; +#endif + +#if defined(_WIN32) || defined(WIN32) || defined(MS_WINDOWS) + #if !defined(_USE_MATH_DEFINES) + #define _USE_MATH_DEFINES + #endif +#endif +#include +#ifdef NAN +#define __PYX_NAN() ((float) NAN) +#else +static CYTHON_INLINE float __PYX_NAN() { + float value; + memset(&value, 0xFF, sizeof(value)); + return value; +} +#endif +#if defined(__CYGWIN__) && defined(_LDBL_EQ_DBL) +#define __Pyx_truncl trunc +#else +#define __Pyx_truncl truncl +#endif + +#define __PYX_MARK_ERR_POS(f_index, lineno) \ + { __pyx_filename = __pyx_f[f_index]; (void)__pyx_filename; __pyx_lineno = lineno; (void)__pyx_lineno; __pyx_clineno = __LINE__; (void)__pyx_clineno; } +#define __PYX_ERR(f_index, lineno, Ln_error) \ + { __PYX_MARK_ERR_POS(f_index, lineno) goto Ln_error; } + +#ifdef CYTHON_EXTERN_C + #undef __PYX_EXTERN_C + #define __PYX_EXTERN_C CYTHON_EXTERN_C +#elif defined(__PYX_EXTERN_C) + #ifdef _MSC_VER + #pragma message ("Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead.") + #else + #warning Please do not define the '__PYX_EXTERN_C' macro externally. Use 'CYTHON_EXTERN_C' instead. + #endif +#else + #ifdef __cplusplus + #define __PYX_EXTERN_C extern "C" + #else + #define __PYX_EXTERN_C extern + #endif +#endif + +#define __PYX_HAVE__TTS__tts__utils__monotonic_align__core +#define __PYX_HAVE_API__TTS__tts__utils__monotonic_align__core +/* Early includes */ +#include +#include + + /* Using NumPy API declarations from "numpy/__init__.cython-30.pxd" */ + +#include "numpy/arrayobject.h" +#include "numpy/ndarrayobject.h" +#include "numpy/ndarraytypes.h" +#include "numpy/arrayscalars.h" +#include "numpy/ufuncobject.h" +#include "pythread.h" +#include +#ifdef _OPENMP +#include +#endif /* _OPENMP */ + +#if defined(PYREX_WITHOUT_ASSERTIONS) && !defined(CYTHON_WITHOUT_ASSERTIONS) +#define CYTHON_WITHOUT_ASSERTIONS +#endif + +typedef struct {PyObject **p; const char *s; const Py_ssize_t n; const char* encoding; + const char is_unicode; const char is_str; const char intern; } __Pyx_StringTabEntry; + +#define __PYX_DEFAULT_STRING_ENCODING_IS_ASCII 0 +#define __PYX_DEFAULT_STRING_ENCODING_IS_UTF8 0 +#define __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT (PY_MAJOR_VERSION >= 3 && __PYX_DEFAULT_STRING_ENCODING_IS_UTF8) +#define __PYX_DEFAULT_STRING_ENCODING "" +#define __Pyx_PyObject_FromString __Pyx_PyBytes_FromString +#define __Pyx_PyObject_FromStringAndSize __Pyx_PyBytes_FromStringAndSize +#define __Pyx_uchar_cast(c) ((unsigned char)c) +#define __Pyx_long_cast(x) ((long)x) +#define __Pyx_fits_Py_ssize_t(v, type, is_signed) (\ + (sizeof(type) < sizeof(Py_ssize_t)) ||\ + (sizeof(type) > sizeof(Py_ssize_t) &&\ + likely(v < (type)PY_SSIZE_T_MAX ||\ + v == (type)PY_SSIZE_T_MAX) &&\ + (!is_signed || likely(v > (type)PY_SSIZE_T_MIN ||\ + v == (type)PY_SSIZE_T_MIN))) ||\ + (sizeof(type) == sizeof(Py_ssize_t) &&\ + (is_signed || likely(v < (type)PY_SSIZE_T_MAX ||\ + v == (type)PY_SSIZE_T_MAX))) ) +static CYTHON_INLINE int __Pyx_is_valid_index(Py_ssize_t i, Py_ssize_t limit) { + return (size_t) i < (size_t) limit; +} +#if defined (__cplusplus) && __cplusplus >= 201103L + #include + #define __Pyx_sst_abs(value) std::abs(value) +#elif SIZEOF_INT >= SIZEOF_SIZE_T + #define __Pyx_sst_abs(value) abs(value) +#elif SIZEOF_LONG >= SIZEOF_SIZE_T + #define __Pyx_sst_abs(value) labs(value) +#elif defined (_MSC_VER) + #define __Pyx_sst_abs(value) ((Py_ssize_t)_abs64(value)) +#elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L + #define __Pyx_sst_abs(value) llabs(value) +#elif defined (__GNUC__) + #define __Pyx_sst_abs(value) __builtin_llabs(value) +#else + #define __Pyx_sst_abs(value) ((value<0) ? -value : value) +#endif +static CYTHON_INLINE Py_ssize_t __Pyx_ssize_strlen(const char *s); +static CYTHON_INLINE const char* __Pyx_PyObject_AsString(PyObject*); +static CYTHON_INLINE const char* __Pyx_PyObject_AsStringAndSize(PyObject*, Py_ssize_t* length); +static CYTHON_INLINE PyObject* __Pyx_PyByteArray_FromString(const char*); +#define __Pyx_PyByteArray_FromStringAndSize(s, l) PyByteArray_FromStringAndSize((const char*)s, l) +#define __Pyx_PyBytes_FromString PyBytes_FromString +#define __Pyx_PyBytes_FromStringAndSize PyBytes_FromStringAndSize +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_FromString(const char*); +#if PY_MAJOR_VERSION < 3 + #define __Pyx_PyStr_FromString __Pyx_PyBytes_FromString + #define __Pyx_PyStr_FromStringAndSize __Pyx_PyBytes_FromStringAndSize +#else + #define __Pyx_PyStr_FromString __Pyx_PyUnicode_FromString + #define __Pyx_PyStr_FromStringAndSize __Pyx_PyUnicode_FromStringAndSize +#endif +#define __Pyx_PyBytes_AsWritableString(s) ((char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyBytes_AsWritableSString(s) ((signed char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyBytes_AsWritableUString(s) ((unsigned char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyBytes_AsString(s) ((const char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyBytes_AsSString(s) ((const signed char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyBytes_AsUString(s) ((const unsigned char*) PyBytes_AS_STRING(s)) +#define __Pyx_PyObject_AsWritableString(s) ((char*)(__pyx_uintptr_t) __Pyx_PyObject_AsString(s)) +#define __Pyx_PyObject_AsWritableSString(s) ((signed char*)(__pyx_uintptr_t) __Pyx_PyObject_AsString(s)) +#define __Pyx_PyObject_AsWritableUString(s) ((unsigned char*)(__pyx_uintptr_t) __Pyx_PyObject_AsString(s)) +#define __Pyx_PyObject_AsSString(s) ((const signed char*) __Pyx_PyObject_AsString(s)) +#define __Pyx_PyObject_AsUString(s) ((const unsigned char*) __Pyx_PyObject_AsString(s)) +#define __Pyx_PyObject_FromCString(s) __Pyx_PyObject_FromString((const char*)s) +#define __Pyx_PyBytes_FromCString(s) __Pyx_PyBytes_FromString((const char*)s) +#define __Pyx_PyByteArray_FromCString(s) __Pyx_PyByteArray_FromString((const char*)s) +#define __Pyx_PyStr_FromCString(s) __Pyx_PyStr_FromString((const char*)s) +#define __Pyx_PyUnicode_FromCString(s) __Pyx_PyUnicode_FromString((const char*)s) +#define __Pyx_PyUnicode_FromOrdinal(o) PyUnicode_FromOrdinal((int)o) +#define __Pyx_PyUnicode_AsUnicode PyUnicode_AsUnicode +#define __Pyx_NewRef(obj) (Py_INCREF(obj), obj) +#define __Pyx_Owned_Py_None(b) __Pyx_NewRef(Py_None) +static CYTHON_INLINE PyObject * __Pyx_PyBool_FromLong(long b); +static CYTHON_INLINE int __Pyx_PyObject_IsTrue(PyObject*); +static CYTHON_INLINE int __Pyx_PyObject_IsTrueAndDecref(PyObject*); +static CYTHON_INLINE PyObject* __Pyx_PyNumber_IntOrLong(PyObject* x); +#define __Pyx_PySequence_Tuple(obj)\ + (likely(PyTuple_CheckExact(obj)) ? __Pyx_NewRef(obj) : PySequence_Tuple(obj)) +static CYTHON_INLINE Py_ssize_t __Pyx_PyIndex_AsSsize_t(PyObject*); +static CYTHON_INLINE PyObject * __Pyx_PyInt_FromSize_t(size_t); +static CYTHON_INLINE Py_hash_t __Pyx_PyIndex_AsHash_t(PyObject*); +#if CYTHON_ASSUME_SAFE_MACROS +#define __pyx_PyFloat_AsDouble(x) (PyFloat_CheckExact(x) ? PyFloat_AS_DOUBLE(x) : PyFloat_AsDouble(x)) +#else +#define __pyx_PyFloat_AsDouble(x) PyFloat_AsDouble(x) +#endif +#define __pyx_PyFloat_AsFloat(x) ((float) __pyx_PyFloat_AsDouble(x)) +#if PY_MAJOR_VERSION >= 3 +#define __Pyx_PyNumber_Int(x) (PyLong_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Long(x)) +#else +#define __Pyx_PyNumber_Int(x) (PyInt_CheckExact(x) ? __Pyx_NewRef(x) : PyNumber_Int(x)) +#endif +#if CYTHON_USE_PYLONG_INTERNALS + #if PY_VERSION_HEX >= 0x030C00A7 + #ifndef _PyLong_SIGN_MASK + #define _PyLong_SIGN_MASK 3 + #endif + #ifndef _PyLong_NON_SIZE_BITS + #define _PyLong_NON_SIZE_BITS 3 + #endif + #define __Pyx_PyLong_Sign(x) (((PyLongObject*)x)->long_value.lv_tag & _PyLong_SIGN_MASK) + #define __Pyx_PyLong_IsNeg(x) ((__Pyx_PyLong_Sign(x) & 2) != 0) + #define __Pyx_PyLong_IsNonNeg(x) (!__Pyx_PyLong_IsNeg(x)) + #define __Pyx_PyLong_IsZero(x) (__Pyx_PyLong_Sign(x) & 1) + #define __Pyx_PyLong_IsPos(x) (__Pyx_PyLong_Sign(x) == 0) + #define __Pyx_PyLong_CompactValueUnsigned(x) (__Pyx_PyLong_Digits(x)[0]) + #define __Pyx_PyLong_DigitCount(x) ((Py_ssize_t) (((PyLongObject*)x)->long_value.lv_tag >> _PyLong_NON_SIZE_BITS)) + #define __Pyx_PyLong_SignedDigitCount(x)\ + ((1 - (Py_ssize_t) __Pyx_PyLong_Sign(x)) * __Pyx_PyLong_DigitCount(x)) + #if defined(PyUnstable_Long_IsCompact) && defined(PyUnstable_Long_CompactValue) + #define __Pyx_PyLong_IsCompact(x) PyUnstable_Long_IsCompact((PyLongObject*) x) + #define __Pyx_PyLong_CompactValue(x) PyUnstable_Long_CompactValue((PyLongObject*) x) + #else + #define __Pyx_PyLong_IsCompact(x) (((PyLongObject*)x)->long_value.lv_tag < (2 << _PyLong_NON_SIZE_BITS)) + #define __Pyx_PyLong_CompactValue(x) ((1 - (Py_ssize_t) __Pyx_PyLong_Sign(x)) * (Py_ssize_t) __Pyx_PyLong_Digits(x)[0]) + #endif + typedef Py_ssize_t __Pyx_compact_pylong; + typedef size_t __Pyx_compact_upylong; + #else + #define __Pyx_PyLong_IsNeg(x) (Py_SIZE(x) < 0) + #define __Pyx_PyLong_IsNonNeg(x) (Py_SIZE(x) >= 0) + #define __Pyx_PyLong_IsZero(x) (Py_SIZE(x) == 0) + #define __Pyx_PyLong_IsPos(x) (Py_SIZE(x) > 0) + #define __Pyx_PyLong_CompactValueUnsigned(x) ((Py_SIZE(x) == 0) ? 0 : __Pyx_PyLong_Digits(x)[0]) + #define __Pyx_PyLong_DigitCount(x) __Pyx_sst_abs(Py_SIZE(x)) + #define __Pyx_PyLong_SignedDigitCount(x) Py_SIZE(x) + #define __Pyx_PyLong_IsCompact(x) (Py_SIZE(x) == 0 || Py_SIZE(x) == 1 || Py_SIZE(x) == -1) + #define __Pyx_PyLong_CompactValue(x)\ + ((Py_SIZE(x) == 0) ? (sdigit) 0 : ((Py_SIZE(x) < 0) ? -(sdigit)__Pyx_PyLong_Digits(x)[0] : (sdigit)__Pyx_PyLong_Digits(x)[0])) + typedef sdigit __Pyx_compact_pylong; + typedef digit __Pyx_compact_upylong; + #endif + #if PY_VERSION_HEX >= 0x030C00A5 + #define __Pyx_PyLong_Digits(x) (((PyLongObject*)x)->long_value.ob_digit) + #else + #define __Pyx_PyLong_Digits(x) (((PyLongObject*)x)->ob_digit) + #endif +#endif +#if PY_MAJOR_VERSION < 3 && __PYX_DEFAULT_STRING_ENCODING_IS_ASCII +#include +static int __Pyx_sys_getdefaultencoding_not_ascii; +static int __Pyx_init_sys_getdefaultencoding_params(void) { + PyObject* sys; + PyObject* default_encoding = NULL; + PyObject* ascii_chars_u = NULL; + PyObject* ascii_chars_b = NULL; + const char* default_encoding_c; + sys = PyImport_ImportModule("sys"); + if (!sys) goto bad; + default_encoding = PyObject_CallMethod(sys, (char*) "getdefaultencoding", NULL); + Py_DECREF(sys); + if (!default_encoding) goto bad; + default_encoding_c = PyBytes_AsString(default_encoding); + if (!default_encoding_c) goto bad; + if (strcmp(default_encoding_c, "ascii") == 0) { + __Pyx_sys_getdefaultencoding_not_ascii = 0; + } else { + char ascii_chars[128]; + int c; + for (c = 0; c < 128; c++) { + ascii_chars[c] = (char) c; + } + __Pyx_sys_getdefaultencoding_not_ascii = 1; + ascii_chars_u = PyUnicode_DecodeASCII(ascii_chars, 128, NULL); + if (!ascii_chars_u) goto bad; + ascii_chars_b = PyUnicode_AsEncodedString(ascii_chars_u, default_encoding_c, NULL); + if (!ascii_chars_b || !PyBytes_Check(ascii_chars_b) || memcmp(ascii_chars, PyBytes_AS_STRING(ascii_chars_b), 128) != 0) { + PyErr_Format( + PyExc_ValueError, + "This module compiled with c_string_encoding=ascii, but default encoding '%.200s' is not a superset of ascii.", + default_encoding_c); + goto bad; + } + Py_DECREF(ascii_chars_u); + Py_DECREF(ascii_chars_b); + } + Py_DECREF(default_encoding); + return 0; +bad: + Py_XDECREF(default_encoding); + Py_XDECREF(ascii_chars_u); + Py_XDECREF(ascii_chars_b); + return -1; +} +#endif +#if __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT && PY_MAJOR_VERSION >= 3 +#define __Pyx_PyUnicode_FromStringAndSize(c_str, size) PyUnicode_DecodeUTF8(c_str, size, NULL) +#else +#define __Pyx_PyUnicode_FromStringAndSize(c_str, size) PyUnicode_Decode(c_str, size, __PYX_DEFAULT_STRING_ENCODING, NULL) +#if __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT +#include +static char* __PYX_DEFAULT_STRING_ENCODING; +static int __Pyx_init_sys_getdefaultencoding_params(void) { + PyObject* sys; + PyObject* default_encoding = NULL; + char* default_encoding_c; + sys = PyImport_ImportModule("sys"); + if (!sys) goto bad; + default_encoding = PyObject_CallMethod(sys, (char*) (const char*) "getdefaultencoding", NULL); + Py_DECREF(sys); + if (!default_encoding) goto bad; + default_encoding_c = PyBytes_AsString(default_encoding); + if (!default_encoding_c) goto bad; + __PYX_DEFAULT_STRING_ENCODING = (char*) malloc(strlen(default_encoding_c) + 1); + if (!__PYX_DEFAULT_STRING_ENCODING) goto bad; + strcpy(__PYX_DEFAULT_STRING_ENCODING, default_encoding_c); + Py_DECREF(default_encoding); + return 0; +bad: + Py_XDECREF(default_encoding); + return -1; +} +#endif +#endif + + +/* Test for GCC > 2.95 */ +#if defined(__GNUC__) && (__GNUC__ > 2 || (__GNUC__ == 2 && (__GNUC_MINOR__ > 95))) + #define likely(x) __builtin_expect(!!(x), 1) + #define unlikely(x) __builtin_expect(!!(x), 0) +#else /* !__GNUC__ or GCC < 2.95 */ + #define likely(x) (x) + #define unlikely(x) (x) +#endif /* __GNUC__ */ +static CYTHON_INLINE void __Pyx_pretend_to_initialize(void* ptr) { (void)ptr; } + +#if !CYTHON_USE_MODULE_STATE +static PyObject *__pyx_m = NULL; +#endif +static int __pyx_lineno; +static int __pyx_clineno = 0; +static const char * __pyx_cfilenm = __FILE__; +static const char *__pyx_filename; + +/* Header.proto */ +#if !defined(CYTHON_CCOMPLEX) + #if defined(__cplusplus) + #define CYTHON_CCOMPLEX 1 + #elif (defined(_Complex_I) && !defined(_MSC_VER)) || ((defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L) && !defined(__STDC_NO_COMPLEX__) && !defined(_MSC_VER)) + #define CYTHON_CCOMPLEX 1 + #else + #define CYTHON_CCOMPLEX 0 + #endif +#endif +#if CYTHON_CCOMPLEX + #ifdef __cplusplus + #include + #else + #include + #endif +#endif +#if CYTHON_CCOMPLEX && !defined(__cplusplus) && defined(__sun__) && defined(__GNUC__) + #undef _Complex_I + #define _Complex_I 1.0fj +#endif + +/* #### Code section: filename_table ### */ + +static const char *__pyx_f[] = { + "TTS/tts/utils/monotonic_align/core.pyx", + "", + "__init__.cython-30.pxd", + "type.pxd", +}; +/* #### Code section: utility_code_proto_before_types ### */ +/* ForceInitThreads.proto */ +#ifndef __PYX_FORCE_INIT_THREADS + #define __PYX_FORCE_INIT_THREADS 0 +#endif + +/* NoFastGil.proto */ +#define __Pyx_PyGILState_Ensure PyGILState_Ensure +#define __Pyx_PyGILState_Release PyGILState_Release +#define __Pyx_FastGIL_Remember() +#define __Pyx_FastGIL_Forget() +#define __Pyx_FastGilFuncInit() + +/* BufferFormatStructs.proto */ +struct __Pyx_StructField_; +#define __PYX_BUF_FLAGS_PACKED_STRUCT (1 << 0) +typedef struct { + const char* name; + struct __Pyx_StructField_* fields; + size_t size; + size_t arraysize[8]; + int ndim; + char typegroup; + char is_unsigned; + int flags; +} __Pyx_TypeInfo; +typedef struct __Pyx_StructField_ { + __Pyx_TypeInfo* type; + const char* name; + size_t offset; +} __Pyx_StructField; +typedef struct { + __Pyx_StructField* field; + size_t parent_offset; +} __Pyx_BufFmt_StackElem; +typedef struct { + __Pyx_StructField root; + __Pyx_BufFmt_StackElem* head; + size_t fmt_offset; + size_t new_count, enc_count; + size_t struct_alignment; + int is_complex; + char enc_type; + char new_packmode; + char enc_packmode; + char is_valid_array; +} __Pyx_BufFmt_Context; + +/* Atomics.proto */ +#include +#ifndef CYTHON_ATOMICS + #define CYTHON_ATOMICS 1 +#endif +#define __PYX_CYTHON_ATOMICS_ENABLED() CYTHON_ATOMICS +#define __pyx_atomic_int_type int +#define __pyx_nonatomic_int_type int +#if CYTHON_ATOMICS && (defined(__STDC_VERSION__) &&\ + (__STDC_VERSION__ >= 201112L) &&\ + !defined(__STDC_NO_ATOMICS__)) + #include +#elif CYTHON_ATOMICS && (defined(__cplusplus) && (\ + (__cplusplus >= 201103L) ||\ + (defined(_MSC_VER) && _MSC_VER >= 1700))) + #include +#endif +#if CYTHON_ATOMICS && (defined(__STDC_VERSION__) &&\ + (__STDC_VERSION__ >= 201112L) &&\ + !defined(__STDC_NO_ATOMICS__) &&\ + ATOMIC_INT_LOCK_FREE == 2) + #undef __pyx_atomic_int_type + #define __pyx_atomic_int_type atomic_int + #define __pyx_atomic_incr_aligned(value) atomic_fetch_add_explicit(value, 1, memory_order_relaxed) + #define __pyx_atomic_decr_aligned(value) atomic_fetch_sub_explicit(value, 1, memory_order_acq_rel) + #if defined(__PYX_DEBUG_ATOMICS) && defined(_MSC_VER) + #pragma message ("Using standard C atomics") + #elif defined(__PYX_DEBUG_ATOMICS) + #warning "Using standard C atomics" + #endif +#elif CYTHON_ATOMICS && (defined(__cplusplus) && (\ + (__cplusplus >= 201103L) ||\ +\ + (defined(_MSC_VER) && _MSC_VER >= 1700)) &&\ + ATOMIC_INT_LOCK_FREE == 2) + #undef __pyx_atomic_int_type + #define __pyx_atomic_int_type std::atomic_int + #define __pyx_atomic_incr_aligned(value) std::atomic_fetch_add_explicit(value, 1, std::memory_order_relaxed) + #define __pyx_atomic_decr_aligned(value) std::atomic_fetch_sub_explicit(value, 1, std::memory_order_acq_rel) + #if defined(__PYX_DEBUG_ATOMICS) && defined(_MSC_VER) + #pragma message ("Using standard C++ atomics") + #elif defined(__PYX_DEBUG_ATOMICS) + #warning "Using standard C++ atomics" + #endif +#elif CYTHON_ATOMICS && (__GNUC__ >= 5 || (__GNUC__ == 4 &&\ + (__GNUC_MINOR__ > 1 ||\ + (__GNUC_MINOR__ == 1 && __GNUC_PATCHLEVEL__ >= 2)))) + #define __pyx_atomic_incr_aligned(value) __sync_fetch_and_add(value, 1) + #define __pyx_atomic_decr_aligned(value) __sync_fetch_and_sub(value, 1) + #ifdef __PYX_DEBUG_ATOMICS + #warning "Using GNU atomics" + #endif +#elif CYTHON_ATOMICS && defined(_MSC_VER) + #include + #undef __pyx_atomic_int_type + #define __pyx_atomic_int_type long + #undef __pyx_nonatomic_int_type + #define __pyx_nonatomic_int_type long + #pragma intrinsic (_InterlockedExchangeAdd) + #define __pyx_atomic_incr_aligned(value) _InterlockedExchangeAdd(value, 1) + #define __pyx_atomic_decr_aligned(value) _InterlockedExchangeAdd(value, -1) + #ifdef __PYX_DEBUG_ATOMICS + #pragma message ("Using MSVC atomics") + #endif +#else + #undef CYTHON_ATOMICS + #define CYTHON_ATOMICS 0 + #ifdef __PYX_DEBUG_ATOMICS + #warning "Not using atomics" + #endif +#endif +#if CYTHON_ATOMICS + #define __pyx_add_acquisition_count(memview)\ + __pyx_atomic_incr_aligned(__pyx_get_slice_count_pointer(memview)) + #define __pyx_sub_acquisition_count(memview)\ + __pyx_atomic_decr_aligned(__pyx_get_slice_count_pointer(memview)) +#else + #define __pyx_add_acquisition_count(memview)\ + __pyx_add_acquisition_count_locked(__pyx_get_slice_count_pointer(memview), memview->lock) + #define __pyx_sub_acquisition_count(memview)\ + __pyx_sub_acquisition_count_locked(__pyx_get_slice_count_pointer(memview), memview->lock) +#endif + +/* MemviewSliceStruct.proto */ +struct __pyx_memoryview_obj; +typedef struct { + struct __pyx_memoryview_obj *memview; + char *data; + Py_ssize_t shape[8]; + Py_ssize_t strides[8]; + Py_ssize_t suboffsets[8]; +} __Pyx_memviewslice; +#define __Pyx_MemoryView_Len(m) (m.shape[0]) + +/* #### Code section: numeric_typedefs ### */ + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":770 + * # in Cython to enable them only on the right systems. + * + * ctypedef npy_int8 int8_t # <<<<<<<<<<<<<< + * ctypedef npy_int16 int16_t + * ctypedef npy_int32 int32_t + */ +typedef npy_int8 __pyx_t_5numpy_int8_t; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":771 + * + * ctypedef npy_int8 int8_t + * ctypedef npy_int16 int16_t # <<<<<<<<<<<<<< + * ctypedef npy_int32 int32_t + * ctypedef npy_int64 int64_t + */ +typedef npy_int16 __pyx_t_5numpy_int16_t; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":772 + * ctypedef npy_int8 int8_t + * ctypedef npy_int16 int16_t + * ctypedef npy_int32 int32_t # <<<<<<<<<<<<<< + * ctypedef npy_int64 int64_t + * #ctypedef npy_int96 int96_t + */ +typedef npy_int32 __pyx_t_5numpy_int32_t; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":773 + * ctypedef npy_int16 int16_t + * ctypedef npy_int32 int32_t + * ctypedef npy_int64 int64_t # <<<<<<<<<<<<<< + * #ctypedef npy_int96 int96_t + * #ctypedef npy_int128 int128_t + */ +typedef npy_int64 __pyx_t_5numpy_int64_t; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":777 + * #ctypedef npy_int128 int128_t + * + * ctypedef npy_uint8 uint8_t # <<<<<<<<<<<<<< + * ctypedef npy_uint16 uint16_t + * ctypedef npy_uint32 uint32_t + */ +typedef npy_uint8 __pyx_t_5numpy_uint8_t; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":778 + * + * ctypedef npy_uint8 uint8_t + * ctypedef npy_uint16 uint16_t # <<<<<<<<<<<<<< + * ctypedef npy_uint32 uint32_t + * ctypedef npy_uint64 uint64_t + */ +typedef npy_uint16 __pyx_t_5numpy_uint16_t; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":779 + * ctypedef npy_uint8 uint8_t + * ctypedef npy_uint16 uint16_t + * ctypedef npy_uint32 uint32_t # <<<<<<<<<<<<<< + * ctypedef npy_uint64 uint64_t + * #ctypedef npy_uint96 uint96_t + */ +typedef npy_uint32 __pyx_t_5numpy_uint32_t; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":780 + * ctypedef npy_uint16 uint16_t + * ctypedef npy_uint32 uint32_t + * ctypedef npy_uint64 uint64_t # <<<<<<<<<<<<<< + * #ctypedef npy_uint96 uint96_t + * #ctypedef npy_uint128 uint128_t + */ +typedef npy_uint64 __pyx_t_5numpy_uint64_t; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":784 + * #ctypedef npy_uint128 uint128_t + * + * ctypedef npy_float32 float32_t # <<<<<<<<<<<<<< + * ctypedef npy_float64 float64_t + * #ctypedef npy_float80 float80_t + */ +typedef npy_float32 __pyx_t_5numpy_float32_t; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":785 + * + * ctypedef npy_float32 float32_t + * ctypedef npy_float64 float64_t # <<<<<<<<<<<<<< + * #ctypedef npy_float80 float80_t + * #ctypedef npy_float128 float128_t + */ +typedef npy_float64 __pyx_t_5numpy_float64_t; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":792 + * ctypedef double complex complex128_t + * + * ctypedef npy_longlong longlong_t # <<<<<<<<<<<<<< + * ctypedef npy_ulonglong ulonglong_t + * + */ +typedef npy_longlong __pyx_t_5numpy_longlong_t; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":793 + * + * ctypedef npy_longlong longlong_t + * ctypedef npy_ulonglong ulonglong_t # <<<<<<<<<<<<<< + * + * ctypedef npy_intp intp_t + */ +typedef npy_ulonglong __pyx_t_5numpy_ulonglong_t; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":795 + * ctypedef npy_ulonglong ulonglong_t + * + * ctypedef npy_intp intp_t # <<<<<<<<<<<<<< + * ctypedef npy_uintp uintp_t + * + */ +typedef npy_intp __pyx_t_5numpy_intp_t; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":796 + * + * ctypedef npy_intp intp_t + * ctypedef npy_uintp uintp_t # <<<<<<<<<<<<<< + * + * ctypedef npy_double float_t + */ +typedef npy_uintp __pyx_t_5numpy_uintp_t; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":798 + * ctypedef npy_uintp uintp_t + * + * ctypedef npy_double float_t # <<<<<<<<<<<<<< + * ctypedef npy_double double_t + * ctypedef npy_longdouble longdouble_t + */ +typedef npy_double __pyx_t_5numpy_float_t; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":799 + * + * ctypedef npy_double float_t + * ctypedef npy_double double_t # <<<<<<<<<<<<<< + * ctypedef npy_longdouble longdouble_t + * + */ +typedef npy_double __pyx_t_5numpy_double_t; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":800 + * ctypedef npy_double float_t + * ctypedef npy_double double_t + * ctypedef npy_longdouble longdouble_t # <<<<<<<<<<<<<< + * + * ctypedef float complex cfloat_t + */ +typedef npy_longdouble __pyx_t_5numpy_longdouble_t; +/* #### Code section: complex_type_declarations ### */ +/* Declarations.proto */ +#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #ifdef __cplusplus + typedef ::std::complex< float > __pyx_t_float_complex; + #else + typedef float _Complex __pyx_t_float_complex; + #endif +#else + typedef struct { float real, imag; } __pyx_t_float_complex; +#endif +static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float, float); + +/* Declarations.proto */ +#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #ifdef __cplusplus + typedef ::std::complex< double > __pyx_t_double_complex; + #else + typedef double _Complex __pyx_t_double_complex; + #endif +#else + typedef struct { double real, imag; } __pyx_t_double_complex; +#endif +static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double, double); + +/* Declarations.proto */ +#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #ifdef __cplusplus + typedef ::std::complex< long double > __pyx_t_long_double_complex; + #else + typedef long double _Complex __pyx_t_long_double_complex; + #endif +#else + typedef struct { long double real, imag; } __pyx_t_long_double_complex; +#endif +static CYTHON_INLINE __pyx_t_long_double_complex __pyx_t_long_double_complex_from_parts(long double, long double); + +/* #### Code section: type_declarations ### */ + +/*--- Type declarations ---*/ +struct __pyx_array_obj; +struct __pyx_MemviewEnum_obj; +struct __pyx_memoryview_obj; +struct __pyx_memoryviewslice_obj; + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1096 + * + * # Iterator API added in v1.6 + * ctypedef int (*NpyIter_IterNextFunc)(NpyIter* it) noexcept nogil # <<<<<<<<<<<<<< + * ctypedef void (*NpyIter_GetMultiIndexFunc)(NpyIter* it, npy_intp* outcoords) noexcept nogil + * + */ +typedef int (*__pyx_t_5numpy_NpyIter_IterNextFunc)(NpyIter *); + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1097 + * # Iterator API added in v1.6 + * ctypedef int (*NpyIter_IterNextFunc)(NpyIter* it) noexcept nogil + * ctypedef void (*NpyIter_GetMultiIndexFunc)(NpyIter* it, npy_intp* outcoords) noexcept nogil # <<<<<<<<<<<<<< + * + * cdef extern from "numpy/arrayobject.h": + */ +typedef void (*__pyx_t_5numpy_NpyIter_GetMultiIndexFunc)(NpyIter *, npy_intp *); +struct __pyx_opt_args_3TTS_3tts_5utils_15monotonic_align_4core_maximum_path_c; + +/* "TTS/tts/utils/monotonic_align/core.pyx":42 + * @cython.boundscheck(False) + * @cython.wraparound(False) + * cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: # <<<<<<<<<<<<<< + * cdef int b = values.shape[0] + * + */ +struct __pyx_opt_args_3TTS_3tts_5utils_15monotonic_align_4core_maximum_path_c { + int __pyx_n; + float max_neg_val; +}; + +/* "View.MemoryView":114 + * @cython.collection_type("sequence") + * @cname("__pyx_array") + * cdef class array: # <<<<<<<<<<<<<< + * + * cdef: + */ +struct __pyx_array_obj { + PyObject_HEAD + struct __pyx_vtabstruct_array *__pyx_vtab; + char *data; + Py_ssize_t len; + char *format; + int ndim; + Py_ssize_t *_shape; + Py_ssize_t *_strides; + Py_ssize_t itemsize; + PyObject *mode; + PyObject *_format; + void (*callback_free_data)(void *); + int free_data; + int dtype_is_object; +}; + + +/* "View.MemoryView":302 + * + * @cname('__pyx_MemviewEnum') + * cdef class Enum(object): # <<<<<<<<<<<<<< + * cdef object name + * def __init__(self, name): + */ +struct __pyx_MemviewEnum_obj { + PyObject_HEAD + PyObject *name; +}; + + +/* "View.MemoryView":337 + * + * @cname('__pyx_memoryview') + * cdef class memoryview: # <<<<<<<<<<<<<< + * + * cdef object obj + */ +struct __pyx_memoryview_obj { + PyObject_HEAD + struct __pyx_vtabstruct_memoryview *__pyx_vtab; + PyObject *obj; + PyObject *_size; + PyObject *_array_interface; + PyThread_type_lock lock; + __pyx_atomic_int_type acquisition_count; + Py_buffer view; + int flags; + int dtype_is_object; + __Pyx_TypeInfo *typeinfo; +}; + + +/* "View.MemoryView":952 + * @cython.collection_type("sequence") + * @cname('__pyx_memoryviewslice') + * cdef class _memoryviewslice(memoryview): # <<<<<<<<<<<<<< + * "Internal class for passing memoryview slices to Python" + * + */ +struct __pyx_memoryviewslice_obj { + struct __pyx_memoryview_obj __pyx_base; + __Pyx_memviewslice from_slice; + PyObject *from_object; + PyObject *(*to_object_func)(char *); + int (*to_dtype_func)(char *, PyObject *); +}; + + + +/* "View.MemoryView":114 + * @cython.collection_type("sequence") + * @cname("__pyx_array") + * cdef class array: # <<<<<<<<<<<<<< + * + * cdef: + */ + +struct __pyx_vtabstruct_array { + PyObject *(*get_memview)(struct __pyx_array_obj *); +}; +static struct __pyx_vtabstruct_array *__pyx_vtabptr_array; + + +/* "View.MemoryView":337 + * + * @cname('__pyx_memoryview') + * cdef class memoryview: # <<<<<<<<<<<<<< + * + * cdef object obj + */ + +struct __pyx_vtabstruct_memoryview { + char *(*get_item_pointer)(struct __pyx_memoryview_obj *, PyObject *); + PyObject *(*is_slice)(struct __pyx_memoryview_obj *, PyObject *); + PyObject *(*setitem_slice_assignment)(struct __pyx_memoryview_obj *, PyObject *, PyObject *); + PyObject *(*setitem_slice_assign_scalar)(struct __pyx_memoryview_obj *, struct __pyx_memoryview_obj *, PyObject *); + PyObject *(*setitem_indexed)(struct __pyx_memoryview_obj *, PyObject *, PyObject *); + PyObject *(*convert_item_to_object)(struct __pyx_memoryview_obj *, char *); + PyObject *(*assign_item_from_object)(struct __pyx_memoryview_obj *, char *, PyObject *); + PyObject *(*_get_base)(struct __pyx_memoryview_obj *); +}; +static struct __pyx_vtabstruct_memoryview *__pyx_vtabptr_memoryview; + + +/* "View.MemoryView":952 + * @cython.collection_type("sequence") + * @cname('__pyx_memoryviewslice') + * cdef class _memoryviewslice(memoryview): # <<<<<<<<<<<<<< + * "Internal class for passing memoryview slices to Python" + * + */ + +struct __pyx_vtabstruct__memoryviewslice { + struct __pyx_vtabstruct_memoryview __pyx_base; +}; +static struct __pyx_vtabstruct__memoryviewslice *__pyx_vtabptr__memoryviewslice; +/* #### Code section: utility_code_proto ### */ + +/* --- Runtime support code (head) --- */ +/* Refnanny.proto */ +#ifndef CYTHON_REFNANNY + #define CYTHON_REFNANNY 0 +#endif +#if CYTHON_REFNANNY + typedef struct { + void (*INCREF)(void*, PyObject*, Py_ssize_t); + void (*DECREF)(void*, PyObject*, Py_ssize_t); + void (*GOTREF)(void*, PyObject*, Py_ssize_t); + void (*GIVEREF)(void*, PyObject*, Py_ssize_t); + void* (*SetupContext)(const char*, Py_ssize_t, const char*); + void (*FinishContext)(void**); + } __Pyx_RefNannyAPIStruct; + static __Pyx_RefNannyAPIStruct *__Pyx_RefNanny = NULL; + static __Pyx_RefNannyAPIStruct *__Pyx_RefNannyImportAPI(const char *modname); + #define __Pyx_RefNannyDeclarations void *__pyx_refnanny = NULL; +#ifdef WITH_THREAD + #define __Pyx_RefNannySetupContext(name, acquire_gil)\ + if (acquire_gil) {\ + PyGILState_STATE __pyx_gilstate_save = PyGILState_Ensure();\ + __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), (__LINE__), (__FILE__));\ + PyGILState_Release(__pyx_gilstate_save);\ + } else {\ + __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), (__LINE__), (__FILE__));\ + } + #define __Pyx_RefNannyFinishContextNogil() {\ + PyGILState_STATE __pyx_gilstate_save = PyGILState_Ensure();\ + __Pyx_RefNannyFinishContext();\ + PyGILState_Release(__pyx_gilstate_save);\ + } +#else + #define __Pyx_RefNannySetupContext(name, acquire_gil)\ + __pyx_refnanny = __Pyx_RefNanny->SetupContext((name), (__LINE__), (__FILE__)) + #define __Pyx_RefNannyFinishContextNogil() __Pyx_RefNannyFinishContext() +#endif + #define __Pyx_RefNannyFinishContextNogil() {\ + PyGILState_STATE __pyx_gilstate_save = PyGILState_Ensure();\ + __Pyx_RefNannyFinishContext();\ + PyGILState_Release(__pyx_gilstate_save);\ + } + #define __Pyx_RefNannyFinishContext()\ + __Pyx_RefNanny->FinishContext(&__pyx_refnanny) + #define __Pyx_INCREF(r) __Pyx_RefNanny->INCREF(__pyx_refnanny, (PyObject *)(r), (__LINE__)) + #define __Pyx_DECREF(r) __Pyx_RefNanny->DECREF(__pyx_refnanny, (PyObject *)(r), (__LINE__)) + #define __Pyx_GOTREF(r) __Pyx_RefNanny->GOTREF(__pyx_refnanny, (PyObject *)(r), (__LINE__)) + #define __Pyx_GIVEREF(r) __Pyx_RefNanny->GIVEREF(__pyx_refnanny, (PyObject *)(r), (__LINE__)) + #define __Pyx_XINCREF(r) do { if((r) == NULL); else {__Pyx_INCREF(r); }} while(0) + #define __Pyx_XDECREF(r) do { if((r) == NULL); else {__Pyx_DECREF(r); }} while(0) + #define __Pyx_XGOTREF(r) do { if((r) == NULL); else {__Pyx_GOTREF(r); }} while(0) + #define __Pyx_XGIVEREF(r) do { if((r) == NULL); else {__Pyx_GIVEREF(r);}} while(0) +#else + #define __Pyx_RefNannyDeclarations + #define __Pyx_RefNannySetupContext(name, acquire_gil) + #define __Pyx_RefNannyFinishContextNogil() + #define __Pyx_RefNannyFinishContext() + #define __Pyx_INCREF(r) Py_INCREF(r) + #define __Pyx_DECREF(r) Py_DECREF(r) + #define __Pyx_GOTREF(r) + #define __Pyx_GIVEREF(r) + #define __Pyx_XINCREF(r) Py_XINCREF(r) + #define __Pyx_XDECREF(r) Py_XDECREF(r) + #define __Pyx_XGOTREF(r) + #define __Pyx_XGIVEREF(r) +#endif +#define __Pyx_Py_XDECREF_SET(r, v) do {\ + PyObject *tmp = (PyObject *) r;\ + r = v; Py_XDECREF(tmp);\ + } while (0) +#define __Pyx_XDECREF_SET(r, v) do {\ + PyObject *tmp = (PyObject *) r;\ + r = v; __Pyx_XDECREF(tmp);\ + } while (0) +#define __Pyx_DECREF_SET(r, v) do {\ + PyObject *tmp = (PyObject *) r;\ + r = v; __Pyx_DECREF(tmp);\ + } while (0) +#define __Pyx_CLEAR(r) do { PyObject* tmp = ((PyObject*)(r)); r = NULL; __Pyx_DECREF(tmp);} while(0) +#define __Pyx_XCLEAR(r) do { if((r) != NULL) {PyObject* tmp = ((PyObject*)(r)); r = NULL; __Pyx_DECREF(tmp);}} while(0) + +/* PyErrExceptionMatches.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_PyErr_ExceptionMatches(err) __Pyx_PyErr_ExceptionMatchesInState(__pyx_tstate, err) +static CYTHON_INLINE int __Pyx_PyErr_ExceptionMatchesInState(PyThreadState* tstate, PyObject* err); +#else +#define __Pyx_PyErr_ExceptionMatches(err) PyErr_ExceptionMatches(err) +#endif + +/* PyThreadStateGet.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_PyThreadState_declare PyThreadState *__pyx_tstate; +#define __Pyx_PyThreadState_assign __pyx_tstate = __Pyx_PyThreadState_Current; +#if PY_VERSION_HEX >= 0x030C00A6 +#define __Pyx_PyErr_Occurred() (__pyx_tstate->current_exception != NULL) +#define __Pyx_PyErr_CurrentExceptionType() (__pyx_tstate->current_exception ? (PyObject*) Py_TYPE(__pyx_tstate->current_exception) : (PyObject*) NULL) +#else +#define __Pyx_PyErr_Occurred() (__pyx_tstate->curexc_type != NULL) +#define __Pyx_PyErr_CurrentExceptionType() (__pyx_tstate->curexc_type) +#endif +#else +#define __Pyx_PyThreadState_declare +#define __Pyx_PyThreadState_assign +#define __Pyx_PyErr_Occurred() (PyErr_Occurred() != NULL) +#define __Pyx_PyErr_CurrentExceptionType() PyErr_Occurred() +#endif + +/* PyErrFetchRestore.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_PyErr_Clear() __Pyx_ErrRestore(NULL, NULL, NULL) +#define __Pyx_ErrRestoreWithState(type, value, tb) __Pyx_ErrRestoreInState(PyThreadState_GET(), type, value, tb) +#define __Pyx_ErrFetchWithState(type, value, tb) __Pyx_ErrFetchInState(PyThreadState_GET(), type, value, tb) +#define __Pyx_ErrRestore(type, value, tb) __Pyx_ErrRestoreInState(__pyx_tstate, type, value, tb) +#define __Pyx_ErrFetch(type, value, tb) __Pyx_ErrFetchInState(__pyx_tstate, type, value, tb) +static CYTHON_INLINE void __Pyx_ErrRestoreInState(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb); +static CYTHON_INLINE void __Pyx_ErrFetchInState(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb); +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A6 +#define __Pyx_PyErr_SetNone(exc) (Py_INCREF(exc), __Pyx_ErrRestore((exc), NULL, NULL)) +#else +#define __Pyx_PyErr_SetNone(exc) PyErr_SetNone(exc) +#endif +#else +#define __Pyx_PyErr_Clear() PyErr_Clear() +#define __Pyx_PyErr_SetNone(exc) PyErr_SetNone(exc) +#define __Pyx_ErrRestoreWithState(type, value, tb) PyErr_Restore(type, value, tb) +#define __Pyx_ErrFetchWithState(type, value, tb) PyErr_Fetch(type, value, tb) +#define __Pyx_ErrRestoreInState(tstate, type, value, tb) PyErr_Restore(type, value, tb) +#define __Pyx_ErrFetchInState(tstate, type, value, tb) PyErr_Fetch(type, value, tb) +#define __Pyx_ErrRestore(type, value, tb) PyErr_Restore(type, value, tb) +#define __Pyx_ErrFetch(type, value, tb) PyErr_Fetch(type, value, tb) +#endif + +/* PyObjectGetAttrStr.proto */ +#if CYTHON_USE_TYPE_SLOTS +static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStr(PyObject* obj, PyObject* attr_name); +#else +#define __Pyx_PyObject_GetAttrStr(o,n) PyObject_GetAttr(o,n) +#endif + +/* PyObjectGetAttrStrNoError.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStrNoError(PyObject* obj, PyObject* attr_name); + +/* GetBuiltinName.proto */ +static PyObject *__Pyx_GetBuiltinName(PyObject *name); + +/* TupleAndListFromArray.proto */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE PyObject* __Pyx_PyList_FromArray(PyObject *const *src, Py_ssize_t n); +static CYTHON_INLINE PyObject* __Pyx_PyTuple_FromArray(PyObject *const *src, Py_ssize_t n); +#endif + +/* IncludeStringH.proto */ +#include + +/* BytesEquals.proto */ +static CYTHON_INLINE int __Pyx_PyBytes_Equals(PyObject* s1, PyObject* s2, int equals); + +/* UnicodeEquals.proto */ +static CYTHON_INLINE int __Pyx_PyUnicode_Equals(PyObject* s1, PyObject* s2, int equals); + +/* fastcall.proto */ +#if CYTHON_AVOID_BORROWED_REFS + #define __Pyx_Arg_VARARGS(args, i) PySequence_GetItem(args, i) +#elif CYTHON_ASSUME_SAFE_MACROS + #define __Pyx_Arg_VARARGS(args, i) PyTuple_GET_ITEM(args, i) +#else + #define __Pyx_Arg_VARARGS(args, i) PyTuple_GetItem(args, i) +#endif +#if CYTHON_AVOID_BORROWED_REFS + #define __Pyx_Arg_NewRef_VARARGS(arg) __Pyx_NewRef(arg) + #define __Pyx_Arg_XDECREF_VARARGS(arg) Py_XDECREF(arg) +#else + #define __Pyx_Arg_NewRef_VARARGS(arg) arg + #define __Pyx_Arg_XDECREF_VARARGS(arg) +#endif +#define __Pyx_NumKwargs_VARARGS(kwds) PyDict_Size(kwds) +#define __Pyx_KwValues_VARARGS(args, nargs) NULL +#define __Pyx_GetKwValue_VARARGS(kw, kwvalues, s) __Pyx_PyDict_GetItemStrWithError(kw, s) +#define __Pyx_KwargsAsDict_VARARGS(kw, kwvalues) PyDict_Copy(kw) +#if CYTHON_METH_FASTCALL + #define __Pyx_Arg_FASTCALL(args, i) args[i] + #define __Pyx_NumKwargs_FASTCALL(kwds) PyTuple_GET_SIZE(kwds) + #define __Pyx_KwValues_FASTCALL(args, nargs) ((args) + (nargs)) + static CYTHON_INLINE PyObject * __Pyx_GetKwValue_FASTCALL(PyObject *kwnames, PyObject *const *kwvalues, PyObject *s); +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030d0000 + CYTHON_UNUSED static PyObject *__Pyx_KwargsAsDict_FASTCALL(PyObject *kwnames, PyObject *const *kwvalues); + #else + #define __Pyx_KwargsAsDict_FASTCALL(kw, kwvalues) _PyStack_AsDict(kwvalues, kw) + #endif + #define __Pyx_Arg_NewRef_FASTCALL(arg) arg /* no-op, __Pyx_Arg_FASTCALL is direct and this needs + to have the same reference counting */ + #define __Pyx_Arg_XDECREF_FASTCALL(arg) +#else + #define __Pyx_Arg_FASTCALL __Pyx_Arg_VARARGS + #define __Pyx_NumKwargs_FASTCALL __Pyx_NumKwargs_VARARGS + #define __Pyx_KwValues_FASTCALL __Pyx_KwValues_VARARGS + #define __Pyx_GetKwValue_FASTCALL __Pyx_GetKwValue_VARARGS + #define __Pyx_KwargsAsDict_FASTCALL __Pyx_KwargsAsDict_VARARGS + #define __Pyx_Arg_NewRef_FASTCALL(arg) __Pyx_Arg_NewRef_VARARGS(arg) + #define __Pyx_Arg_XDECREF_FASTCALL(arg) __Pyx_Arg_XDECREF_VARARGS(arg) +#endif +#if CYTHON_COMPILING_IN_CPYTHON && CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS +#define __Pyx_ArgsSlice_VARARGS(args, start, stop) __Pyx_PyTuple_FromArray(&__Pyx_Arg_VARARGS(args, start), stop - start) +#define __Pyx_ArgsSlice_FASTCALL(args, start, stop) __Pyx_PyTuple_FromArray(&__Pyx_Arg_FASTCALL(args, start), stop - start) +#else +#define __Pyx_ArgsSlice_VARARGS(args, start, stop) PyTuple_GetSlice(args, start, stop) +#define __Pyx_ArgsSlice_FASTCALL(args, start, stop) PyTuple_GetSlice(args, start, stop) +#endif + +/* RaiseArgTupleInvalid.proto */ +static void __Pyx_RaiseArgtupleInvalid(const char* func_name, int exact, + Py_ssize_t num_min, Py_ssize_t num_max, Py_ssize_t num_found); + +/* RaiseDoubleKeywords.proto */ +static void __Pyx_RaiseDoubleKeywordsError(const char* func_name, PyObject* kw_name); + +/* ParseKeywords.proto */ +static int __Pyx_ParseOptionalKeywords(PyObject *kwds, PyObject *const *kwvalues, + PyObject **argnames[], + PyObject *kwds2, PyObject *values[], Py_ssize_t num_pos_args, + const char* function_name); + +/* ArgTypeTest.proto */ +#define __Pyx_ArgTypeTest(obj, type, none_allowed, name, exact)\ + ((likely(__Pyx_IS_TYPE(obj, type) | (none_allowed && (obj == Py_None)))) ? 1 :\ + __Pyx__ArgTypeTest(obj, type, name, exact)) +static int __Pyx__ArgTypeTest(PyObject *obj, PyTypeObject *type, const char *name, int exact); + +/* RaiseException.proto */ +static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause); + +/* PyFunctionFastCall.proto */ +#if CYTHON_FAST_PYCALL +#if !CYTHON_VECTORCALL +#define __Pyx_PyFunction_FastCall(func, args, nargs)\ + __Pyx_PyFunction_FastCallDict((func), (args), (nargs), NULL) +static PyObject *__Pyx_PyFunction_FastCallDict(PyObject *func, PyObject **args, Py_ssize_t nargs, PyObject *kwargs); +#endif +#define __Pyx_BUILD_ASSERT_EXPR(cond)\ + (sizeof(char [1 - 2*!(cond)]) - 1) +#ifndef Py_MEMBER_SIZE +#define Py_MEMBER_SIZE(type, member) sizeof(((type *)0)->member) +#endif +#if !CYTHON_VECTORCALL +#if PY_VERSION_HEX >= 0x03080000 + #include "frameobject.h" +#if PY_VERSION_HEX >= 0x030b00a6 && !CYTHON_COMPILING_IN_LIMITED_API + #ifndef Py_BUILD_CORE + #define Py_BUILD_CORE 1 + #endif + #include "internal/pycore_frame.h" +#endif + #define __Pxy_PyFrame_Initialize_Offsets() + #define __Pyx_PyFrame_GetLocalsplus(frame) ((frame)->f_localsplus) +#else + static size_t __pyx_pyframe_localsplus_offset = 0; + #include "frameobject.h" + #define __Pxy_PyFrame_Initialize_Offsets()\ + ((void)__Pyx_BUILD_ASSERT_EXPR(sizeof(PyFrameObject) == offsetof(PyFrameObject, f_localsplus) + Py_MEMBER_SIZE(PyFrameObject, f_localsplus)),\ + (void)(__pyx_pyframe_localsplus_offset = ((size_t)PyFrame_Type.tp_basicsize) - Py_MEMBER_SIZE(PyFrameObject, f_localsplus))) + #define __Pyx_PyFrame_GetLocalsplus(frame)\ + (assert(__pyx_pyframe_localsplus_offset), (PyObject **)(((char *)(frame)) + __pyx_pyframe_localsplus_offset)) +#endif +#endif +#endif + +/* PyObjectCall.proto */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg, PyObject *kw); +#else +#define __Pyx_PyObject_Call(func, arg, kw) PyObject_Call(func, arg, kw) +#endif + +/* PyObjectCallMethO.proto */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE PyObject* __Pyx_PyObject_CallMethO(PyObject *func, PyObject *arg); +#endif + +/* PyObjectFastCall.proto */ +#define __Pyx_PyObject_FastCall(func, args, nargs) __Pyx_PyObject_FastCallDict(func, args, (size_t)(nargs), NULL) +static CYTHON_INLINE PyObject* __Pyx_PyObject_FastCallDict(PyObject *func, PyObject **args, size_t nargs, PyObject *kwargs); + +/* RaiseUnexpectedTypeError.proto */ +static int __Pyx_RaiseUnexpectedTypeError(const char *expected, PyObject *obj); + +/* GCCDiagnostics.proto */ +#if !defined(__INTEL_COMPILER) && defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6)) +#define __Pyx_HAS_GCC_DIAGNOSTIC +#endif + +/* BuildPyUnicode.proto */ +static PyObject* __Pyx_PyUnicode_BuildFromAscii(Py_ssize_t ulength, char* chars, int clength, + int prepend_sign, char padding_char); + +/* CIntToPyUnicode.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_From_int(int value, Py_ssize_t width, char padding_char, char format_char); + +/* CIntToPyUnicode.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_From_Py_ssize_t(Py_ssize_t value, Py_ssize_t width, char padding_char, char format_char); + +/* JoinPyUnicode.proto */ +static PyObject* __Pyx_PyUnicode_Join(PyObject* value_tuple, Py_ssize_t value_count, Py_ssize_t result_ulength, + Py_UCS4 max_char); + +/* StrEquals.proto */ +#if PY_MAJOR_VERSION >= 3 +#define __Pyx_PyString_Equals __Pyx_PyUnicode_Equals +#else +#define __Pyx_PyString_Equals __Pyx_PyBytes_Equals +#endif + +/* PyObjectFormatSimple.proto */ +#if CYTHON_COMPILING_IN_PYPY + #define __Pyx_PyObject_FormatSimple(s, f) (\ + likely(PyUnicode_CheckExact(s)) ? (Py_INCREF(s), s) :\ + PyObject_Format(s, f)) +#elif PY_MAJOR_VERSION < 3 + #define __Pyx_PyObject_FormatSimple(s, f) (\ + likely(PyUnicode_CheckExact(s)) ? (Py_INCREF(s), s) :\ + likely(PyString_CheckExact(s)) ? PyUnicode_FromEncodedObject(s, NULL, "strict") :\ + PyObject_Format(s, f)) +#elif CYTHON_USE_TYPE_SLOTS + #define __Pyx_PyObject_FormatSimple(s, f) (\ + likely(PyUnicode_CheckExact(s)) ? (Py_INCREF(s), s) :\ + likely(PyLong_CheckExact(s)) ? PyLong_Type.tp_repr(s) :\ + likely(PyFloat_CheckExact(s)) ? PyFloat_Type.tp_repr(s) :\ + PyObject_Format(s, f)) +#else + #define __Pyx_PyObject_FormatSimple(s, f) (\ + likely(PyUnicode_CheckExact(s)) ? (Py_INCREF(s), s) :\ + PyObject_Format(s, f)) +#endif + +CYTHON_UNUSED static int __pyx_array_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /*proto*/ +static PyObject *__pyx_array_get_memview(struct __pyx_array_obj *); /*proto*/ +/* GetAttr.proto */ +static CYTHON_INLINE PyObject *__Pyx_GetAttr(PyObject *, PyObject *); + +/* GetItemInt.proto */ +#define __Pyx_GetItemInt(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ + (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ + __Pyx_GetItemInt_Fast(o, (Py_ssize_t)i, is_list, wraparound, boundscheck) :\ + (is_list ? (PyErr_SetString(PyExc_IndexError, "list index out of range"), (PyObject*)NULL) :\ + __Pyx_GetItemInt_Generic(o, to_py_func(i)))) +#define __Pyx_GetItemInt_List(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ + (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ + __Pyx_GetItemInt_List_Fast(o, (Py_ssize_t)i, wraparound, boundscheck) :\ + (PyErr_SetString(PyExc_IndexError, "list index out of range"), (PyObject*)NULL)) +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_List_Fast(PyObject *o, Py_ssize_t i, + int wraparound, int boundscheck); +#define __Pyx_GetItemInt_Tuple(o, i, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ + (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ + __Pyx_GetItemInt_Tuple_Fast(o, (Py_ssize_t)i, wraparound, boundscheck) :\ + (PyErr_SetString(PyExc_IndexError, "tuple index out of range"), (PyObject*)NULL)) +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Tuple_Fast(PyObject *o, Py_ssize_t i, + int wraparound, int boundscheck); +static PyObject *__Pyx_GetItemInt_Generic(PyObject *o, PyObject* j); +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Fast(PyObject *o, Py_ssize_t i, + int is_list, int wraparound, int boundscheck); + +/* PyObjectCallOneArg.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObject *arg); + +/* ObjectGetItem.proto */ +#if CYTHON_USE_TYPE_SLOTS +static CYTHON_INLINE PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject *key); +#else +#define __Pyx_PyObject_GetItem(obj, key) PyObject_GetItem(obj, key) +#endif + +/* KeywordStringCheck.proto */ +static int __Pyx_CheckKeywordStrings(PyObject *kw, const char* function_name, int kw_allowed); + +/* DivInt[Py_ssize_t].proto */ +static CYTHON_INLINE Py_ssize_t __Pyx_div_Py_ssize_t(Py_ssize_t, Py_ssize_t); + +/* UnaryNegOverflows.proto */ +#define __Pyx_UNARY_NEG_WOULD_OVERFLOW(x)\ + (((x) < 0) & ((unsigned long)(x) == 0-(unsigned long)(x))) + +/* GetAttr3.proto */ +static CYTHON_INLINE PyObject *__Pyx_GetAttr3(PyObject *, PyObject *, PyObject *); + +/* PyDictVersioning.proto */ +#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_TYPE_SLOTS +#define __PYX_DICT_VERSION_INIT ((PY_UINT64_T) -1) +#define __PYX_GET_DICT_VERSION(dict) (((PyDictObject*)(dict))->ma_version_tag) +#define __PYX_UPDATE_DICT_CACHE(dict, value, cache_var, version_var)\ + (version_var) = __PYX_GET_DICT_VERSION(dict);\ + (cache_var) = (value); +#define __PYX_PY_DICT_LOOKUP_IF_MODIFIED(VAR, DICT, LOOKUP) {\ + static PY_UINT64_T __pyx_dict_version = 0;\ + static PyObject *__pyx_dict_cached_value = NULL;\ + if (likely(__PYX_GET_DICT_VERSION(DICT) == __pyx_dict_version)) {\ + (VAR) = __pyx_dict_cached_value;\ + } else {\ + (VAR) = __pyx_dict_cached_value = (LOOKUP);\ + __pyx_dict_version = __PYX_GET_DICT_VERSION(DICT);\ + }\ +} +static CYTHON_INLINE PY_UINT64_T __Pyx_get_tp_dict_version(PyObject *obj); +static CYTHON_INLINE PY_UINT64_T __Pyx_get_object_dict_version(PyObject *obj); +static CYTHON_INLINE int __Pyx_object_dict_version_matches(PyObject* obj, PY_UINT64_T tp_dict_version, PY_UINT64_T obj_dict_version); +#else +#define __PYX_GET_DICT_VERSION(dict) (0) +#define __PYX_UPDATE_DICT_CACHE(dict, value, cache_var, version_var) +#define __PYX_PY_DICT_LOOKUP_IF_MODIFIED(VAR, DICT, LOOKUP) (VAR) = (LOOKUP); +#endif + +/* GetModuleGlobalName.proto */ +#if CYTHON_USE_DICT_VERSIONS +#define __Pyx_GetModuleGlobalName(var, name) do {\ + static PY_UINT64_T __pyx_dict_version = 0;\ + static PyObject *__pyx_dict_cached_value = NULL;\ + (var) = (likely(__pyx_dict_version == __PYX_GET_DICT_VERSION(__pyx_d))) ?\ + (likely(__pyx_dict_cached_value) ? __Pyx_NewRef(__pyx_dict_cached_value) : __Pyx_GetBuiltinName(name)) :\ + __Pyx__GetModuleGlobalName(name, &__pyx_dict_version, &__pyx_dict_cached_value);\ +} while(0) +#define __Pyx_GetModuleGlobalNameUncached(var, name) do {\ + PY_UINT64_T __pyx_dict_version;\ + PyObject *__pyx_dict_cached_value;\ + (var) = __Pyx__GetModuleGlobalName(name, &__pyx_dict_version, &__pyx_dict_cached_value);\ +} while(0) +static PyObject *__Pyx__GetModuleGlobalName(PyObject *name, PY_UINT64_T *dict_version, PyObject **dict_cached_value); +#else +#define __Pyx_GetModuleGlobalName(var, name) (var) = __Pyx__GetModuleGlobalName(name) +#define __Pyx_GetModuleGlobalNameUncached(var, name) (var) = __Pyx__GetModuleGlobalName(name) +static CYTHON_INLINE PyObject *__Pyx__GetModuleGlobalName(PyObject *name); +#endif + +/* AssertionsEnabled.proto */ +#if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX < 0x02070600 && !defined(Py_OptimizeFlag) + #define __Pyx_init_assertions_enabled() (0) + #define __pyx_assertions_enabled() (1) +#elif CYTHON_COMPILING_IN_LIMITED_API || (CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030C0000) + static int __pyx_assertions_enabled_flag; + #define __pyx_assertions_enabled() (__pyx_assertions_enabled_flag) + static int __Pyx_init_assertions_enabled(void) { + PyObject *builtins, *debug, *debug_str; + int flag; + builtins = PyEval_GetBuiltins(); + if (!builtins) goto bad; + debug_str = PyUnicode_FromStringAndSize("__debug__", 9); + if (!debug_str) goto bad; + debug = PyObject_GetItem(builtins, debug_str); + Py_DECREF(debug_str); + if (!debug) goto bad; + flag = PyObject_IsTrue(debug); + Py_DECREF(debug); + if (flag == -1) goto bad; + __pyx_assertions_enabled_flag = flag; + return 0; + bad: + __pyx_assertions_enabled_flag = 1; + return -1; + } +#else + #define __Pyx_init_assertions_enabled() (0) + #define __pyx_assertions_enabled() (!Py_OptimizeFlag) +#endif + +/* RaiseTooManyValuesToUnpack.proto */ +static CYTHON_INLINE void __Pyx_RaiseTooManyValuesError(Py_ssize_t expected); + +/* RaiseNeedMoreValuesToUnpack.proto */ +static CYTHON_INLINE void __Pyx_RaiseNeedMoreValuesError(Py_ssize_t index); + +/* RaiseNoneIterError.proto */ +static CYTHON_INLINE void __Pyx_RaiseNoneNotIterableError(void); + +/* ExtTypeTest.proto */ +static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type); + +/* GetTopmostException.proto */ +#if CYTHON_USE_EXC_INFO_STACK && CYTHON_FAST_THREAD_STATE +static _PyErr_StackItem * __Pyx_PyErr_GetTopmostException(PyThreadState *tstate); +#endif + +/* SaveResetException.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_ExceptionSave(type, value, tb) __Pyx__ExceptionSave(__pyx_tstate, type, value, tb) +static CYTHON_INLINE void __Pyx__ExceptionSave(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb); +#define __Pyx_ExceptionReset(type, value, tb) __Pyx__ExceptionReset(__pyx_tstate, type, value, tb) +static CYTHON_INLINE void __Pyx__ExceptionReset(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb); +#else +#define __Pyx_ExceptionSave(type, value, tb) PyErr_GetExcInfo(type, value, tb) +#define __Pyx_ExceptionReset(type, value, tb) PyErr_SetExcInfo(type, value, tb) +#endif + +/* GetException.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_GetException(type, value, tb) __Pyx__GetException(__pyx_tstate, type, value, tb) +static int __Pyx__GetException(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb); +#else +static int __Pyx_GetException(PyObject **type, PyObject **value, PyObject **tb); +#endif + +/* SwapException.proto */ +#if CYTHON_FAST_THREAD_STATE +#define __Pyx_ExceptionSwap(type, value, tb) __Pyx__ExceptionSwap(__pyx_tstate, type, value, tb) +static CYTHON_INLINE void __Pyx__ExceptionSwap(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb); +#else +static CYTHON_INLINE void __Pyx_ExceptionSwap(PyObject **type, PyObject **value, PyObject **tb); +#endif + +/* Import.proto */ +static PyObject *__Pyx_Import(PyObject *name, PyObject *from_list, int level); + +/* ImportDottedModule.proto */ +static PyObject *__Pyx_ImportDottedModule(PyObject *name, PyObject *parts_tuple); +#if PY_MAJOR_VERSION >= 3 +static PyObject *__Pyx_ImportDottedModule_WalkParts(PyObject *module, PyObject *name, PyObject *parts_tuple); +#endif + +/* FastTypeChecks.proto */ +#if CYTHON_COMPILING_IN_CPYTHON +#define __Pyx_TypeCheck(obj, type) __Pyx_IsSubtype(Py_TYPE(obj), (PyTypeObject *)type) +#define __Pyx_TypeCheck2(obj, type1, type2) __Pyx_IsAnySubtype2(Py_TYPE(obj), (PyTypeObject *)type1, (PyTypeObject *)type2) +static CYTHON_INLINE int __Pyx_IsSubtype(PyTypeObject *a, PyTypeObject *b); +static CYTHON_INLINE int __Pyx_IsAnySubtype2(PyTypeObject *cls, PyTypeObject *a, PyTypeObject *b); +static CYTHON_INLINE int __Pyx_PyErr_GivenExceptionMatches(PyObject *err, PyObject *type); +static CYTHON_INLINE int __Pyx_PyErr_GivenExceptionMatches2(PyObject *err, PyObject *type1, PyObject *type2); +#else +#define __Pyx_TypeCheck(obj, type) PyObject_TypeCheck(obj, (PyTypeObject *)type) +#define __Pyx_TypeCheck2(obj, type1, type2) (PyObject_TypeCheck(obj, (PyTypeObject *)type1) || PyObject_TypeCheck(obj, (PyTypeObject *)type2)) +#define __Pyx_PyErr_GivenExceptionMatches(err, type) PyErr_GivenExceptionMatches(err, type) +#define __Pyx_PyErr_GivenExceptionMatches2(err, type1, type2) (PyErr_GivenExceptionMatches(err, type1) || PyErr_GivenExceptionMatches(err, type2)) +#endif +#define __Pyx_PyErr_ExceptionMatches2(err1, err2) __Pyx_PyErr_GivenExceptionMatches2(__Pyx_PyErr_CurrentExceptionType(), err1, err2) +#define __Pyx_PyException_Check(obj) __Pyx_TypeCheck(obj, PyExc_Exception) + +CYTHON_UNUSED static int __pyx_memoryview_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /*proto*/ +/* ListCompAppend.proto */ +#if CYTHON_USE_PYLIST_INTERNALS && CYTHON_ASSUME_SAFE_MACROS +static CYTHON_INLINE int __Pyx_ListComp_Append(PyObject* list, PyObject* x) { + PyListObject* L = (PyListObject*) list; + Py_ssize_t len = Py_SIZE(list); + if (likely(L->allocated > len)) { + Py_INCREF(x); + #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030d0000 + L->ob_item[len] = x; + #else + PyList_SET_ITEM(list, len, x); + #endif + __Pyx_SET_SIZE(list, len + 1); + return 0; + } + return PyList_Append(list, x); +} +#else +#define __Pyx_ListComp_Append(L,x) PyList_Append(L,x) +#endif + +/* PySequenceMultiply.proto */ +#define __Pyx_PySequence_Multiply_Left(mul, seq) __Pyx_PySequence_Multiply(seq, mul) +static CYTHON_INLINE PyObject* __Pyx_PySequence_Multiply(PyObject *seq, Py_ssize_t mul); + +/* SetItemInt.proto */ +#define __Pyx_SetItemInt(o, i, v, type, is_signed, to_py_func, is_list, wraparound, boundscheck)\ + (__Pyx_fits_Py_ssize_t(i, type, is_signed) ?\ + __Pyx_SetItemInt_Fast(o, (Py_ssize_t)i, v, is_list, wraparound, boundscheck) :\ + (is_list ? (PyErr_SetString(PyExc_IndexError, "list assignment index out of range"), -1) :\ + __Pyx_SetItemInt_Generic(o, to_py_func(i), v))) +static int __Pyx_SetItemInt_Generic(PyObject *o, PyObject *j, PyObject *v); +static CYTHON_INLINE int __Pyx_SetItemInt_Fast(PyObject *o, Py_ssize_t i, PyObject *v, + int is_list, int wraparound, int boundscheck); + +/* RaiseUnboundLocalError.proto */ +static CYTHON_INLINE void __Pyx_RaiseUnboundLocalError(const char *varname); + +/* DivInt[long].proto */ +static CYTHON_INLINE long __Pyx_div_long(long, long); + +/* PySequenceContains.proto */ +static CYTHON_INLINE int __Pyx_PySequence_ContainsTF(PyObject* item, PyObject* seq, int eq) { + int result = PySequence_Contains(seq, item); + return unlikely(result < 0) ? result : (result == (eq == Py_EQ)); +} + +/* ImportFrom.proto */ +static PyObject* __Pyx_ImportFrom(PyObject* module, PyObject* name); + +/* HasAttr.proto */ +static CYTHON_INLINE int __Pyx_HasAttr(PyObject *, PyObject *); + +/* ErrOccurredWithGIL.proto */ +static CYTHON_INLINE int __Pyx_ErrOccurredWithGIL(void); + +/* PyObject_GenericGetAttrNoDict.proto */ +#if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 +static CYTHON_INLINE PyObject* __Pyx_PyObject_GenericGetAttrNoDict(PyObject* obj, PyObject* attr_name); +#else +#define __Pyx_PyObject_GenericGetAttrNoDict PyObject_GenericGetAttr +#endif + +/* PyObject_GenericGetAttr.proto */ +#if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 +static PyObject* __Pyx_PyObject_GenericGetAttr(PyObject* obj, PyObject* attr_name); +#else +#define __Pyx_PyObject_GenericGetAttr PyObject_GenericGetAttr +#endif + +/* IncludeStructmemberH.proto */ +#include + +/* FixUpExtensionType.proto */ +#if CYTHON_USE_TYPE_SPECS +static int __Pyx_fix_up_extension_type_from_spec(PyType_Spec *spec, PyTypeObject *type); +#endif + +/* PyObjectCallNoArg.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyObject_CallNoArg(PyObject *func); + +/* PyObjectGetMethod.proto */ +static int __Pyx_PyObject_GetMethod(PyObject *obj, PyObject *name, PyObject **method); + +/* PyObjectCallMethod0.proto */ +static PyObject* __Pyx_PyObject_CallMethod0(PyObject* obj, PyObject* method_name); + +/* ValidateBasesTuple.proto */ +#if CYTHON_COMPILING_IN_CPYTHON || CYTHON_COMPILING_IN_LIMITED_API || CYTHON_USE_TYPE_SPECS +static int __Pyx_validate_bases_tuple(const char *type_name, Py_ssize_t dictoffset, PyObject *bases); +#endif + +/* PyType_Ready.proto */ +CYTHON_UNUSED static int __Pyx_PyType_Ready(PyTypeObject *t); + +/* SetVTable.proto */ +static int __Pyx_SetVtable(PyTypeObject* typeptr , void* vtable); + +/* GetVTable.proto */ +static void* __Pyx_GetVtable(PyTypeObject *type); + +/* MergeVTables.proto */ +#if !CYTHON_COMPILING_IN_LIMITED_API +static int __Pyx_MergeVtables(PyTypeObject *type); +#endif + +/* SetupReduce.proto */ +#if !CYTHON_COMPILING_IN_LIMITED_API +static int __Pyx_setup_reduce(PyObject* type_obj); +#endif + +/* TypeImport.proto */ +#ifndef __PYX_HAVE_RT_ImportType_proto_3_0_11 +#define __PYX_HAVE_RT_ImportType_proto_3_0_11 +#if defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L +#include +#endif +#if (defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L) || __cplusplus >= 201103L +#define __PYX_GET_STRUCT_ALIGNMENT_3_0_11(s) alignof(s) +#else +#define __PYX_GET_STRUCT_ALIGNMENT_3_0_11(s) sizeof(void*) +#endif +enum __Pyx_ImportType_CheckSize_3_0_11 { + __Pyx_ImportType_CheckSize_Error_3_0_11 = 0, + __Pyx_ImportType_CheckSize_Warn_3_0_11 = 1, + __Pyx_ImportType_CheckSize_Ignore_3_0_11 = 2 +}; +static PyTypeObject *__Pyx_ImportType_3_0_11(PyObject* module, const char *module_name, const char *class_name, size_t size, size_t alignment, enum __Pyx_ImportType_CheckSize_3_0_11 check_size); +#endif + +/* FetchSharedCythonModule.proto */ +static PyObject *__Pyx_FetchSharedCythonABIModule(void); + +/* FetchCommonType.proto */ +#if !CYTHON_USE_TYPE_SPECS +static PyTypeObject* __Pyx_FetchCommonType(PyTypeObject* type); +#else +static PyTypeObject* __Pyx_FetchCommonTypeFromSpec(PyObject *module, PyType_Spec *spec, PyObject *bases); +#endif + +/* PyMethodNew.proto */ +#if CYTHON_COMPILING_IN_LIMITED_API +static PyObject *__Pyx_PyMethod_New(PyObject *func, PyObject *self, PyObject *typ) { + PyObject *typesModule=NULL, *methodType=NULL, *result=NULL; + CYTHON_UNUSED_VAR(typ); + if (!self) + return __Pyx_NewRef(func); + typesModule = PyImport_ImportModule("types"); + if (!typesModule) return NULL; + methodType = PyObject_GetAttrString(typesModule, "MethodType"); + Py_DECREF(typesModule); + if (!methodType) return NULL; + result = PyObject_CallFunctionObjArgs(methodType, func, self, NULL); + Py_DECREF(methodType); + return result; +} +#elif PY_MAJOR_VERSION >= 3 +static PyObject *__Pyx_PyMethod_New(PyObject *func, PyObject *self, PyObject *typ) { + CYTHON_UNUSED_VAR(typ); + if (!self) + return __Pyx_NewRef(func); + return PyMethod_New(func, self); +} +#else + #define __Pyx_PyMethod_New PyMethod_New +#endif + +/* PyVectorcallFastCallDict.proto */ +#if CYTHON_METH_FASTCALL +static CYTHON_INLINE PyObject *__Pyx_PyVectorcall_FastCallDict(PyObject *func, __pyx_vectorcallfunc vc, PyObject *const *args, size_t nargs, PyObject *kw); +#endif + +/* CythonFunctionShared.proto */ +#define __Pyx_CyFunction_USED +#define __Pyx_CYFUNCTION_STATICMETHOD 0x01 +#define __Pyx_CYFUNCTION_CLASSMETHOD 0x02 +#define __Pyx_CYFUNCTION_CCLASS 0x04 +#define __Pyx_CYFUNCTION_COROUTINE 0x08 +#define __Pyx_CyFunction_GetClosure(f)\ + (((__pyx_CyFunctionObject *) (f))->func_closure) +#if PY_VERSION_HEX < 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API + #define __Pyx_CyFunction_GetClassObj(f)\ + (((__pyx_CyFunctionObject *) (f))->func_classobj) +#else + #define __Pyx_CyFunction_GetClassObj(f)\ + ((PyObject*) ((PyCMethodObject *) (f))->mm_class) +#endif +#define __Pyx_CyFunction_SetClassObj(f, classobj)\ + __Pyx__CyFunction_SetClassObj((__pyx_CyFunctionObject *) (f), (classobj)) +#define __Pyx_CyFunction_Defaults(type, f)\ + ((type *)(((__pyx_CyFunctionObject *) (f))->defaults)) +#define __Pyx_CyFunction_SetDefaultsGetter(f, g)\ + ((__pyx_CyFunctionObject *) (f))->defaults_getter = (g) +typedef struct { +#if CYTHON_COMPILING_IN_LIMITED_API + PyObject_HEAD + PyObject *func; +#elif PY_VERSION_HEX < 0x030900B1 + PyCFunctionObject func; +#else + PyCMethodObject func; +#endif +#if CYTHON_BACKPORT_VECTORCALL + __pyx_vectorcallfunc func_vectorcall; +#endif +#if PY_VERSION_HEX < 0x030500A0 || CYTHON_COMPILING_IN_LIMITED_API + PyObject *func_weakreflist; +#endif + PyObject *func_dict; + PyObject *func_name; + PyObject *func_qualname; + PyObject *func_doc; + PyObject *func_globals; + PyObject *func_code; + PyObject *func_closure; +#if PY_VERSION_HEX < 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API + PyObject *func_classobj; +#endif + void *defaults; + int defaults_pyobjects; + size_t defaults_size; + int flags; + PyObject *defaults_tuple; + PyObject *defaults_kwdict; + PyObject *(*defaults_getter)(PyObject *); + PyObject *func_annotations; + PyObject *func_is_coroutine; +} __pyx_CyFunctionObject; +#undef __Pyx_CyOrPyCFunction_Check +#define __Pyx_CyFunction_Check(obj) __Pyx_TypeCheck(obj, __pyx_CyFunctionType) +#define __Pyx_CyOrPyCFunction_Check(obj) __Pyx_TypeCheck2(obj, __pyx_CyFunctionType, &PyCFunction_Type) +#define __Pyx_CyFunction_CheckExact(obj) __Pyx_IS_TYPE(obj, __pyx_CyFunctionType) +static CYTHON_INLINE int __Pyx__IsSameCyOrCFunction(PyObject *func, void *cfunc); +#undef __Pyx_IsSameCFunction +#define __Pyx_IsSameCFunction(func, cfunc) __Pyx__IsSameCyOrCFunction(func, cfunc) +static PyObject *__Pyx_CyFunction_Init(__pyx_CyFunctionObject* op, PyMethodDef *ml, + int flags, PyObject* qualname, + PyObject *closure, + PyObject *module, PyObject *globals, + PyObject* code); +static CYTHON_INLINE void __Pyx__CyFunction_SetClassObj(__pyx_CyFunctionObject* f, PyObject* classobj); +static CYTHON_INLINE void *__Pyx_CyFunction_InitDefaults(PyObject *m, + size_t size, + int pyobjects); +static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsTuple(PyObject *m, + PyObject *tuple); +static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsKwDict(PyObject *m, + PyObject *dict); +static CYTHON_INLINE void __Pyx_CyFunction_SetAnnotationsDict(PyObject *m, + PyObject *dict); +static int __pyx_CyFunction_init(PyObject *module); +#if CYTHON_METH_FASTCALL +static PyObject * __Pyx_CyFunction_Vectorcall_NOARGS(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames); +static PyObject * __Pyx_CyFunction_Vectorcall_O(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames); +static PyObject * __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames); +static PyObject * __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS_METHOD(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames); +#if CYTHON_BACKPORT_VECTORCALL +#define __Pyx_CyFunction_func_vectorcall(f) (((__pyx_CyFunctionObject*)f)->func_vectorcall) +#else +#define __Pyx_CyFunction_func_vectorcall(f) (((PyCFunctionObject*)f)->vectorcall) +#endif +#endif + +/* CythonFunction.proto */ +static PyObject *__Pyx_CyFunction_New(PyMethodDef *ml, + int flags, PyObject* qualname, + PyObject *closure, + PyObject *module, PyObject *globals, + PyObject* code); + +/* CLineInTraceback.proto */ +#ifdef CYTHON_CLINE_IN_TRACEBACK +#define __Pyx_CLineForTraceback(tstate, c_line) (((CYTHON_CLINE_IN_TRACEBACK)) ? c_line : 0) +#else +static int __Pyx_CLineForTraceback(PyThreadState *tstate, int c_line); +#endif + +/* CodeObjectCache.proto */ +#if !CYTHON_COMPILING_IN_LIMITED_API +typedef struct { + PyCodeObject* code_object; + int code_line; +} __Pyx_CodeObjectCacheEntry; +struct __Pyx_CodeObjectCache { + int count; + int max_count; + __Pyx_CodeObjectCacheEntry* entries; +}; +static struct __Pyx_CodeObjectCache __pyx_code_cache = {0,0,NULL}; +static int __pyx_bisect_code_objects(__Pyx_CodeObjectCacheEntry* entries, int count, int code_line); +static PyCodeObject *__pyx_find_code_object(int code_line); +static void __pyx_insert_code_object(int code_line, PyCodeObject* code_object); +#endif + +/* AddTraceback.proto */ +static void __Pyx_AddTraceback(const char *funcname, int c_line, + int py_line, const char *filename); + +#if PY_MAJOR_VERSION < 3 + static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags); + static void __Pyx_ReleaseBuffer(Py_buffer *view); +#else + #define __Pyx_GetBuffer PyObject_GetBuffer + #define __Pyx_ReleaseBuffer PyBuffer_Release +#endif + + +/* BufferStructDeclare.proto */ +typedef struct { + Py_ssize_t shape, strides, suboffsets; +} __Pyx_Buf_DimInfo; +typedef struct { + size_t refcount; + Py_buffer pybuffer; +} __Pyx_Buffer; +typedef struct { + __Pyx_Buffer *rcbuffer; + char *data; + __Pyx_Buf_DimInfo diminfo[8]; +} __Pyx_LocalBuf_ND; + +/* MemviewSliceIsContig.proto */ +static int __pyx_memviewslice_is_contig(const __Pyx_memviewslice mvs, char order, int ndim); + +/* OverlappingSlices.proto */ +static int __pyx_slices_overlap(__Pyx_memviewslice *slice1, + __Pyx_memviewslice *slice2, + int ndim, size_t itemsize); + +/* IsLittleEndian.proto */ +static CYTHON_INLINE int __Pyx_Is_Little_Endian(void); + +/* BufferFormatCheck.proto */ +static const char* __Pyx_BufFmt_CheckString(__Pyx_BufFmt_Context* ctx, const char* ts); +static void __Pyx_BufFmt_Init(__Pyx_BufFmt_Context* ctx, + __Pyx_BufFmt_StackElem* stack, + __Pyx_TypeInfo* type); + +/* TypeInfoCompare.proto */ +static int __pyx_typeinfo_cmp(__Pyx_TypeInfo *a, __Pyx_TypeInfo *b); + +/* MemviewSliceValidateAndInit.proto */ +static int __Pyx_ValidateAndInit_memviewslice( + int *axes_specs, + int c_or_f_flag, + int buf_flags, + int ndim, + __Pyx_TypeInfo *dtype, + __Pyx_BufFmt_StackElem stack[], + __Pyx_memviewslice *memviewslice, + PyObject *original_obj); + +/* ObjectToMemviewSlice.proto */ +static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_d_d_dc_int(PyObject *, int writable_flag); + +/* ObjectToMemviewSlice.proto */ +static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_d_d_dc_float(PyObject *, int writable_flag); + +/* ObjectToMemviewSlice.proto */ +static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_dc_int(PyObject *, int writable_flag); + +/* RealImag.proto */ +#if CYTHON_CCOMPLEX + #ifdef __cplusplus + #define __Pyx_CREAL(z) ((z).real()) + #define __Pyx_CIMAG(z) ((z).imag()) + #else + #define __Pyx_CREAL(z) (__real__(z)) + #define __Pyx_CIMAG(z) (__imag__(z)) + #endif +#else + #define __Pyx_CREAL(z) ((z).real) + #define __Pyx_CIMAG(z) ((z).imag) +#endif +#if defined(__cplusplus) && CYTHON_CCOMPLEX\ + && (defined(_WIN32) || defined(__clang__) || (defined(__GNUC__) && (__GNUC__ >= 5 || __GNUC__ == 4 && __GNUC_MINOR__ >= 4 )) || __cplusplus >= 201103) + #define __Pyx_SET_CREAL(z,x) ((z).real(x)) + #define __Pyx_SET_CIMAG(z,y) ((z).imag(y)) +#else + #define __Pyx_SET_CREAL(z,x) __Pyx_CREAL(z) = (x) + #define __Pyx_SET_CIMAG(z,y) __Pyx_CIMAG(z) = (y) +#endif + +/* Arithmetic.proto */ +#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #define __Pyx_c_eq_float(a, b) ((a)==(b)) + #define __Pyx_c_sum_float(a, b) ((a)+(b)) + #define __Pyx_c_diff_float(a, b) ((a)-(b)) + #define __Pyx_c_prod_float(a, b) ((a)*(b)) + #define __Pyx_c_quot_float(a, b) ((a)/(b)) + #define __Pyx_c_neg_float(a) (-(a)) + #ifdef __cplusplus + #define __Pyx_c_is_zero_float(z) ((z)==(float)0) + #define __Pyx_c_conj_float(z) (::std::conj(z)) + #if 1 + #define __Pyx_c_abs_float(z) (::std::abs(z)) + #define __Pyx_c_pow_float(a, b) (::std::pow(a, b)) + #endif + #else + #define __Pyx_c_is_zero_float(z) ((z)==0) + #define __Pyx_c_conj_float(z) (conjf(z)) + #if 1 + #define __Pyx_c_abs_float(z) (cabsf(z)) + #define __Pyx_c_pow_float(a, b) (cpowf(a, b)) + #endif + #endif +#else + static CYTHON_INLINE int __Pyx_c_eq_float(__pyx_t_float_complex, __pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_sum_float(__pyx_t_float_complex, __pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_diff_float(__pyx_t_float_complex, __pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_prod_float(__pyx_t_float_complex, __pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_quot_float(__pyx_t_float_complex, __pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_neg_float(__pyx_t_float_complex); + static CYTHON_INLINE int __Pyx_c_is_zero_float(__pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_conj_float(__pyx_t_float_complex); + #if 1 + static CYTHON_INLINE float __Pyx_c_abs_float(__pyx_t_float_complex); + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_pow_float(__pyx_t_float_complex, __pyx_t_float_complex); + #endif +#endif + +/* Arithmetic.proto */ +#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #define __Pyx_c_eq_double(a, b) ((a)==(b)) + #define __Pyx_c_sum_double(a, b) ((a)+(b)) + #define __Pyx_c_diff_double(a, b) ((a)-(b)) + #define __Pyx_c_prod_double(a, b) ((a)*(b)) + #define __Pyx_c_quot_double(a, b) ((a)/(b)) + #define __Pyx_c_neg_double(a) (-(a)) + #ifdef __cplusplus + #define __Pyx_c_is_zero_double(z) ((z)==(double)0) + #define __Pyx_c_conj_double(z) (::std::conj(z)) + #if 1 + #define __Pyx_c_abs_double(z) (::std::abs(z)) + #define __Pyx_c_pow_double(a, b) (::std::pow(a, b)) + #endif + #else + #define __Pyx_c_is_zero_double(z) ((z)==0) + #define __Pyx_c_conj_double(z) (conj(z)) + #if 1 + #define __Pyx_c_abs_double(z) (cabs(z)) + #define __Pyx_c_pow_double(a, b) (cpow(a, b)) + #endif + #endif +#else + static CYTHON_INLINE int __Pyx_c_eq_double(__pyx_t_double_complex, __pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_sum_double(__pyx_t_double_complex, __pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_diff_double(__pyx_t_double_complex, __pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_prod_double(__pyx_t_double_complex, __pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_quot_double(__pyx_t_double_complex, __pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_neg_double(__pyx_t_double_complex); + static CYTHON_INLINE int __Pyx_c_is_zero_double(__pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_conj_double(__pyx_t_double_complex); + #if 1 + static CYTHON_INLINE double __Pyx_c_abs_double(__pyx_t_double_complex); + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_pow_double(__pyx_t_double_complex, __pyx_t_double_complex); + #endif +#endif + +/* Arithmetic.proto */ +#if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #define __Pyx_c_eq_long__double(a, b) ((a)==(b)) + #define __Pyx_c_sum_long__double(a, b) ((a)+(b)) + #define __Pyx_c_diff_long__double(a, b) ((a)-(b)) + #define __Pyx_c_prod_long__double(a, b) ((a)*(b)) + #define __Pyx_c_quot_long__double(a, b) ((a)/(b)) + #define __Pyx_c_neg_long__double(a) (-(a)) + #ifdef __cplusplus + #define __Pyx_c_is_zero_long__double(z) ((z)==(long double)0) + #define __Pyx_c_conj_long__double(z) (::std::conj(z)) + #if 1 + #define __Pyx_c_abs_long__double(z) (::std::abs(z)) + #define __Pyx_c_pow_long__double(a, b) (::std::pow(a, b)) + #endif + #else + #define __Pyx_c_is_zero_long__double(z) ((z)==0) + #define __Pyx_c_conj_long__double(z) (conjl(z)) + #if 1 + #define __Pyx_c_abs_long__double(z) (cabsl(z)) + #define __Pyx_c_pow_long__double(a, b) (cpowl(a, b)) + #endif + #endif +#else + static CYTHON_INLINE int __Pyx_c_eq_long__double(__pyx_t_long_double_complex, __pyx_t_long_double_complex); + static CYTHON_INLINE __pyx_t_long_double_complex __Pyx_c_sum_long__double(__pyx_t_long_double_complex, __pyx_t_long_double_complex); + static CYTHON_INLINE __pyx_t_long_double_complex __Pyx_c_diff_long__double(__pyx_t_long_double_complex, __pyx_t_long_double_complex); + static CYTHON_INLINE __pyx_t_long_double_complex __Pyx_c_prod_long__double(__pyx_t_long_double_complex, __pyx_t_long_double_complex); + static CYTHON_INLINE __pyx_t_long_double_complex __Pyx_c_quot_long__double(__pyx_t_long_double_complex, __pyx_t_long_double_complex); + static CYTHON_INLINE __pyx_t_long_double_complex __Pyx_c_neg_long__double(__pyx_t_long_double_complex); + static CYTHON_INLINE int __Pyx_c_is_zero_long__double(__pyx_t_long_double_complex); + static CYTHON_INLINE __pyx_t_long_double_complex __Pyx_c_conj_long__double(__pyx_t_long_double_complex); + #if 1 + static CYTHON_INLINE long double __Pyx_c_abs_long__double(__pyx_t_long_double_complex); + static CYTHON_INLINE __pyx_t_long_double_complex __Pyx_c_pow_long__double(__pyx_t_long_double_complex, __pyx_t_long_double_complex); + #endif +#endif + +/* MemviewSliceCopyTemplate.proto */ +static __Pyx_memviewslice +__pyx_memoryview_copy_new_contig(const __Pyx_memviewslice *from_mvs, + const char *mode, int ndim, + size_t sizeof_dtype, int contig_flag, + int dtype_is_object); + +/* MemviewSliceInit.proto */ +#define __Pyx_BUF_MAX_NDIMS %(BUF_MAX_NDIMS)d +#define __Pyx_MEMVIEW_DIRECT 1 +#define __Pyx_MEMVIEW_PTR 2 +#define __Pyx_MEMVIEW_FULL 4 +#define __Pyx_MEMVIEW_CONTIG 8 +#define __Pyx_MEMVIEW_STRIDED 16 +#define __Pyx_MEMVIEW_FOLLOW 32 +#define __Pyx_IS_C_CONTIG 1 +#define __Pyx_IS_F_CONTIG 2 +static int __Pyx_init_memviewslice( + struct __pyx_memoryview_obj *memview, + int ndim, + __Pyx_memviewslice *memviewslice, + int memview_is_new_reference); +static CYTHON_INLINE int __pyx_add_acquisition_count_locked( + __pyx_atomic_int_type *acquisition_count, PyThread_type_lock lock); +static CYTHON_INLINE int __pyx_sub_acquisition_count_locked( + __pyx_atomic_int_type *acquisition_count, PyThread_type_lock lock); +#define __pyx_get_slice_count_pointer(memview) (&memview->acquisition_count) +#define __PYX_INC_MEMVIEW(slice, have_gil) __Pyx_INC_MEMVIEW(slice, have_gil, __LINE__) +#define __PYX_XCLEAR_MEMVIEW(slice, have_gil) __Pyx_XCLEAR_MEMVIEW(slice, have_gil, __LINE__) +static CYTHON_INLINE void __Pyx_INC_MEMVIEW(__Pyx_memviewslice *, int, int); +static CYTHON_INLINE void __Pyx_XCLEAR_MEMVIEW(__Pyx_memviewslice *, int, int); + +/* CIntToPy.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int(int value); + +/* CIntFromPy.proto */ +static CYTHON_INLINE int __Pyx_PyInt_As_int(PyObject *); + +/* CIntToPy.proto */ +static CYTHON_INLINE PyObject* __Pyx_PyInt_From_long(long value); + +/* CIntFromPy.proto */ +static CYTHON_INLINE long __Pyx_PyInt_As_long(PyObject *); + +/* CIntFromPy.proto */ +static CYTHON_INLINE char __Pyx_PyInt_As_char(PyObject *); + +/* FormatTypeName.proto */ +#if CYTHON_COMPILING_IN_LIMITED_API +typedef PyObject *__Pyx_TypeName; +#define __Pyx_FMT_TYPENAME "%U" +static __Pyx_TypeName __Pyx_PyType_GetName(PyTypeObject* tp); +#define __Pyx_DECREF_TypeName(obj) Py_XDECREF(obj) +#else +typedef const char *__Pyx_TypeName; +#define __Pyx_FMT_TYPENAME "%.200s" +#define __Pyx_PyType_GetName(tp) ((tp)->tp_name) +#define __Pyx_DECREF_TypeName(obj) +#endif + +/* CheckBinaryVersion.proto */ +static unsigned long __Pyx_get_runtime_version(void); +static int __Pyx_check_binary_version(unsigned long ct_version, unsigned long rt_version, int allow_newer); + +/* InitStrings.proto */ +static int __Pyx_InitStrings(__Pyx_StringTabEntry *t); + +/* #### Code section: module_declarations ### */ +static PyObject *__pyx_array_get_memview(struct __pyx_array_obj *__pyx_v_self); /* proto*/ +static char *__pyx_memoryview_get_item_pointer(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index); /* proto*/ +static PyObject *__pyx_memoryview_is_slice(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_obj); /* proto*/ +static PyObject *__pyx_memoryview_setitem_slice_assignment(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_dst, PyObject *__pyx_v_src); /* proto*/ +static PyObject *__pyx_memoryview_setitem_slice_assign_scalar(struct __pyx_memoryview_obj *__pyx_v_self, struct __pyx_memoryview_obj *__pyx_v_dst, PyObject *__pyx_v_value); /* proto*/ +static PyObject *__pyx_memoryview_setitem_indexed(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value); /* proto*/ +static PyObject *__pyx_memoryview_convert_item_to_object(struct __pyx_memoryview_obj *__pyx_v_self, char *__pyx_v_itemp); /* proto*/ +static PyObject *__pyx_memoryview_assign_item_from_object(struct __pyx_memoryview_obj *__pyx_v_self, char *__pyx_v_itemp, PyObject *__pyx_v_value); /* proto*/ +static PyObject *__pyx_memoryview__get_base(struct __pyx_memoryview_obj *__pyx_v_self); /* proto*/ +static PyObject *__pyx_memoryviewslice_convert_item_to_object(struct __pyx_memoryviewslice_obj *__pyx_v_self, char *__pyx_v_itemp); /* proto*/ +static PyObject *__pyx_memoryviewslice_assign_item_from_object(struct __pyx_memoryviewslice_obj *__pyx_v_self, char *__pyx_v_itemp, PyObject *__pyx_v_value); /* proto*/ +static PyObject *__pyx_memoryviewslice__get_base(struct __pyx_memoryviewslice_obj *__pyx_v_self); /* proto*/ +static CYTHON_INLINE npy_intp __pyx_f_5numpy_5dtype_8itemsize_itemsize(PyArray_Descr *__pyx_v_self); /* proto*/ +static CYTHON_INLINE npy_intp __pyx_f_5numpy_5dtype_9alignment_alignment(PyArray_Descr *__pyx_v_self); /* proto*/ +static CYTHON_INLINE PyObject *__pyx_f_5numpy_5dtype_6fields_fields(PyArray_Descr *__pyx_v_self); /* proto*/ +static CYTHON_INLINE PyObject *__pyx_f_5numpy_5dtype_5names_names(PyArray_Descr *__pyx_v_self); /* proto*/ +static CYTHON_INLINE PyArray_ArrayDescr *__pyx_f_5numpy_5dtype_8subarray_subarray(PyArray_Descr *__pyx_v_self); /* proto*/ +static CYTHON_INLINE npy_uint64 __pyx_f_5numpy_5dtype_5flags_flags(PyArray_Descr *__pyx_v_self); /* proto*/ +static CYTHON_INLINE int __pyx_f_5numpy_9broadcast_7numiter_numiter(PyArrayMultiIterObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE npy_intp __pyx_f_5numpy_9broadcast_4size_size(PyArrayMultiIterObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE npy_intp __pyx_f_5numpy_9broadcast_5index_index(PyArrayMultiIterObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE int __pyx_f_5numpy_9broadcast_2nd_nd(PyArrayMultiIterObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE npy_intp *__pyx_f_5numpy_9broadcast_10dimensions_dimensions(PyArrayMultiIterObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE void **__pyx_f_5numpy_9broadcast_5iters_iters(PyArrayMultiIterObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE PyObject *__pyx_f_5numpy_7ndarray_4base_base(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE PyArray_Descr *__pyx_f_5numpy_7ndarray_5descr_descr(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE int __pyx_f_5numpy_7ndarray_4ndim_ndim(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE npy_intp *__pyx_f_5numpy_7ndarray_5shape_shape(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE npy_intp *__pyx_f_5numpy_7ndarray_7strides_strides(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE npy_intp __pyx_f_5numpy_7ndarray_4size_size(PyArrayObject *__pyx_v_self); /* proto*/ +static CYTHON_INLINE char *__pyx_f_5numpy_7ndarray_4data_data(PyArrayObject *__pyx_v_self); /* proto*/ + +/* Module declarations from "cython.view" */ + +/* Module declarations from "cython.dataclasses" */ + +/* Module declarations from "cython" */ + +/* Module declarations from "libc.string" */ + +/* Module declarations from "libc.stdio" */ + +/* Module declarations from "__builtin__" */ + +/* Module declarations from "cpython.type" */ + +/* Module declarations from "cpython" */ + +/* Module declarations from "cpython.object" */ + +/* Module declarations from "cpython.ref" */ + +/* Module declarations from "numpy" */ + +/* Module declarations from "numpy" */ + +/* Module declarations from "TTS.tts.utils.monotonic_align.core" */ +static PyObject *__pyx_collections_abc_Sequence = 0; +static PyObject *generic = 0; +static PyObject *strided = 0; +static PyObject *indirect = 0; +static PyObject *contiguous = 0; +static PyObject *indirect_contiguous = 0; +static int __pyx_memoryview_thread_locks_used; +static PyThread_type_lock __pyx_memoryview_thread_locks[8]; +static void __pyx_f_3TTS_3tts_5utils_15monotonic_align_4core_maximum_path_each(__Pyx_memviewslice, __Pyx_memviewslice, int, int, float); /*proto*/ +static void __pyx_f_3TTS_3tts_5utils_15monotonic_align_4core_maximum_path_c(__Pyx_memviewslice, __Pyx_memviewslice, __Pyx_memviewslice, __Pyx_memviewslice, int __pyx_skip_dispatch, struct __pyx_opt_args_3TTS_3tts_5utils_15monotonic_align_4core_maximum_path_c *__pyx_optional_args); /*proto*/ +static int __pyx_array_allocate_buffer(struct __pyx_array_obj *); /*proto*/ +static struct __pyx_array_obj *__pyx_array_new(PyObject *, Py_ssize_t, char *, char *, char *); /*proto*/ +static PyObject *__pyx_memoryview_new(PyObject *, int, int, __Pyx_TypeInfo *); /*proto*/ +static CYTHON_INLINE int __pyx_memoryview_check(PyObject *); /*proto*/ +static PyObject *_unellipsify(PyObject *, int); /*proto*/ +static int assert_direct_dimensions(Py_ssize_t *, int); /*proto*/ +static struct __pyx_memoryview_obj *__pyx_memview_slice(struct __pyx_memoryview_obj *, PyObject *); /*proto*/ +static int __pyx_memoryview_slice_memviewslice(__Pyx_memviewslice *, Py_ssize_t, Py_ssize_t, Py_ssize_t, int, int, int *, Py_ssize_t, Py_ssize_t, Py_ssize_t, int, int, int, int); /*proto*/ +static char *__pyx_pybuffer_index(Py_buffer *, char *, Py_ssize_t, Py_ssize_t); /*proto*/ +static int __pyx_memslice_transpose(__Pyx_memviewslice *); /*proto*/ +static PyObject *__pyx_memoryview_fromslice(__Pyx_memviewslice, int, PyObject *(*)(char *), int (*)(char *, PyObject *), int); /*proto*/ +static __Pyx_memviewslice *__pyx_memoryview_get_slice_from_memoryview(struct __pyx_memoryview_obj *, __Pyx_memviewslice *); /*proto*/ +static void __pyx_memoryview_slice_copy(struct __pyx_memoryview_obj *, __Pyx_memviewslice *); /*proto*/ +static PyObject *__pyx_memoryview_copy_object(struct __pyx_memoryview_obj *); /*proto*/ +static PyObject *__pyx_memoryview_copy_object_from_slice(struct __pyx_memoryview_obj *, __Pyx_memviewslice *); /*proto*/ +static Py_ssize_t abs_py_ssize_t(Py_ssize_t); /*proto*/ +static char __pyx_get_best_slice_order(__Pyx_memviewslice *, int); /*proto*/ +static void _copy_strided_to_strided(char *, Py_ssize_t *, char *, Py_ssize_t *, Py_ssize_t *, Py_ssize_t *, int, size_t); /*proto*/ +static void copy_strided_to_strided(__Pyx_memviewslice *, __Pyx_memviewslice *, int, size_t); /*proto*/ +static Py_ssize_t __pyx_memoryview_slice_get_size(__Pyx_memviewslice *, int); /*proto*/ +static Py_ssize_t __pyx_fill_contig_strides_array(Py_ssize_t *, Py_ssize_t *, Py_ssize_t, int, char); /*proto*/ +static void *__pyx_memoryview_copy_data_to_temp(__Pyx_memviewslice *, __Pyx_memviewslice *, char, int); /*proto*/ +static int __pyx_memoryview_err_extents(int, Py_ssize_t, Py_ssize_t); /*proto*/ +static int __pyx_memoryview_err_dim(PyObject *, PyObject *, int); /*proto*/ +static int __pyx_memoryview_err(PyObject *, PyObject *); /*proto*/ +static int __pyx_memoryview_err_no_memory(void); /*proto*/ +static int __pyx_memoryview_copy_contents(__Pyx_memviewslice, __Pyx_memviewslice, int, int, int); /*proto*/ +static void __pyx_memoryview_broadcast_leading(__Pyx_memviewslice *, int, int); /*proto*/ +static void __pyx_memoryview_refcount_copying(__Pyx_memviewslice *, int, int, int); /*proto*/ +static void __pyx_memoryview_refcount_objects_in_slice_with_gil(char *, Py_ssize_t *, Py_ssize_t *, int, int); /*proto*/ +static void __pyx_memoryview_refcount_objects_in_slice(char *, Py_ssize_t *, Py_ssize_t *, int, int); /*proto*/ +static void __pyx_memoryview_slice_assign_scalar(__Pyx_memviewslice *, int, size_t, void *, int); /*proto*/ +static void __pyx_memoryview__slice_assign_scalar(char *, Py_ssize_t *, Py_ssize_t *, int, size_t, void *); /*proto*/ +static PyObject *__pyx_unpickle_Enum__set_state(struct __pyx_MemviewEnum_obj *, PyObject *); /*proto*/ +/* #### Code section: typeinfo ### */ +static __Pyx_TypeInfo __Pyx_TypeInfo_int = { "int", NULL, sizeof(int), { 0 }, 0, __PYX_IS_UNSIGNED(int) ? 'U' : 'I', __PYX_IS_UNSIGNED(int), 0 }; +static __Pyx_TypeInfo __Pyx_TypeInfo_float = { "float", NULL, sizeof(float), { 0 }, 0, 'R', 0, 0 }; +/* #### Code section: before_global_var ### */ +#define __Pyx_MODULE_NAME "TTS.tts.utils.monotonic_align.core" +extern int __pyx_module_is_main_TTS__tts__utils__monotonic_align__core; +int __pyx_module_is_main_TTS__tts__utils__monotonic_align__core = 0; + +/* Implementation of "TTS.tts.utils.monotonic_align.core" */ +/* #### Code section: global_var ### */ +static PyObject *__pyx_builtin_range; +static PyObject *__pyx_builtin___import__; +static PyObject *__pyx_builtin_ValueError; +static PyObject *__pyx_builtin_MemoryError; +static PyObject *__pyx_builtin_enumerate; +static PyObject *__pyx_builtin_TypeError; +static PyObject *__pyx_builtin_AssertionError; +static PyObject *__pyx_builtin_Ellipsis; +static PyObject *__pyx_builtin_id; +static PyObject *__pyx_builtin_IndexError; +static PyObject *__pyx_builtin_ImportError; +/* #### Code section: string_decls ### */ +static const char __pyx_k_[] = ": "; +static const char __pyx_k_O[] = "O"; +static const char __pyx_k_c[] = "c"; +static const char __pyx_k__2[] = "."; +static const char __pyx_k__3[] = "*"; +static const char __pyx_k__6[] = "'"; +static const char __pyx_k__7[] = ")"; +static const char __pyx_k_gc[] = "gc"; +static const char __pyx_k_id[] = "id"; +static const char __pyx_k_np[] = "np"; +static const char __pyx_k__25[] = "?"; +static const char __pyx_k_abc[] = "abc"; +static const char __pyx_k_and[] = " and "; +static const char __pyx_k_got[] = " (got "; +static const char __pyx_k_new[] = "__new__"; +static const char __pyx_k_obj[] = "obj"; +static const char __pyx_k_sys[] = "sys"; +static const char __pyx_k_base[] = "base"; +static const char __pyx_k_dict[] = "__dict__"; +static const char __pyx_k_main[] = "__main__"; +static const char __pyx_k_mode[] = "mode"; +static const char __pyx_k_name[] = "name"; +static const char __pyx_k_ndim[] = "ndim"; +static const char __pyx_k_pack[] = "pack"; +static const char __pyx_k_size[] = "size"; +static const char __pyx_k_spec[] = "__spec__"; +static const char __pyx_k_step[] = "step"; +static const char __pyx_k_stop[] = "stop"; +static const char __pyx_k_t_xs[] = "t_xs"; +static const char __pyx_k_t_ys[] = "t_ys"; +static const char __pyx_k_test[] = "__test__"; +static const char __pyx_k_ASCII[] = "ASCII"; +static const char __pyx_k_class[] = "__class__"; +static const char __pyx_k_count[] = "count"; +static const char __pyx_k_error[] = "error"; +static const char __pyx_k_flags[] = "flags"; +static const char __pyx_k_index[] = "index"; +static const char __pyx_k_numpy[] = "numpy"; +static const char __pyx_k_paths[] = "paths"; +static const char __pyx_k_range[] = "range"; +static const char __pyx_k_shape[] = "shape"; +static const char __pyx_k_start[] = "start"; +static const char __pyx_k_enable[] = "enable"; +static const char __pyx_k_encode[] = "encode"; +static const char __pyx_k_format[] = "format"; +static const char __pyx_k_import[] = "__import__"; +static const char __pyx_k_name_2[] = "__name__"; +static const char __pyx_k_pickle[] = "pickle"; +static const char __pyx_k_reduce[] = "__reduce__"; +static const char __pyx_k_struct[] = "struct"; +static const char __pyx_k_unpack[] = "unpack"; +static const char __pyx_k_update[] = "update"; +static const char __pyx_k_values[] = "values"; +static const char __pyx_k_disable[] = "disable"; +static const char __pyx_k_fortran[] = "fortran"; +static const char __pyx_k_memview[] = "memview"; +static const char __pyx_k_Ellipsis[] = "Ellipsis"; +static const char __pyx_k_Sequence[] = "Sequence"; +static const char __pyx_k_getstate[] = "__getstate__"; +static const char __pyx_k_itemsize[] = "itemsize"; +static const char __pyx_k_pyx_type[] = "__pyx_type"; +static const char __pyx_k_register[] = "register"; +static const char __pyx_k_setstate[] = "__setstate__"; +static const char __pyx_k_TypeError[] = "TypeError"; +static const char __pyx_k_enumerate[] = "enumerate"; +static const char __pyx_k_isenabled[] = "isenabled"; +static const char __pyx_k_pyx_state[] = "__pyx_state"; +static const char __pyx_k_reduce_ex[] = "__reduce_ex__"; +static const char __pyx_k_IndexError[] = "IndexError"; +static const char __pyx_k_ValueError[] = "ValueError"; +static const char __pyx_k_pyx_result[] = "__pyx_result"; +static const char __pyx_k_pyx_vtable[] = "__pyx_vtable__"; +static const char __pyx_k_ImportError[] = "ImportError"; +static const char __pyx_k_MemoryError[] = "MemoryError"; +static const char __pyx_k_PickleError[] = "PickleError"; +static const char __pyx_k_collections[] = "collections"; +static const char __pyx_k_max_neg_val[] = "max_neg_val"; +static const char __pyx_k_initializing[] = "_initializing"; +static const char __pyx_k_is_coroutine[] = "_is_coroutine"; +static const char __pyx_k_pyx_checksum[] = "__pyx_checksum"; +static const char __pyx_k_stringsource[] = ""; +static const char __pyx_k_version_info[] = "version_info"; +static const char __pyx_k_class_getitem[] = "__class_getitem__"; +static const char __pyx_k_reduce_cython[] = "__reduce_cython__"; +static const char __pyx_k_AssertionError[] = "AssertionError"; +static const char __pyx_k_maximum_path_c[] = "maximum_path_c"; +static const char __pyx_k_View_MemoryView[] = "View.MemoryView"; +static const char __pyx_k_allocate_buffer[] = "allocate_buffer"; +static const char __pyx_k_collections_abc[] = "collections.abc"; +static const char __pyx_k_dtype_is_object[] = "dtype_is_object"; +static const char __pyx_k_pyx_PickleError[] = "__pyx_PickleError"; +static const char __pyx_k_setstate_cython[] = "__setstate_cython__"; +static const char __pyx_k_pyx_unpickle_Enum[] = "__pyx_unpickle_Enum"; +static const char __pyx_k_asyncio_coroutines[] = "asyncio.coroutines"; +static const char __pyx_k_cline_in_traceback[] = "cline_in_traceback"; +static const char __pyx_k_strided_and_direct[] = ""; +static const char __pyx_k_strided_and_indirect[] = ""; +static const char __pyx_k_Invalid_shape_in_axis[] = "Invalid shape in axis "; +static const char __pyx_k_contiguous_and_direct[] = ""; +static const char __pyx_k_Cannot_index_with_type[] = "Cannot index with type '"; +static const char __pyx_k_MemoryView_of_r_object[] = ""; +static const char __pyx_k_MemoryView_of_r_at_0x_x[] = ""; +static const char __pyx_k_contiguous_and_indirect[] = ""; +static const char __pyx_k_Dimension_d_is_not_direct[] = "Dimension %d is not direct"; +static const char __pyx_k_Index_out_of_bounds_axis_d[] = "Index out of bounds (axis %d)"; +static const char __pyx_k_Step_may_not_be_zero_axis_d[] = "Step may not be zero (axis %d)"; +static const char __pyx_k_itemsize_0_for_cython_array[] = "itemsize <= 0 for cython.array"; +static const char __pyx_k_unable_to_allocate_array_data[] = "unable to allocate array data."; +static const char __pyx_k_strided_and_direct_or_indirect[] = ""; +static const char __pyx_k_All_dimensions_preceding_dimensi[] = "All dimensions preceding dimension %d must be indexed and not sliced"; +static const char __pyx_k_Buffer_view_does_not_expose_stri[] = "Buffer view does not expose strides"; +static const char __pyx_k_Can_only_create_a_buffer_that_is[] = "Can only create a buffer that is contiguous in memory."; +static const char __pyx_k_Cannot_assign_to_read_only_memor[] = "Cannot assign to read-only memoryview"; +static const char __pyx_k_Cannot_create_writable_memory_vi[] = "Cannot create writable memory view from read-only memoryview"; +static const char __pyx_k_Cannot_transpose_memoryview_with[] = "Cannot transpose memoryview with indirect dimensions"; +static const char __pyx_k_Empty_shape_tuple_for_cython_arr[] = "Empty shape tuple for cython.array"; +static const char __pyx_k_Incompatible_checksums_0x_x_vs_0[] = "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))"; +static const char __pyx_k_Indirect_dimensions_not_supporte[] = "Indirect dimensions not supported"; +static const char __pyx_k_Invalid_mode_expected_c_or_fortr[] = "Invalid mode, expected 'c' or 'fortran', got "; +static const char __pyx_k_Out_of_bounds_on_buffer_access_a[] = "Out of bounds on buffer access (axis "; +static const char __pyx_k_TTS_tts_utils_monotonic_align_co[] = "TTS/tts/utils/monotonic_align/core.pyx"; +static const char __pyx_k_Unable_to_convert_item_to_object[] = "Unable to convert item to object"; +static const char __pyx_k_got_differing_extents_in_dimensi[] = "got differing extents in dimension "; +static const char __pyx_k_no_default___reduce___due_to_non[] = "no default __reduce__ due to non-trivial __cinit__"; +static const char __pyx_k_numpy__core_multiarray_failed_to[] = "numpy._core.multiarray failed to import"; +static const char __pyx_k_numpy__core_umath_failed_to_impo[] = "numpy._core.umath failed to import"; +static const char __pyx_k_unable_to_allocate_shape_and_str[] = "unable to allocate shape and strides."; +static const char __pyx_k_TTS_tts_utils_monotonic_align_co_2[] = "TTS.tts.utils.monotonic_align.core"; +/* #### Code section: decls ### */ +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array___cinit__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_shape, Py_ssize_t __pyx_v_itemsize, PyObject *__pyx_v_format, PyObject *__pyx_v_mode, int __pyx_v_allocate_buffer); /* proto */ +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array_2__getbuffer__(struct __pyx_array_obj *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /* proto */ +static void __pyx_array___pyx_pf_15View_dot_MemoryView_5array_4__dealloc__(struct __pyx_array_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_5array_7memview___get__(struct __pyx_array_obj *__pyx_v_self); /* proto */ +static Py_ssize_t __pyx_array___pyx_pf_15View_dot_MemoryView_5array_6__len__(struct __pyx_array_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_array___pyx_pf_15View_dot_MemoryView_5array_8__getattr__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_attr); /* proto */ +static PyObject *__pyx_array___pyx_pf_15View_dot_MemoryView_5array_10__getitem__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_item); /* proto */ +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array_12__setitem__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_item, PyObject *__pyx_v_value); /* proto */ +static PyObject *__pyx_pf___pyx_array___reduce_cython__(CYTHON_UNUSED struct __pyx_array_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_array_2__setstate_cython__(CYTHON_UNUSED struct __pyx_array_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */ +static int __pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum___init__(struct __pyx_MemviewEnum_obj *__pyx_v_self, PyObject *__pyx_v_name); /* proto */ +static PyObject *__pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum_2__repr__(struct __pyx_MemviewEnum_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_MemviewEnum___reduce_cython__(struct __pyx_MemviewEnum_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_MemviewEnum_2__setstate_cython__(struct __pyx_MemviewEnum_obj *__pyx_v_self, PyObject *__pyx_v___pyx_state); /* proto */ +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview___cinit__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_obj, int __pyx_v_flags, int __pyx_v_dtype_is_object); /* proto */ +static void __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_2__dealloc__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_4__getitem__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index); /* proto */ +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_6__setitem__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value); /* proto */ +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_8__getbuffer__(struct __pyx_memoryview_obj *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_1T___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4base___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_5shape___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_7strides___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_10suboffsets___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4ndim___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_8itemsize___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_6nbytes___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4size___get__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static Py_ssize_t __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_10__len__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_12__repr__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_14__str__(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_16is_c_contig(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_18is_f_contig(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_20copy(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_22copy_fortran(struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_memoryview___reduce_cython__(CYTHON_UNUSED struct __pyx_memoryview_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_memoryview_2__setstate_cython__(CYTHON_UNUSED struct __pyx_memoryview_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */ +static void __pyx_memoryviewslice___pyx_pf_15View_dot_MemoryView_16_memoryviewslice___dealloc__(struct __pyx_memoryviewslice_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_memoryviewslice___reduce_cython__(CYTHON_UNUSED struct __pyx_memoryviewslice_obj *__pyx_v_self); /* proto */ +static PyObject *__pyx_pf___pyx_memoryviewslice_2__setstate_cython__(CYTHON_UNUSED struct __pyx_memoryviewslice_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state); /* proto */ +static PyObject *__pyx_pf_15View_dot_MemoryView___pyx_unpickle_Enum(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v___pyx_type, long __pyx_v___pyx_checksum, PyObject *__pyx_v___pyx_state); /* proto */ +static PyObject *__pyx_pf_3TTS_3tts_5utils_15monotonic_align_4core_maximum_path_c(CYTHON_UNUSED PyObject *__pyx_self, __Pyx_memviewslice __pyx_v_paths, __Pyx_memviewslice __pyx_v_values, __Pyx_memviewslice __pyx_v_t_xs, __Pyx_memviewslice __pyx_v_t_ys, float __pyx_v_max_neg_val); /* proto */ +static PyObject *__pyx_tp_new_array(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ +static PyObject *__pyx_tp_new_Enum(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ +static PyObject *__pyx_tp_new_memoryview(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ +static PyObject *__pyx_tp_new__memoryviewslice(PyTypeObject *t, PyObject *a, PyObject *k); /*proto*/ +/* #### Code section: late_includes ### */ +/* #### Code section: module_state ### */ +typedef struct { + PyObject *__pyx_d; + PyObject *__pyx_b; + PyObject *__pyx_cython_runtime; + PyObject *__pyx_empty_tuple; + PyObject *__pyx_empty_bytes; + PyObject *__pyx_empty_unicode; + #ifdef __Pyx_CyFunction_USED + PyTypeObject *__pyx_CyFunctionType; + #endif + #ifdef __Pyx_FusedFunction_USED + PyTypeObject *__pyx_FusedFunctionType; + #endif + #ifdef __Pyx_Generator_USED + PyTypeObject *__pyx_GeneratorType; + #endif + #ifdef __Pyx_IterableCoroutine_USED + PyTypeObject *__pyx_IterableCoroutineType; + #endif + #ifdef __Pyx_Coroutine_USED + PyTypeObject *__pyx_CoroutineAwaitType; + #endif + #ifdef __Pyx_Coroutine_USED + PyTypeObject *__pyx_CoroutineType; + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + PyTypeObject *__pyx_ptype_7cpython_4type_type; + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + #if CYTHON_USE_MODULE_STATE + #endif + PyTypeObject *__pyx_ptype_5numpy_dtype; + PyTypeObject *__pyx_ptype_5numpy_flatiter; + PyTypeObject *__pyx_ptype_5numpy_broadcast; + PyTypeObject *__pyx_ptype_5numpy_ndarray; + PyTypeObject *__pyx_ptype_5numpy_generic; + PyTypeObject *__pyx_ptype_5numpy_number; + PyTypeObject *__pyx_ptype_5numpy_integer; + PyTypeObject *__pyx_ptype_5numpy_signedinteger; + PyTypeObject *__pyx_ptype_5numpy_unsignedinteger; + PyTypeObject *__pyx_ptype_5numpy_inexact; + PyTypeObject *__pyx_ptype_5numpy_floating; + PyTypeObject *__pyx_ptype_5numpy_complexfloating; + PyTypeObject *__pyx_ptype_5numpy_flexible; + PyTypeObject *__pyx_ptype_5numpy_character; + PyTypeObject *__pyx_ptype_5numpy_ufunc; + #if CYTHON_USE_MODULE_STATE + PyObject *__pyx_type___pyx_array; + PyObject *__pyx_type___pyx_MemviewEnum; + PyObject *__pyx_type___pyx_memoryview; + PyObject *__pyx_type___pyx_memoryviewslice; + #endif + PyTypeObject *__pyx_array_type; + PyTypeObject *__pyx_MemviewEnum_type; + PyTypeObject *__pyx_memoryview_type; + PyTypeObject *__pyx_memoryviewslice_type; + PyObject *__pyx_kp_u_; + PyObject *__pyx_n_s_ASCII; + PyObject *__pyx_kp_s_All_dimensions_preceding_dimensi; + PyObject *__pyx_n_s_AssertionError; + PyObject *__pyx_kp_s_Buffer_view_does_not_expose_stri; + PyObject *__pyx_kp_s_Can_only_create_a_buffer_that_is; + PyObject *__pyx_kp_s_Cannot_assign_to_read_only_memor; + PyObject *__pyx_kp_s_Cannot_create_writable_memory_vi; + PyObject *__pyx_kp_u_Cannot_index_with_type; + PyObject *__pyx_kp_s_Cannot_transpose_memoryview_with; + PyObject *__pyx_kp_s_Dimension_d_is_not_direct; + PyObject *__pyx_n_s_Ellipsis; + PyObject *__pyx_kp_s_Empty_shape_tuple_for_cython_arr; + PyObject *__pyx_n_s_ImportError; + PyObject *__pyx_kp_s_Incompatible_checksums_0x_x_vs_0; + PyObject *__pyx_n_s_IndexError; + PyObject *__pyx_kp_s_Index_out_of_bounds_axis_d; + PyObject *__pyx_kp_s_Indirect_dimensions_not_supporte; + PyObject *__pyx_kp_u_Invalid_mode_expected_c_or_fortr; + PyObject *__pyx_kp_u_Invalid_shape_in_axis; + PyObject *__pyx_n_s_MemoryError; + PyObject *__pyx_kp_s_MemoryView_of_r_at_0x_x; + PyObject *__pyx_kp_s_MemoryView_of_r_object; + PyObject *__pyx_n_b_O; + PyObject *__pyx_kp_u_Out_of_bounds_on_buffer_access_a; + PyObject *__pyx_n_s_PickleError; + PyObject *__pyx_n_s_Sequence; + PyObject *__pyx_kp_s_Step_may_not_be_zero_axis_d; + PyObject *__pyx_kp_s_TTS_tts_utils_monotonic_align_co; + PyObject *__pyx_n_s_TTS_tts_utils_monotonic_align_co_2; + PyObject *__pyx_n_s_TypeError; + PyObject *__pyx_kp_s_Unable_to_convert_item_to_object; + PyObject *__pyx_n_s_ValueError; + PyObject *__pyx_n_s_View_MemoryView; + PyObject *__pyx_kp_u__2; + PyObject *__pyx_n_s__25; + PyObject *__pyx_n_s__3; + PyObject *__pyx_kp_u__6; + PyObject *__pyx_kp_u__7; + PyObject *__pyx_n_s_abc; + PyObject *__pyx_n_s_allocate_buffer; + PyObject *__pyx_kp_u_and; + PyObject *__pyx_n_s_asyncio_coroutines; + PyObject *__pyx_n_s_base; + PyObject *__pyx_n_s_c; + PyObject *__pyx_n_u_c; + PyObject *__pyx_n_s_class; + PyObject *__pyx_n_s_class_getitem; + PyObject *__pyx_n_s_cline_in_traceback; + PyObject *__pyx_n_s_collections; + PyObject *__pyx_kp_s_collections_abc; + PyObject *__pyx_kp_s_contiguous_and_direct; + PyObject *__pyx_kp_s_contiguous_and_indirect; + PyObject *__pyx_n_s_count; + PyObject *__pyx_n_s_dict; + PyObject *__pyx_kp_u_disable; + PyObject *__pyx_n_s_dtype_is_object; + PyObject *__pyx_kp_u_enable; + PyObject *__pyx_n_s_encode; + PyObject *__pyx_n_s_enumerate; + PyObject *__pyx_n_s_error; + PyObject *__pyx_n_s_flags; + PyObject *__pyx_n_s_format; + PyObject *__pyx_n_s_fortran; + PyObject *__pyx_n_u_fortran; + PyObject *__pyx_kp_u_gc; + PyObject *__pyx_n_s_getstate; + PyObject *__pyx_kp_u_got; + PyObject *__pyx_kp_u_got_differing_extents_in_dimensi; + PyObject *__pyx_n_s_id; + PyObject *__pyx_n_s_import; + PyObject *__pyx_n_s_index; + PyObject *__pyx_n_s_initializing; + PyObject *__pyx_n_s_is_coroutine; + PyObject *__pyx_kp_u_isenabled; + PyObject *__pyx_n_s_itemsize; + PyObject *__pyx_kp_s_itemsize_0_for_cython_array; + PyObject *__pyx_n_s_main; + PyObject *__pyx_n_s_max_neg_val; + PyObject *__pyx_n_s_maximum_path_c; + PyObject *__pyx_n_s_memview; + PyObject *__pyx_n_s_mode; + PyObject *__pyx_n_s_name; + PyObject *__pyx_n_s_name_2; + PyObject *__pyx_n_s_ndim; + PyObject *__pyx_n_s_new; + PyObject *__pyx_kp_s_no_default___reduce___due_to_non; + PyObject *__pyx_n_s_np; + PyObject *__pyx_n_s_numpy; + PyObject *__pyx_kp_u_numpy__core_multiarray_failed_to; + PyObject *__pyx_kp_u_numpy__core_umath_failed_to_impo; + PyObject *__pyx_n_s_obj; + PyObject *__pyx_n_s_pack; + PyObject *__pyx_n_s_paths; + PyObject *__pyx_n_s_pickle; + PyObject *__pyx_n_s_pyx_PickleError; + PyObject *__pyx_n_s_pyx_checksum; + PyObject *__pyx_n_s_pyx_result; + PyObject *__pyx_n_s_pyx_state; + PyObject *__pyx_n_s_pyx_type; + PyObject *__pyx_n_s_pyx_unpickle_Enum; + PyObject *__pyx_n_s_pyx_vtable; + PyObject *__pyx_n_s_range; + PyObject *__pyx_n_s_reduce; + PyObject *__pyx_n_s_reduce_cython; + PyObject *__pyx_n_s_reduce_ex; + PyObject *__pyx_n_s_register; + PyObject *__pyx_n_s_setstate; + PyObject *__pyx_n_s_setstate_cython; + PyObject *__pyx_n_s_shape; + PyObject *__pyx_n_s_size; + PyObject *__pyx_n_s_spec; + PyObject *__pyx_n_s_start; + PyObject *__pyx_n_s_step; + PyObject *__pyx_n_s_stop; + PyObject *__pyx_kp_s_strided_and_direct; + PyObject *__pyx_kp_s_strided_and_direct_or_indirect; + PyObject *__pyx_kp_s_strided_and_indirect; + PyObject *__pyx_kp_s_stringsource; + PyObject *__pyx_n_s_struct; + PyObject *__pyx_n_s_sys; + PyObject *__pyx_n_s_t_xs; + PyObject *__pyx_n_s_t_ys; + PyObject *__pyx_n_s_test; + PyObject *__pyx_kp_s_unable_to_allocate_array_data; + PyObject *__pyx_kp_s_unable_to_allocate_shape_and_str; + PyObject *__pyx_n_s_unpack; + PyObject *__pyx_n_s_update; + PyObject *__pyx_n_s_values; + PyObject *__pyx_n_s_version_info; + PyObject *__pyx_int_0; + PyObject *__pyx_int_1; + PyObject *__pyx_int_3; + PyObject *__pyx_int_112105877; + PyObject *__pyx_int_136983863; + PyObject *__pyx_int_184977713; + PyObject *__pyx_int_neg_1; + float __pyx_k__11; + PyObject *__pyx_slice__5; + PyObject *__pyx_tuple__4; + PyObject *__pyx_tuple__8; + PyObject *__pyx_tuple__9; + PyObject *__pyx_tuple__10; + PyObject *__pyx_tuple__12; + PyObject *__pyx_tuple__13; + PyObject *__pyx_tuple__14; + PyObject *__pyx_tuple__15; + PyObject *__pyx_tuple__16; + PyObject *__pyx_tuple__17; + PyObject *__pyx_tuple__18; + PyObject *__pyx_tuple__19; + PyObject *__pyx_tuple__20; + PyObject *__pyx_tuple__21; + PyObject *__pyx_tuple__23; + PyObject *__pyx_codeobj__22; + PyObject *__pyx_codeobj__24; +} __pyx_mstate; + +#if CYTHON_USE_MODULE_STATE +#ifdef __cplusplus +namespace { + extern struct PyModuleDef __pyx_moduledef; +} /* anonymous namespace */ +#else +static struct PyModuleDef __pyx_moduledef; +#endif + +#define __pyx_mstate(o) ((__pyx_mstate *)__Pyx_PyModule_GetState(o)) + +#define __pyx_mstate_global (__pyx_mstate(PyState_FindModule(&__pyx_moduledef))) + +#define __pyx_m (PyState_FindModule(&__pyx_moduledef)) +#else +static __pyx_mstate __pyx_mstate_global_static = +#ifdef __cplusplus + {}; +#else + {0}; +#endif +static __pyx_mstate *__pyx_mstate_global = &__pyx_mstate_global_static; +#endif +/* #### Code section: module_state_clear ### */ +#if CYTHON_USE_MODULE_STATE +static int __pyx_m_clear(PyObject *m) { + __pyx_mstate *clear_module_state = __pyx_mstate(m); + if (!clear_module_state) return 0; + Py_CLEAR(clear_module_state->__pyx_d); + Py_CLEAR(clear_module_state->__pyx_b); + Py_CLEAR(clear_module_state->__pyx_cython_runtime); + Py_CLEAR(clear_module_state->__pyx_empty_tuple); + Py_CLEAR(clear_module_state->__pyx_empty_bytes); + Py_CLEAR(clear_module_state->__pyx_empty_unicode); + #ifdef __Pyx_CyFunction_USED + Py_CLEAR(clear_module_state->__pyx_CyFunctionType); + #endif + #ifdef __Pyx_FusedFunction_USED + Py_CLEAR(clear_module_state->__pyx_FusedFunctionType); + #endif + Py_CLEAR(clear_module_state->__pyx_ptype_7cpython_4type_type); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_dtype); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_flatiter); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_broadcast); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_ndarray); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_generic); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_number); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_integer); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_signedinteger); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_unsignedinteger); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_inexact); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_floating); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_complexfloating); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_flexible); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_character); + Py_CLEAR(clear_module_state->__pyx_ptype_5numpy_ufunc); + Py_CLEAR(clear_module_state->__pyx_array_type); + Py_CLEAR(clear_module_state->__pyx_type___pyx_array); + Py_CLEAR(clear_module_state->__pyx_MemviewEnum_type); + Py_CLEAR(clear_module_state->__pyx_type___pyx_MemviewEnum); + Py_CLEAR(clear_module_state->__pyx_memoryview_type); + Py_CLEAR(clear_module_state->__pyx_type___pyx_memoryview); + Py_CLEAR(clear_module_state->__pyx_memoryviewslice_type); + Py_CLEAR(clear_module_state->__pyx_type___pyx_memoryviewslice); + Py_CLEAR(clear_module_state->__pyx_kp_u_); + Py_CLEAR(clear_module_state->__pyx_n_s_ASCII); + Py_CLEAR(clear_module_state->__pyx_kp_s_All_dimensions_preceding_dimensi); + Py_CLEAR(clear_module_state->__pyx_n_s_AssertionError); + Py_CLEAR(clear_module_state->__pyx_kp_s_Buffer_view_does_not_expose_stri); + Py_CLEAR(clear_module_state->__pyx_kp_s_Can_only_create_a_buffer_that_is); + Py_CLEAR(clear_module_state->__pyx_kp_s_Cannot_assign_to_read_only_memor); + Py_CLEAR(clear_module_state->__pyx_kp_s_Cannot_create_writable_memory_vi); + Py_CLEAR(clear_module_state->__pyx_kp_u_Cannot_index_with_type); + Py_CLEAR(clear_module_state->__pyx_kp_s_Cannot_transpose_memoryview_with); + Py_CLEAR(clear_module_state->__pyx_kp_s_Dimension_d_is_not_direct); + Py_CLEAR(clear_module_state->__pyx_n_s_Ellipsis); + Py_CLEAR(clear_module_state->__pyx_kp_s_Empty_shape_tuple_for_cython_arr); + Py_CLEAR(clear_module_state->__pyx_n_s_ImportError); + Py_CLEAR(clear_module_state->__pyx_kp_s_Incompatible_checksums_0x_x_vs_0); + Py_CLEAR(clear_module_state->__pyx_n_s_IndexError); + Py_CLEAR(clear_module_state->__pyx_kp_s_Index_out_of_bounds_axis_d); + Py_CLEAR(clear_module_state->__pyx_kp_s_Indirect_dimensions_not_supporte); + Py_CLEAR(clear_module_state->__pyx_kp_u_Invalid_mode_expected_c_or_fortr); + Py_CLEAR(clear_module_state->__pyx_kp_u_Invalid_shape_in_axis); + Py_CLEAR(clear_module_state->__pyx_n_s_MemoryError); + Py_CLEAR(clear_module_state->__pyx_kp_s_MemoryView_of_r_at_0x_x); + Py_CLEAR(clear_module_state->__pyx_kp_s_MemoryView_of_r_object); + Py_CLEAR(clear_module_state->__pyx_n_b_O); + Py_CLEAR(clear_module_state->__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + Py_CLEAR(clear_module_state->__pyx_n_s_PickleError); + Py_CLEAR(clear_module_state->__pyx_n_s_Sequence); + Py_CLEAR(clear_module_state->__pyx_kp_s_Step_may_not_be_zero_axis_d); + Py_CLEAR(clear_module_state->__pyx_kp_s_TTS_tts_utils_monotonic_align_co); + Py_CLEAR(clear_module_state->__pyx_n_s_TTS_tts_utils_monotonic_align_co_2); + Py_CLEAR(clear_module_state->__pyx_n_s_TypeError); + Py_CLEAR(clear_module_state->__pyx_kp_s_Unable_to_convert_item_to_object); + Py_CLEAR(clear_module_state->__pyx_n_s_ValueError); + Py_CLEAR(clear_module_state->__pyx_n_s_View_MemoryView); + Py_CLEAR(clear_module_state->__pyx_kp_u__2); + Py_CLEAR(clear_module_state->__pyx_n_s__25); + Py_CLEAR(clear_module_state->__pyx_n_s__3); + Py_CLEAR(clear_module_state->__pyx_kp_u__6); + Py_CLEAR(clear_module_state->__pyx_kp_u__7); + Py_CLEAR(clear_module_state->__pyx_n_s_abc); + Py_CLEAR(clear_module_state->__pyx_n_s_allocate_buffer); + Py_CLEAR(clear_module_state->__pyx_kp_u_and); + Py_CLEAR(clear_module_state->__pyx_n_s_asyncio_coroutines); + Py_CLEAR(clear_module_state->__pyx_n_s_base); + Py_CLEAR(clear_module_state->__pyx_n_s_c); + Py_CLEAR(clear_module_state->__pyx_n_u_c); + Py_CLEAR(clear_module_state->__pyx_n_s_class); + Py_CLEAR(clear_module_state->__pyx_n_s_class_getitem); + Py_CLEAR(clear_module_state->__pyx_n_s_cline_in_traceback); + Py_CLEAR(clear_module_state->__pyx_n_s_collections); + Py_CLEAR(clear_module_state->__pyx_kp_s_collections_abc); + Py_CLEAR(clear_module_state->__pyx_kp_s_contiguous_and_direct); + Py_CLEAR(clear_module_state->__pyx_kp_s_contiguous_and_indirect); + Py_CLEAR(clear_module_state->__pyx_n_s_count); + Py_CLEAR(clear_module_state->__pyx_n_s_dict); + Py_CLEAR(clear_module_state->__pyx_kp_u_disable); + Py_CLEAR(clear_module_state->__pyx_n_s_dtype_is_object); + Py_CLEAR(clear_module_state->__pyx_kp_u_enable); + Py_CLEAR(clear_module_state->__pyx_n_s_encode); + Py_CLEAR(clear_module_state->__pyx_n_s_enumerate); + Py_CLEAR(clear_module_state->__pyx_n_s_error); + Py_CLEAR(clear_module_state->__pyx_n_s_flags); + Py_CLEAR(clear_module_state->__pyx_n_s_format); + Py_CLEAR(clear_module_state->__pyx_n_s_fortran); + Py_CLEAR(clear_module_state->__pyx_n_u_fortran); + Py_CLEAR(clear_module_state->__pyx_kp_u_gc); + Py_CLEAR(clear_module_state->__pyx_n_s_getstate); + Py_CLEAR(clear_module_state->__pyx_kp_u_got); + Py_CLEAR(clear_module_state->__pyx_kp_u_got_differing_extents_in_dimensi); + Py_CLEAR(clear_module_state->__pyx_n_s_id); + Py_CLEAR(clear_module_state->__pyx_n_s_import); + Py_CLEAR(clear_module_state->__pyx_n_s_index); + Py_CLEAR(clear_module_state->__pyx_n_s_initializing); + Py_CLEAR(clear_module_state->__pyx_n_s_is_coroutine); + Py_CLEAR(clear_module_state->__pyx_kp_u_isenabled); + Py_CLEAR(clear_module_state->__pyx_n_s_itemsize); + Py_CLEAR(clear_module_state->__pyx_kp_s_itemsize_0_for_cython_array); + Py_CLEAR(clear_module_state->__pyx_n_s_main); + Py_CLEAR(clear_module_state->__pyx_n_s_max_neg_val); + Py_CLEAR(clear_module_state->__pyx_n_s_maximum_path_c); + Py_CLEAR(clear_module_state->__pyx_n_s_memview); + Py_CLEAR(clear_module_state->__pyx_n_s_mode); + Py_CLEAR(clear_module_state->__pyx_n_s_name); + Py_CLEAR(clear_module_state->__pyx_n_s_name_2); + Py_CLEAR(clear_module_state->__pyx_n_s_ndim); + Py_CLEAR(clear_module_state->__pyx_n_s_new); + Py_CLEAR(clear_module_state->__pyx_kp_s_no_default___reduce___due_to_non); + Py_CLEAR(clear_module_state->__pyx_n_s_np); + Py_CLEAR(clear_module_state->__pyx_n_s_numpy); + Py_CLEAR(clear_module_state->__pyx_kp_u_numpy__core_multiarray_failed_to); + Py_CLEAR(clear_module_state->__pyx_kp_u_numpy__core_umath_failed_to_impo); + Py_CLEAR(clear_module_state->__pyx_n_s_obj); + Py_CLEAR(clear_module_state->__pyx_n_s_pack); + Py_CLEAR(clear_module_state->__pyx_n_s_paths); + Py_CLEAR(clear_module_state->__pyx_n_s_pickle); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_PickleError); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_checksum); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_result); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_state); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_type); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_unpickle_Enum); + Py_CLEAR(clear_module_state->__pyx_n_s_pyx_vtable); + Py_CLEAR(clear_module_state->__pyx_n_s_range); + Py_CLEAR(clear_module_state->__pyx_n_s_reduce); + Py_CLEAR(clear_module_state->__pyx_n_s_reduce_cython); + Py_CLEAR(clear_module_state->__pyx_n_s_reduce_ex); + Py_CLEAR(clear_module_state->__pyx_n_s_register); + Py_CLEAR(clear_module_state->__pyx_n_s_setstate); + Py_CLEAR(clear_module_state->__pyx_n_s_setstate_cython); + Py_CLEAR(clear_module_state->__pyx_n_s_shape); + Py_CLEAR(clear_module_state->__pyx_n_s_size); + Py_CLEAR(clear_module_state->__pyx_n_s_spec); + Py_CLEAR(clear_module_state->__pyx_n_s_start); + Py_CLEAR(clear_module_state->__pyx_n_s_step); + Py_CLEAR(clear_module_state->__pyx_n_s_stop); + Py_CLEAR(clear_module_state->__pyx_kp_s_strided_and_direct); + Py_CLEAR(clear_module_state->__pyx_kp_s_strided_and_direct_or_indirect); + Py_CLEAR(clear_module_state->__pyx_kp_s_strided_and_indirect); + Py_CLEAR(clear_module_state->__pyx_kp_s_stringsource); + Py_CLEAR(clear_module_state->__pyx_n_s_struct); + Py_CLEAR(clear_module_state->__pyx_n_s_sys); + Py_CLEAR(clear_module_state->__pyx_n_s_t_xs); + Py_CLEAR(clear_module_state->__pyx_n_s_t_ys); + Py_CLEAR(clear_module_state->__pyx_n_s_test); + Py_CLEAR(clear_module_state->__pyx_kp_s_unable_to_allocate_array_data); + Py_CLEAR(clear_module_state->__pyx_kp_s_unable_to_allocate_shape_and_str); + Py_CLEAR(clear_module_state->__pyx_n_s_unpack); + Py_CLEAR(clear_module_state->__pyx_n_s_update); + Py_CLEAR(clear_module_state->__pyx_n_s_values); + Py_CLEAR(clear_module_state->__pyx_n_s_version_info); + Py_CLEAR(clear_module_state->__pyx_int_0); + Py_CLEAR(clear_module_state->__pyx_int_1); + Py_CLEAR(clear_module_state->__pyx_int_3); + Py_CLEAR(clear_module_state->__pyx_int_112105877); + Py_CLEAR(clear_module_state->__pyx_int_136983863); + Py_CLEAR(clear_module_state->__pyx_int_184977713); + Py_CLEAR(clear_module_state->__pyx_int_neg_1); + Py_CLEAR(clear_module_state->__pyx_slice__5); + Py_CLEAR(clear_module_state->__pyx_tuple__4); + Py_CLEAR(clear_module_state->__pyx_tuple__8); + Py_CLEAR(clear_module_state->__pyx_tuple__9); + Py_CLEAR(clear_module_state->__pyx_tuple__10); + Py_CLEAR(clear_module_state->__pyx_tuple__12); + Py_CLEAR(clear_module_state->__pyx_tuple__13); + Py_CLEAR(clear_module_state->__pyx_tuple__14); + Py_CLEAR(clear_module_state->__pyx_tuple__15); + Py_CLEAR(clear_module_state->__pyx_tuple__16); + Py_CLEAR(clear_module_state->__pyx_tuple__17); + Py_CLEAR(clear_module_state->__pyx_tuple__18); + Py_CLEAR(clear_module_state->__pyx_tuple__19); + Py_CLEAR(clear_module_state->__pyx_tuple__20); + Py_CLEAR(clear_module_state->__pyx_tuple__21); + Py_CLEAR(clear_module_state->__pyx_tuple__23); + Py_CLEAR(clear_module_state->__pyx_codeobj__22); + Py_CLEAR(clear_module_state->__pyx_codeobj__24); + return 0; +} +#endif +/* #### Code section: module_state_traverse ### */ +#if CYTHON_USE_MODULE_STATE +static int __pyx_m_traverse(PyObject *m, visitproc visit, void *arg) { + __pyx_mstate *traverse_module_state = __pyx_mstate(m); + if (!traverse_module_state) return 0; + Py_VISIT(traverse_module_state->__pyx_d); + Py_VISIT(traverse_module_state->__pyx_b); + Py_VISIT(traverse_module_state->__pyx_cython_runtime); + Py_VISIT(traverse_module_state->__pyx_empty_tuple); + Py_VISIT(traverse_module_state->__pyx_empty_bytes); + Py_VISIT(traverse_module_state->__pyx_empty_unicode); + #ifdef __Pyx_CyFunction_USED + Py_VISIT(traverse_module_state->__pyx_CyFunctionType); + #endif + #ifdef __Pyx_FusedFunction_USED + Py_VISIT(traverse_module_state->__pyx_FusedFunctionType); + #endif + Py_VISIT(traverse_module_state->__pyx_ptype_7cpython_4type_type); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_dtype); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_flatiter); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_broadcast); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_ndarray); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_generic); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_number); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_integer); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_signedinteger); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_unsignedinteger); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_inexact); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_floating); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_complexfloating); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_flexible); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_character); + Py_VISIT(traverse_module_state->__pyx_ptype_5numpy_ufunc); + Py_VISIT(traverse_module_state->__pyx_array_type); + Py_VISIT(traverse_module_state->__pyx_type___pyx_array); + Py_VISIT(traverse_module_state->__pyx_MemviewEnum_type); + Py_VISIT(traverse_module_state->__pyx_type___pyx_MemviewEnum); + Py_VISIT(traverse_module_state->__pyx_memoryview_type); + Py_VISIT(traverse_module_state->__pyx_type___pyx_memoryview); + Py_VISIT(traverse_module_state->__pyx_memoryviewslice_type); + Py_VISIT(traverse_module_state->__pyx_type___pyx_memoryviewslice); + Py_VISIT(traverse_module_state->__pyx_kp_u_); + Py_VISIT(traverse_module_state->__pyx_n_s_ASCII); + Py_VISIT(traverse_module_state->__pyx_kp_s_All_dimensions_preceding_dimensi); + Py_VISIT(traverse_module_state->__pyx_n_s_AssertionError); + Py_VISIT(traverse_module_state->__pyx_kp_s_Buffer_view_does_not_expose_stri); + Py_VISIT(traverse_module_state->__pyx_kp_s_Can_only_create_a_buffer_that_is); + Py_VISIT(traverse_module_state->__pyx_kp_s_Cannot_assign_to_read_only_memor); + Py_VISIT(traverse_module_state->__pyx_kp_s_Cannot_create_writable_memory_vi); + Py_VISIT(traverse_module_state->__pyx_kp_u_Cannot_index_with_type); + Py_VISIT(traverse_module_state->__pyx_kp_s_Cannot_transpose_memoryview_with); + Py_VISIT(traverse_module_state->__pyx_kp_s_Dimension_d_is_not_direct); + Py_VISIT(traverse_module_state->__pyx_n_s_Ellipsis); + Py_VISIT(traverse_module_state->__pyx_kp_s_Empty_shape_tuple_for_cython_arr); + Py_VISIT(traverse_module_state->__pyx_n_s_ImportError); + Py_VISIT(traverse_module_state->__pyx_kp_s_Incompatible_checksums_0x_x_vs_0); + Py_VISIT(traverse_module_state->__pyx_n_s_IndexError); + Py_VISIT(traverse_module_state->__pyx_kp_s_Index_out_of_bounds_axis_d); + Py_VISIT(traverse_module_state->__pyx_kp_s_Indirect_dimensions_not_supporte); + Py_VISIT(traverse_module_state->__pyx_kp_u_Invalid_mode_expected_c_or_fortr); + Py_VISIT(traverse_module_state->__pyx_kp_u_Invalid_shape_in_axis); + Py_VISIT(traverse_module_state->__pyx_n_s_MemoryError); + Py_VISIT(traverse_module_state->__pyx_kp_s_MemoryView_of_r_at_0x_x); + Py_VISIT(traverse_module_state->__pyx_kp_s_MemoryView_of_r_object); + Py_VISIT(traverse_module_state->__pyx_n_b_O); + Py_VISIT(traverse_module_state->__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + Py_VISIT(traverse_module_state->__pyx_n_s_PickleError); + Py_VISIT(traverse_module_state->__pyx_n_s_Sequence); + Py_VISIT(traverse_module_state->__pyx_kp_s_Step_may_not_be_zero_axis_d); + Py_VISIT(traverse_module_state->__pyx_kp_s_TTS_tts_utils_monotonic_align_co); + Py_VISIT(traverse_module_state->__pyx_n_s_TTS_tts_utils_monotonic_align_co_2); + Py_VISIT(traverse_module_state->__pyx_n_s_TypeError); + Py_VISIT(traverse_module_state->__pyx_kp_s_Unable_to_convert_item_to_object); + Py_VISIT(traverse_module_state->__pyx_n_s_ValueError); + Py_VISIT(traverse_module_state->__pyx_n_s_View_MemoryView); + Py_VISIT(traverse_module_state->__pyx_kp_u__2); + Py_VISIT(traverse_module_state->__pyx_n_s__25); + Py_VISIT(traverse_module_state->__pyx_n_s__3); + Py_VISIT(traverse_module_state->__pyx_kp_u__6); + Py_VISIT(traverse_module_state->__pyx_kp_u__7); + Py_VISIT(traverse_module_state->__pyx_n_s_abc); + Py_VISIT(traverse_module_state->__pyx_n_s_allocate_buffer); + Py_VISIT(traverse_module_state->__pyx_kp_u_and); + Py_VISIT(traverse_module_state->__pyx_n_s_asyncio_coroutines); + Py_VISIT(traverse_module_state->__pyx_n_s_base); + Py_VISIT(traverse_module_state->__pyx_n_s_c); + Py_VISIT(traverse_module_state->__pyx_n_u_c); + Py_VISIT(traverse_module_state->__pyx_n_s_class); + Py_VISIT(traverse_module_state->__pyx_n_s_class_getitem); + Py_VISIT(traverse_module_state->__pyx_n_s_cline_in_traceback); + Py_VISIT(traverse_module_state->__pyx_n_s_collections); + Py_VISIT(traverse_module_state->__pyx_kp_s_collections_abc); + Py_VISIT(traverse_module_state->__pyx_kp_s_contiguous_and_direct); + Py_VISIT(traverse_module_state->__pyx_kp_s_contiguous_and_indirect); + Py_VISIT(traverse_module_state->__pyx_n_s_count); + Py_VISIT(traverse_module_state->__pyx_n_s_dict); + Py_VISIT(traverse_module_state->__pyx_kp_u_disable); + Py_VISIT(traverse_module_state->__pyx_n_s_dtype_is_object); + Py_VISIT(traverse_module_state->__pyx_kp_u_enable); + Py_VISIT(traverse_module_state->__pyx_n_s_encode); + Py_VISIT(traverse_module_state->__pyx_n_s_enumerate); + Py_VISIT(traverse_module_state->__pyx_n_s_error); + Py_VISIT(traverse_module_state->__pyx_n_s_flags); + Py_VISIT(traverse_module_state->__pyx_n_s_format); + Py_VISIT(traverse_module_state->__pyx_n_s_fortran); + Py_VISIT(traverse_module_state->__pyx_n_u_fortran); + Py_VISIT(traverse_module_state->__pyx_kp_u_gc); + Py_VISIT(traverse_module_state->__pyx_n_s_getstate); + Py_VISIT(traverse_module_state->__pyx_kp_u_got); + Py_VISIT(traverse_module_state->__pyx_kp_u_got_differing_extents_in_dimensi); + Py_VISIT(traverse_module_state->__pyx_n_s_id); + Py_VISIT(traverse_module_state->__pyx_n_s_import); + Py_VISIT(traverse_module_state->__pyx_n_s_index); + Py_VISIT(traverse_module_state->__pyx_n_s_initializing); + Py_VISIT(traverse_module_state->__pyx_n_s_is_coroutine); + Py_VISIT(traverse_module_state->__pyx_kp_u_isenabled); + Py_VISIT(traverse_module_state->__pyx_n_s_itemsize); + Py_VISIT(traverse_module_state->__pyx_kp_s_itemsize_0_for_cython_array); + Py_VISIT(traverse_module_state->__pyx_n_s_main); + Py_VISIT(traverse_module_state->__pyx_n_s_max_neg_val); + Py_VISIT(traverse_module_state->__pyx_n_s_maximum_path_c); + Py_VISIT(traverse_module_state->__pyx_n_s_memview); + Py_VISIT(traverse_module_state->__pyx_n_s_mode); + Py_VISIT(traverse_module_state->__pyx_n_s_name); + Py_VISIT(traverse_module_state->__pyx_n_s_name_2); + Py_VISIT(traverse_module_state->__pyx_n_s_ndim); + Py_VISIT(traverse_module_state->__pyx_n_s_new); + Py_VISIT(traverse_module_state->__pyx_kp_s_no_default___reduce___due_to_non); + Py_VISIT(traverse_module_state->__pyx_n_s_np); + Py_VISIT(traverse_module_state->__pyx_n_s_numpy); + Py_VISIT(traverse_module_state->__pyx_kp_u_numpy__core_multiarray_failed_to); + Py_VISIT(traverse_module_state->__pyx_kp_u_numpy__core_umath_failed_to_impo); + Py_VISIT(traverse_module_state->__pyx_n_s_obj); + Py_VISIT(traverse_module_state->__pyx_n_s_pack); + Py_VISIT(traverse_module_state->__pyx_n_s_paths); + Py_VISIT(traverse_module_state->__pyx_n_s_pickle); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_PickleError); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_checksum); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_result); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_state); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_type); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_unpickle_Enum); + Py_VISIT(traverse_module_state->__pyx_n_s_pyx_vtable); + Py_VISIT(traverse_module_state->__pyx_n_s_range); + Py_VISIT(traverse_module_state->__pyx_n_s_reduce); + Py_VISIT(traverse_module_state->__pyx_n_s_reduce_cython); + Py_VISIT(traverse_module_state->__pyx_n_s_reduce_ex); + Py_VISIT(traverse_module_state->__pyx_n_s_register); + Py_VISIT(traverse_module_state->__pyx_n_s_setstate); + Py_VISIT(traverse_module_state->__pyx_n_s_setstate_cython); + Py_VISIT(traverse_module_state->__pyx_n_s_shape); + Py_VISIT(traverse_module_state->__pyx_n_s_size); + Py_VISIT(traverse_module_state->__pyx_n_s_spec); + Py_VISIT(traverse_module_state->__pyx_n_s_start); + Py_VISIT(traverse_module_state->__pyx_n_s_step); + Py_VISIT(traverse_module_state->__pyx_n_s_stop); + Py_VISIT(traverse_module_state->__pyx_kp_s_strided_and_direct); + Py_VISIT(traverse_module_state->__pyx_kp_s_strided_and_direct_or_indirect); + Py_VISIT(traverse_module_state->__pyx_kp_s_strided_and_indirect); + Py_VISIT(traverse_module_state->__pyx_kp_s_stringsource); + Py_VISIT(traverse_module_state->__pyx_n_s_struct); + Py_VISIT(traverse_module_state->__pyx_n_s_sys); + Py_VISIT(traverse_module_state->__pyx_n_s_t_xs); + Py_VISIT(traverse_module_state->__pyx_n_s_t_ys); + Py_VISIT(traverse_module_state->__pyx_n_s_test); + Py_VISIT(traverse_module_state->__pyx_kp_s_unable_to_allocate_array_data); + Py_VISIT(traverse_module_state->__pyx_kp_s_unable_to_allocate_shape_and_str); + Py_VISIT(traverse_module_state->__pyx_n_s_unpack); + Py_VISIT(traverse_module_state->__pyx_n_s_update); + Py_VISIT(traverse_module_state->__pyx_n_s_values); + Py_VISIT(traverse_module_state->__pyx_n_s_version_info); + Py_VISIT(traverse_module_state->__pyx_int_0); + Py_VISIT(traverse_module_state->__pyx_int_1); + Py_VISIT(traverse_module_state->__pyx_int_3); + Py_VISIT(traverse_module_state->__pyx_int_112105877); + Py_VISIT(traverse_module_state->__pyx_int_136983863); + Py_VISIT(traverse_module_state->__pyx_int_184977713); + Py_VISIT(traverse_module_state->__pyx_int_neg_1); + Py_VISIT(traverse_module_state->__pyx_slice__5); + Py_VISIT(traverse_module_state->__pyx_tuple__4); + Py_VISIT(traverse_module_state->__pyx_tuple__8); + Py_VISIT(traverse_module_state->__pyx_tuple__9); + Py_VISIT(traverse_module_state->__pyx_tuple__10); + Py_VISIT(traverse_module_state->__pyx_tuple__12); + Py_VISIT(traverse_module_state->__pyx_tuple__13); + Py_VISIT(traverse_module_state->__pyx_tuple__14); + Py_VISIT(traverse_module_state->__pyx_tuple__15); + Py_VISIT(traverse_module_state->__pyx_tuple__16); + Py_VISIT(traverse_module_state->__pyx_tuple__17); + Py_VISIT(traverse_module_state->__pyx_tuple__18); + Py_VISIT(traverse_module_state->__pyx_tuple__19); + Py_VISIT(traverse_module_state->__pyx_tuple__20); + Py_VISIT(traverse_module_state->__pyx_tuple__21); + Py_VISIT(traverse_module_state->__pyx_tuple__23); + Py_VISIT(traverse_module_state->__pyx_codeobj__22); + Py_VISIT(traverse_module_state->__pyx_codeobj__24); + return 0; +} +#endif +/* #### Code section: module_state_defines ### */ +#define __pyx_d __pyx_mstate_global->__pyx_d +#define __pyx_b __pyx_mstate_global->__pyx_b +#define __pyx_cython_runtime __pyx_mstate_global->__pyx_cython_runtime +#define __pyx_empty_tuple __pyx_mstate_global->__pyx_empty_tuple +#define __pyx_empty_bytes __pyx_mstate_global->__pyx_empty_bytes +#define __pyx_empty_unicode __pyx_mstate_global->__pyx_empty_unicode +#ifdef __Pyx_CyFunction_USED +#define __pyx_CyFunctionType __pyx_mstate_global->__pyx_CyFunctionType +#endif +#ifdef __Pyx_FusedFunction_USED +#define __pyx_FusedFunctionType __pyx_mstate_global->__pyx_FusedFunctionType +#endif +#ifdef __Pyx_Generator_USED +#define __pyx_GeneratorType __pyx_mstate_global->__pyx_GeneratorType +#endif +#ifdef __Pyx_IterableCoroutine_USED +#define __pyx_IterableCoroutineType __pyx_mstate_global->__pyx_IterableCoroutineType +#endif +#ifdef __Pyx_Coroutine_USED +#define __pyx_CoroutineAwaitType __pyx_mstate_global->__pyx_CoroutineAwaitType +#endif +#ifdef __Pyx_Coroutine_USED +#define __pyx_CoroutineType __pyx_mstate_global->__pyx_CoroutineType +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#define __pyx_ptype_7cpython_4type_type __pyx_mstate_global->__pyx_ptype_7cpython_4type_type +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#if CYTHON_USE_MODULE_STATE +#endif +#define __pyx_ptype_5numpy_dtype __pyx_mstate_global->__pyx_ptype_5numpy_dtype +#define __pyx_ptype_5numpy_flatiter __pyx_mstate_global->__pyx_ptype_5numpy_flatiter +#define __pyx_ptype_5numpy_broadcast __pyx_mstate_global->__pyx_ptype_5numpy_broadcast +#define __pyx_ptype_5numpy_ndarray __pyx_mstate_global->__pyx_ptype_5numpy_ndarray +#define __pyx_ptype_5numpy_generic __pyx_mstate_global->__pyx_ptype_5numpy_generic +#define __pyx_ptype_5numpy_number __pyx_mstate_global->__pyx_ptype_5numpy_number +#define __pyx_ptype_5numpy_integer __pyx_mstate_global->__pyx_ptype_5numpy_integer +#define __pyx_ptype_5numpy_signedinteger __pyx_mstate_global->__pyx_ptype_5numpy_signedinteger +#define __pyx_ptype_5numpy_unsignedinteger __pyx_mstate_global->__pyx_ptype_5numpy_unsignedinteger +#define __pyx_ptype_5numpy_inexact __pyx_mstate_global->__pyx_ptype_5numpy_inexact +#define __pyx_ptype_5numpy_floating __pyx_mstate_global->__pyx_ptype_5numpy_floating +#define __pyx_ptype_5numpy_complexfloating __pyx_mstate_global->__pyx_ptype_5numpy_complexfloating +#define __pyx_ptype_5numpy_flexible __pyx_mstate_global->__pyx_ptype_5numpy_flexible +#define __pyx_ptype_5numpy_character __pyx_mstate_global->__pyx_ptype_5numpy_character +#define __pyx_ptype_5numpy_ufunc __pyx_mstate_global->__pyx_ptype_5numpy_ufunc +#if CYTHON_USE_MODULE_STATE +#define __pyx_type___pyx_array __pyx_mstate_global->__pyx_type___pyx_array +#define __pyx_type___pyx_MemviewEnum __pyx_mstate_global->__pyx_type___pyx_MemviewEnum +#define __pyx_type___pyx_memoryview __pyx_mstate_global->__pyx_type___pyx_memoryview +#define __pyx_type___pyx_memoryviewslice __pyx_mstate_global->__pyx_type___pyx_memoryviewslice +#endif +#define __pyx_array_type __pyx_mstate_global->__pyx_array_type +#define __pyx_MemviewEnum_type __pyx_mstate_global->__pyx_MemviewEnum_type +#define __pyx_memoryview_type __pyx_mstate_global->__pyx_memoryview_type +#define __pyx_memoryviewslice_type __pyx_mstate_global->__pyx_memoryviewslice_type +#define __pyx_kp_u_ __pyx_mstate_global->__pyx_kp_u_ +#define __pyx_n_s_ASCII __pyx_mstate_global->__pyx_n_s_ASCII +#define __pyx_kp_s_All_dimensions_preceding_dimensi __pyx_mstate_global->__pyx_kp_s_All_dimensions_preceding_dimensi +#define __pyx_n_s_AssertionError __pyx_mstate_global->__pyx_n_s_AssertionError +#define __pyx_kp_s_Buffer_view_does_not_expose_stri __pyx_mstate_global->__pyx_kp_s_Buffer_view_does_not_expose_stri +#define __pyx_kp_s_Can_only_create_a_buffer_that_is __pyx_mstate_global->__pyx_kp_s_Can_only_create_a_buffer_that_is +#define __pyx_kp_s_Cannot_assign_to_read_only_memor __pyx_mstate_global->__pyx_kp_s_Cannot_assign_to_read_only_memor +#define __pyx_kp_s_Cannot_create_writable_memory_vi __pyx_mstate_global->__pyx_kp_s_Cannot_create_writable_memory_vi +#define __pyx_kp_u_Cannot_index_with_type __pyx_mstate_global->__pyx_kp_u_Cannot_index_with_type +#define __pyx_kp_s_Cannot_transpose_memoryview_with __pyx_mstate_global->__pyx_kp_s_Cannot_transpose_memoryview_with +#define __pyx_kp_s_Dimension_d_is_not_direct __pyx_mstate_global->__pyx_kp_s_Dimension_d_is_not_direct +#define __pyx_n_s_Ellipsis __pyx_mstate_global->__pyx_n_s_Ellipsis +#define __pyx_kp_s_Empty_shape_tuple_for_cython_arr __pyx_mstate_global->__pyx_kp_s_Empty_shape_tuple_for_cython_arr +#define __pyx_n_s_ImportError __pyx_mstate_global->__pyx_n_s_ImportError +#define __pyx_kp_s_Incompatible_checksums_0x_x_vs_0 __pyx_mstate_global->__pyx_kp_s_Incompatible_checksums_0x_x_vs_0 +#define __pyx_n_s_IndexError __pyx_mstate_global->__pyx_n_s_IndexError +#define __pyx_kp_s_Index_out_of_bounds_axis_d __pyx_mstate_global->__pyx_kp_s_Index_out_of_bounds_axis_d +#define __pyx_kp_s_Indirect_dimensions_not_supporte __pyx_mstate_global->__pyx_kp_s_Indirect_dimensions_not_supporte +#define __pyx_kp_u_Invalid_mode_expected_c_or_fortr __pyx_mstate_global->__pyx_kp_u_Invalid_mode_expected_c_or_fortr +#define __pyx_kp_u_Invalid_shape_in_axis __pyx_mstate_global->__pyx_kp_u_Invalid_shape_in_axis +#define __pyx_n_s_MemoryError __pyx_mstate_global->__pyx_n_s_MemoryError +#define __pyx_kp_s_MemoryView_of_r_at_0x_x __pyx_mstate_global->__pyx_kp_s_MemoryView_of_r_at_0x_x +#define __pyx_kp_s_MemoryView_of_r_object __pyx_mstate_global->__pyx_kp_s_MemoryView_of_r_object +#define __pyx_n_b_O __pyx_mstate_global->__pyx_n_b_O +#define __pyx_kp_u_Out_of_bounds_on_buffer_access_a __pyx_mstate_global->__pyx_kp_u_Out_of_bounds_on_buffer_access_a +#define __pyx_n_s_PickleError __pyx_mstate_global->__pyx_n_s_PickleError +#define __pyx_n_s_Sequence __pyx_mstate_global->__pyx_n_s_Sequence +#define __pyx_kp_s_Step_may_not_be_zero_axis_d __pyx_mstate_global->__pyx_kp_s_Step_may_not_be_zero_axis_d +#define __pyx_kp_s_TTS_tts_utils_monotonic_align_co __pyx_mstate_global->__pyx_kp_s_TTS_tts_utils_monotonic_align_co +#define __pyx_n_s_TTS_tts_utils_monotonic_align_co_2 __pyx_mstate_global->__pyx_n_s_TTS_tts_utils_monotonic_align_co_2 +#define __pyx_n_s_TypeError __pyx_mstate_global->__pyx_n_s_TypeError +#define __pyx_kp_s_Unable_to_convert_item_to_object __pyx_mstate_global->__pyx_kp_s_Unable_to_convert_item_to_object +#define __pyx_n_s_ValueError __pyx_mstate_global->__pyx_n_s_ValueError +#define __pyx_n_s_View_MemoryView __pyx_mstate_global->__pyx_n_s_View_MemoryView +#define __pyx_kp_u__2 __pyx_mstate_global->__pyx_kp_u__2 +#define __pyx_n_s__25 __pyx_mstate_global->__pyx_n_s__25 +#define __pyx_n_s__3 __pyx_mstate_global->__pyx_n_s__3 +#define __pyx_kp_u__6 __pyx_mstate_global->__pyx_kp_u__6 +#define __pyx_kp_u__7 __pyx_mstate_global->__pyx_kp_u__7 +#define __pyx_n_s_abc __pyx_mstate_global->__pyx_n_s_abc +#define __pyx_n_s_allocate_buffer __pyx_mstate_global->__pyx_n_s_allocate_buffer +#define __pyx_kp_u_and __pyx_mstate_global->__pyx_kp_u_and +#define __pyx_n_s_asyncio_coroutines __pyx_mstate_global->__pyx_n_s_asyncio_coroutines +#define __pyx_n_s_base __pyx_mstate_global->__pyx_n_s_base +#define __pyx_n_s_c __pyx_mstate_global->__pyx_n_s_c +#define __pyx_n_u_c __pyx_mstate_global->__pyx_n_u_c +#define __pyx_n_s_class __pyx_mstate_global->__pyx_n_s_class +#define __pyx_n_s_class_getitem __pyx_mstate_global->__pyx_n_s_class_getitem +#define __pyx_n_s_cline_in_traceback __pyx_mstate_global->__pyx_n_s_cline_in_traceback +#define __pyx_n_s_collections __pyx_mstate_global->__pyx_n_s_collections +#define __pyx_kp_s_collections_abc __pyx_mstate_global->__pyx_kp_s_collections_abc +#define __pyx_kp_s_contiguous_and_direct __pyx_mstate_global->__pyx_kp_s_contiguous_and_direct +#define __pyx_kp_s_contiguous_and_indirect __pyx_mstate_global->__pyx_kp_s_contiguous_and_indirect +#define __pyx_n_s_count __pyx_mstate_global->__pyx_n_s_count +#define __pyx_n_s_dict __pyx_mstate_global->__pyx_n_s_dict +#define __pyx_kp_u_disable __pyx_mstate_global->__pyx_kp_u_disable +#define __pyx_n_s_dtype_is_object __pyx_mstate_global->__pyx_n_s_dtype_is_object +#define __pyx_kp_u_enable __pyx_mstate_global->__pyx_kp_u_enable +#define __pyx_n_s_encode __pyx_mstate_global->__pyx_n_s_encode +#define __pyx_n_s_enumerate __pyx_mstate_global->__pyx_n_s_enumerate +#define __pyx_n_s_error __pyx_mstate_global->__pyx_n_s_error +#define __pyx_n_s_flags __pyx_mstate_global->__pyx_n_s_flags +#define __pyx_n_s_format __pyx_mstate_global->__pyx_n_s_format +#define __pyx_n_s_fortran __pyx_mstate_global->__pyx_n_s_fortran +#define __pyx_n_u_fortran __pyx_mstate_global->__pyx_n_u_fortran +#define __pyx_kp_u_gc __pyx_mstate_global->__pyx_kp_u_gc +#define __pyx_n_s_getstate __pyx_mstate_global->__pyx_n_s_getstate +#define __pyx_kp_u_got __pyx_mstate_global->__pyx_kp_u_got +#define __pyx_kp_u_got_differing_extents_in_dimensi __pyx_mstate_global->__pyx_kp_u_got_differing_extents_in_dimensi +#define __pyx_n_s_id __pyx_mstate_global->__pyx_n_s_id +#define __pyx_n_s_import __pyx_mstate_global->__pyx_n_s_import +#define __pyx_n_s_index __pyx_mstate_global->__pyx_n_s_index +#define __pyx_n_s_initializing __pyx_mstate_global->__pyx_n_s_initializing +#define __pyx_n_s_is_coroutine __pyx_mstate_global->__pyx_n_s_is_coroutine +#define __pyx_kp_u_isenabled __pyx_mstate_global->__pyx_kp_u_isenabled +#define __pyx_n_s_itemsize __pyx_mstate_global->__pyx_n_s_itemsize +#define __pyx_kp_s_itemsize_0_for_cython_array __pyx_mstate_global->__pyx_kp_s_itemsize_0_for_cython_array +#define __pyx_n_s_main __pyx_mstate_global->__pyx_n_s_main +#define __pyx_n_s_max_neg_val __pyx_mstate_global->__pyx_n_s_max_neg_val +#define __pyx_n_s_maximum_path_c __pyx_mstate_global->__pyx_n_s_maximum_path_c +#define __pyx_n_s_memview __pyx_mstate_global->__pyx_n_s_memview +#define __pyx_n_s_mode __pyx_mstate_global->__pyx_n_s_mode +#define __pyx_n_s_name __pyx_mstate_global->__pyx_n_s_name +#define __pyx_n_s_name_2 __pyx_mstate_global->__pyx_n_s_name_2 +#define __pyx_n_s_ndim __pyx_mstate_global->__pyx_n_s_ndim +#define __pyx_n_s_new __pyx_mstate_global->__pyx_n_s_new +#define __pyx_kp_s_no_default___reduce___due_to_non __pyx_mstate_global->__pyx_kp_s_no_default___reduce___due_to_non +#define __pyx_n_s_np __pyx_mstate_global->__pyx_n_s_np +#define __pyx_n_s_numpy __pyx_mstate_global->__pyx_n_s_numpy +#define __pyx_kp_u_numpy__core_multiarray_failed_to __pyx_mstate_global->__pyx_kp_u_numpy__core_multiarray_failed_to +#define __pyx_kp_u_numpy__core_umath_failed_to_impo __pyx_mstate_global->__pyx_kp_u_numpy__core_umath_failed_to_impo +#define __pyx_n_s_obj __pyx_mstate_global->__pyx_n_s_obj +#define __pyx_n_s_pack __pyx_mstate_global->__pyx_n_s_pack +#define __pyx_n_s_paths __pyx_mstate_global->__pyx_n_s_paths +#define __pyx_n_s_pickle __pyx_mstate_global->__pyx_n_s_pickle +#define __pyx_n_s_pyx_PickleError __pyx_mstate_global->__pyx_n_s_pyx_PickleError +#define __pyx_n_s_pyx_checksum __pyx_mstate_global->__pyx_n_s_pyx_checksum +#define __pyx_n_s_pyx_result __pyx_mstate_global->__pyx_n_s_pyx_result +#define __pyx_n_s_pyx_state __pyx_mstate_global->__pyx_n_s_pyx_state +#define __pyx_n_s_pyx_type __pyx_mstate_global->__pyx_n_s_pyx_type +#define __pyx_n_s_pyx_unpickle_Enum __pyx_mstate_global->__pyx_n_s_pyx_unpickle_Enum +#define __pyx_n_s_pyx_vtable __pyx_mstate_global->__pyx_n_s_pyx_vtable +#define __pyx_n_s_range __pyx_mstate_global->__pyx_n_s_range +#define __pyx_n_s_reduce __pyx_mstate_global->__pyx_n_s_reduce +#define __pyx_n_s_reduce_cython __pyx_mstate_global->__pyx_n_s_reduce_cython +#define __pyx_n_s_reduce_ex __pyx_mstate_global->__pyx_n_s_reduce_ex +#define __pyx_n_s_register __pyx_mstate_global->__pyx_n_s_register +#define __pyx_n_s_setstate __pyx_mstate_global->__pyx_n_s_setstate +#define __pyx_n_s_setstate_cython __pyx_mstate_global->__pyx_n_s_setstate_cython +#define __pyx_n_s_shape __pyx_mstate_global->__pyx_n_s_shape +#define __pyx_n_s_size __pyx_mstate_global->__pyx_n_s_size +#define __pyx_n_s_spec __pyx_mstate_global->__pyx_n_s_spec +#define __pyx_n_s_start __pyx_mstate_global->__pyx_n_s_start +#define __pyx_n_s_step __pyx_mstate_global->__pyx_n_s_step +#define __pyx_n_s_stop __pyx_mstate_global->__pyx_n_s_stop +#define __pyx_kp_s_strided_and_direct __pyx_mstate_global->__pyx_kp_s_strided_and_direct +#define __pyx_kp_s_strided_and_direct_or_indirect __pyx_mstate_global->__pyx_kp_s_strided_and_direct_or_indirect +#define __pyx_kp_s_strided_and_indirect __pyx_mstate_global->__pyx_kp_s_strided_and_indirect +#define __pyx_kp_s_stringsource __pyx_mstate_global->__pyx_kp_s_stringsource +#define __pyx_n_s_struct __pyx_mstate_global->__pyx_n_s_struct +#define __pyx_n_s_sys __pyx_mstate_global->__pyx_n_s_sys +#define __pyx_n_s_t_xs __pyx_mstate_global->__pyx_n_s_t_xs +#define __pyx_n_s_t_ys __pyx_mstate_global->__pyx_n_s_t_ys +#define __pyx_n_s_test __pyx_mstate_global->__pyx_n_s_test +#define __pyx_kp_s_unable_to_allocate_array_data __pyx_mstate_global->__pyx_kp_s_unable_to_allocate_array_data +#define __pyx_kp_s_unable_to_allocate_shape_and_str __pyx_mstate_global->__pyx_kp_s_unable_to_allocate_shape_and_str +#define __pyx_n_s_unpack __pyx_mstate_global->__pyx_n_s_unpack +#define __pyx_n_s_update __pyx_mstate_global->__pyx_n_s_update +#define __pyx_n_s_values __pyx_mstate_global->__pyx_n_s_values +#define __pyx_n_s_version_info __pyx_mstate_global->__pyx_n_s_version_info +#define __pyx_int_0 __pyx_mstate_global->__pyx_int_0 +#define __pyx_int_1 __pyx_mstate_global->__pyx_int_1 +#define __pyx_int_3 __pyx_mstate_global->__pyx_int_3 +#define __pyx_int_112105877 __pyx_mstate_global->__pyx_int_112105877 +#define __pyx_int_136983863 __pyx_mstate_global->__pyx_int_136983863 +#define __pyx_int_184977713 __pyx_mstate_global->__pyx_int_184977713 +#define __pyx_int_neg_1 __pyx_mstate_global->__pyx_int_neg_1 +#define __pyx_k__11 __pyx_mstate_global->__pyx_k__11 +#define __pyx_slice__5 __pyx_mstate_global->__pyx_slice__5 +#define __pyx_tuple__4 __pyx_mstate_global->__pyx_tuple__4 +#define __pyx_tuple__8 __pyx_mstate_global->__pyx_tuple__8 +#define __pyx_tuple__9 __pyx_mstate_global->__pyx_tuple__9 +#define __pyx_tuple__10 __pyx_mstate_global->__pyx_tuple__10 +#define __pyx_tuple__12 __pyx_mstate_global->__pyx_tuple__12 +#define __pyx_tuple__13 __pyx_mstate_global->__pyx_tuple__13 +#define __pyx_tuple__14 __pyx_mstate_global->__pyx_tuple__14 +#define __pyx_tuple__15 __pyx_mstate_global->__pyx_tuple__15 +#define __pyx_tuple__16 __pyx_mstate_global->__pyx_tuple__16 +#define __pyx_tuple__17 __pyx_mstate_global->__pyx_tuple__17 +#define __pyx_tuple__18 __pyx_mstate_global->__pyx_tuple__18 +#define __pyx_tuple__19 __pyx_mstate_global->__pyx_tuple__19 +#define __pyx_tuple__20 __pyx_mstate_global->__pyx_tuple__20 +#define __pyx_tuple__21 __pyx_mstate_global->__pyx_tuple__21 +#define __pyx_tuple__23 __pyx_mstate_global->__pyx_tuple__23 +#define __pyx_codeobj__22 __pyx_mstate_global->__pyx_codeobj__22 +#define __pyx_codeobj__24 __pyx_mstate_global->__pyx_codeobj__24 +/* #### Code section: module_code ### */ + +/* "View.MemoryView":131 + * cdef bint dtype_is_object + * + * def __cinit__(array self, tuple shape, Py_ssize_t itemsize, format not None, # <<<<<<<<<<<<<< + * mode="c", bint allocate_buffer=True): + * + */ + +/* Python wrapper */ +static int __pyx_array___cinit__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ +static int __pyx_array___cinit__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { + PyObject *__pyx_v_shape = 0; + Py_ssize_t __pyx_v_itemsize; + PyObject *__pyx_v_format = 0; + PyObject *__pyx_v_mode = 0; + int __pyx_v_allocate_buffer; + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[5] = {0,0,0,0,0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__cinit__ (wrapper)", 0); + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return -1; + #endif + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_shape,&__pyx_n_s_itemsize,&__pyx_n_s_format,&__pyx_n_s_mode,&__pyx_n_s_allocate_buffer,0}; + values[3] = __Pyx_Arg_NewRef_VARARGS(((PyObject *)__pyx_n_s_c)); + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 5: values[4] = __Pyx_Arg_VARARGS(__pyx_args, 4); + CYTHON_FALLTHROUGH; + case 4: values[3] = __Pyx_Arg_VARARGS(__pyx_args, 3); + CYTHON_FALLTHROUGH; + case 3: values[2] = __Pyx_Arg_VARARGS(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_VARARGS(__pyx_args, 1); + CYTHON_FALLTHROUGH; + case 1: values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_VARARGS(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_shape)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + CYTHON_FALLTHROUGH; + case 1: + if (likely((values[1] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_itemsize)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[1]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__cinit__", 0, 3, 5, 1); __PYX_ERR(1, 131, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 2: + if (likely((values[2] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_format)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[2]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__cinit__", 0, 3, 5, 2); __PYX_ERR(1, 131, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 3: + if (kw_args > 0) { + PyObject* value = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_mode); + if (value) { values[3] = __Pyx_Arg_NewRef_VARARGS(value); kw_args--; } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 4: + if (kw_args > 0) { + PyObject* value = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_allocate_buffer); + if (value) { values[4] = __Pyx_Arg_NewRef_VARARGS(value); kw_args--; } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + } + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__cinit__") < 0)) __PYX_ERR(1, 131, __pyx_L3_error) + } + } else { + switch (__pyx_nargs) { + case 5: values[4] = __Pyx_Arg_VARARGS(__pyx_args, 4); + CYTHON_FALLTHROUGH; + case 4: values[3] = __Pyx_Arg_VARARGS(__pyx_args, 3); + CYTHON_FALLTHROUGH; + case 3: values[2] = __Pyx_Arg_VARARGS(__pyx_args, 2); + values[1] = __Pyx_Arg_VARARGS(__pyx_args, 1); + values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + break; + default: goto __pyx_L5_argtuple_error; + } + } + __pyx_v_shape = ((PyObject*)values[0]); + __pyx_v_itemsize = __Pyx_PyIndex_AsSsize_t(values[1]); if (unlikely((__pyx_v_itemsize == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 131, __pyx_L3_error) + __pyx_v_format = values[2]; + __pyx_v_mode = values[3]; + if (values[4]) { + __pyx_v_allocate_buffer = __Pyx_PyObject_IsTrue(values[4]); if (unlikely((__pyx_v_allocate_buffer == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 132, __pyx_L3_error) + } else { + + /* "View.MemoryView":132 + * + * def __cinit__(array self, tuple shape, Py_ssize_t itemsize, format not None, + * mode="c", bint allocate_buffer=True): # <<<<<<<<<<<<<< + * + * cdef int idx + */ + __pyx_v_allocate_buffer = ((int)1); + } + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__cinit__", 0, 3, 5, __pyx_nargs); __PYX_ERR(1, 131, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.array.__cinit__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return -1; + __pyx_L4_argument_unpacking_done:; + if (unlikely(!__Pyx_ArgTypeTest(((PyObject *)__pyx_v_shape), (&PyTuple_Type), 1, "shape", 1))) __PYX_ERR(1, 131, __pyx_L1_error) + if (unlikely(((PyObject *)__pyx_v_format) == Py_None)) { + PyErr_Format(PyExc_TypeError, "Argument '%.200s' must not be None", "format"); __PYX_ERR(1, 131, __pyx_L1_error) + } + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array___cinit__(((struct __pyx_array_obj *)__pyx_v_self), __pyx_v_shape, __pyx_v_itemsize, __pyx_v_format, __pyx_v_mode, __pyx_v_allocate_buffer); + + /* "View.MemoryView":131 + * cdef bint dtype_is_object + * + * def __cinit__(array self, tuple shape, Py_ssize_t itemsize, format not None, # <<<<<<<<<<<<<< + * mode="c", bint allocate_buffer=True): + * + */ + + /* function exit code */ + goto __pyx_L0; + __pyx_L1_error:; + __pyx_r = -1; + __pyx_L0:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array___cinit__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_shape, Py_ssize_t __pyx_v_itemsize, PyObject *__pyx_v_format, PyObject *__pyx_v_mode, int __pyx_v_allocate_buffer) { + int __pyx_v_idx; + Py_ssize_t __pyx_v_dim; + char __pyx_v_order; + int __pyx_r; + __Pyx_RefNannyDeclarations + Py_ssize_t __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + unsigned int __pyx_t_7; + char *__pyx_t_8; + int __pyx_t_9; + Py_ssize_t __pyx_t_10; + Py_UCS4 __pyx_t_11; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__cinit__", 0); + __Pyx_INCREF(__pyx_v_format); + + /* "View.MemoryView":137 + * cdef Py_ssize_t dim + * + * self.ndim = len(shape) # <<<<<<<<<<<<<< + * self.itemsize = itemsize + * + */ + if (unlikely(__pyx_v_shape == Py_None)) { + PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); + __PYX_ERR(1, 137, __pyx_L1_error) + } + __pyx_t_1 = __Pyx_PyTuple_GET_SIZE(__pyx_v_shape); if (unlikely(__pyx_t_1 == ((Py_ssize_t)-1))) __PYX_ERR(1, 137, __pyx_L1_error) + __pyx_v_self->ndim = ((int)__pyx_t_1); + + /* "View.MemoryView":138 + * + * self.ndim = len(shape) + * self.itemsize = itemsize # <<<<<<<<<<<<<< + * + * if not self.ndim: + */ + __pyx_v_self->itemsize = __pyx_v_itemsize; + + /* "View.MemoryView":140 + * self.itemsize = itemsize + * + * if not self.ndim: # <<<<<<<<<<<<<< + * raise ValueError, "Empty shape tuple for cython.array" + * + */ + __pyx_t_2 = (!(__pyx_v_self->ndim != 0)); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":141 + * + * if not self.ndim: + * raise ValueError, "Empty shape tuple for cython.array" # <<<<<<<<<<<<<< + * + * if itemsize <= 0: + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Empty_shape_tuple_for_cython_arr, 0, 0); + __PYX_ERR(1, 141, __pyx_L1_error) + + /* "View.MemoryView":140 + * self.itemsize = itemsize + * + * if not self.ndim: # <<<<<<<<<<<<<< + * raise ValueError, "Empty shape tuple for cython.array" + * + */ + } + + /* "View.MemoryView":143 + * raise ValueError, "Empty shape tuple for cython.array" + * + * if itemsize <= 0: # <<<<<<<<<<<<<< + * raise ValueError, "itemsize <= 0 for cython.array" + * + */ + __pyx_t_2 = (__pyx_v_itemsize <= 0); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":144 + * + * if itemsize <= 0: + * raise ValueError, "itemsize <= 0 for cython.array" # <<<<<<<<<<<<<< + * + * if not isinstance(format, bytes): + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_itemsize_0_for_cython_array, 0, 0); + __PYX_ERR(1, 144, __pyx_L1_error) + + /* "View.MemoryView":143 + * raise ValueError, "Empty shape tuple for cython.array" + * + * if itemsize <= 0: # <<<<<<<<<<<<<< + * raise ValueError, "itemsize <= 0 for cython.array" + * + */ + } + + /* "View.MemoryView":146 + * raise ValueError, "itemsize <= 0 for cython.array" + * + * if not isinstance(format, bytes): # <<<<<<<<<<<<<< + * format = format.encode('ASCII') + * self._format = format # keep a reference to the byte string + */ + __pyx_t_2 = PyBytes_Check(__pyx_v_format); + __pyx_t_3 = (!__pyx_t_2); + if (__pyx_t_3) { + + /* "View.MemoryView":147 + * + * if not isinstance(format, bytes): + * format = format.encode('ASCII') # <<<<<<<<<<<<<< + * self._format = format # keep a reference to the byte string + * self.format = self._format + */ + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_v_format, __pyx_n_s_encode); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 147, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_6 = NULL; + __pyx_t_7 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_5))) { + __pyx_t_6 = PyMethod_GET_SELF(__pyx_t_5); + if (likely(__pyx_t_6)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_5); + __Pyx_INCREF(__pyx_t_6); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_5, function); + __pyx_t_7 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_6, __pyx_n_s_ASCII}; + __pyx_t_4 = __Pyx_PyObject_FastCall(__pyx_t_5, __pyx_callargs+1-__pyx_t_7, 1+__pyx_t_7); + __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; + if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 147, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + } + __Pyx_DECREF_SET(__pyx_v_format, __pyx_t_4); + __pyx_t_4 = 0; + + /* "View.MemoryView":146 + * raise ValueError, "itemsize <= 0 for cython.array" + * + * if not isinstance(format, bytes): # <<<<<<<<<<<<<< + * format = format.encode('ASCII') + * self._format = format # keep a reference to the byte string + */ + } + + /* "View.MemoryView":148 + * if not isinstance(format, bytes): + * format = format.encode('ASCII') + * self._format = format # keep a reference to the byte string # <<<<<<<<<<<<<< + * self.format = self._format + * + */ + if (!(likely(PyBytes_CheckExact(__pyx_v_format))||((__pyx_v_format) == Py_None) || __Pyx_RaiseUnexpectedTypeError("bytes", __pyx_v_format))) __PYX_ERR(1, 148, __pyx_L1_error) + __pyx_t_4 = __pyx_v_format; + __Pyx_INCREF(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_4); + __Pyx_GOTREF(__pyx_v_self->_format); + __Pyx_DECREF(__pyx_v_self->_format); + __pyx_v_self->_format = ((PyObject*)__pyx_t_4); + __pyx_t_4 = 0; + + /* "View.MemoryView":149 + * format = format.encode('ASCII') + * self._format = format # keep a reference to the byte string + * self.format = self._format # <<<<<<<<<<<<<< + * + * + */ + if (unlikely(__pyx_v_self->_format == Py_None)) { + PyErr_SetString(PyExc_TypeError, "expected bytes, NoneType found"); + __PYX_ERR(1, 149, __pyx_L1_error) + } + __pyx_t_8 = __Pyx_PyBytes_AsWritableString(__pyx_v_self->_format); if (unlikely((!__pyx_t_8) && PyErr_Occurred())) __PYX_ERR(1, 149, __pyx_L1_error) + __pyx_v_self->format = __pyx_t_8; + + /* "View.MemoryView":152 + * + * + * self._shape = PyObject_Malloc(sizeof(Py_ssize_t)*self.ndim*2) # <<<<<<<<<<<<<< + * self._strides = self._shape + self.ndim + * + */ + __pyx_v_self->_shape = ((Py_ssize_t *)PyObject_Malloc((((sizeof(Py_ssize_t)) * __pyx_v_self->ndim) * 2))); + + /* "View.MemoryView":153 + * + * self._shape = PyObject_Malloc(sizeof(Py_ssize_t)*self.ndim*2) + * self._strides = self._shape + self.ndim # <<<<<<<<<<<<<< + * + * if not self._shape: + */ + __pyx_v_self->_strides = (__pyx_v_self->_shape + __pyx_v_self->ndim); + + /* "View.MemoryView":155 + * self._strides = self._shape + self.ndim + * + * if not self._shape: # <<<<<<<<<<<<<< + * raise MemoryError, "unable to allocate shape and strides." + * + */ + __pyx_t_3 = (!(__pyx_v_self->_shape != 0)); + if (unlikely(__pyx_t_3)) { + + /* "View.MemoryView":156 + * + * if not self._shape: + * raise MemoryError, "unable to allocate shape and strides." # <<<<<<<<<<<<<< + * + * + */ + __Pyx_Raise(__pyx_builtin_MemoryError, __pyx_kp_s_unable_to_allocate_shape_and_str, 0, 0); + __PYX_ERR(1, 156, __pyx_L1_error) + + /* "View.MemoryView":155 + * self._strides = self._shape + self.ndim + * + * if not self._shape: # <<<<<<<<<<<<<< + * raise MemoryError, "unable to allocate shape and strides." + * + */ + } + + /* "View.MemoryView":159 + * + * + * for idx, dim in enumerate(shape): # <<<<<<<<<<<<<< + * if dim <= 0: + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." + */ + __pyx_t_9 = 0; + __pyx_t_4 = __pyx_v_shape; __Pyx_INCREF(__pyx_t_4); + __pyx_t_1 = 0; + for (;;) { + { + Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_4); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 159, __pyx_L1_error) + #endif + if (__pyx_t_1 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_5 = PyTuple_GET_ITEM(__pyx_t_4, __pyx_t_1); __Pyx_INCREF(__pyx_t_5); __pyx_t_1++; if (unlikely((0 < 0))) __PYX_ERR(1, 159, __pyx_L1_error) + #else + __pyx_t_5 = __Pyx_PySequence_ITEM(__pyx_t_4, __pyx_t_1); __pyx_t_1++; if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 159, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + #endif + __pyx_t_10 = __Pyx_PyIndex_AsSsize_t(__pyx_t_5); if (unlikely((__pyx_t_10 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 159, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __pyx_v_dim = __pyx_t_10; + __pyx_v_idx = __pyx_t_9; + __pyx_t_9 = (__pyx_t_9 + 1); + + /* "View.MemoryView":160 + * + * for idx, dim in enumerate(shape): + * if dim <= 0: # <<<<<<<<<<<<<< + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." + * self._shape[idx] = dim + */ + __pyx_t_3 = (__pyx_v_dim <= 0); + if (unlikely(__pyx_t_3)) { + + /* "View.MemoryView":161 + * for idx, dim in enumerate(shape): + * if dim <= 0: + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." # <<<<<<<<<<<<<< + * self._shape[idx] = dim + * + */ + __pyx_t_5 = PyTuple_New(5); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 161, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_10 = 0; + __pyx_t_11 = 127; + __Pyx_INCREF(__pyx_kp_u_Invalid_shape_in_axis); + __pyx_t_10 += 22; + __Pyx_GIVEREF(__pyx_kp_u_Invalid_shape_in_axis); + PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_kp_u_Invalid_shape_in_axis); + __pyx_t_6 = __Pyx_PyUnicode_From_int(__pyx_v_idx, 0, ' ', 'd'); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 161, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __pyx_t_10 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_6); + __Pyx_GIVEREF(__pyx_t_6); + PyTuple_SET_ITEM(__pyx_t_5, 1, __pyx_t_6); + __pyx_t_6 = 0; + __Pyx_INCREF(__pyx_kp_u_); + __pyx_t_10 += 2; + __Pyx_GIVEREF(__pyx_kp_u_); + PyTuple_SET_ITEM(__pyx_t_5, 2, __pyx_kp_u_); + __pyx_t_6 = __Pyx_PyUnicode_From_Py_ssize_t(__pyx_v_dim, 0, ' ', 'd'); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 161, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __pyx_t_10 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_6); + __Pyx_GIVEREF(__pyx_t_6); + PyTuple_SET_ITEM(__pyx_t_5, 3, __pyx_t_6); + __pyx_t_6 = 0; + __Pyx_INCREF(__pyx_kp_u__2); + __pyx_t_10 += 1; + __Pyx_GIVEREF(__pyx_kp_u__2); + PyTuple_SET_ITEM(__pyx_t_5, 4, __pyx_kp_u__2); + __pyx_t_6 = __Pyx_PyUnicode_Join(__pyx_t_5, 5, __pyx_t_10, __pyx_t_11); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 161, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_t_6, 0, 0); + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __PYX_ERR(1, 161, __pyx_L1_error) + + /* "View.MemoryView":160 + * + * for idx, dim in enumerate(shape): + * if dim <= 0: # <<<<<<<<<<<<<< + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." + * self._shape[idx] = dim + */ + } + + /* "View.MemoryView":162 + * if dim <= 0: + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." + * self._shape[idx] = dim # <<<<<<<<<<<<<< + * + * cdef char order + */ + (__pyx_v_self->_shape[__pyx_v_idx]) = __pyx_v_dim; + + /* "View.MemoryView":159 + * + * + * for idx, dim in enumerate(shape): # <<<<<<<<<<<<<< + * if dim <= 0: + * raise ValueError, f"Invalid shape in axis {idx}: {dim}." + */ + } + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + + /* "View.MemoryView":165 + * + * cdef char order + * if mode == 'c': # <<<<<<<<<<<<<< + * order = b'C' + * self.mode = u'c' + */ + __pyx_t_3 = (__Pyx_PyString_Equals(__pyx_v_mode, __pyx_n_s_c, Py_EQ)); if (unlikely((__pyx_t_3 < 0))) __PYX_ERR(1, 165, __pyx_L1_error) + if (__pyx_t_3) { + + /* "View.MemoryView":166 + * cdef char order + * if mode == 'c': + * order = b'C' # <<<<<<<<<<<<<< + * self.mode = u'c' + * elif mode == 'fortran': + */ + __pyx_v_order = 'C'; + + /* "View.MemoryView":167 + * if mode == 'c': + * order = b'C' + * self.mode = u'c' # <<<<<<<<<<<<<< + * elif mode == 'fortran': + * order = b'F' + */ + __Pyx_INCREF(__pyx_n_u_c); + __Pyx_GIVEREF(__pyx_n_u_c); + __Pyx_GOTREF(__pyx_v_self->mode); + __Pyx_DECREF(__pyx_v_self->mode); + __pyx_v_self->mode = __pyx_n_u_c; + + /* "View.MemoryView":165 + * + * cdef char order + * if mode == 'c': # <<<<<<<<<<<<<< + * order = b'C' + * self.mode = u'c' + */ + goto __pyx_L11; + } + + /* "View.MemoryView":168 + * order = b'C' + * self.mode = u'c' + * elif mode == 'fortran': # <<<<<<<<<<<<<< + * order = b'F' + * self.mode = u'fortran' + */ + __pyx_t_3 = (__Pyx_PyString_Equals(__pyx_v_mode, __pyx_n_s_fortran, Py_EQ)); if (unlikely((__pyx_t_3 < 0))) __PYX_ERR(1, 168, __pyx_L1_error) + if (likely(__pyx_t_3)) { + + /* "View.MemoryView":169 + * self.mode = u'c' + * elif mode == 'fortran': + * order = b'F' # <<<<<<<<<<<<<< + * self.mode = u'fortran' + * else: + */ + __pyx_v_order = 'F'; + + /* "View.MemoryView":170 + * elif mode == 'fortran': + * order = b'F' + * self.mode = u'fortran' # <<<<<<<<<<<<<< + * else: + * raise ValueError, f"Invalid mode, expected 'c' or 'fortran', got {mode}" + */ + __Pyx_INCREF(__pyx_n_u_fortran); + __Pyx_GIVEREF(__pyx_n_u_fortran); + __Pyx_GOTREF(__pyx_v_self->mode); + __Pyx_DECREF(__pyx_v_self->mode); + __pyx_v_self->mode = __pyx_n_u_fortran; + + /* "View.MemoryView":168 + * order = b'C' + * self.mode = u'c' + * elif mode == 'fortran': # <<<<<<<<<<<<<< + * order = b'F' + * self.mode = u'fortran' + */ + goto __pyx_L11; + } + + /* "View.MemoryView":172 + * self.mode = u'fortran' + * else: + * raise ValueError, f"Invalid mode, expected 'c' or 'fortran', got {mode}" # <<<<<<<<<<<<<< + * + * self.len = fill_contig_strides_array(self._shape, self._strides, itemsize, self.ndim, order) + */ + /*else*/ { + __pyx_t_4 = __Pyx_PyObject_FormatSimple(__pyx_v_mode, __pyx_empty_unicode); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 172, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_6 = __Pyx_PyUnicode_Concat(__pyx_kp_u_Invalid_mode_expected_c_or_fortr, __pyx_t_4); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 172, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_t_6, 0, 0); + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __PYX_ERR(1, 172, __pyx_L1_error) + } + __pyx_L11:; + + /* "View.MemoryView":174 + * raise ValueError, f"Invalid mode, expected 'c' or 'fortran', got {mode}" + * + * self.len = fill_contig_strides_array(self._shape, self._strides, itemsize, self.ndim, order) # <<<<<<<<<<<<<< + * + * self.free_data = allocate_buffer + */ + __pyx_v_self->len = __pyx_fill_contig_strides_array(__pyx_v_self->_shape, __pyx_v_self->_strides, __pyx_v_itemsize, __pyx_v_self->ndim, __pyx_v_order); + + /* "View.MemoryView":176 + * self.len = fill_contig_strides_array(self._shape, self._strides, itemsize, self.ndim, order) + * + * self.free_data = allocate_buffer # <<<<<<<<<<<<<< + * self.dtype_is_object = format == b'O' + * + */ + __pyx_v_self->free_data = __pyx_v_allocate_buffer; + + /* "View.MemoryView":177 + * + * self.free_data = allocate_buffer + * self.dtype_is_object = format == b'O' # <<<<<<<<<<<<<< + * + * if allocate_buffer: + */ + __pyx_t_6 = PyObject_RichCompare(__pyx_v_format, __pyx_n_b_O, Py_EQ); __Pyx_XGOTREF(__pyx_t_6); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 177, __pyx_L1_error) + __pyx_t_3 = __Pyx_PyObject_IsTrue(__pyx_t_6); if (unlikely((__pyx_t_3 == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 177, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __pyx_v_self->dtype_is_object = __pyx_t_3; + + /* "View.MemoryView":179 + * self.dtype_is_object = format == b'O' + * + * if allocate_buffer: # <<<<<<<<<<<<<< + * _allocate_buffer(self) + * + */ + if (__pyx_v_allocate_buffer) { + + /* "View.MemoryView":180 + * + * if allocate_buffer: + * _allocate_buffer(self) # <<<<<<<<<<<<<< + * + * @cname('getbuffer') + */ + __pyx_t_9 = __pyx_array_allocate_buffer(__pyx_v_self); if (unlikely(__pyx_t_9 == ((int)-1))) __PYX_ERR(1, 180, __pyx_L1_error) + + /* "View.MemoryView":179 + * self.dtype_is_object = format == b'O' + * + * if allocate_buffer: # <<<<<<<<<<<<<< + * _allocate_buffer(self) + * + */ + } + + /* "View.MemoryView":131 + * cdef bint dtype_is_object + * + * def __cinit__(array self, tuple shape, Py_ssize_t itemsize, format not None, # <<<<<<<<<<<<<< + * mode="c", bint allocate_buffer=True): + * + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_4); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_AddTraceback("View.MemoryView.array.__cinit__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_format); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":182 + * _allocate_buffer(self) + * + * @cname('getbuffer') # <<<<<<<<<<<<<< + * def __getbuffer__(self, Py_buffer *info, int flags): + * cdef int bufmode = -1 + */ + +/* Python wrapper */ +CYTHON_UNUSED static int __pyx_array_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /*proto*/ +CYTHON_UNUSED static int __pyx_array_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__getbuffer__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array_2__getbuffer__(((struct __pyx_array_obj *)__pyx_v_self), ((Py_buffer *)__pyx_v_info), ((int)__pyx_v_flags)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array_2__getbuffer__(struct __pyx_array_obj *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags) { + int __pyx_v_bufmode; + int __pyx_r; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + char *__pyx_t_2; + Py_ssize_t __pyx_t_3; + int __pyx_t_4; + Py_ssize_t *__pyx_t_5; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + if (unlikely(__pyx_v_info == NULL)) { + PyErr_SetString(PyExc_BufferError, "PyObject_GetBuffer: view==NULL argument is obsolete"); + return -1; + } + __Pyx_RefNannySetupContext("__getbuffer__", 0); + __pyx_v_info->obj = Py_None; __Pyx_INCREF(Py_None); + __Pyx_GIVEREF(__pyx_v_info->obj); + + /* "View.MemoryView":184 + * @cname('getbuffer') + * def __getbuffer__(self, Py_buffer *info, int flags): + * cdef int bufmode = -1 # <<<<<<<<<<<<<< + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): + * if self.mode == u"c": + */ + __pyx_v_bufmode = -1; + + /* "View.MemoryView":185 + * def __getbuffer__(self, Py_buffer *info, int flags): + * cdef int bufmode = -1 + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): # <<<<<<<<<<<<<< + * if self.mode == u"c": + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + */ + __pyx_t_1 = ((__pyx_v_flags & ((PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS) | PyBUF_ANY_CONTIGUOUS)) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":186 + * cdef int bufmode = -1 + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): + * if self.mode == u"c": # <<<<<<<<<<<<<< + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * elif self.mode == u"fortran": + */ + __pyx_t_1 = (__Pyx_PyUnicode_Equals(__pyx_v_self->mode, __pyx_n_u_c, Py_EQ)); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 186, __pyx_L1_error) + if (__pyx_t_1) { + + /* "View.MemoryView":187 + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): + * if self.mode == u"c": + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS # <<<<<<<<<<<<<< + * elif self.mode == u"fortran": + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + */ + __pyx_v_bufmode = (PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS); + + /* "View.MemoryView":186 + * cdef int bufmode = -1 + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): + * if self.mode == u"c": # <<<<<<<<<<<<<< + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * elif self.mode == u"fortran": + */ + goto __pyx_L4; + } + + /* "View.MemoryView":188 + * if self.mode == u"c": + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * elif self.mode == u"fortran": # <<<<<<<<<<<<<< + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * if not (flags & bufmode): + */ + __pyx_t_1 = (__Pyx_PyUnicode_Equals(__pyx_v_self->mode, __pyx_n_u_fortran, Py_EQ)); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 188, __pyx_L1_error) + if (__pyx_t_1) { + + /* "View.MemoryView":189 + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * elif self.mode == u"fortran": + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS # <<<<<<<<<<<<<< + * if not (flags & bufmode): + * raise ValueError, "Can only create a buffer that is contiguous in memory." + */ + __pyx_v_bufmode = (PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS); + + /* "View.MemoryView":188 + * if self.mode == u"c": + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * elif self.mode == u"fortran": # <<<<<<<<<<<<<< + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * if not (flags & bufmode): + */ + } + __pyx_L4:; + + /* "View.MemoryView":190 + * elif self.mode == u"fortran": + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * if not (flags & bufmode): # <<<<<<<<<<<<<< + * raise ValueError, "Can only create a buffer that is contiguous in memory." + * info.buf = self.data + */ + __pyx_t_1 = (!((__pyx_v_flags & __pyx_v_bufmode) != 0)); + if (unlikely(__pyx_t_1)) { + + /* "View.MemoryView":191 + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * if not (flags & bufmode): + * raise ValueError, "Can only create a buffer that is contiguous in memory." # <<<<<<<<<<<<<< + * info.buf = self.data + * info.len = self.len + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Can_only_create_a_buffer_that_is, 0, 0); + __PYX_ERR(1, 191, __pyx_L1_error) + + /* "View.MemoryView":190 + * elif self.mode == u"fortran": + * bufmode = PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + * if not (flags & bufmode): # <<<<<<<<<<<<<< + * raise ValueError, "Can only create a buffer that is contiguous in memory." + * info.buf = self.data + */ + } + + /* "View.MemoryView":185 + * def __getbuffer__(self, Py_buffer *info, int flags): + * cdef int bufmode = -1 + * if flags & (PyBUF_C_CONTIGUOUS | PyBUF_F_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS): # <<<<<<<<<<<<<< + * if self.mode == u"c": + * bufmode = PyBUF_C_CONTIGUOUS | PyBUF_ANY_CONTIGUOUS + */ + } + + /* "View.MemoryView":192 + * if not (flags & bufmode): + * raise ValueError, "Can only create a buffer that is contiguous in memory." + * info.buf = self.data # <<<<<<<<<<<<<< + * info.len = self.len + * + */ + __pyx_t_2 = __pyx_v_self->data; + __pyx_v_info->buf = __pyx_t_2; + + /* "View.MemoryView":193 + * raise ValueError, "Can only create a buffer that is contiguous in memory." + * info.buf = self.data + * info.len = self.len # <<<<<<<<<<<<<< + * + * if flags & PyBUF_STRIDES: + */ + __pyx_t_3 = __pyx_v_self->len; + __pyx_v_info->len = __pyx_t_3; + + /* "View.MemoryView":195 + * info.len = self.len + * + * if flags & PyBUF_STRIDES: # <<<<<<<<<<<<<< + * info.ndim = self.ndim + * info.shape = self._shape + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_STRIDES) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":196 + * + * if flags & PyBUF_STRIDES: + * info.ndim = self.ndim # <<<<<<<<<<<<<< + * info.shape = self._shape + * info.strides = self._strides + */ + __pyx_t_4 = __pyx_v_self->ndim; + __pyx_v_info->ndim = __pyx_t_4; + + /* "View.MemoryView":197 + * if flags & PyBUF_STRIDES: + * info.ndim = self.ndim + * info.shape = self._shape # <<<<<<<<<<<<<< + * info.strides = self._strides + * else: + */ + __pyx_t_5 = __pyx_v_self->_shape; + __pyx_v_info->shape = __pyx_t_5; + + /* "View.MemoryView":198 + * info.ndim = self.ndim + * info.shape = self._shape + * info.strides = self._strides # <<<<<<<<<<<<<< + * else: + * info.ndim = 1 + */ + __pyx_t_5 = __pyx_v_self->_strides; + __pyx_v_info->strides = __pyx_t_5; + + /* "View.MemoryView":195 + * info.len = self.len + * + * if flags & PyBUF_STRIDES: # <<<<<<<<<<<<<< + * info.ndim = self.ndim + * info.shape = self._shape + */ + goto __pyx_L6; + } + + /* "View.MemoryView":200 + * info.strides = self._strides + * else: + * info.ndim = 1 # <<<<<<<<<<<<<< + * info.shape = &self.len if flags & PyBUF_ND else NULL + * info.strides = NULL + */ + /*else*/ { + __pyx_v_info->ndim = 1; + + /* "View.MemoryView":201 + * else: + * info.ndim = 1 + * info.shape = &self.len if flags & PyBUF_ND else NULL # <<<<<<<<<<<<<< + * info.strides = NULL + * + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_ND) != 0); + if (__pyx_t_1) { + __pyx_t_5 = (&__pyx_v_self->len); + } else { + __pyx_t_5 = NULL; + } + __pyx_v_info->shape = __pyx_t_5; + + /* "View.MemoryView":202 + * info.ndim = 1 + * info.shape = &self.len if flags & PyBUF_ND else NULL + * info.strides = NULL # <<<<<<<<<<<<<< + * + * info.suboffsets = NULL + */ + __pyx_v_info->strides = NULL; + } + __pyx_L6:; + + /* "View.MemoryView":204 + * info.strides = NULL + * + * info.suboffsets = NULL # <<<<<<<<<<<<<< + * info.itemsize = self.itemsize + * info.readonly = 0 + */ + __pyx_v_info->suboffsets = NULL; + + /* "View.MemoryView":205 + * + * info.suboffsets = NULL + * info.itemsize = self.itemsize # <<<<<<<<<<<<<< + * info.readonly = 0 + * info.format = self.format if flags & PyBUF_FORMAT else NULL + */ + __pyx_t_3 = __pyx_v_self->itemsize; + __pyx_v_info->itemsize = __pyx_t_3; + + /* "View.MemoryView":206 + * info.suboffsets = NULL + * info.itemsize = self.itemsize + * info.readonly = 0 # <<<<<<<<<<<<<< + * info.format = self.format if flags & PyBUF_FORMAT else NULL + * info.obj = self + */ + __pyx_v_info->readonly = 0; + + /* "View.MemoryView":207 + * info.itemsize = self.itemsize + * info.readonly = 0 + * info.format = self.format if flags & PyBUF_FORMAT else NULL # <<<<<<<<<<<<<< + * info.obj = self + * + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_FORMAT) != 0); + if (__pyx_t_1) { + __pyx_t_2 = __pyx_v_self->format; + } else { + __pyx_t_2 = NULL; + } + __pyx_v_info->format = __pyx_t_2; + + /* "View.MemoryView":208 + * info.readonly = 0 + * info.format = self.format if flags & PyBUF_FORMAT else NULL + * info.obj = self # <<<<<<<<<<<<<< + * + * def __dealloc__(array self): + */ + __Pyx_INCREF((PyObject *)__pyx_v_self); + __Pyx_GIVEREF((PyObject *)__pyx_v_self); + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); + __pyx_v_info->obj = ((PyObject *)__pyx_v_self); + + /* "View.MemoryView":182 + * _allocate_buffer(self) + * + * @cname('getbuffer') # <<<<<<<<<<<<<< + * def __getbuffer__(self, Py_buffer *info, int flags): + * cdef int bufmode = -1 + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.array.__getbuffer__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + if (__pyx_v_info->obj != NULL) { + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); __pyx_v_info->obj = 0; + } + goto __pyx_L2; + __pyx_L0:; + if (__pyx_v_info->obj == Py_None) { + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); __pyx_v_info->obj = 0; + } + __pyx_L2:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":210 + * info.obj = self + * + * def __dealloc__(array self): # <<<<<<<<<<<<<< + * if self.callback_free_data != NULL: + * self.callback_free_data(self.data) + */ + +/* Python wrapper */ +static void __pyx_array___dealloc__(PyObject *__pyx_v_self); /*proto*/ +static void __pyx_array___dealloc__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__dealloc__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_array___pyx_pf_15View_dot_MemoryView_5array_4__dealloc__(((struct __pyx_array_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); +} + +static void __pyx_array___pyx_pf_15View_dot_MemoryView_5array_4__dealloc__(struct __pyx_array_obj *__pyx_v_self) { + int __pyx_t_1; + int __pyx_t_2; + + /* "View.MemoryView":211 + * + * def __dealloc__(array self): + * if self.callback_free_data != NULL: # <<<<<<<<<<<<<< + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: + */ + __pyx_t_1 = (__pyx_v_self->callback_free_data != NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":212 + * def __dealloc__(array self): + * if self.callback_free_data != NULL: + * self.callback_free_data(self.data) # <<<<<<<<<<<<<< + * elif self.free_data and self.data is not NULL: + * if self.dtype_is_object: + */ + __pyx_v_self->callback_free_data(__pyx_v_self->data); + + /* "View.MemoryView":211 + * + * def __dealloc__(array self): + * if self.callback_free_data != NULL: # <<<<<<<<<<<<<< + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":213 + * if self.callback_free_data != NULL: + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: # <<<<<<<<<<<<<< + * if self.dtype_is_object: + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + */ + if (__pyx_v_self->free_data) { + } else { + __pyx_t_1 = __pyx_v_self->free_data; + goto __pyx_L4_bool_binop_done; + } + __pyx_t_2 = (__pyx_v_self->data != NULL); + __pyx_t_1 = __pyx_t_2; + __pyx_L4_bool_binop_done:; + if (__pyx_t_1) { + + /* "View.MemoryView":214 + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + * free(self.data) + */ + if (__pyx_v_self->dtype_is_object) { + + /* "View.MemoryView":215 + * elif self.free_data and self.data is not NULL: + * if self.dtype_is_object: + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) # <<<<<<<<<<<<<< + * free(self.data) + * PyObject_Free(self._shape) + */ + __pyx_memoryview_refcount_objects_in_slice(__pyx_v_self->data, __pyx_v_self->_shape, __pyx_v_self->_strides, __pyx_v_self->ndim, 0); + + /* "View.MemoryView":214 + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + * free(self.data) + */ + } + + /* "View.MemoryView":216 + * if self.dtype_is_object: + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + * free(self.data) # <<<<<<<<<<<<<< + * PyObject_Free(self._shape) + * + */ + free(__pyx_v_self->data); + + /* "View.MemoryView":213 + * if self.callback_free_data != NULL: + * self.callback_free_data(self.data) + * elif self.free_data and self.data is not NULL: # <<<<<<<<<<<<<< + * if self.dtype_is_object: + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + */ + } + __pyx_L3:; + + /* "View.MemoryView":217 + * refcount_objects_in_slice(self.data, self._shape, self._strides, self.ndim, inc=False) + * free(self.data) + * PyObject_Free(self._shape) # <<<<<<<<<<<<<< + * + * @property + */ + PyObject_Free(__pyx_v_self->_shape); + + /* "View.MemoryView":210 + * info.obj = self + * + * def __dealloc__(array self): # <<<<<<<<<<<<<< + * if self.callback_free_data != NULL: + * self.callback_free_data(self.data) + */ + + /* function exit code */ +} + +/* "View.MemoryView":219 + * PyObject_Free(self._shape) + * + * @property # <<<<<<<<<<<<<< + * def memview(self): + * return self.get_memview() + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_5array_7memview_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_5array_7memview_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_5array_7memview___get__(((struct __pyx_array_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_5array_7memview___get__(struct __pyx_array_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":221 + * @property + * def memview(self): + * return self.get_memview() # <<<<<<<<<<<<<< + * + * @cname('get_memview') + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = ((struct __pyx_vtabstruct_array *)__pyx_v_self->__pyx_vtab)->get_memview(__pyx_v_self); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 221, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":219 + * PyObject_Free(self._shape) + * + * @property # <<<<<<<<<<<<<< + * def memview(self): + * return self.get_memview() + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.array.memview.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":224 + * + * @cname('get_memview') + * cdef get_memview(self): # <<<<<<<<<<<<<< + * flags = PyBUF_ANY_CONTIGUOUS|PyBUF_FORMAT|PyBUF_WRITABLE + * return memoryview(self, flags, self.dtype_is_object) + */ + +static PyObject *__pyx_array_get_memview(struct __pyx_array_obj *__pyx_v_self) { + int __pyx_v_flags; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("get_memview", 1); + + /* "View.MemoryView":225 + * @cname('get_memview') + * cdef get_memview(self): + * flags = PyBUF_ANY_CONTIGUOUS|PyBUF_FORMAT|PyBUF_WRITABLE # <<<<<<<<<<<<<< + * return memoryview(self, flags, self.dtype_is_object) + * + */ + __pyx_v_flags = ((PyBUF_ANY_CONTIGUOUS | PyBUF_FORMAT) | PyBUF_WRITABLE); + + /* "View.MemoryView":226 + * cdef get_memview(self): + * flags = PyBUF_ANY_CONTIGUOUS|PyBUF_FORMAT|PyBUF_WRITABLE + * return memoryview(self, flags, self.dtype_is_object) # <<<<<<<<<<<<<< + * + * def __len__(self): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_flags); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 226, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyBool_FromLong(__pyx_v_self->dtype_is_object); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 226, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 226, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_INCREF((PyObject *)__pyx_v_self); + __Pyx_GIVEREF((PyObject *)__pyx_v_self); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, ((PyObject *)__pyx_v_self))) __PYX_ERR(1, 226, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_1)) __PYX_ERR(1, 226, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_t_2)) __PYX_ERR(1, 226, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_2 = 0; + __pyx_t_2 = __Pyx_PyObject_Call(((PyObject *)__pyx_memoryview_type), __pyx_t_3, NULL); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 226, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":224 + * + * @cname('get_memview') + * cdef get_memview(self): # <<<<<<<<<<<<<< + * flags = PyBUF_ANY_CONTIGUOUS|PyBUF_FORMAT|PyBUF_WRITABLE + * return memoryview(self, flags, self.dtype_is_object) + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.array.get_memview", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":228 + * return memoryview(self, flags, self.dtype_is_object) + * + * def __len__(self): # <<<<<<<<<<<<<< + * return self._shape[0] + * + */ + +/* Python wrapper */ +static Py_ssize_t __pyx_array___len__(PyObject *__pyx_v_self); /*proto*/ +static Py_ssize_t __pyx_array___len__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + Py_ssize_t __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__len__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array_6__len__(((struct __pyx_array_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static Py_ssize_t __pyx_array___pyx_pf_15View_dot_MemoryView_5array_6__len__(struct __pyx_array_obj *__pyx_v_self) { + Py_ssize_t __pyx_r; + + /* "View.MemoryView":229 + * + * def __len__(self): + * return self._shape[0] # <<<<<<<<<<<<<< + * + * def __getattr__(self, attr): + */ + __pyx_r = (__pyx_v_self->_shape[0]); + goto __pyx_L0; + + /* "View.MemoryView":228 + * return memoryview(self, flags, self.dtype_is_object) + * + * def __len__(self): # <<<<<<<<<<<<<< + * return self._shape[0] + * + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":231 + * return self._shape[0] + * + * def __getattr__(self, attr): # <<<<<<<<<<<<<< + * return getattr(self.memview, attr) + * + */ + +/* Python wrapper */ +static PyObject *__pyx_array___getattr__(PyObject *__pyx_v_self, PyObject *__pyx_v_attr); /*proto*/ +static PyObject *__pyx_array___getattr__(PyObject *__pyx_v_self, PyObject *__pyx_v_attr) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__getattr__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array_8__getattr__(((struct __pyx_array_obj *)__pyx_v_self), ((PyObject *)__pyx_v_attr)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_array___pyx_pf_15View_dot_MemoryView_5array_8__getattr__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_attr) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__getattr__", 1); + + /* "View.MemoryView":232 + * + * def __getattr__(self, attr): + * return getattr(self.memview, attr) # <<<<<<<<<<<<<< + * + * def __getitem__(self, item): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_memview); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 232, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_GetAttr(__pyx_t_1, __pyx_v_attr); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 232, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":231 + * return self._shape[0] + * + * def __getattr__(self, attr): # <<<<<<<<<<<<<< + * return getattr(self.memview, attr) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.array.__getattr__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":234 + * return getattr(self.memview, attr) + * + * def __getitem__(self, item): # <<<<<<<<<<<<<< + * return self.memview[item] + * + */ + +/* Python wrapper */ +static PyObject *__pyx_array___getitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_item); /*proto*/ +static PyObject *__pyx_array___getitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_item) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__getitem__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array_10__getitem__(((struct __pyx_array_obj *)__pyx_v_self), ((PyObject *)__pyx_v_item)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_array___pyx_pf_15View_dot_MemoryView_5array_10__getitem__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_item) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__getitem__", 1); + + /* "View.MemoryView":235 + * + * def __getitem__(self, item): + * return self.memview[item] # <<<<<<<<<<<<<< + * + * def __setitem__(self, item, value): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_memview); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 235, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyObject_GetItem(__pyx_t_1, __pyx_v_item); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 235, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":234 + * return getattr(self.memview, attr) + * + * def __getitem__(self, item): # <<<<<<<<<<<<<< + * return self.memview[item] + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.array.__getitem__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":237 + * return self.memview[item] + * + * def __setitem__(self, item, value): # <<<<<<<<<<<<<< + * self.memview[item] = value + * + */ + +/* Python wrapper */ +static int __pyx_array___setitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_item, PyObject *__pyx_v_value); /*proto*/ +static int __pyx_array___setitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_item, PyObject *__pyx_v_value) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setitem__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_array___pyx_pf_15View_dot_MemoryView_5array_12__setitem__(((struct __pyx_array_obj *)__pyx_v_self), ((PyObject *)__pyx_v_item), ((PyObject *)__pyx_v_value)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_array___pyx_pf_15View_dot_MemoryView_5array_12__setitem__(struct __pyx_array_obj *__pyx_v_self, PyObject *__pyx_v_item, PyObject *__pyx_v_value) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setitem__", 1); + + /* "View.MemoryView":238 + * + * def __setitem__(self, item, value): + * self.memview[item] = value # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_memview); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 238, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + if (unlikely((PyObject_SetItem(__pyx_t_1, __pyx_v_item, __pyx_v_value) < 0))) __PYX_ERR(1, 238, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "View.MemoryView":237 + * return self.memview[item] + * + * def __setitem__(self, item, value): # <<<<<<<<<<<<<< + * self.memview[item] = value + * + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.array.__setitem__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_array_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_array_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("__reduce_cython__", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "__reduce_cython__", 0))) return NULL; + __pyx_r = __pyx_pf___pyx_array___reduce_cython__(((struct __pyx_array_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_array___reduce_cython__(CYTHON_UNUSED struct __pyx_array_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__reduce_cython__", 1); + + /* "(tree fragment)":2 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 2, __pyx_L1_error) + + /* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.array.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_array_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_array_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + CYTHON_UNUSED PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 3, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__setstate_cython__") < 0)) __PYX_ERR(1, 3, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + } + __pyx_v___pyx_state = values[0]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__setstate_cython__", 1, 1, 1, __pyx_nargs); __PYX_ERR(1, 3, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.array.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf___pyx_array_2__setstate_cython__(((struct __pyx_array_obj *)__pyx_v_self), __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_array_2__setstate_cython__(CYTHON_UNUSED struct __pyx_array_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setstate_cython__", 1); + + /* "(tree fragment)":4 + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 4, __pyx_L1_error) + + /* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.array.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":248 + * + * @cname("__pyx_array_allocate_buffer") + * cdef int _allocate_buffer(array self) except -1: # <<<<<<<<<<<<<< + * + * + */ + +static int __pyx_array_allocate_buffer(struct __pyx_array_obj *__pyx_v_self) { + Py_ssize_t __pyx_v_i; + PyObject **__pyx_v_p; + int __pyx_r; + int __pyx_t_1; + Py_ssize_t __pyx_t_2; + Py_ssize_t __pyx_t_3; + Py_ssize_t __pyx_t_4; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + + /* "View.MemoryView":254 + * cdef PyObject **p + * + * self.free_data = True # <<<<<<<<<<<<<< + * self.data = malloc(self.len) + * if not self.data: + */ + __pyx_v_self->free_data = 1; + + /* "View.MemoryView":255 + * + * self.free_data = True + * self.data = malloc(self.len) # <<<<<<<<<<<<<< + * if not self.data: + * raise MemoryError, "unable to allocate array data." + */ + __pyx_v_self->data = ((char *)malloc(__pyx_v_self->len)); + + /* "View.MemoryView":256 + * self.free_data = True + * self.data = malloc(self.len) + * if not self.data: # <<<<<<<<<<<<<< + * raise MemoryError, "unable to allocate array data." + * + */ + __pyx_t_1 = (!(__pyx_v_self->data != 0)); + if (unlikely(__pyx_t_1)) { + + /* "View.MemoryView":257 + * self.data = malloc(self.len) + * if not self.data: + * raise MemoryError, "unable to allocate array data." # <<<<<<<<<<<<<< + * + * if self.dtype_is_object: + */ + __Pyx_Raise(__pyx_builtin_MemoryError, __pyx_kp_s_unable_to_allocate_array_data, 0, 0); + __PYX_ERR(1, 257, __pyx_L1_error) + + /* "View.MemoryView":256 + * self.free_data = True + * self.data = malloc(self.len) + * if not self.data: # <<<<<<<<<<<<<< + * raise MemoryError, "unable to allocate array data." + * + */ + } + + /* "View.MemoryView":259 + * raise MemoryError, "unable to allocate array data." + * + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * p = self.data + * for i in range(self.len // self.itemsize): + */ + if (__pyx_v_self->dtype_is_object) { + + /* "View.MemoryView":260 + * + * if self.dtype_is_object: + * p = self.data # <<<<<<<<<<<<<< + * for i in range(self.len // self.itemsize): + * p[i] = Py_None + */ + __pyx_v_p = ((PyObject **)__pyx_v_self->data); + + /* "View.MemoryView":261 + * if self.dtype_is_object: + * p = self.data + * for i in range(self.len // self.itemsize): # <<<<<<<<<<<<<< + * p[i] = Py_None + * Py_INCREF(Py_None) + */ + if (unlikely(__pyx_v_self->itemsize == 0)) { + PyErr_SetString(PyExc_ZeroDivisionError, "integer division or modulo by zero"); + __PYX_ERR(1, 261, __pyx_L1_error) + } + else if (sizeof(Py_ssize_t) == sizeof(long) && (!(((Py_ssize_t)-1) > 0)) && unlikely(__pyx_v_self->itemsize == (Py_ssize_t)-1) && unlikely(__Pyx_UNARY_NEG_WOULD_OVERFLOW(__pyx_v_self->len))) { + PyErr_SetString(PyExc_OverflowError, "value too large to perform division"); + __PYX_ERR(1, 261, __pyx_L1_error) + } + __pyx_t_2 = __Pyx_div_Py_ssize_t(__pyx_v_self->len, __pyx_v_self->itemsize); + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":262 + * p = self.data + * for i in range(self.len // self.itemsize): + * p[i] = Py_None # <<<<<<<<<<<<<< + * Py_INCREF(Py_None) + * return 0 + */ + (__pyx_v_p[__pyx_v_i]) = Py_None; + + /* "View.MemoryView":263 + * for i in range(self.len // self.itemsize): + * p[i] = Py_None + * Py_INCREF(Py_None) # <<<<<<<<<<<<<< + * return 0 + * + */ + Py_INCREF(Py_None); + } + + /* "View.MemoryView":259 + * raise MemoryError, "unable to allocate array data." + * + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * p = self.data + * for i in range(self.len // self.itemsize): + */ + } + + /* "View.MemoryView":264 + * p[i] = Py_None + * Py_INCREF(Py_None) + * return 0 # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":248 + * + * @cname("__pyx_array_allocate_buffer") + * cdef int _allocate_buffer(array self) except -1: # <<<<<<<<<<<<<< + * + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView._allocate_buffer", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":268 + * + * @cname("__pyx_array_new") + * cdef array array_cwrapper(tuple shape, Py_ssize_t itemsize, char *format, char *c_mode, char *buf): # <<<<<<<<<<<<<< + * cdef array result + * cdef str mode = "fortran" if c_mode[0] == b'f' else "c" # this often comes from a constant C string. + */ + +static struct __pyx_array_obj *__pyx_array_new(PyObject *__pyx_v_shape, Py_ssize_t __pyx_v_itemsize, char *__pyx_v_format, char *__pyx_v_c_mode, char *__pyx_v_buf) { + struct __pyx_array_obj *__pyx_v_result = 0; + PyObject *__pyx_v_mode = 0; + struct __pyx_array_obj *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("array_cwrapper", 1); + + /* "View.MemoryView":270 + * cdef array array_cwrapper(tuple shape, Py_ssize_t itemsize, char *format, char *c_mode, char *buf): + * cdef array result + * cdef str mode = "fortran" if c_mode[0] == b'f' else "c" # this often comes from a constant C string. # <<<<<<<<<<<<<< + * + * if buf is NULL: + */ + __pyx_t_2 = ((__pyx_v_c_mode[0]) == 'f'); + if (__pyx_t_2) { + __Pyx_INCREF(__pyx_n_s_fortran); + __pyx_t_1 = __pyx_n_s_fortran; + } else { + __Pyx_INCREF(__pyx_n_s_c); + __pyx_t_1 = __pyx_n_s_c; + } + __pyx_v_mode = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "View.MemoryView":272 + * cdef str mode = "fortran" if c_mode[0] == b'f' else "c" # this often comes from a constant C string. + * + * if buf is NULL: # <<<<<<<<<<<<<< + * result = array.__new__(array, shape, itemsize, format, mode) + * else: + */ + __pyx_t_2 = (__pyx_v_buf == NULL); + if (__pyx_t_2) { + + /* "View.MemoryView":273 + * + * if buf is NULL: + * result = array.__new__(array, shape, itemsize, format, mode) # <<<<<<<<<<<<<< + * else: + * result = array.__new__(array, shape, itemsize, format, mode, allocate_buffer=False) + */ + __pyx_t_1 = PyInt_FromSsize_t(__pyx_v_itemsize); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 273, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_3 = __Pyx_PyBytes_FromString(__pyx_v_format); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 273, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_4 = PyTuple_New(4); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 273, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_INCREF(__pyx_v_shape); + __Pyx_GIVEREF(__pyx_v_shape); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_v_shape)) __PYX_ERR(1, 273, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 1, __pyx_t_1)) __PYX_ERR(1, 273, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_3); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 2, __pyx_t_3)) __PYX_ERR(1, 273, __pyx_L1_error); + __Pyx_INCREF(__pyx_v_mode); + __Pyx_GIVEREF(__pyx_v_mode); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 3, __pyx_v_mode)) __PYX_ERR(1, 273, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_3 = 0; + __pyx_t_3 = ((PyObject *)__pyx_tp_new_array(((PyTypeObject *)__pyx_array_type), __pyx_t_4, NULL)); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 273, __pyx_L1_error) + __Pyx_GOTREF((PyObject *)__pyx_t_3); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __pyx_v_result = ((struct __pyx_array_obj *)__pyx_t_3); + __pyx_t_3 = 0; + + /* "View.MemoryView":272 + * cdef str mode = "fortran" if c_mode[0] == b'f' else "c" # this often comes from a constant C string. + * + * if buf is NULL: # <<<<<<<<<<<<<< + * result = array.__new__(array, shape, itemsize, format, mode) + * else: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":275 + * result = array.__new__(array, shape, itemsize, format, mode) + * else: + * result = array.__new__(array, shape, itemsize, format, mode, allocate_buffer=False) # <<<<<<<<<<<<<< + * result.data = buf + * + */ + /*else*/ { + __pyx_t_3 = PyInt_FromSsize_t(__pyx_v_itemsize); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 275, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_4 = __Pyx_PyBytes_FromString(__pyx_v_format); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 275, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_1 = PyTuple_New(4); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 275, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_v_shape); + __Pyx_GIVEREF(__pyx_v_shape); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_v_shape)) __PYX_ERR(1, 275, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_3); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 1, __pyx_t_3)) __PYX_ERR(1, 275, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_4); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 2, __pyx_t_4)) __PYX_ERR(1, 275, __pyx_L1_error); + __Pyx_INCREF(__pyx_v_mode); + __Pyx_GIVEREF(__pyx_v_mode); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 3, __pyx_v_mode)) __PYX_ERR(1, 275, __pyx_L1_error); + __pyx_t_3 = 0; + __pyx_t_4 = 0; + __pyx_t_4 = __Pyx_PyDict_NewPresized(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 275, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + if (PyDict_SetItem(__pyx_t_4, __pyx_n_s_allocate_buffer, Py_False) < 0) __PYX_ERR(1, 275, __pyx_L1_error) + __pyx_t_3 = ((PyObject *)__pyx_tp_new_array(((PyTypeObject *)__pyx_array_type), __pyx_t_1, __pyx_t_4)); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 275, __pyx_L1_error) + __Pyx_GOTREF((PyObject *)__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __pyx_v_result = ((struct __pyx_array_obj *)__pyx_t_3); + __pyx_t_3 = 0; + + /* "View.MemoryView":276 + * else: + * result = array.__new__(array, shape, itemsize, format, mode, allocate_buffer=False) + * result.data = buf # <<<<<<<<<<<<<< + * + * return result + */ + __pyx_v_result->data = __pyx_v_buf; + } + __pyx_L3:; + + /* "View.MemoryView":278 + * result.data = buf + * + * return result # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF((PyObject *)__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_result); + __pyx_r = __pyx_v_result; + goto __pyx_L0; + + /* "View.MemoryView":268 + * + * @cname("__pyx_array_new") + * cdef array array_cwrapper(tuple shape, Py_ssize_t itemsize, char *format, char *c_mode, char *buf): # <<<<<<<<<<<<<< + * cdef array result + * cdef str mode = "fortran" if c_mode[0] == b'f' else "c" # this often comes from a constant C string. + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView.array_cwrapper", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_result); + __Pyx_XDECREF(__pyx_v_mode); + __Pyx_XGIVEREF((PyObject *)__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":304 + * cdef class Enum(object): + * cdef object name + * def __init__(self, name): # <<<<<<<<<<<<<< + * self.name = name + * def __repr__(self): + */ + +/* Python wrapper */ +static int __pyx_MemviewEnum___init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ +static int __pyx_MemviewEnum___init__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { + PyObject *__pyx_v_name = 0; + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__init__ (wrapper)", 0); + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return -1; + #endif + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_name,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_VARARGS(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_name)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 304, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__init__") < 0)) __PYX_ERR(1, 304, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + } + __pyx_v_name = values[0]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__init__", 1, 1, 1, __pyx_nargs); __PYX_ERR(1, 304, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.Enum.__init__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return -1; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum___init__(((struct __pyx_MemviewEnum_obj *)__pyx_v_self), __pyx_v_name); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum___init__(struct __pyx_MemviewEnum_obj *__pyx_v_self, PyObject *__pyx_v_name) { + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__init__", 1); + + /* "View.MemoryView":305 + * cdef object name + * def __init__(self, name): + * self.name = name # <<<<<<<<<<<<<< + * def __repr__(self): + * return self.name + */ + __Pyx_INCREF(__pyx_v_name); + __Pyx_GIVEREF(__pyx_v_name); + __Pyx_GOTREF(__pyx_v_self->name); + __Pyx_DECREF(__pyx_v_self->name); + __pyx_v_self->name = __pyx_v_name; + + /* "View.MemoryView":304 + * cdef class Enum(object): + * cdef object name + * def __init__(self, name): # <<<<<<<<<<<<<< + * self.name = name + * def __repr__(self): + */ + + /* function exit code */ + __pyx_r = 0; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":306 + * def __init__(self, name): + * self.name = name + * def __repr__(self): # <<<<<<<<<<<<<< + * return self.name + * + */ + +/* Python wrapper */ +static PyObject *__pyx_MemviewEnum___repr__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_MemviewEnum___repr__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__repr__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum_2__repr__(((struct __pyx_MemviewEnum_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_MemviewEnum___pyx_pf_15View_dot_MemoryView_4Enum_2__repr__(struct __pyx_MemviewEnum_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__repr__", 1); + + /* "View.MemoryView":307 + * self.name = name + * def __repr__(self): + * return self.name # <<<<<<<<<<<<<< + * + * cdef generic = Enum("") + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_self->name); + __pyx_r = __pyx_v_self->name; + goto __pyx_L0; + + /* "View.MemoryView":306 + * def __init__(self, name): + * self.name = name + * def __repr__(self): # <<<<<<<<<<<<<< + * return self.name + * + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * cdef tuple state + * cdef object _dict + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_MemviewEnum_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_MemviewEnum_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("__reduce_cython__", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "__reduce_cython__", 0))) return NULL; + __pyx_r = __pyx_pf___pyx_MemviewEnum___reduce_cython__(((struct __pyx_MemviewEnum_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_MemviewEnum___reduce_cython__(struct __pyx_MemviewEnum_obj *__pyx_v_self) { + PyObject *__pyx_v_state = 0; + PyObject *__pyx_v__dict = 0; + int __pyx_v_use_setstate; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__reduce_cython__", 1); + + /* "(tree fragment)":5 + * cdef object _dict + * cdef bint use_setstate + * state = (self.name,) # <<<<<<<<<<<<<< + * _dict = getattr(self, '__dict__', None) + * if _dict is not None: + */ + __pyx_t_1 = PyTuple_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_v_self->name); + __Pyx_GIVEREF(__pyx_v_self->name); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_v_self->name)) __PYX_ERR(1, 5, __pyx_L1_error); + __pyx_v_state = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "(tree fragment)":6 + * cdef bint use_setstate + * state = (self.name,) + * _dict = getattr(self, '__dict__', None) # <<<<<<<<<<<<<< + * if _dict is not None: + * state += (_dict,) + */ + __pyx_t_1 = __Pyx_GetAttr3(((PyObject *)__pyx_v_self), __pyx_n_s_dict, Py_None); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 6, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v__dict = __pyx_t_1; + __pyx_t_1 = 0; + + /* "(tree fragment)":7 + * state = (self.name,) + * _dict = getattr(self, '__dict__', None) + * if _dict is not None: # <<<<<<<<<<<<<< + * state += (_dict,) + * use_setstate = True + */ + __pyx_t_2 = (__pyx_v__dict != Py_None); + if (__pyx_t_2) { + + /* "(tree fragment)":8 + * _dict = getattr(self, '__dict__', None) + * if _dict is not None: + * state += (_dict,) # <<<<<<<<<<<<<< + * use_setstate = True + * else: + */ + __pyx_t_1 = PyTuple_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 8, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_v__dict); + __Pyx_GIVEREF(__pyx_v__dict); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_v__dict)) __PYX_ERR(1, 8, __pyx_L1_error); + __pyx_t_3 = PyNumber_InPlaceAdd(__pyx_v_state, __pyx_t_1); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 8, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF_SET(__pyx_v_state, ((PyObject*)__pyx_t_3)); + __pyx_t_3 = 0; + + /* "(tree fragment)":9 + * if _dict is not None: + * state += (_dict,) + * use_setstate = True # <<<<<<<<<<<<<< + * else: + * use_setstate = self.name is not None + */ + __pyx_v_use_setstate = 1; + + /* "(tree fragment)":7 + * state = (self.name,) + * _dict = getattr(self, '__dict__', None) + * if _dict is not None: # <<<<<<<<<<<<<< + * state += (_dict,) + * use_setstate = True + */ + goto __pyx_L3; + } + + /* "(tree fragment)":11 + * use_setstate = True + * else: + * use_setstate = self.name is not None # <<<<<<<<<<<<<< + * if use_setstate: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, None), state + */ + /*else*/ { + __pyx_t_2 = (__pyx_v_self->name != Py_None); + __pyx_v_use_setstate = __pyx_t_2; + } + __pyx_L3:; + + /* "(tree fragment)":12 + * else: + * use_setstate = self.name is not None + * if use_setstate: # <<<<<<<<<<<<<< + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, None), state + * else: + */ + if (__pyx_v_use_setstate) { + + /* "(tree fragment)":13 + * use_setstate = self.name is not None + * if use_setstate: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, None), state # <<<<<<<<<<<<<< + * else: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, state) + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_GetModuleGlobalName(__pyx_t_3, __pyx_n_s_pyx_unpickle_Enum); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 13, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_1 = PyTuple_New(3); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 13, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); + __Pyx_GIVEREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, ((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self))))) __PYX_ERR(1, 13, __pyx_L1_error); + __Pyx_INCREF(__pyx_int_136983863); + __Pyx_GIVEREF(__pyx_int_136983863); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 1, __pyx_int_136983863)) __PYX_ERR(1, 13, __pyx_L1_error); + __Pyx_INCREF(Py_None); + __Pyx_GIVEREF(Py_None); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 2, Py_None)) __PYX_ERR(1, 13, __pyx_L1_error); + __pyx_t_4 = PyTuple_New(3); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 13, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_3); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_3)) __PYX_ERR(1, 13, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 1, __pyx_t_1)) __PYX_ERR(1, 13, __pyx_L1_error); + __Pyx_INCREF(__pyx_v_state); + __Pyx_GIVEREF(__pyx_v_state); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 2, __pyx_v_state)) __PYX_ERR(1, 13, __pyx_L1_error); + __pyx_t_3 = 0; + __pyx_t_1 = 0; + __pyx_r = __pyx_t_4; + __pyx_t_4 = 0; + goto __pyx_L0; + + /* "(tree fragment)":12 + * else: + * use_setstate = self.name is not None + * if use_setstate: # <<<<<<<<<<<<<< + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, None), state + * else: + */ + } + + /* "(tree fragment)":15 + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, None), state + * else: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, state) # <<<<<<<<<<<<<< + * def __setstate_cython__(self, __pyx_state): + * __pyx_unpickle_Enum__set_state(self, __pyx_state) + */ + /*else*/ { + __Pyx_XDECREF(__pyx_r); + __Pyx_GetModuleGlobalName(__pyx_t_4, __pyx_n_s_pyx_unpickle_Enum); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 15, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_1 = PyTuple_New(3); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 15, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); + __Pyx_GIVEREF(((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self)))); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 0, ((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self))))) __PYX_ERR(1, 15, __pyx_L1_error); + __Pyx_INCREF(__pyx_int_136983863); + __Pyx_GIVEREF(__pyx_int_136983863); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 1, __pyx_int_136983863)) __PYX_ERR(1, 15, __pyx_L1_error); + __Pyx_INCREF(__pyx_v_state); + __Pyx_GIVEREF(__pyx_v_state); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_1, 2, __pyx_v_state)) __PYX_ERR(1, 15, __pyx_L1_error); + __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 15, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_GIVEREF(__pyx_t_4); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_4)) __PYX_ERR(1, 15, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_1)) __PYX_ERR(1, 15, __pyx_L1_error); + __pyx_t_4 = 0; + __pyx_t_1 = 0; + __pyx_r = __pyx_t_3; + __pyx_t_3 = 0; + goto __pyx_L0; + } + + /* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * cdef tuple state + * cdef object _dict + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView.Enum.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_state); + __Pyx_XDECREF(__pyx_v__dict); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":16 + * else: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, state) + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * __pyx_unpickle_Enum__set_state(self, __pyx_state) + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_MemviewEnum_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_MemviewEnum_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 16, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__setstate_cython__") < 0)) __PYX_ERR(1, 16, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + } + __pyx_v___pyx_state = values[0]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__setstate_cython__", 1, 1, 1, __pyx_nargs); __PYX_ERR(1, 16, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.Enum.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf___pyx_MemviewEnum_2__setstate_cython__(((struct __pyx_MemviewEnum_obj *)__pyx_v_self), __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_MemviewEnum_2__setstate_cython__(struct __pyx_MemviewEnum_obj *__pyx_v_self, PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setstate_cython__", 1); + + /* "(tree fragment)":17 + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, state) + * def __setstate_cython__(self, __pyx_state): + * __pyx_unpickle_Enum__set_state(self, __pyx_state) # <<<<<<<<<<<<<< + */ + if (!(likely(PyTuple_CheckExact(__pyx_v___pyx_state))||((__pyx_v___pyx_state) == Py_None) || __Pyx_RaiseUnexpectedTypeError("tuple", __pyx_v___pyx_state))) __PYX_ERR(1, 17, __pyx_L1_error) + __pyx_t_1 = __pyx_unpickle_Enum__set_state(__pyx_v_self, ((PyObject*)__pyx_v___pyx_state)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 17, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "(tree fragment)":16 + * else: + * return __pyx_unpickle_Enum, (type(self), 0x82a3537, state) + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * __pyx_unpickle_Enum__set_state(self, __pyx_state) + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.Enum.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":349 + * cdef __Pyx_TypeInfo *typeinfo + * + * def __cinit__(memoryview self, object obj, int flags, bint dtype_is_object=False): # <<<<<<<<<<<<<< + * self.obj = obj + * self.flags = flags + */ + +/* Python wrapper */ +static int __pyx_memoryview___cinit__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/ +static int __pyx_memoryview___cinit__(PyObject *__pyx_v_self, PyObject *__pyx_args, PyObject *__pyx_kwds) { + PyObject *__pyx_v_obj = 0; + int __pyx_v_flags; + int __pyx_v_dtype_is_object; + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[3] = {0,0,0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__cinit__ (wrapper)", 0); + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return -1; + #endif + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_obj,&__pyx_n_s_flags,&__pyx_n_s_dtype_is_object,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 3: values[2] = __Pyx_Arg_VARARGS(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_VARARGS(__pyx_args, 1); + CYTHON_FALLTHROUGH; + case 1: values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_VARARGS(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_obj)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 349, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + CYTHON_FALLTHROUGH; + case 1: + if (likely((values[1] = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_flags)) != 0)) { + (void)__Pyx_Arg_NewRef_VARARGS(values[1]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 349, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__cinit__", 0, 2, 3, 1); __PYX_ERR(1, 349, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 2: + if (kw_args > 0) { + PyObject* value = __Pyx_GetKwValue_VARARGS(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_dtype_is_object); + if (value) { values[2] = __Pyx_Arg_NewRef_VARARGS(value); kw_args--; } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 349, __pyx_L3_error) + } + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__cinit__") < 0)) __PYX_ERR(1, 349, __pyx_L3_error) + } + } else { + switch (__pyx_nargs) { + case 3: values[2] = __Pyx_Arg_VARARGS(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_VARARGS(__pyx_args, 1); + values[0] = __Pyx_Arg_VARARGS(__pyx_args, 0); + break; + default: goto __pyx_L5_argtuple_error; + } + } + __pyx_v_obj = values[0]; + __pyx_v_flags = __Pyx_PyInt_As_int(values[1]); if (unlikely((__pyx_v_flags == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 349, __pyx_L3_error) + if (values[2]) { + __pyx_v_dtype_is_object = __Pyx_PyObject_IsTrue(values[2]); if (unlikely((__pyx_v_dtype_is_object == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 349, __pyx_L3_error) + } else { + __pyx_v_dtype_is_object = ((int)0); + } + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__cinit__", 0, 2, 3, __pyx_nargs); __PYX_ERR(1, 349, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.memoryview.__cinit__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return -1; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview___cinit__(((struct __pyx_memoryview_obj *)__pyx_v_self), __pyx_v_obj, __pyx_v_flags, __pyx_v_dtype_is_object); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_VARARGS(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview___cinit__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_obj, int __pyx_v_flags, int __pyx_v_dtype_is_object) { + int __pyx_r; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + Py_intptr_t __pyx_t_4; + size_t __pyx_t_5; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__cinit__", 1); + + /* "View.MemoryView":350 + * + * def __cinit__(memoryview self, object obj, int flags, bint dtype_is_object=False): + * self.obj = obj # <<<<<<<<<<<<<< + * self.flags = flags + * if type(self) is memoryview or obj is not None: + */ + __Pyx_INCREF(__pyx_v_obj); + __Pyx_GIVEREF(__pyx_v_obj); + __Pyx_GOTREF(__pyx_v_self->obj); + __Pyx_DECREF(__pyx_v_self->obj); + __pyx_v_self->obj = __pyx_v_obj; + + /* "View.MemoryView":351 + * def __cinit__(memoryview self, object obj, int flags, bint dtype_is_object=False): + * self.obj = obj + * self.flags = flags # <<<<<<<<<<<<<< + * if type(self) is memoryview or obj is not None: + * __Pyx_GetBuffer(obj, &self.view, flags) + */ + __pyx_v_self->flags = __pyx_v_flags; + + /* "View.MemoryView":352 + * self.obj = obj + * self.flags = flags + * if type(self) is memoryview or obj is not None: # <<<<<<<<<<<<<< + * __Pyx_GetBuffer(obj, &self.view, flags) + * if self.view.obj == NULL: + */ + __pyx_t_2 = (((PyObject *)Py_TYPE(((PyObject *)__pyx_v_self))) == ((PyObject *)__pyx_memoryview_type)); + if (!__pyx_t_2) { + } else { + __pyx_t_1 = __pyx_t_2; + goto __pyx_L4_bool_binop_done; + } + __pyx_t_2 = (__pyx_v_obj != Py_None); + __pyx_t_1 = __pyx_t_2; + __pyx_L4_bool_binop_done:; + if (__pyx_t_1) { + + /* "View.MemoryView":353 + * self.flags = flags + * if type(self) is memoryview or obj is not None: + * __Pyx_GetBuffer(obj, &self.view, flags) # <<<<<<<<<<<<<< + * if self.view.obj == NULL: + * (<__pyx_buffer *> &self.view).obj = Py_None + */ + __pyx_t_3 = __Pyx_GetBuffer(__pyx_v_obj, (&__pyx_v_self->view), __pyx_v_flags); if (unlikely(__pyx_t_3 == ((int)-1))) __PYX_ERR(1, 353, __pyx_L1_error) + + /* "View.MemoryView":354 + * if type(self) is memoryview or obj is not None: + * __Pyx_GetBuffer(obj, &self.view, flags) + * if self.view.obj == NULL: # <<<<<<<<<<<<<< + * (<__pyx_buffer *> &self.view).obj = Py_None + * Py_INCREF(Py_None) + */ + __pyx_t_1 = (((PyObject *)__pyx_v_self->view.obj) == NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":355 + * __Pyx_GetBuffer(obj, &self.view, flags) + * if self.view.obj == NULL: + * (<__pyx_buffer *> &self.view).obj = Py_None # <<<<<<<<<<<<<< + * Py_INCREF(Py_None) + * + */ + ((Py_buffer *)(&__pyx_v_self->view))->obj = Py_None; + + /* "View.MemoryView":356 + * if self.view.obj == NULL: + * (<__pyx_buffer *> &self.view).obj = Py_None + * Py_INCREF(Py_None) # <<<<<<<<<<<<<< + * + * if not __PYX_CYTHON_ATOMICS_ENABLED(): + */ + Py_INCREF(Py_None); + + /* "View.MemoryView":354 + * if type(self) is memoryview or obj is not None: + * __Pyx_GetBuffer(obj, &self.view, flags) + * if self.view.obj == NULL: # <<<<<<<<<<<<<< + * (<__pyx_buffer *> &self.view).obj = Py_None + * Py_INCREF(Py_None) + */ + } + + /* "View.MemoryView":352 + * self.obj = obj + * self.flags = flags + * if type(self) is memoryview or obj is not None: # <<<<<<<<<<<<<< + * __Pyx_GetBuffer(obj, &self.view, flags) + * if self.view.obj == NULL: + */ + } + + /* "View.MemoryView":358 + * Py_INCREF(Py_None) + * + * if not __PYX_CYTHON_ATOMICS_ENABLED(): # <<<<<<<<<<<<<< + * global __pyx_memoryview_thread_locks_used + * if __pyx_memoryview_thread_locks_used < 8: + */ + __pyx_t_1 = (!__PYX_CYTHON_ATOMICS_ENABLED()); + if (__pyx_t_1) { + + /* "View.MemoryView":360 + * if not __PYX_CYTHON_ATOMICS_ENABLED(): + * global __pyx_memoryview_thread_locks_used + * if __pyx_memoryview_thread_locks_used < 8: # <<<<<<<<<<<<<< + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] + * __pyx_memoryview_thread_locks_used += 1 + */ + __pyx_t_1 = (__pyx_memoryview_thread_locks_used < 8); + if (__pyx_t_1) { + + /* "View.MemoryView":361 + * global __pyx_memoryview_thread_locks_used + * if __pyx_memoryview_thread_locks_used < 8: + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks_used += 1 + * if self.lock is NULL: + */ + __pyx_v_self->lock = (__pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used]); + + /* "View.MemoryView":362 + * if __pyx_memoryview_thread_locks_used < 8: + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] + * __pyx_memoryview_thread_locks_used += 1 # <<<<<<<<<<<<<< + * if self.lock is NULL: + * self.lock = PyThread_allocate_lock() + */ + __pyx_memoryview_thread_locks_used = (__pyx_memoryview_thread_locks_used + 1); + + /* "View.MemoryView":360 + * if not __PYX_CYTHON_ATOMICS_ENABLED(): + * global __pyx_memoryview_thread_locks_used + * if __pyx_memoryview_thread_locks_used < 8: # <<<<<<<<<<<<<< + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] + * __pyx_memoryview_thread_locks_used += 1 + */ + } + + /* "View.MemoryView":363 + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] + * __pyx_memoryview_thread_locks_used += 1 + * if self.lock is NULL: # <<<<<<<<<<<<<< + * self.lock = PyThread_allocate_lock() + * if self.lock is NULL: + */ + __pyx_t_1 = (__pyx_v_self->lock == NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":364 + * __pyx_memoryview_thread_locks_used += 1 + * if self.lock is NULL: + * self.lock = PyThread_allocate_lock() # <<<<<<<<<<<<<< + * if self.lock is NULL: + * raise MemoryError + */ + __pyx_v_self->lock = PyThread_allocate_lock(); + + /* "View.MemoryView":365 + * if self.lock is NULL: + * self.lock = PyThread_allocate_lock() + * if self.lock is NULL: # <<<<<<<<<<<<<< + * raise MemoryError + * + */ + __pyx_t_1 = (__pyx_v_self->lock == NULL); + if (unlikely(__pyx_t_1)) { + + /* "View.MemoryView":366 + * self.lock = PyThread_allocate_lock() + * if self.lock is NULL: + * raise MemoryError # <<<<<<<<<<<<<< + * + * if flags & PyBUF_FORMAT: + */ + PyErr_NoMemory(); __PYX_ERR(1, 366, __pyx_L1_error) + + /* "View.MemoryView":365 + * if self.lock is NULL: + * self.lock = PyThread_allocate_lock() + * if self.lock is NULL: # <<<<<<<<<<<<<< + * raise MemoryError + * + */ + } + + /* "View.MemoryView":363 + * self.lock = __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] + * __pyx_memoryview_thread_locks_used += 1 + * if self.lock is NULL: # <<<<<<<<<<<<<< + * self.lock = PyThread_allocate_lock() + * if self.lock is NULL: + */ + } + + /* "View.MemoryView":358 + * Py_INCREF(Py_None) + * + * if not __PYX_CYTHON_ATOMICS_ENABLED(): # <<<<<<<<<<<<<< + * global __pyx_memoryview_thread_locks_used + * if __pyx_memoryview_thread_locks_used < 8: + */ + } + + /* "View.MemoryView":368 + * raise MemoryError + * + * if flags & PyBUF_FORMAT: # <<<<<<<<<<<<<< + * self.dtype_is_object = (self.view.format[0] == b'O' and self.view.format[1] == b'\0') + * else: + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_FORMAT) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":369 + * + * if flags & PyBUF_FORMAT: + * self.dtype_is_object = (self.view.format[0] == b'O' and self.view.format[1] == b'\0') # <<<<<<<<<<<<<< + * else: + * self.dtype_is_object = dtype_is_object + */ + __pyx_t_2 = ((__pyx_v_self->view.format[0]) == 'O'); + if (__pyx_t_2) { + } else { + __pyx_t_1 = __pyx_t_2; + goto __pyx_L12_bool_binop_done; + } + __pyx_t_2 = ((__pyx_v_self->view.format[1]) == '\x00'); + __pyx_t_1 = __pyx_t_2; + __pyx_L12_bool_binop_done:; + __pyx_v_self->dtype_is_object = __pyx_t_1; + + /* "View.MemoryView":368 + * raise MemoryError + * + * if flags & PyBUF_FORMAT: # <<<<<<<<<<<<<< + * self.dtype_is_object = (self.view.format[0] == b'O' and self.view.format[1] == b'\0') + * else: + */ + goto __pyx_L11; + } + + /* "View.MemoryView":371 + * self.dtype_is_object = (self.view.format[0] == b'O' and self.view.format[1] == b'\0') + * else: + * self.dtype_is_object = dtype_is_object # <<<<<<<<<<<<<< + * + * assert (&self.acquisition_count) % sizeof(__pyx_atomic_int_type) == 0 + */ + /*else*/ { + __pyx_v_self->dtype_is_object = __pyx_v_dtype_is_object; + } + __pyx_L11:; + + /* "View.MemoryView":373 + * self.dtype_is_object = dtype_is_object + * + * assert (&self.acquisition_count) % sizeof(__pyx_atomic_int_type) == 0 # <<<<<<<<<<<<<< + * self.typeinfo = NULL + * + */ + #ifndef CYTHON_WITHOUT_ASSERTIONS + if (unlikely(__pyx_assertions_enabled())) { + __pyx_t_4 = ((Py_intptr_t)((void *)(&__pyx_v_self->acquisition_count))); + __pyx_t_5 = (sizeof(__pyx_atomic_int_type)); + if (unlikely(__pyx_t_5 == 0)) { + PyErr_SetString(PyExc_ZeroDivisionError, "integer division or modulo by zero"); + __PYX_ERR(1, 373, __pyx_L1_error) + } + __pyx_t_1 = ((__pyx_t_4 % __pyx_t_5) == 0); + if (unlikely(!__pyx_t_1)) { + __Pyx_Raise(__pyx_builtin_AssertionError, 0, 0, 0); + __PYX_ERR(1, 373, __pyx_L1_error) + } + } + #else + if ((1)); else __PYX_ERR(1, 373, __pyx_L1_error) + #endif + + /* "View.MemoryView":374 + * + * assert (&self.acquisition_count) % sizeof(__pyx_atomic_int_type) == 0 + * self.typeinfo = NULL # <<<<<<<<<<<<<< + * + * def __dealloc__(memoryview self): + */ + __pyx_v_self->typeinfo = NULL; + + /* "View.MemoryView":349 + * cdef __Pyx_TypeInfo *typeinfo + * + * def __cinit__(memoryview self, object obj, int flags, bint dtype_is_object=False): # <<<<<<<<<<<<<< + * self.obj = obj + * self.flags = flags + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.memoryview.__cinit__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":376 + * self.typeinfo = NULL + * + * def __dealloc__(memoryview self): # <<<<<<<<<<<<<< + * if self.obj is not None: + * __Pyx_ReleaseBuffer(&self.view) + */ + +/* Python wrapper */ +static void __pyx_memoryview___dealloc__(PyObject *__pyx_v_self); /*proto*/ +static void __pyx_memoryview___dealloc__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__dealloc__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_2__dealloc__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); +} + +static void __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_2__dealloc__(struct __pyx_memoryview_obj *__pyx_v_self) { + int __pyx_v_i; + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_t_4; + PyThread_type_lock __pyx_t_5; + PyThread_type_lock __pyx_t_6; + + /* "View.MemoryView":377 + * + * def __dealloc__(memoryview self): + * if self.obj is not None: # <<<<<<<<<<<<<< + * __Pyx_ReleaseBuffer(&self.view) + * elif (<__pyx_buffer *> &self.view).obj == Py_None: + */ + __pyx_t_1 = (__pyx_v_self->obj != Py_None); + if (__pyx_t_1) { + + /* "View.MemoryView":378 + * def __dealloc__(memoryview self): + * if self.obj is not None: + * __Pyx_ReleaseBuffer(&self.view) # <<<<<<<<<<<<<< + * elif (<__pyx_buffer *> &self.view).obj == Py_None: + * + */ + __Pyx_ReleaseBuffer((&__pyx_v_self->view)); + + /* "View.MemoryView":377 + * + * def __dealloc__(memoryview self): + * if self.obj is not None: # <<<<<<<<<<<<<< + * __Pyx_ReleaseBuffer(&self.view) + * elif (<__pyx_buffer *> &self.view).obj == Py_None: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":379 + * if self.obj is not None: + * __Pyx_ReleaseBuffer(&self.view) + * elif (<__pyx_buffer *> &self.view).obj == Py_None: # <<<<<<<<<<<<<< + * + * (<__pyx_buffer *> &self.view).obj = NULL + */ + __pyx_t_1 = (((Py_buffer *)(&__pyx_v_self->view))->obj == Py_None); + if (__pyx_t_1) { + + /* "View.MemoryView":381 + * elif (<__pyx_buffer *> &self.view).obj == Py_None: + * + * (<__pyx_buffer *> &self.view).obj = NULL # <<<<<<<<<<<<<< + * Py_DECREF(Py_None) + * + */ + ((Py_buffer *)(&__pyx_v_self->view))->obj = NULL; + + /* "View.MemoryView":382 + * + * (<__pyx_buffer *> &self.view).obj = NULL + * Py_DECREF(Py_None) # <<<<<<<<<<<<<< + * + * cdef int i + */ + Py_DECREF(Py_None); + + /* "View.MemoryView":379 + * if self.obj is not None: + * __Pyx_ReleaseBuffer(&self.view) + * elif (<__pyx_buffer *> &self.view).obj == Py_None: # <<<<<<<<<<<<<< + * + * (<__pyx_buffer *> &self.view).obj = NULL + */ + } + __pyx_L3:; + + /* "View.MemoryView":386 + * cdef int i + * global __pyx_memoryview_thread_locks_used + * if self.lock != NULL: # <<<<<<<<<<<<<< + * for i in range(__pyx_memoryview_thread_locks_used): + * if __pyx_memoryview_thread_locks[i] is self.lock: + */ + __pyx_t_1 = (__pyx_v_self->lock != NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":387 + * global __pyx_memoryview_thread_locks_used + * if self.lock != NULL: + * for i in range(__pyx_memoryview_thread_locks_used): # <<<<<<<<<<<<<< + * if __pyx_memoryview_thread_locks[i] is self.lock: + * __pyx_memoryview_thread_locks_used -= 1 + */ + __pyx_t_2 = __pyx_memoryview_thread_locks_used; + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":388 + * if self.lock != NULL: + * for i in range(__pyx_memoryview_thread_locks_used): + * if __pyx_memoryview_thread_locks[i] is self.lock: # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks_used -= 1 + * if i != __pyx_memoryview_thread_locks_used: + */ + __pyx_t_1 = ((__pyx_memoryview_thread_locks[__pyx_v_i]) == __pyx_v_self->lock); + if (__pyx_t_1) { + + /* "View.MemoryView":389 + * for i in range(__pyx_memoryview_thread_locks_used): + * if __pyx_memoryview_thread_locks[i] is self.lock: + * __pyx_memoryview_thread_locks_used -= 1 # <<<<<<<<<<<<<< + * if i != __pyx_memoryview_thread_locks_used: + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( + */ + __pyx_memoryview_thread_locks_used = (__pyx_memoryview_thread_locks_used - 1); + + /* "View.MemoryView":390 + * if __pyx_memoryview_thread_locks[i] is self.lock: + * __pyx_memoryview_thread_locks_used -= 1 + * if i != __pyx_memoryview_thread_locks_used: # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( + * __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used], __pyx_memoryview_thread_locks[i]) + */ + __pyx_t_1 = (__pyx_v_i != __pyx_memoryview_thread_locks_used); + if (__pyx_t_1) { + + /* "View.MemoryView":392 + * if i != __pyx_memoryview_thread_locks_used: + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( + * __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used], __pyx_memoryview_thread_locks[i]) # <<<<<<<<<<<<<< + * break + * else: + */ + __pyx_t_5 = (__pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used]); + __pyx_t_6 = (__pyx_memoryview_thread_locks[__pyx_v_i]); + + /* "View.MemoryView":391 + * __pyx_memoryview_thread_locks_used -= 1 + * if i != __pyx_memoryview_thread_locks_used: + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used], __pyx_memoryview_thread_locks[i]) + * break + */ + (__pyx_memoryview_thread_locks[__pyx_v_i]) = __pyx_t_5; + (__pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used]) = __pyx_t_6; + + /* "View.MemoryView":390 + * if __pyx_memoryview_thread_locks[i] is self.lock: + * __pyx_memoryview_thread_locks_used -= 1 + * if i != __pyx_memoryview_thread_locks_used: # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( + * __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used], __pyx_memoryview_thread_locks[i]) + */ + } + + /* "View.MemoryView":393 + * __pyx_memoryview_thread_locks[i], __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used] = ( + * __pyx_memoryview_thread_locks[__pyx_memoryview_thread_locks_used], __pyx_memoryview_thread_locks[i]) + * break # <<<<<<<<<<<<<< + * else: + * PyThread_free_lock(self.lock) + */ + goto __pyx_L6_break; + + /* "View.MemoryView":388 + * if self.lock != NULL: + * for i in range(__pyx_memoryview_thread_locks_used): + * if __pyx_memoryview_thread_locks[i] is self.lock: # <<<<<<<<<<<<<< + * __pyx_memoryview_thread_locks_used -= 1 + * if i != __pyx_memoryview_thread_locks_used: + */ + } + } + /*else*/ { + + /* "View.MemoryView":395 + * break + * else: + * PyThread_free_lock(self.lock) # <<<<<<<<<<<<<< + * + * cdef char *get_item_pointer(memoryview self, object index) except NULL: + */ + PyThread_free_lock(__pyx_v_self->lock); + } + __pyx_L6_break:; + + /* "View.MemoryView":386 + * cdef int i + * global __pyx_memoryview_thread_locks_used + * if self.lock != NULL: # <<<<<<<<<<<<<< + * for i in range(__pyx_memoryview_thread_locks_used): + * if __pyx_memoryview_thread_locks[i] is self.lock: + */ + } + + /* "View.MemoryView":376 + * self.typeinfo = NULL + * + * def __dealloc__(memoryview self): # <<<<<<<<<<<<<< + * if self.obj is not None: + * __Pyx_ReleaseBuffer(&self.view) + */ + + /* function exit code */ +} + +/* "View.MemoryView":397 + * PyThread_free_lock(self.lock) + * + * cdef char *get_item_pointer(memoryview self, object index) except NULL: # <<<<<<<<<<<<<< + * cdef Py_ssize_t dim + * cdef char *itemp = self.view.buf + */ + +static char *__pyx_memoryview_get_item_pointer(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index) { + Py_ssize_t __pyx_v_dim; + char *__pyx_v_itemp; + PyObject *__pyx_v_idx = NULL; + char *__pyx_r; + __Pyx_RefNannyDeclarations + Py_ssize_t __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + Py_ssize_t __pyx_t_3; + PyObject *(*__pyx_t_4)(PyObject *); + PyObject *__pyx_t_5 = NULL; + Py_ssize_t __pyx_t_6; + char *__pyx_t_7; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("get_item_pointer", 1); + + /* "View.MemoryView":399 + * cdef char *get_item_pointer(memoryview self, object index) except NULL: + * cdef Py_ssize_t dim + * cdef char *itemp = self.view.buf # <<<<<<<<<<<<<< + * + * for dim, idx in enumerate(index): + */ + __pyx_v_itemp = ((char *)__pyx_v_self->view.buf); + + /* "View.MemoryView":401 + * cdef char *itemp = self.view.buf + * + * for dim, idx in enumerate(index): # <<<<<<<<<<<<<< + * itemp = pybuffer_index(&self.view, itemp, idx, dim) + * + */ + __pyx_t_1 = 0; + if (likely(PyList_CheckExact(__pyx_v_index)) || PyTuple_CheckExact(__pyx_v_index)) { + __pyx_t_2 = __pyx_v_index; __Pyx_INCREF(__pyx_t_2); + __pyx_t_3 = 0; + __pyx_t_4 = NULL; + } else { + __pyx_t_3 = -1; __pyx_t_2 = PyObject_GetIter(__pyx_v_index); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 401, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_4 = __Pyx_PyObject_GetIterNextFunc(__pyx_t_2); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 401, __pyx_L1_error) + } + for (;;) { + if (likely(!__pyx_t_4)) { + if (likely(PyList_CheckExact(__pyx_t_2))) { + { + Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_t_2); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 401, __pyx_L1_error) + #endif + if (__pyx_t_3 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_5 = PyList_GET_ITEM(__pyx_t_2, __pyx_t_3); __Pyx_INCREF(__pyx_t_5); __pyx_t_3++; if (unlikely((0 < 0))) __PYX_ERR(1, 401, __pyx_L1_error) + #else + __pyx_t_5 = __Pyx_PySequence_ITEM(__pyx_t_2, __pyx_t_3); __pyx_t_3++; if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 401, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + #endif + } else { + { + Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_2); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 401, __pyx_L1_error) + #endif + if (__pyx_t_3 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_5 = PyTuple_GET_ITEM(__pyx_t_2, __pyx_t_3); __Pyx_INCREF(__pyx_t_5); __pyx_t_3++; if (unlikely((0 < 0))) __PYX_ERR(1, 401, __pyx_L1_error) + #else + __pyx_t_5 = __Pyx_PySequence_ITEM(__pyx_t_2, __pyx_t_3); __pyx_t_3++; if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 401, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + #endif + } + } else { + __pyx_t_5 = __pyx_t_4(__pyx_t_2); + if (unlikely(!__pyx_t_5)) { + PyObject* exc_type = PyErr_Occurred(); + if (exc_type) { + if (likely(__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) PyErr_Clear(); + else __PYX_ERR(1, 401, __pyx_L1_error) + } + break; + } + __Pyx_GOTREF(__pyx_t_5); + } + __Pyx_XDECREF_SET(__pyx_v_idx, __pyx_t_5); + __pyx_t_5 = 0; + __pyx_v_dim = __pyx_t_1; + __pyx_t_1 = (__pyx_t_1 + 1); + + /* "View.MemoryView":402 + * + * for dim, idx in enumerate(index): + * itemp = pybuffer_index(&self.view, itemp, idx, dim) # <<<<<<<<<<<<<< + * + * return itemp + */ + __pyx_t_6 = __Pyx_PyIndex_AsSsize_t(__pyx_v_idx); if (unlikely((__pyx_t_6 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 402, __pyx_L1_error) + __pyx_t_7 = __pyx_pybuffer_index((&__pyx_v_self->view), __pyx_v_itemp, __pyx_t_6, __pyx_v_dim); if (unlikely(__pyx_t_7 == ((char *)NULL))) __PYX_ERR(1, 402, __pyx_L1_error) + __pyx_v_itemp = __pyx_t_7; + + /* "View.MemoryView":401 + * cdef char *itemp = self.view.buf + * + * for dim, idx in enumerate(index): # <<<<<<<<<<<<<< + * itemp = pybuffer_index(&self.view, itemp, idx, dim) + * + */ + } + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + + /* "View.MemoryView":404 + * itemp = pybuffer_index(&self.view, itemp, idx, dim) + * + * return itemp # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = __pyx_v_itemp; + goto __pyx_L0; + + /* "View.MemoryView":397 + * PyThread_free_lock(self.lock) + * + * cdef char *get_item_pointer(memoryview self, object index) except NULL: # <<<<<<<<<<<<<< + * cdef Py_ssize_t dim + * cdef char *itemp = self.view.buf + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_AddTraceback("View.MemoryView.memoryview.get_item_pointer", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_idx); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":407 + * + * + * def __getitem__(memoryview self, object index): # <<<<<<<<<<<<<< + * if index is Ellipsis: + * return self + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview___getitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_index); /*proto*/ +static PyObject *__pyx_memoryview___getitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_index) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__getitem__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_4__getitem__(((struct __pyx_memoryview_obj *)__pyx_v_self), ((PyObject *)__pyx_v_index)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_4__getitem__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index) { + PyObject *__pyx_v_have_slices = NULL; + PyObject *__pyx_v_indices = NULL; + char *__pyx_v_itemp; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + char *__pyx_t_5; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__getitem__", 1); + + /* "View.MemoryView":408 + * + * def __getitem__(memoryview self, object index): + * if index is Ellipsis: # <<<<<<<<<<<<<< + * return self + * + */ + __pyx_t_1 = (__pyx_v_index == __pyx_builtin_Ellipsis); + if (__pyx_t_1) { + + /* "View.MemoryView":409 + * def __getitem__(memoryview self, object index): + * if index is Ellipsis: + * return self # <<<<<<<<<<<<<< + * + * have_slices, indices = _unellipsify(index, self.view.ndim) + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_self); + __pyx_r = ((PyObject *)__pyx_v_self); + goto __pyx_L0; + + /* "View.MemoryView":408 + * + * def __getitem__(memoryview self, object index): + * if index is Ellipsis: # <<<<<<<<<<<<<< + * return self + * + */ + } + + /* "View.MemoryView":411 + * return self + * + * have_slices, indices = _unellipsify(index, self.view.ndim) # <<<<<<<<<<<<<< + * + * cdef char *itemp + */ + __pyx_t_2 = _unellipsify(__pyx_v_index, __pyx_v_self->view.ndim); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 411, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + if (likely(__pyx_t_2 != Py_None)) { + PyObject* sequence = __pyx_t_2; + Py_ssize_t size = __Pyx_PySequence_SIZE(sequence); + if (unlikely(size != 2)) { + if (size > 2) __Pyx_RaiseTooManyValuesError(2); + else if (size >= 0) __Pyx_RaiseNeedMoreValuesError(size); + __PYX_ERR(1, 411, __pyx_L1_error) + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_3 = PyTuple_GET_ITEM(sequence, 0); + __pyx_t_4 = PyTuple_GET_ITEM(sequence, 1); + __Pyx_INCREF(__pyx_t_3); + __Pyx_INCREF(__pyx_t_4); + #else + __pyx_t_3 = PySequence_ITEM(sequence, 0); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 411, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_4 = PySequence_ITEM(sequence, 1); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 411, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + #endif + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + } else { + __Pyx_RaiseNoneNotIterableError(); __PYX_ERR(1, 411, __pyx_L1_error) + } + __pyx_v_have_slices = __pyx_t_3; + __pyx_t_3 = 0; + __pyx_v_indices = __pyx_t_4; + __pyx_t_4 = 0; + + /* "View.MemoryView":414 + * + * cdef char *itemp + * if have_slices: # <<<<<<<<<<<<<< + * return memview_slice(self, indices) + * else: + */ + __pyx_t_1 = __Pyx_PyObject_IsTrue(__pyx_v_have_slices); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 414, __pyx_L1_error) + if (__pyx_t_1) { + + /* "View.MemoryView":415 + * cdef char *itemp + * if have_slices: + * return memview_slice(self, indices) # <<<<<<<<<<<<<< + * else: + * itemp = self.get_item_pointer(indices) + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = ((PyObject *)__pyx_memview_slice(__pyx_v_self, __pyx_v_indices)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 415, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":414 + * + * cdef char *itemp + * if have_slices: # <<<<<<<<<<<<<< + * return memview_slice(self, indices) + * else: + */ + } + + /* "View.MemoryView":417 + * return memview_slice(self, indices) + * else: + * itemp = self.get_item_pointer(indices) # <<<<<<<<<<<<<< + * return self.convert_item_to_object(itemp) + * + */ + /*else*/ { + __pyx_t_5 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->get_item_pointer(__pyx_v_self, __pyx_v_indices); if (unlikely(__pyx_t_5 == ((char *)NULL))) __PYX_ERR(1, 417, __pyx_L1_error) + __pyx_v_itemp = __pyx_t_5; + + /* "View.MemoryView":418 + * else: + * itemp = self.get_item_pointer(indices) + * return self.convert_item_to_object(itemp) # <<<<<<<<<<<<<< + * + * def __setitem__(memoryview self, object index, object value): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->convert_item_to_object(__pyx_v_self, __pyx_v_itemp); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 418, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + } + + /* "View.MemoryView":407 + * + * + * def __getitem__(memoryview self, object index): # <<<<<<<<<<<<<< + * if index is Ellipsis: + * return self + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView.memoryview.__getitem__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_have_slices); + __Pyx_XDECREF(__pyx_v_indices); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":420 + * return self.convert_item_to_object(itemp) + * + * def __setitem__(memoryview self, object index, object value): # <<<<<<<<<<<<<< + * if self.view.readonly: + * raise TypeError, "Cannot assign to read-only memoryview" + */ + +/* Python wrapper */ +static int __pyx_memoryview___setitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value); /*proto*/ +static int __pyx_memoryview___setitem__(PyObject *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setitem__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_6__setitem__(((struct __pyx_memoryview_obj *)__pyx_v_self), ((PyObject *)__pyx_v_index), ((PyObject *)__pyx_v_value)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_6__setitem__(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value) { + PyObject *__pyx_v_have_slices = NULL; + PyObject *__pyx_v_obj = NULL; + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setitem__", 0); + __Pyx_INCREF(__pyx_v_index); + + /* "View.MemoryView":421 + * + * def __setitem__(memoryview self, object index, object value): + * if self.view.readonly: # <<<<<<<<<<<<<< + * raise TypeError, "Cannot assign to read-only memoryview" + * + */ + if (unlikely(__pyx_v_self->view.readonly)) { + + /* "View.MemoryView":422 + * def __setitem__(memoryview self, object index, object value): + * if self.view.readonly: + * raise TypeError, "Cannot assign to read-only memoryview" # <<<<<<<<<<<<<< + * + * have_slices, index = _unellipsify(index, self.view.ndim) + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_Cannot_assign_to_read_only_memor, 0, 0); + __PYX_ERR(1, 422, __pyx_L1_error) + + /* "View.MemoryView":421 + * + * def __setitem__(memoryview self, object index, object value): + * if self.view.readonly: # <<<<<<<<<<<<<< + * raise TypeError, "Cannot assign to read-only memoryview" + * + */ + } + + /* "View.MemoryView":424 + * raise TypeError, "Cannot assign to read-only memoryview" + * + * have_slices, index = _unellipsify(index, self.view.ndim) # <<<<<<<<<<<<<< + * + * if have_slices: + */ + __pyx_t_1 = _unellipsify(__pyx_v_index, __pyx_v_self->view.ndim); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 424, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + if (likely(__pyx_t_1 != Py_None)) { + PyObject* sequence = __pyx_t_1; + Py_ssize_t size = __Pyx_PySequence_SIZE(sequence); + if (unlikely(size != 2)) { + if (size > 2) __Pyx_RaiseTooManyValuesError(2); + else if (size >= 0) __Pyx_RaiseNeedMoreValuesError(size); + __PYX_ERR(1, 424, __pyx_L1_error) + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_2 = PyTuple_GET_ITEM(sequence, 0); + __pyx_t_3 = PyTuple_GET_ITEM(sequence, 1); + __Pyx_INCREF(__pyx_t_2); + __Pyx_INCREF(__pyx_t_3); + #else + __pyx_t_2 = PySequence_ITEM(sequence, 0); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 424, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PySequence_ITEM(sequence, 1); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 424, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + #endif + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + } else { + __Pyx_RaiseNoneNotIterableError(); __PYX_ERR(1, 424, __pyx_L1_error) + } + __pyx_v_have_slices = __pyx_t_2; + __pyx_t_2 = 0; + __Pyx_DECREF_SET(__pyx_v_index, __pyx_t_3); + __pyx_t_3 = 0; + + /* "View.MemoryView":426 + * have_slices, index = _unellipsify(index, self.view.ndim) + * + * if have_slices: # <<<<<<<<<<<<<< + * obj = self.is_slice(value) + * if obj is not None: + */ + __pyx_t_4 = __Pyx_PyObject_IsTrue(__pyx_v_have_slices); if (unlikely((__pyx_t_4 < 0))) __PYX_ERR(1, 426, __pyx_L1_error) + if (__pyx_t_4) { + + /* "View.MemoryView":427 + * + * if have_slices: + * obj = self.is_slice(value) # <<<<<<<<<<<<<< + * if obj is not None: + * self.setitem_slice_assignment(self[index], obj) + */ + __pyx_t_1 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->is_slice(__pyx_v_self, __pyx_v_value); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 427, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v_obj = __pyx_t_1; + __pyx_t_1 = 0; + + /* "View.MemoryView":428 + * if have_slices: + * obj = self.is_slice(value) + * if obj is not None: # <<<<<<<<<<<<<< + * self.setitem_slice_assignment(self[index], obj) + * else: + */ + __pyx_t_4 = (__pyx_v_obj != Py_None); + if (__pyx_t_4) { + + /* "View.MemoryView":429 + * obj = self.is_slice(value) + * if obj is not None: + * self.setitem_slice_assignment(self[index], obj) # <<<<<<<<<<<<<< + * else: + * self.setitem_slice_assign_scalar(self[index], value) + */ + __pyx_t_1 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_self), __pyx_v_index); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 429, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_3 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->setitem_slice_assignment(__pyx_v_self, __pyx_t_1, __pyx_v_obj); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 429, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + + /* "View.MemoryView":428 + * if have_slices: + * obj = self.is_slice(value) + * if obj is not None: # <<<<<<<<<<<<<< + * self.setitem_slice_assignment(self[index], obj) + * else: + */ + goto __pyx_L5; + } + + /* "View.MemoryView":431 + * self.setitem_slice_assignment(self[index], obj) + * else: + * self.setitem_slice_assign_scalar(self[index], value) # <<<<<<<<<<<<<< + * else: + * self.setitem_indexed(index, value) + */ + /*else*/ { + __pyx_t_3 = __Pyx_PyObject_GetItem(((PyObject *)__pyx_v_self), __pyx_v_index); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 431, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + if (!(likely(((__pyx_t_3) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_3, __pyx_memoryview_type))))) __PYX_ERR(1, 431, __pyx_L1_error) + __pyx_t_1 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->setitem_slice_assign_scalar(__pyx_v_self, ((struct __pyx_memoryview_obj *)__pyx_t_3), __pyx_v_value); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 431, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + } + __pyx_L5:; + + /* "View.MemoryView":426 + * have_slices, index = _unellipsify(index, self.view.ndim) + * + * if have_slices: # <<<<<<<<<<<<<< + * obj = self.is_slice(value) + * if obj is not None: + */ + goto __pyx_L4; + } + + /* "View.MemoryView":433 + * self.setitem_slice_assign_scalar(self[index], value) + * else: + * self.setitem_indexed(index, value) # <<<<<<<<<<<<<< + * + * cdef is_slice(self, obj): + */ + /*else*/ { + __pyx_t_1 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->setitem_indexed(__pyx_v_self, __pyx_v_index, __pyx_v_value); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 433, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + } + __pyx_L4:; + + /* "View.MemoryView":420 + * return self.convert_item_to_object(itemp) + * + * def __setitem__(memoryview self, object index, object value): # <<<<<<<<<<<<<< + * if self.view.readonly: + * raise TypeError, "Cannot assign to read-only memoryview" + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview.__setitem__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_have_slices); + __Pyx_XDECREF(__pyx_v_obj); + __Pyx_XDECREF(__pyx_v_index); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":435 + * self.setitem_indexed(index, value) + * + * cdef is_slice(self, obj): # <<<<<<<<<<<<<< + * if not isinstance(obj, memoryview): + * try: + */ + +static PyObject *__pyx_memoryview_is_slice(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_obj) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + int __pyx_t_9; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("is_slice", 0); + __Pyx_INCREF(__pyx_v_obj); + + /* "View.MemoryView":436 + * + * cdef is_slice(self, obj): + * if not isinstance(obj, memoryview): # <<<<<<<<<<<<<< + * try: + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + */ + __pyx_t_1 = __Pyx_TypeCheck(__pyx_v_obj, __pyx_memoryview_type); + __pyx_t_2 = (!__pyx_t_1); + if (__pyx_t_2) { + + /* "View.MemoryView":437 + * cdef is_slice(self, obj): + * if not isinstance(obj, memoryview): + * try: # <<<<<<<<<<<<<< + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + * self.dtype_is_object) + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_3, &__pyx_t_4, &__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_4); + __Pyx_XGOTREF(__pyx_t_5); + /*try:*/ { + + /* "View.MemoryView":438 + * if not isinstance(obj, memoryview): + * try: + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, # <<<<<<<<<<<<<< + * self.dtype_is_object) + * except TypeError: + */ + __pyx_t_6 = __Pyx_PyInt_From_int(((__pyx_v_self->flags & (~PyBUF_WRITABLE)) | PyBUF_ANY_CONTIGUOUS)); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 438, __pyx_L4_error) + __Pyx_GOTREF(__pyx_t_6); + + /* "View.MemoryView":439 + * try: + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + * self.dtype_is_object) # <<<<<<<<<<<<<< + * except TypeError: + * return None + */ + __pyx_t_7 = __Pyx_PyBool_FromLong(__pyx_v_self->dtype_is_object); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 439, __pyx_L4_error) + __Pyx_GOTREF(__pyx_t_7); + + /* "View.MemoryView":438 + * if not isinstance(obj, memoryview): + * try: + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, # <<<<<<<<<<<<<< + * self.dtype_is_object) + * except TypeError: + */ + __pyx_t_8 = PyTuple_New(3); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 438, __pyx_L4_error) + __Pyx_GOTREF(__pyx_t_8); + __Pyx_INCREF(__pyx_v_obj); + __Pyx_GIVEREF(__pyx_v_obj); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_8, 0, __pyx_v_obj)) __PYX_ERR(1, 438, __pyx_L4_error); + __Pyx_GIVEREF(__pyx_t_6); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_8, 1, __pyx_t_6)) __PYX_ERR(1, 438, __pyx_L4_error); + __Pyx_GIVEREF(__pyx_t_7); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_8, 2, __pyx_t_7)) __PYX_ERR(1, 438, __pyx_L4_error); + __pyx_t_6 = 0; + __pyx_t_7 = 0; + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_memoryview_type), __pyx_t_8, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 438, __pyx_L4_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __Pyx_DECREF_SET(__pyx_v_obj, __pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":437 + * cdef is_slice(self, obj): + * if not isinstance(obj, memoryview): + * try: # <<<<<<<<<<<<<< + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + * self.dtype_is_object) + */ + } + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + goto __pyx_L9_try_end; + __pyx_L4_error:; + __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_XDECREF(__pyx_t_8); __pyx_t_8 = 0; + + /* "View.MemoryView":440 + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + * self.dtype_is_object) + * except TypeError: # <<<<<<<<<<<<<< + * return None + * + */ + __pyx_t_9 = __Pyx_PyErr_ExceptionMatches(__pyx_builtin_TypeError); + if (__pyx_t_9) { + __Pyx_AddTraceback("View.MemoryView.memoryview.is_slice", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_7, &__pyx_t_8, &__pyx_t_6) < 0) __PYX_ERR(1, 440, __pyx_L6_except_error) + __Pyx_XGOTREF(__pyx_t_7); + __Pyx_XGOTREF(__pyx_t_8); + __Pyx_XGOTREF(__pyx_t_6); + + /* "View.MemoryView":441 + * self.dtype_is_object) + * except TypeError: + * return None # <<<<<<<<<<<<<< + * + * return obj + */ + __Pyx_XDECREF(__pyx_r); + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + goto __pyx_L7_except_return; + } + goto __pyx_L6_except_error; + + /* "View.MemoryView":437 + * cdef is_slice(self, obj): + * if not isinstance(obj, memoryview): + * try: # <<<<<<<<<<<<<< + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + * self.dtype_is_object) + */ + __pyx_L6_except_error:; + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_4); + __Pyx_XGIVEREF(__pyx_t_5); + __Pyx_ExceptionReset(__pyx_t_3, __pyx_t_4, __pyx_t_5); + goto __pyx_L1_error; + __pyx_L7_except_return:; + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_4); + __Pyx_XGIVEREF(__pyx_t_5); + __Pyx_ExceptionReset(__pyx_t_3, __pyx_t_4, __pyx_t_5); + goto __pyx_L0; + __pyx_L9_try_end:; + } + + /* "View.MemoryView":436 + * + * cdef is_slice(self, obj): + * if not isinstance(obj, memoryview): # <<<<<<<<<<<<<< + * try: + * obj = memoryview(obj, self.flags & ~PyBUF_WRITABLE | PyBUF_ANY_CONTIGUOUS, + */ + } + + /* "View.MemoryView":443 + * return None + * + * return obj # <<<<<<<<<<<<<< + * + * cdef setitem_slice_assignment(self, dst, src): + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_obj); + __pyx_r = __pyx_v_obj; + goto __pyx_L0; + + /* "View.MemoryView":435 + * self.setitem_indexed(index, value) + * + * cdef is_slice(self, obj): # <<<<<<<<<<<<<< + * if not isinstance(obj, memoryview): + * try: + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("View.MemoryView.memoryview.is_slice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_obj); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":445 + * return obj + * + * cdef setitem_slice_assignment(self, dst, src): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice dst_slice + * cdef __Pyx_memviewslice src_slice + */ + +static PyObject *__pyx_memoryview_setitem_slice_assignment(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_dst, PyObject *__pyx_v_src) { + __Pyx_memviewslice __pyx_v_dst_slice; + __Pyx_memviewslice __pyx_v_src_slice; + __Pyx_memviewslice __pyx_v_msrc; + __Pyx_memviewslice __pyx_v_mdst; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice *__pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_t_3; + int __pyx_t_4; + int __pyx_t_5; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("setitem_slice_assignment", 1); + + /* "View.MemoryView":448 + * cdef __Pyx_memviewslice dst_slice + * cdef __Pyx_memviewslice src_slice + * cdef __Pyx_memviewslice msrc = get_slice_from_memview(src, &src_slice)[0] # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice mdst = get_slice_from_memview(dst, &dst_slice)[0] + * + */ + if (!(likely(((__pyx_v_src) == Py_None) || likely(__Pyx_TypeTest(__pyx_v_src, __pyx_memoryview_type))))) __PYX_ERR(1, 448, __pyx_L1_error) + __pyx_t_1 = __pyx_memoryview_get_slice_from_memoryview(((struct __pyx_memoryview_obj *)__pyx_v_src), (&__pyx_v_src_slice)); if (unlikely(__pyx_t_1 == ((__Pyx_memviewslice *)NULL))) __PYX_ERR(1, 448, __pyx_L1_error) + __pyx_v_msrc = (__pyx_t_1[0]); + + /* "View.MemoryView":449 + * cdef __Pyx_memviewslice src_slice + * cdef __Pyx_memviewslice msrc = get_slice_from_memview(src, &src_slice)[0] + * cdef __Pyx_memviewslice mdst = get_slice_from_memview(dst, &dst_slice)[0] # <<<<<<<<<<<<<< + * + * memoryview_copy_contents(msrc, mdst, src.ndim, dst.ndim, self.dtype_is_object) + */ + if (!(likely(((__pyx_v_dst) == Py_None) || likely(__Pyx_TypeTest(__pyx_v_dst, __pyx_memoryview_type))))) __PYX_ERR(1, 449, __pyx_L1_error) + __pyx_t_1 = __pyx_memoryview_get_slice_from_memoryview(((struct __pyx_memoryview_obj *)__pyx_v_dst), (&__pyx_v_dst_slice)); if (unlikely(__pyx_t_1 == ((__Pyx_memviewslice *)NULL))) __PYX_ERR(1, 449, __pyx_L1_error) + __pyx_v_mdst = (__pyx_t_1[0]); + + /* "View.MemoryView":451 + * cdef __Pyx_memviewslice mdst = get_slice_from_memview(dst, &dst_slice)[0] + * + * memoryview_copy_contents(msrc, mdst, src.ndim, dst.ndim, self.dtype_is_object) # <<<<<<<<<<<<<< + * + * cdef setitem_slice_assign_scalar(self, memoryview dst, value): + */ + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_src, __pyx_n_s_ndim); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 451, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = __Pyx_PyInt_As_int(__pyx_t_2); if (unlikely((__pyx_t_3 == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 451, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_v_dst, __pyx_n_s_ndim); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 451, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_4 = __Pyx_PyInt_As_int(__pyx_t_2); if (unlikely((__pyx_t_4 == (int)-1) && PyErr_Occurred())) __PYX_ERR(1, 451, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_5 = __pyx_memoryview_copy_contents(__pyx_v_msrc, __pyx_v_mdst, __pyx_t_3, __pyx_t_4, __pyx_v_self->dtype_is_object); if (unlikely(__pyx_t_5 == ((int)-1))) __PYX_ERR(1, 451, __pyx_L1_error) + + /* "View.MemoryView":445 + * return obj + * + * cdef setitem_slice_assignment(self, dst, src): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice dst_slice + * cdef __Pyx_memviewslice src_slice + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.setitem_slice_assignment", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":453 + * memoryview_copy_contents(msrc, mdst, src.ndim, dst.ndim, self.dtype_is_object) + * + * cdef setitem_slice_assign_scalar(self, memoryview dst, value): # <<<<<<<<<<<<<< + * cdef int array[128] + * cdef void *tmp = NULL + */ + +static PyObject *__pyx_memoryview_setitem_slice_assign_scalar(struct __pyx_memoryview_obj *__pyx_v_self, struct __pyx_memoryview_obj *__pyx_v_dst, PyObject *__pyx_v_value) { + int __pyx_v_array[0x80]; + void *__pyx_v_tmp; + void *__pyx_v_item; + __Pyx_memviewslice *__pyx_v_dst_slice; + __Pyx_memviewslice __pyx_v_tmp_slice; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice *__pyx_t_1; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + int __pyx_t_5; + char const *__pyx_t_6; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + PyObject *__pyx_t_9 = NULL; + PyObject *__pyx_t_10 = NULL; + PyObject *__pyx_t_11 = NULL; + PyObject *__pyx_t_12 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("setitem_slice_assign_scalar", 1); + + /* "View.MemoryView":455 + * cdef setitem_slice_assign_scalar(self, memoryview dst, value): + * cdef int array[128] + * cdef void *tmp = NULL # <<<<<<<<<<<<<< + * cdef void *item + * + */ + __pyx_v_tmp = NULL; + + /* "View.MemoryView":460 + * cdef __Pyx_memviewslice *dst_slice + * cdef __Pyx_memviewslice tmp_slice + * dst_slice = get_slice_from_memview(dst, &tmp_slice) # <<<<<<<<<<<<<< + * + * if self.view.itemsize > sizeof(array): + */ + __pyx_t_1 = __pyx_memoryview_get_slice_from_memoryview(__pyx_v_dst, (&__pyx_v_tmp_slice)); if (unlikely(__pyx_t_1 == ((__Pyx_memviewslice *)NULL))) __PYX_ERR(1, 460, __pyx_L1_error) + __pyx_v_dst_slice = __pyx_t_1; + + /* "View.MemoryView":462 + * dst_slice = get_slice_from_memview(dst, &tmp_slice) + * + * if self.view.itemsize > sizeof(array): # <<<<<<<<<<<<<< + * tmp = PyMem_Malloc(self.view.itemsize) + * if tmp == NULL: + */ + __pyx_t_2 = (((size_t)__pyx_v_self->view.itemsize) > (sizeof(__pyx_v_array))); + if (__pyx_t_2) { + + /* "View.MemoryView":463 + * + * if self.view.itemsize > sizeof(array): + * tmp = PyMem_Malloc(self.view.itemsize) # <<<<<<<<<<<<<< + * if tmp == NULL: + * raise MemoryError + */ + __pyx_v_tmp = PyMem_Malloc(__pyx_v_self->view.itemsize); + + /* "View.MemoryView":464 + * if self.view.itemsize > sizeof(array): + * tmp = PyMem_Malloc(self.view.itemsize) + * if tmp == NULL: # <<<<<<<<<<<<<< + * raise MemoryError + * item = tmp + */ + __pyx_t_2 = (__pyx_v_tmp == NULL); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":465 + * tmp = PyMem_Malloc(self.view.itemsize) + * if tmp == NULL: + * raise MemoryError # <<<<<<<<<<<<<< + * item = tmp + * else: + */ + PyErr_NoMemory(); __PYX_ERR(1, 465, __pyx_L1_error) + + /* "View.MemoryView":464 + * if self.view.itemsize > sizeof(array): + * tmp = PyMem_Malloc(self.view.itemsize) + * if tmp == NULL: # <<<<<<<<<<<<<< + * raise MemoryError + * item = tmp + */ + } + + /* "View.MemoryView":466 + * if tmp == NULL: + * raise MemoryError + * item = tmp # <<<<<<<<<<<<<< + * else: + * item = array + */ + __pyx_v_item = __pyx_v_tmp; + + /* "View.MemoryView":462 + * dst_slice = get_slice_from_memview(dst, &tmp_slice) + * + * if self.view.itemsize > sizeof(array): # <<<<<<<<<<<<<< + * tmp = PyMem_Malloc(self.view.itemsize) + * if tmp == NULL: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":468 + * item = tmp + * else: + * item = array # <<<<<<<<<<<<<< + * + * try: + */ + /*else*/ { + __pyx_v_item = ((void *)__pyx_v_array); + } + __pyx_L3:; + + /* "View.MemoryView":470 + * item = array + * + * try: # <<<<<<<<<<<<<< + * if self.dtype_is_object: + * ( item)[0] = value + */ + /*try:*/ { + + /* "View.MemoryView":471 + * + * try: + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * ( item)[0] = value + * else: + */ + if (__pyx_v_self->dtype_is_object) { + + /* "View.MemoryView":472 + * try: + * if self.dtype_is_object: + * ( item)[0] = value # <<<<<<<<<<<<<< + * else: + * self.assign_item_from_object( item, value) + */ + (((PyObject **)__pyx_v_item)[0]) = ((PyObject *)__pyx_v_value); + + /* "View.MemoryView":471 + * + * try: + * if self.dtype_is_object: # <<<<<<<<<<<<<< + * ( item)[0] = value + * else: + */ + goto __pyx_L8; + } + + /* "View.MemoryView":474 + * ( item)[0] = value + * else: + * self.assign_item_from_object( item, value) # <<<<<<<<<<<<<< + * + * + */ + /*else*/ { + __pyx_t_3 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->assign_item_from_object(__pyx_v_self, ((char *)__pyx_v_item), __pyx_v_value); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 474, __pyx_L6_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + } + __pyx_L8:; + + /* "View.MemoryView":478 + * + * + * if self.view.suboffsets != NULL: # <<<<<<<<<<<<<< + * assert_direct_dimensions(self.view.suboffsets, self.view.ndim) + * slice_assign_scalar(dst_slice, dst.view.ndim, self.view.itemsize, + */ + __pyx_t_2 = (__pyx_v_self->view.suboffsets != NULL); + if (__pyx_t_2) { + + /* "View.MemoryView":479 + * + * if self.view.suboffsets != NULL: + * assert_direct_dimensions(self.view.suboffsets, self.view.ndim) # <<<<<<<<<<<<<< + * slice_assign_scalar(dst_slice, dst.view.ndim, self.view.itemsize, + * item, self.dtype_is_object) + */ + __pyx_t_4 = assert_direct_dimensions(__pyx_v_self->view.suboffsets, __pyx_v_self->view.ndim); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(1, 479, __pyx_L6_error) + + /* "View.MemoryView":478 + * + * + * if self.view.suboffsets != NULL: # <<<<<<<<<<<<<< + * assert_direct_dimensions(self.view.suboffsets, self.view.ndim) + * slice_assign_scalar(dst_slice, dst.view.ndim, self.view.itemsize, + */ + } + + /* "View.MemoryView":480 + * if self.view.suboffsets != NULL: + * assert_direct_dimensions(self.view.suboffsets, self.view.ndim) + * slice_assign_scalar(dst_slice, dst.view.ndim, self.view.itemsize, # <<<<<<<<<<<<<< + * item, self.dtype_is_object) + * finally: + */ + __pyx_memoryview_slice_assign_scalar(__pyx_v_dst_slice, __pyx_v_dst->view.ndim, __pyx_v_self->view.itemsize, __pyx_v_item, __pyx_v_self->dtype_is_object); + } + + /* "View.MemoryView":483 + * item, self.dtype_is_object) + * finally: + * PyMem_Free(tmp) # <<<<<<<<<<<<<< + * + * cdef setitem_indexed(self, index, value): + */ + /*finally:*/ { + /*normal exit:*/{ + PyMem_Free(__pyx_v_tmp); + goto __pyx_L7; + } + __pyx_L6_error:; + /*exception exit:*/{ + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __pyx_t_7 = 0; __pyx_t_8 = 0; __pyx_t_9 = 0; __pyx_t_10 = 0; __pyx_t_11 = 0; __pyx_t_12 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + if (PY_MAJOR_VERSION >= 3) __Pyx_ExceptionSwap(&__pyx_t_10, &__pyx_t_11, &__pyx_t_12); + if ((PY_MAJOR_VERSION < 3) || unlikely(__Pyx_GetException(&__pyx_t_7, &__pyx_t_8, &__pyx_t_9) < 0)) __Pyx_ErrFetch(&__pyx_t_7, &__pyx_t_8, &__pyx_t_9); + __Pyx_XGOTREF(__pyx_t_7); + __Pyx_XGOTREF(__pyx_t_8); + __Pyx_XGOTREF(__pyx_t_9); + __Pyx_XGOTREF(__pyx_t_10); + __Pyx_XGOTREF(__pyx_t_11); + __Pyx_XGOTREF(__pyx_t_12); + __pyx_t_4 = __pyx_lineno; __pyx_t_5 = __pyx_clineno; __pyx_t_6 = __pyx_filename; + { + PyMem_Free(__pyx_v_tmp); + } + if (PY_MAJOR_VERSION >= 3) { + __Pyx_XGIVEREF(__pyx_t_10); + __Pyx_XGIVEREF(__pyx_t_11); + __Pyx_XGIVEREF(__pyx_t_12); + __Pyx_ExceptionReset(__pyx_t_10, __pyx_t_11, __pyx_t_12); + } + __Pyx_XGIVEREF(__pyx_t_7); + __Pyx_XGIVEREF(__pyx_t_8); + __Pyx_XGIVEREF(__pyx_t_9); + __Pyx_ErrRestore(__pyx_t_7, __pyx_t_8, __pyx_t_9); + __pyx_t_7 = 0; __pyx_t_8 = 0; __pyx_t_9 = 0; __pyx_t_10 = 0; __pyx_t_11 = 0; __pyx_t_12 = 0; + __pyx_lineno = __pyx_t_4; __pyx_clineno = __pyx_t_5; __pyx_filename = __pyx_t_6; + goto __pyx_L1_error; + } + __pyx_L7:; + } + + /* "View.MemoryView":453 + * memoryview_copy_contents(msrc, mdst, src.ndim, dst.ndim, self.dtype_is_object) + * + * cdef setitem_slice_assign_scalar(self, memoryview dst, value): # <<<<<<<<<<<<<< + * cdef int array[128] + * cdef void *tmp = NULL + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview.setitem_slice_assign_scalar", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":485 + * PyMem_Free(tmp) + * + * cdef setitem_indexed(self, index, value): # <<<<<<<<<<<<<< + * cdef char *itemp = self.get_item_pointer(index) + * self.assign_item_from_object(itemp, value) + */ + +static PyObject *__pyx_memoryview_setitem_indexed(struct __pyx_memoryview_obj *__pyx_v_self, PyObject *__pyx_v_index, PyObject *__pyx_v_value) { + char *__pyx_v_itemp; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + char *__pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("setitem_indexed", 1); + + /* "View.MemoryView":486 + * + * cdef setitem_indexed(self, index, value): + * cdef char *itemp = self.get_item_pointer(index) # <<<<<<<<<<<<<< + * self.assign_item_from_object(itemp, value) + * + */ + __pyx_t_1 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->get_item_pointer(__pyx_v_self, __pyx_v_index); if (unlikely(__pyx_t_1 == ((char *)NULL))) __PYX_ERR(1, 486, __pyx_L1_error) + __pyx_v_itemp = __pyx_t_1; + + /* "View.MemoryView":487 + * cdef setitem_indexed(self, index, value): + * cdef char *itemp = self.get_item_pointer(index) + * self.assign_item_from_object(itemp, value) # <<<<<<<<<<<<<< + * + * cdef convert_item_to_object(self, char *itemp): + */ + __pyx_t_2 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->assign_item_from_object(__pyx_v_self, __pyx_v_itemp, __pyx_v_value); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 487, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + + /* "View.MemoryView":485 + * PyMem_Free(tmp) + * + * cdef setitem_indexed(self, index, value): # <<<<<<<<<<<<<< + * cdef char *itemp = self.get_item_pointer(index) + * self.assign_item_from_object(itemp, value) + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.setitem_indexed", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":489 + * self.assign_item_from_object(itemp, value) + * + * cdef convert_item_to_object(self, char *itemp): # <<<<<<<<<<<<<< + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + */ + +static PyObject *__pyx_memoryview_convert_item_to_object(struct __pyx_memoryview_obj *__pyx_v_self, char *__pyx_v_itemp) { + PyObject *__pyx_v_struct = NULL; + PyObject *__pyx_v_bytesitem = 0; + PyObject *__pyx_v_result = NULL; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + unsigned int __pyx_t_8; + Py_ssize_t __pyx_t_9; + int __pyx_t_10; + int __pyx_t_11; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("convert_item_to_object", 1); + + /* "View.MemoryView":492 + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + * import struct # <<<<<<<<<<<<<< + * cdef bytes bytesitem + * + */ + __pyx_t_1 = __Pyx_ImportDottedModule(__pyx_n_s_struct, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 492, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v_struct = __pyx_t_1; + __pyx_t_1 = 0; + + /* "View.MemoryView":495 + * cdef bytes bytesitem + * + * bytesitem = itemp[:self.view.itemsize] # <<<<<<<<<<<<<< + * try: + * result = struct.unpack(self.view.format, bytesitem) + */ + __pyx_t_1 = __Pyx_PyBytes_FromStringAndSize(__pyx_v_itemp + 0, __pyx_v_self->view.itemsize - 0); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 495, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v_bytesitem = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "View.MemoryView":496 + * + * bytesitem = itemp[:self.view.itemsize] + * try: # <<<<<<<<<<<<<< + * result = struct.unpack(self.view.format, bytesitem) + * except struct.error: + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_2, &__pyx_t_3, &__pyx_t_4); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_4); + /*try:*/ { + + /* "View.MemoryView":497 + * bytesitem = itemp[:self.view.itemsize] + * try: + * result = struct.unpack(self.view.format, bytesitem) # <<<<<<<<<<<<<< + * except struct.error: + * raise ValueError, "Unable to convert item to object" + */ + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_v_struct, __pyx_n_s_unpack); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 497, __pyx_L3_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_6 = __Pyx_PyBytes_FromString(__pyx_v_self->view.format); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 497, __pyx_L3_error) + __Pyx_GOTREF(__pyx_t_6); + __pyx_t_7 = NULL; + __pyx_t_8 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_5))) { + __pyx_t_7 = PyMethod_GET_SELF(__pyx_t_5); + if (likely(__pyx_t_7)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_5); + __Pyx_INCREF(__pyx_t_7); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_5, function); + __pyx_t_8 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[3] = {__pyx_t_7, __pyx_t_6, __pyx_v_bytesitem}; + __pyx_t_1 = __Pyx_PyObject_FastCall(__pyx_t_5, __pyx_callargs+1-__pyx_t_8, 2+__pyx_t_8); + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 497, __pyx_L3_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + } + __pyx_v_result = __pyx_t_1; + __pyx_t_1 = 0; + + /* "View.MemoryView":496 + * + * bytesitem = itemp[:self.view.itemsize] + * try: # <<<<<<<<<<<<<< + * result = struct.unpack(self.view.format, bytesitem) + * except struct.error: + */ + } + + /* "View.MemoryView":501 + * raise ValueError, "Unable to convert item to object" + * else: + * if len(self.view.format) == 1: # <<<<<<<<<<<<<< + * return result[0] + * return result + */ + /*else:*/ { + __pyx_t_9 = __Pyx_ssize_strlen(__pyx_v_self->view.format); if (unlikely(__pyx_t_9 == ((Py_ssize_t)-1))) __PYX_ERR(1, 501, __pyx_L5_except_error) + __pyx_t_10 = (__pyx_t_9 == 1); + if (__pyx_t_10) { + + /* "View.MemoryView":502 + * else: + * if len(self.view.format) == 1: + * return result[0] # <<<<<<<<<<<<<< + * return result + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_GetItemInt(__pyx_v_result, 0, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 502, __pyx_L5_except_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L6_except_return; + + /* "View.MemoryView":501 + * raise ValueError, "Unable to convert item to object" + * else: + * if len(self.view.format) == 1: # <<<<<<<<<<<<<< + * return result[0] + * return result + */ + } + + /* "View.MemoryView":503 + * if len(self.view.format) == 1: + * return result[0] + * return result # <<<<<<<<<<<<<< + * + * cdef assign_item_from_object(self, char *itemp, object value): + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_result); + __pyx_r = __pyx_v_result; + goto __pyx_L6_except_return; + } + __pyx_L3_error:; + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XDECREF(__pyx_t_6); __pyx_t_6 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "View.MemoryView":498 + * try: + * result = struct.unpack(self.view.format, bytesitem) + * except struct.error: # <<<<<<<<<<<<<< + * raise ValueError, "Unable to convert item to object" + * else: + */ + __Pyx_ErrFetch(&__pyx_t_1, &__pyx_t_5, &__pyx_t_6); + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_v_struct, __pyx_n_s_error); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 498, __pyx_L5_except_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_11 = __Pyx_PyErr_GivenExceptionMatches(__pyx_t_1, __pyx_t_7); + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_ErrRestore(__pyx_t_1, __pyx_t_5, __pyx_t_6); + __pyx_t_1 = 0; __pyx_t_5 = 0; __pyx_t_6 = 0; + if (__pyx_t_11) { + __Pyx_AddTraceback("View.MemoryView.memoryview.convert_item_to_object", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_6, &__pyx_t_5, &__pyx_t_1) < 0) __PYX_ERR(1, 498, __pyx_L5_except_error) + __Pyx_XGOTREF(__pyx_t_6); + __Pyx_XGOTREF(__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_1); + + /* "View.MemoryView":499 + * result = struct.unpack(self.view.format, bytesitem) + * except struct.error: + * raise ValueError, "Unable to convert item to object" # <<<<<<<<<<<<<< + * else: + * if len(self.view.format) == 1: + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Unable_to_convert_item_to_object, 0, 0); + __PYX_ERR(1, 499, __pyx_L5_except_error) + } + goto __pyx_L5_except_error; + + /* "View.MemoryView":496 + * + * bytesitem = itemp[:self.view.itemsize] + * try: # <<<<<<<<<<<<<< + * result = struct.unpack(self.view.format, bytesitem) + * except struct.error: + */ + __pyx_L5_except_error:; + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_4); + __Pyx_ExceptionReset(__pyx_t_2, __pyx_t_3, __pyx_t_4); + goto __pyx_L1_error; + __pyx_L6_except_return:; + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_4); + __Pyx_ExceptionReset(__pyx_t_2, __pyx_t_3, __pyx_t_4); + goto __pyx_L0; + } + + /* "View.MemoryView":489 + * self.assign_item_from_object(itemp, value) + * + * cdef convert_item_to_object(self, char *itemp): # <<<<<<<<<<<<<< + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_AddTraceback("View.MemoryView.memoryview.convert_item_to_object", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_struct); + __Pyx_XDECREF(__pyx_v_bytesitem); + __Pyx_XDECREF(__pyx_v_result); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":505 + * return result + * + * cdef assign_item_from_object(self, char *itemp, object value): # <<<<<<<<<<<<<< + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + */ + +static PyObject *__pyx_memoryview_assign_item_from_object(struct __pyx_memoryview_obj *__pyx_v_self, char *__pyx_v_itemp, PyObject *__pyx_v_value) { + PyObject *__pyx_v_struct = NULL; + char __pyx_v_c; + PyObject *__pyx_v_bytesvalue = 0; + Py_ssize_t __pyx_v_i; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + unsigned int __pyx_t_6; + Py_ssize_t __pyx_t_7; + PyObject *__pyx_t_8 = NULL; + char *__pyx_t_9; + char *__pyx_t_10; + char *__pyx_t_11; + char *__pyx_t_12; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("assign_item_from_object", 1); + + /* "View.MemoryView":508 + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + * import struct # <<<<<<<<<<<<<< + * cdef char c + * cdef bytes bytesvalue + */ + __pyx_t_1 = __Pyx_ImportDottedModule(__pyx_n_s_struct, NULL); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 508, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_v_struct = __pyx_t_1; + __pyx_t_1 = 0; + + /* "View.MemoryView":513 + * cdef Py_ssize_t i + * + * if isinstance(value, tuple): # <<<<<<<<<<<<<< + * bytesvalue = struct.pack(self.view.format, *value) + * else: + */ + __pyx_t_2 = PyTuple_Check(__pyx_v_value); + if (__pyx_t_2) { + + /* "View.MemoryView":514 + * + * if isinstance(value, tuple): + * bytesvalue = struct.pack(self.view.format, *value) # <<<<<<<<<<<<<< + * else: + * bytesvalue = struct.pack(self.view.format, value) + */ + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_v_struct, __pyx_n_s_pack); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_3 = __Pyx_PyBytes_FromString(__pyx_v_self->view.format); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_4 = PyTuple_New(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_3); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_3)) __PYX_ERR(1, 514, __pyx_L1_error); + __pyx_t_3 = 0; + __pyx_t_3 = __Pyx_PySequence_Tuple(__pyx_v_value); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_5 = PyNumber_Add(__pyx_t_4, __pyx_t_3); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_t_3 = __Pyx_PyObject_Call(__pyx_t_1, __pyx_t_5, NULL); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 514, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + if (!(likely(PyBytes_CheckExact(__pyx_t_3))||((__pyx_t_3) == Py_None) || __Pyx_RaiseUnexpectedTypeError("bytes", __pyx_t_3))) __PYX_ERR(1, 514, __pyx_L1_error) + __pyx_v_bytesvalue = ((PyObject*)__pyx_t_3); + __pyx_t_3 = 0; + + /* "View.MemoryView":513 + * cdef Py_ssize_t i + * + * if isinstance(value, tuple): # <<<<<<<<<<<<<< + * bytesvalue = struct.pack(self.view.format, *value) + * else: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":516 + * bytesvalue = struct.pack(self.view.format, *value) + * else: + * bytesvalue = struct.pack(self.view.format, value) # <<<<<<<<<<<<<< + * + * for i, c in enumerate(bytesvalue): + */ + /*else*/ { + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_v_struct, __pyx_n_s_pack); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 516, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_1 = __Pyx_PyBytes_FromString(__pyx_v_self->view.format); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 516, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_4 = NULL; + __pyx_t_6 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_5))) { + __pyx_t_4 = PyMethod_GET_SELF(__pyx_t_5); + if (likely(__pyx_t_4)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_5); + __Pyx_INCREF(__pyx_t_4); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_5, function); + __pyx_t_6 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[3] = {__pyx_t_4, __pyx_t_1, __pyx_v_value}; + __pyx_t_3 = __Pyx_PyObject_FastCall(__pyx_t_5, __pyx_callargs+1-__pyx_t_6, 2+__pyx_t_6); + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 516, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + } + if (!(likely(PyBytes_CheckExact(__pyx_t_3))||((__pyx_t_3) == Py_None) || __Pyx_RaiseUnexpectedTypeError("bytes", __pyx_t_3))) __PYX_ERR(1, 516, __pyx_L1_error) + __pyx_v_bytesvalue = ((PyObject*)__pyx_t_3); + __pyx_t_3 = 0; + } + __pyx_L3:; + + /* "View.MemoryView":518 + * bytesvalue = struct.pack(self.view.format, value) + * + * for i, c in enumerate(bytesvalue): # <<<<<<<<<<<<<< + * itemp[i] = c + * + */ + __pyx_t_7 = 0; + if (unlikely(__pyx_v_bytesvalue == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' is not iterable"); + __PYX_ERR(1, 518, __pyx_L1_error) + } + __Pyx_INCREF(__pyx_v_bytesvalue); + __pyx_t_8 = __pyx_v_bytesvalue; + __pyx_t_10 = PyBytes_AS_STRING(__pyx_t_8); + __pyx_t_11 = (__pyx_t_10 + PyBytes_GET_SIZE(__pyx_t_8)); + for (__pyx_t_12 = __pyx_t_10; __pyx_t_12 < __pyx_t_11; __pyx_t_12++) { + __pyx_t_9 = __pyx_t_12; + __pyx_v_c = (__pyx_t_9[0]); + + /* "View.MemoryView":519 + * + * for i, c in enumerate(bytesvalue): + * itemp[i] = c # <<<<<<<<<<<<<< + * + * @cname('getbuffer') + */ + __pyx_v_i = __pyx_t_7; + + /* "View.MemoryView":518 + * bytesvalue = struct.pack(self.view.format, value) + * + * for i, c in enumerate(bytesvalue): # <<<<<<<<<<<<<< + * itemp[i] = c + * + */ + __pyx_t_7 = (__pyx_t_7 + 1); + + /* "View.MemoryView":519 + * + * for i, c in enumerate(bytesvalue): + * itemp[i] = c # <<<<<<<<<<<<<< + * + * @cname('getbuffer') + */ + (__pyx_v_itemp[__pyx_v_i]) = __pyx_v_c; + } + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + + /* "View.MemoryView":505 + * return result + * + * cdef assign_item_from_object(self, char *itemp, object value): # <<<<<<<<<<<<<< + * """Only used if instantiated manually by the user, or if Cython doesn't + * know how to convert the type""" + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("View.MemoryView.memoryview.assign_item_from_object", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_struct); + __Pyx_XDECREF(__pyx_v_bytesvalue); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":521 + * itemp[i] = c + * + * @cname('getbuffer') # <<<<<<<<<<<<<< + * def __getbuffer__(self, Py_buffer *info, int flags): + * if flags & PyBUF_WRITABLE and self.view.readonly: + */ + +/* Python wrapper */ +CYTHON_UNUSED static int __pyx_memoryview_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags); /*proto*/ +CYTHON_UNUSED static int __pyx_memoryview_getbuffer(PyObject *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + int __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__getbuffer__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_8__getbuffer__(((struct __pyx_memoryview_obj *)__pyx_v_self), ((Py_buffer *)__pyx_v_info), ((int)__pyx_v_flags)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static int __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_8__getbuffer__(struct __pyx_memoryview_obj *__pyx_v_self, Py_buffer *__pyx_v_info, int __pyx_v_flags) { + int __pyx_r; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + int __pyx_t_2; + Py_ssize_t *__pyx_t_3; + char *__pyx_t_4; + void *__pyx_t_5; + int __pyx_t_6; + Py_ssize_t __pyx_t_7; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + if (unlikely(__pyx_v_info == NULL)) { + PyErr_SetString(PyExc_BufferError, "PyObject_GetBuffer: view==NULL argument is obsolete"); + return -1; + } + __Pyx_RefNannySetupContext("__getbuffer__", 0); + __pyx_v_info->obj = Py_None; __Pyx_INCREF(Py_None); + __Pyx_GIVEREF(__pyx_v_info->obj); + + /* "View.MemoryView":523 + * @cname('getbuffer') + * def __getbuffer__(self, Py_buffer *info, int flags): + * if flags & PyBUF_WRITABLE and self.view.readonly: # <<<<<<<<<<<<<< + * raise ValueError, "Cannot create writable memory view from read-only memoryview" + * + */ + __pyx_t_2 = ((__pyx_v_flags & PyBUF_WRITABLE) != 0); + if (__pyx_t_2) { + } else { + __pyx_t_1 = __pyx_t_2; + goto __pyx_L4_bool_binop_done; + } + __pyx_t_1 = __pyx_v_self->view.readonly; + __pyx_L4_bool_binop_done:; + if (unlikely(__pyx_t_1)) { + + /* "View.MemoryView":524 + * def __getbuffer__(self, Py_buffer *info, int flags): + * if flags & PyBUF_WRITABLE and self.view.readonly: + * raise ValueError, "Cannot create writable memory view from read-only memoryview" # <<<<<<<<<<<<<< + * + * if flags & PyBUF_ND: + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Cannot_create_writable_memory_vi, 0, 0); + __PYX_ERR(1, 524, __pyx_L1_error) + + /* "View.MemoryView":523 + * @cname('getbuffer') + * def __getbuffer__(self, Py_buffer *info, int flags): + * if flags & PyBUF_WRITABLE and self.view.readonly: # <<<<<<<<<<<<<< + * raise ValueError, "Cannot create writable memory view from read-only memoryview" + * + */ + } + + /* "View.MemoryView":526 + * raise ValueError, "Cannot create writable memory view from read-only memoryview" + * + * if flags & PyBUF_ND: # <<<<<<<<<<<<<< + * info.shape = self.view.shape + * else: + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_ND) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":527 + * + * if flags & PyBUF_ND: + * info.shape = self.view.shape # <<<<<<<<<<<<<< + * else: + * info.shape = NULL + */ + __pyx_t_3 = __pyx_v_self->view.shape; + __pyx_v_info->shape = __pyx_t_3; + + /* "View.MemoryView":526 + * raise ValueError, "Cannot create writable memory view from read-only memoryview" + * + * if flags & PyBUF_ND: # <<<<<<<<<<<<<< + * info.shape = self.view.shape + * else: + */ + goto __pyx_L6; + } + + /* "View.MemoryView":529 + * info.shape = self.view.shape + * else: + * info.shape = NULL # <<<<<<<<<<<<<< + * + * if flags & PyBUF_STRIDES: + */ + /*else*/ { + __pyx_v_info->shape = NULL; + } + __pyx_L6:; + + /* "View.MemoryView":531 + * info.shape = NULL + * + * if flags & PyBUF_STRIDES: # <<<<<<<<<<<<<< + * info.strides = self.view.strides + * else: + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_STRIDES) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":532 + * + * if flags & PyBUF_STRIDES: + * info.strides = self.view.strides # <<<<<<<<<<<<<< + * else: + * info.strides = NULL + */ + __pyx_t_3 = __pyx_v_self->view.strides; + __pyx_v_info->strides = __pyx_t_3; + + /* "View.MemoryView":531 + * info.shape = NULL + * + * if flags & PyBUF_STRIDES: # <<<<<<<<<<<<<< + * info.strides = self.view.strides + * else: + */ + goto __pyx_L7; + } + + /* "View.MemoryView":534 + * info.strides = self.view.strides + * else: + * info.strides = NULL # <<<<<<<<<<<<<< + * + * if flags & PyBUF_INDIRECT: + */ + /*else*/ { + __pyx_v_info->strides = NULL; + } + __pyx_L7:; + + /* "View.MemoryView":536 + * info.strides = NULL + * + * if flags & PyBUF_INDIRECT: # <<<<<<<<<<<<<< + * info.suboffsets = self.view.suboffsets + * else: + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_INDIRECT) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":537 + * + * if flags & PyBUF_INDIRECT: + * info.suboffsets = self.view.suboffsets # <<<<<<<<<<<<<< + * else: + * info.suboffsets = NULL + */ + __pyx_t_3 = __pyx_v_self->view.suboffsets; + __pyx_v_info->suboffsets = __pyx_t_3; + + /* "View.MemoryView":536 + * info.strides = NULL + * + * if flags & PyBUF_INDIRECT: # <<<<<<<<<<<<<< + * info.suboffsets = self.view.suboffsets + * else: + */ + goto __pyx_L8; + } + + /* "View.MemoryView":539 + * info.suboffsets = self.view.suboffsets + * else: + * info.suboffsets = NULL # <<<<<<<<<<<<<< + * + * if flags & PyBUF_FORMAT: + */ + /*else*/ { + __pyx_v_info->suboffsets = NULL; + } + __pyx_L8:; + + /* "View.MemoryView":541 + * info.suboffsets = NULL + * + * if flags & PyBUF_FORMAT: # <<<<<<<<<<<<<< + * info.format = self.view.format + * else: + */ + __pyx_t_1 = ((__pyx_v_flags & PyBUF_FORMAT) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":542 + * + * if flags & PyBUF_FORMAT: + * info.format = self.view.format # <<<<<<<<<<<<<< + * else: + * info.format = NULL + */ + __pyx_t_4 = __pyx_v_self->view.format; + __pyx_v_info->format = __pyx_t_4; + + /* "View.MemoryView":541 + * info.suboffsets = NULL + * + * if flags & PyBUF_FORMAT: # <<<<<<<<<<<<<< + * info.format = self.view.format + * else: + */ + goto __pyx_L9; + } + + /* "View.MemoryView":544 + * info.format = self.view.format + * else: + * info.format = NULL # <<<<<<<<<<<<<< + * + * info.buf = self.view.buf + */ + /*else*/ { + __pyx_v_info->format = NULL; + } + __pyx_L9:; + + /* "View.MemoryView":546 + * info.format = NULL + * + * info.buf = self.view.buf # <<<<<<<<<<<<<< + * info.ndim = self.view.ndim + * info.itemsize = self.view.itemsize + */ + __pyx_t_5 = __pyx_v_self->view.buf; + __pyx_v_info->buf = __pyx_t_5; + + /* "View.MemoryView":547 + * + * info.buf = self.view.buf + * info.ndim = self.view.ndim # <<<<<<<<<<<<<< + * info.itemsize = self.view.itemsize + * info.len = self.view.len + */ + __pyx_t_6 = __pyx_v_self->view.ndim; + __pyx_v_info->ndim = __pyx_t_6; + + /* "View.MemoryView":548 + * info.buf = self.view.buf + * info.ndim = self.view.ndim + * info.itemsize = self.view.itemsize # <<<<<<<<<<<<<< + * info.len = self.view.len + * info.readonly = self.view.readonly + */ + __pyx_t_7 = __pyx_v_self->view.itemsize; + __pyx_v_info->itemsize = __pyx_t_7; + + /* "View.MemoryView":549 + * info.ndim = self.view.ndim + * info.itemsize = self.view.itemsize + * info.len = self.view.len # <<<<<<<<<<<<<< + * info.readonly = self.view.readonly + * info.obj = self + */ + __pyx_t_7 = __pyx_v_self->view.len; + __pyx_v_info->len = __pyx_t_7; + + /* "View.MemoryView":550 + * info.itemsize = self.view.itemsize + * info.len = self.view.len + * info.readonly = self.view.readonly # <<<<<<<<<<<<<< + * info.obj = self + * + */ + __pyx_t_1 = __pyx_v_self->view.readonly; + __pyx_v_info->readonly = __pyx_t_1; + + /* "View.MemoryView":551 + * info.len = self.view.len + * info.readonly = self.view.readonly + * info.obj = self # <<<<<<<<<<<<<< + * + * + */ + __Pyx_INCREF((PyObject *)__pyx_v_self); + __Pyx_GIVEREF((PyObject *)__pyx_v_self); + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); + __pyx_v_info->obj = ((PyObject *)__pyx_v_self); + + /* "View.MemoryView":521 + * itemp[i] = c + * + * @cname('getbuffer') # <<<<<<<<<<<<<< + * def __getbuffer__(self, Py_buffer *info, int flags): + * if flags & PyBUF_WRITABLE and self.view.readonly: + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.memoryview.__getbuffer__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + if (__pyx_v_info->obj != NULL) { + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); __pyx_v_info->obj = 0; + } + goto __pyx_L2; + __pyx_L0:; + if (__pyx_v_info->obj == Py_None) { + __Pyx_GOTREF(__pyx_v_info->obj); + __Pyx_DECREF(__pyx_v_info->obj); __pyx_v_info->obj = 0; + } + __pyx_L2:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":554 + * + * + * @property # <<<<<<<<<<<<<< + * def T(self): + * cdef _memoryviewslice result = memoryview_copy(self) + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_1T_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_1T_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_1T___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_1T___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + struct __pyx_memoryviewslice_obj *__pyx_v_result = 0; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":556 + * @property + * def T(self): + * cdef _memoryviewslice result = memoryview_copy(self) # <<<<<<<<<<<<<< + * transpose_memslice(&result.from_slice) + * return result + */ + __pyx_t_1 = __pyx_memoryview_copy_object(__pyx_v_self); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 556, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + if (!(likely(((__pyx_t_1) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_1, __pyx_memoryviewslice_type))))) __PYX_ERR(1, 556, __pyx_L1_error) + __pyx_v_result = ((struct __pyx_memoryviewslice_obj *)__pyx_t_1); + __pyx_t_1 = 0; + + /* "View.MemoryView":557 + * def T(self): + * cdef _memoryviewslice result = memoryview_copy(self) + * transpose_memslice(&result.from_slice) # <<<<<<<<<<<<<< + * return result + * + */ + __pyx_t_2 = __pyx_memslice_transpose((&__pyx_v_result->from_slice)); if (unlikely(__pyx_t_2 == ((int)-1))) __PYX_ERR(1, 557, __pyx_L1_error) + + /* "View.MemoryView":558 + * cdef _memoryviewslice result = memoryview_copy(self) + * transpose_memslice(&result.from_slice) + * return result # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_result); + __pyx_r = ((PyObject *)__pyx_v_result); + goto __pyx_L0; + + /* "View.MemoryView":554 + * + * + * @property # <<<<<<<<<<<<<< + * def T(self): + * cdef _memoryviewslice result = memoryview_copy(self) + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.memoryview.T.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_result); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":560 + * return result + * + * @property # <<<<<<<<<<<<<< + * def base(self): + * return self._get_base() + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4base_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4base_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_4base___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4base___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":562 + * @property + * def base(self): + * return self._get_base() # <<<<<<<<<<<<<< + * + * cdef _get_base(self): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = ((struct __pyx_vtabstruct_memoryview *)__pyx_v_self->__pyx_vtab)->_get_base(__pyx_v_self); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 562, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":560 + * return result + * + * @property # <<<<<<<<<<<<<< + * def base(self): + * return self._get_base() + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.memoryview.base.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":564 + * return self._get_base() + * + * cdef _get_base(self): # <<<<<<<<<<<<<< + * return self.obj + * + */ + +static PyObject *__pyx_memoryview__get_base(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("_get_base", 1); + + /* "View.MemoryView":565 + * + * cdef _get_base(self): + * return self.obj # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_self->obj); + __pyx_r = __pyx_v_self->obj; + goto __pyx_L0; + + /* "View.MemoryView":564 + * return self._get_base() + * + * cdef _get_base(self): # <<<<<<<<<<<<<< + * return self.obj + * + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":567 + * return self.obj + * + * @property # <<<<<<<<<<<<<< + * def shape(self): + * return tuple([length for length in self.view.shape[:self.view.ndim]]) + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_5shape_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_5shape_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_5shape___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_5shape___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + Py_ssize_t __pyx_7genexpr__pyx_v_length; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + Py_ssize_t *__pyx_t_2; + Py_ssize_t *__pyx_t_3; + Py_ssize_t *__pyx_t_4; + PyObject *__pyx_t_5 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":569 + * @property + * def shape(self): + * return tuple([length for length in self.view.shape[:self.view.ndim]]) # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + { /* enter inner scope */ + __pyx_t_1 = PyList_New(0); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 569, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_3 = (__pyx_v_self->view.shape + __pyx_v_self->view.ndim); + for (__pyx_t_4 = __pyx_v_self->view.shape; __pyx_t_4 < __pyx_t_3; __pyx_t_4++) { + __pyx_t_2 = __pyx_t_4; + __pyx_7genexpr__pyx_v_length = (__pyx_t_2[0]); + __pyx_t_5 = PyInt_FromSsize_t(__pyx_7genexpr__pyx_v_length); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 569, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + if (unlikely(__Pyx_ListComp_Append(__pyx_t_1, (PyObject*)__pyx_t_5))) __PYX_ERR(1, 569, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + } + } /* exit inner scope */ + __pyx_t_5 = PyList_AsTuple(((PyObject*)__pyx_t_1)); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 569, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_r = __pyx_t_5; + __pyx_t_5 = 0; + goto __pyx_L0; + + /* "View.MemoryView":567 + * return self.obj + * + * @property # <<<<<<<<<<<<<< + * def shape(self): + * return tuple([length for length in self.view.shape[:self.view.ndim]]) + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_AddTraceback("View.MemoryView.memoryview.shape.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":571 + * return tuple([length for length in self.view.shape[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def strides(self): + * if self.view.strides == NULL: + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_7strides_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_7strides_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_7strides___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_7strides___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + Py_ssize_t __pyx_8genexpr1__pyx_v_stride; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + Py_ssize_t *__pyx_t_3; + Py_ssize_t *__pyx_t_4; + Py_ssize_t *__pyx_t_5; + PyObject *__pyx_t_6 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":573 + * @property + * def strides(self): + * if self.view.strides == NULL: # <<<<<<<<<<<<<< + * + * raise ValueError, "Buffer view does not expose strides" + */ + __pyx_t_1 = (__pyx_v_self->view.strides == NULL); + if (unlikely(__pyx_t_1)) { + + /* "View.MemoryView":575 + * if self.view.strides == NULL: + * + * raise ValueError, "Buffer view does not expose strides" # <<<<<<<<<<<<<< + * + * return tuple([stride for stride in self.view.strides[:self.view.ndim]]) + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Buffer_view_does_not_expose_stri, 0, 0); + __PYX_ERR(1, 575, __pyx_L1_error) + + /* "View.MemoryView":573 + * @property + * def strides(self): + * if self.view.strides == NULL: # <<<<<<<<<<<<<< + * + * raise ValueError, "Buffer view does not expose strides" + */ + } + + /* "View.MemoryView":577 + * raise ValueError, "Buffer view does not expose strides" + * + * return tuple([stride for stride in self.view.strides[:self.view.ndim]]) # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + { /* enter inner scope */ + __pyx_t_2 = PyList_New(0); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 577, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_4 = (__pyx_v_self->view.strides + __pyx_v_self->view.ndim); + for (__pyx_t_5 = __pyx_v_self->view.strides; __pyx_t_5 < __pyx_t_4; __pyx_t_5++) { + __pyx_t_3 = __pyx_t_5; + __pyx_8genexpr1__pyx_v_stride = (__pyx_t_3[0]); + __pyx_t_6 = PyInt_FromSsize_t(__pyx_8genexpr1__pyx_v_stride); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 577, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + if (unlikely(__Pyx_ListComp_Append(__pyx_t_2, (PyObject*)__pyx_t_6))) __PYX_ERR(1, 577, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + } + } /* exit inner scope */ + __pyx_t_6 = PyList_AsTuple(((PyObject*)__pyx_t_2)); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 577, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_r = __pyx_t_6; + __pyx_t_6 = 0; + goto __pyx_L0; + + /* "View.MemoryView":571 + * return tuple([length for length in self.view.shape[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def strides(self): + * if self.view.strides == NULL: + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_AddTraceback("View.MemoryView.memoryview.strides.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":579 + * return tuple([stride for stride in self.view.strides[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def suboffsets(self): + * if self.view.suboffsets == NULL: + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_10suboffsets_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_10suboffsets_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_10suboffsets___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_10suboffsets___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + Py_ssize_t __pyx_8genexpr2__pyx_v_suboffset; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + Py_ssize_t *__pyx_t_3; + Py_ssize_t *__pyx_t_4; + Py_ssize_t *__pyx_t_5; + PyObject *__pyx_t_6 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":581 + * @property + * def suboffsets(self): + * if self.view.suboffsets == NULL: # <<<<<<<<<<<<<< + * return (-1,) * self.view.ndim + * + */ + __pyx_t_1 = (__pyx_v_self->view.suboffsets == NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":582 + * def suboffsets(self): + * if self.view.suboffsets == NULL: + * return (-1,) * self.view.ndim # <<<<<<<<<<<<<< + * + * return tuple([suboffset for suboffset in self.view.suboffsets[:self.view.ndim]]) + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __Pyx_PySequence_Multiply(__pyx_tuple__4, __pyx_v_self->view.ndim); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 582, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":581 + * @property + * def suboffsets(self): + * if self.view.suboffsets == NULL: # <<<<<<<<<<<<<< + * return (-1,) * self.view.ndim + * + */ + } + + /* "View.MemoryView":584 + * return (-1,) * self.view.ndim + * + * return tuple([suboffset for suboffset in self.view.suboffsets[:self.view.ndim]]) # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + { /* enter inner scope */ + __pyx_t_2 = PyList_New(0); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 584, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_4 = (__pyx_v_self->view.suboffsets + __pyx_v_self->view.ndim); + for (__pyx_t_5 = __pyx_v_self->view.suboffsets; __pyx_t_5 < __pyx_t_4; __pyx_t_5++) { + __pyx_t_3 = __pyx_t_5; + __pyx_8genexpr2__pyx_v_suboffset = (__pyx_t_3[0]); + __pyx_t_6 = PyInt_FromSsize_t(__pyx_8genexpr2__pyx_v_suboffset); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 584, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + if (unlikely(__Pyx_ListComp_Append(__pyx_t_2, (PyObject*)__pyx_t_6))) __PYX_ERR(1, 584, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + } + } /* exit inner scope */ + __pyx_t_6 = PyList_AsTuple(((PyObject*)__pyx_t_2)); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 584, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_r = __pyx_t_6; + __pyx_t_6 = 0; + goto __pyx_L0; + + /* "View.MemoryView":579 + * return tuple([stride for stride in self.view.strides[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def suboffsets(self): + * if self.view.suboffsets == NULL: + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_AddTraceback("View.MemoryView.memoryview.suboffsets.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":586 + * return tuple([suboffset for suboffset in self.view.suboffsets[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def ndim(self): + * return self.view.ndim + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4ndim_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4ndim_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_4ndim___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4ndim___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":588 + * @property + * def ndim(self): + * return self.view.ndim # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_self->view.ndim); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 588, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":586 + * return tuple([suboffset for suboffset in self.view.suboffsets[:self.view.ndim]]) + * + * @property # <<<<<<<<<<<<<< + * def ndim(self): + * return self.view.ndim + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.memoryview.ndim.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":590 + * return self.view.ndim + * + * @property # <<<<<<<<<<<<<< + * def itemsize(self): + * return self.view.itemsize + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_8itemsize_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_8itemsize_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_8itemsize___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_8itemsize___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":592 + * @property + * def itemsize(self): + * return self.view.itemsize # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyInt_FromSsize_t(__pyx_v_self->view.itemsize); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 592, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":590 + * return self.view.ndim + * + * @property # <<<<<<<<<<<<<< + * def itemsize(self): + * return self.view.itemsize + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.memoryview.itemsize.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":594 + * return self.view.itemsize + * + * @property # <<<<<<<<<<<<<< + * def nbytes(self): + * return self.size * self.view.itemsize + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_6nbytes_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_6nbytes_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_6nbytes___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_6nbytes___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":596 + * @property + * def nbytes(self): + * return self.size * self.view.itemsize # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_size); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 596, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = PyInt_FromSsize_t(__pyx_v_self->view.itemsize); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 596, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PyNumber_Multiply(__pyx_t_1, __pyx_t_2); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 596, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_r = __pyx_t_3; + __pyx_t_3 = 0; + goto __pyx_L0; + + /* "View.MemoryView":594 + * return self.view.itemsize + * + * @property # <<<<<<<<<<<<<< + * def nbytes(self): + * return self.size * self.view.itemsize + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview.nbytes.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":598 + * return self.size * self.view.itemsize + * + * @property # <<<<<<<<<<<<<< + * def size(self): + * if self._size is None: + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4size_1__get__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_pw_15View_dot_MemoryView_10memoryview_4size_1__get__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__get__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_pf_15View_dot_MemoryView_10memoryview_4size___get__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView_10memoryview_4size___get__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_v_result = NULL; + PyObject *__pyx_v_length = NULL; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + Py_ssize_t *__pyx_t_2; + Py_ssize_t *__pyx_t_3; + Py_ssize_t *__pyx_t_4; + PyObject *__pyx_t_5 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__get__", 1); + + /* "View.MemoryView":600 + * @property + * def size(self): + * if self._size is None: # <<<<<<<<<<<<<< + * result = 1 + * + */ + __pyx_t_1 = (__pyx_v_self->_size == Py_None); + if (__pyx_t_1) { + + /* "View.MemoryView":601 + * def size(self): + * if self._size is None: + * result = 1 # <<<<<<<<<<<<<< + * + * for length in self.view.shape[:self.view.ndim]: + */ + __Pyx_INCREF(__pyx_int_1); + __pyx_v_result = __pyx_int_1; + + /* "View.MemoryView":603 + * result = 1 + * + * for length in self.view.shape[:self.view.ndim]: # <<<<<<<<<<<<<< + * result *= length + * + */ + __pyx_t_3 = (__pyx_v_self->view.shape + __pyx_v_self->view.ndim); + for (__pyx_t_4 = __pyx_v_self->view.shape; __pyx_t_4 < __pyx_t_3; __pyx_t_4++) { + __pyx_t_2 = __pyx_t_4; + __pyx_t_5 = PyInt_FromSsize_t((__pyx_t_2[0])); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 603, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_XDECREF_SET(__pyx_v_length, __pyx_t_5); + __pyx_t_5 = 0; + + /* "View.MemoryView":604 + * + * for length in self.view.shape[:self.view.ndim]: + * result *= length # <<<<<<<<<<<<<< + * + * self._size = result + */ + __pyx_t_5 = PyNumber_InPlaceMultiply(__pyx_v_result, __pyx_v_length); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 604, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF_SET(__pyx_v_result, __pyx_t_5); + __pyx_t_5 = 0; + } + + /* "View.MemoryView":606 + * result *= length + * + * self._size = result # <<<<<<<<<<<<<< + * + * return self._size + */ + __Pyx_INCREF(__pyx_v_result); + __Pyx_GIVEREF(__pyx_v_result); + __Pyx_GOTREF(__pyx_v_self->_size); + __Pyx_DECREF(__pyx_v_self->_size); + __pyx_v_self->_size = __pyx_v_result; + + /* "View.MemoryView":600 + * @property + * def size(self): + * if self._size is None: # <<<<<<<<<<<<<< + * result = 1 + * + */ + } + + /* "View.MemoryView":608 + * self._size = result + * + * return self._size # <<<<<<<<<<<<<< + * + * def __len__(self): + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_self->_size); + __pyx_r = __pyx_v_self->_size; + goto __pyx_L0; + + /* "View.MemoryView":598 + * return self.size * self.view.itemsize + * + * @property # <<<<<<<<<<<<<< + * def size(self): + * if self._size is None: + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_5); + __Pyx_AddTraceback("View.MemoryView.memoryview.size.__get__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_result); + __Pyx_XDECREF(__pyx_v_length); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":610 + * return self._size + * + * def __len__(self): # <<<<<<<<<<<<<< + * if self.view.ndim >= 1: + * return self.view.shape[0] + */ + +/* Python wrapper */ +static Py_ssize_t __pyx_memoryview___len__(PyObject *__pyx_v_self); /*proto*/ +static Py_ssize_t __pyx_memoryview___len__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + Py_ssize_t __pyx_r; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__len__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_10__len__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static Py_ssize_t __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_10__len__(struct __pyx_memoryview_obj *__pyx_v_self) { + Py_ssize_t __pyx_r; + int __pyx_t_1; + + /* "View.MemoryView":611 + * + * def __len__(self): + * if self.view.ndim >= 1: # <<<<<<<<<<<<<< + * return self.view.shape[0] + * + */ + __pyx_t_1 = (__pyx_v_self->view.ndim >= 1); + if (__pyx_t_1) { + + /* "View.MemoryView":612 + * def __len__(self): + * if self.view.ndim >= 1: + * return self.view.shape[0] # <<<<<<<<<<<<<< + * + * return 0 + */ + __pyx_r = (__pyx_v_self->view.shape[0]); + goto __pyx_L0; + + /* "View.MemoryView":611 + * + * def __len__(self): + * if self.view.ndim >= 1: # <<<<<<<<<<<<<< + * return self.view.shape[0] + * + */ + } + + /* "View.MemoryView":614 + * return self.view.shape[0] + * + * return 0 # <<<<<<<<<<<<<< + * + * def __repr__(self): + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":610 + * return self._size + * + * def __len__(self): # <<<<<<<<<<<<<< + * if self.view.ndim >= 1: + * return self.view.shape[0] + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":616 + * return 0 + * + * def __repr__(self): # <<<<<<<<<<<<<< + * return "" % (self.base.__class__.__name__, + * id(self)) + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview___repr__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_memoryview___repr__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__repr__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_12__repr__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_12__repr__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__repr__", 1); + + /* "View.MemoryView":617 + * + * def __repr__(self): + * return "" % (self.base.__class__.__name__, # <<<<<<<<<<<<<< + * id(self)) + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_base); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 617, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_class); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 617, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_name_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 617, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + + /* "View.MemoryView":618 + * def __repr__(self): + * return "" % (self.base.__class__.__name__, + * id(self)) # <<<<<<<<<<<<<< + * + * def __str__(self): + */ + __pyx_t_2 = __Pyx_PyObject_CallOneArg(__pyx_builtin_id, ((PyObject *)__pyx_v_self)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 618, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + + /* "View.MemoryView":617 + * + * def __repr__(self): + * return "" % (self.base.__class__.__name__, # <<<<<<<<<<<<<< + * id(self)) + * + */ + __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 617, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_1)) __PYX_ERR(1, 617, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_2)) __PYX_ERR(1, 617, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_2 = 0; + __pyx_t_2 = __Pyx_PyString_Format(__pyx_kp_s_MemoryView_of_r_at_0x_x, __pyx_t_3); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 617, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":616 + * return 0 + * + * def __repr__(self): # <<<<<<<<<<<<<< + * return "" % (self.base.__class__.__name__, + * id(self)) + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview.__repr__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":620 + * id(self)) + * + * def __str__(self): # <<<<<<<<<<<<<< + * return "" % (self.base.__class__.__name__,) + * + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview___str__(PyObject *__pyx_v_self); /*proto*/ +static PyObject *__pyx_memoryview___str__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__str__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_14__str__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_14__str__(struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__str__", 1); + + /* "View.MemoryView":621 + * + * def __str__(self): + * return "" % (self.base.__class__.__name__,) # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v_self), __pyx_n_s_base); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 621, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyObject_GetAttrStr(__pyx_t_1, __pyx_n_s_class); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 621, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_t_1 = __Pyx_PyObject_GetAttrStr(__pyx_t_2, __pyx_n_s_name_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 621, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_2 = PyTuple_New(1); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 621, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_2, 0, __pyx_t_1)) __PYX_ERR(1, 621, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_1 = __Pyx_PyString_Format(__pyx_kp_s_MemoryView_of_r_object, __pyx_t_2); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 621, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":620 + * id(self)) + * + * def __str__(self): # <<<<<<<<<<<<<< + * return "" % (self.base.__class__.__name__,) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.__str__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":624 + * + * + * def is_c_contig(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview_is_c_contig(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_memoryview_is_c_contig(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("is_c_contig (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("is_c_contig", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "is_c_contig", 0))) return NULL; + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_16is_c_contig(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_16is_c_contig(struct __pyx_memoryview_obj *__pyx_v_self) { + __Pyx_memviewslice *__pyx_v_mslice; + __Pyx_memviewslice __pyx_v_tmp; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice *__pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("is_c_contig", 1); + + /* "View.MemoryView":627 + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + * mslice = get_slice_from_memview(self, &tmp) # <<<<<<<<<<<<<< + * return slice_is_contig(mslice[0], 'C', self.view.ndim) + * + */ + __pyx_t_1 = __pyx_memoryview_get_slice_from_memoryview(__pyx_v_self, (&__pyx_v_tmp)); if (unlikely(__pyx_t_1 == ((__Pyx_memviewslice *)NULL))) __PYX_ERR(1, 627, __pyx_L1_error) + __pyx_v_mslice = __pyx_t_1; + + /* "View.MemoryView":628 + * cdef __Pyx_memviewslice tmp + * mslice = get_slice_from_memview(self, &tmp) + * return slice_is_contig(mslice[0], 'C', self.view.ndim) # <<<<<<<<<<<<<< + * + * def is_f_contig(self): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __Pyx_PyBool_FromLong(__pyx_memviewslice_is_contig((__pyx_v_mslice[0]), 'C', __pyx_v_self->view.ndim)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 628, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":624 + * + * + * def is_c_contig(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.is_c_contig", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":630 + * return slice_is_contig(mslice[0], 'C', self.view.ndim) + * + * def is_f_contig(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview_is_f_contig(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_memoryview_is_f_contig(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("is_f_contig (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("is_f_contig", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "is_f_contig", 0))) return NULL; + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_18is_f_contig(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_18is_f_contig(struct __pyx_memoryview_obj *__pyx_v_self) { + __Pyx_memviewslice *__pyx_v_mslice; + __Pyx_memviewslice __pyx_v_tmp; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice *__pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("is_f_contig", 1); + + /* "View.MemoryView":633 + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + * mslice = get_slice_from_memview(self, &tmp) # <<<<<<<<<<<<<< + * return slice_is_contig(mslice[0], 'F', self.view.ndim) + * + */ + __pyx_t_1 = __pyx_memoryview_get_slice_from_memoryview(__pyx_v_self, (&__pyx_v_tmp)); if (unlikely(__pyx_t_1 == ((__Pyx_memviewslice *)NULL))) __PYX_ERR(1, 633, __pyx_L1_error) + __pyx_v_mslice = __pyx_t_1; + + /* "View.MemoryView":634 + * cdef __Pyx_memviewslice tmp + * mslice = get_slice_from_memview(self, &tmp) + * return slice_is_contig(mslice[0], 'F', self.view.ndim) # <<<<<<<<<<<<<< + * + * def copy(self): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __Pyx_PyBool_FromLong(__pyx_memviewslice_is_contig((__pyx_v_mslice[0]), 'F', __pyx_v_self->view.ndim)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 634, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":630 + * return slice_is_contig(mslice[0], 'C', self.view.ndim) + * + * def is_f_contig(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice *mslice + * cdef __Pyx_memviewslice tmp + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.is_f_contig", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":636 + * return slice_is_contig(mslice[0], 'F', self.view.ndim) + * + * def copy(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice mslice + * cdef int flags = self.flags & ~PyBUF_F_CONTIGUOUS + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview_copy(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_memoryview_copy(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("copy (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("copy", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "copy", 0))) return NULL; + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_20copy(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_20copy(struct __pyx_memoryview_obj *__pyx_v_self) { + __Pyx_memviewslice __pyx_v_mslice; + int __pyx_v_flags; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("copy", 1); + + /* "View.MemoryView":638 + * def copy(self): + * cdef __Pyx_memviewslice mslice + * cdef int flags = self.flags & ~PyBUF_F_CONTIGUOUS # <<<<<<<<<<<<<< + * + * slice_copy(self, &mslice) + */ + __pyx_v_flags = (__pyx_v_self->flags & (~PyBUF_F_CONTIGUOUS)); + + /* "View.MemoryView":640 + * cdef int flags = self.flags & ~PyBUF_F_CONTIGUOUS + * + * slice_copy(self, &mslice) # <<<<<<<<<<<<<< + * mslice = slice_copy_contig(&mslice, "c", self.view.ndim, + * self.view.itemsize, + */ + __pyx_memoryview_slice_copy(__pyx_v_self, (&__pyx_v_mslice)); + + /* "View.MemoryView":641 + * + * slice_copy(self, &mslice) + * mslice = slice_copy_contig(&mslice, "c", self.view.ndim, # <<<<<<<<<<<<<< + * self.view.itemsize, + * flags|PyBUF_C_CONTIGUOUS, + */ + __pyx_t_1 = __pyx_memoryview_copy_new_contig((&__pyx_v_mslice), ((char *)"c"), __pyx_v_self->view.ndim, __pyx_v_self->view.itemsize, (__pyx_v_flags | PyBUF_C_CONTIGUOUS), __pyx_v_self->dtype_is_object); if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 641, __pyx_L1_error) + __pyx_v_mslice = __pyx_t_1; + + /* "View.MemoryView":646 + * self.dtype_is_object) + * + * return memoryview_copy_from_slice(self, &mslice) # <<<<<<<<<<<<<< + * + * def copy_fortran(self): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __pyx_memoryview_copy_object_from_slice(__pyx_v_self, (&__pyx_v_mslice)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 646, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":636 + * return slice_is_contig(mslice[0], 'F', self.view.ndim) + * + * def copy(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice mslice + * cdef int flags = self.flags & ~PyBUF_F_CONTIGUOUS + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.copy", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":648 + * return memoryview_copy_from_slice(self, &mslice) + * + * def copy_fortran(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice src, dst + * cdef int flags = self.flags & ~PyBUF_C_CONTIGUOUS + */ + +/* Python wrapper */ +static PyObject *__pyx_memoryview_copy_fortran(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_memoryview_copy_fortran(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("copy_fortran (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("copy_fortran", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "copy_fortran", 0))) return NULL; + __pyx_r = __pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_22copy_fortran(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_memoryview___pyx_pf_15View_dot_MemoryView_10memoryview_22copy_fortran(struct __pyx_memoryview_obj *__pyx_v_self) { + __Pyx_memviewslice __pyx_v_src; + __Pyx_memviewslice __pyx_v_dst; + int __pyx_v_flags; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_memviewslice __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("copy_fortran", 1); + + /* "View.MemoryView":650 + * def copy_fortran(self): + * cdef __Pyx_memviewslice src, dst + * cdef int flags = self.flags & ~PyBUF_C_CONTIGUOUS # <<<<<<<<<<<<<< + * + * slice_copy(self, &src) + */ + __pyx_v_flags = (__pyx_v_self->flags & (~PyBUF_C_CONTIGUOUS)); + + /* "View.MemoryView":652 + * cdef int flags = self.flags & ~PyBUF_C_CONTIGUOUS + * + * slice_copy(self, &src) # <<<<<<<<<<<<<< + * dst = slice_copy_contig(&src, "fortran", self.view.ndim, + * self.view.itemsize, + */ + __pyx_memoryview_slice_copy(__pyx_v_self, (&__pyx_v_src)); + + /* "View.MemoryView":653 + * + * slice_copy(self, &src) + * dst = slice_copy_contig(&src, "fortran", self.view.ndim, # <<<<<<<<<<<<<< + * self.view.itemsize, + * flags|PyBUF_F_CONTIGUOUS, + */ + __pyx_t_1 = __pyx_memoryview_copy_new_contig((&__pyx_v_src), ((char *)"fortran"), __pyx_v_self->view.ndim, __pyx_v_self->view.itemsize, (__pyx_v_flags | PyBUF_F_CONTIGUOUS), __pyx_v_self->dtype_is_object); if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 653, __pyx_L1_error) + __pyx_v_dst = __pyx_t_1; + + /* "View.MemoryView":658 + * self.dtype_is_object) + * + * return memoryview_copy_from_slice(self, &dst) # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __pyx_memoryview_copy_object_from_slice(__pyx_v_self, (&__pyx_v_dst)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 658, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":648 + * return memoryview_copy_from_slice(self, &mslice) + * + * def copy_fortran(self): # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice src, dst + * cdef int flags = self.flags & ~PyBUF_C_CONTIGUOUS + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.memoryview.copy_fortran", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_memoryview_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_memoryview_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("__reduce_cython__", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "__reduce_cython__", 0))) return NULL; + __pyx_r = __pyx_pf___pyx_memoryview___reduce_cython__(((struct __pyx_memoryview_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_memoryview___reduce_cython__(CYTHON_UNUSED struct __pyx_memoryview_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__reduce_cython__", 1); + + /* "(tree fragment)":2 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 2, __pyx_L1_error) + + /* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.memoryview.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_memoryview_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_memoryview_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + CYTHON_UNUSED PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 3, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__setstate_cython__") < 0)) __PYX_ERR(1, 3, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + } + __pyx_v___pyx_state = values[0]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__setstate_cython__", 1, 1, 1, __pyx_nargs); __PYX_ERR(1, 3, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.memoryview.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf___pyx_memoryview_2__setstate_cython__(((struct __pyx_memoryview_obj *)__pyx_v_self), __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_memoryview_2__setstate_cython__(CYTHON_UNUSED struct __pyx_memoryview_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setstate_cython__", 1); + + /* "(tree fragment)":4 + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 4, __pyx_L1_error) + + /* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.memoryview.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":662 + * + * @cname('__pyx_memoryview_new') + * cdef memoryview_cwrapper(object o, int flags, bint dtype_is_object, __Pyx_TypeInfo *typeinfo): # <<<<<<<<<<<<<< + * cdef memoryview result = memoryview(o, flags, dtype_is_object) + * result.typeinfo = typeinfo + */ + +static PyObject *__pyx_memoryview_new(PyObject *__pyx_v_o, int __pyx_v_flags, int __pyx_v_dtype_is_object, __Pyx_TypeInfo *__pyx_v_typeinfo) { + struct __pyx_memoryview_obj *__pyx_v_result = 0; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("memoryview_cwrapper", 1); + + /* "View.MemoryView":663 + * @cname('__pyx_memoryview_new') + * cdef memoryview_cwrapper(object o, int flags, bint dtype_is_object, __Pyx_TypeInfo *typeinfo): + * cdef memoryview result = memoryview(o, flags, dtype_is_object) # <<<<<<<<<<<<<< + * result.typeinfo = typeinfo + * return result + */ + __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_flags); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 663, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyBool_FromLong(__pyx_v_dtype_is_object); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 663, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 663, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_INCREF(__pyx_v_o); + __Pyx_GIVEREF(__pyx_v_o); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_v_o)) __PYX_ERR(1, 663, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_1)) __PYX_ERR(1, 663, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_t_2)) __PYX_ERR(1, 663, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_2 = 0; + __pyx_t_2 = __Pyx_PyObject_Call(((PyObject *)__pyx_memoryview_type), __pyx_t_3, NULL); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 663, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_v_result = ((struct __pyx_memoryview_obj *)__pyx_t_2); + __pyx_t_2 = 0; + + /* "View.MemoryView":664 + * cdef memoryview_cwrapper(object o, int flags, bint dtype_is_object, __Pyx_TypeInfo *typeinfo): + * cdef memoryview result = memoryview(o, flags, dtype_is_object) + * result.typeinfo = typeinfo # <<<<<<<<<<<<<< + * return result + * + */ + __pyx_v_result->typeinfo = __pyx_v_typeinfo; + + /* "View.MemoryView":665 + * cdef memoryview result = memoryview(o, flags, dtype_is_object) + * result.typeinfo = typeinfo + * return result # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_check') + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_result); + __pyx_r = ((PyObject *)__pyx_v_result); + goto __pyx_L0; + + /* "View.MemoryView":662 + * + * @cname('__pyx_memoryview_new') + * cdef memoryview_cwrapper(object o, int flags, bint dtype_is_object, __Pyx_TypeInfo *typeinfo): # <<<<<<<<<<<<<< + * cdef memoryview result = memoryview(o, flags, dtype_is_object) + * result.typeinfo = typeinfo + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview_cwrapper", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_result); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":668 + * + * @cname('__pyx_memoryview_check') + * cdef inline bint memoryview_check(object o) noexcept: # <<<<<<<<<<<<<< + * return isinstance(o, memoryview) + * + */ + +static CYTHON_INLINE int __pyx_memoryview_check(PyObject *__pyx_v_o) { + int __pyx_r; + int __pyx_t_1; + + /* "View.MemoryView":669 + * @cname('__pyx_memoryview_check') + * cdef inline bint memoryview_check(object o) noexcept: + * return isinstance(o, memoryview) # <<<<<<<<<<<<<< + * + * cdef tuple _unellipsify(object index, int ndim): + */ + __pyx_t_1 = __Pyx_TypeCheck(__pyx_v_o, __pyx_memoryview_type); + __pyx_r = __pyx_t_1; + goto __pyx_L0; + + /* "View.MemoryView":668 + * + * @cname('__pyx_memoryview_check') + * cdef inline bint memoryview_check(object o) noexcept: # <<<<<<<<<<<<<< + * return isinstance(o, memoryview) + * + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":671 + * return isinstance(o, memoryview) + * + * cdef tuple _unellipsify(object index, int ndim): # <<<<<<<<<<<<<< + * """ + * Replace all ellipses with full slices and fill incomplete indices with + */ + +static PyObject *_unellipsify(PyObject *__pyx_v_index, int __pyx_v_ndim) { + Py_ssize_t __pyx_v_idx; + PyObject *__pyx_v_tup = NULL; + PyObject *__pyx_v_result = NULL; + int __pyx_v_have_slices; + int __pyx_v_seen_ellipsis; + PyObject *__pyx_v_item = NULL; + Py_ssize_t __pyx_v_nslices; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + Py_ssize_t __pyx_t_4; + Py_ssize_t __pyx_t_5; + Py_UCS4 __pyx_t_6; + PyObject *__pyx_t_7 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("_unellipsify", 1); + + /* "View.MemoryView":677 + * """ + * cdef Py_ssize_t idx + * tup = index if isinstance(index, tuple) else (index,) # <<<<<<<<<<<<<< + * + * result = [slice(None)] * ndim + */ + __pyx_t_2 = PyTuple_Check(__pyx_v_index); + if (__pyx_t_2) { + __Pyx_INCREF(((PyObject*)__pyx_v_index)); + __pyx_t_1 = __pyx_v_index; + } else { + __pyx_t_3 = PyTuple_New(1); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 677, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_INCREF(__pyx_v_index); + __Pyx_GIVEREF(__pyx_v_index); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_v_index)) __PYX_ERR(1, 677, __pyx_L1_error); + __pyx_t_1 = __pyx_t_3; + __pyx_t_3 = 0; + } + __pyx_v_tup = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "View.MemoryView":679 + * tup = index if isinstance(index, tuple) else (index,) + * + * result = [slice(None)] * ndim # <<<<<<<<<<<<<< + * have_slices = False + * seen_ellipsis = False + */ + __pyx_t_1 = PyList_New(1 * ((__pyx_v_ndim<0) ? 0:__pyx_v_ndim)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 679, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + { Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < __pyx_v_ndim; __pyx_temp++) { + __Pyx_INCREF(__pyx_slice__5); + __Pyx_GIVEREF(__pyx_slice__5); + if (__Pyx_PyList_SET_ITEM(__pyx_t_1, __pyx_temp, __pyx_slice__5)) __PYX_ERR(1, 679, __pyx_L1_error); + } + } + __pyx_v_result = ((PyObject*)__pyx_t_1); + __pyx_t_1 = 0; + + /* "View.MemoryView":680 + * + * result = [slice(None)] * ndim + * have_slices = False # <<<<<<<<<<<<<< + * seen_ellipsis = False + * idx = 0 + */ + __pyx_v_have_slices = 0; + + /* "View.MemoryView":681 + * result = [slice(None)] * ndim + * have_slices = False + * seen_ellipsis = False # <<<<<<<<<<<<<< + * idx = 0 + * for item in tup: + */ + __pyx_v_seen_ellipsis = 0; + + /* "View.MemoryView":682 + * have_slices = False + * seen_ellipsis = False + * idx = 0 # <<<<<<<<<<<<<< + * for item in tup: + * if item is Ellipsis: + */ + __pyx_v_idx = 0; + + /* "View.MemoryView":683 + * seen_ellipsis = False + * idx = 0 + * for item in tup: # <<<<<<<<<<<<<< + * if item is Ellipsis: + * if not seen_ellipsis: + */ + if (unlikely(__pyx_v_tup == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not iterable"); + __PYX_ERR(1, 683, __pyx_L1_error) + } + __pyx_t_1 = __pyx_v_tup; __Pyx_INCREF(__pyx_t_1); + __pyx_t_4 = 0; + for (;;) { + { + Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_1); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 683, __pyx_L1_error) + #endif + if (__pyx_t_4 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_3 = PyTuple_GET_ITEM(__pyx_t_1, __pyx_t_4); __Pyx_INCREF(__pyx_t_3); __pyx_t_4++; if (unlikely((0 < 0))) __PYX_ERR(1, 683, __pyx_L1_error) + #else + __pyx_t_3 = __Pyx_PySequence_ITEM(__pyx_t_1, __pyx_t_4); __pyx_t_4++; if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 683, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + #endif + __Pyx_XDECREF_SET(__pyx_v_item, __pyx_t_3); + __pyx_t_3 = 0; + + /* "View.MemoryView":684 + * idx = 0 + * for item in tup: + * if item is Ellipsis: # <<<<<<<<<<<<<< + * if not seen_ellipsis: + * idx += ndim - len(tup) + */ + __pyx_t_2 = (__pyx_v_item == __pyx_builtin_Ellipsis); + if (__pyx_t_2) { + + /* "View.MemoryView":685 + * for item in tup: + * if item is Ellipsis: + * if not seen_ellipsis: # <<<<<<<<<<<<<< + * idx += ndim - len(tup) + * seen_ellipsis = True + */ + __pyx_t_2 = (!__pyx_v_seen_ellipsis); + if (__pyx_t_2) { + + /* "View.MemoryView":686 + * if item is Ellipsis: + * if not seen_ellipsis: + * idx += ndim - len(tup) # <<<<<<<<<<<<<< + * seen_ellipsis = True + * have_slices = True + */ + if (unlikely(__pyx_v_tup == Py_None)) { + PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); + __PYX_ERR(1, 686, __pyx_L1_error) + } + __pyx_t_5 = __Pyx_PyTuple_GET_SIZE(__pyx_v_tup); if (unlikely(__pyx_t_5 == ((Py_ssize_t)-1))) __PYX_ERR(1, 686, __pyx_L1_error) + __pyx_v_idx = (__pyx_v_idx + (__pyx_v_ndim - __pyx_t_5)); + + /* "View.MemoryView":687 + * if not seen_ellipsis: + * idx += ndim - len(tup) + * seen_ellipsis = True # <<<<<<<<<<<<<< + * have_slices = True + * else: + */ + __pyx_v_seen_ellipsis = 1; + + /* "View.MemoryView":685 + * for item in tup: + * if item is Ellipsis: + * if not seen_ellipsis: # <<<<<<<<<<<<<< + * idx += ndim - len(tup) + * seen_ellipsis = True + */ + } + + /* "View.MemoryView":688 + * idx += ndim - len(tup) + * seen_ellipsis = True + * have_slices = True # <<<<<<<<<<<<<< + * else: + * if isinstance(item, slice): + */ + __pyx_v_have_slices = 1; + + /* "View.MemoryView":684 + * idx = 0 + * for item in tup: + * if item is Ellipsis: # <<<<<<<<<<<<<< + * if not seen_ellipsis: + * idx += ndim - len(tup) + */ + goto __pyx_L5; + } + + /* "View.MemoryView":690 + * have_slices = True + * else: + * if isinstance(item, slice): # <<<<<<<<<<<<<< + * have_slices = True + * elif not PyIndex_Check(item): + */ + /*else*/ { + __pyx_t_2 = PySlice_Check(__pyx_v_item); + if (__pyx_t_2) { + + /* "View.MemoryView":691 + * else: + * if isinstance(item, slice): + * have_slices = True # <<<<<<<<<<<<<< + * elif not PyIndex_Check(item): + * raise TypeError, f"Cannot index with type '{type(item)}'" + */ + __pyx_v_have_slices = 1; + + /* "View.MemoryView":690 + * have_slices = True + * else: + * if isinstance(item, slice): # <<<<<<<<<<<<<< + * have_slices = True + * elif not PyIndex_Check(item): + */ + goto __pyx_L7; + } + + /* "View.MemoryView":692 + * if isinstance(item, slice): + * have_slices = True + * elif not PyIndex_Check(item): # <<<<<<<<<<<<<< + * raise TypeError, f"Cannot index with type '{type(item)}'" + * result[idx] = item + */ + __pyx_t_2 = (!(PyIndex_Check(__pyx_v_item) != 0)); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":693 + * have_slices = True + * elif not PyIndex_Check(item): + * raise TypeError, f"Cannot index with type '{type(item)}'" # <<<<<<<<<<<<<< + * result[idx] = item + * idx += 1 + */ + __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 693, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_5 = 0; + __pyx_t_6 = 127; + __Pyx_INCREF(__pyx_kp_u_Cannot_index_with_type); + __pyx_t_5 += 24; + __Pyx_GIVEREF(__pyx_kp_u_Cannot_index_with_type); + PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_kp_u_Cannot_index_with_type); + __pyx_t_7 = __Pyx_PyObject_FormatSimple(((PyObject *)Py_TYPE(__pyx_v_item)), __pyx_empty_unicode); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 693, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_6 = (__Pyx_PyUnicode_MAX_CHAR_VALUE(__pyx_t_7) > __pyx_t_6) ? __Pyx_PyUnicode_MAX_CHAR_VALUE(__pyx_t_7) : __pyx_t_6; + __pyx_t_5 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_7); + __pyx_t_7 = 0; + __Pyx_INCREF(__pyx_kp_u__6); + __pyx_t_5 += 1; + __Pyx_GIVEREF(__pyx_kp_u__6); + PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_kp_u__6); + __pyx_t_7 = __Pyx_PyUnicode_Join(__pyx_t_3, 3, __pyx_t_5, __pyx_t_6); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 693, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_t_7, 0, 0); + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + __PYX_ERR(1, 693, __pyx_L1_error) + + /* "View.MemoryView":692 + * if isinstance(item, slice): + * have_slices = True + * elif not PyIndex_Check(item): # <<<<<<<<<<<<<< + * raise TypeError, f"Cannot index with type '{type(item)}'" + * result[idx] = item + */ + } + __pyx_L7:; + + /* "View.MemoryView":694 + * elif not PyIndex_Check(item): + * raise TypeError, f"Cannot index with type '{type(item)}'" + * result[idx] = item # <<<<<<<<<<<<<< + * idx += 1 + * + */ + if (unlikely((__Pyx_SetItemInt(__pyx_v_result, __pyx_v_idx, __pyx_v_item, Py_ssize_t, 1, PyInt_FromSsize_t, 1, 1, 1) < 0))) __PYX_ERR(1, 694, __pyx_L1_error) + } + __pyx_L5:; + + /* "View.MemoryView":695 + * raise TypeError, f"Cannot index with type '{type(item)}'" + * result[idx] = item + * idx += 1 # <<<<<<<<<<<<<< + * + * nslices = ndim - idx + */ + __pyx_v_idx = (__pyx_v_idx + 1); + + /* "View.MemoryView":683 + * seen_ellipsis = False + * idx = 0 + * for item in tup: # <<<<<<<<<<<<<< + * if item is Ellipsis: + * if not seen_ellipsis: + */ + } + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "View.MemoryView":697 + * idx += 1 + * + * nslices = ndim - idx # <<<<<<<<<<<<<< + * return have_slices or nslices, tuple(result) + * + */ + __pyx_v_nslices = (__pyx_v_ndim - __pyx_v_idx); + + /* "View.MemoryView":698 + * + * nslices = ndim - idx + * return have_slices or nslices, tuple(result) # <<<<<<<<<<<<<< + * + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: + */ + __Pyx_XDECREF(__pyx_r); + if (!__pyx_v_have_slices) { + } else { + __pyx_t_7 = __Pyx_PyBool_FromLong(__pyx_v_have_slices); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 698, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_1 = __pyx_t_7; + __pyx_t_7 = 0; + goto __pyx_L9_bool_binop_done; + } + __pyx_t_7 = PyInt_FromSsize_t(__pyx_v_nslices); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 698, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_1 = __pyx_t_7; + __pyx_t_7 = 0; + __pyx_L9_bool_binop_done:; + __pyx_t_7 = PyList_AsTuple(__pyx_v_result); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 698, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_3 = PyTuple_New(2); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 698, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_GIVEREF(__pyx_t_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_t_1)) __PYX_ERR(1, 698, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_7); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_7)) __PYX_ERR(1, 698, __pyx_L1_error); + __pyx_t_1 = 0; + __pyx_t_7 = 0; + __pyx_r = ((PyObject*)__pyx_t_3); + __pyx_t_3 = 0; + goto __pyx_L0; + + /* "View.MemoryView":671 + * return isinstance(o, memoryview) + * + * cdef tuple _unellipsify(object index, int ndim): # <<<<<<<<<<<<<< + * """ + * Replace all ellipses with full slices and fill incomplete indices with + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_AddTraceback("View.MemoryView._unellipsify", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v_tup); + __Pyx_XDECREF(__pyx_v_result); + __Pyx_XDECREF(__pyx_v_item); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":700 + * return have_slices or nslices, tuple(result) + * + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: # <<<<<<<<<<<<<< + * for suboffset in suboffsets[:ndim]: + * if suboffset >= 0: + */ + +static int assert_direct_dimensions(Py_ssize_t *__pyx_v_suboffsets, int __pyx_v_ndim) { + Py_ssize_t __pyx_v_suboffset; + int __pyx_r; + Py_ssize_t *__pyx_t_1; + Py_ssize_t *__pyx_t_2; + Py_ssize_t *__pyx_t_3; + int __pyx_t_4; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + + /* "View.MemoryView":701 + * + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: + * for suboffset in suboffsets[:ndim]: # <<<<<<<<<<<<<< + * if suboffset >= 0: + * raise ValueError, "Indirect dimensions not supported" + */ + __pyx_t_2 = (__pyx_v_suboffsets + __pyx_v_ndim); + for (__pyx_t_3 = __pyx_v_suboffsets; __pyx_t_3 < __pyx_t_2; __pyx_t_3++) { + __pyx_t_1 = __pyx_t_3; + __pyx_v_suboffset = (__pyx_t_1[0]); + + /* "View.MemoryView":702 + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: + * for suboffset in suboffsets[:ndim]: + * if suboffset >= 0: # <<<<<<<<<<<<<< + * raise ValueError, "Indirect dimensions not supported" + * return 0 # return type just used as an error flag + */ + __pyx_t_4 = (__pyx_v_suboffset >= 0); + if (unlikely(__pyx_t_4)) { + + /* "View.MemoryView":703 + * for suboffset in suboffsets[:ndim]: + * if suboffset >= 0: + * raise ValueError, "Indirect dimensions not supported" # <<<<<<<<<<<<<< + * return 0 # return type just used as an error flag + * + */ + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_kp_s_Indirect_dimensions_not_supporte, 0, 0); + __PYX_ERR(1, 703, __pyx_L1_error) + + /* "View.MemoryView":702 + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: + * for suboffset in suboffsets[:ndim]: + * if suboffset >= 0: # <<<<<<<<<<<<<< + * raise ValueError, "Indirect dimensions not supported" + * return 0 # return type just used as an error flag + */ + } + } + + /* "View.MemoryView":704 + * if suboffset >= 0: + * raise ValueError, "Indirect dimensions not supported" + * return 0 # return type just used as an error flag # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":700 + * return have_slices or nslices, tuple(result) + * + * cdef int assert_direct_dimensions(Py_ssize_t *suboffsets, int ndim) except -1: # <<<<<<<<<<<<<< + * for suboffset in suboffsets[:ndim]: + * if suboffset >= 0: + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView.assert_direct_dimensions", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":711 + * + * @cname('__pyx_memview_slice') + * cdef memoryview memview_slice(memoryview memview, object indices): # <<<<<<<<<<<<<< + * cdef int new_ndim = 0, suboffset_dim = -1, dim + * cdef bint negative_step + */ + +static struct __pyx_memoryview_obj *__pyx_memview_slice(struct __pyx_memoryview_obj *__pyx_v_memview, PyObject *__pyx_v_indices) { + int __pyx_v_new_ndim; + int __pyx_v_suboffset_dim; + int __pyx_v_dim; + __Pyx_memviewslice __pyx_v_src; + __Pyx_memviewslice __pyx_v_dst; + __Pyx_memviewslice *__pyx_v_p_src; + struct __pyx_memoryviewslice_obj *__pyx_v_memviewsliceobj = 0; + __Pyx_memviewslice *__pyx_v_p_dst; + int *__pyx_v_p_suboffset_dim; + Py_ssize_t __pyx_v_start; + Py_ssize_t __pyx_v_stop; + Py_ssize_t __pyx_v_step; + Py_ssize_t __pyx_v_cindex; + int __pyx_v_have_start; + int __pyx_v_have_stop; + int __pyx_v_have_step; + PyObject *__pyx_v_index = NULL; + struct __pyx_memoryview_obj *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + struct __pyx_memoryview_obj *__pyx_t_3; + char *__pyx_t_4; + int __pyx_t_5; + Py_ssize_t __pyx_t_6; + PyObject *(*__pyx_t_7)(PyObject *); + PyObject *__pyx_t_8 = NULL; + Py_ssize_t __pyx_t_9; + int __pyx_t_10; + Py_ssize_t __pyx_t_11; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("memview_slice", 1); + + /* "View.MemoryView":712 + * @cname('__pyx_memview_slice') + * cdef memoryview memview_slice(memoryview memview, object indices): + * cdef int new_ndim = 0, suboffset_dim = -1, dim # <<<<<<<<<<<<<< + * cdef bint negative_step + * cdef __Pyx_memviewslice src, dst + */ + __pyx_v_new_ndim = 0; + __pyx_v_suboffset_dim = -1; + + /* "View.MemoryView":719 + * + * + * memset(&dst, 0, sizeof(dst)) # <<<<<<<<<<<<<< + * + * cdef _memoryviewslice memviewsliceobj + */ + (void)(memset((&__pyx_v_dst), 0, (sizeof(__pyx_v_dst)))); + + /* "View.MemoryView":723 + * cdef _memoryviewslice memviewsliceobj + * + * assert memview.view.ndim > 0 # <<<<<<<<<<<<<< + * + * if isinstance(memview, _memoryviewslice): + */ + #ifndef CYTHON_WITHOUT_ASSERTIONS + if (unlikely(__pyx_assertions_enabled())) { + __pyx_t_1 = (__pyx_v_memview->view.ndim > 0); + if (unlikely(!__pyx_t_1)) { + __Pyx_Raise(__pyx_builtin_AssertionError, 0, 0, 0); + __PYX_ERR(1, 723, __pyx_L1_error) + } + } + #else + if ((1)); else __PYX_ERR(1, 723, __pyx_L1_error) + #endif + + /* "View.MemoryView":725 + * assert memview.view.ndim > 0 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * memviewsliceobj = memview + * p_src = &memviewsliceobj.from_slice + */ + __pyx_t_1 = __Pyx_TypeCheck(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type); + if (__pyx_t_1) { + + /* "View.MemoryView":726 + * + * if isinstance(memview, _memoryviewslice): + * memviewsliceobj = memview # <<<<<<<<<<<<<< + * p_src = &memviewsliceobj.from_slice + * else: + */ + if (!(likely(((((PyObject *)__pyx_v_memview)) == Py_None) || likely(__Pyx_TypeTest(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type))))) __PYX_ERR(1, 726, __pyx_L1_error) + __pyx_t_2 = ((PyObject *)__pyx_v_memview); + __Pyx_INCREF(__pyx_t_2); + __pyx_v_memviewsliceobj = ((struct __pyx_memoryviewslice_obj *)__pyx_t_2); + __pyx_t_2 = 0; + + /* "View.MemoryView":727 + * if isinstance(memview, _memoryviewslice): + * memviewsliceobj = memview + * p_src = &memviewsliceobj.from_slice # <<<<<<<<<<<<<< + * else: + * slice_copy(memview, &src) + */ + __pyx_v_p_src = (&__pyx_v_memviewsliceobj->from_slice); + + /* "View.MemoryView":725 + * assert memview.view.ndim > 0 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * memviewsliceobj = memview + * p_src = &memviewsliceobj.from_slice + */ + goto __pyx_L3; + } + + /* "View.MemoryView":729 + * p_src = &memviewsliceobj.from_slice + * else: + * slice_copy(memview, &src) # <<<<<<<<<<<<<< + * p_src = &src + * + */ + /*else*/ { + __pyx_memoryview_slice_copy(__pyx_v_memview, (&__pyx_v_src)); + + /* "View.MemoryView":730 + * else: + * slice_copy(memview, &src) + * p_src = &src # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_p_src = (&__pyx_v_src); + } + __pyx_L3:; + + /* "View.MemoryView":736 + * + * + * dst.memview = p_src.memview # <<<<<<<<<<<<<< + * dst.data = p_src.data + * + */ + __pyx_t_3 = __pyx_v_p_src->memview; + __pyx_v_dst.memview = __pyx_t_3; + + /* "View.MemoryView":737 + * + * dst.memview = p_src.memview + * dst.data = p_src.data # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_4 = __pyx_v_p_src->data; + __pyx_v_dst.data = __pyx_t_4; + + /* "View.MemoryView":742 + * + * + * cdef __Pyx_memviewslice *p_dst = &dst # <<<<<<<<<<<<<< + * cdef int *p_suboffset_dim = &suboffset_dim + * cdef Py_ssize_t start, stop, step, cindex + */ + __pyx_v_p_dst = (&__pyx_v_dst); + + /* "View.MemoryView":743 + * + * cdef __Pyx_memviewslice *p_dst = &dst + * cdef int *p_suboffset_dim = &suboffset_dim # <<<<<<<<<<<<<< + * cdef Py_ssize_t start, stop, step, cindex + * cdef bint have_start, have_stop, have_step + */ + __pyx_v_p_suboffset_dim = (&__pyx_v_suboffset_dim); + + /* "View.MemoryView":747 + * cdef bint have_start, have_stop, have_step + * + * for dim, index in enumerate(indices): # <<<<<<<<<<<<<< + * if PyIndex_Check(index): + * cindex = index + */ + __pyx_t_5 = 0; + if (likely(PyList_CheckExact(__pyx_v_indices)) || PyTuple_CheckExact(__pyx_v_indices)) { + __pyx_t_2 = __pyx_v_indices; __Pyx_INCREF(__pyx_t_2); + __pyx_t_6 = 0; + __pyx_t_7 = NULL; + } else { + __pyx_t_6 = -1; __pyx_t_2 = PyObject_GetIter(__pyx_v_indices); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 747, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_7 = __Pyx_PyObject_GetIterNextFunc(__pyx_t_2); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 747, __pyx_L1_error) + } + for (;;) { + if (likely(!__pyx_t_7)) { + if (likely(PyList_CheckExact(__pyx_t_2))) { + { + Py_ssize_t __pyx_temp = __Pyx_PyList_GET_SIZE(__pyx_t_2); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 747, __pyx_L1_error) + #endif + if (__pyx_t_6 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_8 = PyList_GET_ITEM(__pyx_t_2, __pyx_t_6); __Pyx_INCREF(__pyx_t_8); __pyx_t_6++; if (unlikely((0 < 0))) __PYX_ERR(1, 747, __pyx_L1_error) + #else + __pyx_t_8 = __Pyx_PySequence_ITEM(__pyx_t_2, __pyx_t_6); __pyx_t_6++; if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 747, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + #endif + } else { + { + Py_ssize_t __pyx_temp = __Pyx_PyTuple_GET_SIZE(__pyx_t_2); + #if !CYTHON_ASSUME_SAFE_MACROS + if (unlikely((__pyx_temp < 0))) __PYX_ERR(1, 747, __pyx_L1_error) + #endif + if (__pyx_t_6 >= __pyx_temp) break; + } + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + __pyx_t_8 = PyTuple_GET_ITEM(__pyx_t_2, __pyx_t_6); __Pyx_INCREF(__pyx_t_8); __pyx_t_6++; if (unlikely((0 < 0))) __PYX_ERR(1, 747, __pyx_L1_error) + #else + __pyx_t_8 = __Pyx_PySequence_ITEM(__pyx_t_2, __pyx_t_6); __pyx_t_6++; if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 747, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + #endif + } + } else { + __pyx_t_8 = __pyx_t_7(__pyx_t_2); + if (unlikely(!__pyx_t_8)) { + PyObject* exc_type = PyErr_Occurred(); + if (exc_type) { + if (likely(__Pyx_PyErr_GivenExceptionMatches(exc_type, PyExc_StopIteration))) PyErr_Clear(); + else __PYX_ERR(1, 747, __pyx_L1_error) + } + break; + } + __Pyx_GOTREF(__pyx_t_8); + } + __Pyx_XDECREF_SET(__pyx_v_index, __pyx_t_8); + __pyx_t_8 = 0; + __pyx_v_dim = __pyx_t_5; + __pyx_t_5 = (__pyx_t_5 + 1); + + /* "View.MemoryView":748 + * + * for dim, index in enumerate(indices): + * if PyIndex_Check(index): # <<<<<<<<<<<<<< + * cindex = index + * slice_memviewslice( + */ + __pyx_t_1 = (PyIndex_Check(__pyx_v_index) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":749 + * for dim, index in enumerate(indices): + * if PyIndex_Check(index): + * cindex = index # <<<<<<<<<<<<<< + * slice_memviewslice( + * p_dst, p_src.shape[dim], p_src.strides[dim], p_src.suboffsets[dim], + */ + __pyx_t_9 = __Pyx_PyIndex_AsSsize_t(__pyx_v_index); if (unlikely((__pyx_t_9 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 749, __pyx_L1_error) + __pyx_v_cindex = __pyx_t_9; + + /* "View.MemoryView":750 + * if PyIndex_Check(index): + * cindex = index + * slice_memviewslice( # <<<<<<<<<<<<<< + * p_dst, p_src.shape[dim], p_src.strides[dim], p_src.suboffsets[dim], + * dim, new_ndim, p_suboffset_dim, + */ + __pyx_t_10 = __pyx_memoryview_slice_memviewslice(__pyx_v_p_dst, (__pyx_v_p_src->shape[__pyx_v_dim]), (__pyx_v_p_src->strides[__pyx_v_dim]), (__pyx_v_p_src->suboffsets[__pyx_v_dim]), __pyx_v_dim, __pyx_v_new_ndim, __pyx_v_p_suboffset_dim, __pyx_v_cindex, 0, 0, 0, 0, 0, 0); if (unlikely(__pyx_t_10 == ((int)-1))) __PYX_ERR(1, 750, __pyx_L1_error) + + /* "View.MemoryView":748 + * + * for dim, index in enumerate(indices): + * if PyIndex_Check(index): # <<<<<<<<<<<<<< + * cindex = index + * slice_memviewslice( + */ + goto __pyx_L6; + } + + /* "View.MemoryView":756 + * 0, 0, 0, # have_{start,stop,step} + * False) + * elif index is None: # <<<<<<<<<<<<<< + * p_dst.shape[new_ndim] = 1 + * p_dst.strides[new_ndim] = 0 + */ + __pyx_t_1 = (__pyx_v_index == Py_None); + if (__pyx_t_1) { + + /* "View.MemoryView":757 + * False) + * elif index is None: + * p_dst.shape[new_ndim] = 1 # <<<<<<<<<<<<<< + * p_dst.strides[new_ndim] = 0 + * p_dst.suboffsets[new_ndim] = -1 + */ + (__pyx_v_p_dst->shape[__pyx_v_new_ndim]) = 1; + + /* "View.MemoryView":758 + * elif index is None: + * p_dst.shape[new_ndim] = 1 + * p_dst.strides[new_ndim] = 0 # <<<<<<<<<<<<<< + * p_dst.suboffsets[new_ndim] = -1 + * new_ndim += 1 + */ + (__pyx_v_p_dst->strides[__pyx_v_new_ndim]) = 0; + + /* "View.MemoryView":759 + * p_dst.shape[new_ndim] = 1 + * p_dst.strides[new_ndim] = 0 + * p_dst.suboffsets[new_ndim] = -1 # <<<<<<<<<<<<<< + * new_ndim += 1 + * else: + */ + (__pyx_v_p_dst->suboffsets[__pyx_v_new_ndim]) = -1L; + + /* "View.MemoryView":760 + * p_dst.strides[new_ndim] = 0 + * p_dst.suboffsets[new_ndim] = -1 + * new_ndim += 1 # <<<<<<<<<<<<<< + * else: + * start = index.start or 0 + */ + __pyx_v_new_ndim = (__pyx_v_new_ndim + 1); + + /* "View.MemoryView":756 + * 0, 0, 0, # have_{start,stop,step} + * False) + * elif index is None: # <<<<<<<<<<<<<< + * p_dst.shape[new_ndim] = 1 + * p_dst.strides[new_ndim] = 0 + */ + goto __pyx_L6; + } + + /* "View.MemoryView":762 + * new_ndim += 1 + * else: + * start = index.start or 0 # <<<<<<<<<<<<<< + * stop = index.stop or 0 + * step = index.step or 0 + */ + /*else*/ { + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_start); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 762, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = __Pyx_PyObject_IsTrue(__pyx_t_8); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 762, __pyx_L1_error) + if (!__pyx_t_1) { + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + } else { + __pyx_t_11 = __Pyx_PyIndex_AsSsize_t(__pyx_t_8); if (unlikely((__pyx_t_11 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 762, __pyx_L1_error) + __pyx_t_9 = __pyx_t_11; + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + goto __pyx_L7_bool_binop_done; + } + __pyx_t_9 = 0; + __pyx_L7_bool_binop_done:; + __pyx_v_start = __pyx_t_9; + + /* "View.MemoryView":763 + * else: + * start = index.start or 0 + * stop = index.stop or 0 # <<<<<<<<<<<<<< + * step = index.step or 0 + * + */ + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_stop); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 763, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = __Pyx_PyObject_IsTrue(__pyx_t_8); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 763, __pyx_L1_error) + if (!__pyx_t_1) { + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + } else { + __pyx_t_11 = __Pyx_PyIndex_AsSsize_t(__pyx_t_8); if (unlikely((__pyx_t_11 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 763, __pyx_L1_error) + __pyx_t_9 = __pyx_t_11; + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + goto __pyx_L9_bool_binop_done; + } + __pyx_t_9 = 0; + __pyx_L9_bool_binop_done:; + __pyx_v_stop = __pyx_t_9; + + /* "View.MemoryView":764 + * start = index.start or 0 + * stop = index.stop or 0 + * step = index.step or 0 # <<<<<<<<<<<<<< + * + * have_start = index.start is not None + */ + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_step); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 764, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = __Pyx_PyObject_IsTrue(__pyx_t_8); if (unlikely((__pyx_t_1 < 0))) __PYX_ERR(1, 764, __pyx_L1_error) + if (!__pyx_t_1) { + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + } else { + __pyx_t_11 = __Pyx_PyIndex_AsSsize_t(__pyx_t_8); if (unlikely((__pyx_t_11 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 764, __pyx_L1_error) + __pyx_t_9 = __pyx_t_11; + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + goto __pyx_L11_bool_binop_done; + } + __pyx_t_9 = 0; + __pyx_L11_bool_binop_done:; + __pyx_v_step = __pyx_t_9; + + /* "View.MemoryView":766 + * step = index.step or 0 + * + * have_start = index.start is not None # <<<<<<<<<<<<<< + * have_stop = index.stop is not None + * have_step = index.step is not None + */ + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_start); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 766, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = (__pyx_t_8 != Py_None); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __pyx_v_have_start = __pyx_t_1; + + /* "View.MemoryView":767 + * + * have_start = index.start is not None + * have_stop = index.stop is not None # <<<<<<<<<<<<<< + * have_step = index.step is not None + * + */ + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_stop); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 767, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = (__pyx_t_8 != Py_None); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __pyx_v_have_stop = __pyx_t_1; + + /* "View.MemoryView":768 + * have_start = index.start is not None + * have_stop = index.stop is not None + * have_step = index.step is not None # <<<<<<<<<<<<<< + * + * slice_memviewslice( + */ + __pyx_t_8 = __Pyx_PyObject_GetAttrStr(__pyx_v_index, __pyx_n_s_step); if (unlikely(!__pyx_t_8)) __PYX_ERR(1, 768, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_8); + __pyx_t_1 = (__pyx_t_8 != Py_None); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __pyx_v_have_step = __pyx_t_1; + + /* "View.MemoryView":770 + * have_step = index.step is not None + * + * slice_memviewslice( # <<<<<<<<<<<<<< + * p_dst, p_src.shape[dim], p_src.strides[dim], p_src.suboffsets[dim], + * dim, new_ndim, p_suboffset_dim, + */ + __pyx_t_10 = __pyx_memoryview_slice_memviewslice(__pyx_v_p_dst, (__pyx_v_p_src->shape[__pyx_v_dim]), (__pyx_v_p_src->strides[__pyx_v_dim]), (__pyx_v_p_src->suboffsets[__pyx_v_dim]), __pyx_v_dim, __pyx_v_new_ndim, __pyx_v_p_suboffset_dim, __pyx_v_start, __pyx_v_stop, __pyx_v_step, __pyx_v_have_start, __pyx_v_have_stop, __pyx_v_have_step, 1); if (unlikely(__pyx_t_10 == ((int)-1))) __PYX_ERR(1, 770, __pyx_L1_error) + + /* "View.MemoryView":776 + * have_start, have_stop, have_step, + * True) + * new_ndim += 1 # <<<<<<<<<<<<<< + * + * if isinstance(memview, _memoryviewslice): + */ + __pyx_v_new_ndim = (__pyx_v_new_ndim + 1); + } + __pyx_L6:; + + /* "View.MemoryView":747 + * cdef bint have_start, have_stop, have_step + * + * for dim, index in enumerate(indices): # <<<<<<<<<<<<<< + * if PyIndex_Check(index): + * cindex = index + */ + } + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + + /* "View.MemoryView":778 + * new_ndim += 1 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * return memoryview_fromslice(dst, new_ndim, + * memviewsliceobj.to_object_func, + */ + __pyx_t_1 = __Pyx_TypeCheck(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type); + if (__pyx_t_1) { + + /* "View.MemoryView":779 + * + * if isinstance(memview, _memoryviewslice): + * return memoryview_fromslice(dst, new_ndim, # <<<<<<<<<<<<<< + * memviewsliceobj.to_object_func, + * memviewsliceobj.to_dtype_func, + */ + __Pyx_XDECREF((PyObject *)__pyx_r); + + /* "View.MemoryView":780 + * if isinstance(memview, _memoryviewslice): + * return memoryview_fromslice(dst, new_ndim, + * memviewsliceobj.to_object_func, # <<<<<<<<<<<<<< + * memviewsliceobj.to_dtype_func, + * memview.dtype_is_object) + */ + if (unlikely(!__pyx_v_memviewsliceobj)) { __Pyx_RaiseUnboundLocalError("memviewsliceobj"); __PYX_ERR(1, 780, __pyx_L1_error) } + + /* "View.MemoryView":781 + * return memoryview_fromslice(dst, new_ndim, + * memviewsliceobj.to_object_func, + * memviewsliceobj.to_dtype_func, # <<<<<<<<<<<<<< + * memview.dtype_is_object) + * else: + */ + if (unlikely(!__pyx_v_memviewsliceobj)) { __Pyx_RaiseUnboundLocalError("memviewsliceobj"); __PYX_ERR(1, 781, __pyx_L1_error) } + + /* "View.MemoryView":779 + * + * if isinstance(memview, _memoryviewslice): + * return memoryview_fromslice(dst, new_ndim, # <<<<<<<<<<<<<< + * memviewsliceobj.to_object_func, + * memviewsliceobj.to_dtype_func, + */ + __pyx_t_2 = __pyx_memoryview_fromslice(__pyx_v_dst, __pyx_v_new_ndim, __pyx_v_memviewsliceobj->to_object_func, __pyx_v_memviewsliceobj->to_dtype_func, __pyx_v_memview->dtype_is_object); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 779, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + if (!(likely(((__pyx_t_2) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_2, __pyx_memoryview_type))))) __PYX_ERR(1, 779, __pyx_L1_error) + __pyx_r = ((struct __pyx_memoryview_obj *)__pyx_t_2); + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":778 + * new_ndim += 1 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * return memoryview_fromslice(dst, new_ndim, + * memviewsliceobj.to_object_func, + */ + } + + /* "View.MemoryView":784 + * memview.dtype_is_object) + * else: + * return memoryview_fromslice(dst, new_ndim, NULL, NULL, # <<<<<<<<<<<<<< + * memview.dtype_is_object) + * + */ + /*else*/ { + __Pyx_XDECREF((PyObject *)__pyx_r); + + /* "View.MemoryView":785 + * else: + * return memoryview_fromslice(dst, new_ndim, NULL, NULL, + * memview.dtype_is_object) # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_2 = __pyx_memoryview_fromslice(__pyx_v_dst, __pyx_v_new_ndim, NULL, NULL, __pyx_v_memview->dtype_is_object); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 784, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + + /* "View.MemoryView":784 + * memview.dtype_is_object) + * else: + * return memoryview_fromslice(dst, new_ndim, NULL, NULL, # <<<<<<<<<<<<<< + * memview.dtype_is_object) + * + */ + if (!(likely(((__pyx_t_2) == Py_None) || likely(__Pyx_TypeTest(__pyx_t_2, __pyx_memoryview_type))))) __PYX_ERR(1, 784, __pyx_L1_error) + __pyx_r = ((struct __pyx_memoryview_obj *)__pyx_t_2); + __pyx_t_2 = 0; + goto __pyx_L0; + } + + /* "View.MemoryView":711 + * + * @cname('__pyx_memview_slice') + * cdef memoryview memview_slice(memoryview memview, object indices): # <<<<<<<<<<<<<< + * cdef int new_ndim = 0, suboffset_dim = -1, dim + * cdef bint negative_step + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("View.MemoryView.memview_slice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_memviewsliceobj); + __Pyx_XDECREF(__pyx_v_index); + __Pyx_XGIVEREF((PyObject *)__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":793 + * + * @cname('__pyx_memoryview_slice_memviewslice') + * cdef int slice_memviewslice( # <<<<<<<<<<<<<< + * __Pyx_memviewslice *dst, + * Py_ssize_t shape, Py_ssize_t stride, Py_ssize_t suboffset, + */ + +static int __pyx_memoryview_slice_memviewslice(__Pyx_memviewslice *__pyx_v_dst, Py_ssize_t __pyx_v_shape, Py_ssize_t __pyx_v_stride, Py_ssize_t __pyx_v_suboffset, int __pyx_v_dim, int __pyx_v_new_ndim, int *__pyx_v_suboffset_dim, Py_ssize_t __pyx_v_start, Py_ssize_t __pyx_v_stop, Py_ssize_t __pyx_v_step, int __pyx_v_have_start, int __pyx_v_have_stop, int __pyx_v_have_step, int __pyx_v_is_slice) { + Py_ssize_t __pyx_v_new_shape; + int __pyx_v_negative_step; + int __pyx_r; + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save; + #endif + + /* "View.MemoryView":813 + * cdef bint negative_step + * + * if not is_slice: # <<<<<<<<<<<<<< + * + * if start < 0: + */ + __pyx_t_1 = (!__pyx_v_is_slice); + if (__pyx_t_1) { + + /* "View.MemoryView":815 + * if not is_slice: + * + * if start < 0: # <<<<<<<<<<<<<< + * start += shape + * if not 0 <= start < shape: + */ + __pyx_t_1 = (__pyx_v_start < 0); + if (__pyx_t_1) { + + /* "View.MemoryView":816 + * + * if start < 0: + * start += shape # <<<<<<<<<<<<<< + * if not 0 <= start < shape: + * _err_dim(PyExc_IndexError, "Index out of bounds (axis %d)", dim) + */ + __pyx_v_start = (__pyx_v_start + __pyx_v_shape); + + /* "View.MemoryView":815 + * if not is_slice: + * + * if start < 0: # <<<<<<<<<<<<<< + * start += shape + * if not 0 <= start < shape: + */ + } + + /* "View.MemoryView":817 + * if start < 0: + * start += shape + * if not 0 <= start < shape: # <<<<<<<<<<<<<< + * _err_dim(PyExc_IndexError, "Index out of bounds (axis %d)", dim) + * else: + */ + __pyx_t_1 = (0 <= __pyx_v_start); + if (__pyx_t_1) { + __pyx_t_1 = (__pyx_v_start < __pyx_v_shape); + } + __pyx_t_2 = (!__pyx_t_1); + if (__pyx_t_2) { + + /* "View.MemoryView":818 + * start += shape + * if not 0 <= start < shape: + * _err_dim(PyExc_IndexError, "Index out of bounds (axis %d)", dim) # <<<<<<<<<<<<<< + * else: + * + */ + __pyx_t_3 = __pyx_memoryview_err_dim(PyExc_IndexError, __pyx_kp_s_Index_out_of_bounds_axis_d, __pyx_v_dim); if (unlikely(__pyx_t_3 == ((int)-1))) __PYX_ERR(1, 818, __pyx_L1_error) + + /* "View.MemoryView":817 + * if start < 0: + * start += shape + * if not 0 <= start < shape: # <<<<<<<<<<<<<< + * _err_dim(PyExc_IndexError, "Index out of bounds (axis %d)", dim) + * else: + */ + } + + /* "View.MemoryView":813 + * cdef bint negative_step + * + * if not is_slice: # <<<<<<<<<<<<<< + * + * if start < 0: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":821 + * else: + * + * if have_step: # <<<<<<<<<<<<<< + * negative_step = step < 0 + * if step == 0: + */ + /*else*/ { + __pyx_t_2 = (__pyx_v_have_step != 0); + if (__pyx_t_2) { + + /* "View.MemoryView":822 + * + * if have_step: + * negative_step = step < 0 # <<<<<<<<<<<<<< + * if step == 0: + * _err_dim(PyExc_ValueError, "Step may not be zero (axis %d)", dim) + */ + __pyx_v_negative_step = (__pyx_v_step < 0); + + /* "View.MemoryView":823 + * if have_step: + * negative_step = step < 0 + * if step == 0: # <<<<<<<<<<<<<< + * _err_dim(PyExc_ValueError, "Step may not be zero (axis %d)", dim) + * else: + */ + __pyx_t_2 = (__pyx_v_step == 0); + if (__pyx_t_2) { + + /* "View.MemoryView":824 + * negative_step = step < 0 + * if step == 0: + * _err_dim(PyExc_ValueError, "Step may not be zero (axis %d)", dim) # <<<<<<<<<<<<<< + * else: + * negative_step = False + */ + __pyx_t_3 = __pyx_memoryview_err_dim(PyExc_ValueError, __pyx_kp_s_Step_may_not_be_zero_axis_d, __pyx_v_dim); if (unlikely(__pyx_t_3 == ((int)-1))) __PYX_ERR(1, 824, __pyx_L1_error) + + /* "View.MemoryView":823 + * if have_step: + * negative_step = step < 0 + * if step == 0: # <<<<<<<<<<<<<< + * _err_dim(PyExc_ValueError, "Step may not be zero (axis %d)", dim) + * else: + */ + } + + /* "View.MemoryView":821 + * else: + * + * if have_step: # <<<<<<<<<<<<<< + * negative_step = step < 0 + * if step == 0: + */ + goto __pyx_L6; + } + + /* "View.MemoryView":826 + * _err_dim(PyExc_ValueError, "Step may not be zero (axis %d)", dim) + * else: + * negative_step = False # <<<<<<<<<<<<<< + * step = 1 + * + */ + /*else*/ { + __pyx_v_negative_step = 0; + + /* "View.MemoryView":827 + * else: + * negative_step = False + * step = 1 # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_step = 1; + } + __pyx_L6:; + + /* "View.MemoryView":830 + * + * + * if have_start: # <<<<<<<<<<<<<< + * if start < 0: + * start += shape + */ + __pyx_t_2 = (__pyx_v_have_start != 0); + if (__pyx_t_2) { + + /* "View.MemoryView":831 + * + * if have_start: + * if start < 0: # <<<<<<<<<<<<<< + * start += shape + * if start < 0: + */ + __pyx_t_2 = (__pyx_v_start < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":832 + * if have_start: + * if start < 0: + * start += shape # <<<<<<<<<<<<<< + * if start < 0: + * start = 0 + */ + __pyx_v_start = (__pyx_v_start + __pyx_v_shape); + + /* "View.MemoryView":833 + * if start < 0: + * start += shape + * if start < 0: # <<<<<<<<<<<<<< + * start = 0 + * elif start >= shape: + */ + __pyx_t_2 = (__pyx_v_start < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":834 + * start += shape + * if start < 0: + * start = 0 # <<<<<<<<<<<<<< + * elif start >= shape: + * if negative_step: + */ + __pyx_v_start = 0; + + /* "View.MemoryView":833 + * if start < 0: + * start += shape + * if start < 0: # <<<<<<<<<<<<<< + * start = 0 + * elif start >= shape: + */ + } + + /* "View.MemoryView":831 + * + * if have_start: + * if start < 0: # <<<<<<<<<<<<<< + * start += shape + * if start < 0: + */ + goto __pyx_L9; + } + + /* "View.MemoryView":835 + * if start < 0: + * start = 0 + * elif start >= shape: # <<<<<<<<<<<<<< + * if negative_step: + * start = shape - 1 + */ + __pyx_t_2 = (__pyx_v_start >= __pyx_v_shape); + if (__pyx_t_2) { + + /* "View.MemoryView":836 + * start = 0 + * elif start >= shape: + * if negative_step: # <<<<<<<<<<<<<< + * start = shape - 1 + * else: + */ + if (__pyx_v_negative_step) { + + /* "View.MemoryView":837 + * elif start >= shape: + * if negative_step: + * start = shape - 1 # <<<<<<<<<<<<<< + * else: + * start = shape + */ + __pyx_v_start = (__pyx_v_shape - 1); + + /* "View.MemoryView":836 + * start = 0 + * elif start >= shape: + * if negative_step: # <<<<<<<<<<<<<< + * start = shape - 1 + * else: + */ + goto __pyx_L11; + } + + /* "View.MemoryView":839 + * start = shape - 1 + * else: + * start = shape # <<<<<<<<<<<<<< + * else: + * if negative_step: + */ + /*else*/ { + __pyx_v_start = __pyx_v_shape; + } + __pyx_L11:; + + /* "View.MemoryView":835 + * if start < 0: + * start = 0 + * elif start >= shape: # <<<<<<<<<<<<<< + * if negative_step: + * start = shape - 1 + */ + } + __pyx_L9:; + + /* "View.MemoryView":830 + * + * + * if have_start: # <<<<<<<<<<<<<< + * if start < 0: + * start += shape + */ + goto __pyx_L8; + } + + /* "View.MemoryView":841 + * start = shape + * else: + * if negative_step: # <<<<<<<<<<<<<< + * start = shape - 1 + * else: + */ + /*else*/ { + if (__pyx_v_negative_step) { + + /* "View.MemoryView":842 + * else: + * if negative_step: + * start = shape - 1 # <<<<<<<<<<<<<< + * else: + * start = 0 + */ + __pyx_v_start = (__pyx_v_shape - 1); + + /* "View.MemoryView":841 + * start = shape + * else: + * if negative_step: # <<<<<<<<<<<<<< + * start = shape - 1 + * else: + */ + goto __pyx_L12; + } + + /* "View.MemoryView":844 + * start = shape - 1 + * else: + * start = 0 # <<<<<<<<<<<<<< + * + * if have_stop: + */ + /*else*/ { + __pyx_v_start = 0; + } + __pyx_L12:; + } + __pyx_L8:; + + /* "View.MemoryView":846 + * start = 0 + * + * if have_stop: # <<<<<<<<<<<<<< + * if stop < 0: + * stop += shape + */ + __pyx_t_2 = (__pyx_v_have_stop != 0); + if (__pyx_t_2) { + + /* "View.MemoryView":847 + * + * if have_stop: + * if stop < 0: # <<<<<<<<<<<<<< + * stop += shape + * if stop < 0: + */ + __pyx_t_2 = (__pyx_v_stop < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":848 + * if have_stop: + * if stop < 0: + * stop += shape # <<<<<<<<<<<<<< + * if stop < 0: + * stop = 0 + */ + __pyx_v_stop = (__pyx_v_stop + __pyx_v_shape); + + /* "View.MemoryView":849 + * if stop < 0: + * stop += shape + * if stop < 0: # <<<<<<<<<<<<<< + * stop = 0 + * elif stop > shape: + */ + __pyx_t_2 = (__pyx_v_stop < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":850 + * stop += shape + * if stop < 0: + * stop = 0 # <<<<<<<<<<<<<< + * elif stop > shape: + * stop = shape + */ + __pyx_v_stop = 0; + + /* "View.MemoryView":849 + * if stop < 0: + * stop += shape + * if stop < 0: # <<<<<<<<<<<<<< + * stop = 0 + * elif stop > shape: + */ + } + + /* "View.MemoryView":847 + * + * if have_stop: + * if stop < 0: # <<<<<<<<<<<<<< + * stop += shape + * if stop < 0: + */ + goto __pyx_L14; + } + + /* "View.MemoryView":851 + * if stop < 0: + * stop = 0 + * elif stop > shape: # <<<<<<<<<<<<<< + * stop = shape + * else: + */ + __pyx_t_2 = (__pyx_v_stop > __pyx_v_shape); + if (__pyx_t_2) { + + /* "View.MemoryView":852 + * stop = 0 + * elif stop > shape: + * stop = shape # <<<<<<<<<<<<<< + * else: + * if negative_step: + */ + __pyx_v_stop = __pyx_v_shape; + + /* "View.MemoryView":851 + * if stop < 0: + * stop = 0 + * elif stop > shape: # <<<<<<<<<<<<<< + * stop = shape + * else: + */ + } + __pyx_L14:; + + /* "View.MemoryView":846 + * start = 0 + * + * if have_stop: # <<<<<<<<<<<<<< + * if stop < 0: + * stop += shape + */ + goto __pyx_L13; + } + + /* "View.MemoryView":854 + * stop = shape + * else: + * if negative_step: # <<<<<<<<<<<<<< + * stop = -1 + * else: + */ + /*else*/ { + if (__pyx_v_negative_step) { + + /* "View.MemoryView":855 + * else: + * if negative_step: + * stop = -1 # <<<<<<<<<<<<<< + * else: + * stop = shape + */ + __pyx_v_stop = -1L; + + /* "View.MemoryView":854 + * stop = shape + * else: + * if negative_step: # <<<<<<<<<<<<<< + * stop = -1 + * else: + */ + goto __pyx_L16; + } + + /* "View.MemoryView":857 + * stop = -1 + * else: + * stop = shape # <<<<<<<<<<<<<< + * + * + */ + /*else*/ { + __pyx_v_stop = __pyx_v_shape; + } + __pyx_L16:; + } + __pyx_L13:; + + /* "View.MemoryView":861 + * + * with cython.cdivision(True): + * new_shape = (stop - start) // step # <<<<<<<<<<<<<< + * + * if (stop - start) - step * new_shape: + */ + __pyx_v_new_shape = ((__pyx_v_stop - __pyx_v_start) / __pyx_v_step); + + /* "View.MemoryView":863 + * new_shape = (stop - start) // step + * + * if (stop - start) - step * new_shape: # <<<<<<<<<<<<<< + * new_shape += 1 + * + */ + __pyx_t_2 = (((__pyx_v_stop - __pyx_v_start) - (__pyx_v_step * __pyx_v_new_shape)) != 0); + if (__pyx_t_2) { + + /* "View.MemoryView":864 + * + * if (stop - start) - step * new_shape: + * new_shape += 1 # <<<<<<<<<<<<<< + * + * if new_shape < 0: + */ + __pyx_v_new_shape = (__pyx_v_new_shape + 1); + + /* "View.MemoryView":863 + * new_shape = (stop - start) // step + * + * if (stop - start) - step * new_shape: # <<<<<<<<<<<<<< + * new_shape += 1 + * + */ + } + + /* "View.MemoryView":866 + * new_shape += 1 + * + * if new_shape < 0: # <<<<<<<<<<<<<< + * new_shape = 0 + * + */ + __pyx_t_2 = (__pyx_v_new_shape < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":867 + * + * if new_shape < 0: + * new_shape = 0 # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_new_shape = 0; + + /* "View.MemoryView":866 + * new_shape += 1 + * + * if new_shape < 0: # <<<<<<<<<<<<<< + * new_shape = 0 + * + */ + } + + /* "View.MemoryView":870 + * + * + * dst.strides[new_ndim] = stride * step # <<<<<<<<<<<<<< + * dst.shape[new_ndim] = new_shape + * dst.suboffsets[new_ndim] = suboffset + */ + (__pyx_v_dst->strides[__pyx_v_new_ndim]) = (__pyx_v_stride * __pyx_v_step); + + /* "View.MemoryView":871 + * + * dst.strides[new_ndim] = stride * step + * dst.shape[new_ndim] = new_shape # <<<<<<<<<<<<<< + * dst.suboffsets[new_ndim] = suboffset + * + */ + (__pyx_v_dst->shape[__pyx_v_new_ndim]) = __pyx_v_new_shape; + + /* "View.MemoryView":872 + * dst.strides[new_ndim] = stride * step + * dst.shape[new_ndim] = new_shape + * dst.suboffsets[new_ndim] = suboffset # <<<<<<<<<<<<<< + * + * + */ + (__pyx_v_dst->suboffsets[__pyx_v_new_ndim]) = __pyx_v_suboffset; + } + __pyx_L3:; + + /* "View.MemoryView":875 + * + * + * if suboffset_dim[0] < 0: # <<<<<<<<<<<<<< + * dst.data += start * stride + * else: + */ + __pyx_t_2 = ((__pyx_v_suboffset_dim[0]) < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":876 + * + * if suboffset_dim[0] < 0: + * dst.data += start * stride # <<<<<<<<<<<<<< + * else: + * dst.suboffsets[suboffset_dim[0]] += start * stride + */ + __pyx_v_dst->data = (__pyx_v_dst->data + (__pyx_v_start * __pyx_v_stride)); + + /* "View.MemoryView":875 + * + * + * if suboffset_dim[0] < 0: # <<<<<<<<<<<<<< + * dst.data += start * stride + * else: + */ + goto __pyx_L19; + } + + /* "View.MemoryView":878 + * dst.data += start * stride + * else: + * dst.suboffsets[suboffset_dim[0]] += start * stride # <<<<<<<<<<<<<< + * + * if suboffset >= 0: + */ + /*else*/ { + __pyx_t_3 = (__pyx_v_suboffset_dim[0]); + (__pyx_v_dst->suboffsets[__pyx_t_3]) = ((__pyx_v_dst->suboffsets[__pyx_t_3]) + (__pyx_v_start * __pyx_v_stride)); + } + __pyx_L19:; + + /* "View.MemoryView":880 + * dst.suboffsets[suboffset_dim[0]] += start * stride + * + * if suboffset >= 0: # <<<<<<<<<<<<<< + * if not is_slice: + * if new_ndim == 0: + */ + __pyx_t_2 = (__pyx_v_suboffset >= 0); + if (__pyx_t_2) { + + /* "View.MemoryView":881 + * + * if suboffset >= 0: + * if not is_slice: # <<<<<<<<<<<<<< + * if new_ndim == 0: + * dst.data = ( dst.data)[0] + suboffset + */ + __pyx_t_2 = (!__pyx_v_is_slice); + if (__pyx_t_2) { + + /* "View.MemoryView":882 + * if suboffset >= 0: + * if not is_slice: + * if new_ndim == 0: # <<<<<<<<<<<<<< + * dst.data = ( dst.data)[0] + suboffset + * else: + */ + __pyx_t_2 = (__pyx_v_new_ndim == 0); + if (__pyx_t_2) { + + /* "View.MemoryView":883 + * if not is_slice: + * if new_ndim == 0: + * dst.data = ( dst.data)[0] + suboffset # <<<<<<<<<<<<<< + * else: + * _err_dim(PyExc_IndexError, "All dimensions preceding dimension %d " + */ + __pyx_v_dst->data = ((((char **)__pyx_v_dst->data)[0]) + __pyx_v_suboffset); + + /* "View.MemoryView":882 + * if suboffset >= 0: + * if not is_slice: + * if new_ndim == 0: # <<<<<<<<<<<<<< + * dst.data = ( dst.data)[0] + suboffset + * else: + */ + goto __pyx_L22; + } + + /* "View.MemoryView":885 + * dst.data = ( dst.data)[0] + suboffset + * else: + * _err_dim(PyExc_IndexError, "All dimensions preceding dimension %d " # <<<<<<<<<<<<<< + * "must be indexed and not sliced", dim) + * else: + */ + /*else*/ { + + /* "View.MemoryView":886 + * else: + * _err_dim(PyExc_IndexError, "All dimensions preceding dimension %d " + * "must be indexed and not sliced", dim) # <<<<<<<<<<<<<< + * else: + * suboffset_dim[0] = new_ndim + */ + __pyx_t_3 = __pyx_memoryview_err_dim(PyExc_IndexError, __pyx_kp_s_All_dimensions_preceding_dimensi, __pyx_v_dim); if (unlikely(__pyx_t_3 == ((int)-1))) __PYX_ERR(1, 885, __pyx_L1_error) + } + __pyx_L22:; + + /* "View.MemoryView":881 + * + * if suboffset >= 0: + * if not is_slice: # <<<<<<<<<<<<<< + * if new_ndim == 0: + * dst.data = ( dst.data)[0] + suboffset + */ + goto __pyx_L21; + } + + /* "View.MemoryView":888 + * "must be indexed and not sliced", dim) + * else: + * suboffset_dim[0] = new_ndim # <<<<<<<<<<<<<< + * + * return 0 + */ + /*else*/ { + (__pyx_v_suboffset_dim[0]) = __pyx_v_new_ndim; + } + __pyx_L21:; + + /* "View.MemoryView":880 + * dst.suboffsets[suboffset_dim[0]] += start * stride + * + * if suboffset >= 0: # <<<<<<<<<<<<<< + * if not is_slice: + * if new_ndim == 0: + */ + } + + /* "View.MemoryView":890 + * suboffset_dim[0] = new_ndim + * + * return 0 # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":793 + * + * @cname('__pyx_memoryview_slice_memviewslice') + * cdef int slice_memviewslice( # <<<<<<<<<<<<<< + * __Pyx_memviewslice *dst, + * Py_ssize_t shape, Py_ssize_t stride, Py_ssize_t suboffset, + */ + + /* function exit code */ + __pyx_L1_error:; + #ifdef WITH_THREAD + __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_AddTraceback("View.MemoryView.slice_memviewslice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":896 + * + * @cname('__pyx_pybuffer_index') + * cdef char *pybuffer_index(Py_buffer *view, char *bufp, Py_ssize_t index, # <<<<<<<<<<<<<< + * Py_ssize_t dim) except NULL: + * cdef Py_ssize_t shape, stride, suboffset = -1 + */ + +static char *__pyx_pybuffer_index(Py_buffer *__pyx_v_view, char *__pyx_v_bufp, Py_ssize_t __pyx_v_index, Py_ssize_t __pyx_v_dim) { + Py_ssize_t __pyx_v_shape; + Py_ssize_t __pyx_v_stride; + Py_ssize_t __pyx_v_suboffset; + Py_ssize_t __pyx_v_itemsize; + char *__pyx_v_resultp; + char *__pyx_r; + __Pyx_RefNannyDeclarations + Py_ssize_t __pyx_t_1; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + Py_UCS4 __pyx_t_4; + PyObject *__pyx_t_5 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("pybuffer_index", 1); + + /* "View.MemoryView":898 + * cdef char *pybuffer_index(Py_buffer *view, char *bufp, Py_ssize_t index, + * Py_ssize_t dim) except NULL: + * cdef Py_ssize_t shape, stride, suboffset = -1 # <<<<<<<<<<<<<< + * cdef Py_ssize_t itemsize = view.itemsize + * cdef char *resultp + */ + __pyx_v_suboffset = -1L; + + /* "View.MemoryView":899 + * Py_ssize_t dim) except NULL: + * cdef Py_ssize_t shape, stride, suboffset = -1 + * cdef Py_ssize_t itemsize = view.itemsize # <<<<<<<<<<<<<< + * cdef char *resultp + * + */ + __pyx_t_1 = __pyx_v_view->itemsize; + __pyx_v_itemsize = __pyx_t_1; + + /* "View.MemoryView":902 + * cdef char *resultp + * + * if view.ndim == 0: # <<<<<<<<<<<<<< + * shape = view.len // itemsize + * stride = itemsize + */ + __pyx_t_2 = (__pyx_v_view->ndim == 0); + if (__pyx_t_2) { + + /* "View.MemoryView":903 + * + * if view.ndim == 0: + * shape = view.len // itemsize # <<<<<<<<<<<<<< + * stride = itemsize + * else: + */ + if (unlikely(__pyx_v_itemsize == 0)) { + PyErr_SetString(PyExc_ZeroDivisionError, "integer division or modulo by zero"); + __PYX_ERR(1, 903, __pyx_L1_error) + } + else if (sizeof(Py_ssize_t) == sizeof(long) && (!(((Py_ssize_t)-1) > 0)) && unlikely(__pyx_v_itemsize == (Py_ssize_t)-1) && unlikely(__Pyx_UNARY_NEG_WOULD_OVERFLOW(__pyx_v_view->len))) { + PyErr_SetString(PyExc_OverflowError, "value too large to perform division"); + __PYX_ERR(1, 903, __pyx_L1_error) + } + __pyx_v_shape = __Pyx_div_Py_ssize_t(__pyx_v_view->len, __pyx_v_itemsize); + + /* "View.MemoryView":904 + * if view.ndim == 0: + * shape = view.len // itemsize + * stride = itemsize # <<<<<<<<<<<<<< + * else: + * shape = view.shape[dim] + */ + __pyx_v_stride = __pyx_v_itemsize; + + /* "View.MemoryView":902 + * cdef char *resultp + * + * if view.ndim == 0: # <<<<<<<<<<<<<< + * shape = view.len // itemsize + * stride = itemsize + */ + goto __pyx_L3; + } + + /* "View.MemoryView":906 + * stride = itemsize + * else: + * shape = view.shape[dim] # <<<<<<<<<<<<<< + * stride = view.strides[dim] + * if view.suboffsets != NULL: + */ + /*else*/ { + __pyx_v_shape = (__pyx_v_view->shape[__pyx_v_dim]); + + /* "View.MemoryView":907 + * else: + * shape = view.shape[dim] + * stride = view.strides[dim] # <<<<<<<<<<<<<< + * if view.suboffsets != NULL: + * suboffset = view.suboffsets[dim] + */ + __pyx_v_stride = (__pyx_v_view->strides[__pyx_v_dim]); + + /* "View.MemoryView":908 + * shape = view.shape[dim] + * stride = view.strides[dim] + * if view.suboffsets != NULL: # <<<<<<<<<<<<<< + * suboffset = view.suboffsets[dim] + * + */ + __pyx_t_2 = (__pyx_v_view->suboffsets != NULL); + if (__pyx_t_2) { + + /* "View.MemoryView":909 + * stride = view.strides[dim] + * if view.suboffsets != NULL: + * suboffset = view.suboffsets[dim] # <<<<<<<<<<<<<< + * + * if index < 0: + */ + __pyx_v_suboffset = (__pyx_v_view->suboffsets[__pyx_v_dim]); + + /* "View.MemoryView":908 + * shape = view.shape[dim] + * stride = view.strides[dim] + * if view.suboffsets != NULL: # <<<<<<<<<<<<<< + * suboffset = view.suboffsets[dim] + * + */ + } + } + __pyx_L3:; + + /* "View.MemoryView":911 + * suboffset = view.suboffsets[dim] + * + * if index < 0: # <<<<<<<<<<<<<< + * index += view.shape[dim] + * if index < 0: + */ + __pyx_t_2 = (__pyx_v_index < 0); + if (__pyx_t_2) { + + /* "View.MemoryView":912 + * + * if index < 0: + * index += view.shape[dim] # <<<<<<<<<<<<<< + * if index < 0: + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + */ + __pyx_v_index = (__pyx_v_index + (__pyx_v_view->shape[__pyx_v_dim])); + + /* "View.MemoryView":913 + * if index < 0: + * index += view.shape[dim] + * if index < 0: # <<<<<<<<<<<<<< + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + */ + __pyx_t_2 = (__pyx_v_index < 0); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":914 + * index += view.shape[dim] + * if index < 0: + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" # <<<<<<<<<<<<<< + * + * if index >= shape: + */ + __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 914, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_1 = 0; + __pyx_t_4 = 127; + __Pyx_INCREF(__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + __pyx_t_1 += 37; + __Pyx_GIVEREF(__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + PyTuple_SET_ITEM(__pyx_t_3, 0, __pyx_kp_u_Out_of_bounds_on_buffer_access_a); + __pyx_t_5 = __Pyx_PyUnicode_From_Py_ssize_t(__pyx_v_dim, 0, ' ', 'd'); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 914, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_1 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_5); + __Pyx_GIVEREF(__pyx_t_5); + PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_t_5); + __pyx_t_5 = 0; + __Pyx_INCREF(__pyx_kp_u__7); + __pyx_t_1 += 1; + __Pyx_GIVEREF(__pyx_kp_u__7); + PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_kp_u__7); + __pyx_t_5 = __Pyx_PyUnicode_Join(__pyx_t_3, 3, __pyx_t_1, __pyx_t_4); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 914, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_Raise(__pyx_builtin_IndexError, __pyx_t_5, 0, 0); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __PYX_ERR(1, 914, __pyx_L1_error) + + /* "View.MemoryView":913 + * if index < 0: + * index += view.shape[dim] + * if index < 0: # <<<<<<<<<<<<<< + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + */ + } + + /* "View.MemoryView":911 + * suboffset = view.suboffsets[dim] + * + * if index < 0: # <<<<<<<<<<<<<< + * index += view.shape[dim] + * if index < 0: + */ + } + + /* "View.MemoryView":916 + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + * if index >= shape: # <<<<<<<<<<<<<< + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + */ + __pyx_t_2 = (__pyx_v_index >= __pyx_v_shape); + if (unlikely(__pyx_t_2)) { + + /* "View.MemoryView":917 + * + * if index >= shape: + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" # <<<<<<<<<<<<<< + * + * resultp = bufp + index * stride + */ + __pyx_t_5 = PyTuple_New(3); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 917, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_1 = 0; + __pyx_t_4 = 127; + __Pyx_INCREF(__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + __pyx_t_1 += 37; + __Pyx_GIVEREF(__pyx_kp_u_Out_of_bounds_on_buffer_access_a); + PyTuple_SET_ITEM(__pyx_t_5, 0, __pyx_kp_u_Out_of_bounds_on_buffer_access_a); + __pyx_t_3 = __Pyx_PyUnicode_From_Py_ssize_t(__pyx_v_dim, 0, ' ', 'd'); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 917, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_1 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_3); + __Pyx_GIVEREF(__pyx_t_3); + PyTuple_SET_ITEM(__pyx_t_5, 1, __pyx_t_3); + __pyx_t_3 = 0; + __Pyx_INCREF(__pyx_kp_u__7); + __pyx_t_1 += 1; + __Pyx_GIVEREF(__pyx_kp_u__7); + PyTuple_SET_ITEM(__pyx_t_5, 2, __pyx_kp_u__7); + __pyx_t_3 = __Pyx_PyUnicode_Join(__pyx_t_5, 3, __pyx_t_1, __pyx_t_4); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 917, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_Raise(__pyx_builtin_IndexError, __pyx_t_3, 0, 0); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __PYX_ERR(1, 917, __pyx_L1_error) + + /* "View.MemoryView":916 + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + * if index >= shape: # <<<<<<<<<<<<<< + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + */ + } + + /* "View.MemoryView":919 + * raise IndexError, f"Out of bounds on buffer access (axis {dim})" + * + * resultp = bufp + index * stride # <<<<<<<<<<<<<< + * if suboffset >= 0: + * resultp = ( resultp)[0] + suboffset + */ + __pyx_v_resultp = (__pyx_v_bufp + (__pyx_v_index * __pyx_v_stride)); + + /* "View.MemoryView":920 + * + * resultp = bufp + index * stride + * if suboffset >= 0: # <<<<<<<<<<<<<< + * resultp = ( resultp)[0] + suboffset + * + */ + __pyx_t_2 = (__pyx_v_suboffset >= 0); + if (__pyx_t_2) { + + /* "View.MemoryView":921 + * resultp = bufp + index * stride + * if suboffset >= 0: + * resultp = ( resultp)[0] + suboffset # <<<<<<<<<<<<<< + * + * return resultp + */ + __pyx_v_resultp = ((((char **)__pyx_v_resultp)[0]) + __pyx_v_suboffset); + + /* "View.MemoryView":920 + * + * resultp = bufp + index * stride + * if suboffset >= 0: # <<<<<<<<<<<<<< + * resultp = ( resultp)[0] + suboffset + * + */ + } + + /* "View.MemoryView":923 + * resultp = ( resultp)[0] + suboffset + * + * return resultp # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = __pyx_v_resultp; + goto __pyx_L0; + + /* "View.MemoryView":896 + * + * @cname('__pyx_pybuffer_index') + * cdef char *pybuffer_index(Py_buffer *view, char *bufp, Py_ssize_t index, # <<<<<<<<<<<<<< + * Py_ssize_t dim) except NULL: + * cdef Py_ssize_t shape, stride, suboffset = -1 + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_AddTraceback("View.MemoryView.pybuffer_index", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":929 + * + * @cname('__pyx_memslice_transpose') + * cdef int transpose_memslice(__Pyx_memviewslice *memslice) except -1 nogil: # <<<<<<<<<<<<<< + * cdef int ndim = memslice.memview.view.ndim + * + */ + +static int __pyx_memslice_transpose(__Pyx_memviewslice *__pyx_v_memslice) { + int __pyx_v_ndim; + Py_ssize_t *__pyx_v_shape; + Py_ssize_t *__pyx_v_strides; + int __pyx_v_i; + int __pyx_v_j; + int __pyx_r; + int __pyx_t_1; + Py_ssize_t *__pyx_t_2; + long __pyx_t_3; + long __pyx_t_4; + Py_ssize_t __pyx_t_5; + Py_ssize_t __pyx_t_6; + int __pyx_t_7; + int __pyx_t_8; + int __pyx_t_9; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save; + #endif + + /* "View.MemoryView":930 + * @cname('__pyx_memslice_transpose') + * cdef int transpose_memslice(__Pyx_memviewslice *memslice) except -1 nogil: + * cdef int ndim = memslice.memview.view.ndim # <<<<<<<<<<<<<< + * + * cdef Py_ssize_t *shape = memslice.shape + */ + __pyx_t_1 = __pyx_v_memslice->memview->view.ndim; + __pyx_v_ndim = __pyx_t_1; + + /* "View.MemoryView":932 + * cdef int ndim = memslice.memview.view.ndim + * + * cdef Py_ssize_t *shape = memslice.shape # <<<<<<<<<<<<<< + * cdef Py_ssize_t *strides = memslice.strides + * + */ + __pyx_t_2 = __pyx_v_memslice->shape; + __pyx_v_shape = __pyx_t_2; + + /* "View.MemoryView":933 + * + * cdef Py_ssize_t *shape = memslice.shape + * cdef Py_ssize_t *strides = memslice.strides # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_2 = __pyx_v_memslice->strides; + __pyx_v_strides = __pyx_t_2; + + /* "View.MemoryView":937 + * + * cdef int i, j + * for i in range(ndim // 2): # <<<<<<<<<<<<<< + * j = ndim - 1 - i + * strides[i], strides[j] = strides[j], strides[i] + */ + __pyx_t_3 = __Pyx_div_long(__pyx_v_ndim, 2); + __pyx_t_4 = __pyx_t_3; + for (__pyx_t_1 = 0; __pyx_t_1 < __pyx_t_4; __pyx_t_1+=1) { + __pyx_v_i = __pyx_t_1; + + /* "View.MemoryView":938 + * cdef int i, j + * for i in range(ndim // 2): + * j = ndim - 1 - i # <<<<<<<<<<<<<< + * strides[i], strides[j] = strides[j], strides[i] + * shape[i], shape[j] = shape[j], shape[i] + */ + __pyx_v_j = ((__pyx_v_ndim - 1) - __pyx_v_i); + + /* "View.MemoryView":939 + * for i in range(ndim // 2): + * j = ndim - 1 - i + * strides[i], strides[j] = strides[j], strides[i] # <<<<<<<<<<<<<< + * shape[i], shape[j] = shape[j], shape[i] + * + */ + __pyx_t_5 = (__pyx_v_strides[__pyx_v_j]); + __pyx_t_6 = (__pyx_v_strides[__pyx_v_i]); + (__pyx_v_strides[__pyx_v_i]) = __pyx_t_5; + (__pyx_v_strides[__pyx_v_j]) = __pyx_t_6; + + /* "View.MemoryView":940 + * j = ndim - 1 - i + * strides[i], strides[j] = strides[j], strides[i] + * shape[i], shape[j] = shape[j], shape[i] # <<<<<<<<<<<<<< + * + * if memslice.suboffsets[i] >= 0 or memslice.suboffsets[j] >= 0: + */ + __pyx_t_6 = (__pyx_v_shape[__pyx_v_j]); + __pyx_t_5 = (__pyx_v_shape[__pyx_v_i]); + (__pyx_v_shape[__pyx_v_i]) = __pyx_t_6; + (__pyx_v_shape[__pyx_v_j]) = __pyx_t_5; + + /* "View.MemoryView":942 + * shape[i], shape[j] = shape[j], shape[i] + * + * if memslice.suboffsets[i] >= 0 or memslice.suboffsets[j] >= 0: # <<<<<<<<<<<<<< + * _err(PyExc_ValueError, "Cannot transpose memoryview with indirect dimensions") + * + */ + __pyx_t_8 = ((__pyx_v_memslice->suboffsets[__pyx_v_i]) >= 0); + if (!__pyx_t_8) { + } else { + __pyx_t_7 = __pyx_t_8; + goto __pyx_L6_bool_binop_done; + } + __pyx_t_8 = ((__pyx_v_memslice->suboffsets[__pyx_v_j]) >= 0); + __pyx_t_7 = __pyx_t_8; + __pyx_L6_bool_binop_done:; + if (__pyx_t_7) { + + /* "View.MemoryView":943 + * + * if memslice.suboffsets[i] >= 0 or memslice.suboffsets[j] >= 0: + * _err(PyExc_ValueError, "Cannot transpose memoryview with indirect dimensions") # <<<<<<<<<<<<<< + * + * return 0 + */ + __pyx_t_9 = __pyx_memoryview_err(PyExc_ValueError, __pyx_kp_s_Cannot_transpose_memoryview_with); if (unlikely(__pyx_t_9 == ((int)-1))) __PYX_ERR(1, 943, __pyx_L1_error) + + /* "View.MemoryView":942 + * shape[i], shape[j] = shape[j], shape[i] + * + * if memslice.suboffsets[i] >= 0 or memslice.suboffsets[j] >= 0: # <<<<<<<<<<<<<< + * _err(PyExc_ValueError, "Cannot transpose memoryview with indirect dimensions") + * + */ + } + } + + /* "View.MemoryView":945 + * _err(PyExc_ValueError, "Cannot transpose memoryview with indirect dimensions") + * + * return 0 # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":929 + * + * @cname('__pyx_memslice_transpose') + * cdef int transpose_memslice(__Pyx_memviewslice *memslice) except -1 nogil: # <<<<<<<<<<<<<< + * cdef int ndim = memslice.memview.view.ndim + * + */ + + /* function exit code */ + __pyx_L1_error:; + #ifdef WITH_THREAD + __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_AddTraceback("View.MemoryView.transpose_memslice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":963 + * cdef int (*to_dtype_func)(char *, object) except 0 + * + * def __dealloc__(self): # <<<<<<<<<<<<<< + * __PYX_XCLEAR_MEMVIEW(&self.from_slice, 1) + * + */ + +/* Python wrapper */ +static void __pyx_memoryviewslice___dealloc__(PyObject *__pyx_v_self); /*proto*/ +static void __pyx_memoryviewslice___dealloc__(PyObject *__pyx_v_self) { + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__dealloc__ (wrapper)", 0); + __pyx_kwvalues = __Pyx_KwValues_VARARGS(__pyx_args, __pyx_nargs); + __pyx_memoryviewslice___pyx_pf_15View_dot_MemoryView_16_memoryviewslice___dealloc__(((struct __pyx_memoryviewslice_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); +} + +static void __pyx_memoryviewslice___pyx_pf_15View_dot_MemoryView_16_memoryviewslice___dealloc__(struct __pyx_memoryviewslice_obj *__pyx_v_self) { + + /* "View.MemoryView":964 + * + * def __dealloc__(self): + * __PYX_XCLEAR_MEMVIEW(&self.from_slice, 1) # <<<<<<<<<<<<<< + * + * cdef convert_item_to_object(self, char *itemp): + */ + __PYX_XCLEAR_MEMVIEW((&__pyx_v_self->from_slice), 1); + + /* "View.MemoryView":963 + * cdef int (*to_dtype_func)(char *, object) except 0 + * + * def __dealloc__(self): # <<<<<<<<<<<<<< + * __PYX_XCLEAR_MEMVIEW(&self.from_slice, 1) + * + */ + + /* function exit code */ +} + +/* "View.MemoryView":966 + * __PYX_XCLEAR_MEMVIEW(&self.from_slice, 1) + * + * cdef convert_item_to_object(self, char *itemp): # <<<<<<<<<<<<<< + * if self.to_object_func != NULL: + * return self.to_object_func(itemp) + */ + +static PyObject *__pyx_memoryviewslice_convert_item_to_object(struct __pyx_memoryviewslice_obj *__pyx_v_self, char *__pyx_v_itemp) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("convert_item_to_object", 1); + + /* "View.MemoryView":967 + * + * cdef convert_item_to_object(self, char *itemp): + * if self.to_object_func != NULL: # <<<<<<<<<<<<<< + * return self.to_object_func(itemp) + * else: + */ + __pyx_t_1 = (__pyx_v_self->to_object_func != NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":968 + * cdef convert_item_to_object(self, char *itemp): + * if self.to_object_func != NULL: + * return self.to_object_func(itemp) # <<<<<<<<<<<<<< + * else: + * return memoryview.convert_item_to_object(self, itemp) + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __pyx_v_self->to_object_func(__pyx_v_itemp); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 968, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* "View.MemoryView":967 + * + * cdef convert_item_to_object(self, char *itemp): + * if self.to_object_func != NULL: # <<<<<<<<<<<<<< + * return self.to_object_func(itemp) + * else: + */ + } + + /* "View.MemoryView":970 + * return self.to_object_func(itemp) + * else: + * return memoryview.convert_item_to_object(self, itemp) # <<<<<<<<<<<<<< + * + * cdef assign_item_from_object(self, char *itemp, object value): + */ + /*else*/ { + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __pyx_memoryview_convert_item_to_object(((struct __pyx_memoryview_obj *)__pyx_v_self), __pyx_v_itemp); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 970, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + } + + /* "View.MemoryView":966 + * __PYX_XCLEAR_MEMVIEW(&self.from_slice, 1) + * + * cdef convert_item_to_object(self, char *itemp): # <<<<<<<<<<<<<< + * if self.to_object_func != NULL: + * return self.to_object_func(itemp) + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView._memoryviewslice.convert_item_to_object", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":972 + * return memoryview.convert_item_to_object(self, itemp) + * + * cdef assign_item_from_object(self, char *itemp, object value): # <<<<<<<<<<<<<< + * if self.to_dtype_func != NULL: + * self.to_dtype_func(itemp, value) + */ + +static PyObject *__pyx_memoryviewslice_assign_item_from_object(struct __pyx_memoryviewslice_obj *__pyx_v_self, char *__pyx_v_itemp, PyObject *__pyx_v_value) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("assign_item_from_object", 1); + + /* "View.MemoryView":973 + * + * cdef assign_item_from_object(self, char *itemp, object value): + * if self.to_dtype_func != NULL: # <<<<<<<<<<<<<< + * self.to_dtype_func(itemp, value) + * else: + */ + __pyx_t_1 = (__pyx_v_self->to_dtype_func != NULL); + if (__pyx_t_1) { + + /* "View.MemoryView":974 + * cdef assign_item_from_object(self, char *itemp, object value): + * if self.to_dtype_func != NULL: + * self.to_dtype_func(itemp, value) # <<<<<<<<<<<<<< + * else: + * memoryview.assign_item_from_object(self, itemp, value) + */ + __pyx_t_2 = __pyx_v_self->to_dtype_func(__pyx_v_itemp, __pyx_v_value); if (unlikely(__pyx_t_2 == ((int)0))) __PYX_ERR(1, 974, __pyx_L1_error) + + /* "View.MemoryView":973 + * + * cdef assign_item_from_object(self, char *itemp, object value): + * if self.to_dtype_func != NULL: # <<<<<<<<<<<<<< + * self.to_dtype_func(itemp, value) + * else: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":976 + * self.to_dtype_func(itemp, value) + * else: + * memoryview.assign_item_from_object(self, itemp, value) # <<<<<<<<<<<<<< + * + * cdef _get_base(self): + */ + /*else*/ { + __pyx_t_3 = __pyx_memoryview_assign_item_from_object(((struct __pyx_memoryview_obj *)__pyx_v_self), __pyx_v_itemp, __pyx_v_value); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 976, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + } + __pyx_L3:; + + /* "View.MemoryView":972 + * return memoryview.convert_item_to_object(self, itemp) + * + * cdef assign_item_from_object(self, char *itemp, object value): # <<<<<<<<<<<<<< + * if self.to_dtype_func != NULL: + * self.to_dtype_func(itemp, value) + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView._memoryviewslice.assign_item_from_object", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":978 + * memoryview.assign_item_from_object(self, itemp, value) + * + * cdef _get_base(self): # <<<<<<<<<<<<<< + * return self.from_object + * + */ + +static PyObject *__pyx_memoryviewslice__get_base(struct __pyx_memoryviewslice_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("_get_base", 1); + + /* "View.MemoryView":979 + * + * cdef _get_base(self): + * return self.from_object # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v_self->from_object); + __pyx_r = __pyx_v_self->from_object; + goto __pyx_L0; + + /* "View.MemoryView":978 + * memoryview.assign_item_from_object(self, itemp, value) + * + * cdef _get_base(self): # <<<<<<<<<<<<<< + * return self.from_object + * + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_memoryviewslice_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_memoryviewslice_1__reduce_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__reduce_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + if (unlikely(__pyx_nargs > 0)) { + __Pyx_RaiseArgtupleInvalid("__reduce_cython__", 1, 0, 0, __pyx_nargs); return NULL;} + if (unlikely(__pyx_kwds) && __Pyx_NumKwargs_FASTCALL(__pyx_kwds) && unlikely(!__Pyx_CheckKeywordStrings(__pyx_kwds, "__reduce_cython__", 0))) return NULL; + __pyx_r = __pyx_pf___pyx_memoryviewslice___reduce_cython__(((struct __pyx_memoryviewslice_obj *)__pyx_v_self)); + + /* function exit code */ + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_memoryviewslice___reduce_cython__(CYTHON_UNUSED struct __pyx_memoryviewslice_obj *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__reduce_cython__", 1); + + /* "(tree fragment)":2 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 2, __pyx_L1_error) + + /* "(tree fragment)":1 + * def __reduce_cython__(self): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView._memoryviewslice.__reduce_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + +/* Python wrapper */ +static PyObject *__pyx_pw___pyx_memoryviewslice_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyObject *__pyx_pw___pyx_memoryviewslice_3__setstate_cython__(PyObject *__pyx_v_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + CYTHON_UNUSED PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[1] = {0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__setstate_cython__ (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 3, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__setstate_cython__") < 0)) __PYX_ERR(1, 3, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 1)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + } + __pyx_v___pyx_state = values[0]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__setstate_cython__", 1, 1, 1, __pyx_nargs); __PYX_ERR(1, 3, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView._memoryviewslice.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf___pyx_memoryviewslice_2__setstate_cython__(((struct __pyx_memoryviewslice_obj *)__pyx_v_self), __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf___pyx_memoryviewslice_2__setstate_cython__(CYTHON_UNUSED struct __pyx_memoryviewslice_obj *__pyx_v_self, CYTHON_UNUSED PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__setstate_cython__", 1); + + /* "(tree fragment)":4 + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" # <<<<<<<<<<<<<< + */ + __Pyx_Raise(__pyx_builtin_TypeError, __pyx_kp_s_no_default___reduce___due_to_non, 0, 0); + __PYX_ERR(1, 4, __pyx_L1_error) + + /* "(tree fragment)":3 + * def __reduce_cython__(self): + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + * def __setstate_cython__(self, __pyx_state): # <<<<<<<<<<<<<< + * raise TypeError, "no default __reduce__ due to non-trivial __cinit__" + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView._memoryviewslice.__setstate_cython__", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":999 + * + * @cname('__pyx_memoryview_fromslice') + * cdef memoryview_fromslice(__Pyx_memviewslice memviewslice, # <<<<<<<<<<<<<< + * int ndim, + * object (*to_object_func)(char *), + */ + +static PyObject *__pyx_memoryview_fromslice(__Pyx_memviewslice __pyx_v_memviewslice, int __pyx_v_ndim, PyObject *(*__pyx_v_to_object_func)(char *), int (*__pyx_v_to_dtype_func)(char *, PyObject *), int __pyx_v_dtype_is_object) { + struct __pyx_memoryviewslice_obj *__pyx_v_result = 0; + Py_ssize_t __pyx_v_suboffset; + PyObject *__pyx_v_length = NULL; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + __Pyx_TypeInfo *__pyx_t_4; + Py_buffer __pyx_t_5; + Py_ssize_t *__pyx_t_6; + Py_ssize_t *__pyx_t_7; + Py_ssize_t *__pyx_t_8; + Py_ssize_t __pyx_t_9; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("memoryview_fromslice", 1); + + /* "View.MemoryView":1007 + * cdef _memoryviewslice result + * + * if memviewslice.memview == Py_None: # <<<<<<<<<<<<<< + * return None + * + */ + __pyx_t_1 = (((PyObject *)__pyx_v_memviewslice.memview) == Py_None); + if (__pyx_t_1) { + + /* "View.MemoryView":1008 + * + * if memviewslice.memview == Py_None: + * return None # <<<<<<<<<<<<<< + * + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + + /* "View.MemoryView":1007 + * cdef _memoryviewslice result + * + * if memviewslice.memview == Py_None: # <<<<<<<<<<<<<< + * return None + * + */ + } + + /* "View.MemoryView":1013 + * + * + * result = _memoryviewslice.__new__(_memoryviewslice, None, 0, dtype_is_object) # <<<<<<<<<<<<<< + * + * result.from_slice = memviewslice + */ + __pyx_t_2 = __Pyx_PyBool_FromLong(__pyx_v_dtype_is_object); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1013, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PyTuple_New(3); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 1013, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_INCREF(Py_None); + __Pyx_GIVEREF(Py_None); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 0, Py_None)) __PYX_ERR(1, 1013, __pyx_L1_error); + __Pyx_INCREF(__pyx_int_0); + __Pyx_GIVEREF(__pyx_int_0); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 1, __pyx_int_0)) __PYX_ERR(1, 1013, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_t_2); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_3, 2, __pyx_t_2)) __PYX_ERR(1, 1013, __pyx_L1_error); + __pyx_t_2 = 0; + __pyx_t_2 = ((PyObject *)__pyx_tp_new__memoryviewslice(((PyTypeObject *)__pyx_memoryviewslice_type), __pyx_t_3, NULL)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1013, __pyx_L1_error) + __Pyx_GOTREF((PyObject *)__pyx_t_2); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_v_result = ((struct __pyx_memoryviewslice_obj *)__pyx_t_2); + __pyx_t_2 = 0; + + /* "View.MemoryView":1015 + * result = _memoryviewslice.__new__(_memoryviewslice, None, 0, dtype_is_object) + * + * result.from_slice = memviewslice # <<<<<<<<<<<<<< + * __PYX_INC_MEMVIEW(&memviewslice, 1) + * + */ + __pyx_v_result->from_slice = __pyx_v_memviewslice; + + /* "View.MemoryView":1016 + * + * result.from_slice = memviewslice + * __PYX_INC_MEMVIEW(&memviewslice, 1) # <<<<<<<<<<<<<< + * + * result.from_object = ( memviewslice.memview)._get_base() + */ + __PYX_INC_MEMVIEW((&__pyx_v_memviewslice), 1); + + /* "View.MemoryView":1018 + * __PYX_INC_MEMVIEW(&memviewslice, 1) + * + * result.from_object = ( memviewslice.memview)._get_base() # <<<<<<<<<<<<<< + * result.typeinfo = memviewslice.memview.typeinfo + * + */ + __pyx_t_2 = ((struct __pyx_vtabstruct_memoryview *)((struct __pyx_memoryview_obj *)__pyx_v_memviewslice.memview)->__pyx_vtab)->_get_base(((struct __pyx_memoryview_obj *)__pyx_v_memviewslice.memview)); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1018, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_GIVEREF(__pyx_t_2); + __Pyx_GOTREF(__pyx_v_result->from_object); + __Pyx_DECREF(__pyx_v_result->from_object); + __pyx_v_result->from_object = __pyx_t_2; + __pyx_t_2 = 0; + + /* "View.MemoryView":1019 + * + * result.from_object = ( memviewslice.memview)._get_base() + * result.typeinfo = memviewslice.memview.typeinfo # <<<<<<<<<<<<<< + * + * result.view = memviewslice.memview.view + */ + __pyx_t_4 = __pyx_v_memviewslice.memview->typeinfo; + __pyx_v_result->__pyx_base.typeinfo = __pyx_t_4; + + /* "View.MemoryView":1021 + * result.typeinfo = memviewslice.memview.typeinfo + * + * result.view = memviewslice.memview.view # <<<<<<<<<<<<<< + * result.view.buf = memviewslice.data + * result.view.ndim = ndim + */ + __pyx_t_5 = __pyx_v_memviewslice.memview->view; + __pyx_v_result->__pyx_base.view = __pyx_t_5; + + /* "View.MemoryView":1022 + * + * result.view = memviewslice.memview.view + * result.view.buf = memviewslice.data # <<<<<<<<<<<<<< + * result.view.ndim = ndim + * (<__pyx_buffer *> &result.view).obj = Py_None + */ + __pyx_v_result->__pyx_base.view.buf = ((void *)__pyx_v_memviewslice.data); + + /* "View.MemoryView":1023 + * result.view = memviewslice.memview.view + * result.view.buf = memviewslice.data + * result.view.ndim = ndim # <<<<<<<<<<<<<< + * (<__pyx_buffer *> &result.view).obj = Py_None + * Py_INCREF(Py_None) + */ + __pyx_v_result->__pyx_base.view.ndim = __pyx_v_ndim; + + /* "View.MemoryView":1024 + * result.view.buf = memviewslice.data + * result.view.ndim = ndim + * (<__pyx_buffer *> &result.view).obj = Py_None # <<<<<<<<<<<<<< + * Py_INCREF(Py_None) + * + */ + ((Py_buffer *)(&__pyx_v_result->__pyx_base.view))->obj = Py_None; + + /* "View.MemoryView":1025 + * result.view.ndim = ndim + * (<__pyx_buffer *> &result.view).obj = Py_None + * Py_INCREF(Py_None) # <<<<<<<<<<<<<< + * + * if (memviewslice.memview).flags & PyBUF_WRITABLE: + */ + Py_INCREF(Py_None); + + /* "View.MemoryView":1027 + * Py_INCREF(Py_None) + * + * if (memviewslice.memview).flags & PyBUF_WRITABLE: # <<<<<<<<<<<<<< + * result.flags = PyBUF_RECORDS + * else: + */ + __pyx_t_1 = ((((struct __pyx_memoryview_obj *)__pyx_v_memviewslice.memview)->flags & PyBUF_WRITABLE) != 0); + if (__pyx_t_1) { + + /* "View.MemoryView":1028 + * + * if (memviewslice.memview).flags & PyBUF_WRITABLE: + * result.flags = PyBUF_RECORDS # <<<<<<<<<<<<<< + * else: + * result.flags = PyBUF_RECORDS_RO + */ + __pyx_v_result->__pyx_base.flags = PyBUF_RECORDS; + + /* "View.MemoryView":1027 + * Py_INCREF(Py_None) + * + * if (memviewslice.memview).flags & PyBUF_WRITABLE: # <<<<<<<<<<<<<< + * result.flags = PyBUF_RECORDS + * else: + */ + goto __pyx_L4; + } + + /* "View.MemoryView":1030 + * result.flags = PyBUF_RECORDS + * else: + * result.flags = PyBUF_RECORDS_RO # <<<<<<<<<<<<<< + * + * result.view.shape = result.from_slice.shape + */ + /*else*/ { + __pyx_v_result->__pyx_base.flags = PyBUF_RECORDS_RO; + } + __pyx_L4:; + + /* "View.MemoryView":1032 + * result.flags = PyBUF_RECORDS_RO + * + * result.view.shape = result.from_slice.shape # <<<<<<<<<<<<<< + * result.view.strides = result.from_slice.strides + * + */ + __pyx_v_result->__pyx_base.view.shape = ((Py_ssize_t *)__pyx_v_result->from_slice.shape); + + /* "View.MemoryView":1033 + * + * result.view.shape = result.from_slice.shape + * result.view.strides = result.from_slice.strides # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_result->__pyx_base.view.strides = ((Py_ssize_t *)__pyx_v_result->from_slice.strides); + + /* "View.MemoryView":1036 + * + * + * result.view.suboffsets = NULL # <<<<<<<<<<<<<< + * for suboffset in result.from_slice.suboffsets[:ndim]: + * if suboffset >= 0: + */ + __pyx_v_result->__pyx_base.view.suboffsets = NULL; + + /* "View.MemoryView":1037 + * + * result.view.suboffsets = NULL + * for suboffset in result.from_slice.suboffsets[:ndim]: # <<<<<<<<<<<<<< + * if suboffset >= 0: + * result.view.suboffsets = result.from_slice.suboffsets + */ + __pyx_t_7 = (__pyx_v_result->from_slice.suboffsets + __pyx_v_ndim); + for (__pyx_t_8 = __pyx_v_result->from_slice.suboffsets; __pyx_t_8 < __pyx_t_7; __pyx_t_8++) { + __pyx_t_6 = __pyx_t_8; + __pyx_v_suboffset = (__pyx_t_6[0]); + + /* "View.MemoryView":1038 + * result.view.suboffsets = NULL + * for suboffset in result.from_slice.suboffsets[:ndim]: + * if suboffset >= 0: # <<<<<<<<<<<<<< + * result.view.suboffsets = result.from_slice.suboffsets + * break + */ + __pyx_t_1 = (__pyx_v_suboffset >= 0); + if (__pyx_t_1) { + + /* "View.MemoryView":1039 + * for suboffset in result.from_slice.suboffsets[:ndim]: + * if suboffset >= 0: + * result.view.suboffsets = result.from_slice.suboffsets # <<<<<<<<<<<<<< + * break + * + */ + __pyx_v_result->__pyx_base.view.suboffsets = ((Py_ssize_t *)__pyx_v_result->from_slice.suboffsets); + + /* "View.MemoryView":1040 + * if suboffset >= 0: + * result.view.suboffsets = result.from_slice.suboffsets + * break # <<<<<<<<<<<<<< + * + * result.view.len = result.view.itemsize + */ + goto __pyx_L6_break; + + /* "View.MemoryView":1038 + * result.view.suboffsets = NULL + * for suboffset in result.from_slice.suboffsets[:ndim]: + * if suboffset >= 0: # <<<<<<<<<<<<<< + * result.view.suboffsets = result.from_slice.suboffsets + * break + */ + } + } + __pyx_L6_break:; + + /* "View.MemoryView":1042 + * break + * + * result.view.len = result.view.itemsize # <<<<<<<<<<<<<< + * for length in result.view.shape[:ndim]: + * result.view.len *= length + */ + __pyx_t_9 = __pyx_v_result->__pyx_base.view.itemsize; + __pyx_v_result->__pyx_base.view.len = __pyx_t_9; + + /* "View.MemoryView":1043 + * + * result.view.len = result.view.itemsize + * for length in result.view.shape[:ndim]: # <<<<<<<<<<<<<< + * result.view.len *= length + * + */ + __pyx_t_7 = (__pyx_v_result->__pyx_base.view.shape + __pyx_v_ndim); + for (__pyx_t_8 = __pyx_v_result->__pyx_base.view.shape; __pyx_t_8 < __pyx_t_7; __pyx_t_8++) { + __pyx_t_6 = __pyx_t_8; + __pyx_t_2 = PyInt_FromSsize_t((__pyx_t_6[0])); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1043, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_XDECREF_SET(__pyx_v_length, __pyx_t_2); + __pyx_t_2 = 0; + + /* "View.MemoryView":1044 + * result.view.len = result.view.itemsize + * for length in result.view.shape[:ndim]: + * result.view.len *= length # <<<<<<<<<<<<<< + * + * result.to_object_func = to_object_func + */ + __pyx_t_2 = PyInt_FromSsize_t(__pyx_v_result->__pyx_base.view.len); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1044, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_t_3 = PyNumber_InPlaceMultiply(__pyx_t_2, __pyx_v_length); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 1044, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __pyx_t_9 = __Pyx_PyIndex_AsSsize_t(__pyx_t_3); if (unlikely((__pyx_t_9 == (Py_ssize_t)-1) && PyErr_Occurred())) __PYX_ERR(1, 1044, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __pyx_v_result->__pyx_base.view.len = __pyx_t_9; + } + + /* "View.MemoryView":1046 + * result.view.len *= length + * + * result.to_object_func = to_object_func # <<<<<<<<<<<<<< + * result.to_dtype_func = to_dtype_func + * + */ + __pyx_v_result->to_object_func = __pyx_v_to_object_func; + + /* "View.MemoryView":1047 + * + * result.to_object_func = to_object_func + * result.to_dtype_func = to_dtype_func # <<<<<<<<<<<<<< + * + * return result + */ + __pyx_v_result->to_dtype_func = __pyx_v_to_dtype_func; + + /* "View.MemoryView":1049 + * result.to_dtype_func = to_dtype_func + * + * return result # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_get_slice_from_memoryview') + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF((PyObject *)__pyx_v_result); + __pyx_r = ((PyObject *)__pyx_v_result); + goto __pyx_L0; + + /* "View.MemoryView":999 + * + * @cname('__pyx_memoryview_fromslice') + * cdef memoryview_fromslice(__Pyx_memviewslice memviewslice, # <<<<<<<<<<<<<< + * int ndim, + * object (*to_object_func)(char *), + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_AddTraceback("View.MemoryView.memoryview_fromslice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_result); + __Pyx_XDECREF(__pyx_v_length); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":1052 + * + * @cname('__pyx_memoryview_get_slice_from_memoryview') + * cdef __Pyx_memviewslice *get_slice_from_memview(memoryview memview, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *mslice) except NULL: + * cdef _memoryviewslice obj + */ + +static __Pyx_memviewslice *__pyx_memoryview_get_slice_from_memoryview(struct __pyx_memoryview_obj *__pyx_v_memview, __Pyx_memviewslice *__pyx_v_mslice) { + struct __pyx_memoryviewslice_obj *__pyx_v_obj = 0; + __Pyx_memviewslice *__pyx_r; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("get_slice_from_memview", 1); + + /* "View.MemoryView":1055 + * __Pyx_memviewslice *mslice) except NULL: + * cdef _memoryviewslice obj + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * obj = memview + * return &obj.from_slice + */ + __pyx_t_1 = __Pyx_TypeCheck(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type); + if (__pyx_t_1) { + + /* "View.MemoryView":1056 + * cdef _memoryviewslice obj + * if isinstance(memview, _memoryviewslice): + * obj = memview # <<<<<<<<<<<<<< + * return &obj.from_slice + * else: + */ + if (!(likely(((((PyObject *)__pyx_v_memview)) == Py_None) || likely(__Pyx_TypeTest(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type))))) __PYX_ERR(1, 1056, __pyx_L1_error) + __pyx_t_2 = ((PyObject *)__pyx_v_memview); + __Pyx_INCREF(__pyx_t_2); + __pyx_v_obj = ((struct __pyx_memoryviewslice_obj *)__pyx_t_2); + __pyx_t_2 = 0; + + /* "View.MemoryView":1057 + * if isinstance(memview, _memoryviewslice): + * obj = memview + * return &obj.from_slice # <<<<<<<<<<<<<< + * else: + * slice_copy(memview, mslice) + */ + __pyx_r = (&__pyx_v_obj->from_slice); + goto __pyx_L0; + + /* "View.MemoryView":1055 + * __Pyx_memviewslice *mslice) except NULL: + * cdef _memoryviewslice obj + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * obj = memview + * return &obj.from_slice + */ + } + + /* "View.MemoryView":1059 + * return &obj.from_slice + * else: + * slice_copy(memview, mslice) # <<<<<<<<<<<<<< + * return mslice + * + */ + /*else*/ { + __pyx_memoryview_slice_copy(__pyx_v_memview, __pyx_v_mslice); + + /* "View.MemoryView":1060 + * else: + * slice_copy(memview, mslice) + * return mslice # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_slice_copy') + */ + __pyx_r = __pyx_v_mslice; + goto __pyx_L0; + } + + /* "View.MemoryView":1052 + * + * @cname('__pyx_memoryview_get_slice_from_memoryview') + * cdef __Pyx_memviewslice *get_slice_from_memview(memoryview memview, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *mslice) except NULL: + * cdef _memoryviewslice obj + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView.get_slice_from_memview", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF((PyObject *)__pyx_v_obj); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":1063 + * + * @cname('__pyx_memoryview_slice_copy') + * cdef void slice_copy(memoryview memview, __Pyx_memviewslice *dst) noexcept: # <<<<<<<<<<<<<< + * cdef int dim + * cdef (Py_ssize_t*) shape, strides, suboffsets + */ + +static void __pyx_memoryview_slice_copy(struct __pyx_memoryview_obj *__pyx_v_memview, __Pyx_memviewslice *__pyx_v_dst) { + int __pyx_v_dim; + Py_ssize_t *__pyx_v_shape; + Py_ssize_t *__pyx_v_strides; + Py_ssize_t *__pyx_v_suboffsets; + Py_ssize_t *__pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_t_4; + Py_ssize_t __pyx_t_5; + int __pyx_t_6; + + /* "View.MemoryView":1067 + * cdef (Py_ssize_t*) shape, strides, suboffsets + * + * shape = memview.view.shape # <<<<<<<<<<<<<< + * strides = memview.view.strides + * suboffsets = memview.view.suboffsets + */ + __pyx_t_1 = __pyx_v_memview->view.shape; + __pyx_v_shape = __pyx_t_1; + + /* "View.MemoryView":1068 + * + * shape = memview.view.shape + * strides = memview.view.strides # <<<<<<<<<<<<<< + * suboffsets = memview.view.suboffsets + * + */ + __pyx_t_1 = __pyx_v_memview->view.strides; + __pyx_v_strides = __pyx_t_1; + + /* "View.MemoryView":1069 + * shape = memview.view.shape + * strides = memview.view.strides + * suboffsets = memview.view.suboffsets # <<<<<<<<<<<<<< + * + * dst.memview = <__pyx_memoryview *> memview + */ + __pyx_t_1 = __pyx_v_memview->view.suboffsets; + __pyx_v_suboffsets = __pyx_t_1; + + /* "View.MemoryView":1071 + * suboffsets = memview.view.suboffsets + * + * dst.memview = <__pyx_memoryview *> memview # <<<<<<<<<<<<<< + * dst.data = memview.view.buf + * + */ + __pyx_v_dst->memview = ((struct __pyx_memoryview_obj *)__pyx_v_memview); + + /* "View.MemoryView":1072 + * + * dst.memview = <__pyx_memoryview *> memview + * dst.data = memview.view.buf # <<<<<<<<<<<<<< + * + * for dim in range(memview.view.ndim): + */ + __pyx_v_dst->data = ((char *)__pyx_v_memview->view.buf); + + /* "View.MemoryView":1074 + * dst.data = memview.view.buf + * + * for dim in range(memview.view.ndim): # <<<<<<<<<<<<<< + * dst.shape[dim] = shape[dim] + * dst.strides[dim] = strides[dim] + */ + __pyx_t_2 = __pyx_v_memview->view.ndim; + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_dim = __pyx_t_4; + + /* "View.MemoryView":1075 + * + * for dim in range(memview.view.ndim): + * dst.shape[dim] = shape[dim] # <<<<<<<<<<<<<< + * dst.strides[dim] = strides[dim] + * dst.suboffsets[dim] = suboffsets[dim] if suboffsets else -1 + */ + (__pyx_v_dst->shape[__pyx_v_dim]) = (__pyx_v_shape[__pyx_v_dim]); + + /* "View.MemoryView":1076 + * for dim in range(memview.view.ndim): + * dst.shape[dim] = shape[dim] + * dst.strides[dim] = strides[dim] # <<<<<<<<<<<<<< + * dst.suboffsets[dim] = suboffsets[dim] if suboffsets else -1 + * + */ + (__pyx_v_dst->strides[__pyx_v_dim]) = (__pyx_v_strides[__pyx_v_dim]); + + /* "View.MemoryView":1077 + * dst.shape[dim] = shape[dim] + * dst.strides[dim] = strides[dim] + * dst.suboffsets[dim] = suboffsets[dim] if suboffsets else -1 # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_copy_object') + */ + __pyx_t_6 = (__pyx_v_suboffsets != 0); + if (__pyx_t_6) { + __pyx_t_5 = (__pyx_v_suboffsets[__pyx_v_dim]); + } else { + __pyx_t_5 = -1L; + } + (__pyx_v_dst->suboffsets[__pyx_v_dim]) = __pyx_t_5; + } + + /* "View.MemoryView":1063 + * + * @cname('__pyx_memoryview_slice_copy') + * cdef void slice_copy(memoryview memview, __Pyx_memviewslice *dst) noexcept: # <<<<<<<<<<<<<< + * cdef int dim + * cdef (Py_ssize_t*) shape, strides, suboffsets + */ + + /* function exit code */ +} + +/* "View.MemoryView":1080 + * + * @cname('__pyx_memoryview_copy_object') + * cdef memoryview_copy(memoryview memview): # <<<<<<<<<<<<<< + * "Create a new memoryview object" + * cdef __Pyx_memviewslice memviewslice + */ + +static PyObject *__pyx_memoryview_copy_object(struct __pyx_memoryview_obj *__pyx_v_memview) { + __Pyx_memviewslice __pyx_v_memviewslice; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("memoryview_copy", 1); + + /* "View.MemoryView":1083 + * "Create a new memoryview object" + * cdef __Pyx_memviewslice memviewslice + * slice_copy(memview, &memviewslice) # <<<<<<<<<<<<<< + * return memoryview_copy_from_slice(memview, &memviewslice) + * + */ + __pyx_memoryview_slice_copy(__pyx_v_memview, (&__pyx_v_memviewslice)); + + /* "View.MemoryView":1084 + * cdef __Pyx_memviewslice memviewslice + * slice_copy(memview, &memviewslice) + * return memoryview_copy_from_slice(memview, &memviewslice) # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_copy_object_from_slice') + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = __pyx_memoryview_copy_object_from_slice(__pyx_v_memview, (&__pyx_v_memviewslice)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 1084, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "View.MemoryView":1080 + * + * @cname('__pyx_memoryview_copy_object') + * cdef memoryview_copy(memoryview memview): # <<<<<<<<<<<<<< + * "Create a new memoryview object" + * cdef __Pyx_memviewslice memviewslice + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("View.MemoryView.memoryview_copy", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":1087 + * + * @cname('__pyx_memoryview_copy_object_from_slice') + * cdef memoryview_copy_from_slice(memoryview memview, __Pyx_memviewslice *memviewslice): # <<<<<<<<<<<<<< + * """ + * Create a new memoryview object from a given memoryview object and slice. + */ + +static PyObject *__pyx_memoryview_copy_object_from_slice(struct __pyx_memoryview_obj *__pyx_v_memview, __Pyx_memviewslice *__pyx_v_memviewslice) { + PyObject *(*__pyx_v_to_object_func)(char *); + int (*__pyx_v_to_dtype_func)(char *, PyObject *); + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *(*__pyx_t_2)(char *); + int (*__pyx_t_3)(char *, PyObject *); + PyObject *__pyx_t_4 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("memoryview_copy_from_slice", 1); + + /* "View.MemoryView":1094 + * cdef int (*to_dtype_func)(char *, object) except 0 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * to_object_func = (<_memoryviewslice> memview).to_object_func + * to_dtype_func = (<_memoryviewslice> memview).to_dtype_func + */ + __pyx_t_1 = __Pyx_TypeCheck(((PyObject *)__pyx_v_memview), __pyx_memoryviewslice_type); + if (__pyx_t_1) { + + /* "View.MemoryView":1095 + * + * if isinstance(memview, _memoryviewslice): + * to_object_func = (<_memoryviewslice> memview).to_object_func # <<<<<<<<<<<<<< + * to_dtype_func = (<_memoryviewslice> memview).to_dtype_func + * else: + */ + __pyx_t_2 = ((struct __pyx_memoryviewslice_obj *)__pyx_v_memview)->to_object_func; + __pyx_v_to_object_func = __pyx_t_2; + + /* "View.MemoryView":1096 + * if isinstance(memview, _memoryviewslice): + * to_object_func = (<_memoryviewslice> memview).to_object_func + * to_dtype_func = (<_memoryviewslice> memview).to_dtype_func # <<<<<<<<<<<<<< + * else: + * to_object_func = NULL + */ + __pyx_t_3 = ((struct __pyx_memoryviewslice_obj *)__pyx_v_memview)->to_dtype_func; + __pyx_v_to_dtype_func = __pyx_t_3; + + /* "View.MemoryView":1094 + * cdef int (*to_dtype_func)(char *, object) except 0 + * + * if isinstance(memview, _memoryviewslice): # <<<<<<<<<<<<<< + * to_object_func = (<_memoryviewslice> memview).to_object_func + * to_dtype_func = (<_memoryviewslice> memview).to_dtype_func + */ + goto __pyx_L3; + } + + /* "View.MemoryView":1098 + * to_dtype_func = (<_memoryviewslice> memview).to_dtype_func + * else: + * to_object_func = NULL # <<<<<<<<<<<<<< + * to_dtype_func = NULL + * + */ + /*else*/ { + __pyx_v_to_object_func = NULL; + + /* "View.MemoryView":1099 + * else: + * to_object_func = NULL + * to_dtype_func = NULL # <<<<<<<<<<<<<< + * + * return memoryview_fromslice(memviewslice[0], memview.view.ndim, + */ + __pyx_v_to_dtype_func = NULL; + } + __pyx_L3:; + + /* "View.MemoryView":1101 + * to_dtype_func = NULL + * + * return memoryview_fromslice(memviewslice[0], memview.view.ndim, # <<<<<<<<<<<<<< + * to_object_func, to_dtype_func, + * memview.dtype_is_object) + */ + __Pyx_XDECREF(__pyx_r); + + /* "View.MemoryView":1103 + * return memoryview_fromslice(memviewslice[0], memview.view.ndim, + * to_object_func, to_dtype_func, + * memview.dtype_is_object) # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_4 = __pyx_memoryview_fromslice((__pyx_v_memviewslice[0]), __pyx_v_memview->view.ndim, __pyx_v_to_object_func, __pyx_v_to_dtype_func, __pyx_v_memview->dtype_is_object); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 1101, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_r = __pyx_t_4; + __pyx_t_4 = 0; + goto __pyx_L0; + + /* "View.MemoryView":1087 + * + * @cname('__pyx_memoryview_copy_object_from_slice') + * cdef memoryview_copy_from_slice(memoryview memview, __Pyx_memviewslice *memviewslice): # <<<<<<<<<<<<<< + * """ + * Create a new memoryview object from a given memoryview object and slice. + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView.memoryview_copy_from_slice", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "View.MemoryView":1109 + * + * + * cdef Py_ssize_t abs_py_ssize_t(Py_ssize_t arg) noexcept nogil: # <<<<<<<<<<<<<< + * return -arg if arg < 0 else arg + * + */ + +static Py_ssize_t abs_py_ssize_t(Py_ssize_t __pyx_v_arg) { + Py_ssize_t __pyx_r; + Py_ssize_t __pyx_t_1; + int __pyx_t_2; + + /* "View.MemoryView":1110 + * + * cdef Py_ssize_t abs_py_ssize_t(Py_ssize_t arg) noexcept nogil: + * return -arg if arg < 0 else arg # <<<<<<<<<<<<<< + * + * @cname('__pyx_get_best_slice_order') + */ + __pyx_t_2 = (__pyx_v_arg < 0); + if (__pyx_t_2) { + __pyx_t_1 = (-__pyx_v_arg); + } else { + __pyx_t_1 = __pyx_v_arg; + } + __pyx_r = __pyx_t_1; + goto __pyx_L0; + + /* "View.MemoryView":1109 + * + * + * cdef Py_ssize_t abs_py_ssize_t(Py_ssize_t arg) noexcept nogil: # <<<<<<<<<<<<<< + * return -arg if arg < 0 else arg + * + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1113 + * + * @cname('__pyx_get_best_slice_order') + * cdef char get_best_order(__Pyx_memviewslice *mslice, int ndim) noexcept nogil: # <<<<<<<<<<<<<< + * """ + * Figure out the best memory access order for a given slice. + */ + +static char __pyx_get_best_slice_order(__Pyx_memviewslice *__pyx_v_mslice, int __pyx_v_ndim) { + int __pyx_v_i; + Py_ssize_t __pyx_v_c_stride; + Py_ssize_t __pyx_v_f_stride; + char __pyx_r; + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_t_4; + + /* "View.MemoryView":1118 + * """ + * cdef int i + * cdef Py_ssize_t c_stride = 0 # <<<<<<<<<<<<<< + * cdef Py_ssize_t f_stride = 0 + * + */ + __pyx_v_c_stride = 0; + + /* "View.MemoryView":1119 + * cdef int i + * cdef Py_ssize_t c_stride = 0 + * cdef Py_ssize_t f_stride = 0 # <<<<<<<<<<<<<< + * + * for i in range(ndim - 1, -1, -1): + */ + __pyx_v_f_stride = 0; + + /* "View.MemoryView":1121 + * cdef Py_ssize_t f_stride = 0 + * + * for i in range(ndim - 1, -1, -1): # <<<<<<<<<<<<<< + * if mslice.shape[i] > 1: + * c_stride = mslice.strides[i] + */ + for (__pyx_t_1 = (__pyx_v_ndim - 1); __pyx_t_1 > -1; __pyx_t_1-=1) { + __pyx_v_i = __pyx_t_1; + + /* "View.MemoryView":1122 + * + * for i in range(ndim - 1, -1, -1): + * if mslice.shape[i] > 1: # <<<<<<<<<<<<<< + * c_stride = mslice.strides[i] + * break + */ + __pyx_t_2 = ((__pyx_v_mslice->shape[__pyx_v_i]) > 1); + if (__pyx_t_2) { + + /* "View.MemoryView":1123 + * for i in range(ndim - 1, -1, -1): + * if mslice.shape[i] > 1: + * c_stride = mslice.strides[i] # <<<<<<<<<<<<<< + * break + * + */ + __pyx_v_c_stride = (__pyx_v_mslice->strides[__pyx_v_i]); + + /* "View.MemoryView":1124 + * if mslice.shape[i] > 1: + * c_stride = mslice.strides[i] + * break # <<<<<<<<<<<<<< + * + * for i in range(ndim): + */ + goto __pyx_L4_break; + + /* "View.MemoryView":1122 + * + * for i in range(ndim - 1, -1, -1): + * if mslice.shape[i] > 1: # <<<<<<<<<<<<<< + * c_stride = mslice.strides[i] + * break + */ + } + } + __pyx_L4_break:; + + /* "View.MemoryView":1126 + * break + * + * for i in range(ndim): # <<<<<<<<<<<<<< + * if mslice.shape[i] > 1: + * f_stride = mslice.strides[i] + */ + __pyx_t_1 = __pyx_v_ndim; + __pyx_t_3 = __pyx_t_1; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":1127 + * + * for i in range(ndim): + * if mslice.shape[i] > 1: # <<<<<<<<<<<<<< + * f_stride = mslice.strides[i] + * break + */ + __pyx_t_2 = ((__pyx_v_mslice->shape[__pyx_v_i]) > 1); + if (__pyx_t_2) { + + /* "View.MemoryView":1128 + * for i in range(ndim): + * if mslice.shape[i] > 1: + * f_stride = mslice.strides[i] # <<<<<<<<<<<<<< + * break + * + */ + __pyx_v_f_stride = (__pyx_v_mslice->strides[__pyx_v_i]); + + /* "View.MemoryView":1129 + * if mslice.shape[i] > 1: + * f_stride = mslice.strides[i] + * break # <<<<<<<<<<<<<< + * + * if abs_py_ssize_t(c_stride) <= abs_py_ssize_t(f_stride): + */ + goto __pyx_L7_break; + + /* "View.MemoryView":1127 + * + * for i in range(ndim): + * if mslice.shape[i] > 1: # <<<<<<<<<<<<<< + * f_stride = mslice.strides[i] + * break + */ + } + } + __pyx_L7_break:; + + /* "View.MemoryView":1131 + * break + * + * if abs_py_ssize_t(c_stride) <= abs_py_ssize_t(f_stride): # <<<<<<<<<<<<<< + * return 'C' + * else: + */ + __pyx_t_2 = (abs_py_ssize_t(__pyx_v_c_stride) <= abs_py_ssize_t(__pyx_v_f_stride)); + if (__pyx_t_2) { + + /* "View.MemoryView":1132 + * + * if abs_py_ssize_t(c_stride) <= abs_py_ssize_t(f_stride): + * return 'C' # <<<<<<<<<<<<<< + * else: + * return 'F' + */ + __pyx_r = 'C'; + goto __pyx_L0; + + /* "View.MemoryView":1131 + * break + * + * if abs_py_ssize_t(c_stride) <= abs_py_ssize_t(f_stride): # <<<<<<<<<<<<<< + * return 'C' + * else: + */ + } + + /* "View.MemoryView":1134 + * return 'C' + * else: + * return 'F' # <<<<<<<<<<<<<< + * + * @cython.cdivision(True) + */ + /*else*/ { + __pyx_r = 'F'; + goto __pyx_L0; + } + + /* "View.MemoryView":1113 + * + * @cname('__pyx_get_best_slice_order') + * cdef char get_best_order(__Pyx_memviewslice *mslice, int ndim) noexcept nogil: # <<<<<<<<<<<<<< + * """ + * Figure out the best memory access order for a given slice. + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1137 + * + * @cython.cdivision(True) + * cdef void _copy_strided_to_strided(char *src_data, Py_ssize_t *src_strides, # <<<<<<<<<<<<<< + * char *dst_data, Py_ssize_t *dst_strides, + * Py_ssize_t *src_shape, Py_ssize_t *dst_shape, + */ + +static void _copy_strided_to_strided(char *__pyx_v_src_data, Py_ssize_t *__pyx_v_src_strides, char *__pyx_v_dst_data, Py_ssize_t *__pyx_v_dst_strides, Py_ssize_t *__pyx_v_src_shape, Py_ssize_t *__pyx_v_dst_shape, int __pyx_v_ndim, size_t __pyx_v_itemsize) { + CYTHON_UNUSED Py_ssize_t __pyx_v_i; + CYTHON_UNUSED Py_ssize_t __pyx_v_src_extent; + Py_ssize_t __pyx_v_dst_extent; + Py_ssize_t __pyx_v_src_stride; + Py_ssize_t __pyx_v_dst_stride; + int __pyx_t_1; + int __pyx_t_2; + Py_ssize_t __pyx_t_3; + Py_ssize_t __pyx_t_4; + Py_ssize_t __pyx_t_5; + + /* "View.MemoryView":1144 + * + * cdef Py_ssize_t i + * cdef Py_ssize_t src_extent = src_shape[0] # <<<<<<<<<<<<<< + * cdef Py_ssize_t dst_extent = dst_shape[0] + * cdef Py_ssize_t src_stride = src_strides[0] + */ + __pyx_v_src_extent = (__pyx_v_src_shape[0]); + + /* "View.MemoryView":1145 + * cdef Py_ssize_t i + * cdef Py_ssize_t src_extent = src_shape[0] + * cdef Py_ssize_t dst_extent = dst_shape[0] # <<<<<<<<<<<<<< + * cdef Py_ssize_t src_stride = src_strides[0] + * cdef Py_ssize_t dst_stride = dst_strides[0] + */ + __pyx_v_dst_extent = (__pyx_v_dst_shape[0]); + + /* "View.MemoryView":1146 + * cdef Py_ssize_t src_extent = src_shape[0] + * cdef Py_ssize_t dst_extent = dst_shape[0] + * cdef Py_ssize_t src_stride = src_strides[0] # <<<<<<<<<<<<<< + * cdef Py_ssize_t dst_stride = dst_strides[0] + * + */ + __pyx_v_src_stride = (__pyx_v_src_strides[0]); + + /* "View.MemoryView":1147 + * cdef Py_ssize_t dst_extent = dst_shape[0] + * cdef Py_ssize_t src_stride = src_strides[0] + * cdef Py_ssize_t dst_stride = dst_strides[0] # <<<<<<<<<<<<<< + * + * if ndim == 1: + */ + __pyx_v_dst_stride = (__pyx_v_dst_strides[0]); + + /* "View.MemoryView":1149 + * cdef Py_ssize_t dst_stride = dst_strides[0] + * + * if ndim == 1: # <<<<<<<<<<<<<< + * if (src_stride > 0 and dst_stride > 0 and + * src_stride == itemsize == dst_stride): + */ + __pyx_t_1 = (__pyx_v_ndim == 1); + if (__pyx_t_1) { + + /* "View.MemoryView":1150 + * + * if ndim == 1: + * if (src_stride > 0 and dst_stride > 0 and # <<<<<<<<<<<<<< + * src_stride == itemsize == dst_stride): + * memcpy(dst_data, src_data, itemsize * dst_extent) + */ + __pyx_t_2 = (__pyx_v_src_stride > 0); + if (__pyx_t_2) { + } else { + __pyx_t_1 = __pyx_t_2; + goto __pyx_L5_bool_binop_done; + } + __pyx_t_2 = (__pyx_v_dst_stride > 0); + if (__pyx_t_2) { + } else { + __pyx_t_1 = __pyx_t_2; + goto __pyx_L5_bool_binop_done; + } + + /* "View.MemoryView":1151 + * if ndim == 1: + * if (src_stride > 0 and dst_stride > 0 and + * src_stride == itemsize == dst_stride): # <<<<<<<<<<<<<< + * memcpy(dst_data, src_data, itemsize * dst_extent) + * else: + */ + __pyx_t_2 = (((size_t)__pyx_v_src_stride) == __pyx_v_itemsize); + if (__pyx_t_2) { + __pyx_t_2 = (__pyx_v_itemsize == ((size_t)__pyx_v_dst_stride)); + } + __pyx_t_1 = __pyx_t_2; + __pyx_L5_bool_binop_done:; + + /* "View.MemoryView":1150 + * + * if ndim == 1: + * if (src_stride > 0 and dst_stride > 0 and # <<<<<<<<<<<<<< + * src_stride == itemsize == dst_stride): + * memcpy(dst_data, src_data, itemsize * dst_extent) + */ + if (__pyx_t_1) { + + /* "View.MemoryView":1152 + * if (src_stride > 0 and dst_stride > 0 and + * src_stride == itemsize == dst_stride): + * memcpy(dst_data, src_data, itemsize * dst_extent) # <<<<<<<<<<<<<< + * else: + * for i in range(dst_extent): + */ + (void)(memcpy(__pyx_v_dst_data, __pyx_v_src_data, (__pyx_v_itemsize * __pyx_v_dst_extent))); + + /* "View.MemoryView":1150 + * + * if ndim == 1: + * if (src_stride > 0 and dst_stride > 0 and # <<<<<<<<<<<<<< + * src_stride == itemsize == dst_stride): + * memcpy(dst_data, src_data, itemsize * dst_extent) + */ + goto __pyx_L4; + } + + /* "View.MemoryView":1154 + * memcpy(dst_data, src_data, itemsize * dst_extent) + * else: + * for i in range(dst_extent): # <<<<<<<<<<<<<< + * memcpy(dst_data, src_data, itemsize) + * src_data += src_stride + */ + /*else*/ { + __pyx_t_3 = __pyx_v_dst_extent; + __pyx_t_4 = __pyx_t_3; + for (__pyx_t_5 = 0; __pyx_t_5 < __pyx_t_4; __pyx_t_5+=1) { + __pyx_v_i = __pyx_t_5; + + /* "View.MemoryView":1155 + * else: + * for i in range(dst_extent): + * memcpy(dst_data, src_data, itemsize) # <<<<<<<<<<<<<< + * src_data += src_stride + * dst_data += dst_stride + */ + (void)(memcpy(__pyx_v_dst_data, __pyx_v_src_data, __pyx_v_itemsize)); + + /* "View.MemoryView":1156 + * for i in range(dst_extent): + * memcpy(dst_data, src_data, itemsize) + * src_data += src_stride # <<<<<<<<<<<<<< + * dst_data += dst_stride + * else: + */ + __pyx_v_src_data = (__pyx_v_src_data + __pyx_v_src_stride); + + /* "View.MemoryView":1157 + * memcpy(dst_data, src_data, itemsize) + * src_data += src_stride + * dst_data += dst_stride # <<<<<<<<<<<<<< + * else: + * for i in range(dst_extent): + */ + __pyx_v_dst_data = (__pyx_v_dst_data + __pyx_v_dst_stride); + } + } + __pyx_L4:; + + /* "View.MemoryView":1149 + * cdef Py_ssize_t dst_stride = dst_strides[0] + * + * if ndim == 1: # <<<<<<<<<<<<<< + * if (src_stride > 0 and dst_stride > 0 and + * src_stride == itemsize == dst_stride): + */ + goto __pyx_L3; + } + + /* "View.MemoryView":1159 + * dst_data += dst_stride + * else: + * for i in range(dst_extent): # <<<<<<<<<<<<<< + * _copy_strided_to_strided(src_data, src_strides + 1, + * dst_data, dst_strides + 1, + */ + /*else*/ { + __pyx_t_3 = __pyx_v_dst_extent; + __pyx_t_4 = __pyx_t_3; + for (__pyx_t_5 = 0; __pyx_t_5 < __pyx_t_4; __pyx_t_5+=1) { + __pyx_v_i = __pyx_t_5; + + /* "View.MemoryView":1160 + * else: + * for i in range(dst_extent): + * _copy_strided_to_strided(src_data, src_strides + 1, # <<<<<<<<<<<<<< + * dst_data, dst_strides + 1, + * src_shape + 1, dst_shape + 1, + */ + _copy_strided_to_strided(__pyx_v_src_data, (__pyx_v_src_strides + 1), __pyx_v_dst_data, (__pyx_v_dst_strides + 1), (__pyx_v_src_shape + 1), (__pyx_v_dst_shape + 1), (__pyx_v_ndim - 1), __pyx_v_itemsize); + + /* "View.MemoryView":1164 + * src_shape + 1, dst_shape + 1, + * ndim - 1, itemsize) + * src_data += src_stride # <<<<<<<<<<<<<< + * dst_data += dst_stride + * + */ + __pyx_v_src_data = (__pyx_v_src_data + __pyx_v_src_stride); + + /* "View.MemoryView":1165 + * ndim - 1, itemsize) + * src_data += src_stride + * dst_data += dst_stride # <<<<<<<<<<<<<< + * + * cdef void copy_strided_to_strided(__Pyx_memviewslice *src, + */ + __pyx_v_dst_data = (__pyx_v_dst_data + __pyx_v_dst_stride); + } + } + __pyx_L3:; + + /* "View.MemoryView":1137 + * + * @cython.cdivision(True) + * cdef void _copy_strided_to_strided(char *src_data, Py_ssize_t *src_strides, # <<<<<<<<<<<<<< + * char *dst_data, Py_ssize_t *dst_strides, + * Py_ssize_t *src_shape, Py_ssize_t *dst_shape, + */ + + /* function exit code */ +} + +/* "View.MemoryView":1167 + * dst_data += dst_stride + * + * cdef void copy_strided_to_strided(__Pyx_memviewslice *src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *dst, + * int ndim, size_t itemsize) noexcept nogil: + */ + +static void copy_strided_to_strided(__Pyx_memviewslice *__pyx_v_src, __Pyx_memviewslice *__pyx_v_dst, int __pyx_v_ndim, size_t __pyx_v_itemsize) { + + /* "View.MemoryView":1170 + * __Pyx_memviewslice *dst, + * int ndim, size_t itemsize) noexcept nogil: + * _copy_strided_to_strided(src.data, src.strides, dst.data, dst.strides, # <<<<<<<<<<<<<< + * src.shape, dst.shape, ndim, itemsize) + * + */ + _copy_strided_to_strided(__pyx_v_src->data, __pyx_v_src->strides, __pyx_v_dst->data, __pyx_v_dst->strides, __pyx_v_src->shape, __pyx_v_dst->shape, __pyx_v_ndim, __pyx_v_itemsize); + + /* "View.MemoryView":1167 + * dst_data += dst_stride + * + * cdef void copy_strided_to_strided(__Pyx_memviewslice *src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *dst, + * int ndim, size_t itemsize) noexcept nogil: + */ + + /* function exit code */ +} + +/* "View.MemoryView":1174 + * + * @cname('__pyx_memoryview_slice_get_size') + * cdef Py_ssize_t slice_get_size(__Pyx_memviewslice *src, int ndim) noexcept nogil: # <<<<<<<<<<<<<< + * "Return the size of the memory occupied by the slice in number of bytes" + * cdef Py_ssize_t shape, size = src.memview.view.itemsize + */ + +static Py_ssize_t __pyx_memoryview_slice_get_size(__Pyx_memviewslice *__pyx_v_src, int __pyx_v_ndim) { + Py_ssize_t __pyx_v_shape; + Py_ssize_t __pyx_v_size; + Py_ssize_t __pyx_r; + Py_ssize_t __pyx_t_1; + Py_ssize_t *__pyx_t_2; + Py_ssize_t *__pyx_t_3; + Py_ssize_t *__pyx_t_4; + + /* "View.MemoryView":1176 + * cdef Py_ssize_t slice_get_size(__Pyx_memviewslice *src, int ndim) noexcept nogil: + * "Return the size of the memory occupied by the slice in number of bytes" + * cdef Py_ssize_t shape, size = src.memview.view.itemsize # <<<<<<<<<<<<<< + * + * for shape in src.shape[:ndim]: + */ + __pyx_t_1 = __pyx_v_src->memview->view.itemsize; + __pyx_v_size = __pyx_t_1; + + /* "View.MemoryView":1178 + * cdef Py_ssize_t shape, size = src.memview.view.itemsize + * + * for shape in src.shape[:ndim]: # <<<<<<<<<<<<<< + * size *= shape + * + */ + __pyx_t_3 = (__pyx_v_src->shape + __pyx_v_ndim); + for (__pyx_t_4 = __pyx_v_src->shape; __pyx_t_4 < __pyx_t_3; __pyx_t_4++) { + __pyx_t_2 = __pyx_t_4; + __pyx_v_shape = (__pyx_t_2[0]); + + /* "View.MemoryView":1179 + * + * for shape in src.shape[:ndim]: + * size *= shape # <<<<<<<<<<<<<< + * + * return size + */ + __pyx_v_size = (__pyx_v_size * __pyx_v_shape); + } + + /* "View.MemoryView":1181 + * size *= shape + * + * return size # <<<<<<<<<<<<<< + * + * @cname('__pyx_fill_contig_strides_array') + */ + __pyx_r = __pyx_v_size; + goto __pyx_L0; + + /* "View.MemoryView":1174 + * + * @cname('__pyx_memoryview_slice_get_size') + * cdef Py_ssize_t slice_get_size(__Pyx_memviewslice *src, int ndim) noexcept nogil: # <<<<<<<<<<<<<< + * "Return the size of the memory occupied by the slice in number of bytes" + * cdef Py_ssize_t shape, size = src.memview.view.itemsize + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1184 + * + * @cname('__pyx_fill_contig_strides_array') + * cdef Py_ssize_t fill_contig_strides_array( # <<<<<<<<<<<<<< + * Py_ssize_t *shape, Py_ssize_t *strides, Py_ssize_t stride, + * int ndim, char order) noexcept nogil: + */ + +static Py_ssize_t __pyx_fill_contig_strides_array(Py_ssize_t *__pyx_v_shape, Py_ssize_t *__pyx_v_strides, Py_ssize_t __pyx_v_stride, int __pyx_v_ndim, char __pyx_v_order) { + int __pyx_v_idx; + Py_ssize_t __pyx_r; + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_t_4; + + /* "View.MemoryView":1193 + * cdef int idx + * + * if order == 'F': # <<<<<<<<<<<<<< + * for idx in range(ndim): + * strides[idx] = stride + */ + __pyx_t_1 = (__pyx_v_order == 'F'); + if (__pyx_t_1) { + + /* "View.MemoryView":1194 + * + * if order == 'F': + * for idx in range(ndim): # <<<<<<<<<<<<<< + * strides[idx] = stride + * stride *= shape[idx] + */ + __pyx_t_2 = __pyx_v_ndim; + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_idx = __pyx_t_4; + + /* "View.MemoryView":1195 + * if order == 'F': + * for idx in range(ndim): + * strides[idx] = stride # <<<<<<<<<<<<<< + * stride *= shape[idx] + * else: + */ + (__pyx_v_strides[__pyx_v_idx]) = __pyx_v_stride; + + /* "View.MemoryView":1196 + * for idx in range(ndim): + * strides[idx] = stride + * stride *= shape[idx] # <<<<<<<<<<<<<< + * else: + * for idx in range(ndim - 1, -1, -1): + */ + __pyx_v_stride = (__pyx_v_stride * (__pyx_v_shape[__pyx_v_idx])); + } + + /* "View.MemoryView":1193 + * cdef int idx + * + * if order == 'F': # <<<<<<<<<<<<<< + * for idx in range(ndim): + * strides[idx] = stride + */ + goto __pyx_L3; + } + + /* "View.MemoryView":1198 + * stride *= shape[idx] + * else: + * for idx in range(ndim - 1, -1, -1): # <<<<<<<<<<<<<< + * strides[idx] = stride + * stride *= shape[idx] + */ + /*else*/ { + for (__pyx_t_2 = (__pyx_v_ndim - 1); __pyx_t_2 > -1; __pyx_t_2-=1) { + __pyx_v_idx = __pyx_t_2; + + /* "View.MemoryView":1199 + * else: + * for idx in range(ndim - 1, -1, -1): + * strides[idx] = stride # <<<<<<<<<<<<<< + * stride *= shape[idx] + * + */ + (__pyx_v_strides[__pyx_v_idx]) = __pyx_v_stride; + + /* "View.MemoryView":1200 + * for idx in range(ndim - 1, -1, -1): + * strides[idx] = stride + * stride *= shape[idx] # <<<<<<<<<<<<<< + * + * return stride + */ + __pyx_v_stride = (__pyx_v_stride * (__pyx_v_shape[__pyx_v_idx])); + } + } + __pyx_L3:; + + /* "View.MemoryView":1202 + * stride *= shape[idx] + * + * return stride # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_copy_data_to_temp') + */ + __pyx_r = __pyx_v_stride; + goto __pyx_L0; + + /* "View.MemoryView":1184 + * + * @cname('__pyx_fill_contig_strides_array') + * cdef Py_ssize_t fill_contig_strides_array( # <<<<<<<<<<<<<< + * Py_ssize_t *shape, Py_ssize_t *strides, Py_ssize_t stride, + * int ndim, char order) noexcept nogil: + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1205 + * + * @cname('__pyx_memoryview_copy_data_to_temp') + * cdef void *copy_data_to_temp(__Pyx_memviewslice *src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *tmpslice, + * char order, + */ + +static void *__pyx_memoryview_copy_data_to_temp(__Pyx_memviewslice *__pyx_v_src, __Pyx_memviewslice *__pyx_v_tmpslice, char __pyx_v_order, int __pyx_v_ndim) { + int __pyx_v_i; + void *__pyx_v_result; + size_t __pyx_v_itemsize; + size_t __pyx_v_size; + void *__pyx_r; + Py_ssize_t __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + struct __pyx_memoryview_obj *__pyx_t_4; + int __pyx_t_5; + int __pyx_t_6; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save; + #endif + + /* "View.MemoryView":1216 + * cdef void *result + * + * cdef size_t itemsize = src.memview.view.itemsize # <<<<<<<<<<<<<< + * cdef size_t size = slice_get_size(src, ndim) + * + */ + __pyx_t_1 = __pyx_v_src->memview->view.itemsize; + __pyx_v_itemsize = __pyx_t_1; + + /* "View.MemoryView":1217 + * + * cdef size_t itemsize = src.memview.view.itemsize + * cdef size_t size = slice_get_size(src, ndim) # <<<<<<<<<<<<<< + * + * result = malloc(size) + */ + __pyx_v_size = __pyx_memoryview_slice_get_size(__pyx_v_src, __pyx_v_ndim); + + /* "View.MemoryView":1219 + * cdef size_t size = slice_get_size(src, ndim) + * + * result = malloc(size) # <<<<<<<<<<<<<< + * if not result: + * _err_no_memory() + */ + __pyx_v_result = malloc(__pyx_v_size); + + /* "View.MemoryView":1220 + * + * result = malloc(size) + * if not result: # <<<<<<<<<<<<<< + * _err_no_memory() + * + */ + __pyx_t_2 = (!(__pyx_v_result != 0)); + if (__pyx_t_2) { + + /* "View.MemoryView":1221 + * result = malloc(size) + * if not result: + * _err_no_memory() # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_3 = __pyx_memoryview_err_no_memory(); if (unlikely(__pyx_t_3 == ((int)-1))) __PYX_ERR(1, 1221, __pyx_L1_error) + + /* "View.MemoryView":1220 + * + * result = malloc(size) + * if not result: # <<<<<<<<<<<<<< + * _err_no_memory() + * + */ + } + + /* "View.MemoryView":1224 + * + * + * tmpslice.data = result # <<<<<<<<<<<<<< + * tmpslice.memview = src.memview + * for i in range(ndim): + */ + __pyx_v_tmpslice->data = ((char *)__pyx_v_result); + + /* "View.MemoryView":1225 + * + * tmpslice.data = result + * tmpslice.memview = src.memview # <<<<<<<<<<<<<< + * for i in range(ndim): + * tmpslice.shape[i] = src.shape[i] + */ + __pyx_t_4 = __pyx_v_src->memview; + __pyx_v_tmpslice->memview = __pyx_t_4; + + /* "View.MemoryView":1226 + * tmpslice.data = result + * tmpslice.memview = src.memview + * for i in range(ndim): # <<<<<<<<<<<<<< + * tmpslice.shape[i] = src.shape[i] + * tmpslice.suboffsets[i] = -1 + */ + __pyx_t_3 = __pyx_v_ndim; + __pyx_t_5 = __pyx_t_3; + for (__pyx_t_6 = 0; __pyx_t_6 < __pyx_t_5; __pyx_t_6+=1) { + __pyx_v_i = __pyx_t_6; + + /* "View.MemoryView":1227 + * tmpslice.memview = src.memview + * for i in range(ndim): + * tmpslice.shape[i] = src.shape[i] # <<<<<<<<<<<<<< + * tmpslice.suboffsets[i] = -1 + * + */ + (__pyx_v_tmpslice->shape[__pyx_v_i]) = (__pyx_v_src->shape[__pyx_v_i]); + + /* "View.MemoryView":1228 + * for i in range(ndim): + * tmpslice.shape[i] = src.shape[i] + * tmpslice.suboffsets[i] = -1 # <<<<<<<<<<<<<< + * + * fill_contig_strides_array(&tmpslice.shape[0], &tmpslice.strides[0], itemsize, ndim, order) + */ + (__pyx_v_tmpslice->suboffsets[__pyx_v_i]) = -1L; + } + + /* "View.MemoryView":1230 + * tmpslice.suboffsets[i] = -1 + * + * fill_contig_strides_array(&tmpslice.shape[0], &tmpslice.strides[0], itemsize, ndim, order) # <<<<<<<<<<<<<< + * + * + */ + (void)(__pyx_fill_contig_strides_array((&(__pyx_v_tmpslice->shape[0])), (&(__pyx_v_tmpslice->strides[0])), __pyx_v_itemsize, __pyx_v_ndim, __pyx_v_order)); + + /* "View.MemoryView":1233 + * + * + * for i in range(ndim): # <<<<<<<<<<<<<< + * if tmpslice.shape[i] == 1: + * tmpslice.strides[i] = 0 + */ + __pyx_t_3 = __pyx_v_ndim; + __pyx_t_5 = __pyx_t_3; + for (__pyx_t_6 = 0; __pyx_t_6 < __pyx_t_5; __pyx_t_6+=1) { + __pyx_v_i = __pyx_t_6; + + /* "View.MemoryView":1234 + * + * for i in range(ndim): + * if tmpslice.shape[i] == 1: # <<<<<<<<<<<<<< + * tmpslice.strides[i] = 0 + * + */ + __pyx_t_2 = ((__pyx_v_tmpslice->shape[__pyx_v_i]) == 1); + if (__pyx_t_2) { + + /* "View.MemoryView":1235 + * for i in range(ndim): + * if tmpslice.shape[i] == 1: + * tmpslice.strides[i] = 0 # <<<<<<<<<<<<<< + * + * if slice_is_contig(src[0], order, ndim): + */ + (__pyx_v_tmpslice->strides[__pyx_v_i]) = 0; + + /* "View.MemoryView":1234 + * + * for i in range(ndim): + * if tmpslice.shape[i] == 1: # <<<<<<<<<<<<<< + * tmpslice.strides[i] = 0 + * + */ + } + } + + /* "View.MemoryView":1237 + * tmpslice.strides[i] = 0 + * + * if slice_is_contig(src[0], order, ndim): # <<<<<<<<<<<<<< + * memcpy(result, src.data, size) + * else: + */ + __pyx_t_2 = __pyx_memviewslice_is_contig((__pyx_v_src[0]), __pyx_v_order, __pyx_v_ndim); + if (__pyx_t_2) { + + /* "View.MemoryView":1238 + * + * if slice_is_contig(src[0], order, ndim): + * memcpy(result, src.data, size) # <<<<<<<<<<<<<< + * else: + * copy_strided_to_strided(src, tmpslice, ndim, itemsize) + */ + (void)(memcpy(__pyx_v_result, __pyx_v_src->data, __pyx_v_size)); + + /* "View.MemoryView":1237 + * tmpslice.strides[i] = 0 + * + * if slice_is_contig(src[0], order, ndim): # <<<<<<<<<<<<<< + * memcpy(result, src.data, size) + * else: + */ + goto __pyx_L9; + } + + /* "View.MemoryView":1240 + * memcpy(result, src.data, size) + * else: + * copy_strided_to_strided(src, tmpslice, ndim, itemsize) # <<<<<<<<<<<<<< + * + * return result + */ + /*else*/ { + copy_strided_to_strided(__pyx_v_src, __pyx_v_tmpslice, __pyx_v_ndim, __pyx_v_itemsize); + } + __pyx_L9:; + + /* "View.MemoryView":1242 + * copy_strided_to_strided(src, tmpslice, ndim, itemsize) + * + * return result # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = __pyx_v_result; + goto __pyx_L0; + + /* "View.MemoryView":1205 + * + * @cname('__pyx_memoryview_copy_data_to_temp') + * cdef void *copy_data_to_temp(__Pyx_memviewslice *src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice *tmpslice, + * char order, + */ + + /* function exit code */ + __pyx_L1_error:; + #ifdef WITH_THREAD + __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_AddTraceback("View.MemoryView.copy_data_to_temp", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1247 + * + * @cname('__pyx_memoryview_err_extents') + * cdef int _err_extents(int i, Py_ssize_t extent1, # <<<<<<<<<<<<<< + * Py_ssize_t extent2) except -1 with gil: + * raise ValueError, f"got differing extents in dimension {i} (got {extent1} and {extent2})" + */ + +static int __pyx_memoryview_err_extents(int __pyx_v_i, Py_ssize_t __pyx_v_extent1, Py_ssize_t __pyx_v_extent2) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + Py_ssize_t __pyx_t_2; + Py_UCS4 __pyx_t_3; + PyObject *__pyx_t_4 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_RefNannySetupContext("_err_extents", 0); + + /* "View.MemoryView":1249 + * cdef int _err_extents(int i, Py_ssize_t extent1, + * Py_ssize_t extent2) except -1 with gil: + * raise ValueError, f"got differing extents in dimension {i} (got {extent1} and {extent2})" # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_err_dim') + */ + __pyx_t_1 = PyTuple_New(7); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 1249, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = 0; + __pyx_t_3 = 127; + __Pyx_INCREF(__pyx_kp_u_got_differing_extents_in_dimensi); + __pyx_t_2 += 35; + __Pyx_GIVEREF(__pyx_kp_u_got_differing_extents_in_dimensi); + PyTuple_SET_ITEM(__pyx_t_1, 0, __pyx_kp_u_got_differing_extents_in_dimensi); + __pyx_t_4 = __Pyx_PyUnicode_From_int(__pyx_v_i, 0, ' ', 'd'); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 1249, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_2 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_4); + PyTuple_SET_ITEM(__pyx_t_1, 1, __pyx_t_4); + __pyx_t_4 = 0; + __Pyx_INCREF(__pyx_kp_u_got); + __pyx_t_2 += 6; + __Pyx_GIVEREF(__pyx_kp_u_got); + PyTuple_SET_ITEM(__pyx_t_1, 2, __pyx_kp_u_got); + __pyx_t_4 = __Pyx_PyUnicode_From_Py_ssize_t(__pyx_v_extent1, 0, ' ', 'd'); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 1249, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_2 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_4); + PyTuple_SET_ITEM(__pyx_t_1, 3, __pyx_t_4); + __pyx_t_4 = 0; + __Pyx_INCREF(__pyx_kp_u_and); + __pyx_t_2 += 5; + __Pyx_GIVEREF(__pyx_kp_u_and); + PyTuple_SET_ITEM(__pyx_t_1, 4, __pyx_kp_u_and); + __pyx_t_4 = __Pyx_PyUnicode_From_Py_ssize_t(__pyx_v_extent2, 0, ' ', 'd'); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 1249, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_2 += __Pyx_PyUnicode_GET_LENGTH(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_4); + PyTuple_SET_ITEM(__pyx_t_1, 5, __pyx_t_4); + __pyx_t_4 = 0; + __Pyx_INCREF(__pyx_kp_u__7); + __pyx_t_2 += 1; + __Pyx_GIVEREF(__pyx_kp_u__7); + PyTuple_SET_ITEM(__pyx_t_1, 6, __pyx_kp_u__7); + __pyx_t_4 = __Pyx_PyUnicode_Join(__pyx_t_1, 7, __pyx_t_2, __pyx_t_3); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 1249, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_Raise(__pyx_builtin_ValueError, __pyx_t_4, 0, 0); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __PYX_ERR(1, 1249, __pyx_L1_error) + + /* "View.MemoryView":1247 + * + * @cname('__pyx_memoryview_err_extents') + * cdef int _err_extents(int i, Py_ssize_t extent1, # <<<<<<<<<<<<<< + * Py_ssize_t extent2) except -1 with gil: + * raise ValueError, f"got differing extents in dimension {i} (got {extent1} and {extent2})" + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView._err_extents", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __Pyx_RefNannyFinishContext(); + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + return __pyx_r; +} + +/* "View.MemoryView":1252 + * + * @cname('__pyx_memoryview_err_dim') + * cdef int _err_dim(PyObject *error, str msg, int dim) except -1 with gil: # <<<<<<<<<<<<<< + * raise error, msg % dim + * + */ + +static int __pyx_memoryview_err_dim(PyObject *__pyx_v_error, PyObject *__pyx_v_msg, int __pyx_v_dim) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_RefNannySetupContext("_err_dim", 0); + __Pyx_INCREF(__pyx_v_msg); + + /* "View.MemoryView":1253 + * @cname('__pyx_memoryview_err_dim') + * cdef int _err_dim(PyObject *error, str msg, int dim) except -1 with gil: + * raise error, msg % dim # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_err') + */ + __pyx_t_1 = __Pyx_PyInt_From_int(__pyx_v_dim); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 1253, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = __Pyx_PyString_FormatSafe(__pyx_v_msg, __pyx_t_1); if (unlikely(!__pyx_t_2)) __PYX_ERR(1, 1253, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_Raise(((PyObject *)__pyx_v_error), __pyx_t_2, 0, 0); + __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0; + __PYX_ERR(1, 1253, __pyx_L1_error) + + /* "View.MemoryView":1252 + * + * @cname('__pyx_memoryview_err_dim') + * cdef int _err_dim(PyObject *error, str msg, int dim) except -1 with gil: # <<<<<<<<<<<<<< + * raise error, msg % dim + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("View.MemoryView._err_dim", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __Pyx_XDECREF(__pyx_v_msg); + __Pyx_RefNannyFinishContext(); + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + return __pyx_r; +} + +/* "View.MemoryView":1256 + * + * @cname('__pyx_memoryview_err') + * cdef int _err(PyObject *error, str msg) except -1 with gil: # <<<<<<<<<<<<<< + * raise error, msg + * + */ + +static int __pyx_memoryview_err(PyObject *__pyx_v_error, PyObject *__pyx_v_msg) { + int __pyx_r; + __Pyx_RefNannyDeclarations + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_RefNannySetupContext("_err", 0); + __Pyx_INCREF(__pyx_v_msg); + + /* "View.MemoryView":1257 + * @cname('__pyx_memoryview_err') + * cdef int _err(PyObject *error, str msg) except -1 with gil: + * raise error, msg # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_err_no_memory') + */ + __Pyx_Raise(((PyObject *)__pyx_v_error), __pyx_v_msg, 0, 0); + __PYX_ERR(1, 1257, __pyx_L1_error) + + /* "View.MemoryView":1256 + * + * @cname('__pyx_memoryview_err') + * cdef int _err(PyObject *error, str msg) except -1 with gil: # <<<<<<<<<<<<<< + * raise error, msg + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView._err", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __Pyx_XDECREF(__pyx_v_msg); + __Pyx_RefNannyFinishContext(); + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + return __pyx_r; +} + +/* "View.MemoryView":1260 + * + * @cname('__pyx_memoryview_err_no_memory') + * cdef int _err_no_memory() except -1 with gil: # <<<<<<<<<<<<<< + * raise MemoryError + * + */ + +static int __pyx_memoryview_err_no_memory(void) { + int __pyx_r; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + + /* "View.MemoryView":1261 + * @cname('__pyx_memoryview_err_no_memory') + * cdef int _err_no_memory() except -1 with gil: + * raise MemoryError # <<<<<<<<<<<<<< + * + * + */ + PyErr_NoMemory(); __PYX_ERR(1, 1261, __pyx_L1_error) + + /* "View.MemoryView":1260 + * + * @cname('__pyx_memoryview_err_no_memory') + * cdef int _err_no_memory() except -1 with gil: # <<<<<<<<<<<<<< + * raise MemoryError + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_AddTraceback("View.MemoryView._err_no_memory", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + return __pyx_r; +} + +/* "View.MemoryView":1265 + * + * @cname('__pyx_memoryview_copy_contents') + * cdef int memoryview_copy_contents(__Pyx_memviewslice src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice dst, + * int src_ndim, int dst_ndim, + */ + +static int __pyx_memoryview_copy_contents(__Pyx_memviewslice __pyx_v_src, __Pyx_memviewslice __pyx_v_dst, int __pyx_v_src_ndim, int __pyx_v_dst_ndim, int __pyx_v_dtype_is_object) { + void *__pyx_v_tmpdata; + size_t __pyx_v_itemsize; + int __pyx_v_i; + char __pyx_v_order; + int __pyx_v_broadcasting; + int __pyx_v_direct_copy; + __Pyx_memviewslice __pyx_v_tmp; + int __pyx_v_ndim; + int __pyx_r; + Py_ssize_t __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + int __pyx_t_4; + int __pyx_t_5; + int __pyx_t_6; + void *__pyx_t_7; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save; + #endif + + /* "View.MemoryView":1273 + * Check for overlapping memory and verify the shapes. + * """ + * cdef void *tmpdata = NULL # <<<<<<<<<<<<<< + * cdef size_t itemsize = src.memview.view.itemsize + * cdef int i + */ + __pyx_v_tmpdata = NULL; + + /* "View.MemoryView":1274 + * """ + * cdef void *tmpdata = NULL + * cdef size_t itemsize = src.memview.view.itemsize # <<<<<<<<<<<<<< + * cdef int i + * cdef char order = get_best_order(&src, src_ndim) + */ + __pyx_t_1 = __pyx_v_src.memview->view.itemsize; + __pyx_v_itemsize = __pyx_t_1; + + /* "View.MemoryView":1276 + * cdef size_t itemsize = src.memview.view.itemsize + * cdef int i + * cdef char order = get_best_order(&src, src_ndim) # <<<<<<<<<<<<<< + * cdef bint broadcasting = False + * cdef bint direct_copy = False + */ + __pyx_v_order = __pyx_get_best_slice_order((&__pyx_v_src), __pyx_v_src_ndim); + + /* "View.MemoryView":1277 + * cdef int i + * cdef char order = get_best_order(&src, src_ndim) + * cdef bint broadcasting = False # <<<<<<<<<<<<<< + * cdef bint direct_copy = False + * cdef __Pyx_memviewslice tmp + */ + __pyx_v_broadcasting = 0; + + /* "View.MemoryView":1278 + * cdef char order = get_best_order(&src, src_ndim) + * cdef bint broadcasting = False + * cdef bint direct_copy = False # <<<<<<<<<<<<<< + * cdef __Pyx_memviewslice tmp + * + */ + __pyx_v_direct_copy = 0; + + /* "View.MemoryView":1281 + * cdef __Pyx_memviewslice tmp + * + * if src_ndim < dst_ndim: # <<<<<<<<<<<<<< + * broadcast_leading(&src, src_ndim, dst_ndim) + * elif dst_ndim < src_ndim: + */ + __pyx_t_2 = (__pyx_v_src_ndim < __pyx_v_dst_ndim); + if (__pyx_t_2) { + + /* "View.MemoryView":1282 + * + * if src_ndim < dst_ndim: + * broadcast_leading(&src, src_ndim, dst_ndim) # <<<<<<<<<<<<<< + * elif dst_ndim < src_ndim: + * broadcast_leading(&dst, dst_ndim, src_ndim) + */ + __pyx_memoryview_broadcast_leading((&__pyx_v_src), __pyx_v_src_ndim, __pyx_v_dst_ndim); + + /* "View.MemoryView":1281 + * cdef __Pyx_memviewslice tmp + * + * if src_ndim < dst_ndim: # <<<<<<<<<<<<<< + * broadcast_leading(&src, src_ndim, dst_ndim) + * elif dst_ndim < src_ndim: + */ + goto __pyx_L3; + } + + /* "View.MemoryView":1283 + * if src_ndim < dst_ndim: + * broadcast_leading(&src, src_ndim, dst_ndim) + * elif dst_ndim < src_ndim: # <<<<<<<<<<<<<< + * broadcast_leading(&dst, dst_ndim, src_ndim) + * + */ + __pyx_t_2 = (__pyx_v_dst_ndim < __pyx_v_src_ndim); + if (__pyx_t_2) { + + /* "View.MemoryView":1284 + * broadcast_leading(&src, src_ndim, dst_ndim) + * elif dst_ndim < src_ndim: + * broadcast_leading(&dst, dst_ndim, src_ndim) # <<<<<<<<<<<<<< + * + * cdef int ndim = max(src_ndim, dst_ndim) + */ + __pyx_memoryview_broadcast_leading((&__pyx_v_dst), __pyx_v_dst_ndim, __pyx_v_src_ndim); + + /* "View.MemoryView":1283 + * if src_ndim < dst_ndim: + * broadcast_leading(&src, src_ndim, dst_ndim) + * elif dst_ndim < src_ndim: # <<<<<<<<<<<<<< + * broadcast_leading(&dst, dst_ndim, src_ndim) + * + */ + } + __pyx_L3:; + + /* "View.MemoryView":1286 + * broadcast_leading(&dst, dst_ndim, src_ndim) + * + * cdef int ndim = max(src_ndim, dst_ndim) # <<<<<<<<<<<<<< + * + * for i in range(ndim): + */ + __pyx_t_3 = __pyx_v_dst_ndim; + __pyx_t_4 = __pyx_v_src_ndim; + __pyx_t_2 = (__pyx_t_3 > __pyx_t_4); + if (__pyx_t_2) { + __pyx_t_5 = __pyx_t_3; + } else { + __pyx_t_5 = __pyx_t_4; + } + __pyx_v_ndim = __pyx_t_5; + + /* "View.MemoryView":1288 + * cdef int ndim = max(src_ndim, dst_ndim) + * + * for i in range(ndim): # <<<<<<<<<<<<<< + * if src.shape[i] != dst.shape[i]: + * if src.shape[i] == 1: + */ + __pyx_t_5 = __pyx_v_ndim; + __pyx_t_3 = __pyx_t_5; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":1289 + * + * for i in range(ndim): + * if src.shape[i] != dst.shape[i]: # <<<<<<<<<<<<<< + * if src.shape[i] == 1: + * broadcasting = True + */ + __pyx_t_2 = ((__pyx_v_src.shape[__pyx_v_i]) != (__pyx_v_dst.shape[__pyx_v_i])); + if (__pyx_t_2) { + + /* "View.MemoryView":1290 + * for i in range(ndim): + * if src.shape[i] != dst.shape[i]: + * if src.shape[i] == 1: # <<<<<<<<<<<<<< + * broadcasting = True + * src.strides[i] = 0 + */ + __pyx_t_2 = ((__pyx_v_src.shape[__pyx_v_i]) == 1); + if (__pyx_t_2) { + + /* "View.MemoryView":1291 + * if src.shape[i] != dst.shape[i]: + * if src.shape[i] == 1: + * broadcasting = True # <<<<<<<<<<<<<< + * src.strides[i] = 0 + * else: + */ + __pyx_v_broadcasting = 1; + + /* "View.MemoryView":1292 + * if src.shape[i] == 1: + * broadcasting = True + * src.strides[i] = 0 # <<<<<<<<<<<<<< + * else: + * _err_extents(i, dst.shape[i], src.shape[i]) + */ + (__pyx_v_src.strides[__pyx_v_i]) = 0; + + /* "View.MemoryView":1290 + * for i in range(ndim): + * if src.shape[i] != dst.shape[i]: + * if src.shape[i] == 1: # <<<<<<<<<<<<<< + * broadcasting = True + * src.strides[i] = 0 + */ + goto __pyx_L7; + } + + /* "View.MemoryView":1294 + * src.strides[i] = 0 + * else: + * _err_extents(i, dst.shape[i], src.shape[i]) # <<<<<<<<<<<<<< + * + * if src.suboffsets[i] >= 0: + */ + /*else*/ { + __pyx_t_6 = __pyx_memoryview_err_extents(__pyx_v_i, (__pyx_v_dst.shape[__pyx_v_i]), (__pyx_v_src.shape[__pyx_v_i])); if (unlikely(__pyx_t_6 == ((int)-1))) __PYX_ERR(1, 1294, __pyx_L1_error) + } + __pyx_L7:; + + /* "View.MemoryView":1289 + * + * for i in range(ndim): + * if src.shape[i] != dst.shape[i]: # <<<<<<<<<<<<<< + * if src.shape[i] == 1: + * broadcasting = True + */ + } + + /* "View.MemoryView":1296 + * _err_extents(i, dst.shape[i], src.shape[i]) + * + * if src.suboffsets[i] >= 0: # <<<<<<<<<<<<<< + * _err_dim(PyExc_ValueError, "Dimension %d is not direct", i) + * + */ + __pyx_t_2 = ((__pyx_v_src.suboffsets[__pyx_v_i]) >= 0); + if (__pyx_t_2) { + + /* "View.MemoryView":1297 + * + * if src.suboffsets[i] >= 0: + * _err_dim(PyExc_ValueError, "Dimension %d is not direct", i) # <<<<<<<<<<<<<< + * + * if slices_overlap(&src, &dst, ndim, itemsize): + */ + __pyx_t_6 = __pyx_memoryview_err_dim(PyExc_ValueError, __pyx_kp_s_Dimension_d_is_not_direct, __pyx_v_i); if (unlikely(__pyx_t_6 == ((int)-1))) __PYX_ERR(1, 1297, __pyx_L1_error) + + /* "View.MemoryView":1296 + * _err_extents(i, dst.shape[i], src.shape[i]) + * + * if src.suboffsets[i] >= 0: # <<<<<<<<<<<<<< + * _err_dim(PyExc_ValueError, "Dimension %d is not direct", i) + * + */ + } + } + + /* "View.MemoryView":1299 + * _err_dim(PyExc_ValueError, "Dimension %d is not direct", i) + * + * if slices_overlap(&src, &dst, ndim, itemsize): # <<<<<<<<<<<<<< + * + * if not slice_is_contig(src, order, ndim): + */ + __pyx_t_2 = __pyx_slices_overlap((&__pyx_v_src), (&__pyx_v_dst), __pyx_v_ndim, __pyx_v_itemsize); + if (__pyx_t_2) { + + /* "View.MemoryView":1301 + * if slices_overlap(&src, &dst, ndim, itemsize): + * + * if not slice_is_contig(src, order, ndim): # <<<<<<<<<<<<<< + * order = get_best_order(&dst, ndim) + * + */ + __pyx_t_2 = (!__pyx_memviewslice_is_contig(__pyx_v_src, __pyx_v_order, __pyx_v_ndim)); + if (__pyx_t_2) { + + /* "View.MemoryView":1302 + * + * if not slice_is_contig(src, order, ndim): + * order = get_best_order(&dst, ndim) # <<<<<<<<<<<<<< + * + * tmpdata = copy_data_to_temp(&src, &tmp, order, ndim) + */ + __pyx_v_order = __pyx_get_best_slice_order((&__pyx_v_dst), __pyx_v_ndim); + + /* "View.MemoryView":1301 + * if slices_overlap(&src, &dst, ndim, itemsize): + * + * if not slice_is_contig(src, order, ndim): # <<<<<<<<<<<<<< + * order = get_best_order(&dst, ndim) + * + */ + } + + /* "View.MemoryView":1304 + * order = get_best_order(&dst, ndim) + * + * tmpdata = copy_data_to_temp(&src, &tmp, order, ndim) # <<<<<<<<<<<<<< + * src = tmp + * + */ + __pyx_t_7 = __pyx_memoryview_copy_data_to_temp((&__pyx_v_src), (&__pyx_v_tmp), __pyx_v_order, __pyx_v_ndim); if (unlikely(__pyx_t_7 == ((void *)NULL))) __PYX_ERR(1, 1304, __pyx_L1_error) + __pyx_v_tmpdata = __pyx_t_7; + + /* "View.MemoryView":1305 + * + * tmpdata = copy_data_to_temp(&src, &tmp, order, ndim) + * src = tmp # <<<<<<<<<<<<<< + * + * if not broadcasting: + */ + __pyx_v_src = __pyx_v_tmp; + + /* "View.MemoryView":1299 + * _err_dim(PyExc_ValueError, "Dimension %d is not direct", i) + * + * if slices_overlap(&src, &dst, ndim, itemsize): # <<<<<<<<<<<<<< + * + * if not slice_is_contig(src, order, ndim): + */ + } + + /* "View.MemoryView":1307 + * src = tmp + * + * if not broadcasting: # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_2 = (!__pyx_v_broadcasting); + if (__pyx_t_2) { + + /* "View.MemoryView":1310 + * + * + * if slice_is_contig(src, 'C', ndim): # <<<<<<<<<<<<<< + * direct_copy = slice_is_contig(dst, 'C', ndim) + * elif slice_is_contig(src, 'F', ndim): + */ + __pyx_t_2 = __pyx_memviewslice_is_contig(__pyx_v_src, 'C', __pyx_v_ndim); + if (__pyx_t_2) { + + /* "View.MemoryView":1311 + * + * if slice_is_contig(src, 'C', ndim): + * direct_copy = slice_is_contig(dst, 'C', ndim) # <<<<<<<<<<<<<< + * elif slice_is_contig(src, 'F', ndim): + * direct_copy = slice_is_contig(dst, 'F', ndim) + */ + __pyx_v_direct_copy = __pyx_memviewslice_is_contig(__pyx_v_dst, 'C', __pyx_v_ndim); + + /* "View.MemoryView":1310 + * + * + * if slice_is_contig(src, 'C', ndim): # <<<<<<<<<<<<<< + * direct_copy = slice_is_contig(dst, 'C', ndim) + * elif slice_is_contig(src, 'F', ndim): + */ + goto __pyx_L12; + } + + /* "View.MemoryView":1312 + * if slice_is_contig(src, 'C', ndim): + * direct_copy = slice_is_contig(dst, 'C', ndim) + * elif slice_is_contig(src, 'F', ndim): # <<<<<<<<<<<<<< + * direct_copy = slice_is_contig(dst, 'F', ndim) + * + */ + __pyx_t_2 = __pyx_memviewslice_is_contig(__pyx_v_src, 'F', __pyx_v_ndim); + if (__pyx_t_2) { + + /* "View.MemoryView":1313 + * direct_copy = slice_is_contig(dst, 'C', ndim) + * elif slice_is_contig(src, 'F', ndim): + * direct_copy = slice_is_contig(dst, 'F', ndim) # <<<<<<<<<<<<<< + * + * if direct_copy: + */ + __pyx_v_direct_copy = __pyx_memviewslice_is_contig(__pyx_v_dst, 'F', __pyx_v_ndim); + + /* "View.MemoryView":1312 + * if slice_is_contig(src, 'C', ndim): + * direct_copy = slice_is_contig(dst, 'C', ndim) + * elif slice_is_contig(src, 'F', ndim): # <<<<<<<<<<<<<< + * direct_copy = slice_is_contig(dst, 'F', ndim) + * + */ + } + __pyx_L12:; + + /* "View.MemoryView":1315 + * direct_copy = slice_is_contig(dst, 'F', ndim) + * + * if direct_copy: # <<<<<<<<<<<<<< + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + */ + if (__pyx_v_direct_copy) { + + /* "View.MemoryView":1317 + * if direct_copy: + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) # <<<<<<<<<<<<<< + * memcpy(dst.data, src.data, slice_get_size(&src, ndim)) + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + */ + __pyx_memoryview_refcount_copying((&__pyx_v_dst), __pyx_v_dtype_is_object, __pyx_v_ndim, 0); + + /* "View.MemoryView":1318 + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + * memcpy(dst.data, src.data, slice_get_size(&src, ndim)) # <<<<<<<<<<<<<< + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + * free(tmpdata) + */ + (void)(memcpy(__pyx_v_dst.data, __pyx_v_src.data, __pyx_memoryview_slice_get_size((&__pyx_v_src), __pyx_v_ndim))); + + /* "View.MemoryView":1319 + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + * memcpy(dst.data, src.data, slice_get_size(&src, ndim)) + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) # <<<<<<<<<<<<<< + * free(tmpdata) + * return 0 + */ + __pyx_memoryview_refcount_copying((&__pyx_v_dst), __pyx_v_dtype_is_object, __pyx_v_ndim, 1); + + /* "View.MemoryView":1320 + * memcpy(dst.data, src.data, slice_get_size(&src, ndim)) + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + * free(tmpdata) # <<<<<<<<<<<<<< + * return 0 + * + */ + free(__pyx_v_tmpdata); + + /* "View.MemoryView":1321 + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + * free(tmpdata) + * return 0 # <<<<<<<<<<<<<< + * + * if order == 'F' == get_best_order(&dst, ndim): + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":1315 + * direct_copy = slice_is_contig(dst, 'F', ndim) + * + * if direct_copy: # <<<<<<<<<<<<<< + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + */ + } + + /* "View.MemoryView":1307 + * src = tmp + * + * if not broadcasting: # <<<<<<<<<<<<<< + * + * + */ + } + + /* "View.MemoryView":1323 + * return 0 + * + * if order == 'F' == get_best_order(&dst, ndim): # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_2 = (__pyx_v_order == 'F'); + if (__pyx_t_2) { + __pyx_t_2 = ('F' == __pyx_get_best_slice_order((&__pyx_v_dst), __pyx_v_ndim)); + } + if (__pyx_t_2) { + + /* "View.MemoryView":1326 + * + * + * transpose_memslice(&src) # <<<<<<<<<<<<<< + * transpose_memslice(&dst) + * + */ + __pyx_t_5 = __pyx_memslice_transpose((&__pyx_v_src)); if (unlikely(__pyx_t_5 == ((int)-1))) __PYX_ERR(1, 1326, __pyx_L1_error) + + /* "View.MemoryView":1327 + * + * transpose_memslice(&src) + * transpose_memslice(&dst) # <<<<<<<<<<<<<< + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + */ + __pyx_t_5 = __pyx_memslice_transpose((&__pyx_v_dst)); if (unlikely(__pyx_t_5 == ((int)-1))) __PYX_ERR(1, 1327, __pyx_L1_error) + + /* "View.MemoryView":1323 + * return 0 + * + * if order == 'F' == get_best_order(&dst, ndim): # <<<<<<<<<<<<<< + * + * + */ + } + + /* "View.MemoryView":1329 + * transpose_memslice(&dst) + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) # <<<<<<<<<<<<<< + * copy_strided_to_strided(&src, &dst, ndim, itemsize) + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + */ + __pyx_memoryview_refcount_copying((&__pyx_v_dst), __pyx_v_dtype_is_object, __pyx_v_ndim, 0); + + /* "View.MemoryView":1330 + * + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + * copy_strided_to_strided(&src, &dst, ndim, itemsize) # <<<<<<<<<<<<<< + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + * + */ + copy_strided_to_strided((&__pyx_v_src), (&__pyx_v_dst), __pyx_v_ndim, __pyx_v_itemsize); + + /* "View.MemoryView":1331 + * refcount_copying(&dst, dtype_is_object, ndim, inc=False) + * copy_strided_to_strided(&src, &dst, ndim, itemsize) + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) # <<<<<<<<<<<<<< + * + * free(tmpdata) + */ + __pyx_memoryview_refcount_copying((&__pyx_v_dst), __pyx_v_dtype_is_object, __pyx_v_ndim, 1); + + /* "View.MemoryView":1333 + * refcount_copying(&dst, dtype_is_object, ndim, inc=True) + * + * free(tmpdata) # <<<<<<<<<<<<<< + * return 0 + * + */ + free(__pyx_v_tmpdata); + + /* "View.MemoryView":1334 + * + * free(tmpdata) + * return 0 # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_broadcast_leading') + */ + __pyx_r = 0; + goto __pyx_L0; + + /* "View.MemoryView":1265 + * + * @cname('__pyx_memoryview_copy_contents') + * cdef int memoryview_copy_contents(__Pyx_memviewslice src, # <<<<<<<<<<<<<< + * __Pyx_memviewslice dst, + * int src_ndim, int dst_ndim, + */ + + /* function exit code */ + __pyx_L1_error:; + #ifdef WITH_THREAD + __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_AddTraceback("View.MemoryView.memoryview_copy_contents", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + __pyx_L0:; + return __pyx_r; +} + +/* "View.MemoryView":1337 + * + * @cname('__pyx_memoryview_broadcast_leading') + * cdef void broadcast_leading(__Pyx_memviewslice *mslice, # <<<<<<<<<<<<<< + * int ndim, + * int ndim_other) noexcept nogil: + */ + +static void __pyx_memoryview_broadcast_leading(__Pyx_memviewslice *__pyx_v_mslice, int __pyx_v_ndim, int __pyx_v_ndim_other) { + int __pyx_v_i; + int __pyx_v_offset; + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + + /* "View.MemoryView":1341 + * int ndim_other) noexcept nogil: + * cdef int i + * cdef int offset = ndim_other - ndim # <<<<<<<<<<<<<< + * + * for i in range(ndim - 1, -1, -1): + */ + __pyx_v_offset = (__pyx_v_ndim_other - __pyx_v_ndim); + + /* "View.MemoryView":1343 + * cdef int offset = ndim_other - ndim + * + * for i in range(ndim - 1, -1, -1): # <<<<<<<<<<<<<< + * mslice.shape[i + offset] = mslice.shape[i] + * mslice.strides[i + offset] = mslice.strides[i] + */ + for (__pyx_t_1 = (__pyx_v_ndim - 1); __pyx_t_1 > -1; __pyx_t_1-=1) { + __pyx_v_i = __pyx_t_1; + + /* "View.MemoryView":1344 + * + * for i in range(ndim - 1, -1, -1): + * mslice.shape[i + offset] = mslice.shape[i] # <<<<<<<<<<<<<< + * mslice.strides[i + offset] = mslice.strides[i] + * mslice.suboffsets[i + offset] = mslice.suboffsets[i] + */ + (__pyx_v_mslice->shape[(__pyx_v_i + __pyx_v_offset)]) = (__pyx_v_mslice->shape[__pyx_v_i]); + + /* "View.MemoryView":1345 + * for i in range(ndim - 1, -1, -1): + * mslice.shape[i + offset] = mslice.shape[i] + * mslice.strides[i + offset] = mslice.strides[i] # <<<<<<<<<<<<<< + * mslice.suboffsets[i + offset] = mslice.suboffsets[i] + * + */ + (__pyx_v_mslice->strides[(__pyx_v_i + __pyx_v_offset)]) = (__pyx_v_mslice->strides[__pyx_v_i]); + + /* "View.MemoryView":1346 + * mslice.shape[i + offset] = mslice.shape[i] + * mslice.strides[i + offset] = mslice.strides[i] + * mslice.suboffsets[i + offset] = mslice.suboffsets[i] # <<<<<<<<<<<<<< + * + * for i in range(offset): + */ + (__pyx_v_mslice->suboffsets[(__pyx_v_i + __pyx_v_offset)]) = (__pyx_v_mslice->suboffsets[__pyx_v_i]); + } + + /* "View.MemoryView":1348 + * mslice.suboffsets[i + offset] = mslice.suboffsets[i] + * + * for i in range(offset): # <<<<<<<<<<<<<< + * mslice.shape[i] = 1 + * mslice.strides[i] = mslice.strides[0] + */ + __pyx_t_1 = __pyx_v_offset; + __pyx_t_2 = __pyx_t_1; + for (__pyx_t_3 = 0; __pyx_t_3 < __pyx_t_2; __pyx_t_3+=1) { + __pyx_v_i = __pyx_t_3; + + /* "View.MemoryView":1349 + * + * for i in range(offset): + * mslice.shape[i] = 1 # <<<<<<<<<<<<<< + * mslice.strides[i] = mslice.strides[0] + * mslice.suboffsets[i] = -1 + */ + (__pyx_v_mslice->shape[__pyx_v_i]) = 1; + + /* "View.MemoryView":1350 + * for i in range(offset): + * mslice.shape[i] = 1 + * mslice.strides[i] = mslice.strides[0] # <<<<<<<<<<<<<< + * mslice.suboffsets[i] = -1 + * + */ + (__pyx_v_mslice->strides[__pyx_v_i]) = (__pyx_v_mslice->strides[0]); + + /* "View.MemoryView":1351 + * mslice.shape[i] = 1 + * mslice.strides[i] = mslice.strides[0] + * mslice.suboffsets[i] = -1 # <<<<<<<<<<<<<< + * + * + */ + (__pyx_v_mslice->suboffsets[__pyx_v_i]) = -1L; + } + + /* "View.MemoryView":1337 + * + * @cname('__pyx_memoryview_broadcast_leading') + * cdef void broadcast_leading(__Pyx_memviewslice *mslice, # <<<<<<<<<<<<<< + * int ndim, + * int ndim_other) noexcept nogil: + */ + + /* function exit code */ +} + +/* "View.MemoryView":1359 + * + * @cname('__pyx_memoryview_refcount_copying') + * cdef void refcount_copying(__Pyx_memviewslice *dst, bint dtype_is_object, int ndim, bint inc) noexcept nogil: # <<<<<<<<<<<<<< + * + * if dtype_is_object: + */ + +static void __pyx_memoryview_refcount_copying(__Pyx_memviewslice *__pyx_v_dst, int __pyx_v_dtype_is_object, int __pyx_v_ndim, int __pyx_v_inc) { + + /* "View.MemoryView":1361 + * cdef void refcount_copying(__Pyx_memviewslice *dst, bint dtype_is_object, int ndim, bint inc) noexcept nogil: + * + * if dtype_is_object: # <<<<<<<<<<<<<< + * refcount_objects_in_slice_with_gil(dst.data, dst.shape, dst.strides, ndim, inc) + * + */ + if (__pyx_v_dtype_is_object) { + + /* "View.MemoryView":1362 + * + * if dtype_is_object: + * refcount_objects_in_slice_with_gil(dst.data, dst.shape, dst.strides, ndim, inc) # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_refcount_objects_in_slice_with_gil') + */ + __pyx_memoryview_refcount_objects_in_slice_with_gil(__pyx_v_dst->data, __pyx_v_dst->shape, __pyx_v_dst->strides, __pyx_v_ndim, __pyx_v_inc); + + /* "View.MemoryView":1361 + * cdef void refcount_copying(__Pyx_memviewslice *dst, bint dtype_is_object, int ndim, bint inc) noexcept nogil: + * + * if dtype_is_object: # <<<<<<<<<<<<<< + * refcount_objects_in_slice_with_gil(dst.data, dst.shape, dst.strides, ndim, inc) + * + */ + } + + /* "View.MemoryView":1359 + * + * @cname('__pyx_memoryview_refcount_copying') + * cdef void refcount_copying(__Pyx_memviewslice *dst, bint dtype_is_object, int ndim, bint inc) noexcept nogil: # <<<<<<<<<<<<<< + * + * if dtype_is_object: + */ + + /* function exit code */ +} + +/* "View.MemoryView":1365 + * + * @cname('__pyx_memoryview_refcount_objects_in_slice_with_gil') + * cdef void refcount_objects_in_slice_with_gil(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, + * bint inc) noexcept with gil: + */ + +static void __pyx_memoryview_refcount_objects_in_slice_with_gil(char *__pyx_v_data, Py_ssize_t *__pyx_v_shape, Py_ssize_t *__pyx_v_strides, int __pyx_v_ndim, int __pyx_v_inc) { + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + + /* "View.MemoryView":1368 + * Py_ssize_t *strides, int ndim, + * bint inc) noexcept with gil: + * refcount_objects_in_slice(data, shape, strides, ndim, inc) # <<<<<<<<<<<<<< + * + * @cname('__pyx_memoryview_refcount_objects_in_slice') + */ + __pyx_memoryview_refcount_objects_in_slice(__pyx_v_data, __pyx_v_shape, __pyx_v_strides, __pyx_v_ndim, __pyx_v_inc); + + /* "View.MemoryView":1365 + * + * @cname('__pyx_memoryview_refcount_objects_in_slice_with_gil') + * cdef void refcount_objects_in_slice_with_gil(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, + * bint inc) noexcept with gil: + */ + + /* function exit code */ + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif +} + +/* "View.MemoryView":1371 + * + * @cname('__pyx_memoryview_refcount_objects_in_slice') + * cdef void refcount_objects_in_slice(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, bint inc) noexcept: + * cdef Py_ssize_t i + */ + +static void __pyx_memoryview_refcount_objects_in_slice(char *__pyx_v_data, Py_ssize_t *__pyx_v_shape, Py_ssize_t *__pyx_v_strides, int __pyx_v_ndim, int __pyx_v_inc) { + CYTHON_UNUSED Py_ssize_t __pyx_v_i; + Py_ssize_t __pyx_v_stride; + Py_ssize_t __pyx_t_1; + Py_ssize_t __pyx_t_2; + Py_ssize_t __pyx_t_3; + int __pyx_t_4; + + /* "View.MemoryView":1374 + * Py_ssize_t *strides, int ndim, bint inc) noexcept: + * cdef Py_ssize_t i + * cdef Py_ssize_t stride = strides[0] # <<<<<<<<<<<<<< + * + * for i in range(shape[0]): + */ + __pyx_v_stride = (__pyx_v_strides[0]); + + /* "View.MemoryView":1376 + * cdef Py_ssize_t stride = strides[0] + * + * for i in range(shape[0]): # <<<<<<<<<<<<<< + * if ndim == 1: + * if inc: + */ + __pyx_t_1 = (__pyx_v_shape[0]); + __pyx_t_2 = __pyx_t_1; + for (__pyx_t_3 = 0; __pyx_t_3 < __pyx_t_2; __pyx_t_3+=1) { + __pyx_v_i = __pyx_t_3; + + /* "View.MemoryView":1377 + * + * for i in range(shape[0]): + * if ndim == 1: # <<<<<<<<<<<<<< + * if inc: + * Py_INCREF(( data)[0]) + */ + __pyx_t_4 = (__pyx_v_ndim == 1); + if (__pyx_t_4) { + + /* "View.MemoryView":1378 + * for i in range(shape[0]): + * if ndim == 1: + * if inc: # <<<<<<<<<<<<<< + * Py_INCREF(( data)[0]) + * else: + */ + if (__pyx_v_inc) { + + /* "View.MemoryView":1379 + * if ndim == 1: + * if inc: + * Py_INCREF(( data)[0]) # <<<<<<<<<<<<<< + * else: + * Py_DECREF(( data)[0]) + */ + Py_INCREF((((PyObject **)__pyx_v_data)[0])); + + /* "View.MemoryView":1378 + * for i in range(shape[0]): + * if ndim == 1: + * if inc: # <<<<<<<<<<<<<< + * Py_INCREF(( data)[0]) + * else: + */ + goto __pyx_L6; + } + + /* "View.MemoryView":1381 + * Py_INCREF(( data)[0]) + * else: + * Py_DECREF(( data)[0]) # <<<<<<<<<<<<<< + * else: + * refcount_objects_in_slice(data, shape + 1, strides + 1, ndim - 1, inc) + */ + /*else*/ { + Py_DECREF((((PyObject **)__pyx_v_data)[0])); + } + __pyx_L6:; + + /* "View.MemoryView":1377 + * + * for i in range(shape[0]): + * if ndim == 1: # <<<<<<<<<<<<<< + * if inc: + * Py_INCREF(( data)[0]) + */ + goto __pyx_L5; + } + + /* "View.MemoryView":1383 + * Py_DECREF(( data)[0]) + * else: + * refcount_objects_in_slice(data, shape + 1, strides + 1, ndim - 1, inc) # <<<<<<<<<<<<<< + * + * data += stride + */ + /*else*/ { + __pyx_memoryview_refcount_objects_in_slice(__pyx_v_data, (__pyx_v_shape + 1), (__pyx_v_strides + 1), (__pyx_v_ndim - 1), __pyx_v_inc); + } + __pyx_L5:; + + /* "View.MemoryView":1385 + * refcount_objects_in_slice(data, shape + 1, strides + 1, ndim - 1, inc) + * + * data += stride # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_data = (__pyx_v_data + __pyx_v_stride); + } + + /* "View.MemoryView":1371 + * + * @cname('__pyx_memoryview_refcount_objects_in_slice') + * cdef void refcount_objects_in_slice(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, bint inc) noexcept: + * cdef Py_ssize_t i + */ + + /* function exit code */ +} + +/* "View.MemoryView":1391 + * + * @cname('__pyx_memoryview_slice_assign_scalar') + * cdef void slice_assign_scalar(__Pyx_memviewslice *dst, int ndim, # <<<<<<<<<<<<<< + * size_t itemsize, void *item, + * bint dtype_is_object) noexcept nogil: + */ + +static void __pyx_memoryview_slice_assign_scalar(__Pyx_memviewslice *__pyx_v_dst, int __pyx_v_ndim, size_t __pyx_v_itemsize, void *__pyx_v_item, int __pyx_v_dtype_is_object) { + + /* "View.MemoryView":1394 + * size_t itemsize, void *item, + * bint dtype_is_object) noexcept nogil: + * refcount_copying(dst, dtype_is_object, ndim, inc=False) # <<<<<<<<<<<<<< + * _slice_assign_scalar(dst.data, dst.shape, dst.strides, ndim, itemsize, item) + * refcount_copying(dst, dtype_is_object, ndim, inc=True) + */ + __pyx_memoryview_refcount_copying(__pyx_v_dst, __pyx_v_dtype_is_object, __pyx_v_ndim, 0); + + /* "View.MemoryView":1395 + * bint dtype_is_object) noexcept nogil: + * refcount_copying(dst, dtype_is_object, ndim, inc=False) + * _slice_assign_scalar(dst.data, dst.shape, dst.strides, ndim, itemsize, item) # <<<<<<<<<<<<<< + * refcount_copying(dst, dtype_is_object, ndim, inc=True) + * + */ + __pyx_memoryview__slice_assign_scalar(__pyx_v_dst->data, __pyx_v_dst->shape, __pyx_v_dst->strides, __pyx_v_ndim, __pyx_v_itemsize, __pyx_v_item); + + /* "View.MemoryView":1396 + * refcount_copying(dst, dtype_is_object, ndim, inc=False) + * _slice_assign_scalar(dst.data, dst.shape, dst.strides, ndim, itemsize, item) + * refcount_copying(dst, dtype_is_object, ndim, inc=True) # <<<<<<<<<<<<<< + * + * + */ + __pyx_memoryview_refcount_copying(__pyx_v_dst, __pyx_v_dtype_is_object, __pyx_v_ndim, 1); + + /* "View.MemoryView":1391 + * + * @cname('__pyx_memoryview_slice_assign_scalar') + * cdef void slice_assign_scalar(__Pyx_memviewslice *dst, int ndim, # <<<<<<<<<<<<<< + * size_t itemsize, void *item, + * bint dtype_is_object) noexcept nogil: + */ + + /* function exit code */ +} + +/* "View.MemoryView":1400 + * + * @cname('__pyx_memoryview__slice_assign_scalar') + * cdef void _slice_assign_scalar(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, + * size_t itemsize, void *item) noexcept nogil: + */ + +static void __pyx_memoryview__slice_assign_scalar(char *__pyx_v_data, Py_ssize_t *__pyx_v_shape, Py_ssize_t *__pyx_v_strides, int __pyx_v_ndim, size_t __pyx_v_itemsize, void *__pyx_v_item) { + CYTHON_UNUSED Py_ssize_t __pyx_v_i; + Py_ssize_t __pyx_v_stride; + Py_ssize_t __pyx_v_extent; + int __pyx_t_1; + Py_ssize_t __pyx_t_2; + Py_ssize_t __pyx_t_3; + Py_ssize_t __pyx_t_4; + + /* "View.MemoryView":1404 + * size_t itemsize, void *item) noexcept nogil: + * cdef Py_ssize_t i + * cdef Py_ssize_t stride = strides[0] # <<<<<<<<<<<<<< + * cdef Py_ssize_t extent = shape[0] + * + */ + __pyx_v_stride = (__pyx_v_strides[0]); + + /* "View.MemoryView":1405 + * cdef Py_ssize_t i + * cdef Py_ssize_t stride = strides[0] + * cdef Py_ssize_t extent = shape[0] # <<<<<<<<<<<<<< + * + * if ndim == 1: + */ + __pyx_v_extent = (__pyx_v_shape[0]); + + /* "View.MemoryView":1407 + * cdef Py_ssize_t extent = shape[0] + * + * if ndim == 1: # <<<<<<<<<<<<<< + * for i in range(extent): + * memcpy(data, item, itemsize) + */ + __pyx_t_1 = (__pyx_v_ndim == 1); + if (__pyx_t_1) { + + /* "View.MemoryView":1408 + * + * if ndim == 1: + * for i in range(extent): # <<<<<<<<<<<<<< + * memcpy(data, item, itemsize) + * data += stride + */ + __pyx_t_2 = __pyx_v_extent; + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":1409 + * if ndim == 1: + * for i in range(extent): + * memcpy(data, item, itemsize) # <<<<<<<<<<<<<< + * data += stride + * else: + */ + (void)(memcpy(__pyx_v_data, __pyx_v_item, __pyx_v_itemsize)); + + /* "View.MemoryView":1410 + * for i in range(extent): + * memcpy(data, item, itemsize) + * data += stride # <<<<<<<<<<<<<< + * else: + * for i in range(extent): + */ + __pyx_v_data = (__pyx_v_data + __pyx_v_stride); + } + + /* "View.MemoryView":1407 + * cdef Py_ssize_t extent = shape[0] + * + * if ndim == 1: # <<<<<<<<<<<<<< + * for i in range(extent): + * memcpy(data, item, itemsize) + */ + goto __pyx_L3; + } + + /* "View.MemoryView":1412 + * data += stride + * else: + * for i in range(extent): # <<<<<<<<<<<<<< + * _slice_assign_scalar(data, shape + 1, strides + 1, ndim - 1, itemsize, item) + * data += stride + */ + /*else*/ { + __pyx_t_2 = __pyx_v_extent; + __pyx_t_3 = __pyx_t_2; + for (__pyx_t_4 = 0; __pyx_t_4 < __pyx_t_3; __pyx_t_4+=1) { + __pyx_v_i = __pyx_t_4; + + /* "View.MemoryView":1413 + * else: + * for i in range(extent): + * _slice_assign_scalar(data, shape + 1, strides + 1, ndim - 1, itemsize, item) # <<<<<<<<<<<<<< + * data += stride + * + */ + __pyx_memoryview__slice_assign_scalar(__pyx_v_data, (__pyx_v_shape + 1), (__pyx_v_strides + 1), (__pyx_v_ndim - 1), __pyx_v_itemsize, __pyx_v_item); + + /* "View.MemoryView":1414 + * for i in range(extent): + * _slice_assign_scalar(data, shape + 1, strides + 1, ndim - 1, itemsize, item) + * data += stride # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_data = (__pyx_v_data + __pyx_v_stride); + } + } + __pyx_L3:; + + /* "View.MemoryView":1400 + * + * @cname('__pyx_memoryview__slice_assign_scalar') + * cdef void _slice_assign_scalar(char *data, Py_ssize_t *shape, # <<<<<<<<<<<<<< + * Py_ssize_t *strides, int ndim, + * size_t itemsize, void *item) noexcept nogil: + */ + + /* function exit code */ +} + +/* "(tree fragment)":1 + * def __pyx_unpickle_Enum(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< + * cdef object __pyx_PickleError + * cdef object __pyx_result + */ + +/* Python wrapper */ +static PyObject *__pyx_pw_15View_dot_MemoryView_1__pyx_unpickle_Enum(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyMethodDef __pyx_mdef_15View_dot_MemoryView_1__pyx_unpickle_Enum = {"__pyx_unpickle_Enum", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_15View_dot_MemoryView_1__pyx_unpickle_Enum, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}; +static PyObject *__pyx_pw_15View_dot_MemoryView_1__pyx_unpickle_Enum(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + PyObject *__pyx_v___pyx_type = 0; + long __pyx_v___pyx_checksum; + PyObject *__pyx_v___pyx_state = 0; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[3] = {0,0,0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__pyx_unpickle_Enum (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_pyx_type,&__pyx_n_s_pyx_checksum,&__pyx_n_s_pyx_state,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 3: values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + CYTHON_FALLTHROUGH; + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_type)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + CYTHON_FALLTHROUGH; + case 1: + if (likely((values[1] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_checksum)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[1]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_Enum", 1, 3, 3, 1); __PYX_ERR(1, 1, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 2: + if (likely((values[2] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_pyx_state)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[2]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_Enum", 1, 3, 3, 2); __PYX_ERR(1, 1, __pyx_L3_error) + } + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "__pyx_unpickle_Enum") < 0)) __PYX_ERR(1, 1, __pyx_L3_error) + } + } else if (unlikely(__pyx_nargs != 3)) { + goto __pyx_L5_argtuple_error; + } else { + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + } + __pyx_v___pyx_type = values[0]; + __pyx_v___pyx_checksum = __Pyx_PyInt_As_long(values[1]); if (unlikely((__pyx_v___pyx_checksum == (long)-1) && PyErr_Occurred())) __PYX_ERR(1, 1, __pyx_L3_error) + __pyx_v___pyx_state = values[2]; + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("__pyx_unpickle_Enum", 1, 3, 3, __pyx_nargs); __PYX_ERR(1, 1, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_AddTraceback("View.MemoryView.__pyx_unpickle_Enum", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf_15View_dot_MemoryView___pyx_unpickle_Enum(__pyx_self, __pyx_v___pyx_type, __pyx_v___pyx_checksum, __pyx_v___pyx_state); + + /* function exit code */ + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_15View_dot_MemoryView___pyx_unpickle_Enum(CYTHON_UNUSED PyObject *__pyx_self, PyObject *__pyx_v___pyx_type, long __pyx_v___pyx_checksum, PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_v___pyx_PickleError = 0; + PyObject *__pyx_v___pyx_result = 0; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + unsigned int __pyx_t_5; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__pyx_unpickle_Enum", 1); + + /* "(tree fragment)":4 + * cdef object __pyx_PickleError + * cdef object __pyx_result + * if __pyx_checksum not in (0x82a3537, 0x6ae9995, 0xb068931): # <<<<<<<<<<<<<< + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + */ + __pyx_t_1 = __Pyx_PyInt_From_long(__pyx_v___pyx_checksum); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 4, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_t_2 = (__Pyx_PySequence_ContainsTF(__pyx_t_1, __pyx_tuple__8, Py_NE)); if (unlikely((__pyx_t_2 < 0))) __PYX_ERR(1, 4, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + if (__pyx_t_2) { + + /* "(tree fragment)":5 + * cdef object __pyx_result + * if __pyx_checksum not in (0x82a3537, 0x6ae9995, 0xb068931): + * from pickle import PickleError as __pyx_PickleError # <<<<<<<<<<<<<< + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + * __pyx_result = Enum.__new__(__pyx_type) + */ + __pyx_t_1 = PyList_New(1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_n_s_PickleError); + __Pyx_GIVEREF(__pyx_n_s_PickleError); + if (__Pyx_PyList_SET_ITEM(__pyx_t_1, 0, __pyx_n_s_PickleError)) __PYX_ERR(1, 5, __pyx_L1_error); + __pyx_t_3 = __Pyx_Import(__pyx_n_s_pickle, __pyx_t_1, 0); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_t_1 = __Pyx_ImportFrom(__pyx_t_3, __pyx_n_s_PickleError); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 5, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_INCREF(__pyx_t_1); + __pyx_v___pyx_PickleError = __pyx_t_1; + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + + /* "(tree fragment)":6 + * if __pyx_checksum not in (0x82a3537, 0x6ae9995, 0xb068931): + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum # <<<<<<<<<<<<<< + * __pyx_result = Enum.__new__(__pyx_type) + * if __pyx_state is not None: + */ + __pyx_t_3 = __Pyx_PyInt_From_long(__pyx_v___pyx_checksum); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 6, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_1 = __Pyx_PyString_Format(__pyx_kp_s_Incompatible_checksums_0x_x_vs_0, __pyx_t_3); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 6, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_Raise(__pyx_v___pyx_PickleError, __pyx_t_1, 0, 0); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __PYX_ERR(1, 6, __pyx_L1_error) + + /* "(tree fragment)":4 + * cdef object __pyx_PickleError + * cdef object __pyx_result + * if __pyx_checksum not in (0x82a3537, 0x6ae9995, 0xb068931): # <<<<<<<<<<<<<< + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + */ + } + + /* "(tree fragment)":7 + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + * __pyx_result = Enum.__new__(__pyx_type) # <<<<<<<<<<<<<< + * if __pyx_state is not None: + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + */ + __pyx_t_3 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_MemviewEnum_type), __pyx_n_s_new); if (unlikely(!__pyx_t_3)) __PYX_ERR(1, 7, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_3); + __pyx_t_4 = NULL; + __pyx_t_5 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_3))) { + __pyx_t_4 = PyMethod_GET_SELF(__pyx_t_3); + if (likely(__pyx_t_4)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_3); + __Pyx_INCREF(__pyx_t_4); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_3, function); + __pyx_t_5 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_4, __pyx_v___pyx_type}; + __pyx_t_1 = __Pyx_PyObject_FastCall(__pyx_t_3, __pyx_callargs+1-__pyx_t_5, 1+__pyx_t_5); + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 7, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_3); __pyx_t_3 = 0; + } + __pyx_v___pyx_result = __pyx_t_1; + __pyx_t_1 = 0; + + /* "(tree fragment)":8 + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + * __pyx_result = Enum.__new__(__pyx_type) + * if __pyx_state is not None: # <<<<<<<<<<<<<< + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + * return __pyx_result + */ + __pyx_t_2 = (__pyx_v___pyx_state != Py_None); + if (__pyx_t_2) { + + /* "(tree fragment)":9 + * __pyx_result = Enum.__new__(__pyx_type) + * if __pyx_state is not None: + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) # <<<<<<<<<<<<<< + * return __pyx_result + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): + */ + if (!(likely(PyTuple_CheckExact(__pyx_v___pyx_state))||((__pyx_v___pyx_state) == Py_None) || __Pyx_RaiseUnexpectedTypeError("tuple", __pyx_v___pyx_state))) __PYX_ERR(1, 9, __pyx_L1_error) + __pyx_t_1 = __pyx_unpickle_Enum__set_state(((struct __pyx_MemviewEnum_obj *)__pyx_v___pyx_result), ((PyObject*)__pyx_v___pyx_state)); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 9, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "(tree fragment)":8 + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + * __pyx_result = Enum.__new__(__pyx_type) + * if __pyx_state is not None: # <<<<<<<<<<<<<< + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + * return __pyx_result + */ + } + + /* "(tree fragment)":10 + * if __pyx_state is not None: + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + * return __pyx_result # <<<<<<<<<<<<<< + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): + * __pyx_result.name = __pyx_state[0] + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_v___pyx_result); + __pyx_r = __pyx_v___pyx_result; + goto __pyx_L0; + + /* "(tree fragment)":1 + * def __pyx_unpickle_Enum(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< + * cdef object __pyx_PickleError + * cdef object __pyx_result + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_3); + __Pyx_XDECREF(__pyx_t_4); + __Pyx_AddTraceback("View.MemoryView.__pyx_unpickle_Enum", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XDECREF(__pyx_v___pyx_PickleError); + __Pyx_XDECREF(__pyx_v___pyx_result); + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "(tree fragment)":11 + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + * return __pyx_result + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): # <<<<<<<<<<<<<< + * __pyx_result.name = __pyx_state[0] + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): + */ + +static PyObject *__pyx_unpickle_Enum__set_state(struct __pyx_MemviewEnum_obj *__pyx_v___pyx_result, PyObject *__pyx_v___pyx_state) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_t_2; + Py_ssize_t __pyx_t_3; + int __pyx_t_4; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + unsigned int __pyx_t_8; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__pyx_unpickle_Enum__set_state", 1); + + /* "(tree fragment)":12 + * return __pyx_result + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): + * __pyx_result.name = __pyx_state[0] # <<<<<<<<<<<<<< + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): + * __pyx_result.__dict__.update(__pyx_state[1]) + */ + if (unlikely(__pyx_v___pyx_state == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); + __PYX_ERR(1, 12, __pyx_L1_error) + } + __pyx_t_1 = __Pyx_GetItemInt_Tuple(__pyx_v___pyx_state, 0, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 12, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_GIVEREF(__pyx_t_1); + __Pyx_GOTREF(__pyx_v___pyx_result->name); + __Pyx_DECREF(__pyx_v___pyx_result->name); + __pyx_v___pyx_result->name = __pyx_t_1; + __pyx_t_1 = 0; + + /* "(tree fragment)":13 + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): + * __pyx_result.name = __pyx_state[0] + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): # <<<<<<<<<<<<<< + * __pyx_result.__dict__.update(__pyx_state[1]) + */ + if (unlikely(__pyx_v___pyx_state == Py_None)) { + PyErr_SetString(PyExc_TypeError, "object of type 'NoneType' has no len()"); + __PYX_ERR(1, 13, __pyx_L1_error) + } + __pyx_t_3 = __Pyx_PyTuple_GET_SIZE(__pyx_v___pyx_state); if (unlikely(__pyx_t_3 == ((Py_ssize_t)-1))) __PYX_ERR(1, 13, __pyx_L1_error) + __pyx_t_4 = (__pyx_t_3 > 1); + if (__pyx_t_4) { + } else { + __pyx_t_2 = __pyx_t_4; + goto __pyx_L4_bool_binop_done; + } + __pyx_t_4 = __Pyx_HasAttr(((PyObject *)__pyx_v___pyx_result), __pyx_n_s_dict); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(1, 13, __pyx_L1_error) + __pyx_t_2 = __pyx_t_4; + __pyx_L4_bool_binop_done:; + if (__pyx_t_2) { + + /* "(tree fragment)":14 + * __pyx_result.name = __pyx_state[0] + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): + * __pyx_result.__dict__.update(__pyx_state[1]) # <<<<<<<<<<<<<< + */ + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(((PyObject *)__pyx_v___pyx_result), __pyx_n_s_dict); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 14, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_6 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_update); if (unlikely(!__pyx_t_6)) __PYX_ERR(1, 14, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_6); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + if (unlikely(__pyx_v___pyx_state == Py_None)) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not subscriptable"); + __PYX_ERR(1, 14, __pyx_L1_error) + } + __pyx_t_5 = __Pyx_GetItemInt_Tuple(__pyx_v___pyx_state, 1, long, 1, __Pyx_PyInt_From_long, 0, 0, 1); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 14, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_5); + __pyx_t_7 = NULL; + __pyx_t_8 = 0; + #if CYTHON_UNPACK_METHODS + if (likely(PyMethod_Check(__pyx_t_6))) { + __pyx_t_7 = PyMethod_GET_SELF(__pyx_t_6); + if (likely(__pyx_t_7)) { + PyObject* function = PyMethod_GET_FUNCTION(__pyx_t_6); + __Pyx_INCREF(__pyx_t_7); + __Pyx_INCREF(function); + __Pyx_DECREF_SET(__pyx_t_6, function); + __pyx_t_8 = 1; + } + } + #endif + { + PyObject *__pyx_callargs[2] = {__pyx_t_7, __pyx_t_5}; + __pyx_t_1 = __Pyx_PyObject_FastCall(__pyx_t_6, __pyx_callargs+1-__pyx_t_8, 1+__pyx_t_8); + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 14, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0; + } + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + + /* "(tree fragment)":13 + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): + * __pyx_result.name = __pyx_state[0] + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): # <<<<<<<<<<<<<< + * __pyx_result.__dict__.update(__pyx_state[1]) + */ + } + + /* "(tree fragment)":11 + * __pyx_unpickle_Enum__set_state( __pyx_result, __pyx_state) + * return __pyx_result + * cdef __pyx_unpickle_Enum__set_state(Enum __pyx_result, tuple __pyx_state): # <<<<<<<<<<<<<< + * __pyx_result.name = __pyx_state[0] + * if len(__pyx_state) > 1 and hasattr(__pyx_result, '__dict__'): + */ + + /* function exit code */ + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_AddTraceback("View.MemoryView.__pyx_unpickle_Enum__set_state", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":286 + * + * @property + * cdef inline npy_intp itemsize(self) noexcept nogil: # <<<<<<<<<<<<<< + * return PyDataType_ELSIZE(self) + * + */ + +static CYTHON_INLINE npy_intp __pyx_f_5numpy_5dtype_8itemsize_itemsize(PyArray_Descr *__pyx_v_self) { + npy_intp __pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":287 + * @property + * cdef inline npy_intp itemsize(self) noexcept nogil: + * return PyDataType_ELSIZE(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyDataType_ELSIZE(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":286 + * + * @property + * cdef inline npy_intp itemsize(self) noexcept nogil: # <<<<<<<<<<<<<< + * return PyDataType_ELSIZE(self) + * + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":290 + * + * @property + * cdef inline npy_intp alignment(self) noexcept nogil: # <<<<<<<<<<<<<< + * return PyDataType_ALIGNMENT(self) + * + */ + +static CYTHON_INLINE npy_intp __pyx_f_5numpy_5dtype_9alignment_alignment(PyArray_Descr *__pyx_v_self) { + npy_intp __pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":291 + * @property + * cdef inline npy_intp alignment(self) noexcept nogil: + * return PyDataType_ALIGNMENT(self) # <<<<<<<<<<<<<< + * + * # Use fields/names with care as they may be NULL. You must check + */ + __pyx_r = PyDataType_ALIGNMENT(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":290 + * + * @property + * cdef inline npy_intp alignment(self) noexcept nogil: # <<<<<<<<<<<<<< + * return PyDataType_ALIGNMENT(self) + * + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":296 + * # for this using PyDataType_HASFIELDS. + * @property + * cdef inline object fields(self): # <<<<<<<<<<<<<< + * return PyDataType_FIELDS(self) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_5dtype_6fields_fields(PyArray_Descr *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1; + __Pyx_RefNannySetupContext("fields", 1); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":297 + * @property + * cdef inline object fields(self): + * return PyDataType_FIELDS(self) # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyDataType_FIELDS(__pyx_v_self); + __Pyx_INCREF(((PyObject *)__pyx_t_1)); + __pyx_r = ((PyObject *)__pyx_t_1); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":296 + * # for this using PyDataType_HASFIELDS. + * @property + * cdef inline object fields(self): # <<<<<<<<<<<<<< + * return PyDataType_FIELDS(self) + * + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":300 + * + * @property + * cdef inline tuple names(self): # <<<<<<<<<<<<<< + * return PyDataType_NAMES(self) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_5dtype_5names_names(PyArray_Descr *__pyx_v_self) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1; + __Pyx_RefNannySetupContext("names", 1); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":301 + * @property + * cdef inline tuple names(self): + * return PyDataType_NAMES(self) # <<<<<<<<<<<<<< + * + * # Use PyDataType_HASSUBARRAY to test whether this field is + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyDataType_NAMES(__pyx_v_self); + __Pyx_INCREF(((PyObject*)__pyx_t_1)); + __pyx_r = ((PyObject*)__pyx_t_1); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":300 + * + * @property + * cdef inline tuple names(self): # <<<<<<<<<<<<<< + * return PyDataType_NAMES(self) + * + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":307 + * # this field via the inline helper method PyDataType_SHAPE. + * @property + * cdef inline PyArray_ArrayDescr* subarray(self) noexcept nogil: # <<<<<<<<<<<<<< + * return PyDataType_SUBARRAY(self) + * + */ + +static CYTHON_INLINE PyArray_ArrayDescr *__pyx_f_5numpy_5dtype_8subarray_subarray(PyArray_Descr *__pyx_v_self) { + PyArray_ArrayDescr *__pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":308 + * @property + * cdef inline PyArray_ArrayDescr* subarray(self) noexcept nogil: + * return PyDataType_SUBARRAY(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyDataType_SUBARRAY(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":307 + * # this field via the inline helper method PyDataType_SHAPE. + * @property + * cdef inline PyArray_ArrayDescr* subarray(self) noexcept nogil: # <<<<<<<<<<<<<< + * return PyDataType_SUBARRAY(self) + * + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":311 + * + * @property + * cdef inline npy_uint64 flags(self) noexcept nogil: # <<<<<<<<<<<<<< + * """The data types flags.""" + * return PyDataType_FLAGS(self) + */ + +static CYTHON_INLINE npy_uint64 __pyx_f_5numpy_5dtype_5flags_flags(PyArray_Descr *__pyx_v_self) { + npy_uint64 __pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":313 + * cdef inline npy_uint64 flags(self) noexcept nogil: + * """The data types flags.""" + * return PyDataType_FLAGS(self) # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = PyDataType_FLAGS(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":311 + * + * @property + * cdef inline npy_uint64 flags(self) noexcept nogil: # <<<<<<<<<<<<<< + * """The data types flags.""" + * return PyDataType_FLAGS(self) + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":323 + * + * @property + * cdef inline int numiter(self) noexcept nogil: # <<<<<<<<<<<<<< + * """The number of arrays that need to be broadcast to the same shape.""" + * return PyArray_MultiIter_NUMITER(self) + */ + +static CYTHON_INLINE int __pyx_f_5numpy_9broadcast_7numiter_numiter(PyArrayMultiIterObject *__pyx_v_self) { + int __pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":325 + * cdef inline int numiter(self) noexcept nogil: + * """The number of arrays that need to be broadcast to the same shape.""" + * return PyArray_MultiIter_NUMITER(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_MultiIter_NUMITER(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":323 + * + * @property + * cdef inline int numiter(self) noexcept nogil: # <<<<<<<<<<<<<< + * """The number of arrays that need to be broadcast to the same shape.""" + * return PyArray_MultiIter_NUMITER(self) + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":328 + * + * @property + * cdef inline npy_intp size(self) noexcept nogil: # <<<<<<<<<<<<<< + * """The total broadcasted size.""" + * return PyArray_MultiIter_SIZE(self) + */ + +static CYTHON_INLINE npy_intp __pyx_f_5numpy_9broadcast_4size_size(PyArrayMultiIterObject *__pyx_v_self) { + npy_intp __pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":330 + * cdef inline npy_intp size(self) noexcept nogil: + * """The total broadcasted size.""" + * return PyArray_MultiIter_SIZE(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_MultiIter_SIZE(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":328 + * + * @property + * cdef inline npy_intp size(self) noexcept nogil: # <<<<<<<<<<<<<< + * """The total broadcasted size.""" + * return PyArray_MultiIter_SIZE(self) + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":333 + * + * @property + * cdef inline npy_intp index(self) noexcept nogil: # <<<<<<<<<<<<<< + * """The current (1-d) index into the broadcasted result.""" + * return PyArray_MultiIter_INDEX(self) + */ + +static CYTHON_INLINE npy_intp __pyx_f_5numpy_9broadcast_5index_index(PyArrayMultiIterObject *__pyx_v_self) { + npy_intp __pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":335 + * cdef inline npy_intp index(self) noexcept nogil: + * """The current (1-d) index into the broadcasted result.""" + * return PyArray_MultiIter_INDEX(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_MultiIter_INDEX(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":333 + * + * @property + * cdef inline npy_intp index(self) noexcept nogil: # <<<<<<<<<<<<<< + * """The current (1-d) index into the broadcasted result.""" + * return PyArray_MultiIter_INDEX(self) + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":338 + * + * @property + * cdef inline int nd(self) noexcept nogil: # <<<<<<<<<<<<<< + * """The number of dimensions in the broadcasted result.""" + * return PyArray_MultiIter_NDIM(self) + */ + +static CYTHON_INLINE int __pyx_f_5numpy_9broadcast_2nd_nd(PyArrayMultiIterObject *__pyx_v_self) { + int __pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":340 + * cdef inline int nd(self) noexcept nogil: + * """The number of dimensions in the broadcasted result.""" + * return PyArray_MultiIter_NDIM(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_MultiIter_NDIM(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":338 + * + * @property + * cdef inline int nd(self) noexcept nogil: # <<<<<<<<<<<<<< + * """The number of dimensions in the broadcasted result.""" + * return PyArray_MultiIter_NDIM(self) + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":343 + * + * @property + * cdef inline npy_intp* dimensions(self) noexcept nogil: # <<<<<<<<<<<<<< + * """The shape of the broadcasted result.""" + * return PyArray_MultiIter_DIMS(self) + */ + +static CYTHON_INLINE npy_intp *__pyx_f_5numpy_9broadcast_10dimensions_dimensions(PyArrayMultiIterObject *__pyx_v_self) { + npy_intp *__pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":345 + * cdef inline npy_intp* dimensions(self) noexcept nogil: + * """The shape of the broadcasted result.""" + * return PyArray_MultiIter_DIMS(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_MultiIter_DIMS(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":343 + * + * @property + * cdef inline npy_intp* dimensions(self) noexcept nogil: # <<<<<<<<<<<<<< + * """The shape of the broadcasted result.""" + * return PyArray_MultiIter_DIMS(self) + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":348 + * + * @property + * cdef inline void** iters(self) noexcept nogil: # <<<<<<<<<<<<<< + * """An array of iterator objects that holds the iterators for the arrays to be broadcast together. + * On return, the iterators are adjusted for broadcasting.""" + */ + +static CYTHON_INLINE void **__pyx_f_5numpy_9broadcast_5iters_iters(PyArrayMultiIterObject *__pyx_v_self) { + void **__pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":351 + * """An array of iterator objects that holds the iterators for the arrays to be broadcast together. + * On return, the iterators are adjusted for broadcasting.""" + * return PyArray_MultiIter_ITERS(self) # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = PyArray_MultiIter_ITERS(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":348 + * + * @property + * cdef inline void** iters(self) noexcept nogil: # <<<<<<<<<<<<<< + * """An array of iterator objects that holds the iterators for the arrays to be broadcast together. + * On return, the iterators are adjusted for broadcasting.""" + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":366 + * + * @property + * cdef inline PyObject* base(self) noexcept nogil: # <<<<<<<<<<<<<< + * """Returns a borrowed reference to the object owning the data/memory. + * """ + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_7ndarray_4base_base(PyArrayObject *__pyx_v_self) { + PyObject *__pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":369 + * """Returns a borrowed reference to the object owning the data/memory. + * """ + * return PyArray_BASE(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_BASE(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":366 + * + * @property + * cdef inline PyObject* base(self) noexcept nogil: # <<<<<<<<<<<<<< + * """Returns a borrowed reference to the object owning the data/memory. + * """ + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":372 + * + * @property + * cdef inline dtype descr(self): # <<<<<<<<<<<<<< + * """Returns an owned reference to the dtype of the array. + * """ + */ + +static CYTHON_INLINE PyArray_Descr *__pyx_f_5numpy_7ndarray_5descr_descr(PyArrayObject *__pyx_v_self) { + PyArray_Descr *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyArray_Descr *__pyx_t_1; + __Pyx_RefNannySetupContext("descr", 1); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":375 + * """Returns an owned reference to the dtype of the array. + * """ + * return PyArray_DESCR(self) # <<<<<<<<<<<<<< + * + * @property + */ + __Pyx_XDECREF((PyObject *)__pyx_r); + __pyx_t_1 = PyArray_DESCR(__pyx_v_self); + __Pyx_INCREF((PyObject *)((PyArray_Descr *)__pyx_t_1)); + __pyx_r = ((PyArray_Descr *)__pyx_t_1); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":372 + * + * @property + * cdef inline dtype descr(self): # <<<<<<<<<<<<<< + * """Returns an owned reference to the dtype of the array. + * """ + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF((PyObject *)__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":378 + * + * @property + * cdef inline int ndim(self) noexcept nogil: # <<<<<<<<<<<<<< + * """Returns the number of dimensions in the array. + * """ + */ + +static CYTHON_INLINE int __pyx_f_5numpy_7ndarray_4ndim_ndim(PyArrayObject *__pyx_v_self) { + int __pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":381 + * """Returns the number of dimensions in the array. + * """ + * return PyArray_NDIM(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_NDIM(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":378 + * + * @property + * cdef inline int ndim(self) noexcept nogil: # <<<<<<<<<<<<<< + * """Returns the number of dimensions in the array. + * """ + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":384 + * + * @property + * cdef inline npy_intp *shape(self) noexcept nogil: # <<<<<<<<<<<<<< + * """Returns a pointer to the dimensions/shape of the array. + * The number of elements matches the number of dimensions of the array (ndim). + */ + +static CYTHON_INLINE npy_intp *__pyx_f_5numpy_7ndarray_5shape_shape(PyArrayObject *__pyx_v_self) { + npy_intp *__pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":389 + * Can return NULL for 0-dimensional arrays. + * """ + * return PyArray_DIMS(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_DIMS(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":384 + * + * @property + * cdef inline npy_intp *shape(self) noexcept nogil: # <<<<<<<<<<<<<< + * """Returns a pointer to the dimensions/shape of the array. + * The number of elements matches the number of dimensions of the array (ndim). + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":392 + * + * @property + * cdef inline npy_intp *strides(self) noexcept nogil: # <<<<<<<<<<<<<< + * """Returns a pointer to the strides of the array. + * The number of elements matches the number of dimensions of the array (ndim). + */ + +static CYTHON_INLINE npy_intp *__pyx_f_5numpy_7ndarray_7strides_strides(PyArrayObject *__pyx_v_self) { + npy_intp *__pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":396 + * The number of elements matches the number of dimensions of the array (ndim). + * """ + * return PyArray_STRIDES(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_STRIDES(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":392 + * + * @property + * cdef inline npy_intp *strides(self) noexcept nogil: # <<<<<<<<<<<<<< + * """Returns a pointer to the strides of the array. + * The number of elements matches the number of dimensions of the array (ndim). + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":399 + * + * @property + * cdef inline npy_intp size(self) noexcept nogil: # <<<<<<<<<<<<<< + * """Returns the total size (in number of elements) of the array. + * """ + */ + +static CYTHON_INLINE npy_intp __pyx_f_5numpy_7ndarray_4size_size(PyArrayObject *__pyx_v_self) { + npy_intp __pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":402 + * """Returns the total size (in number of elements) of the array. + * """ + * return PyArray_SIZE(self) # <<<<<<<<<<<<<< + * + * @property + */ + __pyx_r = PyArray_SIZE(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":399 + * + * @property + * cdef inline npy_intp size(self) noexcept nogil: # <<<<<<<<<<<<<< + * """Returns the total size (in number of elements) of the array. + * """ + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":405 + * + * @property + * cdef inline char* data(self) noexcept nogil: # <<<<<<<<<<<<<< + * """The pointer to the data buffer as a char*. + * This is provided for legacy reasons to avoid direct struct field access. + */ + +static CYTHON_INLINE char *__pyx_f_5numpy_7ndarray_4data_data(PyArrayObject *__pyx_v_self) { + char *__pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":411 + * of `PyArray_DATA()` instead, which returns a 'void*'. + * """ + * return PyArray_BYTES(self) # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = PyArray_BYTES(__pyx_v_self); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":405 + * + * @property + * cdef inline char* data(self) noexcept nogil: # <<<<<<<<<<<<<< + * """The pointer to the data buffer as a char*. + * This is provided for legacy reasons to avoid direct struct field access. + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":807 + * ctypedef long double complex clongdouble_t + * + * cdef inline object PyArray_MultiIterNew1(a): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(1, a) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew1(PyObject *__pyx_v_a) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("PyArray_MultiIterNew1", 1); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":808 + * + * cdef inline object PyArray_MultiIterNew1(a): + * return PyArray_MultiIterNew(1, a) # <<<<<<<<<<<<<< + * + * cdef inline object PyArray_MultiIterNew2(a, b): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyArray_MultiIterNew(1, ((void *)__pyx_v_a)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 808, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":807 + * ctypedef long double complex clongdouble_t + * + * cdef inline object PyArray_MultiIterNew1(a): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(1, a) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("numpy.PyArray_MultiIterNew1", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":810 + * return PyArray_MultiIterNew(1, a) + * + * cdef inline object PyArray_MultiIterNew2(a, b): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(2, a, b) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew2(PyObject *__pyx_v_a, PyObject *__pyx_v_b) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("PyArray_MultiIterNew2", 1); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":811 + * + * cdef inline object PyArray_MultiIterNew2(a, b): + * return PyArray_MultiIterNew(2, a, b) # <<<<<<<<<<<<<< + * + * cdef inline object PyArray_MultiIterNew3(a, b, c): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyArray_MultiIterNew(2, ((void *)__pyx_v_a), ((void *)__pyx_v_b)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 811, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":810 + * return PyArray_MultiIterNew(1, a) + * + * cdef inline object PyArray_MultiIterNew2(a, b): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(2, a, b) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("numpy.PyArray_MultiIterNew2", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":813 + * return PyArray_MultiIterNew(2, a, b) + * + * cdef inline object PyArray_MultiIterNew3(a, b, c): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(3, a, b, c) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew3(PyObject *__pyx_v_a, PyObject *__pyx_v_b, PyObject *__pyx_v_c) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("PyArray_MultiIterNew3", 1); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":814 + * + * cdef inline object PyArray_MultiIterNew3(a, b, c): + * return PyArray_MultiIterNew(3, a, b, c) # <<<<<<<<<<<<<< + * + * cdef inline object PyArray_MultiIterNew4(a, b, c, d): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyArray_MultiIterNew(3, ((void *)__pyx_v_a), ((void *)__pyx_v_b), ((void *)__pyx_v_c)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 814, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":813 + * return PyArray_MultiIterNew(2, a, b) + * + * cdef inline object PyArray_MultiIterNew3(a, b, c): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(3, a, b, c) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("numpy.PyArray_MultiIterNew3", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":816 + * return PyArray_MultiIterNew(3, a, b, c) + * + * cdef inline object PyArray_MultiIterNew4(a, b, c, d): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(4, a, b, c, d) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew4(PyObject *__pyx_v_a, PyObject *__pyx_v_b, PyObject *__pyx_v_c, PyObject *__pyx_v_d) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("PyArray_MultiIterNew4", 1); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":817 + * + * cdef inline object PyArray_MultiIterNew4(a, b, c, d): + * return PyArray_MultiIterNew(4, a, b, c, d) # <<<<<<<<<<<<<< + * + * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyArray_MultiIterNew(4, ((void *)__pyx_v_a), ((void *)__pyx_v_b), ((void *)__pyx_v_c), ((void *)__pyx_v_d)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 817, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":816 + * return PyArray_MultiIterNew(3, a, b, c) + * + * cdef inline object PyArray_MultiIterNew4(a, b, c, d): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(4, a, b, c, d) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("numpy.PyArray_MultiIterNew4", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":819 + * return PyArray_MultiIterNew(4, a, b, c, d) + * + * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(5, a, b, c, d, e) + * + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyArray_MultiIterNew5(PyObject *__pyx_v_a, PyObject *__pyx_v_b, PyObject *__pyx_v_c, PyObject *__pyx_v_d, PyObject *__pyx_v_e) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("PyArray_MultiIterNew5", 1); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":820 + * + * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): + * return PyArray_MultiIterNew(5, a, b, c, d, e) # <<<<<<<<<<<<<< + * + * cdef inline tuple PyDataType_SHAPE(dtype d): + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_1 = PyArray_MultiIterNew(5, ((void *)__pyx_v_a), ((void *)__pyx_v_b), ((void *)__pyx_v_c), ((void *)__pyx_v_d), ((void *)__pyx_v_e)); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 820, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_r = __pyx_t_1; + __pyx_t_1 = 0; + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":819 + * return PyArray_MultiIterNew(4, a, b, c, d) + * + * cdef inline object PyArray_MultiIterNew5(a, b, c, d, e): # <<<<<<<<<<<<<< + * return PyArray_MultiIterNew(5, a, b, c, d, e) + * + */ + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_AddTraceback("numpy.PyArray_MultiIterNew5", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = 0; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":822 + * return PyArray_MultiIterNew(5, a, b, c, d, e) + * + * cdef inline tuple PyDataType_SHAPE(dtype d): # <<<<<<<<<<<<<< + * if PyDataType_HASSUBARRAY(d): + * return d.subarray.shape + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_PyDataType_SHAPE(PyArray_Descr *__pyx_v_d) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + PyObject *__pyx_t_2; + __Pyx_RefNannySetupContext("PyDataType_SHAPE", 1); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":823 + * + * cdef inline tuple PyDataType_SHAPE(dtype d): + * if PyDataType_HASSUBARRAY(d): # <<<<<<<<<<<<<< + * return d.subarray.shape + * else: + */ + __pyx_t_1 = PyDataType_HASSUBARRAY(__pyx_v_d); + if (__pyx_t_1) { + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":824 + * cdef inline tuple PyDataType_SHAPE(dtype d): + * if PyDataType_HASSUBARRAY(d): + * return d.subarray.shape # <<<<<<<<<<<<<< + * else: + * return () + */ + __Pyx_XDECREF(__pyx_r); + __pyx_t_2 = __pyx_f_5numpy_5dtype_8subarray_subarray(__pyx_v_d)->shape; + __Pyx_INCREF(((PyObject*)__pyx_t_2)); + __pyx_r = ((PyObject*)__pyx_t_2); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":823 + * + * cdef inline tuple PyDataType_SHAPE(dtype d): + * if PyDataType_HASSUBARRAY(d): # <<<<<<<<<<<<<< + * return d.subarray.shape + * else: + */ + } + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":826 + * return d.subarray.shape + * else: + * return () # <<<<<<<<<<<<<< + * + * + */ + /*else*/ { + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(__pyx_empty_tuple); + __pyx_r = __pyx_empty_tuple; + goto __pyx_L0; + } + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":822 + * return PyArray_MultiIterNew(5, a, b, c, d, e) + * + * cdef inline tuple PyDataType_SHAPE(dtype d): # <<<<<<<<<<<<<< + * if PyDataType_HASSUBARRAY(d): + * return d.subarray.shape + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1010 + * int _import_umath() except -1 + * + * cdef inline void set_array_base(ndarray arr, object base) except *: # <<<<<<<<<<<<<< + * Py_INCREF(base) # important to do this before stealing the reference below! + * PyArray_SetBaseObject(arr, base) + */ + +static CYTHON_INLINE void __pyx_f_5numpy_set_array_base(PyArrayObject *__pyx_v_arr, PyObject *__pyx_v_base) { + int __pyx_t_1; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1011 + * + * cdef inline void set_array_base(ndarray arr, object base) except *: + * Py_INCREF(base) # important to do this before stealing the reference below! # <<<<<<<<<<<<<< + * PyArray_SetBaseObject(arr, base) + * + */ + Py_INCREF(__pyx_v_base); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1012 + * cdef inline void set_array_base(ndarray arr, object base) except *: + * Py_INCREF(base) # important to do this before stealing the reference below! + * PyArray_SetBaseObject(arr, base) # <<<<<<<<<<<<<< + * + * cdef inline object get_array_base(ndarray arr): + */ + __pyx_t_1 = PyArray_SetBaseObject(__pyx_v_arr, __pyx_v_base); if (unlikely(__pyx_t_1 == ((int)-1))) __PYX_ERR(2, 1012, __pyx_L1_error) + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1010 + * int _import_umath() except -1 + * + * cdef inline void set_array_base(ndarray arr, object base) except *: # <<<<<<<<<<<<<< + * Py_INCREF(base) # important to do this before stealing the reference below! + * PyArray_SetBaseObject(arr, base) + */ + + /* function exit code */ + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_AddTraceback("numpy.set_array_base", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_L0:; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1014 + * PyArray_SetBaseObject(arr, base) + * + * cdef inline object get_array_base(ndarray arr): # <<<<<<<<<<<<<< + * base = PyArray_BASE(arr) + * if base is NULL: + */ + +static CYTHON_INLINE PyObject *__pyx_f_5numpy_get_array_base(PyArrayObject *__pyx_v_arr) { + PyObject *__pyx_v_base; + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + __Pyx_RefNannySetupContext("get_array_base", 1); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1015 + * + * cdef inline object get_array_base(ndarray arr): + * base = PyArray_BASE(arr) # <<<<<<<<<<<<<< + * if base is NULL: + * return None + */ + __pyx_v_base = PyArray_BASE(__pyx_v_arr); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1016 + * cdef inline object get_array_base(ndarray arr): + * base = PyArray_BASE(arr) + * if base is NULL: # <<<<<<<<<<<<<< + * return None + * return base + */ + __pyx_t_1 = (__pyx_v_base == NULL); + if (__pyx_t_1) { + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1017 + * base = PyArray_BASE(arr) + * if base is NULL: + * return None # <<<<<<<<<<<<<< + * return base + * + */ + __Pyx_XDECREF(__pyx_r); + __pyx_r = Py_None; __Pyx_INCREF(Py_None); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1016 + * cdef inline object get_array_base(ndarray arr): + * base = PyArray_BASE(arr) + * if base is NULL: # <<<<<<<<<<<<<< + * return None + * return base + */ + } + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1018 + * if base is NULL: + * return None + * return base # <<<<<<<<<<<<<< + * + * # Versions of the import_* functions which are more suitable for + */ + __Pyx_XDECREF(__pyx_r); + __Pyx_INCREF(((PyObject *)__pyx_v_base)); + __pyx_r = ((PyObject *)__pyx_v_base); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1014 + * PyArray_SetBaseObject(arr, base) + * + * cdef inline object get_array_base(ndarray arr): # <<<<<<<<<<<<<< + * base = PyArray_BASE(arr) + * if base is NULL: + */ + + /* function exit code */ + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1022 + * # Versions of the import_* functions which are more suitable for + * # Cython code. + * cdef inline int import_array() except -1: # <<<<<<<<<<<<<< + * try: + * __pyx_import_array() + */ + +static CYTHON_INLINE int __pyx_f_5numpy_import_array(void) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("import_array", 1); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1023 + * # Cython code. + * cdef inline int import_array() except -1: + * try: # <<<<<<<<<<<<<< + * __pyx_import_array() + * except Exception: + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + /*try:*/ { + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1024 + * cdef inline int import_array() except -1: + * try: + * __pyx_import_array() # <<<<<<<<<<<<<< + * except Exception: + * raise ImportError("numpy._core.multiarray failed to import") + */ + __pyx_t_4 = _import_array(); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(2, 1024, __pyx_L3_error) + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1023 + * # Cython code. + * cdef inline int import_array() except -1: + * try: # <<<<<<<<<<<<<< + * __pyx_import_array() + * except Exception: + */ + } + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + goto __pyx_L8_try_end; + __pyx_L3_error:; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1025 + * try: + * __pyx_import_array() + * except Exception: # <<<<<<<<<<<<<< + * raise ImportError("numpy._core.multiarray failed to import") + * + */ + __pyx_t_4 = __Pyx_PyErr_ExceptionMatches(((PyObject *)(&((PyTypeObject*)PyExc_Exception)[0]))); + if (__pyx_t_4) { + __Pyx_AddTraceback("numpy.import_array", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_5, &__pyx_t_6, &__pyx_t_7) < 0) __PYX_ERR(2, 1025, __pyx_L5_except_error) + __Pyx_XGOTREF(__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_6); + __Pyx_XGOTREF(__pyx_t_7); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1026 + * __pyx_import_array() + * except Exception: + * raise ImportError("numpy._core.multiarray failed to import") # <<<<<<<<<<<<<< + * + * cdef inline int import_umath() except -1: + */ + __pyx_t_8 = __Pyx_PyObject_Call(__pyx_builtin_ImportError, __pyx_tuple__9, NULL); if (unlikely(!__pyx_t_8)) __PYX_ERR(2, 1026, __pyx_L5_except_error) + __Pyx_GOTREF(__pyx_t_8); + __Pyx_Raise(__pyx_t_8, 0, 0, 0); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __PYX_ERR(2, 1026, __pyx_L5_except_error) + } + goto __pyx_L5_except_error; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1023 + * # Cython code. + * cdef inline int import_array() except -1: + * try: # <<<<<<<<<<<<<< + * __pyx_import_array() + * except Exception: + */ + __pyx_L5_except_error:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + goto __pyx_L1_error; + __pyx_L8_try_end:; + } + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1022 + * # Versions of the import_* functions which are more suitable for + * # Cython code. + * cdef inline int import_array() except -1: # <<<<<<<<<<<<<< + * try: + * __pyx_import_array() + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("numpy.import_array", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1028 + * raise ImportError("numpy._core.multiarray failed to import") + * + * cdef inline int import_umath() except -1: # <<<<<<<<<<<<<< + * try: + * _import_umath() + */ + +static CYTHON_INLINE int __pyx_f_5numpy_import_umath(void) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("import_umath", 1); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1029 + * + * cdef inline int import_umath() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + /*try:*/ { + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1030 + * cdef inline int import_umath() except -1: + * try: + * _import_umath() # <<<<<<<<<<<<<< + * except Exception: + * raise ImportError("numpy._core.umath failed to import") + */ + __pyx_t_4 = _import_umath(); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(2, 1030, __pyx_L3_error) + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1029 + * + * cdef inline int import_umath() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + } + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + goto __pyx_L8_try_end; + __pyx_L3_error:; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1031 + * try: + * _import_umath() + * except Exception: # <<<<<<<<<<<<<< + * raise ImportError("numpy._core.umath failed to import") + * + */ + __pyx_t_4 = __Pyx_PyErr_ExceptionMatches(((PyObject *)(&((PyTypeObject*)PyExc_Exception)[0]))); + if (__pyx_t_4) { + __Pyx_AddTraceback("numpy.import_umath", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_5, &__pyx_t_6, &__pyx_t_7) < 0) __PYX_ERR(2, 1031, __pyx_L5_except_error) + __Pyx_XGOTREF(__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_6); + __Pyx_XGOTREF(__pyx_t_7); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1032 + * _import_umath() + * except Exception: + * raise ImportError("numpy._core.umath failed to import") # <<<<<<<<<<<<<< + * + * cdef inline int import_ufunc() except -1: + */ + __pyx_t_8 = __Pyx_PyObject_Call(__pyx_builtin_ImportError, __pyx_tuple__10, NULL); if (unlikely(!__pyx_t_8)) __PYX_ERR(2, 1032, __pyx_L5_except_error) + __Pyx_GOTREF(__pyx_t_8); + __Pyx_Raise(__pyx_t_8, 0, 0, 0); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __PYX_ERR(2, 1032, __pyx_L5_except_error) + } + goto __pyx_L5_except_error; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1029 + * + * cdef inline int import_umath() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + __pyx_L5_except_error:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + goto __pyx_L1_error; + __pyx_L8_try_end:; + } + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1028 + * raise ImportError("numpy._core.multiarray failed to import") + * + * cdef inline int import_umath() except -1: # <<<<<<<<<<<<<< + * try: + * _import_umath() + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("numpy.import_umath", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1034 + * raise ImportError("numpy._core.umath failed to import") + * + * cdef inline int import_ufunc() except -1: # <<<<<<<<<<<<<< + * try: + * _import_umath() + */ + +static CYTHON_INLINE int __pyx_f_5numpy_import_ufunc(void) { + int __pyx_r; + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + int __pyx_t_4; + PyObject *__pyx_t_5 = NULL; + PyObject *__pyx_t_6 = NULL; + PyObject *__pyx_t_7 = NULL; + PyObject *__pyx_t_8 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("import_ufunc", 1); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1035 + * + * cdef inline int import_ufunc() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + /*try:*/ { + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1036 + * cdef inline int import_ufunc() except -1: + * try: + * _import_umath() # <<<<<<<<<<<<<< + * except Exception: + * raise ImportError("numpy._core.umath failed to import") + */ + __pyx_t_4 = _import_umath(); if (unlikely(__pyx_t_4 == ((int)-1))) __PYX_ERR(2, 1036, __pyx_L3_error) + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1035 + * + * cdef inline int import_ufunc() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + } + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + goto __pyx_L8_try_end; + __pyx_L3_error:; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1037 + * try: + * _import_umath() + * except Exception: # <<<<<<<<<<<<<< + * raise ImportError("numpy._core.umath failed to import") + * + */ + __pyx_t_4 = __Pyx_PyErr_ExceptionMatches(((PyObject *)(&((PyTypeObject*)PyExc_Exception)[0]))); + if (__pyx_t_4) { + __Pyx_AddTraceback("numpy.import_ufunc", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_5, &__pyx_t_6, &__pyx_t_7) < 0) __PYX_ERR(2, 1037, __pyx_L5_except_error) + __Pyx_XGOTREF(__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_6); + __Pyx_XGOTREF(__pyx_t_7); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1038 + * _import_umath() + * except Exception: + * raise ImportError("numpy._core.umath failed to import") # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_8 = __Pyx_PyObject_Call(__pyx_builtin_ImportError, __pyx_tuple__10, NULL); if (unlikely(!__pyx_t_8)) __PYX_ERR(2, 1038, __pyx_L5_except_error) + __Pyx_GOTREF(__pyx_t_8); + __Pyx_Raise(__pyx_t_8, 0, 0, 0); + __Pyx_DECREF(__pyx_t_8); __pyx_t_8 = 0; + __PYX_ERR(2, 1038, __pyx_L5_except_error) + } + goto __pyx_L5_except_error; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1035 + * + * cdef inline int import_ufunc() except -1: + * try: # <<<<<<<<<<<<<< + * _import_umath() + * except Exception: + */ + __pyx_L5_except_error:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + goto __pyx_L1_error; + __pyx_L8_try_end:; + } + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1034 + * raise ImportError("numpy._core.umath failed to import") + * + * cdef inline int import_ufunc() except -1: # <<<<<<<<<<<<<< + * try: + * _import_umath() + */ + + /* function exit code */ + __pyx_r = 0; + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_6); + __Pyx_XDECREF(__pyx_t_7); + __Pyx_XDECREF(__pyx_t_8); + __Pyx_AddTraceback("numpy.import_ufunc", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = -1; + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1041 + * + * + * cdef inline bint is_timedelta64_object(object obj) noexcept: # <<<<<<<<<<<<<< + * """ + * Cython equivalent of `isinstance(obj, np.timedelta64)` + */ + +static CYTHON_INLINE int __pyx_f_5numpy_is_timedelta64_object(PyObject *__pyx_v_obj) { + int __pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1053 + * bool + * """ + * return PyObject_TypeCheck(obj, &PyTimedeltaArrType_Type) # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = PyObject_TypeCheck(__pyx_v_obj, (&PyTimedeltaArrType_Type)); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1041 + * + * + * cdef inline bint is_timedelta64_object(object obj) noexcept: # <<<<<<<<<<<<<< + * """ + * Cython equivalent of `isinstance(obj, np.timedelta64)` + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1056 + * + * + * cdef inline bint is_datetime64_object(object obj) noexcept: # <<<<<<<<<<<<<< + * """ + * Cython equivalent of `isinstance(obj, np.datetime64)` + */ + +static CYTHON_INLINE int __pyx_f_5numpy_is_datetime64_object(PyObject *__pyx_v_obj) { + int __pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1068 + * bool + * """ + * return PyObject_TypeCheck(obj, &PyDatetimeArrType_Type) # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = PyObject_TypeCheck(__pyx_v_obj, (&PyDatetimeArrType_Type)); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1056 + * + * + * cdef inline bint is_datetime64_object(object obj) noexcept: # <<<<<<<<<<<<<< + * """ + * Cython equivalent of `isinstance(obj, np.datetime64)` + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1071 + * + * + * cdef inline npy_datetime get_datetime64_value(object obj) noexcept nogil: # <<<<<<<<<<<<<< + * """ + * returns the int64 value underlying scalar numpy datetime64 object + */ + +static CYTHON_INLINE npy_datetime __pyx_f_5numpy_get_datetime64_value(PyObject *__pyx_v_obj) { + npy_datetime __pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1078 + * also needed. That can be found using `get_datetime64_unit`. + * """ + * return (obj).obval # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = ((PyDatetimeScalarObject *)__pyx_v_obj)->obval; + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1071 + * + * + * cdef inline npy_datetime get_datetime64_value(object obj) noexcept nogil: # <<<<<<<<<<<<<< + * """ + * returns the int64 value underlying scalar numpy datetime64 object + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1081 + * + * + * cdef inline npy_timedelta get_timedelta64_value(object obj) noexcept nogil: # <<<<<<<<<<<<<< + * """ + * returns the int64 value underlying scalar numpy timedelta64 object + */ + +static CYTHON_INLINE npy_timedelta __pyx_f_5numpy_get_timedelta64_value(PyObject *__pyx_v_obj) { + npy_timedelta __pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1085 + * returns the int64 value underlying scalar numpy timedelta64 object + * """ + * return (obj).obval # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = ((PyTimedeltaScalarObject *)__pyx_v_obj)->obval; + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1081 + * + * + * cdef inline npy_timedelta get_timedelta64_value(object obj) noexcept nogil: # <<<<<<<<<<<<<< + * """ + * returns the int64 value underlying scalar numpy timedelta64 object + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1088 + * + * + * cdef inline NPY_DATETIMEUNIT get_datetime64_unit(object obj) noexcept nogil: # <<<<<<<<<<<<<< + * """ + * returns the unit part of the dtype for a numpy datetime64 object. + */ + +static CYTHON_INLINE NPY_DATETIMEUNIT __pyx_f_5numpy_get_datetime64_unit(PyObject *__pyx_v_obj) { + NPY_DATETIMEUNIT __pyx_r; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1092 + * returns the unit part of the dtype for a numpy datetime64 object. + * """ + * return (obj).obmeta.base # <<<<<<<<<<<<<< + * + * + */ + __pyx_r = ((NPY_DATETIMEUNIT)((PyDatetimeScalarObject *)__pyx_v_obj)->obmeta.base); + goto __pyx_L0; + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1088 + * + * + * cdef inline NPY_DATETIMEUNIT get_datetime64_unit(object obj) noexcept nogil: # <<<<<<<<<<<<<< + * """ + * returns the unit part of the dtype for a numpy datetime64 object. + */ + + /* function exit code */ + __pyx_L0:; + return __pyx_r; +} + +/* "TTS/tts/utils/monotonic_align/core.pyx":11 + * @cython.boundscheck(False) + * @cython.wraparound(False) + * cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: # <<<<<<<<<<<<<< + * cdef int x + * cdef int y + */ + +static void __pyx_f_3TTS_3tts_5utils_15monotonic_align_4core_maximum_path_each(__Pyx_memviewslice __pyx_v_path, __Pyx_memviewslice __pyx_v_value, int __pyx_v_t_x, int __pyx_v_t_y, float __pyx_v_max_neg_val) { + int __pyx_v_x; + int __pyx_v_y; + float __pyx_v_v_prev; + float __pyx_v_v_cur; + int __pyx_v_index; + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + long __pyx_t_4; + int __pyx_t_5; + long __pyx_t_6; + int __pyx_t_7; + long __pyx_t_8; + Py_ssize_t __pyx_t_9; + Py_ssize_t __pyx_t_10; + float __pyx_t_11; + float __pyx_t_12; + float __pyx_t_13; + Py_ssize_t __pyx_t_14; + Py_ssize_t __pyx_t_15; + int __pyx_t_16; + + /* "TTS/tts/utils/monotonic_align/core.pyx":17 + * cdef float v_cur + * cdef float tmp + * cdef int index = t_x - 1 # <<<<<<<<<<<<<< + * + * for y in range(t_y): + */ + __pyx_v_index = (__pyx_v_t_x - 1); + + /* "TTS/tts/utils/monotonic_align/core.pyx":19 + * cdef int index = t_x - 1 + * + * for y in range(t_y): # <<<<<<<<<<<<<< + * for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + * if x == y: + */ + __pyx_t_1 = __pyx_v_t_y; + __pyx_t_2 = __pyx_t_1; + for (__pyx_t_3 = 0; __pyx_t_3 < __pyx_t_2; __pyx_t_3+=1) { + __pyx_v_y = __pyx_t_3; + + /* "TTS/tts/utils/monotonic_align/core.pyx":20 + * + * for y in range(t_y): + * for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): # <<<<<<<<<<<<<< + * if x == y: + * v_cur = max_neg_val + */ + __pyx_t_4 = (__pyx_v_y + 1); + __pyx_t_5 = __pyx_v_t_x; + __pyx_t_7 = (__pyx_t_4 < __pyx_t_5); + if (__pyx_t_7) { + __pyx_t_6 = __pyx_t_4; + } else { + __pyx_t_6 = __pyx_t_5; + } + __pyx_t_4 = __pyx_t_6; + __pyx_t_5 = ((__pyx_v_t_x + __pyx_v_y) - __pyx_v_t_y); + __pyx_t_6 = 0; + __pyx_t_7 = (__pyx_t_5 > __pyx_t_6); + if (__pyx_t_7) { + __pyx_t_8 = __pyx_t_5; + } else { + __pyx_t_8 = __pyx_t_6; + } + __pyx_t_6 = __pyx_t_4; + for (__pyx_t_5 = __pyx_t_8; __pyx_t_5 < __pyx_t_6; __pyx_t_5+=1) { + __pyx_v_x = __pyx_t_5; + + /* "TTS/tts/utils/monotonic_align/core.pyx":21 + * for y in range(t_y): + * for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + * if x == y: # <<<<<<<<<<<<<< + * v_cur = max_neg_val + * else: + */ + __pyx_t_7 = (__pyx_v_x == __pyx_v_y); + if (__pyx_t_7) { + + /* "TTS/tts/utils/monotonic_align/core.pyx":22 + * for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + * if x == y: + * v_cur = max_neg_val # <<<<<<<<<<<<<< + * else: + * v_cur = value[x, y-1] + */ + __pyx_v_v_cur = __pyx_v_max_neg_val; + + /* "TTS/tts/utils/monotonic_align/core.pyx":21 + * for y in range(t_y): + * for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + * if x == y: # <<<<<<<<<<<<<< + * v_cur = max_neg_val + * else: + */ + goto __pyx_L7; + } + + /* "TTS/tts/utils/monotonic_align/core.pyx":24 + * v_cur = max_neg_val + * else: + * v_cur = value[x, y-1] # <<<<<<<<<<<<<< + * if x == 0: + * if y == 0: + */ + /*else*/ { + __pyx_t_9 = __pyx_v_x; + __pyx_t_10 = (__pyx_v_y - 1); + __pyx_v_v_cur = (*((float *) ( /* dim=1 */ ((char *) (((float *) ( /* dim=0 */ (__pyx_v_value.data + __pyx_t_9 * __pyx_v_value.strides[0]) )) + __pyx_t_10)) ))); + } + __pyx_L7:; + + /* "TTS/tts/utils/monotonic_align/core.pyx":25 + * else: + * v_cur = value[x, y-1] + * if x == 0: # <<<<<<<<<<<<<< + * if y == 0: + * v_prev = 0. + */ + __pyx_t_7 = (__pyx_v_x == 0); + if (__pyx_t_7) { + + /* "TTS/tts/utils/monotonic_align/core.pyx":26 + * v_cur = value[x, y-1] + * if x == 0: + * if y == 0: # <<<<<<<<<<<<<< + * v_prev = 0. + * else: + */ + __pyx_t_7 = (__pyx_v_y == 0); + if (__pyx_t_7) { + + /* "TTS/tts/utils/monotonic_align/core.pyx":27 + * if x == 0: + * if y == 0: + * v_prev = 0. # <<<<<<<<<<<<<< + * else: + * v_prev = max_neg_val + */ + __pyx_v_v_prev = 0.; + + /* "TTS/tts/utils/monotonic_align/core.pyx":26 + * v_cur = value[x, y-1] + * if x == 0: + * if y == 0: # <<<<<<<<<<<<<< + * v_prev = 0. + * else: + */ + goto __pyx_L9; + } + + /* "TTS/tts/utils/monotonic_align/core.pyx":29 + * v_prev = 0. + * else: + * v_prev = max_neg_val # <<<<<<<<<<<<<< + * else: + * v_prev = value[x-1, y-1] + */ + /*else*/ { + __pyx_v_v_prev = __pyx_v_max_neg_val; + } + __pyx_L9:; + + /* "TTS/tts/utils/monotonic_align/core.pyx":25 + * else: + * v_cur = value[x, y-1] + * if x == 0: # <<<<<<<<<<<<<< + * if y == 0: + * v_prev = 0. + */ + goto __pyx_L8; + } + + /* "TTS/tts/utils/monotonic_align/core.pyx":31 + * v_prev = max_neg_val + * else: + * v_prev = value[x-1, y-1] # <<<<<<<<<<<<<< + * value[x, y] = max(v_cur, v_prev) + value[x, y] + * + */ + /*else*/ { + __pyx_t_10 = (__pyx_v_x - 1); + __pyx_t_9 = (__pyx_v_y - 1); + __pyx_v_v_prev = (*((float *) ( /* dim=1 */ ((char *) (((float *) ( /* dim=0 */ (__pyx_v_value.data + __pyx_t_10 * __pyx_v_value.strides[0]) )) + __pyx_t_9)) ))); + } + __pyx_L8:; + + /* "TTS/tts/utils/monotonic_align/core.pyx":32 + * else: + * v_prev = value[x-1, y-1] + * value[x, y] = max(v_cur, v_prev) + value[x, y] # <<<<<<<<<<<<<< + * + * for y in range(t_y - 1, -1, -1): + */ + __pyx_t_11 = __pyx_v_v_prev; + __pyx_t_12 = __pyx_v_v_cur; + __pyx_t_7 = (__pyx_t_11 > __pyx_t_12); + if (__pyx_t_7) { + __pyx_t_13 = __pyx_t_11; + } else { + __pyx_t_13 = __pyx_t_12; + } + __pyx_t_9 = __pyx_v_x; + __pyx_t_10 = __pyx_v_y; + __pyx_t_14 = __pyx_v_x; + __pyx_t_15 = __pyx_v_y; + *((float *) ( /* dim=1 */ ((char *) (((float *) ( /* dim=0 */ (__pyx_v_value.data + __pyx_t_14 * __pyx_v_value.strides[0]) )) + __pyx_t_15)) )) = (__pyx_t_13 + (*((float *) ( /* dim=1 */ ((char *) (((float *) ( /* dim=0 */ (__pyx_v_value.data + __pyx_t_9 * __pyx_v_value.strides[0]) )) + __pyx_t_10)) )))); + } + } + + /* "TTS/tts/utils/monotonic_align/core.pyx":34 + * value[x, y] = max(v_cur, v_prev) + value[x, y] + * + * for y in range(t_y - 1, -1, -1): # <<<<<<<<<<<<<< + * path[index, y] = 1 + * if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): + */ + for (__pyx_t_1 = (__pyx_v_t_y - 1); __pyx_t_1 > -1; __pyx_t_1-=1) { + __pyx_v_y = __pyx_t_1; + + /* "TTS/tts/utils/monotonic_align/core.pyx":35 + * + * for y in range(t_y - 1, -1, -1): + * path[index, y] = 1 # <<<<<<<<<<<<<< + * if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): + * index = index - 1 + */ + __pyx_t_10 = __pyx_v_index; + __pyx_t_9 = __pyx_v_y; + *((int *) ( /* dim=1 */ ((char *) (((int *) ( /* dim=0 */ (__pyx_v_path.data + __pyx_t_10 * __pyx_v_path.strides[0]) )) + __pyx_t_9)) )) = 1; + + /* "TTS/tts/utils/monotonic_align/core.pyx":36 + * for y in range(t_y - 1, -1, -1): + * path[index, y] = 1 + * if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): # <<<<<<<<<<<<<< + * index = index - 1 + * + */ + __pyx_t_16 = (__pyx_v_index != 0); + if (__pyx_t_16) { + } else { + __pyx_t_7 = __pyx_t_16; + goto __pyx_L13_bool_binop_done; + } + __pyx_t_16 = (__pyx_v_index == __pyx_v_y); + if (!__pyx_t_16) { + } else { + __pyx_t_7 = __pyx_t_16; + goto __pyx_L13_bool_binop_done; + } + __pyx_t_9 = __pyx_v_index; + __pyx_t_10 = (__pyx_v_y - 1); + __pyx_t_15 = (__pyx_v_index - 1); + __pyx_t_14 = (__pyx_v_y - 1); + __pyx_t_16 = ((*((float *) ( /* dim=1 */ ((char *) (((float *) ( /* dim=0 */ (__pyx_v_value.data + __pyx_t_9 * __pyx_v_value.strides[0]) )) + __pyx_t_10)) ))) < (*((float *) ( /* dim=1 */ ((char *) (((float *) ( /* dim=0 */ (__pyx_v_value.data + __pyx_t_15 * __pyx_v_value.strides[0]) )) + __pyx_t_14)) )))); + __pyx_t_7 = __pyx_t_16; + __pyx_L13_bool_binop_done:; + if (__pyx_t_7) { + + /* "TTS/tts/utils/monotonic_align/core.pyx":37 + * path[index, y] = 1 + * if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): + * index = index - 1 # <<<<<<<<<<<<<< + * + * + */ + __pyx_v_index = (__pyx_v_index - 1); + + /* "TTS/tts/utils/monotonic_align/core.pyx":36 + * for y in range(t_y - 1, -1, -1): + * path[index, y] = 1 + * if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): # <<<<<<<<<<<<<< + * index = index - 1 + * + */ + } + } + + /* "TTS/tts/utils/monotonic_align/core.pyx":11 + * @cython.boundscheck(False) + * @cython.wraparound(False) + * cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: # <<<<<<<<<<<<<< + * cdef int x + * cdef int y + */ + + /* function exit code */ +} + +/* "TTS/tts/utils/monotonic_align/core.pyx":42 + * @cython.boundscheck(False) + * @cython.wraparound(False) + * cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: # <<<<<<<<<<<<<< + * cdef int b = values.shape[0] + * + */ + +static PyObject *__pyx_pw_3TTS_3tts_5utils_15monotonic_align_4core_1maximum_path_c(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static void __pyx_f_3TTS_3tts_5utils_15monotonic_align_4core_maximum_path_c(__Pyx_memviewslice __pyx_v_paths, __Pyx_memviewslice __pyx_v_values, __Pyx_memviewslice __pyx_v_t_xs, __Pyx_memviewslice __pyx_v_t_ys, CYTHON_UNUSED int __pyx_skip_dispatch, struct __pyx_opt_args_3TTS_3tts_5utils_15monotonic_align_4core_maximum_path_c *__pyx_optional_args) { + float __pyx_v_max_neg_val = __pyx_k__11; + CYTHON_UNUSED int __pyx_v_b; + int __pyx_v_i; + __Pyx_RefNannyDeclarations + int __pyx_t_1; + int __pyx_t_2; + int __pyx_t_3; + __Pyx_memviewslice __pyx_t_4 = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_memviewslice __pyx_t_5 = { 0, 0, { 0 }, { 0 }, { 0 } }; + Py_ssize_t __pyx_t_6; + Py_ssize_t __pyx_t_7; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save; + #endif + __Pyx_RefNannySetupContext("maximum_path_c", 1); + if (__pyx_optional_args) { + if (__pyx_optional_args->__pyx_n > 0) { + __pyx_v_max_neg_val = __pyx_optional_args->max_neg_val; + } + } + + /* "TTS/tts/utils/monotonic_align/core.pyx":43 + * @cython.wraparound(False) + * cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: + * cdef int b = values.shape[0] # <<<<<<<<<<<<<< + * + * cdef int i + */ + __pyx_v_b = (__pyx_v_values.shape[0]); + + /* "TTS/tts/utils/monotonic_align/core.pyx":46 + * + * cdef int i + * for i in prange(b, nogil=True): # <<<<<<<<<<<<<< + * maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) + */ + { + #ifdef WITH_THREAD + PyThreadState *_save; + _save = NULL; + if (PyGILState_Check()) { + Py_UNBLOCK_THREADS + } + __Pyx_FastGIL_Remember(); + #endif + /*try:*/ { + __pyx_t_1 = __pyx_v_b; + { + int __pyx_parallel_temp0 = ((int)0xbad0bad0); + const char *__pyx_parallel_filename = NULL; int __pyx_parallel_lineno = 0, __pyx_parallel_clineno = 0; + PyObject *__pyx_parallel_exc_type = NULL, *__pyx_parallel_exc_value = NULL, *__pyx_parallel_exc_tb = NULL; + int __pyx_parallel_why; + __pyx_parallel_why = 0; + #if ((defined(__APPLE__) || defined(__OSX__)) && (defined(__GNUC__) && (__GNUC__ > 2 || (__GNUC__ == 2 && (__GNUC_MINOR__ > 95))))) + #undef likely + #undef unlikely + #define likely(x) (x) + #define unlikely(x) (x) + #endif + __pyx_t_3 = (__pyx_t_1 - 0 + 1 - 1/abs(1)) / 1; + if (__pyx_t_3 > 0) + { + #ifdef _OPENMP + #pragma omp parallel private(__pyx_t_6, __pyx_t_7) firstprivate(__pyx_t_4, __pyx_t_5) private(__pyx_filename, __pyx_lineno, __pyx_clineno) shared(__pyx_parallel_why, __pyx_parallel_exc_type, __pyx_parallel_exc_value, __pyx_parallel_exc_tb) + #endif /* _OPENMP */ + { + #ifdef _OPENMP + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + Py_BEGIN_ALLOW_THREADS + #endif /* _OPENMP */ + #ifdef _OPENMP + #pragma omp for firstprivate(__pyx_v_i) lastprivate(__pyx_v_i) + #endif /* _OPENMP */ + for (__pyx_t_2 = 0; __pyx_t_2 < __pyx_t_3; __pyx_t_2++){ + if (__pyx_parallel_why < 2) + { + __pyx_v_i = (int)(0 + 1 * __pyx_t_2); + + /* "TTS/tts/utils/monotonic_align/core.pyx":47 + * cdef int i + * for i in prange(b, nogil=True): + * maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) # <<<<<<<<<<<<<< + */ + __pyx_t_4.data = __pyx_v_paths.data; + __pyx_t_4.memview = __pyx_v_paths.memview; + __PYX_INC_MEMVIEW(&__pyx_t_4, 0); + { + Py_ssize_t __pyx_tmp_idx = __pyx_v_i; + Py_ssize_t __pyx_tmp_stride = __pyx_v_paths.strides[0]; + __pyx_t_4.data += __pyx_tmp_idx * __pyx_tmp_stride; +} + +__pyx_t_4.shape[0] = __pyx_v_paths.shape[1]; +__pyx_t_4.strides[0] = __pyx_v_paths.strides[1]; + __pyx_t_4.suboffsets[0] = -1; + +__pyx_t_4.shape[1] = __pyx_v_paths.shape[2]; +__pyx_t_4.strides[1] = __pyx_v_paths.strides[2]; + __pyx_t_4.suboffsets[1] = -1; + +__pyx_t_5.data = __pyx_v_values.data; + __pyx_t_5.memview = __pyx_v_values.memview; + __PYX_INC_MEMVIEW(&__pyx_t_5, 0); + { + Py_ssize_t __pyx_tmp_idx = __pyx_v_i; + Py_ssize_t __pyx_tmp_stride = __pyx_v_values.strides[0]; + __pyx_t_5.data += __pyx_tmp_idx * __pyx_tmp_stride; +} + +__pyx_t_5.shape[0] = __pyx_v_values.shape[1]; +__pyx_t_5.strides[0] = __pyx_v_values.strides[1]; + __pyx_t_5.suboffsets[0] = -1; + +__pyx_t_5.shape[1] = __pyx_v_values.shape[2]; +__pyx_t_5.strides[1] = __pyx_v_values.strides[2]; + __pyx_t_5.suboffsets[1] = -1; + +__pyx_t_6 = __pyx_v_i; + __pyx_t_7 = __pyx_v_i; + __pyx_f_3TTS_3tts_5utils_15monotonic_align_4core_maximum_path_each(__pyx_t_4, __pyx_t_5, (*((int *) ( /* dim=0 */ ((char *) (((int *) __pyx_v_t_xs.data) + __pyx_t_6)) ))), (*((int *) ( /* dim=0 */ ((char *) (((int *) __pyx_v_t_ys.data) + __pyx_t_7)) ))), __pyx_v_max_neg_val); if (unlikely(__Pyx_ErrOccurredWithGIL())) __PYX_ERR(0, 47, __pyx_L8_error) + __PYX_XCLEAR_MEMVIEW(&__pyx_t_4, 0); + __pyx_t_4.memview = NULL; __pyx_t_4.data = NULL; + __PYX_XCLEAR_MEMVIEW(&__pyx_t_5, 0); + __pyx_t_5.memview = NULL; __pyx_t_5.data = NULL; + goto __pyx_L11; + __pyx_L8_error:; + { + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + #ifdef _OPENMP + #pragma omp flush(__pyx_parallel_exc_type) + #endif /* _OPENMP */ + if (!__pyx_parallel_exc_type) { + __Pyx_ErrFetchWithState(&__pyx_parallel_exc_type, &__pyx_parallel_exc_value, &__pyx_parallel_exc_tb); + __pyx_parallel_filename = __pyx_filename; __pyx_parallel_lineno = __pyx_lineno; __pyx_parallel_clineno = __pyx_clineno; + __Pyx_GOTREF(__pyx_parallel_exc_type); + } + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + } + __pyx_parallel_why = 4; + goto __pyx_L10; + __pyx_L10:; + #ifdef _OPENMP + #pragma omp critical(__pyx_parallel_lastprivates0) + #endif /* _OPENMP */ + { + __pyx_parallel_temp0 = __pyx_v_i; + } + __pyx_L11:; + #ifdef _OPENMP + #pragma omp flush(__pyx_parallel_why) + #endif /* _OPENMP */ + } + } + #ifdef _OPENMP + Py_END_ALLOW_THREADS + #else +{ +#ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + #endif /* _OPENMP */ + /* Clean up any temporaries */ + __PYX_XCLEAR_MEMVIEW(&__pyx_t_4, 0); + __pyx_t_4.memview = NULL; __pyx_t_4.data = NULL; + __PYX_XCLEAR_MEMVIEW(&__pyx_t_5, 0); + __pyx_t_5.memview = NULL; __pyx_t_5.data = NULL; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + #ifndef _OPENMP +} +#endif /* _OPENMP */ + } + } + if (__pyx_parallel_exc_type) { + /* This may have been overridden by a continue, break or return in another thread. Prefer the error. */ + __pyx_parallel_why = 4; + } + if (__pyx_parallel_why) { + __pyx_v_i = __pyx_parallel_temp0; + switch (__pyx_parallel_why) { + case 4: + { + #ifdef WITH_THREAD + PyGILState_STATE __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __Pyx_GIVEREF(__pyx_parallel_exc_type); + __Pyx_ErrRestoreWithState(__pyx_parallel_exc_type, __pyx_parallel_exc_value, __pyx_parallel_exc_tb); + __pyx_filename = __pyx_parallel_filename; __pyx_lineno = __pyx_parallel_lineno; __pyx_clineno = __pyx_parallel_clineno; + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + } + goto __pyx_L4_error; + } + } + } + #if ((defined(__APPLE__) || defined(__OSX__)) && (defined(__GNUC__) && (__GNUC__ > 2 || (__GNUC__ == 2 && (__GNUC_MINOR__ > 95))))) + #undef likely + #undef unlikely + #define likely(x) __builtin_expect(!!(x), 1) + #define unlikely(x) __builtin_expect(!!(x), 0) + #endif + } + + /* "TTS/tts/utils/monotonic_align/core.pyx":46 + * + * cdef int i + * for i in prange(b, nogil=True): # <<<<<<<<<<<<<< + * maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) + */ + /*finally:*/ { + /*normal exit:*/{ + #ifdef WITH_THREAD + __Pyx_FastGIL_Forget(); + if (_save) { + Py_BLOCK_THREADS + } + #endif + goto __pyx_L5; + } + __pyx_L4_error: { + #ifdef WITH_THREAD + __Pyx_FastGIL_Forget(); + if (_save) { + Py_BLOCK_THREADS + } + #endif + goto __pyx_L1_error; + } + __pyx_L5:; + } + } + + /* "TTS/tts/utils/monotonic_align/core.pyx":42 + * @cython.boundscheck(False) + * @cython.wraparound(False) + * cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: # <<<<<<<<<<<<<< + * cdef int b = values.shape[0] + * + */ + + /* function exit code */ + goto __pyx_L0; + __pyx_L1_error:; + #ifdef WITH_THREAD + __pyx_gilstate_save = __Pyx_PyGILState_Ensure(); + #endif + __PYX_XCLEAR_MEMVIEW(&__pyx_t_4, 1); + __PYX_XCLEAR_MEMVIEW(&__pyx_t_5, 1); + __Pyx_AddTraceback("TTS.tts.utils.monotonic_align.core.maximum_path_c", __pyx_clineno, __pyx_lineno, __pyx_filename); + #ifdef WITH_THREAD + __Pyx_PyGILState_Release(__pyx_gilstate_save); + #endif + __pyx_L0:; + __Pyx_RefNannyFinishContextNogil() +} + +/* Python wrapper */ +static PyObject *__pyx_pw_3TTS_3tts_5utils_15monotonic_align_4core_1maximum_path_c(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +); /*proto*/ +static PyMethodDef __pyx_mdef_3TTS_3tts_5utils_15monotonic_align_4core_1maximum_path_c = {"maximum_path_c", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw_3TTS_3tts_5utils_15monotonic_align_4core_1maximum_path_c, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}; +static PyObject *__pyx_pw_3TTS_3tts_5utils_15monotonic_align_4core_1maximum_path_c(PyObject *__pyx_self, +#if CYTHON_METH_FASTCALL +PyObject *const *__pyx_args, Py_ssize_t __pyx_nargs, PyObject *__pyx_kwds +#else +PyObject *__pyx_args, PyObject *__pyx_kwds +#endif +) { + __Pyx_memviewslice __pyx_v_paths = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_memviewslice __pyx_v_values = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_memviewslice __pyx_v_t_xs = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_memviewslice __pyx_v_t_ys = { 0, 0, { 0 }, { 0 }, { 0 } }; + float __pyx_v_max_neg_val; + #if !CYTHON_METH_FASTCALL + CYTHON_UNUSED Py_ssize_t __pyx_nargs; + #endif + CYTHON_UNUSED PyObject *const *__pyx_kwvalues; + PyObject* values[5] = {0,0,0,0,0}; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + PyObject *__pyx_r = 0; + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("maximum_path_c (wrapper)", 0); + #if !CYTHON_METH_FASTCALL + #if CYTHON_ASSUME_SAFE_MACROS + __pyx_nargs = PyTuple_GET_SIZE(__pyx_args); + #else + __pyx_nargs = PyTuple_Size(__pyx_args); if (unlikely(__pyx_nargs < 0)) return NULL; + #endif + #endif + __pyx_kwvalues = __Pyx_KwValues_FASTCALL(__pyx_args, __pyx_nargs); + { + PyObject **__pyx_pyargnames[] = {&__pyx_n_s_paths,&__pyx_n_s_values,&__pyx_n_s_t_xs,&__pyx_n_s_t_ys,&__pyx_n_s_max_neg_val,0}; + if (__pyx_kwds) { + Py_ssize_t kw_args; + switch (__pyx_nargs) { + case 5: values[4] = __Pyx_Arg_FASTCALL(__pyx_args, 4); + CYTHON_FALLTHROUGH; + case 4: values[3] = __Pyx_Arg_FASTCALL(__pyx_args, 3); + CYTHON_FALLTHROUGH; + case 3: values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + CYTHON_FALLTHROUGH; + case 2: values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + CYTHON_FALLTHROUGH; + case 1: values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + CYTHON_FALLTHROUGH; + case 0: break; + default: goto __pyx_L5_argtuple_error; + } + kw_args = __Pyx_NumKwargs_FASTCALL(__pyx_kwds); + switch (__pyx_nargs) { + case 0: + if (likely((values[0] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_paths)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[0]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 42, __pyx_L3_error) + else goto __pyx_L5_argtuple_error; + CYTHON_FALLTHROUGH; + case 1: + if (likely((values[1] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_values)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[1]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 42, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("maximum_path_c", 0, 4, 5, 1); __PYX_ERR(0, 42, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 2: + if (likely((values[2] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_t_xs)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[2]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 42, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("maximum_path_c", 0, 4, 5, 2); __PYX_ERR(0, 42, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 3: + if (likely((values[3] = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_t_ys)) != 0)) { + (void)__Pyx_Arg_NewRef_FASTCALL(values[3]); + kw_args--; + } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 42, __pyx_L3_error) + else { + __Pyx_RaiseArgtupleInvalid("maximum_path_c", 0, 4, 5, 3); __PYX_ERR(0, 42, __pyx_L3_error) + } + CYTHON_FALLTHROUGH; + case 4: + if (kw_args > 0) { + PyObject* value = __Pyx_GetKwValue_FASTCALL(__pyx_kwds, __pyx_kwvalues, __pyx_n_s_max_neg_val); + if (value) { values[4] = __Pyx_Arg_NewRef_FASTCALL(value); kw_args--; } + else if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 42, __pyx_L3_error) + } + } + if (unlikely(kw_args > 0)) { + const Py_ssize_t kwd_pos_args = __pyx_nargs; + if (unlikely(__Pyx_ParseOptionalKeywords(__pyx_kwds, __pyx_kwvalues, __pyx_pyargnames, 0, values + 0, kwd_pos_args, "maximum_path_c") < 0)) __PYX_ERR(0, 42, __pyx_L3_error) + } + } else { + switch (__pyx_nargs) { + case 5: values[4] = __Pyx_Arg_FASTCALL(__pyx_args, 4); + CYTHON_FALLTHROUGH; + case 4: values[3] = __Pyx_Arg_FASTCALL(__pyx_args, 3); + values[2] = __Pyx_Arg_FASTCALL(__pyx_args, 2); + values[1] = __Pyx_Arg_FASTCALL(__pyx_args, 1); + values[0] = __Pyx_Arg_FASTCALL(__pyx_args, 0); + break; + default: goto __pyx_L5_argtuple_error; + } + } + __pyx_v_paths = __Pyx_PyObject_to_MemoryviewSlice_d_d_dc_int(values[0], PyBUF_WRITABLE); if (unlikely(!__pyx_v_paths.memview)) __PYX_ERR(0, 42, __pyx_L3_error) + __pyx_v_values = __Pyx_PyObject_to_MemoryviewSlice_d_d_dc_float(values[1], PyBUF_WRITABLE); if (unlikely(!__pyx_v_values.memview)) __PYX_ERR(0, 42, __pyx_L3_error) + __pyx_v_t_xs = __Pyx_PyObject_to_MemoryviewSlice_dc_int(values[2], PyBUF_WRITABLE); if (unlikely(!__pyx_v_t_xs.memview)) __PYX_ERR(0, 42, __pyx_L3_error) + __pyx_v_t_ys = __Pyx_PyObject_to_MemoryviewSlice_dc_int(values[3], PyBUF_WRITABLE); if (unlikely(!__pyx_v_t_ys.memview)) __PYX_ERR(0, 42, __pyx_L3_error) + if (values[4]) { + __pyx_v_max_neg_val = __pyx_PyFloat_AsFloat(values[4]); if (unlikely((__pyx_v_max_neg_val == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 42, __pyx_L3_error) + } else { + __pyx_v_max_neg_val = __pyx_k__11; + } + } + goto __pyx_L6_skip; + __pyx_L5_argtuple_error:; + __Pyx_RaiseArgtupleInvalid("maximum_path_c", 0, 4, 5, __pyx_nargs); __PYX_ERR(0, 42, __pyx_L3_error) + __pyx_L6_skip:; + goto __pyx_L4_argument_unpacking_done; + __pyx_L3_error:; + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __PYX_XCLEAR_MEMVIEW(&__pyx_v_paths, 1); + __PYX_XCLEAR_MEMVIEW(&__pyx_v_values, 1); + __PYX_XCLEAR_MEMVIEW(&__pyx_v_t_xs, 1); + __PYX_XCLEAR_MEMVIEW(&__pyx_v_t_ys, 1); + __Pyx_AddTraceback("TTS.tts.utils.monotonic_align.core.maximum_path_c", __pyx_clineno, __pyx_lineno, __pyx_filename); + __Pyx_RefNannyFinishContext(); + return NULL; + __pyx_L4_argument_unpacking_done:; + __pyx_r = __pyx_pf_3TTS_3tts_5utils_15monotonic_align_4core_maximum_path_c(__pyx_self, __pyx_v_paths, __pyx_v_values, __pyx_v_t_xs, __pyx_v_t_ys, __pyx_v_max_neg_val); + + /* function exit code */ + __PYX_XCLEAR_MEMVIEW(&__pyx_v_paths, 1); + __PYX_XCLEAR_MEMVIEW(&__pyx_v_values, 1); + __PYX_XCLEAR_MEMVIEW(&__pyx_v_t_xs, 1); + __PYX_XCLEAR_MEMVIEW(&__pyx_v_t_ys, 1); + { + Py_ssize_t __pyx_temp; + for (__pyx_temp=0; __pyx_temp < (Py_ssize_t)(sizeof(values)/sizeof(values[0])); ++__pyx_temp) { + __Pyx_Arg_XDECREF_FASTCALL(values[__pyx_temp]); + } + } + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} + +static PyObject *__pyx_pf_3TTS_3tts_5utils_15monotonic_align_4core_maximum_path_c(CYTHON_UNUSED PyObject *__pyx_self, __Pyx_memviewslice __pyx_v_paths, __Pyx_memviewslice __pyx_v_values, __Pyx_memviewslice __pyx_v_t_xs, __Pyx_memviewslice __pyx_v_t_ys, float __pyx_v_max_neg_val) { + PyObject *__pyx_r = NULL; + __Pyx_RefNannyDeclarations + struct __pyx_opt_args_3TTS_3tts_5utils_15monotonic_align_4core_maximum_path_c __pyx_t_1; + PyObject *__pyx_t_2 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("maximum_path_c", 1); + __Pyx_XDECREF(__pyx_r); + if (unlikely(!__pyx_v_paths.memview)) { __Pyx_RaiseUnboundLocalError("paths"); __PYX_ERR(0, 42, __pyx_L1_error) } + if (unlikely(!__pyx_v_values.memview)) { __Pyx_RaiseUnboundLocalError("values"); __PYX_ERR(0, 42, __pyx_L1_error) } + if (unlikely(!__pyx_v_t_xs.memview)) { __Pyx_RaiseUnboundLocalError("t_xs"); __PYX_ERR(0, 42, __pyx_L1_error) } + if (unlikely(!__pyx_v_t_ys.memview)) { __Pyx_RaiseUnboundLocalError("t_ys"); __PYX_ERR(0, 42, __pyx_L1_error) } + __pyx_t_1.__pyx_n = 1; + __pyx_t_1.max_neg_val = __pyx_v_max_neg_val; + __pyx_f_3TTS_3tts_5utils_15monotonic_align_4core_maximum_path_c(__pyx_v_paths, __pyx_v_values, __pyx_v_t_xs, __pyx_v_t_ys, 0, &__pyx_t_1); if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 42, __pyx_L1_error) + __pyx_t_2 = __Pyx_void_to_None(NULL); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 42, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_2); + __pyx_r = __pyx_t_2; + __pyx_t_2 = 0; + goto __pyx_L0; + + /* function exit code */ + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_2); + __Pyx_AddTraceback("TTS.tts.utils.monotonic_align.core.maximum_path_c", __pyx_clineno, __pyx_lineno, __pyx_filename); + __pyx_r = NULL; + __pyx_L0:; + __Pyx_XGIVEREF(__pyx_r); + __Pyx_RefNannyFinishContext(); + return __pyx_r; +} +static struct __pyx_vtabstruct_array __pyx_vtable_array; + +static PyObject *__pyx_tp_new_array(PyTypeObject *t, PyObject *a, PyObject *k) { + struct __pyx_array_obj *p; + PyObject *o; + #if CYTHON_COMPILING_IN_LIMITED_API + allocfunc alloc_func = (allocfunc)PyType_GetSlot(t, Py_tp_alloc); + o = alloc_func(t, 0); + #else + if (likely(!__Pyx_PyType_HasFeature(t, Py_TPFLAGS_IS_ABSTRACT))) { + o = (*t->tp_alloc)(t, 0); + } else { + o = (PyObject *) PyBaseObject_Type.tp_new(t, __pyx_empty_tuple, 0); + } + if (unlikely(!o)) return 0; + #endif + p = ((struct __pyx_array_obj *)o); + p->__pyx_vtab = __pyx_vtabptr_array; + p->mode = ((PyObject*)Py_None); Py_INCREF(Py_None); + p->_format = ((PyObject*)Py_None); Py_INCREF(Py_None); + if (unlikely(__pyx_array___cinit__(o, a, k) < 0)) goto bad; + return o; + bad: + Py_DECREF(o); o = 0; + return NULL; +} + +static void __pyx_tp_dealloc_array(PyObject *o) { + struct __pyx_array_obj *p = (struct __pyx_array_obj *)o; + #if CYTHON_USE_TP_FINALIZE + if (unlikely((PY_VERSION_HEX >= 0x03080000 || __Pyx_PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE)) && __Pyx_PyObject_GetSlot(o, tp_finalize, destructor)) && (!PyType_IS_GC(Py_TYPE(o)) || !__Pyx_PyObject_GC_IsFinalized(o))) { + if (__Pyx_PyObject_GetSlot(o, tp_dealloc, destructor) == __pyx_tp_dealloc_array) { + if (PyObject_CallFinalizerFromDealloc(o)) return; + } + } + #endif + { + PyObject *etype, *eval, *etb; + PyErr_Fetch(&etype, &eval, &etb); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) + 1); + __pyx_array___dealloc__(o); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) - 1); + PyErr_Restore(etype, eval, etb); + } + Py_CLEAR(p->mode); + Py_CLEAR(p->_format); + #if CYTHON_USE_TYPE_SLOTS || CYTHON_COMPILING_IN_PYPY + (*Py_TYPE(o)->tp_free)(o); + #else + { + freefunc tp_free = (freefunc)PyType_GetSlot(Py_TYPE(o), Py_tp_free); + if (tp_free) tp_free(o); + } + #endif +} +static PyObject *__pyx_sq_item_array(PyObject *o, Py_ssize_t i) { + PyObject *r; + PyObject *x = PyInt_FromSsize_t(i); if(!x) return 0; + r = Py_TYPE(o)->tp_as_mapping->mp_subscript(o, x); + Py_DECREF(x); + return r; +} + +static int __pyx_mp_ass_subscript_array(PyObject *o, PyObject *i, PyObject *v) { + if (v) { + return __pyx_array___setitem__(o, i, v); + } + else { + __Pyx_TypeName o_type_name; + o_type_name = __Pyx_PyType_GetName(Py_TYPE(o)); + PyErr_Format(PyExc_NotImplementedError, + "Subscript deletion not supported by " __Pyx_FMT_TYPENAME, o_type_name); + __Pyx_DECREF_TypeName(o_type_name); + return -1; + } +} + +static PyObject *__pyx_tp_getattro_array(PyObject *o, PyObject *n) { + PyObject *v = __Pyx_PyObject_GenericGetAttr(o, n); + if (!v && PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + v = __pyx_array___getattr__(o, n); + } + return v; +} + +static PyObject *__pyx_getprop___pyx_array_memview(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_5array_7memview_1__get__(o); +} + +static PyMethodDef __pyx_methods_array[] = { + {"__getattr__", (PyCFunction)__pyx_array___getattr__, METH_O|METH_COEXIST, 0}, + {"__reduce_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_array_1__reduce_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"__setstate_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_array_3__setstate_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {0, 0, 0, 0} +}; + +static struct PyGetSetDef __pyx_getsets_array[] = { + {(char *)"memview", __pyx_getprop___pyx_array_memview, 0, (char *)0, 0}, + {0, 0, 0, 0, 0} +}; +#if CYTHON_USE_TYPE_SPECS +#if !CYTHON_COMPILING_IN_LIMITED_API + +static PyBufferProcs __pyx_tp_as_buffer_array = { + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getreadbuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getwritebuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getsegcount*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getcharbuffer*/ + #endif + __pyx_array_getbuffer, /*bf_getbuffer*/ + 0, /*bf_releasebuffer*/ +}; +#endif +static PyType_Slot __pyx_type___pyx_array_slots[] = { + {Py_tp_dealloc, (void *)__pyx_tp_dealloc_array}, + {Py_sq_length, (void *)__pyx_array___len__}, + {Py_sq_item, (void *)__pyx_sq_item_array}, + {Py_mp_length, (void *)__pyx_array___len__}, + {Py_mp_subscript, (void *)__pyx_array___getitem__}, + {Py_mp_ass_subscript, (void *)__pyx_mp_ass_subscript_array}, + {Py_tp_getattro, (void *)__pyx_tp_getattro_array}, + #if defined(Py_bf_getbuffer) + {Py_bf_getbuffer, (void *)__pyx_array_getbuffer}, + #endif + {Py_tp_methods, (void *)__pyx_methods_array}, + {Py_tp_getset, (void *)__pyx_getsets_array}, + {Py_tp_new, (void *)__pyx_tp_new_array}, + {0, 0}, +}; +static PyType_Spec __pyx_type___pyx_array_spec = { + "TTS.tts.utils.monotonic_align.core.array", + sizeof(struct __pyx_array_obj), + 0, + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_SEQUENCE, + __pyx_type___pyx_array_slots, +}; +#else + +static PySequenceMethods __pyx_tp_as_sequence_array = { + __pyx_array___len__, /*sq_length*/ + 0, /*sq_concat*/ + 0, /*sq_repeat*/ + __pyx_sq_item_array, /*sq_item*/ + 0, /*sq_slice*/ + 0, /*sq_ass_item*/ + 0, /*sq_ass_slice*/ + 0, /*sq_contains*/ + 0, /*sq_inplace_concat*/ + 0, /*sq_inplace_repeat*/ +}; + +static PyMappingMethods __pyx_tp_as_mapping_array = { + __pyx_array___len__, /*mp_length*/ + __pyx_array___getitem__, /*mp_subscript*/ + __pyx_mp_ass_subscript_array, /*mp_ass_subscript*/ +}; + +static PyBufferProcs __pyx_tp_as_buffer_array = { + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getreadbuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getwritebuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getsegcount*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getcharbuffer*/ + #endif + __pyx_array_getbuffer, /*bf_getbuffer*/ + 0, /*bf_releasebuffer*/ +}; + +static PyTypeObject __pyx_type___pyx_array = { + PyVarObject_HEAD_INIT(0, 0) + "TTS.tts.utils.monotonic_align.core.""array", /*tp_name*/ + sizeof(struct __pyx_array_obj), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + __pyx_tp_dealloc_array, /*tp_dealloc*/ + #if PY_VERSION_HEX < 0x030800b4 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030800b4 + 0, /*tp_vectorcall_offset*/ + #endif + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + #if PY_MAJOR_VERSION < 3 + 0, /*tp_compare*/ + #endif + #if PY_MAJOR_VERSION >= 3 + 0, /*tp_as_async*/ + #endif + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + &__pyx_tp_as_sequence_array, /*tp_as_sequence*/ + &__pyx_tp_as_mapping_array, /*tp_as_mapping*/ + 0, /*tp_hash*/ + 0, /*tp_call*/ + 0, /*tp_str*/ + __pyx_tp_getattro_array, /*tp_getattro*/ + 0, /*tp_setattro*/ + &__pyx_tp_as_buffer_array, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_SEQUENCE, /*tp_flags*/ + 0, /*tp_doc*/ + 0, /*tp_traverse*/ + 0, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + __pyx_methods_array, /*tp_methods*/ + 0, /*tp_members*/ + __pyx_getsets_array, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + #if !CYTHON_USE_TYPE_SPECS + 0, /*tp_dictoffset*/ + #endif + 0, /*tp_init*/ + 0, /*tp_alloc*/ + __pyx_tp_new_array, /*tp_new*/ + 0, /*tp_free*/ + 0, /*tp_is_gc*/ + 0, /*tp_bases*/ + 0, /*tp_mro*/ + 0, /*tp_cache*/ + 0, /*tp_subclasses*/ + 0, /*tp_weaklist*/ + 0, /*tp_del*/ + 0, /*tp_version_tag*/ + #if PY_VERSION_HEX >= 0x030400a1 + #if CYTHON_USE_TP_FINALIZE + 0, /*tp_finalize*/ + #else + NULL, /*tp_finalize*/ + #endif + #endif + #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) + 0, /*tp_vectorcall*/ + #endif + #if __PYX_NEED_TP_PRINT_SLOT == 1 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030C0000 + 0, /*tp_watched*/ + #endif + #if PY_VERSION_HEX >= 0x030d00A4 + 0, /*tp_versions_used*/ + #endif + #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 + 0, /*tp_pypy_flags*/ + #endif +}; +#endif + +static PyObject *__pyx_tp_new_Enum(PyTypeObject *t, CYTHON_UNUSED PyObject *a, CYTHON_UNUSED PyObject *k) { + struct __pyx_MemviewEnum_obj *p; + PyObject *o; + #if CYTHON_COMPILING_IN_LIMITED_API + allocfunc alloc_func = (allocfunc)PyType_GetSlot(t, Py_tp_alloc); + o = alloc_func(t, 0); + #else + if (likely(!__Pyx_PyType_HasFeature(t, Py_TPFLAGS_IS_ABSTRACT))) { + o = (*t->tp_alloc)(t, 0); + } else { + o = (PyObject *) PyBaseObject_Type.tp_new(t, __pyx_empty_tuple, 0); + } + if (unlikely(!o)) return 0; + #endif + p = ((struct __pyx_MemviewEnum_obj *)o); + p->name = Py_None; Py_INCREF(Py_None); + return o; +} + +static void __pyx_tp_dealloc_Enum(PyObject *o) { + struct __pyx_MemviewEnum_obj *p = (struct __pyx_MemviewEnum_obj *)o; + #if CYTHON_USE_TP_FINALIZE + if (unlikely((PY_VERSION_HEX >= 0x03080000 || __Pyx_PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE)) && __Pyx_PyObject_GetSlot(o, tp_finalize, destructor)) && !__Pyx_PyObject_GC_IsFinalized(o)) { + if (__Pyx_PyObject_GetSlot(o, tp_dealloc, destructor) == __pyx_tp_dealloc_Enum) { + if (PyObject_CallFinalizerFromDealloc(o)) return; + } + } + #endif + PyObject_GC_UnTrack(o); + Py_CLEAR(p->name); + #if CYTHON_USE_TYPE_SLOTS || CYTHON_COMPILING_IN_PYPY + (*Py_TYPE(o)->tp_free)(o); + #else + { + freefunc tp_free = (freefunc)PyType_GetSlot(Py_TYPE(o), Py_tp_free); + if (tp_free) tp_free(o); + } + #endif +} + +static int __pyx_tp_traverse_Enum(PyObject *o, visitproc v, void *a) { + int e; + struct __pyx_MemviewEnum_obj *p = (struct __pyx_MemviewEnum_obj *)o; + if (p->name) { + e = (*v)(p->name, a); if (e) return e; + } + return 0; +} + +static int __pyx_tp_clear_Enum(PyObject *o) { + PyObject* tmp; + struct __pyx_MemviewEnum_obj *p = (struct __pyx_MemviewEnum_obj *)o; + tmp = ((PyObject*)p->name); + p->name = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + return 0; +} + +static PyObject *__pyx_specialmethod___pyx_MemviewEnum___repr__(PyObject *self, CYTHON_UNUSED PyObject *arg) { + return __pyx_MemviewEnum___repr__(self); +} + +static PyMethodDef __pyx_methods_Enum[] = { + {"__repr__", (PyCFunction)__pyx_specialmethod___pyx_MemviewEnum___repr__, METH_NOARGS|METH_COEXIST, 0}, + {"__reduce_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_MemviewEnum_1__reduce_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"__setstate_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_MemviewEnum_3__setstate_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {0, 0, 0, 0} +}; +#if CYTHON_USE_TYPE_SPECS +static PyType_Slot __pyx_type___pyx_MemviewEnum_slots[] = { + {Py_tp_dealloc, (void *)__pyx_tp_dealloc_Enum}, + {Py_tp_repr, (void *)__pyx_MemviewEnum___repr__}, + {Py_tp_traverse, (void *)__pyx_tp_traverse_Enum}, + {Py_tp_clear, (void *)__pyx_tp_clear_Enum}, + {Py_tp_methods, (void *)__pyx_methods_Enum}, + {Py_tp_init, (void *)__pyx_MemviewEnum___init__}, + {Py_tp_new, (void *)__pyx_tp_new_Enum}, + {0, 0}, +}; +static PyType_Spec __pyx_type___pyx_MemviewEnum_spec = { + "TTS.tts.utils.monotonic_align.core.Enum", + sizeof(struct __pyx_MemviewEnum_obj), + 0, + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, + __pyx_type___pyx_MemviewEnum_slots, +}; +#else + +static PyTypeObject __pyx_type___pyx_MemviewEnum = { + PyVarObject_HEAD_INIT(0, 0) + "TTS.tts.utils.monotonic_align.core.""Enum", /*tp_name*/ + sizeof(struct __pyx_MemviewEnum_obj), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + __pyx_tp_dealloc_Enum, /*tp_dealloc*/ + #if PY_VERSION_HEX < 0x030800b4 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030800b4 + 0, /*tp_vectorcall_offset*/ + #endif + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + #if PY_MAJOR_VERSION < 3 + 0, /*tp_compare*/ + #endif + #if PY_MAJOR_VERSION >= 3 + 0, /*tp_as_async*/ + #endif + __pyx_MemviewEnum___repr__, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash*/ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, /*tp_flags*/ + 0, /*tp_doc*/ + __pyx_tp_traverse_Enum, /*tp_traverse*/ + __pyx_tp_clear_Enum, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + __pyx_methods_Enum, /*tp_methods*/ + 0, /*tp_members*/ + 0, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + #if !CYTHON_USE_TYPE_SPECS + 0, /*tp_dictoffset*/ + #endif + __pyx_MemviewEnum___init__, /*tp_init*/ + 0, /*tp_alloc*/ + __pyx_tp_new_Enum, /*tp_new*/ + 0, /*tp_free*/ + 0, /*tp_is_gc*/ + 0, /*tp_bases*/ + 0, /*tp_mro*/ + 0, /*tp_cache*/ + 0, /*tp_subclasses*/ + 0, /*tp_weaklist*/ + 0, /*tp_del*/ + 0, /*tp_version_tag*/ + #if PY_VERSION_HEX >= 0x030400a1 + #if CYTHON_USE_TP_FINALIZE + 0, /*tp_finalize*/ + #else + NULL, /*tp_finalize*/ + #endif + #endif + #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) + 0, /*tp_vectorcall*/ + #endif + #if __PYX_NEED_TP_PRINT_SLOT == 1 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030C0000 + 0, /*tp_watched*/ + #endif + #if PY_VERSION_HEX >= 0x030d00A4 + 0, /*tp_versions_used*/ + #endif + #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 + 0, /*tp_pypy_flags*/ + #endif +}; +#endif +static struct __pyx_vtabstruct_memoryview __pyx_vtable_memoryview; + +static PyObject *__pyx_tp_new_memoryview(PyTypeObject *t, PyObject *a, PyObject *k) { + struct __pyx_memoryview_obj *p; + PyObject *o; + #if CYTHON_COMPILING_IN_LIMITED_API + allocfunc alloc_func = (allocfunc)PyType_GetSlot(t, Py_tp_alloc); + o = alloc_func(t, 0); + #else + if (likely(!__Pyx_PyType_HasFeature(t, Py_TPFLAGS_IS_ABSTRACT))) { + o = (*t->tp_alloc)(t, 0); + } else { + o = (PyObject *) PyBaseObject_Type.tp_new(t, __pyx_empty_tuple, 0); + } + if (unlikely(!o)) return 0; + #endif + p = ((struct __pyx_memoryview_obj *)o); + p->__pyx_vtab = __pyx_vtabptr_memoryview; + p->obj = Py_None; Py_INCREF(Py_None); + p->_size = Py_None; Py_INCREF(Py_None); + p->_array_interface = Py_None; Py_INCREF(Py_None); + p->view.obj = NULL; + if (unlikely(__pyx_memoryview___cinit__(o, a, k) < 0)) goto bad; + return o; + bad: + Py_DECREF(o); o = 0; + return NULL; +} + +static void __pyx_tp_dealloc_memoryview(PyObject *o) { + struct __pyx_memoryview_obj *p = (struct __pyx_memoryview_obj *)o; + #if CYTHON_USE_TP_FINALIZE + if (unlikely((PY_VERSION_HEX >= 0x03080000 || __Pyx_PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE)) && __Pyx_PyObject_GetSlot(o, tp_finalize, destructor)) && !__Pyx_PyObject_GC_IsFinalized(o)) { + if (__Pyx_PyObject_GetSlot(o, tp_dealloc, destructor) == __pyx_tp_dealloc_memoryview) { + if (PyObject_CallFinalizerFromDealloc(o)) return; + } + } + #endif + PyObject_GC_UnTrack(o); + { + PyObject *etype, *eval, *etb; + PyErr_Fetch(&etype, &eval, &etb); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) + 1); + __pyx_memoryview___dealloc__(o); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) - 1); + PyErr_Restore(etype, eval, etb); + } + Py_CLEAR(p->obj); + Py_CLEAR(p->_size); + Py_CLEAR(p->_array_interface); + #if CYTHON_USE_TYPE_SLOTS || CYTHON_COMPILING_IN_PYPY + (*Py_TYPE(o)->tp_free)(o); + #else + { + freefunc tp_free = (freefunc)PyType_GetSlot(Py_TYPE(o), Py_tp_free); + if (tp_free) tp_free(o); + } + #endif +} + +static int __pyx_tp_traverse_memoryview(PyObject *o, visitproc v, void *a) { + int e; + struct __pyx_memoryview_obj *p = (struct __pyx_memoryview_obj *)o; + if (p->obj) { + e = (*v)(p->obj, a); if (e) return e; + } + if (p->_size) { + e = (*v)(p->_size, a); if (e) return e; + } + if (p->_array_interface) { + e = (*v)(p->_array_interface, a); if (e) return e; + } + if (p->view.obj) { + e = (*v)(p->view.obj, a); if (e) return e; + } + return 0; +} + +static int __pyx_tp_clear_memoryview(PyObject *o) { + PyObject* tmp; + struct __pyx_memoryview_obj *p = (struct __pyx_memoryview_obj *)o; + tmp = ((PyObject*)p->obj); + p->obj = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + tmp = ((PyObject*)p->_size); + p->_size = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + tmp = ((PyObject*)p->_array_interface); + p->_array_interface = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + Py_CLEAR(p->view.obj); + return 0; +} +static PyObject *__pyx_sq_item_memoryview(PyObject *o, Py_ssize_t i) { + PyObject *r; + PyObject *x = PyInt_FromSsize_t(i); if(!x) return 0; + r = Py_TYPE(o)->tp_as_mapping->mp_subscript(o, x); + Py_DECREF(x); + return r; +} + +static int __pyx_mp_ass_subscript_memoryview(PyObject *o, PyObject *i, PyObject *v) { + if (v) { + return __pyx_memoryview___setitem__(o, i, v); + } + else { + __Pyx_TypeName o_type_name; + o_type_name = __Pyx_PyType_GetName(Py_TYPE(o)); + PyErr_Format(PyExc_NotImplementedError, + "Subscript deletion not supported by " __Pyx_FMT_TYPENAME, o_type_name); + __Pyx_DECREF_TypeName(o_type_name); + return -1; + } +} + +static PyObject *__pyx_getprop___pyx_memoryview_T(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_1T_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_base(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_4base_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_shape(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_5shape_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_strides(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_7strides_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_suboffsets(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_10suboffsets_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_ndim(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_4ndim_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_itemsize(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_8itemsize_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_nbytes(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_6nbytes_1__get__(o); +} + +static PyObject *__pyx_getprop___pyx_memoryview_size(PyObject *o, CYTHON_UNUSED void *x) { + return __pyx_pw_15View_dot_MemoryView_10memoryview_4size_1__get__(o); +} + +static PyObject *__pyx_specialmethod___pyx_memoryview___repr__(PyObject *self, CYTHON_UNUSED PyObject *arg) { + return __pyx_memoryview___repr__(self); +} + +static PyMethodDef __pyx_methods_memoryview[] = { + {"__repr__", (PyCFunction)__pyx_specialmethod___pyx_memoryview___repr__, METH_NOARGS|METH_COEXIST, 0}, + {"is_c_contig", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_memoryview_is_c_contig, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"is_f_contig", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_memoryview_is_f_contig, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"copy", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_memoryview_copy, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"copy_fortran", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_memoryview_copy_fortran, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"__reduce_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_memoryview_1__reduce_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"__setstate_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_memoryview_3__setstate_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {0, 0, 0, 0} +}; + +static struct PyGetSetDef __pyx_getsets_memoryview[] = { + {(char *)"T", __pyx_getprop___pyx_memoryview_T, 0, (char *)0, 0}, + {(char *)"base", __pyx_getprop___pyx_memoryview_base, 0, (char *)0, 0}, + {(char *)"shape", __pyx_getprop___pyx_memoryview_shape, 0, (char *)0, 0}, + {(char *)"strides", __pyx_getprop___pyx_memoryview_strides, 0, (char *)0, 0}, + {(char *)"suboffsets", __pyx_getprop___pyx_memoryview_suboffsets, 0, (char *)0, 0}, + {(char *)"ndim", __pyx_getprop___pyx_memoryview_ndim, 0, (char *)0, 0}, + {(char *)"itemsize", __pyx_getprop___pyx_memoryview_itemsize, 0, (char *)0, 0}, + {(char *)"nbytes", __pyx_getprop___pyx_memoryview_nbytes, 0, (char *)0, 0}, + {(char *)"size", __pyx_getprop___pyx_memoryview_size, 0, (char *)0, 0}, + {0, 0, 0, 0, 0} +}; +#if CYTHON_USE_TYPE_SPECS +#if !CYTHON_COMPILING_IN_LIMITED_API + +static PyBufferProcs __pyx_tp_as_buffer_memoryview = { + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getreadbuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getwritebuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getsegcount*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getcharbuffer*/ + #endif + __pyx_memoryview_getbuffer, /*bf_getbuffer*/ + 0, /*bf_releasebuffer*/ +}; +#endif +static PyType_Slot __pyx_type___pyx_memoryview_slots[] = { + {Py_tp_dealloc, (void *)__pyx_tp_dealloc_memoryview}, + {Py_tp_repr, (void *)__pyx_memoryview___repr__}, + {Py_sq_length, (void *)__pyx_memoryview___len__}, + {Py_sq_item, (void *)__pyx_sq_item_memoryview}, + {Py_mp_length, (void *)__pyx_memoryview___len__}, + {Py_mp_subscript, (void *)__pyx_memoryview___getitem__}, + {Py_mp_ass_subscript, (void *)__pyx_mp_ass_subscript_memoryview}, + {Py_tp_str, (void *)__pyx_memoryview___str__}, + #if defined(Py_bf_getbuffer) + {Py_bf_getbuffer, (void *)__pyx_memoryview_getbuffer}, + #endif + {Py_tp_traverse, (void *)__pyx_tp_traverse_memoryview}, + {Py_tp_clear, (void *)__pyx_tp_clear_memoryview}, + {Py_tp_methods, (void *)__pyx_methods_memoryview}, + {Py_tp_getset, (void *)__pyx_getsets_memoryview}, + {Py_tp_new, (void *)__pyx_tp_new_memoryview}, + {0, 0}, +}; +static PyType_Spec __pyx_type___pyx_memoryview_spec = { + "TTS.tts.utils.monotonic_align.core.memoryview", + sizeof(struct __pyx_memoryview_obj), + 0, + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, + __pyx_type___pyx_memoryview_slots, +}; +#else + +static PySequenceMethods __pyx_tp_as_sequence_memoryview = { + __pyx_memoryview___len__, /*sq_length*/ + 0, /*sq_concat*/ + 0, /*sq_repeat*/ + __pyx_sq_item_memoryview, /*sq_item*/ + 0, /*sq_slice*/ + 0, /*sq_ass_item*/ + 0, /*sq_ass_slice*/ + 0, /*sq_contains*/ + 0, /*sq_inplace_concat*/ + 0, /*sq_inplace_repeat*/ +}; + +static PyMappingMethods __pyx_tp_as_mapping_memoryview = { + __pyx_memoryview___len__, /*mp_length*/ + __pyx_memoryview___getitem__, /*mp_subscript*/ + __pyx_mp_ass_subscript_memoryview, /*mp_ass_subscript*/ +}; + +static PyBufferProcs __pyx_tp_as_buffer_memoryview = { + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getreadbuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getwritebuffer*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getsegcount*/ + #endif + #if PY_MAJOR_VERSION < 3 + 0, /*bf_getcharbuffer*/ + #endif + __pyx_memoryview_getbuffer, /*bf_getbuffer*/ + 0, /*bf_releasebuffer*/ +}; + +static PyTypeObject __pyx_type___pyx_memoryview = { + PyVarObject_HEAD_INIT(0, 0) + "TTS.tts.utils.monotonic_align.core.""memoryview", /*tp_name*/ + sizeof(struct __pyx_memoryview_obj), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + __pyx_tp_dealloc_memoryview, /*tp_dealloc*/ + #if PY_VERSION_HEX < 0x030800b4 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030800b4 + 0, /*tp_vectorcall_offset*/ + #endif + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + #if PY_MAJOR_VERSION < 3 + 0, /*tp_compare*/ + #endif + #if PY_MAJOR_VERSION >= 3 + 0, /*tp_as_async*/ + #endif + __pyx_memoryview___repr__, /*tp_repr*/ + 0, /*tp_as_number*/ + &__pyx_tp_as_sequence_memoryview, /*tp_as_sequence*/ + &__pyx_tp_as_mapping_memoryview, /*tp_as_mapping*/ + 0, /*tp_hash*/ + 0, /*tp_call*/ + __pyx_memoryview___str__, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + &__pyx_tp_as_buffer_memoryview, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, /*tp_flags*/ + 0, /*tp_doc*/ + __pyx_tp_traverse_memoryview, /*tp_traverse*/ + __pyx_tp_clear_memoryview, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + __pyx_methods_memoryview, /*tp_methods*/ + 0, /*tp_members*/ + __pyx_getsets_memoryview, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + #if !CYTHON_USE_TYPE_SPECS + 0, /*tp_dictoffset*/ + #endif + 0, /*tp_init*/ + 0, /*tp_alloc*/ + __pyx_tp_new_memoryview, /*tp_new*/ + 0, /*tp_free*/ + 0, /*tp_is_gc*/ + 0, /*tp_bases*/ + 0, /*tp_mro*/ + 0, /*tp_cache*/ + 0, /*tp_subclasses*/ + 0, /*tp_weaklist*/ + 0, /*tp_del*/ + 0, /*tp_version_tag*/ + #if PY_VERSION_HEX >= 0x030400a1 + #if CYTHON_USE_TP_FINALIZE + 0, /*tp_finalize*/ + #else + NULL, /*tp_finalize*/ + #endif + #endif + #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) + 0, /*tp_vectorcall*/ + #endif + #if __PYX_NEED_TP_PRINT_SLOT == 1 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030C0000 + 0, /*tp_watched*/ + #endif + #if PY_VERSION_HEX >= 0x030d00A4 + 0, /*tp_versions_used*/ + #endif + #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 + 0, /*tp_pypy_flags*/ + #endif +}; +#endif +static struct __pyx_vtabstruct__memoryviewslice __pyx_vtable__memoryviewslice; + +static PyObject *__pyx_tp_new__memoryviewslice(PyTypeObject *t, PyObject *a, PyObject *k) { + struct __pyx_memoryviewslice_obj *p; + PyObject *o = __pyx_tp_new_memoryview(t, a, k); + if (unlikely(!o)) return 0; + p = ((struct __pyx_memoryviewslice_obj *)o); + p->__pyx_base.__pyx_vtab = (struct __pyx_vtabstruct_memoryview*)__pyx_vtabptr__memoryviewslice; + p->from_object = Py_None; Py_INCREF(Py_None); + p->from_slice.memview = NULL; + return o; +} + +static void __pyx_tp_dealloc__memoryviewslice(PyObject *o) { + struct __pyx_memoryviewslice_obj *p = (struct __pyx_memoryviewslice_obj *)o; + #if CYTHON_USE_TP_FINALIZE + if (unlikely((PY_VERSION_HEX >= 0x03080000 || __Pyx_PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_HAVE_FINALIZE)) && __Pyx_PyObject_GetSlot(o, tp_finalize, destructor)) && !__Pyx_PyObject_GC_IsFinalized(o)) { + if (__Pyx_PyObject_GetSlot(o, tp_dealloc, destructor) == __pyx_tp_dealloc__memoryviewslice) { + if (PyObject_CallFinalizerFromDealloc(o)) return; + } + } + #endif + PyObject_GC_UnTrack(o); + { + PyObject *etype, *eval, *etb; + PyErr_Fetch(&etype, &eval, &etb); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) + 1); + __pyx_memoryviewslice___dealloc__(o); + __Pyx_SET_REFCNT(o, Py_REFCNT(o) - 1); + PyErr_Restore(etype, eval, etb); + } + Py_CLEAR(p->from_object); + PyObject_GC_Track(o); + __pyx_tp_dealloc_memoryview(o); +} + +static int __pyx_tp_traverse__memoryviewslice(PyObject *o, visitproc v, void *a) { + int e; + struct __pyx_memoryviewslice_obj *p = (struct __pyx_memoryviewslice_obj *)o; + e = __pyx_tp_traverse_memoryview(o, v, a); if (e) return e; + if (p->from_object) { + e = (*v)(p->from_object, a); if (e) return e; + } + return 0; +} + +static int __pyx_tp_clear__memoryviewslice(PyObject *o) { + PyObject* tmp; + struct __pyx_memoryviewslice_obj *p = (struct __pyx_memoryviewslice_obj *)o; + __pyx_tp_clear_memoryview(o); + tmp = ((PyObject*)p->from_object); + p->from_object = Py_None; Py_INCREF(Py_None); + Py_XDECREF(tmp); + __PYX_XCLEAR_MEMVIEW(&p->from_slice, 1); + return 0; +} + +static PyMethodDef __pyx_methods__memoryviewslice[] = { + {"__reduce_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_memoryviewslice_1__reduce_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {"__setstate_cython__", (PyCFunction)(void*)(__Pyx_PyCFunction_FastCallWithKeywords)__pyx_pw___pyx_memoryviewslice_3__setstate_cython__, __Pyx_METH_FASTCALL|METH_KEYWORDS, 0}, + {0, 0, 0, 0} +}; +#if CYTHON_USE_TYPE_SPECS +static PyType_Slot __pyx_type___pyx_memoryviewslice_slots[] = { + {Py_tp_dealloc, (void *)__pyx_tp_dealloc__memoryviewslice}, + {Py_tp_doc, (void *)PyDoc_STR("Internal class for passing memoryview slices to Python")}, + {Py_tp_traverse, (void *)__pyx_tp_traverse__memoryviewslice}, + {Py_tp_clear, (void *)__pyx_tp_clear__memoryviewslice}, + {Py_tp_methods, (void *)__pyx_methods__memoryviewslice}, + {Py_tp_new, (void *)__pyx_tp_new__memoryviewslice}, + {0, 0}, +}; +static PyType_Spec __pyx_type___pyx_memoryviewslice_spec = { + "TTS.tts.utils.monotonic_align.core._memoryviewslice", + sizeof(struct __pyx_memoryviewslice_obj), + 0, + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC|Py_TPFLAGS_SEQUENCE, + __pyx_type___pyx_memoryviewslice_slots, +}; +#else + +static PyTypeObject __pyx_type___pyx_memoryviewslice = { + PyVarObject_HEAD_INIT(0, 0) + "TTS.tts.utils.monotonic_align.core.""_memoryviewslice", /*tp_name*/ + sizeof(struct __pyx_memoryviewslice_obj), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + __pyx_tp_dealloc__memoryviewslice, /*tp_dealloc*/ + #if PY_VERSION_HEX < 0x030800b4 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030800b4 + 0, /*tp_vectorcall_offset*/ + #endif + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + #if PY_MAJOR_VERSION < 3 + 0, /*tp_compare*/ + #endif + #if PY_MAJOR_VERSION >= 3 + 0, /*tp_as_async*/ + #endif + #if CYTHON_COMPILING_IN_PYPY || 0 + __pyx_memoryview___repr__, /*tp_repr*/ + #else + 0, /*tp_repr*/ + #endif + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash*/ + 0, /*tp_call*/ + #if CYTHON_COMPILING_IN_PYPY || 0 + __pyx_memoryview___str__, /*tp_str*/ + #else + 0, /*tp_str*/ + #endif + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT|Py_TPFLAGS_HAVE_VERSION_TAG|Py_TPFLAGS_CHECKTYPES|Py_TPFLAGS_HAVE_NEWBUFFER|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC|Py_TPFLAGS_SEQUENCE, /*tp_flags*/ + PyDoc_STR("Internal class for passing memoryview slices to Python"), /*tp_doc*/ + __pyx_tp_traverse__memoryviewslice, /*tp_traverse*/ + __pyx_tp_clear__memoryviewslice, /*tp_clear*/ + 0, /*tp_richcompare*/ + 0, /*tp_weaklistoffset*/ + 0, /*tp_iter*/ + 0, /*tp_iternext*/ + __pyx_methods__memoryviewslice, /*tp_methods*/ + 0, /*tp_members*/ + 0, /*tp_getset*/ + 0, /*tp_base*/ + 0, /*tp_dict*/ + 0, /*tp_descr_get*/ + 0, /*tp_descr_set*/ + #if !CYTHON_USE_TYPE_SPECS + 0, /*tp_dictoffset*/ + #endif + 0, /*tp_init*/ + 0, /*tp_alloc*/ + __pyx_tp_new__memoryviewslice, /*tp_new*/ + 0, /*tp_free*/ + 0, /*tp_is_gc*/ + 0, /*tp_bases*/ + 0, /*tp_mro*/ + 0, /*tp_cache*/ + 0, /*tp_subclasses*/ + 0, /*tp_weaklist*/ + 0, /*tp_del*/ + 0, /*tp_version_tag*/ + #if PY_VERSION_HEX >= 0x030400a1 + #if CYTHON_USE_TP_FINALIZE + 0, /*tp_finalize*/ + #else + NULL, /*tp_finalize*/ + #endif + #endif + #if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) + 0, /*tp_vectorcall*/ + #endif + #if __PYX_NEED_TP_PRINT_SLOT == 1 + 0, /*tp_print*/ + #endif + #if PY_VERSION_HEX >= 0x030C0000 + 0, /*tp_watched*/ + #endif + #if PY_VERSION_HEX >= 0x030d00A4 + 0, /*tp_versions_used*/ + #endif + #if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 + 0, /*tp_pypy_flags*/ + #endif +}; +#endif + +static PyMethodDef __pyx_methods[] = { + {0, 0, 0, 0} +}; +#ifndef CYTHON_SMALL_CODE +#if defined(__clang__) + #define CYTHON_SMALL_CODE +#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3)) + #define CYTHON_SMALL_CODE __attribute__((cold)) +#else + #define CYTHON_SMALL_CODE +#endif +#endif +/* #### Code section: pystring_table ### */ + +static int __Pyx_CreateStringTabAndInitStrings(void) { + __Pyx_StringTabEntry __pyx_string_tab[] = { + {&__pyx_kp_u_, __pyx_k_, sizeof(__pyx_k_), 0, 1, 0, 0}, + {&__pyx_n_s_ASCII, __pyx_k_ASCII, sizeof(__pyx_k_ASCII), 0, 0, 1, 1}, + {&__pyx_kp_s_All_dimensions_preceding_dimensi, __pyx_k_All_dimensions_preceding_dimensi, sizeof(__pyx_k_All_dimensions_preceding_dimensi), 0, 0, 1, 0}, + {&__pyx_n_s_AssertionError, __pyx_k_AssertionError, sizeof(__pyx_k_AssertionError), 0, 0, 1, 1}, + {&__pyx_kp_s_Buffer_view_does_not_expose_stri, __pyx_k_Buffer_view_does_not_expose_stri, sizeof(__pyx_k_Buffer_view_does_not_expose_stri), 0, 0, 1, 0}, + {&__pyx_kp_s_Can_only_create_a_buffer_that_is, __pyx_k_Can_only_create_a_buffer_that_is, sizeof(__pyx_k_Can_only_create_a_buffer_that_is), 0, 0, 1, 0}, + {&__pyx_kp_s_Cannot_assign_to_read_only_memor, __pyx_k_Cannot_assign_to_read_only_memor, sizeof(__pyx_k_Cannot_assign_to_read_only_memor), 0, 0, 1, 0}, + {&__pyx_kp_s_Cannot_create_writable_memory_vi, __pyx_k_Cannot_create_writable_memory_vi, sizeof(__pyx_k_Cannot_create_writable_memory_vi), 0, 0, 1, 0}, + {&__pyx_kp_u_Cannot_index_with_type, __pyx_k_Cannot_index_with_type, sizeof(__pyx_k_Cannot_index_with_type), 0, 1, 0, 0}, + {&__pyx_kp_s_Cannot_transpose_memoryview_with, __pyx_k_Cannot_transpose_memoryview_with, sizeof(__pyx_k_Cannot_transpose_memoryview_with), 0, 0, 1, 0}, + {&__pyx_kp_s_Dimension_d_is_not_direct, __pyx_k_Dimension_d_is_not_direct, sizeof(__pyx_k_Dimension_d_is_not_direct), 0, 0, 1, 0}, + {&__pyx_n_s_Ellipsis, __pyx_k_Ellipsis, sizeof(__pyx_k_Ellipsis), 0, 0, 1, 1}, + {&__pyx_kp_s_Empty_shape_tuple_for_cython_arr, __pyx_k_Empty_shape_tuple_for_cython_arr, sizeof(__pyx_k_Empty_shape_tuple_for_cython_arr), 0, 0, 1, 0}, + {&__pyx_n_s_ImportError, __pyx_k_ImportError, sizeof(__pyx_k_ImportError), 0, 0, 1, 1}, + {&__pyx_kp_s_Incompatible_checksums_0x_x_vs_0, __pyx_k_Incompatible_checksums_0x_x_vs_0, sizeof(__pyx_k_Incompatible_checksums_0x_x_vs_0), 0, 0, 1, 0}, + {&__pyx_n_s_IndexError, __pyx_k_IndexError, sizeof(__pyx_k_IndexError), 0, 0, 1, 1}, + {&__pyx_kp_s_Index_out_of_bounds_axis_d, __pyx_k_Index_out_of_bounds_axis_d, sizeof(__pyx_k_Index_out_of_bounds_axis_d), 0, 0, 1, 0}, + {&__pyx_kp_s_Indirect_dimensions_not_supporte, __pyx_k_Indirect_dimensions_not_supporte, sizeof(__pyx_k_Indirect_dimensions_not_supporte), 0, 0, 1, 0}, + {&__pyx_kp_u_Invalid_mode_expected_c_or_fortr, __pyx_k_Invalid_mode_expected_c_or_fortr, sizeof(__pyx_k_Invalid_mode_expected_c_or_fortr), 0, 1, 0, 0}, + {&__pyx_kp_u_Invalid_shape_in_axis, __pyx_k_Invalid_shape_in_axis, sizeof(__pyx_k_Invalid_shape_in_axis), 0, 1, 0, 0}, + {&__pyx_n_s_MemoryError, __pyx_k_MemoryError, sizeof(__pyx_k_MemoryError), 0, 0, 1, 1}, + {&__pyx_kp_s_MemoryView_of_r_at_0x_x, __pyx_k_MemoryView_of_r_at_0x_x, sizeof(__pyx_k_MemoryView_of_r_at_0x_x), 0, 0, 1, 0}, + {&__pyx_kp_s_MemoryView_of_r_object, __pyx_k_MemoryView_of_r_object, sizeof(__pyx_k_MemoryView_of_r_object), 0, 0, 1, 0}, + {&__pyx_n_b_O, __pyx_k_O, sizeof(__pyx_k_O), 0, 0, 0, 1}, + {&__pyx_kp_u_Out_of_bounds_on_buffer_access_a, __pyx_k_Out_of_bounds_on_buffer_access_a, sizeof(__pyx_k_Out_of_bounds_on_buffer_access_a), 0, 1, 0, 0}, + {&__pyx_n_s_PickleError, __pyx_k_PickleError, sizeof(__pyx_k_PickleError), 0, 0, 1, 1}, + {&__pyx_n_s_Sequence, __pyx_k_Sequence, sizeof(__pyx_k_Sequence), 0, 0, 1, 1}, + {&__pyx_kp_s_Step_may_not_be_zero_axis_d, __pyx_k_Step_may_not_be_zero_axis_d, sizeof(__pyx_k_Step_may_not_be_zero_axis_d), 0, 0, 1, 0}, + {&__pyx_kp_s_TTS_tts_utils_monotonic_align_co, __pyx_k_TTS_tts_utils_monotonic_align_co, sizeof(__pyx_k_TTS_tts_utils_monotonic_align_co), 0, 0, 1, 0}, + {&__pyx_n_s_TTS_tts_utils_monotonic_align_co_2, __pyx_k_TTS_tts_utils_monotonic_align_co_2, sizeof(__pyx_k_TTS_tts_utils_monotonic_align_co_2), 0, 0, 1, 1}, + {&__pyx_n_s_TypeError, __pyx_k_TypeError, sizeof(__pyx_k_TypeError), 0, 0, 1, 1}, + {&__pyx_kp_s_Unable_to_convert_item_to_object, __pyx_k_Unable_to_convert_item_to_object, sizeof(__pyx_k_Unable_to_convert_item_to_object), 0, 0, 1, 0}, + {&__pyx_n_s_ValueError, __pyx_k_ValueError, sizeof(__pyx_k_ValueError), 0, 0, 1, 1}, + {&__pyx_n_s_View_MemoryView, __pyx_k_View_MemoryView, sizeof(__pyx_k_View_MemoryView), 0, 0, 1, 1}, + {&__pyx_kp_u__2, __pyx_k__2, sizeof(__pyx_k__2), 0, 1, 0, 0}, + {&__pyx_n_s__25, __pyx_k__25, sizeof(__pyx_k__25), 0, 0, 1, 1}, + {&__pyx_n_s__3, __pyx_k__3, sizeof(__pyx_k__3), 0, 0, 1, 1}, + {&__pyx_kp_u__6, __pyx_k__6, sizeof(__pyx_k__6), 0, 1, 0, 0}, + {&__pyx_kp_u__7, __pyx_k__7, sizeof(__pyx_k__7), 0, 1, 0, 0}, + {&__pyx_n_s_abc, __pyx_k_abc, sizeof(__pyx_k_abc), 0, 0, 1, 1}, + {&__pyx_n_s_allocate_buffer, __pyx_k_allocate_buffer, sizeof(__pyx_k_allocate_buffer), 0, 0, 1, 1}, + {&__pyx_kp_u_and, __pyx_k_and, sizeof(__pyx_k_and), 0, 1, 0, 0}, + {&__pyx_n_s_asyncio_coroutines, __pyx_k_asyncio_coroutines, sizeof(__pyx_k_asyncio_coroutines), 0, 0, 1, 1}, + {&__pyx_n_s_base, __pyx_k_base, sizeof(__pyx_k_base), 0, 0, 1, 1}, + {&__pyx_n_s_c, __pyx_k_c, sizeof(__pyx_k_c), 0, 0, 1, 1}, + {&__pyx_n_u_c, __pyx_k_c, sizeof(__pyx_k_c), 0, 1, 0, 1}, + {&__pyx_n_s_class, __pyx_k_class, sizeof(__pyx_k_class), 0, 0, 1, 1}, + {&__pyx_n_s_class_getitem, __pyx_k_class_getitem, sizeof(__pyx_k_class_getitem), 0, 0, 1, 1}, + {&__pyx_n_s_cline_in_traceback, __pyx_k_cline_in_traceback, sizeof(__pyx_k_cline_in_traceback), 0, 0, 1, 1}, + {&__pyx_n_s_collections, __pyx_k_collections, sizeof(__pyx_k_collections), 0, 0, 1, 1}, + {&__pyx_kp_s_collections_abc, __pyx_k_collections_abc, sizeof(__pyx_k_collections_abc), 0, 0, 1, 0}, + {&__pyx_kp_s_contiguous_and_direct, __pyx_k_contiguous_and_direct, sizeof(__pyx_k_contiguous_and_direct), 0, 0, 1, 0}, + {&__pyx_kp_s_contiguous_and_indirect, __pyx_k_contiguous_and_indirect, sizeof(__pyx_k_contiguous_and_indirect), 0, 0, 1, 0}, + {&__pyx_n_s_count, __pyx_k_count, sizeof(__pyx_k_count), 0, 0, 1, 1}, + {&__pyx_n_s_dict, __pyx_k_dict, sizeof(__pyx_k_dict), 0, 0, 1, 1}, + {&__pyx_kp_u_disable, __pyx_k_disable, sizeof(__pyx_k_disable), 0, 1, 0, 0}, + {&__pyx_n_s_dtype_is_object, __pyx_k_dtype_is_object, sizeof(__pyx_k_dtype_is_object), 0, 0, 1, 1}, + {&__pyx_kp_u_enable, __pyx_k_enable, sizeof(__pyx_k_enable), 0, 1, 0, 0}, + {&__pyx_n_s_encode, __pyx_k_encode, sizeof(__pyx_k_encode), 0, 0, 1, 1}, + {&__pyx_n_s_enumerate, __pyx_k_enumerate, sizeof(__pyx_k_enumerate), 0, 0, 1, 1}, + {&__pyx_n_s_error, __pyx_k_error, sizeof(__pyx_k_error), 0, 0, 1, 1}, + {&__pyx_n_s_flags, __pyx_k_flags, sizeof(__pyx_k_flags), 0, 0, 1, 1}, + {&__pyx_n_s_format, __pyx_k_format, sizeof(__pyx_k_format), 0, 0, 1, 1}, + {&__pyx_n_s_fortran, __pyx_k_fortran, sizeof(__pyx_k_fortran), 0, 0, 1, 1}, + {&__pyx_n_u_fortran, __pyx_k_fortran, sizeof(__pyx_k_fortran), 0, 1, 0, 1}, + {&__pyx_kp_u_gc, __pyx_k_gc, sizeof(__pyx_k_gc), 0, 1, 0, 0}, + {&__pyx_n_s_getstate, __pyx_k_getstate, sizeof(__pyx_k_getstate), 0, 0, 1, 1}, + {&__pyx_kp_u_got, __pyx_k_got, sizeof(__pyx_k_got), 0, 1, 0, 0}, + {&__pyx_kp_u_got_differing_extents_in_dimensi, __pyx_k_got_differing_extents_in_dimensi, sizeof(__pyx_k_got_differing_extents_in_dimensi), 0, 1, 0, 0}, + {&__pyx_n_s_id, __pyx_k_id, sizeof(__pyx_k_id), 0, 0, 1, 1}, + {&__pyx_n_s_import, __pyx_k_import, sizeof(__pyx_k_import), 0, 0, 1, 1}, + {&__pyx_n_s_index, __pyx_k_index, sizeof(__pyx_k_index), 0, 0, 1, 1}, + {&__pyx_n_s_initializing, __pyx_k_initializing, sizeof(__pyx_k_initializing), 0, 0, 1, 1}, + {&__pyx_n_s_is_coroutine, __pyx_k_is_coroutine, sizeof(__pyx_k_is_coroutine), 0, 0, 1, 1}, + {&__pyx_kp_u_isenabled, __pyx_k_isenabled, sizeof(__pyx_k_isenabled), 0, 1, 0, 0}, + {&__pyx_n_s_itemsize, __pyx_k_itemsize, sizeof(__pyx_k_itemsize), 0, 0, 1, 1}, + {&__pyx_kp_s_itemsize_0_for_cython_array, __pyx_k_itemsize_0_for_cython_array, sizeof(__pyx_k_itemsize_0_for_cython_array), 0, 0, 1, 0}, + {&__pyx_n_s_main, __pyx_k_main, sizeof(__pyx_k_main), 0, 0, 1, 1}, + {&__pyx_n_s_max_neg_val, __pyx_k_max_neg_val, sizeof(__pyx_k_max_neg_val), 0, 0, 1, 1}, + {&__pyx_n_s_maximum_path_c, __pyx_k_maximum_path_c, sizeof(__pyx_k_maximum_path_c), 0, 0, 1, 1}, + {&__pyx_n_s_memview, __pyx_k_memview, sizeof(__pyx_k_memview), 0, 0, 1, 1}, + {&__pyx_n_s_mode, __pyx_k_mode, sizeof(__pyx_k_mode), 0, 0, 1, 1}, + {&__pyx_n_s_name, __pyx_k_name, sizeof(__pyx_k_name), 0, 0, 1, 1}, + {&__pyx_n_s_name_2, __pyx_k_name_2, sizeof(__pyx_k_name_2), 0, 0, 1, 1}, + {&__pyx_n_s_ndim, __pyx_k_ndim, sizeof(__pyx_k_ndim), 0, 0, 1, 1}, + {&__pyx_n_s_new, __pyx_k_new, sizeof(__pyx_k_new), 0, 0, 1, 1}, + {&__pyx_kp_s_no_default___reduce___due_to_non, __pyx_k_no_default___reduce___due_to_non, sizeof(__pyx_k_no_default___reduce___due_to_non), 0, 0, 1, 0}, + {&__pyx_n_s_np, __pyx_k_np, sizeof(__pyx_k_np), 0, 0, 1, 1}, + {&__pyx_n_s_numpy, __pyx_k_numpy, sizeof(__pyx_k_numpy), 0, 0, 1, 1}, + {&__pyx_kp_u_numpy__core_multiarray_failed_to, __pyx_k_numpy__core_multiarray_failed_to, sizeof(__pyx_k_numpy__core_multiarray_failed_to), 0, 1, 0, 0}, + {&__pyx_kp_u_numpy__core_umath_failed_to_impo, __pyx_k_numpy__core_umath_failed_to_impo, sizeof(__pyx_k_numpy__core_umath_failed_to_impo), 0, 1, 0, 0}, + {&__pyx_n_s_obj, __pyx_k_obj, sizeof(__pyx_k_obj), 0, 0, 1, 1}, + {&__pyx_n_s_pack, __pyx_k_pack, sizeof(__pyx_k_pack), 0, 0, 1, 1}, + {&__pyx_n_s_paths, __pyx_k_paths, sizeof(__pyx_k_paths), 0, 0, 1, 1}, + {&__pyx_n_s_pickle, __pyx_k_pickle, sizeof(__pyx_k_pickle), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_PickleError, __pyx_k_pyx_PickleError, sizeof(__pyx_k_pyx_PickleError), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_checksum, __pyx_k_pyx_checksum, sizeof(__pyx_k_pyx_checksum), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_result, __pyx_k_pyx_result, sizeof(__pyx_k_pyx_result), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_state, __pyx_k_pyx_state, sizeof(__pyx_k_pyx_state), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_type, __pyx_k_pyx_type, sizeof(__pyx_k_pyx_type), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_unpickle_Enum, __pyx_k_pyx_unpickle_Enum, sizeof(__pyx_k_pyx_unpickle_Enum), 0, 0, 1, 1}, + {&__pyx_n_s_pyx_vtable, __pyx_k_pyx_vtable, sizeof(__pyx_k_pyx_vtable), 0, 0, 1, 1}, + {&__pyx_n_s_range, __pyx_k_range, sizeof(__pyx_k_range), 0, 0, 1, 1}, + {&__pyx_n_s_reduce, __pyx_k_reduce, sizeof(__pyx_k_reduce), 0, 0, 1, 1}, + {&__pyx_n_s_reduce_cython, __pyx_k_reduce_cython, sizeof(__pyx_k_reduce_cython), 0, 0, 1, 1}, + {&__pyx_n_s_reduce_ex, __pyx_k_reduce_ex, sizeof(__pyx_k_reduce_ex), 0, 0, 1, 1}, + {&__pyx_n_s_register, __pyx_k_register, sizeof(__pyx_k_register), 0, 0, 1, 1}, + {&__pyx_n_s_setstate, __pyx_k_setstate, sizeof(__pyx_k_setstate), 0, 0, 1, 1}, + {&__pyx_n_s_setstate_cython, __pyx_k_setstate_cython, sizeof(__pyx_k_setstate_cython), 0, 0, 1, 1}, + {&__pyx_n_s_shape, __pyx_k_shape, sizeof(__pyx_k_shape), 0, 0, 1, 1}, + {&__pyx_n_s_size, __pyx_k_size, sizeof(__pyx_k_size), 0, 0, 1, 1}, + {&__pyx_n_s_spec, __pyx_k_spec, sizeof(__pyx_k_spec), 0, 0, 1, 1}, + {&__pyx_n_s_start, __pyx_k_start, sizeof(__pyx_k_start), 0, 0, 1, 1}, + {&__pyx_n_s_step, __pyx_k_step, sizeof(__pyx_k_step), 0, 0, 1, 1}, + {&__pyx_n_s_stop, __pyx_k_stop, sizeof(__pyx_k_stop), 0, 0, 1, 1}, + {&__pyx_kp_s_strided_and_direct, __pyx_k_strided_and_direct, sizeof(__pyx_k_strided_and_direct), 0, 0, 1, 0}, + {&__pyx_kp_s_strided_and_direct_or_indirect, __pyx_k_strided_and_direct_or_indirect, sizeof(__pyx_k_strided_and_direct_or_indirect), 0, 0, 1, 0}, + {&__pyx_kp_s_strided_and_indirect, __pyx_k_strided_and_indirect, sizeof(__pyx_k_strided_and_indirect), 0, 0, 1, 0}, + {&__pyx_kp_s_stringsource, __pyx_k_stringsource, sizeof(__pyx_k_stringsource), 0, 0, 1, 0}, + {&__pyx_n_s_struct, __pyx_k_struct, sizeof(__pyx_k_struct), 0, 0, 1, 1}, + {&__pyx_n_s_sys, __pyx_k_sys, sizeof(__pyx_k_sys), 0, 0, 1, 1}, + {&__pyx_n_s_t_xs, __pyx_k_t_xs, sizeof(__pyx_k_t_xs), 0, 0, 1, 1}, + {&__pyx_n_s_t_ys, __pyx_k_t_ys, sizeof(__pyx_k_t_ys), 0, 0, 1, 1}, + {&__pyx_n_s_test, __pyx_k_test, sizeof(__pyx_k_test), 0, 0, 1, 1}, + {&__pyx_kp_s_unable_to_allocate_array_data, __pyx_k_unable_to_allocate_array_data, sizeof(__pyx_k_unable_to_allocate_array_data), 0, 0, 1, 0}, + {&__pyx_kp_s_unable_to_allocate_shape_and_str, __pyx_k_unable_to_allocate_shape_and_str, sizeof(__pyx_k_unable_to_allocate_shape_and_str), 0, 0, 1, 0}, + {&__pyx_n_s_unpack, __pyx_k_unpack, sizeof(__pyx_k_unpack), 0, 0, 1, 1}, + {&__pyx_n_s_update, __pyx_k_update, sizeof(__pyx_k_update), 0, 0, 1, 1}, + {&__pyx_n_s_values, __pyx_k_values, sizeof(__pyx_k_values), 0, 0, 1, 1}, + {&__pyx_n_s_version_info, __pyx_k_version_info, sizeof(__pyx_k_version_info), 0, 0, 1, 1}, + {0, 0, 0, 0, 0, 0, 0} + }; + return __Pyx_InitStrings(__pyx_string_tab); +} +/* #### Code section: cached_builtins ### */ +static CYTHON_SMALL_CODE int __Pyx_InitCachedBuiltins(void) { + __pyx_builtin_range = __Pyx_GetBuiltinName(__pyx_n_s_range); if (!__pyx_builtin_range) __PYX_ERR(0, 19, __pyx_L1_error) + __pyx_builtin___import__ = __Pyx_GetBuiltinName(__pyx_n_s_import); if (!__pyx_builtin___import__) __PYX_ERR(1, 100, __pyx_L1_error) + __pyx_builtin_ValueError = __Pyx_GetBuiltinName(__pyx_n_s_ValueError); if (!__pyx_builtin_ValueError) __PYX_ERR(1, 141, __pyx_L1_error) + __pyx_builtin_MemoryError = __Pyx_GetBuiltinName(__pyx_n_s_MemoryError); if (!__pyx_builtin_MemoryError) __PYX_ERR(1, 156, __pyx_L1_error) + __pyx_builtin_enumerate = __Pyx_GetBuiltinName(__pyx_n_s_enumerate); if (!__pyx_builtin_enumerate) __PYX_ERR(1, 159, __pyx_L1_error) + __pyx_builtin_TypeError = __Pyx_GetBuiltinName(__pyx_n_s_TypeError); if (!__pyx_builtin_TypeError) __PYX_ERR(1, 2, __pyx_L1_error) + __pyx_builtin_AssertionError = __Pyx_GetBuiltinName(__pyx_n_s_AssertionError); if (!__pyx_builtin_AssertionError) __PYX_ERR(1, 373, __pyx_L1_error) + __pyx_builtin_Ellipsis = __Pyx_GetBuiltinName(__pyx_n_s_Ellipsis); if (!__pyx_builtin_Ellipsis) __PYX_ERR(1, 408, __pyx_L1_error) + __pyx_builtin_id = __Pyx_GetBuiltinName(__pyx_n_s_id); if (!__pyx_builtin_id) __PYX_ERR(1, 618, __pyx_L1_error) + __pyx_builtin_IndexError = __Pyx_GetBuiltinName(__pyx_n_s_IndexError); if (!__pyx_builtin_IndexError) __PYX_ERR(1, 914, __pyx_L1_error) + __pyx_builtin_ImportError = __Pyx_GetBuiltinName(__pyx_n_s_ImportError); if (!__pyx_builtin_ImportError) __PYX_ERR(2, 1026, __pyx_L1_error) + return 0; + __pyx_L1_error:; + return -1; +} +/* #### Code section: cached_constants ### */ + +static CYTHON_SMALL_CODE int __Pyx_InitCachedConstants(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_InitCachedConstants", 0); + + /* "View.MemoryView":582 + * def suboffsets(self): + * if self.view.suboffsets == NULL: + * return (-1,) * self.view.ndim # <<<<<<<<<<<<<< + * + * return tuple([suboffset for suboffset in self.view.suboffsets[:self.view.ndim]]) + */ + __pyx_tuple__4 = PyTuple_New(1); if (unlikely(!__pyx_tuple__4)) __PYX_ERR(1, 582, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__4); + __Pyx_INCREF(__pyx_int_neg_1); + __Pyx_GIVEREF(__pyx_int_neg_1); + if (__Pyx_PyTuple_SET_ITEM(__pyx_tuple__4, 0, __pyx_int_neg_1)) __PYX_ERR(1, 582, __pyx_L1_error); + __Pyx_GIVEREF(__pyx_tuple__4); + + /* "View.MemoryView":679 + * tup = index if isinstance(index, tuple) else (index,) + * + * result = [slice(None)] * ndim # <<<<<<<<<<<<<< + * have_slices = False + * seen_ellipsis = False + */ + __pyx_slice__5 = PySlice_New(Py_None, Py_None, Py_None); if (unlikely(!__pyx_slice__5)) __PYX_ERR(1, 679, __pyx_L1_error) + __Pyx_GOTREF(__pyx_slice__5); + __Pyx_GIVEREF(__pyx_slice__5); + + /* "(tree fragment)":4 + * cdef object __pyx_PickleError + * cdef object __pyx_result + * if __pyx_checksum not in (0x82a3537, 0x6ae9995, 0xb068931): # <<<<<<<<<<<<<< + * from pickle import PickleError as __pyx_PickleError + * raise __pyx_PickleError, "Incompatible checksums (0x%x vs (0x82a3537, 0x6ae9995, 0xb068931) = (name))" % __pyx_checksum + */ + __pyx_tuple__8 = PyTuple_Pack(3, __pyx_int_136983863, __pyx_int_112105877, __pyx_int_184977713); if (unlikely(!__pyx_tuple__8)) __PYX_ERR(1, 4, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__8); + __Pyx_GIVEREF(__pyx_tuple__8); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1026 + * __pyx_import_array() + * except Exception: + * raise ImportError("numpy._core.multiarray failed to import") # <<<<<<<<<<<<<< + * + * cdef inline int import_umath() except -1: + */ + __pyx_tuple__9 = PyTuple_Pack(1, __pyx_kp_u_numpy__core_multiarray_failed_to); if (unlikely(!__pyx_tuple__9)) __PYX_ERR(2, 1026, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__9); + __Pyx_GIVEREF(__pyx_tuple__9); + + /* "../../../../../tmp/build-env-j3n9x19u/lib/python3.9/site-packages/numpy/__init__.cython-30.pxd":1032 + * _import_umath() + * except Exception: + * raise ImportError("numpy._core.umath failed to import") # <<<<<<<<<<<<<< + * + * cdef inline int import_ufunc() except -1: + */ + __pyx_tuple__10 = PyTuple_Pack(1, __pyx_kp_u_numpy__core_umath_failed_to_impo); if (unlikely(!__pyx_tuple__10)) __PYX_ERR(2, 1032, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__10); + __Pyx_GIVEREF(__pyx_tuple__10); + + /* "View.MemoryView":100 + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: + * if __import__("sys").version_info >= (3, 3): # <<<<<<<<<<<<<< + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + * else: + */ + __pyx_tuple__12 = PyTuple_Pack(1, __pyx_n_s_sys); if (unlikely(!__pyx_tuple__12)) __PYX_ERR(1, 100, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__12); + __Pyx_GIVEREF(__pyx_tuple__12); + __pyx_tuple__13 = PyTuple_Pack(2, __pyx_int_3, __pyx_int_3); if (unlikely(!__pyx_tuple__13)) __PYX_ERR(1, 100, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__13); + __Pyx_GIVEREF(__pyx_tuple__13); + + /* "View.MemoryView":101 + * try: + * if __import__("sys").version_info >= (3, 3): + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence # <<<<<<<<<<<<<< + * else: + * __pyx_collections_abc_Sequence = __import__("collections").Sequence + */ + __pyx_tuple__14 = PyTuple_Pack(1, __pyx_kp_s_collections_abc); if (unlikely(!__pyx_tuple__14)) __PYX_ERR(1, 101, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__14); + __Pyx_GIVEREF(__pyx_tuple__14); + + /* "View.MemoryView":103 + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + * else: + * __pyx_collections_abc_Sequence = __import__("collections").Sequence # <<<<<<<<<<<<<< + * except: + * + */ + __pyx_tuple__15 = PyTuple_Pack(1, __pyx_n_s_collections); if (unlikely(!__pyx_tuple__15)) __PYX_ERR(1, 103, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__15); + __Pyx_GIVEREF(__pyx_tuple__15); + + /* "View.MemoryView":309 + * return self.name + * + * cdef generic = Enum("") # <<<<<<<<<<<<<< + * cdef strided = Enum("") # default + * cdef indirect = Enum("") + */ + __pyx_tuple__16 = PyTuple_Pack(1, __pyx_kp_s_strided_and_direct_or_indirect); if (unlikely(!__pyx_tuple__16)) __PYX_ERR(1, 309, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__16); + __Pyx_GIVEREF(__pyx_tuple__16); + + /* "View.MemoryView":310 + * + * cdef generic = Enum("") + * cdef strided = Enum("") # default # <<<<<<<<<<<<<< + * cdef indirect = Enum("") + * + */ + __pyx_tuple__17 = PyTuple_Pack(1, __pyx_kp_s_strided_and_direct); if (unlikely(!__pyx_tuple__17)) __PYX_ERR(1, 310, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__17); + __Pyx_GIVEREF(__pyx_tuple__17); + + /* "View.MemoryView":311 + * cdef generic = Enum("") + * cdef strided = Enum("") # default + * cdef indirect = Enum("") # <<<<<<<<<<<<<< + * + * + */ + __pyx_tuple__18 = PyTuple_Pack(1, __pyx_kp_s_strided_and_indirect); if (unlikely(!__pyx_tuple__18)) __PYX_ERR(1, 311, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__18); + __Pyx_GIVEREF(__pyx_tuple__18); + + /* "View.MemoryView":314 + * + * + * cdef contiguous = Enum("") # <<<<<<<<<<<<<< + * cdef indirect_contiguous = Enum("") + * + */ + __pyx_tuple__19 = PyTuple_Pack(1, __pyx_kp_s_contiguous_and_direct); if (unlikely(!__pyx_tuple__19)) __PYX_ERR(1, 314, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__19); + __Pyx_GIVEREF(__pyx_tuple__19); + + /* "View.MemoryView":315 + * + * cdef contiguous = Enum("") + * cdef indirect_contiguous = Enum("") # <<<<<<<<<<<<<< + * + * + */ + __pyx_tuple__20 = PyTuple_Pack(1, __pyx_kp_s_contiguous_and_indirect); if (unlikely(!__pyx_tuple__20)) __PYX_ERR(1, 315, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__20); + __Pyx_GIVEREF(__pyx_tuple__20); + + /* "(tree fragment)":1 + * def __pyx_unpickle_Enum(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< + * cdef object __pyx_PickleError + * cdef object __pyx_result + */ + __pyx_tuple__21 = PyTuple_Pack(5, __pyx_n_s_pyx_type, __pyx_n_s_pyx_checksum, __pyx_n_s_pyx_state, __pyx_n_s_pyx_PickleError, __pyx_n_s_pyx_result); if (unlikely(!__pyx_tuple__21)) __PYX_ERR(1, 1, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__21); + __Pyx_GIVEREF(__pyx_tuple__21); + __pyx_codeobj__22 = (PyObject*)__Pyx_PyCode_New(3, 0, 0, 5, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__21, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_stringsource, __pyx_n_s_pyx_unpickle_Enum, 1, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__22)) __PYX_ERR(1, 1, __pyx_L1_error) + + /* "TTS/tts/utils/monotonic_align/core.pyx":42 + * @cython.boundscheck(False) + * @cython.wraparound(False) + * cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: # <<<<<<<<<<<<<< + * cdef int b = values.shape[0] + * + */ + __pyx_tuple__23 = PyTuple_Pack(5, __pyx_n_s_paths, __pyx_n_s_values, __pyx_n_s_t_xs, __pyx_n_s_t_ys, __pyx_n_s_max_neg_val); if (unlikely(!__pyx_tuple__23)) __PYX_ERR(0, 42, __pyx_L1_error) + __Pyx_GOTREF(__pyx_tuple__23); + __Pyx_GIVEREF(__pyx_tuple__23); + __pyx_codeobj__24 = (PyObject*)__Pyx_PyCode_New(5, 0, 0, 5, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__23, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_TTS_tts_utils_monotonic_align_co, __pyx_n_s_maximum_path_c, 42, __pyx_empty_bytes); if (unlikely(!__pyx_codeobj__24)) __PYX_ERR(0, 42, __pyx_L1_error) + __Pyx_RefNannyFinishContext(); + return 0; + __pyx_L1_error:; + __Pyx_RefNannyFinishContext(); + return -1; +} +/* #### Code section: init_constants ### */ + +static CYTHON_SMALL_CODE int __Pyx_InitConstants(void) { + if (__Pyx_CreateStringTabAndInitStrings() < 0) __PYX_ERR(0, 1, __pyx_L1_error); + __pyx_int_0 = PyInt_FromLong(0); if (unlikely(!__pyx_int_0)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_1 = PyInt_FromLong(1); if (unlikely(!__pyx_int_1)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_3 = PyInt_FromLong(3); if (unlikely(!__pyx_int_3)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_112105877 = PyInt_FromLong(112105877L); if (unlikely(!__pyx_int_112105877)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_136983863 = PyInt_FromLong(136983863L); if (unlikely(!__pyx_int_136983863)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_184977713 = PyInt_FromLong(184977713L); if (unlikely(!__pyx_int_184977713)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_int_neg_1 = PyInt_FromLong(-1); if (unlikely(!__pyx_int_neg_1)) __PYX_ERR(0, 1, __pyx_L1_error) + return 0; + __pyx_L1_error:; + return -1; +} +/* #### Code section: init_globals ### */ + +static CYTHON_SMALL_CODE int __Pyx_InitGlobals(void) { + /* AssertionsEnabled.init */ + if (likely(__Pyx_init_assertions_enabled() == 0)); else + +if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 1, __pyx_L1_error) + + /* NumpyImportArray.init */ + /* + * Cython has automatically inserted a call to _import_array since + * you didn't include one when you cimported numpy. To disable this + * add the line + * numpy._import_array + */ +#ifdef NPY_FEATURE_VERSION +#ifndef NO_IMPORT_ARRAY +if (unlikely(_import_array() == -1)) { + PyErr_SetString(PyExc_ImportError, "numpy.core.multiarray failed to import " + "(auto-generated because you didn't call 'numpy.import_array()' after cimporting numpy; " + "use 'numpy._import_array' to disable if you are certain you don't need it)."); +} +#endif +#endif + +if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 1, __pyx_L1_error) + + /* InitThreads.init */ + #if defined(WITH_THREAD) && PY_VERSION_HEX < 0x030700F0 +PyEval_InitThreads(); +#endif + +if (unlikely(PyErr_Occurred())) __PYX_ERR(0, 1, __pyx_L1_error) + + return 0; + __pyx_L1_error:; + return -1; +} +/* #### Code section: init_module ### */ + +static CYTHON_SMALL_CODE int __Pyx_modinit_global_init_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_variable_export_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_function_export_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_type_init_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_type_import_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_variable_import_code(void); /*proto*/ +static CYTHON_SMALL_CODE int __Pyx_modinit_function_import_code(void); /*proto*/ + +static int __Pyx_modinit_global_init_code(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_modinit_global_init_code", 0); + /*--- Global init code ---*/ + __pyx_collections_abc_Sequence = Py_None; Py_INCREF(Py_None); + generic = Py_None; Py_INCREF(Py_None); + strided = Py_None; Py_INCREF(Py_None); + indirect = Py_None; Py_INCREF(Py_None); + contiguous = Py_None; Py_INCREF(Py_None); + indirect_contiguous = Py_None; Py_INCREF(Py_None); + __Pyx_RefNannyFinishContext(); + return 0; +} + +static int __Pyx_modinit_variable_export_code(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_modinit_variable_export_code", 0); + /*--- Variable export code ---*/ + __Pyx_RefNannyFinishContext(); + return 0; +} + +static int __Pyx_modinit_function_export_code(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_modinit_function_export_code", 0); + /*--- Function export code ---*/ + __Pyx_RefNannyFinishContext(); + return 0; +} + +static int __Pyx_modinit_type_init_code(void) { + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__Pyx_modinit_type_init_code", 0); + /*--- Type init code ---*/ + __pyx_vtabptr_array = &__pyx_vtable_array; + __pyx_vtable_array.get_memview = (PyObject *(*)(struct __pyx_array_obj *))__pyx_array_get_memview; + #if CYTHON_USE_TYPE_SPECS + __pyx_array_type = (PyTypeObject *) __Pyx_PyType_FromModuleAndSpec(__pyx_m, &__pyx_type___pyx_array_spec, NULL); if (unlikely(!__pyx_array_type)) __PYX_ERR(1, 114, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + __pyx_array_type->tp_as_buffer = &__pyx_tp_as_buffer_array; + if (!__pyx_array_type->tp_as_buffer->bf_releasebuffer && __pyx_array_type->tp_base->tp_as_buffer && __pyx_array_type->tp_base->tp_as_buffer->bf_releasebuffer) { + __pyx_array_type->tp_as_buffer->bf_releasebuffer = __pyx_array_type->tp_base->tp_as_buffer->bf_releasebuffer; + } + #elif defined(Py_bf_getbuffer) && defined(Py_bf_releasebuffer) + /* PY_VERSION_HEX >= 0x03090000 || Py_LIMITED_API >= 0x030B0000 */ + #elif defined(_MSC_VER) + #pragma message ("The buffer protocol is not supported in the Limited C-API < 3.11.") + #else + #warning "The buffer protocol is not supported in the Limited C-API < 3.11." + #endif + if (__Pyx_fix_up_extension_type_from_spec(&__pyx_type___pyx_array_spec, __pyx_array_type) < 0) __PYX_ERR(1, 114, __pyx_L1_error) + #else + __pyx_array_type = &__pyx_type___pyx_array; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + #endif + #if !CYTHON_USE_TYPE_SPECS + if (__Pyx_PyType_Ready(__pyx_array_type) < 0) __PYX_ERR(1, 114, __pyx_L1_error) + #endif + #if PY_MAJOR_VERSION < 3 + __pyx_array_type->tp_print = 0; + #endif + if (__Pyx_SetVtable(__pyx_array_type, __pyx_vtabptr_array) < 0) __PYX_ERR(1, 114, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_MergeVtables(__pyx_array_type) < 0) __PYX_ERR(1, 114, __pyx_L1_error) + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_setup_reduce((PyObject *) __pyx_array_type) < 0) __PYX_ERR(1, 114, __pyx_L1_error) + #endif + #if CYTHON_USE_TYPE_SPECS + __pyx_MemviewEnum_type = (PyTypeObject *) __Pyx_PyType_FromModuleAndSpec(__pyx_m, &__pyx_type___pyx_MemviewEnum_spec, NULL); if (unlikely(!__pyx_MemviewEnum_type)) __PYX_ERR(1, 302, __pyx_L1_error) + if (__Pyx_fix_up_extension_type_from_spec(&__pyx_type___pyx_MemviewEnum_spec, __pyx_MemviewEnum_type) < 0) __PYX_ERR(1, 302, __pyx_L1_error) + #else + __pyx_MemviewEnum_type = &__pyx_type___pyx_MemviewEnum; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + #endif + #if !CYTHON_USE_TYPE_SPECS + if (__Pyx_PyType_Ready(__pyx_MemviewEnum_type) < 0) __PYX_ERR(1, 302, __pyx_L1_error) + #endif + #if PY_MAJOR_VERSION < 3 + __pyx_MemviewEnum_type->tp_print = 0; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_MemviewEnum_type->tp_dictoffset && __pyx_MemviewEnum_type->tp_getattro == PyObject_GenericGetAttr)) { + __pyx_MemviewEnum_type->tp_getattro = __Pyx_PyObject_GenericGetAttr; + } + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_setup_reduce((PyObject *) __pyx_MemviewEnum_type) < 0) __PYX_ERR(1, 302, __pyx_L1_error) + #endif + __pyx_vtabptr_memoryview = &__pyx_vtable_memoryview; + __pyx_vtable_memoryview.get_item_pointer = (char *(*)(struct __pyx_memoryview_obj *, PyObject *))__pyx_memoryview_get_item_pointer; + __pyx_vtable_memoryview.is_slice = (PyObject *(*)(struct __pyx_memoryview_obj *, PyObject *))__pyx_memoryview_is_slice; + __pyx_vtable_memoryview.setitem_slice_assignment = (PyObject *(*)(struct __pyx_memoryview_obj *, PyObject *, PyObject *))__pyx_memoryview_setitem_slice_assignment; + __pyx_vtable_memoryview.setitem_slice_assign_scalar = (PyObject *(*)(struct __pyx_memoryview_obj *, struct __pyx_memoryview_obj *, PyObject *))__pyx_memoryview_setitem_slice_assign_scalar; + __pyx_vtable_memoryview.setitem_indexed = (PyObject *(*)(struct __pyx_memoryview_obj *, PyObject *, PyObject *))__pyx_memoryview_setitem_indexed; + __pyx_vtable_memoryview.convert_item_to_object = (PyObject *(*)(struct __pyx_memoryview_obj *, char *))__pyx_memoryview_convert_item_to_object; + __pyx_vtable_memoryview.assign_item_from_object = (PyObject *(*)(struct __pyx_memoryview_obj *, char *, PyObject *))__pyx_memoryview_assign_item_from_object; + __pyx_vtable_memoryview._get_base = (PyObject *(*)(struct __pyx_memoryview_obj *))__pyx_memoryview__get_base; + #if CYTHON_USE_TYPE_SPECS + __pyx_memoryview_type = (PyTypeObject *) __Pyx_PyType_FromModuleAndSpec(__pyx_m, &__pyx_type___pyx_memoryview_spec, NULL); if (unlikely(!__pyx_memoryview_type)) __PYX_ERR(1, 337, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + __pyx_memoryview_type->tp_as_buffer = &__pyx_tp_as_buffer_memoryview; + if (!__pyx_memoryview_type->tp_as_buffer->bf_releasebuffer && __pyx_memoryview_type->tp_base->tp_as_buffer && __pyx_memoryview_type->tp_base->tp_as_buffer->bf_releasebuffer) { + __pyx_memoryview_type->tp_as_buffer->bf_releasebuffer = __pyx_memoryview_type->tp_base->tp_as_buffer->bf_releasebuffer; + } + #elif defined(Py_bf_getbuffer) && defined(Py_bf_releasebuffer) + /* PY_VERSION_HEX >= 0x03090000 || Py_LIMITED_API >= 0x030B0000 */ + #elif defined(_MSC_VER) + #pragma message ("The buffer protocol is not supported in the Limited C-API < 3.11.") + #else + #warning "The buffer protocol is not supported in the Limited C-API < 3.11." + #endif + if (__Pyx_fix_up_extension_type_from_spec(&__pyx_type___pyx_memoryview_spec, __pyx_memoryview_type) < 0) __PYX_ERR(1, 337, __pyx_L1_error) + #else + __pyx_memoryview_type = &__pyx_type___pyx_memoryview; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + #endif + #if !CYTHON_USE_TYPE_SPECS + if (__Pyx_PyType_Ready(__pyx_memoryview_type) < 0) __PYX_ERR(1, 337, __pyx_L1_error) + #endif + #if PY_MAJOR_VERSION < 3 + __pyx_memoryview_type->tp_print = 0; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_memoryview_type->tp_dictoffset && __pyx_memoryview_type->tp_getattro == PyObject_GenericGetAttr)) { + __pyx_memoryview_type->tp_getattro = __Pyx_PyObject_GenericGetAttr; + } + #endif + if (__Pyx_SetVtable(__pyx_memoryview_type, __pyx_vtabptr_memoryview) < 0) __PYX_ERR(1, 337, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_MergeVtables(__pyx_memoryview_type) < 0) __PYX_ERR(1, 337, __pyx_L1_error) + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_setup_reduce((PyObject *) __pyx_memoryview_type) < 0) __PYX_ERR(1, 337, __pyx_L1_error) + #endif + __pyx_vtabptr__memoryviewslice = &__pyx_vtable__memoryviewslice; + __pyx_vtable__memoryviewslice.__pyx_base = *__pyx_vtabptr_memoryview; + __pyx_vtable__memoryviewslice.__pyx_base.convert_item_to_object = (PyObject *(*)(struct __pyx_memoryview_obj *, char *))__pyx_memoryviewslice_convert_item_to_object; + __pyx_vtable__memoryviewslice.__pyx_base.assign_item_from_object = (PyObject *(*)(struct __pyx_memoryview_obj *, char *, PyObject *))__pyx_memoryviewslice_assign_item_from_object; + __pyx_vtable__memoryviewslice.__pyx_base._get_base = (PyObject *(*)(struct __pyx_memoryview_obj *))__pyx_memoryviewslice__get_base; + #if CYTHON_USE_TYPE_SPECS + __pyx_t_1 = PyTuple_Pack(1, (PyObject *)__pyx_memoryview_type); if (unlikely(!__pyx_t_1)) __PYX_ERR(1, 952, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_memoryviewslice_type = (PyTypeObject *) __Pyx_PyType_FromModuleAndSpec(__pyx_m, &__pyx_type___pyx_memoryviewslice_spec, __pyx_t_1); + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + if (unlikely(!__pyx_memoryviewslice_type)) __PYX_ERR(1, 952, __pyx_L1_error) + if (__Pyx_fix_up_extension_type_from_spec(&__pyx_type___pyx_memoryviewslice_spec, __pyx_memoryviewslice_type) < 0) __PYX_ERR(1, 952, __pyx_L1_error) + #else + __pyx_memoryviewslice_type = &__pyx_type___pyx_memoryviewslice; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + __pyx_memoryviewslice_type->tp_base = __pyx_memoryview_type; + #endif + #if !CYTHON_USE_TYPE_SPECS + if (__Pyx_PyType_Ready(__pyx_memoryviewslice_type) < 0) __PYX_ERR(1, 952, __pyx_L1_error) + #endif + #if PY_MAJOR_VERSION < 3 + __pyx_memoryviewslice_type->tp_print = 0; + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if ((CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP) && likely(!__pyx_memoryviewslice_type->tp_dictoffset && __pyx_memoryviewslice_type->tp_getattro == PyObject_GenericGetAttr)) { + __pyx_memoryviewslice_type->tp_getattro = __Pyx_PyObject_GenericGetAttr; + } + #endif + if (__Pyx_SetVtable(__pyx_memoryviewslice_type, __pyx_vtabptr__memoryviewslice) < 0) __PYX_ERR(1, 952, __pyx_L1_error) + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_MergeVtables(__pyx_memoryviewslice_type) < 0) __PYX_ERR(1, 952, __pyx_L1_error) + #endif + #if !CYTHON_COMPILING_IN_LIMITED_API + if (__Pyx_setup_reduce((PyObject *) __pyx_memoryviewslice_type) < 0) __PYX_ERR(1, 952, __pyx_L1_error) + #endif + __Pyx_RefNannyFinishContext(); + return 0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_RefNannyFinishContext(); + return -1; +} + +static int __Pyx_modinit_type_import_code(void) { + __Pyx_RefNannyDeclarations + PyObject *__pyx_t_1 = NULL; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannySetupContext("__Pyx_modinit_type_import_code", 0); + /*--- Type import code ---*/ + __pyx_t_1 = PyImport_ImportModule(__Pyx_BUILTIN_MODULE_NAME); if (unlikely(!__pyx_t_1)) __PYX_ERR(3, 9, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_ptype_7cpython_4type_type = __Pyx_ImportType_3_0_11(__pyx_t_1, __Pyx_BUILTIN_MODULE_NAME, "type", + #if defined(PYPY_VERSION_NUM) && PYPY_VERSION_NUM < 0x050B0000 + sizeof(PyTypeObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyTypeObject), + #elif CYTHON_COMPILING_IN_LIMITED_API + sizeof(PyTypeObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyTypeObject), + #else + sizeof(PyHeapTypeObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyHeapTypeObject), + #endif + __Pyx_ImportType_CheckSize_Warn_3_0_11); if (!__pyx_ptype_7cpython_4type_type) __PYX_ERR(3, 9, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __pyx_t_1 = PyImport_ImportModule("numpy"); if (unlikely(!__pyx_t_1)) __PYX_ERR(2, 271, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_1); + __pyx_ptype_5numpy_dtype = __Pyx_ImportType_3_0_11(__pyx_t_1, "numpy", "dtype", sizeof(PyArray_Descr), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyArray_Descr),__Pyx_ImportType_CheckSize_Ignore_3_0_11); if (!__pyx_ptype_5numpy_dtype) __PYX_ERR(2, 271, __pyx_L1_error) + __pyx_ptype_5numpy_flatiter = __Pyx_ImportType_3_0_11(__pyx_t_1, "numpy", "flatiter", sizeof(PyArrayIterObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyArrayIterObject),__Pyx_ImportType_CheckSize_Ignore_3_0_11); if (!__pyx_ptype_5numpy_flatiter) __PYX_ERR(2, 316, __pyx_L1_error) + __pyx_ptype_5numpy_broadcast = __Pyx_ImportType_3_0_11(__pyx_t_1, "numpy", "broadcast", sizeof(PyArrayMultiIterObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyArrayMultiIterObject),__Pyx_ImportType_CheckSize_Ignore_3_0_11); if (!__pyx_ptype_5numpy_broadcast) __PYX_ERR(2, 320, __pyx_L1_error) + __pyx_ptype_5numpy_ndarray = __Pyx_ImportType_3_0_11(__pyx_t_1, "numpy", "ndarray", sizeof(PyArrayObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyArrayObject),__Pyx_ImportType_CheckSize_Ignore_3_0_11); if (!__pyx_ptype_5numpy_ndarray) __PYX_ERR(2, 359, __pyx_L1_error) + __pyx_ptype_5numpy_generic = __Pyx_ImportType_3_0_11(__pyx_t_1, "numpy", "generic", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_11); if (!__pyx_ptype_5numpy_generic) __PYX_ERR(2, 848, __pyx_L1_error) + __pyx_ptype_5numpy_number = __Pyx_ImportType_3_0_11(__pyx_t_1, "numpy", "number", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_11); if (!__pyx_ptype_5numpy_number) __PYX_ERR(2, 850, __pyx_L1_error) + __pyx_ptype_5numpy_integer = __Pyx_ImportType_3_0_11(__pyx_t_1, "numpy", "integer", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_11); if (!__pyx_ptype_5numpy_integer) __PYX_ERR(2, 852, __pyx_L1_error) + __pyx_ptype_5numpy_signedinteger = __Pyx_ImportType_3_0_11(__pyx_t_1, "numpy", "signedinteger", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_11); if (!__pyx_ptype_5numpy_signedinteger) __PYX_ERR(2, 854, __pyx_L1_error) + __pyx_ptype_5numpy_unsignedinteger = __Pyx_ImportType_3_0_11(__pyx_t_1, "numpy", "unsignedinteger", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_11); if (!__pyx_ptype_5numpy_unsignedinteger) __PYX_ERR(2, 856, __pyx_L1_error) + __pyx_ptype_5numpy_inexact = __Pyx_ImportType_3_0_11(__pyx_t_1, "numpy", "inexact", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_11); if (!__pyx_ptype_5numpy_inexact) __PYX_ERR(2, 858, __pyx_L1_error) + __pyx_ptype_5numpy_floating = __Pyx_ImportType_3_0_11(__pyx_t_1, "numpy", "floating", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_11); if (!__pyx_ptype_5numpy_floating) __PYX_ERR(2, 860, __pyx_L1_error) + __pyx_ptype_5numpy_complexfloating = __Pyx_ImportType_3_0_11(__pyx_t_1, "numpy", "complexfloating", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_11); if (!__pyx_ptype_5numpy_complexfloating) __PYX_ERR(2, 862, __pyx_L1_error) + __pyx_ptype_5numpy_flexible = __Pyx_ImportType_3_0_11(__pyx_t_1, "numpy", "flexible", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_11); if (!__pyx_ptype_5numpy_flexible) __PYX_ERR(2, 864, __pyx_L1_error) + __pyx_ptype_5numpy_character = __Pyx_ImportType_3_0_11(__pyx_t_1, "numpy", "character", sizeof(PyObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyObject),__Pyx_ImportType_CheckSize_Warn_3_0_11); if (!__pyx_ptype_5numpy_character) __PYX_ERR(2, 866, __pyx_L1_error) + __pyx_ptype_5numpy_ufunc = __Pyx_ImportType_3_0_11(__pyx_t_1, "numpy", "ufunc", sizeof(PyUFuncObject), __PYX_GET_STRUCT_ALIGNMENT_3_0_11(PyUFuncObject),__Pyx_ImportType_CheckSize_Ignore_3_0_11); if (!__pyx_ptype_5numpy_ufunc) __PYX_ERR(2, 930, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_RefNannyFinishContext(); + return 0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_1); + __Pyx_RefNannyFinishContext(); + return -1; +} + +static int __Pyx_modinit_variable_import_code(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_modinit_variable_import_code", 0); + /*--- Variable import code ---*/ + __Pyx_RefNannyFinishContext(); + return 0; +} + +static int __Pyx_modinit_function_import_code(void) { + __Pyx_RefNannyDeclarations + __Pyx_RefNannySetupContext("__Pyx_modinit_function_import_code", 0); + /*--- Function import code ---*/ + __Pyx_RefNannyFinishContext(); + return 0; +} + + +#if PY_MAJOR_VERSION >= 3 +#if CYTHON_PEP489_MULTI_PHASE_INIT +static PyObject* __pyx_pymod_create(PyObject *spec, PyModuleDef *def); /*proto*/ +static int __pyx_pymod_exec_core(PyObject* module); /*proto*/ +static PyModuleDef_Slot __pyx_moduledef_slots[] = { + {Py_mod_create, (void*)__pyx_pymod_create}, + {Py_mod_exec, (void*)__pyx_pymod_exec_core}, + {0, NULL} +}; +#endif + +#ifdef __cplusplus +namespace { + struct PyModuleDef __pyx_moduledef = + #else + static struct PyModuleDef __pyx_moduledef = + #endif + { + PyModuleDef_HEAD_INIT, + "core", + 0, /* m_doc */ + #if CYTHON_PEP489_MULTI_PHASE_INIT + 0, /* m_size */ + #elif CYTHON_USE_MODULE_STATE + sizeof(__pyx_mstate), /* m_size */ + #else + -1, /* m_size */ + #endif + __pyx_methods /* m_methods */, + #if CYTHON_PEP489_MULTI_PHASE_INIT + __pyx_moduledef_slots, /* m_slots */ + #else + NULL, /* m_reload */ + #endif + #if CYTHON_USE_MODULE_STATE + __pyx_m_traverse, /* m_traverse */ + __pyx_m_clear, /* m_clear */ + NULL /* m_free */ + #else + NULL, /* m_traverse */ + NULL, /* m_clear */ + NULL /* m_free */ + #endif + }; + #ifdef __cplusplus +} /* anonymous namespace */ +#endif +#endif + +#ifndef CYTHON_NO_PYINIT_EXPORT +#define __Pyx_PyMODINIT_FUNC PyMODINIT_FUNC +#elif PY_MAJOR_VERSION < 3 +#ifdef __cplusplus +#define __Pyx_PyMODINIT_FUNC extern "C" void +#else +#define __Pyx_PyMODINIT_FUNC void +#endif +#else +#ifdef __cplusplus +#define __Pyx_PyMODINIT_FUNC extern "C" PyObject * +#else +#define __Pyx_PyMODINIT_FUNC PyObject * +#endif +#endif + + +#if PY_MAJOR_VERSION < 3 +__Pyx_PyMODINIT_FUNC initcore(void) CYTHON_SMALL_CODE; /*proto*/ +__Pyx_PyMODINIT_FUNC initcore(void) +#else +__Pyx_PyMODINIT_FUNC PyInit_core(void) CYTHON_SMALL_CODE; /*proto*/ +__Pyx_PyMODINIT_FUNC PyInit_core(void) +#if CYTHON_PEP489_MULTI_PHASE_INIT +{ + return PyModuleDef_Init(&__pyx_moduledef); +} +static CYTHON_SMALL_CODE int __Pyx_check_single_interpreter(void) { + #if PY_VERSION_HEX >= 0x030700A1 + static PY_INT64_T main_interpreter_id = -1; + PY_INT64_T current_id = PyInterpreterState_GetID(PyThreadState_Get()->interp); + if (main_interpreter_id == -1) { + main_interpreter_id = current_id; + return (unlikely(current_id == -1)) ? -1 : 0; + } else if (unlikely(main_interpreter_id != current_id)) + #else + static PyInterpreterState *main_interpreter = NULL; + PyInterpreterState *current_interpreter = PyThreadState_Get()->interp; + if (!main_interpreter) { + main_interpreter = current_interpreter; + } else if (unlikely(main_interpreter != current_interpreter)) + #endif + { + PyErr_SetString( + PyExc_ImportError, + "Interpreter change detected - this module can only be loaded into one interpreter per process."); + return -1; + } + return 0; +} +#if CYTHON_COMPILING_IN_LIMITED_API +static CYTHON_SMALL_CODE int __Pyx_copy_spec_to_module(PyObject *spec, PyObject *module, const char* from_name, const char* to_name, int allow_none) +#else +static CYTHON_SMALL_CODE int __Pyx_copy_spec_to_module(PyObject *spec, PyObject *moddict, const char* from_name, const char* to_name, int allow_none) +#endif +{ + PyObject *value = PyObject_GetAttrString(spec, from_name); + int result = 0; + if (likely(value)) { + if (allow_none || value != Py_None) { +#if CYTHON_COMPILING_IN_LIMITED_API + result = PyModule_AddObject(module, to_name, value); +#else + result = PyDict_SetItemString(moddict, to_name, value); +#endif + } + Py_DECREF(value); + } else if (PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + } else { + result = -1; + } + return result; +} +static CYTHON_SMALL_CODE PyObject* __pyx_pymod_create(PyObject *spec, PyModuleDef *def) { + PyObject *module = NULL, *moddict, *modname; + CYTHON_UNUSED_VAR(def); + if (__Pyx_check_single_interpreter()) + return NULL; + if (__pyx_m) + return __Pyx_NewRef(__pyx_m); + modname = PyObject_GetAttrString(spec, "name"); + if (unlikely(!modname)) goto bad; + module = PyModule_NewObject(modname); + Py_DECREF(modname); + if (unlikely(!module)) goto bad; +#if CYTHON_COMPILING_IN_LIMITED_API + moddict = module; +#else + moddict = PyModule_GetDict(module); + if (unlikely(!moddict)) goto bad; +#endif + if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "loader", "__loader__", 1) < 0)) goto bad; + if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "origin", "__file__", 1) < 0)) goto bad; + if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "parent", "__package__", 1) < 0)) goto bad; + if (unlikely(__Pyx_copy_spec_to_module(spec, moddict, "submodule_search_locations", "__path__", 0) < 0)) goto bad; + return module; +bad: + Py_XDECREF(module); + return NULL; +} + + +static CYTHON_SMALL_CODE int __pyx_pymod_exec_core(PyObject *__pyx_pyinit_module) +#endif +#endif +{ + int stringtab_initialized = 0; + #if CYTHON_USE_MODULE_STATE + int pystate_addmodule_run = 0; + #endif + PyObject *__pyx_t_1 = NULL; + PyObject *__pyx_t_2 = NULL; + PyObject *__pyx_t_3 = NULL; + PyObject *__pyx_t_4 = NULL; + PyObject *__pyx_t_5 = NULL; + int __pyx_t_6; + PyObject *__pyx_t_7 = NULL; + static PyThread_type_lock __pyx_t_8[8]; + int __pyx_lineno = 0; + const char *__pyx_filename = NULL; + int __pyx_clineno = 0; + __Pyx_RefNannyDeclarations + #if CYTHON_PEP489_MULTI_PHASE_INIT + if (__pyx_m) { + if (__pyx_m == __pyx_pyinit_module) return 0; + PyErr_SetString(PyExc_RuntimeError, "Module 'core' has already been imported. Re-initialisation is not supported."); + return -1; + } + #elif PY_MAJOR_VERSION >= 3 + if (__pyx_m) return __Pyx_NewRef(__pyx_m); + #endif + /*--- Module creation code ---*/ + #if CYTHON_PEP489_MULTI_PHASE_INIT + __pyx_m = __pyx_pyinit_module; + Py_INCREF(__pyx_m); + #else + #if PY_MAJOR_VERSION < 3 + __pyx_m = Py_InitModule4("core", __pyx_methods, 0, 0, PYTHON_API_VERSION); Py_XINCREF(__pyx_m); + if (unlikely(!__pyx_m)) __PYX_ERR(0, 1, __pyx_L1_error) + #elif CYTHON_USE_MODULE_STATE + __pyx_t_1 = PyModule_Create(&__pyx_moduledef); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error) + { + int add_module_result = PyState_AddModule(__pyx_t_1, &__pyx_moduledef); + __pyx_t_1 = 0; /* transfer ownership from __pyx_t_1 to "core" pseudovariable */ + if (unlikely((add_module_result < 0))) __PYX_ERR(0, 1, __pyx_L1_error) + pystate_addmodule_run = 1; + } + #else + __pyx_m = PyModule_Create(&__pyx_moduledef); + if (unlikely(!__pyx_m)) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #endif + CYTHON_UNUSED_VAR(__pyx_t_1); + __pyx_d = PyModule_GetDict(__pyx_m); if (unlikely(!__pyx_d)) __PYX_ERR(0, 1, __pyx_L1_error) + Py_INCREF(__pyx_d); + __pyx_b = __Pyx_PyImport_AddModuleRef(__Pyx_BUILTIN_MODULE_NAME); if (unlikely(!__pyx_b)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_cython_runtime = __Pyx_PyImport_AddModuleRef((const char *) "cython_runtime"); if (unlikely(!__pyx_cython_runtime)) __PYX_ERR(0, 1, __pyx_L1_error) + if (PyObject_SetAttrString(__pyx_m, "__builtins__", __pyx_b) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #if CYTHON_REFNANNY +__Pyx_RefNanny = __Pyx_RefNannyImportAPI("refnanny"); +if (!__Pyx_RefNanny) { + PyErr_Clear(); + __Pyx_RefNanny = __Pyx_RefNannyImportAPI("Cython.Runtime.refnanny"); + if (!__Pyx_RefNanny) + Py_FatalError("failed to import 'refnanny' module"); +} +#endif + __Pyx_RefNannySetupContext("__Pyx_PyMODINIT_FUNC PyInit_core(void)", 0); + if (__Pyx_check_binary_version(__PYX_LIMITED_VERSION_HEX, __Pyx_get_runtime_version(), CYTHON_COMPILING_IN_LIMITED_API) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #ifdef __Pxy_PyFrame_Initialize_Offsets + __Pxy_PyFrame_Initialize_Offsets(); + #endif + __pyx_empty_tuple = PyTuple_New(0); if (unlikely(!__pyx_empty_tuple)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_empty_bytes = PyBytes_FromStringAndSize("", 0); if (unlikely(!__pyx_empty_bytes)) __PYX_ERR(0, 1, __pyx_L1_error) + __pyx_empty_unicode = PyUnicode_FromStringAndSize("", 0); if (unlikely(!__pyx_empty_unicode)) __PYX_ERR(0, 1, __pyx_L1_error) + #ifdef __Pyx_CyFunction_USED + if (__pyx_CyFunction_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #ifdef __Pyx_FusedFunction_USED + if (__pyx_FusedFunction_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #ifdef __Pyx_Coroutine_USED + if (__pyx_Coroutine_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #ifdef __Pyx_Generator_USED + if (__pyx_Generator_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #ifdef __Pyx_AsyncGen_USED + if (__pyx_AsyncGen_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + #ifdef __Pyx_StopAsyncIteration_USED + if (__pyx_StopAsyncIteration_init(__pyx_m) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + /*--- Library function declarations ---*/ + /*--- Threads initialization code ---*/ + #if defined(WITH_THREAD) && PY_VERSION_HEX < 0x030700F0 && defined(__PYX_FORCE_INIT_THREADS) && __PYX_FORCE_INIT_THREADS + PyEval_InitThreads(); + #endif + /*--- Initialize various global constants etc. ---*/ + if (__Pyx_InitConstants() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + stringtab_initialized = 1; + if (__Pyx_InitGlobals() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #if PY_MAJOR_VERSION < 3 && (__PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT) + if (__Pyx_init_sys_getdefaultencoding_params() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + if (__pyx_module_is_main_TTS__tts__utils__monotonic_align__core) { + if (PyObject_SetAttr(__pyx_m, __pyx_n_s_name_2, __pyx_n_s_main) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + } + #if PY_MAJOR_VERSION >= 3 + { + PyObject *modules = PyImport_GetModuleDict(); if (unlikely(!modules)) __PYX_ERR(0, 1, __pyx_L1_error) + if (!PyDict_GetItemString(modules, "TTS.tts.utils.monotonic_align.core")) { + if (unlikely((PyDict_SetItemString(modules, "TTS.tts.utils.monotonic_align.core", __pyx_m) < 0))) __PYX_ERR(0, 1, __pyx_L1_error) + } + } + #endif + /*--- Builtin init code ---*/ + if (__Pyx_InitCachedBuiltins() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + /*--- Constants init code ---*/ + if (__Pyx_InitCachedConstants() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + /*--- Global type/function init code ---*/ + (void)__Pyx_modinit_global_init_code(); + (void)__Pyx_modinit_variable_export_code(); + (void)__Pyx_modinit_function_export_code(); + if (unlikely((__Pyx_modinit_type_init_code() < 0))) __PYX_ERR(0, 1, __pyx_L1_error) + if (unlikely((__Pyx_modinit_type_import_code() < 0))) __PYX_ERR(0, 1, __pyx_L1_error) + (void)__Pyx_modinit_variable_import_code(); + (void)__Pyx_modinit_function_import_code(); + /*--- Execution code ---*/ + #if defined(__Pyx_Generator_USED) || defined(__Pyx_Coroutine_USED) + if (__Pyx_patch_abc() < 0) __PYX_ERR(0, 1, __pyx_L1_error) + #endif + + /* "View.MemoryView":99 + * + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: # <<<<<<<<<<<<<< + * if __import__("sys").version_info >= (3, 3): + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + /*try:*/ { + + /* "View.MemoryView":100 + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: + * if __import__("sys").version_info >= (3, 3): # <<<<<<<<<<<<<< + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + * else: + */ + __pyx_t_4 = __Pyx_PyObject_Call(__pyx_builtin___import__, __pyx_tuple__12, NULL); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 100, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_4, __pyx_n_s_version_info); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 100, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __pyx_t_4 = PyObject_RichCompare(__pyx_t_5, __pyx_tuple__13, Py_GE); __Pyx_XGOTREF(__pyx_t_4); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 100, __pyx_L2_error) + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __pyx_t_6 = __Pyx_PyObject_IsTrue(__pyx_t_4); if (unlikely((__pyx_t_6 < 0))) __PYX_ERR(1, 100, __pyx_L2_error) + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + if (__pyx_t_6) { + + /* "View.MemoryView":101 + * try: + * if __import__("sys").version_info >= (3, 3): + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence # <<<<<<<<<<<<<< + * else: + * __pyx_collections_abc_Sequence = __import__("collections").Sequence + */ + __pyx_t_4 = __Pyx_PyObject_Call(__pyx_builtin___import__, __pyx_tuple__14, NULL); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 101, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_4, __pyx_n_s_abc); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 101, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __pyx_t_4 = __Pyx_PyObject_GetAttrStr(__pyx_t_5, __pyx_n_s_Sequence); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 101, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_DECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XGOTREF(__pyx_collections_abc_Sequence); + __Pyx_DECREF_SET(__pyx_collections_abc_Sequence, __pyx_t_4); + __Pyx_GIVEREF(__pyx_t_4); + __pyx_t_4 = 0; + + /* "View.MemoryView":100 + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: + * if __import__("sys").version_info >= (3, 3): # <<<<<<<<<<<<<< + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + * else: + */ + goto __pyx_L8; + } + + /* "View.MemoryView":103 + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + * else: + * __pyx_collections_abc_Sequence = __import__("collections").Sequence # <<<<<<<<<<<<<< + * except: + * + */ + /*else*/ { + __pyx_t_4 = __Pyx_PyObject_Call(__pyx_builtin___import__, __pyx_tuple__15, NULL); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 103, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_5 = __Pyx_PyObject_GetAttrStr(__pyx_t_4, __pyx_n_s_Sequence); if (unlikely(!__pyx_t_5)) __PYX_ERR(1, 103, __pyx_L2_error) + __Pyx_GOTREF(__pyx_t_5); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XGOTREF(__pyx_collections_abc_Sequence); + __Pyx_DECREF_SET(__pyx_collections_abc_Sequence, __pyx_t_5); + __Pyx_GIVEREF(__pyx_t_5); + __pyx_t_5 = 0; + } + __pyx_L8:; + + /* "View.MemoryView":99 + * + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: # <<<<<<<<<<<<<< + * if __import__("sys").version_info >= (3, 3): + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + */ + } + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + goto __pyx_L7_try_end; + __pyx_L2_error:; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + + /* "View.MemoryView":104 + * else: + * __pyx_collections_abc_Sequence = __import__("collections").Sequence + * except: # <<<<<<<<<<<<<< + * + * __pyx_collections_abc_Sequence = None + */ + /*except:*/ { + __Pyx_AddTraceback("View.MemoryView", __pyx_clineno, __pyx_lineno, __pyx_filename); + if (__Pyx_GetException(&__pyx_t_5, &__pyx_t_4, &__pyx_t_7) < 0) __PYX_ERR(1, 104, __pyx_L4_except_error) + __Pyx_XGOTREF(__pyx_t_5); + __Pyx_XGOTREF(__pyx_t_4); + __Pyx_XGOTREF(__pyx_t_7); + + /* "View.MemoryView":106 + * except: + * + * __pyx_collections_abc_Sequence = None # <<<<<<<<<<<<<< + * + * + */ + __Pyx_INCREF(Py_None); + __Pyx_XGOTREF(__pyx_collections_abc_Sequence); + __Pyx_DECREF_SET(__pyx_collections_abc_Sequence, Py_None); + __Pyx_GIVEREF(Py_None); + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + goto __pyx_L3_exception_handled; + } + + /* "View.MemoryView":99 + * + * cdef object __pyx_collections_abc_Sequence "__pyx_collections_abc_Sequence" + * try: # <<<<<<<<<<<<<< + * if __import__("sys").version_info >= (3, 3): + * __pyx_collections_abc_Sequence = __import__("collections.abc").abc.Sequence + */ + __pyx_L4_except_error:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + goto __pyx_L1_error; + __pyx_L3_exception_handled:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + __pyx_L7_try_end:; + } + + /* "View.MemoryView":241 + * + * + * try: # <<<<<<<<<<<<<< + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_3, &__pyx_t_2, &__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_1); + /*try:*/ { + + /* "View.MemoryView":242 + * + * try: + * count = __pyx_collections_abc_Sequence.count # <<<<<<<<<<<<<< + * index = __pyx_collections_abc_Sequence.index + * except: + */ + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_count); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 242, __pyx_L11_error) + __Pyx_GOTREF(__pyx_t_7); + if (__Pyx_SetItemOnTypeDict(__pyx_array_type, __pyx_n_s_count, __pyx_t_7) < 0) __PYX_ERR(1, 242, __pyx_L11_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + PyType_Modified(__pyx_array_type); + + /* "View.MemoryView":243 + * try: + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index # <<<<<<<<<<<<<< + * except: + * pass + */ + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_index); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 243, __pyx_L11_error) + __Pyx_GOTREF(__pyx_t_7); + if (__Pyx_SetItemOnTypeDict(__pyx_array_type, __pyx_n_s_index, __pyx_t_7) < 0) __PYX_ERR(1, 243, __pyx_L11_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + PyType_Modified(__pyx_array_type); + + /* "View.MemoryView":241 + * + * + * try: # <<<<<<<<<<<<<< + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + */ + } + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + goto __pyx_L16_try_end; + __pyx_L11_error:; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "View.MemoryView":244 + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + * except: # <<<<<<<<<<<<<< + * pass + * + */ + /*except:*/ { + __Pyx_ErrRestore(0,0,0); + goto __pyx_L12_exception_handled; + } + __pyx_L12_exception_handled:; + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_ExceptionReset(__pyx_t_3, __pyx_t_2, __pyx_t_1); + __pyx_L16_try_end:; + } + + /* "View.MemoryView":309 + * return self.name + * + * cdef generic = Enum("") # <<<<<<<<<<<<<< + * cdef strided = Enum("") # default + * cdef indirect = Enum("") + */ + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_MemviewEnum_type), __pyx_tuple__16, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 309, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_XGOTREF(generic); + __Pyx_DECREF_SET(generic, __pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":310 + * + * cdef generic = Enum("") + * cdef strided = Enum("") # default # <<<<<<<<<<<<<< + * cdef indirect = Enum("") + * + */ + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_MemviewEnum_type), __pyx_tuple__17, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 310, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_XGOTREF(strided); + __Pyx_DECREF_SET(strided, __pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":311 + * cdef generic = Enum("") + * cdef strided = Enum("") # default + * cdef indirect = Enum("") # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_MemviewEnum_type), __pyx_tuple__18, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 311, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_XGOTREF(indirect); + __Pyx_DECREF_SET(indirect, __pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":314 + * + * + * cdef contiguous = Enum("") # <<<<<<<<<<<<<< + * cdef indirect_contiguous = Enum("") + * + */ + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_MemviewEnum_type), __pyx_tuple__19, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 314, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_XGOTREF(contiguous); + __Pyx_DECREF_SET(contiguous, __pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":315 + * + * cdef contiguous = Enum("") + * cdef indirect_contiguous = Enum("") # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_7 = __Pyx_PyObject_Call(((PyObject *)__pyx_MemviewEnum_type), __pyx_tuple__20, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 315, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_XGOTREF(indirect_contiguous); + __Pyx_DECREF_SET(indirect_contiguous, __pyx_t_7); + __Pyx_GIVEREF(__pyx_t_7); + __pyx_t_7 = 0; + + /* "View.MemoryView":323 + * + * + * cdef int __pyx_memoryview_thread_locks_used = 0 # <<<<<<<<<<<<<< + * cdef PyThread_type_lock[8] __pyx_memoryview_thread_locks = [ + * PyThread_allocate_lock(), + */ + __pyx_memoryview_thread_locks_used = 0; + + /* "View.MemoryView":324 + * + * cdef int __pyx_memoryview_thread_locks_used = 0 + * cdef PyThread_type_lock[8] __pyx_memoryview_thread_locks = [ # <<<<<<<<<<<<<< + * PyThread_allocate_lock(), + * PyThread_allocate_lock(), + */ + __pyx_t_8[0] = PyThread_allocate_lock(); + __pyx_t_8[1] = PyThread_allocate_lock(); + __pyx_t_8[2] = PyThread_allocate_lock(); + __pyx_t_8[3] = PyThread_allocate_lock(); + __pyx_t_8[4] = PyThread_allocate_lock(); + __pyx_t_8[5] = PyThread_allocate_lock(); + __pyx_t_8[6] = PyThread_allocate_lock(); + __pyx_t_8[7] = PyThread_allocate_lock(); + memcpy(&(__pyx_memoryview_thread_locks[0]), __pyx_t_8, sizeof(__pyx_memoryview_thread_locks[0]) * (8)); + + /* "View.MemoryView":982 + * + * + * try: # <<<<<<<<<<<<<< + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_1, &__pyx_t_2, &__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_3); + /*try:*/ { + + /* "View.MemoryView":983 + * + * try: + * count = __pyx_collections_abc_Sequence.count # <<<<<<<<<<<<<< + * index = __pyx_collections_abc_Sequence.index + * except: + */ + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_count); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 983, __pyx_L17_error) + __Pyx_GOTREF(__pyx_t_7); + if (__Pyx_SetItemOnTypeDict(__pyx_memoryviewslice_type, __pyx_n_s_count, __pyx_t_7) < 0) __PYX_ERR(1, 983, __pyx_L17_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + PyType_Modified(__pyx_memoryviewslice_type); + + /* "View.MemoryView":984 + * try: + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index # <<<<<<<<<<<<<< + * except: + * pass + */ + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_index); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 984, __pyx_L17_error) + __Pyx_GOTREF(__pyx_t_7); + if (__Pyx_SetItemOnTypeDict(__pyx_memoryviewslice_type, __pyx_n_s_index, __pyx_t_7) < 0) __PYX_ERR(1, 984, __pyx_L17_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + PyType_Modified(__pyx_memoryviewslice_type); + + /* "View.MemoryView":982 + * + * + * try: # <<<<<<<<<<<<<< + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + */ + } + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + goto __pyx_L22_try_end; + __pyx_L17_error:; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "View.MemoryView":985 + * count = __pyx_collections_abc_Sequence.count + * index = __pyx_collections_abc_Sequence.index + * except: # <<<<<<<<<<<<<< + * pass + * + */ + /*except:*/ { + __Pyx_ErrRestore(0,0,0); + goto __pyx_L18_exception_handled; + } + __pyx_L18_exception_handled:; + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_ExceptionReset(__pyx_t_1, __pyx_t_2, __pyx_t_3); + __pyx_L22_try_end:; + } + + /* "View.MemoryView":988 + * pass + * + * try: # <<<<<<<<<<<<<< + * if __pyx_collections_abc_Sequence: + * + */ + { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ExceptionSave(&__pyx_t_3, &__pyx_t_2, &__pyx_t_1); + __Pyx_XGOTREF(__pyx_t_3); + __Pyx_XGOTREF(__pyx_t_2); + __Pyx_XGOTREF(__pyx_t_1); + /*try:*/ { + + /* "View.MemoryView":989 + * + * try: + * if __pyx_collections_abc_Sequence: # <<<<<<<<<<<<<< + * + * + */ + __pyx_t_6 = __Pyx_PyObject_IsTrue(__pyx_collections_abc_Sequence); if (unlikely((__pyx_t_6 < 0))) __PYX_ERR(1, 989, __pyx_L23_error) + if (__pyx_t_6) { + + /* "View.MemoryView":993 + * + * + * __pyx_collections_abc_Sequence.register(_memoryviewslice) # <<<<<<<<<<<<<< + * __pyx_collections_abc_Sequence.register(array) + * except: + */ + __pyx_t_7 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_register); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 993, __pyx_L23_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_4 = __Pyx_PyObject_CallOneArg(__pyx_t_7, ((PyObject *)__pyx_memoryviewslice_type)); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 993, __pyx_L23_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + + /* "View.MemoryView":994 + * + * __pyx_collections_abc_Sequence.register(_memoryviewslice) + * __pyx_collections_abc_Sequence.register(array) # <<<<<<<<<<<<<< + * except: + * pass # ignore failure, it's a minor issue + */ + __pyx_t_4 = __Pyx_PyObject_GetAttrStr(__pyx_collections_abc_Sequence, __pyx_n_s_register); if (unlikely(!__pyx_t_4)) __PYX_ERR(1, 994, __pyx_L23_error) + __Pyx_GOTREF(__pyx_t_4); + __pyx_t_7 = __Pyx_PyObject_CallOneArg(__pyx_t_4, ((PyObject *)__pyx_array_type)); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 994, __pyx_L23_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "View.MemoryView":989 + * + * try: + * if __pyx_collections_abc_Sequence: # <<<<<<<<<<<<<< + * + * + */ + } + + /* "View.MemoryView":988 + * pass + * + * try: # <<<<<<<<<<<<<< + * if __pyx_collections_abc_Sequence: + * + */ + } + __Pyx_XDECREF(__pyx_t_3); __pyx_t_3 = 0; + __Pyx_XDECREF(__pyx_t_2); __pyx_t_2 = 0; + __Pyx_XDECREF(__pyx_t_1); __pyx_t_1 = 0; + goto __pyx_L28_try_end; + __pyx_L23_error:; + __Pyx_XDECREF(__pyx_t_4); __pyx_t_4 = 0; + __Pyx_XDECREF(__pyx_t_5); __pyx_t_5 = 0; + __Pyx_XDECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "View.MemoryView":995 + * __pyx_collections_abc_Sequence.register(_memoryviewslice) + * __pyx_collections_abc_Sequence.register(array) + * except: # <<<<<<<<<<<<<< + * pass # ignore failure, it's a minor issue + * + */ + /*except:*/ { + __Pyx_ErrRestore(0,0,0); + goto __pyx_L24_exception_handled; + } + __pyx_L24_exception_handled:; + __Pyx_XGIVEREF(__pyx_t_3); + __Pyx_XGIVEREF(__pyx_t_2); + __Pyx_XGIVEREF(__pyx_t_1); + __Pyx_ExceptionReset(__pyx_t_3, __pyx_t_2, __pyx_t_1); + __pyx_L28_try_end:; + } + + /* "(tree fragment)":1 + * def __pyx_unpickle_Enum(__pyx_type, long __pyx_checksum, __pyx_state): # <<<<<<<<<<<<<< + * cdef object __pyx_PickleError + * cdef object __pyx_result + */ + __pyx_t_7 = PyCFunction_NewEx(&__pyx_mdef_15View_dot_MemoryView_1__pyx_unpickle_Enum, NULL, __pyx_n_s_View_MemoryView); if (unlikely(!__pyx_t_7)) __PYX_ERR(1, 1, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_pyx_unpickle_Enum, __pyx_t_7) < 0) __PYX_ERR(1, 1, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "TTS/tts/utils/monotonic_align/core.pyx":1 + * import numpy as np # <<<<<<<<<<<<<< + * + * cimport cython + */ + __pyx_t_7 = __Pyx_ImportDottedModule(__pyx_n_s_numpy, NULL); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 1, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_np, __pyx_t_7) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "TTS/tts/utils/monotonic_align/core.pyx":42 + * @cython.boundscheck(False) + * @cython.wraparound(False) + * cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: # <<<<<<<<<<<<<< + * cdef int b = values.shape[0] + * + */ + __pyx_k__11 = (-1e9); + __pyx_t_7 = PyFloat_FromDouble((-1e9)); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 42, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __pyx_t_4 = PyTuple_New(1); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 42, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_4); + __Pyx_GIVEREF(__pyx_t_7); + if (__Pyx_PyTuple_SET_ITEM(__pyx_t_4, 0, __pyx_t_7)) __PYX_ERR(0, 42, __pyx_L1_error); + __pyx_t_7 = 0; + __pyx_t_7 = __Pyx_CyFunction_New(&__pyx_mdef_3TTS_3tts_5utils_15monotonic_align_4core_1maximum_path_c, 0, __pyx_n_s_maximum_path_c, NULL, __pyx_n_s_TTS_tts_utils_monotonic_align_co_2, __pyx_d, ((PyObject *)__pyx_codeobj__24)); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 42, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + __Pyx_CyFunction_SetDefaultsTuple(__pyx_t_7, __pyx_t_4); + __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0; + if (PyDict_SetItem(__pyx_d, __pyx_n_s_maximum_path_c, __pyx_t_7) < 0) __PYX_ERR(0, 42, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /* "TTS/tts/utils/monotonic_align/core.pyx":1 + * import numpy as np # <<<<<<<<<<<<<< + * + * cimport cython + */ + __pyx_t_7 = __Pyx_PyDict_NewPresized(0); if (unlikely(!__pyx_t_7)) __PYX_ERR(0, 1, __pyx_L1_error) + __Pyx_GOTREF(__pyx_t_7); + if (PyDict_SetItem(__pyx_d, __pyx_n_s_test, __pyx_t_7) < 0) __PYX_ERR(0, 1, __pyx_L1_error) + __Pyx_DECREF(__pyx_t_7); __pyx_t_7 = 0; + + /*--- Wrapped vars code ---*/ + + goto __pyx_L0; + __pyx_L1_error:; + __Pyx_XDECREF(__pyx_t_4); + __Pyx_XDECREF(__pyx_t_5); + __Pyx_XDECREF(__pyx_t_7); + if (__pyx_m) { + if (__pyx_d && stringtab_initialized) { + __Pyx_AddTraceback("init TTS.tts.utils.monotonic_align.core", __pyx_clineno, __pyx_lineno, __pyx_filename); + } + #if !CYTHON_USE_MODULE_STATE + Py_CLEAR(__pyx_m); + #else + Py_DECREF(__pyx_m); + if (pystate_addmodule_run) { + PyObject *tp, *value, *tb; + PyErr_Fetch(&tp, &value, &tb); + PyState_RemoveModule(&__pyx_moduledef); + PyErr_Restore(tp, value, tb); + } + #endif + } else if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_ImportError, "init TTS.tts.utils.monotonic_align.core"); + } + __pyx_L0:; + __Pyx_RefNannyFinishContext(); + #if CYTHON_PEP489_MULTI_PHASE_INIT + return (__pyx_m != NULL) ? 0 : -1; + #elif PY_MAJOR_VERSION >= 3 + return __pyx_m; + #else + return; + #endif +} +/* #### Code section: cleanup_globals ### */ +/* #### Code section: cleanup_module ### */ +/* #### Code section: main_method ### */ +/* #### Code section: utility_code_pragmas ### */ +#ifdef _MSC_VER +#pragma warning( push ) +/* Warning 4127: conditional expression is constant + * Cython uses constant conditional expressions to allow in inline functions to be optimized at + * compile-time, so this warning is not useful + */ +#pragma warning( disable : 4127 ) +#endif + + + +/* #### Code section: utility_code_def ### */ + +/* --- Runtime support code --- */ +/* Refnanny */ +#if CYTHON_REFNANNY +static __Pyx_RefNannyAPIStruct *__Pyx_RefNannyImportAPI(const char *modname) { + PyObject *m = NULL, *p = NULL; + void *r = NULL; + m = PyImport_ImportModule(modname); + if (!m) goto end; + p = PyObject_GetAttrString(m, "RefNannyAPI"); + if (!p) goto end; + r = PyLong_AsVoidPtr(p); +end: + Py_XDECREF(p); + Py_XDECREF(m); + return (__Pyx_RefNannyAPIStruct *)r; +} +#endif + +/* PyErrExceptionMatches */ +#if CYTHON_FAST_THREAD_STATE +static int __Pyx_PyErr_ExceptionMatchesTuple(PyObject *exc_type, PyObject *tuple) { + Py_ssize_t i, n; + n = PyTuple_GET_SIZE(tuple); +#if PY_MAJOR_VERSION >= 3 + for (i=0; i= 0x030C00A6 + PyObject *current_exception = tstate->current_exception; + if (unlikely(!current_exception)) return 0; + exc_type = (PyObject*) Py_TYPE(current_exception); + if (exc_type == err) return 1; +#else + exc_type = tstate->curexc_type; + if (exc_type == err) return 1; + if (unlikely(!exc_type)) return 0; +#endif + #if CYTHON_AVOID_BORROWED_REFS + Py_INCREF(exc_type); + #endif + if (unlikely(PyTuple_Check(err))) { + result = __Pyx_PyErr_ExceptionMatchesTuple(exc_type, err); + } else { + result = __Pyx_PyErr_GivenExceptionMatches(exc_type, err); + } + #if CYTHON_AVOID_BORROWED_REFS + Py_DECREF(exc_type); + #endif + return result; +} +#endif + +/* PyErrFetchRestore */ +#if CYTHON_FAST_THREAD_STATE +static CYTHON_INLINE void __Pyx_ErrRestoreInState(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb) { +#if PY_VERSION_HEX >= 0x030C00A6 + PyObject *tmp_value; + assert(type == NULL || (value != NULL && type == (PyObject*) Py_TYPE(value))); + if (value) { + #if CYTHON_COMPILING_IN_CPYTHON + if (unlikely(((PyBaseExceptionObject*) value)->traceback != tb)) + #endif + PyException_SetTraceback(value, tb); + } + tmp_value = tstate->current_exception; + tstate->current_exception = value; + Py_XDECREF(tmp_value); + Py_XDECREF(type); + Py_XDECREF(tb); +#else + PyObject *tmp_type, *tmp_value, *tmp_tb; + tmp_type = tstate->curexc_type; + tmp_value = tstate->curexc_value; + tmp_tb = tstate->curexc_traceback; + tstate->curexc_type = type; + tstate->curexc_value = value; + tstate->curexc_traceback = tb; + Py_XDECREF(tmp_type); + Py_XDECREF(tmp_value); + Py_XDECREF(tmp_tb); +#endif +} +static CYTHON_INLINE void __Pyx_ErrFetchInState(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) { +#if PY_VERSION_HEX >= 0x030C00A6 + PyObject* exc_value; + exc_value = tstate->current_exception; + tstate->current_exception = 0; + *value = exc_value; + *type = NULL; + *tb = NULL; + if (exc_value) { + *type = (PyObject*) Py_TYPE(exc_value); + Py_INCREF(*type); + #if CYTHON_COMPILING_IN_CPYTHON + *tb = ((PyBaseExceptionObject*) exc_value)->traceback; + Py_XINCREF(*tb); + #else + *tb = PyException_GetTraceback(exc_value); + #endif + } +#else + *type = tstate->curexc_type; + *value = tstate->curexc_value; + *tb = tstate->curexc_traceback; + tstate->curexc_type = 0; + tstate->curexc_value = 0; + tstate->curexc_traceback = 0; +#endif +} +#endif + +/* PyObjectGetAttrStr */ +#if CYTHON_USE_TYPE_SLOTS +static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStr(PyObject* obj, PyObject* attr_name) { + PyTypeObject* tp = Py_TYPE(obj); + if (likely(tp->tp_getattro)) + return tp->tp_getattro(obj, attr_name); +#if PY_MAJOR_VERSION < 3 + if (likely(tp->tp_getattr)) + return tp->tp_getattr(obj, PyString_AS_STRING(attr_name)); +#endif + return PyObject_GetAttr(obj, attr_name); +} +#endif + +/* PyObjectGetAttrStrNoError */ +#if __PYX_LIMITED_VERSION_HEX < 0x030d00A1 +static void __Pyx_PyObject_GetAttrStr_ClearAttributeError(void) { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + if (likely(__Pyx_PyErr_ExceptionMatches(PyExc_AttributeError))) + __Pyx_PyErr_Clear(); +} +#endif +static CYTHON_INLINE PyObject* __Pyx_PyObject_GetAttrStrNoError(PyObject* obj, PyObject* attr_name) { + PyObject *result; +#if __PYX_LIMITED_VERSION_HEX >= 0x030d00A1 + (void) PyObject_GetOptionalAttr(obj, attr_name, &result); + return result; +#else +#if CYTHON_COMPILING_IN_CPYTHON && CYTHON_USE_TYPE_SLOTS && PY_VERSION_HEX >= 0x030700B1 + PyTypeObject* tp = Py_TYPE(obj); + if (likely(tp->tp_getattro == PyObject_GenericGetAttr)) { + return _PyObject_GenericGetAttrWithDict(obj, attr_name, NULL, 1); + } +#endif + result = __Pyx_PyObject_GetAttrStr(obj, attr_name); + if (unlikely(!result)) { + __Pyx_PyObject_GetAttrStr_ClearAttributeError(); + } + return result; +#endif +} + +/* GetBuiltinName */ +static PyObject *__Pyx_GetBuiltinName(PyObject *name) { + PyObject* result = __Pyx_PyObject_GetAttrStrNoError(__pyx_b, name); + if (unlikely(!result) && !PyErr_Occurred()) { + PyErr_Format(PyExc_NameError, +#if PY_MAJOR_VERSION >= 3 + "name '%U' is not defined", name); +#else + "name '%.200s' is not defined", PyString_AS_STRING(name)); +#endif + } + return result; +} + +/* TupleAndListFromArray */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE void __Pyx_copy_object_array(PyObject *const *CYTHON_RESTRICT src, PyObject** CYTHON_RESTRICT dest, Py_ssize_t length) { + PyObject *v; + Py_ssize_t i; + for (i = 0; i < length; i++) { + v = dest[i] = src[i]; + Py_INCREF(v); + } +} +static CYTHON_INLINE PyObject * +__Pyx_PyTuple_FromArray(PyObject *const *src, Py_ssize_t n) +{ + PyObject *res; + if (n <= 0) { + Py_INCREF(__pyx_empty_tuple); + return __pyx_empty_tuple; + } + res = PyTuple_New(n); + if (unlikely(res == NULL)) return NULL; + __Pyx_copy_object_array(src, ((PyTupleObject*)res)->ob_item, n); + return res; +} +static CYTHON_INLINE PyObject * +__Pyx_PyList_FromArray(PyObject *const *src, Py_ssize_t n) +{ + PyObject *res; + if (n <= 0) { + return PyList_New(0); + } + res = PyList_New(n); + if (unlikely(res == NULL)) return NULL; + __Pyx_copy_object_array(src, ((PyListObject*)res)->ob_item, n); + return res; +} +#endif + +/* BytesEquals */ +static CYTHON_INLINE int __Pyx_PyBytes_Equals(PyObject* s1, PyObject* s2, int equals) { +#if CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API + return PyObject_RichCompareBool(s1, s2, equals); +#else + if (s1 == s2) { + return (equals == Py_EQ); + } else if (PyBytes_CheckExact(s1) & PyBytes_CheckExact(s2)) { + const char *ps1, *ps2; + Py_ssize_t length = PyBytes_GET_SIZE(s1); + if (length != PyBytes_GET_SIZE(s2)) + return (equals == Py_NE); + ps1 = PyBytes_AS_STRING(s1); + ps2 = PyBytes_AS_STRING(s2); + if (ps1[0] != ps2[0]) { + return (equals == Py_NE); + } else if (length == 1) { + return (equals == Py_EQ); + } else { + int result; +#if CYTHON_USE_UNICODE_INTERNALS && (PY_VERSION_HEX < 0x030B0000) + Py_hash_t hash1, hash2; + hash1 = ((PyBytesObject*)s1)->ob_shash; + hash2 = ((PyBytesObject*)s2)->ob_shash; + if (hash1 != hash2 && hash1 != -1 && hash2 != -1) { + return (equals == Py_NE); + } +#endif + result = memcmp(ps1, ps2, (size_t)length); + return (equals == Py_EQ) ? (result == 0) : (result != 0); + } + } else if ((s1 == Py_None) & PyBytes_CheckExact(s2)) { + return (equals == Py_NE); + } else if ((s2 == Py_None) & PyBytes_CheckExact(s1)) { + return (equals == Py_NE); + } else { + int result; + PyObject* py_result = PyObject_RichCompare(s1, s2, equals); + if (!py_result) + return -1; + result = __Pyx_PyObject_IsTrue(py_result); + Py_DECREF(py_result); + return result; + } +#endif +} + +/* UnicodeEquals */ +static CYTHON_INLINE int __Pyx_PyUnicode_Equals(PyObject* s1, PyObject* s2, int equals) { +#if CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API + return PyObject_RichCompareBool(s1, s2, equals); +#else +#if PY_MAJOR_VERSION < 3 + PyObject* owned_ref = NULL; +#endif + int s1_is_unicode, s2_is_unicode; + if (s1 == s2) { + goto return_eq; + } + s1_is_unicode = PyUnicode_CheckExact(s1); + s2_is_unicode = PyUnicode_CheckExact(s2); +#if PY_MAJOR_VERSION < 3 + if ((s1_is_unicode & (!s2_is_unicode)) && PyString_CheckExact(s2)) { + owned_ref = PyUnicode_FromObject(s2); + if (unlikely(!owned_ref)) + return -1; + s2 = owned_ref; + s2_is_unicode = 1; + } else if ((s2_is_unicode & (!s1_is_unicode)) && PyString_CheckExact(s1)) { + owned_ref = PyUnicode_FromObject(s1); + if (unlikely(!owned_ref)) + return -1; + s1 = owned_ref; + s1_is_unicode = 1; + } else if (((!s2_is_unicode) & (!s1_is_unicode))) { + return __Pyx_PyBytes_Equals(s1, s2, equals); + } +#endif + if (s1_is_unicode & s2_is_unicode) { + Py_ssize_t length; + int kind; + void *data1, *data2; + if (unlikely(__Pyx_PyUnicode_READY(s1) < 0) || unlikely(__Pyx_PyUnicode_READY(s2) < 0)) + return -1; + length = __Pyx_PyUnicode_GET_LENGTH(s1); + if (length != __Pyx_PyUnicode_GET_LENGTH(s2)) { + goto return_ne; + } +#if CYTHON_USE_UNICODE_INTERNALS + { + Py_hash_t hash1, hash2; + #if CYTHON_PEP393_ENABLED + hash1 = ((PyASCIIObject*)s1)->hash; + hash2 = ((PyASCIIObject*)s2)->hash; + #else + hash1 = ((PyUnicodeObject*)s1)->hash; + hash2 = ((PyUnicodeObject*)s2)->hash; + #endif + if (hash1 != hash2 && hash1 != -1 && hash2 != -1) { + goto return_ne; + } + } +#endif + kind = __Pyx_PyUnicode_KIND(s1); + if (kind != __Pyx_PyUnicode_KIND(s2)) { + goto return_ne; + } + data1 = __Pyx_PyUnicode_DATA(s1); + data2 = __Pyx_PyUnicode_DATA(s2); + if (__Pyx_PyUnicode_READ(kind, data1, 0) != __Pyx_PyUnicode_READ(kind, data2, 0)) { + goto return_ne; + } else if (length == 1) { + goto return_eq; + } else { + int result = memcmp(data1, data2, (size_t)(length * kind)); + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(owned_ref); + #endif + return (equals == Py_EQ) ? (result == 0) : (result != 0); + } + } else if ((s1 == Py_None) & s2_is_unicode) { + goto return_ne; + } else if ((s2 == Py_None) & s1_is_unicode) { + goto return_ne; + } else { + int result; + PyObject* py_result = PyObject_RichCompare(s1, s2, equals); + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(owned_ref); + #endif + if (!py_result) + return -1; + result = __Pyx_PyObject_IsTrue(py_result); + Py_DECREF(py_result); + return result; + } +return_eq: + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(owned_ref); + #endif + return (equals == Py_EQ); +return_ne: + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(owned_ref); + #endif + return (equals == Py_NE); +#endif +} + +/* fastcall */ +#if CYTHON_METH_FASTCALL +static CYTHON_INLINE PyObject * __Pyx_GetKwValue_FASTCALL(PyObject *kwnames, PyObject *const *kwvalues, PyObject *s) +{ + Py_ssize_t i, n = PyTuple_GET_SIZE(kwnames); + for (i = 0; i < n; i++) + { + if (s == PyTuple_GET_ITEM(kwnames, i)) return kwvalues[i]; + } + for (i = 0; i < n; i++) + { + int eq = __Pyx_PyUnicode_Equals(s, PyTuple_GET_ITEM(kwnames, i), Py_EQ); + if (unlikely(eq != 0)) { + if (unlikely(eq < 0)) return NULL; + return kwvalues[i]; + } + } + return NULL; +} +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030d0000 +CYTHON_UNUSED static PyObject *__Pyx_KwargsAsDict_FASTCALL(PyObject *kwnames, PyObject *const *kwvalues) { + Py_ssize_t i, nkwargs = PyTuple_GET_SIZE(kwnames); + PyObject *dict; + dict = PyDict_New(); + if (unlikely(!dict)) + return NULL; + for (i=0; i= 3 + "%s() got multiple values for keyword argument '%U'", func_name, kw_name); + #else + "%s() got multiple values for keyword argument '%s'", func_name, + PyString_AsString(kw_name)); + #endif +} + +/* ParseKeywords */ +static int __Pyx_ParseOptionalKeywords( + PyObject *kwds, + PyObject *const *kwvalues, + PyObject **argnames[], + PyObject *kwds2, + PyObject *values[], + Py_ssize_t num_pos_args, + const char* function_name) +{ + PyObject *key = 0, *value = 0; + Py_ssize_t pos = 0; + PyObject*** name; + PyObject*** first_kw_arg = argnames + num_pos_args; + int kwds_is_tuple = CYTHON_METH_FASTCALL && likely(PyTuple_Check(kwds)); + while (1) { + Py_XDECREF(key); key = NULL; + Py_XDECREF(value); value = NULL; + if (kwds_is_tuple) { + Py_ssize_t size; +#if CYTHON_ASSUME_SAFE_MACROS + size = PyTuple_GET_SIZE(kwds); +#else + size = PyTuple_Size(kwds); + if (size < 0) goto bad; +#endif + if (pos >= size) break; +#if CYTHON_AVOID_BORROWED_REFS + key = __Pyx_PySequence_ITEM(kwds, pos); + if (!key) goto bad; +#elif CYTHON_ASSUME_SAFE_MACROS + key = PyTuple_GET_ITEM(kwds, pos); +#else + key = PyTuple_GetItem(kwds, pos); + if (!key) goto bad; +#endif + value = kwvalues[pos]; + pos++; + } + else + { + if (!PyDict_Next(kwds, &pos, &key, &value)) break; +#if CYTHON_AVOID_BORROWED_REFS + Py_INCREF(key); +#endif + } + name = first_kw_arg; + while (*name && (**name != key)) name++; + if (*name) { + values[name-argnames] = value; +#if CYTHON_AVOID_BORROWED_REFS + Py_INCREF(value); + Py_DECREF(key); +#endif + key = NULL; + value = NULL; + continue; + } +#if !CYTHON_AVOID_BORROWED_REFS + Py_INCREF(key); +#endif + Py_INCREF(value); + name = first_kw_arg; + #if PY_MAJOR_VERSION < 3 + if (likely(PyString_Check(key))) { + while (*name) { + if ((CYTHON_COMPILING_IN_PYPY || PyString_GET_SIZE(**name) == PyString_GET_SIZE(key)) + && _PyString_Eq(**name, key)) { + values[name-argnames] = value; +#if CYTHON_AVOID_BORROWED_REFS + value = NULL; +#endif + break; + } + name++; + } + if (*name) continue; + else { + PyObject*** argname = argnames; + while (argname != first_kw_arg) { + if ((**argname == key) || ( + (CYTHON_COMPILING_IN_PYPY || PyString_GET_SIZE(**argname) == PyString_GET_SIZE(key)) + && _PyString_Eq(**argname, key))) { + goto arg_passed_twice; + } + argname++; + } + } + } else + #endif + if (likely(PyUnicode_Check(key))) { + while (*name) { + int cmp = ( + #if !CYTHON_COMPILING_IN_PYPY && PY_MAJOR_VERSION >= 3 + (__Pyx_PyUnicode_GET_LENGTH(**name) != __Pyx_PyUnicode_GET_LENGTH(key)) ? 1 : + #endif + PyUnicode_Compare(**name, key) + ); + if (cmp < 0 && unlikely(PyErr_Occurred())) goto bad; + if (cmp == 0) { + values[name-argnames] = value; +#if CYTHON_AVOID_BORROWED_REFS + value = NULL; +#endif + break; + } + name++; + } + if (*name) continue; + else { + PyObject*** argname = argnames; + while (argname != first_kw_arg) { + int cmp = (**argname == key) ? 0 : + #if !CYTHON_COMPILING_IN_PYPY && PY_MAJOR_VERSION >= 3 + (__Pyx_PyUnicode_GET_LENGTH(**argname) != __Pyx_PyUnicode_GET_LENGTH(key)) ? 1 : + #endif + PyUnicode_Compare(**argname, key); + if (cmp < 0 && unlikely(PyErr_Occurred())) goto bad; + if (cmp == 0) goto arg_passed_twice; + argname++; + } + } + } else + goto invalid_keyword_type; + if (kwds2) { + if (unlikely(PyDict_SetItem(kwds2, key, value))) goto bad; + } else { + goto invalid_keyword; + } + } + Py_XDECREF(key); + Py_XDECREF(value); + return 0; +arg_passed_twice: + __Pyx_RaiseDoubleKeywordsError(function_name, key); + goto bad; +invalid_keyword_type: + PyErr_Format(PyExc_TypeError, + "%.200s() keywords must be strings", function_name); + goto bad; +invalid_keyword: + #if PY_MAJOR_VERSION < 3 + PyErr_Format(PyExc_TypeError, + "%.200s() got an unexpected keyword argument '%.200s'", + function_name, PyString_AsString(key)); + #else + PyErr_Format(PyExc_TypeError, + "%s() got an unexpected keyword argument '%U'", + function_name, key); + #endif +bad: + Py_XDECREF(key); + Py_XDECREF(value); + return -1; +} + +/* ArgTypeTest */ +static int __Pyx__ArgTypeTest(PyObject *obj, PyTypeObject *type, const char *name, int exact) +{ + __Pyx_TypeName type_name; + __Pyx_TypeName obj_type_name; + if (unlikely(!type)) { + PyErr_SetString(PyExc_SystemError, "Missing type object"); + return 0; + } + else if (exact) { + #if PY_MAJOR_VERSION == 2 + if ((type == &PyBaseString_Type) && likely(__Pyx_PyBaseString_CheckExact(obj))) return 1; + #endif + } + else { + if (likely(__Pyx_TypeCheck(obj, type))) return 1; + } + type_name = __Pyx_PyType_GetName(type); + obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj)); + PyErr_Format(PyExc_TypeError, + "Argument '%.200s' has incorrect type (expected " __Pyx_FMT_TYPENAME + ", got " __Pyx_FMT_TYPENAME ")", name, type_name, obj_type_name); + __Pyx_DECREF_TypeName(type_name); + __Pyx_DECREF_TypeName(obj_type_name); + return 0; +} + +/* RaiseException */ +#if PY_MAJOR_VERSION < 3 +static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) { + __Pyx_PyThreadState_declare + CYTHON_UNUSED_VAR(cause); + Py_XINCREF(type); + if (!value || value == Py_None) + value = NULL; + else + Py_INCREF(value); + if (!tb || tb == Py_None) + tb = NULL; + else { + Py_INCREF(tb); + if (!PyTraceBack_Check(tb)) { + PyErr_SetString(PyExc_TypeError, + "raise: arg 3 must be a traceback or None"); + goto raise_error; + } + } + if (PyType_Check(type)) { +#if CYTHON_COMPILING_IN_PYPY + if (!value) { + Py_INCREF(Py_None); + value = Py_None; + } +#endif + PyErr_NormalizeException(&type, &value, &tb); + } else { + if (value) { + PyErr_SetString(PyExc_TypeError, + "instance exception may not have a separate value"); + goto raise_error; + } + value = type; + type = (PyObject*) Py_TYPE(type); + Py_INCREF(type); + if (!PyType_IsSubtype((PyTypeObject *)type, (PyTypeObject *)PyExc_BaseException)) { + PyErr_SetString(PyExc_TypeError, + "raise: exception class must be a subclass of BaseException"); + goto raise_error; + } + } + __Pyx_PyThreadState_assign + __Pyx_ErrRestore(type, value, tb); + return; +raise_error: + Py_XDECREF(value); + Py_XDECREF(type); + Py_XDECREF(tb); + return; +} +#else +static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) { + PyObject* owned_instance = NULL; + if (tb == Py_None) { + tb = 0; + } else if (tb && !PyTraceBack_Check(tb)) { + PyErr_SetString(PyExc_TypeError, + "raise: arg 3 must be a traceback or None"); + goto bad; + } + if (value == Py_None) + value = 0; + if (PyExceptionInstance_Check(type)) { + if (value) { + PyErr_SetString(PyExc_TypeError, + "instance exception may not have a separate value"); + goto bad; + } + value = type; + type = (PyObject*) Py_TYPE(value); + } else if (PyExceptionClass_Check(type)) { + PyObject *instance_class = NULL; + if (value && PyExceptionInstance_Check(value)) { + instance_class = (PyObject*) Py_TYPE(value); + if (instance_class != type) { + int is_subclass = PyObject_IsSubclass(instance_class, type); + if (!is_subclass) { + instance_class = NULL; + } else if (unlikely(is_subclass == -1)) { + goto bad; + } else { + type = instance_class; + } + } + } + if (!instance_class) { + PyObject *args; + if (!value) + args = PyTuple_New(0); + else if (PyTuple_Check(value)) { + Py_INCREF(value); + args = value; + } else + args = PyTuple_Pack(1, value); + if (!args) + goto bad; + owned_instance = PyObject_Call(type, args, NULL); + Py_DECREF(args); + if (!owned_instance) + goto bad; + value = owned_instance; + if (!PyExceptionInstance_Check(value)) { + PyErr_Format(PyExc_TypeError, + "calling %R should have returned an instance of " + "BaseException, not %R", + type, Py_TYPE(value)); + goto bad; + } + } + } else { + PyErr_SetString(PyExc_TypeError, + "raise: exception class must be a subclass of BaseException"); + goto bad; + } + if (cause) { + PyObject *fixed_cause; + if (cause == Py_None) { + fixed_cause = NULL; + } else if (PyExceptionClass_Check(cause)) { + fixed_cause = PyObject_CallObject(cause, NULL); + if (fixed_cause == NULL) + goto bad; + } else if (PyExceptionInstance_Check(cause)) { + fixed_cause = cause; + Py_INCREF(fixed_cause); + } else { + PyErr_SetString(PyExc_TypeError, + "exception causes must derive from " + "BaseException"); + goto bad; + } + PyException_SetCause(value, fixed_cause); + } + PyErr_SetObject(type, value); + if (tb) { + #if PY_VERSION_HEX >= 0x030C00A6 + PyException_SetTraceback(value, tb); + #elif CYTHON_FAST_THREAD_STATE + PyThreadState *tstate = __Pyx_PyThreadState_Current; + PyObject* tmp_tb = tstate->curexc_traceback; + if (tb != tmp_tb) { + Py_INCREF(tb); + tstate->curexc_traceback = tb; + Py_XDECREF(tmp_tb); + } +#else + PyObject *tmp_type, *tmp_value, *tmp_tb; + PyErr_Fetch(&tmp_type, &tmp_value, &tmp_tb); + Py_INCREF(tb); + PyErr_Restore(tmp_type, tmp_value, tb); + Py_XDECREF(tmp_tb); +#endif + } +bad: + Py_XDECREF(owned_instance); + return; +} +#endif + +/* PyFunctionFastCall */ +#if CYTHON_FAST_PYCALL && !CYTHON_VECTORCALL +static PyObject* __Pyx_PyFunction_FastCallNoKw(PyCodeObject *co, PyObject **args, Py_ssize_t na, + PyObject *globals) { + PyFrameObject *f; + PyThreadState *tstate = __Pyx_PyThreadState_Current; + PyObject **fastlocals; + Py_ssize_t i; + PyObject *result; + assert(globals != NULL); + /* XXX Perhaps we should create a specialized + PyFrame_New() that doesn't take locals, but does + take builtins without sanity checking them. + */ + assert(tstate != NULL); + f = PyFrame_New(tstate, co, globals, NULL); + if (f == NULL) { + return NULL; + } + fastlocals = __Pyx_PyFrame_GetLocalsplus(f); + for (i = 0; i < na; i++) { + Py_INCREF(*args); + fastlocals[i] = *args++; + } + result = PyEval_EvalFrameEx(f,0); + ++tstate->recursion_depth; + Py_DECREF(f); + --tstate->recursion_depth; + return result; +} +static PyObject *__Pyx_PyFunction_FastCallDict(PyObject *func, PyObject **args, Py_ssize_t nargs, PyObject *kwargs) { + PyCodeObject *co = (PyCodeObject *)PyFunction_GET_CODE(func); + PyObject *globals = PyFunction_GET_GLOBALS(func); + PyObject *argdefs = PyFunction_GET_DEFAULTS(func); + PyObject *closure; +#if PY_MAJOR_VERSION >= 3 + PyObject *kwdefs; +#endif + PyObject *kwtuple, **k; + PyObject **d; + Py_ssize_t nd; + Py_ssize_t nk; + PyObject *result; + assert(kwargs == NULL || PyDict_Check(kwargs)); + nk = kwargs ? PyDict_Size(kwargs) : 0; + #if PY_MAJOR_VERSION < 3 + if (unlikely(Py_EnterRecursiveCall((char*)" while calling a Python object"))) { + return NULL; + } + #else + if (unlikely(Py_EnterRecursiveCall(" while calling a Python object"))) { + return NULL; + } + #endif + if ( +#if PY_MAJOR_VERSION >= 3 + co->co_kwonlyargcount == 0 && +#endif + likely(kwargs == NULL || nk == 0) && + co->co_flags == (CO_OPTIMIZED | CO_NEWLOCALS | CO_NOFREE)) { + if (argdefs == NULL && co->co_argcount == nargs) { + result = __Pyx_PyFunction_FastCallNoKw(co, args, nargs, globals); + goto done; + } + else if (nargs == 0 && argdefs != NULL + && co->co_argcount == Py_SIZE(argdefs)) { + /* function called with no arguments, but all parameters have + a default value: use default values as arguments .*/ + args = &PyTuple_GET_ITEM(argdefs, 0); + result =__Pyx_PyFunction_FastCallNoKw(co, args, Py_SIZE(argdefs), globals); + goto done; + } + } + if (kwargs != NULL) { + Py_ssize_t pos, i; + kwtuple = PyTuple_New(2 * nk); + if (kwtuple == NULL) { + result = NULL; + goto done; + } + k = &PyTuple_GET_ITEM(kwtuple, 0); + pos = i = 0; + while (PyDict_Next(kwargs, &pos, &k[i], &k[i+1])) { + Py_INCREF(k[i]); + Py_INCREF(k[i+1]); + i += 2; + } + nk = i / 2; + } + else { + kwtuple = NULL; + k = NULL; + } + closure = PyFunction_GET_CLOSURE(func); +#if PY_MAJOR_VERSION >= 3 + kwdefs = PyFunction_GET_KW_DEFAULTS(func); +#endif + if (argdefs != NULL) { + d = &PyTuple_GET_ITEM(argdefs, 0); + nd = Py_SIZE(argdefs); + } + else { + d = NULL; + nd = 0; + } +#if PY_MAJOR_VERSION >= 3 + result = PyEval_EvalCodeEx((PyObject*)co, globals, (PyObject *)NULL, + args, (int)nargs, + k, (int)nk, + d, (int)nd, kwdefs, closure); +#else + result = PyEval_EvalCodeEx(co, globals, (PyObject *)NULL, + args, (int)nargs, + k, (int)nk, + d, (int)nd, closure); +#endif + Py_XDECREF(kwtuple); +done: + Py_LeaveRecursiveCall(); + return result; +} +#endif + +/* PyObjectCall */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg, PyObject *kw) { + PyObject *result; + ternaryfunc call = Py_TYPE(func)->tp_call; + if (unlikely(!call)) + return PyObject_Call(func, arg, kw); + #if PY_MAJOR_VERSION < 3 + if (unlikely(Py_EnterRecursiveCall((char*)" while calling a Python object"))) + return NULL; + #else + if (unlikely(Py_EnterRecursiveCall(" while calling a Python object"))) + return NULL; + #endif + result = (*call)(func, arg, kw); + Py_LeaveRecursiveCall(); + if (unlikely(!result) && unlikely(!PyErr_Occurred())) { + PyErr_SetString( + PyExc_SystemError, + "NULL result without error in PyObject_Call"); + } + return result; +} +#endif + +/* PyObjectCallMethO */ +#if CYTHON_COMPILING_IN_CPYTHON +static CYTHON_INLINE PyObject* __Pyx_PyObject_CallMethO(PyObject *func, PyObject *arg) { + PyObject *self, *result; + PyCFunction cfunc; + cfunc = __Pyx_CyOrPyCFunction_GET_FUNCTION(func); + self = __Pyx_CyOrPyCFunction_GET_SELF(func); + #if PY_MAJOR_VERSION < 3 + if (unlikely(Py_EnterRecursiveCall((char*)" while calling a Python object"))) + return NULL; + #else + if (unlikely(Py_EnterRecursiveCall(" while calling a Python object"))) + return NULL; + #endif + result = cfunc(self, arg); + Py_LeaveRecursiveCall(); + if (unlikely(!result) && unlikely(!PyErr_Occurred())) { + PyErr_SetString( + PyExc_SystemError, + "NULL result without error in PyObject_Call"); + } + return result; +} +#endif + +/* PyObjectFastCall */ +#if PY_VERSION_HEX < 0x03090000 || CYTHON_COMPILING_IN_LIMITED_API +static PyObject* __Pyx_PyObject_FastCall_fallback(PyObject *func, PyObject **args, size_t nargs, PyObject *kwargs) { + PyObject *argstuple; + PyObject *result = 0; + size_t i; + argstuple = PyTuple_New((Py_ssize_t)nargs); + if (unlikely(!argstuple)) return NULL; + for (i = 0; i < nargs; i++) { + Py_INCREF(args[i]); + if (__Pyx_PyTuple_SET_ITEM(argstuple, (Py_ssize_t)i, args[i]) < 0) goto bad; + } + result = __Pyx_PyObject_Call(func, argstuple, kwargs); + bad: + Py_DECREF(argstuple); + return result; +} +#endif +static CYTHON_INLINE PyObject* __Pyx_PyObject_FastCallDict(PyObject *func, PyObject **args, size_t _nargs, PyObject *kwargs) { + Py_ssize_t nargs = __Pyx_PyVectorcall_NARGS(_nargs); +#if CYTHON_COMPILING_IN_CPYTHON + if (nargs == 0 && kwargs == NULL) { + if (__Pyx_CyOrPyCFunction_Check(func) && likely( __Pyx_CyOrPyCFunction_GET_FLAGS(func) & METH_NOARGS)) + return __Pyx_PyObject_CallMethO(func, NULL); + } + else if (nargs == 1 && kwargs == NULL) { + if (__Pyx_CyOrPyCFunction_Check(func) && likely( __Pyx_CyOrPyCFunction_GET_FLAGS(func) & METH_O)) + return __Pyx_PyObject_CallMethO(func, args[0]); + } +#endif + #if PY_VERSION_HEX < 0x030800B1 + #if CYTHON_FAST_PYCCALL + if (PyCFunction_Check(func)) { + if (kwargs) { + return _PyCFunction_FastCallDict(func, args, nargs, kwargs); + } else { + return _PyCFunction_FastCallKeywords(func, args, nargs, NULL); + } + } + #if PY_VERSION_HEX >= 0x030700A1 + if (!kwargs && __Pyx_IS_TYPE(func, &PyMethodDescr_Type)) { + return _PyMethodDescr_FastCallKeywords(func, args, nargs, NULL); + } + #endif + #endif + #if CYTHON_FAST_PYCALL + if (PyFunction_Check(func)) { + return __Pyx_PyFunction_FastCallDict(func, args, nargs, kwargs); + } + #endif + #endif + if (kwargs == NULL) { + #if CYTHON_VECTORCALL + #if PY_VERSION_HEX < 0x03090000 + vectorcallfunc f = _PyVectorcall_Function(func); + #else + vectorcallfunc f = PyVectorcall_Function(func); + #endif + if (f) { + return f(func, args, (size_t)nargs, NULL); + } + #elif defined(__Pyx_CyFunction_USED) && CYTHON_BACKPORT_VECTORCALL + if (__Pyx_CyFunction_CheckExact(func)) { + __pyx_vectorcallfunc f = __Pyx_CyFunction_func_vectorcall(func); + if (f) return f(func, args, (size_t)nargs, NULL); + } + #endif + } + if (nargs == 0) { + return __Pyx_PyObject_Call(func, __pyx_empty_tuple, kwargs); + } + #if PY_VERSION_HEX >= 0x03090000 && !CYTHON_COMPILING_IN_LIMITED_API + return PyObject_VectorcallDict(func, args, (size_t)nargs, kwargs); + #else + return __Pyx_PyObject_FastCall_fallback(func, args, (size_t)nargs, kwargs); + #endif +} + +/* RaiseUnexpectedTypeError */ +static int +__Pyx_RaiseUnexpectedTypeError(const char *expected, PyObject *obj) +{ + __Pyx_TypeName obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj)); + PyErr_Format(PyExc_TypeError, "Expected %s, got " __Pyx_FMT_TYPENAME, + expected, obj_type_name); + __Pyx_DECREF_TypeName(obj_type_name); + return 0; +} + +/* CIntToDigits */ +static const char DIGIT_PAIRS_10[2*10*10+1] = { + "00010203040506070809" + "10111213141516171819" + "20212223242526272829" + "30313233343536373839" + "40414243444546474849" + "50515253545556575859" + "60616263646566676869" + "70717273747576777879" + "80818283848586878889" + "90919293949596979899" +}; +static const char DIGIT_PAIRS_8[2*8*8+1] = { + "0001020304050607" + "1011121314151617" + "2021222324252627" + "3031323334353637" + "4041424344454647" + "5051525354555657" + "6061626364656667" + "7071727374757677" +}; +static const char DIGITS_HEX[2*16+1] = { + "0123456789abcdef" + "0123456789ABCDEF" +}; + +/* BuildPyUnicode */ +static PyObject* __Pyx_PyUnicode_BuildFromAscii(Py_ssize_t ulength, char* chars, int clength, + int prepend_sign, char padding_char) { + PyObject *uval; + Py_ssize_t uoffset = ulength - clength; +#if CYTHON_USE_UNICODE_INTERNALS + Py_ssize_t i; +#if CYTHON_PEP393_ENABLED + void *udata; + uval = PyUnicode_New(ulength, 127); + if (unlikely(!uval)) return NULL; + udata = PyUnicode_DATA(uval); +#else + Py_UNICODE *udata; + uval = PyUnicode_FromUnicode(NULL, ulength); + if (unlikely(!uval)) return NULL; + udata = PyUnicode_AS_UNICODE(uval); +#endif + if (uoffset > 0) { + i = 0; + if (prepend_sign) { + __Pyx_PyUnicode_WRITE(PyUnicode_1BYTE_KIND, udata, 0, '-'); + i++; + } + for (; i < uoffset; i++) { + __Pyx_PyUnicode_WRITE(PyUnicode_1BYTE_KIND, udata, i, padding_char); + } + } + for (i=0; i < clength; i++) { + __Pyx_PyUnicode_WRITE(PyUnicode_1BYTE_KIND, udata, uoffset+i, chars[i]); + } +#else + { + PyObject *sign = NULL, *padding = NULL; + uval = NULL; + if (uoffset > 0) { + prepend_sign = !!prepend_sign; + if (uoffset > prepend_sign) { + padding = PyUnicode_FromOrdinal(padding_char); + if (likely(padding) && uoffset > prepend_sign + 1) { + PyObject *tmp; + PyObject *repeat = PyInt_FromSsize_t(uoffset - prepend_sign); + if (unlikely(!repeat)) goto done_or_error; + tmp = PyNumber_Multiply(padding, repeat); + Py_DECREF(repeat); + Py_DECREF(padding); + padding = tmp; + } + if (unlikely(!padding)) goto done_or_error; + } + if (prepend_sign) { + sign = PyUnicode_FromOrdinal('-'); + if (unlikely(!sign)) goto done_or_error; + } + } + uval = PyUnicode_DecodeASCII(chars, clength, NULL); + if (likely(uval) && padding) { + PyObject *tmp = PyNumber_Add(padding, uval); + Py_DECREF(uval); + uval = tmp; + } + if (likely(uval) && sign) { + PyObject *tmp = PyNumber_Add(sign, uval); + Py_DECREF(uval); + uval = tmp; + } +done_or_error: + Py_XDECREF(padding); + Py_XDECREF(sign); + } +#endif + return uval; +} + +/* CIntToPyUnicode */ +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_From_int(int value, Py_ssize_t width, char padding_char, char format_char) { + char digits[sizeof(int)*3+2]; + char *dpos, *end = digits + sizeof(int)*3+2; + const char *hex_digits = DIGITS_HEX; + Py_ssize_t length, ulength; + int prepend_sign, last_one_off; + int remaining; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const int neg_one = (int) -1, const_zero = (int) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; + if (format_char == 'X') { + hex_digits += 16; + format_char = 'x'; + } + remaining = value; + last_one_off = 0; + dpos = end; + do { + int digit_pos; + switch (format_char) { + case 'o': + digit_pos = abs((int)(remaining % (8*8))); + remaining = (int) (remaining / (8*8)); + dpos -= 2; + memcpy(dpos, DIGIT_PAIRS_8 + digit_pos * 2, 2); + last_one_off = (digit_pos < 8); + break; + case 'd': + digit_pos = abs((int)(remaining % (10*10))); + remaining = (int) (remaining / (10*10)); + dpos -= 2; + memcpy(dpos, DIGIT_PAIRS_10 + digit_pos * 2, 2); + last_one_off = (digit_pos < 10); + break; + case 'x': + *(--dpos) = hex_digits[abs((int)(remaining % 16))]; + remaining = (int) (remaining / 16); + break; + default: + assert(0); + break; + } + } while (unlikely(remaining != 0)); + assert(!last_one_off || *dpos == '0'); + dpos += last_one_off; + length = end - dpos; + ulength = length; + prepend_sign = 0; + if (!is_unsigned && value <= neg_one) { + if (padding_char == ' ' || width <= length + 1) { + *(--dpos) = '-'; + ++length; + } else { + prepend_sign = 1; + } + ++ulength; + } + if (width > ulength) { + ulength = width; + } + if (ulength == 1) { + return PyUnicode_FromOrdinal(*dpos); + } + return __Pyx_PyUnicode_BuildFromAscii(ulength, dpos, (int) length, prepend_sign, padding_char); +} + +/* CIntToPyUnicode */ +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_From_Py_ssize_t(Py_ssize_t value, Py_ssize_t width, char padding_char, char format_char) { + char digits[sizeof(Py_ssize_t)*3+2]; + char *dpos, *end = digits + sizeof(Py_ssize_t)*3+2; + const char *hex_digits = DIGITS_HEX; + Py_ssize_t length, ulength; + int prepend_sign, last_one_off; + Py_ssize_t remaining; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const Py_ssize_t neg_one = (Py_ssize_t) -1, const_zero = (Py_ssize_t) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; + if (format_char == 'X') { + hex_digits += 16; + format_char = 'x'; + } + remaining = value; + last_one_off = 0; + dpos = end; + do { + int digit_pos; + switch (format_char) { + case 'o': + digit_pos = abs((int)(remaining % (8*8))); + remaining = (Py_ssize_t) (remaining / (8*8)); + dpos -= 2; + memcpy(dpos, DIGIT_PAIRS_8 + digit_pos * 2, 2); + last_one_off = (digit_pos < 8); + break; + case 'd': + digit_pos = abs((int)(remaining % (10*10))); + remaining = (Py_ssize_t) (remaining / (10*10)); + dpos -= 2; + memcpy(dpos, DIGIT_PAIRS_10 + digit_pos * 2, 2); + last_one_off = (digit_pos < 10); + break; + case 'x': + *(--dpos) = hex_digits[abs((int)(remaining % 16))]; + remaining = (Py_ssize_t) (remaining / 16); + break; + default: + assert(0); + break; + } + } while (unlikely(remaining != 0)); + assert(!last_one_off || *dpos == '0'); + dpos += last_one_off; + length = end - dpos; + ulength = length; + prepend_sign = 0; + if (!is_unsigned && value <= neg_one) { + if (padding_char == ' ' || width <= length + 1) { + *(--dpos) = '-'; + ++length; + } else { + prepend_sign = 1; + } + ++ulength; + } + if (width > ulength) { + ulength = width; + } + if (ulength == 1) { + return PyUnicode_FromOrdinal(*dpos); + } + return __Pyx_PyUnicode_BuildFromAscii(ulength, dpos, (int) length, prepend_sign, padding_char); +} + +/* JoinPyUnicode */ +static PyObject* __Pyx_PyUnicode_Join(PyObject* value_tuple, Py_ssize_t value_count, Py_ssize_t result_ulength, + Py_UCS4 max_char) { +#if CYTHON_USE_UNICODE_INTERNALS && CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + PyObject *result_uval; + int result_ukind, kind_shift; + Py_ssize_t i, char_pos; + void *result_udata; + CYTHON_MAYBE_UNUSED_VAR(max_char); +#if CYTHON_PEP393_ENABLED + result_uval = PyUnicode_New(result_ulength, max_char); + if (unlikely(!result_uval)) return NULL; + result_ukind = (max_char <= 255) ? PyUnicode_1BYTE_KIND : (max_char <= 65535) ? PyUnicode_2BYTE_KIND : PyUnicode_4BYTE_KIND; + kind_shift = (result_ukind == PyUnicode_4BYTE_KIND) ? 2 : result_ukind - 1; + result_udata = PyUnicode_DATA(result_uval); +#else + result_uval = PyUnicode_FromUnicode(NULL, result_ulength); + if (unlikely(!result_uval)) return NULL; + result_ukind = sizeof(Py_UNICODE); + kind_shift = (result_ukind == 4) ? 2 : result_ukind - 1; + result_udata = PyUnicode_AS_UNICODE(result_uval); +#endif + assert(kind_shift == 2 || kind_shift == 1 || kind_shift == 0); + char_pos = 0; + for (i=0; i < value_count; i++) { + int ukind; + Py_ssize_t ulength; + void *udata; + PyObject *uval = PyTuple_GET_ITEM(value_tuple, i); + if (unlikely(__Pyx_PyUnicode_READY(uval))) + goto bad; + ulength = __Pyx_PyUnicode_GET_LENGTH(uval); + if (unlikely(!ulength)) + continue; + if (unlikely((PY_SSIZE_T_MAX >> kind_shift) - ulength < char_pos)) + goto overflow; + ukind = __Pyx_PyUnicode_KIND(uval); + udata = __Pyx_PyUnicode_DATA(uval); + if (!CYTHON_PEP393_ENABLED || ukind == result_ukind) { + memcpy((char *)result_udata + (char_pos << kind_shift), udata, (size_t) (ulength << kind_shift)); + } else { + #if PY_VERSION_HEX >= 0x030d0000 + if (unlikely(PyUnicode_CopyCharacters(result_uval, char_pos, uval, 0, ulength) < 0)) goto bad; + #elif CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030300F0 || defined(_PyUnicode_FastCopyCharacters) + _PyUnicode_FastCopyCharacters(result_uval, char_pos, uval, 0, ulength); + #else + Py_ssize_t j; + for (j=0; j < ulength; j++) { + Py_UCS4 uchar = __Pyx_PyUnicode_READ(ukind, udata, j); + __Pyx_PyUnicode_WRITE(result_ukind, result_udata, char_pos+j, uchar); + } + #endif + } + char_pos += ulength; + } + return result_uval; +overflow: + PyErr_SetString(PyExc_OverflowError, "join() result is too long for a Python string"); +bad: + Py_DECREF(result_uval); + return NULL; +#else + CYTHON_UNUSED_VAR(max_char); + CYTHON_UNUSED_VAR(result_ulength); + CYTHON_UNUSED_VAR(value_count); + return PyUnicode_Join(__pyx_empty_unicode, value_tuple); +#endif +} + +/* GetAttr */ +static CYTHON_INLINE PyObject *__Pyx_GetAttr(PyObject *o, PyObject *n) { +#if CYTHON_USE_TYPE_SLOTS +#if PY_MAJOR_VERSION >= 3 + if (likely(PyUnicode_Check(n))) +#else + if (likely(PyString_Check(n))) +#endif + return __Pyx_PyObject_GetAttrStr(o, n); +#endif + return PyObject_GetAttr(o, n); +} + +/* GetItemInt */ +static PyObject *__Pyx_GetItemInt_Generic(PyObject *o, PyObject* j) { + PyObject *r; + if (unlikely(!j)) return NULL; + r = PyObject_GetItem(o, j); + Py_DECREF(j); + return r; +} +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_List_Fast(PyObject *o, Py_ssize_t i, + CYTHON_NCP_UNUSED int wraparound, + CYTHON_NCP_UNUSED int boundscheck) { +#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + Py_ssize_t wrapped_i = i; + if (wraparound & unlikely(i < 0)) { + wrapped_i += PyList_GET_SIZE(o); + } + if ((!boundscheck) || likely(__Pyx_is_valid_index(wrapped_i, PyList_GET_SIZE(o)))) { + PyObject *r = PyList_GET_ITEM(o, wrapped_i); + Py_INCREF(r); + return r; + } + return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i)); +#else + return PySequence_GetItem(o, i); +#endif +} +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Tuple_Fast(PyObject *o, Py_ssize_t i, + CYTHON_NCP_UNUSED int wraparound, + CYTHON_NCP_UNUSED int boundscheck) { +#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + Py_ssize_t wrapped_i = i; + if (wraparound & unlikely(i < 0)) { + wrapped_i += PyTuple_GET_SIZE(o); + } + if ((!boundscheck) || likely(__Pyx_is_valid_index(wrapped_i, PyTuple_GET_SIZE(o)))) { + PyObject *r = PyTuple_GET_ITEM(o, wrapped_i); + Py_INCREF(r); + return r; + } + return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i)); +#else + return PySequence_GetItem(o, i); +#endif +} +static CYTHON_INLINE PyObject *__Pyx_GetItemInt_Fast(PyObject *o, Py_ssize_t i, int is_list, + CYTHON_NCP_UNUSED int wraparound, + CYTHON_NCP_UNUSED int boundscheck) { +#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS && CYTHON_USE_TYPE_SLOTS + if (is_list || PyList_CheckExact(o)) { + Py_ssize_t n = ((!wraparound) | likely(i >= 0)) ? i : i + PyList_GET_SIZE(o); + if ((!boundscheck) || (likely(__Pyx_is_valid_index(n, PyList_GET_SIZE(o))))) { + PyObject *r = PyList_GET_ITEM(o, n); + Py_INCREF(r); + return r; + } + } + else if (PyTuple_CheckExact(o)) { + Py_ssize_t n = ((!wraparound) | likely(i >= 0)) ? i : i + PyTuple_GET_SIZE(o); + if ((!boundscheck) || likely(__Pyx_is_valid_index(n, PyTuple_GET_SIZE(o)))) { + PyObject *r = PyTuple_GET_ITEM(o, n); + Py_INCREF(r); + return r; + } + } else { + PyMappingMethods *mm = Py_TYPE(o)->tp_as_mapping; + PySequenceMethods *sm = Py_TYPE(o)->tp_as_sequence; + if (mm && mm->mp_subscript) { + PyObject *r, *key = PyInt_FromSsize_t(i); + if (unlikely(!key)) return NULL; + r = mm->mp_subscript(o, key); + Py_DECREF(key); + return r; + } + if (likely(sm && sm->sq_item)) { + if (wraparound && unlikely(i < 0) && likely(sm->sq_length)) { + Py_ssize_t l = sm->sq_length(o); + if (likely(l >= 0)) { + i += l; + } else { + if (!PyErr_ExceptionMatches(PyExc_OverflowError)) + return NULL; + PyErr_Clear(); + } + } + return sm->sq_item(o, i); + } + } +#else + if (is_list || !PyMapping_Check(o)) { + return PySequence_GetItem(o, i); + } +#endif + return __Pyx_GetItemInt_Generic(o, PyInt_FromSsize_t(i)); +} + +/* PyObjectCallOneArg */ +static CYTHON_INLINE PyObject* __Pyx_PyObject_CallOneArg(PyObject *func, PyObject *arg) { + PyObject *args[2] = {NULL, arg}; + return __Pyx_PyObject_FastCall(func, args+1, 1 | __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET); +} + +/* ObjectGetItem */ +#if CYTHON_USE_TYPE_SLOTS +static PyObject *__Pyx_PyObject_GetIndex(PyObject *obj, PyObject *index) { + PyObject *runerr = NULL; + Py_ssize_t key_value; + key_value = __Pyx_PyIndex_AsSsize_t(index); + if (likely(key_value != -1 || !(runerr = PyErr_Occurred()))) { + return __Pyx_GetItemInt_Fast(obj, key_value, 0, 1, 1); + } + if (PyErr_GivenExceptionMatches(runerr, PyExc_OverflowError)) { + __Pyx_TypeName index_type_name = __Pyx_PyType_GetName(Py_TYPE(index)); + PyErr_Clear(); + PyErr_Format(PyExc_IndexError, + "cannot fit '" __Pyx_FMT_TYPENAME "' into an index-sized integer", index_type_name); + __Pyx_DECREF_TypeName(index_type_name); + } + return NULL; +} +static PyObject *__Pyx_PyObject_GetItem_Slow(PyObject *obj, PyObject *key) { + __Pyx_TypeName obj_type_name; + if (likely(PyType_Check(obj))) { + PyObject *meth = __Pyx_PyObject_GetAttrStrNoError(obj, __pyx_n_s_class_getitem); + if (!meth) { + PyErr_Clear(); + } else { + PyObject *result = __Pyx_PyObject_CallOneArg(meth, key); + Py_DECREF(meth); + return result; + } + } + obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj)); + PyErr_Format(PyExc_TypeError, + "'" __Pyx_FMT_TYPENAME "' object is not subscriptable", obj_type_name); + __Pyx_DECREF_TypeName(obj_type_name); + return NULL; +} +static PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject *key) { + PyTypeObject *tp = Py_TYPE(obj); + PyMappingMethods *mm = tp->tp_as_mapping; + PySequenceMethods *sm = tp->tp_as_sequence; + if (likely(mm && mm->mp_subscript)) { + return mm->mp_subscript(obj, key); + } + if (likely(sm && sm->sq_item)) { + return __Pyx_PyObject_GetIndex(obj, key); + } + return __Pyx_PyObject_GetItem_Slow(obj, key); +} +#endif + +/* KeywordStringCheck */ +static int __Pyx_CheckKeywordStrings( + PyObject *kw, + const char* function_name, + int kw_allowed) +{ + PyObject* key = 0; + Py_ssize_t pos = 0; +#if CYTHON_COMPILING_IN_PYPY + if (!kw_allowed && PyDict_Next(kw, &pos, &key, 0)) + goto invalid_keyword; + return 1; +#else + if (CYTHON_METH_FASTCALL && likely(PyTuple_Check(kw))) { + Py_ssize_t kwsize; +#if CYTHON_ASSUME_SAFE_MACROS + kwsize = PyTuple_GET_SIZE(kw); +#else + kwsize = PyTuple_Size(kw); + if (kwsize < 0) return 0; +#endif + if (unlikely(kwsize == 0)) + return 1; + if (!kw_allowed) { +#if CYTHON_ASSUME_SAFE_MACROS + key = PyTuple_GET_ITEM(kw, 0); +#else + key = PyTuple_GetItem(kw, pos); + if (!key) return 0; +#endif + goto invalid_keyword; + } +#if PY_VERSION_HEX < 0x03090000 + for (pos = 0; pos < kwsize; pos++) { +#if CYTHON_ASSUME_SAFE_MACROS + key = PyTuple_GET_ITEM(kw, pos); +#else + key = PyTuple_GetItem(kw, pos); + if (!key) return 0; +#endif + if (unlikely(!PyUnicode_Check(key))) + goto invalid_keyword_type; + } +#endif + return 1; + } + while (PyDict_Next(kw, &pos, &key, 0)) { + #if PY_MAJOR_VERSION < 3 + if (unlikely(!PyString_Check(key))) + #endif + if (unlikely(!PyUnicode_Check(key))) + goto invalid_keyword_type; + } + if (!kw_allowed && unlikely(key)) + goto invalid_keyword; + return 1; +invalid_keyword_type: + PyErr_Format(PyExc_TypeError, + "%.200s() keywords must be strings", function_name); + return 0; +#endif +invalid_keyword: + #if PY_MAJOR_VERSION < 3 + PyErr_Format(PyExc_TypeError, + "%.200s() got an unexpected keyword argument '%.200s'", + function_name, PyString_AsString(key)); + #else + PyErr_Format(PyExc_TypeError, + "%s() got an unexpected keyword argument '%U'", + function_name, key); + #endif + return 0; +} + +/* DivInt[Py_ssize_t] */ +static CYTHON_INLINE Py_ssize_t __Pyx_div_Py_ssize_t(Py_ssize_t a, Py_ssize_t b) { + Py_ssize_t q = a / b; + Py_ssize_t r = a - q*b; + q -= ((r != 0) & ((r ^ b) < 0)); + return q; +} + +/* GetAttr3 */ +#if __PYX_LIMITED_VERSION_HEX < 0x030d00A1 +static PyObject *__Pyx_GetAttr3Default(PyObject *d) { + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + if (unlikely(!__Pyx_PyErr_ExceptionMatches(PyExc_AttributeError))) + return NULL; + __Pyx_PyErr_Clear(); + Py_INCREF(d); + return d; +} +#endif +static CYTHON_INLINE PyObject *__Pyx_GetAttr3(PyObject *o, PyObject *n, PyObject *d) { + PyObject *r; +#if __PYX_LIMITED_VERSION_HEX >= 0x030d00A1 + int res = PyObject_GetOptionalAttr(o, n, &r); + return (res != 0) ? r : __Pyx_NewRef(d); +#else + #if CYTHON_USE_TYPE_SLOTS + if (likely(PyString_Check(n))) { + r = __Pyx_PyObject_GetAttrStrNoError(o, n); + if (unlikely(!r) && likely(!PyErr_Occurred())) { + r = __Pyx_NewRef(d); + } + return r; + } + #endif + r = PyObject_GetAttr(o, n); + return (likely(r)) ? r : __Pyx_GetAttr3Default(d); +#endif +} + +/* PyDictVersioning */ +#if CYTHON_USE_DICT_VERSIONS && CYTHON_USE_TYPE_SLOTS +static CYTHON_INLINE PY_UINT64_T __Pyx_get_tp_dict_version(PyObject *obj) { + PyObject *dict = Py_TYPE(obj)->tp_dict; + return likely(dict) ? __PYX_GET_DICT_VERSION(dict) : 0; +} +static CYTHON_INLINE PY_UINT64_T __Pyx_get_object_dict_version(PyObject *obj) { + PyObject **dictptr = NULL; + Py_ssize_t offset = Py_TYPE(obj)->tp_dictoffset; + if (offset) { +#if CYTHON_COMPILING_IN_CPYTHON + dictptr = (likely(offset > 0)) ? (PyObject **) ((char *)obj + offset) : _PyObject_GetDictPtr(obj); +#else + dictptr = _PyObject_GetDictPtr(obj); +#endif + } + return (dictptr && *dictptr) ? __PYX_GET_DICT_VERSION(*dictptr) : 0; +} +static CYTHON_INLINE int __Pyx_object_dict_version_matches(PyObject* obj, PY_UINT64_T tp_dict_version, PY_UINT64_T obj_dict_version) { + PyObject *dict = Py_TYPE(obj)->tp_dict; + if (unlikely(!dict) || unlikely(tp_dict_version != __PYX_GET_DICT_VERSION(dict))) + return 0; + return obj_dict_version == __Pyx_get_object_dict_version(obj); +} +#endif + +/* GetModuleGlobalName */ +#if CYTHON_USE_DICT_VERSIONS +static PyObject *__Pyx__GetModuleGlobalName(PyObject *name, PY_UINT64_T *dict_version, PyObject **dict_cached_value) +#else +static CYTHON_INLINE PyObject *__Pyx__GetModuleGlobalName(PyObject *name) +#endif +{ + PyObject *result; +#if !CYTHON_AVOID_BORROWED_REFS +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030500A1 && PY_VERSION_HEX < 0x030d0000 + result = _PyDict_GetItem_KnownHash(__pyx_d, name, ((PyASCIIObject *) name)->hash); + __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version) + if (likely(result)) { + return __Pyx_NewRef(result); + } else if (unlikely(PyErr_Occurred())) { + return NULL; + } +#elif CYTHON_COMPILING_IN_LIMITED_API + if (unlikely(!__pyx_m)) { + return NULL; + } + result = PyObject_GetAttr(__pyx_m, name); + if (likely(result)) { + return result; + } +#else + result = PyDict_GetItem(__pyx_d, name); + __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version) + if (likely(result)) { + return __Pyx_NewRef(result); + } +#endif +#else + result = PyObject_GetItem(__pyx_d, name); + __PYX_UPDATE_DICT_CACHE(__pyx_d, result, *dict_cached_value, *dict_version) + if (likely(result)) { + return __Pyx_NewRef(result); + } + PyErr_Clear(); +#endif + return __Pyx_GetBuiltinName(name); +} + +/* RaiseTooManyValuesToUnpack */ +static CYTHON_INLINE void __Pyx_RaiseTooManyValuesError(Py_ssize_t expected) { + PyErr_Format(PyExc_ValueError, + "too many values to unpack (expected %" CYTHON_FORMAT_SSIZE_T "d)", expected); +} + +/* RaiseNeedMoreValuesToUnpack */ +static CYTHON_INLINE void __Pyx_RaiseNeedMoreValuesError(Py_ssize_t index) { + PyErr_Format(PyExc_ValueError, + "need more than %" CYTHON_FORMAT_SSIZE_T "d value%.1s to unpack", + index, (index == 1) ? "" : "s"); +} + +/* RaiseNoneIterError */ +static CYTHON_INLINE void __Pyx_RaiseNoneNotIterableError(void) { + PyErr_SetString(PyExc_TypeError, "'NoneType' object is not iterable"); +} + +/* ExtTypeTest */ +static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) { + __Pyx_TypeName obj_type_name; + __Pyx_TypeName type_name; + if (unlikely(!type)) { + PyErr_SetString(PyExc_SystemError, "Missing type object"); + return 0; + } + if (likely(__Pyx_TypeCheck(obj, type))) + return 1; + obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj)); + type_name = __Pyx_PyType_GetName(type); + PyErr_Format(PyExc_TypeError, + "Cannot convert " __Pyx_FMT_TYPENAME " to " __Pyx_FMT_TYPENAME, + obj_type_name, type_name); + __Pyx_DECREF_TypeName(obj_type_name); + __Pyx_DECREF_TypeName(type_name); + return 0; +} + +/* GetTopmostException */ +#if CYTHON_USE_EXC_INFO_STACK && CYTHON_FAST_THREAD_STATE +static _PyErr_StackItem * +__Pyx_PyErr_GetTopmostException(PyThreadState *tstate) +{ + _PyErr_StackItem *exc_info = tstate->exc_info; + while ((exc_info->exc_value == NULL || exc_info->exc_value == Py_None) && + exc_info->previous_item != NULL) + { + exc_info = exc_info->previous_item; + } + return exc_info; +} +#endif + +/* SaveResetException */ +#if CYTHON_FAST_THREAD_STATE +static CYTHON_INLINE void __Pyx__ExceptionSave(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) { + #if CYTHON_USE_EXC_INFO_STACK && PY_VERSION_HEX >= 0x030B00a4 + _PyErr_StackItem *exc_info = __Pyx_PyErr_GetTopmostException(tstate); + PyObject *exc_value = exc_info->exc_value; + if (exc_value == NULL || exc_value == Py_None) { + *value = NULL; + *type = NULL; + *tb = NULL; + } else { + *value = exc_value; + Py_INCREF(*value); + *type = (PyObject*) Py_TYPE(exc_value); + Py_INCREF(*type); + *tb = PyException_GetTraceback(exc_value); + } + #elif CYTHON_USE_EXC_INFO_STACK + _PyErr_StackItem *exc_info = __Pyx_PyErr_GetTopmostException(tstate); + *type = exc_info->exc_type; + *value = exc_info->exc_value; + *tb = exc_info->exc_traceback; + Py_XINCREF(*type); + Py_XINCREF(*value); + Py_XINCREF(*tb); + #else + *type = tstate->exc_type; + *value = tstate->exc_value; + *tb = tstate->exc_traceback; + Py_XINCREF(*type); + Py_XINCREF(*value); + Py_XINCREF(*tb); + #endif +} +static CYTHON_INLINE void __Pyx__ExceptionReset(PyThreadState *tstate, PyObject *type, PyObject *value, PyObject *tb) { + #if CYTHON_USE_EXC_INFO_STACK && PY_VERSION_HEX >= 0x030B00a4 + _PyErr_StackItem *exc_info = tstate->exc_info; + PyObject *tmp_value = exc_info->exc_value; + exc_info->exc_value = value; + Py_XDECREF(tmp_value); + Py_XDECREF(type); + Py_XDECREF(tb); + #else + PyObject *tmp_type, *tmp_value, *tmp_tb; + #if CYTHON_USE_EXC_INFO_STACK + _PyErr_StackItem *exc_info = tstate->exc_info; + tmp_type = exc_info->exc_type; + tmp_value = exc_info->exc_value; + tmp_tb = exc_info->exc_traceback; + exc_info->exc_type = type; + exc_info->exc_value = value; + exc_info->exc_traceback = tb; + #else + tmp_type = tstate->exc_type; + tmp_value = tstate->exc_value; + tmp_tb = tstate->exc_traceback; + tstate->exc_type = type; + tstate->exc_value = value; + tstate->exc_traceback = tb; + #endif + Py_XDECREF(tmp_type); + Py_XDECREF(tmp_value); + Py_XDECREF(tmp_tb); + #endif +} +#endif + +/* GetException */ +#if CYTHON_FAST_THREAD_STATE +static int __Pyx__GetException(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) +#else +static int __Pyx_GetException(PyObject **type, PyObject **value, PyObject **tb) +#endif +{ + PyObject *local_type = NULL, *local_value, *local_tb = NULL; +#if CYTHON_FAST_THREAD_STATE + PyObject *tmp_type, *tmp_value, *tmp_tb; + #if PY_VERSION_HEX >= 0x030C00A6 + local_value = tstate->current_exception; + tstate->current_exception = 0; + if (likely(local_value)) { + local_type = (PyObject*) Py_TYPE(local_value); + Py_INCREF(local_type); + local_tb = PyException_GetTraceback(local_value); + } + #else + local_type = tstate->curexc_type; + local_value = tstate->curexc_value; + local_tb = tstate->curexc_traceback; + tstate->curexc_type = 0; + tstate->curexc_value = 0; + tstate->curexc_traceback = 0; + #endif +#else + PyErr_Fetch(&local_type, &local_value, &local_tb); +#endif + PyErr_NormalizeException(&local_type, &local_value, &local_tb); +#if CYTHON_FAST_THREAD_STATE && PY_VERSION_HEX >= 0x030C00A6 + if (unlikely(tstate->current_exception)) +#elif CYTHON_FAST_THREAD_STATE + if (unlikely(tstate->curexc_type)) +#else + if (unlikely(PyErr_Occurred())) +#endif + goto bad; + #if PY_MAJOR_VERSION >= 3 + if (local_tb) { + if (unlikely(PyException_SetTraceback(local_value, local_tb) < 0)) + goto bad; + } + #endif + Py_XINCREF(local_tb); + Py_XINCREF(local_type); + Py_XINCREF(local_value); + *type = local_type; + *value = local_value; + *tb = local_tb; +#if CYTHON_FAST_THREAD_STATE + #if CYTHON_USE_EXC_INFO_STACK + { + _PyErr_StackItem *exc_info = tstate->exc_info; + #if PY_VERSION_HEX >= 0x030B00a4 + tmp_value = exc_info->exc_value; + exc_info->exc_value = local_value; + tmp_type = NULL; + tmp_tb = NULL; + Py_XDECREF(local_type); + Py_XDECREF(local_tb); + #else + tmp_type = exc_info->exc_type; + tmp_value = exc_info->exc_value; + tmp_tb = exc_info->exc_traceback; + exc_info->exc_type = local_type; + exc_info->exc_value = local_value; + exc_info->exc_traceback = local_tb; + #endif + } + #else + tmp_type = tstate->exc_type; + tmp_value = tstate->exc_value; + tmp_tb = tstate->exc_traceback; + tstate->exc_type = local_type; + tstate->exc_value = local_value; + tstate->exc_traceback = local_tb; + #endif + Py_XDECREF(tmp_type); + Py_XDECREF(tmp_value); + Py_XDECREF(tmp_tb); +#else + PyErr_SetExcInfo(local_type, local_value, local_tb); +#endif + return 0; +bad: + *type = 0; + *value = 0; + *tb = 0; + Py_XDECREF(local_type); + Py_XDECREF(local_value); + Py_XDECREF(local_tb); + return -1; +} + +/* SwapException */ +#if CYTHON_FAST_THREAD_STATE +static CYTHON_INLINE void __Pyx__ExceptionSwap(PyThreadState *tstate, PyObject **type, PyObject **value, PyObject **tb) { + PyObject *tmp_type, *tmp_value, *tmp_tb; + #if CYTHON_USE_EXC_INFO_STACK && PY_VERSION_HEX >= 0x030B00a4 + _PyErr_StackItem *exc_info = tstate->exc_info; + tmp_value = exc_info->exc_value; + exc_info->exc_value = *value; + if (tmp_value == NULL || tmp_value == Py_None) { + Py_XDECREF(tmp_value); + tmp_value = NULL; + tmp_type = NULL; + tmp_tb = NULL; + } else { + tmp_type = (PyObject*) Py_TYPE(tmp_value); + Py_INCREF(tmp_type); + #if CYTHON_COMPILING_IN_CPYTHON + tmp_tb = ((PyBaseExceptionObject*) tmp_value)->traceback; + Py_XINCREF(tmp_tb); + #else + tmp_tb = PyException_GetTraceback(tmp_value); + #endif + } + #elif CYTHON_USE_EXC_INFO_STACK + _PyErr_StackItem *exc_info = tstate->exc_info; + tmp_type = exc_info->exc_type; + tmp_value = exc_info->exc_value; + tmp_tb = exc_info->exc_traceback; + exc_info->exc_type = *type; + exc_info->exc_value = *value; + exc_info->exc_traceback = *tb; + #else + tmp_type = tstate->exc_type; + tmp_value = tstate->exc_value; + tmp_tb = tstate->exc_traceback; + tstate->exc_type = *type; + tstate->exc_value = *value; + tstate->exc_traceback = *tb; + #endif + *type = tmp_type; + *value = tmp_value; + *tb = tmp_tb; +} +#else +static CYTHON_INLINE void __Pyx_ExceptionSwap(PyObject **type, PyObject **value, PyObject **tb) { + PyObject *tmp_type, *tmp_value, *tmp_tb; + PyErr_GetExcInfo(&tmp_type, &tmp_value, &tmp_tb); + PyErr_SetExcInfo(*type, *value, *tb); + *type = tmp_type; + *value = tmp_value; + *tb = tmp_tb; +} +#endif + +/* Import */ +static PyObject *__Pyx_Import(PyObject *name, PyObject *from_list, int level) { + PyObject *module = 0; + PyObject *empty_dict = 0; + PyObject *empty_list = 0; + #if PY_MAJOR_VERSION < 3 + PyObject *py_import; + py_import = __Pyx_PyObject_GetAttrStr(__pyx_b, __pyx_n_s_import); + if (unlikely(!py_import)) + goto bad; + if (!from_list) { + empty_list = PyList_New(0); + if (unlikely(!empty_list)) + goto bad; + from_list = empty_list; + } + #endif + empty_dict = PyDict_New(); + if (unlikely(!empty_dict)) + goto bad; + { + #if PY_MAJOR_VERSION >= 3 + if (level == -1) { + if (strchr(__Pyx_MODULE_NAME, '.') != NULL) { + module = PyImport_ImportModuleLevelObject( + name, __pyx_d, empty_dict, from_list, 1); + if (unlikely(!module)) { + if (unlikely(!PyErr_ExceptionMatches(PyExc_ImportError))) + goto bad; + PyErr_Clear(); + } + } + level = 0; + } + #endif + if (!module) { + #if PY_MAJOR_VERSION < 3 + PyObject *py_level = PyInt_FromLong(level); + if (unlikely(!py_level)) + goto bad; + module = PyObject_CallFunctionObjArgs(py_import, + name, __pyx_d, empty_dict, from_list, py_level, (PyObject *)NULL); + Py_DECREF(py_level); + #else + module = PyImport_ImportModuleLevelObject( + name, __pyx_d, empty_dict, from_list, level); + #endif + } + } +bad: + Py_XDECREF(empty_dict); + Py_XDECREF(empty_list); + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(py_import); + #endif + return module; +} + +/* ImportDottedModule */ +#if PY_MAJOR_VERSION >= 3 +static PyObject *__Pyx__ImportDottedModule_Error(PyObject *name, PyObject *parts_tuple, Py_ssize_t count) { + PyObject *partial_name = NULL, *slice = NULL, *sep = NULL; + if (unlikely(PyErr_Occurred())) { + PyErr_Clear(); + } + if (likely(PyTuple_GET_SIZE(parts_tuple) == count)) { + partial_name = name; + } else { + slice = PySequence_GetSlice(parts_tuple, 0, count); + if (unlikely(!slice)) + goto bad; + sep = PyUnicode_FromStringAndSize(".", 1); + if (unlikely(!sep)) + goto bad; + partial_name = PyUnicode_Join(sep, slice); + } + PyErr_Format( +#if PY_MAJOR_VERSION < 3 + PyExc_ImportError, + "No module named '%s'", PyString_AS_STRING(partial_name)); +#else +#if PY_VERSION_HEX >= 0x030600B1 + PyExc_ModuleNotFoundError, +#else + PyExc_ImportError, +#endif + "No module named '%U'", partial_name); +#endif +bad: + Py_XDECREF(sep); + Py_XDECREF(slice); + Py_XDECREF(partial_name); + return NULL; +} +#endif +#if PY_MAJOR_VERSION >= 3 +static PyObject *__Pyx__ImportDottedModule_Lookup(PyObject *name) { + PyObject *imported_module; +#if PY_VERSION_HEX < 0x030700A1 || (CYTHON_COMPILING_IN_PYPY && PYPY_VERSION_NUM < 0x07030400) + PyObject *modules = PyImport_GetModuleDict(); + if (unlikely(!modules)) + return NULL; + imported_module = __Pyx_PyDict_GetItemStr(modules, name); + Py_XINCREF(imported_module); +#else + imported_module = PyImport_GetModule(name); +#endif + return imported_module; +} +#endif +#if PY_MAJOR_VERSION >= 3 +static PyObject *__Pyx_ImportDottedModule_WalkParts(PyObject *module, PyObject *name, PyObject *parts_tuple) { + Py_ssize_t i, nparts; + nparts = PyTuple_GET_SIZE(parts_tuple); + for (i=1; i < nparts && module; i++) { + PyObject *part, *submodule; +#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + part = PyTuple_GET_ITEM(parts_tuple, i); +#else + part = PySequence_ITEM(parts_tuple, i); +#endif + submodule = __Pyx_PyObject_GetAttrStrNoError(module, part); +#if !(CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS) + Py_DECREF(part); +#endif + Py_DECREF(module); + module = submodule; + } + if (unlikely(!module)) { + return __Pyx__ImportDottedModule_Error(name, parts_tuple, i); + } + return module; +} +#endif +static PyObject *__Pyx__ImportDottedModule(PyObject *name, PyObject *parts_tuple) { +#if PY_MAJOR_VERSION < 3 + PyObject *module, *from_list, *star = __pyx_n_s__3; + CYTHON_UNUSED_VAR(parts_tuple); + from_list = PyList_New(1); + if (unlikely(!from_list)) + return NULL; + Py_INCREF(star); + PyList_SET_ITEM(from_list, 0, star); + module = __Pyx_Import(name, from_list, 0); + Py_DECREF(from_list); + return module; +#else + PyObject *imported_module; + PyObject *module = __Pyx_Import(name, NULL, 0); + if (!parts_tuple || unlikely(!module)) + return module; + imported_module = __Pyx__ImportDottedModule_Lookup(name); + if (likely(imported_module)) { + Py_DECREF(module); + return imported_module; + } + PyErr_Clear(); + return __Pyx_ImportDottedModule_WalkParts(module, name, parts_tuple); +#endif +} +static PyObject *__Pyx_ImportDottedModule(PyObject *name, PyObject *parts_tuple) { +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030400B1 + PyObject *module = __Pyx__ImportDottedModule_Lookup(name); + if (likely(module)) { + PyObject *spec = __Pyx_PyObject_GetAttrStrNoError(module, __pyx_n_s_spec); + if (likely(spec)) { + PyObject *unsafe = __Pyx_PyObject_GetAttrStrNoError(spec, __pyx_n_s_initializing); + if (likely(!unsafe || !__Pyx_PyObject_IsTrue(unsafe))) { + Py_DECREF(spec); + spec = NULL; + } + Py_XDECREF(unsafe); + } + if (likely(!spec)) { + PyErr_Clear(); + return module; + } + Py_DECREF(spec); + Py_DECREF(module); + } else if (PyErr_Occurred()) { + PyErr_Clear(); + } +#endif + return __Pyx__ImportDottedModule(name, parts_tuple); +} + +/* FastTypeChecks */ +#if CYTHON_COMPILING_IN_CPYTHON +static int __Pyx_InBases(PyTypeObject *a, PyTypeObject *b) { + while (a) { + a = __Pyx_PyType_GetSlot(a, tp_base, PyTypeObject*); + if (a == b) + return 1; + } + return b == &PyBaseObject_Type; +} +static CYTHON_INLINE int __Pyx_IsSubtype(PyTypeObject *a, PyTypeObject *b) { + PyObject *mro; + if (a == b) return 1; + mro = a->tp_mro; + if (likely(mro)) { + Py_ssize_t i, n; + n = PyTuple_GET_SIZE(mro); + for (i = 0; i < n; i++) { + if (PyTuple_GET_ITEM(mro, i) == (PyObject *)b) + return 1; + } + return 0; + } + return __Pyx_InBases(a, b); +} +static CYTHON_INLINE int __Pyx_IsAnySubtype2(PyTypeObject *cls, PyTypeObject *a, PyTypeObject *b) { + PyObject *mro; + if (cls == a || cls == b) return 1; + mro = cls->tp_mro; + if (likely(mro)) { + Py_ssize_t i, n; + n = PyTuple_GET_SIZE(mro); + for (i = 0; i < n; i++) { + PyObject *base = PyTuple_GET_ITEM(mro, i); + if (base == (PyObject *)a || base == (PyObject *)b) + return 1; + } + return 0; + } + return __Pyx_InBases(cls, a) || __Pyx_InBases(cls, b); +} +#if PY_MAJOR_VERSION == 2 +static int __Pyx_inner_PyErr_GivenExceptionMatches2(PyObject *err, PyObject* exc_type1, PyObject* exc_type2) { + PyObject *exception, *value, *tb; + int res; + __Pyx_PyThreadState_declare + __Pyx_PyThreadState_assign + __Pyx_ErrFetch(&exception, &value, &tb); + res = exc_type1 ? PyObject_IsSubclass(err, exc_type1) : 0; + if (unlikely(res == -1)) { + PyErr_WriteUnraisable(err); + res = 0; + } + if (!res) { + res = PyObject_IsSubclass(err, exc_type2); + if (unlikely(res == -1)) { + PyErr_WriteUnraisable(err); + res = 0; + } + } + __Pyx_ErrRestore(exception, value, tb); + return res; +} +#else +static CYTHON_INLINE int __Pyx_inner_PyErr_GivenExceptionMatches2(PyObject *err, PyObject* exc_type1, PyObject *exc_type2) { + if (exc_type1) { + return __Pyx_IsAnySubtype2((PyTypeObject*)err, (PyTypeObject*)exc_type1, (PyTypeObject*)exc_type2); + } else { + return __Pyx_IsSubtype((PyTypeObject*)err, (PyTypeObject*)exc_type2); + } +} +#endif +static int __Pyx_PyErr_GivenExceptionMatchesTuple(PyObject *exc_type, PyObject *tuple) { + Py_ssize_t i, n; + assert(PyExceptionClass_Check(exc_type)); + n = PyTuple_GET_SIZE(tuple); +#if PY_MAJOR_VERSION >= 3 + for (i=0; itp_as_sequence && type->tp_as_sequence->sq_repeat)) { + return type->tp_as_sequence->sq_repeat(seq, mul); + } else +#endif + { + return __Pyx_PySequence_Multiply_Generic(seq, mul); + } +} + +/* SetItemInt */ +static int __Pyx_SetItemInt_Generic(PyObject *o, PyObject *j, PyObject *v) { + int r; + if (unlikely(!j)) return -1; + r = PyObject_SetItem(o, j, v); + Py_DECREF(j); + return r; +} +static CYTHON_INLINE int __Pyx_SetItemInt_Fast(PyObject *o, Py_ssize_t i, PyObject *v, int is_list, + CYTHON_NCP_UNUSED int wraparound, CYTHON_NCP_UNUSED int boundscheck) { +#if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS && CYTHON_USE_TYPE_SLOTS + if (is_list || PyList_CheckExact(o)) { + Py_ssize_t n = (!wraparound) ? i : ((likely(i >= 0)) ? i : i + PyList_GET_SIZE(o)); + if ((!boundscheck) || likely(__Pyx_is_valid_index(n, PyList_GET_SIZE(o)))) { + PyObject* old = PyList_GET_ITEM(o, n); + Py_INCREF(v); + PyList_SET_ITEM(o, n, v); + Py_DECREF(old); + return 1; + } + } else { + PyMappingMethods *mm = Py_TYPE(o)->tp_as_mapping; + PySequenceMethods *sm = Py_TYPE(o)->tp_as_sequence; + if (mm && mm->mp_ass_subscript) { + int r; + PyObject *key = PyInt_FromSsize_t(i); + if (unlikely(!key)) return -1; + r = mm->mp_ass_subscript(o, key, v); + Py_DECREF(key); + return r; + } + if (likely(sm && sm->sq_ass_item)) { + if (wraparound && unlikely(i < 0) && likely(sm->sq_length)) { + Py_ssize_t l = sm->sq_length(o); + if (likely(l >= 0)) { + i += l; + } else { + if (!PyErr_ExceptionMatches(PyExc_OverflowError)) + return -1; + PyErr_Clear(); + } + } + return sm->sq_ass_item(o, i, v); + } + } +#else + if (is_list || !PyMapping_Check(o)) + { + return PySequence_SetItem(o, i, v); + } +#endif + return __Pyx_SetItemInt_Generic(o, PyInt_FromSsize_t(i), v); +} + +/* RaiseUnboundLocalError */ +static CYTHON_INLINE void __Pyx_RaiseUnboundLocalError(const char *varname) { + PyErr_Format(PyExc_UnboundLocalError, "local variable '%s' referenced before assignment", varname); +} + +/* DivInt[long] */ +static CYTHON_INLINE long __Pyx_div_long(long a, long b) { + long q = a / b; + long r = a - q*b; + q -= ((r != 0) & ((r ^ b) < 0)); + return q; +} + +/* ImportFrom */ +static PyObject* __Pyx_ImportFrom(PyObject* module, PyObject* name) { + PyObject* value = __Pyx_PyObject_GetAttrStr(module, name); + if (unlikely(!value) && PyErr_ExceptionMatches(PyExc_AttributeError)) { + const char* module_name_str = 0; + PyObject* module_name = 0; + PyObject* module_dot = 0; + PyObject* full_name = 0; + PyErr_Clear(); + module_name_str = PyModule_GetName(module); + if (unlikely(!module_name_str)) { goto modbad; } + module_name = PyUnicode_FromString(module_name_str); + if (unlikely(!module_name)) { goto modbad; } + module_dot = PyUnicode_Concat(module_name, __pyx_kp_u__2); + if (unlikely(!module_dot)) { goto modbad; } + full_name = PyUnicode_Concat(module_dot, name); + if (unlikely(!full_name)) { goto modbad; } + #if PY_VERSION_HEX < 0x030700A1 || (CYTHON_COMPILING_IN_PYPY && PYPY_VERSION_NUM < 0x07030400) + { + PyObject *modules = PyImport_GetModuleDict(); + if (unlikely(!modules)) + goto modbad; + value = PyObject_GetItem(modules, full_name); + } + #else + value = PyImport_GetModule(full_name); + #endif + modbad: + Py_XDECREF(full_name); + Py_XDECREF(module_dot); + Py_XDECREF(module_name); + } + if (unlikely(!value)) { + PyErr_Format(PyExc_ImportError, + #if PY_MAJOR_VERSION < 3 + "cannot import name %.230s", PyString_AS_STRING(name)); + #else + "cannot import name %S", name); + #endif + } + return value; +} + +/* HasAttr */ +static CYTHON_INLINE int __Pyx_HasAttr(PyObject *o, PyObject *n) { + PyObject *r; + if (unlikely(!__Pyx_PyBaseString_Check(n))) { + PyErr_SetString(PyExc_TypeError, + "hasattr(): attribute name must be string"); + return -1; + } + r = __Pyx_GetAttr(o, n); + if (!r) { + PyErr_Clear(); + return 0; + } else { + Py_DECREF(r); + return 1; + } +} + +/* ErrOccurredWithGIL */ +static CYTHON_INLINE int __Pyx_ErrOccurredWithGIL(void) { + int err; + #ifdef WITH_THREAD + PyGILState_STATE _save = PyGILState_Ensure(); + #endif + err = !!PyErr_Occurred(); + #ifdef WITH_THREAD + PyGILState_Release(_save); + #endif + return err; +} + +/* PyObject_GenericGetAttrNoDict */ +#if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 +static PyObject *__Pyx_RaiseGenericGetAttributeError(PyTypeObject *tp, PyObject *attr_name) { + __Pyx_TypeName type_name = __Pyx_PyType_GetName(tp); + PyErr_Format(PyExc_AttributeError, +#if PY_MAJOR_VERSION >= 3 + "'" __Pyx_FMT_TYPENAME "' object has no attribute '%U'", + type_name, attr_name); +#else + "'" __Pyx_FMT_TYPENAME "' object has no attribute '%.400s'", + type_name, PyString_AS_STRING(attr_name)); +#endif + __Pyx_DECREF_TypeName(type_name); + return NULL; +} +static CYTHON_INLINE PyObject* __Pyx_PyObject_GenericGetAttrNoDict(PyObject* obj, PyObject* attr_name) { + PyObject *descr; + PyTypeObject *tp = Py_TYPE(obj); + if (unlikely(!PyString_Check(attr_name))) { + return PyObject_GenericGetAttr(obj, attr_name); + } + assert(!tp->tp_dictoffset); + descr = _PyType_Lookup(tp, attr_name); + if (unlikely(!descr)) { + return __Pyx_RaiseGenericGetAttributeError(tp, attr_name); + } + Py_INCREF(descr); + #if PY_MAJOR_VERSION < 3 + if (likely(PyType_HasFeature(Py_TYPE(descr), Py_TPFLAGS_HAVE_CLASS))) + #endif + { + descrgetfunc f = Py_TYPE(descr)->tp_descr_get; + if (unlikely(f)) { + PyObject *res = f(descr, obj, (PyObject *)tp); + Py_DECREF(descr); + return res; + } + } + return descr; +} +#endif + +/* PyObject_GenericGetAttr */ +#if CYTHON_USE_TYPE_SLOTS && CYTHON_USE_PYTYPE_LOOKUP && PY_VERSION_HEX < 0x03070000 +static PyObject* __Pyx_PyObject_GenericGetAttr(PyObject* obj, PyObject* attr_name) { + if (unlikely(Py_TYPE(obj)->tp_dictoffset)) { + return PyObject_GenericGetAttr(obj, attr_name); + } + return __Pyx_PyObject_GenericGetAttrNoDict(obj, attr_name); +} +#endif + +/* FixUpExtensionType */ +#if CYTHON_USE_TYPE_SPECS +static int __Pyx_fix_up_extension_type_from_spec(PyType_Spec *spec, PyTypeObject *type) { +#if PY_VERSION_HEX > 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API + CYTHON_UNUSED_VAR(spec); + CYTHON_UNUSED_VAR(type); +#else + const PyType_Slot *slot = spec->slots; + while (slot && slot->slot && slot->slot != Py_tp_members) + slot++; + if (slot && slot->slot == Py_tp_members) { + int changed = 0; +#if !(PY_VERSION_HEX <= 0x030900b1 && CYTHON_COMPILING_IN_CPYTHON) + const +#endif + PyMemberDef *memb = (PyMemberDef*) slot->pfunc; + while (memb && memb->name) { + if (memb->name[0] == '_' && memb->name[1] == '_') { +#if PY_VERSION_HEX < 0x030900b1 + if (strcmp(memb->name, "__weaklistoffset__") == 0) { + assert(memb->type == T_PYSSIZET); + assert(memb->flags == READONLY); + type->tp_weaklistoffset = memb->offset; + changed = 1; + } + else if (strcmp(memb->name, "__dictoffset__") == 0) { + assert(memb->type == T_PYSSIZET); + assert(memb->flags == READONLY); + type->tp_dictoffset = memb->offset; + changed = 1; + } +#if CYTHON_METH_FASTCALL + else if (strcmp(memb->name, "__vectorcalloffset__") == 0) { + assert(memb->type == T_PYSSIZET); + assert(memb->flags == READONLY); +#if PY_VERSION_HEX >= 0x030800b4 + type->tp_vectorcall_offset = memb->offset; +#else + type->tp_print = (printfunc) memb->offset; +#endif + changed = 1; + } +#endif +#else + if ((0)); +#endif +#if PY_VERSION_HEX <= 0x030900b1 && CYTHON_COMPILING_IN_CPYTHON + else if (strcmp(memb->name, "__module__") == 0) { + PyObject *descr; + assert(memb->type == T_OBJECT); + assert(memb->flags == 0 || memb->flags == READONLY); + descr = PyDescr_NewMember(type, memb); + if (unlikely(!descr)) + return -1; + if (unlikely(PyDict_SetItem(type->tp_dict, PyDescr_NAME(descr), descr) < 0)) { + Py_DECREF(descr); + return -1; + } + Py_DECREF(descr); + changed = 1; + } +#endif + } + memb++; + } + if (changed) + PyType_Modified(type); + } +#endif + return 0; +} +#endif + +/* PyObjectCallNoArg */ +static CYTHON_INLINE PyObject* __Pyx_PyObject_CallNoArg(PyObject *func) { + PyObject *arg[2] = {NULL, NULL}; + return __Pyx_PyObject_FastCall(func, arg + 1, 0 | __Pyx_PY_VECTORCALL_ARGUMENTS_OFFSET); +} + +/* PyObjectGetMethod */ +static int __Pyx_PyObject_GetMethod(PyObject *obj, PyObject *name, PyObject **method) { + PyObject *attr; +#if CYTHON_UNPACK_METHODS && CYTHON_COMPILING_IN_CPYTHON && CYTHON_USE_PYTYPE_LOOKUP + __Pyx_TypeName type_name; + PyTypeObject *tp = Py_TYPE(obj); + PyObject *descr; + descrgetfunc f = NULL; + PyObject **dictptr, *dict; + int meth_found = 0; + assert (*method == NULL); + if (unlikely(tp->tp_getattro != PyObject_GenericGetAttr)) { + attr = __Pyx_PyObject_GetAttrStr(obj, name); + goto try_unpack; + } + if (unlikely(tp->tp_dict == NULL) && unlikely(PyType_Ready(tp) < 0)) { + return 0; + } + descr = _PyType_Lookup(tp, name); + if (likely(descr != NULL)) { + Py_INCREF(descr); +#if defined(Py_TPFLAGS_METHOD_DESCRIPTOR) && Py_TPFLAGS_METHOD_DESCRIPTOR + if (__Pyx_PyType_HasFeature(Py_TYPE(descr), Py_TPFLAGS_METHOD_DESCRIPTOR)) +#elif PY_MAJOR_VERSION >= 3 + #ifdef __Pyx_CyFunction_USED + if (likely(PyFunction_Check(descr) || __Pyx_IS_TYPE(descr, &PyMethodDescr_Type) || __Pyx_CyFunction_Check(descr))) + #else + if (likely(PyFunction_Check(descr) || __Pyx_IS_TYPE(descr, &PyMethodDescr_Type))) + #endif +#else + #ifdef __Pyx_CyFunction_USED + if (likely(PyFunction_Check(descr) || __Pyx_CyFunction_Check(descr))) + #else + if (likely(PyFunction_Check(descr))) + #endif +#endif + { + meth_found = 1; + } else { + f = Py_TYPE(descr)->tp_descr_get; + if (f != NULL && PyDescr_IsData(descr)) { + attr = f(descr, obj, (PyObject *)Py_TYPE(obj)); + Py_DECREF(descr); + goto try_unpack; + } + } + } + dictptr = _PyObject_GetDictPtr(obj); + if (dictptr != NULL && (dict = *dictptr) != NULL) { + Py_INCREF(dict); + attr = __Pyx_PyDict_GetItemStr(dict, name); + if (attr != NULL) { + Py_INCREF(attr); + Py_DECREF(dict); + Py_XDECREF(descr); + goto try_unpack; + } + Py_DECREF(dict); + } + if (meth_found) { + *method = descr; + return 1; + } + if (f != NULL) { + attr = f(descr, obj, (PyObject *)Py_TYPE(obj)); + Py_DECREF(descr); + goto try_unpack; + } + if (likely(descr != NULL)) { + *method = descr; + return 0; + } + type_name = __Pyx_PyType_GetName(tp); + PyErr_Format(PyExc_AttributeError, +#if PY_MAJOR_VERSION >= 3 + "'" __Pyx_FMT_TYPENAME "' object has no attribute '%U'", + type_name, name); +#else + "'" __Pyx_FMT_TYPENAME "' object has no attribute '%.400s'", + type_name, PyString_AS_STRING(name)); +#endif + __Pyx_DECREF_TypeName(type_name); + return 0; +#else + attr = __Pyx_PyObject_GetAttrStr(obj, name); + goto try_unpack; +#endif +try_unpack: +#if CYTHON_UNPACK_METHODS + if (likely(attr) && PyMethod_Check(attr) && likely(PyMethod_GET_SELF(attr) == obj)) { + PyObject *function = PyMethod_GET_FUNCTION(attr); + Py_INCREF(function); + Py_DECREF(attr); + *method = function; + return 1; + } +#endif + *method = attr; + return 0; +} + +/* PyObjectCallMethod0 */ +static PyObject* __Pyx_PyObject_CallMethod0(PyObject* obj, PyObject* method_name) { + PyObject *method = NULL, *result = NULL; + int is_method = __Pyx_PyObject_GetMethod(obj, method_name, &method); + if (likely(is_method)) { + result = __Pyx_PyObject_CallOneArg(method, obj); + Py_DECREF(method); + return result; + } + if (unlikely(!method)) goto bad; + result = __Pyx_PyObject_CallNoArg(method); + Py_DECREF(method); +bad: + return result; +} + +/* ValidateBasesTuple */ +#if CYTHON_COMPILING_IN_CPYTHON || CYTHON_COMPILING_IN_LIMITED_API || CYTHON_USE_TYPE_SPECS +static int __Pyx_validate_bases_tuple(const char *type_name, Py_ssize_t dictoffset, PyObject *bases) { + Py_ssize_t i, n; +#if CYTHON_ASSUME_SAFE_MACROS + n = PyTuple_GET_SIZE(bases); +#else + n = PyTuple_Size(bases); + if (n < 0) return -1; +#endif + for (i = 1; i < n; i++) + { +#if CYTHON_AVOID_BORROWED_REFS + PyObject *b0 = PySequence_GetItem(bases, i); + if (!b0) return -1; +#elif CYTHON_ASSUME_SAFE_MACROS + PyObject *b0 = PyTuple_GET_ITEM(bases, i); +#else + PyObject *b0 = PyTuple_GetItem(bases, i); + if (!b0) return -1; +#endif + PyTypeObject *b; +#if PY_MAJOR_VERSION < 3 + if (PyClass_Check(b0)) + { + PyErr_Format(PyExc_TypeError, "base class '%.200s' is an old-style class", + PyString_AS_STRING(((PyClassObject*)b0)->cl_name)); +#if CYTHON_AVOID_BORROWED_REFS + Py_DECREF(b0); +#endif + return -1; + } +#endif + b = (PyTypeObject*) b0; + if (!__Pyx_PyType_HasFeature(b, Py_TPFLAGS_HEAPTYPE)) + { + __Pyx_TypeName b_name = __Pyx_PyType_GetName(b); + PyErr_Format(PyExc_TypeError, + "base class '" __Pyx_FMT_TYPENAME "' is not a heap type", b_name); + __Pyx_DECREF_TypeName(b_name); +#if CYTHON_AVOID_BORROWED_REFS + Py_DECREF(b0); +#endif + return -1; + } + if (dictoffset == 0) + { + Py_ssize_t b_dictoffset = 0; +#if CYTHON_USE_TYPE_SLOTS || CYTHON_COMPILING_IN_PYPY + b_dictoffset = b->tp_dictoffset; +#else + PyObject *py_b_dictoffset = PyObject_GetAttrString((PyObject*)b, "__dictoffset__"); + if (!py_b_dictoffset) goto dictoffset_return; + b_dictoffset = PyLong_AsSsize_t(py_b_dictoffset); + Py_DECREF(py_b_dictoffset); + if (b_dictoffset == -1 && PyErr_Occurred()) goto dictoffset_return; +#endif + if (b_dictoffset) { + { + __Pyx_TypeName b_name = __Pyx_PyType_GetName(b); + PyErr_Format(PyExc_TypeError, + "extension type '%.200s' has no __dict__ slot, " + "but base type '" __Pyx_FMT_TYPENAME "' has: " + "either add 'cdef dict __dict__' to the extension type " + "or add '__slots__ = [...]' to the base type", + type_name, b_name); + __Pyx_DECREF_TypeName(b_name); + } +#if !(CYTHON_USE_TYPE_SLOTS || CYTHON_COMPILING_IN_PYPY) + dictoffset_return: +#endif +#if CYTHON_AVOID_BORROWED_REFS + Py_DECREF(b0); +#endif + return -1; + } + } +#if CYTHON_AVOID_BORROWED_REFS + Py_DECREF(b0); +#endif + } + return 0; +} +#endif + +/* PyType_Ready */ +static int __Pyx_PyType_Ready(PyTypeObject *t) { +#if CYTHON_USE_TYPE_SPECS || !(CYTHON_COMPILING_IN_CPYTHON || CYTHON_COMPILING_IN_LIMITED_API) || defined(PYSTON_MAJOR_VERSION) + (void)__Pyx_PyObject_CallMethod0; +#if CYTHON_USE_TYPE_SPECS + (void)__Pyx_validate_bases_tuple; +#endif + return PyType_Ready(t); +#else + int r; + PyObject *bases = __Pyx_PyType_GetSlot(t, tp_bases, PyObject*); + if (bases && unlikely(__Pyx_validate_bases_tuple(t->tp_name, t->tp_dictoffset, bases) == -1)) + return -1; +#if PY_VERSION_HEX >= 0x03050000 && !defined(PYSTON_MAJOR_VERSION) + { + int gc_was_enabled; + #if PY_VERSION_HEX >= 0x030A00b1 + gc_was_enabled = PyGC_Disable(); + (void)__Pyx_PyObject_CallMethod0; + #else + PyObject *ret, *py_status; + PyObject *gc = NULL; + #if PY_VERSION_HEX >= 0x030700a1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM+0 >= 0x07030400) + gc = PyImport_GetModule(__pyx_kp_u_gc); + #endif + if (unlikely(!gc)) gc = PyImport_Import(__pyx_kp_u_gc); + if (unlikely(!gc)) return -1; + py_status = __Pyx_PyObject_CallMethod0(gc, __pyx_kp_u_isenabled); + if (unlikely(!py_status)) { + Py_DECREF(gc); + return -1; + } + gc_was_enabled = __Pyx_PyObject_IsTrue(py_status); + Py_DECREF(py_status); + if (gc_was_enabled > 0) { + ret = __Pyx_PyObject_CallMethod0(gc, __pyx_kp_u_disable); + if (unlikely(!ret)) { + Py_DECREF(gc); + return -1; + } + Py_DECREF(ret); + } else if (unlikely(gc_was_enabled == -1)) { + Py_DECREF(gc); + return -1; + } + #endif + t->tp_flags |= Py_TPFLAGS_HEAPTYPE; +#if PY_VERSION_HEX >= 0x030A0000 + t->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE; +#endif +#else + (void)__Pyx_PyObject_CallMethod0; +#endif + r = PyType_Ready(t); +#if PY_VERSION_HEX >= 0x03050000 && !defined(PYSTON_MAJOR_VERSION) + t->tp_flags &= ~Py_TPFLAGS_HEAPTYPE; + #if PY_VERSION_HEX >= 0x030A00b1 + if (gc_was_enabled) + PyGC_Enable(); + #else + if (gc_was_enabled) { + PyObject *tp, *v, *tb; + PyErr_Fetch(&tp, &v, &tb); + ret = __Pyx_PyObject_CallMethod0(gc, __pyx_kp_u_enable); + if (likely(ret || r == -1)) { + Py_XDECREF(ret); + PyErr_Restore(tp, v, tb); + } else { + Py_XDECREF(tp); + Py_XDECREF(v); + Py_XDECREF(tb); + r = -1; + } + } + Py_DECREF(gc); + #endif + } +#endif + return r; +#endif +} + +/* SetVTable */ +static int __Pyx_SetVtable(PyTypeObject *type, void *vtable) { + PyObject *ob = PyCapsule_New(vtable, 0, 0); + if (unlikely(!ob)) + goto bad; +#if CYTHON_COMPILING_IN_LIMITED_API + if (unlikely(PyObject_SetAttr((PyObject *) type, __pyx_n_s_pyx_vtable, ob) < 0)) +#else + if (unlikely(PyDict_SetItem(type->tp_dict, __pyx_n_s_pyx_vtable, ob) < 0)) +#endif + goto bad; + Py_DECREF(ob); + return 0; +bad: + Py_XDECREF(ob); + return -1; +} + +/* GetVTable */ +static void* __Pyx_GetVtable(PyTypeObject *type) { + void* ptr; +#if CYTHON_COMPILING_IN_LIMITED_API + PyObject *ob = PyObject_GetAttr((PyObject *)type, __pyx_n_s_pyx_vtable); +#else + PyObject *ob = PyObject_GetItem(type->tp_dict, __pyx_n_s_pyx_vtable); +#endif + if (!ob) + goto bad; + ptr = PyCapsule_GetPointer(ob, 0); + if (!ptr && !PyErr_Occurred()) + PyErr_SetString(PyExc_RuntimeError, "invalid vtable found for imported type"); + Py_DECREF(ob); + return ptr; +bad: + Py_XDECREF(ob); + return NULL; +} + +/* MergeVTables */ +#if !CYTHON_COMPILING_IN_LIMITED_API +static int __Pyx_MergeVtables(PyTypeObject *type) { + int i; + void** base_vtables; + __Pyx_TypeName tp_base_name; + __Pyx_TypeName base_name; + void* unknown = (void*)-1; + PyObject* bases = type->tp_bases; + int base_depth = 0; + { + PyTypeObject* base = type->tp_base; + while (base) { + base_depth += 1; + base = base->tp_base; + } + } + base_vtables = (void**) malloc(sizeof(void*) * (size_t)(base_depth + 1)); + base_vtables[0] = unknown; + for (i = 1; i < PyTuple_GET_SIZE(bases); i++) { + void* base_vtable = __Pyx_GetVtable(((PyTypeObject*)PyTuple_GET_ITEM(bases, i))); + if (base_vtable != NULL) { + int j; + PyTypeObject* base = type->tp_base; + for (j = 0; j < base_depth; j++) { + if (base_vtables[j] == unknown) { + base_vtables[j] = __Pyx_GetVtable(base); + base_vtables[j + 1] = unknown; + } + if (base_vtables[j] == base_vtable) { + break; + } else if (base_vtables[j] == NULL) { + goto bad; + } + base = base->tp_base; + } + } + } + PyErr_Clear(); + free(base_vtables); + return 0; +bad: + tp_base_name = __Pyx_PyType_GetName(type->tp_base); + base_name = __Pyx_PyType_GetName((PyTypeObject*)PyTuple_GET_ITEM(bases, i)); + PyErr_Format(PyExc_TypeError, + "multiple bases have vtable conflict: '" __Pyx_FMT_TYPENAME "' and '" __Pyx_FMT_TYPENAME "'", tp_base_name, base_name); + __Pyx_DECREF_TypeName(tp_base_name); + __Pyx_DECREF_TypeName(base_name); + free(base_vtables); + return -1; +} +#endif + +/* SetupReduce */ +#if !CYTHON_COMPILING_IN_LIMITED_API +static int __Pyx_setup_reduce_is_named(PyObject* meth, PyObject* name) { + int ret; + PyObject *name_attr; + name_attr = __Pyx_PyObject_GetAttrStrNoError(meth, __pyx_n_s_name_2); + if (likely(name_attr)) { + ret = PyObject_RichCompareBool(name_attr, name, Py_EQ); + } else { + ret = -1; + } + if (unlikely(ret < 0)) { + PyErr_Clear(); + ret = 0; + } + Py_XDECREF(name_attr); + return ret; +} +static int __Pyx_setup_reduce(PyObject* type_obj) { + int ret = 0; + PyObject *object_reduce = NULL; + PyObject *object_getstate = NULL; + PyObject *object_reduce_ex = NULL; + PyObject *reduce = NULL; + PyObject *reduce_ex = NULL; + PyObject *reduce_cython = NULL; + PyObject *setstate = NULL; + PyObject *setstate_cython = NULL; + PyObject *getstate = NULL; +#if CYTHON_USE_PYTYPE_LOOKUP + getstate = _PyType_Lookup((PyTypeObject*)type_obj, __pyx_n_s_getstate); +#else + getstate = __Pyx_PyObject_GetAttrStrNoError(type_obj, __pyx_n_s_getstate); + if (!getstate && PyErr_Occurred()) { + goto __PYX_BAD; + } +#endif + if (getstate) { +#if CYTHON_USE_PYTYPE_LOOKUP + object_getstate = _PyType_Lookup(&PyBaseObject_Type, __pyx_n_s_getstate); +#else + object_getstate = __Pyx_PyObject_GetAttrStrNoError((PyObject*)&PyBaseObject_Type, __pyx_n_s_getstate); + if (!object_getstate && PyErr_Occurred()) { + goto __PYX_BAD; + } +#endif + if (object_getstate != getstate) { + goto __PYX_GOOD; + } + } +#if CYTHON_USE_PYTYPE_LOOKUP + object_reduce_ex = _PyType_Lookup(&PyBaseObject_Type, __pyx_n_s_reduce_ex); if (!object_reduce_ex) goto __PYX_BAD; +#else + object_reduce_ex = __Pyx_PyObject_GetAttrStr((PyObject*)&PyBaseObject_Type, __pyx_n_s_reduce_ex); if (!object_reduce_ex) goto __PYX_BAD; +#endif + reduce_ex = __Pyx_PyObject_GetAttrStr(type_obj, __pyx_n_s_reduce_ex); if (unlikely(!reduce_ex)) goto __PYX_BAD; + if (reduce_ex == object_reduce_ex) { +#if CYTHON_USE_PYTYPE_LOOKUP + object_reduce = _PyType_Lookup(&PyBaseObject_Type, __pyx_n_s_reduce); if (!object_reduce) goto __PYX_BAD; +#else + object_reduce = __Pyx_PyObject_GetAttrStr((PyObject*)&PyBaseObject_Type, __pyx_n_s_reduce); if (!object_reduce) goto __PYX_BAD; +#endif + reduce = __Pyx_PyObject_GetAttrStr(type_obj, __pyx_n_s_reduce); if (unlikely(!reduce)) goto __PYX_BAD; + if (reduce == object_reduce || __Pyx_setup_reduce_is_named(reduce, __pyx_n_s_reduce_cython)) { + reduce_cython = __Pyx_PyObject_GetAttrStrNoError(type_obj, __pyx_n_s_reduce_cython); + if (likely(reduce_cython)) { + ret = PyDict_SetItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_reduce, reduce_cython); if (unlikely(ret < 0)) goto __PYX_BAD; + ret = PyDict_DelItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_reduce_cython); if (unlikely(ret < 0)) goto __PYX_BAD; + } else if (reduce == object_reduce || PyErr_Occurred()) { + goto __PYX_BAD; + } + setstate = __Pyx_PyObject_GetAttrStrNoError(type_obj, __pyx_n_s_setstate); + if (!setstate) PyErr_Clear(); + if (!setstate || __Pyx_setup_reduce_is_named(setstate, __pyx_n_s_setstate_cython)) { + setstate_cython = __Pyx_PyObject_GetAttrStrNoError(type_obj, __pyx_n_s_setstate_cython); + if (likely(setstate_cython)) { + ret = PyDict_SetItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_setstate, setstate_cython); if (unlikely(ret < 0)) goto __PYX_BAD; + ret = PyDict_DelItem(((PyTypeObject*)type_obj)->tp_dict, __pyx_n_s_setstate_cython); if (unlikely(ret < 0)) goto __PYX_BAD; + } else if (!setstate || PyErr_Occurred()) { + goto __PYX_BAD; + } + } + PyType_Modified((PyTypeObject*)type_obj); + } + } + goto __PYX_GOOD; +__PYX_BAD: + if (!PyErr_Occurred()) { + __Pyx_TypeName type_obj_name = + __Pyx_PyType_GetName((PyTypeObject*)type_obj); + PyErr_Format(PyExc_RuntimeError, + "Unable to initialize pickling for " __Pyx_FMT_TYPENAME, type_obj_name); + __Pyx_DECREF_TypeName(type_obj_name); + } + ret = -1; +__PYX_GOOD: +#if !CYTHON_USE_PYTYPE_LOOKUP + Py_XDECREF(object_reduce); + Py_XDECREF(object_reduce_ex); + Py_XDECREF(object_getstate); + Py_XDECREF(getstate); +#endif + Py_XDECREF(reduce); + Py_XDECREF(reduce_ex); + Py_XDECREF(reduce_cython); + Py_XDECREF(setstate); + Py_XDECREF(setstate_cython); + return ret; +} +#endif + +/* TypeImport */ +#ifndef __PYX_HAVE_RT_ImportType_3_0_11 +#define __PYX_HAVE_RT_ImportType_3_0_11 +static PyTypeObject *__Pyx_ImportType_3_0_11(PyObject *module, const char *module_name, const char *class_name, + size_t size, size_t alignment, enum __Pyx_ImportType_CheckSize_3_0_11 check_size) +{ + PyObject *result = 0; + char warning[200]; + Py_ssize_t basicsize; + Py_ssize_t itemsize; +#if CYTHON_COMPILING_IN_LIMITED_API + PyObject *py_basicsize; + PyObject *py_itemsize; +#endif + result = PyObject_GetAttrString(module, class_name); + if (!result) + goto bad; + if (!PyType_Check(result)) { + PyErr_Format(PyExc_TypeError, + "%.200s.%.200s is not a type object", + module_name, class_name); + goto bad; + } +#if !CYTHON_COMPILING_IN_LIMITED_API + basicsize = ((PyTypeObject *)result)->tp_basicsize; + itemsize = ((PyTypeObject *)result)->tp_itemsize; +#else + py_basicsize = PyObject_GetAttrString(result, "__basicsize__"); + if (!py_basicsize) + goto bad; + basicsize = PyLong_AsSsize_t(py_basicsize); + Py_DECREF(py_basicsize); + py_basicsize = 0; + if (basicsize == (Py_ssize_t)-1 && PyErr_Occurred()) + goto bad; + py_itemsize = PyObject_GetAttrString(result, "__itemsize__"); + if (!py_itemsize) + goto bad; + itemsize = PyLong_AsSsize_t(py_itemsize); + Py_DECREF(py_itemsize); + py_itemsize = 0; + if (itemsize == (Py_ssize_t)-1 && PyErr_Occurred()) + goto bad; +#endif + if (itemsize) { + if (size % alignment) { + alignment = size % alignment; + } + if (itemsize < (Py_ssize_t)alignment) + itemsize = (Py_ssize_t)alignment; + } + if ((size_t)(basicsize + itemsize) < size) { + PyErr_Format(PyExc_ValueError, + "%.200s.%.200s size changed, may indicate binary incompatibility. " + "Expected %zd from C header, got %zd from PyObject", + module_name, class_name, size, basicsize+itemsize); + goto bad; + } + if (check_size == __Pyx_ImportType_CheckSize_Error_3_0_11 && + ((size_t)basicsize > size || (size_t)(basicsize + itemsize) < size)) { + PyErr_Format(PyExc_ValueError, + "%.200s.%.200s size changed, may indicate binary incompatibility. " + "Expected %zd from C header, got %zd-%zd from PyObject", + module_name, class_name, size, basicsize, basicsize+itemsize); + goto bad; + } + else if (check_size == __Pyx_ImportType_CheckSize_Warn_3_0_11 && (size_t)basicsize > size) { + PyOS_snprintf(warning, sizeof(warning), + "%s.%s size changed, may indicate binary incompatibility. " + "Expected %zd from C header, got %zd from PyObject", + module_name, class_name, size, basicsize); + if (PyErr_WarnEx(NULL, warning, 0) < 0) goto bad; + } + return (PyTypeObject *)result; +bad: + Py_XDECREF(result); + return NULL; +} +#endif + +/* FetchSharedCythonModule */ +static PyObject *__Pyx_FetchSharedCythonABIModule(void) { + return __Pyx_PyImport_AddModuleRef((char*) __PYX_ABI_MODULE_NAME); +} + +/* FetchCommonType */ +static int __Pyx_VerifyCachedType(PyObject *cached_type, + const char *name, + Py_ssize_t basicsize, + Py_ssize_t expected_basicsize) { + if (!PyType_Check(cached_type)) { + PyErr_Format(PyExc_TypeError, + "Shared Cython type %.200s is not a type object", name); + return -1; + } + if (basicsize != expected_basicsize) { + PyErr_Format(PyExc_TypeError, + "Shared Cython type %.200s has the wrong size, try recompiling", + name); + return -1; + } + return 0; +} +#if !CYTHON_USE_TYPE_SPECS +static PyTypeObject* __Pyx_FetchCommonType(PyTypeObject* type) { + PyObject* abi_module; + const char* object_name; + PyTypeObject *cached_type = NULL; + abi_module = __Pyx_FetchSharedCythonABIModule(); + if (!abi_module) return NULL; + object_name = strrchr(type->tp_name, '.'); + object_name = object_name ? object_name+1 : type->tp_name; + cached_type = (PyTypeObject*) PyObject_GetAttrString(abi_module, object_name); + if (cached_type) { + if (__Pyx_VerifyCachedType( + (PyObject *)cached_type, + object_name, + cached_type->tp_basicsize, + type->tp_basicsize) < 0) { + goto bad; + } + goto done; + } + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) goto bad; + PyErr_Clear(); + if (PyType_Ready(type) < 0) goto bad; + if (PyObject_SetAttrString(abi_module, object_name, (PyObject *)type) < 0) + goto bad; + Py_INCREF(type); + cached_type = type; +done: + Py_DECREF(abi_module); + return cached_type; +bad: + Py_XDECREF(cached_type); + cached_type = NULL; + goto done; +} +#else +static PyTypeObject *__Pyx_FetchCommonTypeFromSpec(PyObject *module, PyType_Spec *spec, PyObject *bases) { + PyObject *abi_module, *cached_type = NULL; + const char* object_name = strrchr(spec->name, '.'); + object_name = object_name ? object_name+1 : spec->name; + abi_module = __Pyx_FetchSharedCythonABIModule(); + if (!abi_module) return NULL; + cached_type = PyObject_GetAttrString(abi_module, object_name); + if (cached_type) { + Py_ssize_t basicsize; +#if CYTHON_COMPILING_IN_LIMITED_API + PyObject *py_basicsize; + py_basicsize = PyObject_GetAttrString(cached_type, "__basicsize__"); + if (unlikely(!py_basicsize)) goto bad; + basicsize = PyLong_AsSsize_t(py_basicsize); + Py_DECREF(py_basicsize); + py_basicsize = 0; + if (unlikely(basicsize == (Py_ssize_t)-1) && PyErr_Occurred()) goto bad; +#else + basicsize = likely(PyType_Check(cached_type)) ? ((PyTypeObject*) cached_type)->tp_basicsize : -1; +#endif + if (__Pyx_VerifyCachedType( + cached_type, + object_name, + basicsize, + spec->basicsize) < 0) { + goto bad; + } + goto done; + } + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) goto bad; + PyErr_Clear(); + CYTHON_UNUSED_VAR(module); + cached_type = __Pyx_PyType_FromModuleAndSpec(abi_module, spec, bases); + if (unlikely(!cached_type)) goto bad; + if (unlikely(__Pyx_fix_up_extension_type_from_spec(spec, (PyTypeObject *) cached_type) < 0)) goto bad; + if (PyObject_SetAttrString(abi_module, object_name, cached_type) < 0) goto bad; +done: + Py_DECREF(abi_module); + assert(cached_type == NULL || PyType_Check(cached_type)); + return (PyTypeObject *) cached_type; +bad: + Py_XDECREF(cached_type); + cached_type = NULL; + goto done; +} +#endif + +/* PyVectorcallFastCallDict */ +#if CYTHON_METH_FASTCALL +static PyObject *__Pyx_PyVectorcall_FastCallDict_kw(PyObject *func, __pyx_vectorcallfunc vc, PyObject *const *args, size_t nargs, PyObject *kw) +{ + PyObject *res = NULL; + PyObject *kwnames; + PyObject **newargs; + PyObject **kwvalues; + Py_ssize_t i, pos; + size_t j; + PyObject *key, *value; + unsigned long keys_are_strings; + Py_ssize_t nkw = PyDict_GET_SIZE(kw); + newargs = (PyObject **)PyMem_Malloc((nargs + (size_t)nkw) * sizeof(args[0])); + if (unlikely(newargs == NULL)) { + PyErr_NoMemory(); + return NULL; + } + for (j = 0; j < nargs; j++) newargs[j] = args[j]; + kwnames = PyTuple_New(nkw); + if (unlikely(kwnames == NULL)) { + PyMem_Free(newargs); + return NULL; + } + kwvalues = newargs + nargs; + pos = i = 0; + keys_are_strings = Py_TPFLAGS_UNICODE_SUBCLASS; + while (PyDict_Next(kw, &pos, &key, &value)) { + keys_are_strings &= Py_TYPE(key)->tp_flags; + Py_INCREF(key); + Py_INCREF(value); + PyTuple_SET_ITEM(kwnames, i, key); + kwvalues[i] = value; + i++; + } + if (unlikely(!keys_are_strings)) { + PyErr_SetString(PyExc_TypeError, "keywords must be strings"); + goto cleanup; + } + res = vc(func, newargs, nargs, kwnames); +cleanup: + Py_DECREF(kwnames); + for (i = 0; i < nkw; i++) + Py_DECREF(kwvalues[i]); + PyMem_Free(newargs); + return res; +} +static CYTHON_INLINE PyObject *__Pyx_PyVectorcall_FastCallDict(PyObject *func, __pyx_vectorcallfunc vc, PyObject *const *args, size_t nargs, PyObject *kw) +{ + if (likely(kw == NULL) || PyDict_GET_SIZE(kw) == 0) { + return vc(func, args, nargs, NULL); + } + return __Pyx_PyVectorcall_FastCallDict_kw(func, vc, args, nargs, kw); +} +#endif + +/* CythonFunctionShared */ +#if CYTHON_COMPILING_IN_LIMITED_API +static CYTHON_INLINE int __Pyx__IsSameCyOrCFunction(PyObject *func, void *cfunc) { + if (__Pyx_CyFunction_Check(func)) { + return PyCFunction_GetFunction(((__pyx_CyFunctionObject*)func)->func) == (PyCFunction) cfunc; + } else if (PyCFunction_Check(func)) { + return PyCFunction_GetFunction(func) == (PyCFunction) cfunc; + } + return 0; +} +#else +static CYTHON_INLINE int __Pyx__IsSameCyOrCFunction(PyObject *func, void *cfunc) { + return __Pyx_CyOrPyCFunction_Check(func) && __Pyx_CyOrPyCFunction_GET_FUNCTION(func) == (PyCFunction) cfunc; +} +#endif +static CYTHON_INLINE void __Pyx__CyFunction_SetClassObj(__pyx_CyFunctionObject* f, PyObject* classobj) { +#if PY_VERSION_HEX < 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API + __Pyx_Py_XDECREF_SET( + __Pyx_CyFunction_GetClassObj(f), + ((classobj) ? __Pyx_NewRef(classobj) : NULL)); +#else + __Pyx_Py_XDECREF_SET( + ((PyCMethodObject *) (f))->mm_class, + (PyTypeObject*)((classobj) ? __Pyx_NewRef(classobj) : NULL)); +#endif +} +static PyObject * +__Pyx_CyFunction_get_doc(__pyx_CyFunctionObject *op, void *closure) +{ + CYTHON_UNUSED_VAR(closure); + if (unlikely(op->func_doc == NULL)) { +#if CYTHON_COMPILING_IN_LIMITED_API + op->func_doc = PyObject_GetAttrString(op->func, "__doc__"); + if (unlikely(!op->func_doc)) return NULL; +#else + if (((PyCFunctionObject*)op)->m_ml->ml_doc) { +#if PY_MAJOR_VERSION >= 3 + op->func_doc = PyUnicode_FromString(((PyCFunctionObject*)op)->m_ml->ml_doc); +#else + op->func_doc = PyString_FromString(((PyCFunctionObject*)op)->m_ml->ml_doc); +#endif + if (unlikely(op->func_doc == NULL)) + return NULL; + } else { + Py_INCREF(Py_None); + return Py_None; + } +#endif + } + Py_INCREF(op->func_doc); + return op->func_doc; +} +static int +__Pyx_CyFunction_set_doc(__pyx_CyFunctionObject *op, PyObject *value, void *context) +{ + CYTHON_UNUSED_VAR(context); + if (value == NULL) { + value = Py_None; + } + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->func_doc, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_name(__pyx_CyFunctionObject *op, void *context) +{ + CYTHON_UNUSED_VAR(context); + if (unlikely(op->func_name == NULL)) { +#if CYTHON_COMPILING_IN_LIMITED_API + op->func_name = PyObject_GetAttrString(op->func, "__name__"); +#elif PY_MAJOR_VERSION >= 3 + op->func_name = PyUnicode_InternFromString(((PyCFunctionObject*)op)->m_ml->ml_name); +#else + op->func_name = PyString_InternFromString(((PyCFunctionObject*)op)->m_ml->ml_name); +#endif + if (unlikely(op->func_name == NULL)) + return NULL; + } + Py_INCREF(op->func_name); + return op->func_name; +} +static int +__Pyx_CyFunction_set_name(__pyx_CyFunctionObject *op, PyObject *value, void *context) +{ + CYTHON_UNUSED_VAR(context); +#if PY_MAJOR_VERSION >= 3 + if (unlikely(value == NULL || !PyUnicode_Check(value))) +#else + if (unlikely(value == NULL || !PyString_Check(value))) +#endif + { + PyErr_SetString(PyExc_TypeError, + "__name__ must be set to a string object"); + return -1; + } + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->func_name, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_qualname(__pyx_CyFunctionObject *op, void *context) +{ + CYTHON_UNUSED_VAR(context); + Py_INCREF(op->func_qualname); + return op->func_qualname; +} +static int +__Pyx_CyFunction_set_qualname(__pyx_CyFunctionObject *op, PyObject *value, void *context) +{ + CYTHON_UNUSED_VAR(context); +#if PY_MAJOR_VERSION >= 3 + if (unlikely(value == NULL || !PyUnicode_Check(value))) +#else + if (unlikely(value == NULL || !PyString_Check(value))) +#endif + { + PyErr_SetString(PyExc_TypeError, + "__qualname__ must be set to a string object"); + return -1; + } + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->func_qualname, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_dict(__pyx_CyFunctionObject *op, void *context) +{ + CYTHON_UNUSED_VAR(context); + if (unlikely(op->func_dict == NULL)) { + op->func_dict = PyDict_New(); + if (unlikely(op->func_dict == NULL)) + return NULL; + } + Py_INCREF(op->func_dict); + return op->func_dict; +} +static int +__Pyx_CyFunction_set_dict(__pyx_CyFunctionObject *op, PyObject *value, void *context) +{ + CYTHON_UNUSED_VAR(context); + if (unlikely(value == NULL)) { + PyErr_SetString(PyExc_TypeError, + "function's dictionary may not be deleted"); + return -1; + } + if (unlikely(!PyDict_Check(value))) { + PyErr_SetString(PyExc_TypeError, + "setting function's dictionary to a non-dict"); + return -1; + } + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->func_dict, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_globals(__pyx_CyFunctionObject *op, void *context) +{ + CYTHON_UNUSED_VAR(context); + Py_INCREF(op->func_globals); + return op->func_globals; +} +static PyObject * +__Pyx_CyFunction_get_closure(__pyx_CyFunctionObject *op, void *context) +{ + CYTHON_UNUSED_VAR(op); + CYTHON_UNUSED_VAR(context); + Py_INCREF(Py_None); + return Py_None; +} +static PyObject * +__Pyx_CyFunction_get_code(__pyx_CyFunctionObject *op, void *context) +{ + PyObject* result = (op->func_code) ? op->func_code : Py_None; + CYTHON_UNUSED_VAR(context); + Py_INCREF(result); + return result; +} +static int +__Pyx_CyFunction_init_defaults(__pyx_CyFunctionObject *op) { + int result = 0; + PyObject *res = op->defaults_getter((PyObject *) op); + if (unlikely(!res)) + return -1; + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + op->defaults_tuple = PyTuple_GET_ITEM(res, 0); + Py_INCREF(op->defaults_tuple); + op->defaults_kwdict = PyTuple_GET_ITEM(res, 1); + Py_INCREF(op->defaults_kwdict); + #else + op->defaults_tuple = __Pyx_PySequence_ITEM(res, 0); + if (unlikely(!op->defaults_tuple)) result = -1; + else { + op->defaults_kwdict = __Pyx_PySequence_ITEM(res, 1); + if (unlikely(!op->defaults_kwdict)) result = -1; + } + #endif + Py_DECREF(res); + return result; +} +static int +__Pyx_CyFunction_set_defaults(__pyx_CyFunctionObject *op, PyObject* value, void *context) { + CYTHON_UNUSED_VAR(context); + if (!value) { + value = Py_None; + } else if (unlikely(value != Py_None && !PyTuple_Check(value))) { + PyErr_SetString(PyExc_TypeError, + "__defaults__ must be set to a tuple object"); + return -1; + } + PyErr_WarnEx(PyExc_RuntimeWarning, "changes to cyfunction.__defaults__ will not " + "currently affect the values used in function calls", 1); + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->defaults_tuple, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_defaults(__pyx_CyFunctionObject *op, void *context) { + PyObject* result = op->defaults_tuple; + CYTHON_UNUSED_VAR(context); + if (unlikely(!result)) { + if (op->defaults_getter) { + if (unlikely(__Pyx_CyFunction_init_defaults(op) < 0)) return NULL; + result = op->defaults_tuple; + } else { + result = Py_None; + } + } + Py_INCREF(result); + return result; +} +static int +__Pyx_CyFunction_set_kwdefaults(__pyx_CyFunctionObject *op, PyObject* value, void *context) { + CYTHON_UNUSED_VAR(context); + if (!value) { + value = Py_None; + } else if (unlikely(value != Py_None && !PyDict_Check(value))) { + PyErr_SetString(PyExc_TypeError, + "__kwdefaults__ must be set to a dict object"); + return -1; + } + PyErr_WarnEx(PyExc_RuntimeWarning, "changes to cyfunction.__kwdefaults__ will not " + "currently affect the values used in function calls", 1); + Py_INCREF(value); + __Pyx_Py_XDECREF_SET(op->defaults_kwdict, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_kwdefaults(__pyx_CyFunctionObject *op, void *context) { + PyObject* result = op->defaults_kwdict; + CYTHON_UNUSED_VAR(context); + if (unlikely(!result)) { + if (op->defaults_getter) { + if (unlikely(__Pyx_CyFunction_init_defaults(op) < 0)) return NULL; + result = op->defaults_kwdict; + } else { + result = Py_None; + } + } + Py_INCREF(result); + return result; +} +static int +__Pyx_CyFunction_set_annotations(__pyx_CyFunctionObject *op, PyObject* value, void *context) { + CYTHON_UNUSED_VAR(context); + if (!value || value == Py_None) { + value = NULL; + } else if (unlikely(!PyDict_Check(value))) { + PyErr_SetString(PyExc_TypeError, + "__annotations__ must be set to a dict object"); + return -1; + } + Py_XINCREF(value); + __Pyx_Py_XDECREF_SET(op->func_annotations, value); + return 0; +} +static PyObject * +__Pyx_CyFunction_get_annotations(__pyx_CyFunctionObject *op, void *context) { + PyObject* result = op->func_annotations; + CYTHON_UNUSED_VAR(context); + if (unlikely(!result)) { + result = PyDict_New(); + if (unlikely(!result)) return NULL; + op->func_annotations = result; + } + Py_INCREF(result); + return result; +} +static PyObject * +__Pyx_CyFunction_get_is_coroutine(__pyx_CyFunctionObject *op, void *context) { + int is_coroutine; + CYTHON_UNUSED_VAR(context); + if (op->func_is_coroutine) { + return __Pyx_NewRef(op->func_is_coroutine); + } + is_coroutine = op->flags & __Pyx_CYFUNCTION_COROUTINE; +#if PY_VERSION_HEX >= 0x03050000 + if (is_coroutine) { + PyObject *module, *fromlist, *marker = __pyx_n_s_is_coroutine; + fromlist = PyList_New(1); + if (unlikely(!fromlist)) return NULL; + Py_INCREF(marker); +#if CYTHON_ASSUME_SAFE_MACROS + PyList_SET_ITEM(fromlist, 0, marker); +#else + if (unlikely(PyList_SetItem(fromlist, 0, marker) < 0)) { + Py_DECREF(marker); + Py_DECREF(fromlist); + return NULL; + } +#endif + module = PyImport_ImportModuleLevelObject(__pyx_n_s_asyncio_coroutines, NULL, NULL, fromlist, 0); + Py_DECREF(fromlist); + if (unlikely(!module)) goto ignore; + op->func_is_coroutine = __Pyx_PyObject_GetAttrStr(module, marker); + Py_DECREF(module); + if (likely(op->func_is_coroutine)) { + return __Pyx_NewRef(op->func_is_coroutine); + } +ignore: + PyErr_Clear(); + } +#endif + op->func_is_coroutine = __Pyx_PyBool_FromLong(is_coroutine); + return __Pyx_NewRef(op->func_is_coroutine); +} +#if CYTHON_COMPILING_IN_LIMITED_API +static PyObject * +__Pyx_CyFunction_get_module(__pyx_CyFunctionObject *op, void *context) { + CYTHON_UNUSED_VAR(context); + return PyObject_GetAttrString(op->func, "__module__"); +} +static int +__Pyx_CyFunction_set_module(__pyx_CyFunctionObject *op, PyObject* value, void *context) { + CYTHON_UNUSED_VAR(context); + return PyObject_SetAttrString(op->func, "__module__", value); +} +#endif +static PyGetSetDef __pyx_CyFunction_getsets[] = { + {(char *) "func_doc", (getter)__Pyx_CyFunction_get_doc, (setter)__Pyx_CyFunction_set_doc, 0, 0}, + {(char *) "__doc__", (getter)__Pyx_CyFunction_get_doc, (setter)__Pyx_CyFunction_set_doc, 0, 0}, + {(char *) "func_name", (getter)__Pyx_CyFunction_get_name, (setter)__Pyx_CyFunction_set_name, 0, 0}, + {(char *) "__name__", (getter)__Pyx_CyFunction_get_name, (setter)__Pyx_CyFunction_set_name, 0, 0}, + {(char *) "__qualname__", (getter)__Pyx_CyFunction_get_qualname, (setter)__Pyx_CyFunction_set_qualname, 0, 0}, + {(char *) "func_dict", (getter)__Pyx_CyFunction_get_dict, (setter)__Pyx_CyFunction_set_dict, 0, 0}, + {(char *) "__dict__", (getter)__Pyx_CyFunction_get_dict, (setter)__Pyx_CyFunction_set_dict, 0, 0}, + {(char *) "func_globals", (getter)__Pyx_CyFunction_get_globals, 0, 0, 0}, + {(char *) "__globals__", (getter)__Pyx_CyFunction_get_globals, 0, 0, 0}, + {(char *) "func_closure", (getter)__Pyx_CyFunction_get_closure, 0, 0, 0}, + {(char *) "__closure__", (getter)__Pyx_CyFunction_get_closure, 0, 0, 0}, + {(char *) "func_code", (getter)__Pyx_CyFunction_get_code, 0, 0, 0}, + {(char *) "__code__", (getter)__Pyx_CyFunction_get_code, 0, 0, 0}, + {(char *) "func_defaults", (getter)__Pyx_CyFunction_get_defaults, (setter)__Pyx_CyFunction_set_defaults, 0, 0}, + {(char *) "__defaults__", (getter)__Pyx_CyFunction_get_defaults, (setter)__Pyx_CyFunction_set_defaults, 0, 0}, + {(char *) "__kwdefaults__", (getter)__Pyx_CyFunction_get_kwdefaults, (setter)__Pyx_CyFunction_set_kwdefaults, 0, 0}, + {(char *) "__annotations__", (getter)__Pyx_CyFunction_get_annotations, (setter)__Pyx_CyFunction_set_annotations, 0, 0}, + {(char *) "_is_coroutine", (getter)__Pyx_CyFunction_get_is_coroutine, 0, 0, 0}, +#if CYTHON_COMPILING_IN_LIMITED_API + {"__module__", (getter)__Pyx_CyFunction_get_module, (setter)__Pyx_CyFunction_set_module, 0, 0}, +#endif + {0, 0, 0, 0, 0} +}; +static PyMemberDef __pyx_CyFunction_members[] = { +#if !CYTHON_COMPILING_IN_LIMITED_API + {(char *) "__module__", T_OBJECT, offsetof(PyCFunctionObject, m_module), 0, 0}, +#endif +#if CYTHON_USE_TYPE_SPECS + {(char *) "__dictoffset__", T_PYSSIZET, offsetof(__pyx_CyFunctionObject, func_dict), READONLY, 0}, +#if CYTHON_METH_FASTCALL +#if CYTHON_BACKPORT_VECTORCALL + {(char *) "__vectorcalloffset__", T_PYSSIZET, offsetof(__pyx_CyFunctionObject, func_vectorcall), READONLY, 0}, +#else +#if !CYTHON_COMPILING_IN_LIMITED_API + {(char *) "__vectorcalloffset__", T_PYSSIZET, offsetof(PyCFunctionObject, vectorcall), READONLY, 0}, +#endif +#endif +#endif +#if PY_VERSION_HEX < 0x030500A0 || CYTHON_COMPILING_IN_LIMITED_API + {(char *) "__weaklistoffset__", T_PYSSIZET, offsetof(__pyx_CyFunctionObject, func_weakreflist), READONLY, 0}, +#else + {(char *) "__weaklistoffset__", T_PYSSIZET, offsetof(PyCFunctionObject, m_weakreflist), READONLY, 0}, +#endif +#endif + {0, 0, 0, 0, 0} +}; +static PyObject * +__Pyx_CyFunction_reduce(__pyx_CyFunctionObject *m, PyObject *args) +{ + CYTHON_UNUSED_VAR(args); +#if PY_MAJOR_VERSION >= 3 + Py_INCREF(m->func_qualname); + return m->func_qualname; +#else + return PyString_FromString(((PyCFunctionObject*)m)->m_ml->ml_name); +#endif +} +static PyMethodDef __pyx_CyFunction_methods[] = { + {"__reduce__", (PyCFunction)__Pyx_CyFunction_reduce, METH_VARARGS, 0}, + {0, 0, 0, 0} +}; +#if PY_VERSION_HEX < 0x030500A0 || CYTHON_COMPILING_IN_LIMITED_API +#define __Pyx_CyFunction_weakreflist(cyfunc) ((cyfunc)->func_weakreflist) +#else +#define __Pyx_CyFunction_weakreflist(cyfunc) (((PyCFunctionObject*)cyfunc)->m_weakreflist) +#endif +static PyObject *__Pyx_CyFunction_Init(__pyx_CyFunctionObject *op, PyMethodDef *ml, int flags, PyObject* qualname, + PyObject *closure, PyObject *module, PyObject* globals, PyObject* code) { +#if !CYTHON_COMPILING_IN_LIMITED_API + PyCFunctionObject *cf = (PyCFunctionObject*) op; +#endif + if (unlikely(op == NULL)) + return NULL; +#if CYTHON_COMPILING_IN_LIMITED_API + op->func = PyCFunction_NewEx(ml, (PyObject*)op, module); + if (unlikely(!op->func)) return NULL; +#endif + op->flags = flags; + __Pyx_CyFunction_weakreflist(op) = NULL; +#if !CYTHON_COMPILING_IN_LIMITED_API + cf->m_ml = ml; + cf->m_self = (PyObject *) op; +#endif + Py_XINCREF(closure); + op->func_closure = closure; +#if !CYTHON_COMPILING_IN_LIMITED_API + Py_XINCREF(module); + cf->m_module = module; +#endif + op->func_dict = NULL; + op->func_name = NULL; + Py_INCREF(qualname); + op->func_qualname = qualname; + op->func_doc = NULL; +#if PY_VERSION_HEX < 0x030900B1 || CYTHON_COMPILING_IN_LIMITED_API + op->func_classobj = NULL; +#else + ((PyCMethodObject*)op)->mm_class = NULL; +#endif + op->func_globals = globals; + Py_INCREF(op->func_globals); + Py_XINCREF(code); + op->func_code = code; + op->defaults_pyobjects = 0; + op->defaults_size = 0; + op->defaults = NULL; + op->defaults_tuple = NULL; + op->defaults_kwdict = NULL; + op->defaults_getter = NULL; + op->func_annotations = NULL; + op->func_is_coroutine = NULL; +#if CYTHON_METH_FASTCALL + switch (ml->ml_flags & (METH_VARARGS | METH_FASTCALL | METH_NOARGS | METH_O | METH_KEYWORDS | METH_METHOD)) { + case METH_NOARGS: + __Pyx_CyFunction_func_vectorcall(op) = __Pyx_CyFunction_Vectorcall_NOARGS; + break; + case METH_O: + __Pyx_CyFunction_func_vectorcall(op) = __Pyx_CyFunction_Vectorcall_O; + break; + case METH_METHOD | METH_FASTCALL | METH_KEYWORDS: + __Pyx_CyFunction_func_vectorcall(op) = __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS_METHOD; + break; + case METH_FASTCALL | METH_KEYWORDS: + __Pyx_CyFunction_func_vectorcall(op) = __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS; + break; + case METH_VARARGS | METH_KEYWORDS: + __Pyx_CyFunction_func_vectorcall(op) = NULL; + break; + default: + PyErr_SetString(PyExc_SystemError, "Bad call flags for CyFunction"); + Py_DECREF(op); + return NULL; + } +#endif + return (PyObject *) op; +} +static int +__Pyx_CyFunction_clear(__pyx_CyFunctionObject *m) +{ + Py_CLEAR(m->func_closure); +#if CYTHON_COMPILING_IN_LIMITED_API + Py_CLEAR(m->func); +#else + Py_CLEAR(((PyCFunctionObject*)m)->m_module); +#endif + Py_CLEAR(m->func_dict); + Py_CLEAR(m->func_name); + Py_CLEAR(m->func_qualname); + Py_CLEAR(m->func_doc); + Py_CLEAR(m->func_globals); + Py_CLEAR(m->func_code); +#if !CYTHON_COMPILING_IN_LIMITED_API +#if PY_VERSION_HEX < 0x030900B1 + Py_CLEAR(__Pyx_CyFunction_GetClassObj(m)); +#else + { + PyObject *cls = (PyObject*) ((PyCMethodObject *) (m))->mm_class; + ((PyCMethodObject *) (m))->mm_class = NULL; + Py_XDECREF(cls); + } +#endif +#endif + Py_CLEAR(m->defaults_tuple); + Py_CLEAR(m->defaults_kwdict); + Py_CLEAR(m->func_annotations); + Py_CLEAR(m->func_is_coroutine); + if (m->defaults) { + PyObject **pydefaults = __Pyx_CyFunction_Defaults(PyObject *, m); + int i; + for (i = 0; i < m->defaults_pyobjects; i++) + Py_XDECREF(pydefaults[i]); + PyObject_Free(m->defaults); + m->defaults = NULL; + } + return 0; +} +static void __Pyx__CyFunction_dealloc(__pyx_CyFunctionObject *m) +{ + if (__Pyx_CyFunction_weakreflist(m) != NULL) + PyObject_ClearWeakRefs((PyObject *) m); + __Pyx_CyFunction_clear(m); + __Pyx_PyHeapTypeObject_GC_Del(m); +} +static void __Pyx_CyFunction_dealloc(__pyx_CyFunctionObject *m) +{ + PyObject_GC_UnTrack(m); + __Pyx__CyFunction_dealloc(m); +} +static int __Pyx_CyFunction_traverse(__pyx_CyFunctionObject *m, visitproc visit, void *arg) +{ + Py_VISIT(m->func_closure); +#if CYTHON_COMPILING_IN_LIMITED_API + Py_VISIT(m->func); +#else + Py_VISIT(((PyCFunctionObject*)m)->m_module); +#endif + Py_VISIT(m->func_dict); + Py_VISIT(m->func_name); + Py_VISIT(m->func_qualname); + Py_VISIT(m->func_doc); + Py_VISIT(m->func_globals); + Py_VISIT(m->func_code); +#if !CYTHON_COMPILING_IN_LIMITED_API + Py_VISIT(__Pyx_CyFunction_GetClassObj(m)); +#endif + Py_VISIT(m->defaults_tuple); + Py_VISIT(m->defaults_kwdict); + Py_VISIT(m->func_is_coroutine); + if (m->defaults) { + PyObject **pydefaults = __Pyx_CyFunction_Defaults(PyObject *, m); + int i; + for (i = 0; i < m->defaults_pyobjects; i++) + Py_VISIT(pydefaults[i]); + } + return 0; +} +static PyObject* +__Pyx_CyFunction_repr(__pyx_CyFunctionObject *op) +{ +#if PY_MAJOR_VERSION >= 3 + return PyUnicode_FromFormat("", + op->func_qualname, (void *)op); +#else + return PyString_FromFormat("", + PyString_AsString(op->func_qualname), (void *)op); +#endif +} +static PyObject * __Pyx_CyFunction_CallMethod(PyObject *func, PyObject *self, PyObject *arg, PyObject *kw) { +#if CYTHON_COMPILING_IN_LIMITED_API + PyObject *f = ((__pyx_CyFunctionObject*)func)->func; + PyObject *py_name = NULL; + PyCFunction meth; + int flags; + meth = PyCFunction_GetFunction(f); + if (unlikely(!meth)) return NULL; + flags = PyCFunction_GetFlags(f); + if (unlikely(flags < 0)) return NULL; +#else + PyCFunctionObject* f = (PyCFunctionObject*)func; + PyCFunction meth = f->m_ml->ml_meth; + int flags = f->m_ml->ml_flags; +#endif + Py_ssize_t size; + switch (flags & (METH_VARARGS | METH_KEYWORDS | METH_NOARGS | METH_O)) { + case METH_VARARGS: + if (likely(kw == NULL || PyDict_Size(kw) == 0)) + return (*meth)(self, arg); + break; + case METH_VARARGS | METH_KEYWORDS: + return (*(PyCFunctionWithKeywords)(void*)meth)(self, arg, kw); + case METH_NOARGS: + if (likely(kw == NULL || PyDict_Size(kw) == 0)) { +#if CYTHON_ASSUME_SAFE_MACROS + size = PyTuple_GET_SIZE(arg); +#else + size = PyTuple_Size(arg); + if (unlikely(size < 0)) return NULL; +#endif + if (likely(size == 0)) + return (*meth)(self, NULL); +#if CYTHON_COMPILING_IN_LIMITED_API + py_name = __Pyx_CyFunction_get_name((__pyx_CyFunctionObject*)func, NULL); + if (!py_name) return NULL; + PyErr_Format(PyExc_TypeError, + "%.200S() takes no arguments (%" CYTHON_FORMAT_SSIZE_T "d given)", + py_name, size); + Py_DECREF(py_name); +#else + PyErr_Format(PyExc_TypeError, + "%.200s() takes no arguments (%" CYTHON_FORMAT_SSIZE_T "d given)", + f->m_ml->ml_name, size); +#endif + return NULL; + } + break; + case METH_O: + if (likely(kw == NULL || PyDict_Size(kw) == 0)) { +#if CYTHON_ASSUME_SAFE_MACROS + size = PyTuple_GET_SIZE(arg); +#else + size = PyTuple_Size(arg); + if (unlikely(size < 0)) return NULL; +#endif + if (likely(size == 1)) { + PyObject *result, *arg0; + #if CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS + arg0 = PyTuple_GET_ITEM(arg, 0); + #else + arg0 = __Pyx_PySequence_ITEM(arg, 0); if (unlikely(!arg0)) return NULL; + #endif + result = (*meth)(self, arg0); + #if !(CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS) + Py_DECREF(arg0); + #endif + return result; + } +#if CYTHON_COMPILING_IN_LIMITED_API + py_name = __Pyx_CyFunction_get_name((__pyx_CyFunctionObject*)func, NULL); + if (!py_name) return NULL; + PyErr_Format(PyExc_TypeError, + "%.200S() takes exactly one argument (%" CYTHON_FORMAT_SSIZE_T "d given)", + py_name, size); + Py_DECREF(py_name); +#else + PyErr_Format(PyExc_TypeError, + "%.200s() takes exactly one argument (%" CYTHON_FORMAT_SSIZE_T "d given)", + f->m_ml->ml_name, size); +#endif + return NULL; + } + break; + default: + PyErr_SetString(PyExc_SystemError, "Bad call flags for CyFunction"); + return NULL; + } +#if CYTHON_COMPILING_IN_LIMITED_API + py_name = __Pyx_CyFunction_get_name((__pyx_CyFunctionObject*)func, NULL); + if (!py_name) return NULL; + PyErr_Format(PyExc_TypeError, "%.200S() takes no keyword arguments", + py_name); + Py_DECREF(py_name); +#else + PyErr_Format(PyExc_TypeError, "%.200s() takes no keyword arguments", + f->m_ml->ml_name); +#endif + return NULL; +} +static CYTHON_INLINE PyObject *__Pyx_CyFunction_Call(PyObject *func, PyObject *arg, PyObject *kw) { + PyObject *self, *result; +#if CYTHON_COMPILING_IN_LIMITED_API + self = PyCFunction_GetSelf(((__pyx_CyFunctionObject*)func)->func); + if (unlikely(!self) && PyErr_Occurred()) return NULL; +#else + self = ((PyCFunctionObject*)func)->m_self; +#endif + result = __Pyx_CyFunction_CallMethod(func, self, arg, kw); + return result; +} +static PyObject *__Pyx_CyFunction_CallAsMethod(PyObject *func, PyObject *args, PyObject *kw) { + PyObject *result; + __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *) func; +#if CYTHON_METH_FASTCALL + __pyx_vectorcallfunc vc = __Pyx_CyFunction_func_vectorcall(cyfunc); + if (vc) { +#if CYTHON_ASSUME_SAFE_MACROS + return __Pyx_PyVectorcall_FastCallDict(func, vc, &PyTuple_GET_ITEM(args, 0), (size_t)PyTuple_GET_SIZE(args), kw); +#else + (void) &__Pyx_PyVectorcall_FastCallDict; + return PyVectorcall_Call(func, args, kw); +#endif + } +#endif + if ((cyfunc->flags & __Pyx_CYFUNCTION_CCLASS) && !(cyfunc->flags & __Pyx_CYFUNCTION_STATICMETHOD)) { + Py_ssize_t argc; + PyObject *new_args; + PyObject *self; +#if CYTHON_ASSUME_SAFE_MACROS + argc = PyTuple_GET_SIZE(args); +#else + argc = PyTuple_Size(args); + if (unlikely(!argc) < 0) return NULL; +#endif + new_args = PyTuple_GetSlice(args, 1, argc); + if (unlikely(!new_args)) + return NULL; + self = PyTuple_GetItem(args, 0); + if (unlikely(!self)) { + Py_DECREF(new_args); +#if PY_MAJOR_VERSION > 2 + PyErr_Format(PyExc_TypeError, + "unbound method %.200S() needs an argument", + cyfunc->func_qualname); +#else + PyErr_SetString(PyExc_TypeError, + "unbound method needs an argument"); +#endif + return NULL; + } + result = __Pyx_CyFunction_CallMethod(func, self, new_args, kw); + Py_DECREF(new_args); + } else { + result = __Pyx_CyFunction_Call(func, args, kw); + } + return result; +} +#if CYTHON_METH_FASTCALL +static CYTHON_INLINE int __Pyx_CyFunction_Vectorcall_CheckArgs(__pyx_CyFunctionObject *cyfunc, Py_ssize_t nargs, PyObject *kwnames) +{ + int ret = 0; + if ((cyfunc->flags & __Pyx_CYFUNCTION_CCLASS) && !(cyfunc->flags & __Pyx_CYFUNCTION_STATICMETHOD)) { + if (unlikely(nargs < 1)) { + PyErr_Format(PyExc_TypeError, "%.200s() needs an argument", + ((PyCFunctionObject*)cyfunc)->m_ml->ml_name); + return -1; + } + ret = 1; + } + if (unlikely(kwnames) && unlikely(PyTuple_GET_SIZE(kwnames))) { + PyErr_Format(PyExc_TypeError, + "%.200s() takes no keyword arguments", ((PyCFunctionObject*)cyfunc)->m_ml->ml_name); + return -1; + } + return ret; +} +static PyObject * __Pyx_CyFunction_Vectorcall_NOARGS(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames) +{ + __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *)func; + PyMethodDef* def = ((PyCFunctionObject*)cyfunc)->m_ml; +#if CYTHON_BACKPORT_VECTORCALL + Py_ssize_t nargs = (Py_ssize_t)nargsf; +#else + Py_ssize_t nargs = PyVectorcall_NARGS(nargsf); +#endif + PyObject *self; + switch (__Pyx_CyFunction_Vectorcall_CheckArgs(cyfunc, nargs, kwnames)) { + case 1: + self = args[0]; + args += 1; + nargs -= 1; + break; + case 0: + self = ((PyCFunctionObject*)cyfunc)->m_self; + break; + default: + return NULL; + } + if (unlikely(nargs != 0)) { + PyErr_Format(PyExc_TypeError, + "%.200s() takes no arguments (%" CYTHON_FORMAT_SSIZE_T "d given)", + def->ml_name, nargs); + return NULL; + } + return def->ml_meth(self, NULL); +} +static PyObject * __Pyx_CyFunction_Vectorcall_O(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames) +{ + __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *)func; + PyMethodDef* def = ((PyCFunctionObject*)cyfunc)->m_ml; +#if CYTHON_BACKPORT_VECTORCALL + Py_ssize_t nargs = (Py_ssize_t)nargsf; +#else + Py_ssize_t nargs = PyVectorcall_NARGS(nargsf); +#endif + PyObject *self; + switch (__Pyx_CyFunction_Vectorcall_CheckArgs(cyfunc, nargs, kwnames)) { + case 1: + self = args[0]; + args += 1; + nargs -= 1; + break; + case 0: + self = ((PyCFunctionObject*)cyfunc)->m_self; + break; + default: + return NULL; + } + if (unlikely(nargs != 1)) { + PyErr_Format(PyExc_TypeError, + "%.200s() takes exactly one argument (%" CYTHON_FORMAT_SSIZE_T "d given)", + def->ml_name, nargs); + return NULL; + } + return def->ml_meth(self, args[0]); +} +static PyObject * __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames) +{ + __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *)func; + PyMethodDef* def = ((PyCFunctionObject*)cyfunc)->m_ml; +#if CYTHON_BACKPORT_VECTORCALL + Py_ssize_t nargs = (Py_ssize_t)nargsf; +#else + Py_ssize_t nargs = PyVectorcall_NARGS(nargsf); +#endif + PyObject *self; + switch (__Pyx_CyFunction_Vectorcall_CheckArgs(cyfunc, nargs, NULL)) { + case 1: + self = args[0]; + args += 1; + nargs -= 1; + break; + case 0: + self = ((PyCFunctionObject*)cyfunc)->m_self; + break; + default: + return NULL; + } + return ((__Pyx_PyCFunctionFastWithKeywords)(void(*)(void))def->ml_meth)(self, args, nargs, kwnames); +} +static PyObject * __Pyx_CyFunction_Vectorcall_FASTCALL_KEYWORDS_METHOD(PyObject *func, PyObject *const *args, size_t nargsf, PyObject *kwnames) +{ + __pyx_CyFunctionObject *cyfunc = (__pyx_CyFunctionObject *)func; + PyMethodDef* def = ((PyCFunctionObject*)cyfunc)->m_ml; + PyTypeObject *cls = (PyTypeObject *) __Pyx_CyFunction_GetClassObj(cyfunc); +#if CYTHON_BACKPORT_VECTORCALL + Py_ssize_t nargs = (Py_ssize_t)nargsf; +#else + Py_ssize_t nargs = PyVectorcall_NARGS(nargsf); +#endif + PyObject *self; + switch (__Pyx_CyFunction_Vectorcall_CheckArgs(cyfunc, nargs, NULL)) { + case 1: + self = args[0]; + args += 1; + nargs -= 1; + break; + case 0: + self = ((PyCFunctionObject*)cyfunc)->m_self; + break; + default: + return NULL; + } + return ((__Pyx_PyCMethod)(void(*)(void))def->ml_meth)(self, cls, args, (size_t)nargs, kwnames); +} +#endif +#if CYTHON_USE_TYPE_SPECS +static PyType_Slot __pyx_CyFunctionType_slots[] = { + {Py_tp_dealloc, (void *)__Pyx_CyFunction_dealloc}, + {Py_tp_repr, (void *)__Pyx_CyFunction_repr}, + {Py_tp_call, (void *)__Pyx_CyFunction_CallAsMethod}, + {Py_tp_traverse, (void *)__Pyx_CyFunction_traverse}, + {Py_tp_clear, (void *)__Pyx_CyFunction_clear}, + {Py_tp_methods, (void *)__pyx_CyFunction_methods}, + {Py_tp_members, (void *)__pyx_CyFunction_members}, + {Py_tp_getset, (void *)__pyx_CyFunction_getsets}, + {Py_tp_descr_get, (void *)__Pyx_PyMethod_New}, + {0, 0}, +}; +static PyType_Spec __pyx_CyFunctionType_spec = { + __PYX_TYPE_MODULE_PREFIX "cython_function_or_method", + sizeof(__pyx_CyFunctionObject), + 0, +#ifdef Py_TPFLAGS_METHOD_DESCRIPTOR + Py_TPFLAGS_METHOD_DESCRIPTOR | +#endif +#if (defined(_Py_TPFLAGS_HAVE_VECTORCALL) && CYTHON_METH_FASTCALL) + _Py_TPFLAGS_HAVE_VECTORCALL | +#endif + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE, + __pyx_CyFunctionType_slots +}; +#else +static PyTypeObject __pyx_CyFunctionType_type = { + PyVarObject_HEAD_INIT(0, 0) + __PYX_TYPE_MODULE_PREFIX "cython_function_or_method", + sizeof(__pyx_CyFunctionObject), + 0, + (destructor) __Pyx_CyFunction_dealloc, +#if !CYTHON_METH_FASTCALL + 0, +#elif CYTHON_BACKPORT_VECTORCALL + (printfunc)offsetof(__pyx_CyFunctionObject, func_vectorcall), +#else + offsetof(PyCFunctionObject, vectorcall), +#endif + 0, + 0, +#if PY_MAJOR_VERSION < 3 + 0, +#else + 0, +#endif + (reprfunc) __Pyx_CyFunction_repr, + 0, + 0, + 0, + 0, + __Pyx_CyFunction_CallAsMethod, + 0, + 0, + 0, + 0, +#ifdef Py_TPFLAGS_METHOD_DESCRIPTOR + Py_TPFLAGS_METHOD_DESCRIPTOR | +#endif +#if defined(_Py_TPFLAGS_HAVE_VECTORCALL) && CYTHON_METH_FASTCALL + _Py_TPFLAGS_HAVE_VECTORCALL | +#endif + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_BASETYPE, + 0, + (traverseproc) __Pyx_CyFunction_traverse, + (inquiry) __Pyx_CyFunction_clear, + 0, +#if PY_VERSION_HEX < 0x030500A0 + offsetof(__pyx_CyFunctionObject, func_weakreflist), +#else + offsetof(PyCFunctionObject, m_weakreflist), +#endif + 0, + 0, + __pyx_CyFunction_methods, + __pyx_CyFunction_members, + __pyx_CyFunction_getsets, + 0, + 0, + __Pyx_PyMethod_New, + 0, + offsetof(__pyx_CyFunctionObject, func_dict), + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, +#if PY_VERSION_HEX >= 0x030400a1 + 0, +#endif +#if PY_VERSION_HEX >= 0x030800b1 && (!CYTHON_COMPILING_IN_PYPY || PYPY_VERSION_NUM >= 0x07030800) + 0, +#endif +#if __PYX_NEED_TP_PRINT_SLOT + 0, +#endif +#if PY_VERSION_HEX >= 0x030C0000 + 0, +#endif +#if PY_VERSION_HEX >= 0x030d00A4 + 0, +#endif +#if CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX >= 0x03090000 && PY_VERSION_HEX < 0x030a0000 + 0, +#endif +}; +#endif +static int __pyx_CyFunction_init(PyObject *module) { +#if CYTHON_USE_TYPE_SPECS + __pyx_CyFunctionType = __Pyx_FetchCommonTypeFromSpec(module, &__pyx_CyFunctionType_spec, NULL); +#else + CYTHON_UNUSED_VAR(module); + __pyx_CyFunctionType = __Pyx_FetchCommonType(&__pyx_CyFunctionType_type); +#endif + if (unlikely(__pyx_CyFunctionType == NULL)) { + return -1; + } + return 0; +} +static CYTHON_INLINE void *__Pyx_CyFunction_InitDefaults(PyObject *func, size_t size, int pyobjects) { + __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; + m->defaults = PyObject_Malloc(size); + if (unlikely(!m->defaults)) + return PyErr_NoMemory(); + memset(m->defaults, 0, size); + m->defaults_pyobjects = pyobjects; + m->defaults_size = size; + return m->defaults; +} +static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsTuple(PyObject *func, PyObject *tuple) { + __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; + m->defaults_tuple = tuple; + Py_INCREF(tuple); +} +static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsKwDict(PyObject *func, PyObject *dict) { + __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; + m->defaults_kwdict = dict; + Py_INCREF(dict); +} +static CYTHON_INLINE void __Pyx_CyFunction_SetAnnotationsDict(PyObject *func, PyObject *dict) { + __pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func; + m->func_annotations = dict; + Py_INCREF(dict); +} + +/* CythonFunction */ +static PyObject *__Pyx_CyFunction_New(PyMethodDef *ml, int flags, PyObject* qualname, + PyObject *closure, PyObject *module, PyObject* globals, PyObject* code) { + PyObject *op = __Pyx_CyFunction_Init( + PyObject_GC_New(__pyx_CyFunctionObject, __pyx_CyFunctionType), + ml, flags, qualname, closure, module, globals, code + ); + if (likely(op)) { + PyObject_GC_Track(op); + } + return op; +} + +/* CLineInTraceback */ +#ifndef CYTHON_CLINE_IN_TRACEBACK +static int __Pyx_CLineForTraceback(PyThreadState *tstate, int c_line) { + PyObject *use_cline; + PyObject *ptype, *pvalue, *ptraceback; +#if CYTHON_COMPILING_IN_CPYTHON + PyObject **cython_runtime_dict; +#endif + CYTHON_MAYBE_UNUSED_VAR(tstate); + if (unlikely(!__pyx_cython_runtime)) { + return c_line; + } + __Pyx_ErrFetchInState(tstate, &ptype, &pvalue, &ptraceback); +#if CYTHON_COMPILING_IN_CPYTHON + cython_runtime_dict = _PyObject_GetDictPtr(__pyx_cython_runtime); + if (likely(cython_runtime_dict)) { + __PYX_PY_DICT_LOOKUP_IF_MODIFIED( + use_cline, *cython_runtime_dict, + __Pyx_PyDict_GetItemStr(*cython_runtime_dict, __pyx_n_s_cline_in_traceback)) + } else +#endif + { + PyObject *use_cline_obj = __Pyx_PyObject_GetAttrStrNoError(__pyx_cython_runtime, __pyx_n_s_cline_in_traceback); + if (use_cline_obj) { + use_cline = PyObject_Not(use_cline_obj) ? Py_False : Py_True; + Py_DECREF(use_cline_obj); + } else { + PyErr_Clear(); + use_cline = NULL; + } + } + if (!use_cline) { + c_line = 0; + (void) PyObject_SetAttr(__pyx_cython_runtime, __pyx_n_s_cline_in_traceback, Py_False); + } + else if (use_cline == Py_False || (use_cline != Py_True && PyObject_Not(use_cline) != 0)) { + c_line = 0; + } + __Pyx_ErrRestoreInState(tstate, ptype, pvalue, ptraceback); + return c_line; +} +#endif + +/* CodeObjectCache */ +#if !CYTHON_COMPILING_IN_LIMITED_API +static int __pyx_bisect_code_objects(__Pyx_CodeObjectCacheEntry* entries, int count, int code_line) { + int start = 0, mid = 0, end = count - 1; + if (end >= 0 && code_line > entries[end].code_line) { + return count; + } + while (start < end) { + mid = start + (end - start) / 2; + if (code_line < entries[mid].code_line) { + end = mid; + } else if (code_line > entries[mid].code_line) { + start = mid + 1; + } else { + return mid; + } + } + if (code_line <= entries[mid].code_line) { + return mid; + } else { + return mid + 1; + } +} +static PyCodeObject *__pyx_find_code_object(int code_line) { + PyCodeObject* code_object; + int pos; + if (unlikely(!code_line) || unlikely(!__pyx_code_cache.entries)) { + return NULL; + } + pos = __pyx_bisect_code_objects(__pyx_code_cache.entries, __pyx_code_cache.count, code_line); + if (unlikely(pos >= __pyx_code_cache.count) || unlikely(__pyx_code_cache.entries[pos].code_line != code_line)) { + return NULL; + } + code_object = __pyx_code_cache.entries[pos].code_object; + Py_INCREF(code_object); + return code_object; +} +static void __pyx_insert_code_object(int code_line, PyCodeObject* code_object) { + int pos, i; + __Pyx_CodeObjectCacheEntry* entries = __pyx_code_cache.entries; + if (unlikely(!code_line)) { + return; + } + if (unlikely(!entries)) { + entries = (__Pyx_CodeObjectCacheEntry*)PyMem_Malloc(64*sizeof(__Pyx_CodeObjectCacheEntry)); + if (likely(entries)) { + __pyx_code_cache.entries = entries; + __pyx_code_cache.max_count = 64; + __pyx_code_cache.count = 1; + entries[0].code_line = code_line; + entries[0].code_object = code_object; + Py_INCREF(code_object); + } + return; + } + pos = __pyx_bisect_code_objects(__pyx_code_cache.entries, __pyx_code_cache.count, code_line); + if ((pos < __pyx_code_cache.count) && unlikely(__pyx_code_cache.entries[pos].code_line == code_line)) { + PyCodeObject* tmp = entries[pos].code_object; + entries[pos].code_object = code_object; + Py_DECREF(tmp); + return; + } + if (__pyx_code_cache.count == __pyx_code_cache.max_count) { + int new_max = __pyx_code_cache.max_count + 64; + entries = (__Pyx_CodeObjectCacheEntry*)PyMem_Realloc( + __pyx_code_cache.entries, ((size_t)new_max) * sizeof(__Pyx_CodeObjectCacheEntry)); + if (unlikely(!entries)) { + return; + } + __pyx_code_cache.entries = entries; + __pyx_code_cache.max_count = new_max; + } + for (i=__pyx_code_cache.count; i>pos; i--) { + entries[i] = entries[i-1]; + } + entries[pos].code_line = code_line; + entries[pos].code_object = code_object; + __pyx_code_cache.count++; + Py_INCREF(code_object); +} +#endif + +/* AddTraceback */ +#include "compile.h" +#include "frameobject.h" +#include "traceback.h" +#if PY_VERSION_HEX >= 0x030b00a6 && !CYTHON_COMPILING_IN_LIMITED_API + #ifndef Py_BUILD_CORE + #define Py_BUILD_CORE 1 + #endif + #include "internal/pycore_frame.h" +#endif +#if CYTHON_COMPILING_IN_LIMITED_API +static PyObject *__Pyx_PyCode_Replace_For_AddTraceback(PyObject *code, PyObject *scratch_dict, + PyObject *firstlineno, PyObject *name) { + PyObject *replace = NULL; + if (unlikely(PyDict_SetItemString(scratch_dict, "co_firstlineno", firstlineno))) return NULL; + if (unlikely(PyDict_SetItemString(scratch_dict, "co_name", name))) return NULL; + replace = PyObject_GetAttrString(code, "replace"); + if (likely(replace)) { + PyObject *result; + result = PyObject_Call(replace, __pyx_empty_tuple, scratch_dict); + Py_DECREF(replace); + return result; + } + PyErr_Clear(); + #if __PYX_LIMITED_VERSION_HEX < 0x030780000 + { + PyObject *compiled = NULL, *result = NULL; + if (unlikely(PyDict_SetItemString(scratch_dict, "code", code))) return NULL; + if (unlikely(PyDict_SetItemString(scratch_dict, "type", (PyObject*)(&PyType_Type)))) return NULL; + compiled = Py_CompileString( + "out = type(code)(\n" + " code.co_argcount, code.co_kwonlyargcount, code.co_nlocals, code.co_stacksize,\n" + " code.co_flags, code.co_code, code.co_consts, code.co_names,\n" + " code.co_varnames, code.co_filename, co_name, co_firstlineno,\n" + " code.co_lnotab)\n", "", Py_file_input); + if (!compiled) return NULL; + result = PyEval_EvalCode(compiled, scratch_dict, scratch_dict); + Py_DECREF(compiled); + if (!result) PyErr_Print(); + Py_DECREF(result); + result = PyDict_GetItemString(scratch_dict, "out"); + if (result) Py_INCREF(result); + return result; + } + #else + return NULL; + #endif +} +static void __Pyx_AddTraceback(const char *funcname, int c_line, + int py_line, const char *filename) { + PyObject *code_object = NULL, *py_py_line = NULL, *py_funcname = NULL, *dict = NULL; + PyObject *replace = NULL, *getframe = NULL, *frame = NULL; + PyObject *exc_type, *exc_value, *exc_traceback; + int success = 0; + if (c_line) { + (void) __pyx_cfilenm; + (void) __Pyx_CLineForTraceback(__Pyx_PyThreadState_Current, c_line); + } + PyErr_Fetch(&exc_type, &exc_value, &exc_traceback); + code_object = Py_CompileString("_getframe()", filename, Py_eval_input); + if (unlikely(!code_object)) goto bad; + py_py_line = PyLong_FromLong(py_line); + if (unlikely(!py_py_line)) goto bad; + py_funcname = PyUnicode_FromString(funcname); + if (unlikely(!py_funcname)) goto bad; + dict = PyDict_New(); + if (unlikely(!dict)) goto bad; + { + PyObject *old_code_object = code_object; + code_object = __Pyx_PyCode_Replace_For_AddTraceback(code_object, dict, py_py_line, py_funcname); + Py_DECREF(old_code_object); + } + if (unlikely(!code_object)) goto bad; + getframe = PySys_GetObject("_getframe"); + if (unlikely(!getframe)) goto bad; + if (unlikely(PyDict_SetItemString(dict, "_getframe", getframe))) goto bad; + frame = PyEval_EvalCode(code_object, dict, dict); + if (unlikely(!frame) || frame == Py_None) goto bad; + success = 1; + bad: + PyErr_Restore(exc_type, exc_value, exc_traceback); + Py_XDECREF(code_object); + Py_XDECREF(py_py_line); + Py_XDECREF(py_funcname); + Py_XDECREF(dict); + Py_XDECREF(replace); + if (success) { + PyTraceBack_Here( + (struct _frame*)frame); + } + Py_XDECREF(frame); +} +#else +static PyCodeObject* __Pyx_CreateCodeObjectForTraceback( + const char *funcname, int c_line, + int py_line, const char *filename) { + PyCodeObject *py_code = NULL; + PyObject *py_funcname = NULL; + #if PY_MAJOR_VERSION < 3 + PyObject *py_srcfile = NULL; + py_srcfile = PyString_FromString(filename); + if (!py_srcfile) goto bad; + #endif + if (c_line) { + #if PY_MAJOR_VERSION < 3 + py_funcname = PyString_FromFormat( "%s (%s:%d)", funcname, __pyx_cfilenm, c_line); + if (!py_funcname) goto bad; + #else + py_funcname = PyUnicode_FromFormat( "%s (%s:%d)", funcname, __pyx_cfilenm, c_line); + if (!py_funcname) goto bad; + funcname = PyUnicode_AsUTF8(py_funcname); + if (!funcname) goto bad; + #endif + } + else { + #if PY_MAJOR_VERSION < 3 + py_funcname = PyString_FromString(funcname); + if (!py_funcname) goto bad; + #endif + } + #if PY_MAJOR_VERSION < 3 + py_code = __Pyx_PyCode_New( + 0, + 0, + 0, + 0, + 0, + 0, + __pyx_empty_bytes, /*PyObject *code,*/ + __pyx_empty_tuple, /*PyObject *consts,*/ + __pyx_empty_tuple, /*PyObject *names,*/ + __pyx_empty_tuple, /*PyObject *varnames,*/ + __pyx_empty_tuple, /*PyObject *freevars,*/ + __pyx_empty_tuple, /*PyObject *cellvars,*/ + py_srcfile, /*PyObject *filename,*/ + py_funcname, /*PyObject *name,*/ + py_line, + __pyx_empty_bytes /*PyObject *lnotab*/ + ); + Py_DECREF(py_srcfile); + #else + py_code = PyCode_NewEmpty(filename, funcname, py_line); + #endif + Py_XDECREF(py_funcname); + return py_code; +bad: + Py_XDECREF(py_funcname); + #if PY_MAJOR_VERSION < 3 + Py_XDECREF(py_srcfile); + #endif + return NULL; +} +static void __Pyx_AddTraceback(const char *funcname, int c_line, + int py_line, const char *filename) { + PyCodeObject *py_code = 0; + PyFrameObject *py_frame = 0; + PyThreadState *tstate = __Pyx_PyThreadState_Current; + PyObject *ptype, *pvalue, *ptraceback; + if (c_line) { + c_line = __Pyx_CLineForTraceback(tstate, c_line); + } + py_code = __pyx_find_code_object(c_line ? -c_line : py_line); + if (!py_code) { + __Pyx_ErrFetchInState(tstate, &ptype, &pvalue, &ptraceback); + py_code = __Pyx_CreateCodeObjectForTraceback( + funcname, c_line, py_line, filename); + if (!py_code) { + /* If the code object creation fails, then we should clear the + fetched exception references and propagate the new exception */ + Py_XDECREF(ptype); + Py_XDECREF(pvalue); + Py_XDECREF(ptraceback); + goto bad; + } + __Pyx_ErrRestoreInState(tstate, ptype, pvalue, ptraceback); + __pyx_insert_code_object(c_line ? -c_line : py_line, py_code); + } + py_frame = PyFrame_New( + tstate, /*PyThreadState *tstate,*/ + py_code, /*PyCodeObject *code,*/ + __pyx_d, /*PyObject *globals,*/ + 0 /*PyObject *locals*/ + ); + if (!py_frame) goto bad; + __Pyx_PyFrame_SetLineNumber(py_frame, py_line); + PyTraceBack_Here(py_frame); +bad: + Py_XDECREF(py_code); + Py_XDECREF(py_frame); +} +#endif + +#if PY_MAJOR_VERSION < 3 +static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags) { + __Pyx_TypeName obj_type_name; + if (PyObject_CheckBuffer(obj)) return PyObject_GetBuffer(obj, view, flags); + if (__Pyx_TypeCheck(obj, __pyx_array_type)) return __pyx_array_getbuffer(obj, view, flags); + if (__Pyx_TypeCheck(obj, __pyx_memoryview_type)) return __pyx_memoryview_getbuffer(obj, view, flags); + obj_type_name = __Pyx_PyType_GetName(Py_TYPE(obj)); + PyErr_Format(PyExc_TypeError, + "'" __Pyx_FMT_TYPENAME "' does not have the buffer interface", + obj_type_name); + __Pyx_DECREF_TypeName(obj_type_name); + return -1; +} +static void __Pyx_ReleaseBuffer(Py_buffer *view) { + PyObject *obj = view->obj; + if (!obj) return; + if (PyObject_CheckBuffer(obj)) { + PyBuffer_Release(view); + return; + } + if ((0)) {} + view->obj = NULL; + Py_DECREF(obj); +} +#endif + + +/* MemviewSliceIsContig */ +static int +__pyx_memviewslice_is_contig(const __Pyx_memviewslice mvs, char order, int ndim) +{ + int i, index, step, start; + Py_ssize_t itemsize = mvs.memview->view.itemsize; + if (order == 'F') { + step = 1; + start = 0; + } else { + step = -1; + start = ndim - 1; + } + for (i = 0; i < ndim; i++) { + index = start + step * i; + if (mvs.suboffsets[index] >= 0 || mvs.strides[index] != itemsize) + return 0; + itemsize *= mvs.shape[index]; + } + return 1; +} + +/* OverlappingSlices */ +static void +__pyx_get_array_memory_extents(__Pyx_memviewslice *slice, + void **out_start, void **out_end, + int ndim, size_t itemsize) +{ + char *start, *end; + int i; + start = end = slice->data; + for (i = 0; i < ndim; i++) { + Py_ssize_t stride = slice->strides[i]; + Py_ssize_t extent = slice->shape[i]; + if (extent == 0) { + *out_start = *out_end = start; + return; + } else { + if (stride > 0) + end += stride * (extent - 1); + else + start += stride * (extent - 1); + } + } + *out_start = start; + *out_end = end + itemsize; +} +static int +__pyx_slices_overlap(__Pyx_memviewslice *slice1, + __Pyx_memviewslice *slice2, + int ndim, size_t itemsize) +{ + void *start1, *end1, *start2, *end2; + __pyx_get_array_memory_extents(slice1, &start1, &end1, ndim, itemsize); + __pyx_get_array_memory_extents(slice2, &start2, &end2, ndim, itemsize); + return (start1 < end2) && (start2 < end1); +} + +/* IsLittleEndian */ +static CYTHON_INLINE int __Pyx_Is_Little_Endian(void) +{ + union { + uint32_t u32; + uint8_t u8[4]; + } S; + S.u32 = 0x01020304; + return S.u8[0] == 4; +} + +/* BufferFormatCheck */ +static void __Pyx_BufFmt_Init(__Pyx_BufFmt_Context* ctx, + __Pyx_BufFmt_StackElem* stack, + __Pyx_TypeInfo* type) { + stack[0].field = &ctx->root; + stack[0].parent_offset = 0; + ctx->root.type = type; + ctx->root.name = "buffer dtype"; + ctx->root.offset = 0; + ctx->head = stack; + ctx->head->field = &ctx->root; + ctx->fmt_offset = 0; + ctx->head->parent_offset = 0; + ctx->new_packmode = '@'; + ctx->enc_packmode = '@'; + ctx->new_count = 1; + ctx->enc_count = 0; + ctx->enc_type = 0; + ctx->is_complex = 0; + ctx->is_valid_array = 0; + ctx->struct_alignment = 0; + while (type->typegroup == 'S') { + ++ctx->head; + ctx->head->field = type->fields; + ctx->head->parent_offset = 0; + type = type->fields->type; + } +} +static int __Pyx_BufFmt_ParseNumber(const char** ts) { + int count; + const char* t = *ts; + if (*t < '0' || *t > '9') { + return -1; + } else { + count = *t++ - '0'; + while (*t >= '0' && *t <= '9') { + count *= 10; + count += *t++ - '0'; + } + } + *ts = t; + return count; +} +static int __Pyx_BufFmt_ExpectNumber(const char **ts) { + int number = __Pyx_BufFmt_ParseNumber(ts); + if (number == -1) + PyErr_Format(PyExc_ValueError,\ + "Does not understand character buffer dtype format string ('%c')", **ts); + return number; +} +static void __Pyx_BufFmt_RaiseUnexpectedChar(char ch) { + PyErr_Format(PyExc_ValueError, + "Unexpected format string character: '%c'", ch); +} +static const char* __Pyx_BufFmt_DescribeTypeChar(char ch, int is_complex) { + switch (ch) { + case '?': return "'bool'"; + case 'c': return "'char'"; + case 'b': return "'signed char'"; + case 'B': return "'unsigned char'"; + case 'h': return "'short'"; + case 'H': return "'unsigned short'"; + case 'i': return "'int'"; + case 'I': return "'unsigned int'"; + case 'l': return "'long'"; + case 'L': return "'unsigned long'"; + case 'q': return "'long long'"; + case 'Q': return "'unsigned long long'"; + case 'f': return (is_complex ? "'complex float'" : "'float'"); + case 'd': return (is_complex ? "'complex double'" : "'double'"); + case 'g': return (is_complex ? "'complex long double'" : "'long double'"); + case 'T': return "a struct"; + case 'O': return "Python object"; + case 'P': return "a pointer"; + case 's': case 'p': return "a string"; + case 0: return "end"; + default: return "unparsable format string"; + } +} +static size_t __Pyx_BufFmt_TypeCharToStandardSize(char ch, int is_complex) { + switch (ch) { + case '?': case 'c': case 'b': case 'B': case 's': case 'p': return 1; + case 'h': case 'H': return 2; + case 'i': case 'I': case 'l': case 'L': return 4; + case 'q': case 'Q': return 8; + case 'f': return (is_complex ? 8 : 4); + case 'd': return (is_complex ? 16 : 8); + case 'g': { + PyErr_SetString(PyExc_ValueError, "Python does not define a standard format string size for long double ('g').."); + return 0; + } + case 'O': case 'P': return sizeof(void*); + default: + __Pyx_BufFmt_RaiseUnexpectedChar(ch); + return 0; + } +} +static size_t __Pyx_BufFmt_TypeCharToNativeSize(char ch, int is_complex) { + switch (ch) { + case '?': case 'c': case 'b': case 'B': case 's': case 'p': return 1; + case 'h': case 'H': return sizeof(short); + case 'i': case 'I': return sizeof(int); + case 'l': case 'L': return sizeof(long); + #ifdef HAVE_LONG_LONG + case 'q': case 'Q': return sizeof(PY_LONG_LONG); + #endif + case 'f': return sizeof(float) * (is_complex ? 2 : 1); + case 'd': return sizeof(double) * (is_complex ? 2 : 1); + case 'g': return sizeof(long double) * (is_complex ? 2 : 1); + case 'O': case 'P': return sizeof(void*); + default: { + __Pyx_BufFmt_RaiseUnexpectedChar(ch); + return 0; + } + } +} +typedef struct { char c; short x; } __Pyx_st_short; +typedef struct { char c; int x; } __Pyx_st_int; +typedef struct { char c; long x; } __Pyx_st_long; +typedef struct { char c; float x; } __Pyx_st_float; +typedef struct { char c; double x; } __Pyx_st_double; +typedef struct { char c; long double x; } __Pyx_st_longdouble; +typedef struct { char c; void *x; } __Pyx_st_void_p; +#ifdef HAVE_LONG_LONG +typedef struct { char c; PY_LONG_LONG x; } __Pyx_st_longlong; +#endif +static size_t __Pyx_BufFmt_TypeCharToAlignment(char ch, int is_complex) { + CYTHON_UNUSED_VAR(is_complex); + switch (ch) { + case '?': case 'c': case 'b': case 'B': case 's': case 'p': return 1; + case 'h': case 'H': return sizeof(__Pyx_st_short) - sizeof(short); + case 'i': case 'I': return sizeof(__Pyx_st_int) - sizeof(int); + case 'l': case 'L': return sizeof(__Pyx_st_long) - sizeof(long); +#ifdef HAVE_LONG_LONG + case 'q': case 'Q': return sizeof(__Pyx_st_longlong) - sizeof(PY_LONG_LONG); +#endif + case 'f': return sizeof(__Pyx_st_float) - sizeof(float); + case 'd': return sizeof(__Pyx_st_double) - sizeof(double); + case 'g': return sizeof(__Pyx_st_longdouble) - sizeof(long double); + case 'P': case 'O': return sizeof(__Pyx_st_void_p) - sizeof(void*); + default: + __Pyx_BufFmt_RaiseUnexpectedChar(ch); + return 0; + } +} +/* These are for computing the padding at the end of the struct to align + on the first member of the struct. This will probably the same as above, + but we don't have any guarantees. + */ +typedef struct { short x; char c; } __Pyx_pad_short; +typedef struct { int x; char c; } __Pyx_pad_int; +typedef struct { long x; char c; } __Pyx_pad_long; +typedef struct { float x; char c; } __Pyx_pad_float; +typedef struct { double x; char c; } __Pyx_pad_double; +typedef struct { long double x; char c; } __Pyx_pad_longdouble; +typedef struct { void *x; char c; } __Pyx_pad_void_p; +#ifdef HAVE_LONG_LONG +typedef struct { PY_LONG_LONG x; char c; } __Pyx_pad_longlong; +#endif +static size_t __Pyx_BufFmt_TypeCharToPadding(char ch, int is_complex) { + CYTHON_UNUSED_VAR(is_complex); + switch (ch) { + case '?': case 'c': case 'b': case 'B': case 's': case 'p': return 1; + case 'h': case 'H': return sizeof(__Pyx_pad_short) - sizeof(short); + case 'i': case 'I': return sizeof(__Pyx_pad_int) - sizeof(int); + case 'l': case 'L': return sizeof(__Pyx_pad_long) - sizeof(long); +#ifdef HAVE_LONG_LONG + case 'q': case 'Q': return sizeof(__Pyx_pad_longlong) - sizeof(PY_LONG_LONG); +#endif + case 'f': return sizeof(__Pyx_pad_float) - sizeof(float); + case 'd': return sizeof(__Pyx_pad_double) - sizeof(double); + case 'g': return sizeof(__Pyx_pad_longdouble) - sizeof(long double); + case 'P': case 'O': return sizeof(__Pyx_pad_void_p) - sizeof(void*); + default: + __Pyx_BufFmt_RaiseUnexpectedChar(ch); + return 0; + } +} +static char __Pyx_BufFmt_TypeCharToGroup(char ch, int is_complex) { + switch (ch) { + case 'c': + return 'H'; + case 'b': case 'h': case 'i': + case 'l': case 'q': case 's': case 'p': + return 'I'; + case '?': case 'B': case 'H': case 'I': case 'L': case 'Q': + return 'U'; + case 'f': case 'd': case 'g': + return (is_complex ? 'C' : 'R'); + case 'O': + return 'O'; + case 'P': + return 'P'; + default: { + __Pyx_BufFmt_RaiseUnexpectedChar(ch); + return 0; + } + } +} +static void __Pyx_BufFmt_RaiseExpected(__Pyx_BufFmt_Context* ctx) { + if (ctx->head == NULL || ctx->head->field == &ctx->root) { + const char* expected; + const char* quote; + if (ctx->head == NULL) { + expected = "end"; + quote = ""; + } else { + expected = ctx->head->field->type->name; + quote = "'"; + } + PyErr_Format(PyExc_ValueError, + "Buffer dtype mismatch, expected %s%s%s but got %s", + quote, expected, quote, + __Pyx_BufFmt_DescribeTypeChar(ctx->enc_type, ctx->is_complex)); + } else { + __Pyx_StructField* field = ctx->head->field; + __Pyx_StructField* parent = (ctx->head - 1)->field; + PyErr_Format(PyExc_ValueError, + "Buffer dtype mismatch, expected '%s' but got %s in '%s.%s'", + field->type->name, __Pyx_BufFmt_DescribeTypeChar(ctx->enc_type, ctx->is_complex), + parent->type->name, field->name); + } +} +static int __Pyx_BufFmt_ProcessTypeChunk(__Pyx_BufFmt_Context* ctx) { + char group; + size_t size, offset, arraysize = 1; + if (ctx->enc_type == 0) return 0; + if (ctx->head->field->type->arraysize[0]) { + int i, ndim = 0; + if (ctx->enc_type == 's' || ctx->enc_type == 'p') { + ctx->is_valid_array = ctx->head->field->type->ndim == 1; + ndim = 1; + if (ctx->enc_count != ctx->head->field->type->arraysize[0]) { + PyErr_Format(PyExc_ValueError, + "Expected a dimension of size %zu, got %zu", + ctx->head->field->type->arraysize[0], ctx->enc_count); + return -1; + } + } + if (!ctx->is_valid_array) { + PyErr_Format(PyExc_ValueError, "Expected %d dimensions, got %d", + ctx->head->field->type->ndim, ndim); + return -1; + } + for (i = 0; i < ctx->head->field->type->ndim; i++) { + arraysize *= ctx->head->field->type->arraysize[i]; + } + ctx->is_valid_array = 0; + ctx->enc_count = 1; + } + group = __Pyx_BufFmt_TypeCharToGroup(ctx->enc_type, ctx->is_complex); + do { + __Pyx_StructField* field = ctx->head->field; + __Pyx_TypeInfo* type = field->type; + if (ctx->enc_packmode == '@' || ctx->enc_packmode == '^') { + size = __Pyx_BufFmt_TypeCharToNativeSize(ctx->enc_type, ctx->is_complex); + } else { + size = __Pyx_BufFmt_TypeCharToStandardSize(ctx->enc_type, ctx->is_complex); + } + if (ctx->enc_packmode == '@') { + size_t align_at = __Pyx_BufFmt_TypeCharToAlignment(ctx->enc_type, ctx->is_complex); + size_t align_mod_offset; + if (align_at == 0) return -1; + align_mod_offset = ctx->fmt_offset % align_at; + if (align_mod_offset > 0) ctx->fmt_offset += align_at - align_mod_offset; + if (ctx->struct_alignment == 0) + ctx->struct_alignment = __Pyx_BufFmt_TypeCharToPadding(ctx->enc_type, + ctx->is_complex); + } + if (type->size != size || type->typegroup != group) { + if (type->typegroup == 'C' && type->fields != NULL) { + size_t parent_offset = ctx->head->parent_offset + field->offset; + ++ctx->head; + ctx->head->field = type->fields; + ctx->head->parent_offset = parent_offset; + continue; + } + if ((type->typegroup == 'H' || group == 'H') && type->size == size) { + } else { + __Pyx_BufFmt_RaiseExpected(ctx); + return -1; + } + } + offset = ctx->head->parent_offset + field->offset; + if (ctx->fmt_offset != offset) { + PyErr_Format(PyExc_ValueError, + "Buffer dtype mismatch; next field is at offset %" CYTHON_FORMAT_SSIZE_T "d but %" CYTHON_FORMAT_SSIZE_T "d expected", + (Py_ssize_t)ctx->fmt_offset, (Py_ssize_t)offset); + return -1; + } + ctx->fmt_offset += size; + if (arraysize) + ctx->fmt_offset += (arraysize - 1) * size; + --ctx->enc_count; + while (1) { + if (field == &ctx->root) { + ctx->head = NULL; + if (ctx->enc_count != 0) { + __Pyx_BufFmt_RaiseExpected(ctx); + return -1; + } + break; + } + ctx->head->field = ++field; + if (field->type == NULL) { + --ctx->head; + field = ctx->head->field; + continue; + } else if (field->type->typegroup == 'S') { + size_t parent_offset = ctx->head->parent_offset + field->offset; + if (field->type->fields->type == NULL) continue; + field = field->type->fields; + ++ctx->head; + ctx->head->field = field; + ctx->head->parent_offset = parent_offset; + break; + } else { + break; + } + } + } while (ctx->enc_count); + ctx->enc_type = 0; + ctx->is_complex = 0; + return 0; +} +static int +__pyx_buffmt_parse_array(__Pyx_BufFmt_Context* ctx, const char** tsp) +{ + const char *ts = *tsp; + int i = 0, number, ndim; + ++ts; + if (ctx->new_count != 1) { + PyErr_SetString(PyExc_ValueError, + "Cannot handle repeated arrays in format string"); + return -1; + } + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return -1; + ndim = ctx->head->field->type->ndim; + while (*ts && *ts != ')') { + switch (*ts) { + case ' ': case '\f': case '\r': case '\n': case '\t': case '\v': continue; + default: break; + } + number = __Pyx_BufFmt_ExpectNumber(&ts); + if (number == -1) return -1; + if (i < ndim && (size_t) number != ctx->head->field->type->arraysize[i]) { + PyErr_Format(PyExc_ValueError, + "Expected a dimension of size %zu, got %d", + ctx->head->field->type->arraysize[i], number); + return -1; + } + if (*ts != ',' && *ts != ')') { + PyErr_Format(PyExc_ValueError, + "Expected a comma in format string, got '%c'", *ts); + return -1; + } + if (*ts == ',') ts++; + i++; + } + if (i != ndim) { + PyErr_Format(PyExc_ValueError, "Expected %d dimension(s), got %d", + ctx->head->field->type->ndim, i); + return -1; + } + if (!*ts) { + PyErr_SetString(PyExc_ValueError, + "Unexpected end of format string, expected ')'"); + return -1; + } + ctx->is_valid_array = 1; + ctx->new_count = 1; + *tsp = ++ts; + return 0; +} +static const char* __Pyx_BufFmt_CheckString(__Pyx_BufFmt_Context* ctx, const char* ts) { + int got_Z = 0; + while (1) { + switch(*ts) { + case 0: + if (ctx->enc_type != 0 && ctx->head == NULL) { + __Pyx_BufFmt_RaiseExpected(ctx); + return NULL; + } + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL; + if (ctx->head != NULL) { + __Pyx_BufFmt_RaiseExpected(ctx); + return NULL; + } + return ts; + case ' ': + case '\r': + case '\n': + ++ts; + break; + case '<': + if (!__Pyx_Is_Little_Endian()) { + PyErr_SetString(PyExc_ValueError, "Little-endian buffer not supported on big-endian compiler"); + return NULL; + } + ctx->new_packmode = '='; + ++ts; + break; + case '>': + case '!': + if (__Pyx_Is_Little_Endian()) { + PyErr_SetString(PyExc_ValueError, "Big-endian buffer not supported on little-endian compiler"); + return NULL; + } + ctx->new_packmode = '='; + ++ts; + break; + case '=': + case '@': + case '^': + ctx->new_packmode = *ts++; + break; + case 'T': + { + const char* ts_after_sub; + size_t i, struct_count = ctx->new_count; + size_t struct_alignment = ctx->struct_alignment; + ctx->new_count = 1; + ++ts; + if (*ts != '{') { + PyErr_SetString(PyExc_ValueError, "Buffer acquisition: Expected '{' after 'T'"); + return NULL; + } + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL; + ctx->enc_type = 0; + ctx->enc_count = 0; + ctx->struct_alignment = 0; + ++ts; + ts_after_sub = ts; + for (i = 0; i != struct_count; ++i) { + ts_after_sub = __Pyx_BufFmt_CheckString(ctx, ts); + if (!ts_after_sub) return NULL; + } + ts = ts_after_sub; + if (struct_alignment) ctx->struct_alignment = struct_alignment; + } + break; + case '}': + { + size_t alignment = ctx->struct_alignment; + ++ts; + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL; + ctx->enc_type = 0; + if (alignment && ctx->fmt_offset % alignment) { + ctx->fmt_offset += alignment - (ctx->fmt_offset % alignment); + } + } + return ts; + case 'x': + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL; + ctx->fmt_offset += ctx->new_count; + ctx->new_count = 1; + ctx->enc_count = 0; + ctx->enc_type = 0; + ctx->enc_packmode = ctx->new_packmode; + ++ts; + break; + case 'Z': + got_Z = 1; + ++ts; + if (*ts != 'f' && *ts != 'd' && *ts != 'g') { + __Pyx_BufFmt_RaiseUnexpectedChar('Z'); + return NULL; + } + CYTHON_FALLTHROUGH; + case '?': case 'c': case 'b': case 'B': case 'h': case 'H': case 'i': case 'I': + case 'l': case 'L': case 'q': case 'Q': + case 'f': case 'd': case 'g': + case 'O': case 'p': + if ((ctx->enc_type == *ts) && (got_Z == ctx->is_complex) && + (ctx->enc_packmode == ctx->new_packmode) && (!ctx->is_valid_array)) { + ctx->enc_count += ctx->new_count; + ctx->new_count = 1; + got_Z = 0; + ++ts; + break; + } + CYTHON_FALLTHROUGH; + case 's': + if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL; + ctx->enc_count = ctx->new_count; + ctx->enc_packmode = ctx->new_packmode; + ctx->enc_type = *ts; + ctx->is_complex = got_Z; + ++ts; + ctx->new_count = 1; + got_Z = 0; + break; + case ':': + ++ts; + while(*ts != ':') ++ts; + ++ts; + break; + case '(': + if (__pyx_buffmt_parse_array(ctx, &ts) < 0) return NULL; + break; + default: + { + int number = __Pyx_BufFmt_ExpectNumber(&ts); + if (number == -1) return NULL; + ctx->new_count = (size_t)number; + } + } + } +} + +/* TypeInfoCompare */ + static int +__pyx_typeinfo_cmp(__Pyx_TypeInfo *a, __Pyx_TypeInfo *b) +{ + int i; + if (!a || !b) + return 0; + if (a == b) + return 1; + if (a->size != b->size || a->typegroup != b->typegroup || + a->is_unsigned != b->is_unsigned || a->ndim != b->ndim) { + if (a->typegroup == 'H' || b->typegroup == 'H') { + return a->size == b->size; + } else { + return 0; + } + } + if (a->ndim) { + for (i = 0; i < a->ndim; i++) + if (a->arraysize[i] != b->arraysize[i]) + return 0; + } + if (a->typegroup == 'S') { + if (a->flags != b->flags) + return 0; + if (a->fields || b->fields) { + if (!(a->fields && b->fields)) + return 0; + for (i = 0; a->fields[i].type && b->fields[i].type; i++) { + __Pyx_StructField *field_a = a->fields + i; + __Pyx_StructField *field_b = b->fields + i; + if (field_a->offset != field_b->offset || + !__pyx_typeinfo_cmp(field_a->type, field_b->type)) + return 0; + } + return !a->fields[i].type && !b->fields[i].type; + } + } + return 1; +} + +/* MemviewSliceValidateAndInit */ + static int +__pyx_check_strides(Py_buffer *buf, int dim, int ndim, int spec) +{ + if (buf->shape[dim] <= 1) + return 1; + if (buf->strides) { + if (spec & __Pyx_MEMVIEW_CONTIG) { + if (spec & (__Pyx_MEMVIEW_PTR|__Pyx_MEMVIEW_FULL)) { + if (unlikely(buf->strides[dim] != sizeof(void *))) { + PyErr_Format(PyExc_ValueError, + "Buffer is not indirectly contiguous " + "in dimension %d.", dim); + goto fail; + } + } else if (unlikely(buf->strides[dim] != buf->itemsize)) { + PyErr_SetString(PyExc_ValueError, + "Buffer and memoryview are not contiguous " + "in the same dimension."); + goto fail; + } + } + if (spec & __Pyx_MEMVIEW_FOLLOW) { + Py_ssize_t stride = buf->strides[dim]; + if (stride < 0) + stride = -stride; + if (unlikely(stride < buf->itemsize)) { + PyErr_SetString(PyExc_ValueError, + "Buffer and memoryview are not contiguous " + "in the same dimension."); + goto fail; + } + } + } else { + if (unlikely(spec & __Pyx_MEMVIEW_CONTIG && dim != ndim - 1)) { + PyErr_Format(PyExc_ValueError, + "C-contiguous buffer is not contiguous in " + "dimension %d", dim); + goto fail; + } else if (unlikely(spec & (__Pyx_MEMVIEW_PTR))) { + PyErr_Format(PyExc_ValueError, + "C-contiguous buffer is not indirect in " + "dimension %d", dim); + goto fail; + } else if (unlikely(buf->suboffsets)) { + PyErr_SetString(PyExc_ValueError, + "Buffer exposes suboffsets but no strides"); + goto fail; + } + } + return 1; +fail: + return 0; +} +static int +__pyx_check_suboffsets(Py_buffer *buf, int dim, int ndim, int spec) +{ + CYTHON_UNUSED_VAR(ndim); + if (spec & __Pyx_MEMVIEW_DIRECT) { + if (unlikely(buf->suboffsets && buf->suboffsets[dim] >= 0)) { + PyErr_Format(PyExc_ValueError, + "Buffer not compatible with direct access " + "in dimension %d.", dim); + goto fail; + } + } + if (spec & __Pyx_MEMVIEW_PTR) { + if (unlikely(!buf->suboffsets || (buf->suboffsets[dim] < 0))) { + PyErr_Format(PyExc_ValueError, + "Buffer is not indirectly accessible " + "in dimension %d.", dim); + goto fail; + } + } + return 1; +fail: + return 0; +} +static int +__pyx_verify_contig(Py_buffer *buf, int ndim, int c_or_f_flag) +{ + int i; + if (c_or_f_flag & __Pyx_IS_F_CONTIG) { + Py_ssize_t stride = 1; + for (i = 0; i < ndim; i++) { + if (unlikely(stride * buf->itemsize != buf->strides[i] && buf->shape[i] > 1)) { + PyErr_SetString(PyExc_ValueError, + "Buffer not fortran contiguous."); + goto fail; + } + stride = stride * buf->shape[i]; + } + } else if (c_or_f_flag & __Pyx_IS_C_CONTIG) { + Py_ssize_t stride = 1; + for (i = ndim - 1; i >- 1; i--) { + if (unlikely(stride * buf->itemsize != buf->strides[i] && buf->shape[i] > 1)) { + PyErr_SetString(PyExc_ValueError, + "Buffer not C contiguous."); + goto fail; + } + stride = stride * buf->shape[i]; + } + } + return 1; +fail: + return 0; +} +static int __Pyx_ValidateAndInit_memviewslice( + int *axes_specs, + int c_or_f_flag, + int buf_flags, + int ndim, + __Pyx_TypeInfo *dtype, + __Pyx_BufFmt_StackElem stack[], + __Pyx_memviewslice *memviewslice, + PyObject *original_obj) +{ + struct __pyx_memoryview_obj *memview, *new_memview; + __Pyx_RefNannyDeclarations + Py_buffer *buf; + int i, spec = 0, retval = -1; + __Pyx_BufFmt_Context ctx; + int from_memoryview = __pyx_memoryview_check(original_obj); + __Pyx_RefNannySetupContext("ValidateAndInit_memviewslice", 0); + if (from_memoryview && __pyx_typeinfo_cmp(dtype, ((struct __pyx_memoryview_obj *) + original_obj)->typeinfo)) { + memview = (struct __pyx_memoryview_obj *) original_obj; + new_memview = NULL; + } else { + memview = (struct __pyx_memoryview_obj *) __pyx_memoryview_new( + original_obj, buf_flags, 0, dtype); + new_memview = memview; + if (unlikely(!memview)) + goto fail; + } + buf = &memview->view; + if (unlikely(buf->ndim != ndim)) { + PyErr_Format(PyExc_ValueError, + "Buffer has wrong number of dimensions (expected %d, got %d)", + ndim, buf->ndim); + goto fail; + } + if (new_memview) { + __Pyx_BufFmt_Init(&ctx, stack, dtype); + if (unlikely(!__Pyx_BufFmt_CheckString(&ctx, buf->format))) goto fail; + } + if (unlikely((unsigned) buf->itemsize != dtype->size)) { + PyErr_Format(PyExc_ValueError, + "Item size of buffer (%" CYTHON_FORMAT_SSIZE_T "u byte%s) " + "does not match size of '%s' (%" CYTHON_FORMAT_SSIZE_T "u byte%s)", + buf->itemsize, + (buf->itemsize > 1) ? "s" : "", + dtype->name, + dtype->size, + (dtype->size > 1) ? "s" : ""); + goto fail; + } + if (buf->len > 0) { + for (i = 0; i < ndim; i++) { + spec = axes_specs[i]; + if (unlikely(!__pyx_check_strides(buf, i, ndim, spec))) + goto fail; + if (unlikely(!__pyx_check_suboffsets(buf, i, ndim, spec))) + goto fail; + } + if (unlikely(buf->strides && !__pyx_verify_contig(buf, ndim, c_or_f_flag))) + goto fail; + } + if (unlikely(__Pyx_init_memviewslice(memview, ndim, memviewslice, + new_memview != NULL) == -1)) { + goto fail; + } + retval = 0; + goto no_fail; +fail: + Py_XDECREF(new_memview); + retval = -1; +no_fail: + __Pyx_RefNannyFinishContext(); + return retval; +} + +/* ObjectToMemviewSlice */ + static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_d_d_dc_int(PyObject *obj, int writable_flag) { + __Pyx_memviewslice result = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_BufFmt_StackElem stack[1]; + int axes_specs[] = { (__Pyx_MEMVIEW_DIRECT | __Pyx_MEMVIEW_FOLLOW), (__Pyx_MEMVIEW_DIRECT | __Pyx_MEMVIEW_FOLLOW), (__Pyx_MEMVIEW_DIRECT | __Pyx_MEMVIEW_CONTIG) }; + int retcode; + if (obj == Py_None) { + result.memview = (struct __pyx_memoryview_obj *) Py_None; + return result; + } + retcode = __Pyx_ValidateAndInit_memviewslice(axes_specs, __Pyx_IS_C_CONTIG, + (PyBUF_C_CONTIGUOUS | PyBUF_FORMAT) | writable_flag, 3, + &__Pyx_TypeInfo_int, stack, + &result, obj); + if (unlikely(retcode == -1)) + goto __pyx_fail; + return result; +__pyx_fail: + result.memview = NULL; + result.data = NULL; + return result; +} + +/* ObjectToMemviewSlice */ + static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_d_d_dc_float(PyObject *obj, int writable_flag) { + __Pyx_memviewslice result = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_BufFmt_StackElem stack[1]; + int axes_specs[] = { (__Pyx_MEMVIEW_DIRECT | __Pyx_MEMVIEW_FOLLOW), (__Pyx_MEMVIEW_DIRECT | __Pyx_MEMVIEW_FOLLOW), (__Pyx_MEMVIEW_DIRECT | __Pyx_MEMVIEW_CONTIG) }; + int retcode; + if (obj == Py_None) { + result.memview = (struct __pyx_memoryview_obj *) Py_None; + return result; + } + retcode = __Pyx_ValidateAndInit_memviewslice(axes_specs, __Pyx_IS_C_CONTIG, + (PyBUF_C_CONTIGUOUS | PyBUF_FORMAT) | writable_flag, 3, + &__Pyx_TypeInfo_float, stack, + &result, obj); + if (unlikely(retcode == -1)) + goto __pyx_fail; + return result; +__pyx_fail: + result.memview = NULL; + result.data = NULL; + return result; +} + +/* ObjectToMemviewSlice */ + static CYTHON_INLINE __Pyx_memviewslice __Pyx_PyObject_to_MemoryviewSlice_dc_int(PyObject *obj, int writable_flag) { + __Pyx_memviewslice result = { 0, 0, { 0 }, { 0 }, { 0 } }; + __Pyx_BufFmt_StackElem stack[1]; + int axes_specs[] = { (__Pyx_MEMVIEW_DIRECT | __Pyx_MEMVIEW_CONTIG) }; + int retcode; + if (obj == Py_None) { + result.memview = (struct __pyx_memoryview_obj *) Py_None; + return result; + } + retcode = __Pyx_ValidateAndInit_memviewslice(axes_specs, __Pyx_IS_C_CONTIG, + (PyBUF_C_CONTIGUOUS | PyBUF_FORMAT) | writable_flag, 1, + &__Pyx_TypeInfo_int, stack, + &result, obj); + if (unlikely(retcode == -1)) + goto __pyx_fail; + return result; +__pyx_fail: + result.memview = NULL; + result.data = NULL; + return result; +} + +/* CIntFromPyVerify */ + #define __PYX_VERIFY_RETURN_INT(target_type, func_type, func_value)\ + __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, 0) +#define __PYX_VERIFY_RETURN_INT_EXC(target_type, func_type, func_value)\ + __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, 1) +#define __PYX__VERIFY_RETURN_INT(target_type, func_type, func_value, exc)\ + {\ + func_type value = func_value;\ + if (sizeof(target_type) < sizeof(func_type)) {\ + if (unlikely(value != (func_type) (target_type) value)) {\ + func_type zero = 0;\ + if (exc && unlikely(value == (func_type)-1 && PyErr_Occurred()))\ + return (target_type) -1;\ + if (is_unsigned && unlikely(value < zero))\ + goto raise_neg_overflow;\ + else\ + goto raise_overflow;\ + }\ + }\ + return (target_type) value;\ + } + +/* Declarations */ + #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #ifdef __cplusplus + static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float x, float y) { + return ::std::complex< float >(x, y); + } + #else + static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float x, float y) { + return x + y*(__pyx_t_float_complex)_Complex_I; + } + #endif +#else + static CYTHON_INLINE __pyx_t_float_complex __pyx_t_float_complex_from_parts(float x, float y) { + __pyx_t_float_complex z; + z.real = x; + z.imag = y; + return z; + } +#endif + +/* Arithmetic */ + #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) +#else + static CYTHON_INLINE int __Pyx_c_eq_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + return (a.real == b.real) && (a.imag == b.imag); + } + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_sum_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + __pyx_t_float_complex z; + z.real = a.real + b.real; + z.imag = a.imag + b.imag; + return z; + } + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_diff_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + __pyx_t_float_complex z; + z.real = a.real - b.real; + z.imag = a.imag - b.imag; + return z; + } + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_prod_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + __pyx_t_float_complex z; + z.real = a.real * b.real - a.imag * b.imag; + z.imag = a.real * b.imag + a.imag * b.real; + return z; + } + #if 1 + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_quot_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + if (b.imag == 0) { + return __pyx_t_float_complex_from_parts(a.real / b.real, a.imag / b.real); + } else if (fabsf(b.real) >= fabsf(b.imag)) { + if (b.real == 0 && b.imag == 0) { + return __pyx_t_float_complex_from_parts(a.real / b.real, a.imag / b.imag); + } else { + float r = b.imag / b.real; + float s = (float)(1.0) / (b.real + b.imag * r); + return __pyx_t_float_complex_from_parts( + (a.real + a.imag * r) * s, (a.imag - a.real * r) * s); + } + } else { + float r = b.real / b.imag; + float s = (float)(1.0) / (b.imag + b.real * r); + return __pyx_t_float_complex_from_parts( + (a.real * r + a.imag) * s, (a.imag * r - a.real) * s); + } + } + #else + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_quot_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + if (b.imag == 0) { + return __pyx_t_float_complex_from_parts(a.real / b.real, a.imag / b.real); + } else { + float denom = b.real * b.real + b.imag * b.imag; + return __pyx_t_float_complex_from_parts( + (a.real * b.real + a.imag * b.imag) / denom, + (a.imag * b.real - a.real * b.imag) / denom); + } + } + #endif + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_neg_float(__pyx_t_float_complex a) { + __pyx_t_float_complex z; + z.real = -a.real; + z.imag = -a.imag; + return z; + } + static CYTHON_INLINE int __Pyx_c_is_zero_float(__pyx_t_float_complex a) { + return (a.real == 0) && (a.imag == 0); + } + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_conj_float(__pyx_t_float_complex a) { + __pyx_t_float_complex z; + z.real = a.real; + z.imag = -a.imag; + return z; + } + #if 1 + static CYTHON_INLINE float __Pyx_c_abs_float(__pyx_t_float_complex z) { + #if !defined(HAVE_HYPOT) || defined(_MSC_VER) + return sqrtf(z.real*z.real + z.imag*z.imag); + #else + return hypotf(z.real, z.imag); + #endif + } + static CYTHON_INLINE __pyx_t_float_complex __Pyx_c_pow_float(__pyx_t_float_complex a, __pyx_t_float_complex b) { + __pyx_t_float_complex z; + float r, lnr, theta, z_r, z_theta; + if (b.imag == 0 && b.real == (int)b.real) { + if (b.real < 0) { + float denom = a.real * a.real + a.imag * a.imag; + a.real = a.real / denom; + a.imag = -a.imag / denom; + b.real = -b.real; + } + switch ((int)b.real) { + case 0: + z.real = 1; + z.imag = 0; + return z; + case 1: + return a; + case 2: + return __Pyx_c_prod_float(a, a); + case 3: + z = __Pyx_c_prod_float(a, a); + return __Pyx_c_prod_float(z, a); + case 4: + z = __Pyx_c_prod_float(a, a); + return __Pyx_c_prod_float(z, z); + } + } + if (a.imag == 0) { + if (a.real == 0) { + return a; + } else if ((b.imag == 0) && (a.real >= 0)) { + z.real = powf(a.real, b.real); + z.imag = 0; + return z; + } else if (a.real > 0) { + r = a.real; + theta = 0; + } else { + r = -a.real; + theta = atan2f(0.0, -1.0); + } + } else { + r = __Pyx_c_abs_float(a); + theta = atan2f(a.imag, a.real); + } + lnr = logf(r); + z_r = expf(lnr * b.real - theta * b.imag); + z_theta = theta * b.real + lnr * b.imag; + z.real = z_r * cosf(z_theta); + z.imag = z_r * sinf(z_theta); + return z; + } + #endif +#endif + +/* Declarations */ + #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #ifdef __cplusplus + static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double x, double y) { + return ::std::complex< double >(x, y); + } + #else + static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double x, double y) { + return x + y*(__pyx_t_double_complex)_Complex_I; + } + #endif +#else + static CYTHON_INLINE __pyx_t_double_complex __pyx_t_double_complex_from_parts(double x, double y) { + __pyx_t_double_complex z; + z.real = x; + z.imag = y; + return z; + } +#endif + +/* Arithmetic */ + #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) +#else + static CYTHON_INLINE int __Pyx_c_eq_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + return (a.real == b.real) && (a.imag == b.imag); + } + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_sum_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + __pyx_t_double_complex z; + z.real = a.real + b.real; + z.imag = a.imag + b.imag; + return z; + } + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_diff_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + __pyx_t_double_complex z; + z.real = a.real - b.real; + z.imag = a.imag - b.imag; + return z; + } + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_prod_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + __pyx_t_double_complex z; + z.real = a.real * b.real - a.imag * b.imag; + z.imag = a.real * b.imag + a.imag * b.real; + return z; + } + #if 1 + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_quot_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + if (b.imag == 0) { + return __pyx_t_double_complex_from_parts(a.real / b.real, a.imag / b.real); + } else if (fabs(b.real) >= fabs(b.imag)) { + if (b.real == 0 && b.imag == 0) { + return __pyx_t_double_complex_from_parts(a.real / b.real, a.imag / b.imag); + } else { + double r = b.imag / b.real; + double s = (double)(1.0) / (b.real + b.imag * r); + return __pyx_t_double_complex_from_parts( + (a.real + a.imag * r) * s, (a.imag - a.real * r) * s); + } + } else { + double r = b.real / b.imag; + double s = (double)(1.0) / (b.imag + b.real * r); + return __pyx_t_double_complex_from_parts( + (a.real * r + a.imag) * s, (a.imag * r - a.real) * s); + } + } + #else + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_quot_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + if (b.imag == 0) { + return __pyx_t_double_complex_from_parts(a.real / b.real, a.imag / b.real); + } else { + double denom = b.real * b.real + b.imag * b.imag; + return __pyx_t_double_complex_from_parts( + (a.real * b.real + a.imag * b.imag) / denom, + (a.imag * b.real - a.real * b.imag) / denom); + } + } + #endif + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_neg_double(__pyx_t_double_complex a) { + __pyx_t_double_complex z; + z.real = -a.real; + z.imag = -a.imag; + return z; + } + static CYTHON_INLINE int __Pyx_c_is_zero_double(__pyx_t_double_complex a) { + return (a.real == 0) && (a.imag == 0); + } + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_conj_double(__pyx_t_double_complex a) { + __pyx_t_double_complex z; + z.real = a.real; + z.imag = -a.imag; + return z; + } + #if 1 + static CYTHON_INLINE double __Pyx_c_abs_double(__pyx_t_double_complex z) { + #if !defined(HAVE_HYPOT) || defined(_MSC_VER) + return sqrt(z.real*z.real + z.imag*z.imag); + #else + return hypot(z.real, z.imag); + #endif + } + static CYTHON_INLINE __pyx_t_double_complex __Pyx_c_pow_double(__pyx_t_double_complex a, __pyx_t_double_complex b) { + __pyx_t_double_complex z; + double r, lnr, theta, z_r, z_theta; + if (b.imag == 0 && b.real == (int)b.real) { + if (b.real < 0) { + double denom = a.real * a.real + a.imag * a.imag; + a.real = a.real / denom; + a.imag = -a.imag / denom; + b.real = -b.real; + } + switch ((int)b.real) { + case 0: + z.real = 1; + z.imag = 0; + return z; + case 1: + return a; + case 2: + return __Pyx_c_prod_double(a, a); + case 3: + z = __Pyx_c_prod_double(a, a); + return __Pyx_c_prod_double(z, a); + case 4: + z = __Pyx_c_prod_double(a, a); + return __Pyx_c_prod_double(z, z); + } + } + if (a.imag == 0) { + if (a.real == 0) { + return a; + } else if ((b.imag == 0) && (a.real >= 0)) { + z.real = pow(a.real, b.real); + z.imag = 0; + return z; + } else if (a.real > 0) { + r = a.real; + theta = 0; + } else { + r = -a.real; + theta = atan2(0.0, -1.0); + } + } else { + r = __Pyx_c_abs_double(a); + theta = atan2(a.imag, a.real); + } + lnr = log(r); + z_r = exp(lnr * b.real - theta * b.imag); + z_theta = theta * b.real + lnr * b.imag; + z.real = z_r * cos(z_theta); + z.imag = z_r * sin(z_theta); + return z; + } + #endif +#endif + +/* Declarations */ + #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) + #ifdef __cplusplus + static CYTHON_INLINE __pyx_t_long_double_complex __pyx_t_long_double_complex_from_parts(long double x, long double y) { + return ::std::complex< long double >(x, y); + } + #else + static CYTHON_INLINE __pyx_t_long_double_complex __pyx_t_long_double_complex_from_parts(long double x, long double y) { + return x + y*(__pyx_t_long_double_complex)_Complex_I; + } + #endif +#else + static CYTHON_INLINE __pyx_t_long_double_complex __pyx_t_long_double_complex_from_parts(long double x, long double y) { + __pyx_t_long_double_complex z; + z.real = x; + z.imag = y; + return z; + } +#endif + +/* Arithmetic */ + #if CYTHON_CCOMPLEX && (1) && (!0 || __cplusplus) +#else + static CYTHON_INLINE int __Pyx_c_eq_long__double(__pyx_t_long_double_complex a, __pyx_t_long_double_complex b) { + return (a.real == b.real) && (a.imag == b.imag); + } + static CYTHON_INLINE __pyx_t_long_double_complex __Pyx_c_sum_long__double(__pyx_t_long_double_complex a, __pyx_t_long_double_complex b) { + __pyx_t_long_double_complex z; + z.real = a.real + b.real; + z.imag = a.imag + b.imag; + return z; + } + static CYTHON_INLINE __pyx_t_long_double_complex __Pyx_c_diff_long__double(__pyx_t_long_double_complex a, __pyx_t_long_double_complex b) { + __pyx_t_long_double_complex z; + z.real = a.real - b.real; + z.imag = a.imag - b.imag; + return z; + } + static CYTHON_INLINE __pyx_t_long_double_complex __Pyx_c_prod_long__double(__pyx_t_long_double_complex a, __pyx_t_long_double_complex b) { + __pyx_t_long_double_complex z; + z.real = a.real * b.real - a.imag * b.imag; + z.imag = a.real * b.imag + a.imag * b.real; + return z; + } + #if 1 + static CYTHON_INLINE __pyx_t_long_double_complex __Pyx_c_quot_long__double(__pyx_t_long_double_complex a, __pyx_t_long_double_complex b) { + if (b.imag == 0) { + return __pyx_t_long_double_complex_from_parts(a.real / b.real, a.imag / b.real); + } else if (fabsl(b.real) >= fabsl(b.imag)) { + if (b.real == 0 && b.imag == 0) { + return __pyx_t_long_double_complex_from_parts(a.real / b.real, a.imag / b.imag); + } else { + long double r = b.imag / b.real; + long double s = (long double)(1.0) / (b.real + b.imag * r); + return __pyx_t_long_double_complex_from_parts( + (a.real + a.imag * r) * s, (a.imag - a.real * r) * s); + } + } else { + long double r = b.real / b.imag; + long double s = (long double)(1.0) / (b.imag + b.real * r); + return __pyx_t_long_double_complex_from_parts( + (a.real * r + a.imag) * s, (a.imag * r - a.real) * s); + } + } + #else + static CYTHON_INLINE __pyx_t_long_double_complex __Pyx_c_quot_long__double(__pyx_t_long_double_complex a, __pyx_t_long_double_complex b) { + if (b.imag == 0) { + return __pyx_t_long_double_complex_from_parts(a.real / b.real, a.imag / b.real); + } else { + long double denom = b.real * b.real + b.imag * b.imag; + return __pyx_t_long_double_complex_from_parts( + (a.real * b.real + a.imag * b.imag) / denom, + (a.imag * b.real - a.real * b.imag) / denom); + } + } + #endif + static CYTHON_INLINE __pyx_t_long_double_complex __Pyx_c_neg_long__double(__pyx_t_long_double_complex a) { + __pyx_t_long_double_complex z; + z.real = -a.real; + z.imag = -a.imag; + return z; + } + static CYTHON_INLINE int __Pyx_c_is_zero_long__double(__pyx_t_long_double_complex a) { + return (a.real == 0) && (a.imag == 0); + } + static CYTHON_INLINE __pyx_t_long_double_complex __Pyx_c_conj_long__double(__pyx_t_long_double_complex a) { + __pyx_t_long_double_complex z; + z.real = a.real; + z.imag = -a.imag; + return z; + } + #if 1 + static CYTHON_INLINE long double __Pyx_c_abs_long__double(__pyx_t_long_double_complex z) { + #if !defined(HAVE_HYPOT) || defined(_MSC_VER) + return sqrtl(z.real*z.real + z.imag*z.imag); + #else + return hypotl(z.real, z.imag); + #endif + } + static CYTHON_INLINE __pyx_t_long_double_complex __Pyx_c_pow_long__double(__pyx_t_long_double_complex a, __pyx_t_long_double_complex b) { + __pyx_t_long_double_complex z; + long double r, lnr, theta, z_r, z_theta; + if (b.imag == 0 && b.real == (int)b.real) { + if (b.real < 0) { + long double denom = a.real * a.real + a.imag * a.imag; + a.real = a.real / denom; + a.imag = -a.imag / denom; + b.real = -b.real; + } + switch ((int)b.real) { + case 0: + z.real = 1; + z.imag = 0; + return z; + case 1: + return a; + case 2: + return __Pyx_c_prod_long__double(a, a); + case 3: + z = __Pyx_c_prod_long__double(a, a); + return __Pyx_c_prod_long__double(z, a); + case 4: + z = __Pyx_c_prod_long__double(a, a); + return __Pyx_c_prod_long__double(z, z); + } + } + if (a.imag == 0) { + if (a.real == 0) { + return a; + } else if ((b.imag == 0) && (a.real >= 0)) { + z.real = powl(a.real, b.real); + z.imag = 0; + return z; + } else if (a.real > 0) { + r = a.real; + theta = 0; + } else { + r = -a.real; + theta = atan2l(0.0, -1.0); + } + } else { + r = __Pyx_c_abs_long__double(a); + theta = atan2l(a.imag, a.real); + } + lnr = logl(r); + z_r = expl(lnr * b.real - theta * b.imag); + z_theta = theta * b.real + lnr * b.imag; + z.real = z_r * cosl(z_theta); + z.imag = z_r * sinl(z_theta); + return z; + } + #endif +#endif + +/* MemviewSliceCopyTemplate */ + static __Pyx_memviewslice +__pyx_memoryview_copy_new_contig(const __Pyx_memviewslice *from_mvs, + const char *mode, int ndim, + size_t sizeof_dtype, int contig_flag, + int dtype_is_object) +{ + __Pyx_RefNannyDeclarations + int i; + __Pyx_memviewslice new_mvs = { 0, 0, { 0 }, { 0 }, { 0 } }; + struct __pyx_memoryview_obj *from_memview = from_mvs->memview; + Py_buffer *buf = &from_memview->view; + PyObject *shape_tuple = NULL; + PyObject *temp_int = NULL; + struct __pyx_array_obj *array_obj = NULL; + struct __pyx_memoryview_obj *memview_obj = NULL; + __Pyx_RefNannySetupContext("__pyx_memoryview_copy_new_contig", 0); + for (i = 0; i < ndim; i++) { + if (unlikely(from_mvs->suboffsets[i] >= 0)) { + PyErr_Format(PyExc_ValueError, "Cannot copy memoryview slice with " + "indirect dimensions (axis %d)", i); + goto fail; + } + } + shape_tuple = PyTuple_New(ndim); + if (unlikely(!shape_tuple)) { + goto fail; + } + __Pyx_GOTREF(shape_tuple); + for(i = 0; i < ndim; i++) { + temp_int = PyInt_FromSsize_t(from_mvs->shape[i]); + if(unlikely(!temp_int)) { + goto fail; + } else { + PyTuple_SET_ITEM(shape_tuple, i, temp_int); + temp_int = NULL; + } + } + array_obj = __pyx_array_new(shape_tuple, sizeof_dtype, buf->format, (char *) mode, NULL); + if (unlikely(!array_obj)) { + goto fail; + } + __Pyx_GOTREF(array_obj); + memview_obj = (struct __pyx_memoryview_obj *) __pyx_memoryview_new( + (PyObject *) array_obj, contig_flag, + dtype_is_object, + from_mvs->memview->typeinfo); + if (unlikely(!memview_obj)) + goto fail; + if (unlikely(__Pyx_init_memviewslice(memview_obj, ndim, &new_mvs, 1) < 0)) + goto fail; + if (unlikely(__pyx_memoryview_copy_contents(*from_mvs, new_mvs, ndim, ndim, + dtype_is_object) < 0)) + goto fail; + goto no_fail; +fail: + __Pyx_XDECREF(new_mvs.memview); + new_mvs.memview = NULL; + new_mvs.data = NULL; +no_fail: + __Pyx_XDECREF(shape_tuple); + __Pyx_XDECREF(temp_int); + __Pyx_XDECREF(array_obj); + __Pyx_RefNannyFinishContext(); + return new_mvs; +} + +/* MemviewSliceInit */ + static int +__Pyx_init_memviewslice(struct __pyx_memoryview_obj *memview, + int ndim, + __Pyx_memviewslice *memviewslice, + int memview_is_new_reference) +{ + __Pyx_RefNannyDeclarations + int i, retval=-1; + Py_buffer *buf = &memview->view; + __Pyx_RefNannySetupContext("init_memviewslice", 0); + if (unlikely(memviewslice->memview || memviewslice->data)) { + PyErr_SetString(PyExc_ValueError, + "memviewslice is already initialized!"); + goto fail; + } + if (buf->strides) { + for (i = 0; i < ndim; i++) { + memviewslice->strides[i] = buf->strides[i]; + } + } else { + Py_ssize_t stride = buf->itemsize; + for (i = ndim - 1; i >= 0; i--) { + memviewslice->strides[i] = stride; + stride *= buf->shape[i]; + } + } + for (i = 0; i < ndim; i++) { + memviewslice->shape[i] = buf->shape[i]; + if (buf->suboffsets) { + memviewslice->suboffsets[i] = buf->suboffsets[i]; + } else { + memviewslice->suboffsets[i] = -1; + } + } + memviewslice->memview = memview; + memviewslice->data = (char *)buf->buf; + if (__pyx_add_acquisition_count(memview) == 0 && !memview_is_new_reference) { + Py_INCREF(memview); + } + retval = 0; + goto no_fail; +fail: + memviewslice->memview = 0; + memviewslice->data = 0; + retval = -1; +no_fail: + __Pyx_RefNannyFinishContext(); + return retval; +} +#ifndef Py_NO_RETURN +#define Py_NO_RETURN +#endif +static void __pyx_fatalerror(const char *fmt, ...) Py_NO_RETURN { + va_list vargs; + char msg[200]; +#if PY_VERSION_HEX >= 0x030A0000 || defined(HAVE_STDARG_PROTOTYPES) + va_start(vargs, fmt); +#else + va_start(vargs); +#endif + vsnprintf(msg, 200, fmt, vargs); + va_end(vargs); + Py_FatalError(msg); +} +static CYTHON_INLINE int +__pyx_add_acquisition_count_locked(__pyx_atomic_int_type *acquisition_count, + PyThread_type_lock lock) +{ + int result; + PyThread_acquire_lock(lock, 1); + result = (*acquisition_count)++; + PyThread_release_lock(lock); + return result; +} +static CYTHON_INLINE int +__pyx_sub_acquisition_count_locked(__pyx_atomic_int_type *acquisition_count, + PyThread_type_lock lock) +{ + int result; + PyThread_acquire_lock(lock, 1); + result = (*acquisition_count)--; + PyThread_release_lock(lock); + return result; +} +static CYTHON_INLINE void +__Pyx_INC_MEMVIEW(__Pyx_memviewslice *memslice, int have_gil, int lineno) +{ + __pyx_nonatomic_int_type old_acquisition_count; + struct __pyx_memoryview_obj *memview = memslice->memview; + if (unlikely(!memview || (PyObject *) memview == Py_None)) { + return; + } + old_acquisition_count = __pyx_add_acquisition_count(memview); + if (unlikely(old_acquisition_count <= 0)) { + if (likely(old_acquisition_count == 0)) { + if (have_gil) { + Py_INCREF((PyObject *) memview); + } else { + PyGILState_STATE _gilstate = PyGILState_Ensure(); + Py_INCREF((PyObject *) memview); + PyGILState_Release(_gilstate); + } + } else { + __pyx_fatalerror("Acquisition count is %d (line %d)", + old_acquisition_count+1, lineno); + } + } +} +static CYTHON_INLINE void __Pyx_XCLEAR_MEMVIEW(__Pyx_memviewslice *memslice, + int have_gil, int lineno) { + __pyx_nonatomic_int_type old_acquisition_count; + struct __pyx_memoryview_obj *memview = memslice->memview; + if (unlikely(!memview || (PyObject *) memview == Py_None)) { + memslice->memview = NULL; + return; + } + old_acquisition_count = __pyx_sub_acquisition_count(memview); + memslice->data = NULL; + if (likely(old_acquisition_count > 1)) { + memslice->memview = NULL; + } else if (likely(old_acquisition_count == 1)) { + if (have_gil) { + Py_CLEAR(memslice->memview); + } else { + PyGILState_STATE _gilstate = PyGILState_Ensure(); + Py_CLEAR(memslice->memview); + PyGILState_Release(_gilstate); + } + } else { + __pyx_fatalerror("Acquisition count is %d (line %d)", + old_acquisition_count-1, lineno); + } +} + +/* CIntToPy */ + static CYTHON_INLINE PyObject* __Pyx_PyInt_From_int(int value) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const int neg_one = (int) -1, const_zero = (int) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; + if (is_unsigned) { + if (sizeof(int) < sizeof(long)) { + return PyInt_FromLong((long) value); + } else if (sizeof(int) <= sizeof(unsigned long)) { + return PyLong_FromUnsignedLong((unsigned long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(int) <= sizeof(unsigned PY_LONG_LONG)) { + return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value); +#endif + } + } else { + if (sizeof(int) <= sizeof(long)) { + return PyInt_FromLong((long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(int) <= sizeof(PY_LONG_LONG)) { + return PyLong_FromLongLong((PY_LONG_LONG) value); +#endif + } + } + { + unsigned char *bytes = (unsigned char *)&value; +#if !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX >= 0x030d00A4 + if (is_unsigned) { + return PyLong_FromUnsignedNativeBytes(bytes, sizeof(value), -1); + } else { + return PyLong_FromNativeBytes(bytes, sizeof(value), -1); + } +#elif !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030d0000 + int one = 1; int little = (int)*(unsigned char *)&one; + return _PyLong_FromByteArray(bytes, sizeof(int), + little, !is_unsigned); +#else + int one = 1; int little = (int)*(unsigned char *)&one; + PyObject *from_bytes, *result = NULL; + PyObject *py_bytes = NULL, *arg_tuple = NULL, *kwds = NULL, *order_str = NULL; + from_bytes = PyObject_GetAttrString((PyObject*)&PyLong_Type, "from_bytes"); + if (!from_bytes) return NULL; + py_bytes = PyBytes_FromStringAndSize((char*)bytes, sizeof(int)); + if (!py_bytes) goto limited_bad; + order_str = PyUnicode_FromString(little ? "little" : "big"); + if (!order_str) goto limited_bad; + arg_tuple = PyTuple_Pack(2, py_bytes, order_str); + if (!arg_tuple) goto limited_bad; + if (!is_unsigned) { + kwds = PyDict_New(); + if (!kwds) goto limited_bad; + if (PyDict_SetItemString(kwds, "signed", __Pyx_NewRef(Py_True))) goto limited_bad; + } + result = PyObject_Call(from_bytes, arg_tuple, kwds); + limited_bad: + Py_XDECREF(kwds); + Py_XDECREF(arg_tuple); + Py_XDECREF(order_str); + Py_XDECREF(py_bytes); + Py_XDECREF(from_bytes); + return result; +#endif + } +} + +/* CIntFromPy */ + static CYTHON_INLINE int __Pyx_PyInt_As_int(PyObject *x) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const int neg_one = (int) -1, const_zero = (int) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_Check(x))) { + if ((sizeof(int) < sizeof(long))) { + __PYX_VERIFY_RETURN_INT(int, long, PyInt_AS_LONG(x)) + } else { + long val = PyInt_AS_LONG(x); + if (is_unsigned && unlikely(val < 0)) { + goto raise_neg_overflow; + } + return (int) val; + } + } +#endif + if (unlikely(!PyLong_Check(x))) { + int val; + PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); + if (!tmp) return (int) -1; + val = __Pyx_PyInt_As_int(tmp); + Py_DECREF(tmp); + return val; + } + if (is_unsigned) { +#if CYTHON_USE_PYLONG_INTERNALS + if (unlikely(__Pyx_PyLong_IsNeg(x))) { + goto raise_neg_overflow; + } else if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(int, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_DigitCount(x)) { + case 2: + if ((8 * sizeof(int) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) >= 2 * PyLong_SHIFT)) { + return (int) (((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); + } + } + break; + case 3: + if ((8 * sizeof(int) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) >= 3 * PyLong_SHIFT)) { + return (int) (((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); + } + } + break; + case 4: + if ((8 * sizeof(int) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) >= 4 * PyLong_SHIFT)) { + return (int) (((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0])); + } + } + break; + } + } +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7 + if (unlikely(Py_SIZE(x) < 0)) { + goto raise_neg_overflow; + } +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (int) -1; + if (unlikely(result == 1)) + goto raise_neg_overflow; + } +#endif + if ((sizeof(int) <= sizeof(unsigned long))) { + __PYX_VERIFY_RETURN_INT_EXC(int, unsigned long, PyLong_AsUnsignedLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(int) <= sizeof(unsigned PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(int, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) +#endif + } + } else { +#if CYTHON_USE_PYLONG_INTERNALS + if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(int, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_SignedDigitCount(x)) { + case -2: + if ((8 * sizeof(int) - 1 > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) { + return (int) (((int)-1)*(((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + case 2: + if ((8 * sizeof(int) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) { + return (int) ((((((int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + case -3: + if ((8 * sizeof(int) - 1 > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) { + return (int) (((int)-1)*(((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + case 3: + if ((8 * sizeof(int) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) { + return (int) ((((((((int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + case -4: + if ((8 * sizeof(int) - 1 > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 4 * PyLong_SHIFT)) { + return (int) (((int)-1)*(((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + case 4: + if ((8 * sizeof(int) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(int, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(int) - 1 > 4 * PyLong_SHIFT)) { + return (int) ((((((((((int)digits[3]) << PyLong_SHIFT) | (int)digits[2]) << PyLong_SHIFT) | (int)digits[1]) << PyLong_SHIFT) | (int)digits[0]))); + } + } + break; + } + } +#endif + if ((sizeof(int) <= sizeof(long))) { + __PYX_VERIFY_RETURN_INT_EXC(int, long, PyLong_AsLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(int) <= sizeof(PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(int, PY_LONG_LONG, PyLong_AsLongLong(x)) +#endif + } + } + { + int val; + int ret = -1; +#if PY_VERSION_HEX >= 0x030d00A6 && !CYTHON_COMPILING_IN_LIMITED_API + Py_ssize_t bytes_copied = PyLong_AsNativeBytes( + x, &val, sizeof(val), Py_ASNATIVEBYTES_NATIVE_ENDIAN | (is_unsigned ? Py_ASNATIVEBYTES_UNSIGNED_BUFFER | Py_ASNATIVEBYTES_REJECT_NEGATIVE : 0)); + if (unlikely(bytes_copied == -1)) { + } else if (unlikely(bytes_copied > (Py_ssize_t) sizeof(val))) { + goto raise_overflow; + } else { + ret = 0; + } +#elif PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) + int one = 1; int is_little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&val; + ret = _PyLong_AsByteArray((PyLongObject *)x, + bytes, sizeof(val), + is_little, !is_unsigned); +#else + PyObject *v; + PyObject *stepval = NULL, *mask = NULL, *shift = NULL; + int bits, remaining_bits, is_negative = 0; + int chunk_size = (sizeof(long) < 8) ? 30 : 62; + if (likely(PyLong_CheckExact(x))) { + v = __Pyx_NewRef(x); + } else { + v = PyNumber_Long(x); + if (unlikely(!v)) return (int) -1; + assert(PyLong_CheckExact(v)); + } + { + int result = PyObject_RichCompareBool(v, Py_False, Py_LT); + if (unlikely(result < 0)) { + Py_DECREF(v); + return (int) -1; + } + is_negative = result == 1; + } + if (is_unsigned && unlikely(is_negative)) { + Py_DECREF(v); + goto raise_neg_overflow; + } else if (is_negative) { + stepval = PyNumber_Invert(v); + Py_DECREF(v); + if (unlikely(!stepval)) + return (int) -1; + } else { + stepval = v; + } + v = NULL; + val = (int) 0; + mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; + shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; + for (bits = 0; bits < (int) sizeof(int) * 8 - chunk_size; bits += chunk_size) { + PyObject *tmp, *digit; + long idigit; + digit = PyNumber_And(stepval, mask); + if (unlikely(!digit)) goto done; + idigit = PyLong_AsLong(digit); + Py_DECREF(digit); + if (unlikely(idigit < 0)) goto done; + val |= ((int) idigit) << bits; + tmp = PyNumber_Rshift(stepval, shift); + if (unlikely(!tmp)) goto done; + Py_DECREF(stepval); stepval = tmp; + } + Py_DECREF(shift); shift = NULL; + Py_DECREF(mask); mask = NULL; + { + long idigit = PyLong_AsLong(stepval); + if (unlikely(idigit < 0)) goto done; + remaining_bits = ((int) sizeof(int) * 8) - bits - (is_unsigned ? 0 : 1); + if (unlikely(idigit >= (1L << remaining_bits))) + goto raise_overflow; + val |= ((int) idigit) << bits; + } + if (!is_unsigned) { + if (unlikely(val & (((int) 1) << (sizeof(int) * 8 - 1)))) + goto raise_overflow; + if (is_negative) + val = ~val; + } + ret = 0; + done: + Py_XDECREF(shift); + Py_XDECREF(mask); + Py_XDECREF(stepval); +#endif + if (unlikely(ret)) + return (int) -1; + return val; + } +raise_overflow: + PyErr_SetString(PyExc_OverflowError, + "value too large to convert to int"); + return (int) -1; +raise_neg_overflow: + PyErr_SetString(PyExc_OverflowError, + "can't convert negative value to int"); + return (int) -1; +} + +/* CIntToPy */ + static CYTHON_INLINE PyObject* __Pyx_PyInt_From_long(long value) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const long neg_one = (long) -1, const_zero = (long) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; + if (is_unsigned) { + if (sizeof(long) < sizeof(long)) { + return PyInt_FromLong((long) value); + } else if (sizeof(long) <= sizeof(unsigned long)) { + return PyLong_FromUnsignedLong((unsigned long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(long) <= sizeof(unsigned PY_LONG_LONG)) { + return PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG) value); +#endif + } + } else { + if (sizeof(long) <= sizeof(long)) { + return PyInt_FromLong((long) value); +#ifdef HAVE_LONG_LONG + } else if (sizeof(long) <= sizeof(PY_LONG_LONG)) { + return PyLong_FromLongLong((PY_LONG_LONG) value); +#endif + } + } + { + unsigned char *bytes = (unsigned char *)&value; +#if !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX >= 0x030d00A4 + if (is_unsigned) { + return PyLong_FromUnsignedNativeBytes(bytes, sizeof(value), -1); + } else { + return PyLong_FromNativeBytes(bytes, sizeof(value), -1); + } +#elif !CYTHON_COMPILING_IN_LIMITED_API && PY_VERSION_HEX < 0x030d0000 + int one = 1; int little = (int)*(unsigned char *)&one; + return _PyLong_FromByteArray(bytes, sizeof(long), + little, !is_unsigned); +#else + int one = 1; int little = (int)*(unsigned char *)&one; + PyObject *from_bytes, *result = NULL; + PyObject *py_bytes = NULL, *arg_tuple = NULL, *kwds = NULL, *order_str = NULL; + from_bytes = PyObject_GetAttrString((PyObject*)&PyLong_Type, "from_bytes"); + if (!from_bytes) return NULL; + py_bytes = PyBytes_FromStringAndSize((char*)bytes, sizeof(long)); + if (!py_bytes) goto limited_bad; + order_str = PyUnicode_FromString(little ? "little" : "big"); + if (!order_str) goto limited_bad; + arg_tuple = PyTuple_Pack(2, py_bytes, order_str); + if (!arg_tuple) goto limited_bad; + if (!is_unsigned) { + kwds = PyDict_New(); + if (!kwds) goto limited_bad; + if (PyDict_SetItemString(kwds, "signed", __Pyx_NewRef(Py_True))) goto limited_bad; + } + result = PyObject_Call(from_bytes, arg_tuple, kwds); + limited_bad: + Py_XDECREF(kwds); + Py_XDECREF(arg_tuple); + Py_XDECREF(order_str); + Py_XDECREF(py_bytes); + Py_XDECREF(from_bytes); + return result; +#endif + } +} + +/* CIntFromPy */ + static CYTHON_INLINE long __Pyx_PyInt_As_long(PyObject *x) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const long neg_one = (long) -1, const_zero = (long) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_Check(x))) { + if ((sizeof(long) < sizeof(long))) { + __PYX_VERIFY_RETURN_INT(long, long, PyInt_AS_LONG(x)) + } else { + long val = PyInt_AS_LONG(x); + if (is_unsigned && unlikely(val < 0)) { + goto raise_neg_overflow; + } + return (long) val; + } + } +#endif + if (unlikely(!PyLong_Check(x))) { + long val; + PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); + if (!tmp) return (long) -1; + val = __Pyx_PyInt_As_long(tmp); + Py_DECREF(tmp); + return val; + } + if (is_unsigned) { +#if CYTHON_USE_PYLONG_INTERNALS + if (unlikely(__Pyx_PyLong_IsNeg(x))) { + goto raise_neg_overflow; + } else if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(long, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_DigitCount(x)) { + case 2: + if ((8 * sizeof(long) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) >= 2 * PyLong_SHIFT)) { + return (long) (((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); + } + } + break; + case 3: + if ((8 * sizeof(long) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) >= 3 * PyLong_SHIFT)) { + return (long) (((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); + } + } + break; + case 4: + if ((8 * sizeof(long) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) >= 4 * PyLong_SHIFT)) { + return (long) (((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0])); + } + } + break; + } + } +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7 + if (unlikely(Py_SIZE(x) < 0)) { + goto raise_neg_overflow; + } +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (long) -1; + if (unlikely(result == 1)) + goto raise_neg_overflow; + } +#endif + if ((sizeof(long) <= sizeof(unsigned long))) { + __PYX_VERIFY_RETURN_INT_EXC(long, unsigned long, PyLong_AsUnsignedLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(long) <= sizeof(unsigned PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(long, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) +#endif + } + } else { +#if CYTHON_USE_PYLONG_INTERNALS + if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(long, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_SignedDigitCount(x)) { + case -2: + if ((8 * sizeof(long) - 1 > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) { + return (long) (((long)-1)*(((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + case 2: + if ((8 * sizeof(long) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) { + return (long) ((((((long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + case -3: + if ((8 * sizeof(long) - 1 > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) { + return (long) (((long)-1)*(((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + case 3: + if ((8 * sizeof(long) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) { + return (long) ((((((((long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + case -4: + if ((8 * sizeof(long) - 1 > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 4 * PyLong_SHIFT)) { + return (long) (((long)-1)*(((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + case 4: + if ((8 * sizeof(long) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(long, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(long) - 1 > 4 * PyLong_SHIFT)) { + return (long) ((((((((((long)digits[3]) << PyLong_SHIFT) | (long)digits[2]) << PyLong_SHIFT) | (long)digits[1]) << PyLong_SHIFT) | (long)digits[0]))); + } + } + break; + } + } +#endif + if ((sizeof(long) <= sizeof(long))) { + __PYX_VERIFY_RETURN_INT_EXC(long, long, PyLong_AsLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(long) <= sizeof(PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(long, PY_LONG_LONG, PyLong_AsLongLong(x)) +#endif + } + } + { + long val; + int ret = -1; +#if PY_VERSION_HEX >= 0x030d00A6 && !CYTHON_COMPILING_IN_LIMITED_API + Py_ssize_t bytes_copied = PyLong_AsNativeBytes( + x, &val, sizeof(val), Py_ASNATIVEBYTES_NATIVE_ENDIAN | (is_unsigned ? Py_ASNATIVEBYTES_UNSIGNED_BUFFER | Py_ASNATIVEBYTES_REJECT_NEGATIVE : 0)); + if (unlikely(bytes_copied == -1)) { + } else if (unlikely(bytes_copied > (Py_ssize_t) sizeof(val))) { + goto raise_overflow; + } else { + ret = 0; + } +#elif PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) + int one = 1; int is_little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&val; + ret = _PyLong_AsByteArray((PyLongObject *)x, + bytes, sizeof(val), + is_little, !is_unsigned); +#else + PyObject *v; + PyObject *stepval = NULL, *mask = NULL, *shift = NULL; + int bits, remaining_bits, is_negative = 0; + int chunk_size = (sizeof(long) < 8) ? 30 : 62; + if (likely(PyLong_CheckExact(x))) { + v = __Pyx_NewRef(x); + } else { + v = PyNumber_Long(x); + if (unlikely(!v)) return (long) -1; + assert(PyLong_CheckExact(v)); + } + { + int result = PyObject_RichCompareBool(v, Py_False, Py_LT); + if (unlikely(result < 0)) { + Py_DECREF(v); + return (long) -1; + } + is_negative = result == 1; + } + if (is_unsigned && unlikely(is_negative)) { + Py_DECREF(v); + goto raise_neg_overflow; + } else if (is_negative) { + stepval = PyNumber_Invert(v); + Py_DECREF(v); + if (unlikely(!stepval)) + return (long) -1; + } else { + stepval = v; + } + v = NULL; + val = (long) 0; + mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; + shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; + for (bits = 0; bits < (int) sizeof(long) * 8 - chunk_size; bits += chunk_size) { + PyObject *tmp, *digit; + long idigit; + digit = PyNumber_And(stepval, mask); + if (unlikely(!digit)) goto done; + idigit = PyLong_AsLong(digit); + Py_DECREF(digit); + if (unlikely(idigit < 0)) goto done; + val |= ((long) idigit) << bits; + tmp = PyNumber_Rshift(stepval, shift); + if (unlikely(!tmp)) goto done; + Py_DECREF(stepval); stepval = tmp; + } + Py_DECREF(shift); shift = NULL; + Py_DECREF(mask); mask = NULL; + { + long idigit = PyLong_AsLong(stepval); + if (unlikely(idigit < 0)) goto done; + remaining_bits = ((int) sizeof(long) * 8) - bits - (is_unsigned ? 0 : 1); + if (unlikely(idigit >= (1L << remaining_bits))) + goto raise_overflow; + val |= ((long) idigit) << bits; + } + if (!is_unsigned) { + if (unlikely(val & (((long) 1) << (sizeof(long) * 8 - 1)))) + goto raise_overflow; + if (is_negative) + val = ~val; + } + ret = 0; + done: + Py_XDECREF(shift); + Py_XDECREF(mask); + Py_XDECREF(stepval); +#endif + if (unlikely(ret)) + return (long) -1; + return val; + } +raise_overflow: + PyErr_SetString(PyExc_OverflowError, + "value too large to convert to long"); + return (long) -1; +raise_neg_overflow: + PyErr_SetString(PyExc_OverflowError, + "can't convert negative value to long"); + return (long) -1; +} + +/* CIntFromPy */ + static CYTHON_INLINE char __Pyx_PyInt_As_char(PyObject *x) { +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wconversion" +#endif + const char neg_one = (char) -1, const_zero = (char) 0; +#ifdef __Pyx_HAS_GCC_DIAGNOSTIC +#pragma GCC diagnostic pop +#endif + const int is_unsigned = neg_one > const_zero; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_Check(x))) { + if ((sizeof(char) < sizeof(long))) { + __PYX_VERIFY_RETURN_INT(char, long, PyInt_AS_LONG(x)) + } else { + long val = PyInt_AS_LONG(x); + if (is_unsigned && unlikely(val < 0)) { + goto raise_neg_overflow; + } + return (char) val; + } + } +#endif + if (unlikely(!PyLong_Check(x))) { + char val; + PyObject *tmp = __Pyx_PyNumber_IntOrLong(x); + if (!tmp) return (char) -1; + val = __Pyx_PyInt_As_char(tmp); + Py_DECREF(tmp); + return val; + } + if (is_unsigned) { +#if CYTHON_USE_PYLONG_INTERNALS + if (unlikely(__Pyx_PyLong_IsNeg(x))) { + goto raise_neg_overflow; + } else if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(char, __Pyx_compact_upylong, __Pyx_PyLong_CompactValueUnsigned(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_DigitCount(x)) { + case 2: + if ((8 * sizeof(char) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) >= 2 * PyLong_SHIFT)) { + return (char) (((((char)digits[1]) << PyLong_SHIFT) | (char)digits[0])); + } + } + break; + case 3: + if ((8 * sizeof(char) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) >= 3 * PyLong_SHIFT)) { + return (char) (((((((char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0])); + } + } + break; + case 4: + if ((8 * sizeof(char) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) >= 4 * PyLong_SHIFT)) { + return (char) (((((((((char)digits[3]) << PyLong_SHIFT) | (char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0])); + } + } + break; + } + } +#endif +#if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX < 0x030C00A7 + if (unlikely(Py_SIZE(x) < 0)) { + goto raise_neg_overflow; + } +#else + { + int result = PyObject_RichCompareBool(x, Py_False, Py_LT); + if (unlikely(result < 0)) + return (char) -1; + if (unlikely(result == 1)) + goto raise_neg_overflow; + } +#endif + if ((sizeof(char) <= sizeof(unsigned long))) { + __PYX_VERIFY_RETURN_INT_EXC(char, unsigned long, PyLong_AsUnsignedLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(char) <= sizeof(unsigned PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(char, unsigned PY_LONG_LONG, PyLong_AsUnsignedLongLong(x)) +#endif + } + } else { +#if CYTHON_USE_PYLONG_INTERNALS + if (__Pyx_PyLong_IsCompact(x)) { + __PYX_VERIFY_RETURN_INT(char, __Pyx_compact_pylong, __Pyx_PyLong_CompactValue(x)) + } else { + const digit* digits = __Pyx_PyLong_Digits(x); + assert(__Pyx_PyLong_DigitCount(x) > 1); + switch (__Pyx_PyLong_SignedDigitCount(x)) { + case -2: + if ((8 * sizeof(char) - 1 > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, long, -(long) (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 2 * PyLong_SHIFT)) { + return (char) (((char)-1)*(((((char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + case 2: + if ((8 * sizeof(char) > 1 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 2 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 2 * PyLong_SHIFT)) { + return (char) ((((((char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + case -3: + if ((8 * sizeof(char) - 1 > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, long, -(long) (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 3 * PyLong_SHIFT)) { + return (char) (((char)-1)*(((((((char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + case 3: + if ((8 * sizeof(char) > 2 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 3 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((((unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 3 * PyLong_SHIFT)) { + return (char) ((((((((char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + case -4: + if ((8 * sizeof(char) - 1 > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, long, -(long) (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 4 * PyLong_SHIFT)) { + return (char) (((char)-1)*(((((((((char)digits[3]) << PyLong_SHIFT) | (char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + case 4: + if ((8 * sizeof(char) > 3 * PyLong_SHIFT)) { + if ((8 * sizeof(unsigned long) > 4 * PyLong_SHIFT)) { + __PYX_VERIFY_RETURN_INT(char, unsigned long, (((((((((unsigned long)digits[3]) << PyLong_SHIFT) | (unsigned long)digits[2]) << PyLong_SHIFT) | (unsigned long)digits[1]) << PyLong_SHIFT) | (unsigned long)digits[0]))) + } else if ((8 * sizeof(char) - 1 > 4 * PyLong_SHIFT)) { + return (char) ((((((((((char)digits[3]) << PyLong_SHIFT) | (char)digits[2]) << PyLong_SHIFT) | (char)digits[1]) << PyLong_SHIFT) | (char)digits[0]))); + } + } + break; + } + } +#endif + if ((sizeof(char) <= sizeof(long))) { + __PYX_VERIFY_RETURN_INT_EXC(char, long, PyLong_AsLong(x)) +#ifdef HAVE_LONG_LONG + } else if ((sizeof(char) <= sizeof(PY_LONG_LONG))) { + __PYX_VERIFY_RETURN_INT_EXC(char, PY_LONG_LONG, PyLong_AsLongLong(x)) +#endif + } + } + { + char val; + int ret = -1; +#if PY_VERSION_HEX >= 0x030d00A6 && !CYTHON_COMPILING_IN_LIMITED_API + Py_ssize_t bytes_copied = PyLong_AsNativeBytes( + x, &val, sizeof(val), Py_ASNATIVEBYTES_NATIVE_ENDIAN | (is_unsigned ? Py_ASNATIVEBYTES_UNSIGNED_BUFFER | Py_ASNATIVEBYTES_REJECT_NEGATIVE : 0)); + if (unlikely(bytes_copied == -1)) { + } else if (unlikely(bytes_copied > (Py_ssize_t) sizeof(val))) { + goto raise_overflow; + } else { + ret = 0; + } +#elif PY_VERSION_HEX < 0x030d0000 && !(CYTHON_COMPILING_IN_PYPY || CYTHON_COMPILING_IN_LIMITED_API) || defined(_PyLong_AsByteArray) + int one = 1; int is_little = (int)*(unsigned char *)&one; + unsigned char *bytes = (unsigned char *)&val; + ret = _PyLong_AsByteArray((PyLongObject *)x, + bytes, sizeof(val), + is_little, !is_unsigned); +#else + PyObject *v; + PyObject *stepval = NULL, *mask = NULL, *shift = NULL; + int bits, remaining_bits, is_negative = 0; + int chunk_size = (sizeof(long) < 8) ? 30 : 62; + if (likely(PyLong_CheckExact(x))) { + v = __Pyx_NewRef(x); + } else { + v = PyNumber_Long(x); + if (unlikely(!v)) return (char) -1; + assert(PyLong_CheckExact(v)); + } + { + int result = PyObject_RichCompareBool(v, Py_False, Py_LT); + if (unlikely(result < 0)) { + Py_DECREF(v); + return (char) -1; + } + is_negative = result == 1; + } + if (is_unsigned && unlikely(is_negative)) { + Py_DECREF(v); + goto raise_neg_overflow; + } else if (is_negative) { + stepval = PyNumber_Invert(v); + Py_DECREF(v); + if (unlikely(!stepval)) + return (char) -1; + } else { + stepval = v; + } + v = NULL; + val = (char) 0; + mask = PyLong_FromLong((1L << chunk_size) - 1); if (unlikely(!mask)) goto done; + shift = PyLong_FromLong(chunk_size); if (unlikely(!shift)) goto done; + for (bits = 0; bits < (int) sizeof(char) * 8 - chunk_size; bits += chunk_size) { + PyObject *tmp, *digit; + long idigit; + digit = PyNumber_And(stepval, mask); + if (unlikely(!digit)) goto done; + idigit = PyLong_AsLong(digit); + Py_DECREF(digit); + if (unlikely(idigit < 0)) goto done; + val |= ((char) idigit) << bits; + tmp = PyNumber_Rshift(stepval, shift); + if (unlikely(!tmp)) goto done; + Py_DECREF(stepval); stepval = tmp; + } + Py_DECREF(shift); shift = NULL; + Py_DECREF(mask); mask = NULL; + { + long idigit = PyLong_AsLong(stepval); + if (unlikely(idigit < 0)) goto done; + remaining_bits = ((int) sizeof(char) * 8) - bits - (is_unsigned ? 0 : 1); + if (unlikely(idigit >= (1L << remaining_bits))) + goto raise_overflow; + val |= ((char) idigit) << bits; + } + if (!is_unsigned) { + if (unlikely(val & (((char) 1) << (sizeof(char) * 8 - 1)))) + goto raise_overflow; + if (is_negative) + val = ~val; + } + ret = 0; + done: + Py_XDECREF(shift); + Py_XDECREF(mask); + Py_XDECREF(stepval); +#endif + if (unlikely(ret)) + return (char) -1; + return val; + } +raise_overflow: + PyErr_SetString(PyExc_OverflowError, + "value too large to convert to char"); + return (char) -1; +raise_neg_overflow: + PyErr_SetString(PyExc_OverflowError, + "can't convert negative value to char"); + return (char) -1; +} + +/* FormatTypeName */ + #if CYTHON_COMPILING_IN_LIMITED_API +static __Pyx_TypeName +__Pyx_PyType_GetName(PyTypeObject* tp) +{ + PyObject *name = __Pyx_PyObject_GetAttrStr((PyObject *)tp, + __pyx_n_s_name_2); + if (unlikely(name == NULL) || unlikely(!PyUnicode_Check(name))) { + PyErr_Clear(); + Py_XDECREF(name); + name = __Pyx_NewRef(__pyx_n_s__25); + } + return name; +} +#endif + +/* CheckBinaryVersion */ + static unsigned long __Pyx_get_runtime_version(void) { +#if __PYX_LIMITED_VERSION_HEX >= 0x030B00A4 + return Py_Version & ~0xFFUL; +#else + const char* rt_version = Py_GetVersion(); + unsigned long version = 0; + unsigned long factor = 0x01000000UL; + unsigned int digit = 0; + int i = 0; + while (factor) { + while ('0' <= rt_version[i] && rt_version[i] <= '9') { + digit = digit * 10 + (unsigned int) (rt_version[i] - '0'); + ++i; + } + version += factor * digit; + if (rt_version[i] != '.') + break; + digit = 0; + factor >>= 8; + ++i; + } + return version; +#endif +} +static int __Pyx_check_binary_version(unsigned long ct_version, unsigned long rt_version, int allow_newer) { + const unsigned long MAJOR_MINOR = 0xFFFF0000UL; + if ((rt_version & MAJOR_MINOR) == (ct_version & MAJOR_MINOR)) + return 0; + if (likely(allow_newer && (rt_version & MAJOR_MINOR) > (ct_version & MAJOR_MINOR))) + return 1; + { + char message[200]; + PyOS_snprintf(message, sizeof(message), + "compile time Python version %d.%d " + "of module '%.100s' " + "%s " + "runtime version %d.%d", + (int) (ct_version >> 24), (int) ((ct_version >> 16) & 0xFF), + __Pyx_MODULE_NAME, + (allow_newer) ? "was newer than" : "does not match", + (int) (rt_version >> 24), (int) ((rt_version >> 16) & 0xFF) + ); + return PyErr_WarnEx(NULL, message, 1); + } +} + +/* InitStrings */ + #if PY_MAJOR_VERSION >= 3 +static int __Pyx_InitString(__Pyx_StringTabEntry t, PyObject **str) { + if (t.is_unicode | t.is_str) { + if (t.intern) { + *str = PyUnicode_InternFromString(t.s); + } else if (t.encoding) { + *str = PyUnicode_Decode(t.s, t.n - 1, t.encoding, NULL); + } else { + *str = PyUnicode_FromStringAndSize(t.s, t.n - 1); + } + } else { + *str = PyBytes_FromStringAndSize(t.s, t.n - 1); + } + if (!*str) + return -1; + if (PyObject_Hash(*str) == -1) + return -1; + return 0; +} +#endif +static int __Pyx_InitStrings(__Pyx_StringTabEntry *t) { + while (t->p) { + #if PY_MAJOR_VERSION >= 3 + __Pyx_InitString(*t, t->p); + #else + if (t->is_unicode) { + *t->p = PyUnicode_DecodeUTF8(t->s, t->n - 1, NULL); + } else if (t->intern) { + *t->p = PyString_InternFromString(t->s); + } else { + *t->p = PyString_FromStringAndSize(t->s, t->n - 1); + } + if (!*t->p) + return -1; + if (PyObject_Hash(*t->p) == -1) + return -1; + #endif + ++t; + } + return 0; +} + +#include +static CYTHON_INLINE Py_ssize_t __Pyx_ssize_strlen(const char *s) { + size_t len = strlen(s); + if (unlikely(len > (size_t) PY_SSIZE_T_MAX)) { + PyErr_SetString(PyExc_OverflowError, "byte string is too long"); + return -1; + } + return (Py_ssize_t) len; +} +static CYTHON_INLINE PyObject* __Pyx_PyUnicode_FromString(const char* c_str) { + Py_ssize_t len = __Pyx_ssize_strlen(c_str); + if (unlikely(len < 0)) return NULL; + return __Pyx_PyUnicode_FromStringAndSize(c_str, len); +} +static CYTHON_INLINE PyObject* __Pyx_PyByteArray_FromString(const char* c_str) { + Py_ssize_t len = __Pyx_ssize_strlen(c_str); + if (unlikely(len < 0)) return NULL; + return PyByteArray_FromStringAndSize(c_str, len); +} +static CYTHON_INLINE const char* __Pyx_PyObject_AsString(PyObject* o) { + Py_ssize_t ignore; + return __Pyx_PyObject_AsStringAndSize(o, &ignore); +} +#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT +#if !CYTHON_PEP393_ENABLED +static const char* __Pyx_PyUnicode_AsStringAndSize(PyObject* o, Py_ssize_t *length) { + char* defenc_c; + PyObject* defenc = _PyUnicode_AsDefaultEncodedString(o, NULL); + if (!defenc) return NULL; + defenc_c = PyBytes_AS_STRING(defenc); +#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII + { + char* end = defenc_c + PyBytes_GET_SIZE(defenc); + char* c; + for (c = defenc_c; c < end; c++) { + if ((unsigned char) (*c) >= 128) { + PyUnicode_AsASCIIString(o); + return NULL; + } + } + } +#endif + *length = PyBytes_GET_SIZE(defenc); + return defenc_c; +} +#else +static CYTHON_INLINE const char* __Pyx_PyUnicode_AsStringAndSize(PyObject* o, Py_ssize_t *length) { + if (unlikely(__Pyx_PyUnicode_READY(o) == -1)) return NULL; +#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII + if (likely(PyUnicode_IS_ASCII(o))) { + *length = PyUnicode_GET_LENGTH(o); + return PyUnicode_AsUTF8(o); + } else { + PyUnicode_AsASCIIString(o); + return NULL; + } +#else + return PyUnicode_AsUTF8AndSize(o, length); +#endif +} +#endif +#endif +static CYTHON_INLINE const char* __Pyx_PyObject_AsStringAndSize(PyObject* o, Py_ssize_t *length) { +#if __PYX_DEFAULT_STRING_ENCODING_IS_ASCII || __PYX_DEFAULT_STRING_ENCODING_IS_DEFAULT + if ( +#if PY_MAJOR_VERSION < 3 && __PYX_DEFAULT_STRING_ENCODING_IS_ASCII + __Pyx_sys_getdefaultencoding_not_ascii && +#endif + PyUnicode_Check(o)) { + return __Pyx_PyUnicode_AsStringAndSize(o, length); + } else +#endif +#if (!CYTHON_COMPILING_IN_PYPY && !CYTHON_COMPILING_IN_LIMITED_API) || (defined(PyByteArray_AS_STRING) && defined(PyByteArray_GET_SIZE)) + if (PyByteArray_Check(o)) { + *length = PyByteArray_GET_SIZE(o); + return PyByteArray_AS_STRING(o); + } else +#endif + { + char* result; + int r = PyBytes_AsStringAndSize(o, &result, length); + if (unlikely(r < 0)) { + return NULL; + } else { + return result; + } + } +} +static CYTHON_INLINE int __Pyx_PyObject_IsTrue(PyObject* x) { + int is_true = x == Py_True; + if (is_true | (x == Py_False) | (x == Py_None)) return is_true; + else return PyObject_IsTrue(x); +} +static CYTHON_INLINE int __Pyx_PyObject_IsTrueAndDecref(PyObject* x) { + int retval; + if (unlikely(!x)) return -1; + retval = __Pyx_PyObject_IsTrue(x); + Py_DECREF(x); + return retval; +} +static PyObject* __Pyx_PyNumber_IntOrLongWrongResultType(PyObject* result, const char* type_name) { + __Pyx_TypeName result_type_name = __Pyx_PyType_GetName(Py_TYPE(result)); +#if PY_MAJOR_VERSION >= 3 + if (PyLong_Check(result)) { + if (PyErr_WarnFormat(PyExc_DeprecationWarning, 1, + "__int__ returned non-int (type " __Pyx_FMT_TYPENAME "). " + "The ability to return an instance of a strict subclass of int is deprecated, " + "and may be removed in a future version of Python.", + result_type_name)) { + __Pyx_DECREF_TypeName(result_type_name); + Py_DECREF(result); + return NULL; + } + __Pyx_DECREF_TypeName(result_type_name); + return result; + } +#endif + PyErr_Format(PyExc_TypeError, + "__%.4s__ returned non-%.4s (type " __Pyx_FMT_TYPENAME ")", + type_name, type_name, result_type_name); + __Pyx_DECREF_TypeName(result_type_name); + Py_DECREF(result); + return NULL; +} +static CYTHON_INLINE PyObject* __Pyx_PyNumber_IntOrLong(PyObject* x) { +#if CYTHON_USE_TYPE_SLOTS + PyNumberMethods *m; +#endif + const char *name = NULL; + PyObject *res = NULL; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_Check(x) || PyLong_Check(x))) +#else + if (likely(PyLong_Check(x))) +#endif + return __Pyx_NewRef(x); +#if CYTHON_USE_TYPE_SLOTS + m = Py_TYPE(x)->tp_as_number; + #if PY_MAJOR_VERSION < 3 + if (m && m->nb_int) { + name = "int"; + res = m->nb_int(x); + } + else if (m && m->nb_long) { + name = "long"; + res = m->nb_long(x); + } + #else + if (likely(m && m->nb_int)) { + name = "int"; + res = m->nb_int(x); + } + #endif +#else + if (!PyBytes_CheckExact(x) && !PyUnicode_CheckExact(x)) { + res = PyNumber_Int(x); + } +#endif + if (likely(res)) { +#if PY_MAJOR_VERSION < 3 + if (unlikely(!PyInt_Check(res) && !PyLong_Check(res))) { +#else + if (unlikely(!PyLong_CheckExact(res))) { +#endif + return __Pyx_PyNumber_IntOrLongWrongResultType(res, name); + } + } + else if (!PyErr_Occurred()) { + PyErr_SetString(PyExc_TypeError, + "an integer is required"); + } + return res; +} +static CYTHON_INLINE Py_ssize_t __Pyx_PyIndex_AsSsize_t(PyObject* b) { + Py_ssize_t ival; + PyObject *x; +#if PY_MAJOR_VERSION < 3 + if (likely(PyInt_CheckExact(b))) { + if (sizeof(Py_ssize_t) >= sizeof(long)) + return PyInt_AS_LONG(b); + else + return PyInt_AsSsize_t(b); + } +#endif + if (likely(PyLong_CheckExact(b))) { + #if CYTHON_USE_PYLONG_INTERNALS + if (likely(__Pyx_PyLong_IsCompact(b))) { + return __Pyx_PyLong_CompactValue(b); + } else { + const digit* digits = __Pyx_PyLong_Digits(b); + const Py_ssize_t size = __Pyx_PyLong_SignedDigitCount(b); + switch (size) { + case 2: + if (8 * sizeof(Py_ssize_t) > 2 * PyLong_SHIFT) { + return (Py_ssize_t) (((((size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + case -2: + if (8 * sizeof(Py_ssize_t) > 2 * PyLong_SHIFT) { + return -(Py_ssize_t) (((((size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + case 3: + if (8 * sizeof(Py_ssize_t) > 3 * PyLong_SHIFT) { + return (Py_ssize_t) (((((((size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + case -3: + if (8 * sizeof(Py_ssize_t) > 3 * PyLong_SHIFT) { + return -(Py_ssize_t) (((((((size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + case 4: + if (8 * sizeof(Py_ssize_t) > 4 * PyLong_SHIFT) { + return (Py_ssize_t) (((((((((size_t)digits[3]) << PyLong_SHIFT) | (size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + case -4: + if (8 * sizeof(Py_ssize_t) > 4 * PyLong_SHIFT) { + return -(Py_ssize_t) (((((((((size_t)digits[3]) << PyLong_SHIFT) | (size_t)digits[2]) << PyLong_SHIFT) | (size_t)digits[1]) << PyLong_SHIFT) | (size_t)digits[0])); + } + break; + } + } + #endif + return PyLong_AsSsize_t(b); + } + x = PyNumber_Index(b); + if (!x) return -1; + ival = PyInt_AsSsize_t(x); + Py_DECREF(x); + return ival; +} +static CYTHON_INLINE Py_hash_t __Pyx_PyIndex_AsHash_t(PyObject* o) { + if (sizeof(Py_hash_t) == sizeof(Py_ssize_t)) { + return (Py_hash_t) __Pyx_PyIndex_AsSsize_t(o); +#if PY_MAJOR_VERSION < 3 + } else if (likely(PyInt_CheckExact(o))) { + return PyInt_AS_LONG(o); +#endif + } else { + Py_ssize_t ival; + PyObject *x; + x = PyNumber_Index(o); + if (!x) return -1; + ival = PyInt_AsLong(x); + Py_DECREF(x); + return ival; + } +} +static CYTHON_INLINE PyObject * __Pyx_PyBool_FromLong(long b) { + return b ? __Pyx_NewRef(Py_True) : __Pyx_NewRef(Py_False); +} +static CYTHON_INLINE PyObject * __Pyx_PyInt_FromSize_t(size_t ival) { + return PyInt_FromSize_t(ival); +} + + +/* #### Code section: utility_code_pragmas_end ### */ +#ifdef _MSC_VER +#pragma warning( pop ) +#endif + + + +/* #### Code section: end ### */ +#endif /* Py_PYTHON_H */ diff --git a/TTS/tts/utils/monotonic_align/core.pyx b/TTS/tts/utils/monotonic_align/core.pyx new file mode 100644 index 0000000000000000000000000000000000000000..091fcc3a50a51f3d3fee47a70825260757e6d885 --- /dev/null +++ b/TTS/tts/utils/monotonic_align/core.pyx @@ -0,0 +1,47 @@ +import numpy as np + +cimport cython +cimport numpy as np + +from cython.parallel import prange + + +@cython.boundscheck(False) +@cython.wraparound(False) +cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: + cdef int x + cdef int y + cdef float v_prev + cdef float v_cur + cdef float tmp + cdef int index = t_x - 1 + + for y in range(t_y): + for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): + if x == y: + v_cur = max_neg_val + else: + v_cur = value[x, y-1] + if x == 0: + if y == 0: + v_prev = 0. + else: + v_prev = max_neg_val + else: + v_prev = value[x-1, y-1] + value[x, y] = max(v_cur, v_prev) + value[x, y] + + for y in range(t_y - 1, -1, -1): + path[index, y] = 1 + if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): + index = index - 1 + + +@cython.boundscheck(False) +@cython.wraparound(False) +cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: + cdef int b = values.shape[0] + + cdef int i + for i in prange(b, nogil=True): + maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py new file mode 100644 index 0000000000000000000000000000000000000000..5229af81c5631b04df188b5b2c9230d9ca389b90 --- /dev/null +++ b/TTS/tts/utils/speakers.py @@ -0,0 +1,227 @@ +import json +import logging +import os +from typing import Any, Dict, List, Union + +import fsspec +import numpy as np +import torch +from coqpit import Coqpit + +from TTS.config import get_from_config_or_model_args_with_default +from TTS.tts.utils.managers import EmbeddingManager + +logger = logging.getLogger(__name__) + + +class SpeakerManager(EmbeddingManager): + """Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information + in a way that can be queried by speaker or clip. + + There are 3 different scenarios considered: + + 1. Models using speaker embedding layers. The datafile only maps speaker names to ids used by the embedding layer. + 2. Models using d-vectors. The datafile includes a dictionary in the following format. + + :: + + { + 'clip_name.wav':{ + 'name': 'speakerA', + 'embedding'[] + }, + ... + } + + + 3. Computing the d-vectors by the speaker encoder. It loads the speaker encoder model and + computes the d-vectors for a given clip or speaker. + + Args: + d_vectors_file_path (str, optional): Path to the metafile including x vectors. Defaults to "". + speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by + TTS models. Defaults to "". + encoder_model_path (str, optional): Path to the speaker encoder model file. Defaults to "". + encoder_config_path (str, optional): Path to the spealer encoder config file. Defaults to "". + + Examples: + >>> # load audio processor and speaker encoder + >>> ap = AudioProcessor(**config.audio) + >>> manager = SpeakerManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path) + >>> # load a sample audio and compute embedding + >>> waveform = ap.load_wav(sample_wav_path) + >>> mel = ap.melspectrogram(waveform) + >>> d_vector = manager.compute_embeddings(mel.T) + """ + + def __init__( + self, + data_items: List[List[Any]] = None, + d_vectors_file_path: str = "", + speaker_id_file_path: str = "", + encoder_model_path: str = "", + encoder_config_path: str = "", + use_cuda: bool = False, + ): + super().__init__( + embedding_file_path=d_vectors_file_path, + id_file_path=speaker_id_file_path, + encoder_model_path=encoder_model_path, + encoder_config_path=encoder_config_path, + use_cuda=use_cuda, + ) + + if data_items: + self.set_ids_from_data(data_items, parse_key="speaker_name") + + @property + def num_speakers(self): + return len(self.name_to_id) + + @property + def speaker_names(self): + return list(self.name_to_id.keys()) + + def get_speakers(self) -> List: + return self.name_to_id + + @staticmethod + def init_from_config(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "SpeakerManager": + """Initialize a speaker manager from config + + Args: + config (Coqpit): Config object. + samples (Union[List[List], List[Dict]], optional): List of data samples to parse out the speaker names. + Defaults to None. + + Returns: + SpeakerEncoder: Speaker encoder object. + """ + speaker_manager = None + if get_from_config_or_model_args_with_default(config, "use_speaker_embedding", False): + if samples: + speaker_manager = SpeakerManager(data_items=samples) + if get_from_config_or_model_args_with_default(config, "speaker_file", None): + speaker_manager = SpeakerManager( + speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speaker_file", None) + ) + if get_from_config_or_model_args_with_default(config, "speakers_file", None): + speaker_manager = SpeakerManager( + speaker_id_file_path=get_from_config_or_model_args_with_default(config, "speakers_file", None) + ) + + if get_from_config_or_model_args_with_default(config, "use_d_vector_file", False): + speaker_manager = SpeakerManager() + if get_from_config_or_model_args_with_default(config, "d_vector_file", None): + speaker_manager = SpeakerManager( + d_vectors_file_path=get_from_config_or_model_args_with_default(config, "d_vector_file", None) + ) + return speaker_manager + + +def _set_file_path(path): + """Find the speakers.json under the given path or the above it. + Intended to band aid the different paths returned in restored and continued training.""" + path_restore = os.path.join(os.path.dirname(path), "speakers.json") + path_continue = os.path.join(path, "speakers.json") + fs = fsspec.get_mapper(path).fs + if fs.exists(path_restore): + return path_restore + if fs.exists(path_continue): + return path_continue + raise FileNotFoundError(f" [!] `speakers.json` not found in {path}") + + +def load_speaker_mapping(out_path): + """Loads speaker mapping if already present.""" + if os.path.splitext(out_path)[1] == ".json": + json_file = out_path + else: + json_file = _set_file_path(out_path) + with fsspec.open(json_file, "r") as f: + return json.load(f) + + +def save_speaker_mapping(out_path, speaker_mapping): + """Saves speaker mapping if not yet present.""" + if out_path is not None: + speakers_json_path = _set_file_path(out_path) + with fsspec.open(speakers_json_path, "w") as f: + json.dump(speaker_mapping, f, indent=4) + + +def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, out_path: str = None) -> SpeakerManager: + """Initiate a `SpeakerManager` instance by the provided config. + + Args: + c (Coqpit): Model configuration. + restore_path (str): Path to a previous training folder. + data (List): Data samples used in training to infer speakers from. It must be provided if speaker embedding + layers is used. Defaults to None. + out_path (str, optional): Save the generated speaker IDs to a output path. Defaults to None. + + Returns: + SpeakerManager: initialized and ready to use instance. + """ + speaker_manager = SpeakerManager() + if c.use_speaker_embedding: + if data is not None: + speaker_manager.set_ids_from_data(data, parse_key="speaker_name") + if restore_path: + speakers_file = _set_file_path(restore_path) + # restoring speaker manager from a previous run. + if c.use_d_vector_file: + # restore speaker manager with the embedding file + if not os.path.exists(speakers_file): + logger.warning( + "speakers.json was not found in %s, trying to use CONFIG.d_vector_file", restore_path + ) + if not os.path.exists(c.d_vector_file): + raise RuntimeError( + "You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file" + ) + speaker_manager.load_embeddings_from_file(c.d_vector_file) + speaker_manager.load_embeddings_from_file(speakers_file) + elif not c.use_d_vector_file: # restor speaker manager with speaker ID file. + speaker_ids_from_data = speaker_manager.name_to_id + speaker_manager.load_ids_from_file(speakers_file) + assert all( + speaker in speaker_manager.name_to_id for speaker in speaker_ids_from_data + ), " [!] You cannot introduce new speakers to a pre-trained model." + elif c.use_d_vector_file and c.d_vector_file: + # new speaker manager with external speaker embeddings. + speaker_manager.load_embeddings_from_file(c.d_vector_file) + elif c.use_d_vector_file and not c.d_vector_file: + raise "use_d_vector_file is True, so you need pass a external speaker embedding file." + elif c.use_speaker_embedding and "speakers_file" in c and c.speakers_file: + # new speaker manager with speaker IDs file. + speaker_manager.load_ids_from_file(c.speakers_file) + + if speaker_manager.num_speakers > 0: + logger.info( + "Speaker manager is loaded with %d speakers: %s", + speaker_manager.num_speakers, + ", ".join(speaker_manager.name_to_id), + ) + + # save file if path is defined + if out_path: + out_file_path = os.path.join(out_path, "speakers.json") + logger.info("Saving `speakers.json` to %s", out_file_path) + if c.use_d_vector_file and c.d_vector_file: + speaker_manager.save_embeddings_to_file(out_file_path) + else: + speaker_manager.save_ids_to_file(out_file_path) + return speaker_manager + + +def get_speaker_balancer_weights(items: list): + speaker_names = np.array([item["speaker_name"] for item in items]) + unique_speaker_names = np.unique(speaker_names).tolist() + speaker_ids = [unique_speaker_names.index(l) for l in speaker_names] + speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names]) + weight_speaker = 1.0 / speaker_count + dataset_samples_weight = np.array([weight_speaker[l] for l in speaker_ids]) + # normalize + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + return torch.from_numpy(dataset_samples_weight).float() diff --git a/TTS/tts/utils/ssim.py b/TTS/tts/utils/ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..eddf05db3f2c8d04fd195143433e1f1d8c59fd96 --- /dev/null +++ b/TTS/tts/utils/ssim.py @@ -0,0 +1,384 @@ +# Adopted from https://github.com/photosynthesis-team/piq + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss + + +def _reduce(x: torch.Tensor, reduction: str = "mean") -> torch.Tensor: + r"""Reduce input in batch dimension if needed. + Args: + x: Tensor with shape (N, *). + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'`` + """ + if reduction == "none": + return x + if reduction == "mean": + return x.mean(dim=0) + if reduction == "sum": + return x.sum(dim=0) + raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}") + + +def _validate_input( + tensors: List[torch.Tensor], + dim_range: Tuple[int, int] = (0, -1), + data_range: Tuple[float, float] = (0.0, -1.0), + # size_dim_range: Tuple[float, float] = (0., -1.), + size_range: Optional[Tuple[int, int]] = None, +) -> None: + r"""Check that input(-s) satisfies the requirements + Args: + tensors: Tensors to check + dim_range: Allowed number of dimensions. (min, max) + data_range: Allowed range of values in tensors. (min, max) + size_range: Dimensions to include in size comparison. (start_dim, end_dim + 1) + """ + + if not __debug__: + return + + x = tensors[0] + + for t in tensors: + assert torch.is_tensor(t), f"Expected torch.Tensor, got {type(t)}" + assert t.device == x.device, f"Expected tensors to be on {x.device}, got {t.device}" + + if size_range is None: + assert t.size() == x.size(), f"Expected tensors with same size, got {t.size()} and {x.size()}" + else: + assert ( + t.size()[size_range[0] : size_range[1]] == x.size()[size_range[0] : size_range[1]] + ), f"Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}" + + if dim_range[0] == dim_range[1]: + assert t.dim() == dim_range[0], f"Expected number of dimensions to be {dim_range[0]}, got {t.dim()}" + elif dim_range[0] < dim_range[1]: + assert ( + dim_range[0] <= t.dim() <= dim_range[1] + ), f"Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}" + + if data_range[0] < data_range[1]: + assert data_range[0] <= t.min(), f"Expected values to be greater or equal to {data_range[0]}, got {t.min()}" + assert t.max() <= data_range[1], f"Expected values to be lower or equal to {data_range[1]}, got {t.max()}" + + +def gaussian_filter(kernel_size: int, sigma: float) -> torch.Tensor: + r"""Returns 2D Gaussian kernel N(0,`sigma`^2) + Args: + size: Size of the kernel + sigma: Std of the distribution + Returns: + gaussian_kernel: Tensor with shape (1, kernel_size, kernel_size) + """ + coords = torch.arange(kernel_size, dtype=torch.float32) + coords -= (kernel_size - 1) / 2.0 + + g = coords**2 + g = (-(g.unsqueeze(0) + g.unsqueeze(1)) / (2 * sigma**2)).exp() + + g /= g.sum() + return g.unsqueeze(0) + + +def ssim( + x: torch.Tensor, + y: torch.Tensor, + kernel_size: int = 11, + kernel_sigma: float = 1.5, + data_range: Union[int, float] = 1.0, + reduction: str = "mean", + full: bool = False, + downsample: bool = True, + k1: float = 0.01, + k2: float = 0.03, +) -> List[torch.Tensor]: + r"""Interface of Structural Similarity (SSIM) index. + Inputs supposed to be in range ``[0, data_range]``. + To match performance with skimage and tensorflow set ``'downsample' = True``. + + Args: + x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`. + y: A target tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`. + kernel_size: The side-length of the sliding window used in comparison. Must be an odd value. + kernel_sigma: Sigma of normal distribution. + data_range: Maximum value range of images (usually 1.0 or 255). + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` + full: Return cs map or not. + downsample: Perform average pool before SSIM computation. Default: True + k1: Algorithm parameter, K1 (small constant). + k2: Algorithm parameter, K2 (small constant). + Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + + Returns: + Value of Structural Similarity (SSIM) index. In case of 5D input tensors, complex value is returned + as a tensor of size 2. + + References: + Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). + Image quality assessment: From error visibility to structural similarity. + IEEE Transactions on Image Processing, 13, 600-612. + https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, + DOI: `10.1109/TIP.2003.819861` + """ + assert kernel_size % 2 == 1, f"Kernel size must be odd, got [{kernel_size}]" + _validate_input([x, y], dim_range=(4, 5), data_range=(0, data_range)) + + x = x / float(data_range) + y = y / float(data_range) + + # Averagepool image if the size is large enough + f = max(1, round(min(x.size()[-2:]) / 256)) + if (f > 1) and downsample: + x = F.avg_pool2d(x, kernel_size=f) + y = F.avg_pool2d(y, kernel_size=f) + + kernel = gaussian_filter(kernel_size, kernel_sigma).repeat(x.size(1), 1, 1, 1).to(y) + _compute_ssim_per_channel = _ssim_per_channel_complex if x.dim() == 5 else _ssim_per_channel + ssim_map, cs_map = _compute_ssim_per_channel(x=x, y=y, kernel=kernel, k1=k1, k2=k2) + ssim_val = ssim_map.mean(1) + cs = cs_map.mean(1) + + ssim_val = _reduce(ssim_val, reduction) + cs = _reduce(cs, reduction) + + if full: + return [ssim_val, cs] + + return ssim_val + + +class SSIMLoss(_Loss): + r"""Creates a criterion that measures the structural similarity index error between + each element in the input :math:`x` and target :math:`y`. + + To match performance with skimage and tensorflow set ``'downsample' = True``. + + The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: + + .. math:: + SSIM = \{ssim_1,\dots,ssim_{N \times C}\}\\ + ssim_{l}(x, y) = \frac{(2 \mu_x \mu_y + c_1) (2 \sigma_{xy} + c_2)} + {(\mu_x^2 +\mu_y^2 + c_1)(\sigma_x^2 +\sigma_y^2 + c_2)}, + + where :math:`N` is the batch size, `C` is the channel size. If :attr:`reduction` is not ``'none'`` + (default ``'mean'``), then: + + .. math:: + SSIMLoss(x, y) = + \begin{cases} + \operatorname{mean}(1 - SSIM), & \text{if reduction} = \text{'mean';}\\ + \operatorname{sum}(1 - SSIM), & \text{if reduction} = \text{'sum'.} + \end{cases} + + :math:`x` and :math:`y` are tensors of arbitrary shapes with a total + of :math:`n` elements each. + + The sum operation still operates over all the elements, and divides by :math:`n`. + The division by :math:`n` can be avoided if one sets ``reduction = 'sum'``. + In case of 5D input tensors, complex value is returned as a tensor of size 2. + + Args: + kernel_size: By default, the mean and covariance of a pixel is obtained + by convolution with given filter_size. + kernel_sigma: Standard deviation for Gaussian kernel. + k1: Coefficient related to c1 in the above equation. + k2: Coefficient related to c2 in the above equation. + downsample: Perform average pool before SSIM computation. Default: True + reduction: Specifies the reduction type: + ``'none'`` | ``'mean'`` | ``'sum'``. Default:``'mean'`` + data_range: Maximum value range of images (usually 1.0 or 255). + + Examples: + >>> loss = SSIMLoss() + >>> x = torch.rand(3, 3, 256, 256, requires_grad=True) + >>> y = torch.rand(3, 3, 256, 256) + >>> output = loss(x, y) + >>> output.backward() + + References: + Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004). + Image quality assessment: From error visibility to structural similarity. + IEEE Transactions on Image Processing, 13, 600-612. + https://ece.uwaterloo.ca/~z70wang/publications/ssim.pdf, + DOI:`10.1109/TIP.2003.819861` + """ + + __constants__ = ["kernel_size", "k1", "k2", "sigma", "kernel", "reduction"] + + def __init__( + self, + kernel_size: int = 11, + kernel_sigma: float = 1.5, + k1: float = 0.01, + k2: float = 0.03, + downsample: bool = True, + reduction: str = "mean", + data_range: Union[int, float] = 1.0, + ) -> None: + super().__init__() + + # Generic loss parameters. + self.reduction = reduction + + # Loss-specific parameters. + self.kernel_size = kernel_size + + # This check might look redundant because kernel size is checked within the ssim function anyway. + # However, this check allows to fail fast when the loss is being initialised and training has not been started. + assert kernel_size % 2 == 1, f"Kernel size must be odd, got [{kernel_size}]" + self.kernel_sigma = kernel_sigma + self.k1 = k1 + self.k2 = k2 + self.downsample = downsample + self.data_range = data_range + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + r"""Computation of Structural Similarity (SSIM) index as a loss function. + + Args: + x: An input tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`. + y: A target tensor. Shape :math:`(N, C, H, W)` or :math:`(N, C, H, W, 2)`. + + Returns: + Value of SSIM loss to be minimized, i.e ``1 - ssim`` in [0, 1] range. In case of 5D input tensors, + complex value is returned as a tensor of size 2. + """ + + score = ssim( + x=x, + y=y, + kernel_size=self.kernel_size, + kernel_sigma=self.kernel_sigma, + downsample=self.downsample, + data_range=self.data_range, + reduction=self.reduction, + full=False, + k1=self.k1, + k2=self.k2, + ) + return torch.ones_like(score) - score + + +def _ssim_per_channel( + x: torch.Tensor, + y: torch.Tensor, + kernel: torch.Tensor, + k1: float = 0.01, + k2: float = 0.03, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + r"""Calculate Structural Similarity (SSIM) index for X and Y per channel. + + Args: + x: An input tensor. Shape :math:`(N, C, H, W)`. + y: A target tensor. Shape :math:`(N, C, H, W)`. + kernel: 2D Gaussian kernel. + k1: Algorithm parameter, K1 (small constant, see [1]). + k2: Algorithm parameter, K2 (small constant, see [1]). + Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + + Returns: + Full Value of Structural Similarity (SSIM) index. + """ + if x.size(-1) < kernel.size(-1) or x.size(-2) < kernel.size(-2): + raise ValueError( + f"Kernel size can't be greater than actual input size. Input size: {x.size()}. " + f"Kernel size: {kernel.size()}" + ) + + c1 = k1**2 + c2 = k2**2 + n_channels = x.size(1) + mu_x = F.conv2d(x, weight=kernel, stride=1, padding=0, groups=n_channels) + mu_y = F.conv2d(y, weight=kernel, stride=1, padding=0, groups=n_channels) + + mu_xx = mu_x**2 + mu_yy = mu_y**2 + mu_xy = mu_x * mu_y + + sigma_xx = F.conv2d(x**2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xx + sigma_yy = F.conv2d(y**2, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_yy + sigma_xy = F.conv2d(x * y, weight=kernel, stride=1, padding=0, groups=n_channels) - mu_xy + + # Contrast sensitivity (CS) with alpha = beta = gamma = 1. + cs = (2.0 * sigma_xy + c2) / (sigma_xx + sigma_yy + c2) + + # Structural similarity (SSIM) + ss = (2.0 * mu_xy + c1) / (mu_xx + mu_yy + c1) * cs + + ssim_val = ss.mean(dim=(-1, -2)) + cs = cs.mean(dim=(-1, -2)) + return ssim_val, cs + + +def _ssim_per_channel_complex( + x: torch.Tensor, + y: torch.Tensor, + kernel: torch.Tensor, + k1: float = 0.01, + k2: float = 0.03, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + r"""Calculate Structural Similarity (SSIM) index for Complex X and Y per channel. + + Args: + x: An input tensor. Shape :math:`(N, C, H, W, 2)`. + y: A target tensor. Shape :math:`(N, C, H, W, 2)`. + kernel: 2-D gauss kernel. + k1: Algorithm parameter, K1 (small constant, see [1]). + k2: Algorithm parameter, K2 (small constant, see [1]). + Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. + + Returns: + Full Value of Complex Structural Similarity (SSIM) index. + """ + n_channels = x.size(1) + if x.size(-2) < kernel.size(-1) or x.size(-3) < kernel.size(-2): + raise ValueError( + f"Kernel size can't be greater than actual input size. Input size: {x.size()}. " + f"Kernel size: {kernel.size()}" + ) + + c1 = k1**2 + c2 = k2**2 + + x_real = x[..., 0] + x_imag = x[..., 1] + y_real = y[..., 0] + y_imag = y[..., 1] + + mu1_real = F.conv2d(x_real, weight=kernel, stride=1, padding=0, groups=n_channels) + mu1_imag = F.conv2d(x_imag, weight=kernel, stride=1, padding=0, groups=n_channels) + mu2_real = F.conv2d(y_real, weight=kernel, stride=1, padding=0, groups=n_channels) + mu2_imag = F.conv2d(y_imag, weight=kernel, stride=1, padding=0, groups=n_channels) + + mu1_sq = mu1_real.pow(2) + mu1_imag.pow(2) + mu2_sq = mu2_real.pow(2) + mu2_imag.pow(2) + mu1_mu2_real = mu1_real * mu2_real - mu1_imag * mu2_imag + mu1_mu2_imag = mu1_real * mu2_imag + mu1_imag * mu2_real + + compensation = 1.0 + + x_sq = x_real.pow(2) + x_imag.pow(2) + y_sq = y_real.pow(2) + y_imag.pow(2) + x_y_real = x_real * y_real - x_imag * y_imag + x_y_imag = x_real * y_imag + x_imag * y_real + + sigma1_sq = F.conv2d(x_sq, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_sq + sigma2_sq = F.conv2d(y_sq, weight=kernel, stride=1, padding=0, groups=n_channels) - mu2_sq + sigma12_real = F.conv2d(x_y_real, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_mu2_real + sigma12_imag = F.conv2d(x_y_imag, weight=kernel, stride=1, padding=0, groups=n_channels) - mu1_mu2_imag + sigma12 = torch.stack((sigma12_imag, sigma12_real), dim=-1) + mu1_mu2 = torch.stack((mu1_mu2_real, mu1_mu2_imag), dim=-1) + # Set alpha = beta = gamma = 1. + cs_map = (sigma12 * 2 + c2 * compensation) / (sigma1_sq.unsqueeze(-1) + sigma2_sq.unsqueeze(-1) + c2 * compensation) + ssim_map = (mu1_mu2 * 2 + c1 * compensation) / (mu1_sq.unsqueeze(-1) + mu2_sq.unsqueeze(-1) + c1 * compensation) + ssim_map = ssim_map * cs_map + + ssim_val = ssim_map.mean(dim=(-2, -3)) + cs = cs_map.mean(dim=(-2, -3)) + + return ssim_val, cs diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py new file mode 100644 index 0000000000000000000000000000000000000000..797151c2540fe86b26df2b52bf734f0a93389f72 --- /dev/null +++ b/TTS/tts/utils/synthesis.py @@ -0,0 +1,343 @@ +from typing import Dict + +import numpy as np +import torch +from torch import nn + + +def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"): + if cuda: + device = "cuda" + if np_array is None: + return None + tensor = torch.as_tensor(np_array, dtype=dtype, device=device) + return tensor + + +def compute_style_mel(style_wav, ap, cuda=False, device="cpu"): + if cuda: + device = "cuda" + style_mel = torch.FloatTensor( + ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)), + device=device, + ).unsqueeze(0) + return style_mel + + +def run_model_torch( + model: nn.Module, + inputs: torch.Tensor, + speaker_id: int = None, + style_mel: torch.Tensor = None, + style_text: str = None, + d_vector: torch.Tensor = None, + language_id: torch.Tensor = None, +) -> Dict: + """Run a torch model for inference. It does not support batch inference. + + Args: + model (nn.Module): The model to run inference. + inputs (torch.Tensor): Input tensor with character ids. + speaker_id (int, optional): Input speaker ids for multi-speaker models. Defaults to None. + style_mel (torch.Tensor, optional): Spectrograms used for voice styling . Defaults to None. + d_vector (torch.Tensor, optional): d-vector for multi-speaker models . Defaults to None. + + Returns: + Dict: model outputs. + """ + input_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) + if hasattr(model, "module"): + _func = model.module.inference + else: + _func = model.inference + outputs = _func( + inputs, + aux_input={ + "x_lengths": input_lengths, + "speaker_ids": speaker_id, + "d_vectors": d_vector, + "style_mel": style_mel, + "style_text": style_text, + "language_ids": language_id, + }, + ) + return outputs + + +def trim_silence(wav, ap): + return wav[: ap.find_endpoint(wav)] + + +def inv_spectrogram(postnet_output, ap, CONFIG): + if CONFIG.model.lower() in ["tacotron"]: + wav = ap.inv_spectrogram(postnet_output.T) + else: + wav = ap.inv_melspectrogram(postnet_output.T) + return wav + + +def id_to_torch(aux_id, cuda=False, device="cpu"): + if cuda: + device = "cuda" + if aux_id is not None: + aux_id = np.asarray(aux_id) + aux_id = torch.from_numpy(aux_id).to(device) + return aux_id + + +def embedding_to_torch(d_vector, cuda=False, device="cpu"): + if cuda: + device = "cuda" + if d_vector is not None: + d_vector = np.asarray(d_vector) + d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor) + d_vector = d_vector.squeeze().unsqueeze(0).to(device) + return d_vector + + +# TODO: perform GL with pytorch for batching +def apply_griffin_lim(inputs, input_lens, CONFIG, ap): + """Apply griffin-lim to each sample iterating throught the first dimension. + Args: + inputs (Tensor or np.Array): Features to be converted by GL. First dimension is the batch size. + input_lens (Tensor or np.Array): 1D array of sample lengths. + CONFIG (Dict): TTS config. + ap (AudioProcessor): TTS audio processor. + """ + wavs = [] + for idx, spec in enumerate(inputs): + wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding + wav = inv_spectrogram(spec, ap, CONFIG) + # assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}" + wavs.append(wav[:wav_len]) + return wavs + + +def synthesis( + model, + text, + CONFIG, + use_cuda, + speaker_id=None, + style_wav=None, + style_text=None, + use_griffin_lim=False, + do_trim_silence=False, + d_vector=None, + language_id=None, +): + """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to + the vocoder model. + + Args: + model (TTS.tts.models): + The TTS model to synthesize audio with. + + text (str): + The input text to convert to speech. + + CONFIG (Coqpit): + Model configuration. + + use_cuda (bool): + Enable/disable CUDA. + + speaker_id (int): + Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None. + + style_wav (str | Dict[str, float]): + Path or tensor to/of a waveform used for computing the style embedding based on GST or Capacitron. + Defaults to None, meaning that Capacitron models will sample from the prior distribution to + generate random but realistic prosody. + + style_text (str): + Transcription of style_wav for Capacitron models. Defaults to None. + + enable_eos_bos_chars (bool): + enable special chars for end of sentence and start of sentence. Defaults to False. + + do_trim_silence (bool): + trim silence after synthesis. Defaults to False. + + d_vector (torch.Tensor): + d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None. + + language_id (int): + Language ID passed to the language embedding layer in multi-langual model. Defaults to None. + """ + # device + device = next(model.parameters()).device + if use_cuda: + device = "cuda" + + # GST or Capacitron processing + # TODO: need to handle the case of setting both gst and capacitron to true somewhere + style_mel = None + if CONFIG.has("gst") and CONFIG.gst and style_wav is not None: + if isinstance(style_wav, dict): + style_mel = style_wav + else: + style_mel = compute_style_mel(style_wav, model.ap, device=device) + + if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None: + style_mel = compute_style_mel(style_wav, model.ap, device=device) + style_mel = style_mel.transpose(1, 2) # [1, time, depth] + + language_name = None + if language_id is not None: + language = [k for k, v in model.language_manager.name_to_id.items() if v == language_id] + assert len(language) == 1, "language_id must be a valid language" + language_name = language[0] + + # convert text to sequence of token IDs + text_inputs = np.asarray( + model.tokenizer.text_to_ids(text, language=language_name), + dtype=np.int32, + ) + # pass tensors to backend + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, device=device) + + if d_vector is not None: + d_vector = embedding_to_torch(d_vector, device=device) + + if language_id is not None: + language_id = id_to_torch(language_id, device=device) + + if not isinstance(style_mel, dict): + # GST or Capacitron style mel + style_mel = numpy_to_torch(style_mel, torch.float, device=device) + if style_text is not None: + style_text = np.asarray( + model.tokenizer.text_to_ids(style_text, language=language_id), + dtype=np.int32, + ) + style_text = numpy_to_torch(style_text, torch.long, device=device) + style_text = style_text.unsqueeze(0) + + text_inputs = numpy_to_torch(text_inputs, torch.long, device=device) + text_inputs = text_inputs.unsqueeze(0) + # synthesize voice + outputs = run_model_torch( + model, + text_inputs, + speaker_id, + style_mel, + style_text, + d_vector=d_vector, + language_id=language_id, + ) + model_outputs = outputs["model_outputs"] + model_outputs = model_outputs[0].data.cpu().numpy() + alignments = outputs["alignments"] + + # convert outputs to numpy + # plot results + wav = None + model_outputs = model_outputs.squeeze() + if model_outputs.ndim == 2: # [T, C_spec] + if use_griffin_lim: + wav = inv_spectrogram(model_outputs, model.ap, CONFIG) + # trim silence + if do_trim_silence: + wav = trim_silence(wav, model.ap) + else: # [T,] + wav = model_outputs + return_dict = { + "wav": wav, + "alignments": alignments, + "text_inputs": text_inputs, + "outputs": outputs, + } + return return_dict + + +def transfer_voice( + model, + CONFIG, + use_cuda, + reference_wav, + speaker_id=None, + d_vector=None, + reference_speaker_id=None, + reference_d_vector=None, + do_trim_silence=False, + use_griffin_lim=False, +): + """Synthesize voice for the given text using Griffin-Lim vocoder or just compute output features to be passed to + the vocoder model. + + Args: + model (TTS.tts.models): + The TTS model to synthesize audio with. + + CONFIG (Coqpit): + Model configuration. + + use_cuda (bool): + Enable/disable CUDA. + + reference_wav (str): + Path of reference_wav to be used to voice conversion. + + speaker_id (int): + Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None. + + d_vector (torch.Tensor): + d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None. + + reference_speaker_id (int): + Reference Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None. + + reference_d_vector (torch.Tensor): + Reference d-vector for multi-speaker models in share :math:`[1, D]`. Defaults to None. + + enable_eos_bos_chars (bool): + enable special chars for end of sentence and start of sentence. Defaults to False. + + do_trim_silence (bool): + trim silence after synthesis. Defaults to False. + """ + # device + device = next(model.parameters()).device + if use_cuda: + device = "cuda" + + # pass tensors to backend + if speaker_id is not None: + speaker_id = id_to_torch(speaker_id, device=device) + + if d_vector is not None: + d_vector = embedding_to_torch(d_vector, device=device) + + if reference_d_vector is not None: + reference_d_vector = embedding_to_torch(reference_d_vector, device=device) + + # load reference_wav audio + reference_wav = embedding_to_torch( + model.ap.load_wav( + reference_wav, sr=model.args.encoder_sample_rate if model.args.encoder_sample_rate else model.ap.sample_rate + ), + device=device, + ) + + if hasattr(model, "module"): + _func = model.module.inference_voice_conversion + else: + _func = model.inference_voice_conversion + model_outputs = _func(reference_wav, speaker_id, d_vector, reference_speaker_id, reference_d_vector) + + # convert outputs to numpy + # plot results + wav = None + model_outputs = model_outputs.squeeze() + if model_outputs.ndim == 2: # [T, C_spec] + if use_griffin_lim: + wav = inv_spectrogram(model_outputs, model.ap, CONFIG) + # trim silence + if do_trim_silence: + wav = trim_silence(wav, model.ap) + else: # [T,] + wav = model_outputs + + return wav diff --git a/TTS/tts/utils/text/__init__.py b/TTS/tts/utils/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..593372dc7cb2fba240eb5f08e8e2cfae5a4b4e45 --- /dev/null +++ b/TTS/tts/utils/text/__init__.py @@ -0,0 +1 @@ +from TTS.tts.utils.text.tokenizer import TTSTokenizer diff --git a/TTS/tts/utils/text/bangla/__init__.py b/TTS/tts/utils/text/bangla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/utils/text/bangla/phonemizer.py b/TTS/tts/utils/text/bangla/phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..cddcb00fd54db13964a48628fb7395083e9abbc5 --- /dev/null +++ b/TTS/tts/utils/text/bangla/phonemizer.py @@ -0,0 +1,124 @@ +import re + +try: + import bangla + from bnnumerizer import numerize + from bnunicodenormalizer import Normalizer +except ImportError as e: + raise ImportError("Bangla requires: bangla, bnnumerizer, bnunicodenormalizer") from e + +# initialize +bnorm = Normalizer() + + +attribution_dict = { + "সাঃ": "সাল্লাল্লাহু আলাইহি ওয়া সাল্লাম", + "আঃ": "আলাইহিস সালাম", + "রাঃ": "রাদিআল্লাহু আনহু", + "রহঃ": "রহমাতুল্লাহি আলাইহি", + "রহিঃ": "রহিমাহুল্লাহ", + "হাফিঃ": "হাফিযাহুল্লাহ", + "বায়ান": "বাইআন", + "দাঃবাঃ": "দামাত বারাকাতুহুম,দামাত বারাকাতুল্লাহ", + # "আয়াত" : "আইআত",#আইআত + # "ওয়া" : "ওআ", + # "ওয়াসাল্লাম" : "ওআসাল্লাম", + # "কেন" : "কেনো", + # "কোন" : "কোনো", + # "বল" : "বলো", + # "চল" : "চলো", + # "কর" : "করো", + # "রাখ" : "রাখো", + "’": "", + "‘": "", + # "য়" : "অ", + # "সম্প্রদায়" : "সম্প্রদাই", + # "রয়েছে" : "রইছে", + # "রয়েছ" : "রইছ", + "/": " বাই ", +} + + +def tag_text(text: str): + # remove multiple spaces + text = re.sub(" +", " ", text) + # create start and end + text = "start" + text + "end" + # tag text + parts = re.split("[\u0600-\u06FF]+", text) + # remove non chars + parts = [p for p in parts if p.strip()] + # unique parts + parts = set(parts) + # tag the text + for m in parts: + if len(m.strip()) > 1: + text = text.replace(m, f"{m}") + # clean-tags + text = text.replace("start", "") + text = text.replace("end", "") + return text + + +def normalize(sen): + global bnorm # pylint: disable=global-statement + _words = [bnorm(word)["normalized"] for word in sen.split()] + return " ".join([word for word in _words if word is not None]) + + +def expand_full_attribution(text): + for word, attr in attribution_dict.items(): + if word in text: + text = text.replace(word, normalize(attr)) + return text + + +def collapse_whitespace(text): + # Regular expression matching whitespace: + _whitespace_re = re.compile(r"\s+") + return re.sub(_whitespace_re, " ", text) + + +def bangla_text_to_phonemes(text: str) -> str: + # english numbers to bangla conversion + res = re.search("[0-9]", text) + if res is not None: + text = bangla.convert_english_digit_to_bangla_digit(text) + + # replace ':' in between two bangla numbers with ' এর ' + pattern = r"[০, ১, ২, ৩, ৪, ৫, ৬, ৭, ৮, ৯]:[০, ১, ২, ৩, ৪, ৫, ৬, ৭, ৮, ৯]" + matches = re.findall(pattern, text) + for m in matches: + r = m.replace(":", " এর ") + text = text.replace(m, r) + + # numerize text + text = numerize(text) + + # tag sections + text = tag_text(text) + + # text blocks + # blocks = text.split("") + # blocks = [b for b in blocks if b.strip()] + + # create tuple of (lang,text) + if "" in text: + text = text.replace("", "").replace("", "") + # Split based on sentence ending Characters + bn_text = text.strip() + + sentenceEnders = re.compile("[।!?]") + sentences = sentenceEnders.split(str(bn_text)) + + data = "" + for sent in sentences: + res = re.sub("\n", "", sent) + res = normalize(res) + # expand attributes + res = expand_full_attribution(res) + + res = collapse_whitespace(res) + res += "।" + data += res + return data diff --git a/TTS/tts/utils/text/belarusian/__init__.py b/TTS/tts/utils/text/belarusian/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/utils/text/belarusian/phonemizer.py b/TTS/tts/utils/text/belarusian/phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1922577e5b479980a8e11ac3ae15549cfeb178db --- /dev/null +++ b/TTS/tts/utils/text/belarusian/phonemizer.py @@ -0,0 +1,37 @@ +import os + +finder = None + + +def init(): + try: + import jpype + import jpype.imports + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Belarusian phonemizer requires to install module 'jpype1' manually. Try `pip install jpype1`." + ) + + try: + jar_path = os.environ["BEL_FANETYKA_JAR"] + except KeyError: + raise KeyError("You need to define 'BEL_FANETYKA_JAR' environment variable as path to the fanetyka.jar file") + + jpype.startJVM(classpath=[jar_path]) + + # import the Java modules + from org.alex73.korpus.base import GrammarDB2, GrammarFinder + + grammar_db = GrammarDB2.initializeFromJar() + global finder + finder = GrammarFinder(grammar_db) + + +def belarusian_text_to_phonemes(text: str) -> str: + # Initialize only on first run + if finder is None: + init() + + from org.alex73.fanetyka.impl import FanetykaText + + return str(FanetykaText(finder, text).ipa) diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py new file mode 100644 index 0000000000000000000000000000000000000000..4bf9bf6bd5c2f20e87f64d0048230cbdf159c98e --- /dev/null +++ b/TTS/tts/utils/text/characters.py @@ -0,0 +1,500 @@ +import logging +from dataclasses import replace +from typing import Dict + +from TTS.tts.configs.shared_configs import CharactersConfig + +logger = logging.getLogger(__name__) + + +def parse_symbols(): + return { + "pad": _pad, + "eos": _eos, + "bos": _bos, + "characters": _characters, + "punctuations": _punctuations, + "phonemes": _phonemes, + } + + +# DEFAULT SET OF GRAPHEMES +_pad = "" +_eos = "" +_bos = "" +_blank = "" # TODO: check if we need this alongside with PAD +_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" +_punctuations = "!'(),-.:;? " + + +# DEFAULT SET OF IPA PHONEMES +# Phonemes definition (All IPA characters) +_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ" +_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ" +_pulmonic_consonants = "pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ" +_suprasegmentals = "ˈˌːˑ" +_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ" +_diacritics = "̃ɚ˞ɫ" +_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacritics + + +class BaseVocabulary: + """Base Vocabulary class. + + This class only needs a vocabulary dictionary without specifying the characters. + + Args: + vocab (Dict): A dictionary of characters and their corresponding indices. + """ + + def __init__(self, vocab: Dict, pad: str = None, blank: str = None, bos: str = None, eos: str = None): + self.vocab = vocab + self.pad = pad + self.blank = blank + self.bos = bos + self.eos = eos + + @property + def pad_id(self) -> int: + """Return the index of the padding character. If the padding character is not specified, return the length + of the vocabulary.""" + return self.char_to_id(self.pad) if self.pad else len(self.vocab) + + @property + def blank_id(self) -> int: + """Return the index of the blank character. If the blank character is not specified, return the length of + the vocabulary.""" + return self.char_to_id(self.blank) if self.blank else len(self.vocab) + + @property + def bos_id(self) -> int: + """Return the index of the bos character. If the bos character is not specified, return the length of the + vocabulary.""" + return self.char_to_id(self.bos) if self.bos else len(self.vocab) + + @property + def eos_id(self) -> int: + """Return the index of the eos character. If the eos character is not specified, return the length of the + vocabulary.""" + return self.char_to_id(self.eos) if self.eos else len(self.vocab) + + @property + def vocab(self): + """Return the vocabulary dictionary.""" + return self._vocab + + @vocab.setter + def vocab(self, vocab): + """Set the vocabulary dictionary and character mapping dictionaries.""" + self._vocab, self._char_to_id, self._id_to_char = None, None, None + if vocab is not None: + self._vocab = vocab + self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)} + self._id_to_char = dict(enumerate(self._vocab)) + + @staticmethod + def init_from_config(config, **kwargs): + """Initialize from the given config.""" + if config.characters is not None and "vocab_dict" in config.characters and config.characters.vocab_dict: + return ( + BaseVocabulary( + config.characters.vocab_dict, + config.characters.pad, + config.characters.blank, + config.characters.bos, + config.characters.eos, + ), + config, + ) + return BaseVocabulary(**kwargs), config + + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + vocab_dict=self._vocab, + pad=self.pad, + eos=self.eos, + bos=self.bos, + blank=self.blank, + is_unique=False, + is_sorted=False, + ) + + @property + def num_chars(self): + """Return number of tokens in the vocabulary.""" + return len(self._vocab) + + def char_to_id(self, char: str) -> int: + """Map a character to an token ID.""" + try: + return self._char_to_id[char] + except KeyError as e: + raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e + + def id_to_char(self, idx: int) -> str: + """Map an token ID to a character.""" + return self._id_to_char[idx] + + +class BaseCharacters: + """🐸BaseCharacters class + + Every new character class should inherit from this. + + Characters are oredered as follows ```[PAD, EOS, BOS, BLANK, CHARACTERS, PUNCTUATIONS]```. + + If you need a custom order, you need to define inherit from this class and override the ```_create_vocab``` method. + + Args: + characters (str): + Main set of characters to be used in the vocabulary. + + punctuations (str): + Characters to be treated as punctuation. + + pad (str): + Special padding character that would be ignored by the model. + + eos (str): + End of the sentence character. + + bos (str): + Beginning of the sentence character. + + blank (str): + Optional character used between characters by some models for better prosody. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + el + is_sorted (bool): + Sort the characters in alphabetical order. Only applies to `self.characters`. Defaults to True. + """ + + def __init__( + self, + characters: str = None, + punctuations: str = None, + pad: str = None, + eos: str = None, + bos: str = None, + blank: str = None, + is_unique: bool = False, + is_sorted: bool = True, + ) -> None: + self._characters = characters + self._punctuations = punctuations + self._pad = pad + self._eos = eos + self._bos = bos + self._blank = blank + self.is_unique = is_unique + self.is_sorted = is_sorted + self._create_vocab() + + @property + def pad_id(self) -> int: + return self.char_to_id(self.pad) if self.pad else len(self.vocab) + + @property + def blank_id(self) -> int: + return self.char_to_id(self.blank) if self.blank else len(self.vocab) + + @property + def eos_id(self) -> int: + return self.char_to_id(self.eos) if self.eos else len(self.vocab) + + @property + def bos_id(self) -> int: + return self.char_to_id(self.bos) if self.bos else len(self.vocab) + + @property + def characters(self): + return self._characters + + @characters.setter + def characters(self, characters): + self._characters = characters + self._create_vocab() + + @property + def punctuations(self): + return self._punctuations + + @punctuations.setter + def punctuations(self, punctuations): + self._punctuations = punctuations + self._create_vocab() + + @property + def pad(self): + return self._pad + + @pad.setter + def pad(self, pad): + self._pad = pad + self._create_vocab() + + @property + def eos(self): + return self._eos + + @eos.setter + def eos(self, eos): + self._eos = eos + self._create_vocab() + + @property + def bos(self): + return self._bos + + @bos.setter + def bos(self, bos): + self._bos = bos + self._create_vocab() + + @property + def blank(self): + return self._blank + + @blank.setter + def blank(self, blank): + self._blank = blank + self._create_vocab() + + @property + def vocab(self): + return self._vocab + + @vocab.setter + def vocab(self, vocab): + self._vocab = vocab + self._char_to_id = {char: idx for idx, char in enumerate(self.vocab)} + self._id_to_char = dict(enumerate(self.vocab)) + + @property + def num_chars(self): + return len(self._vocab) + + def _create_vocab(self): + _vocab = self._characters + if self.is_unique: + _vocab = list(set(_vocab)) + if self.is_sorted: + _vocab = sorted(_vocab) + _vocab = list(_vocab) + _vocab = [self._blank] + _vocab if self._blank is not None and len(self._blank) > 0 else _vocab + _vocab = [self._bos] + _vocab if self._bos is not None and len(self._bos) > 0 else _vocab + _vocab = [self._eos] + _vocab if self._eos is not None and len(self._eos) > 0 else _vocab + _vocab = [self._pad] + _vocab if self._pad is not None and len(self._pad) > 0 else _vocab + self.vocab = _vocab + list(self._punctuations) + if self.is_unique: + duplicates = {x for x in self.vocab if self.vocab.count(x) > 1} + assert ( + len(self.vocab) == len(self._char_to_id) == len(self._id_to_char) + ), f" [!] There are duplicate characters in the character set. {duplicates}" + + def char_to_id(self, char: str) -> int: + try: + return self._char_to_id[char] + except KeyError as e: + raise KeyError(f" [!] {repr(char)} is not in the vocabulary.") from e + + def id_to_char(self, idx: int) -> str: + return self._id_to_char[idx] + + def print_log(self, level: int = 0): + """ + Prints the vocabulary in a nice format. + """ + indent = "\t" * level + logger.info("%s| Characters: %s", indent, self._characters) + logger.info("%s| Punctuations: %s", indent, self._punctuations) + logger.info("%s| Pad: %s", indent, self._pad) + logger.info("%s| EOS: %s", indent, self._eos) + logger.info("%s| BOS: %s", indent, self._bos) + logger.info("%s| Blank: %s", indent, self._blank) + logger.info("%s| Vocab: %s", indent, self.vocab) + logger.info("%s| Num chars: %d", indent, self.num_chars) + + @staticmethod + def init_from_config(config: "Coqpit"): # pylint: disable=unused-argument + """Init your character class from a config. + + Implement this method for your subclass. + """ + # use character set from config + if config.characters is not None: + return BaseCharacters(**config.characters), config + # return default character set + characters = BaseCharacters() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + characters=self._characters, + punctuations=self._punctuations, + pad=self._pad, + eos=self._eos, + bos=self._bos, + blank=self._blank, + is_unique=self.is_unique, + is_sorted=self.is_sorted, + ) + + +class IPAPhonemes(BaseCharacters): + """🐸IPAPhonemes class to manage `TTS.tts` model vocabulary + + Intended to be used with models using IPAPhonemes as input. + It uses system defaults for the undefined class arguments. + + Args: + characters (str): + Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_phonemes`. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `_punctuations`. + + pad (str): + Special padding character that would be ignored by the model. Defaults to `_pad`. + + eos (str): + End of the sentence character. Defaults to `_eos`. + + bos (str): + Beginning of the sentence character. Defaults to `_bos`. + + blank (str): + Optional character used between characters by some models for better prosody. Defaults to `_blank`. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + def __init__( + self, + characters: str = _phonemes, + punctuations: str = _punctuations, + pad: str = _pad, + eos: str = _eos, + bos: str = _bos, + blank: str = _blank, + is_unique: bool = False, + is_sorted: bool = True, + ) -> None: + super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted) + + @staticmethod + def init_from_config(config: "Coqpit"): + """Init a IPAPhonemes object from a model config + + If characters are not defined in the config, it will be set to the default characters and the config + will be updated. + """ + # band-aid for compatibility with old models + if "characters" in config and config.characters is not None: + if "phonemes" in config.characters and config.characters.phonemes is not None: + config.characters["characters"] = config.characters["phonemes"] + return ( + IPAPhonemes( + characters=config.characters["characters"], + punctuations=config.characters["punctuations"], + pad=config.characters["pad"], + eos=config.characters["eos"], + bos=config.characters["bos"], + blank=config.characters["blank"], + is_unique=config.characters["is_unique"], + is_sorted=config.characters["is_sorted"], + ), + config, + ) + # use character set from config + if config.characters is not None: + return IPAPhonemes(**config.characters), config + # return default character set + characters = IPAPhonemes() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + +class Graphemes(BaseCharacters): + """🐸Graphemes class to manage `TTS.tts` model vocabulary + + Intended to be used with models using graphemes as input. + It uses system defaults for the undefined class arguments. + + Args: + characters (str): + Main set of case-sensitive characters to be used in the vocabulary. Defaults to `_characters`. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `_punctuations`. + + pad (str): + Special padding character that would be ignored by the model. Defaults to `_pad`. + + eos (str): + End of the sentence character. Defaults to `_eos`. + + bos (str): + Beginning of the sentence character. Defaults to `_bos`. + + is_unique (bool): + Remove duplicates from the provided characters. Defaults to True. + + is_sorted (bool): + Sort the characters in alphabetical order. Defaults to True. + """ + + def __init__( + self, + characters: str = _characters, + punctuations: str = _punctuations, + pad: str = _pad, + eos: str = _eos, + bos: str = _bos, + blank: str = _blank, + is_unique: bool = False, + is_sorted: bool = True, + ) -> None: + super().__init__(characters, punctuations, pad, eos, bos, blank, is_unique, is_sorted) + + @staticmethod + def init_from_config(config: "Coqpit"): + """Init a Graphemes object from a model config + + If characters are not defined in the config, it will be set to the default characters and the config + will be updated. + """ + if config.characters is not None: + # band-aid for compatibility with old models + if "phonemes" in config.characters: + return ( + Graphemes( + characters=config.characters["characters"], + punctuations=config.characters["punctuations"], + pad=config.characters["pad"], + eos=config.characters["eos"], + bos=config.characters["bos"], + blank=config.characters["blank"], + is_unique=config.characters["is_unique"], + is_sorted=config.characters["is_sorted"], + ), + config, + ) + return Graphemes(**config.characters), config + characters = Graphemes() + new_config = replace(config, characters=characters.to_config()) + return characters, new_config + + +if __name__ == "__main__": + gr = Graphemes() + ph = IPAPhonemes() + gr.print_log() + ph.print_log() diff --git a/TTS/tts/utils/text/chinese_mandarin/__init__.py b/TTS/tts/utils/text/chinese_mandarin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/utils/text/chinese_mandarin/numbers.py b/TTS/tts/utils/text/chinese_mandarin/numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..4787ea61007656819eb57d52d5865b38c7afa915 --- /dev/null +++ b/TTS/tts/utils/text/chinese_mandarin/numbers.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Licensed under WTFPL or the Unlicense or CC0. +# This uses Python 3, but it's easy to port to Python 2 by changing +# strings to u'xx'. + +import itertools +import re + + +def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str: + """Convert numerical arabic numbers (0->9) to chinese hanzi numbers (〇 -> 九) + + Args: + num (str): arabic number to convert + big (bool, optional): use financial characters. Defaults to False. + simp (bool, optional): use simplified characters instead of tradictional characters. Defaults to True. + o (bool, optional): use 〇 for 'zero'. Defaults to False. + twoalt (bool, optional): use 两/兩 for 'two' when appropriate. Defaults to False. + + Raises: + ValueError: if number is more than 1e48 + ValueError: if 'e' exposent in number + + Returns: + str: converted number as hanzi characters + """ + + # check num first + nd = str(num) + if abs(float(nd)) >= 1e48: + raise ValueError("number out of range") + if "e" in nd: + raise ValueError("scientific notation is not supported") + c_symbol = "正负点" if simp else "正負點" + if o: # formal + twoalt = False + if big: + c_basic = "零壹贰叁肆伍陆柒捌玖" if simp else "零壹貳參肆伍陸柒捌玖" + c_unit1 = "拾佰仟" + c_twoalt = "贰" if simp else "貳" + else: + c_basic = "〇一二三四五六七八九" if o else "零一二三四五六七八九" + c_unit1 = "十百千" + if twoalt: + c_twoalt = "两" if simp else "兩" + else: + c_twoalt = "二" + c_unit2 = "万亿兆京垓秭穰沟涧正载" if simp else "萬億兆京垓秭穰溝澗正載" + revuniq = lambda l: "".join(k for k, g in itertools.groupby(reversed(l))) + nd = str(num) + result = [] + if nd[0] == "+": + result.append(c_symbol[0]) + elif nd[0] == "-": + result.append(c_symbol[1]) + if "." in nd: + integer, remainder = nd.lstrip("+-").split(".") + else: + integer, remainder = nd.lstrip("+-"), None + if int(integer): + splitted = [integer[max(i - 4, 0) : i] for i in range(len(integer), 0, -4)] + intresult = [] + for nu, unit in enumerate(splitted): + # special cases + if int(unit) == 0: # 0000 + intresult.append(c_basic[0]) + continue + if nu > 0 and int(unit) == 2: # 0002 + intresult.append(c_twoalt + c_unit2[nu - 1]) + continue + ulist = [] + unit = unit.zfill(4) + for nc, ch in enumerate(reversed(unit)): + if ch == "0": + if ulist: # ???0 + ulist.append(c_basic[0]) + elif nc == 0: + ulist.append(c_basic[int(ch)]) + elif nc == 1 and ch == "1" and unit[1] == "0": + # special case for tens + # edit the 'elif' if you don't like + # 十四, 三千零十四, 三千三百一十四 + ulist.append(c_unit1[0]) + elif nc > 1 and ch == "2": + ulist.append(c_twoalt + c_unit1[nc - 1]) + else: + ulist.append(c_basic[int(ch)] + c_unit1[nc - 1]) + ustr = revuniq(ulist) + if nu == 0: + intresult.append(ustr) + else: + intresult.append(ustr + c_unit2[nu - 1]) + result.append(revuniq(intresult).strip(c_basic[0])) + else: + result.append(c_basic[0]) + if remainder: + result.append(c_symbol[2]) + result.append("".join(c_basic[int(ch)] for ch in remainder)) + return "".join(result) + + +def _number_replace(match) -> str: + """function to apply in a match, transform all numbers in a match by chinese characters + + Args: + match (re.Match): numbers regex matches + + Returns: + str: replaced characters for the numbers + """ + match_str: str = match.group() + return _num2chinese(match_str) + + +def replace_numbers_to_characters_in_text(text: str) -> str: + """Replace all arabic numbers in a text by their equivalent in chinese characters (simplified) + + Args: + text (str): input text to transform + + Returns: + str: output text + """ + text = re.sub(r"[0-9]+", _number_replace, text) + return text diff --git a/TTS/tts/utils/text/chinese_mandarin/phonemizer.py b/TTS/tts/utils/text/chinese_mandarin/phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d62e9d06638f55cd13734ef3e698d9b8554438 --- /dev/null +++ b/TTS/tts/utils/text/chinese_mandarin/phonemizer.py @@ -0,0 +1,40 @@ +from typing import List + +try: + import jieba + import pypinyin +except ImportError as e: + raise ImportError("Chinese requires: jieba, pypinyin") from e + +from .pinyinToPhonemes import PINYIN_DICT + + +def _chinese_character_to_pinyin(text: str) -> List[str]: + pinyins = pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True) + pinyins_flat_list = [item for sublist in pinyins for item in sublist] + return pinyins_flat_list + + +def _chinese_pinyin_to_phoneme(pinyin: str) -> str: + segment = pinyin[:-1] + tone = pinyin[-1] + phoneme = PINYIN_DICT.get(segment, [""])[0] + return phoneme + tone + + +def chinese_text_to_phonemes(text: str, seperator: str = "|") -> str: + tokenized_text = jieba.cut(text, HMM=False) + tokenized_text = " ".join(tokenized_text) + pinyined_text: List[str] = _chinese_character_to_pinyin(tokenized_text) + + results: List[str] = [] + + for token in pinyined_text: + if token[-1] in "12345": # TODO transform to is_pinyin() + pinyin_phonemes = _chinese_pinyin_to_phoneme(token) + + results += list(pinyin_phonemes) + else: # is ponctuation or other + results += list(token) + + return seperator.join(results) diff --git a/TTS/tts/utils/text/chinese_mandarin/pinyinToPhonemes.py b/TTS/tts/utils/text/chinese_mandarin/pinyinToPhonemes.py new file mode 100644 index 0000000000000000000000000000000000000000..89dd654ab16c5817cc28efd23b8a21466f814395 --- /dev/null +++ b/TTS/tts/utils/text/chinese_mandarin/pinyinToPhonemes.py @@ -0,0 +1,419 @@ +PINYIN_DICT = { + "a": ["a"], + "ai": ["ai"], + "an": ["an"], + "ang": ["ɑŋ"], + "ao": ["aʌ"], + "ba": ["ba"], + "bai": ["bai"], + "ban": ["ban"], + "bang": ["bɑŋ"], + "bao": ["baʌ"], + # "be": ["be"], doesnt exist + "bei": ["bɛi"], + "ben": ["bœn"], + "beng": ["bɵŋ"], + "bi": ["bi"], + "bian": ["biɛn"], + "biao": ["biaʌ"], + "bie": ["bie"], + "bin": ["bin"], + "bing": ["bɨŋ"], + "bo": ["bo"], + "bu": ["bu"], + "ca": ["tsa"], + "cai": ["tsai"], + "can": ["tsan"], + "cang": ["tsɑŋ"], + "cao": ["tsaʌ"], + "ce": ["tsø"], + "cen": ["tsœn"], + "ceng": ["tsɵŋ"], + "cha": ["ʈʂa"], + "chai": ["ʈʂai"], + "chan": ["ʈʂan"], + "chang": ["ʈʂɑŋ"], + "chao": ["ʈʂaʌ"], + "che": ["ʈʂø"], + "chen": ["ʈʂœn"], + "cheng": ["ʈʂɵŋ"], + "chi": ["ʈʂʏ"], + "chong": ["ʈʂoŋ"], + "chou": ["ʈʂou"], + "chu": ["ʈʂu"], + "chua": ["ʈʂua"], + "chuai": ["ʈʂuai"], + "chuan": ["ʈʂuan"], + "chuang": ["ʈʂuɑŋ"], + "chui": ["ʈʂuei"], + "chun": ["ʈʂun"], + "chuo": ["ʈʂuo"], + "ci": ["tsɪ"], + "cong": ["tsoŋ"], + "cou": ["tsou"], + "cu": ["tsu"], + "cuan": ["tsuan"], + "cui": ["tsuei"], + "cun": ["tsun"], + "cuo": ["tsuo"], + "da": ["da"], + "dai": ["dai"], + "dan": ["dan"], + "dang": ["dɑŋ"], + "dao": ["daʌ"], + "de": ["dø"], + "dei": ["dei"], + # "den": ["dœn"], + "deng": ["dɵŋ"], + "di": ["di"], + "dia": ["dia"], + "dian": ["diɛn"], + "diao": ["diaʌ"], + "die": ["die"], + "ding": ["dɨŋ"], + "diu": ["dio"], + "dong": ["doŋ"], + "dou": ["dou"], + "du": ["du"], + "duan": ["duan"], + "dui": ["duei"], + "dun": ["dun"], + "duo": ["duo"], + "e": ["ø"], + "ei": ["ei"], + "en": ["œn"], + # "ng": ["œn"], + # "eng": ["ɵŋ"], + "er": ["er"], + "fa": ["fa"], + "fan": ["fan"], + "fang": ["fɑŋ"], + "fei": ["fei"], + "fen": ["fœn"], + "feng": ["fɵŋ"], + "fo": ["fo"], + "fou": ["fou"], + "fu": ["fu"], + "ga": ["ɡa"], + "gai": ["ɡai"], + "gan": ["ɡan"], + "gang": ["ɡɑŋ"], + "gao": ["ɡaʌ"], + "ge": ["ɡø"], + "gei": ["ɡei"], + "gen": ["ɡœn"], + "geng": ["ɡɵŋ"], + "gong": ["ɡoŋ"], + "gou": ["ɡou"], + "gu": ["ɡu"], + "gua": ["ɡua"], + "guai": ["ɡuai"], + "guan": ["ɡuan"], + "guang": ["ɡuɑŋ"], + "gui": ["ɡuei"], + "gun": ["ɡun"], + "guo": ["ɡuo"], + "ha": ["xa"], + "hai": ["xai"], + "han": ["xan"], + "hang": ["xɑŋ"], + "hao": ["xaʌ"], + "he": ["xø"], + "hei": ["xei"], + "hen": ["xœn"], + "heng": ["xɵŋ"], + "hong": ["xoŋ"], + "hou": ["xou"], + "hu": ["xu"], + "hua": ["xua"], + "huai": ["xuai"], + "huan": ["xuan"], + "huang": ["xuɑŋ"], + "hui": ["xuei"], + "hun": ["xun"], + "huo": ["xuo"], + "ji": ["dʑi"], + "jia": ["dʑia"], + "jian": ["dʑiɛn"], + "jiang": ["dʑiɑŋ"], + "jiao": ["dʑiaʌ"], + "jie": ["dʑie"], + "jin": ["dʑin"], + "jing": ["dʑɨŋ"], + "jiong": ["dʑioŋ"], + "jiu": ["dʑio"], + "ju": ["dʑy"], + "juan": ["dʑyɛn"], + "jue": ["dʑye"], + "jun": ["dʑyn"], + "ka": ["ka"], + "kai": ["kai"], + "kan": ["kan"], + "kang": ["kɑŋ"], + "kao": ["kaʌ"], + "ke": ["kø"], + "kei": ["kei"], + "ken": ["kœn"], + "keng": ["kɵŋ"], + "kong": ["koŋ"], + "kou": ["kou"], + "ku": ["ku"], + "kua": ["kua"], + "kuai": ["kuai"], + "kuan": ["kuan"], + "kuang": ["kuɑŋ"], + "kui": ["kuei"], + "kun": ["kun"], + "kuo": ["kuo"], + "la": ["la"], + "lai": ["lai"], + "lan": ["lan"], + "lang": ["lɑŋ"], + "lao": ["laʌ"], + "le": ["lø"], + "lei": ["lei"], + "leng": ["lɵŋ"], + "li": ["li"], + "lia": ["lia"], + "lian": ["liɛn"], + "liang": ["liɑŋ"], + "liao": ["liaʌ"], + "lie": ["lie"], + "lin": ["lin"], + "ling": ["lɨŋ"], + "liu": ["lio"], + "lo": ["lo"], + "long": ["loŋ"], + "lou": ["lou"], + "lu": ["lu"], + "lv": ["ly"], + "luan": ["luan"], + "lve": ["lye"], + "lue": ["lue"], + "lun": ["lun"], + "luo": ["luo"], + "ma": ["ma"], + "mai": ["mai"], + "man": ["man"], + "mang": ["mɑŋ"], + "mao": ["maʌ"], + "me": ["mø"], + "mei": ["mei"], + "men": ["mœn"], + "meng": ["mɵŋ"], + "mi": ["mi"], + "mian": ["miɛn"], + "miao": ["miaʌ"], + "mie": ["mie"], + "min": ["min"], + "ming": ["mɨŋ"], + "miu": ["mio"], + "mo": ["mo"], + "mou": ["mou"], + "mu": ["mu"], + "na": ["na"], + "nai": ["nai"], + "nan": ["nan"], + "nang": ["nɑŋ"], + "nao": ["naʌ"], + "ne": ["nø"], + "nei": ["nei"], + "nen": ["nœn"], + "neng": ["nɵŋ"], + "ni": ["ni"], + "nia": ["nia"], + "nian": ["niɛn"], + "niang": ["niɑŋ"], + "niao": ["niaʌ"], + "nie": ["nie"], + "nin": ["nin"], + "ning": ["nɨŋ"], + "niu": ["nio"], + "nong": ["noŋ"], + "nou": ["nou"], + "nu": ["nu"], + "nv": ["ny"], + "nuan": ["nuan"], + "nve": ["nye"], + "nue": ["nye"], + "nuo": ["nuo"], + "o": ["o"], + "ou": ["ou"], + "pa": ["pa"], + "pai": ["pai"], + "pan": ["pan"], + "pang": ["pɑŋ"], + "pao": ["paʌ"], + "pe": ["pø"], + "pei": ["pei"], + "pen": ["pœn"], + "peng": ["pɵŋ"], + "pi": ["pi"], + "pian": ["piɛn"], + "piao": ["piaʌ"], + "pie": ["pie"], + "pin": ["pin"], + "ping": ["pɨŋ"], + "po": ["po"], + "pou": ["pou"], + "pu": ["pu"], + "qi": ["tɕi"], + "qia": ["tɕia"], + "qian": ["tɕiɛn"], + "qiang": ["tɕiɑŋ"], + "qiao": ["tɕiaʌ"], + "qie": ["tɕie"], + "qin": ["tɕin"], + "qing": ["tɕɨŋ"], + "qiong": ["tɕioŋ"], + "qiu": ["tɕio"], + "qu": ["tɕy"], + "quan": ["tɕyɛn"], + "que": ["tɕye"], + "qun": ["tɕyn"], + "ran": ["ʐan"], + "rang": ["ʐɑŋ"], + "rao": ["ʐaʌ"], + "re": ["ʐø"], + "ren": ["ʐœn"], + "reng": ["ʐɵŋ"], + "ri": ["ʐʏ"], + "rong": ["ʐoŋ"], + "rou": ["ʐou"], + "ru": ["ʐu"], + "rua": ["ʐua"], + "ruan": ["ʐuan"], + "rui": ["ʐuei"], + "run": ["ʐun"], + "ruo": ["ʐuo"], + "sa": ["sa"], + "sai": ["sai"], + "san": ["san"], + "sang": ["sɑŋ"], + "sao": ["saʌ"], + "se": ["sø"], + "sen": ["sœn"], + "seng": ["sɵŋ"], + "sha": ["ʂa"], + "shai": ["ʂai"], + "shan": ["ʂan"], + "shang": ["ʂɑŋ"], + "shao": ["ʂaʌ"], + "she": ["ʂø"], + "shei": ["ʂei"], + "shen": ["ʂœn"], + "sheng": ["ʂɵŋ"], + "shi": ["ʂʏ"], + "shou": ["ʂou"], + "shu": ["ʂu"], + "shua": ["ʂua"], + "shuai": ["ʂuai"], + "shuan": ["ʂuan"], + "shuang": ["ʂuɑŋ"], + "shui": ["ʂuei"], + "shun": ["ʂun"], + "shuo": ["ʂuo"], + "si": ["sɪ"], + "song": ["soŋ"], + "sou": ["sou"], + "su": ["su"], + "suan": ["suan"], + "sui": ["suei"], + "sun": ["sun"], + "suo": ["suo"], + "ta": ["ta"], + "tai": ["tai"], + "tan": ["tan"], + "tang": ["tɑŋ"], + "tao": ["taʌ"], + "te": ["tø"], + "tei": ["tei"], + "teng": ["tɵŋ"], + "ti": ["ti"], + "tian": ["tiɛn"], + "tiao": ["tiaʌ"], + "tie": ["tie"], + "ting": ["tɨŋ"], + "tong": ["toŋ"], + "tou": ["tou"], + "tu": ["tu"], + "tuan": ["tuan"], + "tui": ["tuei"], + "tun": ["tun"], + "tuo": ["tuo"], + "wa": ["wa"], + "wai": ["wai"], + "wan": ["wan"], + "wang": ["wɑŋ"], + "wei": ["wei"], + "wen": ["wœn"], + "weng": ["wɵŋ"], + "wo": ["wo"], + "wu": ["wu"], + "xi": ["ɕi"], + "xia": ["ɕia"], + "xian": ["ɕiɛn"], + "xiang": ["ɕiɑŋ"], + "xiao": ["ɕiaʌ"], + "xie": ["ɕie"], + "xin": ["ɕin"], + "xing": ["ɕɨŋ"], + "xiong": ["ɕioŋ"], + "xiu": ["ɕio"], + "xu": ["ɕy"], + "xuan": ["ɕyɛn"], + "xue": ["ɕye"], + "xun": ["ɕyn"], + "ya": ["ia"], + "yan": ["iɛn"], + "yang": ["iɑŋ"], + "yao": ["iaʌ"], + "ye": ["ie"], + "yi": ["i"], + "yin": ["in"], + "ying": ["ɨŋ"], + "yo": ["io"], + "yong": ["ioŋ"], + "you": ["io"], + "yu": ["y"], + "yuan": ["yɛn"], + "yue": ["ye"], + "yun": ["yn"], + "za": ["dza"], + "zai": ["dzai"], + "zan": ["dzan"], + "zang": ["dzɑŋ"], + "zao": ["dzaʌ"], + "ze": ["dzø"], + "zei": ["dzei"], + "zen": ["dzœn"], + "zeng": ["dzɵŋ"], + "zha": ["dʒa"], + "zhai": ["dʒai"], + "zhan": ["dʒan"], + "zhang": ["dʒɑŋ"], + "zhao": ["dʒaʌ"], + "zhe": ["dʒø"], + # "zhei": ["dʒei"], it doesn't exist + "zhen": ["dʒœn"], + "zheng": ["dʒɵŋ"], + "zhi": ["dʒʏ"], + "zhong": ["dʒoŋ"], + "zhou": ["dʒou"], + "zhu": ["dʒu"], + "zhua": ["dʒua"], + "zhuai": ["dʒuai"], + "zhuan": ["dʒuan"], + "zhuang": ["dʒuɑŋ"], + "zhui": ["dʒuei"], + "zhun": ["dʒun"], + "zhuo": ["dʒuo"], + "zi": ["dzɪ"], + "zong": ["dzoŋ"], + "zou": ["dzou"], + "zu": ["dzu"], + "zuan": ["dzuan"], + "zui": ["dzuei"], + "zun": ["dzun"], + "zuo": ["dzuo"], +} diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py new file mode 100644 index 0000000000000000000000000000000000000000..f496b9f0dd50e5bb466cf56310e90eac6952249e --- /dev/null +++ b/TTS/tts/utils/text/cleaners.py @@ -0,0 +1,204 @@ +"""Set of default text cleaners""" + +import re +from typing import Optional +from unicodedata import normalize + +from anyascii import anyascii + +from TTS.tts.utils.text.chinese_mandarin.numbers import replace_numbers_to_characters_in_text + +from .english.abbreviations import abbreviations_en +from .english.number_norm import normalize_numbers as en_normalize_numbers +from .english.time_norm import expand_time_english +from .french.abbreviations import abbreviations_fr + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r"\s+") + + +def expand_abbreviations(text: str, lang: str = "en") -> str: + if lang == "en": + _abbreviations = abbreviations_en + elif lang == "fr": + _abbreviations = abbreviations_fr + else: + msg = f"Language {lang} not supported in expand_abbreviations" + raise ValueError(msg) + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def lowercase(text: str) -> str: + return text.lower() + + +def collapse_whitespace(text: str) -> str: + return re.sub(_whitespace_re, " ", text).strip() + + +def convert_to_ascii(text: str) -> str: + return anyascii(text) + + +def remove_aux_symbols(text: str) -> str: + text = re.sub(r"[\<\>\(\)\[\]\"]+", "", text) + return text + + +def replace_symbols(text: str, lang: Optional[str] = "en") -> str: + """Replace symbols based on the language tag. + + Args: + text: + Input text. + lang: + Lenguage identifier. ex: "en", "fr", "pt", "ca". + + Returns: + The modified text + example: + input args: + text: "si l'avi cau, diguem-ho" + lang: "ca" + Output: + text: "si lavi cau, diguemho" + """ + text = text.replace(";", ",") + text = text.replace("-", " ") if lang != "ca" else text.replace("-", "") + text = text.replace(":", ",") + if lang == "en": + text = text.replace("&", " and ") + elif lang == "fr": + text = text.replace("&", " et ") + elif lang == "pt": + text = text.replace("&", " e ") + elif lang == "ca": + text = text.replace("&", " i ") + text = text.replace("'", "") + return text + + +def basic_cleaners(text: str) -> str: + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" + text = normalize_unicode(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text: str) -> str: + """Pipeline for non-English text that transliterates to ASCII.""" + text = normalize_unicode(text) + # text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def basic_german_cleaners(text: str) -> str: + """Pipeline for German text""" + text = normalize_unicode(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +# TODO: elaborate it +def basic_turkish_cleaners(text: str) -> str: + """Pipeline for Turkish text""" + text = normalize_unicode(text) + text = text.replace("I", "ı") + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text: str) -> str: + """Pipeline for English text, including number and abbreviation expansion.""" + text = normalize_unicode(text) + # text = convert_to_ascii(text) + text = lowercase(text) + text = expand_time_english(text) + text = en_normalize_numbers(text) + text = expand_abbreviations(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def phoneme_cleaners(text: str) -> str: + """Pipeline for phonemes mode, including number and abbreviation expansion. + + NB: This cleaner converts numbers into English words, for other languages + use multilingual_phoneme_cleaners(). + """ + text = normalize_unicode(text) + text = en_normalize_numbers(text) + text = expand_abbreviations(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def multilingual_phoneme_cleaners(text: str) -> str: + """Pipeline for phonemes mode, including number and abbreviation expansion.""" + text = normalize_unicode(text) + text = replace_symbols(text, lang=None) + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def french_cleaners(text: str) -> str: + """Pipeline for French text. There is no need to expand numbers, phonemizer already does that""" + text = normalize_unicode(text) + text = expand_abbreviations(text, lang="fr") + text = lowercase(text) + text = replace_symbols(text, lang="fr") + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def portuguese_cleaners(text: str) -> str: + """Basic pipeline for Portuguese text. There is no need to expand abbreviation and + numbers, phonemizer already does that""" + text = normalize_unicode(text) + text = lowercase(text) + text = replace_symbols(text, lang="pt") + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def chinese_mandarin_cleaners(text: str) -> str: + """Basic pipeline for chinese""" + text = normalize_unicode(text) + text = replace_numbers_to_characters_in_text(text) + return text + + +def multilingual_cleaners(text: str) -> str: + """Pipeline for multilingual text""" + text = normalize_unicode(text) + text = lowercase(text) + text = replace_symbols(text, lang=None) + text = remove_aux_symbols(text) + text = collapse_whitespace(text) + return text + + +def no_cleaners(text: str) -> str: + # remove newline characters + text = text.replace("\n", "") + return text + + +def normalize_unicode(text: str) -> str: + """Normalize Unicode characters.""" + text = normalize("NFC", text) + return text diff --git a/TTS/tts/utils/text/cmudict.py b/TTS/tts/utils/text/cmudict.py new file mode 100644 index 0000000000000000000000000000000000000000..f206fb043be1d478fa6ace36fefdefa30b0acb02 --- /dev/null +++ b/TTS/tts/utils/text/cmudict.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- + +import re + +VALID_SYMBOLS = [ + "AA", + "AA0", + "AA1", + "AA2", + "AE", + "AE0", + "AE1", + "AE2", + "AH", + "AH0", + "AH1", + "AH2", + "AO", + "AO0", + "AO1", + "AO2", + "AW", + "AW0", + "AW1", + "AW2", + "AY", + "AY0", + "AY1", + "AY2", + "B", + "CH", + "D", + "DH", + "EH", + "EH0", + "EH1", + "EH2", + "ER", + "ER0", + "ER1", + "ER2", + "EY", + "EY0", + "EY1", + "EY2", + "F", + "G", + "HH", + "IH", + "IH0", + "IH1", + "IH2", + "IY", + "IY0", + "IY1", + "IY2", + "JH", + "K", + "L", + "M", + "N", + "NG", + "OW", + "OW0", + "OW1", + "OW2", + "OY", + "OY0", + "OY1", + "OY2", + "P", + "R", + "S", + "SH", + "T", + "TH", + "UH", + "UH0", + "UH1", + "UH2", + "UW", + "UW0", + "UW1", + "UW2", + "V", + "W", + "Y", + "Z", + "ZH", +] + + +class CMUDict: + """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" + + def __init__(self, file_or_path, keep_ambiguous=True): + if isinstance(file_or_path, str): + with open(file_or_path, encoding="latin-1") as f: + entries = _parse_cmudict(f) + else: + entries = _parse_cmudict(file_or_path) + if not keep_ambiguous: + entries = {word: pron for word, pron in entries.items() if len(pron) == 1} + self._entries = entries + + def __len__(self): + return len(self._entries) + + def lookup(self, word): + """Returns list of ARPAbet pronunciations of the given word.""" + return self._entries.get(word.upper()) + + @staticmethod + def get_arpabet(word, cmudict, punctuation_symbols): + first_symbol, last_symbol = "", "" + if word and word[0] in punctuation_symbols: + first_symbol = word[0] + word = word[1:] + if word and word[-1] in punctuation_symbols: + last_symbol = word[-1] + word = word[:-1] + arpabet = cmudict.lookup(word) + if arpabet is not None: + return first_symbol + "{%s}" % arpabet[0] + last_symbol + return first_symbol + word + last_symbol + + +_alt_re = re.compile(r"\([0-9]+\)") + + +def _parse_cmudict(file): + cmudict = {} + for line in file: + if line and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): + parts = line.split(" ") + word = re.sub(_alt_re, "", parts[0]) + pronunciation = _get_pronunciation(parts[1]) + if pronunciation: + if word in cmudict: + cmudict[word].append(pronunciation) + else: + cmudict[word] = [pronunciation] + return cmudict + + +def _get_pronunciation(s): + parts = s.strip().split(" ") + for part in parts: + if part not in VALID_SYMBOLS: + return None + return " ".join(parts) diff --git a/TTS/tts/utils/text/english/__init__.py b/TTS/tts/utils/text/english/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/utils/text/english/abbreviations.py b/TTS/tts/utils/text/english/abbreviations.py new file mode 100644 index 0000000000000000000000000000000000000000..cd93c13c8ecfbc0df2d0c6d2fa348388940c213a --- /dev/null +++ b/TTS/tts/utils/text/english/abbreviations.py @@ -0,0 +1,26 @@ +import re + +# List of (regular expression, replacement) pairs for abbreviations in english: +abbreviations_en = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] diff --git a/TTS/tts/utils/text/english/number_norm.py b/TTS/tts/utils/text/english/number_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..e8377ede87ebc9d1bb9cffbbb290aa7787caea4f --- /dev/null +++ b/TTS/tts/utils/text/english/number_norm.py @@ -0,0 +1,97 @@ +""" from https://github.com/keithito/tacotron """ + +import re +from typing import Dict + +import inflect + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_currency_re = re.compile(r"(£|\$|¥)([0-9\,\.]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"-?[0-9]+") + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def __expand_currency(value: str, inflection: Dict[float, str]) -> str: + parts = value.replace(",", "").split(".") + if len(parts) > 2: + return f"{value} {inflection[2]}" # Unexpected format + text = [] + integer = int(parts[0]) if parts[0] else 0 + if integer > 0: + integer_unit = inflection.get(integer, inflection[2]) + text.append(f"{integer} {integer_unit}") + fraction = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if fraction > 0: + fraction_unit = inflection.get(fraction / 100, inflection[0.02]) + text.append(f"{fraction} {fraction_unit}") + if len(text) == 0: + return f"zero {inflection[2]}" + return " ".join(text) + + +def _expand_currency(m: "re.Match") -> str: + currencies = { + "$": { + 0.01: "cent", + 0.02: "cents", + 1: "dollar", + 2: "dollars", + }, + "€": { + 0.01: "cent", + 0.02: "cents", + 1: "euro", + 2: "euros", + }, + "£": { + 0.01: "penny", + 0.02: "pence", + 1: "pound sterling", + 2: "pounds sterling", + }, + "¥": { + # TODO rin + 0.02: "sen", + 2: "yen", + }, + } + unit = m.group(1) + currency = currencies[unit] + value = m.group(2) + return __expand_currency(value, currency) + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if 1000 < num < 3000: + if num == 2000: + return "two thousand" + if 2000 < num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + if num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + return _inflect.number_to_words(num, andword="") + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_currency_re, _expand_currency, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/TTS/tts/utils/text/english/time_norm.py b/TTS/tts/utils/text/english/time_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..c8ac09e79db4a239a7f72f101503dbf0d6feb3ae --- /dev/null +++ b/TTS/tts/utils/text/english/time_norm.py @@ -0,0 +1,47 @@ +import re + +import inflect + +_inflect = inflect.engine() + +_time_re = re.compile( + r"""\b + ((0?[0-9])|(1[0-1])|(1[2-9])|(2[0-3])) # hours + : + ([0-5][0-9]) # minutes + \s*(a\\.m\\.|am|pm|p\\.m\\.|a\\.m|p\\.m)? # am/pm + \b""", + re.IGNORECASE | re.X, +) + + +def _expand_num(n: int) -> str: + return _inflect.number_to_words(n) + + +def _expand_time_english(match: "re.Match") -> str: + hour = int(match.group(1)) + past_noon = hour >= 12 + time = [] + if hour > 12: + hour -= 12 + elif hour == 0: + hour = 12 + past_noon = True + time.append(_expand_num(hour)) + + minute = int(match.group(6)) + if minute > 0: + if minute < 10: + time.append("oh") + time.append(_expand_num(minute)) + am_pm = match.group(7) + if am_pm is None: + time.append("p m" if past_noon else "a m") + else: + time.extend(list(am_pm.replace(".", ""))) + return " ".join(time) + + +def expand_time_english(text: str) -> str: + return re.sub(_time_re, _expand_time_english, text) diff --git a/TTS/tts/utils/text/french/__init__.py b/TTS/tts/utils/text/french/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/utils/text/french/abbreviations.py b/TTS/tts/utils/text/french/abbreviations.py new file mode 100644 index 0000000000000000000000000000000000000000..f580dfed7b4576a9f87b0a4145cb729e70050d50 --- /dev/null +++ b/TTS/tts/utils/text/french/abbreviations.py @@ -0,0 +1,48 @@ +import re + +# List of (regular expression, replacement) pairs for abbreviations in french: +abbreviations_fr = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("M", "monsieur"), + ("Mlle", "mademoiselle"), + ("Mlles", "mesdemoiselles"), + ("Mme", "Madame"), + ("Mmes", "Mesdames"), + ("N.B", "nota bene"), + ("M", "monsieur"), + ("p.c.q", "parce que"), + ("Pr", "professeur"), + ("qqch", "quelque chose"), + ("rdv", "rendez-vous"), + ("max", "maximum"), + ("min", "minimum"), + ("no", "numéro"), + ("adr", "adresse"), + ("dr", "docteur"), + ("st", "saint"), + ("co", "companie"), + ("jr", "junior"), + ("sgt", "sergent"), + ("capt", "capitain"), + ("col", "colonel"), + ("av", "avenue"), + ("av. J.-C", "avant Jésus-Christ"), + ("apr. J.-C", "après Jésus-Christ"), + ("art", "article"), + ("boul", "boulevard"), + ("c.-à-d", "c’est-à-dire"), + ("etc", "et cetera"), + ("ex", "exemple"), + ("excl", "exclusivement"), + ("boul", "boulevard"), + ] +] + [ + (re.compile("\\b%s" % x[0]), x[1]) + for x in [ + ("Mlle", "mademoiselle"), + ("Mlles", "mesdemoiselles"), + ("Mme", "Madame"), + ("Mmes", "Mesdames"), + ] +] diff --git a/TTS/tts/utils/text/japanese/__init__.py b/TTS/tts/utils/text/japanese/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/utils/text/japanese/phonemizer.py b/TTS/tts/utils/text/japanese/phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..30072ae5010df59c750f8f7f6047e85a053de5cf --- /dev/null +++ b/TTS/tts/utils/text/japanese/phonemizer.py @@ -0,0 +1,470 @@ +# Convert Japanese text to phonemes which is +# compatible with Julius https://github.com/julius-speech/segmentation-kit + +import re +import unicodedata + +try: + import MeCab +except ImportError as e: + raise ImportError("Japanese requires mecab-python3 and unidic-lite.") from e +from num2words import num2words + +_CONVRULES = [ + # Conversion of 2 letters + "アァ/ a a", + "イィ/ i i", + "イェ/ i e", + "イャ/ y a", + "ウゥ/ u:", + "エェ/ e e", + "オォ/ o:", + "カァ/ k a:", + "キィ/ k i:", + "クゥ/ k u:", + "クャ/ ky a", + "クュ/ ky u", + "クョ/ ky o", + "ケェ/ k e:", + "コォ/ k o:", + "ガァ/ g a:", + "ギィ/ g i:", + "グゥ/ g u:", + "グャ/ gy a", + "グュ/ gy u", + "グョ/ gy o", + "ゲェ/ g e:", + "ゴォ/ g o:", + "サァ/ s a:", + "シィ/ sh i:", + "スゥ/ s u:", + "スャ/ sh a", + "スュ/ sh u", + "スョ/ sh o", + "セェ/ s e:", + "ソォ/ s o:", + "ザァ/ z a:", + "ジィ/ j i:", + "ズゥ/ z u:", + "ズャ/ zy a", + "ズュ/ zy u", + "ズョ/ zy o", + "ゼェ/ z e:", + "ゾォ/ z o:", + "タァ/ t a:", + "チィ/ ch i:", + "ツァ/ ts a", + "ツィ/ ts i", + "ツゥ/ ts u:", + "ツャ/ ch a", + "ツュ/ ch u", + "ツョ/ ch o", + "ツェ/ ts e", + "ツォ/ ts o", + "テェ/ t e:", + "トォ/ t o:", + "ダァ/ d a:", + "ヂィ/ j i:", + "ヅゥ/ d u:", + "ヅャ/ zy a", + "ヅュ/ zy u", + "ヅョ/ zy o", + "デェ/ d e:", + "ドォ/ d o:", + "ナァ/ n a:", + "ニィ/ n i:", + "ヌゥ/ n u:", + "ヌャ/ ny a", + "ヌュ/ ny u", + "ヌョ/ ny o", + "ネェ/ n e:", + "ノォ/ n o:", + "ハァ/ h a:", + "ヒィ/ h i:", + "フゥ/ f u:", + "フャ/ hy a", + "フュ/ hy u", + "フョ/ hy o", + "ヘェ/ h e:", + "ホォ/ h o:", + "バァ/ b a:", + "ビィ/ b i:", + "ブゥ/ b u:", + "フャ/ hy a", + "ブュ/ by u", + "フョ/ hy o", + "ベェ/ b e:", + "ボォ/ b o:", + "パァ/ p a:", + "ピィ/ p i:", + "プゥ/ p u:", + "プャ/ py a", + "プュ/ py u", + "プョ/ py o", + "ペェ/ p e:", + "ポォ/ p o:", + "マァ/ m a:", + "ミィ/ m i:", + "ムゥ/ m u:", + "ムャ/ my a", + "ムュ/ my u", + "ムョ/ my o", + "メェ/ m e:", + "モォ/ m o:", + "ヤァ/ y a:", + "ユゥ/ y u:", + "ユャ/ y a:", + "ユュ/ y u:", + "ユョ/ y o:", + "ヨォ/ y o:", + "ラァ/ r a:", + "リィ/ r i:", + "ルゥ/ r u:", + "ルャ/ ry a", + "ルュ/ ry u", + "ルョ/ ry o", + "レェ/ r e:", + "ロォ/ r o:", + "ワァ/ w a:", + "ヲォ/ o:", + "ディ/ d i", + "デェ/ d e:", + "デャ/ dy a", + "デュ/ dy u", + "デョ/ dy o", + "ティ/ t i", + "テェ/ t e:", + "テャ/ ty a", + "テュ/ ty u", + "テョ/ ty o", + "スィ/ s i", + "ズァ/ z u a", + "ズィ/ z i", + "ズゥ/ z u", + "ズャ/ zy a", + "ズュ/ zy u", + "ズョ/ zy o", + "ズェ/ z e", + "ズォ/ z o", + "キャ/ ky a", + "キュ/ ky u", + "キョ/ ky o", + "シャ/ sh a", + "シュ/ sh u", + "シェ/ sh e", + "ショ/ sh o", + "チャ/ ch a", + "チュ/ ch u", + "チェ/ ch e", + "チョ/ ch o", + "トゥ/ t u", + "トャ/ ty a", + "トュ/ ty u", + "トョ/ ty o", + "ドァ/ d o a", + "ドゥ/ d u", + "ドャ/ dy a", + "ドュ/ dy u", + "ドョ/ dy o", + "ドォ/ d o:", + "ニャ/ ny a", + "ニュ/ ny u", + "ニョ/ ny o", + "ヒャ/ hy a", + "ヒュ/ hy u", + "ヒョ/ hy o", + "ミャ/ my a", + "ミュ/ my u", + "ミョ/ my o", + "リャ/ ry a", + "リュ/ ry u", + "リョ/ ry o", + "ギャ/ gy a", + "ギュ/ gy u", + "ギョ/ gy o", + "ヂェ/ j e", + "ヂャ/ j a", + "ヂュ/ j u", + "ヂョ/ j o", + "ジェ/ j e", + "ジャ/ j a", + "ジュ/ j u", + "ジョ/ j o", + "ビャ/ by a", + "ビュ/ by u", + "ビョ/ by o", + "ピャ/ py a", + "ピュ/ py u", + "ピョ/ py o", + "ウァ/ u a", + "ウィ/ w i", + "ウェ/ w e", + "ウォ/ w o", + "ファ/ f a", + "フィ/ f i", + "フゥ/ f u", + "フャ/ hy a", + "フュ/ hy u", + "フョ/ hy o", + "フェ/ f e", + "フォ/ f o", + "ヴァ/ b a", + "ヴィ/ b i", + "ヴェ/ b e", + "ヴォ/ b o", + "ヴュ/ by u", + # Conversion of 1 letter + "ア/ a", + "イ/ i", + "ウ/ u", + "エ/ e", + "オ/ o", + "カ/ k a", + "キ/ k i", + "ク/ k u", + "ケ/ k e", + "コ/ k o", + "サ/ s a", + "シ/ sh i", + "ス/ s u", + "セ/ s e", + "ソ/ s o", + "タ/ t a", + "チ/ ch i", + "ツ/ ts u", + "テ/ t e", + "ト/ t o", + "ナ/ n a", + "ニ/ n i", + "ヌ/ n u", + "ネ/ n e", + "ノ/ n o", + "ハ/ h a", + "ヒ/ h i", + "フ/ f u", + "ヘ/ h e", + "ホ/ h o", + "マ/ m a", + "ミ/ m i", + "ム/ m u", + "メ/ m e", + "モ/ m o", + "ラ/ r a", + "リ/ r i", + "ル/ r u", + "レ/ r e", + "ロ/ r o", + "ガ/ g a", + "ギ/ g i", + "グ/ g u", + "ゲ/ g e", + "ゴ/ g o", + "ザ/ z a", + "ジ/ j i", + "ズ/ z u", + "ゼ/ z e", + "ゾ/ z o", + "ダ/ d a", + "ヂ/ j i", + "ヅ/ z u", + "デ/ d e", + "ド/ d o", + "バ/ b a", + "ビ/ b i", + "ブ/ b u", + "ベ/ b e", + "ボ/ b o", + "パ/ p a", + "ピ/ p i", + "プ/ p u", + "ペ/ p e", + "ポ/ p o", + "ヤ/ y a", + "ユ/ y u", + "ヨ/ y o", + "ワ/ w a", + "ヰ/ i", + "ヱ/ e", + "ヲ/ o", + "ン/ N", + "ッ/ q", + "ヴ/ b u", + "ー/:", + # Try converting broken text + "ァ/ a", + "ィ/ i", + "ゥ/ u", + "ェ/ e", + "ォ/ o", + "ヮ/ w a", + "ォ/ o", + # Symbols + "、/ ,", + "。/ .", + "!/ !", + "?/ ?", + "・/ ,", +] + +_COLON_RX = re.compile(":+") +_REJECT_RX = re.compile("[^ a-zA-Z:,.?]") + + +def _makerulemap(): + l = [tuple(x.split("/")) for x in _CONVRULES] + return tuple({k: v for k, v in l if len(k) == i} for i in (1, 2)) + + +_RULEMAP1, _RULEMAP2 = _makerulemap() + + +def kata2phoneme(text: str) -> str: + """Convert katakana text to phonemes.""" + text = text.strip() + res = "" + while text: + if len(text) >= 2: + x = _RULEMAP2.get(text[:2]) + if x is not None: + text = text[2:] + res += x + continue + x = _RULEMAP1.get(text[0]) + if x is not None: + text = text[1:] + res += x + continue + res += " " + text[0] + text = text[1:] + res = _COLON_RX.sub(":", res) + return res[1:] + + +_KATAKANA = "".join(chr(ch) for ch in range(ord("ァ"), ord("ン") + 1)) +_HIRAGANA = "".join(chr(ch) for ch in range(ord("ぁ"), ord("ん") + 1)) +_HIRA2KATATRANS = str.maketrans(_HIRAGANA, _KATAKANA) + + +def hira2kata(text: str) -> str: + text = text.translate(_HIRA2KATATRANS) + return text.replace("う゛", "ヴ") + + +_SYMBOL_TOKENS = set("・、。?!") +_NO_YOMI_TOKENS = set("「」『』―()[][] …") +_TAGGER = MeCab.Tagger() + + +def text2kata(text: str) -> str: + parsed = _TAGGER.parse(text) + res = [] + for line in parsed.split("\n"): + if line == "EOS": + break + parts = line.split("\t") + + word, yomi = parts[0], parts[1] + if yomi: + res.append(yomi) + else: + if word in _SYMBOL_TOKENS: + res.append(word) + elif word in ("っ", "ッ"): + res.append("ッ") + elif word in _NO_YOMI_TOKENS: + pass + else: + res.append(word) + return hira2kata("".join(res)) + + +_ALPHASYMBOL_YOMI = { + "#": "シャープ", + "%": "パーセント", + "&": "アンド", + "+": "プラス", + "-": "マイナス", + ":": "コロン", + ";": "セミコロン", + "<": "小なり", + "=": "イコール", + ">": "大なり", + "@": "アット", + "a": "エー", + "b": "ビー", + "c": "シー", + "d": "ディー", + "e": "イー", + "f": "エフ", + "g": "ジー", + "h": "エイチ", + "i": "アイ", + "j": "ジェー", + "k": "ケー", + "l": "エル", + "m": "エム", + "n": "エヌ", + "o": "オー", + "p": "ピー", + "q": "キュー", + "r": "アール", + "s": "エス", + "t": "ティー", + "u": "ユー", + "v": "ブイ", + "w": "ダブリュー", + "x": "エックス", + "y": "ワイ", + "z": "ゼット", + "α": "アルファ", + "β": "ベータ", + "γ": "ガンマ", + "δ": "デルタ", + "ε": "イプシロン", + "ζ": "ゼータ", + "η": "イータ", + "θ": "シータ", + "ι": "イオタ", + "κ": "カッパ", + "λ": "ラムダ", + "μ": "ミュー", + "ν": "ニュー", + "ξ": "クサイ", + "ο": "オミクロン", + "π": "パイ", + "ρ": "ロー", + "σ": "シグマ", + "τ": "タウ", + "υ": "ウプシロン", + "φ": "ファイ", + "χ": "カイ", + "ψ": "プサイ", + "ω": "オメガ", +} + + +_NUMBER_WITH_SEPARATOR_RX = re.compile("[0-9]{1,3}(,[0-9]{3})+") +_CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"} +_CURRENCY_RX = re.compile(r"([$¥£€])([0-9.]*[0-9])") +_NUMBER_RX = re.compile(r"[0-9]+(\.[0-9]+)?") + + +def japanese_convert_numbers_to_words(text: str) -> str: + res = _NUMBER_WITH_SEPARATOR_RX.sub(lambda m: m[0].replace(",", ""), text) + res = _CURRENCY_RX.sub(lambda m: m[2] + _CURRENCY_MAP.get(m[1], m[1]), res) + res = _NUMBER_RX.sub(lambda m: num2words(m[0], lang="ja"), res) + return res + + +def japanese_convert_alpha_symbols_to_words(text: str) -> str: + return "".join([_ALPHASYMBOL_YOMI.get(ch, ch) for ch in text.lower()]) + + +def japanese_text_to_phonemes(text: str) -> str: + """Convert Japanese text to phonemes.""" + res = unicodedata.normalize("NFKC", text) + res = japanese_convert_numbers_to_words(res) + res = japanese_convert_alpha_symbols_to_words(res) + res = text2kata(res) + res = kata2phoneme(res) + return res.replace(" ", "") diff --git a/TTS/tts/utils/text/korean/__init__.py b/TTS/tts/utils/text/korean/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/tts/utils/text/korean/ko_dictionary.py b/TTS/tts/utils/text/korean/ko_dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..9b739339c6750443e7362aab1762147eda33fb53 --- /dev/null +++ b/TTS/tts/utils/text/korean/ko_dictionary.py @@ -0,0 +1,44 @@ +# coding: utf-8 +# Add the word you want to the dictionary. +etc_dictionary = {"1+1": "원플러스원", "2+1": "투플러스원"} + + +english_dictionary = { + "KOREA": "코리아", + "IDOL": "아이돌", + "IT": "아이티", + "IQ": "아이큐", + "UP": "업", + "DOWN": "다운", + "PC": "피씨", + "CCTV": "씨씨티비", + "SNS": "에스엔에스", + "AI": "에이아이", + "CEO": "씨이오", + "A": "에이", + "B": "비", + "C": "씨", + "D": "디", + "E": "이", + "F": "에프", + "G": "지", + "H": "에이치", + "I": "아이", + "J": "제이", + "K": "케이", + "L": "엘", + "M": "엠", + "N": "엔", + "O": "오", + "P": "피", + "Q": "큐", + "R": "알", + "S": "에스", + "T": "티", + "U": "유", + "V": "브이", + "W": "더블유", + "X": "엑스", + "Y": "와이", + "Z": "제트", +} diff --git a/TTS/tts/utils/text/korean/korean.py b/TTS/tts/utils/text/korean/korean.py new file mode 100644 index 0000000000000000000000000000000000000000..423aeed377e3ba2d22b574418aa45714c322f670 --- /dev/null +++ b/TTS/tts/utils/text/korean/korean.py @@ -0,0 +1,32 @@ +# coding: utf-8 +# Code based on https://github.com/carpedm20/multi-speaker-tacotron-tensorflow/blob/master/text/korean.py +import re + +from TTS.tts.utils.text.korean.ko_dictionary import english_dictionary, etc_dictionary + + +def normalize(text): + text = text.strip() + text = re.sub("[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]", "", text) + text = normalize_with_dictionary(text, etc_dictionary) + text = normalize_english(text) + text = text.lower() + return text + + +def normalize_with_dictionary(text, dic): + if any(key in text for key in dic.keys()): + pattern = re.compile("|".join(re.escape(key) for key in dic.keys())) + return pattern.sub(lambda x: dic[x.group()], text) + return text + + +def normalize_english(text): + def fn(m): + word = m.group() + if word in english_dictionary: + return english_dictionary.get(word) + return word + + text = re.sub("([A-Za-z]+)", fn, text) + return text diff --git a/TTS/tts/utils/text/korean/phonemizer.py b/TTS/tts/utils/text/korean/phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..dde039b0f5325662e6065da0782e0256c553ee73 --- /dev/null +++ b/TTS/tts/utils/text/korean/phonemizer.py @@ -0,0 +1,39 @@ +try: + from jamo import hangul_to_jamo +except ImportError as e: + raise ImportError("Korean requires: g2pkk, jamo") from e + +from TTS.tts.utils.text.korean.korean import normalize + +g2p = None + + +def korean_text_to_phonemes(text, character: str = "hangeul") -> str: + """ + + The input and output values look the same, but they are different in Unicode. + + example : + + input = '하늘' (Unicode : \ud558\ub298), (하 + 늘) + output = '하늘' (Unicode :\u1112\u1161\u1102\u1173\u11af), (ᄒ + ᅡ + ᄂ + ᅳ + ᆯ) + + """ + global g2p # pylint: disable=global-statement + if g2p is None: + from g2pkk import G2p + + g2p = G2p() + + if character == "english": + from anyascii import anyascii + + text = normalize(text) + text = g2p(text) + text = anyascii(text) + return text + + text = normalize(text) + text = g2p(text) + text = list(hangul_to_jamo(text)) # '하늘' --> ['ᄒ', 'ᅡ', 'ᄂ', 'ᅳ', 'ᆯ'] + return "".join(text) diff --git a/TTS/tts/utils/text/phonemizers/__init__.py b/TTS/tts/utils/text/phonemizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fdf62bab3d098cf18dc0800f7775e215d45dfba8 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/__init__.py @@ -0,0 +1,100 @@ +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer +from TTS.tts.utils.text.phonemizers.belarusian_phonemizer import BEL_Phonemizer +from TTS.tts.utils.text.phonemizers.espeak_wrapper import ESpeak +from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut + +try: + from TTS.tts.utils.text.phonemizers.bangla_phonemizer import BN_Phonemizer +except ImportError: + BN_Phonemizer = None + +try: + from TTS.tts.utils.text.phonemizers.ja_jp_phonemizer import JA_JP_Phonemizer +except ImportError: + JA_JP_Phonemizer = None + +try: + from TTS.tts.utils.text.phonemizers.ko_kr_phonemizer import KO_KR_Phonemizer +except ImportError: + KO_KR_Phonemizer = None + +try: + from TTS.tts.utils.text.phonemizers.zh_cn_phonemizer import ZH_CN_Phonemizer +except ImportError: + ZH_CN_Phonemizer = None + +PHONEMIZERS = {b.name(): b for b in (ESpeak, Gruut)} + + +ESPEAK_LANGS = list(ESpeak.supported_languages().keys()) +GRUUT_LANGS = list(Gruut.supported_languages()) + + +# Dict setting default phonemizers for each language +# Add Gruut languages +_ = [Gruut.name()] * len(GRUUT_LANGS) +DEF_LANG_TO_PHONEMIZER = dict(list(zip(GRUUT_LANGS, _))) + + +# Add ESpeak languages and override any existing ones +_ = [ESpeak.name()] * len(ESPEAK_LANGS) +_new_dict = dict(list(zip(list(ESPEAK_LANGS), _))) +DEF_LANG_TO_PHONEMIZER.update(_new_dict) + + +# Force default for some languages +DEF_LANG_TO_PHONEMIZER["en"] = DEF_LANG_TO_PHONEMIZER["en-us"] +DEF_LANG_TO_PHONEMIZER["be"] = BEL_Phonemizer.name() + + +if BN_Phonemizer is not None: + PHONEMIZERS[BN_Phonemizer.name()] = BN_Phonemizer + DEF_LANG_TO_PHONEMIZER["bn"] = BN_Phonemizer.name() +if JA_JP_Phonemizer is not None: + PHONEMIZERS[JA_JP_Phonemizer.name()] = JA_JP_Phonemizer + DEF_LANG_TO_PHONEMIZER["ja-jp"] = JA_JP_Phonemizer.name() +if KO_KR_Phonemizer is not None: + PHONEMIZERS[KO_KR_Phonemizer.name()] = KO_KR_Phonemizer + DEF_LANG_TO_PHONEMIZER["ko-kr"] = KO_KR_Phonemizer.name() +if ZH_CN_Phonemizer is not None: + PHONEMIZERS[ZH_CN_Phonemizer.name()] = ZH_CN_Phonemizer + DEF_LANG_TO_PHONEMIZER["zh-cn"] = ZH_CN_Phonemizer.name() + + +def get_phonemizer_by_name(name: str, **kwargs) -> BasePhonemizer: + """Initiate a phonemizer by name + + Args: + name (str): + Name of the phonemizer that should match `phonemizer.name()`. + + kwargs (dict): + Extra keyword arguments that should be passed to the phonemizer. + """ + if name == "espeak": + return ESpeak(**kwargs) + if name == "gruut": + return Gruut(**kwargs) + if name == "zh_cn_phonemizer": + if ZH_CN_Phonemizer is None: + raise ValueError("You need to install ZH phonemizer dependencies. Try `pip install coqui-tts[zh]`.") + return ZH_CN_Phonemizer(**kwargs) + if name == "ja_jp_phonemizer": + if JA_JP_Phonemizer is None: + raise ValueError("You need to install JA phonemizer dependencies. Try `pip install coqui-tts[ja]`.") + return JA_JP_Phonemizer(**kwargs) + if name == "ko_kr_phonemizer": + if KO_KR_Phonemizer is None: + raise ValueError("You need to install KO phonemizer dependencies. Try `pip install coqui-tts[ko]`.") + return KO_KR_Phonemizer(**kwargs) + if name == "bn_phonemizer": + if BN_Phonemizer is None: + raise ValueError("You need to install BN phonemizer dependencies. Try `pip install coqui-tts[bn]`.") + return BN_Phonemizer(**kwargs) + if name == "be_phonemizer": + return BEL_Phonemizer(**kwargs) + raise ValueError(f"Phonemizer {name} not found") + + +if __name__ == "__main__": + print(DEF_LANG_TO_PHONEMIZER) diff --git a/TTS/tts/utils/text/phonemizers/bangla_phonemizer.py b/TTS/tts/utils/text/phonemizers/bangla_phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4a35bbfa4fc63a7cbb2a2caa152ef59b4afa63 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/bangla_phonemizer.py @@ -0,0 +1,62 @@ +from typing import Dict + +from TTS.tts.utils.text.bangla.phonemizer import bangla_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer + +_DEF_ZH_PUNCS = "、.,[]()?!〽~『』「」【】" + + +class BN_Phonemizer(BasePhonemizer): + """🐸TTS bn phonemizer using functions in `TTS.tts.utils.text.bangla.phonemizer` + + Args: + punctuations (str): + Set of characters to be treated as punctuation. Defaults to `_DEF_ZH_PUNCS`. + + keep_puncs (bool): + If True, keep the punctuations after phonemization. Defaults to False. + + Example :: + + "这是,样本中文。" -> `d|ʒ|ø|4| |ʂ|ʏ|4| |,| |i|ɑ|ŋ|4|b|œ|n|3| |d|ʒ|o|ŋ|1|w|œ|n|2| |。` + + TODO: someone with Bangla knowledge should check this implementation + """ + + language = "bn" + + def __init__(self, punctuations=_DEF_ZH_PUNCS, keep_puncs=False, **kwargs): # pylint: disable=unused-argument + super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) + + @staticmethod + def name(): + return "bn_phonemizer" + + @staticmethod + def phonemize_bn(text: str, separator: str = "|") -> str: # pylint: disable=unused-argument + ph = bangla_text_to_phonemes(text) + return ph + + def _phonemize(self, text, separator): + return self.phonemize_bn(text, separator) + + @staticmethod + def supported_languages() -> Dict: + return {"bn": "Bangla"} + + def version(self) -> str: + return "0.0.1" + + def is_available(self) -> bool: + return True + + +if __name__ == "__main__": + txt = "রাসূলুল্লাহ সাল্লাল্লাহু আলাইহি ওয়া সাল্লাম শিক্ষা দিয়েছেন যে, কেউ যদি কোন খারাপ কিছুর সম্মুখীন হয়, তখনও যেন বলে." + e = BN_Phonemizer() + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + print("`" + e.phonemize(txt) + "`") diff --git a/TTS/tts/utils/text/phonemizers/base.py b/TTS/tts/utils/text/phonemizers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..5e701df458d7acd9ed8663ab83017e1ff02dbd02 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/base.py @@ -0,0 +1,143 @@ +import abc +import logging +from typing import List, Tuple + +from TTS.tts.utils.text.punctuation import Punctuation + +logger = logging.getLogger(__name__) + + +class BasePhonemizer(abc.ABC): + """Base phonemizer class + + Phonemization follows the following steps: + 1. Preprocessing: + - remove empty lines + - remove punctuation + - keep track of punctuation marks + + 2. Phonemization: + - convert text to phonemes + + 3. Postprocessing: + - join phonemes + - restore punctuation marks + + Args: + language (str): + Language used by the phonemizer. + + punctuations (List[str]): + List of punctuation marks to be preserved. + + keep_puncs (bool): + Whether to preserve punctuation marks or not. + """ + + def __init__(self, language, punctuations=Punctuation.default_puncs(), keep_puncs=False): + # ensure the backend is installed on the system + if not self.is_available(): + raise RuntimeError("{} not installed on your system".format(self.name())) # pragma: nocover + + # ensure the backend support the requested language + self._language = self._init_language(language) + + # setup punctuation processing + self._keep_puncs = keep_puncs + self._punctuator = Punctuation(punctuations) + + def _init_language(self, language): + """Language initialization + + This method may be overloaded in child classes (see Segments backend) + + """ + if not self.is_supported_language(language): + raise RuntimeError(f'language "{language}" is not supported by the ' f"{self.name()} backend") + return language + + @property + def language(self): + """The language code configured to be used for phonemization""" + return self._language + + @staticmethod + @abc.abstractmethod + def name(): + """The name of the backend""" + ... + + @classmethod + @abc.abstractmethod + def is_available(cls): + """Returns True if the backend is installed, False otherwise""" + ... + + @classmethod + @abc.abstractmethod + def version(cls): + """Return the backend version as a tuple (major, minor, patch)""" + ... + + @staticmethod + @abc.abstractmethod + def supported_languages(): + """Return a dict of language codes -> name supported by the backend""" + ... + + def is_supported_language(self, language): + """Returns True if `language` is supported by the backend""" + return language in self.supported_languages() + + @abc.abstractmethod + def _phonemize(self, text, separator): + """The main phonemization method""" + + def _phonemize_preprocess(self, text) -> Tuple[List[str], List]: + """Preprocess the text before phonemization + + 1. remove spaces + 2. remove punctuation + + Override this if you need a different behaviour + """ + text = text.strip() + if self._keep_puncs: + # a tuple (text, punctuation marks) + return self._punctuator.strip_to_restore(text) + return [self._punctuator.strip(text)], [] + + def _phonemize_postprocess(self, phonemized, punctuations) -> str: + """Postprocess the raw phonemized output + + Override this if you need a different behaviour + """ + if self._keep_puncs: + return self._punctuator.restore(phonemized, punctuations)[0] + return phonemized[0] + + def phonemize(self, text: str, separator="|", language: str = None) -> str: # pylint: disable=unused-argument + """Returns the `text` phonemized for the given language + + Args: + text (str): + Text to be phonemized. + + separator (str): + string separator used between phonemes. Default to '_'. + + Returns: + (str): Phonemized text + """ + text, punctuations = self._phonemize_preprocess(text) + phonemized = [] + for t in text: + p = self._phonemize(t, separator) + phonemized.append(p) + phonemized = self._phonemize_postprocess(phonemized, punctuations) + return phonemized + + def print_logs(self, level: int = 0): + indent = "\t" * level + logger.info("%s| phoneme language: %s", indent, self.language) + logger.info("%s| phoneme backend: %s", indent, self.name()) diff --git a/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py b/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..e5fcab6e09b7e9fad06c381ebc7239e36c4cb8db --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/belarusian_phonemizer.py @@ -0,0 +1,55 @@ +from typing import Dict + +from TTS.tts.utils.text.belarusian.phonemizer import belarusian_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer + +_DEF_BE_PUNCS = ",!." # TODO + + +class BEL_Phonemizer(BasePhonemizer): + """🐸TTS be phonemizer using functions in `TTS.tts.utils.text.belarusian.phonemizer` + + Args: + punctuations (str): + Set of characters to be treated as punctuation. Defaults to `_DEF_BE_PUNCS`. + + keep_puncs (bool): + If True, keep the punctuations after phonemization. Defaults to False. + """ + + language = "be" + + def __init__(self, punctuations=_DEF_BE_PUNCS, keep_puncs=True, **kwargs): # pylint: disable=unused-argument + super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) + + @staticmethod + def name(): + return "be_phonemizer" + + @staticmethod + def phonemize_be(text: str, separator: str = "|") -> str: # pylint: disable=unused-argument + return belarusian_text_to_phonemes(text) + + def _phonemize(self, text, separator): + return self.phonemize_be(text, separator) + + @staticmethod + def supported_languages() -> Dict: + return {"be": "Belarusian"} + + def version(self) -> str: + return "0.0.1" + + def is_available(self) -> bool: + return True + + +if __name__ == "__main__": + txt = "тэст" + e = BEL_Phonemizer() + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + print("`" + e.phonemize(txt) + "`") diff --git a/TTS/tts/utils/text/phonemizers/espeak_wrapper.py b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..a15df716e75fffe5b9e7a86ea39e578872aca410 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/espeak_wrapper.py @@ -0,0 +1,261 @@ +"""Wrapper to call the espeak/espeak-ng phonemizer.""" + +import logging +import re +import subprocess +import tempfile +from pathlib import Path +from typing import Optional + +from packaging.version import Version + +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer +from TTS.tts.utils.text.punctuation import Punctuation + +logger = logging.getLogger(__name__) + + +def _is_tool(name) -> bool: + from shutil import which + + return which(name) is not None + + +# Use a regex pattern to match the espeak version, because it may be +# symlinked to espeak-ng, which moves the version bits to another spot. +espeak_version_pattern = re.compile(r"text-to-speech:\s(?P\d+\.\d+(\.\d+)?)") + + +def get_espeak_version() -> str: + """Return version of the `espeak` binary.""" + output = subprocess.getoutput("espeak --version") + match = espeak_version_pattern.search(output) + + return match.group("version") + + +def get_espeakng_version() -> str: + """Return version of the `espeak-ng` binary.""" + output = subprocess.getoutput("espeak-ng --version") + return output.split()[3] + + +# priority: espeakng > espeak +if _is_tool("espeak-ng"): + _DEF_ESPEAK_LIB = "espeak-ng" + _DEF_ESPEAK_VER = get_espeakng_version() +elif _is_tool("espeak"): + _DEF_ESPEAK_LIB = "espeak" + _DEF_ESPEAK_VER = get_espeak_version() +else: + _DEF_ESPEAK_LIB = None + _DEF_ESPEAK_VER = None + + +def _espeak_exe(espeak_lib: str, args: list) -> list[str]: + """Run espeak with the given arguments.""" + cmd = [ + espeak_lib, + "-q", + "-b", + "1", # UTF8 text encoding + ] + cmd.extend(args) + logger.debug("Executing: %s", repr(cmd)) + + p = subprocess.run(cmd, capture_output=True, encoding="utf8", check=True) + for line in p.stderr.strip().split("\n"): + if line.strip() != "": + logger.warning("%s: %s", espeak_lib, line.strip()) + res = [] + for line in p.stdout.strip().split("\n"): + if line.strip() != "": + logger.debug("%s: %s", espeak_lib, line.strip()) + res.append(line.strip()) + return res + + +class ESpeak(BasePhonemizer): + """Wrapper calling `espeak` or `espeak-ng` from the command-line to perform G2P. + + Args: + language (str): + Valid language code for the used backend. + + backend (str): + Name of the backend library to use. `espeak` or `espeak-ng`. If None, set automatically + prefering `espeak-ng` over `espeak`. Defaults to None. + + punctuations (str): + Characters to be treated as punctuation. Defaults to Punctuation.default_puncs(). + + keep_puncs (bool): + If True, keep the punctuations after phonemization. Defaults to True. + + Example: + + >>> from TTS.tts.utils.text.phonemizers import ESpeak + >>> phonemizer = ESpeak("tr") + >>> phonemizer.phonemize("Bu Türkçe, bir örnektir.", separator="|") + 'b|ʊ t|ˈø|r|k|tʃ|ɛ, b|ɪ|r œ|r|n|ˈɛ|c|t|ɪ|r.' + + """ + + def __init__( + self, + language: str, + backend: Optional[str] = None, + punctuations: str = Punctuation.default_puncs(), + keep_puncs: bool = True, + ): + if _DEF_ESPEAK_LIB is None: + msg = "[!] No espeak backend found. Install espeak-ng or espeak to your system." + raise FileNotFoundError(msg) + self.backend = _DEF_ESPEAK_LIB + + # band-aid for backwards compatibility + if language == "en": + language = "en-us" + if language == "zh-cn": + language = "cmn" + + super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs) + if backend is not None: + self.backend = backend + + @property + def backend(self) -> str: + return self._ESPEAK_LIB + + @property + def backend_version(self) -> str: + return self._ESPEAK_VER + + @backend.setter + def backend(self, backend: str) -> None: + if backend not in ["espeak", "espeak-ng"]: + msg = f"Unknown backend: {backend}" + raise ValueError(msg) + self._ESPEAK_LIB = backend + self._ESPEAK_VER = get_espeakng_version() if backend == "espeak-ng" else get_espeak_version() + + def auto_set_espeak_lib(self) -> None: + if _is_tool("espeak-ng"): + self._ESPEAK_LIB = "espeak-ng" + self._ESPEAK_VER = get_espeakng_version() + elif _is_tool("espeak"): + self._ESPEAK_LIB = "espeak" + self._ESPEAK_VER = get_espeak_version() + else: + msg = "Cannot set backend automatically. espeak-ng or espeak not found" + raise FileNotFoundError(msg) + + @staticmethod + def name() -> str: + return "espeak" + + def phonemize_espeak(self, text: str, separator: str = "|", *, tie: bool = False) -> str: + """Convert input text to phonemes. + + Args: + text (str): + Text to be converted to phonemes. + + tie (bool, optional) : When True use a '͡' character between + consecutive characters of a single phoneme. Else separate phoneme + with '_'. This option requires espeak>=1.49. Default to False. + """ + # set arguments + args = ["-v", f"{self._language}"] + # espeak and espeak-ng parses `ipa` differently + if tie: + # use '͡' between phonemes + if self.backend == "espeak": + args.append("--ipa=1") + else: + args.append("--ipa=3") + else: + # split with '_' + if self.backend == "espeak": + if Version(self.backend_version) >= Version("1.48.15"): + args.append("--ipa=1") + else: + args.append("--ipa=3") + else: + args.append("--ipa=1") + if tie: + args.append("--tie=%s" % tie) + + tmp = tempfile.NamedTemporaryFile(mode="w+t", delete=False, encoding="utf8") + tmp.write(text) + tmp.close() + args.append("-f") + args.append(tmp.name) + + # compute phonemes + phonemes = "" + for line in _espeak_exe(self.backend, args): + # espeak: + # version 1.48.15: " p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n" + # espeak-ng: + # "p_ɹ_ˈaɪ_ɚ t_ə n_oʊ_v_ˈɛ_m_b_ɚ t_w_ˈɛ_n_t_i t_ˈuː\n" + + # espeak-ng backend can add language flags that need to be removed: + # "sɛʁtˈɛ̃ mˈo kɔm (en)fˈʊtbɔːl(fr) ʒenˈɛʁ de- flˈaɡ də- lˈɑ̃ɡ." + # phonemize needs to remove the language flags of the returned text: + # "sɛʁtˈɛ̃ mˈo kɔm fˈʊtbɔːl ʒenˈɛʁ de- flˈaɡ də- lˈɑ̃ɡ." + ph_decoded = re.sub(r"\(.+?\)", "", line) + + phonemes += ph_decoded.strip() + Path(tmp.name).unlink() + return phonemes.replace("_", separator) + + def _phonemize(self, text: str, separator: str = "") -> str: + return self.phonemize_espeak(text, separator, tie=False) + + @staticmethod + def supported_languages() -> dict[str, str]: + """Get a dictionary of supported languages. + + Returns: + Dict: Dictionary of language codes. + """ + if _DEF_ESPEAK_LIB is None: + return {} + args = ["--voices"] + langs = {} + for count, line in enumerate(_espeak_exe(_DEF_ESPEAK_LIB, args)): + if count > 0: + cols = line.split() + lang_code = cols[1] + lang_name = cols[3] + langs[lang_code] = lang_name + return langs + + def version(self) -> str: + """Get the version of the used backend. + + Returns: + str: Version of the used backend. + """ + return self.backend_version + + @classmethod + def is_available(cls) -> bool: + """Return true if ESpeak is available else false.""" + return _is_tool("espeak") or _is_tool("espeak-ng") + + +if __name__ == "__main__": + e = ESpeak(language="en-us") + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + + e = ESpeak(language="en-us", keep_puncs=False) + print("`" + e.phonemize("hello how are you today?") + "`") + + e = ESpeak(language="en-us", keep_puncs=True) + print("`" + e.phonemize("hello how are you today?") + "`") diff --git a/TTS/tts/utils/text/phonemizers/gruut_wrapper.py b/TTS/tts/utils/text/phonemizers/gruut_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f3e9c9abd4c41935ed07ec10ed883d75b42a6bc8 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/gruut_wrapper.py @@ -0,0 +1,151 @@ +import importlib +from typing import List + +import gruut +from gruut_ipa import IPA + +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer +from TTS.tts.utils.text.punctuation import Punctuation + +# Table for str.translate to fix gruut/TTS phoneme mismatch +GRUUT_TRANS_TABLE = str.maketrans("g", "ɡ") + + +class Gruut(BasePhonemizer): + """Gruut wrapper for G2P + + Args: + language (str): + Valid language code for the used backend. + + punctuations (str): + Characters to be treated as punctuation. Defaults to `Punctuation.default_puncs()`. + + keep_puncs (bool): + If true, keep the punctuations after phonemization. Defaults to True. + + use_espeak_phonemes (bool): + If true, use espeak lexicons instead of default Gruut lexicons. Defaults to False. + + keep_stress (bool): + If true, keep the stress characters after phonemization. Defaults to False. + + Example: + + >>> from TTS.tts.utils.text.phonemizers.gruut_wrapper import Gruut + >>> phonemizer = Gruut('en-us') + >>> phonemizer.phonemize("Be a voice, not an! echo?", separator="|") + 'b|i| ə| v|ɔ|ɪ|s, n|ɑ|t| ə|n! ɛ|k|o|ʊ?' + """ + + def __init__( + self, + language: str, + punctuations=Punctuation.default_puncs(), + keep_puncs=True, + use_espeak_phonemes=False, + keep_stress=False, + ): + super().__init__(language, punctuations=punctuations, keep_puncs=keep_puncs) + self.use_espeak_phonemes = use_espeak_phonemes + self.keep_stress = keep_stress + + @staticmethod + def name(): + return "gruut" + + def phonemize_gruut(self, text: str, separator: str = "|", tie=False) -> str: # pylint: disable=unused-argument + """Convert input text to phonemes. + + Gruut phonemizes the given `str` by seperating each phoneme character with `separator`, even for characters + that constitude a single sound. + + It doesn't affect 🐸TTS since it individually converts each character to token IDs. + + Examples:: + "hello how are you today?" -> `h|ɛ|l|o|ʊ| h|a|ʊ| ɑ|ɹ| j|u| t|ə|d|e|ɪ` + + Args: + text (str): + Text to be converted to phonemes. + + tie (bool, optional) : When True use a '͡' character between + consecutive characters of a single phoneme. Else separate phoneme + with '_'. This option requires espeak>=1.49. Default to False. + """ + ph_list = [] + for sentence in gruut.sentences(text, lang=self.language, espeak=self.use_espeak_phonemes): + for word in sentence: + if word.is_break: + # Use actual character for break phoneme (e.g., comma) + if ph_list: + # Join with previous word + ph_list[-1].append(word.text) + else: + # First word is punctuation + ph_list.append([word.text]) + elif word.phonemes: + # Add phonemes for word + word_phonemes = [] + + for word_phoneme in word.phonemes: + if not self.keep_stress: + # Remove primary/secondary stress + word_phoneme = IPA.without_stress(word_phoneme) + + word_phoneme = word_phoneme.translate(GRUUT_TRANS_TABLE) + + if word_phoneme: + # Flatten phonemes + word_phonemes.extend(word_phoneme) + + if word_phonemes: + ph_list.append(word_phonemes) + + ph_words = [separator.join(word_phonemes) for word_phonemes in ph_list] + ph = f"{separator} ".join(ph_words) + return ph + + def _phonemize(self, text, separator): + return self.phonemize_gruut(text, separator, tie=False) + + def is_supported_language(self, language): + """Returns True if `language` is supported by the backend""" + return gruut.is_language_supported(language) + + @staticmethod + def supported_languages() -> List: + """Get a dictionary of supported languages. + + Returns: + List: List of language codes. + """ + return list(gruut.get_supported_languages()) + + def version(self): + """Get the version of the used backend. + + Returns: + str: Version of the used backend. + """ + return gruut.__version__ + + @classmethod + def is_available(cls): + """Return true if ESpeak is available else false""" + return importlib.util.find_spec("gruut") is not None + + +if __name__ == "__main__": + e = Gruut(language="en-us") + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + + e = Gruut(language="en-us", keep_puncs=False) + print("`" + e.phonemize("hello how are you today?") + "`") + + e = Gruut(language="en-us", keep_puncs=True) + print("`" + e.phonemize("hello how, are you today?") + "`") diff --git a/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py b/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..878e5e52969f740ae74d875c91294666095e89ac --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/ja_jp_phonemizer.py @@ -0,0 +1,72 @@ +from typing import Dict + +from TTS.tts.utils.text.japanese.phonemizer import japanese_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer + +_DEF_JA_PUNCS = "、.,[]()?!〽~『』「」【】" + +_TRANS_TABLE = {"、": ","} + + +def trans(text): + for i, j in _TRANS_TABLE.items(): + text = text.replace(i, j) + return text + + +class JA_JP_Phonemizer(BasePhonemizer): + """🐸TTS Ja-Jp phonemizer using functions in `TTS.tts.utils.text.japanese.phonemizer` + + TODO: someone with JA knowledge should check this implementation + + Example: + + >>> from TTS.tts.utils.text.phonemizers import JA_JP_Phonemizer + >>> phonemizer = JA_JP_Phonemizer() + >>> phonemizer.phonemize("どちらに行きますか?", separator="|") + 'd|o|c|h|i|r|a|n|i|i|k|i|m|a|s|u|k|a|?' + + """ + + language = "ja-jp" + + def __init__(self, punctuations=_DEF_JA_PUNCS, keep_puncs=True, **kwargs): # pylint: disable=unused-argument + super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) + + @staticmethod + def name(): + return "ja_jp_phonemizer" + + def _phonemize(self, text: str, separator: str = "|") -> str: + ph = japanese_text_to_phonemes(text) + if separator is not None or separator != "": + return separator.join(ph) + return ph + + def phonemize(self, text: str, separator="|", language=None) -> str: + """Custom phonemize for JP_JA + + Skip pre-post processing steps used by the other phonemizers. + """ + return self._phonemize(text, separator) + + @staticmethod + def supported_languages() -> Dict: + return {"ja-jp": "Japanese (Japan)"} + + def version(self) -> str: + return "0.0.1" + + def is_available(self) -> bool: + return True + + +# if __name__ == "__main__": +# text = "これは、電話をかけるための私の日本語の例のテキストです。" +# e = JA_JP_Phonemizer() +# print(e.supported_languages()) +# print(e.version()) +# print(e.language) +# print(e.name()) +# print(e.is_available()) +# print("`" + e.phonemize(text) + "`") diff --git a/TTS/tts/utils/text/phonemizers/ko_kr_phonemizer.py b/TTS/tts/utils/text/phonemizers/ko_kr_phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0bdba2137b039025491a08d8b8d52454b82a72fd --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/ko_kr_phonemizer.py @@ -0,0 +1,65 @@ +from typing import Dict + +from TTS.tts.utils.text.korean.phonemizer import korean_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer + +_DEF_KO_PUNCS = "、.,[]()?!〽~『』「」【】" + + +class KO_KR_Phonemizer(BasePhonemizer): + """🐸TTS ko_kr_phonemizer using functions in `TTS.tts.utils.text.korean.phonemizer` + + TODO: Add Korean to character (ᄀᄁᄂᄃᄄᄅᄆᄇᄈᄉᄊᄋᄌᄍᄎᄏᄐᄑ하ᅢᅣᅤᅥᅦᅧᅨᅩᅪᅫᅬᅭᅮᅯᅰᅱᅲᅳᅴᅵᆨᆩᆪᆫᆬᆭᆮᆯᆰᆱᆲᆳᆴᆵᆶᆷᆸᆹᆺᆻᆼᆽᆾᆿᇀᇁᇂ) + + Example: + + >>> from TTS.tts.utils.text.phonemizers import KO_KR_Phonemizer + >>> phonemizer = KO_KR_Phonemizer() + >>> phonemizer.phonemize("이 문장은 음성합성 테스트를 위한 문장입니다.", separator="|") + 'ᄋ|ᅵ| |ᄆ|ᅮ|ᆫ|ᄌ|ᅡ|ᆼ|ᄋ|ᅳ| |ᄂ|ᅳ|ᆷ|ᄉ|ᅥ|ᆼ|ᄒ|ᅡ|ᆸ|ᄊ|ᅥ|ᆼ| |ᄐ|ᅦ|ᄉ|ᅳ|ᄐ|ᅳ|ᄅ|ᅳ| |ᄅ|ᅱ|ᄒ|ᅡ|ᆫ| |ᄆ|ᅮ|ᆫ|ᄌ|ᅡ|ᆼ|ᄋ|ᅵ|ᆷ|ᄂ|ᅵ|ᄃ|ᅡ|.' + + >>> from TTS.tts.utils.text.phonemizers import KO_KR_Phonemizer + >>> phonemizer = KO_KR_Phonemizer() + >>> phonemizer.phonemize("이 문장은 음성합성 테스트를 위한 문장입니다.", separator="|", character='english') + 'I| |M|u|n|J|a|n|g|E|u| |N|e|u|m|S|e|o|n|g|H|a|b|S|s|e|o|n|g| |T|e|S|e|u|T|e|u|L|e|u| |L|w|i|H|a|n| |M|u|n|J|a|n|g|I|m|N|i|D|a|.' + + """ + + language = "ko-kr" + + def __init__(self, punctuations=_DEF_KO_PUNCS, keep_puncs=True, **kwargs): # pylint: disable=unused-argument + super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) + + @staticmethod + def name(): + return "ko_kr_phonemizer" + + def _phonemize(self, text: str, separator: str = "", character: str = "hangeul") -> str: + ph = korean_text_to_phonemes(text, character=character) + if separator is not None or separator != "": + return separator.join(ph) + return ph + + def phonemize(self, text: str, separator: str = "", character: str = "hangeul", language=None) -> str: + return self._phonemize(text, separator, character) + + @staticmethod + def supported_languages() -> Dict: + return {"ko-kr": "hangeul(korean)"} + + def version(self) -> str: + return "0.0.2" + + def is_available(self) -> bool: + return True + + +if __name__ == "__main__": + texts = "이 문장은 음성합성 테스트를 위한 문장입니다." + e = KO_KR_Phonemizer() + print(e.supported_languages()) + print(e.version()) + print(e.language) + print(e.name()) + print(e.is_available()) + print(e.phonemize(texts)) diff --git a/TTS/tts/utils/text/phonemizers/multi_phonemizer.py b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1a9e98b091e9acece2a0f703fb7385b619936879 --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/multi_phonemizer.py @@ -0,0 +1,68 @@ +import logging +from typing import Dict, List + +from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name + +logger = logging.getLogger(__name__) + + +class MultiPhonemizer: + """🐸TTS multi-phonemizer that operates phonemizers for multiple langugages + + Args: + custom_lang_to_phonemizer (Dict): + Custom phonemizer mapping if you want to change the defaults. In the format of + `{"lang_code", "phonemizer_name"}`. When it is None, `DEF_LANG_TO_PHONEMIZER` is used. Defaults to `{}`. + + TODO: find a way to pass custom kwargs to the phonemizers + """ + + lang_to_phonemizer = {} + + def __init__(self, lang_to_phonemizer_name: Dict = {}) -> None: # pylint: disable=dangerous-default-value + for k, v in lang_to_phonemizer_name.items(): + if v == "" and k in DEF_LANG_TO_PHONEMIZER.keys(): + lang_to_phonemizer_name[k] = DEF_LANG_TO_PHONEMIZER[k] + elif v == "": + raise ValueError(f"Phonemizer wasn't set for language {k} and doesn't have a default.") + self.lang_to_phonemizer_name = lang_to_phonemizer_name + self.lang_to_phonemizer = self.init_phonemizers(self.lang_to_phonemizer_name) + + @staticmethod + def init_phonemizers(lang_to_phonemizer_name: Dict) -> Dict: + lang_to_phonemizer = {} + for k, v in lang_to_phonemizer_name.items(): + lang_to_phonemizer[k] = get_phonemizer_by_name(v, language=k) + return lang_to_phonemizer + + @staticmethod + def name(): + return "multi-phonemizer" + + def phonemize(self, text, separator="|", language=""): + if language == "": + raise ValueError("Language must be set for multi-phonemizer to phonemize.") + return self.lang_to_phonemizer[language].phonemize(text, separator) + + def supported_languages(self) -> List: + return list(self.lang_to_phonemizer.keys()) + + def print_logs(self, level: int = 0): + indent = "\t" * level + logger.info("%s| phoneme language: %s", indent, self.supported_languages()) + logger.info("%s| phoneme backend: %s", indent, self.name()) + + +# if __name__ == "__main__": +# texts = { +# "tr": "Merhaba, bu Türkçe bit örnek!", +# "en-us": "Hello, this is English example!", +# "de": "Hallo, das ist ein Deutches Beipiel!", +# "zh-cn": "这是中国的例子", +# } +# phonemes = {} +# ph = MultiPhonemizer({"tr": "espeak", "en-us": "", "de": "gruut", "zh-cn": ""}) +# for lang, text in texts.items(): +# phoneme = ph.phonemize(text, lang) +# phonemes[lang] = phoneme +# print(phonemes) diff --git a/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py b/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py new file mode 100644 index 0000000000000000000000000000000000000000..41480c417356fd941e71e3eff0099eb38ac7296a --- /dev/null +++ b/TTS/tts/utils/text/phonemizers/zh_cn_phonemizer.py @@ -0,0 +1,62 @@ +from typing import Dict + +from TTS.tts.utils.text.chinese_mandarin.phonemizer import chinese_text_to_phonemes +from TTS.tts.utils.text.phonemizers.base import BasePhonemizer + +_DEF_ZH_PUNCS = "、.,[]()?!〽~『』「」【】" + + +class ZH_CN_Phonemizer(BasePhonemizer): + """🐸TTS Zh-Cn phonemizer using functions in `TTS.tts.utils.text.chinese_mandarin.phonemizer` + + Args: + punctuations (str): + Set of characters to be treated as punctuation. Defaults to `_DEF_ZH_PUNCS`. + + keep_puncs (bool): + If True, keep the punctuations after phonemization. Defaults to False. + + Example :: + + "这是,样本中文。" -> `d|ʒ|ø|4| |ʂ|ʏ|4| |,| |i|ɑ|ŋ|4|b|œ|n|3| |d|ʒ|o|ŋ|1|w|œ|n|2| |。` + + TODO: someone with Mandarin knowledge should check this implementation + """ + + language = "zh-cn" + + def __init__(self, punctuations=_DEF_ZH_PUNCS, keep_puncs=False, **kwargs): # pylint: disable=unused-argument + super().__init__(self.language, punctuations=punctuations, keep_puncs=keep_puncs) + + @staticmethod + def name(): + return "zh_cn_phonemizer" + + @staticmethod + def phonemize_zh_cn(text: str, separator: str = "|") -> str: + ph = chinese_text_to_phonemes(text, separator) + return ph + + def _phonemize(self, text, separator): + return self.phonemize_zh_cn(text, separator) + + @staticmethod + def supported_languages() -> Dict: + return {"zh-cn": "Chinese (China)"} + + def version(self) -> str: + return "0.0.1" + + def is_available(self) -> bool: + return True + + +# if __name__ == "__main__": +# text = "这是,样本中文。" +# e = ZH_CN_Phonemizer() +# print(e.supported_languages()) +# print(e.version()) +# print(e.language) +# print(e.name()) +# print(e.is_available()) +# print("`" + e.phonemize(text) + "`") diff --git a/TTS/tts/utils/text/punctuation.py b/TTS/tts/utils/text/punctuation.py new file mode 100644 index 0000000000000000000000000000000000000000..36c467d08335ea24ad1c0e6705e0c351f7fb8d65 --- /dev/null +++ b/TTS/tts/utils/text/punctuation.py @@ -0,0 +1,171 @@ +import collections +import re +from enum import Enum + +import six + +_DEF_PUNCS = ';:,.!?¡¿—…"«»“”' + +_PUNC_IDX = collections.namedtuple("_punc_index", ["punc", "position"]) + + +class PuncPosition(Enum): + """Enum for the punctuations positions""" + + BEGIN = 0 + END = 1 + MIDDLE = 2 + + +class Punctuation: + """Handle punctuations in text. + + Just strip punctuations from text or strip and restore them later. + + Args: + puncs (str): The punctuations to be processed. Defaults to `_DEF_PUNCS`. + + Example: + >>> punc = Punctuation() + >>> punc.strip("This is. example !") + 'This is example' + + >>> text_striped, punc_map = punc.strip_to_restore("This is. example !") + >>> ' '.join(text_striped) + 'This is example' + + >>> text_restored = punc.restore(text_striped, punc_map) + >>> text_restored[0] + 'This is. example !' + """ + + def __init__(self, puncs: str = _DEF_PUNCS): + self.puncs = puncs + + @staticmethod + def default_puncs(): + """Return default set of punctuations.""" + return _DEF_PUNCS + + @property + def puncs(self): + return self._puncs + + @puncs.setter + def puncs(self, value): + if not isinstance(value, six.string_types): + raise ValueError("[!] Punctuations must be of type str.") + self._puncs = "".join(list(dict.fromkeys(list(value)))) # remove duplicates without changing the oreder + self.puncs_regular_exp = re.compile(rf"(\s*[{re.escape(self._puncs)}]+\s*)+") + + def strip(self, text): + """Remove all the punctuations by replacing with `space`. + + Args: + text (str): The text to be processed. + + Example:: + + "This is. example !" -> "This is example " + """ + return re.sub(self.puncs_regular_exp, " ", text).rstrip().lstrip() + + def strip_to_restore(self, text): + """Remove punctuations from text to restore them later. + + Args: + text (str): The text to be processed. + + Examples :: + + "This is. example !" -> [["This is", "example"], [".", "!"]] + + """ + text, puncs = self._strip_to_restore(text) + return text, puncs + + def _strip_to_restore(self, text): + """Auxiliary method for Punctuation.preserve()""" + matches = list(re.finditer(self.puncs_regular_exp, text)) + if not matches: + return [text], [] + # the text is only punctuations + if len(matches) == 1 and matches[0].group() == text: + return [], [_PUNC_IDX(text, PuncPosition.BEGIN)] + # build a punctuation map to be used later to restore punctuations + puncs = [] + for match in matches: + position = PuncPosition.MIDDLE + if match == matches[0] and text.startswith(match.group()): + position = PuncPosition.BEGIN + elif match == matches[-1] and text.endswith(match.group()): + position = PuncPosition.END + puncs.append(_PUNC_IDX(match.group(), position)) + # convert str text to a List[str], each item is separated by a punctuation + splitted_text = [] + for idx, punc in enumerate(puncs): + split = text.split(punc.punc) + prefix, suffix = split[0], punc.punc.join(split[1:]) + text = suffix + if prefix == "": + # We don't want to insert an empty string in case of initial punctuation + continue + splitted_text.append(prefix) + # if the text does not end with a punctuation, add it to the last item + if idx == len(puncs) - 1 and len(suffix) > 0: + splitted_text.append(suffix) + return splitted_text, puncs + + @classmethod + def restore(cls, text, puncs): + """Restore punctuation in a text. + + Args: + text (str): The text to be processed. + puncs (List[str]): The list of punctuations map to be used for restoring. + + Examples :: + + ['This is', 'example'], ['.', '!'] -> "This is. example!" + + """ + return cls._restore(text, puncs) + + @classmethod + def _restore(cls, text, puncs): # pylint: disable=too-many-return-statements + """Auxiliary method for Punctuation.restore()""" + if not puncs: + return text + + # nothing have been phonemized, returns the puncs alone + if not text: + return ["".join(m.punc for m in puncs)] + + current = puncs[0] + + if current.position == PuncPosition.BEGIN: + return cls._restore([current.punc + text[0]] + text[1:], puncs[1:]) + + if current.position == PuncPosition.END: + return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:]) + + # POSITION == MIDDLE + if len(text) == 1: # pragma: nocover + # a corner case where the final part of an intermediate + # mark (I) has not been phonemized + return cls._restore([text[0] + current.punc], puncs[1:]) + + return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:]) + + +# if __name__ == "__main__": +# punc = Punctuation() +# text = "This is. This is, example!" + +# print(punc.strip(text)) + +# split_text, puncs = punc.strip_to_restore(text) +# print(split_text, " ---- ", puncs) + +# restored_text = punc.restore(split_text, puncs) +# print(restored_text) diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f653cdf13f0d690bbc4d6348c76cddf938d1aa6e --- /dev/null +++ b/TTS/tts/utils/text/tokenizer.py @@ -0,0 +1,222 @@ +import logging +from typing import Callable, Dict, List, Union + +from TTS.tts.utils.text import cleaners +from TTS.tts.utils.text.characters import Graphemes, IPAPhonemes +from TTS.tts.utils.text.phonemizers import DEF_LANG_TO_PHONEMIZER, get_phonemizer_by_name +from TTS.tts.utils.text.phonemizers.multi_phonemizer import MultiPhonemizer +from TTS.utils.generic_utils import get_import_path, import_class + +logger = logging.getLogger(__name__) + + +class TTSTokenizer: + """🐸TTS tokenizer to convert input characters to token IDs and back. + + Token IDs for OOV chars are discarded but those are stored in `self.not_found_characters` for later. + + Args: + use_phonemes (bool): + Whether to use phonemes instead of characters. Defaults to False. + + characters (Characters): + A Characters object to use for character-to-ID and ID-to-character mappings. + + text_cleaner (callable): + A function to pre-process the text before tokenization and phonemization. Defaults to None. + + phonemizer (Phonemizer): + A phonemizer object or a dict that maps language codes to phonemizer objects. Defaults to None. + + Example: + + >>> from TTS.tts.utils.text.tokenizer import TTSTokenizer + >>> tokenizer = TTSTokenizer(use_phonemes=False, characters=Graphemes()) + >>> text = "Hello world!" + >>> ids = tokenizer.text_to_ids(text) + >>> text_hat = tokenizer.ids_to_text(ids) + >>> assert text == text_hat + """ + + def __init__( + self, + use_phonemes=False, + text_cleaner: Callable = None, + characters: "BaseCharacters" = None, + phonemizer: Union["Phonemizer", Dict] = None, + add_blank: bool = False, + use_eos_bos=False, + ): + self.text_cleaner = text_cleaner + self.use_phonemes = use_phonemes + self.add_blank = add_blank + self.use_eos_bos = use_eos_bos + self.characters = characters + self.not_found_characters = [] + self.phonemizer = phonemizer + + @property + def characters(self): + return self._characters + + @characters.setter + def characters(self, new_characters): + self._characters = new_characters + self.pad_id = self.characters.char_to_id(self.characters.pad) if self.characters.pad else None + self.blank_id = self.characters.char_to_id(self.characters.blank) if self.characters.blank else None + + def encode(self, text: str) -> List[int]: + """Encodes a string of text as a sequence of IDs.""" + token_ids = [] + for char in text: + try: + idx = self.characters.char_to_id(char) + token_ids.append(idx) + except KeyError: + # discard but store not found characters + if char not in self.not_found_characters: + self.not_found_characters.append(char) + logger.warning(text) + logger.warning("Character %s not found in the vocabulary. Discarding it.", repr(char)) + return token_ids + + def decode(self, token_ids: List[int]) -> str: + """Decodes a sequence of IDs to a string of text.""" + text = "" + for token_id in token_ids: + text += self.characters.id_to_char(token_id) + return text + + def text_to_ids(self, text: str, language: str = None) -> List[int]: # pylint: disable=unused-argument + """Converts a string of text to a sequence of token IDs. + + Args: + text(str): + The text to convert to token IDs. + + language(str): + The language code of the text. Defaults to None. + + TODO: + - Add support for language-specific processing. + + 1. Text normalizatin + 2. Phonemization (if use_phonemes is True) + 3. Add blank char between characters + 4. Add BOS and EOS characters + 5. Text to token IDs + """ + # TODO: text cleaner should pick the right routine based on the language + logger.debug("Tokenizer input text: %s", text) + if self.text_cleaner is not None: + text = self.text_cleaner(text) + logger.debug("Cleaned text: %s", text) + if self.use_phonemes: + text = self.phonemizer.phonemize(text, separator="", language=language) + logger.debug("Phonemes: %s", text) + text = self.encode(text) + if self.add_blank: + text = self.intersperse_blank_char(text, True) + if self.use_eos_bos: + text = self.pad_with_bos_eos(text) + return text + + def ids_to_text(self, id_sequence: List[int]) -> str: + """Converts a sequence of token IDs to a string of text.""" + return self.decode(id_sequence) + + def pad_with_bos_eos(self, char_sequence: List[str]): + """Pads a sequence with the special BOS and EOS characters.""" + return [self.characters.bos_id] + list(char_sequence) + [self.characters.eos_id] + + def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False): + """Intersperses the blank character between characters in a sequence. + + Use the ```blank``` character if defined else use the ```pad``` character. + """ + char_to_use = self.characters.blank_id if use_blank_char else self.characters.pad + result = [char_to_use] * (len(char_sequence) * 2 + 1) + result[1::2] = char_sequence + return result + + def print_logs(self, level: int = 0): + indent = "\t" * level + logger.info("%s| add_blank: %s", indent, self.add_blank) + logger.info("%s| use_eos_bos: %s", indent, self.use_eos_bos) + logger.info("%s| use_phonemes: %s", indent, self.use_phonemes) + if self.use_phonemes: + logger.info("%s| phonemizer:", indent) + self.phonemizer.print_logs(level + 1) + if len(self.not_found_characters) > 0: + logger.info("%s| %d characters not found:", indent, len(self.not_found_characters)) + for char in self.not_found_characters: + logger.info("%s| %s", indent, char) + + @staticmethod + def init_from_config(config: "Coqpit", characters: "BaseCharacters" = None): + """Init Tokenizer object from config + + Args: + config (Coqpit): Coqpit model config. + characters (BaseCharacters): Defines the model character set. If not set, use the default options based on + the config values. Defaults to None. + """ + # init cleaners + text_cleaner = None + if isinstance(config.text_cleaner, (str, list)): + text_cleaner = getattr(cleaners, config.text_cleaner) + + # init characters + if characters is None: + # set characters based on defined characters class + if config.characters and config.characters.characters_class: + CharactersClass = import_class(config.characters.characters_class) + characters, new_config = CharactersClass.init_from_config(config) + # set characters based on config + else: + if config.use_phonemes: + # init phoneme set + characters, new_config = IPAPhonemes().init_from_config(config) + else: + # init character set + characters, new_config = Graphemes().init_from_config(config) + + else: + characters, new_config = characters.init_from_config(config) + + # set characters class + new_config.characters.characters_class = get_import_path(characters) + + # init phonemizer + phonemizer = None + if config.use_phonemes: + if "phonemizer" in config and config.phonemizer == "multi_phonemizer": + lang_to_phonemizer_name = {} + for dataset in config.datasets: + if dataset.language != "": + lang_to_phonemizer_name[dataset.language] = dataset.phonemizer + else: + raise ValueError("Multi phonemizer requires language to be set for each dataset.") + phonemizer = MultiPhonemizer(lang_to_phonemizer_name) + else: + phonemizer_kwargs = {"language": config.phoneme_language} + if "phonemizer" in config and config.phonemizer: + phonemizer = get_phonemizer_by_name(config.phonemizer, **phonemizer_kwargs) + else: + try: + phonemizer = get_phonemizer_by_name( + DEF_LANG_TO_PHONEMIZER[config.phoneme_language], **phonemizer_kwargs + ) + new_config.phonemizer = phonemizer.name() + except KeyError as e: + raise ValueError( + f"""No phonemizer found for language {config.phoneme_language}. + You may need to install a third party library for this language.""" + ) from e + + return ( + TTSTokenizer( + config.use_phonemes, text_cleaner, characters, phonemizer, config.add_blank, config.enable_eos_bos_chars + ), + new_config, + ) diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py new file mode 100644 index 0000000000000000000000000000000000000000..fba7bc508ef962d5a93c794ca868acd46d07ec16 --- /dev/null +++ b/TTS/tts/utils/visual.py @@ -0,0 +1,238 @@ +import librosa +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.colors import LogNorm + +matplotlib.use("Agg") + + +def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False, plot_log=False): + if isinstance(alignment, torch.Tensor): + alignment_ = alignment.detach().cpu().numpy().squeeze() + else: + alignment_ = alignment + alignment_ = alignment_.astype(np.float32) if alignment_.dtype == np.float16 else alignment_ + fig, ax = plt.subplots(figsize=fig_size) + im = ax.imshow( + alignment_.T, aspect="auto", origin="lower", interpolation="none", norm=LogNorm() if plot_log else None + ) + fig.colorbar(im, ax=ax) + xlabel = "Decoder timestep" + if info is not None: + xlabel += "\n\n" + info + plt.xlabel(xlabel) + plt.ylabel("Encoder timestep") + # plt.yticks(range(len(text)), list(text)) + plt.tight_layout() + if title is not None: + plt.title(title) + if not output_fig: + plt.close() + return fig + + +def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False): + if isinstance(spectrogram, torch.Tensor): + spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T + else: + spectrogram_ = spectrogram.T + spectrogram_ = spectrogram_.astype(np.float32) if spectrogram_.dtype == np.float16 else spectrogram_ + if ap is not None: + spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access + fig = plt.figure(figsize=fig_size) + plt.imshow(spectrogram_, aspect="auto", origin="lower") + plt.colorbar() + plt.tight_layout() + if not output_fig: + plt.close() + return fig + + +def plot_pitch(pitch, spectrogram, ap=None, fig_size=(30, 10), output_fig=False): + """Plot pitch curves on top of the spectrogram. + + Args: + pitch (np.array): Pitch values. + spectrogram (np.array): Spectrogram values. + + Shapes: + pitch: :math:`(T,)` + spec: :math:`(C, T)` + """ + + if isinstance(spectrogram, torch.Tensor): + spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T + else: + spectrogram_ = spectrogram.T + spectrogram_ = spectrogram_.astype(np.float32) if spectrogram_.dtype == np.float16 else spectrogram_ + if ap is not None: + spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access + + old_fig_size = plt.rcParams["figure.figsize"] + if fig_size is not None: + plt.rcParams["figure.figsize"] = fig_size + + fig, ax = plt.subplots() + + ax.imshow(spectrogram_, aspect="auto", origin="lower") + ax.set_xlabel("time") + ax.set_ylabel("spec_freq") + + ax2 = ax.twinx() + ax2.plot(pitch, linewidth=5.0, color="red") + ax2.set_ylabel("F0") + + plt.rcParams["figure.figsize"] = old_fig_size + if not output_fig: + plt.close() + return fig + + +def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False): + """Plot pitch curves on top of the input characters. + + Args: + pitch (np.array): Pitch values. + chars (str): Characters to place to the x-axis. + + Shapes: + pitch: :math:`(T,)` + """ + old_fig_size = plt.rcParams["figure.figsize"] + if fig_size is not None: + plt.rcParams["figure.figsize"] = fig_size + + fig, ax = plt.subplots() + + x = np.array(range(len(chars))) + my_xticks = chars + plt.xticks(x, my_xticks) + + ax.set_xlabel("characters") + ax.set_ylabel("freq") + + ax2 = ax.twinx() + ax2.plot(pitch, linewidth=5.0, color="red") + ax2.set_ylabel("F0") + + plt.rcParams["figure.figsize"] = old_fig_size + if not output_fig: + plt.close() + return fig + + +def plot_avg_energy(energy, chars, fig_size=(30, 10), output_fig=False): + """Plot energy curves on top of the input characters. + + Args: + energy (np.array): energy values. + chars (str): Characters to place to the x-axis. + + Shapes: + energy: :math:`(T,)` + """ + old_fig_size = plt.rcParams["figure.figsize"] + if fig_size is not None: + plt.rcParams["figure.figsize"] = fig_size + + fig, ax = plt.subplots() + + x = np.array(range(len(chars))) + my_xticks = chars + plt.xticks(x, my_xticks) + + ax.set_xlabel("characters") + ax.set_ylabel("freq") + + ax2 = ax.twinx() + ax2.plot(energy, linewidth=5.0, color="red") + ax2.set_ylabel("energy") + + plt.rcParams["figure.figsize"] = old_fig_size + if not output_fig: + plt.close() + return fig + + +def visualize( + alignment, + postnet_output, + text, + hop_length, + CONFIG, + tokenizer, + stop_tokens=None, + decoder_output=None, + output_path=None, + figsize=(8, 24), + output_fig=False, +): + """Intended to be used in Notebooks.""" + + if decoder_output is not None: + num_plot = 4 + else: + num_plot = 3 + + label_fontsize = 16 + fig = plt.figure(figsize=figsize) + + plt.subplot(num_plot, 1, 1) + plt.imshow(alignment.T, aspect="auto", origin="lower", interpolation=None) + plt.xlabel("Decoder timestamp", fontsize=label_fontsize) + plt.ylabel("Encoder timestamp", fontsize=label_fontsize) + # compute phoneme representation and back + if CONFIG.use_phonemes: + seq = tokenizer.text_to_ids(text) + text = tokenizer.ids_to_text(seq) + print(text) + plt.yticks(range(len(text)), list(text)) + plt.colorbar() + + if stop_tokens is not None: + # plot stopnet predictions + plt.subplot(num_plot, 1, 2) + plt.plot(range(len(stop_tokens)), list(stop_tokens)) + + # plot postnet spectrogram + plt.subplot(num_plot, 1, 3) + librosa.display.specshow( + postnet_output.T, + sr=CONFIG.audio["sample_rate"], + hop_length=hop_length, + x_axis="time", + y_axis="linear", + fmin=CONFIG.audio["mel_fmin"], + fmax=CONFIG.audio["mel_fmax"], + ) + + plt.xlabel("Time", fontsize=label_fontsize) + plt.ylabel("Hz", fontsize=label_fontsize) + plt.tight_layout() + plt.colorbar() + + if decoder_output is not None: + plt.subplot(num_plot, 1, 4) + librosa.display.specshow( + decoder_output.T, + sr=CONFIG.audio["sample_rate"], + hop_length=hop_length, + x_axis="time", + y_axis="linear", + fmin=CONFIG.audio["mel_fmin"], + fmax=CONFIG.audio["mel_fmax"], + ) + plt.xlabel("Time", fontsize=label_fontsize) + plt.ylabel("Hz", fontsize=label_fontsize) + plt.tight_layout() + plt.colorbar() + + if output_path: + print(output_path) + fig.savefig(output_path) + plt.close() + + if not output_fig: + plt.close() diff --git a/TTS/utils/__init__.py b/TTS/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/utils/audio/__init__.py b/TTS/utils/audio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f18f22199908ee0dd5445e34527f5fddb65cfed8 --- /dev/null +++ b/TTS/utils/audio/__init__.py @@ -0,0 +1 @@ +from TTS.utils.audio.processor import AudioProcessor diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..4a8972480c43b9a78a83fc2bf42f97cde345601d --- /dev/null +++ b/TTS/utils/audio/numpy_transforms.py @@ -0,0 +1,488 @@ +import logging +from io import BytesIO +from typing import Tuple + +import librosa +import numpy as np +import scipy +import soundfile as sf +from librosa import magphase, pyin + +logger = logging.getLogger(__name__) + +# For using kwargs +# pylint: disable=unused-argument + + +def build_mel_basis( + *, + sample_rate: int = None, + fft_size: int = None, + num_mels: int = None, + mel_fmax: int = None, + mel_fmin: int = None, + **kwargs, +) -> np.ndarray: + """Build melspectrogram basis. + + Returns: + np.ndarray: melspectrogram basis. + """ + if mel_fmax is not None: + assert mel_fmax <= sample_rate // 2 + assert mel_fmax - mel_fmin > 0 + return librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=mel_fmin, fmax=mel_fmax) + + +def millisec_to_length( + *, frame_length_ms: int = None, frame_shift_ms: int = None, sample_rate: int = None, **kwargs +) -> Tuple[int, int]: + """Compute hop and window length from milliseconds. + + Returns: + Tuple[int, int]: hop length and window length for STFT. + """ + factor = frame_length_ms / frame_shift_ms + assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms" + win_length = int(frame_length_ms / 1000.0 * sample_rate) + hop_length = int(win_length / float(factor)) + return win_length, hop_length + + +def _log(x, base): + if base == 10: + return np.log10(x) + return np.log(x) + + +def _exp(x, base): + if base == 10: + return np.power(10, x) + return np.exp(x) + + +def amp_to_db(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray: + """Convert amplitude values to decibels. + + Args: + x (np.ndarray): Amplitude spectrogram. + gain (float): Gain factor. Defaults to 1. + base (int): Logarithm base. Defaults to 10. + + Returns: + np.ndarray: Decibels spectrogram. + """ + assert (x < 0).sum() == 0, " [!] Input values must be non-negative." + return gain * _log(np.maximum(1e-8, x), base) + + +# pylint: disable=no-self-use +def db_to_amp(*, x: np.ndarray = None, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray: + """Convert decibels spectrogram to amplitude spectrogram. + + Args: + x (np.ndarray): Decibels spectrogram. + gain (float): Gain factor. Defaults to 1. + base (int): Logarithm base. Defaults to 10. + + Returns: + np.ndarray: Amplitude spectrogram. + """ + return _exp(x / gain, base) + + +def preemphasis(*, x: np.ndarray, coef: float = 0.97, **kwargs) -> np.ndarray: + """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. + + Args: + x (np.ndarray): Audio signal. + + Raises: + RuntimeError: Preemphasis coeff is set to 0. + + Returns: + np.ndarray: Decorrelated audio signal. + """ + if coef == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1, -coef], [1], x) + + +def deemphasis(*, x: np.ndarray = None, coef: float = 0.97, **kwargs) -> np.ndarray: + """Reverse pre-emphasis.""" + if coef == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1], [1, -coef], x) + + +def spec_to_mel(*, spec: np.ndarray, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray: + """Convert a full scale linear spectrogram output of a network to a melspectrogram. + + Args: + spec (np.ndarray): Normalized full scale linear spectrogram. + + Shapes: + - spec: :math:`[C, T]` + + Returns: + np.ndarray: Normalized melspectrogram. + """ + return np.dot(mel_basis, spec) + + +def mel_to_spec(*, mel: np.ndarray = None, mel_basis: np.ndarray = None, **kwargs) -> np.ndarray: + """Convert a melspectrogram to full scale spectrogram.""" + assert (mel < 0).sum() == 0, " [!] Input values must be non-negative." + inv_mel_basis = np.linalg.pinv(mel_basis) + return np.maximum(1e-10, np.dot(inv_mel_basis, mel)) + + +def wav_to_spec(*, wav: np.ndarray = None, **kwargs) -> np.ndarray: + """Compute a spectrogram from a waveform. + + Args: + wav (np.ndarray): Waveform. Shape :math:`[T_wav,]` + + Returns: + np.ndarray: Spectrogram. Shape :math:`[C, T_spec]`. :math:`T_spec == T_wav / hop_length` + """ + D = stft(y=wav, **kwargs) + S = np.abs(D) + return S.astype(np.float32) + + +def wav_to_mel(*, wav: np.ndarray = None, mel_basis=None, **kwargs) -> np.ndarray: + """Compute a melspectrogram from a waveform.""" + D = stft(y=wav, **kwargs) + S = spec_to_mel(spec=np.abs(D), mel_basis=mel_basis, **kwargs) + return S.astype(np.float32) + + +def spec_to_wav(*, spec: np.ndarray, power: float = 1.5, **kwargs) -> np.ndarray: + """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" + S = spec.copy() + return griffin_lim(spec=S**power, **kwargs) + + +def mel_to_wav(*, mel: np.ndarray = None, power: float = 1.5, **kwargs) -> np.ndarray: + """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" + S = mel.copy() + S = mel_to_spec(mel=S, mel_basis=kwargs["mel_basis"]) # Convert back to linear + return griffin_lim(spec=S**power, **kwargs) + + +### STFT and ISTFT ### +def stft( + *, + y: np.ndarray = None, + fft_size: int = None, + hop_length: int = None, + win_length: int = None, + pad_mode: str = "reflect", + window: str = "hann", + center: bool = True, + **kwargs, +) -> np.ndarray: + """Librosa STFT wrapper. + + Check http://librosa.org/doc/main/generated/librosa.stft.html argument details. + + Returns: + np.ndarray: Complex number array. + """ + return librosa.stft( + y=y, + n_fft=fft_size, + hop_length=hop_length, + win_length=win_length, + pad_mode=pad_mode, + window=window, + center=center, + ) + + +def istft( + *, + y: np.ndarray = None, + hop_length: int = None, + win_length: int = None, + window: str = "hann", + center: bool = True, + **kwargs, +) -> np.ndarray: + """Librosa iSTFT wrapper. + + Check http://librosa.org/doc/main/generated/librosa.istft.html argument details. + + Returns: + np.ndarray: Complex number array. + """ + return librosa.istft(y, hop_length=hop_length, win_length=win_length, center=center, window=window) + + +def griffin_lim(*, spec: np.ndarray = None, num_iter=60, **kwargs) -> np.ndarray: + angles = np.exp(2j * np.pi * np.random.rand(*spec.shape)) + S_complex = np.abs(spec).astype(complex) + y = istft(y=S_complex * angles, **kwargs) + if not np.isfinite(y).all(): + logger.warning("Waveform is not finite everywhere. Skipping the GL.") + return np.array([0.0]) + for _ in range(num_iter): + angles = np.exp(1j * np.angle(stft(y=y, **kwargs))) + y = istft(y=S_complex * angles, **kwargs) + return y + + +def compute_stft_paddings( + *, x: np.ndarray = None, hop_length: int = None, pad_two_sides: bool = False, **kwargs +) -> Tuple[int, int]: + """Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding + (first and final frames)""" + pad = (x.shape[0] // hop_length + 1) * hop_length - x.shape[0] + if not pad_two_sides: + return 0, pad + return pad // 2, pad // 2 + pad % 2 + + +def compute_f0( + *, + x: np.ndarray = None, + pitch_fmax: float = None, + pitch_fmin: float = None, + hop_length: int = None, + win_length: int = None, + sample_rate: int = None, + stft_pad_mode: str = "reflect", + center: bool = True, + **kwargs, +) -> np.ndarray: + """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. + + Args: + x (np.ndarray): Waveform. Shape :math:`[T_wav,]` + pitch_fmax (float): Pitch max value. + pitch_fmin (float): Pitch min value. + hop_length (int): Number of frames between STFT columns. + win_length (int): STFT window length. + sample_rate (int): Audio sampling rate. + stft_pad_mode (str): Padding mode for STFT. + center (bool): Centered padding. + + Returns: + np.ndarray: Pitch. Shape :math:`[T_pitch,]`. :math:`T_pitch == T_wav / hop_length` + + Examples: + >>> WAV_FILE = filename = librosa.example('vibeace') + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio import AudioProcessor + >>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1) + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate] + >>> pitch = ap.compute_f0(wav) + """ + assert pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`." + assert pitch_fmin is not None, " [!] Set `pitch_fmin` before caling `compute_f0`." + + f0, voiced_mask, _ = pyin( + y=x.astype(np.double), + fmin=pitch_fmin, + fmax=pitch_fmax, + sr=sample_rate, + frame_length=win_length, + win_length=win_length // 2, + hop_length=hop_length, + pad_mode=stft_pad_mode, + center=center, + n_thresholds=100, + beta_parameters=(2, 18), + boltzmann_parameter=2, + resolution=0.1, + max_transition_rate=35.92, + switch_prob=0.01, + no_trough_prob=0.01, + ) + f0[~voiced_mask] = 0.0 + + return f0 + + +def compute_energy(y: np.ndarray, **kwargs) -> np.ndarray: + """Compute energy of a waveform using the same parameters used for computing melspectrogram. + Args: + x (np.ndarray): Waveform. Shape :math:`[T_wav,]` + Returns: + np.ndarray: energy. Shape :math:`[T_energy,]`. :math:`T_energy == T_wav / hop_length` + Examples: + >>> WAV_FILE = filename = librosa.example('vibeace') + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio import AudioProcessor + >>> conf = BaseAudioConfig() + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate] + >>> energy = ap.compute_energy(wav) + """ + x = stft(y=y, **kwargs) + mag, _ = magphase(x) + energy = np.sqrt(np.sum(mag**2, axis=0)) + return energy + + +### Audio Processing ### +def find_endpoint( + *, + wav: np.ndarray = None, + trim_db: float = -40, + sample_rate: int = None, + min_silence_sec=0.8, + gain: float = None, + base: int = None, + **kwargs, +) -> int: + """Find the last point without silence at the end of a audio signal. + + Args: + wav (np.ndarray): Audio signal. + threshold_db (int, optional): Silence threshold in decibels. Defaults to -40. + min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8. + gian (float, optional): Gain to be used to convert trim_db to trim_amp. Defaults to None. + base (int, optional): Base of the logarithm used to convert trim_db to trim_amp. Defaults to 10. + + Returns: + int: Last point without silence. + """ + window_length = int(sample_rate * min_silence_sec) + hop_length = int(window_length / 4) + threshold = db_to_amp(x=-trim_db, gain=gain, base=base) + for x in range(hop_length, len(wav) - window_length, hop_length): + if np.max(wav[x : x + window_length]) < threshold: + return x + hop_length + return len(wav) + + +def trim_silence( + *, + wav: np.ndarray = None, + sample_rate: int = None, + trim_db: float = None, + win_length: int = None, + hop_length: int = None, + **kwargs, +) -> np.ndarray: + """Trim silent parts with a threshold and 0.01 sec margin""" + margin = int(sample_rate * 0.01) + wav = wav[margin:-margin] + return librosa.effects.trim(wav, top_db=trim_db, frame_length=win_length, hop_length=hop_length)[0] + + +def volume_norm(*, x: np.ndarray = None, coef: float = 0.95, **kwargs) -> np.ndarray: + """Normalize the volume of an audio signal. + + Args: + x (np.ndarray): Raw waveform. + coef (float): Coefficient to rescale the maximum value. Defaults to 0.95. + + Returns: + np.ndarray: Volume normalized waveform. + """ + return x / abs(x).max() * coef + + +def rms_norm(*, wav: np.ndarray = None, db_level: float = -27.0, **kwargs) -> np.ndarray: + r = 10 ** (db_level / 20) + a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2)) + return wav * a + + +def rms_volume_norm(*, x: np.ndarray, db_level: float = -27.0, **kwargs) -> np.ndarray: + """Normalize the volume based on RMS of the signal. + + Args: + x (np.ndarray): Raw waveform. + db_level (float): Target dB level in RMS. Defaults to -27.0. + + Returns: + np.ndarray: RMS normalized waveform. + """ + assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0" + wav = rms_norm(wav=x, db_level=db_level) + return wav + + +def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False, **kwargs) -> np.ndarray: + """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. + + Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before. + + Args: + filename (str): Path to the wav file. + sr (int, optional): Sampling rate for resampling. Defaults to None. + resample (bool, optional): Resample the audio file when loading. Slows down the I/O time. Defaults to False. + + Returns: + np.ndarray: Loaded waveform. + """ + if resample: + # loading with resampling. It is significantly slower. + x, _ = librosa.load(filename, sr=sample_rate) + else: + # SF is faster than librosa for loading files + x, _ = sf.read(filename) + return x + + +def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, pipe_out=None, **kwargs) -> None: + """Save float waveform to a file using Scipy. + + Args: + wav (np.ndarray): Waveform with float values in range [-1, 1] to save. + path (str): Path to a output file. + sr (int, optional): Sampling rate used for saving to the file. Defaults to None. + pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe. + """ + wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + + wav_norm = wav_norm.astype(np.int16) + if pipe_out: + wav_buffer = BytesIO() + scipy.io.wavfile.write(wav_buffer, sample_rate, wav_norm) + wav_buffer.seek(0) + pipe_out.buffer.write(wav_buffer.read()) + scipy.io.wavfile.write(path, sample_rate, wav_norm) + + +def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray: + mu = 2**mulaw_qc - 1 + signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) + signal = (signal + 1) / 2 * mu + 0.5 + return np.floor( + signal, + ) + + +def mulaw_decode(*, wav, mulaw_qc: int, **kwargs) -> np.ndarray: + """Recovers waveform from quantized values.""" + mu = 2**mulaw_qc - 1 + x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) + return x + + +def encode_16bits(*, x: np.ndarray, **kwargs) -> np.ndarray: + return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16) + + +def quantize(*, x: np.ndarray, quantize_bits: int, **kwargs) -> np.ndarray: + """Quantize a waveform to a given number of bits. + + Args: + x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`. + quantize_bits (int): Number of quantization bits. + + Returns: + np.ndarray: Quantized waveform. + """ + return (x + 1.0) * (2**quantize_bits - 1) / 2 + + +def dequantize(*, x, quantize_bits, **kwargs) -> np.ndarray: + """Dequantize a waveform from the given number of bits.""" + return 2 * x / (2**quantize_bits - 1) - 1 diff --git a/TTS/utils/audio/processor.py b/TTS/utils/audio/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..680e29debce6a634561e0e91002e92fa8f1736ab --- /dev/null +++ b/TTS/utils/audio/processor.py @@ -0,0 +1,630 @@ +import logging +from io import BytesIO +from typing import Dict, Tuple + +import librosa +import numpy as np +import scipy.io.wavfile +import scipy.signal + +from TTS.tts.utils.helpers import StandardScaler +from TTS.utils.audio.numpy_transforms import ( + amp_to_db, + build_mel_basis, + compute_f0, + db_to_amp, + deemphasis, + find_endpoint, + griffin_lim, + load_wav, + mel_to_spec, + millisec_to_length, + preemphasis, + rms_volume_norm, + spec_to_mel, + stft, + trim_silence, + volume_norm, +) + +logger = logging.getLogger(__name__) + +# pylint: disable=too-many-public-methods + + +class AudioProcessor(object): + """Audio Processor for TTS. + + Note: + All the class arguments are set to default values to enable a flexible initialization + of the class with the model config. They are not meaningful for all the arguments. + + Args: + sample_rate (int, optional): + target audio sampling rate. Defaults to None. + + resample (bool, optional): + enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False. + + num_mels (int, optional): + number of melspectrogram dimensions. Defaults to None. + + log_func (int, optional): + log exponent used for converting spectrogram aplitude to DB. + + min_level_db (int, optional): + minimum db threshold for the computed melspectrograms. Defaults to None. + + frame_shift_ms (int, optional): + milliseconds of frames between STFT columns. Defaults to None. + + frame_length_ms (int, optional): + milliseconds of STFT window length. Defaults to None. + + hop_length (int, optional): + number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None. + + win_length (int, optional): + STFT window length. Used if ```frame_length_ms``` is None. Defaults to None. + + ref_level_db (int, optional): + reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None. + + fft_size (int, optional): + FFT window size for STFT. Defaults to 1024. + + power (int, optional): + Exponent value applied to the spectrogram before GriffinLim. Defaults to None. + + preemphasis (float, optional): + Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0. + + signal_norm (bool, optional): + enable/disable signal normalization. Defaults to None. + + symmetric_norm (bool, optional): + enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None. + + max_norm (float, optional): + ```k``` defining the normalization range. Defaults to None. + + mel_fmin (int, optional): + minimum filter frequency for computing melspectrograms. Defaults to None. + + mel_fmax (int, optional): + maximum filter frequency for computing melspectrograms. Defaults to None. + + pitch_fmin (int, optional): + minimum filter frequency for computing pitch. Defaults to None. + + pitch_fmax (int, optional): + maximum filter frequency for computing pitch. Defaults to None. + + spec_gain (int, optional): + gain applied when converting amplitude to DB. Defaults to 20. + + stft_pad_mode (str, optional): + Padding mode for STFT. Defaults to 'reflect'. + + clip_norm (bool, optional): + enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + + griffin_lim_iters (int, optional): + Number of GriffinLim iterations. Defaults to None. + + do_trim_silence (bool, optional): + enable/disable silence trimming when loading the audio signal. Defaults to False. + + trim_db (int, optional): + DB threshold used for silence trimming. Defaults to 60. + + do_sound_norm (bool, optional): + enable/disable signal normalization. Defaults to False. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. + + do_amp_to_db_mel (bool, optional): + enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + + do_rms_norm (bool, optional): + enable/disable RMS volume normalization when loading an audio file. Defaults to False. + + db_level (int, optional): + dB level used for rms normalization. The range is -99 to 0. Defaults to None. + + stats_path (str, optional): + Path to the computed stats file. Defaults to None. + """ + + def __init__( + self, + sample_rate=None, + resample=False, + num_mels=None, + log_func="np.log10", + min_level_db=None, + frame_shift_ms=None, + frame_length_ms=None, + hop_length=None, + win_length=None, + ref_level_db=None, + fft_size=1024, + power=None, + preemphasis=0.0, + signal_norm=None, + symmetric_norm=None, + max_norm=None, + mel_fmin=None, + mel_fmax=None, + pitch_fmax=None, + pitch_fmin=None, + spec_gain=20, + stft_pad_mode="reflect", + clip_norm=True, + griffin_lim_iters=None, + do_trim_silence=False, + trim_db=60, + do_sound_norm=False, + do_amp_to_db_linear=True, + do_amp_to_db_mel=True, + do_rms_norm=False, + db_level=None, + stats_path=None, + **_, + ): + # setup class attributed + self.sample_rate = sample_rate + self.resample = resample + self.num_mels = num_mels + self.log_func = log_func + self.min_level_db = min_level_db or 0 + self.frame_shift_ms = frame_shift_ms + self.frame_length_ms = frame_length_ms + self.ref_level_db = ref_level_db + self.fft_size = fft_size + self.power = power + self.preemphasis = preemphasis + self.griffin_lim_iters = griffin_lim_iters + self.signal_norm = signal_norm + self.symmetric_norm = symmetric_norm + self.mel_fmin = mel_fmin or 0 + self.mel_fmax = mel_fmax + self.pitch_fmin = pitch_fmin + self.pitch_fmax = pitch_fmax + self.spec_gain = float(spec_gain) + self.stft_pad_mode = stft_pad_mode + self.max_norm = 1.0 if max_norm is None else float(max_norm) + self.clip_norm = clip_norm + self.do_trim_silence = do_trim_silence + self.trim_db = trim_db + self.do_sound_norm = do_sound_norm + self.do_amp_to_db_linear = do_amp_to_db_linear + self.do_amp_to_db_mel = do_amp_to_db_mel + self.do_rms_norm = do_rms_norm + self.db_level = db_level + self.stats_path = stats_path + # setup exp_func for db to amp conversion + if log_func == "np.log": + self.base = np.e + elif log_func == "np.log10": + self.base = 10 + else: + raise ValueError(" [!] unknown `log_func` value.") + # setup stft parameters + if hop_length is None: + # compute stft parameters from given time values + self.win_length, self.hop_length = millisec_to_length( + frame_length_ms=self.frame_length_ms, frame_shift_ms=self.frame_shift_ms, sample_rate=self.sample_rate + ) + else: + # use stft parameters from config file + self.hop_length = hop_length + self.win_length = win_length + assert min_level_db != 0.0, " [!] min_level_db is 0" + assert ( + self.win_length <= self.fft_size + ), f" [!] win_length cannot be larger than fft_size - {self.win_length} vs {self.fft_size}" + members = vars(self) + logger.info("Setting up Audio Processor...") + for key, value in members.items(): + logger.info(" | %s: %s", key, value) + # create spectrogram utils + self.mel_basis = build_mel_basis( + sample_rate=self.sample_rate, + fft_size=self.fft_size, + num_mels=self.num_mels, + mel_fmax=self.mel_fmax, + mel_fmin=self.mel_fmin, + ) + # setup scaler + if stats_path and signal_norm: + mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path) + self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std) + self.signal_norm = True + self.max_norm = None + self.clip_norm = None + self.symmetric_norm = None + + @staticmethod + def init_from_config(config: "Coqpit"): + if "audio" in config: + return AudioProcessor(**config.audio) + return AudioProcessor(**config) + + ### normalization ### + def normalize(self, S: np.ndarray) -> np.ndarray: + """Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]` + + Args: + S (np.ndarray): Spectrogram to normalize. + + Raises: + RuntimeError: Mean and variance is computed from incompatible parameters. + + Returns: + np.ndarray: Normalized spectrogram. + """ + # pylint: disable=no-else-return + S = S.copy() + if self.signal_norm: + # mean-var scaling + if hasattr(self, "mel_scaler"): + if S.shape[0] == self.num_mels: + return self.mel_scaler.transform(S.T).T + elif S.shape[0] == self.fft_size / 2: + return self.linear_scaler.transform(S.T).T + else: + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") + # range normalization + S -= self.ref_level_db # discard certain range of DB assuming it is air noise + S_norm = (S - self.min_level_db) / (-self.min_level_db) + if self.symmetric_norm: + S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm + if self.clip_norm: + S_norm = np.clip( + S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type + ) + return S_norm + else: + S_norm = self.max_norm * S_norm + if self.clip_norm: + S_norm = np.clip(S_norm, 0, self.max_norm) + return S_norm + else: + return S + + def denormalize(self, S: np.ndarray) -> np.ndarray: + """Denormalize spectrogram values. + + Args: + S (np.ndarray): Spectrogram to denormalize. + + Raises: + RuntimeError: Mean and variance are incompatible. + + Returns: + np.ndarray: Denormalized spectrogram. + """ + # pylint: disable=no-else-return + S_denorm = S.copy() + if self.signal_norm: + # mean-var scaling + if hasattr(self, "mel_scaler"): + if S_denorm.shape[0] == self.num_mels: + return self.mel_scaler.inverse_transform(S_denorm.T).T + elif S_denorm.shape[0] == self.fft_size / 2: + return self.linear_scaler.inverse_transform(S_denorm.T).T + else: + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") + if self.symmetric_norm: + if self.clip_norm: + S_denorm = np.clip( + S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type + ) + S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db + return S_denorm + self.ref_level_db + else: + if self.clip_norm: + S_denorm = np.clip(S_denorm, 0, self.max_norm) + S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db + return S_denorm + self.ref_level_db + else: + return S_denorm + + ### Mean-STD scaling ### + def load_stats(self, stats_path: str) -> Tuple[np.array, np.array, np.array, np.array, Dict]: + """Loading mean and variance statistics from a `npy` file. + + Args: + stats_path (str): Path to the `npy` file containing + + Returns: + Tuple[np.array, np.array, np.array, np.array, Dict]: loaded statistics and the config used to + compute them. + """ + stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg + mel_mean = stats["mel_mean"] + mel_std = stats["mel_std"] + linear_mean = stats["linear_mean"] + linear_std = stats["linear_std"] + stats_config = stats["audio_config"] + # check all audio parameters used for computing stats + skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"] + for key in stats_config.keys(): + if key in skip_parameters: + continue + if key not in ["sample_rate", "trim_db"]: + assert ( + stats_config[key] == self.__dict__[key] + ), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" + return mel_mean, mel_std, linear_mean, linear_std, stats_config + + # pylint: disable=attribute-defined-outside-init + def setup_scaler( + self, mel_mean: np.ndarray, mel_std: np.ndarray, linear_mean: np.ndarray, linear_std: np.ndarray + ) -> None: + """Initialize scaler objects used in mean-std normalization. + + Args: + mel_mean (np.ndarray): Mean for melspectrograms. + mel_std (np.ndarray): STD for melspectrograms. + linear_mean (np.ndarray): Mean for full scale spectrograms. + linear_std (np.ndarray): STD for full scale spectrograms. + """ + self.mel_scaler = StandardScaler() + self.mel_scaler.set_stats(mel_mean, mel_std) + self.linear_scaler = StandardScaler() + self.linear_scaler.set_stats(linear_mean, linear_std) + + ### Preemphasis ### + def apply_preemphasis(self, x: np.ndarray) -> np.ndarray: + """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. + + Args: + x (np.ndarray): Audio signal. + + Raises: + RuntimeError: Preemphasis coeff is set to 0. + + Returns: + np.ndarray: Decorrelated audio signal. + """ + return preemphasis(x=x, coef=self.preemphasis) + + def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray: + """Reverse pre-emphasis.""" + return deemphasis(x=x, coef=self.preemphasis) + + ### SPECTROGRAMs ### + def spectrogram(self, y: np.ndarray) -> np.ndarray: + """Compute a spectrogram from a waveform. + + Args: + y (np.ndarray): Waveform. + + Returns: + np.ndarray: Spectrogram. + """ + if self.preemphasis != 0: + y = self.apply_preemphasis(y) + D = stft( + y=y, + fft_size=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + ) + if self.do_amp_to_db_linear: + S = amp_to_db(x=np.abs(D), gain=self.spec_gain, base=self.base) + else: + S = np.abs(D) + return self.normalize(S).astype(np.float32) + + def melspectrogram(self, y: np.ndarray) -> np.ndarray: + """Compute a melspectrogram from a waveform.""" + if self.preemphasis != 0: + y = self.apply_preemphasis(y) + D = stft( + y=y, + fft_size=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + ) + S = spec_to_mel(spec=np.abs(D), mel_basis=self.mel_basis) + if self.do_amp_to_db_mel: + S = amp_to_db(x=S, gain=self.spec_gain, base=self.base) + + return self.normalize(S).astype(np.float32) + + def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray: + """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" + S = self.denormalize(spectrogram) + S = db_to_amp(x=S, gain=self.spec_gain, base=self.base) + # Reconstruct phase + W = self._griffin_lim(S**self.power) + return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W + + def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: + """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" + D = self.denormalize(mel_spectrogram) + S = db_to_amp(x=D, gain=self.spec_gain, base=self.base) + S = mel_to_spec(mel=S, mel_basis=self.mel_basis) # Convert back to linear + W = self._griffin_lim(S**self.power) + return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W + + def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: + """Convert a full scale linear spectrogram output of a network to a melspectrogram. + + Args: + linear_spec (np.ndarray): Normalized full scale linear spectrogram. + + Returns: + np.ndarray: Normalized melspectrogram. + """ + S = self.denormalize(linear_spec) + S = db_to_amp(x=S, gain=self.spec_gain, base=self.base) + S = spec_to_mel(spec=np.abs(S), mel_basis=self.mel_basis) + S = amp_to_db(x=S, gain=self.spec_gain, base=self.base) + mel = self.normalize(S) + return mel + + def _griffin_lim(self, S): + return griffin_lim( + spec=S, + num_iter=self.griffin_lim_iters, + hop_length=self.hop_length, + win_length=self.win_length, + fft_size=self.fft_size, + pad_mode=self.stft_pad_mode, + ) + + def compute_f0(self, x: np.ndarray) -> np.ndarray: + """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. + + Args: + x (np.ndarray): Waveform. + + Returns: + np.ndarray: Pitch. + + Examples: + >>> WAV_FILE = filename = librosa.example('vibeace') + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio import AudioProcessor + >>> conf = BaseAudioConfig(pitch_fmax=640, pitch_fmin=1) + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate] + >>> pitch = ap.compute_f0(wav) + """ + # align F0 length to the spectrogram length + if len(x) % self.hop_length == 0: + x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode) + + f0 = compute_f0( + x=x, + pitch_fmax=self.pitch_fmax, + pitch_fmin=self.pitch_fmin, + hop_length=self.hop_length, + win_length=self.win_length, + sample_rate=self.sample_rate, + stft_pad_mode=self.stft_pad_mode, + center=True, + ) + + return f0 + + ### Audio Processing ### + def find_endpoint(self, wav: np.ndarray, min_silence_sec=0.8) -> int: + """Find the last point without silence at the end of a audio signal. + + Args: + wav (np.ndarray): Audio signal. + threshold_db (int, optional): Silence threshold in decibels. Defaults to -40. + min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8. + + Returns: + int: Last point without silence. + """ + return find_endpoint( + wav=wav, + trim_db=self.trim_db, + sample_rate=self.sample_rate, + min_silence_sec=min_silence_sec, + gain=self.spec_gain, + base=self.base, + ) + + def trim_silence(self, wav): + """Trim silent parts with a threshold and 0.01 sec margin""" + return trim_silence( + wav=wav, + sample_rate=self.sample_rate, + trim_db=self.trim_db, + win_length=self.win_length, + hop_length=self.hop_length, + ) + + @staticmethod + def sound_norm(x: np.ndarray) -> np.ndarray: + """Normalize the volume of an audio signal. + + Args: + x (np.ndarray): Raw waveform. + + Returns: + np.ndarray: Volume normalized waveform. + """ + return volume_norm(x=x) + + def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray: + """Normalize the volume based on RMS of the signal. + + Args: + x (np.ndarray): Raw waveform. + + Returns: + np.ndarray: RMS normalized waveform. + """ + if db_level is None: + db_level = self.db_level + return rms_volume_norm(x=x, db_level=db_level) + + ### save and load ### + def load_wav(self, filename: str, sr: int = None) -> np.ndarray: + """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. + + Resampling slows down loading the file significantly. Therefore it is recommended to resample the file before. + + Args: + filename (str): Path to the wav file. + sr (int, optional): Sampling rate for resampling. Defaults to None. + + Returns: + np.ndarray: Loaded waveform. + """ + if sr is not None: + x = load_wav(filename=filename, sample_rate=sr, resample=True) + else: + x = load_wav(filename=filename, sample_rate=self.sample_rate, resample=self.resample) + if self.do_trim_silence: + try: + x = self.trim_silence(x) + except ValueError: + logger.exception("File cannot be trimmed for silence - %s", filename) + if self.do_sound_norm: + x = self.sound_norm(x) + if self.do_rms_norm: + x = self.rms_volume_norm(x, self.db_level) + return x + + def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out=None) -> None: + """Save a waveform to a file using Scipy. + + Args: + wav (np.ndarray): Waveform to save. + path (str): Path to a output file. + sr (int, optional): Sampling rate used for saving to the file. Defaults to None. + pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe. + """ + if self.do_rms_norm: + wav_norm = self.rms_volume_norm(wav, self.db_level) * 32767 + else: + wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + + wav_norm = wav_norm.astype(np.int16) + if pipe_out: + wav_buffer = BytesIO() + scipy.io.wavfile.write(wav_buffer, sr if sr else self.sample_rate, wav_norm) + wav_buffer.seek(0) + pipe_out.buffer.write(wav_buffer.read()) + scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm) + + def get_duration(self, filename: str) -> float: + """Get the duration of a wav file using Librosa. + + Args: + filename (str): Path to the wav file. + """ + return librosa.get_duration(filename=filename) diff --git a/TTS/utils/audio/torch_transforms.py b/TTS/utils/audio/torch_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..632969c51a0426728912bdffd0193c42398d6a21 --- /dev/null +++ b/TTS/utils/audio/torch_transforms.py @@ -0,0 +1,167 @@ +import librosa +import torch +from torch import nn + + +class TorchSTFT(nn.Module): # pylint: disable=abstract-method + """Some of the audio processing funtions using Torch for faster batch processing. + + Args: + + n_fft (int): + FFT window size for STFT. + + hop_length (int): + number of frames between STFT columns. + + win_length (int, optional): + STFT window length. + + pad_wav (bool, optional): + If True pad the audio with (n_fft - hop_length) / 2). Defaults to False. + + window (str, optional): + The name of a function to create a window tensor that is applied/multiplied to each frame/window. Defaults to "hann_window" + + sample_rate (int, optional): + target audio sampling rate. Defaults to None. + + mel_fmin (int, optional): + minimum filter frequency for computing melspectrograms. Defaults to None. + + mel_fmax (int, optional): + maximum filter frequency for computing melspectrograms. Defaults to None. + + n_mels (int, optional): + number of melspectrogram dimensions. Defaults to None. + + use_mel (bool, optional): + If True compute the melspectrograms otherwise. Defaults to False. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to False. + + spec_gain (float, optional): + gain applied when converting amplitude to DB. Defaults to 1.0. + + power (float, optional): + Exponent for the magnitude spectrogram, e.g., 1 for energy, 2 for power, etc. Defaults to None. + + use_htk (bool, optional): + Use HTK formula in mel filter instead of Slaney. + + mel_norm (None, 'slaney', or number, optional): + If 'slaney', divide the triangular mel weights by the width of the mel band + (area normalization). + + If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm. + See `librosa.util.normalize` for a full description of supported norm values + (including `+-np.inf`). + + Otherwise, leave all the triangles aiming for a peak value of 1.0. Defaults to "slaney". + """ + + def __init__( + self, + n_fft, + hop_length, + win_length, + pad_wav=False, + window="hann_window", + sample_rate=None, + mel_fmin=0, + mel_fmax=None, + n_mels=80, + use_mel=False, + do_amp_to_db=False, + spec_gain=1.0, + power=None, + use_htk=False, + mel_norm="slaney", + normalized=False, + ): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.pad_wav = pad_wav + self.sample_rate = sample_rate + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.n_mels = n_mels + self.use_mel = use_mel + self.do_amp_to_db = do_amp_to_db + self.spec_gain = spec_gain + self.power = power + self.use_htk = use_htk + self.mel_norm = mel_norm + self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) + self.mel_basis = None + self.normalized = normalized + if use_mel: + self._build_mel_basis() + + def __call__(self, x): + """Compute spectrogram frames by torch based stft. + + Args: + x (Tensor): input waveform + + Returns: + Tensor: spectrogram frames. + + Shapes: + x: [B x T] or [:math:`[B, 1, T]`] + """ + if x.ndim == 2: + x = x.unsqueeze(1) + if self.pad_wav: + padding = int((self.n_fft - self.hop_length) / 2) + x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") + # B x D x T x 2 + o = torch.view_as_real( + torch.stft( + x.squeeze(1), + self.n_fft, + self.hop_length, + self.win_length, + self.window, + center=True, + pad_mode="reflect", # compatible with audio.py + normalized=self.normalized, + onesided=True, + return_complex=True, + ) + ) + M = o[:, :, :, 0] + P = o[:, :, :, 1] + S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) + + if self.power is not None: + S = S**self.power + + if self.use_mel: + S = torch.matmul(self.mel_basis.to(x), S) + if self.do_amp_to_db: + S = self._amp_to_db(S, spec_gain=self.spec_gain) + return S + + def _build_mel_basis(self): + mel_basis = librosa.filters.mel( + sr=self.sample_rate, + n_fft=self.n_fft, + n_mels=self.n_mels, + fmin=self.mel_fmin, + fmax=self.mel_fmax, + htk=self.use_htk, + norm=self.mel_norm, + ) + self.mel_basis = torch.from_numpy(mel_basis).float() + + @staticmethod + def _amp_to_db(x, spec_gain=1.0): + return torch.log(torch.clamp(x, min=1e-5) * spec_gain) + + @staticmethod + def _db_to_amp(x, spec_gain=1.0): + return torch.exp(x) / spec_gain diff --git a/TTS/utils/callbacks.py b/TTS/utils/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..511d215c656f1ce3ed31484963db64fae4dc77d4 --- /dev/null +++ b/TTS/utils/callbacks.py @@ -0,0 +1,105 @@ +class TrainerCallback: + @staticmethod + def on_init_start(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_init_start"): + trainer.model.module.on_init_start(trainer) + else: + if hasattr(trainer.model, "on_init_start"): + trainer.model.on_init_start(trainer) + + if hasattr(trainer.criterion, "on_init_start"): + trainer.criterion.on_init_start(trainer) + + if hasattr(trainer.optimizer, "on_init_start"): + trainer.optimizer.on_init_start(trainer) + + @staticmethod + def on_init_end(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_init_end"): + trainer.model.module.on_init_end(trainer) + else: + if hasattr(trainer.model, "on_init_end"): + trainer.model.on_init_end(trainer) + + if hasattr(trainer.criterion, "on_init_end"): + trainer.criterion.on_init_end(trainer) + + if hasattr(trainer.optimizer, "on_init_end"): + trainer.optimizer.on_init_end(trainer) + + @staticmethod + def on_epoch_start(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_epoch_start"): + trainer.model.module.on_epoch_start(trainer) + else: + if hasattr(trainer.model, "on_epoch_start"): + trainer.model.on_epoch_start(trainer) + + if hasattr(trainer.criterion, "on_epoch_start"): + trainer.criterion.on_epoch_start(trainer) + + if hasattr(trainer.optimizer, "on_epoch_start"): + trainer.optimizer.on_epoch_start(trainer) + + @staticmethod + def on_epoch_end(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_epoch_end"): + trainer.model.module.on_epoch_end(trainer) + else: + if hasattr(trainer.model, "on_epoch_end"): + trainer.model.on_epoch_end(trainer) + + if hasattr(trainer.criterion, "on_epoch_end"): + trainer.criterion.on_epoch_end(trainer) + + if hasattr(trainer.optimizer, "on_epoch_end"): + trainer.optimizer.on_epoch_end(trainer) + + @staticmethod + def on_train_step_start(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_train_step_start"): + trainer.model.module.on_train_step_start(trainer) + else: + if hasattr(trainer.model, "on_train_step_start"): + trainer.model.on_train_step_start(trainer) + + if hasattr(trainer.criterion, "on_train_step_start"): + trainer.criterion.on_train_step_start(trainer) + + if hasattr(trainer.optimizer, "on_train_step_start"): + trainer.optimizer.on_train_step_start(trainer) + + @staticmethod + def on_train_step_end(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_train_step_end"): + trainer.model.module.on_train_step_end(trainer) + else: + if hasattr(trainer.model, "on_train_step_end"): + trainer.model.on_train_step_end(trainer) + + if hasattr(trainer.criterion, "on_train_step_end"): + trainer.criterion.on_train_step_end(trainer) + + if hasattr(trainer.optimizer, "on_train_step_end"): + trainer.optimizer.on_train_step_end(trainer) + + @staticmethod + def on_keyboard_interrupt(trainer) -> None: + if hasattr(trainer.model, "module"): + if hasattr(trainer.model.module, "on_keyboard_interrupt"): + trainer.model.module.on_keyboard_interrupt(trainer) + else: + if hasattr(trainer.model, "on_keyboard_interrupt"): + trainer.model.on_keyboard_interrupt(trainer) + + if hasattr(trainer.criterion, "on_keyboard_interrupt"): + trainer.criterion.on_keyboard_interrupt(trainer) + + if hasattr(trainer.optimizer, "on_keyboard_interrupt"): + trainer.optimizer.on_keyboard_interrupt(trainer) diff --git a/TTS/utils/capacitron_optimizer.py b/TTS/utils/capacitron_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7206ffd508896cab96a22288f33a93e999c5f009 --- /dev/null +++ b/TTS/utils/capacitron_optimizer.py @@ -0,0 +1,67 @@ +from typing import Generator + +from trainer.trainer_utils import get_optimizer + + +class CapacitronOptimizer: + """Double optimizer class for the Capacitron model.""" + + def __init__(self, config: dict, model_params: Generator) -> None: + self.primary_params, self.secondary_params = self.split_model_parameters(model_params) + + optimizer_names = list(config.optimizer_params.keys()) + optimizer_parameters = list(config.optimizer_params.values()) + + self.primary_optimizer = get_optimizer( + optimizer_names[0], + optimizer_parameters[0], + config.lr, + parameters=self.primary_params, + ) + + self.secondary_optimizer = get_optimizer( + optimizer_names[1], + self.extract_optimizer_parameters(optimizer_parameters[1]), + optimizer_parameters[1]["lr"], + parameters=self.secondary_params, + ) + + self.param_groups = self.primary_optimizer.param_groups + + def first_step(self): + self.secondary_optimizer.step() + self.secondary_optimizer.zero_grad() + self.primary_optimizer.zero_grad() + + def step(self): + # Update param groups to display the correct learning rate + self.param_groups = self.primary_optimizer.param_groups + self.primary_optimizer.step() + + def zero_grad(self, set_to_none=False): + self.primary_optimizer.zero_grad(set_to_none) + self.secondary_optimizer.zero_grad(set_to_none) + + def load_state_dict(self, state_dict): + self.primary_optimizer.load_state_dict(state_dict[0]) + self.secondary_optimizer.load_state_dict(state_dict[1]) + + def state_dict(self): + return [self.primary_optimizer.state_dict(), self.secondary_optimizer.state_dict()] + + @staticmethod + def split_model_parameters(model_params: Generator) -> list: + primary_params = [] + secondary_params = [] + for name, param in model_params: + if param.requires_grad: + if name == "capacitron_vae_layer.beta": + secondary_params.append(param) + else: + primary_params.append(param) + return [iter(primary_params), iter(secondary_params)] + + @staticmethod + def extract_optimizer_parameters(params: dict) -> dict: + """Extract parameters that are not the learning rate""" + return {k: v for k, v in params.items() if k != "lr"} diff --git a/TTS/utils/distribute.py b/TTS/utils/distribute.py new file mode 100644 index 0000000000000000000000000000000000000000..a51ef7661ece97c87c165ad1aba4c9d9700379dc --- /dev/null +++ b/TTS/utils/distribute.py @@ -0,0 +1,20 @@ +# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py +import torch +import torch.distributed as dist + + +def reduce_tensor(tensor, num_gpus): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.reduce_op.SUM) + rt /= num_gpus + return rt + + +def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): + assert torch.cuda.is_available(), "Distributed mode requires CUDA." + + # Set cuda device so everything is done on the right GPU. + torch.cuda.set_device(rank % torch.cuda.device_count()) + + # Initialize distributed communication + dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name) diff --git a/TTS/utils/download.py b/TTS/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..e94b1d68c81980862aa7df792a086d426b1dc643 --- /dev/null +++ b/TTS/utils/download.py @@ -0,0 +1,212 @@ +# Adapted from https://github.com/pytorch/audio/ + +import hashlib +import logging +import os +import tarfile +import urllib +import urllib.request +import zipfile +from os.path import expanduser +from typing import Any, Iterable, List, Optional + +from torch.utils.model_zoo import tqdm + +logger = logging.getLogger(__name__) + + +def stream_url( + url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True +) -> Iterable: + """Stream url by chunk + + Args: + url (str): Url. + start_byte (int or None, optional): Start streaming at that point (Default: ``None``). + block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``). + progress_bar (bool, optional): Display a progress bar (Default: ``True``). + """ + + # If we already have the whole file, there is no need to download it again + req = urllib.request.Request(url, method="HEAD") + with urllib.request.urlopen(req) as response: + url_size = int(response.info().get("Content-Length", -1)) + if url_size == start_byte: + return + + req = urllib.request.Request(url) + if start_byte: + req.headers["Range"] = "bytes={}-".format(start_byte) + + with ( + urllib.request.urlopen(req) as upointer, + tqdm( + unit="B", + unit_scale=True, + unit_divisor=1024, + total=url_size, + disable=not progress_bar, + ) as pbar, + ): + num_bytes = 0 + while True: + chunk = upointer.read(block_size) + if not chunk: + break + yield chunk + num_bytes += len(chunk) + pbar.update(len(chunk)) + + +def download_url( + url: str, + download_folder: str, + filename: Optional[str] = None, + hash_value: Optional[str] = None, + hash_type: str = "sha256", + progress_bar: bool = True, + resume: bool = False, +) -> None: + """Download file to disk. + + Args: + url (str): Url. + download_folder (str): Folder to download file. + filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url + (Default: ``None``). + hash_value (str or None, optional): Hash for url (Default: ``None``). + hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). + progress_bar (bool, optional): Display a progress bar (Default: ``True``). + resume (bool, optional): Enable resuming download (Default: ``False``). + """ + + req = urllib.request.Request(url, method="HEAD") + req_info = urllib.request.urlopen(req).info() # pylint: disable=consider-using-with + + # Detect filename + filename = filename or req_info.get_filename() or os.path.basename(url) + filepath = os.path.join(download_folder, filename) + if resume and os.path.exists(filepath): + mode = "ab" + local_size: Optional[int] = os.path.getsize(filepath) + + elif not resume and os.path.exists(filepath): + raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath)) + else: + mode = "wb" + local_size = None + + if hash_value and local_size == int(req_info.get("Content-Length", -1)): + with open(filepath, "rb") as file_obj: + if validate_file(file_obj, hash_value, hash_type): + return + raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath)) + + with open(filepath, mode) as fpointer: + for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar): + fpointer.write(chunk) + + with open(filepath, "rb") as file_obj: + if hash_value and not validate_file(file_obj, hash_value, hash_type): + raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath)) + + +def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool: + """Validate a given file object with its hash. + + Args: + file_obj: File object to read from. + hash_value (str): Hash for url. + hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). + + Returns: + bool: return True if its a valid file, else False. + """ + + if hash_type == "sha256": + hash_func = hashlib.sha256() + elif hash_type == "md5": + hash_func = hashlib.md5() + else: + raise ValueError + + while True: + # Read by chunk to avoid filling memory + chunk = file_obj.read(1024**2) + if not chunk: + break + hash_func.update(chunk) + + return hash_func.hexdigest() == hash_value + + +def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]: + """Extract archive. + Args: + from_path (str): the path of the archive. + to_path (str or None, optional): the root path of the extraced files (directory of from_path) + (Default: ``None``) + overwrite (bool, optional): overwrite existing files (Default: ``False``) + + Returns: + list: List of paths to extracted files even if not overwritten. + """ + logger.info("Extracting archive file...") + if to_path is None: + to_path = os.path.dirname(from_path) + + try: + with tarfile.open(from_path, "r") as tar: + logger.info("Opened tar file %s.", from_path) + files = [] + for file_ in tar: # type: Any + file_path = os.path.join(to_path, file_.name) + if file_.isfile(): + files.append(file_path) + if os.path.exists(file_path): + logger.info("%s already extracted.", file_path) + if not overwrite: + continue + tar.extract(file_, to_path) + return files + except tarfile.ReadError: + pass + + try: + with zipfile.ZipFile(from_path, "r") as zfile: + logger.info("Opened zip file %s.", from_path) + files = zfile.namelist() + for file_ in files: + file_path = os.path.join(to_path, file_) + if os.path.exists(file_path): + logger.info("%s already extracted.", file_path) + if not overwrite: + continue + zfile.extract(file_, to_path) + return files + except zipfile.BadZipFile: + pass + + raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.") + + +def download_kaggle_dataset(dataset_path: str, dataset_name: str, output_path: str): + """Download dataset from kaggle. + Args: + dataset_path (str): + This the kaggle link to the dataset. for example vctk is 'mfekadu/english-multispeaker-corpus-for-voice-cloning' + dataset_name (str): Name of the folder the dataset will be saved in. + output_path (str): Path of the location you want the dataset folder to be saved to. + """ + data_path = os.path.join(output_path, dataset_name) + try: + import kaggle # pylint: disable=import-outside-toplevel + + kaggle.api.authenticate() + logger.info("Downloading %s...", dataset_name) + kaggle.api.dataset_download_files(dataset_path, path=data_path, unzip=True) + except OSError: + logger.exception( + "In order to download kaggle datasets, you need to have a kaggle api token stored in your %s", + os.path.join(expanduser("~"), ".kaggle/kaggle.json"), + ) diff --git a/TTS/utils/downloaders.py b/TTS/utils/downloaders.py new file mode 100644 index 0000000000000000000000000000000000000000..8705873982708978d08ba9924d27db7d3f528e3e --- /dev/null +++ b/TTS/utils/downloaders.py @@ -0,0 +1,123 @@ +import logging +import os +from typing import Optional + +from TTS.utils.download import download_kaggle_dataset, download_url, extract_archive + +logger = logging.getLogger(__name__) + + +def download_ljspeech(path: str): + """Download and extract LJSpeech dataset + + Args: + path (str): path to the directory where the dataset will be stored. + """ + os.makedirs(path, exist_ok=True) + url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2" + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + extract_archive(archive) + + +def download_vctk(path: str, use_kaggle: Optional[bool] = False): + """Download and extract VCTK dataset. + + Args: + path (str): path to the directory where the dataset will be stored. + + use_kaggle (bool, optional): Downloads vctk dataset from kaggle. Is generally faster. Defaults to False. + """ + if use_kaggle: + download_kaggle_dataset("mfekadu/english-multispeaker-corpus-for-voice-cloning", "VCTK", path) + else: + os.makedirs(path, exist_ok=True) + url = "https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip" + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + extract_archive(archive) + + +def download_tweb(path: str): + """Download and extract Tweb dataset + + Args: + path (str): Path to the directory where the dataset will be stored. + """ + download_kaggle_dataset("bryanpark/the-world-english-bible-speech-dataset", "TWEB", path) + + +def download_libri_tts(path: str, subset: Optional[str] = "all"): + """Download and extract libri tts dataset. + + Args: + path (str): Path to the directory where the dataset will be stored. + + subset (str, optional): Name of the subset to download. If you only want to download a certain + portion specify it here. Defaults to 'all'. + """ + + subset_dict = { + "libri-tts-clean-100": "http://www.openslr.org/resources/60/train-clean-100.tar.gz", + "libri-tts-clean-360": "http://www.openslr.org/resources/60/train-clean-360.tar.gz", + "libri-tts-other-500": "http://www.openslr.org/resources/60/train-other-500.tar.gz", + "libri-tts-dev-clean": "http://www.openslr.org/resources/60/dev-clean.tar.gz", + "libri-tts-dev-other": "http://www.openslr.org/resources/60/dev-other.tar.gz", + "libri-tts-test-clean": "http://www.openslr.org/resources/60/test-clean.tar.gz", + "libri-tts-test-other": "http://www.openslr.org/resources/60/test-other.tar.gz", + } + + os.makedirs(path, exist_ok=True) + if subset == "all": + for sub, val in subset_dict.items(): + logger.info("Downloading %s...", sub) + download_url(val, path) + basename = os.path.basename(val) + archive = os.path.join(path, basename) + extract_archive(archive) + logger.info("All subsets downloaded") + else: + url = subset_dict[subset] + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + extract_archive(archive) + + +def download_thorsten_de(path: str): + """Download and extract Thorsten german male voice dataset. + + Args: + path (str): Path to the directory where the dataset will be stored. + """ + os.makedirs(path, exist_ok=True) + url = "https://www.openslr.org/resources/95/thorsten-de_v02.tgz" + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + extract_archive(archive) + + +def download_mailabs(path: str, language: str = "english"): + """Download and extract Mailabs dataset. + + Args: + path (str): Path to the directory where the dataset will be stored. + + language (str): Language subset to download. Defaults to english. + """ + language_dict = { + "english": "https://data.solak.de/data/Training/stt_tts/en_US.tgz", + "german": "https://data.solak.de/data/Training/stt_tts/de_DE.tgz", + "french": "https://data.solak.de/data/Training/stt_tts/fr_FR.tgz", + "italian": "https://data.solak.de/data/Training/stt_tts/it_IT.tgz", + "spanish": "https://data.solak.de/data/Training/stt_tts/es_ES.tgz", + } + os.makedirs(path, exist_ok=True) + url = language_dict[language] + download_url(url, path) + basename = os.path.basename(url) + archive = os.path.join(path, basename) + extract_archive(archive) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3ee285232f4efa796c48fab99a9423b33abe67b9 --- /dev/null +++ b/TTS/utils/generic_utils.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +import datetime +import importlib +import logging +import re +from pathlib import Path +from typing import Dict, Optional + +import torch +from packaging.version import Version + +logger = logging.getLogger(__name__) + + +def to_camel(text): + text = text.capitalize() + text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) + text = text.replace("Tts", "TTS") + text = text.replace("vc", "VC") + return text + + +def find_module(module_path: str, module_name: str) -> object: + module_name = module_name.lower() + module = importlib.import_module(module_path + "." + module_name) + class_name = to_camel(module_name) + return getattr(module, class_name) + + +def import_class(module_path: str) -> object: + """Import a class from a module path. + + Args: + module_path (str): The module path of the class. + + Returns: + object: The imported class. + """ + class_name = module_path.split(".")[-1] + module_path = ".".join(module_path.split(".")[:-1]) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def get_import_path(obj: object) -> str: + """Get the import path of a class. + + Args: + obj (object): The class object. + + Returns: + str: The import path of the class. + """ + return ".".join([type(obj).__module__, type(obj).__name__]) + + +def set_init_dict(model_dict, checkpoint_state, c): + # Partial initialization: if there is a mismatch with new and old layer, it is skipped. + for k, v in checkpoint_state.items(): + if k not in model_dict: + logger.warning("Layer missing in the model finition %s", k) + # 1. filter out unnecessary keys + pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} + # 2. filter out different size layers + pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} + # 3. skip reinit layers + if c.has("reinit_layers") and c.reinit_layers is not None: + for reinit_layer_name in c.reinit_layers: + pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} + # 4. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict)) + return model_dict + + +def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict: + """Format kwargs to hande auxilary inputs to models. + + Args: + def_args (Dict): A dictionary of argument names and their default values if not defined in `kwargs`. + kwargs (Dict): A `dict` or `kwargs` that includes auxilary inputs to the model. + + Returns: + Dict: arguments with formatted auxilary inputs. + """ + kwargs = kwargs.copy() + for name in def_args: + if name not in kwargs or kwargs[name] is None: + kwargs[name] = def_args[name] + return kwargs + + +def get_timestamp() -> str: + return datetime.datetime.now().strftime("%y%m%d-%H%M%S") + + +class ConsoleFormatter(logging.Formatter): + """Custom formatter that prints logging.INFO messages without the level name. + + Source: https://stackoverflow.com/a/62488520 + """ + + def format(self, record): + if record.levelno == logging.INFO: + self._style._fmt = "%(message)s" + else: + self._style._fmt = "%(levelname)s: %(message)s" + return super().format(record) + + +def setup_logger( + logger_name: str, + level: int = logging.INFO, + *, + formatter: Optional[logging.Formatter] = None, + screen: bool = False, + tofile: bool = False, + log_dir: str = "logs", + log_name: str = "log", +) -> None: + lg = logging.getLogger(logger_name) + if formatter is None: + formatter = logging.Formatter( + "%(asctime)s.%(msecs)03d - %(levelname)-8s - %(name)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S" + ) + lg.setLevel(level) + if tofile: + Path(log_dir).mkdir(exist_ok=True, parents=True) + log_file = Path(log_dir) / f"{log_name}_{get_timestamp()}.log" + fh = logging.FileHandler(log_file, mode="w") + fh.setFormatter(formatter) + lg.addHandler(fh) + if screen: + sh = logging.StreamHandler() + sh.setFormatter(formatter) + lg.addHandler(sh) + + +def is_pytorch_at_least_2_4() -> bool: + """Check if the installed Pytorch version is 2.4 or higher.""" + return Version(torch.__version__) >= Version("2.4") diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py new file mode 100644 index 0000000000000000000000000000000000000000..fb5071d9b01ee5f7cc50142b95f89a11f4ebad81 --- /dev/null +++ b/TTS/utils/manage.py @@ -0,0 +1,611 @@ +import json +import logging +import os +import re +import tarfile +import zipfile +from pathlib import Path +from shutil import copyfile, rmtree +from typing import Dict, Tuple + +import fsspec +import requests +from tqdm import tqdm +from trainer.io import get_user_data_dir + +from TTS.config import load_config, read_json_with_comments + +logger = logging.getLogger(__name__) + +LICENSE_URLS = { + "cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/", + "mpl": "https://www.mozilla.org/en-US/MPL/2.0/", + "mpl2": "https://www.mozilla.org/en-US/MPL/2.0/", + "mpl 2.0": "https://www.mozilla.org/en-US/MPL/2.0/", + "mit": "https://choosealicense.com/licenses/mit/", + "apache 2.0": "https://choosealicense.com/licenses/apache-2.0/", + "apache2": "https://choosealicense.com/licenses/apache-2.0/", + "cc-by-sa 4.0": "https://creativecommons.org/licenses/by-sa/4.0/", + "cpml": "https://coqui.ai/cpml.txt", +} + + +class ModelManager(object): + tqdm_progress = None + """Manage TTS models defined in .models.json. + It provides an interface to list and download + models defines in '.model.json' + + Models are downloaded under '.TTS' folder in the user's + home path. + + Args: + models_file (str): path to .model.json file. Defaults to None. + output_prefix (str): prefix to `tts` to download models. Defaults to None + progress_bar (bool): print a progress bar when donwloading a file. Defaults to False. + """ + + def __init__(self, models_file=None, output_prefix=None, progress_bar=False): + super().__init__() + self.progress_bar = progress_bar + if output_prefix is None: + self.output_prefix = get_user_data_dir("tts") + else: + self.output_prefix = os.path.join(output_prefix, "tts") + self.models_dict = None + if models_file is not None: + self.read_models_file(models_file) + else: + # try the default location + path = Path(__file__).parent / "../.models.json" + self.read_models_file(path) + + def read_models_file(self, file_path): + """Read .models.json as a dict + + Args: + file_path (str): path to .models.json. + """ + self.models_dict = read_json_with_comments(file_path) + + def _list_models(self, model_type, model_count=0): + logger.info("") + logger.info("Name format: type/language/dataset/model") + model_list = [] + for lang in self.models_dict[model_type]: + for dataset in self.models_dict[model_type][lang]: + for model in self.models_dict[model_type][lang][dataset]: + model_full_name = f"{model_type}--{lang}--{dataset}--{model}" + output_path = Path(self.output_prefix) / model_full_name + downloaded = " [already downloaded]" if output_path.is_dir() else "" + logger.info(" %2d: %s/%s/%s/%s%s", model_count, model_type, lang, dataset, model, downloaded) + model_list.append(f"{model_type}/{lang}/{dataset}/{model}") + model_count += 1 + return model_list + + def _list_for_model_type(self, model_type): + models_name_list = [] + model_count = 1 + models_name_list.extend(self._list_models(model_type, model_count)) + return models_name_list + + def list_models(self): + models_name_list = [] + model_count = 1 + for model_type in self.models_dict: + model_list = self._list_models(model_type, model_count) + models_name_list.extend(model_list) + return models_name_list + + def log_model_details(self, model_type, lang, dataset, model): + logger.info("Model type: %s", model_type) + logger.info("Language supported: %s", lang) + logger.info("Dataset used: %s", dataset) + logger.info("Model name: %s", model) + if "description" in self.models_dict[model_type][lang][dataset][model]: + logger.info("Description: %s", self.models_dict[model_type][lang][dataset][model]["description"]) + else: + logger.info("Description: coming soon") + if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]: + logger.info( + "Default vocoder: %s", + self.models_dict[model_type][lang][dataset][model]["default_vocoder"], + ) + + def model_info_by_idx(self, model_query): + """Print the description of the model from .models.json file using model_query_idx + + Args: + model_query (str): / + """ + model_name_list = [] + model_type, model_query_idx = model_query.split("/") + try: + model_query_idx = int(model_query_idx) + if model_query_idx <= 0: + logger.error("model_query_idx [%d] should be a positive integer!", model_query_idx) + return + except (TypeError, ValueError): + logger.error("model_query_idx [%s] should be an integer!", model_query_idx) + return + model_count = 0 + if model_type in self.models_dict: + for lang in self.models_dict[model_type]: + for dataset in self.models_dict[model_type][lang]: + for model in self.models_dict[model_type][lang][dataset]: + model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}") + model_count += 1 + else: + logger.error("Model type %s does not exist in the list.", model_type) + return + if model_query_idx > model_count: + logger.error("model_query_idx exceeds the number of available models [%d]", model_count) + else: + model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/") + self.log_model_details(model_type, lang, dataset, model) + + def model_info_by_full_name(self, model_query_name): + """Print the description of the model from .models.json file using model_full_name + + Args: + model_query_name (str): Format is /// + """ + model_type, lang, dataset, model = model_query_name.split("/") + if model_type not in self.models_dict: + logger.error("Model type %s does not exist in the list.", model_type) + return + if lang not in self.models_dict[model_type]: + logger.error("Language %s does not exist for %s.", lang, model_type) + return + if dataset not in self.models_dict[model_type][lang]: + logger.error("Dataset %s does not exist for %s/%s.", dataset, model_type, lang) + return + if model not in self.models_dict[model_type][lang][dataset]: + logger.error("Model %s does not exist for %s/%s/%s.", model, model_type, lang, dataset) + return + self.log_model_details(model_type, lang, dataset, model) + + def list_tts_models(self): + """Print all `TTS` models and return a list of model names + + Format is `language/dataset/model` + """ + return self._list_for_model_type("tts_models") + + def list_vocoder_models(self): + """Print all the `vocoder` models and return a list of model names + + Format is `language/dataset/model` + """ + return self._list_for_model_type("vocoder_models") + + def list_vc_models(self): + """Print all the voice conversion models and return a list of model names + + Format is `language/dataset/model` + """ + return self._list_for_model_type("voice_conversion_models") + + def list_langs(self): + """Print all the available languages""" + logger.info("Name format: type/language") + for model_type in self.models_dict: + for lang in self.models_dict[model_type]: + logger.info(" %s/%s", model_type, lang) + + def list_datasets(self): + """Print all the datasets""" + logger.info("Name format: type/language/dataset") + for model_type in self.models_dict: + for lang in self.models_dict[model_type]: + for dataset in self.models_dict[model_type][lang]: + logger.info(" %s/%s/%s", model_type, lang, dataset) + + @staticmethod + def print_model_license(model_item: Dict): + """Print the license of a model + + Args: + model_item (dict): model item in the models.json + """ + if "license" in model_item and model_item["license"].strip() != "": + logger.info("Model's license - %s", model_item["license"]) + if model_item["license"].lower() in LICENSE_URLS: + logger.info("Check %s for more info.", LICENSE_URLS[model_item["license"].lower()]) + else: + logger.info("Check https://opensource.org/licenses for more info.") + else: + logger.info("Model's license - No license information available") + + def _download_github_model(self, model_item: Dict, output_path: str): + if isinstance(model_item["github_rls_url"], list): + self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar) + else: + self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) + + def _download_hf_model(self, model_item: Dict, output_path: str): + if isinstance(model_item["hf_url"], list): + self._download_model_files(model_item["hf_url"], output_path, self.progress_bar) + else: + self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar) + + def download_fairseq_model(self, model_name, output_path): + URI_PREFIX = "https://coqui.gateway.scarf.sh/fairseq/" + _, lang, _, _ = model_name.split("/") + model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz") + self._download_tar_file(model_download_uri, output_path, self.progress_bar) + + @staticmethod + def set_model_url(model_item: Dict): + model_item["model_url"] = None + if "github_rls_url" in model_item: + model_item["model_url"] = model_item["github_rls_url"] + elif "hf_url" in model_item: + model_item["model_url"] = model_item["hf_url"] + elif "fairseq" in model_item["model_name"]: + model_item["model_url"] = "https://coqui.gateway.scarf.sh/fairseq/" + elif "xtts" in model_item["model_name"]: + model_item["model_url"] = "https://coqui.gateway.scarf.sh/xtts/" + return model_item + + def _set_model_item(self, model_name): + # fetch model info from the dict + if "fairseq" in model_name: + model_type, lang, dataset, model = model_name.split("/") + model_item = { + "model_type": "tts_models", + "license": "CC BY-NC 4.0", + "default_vocoder": None, + "author": "fairseq", + "description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.", + } + model_item["model_name"] = model_name + elif "xtts" in model_name and len(model_name.split("/")) != 4: + # loading xtts models with only model name (e.g. xtts_v2.0.2) + # check model name has the version number with regex + version_regex = r"v\d+\.\d+\.\d+" + if re.search(version_regex, model_name): + model_version = model_name.split("_")[-1] + else: + model_version = "main" + model_type = "tts_models" + lang = "multilingual" + dataset = "multi-dataset" + model = model_name + model_item = { + "default_vocoder": None, + "license": "CPML", + "contact": "info@coqui.ai", + "tos_required": True, + "hf_url": [ + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/model.pth", + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/config.json", + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/vocab.json", + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/hash.md5", + f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/speakers_xtts.pth", + ], + } + else: + # get model from models.json + model_type, lang, dataset, model = model_name.split("/") + model_item = self.models_dict[model_type][lang][dataset][model] + model_item["model_type"] = model_type + + model_full_name = f"{model_type}--{lang}--{dataset}--{model}" + md5hash = model_item["model_hash"] if "model_hash" in model_item else None + model_item = self.set_model_url(model_item) + return model_item, model_full_name, model, md5hash + + @staticmethod + def ask_tos(model_full_path): + """Ask the user to agree to the terms of service""" + tos_path = os.path.join(model_full_path, "tos_agreed.txt") + print(" > You must confirm the following:") + print(' | > "I have purchased a commercial license from Coqui: licensing@coqui.ai"') + print(' | > "Otherwise, I agree to the terms of the non-commercial CPML: https://coqui.ai/cpml" - [y/n]') + answer = input(" | | > ") + if answer.lower() == "y": + with open(tos_path, "w", encoding="utf-8") as f: + f.write("I have read, understood and agreed to the Terms and Conditions.") + return True + return False + + @staticmethod + def tos_agreed(model_item, model_full_path): + """Check if the user has agreed to the terms of service""" + if "tos_required" in model_item and model_item["tos_required"]: + tos_path = os.path.join(model_full_path, "tos_agreed.txt") + if os.path.exists(tos_path) or os.environ.get("COQUI_TOS_AGREED") == "1": + return True + return False + return True + + def create_dir_and_download_model(self, model_name, model_item, output_path): + os.makedirs(output_path, exist_ok=True) + # handle TOS + if not self.tos_agreed(model_item, output_path): + if not self.ask_tos(output_path): + os.rmdir(output_path) + raise Exception(" [!] You must agree to the terms of service to use this model.") + logger.info("Downloading model to %s", output_path) + try: + if "fairseq" in model_name: + self.download_fairseq_model(model_name, output_path) + elif "github_rls_url" in model_item: + self._download_github_model(model_item, output_path) + elif "hf_url" in model_item: + self._download_hf_model(model_item, output_path) + + except requests.RequestException as e: + logger.exception("Failed to download the model file to %s", output_path) + rmtree(output_path) + raise e + self.print_model_license(model_item=model_item) + + def check_if_configs_are_equal(self, model_name, model_item, output_path): + with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f: + config_local = json.load(f) + remote_url = None + for url in model_item["hf_url"]: + if "config.json" in url: + remote_url = url + break + + with fsspec.open(remote_url, "r", encoding="utf-8") as f: + config_remote = json.load(f) + + if not config_local == config_remote: + logger.info("%s is already downloaded however it has been changed. Redownloading it...", model_name) + self.create_dir_and_download_model(model_name, model_item, output_path) + + def download_model(self, model_name): + """Download model files given the full model name. + Model name is in the format + 'type/language/dataset/model' + e.g. 'tts_model/en/ljspeech/tacotron' + + Every model must have the following files: + - *.pth : pytorch model checkpoint file. + - config.json : model config file. + - scale_stats.npy (if exist): scale values for preprocessing. + + Args: + model_name (str): model name as explained above. + """ + model_item, model_full_name, model, md5sum = self._set_model_item(model_name) + # set the model specific output path + output_path = os.path.join(self.output_prefix, model_full_name) + if os.path.exists(output_path): + if md5sum is not None: + md5sum_file = os.path.join(output_path, "hash.md5") + if os.path.isfile(md5sum_file): + with open(md5sum_file, mode="r") as f: + if not f.read() == md5sum: + logger.info("%s has been updated, clearing model cache...", model_name) + self.create_dir_and_download_model(model_name, model_item, output_path) + else: + logger.info("%s is already downloaded.", model_name) + else: + logger.info("%s has been updated, clearing model cache...", model_name) + self.create_dir_and_download_model(model_name, model_item, output_path) + # if the configs are different, redownload it + # ToDo: we need a better way to handle it + if "xtts" in model_name: + try: + self.check_if_configs_are_equal(model_name, model_item, output_path) + except: + pass + else: + logger.info("%s is already downloaded.", model_name) + else: + self.create_dir_and_download_model(model_name, model_item, output_path) + + # find downloaded files + output_model_path = output_path + output_config_path = None + if ( + model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name + ): # TODO:This is stupid but don't care for now. + output_model_path, output_config_path = self._find_files(output_path) + # update paths in the config.json + self._update_paths(output_path, output_config_path) + return output_model_path, output_config_path, model_item + + @staticmethod + def _find_files(output_path: str) -> Tuple[str, str]: + """Find the model and config files in the output path + + Args: + output_path (str): path to the model files + + Returns: + Tuple[str, str]: path to the model file and config file + """ + model_file = None + config_file = None + for file_name in os.listdir(output_path): + if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]: + model_file = os.path.join(output_path, file_name) + elif file_name == "config.json": + config_file = os.path.join(output_path, file_name) + if model_file is None: + raise ValueError(" [!] Model file not found in the output path") + if config_file is None: + raise ValueError(" [!] Config file not found in the output path") + return model_file, config_file + + @staticmethod + def _find_speaker_encoder(output_path: str) -> str: + """Find the speaker encoder file in the output path + + Args: + output_path (str): path to the model files + + Returns: + str: path to the speaker encoder file + """ + speaker_encoder_file = None + for file_name in os.listdir(output_path): + if file_name in ["model_se.pth", "model_se.pth.tar"]: + speaker_encoder_file = os.path.join(output_path, file_name) + return speaker_encoder_file + + def _update_paths(self, output_path: str, config_path: str) -> None: + """Update paths for certain files in config.json after download. + + Args: + output_path (str): local path the model is downloaded to. + config_path (str): local config.json path. + """ + output_stats_path = os.path.join(output_path, "scale_stats.npy") + output_d_vector_file_path = os.path.join(output_path, "speakers.json") + output_d_vector_file_pth_path = os.path.join(output_path, "speakers.pth") + output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json") + output_speaker_ids_file_pth_path = os.path.join(output_path, "speaker_ids.pth") + speaker_encoder_config_path = os.path.join(output_path, "config_se.json") + speaker_encoder_model_path = self._find_speaker_encoder(output_path) + + # update the scale_path.npy file path in the model config.json + self._update_path("audio.stats_path", output_stats_path, config_path) + + # update the speakers.json file path in the model config.json to the current path + self._update_path("d_vector_file", output_d_vector_file_path, config_path) + self._update_path("d_vector_file", output_d_vector_file_pth_path, config_path) + self._update_path("model_args.d_vector_file", output_d_vector_file_path, config_path) + self._update_path("model_args.d_vector_file", output_d_vector_file_pth_path, config_path) + + # update the speaker_ids.json file path in the model config.json to the current path + self._update_path("speakers_file", output_speaker_ids_file_path, config_path) + self._update_path("speakers_file", output_speaker_ids_file_pth_path, config_path) + self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path) + self._update_path("model_args.speakers_file", output_speaker_ids_file_pth_path, config_path) + + # update the speaker_encoder file path in the model config.json to the current path + self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path) + self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path) + self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path) + self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path) + + @staticmethod + def _update_path(field_name, new_path, config_path): + """Update the path in the model config.json for the current environment after download""" + if new_path and os.path.exists(new_path): + config = load_config(config_path) + field_names = field_name.split(".") + if len(field_names) > 1: + # field name points to a sub-level field + sub_conf = config + for fd in field_names[:-1]: + if fd in sub_conf: + sub_conf = sub_conf[fd] + else: + return + if isinstance(sub_conf[field_names[-1]], list): + sub_conf[field_names[-1]] = [new_path] + else: + sub_conf[field_names[-1]] = new_path + else: + # field name points to a top-level field + if field_name not in config: + return + if isinstance(config[field_name], list): + config[field_name] = [new_path] + else: + config[field_name] = new_path + config.save_json(config_path) + + @staticmethod + def _download_zip_file(file_url, output_folder, progress_bar): + """Download the github releases""" + # download the file + r = requests.get(file_url, stream=True) + # extract the file + try: + total_size_in_bytes = int(r.headers.get("content-length", 0)) + block_size = 1024 # 1 Kibibyte + if progress_bar: + ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1]) + with open(temp_zip_name, "wb") as file: + for data in r.iter_content(block_size): + if progress_bar: + ModelManager.tqdm_progress.update(len(data)) + file.write(data) + with zipfile.ZipFile(temp_zip_name) as z: + z.extractall(output_folder) + os.remove(temp_zip_name) # delete zip after extract + except zipfile.BadZipFile: + logger.exception("Bad zip file - %s", file_url) + raise zipfile.BadZipFile # pylint: disable=raise-missing-from + # move the files to the outer path + for file_path in z.namelist(): + src_path = os.path.join(output_folder, file_path) + if os.path.isfile(src_path): + dst_path = os.path.join(output_folder, os.path.basename(file_path)) + if src_path != dst_path: + copyfile(src_path, dst_path) + # remove redundant (hidden or not) folders + for file_path in z.namelist(): + if os.path.isdir(os.path.join(output_folder, file_path)): + rmtree(os.path.join(output_folder, file_path)) + + @staticmethod + def _download_tar_file(file_url, output_folder, progress_bar): + """Download the github releases""" + # download the file + r = requests.get(file_url, stream=True) + # extract the file + try: + total_size_in_bytes = int(r.headers.get("content-length", 0)) + block_size = 1024 # 1 Kibibyte + if progress_bar: + ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1]) + with open(temp_tar_name, "wb") as file: + for data in r.iter_content(block_size): + if progress_bar: + ModelManager.tqdm_progress.update(len(data)) + file.write(data) + with tarfile.open(temp_tar_name) as t: + t.extractall(output_folder) + tar_names = t.getnames() + os.remove(temp_tar_name) # delete tar after extract + except tarfile.ReadError: + logger.exception("Bad tar file - %s", file_url) + raise tarfile.ReadError # pylint: disable=raise-missing-from + # move the files to the outer path + for file_path in os.listdir(os.path.join(output_folder, tar_names[0])): + src_path = os.path.join(output_folder, tar_names[0], file_path) + dst_path = os.path.join(output_folder, os.path.basename(file_path)) + if src_path != dst_path: + copyfile(src_path, dst_path) + # remove the extracted folder + rmtree(os.path.join(output_folder, tar_names[0])) + + @staticmethod + def _download_model_files(file_urls, output_folder, progress_bar): + """Download the github releases""" + for file_url in file_urls: + # download the file + r = requests.get(file_url, stream=True) + # extract the file + bease_filename = file_url.split("/")[-1] + temp_zip_name = os.path.join(output_folder, bease_filename) + total_size_in_bytes = int(r.headers.get("content-length", 0)) + block_size = 1024 # 1 Kibibyte + with open(temp_zip_name, "wb") as file: + if progress_bar: + ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + for data in r.iter_content(block_size): + if progress_bar: + ModelManager.tqdm_progress.update(len(data)) + file.write(data) + + @staticmethod + def _check_dict_key(my_dict, key): + if key in my_dict.keys() and my_dict[key] is not None: + if not isinstance(key, str): + return True + if isinstance(key, str) and len(my_dict[key]) > 0: + return True + return False diff --git a/TTS/utils/radam.py b/TTS/utils/radam.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd14990f33cb671f030e401a3a2f9b96c2710cd --- /dev/null +++ b/TTS/utils/radam.py @@ -0,0 +1,105 @@ +# modified from https://github.com/LiyuanLucasLiu/RAdam + +import math + +import torch +from torch.optim.optimizer import Optimizer + + +class RAdam(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): + if lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if eps < 0.0: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + self.degenerated_to_sgd = degenerated_to_sgd + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if "betas" in param and (param["betas"][0] != betas[0] or param["betas"][1] != betas[1]): + param["buffer"] = [[None, None, None] for _ in range(10)] + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)] + ) + super().__init__(params, defaults) + + def __setstate__(self, state): # pylint: disable=useless-super-delegation + super().__setstate__(state) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError("RAdam does not support sparse gradients") + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + state["step"] += 1 + buffered = group["buffer"][int(state["step"] % 10)] + if state["step"] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state["step"] + beta2_t = beta2 ** state["step"] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) + * (N_sma - 4) + / (N_sma_max - 4) + * (N_sma - 2) + / N_sma + * N_sma_max + / (N_sma_max - 2) + ) / (1 - beta1 ** state["step"]) + elif self.degenerated_to_sgd: + step_size = 1.0 / (1 - beta1 ** state["step"]) + else: + step_size = -1 + buffered[2] = step_size + + # more conservative since it's an approximated value + if N_sma >= 5: + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"]) + p.data.copy_(p_data_fp32) + elif step_size > 0: + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"]) + p.data.copy_(p_data_fp32) + + return loss diff --git a/TTS/utils/samplers.py b/TTS/utils/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..b08a763a33e40d00a32577e31ffc12b8e228bc46 --- /dev/null +++ b/TTS/utils/samplers.py @@ -0,0 +1,201 @@ +import math +import random +from typing import Callable, List, Union + +from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler + + +class SubsetSampler(Sampler): + """ + Samples elements sequentially from a given list of indices. + + Args: + indices (list): a sequence of indices + """ + + def __init__(self, indices): + super().__init__(indices) + self.indices = indices + + def __iter__(self): + return (self.indices[i] for i in range(len(self.indices))) + + def __len__(self): + return len(self.indices) + + +class PerfectBatchSampler(Sampler): + """ + Samples a mini-batch of indices for a balanced class batching + + Args: + dataset_items(list): dataset items to sample from. + classes (list): list of classes of dataset_items to sample from. + batch_size (int): total number of samples to be sampled in a mini-batch. + num_gpus (int): number of GPU in the data parallel mode. + shuffle (bool): if True, samples randomly, otherwise samples sequentially. + drop_last (bool): if True, drops last incomplete batch. + """ + + def __init__( + self, + dataset_items, + classes, + batch_size, + num_classes_in_batch, + num_gpus=1, + shuffle=True, + drop_last=False, + label_key="class_name", + ): + super().__init__(dataset_items) + assert ( + batch_size % (num_classes_in_batch * num_gpus) == 0 + ), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)." + + label_indices = {} + for idx, item in enumerate(dataset_items): + label = item[label_key] + if label not in label_indices.keys(): + label_indices[label] = [idx] + else: + label_indices[label].append(idx) + + if shuffle: + self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes] + else: + self._samplers = [SubsetSampler(label_indices[key]) for key in classes] + + self._batch_size = batch_size + self._drop_last = drop_last + self._dp_devices = num_gpus + self._num_classes_in_batch = num_classes_in_batch + + def __iter__(self): + batch = [] + if self._num_classes_in_batch != len(self._samplers): + valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) + else: + valid_samplers_idx = None + + iters = [iter(s) for s in self._samplers] + done = False + + while True: + b = [] + for i, it in enumerate(iters): + if valid_samplers_idx is not None and i not in valid_samplers_idx: + continue + idx = next(it, None) + if idx is None: + done = True + break + b.append(idx) + if done: + break + batch += b + if len(batch) == self._batch_size: + yield batch + batch = [] + if valid_samplers_idx is not None: + valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) + + if not self._drop_last: + if len(batch) > 0: + groups = len(batch) // self._num_classes_in_batch + if groups % self._dp_devices == 0: + yield batch + else: + batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch] + if len(batch) > 0: + yield batch + + def __len__(self): + class_batch_size = self._batch_size // self._num_classes_in_batch + return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers) + + +def identity(x): + return x + + +class SortedSampler(Sampler): + """Samples elements sequentially, always in the same order. + + Taken from https://github.com/PetrochukM/PyTorch-NLP + + Args: + data (iterable): Iterable data. + sort_key (callable): Specifies a function of one argument that is used to extract a + numerical comparison key from each list element. + + Example: + >>> list(SortedSampler(range(10), sort_key=lambda i: -i)) + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] + + """ + + def __init__(self, data, sort_key: Callable = identity): + super().__init__(data) + self.data = data + self.sort_key = sort_key + zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] + zip_ = sorted(zip_, key=lambda r: r[1]) + self.sorted_indexes = [item[0] for item in zip_] + + def __iter__(self): + return iter(self.sorted_indexes) + + def __len__(self): + return len(self.data) + + +class BucketBatchSampler(BatchSampler): + """Bucket batch sampler + + Adapted from https://github.com/PetrochukM/PyTorch-NLP + + Args: + sampler (torch.data.utils.sampler.Sampler): + batch_size (int): Size of mini-batch. + drop_last (bool): If `True` the sampler will drop the last batch if its size would be less + than `batch_size`. + data (list): List of data samples. + sort_key (callable, optional): Callable to specify a comparison key for sorting. + bucket_size_multiplier (int, optional): Buckets are of size + `batch_size * bucket_size_multiplier`. + + Example: + >>> sampler = WeightedRandomSampler(weights, len(weights)) + >>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True) + """ + + def __init__( + self, + sampler, + data, + batch_size, + drop_last, + sort_key: Union[Callable, List] = identity, + bucket_size_multiplier=100, + ): + super().__init__(sampler, batch_size, drop_last) + self.data = data + self.sort_key = sort_key + _bucket_size = batch_size * bucket_size_multiplier + if hasattr(sampler, "__len__"): + _bucket_size = min(_bucket_size, len(sampler)) + self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) + + def __iter__(self): + for idxs in self.bucket_sampler: + bucket_data = [self.data[idx] for idx in idxs] + sorted_sampler = SortedSampler(bucket_data, self.sort_key) + for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))): + sorted_idxs = [idxs[i] for i in batch_idx] + yield sorted_idxs + + def __len__(self): + if self.drop_last: + return len(self.sampler) // self.batch_size + return math.ceil(len(self.sampler) / self.batch_size) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py new file mode 100644 index 0000000000000000000000000000000000000000..90af4f48f9d9990dda3cead5084bf139c2c820c4 --- /dev/null +++ b/TTS/utils/synthesizer.py @@ -0,0 +1,505 @@ +import logging +import os +import time +from typing import List + +import numpy as np +import pysbd +import torch +from torch import nn + +from TTS.config import load_config +from TTS.tts.configs.vits_config import VitsConfig +from TTS.tts.models import setup_model as setup_tts_model +from TTS.tts.models.vits import Vits +from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence +from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import save_wav +from TTS.vc.models import setup_model as setup_vc_model +from TTS.vocoder.models import setup_model as setup_vocoder_model +from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input + +logger = logging.getLogger(__name__) + + +class Synthesizer(nn.Module): + def __init__( + self, + tts_checkpoint: str = "", + tts_config_path: str = "", + tts_speakers_file: str = "", + tts_languages_file: str = "", + vocoder_checkpoint: str = "", + vocoder_config: str = "", + encoder_checkpoint: str = "", + encoder_config: str = "", + vc_checkpoint: str = "", + vc_config: str = "", + model_dir: str = "", + voice_dir: str = None, + use_cuda: bool = False, + ) -> None: + """General 🐸 TTS interface for inference. It takes a tts and a vocoder + model and synthesize speech from the provided text. + + The text is divided into a list of sentences using `pysbd` and synthesize + speech on each sentence separately. + + If you have certain special characters in your text, you need to handle + them before providing the text to Synthesizer. + + TODO: set the segmenter based on the source language + + Args: + tts_checkpoint (str, optional): path to the tts model file. + tts_config_path (str, optional): path to the tts config file. + vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None. + vocoder_config (str, optional): path to the vocoder config file. Defaults to None. + encoder_checkpoint (str, optional): path to the speaker encoder model file. Defaults to `""`, + encoder_config (str, optional): path to the speaker encoder config file. Defaults to `""`, + vc_checkpoint (str, optional): path to the voice conversion model file. Defaults to `""`, + vc_config (str, optional): path to the voice conversion config file. Defaults to `""`, + use_cuda (bool, optional): enable/disable cuda. Defaults to False. + """ + super().__init__() + self.tts_checkpoint = tts_checkpoint + self.tts_config_path = tts_config_path + self.tts_speakers_file = tts_speakers_file + self.tts_languages_file = tts_languages_file + self.vocoder_checkpoint = vocoder_checkpoint + self.vocoder_config = vocoder_config + self.encoder_checkpoint = encoder_checkpoint + self.encoder_config = encoder_config + self.vc_checkpoint = vc_checkpoint + self.vc_config = vc_config + self.use_cuda = use_cuda + + self.tts_model = None + self.vocoder_model = None + self.vc_model = None + self.speaker_manager = None + self.tts_speakers = {} + self.language_manager = None + self.num_languages = 0 + self.tts_languages = {} + self.d_vector_dim = 0 + self.seg = self._get_segmenter("en") + self.use_cuda = use_cuda + self.voice_dir = voice_dir + if self.use_cuda: + assert torch.cuda.is_available(), "CUDA is not availabe on this machine." + + if tts_checkpoint: + self._load_tts(tts_checkpoint, tts_config_path, use_cuda) + self.output_sample_rate = self.tts_config.audio["sample_rate"] + + if vocoder_checkpoint: + self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) + self.output_sample_rate = self.vocoder_config.audio["sample_rate"] + + if vc_checkpoint: + self._load_vc(vc_checkpoint, vc_config, use_cuda) + self.output_sample_rate = self.vc_config.audio["output_sample_rate"] + + if model_dir: + if "fairseq" in model_dir: + self._load_fairseq_from_dir(model_dir, use_cuda) + self.output_sample_rate = self.tts_config.audio["sample_rate"] + else: + self._load_tts_from_dir(model_dir, use_cuda) + self.output_sample_rate = self.tts_config.audio["output_sample_rate"] + + @staticmethod + def _get_segmenter(lang: str): + """get the sentence segmenter for the given language. + + Args: + lang (str): target language code. + + Returns: + [type]: [description] + """ + return pysbd.Segmenter(language=lang, clean=True) + + def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> None: + """Load the voice conversion model. + + 1. Load the model config. + 2. Init the model from the config. + 3. Load the model weights. + 4. Move the model to the GPU if CUDA is enabled. + + Args: + vc_checkpoint (str): path to the model checkpoint. + tts_config_path (str): path to the model config file. + use_cuda (bool): enable/disable CUDA use. + """ + # pylint: disable=global-statement + self.vc_config = load_config(vc_config_path) + self.vc_model = setup_vc_model(config=self.vc_config) + self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint) + if use_cuda: + self.vc_model.cuda() + + def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None: + """Load the fairseq model from a directory. + + We assume it is VITS and the model knows how to load itself from the directory and there is a config.json file in the directory. + """ + self.tts_config = VitsConfig() + self.tts_model = Vits.init_from_config(self.tts_config) + self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True) + self.tts_config = self.tts_model.config + if use_cuda: + self.tts_model.cuda() + + def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None: + """Load the TTS model from a directory. + + We assume the model knows how to load itself from the directory and there is a config.json file in the directory. + """ + config = load_config(os.path.join(model_dir, "config.json")) + self.tts_config = config + self.tts_model = setup_tts_model(config) + self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True) + if use_cuda: + self.tts_model.cuda() + + def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None: + """Load the TTS model. + + 1. Load the model config. + 2. Init the model from the config. + 3. Load the model weights. + 4. Move the model to the GPU if CUDA is enabled. + 5. Init the speaker manager in the model. + + Args: + tts_checkpoint (str): path to the model checkpoint. + tts_config_path (str): path to the model config file. + use_cuda (bool): enable/disable CUDA use. + """ + # pylint: disable=global-statement + self.tts_config = load_config(tts_config_path) + if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None: + raise ValueError("Phonemizer is not defined in the TTS config.") + + self.tts_model = setup_tts_model(config=self.tts_config) + + if not self.encoder_checkpoint: + self._set_speaker_encoder_paths_from_tts_config() + + self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True) + if use_cuda: + self.tts_model.cuda() + + if self.encoder_checkpoint and hasattr(self.tts_model, "speaker_manager"): + self.tts_model.speaker_manager.init_encoder(self.encoder_checkpoint, self.encoder_config, use_cuda) + + def _set_speaker_encoder_paths_from_tts_config(self): + """Set the encoder paths from the tts model config for models with speaker encoders.""" + if hasattr(self.tts_config, "model_args") and hasattr( + self.tts_config.model_args, "speaker_encoder_config_path" + ): + self.encoder_checkpoint = self.tts_config.model_args.speaker_encoder_model_path + self.encoder_config = self.tts_config.model_args.speaker_encoder_config_path + + def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None: + """Load the vocoder model. + + 1. Load the vocoder config. + 2. Init the AudioProcessor for the vocoder. + 3. Init the vocoder model from the config. + 4. Move the model to the GPU if CUDA is enabled. + + Args: + model_file (str): path to the model checkpoint. + model_config (str): path to the model config file. + use_cuda (bool): enable/disable CUDA use. + """ + self.vocoder_config = load_config(model_config) + self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio) + self.vocoder_model = setup_vocoder_model(self.vocoder_config) + self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True) + if use_cuda: + self.vocoder_model.cuda() + + def split_into_sentences(self, text) -> List[str]: + """Split give text into sentences. + + Args: + text (str): input text in string format. + + Returns: + List[str]: list of sentences. + """ + return self.seg.segment(text) + + def save_wav(self, wav: List[int], path: str, pipe_out=None) -> None: + """Save the waveform as a file. + + Args: + wav (List[int]): waveform as a list of values. + path (str): output path to save the waveform. + pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe. + """ + # if tensor convert to numpy + if torch.is_tensor(wav): + wav = wav.cpu().numpy() + if isinstance(wav, list): + wav = np.array(wav) + save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate, pipe_out=pipe_out) + + def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]: + output_wav = self.vc_model.voice_conversion(source_wav, target_wav) + return output_wav + + def tts( + self, + text: str = "", + speaker_name: str = "", + language_name: str = "", + speaker_wav=None, + style_wav=None, + style_text=None, + reference_wav=None, + reference_speaker_name=None, + split_sentences: bool = True, + **kwargs, + ) -> List[int]: + """🐸 TTS magic. Run all the models and generate speech. + + Args: + text (str): input text. + speaker_name (str, optional): speaker id for multi-speaker models. Defaults to "". + language_name (str, optional): language id for multi-language models. Defaults to "". + speaker_wav (Union[str, List[str]], optional): path to the speaker wav for voice cloning. Defaults to None. + style_wav ([type], optional): style waveform for GST. Defaults to None. + style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None. + reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None. + reference_speaker_name ([type], optional): speaker id of reference waveform. Defaults to None. + split_sentences (bool, optional): split the input text into sentences. Defaults to True. + **kwargs: additional arguments to pass to the TTS model. + Returns: + List[int]: [description] + """ + start_time = time.time() + wavs = [] + + if not text and not reference_wav: + raise ValueError( + "You need to define either `text` (for sythesis) or a `reference_wav` (for voice conversion) to use the Coqui TTS API." + ) + + if text: + sens = [text] + if split_sentences: + sens = self.split_into_sentences(text) + logger.info("Text split into sentences.") + logger.info("Input: %s", sens) + + # handle multi-speaker + if "voice_dir" in kwargs: + self.voice_dir = kwargs["voice_dir"] + kwargs.pop("voice_dir") + speaker_embedding = None + speaker_id = None + if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"): + if speaker_name and isinstance(speaker_name, str) and not self.tts_config.model == "xtts": + if self.tts_config.use_d_vector_file: + # get the average speaker embedding from the saved d_vectors. + speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding( + speaker_name, num_samples=None, randomize=False + ) + speaker_embedding = np.array(speaker_embedding)[None, :] # [1 x embedding_dim] + else: + # get speaker idx from the speaker name + speaker_id = self.tts_model.speaker_manager.name_to_id[speaker_name] + # handle Neon models with single speaker. + elif len(self.tts_model.speaker_manager.name_to_id) == 1: + speaker_id = list(self.tts_model.speaker_manager.name_to_id.values())[0] + elif not speaker_name and not speaker_wav: + raise ValueError( + " [!] Looks like you are using a multi-speaker model. " + "You need to define either a `speaker_idx` or a `speaker_wav` to use a multi-speaker model." + ) + else: + speaker_embedding = None + else: + if speaker_name and self.voice_dir is None: + raise ValueError( + f" [!] Missing speakers.json file path for selecting speaker {speaker_name}." + "Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. " + ) + + # handle multi-lingual + language_id = None + if self.tts_languages_file or ( + hasattr(self.tts_model, "language_manager") + and self.tts_model.language_manager is not None + and not self.tts_config.model == "xtts" + ): + if len(self.tts_model.language_manager.name_to_id) == 1: + language_id = list(self.tts_model.language_manager.name_to_id.values())[0] + + elif language_name and isinstance(language_name, str): + try: + language_id = self.tts_model.language_manager.name_to_id[language_name] + except KeyError as e: + raise ValueError( + f" [!] Looks like you use a multi-lingual model. " + f"Language {language_name} is not in the available languages: " + f"{self.tts_model.language_manager.name_to_id.keys()}." + ) from e + + elif not language_name: + raise ValueError( + " [!] Look like you use a multi-lingual model. " + "You need to define either a `language_name` or a `style_wav` to use a multi-lingual model." + ) + + else: + raise ValueError( + f" [!] Missing language_ids.json file path for selecting language {language_name}." + "Define path for language_ids.json if it is a multi-lingual model or remove defined language idx. " + ) + + # compute a new d_vector from the given clip. + if ( + speaker_wav is not None + and self.tts_model.speaker_manager is not None + and hasattr(self.tts_model.speaker_manager, "encoder_ap") + and self.tts_model.speaker_manager.encoder_ap is not None + ): + speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav) + + vocoder_device = "cpu" + use_gl = self.vocoder_model is None + if not use_gl: + vocoder_device = next(self.vocoder_model.parameters()).device + if self.use_cuda: + vocoder_device = "cuda" + + if not reference_wav: # not voice conversion + for sen in sens: + if hasattr(self.tts_model, "synthesize"): + outputs = self.tts_model.synthesize( + text=sen, + config=self.tts_config, + speaker_id=speaker_name, + voice_dirs=self.voice_dir, + d_vector=speaker_embedding, + speaker_wav=speaker_wav, + language=language_name, + **kwargs, + ) + else: + # synthesize voice + outputs = synthesis( + model=self.tts_model, + text=sen, + CONFIG=self.tts_config, + use_cuda=self.use_cuda, + speaker_id=speaker_id, + style_wav=style_wav, + style_text=style_text, + use_griffin_lim=use_gl, + d_vector=speaker_embedding, + language_id=language_id, + ) + waveform = outputs["wav"] + if not use_gl: + mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy() + # denormalize tts output based on tts audio config + mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T + # renormalize spectrogram based on vocoder config + vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) + # compute scale factor for possible sample rate mismatch + scale_factor = [ + 1, + self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, + ] + if scale_factor[1] != 1: + logger.info("Interpolating TTS model output.") + vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) + else: + vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable + # run vocoder model + # [1, T, C] + waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device)) + if torch.is_tensor(waveform) and waveform.device != torch.device("cpu") and not use_gl: + waveform = waveform.cpu() + if not use_gl: + waveform = waveform.numpy() + waveform = waveform.squeeze() + + # trim silence + if "do_trim_silence" in self.tts_config.audio and self.tts_config.audio["do_trim_silence"]: + waveform = trim_silence(waveform, self.tts_model.ap) + + wavs += list(waveform) + wavs += [0] * 10000 + else: + # get the speaker embedding or speaker id for the reference wav file + reference_speaker_embedding = None + reference_speaker_id = None + if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"): + if reference_speaker_name and isinstance(reference_speaker_name, str): + if self.tts_config.use_d_vector_file: + # get the speaker embedding from the saved d_vectors. + reference_speaker_embedding = self.tts_model.speaker_manager.get_embeddings_by_name( + reference_speaker_name + )[0] + reference_speaker_embedding = np.array(reference_speaker_embedding)[ + None, : + ] # [1 x embedding_dim] + else: + # get speaker idx from the speaker name + reference_speaker_id = self.tts_model.speaker_manager.name_to_id[reference_speaker_name] + else: + reference_speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip( + reference_wav + ) + outputs = transfer_voice( + model=self.tts_model, + CONFIG=self.tts_config, + use_cuda=self.use_cuda, + reference_wav=reference_wav, + speaker_id=speaker_id, + d_vector=speaker_embedding, + use_griffin_lim=use_gl, + reference_speaker_id=reference_speaker_id, + reference_d_vector=reference_speaker_embedding, + ) + waveform = outputs + if not use_gl: + mel_postnet_spec = outputs[0].detach().cpu().numpy() + # denormalize tts output based on tts audio config + mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T + # renormalize spectrogram based on vocoder config + vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) + # compute scale factor for possible sample rate mismatch + scale_factor = [ + 1, + self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, + ] + if scale_factor[1] != 1: + logger.info("Interpolating TTS model output.") + vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) + else: + vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable + # run vocoder model + # [1, T, C] + waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device)) + if torch.is_tensor(waveform) and waveform.device != torch.device("cpu"): + waveform = waveform.cpu() + if not use_gl: + waveform = waveform.numpy() + wavs = waveform.squeeze() + + # compute stats + process_time = time.time() - start_time + audio_time = len(wavs) / self.tts_config.audio["sample_rate"] + logger.info("Processing time: %.3f", process_time) + logger.info("Real-time factor: %.3f", process_time / audio_time) + return wavs diff --git a/TTS/utils/training.py b/TTS/utils/training.py new file mode 100644 index 0000000000000000000000000000000000000000..57885005f110932015da9bc2a44682ba34601b72 --- /dev/null +++ b/TTS/utils/training.py @@ -0,0 +1,48 @@ +import logging + +import numpy as np +import torch + +logger = logging.getLogger(__name__) + + +def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): + r"""Check model gradient against unexpected jumps and failures""" + skip_flag = False + if ignore_stopnet: + if not amp_opt_params: + grad_norm = torch.nn.utils.clip_grad_norm_( + [param for name, param in model.named_parameters() if "stopnet" not in name], grad_clip + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip) + else: + if not amp_opt_params: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) + else: + grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip) + + # compatibility with different torch versions + if isinstance(grad_norm, float): + if np.isinf(grad_norm): + logger.warning("Gradient is INF !!") + skip_flag = True + else: + if torch.isinf(grad_norm): + logger.warning("Gradient is INF !!") + skip_flag = True + return grad_norm, skip_flag + + +def gradual_training_scheduler(global_step, config): + """Setup the gradual training schedule wrt number + of active GPUs""" + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + num_gpus = 1 + new_values = None + # we set the scheduling wrt num_gpus + for values in config.gradual_training: + if global_step * num_gpus >= values[0]: + new_values = values + return new_values[1], new_values[2] diff --git a/TTS/utils/vad.py b/TTS/utils/vad.py new file mode 100644 index 0000000000000000000000000000000000000000..49c8dc6b6647a86ab2f1d851542f7ae897dc427c --- /dev/null +++ b/TTS/utils/vad.py @@ -0,0 +1,92 @@ +import logging + +import torch +import torchaudio + +logger = logging.getLogger(__name__) + + +def read_audio(path): + wav, sr = torchaudio.load(path) + + if wav.size(0) > 1: + wav = wav.mean(dim=0, keepdim=True) + + return wav.squeeze(0), sr + + +def resample_wav(wav, sr, new_sr): + wav = wav.unsqueeze(0) + transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=new_sr) + wav = transform(wav) + return wav.squeeze(0) + + +def map_timestamps_to_new_sr(vad_sr, new_sr, timestamps, just_begging_end=False): + factor = new_sr / vad_sr + new_timestamps = [] + if just_begging_end and timestamps: + # get just the start and end timestamps + new_dict = {"start": int(timestamps[0]["start"] * factor), "end": int(timestamps[-1]["end"] * factor)} + new_timestamps.append(new_dict) + else: + for ts in timestamps: + # map to the new SR + new_dict = {"start": int(ts["start"] * factor), "end": int(ts["end"] * factor)} + new_timestamps.append(new_dict) + + return new_timestamps + + +def get_vad_model_and_utils(use_cuda=False, use_onnx=False): + model, utils = torch.hub.load( + repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=True, onnx=use_onnx, force_onnx_cpu=True + ) + if use_cuda: + model = model.cuda() + + get_speech_timestamps, save_audio, _, _, collect_chunks = utils + return model, get_speech_timestamps, save_audio, collect_chunks + + +def remove_silence( + model_and_utils, audio_path, out_path, vad_sample_rate=8000, trim_just_beginning_and_end=True, use_cuda=False +): + # get the VAD model and utils functions + model, get_speech_timestamps, _, collect_chunks = model_and_utils + + # read ground truth wav and resample the audio for the VAD + try: + wav, gt_sample_rate = read_audio(audio_path) + except Exception: + logger.exception("Failed to read %s", audio_path) + return None, False + + # if needed, resample the audio for the VAD model + if gt_sample_rate != vad_sample_rate: + wav_vad = resample_wav(wav, gt_sample_rate, vad_sample_rate) + else: + wav_vad = wav + + if use_cuda: + wav_vad = wav_vad.cuda() + + # get speech timestamps from full audio file + speech_timestamps = get_speech_timestamps(wav_vad, model, sampling_rate=vad_sample_rate, window_size_samples=768) + + # map the current speech_timestamps to the sample rate of the ground truth audio + new_speech_timestamps = map_timestamps_to_new_sr( + vad_sample_rate, gt_sample_rate, speech_timestamps, trim_just_beginning_and_end + ) + + # if have speech timestamps else save the wav + if new_speech_timestamps: + wav = collect_chunks(new_speech_timestamps, wav) + is_speech = True + else: + logger.warning("The file %s probably does not have speech please check it!", audio_path) + is_speech = False + + # save + torchaudio.save(out_path, wav[None, :], gt_sample_rate) + return out_path, is_speech diff --git a/TTS/vc/configs/__init__.py b/TTS/vc/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/vc/configs/freevc_config.py b/TTS/vc/configs/freevc_config.py new file mode 100644 index 0000000000000000000000000000000000000000..207181b303982f260c46619bc8ac470f5e950223 --- /dev/null +++ b/TTS/vc/configs/freevc_config.py @@ -0,0 +1,278 @@ +from dataclasses import dataclass, field +from typing import List, Optional + +from coqpit import Coqpit + +from TTS.vc.configs.shared_configs import BaseVCConfig + + +@dataclass +class FreeVCAudioConfig(Coqpit): + """Audio configuration + + Args: + max_wav_value (float): + The maximum value of the waveform. + + input_sample_rate (int): + The sampling rate of the input waveform. + + output_sample_rate (int): + The sampling rate of the output waveform. + + filter_length (int): + The length of the filter. + + hop_length (int): + The hop length. + + win_length (int): + The window length. + + n_mel_channels (int): + The number of mel channels. + + mel_fmin (float): + The minimum frequency of the mel filterbank. + + mel_fmax (Optional[float]): + The maximum frequency of the mel filterbank. + """ + + max_wav_value: float = field(default=32768.0) + input_sample_rate: int = field(default=16000) + output_sample_rate: int = field(default=24000) + filter_length: int = field(default=1280) + hop_length: int = field(default=320) + win_length: int = field(default=1280) + n_mel_channels: int = field(default=80) + mel_fmin: float = field(default=0.0) + mel_fmax: Optional[float] = field(default=None) + + +@dataclass +class FreeVCArgs(Coqpit): + """FreeVC model arguments + + Args: + spec_channels (int): + The number of channels in the spectrogram. + + inter_channels (int): + The number of channels in the intermediate layers. + + hidden_channels (int): + The number of channels in the hidden layers. + + filter_channels (int): + The number of channels in the filter layers. + + n_heads (int): + The number of attention heads. + + n_layers (int): + The number of layers. + + kernel_size (int): + The size of the kernel. + + p_dropout (float): + The dropout probability. + + resblock (str): + The type of residual block. + + resblock_kernel_sizes (List[int]): + The kernel sizes for the residual blocks. + + resblock_dilation_sizes (List[List[int]]): + The dilation sizes for the residual blocks. + + upsample_rates (List[int]): + The upsample rates. + + upsample_initial_channel (int): + The number of channels in the initial upsample layer. + + upsample_kernel_sizes (List[int]): + The kernel sizes for the upsample layers. + + n_layers_q (int): + The number of layers in the quantization network. + + use_spectral_norm (bool): + Whether to use spectral normalization. + + gin_channels (int): + The number of channels in the global conditioning vector. + + ssl_dim (int): + The dimension of the self-supervised learning embedding. + + use_spk (bool): + Whether to use external speaker encoder. + """ + + spec_channels: int = field(default=641) + inter_channels: int = field(default=192) + hidden_channels: int = field(default=192) + filter_channels: int = field(default=768) + n_heads: int = field(default=2) + n_layers: int = field(default=6) + kernel_size: int = field(default=3) + p_dropout: float = field(default=0.1) + resblock: str = field(default="1") + resblock_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 11]) + resblock_dilation_sizes: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_rates: List[int] = field(default_factory=lambda: [10, 8, 2, 2]) + upsample_initial_channel: int = field(default=512) + upsample_kernel_sizes: List[int] = field(default_factory=lambda: [16, 16, 4, 4]) + n_layers_q: int = field(default=3) + use_spectral_norm: bool = field(default=False) + gin_channels: int = field(default=256) + ssl_dim: int = field(default=1024) + use_spk: bool = field(default=False) + num_spks: int = field(default=0) + segment_size: int = field(default=8960) + + +@dataclass +class FreeVCConfig(BaseVCConfig): + """Defines parameters for FreeVC End2End TTS model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (FreeVCArgs): + Model architecture arguments. Defaults to `FreeVCArgs()`. + + audio (FreeVCAudioConfig): + Audio processing configuration. Defaults to `FreeVCAudioConfig()`. + + grad_clip (List): + Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`. + + lr_gen (float): + Initial learning rate for the generator. Defaults to 0.0002. + + lr_disc (float): + Initial learning rate for the discriminator. Defaults to 0.0002. + + lr_scheduler_gen (str): + Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_gen_params (dict): + Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + lr_scheduler_disc (str): + Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to + `ExponentialLR`. + + lr_scheduler_disc_params (dict): + Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`. + + scheduler_after_epoch (bool): + If true, step the schedulers after each epoch else after each step. Defaults to `False`. + + optimizer (str): + Name of the optimizer to use with both the generator and the discriminator networks. One of the + `torch.optim.*`. Defaults to `AdamW`. + + kl_loss_alpha (float): + Loss weight for KL loss. Defaults to 1.0. + + disc_loss_alpha (float): + Loss weight for the discriminator loss. Defaults to 1.0. + + gen_loss_alpha (float): + Loss weight for the generator loss. Defaults to 1.0. + + feat_loss_alpha (float): + Loss weight for the feature matching loss. Defaults to 1.0. + + mel_loss_alpha (float): + Loss weight for the mel loss. Defaults to 45.0. + + return_wav (bool): + If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`. + + compute_linear_spec (bool): + If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. + + use_weighted_sampler (bool): + If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`. + + weighted_sampler_attrs (dict): + Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities + by overweighting `root_path` by 2.0. Defaults to `{}`. + + weighted_sampler_multipliers (dict): + Weight each unique value of a key returned by the formatter for weighted sampling. + For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`. + It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`. + + r (int): + Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. + + add_blank (bool): + If true, a blank token is added in between every character. Defaults to `True`. + + test_sentences (List[List]): + List of sentences with speaker and language information to be used for testing. + + language_ids_file (str): + Path to the language ids file. + + use_language_embedding (bool): + If true, language embedding is used. Defaults to `False`. + + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + + Example: + + >>> from TTS.vc.configs.freevc_config import FreeVCConfig + >>> config = FreeVCConfig() + """ + + model: str = "freevc" + # model specific params + model_args: FreeVCArgs = field(default_factory=FreeVCArgs) + audio: FreeVCAudioConfig = field(default_factory=FreeVCAudioConfig) + + # optimizer + # TODO with training support + + # loss params + # TODO with training support + + # data loader params + return_wav: bool = True + compute_linear_spec: bool = True + + # sampler params + use_weighted_sampler: bool = False # TODO: move it to the base config + weighted_sampler_attrs: dict = field(default_factory=lambda: {}) + weighted_sampler_multipliers: dict = field(default_factory=lambda: {}) + + # overrides + r: int = 1 # DO NOT CHANGE + add_blank: bool = True + + # multi-speaker settings + # use speaker embedding layer + num_speakers: int = 0 + speakers_file: str = None + speaker_embedding_channels: int = 256 + + # use d-vectors + use_d_vector_file: bool = False + d_vector_file: List[str] = None + d_vector_dim: int = None + + def __post_init__(self): + for key, val in self.model_args.items(): + if hasattr(self, key): + self[key] = val diff --git a/TTS/vc/configs/shared_configs.py b/TTS/vc/configs/shared_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..b2fe63d29d3d2e36ddefd8da3b9599d5267adcc0 --- /dev/null +++ b/TTS/vc/configs/shared_configs.py @@ -0,0 +1,153 @@ +from dataclasses import dataclass, field +from typing import List + +from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig + + +@dataclass +class BaseVCConfig(BaseTrainingConfig): + """Shared parameters among all the tts models. + + Args: + + audio (BaseAudioConfig): + Audio processor config object instance. + + batch_group_size (int): + Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence + length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to + prevent using the same batches for each epoch. + + loss_masking (bool): + enable / disable masking loss values against padded segments of samples in a batch. + + min_text_len (int): + Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0. + + max_text_len (int): + Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf"). + + min_audio_len (int): + Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0. + + max_audio_len (int): + Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the + dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an + OOM error in training. Defaults to float("inf"). + + compute_f0 (int): + (Not in use yet). + + compute_energy (int): + (Not in use yet). + + compute_linear_spec (bool): + If True data loader computes and returns linear spectrograms alongside the other data. + + precompute_num_workers (int): + Number of workers to precompute features. Defaults to 0. + + use_noise_augment (bool): + Augment the input audio with random noise. + + start_by_longest (bool): + If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues. + Defaults to False. + + shuffle (bool): + If True, the data loader will shuffle the dataset when there is not sampler defined. Defaults to True. + + drop_last (bool): + If True, the data loader will drop the last batch if it is not complete. It helps to prevent + issues that emerge from the partial batch statistics. Defaults to True. + + add_blank (bool): + Add blank characters between each other two characters. It improves performance for some models at expense + of slower run-time due to the longer input sequence. + + datasets (List[BaseDatasetConfig]): + List of datasets used for training. If multiple datasets are provided, they are merged and used together + for training. + + optimizer (str): + Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`. + Defaults to ``. + + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + + lr_scheduler (str): + Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or + `TTS.utils.training`. Defaults to ``. + + lr_scheduler_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`. + + test_sentences (List[str]): + List of sentences to be used at testing. Defaults to '[]' + + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + + use_speaker_weighted_sampler (bool): + Enable / Disable the batch balancer by speaker. Defaults to ```False```. + + speaker_weighted_sampler_alpha (float): + Number that control the influence of the speaker sampler weights. Defaults to ```1.0```. + + use_language_weighted_sampler (bool): + Enable / Disable the batch balancer by language. Defaults to ```False```. + + language_weighted_sampler_alpha (float): + Number that control the influence of the language sampler weights. Defaults to ```1.0```. + + use_length_weighted_sampler (bool): + Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided + into 10 buckets considering the min and max audio of the dataset. The sampler weights will be + computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```. + + length_weighted_sampler_alpha (float): + Number that control the influence of the length sampler weights. Defaults to ```1.0```. + """ + + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + # training params + batch_group_size: int = 0 + loss_masking: bool = None + # dataloading + min_audio_len: int = 1 + max_audio_len: int = float("inf") + min_text_len: int = 1 + max_text_len: int = float("inf") + compute_f0: bool = False + compute_energy: bool = False + compute_linear_spec: bool = False + precompute_num_workers: int = 0 + use_noise_augment: bool = False + start_by_longest: bool = False + shuffle: bool = False + drop_last: bool = False + # dataset + datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + # optimizer + optimizer: str = "radam" + optimizer_params: dict = None + # scheduler + lr_scheduler: str = None + lr_scheduler_params: dict = field(default_factory=lambda: {}) + # testing + test_sentences: List[str] = field(default_factory=lambda: []) + # evaluation + eval_split_max_size: int = None + eval_split_size: float = 0.01 + # weighted samplers + use_speaker_weighted_sampler: bool = False + speaker_weighted_sampler_alpha: float = 1.0 + use_language_weighted_sampler: bool = False + language_weighted_sampler_alpha: float = 1.0 + use_length_weighted_sampler: bool = False + length_weighted_sampler_alpha: float = 1.0 diff --git a/TTS/vc/models/__init__.py b/TTS/vc/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a498b292b7ef40738b83fada8cd32de31d48be2f --- /dev/null +++ b/TTS/vc/models/__init__.py @@ -0,0 +1,20 @@ +import importlib +import logging +import re +from typing import Dict, List, Union + +logger = logging.getLogger(__name__) + + +def to_camel(text): + text = text.capitalize() + return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) + + +def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC": + logger.info("Using model: %s", config.model) + # fetch the right model implementation. + if "model" in config and config["model"].lower() == "freevc": + MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC + model = MyModel.init_from_config(config, samples) + return model diff --git a/TTS/vc/models/base_vc.py b/TTS/vc/models/base_vc.py new file mode 100644 index 0000000000000000000000000000000000000000..22ffd0095ccf9b0cf21e7135bc18541e74d0687b --- /dev/null +++ b/TTS/vc/models/base_vc.py @@ -0,0 +1,435 @@ +import logging +import os +import random +from typing import Any, Optional, Union + +import torch +import torch.distributed as dist +from coqpit import Coqpit +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.torch import DistributedSampler, DistributedSamplerWrapper +from trainer.trainer import Trainer + +from TTS.model import BaseTrainerModel +from TTS.tts.datasets.dataset import TTSDataset +from TTS.tts.utils.data import get_length_balancer_weights +from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights +from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.audio.processor import AudioProcessor + +# pylint: skip-file + +logger = logging.getLogger(__name__) + + +class BaseVC(BaseTrainerModel): + """Base `vc` class. Every new `vc` model must inherit this. + + It defines common `vc` specific functions on top of `Model` implementation. + """ + + MODEL_TYPE = "vc" + + def __init__( + self, + config: Coqpit, + ap: AudioProcessor, + speaker_manager: Optional[SpeakerManager] = None, + language_manager: Optional[LanguageManager] = None, + ) -> None: + super().__init__() + self.config = config + self.ap = ap + self.speaker_manager = speaker_manager + self.language_manager = language_manager + self._set_model_args(config) + + def _set_model_args(self, config: Coqpit) -> None: + """Setup model args based on the config type (`ModelConfig` or `ModelArgs`). + + `ModelArgs` has all the fields reuqired to initialize the model architecture. + + `ModelConfig` has all the fields required for training, inference and containes `ModelArgs`. + + If the config is for training with a name like "*Config", then the model args are embeded in the + config.model_args + + If the config is for the model with a name like "*Args", then we assign the directly. + """ + # don't use isintance not to import recursively + if "Config" in config.__class__.__name__: + self.config = config + self.args = config.model_args + elif "Args" in config.__class__.__name__: + self.args = config + else: + raise ValueError("config must be either a *Config or *Args") + + def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None: + """Initialize a speaker embedding layer if needen and define expected embedding channel size for defining + `in_channels` size of the connected layers. + + This implementation yields 3 possible outcomes: + + 1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing. + 2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512. + 3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of + `config.d_vector_dim` or 512. + + You can override this function for new models. + + Args: + config (Coqpit): Model configuration. + """ + # set number of speakers + if self.speaker_manager is not None: + self.num_speakers = self.speaker_manager.num_speakers + elif hasattr(config, "num_speakers"): + self.num_speakers = config.num_speakers + + # set ultimate speaker embedding size + if config.use_speaker_embedding or config.use_d_vector_file: + self.embedded_speaker_dim = ( + config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 + ) + # init speaker embedding layer + if config.use_speaker_embedding and not config.use_d_vector_file: + logger.info("Init speaker_embedding layer.") + self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) + self.speaker_embedding.weight.data.normal_(0, 0.3) + + def get_aux_input(self, **kwargs: Any) -> dict[str, Any]: + """Prepare and return `aux_input` used by `forward()`""" + return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None} + + def get_aux_input_from_test_sentences(self, sentence_info: Union[str, list[str]]) -> dict[str, Any]: + if hasattr(self.config, "model_args"): + config = self.config.model_args + else: + config = self.config + + # extract speaker and language info + text, speaker_name, style_wav, language_name = None, None, None, None + + if isinstance(sentence_info, list): + if len(sentence_info) == 1: + text = sentence_info[0] + elif len(sentence_info) == 2: + text, speaker_name = sentence_info + elif len(sentence_info) == 3: + text, speaker_name, style_wav = sentence_info + elif len(sentence_info) == 4: + text, speaker_name, style_wav, language_name = sentence_info + else: + text = sentence_info + + # get speaker id/d_vector + speaker_id, d_vector, language_id = None, None, None + if self.speaker_manager is not None: + if config.use_d_vector_file: + if speaker_name is None: + d_vector = self.speaker_manager.get_random_embedding() + else: + d_vector = self.speaker_manager.get_mean_embedding(speaker_name) + elif config.use_speaker_embedding: + if speaker_name is None: + speaker_id = self.speaker_manager.get_random_id() + else: + speaker_id = self.speaker_manager.name_to_id[speaker_name] + + # get language id + if self.language_manager is not None and config.use_language_embedding and language_name is not None: + language_id = self.language_manager.name_to_id[language_name] + + return { + "text": text, + "speaker_id": speaker_id, + "style_wav": style_wav, + "d_vector": d_vector, + "language_id": language_id, + } + + def format_batch(self, batch: dict[str, Any]) -> dict[str, Any]: + """Generic batch formatting for `VCDataset`. + + You must override this if you use a custom dataset. + + Args: + batch (dict): [description] + + Returns: + dict: [description] + """ + # setup input batch + text_input = batch["token_id"] + text_lengths = batch["token_id_lengths"] + speaker_names = batch["speaker_names"] + linear_input = batch["linear"] + mel_input = batch["mel"] + mel_lengths = batch["mel_lengths"] + stop_targets = batch["stop_targets"] + item_idx = batch["item_idxs"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + attn_mask = batch["attns"] + waveform = batch["waveform"] + pitch = batch["pitch"] + energy = batch["energy"] + language_ids = batch["language_ids"] + max_text_length = torch.max(text_lengths.float()) + max_spec_length = torch.max(mel_lengths.float()) + + # compute durations from attention masks + durations = None + if attn_mask is not None: + durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) + for idx, am in enumerate(attn_mask): + # compute raw durations + c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] + # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) + c_idxs, counts = torch.unique(c_idxs, return_counts=True) + dur = torch.ones([text_lengths[idx]]).to(counts.dtype) + dur[c_idxs] = counts + # smooth the durations and set any 0 duration to 1 + # by cutting off from the largest duration indeces. + extra_frames = dur.sum() - mel_lengths[idx] + largest_idxs = torch.argsort(-dur)[:extra_frames] + dur[largest_idxs] -= 1 + assert ( + dur.sum() == mel_lengths[idx] + ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" + durations[idx, : text_lengths[idx]] = dur + + # set stop targets wrt reduction factor + stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) + stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_() + + return { + "text_input": text_input, + "text_lengths": text_lengths, + "speaker_names": speaker_names, + "mel_input": mel_input, + "mel_lengths": mel_lengths, + "linear_input": linear_input, + "stop_targets": stop_targets, + "stop_target_lengths": stop_target_lengths, + "attn_mask": attn_mask, + "durations": durations, + "speaker_ids": speaker_ids, + "d_vectors": d_vectors, + "max_text_length": float(max_text_length), + "max_spec_length": float(max_spec_length), + "item_idx": item_idx, + "waveform": waveform, + "pitch": pitch, + "energy": energy, + "language_ids": language_ids, + "audio_unique_names": batch["audio_unique_names"], + } + + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus: int = 1): + weights = None + data_items = dataset.samples + + if getattr(config, "use_language_weighted_sampler", False): + alpha = getattr(config, "language_weighted_sampler_alpha", 1.0) + logger.info("Using Language weighted sampler with alpha: %.2f", alpha) + weights = get_language_balancer_weights(data_items) * alpha + + if getattr(config, "use_speaker_weighted_sampler", False): + alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0) + logger.info("Using Speaker weighted sampler with alpha: %.2f", alpha) + if weights is not None: + weights += get_speaker_balancer_weights(data_items) * alpha + else: + weights = get_speaker_balancer_weights(data_items) * alpha + + if getattr(config, "use_length_weighted_sampler", False): + alpha = getattr(config, "length_weighted_sampler_alpha", 1.0) + logger.info("Using Length weighted sampler with alpha: %.2f", alpha) + if weights is not None: + weights += get_length_balancer_weights(data_items) * alpha + else: + weights = get_length_balancer_weights(data_items) * alpha + + if weights is not None: + sampler = WeightedRandomSampler(weights, len(weights)) + else: + sampler = None + + # sampler for DDP + if sampler is None: + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler + + return sampler + + def get_data_loader( + self, + config: Coqpit, + assets: dict, + is_eval: bool, + samples: Union[list[dict], list[list]], + verbose: bool, + num_gpus: int, + rank: Optional[int] = None, + ) -> "DataLoader": + if is_eval and not config.run_eval: + loader = None + else: + # setup multi-speaker attributes + if self.speaker_manager is not None: + if hasattr(config, "model_args"): + speaker_id_mapping = ( + self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None + ) + d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None + config.use_d_vector_file = config.model_args.use_d_vector_file + else: + speaker_id_mapping = self.speaker_manager.name_to_id if config.use_speaker_embedding else None + d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None + else: + speaker_id_mapping = None + d_vector_mapping = None + + # setup multi-lingual attributes + if self.language_manager is not None: + language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None + else: + language_id_mapping = None + + # init dataloader + dataset = TTSDataset( + outputs_per_step=config.r if "r" in config else 1, + compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec, + compute_f0=config.get("compute_f0", False), + f0_cache_path=config.get("f0_cache_path", None), + compute_energy=config.get("compute_energy", False), + energy_cache_path=config.get("energy_cache_path", None), + samples=samples, + ap=self.ap, + return_wav=config.return_wav if "return_wav" in config else False, + batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, + phoneme_cache_path=config.phoneme_cache_path, + precompute_num_workers=config.precompute_num_workers, + use_noise_augment=False if is_eval else config.use_noise_augment, + speaker_id_mapping=speaker_id_mapping, + d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None, + tokenizer=None, + start_by_longest=config.start_by_longest, + language_id_mapping=language_id_mapping, + ) + + # wait all the DDP process to be ready + if num_gpus > 1: + dist.barrier() + + # sort input sequences from short to long + dataset.preprocess_samples() + + # get samplers + sampler = self.get_sampler(config, dataset, num_gpus) + + loader = DataLoader( + dataset, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + shuffle=config.shuffle if sampler is None else False, # if there is no other sampler + collate_fn=dataset.collate_fn, + drop_last=config.drop_last, # setting this False might cause issues in AMP training. + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def _get_test_aux_input( + self, + ) -> dict[str, Any]: + d_vector = None + if self.speaker_manager is not None and self.config.use_d_vector_file: + d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings] + d_vector = (random.sample(sorted(d_vector), 1),) + + aux_inputs = { + "speaker_id": ( + None + if not self.config.use_speaker_embedding + else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1) + ), + "d_vector": d_vector, + "style_wav": None, # TODO: handle GST style input + } + return aux_inputs + + def test_run(self, assets: dict) -> tuple[dict, dict]: + """Generic test run for `vc` models used by `Trainer`. + + You can override this for a different behaviour. + + Args: + assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`. + + Returns: + tuple[dict, dict]: Test figures and audios to be projected to Tensorboard. + """ + logger.info("Synthesizing test sentences.") + test_audios = {} + test_figures = {} + test_sentences = self.config.test_sentences + aux_inputs = self._get_test_aux_input() + for idx, sen in enumerate(test_sentences): + if isinstance(sen, list): + aux_inputs = self.get_aux_input_from_test_sentences(sen) + sen = aux_inputs["text"] + outputs_dict = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + use_griffin_lim=True, + do_trim_silence=False, + ) + test_audios["{}-audio".format(idx)] = outputs_dict["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment( + outputs_dict["outputs"]["alignments"], output_fig=False + ) + return test_figures, test_audios + + def on_init_start(self, trainer: Trainer) -> None: + """Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths.""" + if self.speaker_manager is not None: + output_path = os.path.join(trainer.output_path, "speakers.pth") + self.speaker_manager.save_ids_to_file(output_path) + trainer.config.speakers_file = output_path + # some models don't have `model_args` set + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.speakers_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + logger.info("`speakers.pth` is saved to %s", output_path) + logger.info("`speakers_file` is updated in the config.json.") + + if self.language_manager is not None: + output_path = os.path.join(trainer.output_path, "language_ids.json") + self.language_manager.save_ids_to_file(output_path) + trainer.config.language_ids_file = output_path + if hasattr(trainer.config, "model_args"): + trainer.config.model_args.language_ids_file = output_path + trainer.config.save_json(os.path.join(trainer.output_path, "config.json")) + logger.info("`language_ids.json` is saved to %s", output_path) + logger.info("`language_ids_file` is updated in the config.json.") diff --git a/TTS/vc/models/freevc.py b/TTS/vc/models/freevc.py new file mode 100644 index 0000000000000000000000000000000000000000..e5cfdc1e61692d5521eaae8160ee7e9ba98c84c0 --- /dev/null +++ b/TTS/vc/models/freevc.py @@ -0,0 +1,565 @@ +import logging +from typing import Dict, List, Optional, Tuple, Union + +import librosa +import numpy as np +import torch +from coqpit import Coqpit +from torch import nn +from torch.nn import Conv1d, Conv2d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils import spectral_norm +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations +from trainer.io import load_fsspec + +import TTS.vc.modules.freevc.commons as commons +import TTS.vc.modules.freevc.modules as modules +from TTS.tts.utils.helpers import sequence_mask +from TTS.tts.utils.speakers import SpeakerManager +from TTS.vc.configs.freevc_config import FreeVCConfig +from TTS.vc.models.base_vc import BaseVC +from TTS.vc.modules.freevc.commons import init_weights +from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch +from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx +from TTS.vc.modules.freevc.wavlm import get_wavlm +from TTS.vocoder.models.hifigan_generator import get_padding + +logger = logging.getLogger(__name__) + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) + self.flows.append(modules.Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class Encoder(nn.Module): + def __init__( + self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0 + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + logger.info("Removing weight norm...") + for l in self.ups: + remove_parametrizations(l, "weight") + for l in self.resblocks: + remove_parametrizations(l, "weight") + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class SpeakerEncoder(torch.nn.Module): + def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256): + super(SpeakerEncoder, self).__init__() + self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) + self.linear = nn.Linear(model_hidden_size, model_embedding_size) + self.relu = nn.ReLU() + + def forward(self, mels): + self.lstm.flatten_parameters() + _, (hidden, _) = self.lstm(mels) + embeds_raw = self.relu(self.linear(hidden[-1])) + return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) + + def compute_partial_slices(self, total_frames, partial_frames, partial_hop): + mel_slices = [] + for i in range(0, total_frames - partial_frames, partial_hop): + mel_range = torch.arange(i, i + partial_frames) + mel_slices.append(mel_range) + + return mel_slices + + def embed_utterance(self, mel, partial_frames=128, partial_hop=64): + mel_len = mel.size(1) + last_mel = mel[:, -partial_frames:] + + if mel_len > partial_frames: + mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop) + mels = list(mel[:, s] for s in mel_slices) + mels.append(last_mel) + mels = torch.stack(tuple(mels), 0).squeeze(1) + + with torch.no_grad(): + partial_embeds = self(mels) + embed = torch.mean(partial_embeds, axis=0).unsqueeze(0) + # embed = embed / torch.linalg.norm(embed, 2) + else: + with torch.no_grad(): + embed = self(last_mel) + + return embed + + +class FreeVC(BaseVC): + """ + + Papaer:: + https://arxiv.org/abs/2210.15418# + + Paper Abstract:: + Voice conversion (VC) can be achieved by first extracting source content information and target speaker + information, and then reconstructing waveform with these information. However, current approaches normally + either extract dirty content information with speaker information leaked in, or demand a large amount of + annotated data for training. Besides, the quality of reconstructed waveform can be degraded by the + mismatch between conversion model and vocoder. In this paper, we adopt the end-to-end framework of VITS for + high-quality waveform reconstruction, and propose strategies for clean content information extraction without + text annotation. We disentangle content information by imposing an information bottleneck to WavLM features, + and propose the spectrogram-resize based data augmentation to improve the purity of extracted content + information. Experimental results show that the proposed method outperforms the latest VC models trained with + annotated data and has greater robustness. + + Original Code:: + https://github.com/OlaWod/FreeVC + + Examples: + >>> from TTS.vc.configs.freevc_config import FreeVCConfig + >>> from TTS.vc.models.freevc import FreeVC + >>> config = FreeVCConfig() + >>> model = FreeVC(config) + """ + + def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): + super().__init__(config, None, speaker_manager, None) + + self.init_multispeaker(config) + + self.spec_channels = self.args.spec_channels + self.inter_channels = self.args.inter_channels + self.hidden_channels = self.args.hidden_channels + self.filter_channels = self.args.filter_channels + self.n_heads = self.args.n_heads + self.n_layers = self.args.n_layers + self.kernel_size = self.args.kernel_size + self.p_dropout = self.args.p_dropout + self.resblock = self.args.resblock + self.resblock_kernel_sizes = self.args.resblock_kernel_sizes + self.resblock_dilation_sizes = self.args.resblock_dilation_sizes + self.upsample_rates = self.args.upsample_rates + self.upsample_initial_channel = self.args.upsample_initial_channel + self.upsample_kernel_sizes = self.args.upsample_kernel_sizes + self.segment_size = self.args.segment_size + self.gin_channels = self.args.gin_channels + self.ssl_dim = self.args.ssl_dim + self.use_spk = self.args.use_spk + + self.enc_p = Encoder(self.args.ssl_dim, self.inter_channels, self.hidden_channels, 5, 1, 16) + self.dec = Generator( + self.inter_channels, + self.resblock, + self.resblock_kernel_sizes, + self.resblock_dilation_sizes, + self.upsample_rates, + self.upsample_initial_channel, + self.upsample_kernel_sizes, + gin_channels=self.gin_channels, + ) + self.enc_q = Encoder( + self.spec_channels, self.inter_channels, self.hidden_channels, 5, 1, 16, gin_channels=self.gin_channels + ) + self.flow = ResidualCouplingBlock( + self.inter_channels, self.hidden_channels, 5, 1, 4, gin_channels=self.gin_channels + ) + if not self.use_spk: + self.enc_spk = SpeakerEncoder(model_hidden_size=self.gin_channels, model_embedding_size=self.gin_channels) + else: + self.load_pretrained_speaker_encoder() + + self.wavlm = get_wavlm() + + @property + def device(self): + return next(self.parameters()).device + + def load_pretrained_speaker_encoder(self): + """Load pretrained speaker encoder model as mentioned in the paper.""" + logger.info("Loading pretrained speaker encoder model ...") + self.enc_spk_ex = SpeakerEncoderEx( + "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt", device=self.device + ) + + def init_multispeaker(self, config: Coqpit): + """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer + or with external `d_vectors` computed from a speaker encoder model. + + You must provide a `speaker_manager` at initialization to set up the multi-speaker modules. + + Args: + config (Coqpit): Model configuration. + data (List, optional): Dataset items to infer number of speakers. Defaults to None. + """ + self.num_spks = self.args.num_spks + if self.speaker_manager: + self.num_spks = self.speaker_manager.num_spks + + def forward( + self, + c: torch.Tensor, + spec: torch.Tensor, + g: Optional[torch.Tensor] = None, + mel: Optional[torch.Tensor] = None, + c_lengths: Optional[torch.Tensor] = None, + spec_lengths: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + ]: + """ + Forward pass of the model. + + Args: + c: WavLM features. Shape: (batch_size, c_seq_len). + spec: The input spectrogram. Shape: (batch_size, spec_seq_len, spec_dim). + g: The speaker embedding. Shape: (batch_size, spk_emb_dim). + mel: The input mel-spectrogram for the speaker encoder. Shape: (batch_size, mel_seq_len, mel_dim). + c_lengths: The lengths of the WavLM features. Shape: (batch_size,). + spec_lengths: The lengths of the spectrogram. Shape: (batch_size,). + + Returns: + o: The output spectrogram. Shape: (batch_size, spec_seq_len, spec_dim). + ids_slice: The slice indices. Shape: (batch_size, num_slices). + spec_mask: The spectrogram mask. Shape: (batch_size, spec_seq_len). + (z, z_p, m_p, logs_p, m_q, logs_q): A tuple of latent variables. + """ + + # If c_lengths is None, set it to the length of the last dimension of c + if c_lengths is None: + c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + + # If spec_lengths is None, set it to the length of the last dimension of spec + if spec_lengths is None: + spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device) + + # If use_spk is False, compute g from mel using enc_spk + g = None + if not self.use_spk: + g = self.enc_spk(mel).unsqueeze(-1) + + # Compute m_p, logs_p, z, m_q, logs_q, and spec_mask using enc_p and enc_q + _, m_p, logs_p, _ = self.enc_p(c, c_lengths) + z, m_q, logs_q, spec_mask = self.enc_q(spec.transpose(1, 2), spec_lengths, g=g) + + # Compute z_p using flow + z_p = self.flow(z, spec_mask, g=g) + + # Randomly slice z and compute o using dec + z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size) + o = self.dec(z_slice, g=g) + + return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + @torch.no_grad() + def inference(self, c, g=None, mel=None, c_lengths=None): + """ + Inference pass of the model + + Args: + c (torch.Tensor): Input tensor. Shape: (batch_size, c_seq_len). + g (torch.Tensor): Speaker embedding tensor. Shape: (batch_size, spk_emb_dim). + mel (torch.Tensor): Mel-spectrogram tensor. Shape: (batch_size, mel_seq_len, mel_dim). + c_lengths (torch.Tensor): Lengths of the input tensor. Shape: (batch_size,). + + Returns: + torch.Tensor: Output tensor. + """ + if c_lengths is None: + c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device) + if not self.use_spk: + g = self.enc_spk.embed_utterance(mel) + g = g.unsqueeze(-1) + z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths) + z = self.flow(z_p, c_mask, g=g, reverse=True) + o = self.dec(z * c_mask, g=g) + return o + + def extract_wavlm_features(self, y): + """Extract WavLM features from an audio tensor. + + Args: + y (torch.Tensor): Audio tensor. Shape: (batch_size, audio_seq_len). + """ + + with torch.no_grad(): + c = self.wavlm.extract_features(y)[0] + c = c.transpose(1, 2) + return c + + def load_audio(self, wav): + """Read and format the input audio.""" + if isinstance(wav, str): + wav, _ = librosa.load(wav, sr=self.config.audio.input_sample_rate) + if isinstance(wav, np.ndarray): + wav = torch.from_numpy(wav).to(self.device) + if isinstance(wav, torch.Tensor): + wav = wav.to(self.device) + if isinstance(wav, list): + wav = torch.from_numpy(np.array(wav)).to(self.device) + return wav.float() + + @torch.inference_mode() + def voice_conversion(self, src, tgt): + """ + Voice conversion pass of the model. + + Args: + src (str or torch.Tensor): Source utterance. + tgt (str or torch.Tensor): Target utterance. + + Returns: + torch.Tensor: Output tensor. + """ + + wav_tgt = self.load_audio(tgt).cpu().numpy() + wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20) + + if self.config.model_args.use_spk: + g_tgt = self.enc_spk_ex.embed_utterance(wav_tgt) + g_tgt = torch.from_numpy(g_tgt)[None, :, None].to(self.device) + else: + wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(self.device) + mel_tgt = mel_spectrogram_torch( + wav_tgt, + self.config.audio.filter_length, + self.config.audio.n_mel_channels, + self.config.audio.input_sample_rate, + self.config.audio.hop_length, + self.config.audio.win_length, + self.config.audio.mel_fmin, + self.config.audio.mel_fmax, + ) + # src + wav_src = self.load_audio(src) + c = self.extract_wavlm_features(wav_src[None, :]) + + if self.config.model_args.use_spk: + audio = self.inference(c, g=g_tgt) + else: + audio = self.inference(c, mel=mel_tgt.transpose(1, 2)) + audio = audio[0][0].data.cpu().float().numpy() + return audio + + def eval_step(): ... + + @staticmethod + def init_from_config(config: FreeVCConfig, samples: Union[List[List], List[Dict]] = None): + model = FreeVC(config) + return model + + def load_checkpoint(self, config, checkpoint_path, eval=False, strict=True, cache=False): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"], strict=strict) + if eval: + self.eval() + + def train_step(): ... diff --git a/TTS/vc/modules/__init__.py b/TTS/vc/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/vc/modules/freevc/__init__.py b/TTS/vc/modules/freevc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/vc/modules/freevc/commons.py b/TTS/vc/modules/freevc/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..feea7f34dc8020099947a6e1317b51298652fb07 --- /dev/null +++ b/TTS/vc/modules/freevc/commons.py @@ -0,0 +1,145 @@ +import math + +import torch +from torch.nn import functional as F + +from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask + + +def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None: + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def rand_spec_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm diff --git a/TTS/vc/modules/freevc/mel_processing.py b/TTS/vc/modules/freevc/mel_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..a3e251891a8b4cbcad8bb25aaccf0a83a5ec7d15 --- /dev/null +++ b/TTS/vc/modules/freevc/mel_processing.py @@ -0,0 +1,133 @@ +import logging + +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +logger = logging.getLogger(__name__) + +MAX_WAV_VALUE = 32768.0 + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + if torch.min(y) < -1.0: + logger.info("Min value is: %.3f", torch.min(y)) + if torch.max(y) > 1.0: + logger.info("Max value is: %.3f", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + +def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + return spec + + +def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + logger.info("Min value is: %.3f", torch.min(y)) + if torch.max(y) > 1.0: + logger.info("Max value is: %.3f", torch.max(y)) + + global mel_basis, hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/TTS/vc/modules/freevc/modules.py b/TTS/vc/modules/freevc/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..722444a30358851d8d250147945d24d79a1d3e5e --- /dev/null +++ b/TTS/vc/modules/freevc/modules.py @@ -0,0 +1,314 @@ +import torch +from torch import nn +from torch.nn import Conv1d +from torch.nn import functional as F +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations + +import TTS.vc.modules.freevc.commons as commons +from TTS.tts.layers.generic.normalization import LayerNorm2 +from TTS.vc.modules.freevc.commons import init_weights +from TTS.vocoder.models.hifigan_generator import get_padding + +LRELU_SLOPE = 0.1 + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm2(hidden_channels)) + self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) + for _ in range(n_layers - 1): + self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) + self.norm_layers.append(LayerNorm2(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class WN(torch.nn.Module): + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size,) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) + self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight") + + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d( + hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) + in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + remove_parametrizations(self.cond_layer, "weight") + for l in self.in_layers: + remove_parametrizations(l, "weight") + for l in self.res_skip_layers: + remove_parametrizations(l, "weight") + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_parametrizations(l, "weight") + for l in self.convs2: + remove_parametrizations(l, "weight") + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_parametrizations(l, "weight") + + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x + + +class ResidualCouplingLayer(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x diff --git a/TTS/vc/modules/freevc/speaker_encoder/__init__.py b/TTS/vc/modules/freevc/speaker_encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/vc/modules/freevc/speaker_encoder/audio.py b/TTS/vc/modules/freevc/speaker_encoder/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..5b23a4dbb69278d603cc3a67c0dde1c24f28386f --- /dev/null +++ b/TTS/vc/modules/freevc/speaker_encoder/audio.py @@ -0,0 +1,69 @@ +from pathlib import Path +from typing import Optional, Union + +# import webrtcvad +import librosa +import numpy as np + +from TTS.vc.modules.freevc.speaker_encoder.hparams import ( + audio_norm_target_dBFS, + mel_n_channels, + mel_window_length, + mel_window_step, + sampling_rate, +) + +int16_max = (2**15) - 1 + + +def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], source_sr: Optional[int] = None): + """ + Applies the preprocessing operations used in training the Speaker Encoder to a waveform + either on disk or in memory. The waveform will be resampled to match the data hyperparameters. + + :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not + just .wav), either the waveform as a numpy array of floats. + :param source_sr: if passing an audio waveform, the sampling rate of the waveform before + preprocessing. After preprocessing, the waveform's sampling rate will match the data + hyperparameters. If passing a filepath, the sampling rate will be automatically detected and + this argument will be ignored. + """ + # Load the wav from disk if needed + if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): + wav, source_sr = librosa.load(fpath_or_wav, sr=None) + else: + wav = fpath_or_wav + + # Resample the wav if needed + if source_sr is not None and source_sr != sampling_rate: + wav = librosa.resample(wav, source_sr, sampling_rate) + + # Apply the preprocessing: normalize volume and shorten long silences + wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True) + wav = trim_long_silences(wav) + + return wav + + +def wav_to_mel_spectrogram(wav): + """ + Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform. + Note: this not a log-mel spectrogram. + """ + frames = librosa.feature.melspectrogram( + y=wav, + sr=sampling_rate, + n_fft=int(sampling_rate * mel_window_length / 1000), + hop_length=int(sampling_rate * mel_window_step / 1000), + n_mels=mel_n_channels, + ) + return frames.astype(np.float32).T + + +def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False): + if increase_only and decrease_only: + raise ValueError("Both increase only and decrease only are set") + dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav**2)) + if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only): + return wav + return wav * (10 ** (dBFS_change / 20)) diff --git a/TTS/vc/modules/freevc/speaker_encoder/hparams.py b/TTS/vc/modules/freevc/speaker_encoder/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..2c536ae16cf8134d66c83aaf978ed01fc396b680 --- /dev/null +++ b/TTS/vc/modules/freevc/speaker_encoder/hparams.py @@ -0,0 +1,31 @@ +## Mel-filterbank +mel_window_length = 25 # In milliseconds +mel_window_step = 10 # In milliseconds +mel_n_channels = 40 + + +## Audio +sampling_rate = 16000 +# Number of spectrogram frames in a partial utterance +partials_n_frames = 160 # 1600 ms + + +## Voice Activation Detection +# Window size of the VAD. Must be either 10, 20 or 30 milliseconds. +# This sets the granularity of the VAD. Should not need to be changed. +vad_window_length = 30 # In milliseconds +# Number of frames to average together when performing the moving average smoothing. +# The larger this value, the larger the VAD variations must be to not get smoothed out. +vad_moving_average_width = 8 +# Maximum number of consecutive silent frames a segment can have. +vad_max_silence_length = 6 + + +## Audio volume normalization +audio_norm_target_dBFS = -30 + + +## Model parameters +model_hidden_size = 256 +model_embedding_size = 256 +model_num_layers = 3 diff --git a/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py b/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..294bf322cb1466f0b219082cb9a1c5157d0dbdb8 --- /dev/null +++ b/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py @@ -0,0 +1,183 @@ +import logging +from time import perf_counter as timer +from typing import List, Union + +import numpy as np +import torch +from torch import nn +from trainer.io import load_fsspec + +from TTS.vc.modules.freevc.speaker_encoder import audio +from TTS.vc.modules.freevc.speaker_encoder.hparams import ( + mel_n_channels, + mel_window_step, + model_embedding_size, + model_hidden_size, + model_num_layers, + partials_n_frames, + sampling_rate, +) + +logger = logging.getLogger(__name__) + + +class SpeakerEncoder(nn.Module): + def __init__(self, weights_fpath, device: Union[str, torch.device] = None): + """ + :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). + If None, defaults to cuda if it is available on your machine, otherwise the model will + run on cpu. Outputs are always returned on the cpu, as numpy arrays. + """ + super().__init__() + + # Define the network + self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True) + self.linear = nn.Linear(model_hidden_size, model_embedding_size) + self.relu = nn.ReLU() + + # Get the target device + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + elif isinstance(device, str): + device = torch.device(device) + self.device = device + + # Load the pretrained model'speaker weights + # weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt") + # if not weights_fpath.exists(): + # raise Exception("Couldn't find the voice encoder pretrained model at %s." % + # weights_fpath) + + start = timer() + checkpoint = load_fsspec(weights_fpath, map_location="cpu") + + self.load_state_dict(checkpoint["model_state"], strict=False) + self.to(device) + logger.info("Loaded the voice encoder model on %s in %.2f seconds.", device.type, timer() - start) + + def forward(self, mels: torch.FloatTensor): + """ + Computes the embeddings of a batch of utterance spectrograms. + :param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape + (batch_size, n_frames, n_channels) + :return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size). + Embeddings are positive and L2-normed, thus they lay in the range [0, 1]. + """ + # Pass the input through the LSTM layers and retrieve the final hidden state of the last + # layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings. + _, (hidden, _) = self.lstm(mels) + embeds_raw = self.relu(self.linear(hidden[-1])) + return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) + + @staticmethod + def compute_partial_slices(n_samples: int, rate, min_coverage): + """ + Computes where to split an utterance waveform and its corresponding mel spectrogram to + obtain partial utterances of each. Both the waveform and the + mel spectrogram slices are returned, so as to make each partial utterance waveform + correspond to its spectrogram. + + The returned ranges may be indexing further than the length of the waveform. It is + recommended that you pad the waveform with zeros up to wav_slices[-1].stop. + + :param n_samples: the number of samples in the waveform + :param rate: how many partial utterances should occur per second. Partial utterances must + cover the span of the entire utterance, thus the rate should not be lower than the inverse + of the duration of a partial utterance. By default, partial utterances are 1.6s long and + the minimum rate is thus 0.625. + :param min_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered by zero-padding the audio. Otherwise, + it will be discarded. If there aren't enough frames for one partial utterance, + this parameter is ignored so that the function always returns at least one slice. + :return: the waveform slices and mel spectrogram slices as lists of array slices. Index + respectively the waveform and the mel spectrogram with these slices to obtain the partial + utterances. + """ + assert 0 < min_coverage <= 1 + + # Compute how many frames separate two partial utterances + samples_per_frame = int((sampling_rate * mel_window_step / 1000)) + n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) + frame_step = int(np.round((sampling_rate / rate) / samples_per_frame)) + assert 0 < frame_step, "The rate is too high" + assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % ( + sampling_rate / (samples_per_frame * partials_n_frames) + ) + + # Compute the slices + wav_slices, mel_slices = [], [] + steps = max(1, n_frames - partials_n_frames + frame_step + 1) + for i in range(0, steps, frame_step): + mel_range = np.array([i, i + partials_n_frames]) + wav_range = mel_range * samples_per_frame + mel_slices.append(slice(*mel_range)) + wav_slices.append(slice(*wav_range)) + + # Evaluate whether extra padding is warranted or not + last_wav_range = wav_slices[-1] + coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) + if coverage < min_coverage and len(mel_slices) > 1: + mel_slices = mel_slices[:-1] + wav_slices = wav_slices[:-1] + + return wav_slices, mel_slices + + def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75): + """ + Computes an embedding for a single utterance. The utterance is divided in partial + utterances and an embedding is computed for each. The complete utterance embedding is the + L2-normed average embedding of the partial utterances. + + TODO: independent batched version of this function + + :param wav: a preprocessed utterance waveform as a numpy array of float32 + :param return_partials: if True, the partial embeddings will also be returned along with + the wav slices corresponding to each partial utterance. + :param rate: how many partial utterances should occur per second. Partial utterances must + cover the span of the entire utterance, thus the rate should not be lower than the inverse + of the duration of a partial utterance. By default, partial utterances are 1.6s long and + the minimum rate is thus 0.625. + :param min_coverage: when reaching the last partial utterance, it may or may not have + enough frames. If at least of are present, + then the last partial utterance will be considered by zero-padding the audio. Otherwise, + it will be discarded. If there aren't enough frames for one partial utterance, + this parameter is ignored so that the function always returns at least one slice. + :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If + is True, the partial utterances as a numpy array of float32 of shape + (n_partials, model_embedding_size) and the wav partials as a list of slices will also be + returned. + """ + # Compute where to split the utterance into partials and pad the waveform with zeros if + # the partial utterances cover a larger range. + wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage) + max_wave_length = wav_slices[-1].stop + if max_wave_length >= len(wav): + wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") + + # Split the utterance into partials and forward them through the model + mel = audio.wav_to_mel_spectrogram(wav) + mels = np.array([mel[s] for s in mel_slices]) + with torch.no_grad(): + mels = torch.from_numpy(mels).to(self.device) + partial_embeds = self(mels).cpu().numpy() + + # Compute the utterance embedding from the partial embeddings + raw_embed = np.mean(partial_embeds, axis=0) + embed = raw_embed / np.linalg.norm(raw_embed, 2) + + if return_partials: + return embed, partial_embeds, wav_slices + return embed + + def embed_speaker(self, wavs: List[np.ndarray], **kwargs): + """ + Compute the embedding of a collection of wavs (presumably from the same speaker) by + averaging their embedding and L2-normalizing it. + + :param wavs: list of wavs a numpy arrays of float32. + :param kwargs: extra arguments to embed_utterance() + :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). + """ + raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) for wav in wavs], axis=0) + return raw_embed / np.linalg.norm(raw_embed, 2) diff --git a/TTS/vc/modules/freevc/wavlm/__init__.py b/TTS/vc/modules/freevc/wavlm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4046e137f5fa482bb7d3f19b25346b490bb62090 --- /dev/null +++ b/TTS/vc/modules/freevc/wavlm/__init__.py @@ -0,0 +1,39 @@ +import logging +import os +import urllib.request + +import torch +from trainer.io import get_user_data_dir + +from TTS.utils.generic_utils import is_pytorch_at_least_2_4 +from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig + +logger = logging.getLogger(__name__) + +model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt" + + +def get_wavlm(device="cpu"): + """Download the model and return the model object.""" + + output_path = get_user_data_dir("tts") + + output_path = os.path.join(output_path, "wavlm") + if not os.path.exists(output_path): + os.makedirs(output_path) + + output_path = os.path.join(output_path, "WavLM-Large.pt") + if not os.path.exists(output_path): + logger.info("Downloading WavLM model to %s ...", output_path) + urllib.request.urlretrieve(model_uri, output_path) + + checkpoint = torch.load(output_path, map_location=torch.device(device), weights_only=is_pytorch_at_least_2_4()) + cfg = WavLMConfig(checkpoint["cfg"]) + wavlm = WavLM(cfg).to(device) + wavlm.load_state_dict(checkpoint["model"]) + wavlm.eval() + return wavlm + + +if __name__ == "__main__": + wavlm = get_wavlm() diff --git a/TTS/vc/modules/freevc/wavlm/config.json b/TTS/vc/modules/freevc/wavlm/config.json new file mode 100644 index 0000000000000000000000000000000000000000..c2e414cf0b78db9ce54d91db1627d59651c747bd --- /dev/null +++ b/TTS/vc/modules/freevc/wavlm/config.json @@ -0,0 +1,99 @@ +{ + "_name_or_path": "./wavlm-large/", + "activation_dropout": 0.0, + "adapter_kernel_size": 3, + "adapter_stride": 2, + "add_adapter": false, + "apply_spec_augment": true, + "architectures": [ + "WavLMModel" + ], + "attention_dropout": 0.1, + "bos_token_id": 1, + "classifier_proj_size": 256, + "codevector_dim": 768, + "contrastive_logits_temperature": 0.1, + "conv_bias": false, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "diversity_loss_weight": 0.1, + "do_stable_layer_norm": true, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "layer", + "feat_proj_dropout": 0.1, + "feat_quantizer_dropout": 0.0, + "final_dropout": 0.0, + "gradient_checkpointing": false, + "hidden_act": "gelu", + "hidden_dropout": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.1, + "mask_channel_length": 10, + "mask_channel_min_space": 1, + "mask_channel_other": 0.0, + "mask_channel_prob": 0.0, + "mask_channel_selection": "static", + "mask_feature_length": 10, + "mask_feature_min_masks": 0, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_min_masks": 2, + "mask_time_min_space": 1, + "mask_time_other": 0.0, + "mask_time_prob": 0.075, + "mask_time_selection": "static", + "max_bucket_distance": 800, + "model_type": "wavlm", + "num_adapter_layers": 3, + "num_attention_heads": 16, + "num_buckets": 320, + "num_codevector_groups": 2, + "num_codevectors_per_group": 320, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_ctc_classes": 80, + "num_feat_extract_layers": 7, + "num_hidden_layers": 24, + "num_negatives": 100, + "output_hidden_size": 1024, + "pad_token_id": 0, + "proj_codevector_dim": 768, + "replace_prob": 0.5, + "tokenizer_class": "Wav2Vec2CTCTokenizer", + "torch_dtype": "float32", + "transformers_version": "4.15.0.dev0", + "use_weighted_layer_sum": false, + "vocab_size": 32 + } diff --git a/TTS/vc/modules/freevc/wavlm/modules.py b/TTS/vc/modules/freevc/wavlm/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..37c1a6e8774cdfd439baa38a8a7ad55fd79ebf7c --- /dev/null +++ b/TTS/vc/modules/freevc/wavlm/modules.py @@ -0,0 +1,768 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn import Parameter + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + """Swish function""" + + def __init__(self): + """Construct an MultiHeadedAttention object.""" + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = x[:, :, 0 : self.output_dim] * x[:, :, self.output_dim : self.output_dim * 2] + else: + x = x[:, :, 0 : self.output_dim] * self.glu_act(x[:, :, self.output_dim : self.output_dim * 2]) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate") + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros(in_features // block_size * out_features, device=weight.device) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device) + mask.bernoulli_(p) + mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + + # scale weights and apply mask + mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim**-0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise(nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size) + self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size) + self.q_proj = quant_noise(nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size) + + self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if ( + not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and self.q_head_dim == self.head_dim + ): + assert key is not None and value is not None + assert attn_mask is None + + attn_mask_rel_pos = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: + query_layer = query.transpose(0, 1) + new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) + query_layer = query_layer.view(*new_x_shape) + query_layer = query_layer.permute(0, 2, 1, 3) + _B, _H, _L, __ = query_layer.size() + + gate_a, gate_b = torch.sigmoid( + self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) + k_proj_bias = self.k_proj.bias + if k_proj_bias is None: + k_proj_bias = torch.zeros_like(self.q_proj.bias) + + x, attn = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training, + # self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask_rel_pos, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + return x, attn, position_bias + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid( + self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + position_bias = position_bias.view(attn_weights.size()) + + attn_weights = attn_weights + position_bias + + attn_weights_float = F.softmax(attn_weights, dim=-1) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights diff --git a/TTS/vc/modules/freevc/wavlm/wavlm.py b/TTS/vc/modules/freevc/wavlm/wavlm.py new file mode 100644 index 0000000000000000000000000000000000000000..10dd09ed0cdae90f54b71de5425d901f54c07315 --- /dev/null +++ b/TTS/vc/modules/freevc/wavlm/wavlm.py @@ -0,0 +1,723 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import logging +import math +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm + +from TTS.vc.modules.freevc.wavlm.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GLU_Linear, + GradMultiply, + MultiheadAttention, + SamePad, + TransposeLast, + get_activation_fn, + init_bert_params, +) + +logger = logging.getLogger(__name__) + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + +class WavLMConfig: + def __init__(self, cfg=None): + self.extractor_mode: str = ( + "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + ) + self.encoder_layers: int = 12 # num encoder layers in the transformer + + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.conv_feature_layers: str = ( + "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + ) + self.conv_bias: bool = False # include bias in conv encoder + self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this + + self.normalize: bool = False # normalize input to have 0 mean and unit variance during training + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) + + # masking + self.mask_length: int = 10 # mask length + self.mask_prob: float = 0.65 # probability of replacing a token with mask + self.mask_selection: str = "static" # how to choose mask length + self.mask_other: float = ( + 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh + ) + self.no_mask_overlap: bool = False # whether to allow masks to overlap + self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # channel masking + self.mask_channel_length: int = 10 # length of the mask for features (channels) + self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 + self.mask_channel_selection: str = "static" # how to choose mask length for channel masking + self.mask_channel_other: float = ( + 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices + ) + self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap + self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class WavLM(nn.Module): + def __init__( + self, + cfg: WavLMConfig, + ) -> None: + super().__init__() + logger.info(f"WavLM Config: {cfg.__dict__}") + + self.cfg = cfg + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_()) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) + # padding_mask = padding_mask.all(-1) + padding_mask = padding_mask.any(-1) + return padding_mask + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ret_layer_results: bool = False, + ): + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask(features, padding_mask) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1 + ) + + res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} + + feature = res["features"] if ret_conv else res["x"] + if ret_layer_results: + feature = (feature, res["layer_results"]) + return feature, res["padding_mask"] + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + conv_type: str = "default", + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert (is_layer_norm and is_group_norm) is False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + self.conv_type = conv_type + if self.conv_type == "default": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + elif self.conv_type == "conv2d": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + + self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride)) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + elif self.conv_type == "custom": + in_d = 1 + idim = 80 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride, padding=1)) + self.conv_layers.append(torch.nn.LayerNorm([dim, idim])) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + if (i + 1) % 2 == 0: + self.conv_layers.append(torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)) + idim = int(math.ceil(idim / 2)) + else: + pass + + def forward(self, x, mask=None): + # BxT -> BxCxT + x = x.unsqueeze(1) + if self.conv_type == "custom": + for conv in self.conv_layers: + if isinstance(conv, nn.LayerNorm): + x = x.transpose(1, 2) + x = conv(x).transpose(1, 2) + else: + x = conv(x) + x = x.transpose(2, 3).contiguous() + x = x.view(x.size(0), -1, x.size(-1)) + else: + for conv in self.conv_layers: + x = conv(x) + if self.conv_type == "conv2d": + b, c, t, f = x.size() + x = x.transpose(2, 3).contiguous().view(b, c * f, t) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.parametrizations.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + has_relative_attention_bias=(self.relative_position_embedding and i == 0), + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + ) + for i in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x += x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer( + x, + self_attn_padding_mask=padding_mask, + need_weights=False, + self_attn_mask=streaming_mask, + pos_bias=pos_bias, + ) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + ) -> None: + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias, + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias, + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias diff --git a/TTS/vocoder/README.md b/TTS/vocoder/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b9fb17c8f09fa6e8c217087e31fb8c52d96da536 --- /dev/null +++ b/TTS/vocoder/README.md @@ -0,0 +1,39 @@ +# Mozilla TTS Vocoders (Experimental) + +Here there are vocoder model implementations which can be combined with the other TTS models. + +Currently, following models are implemented: + +- Melgan +- MultiBand-Melgan +- ParallelWaveGAN +- GAN-TTS (Discriminator Only) + +It is also very easy to adapt different vocoder models as we provide a flexible and modular (but not too modular) framework. + +## Training a model + +You can see here an example (Soon)[Colab Notebook]() training MelGAN with LJSpeech dataset. + +In order to train a new model, you need to gather all wav files into a folder and give this folder to `data_path` in '''config.json''' + +You need to define other relevant parameters in your ```config.json``` and then start traning with the following command. + +```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --config_path path/to/config.json``` + +Example config files can be found under `tts/vocoder/configs/` folder. + +You can continue a previous training run by the following command. + +```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --continue_path path/to/your/model/folder``` + +You can fine-tune a pre-trained model by the following command. + +```CUDA_VISIBLE_DEVICES='0' python tts/bin/train_vocoder.py --restore_path path/to/your/model.pth``` + +Restoring a model starts a new training in a different folder. It only restores model weights with the given checkpoint file. However, continuing a training starts from the same directory where the previous training run left off. + +You can also follow your training runs on Tensorboard as you do with our TTS models. + +## Acknowledgement +Thanks to @kan-bayashi for his [repository](https://github.com/kan-bayashi/ParallelWaveGAN) being the start point of our work. diff --git a/TTS/vocoder/__init__.py b/TTS/vocoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/vocoder/configs/__init__.py b/TTS/vocoder/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e11b990c6d7294e7cb00c3e024bbb5f94a8105 --- /dev/null +++ b/TTS/vocoder/configs/__init__.py @@ -0,0 +1,17 @@ +import importlib +import os +from inspect import isclass + +# import all files under configs/ +configs_dir = os.path.dirname(__file__) +for file in os.listdir(configs_dir): + path = os.path.join(configs_dir, file) + if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): + config_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("TTS.vocoder.configs." + config_name) + for attribute_name in dir(module): + attribute = getattr(module, attribute_name) + + if isclass(attribute): + # Add the class to this package's variables + globals()[attribute_name] = attribute diff --git a/TTS/vocoder/configs/fullband_melgan_config.py b/TTS/vocoder/configs/fullband_melgan_config.py new file mode 100644 index 0000000000000000000000000000000000000000..2ab83aace678e328a8f99a5f0dc63e54ed99d4c4 --- /dev/null +++ b/TTS/vocoder/configs/fullband_melgan_config.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass, field + +from .shared_configs import BaseGANVocoderConfig + + +@dataclass +class FullbandMelganConfig(BaseGANVocoderConfig): + """Defines parameters for FullBand MelGAN vocoder. + + Example: + + >>> from TTS.vocoder.configs import FullbandMelganConfig + >>> config = FullbandMelganConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `fullband_melgan`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'melgan_multiscale_discriminator`. + discriminator_model_params (dict): The discriminator model parameters. Defaults to + '{"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]}` + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `melgan_generator`. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 16. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + use_stft_loss (bool): + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False. + stft_loss_params (dict): STFT loss parameters. Default to + `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + """ + + model: str = "fullband_melgan" + + # Model specific params + discriminator_model: str = "melgan_multiscale_discriminator" + discriminator_model_params: dict = field( + default_factory=lambda: {"base_channels": 16, "max_channels": 512, "downsample_factors": [4, 4, 4]} + ) + generator_model: str = "melgan_generator" + generator_model_params: dict = field( + default_factory=lambda: {"upsample_factors": [8, 8, 2, 2], "num_res_blocks": 4} + ) + + # Training - overrides + batch_size: int = 16 + seq_len: int = 8192 + pad_short: int = 2000 + use_noise_augment: bool = True + use_cache: bool = True + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = True + use_subband_stft_loss: bool = False + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = False + + stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + # loss weights - overrides + stft_loss_weight: float = 0.5 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 2.5 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 108 + l1_spec_loss_weight: float = 0.0 diff --git a/TTS/vocoder/configs/hifigan_config.py b/TTS/vocoder/configs/hifigan_config.py new file mode 100644 index 0000000000000000000000000000000000000000..9a102f0c89588b1a7fe270225e4b0fefa2e4bc71 --- /dev/null +++ b/TTS/vocoder/configs/hifigan_config.py @@ -0,0 +1,136 @@ +from dataclasses import dataclass, field + +from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig + + +@dataclass +class HifiganConfig(BaseGANVocoderConfig): + """Defines parameters for FullBand MelGAN vocoder. + + Example: + + >>> from TTS.vocoder.configs import HifiganConfig + >>> config = HifiganConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `hifigan`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'hifigan_discriminator`. + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `hifigan_generator`. + generator_model_params (dict): Parameters of the generator model. Defaults to + ` + { + "upsample_factors": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "resblock_type": "1", + } + ` + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 16. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + use_stft_loss (bool): + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False. + stft_loss_params (dict): + STFT loss parameters. Default to + `{ + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240] + }` + l1_spec_loss_params (dict): + L1 spectrogram loss parameters. Default to + `{ + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + }` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + """ + + model: str = "hifigan" + # model specific params + discriminator_model: str = "hifigan_discriminator" + generator_model: str = "hifigan_generator" + generator_model_params: dict = field( + default_factory=lambda: { + "upsample_factors": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + "upsample_initial_channel": 512, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "resblock_type": "1", + } + ) + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = False + use_subband_stft_loss: bool = False + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = True + + # loss weights - overrides + stft_loss_weight: float = 0 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 1 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 108 + l1_spec_loss_weight: float = 45 + l1_spec_loss_params: dict = field( + default_factory=lambda: { + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + } + ) + + # optimizer parameters + lr: float = 1e-4 + wd: float = 1e-6 diff --git a/TTS/vocoder/configs/melgan_config.py b/TTS/vocoder/configs/melgan_config.py new file mode 100644 index 0000000000000000000000000000000000000000..dc35b6f8b70891d4904baefad802d9c62fe67925 --- /dev/null +++ b/TTS/vocoder/configs/melgan_config.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass, field + +from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig + + +@dataclass +class MelganConfig(BaseGANVocoderConfig): + """Defines parameters for MelGAN vocoder. + + Example: + + >>> from TTS.vocoder.configs import MelganConfig + >>> config = MelganConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `melgan`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'melgan_multiscale_discriminator`. + discriminator_model_params (dict): The discriminator model parameters. Defaults to + '{"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]}` + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `melgan_generator`. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 16. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + use_stft_loss (bool): + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False. + stft_loss_params (dict): STFT loss parameters. Default to + `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + """ + + model: str = "melgan" + + # Model specific params + discriminator_model: str = "melgan_multiscale_discriminator" + discriminator_model_params: dict = field( + default_factory=lambda: {"base_channels": 16, "max_channels": 1024, "downsample_factors": [4, 4, 4, 4]} + ) + generator_model: str = "melgan_generator" + generator_model_params: dict = field( + default_factory=lambda: {"upsample_factors": [8, 8, 2, 2], "num_res_blocks": 3} + ) + + # Training - overrides + batch_size: int = 16 + seq_len: int = 8192 + pad_short: int = 2000 + use_noise_augment: bool = True + use_cache: bool = True + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = True + use_subband_stft_loss: bool = False + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = False + + stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + # loss weights - overrides + stft_loss_weight: float = 0.5 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 2.5 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 108 + l1_spec_loss_weight: float = 0 diff --git a/TTS/vocoder/configs/multiband_melgan_config.py b/TTS/vocoder/configs/multiband_melgan_config.py new file mode 100644 index 0000000000000000000000000000000000000000..763113537f36a8615b2b77369bf5bde01527fe53 --- /dev/null +++ b/TTS/vocoder/configs/multiband_melgan_config.py @@ -0,0 +1,144 @@ +from dataclasses import dataclass, field + +from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig + + +@dataclass +class MultibandMelganConfig(BaseGANVocoderConfig): + """Defines parameters for MultiBandMelGAN vocoder. + + Example: + + >>> from TTS.vocoder.configs import MultibandMelganConfig + >>> config = MultibandMelganConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `multiband_melgan`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'melgan_multiscale_discriminator`. + discriminator_model_params (dict): The discriminator model parameters. Defaults to + '{ + "base_channels": 16, + "max_channels": 512, + "downsample_factors": [4, 4, 4] + }` + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `melgan_generator`. + generator_model_param (dict): + The generator model parameters. Defaults to `{"upsample_factors": [8, 4, 2], "num_res_blocks": 4}`. + use_pqmf (bool): + enable / disable PQMF modulation for multi-band training. Defaults to True. + lr_gen (float): + Initial learning rate for the generator model. Defaults to 0.0001. + lr_disc (float): + Initial learning rate for the discriminator model. Defaults to 0.0001. + optimizer (torch.optim.Optimizer): + Optimizer used for the training. Defaults to `AdamW`. + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + lr_scheduler_gen (torch.optim.Scheduler): + Learning rate scheduler for the generator. Defaults to `MultiStepLR`. + lr_scheduler_gen_params (dict): + Parameters for the generator learning rate scheduler. Defaults to + `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`. + lr_scheduler_disc (torch.optim.Scheduler): + Learning rate scheduler for the discriminator. Defaults to `MultiStepLR`. + lr_scheduler_dict_params (dict): + Parameters for the discriminator learning rate scheduler. Defaults to + `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}`. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 16. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + steps_to_start_discriminator (int): + Number of steps required to start training the discriminator. Defaults to 0. + use_stft_loss (bool):` + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False. + stft_loss_params (dict): STFT loss parameters. Default to + `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + """ + + model: str = "multiband_melgan" + + # Model specific params + discriminator_model: str = "melgan_multiscale_discriminator" + discriminator_model_params: dict = field( + default_factory=lambda: {"base_channels": 16, "max_channels": 512, "downsample_factors": [4, 4, 4]} + ) + generator_model: str = "multiband_melgan_generator" + generator_model_params: dict = field(default_factory=lambda: {"upsample_factors": [8, 4, 2], "num_res_blocks": 4}) + use_pqmf: bool = True + + # optimizer - overrides + lr_gen: float = 0.0001 # Initial learning rate. + lr_disc: float = 0.0001 # Initial learning rate. + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0}) + lr_scheduler_gen: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_gen_params: dict = field( + default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]} + ) + lr_scheduler_disc: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_disc_params: dict = field( + default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]} + ) + + # Training - overrides + batch_size: int = 64 + seq_len: int = 16384 + pad_short: int = 2000 + use_noise_augment: bool = False + use_cache: bool = True + steps_to_start_discriminator: bool = 200000 + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = True + use_subband_stft_loss: bool = True + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = False + + subband_stft_loss_params: dict = field( + default_factory=lambda: {"n_ffts": [384, 683, 171], "hop_lengths": [30, 60, 10], "win_lengths": [150, 300, 60]} + ) + + # loss weights - overrides + stft_loss_weight: float = 0.5 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 2.5 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 108 + l1_spec_loss_weight: float = 0 diff --git a/TTS/vocoder/configs/parallel_wavegan_config.py b/TTS/vocoder/configs/parallel_wavegan_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6059d7f04f7c9fadd6df1d424e8e164e54a7310e --- /dev/null +++ b/TTS/vocoder/configs/parallel_wavegan_config.py @@ -0,0 +1,134 @@ +from dataclasses import dataclass, field + +from .shared_configs import BaseGANVocoderConfig + + +@dataclass +class ParallelWaveganConfig(BaseGANVocoderConfig): + """Defines parameters for ParallelWavegan vocoder. + + Args: + model (str): + Model name used for selecting the right configuration at initialization. Defaults to `gan`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'parallel_wavegan_discriminator`. + discriminator_model_params (dict): The discriminator model kwargs. Defaults to + '{"num_layers": 10}` + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `parallel_wavegan_generator`. + generator_model_param (dict): + The generator model kwargs. Defaults to `{"upsample_factors": [4, 4, 4, 4], "stacks": 3, "num_res_blocks": 30}`. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 16. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + steps_to_start_discriminator (int): + Number of steps required to start training the discriminator. Defaults to 0. + use_stft_loss (bool):` + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by HifiGAN model. Defaults to False. + stft_loss_params (dict): STFT loss parameters. Default to + `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 0. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + lr_gen (float): + Generator model initial learning rate. Defaults to 0.0002. + lr_disc (float): + Discriminator model initial learning rate. Defaults to 0.0002. + optimizer (torch.optim.Optimizer): + Optimizer used for the training. Defaults to `AdamW`. + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + lr_scheduler_gen (torch.optim.Scheduler): + Learning rate scheduler for the generator. Defaults to `ExponentialLR`. + lr_scheduler_gen_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.5, "step_size": 200000, "last_epoch": -1}`. + lr_scheduler_disc (torch.optim.Scheduler): + Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`. + lr_scheduler_dict_params (dict): + Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.5, "step_size": 200000, "last_epoch": -1}`. + """ + + model: str = "parallel_wavegan" + + # Model specific params + discriminator_model: str = "parallel_wavegan_discriminator" + discriminator_model_params: dict = field(default_factory=lambda: {"num_layers": 10}) + generator_model: str = "parallel_wavegan_generator" + generator_model_params: dict = field( + default_factory=lambda: {"upsample_factors": [4, 4, 4, 4], "stacks": 3, "num_res_blocks": 30} + ) + + # Training - overrides + batch_size: int = 6 + seq_len: int = 25600 + pad_short: int = 2000 + use_noise_augment: bool = False + use_cache: bool = True + steps_to_start_discriminator: int = 200000 + target_loss: str = "loss_1" + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = True + use_subband_stft_loss: bool = False + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = False + + stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + # loss weights - overrides + stft_loss_weight: float = 0.5 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 2.5 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 0 + l1_spec_loss_weight: float = 0 + + # optimizer overrides + lr_gen: float = 0.0002 # Initial learning rate. + lr_disc: float = 0.0002 # Initial learning rate. + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0}) + lr_scheduler_gen: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}) + lr_scheduler_disc: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_disc_params: dict = field( + default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1} + ) + scheduler_after_epoch: bool = False diff --git a/TTS/vocoder/configs/shared_configs.py b/TTS/vocoder/configs/shared_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..a558cfcabbc2abc26be60065d3ac75cebd829f28 --- /dev/null +++ b/TTS/vocoder/configs/shared_configs.py @@ -0,0 +1,182 @@ +from dataclasses import dataclass, field + +from TTS.config import BaseAudioConfig, BaseTrainingConfig + + +@dataclass +class BaseVocoderConfig(BaseTrainingConfig): + """Shared parameters among all the vocoder models. + Args: + audio (BaseAudioConfig): + Audio processor config instance. Defaultsto `BaseAudioConfig()`. + use_noise_augment (bool): + Augment the input audio with random noise. Defaults to False/ + eval_split_size (int): + Number of instances used for evaluation. Defaults to 10. + data_path (str): + Root path of the training data. All the audio files found recursively from this root path are used for + training. Defaults to `""`. + feature_path (str): + Root path to the precomputed feature files. Defaults to None. + seq_len (int): + Length of the waveform segments used for training. Defaults to 1000. + pad_short (int): + Extra padding for the waveforms shorter than `seq_len`. Defaults to 0. + conv_path (int): + Extra padding for the feature frames against convolution of the edge frames. Defaults to MISSING. + Defaults to 0. + use_cache (bool): + enable / disable in memory caching of the computed features. If the RAM is not enough, if may cause OOM. + Defaults to False. + epochs (int): + Number of training epochs to. Defaults to 10000. + wd (float): + Weight decay. + optimizer (torch.optim.Optimizer): + Optimizer used for the training. Defaults to `AdamW`. + optimizer_params (dict): + Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}` + """ + + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + # dataloading + use_noise_augment: bool = False # enable/disable random noise augmentation in spectrograms. + eval_split_size: int = 10 # number of samples used for evaluation. + # dataset + data_path: str = "" # root data path. It finds all wav files recursively from there. + feature_path: str = None # if you use precomputed features + seq_len: int = 1000 # signal length used in training. + pad_short: int = 0 # additional padding for short wavs + conv_pad: int = 0 # additional padding against convolutions applied to spectrograms + use_cache: bool = False # use in memory cache to keep the computed features. This might cause OOM. + # OPTIMIZER + epochs: int = 10000 # total number of epochs to train. + wd: float = 0.0 # Weight decay weight. + optimizer: str = "AdamW" + optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0}) + + +@dataclass +class BaseGANVocoderConfig(BaseVocoderConfig): + """Base config class used among all the GAN based vocoders. + Args: + use_stft_loss (bool): + enable / disable the use of STFT loss. Defaults to True. + use_subband_stft_loss (bool): + enable / disable the use of Subband STFT loss. Defaults to True. + use_mse_gan_loss (bool): + enable / disable the use of Mean Squared Error based GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable the use of Hinge GAN loss. Defaults to True. + use_feat_match_loss (bool): + enable / disable feature matching loss. Defaults to True. + use_l1_spec_loss (bool): + enable / disable L1 spectrogram loss. Defaults to True. + stft_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 0. + subband_stft_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 0. + mse_G_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 1. + hinge_G_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 0. + feat_match_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 100. + l1_spec_loss_weight (float): + Loss weight that multiplies the computed loss value. Defaults to 45. + stft_loss_params (dict): + Parameters for the STFT loss. Defaults to `{"n_ffts": [1024, 2048, 512], "hop_lengths": [120, 240, 50], "win_lengths": [600, 1200, 240]}`. + l1_spec_loss_params (dict): + Parameters for the L1 spectrogram loss. Defaults to + `{ + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + }` + target_loss (str): + Target loss name that defines the quality of the model. Defaults to `G_avg_loss`. + grad_clip (list): + A list of gradient clipping theresholds for each optimizer. Any value less than 0 disables clipping. + Defaults to [5, 5]. + lr_gen (float): + Generator model initial learning rate. Defaults to 0.0002. + lr_disc (float): + Discriminator model initial learning rate. Defaults to 0.0002. + lr_scheduler_gen (torch.optim.Scheduler): + Learning rate scheduler for the generator. Defaults to `ExponentialLR`. + lr_scheduler_gen_params (dict): + Parameters for the generator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`. + lr_scheduler_disc (torch.optim.Scheduler): + Learning rate scheduler for the discriminator. Defaults to `ExponentialLR`. + lr_scheduler_disc_params (dict): + Parameters for the discriminator learning rate scheduler. Defaults to `{"gamma": 0.999, "last_epoch": -1}`. + scheduler_after_epoch (bool): + Whether to update the learning rate schedulers after each epoch. Defaults to True. + use_pqmf (bool): + enable / disable PQMF for subband approximation at training. Defaults to False. + steps_to_start_discriminator (int): + Number of steps required to start training the discriminator. Defaults to 0. + diff_samples_for_G_and_D (bool): + enable / disable use of different training samples for the generator and the discriminator iterations. + Enabling it results in slower iterations but faster convergance in some cases. Defaults to False. + """ + + model: str = "gan" + + # LOSS PARAMETERS + use_stft_loss: bool = True + use_subband_stft_loss: bool = True + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = True + use_feat_match_loss: bool = True # requires MelGAN Discriminators (MelGAN and HifiGAN) + use_l1_spec_loss: bool = True + + # loss weights + stft_loss_weight: float = 0 + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 1 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 100 + l1_spec_loss_weight: float = 45 + + stft_loss_params: dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + + l1_spec_loss_params: dict = field( + default_factory=lambda: { + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + } + ) + + target_loss: str = "loss_0" # loss value to pick the best model to save after each epoch + + # optimizer + grad_clip: float = field(default_factory=lambda: [5, 5]) + lr_gen: float = 0.0002 # Initial learning rate. + lr_disc: float = 0.0002 # Initial learning rate. + lr_scheduler_gen: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) + lr_scheduler_disc: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) + scheduler_after_epoch: bool = True + + use_pqmf: bool = False # enable/disable using pqmf for multi-band training. (Multi-band MelGAN) + steps_to_start_discriminator = 0 # start training the discriminator after this number of steps. + diff_samples_for_G_and_D: bool = False # use different samples for G and D training steps. diff --git a/TTS/vocoder/configs/univnet_config.py b/TTS/vocoder/configs/univnet_config.py new file mode 100644 index 0000000000000000000000000000000000000000..67f324cfce5f701f0d7453beab81590bef6be114 --- /dev/null +++ b/TTS/vocoder/configs/univnet_config.py @@ -0,0 +1,161 @@ +from dataclasses import dataclass, field +from typing import Dict + +from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig + + +@dataclass +class UnivnetConfig(BaseGANVocoderConfig): + """Defines parameters for UnivNet vocoder. + + Example: + + >>> from TTS.vocoder.configs import UnivNetConfig + >>> config = UnivNetConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `UnivNet`. + discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to + 'UnivNet_discriminator`. + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `UnivNet_generator`. + generator_model_params (dict): Parameters of the generator model. Defaults to + ` + { + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + } + ` + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 32. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 8192. + pad_short (int): + Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0. + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + use_stft_loss (bool): + enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True. + use_subband_stft (bool): + enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True. + use_mse_gan_loss (bool): + enable / disable using Mean Squeare Error GAN loss. Defaults to True. + use_hinge_gan_loss (bool): + enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models. + Defaults to False. + use_feat_match_loss (bool): + enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True. + use_l1_spec_loss (bool): + enable / disable using L1 spectrogram loss originally used by univnet model. Defaults to False. + stft_loss_params (dict): + STFT loss parameters. Default to + `{ + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240] + }` + l1_spec_loss_params (dict): + L1 spectrogram loss parameters. Default to + `{ + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + }` + stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total + model loss. Defaults to 0.5. + subband_stft_loss_weight (float): + Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + mse_G_loss_weight (float): + MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5. + hinge_G_loss_weight (float): + Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + feat_match_loss_weight (float): + Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108. + l1_spec_loss_weight (float): + L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0. + """ + + model: str = "univnet" + batch_size: int = 32 + # model specific params + discriminator_model: str = "univnet_discriminator" + generator_model: str = "univnet_generator" + generator_model_params: Dict = field( + default_factory=lambda: { + "in_channels": 64, + "out_channels": 1, + "hidden_channels": 32, + "cond_channels": 80, + "upsample_factors": [8, 8, 4], + "lvc_layers_each_block": 4, + "lvc_kernel_size": 3, + "kpnet_hidden_channels": 64, + "kpnet_conv_size": 3, + "dropout": 0.0, + } + ) + + # LOSS PARAMETERS - overrides + use_stft_loss: bool = True + use_subband_stft_loss: bool = False + use_mse_gan_loss: bool = True + use_hinge_gan_loss: bool = False + use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and univnet) + use_l1_spec_loss: bool = False + + # loss weights - overrides + stft_loss_weight: float = 2.5 + stft_loss_params: Dict = field( + default_factory=lambda: { + "n_ffts": [1024, 2048, 512], + "hop_lengths": [120, 240, 50], + "win_lengths": [600, 1200, 240], + } + ) + subband_stft_loss_weight: float = 0 + mse_G_loss_weight: float = 1 + hinge_G_loss_weight: float = 0 + feat_match_loss_weight: float = 0 + l1_spec_loss_weight: float = 0 + l1_spec_loss_params: Dict = field( + default_factory=lambda: { + "use_mel": True, + "sample_rate": 22050, + "n_fft": 1024, + "hop_length": 256, + "win_length": 1024, + "n_mels": 80, + "mel_fmin": 0.0, + "mel_fmax": None, + } + ) + + # optimizer parameters + lr_gen: float = 1e-4 # Initial learning rate. + lr_disc: float = 1e-4 # Initial learning rate. + lr_scheduler_gen: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + # lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) + lr_scheduler_disc: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + # lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1}) + optimizer_params: Dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0}) + steps_to_start_discriminator: int = 200000 + + def __post_init__(self): + super().__post_init__() + self.generator_model_params["cond_channels"] = self.audio.num_mels diff --git a/TTS/vocoder/configs/wavegrad_config.py b/TTS/vocoder/configs/wavegrad_config.py new file mode 100644 index 0000000000000000000000000000000000000000..c39813ae68c3d8c77614c9a5188ac5f2a59d991d --- /dev/null +++ b/TTS/vocoder/configs/wavegrad_config.py @@ -0,0 +1,90 @@ +from dataclasses import dataclass, field + +from TTS.vocoder.configs.shared_configs import BaseVocoderConfig +from TTS.vocoder.models.wavegrad import WavegradArgs + + +@dataclass +class WavegradConfig(BaseVocoderConfig): + """Defines parameters for WaveGrad vocoder. + Example: + + >>> from TTS.vocoder.configs import WavegradConfig + >>> config = WavegradConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `wavegrad`. + generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `wavegrad`. + model_params (WavegradArgs): Model parameters. Check `WavegradArgs` for default values. + target_loss (str): + Target loss name that defines the quality of the model. Defaults to `avg_wavegrad_loss`. + epochs (int): + Number of epochs to traing the model. Defaults to 10000. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 96. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 6144. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + mixed_precision (bool): + enable / disable mixed precision training. Default is True. + eval_split_size (int): + Number of samples used for evalutaion. Defaults to 50. + train_noise_schedule (dict): + Training noise schedule. Defaults to + `{"min_val": 1e-6, "max_val": 1e-2, "num_steps": 1000}` + test_noise_schedule (dict): + Inference noise schedule. For a better performance, you may need to use `bin/tune_wavegrad.py` to find a + better schedule. Defaults to + ` + { + "min_val": 1e-6, + "max_val": 1e-2, + "num_steps": 50, + } + ` + grad_clip (float): + Gradient clipping threshold. If <= 0.0, no clipping is applied. Defaults to 1.0 + lr (float): + Initila leraning rate. Defaults to 1e-4. + lr_scheduler (str): + One of the learning rate schedulers from `torch.optim.scheduler.*`. Defaults to `MultiStepLR`. + lr_scheduler_params (dict): + kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]}` + """ + + model: str = "wavegrad" + # Model specific params + generator_model: str = "wavegrad" + model_params: WavegradArgs = field(default_factory=WavegradArgs) + target_loss: str = "loss" # loss value to pick the best model to save after each epoch + + # Training - overrides + epochs: int = 10000 + batch_size: int = 96 + seq_len: int = 6144 + use_cache: bool = True + mixed_precision: bool = True + eval_split_size: int = 50 + + # NOISE SCHEDULE PARAMS + train_noise_schedule: dict = field(default_factory=lambda: {"min_val": 1e-6, "max_val": 1e-2, "num_steps": 1000}) + + test_noise_schedule: dict = field( + default_factory=lambda: { # inference noise schedule. Try TTS/bin/tune_wavegrad.py to find the optimal values. + "min_val": 1e-6, + "max_val": 1e-2, + "num_steps": 50, + } + ) + + # optimizer overrides + grad_clip: float = 1.0 + lr: float = 1e-4 # Initial learning rate. + lr_scheduler: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_params: dict = field( + default_factory=lambda: {"gamma": 0.5, "milestones": [100000, 200000, 300000, 400000, 500000, 600000]} + ) diff --git a/TTS/vocoder/configs/wavernn_config.py b/TTS/vocoder/configs/wavernn_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f39400e5e50b56d4ff79c8c148fd518b3ec3b390 --- /dev/null +++ b/TTS/vocoder/configs/wavernn_config.py @@ -0,0 +1,102 @@ +from dataclasses import dataclass, field + +from TTS.vocoder.configs.shared_configs import BaseVocoderConfig +from TTS.vocoder.models.wavernn import WavernnArgs + + +@dataclass +class WavernnConfig(BaseVocoderConfig): + """Defines parameters for Wavernn vocoder. + Example: + + >>> from TTS.vocoder.configs import WavernnConfig + >>> config = WavernnConfig() + + Args: + model (str): + Model name used for selecting the right model at initialization. Defaults to `wavernn`. + mode (str): + Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single + Gaussian Distribution and `bits` for quantized bits as the model's output. + mulaw (bool): + enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults + to `True`. + generator_model (str): + One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is + considered as a generator too. Defaults to `WaveRNN`. + wavernn_model_params (dict): + kwargs for the WaveRNN model. Defaults to + `{ + "rnn_dims": 512, + "fc_dims": 512, + "compute_dims": 128, + "res_out_dims": 128, + "num_res_blocks": 10, + "use_aux_net": True, + "use_upsample_net": True, + "upsample_factors": [4, 8, 8] + }` + batched (bool): + enable / disable the batched inference. It speeds up the inference by splitting the input into segments and + processing the segments in a batch. Then it merges the outputs with a certain overlap and smoothing. If + you set it False, without CUDA, it is too slow to be practical. Defaults to True. + target_samples (int): + Size of the segments in batched mode. Defaults to 11000. + overlap_sampels (int): + Size of the overlap between consecutive segments. Defaults to 550. + batch_size (int): + Batch size used at training. Larger values use more memory. Defaults to 256. + seq_len (int): + Audio segment length used at training. Larger values use more memory. Defaults to 1280. + + use_noise_augment (bool): + enable / disable random noise added to the input waveform. The noise is added after computing the + features. Defaults to True. + use_cache (bool): + enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is + not large enough. Defaults to True. + mixed_precision (bool): + enable / disable mixed precision training. Default is True. + eval_split_size (int): + Number of samples used for evalutaion. Defaults to 50. + num_epochs_before_test (int): + Number of epochs waited to run the next evalution. Since inference takes some time, it is better to + wait some number of epochs not ot waste training time. Defaults to 10. + grad_clip (float): + Gradient clipping threshold. If <= 0.0, no clipping is applied. Defaults to 4.0 + lr (float): + Initila leraning rate. Defaults to 1e-4. + lr_scheduler (str): + One of the learning rate schedulers from `torch.optim.scheduler.*`. Defaults to `MultiStepLR`. + lr_scheduler_params (dict): + kwargs for the scheduler. Defaults to `{"gamma": 0.5, "milestones": [200000, 400000, 600000]}` + """ + + model: str = "wavernn" + + # Model specific params + model_args: WavernnArgs = field(default_factory=WavernnArgs) + target_loss: str = "loss" + + # Inference + batched: bool = True + target_samples: int = 11000 + overlap_samples: int = 550 + + # Training - overrides + epochs: int = 10000 + batch_size: int = 256 + seq_len: int = 1280 + use_noise_augment: bool = False + use_cache: bool = True + mixed_precision: bool = True + eval_split_size: int = 50 + num_epochs_before_test: int = ( + 10 # number of epochs to wait until the next test run (synthesizing a full audio clip). + ) + + # optimizer overrides + grad_clip: float = 4.0 + lr: float = 1e-4 # Initial learning rate. + lr_scheduler: str = "MultiStepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html + lr_scheduler_params: dict = field(default_factory=lambda: {"gamma": 0.5, "milestones": [200000, 400000, 600000]}) diff --git a/TTS/vocoder/datasets/__init__.py b/TTS/vocoder/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..04462817a8b801ba7d04b98c8c62043ebe347c20 --- /dev/null +++ b/TTS/vocoder/datasets/__init__.py @@ -0,0 +1,55 @@ +from typing import List + +from coqpit import Coqpit +from torch.utils.data import Dataset + +from TTS.utils.audio import AudioProcessor +from TTS.vocoder.datasets.gan_dataset import GANDataset +from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data +from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset +from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset + + +def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List) -> Dataset: + if config.model.lower() in "gan": + dataset = GANDataset( + ap=ap, + items=data_items, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False, + is_training=not is_eval, + return_segments=not is_eval, + use_noise_augment=config.use_noise_augment, + use_cache=config.use_cache, + ) + dataset.shuffle_mapping() + elif config.model.lower() == "wavegrad": + dataset = WaveGradDataset( + ap=ap, + items=data_items, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + is_training=not is_eval, + return_segments=True, + use_noise_augment=False, + use_cache=config.use_cache, + ) + elif config.model.lower() == "wavernn": + dataset = WaveRNNDataset( + ap=ap, + items=data_items, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad=config.model_params.pad, + mode=config.model_params.mode, + mulaw=config.model_params.mulaw, + is_training=not is_eval, + ) + else: + raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.") + return dataset diff --git a/TTS/vocoder/datasets/gan_dataset.py b/TTS/vocoder/datasets/gan_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0806c0d4963e0eedaf323f5a95fda86b4508a770 --- /dev/null +++ b/TTS/vocoder/datasets/gan_dataset.py @@ -0,0 +1,149 @@ +import glob +import os +import random +from multiprocessing import Manager + +import numpy as np +import torch +from torch.utils.data import Dataset + + +class GANDataset(Dataset): + """ + GAN Dataset searchs for all the wav files under root path + and converts them to acoustic features on the fly and returns + random segments of (audio, feature) couples. + """ + + def __init__( + self, + ap, + items, + seq_len, + hop_len, + pad_short, + conv_pad=2, + return_pairs=False, + is_training=True, + return_segments=True, + use_noise_augment=False, + use_cache=False, + ): + super().__init__() + self.ap = ap + self.item_list = items + self.compute_feat = not isinstance(items[0], (tuple, list)) + self.seq_len = seq_len + self.hop_len = hop_len + self.pad_short = pad_short + self.conv_pad = conv_pad + self.return_pairs = return_pairs + self.is_training = is_training + self.return_segments = return_segments + self.use_cache = use_cache + self.use_noise_augment = use_noise_augment + + assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len." + self.feat_frame_len = seq_len // hop_len + (2 * conv_pad) + + # map G and D instances + self.G_to_D_mappings = list(range(len(self.item_list))) + self.shuffle_mapping() + + # cache acoustic features + if use_cache: + self.create_feature_cache() + + def create_feature_cache(self): + self.manager = Manager() + self.cache = self.manager.list() + self.cache += [None for _ in range(len(self.item_list))] + + @staticmethod + def find_wav_files(path): + return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True) + + def __len__(self): + return len(self.item_list) + + def __getitem__(self, idx): + """Return different items for Generator and Discriminator and + cache acoustic features""" + + # set the seed differently for each worker + if torch.utils.data.get_worker_info(): + random.seed(torch.utils.data.get_worker_info().seed) + + if self.return_segments: + item1 = self.load_item(idx) + if self.return_pairs: + idx2 = self.G_to_D_mappings[idx] + item2 = self.load_item(idx2) + return item1, item2 + return item1 + item1 = self.load_item(idx) + return item1 + + def _pad_short_samples(self, audio, mel=None): + """Pad samples shorter than the output sequence length""" + if len(audio) < self.seq_len: + audio = np.pad(audio, (0, self.seq_len - len(audio)), mode="constant", constant_values=0.0) + + if mel is not None and mel.shape[1] < self.feat_frame_len: + pad_value = self.ap.melspectrogram(np.zeros([self.ap.win_length]))[:, 0] + mel = np.pad( + mel, + ([0, 0], [0, self.feat_frame_len - mel.shape[1]]), + mode="constant", + constant_values=pad_value.mean(), + ) + return audio, mel + + def shuffle_mapping(self): + random.shuffle(self.G_to_D_mappings) + + def load_item(self, idx): + """load (audio, feat) couple""" + if self.compute_feat: + # compute features from wav + wavpath = self.item_list[idx] + + if self.use_cache and self.cache[idx] is not None: + audio, mel = self.cache[idx] + else: + audio = self.ap.load_wav(wavpath) + mel = self.ap.melspectrogram(audio) + audio, mel = self._pad_short_samples(audio, mel) + else: + # load precomputed features + wavpath, feat_path = self.item_list[idx] + + if self.use_cache and self.cache[idx] is not None: + audio, mel = self.cache[idx] + else: + audio = self.ap.load_wav(wavpath) + mel = np.load(feat_path) + audio, mel = self._pad_short_samples(audio, mel) + + # correct the audio length wrt padding applied in stft + audio = np.pad(audio, (0, self.hop_len), mode="edge") + audio = audio[: mel.shape[-1] * self.hop_len] + assert ( + mel.shape[-1] * self.hop_len == audio.shape[-1] + ), f" [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}" + + audio = torch.from_numpy(audio).float().unsqueeze(0) + mel = torch.from_numpy(mel).float().squeeze(0) + + if self.return_segments: + max_mel_start = mel.shape[1] - self.feat_frame_len + mel_start = random.randint(0, max_mel_start) + mel_end = mel_start + self.feat_frame_len + mel = mel[:, mel_start:mel_end] + + audio_start = mel_start * self.hop_len + audio = audio[:, audio_start : audio_start + self.seq_len] + + if self.use_noise_augment and self.is_training and self.return_segments: + audio = audio + (1 / 32768) * torch.randn_like(audio) + return (mel, audio) diff --git a/TTS/vocoder/datasets/preprocess.py b/TTS/vocoder/datasets/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..503bb04b2fba637ec3b8ab449018558898f05024 --- /dev/null +++ b/TTS/vocoder/datasets/preprocess.py @@ -0,0 +1,75 @@ +import glob +import os +from pathlib import Path + +import numpy as np +from coqpit import Coqpit +from tqdm import tqdm + +from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize + + +def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor): + """Process wav and compute mel and quantized wave signal. + It is mainly used by WaveRNN dataloader. + + Args: + out_path (str): Parent folder path to save the files. + config (Coqpit): Model config. + ap (AudioProcessor): Audio processor. + """ + os.makedirs(os.path.join(out_path, "quant"), exist_ok=True) + os.makedirs(os.path.join(out_path, "mel"), exist_ok=True) + wav_files = find_wav_files(config.data_path) + for path in tqdm(wav_files): + wav_name = Path(path).stem + quant_path = os.path.join(out_path, "quant", wav_name + ".npy") + mel_path = os.path.join(out_path, "mel", wav_name + ".npy") + y = ap.load_wav(path) + mel = ap.melspectrogram(y) + np.save(mel_path, mel) + if isinstance(config.mode, int): + quant = ( + mulaw_encode(wav=y, mulaw_qc=config.mode) + if config.model_args.mulaw + else quantize(x=y, quantize_bits=config.mode) + ) + np.save(quant_path, quant) + + +def find_wav_files(data_path, file_ext="wav"): + wav_paths = glob.glob(os.path.join(data_path, "**", f"*.{file_ext}"), recursive=True) + return wav_paths + + +def find_feat_files(data_path): + feat_paths = glob.glob(os.path.join(data_path, "**", "*.npy"), recursive=True) + return feat_paths + + +def load_wav_data(data_path, eval_split_size, file_ext="wav"): + wav_paths = find_wav_files(data_path, file_ext=file_ext) + assert len(wav_paths) > 0, f" [!] {data_path} is empty." + np.random.seed(0) + np.random.shuffle(wav_paths) + return wav_paths[:eval_split_size], wav_paths[eval_split_size:] + + +def load_wav_feat_data(data_path, feat_path, eval_split_size): + wav_paths = find_wav_files(data_path) + feat_paths = find_feat_files(feat_path) + + wav_paths.sort(key=lambda x: Path(x).stem) + feat_paths.sort(key=lambda x: Path(x).stem) + + assert len(wav_paths) == len(feat_paths), f" [!] {len(wav_paths)} vs {feat_paths}" + for wav, feat in zip(wav_paths, feat_paths): + wav_name = Path(wav).stem + feat_name = Path(feat).stem + assert wav_name == feat_name + + items = list(zip(wav_paths, feat_paths)) + np.random.seed(0) + np.random.shuffle(items) + return items[:eval_split_size], items[eval_split_size:] diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6f34bccb7c386642dbf3635e7e0ee677fca368ff --- /dev/null +++ b/TTS/vocoder/datasets/wavegrad_dataset.py @@ -0,0 +1,149 @@ +import glob +import os +import random +from multiprocessing import Manager +from typing import List, Tuple + +import numpy as np +import torch +from torch.utils.data import Dataset + + +class WaveGradDataset(Dataset): + """ + WaveGrad Dataset searchs for all the wav files under root path + and converts them to acoustic features on the fly and returns + random segments of (audio, feature) couples. + """ + + def __init__( + self, + ap, + items, + seq_len, + hop_len, + pad_short, + conv_pad=2, + is_training=True, + return_segments=True, + use_noise_augment=False, + use_cache=False, + ): + super().__init__() + self.ap = ap + self.item_list = items + self.seq_len = seq_len if return_segments else None + self.hop_len = hop_len + self.pad_short = pad_short + self.conv_pad = conv_pad + self.is_training = is_training + self.return_segments = return_segments + self.use_cache = use_cache + self.use_noise_augment = use_noise_augment + + if return_segments: + assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len." + self.feat_frame_len = seq_len // hop_len + (2 * conv_pad) + + # cache acoustic features + if use_cache: + self.create_feature_cache() + + def create_feature_cache(self): + self.manager = Manager() + self.cache = self.manager.list() + self.cache += [None for _ in range(len(self.item_list))] + + @staticmethod + def find_wav_files(path): + return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True) + + def __len__(self): + return len(self.item_list) + + def __getitem__(self, idx): + item = self.load_item(idx) + return item + + def load_test_samples(self, num_samples: int) -> List[Tuple]: + """Return test samples. + + Args: + num_samples (int): Number of samples to return. + + Returns: + List[Tuple]: melspectorgram and audio. + + Shapes: + - melspectrogram (Tensor): :math:`[C, T]` + - audio (Tensor): :math:`[T_audio]` + """ + samples = [] + return_segments = self.return_segments + self.return_segments = False + for idx in range(num_samples): + mel, audio = self.load_item(idx) + samples.append([mel, audio]) + self.return_segments = return_segments + return samples + + def load_item(self, idx): + """load (audio, feat) couple""" + # compute features from wav + wavpath = self.item_list[idx] + + if self.use_cache and self.cache[idx] is not None: + audio = self.cache[idx] + else: + audio = self.ap.load_wav(wavpath) + + if self.return_segments: + # correct audio length wrt segment length + if audio.shape[-1] < self.seq_len + self.pad_short: + audio = np.pad( + audio, (0, self.seq_len + self.pad_short - len(audio)), mode="constant", constant_values=0.0 + ) + assert ( + audio.shape[-1] >= self.seq_len + self.pad_short + ), f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}" + + # correct the audio length wrt hop length + p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1] + audio = np.pad(audio, (0, p), mode="constant", constant_values=0.0) + + if self.use_cache: + self.cache[idx] = audio + + if self.return_segments: + max_start = len(audio) - self.seq_len + start = random.randint(0, max_start) + end = start + self.seq_len + audio = audio[start:end] + + if self.use_noise_augment and self.is_training and self.return_segments: + audio = audio + (1 / 32768) * torch.randn_like(audio) + + mel = self.ap.melspectrogram(audio) + mel = mel[..., :-1] # ignore the padding + + audio = torch.from_numpy(audio).float() + mel = torch.from_numpy(mel).float().squeeze(0) + return (mel, audio) + + @staticmethod + def collate_full_clips(batch): + """This is used in tune_wavegrad.py. + It pads sequences to the max length.""" + max_mel_length = max([b[0].shape[1] for b in batch]) if len(batch) > 1 else batch[0][0].shape[1] + max_audio_length = max([b[1].shape[0] for b in batch]) if len(batch) > 1 else batch[0][1].shape[0] + + mels = torch.zeros([len(batch), batch[0][0].shape[0], max_mel_length]) + audios = torch.zeros([len(batch), max_audio_length]) + + for idx, b in enumerate(batch): + mel = b[0] + audio = b[1] + mels[idx, :, : mel.shape[1]] = mel + audios[idx, : audio.shape[0]] = audio + + return mels, audios diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4c4f5c48dfc9bb1f0e8d803083f25f479d930686 --- /dev/null +++ b/TTS/vocoder/datasets/wavernn_dataset.py @@ -0,0 +1,119 @@ +import logging + +import numpy as np +import torch +from torch.utils.data import Dataset + +from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize + +logger = logging.getLogger(__name__) + + +class WaveRNNDataset(Dataset): + """ + WaveRNN Dataset searchs for all the wav files under root path + and converts them to acoustic features on the fly. + """ + + def __init__(self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, return_segments=True): + super().__init__() + self.ap = ap + self.compute_feat = not isinstance(items[0], (tuple, list)) + self.item_list = items + self.seq_len = seq_len + self.hop_len = hop_len + self.mel_len = seq_len // hop_len + self.pad = pad + self.mode = mode + self.mulaw = mulaw + self.is_training = is_training + self.return_segments = return_segments + + assert self.seq_len % self.hop_len == 0 + + def __len__(self): + return len(self.item_list) + + def __getitem__(self, index): + item = self.load_item(index) + return item + + def load_test_samples(self, num_samples): + samples = [] + return_segments = self.return_segments + self.return_segments = False + for idx in range(num_samples): + mel, audio, _ = self.load_item(idx) + samples.append([mel, audio]) + self.return_segments = return_segments + return samples + + def load_item(self, index): + """ + load (audio, feat) couple if feature_path is set + else compute it on the fly + """ + if self.compute_feat: + wavpath = self.item_list[index] + audio = self.ap.load_wav(wavpath) + if self.return_segments: + min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len) + else: + min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len) + if audio.shape[0] < min_audio_len: + logger.warning("Instance is too short: %s", wavpath) + audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len]) + mel = self.ap.melspectrogram(audio) + + if self.mode in ["gauss", "mold"]: + x_input = audio + elif isinstance(self.mode, int): + x_input = ( + mulaw_encode(wav=audio, mulaw_qc=self.mode) + if self.mulaw + else quantize(x=audio, quantize_bits=self.mode) + ) + else: + raise RuntimeError("Unknown dataset mode - ", self.mode) + + else: + wavpath, feat_path = self.item_list[index] + mel = np.load(feat_path.replace("/quant/", "/mel/")) + + if mel.shape[-1] < self.mel_len + 2 * self.pad: + logger.warning("Instance is too short: %s", wavpath) + self.item_list[index] = self.item_list[index + 1] + feat_path = self.item_list[index] + mel = np.load(feat_path.replace("/quant/", "/mel/")) + if self.mode in ["gauss", "mold"]: + x_input = self.ap.load_wav(wavpath) + elif isinstance(self.mode, int): + x_input = np.load(feat_path.replace("/mel/", "/quant/")) + else: + raise RuntimeError("Unknown dataset mode - ", self.mode) + + return mel, x_input, wavpath + + def collate(self, batch): + mel_win = self.seq_len // self.hop_len + 2 * self.pad + max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch] + + mel_offsets = [np.random.randint(0, offset) for offset in max_offsets] + sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets] + + mels = [x[0][:, mel_offsets[i] : mel_offsets[i] + mel_win] for i, x in enumerate(batch)] + + coarse = [x[1][sig_offsets[i] : sig_offsets[i] + self.seq_len + 1] for i, x in enumerate(batch)] + + mels = np.stack(mels).astype(np.float32) + if self.mode in ["gauss", "mold"]: + coarse = np.stack(coarse).astype(np.float32) + coarse = torch.FloatTensor(coarse) + x_input = coarse[:, : self.seq_len] + elif isinstance(self.mode, int): + coarse = np.stack(coarse).astype(np.int64) + coarse = torch.LongTensor(coarse) + x_input = 2 * coarse[:, : self.seq_len].float() / (2**self.mode - 1.0) - 1.0 + y_coarse = coarse[:, 1:] + mels = torch.FloatTensor(mels) + return x_input, mels, y_coarse diff --git a/TTS/vocoder/layers/__init__.py b/TTS/vocoder/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/vocoder/layers/hifigan.py b/TTS/vocoder/layers/hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..8dd75133bbc38b95ce6ea6a9a05079b9b205a539 --- /dev/null +++ b/TTS/vocoder/layers/hifigan.py @@ -0,0 +1,56 @@ +from torch import nn +from torch.nn.utils.parametrize import remove_parametrizations + + +# pylint: disable=dangerous-default-value +class ResStack(nn.Module): + def __init__(self, kernel, channel, padding, dilations=[1, 3, 5]): + super().__init__() + resstack = [] + for dilation in dilations: + resstack += [ + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(dilation), + nn.utils.parametrizations.weight_norm( + nn.Conv1d(channel, channel, kernel_size=kernel, dilation=dilation) + ), + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(padding), + nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)), + ] + self.resstack = nn.Sequential(*resstack) + + self.shortcut = nn.utils.parametrizations.weight_norm(nn.Conv1d(channel, channel, kernel_size=1)) + + def forward(self, x): + x1 = self.shortcut(x) + x2 = self.resstack(x) + return x1 + x2 + + def remove_weight_norm(self): + remove_parametrizations(self.shortcut, "weight") + remove_parametrizations(self.resstack[2], "weight") + remove_parametrizations(self.resstack[5], "weight") + remove_parametrizations(self.resstack[8], "weight") + remove_parametrizations(self.resstack[11], "weight") + remove_parametrizations(self.resstack[14], "weight") + remove_parametrizations(self.resstack[17], "weight") + + +class MRF(nn.Module): + def __init__(self, kernels, channel, dilations=[1, 3, 5]): # # pylint: disable=dangerous-default-value + super().__init__() + self.resblock1 = ResStack(kernels[0], channel, 0, dilations) + self.resblock2 = ResStack(kernels[1], channel, 6, dilations) + self.resblock3 = ResStack(kernels[2], channel, 12, dilations) + + def forward(self, x): + x1 = self.resblock1(x) + x2 = self.resblock2(x) + x3 = self.resblock3(x) + return x1 + x2 + x3 + + def remove_weight_norm(self): + self.resblock1.remove_weight_norm() + self.resblock2.remove_weight_norm() + self.resblock3.remove_weight_norm() diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..8d4dd725ef32a8369f890e6bdd041054cbd45d15 --- /dev/null +++ b/TTS/vocoder/layers/losses.py @@ -0,0 +1,368 @@ +from typing import Dict, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.utils.audio.torch_transforms import TorchSTFT +from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss + +################################# +# GENERATOR LOSSES +################################# + + +class STFTLoss(nn.Module): + """STFT loss. Input generate and real waveforms are converted + to spectrograms compared with L1 and Spectral convergence losses. + It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf""" + + def __init__(self, n_fft, hop_length, win_length): + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.stft = TorchSTFT(n_fft, hop_length, win_length) + + def forward(self, y_hat, y): + y_hat_M = self.stft(y_hat) + y_M = self.stft(y) + # magnitude loss + loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M)) + # spectral convergence loss + loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro") + return loss_mag, loss_sc + + +class MultiScaleSTFTLoss(torch.nn.Module): + """Multi-scale STFT loss. Input generate and real waveforms are converted + to spectrograms compared with L1 and Spectral convergence losses. + It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf""" + + def __init__(self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), win_lengths=(600, 1200, 240)): + super().__init__() + self.loss_funcs = torch.nn.ModuleList() + for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths): + self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length)) + + def forward(self, y_hat, y): + N = len(self.loss_funcs) + loss_sc = 0 + loss_mag = 0 + for f in self.loss_funcs: + lm, lsc = f(y_hat, y) + loss_mag += lm + loss_sc += lsc + loss_sc /= N + loss_mag /= N + return loss_mag, loss_sc + + +class L1SpecLoss(nn.Module): + """L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf""" + + def __init__( + self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True + ): + super().__init__() + self.use_mel = use_mel + self.stft = TorchSTFT( + n_fft, + hop_length, + win_length, + sample_rate=sample_rate, + mel_fmin=mel_fmin, + mel_fmax=mel_fmax, + n_mels=n_mels, + use_mel=use_mel, + ) + + def forward(self, y_hat, y): + y_hat_M = self.stft(y_hat) + y_M = self.stft(y) + # magnitude loss + loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M)) + return loss_mag + + +class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss): + """Multiscale STFT loss for multi band model outputs. + From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106""" + + # pylint: disable=no-self-use + def forward(self, y_hat, y): + y_hat = y_hat.view(-1, 1, y_hat.shape[2]) + y = y.view(-1, 1, y.shape[2]) + return super().forward(y_hat.squeeze(1), y.squeeze(1)) + + +class MSEGLoss(nn.Module): + """Mean Squared Generator Loss""" + + # pylint: disable=no-self-use + def forward(self, score_real): + loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape)) + return loss_fake + + +class HingeGLoss(nn.Module): + """Hinge Discriminator Loss""" + + # pylint: disable=no-self-use + def forward(self, score_real): + # TODO: this might be wrong + loss_fake = torch.mean(F.relu(1.0 - score_real)) + return loss_fake + + +################################## +# DISCRIMINATOR LOSSES +################################## + + +class MSEDLoss(nn.Module): + """Mean Squared Discriminator Loss""" + + def __init__( + self, + ): + super().__init__() + self.loss_func = nn.MSELoss() + + # pylint: disable=no-self-use + def forward(self, score_fake, score_real): + loss_real = self.loss_func(score_real, score_real.new_ones(score_real.shape)) + loss_fake = self.loss_func(score_fake, score_fake.new_zeros(score_fake.shape)) + loss_d = loss_real + loss_fake + return loss_d, loss_real, loss_fake + + +class HingeDLoss(nn.Module): + """Hinge Discriminator Loss""" + + # pylint: disable=no-self-use + def forward(self, score_fake, score_real): + loss_real = torch.mean(F.relu(1.0 - score_real)) + loss_fake = torch.mean(F.relu(1.0 + score_fake)) + loss_d = loss_real + loss_fake + return loss_d, loss_real, loss_fake + + +class MelganFeatureLoss(nn.Module): + def __init__( + self, + ): + super().__init__() + self.loss_func = nn.L1Loss() + + # pylint: disable=no-self-use + def forward(self, fake_feats, real_feats): + loss_feats = 0 + num_feats = 0 + for idx, _ in enumerate(fake_feats): + for fake_feat, real_feat in zip(fake_feats[idx], real_feats[idx]): + loss_feats += self.loss_func(fake_feat, real_feat) + num_feats += 1 + loss_feats = loss_feats / num_feats + return loss_feats + + +##################################### +# LOSS WRAPPERS +##################################### + + +def _apply_G_adv_loss(scores_fake, loss_func): + """Compute G adversarial loss function + and normalize values""" + adv_loss = 0 + if isinstance(scores_fake, list): + for score_fake in scores_fake: + fake_loss = loss_func(score_fake) + adv_loss += fake_loss + adv_loss /= len(scores_fake) + else: + fake_loss = loss_func(scores_fake) + adv_loss = fake_loss + return adv_loss + + +def _apply_D_loss(scores_fake, scores_real, loss_func): + """Compute D loss func and normalize loss values""" + loss = 0 + real_loss = 0 + fake_loss = 0 + if isinstance(scores_fake, list): + # multi-scale loss + for score_fake, score_real in zip(scores_fake, scores_real): + total_loss, real_loss_, fake_loss_ = loss_func(score_fake=score_fake, score_real=score_real) + loss += total_loss + real_loss += real_loss_ + fake_loss += fake_loss_ + # normalize loss values with number of scales (discriminators) + loss /= len(scores_fake) + real_loss /= len(scores_real) + fake_loss /= len(scores_fake) + else: + # single scale loss + total_loss, real_loss, fake_loss = loss_func(scores_fake, scores_real) + loss = total_loss + return loss, real_loss, fake_loss + + +################################## +# MODEL LOSSES +################################## + + +class GeneratorLoss(nn.Module): + """Generator Loss Wrapper. Based on model configuration it sets a right set of loss functions and computes + losses. It allows to experiment with different combinations of loss functions with different models by just + changing configurations. + + Args: + C (Coqpit): model configuration. + """ + + def __init__(self, C): + super().__init__() + assert not ( + C.use_mse_gan_loss and C.use_hinge_gan_loss + ), " [!] Cannot use HingeGANLoss and MSEGANLoss together." + + self.use_stft_loss = C.use_stft_loss if "use_stft_loss" in C else False + self.use_subband_stft_loss = C.use_subband_stft_loss if "use_subband_stft_loss" in C else False + self.use_mse_gan_loss = C.use_mse_gan_loss if "use_mse_gan_loss" in C else False + self.use_hinge_gan_loss = C.use_hinge_gan_loss if "use_hinge_gan_loss" in C else False + self.use_feat_match_loss = C.use_feat_match_loss if "use_feat_match_loss" in C else False + self.use_l1_spec_loss = C.use_l1_spec_loss if "use_l1_spec_loss" in C else False + + self.stft_loss_weight = C.stft_loss_weight if "stft_loss_weight" in C else 0.0 + self.subband_stft_loss_weight = C.subband_stft_loss_weight if "subband_stft_loss_weight" in C else 0.0 + self.mse_gan_loss_weight = C.mse_G_loss_weight if "mse_G_loss_weight" in C else 0.0 + self.hinge_gan_loss_weight = C.hinge_G_loss_weight if "hinde_G_loss_weight" in C else 0.0 + self.feat_match_loss_weight = C.feat_match_loss_weight if "feat_match_loss_weight" in C else 0.0 + self.l1_spec_loss_weight = C.l1_spec_loss_weight if "l1_spec_loss_weight" in C else 0.0 + + if C.use_stft_loss: + self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params) + if C.use_subband_stft_loss: + self.subband_stft_loss = MultiScaleSubbandSTFTLoss(**C.subband_stft_loss_params) + if C.use_mse_gan_loss: + self.mse_loss = MSEGLoss() + if C.use_hinge_gan_loss: + self.hinge_loss = HingeGLoss() + if C.use_feat_match_loss: + self.feat_match_loss = MelganFeatureLoss() + if C.use_l1_spec_loss: + assert C.audio["sample_rate"] == C.l1_spec_loss_params["sample_rate"] + self.l1_spec_loss = L1SpecLoss(**C.l1_spec_loss_params) + + def forward( + self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None + ): + gen_loss = 0 + adv_loss = 0 + return_dict = {} + + # STFT Loss + if self.use_stft_loss: + stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, : y.size(2)].squeeze(1), y.squeeze(1)) + return_dict["G_stft_loss_mg"] = stft_loss_mg + return_dict["G_stft_loss_sc"] = stft_loss_sc + gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc) + + # L1 Spec loss + if self.use_l1_spec_loss: + l1_spec_loss = self.l1_spec_loss(y_hat, y) + return_dict["G_l1_spec_loss"] = l1_spec_loss + gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss + + # subband STFT Loss + if self.use_subband_stft_loss: + subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub) + return_dict["G_subband_stft_loss_mg"] = subband_stft_loss_mg + return_dict["G_subband_stft_loss_sc"] = subband_stft_loss_sc + gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc) + + # multiscale MSE adversarial loss + if self.use_mse_gan_loss and scores_fake is not None: + mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss) + return_dict["G_mse_fake_loss"] = mse_fake_loss + adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss + + # multiscale Hinge adversarial loss + if self.use_hinge_gan_loss and not scores_fake is not None: + hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss) + return_dict["G_hinge_fake_loss"] = hinge_fake_loss + adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss + + # Feature Matching Loss + if self.use_feat_match_loss and feats_fake is not None: + feat_match_loss = self.feat_match_loss(feats_fake, feats_real) + return_dict["G_feat_match_loss"] = feat_match_loss + adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss + return_dict["loss"] = gen_loss + adv_loss + return_dict["G_gen_loss"] = gen_loss + return_dict["G_adv_loss"] = adv_loss + return return_dict + + +class DiscriminatorLoss(nn.Module): + """Like ```GeneratorLoss```""" + + def __init__(self, C): + super().__init__() + assert not ( + C.use_mse_gan_loss and C.use_hinge_gan_loss + ), " [!] Cannot use HingeGANLoss and MSEGANLoss together." + + self.use_mse_gan_loss = C.use_mse_gan_loss + self.use_hinge_gan_loss = C.use_hinge_gan_loss + + if C.use_mse_gan_loss: + self.mse_loss = MSEDLoss() + if C.use_hinge_gan_loss: + self.hinge_loss = HingeDLoss() + + def forward(self, scores_fake, scores_real): + loss = 0 + return_dict = {} + + if self.use_mse_gan_loss: + mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss( + scores_fake=scores_fake, scores_real=scores_real, loss_func=self.mse_loss + ) + return_dict["D_mse_gan_loss"] = mse_D_loss + return_dict["D_mse_gan_real_loss"] = mse_D_real_loss + return_dict["D_mse_gan_fake_loss"] = mse_D_fake_loss + loss += mse_D_loss + + if self.use_hinge_gan_loss: + hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss( + scores_fake=scores_fake, scores_real=scores_real, loss_func=self.hinge_loss + ) + return_dict["D_hinge_gan_loss"] = hinge_D_loss + return_dict["D_hinge_gan_real_loss"] = hinge_D_real_loss + return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss + loss += hinge_D_loss + + return_dict["loss"] = loss + return return_dict + + +class WaveRNNLoss(nn.Module): + def __init__(self, wave_rnn_mode: Union[str, int]): + super().__init__() + if wave_rnn_mode == "mold": + self.loss_func = discretized_mix_logistic_loss + elif wave_rnn_mode == "gauss": + self.loss_func = gaussian_loss + elif isinstance(wave_rnn_mode, int): + self.loss_func = torch.nn.CrossEntropyLoss() + else: + raise ValueError(" [!] Unknown mode for Wavernn.") + + def forward(self, y_hat, y) -> Dict: + loss = self.loss_func(y_hat, y) + return {"loss": loss} diff --git a/TTS/vocoder/layers/lvc_block.py b/TTS/vocoder/layers/lvc_block.py new file mode 100644 index 0000000000000000000000000000000000000000..8913a1132ec769fd304077412289c01c0d1cb17b --- /dev/null +++ b/TTS/vocoder/layers/lvc_block.py @@ -0,0 +1,198 @@ +import torch +import torch.nn.functional as F + + +class KernelPredictor(torch.nn.Module): + """Kernel predictor for the location-variable convolutions""" + + def __init__( # pylint: disable=dangerous-default-value + self, + cond_channels, + conv_in_channels, + conv_out_channels, + conv_layers, + conv_kernel_size=3, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + kpnet_nonlinear_activation="LeakyReLU", + kpnet_nonlinear_activation_params={"negative_slope": 0.1}, + ): + """ + Args: + cond_channels (int): number of channel for the conditioning sequence, + conv_in_channels (int): number of channel for the input sequence, + conv_out_channels (int): number of channel for the output sequence, + conv_layers (int): + kpnet_ + """ + super().__init__() + + self.conv_in_channels = conv_in_channels + self.conv_out_channels = conv_out_channels + self.conv_kernel_size = conv_kernel_size + self.conv_layers = conv_layers + + l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers + l_b = conv_out_channels * conv_layers + + padding = (kpnet_conv_size - 1) // 2 + self.input_conv = torch.nn.Sequential( + torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.residual_conv = torch.nn.Sequential( + torch.nn.Dropout(kpnet_dropout), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Dropout(kpnet_dropout), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Dropout(kpnet_dropout), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size, padding=padding, bias=True) + self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding, bias=True) + + def forward(self, c): + """ + Args: + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + Returns: + """ + batch, _, cond_length = c.shape + + c = self.input_conv(c) + c = c + self.residual_conv(c) + k = self.kernel_conv(c) + b = self.bias_conv(c) + + kernels = k.contiguous().view( + batch, self.conv_layers, self.conv_in_channels, self.conv_out_channels, self.conv_kernel_size, cond_length + ) + bias = b.contiguous().view(batch, self.conv_layers, self.conv_out_channels, cond_length) + return kernels, bias + + +class LVCBlock(torch.nn.Module): + """the location-variable convolutions""" + + def __init__( + self, + in_channels, + cond_channels, + upsample_ratio, + conv_layers=4, + conv_kernel_size=3, + cond_hop_length=256, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + ): + super().__init__() + + self.cond_hop_length = cond_hop_length + self.conv_layers = conv_layers + self.conv_kernel_size = conv_kernel_size + self.convs = torch.nn.ModuleList() + + self.upsample = torch.nn.ConvTranspose1d( + in_channels, + in_channels, + kernel_size=upsample_ratio * 2, + stride=upsample_ratio, + padding=upsample_ratio // 2 + upsample_ratio % 2, + output_padding=upsample_ratio % 2, + ) + + self.kernel_predictor = KernelPredictor( + cond_channels=cond_channels, + conv_in_channels=in_channels, + conv_out_channels=2 * in_channels, + conv_layers=conv_layers, + conv_kernel_size=conv_kernel_size, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=kpnet_dropout, + ) + + for i in range(conv_layers): + padding = (3**i) * int((conv_kernel_size - 1) / 2) + conv = torch.nn.Conv1d( + in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3**i + ) + + self.convs.append(conv) + + def forward(self, x, c): + """forward propagation of the location-variable convolutions. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length) + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + + Returns: + Tensor: the output sequence (batch, in_channels, in_length) + """ + in_channels = x.shape[1] + kernels, bias = self.kernel_predictor(c) + + x = F.leaky_relu(x, 0.2) + x = self.upsample(x) + + for i in range(self.conv_layers): + y = F.leaky_relu(x, 0.2) + y = self.convs[i](y) + y = F.leaky_relu(y, 0.2) + + k = kernels[:, i, :, :, :, :] + b = bias[:, i, :, :] + y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length) + x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :]) + return x + + @staticmethod + def location_variable_convolution(x, kernel, bias, dilation, hop_size): + """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. + Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length). + kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) + bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) + dilation (int): the dilation of convolution. + hop_size (int): the hop_size of the conditioning sequence. + Returns: + (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). + """ + batch, _, in_length = x.shape + batch, _, out_channels, kernel_size, kernel_length = kernel.shape + + assert in_length == ( + kernel_length * hop_size + ), f"length of (x, kernel) is not matched, {in_length} vs {kernel_length * hop_size}" + + padding = dilation * int((kernel_size - 1) / 2) + x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) + x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) + + if hop_size < dilation: + x = F.pad(x, (0, dilation), "constant", 0) + x = x.unfold( + 3, dilation, dilation + ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) + x = x[:, :, :, :, :hop_size] + x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) + + o = torch.einsum("bildsk,biokl->bolsd", x, kernel) + o = o + bias.unsqueeze(-1).unsqueeze(-1) + o = o.contiguous().view(batch, out_channels, -1) + return o diff --git a/TTS/vocoder/layers/melgan.py b/TTS/vocoder/layers/melgan.py new file mode 100644 index 0000000000000000000000000000000000000000..7ad41a0f78d7e33021b445936a79318c63447284 --- /dev/null +++ b/TTS/vocoder/layers/melgan.py @@ -0,0 +1,43 @@ +from torch import nn +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations + + +class ResidualStack(nn.Module): + def __init__(self, channels, num_res_blocks, kernel_size): + super().__init__() + + assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd." + base_padding = (kernel_size - 1) // 2 + + self.blocks = nn.ModuleList() + for idx in range(num_res_blocks): + layer_kernel_size = kernel_size + layer_dilation = layer_kernel_size**idx + layer_padding = base_padding * layer_dilation + self.blocks += [ + nn.Sequential( + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(layer_padding), + weight_norm( + nn.Conv1d(channels, channels, kernel_size=kernel_size, dilation=layer_dilation, bias=True) + ), + nn.LeakyReLU(0.2), + weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)), + ) + ] + + self.shortcuts = nn.ModuleList( + [weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for _ in range(num_res_blocks)] + ) + + def forward(self, x): + for block, shortcut in zip(self.blocks, self.shortcuts): + x = shortcut(x) + block(x) + return x + + def remove_weight_norm(self): + for block, shortcut in zip(self.blocks, self.shortcuts): + remove_parametrizations(block[2], "weight") + remove_parametrizations(block[4], "weight") + remove_parametrizations(shortcut, "weight") diff --git a/TTS/vocoder/layers/parallel_wavegan.py b/TTS/vocoder/layers/parallel_wavegan.py new file mode 100644 index 0000000000000000000000000000000000000000..51142e5eceb20564585635a9040a24bc8eb3b6e3 --- /dev/null +++ b/TTS/vocoder/layers/parallel_wavegan.py @@ -0,0 +1,77 @@ +import torch +from torch.nn import functional as F + + +class ResidualBlock(torch.nn.Module): + """Residual block module in WaveNet.""" + + def __init__( + self, + kernel_size=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + dropout=0.0, + dilation=1, + bias=True, + use_causal_conv=False, + ): + super().__init__() + self.dropout = dropout + # no future time stamps available + if use_causal_conv: + padding = (kernel_size - 1) * dilation + else: + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + padding = (kernel_size - 1) // 2 * dilation + self.use_causal_conv = use_causal_conv + + # dilation conv + self.conv = torch.nn.Conv1d( + res_channels, gate_channels, kernel_size, padding=padding, dilation=dilation, bias=bias + ) + + # local conditioning + if aux_channels > 0: + self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False) + else: + self.conv1x1_aux = None + + # conv output is split into two groups + gate_out_channels = gate_channels // 2 + self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias) + self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias) + + def forward(self, x, c): + """ + x: B x D_res x T + c: B x D_aux x T + """ + residual = x + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.conv(x) + + # remove future time steps if use_causal_conv conv + x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x + + # split into two part for gated activation + splitdim = 1 + xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) + + # local conditioning + if c is not None: + assert self.conv1x1_aux is not None + c = self.conv1x1_aux(c) + ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ca, xb + cb + + x = torch.tanh(xa) * torch.sigmoid(xb) + + # for skip connection + s = self.conv1x1_skip(x) + + # for residual connection + x = (self.conv1x1_out(x) + residual) * (0.5**2) + + return x, s diff --git a/TTS/vocoder/layers/pqmf.py b/TTS/vocoder/layers/pqmf.py new file mode 100644 index 0000000000000000000000000000000000000000..6253efbbefc32222464a97bee99707d46bcdcf8b --- /dev/null +++ b/TTS/vocoder/layers/pqmf.py @@ -0,0 +1,53 @@ +import numpy as np +import torch +import torch.nn.functional as F +from scipy import signal as sig + + +# adapted from +# https://github.com/kan-bayashi/ParallelWaveGAN/tree/master/parallel_wavegan +class PQMF(torch.nn.Module): + def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0): + super().__init__() + + self.N = N + self.taps = taps + self.cutoff = cutoff + self.beta = beta + + QMF = sig.firwin(taps + 1, cutoff, window=("kaiser", beta)) + H = np.zeros((N, len(QMF))) + G = np.zeros((N, len(QMF))) + for k in range(N): + constant_factor = ( + (2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - ((taps - 1) / 2)) + ) # TODO: (taps - 1) -> taps + phase = (-1) ** k * np.pi / 4 + H[k] = 2 * QMF * np.cos(constant_factor + phase) + + G[k] = 2 * QMF * np.cos(constant_factor - phase) + + H = torch.from_numpy(H[:, None, :]).float() + G = torch.from_numpy(G[None, :, :]).float() + + self.register_buffer("H", H) + self.register_buffer("G", G) + + updown_filter = torch.zeros((N, N, N)).float() + for k in range(N): + updown_filter[k, k, 0] = 1.0 + self.register_buffer("updown_filter", updown_filter) + self.N = N + + self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0) + + def forward(self, x): + return self.analysis(x) + + def analysis(self, x): + return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N) + + def synthesis(self, x): + x = F.conv_transpose1d(x, self.updown_filter * self.N, stride=self.N) + x = F.conv1d(x, self.G, padding=self.taps // 2) + return x diff --git a/TTS/vocoder/layers/qmf.dat b/TTS/vocoder/layers/qmf.dat new file mode 100644 index 0000000000000000000000000000000000000000..17eab1379de991c36897c2ce701802ef76849c0d --- /dev/null +++ b/TTS/vocoder/layers/qmf.dat @@ -0,0 +1,640 @@ + 0.0000000e+000 + -5.5252865e-004 + -5.6176926e-004 + -4.9475181e-004 + -4.8752280e-004 + -4.8937912e-004 + -5.0407143e-004 + -5.2265643e-004 + -5.4665656e-004 + -5.6778026e-004 + -5.8709305e-004 + -6.1327474e-004 + -6.3124935e-004 + -6.5403334e-004 + -6.7776908e-004 + -6.9416146e-004 + -7.1577365e-004 + -7.2550431e-004 + -7.4409419e-004 + -7.4905981e-004 + -7.6813719e-004 + -7.7248486e-004 + -7.8343323e-004 + -7.7798695e-004 + -7.8036647e-004 + -7.8014496e-004 + -7.7579773e-004 + -7.6307936e-004 + -7.5300014e-004 + -7.3193572e-004 + -7.2153920e-004 + -6.9179375e-004 + -6.6504151e-004 + -6.3415949e-004 + -5.9461189e-004 + -5.5645764e-004 + -5.1455722e-004 + -4.6063255e-004 + -4.0951215e-004 + -3.5011759e-004 + -2.8969812e-004 + -2.0983373e-004 + -1.4463809e-004 + -6.1733441e-005 + 1.3494974e-005 + 1.0943831e-004 + 2.0430171e-004 + 2.9495311e-004 + 4.0265402e-004 + 5.1073885e-004 + 6.2393761e-004 + 7.4580259e-004 + 8.6084433e-004 + 9.8859883e-004 + 1.1250155e-003 + 1.2577885e-003 + 1.3902495e-003 + 1.5443220e-003 + 1.6868083e-003 + 1.8348265e-003 + 1.9841141e-003 + 2.1461584e-003 + 2.3017255e-003 + 2.4625617e-003 + 2.6201759e-003 + 2.7870464e-003 + 2.9469448e-003 + 3.1125421e-003 + 3.2739613e-003 + 3.4418874e-003 + 3.6008268e-003 + 3.7603923e-003 + 3.9207432e-003 + 4.0819753e-003 + 4.2264269e-003 + 4.3730720e-003 + 4.5209853e-003 + 4.6606461e-003 + 4.7932561e-003 + 4.9137604e-003 + 5.0393023e-003 + 5.1407354e-003 + 5.2461166e-003 + 5.3471681e-003 + 5.4196776e-003 + 5.4876040e-003 + 5.5475715e-003 + 5.5938023e-003 + 5.6220643e-003 + 5.6455197e-003 + 5.6389200e-003 + 5.6266114e-003 + 5.5917129e-003 + 5.5404364e-003 + 5.4753783e-003 + 5.3838976e-003 + 5.2715759e-003 + 5.1382275e-003 + 4.9839688e-003 + 4.8109469e-003 + 4.6039530e-003 + 4.3801862e-003 + 4.1251642e-003 + 3.8456408e-003 + 3.5401247e-003 + 3.2091886e-003 + 2.8446758e-003 + 2.4508540e-003 + 2.0274176e-003 + 1.5784683e-003 + 1.0902329e-003 + 5.8322642e-004 + 2.7604519e-005 + -5.4642809e-004 + -1.1568136e-003 + -1.8039473e-003 + -2.4826724e-003 + -3.1933778e-003 + -3.9401124e-003 + -4.7222596e-003 + -5.5337211e-003 + -6.3792293e-003 + -7.2615817e-003 + -8.1798233e-003 + -9.1325330e-003 + -1.0115022e-002 + -1.1131555e-002 + -1.2185000e-002 + -1.3271822e-002 + -1.4390467e-002 + -1.5540555e-002 + -1.6732471e-002 + -1.7943338e-002 + -1.9187243e-002 + -2.0453179e-002 + -2.1746755e-002 + -2.3068017e-002 + -2.4416099e-002 + -2.5787585e-002 + -2.7185943e-002 + -2.8607217e-002 + -3.0050266e-002 + -3.1501761e-002 + -3.2975408e-002 + -3.4462095e-002 + -3.5969756e-002 + -3.7481285e-002 + -3.9005368e-002 + -4.0534917e-002 + -4.2064909e-002 + -4.3609754e-002 + -4.5148841e-002 + -4.6684303e-002 + -4.8216572e-002 + -4.9738576e-002 + -5.1255616e-002 + -5.2763075e-002 + -5.4245277e-002 + -5.5717365e-002 + -5.7161645e-002 + -5.8591568e-002 + -5.9983748e-002 + -6.1345517e-002 + -6.2685781e-002 + -6.3971590e-002 + -6.5224711e-002 + -6.6436751e-002 + -6.7607599e-002 + -6.8704383e-002 + -6.9763024e-002 + -7.0762871e-002 + -7.1700267e-002 + -7.2568258e-002 + -7.3362026e-002 + -7.4100364e-002 + -7.4745256e-002 + -7.5313734e-002 + -7.5800836e-002 + -7.6199248e-002 + -7.6499217e-002 + -7.6709349e-002 + -7.6817398e-002 + -7.6823001e-002 + -7.6720492e-002 + -7.6505072e-002 + -7.6174832e-002 + -7.5730576e-002 + -7.5157626e-002 + -7.4466439e-002 + -7.3640601e-002 + -7.2677464e-002 + -7.1582636e-002 + -7.0353307e-002 + -6.8966401e-002 + -6.7452502e-002 + -6.5769067e-002 + -6.3944481e-002 + -6.1960278e-002 + -5.9816657e-002 + -5.7515269e-002 + -5.5046003e-002 + -5.2409382e-002 + -4.9597868e-002 + -4.6630331e-002 + -4.3476878e-002 + -4.0145828e-002 + -3.6641812e-002 + -3.2958393e-002 + -2.9082401e-002 + -2.5030756e-002 + -2.0799707e-002 + -1.6370126e-002 + -1.1762383e-002 + -6.9636862e-003 + -1.9765601e-003 + 3.2086897e-003 + 8.5711749e-003 + 1.4128883e-002 + 1.9883413e-002 + 2.5822729e-002 + 3.1953127e-002 + 3.8277657e-002 + 4.4780682e-002 + 5.1480418e-002 + 5.8370533e-002 + 6.5440985e-002 + 7.2694330e-002 + 8.0137293e-002 + 8.7754754e-002 + 9.5553335e-002 + 1.0353295e-001 + 1.1168269e-001 + 1.2000780e-001 + 1.2850029e-001 + 1.3715518e-001 + 1.4597665e-001 + 1.5496071e-001 + 1.6409589e-001 + 1.7338082e-001 + 1.8281725e-001 + 1.9239667e-001 + 2.0212502e-001 + 2.1197359e-001 + 2.2196527e-001 + 2.3206909e-001 + 2.4230169e-001 + 2.5264803e-001 + 2.6310533e-001 + 2.7366340e-001 + 2.8432142e-001 + 2.9507167e-001 + 3.0590986e-001 + 3.1682789e-001 + 3.2781137e-001 + 3.3887227e-001 + 3.4999141e-001 + 3.6115899e-001 + 3.7237955e-001 + 3.8363500e-001 + 3.9492118e-001 + 4.0623177e-001 + 4.1756969e-001 + 4.2891199e-001 + 4.4025538e-001 + 4.5159965e-001 + 4.6293081e-001 + 4.7424532e-001 + 4.8552531e-001 + 4.9677083e-001 + 5.0798175e-001 + 5.1912350e-001 + 5.3022409e-001 + 5.4125534e-001 + 5.5220513e-001 + 5.6307891e-001 + 5.7385241e-001 + 5.8454032e-001 + 5.9511231e-001 + 6.0557835e-001 + 6.1591099e-001 + 6.2612427e-001 + 6.3619801e-001 + 6.4612697e-001 + 6.5590163e-001 + 6.6551399e-001 + 6.7496632e-001 + 6.8423533e-001 + 6.9332824e-001 + 7.0223887e-001 + 7.1094104e-001 + 7.1944626e-001 + 7.2774489e-001 + 7.3582118e-001 + 7.4368279e-001 + 7.5131375e-001 + 7.5870808e-001 + 7.6586749e-001 + 7.7277809e-001 + 7.7942875e-001 + 7.8583531e-001 + 7.9197358e-001 + 7.9784664e-001 + 8.0344858e-001 + 8.0876950e-001 + 8.1381913e-001 + 8.1857760e-001 + 8.2304199e-001 + 8.2722753e-001 + 8.3110385e-001 + 8.3469374e-001 + 8.3797173e-001 + 8.4095414e-001 + 8.4362383e-001 + 8.4598185e-001 + 8.4803158e-001 + 8.4978052e-001 + 8.5119715e-001 + 8.5230470e-001 + 8.5310209e-001 + 8.5357206e-001 + 8.5373856e-001 + 8.5357206e-001 + 8.5310209e-001 + 8.5230470e-001 + 8.5119715e-001 + 8.4978052e-001 + 8.4803158e-001 + 8.4598185e-001 + 8.4362383e-001 + 8.4095414e-001 + 8.3797173e-001 + 8.3469374e-001 + 8.3110385e-001 + 8.2722753e-001 + 8.2304199e-001 + 8.1857760e-001 + 8.1381913e-001 + 8.0876950e-001 + 8.0344858e-001 + 7.9784664e-001 + 7.9197358e-001 + 7.8583531e-001 + 7.7942875e-001 + 7.7277809e-001 + 7.6586749e-001 + 7.5870808e-001 + 7.5131375e-001 + 7.4368279e-001 + 7.3582118e-001 + 7.2774489e-001 + 7.1944626e-001 + 7.1094104e-001 + 7.0223887e-001 + 6.9332824e-001 + 6.8423533e-001 + 6.7496632e-001 + 6.6551399e-001 + 6.5590163e-001 + 6.4612697e-001 + 6.3619801e-001 + 6.2612427e-001 + 6.1591099e-001 + 6.0557835e-001 + 5.9511231e-001 + 5.8454032e-001 + 5.7385241e-001 + 5.6307891e-001 + 5.5220513e-001 + 5.4125534e-001 + 5.3022409e-001 + 5.1912350e-001 + 5.0798175e-001 + 4.9677083e-001 + 4.8552531e-001 + 4.7424532e-001 + 4.6293081e-001 + 4.5159965e-001 + 4.4025538e-001 + 4.2891199e-001 + 4.1756969e-001 + 4.0623177e-001 + 3.9492118e-001 + 3.8363500e-001 + 3.7237955e-001 + 3.6115899e-001 + 3.4999141e-001 + 3.3887227e-001 + 3.2781137e-001 + 3.1682789e-001 + 3.0590986e-001 + 2.9507167e-001 + 2.8432142e-001 + 2.7366340e-001 + 2.6310533e-001 + 2.5264803e-001 + 2.4230169e-001 + 2.3206909e-001 + 2.2196527e-001 + 2.1197359e-001 + 2.0212502e-001 + 1.9239667e-001 + 1.8281725e-001 + 1.7338082e-001 + 1.6409589e-001 + 1.5496071e-001 + 1.4597665e-001 + 1.3715518e-001 + 1.2850029e-001 + 1.2000780e-001 + 1.1168269e-001 + 1.0353295e-001 + 9.5553335e-002 + 8.7754754e-002 + 8.0137293e-002 + 7.2694330e-002 + 6.5440985e-002 + 5.8370533e-002 + 5.1480418e-002 + 4.4780682e-002 + 3.8277657e-002 + 3.1953127e-002 + 2.5822729e-002 + 1.9883413e-002 + 1.4128883e-002 + 8.5711749e-003 + 3.2086897e-003 + -1.9765601e-003 + -6.9636862e-003 + -1.1762383e-002 + -1.6370126e-002 + -2.0799707e-002 + -2.5030756e-002 + -2.9082401e-002 + -3.2958393e-002 + -3.6641812e-002 + -4.0145828e-002 + -4.3476878e-002 + -4.6630331e-002 + -4.9597868e-002 + -5.2409382e-002 + -5.5046003e-002 + -5.7515269e-002 + -5.9816657e-002 + -6.1960278e-002 + -6.3944481e-002 + -6.5769067e-002 + -6.7452502e-002 + -6.8966401e-002 + -7.0353307e-002 + -7.1582636e-002 + -7.2677464e-002 + -7.3640601e-002 + -7.4466439e-002 + -7.5157626e-002 + -7.5730576e-002 + -7.6174832e-002 + -7.6505072e-002 + -7.6720492e-002 + -7.6823001e-002 + -7.6817398e-002 + -7.6709349e-002 + -7.6499217e-002 + -7.6199248e-002 + -7.5800836e-002 + -7.5313734e-002 + -7.4745256e-002 + -7.4100364e-002 + -7.3362026e-002 + -7.2568258e-002 + -7.1700267e-002 + -7.0762871e-002 + -6.9763024e-002 + -6.8704383e-002 + -6.7607599e-002 + -6.6436751e-002 + -6.5224711e-002 + -6.3971590e-002 + -6.2685781e-002 + -6.1345517e-002 + -5.9983748e-002 + -5.8591568e-002 + -5.7161645e-002 + -5.5717365e-002 + -5.4245277e-002 + -5.2763075e-002 + -5.1255616e-002 + -4.9738576e-002 + -4.8216572e-002 + -4.6684303e-002 + -4.5148841e-002 + -4.3609754e-002 + -4.2064909e-002 + -4.0534917e-002 + -3.9005368e-002 + -3.7481285e-002 + -3.5969756e-002 + -3.4462095e-002 + -3.2975408e-002 + -3.1501761e-002 + -3.0050266e-002 + -2.8607217e-002 + -2.7185943e-002 + -2.5787585e-002 + -2.4416099e-002 + -2.3068017e-002 + -2.1746755e-002 + -2.0453179e-002 + -1.9187243e-002 + -1.7943338e-002 + -1.6732471e-002 + -1.5540555e-002 + -1.4390467e-002 + -1.3271822e-002 + -1.2185000e-002 + -1.1131555e-002 + -1.0115022e-002 + -9.1325330e-003 + -8.1798233e-003 + -7.2615817e-003 + -6.3792293e-003 + -5.5337211e-003 + -4.7222596e-003 + -3.9401124e-003 + -3.1933778e-003 + -2.4826724e-003 + -1.8039473e-003 + -1.1568136e-003 + -5.4642809e-004 + 2.7604519e-005 + 5.8322642e-004 + 1.0902329e-003 + 1.5784683e-003 + 2.0274176e-003 + 2.4508540e-003 + 2.8446758e-003 + 3.2091886e-003 + 3.5401247e-003 + 3.8456408e-003 + 4.1251642e-003 + 4.3801862e-003 + 4.6039530e-003 + 4.8109469e-003 + 4.9839688e-003 + 5.1382275e-003 + 5.2715759e-003 + 5.3838976e-003 + 5.4753783e-003 + 5.5404364e-003 + 5.5917129e-003 + 5.6266114e-003 + 5.6389200e-003 + 5.6455197e-003 + 5.6220643e-003 + 5.5938023e-003 + 5.5475715e-003 + 5.4876040e-003 + 5.4196776e-003 + 5.3471681e-003 + 5.2461166e-003 + 5.1407354e-003 + 5.0393023e-003 + 4.9137604e-003 + 4.7932561e-003 + 4.6606461e-003 + 4.5209853e-003 + 4.3730720e-003 + 4.2264269e-003 + 4.0819753e-003 + 3.9207432e-003 + 3.7603923e-003 + 3.6008268e-003 + 3.4418874e-003 + 3.2739613e-003 + 3.1125421e-003 + 2.9469448e-003 + 2.7870464e-003 + 2.6201759e-003 + 2.4625617e-003 + 2.3017255e-003 + 2.1461584e-003 + 1.9841141e-003 + 1.8348265e-003 + 1.6868083e-003 + 1.5443220e-003 + 1.3902495e-003 + 1.2577885e-003 + 1.1250155e-003 + 9.8859883e-004 + 8.6084433e-004 + 7.4580259e-004 + 6.2393761e-004 + 5.1073885e-004 + 4.0265402e-004 + 2.9495311e-004 + 2.0430171e-004 + 1.0943831e-004 + 1.3494974e-005 + -6.1733441e-005 + -1.4463809e-004 + -2.0983373e-004 + -2.8969812e-004 + -3.5011759e-004 + -4.0951215e-004 + -4.6063255e-004 + -5.1455722e-004 + -5.5645764e-004 + -5.9461189e-004 + -6.3415949e-004 + -6.6504151e-004 + -6.9179375e-004 + -7.2153920e-004 + -7.3193572e-004 + -7.5300014e-004 + -7.6307936e-004 + -7.7579773e-004 + -7.8014496e-004 + -7.8036647e-004 + -7.7798695e-004 + -7.8343323e-004 + -7.7248486e-004 + -7.6813719e-004 + -7.4905981e-004 + -7.4409419e-004 + -7.2550431e-004 + -7.1577365e-004 + -6.9416146e-004 + -6.7776908e-004 + -6.5403334e-004 + -6.3124935e-004 + -6.1327474e-004 + -5.8709305e-004 + -5.6778026e-004 + -5.4665656e-004 + -5.2265643e-004 + -5.0407143e-004 + -4.8937912e-004 + -4.8752280e-004 + -4.9475181e-004 + -5.6176926e-004 + -5.5252865e-004 diff --git a/TTS/vocoder/layers/upsample.py b/TTS/vocoder/layers/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..e169db00b2749493e1cec07ee51c93178dada118 --- /dev/null +++ b/TTS/vocoder/layers/upsample.py @@ -0,0 +1,102 @@ +import torch +from torch.nn import functional as F + + +class Stretch2d(torch.nn.Module): + def __init__(self, x_scale, y_scale, mode="nearest"): + super().__init__() + self.x_scale = x_scale + self.y_scale = y_scale + self.mode = mode + + def forward(self, x): + """ + x (Tensor): Input tensor (B, C, F, T). + Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), + """ + return F.interpolate(x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) + + +class UpsampleNetwork(torch.nn.Module): + # pylint: disable=dangerous-default-value + def __init__( + self, + upsample_factors, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + use_causal_conv=False, + ): + super().__init__() + self.use_causal_conv = use_causal_conv + self.up_layers = torch.nn.ModuleList() + for scale in upsample_factors: + # interpolation layer + stretch = Stretch2d(scale, 1, interpolate_mode) + self.up_layers += [stretch] + + # conv layer + assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size." + freq_axis_padding = (freq_axis_kernel_size - 1) // 2 + kernel_size = (freq_axis_kernel_size, scale * 2 + 1) + if use_causal_conv: + padding = (freq_axis_padding, scale * 2) + else: + padding = (freq_axis_padding, scale) + conv = torch.nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + self.up_layers += [conv] + + # nonlinear + if nonlinear_activation is not None: + nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params) + self.up_layers += [nonlinear] + + def forward(self, c): + """ + c : (B, C, T_in). + Tensor: (B, C, T_upsample) + """ + c = c.unsqueeze(1) # (B, 1, C, T) + for f in self.up_layers: + c = f(c) + return c.squeeze(1) # (B, C, T') + + +class ConvUpsample(torch.nn.Module): + # pylint: disable=dangerous-default-value + def __init__( + self, + upsample_factors, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + aux_channels=80, + aux_context_window=0, + use_causal_conv=False, + ): + super().__init__() + self.aux_context_window = aux_context_window + self.use_causal_conv = use_causal_conv and aux_context_window > 0 + # To capture wide-context information in conditional features + kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1 + # NOTE(kan-bayashi): Here do not use padding because the input is already padded + self.conv_in = torch.nn.Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False) + self.upsample = UpsampleNetwork( + upsample_factors=upsample_factors, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + interpolate_mode=interpolate_mode, + freq_axis_kernel_size=freq_axis_kernel_size, + use_causal_conv=use_causal_conv, + ) + + def forward(self, c): + """ + c : (B, C, T_in). + Tensor: (B, C, T_upsampled), + """ + c_ = self.conv_in(c) + c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_ + return self.upsample(c) diff --git a/TTS/vocoder/layers/wavegrad.py b/TTS/vocoder/layers/wavegrad.py new file mode 100644 index 0000000000000000000000000000000000000000..9f1512c6d4b836c02b8b04136d69a32833d9f8c3 --- /dev/null +++ b/TTS/vocoder/layers/wavegrad.py @@ -0,0 +1,166 @@ +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations + + +class Conv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + nn.init.orthogonal_(self.weight) + nn.init.zeros_(self.bias) + + +class PositionalEncoding(nn.Module): + """Positional encoding with noise level conditioning""" + + def __init__(self, n_channels, max_len=10000): + super().__init__() + self.n_channels = n_channels + self.max_len = max_len + self.C = 5000 + self.pe = torch.zeros(0, 0) + + def forward(self, x, noise_level): + if x.shape[2] > self.pe.shape[1]: + self.init_pe_matrix(x.shape[1], x.shape[2], x) + return x + noise_level[..., None, None] + self.pe[:, : x.size(2)].repeat(x.shape[0], 1, 1) / self.C + + def init_pe_matrix(self, n_channels, max_len, x): + pe = torch.zeros(max_len, n_channels) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.pow(10000, torch.arange(0, n_channels, 2).float() / n_channels) + + pe[:, 0::2] = torch.sin(position / div_term) + pe[:, 1::2] = torch.cos(position / div_term) + self.pe = pe.transpose(0, 1).to(x) + + +class FiLM(nn.Module): + def __init__(self, input_size, output_size): + super().__init__() + self.encoding = PositionalEncoding(input_size) + self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1) + self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1) + + nn.init.xavier_uniform_(self.input_conv.weight) + nn.init.xavier_uniform_(self.output_conv.weight) + nn.init.zeros_(self.input_conv.bias) + nn.init.zeros_(self.output_conv.bias) + + def forward(self, x, noise_scale): + o = self.input_conv(x) + o = F.leaky_relu(o, 0.2) + o = self.encoding(o, noise_scale) + shift, scale = torch.chunk(self.output_conv(o), 2, dim=1) + return shift, scale + + def remove_weight_norm(self): + remove_parametrizations(self.input_conv, "weight") + remove_parametrizations(self.output_conv, "weight") + + def apply_weight_norm(self): + self.input_conv = weight_norm(self.input_conv) + self.output_conv = weight_norm(self.output_conv) + + +@torch.jit.script +def shif_and_scale(x, scale, shift): + o = shift + scale * x + return o + + +class UBlock(nn.Module): + def __init__(self, input_size, hidden_size, factor, dilation): + super().__init__() + assert isinstance(dilation, (list, tuple)) + assert len(dilation) == 4 + + self.factor = factor + self.res_block = Conv1d(input_size, hidden_size, 1) + self.main_block = nn.ModuleList( + [ + Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]), + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1]), + ] + ) + self.out_block = nn.ModuleList( + [ + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]), + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3]), + ] + ) + + def forward(self, x, shift, scale): + x_inter = F.interpolate(x, size=x.shape[-1] * self.factor) + res = self.res_block(x_inter) + o = F.leaky_relu(x_inter, 0.2) + o = F.interpolate(o, size=x.shape[-1] * self.factor) + o = self.main_block[0](o) + o = shif_and_scale(o, scale, shift) + o = F.leaky_relu(o, 0.2) + o = self.main_block[1](o) + res2 = res + o + o = shif_and_scale(res2, scale, shift) + o = F.leaky_relu(o, 0.2) + o = self.out_block[0](o) + o = shif_and_scale(o, scale, shift) + o = F.leaky_relu(o, 0.2) + o = self.out_block[1](o) + o = o + res2 + return o + + def remove_weight_norm(self): + remove_parametrizations(self.res_block, "weight") + for _, layer in enumerate(self.main_block): + if len(layer.state_dict()) != 0: + remove_parametrizations(layer, "weight") + for _, layer in enumerate(self.out_block): + if len(layer.state_dict()) != 0: + remove_parametrizations(layer, "weight") + + def apply_weight_norm(self): + self.res_block = weight_norm(self.res_block) + for idx, layer in enumerate(self.main_block): + if len(layer.state_dict()) != 0: + self.main_block[idx] = weight_norm(layer) + for idx, layer in enumerate(self.out_block): + if len(layer.state_dict()) != 0: + self.out_block[idx] = weight_norm(layer) + + +class DBlock(nn.Module): + def __init__(self, input_size, hidden_size, factor): + super().__init__() + self.factor = factor + self.res_block = Conv1d(input_size, hidden_size, 1) + self.main_block = nn.ModuleList( + [ + Conv1d(input_size, hidden_size, 3, dilation=1, padding=1), + Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2), + Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4), + ] + ) + + def forward(self, x): + size = x.shape[-1] // self.factor + res = self.res_block(x) + res = F.interpolate(res, size=size) + o = F.interpolate(x, size=size) + for layer in self.main_block: + o = F.leaky_relu(o, 0.2) + o = layer(o) + return o + res + + def remove_weight_norm(self): + remove_parametrizations(self.res_block, "weight") + for _, layer in enumerate(self.main_block): + if len(layer.state_dict()) != 0: + remove_parametrizations(layer, "weight") + + def apply_weight_norm(self): + self.res_block = weight_norm(self.res_block) + for idx, layer in enumerate(self.main_block): + if len(layer.state_dict()) != 0: + self.main_block[idx] = weight_norm(layer) diff --git a/TTS/vocoder/models/__init__.py b/TTS/vocoder/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1716f16d2677297e6f126dbbd0b7401f6a9996 --- /dev/null +++ b/TTS/vocoder/models/__init__.py @@ -0,0 +1,157 @@ +import importlib +import logging +import re + +from coqpit import Coqpit + +logger = logging.getLogger(__name__) + + +def to_camel(text): + text = text.capitalize() + return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) + + +def setup_model(config: Coqpit): + """Load models directly from configuration.""" + if "discriminator_model" in config and "generator_model" in config: + MyModel = importlib.import_module("TTS.vocoder.models.gan") + MyModel = getattr(MyModel, "GAN") + else: + MyModel = importlib.import_module("TTS.vocoder.models." + config.model.lower()) + if config.model.lower() == "wavernn": + MyModel = getattr(MyModel, "Wavernn") + elif config.model.lower() == "gan": + MyModel = getattr(MyModel, "GAN") + elif config.model.lower() == "wavegrad": + MyModel = getattr(MyModel, "Wavegrad") + else: + try: + MyModel = getattr(MyModel, to_camel(config.model)) + except ModuleNotFoundError as e: + raise ValueError(f"Model {config.model} not exist!") from e + logger.info("Vocoder model: %s", config.model) + return MyModel.init_from_config(config) + + +def setup_generator(c): + """TODO: use config object as arguments""" + logger.info("Generator model: %s", c.generator_model) + MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) + MyModel = getattr(MyModel, to_camel(c.generator_model)) + # this is to preserve the Wavernn class name (instead of Wavernn) + if c.generator_model.lower() in "hifigan_generator": + model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params) + elif c.generator_model.lower() in "melgan_generator": + model = MyModel( + in_channels=c.audio["num_mels"], + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=c.generator_model_params["upsample_factors"], + res_kernel=3, + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + elif c.generator_model in "melgan_fb_generator": + raise ValueError("melgan_fb_generator is now fullband_melgan_generator") + elif c.generator_model.lower() in "multiband_melgan_generator": + model = MyModel( + in_channels=c.audio["num_mels"], + out_channels=4, + proj_kernel=7, + base_channels=384, + upsample_factors=c.generator_model_params["upsample_factors"], + res_kernel=3, + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + elif c.generator_model.lower() in "fullband_melgan_generator": + model = MyModel( + in_channels=c.audio["num_mels"], + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=c.generator_model_params["upsample_factors"], + res_kernel=3, + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + elif c.generator_model.lower() in "parallel_wavegan_generator": + model = MyModel( + in_channels=1, + out_channels=1, + kernel_size=3, + num_res_blocks=c.generator_model_params["num_res_blocks"], + stacks=c.generator_model_params["stacks"], + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=c.audio["num_mels"], + dropout=0.0, + bias=True, + use_weight_norm=True, + upsample_factors=c.generator_model_params["upsample_factors"], + ) + elif c.generator_model.lower() in "univnet_generator": + model = MyModel(**c.generator_model_params) + else: + raise NotImplementedError(f"Model {c.generator_model} not implemented!") + return model + + +def setup_discriminator(c): + """TODO: use config objekt as arguments""" + logger.info("Discriminator model: %s", c.discriminator_model) + if "parallel_wavegan" in c.discriminator_model: + MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator") + else: + MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower()) + MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower())) + if c.discriminator_model in "hifigan_discriminator": + model = MyModel() + if c.discriminator_model in "random_window_discriminator": + model = MyModel( + cond_channels=c.audio["num_mels"], + hop_length=c.audio["hop_length"], + uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"], + cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"], + cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"], + window_sizes=c.discriminator_model_params["window_sizes"], + ) + if c.discriminator_model in "melgan_multiscale_discriminator": + model = MyModel( + in_channels=1, + out_channels=1, + kernel_sizes=(5, 3), + base_channels=c.discriminator_model_params["base_channels"], + max_channels=c.discriminator_model_params["max_channels"], + downsample_factors=c.discriminator_model_params["downsample_factors"], + ) + if c.discriminator_model == "residual_parallel_wavegan_discriminator": + model = MyModel( + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=c.discriminator_model_params["num_layers"], + stacks=c.discriminator_model_params["stacks"], + res_channels=64, + gate_channels=128, + skip_channels=64, + dropout=0.0, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + ) + if c.discriminator_model == "parallel_wavegan_discriminator": + model = MyModel( + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=c.discriminator_model_params["num_layers"], + conv_channels=64, + dilation_factor=1, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + bias=True, + ) + if c.discriminator_model == "univnet_discriminator": + model = MyModel() + return model diff --git a/TTS/vocoder/models/base_vocoder.py b/TTS/vocoder/models/base_vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0bcbe7ba1cb933c2bd3e8925c32d781aeaf79add --- /dev/null +++ b/TTS/vocoder/models/base_vocoder.py @@ -0,0 +1,55 @@ +from coqpit import Coqpit + +from TTS.model import BaseTrainerModel + +# pylint: skip-file + + +class BaseVocoder(BaseTrainerModel): + """Base `vocoder` class. Every new `vocoder` model must inherit this. + + It defines `vocoder` specific functions on top of `Model`. + + Notes on input/output tensor shapes: + Any input or output tensor of the model must be shaped as + + - 3D tensors `batch x time x channels` + - 2D tensors `batch x channels` + - 1D tensors `batch x 1` + """ + + MODEL_TYPE = "vocoder" + + def __init__(self, config): + super().__init__() + self._set_model_args(config) + + def _set_model_args(self, config: Coqpit): + """Setup model args based on the config type. + + If the config is for training with a name like "*Config", then the model args are embeded in the + config.model_args + + If the config is for the model with a name like "*Args", then we assign the directly. + """ + # don't use isintance not to import recursively + if "Config" in config.__class__.__name__: + if "characters" in config: + _, self.config, num_chars = self.get_characters(config) + self.config.num_chars = num_chars + if hasattr(self.config, "model_args"): + config.model_args.num_chars = num_chars + if "model_args" in config: + self.args = self.config.model_args + # This is for backward compatibility + if "model_params" in config: + self.args = self.config.model_params + else: + self.config = config + if "model_args" in config: + self.args = self.config.model_args + # This is for backward compatibility + if "model_params" in config: + self.args = self.config.model_params + else: + raise ValueError("config must be either a *Config or *Args") diff --git a/TTS/vocoder/models/fullband_melgan_generator.py b/TTS/vocoder/models/fullband_melgan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..ee25559af0d468aac535841bdfdd33b366250f43 --- /dev/null +++ b/TTS/vocoder/models/fullband_melgan_generator.py @@ -0,0 +1,33 @@ +import torch + +from TTS.vocoder.models.melgan_generator import MelganGenerator + + +class FullbandMelganGenerator(MelganGenerator): + def __init__( + self, + in_channels=80, + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=(2, 8, 2, 2), + res_kernel=3, + num_res_blocks=4, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + proj_kernel=proj_kernel, + base_channels=base_channels, + upsample_factors=upsample_factors, + res_kernel=res_kernel, + num_res_blocks=num_res_blocks, + ) + + @torch.no_grad() + def inference(self, cond_features): + cond_features = cond_features.to(self.layers[1].weight.device) + cond_features = torch.nn.functional.pad( + cond_features, (self.inference_padding, self.inference_padding), "replicate" + ) + return self.layers(cond_features) diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py new file mode 100644 index 0000000000000000000000000000000000000000..8792950a564fcf05d9cb64dbff5297bdf9e8fede --- /dev/null +++ b/TTS/vocoder/models/gan.py @@ -0,0 +1,373 @@ +from inspect import signature +from typing import Dict, List, Tuple + +import numpy as np +import torch +from coqpit import Coqpit +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from trainer.io import load_fsspec +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.utils.audio import AudioProcessor +from TTS.vocoder.datasets.gan_dataset import GANDataset +from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss +from TTS.vocoder.models import setup_discriminator, setup_generator +from TTS.vocoder.models.base_vocoder import BaseVocoder +from TTS.vocoder.utils.generic_utils import plot_results + + +class GAN(BaseVocoder): + def __init__(self, config: Coqpit, ap: AudioProcessor = None): + """Wrap a generator and a discriminator network. It provides a compatible interface for the trainer. + It also helps mixing and matching different generator and disciminator networks easily. + + To implement a new GAN models, you just need to define the generator and the discriminator networks, the rest + is handled by the `GAN` class. + + Args: + config (Coqpit): Model configuration. + ap (AudioProcessor): 🐸TTS AudioProcessor instance. Defaults to None. + + Examples: + Initializing the GAN model with HifiGAN generator and discriminator. + >>> from TTS.vocoder.configs import HifiganConfig + >>> config = HifiganConfig() + >>> model = GAN(config) + """ + super().__init__(config) + self.config = config + self.model_g = setup_generator(config) + self.model_d = setup_discriminator(config) + self.train_disc = False # if False, train only the generator. + self.y_hat_g = None # the last generator prediction to be passed onto the discriminator + self.ap = ap + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run the generator's forward pass. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: output of the GAN generator network. + """ + return self.model_g.forward(x) + + def inference(self, x: torch.Tensor) -> torch.Tensor: + """Run the generator's inference pass. + + Args: + x (torch.Tensor): Input tensor. + Returns: + torch.Tensor: output of the GAN generator network. + """ + return self.model_g.inference(x) + + def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]: + """Compute model outputs and the loss values. `optimizer_idx` selects the generator or the discriminator for + network on the current pass. + + Args: + batch (Dict): Batch of samples returned by the dataloader. + criterion (Dict): Criterion used to compute the losses. + optimizer_idx (int): ID of the optimizer in use on the current pass. + + Raises: + ValueError: `optimizer_idx` is an unexpected value. + + Returns: + Tuple[Dict, Dict]: model outputs and the computed loss values. + """ + outputs = {} + loss_dict = {} + + x = batch["input"] + y = batch["waveform"] + + if optimizer_idx not in [0, 1]: + raise ValueError(" [!] Unexpected `optimizer_idx`.") + + if optimizer_idx == 0: + # DISCRIMINATOR optimization + + # generator pass + y_hat = self.model_g(x)[:, :, : y.size(2)] + + # cache for generator loss + # pylint: disable=W0201 + self.y_hat_g = y_hat + self.y_hat_sub = None + self.y_sub_g = None + + # PQMF formatting + if y_hat.shape[1] > 1: + self.y_hat_sub = y_hat + y_hat = self.model_g.pqmf_synthesis(y_hat) + self.y_hat_g = y_hat # save for generator loss + self.y_sub_g = self.model_g.pqmf_analysis(y) + + scores_fake, feats_fake, feats_real = None, None, None + + if self.train_disc: + # use different samples for G and D trainings + if self.config.diff_samples_for_G_and_D: + x_d = batch["input_disc"] + y_d = batch["waveform_disc"] + # use a different sample than generator + with torch.no_grad(): + y_hat = self.model_g(x_d) + + # PQMF formatting + if y_hat.shape[1] > 1: + y_hat = self.model_g.pqmf_synthesis(y_hat) + else: + # use the same samples as generator + x_d = x.clone() + y_d = y.clone() + y_hat = self.y_hat_g + + # run D with or without cond. features + if len(signature(self.model_d.forward).parameters) == 2: + D_out_fake = self.model_d(y_hat.detach().clone(), x_d) + D_out_real = self.model_d(y_d, x_d) + else: + D_out_fake = self.model_d(y_hat.detach()) + D_out_real = self.model_d(y_d) + + # format D outputs + if isinstance(D_out_fake, tuple): + # self.model_d returns scores and features + scores_fake, feats_fake = D_out_fake + if D_out_real is None: + scores_real, feats_real = None, None + else: + scores_real, feats_real = D_out_real + else: + # model D returns only scores + scores_fake = D_out_fake + scores_real = D_out_real + + # compute losses + loss_dict = criterion[optimizer_idx](scores_fake, scores_real) + outputs = {"model_outputs": y_hat} + + if optimizer_idx == 1: + # GENERATOR loss + scores_fake, feats_fake, feats_real = None, None, None + if self.train_disc: + if len(signature(self.model_d.forward).parameters) == 2: + D_out_fake = self.model_d(self.y_hat_g, x) + else: + D_out_fake = self.model_d(self.y_hat_g) + D_out_real = None + + if self.config.use_feat_match_loss: + with torch.no_grad(): + D_out_real = self.model_d(y) + + # format D outputs + if isinstance(D_out_fake, tuple): + scores_fake, feats_fake = D_out_fake + if D_out_real is None: + feats_real = None + else: + _, feats_real = D_out_real + else: + scores_fake = D_out_fake + feats_fake, feats_real = None, None + + # compute losses + loss_dict = criterion[optimizer_idx]( + self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g + ) + outputs = {"model_outputs": self.y_hat_g} + return outputs, loss_dict + + def _log(self, name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]: + """Logging shared by the training and evaluation. + + Args: + name (str): Name of the run. `train` or `eval`, + ap (AudioProcessor): Audio processor used in training. + batch (Dict): Batch used in the last train/eval step. + outputs (Dict): Model outputs from the last train/eval step. + + Returns: + Tuple[Dict, Dict]: log figures and audio samples. + """ + y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"] + y = batch["waveform"] + figures = plot_results(y_hat, y, ap, name) + sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() + audios = {f"{name}/audio": sample_voice} + return figures, audios + + def train_log( + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: + """Call `_log()` for training.""" + figures, audios = self._log("eval", self.ap, batch, outputs) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @torch.no_grad() + def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: + """Call `train_step()` with `no_grad()`""" + self.train_disc = True # Avoid a bug in the Training with the missing discriminator loss + return self.train_step(batch, criterion, optimizer_idx) + + def eval_log( + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: + """Call `_log()` for evaluation.""" + figures, audios = self._log("eval", self.ap, batch, outputs) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + def load_checkpoint( + self, + config: Coqpit, + checkpoint_path: str, + eval: bool = False, # pylint: disable=unused-argument, redefined-builtin + cache: bool = False, + ) -> None: + """Load a GAN checkpoint and initialize model parameters. + + Args: + config (Coqpit): Model config. + checkpoint_path (str): Checkpoint file path. + eval (bool, optional): If true, load the model for inference. If falseDefaults to False. + """ + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + # band-aid for older than v0.0.15 GAN models + if "model_disc" in state: + self.model_g.load_checkpoint(config, checkpoint_path, eval) + else: + self.load_state_dict(state["model"]) + if eval: + self.model_d = None + if hasattr(self.model_g, "remove_weight_norm"): + self.model_g.remove_weight_norm() + + def on_train_step_start(self, trainer) -> None: + """Enable the discriminator training based on `steps_to_start_discriminator` + + Args: + trainer (Trainer): Trainer object. + """ + self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator + + def get_optimizer(self) -> List: + """Initiate and return the GAN optimizers based on the config parameters. + + It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator. + + Returns: + List: optimizers. + """ + optimizer1 = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g + ) + optimizer2 = get_optimizer( + self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d + ) + return [optimizer2, optimizer1] + + def get_lr(self) -> List: + """Set the initial learning rates for each optimizer. + + Returns: + List: learning rates for each optimizer. + """ + return [self.config.lr_disc, self.config.lr_gen] + + def get_scheduler(self, optimizer) -> List: + """Set the schedulers for each optimizer. + + Args: + optimizer (List[`torch.optim.Optimizer`]): List of optimizers. + + Returns: + List: Schedulers, one for each optimizer. + """ + scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) + scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) + return [scheduler2, scheduler1] + + @staticmethod + def format_batch(batch: List) -> Dict: + """Format the batch for training. + + Args: + batch (List): Batch out of the dataloader. + + Returns: + Dict: formatted model inputs. + """ + if isinstance(batch[0], list): + x_G, y_G = batch[0] + x_D, y_D = batch[1] + return {"input": x_G, "waveform": y_G, "input_disc": x_D, "waveform_disc": y_D} + x, y = batch + return {"input": x, "waveform": y} + + def get_data_loader( # pylint: disable=no-self-use, unused-argument + self, + config: Coqpit, + assets: Dict, + is_eval: True, + samples: List, + verbose: bool, + num_gpus: int, + rank: int = None, # pylint: disable=unused-argument + ): + """Initiate and return the GAN dataloader. + + Args: + config (Coqpit): Model config. + ap (AudioProcessor): Audio processor. + is_eval (True): Set the dataloader for evaluation if true. + samples (List): Data samples. + verbose (bool): Log information if true. + num_gpus (int): Number of GPUs in use. + rank (int): Rank of the current GPU. Defaults to None. + + Returns: + DataLoader: Torch dataloader. + """ + dataset = GANDataset( + ap=self.ap, + items=samples, + seq_len=config.seq_len, + hop_len=self.ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False, + is_training=not is_eval, + return_segments=not is_eval, + use_noise_augment=config.use_noise_augment, + use_cache=config.use_cache, + ) + dataset.shuffle_mapping() + sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=1 if is_eval else config.batch_size, + shuffle=num_gpus == 0, + drop_last=False, + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=False, + ) + return loader + + def get_criterion(self): + """Return criterions for the optimizers""" + return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)] + + @staticmethod + def init_from_config(config: Coqpit) -> "GAN": + ap = AudioProcessor.init_from_config(config) + return GAN(config, ap=ap) diff --git a/TTS/vocoder/models/hifigan_discriminator.py b/TTS/vocoder/models/hifigan_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..1cbc6ab35784055fc68b2a375954226d097ada5d --- /dev/null +++ b/TTS/vocoder/models/hifigan_discriminator.py @@ -0,0 +1,218 @@ +# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.vocoder.models.hifigan_generator import get_padding + +LRELU_SLOPE = 0.1 + + +class DiscriminatorP(torch.nn.Module): + """HiFiGAN Periodic Discriminator + + Takes every Pth value from the input waveform and applied a stack of convoluations. + + Note: + if `period` is 2 + `waveform = [1, 2, 3, 4, 5, 6 ...] --> [1, 3, 5 ... ] --> convs -> score, feat` + + Args: + x (Tensor): input waveform. + + Returns: + [Tensor]: discriminator scores per sample in the batch. + [List[Tensor]]: list of features from each convolutional layer. + + Shapes: + x: [B, 1, T] + """ + + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super().__init__() + self.period = period + norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ] + ) + self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + [Tensor]: discriminator scores per sample in the batch. + [List[Tensor]]: list of features from each convolutional layer. + + Shapes: + x: [B, 1, T] + """ + feat = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + feat.append(x) + x = self.conv_post(x) + feat.append(x) + x = torch.flatten(x, 1, -1) + + return x, feat + + +class MultiPeriodDiscriminator(torch.nn.Module): + """HiFiGAN Multi-Period Discriminator (MPD) + Wrapper for the `PeriodDiscriminator` to apply it in different periods. + Periods are suggested to be prime numbers to reduce the overlap between each discriminator. + """ + + def __init__(self, use_spectral_norm=False): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorP(2, use_spectral_norm=use_spectral_norm), + DiscriminatorP(3, use_spectral_norm=use_spectral_norm), + DiscriminatorP(5, use_spectral_norm=use_spectral_norm), + DiscriminatorP(7, use_spectral_norm=use_spectral_norm), + DiscriminatorP(11, use_spectral_norm=use_spectral_norm), + ] + ) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + [List[Tensor]]: list of scores from each discriminator. + [List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer. + + Shapes: + x: [B, 1, T] + """ + scores = [] + feats = [] + for _, d in enumerate(self.discriminators): + score, feat = d(x) + scores.append(score) + feats.append(feat) + return scores, feats + + +class DiscriminatorS(torch.nn.Module): + """HiFiGAN Scale Discriminator. + It is similar to `MelganDiscriminator` but with a specific architecture explained in the paper. + + Args: + use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. + + """ + + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)), + norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + Tensor: discriminator scores. + List[Tensor]: list of features from the convolutiona layers. + """ + feat = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + feat.append(x) + x = self.conv_post(x) + feat.append(x) + x = torch.flatten(x, 1, -1) + return x, feat + + +class MultiScaleDiscriminator(torch.nn.Module): + """HiFiGAN Multi-Scale Discriminator. + It is similar to `MultiScaleMelganDiscriminator` but specially tailored for HiFiGAN as in the paper. + """ + + def __init__(self): + super().__init__() + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + self.meanpools = nn.ModuleList([nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of list of features from each layers of each discriminator. + """ + scores = [] + feats = [] + for i, d in enumerate(self.discriminators): + if i != 0: + x = self.meanpools[i - 1](x) + score, feat = d(x) + scores.append(score) + feats.append(feat) + return scores, feats + + +class HifiganDiscriminator(nn.Module): + """HiFiGAN discriminator wrapping MPD and MSD.""" + + def __init__(self): + super().__init__() + self.mpd = MultiPeriodDiscriminator() + self.msd = MultiScaleDiscriminator() + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of list of features from each layers of each discriminator. + """ + scores, feats = self.mpd(x) + scores_, feats_ = self.msd(x) + return scores + scores_, feats + feats_ diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..afdd59a859b819d5545243772f817cb84f79e9f0 --- /dev/null +++ b/TTS/vocoder/models/hifigan_generator.py @@ -0,0 +1,304 @@ +# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py +import logging + +import torch +from torch import nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn import functional as F +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations +from trainer.io import load_fsspec + +logger = logging.getLogger(__name__) + +LRELU_SLOPE = 0.1 + + +def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock1(torch.nn.Module): + """Residual Block Type 1. It has 3 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o + |--------------------------------------------------------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + weight_norm( + Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)) + ), + ] + ) + + def forward(self, x): + """ + Args: + x (Tensor): input tensor. + Returns: + Tensor: output tensor. + Shapes: + x: [B, C, T] + """ + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_parametrizations(l, "weight") + for l in self.convs2: + remove_parametrizations(l, "weight") + + +class ResBlock2(torch.nn.Module): + """Residual Block Type 2. It has 1 convolutional layers in each convolutional block. + + Network:: + + x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o + |---------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ + + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_parametrizations(l, "weight") + + +class HifiganGenerator(torch.nn.Module): + def __init__( + self, + in_channels, + out_channels, + resblock_type, + resblock_dilation_sizes, + resblock_kernel_sizes, + upsample_kernel_sizes, + upsample_initial_channel, + upsample_factors, + inference_padding=5, + cond_channels=0, + conv_pre_weight_norm=True, + conv_post_weight_norm=True, + conv_post_bias=True, + ): + r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) + + Network: + x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o + .. -> zI ---| + resblockN_kNx1 -> zN ---' + + Args: + in_channels (int): number of input tensor channels. + out_channels (int): number of output tensor channels. + resblock_type (str): type of the `ResBlock`. '1' or '2'. + resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`. + resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`. + upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution. + upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 + for each consecutive upsampling layer. + upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer. + inference_padding (int): constant padding applied to the input at inference time. Defaults to 5. + """ + super().__init__() + self.inference_padding = inference_padding + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_factors) + # initial upsampling layers + self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if resblock_type == "1" else ResBlock2 + # upsampling layers + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + # MRF blocks + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + # post convolution layer + self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)) + if cond_channels > 0: + self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) + + if not conv_pre_weight_norm: + remove_parametrizations(self.conv_pre, "weight") + + if not conv_post_weight_norm: + remove_parametrizations(self.conv_post, "weight") + + def forward(self, x, g=None): + """ + Args: + x (Tensor): feature input tensor. + g (Tensor): global conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + o = self.conv_pre(x) + if hasattr(self, "cond_layer"): + o = o + self.cond_layer(g) + for i in range(self.num_upsamples): + o = F.leaky_relu(o, LRELU_SLOPE) + o = self.ups[i](o) + z_sum = None + for j in range(self.num_kernels): + if z_sum is None: + z_sum = self.resblocks[i * self.num_kernels + j](o) + else: + z_sum += self.resblocks[i * self.num_kernels + j](o) + o = z_sum / self.num_kernels + o = F.leaky_relu(o) + o = self.conv_post(o) + o = torch.tanh(o) + return o + + @torch.no_grad() + def inference(self, c): + """ + Args: + x (Tensor): conditioning input tensor. + + Returns: + Tensor: output waveform. + + Shapes: + x: [B, C, T] + Tensor: [B, 1, T] + """ + c = c.to(self.conv_pre.weight.device) + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") + return self.forward(c) + + def remove_weight_norm(self): + logger.info("Removing weight norm...") + for l in self.ups: + remove_parametrizations(l, "weight") + for l in self.resblocks: + l.remove_weight_norm() + remove_parametrizations(self.conv_pre, "weight") + remove_parametrizations(self.conv_post, "weight") + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + self.remove_weight_norm() diff --git a/TTS/vocoder/models/melgan_discriminator.py b/TTS/vocoder/models/melgan_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..e41467da3c1d2fdceb9e86cecdc848b62ee731ff --- /dev/null +++ b/TTS/vocoder/models/melgan_discriminator.py @@ -0,0 +1,84 @@ +import numpy as np +from torch import nn +from torch.nn.utils.parametrizations import weight_norm + + +class MelganDiscriminator(nn.Module): + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_sizes=(5, 3), + base_channels=16, + max_channels=1024, + downsample_factors=(4, 4, 4, 4), + groups_denominator=4, + ): + super().__init__() + self.layers = nn.ModuleList() + + layer_kernel_size = np.prod(kernel_sizes) + layer_padding = (layer_kernel_size - 1) // 2 + + # initial layer + self.layers += [ + nn.Sequential( + nn.ReflectionPad1d(layer_padding), + weight_norm(nn.Conv1d(in_channels, base_channels, layer_kernel_size, stride=1)), + nn.LeakyReLU(0.2, inplace=True), + ) + ] + + # downsampling layers + layer_in_channels = base_channels + for downsample_factor in downsample_factors: + layer_out_channels = min(layer_in_channels * downsample_factor, max_channels) + layer_kernel_size = downsample_factor * 10 + 1 + layer_padding = (layer_kernel_size - 1) // 2 + layer_groups = layer_in_channels // groups_denominator + self.layers += [ + nn.Sequential( + weight_norm( + nn.Conv1d( + layer_in_channels, + layer_out_channels, + kernel_size=layer_kernel_size, + stride=downsample_factor, + padding=layer_padding, + groups=layer_groups, + ) + ), + nn.LeakyReLU(0.2, inplace=True), + ) + ] + layer_in_channels = layer_out_channels + + # last 2 layers + layer_padding1 = (kernel_sizes[0] - 1) // 2 + layer_padding2 = (kernel_sizes[1] - 1) // 2 + self.layers += [ + nn.Sequential( + weight_norm( + nn.Conv1d( + layer_out_channels, + layer_out_channels, + kernel_size=kernel_sizes[0], + stride=1, + padding=layer_padding1, + ) + ), + nn.LeakyReLU(0.2, inplace=True), + ), + weight_norm( + nn.Conv1d( + layer_out_channels, out_channels, kernel_size=kernel_sizes[1], stride=1, padding=layer_padding2 + ) + ), + ] + + def forward(self, x): + feats = [] + for layer in self.layers: + x = layer(x) + feats.append(x) + return x, feats diff --git a/TTS/vocoder/models/melgan_generator.py b/TTS/vocoder/models/melgan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..03c971afa4fff9836436aef290716063f88adcaa --- /dev/null +++ b/TTS/vocoder/models/melgan_generator.py @@ -0,0 +1,95 @@ +import torch +from torch import nn +from torch.nn.utils.parametrizations import weight_norm +from trainer.io import load_fsspec + +from TTS.vocoder.layers.melgan import ResidualStack + + +class MelganGenerator(nn.Module): + def __init__( + self, + in_channels=80, + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=(8, 8, 2, 2), + res_kernel=3, + num_res_blocks=3, + ): + super().__init__() + + # assert model parameters + assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number." + + # setup additional model parameters + base_padding = (proj_kernel - 1) // 2 + act_slope = 0.2 + self.inference_padding = 2 + + # initial layer + layers = [] + layers += [ + nn.ReflectionPad1d(base_padding), + weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)), + ] + + # upsampling layers and residual stacks + for idx, upsample_factor in enumerate(upsample_factors): + layer_in_channels = base_channels // (2**idx) + layer_out_channels = base_channels // (2 ** (idx + 1)) + layer_filter_size = upsample_factor * 2 + layer_stride = upsample_factor + layer_output_padding = upsample_factor % 2 + layer_padding = upsample_factor // 2 + layer_output_padding + layers += [ + nn.LeakyReLU(act_slope), + weight_norm( + nn.ConvTranspose1d( + layer_in_channels, + layer_out_channels, + layer_filter_size, + stride=layer_stride, + padding=layer_padding, + output_padding=layer_output_padding, + bias=True, + ) + ), + ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel), + ] + + layers += [nn.LeakyReLU(act_slope)] + + # final layer + layers += [ + nn.ReflectionPad1d(base_padding), + weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)), + nn.Tanh(), + ] + self.layers = nn.Sequential(*layers) + + def forward(self, c): + return self.layers(c) + + def inference(self, c): + c = c.to(self.layers[1].weight.device) + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") + return self.layers(c) + + def remove_weight_norm(self): + for _, layer in enumerate(self.layers): + if len(layer.state_dict()) != 0: + try: + nn.utils.parametrize.remove_parametrizations(layer, "weight") + except ValueError: + layer.remove_weight_norm() + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + self.remove_weight_norm() diff --git a/TTS/vocoder/models/melgan_multiscale_discriminator.py b/TTS/vocoder/models/melgan_multiscale_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..b4909f37c0c91c6fee8bb0baab98a8662039dea1 --- /dev/null +++ b/TTS/vocoder/models/melgan_multiscale_discriminator.py @@ -0,0 +1,50 @@ +from torch import nn + +from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator + + +class MelganMultiscaleDiscriminator(nn.Module): + def __init__( + self, + in_channels=1, + out_channels=1, + num_scales=3, + kernel_sizes=(5, 3), + base_channels=16, + max_channels=1024, + downsample_factors=(4, 4, 4), + pooling_kernel_size=4, + pooling_stride=2, + pooling_padding=2, + groups_denominator=4, + ): + super().__init__() + + self.discriminators = nn.ModuleList( + [ + MelganDiscriminator( + in_channels=in_channels, + out_channels=out_channels, + kernel_sizes=kernel_sizes, + base_channels=base_channels, + max_channels=max_channels, + downsample_factors=downsample_factors, + groups_denominator=groups_denominator, + ) + for _ in range(num_scales) + ] + ) + + self.pooling = nn.AvgPool1d( + kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False + ) + + def forward(self, x): + scores = [] + feats = [] + for disc in self.discriminators: + score, feat = disc(x) + scores.append(score) + feats.append(feat) + x = self.pooling(x) + return scores, feats diff --git a/TTS/vocoder/models/multiband_melgan_generator.py b/TTS/vocoder/models/multiband_melgan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..25d6590659cf5863176eb6609c7609b0e1b28d12 --- /dev/null +++ b/TTS/vocoder/models/multiband_melgan_generator.py @@ -0,0 +1,41 @@ +import torch + +from TTS.vocoder.layers.pqmf import PQMF +from TTS.vocoder.models.melgan_generator import MelganGenerator + + +class MultibandMelganGenerator(MelganGenerator): + def __init__( + self, + in_channels=80, + out_channels=4, + proj_kernel=7, + base_channels=384, + upsample_factors=(2, 8, 2, 2), + res_kernel=3, + num_res_blocks=3, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + proj_kernel=proj_kernel, + base_channels=base_channels, + upsample_factors=upsample_factors, + res_kernel=res_kernel, + num_res_blocks=num_res_blocks, + ) + self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0) + + def pqmf_analysis(self, x): + return self.pqmf_layer.analysis(x) + + def pqmf_synthesis(self, x): + return self.pqmf_layer.synthesis(x) + + @torch.no_grad() + def inference(self, cond_features): + cond_features = cond_features.to(self.layers[1].weight.device) + cond_features = torch.nn.functional.pad( + cond_features, (self.inference_padding, self.inference_padding), "replicate" + ) + return self.pqmf_synthesis(self.layers(cond_features)) diff --git a/TTS/vocoder/models/parallel_wavegan_discriminator.py b/TTS/vocoder/models/parallel_wavegan_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..211d45d91c8f9b13817587f2748258ccb2291105 --- /dev/null +++ b/TTS/vocoder/models/parallel_wavegan_discriminator.py @@ -0,0 +1,190 @@ +import logging +import math + +import torch +from torch import nn +from torch.nn.utils.parametrize import remove_parametrizations + +from TTS.vocoder.layers.parallel_wavegan import ResidualBlock + +logger = logging.getLogger(__name__) + + +class ParallelWaveganDiscriminator(nn.Module): + """PWGAN discriminator as in https://arxiv.org/abs/1910.11480. + It classifies each audio window real/fake and returns a sequence + of predictions. + It is a stack of convolutional blocks with dilation. + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=10, + conv_channels=64, + dilation_factor=1, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + bias=True, + ): + super().__init__() + assert (kernel_size - 1) % 2 == 0, " [!] does not support even number kernel size." + assert dilation_factor > 0, " [!] dilation factor must be > 0." + self.conv_layers = nn.ModuleList() + conv_in_channels = in_channels + for i in range(num_layers - 1): + if i == 0: + dilation = 1 + else: + dilation = i if dilation_factor == 1 else dilation_factor**i + conv_in_channels = conv_channels + padding = (kernel_size - 1) // 2 * dilation + conv_layer = [ + nn.Conv1d( + conv_in_channels, + conv_channels, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + bias=bias, + ), + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), + ] + self.conv_layers += conv_layer + padding = (kernel_size - 1) // 2 + last_conv_layer = nn.Conv1d(conv_in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + self.conv_layers += [last_conv_layer] + self.apply_weight_norm() + + def forward(self, x): + """ + x : (B, 1, T). + Returns: + Tensor: (B, 1, T) + """ + for f in self.conv_layers: + x = f(x) + return x + + def apply_weight_norm(self): + def _apply_weight_norm(m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): + torch.nn.utils.parametrizations.weight_norm(m) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + def _remove_weight_norm(m): + try: + logger.info("Weight norm is removed from %s", m) + remove_parametrizations(m, "weight") + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + +class ResidualParallelWaveganDiscriminator(nn.Module): + # pylint: disable=dangerous-default-value + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=30, + stacks=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + dropout=0.0, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + ): + super().__init__() + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_layers = num_layers + self.stacks = stacks + self.kernel_size = kernel_size + self.res_factor = math.sqrt(1.0 / num_layers) + + # check the number of num_layers and stacks + assert num_layers % stacks == 0 + layers_per_stack = num_layers // stacks + + # define first convolution + self.first_conv = nn.Sequential( + nn.Conv1d(in_channels, res_channels, kernel_size=1, padding=0, dilation=1, bias=True), + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), + ) + + # define residual blocks + self.conv_layers = nn.ModuleList() + for layer in range(num_layers): + dilation = 2 ** (layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + res_channels=res_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=-1, + dilation=dilation, + dropout=dropout, + bias=bias, + use_causal_conv=False, + ) + self.conv_layers += [conv] + + # define output layers + self.last_conv_layers = nn.ModuleList( + [ + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), + nn.Conv1d(skip_channels, skip_channels, kernel_size=1, padding=0, dilation=1, bias=True), + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), + nn.Conv1d(skip_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=True), + ] + ) + + # apply weight norm + self.apply_weight_norm() + + def forward(self, x): + """ + x: (B, 1, T). + """ + x = self.first_conv(x) + + skips = 0 + for f in self.conv_layers: + x, h = f(x, None) + skips += h + skips *= self.res_factor + + # apply final layers + x = skips + for f in self.last_conv_layers: + x = f(x) + return x + + def apply_weight_norm(self): + def _apply_weight_norm(m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): + torch.nn.utils.parametrizations.weight_norm(m) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + def _remove_weight_norm(m): + try: + logger.info("Weight norm is removed from %s", m) + remove_parametrizations(m, "weight") + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..6a4d4ca6e75257f8f12a881065f0ada45d601f8d --- /dev/null +++ b/TTS/vocoder/models/parallel_wavegan_generator.py @@ -0,0 +1,167 @@ +import logging +import math + +import numpy as np +import torch +from torch.nn.utils.parametrize import remove_parametrizations +from trainer.io import load_fsspec + +from TTS.vocoder.layers.parallel_wavegan import ResidualBlock +from TTS.vocoder.layers.upsample import ConvUpsample + +logger = logging.getLogger(__name__) + + +class ParallelWaveganGenerator(torch.nn.Module): + """PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf. + It is similar to WaveNet with no causal convolution. + It is conditioned on an aux feature (spectrogram) to generate + an output waveform from an input noise. + """ + + # pylint: disable=dangerous-default-value + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_res_blocks=30, + stacks=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + dropout=0.0, + bias=True, + use_weight_norm=True, + upsample_factors=[4, 4, 4, 4], + inference_padding=2, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.aux_channels = aux_channels + self.num_res_blocks = num_res_blocks + self.stacks = stacks + self.kernel_size = kernel_size + self.upsample_factors = upsample_factors + self.upsample_scale = np.prod(upsample_factors) + self.inference_padding = inference_padding + self.use_weight_norm = use_weight_norm + + # check the number of layers and stacks + assert num_res_blocks % stacks == 0 + layers_per_stack = num_res_blocks // stacks + + # define first convolution + self.first_conv = torch.nn.Conv1d(in_channels, res_channels, kernel_size=1, bias=True) + + # define conv + upsampling network + self.upsample_net = ConvUpsample(upsample_factors=upsample_factors) + + # define residual blocks + self.conv_layers = torch.nn.ModuleList() + for layer in range(num_res_blocks): + dilation = 2 ** (layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + res_channels=res_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=aux_channels, + dilation=dilation, + dropout=dropout, + bias=bias, + ) + self.conv_layers += [conv] + + # define output layers + self.last_conv_layers = torch.nn.ModuleList( + [ + torch.nn.ReLU(inplace=True), + torch.nn.Conv1d(skip_channels, skip_channels, kernel_size=1, bias=True), + torch.nn.ReLU(inplace=True), + torch.nn.Conv1d(skip_channels, out_channels, kernel_size=1, bias=True), + ] + ) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, c): + """ + c: (B, C ,T'). + o: Output tensor (B, out_channels, T) + """ + # random noise + x = torch.randn([c.shape[0], 1, c.shape[2] * self.upsample_scale]) + x = x.to(self.first_conv.bias.device) + + # perform upsampling + if c is not None and self.upsample_net is not None: + c = self.upsample_net(c) + assert ( + c.shape[-1] == x.shape[-1] + ), f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}" + + # encode to hidden representation + x = self.first_conv(x) + skips = 0 + for f in self.conv_layers: + x, h = f(x, c) + skips += h + skips *= math.sqrt(1.0 / len(self.conv_layers)) + + # apply final layers + x = skips + for f in self.last_conv_layers: + x = f(x) + + return x + + @torch.no_grad() + def inference(self, c): + c = c.to(self.first_conv.weight.device) + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") + return self.forward(c) + + def remove_weight_norm(self): + def _remove_weight_norm(m): + try: + logger.info("Weight norm is removed from %s", m) + remove_parametrizations(m, "weight") + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + def _apply_weight_norm(m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): + torch.nn.utils.parametrizations.weight_norm(m) + logger.info("Weight norm is applied to %s", m) + + self.apply(_apply_weight_norm) + + @staticmethod + def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): + assert layers % stacks == 0 + layers_per_cycle = layers // stacks + dilations = [dilation(i % layers_per_cycle) for i in range(layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + @property + def receptive_field_size(self): + return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + if self.use_weight_norm: + self.remove_weight_norm() diff --git a/TTS/vocoder/models/random_window_discriminator.py b/TTS/vocoder/models/random_window_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..79b68e9780ff1703b0f4955e37fe93c02d97ea09 --- /dev/null +++ b/TTS/vocoder/models/random_window_discriminator.py @@ -0,0 +1,203 @@ +import numpy as np +from torch import nn + + +class GBlock(nn.Module): + def __init__(self, in_channels, cond_channels, downsample_factor): + super().__init__() + + self.in_channels = in_channels + self.cond_channels = cond_channels + self.downsample_factor = downsample_factor + + self.start = nn.Sequential( + nn.AvgPool1d(downsample_factor, stride=downsample_factor), + nn.ReLU(), + nn.Conv1d(in_channels, in_channels * 2, kernel_size=3, padding=1), + ) + self.lc_conv1d = nn.Conv1d(cond_channels, in_channels * 2, kernel_size=1) + self.end = nn.Sequential( + nn.ReLU(), nn.Conv1d(in_channels * 2, in_channels * 2, kernel_size=3, dilation=2, padding=2) + ) + self.residual = nn.Sequential( + nn.Conv1d(in_channels, in_channels * 2, kernel_size=1), + nn.AvgPool1d(downsample_factor, stride=downsample_factor), + ) + + def forward(self, inputs, conditions): + outputs = self.start(inputs) + self.lc_conv1d(conditions) + outputs = self.end(outputs) + residual_outputs = self.residual(inputs) + outputs = outputs + residual_outputs + + return outputs + + +class DBlock(nn.Module): + def __init__(self, in_channels, out_channels, downsample_factor): + super().__init__() + + self.in_channels = in_channels + self.downsample_factor = downsample_factor + self.out_channels = out_channels + + self.donwsample_layer = nn.AvgPool1d(downsample_factor, stride=downsample_factor) + self.layers = nn.Sequential( + nn.ReLU(), + nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1), + nn.ReLU(), + nn.Conv1d(out_channels, out_channels, kernel_size=3, dilation=2, padding=2), + ) + self.residual = nn.Sequential( + nn.Conv1d(in_channels, out_channels, kernel_size=1), + ) + + def forward(self, inputs): + if self.downsample_factor > 1: + outputs = self.layers(self.donwsample_layer(inputs)) + self.donwsample_layer(self.residual(inputs)) + else: + outputs = self.layers(inputs) + self.residual(inputs) + return outputs + + +class ConditionalDiscriminator(nn.Module): + def __init__(self, in_channels, cond_channels, downsample_factors=(2, 2, 2), out_channels=(128, 256)): + super().__init__() + + assert len(downsample_factors) == len(out_channels) + 1 + + self.in_channels = in_channels + self.cond_channels = cond_channels + self.downsample_factors = downsample_factors + self.out_channels = out_channels + + self.pre_cond_layers = nn.ModuleList() + self.post_cond_layers = nn.ModuleList() + + # layers before condition features + self.pre_cond_layers += [DBlock(in_channels, 64, 1)] + in_channels = 64 + for i, channel in enumerate(out_channels): + self.pre_cond_layers.append(DBlock(in_channels, channel, downsample_factors[i])) + in_channels = channel + + # condition block + self.cond_block = GBlock(in_channels, cond_channels, downsample_factors[-1]) + + # layers after condition block + self.post_cond_layers += [ + DBlock(in_channels * 2, in_channels * 2, 1), + DBlock(in_channels * 2, in_channels * 2, 1), + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(in_channels * 2, 1, kernel_size=1), + ] + + def forward(self, inputs, conditions): + batch_size = inputs.size()[0] + outputs = inputs.view(batch_size, self.in_channels, -1) + for layer in self.pre_cond_layers: + outputs = layer(outputs) + outputs = self.cond_block(outputs, conditions) + for layer in self.post_cond_layers: + outputs = layer(outputs) + + return outputs + + +class UnconditionalDiscriminator(nn.Module): + def __init__(self, in_channels, base_channels=64, downsample_factors=(8, 4), out_channels=(128, 256)): + super().__init__() + + self.downsample_factors = downsample_factors + self.in_channels = in_channels + self.downsample_factors = downsample_factors + self.out_channels = out_channels + + self.layers = nn.ModuleList() + self.layers += [DBlock(self.in_channels, base_channels, 1)] + in_channels = base_channels + for i, factor in enumerate(downsample_factors): + self.layers.append(DBlock(in_channels, out_channels[i], factor)) + in_channels *= 2 + self.layers += [ + DBlock(in_channels, in_channels, 1), + DBlock(in_channels, in_channels, 1), + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(in_channels, 1, kernel_size=1), + ] + + def forward(self, inputs): + batch_size = inputs.size()[0] + outputs = inputs.view(batch_size, self.in_channels, -1) + for layer in self.layers: + outputs = layer(outputs) + return outputs + + +class RandomWindowDiscriminator(nn.Module): + """Random Window Discriminator as described in + http://arxiv.org/abs/1909.11646""" + + def __init__( + self, + cond_channels, + hop_length, + uncond_disc_donwsample_factors=(8, 4), + cond_disc_downsample_factors=((8, 4, 2, 2, 2), (8, 4, 2, 2), (8, 4, 2), (8, 4), (4, 2, 2)), + cond_disc_out_channels=((128, 128, 256, 256), (128, 256, 256), (128, 256), (256,), (128, 256)), + window_sizes=(512, 1024, 2048, 4096, 8192), + ): + super().__init__() + self.cond_channels = cond_channels + self.window_sizes = window_sizes + self.hop_length = hop_length + self.base_window_size = self.hop_length * 2 + self.ks = [ws // self.base_window_size for ws in window_sizes] + + # check arguments + assert len(cond_disc_downsample_factors) == len(cond_disc_out_channels) == len(window_sizes) + for ws in window_sizes: + assert ws % hop_length == 0 + + for idx, cf in enumerate(cond_disc_downsample_factors): + assert np.prod(cf) == hop_length // self.ks[idx] + + # define layers + self.unconditional_discriminators = nn.ModuleList([]) + for k in self.ks: + layer = UnconditionalDiscriminator( + in_channels=k, base_channels=64, downsample_factors=uncond_disc_donwsample_factors + ) + self.unconditional_discriminators.append(layer) + + self.conditional_discriminators = nn.ModuleList([]) + for idx, k in enumerate(self.ks): + layer = ConditionalDiscriminator( + in_channels=k, + cond_channels=cond_channels, + downsample_factors=cond_disc_downsample_factors[idx], + out_channels=cond_disc_out_channels[idx], + ) + self.conditional_discriminators.append(layer) + + def forward(self, x, c): + scores = [] + feats = [] + # unconditional pass + for window_size, layer in zip(self.window_sizes, self.unconditional_discriminators): + index = np.random.randint(x.shape[-1] - window_size) + + score = layer(x[:, :, index : index + window_size]) + scores.append(score) + + # conditional pass + for window_size, layer in zip(self.window_sizes, self.conditional_discriminators): + frame_size = window_size // self.hop_length + lc_index = np.random.randint(c.shape[-1] - frame_size) + sample_index = lc_index * self.hop_length + x_sub = x[:, :, sample_index : (lc_index + frame_size) * self.hop_length] + c_sub = c[:, :, lc_index : lc_index + frame_size] + + score = layer(x_sub, c_sub) + scores.append(score) + return scores, feats diff --git a/TTS/vocoder/models/univnet_discriminator.py b/TTS/vocoder/models/univnet_discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..497d67ac76d0cb797db92bba3f9ad24ab5293a64 --- /dev/null +++ b/TTS/vocoder/models/univnet_discriminator.py @@ -0,0 +1,95 @@ +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.utils import spectral_norm +from torch.nn.utils.parametrizations import weight_norm + +from TTS.utils.audio.torch_transforms import TorchSTFT +from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator + +LRELU_SLOPE = 0.1 + + +class SpecDiscriminator(nn.Module): + """docstring for Discriminator.""" + + def __init__(self, fft_size=1024, hop_length=120, win_length=600, use_spectral_norm=False): + super().__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.fft_size = fft_size + self.hop_length = hop_length + self.win_length = win_length + self.stft = TorchSTFT(fft_size, hop_length, win_length) + self.discriminators = nn.ModuleList( + [ + norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))), + norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), + ] + ) + + self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1)) + + def forward(self, y): + fmap = [] + with torch.no_grad(): + y = y.squeeze(1) + y = self.stft(y) + y = y.unsqueeze(1) + for _, d in enumerate(self.discriminators): + y = d(y) + y = F.leaky_relu(y, LRELU_SLOPE) + fmap.append(y) + + y = self.out(y) + fmap.append(y) + + return torch.flatten(y, 1, -1), fmap + + +class MultiResSpecDiscriminator(torch.nn.Module): + def __init__( # pylint: disable=dangerous-default-value + self, fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240], window="hann_window" + ): + super().__init__() + self.discriminators = nn.ModuleList( + [ + SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window), + SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window), + SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window), + ] + ) + + def forward(self, x): + scores = [] + feats = [] + for d in self.discriminators: + score, feat = d(x) + scores.append(score) + feats.append(feat) + + return scores, feats + + +class UnivnetDiscriminator(nn.Module): + """Univnet discriminator wrapping MPD and MSD.""" + + def __init__(self): + super().__init__() + self.mpd = MultiPeriodDiscriminator() + self.msd = MultiResSpecDiscriminator() + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of list of features from each layers of each discriminator. + """ + scores, feats = self.mpd(x) + scores_, feats_ = self.msd(x) + return scores + scores_, feats + feats_ diff --git a/TTS/vocoder/models/univnet_generator.py b/TTS/vocoder/models/univnet_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..72e57a9c3900bc9ca55da3efa8ace01e53542891 --- /dev/null +++ b/TTS/vocoder/models/univnet_generator.py @@ -0,0 +1,160 @@ +import logging +from typing import List + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn.utils import parametrize + +from TTS.vocoder.layers.lvc_block import LVCBlock + +logger = logging.getLogger(__name__) + +LRELU_SLOPE = 0.1 + + +class UnivnetGenerator(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + cond_channels: int, + upsample_factors: List[int], + lvc_layers_each_block: int, + lvc_kernel_size: int, + kpnet_hidden_channels: int, + kpnet_conv_size: int, + dropout: float, + use_weight_norm=True, + ): + """Univnet Generator network. + + Paper: https://arxiv.org/pdf/2106.07889.pdf + + Args: + in_channels (int): Number of input tensor channels. + out_channels (int): Number of channels of the output tensor. + hidden_channels (int): Number of hidden network channels. + cond_channels (int): Number of channels of the conditioning tensors. + upsample_factors (List[int]): List of uplsample factors for the upsampling layers. + lvc_layers_each_block (int): Number of LVC layers in each block. + lvc_kernel_size (int): Kernel size of the LVC layers. + kpnet_hidden_channels (int): Number of hidden channels in the key-point network. + kpnet_conv_size (int): Number of convolution channels in the key-point network. + dropout (float): Dropout rate. + use_weight_norm (bool, optional): Enable/disable weight norm. Defaults to True. + """ + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.cond_channels = cond_channels + self.upsample_scale = np.prod(upsample_factors) + self.lvc_block_nums = len(upsample_factors) + + # define first convolution + self.first_conv = torch.nn.Conv1d( + in_channels, hidden_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True + ) + + # define residual blocks + self.lvc_blocks = torch.nn.ModuleList() + cond_hop_length = 1 + for n in range(self.lvc_block_nums): + cond_hop_length = cond_hop_length * upsample_factors[n] + lvcb = LVCBlock( + in_channels=hidden_channels, + cond_channels=cond_channels, + upsample_ratio=upsample_factors[n], + conv_layers=lvc_layers_each_block, + conv_kernel_size=lvc_kernel_size, + cond_hop_length=cond_hop_length, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=dropout, + ) + self.lvc_blocks += [lvcb] + + # define output layers + self.last_conv_layers = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + hidden_channels, out_channels, kernel_size=7, padding=(7 - 1) // 2, dilation=1, bias=True + ), + ] + ) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, c): + """Calculate forward propagation. + Args: + c (Tensor): Local conditioning auxiliary features (B, C ,T'). + Returns: + Tensor: Output tensor (B, out_channels, T) + """ + # random noise + x = torch.randn([c.shape[0], self.in_channels, c.shape[2]]) + x = x.to(self.first_conv.bias.device) + x = self.first_conv(x) + + for n in range(self.lvc_block_nums): + x = self.lvc_blocks[n](x, c) + + # apply final layers + for f in self.last_conv_layers: + x = F.leaky_relu(x, LRELU_SLOPE) + x = f(x) + x = torch.tanh(x) + return x + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + logger.info("Weight norm is removed from %s", m) + parametrize.remove_parametrizations(m, "weight") + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): + torch.nn.utils.parametrizations.weight_norm(m) + logger.info("Weight norm is applied to %s", m) + + self.apply(_apply_weight_norm) + + @staticmethod + def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): + assert layers % stacks == 0 + layers_per_cycle = layers // stacks + dilations = [dilation(i % layers_per_cycle) for i in range(layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + @property + def receptive_field_size(self): + """Return receptive field size.""" + return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) + + @torch.no_grad() + def inference(self, c): + """Perform inference. + Args: + c (Tensor): Local conditioning auxiliary features :math:`(B, C, T)`. + Returns: + Tensor: Output tensor (T, out_channels) + """ + x = torch.randn([c.shape[0], self.in_channels, c.shape[2]]) + x = x.to(self.first_conv.bias.device) + + c = c.to(next(self.parameters())) + return self.forward(c) diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py new file mode 100644 index 0000000000000000000000000000000000000000..c49abd22017c309209b1dcabb3fb6a765cb171f2 --- /dev/null +++ b/TTS/vocoder/models/wavegrad.py @@ -0,0 +1,344 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Tuple + +import numpy as np +import torch +from coqpit import Coqpit +from torch import nn +from torch.nn.utils.parametrizations import weight_norm +from torch.nn.utils.parametrize import remove_parametrizations +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from trainer.io import load_fsspec +from trainer.trainer_utils import get_optimizer, get_scheduler + +from TTS.vocoder.datasets import WaveGradDataset +from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock +from TTS.vocoder.models.base_vocoder import BaseVocoder +from TTS.vocoder.utils.generic_utils import plot_results + + +@dataclass +class WavegradArgs(Coqpit): + in_channels: int = 80 + out_channels: int = 1 + use_weight_norm: bool = False + y_conv_channels: int = 32 + x_conv_channels: int = 768 + dblock_out_channels: List[int] = field(default_factory=lambda: [128, 128, 256, 512]) + ublock_out_channels: List[int] = field(default_factory=lambda: [512, 512, 256, 128, 128]) + upsample_factors: List[int] = field(default_factory=lambda: [4, 4, 4, 2, 2]) + upsample_dilations: List[List[int]] = field( + default_factory=lambda: [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]] + ) + + +class Wavegrad(BaseVocoder): + """🐸 🌊 WaveGrad 🌊 model. + Paper - https://arxiv.org/abs/2009.00713 + + Examples: + Initializing the model. + + >>> from TTS.vocoder.configs import WavegradConfig + >>> config = WavegradConfig() + >>> model = Wavegrad(config) + + Paper Abstract: + This paper introduces WaveGrad, a conditional model for waveform generation which estimates gradients of the + data density. The model is built on prior work on score matching and diffusion probabilistic models. It starts + from a Gaussian white noise signal and iteratively refines the signal via a gradient-based sampler conditioned + on the mel-spectrogram. WaveGrad offers a natural way to trade inference speed for sample quality by adjusting + the number of refinement steps, and bridges the gap between non-autoregressive and autoregressive models in + terms of audio quality. We find that it can generate high fidelity audio samples using as few as six iterations. + Experiments reveal WaveGrad to generate high fidelity audio, outperforming adversarial non-autoregressive + baselines and matching a strong likelihood-based autoregressive baseline using fewer sequential operations. + Audio samples are available at this https URL. + """ + + # pylint: disable=dangerous-default-value + def __init__(self, config: Coqpit): + super().__init__(config) + self.config = config + self.use_weight_norm = config.model_params.use_weight_norm + self.hop_len = np.prod(config.model_params.upsample_factors) + self.noise_level = None + self.num_steps = None + self.beta = None + self.alpha = None + self.alpha_hat = None + self.c1 = None + self.c2 = None + self.sigma = None + + # dblocks + self.y_conv = Conv1d(1, config.model_params.y_conv_channels, 5, padding=2) + self.dblocks = nn.ModuleList([]) + ic = config.model_params.y_conv_channels + for oc, df in zip(config.model_params.dblock_out_channels, reversed(config.model_params.upsample_factors)): + self.dblocks.append(DBlock(ic, oc, df)) + ic = oc + + # film + self.film = nn.ModuleList([]) + ic = config.model_params.y_conv_channels + for oc in reversed(config.model_params.ublock_out_channels): + self.film.append(FiLM(ic, oc)) + ic = oc + + # ublocksn + self.ublocks = nn.ModuleList([]) + ic = config.model_params.x_conv_channels + for oc, uf, ud in zip( + config.model_params.ublock_out_channels, + config.model_params.upsample_factors, + config.model_params.upsample_dilations, + ): + self.ublocks.append(UBlock(ic, oc, uf, ud)) + ic = oc + + self.x_conv = Conv1d(config.model_params.in_channels, config.model_params.x_conv_channels, 3, padding=1) + self.out_conv = Conv1d(oc, config.model_params.out_channels, 3, padding=1) + + if config.model_params.use_weight_norm: + self.apply_weight_norm() + + def forward(self, x, spectrogram, noise_scale): + shift_and_scale = [] + + x = self.y_conv(x) + shift_and_scale.append(self.film[0](x, noise_scale)) + + for film, layer in zip(self.film[1:], self.dblocks): + x = layer(x) + shift_and_scale.append(film(x, noise_scale)) + + x = self.x_conv(spectrogram) + for layer, (film_shift, film_scale) in zip(self.ublocks, reversed(shift_and_scale)): + x = layer(x, film_shift, film_scale) + x = self.out_conv(x) + return x + + def load_noise_schedule(self, path): + beta = np.load(path, allow_pickle=True).item()["beta"] # pylint: disable=unexpected-keyword-arg + self.compute_noise_level(beta) + + @torch.no_grad() + def inference(self, x, y_n=None): + """ + Shapes: + x: :math:`[B, C , T]` + y_n: :math:`[B, 1, T]` + """ + if y_n is None: + y_n = torch.randn(x.shape[0], 1, self.hop_len * x.shape[-1]) + else: + y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0) + y_n = y_n.type_as(x) + sqrt_alpha_hat = self.noise_level.to(x) + for n in range(len(self.alpha) - 1, -1, -1): + y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0]))) + if n > 0: + z = torch.randn_like(y_n) + y_n += self.sigma[n - 1] * z + y_n.clamp_(-1.0, 1.0) + return y_n + + def compute_y_n(self, y_0): + """Compute noisy audio based on noise schedule""" + self.noise_level = self.noise_level.to(y_0) + if len(y_0.shape) == 3: + y_0 = y_0.squeeze(1) + s = torch.randint(0, self.num_steps - 1, [y_0.shape[0]]) + l_a, l_b = self.noise_level[s], self.noise_level[s + 1] + noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a) + noise_scale = noise_scale.unsqueeze(1) + noise = torch.randn_like(y_0) + noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2) ** 0.5 * noise + return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0] + + def compute_noise_level(self, beta): + """Compute noise schedule parameters""" + self.num_steps = len(beta) + alpha = 1 - beta + alpha_hat = np.cumprod(alpha) + noise_level = np.concatenate([[1.0], alpha_hat**0.5], axis=0) + noise_level = alpha_hat**0.5 + + # pylint: disable=not-callable + self.beta = torch.tensor(beta.astype(np.float32)) + self.alpha = torch.tensor(alpha.astype(np.float32)) + self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32)) + self.noise_level = torch.tensor(noise_level.astype(np.float32)) + + self.c1 = 1 / self.alpha**0.5 + self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5 + self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5 + + def remove_weight_norm(self): + for _, layer in enumerate(self.dblocks): + if len(layer.state_dict()) != 0: + try: + remove_parametrizations(layer, "weight") + except ValueError: + layer.remove_weight_norm() + + for _, layer in enumerate(self.film): + if len(layer.state_dict()) != 0: + try: + remove_parametrizations(layer, "weight") + except ValueError: + layer.remove_weight_norm() + + for _, layer in enumerate(self.ublocks): + if len(layer.state_dict()) != 0: + try: + remove_parametrizations(layer, "weight") + except ValueError: + layer.remove_weight_norm() + + remove_parametrizations(self.x_conv, "weight") + remove_parametrizations(self.out_conv, "weight") + remove_parametrizations(self.y_conv, "weight") + + def apply_weight_norm(self): + for _, layer in enumerate(self.dblocks): + if len(layer.state_dict()) != 0: + layer.apply_weight_norm() + + for _, layer in enumerate(self.film): + if len(layer.state_dict()) != 0: + layer.apply_weight_norm() + + for _, layer in enumerate(self.ublocks): + if len(layer.state_dict()) != 0: + layer.apply_weight_norm() + + self.x_conv = weight_norm(self.x_conv) + self.out_conv = weight_norm(self.out_conv) + self.y_conv = weight_norm(self.y_conv) + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + if self.config.model_params.use_weight_norm: + self.remove_weight_norm() + betas = np.linspace( + config["test_noise_schedule"]["min_val"], + config["test_noise_schedule"]["max_val"], + config["test_noise_schedule"]["num_steps"], + ) + self.compute_noise_level(betas) + else: + betas = np.linspace( + config["train_noise_schedule"]["min_val"], + config["train_noise_schedule"]["max_val"], + config["train_noise_schedule"]["num_steps"], + ) + self.compute_noise_level(betas) + + def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]: + # format data + x = batch["input"] + y = batch["waveform"] + + # set noise scale + noise, x_noisy, noise_scale = self.compute_y_n(y) + + # forward pass + noise_hat = self.forward(x_noisy, x, noise_scale) + + # compute losses + loss = criterion(noise, noise_hat) + return {"model_output": noise_hat}, {"loss": loss} + + def train_log( # pylint: disable=no-self-use + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: + pass + + @torch.no_grad() + def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: + return self.train_step(batch, criterion) + + def eval_log( # pylint: disable=no-self-use + self, batch: Dict, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> None: + pass + + def test(self, assets: Dict, test_loader: "DataLoader", outputs=None): # pylint: disable=unused-argument + # setup noise schedule and inference + ap = assets["audio_processor"] + noise_schedule = self.config["test_noise_schedule"] + betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) + self.compute_noise_level(betas) + samples = test_loader.dataset.load_test_samples(1) + for sample in samples: + x = sample[0] + x = x[None, :, :].to(next(self.parameters()).device) + y = sample[1] + y = y[None, :] + # compute voice + y_pred = self.inference(x) + # compute spectrograms + figures = plot_results(y_pred, y, ap, "test") + # Sample audio + sample_voice = y_pred[0].squeeze(0).detach().cpu().numpy() + return figures, {"test/audio": sample_voice} + + def get_optimizer(self): + return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self) + + def get_scheduler(self, optimizer): + return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer) + + @staticmethod + def get_criterion(): + return torch.nn.L1Loss() + + @staticmethod + def format_batch(batch: Dict) -> Dict: + # return a whole audio segment + m, y = batch[0], batch[1] + y = y.unsqueeze(1) + return {"input": m, "waveform": y} + + def get_data_loader(self, config: Coqpit, assets: Dict, is_eval: True, samples: List, verbose: bool, num_gpus: int): + ap = assets["audio_processor"] + dataset = WaveGradDataset( + ap=ap, + items=samples, + seq_len=self.config.seq_len, + hop_len=ap.hop_length, + pad_short=self.config.pad_short, + conv_pad=self.config.conv_pad, + is_training=not is_eval, + return_segments=True, + use_noise_augment=False, + use_cache=config.use_cache, + ) + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=self.config.batch_size, + shuffle=num_gpus <= 1, + drop_last=False, + sampler=sampler, + num_workers=self.config.num_eval_loader_workers if is_eval else self.config.num_loader_workers, + pin_memory=False, + ) + return loader + + def on_epoch_start(self, trainer): # pylint: disable=unused-argument + noise_schedule = self.config["train_noise_schedule"] + betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) + self.compute_noise_level(betas) + + @staticmethod + def init_from_config(config: "WavegradConfig"): + return Wavegrad(config) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py new file mode 100644 index 0000000000000000000000000000000000000000..723f18dde28099800c449e1b11a6c1b328e90282 --- /dev/null +++ b/TTS/vocoder/models/wavernn.py @@ -0,0 +1,645 @@ +import sys +import time +from dataclasses import dataclass, field +from typing import Dict, List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from coqpit import Coqpit +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from trainer.io import load_fsspec + +from TTS.tts.utils.visual import plot_spectrogram +from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import mulaw_decode +from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset +from TTS.vocoder.layers.losses import WaveRNNLoss +from TTS.vocoder.models.base_vocoder import BaseVocoder +from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian + + +def stream(string, variables): + sys.stdout.write(f"\r{string}" % variables) + + +# pylint: disable=abstract-method +# relates https://github.com/pytorch/pytorch/issues/42305 +class ResBlock(nn.Module): + def __init__(self, dims): + super().__init__() + self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.batch_norm1 = nn.BatchNorm1d(dims) + self.batch_norm2 = nn.BatchNorm1d(dims) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.batch_norm1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.batch_norm2(x) + return x + residual + + +class MelResNet(nn.Module): + def __init__(self, num_res_blocks, in_dims, compute_dims, res_out_dims, pad): + super().__init__() + k_size = pad * 2 + 1 + self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False) + self.batch_norm = nn.BatchNorm1d(compute_dims) + self.layers = nn.ModuleList() + for _ in range(num_res_blocks): + self.layers.append(ResBlock(compute_dims)) + self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1) + + def forward(self, x): + x = self.conv_in(x) + x = self.batch_norm(x) + x = F.relu(x) + for f in self.layers: + x = f(x) + x = self.conv_out(x) + return x + + +class Stretch2d(nn.Module): + def __init__(self, x_scale, y_scale): + super().__init__() + self.x_scale = x_scale + self.y_scale = y_scale + + def forward(self, x): + b, c, h, w = x.size() + x = x.unsqueeze(-1).unsqueeze(3) + x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale) + return x.view(b, c, h * self.y_scale, w * self.x_scale) + + +class UpsampleNetwork(nn.Module): + def __init__( + self, + feat_dims, + upsample_scales, + compute_dims, + num_res_blocks, + res_out_dims, + pad, + use_aux_net, + ): + super().__init__() + self.total_scale = np.cumprod(upsample_scales)[-1] + self.indent = pad * self.total_scale + self.use_aux_net = use_aux_net + if use_aux_net: + self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad) + self.resnet_stretch = Stretch2d(self.total_scale, 1) + self.up_layers = nn.ModuleList() + for scale in upsample_scales: + k_size = (1, scale * 2 + 1) + padding = (0, scale) + stretch = Stretch2d(scale, 1) + conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False) + conv.weight.data.fill_(1.0 / k_size[1]) + self.up_layers.append(stretch) + self.up_layers.append(conv) + + def forward(self, m): + if self.use_aux_net: + aux = self.resnet(m).unsqueeze(1) + aux = self.resnet_stretch(aux) + aux = aux.squeeze(1) + aux = aux.transpose(1, 2) + else: + aux = None + m = m.unsqueeze(1) + for f in self.up_layers: + m = f(m) + m = m.squeeze(1)[:, :, self.indent : -self.indent] + return m.transpose(1, 2), aux + + +class Upsample(nn.Module): + def __init__(self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net): + super().__init__() + self.scale = scale + self.pad = pad + self.indent = pad * scale + self.use_aux_net = use_aux_net + self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad) + + def forward(self, m): + if self.use_aux_net: + aux = self.resnet(m) + aux = torch.nn.functional.interpolate(aux, scale_factor=self.scale, mode="linear", align_corners=True) + aux = aux.transpose(1, 2) + else: + aux = None + m = torch.nn.functional.interpolate(m, scale_factor=self.scale, mode="linear", align_corners=True) + m = m[:, :, self.indent : -self.indent] + m = m * 0.045 # empirically found + + return m.transpose(1, 2), aux + + +@dataclass +class WavernnArgs(Coqpit): + """🐸 WaveRNN model arguments. + + rnn_dims (int): + Number of hidden channels in RNN layers. Defaults to 512. + fc_dims (int): + Number of hidden channels in fully-conntected layers. Defaults to 512. + compute_dims (int): + Number of hidden channels in the feature ResNet. Defaults to 128. + res_out_dim (int): + Number of hidden channels in the feature ResNet output. Defaults to 128. + num_res_blocks (int): + Number of residual blocks in the ResNet. Defaults to 10. + use_aux_net (bool): + enable/disable the feature ResNet. Defaults to True. + use_upsample_net (bool): + enable/ disable the upsampling networl. If False, basic upsampling is used. Defaults to True. + upsample_factors (list): + Upsampling factors. The multiply of the values must match the `hop_length`. Defaults to ```[4, 8, 8]```. + mode (str): + Output mode of the WaveRNN vocoder. `mold` for Mixture of Logistic Distribution, `gauss` for a single + Gaussian Distribution and `bits` for quantized bits as the model's output. + mulaw (bool): + enable / disable the use of Mulaw quantization for training. Only applicable if `mode == 'bits'`. Defaults + to `True`. + pad (int): + Padding applied to the input feature frames against the convolution layers of the feature network. + Defaults to 2. + """ + + rnn_dims: int = 512 + fc_dims: int = 512 + compute_dims: int = 128 + res_out_dims: int = 128 + num_res_blocks: int = 10 + use_aux_net: bool = True + use_upsample_net: bool = True + upsample_factors: List[int] = field(default_factory=lambda: [4, 8, 8]) + mode: str = "mold" # mold [string], gauss [string], bits [int] + mulaw: bool = True # apply mulaw if mode is bits + pad: int = 2 + feat_dims: int = 80 + + +class Wavernn(BaseVocoder): + def __init__(self, config: Coqpit): + """🐸 WaveRNN model. + Original paper - https://arxiv.org/abs/1802.08435 + Official implementation - https://github.com/fatchord/WaveRNN + + Args: + config (Coqpit): [description] + + Raises: + RuntimeError: [description] + + Examples: + >>> from TTS.vocoder.configs import WavernnConfig + >>> config = WavernnConfig() + >>> model = Wavernn(config) + + Paper Abstract: + Sequential models achieve state-of-the-art results in audio, visual and textual domains with respect to + both estimating the data distribution and generating high-quality samples. Efficient sampling for this + class of models has however remained an elusive problem. With a focus on text-to-speech synthesis, we + describe a set of general techniques for reducing sampling time while maintaining high output quality. + We first describe a single-layer recurrent neural network, the WaveRNN, with a dual softmax layer that + matches the quality of the state-of-the-art WaveNet model. The compact form of the network makes it + possible to generate 24kHz 16-bit audio 4x faster than real time on a GPU. Second, we apply a weight + pruning technique to reduce the number of weights in the WaveRNN. We find that, for a constant number of + parameters, large sparse networks perform better than small dense networks and this relationship holds for + sparsity levels beyond 96%. The small number of weights in a Sparse WaveRNN makes it possible to sample + high-fidelity audio on a mobile CPU in real time. Finally, we propose a new generation scheme based on + subscaling that folds a long sequence into a batch of shorter sequences and allows one to generate multiple + samples at once. The Subscale WaveRNN produces 16 samples per step without loss of quality and offers an + orthogonal method for increasing sampling efficiency. + """ + super().__init__(config) + + if isinstance(self.args.mode, int): + self.n_classes = 2**self.args.mode + elif self.args.mode == "mold": + self.n_classes = 3 * 10 + elif self.args.mode == "gauss": + self.n_classes = 2 + else: + raise RuntimeError("Unknown model mode value - ", self.args.mode) + + self.ap = AudioProcessor(**config.audio.to_dict()) + self.aux_dims = self.args.res_out_dims // 4 + + if self.args.use_upsample_net: + assert ( + np.cumprod(self.args.upsample_factors)[-1] == config.audio.hop_length + ), " [!] upsample scales needs to be equal to hop_length" + self.upsample = UpsampleNetwork( + self.args.feat_dims, + self.args.upsample_factors, + self.args.compute_dims, + self.args.num_res_blocks, + self.args.res_out_dims, + self.args.pad, + self.args.use_aux_net, + ) + else: + self.upsample = Upsample( + config.audio.hop_length, + self.args.pad, + self.args.num_res_blocks, + self.args.feat_dims, + self.args.compute_dims, + self.args.res_out_dims, + self.args.use_aux_net, + ) + if self.args.use_aux_net: + self.I = nn.Linear(self.args.feat_dims + self.aux_dims + 1, self.args.rnn_dims) + self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True) + self.rnn2 = nn.GRU(self.args.rnn_dims + self.aux_dims, self.args.rnn_dims, batch_first=True) + self.fc1 = nn.Linear(self.args.rnn_dims + self.aux_dims, self.args.fc_dims) + self.fc2 = nn.Linear(self.args.fc_dims + self.aux_dims, self.args.fc_dims) + self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes) + else: + self.I = nn.Linear(self.args.feat_dims + 1, self.args.rnn_dims) + self.rnn1 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True) + self.rnn2 = nn.GRU(self.args.rnn_dims, self.args.rnn_dims, batch_first=True) + self.fc1 = nn.Linear(self.args.rnn_dims, self.args.fc_dims) + self.fc2 = nn.Linear(self.args.fc_dims, self.args.fc_dims) + self.fc3 = nn.Linear(self.args.fc_dims, self.n_classes) + + def forward(self, x, mels): + bsize = x.size(0) + h1 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device) + h2 = torch.zeros(1, bsize, self.args.rnn_dims).to(x.device) + mels, aux = self.upsample(mels) + + if self.args.use_aux_net: + aux_idx = [self.aux_dims * i for i in range(5)] + a1 = aux[:, :, aux_idx[0] : aux_idx[1]] + a2 = aux[:, :, aux_idx[1] : aux_idx[2]] + a3 = aux[:, :, aux_idx[2] : aux_idx[3]] + a4 = aux[:, :, aux_idx[3] : aux_idx[4]] + + x = ( + torch.cat([x.unsqueeze(-1), mels, a1], dim=2) + if self.args.use_aux_net + else torch.cat([x.unsqueeze(-1), mels], dim=2) + ) + x = self.I(x) + res = x + self.rnn1.flatten_parameters() + x, _ = self.rnn1(x, h1) + + x = x + res + res = x + x = torch.cat([x, a2], dim=2) if self.args.use_aux_net else x + self.rnn2.flatten_parameters() + x, _ = self.rnn2(x, h2) + + x = x + res + x = torch.cat([x, a3], dim=2) if self.args.use_aux_net else x + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4], dim=2) if self.args.use_aux_net else x + x = F.relu(self.fc2(x)) + return self.fc3(x) + + def inference(self, mels, batched=None, target=None, overlap=None): + self.eval() + output = [] + start = time.time() + rnn1 = self.get_gru_cell(self.rnn1) + rnn2 = self.get_gru_cell(self.rnn2) + + with torch.no_grad(): + if isinstance(mels, np.ndarray): + mels = torch.FloatTensor(mels).to(str(next(self.parameters()).device)) + + if mels.ndim == 2: + mels = mels.unsqueeze(0) + wave_len = (mels.size(-1) - 1) * self.config.audio.hop_length + + mels = self.pad_tensor(mels.transpose(1, 2), pad=self.args.pad, side="both") + mels, aux = self.upsample(mels.transpose(1, 2)) + + if batched: + mels = self.fold_with_overlap(mels, target, overlap) + if aux is not None: + aux = self.fold_with_overlap(aux, target, overlap) + + b_size, seq_len, _ = mels.size() + + h1 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels) + h2 = torch.zeros(b_size, self.args.rnn_dims).type_as(mels) + x = torch.zeros(b_size, 1).type_as(mels) + + if self.args.use_aux_net: + d = self.aux_dims + aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)] + + for i in range(seq_len): + m_t = mels[:, i, :] + + if self.args.use_aux_net: + a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split) + + x = torch.cat([x, m_t, a1_t], dim=1) if self.args.use_aux_net else torch.cat([x, m_t], dim=1) + x = self.I(x) + h1 = rnn1(x, h1) + + x = x + h1 + inp = torch.cat([x, a2_t], dim=1) if self.args.use_aux_net else x + h2 = rnn2(inp, h2) + + x = x + h2 + x = torch.cat([x, a3_t], dim=1) if self.args.use_aux_net else x + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4_t], dim=1) if self.args.use_aux_net else x + x = F.relu(self.fc2(x)) + + logits = self.fc3(x) + + if self.args.mode == "mold": + sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2)) + output.append(sample.view(-1)) + x = sample.transpose(0, 1).type_as(mels) + elif self.args.mode == "gauss": + sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2)) + output.append(sample.view(-1)) + x = sample.transpose(0, 1).type_as(mels) + elif isinstance(self.args.mode, int): + posterior = F.softmax(logits, dim=1) + distrib = torch.distributions.Categorical(posterior) + + sample = 2 * distrib.sample().float() / (self.n_classes - 1.0) - 1.0 + output.append(sample) + x = sample.unsqueeze(-1) + else: + raise RuntimeError("Unknown model mode value - ", self.args.mode) + + if i % 100 == 0: + self.gen_display(i, seq_len, b_size, start) + + output = torch.stack(output).transpose(0, 1) + output = output.cpu() + if batched: + output = output.numpy() + output = output.astype(np.float64) + + output = self.xfade_and_unfold(output, target, overlap) + else: + output = output[0] + + if self.args.mulaw and isinstance(self.args.mode, int): + output = mulaw_decode(wav=output, mulaw_qc=self.args.mode) + + # Fade-out at the end to avoid signal cutting out suddenly + fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length) + output = output[:wave_len] + + if wave_len > len(fade_out): + output[-20 * self.config.audio.hop_length :] *= fade_out + + self.train() + return output + + def gen_display(self, i, seq_len, b_size, start): + gen_rate = (i + 1) / (time.time() - start) * b_size / 1000 + realtime_ratio = gen_rate * 1000 / self.config.audio.sample_rate + stream( + "%i/%i -- batch_size: %i -- gen_rate: %.1f kHz -- x_realtime: %.1f ", + (i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio), + ) + + def fold_with_overlap(self, x, target, overlap): + """Fold the tensor with overlap for quick batched inference. + Overlap will be used for crossfading in xfade_and_unfold() + Args: + x (tensor) : Upsampled conditioning features. + shape=(1, timesteps, features) + target (int) : Target timesteps for each index of batch + overlap (int) : Timesteps for both xfade and rnn warmup + Return: + (tensor) : shape=(num_folds, target + 2 * overlap, features) + Details: + x = [[h1, h2, ... hn]] + Where each h is a vector of conditioning features + Eg: target=2, overlap=1 with x.size(1)=10 + folded = [[h1, h2, h3, h4], + [h4, h5, h6, h7], + [h7, h8, h9, h10]] + """ + + _, total_len, features = x.size() + + # Calculate variables needed + num_folds = (total_len - overlap) // (target + overlap) + extended_len = num_folds * (overlap + target) + overlap + remaining = total_len - extended_len + + # Pad if some time steps poking out + if remaining != 0: + num_folds += 1 + padding = target + 2 * overlap - remaining + x = self.pad_tensor(x, padding, side="after") + + folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device) + + # Get the values for the folded tensor + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + folded[i] = x[:, start:end, :] + + return folded + + @staticmethod + def get_gru_cell(gru): + gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size) + gru_cell.weight_hh.data = gru.weight_hh_l0.data + gru_cell.weight_ih.data = gru.weight_ih_l0.data + gru_cell.bias_hh.data = gru.bias_hh_l0.data + gru_cell.bias_ih.data = gru.bias_ih_l0.data + return gru_cell + + @staticmethod + def pad_tensor(x, pad, side="both"): + # NB - this is just a quick method i need right now + # i.e., it won't generalise to other shapes/dims + b, t, c = x.size() + total = t + 2 * pad if side == "both" else t + pad + padded = torch.zeros(b, total, c).to(x.device) + if side in ("before", "both"): + padded[:, pad : pad + t, :] = x + elif side == "after": + padded[:, :t, :] = x + return padded + + @staticmethod + def xfade_and_unfold(y, target, overlap): + """Applies a crossfade and unfolds into a 1d array. + Args: + y (ndarry) : Batched sequences of audio samples + shape=(num_folds, target + 2 * overlap) + dtype=np.float64 + overlap (int) : Timesteps for both xfade and rnn warmup + Return: + (ndarry) : audio samples in a 1d array + shape=(total_len) + dtype=np.float64 + Details: + y = [[seq1], + [seq2], + [seq3]] + Apply a gain envelope at both ends of the sequences + y = [[seq1_in, seq1_target, seq1_out], + [seq2_in, seq2_target, seq2_out], + [seq3_in, seq3_target, seq3_out]] + Stagger and add up the groups of samples: + [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...] + """ + + num_folds, length = y.shape + target = length - 2 * overlap + total_len = num_folds * (target + overlap) + overlap + + # Need some silence for the rnn warmup + silence_len = overlap // 2 + fade_len = overlap - silence_len + silence = np.zeros((silence_len), dtype=np.float64) + + # Equal power crossfade + t = np.linspace(-1, 1, fade_len, dtype=np.float64) + fade_in = np.sqrt(0.5 * (1 + t)) + fade_out = np.sqrt(0.5 * (1 - t)) + + # Concat the silence to the fades + fade_in = np.concatenate([silence, fade_in]) + fade_out = np.concatenate([fade_out, silence]) + + # Apply the gain to the overlap samples + y[:, :overlap] *= fade_in + y[:, -overlap:] *= fade_out + + unfolded = np.zeros((total_len), dtype=np.float64) + + # Loop to add up all the samples + for i in range(num_folds): + start = i * (target + overlap) + end = start + target + 2 * overlap + unfolded[start:end] += y[i] + + return unfolded + + def load_checkpoint( + self, config, checkpoint_path, eval=False, cache=False + ): # pylint: disable=unused-argument, redefined-builtin + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"]) + if eval: + self.eval() + assert not self.training + + def train_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]: + mels = batch["input"] + waveform = batch["waveform"] + waveform_coarse = batch["waveform_coarse"] + + y_hat = self.forward(waveform, mels) + if isinstance(self.args.mode, int): + y_hat = y_hat.transpose(1, 2).unsqueeze(-1) + else: + waveform_coarse = waveform_coarse.float() + waveform_coarse = waveform_coarse.unsqueeze(-1) + # compute losses + loss_dict = criterion(y_hat, waveform_coarse) + return {"model_output": y_hat}, loss_dict + + def eval_step(self, batch: Dict, criterion: Dict) -> Tuple[Dict, Dict]: + return self.train_step(batch, criterion) + + @torch.no_grad() + def test( + self, assets: Dict, test_loader: "DataLoader", output: Dict # pylint: disable=unused-argument + ) -> Tuple[Dict, Dict]: + ap = self.ap + figures = {} + audios = {} + samples = test_loader.dataset.load_test_samples(1) + for idx, sample in enumerate(samples): + x = torch.FloatTensor(sample[0]) + x = x.to(next(self.parameters()).device) + y_hat = self.inference(x, self.config.batched, self.config.target_samples, self.config.overlap_samples) + x_hat = ap.melspectrogram(y_hat) + figures.update( + { + f"test_{idx}/ground_truth": plot_spectrogram(x.T), + f"test_{idx}/prediction": plot_spectrogram(x_hat.T), + } + ) + audios.update({f"test_{idx}/audio": y_hat}) + # audios.update({f"real_{idx}/audio": y_hat}) + return figures, audios + + def test_log( + self, outputs: Dict, logger: "Logger", assets: Dict, steps: int # pylint: disable=unused-argument + ) -> Tuple[Dict, np.ndarray]: + figures, audios = outputs + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, self.ap.sample_rate) + + @staticmethod + def format_batch(batch: Dict) -> Dict: + waveform = batch[0] + mels = batch[1] + waveform_coarse = batch[2] + return {"input": mels, "waveform": waveform, "waveform_coarse": waveform_coarse} + + def get_data_loader( # pylint: disable=no-self-use + self, + config: Coqpit, + assets: Dict, + is_eval: True, + samples: List, + verbose: bool, + num_gpus: int, + ): + ap = self.ap + dataset = WaveRNNDataset( + ap=ap, + items=samples, + seq_len=config.seq_len, + hop_len=ap.hop_length, + pad=config.model_args.pad, + mode=config.model_args.mode, + mulaw=config.model_args.mulaw, + is_training=not is_eval, + ) + sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=1 if is_eval else config.batch_size, + shuffle=num_gpus == 0, + collate_fn=dataset.collate, + sampler=sampler, + num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, + pin_memory=True, + ) + return loader + + def get_criterion(self): + # define train functions + return WaveRNNLoss(self.args.mode) + + @staticmethod + def init_from_config(config: "WavernnConfig"): + return Wavernn(config) diff --git a/TTS/vocoder/pqmf_output.wav b/TTS/vocoder/pqmf_output.wav new file mode 100644 index 0000000000000000000000000000000000000000..8a77747b00198a4adfd6c398998517df5b4bdb8d Binary files /dev/null and b/TTS/vocoder/pqmf_output.wav differ diff --git a/TTS/vocoder/utils/__init__.py b/TTS/vocoder/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TTS/vocoder/utils/distribution.py b/TTS/vocoder/utils/distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..fe706ba9ffbc3f8aad75285bca34a910246666b3 --- /dev/null +++ b/TTS/vocoder/utils/distribution.py @@ -0,0 +1,154 @@ +import math + +import numpy as np +import torch +import torch.nn.functional as F +from torch.distributions.normal import Normal + + +def gaussian_loss(y_hat, y, log_std_min=-7.0): + assert y_hat.dim() == 3 + assert y_hat.size(2) == 2 + mean = y_hat[:, :, :1] + log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min) + # TODO: replace with pytorch dist + log_probs = -0.5 * (-math.log(2.0 * math.pi) - 2.0 * log_std - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std))) + return log_probs.squeeze().mean() + + +def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.0): + assert y_hat.size(2) == 2 + mean = y_hat[:, :, :1] + log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min) + dist = Normal( + mean, + torch.exp(log_std), + ) + sample = dist.sample() + sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor) + del dist + return sample + + +def log_sum_exp(x): + """numerically stable log_sum_exp implementation that prevents overflow""" + # TF ordering + axis = len(x.size()) - 1 + m, _ = torch.max(x, dim=axis) + m2, _ = torch.max(x, dim=axis, keepdim=True) + return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) + + +# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py +def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): + if log_scale_min is None: + log_scale_min = float(np.log(1e-14)) + y_hat = y_hat.permute(0, 2, 1) + assert y_hat.dim() == 3 + assert y_hat.size(1) % 3 == 0 + nr_mix = y_hat.size(1) // 3 + + # (B x T x C) + y_hat = y_hat.transpose(1, 2) + + # unpack parameters. (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix : 2 * nr_mix] + log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min) + + # B x T x 1 -> B x T x num_mixtures + y = y.expand_as(means) + + centered_y = y - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1)) + cdf_plus = torch.sigmoid(plus_in) + min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1)) + cdf_min = torch.sigmoid(min_in) + + # log probability for edge case of 0 (before scaling) + # equivalent: torch.log(F.sigmoid(plus_in)) + log_cdf_plus = plus_in - F.softplus(plus_in) + + # log probability for edge case of 255 (before scaling) + # equivalent: (1 - F.sigmoid(min_in)).log() + log_one_minus_cdf_min = -F.softplus(min_in) + + # probability for all other cases + cdf_delta = cdf_plus - cdf_min + + mid_in = inv_stdv * centered_y + # log probability in the center of the bin, to be used in extreme cases + # (not actually used in our code) + log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in) + + # tf equivalent + + # log_probs = tf.where(x < -0.999, log_cdf_plus, + # tf.where(x > 0.999, log_one_minus_cdf_min, + # tf.where(cdf_delta > 1e-5, + # tf.log(tf.maximum(cdf_delta, 1e-12)), + # log_pdf_mid - np.log(127.5)))) + + # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value + # for num_classes=65536 case? 1e-7? not sure.. + inner_inner_cond = (cdf_delta > 1e-5).float() + + inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1.0 - inner_inner_cond) * ( + log_pdf_mid - np.log((num_classes - 1) / 2) + ) + inner_cond = (y > 0.999).float() + inner_out = inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out + cond = (y < -0.999).float() + log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out + + log_probs = log_probs + F.log_softmax(logit_probs, -1) + + if reduce: + return -torch.mean(log_sum_exp(log_probs)) + return -log_sum_exp(log_probs).unsqueeze(-1) + + +def sample_from_discretized_mix_logistic(y, log_scale_min=None): + """ + Sample from discretized mixture of logistic distributions + Args: + y (Tensor): :math:`[B, C, T]` + log_scale_min (float): Log scale minimum value + Returns: + Tensor: sample in range of [-1, 1]. + """ + if log_scale_min is None: + log_scale_min = float(np.log(1e-14)) + assert y.size(1) % 3 == 0 + nr_mix = y.size(1) // 3 + + # B x T x C + y = y.transpose(1, 2) + logit_probs = y[:, :, :nr_mix] + + # sample mixture indicator from softmax + temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) + temp = logit_probs.data - torch.log(-torch.log(temp)) + _, argmax = temp.max(dim=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = to_one_hot(argmax, nr_mix) + # select logistic parameters + means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1) + log_scales = torch.clamp(torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min) + # sample from logistic & clip to interval + # we don't actually round to the nearest 8bit value when sampling + u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) + x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u)) + + x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0) + + return x + + +def to_one_hot(tensor, n, fill_with=1.0): + # we perform one hot encore with respect to the last axis + one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_().type_as(tensor) + one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) + return one_hot diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac797d97f7a53dbc4223c19e928e5f70f522ce82 --- /dev/null +++ b/TTS/vocoder/utils/generic_utils.py @@ -0,0 +1,75 @@ +import logging +from typing import Dict + +import numpy as np +import torch +from matplotlib import pyplot as plt + +from TTS.tts.utils.visual import plot_spectrogram +from TTS.utils.audio import AudioProcessor + +logger = logging.getLogger(__name__) + + +def interpolate_vocoder_input(scale_factor, spec): + """Interpolate spectrogram by the scale factor. + It is mainly used to match the sampling rates of + the tts and vocoder models. + + Args: + scale_factor (float): scale factor to interpolate the spectrogram + spec (np.array): spectrogram to be interpolated + + Returns: + torch.tensor: interpolated spectrogram. + """ + logger.info("Before interpolation: %s", spec.shape) + spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable + spec = torch.nn.functional.interpolate( + spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False + ).squeeze(0) + logger.info("After interpolation: %s", spec.shape) + return spec + + +def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict: + """Plot the predicted and the real waveform and their spectrograms. + + Args: + y_hat (torch.tensor): Predicted waveform. + y (torch.tensor): Real waveform. + ap (AudioProcessor): Audio processor used to process the waveform. + name_prefix (str, optional): Name prefix used to name the figures. Defaults to None. + + Returns: + Dict: output figures keyed by the name of the figures. + """ + if name_prefix is None: + name_prefix = "" + + # select an instance from batch + y_hat = y_hat[0].squeeze().detach().cpu().numpy() + y = y[0].squeeze().detach().cpu().numpy() + + spec_fake = ap.melspectrogram(y_hat).T + spec_real = ap.melspectrogram(y).T + spec_diff = np.abs(spec_fake - spec_real) + + # plot figure and save it + fig_wave = plt.figure() + plt.subplot(2, 1, 1) + plt.plot(y) + plt.title("groundtruth speech") + plt.subplot(2, 1, 2) + plt.plot(y_hat) + plt.title("generated speech") + plt.tight_layout() + plt.close() + + figures = { + name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake), + name_prefix + "spectrogram/real": plot_spectrogram(spec_real), + name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff), + name_prefix + "speech_comparison": fig_wave, + } + return figures diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f090e85999cdba1cead565a4666a0b037e899c79 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,21 @@ +beautifulsoup4 +coqui-tts +cutlet +deep_translator +docker +ebooklib +gensim +gradio>=4.44 +hangul-romanize +indic-nlp-library +iso-639 +jieba +mecab +mecab-python3 +pydub +pypinyin +ray +transformers +translate +tqdm +unidic \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..cd339b5b0bf99729fdeb3ac28945f559c5bd7300 --- /dev/null +++ b/setup.py @@ -0,0 +1,50 @@ +import subprocess +import sys +from setuptools import setup, find_packages +from setuptools.command.develop import develop +from setuptools.command.install import install +import os + +cwd = os.path.dirname(os.path.abspath(__file__)) + +with open("README.md", "r", encoding='utf-8') as fh: + long_description = fh.read() + +with open('requirements.txt') as f: + requirements = f.read().splitlines() + +class PostInstallCommand(install): + def run(self): + install.run(self) + try: + subprocess.run([sys.executable, 'python -m', 'unidic', 'download'], check=True) + except Exception: + print("unidic download failed during installation, but it will be re-attempted a diffrent way when the app itself runs.") + + +setup( + name='ebook2audiobook', + version='2.0.0', + python_requires=">=3.10,<3.12", + author="Drew Thomasson", + description="Convert eBooks to audiobooks with chapters and metadata", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/DrewThomasson/ebook2audiobook", + packages=find_packages(), + install_requires=requirements, + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + include_package_data=True, + entry_points={ + "console_scripts": [ + "ebook2audiobook = app:main", + ], + }, + cmdclass={ + 'install': PostInstallCommand, + } +) diff --git a/tools/generate_ebooks.py b/tools/generate_ebooks.py new file mode 100644 index 0000000000000000000000000000000000000000..ee6547d4b0f4e3df5696d8f6ec41607a5d1a9d58 --- /dev/null +++ b/tools/generate_ebooks.py @@ -0,0 +1,84 @@ +import os +import sys +import subprocess + +from iso639 import languages +from deep_translator import GoogleTranslator +from tqdm import tqdm + +parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +sys.path.append(parent_dir) + +# Your language mapping dictionary from lang.py +from lib.lang import language_mapping + +env = os.environ.copy() +env["PYTHONIOENCODING"] = "utf-8"; +env["LANG"] = "en_US.UTF-8" + +# Base text to be translated +base_text = "This is a test from the result of text file to audiobook conversion." + +# Output directory +output_dir = "../ebooks/tests" +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +# Path to your base cover image (adjust the path accordingly) +base_cover_image = "../ebooks/tests/__cover.jpg" + +# List to keep track of languages that failed +failed_languages = [] + +# Loop over languages with a progress bar +for lang_code, lang_info in tqdm(language_mapping.items(), desc="Processing languages"): + try: + lang_iso = lang_code + language_array = languages.get(part3=lang_code) + if language_array and language_array.part1: + lang_iso = language_array.part1 + if lang_iso == "zh": + lang_iso = "zh-CN" + # Translate the text + translated_text = GoogleTranslator(source='en', target=lang_iso).translate(base_text) + print(f"\nTranslated text for {lang_info['name']} ({lang_iso}): {translated_text}") + + # Write the translated text to a txt file + txt_filename = f"test_{lang_code}.txt" + txt_filepath = os.path.join(output_dir, txt_filename) + with open(txt_filepath, 'w', encoding='utf-8') as f: + f.write(translated_text) + + # Prepare the ebook-convert command + azw3_filename = f"test_{lang_code}.azw3" + azw3_filepath = os.path.join(output_dir, azw3_filename) + + title = f"Ebook {lang_info['name']} Test" + authors = "Dev Team" + language = lang_iso + + command = [ + "ebook-convert", + txt_filepath, + azw3_filepath, + "--cover", base_cover_image, + "--title", title, + "--authors", authors, + "--language", language, + "--input-encoding", "utf-8" + ] + + result = subprocess.run(command, env=env, text=True, encoding="utf-8") + print(f"Ebook generated for {lang_info['name']} at {azw3_filepath}\n") + + except Exception as e: + print(f"Erro: language {lang_code} not supported!") + failed_languages.append(lang_code) + continue + +# After processing all languages, output the list of languages that failed +if failed_languages: + print("\nThe following languages could not be processed:") + for lang_code in failed_languages: + lang_name = language_mapping[lang_code]['name'] + print(f"- {lang_name} ({lang_code})")