Plachta commited on
Commit
9a83644
·
verified ·
1 Parent(s): e7a70ec

Upload 35 files

Browse files
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ import torchaudio
5
+ import librosa
6
+ from modules.commons import build_model, load_checkpoint, recursive_munch
7
+ import yaml
8
+ from hf_utils import load_custom_model_from_hf
9
+
10
+ # Load model and configuration
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
14
+ "DiT_step_315000_seed_v2_online_pruned.pth",
15
+ "config_dit_mel_seed.yml")
16
+
17
+ config = yaml.safe_load(open(dit_config_path, 'r'))
18
+ model_params = recursive_munch(config['model_params'])
19
+ model = build_model(model_params, stage='DiT')
20
+ hop_length = config['preprocess_params']['spect_params']['hop_length']
21
+ sr = config['preprocess_params']['sr']
22
+
23
+ # Load checkpoints
24
+ model, _, _, _ = load_checkpoint(model, None, dit_checkpoint_path,
25
+ load_only_params=True, ignore_modules=[], is_distributed=False)
26
+ for key in model:
27
+ model[key].eval()
28
+ model[key].to(device)
29
+ model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
30
+
31
+ # Load additional modules
32
+ from modules.campplus.DTDNN import CAMPPlus
33
+
34
+ campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
35
+ campplus_model.load_state_dict(torch.load(config['model_params']['style_encoder']['campplus_path']))
36
+ campplus_model.eval()
37
+ campplus_model.to(device)
38
+
39
+ from modules.hifigan.generator import HiFTGenerator
40
+ from modules.hifigan.f0_predictor import ConvRNNF0Predictor
41
+
42
+ hift_checkpoint_path, hift_config_path = load_custom_model_from_hf("Plachta/Seed-VC",
43
+ "hift.pt",
44
+ "hifigan.yml")
45
+ hift_config = yaml.safe_load(open(hift_config_path, 'r'))
46
+ hift_gen = HiFTGenerator(**hift_config['hift'], f0_predictor=ConvRNNF0Predictor(**hift_config['f0_predictor']))
47
+ hift_gen.load_state_dict(torch.load(hift_config['pretrained_model_path'], map_location='cpu'))
48
+ hift_gen.eval()
49
+ hift_gen.to(device)
50
+
51
+ from modules.cosyvoice_tokenizer.frontend import CosyVoiceFrontEnd
52
+
53
+ speech_tokenizer_path = load_custom_model_from_hf("Plachta/Seed-VC", "speech_tokenizer_v1.onnx", None)
54
+
55
+ cosyvoice_frontend = CosyVoiceFrontEnd(speech_tokenizer_model=speech_tokenizer_path,
56
+ device='cuda', device_id=0)
57
+ # Generate mel spectrograms
58
+ mel_fn_args = {
59
+ "n_fft": config['preprocess_params']['spect_params']['n_fft'],
60
+ "win_size": config['preprocess_params']['spect_params']['win_length'],
61
+ "hop_size": config['preprocess_params']['spect_params']['hop_length'],
62
+ "num_mels": config['preprocess_params']['spect_params']['n_mels'],
63
+ "sampling_rate": sr,
64
+ "fmin": 0,
65
+ "fmax": 8000,
66
+ "center": False
67
+ }
68
+ from modules.audio import mel_spectrogram
69
+
70
+ to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
71
+
72
+ @spaces.GPU
73
+ @torch.no_grad()
74
+ @torch.inference_mode()
75
+ def voice_conversion(source, target, diffusion_steps, length_adjust, inference_cfg_rate):
76
+ # Load audio
77
+ source_audio = librosa.load(source, sr=sr)[0]
78
+ ref_audio = librosa.load(target, sr=sr)[0]
79
+
80
+ # Process audio
81
+ source_audio = torch.tensor(source_audio[:sr * 30]).unsqueeze(0).float().to(device)
82
+ ref_audio = torch.tensor(ref_audio[:sr * 30]).unsqueeze(0).float().to(device)
83
+
84
+ # Resample
85
+ source_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
86
+ ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
87
+
88
+ # Extract features
89
+ S_alt = cosyvoice_frontend.extract_speech_token(source_waves_16k)[0]
90
+ S_ori = cosyvoice_frontend.extract_speech_token(ref_waves_16k)[0]
91
+
92
+ mel = to_mel(source_audio.to(device).float())
93
+ mel2 = to_mel(ref_audio.to(device).float())
94
+
95
+ target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
96
+ target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
97
+
98
+ # Style encoding
99
+ feat = torchaudio.compliance.kaldi.fbank(source_waves_16k,
100
+ num_mel_bins=80,
101
+ dither=0,
102
+ sample_frequency=16000)
103
+ feat = feat - feat.mean(dim=0, keepdim=True)
104
+ style1 = campplus_model(feat.unsqueeze(0))
105
+
106
+ feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k,
107
+ num_mel_bins=80,
108
+ dither=0,
109
+ sample_frequency=16000)
110
+ feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
111
+ style2 = campplus_model(feat2.unsqueeze(0))
112
+
113
+ # Length regulation
114
+ cond = model.length_regulator(S_alt, ylens=target_lengths)[0]
115
+ prompt_condition = model.length_regulator(S_ori, ylens=target2_lengths)[0]
116
+ cat_condition = torch.cat([prompt_condition, cond], dim=1)
117
+
118
+ # Voice Conversion
119
+ vc_target = model.cfm.inference(cat_condition, torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
120
+ mel2, style2, None, diffusion_steps, inference_cfg_rate=inference_cfg_rate)
121
+ vc_target = vc_target[:, :, mel2.size(-1):]
122
+
123
+ # Convert to waveform
124
+ vc_wave = hift_gen.inference(vc_target)
125
+
126
+ return (sr, vc_wave.squeeze(0).cpu().numpy())
127
+
128
+
129
+ if __name__ == "__main__":
130
+ description = "Zero-shot voice conversion with in-context learning. Check out our [GitHub repository](https://github.com/Plachtaa/seed-vc) for details and updates."
131
+ inputs = [
132
+ gr.Audio(source="upload", type="filepath", label="Source Audio"),
133
+ gr.Audio(source="upload", type="filepath", label="Reference Audio"),
134
+ gr.Slider(minimum=1, maximum=1000, value=100, step=1, label="Diffusion Steps"),
135
+ gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust"),
136
+ gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate"),
137
+ ]
138
+
139
+ outputs = gr.Audio(label="Output Audio")
140
+
141
+ gr.Interface(fn=voice_conversion, description=description, inputs=inputs, outputs=outputs, title="Seed Voice Conversion").launch()
campplus_cn_common.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3388cf5fd3493c9ac9c69851d8e7a8badcfb4f3dc631020c4961371646d5ada8
3
+ size 28036335
configs/config_dit_mel_seed.yml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "./runs/run_dit_mel_seed"
2
+ save_freq: 1
3
+ log_interval: 10
4
+ save_interval: 1000
5
+ device: "cuda"
6
+ epochs: 1000 # number of epochs for first stage training (pre-training)
7
+ batch_size: 4
8
+ batch_length: 100 # maximum duration of audio in a batch (in seconds)
9
+ max_len: 80 # maximum number of frames
10
+ pretrained_model: ""
11
+ pretrained_encoder: ""
12
+ load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "modules/JDC/bst.t7"
15
+
16
+ preprocess_params:
17
+ sr: 22050
18
+ spect_params:
19
+ n_fft: 1024
20
+ win_length: 1024
21
+ hop_length: 256
22
+ n_mels: 80
23
+
24
+ model_params:
25
+ dit_type: "DiT" # uDiT or DiT
26
+ reg_loss_type: "l2" # l1 or l2
27
+
28
+ speech_tokenizer:
29
+ path: "speech_tokenizer_v1.onnx"
30
+
31
+ style_encoder:
32
+ dim: 192
33
+ campplus_path: "campplus_cn_common.bin"
34
+
35
+ DAC:
36
+ encoder_dim: 64
37
+ encoder_rates: [2, 5, 5, 6]
38
+ decoder_dim: 1536
39
+ decoder_rates: [ 6, 5, 5, 2 ]
40
+ sr: 24000
41
+
42
+ length_regulator:
43
+ channels: 768
44
+ is_discrete: true
45
+ content_codebook_size: 4096
46
+ in_frame_rate: 50
47
+ out_frame_rate: 80
48
+ sampling_ratios: [1, 1, 1, 1]
49
+
50
+ DiT:
51
+ hidden_dim: 768
52
+ num_heads: 12
53
+ depth: 12
54
+ class_dropout_prob: 0.1
55
+ block_size: 4096
56
+ in_channels: 80
57
+ style_condition: true
58
+ final_layer_type: 'wavenet'
59
+ target: 'mel' # mel or codec
60
+ content_dim: 768
61
+ content_codebook_size: 1024
62
+ content_type: 'discrete'
63
+ f0_condition: false
64
+ n_f0_bins: 512
65
+ content_codebooks: 1
66
+ is_causal: false
67
+ long_skip_connection: true
68
+ zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
69
+
70
+ wavenet:
71
+ hidden_dim: 768
72
+ num_layers: 8
73
+ kernel_size: 5
74
+ dilation_rate: 1
75
+ p_dropout: 0.2
76
+ style_condition: true
77
+
78
+ loss_params:
79
+ base_lr: 0.0001
configs/hifigan.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hift:
2
+ in_channels: 80
3
+ base_channels: 512
4
+ nb_harmonics: 8
5
+ sampling_rate: 22050
6
+ nsf_alpha: 0.1
7
+ nsf_sigma: 0.003
8
+ nsf_voiced_threshold: 10
9
+ upsample_rates: [8, 8]
10
+ upsample_kernel_sizes: [16, 16]
11
+ istft_params:
12
+ n_fft: 16
13
+ hop_len: 4
14
+ resblock_kernel_sizes: [3, 7, 11]
15
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
16
+ source_resblock_kernel_sizes: [7, 11]
17
+ source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
18
+ lrelu_slope: 0.1
19
+ audio_limit: 0.99
20
+ f0_predictor:
21
+ num_class: 1
22
+ in_channels: 80
23
+ cond_channels: 512
24
+
25
+ pretrained_model_path: "hift.pt"
hf_utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import hf_hub_download
3
+
4
+
5
+ def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.yml"):
6
+ os.makedirs("./checkpoints", exist_ok=True)
7
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints")
8
+ if config_filename is None:
9
+ return model_path
10
+ config_path = hf_hub_download(repo_id=repo_id, filename=config_filename, cache_dir="./checkpoints")
11
+
12
+ return model_path, config_path
modules/__pycache__/audio.cpython-310.pyc ADDED
Binary file (2.43 kB). View file
 
modules/__pycache__/commons.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
modules/__pycache__/diffusion_transformer.cpython-310.pyc ADDED
Binary file (7.76 kB). View file
 
modules/__pycache__/encodec.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
modules/__pycache__/flow_matching.cpython-310.pyc ADDED
Binary file (5.11 kB). View file
 
modules/__pycache__/length_regulator.cpython-310.pyc ADDED
Binary file (1.58 kB). View file
 
modules/__pycache__/wavenet.cpython-310.pyc ADDED
Binary file (5.15 kB). View file
 
modules/audio.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from scipy.io.wavfile import read
6
+
7
+ MAX_WAV_VALUE = 32768.0
8
+
9
+
10
+ def load_wav(full_path):
11
+ sampling_rate, data = read(full_path)
12
+ return data, sampling_rate
13
+
14
+
15
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
+
18
+
19
+ def dynamic_range_decompression(x, C=1):
20
+ return np.exp(x) / C
21
+
22
+
23
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
+ return torch.log(torch.clamp(x, min=clip_val) * C)
25
+
26
+
27
+ def dynamic_range_decompression_torch(x, C=1):
28
+ return torch.exp(x) / C
29
+
30
+
31
+ def spectral_normalize_torch(magnitudes):
32
+ output = dynamic_range_compression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ def spectral_de_normalize_torch(magnitudes):
37
+ output = dynamic_range_decompression_torch(magnitudes)
38
+ return output
39
+
40
+
41
+ mel_basis = {}
42
+ hann_window = {}
43
+
44
+
45
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
46
+ if torch.min(y) < -1.0:
47
+ print("min value is ", torch.min(y))
48
+ if torch.max(y) > 1.0:
49
+ print("max value is ", torch.max(y))
50
+
51
+ global mel_basis, hann_window # pylint: disable=global-statement
52
+ if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
53
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
54
+ mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
55
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
56
+
57
+ y = torch.nn.functional.pad(
58
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
59
+ )
60
+ y = y.squeeze(1)
61
+
62
+ spec = torch.view_as_real(
63
+ torch.stft(
64
+ y,
65
+ n_fft,
66
+ hop_length=hop_size,
67
+ win_length=win_size,
68
+ window=hann_window[str(y.device)],
69
+ center=center,
70
+ pad_mode="reflect",
71
+ normalized=False,
72
+ onesided=True,
73
+ return_complex=True,
74
+ )
75
+ )
76
+
77
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
78
+
79
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
80
+ spec = spectral_normalize_torch(spec)
81
+
82
+ return spec
modules/campplus/DTDNN.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+
10
+ from modules.campplus.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, BasicResBlock, get_nonlinear
11
+
12
+
13
+ class FCM(nn.Module):
14
+ def __init__(self,
15
+ block=BasicResBlock,
16
+ num_blocks=[2, 2],
17
+ m_channels=32,
18
+ feat_dim=80):
19
+ super(FCM, self).__init__()
20
+ self.in_planes = m_channels
21
+ self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
22
+ self.bn1 = nn.BatchNorm2d(m_channels)
23
+
24
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
25
+ self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2)
26
+
27
+ self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
28
+ self.bn2 = nn.BatchNorm2d(m_channels)
29
+ self.out_channels = m_channels * (feat_dim // 8)
30
+
31
+ def _make_layer(self, block, planes, num_blocks, stride):
32
+ strides = [stride] + [1] * (num_blocks - 1)
33
+ layers = []
34
+ for stride in strides:
35
+ layers.append(block(self.in_planes, planes, stride))
36
+ self.in_planes = planes * block.expansion
37
+ return nn.Sequential(*layers)
38
+
39
+ def forward(self, x):
40
+ x = x.unsqueeze(1)
41
+ out = F.relu(self.bn1(self.conv1(x)))
42
+ out = self.layer1(out)
43
+ out = self.layer2(out)
44
+ out = F.relu(self.bn2(self.conv2(out)))
45
+
46
+ shape = out.shape
47
+ out = out.reshape(shape[0], shape[1]*shape[2], shape[3])
48
+ return out
49
+
50
+ class CAMPPlus(nn.Module):
51
+ def __init__(self,
52
+ feat_dim=80,
53
+ embedding_size=512,
54
+ growth_rate=32,
55
+ bn_size=4,
56
+ init_channels=128,
57
+ config_str='batchnorm-relu',
58
+ memory_efficient=True):
59
+ super(CAMPPlus, self).__init__()
60
+
61
+ self.head = FCM(feat_dim=feat_dim)
62
+ channels = self.head.out_channels
63
+
64
+ self.xvector = nn.Sequential(
65
+ OrderedDict([
66
+
67
+ ('tdnn',
68
+ TDNNLayer(channels,
69
+ init_channels,
70
+ 5,
71
+ stride=2,
72
+ dilation=1,
73
+ padding=-1,
74
+ config_str=config_str)),
75
+ ]))
76
+ channels = init_channels
77
+ for i, (num_layers, kernel_size,
78
+ dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
79
+ block = CAMDenseTDNNBlock(num_layers=num_layers,
80
+ in_channels=channels,
81
+ out_channels=growth_rate,
82
+ bn_channels=bn_size * growth_rate,
83
+ kernel_size=kernel_size,
84
+ dilation=dilation,
85
+ config_str=config_str,
86
+ memory_efficient=memory_efficient)
87
+ self.xvector.add_module('block%d' % (i + 1), block)
88
+ channels = channels + num_layers * growth_rate
89
+ self.xvector.add_module(
90
+ 'transit%d' % (i + 1),
91
+ TransitLayer(channels,
92
+ channels // 2,
93
+ bias=False,
94
+ config_str=config_str))
95
+ channels //= 2
96
+
97
+ self.xvector.add_module(
98
+ 'out_nonlinear', get_nonlinear(config_str, channels))
99
+
100
+ self.xvector.add_module('stats', StatsPool())
101
+ self.xvector.add_module(
102
+ 'dense',
103
+ DenseLayer(channels * 2, embedding_size, config_str='batchnorm_'))
104
+
105
+ for m in self.modules():
106
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
107
+ nn.init.kaiming_normal_(m.weight.data)
108
+ if m.bias is not None:
109
+ nn.init.zeros_(m.bias)
110
+
111
+ def forward(self, x):
112
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
113
+ x = self.head(x)
114
+ x = self.xvector(x)
115
+ return x
modules/campplus/__pycache__/DTDNN.cpython-310.pyc ADDED
Binary file (3.45 kB). View file
 
modules/campplus/__pycache__/layers.cpython-310.pyc ADDED
Binary file (7.3 kB). View file
 
modules/campplus/classifier.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from modules.campplus.layers import DenseLayer
9
+
10
+
11
+ class CosineClassifier(nn.Module):
12
+ def __init__(
13
+ self,
14
+ input_dim,
15
+ num_blocks=0,
16
+ inter_dim=512,
17
+ out_neurons=1000,
18
+ ):
19
+
20
+ super().__init__()
21
+ self.blocks = nn.ModuleList()
22
+
23
+ for index in range(num_blocks):
24
+ self.blocks.append(
25
+ DenseLayer(input_dim, inter_dim, config_str='batchnorm')
26
+ )
27
+ input_dim = inter_dim
28
+
29
+ self.weight = nn.Parameter(
30
+ torch.FloatTensor(out_neurons, input_dim)
31
+ )
32
+ nn.init.xavier_uniform_(self.weight)
33
+
34
+ def forward(self, x):
35
+ # x: [B, dim]
36
+ for layer in self.blocks:
37
+ x = layer(x)
38
+
39
+ # normalized
40
+ x = F.linear(F.normalize(x), F.normalize(self.weight))
41
+ return x
42
+
43
+ class LinearClassifier(nn.Module):
44
+ def __init__(
45
+ self,
46
+ input_dim,
47
+ num_blocks=0,
48
+ inter_dim=512,
49
+ out_neurons=1000,
50
+ ):
51
+
52
+ super().__init__()
53
+ self.blocks = nn.ModuleList()
54
+
55
+ self.nonlinear = nn.ReLU(inplace=True)
56
+ for index in range(num_blocks):
57
+ self.blocks.append(
58
+ DenseLayer(input_dim, inter_dim, bias=True)
59
+ )
60
+ input_dim = inter_dim
61
+
62
+ self.linear = nn.Linear(input_dim, out_neurons, bias=True)
63
+
64
+ def forward(self, x):
65
+ # x: [B, dim]
66
+ x = self.nonlinear(x)
67
+ for layer in self.blocks:
68
+ x = layer(x)
69
+ x = self.linear(x)
70
+ return x
modules/campplus/layers.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2
+ # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torch.utils.checkpoint as cp
7
+ from torch import nn
8
+
9
+
10
+ def get_nonlinear(config_str, channels):
11
+ nonlinear = nn.Sequential()
12
+ for name in config_str.split('-'):
13
+ if name == 'relu':
14
+ nonlinear.add_module('relu', nn.ReLU(inplace=True))
15
+ elif name == 'prelu':
16
+ nonlinear.add_module('prelu', nn.PReLU(channels))
17
+ elif name == 'batchnorm':
18
+ nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
19
+ elif name == 'batchnorm_':
20
+ nonlinear.add_module('batchnorm',
21
+ nn.BatchNorm1d(channels, affine=False))
22
+ else:
23
+ raise ValueError('Unexpected module ({}).'.format(name))
24
+ return nonlinear
25
+
26
+ def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
27
+ mean = x.mean(dim=dim)
28
+ std = x.std(dim=dim, unbiased=unbiased)
29
+ stats = torch.cat([mean, std], dim=-1)
30
+ if keepdim:
31
+ stats = stats.unsqueeze(dim=dim)
32
+ return stats
33
+
34
+
35
+ class StatsPool(nn.Module):
36
+ def forward(self, x):
37
+ return statistics_pooling(x)
38
+
39
+
40
+ class TDNNLayer(nn.Module):
41
+ def __init__(self,
42
+ in_channels,
43
+ out_channels,
44
+ kernel_size,
45
+ stride=1,
46
+ padding=0,
47
+ dilation=1,
48
+ bias=False,
49
+ config_str='batchnorm-relu'):
50
+ super(TDNNLayer, self).__init__()
51
+ if padding < 0:
52
+ assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
53
+ kernel_size)
54
+ padding = (kernel_size - 1) // 2 * dilation
55
+ self.linear = nn.Conv1d(in_channels,
56
+ out_channels,
57
+ kernel_size,
58
+ stride=stride,
59
+ padding=padding,
60
+ dilation=dilation,
61
+ bias=bias)
62
+ self.nonlinear = get_nonlinear(config_str, out_channels)
63
+
64
+ def forward(self, x):
65
+ x = self.linear(x)
66
+ x = self.nonlinear(x)
67
+ return x
68
+
69
+
70
+ class CAMLayer(nn.Module):
71
+ def __init__(self,
72
+ bn_channels,
73
+ out_channels,
74
+ kernel_size,
75
+ stride,
76
+ padding,
77
+ dilation,
78
+ bias,
79
+ reduction=2):
80
+ super(CAMLayer, self).__init__()
81
+ self.linear_local = nn.Conv1d(bn_channels,
82
+ out_channels,
83
+ kernel_size,
84
+ stride=stride,
85
+ padding=padding,
86
+ dilation=dilation,
87
+ bias=bias)
88
+ self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
89
+ self.relu = nn.ReLU(inplace=True)
90
+ self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
91
+ self.sigmoid = nn.Sigmoid()
92
+
93
+ def forward(self, x):
94
+ y = self.linear_local(x)
95
+ context = x.mean(-1, keepdim=True)+self.seg_pooling(x)
96
+ context = self.relu(self.linear1(context))
97
+ m = self.sigmoid(self.linear2(context))
98
+ return y*m
99
+
100
+ def seg_pooling(self, x, seg_len=100, stype='avg'):
101
+ if stype == 'avg':
102
+ seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
103
+ elif stype == 'max':
104
+ seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
105
+ else:
106
+ raise ValueError('Wrong segment pooling type.')
107
+ shape = seg.shape
108
+ seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
109
+ seg = seg[..., :x.shape[-1]]
110
+ return seg
111
+
112
+
113
+ class CAMDenseTDNNLayer(nn.Module):
114
+ def __init__(self,
115
+ in_channels,
116
+ out_channels,
117
+ bn_channels,
118
+ kernel_size,
119
+ stride=1,
120
+ dilation=1,
121
+ bias=False,
122
+ config_str='batchnorm-relu',
123
+ memory_efficient=False):
124
+ super(CAMDenseTDNNLayer, self).__init__()
125
+ assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
126
+ kernel_size)
127
+ padding = (kernel_size - 1) // 2 * dilation
128
+ self.memory_efficient = memory_efficient
129
+ self.nonlinear1 = get_nonlinear(config_str, in_channels)
130
+ self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
131
+ self.nonlinear2 = get_nonlinear(config_str, bn_channels)
132
+ self.cam_layer = CAMLayer(bn_channels,
133
+ out_channels,
134
+ kernel_size,
135
+ stride=stride,
136
+ padding=padding,
137
+ dilation=dilation,
138
+ bias=bias)
139
+
140
+ def bn_function(self, x):
141
+ return self.linear1(self.nonlinear1(x))
142
+
143
+ def forward(self, x):
144
+ if self.training and self.memory_efficient:
145
+ x = cp.checkpoint(self.bn_function, x)
146
+ else:
147
+ x = self.bn_function(x)
148
+ x = self.cam_layer(self.nonlinear2(x))
149
+ return x
150
+
151
+
152
+ class CAMDenseTDNNBlock(nn.ModuleList):
153
+ def __init__(self,
154
+ num_layers,
155
+ in_channels,
156
+ out_channels,
157
+ bn_channels,
158
+ kernel_size,
159
+ stride=1,
160
+ dilation=1,
161
+ bias=False,
162
+ config_str='batchnorm-relu',
163
+ memory_efficient=False):
164
+ super(CAMDenseTDNNBlock, self).__init__()
165
+ for i in range(num_layers):
166
+ layer = CAMDenseTDNNLayer(in_channels=in_channels + i * out_channels,
167
+ out_channels=out_channels,
168
+ bn_channels=bn_channels,
169
+ kernel_size=kernel_size,
170
+ stride=stride,
171
+ dilation=dilation,
172
+ bias=bias,
173
+ config_str=config_str,
174
+ memory_efficient=memory_efficient)
175
+ self.add_module('tdnnd%d' % (i + 1), layer)
176
+
177
+ def forward(self, x):
178
+ for layer in self:
179
+ x = torch.cat([x, layer(x)], dim=1)
180
+ return x
181
+
182
+
183
+ class TransitLayer(nn.Module):
184
+ def __init__(self,
185
+ in_channels,
186
+ out_channels,
187
+ bias=True,
188
+ config_str='batchnorm-relu'):
189
+ super(TransitLayer, self).__init__()
190
+ self.nonlinear = get_nonlinear(config_str, in_channels)
191
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
192
+
193
+ def forward(self, x):
194
+ x = self.nonlinear(x)
195
+ x = self.linear(x)
196
+ return x
197
+
198
+
199
+ class DenseLayer(nn.Module):
200
+ def __init__(self,
201
+ in_channels,
202
+ out_channels,
203
+ bias=False,
204
+ config_str='batchnorm-relu'):
205
+ super(DenseLayer, self).__init__()
206
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
207
+ self.nonlinear = get_nonlinear(config_str, out_channels)
208
+
209
+ def forward(self, x):
210
+ if len(x.shape) == 2:
211
+ x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
212
+ else:
213
+ x = self.linear(x)
214
+ x = self.nonlinear(x)
215
+ return x
216
+
217
+
218
+ class BasicResBlock(nn.Module):
219
+ expansion = 1
220
+
221
+ def __init__(self, in_planes, planes, stride=1):
222
+ super(BasicResBlock, self).__init__()
223
+ self.conv1 = nn.Conv2d(in_planes,
224
+ planes,
225
+ kernel_size=3,
226
+ stride=(stride, 1),
227
+ padding=1,
228
+ bias=False)
229
+ self.bn1 = nn.BatchNorm2d(planes)
230
+ self.conv2 = nn.Conv2d(planes,
231
+ planes,
232
+ kernel_size=3,
233
+ stride=1,
234
+ padding=1,
235
+ bias=False)
236
+ self.bn2 = nn.BatchNorm2d(planes)
237
+
238
+ self.shortcut = nn.Sequential()
239
+ if stride != 1 or in_planes != self.expansion * planes:
240
+ self.shortcut = nn.Sequential(
241
+ nn.Conv2d(in_planes,
242
+ self.expansion * planes,
243
+ kernel_size=1,
244
+ stride=(stride, 1),
245
+ bias=False),
246
+ nn.BatchNorm2d(self.expansion * planes))
247
+
248
+ def forward(self, x):
249
+ out = F.relu(self.bn1(self.conv1(x)))
250
+ out = self.bn2(self.conv2(out))
251
+ out += self.shortcut(x)
252
+ out = F.relu(out)
253
+ return out
modules/commons.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from munch import Munch
7
+ import json
8
+
9
+
10
+ class AttrDict(dict):
11
+ def __init__(self, *args, **kwargs):
12
+ super(AttrDict, self).__init__(*args, **kwargs)
13
+ self.__dict__ = self
14
+
15
+
16
+ def init_weights(m, mean=0.0, std=0.01):
17
+ classname = m.__class__.__name__
18
+ if classname.find("Conv") != -1:
19
+ m.weight.data.normal_(mean, std)
20
+
21
+
22
+ def get_padding(kernel_size, dilation=1):
23
+ return int((kernel_size * dilation - dilation) / 2)
24
+
25
+
26
+ def convert_pad_shape(pad_shape):
27
+ l = pad_shape[::-1]
28
+ pad_shape = [item for sublist in l for item in sublist]
29
+ return pad_shape
30
+
31
+
32
+ def intersperse(lst, item):
33
+ result = [item] * (len(lst) * 2 + 1)
34
+ result[1::2] = lst
35
+ return result
36
+
37
+
38
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
39
+ """KL(P||Q)"""
40
+ kl = (logs_q - logs_p) - 0.5
41
+ kl += (
42
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
43
+ )
44
+ return kl
45
+
46
+
47
+ def rand_gumbel(shape):
48
+ """Sample from the Gumbel distribution, protect from overflows."""
49
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
50
+ return -torch.log(-torch.log(uniform_samples))
51
+
52
+
53
+ def rand_gumbel_like(x):
54
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
55
+ return g
56
+
57
+
58
+ def slice_segments(x, ids_str, segment_size=4):
59
+ ret = torch.zeros_like(x[:, :, :segment_size])
60
+ for i in range(x.size(0)):
61
+ idx_str = ids_str[i]
62
+ idx_end = idx_str + segment_size
63
+ ret[i] = x[i, :, idx_str:idx_end]
64
+ return ret
65
+
66
+
67
+ def slice_segments_audio(x, ids_str, segment_size=4):
68
+ ret = torch.zeros_like(x[:, :segment_size])
69
+ for i in range(x.size(0)):
70
+ idx_str = ids_str[i]
71
+ idx_end = idx_str + segment_size
72
+ ret[i] = x[i, idx_str:idx_end]
73
+ return ret
74
+
75
+
76
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
77
+ b, d, t = x.size()
78
+ if x_lengths is None:
79
+ x_lengths = t
80
+ ids_str_max = x_lengths - segment_size + 1
81
+ ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
82
+ dtype=torch.long
83
+ )
84
+ ret = slice_segments(x, ids_str, segment_size)
85
+ return ret, ids_str
86
+
87
+
88
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
89
+ position = torch.arange(length, dtype=torch.float)
90
+ num_timescales = channels // 2
91
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
92
+ num_timescales - 1
93
+ )
94
+ inv_timescales = min_timescale * torch.exp(
95
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
96
+ )
97
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
98
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
99
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
100
+ signal = signal.view(1, channels, length)
101
+ return signal
102
+
103
+
104
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
105
+ b, channels, length = x.size()
106
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
107
+ return x + signal.to(dtype=x.dtype, device=x.device)
108
+
109
+
110
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
111
+ b, channels, length = x.size()
112
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
113
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
114
+
115
+
116
+ def subsequent_mask(length):
117
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
118
+ return mask
119
+
120
+
121
+ @torch.jit.script
122
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
123
+ n_channels_int = n_channels[0]
124
+ in_act = input_a + input_b
125
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
126
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
127
+ acts = t_act * s_act
128
+ return acts
129
+
130
+
131
+ def convert_pad_shape(pad_shape):
132
+ l = pad_shape[::-1]
133
+ pad_shape = [item for sublist in l for item in sublist]
134
+ return pad_shape
135
+
136
+
137
+ def shift_1d(x):
138
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
139
+ return x
140
+
141
+
142
+ def sequence_mask(length, max_length=None):
143
+ if max_length is None:
144
+ max_length = length.max()
145
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
146
+ return x.unsqueeze(0) < length.unsqueeze(1)
147
+
148
+
149
+ def avg_with_mask(x, mask):
150
+ assert mask.dtype == torch.float, "Mask should be float"
151
+
152
+ if mask.ndim == 2:
153
+ mask = mask.unsqueeze(1)
154
+
155
+ if mask.shape[1] == 1:
156
+ mask = mask.expand_as(x)
157
+
158
+ return (x * mask).sum() / mask.sum()
159
+
160
+
161
+ def generate_path(duration, mask):
162
+ """
163
+ duration: [b, 1, t_x]
164
+ mask: [b, 1, t_y, t_x]
165
+ """
166
+ device = duration.device
167
+
168
+ b, _, t_y, t_x = mask.shape
169
+ cum_duration = torch.cumsum(duration, -1)
170
+
171
+ cum_duration_flat = cum_duration.view(b * t_x)
172
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
173
+ path = path.view(b, t_x, t_y)
174
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
175
+ path = path.unsqueeze(1).transpose(2, 3) * mask
176
+ return path
177
+
178
+
179
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
180
+ if isinstance(parameters, torch.Tensor):
181
+ parameters = [parameters]
182
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
183
+ norm_type = float(norm_type)
184
+ if clip_value is not None:
185
+ clip_value = float(clip_value)
186
+
187
+ total_norm = 0
188
+ for p in parameters:
189
+ param_norm = p.grad.data.norm(norm_type)
190
+ total_norm += param_norm.item() ** norm_type
191
+ if clip_value is not None:
192
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
193
+ total_norm = total_norm ** (1.0 / norm_type)
194
+ return total_norm
195
+
196
+
197
+ def log_norm(x, mean=-4, std=4, dim=2):
198
+ """
199
+ normalized log mel -> mel -> norm -> log(norm)
200
+ """
201
+ x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
202
+ return x
203
+
204
+
205
+ def load_F0_models(path):
206
+ # load F0 model
207
+ from .JDC.model import JDCNet
208
+
209
+ F0_model = JDCNet(num_class=1, seq_len=192)
210
+ params = torch.load(path, map_location="cpu")["net"]
211
+ F0_model.load_state_dict(params)
212
+ _ = F0_model.train()
213
+
214
+ return F0_model
215
+
216
+
217
+ def modify_w2v_forward(self, output_layer=15):
218
+ """
219
+ change forward method of w2v encoder to get its intermediate layer output
220
+ :param self:
221
+ :param layer:
222
+ :return:
223
+ """
224
+ from transformers.modeling_outputs import BaseModelOutput
225
+
226
+ def forward(
227
+ hidden_states,
228
+ attention_mask=None,
229
+ output_attentions=False,
230
+ output_hidden_states=False,
231
+ return_dict=True,
232
+ ):
233
+ all_hidden_states = () if output_hidden_states else None
234
+ all_self_attentions = () if output_attentions else None
235
+
236
+ conv_attention_mask = attention_mask
237
+ if attention_mask is not None:
238
+ # make sure padded tokens output 0
239
+ hidden_states = hidden_states.masked_fill(
240
+ ~attention_mask.bool().unsqueeze(-1), 0.0
241
+ )
242
+
243
+ # extend attention_mask
244
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(
245
+ dtype=hidden_states.dtype
246
+ )
247
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
248
+ attention_mask = attention_mask.expand(
249
+ attention_mask.shape[0],
250
+ 1,
251
+ attention_mask.shape[-1],
252
+ attention_mask.shape[-1],
253
+ )
254
+
255
+ hidden_states = self.dropout(hidden_states)
256
+
257
+ if self.embed_positions is not None:
258
+ relative_position_embeddings = self.embed_positions(hidden_states)
259
+ else:
260
+ relative_position_embeddings = None
261
+
262
+ deepspeed_zero3_is_enabled = False
263
+
264
+ for i, layer in enumerate(self.layers):
265
+ if output_hidden_states:
266
+ all_hidden_states = all_hidden_states + (hidden_states,)
267
+
268
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
269
+ dropout_probability = torch.rand([])
270
+
271
+ skip_the_layer = (
272
+ True
273
+ if self.training and (dropout_probability < self.config.layerdrop)
274
+ else False
275
+ )
276
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
277
+ # under deepspeed zero3 all gpus must run in sync
278
+ if self.gradient_checkpointing and self.training:
279
+ layer_outputs = self._gradient_checkpointing_func(
280
+ layer.__call__,
281
+ hidden_states,
282
+ attention_mask,
283
+ relative_position_embeddings,
284
+ output_attentions,
285
+ conv_attention_mask,
286
+ )
287
+ else:
288
+ layer_outputs = layer(
289
+ hidden_states,
290
+ attention_mask=attention_mask,
291
+ relative_position_embeddings=relative_position_embeddings,
292
+ output_attentions=output_attentions,
293
+ conv_attention_mask=conv_attention_mask,
294
+ )
295
+ hidden_states = layer_outputs[0]
296
+
297
+ if skip_the_layer:
298
+ layer_outputs = (None, None)
299
+
300
+ if output_attentions:
301
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
302
+
303
+ if i == output_layer - 1:
304
+ break
305
+
306
+ if output_hidden_states:
307
+ all_hidden_states = all_hidden_states + (hidden_states,)
308
+
309
+ if not return_dict:
310
+ return tuple(
311
+ v
312
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
313
+ if v is not None
314
+ )
315
+ return BaseModelOutput(
316
+ last_hidden_state=hidden_states,
317
+ hidden_states=all_hidden_states,
318
+ attentions=all_self_attentions,
319
+ )
320
+
321
+ return forward
322
+
323
+
324
+ MATPLOTLIB_FLAG = False
325
+
326
+
327
+ def plot_spectrogram_to_numpy(spectrogram):
328
+ global MATPLOTLIB_FLAG
329
+ if not MATPLOTLIB_FLAG:
330
+ import matplotlib
331
+ import logging
332
+
333
+ matplotlib.use("Agg")
334
+ MATPLOTLIB_FLAG = True
335
+ mpl_logger = logging.getLogger("matplotlib")
336
+ mpl_logger.setLevel(logging.WARNING)
337
+ import matplotlib.pylab as plt
338
+ import numpy as np
339
+
340
+ fig, ax = plt.subplots(figsize=(10, 2))
341
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
342
+ plt.colorbar(im, ax=ax)
343
+ plt.xlabel("Frames")
344
+ plt.ylabel("Channels")
345
+ plt.tight_layout()
346
+
347
+ fig.canvas.draw()
348
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
349
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
350
+ plt.close()
351
+ return data
352
+
353
+
354
+ def normalize_f0(f0_sequence):
355
+ # Remove unvoiced frames (replace with -1)
356
+ voiced_indices = np.where(f0_sequence > 0)[0]
357
+ f0_voiced = f0_sequence[voiced_indices]
358
+
359
+ # Convert to log scale
360
+ log_f0 = np.log2(f0_voiced)
361
+
362
+ # Calculate mean and standard deviation
363
+ mean_f0 = np.mean(log_f0)
364
+ std_f0 = np.std(log_f0)
365
+
366
+ # Normalize the F0 sequence
367
+ normalized_f0 = (log_f0 - mean_f0) / std_f0
368
+
369
+ # Create the normalized F0 sequence with unvoiced frames
370
+ normalized_sequence = np.zeros_like(f0_sequence)
371
+ normalized_sequence[voiced_indices] = normalized_f0
372
+ normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames
373
+
374
+ return normalized_sequence
375
+
376
+
377
+ def build_model(args, stage="DiT"):
378
+ if stage == "DiT":
379
+ from modules.flow_matching import CFM
380
+ from modules.length_regulator import InterpolateRegulator
381
+
382
+ length_regulator = InterpolateRegulator(
383
+ channels=args.length_regulator.channels,
384
+ sampling_ratios=args.length_regulator.sampling_ratios,
385
+ is_discrete=args.length_regulator.is_discrete,
386
+ codebook_size=args.length_regulator.content_codebook_size,
387
+ )
388
+ cfm = CFM(args)
389
+ nets = Munch(
390
+ cfm=cfm,
391
+ length_regulator=length_regulator,
392
+ )
393
+ else:
394
+ raise ValueError(f"Unknown stage: {stage}")
395
+
396
+ return nets
397
+
398
+
399
+ def load_checkpoint(
400
+ model,
401
+ optimizer,
402
+ path,
403
+ load_only_params=True,
404
+ ignore_modules=[],
405
+ is_distributed=False,
406
+ ):
407
+ state = torch.load(path, map_location="cpu")
408
+ params = state["net"]
409
+ for key in model:
410
+ if key in params and key not in ignore_modules:
411
+ if not is_distributed:
412
+ # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
413
+ for k in list(params[key].keys()):
414
+ if k.startswith("module."):
415
+ params[key][k[len("module.") :]] = params[key][k]
416
+ del params[key][k]
417
+ model_state_dict = model[key].state_dict()
418
+ # 过滤出形状匹配的键值对
419
+ filtered_state_dict = {
420
+ k: v
421
+ for k, v in params[key].items()
422
+ if k in model_state_dict and v.shape == model_state_dict[k].shape
423
+ }
424
+ skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
425
+ if skipped_keys:
426
+ print(
427
+ f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
428
+ )
429
+ print("%s loaded" % key)
430
+ model[key].load_state_dict(filtered_state_dict, strict=False)
431
+ _ = [model[key].eval() for key in model]
432
+
433
+ if not load_only_params:
434
+ epoch = state["epoch"] + 1
435
+ iters = state["iters"]
436
+ optimizer.load_state_dict(state["optimizer"])
437
+ optimizer.load_scheduler_state_dict(state["scheduler"])
438
+
439
+ else:
440
+ epoch = 0
441
+ iters = 0
442
+
443
+ return model, optimizer, epoch, iters
444
+
445
+
446
+ def recursive_munch(d):
447
+ if isinstance(d, dict):
448
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
449
+ elif isinstance(d, list):
450
+ return [recursive_munch(v) for v in d]
451
+ else:
452
+ return d
modules/cosyvoice_tokenizer/__pycache__/frontend.cpython-310.pyc ADDED
Binary file (2.56 kB). View file
 
modules/cosyvoice_tokenizer/frontend.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from functools import partial
15
+ import onnxruntime
16
+ import torch
17
+ import numpy as np
18
+ import whisper
19
+ import torchaudio.compliance.kaldi as kaldi
20
+
21
+ class CosyVoiceFrontEnd:
22
+
23
+ def __init__(self, speech_tokenizer_model: str, device: str = 'cuda', device_id: int = 0):
24
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+ option = onnxruntime.SessionOptions()
26
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
27
+ option.intra_op_num_threads = 1
28
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if device == "cuda" else "CPUExecutionProvider"])
29
+ if device == 'cuda':
30
+ self.speech_tokenizer_session.set_providers(['CUDAExecutionProvider'], [ {'device_id': device_id}])
31
+
32
+ def extract_speech_token(self, speech):
33
+ feat = whisper.log_mel_spectrogram(speech, n_mels=128)
34
+ speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
35
+ self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
36
+ speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
37
+ speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
38
+ return speech_token, speech_token_len
39
+
40
+ def _extract_spk_embedding(self, speech):
41
+ feat = kaldi.fbank(speech,
42
+ num_mel_bins=80,
43
+ dither=0,
44
+ sample_frequency=16000)
45
+ feat = feat - feat.mean(dim=0, keepdim=True)
46
+ embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
47
+ embedding = torch.tensor([embedding]).to(self.device)
48
+ return embedding
49
+
50
+ def _extract_speech_feat(self, speech):
51
+ speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
52
+ speech_feat = speech_feat.unsqueeze(dim=0)
53
+ speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
54
+ return speech_feat, speech_feat_len
modules/diffusion_transformer.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import math
4
+
5
+ from modules.gpt_fast.model import ModelArgs, Transformer
6
+ from modules.wavenet import WN
7
+ from modules.commons import sequence_mask
8
+
9
+ from torch.nn.utils import weight_norm
10
+
11
+ def modulate(x, shift, scale):
12
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
13
+
14
+
15
+ #################################################################################
16
+ # Embedding Layers for Timesteps and Class Labels #
17
+ #################################################################################
18
+
19
+ class TimestepEmbedder(nn.Module):
20
+ """
21
+ Embeds scalar timesteps into vector representations.
22
+ """
23
+ def __init__(self, hidden_size, frequency_embedding_size=256):
24
+ super().__init__()
25
+ self.mlp = nn.Sequential(
26
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
27
+ nn.SiLU(),
28
+ nn.Linear(hidden_size, hidden_size, bias=True),
29
+ )
30
+ self.frequency_embedding_size = frequency_embedding_size
31
+
32
+ @staticmethod
33
+ def timestep_embedding(t, dim, max_period=10000, scale=1000):
34
+ """
35
+ Create sinusoidal timestep embeddings.
36
+ :param t: a 1-D Tensor of N indices, one per batch element.
37
+ These may be fractional.
38
+ :param dim: the dimension of the output.
39
+ :param max_period: controls the minimum frequency of the embeddings.
40
+ :return: an (N, D) Tensor of positional embeddings.
41
+ """
42
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
43
+ half = dim // 2
44
+ freqs = torch.exp(
45
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
46
+ ).to(device=t.device)
47
+ args = scale * t[:, None].float() * freqs[None]
48
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
49
+ if dim % 2:
50
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
51
+ return embedding
52
+
53
+ def forward(self, t):
54
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
55
+ t_emb = self.mlp(t_freq)
56
+ return t_emb
57
+
58
+
59
+ class StyleEmbedder(nn.Module):
60
+ """
61
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
62
+ """
63
+ def __init__(self, input_size, hidden_size, dropout_prob):
64
+ super().__init__()
65
+ use_cfg_embedding = dropout_prob > 0
66
+ self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
67
+ self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
68
+ self.input_size = input_size
69
+ self.dropout_prob = dropout_prob
70
+
71
+ def forward(self, labels, train, force_drop_ids=None):
72
+ use_dropout = self.dropout_prob > 0
73
+ if (train and use_dropout) or (force_drop_ids is not None):
74
+ labels = self.token_drop(labels, force_drop_ids)
75
+ else:
76
+ labels = self.style_in(labels)
77
+ embeddings = labels
78
+ return embeddings
79
+
80
+ class FinalLayer(nn.Module):
81
+ """
82
+ The final layer of DiT.
83
+ """
84
+ def __init__(self, hidden_size, patch_size, out_channels):
85
+ super().__init__()
86
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
87
+ self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
88
+ self.adaLN_modulation = nn.Sequential(
89
+ nn.SiLU(),
90
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
91
+ )
92
+
93
+ def forward(self, x, c):
94
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
95
+ x = modulate(self.norm_final(x), shift, scale)
96
+ x = self.linear(x)
97
+ return x
98
+
99
+ class DiT(torch.nn.Module):
100
+ def __init__(
101
+ self,
102
+ args
103
+ ):
104
+ super(DiT, self).__init__()
105
+ self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
106
+ self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
107
+ self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
108
+ model_args = ModelArgs(
109
+ block_size=args.DiT.block_size,
110
+ n_layer=args.DiT.depth,
111
+ n_head=args.DiT.num_heads,
112
+ dim=args.DiT.hidden_dim,
113
+ head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
114
+ vocab_size=1024,
115
+ uvit_skip_connection=self.uvit_skip_connection,
116
+ )
117
+ self.transformer = Transformer(model_args)
118
+ self.in_channels = args.DiT.in_channels
119
+ self.out_channels = args.DiT.in_channels
120
+ self.num_heads = args.DiT.num_heads
121
+
122
+ self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
123
+
124
+ self.content_type = args.DiT.content_type # 'discrete' or 'continuous'
125
+ self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
126
+ self.content_dim = args.DiT.content_dim # for continuous content
127
+ self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content
128
+ self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
129
+
130
+ self.is_causal = args.DiT.is_causal
131
+
132
+ self.n_f0_bins = args.DiT.n_f0_bins
133
+ self.f0_bins = torch.arange(2, 1024, 1024 // args.DiT.n_f0_bins)
134
+ self.f0_embedder = nn.Embedding(args.DiT.n_f0_bins, args.DiT.hidden_dim)
135
+ self.f0_condition = args.DiT.f0_condition
136
+
137
+ self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
138
+ self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
139
+ # self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
140
+ # self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
141
+
142
+ input_pos = torch.arange(args.DiT.block_size)
143
+ self.register_buffer("input_pos", input_pos)
144
+
145
+ self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
146
+ self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
147
+ self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet
148
+ if self.final_layer_type == 'wavenet':
149
+ self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
150
+ kernel_size=args.wavenet.kernel_size,
151
+ dilation_rate=args.wavenet.dilation_rate,
152
+ n_layers=args.wavenet.num_layers,
153
+ gin_channels=args.wavenet.hidden_dim,
154
+ p_dropout=args.wavenet.p_dropout,
155
+ causal=False)
156
+ self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
157
+ else:
158
+ self.final_mlp = nn.Sequential(
159
+ nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
160
+ nn.SiLU(),
161
+ nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
162
+ )
163
+ self.final_conv = nn.Conv1d(args.DiT.in_channels, args.DiT.in_channels, kernel_size=3, padding=1)
164
+ self.transformer_style_condition = args.DiT.style_condition
165
+ self.wavenet_style_condition = args.wavenet.style_condition
166
+ assert args.DiT.style_condition == args.wavenet.style_condition
167
+
168
+ self.class_dropout_prob = args.DiT.class_dropout_prob
169
+ self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
170
+ self.res_projection = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim) # residual connection from tranformer output to final output
171
+ self.long_skip_connection = args.DiT.long_skip_connection
172
+ self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
173
+
174
+ self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
175
+ args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
176
+ args.DiT.hidden_dim)
177
+ if self.style_as_token:
178
+ self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
179
+
180
+ def setup_caches(self, max_batch_size, max_seq_length):
181
+ self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
182
+ def forward(self, x, prompt_x, x_lens, t, style, cond, f0=None, mask_content=False):
183
+ class_dropout = False
184
+ if self.training and torch.rand(1) < self.class_dropout_prob:
185
+ class_dropout = True
186
+ if not self.training and mask_content:
187
+ class_dropout = True
188
+ # cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
189
+ cond_in_module = self.cond_projection
190
+
191
+ B, _, T = x.size()
192
+
193
+
194
+ t1 = self.t_embedder(t) # (N, D)
195
+
196
+ cond = cond_in_module(cond)
197
+ if self.f0_condition and f0 is not None:
198
+ quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
199
+ cond = cond + self.f0_embedder(quantized_f0)
200
+
201
+ x = x.transpose(1, 2)
202
+ prompt_x = prompt_x.transpose(1, 2)
203
+
204
+ x_in = torch.cat([x, prompt_x, cond], dim=-1)
205
+ if self.transformer_style_condition and not self.style_as_token:
206
+ x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1)
207
+ if class_dropout:
208
+ x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0
209
+ x_in = self.cond_x_merge_linear(x_in) # (N, T, D)
210
+
211
+ if self.style_as_token:
212
+ style = self.style_in(style)
213
+ style = torch.zeros_like(style) if class_dropout else style
214
+ x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
215
+ if self.time_as_token:
216
+ x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
217
+ x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1)
218
+ input_pos = self.input_pos[:x_in.size(1)] # (T,)
219
+ x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None
220
+ x_res = self.transformer(x_in, None if self.time_as_token else t1.unsqueeze(1), input_pos, x_mask_expanded)
221
+ x_res = x_res[:, 1:] if self.time_as_token else x_res
222
+ x_res = x_res[:, 1:] if self.style_as_token else x_res
223
+ if self.long_skip_connection:
224
+ x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
225
+ if self.final_layer_type == 'wavenet':
226
+ x = self.conv1(x_res)
227
+ x = x.transpose(1, 2)
228
+ t2 = self.t_embedder2(t)
229
+ x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
230
+ x_res) # long residual connection
231
+ x = self.final_layer(x, t1).transpose(1, 2)
232
+ x = self.conv2(x)
233
+ else:
234
+ x = self.final_mlp(x_res)
235
+ x = x.transpose(1, 2)
236
+ x = self.final_conv(x)
237
+ return x
modules/encodec.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Convolutional layers wrappers and utilities."""
8
+
9
+ import math
10
+ import typing as tp
11
+ import warnings
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torch.nn.utils import spectral_norm, weight_norm
17
+
18
+ import typing as tp
19
+
20
+ import einops
21
+
22
+
23
+ class ConvLayerNorm(nn.LayerNorm):
24
+ """
25
+ Convolution-friendly LayerNorm that moves channels to last dimensions
26
+ before running the normalization and moves them back to original position right after.
27
+ """
28
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
29
+ super().__init__(normalized_shape, **kwargs)
30
+
31
+ def forward(self, x):
32
+ x = einops.rearrange(x, 'b ... t -> b t ...')
33
+ x = super().forward(x)
34
+ x = einops.rearrange(x, 'b t ... -> b ... t')
35
+ return
36
+
37
+
38
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
39
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
40
+
41
+
42
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
43
+ assert norm in CONV_NORMALIZATIONS
44
+ if norm == 'weight_norm':
45
+ return weight_norm(module)
46
+ elif norm == 'spectral_norm':
47
+ return spectral_norm(module)
48
+ else:
49
+ # We already check was in CONV_NORMALIZATION, so any other choice
50
+ # doesn't need reparametrization.
51
+ return module
52
+
53
+
54
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
55
+ """Return the proper normalization module. If causal is True, this will ensure the returned
56
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
57
+ """
58
+ assert norm in CONV_NORMALIZATIONS
59
+ if norm == 'layer_norm':
60
+ assert isinstance(module, nn.modules.conv._ConvNd)
61
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
62
+ elif norm == 'time_group_norm':
63
+ if causal:
64
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
65
+ assert isinstance(module, nn.modules.conv._ConvNd)
66
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
67
+ else:
68
+ return nn.Identity()
69
+
70
+
71
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
72
+ padding_total: int = 0) -> int:
73
+ """See `pad_for_conv1d`.
74
+ """
75
+ length = x.shape[-1]
76
+ n_frames = (length - kernel_size + padding_total) / stride + 1
77
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
78
+ return ideal_length - length
79
+
80
+
81
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
82
+ """Pad for a convolution to make sure that the last window is full.
83
+ Extra padding is added at the end. This is required to ensure that we can rebuild
84
+ an output of the same length, as otherwise, even with padding, some time steps
85
+ might get removed.
86
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
87
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
88
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
89
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
90
+ 1 2 3 4 # once you removed padding, we are missing one time step !
91
+ """
92
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
93
+ return F.pad(x, (0, extra_padding))
94
+
95
+
96
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
97
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
98
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
99
+ """
100
+ length = x.shape[-1]
101
+ padding_left, padding_right = paddings
102
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
103
+ if mode == 'reflect':
104
+ max_pad = max(padding_left, padding_right)
105
+ extra_pad = 0
106
+ if length <= max_pad:
107
+ extra_pad = max_pad - length + 1
108
+ x = F.pad(x, (0, extra_pad))
109
+ padded = F.pad(x, paddings, mode, value)
110
+ end = padded.shape[-1] - extra_pad
111
+ return padded[..., :end]
112
+ else:
113
+ return F.pad(x, paddings, mode, value)
114
+
115
+
116
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
117
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
118
+ padding_left, padding_right = paddings
119
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
120
+ assert (padding_left + padding_right) <= x.shape[-1]
121
+ end = x.shape[-1] - padding_right
122
+ return x[..., padding_left: end]
123
+
124
+
125
+ class NormConv1d(nn.Module):
126
+ """Wrapper around Conv1d and normalization applied to this conv
127
+ to provide a uniform interface across normalization approaches.
128
+ """
129
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
130
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
131
+ super().__init__()
132
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
133
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
134
+ self.norm_type = norm
135
+
136
+ def forward(self, x):
137
+ x = self.conv(x)
138
+ x = self.norm(x)
139
+ return x
140
+
141
+
142
+ class NormConv2d(nn.Module):
143
+ """Wrapper around Conv2d and normalization applied to this conv
144
+ to provide a uniform interface across normalization approaches.
145
+ """
146
+ def __init__(self, *args, norm: str = 'none',
147
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
148
+ super().__init__()
149
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
150
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
151
+ self.norm_type = norm
152
+
153
+ def forward(self, x):
154
+ x = self.conv(x)
155
+ x = self.norm(x)
156
+ return x
157
+
158
+
159
+ class NormConvTranspose1d(nn.Module):
160
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
161
+ to provide a uniform interface across normalization approaches.
162
+ """
163
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
164
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
165
+ super().__init__()
166
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
167
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
168
+ self.norm_type = norm
169
+
170
+ def forward(self, x):
171
+ x = self.convtr(x)
172
+ x = self.norm(x)
173
+ return x
174
+
175
+
176
+ class NormConvTranspose2d(nn.Module):
177
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
178
+ to provide a uniform interface across normalization approaches.
179
+ """
180
+ def __init__(self, *args, norm: str = 'none',
181
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
182
+ super().__init__()
183
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
184
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
185
+
186
+ def forward(self, x):
187
+ x = self.convtr(x)
188
+ x = self.norm(x)
189
+ return x
190
+
191
+
192
+ class SConv1d(nn.Module):
193
+ """Conv1d with some builtin handling of asymmetric or causal padding
194
+ and normalization.
195
+ """
196
+ def __init__(self, in_channels: int, out_channels: int,
197
+ kernel_size: int, stride: int = 1, dilation: int = 1,
198
+ groups: int = 1, bias: bool = True, causal: bool = False,
199
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
200
+ pad_mode: str = 'reflect', **kwargs):
201
+ super().__init__()
202
+ # warn user on unusual setup between dilation and stride
203
+ if stride > 1 and dilation > 1:
204
+ warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
205
+ f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
206
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
207
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
208
+ norm=norm, norm_kwargs=norm_kwargs)
209
+ self.causal = causal
210
+ self.pad_mode = pad_mode
211
+
212
+ def forward(self, x):
213
+ B, C, T = x.shape
214
+ kernel_size = self.conv.conv.kernel_size[0]
215
+ stride = self.conv.conv.stride[0]
216
+ dilation = self.conv.conv.dilation[0]
217
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
218
+ padding_total = kernel_size - stride
219
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
220
+ if self.causal:
221
+ # Left padding for causal
222
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
223
+ else:
224
+ # Asymmetric padding required for odd strides
225
+ padding_right = padding_total // 2
226
+ padding_left = padding_total - padding_right
227
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
228
+ return self.conv(x)
229
+
230
+
231
+ class SConvTranspose1d(nn.Module):
232
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
233
+ and normalization.
234
+ """
235
+ def __init__(self, in_channels: int, out_channels: int,
236
+ kernel_size: int, stride: int = 1, causal: bool = False,
237
+ norm: str = 'none', trim_right_ratio: float = 1.,
238
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
239
+ super().__init__()
240
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
241
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
242
+ self.causal = causal
243
+ self.trim_right_ratio = trim_right_ratio
244
+ assert self.causal or self.trim_right_ratio == 1., \
245
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
246
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
247
+
248
+ def forward(self, x):
249
+ kernel_size = self.convtr.convtr.kernel_size[0]
250
+ stride = self.convtr.convtr.stride[0]
251
+ padding_total = kernel_size - stride
252
+
253
+ y = self.convtr(x)
254
+
255
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
256
+ # removed at the very end, when keeping only the right length for the output,
257
+ # as removing it here would require also passing the length at the matching layer
258
+ # in the encoder.
259
+ if self.causal:
260
+ # Trim the padding on the right according to the specified ratio
261
+ # if trim_right_ratio = 1.0, trim everything from right
262
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
263
+ padding_left = padding_total - padding_right
264
+ y = unpad1d(y, (padding_left, padding_right))
265
+ else:
266
+ # Asymmetric padding required for odd strides
267
+ padding_right = padding_total // 2
268
+ padding_left = padding_total - padding_right
269
+ y = unpad1d(y, (padding_left, padding_right))
270
+ return y
271
+
272
+ class SLSTM(nn.Module):
273
+ """
274
+ LSTM without worrying about the hidden state, nor the layout of the data.
275
+ Expects input as convolutional layout.
276
+ """
277
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
278
+ super().__init__()
279
+ self.skip = skip
280
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
281
+ self.hidden = None
282
+
283
+ def forward(self, x):
284
+ x = x.permute(2, 0, 1)
285
+ if self.training:
286
+ y, _ = self.lstm(x)
287
+ else:
288
+ y, self.hidden = self.lstm(x, self.hidden)
289
+ if self.skip:
290
+ y = y + x
291
+ y = y.permute(1, 2, 0)
292
+ return y
modules/flow_matching.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from modules.diffusion_transformer import DiT
7
+ from modules.commons import sequence_mask
8
+
9
+ class BASECFM(torch.nn.Module, ABC):
10
+ def __init__(
11
+ self,
12
+ args,
13
+ ):
14
+ super().__init__()
15
+ self.sigma_min = 1e-6
16
+
17
+ self.estimator = None
18
+
19
+ self.in_channels = args.DiT.in_channels
20
+
21
+ self.criterion = torch.nn.MSELoss() if args.reg_loss_type == "l2" else torch.nn.L1Loss()
22
+
23
+ if hasattr(args.DiT, 'zero_prompt_speech_token'):
24
+ self.zero_prompt_speech_token = args.DiT.zero_prompt_speech_token
25
+ else:
26
+ self.zero_prompt_speech_token = False
27
+
28
+ @torch.inference_mode()
29
+ def inference(self, mu, x_lens, prompt, style, f0, n_timesteps, temperature=1.0, inference_cfg_rate=0.5):
30
+ """Forward diffusion
31
+
32
+ Args:
33
+ mu (torch.Tensor): output of encoder
34
+ shape: (batch_size, n_feats, mel_timesteps)
35
+ mask (torch.Tensor): output_mask
36
+ shape: (batch_size, 1, mel_timesteps)
37
+ n_timesteps (int): number of diffusion steps
38
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
39
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
40
+ shape: (batch_size, spk_emb_dim)
41
+ cond: Not used but kept for future purposes
42
+
43
+ Returns:
44
+ sample: generated mel-spectrogram
45
+ shape: (batch_size, n_feats, mel_timesteps)
46
+ """
47
+ B, T = mu.size(0), mu.size(1)
48
+ z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
49
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
50
+ return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate)
51
+
52
+ def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5):
53
+ """
54
+ Fixed euler solver for ODEs.
55
+ Args:
56
+ x (torch.Tensor): random noise
57
+ t_span (torch.Tensor): n_timesteps interpolated
58
+ shape: (n_timesteps + 1,)
59
+ mu (torch.Tensor): output of encoder
60
+ shape: (batch_size, n_feats, mel_timesteps)
61
+ mask (torch.Tensor): output_mask
62
+ shape: (batch_size, 1, mel_timesteps)
63
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
64
+ shape: (batch_size, spk_emb_dim)
65
+ cond: Not used but kept for future purposes
66
+ """
67
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
68
+
69
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
70
+ # Or in future might add like a return_all_steps flag
71
+ sol = []
72
+ # apply prompt
73
+ prompt_len = prompt.size(-1)
74
+ prompt_x = torch.zeros_like(x)
75
+ prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
76
+ x[..., :prompt_len] = 0
77
+ if self.zero_prompt_speech_token:
78
+ mu[..., :prompt_len] = 0
79
+ for step in range(1, len(t_span)):
80
+ dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu, f0)
81
+ # Classifier-Free Guidance inference introduced in VoiceBox
82
+ if inference_cfg_rate > 0:
83
+ cfg_dphi_dt = self.estimator(
84
+ x, torch.zeros_like(prompt_x), x_lens, t.unsqueeze(0),
85
+ torch.zeros_like(style),
86
+ torch.zeros_like(mu), None
87
+ )
88
+ dphi_dt = ((1.0 + inference_cfg_rate) * dphi_dt -
89
+ inference_cfg_rate * cfg_dphi_dt)
90
+ x = x + dt * dphi_dt
91
+ t = t + dt
92
+ sol.append(x)
93
+ if step < len(t_span) - 1:
94
+ dt = t_span[step + 1] - t
95
+ x[:, :, :prompt_len] = 0
96
+
97
+ return sol[-1]
98
+
99
+ def forward(self, x1, x_lens, prompt_lens, mu, style, f0=None):
100
+ """Computes diffusion loss
101
+
102
+ Args:
103
+ x1 (torch.Tensor): Target
104
+ shape: (batch_size, n_feats, mel_timesteps)
105
+ mask (torch.Tensor): target mask
106
+ shape: (batch_size, 1, mel_timesteps)
107
+ mu (torch.Tensor): output of encoder
108
+ shape: (batch_size, n_feats, mel_timesteps)
109
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
110
+ shape: (batch_size, spk_emb_dim)
111
+
112
+ Returns:
113
+ loss: conditional flow matching loss
114
+ y: conditional flow
115
+ shape: (batch_size, n_feats, mel_timesteps)
116
+ """
117
+ b, _, t = x1.shape
118
+
119
+ # random timestep
120
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype)
121
+ # sample noise p(x_0)
122
+ z = torch.randn_like(x1)
123
+
124
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
125
+ u = x1 - (1 - self.sigma_min) * z
126
+
127
+ prompt = torch.zeros_like(x1)
128
+ for bib in range(b):
129
+ prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]]
130
+ # range covered by prompt are set to 0
131
+ y[bib, :, :prompt_lens[bib]] = 0
132
+ if self.zero_prompt_speech_token:
133
+ mu[bib, :, :prompt_lens[bib]] = 0
134
+
135
+ estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(), style, mu, f0)
136
+ loss = 0
137
+ for bib in range(b):
138
+ loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
139
+ loss /= b
140
+
141
+ return loss, y
142
+
143
+
144
+
145
+ class CFM(BASECFM):
146
+ def __init__(self, args):
147
+ super().__init__(
148
+ args
149
+ )
150
+ if args.dit_type == "DiT":
151
+ self.estimator = DiT(args)
152
+ else:
153
+ raise NotImplementedError(f"Unknown diffusion type {args.dit_type}")
modules/gpt_fast/__pycache__/model.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
modules/gpt_fast/generate.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import itertools
7
+ import sys
8
+ import time
9
+ from pathlib import Path
10
+ from typing import Optional, Tuple
11
+
12
+ import torch
13
+ import torch._dynamo.config
14
+ import torch._inductor.config
15
+
16
+ def device_sync(device):
17
+ if "cuda" in device:
18
+ torch.cuda.synchronize(device)
19
+ elif ("cpu" in device) or ("mps" in device):
20
+ pass
21
+ else:
22
+ print(f"device={device} is not yet suppported")
23
+
24
+
25
+ torch._inductor.config.coordinate_descent_tuning = True
26
+ torch._inductor.config.triton.unique_kernel_names = True
27
+ torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
28
+
29
+ default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+
31
+ # support running without installing as a package
32
+ wd = Path(__file__).parent.parent.resolve()
33
+ sys.path.append(str(wd))
34
+
35
+ from model import Transformer
36
+ from tokenizer import get_tokenizer
37
+
38
+ def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
39
+ q = torch.empty_like(probs_sort).exponential_(1)
40
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
41
+
42
+ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
43
+ logits = logits / max(temperature, 1e-5)
44
+
45
+ if top_k is not None:
46
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
47
+ pivot = v.select(-1, -1).unsqueeze(-1)
48
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
49
+ probs = torch.nn.functional.softmax(logits, dim=-1)
50
+ return probs
51
+
52
+ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
53
+ probs = logits_to_probs(logits[0, -1], temperature, top_k)
54
+ idx_next = multinomial_sample_one_no_sync(probs)
55
+ return idx_next, probs
56
+
57
+ def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
58
+ # input_pos: [B, S]
59
+ logits = model(x, input_pos)
60
+ return sample(logits, **sampling_kwargs)[0]
61
+
62
+ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
63
+ # input_pos: [B, 1]
64
+ assert input_pos.shape[-1] == 1
65
+ logits = model(x, input_pos)
66
+ return sample(logits, **sampling_kwargs)
67
+
68
+ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
69
+ new_tokens, new_probs = [], []
70
+ for i in range(num_new_tokens):
71
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
72
+ next_token, next_prob = decode_one_token(
73
+ model, cur_token, input_pos, **sampling_kwargs
74
+ )
75
+ input_pos += 1
76
+ new_tokens.append(next_token.clone())
77
+ callback(new_tokens[-1])
78
+ new_probs.append(next_prob.clone())
79
+ cur_token = next_token.view(1, -1)
80
+
81
+ return new_tokens, new_probs
82
+
83
+
84
+ def model_forward(model, x, input_pos):
85
+ return model(x, input_pos)
86
+
87
+ def speculative_decode(
88
+ model: Transformer,
89
+ draft_model: Transformer,
90
+ cur_token: torch.Tensor,
91
+ input_pos: int,
92
+ speculate_k: int,
93
+ **sampling_kwargs
94
+ ) -> torch.Tensor:
95
+ # draft model inference sequentially
96
+ device = cur_token.device
97
+ orig_input_pos = torch.tensor([input_pos], dtype=torch.int64, device=cur_token.device)
98
+ draft_tokens, draft_probs = decode_n_tokens(draft_model, cur_token.view(1, -1), orig_input_pos.clone(), speculate_k, **sampling_kwargs)
99
+
100
+ draft_tokens = torch.cat(draft_tokens)
101
+ # parallel inference on target model using draft tokens
102
+ target_logits = model_forward(
103
+ model,
104
+ torch.cat([cur_token.view(1), draft_tokens]).view(1, -1),
105
+ torch.arange(input_pos, input_pos + speculate_k + 1, device=cur_token.device)
106
+ )
107
+ target_probs = logits_to_probs(target_logits[0], **sampling_kwargs)
108
+ draft_probs = torch.stack(draft_probs)
109
+ # q: target prob, p: draft prob
110
+ # q >= p: always accept draft token
111
+ # q < p: q/p prob to accept draft token
112
+ p = draft_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
113
+ q = target_probs[torch.arange(0, speculate_k, device=device), draft_tokens]
114
+ accept_draft_prob = torch.minimum(torch.ones(()), q[:speculate_k]/ p)
115
+ rejected_locations = (torch.rand_like(accept_draft_prob) > accept_draft_prob).nonzero()
116
+
117
+ if rejected_locations.shape[0] == 0: # All draft tokens have been accepted
118
+ accept_length = speculate_k + 1
119
+ last_token = multinomial_sample_one_no_sync(target_probs[-1])
120
+ # fill last token into draft model
121
+ model_forward(
122
+ draft_model,
123
+ draft_tokens[-1].view(1, -1),
124
+ orig_input_pos + speculate_k,
125
+ )
126
+ return torch.cat([draft_tokens, last_token])
127
+ else:
128
+ accept_length = rejected_locations[0].item()
129
+ p = draft_probs[accept_length]
130
+ q = target_probs[accept_length]
131
+ new = q - p
132
+ new = torch.where(new > 0, new, 0.0)
133
+ new = new / new.sum()
134
+ next_token = multinomial_sample_one_no_sync(new)
135
+ return torch.cat([draft_tokens[:accept_length], next_token])
136
+
137
+ @torch.no_grad()
138
+ def generate(
139
+ model: Transformer,
140
+ prompt: torch.Tensor,
141
+ max_new_tokens: int,
142
+ *,
143
+ interactive: bool,
144
+ draft_model: Transformer,
145
+ speculate_k: Optional[int] = 8,
146
+ callback = lambda x: x,
147
+ **sampling_kwargs
148
+ ) -> torch.Tensor:
149
+ """
150
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
151
+ """
152
+
153
+ is_speculative = draft_model is not None
154
+ # create an empty tensor of the expected final shape and fill in the current tokens
155
+ T = prompt.size(0)
156
+ T_new = T + max_new_tokens
157
+ if interactive:
158
+ max_seq_length = 350
159
+ else:
160
+ max_seq_length = min(T_new, model.config.block_size)
161
+
162
+ device, dtype = prompt.device, prompt.dtype
163
+ max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
164
+ with torch.device(device):
165
+ model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
166
+ if is_speculative and draft_model is not model:
167
+ draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
168
+
169
+ # create an empty tensor of the expected final shape and fill in the current tokens
170
+ empty = torch.empty(T_new, dtype=dtype, device=device)
171
+ empty[:T] = prompt
172
+ seq = empty
173
+ input_pos = torch.arange(0, T, device=device)
174
+
175
+ next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone()
176
+ if is_speculative:
177
+ prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
178
+ seq[T] = next_token
179
+
180
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
181
+ accept_counts = [0] * (speculate_k + 1)
182
+
183
+ if is_speculative:
184
+ input_pos = input_pos.item() # for speculative decoding easier to keep on host
185
+ while input_pos < T_new - 1:
186
+ cur_token = next_token.view(())
187
+
188
+ next_tokens = speculative_decode(
189
+ model, draft_model, cur_token, input_pos, speculate_k, **sampling_kwargs
190
+ )
191
+
192
+ accept_counts[len(next_tokens) - 1] += 1
193
+ num_added = min(T_new - input_pos - 1, len(next_tokens))
194
+ seq[input_pos + 1 : input_pos + num_added + 1] = next_tokens[: num_added]
195
+ for i in next_tokens[: num_added,]:
196
+ callback(i)
197
+ input_pos = input_pos + num_added
198
+ next_token = next_tokens[-1]
199
+ else:
200
+ generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
201
+ seq[T + 1:] = torch.cat(generated_tokens)
202
+
203
+ generate_stats = {
204
+ 'accept_counts': accept_counts
205
+ }
206
+ return seq, generate_stats
207
+
208
+ def encode_tokens(tokenizer, string, bos=True, device=default_device):
209
+ tokens = tokenizer.encode(string)
210
+ if bos:
211
+ tokens = [tokenizer.bos_id()] + tokens
212
+ return torch.tensor(tokens, dtype=torch.int, device=device)
213
+
214
+ def _load_model(checkpoint_path, device, precision, use_tp):
215
+ use_cuda = 'cuda' in device
216
+ with torch.device('meta'):
217
+ model = Transformer.from_name(checkpoint_path.parent.name)
218
+
219
+ if "int8" in str(checkpoint_path):
220
+ print("Using int8 weight-only quantization!")
221
+ from quantize import WeightOnlyInt8QuantHandler
222
+ simple_quantizer = WeightOnlyInt8QuantHandler(model)
223
+ model = simple_quantizer.convert_for_runtime()
224
+
225
+ if "int4" in str(checkpoint_path):
226
+ print("Using int4 weight-only quantization!")
227
+ path_comps = checkpoint_path.name.split(".")
228
+ groupsize = int(path_comps[-2][1:])
229
+ from quantize import WeightOnlyInt4QuantHandler
230
+ simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
231
+ model = simple_quantizer.convert_for_runtime()
232
+
233
+ checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
234
+ if "model" in checkpoint and "stories" in str(checkpoint_path):
235
+ checkpoint = checkpoint["model"]
236
+ model.load_state_dict(checkpoint, assign=True)
237
+
238
+ if use_tp:
239
+ from tp import apply_tp
240
+ print("Applying tensor parallel to model ...")
241
+ apply_tp(model)
242
+
243
+ model = model.to(device=device, dtype=precision)
244
+ return model.eval()
245
+
246
+ def _get_model_size(model):
247
+ model_size = 0
248
+ for name, child in model.named_children():
249
+ if not isinstance(child, torch.nn.Embedding):
250
+ model_size += sum(
251
+ [
252
+ p.numel() * p.dtype.itemsize
253
+ for p in itertools.chain(child.parameters(), child.buffers())
254
+ ]
255
+ )
256
+ return model_size
257
+
258
+ B_INST, E_INST = "[INST]", "[/INST]"
259
+
260
+ def main(
261
+ prompt: str = "Hello, my name is",
262
+ interactive: bool = False,
263
+ num_samples: int = 5,
264
+ max_new_tokens: int = 100,
265
+ top_k: int = 200,
266
+ temperature: float = 0.8,
267
+ checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
268
+ compile: bool = True,
269
+ compile_prefill: bool = False,
270
+ profile: Optional[Path] = None,
271
+ draft_checkpoint_path: Optional[Path] = None,
272
+ speculate_k: int = 5,
273
+ device=default_device,
274
+ ) -> None:
275
+ """Generates text samples based on a pre-trained Transformer model and tokenizer.
276
+ """
277
+ assert checkpoint_path.is_file(), checkpoint_path
278
+
279
+ tokenizer_path = checkpoint_path.parent / "tokenizer.model"
280
+ assert tokenizer_path.is_file(), str(tokenizer_path)
281
+
282
+ global print
283
+ from tp import maybe_init_dist
284
+ rank = maybe_init_dist()
285
+ use_tp = rank is not None
286
+ if use_tp:
287
+ if rank != 0:
288
+ # only print on rank 0
289
+ print = lambda *args, **kwargs: None
290
+
291
+ print(f"Using device={device}")
292
+ precision = torch.bfloat16
293
+ is_speculative = draft_checkpoint_path is not None
294
+ is_chat = "chat" in str(checkpoint_path)
295
+
296
+ print("Loading model ...")
297
+ t0 = time.time()
298
+ model = _load_model(checkpoint_path, device, precision, use_tp)
299
+
300
+ if is_speculative:
301
+ draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp)
302
+ else:
303
+ draft_model = None
304
+
305
+ device_sync(device=device) # MKG
306
+ print(f"Time to load model: {time.time() - t0:.02f} seconds")
307
+
308
+ tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
309
+
310
+ encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
311
+ prompt_length = encoded.size(0)
312
+
313
+ torch.manual_seed(1234)
314
+ model_size = _get_model_size(model)
315
+ if compile:
316
+ if is_speculative and use_tp: # and ("cuda" in device):
317
+ torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
318
+
319
+ if is_speculative:
320
+ global model_forward, logits_to_prob
321
+ model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
322
+
323
+ global decode_one_token, prefill
324
+ decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
325
+
326
+ # Uncomment to squeeze more perf out of prefill
327
+ if compile_prefill:
328
+ prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
329
+
330
+
331
+ aggregate_metrics = {
332
+ 'tokens_per_sec': [],
333
+ 'accept_counts': [],
334
+ }
335
+ start = -1 if compile else 0
336
+
337
+ for i in range(start, num_samples):
338
+ device_sync(device=device) # MKG
339
+ if i >= 0 and interactive:
340
+ prompt = input("What is your prompt? ")
341
+ if is_chat:
342
+ prompt = f"{B_INST} {prompt.strip()} {E_INST}"
343
+ encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
344
+
345
+ if interactive and i >= 0:
346
+ buffer = []
347
+ period_id = tokenizer.encode('.')[0]
348
+ done_generating = False
349
+ def callback(x):
350
+ nonlocal done_generating
351
+ if done_generating:
352
+ return
353
+ buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
354
+ if x.item() == tokenizer.eos_id():
355
+ done_generating = True
356
+ if len(buffer) == 4 or done_generating:
357
+ print(''.join(buffer), end='', flush=True)
358
+ buffer.clear()
359
+ # print(, end='', flush=True)
360
+ else:
361
+ callback = lambda x : x
362
+ t0 = time.perf_counter()
363
+ import contextlib
364
+ if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
365
+ prof = contextlib.nullcontext()
366
+ else:
367
+ torch.profiler._utils._init_for_cuda_graphs()
368
+ prof = torch.profiler.profile()
369
+ with prof:
370
+ y, metrics = generate(
371
+ model,
372
+ encoded,
373
+ max_new_tokens,
374
+ draft_model=draft_model,
375
+ speculate_k=speculate_k,
376
+ interactive=interactive,
377
+ callback=callback,
378
+ temperature=temperature,
379
+ top_k=top_k,
380
+ )
381
+ aggregate_metrics['accept_counts'].append(metrics['accept_counts'])
382
+ if i == -1:
383
+ print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
384
+ continue
385
+ if hasattr(prof, "export_chrome_trace"):
386
+ if use_tp:
387
+ prof.export_chrome_trace(f"{profile}_rank_{rank}.json")
388
+ else:
389
+ prof.export_chrome_trace(f"{profile}.json")
390
+ device_sync(device=device) # MKG
391
+ t = time.perf_counter() - t0
392
+
393
+ if not interactive:
394
+ print(tokenizer.decode(y.tolist()))
395
+ else:
396
+ print()
397
+ tokens_generated = y.size(0) - prompt_length
398
+ tokens_sec = tokens_generated / t
399
+ aggregate_metrics['tokens_per_sec'].append(tokens_sec)
400
+ print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
401
+ print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
402
+ print("==========")
403
+ if is_speculative:
404
+ counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])]
405
+ acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated]
406
+ print(f"Acceptance probs: {acceptance_probs}")
407
+ print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}")
408
+
409
+ print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
410
+ print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
411
+
412
+
413
+ if __name__ == '__main__':
414
+ import argparse
415
+ parser = argparse.ArgumentParser(description='Your CLI description.')
416
+
417
+ parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
418
+ parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
419
+ parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
420
+ parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
421
+ parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
422
+ parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
423
+ parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
424
+ parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
425
+ parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
426
+ parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
427
+ parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
428
+ parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
429
+ parser.add_argument('--device', type=str, default=default_device, help='Device to use')
430
+
431
+ args = parser.parse_args()
432
+ main(
433
+ args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
434
+ args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
435
+ args.speculate_k, args.device
436
+ )
modules/gpt_fast/model.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+ from torch.nn import functional as F
13
+
14
+
15
+ def find_multiple(n: int, k: int) -> int:
16
+ if n % k == 0:
17
+ return n
18
+ return n + k - (n % k)
19
+
20
+ class AdaptiveLayerNorm(nn.Module):
21
+ r"""Adaptive Layer Normalization"""
22
+
23
+ def __init__(self, d_model, norm) -> None:
24
+ super(AdaptiveLayerNorm, self).__init__()
25
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
26
+ self.norm = norm
27
+ self.d_model = d_model
28
+ self.eps = self.norm.eps
29
+
30
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
31
+ if embedding is None:
32
+ return self.norm(input)
33
+ weight, bias = torch.split(
34
+ self.project_layer(embedding),
35
+ split_size_or_sections=self.d_model,
36
+ dim=-1,
37
+ )
38
+ return weight * self.norm(input) + bias
39
+
40
+
41
+ @dataclass
42
+ class ModelArgs:
43
+ block_size: int = 2048
44
+ vocab_size: int = 32000
45
+ n_layer: int = 32
46
+ n_head: int = 32
47
+ dim: int = 4096
48
+ intermediate_size: int = None
49
+ n_local_heads: int = -1
50
+ head_dim: int = 64
51
+ rope_base: float = 10000
52
+ norm_eps: float = 1e-5
53
+ has_cross_attention: bool = False
54
+ context_dim: int = 0
55
+ uvit_skip_connection: bool = False
56
+
57
+ def __post_init__(self):
58
+ if self.n_local_heads == -1:
59
+ self.n_local_heads = self.n_head
60
+ if self.intermediate_size is None:
61
+ hidden_dim = 4 * self.dim
62
+ n_hidden = int(2 * hidden_dim / 3)
63
+ self.intermediate_size = find_multiple(n_hidden, 256)
64
+ # self.head_dim = self.dim // self.n_head
65
+
66
+ @classmethod
67
+ def from_name(cls, name: str):
68
+ if name in transformer_configs:
69
+ return cls(**transformer_configs[name])
70
+ # fuzzy search
71
+ config = [config for config in transformer_configs if config.lower() in str(name).lower()]
72
+
73
+ # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
74
+ # take longer name (as it have more symbols matched)
75
+ if len(config) > 1:
76
+ config.sort(key=len, reverse=True)
77
+ assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
78
+
79
+ return cls(**transformer_configs[config[0]])
80
+
81
+
82
+ transformer_configs = {
83
+ "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000),
84
+ "7B": dict(n_layer=32, n_head=32, dim=4096),
85
+ "13B": dict(n_layer=40, n_head=40, dim=5120),
86
+ "30B": dict(n_layer=60, n_head=52, dim=6656),
87
+ "34B": dict(n_layer=48, n_head=64, dim=8192, vocab_size=32000, n_local_heads=8, intermediate_size=22016,
88
+ rope_base=1000000), # CodeLlama-34B-Python-hf
89
+ "70B": dict(n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672),
90
+ "Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
91
+ "stories15M": dict(n_layer=6, n_head=6, dim=288),
92
+ "stories110M": dict(n_layer=12, n_head=12, dim=768),
93
+
94
+ "llama-3-8b": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336,
95
+ vocab_size=128256, rope_base=500000),
96
+ "llama-3-70b": dict(block_size=8192, n_layer=80, n_head=64, n_local_heads=8, dim=8192, intermediate_size=28672,
97
+ vocab_size=128256, rope_base=500000),
98
+ }
99
+
100
+
101
+ class KVCache(nn.Module):
102
+ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
103
+ super().__init__()
104
+ cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
105
+ self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
106
+ self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
107
+
108
+ def update(self, input_pos, k_val, v_val):
109
+ # input_pos: [S], k_val: [B, H, S, D]
110
+ assert input_pos.shape[0] == k_val.shape[2]
111
+
112
+ k_out = self.k_cache
113
+ v_out = self.v_cache
114
+ k_out[:, :, input_pos] = k_val
115
+ v_out[:, :, input_pos] = v_val
116
+
117
+ return k_out, v_out
118
+
119
+
120
+ class Transformer(nn.Module):
121
+ def __init__(self, config: ModelArgs) -> None:
122
+ super().__init__()
123
+ self.config = config
124
+
125
+ self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
126
+ self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
127
+
128
+ self.freqs_cis: Optional[Tensor] = None
129
+ self.mask_cache: Optional[Tensor] = None
130
+ self.max_batch_size = -1
131
+ self.max_seq_length = -1
132
+
133
+ def setup_caches(self, max_batch_size, max_seq_length, use_kv_cache=True):
134
+ if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
135
+ return
136
+ head_dim = self.config.dim // self.config.n_head
137
+ max_seq_length = find_multiple(max_seq_length, 8)
138
+ self.max_seq_length = max_seq_length
139
+ self.max_batch_size = max_batch_size
140
+ dtype = self.norm.project_layer.weight.dtype
141
+ device = self.norm.project_layer.weight.device
142
+
143
+ if not self.training and use_kv_cache:
144
+ for b in self.layers:
145
+ b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype).to(device)
146
+
147
+ self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
148
+ self.config.rope_base, dtype).to(device)
149
+ self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device)
150
+ self.use_kv_cache = use_kv_cache
151
+ self.uvit_skip_connection = self.config.uvit_skip_connection
152
+ if self.uvit_skip_connection:
153
+ self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2]
154
+ self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2]
155
+ else:
156
+ self.layers_emit_skip = []
157
+ self.layers_receive_skip = []
158
+
159
+ def forward(self,
160
+ x: Tensor,
161
+ c: Tensor,
162
+ input_pos: Optional[Tensor] = None,
163
+ mask: Optional[Tensor] = None,
164
+ context: Optional[Tensor] = None,
165
+ context_input_pos: Optional[Tensor] = None,
166
+ cross_attention_mask: Optional[Tensor] = None,
167
+ ) -> Tensor:
168
+ assert self.freqs_cis is not None, "Caches must be initialized first"
169
+ if mask is None: # in case of non-causal model
170
+ if not self.training and self.use_kv_cache:
171
+ mask = self.causal_mask[None, None, input_pos]
172
+ else:
173
+ mask = self.causal_mask[None, None, input_pos]
174
+ mask = mask[..., input_pos]
175
+ freqs_cis = self.freqs_cis[input_pos]
176
+ if context is not None:
177
+ context_freqs_cis = self.freqs_cis[context_input_pos]
178
+ else:
179
+ context_freqs_cis = None
180
+ skip_in_x_list = []
181
+ for i, layer in enumerate(self.layers):
182
+ if self.uvit_skip_connection and i in self.layers_receive_skip:
183
+ skip_in_x = skip_in_x_list.pop(-1)
184
+ else:
185
+ skip_in_x = None
186
+ x = layer(x, c, input_pos, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask, skip_in_x)
187
+ if self.uvit_skip_connection and i in self.layers_emit_skip:
188
+ skip_in_x_list.append(x)
189
+ x = self.norm(x, c)
190
+ return x
191
+
192
+ @classmethod
193
+ def from_name(cls, name: str):
194
+ return cls(ModelArgs.from_name(name))
195
+
196
+
197
+ class TransformerBlock(nn.Module):
198
+ def __init__(self, config: ModelArgs) -> None:
199
+ super().__init__()
200
+ self.attention = Attention(config)
201
+ self.feed_forward = FeedForward(config)
202
+ self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
203
+ self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
204
+
205
+ if config.has_cross_attention:
206
+ self.has_cross_attention = True
207
+ self.cross_attention = Attention(config, is_cross_attention=True)
208
+ self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
209
+ else:
210
+ self.has_cross_attention = False
211
+
212
+ if config.uvit_skip_connection:
213
+ self.skip_in_linear = nn.Linear(config.dim * 2, config.dim)
214
+ self.uvit_skip_connection = True
215
+ else:
216
+ self.uvit_skip_connection = False
217
+
218
+ def forward(self,
219
+ x: Tensor,
220
+ c: Tensor,
221
+ input_pos: Tensor,
222
+ freqs_cis: Tensor,
223
+ mask: Tensor,
224
+ context: Optional[Tensor] = None,
225
+ context_freqs_cis: Optional[Tensor] = None,
226
+ cross_attention_mask: Optional[Tensor] = None,
227
+ skip_in_x: Optional[Tensor] = None,
228
+ ) -> Tensor:
229
+ if self.uvit_skip_connection and skip_in_x is not None:
230
+ x = self.skip_in_linear(torch.cat([x, skip_in_x], dim=-1))
231
+ h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask, input_pos)
232
+ if self.has_cross_attention:
233
+ h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, input_pos, context, context_freqs_cis)
234
+ out = h + self.feed_forward(self.ffn_norm(h, c))
235
+ return out
236
+
237
+
238
+ class Attention(nn.Module):
239
+ def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
240
+ super().__init__()
241
+ assert config.dim % config.n_head == 0
242
+
243
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
244
+ # key, query, value projections for all heads, but in a batch
245
+ if is_cross_attention:
246
+ self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
247
+ self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
248
+ else:
249
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
250
+ self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
251
+ self.kv_cache = None
252
+
253
+ self.n_head = config.n_head
254
+ self.head_dim = config.head_dim
255
+ self.n_local_heads = config.n_local_heads
256
+ self.dim = config.dim
257
+ # self._register_load_state_dict_pre_hook(self.load_hook)
258
+
259
+ # def load_hook(self, state_dict, prefix, *args):
260
+ # if prefix + "wq.weight" in state_dict:
261
+ # wq = state_dict.pop(prefix + "wq.weight")
262
+ # wk = state_dict.pop(prefix + "wk.weight")
263
+ # wv = state_dict.pop(prefix + "wv.weight")
264
+ # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
265
+
266
+ def forward(self,
267
+ x: Tensor,
268
+ freqs_cis: Tensor,
269
+ mask: Tensor,
270
+ input_pos: Optional[Tensor] = None,
271
+ context: Optional[Tensor] = None,
272
+ context_freqs_cis: Optional[Tensor] = None,
273
+ ) -> Tensor:
274
+ bsz, seqlen, _ = x.shape
275
+
276
+ kv_size = self.n_local_heads * self.head_dim
277
+ if context is None:
278
+ q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
279
+ context_seqlen = seqlen
280
+ else:
281
+ q = self.wq(x)
282
+ k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
283
+ context_seqlen = context.shape[1]
284
+
285
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
286
+ k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
287
+ v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
288
+
289
+ q = apply_rotary_emb(q, freqs_cis)
290
+ k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
291
+
292
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
293
+
294
+ if self.kv_cache is not None:
295
+ k, v = self.kv_cache.update(input_pos, k, v)
296
+
297
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
298
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
299
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
300
+
301
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
302
+
303
+ y = self.wo(y)
304
+ return y
305
+
306
+
307
+ class FeedForward(nn.Module):
308
+ def __init__(self, config: ModelArgs) -> None:
309
+ super().__init__()
310
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
311
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
312
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
313
+
314
+ def forward(self, x: Tensor) -> Tensor:
315
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
316
+
317
+
318
+ class RMSNorm(nn.Module):
319
+ def __init__(self, dim: int, eps: float = 1e-5):
320
+ super().__init__()
321
+ self.eps = eps
322
+ self.weight = nn.Parameter(torch.ones(dim))
323
+
324
+ def _norm(self, x):
325
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
326
+
327
+ def forward(self, x: Tensor) -> Tensor:
328
+ output = self._norm(x.float()).type_as(x)
329
+ return output * self.weight
330
+
331
+
332
+ def precompute_freqs_cis(
333
+ seq_len: int, n_elem: int, base: int = 10000,
334
+ dtype: torch.dtype = torch.bfloat16
335
+ ) -> Tensor:
336
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
337
+ t = torch.arange(seq_len, device=freqs.device)
338
+ freqs = torch.outer(t, freqs)
339
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
340
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
341
+ return cache.to(dtype=dtype)
342
+
343
+
344
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
345
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
346
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
347
+ x_out2 = torch.stack(
348
+ [
349
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
350
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
351
+ ],
352
+ -1,
353
+ )
354
+
355
+ x_out2 = x_out2.flatten(3)
356
+ return x_out2.type_as(x)
modules/gpt_fast/quantize.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import time
7
+ from pathlib import Path
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from tokenizer import get_tokenizer
13
+
14
+ try:
15
+ from GPTQ import GenericGPTQRunner, InputRecorder
16
+ from eval import get_task_dict, evaluate, lm_eval
17
+ except:
18
+ pass
19
+
20
+ from model import Transformer
21
+
22
+ ##### Quantization Primitives ######
23
+
24
+ def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
25
+ # assumes symmetric quantization
26
+ # assumes axis == 0
27
+ # assumes dense memory format
28
+ # TODO(future): relax ^ as needed
29
+
30
+ # default setup for affine quantization of activations
31
+ eps = torch.finfo(torch.float32).eps
32
+
33
+ # get min and max
34
+ min_val, max_val = torch.aminmax(x, dim=1)
35
+
36
+ # calculate scales and zero_points based on min and max
37
+ # reference: https://fburl.com/code/srbiybme
38
+ min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
39
+ max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
40
+ device = min_val_neg.device
41
+
42
+ # reference: https://fburl.com/code/4wll53rk
43
+ max_val_pos = torch.max(-min_val_neg, max_val_pos)
44
+ scales = max_val_pos / (float(quant_max - quant_min) / 2)
45
+ # ensure scales is the same dtype as the original tensor
46
+ scales = torch.clamp(scales, min=eps).to(x.dtype)
47
+ zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
48
+
49
+ # quantize based on qmin/qmax/scales/zp
50
+ # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
51
+ x_div = x / scales.unsqueeze(-1)
52
+ x_round = torch.round(x_div)
53
+ x_zp = x_round + zero_points.unsqueeze(-1)
54
+ quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
55
+
56
+ return quant, scales, zero_points
57
+
58
+ def get_group_qparams(w, n_bit=4, groupsize=128):
59
+ # needed for GPTQ with padding
60
+ if groupsize > w.shape[-1]:
61
+ groupsize = w.shape[-1]
62
+ assert groupsize > 1
63
+ assert w.shape[-1] % groupsize == 0
64
+ assert w.dim() == 2
65
+
66
+ to_quant = w.reshape(-1, groupsize)
67
+ assert torch.isnan(to_quant).sum() == 0
68
+
69
+ max_val = to_quant.amax(dim=1, keepdim=True)
70
+ min_val = to_quant.amin(dim=1, keepdim=True)
71
+ max_int = 2**n_bit - 1
72
+ scales = (max_val - min_val).clamp(min=1e-6) / max_int
73
+ zeros = min_val + scales * (2 ** (n_bit - 1))
74
+ return scales.to(torch.bfloat16).reshape(w.shape[0], -1), zeros.to(
75
+ torch.bfloat16
76
+ ).reshape(w.shape[0], -1)
77
+
78
+
79
+ def pack_scales_and_zeros(scales, zeros):
80
+ assert scales.shape == zeros.shape
81
+ assert scales.dtype == torch.bfloat16
82
+ assert zeros.dtype == torch.bfloat16
83
+ return (
84
+ torch.cat(
85
+ [
86
+ scales.reshape(scales.size(0), scales.size(1), 1),
87
+ zeros.reshape(zeros.size(0), zeros.size(1), 1),
88
+ ],
89
+ 2,
90
+ )
91
+ .transpose(0, 1)
92
+ .contiguous()
93
+ )
94
+
95
+
96
+ def unpack_scales_and_zeros(scales_and_zeros):
97
+ assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
98
+ assert scales_and_zeros.dtype == torch.float
99
+ return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
100
+
101
+
102
+ def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
103
+ assert groupsize > 1
104
+ # needed for GPTQ single column quantize
105
+ if groupsize > w.shape[-1] and scales.shape[-1] == 1:
106
+ groupsize = w.shape[-1]
107
+
108
+ assert w.shape[-1] % groupsize == 0
109
+ assert w.dim() == 2
110
+
111
+ to_quant = w.reshape(-1, groupsize)
112
+ assert torch.isnan(to_quant).sum() == 0
113
+
114
+ scales = scales.reshape(-1, 1)
115
+ zeros = zeros.reshape(-1, 1)
116
+ min_val = zeros - scales * (2 ** (n_bit - 1))
117
+ max_int = 2**n_bit - 1
118
+ min_int = 0
119
+ w_int32 = (
120
+ to_quant.sub(min_val)
121
+ .div(scales)
122
+ .round()
123
+ .clamp_(min_int, max_int)
124
+ .to(torch.int32)
125
+ .reshape_as(w)
126
+ )
127
+
128
+ return w_int32
129
+
130
+
131
+ def group_quantize_tensor(w, n_bit=4, groupsize=128):
132
+ scales, zeros = get_group_qparams(w, n_bit, groupsize)
133
+ w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
134
+ scales_and_zeros = pack_scales_and_zeros(scales, zeros)
135
+ return w_int32, scales_and_zeros
136
+
137
+
138
+ def group_dequantize_tensor_from_qparams(
139
+ w_int32, scales, zeros, n_bit=4, groupsize=128
140
+ ):
141
+ assert groupsize > 1
142
+ # needed for GPTQ single column dequantize
143
+ if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
144
+ groupsize = w_int32.shape[-1]
145
+ assert w_int32.shape[-1] % groupsize == 0
146
+ assert w_int32.dim() == 2
147
+
148
+ w_int32_grouped = w_int32.reshape(-1, groupsize)
149
+ scales = scales.reshape(-1, 1)
150
+ zeros = zeros.reshape(-1, 1)
151
+
152
+ w_dq = (
153
+ w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
154
+ )
155
+ return w_dq
156
+
157
+
158
+ def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
159
+ scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
160
+ return group_dequantize_tensor_from_qparams(
161
+ w_int32, scales, zeros, n_bit, groupsize
162
+ )
163
+
164
+ class QuantHandler:
165
+ def __init__(self, mod):
166
+ self.mod = mod
167
+
168
+ def create_quantized_state_dict(self) -> "StateDict":
169
+ pass
170
+
171
+ def convert_for_runtime(self) -> "nn.Module":
172
+ pass
173
+
174
+ class GPTQQuantHandler(QuantHandler):
175
+ """
176
+ This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
177
+ Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement
178
+ __init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime.
179
+
180
+ The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and
181
+ create_quantized_state_dict. Here is a description of each function.
182
+
183
+ get_qparams_func:
184
+ A function that calculates the quantization qparams for an input tensor.
185
+ Args:
186
+ weight: A 2d weight tensor with non-integer dtype.
187
+ Returns:
188
+ qparams: it can have any format but will need to be handled by the other defined functions below.
189
+
190
+ quantize_func:
191
+ A function that applies quantization to an input tensor. It should be noted
192
+ that this function needs to be able to handle quantizing the entire weight tensor, a single group,
193
+ or a single column.
194
+ Args:
195
+ weight: A 2d weight tensor with non-integer dtype.
196
+ qparams: the output from get_qparams_func
197
+ Returns:
198
+ quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
199
+
200
+
201
+ dequantize_func:
202
+ A function that dequantizes an input quantized weight tensor. It should be noted
203
+ that this function needs to be able to handle dequantizing the entire weight tensor, a single group,
204
+ or a single column.
205
+ Args:
206
+ quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
207
+ qparams: the output from get_qparams_func
208
+ Returns:
209
+ weight: A 2d weight tensor with non-integer dtype.
210
+
211
+ combine_qparams_list_func:
212
+ A function that combines several qparams into one qparam.
213
+ Args:
214
+ qparams_list: a list of qparams objects, each obtained by calling get_qparams_func
215
+ on a single group from a weight tensor
216
+ Returns:
217
+ qparams: an object of the same format as the qparams above.
218
+
219
+ skip_layer_func:
220
+ A function that determines which linear layers should be skipped during GPTQ
221
+ Args:
222
+ weight: A 2d weight tensor with non-integer dtype.
223
+ Returns:
224
+ skip: boolean indicating whether layer should be skipped
225
+
226
+ make_names_and_values_dict_func:
227
+ A function that prepares the qparams and quantized_weight and creates a dictionary indicating how they
228
+ should be inserted into the state_dict. Generally any packing of the weight and qparams should be done here.
229
+ Args:
230
+ quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
231
+ qparams: the output from get_qparams_func
232
+ Returns:
233
+ names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the
234
+ corresponding quantized weights and qparams.
235
+ """
236
+ def __init__(self):
237
+ assert self.mod is not None
238
+ assert self.get_qparams_func is not None
239
+ assert self.quantize_func is not None
240
+ assert self.dequantize_func is not None
241
+ assert self.combine_qparams_list_func is not None
242
+ assert self.make_names_and_values_dict_func is not None
243
+
244
+ @staticmethod
245
+ def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput":
246
+ input_recorder = InputRecorder(
247
+ model,
248
+ tokenizer,
249
+ calibration_seq_length,
250
+ pad_calibration_inputs,
251
+ )
252
+
253
+ try:
254
+ lm_eval.tasks.initialize_tasks()
255
+ except:
256
+ pass
257
+ task_dict = get_task_dict(calibration_tasks)
258
+ print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)
259
+
260
+ evaluate(
261
+ input_recorder,
262
+ task_dict,
263
+ limit=calibration_limit,
264
+ )
265
+ inputs = input_recorder.get_recorded_inputs()
266
+ assert inputs is not None, (
267
+ f"No inputs were collected, use a task other than {calibration_tasks}, "+
268
+ f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+
269
+ f"{calibration_seq_length})"
270
+ )
271
+ print(f"Obtained {len(inputs[0].values)} calibration samples")
272
+ return inputs
273
+
274
+ @torch.no_grad()
275
+ def create_quantized_state_dict(
276
+ self,
277
+ tokenizer,
278
+ blocksize,
279
+ percdamp,
280
+ groupsize,
281
+ calibration_tasks,
282
+ calibration_limit,
283
+ calibration_seq_length,
284
+ pad_calibration_inputs,
285
+ ) -> "StateDict":
286
+ inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs)
287
+ print("Tracing model for GPTQ")
288
+ GPTQ_runner = GenericGPTQRunner(
289
+ self.mod,
290
+ inputs,
291
+ blocksize,
292
+ percdamp,
293
+ groupsize,
294
+ ).configure_quantization_mode(
295
+ self.get_qparams_func,
296
+ self.quantize_func,
297
+ self.dequantize_func,
298
+ self.combine_qparams_list_func,
299
+ self.make_names_and_values_dict_func,
300
+ self.skip_layer_func
301
+ )
302
+
303
+ print("Applying GPTQ to weights")
304
+ GPTQ_runner.run()
305
+ return GPTQ_runner.get_quantized_state_dict()
306
+
307
+ def convert_for_runtime(self) -> "nn.Module":
308
+ pass
309
+
310
+ ##### Weight-only int8 per-channel quantized code ######
311
+
312
+ def replace_linear_weight_only_int8_per_channel(module):
313
+ for name, child in module.named_children():
314
+ if isinstance(child, nn.Linear):
315
+ setattr(module, name, WeightOnlyInt8Linear(child.in_features, child.out_features))
316
+ else:
317
+ replace_linear_weight_only_int8_per_channel(child)
318
+
319
+ class WeightOnlyInt8QuantHandler:
320
+ def __init__(self, mod):
321
+ self.mod = mod
322
+
323
+ @torch.no_grad()
324
+ def create_quantized_state_dict(self):
325
+ cur_state_dict = self.mod.state_dict()
326
+ for fqn, mod in self.mod.named_modules():
327
+ if isinstance(mod, torch.nn.Linear):
328
+ int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8)
329
+ cur_state_dict[f"{fqn}.weight"] = int8_weight
330
+ cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
331
+
332
+ return cur_state_dict
333
+
334
+ def convert_for_runtime(self):
335
+ replace_linear_weight_only_int8_per_channel(self.mod)
336
+ return self.mod
337
+
338
+
339
+ class WeightOnlyInt8Linear(torch.nn.Module):
340
+ __constants__ = ['in_features', 'out_features']
341
+ in_features: int
342
+ out_features: int
343
+ weight: torch.Tensor
344
+
345
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
346
+ device=None, dtype=None) -> None:
347
+ factory_kwargs = {'device': device, 'dtype': dtype}
348
+ super().__init__()
349
+ self.in_features = in_features
350
+ self.out_features = out_features
351
+ self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8))
352
+ self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
353
+
354
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
355
+ return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
356
+
357
+ ##### weight only int4 per channel groupwise quantized code ######
358
+
359
+ def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles):
360
+ weight_int32, scales_and_zeros = group_quantize_tensor(
361
+ weight_bf16, n_bit=4, groupsize=groupsize
362
+ )
363
+ weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
364
+ return weight_int4pack, scales_and_zeros
365
+
366
+
367
+ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
368
+ origin_x_size = x.size()
369
+ x = x.reshape(-1, origin_x_size[-1])
370
+ c = torch.ops.aten._weight_int4pack_mm(x, weight_int4pack, groupsize, scales_and_zeros)
371
+ new_shape = origin_x_size[:-1] + (out_features,)
372
+ c = c.reshape(new_shape)
373
+ return c
374
+
375
+
376
+ def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
377
+ return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
378
+
379
+ def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
380
+ for name, child in module.named_children():
381
+ if isinstance(child, nn.Linear):
382
+ if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
383
+ setattr(module, name, WeightOnlyInt4Linear(
384
+ child.in_features, child.out_features, bias=False,
385
+ groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False,
386
+ ))
387
+ elif padding:
388
+ setattr(module, name, WeightOnlyInt4Linear(
389
+ child.in_features, child.out_features, bias=False,
390
+ groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True,
391
+ ))
392
+ else:
393
+ replace_linear_int4(child, groupsize, inner_k_tiles, padding)
394
+
395
+
396
+ class WeightOnlyInt4QuantHandler:
397
+ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
398
+ self.mod = mod
399
+ self.groupsize = groupsize
400
+ self.inner_k_tiles = inner_k_tiles
401
+ self.padding = padding
402
+ assert groupsize in [32, 64, 128, 256]
403
+ assert inner_k_tiles in [2, 4, 8]
404
+
405
+ @torch.no_grad()
406
+ def create_quantized_state_dict(self, use_cuda = True):
407
+ if use_cuda:
408
+ device="cuda"
409
+ else:
410
+ device="cpu"
411
+
412
+ cur_state_dict = self.mod.state_dict()
413
+ for fqn, mod in self.mod.named_modules():
414
+ if isinstance(mod, torch.nn.Linear):
415
+ assert not mod.bias
416
+ out_features = mod.out_features
417
+ in_features = mod.in_features
418
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
419
+ print(f"linear: {fqn}, in={in_features}, out={out_features}")
420
+
421
+ weight = mod.weight.data
422
+ if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
423
+ if self.padding:
424
+ from model import find_multiple
425
+ import torch.nn.functional as F
426
+ print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
427
+ padded_in_features = find_multiple(in_features, 1024)
428
+ weight = F.pad(weight, pad=(0, padded_in_features - in_features))
429
+ else:
430
+ print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
431
+ "and that groupsize and inner_k_tiles*16 evenly divide into it")
432
+ continue
433
+ weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
434
+ weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles
435
+ )
436
+ cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu')
437
+ cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu')
438
+
439
+ return cur_state_dict
440
+
441
+ def convert_for_runtime(self):
442
+ replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
443
+ return self.mod
444
+
445
+ class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
446
+ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
447
+ from model import find_multiple
448
+ self.mod = mod
449
+ self.groupsize = groupsize
450
+ self.inner_k_tiles = inner_k_tiles
451
+ self.padding = padding
452
+ self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize)
453
+ self.quantize_func = lambda w, qparams: \
454
+ group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize)
455
+ self.dequantize_func = lambda q, qparams: \
456
+ group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float()
457
+ self.combine_qparams_list_func = lambda qparams_list: \
458
+ [torch.cat(x, dim=1) for x in zip(*qparams_list)]
459
+ # skip unless padding=True or its correctly sized
460
+ self.skip_layer_func = lambda linear_weight: not (
461
+ _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding
462
+ )
463
+ # we need to do the padding here, both for q and the qparams if necessary
464
+ def make_names_and_values_dict_func(q, qparams):
465
+ k = q.shape[1]
466
+ new_k = find_multiple(k, 1024)
467
+ # how much we need to pad the weight
468
+ delta_k = new_k - q.shape[1]
469
+ final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles)
470
+ scales_and_zeros = pack_scales_and_zeros(*qparams)
471
+ # how many new groups we need for padded weight
472
+ delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
473
+ final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1)
474
+ return {"weight": final_q, "scales_and_zeros": final_s_and_z}
475
+ self.make_names_and_values_dict_func = make_names_and_values_dict_func
476
+ super().__init__()
477
+
478
+
479
+ def convert_for_runtime(self):
480
+ replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
481
+ return self.mod
482
+
483
+ class WeightOnlyInt4Linear(torch.nn.Module):
484
+ __constants__ = ['in_features', 'out_features']
485
+ in_features: int
486
+ out_features: int
487
+ weight: torch.Tensor
488
+
489
+ def __init__(
490
+ self, in_features: int, out_features: int,
491
+ bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True,
492
+ ) -> None:
493
+ super().__init__()
494
+ self.padding = padding
495
+ if padding:
496
+ from model import find_multiple
497
+ self.origin_in_features = in_features
498
+ in_features = find_multiple(in_features, 1024)
499
+
500
+ self.in_features = in_features
501
+ self.out_features = out_features
502
+ assert not bias, "require bias=False"
503
+ self.groupsize = groupsize
504
+ self.inner_k_tiles = inner_k_tiles
505
+
506
+ assert out_features % 8 == 0, "require out_features % 8 == 0"
507
+ assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
508
+ self.register_buffer(
509
+ "weight",
510
+ torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
511
+ )
512
+ self.register_buffer(
513
+ "scales_and_zeros",
514
+ torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
515
+ )
516
+
517
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
518
+ input = input.to(torch.bfloat16)
519
+ if self.padding:
520
+ import torch.nn.functional as F
521
+ input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
522
+ return linear_forward_int4(
523
+ input,
524
+ self.weight, self.scales_and_zeros, self.out_features, self.groupsize
525
+ )
526
+
527
+
528
+ def quantize(
529
+ checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
530
+ mode: str = 'int8',
531
+ # following arguments only available when setting int4 quantization.
532
+ groupsize: int = 128,
533
+ # following arguments only used for GPTQ
534
+ calibration_tasks: list = ["hellaswag"],
535
+ calibration_limit: int = 1000,
536
+ calibration_seq_length: int = 100,
537
+ pad_calibration_inputs: bool = False,
538
+ percdamp: float = .01,
539
+ blocksize: int = 128,
540
+ label: str = '',
541
+ ) -> None:
542
+ assert checkpoint_path.is_file(), checkpoint_path
543
+
544
+ device = 'cpu'
545
+ precision = torch.bfloat16
546
+
547
+ print("Loading model ...")
548
+ t0 = time.time()
549
+
550
+ with torch.device('meta'):
551
+ model = Transformer.from_name(checkpoint_path.parent.name)
552
+
553
+ checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
554
+ model.load_state_dict(checkpoint, assign=True)
555
+ model = model.to(dtype=precision, device=device)
556
+
557
+ if mode == 'int8':
558
+ print("Quantizing model weights for int8 weight-only symmetric per-channel quantization")
559
+ quant_handler = WeightOnlyInt8QuantHandler(model)
560
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
561
+
562
+ dir_name = checkpoint_path.parent
563
+ base_name = checkpoint_path.name
564
+ new_base_name = base_name.replace('.pth', f'{label}int8.pth')
565
+
566
+ elif mode == 'int4':
567
+ print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization")
568
+ quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
569
+ quantized_state_dict = quant_handler.create_quantized_state_dict()
570
+
571
+ dir_name = checkpoint_path.parent
572
+ base_name = checkpoint_path.name
573
+ new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth")
574
+
575
+ elif mode == 'int4-gptq':
576
+ print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...")
577
+ quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize)
578
+
579
+ tokenizer_path = checkpoint_path.parent / "tokenizer.model"
580
+ assert tokenizer_path.is_file(), str(tokenizer_path)
581
+ tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
582
+
583
+ quantized_state_dict = quant_handler.create_quantized_state_dict(
584
+ tokenizer,
585
+ blocksize,
586
+ percdamp,
587
+ groupsize,
588
+ calibration_tasks,
589
+ calibration_limit,
590
+ calibration_seq_length,
591
+ pad_calibration_inputs
592
+ )
593
+
594
+ dir_name = checkpoint_path.parent
595
+ base_name = checkpoint_path.name
596
+ new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth")
597
+ else:
598
+ raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")
599
+
600
+ quantize_path = dir_name / new_base_name
601
+ print(f"Writing quantized weights to {quantize_path}")
602
+ quantize_path.unlink(missing_ok=True) # remove existing file if one already there
603
+ torch.save(quantized_state_dict, quantize_path)
604
+ print(f"Quantization complete took {time.time() - t0:.02f} seconds")
605
+ return
606
+
607
+ if __name__ == '__main__':
608
+ import argparse
609
+ parser = argparse.ArgumentParser(description='Quantize a model.')
610
+ parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.')
611
+ parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform')
612
+ parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.')
613
+ parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
614
+ parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration')
615
+ parser.add_argument('--calibration_seq_length', type=int, default=100, help='length of sequences to use for gptq calibration')
616
+ parser.add_argument('--pad_calibration_inputs', type=bool, default=False, help='pads sequences shorter than calibration_seq_length to that length, yielding more calibration inputs but running much slower')
617
+ parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening')
618
+ parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq')
619
+ parser.add_argument('--label', type=str, default='_', help='label to add to output filename')
620
+
621
+ args = parser.parse_args()
622
+ quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label)
modules/hifigan/__pycache__/f0_predictor.cpython-310.pyc ADDED
Binary file (1.33 kB). View file
 
modules/hifigan/__pycache__/generator.cpython-310.pyc ADDED
Binary file (13.3 kB). View file
 
modules/hifigan/f0_predictor.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn.utils import weight_norm
17
+
18
+
19
+ class ConvRNNF0Predictor(nn.Module):
20
+ def __init__(self,
21
+ num_class: int = 1,
22
+ in_channels: int = 80,
23
+ cond_channels: int = 512
24
+ ):
25
+ super().__init__()
26
+
27
+ self.num_class = num_class
28
+ self.condnet = nn.Sequential(
29
+ weight_norm(
30
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
31
+ ),
32
+ nn.ELU(),
33
+ weight_norm(
34
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
35
+ ),
36
+ nn.ELU(),
37
+ weight_norm(
38
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
39
+ ),
40
+ nn.ELU(),
41
+ weight_norm(
42
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
43
+ ),
44
+ nn.ELU(),
45
+ weight_norm(
46
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
47
+ ),
48
+ nn.ELU(),
49
+ )
50
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ x = self.condnet(x)
54
+ x = x.transpose(1, 2)
55
+ return torch.abs(self.classifier(x).squeeze(-1))
modules/hifigan/generator.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ import typing as tp
18
+ import numpy as np
19
+ from scipy.signal import get_window
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.nn import Conv1d
24
+ from torch.nn import ConvTranspose1d
25
+ from torch.nn.utils import remove_weight_norm
26
+ from torch.nn.utils import weight_norm
27
+ from torch.distributions.uniform import Uniform
28
+
29
+ from torch import sin
30
+ from torch.nn.parameter import Parameter
31
+
32
+
33
+ """hifigan based generator implementation.
34
+
35
+ This code is modified from https://github.com/jik876/hifi-gan
36
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
37
+ https://github.com/NVIDIA/BigVGAN
38
+
39
+ """
40
+ class Snake(nn.Module):
41
+ '''
42
+ Implementation of a sine-based periodic activation function
43
+ Shape:
44
+ - Input: (B, C, T)
45
+ - Output: (B, C, T), same shape as the input
46
+ Parameters:
47
+ - alpha - trainable parameter
48
+ References:
49
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
50
+ https://arxiv.org/abs/2006.08195
51
+ Examples:
52
+ >>> a1 = snake(256)
53
+ >>> x = torch.randn(256)
54
+ >>> x = a1(x)
55
+ '''
56
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
57
+ '''
58
+ Initialization.
59
+ INPUT:
60
+ - in_features: shape of the input
61
+ - alpha: trainable parameter
62
+ alpha is initialized to 1 by default, higher values = higher-frequency.
63
+ alpha will be trained along with the rest of your model.
64
+ '''
65
+ super(Snake, self).__init__()
66
+ self.in_features = in_features
67
+
68
+ # initialize alpha
69
+ self.alpha_logscale = alpha_logscale
70
+ if self.alpha_logscale: # log scale alphas initialized to zeros
71
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
72
+ else: # linear scale alphas initialized to ones
73
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
74
+
75
+ self.alpha.requires_grad = alpha_trainable
76
+
77
+ self.no_div_by_zero = 0.000000001
78
+
79
+ def forward(self, x):
80
+ '''
81
+ Forward pass of the function.
82
+ Applies the function to the input elementwise.
83
+ Snake ∶= x + 1/a * sin^2 (xa)
84
+ '''
85
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
86
+ if self.alpha_logscale:
87
+ alpha = torch.exp(alpha)
88
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
89
+
90
+ return x
91
+
92
+ def get_padding(kernel_size, dilation=1):
93
+ return int((kernel_size * dilation - dilation) / 2)
94
+
95
+
96
+ def init_weights(m, mean=0.0, std=0.01):
97
+ classname = m.__class__.__name__
98
+ if classname.find("Conv") != -1:
99
+ m.weight.data.normal_(mean, std)
100
+
101
+
102
+
103
+ class ResBlock(torch.nn.Module):
104
+ """Residual block module in HiFiGAN/BigVGAN."""
105
+ def __init__(
106
+ self,
107
+ channels: int = 512,
108
+ kernel_size: int = 3,
109
+ dilations: tp.List[int] = [1, 3, 5],
110
+ ):
111
+ super(ResBlock, self).__init__()
112
+ self.convs1 = nn.ModuleList()
113
+ self.convs2 = nn.ModuleList()
114
+
115
+ for dilation in dilations:
116
+ self.convs1.append(
117
+ weight_norm(
118
+ Conv1d(
119
+ channels,
120
+ channels,
121
+ kernel_size,
122
+ 1,
123
+ dilation=dilation,
124
+ padding=get_padding(kernel_size, dilation)
125
+ )
126
+ )
127
+ )
128
+ self.convs2.append(
129
+ weight_norm(
130
+ Conv1d(
131
+ channels,
132
+ channels,
133
+ kernel_size,
134
+ 1,
135
+ dilation=1,
136
+ padding=get_padding(kernel_size, 1)
137
+ )
138
+ )
139
+ )
140
+ self.convs1.apply(init_weights)
141
+ self.convs2.apply(init_weights)
142
+ self.activations1 = nn.ModuleList([
143
+ Snake(channels, alpha_logscale=False)
144
+ for _ in range(len(self.convs1))
145
+ ])
146
+ self.activations2 = nn.ModuleList([
147
+ Snake(channels, alpha_logscale=False)
148
+ for _ in range(len(self.convs2))
149
+ ])
150
+
151
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
152
+ for idx in range(len(self.convs1)):
153
+ xt = self.activations1[idx](x)
154
+ xt = self.convs1[idx](xt)
155
+ xt = self.activations2[idx](xt)
156
+ xt = self.convs2[idx](xt)
157
+ x = xt + x
158
+ return x
159
+
160
+ def remove_weight_norm(self):
161
+ for idx in range(len(self.convs1)):
162
+ remove_weight_norm(self.convs1[idx])
163
+ remove_weight_norm(self.convs2[idx])
164
+
165
+ class SineGen(torch.nn.Module):
166
+ """ Definition of sine generator
167
+ SineGen(samp_rate, harmonic_num = 0,
168
+ sine_amp = 0.1, noise_std = 0.003,
169
+ voiced_threshold = 0,
170
+ flag_for_pulse=False)
171
+ samp_rate: sampling rate in Hz
172
+ harmonic_num: number of harmonic overtones (default 0)
173
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
174
+ noise_std: std of Gaussian noise (default 0.003)
175
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
176
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
177
+ Note: when flag_for_pulse is True, the first time step of a voiced
178
+ segment is always sin(np.pi) or cos(0)
179
+ """
180
+
181
+ def __init__(self, samp_rate, harmonic_num=0,
182
+ sine_amp=0.1, noise_std=0.003,
183
+ voiced_threshold=0):
184
+ super(SineGen, self).__init__()
185
+ self.sine_amp = sine_amp
186
+ self.noise_std = noise_std
187
+ self.harmonic_num = harmonic_num
188
+ self.sampling_rate = samp_rate
189
+ self.voiced_threshold = voiced_threshold
190
+
191
+ def _f02uv(self, f0):
192
+ # generate uv signal
193
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
194
+ return uv
195
+
196
+ @torch.no_grad()
197
+ def forward(self, f0):
198
+ """
199
+ :param f0: [B, 1, sample_len], Hz
200
+ :return: [B, 1, sample_len]
201
+ """
202
+
203
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
204
+ for i in range(self.harmonic_num + 1):
205
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
206
+
207
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
208
+ u_dist = Uniform(low=-np.pi, high=np.pi)
209
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
210
+ phase_vec[:, 0, :] = 0
211
+
212
+ # generate sine waveforms
213
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
214
+
215
+ # generate uv signal
216
+ uv = self._f02uv(f0)
217
+
218
+ # noise: for unvoiced should be similar to sine_amp
219
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
220
+ # . for voiced regions is self.noise_std
221
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
222
+ noise = noise_amp * torch.randn_like(sine_waves)
223
+
224
+ # first: set the unvoiced part to 0 by uv
225
+ # then: additive noise
226
+ sine_waves = sine_waves * uv + noise
227
+ return sine_waves, uv, noise
228
+
229
+
230
+ class SourceModuleHnNSF(torch.nn.Module):
231
+ """ SourceModule for hn-nsf
232
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
233
+ add_noise_std=0.003, voiced_threshod=0)
234
+ sampling_rate: sampling_rate in Hz
235
+ harmonic_num: number of harmonic above F0 (default: 0)
236
+ sine_amp: amplitude of sine source signal (default: 0.1)
237
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
238
+ note that amplitude of noise in unvoiced is decided
239
+ by sine_amp
240
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
241
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
242
+ F0_sampled (batchsize, length, 1)
243
+ Sine_source (batchsize, length, 1)
244
+ noise_source (batchsize, length 1)
245
+ uv (batchsize, length, 1)
246
+ """
247
+
248
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
249
+ add_noise_std=0.003, voiced_threshod=0):
250
+ super(SourceModuleHnNSF, self).__init__()
251
+
252
+ self.sine_amp = sine_amp
253
+ self.noise_std = add_noise_std
254
+
255
+ # to produce sine waveforms
256
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
257
+ sine_amp, add_noise_std, voiced_threshod)
258
+
259
+ # to merge source harmonics into a single excitation
260
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
261
+ self.l_tanh = torch.nn.Tanh()
262
+
263
+ def forward(self, x):
264
+ """
265
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
266
+ F0_sampled (batchsize, length, 1)
267
+ Sine_source (batchsize, length, 1)
268
+ noise_source (batchsize, length 1)
269
+ """
270
+ # source for harmonic branch
271
+ with torch.no_grad():
272
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
273
+ sine_wavs = sine_wavs.transpose(1, 2)
274
+ uv = uv.transpose(1, 2)
275
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
276
+
277
+ # source for noise branch, in the same shape as uv
278
+ noise = torch.randn_like(uv) * self.sine_amp / 3
279
+ return sine_merge, noise, uv
280
+
281
+
282
+ class HiFTGenerator(nn.Module):
283
+ """
284
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
285
+ https://arxiv.org/abs/2309.09493
286
+ """
287
+ def __init__(
288
+ self,
289
+ in_channels: int = 80,
290
+ base_channels: int = 512,
291
+ nb_harmonics: int = 8,
292
+ sampling_rate: int = 22050,
293
+ nsf_alpha: float = 0.1,
294
+ nsf_sigma: float = 0.003,
295
+ nsf_voiced_threshold: float = 10,
296
+ upsample_rates: tp.List[int] = [8, 8],
297
+ upsample_kernel_sizes: tp.List[int] = [16, 16],
298
+ istft_params: tp.Dict[str, int] = {"n_fft": 16, "hop_len": 4},
299
+ resblock_kernel_sizes: tp.List[int] = [3, 7, 11],
300
+ resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
301
+ source_resblock_kernel_sizes: tp.List[int] = [7, 11],
302
+ source_resblock_dilation_sizes: tp.List[tp.List[int]] = [[1, 3, 5], [1, 3, 5]],
303
+ lrelu_slope: float = 0.1,
304
+ audio_limit: float = 0.99,
305
+ f0_predictor: torch.nn.Module = None,
306
+ ):
307
+ super(HiFTGenerator, self).__init__()
308
+
309
+ self.out_channels = 1
310
+ self.nb_harmonics = nb_harmonics
311
+ self.sampling_rate = sampling_rate
312
+ self.istft_params = istft_params
313
+ self.lrelu_slope = lrelu_slope
314
+ self.audio_limit = audio_limit
315
+
316
+ self.num_kernels = len(resblock_kernel_sizes)
317
+ self.num_upsamples = len(upsample_rates)
318
+ self.m_source = SourceModuleHnNSF(
319
+ sampling_rate=sampling_rate,
320
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
321
+ harmonic_num=nb_harmonics,
322
+ sine_amp=nsf_alpha,
323
+ add_noise_std=nsf_sigma,
324
+ voiced_threshod=nsf_voiced_threshold)
325
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
326
+
327
+ self.conv_pre = weight_norm(
328
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
329
+ )
330
+
331
+ # Up
332
+ self.ups = nn.ModuleList()
333
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
334
+ self.ups.append(
335
+ weight_norm(
336
+ ConvTranspose1d(
337
+ base_channels // (2**i),
338
+ base_channels // (2**(i + 1)),
339
+ k,
340
+ u,
341
+ padding=(k - u) // 2,
342
+ )
343
+ )
344
+ )
345
+
346
+ # Down
347
+ self.source_downs = nn.ModuleList()
348
+ self.source_resblocks = nn.ModuleList()
349
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
350
+ downsample_cum_rates = np.cumprod(downsample_rates)
351
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
352
+ source_resblock_dilation_sizes)):
353
+ if u == 1:
354
+ self.source_downs.append(
355
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
356
+ )
357
+ else:
358
+ self.source_downs.append(
359
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
360
+ )
361
+
362
+ self.source_resblocks.append(
363
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
364
+ )
365
+
366
+ self.resblocks = nn.ModuleList()
367
+ for i in range(len(self.ups)):
368
+ ch = base_channels // (2**(i + 1))
369
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
370
+ self.resblocks.append(ResBlock(ch, k, d))
371
+
372
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
373
+ self.ups.apply(init_weights)
374
+ self.conv_post.apply(init_weights)
375
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
376
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
377
+ self.f0_predictor = f0_predictor
378
+
379
+ def _f02source(self, f0: torch.Tensor) -> torch.Tensor:
380
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
381
+
382
+ har_source, _, _ = self.m_source(f0)
383
+ return har_source.transpose(1, 2)
384
+
385
+ def _stft(self, x):
386
+ spec = torch.stft(
387
+ x,
388
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
389
+ return_complex=True)
390
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
391
+ return spec[..., 0], spec[..., 1]
392
+
393
+ def _istft(self, magnitude, phase):
394
+ magnitude = torch.clip(magnitude, max=1e2)
395
+ real = magnitude * torch.cos(phase)
396
+ img = magnitude * torch.sin(phase)
397
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
398
+ return inverse_transform
399
+
400
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
401
+ f0 = self.f0_predictor(x)
402
+ s = self._f02source(f0)
403
+
404
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
405
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
406
+
407
+ x = self.conv_pre(x)
408
+ for i in range(self.num_upsamples):
409
+ x = F.leaky_relu(x, self.lrelu_slope)
410
+ x = self.ups[i](x)
411
+
412
+ if i == self.num_upsamples - 1:
413
+ x = self.reflection_pad(x)
414
+
415
+ # fusion
416
+ si = self.source_downs[i](s_stft)
417
+ si = self.source_resblocks[i](si)
418
+ x = x + si
419
+
420
+ xs = None
421
+ for j in range(self.num_kernels):
422
+ if xs is None:
423
+ xs = self.resblocks[i * self.num_kernels + j](x)
424
+ else:
425
+ xs += self.resblocks[i * self.num_kernels + j](x)
426
+ x = xs / self.num_kernels
427
+
428
+ x = F.leaky_relu(x)
429
+ x = self.conv_post(x)
430
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
431
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
432
+
433
+ x = self._istft(magnitude, phase)
434
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
435
+ return x
436
+
437
+ def remove_weight_norm(self):
438
+ print('Removing weight norm...')
439
+ for l in self.ups:
440
+ remove_weight_norm(l)
441
+ for l in self.resblocks:
442
+ l.remove_weight_norm()
443
+ remove_weight_norm(self.conv_pre)
444
+ remove_weight_norm(self.conv_post)
445
+ self.source_module.remove_weight_norm()
446
+ for l in self.source_downs:
447
+ remove_weight_norm(l)
448
+ for l in self.source_resblocks:
449
+ l.remove_weight_norm()
450
+
451
+ @torch.inference_mode()
452
+ def inference(self, mel: torch.Tensor) -> torch.Tensor:
453
+ return self.forward(x=mel)
modules/layers.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from typing import Optional, Any
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ import torchaudio.functional as audio_F
9
+
10
+ import random
11
+ random.seed(0)
12
+
13
+
14
+ def _get_activation_fn(activ):
15
+ if activ == 'relu':
16
+ return nn.ReLU()
17
+ elif activ == 'lrelu':
18
+ return nn.LeakyReLU(0.2)
19
+ elif activ == 'swish':
20
+ return lambda x: x*torch.sigmoid(x)
21
+ else:
22
+ raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ)
23
+
24
+ class LinearNorm(torch.nn.Module):
25
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
26
+ super(LinearNorm, self).__init__()
27
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
28
+
29
+ torch.nn.init.xavier_uniform_(
30
+ self.linear_layer.weight,
31
+ gain=torch.nn.init.calculate_gain(w_init_gain))
32
+
33
+ def forward(self, x):
34
+ return self.linear_layer(x)
35
+
36
+
37
+ class ConvNorm(torch.nn.Module):
38
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
39
+ padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
40
+ super(ConvNorm, self).__init__()
41
+ if padding is None:
42
+ assert(kernel_size % 2 == 1)
43
+ padding = int(dilation * (kernel_size - 1) / 2)
44
+
45
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
46
+ kernel_size=kernel_size, stride=stride,
47
+ padding=padding, dilation=dilation,
48
+ bias=bias)
49
+
50
+ torch.nn.init.xavier_uniform_(
51
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
52
+
53
+ def forward(self, signal):
54
+ conv_signal = self.conv(signal)
55
+ return conv_signal
56
+
57
+ class CausualConv(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None):
59
+ super(CausualConv, self).__init__()
60
+ if padding is None:
61
+ assert(kernel_size % 2 == 1)
62
+ padding = int(dilation * (kernel_size - 1) / 2) * 2
63
+ else:
64
+ self.padding = padding * 2
65
+ self.conv = nn.Conv1d(in_channels, out_channels,
66
+ kernel_size=kernel_size, stride=stride,
67
+ padding=self.padding,
68
+ dilation=dilation,
69
+ bias=bias)
70
+
71
+ torch.nn.init.xavier_uniform_(
72
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
73
+
74
+ def forward(self, x):
75
+ x = self.conv(x)
76
+ x = x[:, :, :-self.padding]
77
+ return x
78
+
79
+ class CausualBlock(nn.Module):
80
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'):
81
+ super(CausualBlock, self).__init__()
82
+ self.blocks = nn.ModuleList([
83
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
84
+ for i in range(n_conv)])
85
+
86
+ def forward(self, x):
87
+ for block in self.blocks:
88
+ res = x
89
+ x = block(x)
90
+ x += res
91
+ return x
92
+
93
+ def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2):
94
+ layers = [
95
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
96
+ _get_activation_fn(activ),
97
+ nn.BatchNorm1d(hidden_dim),
98
+ nn.Dropout(p=dropout_p),
99
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
100
+ _get_activation_fn(activ),
101
+ nn.Dropout(p=dropout_p)
102
+ ]
103
+ return nn.Sequential(*layers)
104
+
105
+ class ConvBlock(nn.Module):
106
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'):
107
+ super().__init__()
108
+ self._n_groups = 8
109
+ self.blocks = nn.ModuleList([
110
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
111
+ for i in range(n_conv)])
112
+
113
+
114
+ def forward(self, x):
115
+ for block in self.blocks:
116
+ res = x
117
+ x = block(x)
118
+ x += res
119
+ return x
120
+
121
+ def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2):
122
+ layers = [
123
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
124
+ _get_activation_fn(activ),
125
+ nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
126
+ nn.Dropout(p=dropout_p),
127
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
128
+ _get_activation_fn(activ),
129
+ nn.Dropout(p=dropout_p)
130
+ ]
131
+ return nn.Sequential(*layers)
132
+
133
+ class LocationLayer(nn.Module):
134
+ def __init__(self, attention_n_filters, attention_kernel_size,
135
+ attention_dim):
136
+ super(LocationLayer, self).__init__()
137
+ padding = int((attention_kernel_size - 1) / 2)
138
+ self.location_conv = ConvNorm(2, attention_n_filters,
139
+ kernel_size=attention_kernel_size,
140
+ padding=padding, bias=False, stride=1,
141
+ dilation=1)
142
+ self.location_dense = LinearNorm(attention_n_filters, attention_dim,
143
+ bias=False, w_init_gain='tanh')
144
+
145
+ def forward(self, attention_weights_cat):
146
+ processed_attention = self.location_conv(attention_weights_cat)
147
+ processed_attention = processed_attention.transpose(1, 2)
148
+ processed_attention = self.location_dense(processed_attention)
149
+ return processed_attention
150
+
151
+
152
+ class Attention(nn.Module):
153
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
154
+ attention_location_n_filters, attention_location_kernel_size):
155
+ super(Attention, self).__init__()
156
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
157
+ bias=False, w_init_gain='tanh')
158
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
159
+ w_init_gain='tanh')
160
+ self.v = LinearNorm(attention_dim, 1, bias=False)
161
+ self.location_layer = LocationLayer(attention_location_n_filters,
162
+ attention_location_kernel_size,
163
+ attention_dim)
164
+ self.score_mask_value = -float("inf")
165
+
166
+ def get_alignment_energies(self, query, processed_memory,
167
+ attention_weights_cat):
168
+ """
169
+ PARAMS
170
+ ------
171
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
172
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
173
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
174
+ RETURNS
175
+ -------
176
+ alignment (batch, max_time)
177
+ """
178
+
179
+ processed_query = self.query_layer(query.unsqueeze(1))
180
+ processed_attention_weights = self.location_layer(attention_weights_cat)
181
+ energies = self.v(torch.tanh(
182
+ processed_query + processed_attention_weights + processed_memory))
183
+
184
+ energies = energies.squeeze(-1)
185
+ return energies
186
+
187
+ def forward(self, attention_hidden_state, memory, processed_memory,
188
+ attention_weights_cat, mask):
189
+ """
190
+ PARAMS
191
+ ------
192
+ attention_hidden_state: attention rnn last output
193
+ memory: encoder outputs
194
+ processed_memory: processed encoder outputs
195
+ attention_weights_cat: previous and cummulative attention weights
196
+ mask: binary mask for padded data
197
+ """
198
+ alignment = self.get_alignment_energies(
199
+ attention_hidden_state, processed_memory, attention_weights_cat)
200
+
201
+ if mask is not None:
202
+ alignment.data.masked_fill_(mask, self.score_mask_value)
203
+
204
+ attention_weights = F.softmax(alignment, dim=1)
205
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
206
+ attention_context = attention_context.squeeze(1)
207
+
208
+ return attention_context, attention_weights
209
+
210
+
211
+ class ForwardAttentionV2(nn.Module):
212
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
213
+ attention_location_n_filters, attention_location_kernel_size):
214
+ super(ForwardAttentionV2, self).__init__()
215
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
216
+ bias=False, w_init_gain='tanh')
217
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
218
+ w_init_gain='tanh')
219
+ self.v = LinearNorm(attention_dim, 1, bias=False)
220
+ self.location_layer = LocationLayer(attention_location_n_filters,
221
+ attention_location_kernel_size,
222
+ attention_dim)
223
+ self.score_mask_value = -float(1e20)
224
+
225
+ def get_alignment_energies(self, query, processed_memory,
226
+ attention_weights_cat):
227
+ """
228
+ PARAMS
229
+ ------
230
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
231
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
232
+ attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
233
+ RETURNS
234
+ -------
235
+ alignment (batch, max_time)
236
+ """
237
+
238
+ processed_query = self.query_layer(query.unsqueeze(1))
239
+ processed_attention_weights = self.location_layer(attention_weights_cat)
240
+ energies = self.v(torch.tanh(
241
+ processed_query + processed_attention_weights + processed_memory))
242
+
243
+ energies = energies.squeeze(-1)
244
+ return energies
245
+
246
+ def forward(self, attention_hidden_state, memory, processed_memory,
247
+ attention_weights_cat, mask, log_alpha):
248
+ """
249
+ PARAMS
250
+ ------
251
+ attention_hidden_state: attention rnn last output
252
+ memory: encoder outputs
253
+ processed_memory: processed encoder outputs
254
+ attention_weights_cat: previous and cummulative attention weights
255
+ mask: binary mask for padded data
256
+ """
257
+ log_energy = self.get_alignment_energies(
258
+ attention_hidden_state, processed_memory, attention_weights_cat)
259
+
260
+ #log_energy =
261
+
262
+ if mask is not None:
263
+ log_energy.data.masked_fill_(mask, self.score_mask_value)
264
+
265
+ #attention_weights = F.softmax(alignment, dim=1)
266
+
267
+ #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
268
+ #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
269
+
270
+ #log_total_score = log_alpha + content_score
271
+
272
+ #previous_attention_weights = attention_weights_cat[:,0,:]
273
+
274
+ log_alpha_shift_padded = []
275
+ max_time = log_energy.size(1)
276
+ for sft in range(2):
277
+ shifted = log_alpha[:,:max_time-sft]
278
+ shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value)
279
+ log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
280
+
281
+ biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2)
282
+
283
+ log_alpha_new = biased + log_energy
284
+
285
+ attention_weights = F.softmax(log_alpha_new, dim=1)
286
+
287
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
288
+ attention_context = attention_context.squeeze(1)
289
+
290
+ return attention_context, attention_weights, log_alpha_new
291
+
292
+
293
+ class PhaseShuffle2d(nn.Module):
294
+ def __init__(self, n=2):
295
+ super(PhaseShuffle2d, self).__init__()
296
+ self.n = n
297
+ self.random = random.Random(1)
298
+
299
+ def forward(self, x, move=None):
300
+ # x.size = (B, C, M, L)
301
+ if move is None:
302
+ move = self.random.randint(-self.n, self.n)
303
+
304
+ if move == 0:
305
+ return x
306
+ else:
307
+ left = x[:, :, :, :move]
308
+ right = x[:, :, :, move:]
309
+ shuffled = torch.cat([right, left], dim=3)
310
+ return shuffled
311
+
312
+ class PhaseShuffle1d(nn.Module):
313
+ def __init__(self, n=2):
314
+ super(PhaseShuffle1d, self).__init__()
315
+ self.n = n
316
+ self.random = random.Random(1)
317
+
318
+ def forward(self, x, move=None):
319
+ # x.size = (B, C, M, L)
320
+ if move is None:
321
+ move = self.random.randint(-self.n, self.n)
322
+
323
+ if move == 0:
324
+ return x
325
+ else:
326
+ left = x[:, :, :move]
327
+ right = x[:, :, move:]
328
+ shuffled = torch.cat([right, left], dim=2)
329
+
330
+ return shuffled
331
+
332
+ class MFCC(nn.Module):
333
+ def __init__(self, n_mfcc=40, n_mels=80):
334
+ super(MFCC, self).__init__()
335
+ self.n_mfcc = n_mfcc
336
+ self.n_mels = n_mels
337
+ self.norm = 'ortho'
338
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
339
+ self.register_buffer('dct_mat', dct_mat)
340
+
341
+ def forward(self, mel_specgram):
342
+ if len(mel_specgram.shape) == 2:
343
+ mel_specgram = mel_specgram.unsqueeze(0)
344
+ unsqueezed = True
345
+ else:
346
+ unsqueezed = False
347
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
348
+ # -> (channel, time, n_mfcc).tranpose(...)
349
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
350
+
351
+ # unpack batch
352
+ if unsqueezed:
353
+ mfcc = mfcc.squeeze(0)
354
+ return mfcc
modules/length_regulator.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from modules.commons import sequence_mask
5
+
6
+
7
+ class InterpolateRegulator(nn.Module):
8
+ def __init__(
9
+ self,
10
+ channels: int,
11
+ sampling_ratios: Tuple,
12
+ is_discrete: bool = False,
13
+ codebook_size: int = 1024, # for discrete only
14
+ out_channels: int = None,
15
+ groups: int = 1,
16
+ ):
17
+ super().__init__()
18
+ self.sampling_ratios = sampling_ratios
19
+ out_channels = out_channels or channels
20
+ model = nn.ModuleList([])
21
+ if len(sampling_ratios) > 0:
22
+ for _ in sampling_ratios:
23
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
24
+ norm = nn.GroupNorm(groups, channels)
25
+ act = nn.Mish()
26
+ model.extend([module, norm, act])
27
+ model.append(
28
+ nn.Conv1d(channels, out_channels, 1, 1)
29
+ )
30
+ self.model = nn.Sequential(*model)
31
+ self.embedding = nn.Embedding(codebook_size, channels)
32
+ self.is_discrete = is_discrete
33
+
34
+ def forward(self, x, ylens=None):
35
+ if self.is_discrete:
36
+ x = self.embedding(x)
37
+ # x in (B, T, D)
38
+ mask = sequence_mask(ylens).unsqueeze(-1)
39
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
40
+ out = self.model(x).transpose(1, 2).contiguous()
41
+ olens = ylens
42
+ return out * mask, olens
modules/wavenet.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from modules.encodec import SConv1d
7
+
8
+ from . import commons
9
+ LRELU_SLOPE = 0.1
10
+
11
+ class LayerNorm(nn.Module):
12
+ def __init__(self, channels, eps=1e-5):
13
+ super().__init__()
14
+ self.channels = channels
15
+ self.eps = eps
16
+
17
+ self.gamma = nn.Parameter(torch.ones(channels))
18
+ self.beta = nn.Parameter(torch.zeros(channels))
19
+
20
+ def forward(self, x):
21
+ x = x.transpose(1, -1)
22
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
23
+ return x.transpose(1, -1)
24
+
25
+
26
+ class ConvReluNorm(nn.Module):
27
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
28
+ super().__init__()
29
+ self.in_channels = in_channels
30
+ self.hidden_channels = hidden_channels
31
+ self.out_channels = out_channels
32
+ self.kernel_size = kernel_size
33
+ self.n_layers = n_layers
34
+ self.p_dropout = p_dropout
35
+ assert n_layers > 1, "Number of layers should be larger than 0."
36
+
37
+ self.conv_layers = nn.ModuleList()
38
+ self.norm_layers = nn.ModuleList()
39
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
40
+ self.norm_layers.append(LayerNorm(hidden_channels))
41
+ self.relu_drop = nn.Sequential(
42
+ nn.ReLU(),
43
+ nn.Dropout(p_dropout))
44
+ for _ in range(n_layers - 1):
45
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
46
+ self.norm_layers.append(LayerNorm(hidden_channels))
47
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
48
+ self.proj.weight.data.zero_()
49
+ self.proj.bias.data.zero_()
50
+
51
+ def forward(self, x, x_mask):
52
+ x_org = x
53
+ for i in range(self.n_layers):
54
+ x = self.conv_layers[i](x * x_mask)
55
+ x = self.norm_layers[i](x)
56
+ x = self.relu_drop(x)
57
+ x = x_org + self.proj(x)
58
+ return x * x_mask
59
+
60
+
61
+ class DDSConv(nn.Module):
62
+ """
63
+ Dialted and Depth-Separable Convolution
64
+ """
65
+
66
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
67
+ super().__init__()
68
+ self.channels = channels
69
+ self.kernel_size = kernel_size
70
+ self.n_layers = n_layers
71
+ self.p_dropout = p_dropout
72
+
73
+ self.drop = nn.Dropout(p_dropout)
74
+ self.convs_sep = nn.ModuleList()
75
+ self.convs_1x1 = nn.ModuleList()
76
+ self.norms_1 = nn.ModuleList()
77
+ self.norms_2 = nn.ModuleList()
78
+ for i in range(n_layers):
79
+ dilation = kernel_size ** i
80
+ padding = (kernel_size * dilation - dilation) // 2
81
+ self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
82
+ groups=channels, dilation=dilation, padding=padding
83
+ ))
84
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
85
+ self.norms_1.append(LayerNorm(channels))
86
+ self.norms_2.append(LayerNorm(channels))
87
+
88
+ def forward(self, x, x_mask, g=None):
89
+ if g is not None:
90
+ x = x + g
91
+ for i in range(self.n_layers):
92
+ y = self.convs_sep[i](x * x_mask)
93
+ y = self.norms_1[i](y)
94
+ y = F.gelu(y)
95
+ y = self.convs_1x1[i](y)
96
+ y = self.norms_2[i](y)
97
+ y = F.gelu(y)
98
+ y = self.drop(y)
99
+ x = x + y
100
+ return x * x_mask
101
+
102
+
103
+ class WN(torch.nn.Module):
104
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0, causal=False):
105
+ super(WN, self).__init__()
106
+ conv1d_type = SConv1d
107
+ assert (kernel_size % 2 == 1)
108
+ self.hidden_channels = hidden_channels
109
+ self.kernel_size = kernel_size,
110
+ self.dilation_rate = dilation_rate
111
+ self.n_layers = n_layers
112
+ self.gin_channels = gin_channels
113
+ self.p_dropout = p_dropout
114
+
115
+ self.in_layers = torch.nn.ModuleList()
116
+ self.res_skip_layers = torch.nn.ModuleList()
117
+ self.drop = nn.Dropout(p_dropout)
118
+
119
+ if gin_channels != 0:
120
+ self.cond_layer = conv1d_type(gin_channels, 2 * hidden_channels * n_layers, 1, norm='weight_norm')
121
+
122
+ for i in range(n_layers):
123
+ dilation = dilation_rate ** i
124
+ padding = int((kernel_size * dilation - dilation) / 2)
125
+ in_layer = conv1d_type(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation,
126
+ padding=padding, norm='weight_norm', causal=causal)
127
+ self.in_layers.append(in_layer)
128
+
129
+ # last one is not necessary
130
+ if i < n_layers - 1:
131
+ res_skip_channels = 2 * hidden_channels
132
+ else:
133
+ res_skip_channels = hidden_channels
134
+
135
+ res_skip_layer = conv1d_type(hidden_channels, res_skip_channels, 1, norm='weight_norm', causal=causal)
136
+ self.res_skip_layers.append(res_skip_layer)
137
+
138
+ def forward(self, x, x_mask, g=None, **kwargs):
139
+ output = torch.zeros_like(x)
140
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
141
+
142
+ if g is not None:
143
+ g = self.cond_layer(g)
144
+
145
+ for i in range(self.n_layers):
146
+ x_in = self.in_layers[i](x)
147
+ if g is not None:
148
+ cond_offset = i * 2 * self.hidden_channels
149
+ g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
150
+ else:
151
+ g_l = torch.zeros_like(x_in)
152
+
153
+ acts = commons.fused_add_tanh_sigmoid_multiply(
154
+ x_in,
155
+ g_l,
156
+ n_channels_tensor)
157
+ acts = self.drop(acts)
158
+
159
+ res_skip_acts = self.res_skip_layers[i](acts)
160
+ if i < self.n_layers - 1:
161
+ res_acts = res_skip_acts[:, :self.hidden_channels, :]
162
+ x = (x + res_acts) * x_mask
163
+ output = output + res_skip_acts[:, self.hidden_channels:, :]
164
+ else:
165
+ output = output + res_skip_acts
166
+ return output * x_mask
167
+
168
+ def remove_weight_norm(self):
169
+ if self.gin_channels != 0:
170
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
171
+ for l in self.in_layers:
172
+ torch.nn.utils.remove_weight_norm(l)
173
+ for l in self.res_skip_layers:
174
+ torch.nn.utils.remove_weight_norm(l)