jbetker commited on
Commit
9043dde
·
1 Parent(s): 287debd

Integrate new diffusion network

Browse files
Files changed (3) hide show
  1. api.py +24 -25
  2. models/arch_util.py +13 -8
  3. models/diffusion_decoder.py +155 -360
api.py CHANGED
@@ -49,6 +49,15 @@ def download_models():
49
  print('Done.')
50
 
51
 
 
 
 
 
 
 
 
 
 
52
  def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1):
53
  """
54
  Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
@@ -96,26 +105,25 @@ def fix_autoregressive_output(codes, stop_token):
96
  return codes
97
 
98
 
99
- def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_input, temperature=1):
100
  """
101
- Uses the specified diffusion model and DVAE model to convert the provided MEL & conditioning inputs into an audio clip.
102
  """
103
  with torch.no_grad():
104
- cond_mel = wav_to_univnet_mel(conditioning_input.squeeze(1), do_normalization=False)
105
- # Pad MEL to multiples of 32
106
- msl = mel_codes.shape[-1]
107
- dsl = 32
108
- gap = dsl - (msl % dsl)
109
- if gap > 0:
110
- mel = torch.nn.functional.pad(mel_codes, (0, gap))
111
 
112
- output_shape = (mel.shape[0], 100, mel.shape[-1]*4)
113
- precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mel)
114
 
115
  noise = torch.randn(output_shape, device=mel_codes.device) * temperature
116
  mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
117
  model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
118
- return denormalize_tacotron_mel(mel)[:,:,:msl*4]
119
 
120
 
121
  class TextToSpeech:
@@ -137,12 +145,9 @@ class TextToSpeech:
137
  use_xformers=True).cpu().eval()
138
  self.clip.load_state_dict(torch.load('.models/clip.pth'))
139
 
140
- self.diffusion = DiffusionTts(model_channels=512, in_channels=100, out_channels=200, in_latent_channels=1024,
141
- channel_mult=[1, 2, 3, 4], num_res_blocks=[3, 3, 3, 3],
142
- token_conditioning_resolutions=[1, 4, 8],
143
- dropout=0, attention_resolutions=[4, 8], num_heads=8, kernel_size=3, scale_factor=2,
144
- time_embed_dim_multiplier=4, unconditioned_percentage=0, conditioning_dim_factor=2,
145
- conditioning_expansion=1).cpu().eval()
146
  self.diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
147
 
148
  self.vocoder = UnivNetGenerator().cpu()
@@ -164,12 +169,6 @@ class TextToSpeech:
164
  for vs in voice_samples:
165
  conds.append(load_conditioning(vs))
166
  conds = torch.stack(conds, dim=1)
167
- cond_diffusion = voice_samples[0].cuda()
168
- # The diffusion model expects = 88200 conditioning samples.
169
- if cond_diffusion.shape[-1] < 88200:
170
- cond_diffusion = F.pad(cond_diffusion, (0, 88200-cond_diffusion.shape[-1]))
171
- else:
172
- cond_diffusion = cond_diffusion[:, :88200]
173
 
174
  diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
175
 
@@ -211,7 +210,7 @@ class TextToSpeech:
211
  self.vocoder = self.vocoder.cuda()
212
  for b in range(best_results.shape[0]):
213
  code = best_results[b].unsqueeze(0)
214
- mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, cond_diffusion, temperature=diffusion_temperature)
215
  wav = self.vocoder.inference(mel)
216
  wav_candidates.append(wav.cpu())
217
  self.diffusion = self.diffusion.cpu()
 
49
  print('Done.')
50
 
51
 
52
+ def pad_or_truncate(t, length):
53
+ if t.shape[-1] == length:
54
+ return t
55
+ elif t.shape[-1] < length:
56
+ return F.pad(t, (0, length-t.shape[-1]))
57
+ else:
58
+ return t[..., :length]
59
+
60
+
61
  def load_discrete_vocoder_diffuser(trained_diffusion_steps=4000, desired_diffusion_steps=200, cond_free=True, cond_free_k=1):
62
  """
63
  Helper function to load a GaussianDiffusion instance configured for use as a vocoder.
 
105
  return codes
106
 
107
 
108
+ def do_spectrogram_diffusion(diffusion_model, diffuser, mel_codes, conditioning_samples, temperature=1):
109
  """
110
+ Uses the specified diffusion model to convert discrete codes into a spectrogram.
111
  """
112
  with torch.no_grad():
113
+ cond_mels = []
114
+ for sample in conditioning_samples:
115
+ sample = pad_or_truncate(sample, 102400)
116
+ cond_mel = wav_to_univnet_mel(sample.to(mel_codes.device), do_normalization=False)
117
+ cond_mels.append(cond_mel)
118
+ cond_mels = torch.stack(cond_mels, dim=1)
 
119
 
120
+ output_shape = (mel_codes.shape[0], 100, mel_codes.shape[-1]*4)
121
+ precomputed_embeddings = diffusion_model.timestep_independent(mel_codes, cond_mels, False)
122
 
123
  noise = torch.randn(output_shape, device=mel_codes.device) * temperature
124
  mel = diffuser.p_sample_loop(diffusion_model, output_shape, noise=noise,
125
  model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings})
126
+ return denormalize_tacotron_mel(mel)[:,:,:mel_codes.shape[-1]*4]
127
 
128
 
129
  class TextToSpeech:
 
145
  use_xformers=True).cpu().eval()
146
  self.clip.load_state_dict(torch.load('.models/clip.pth'))
147
 
148
+ self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
149
+ in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
150
+ layer_drop=0, unconditioned_percentage=0).cpu().eval()
 
 
 
151
  self.diffusion.load_state_dict(torch.load('.models/diffusion.pth'))
152
 
153
  self.vocoder = UnivNetGenerator().cpu()
 
169
  for vs in voice_samples:
170
  conds.append(load_conditioning(vs))
171
  conds = torch.stack(conds, dim=1)
 
 
 
 
 
 
172
 
173
  diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_iterations, cond_free=cond_free, cond_free_k=cond_free_k)
174
 
 
210
  self.vocoder = self.vocoder.cuda()
211
  for b in range(best_results.shape[0]):
212
  code = best_results[b].unsqueeze(0)
213
+ mel = do_spectrogram_diffusion(self.diffusion, diffuser, code, voice_samples, temperature=diffusion_temperature)
214
  wav = self.vocoder.inference(mel)
215
  wav_candidates.append(wav.cpu())
216
  self.diffusion = self.diffusion.cpu()
models/arch_util.py CHANGED
@@ -6,6 +6,7 @@ import torch.nn as nn
6
  import torch.nn.functional as F
7
  import torchaudio
8
  from x_transformers import ContinuousTransformerWrapper
 
9
 
10
 
11
  def zero_module(module):
@@ -49,7 +50,7 @@ class QKVAttentionLegacy(nn.Module):
49
  super().__init__()
50
  self.n_heads = n_heads
51
 
52
- def forward(self, qkv, mask=None):
53
  """
54
  Apply QKV attention.
55
 
@@ -64,6 +65,8 @@ class QKVAttentionLegacy(nn.Module):
64
  weight = torch.einsum(
65
  "bct,bcs->bts", q * scale, k * scale
66
  ) # More stable with f16 than dividing afterwards
 
 
67
  weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
68
  if mask is not None:
69
  # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
@@ -87,9 +90,12 @@ class AttentionBlock(nn.Module):
87
  channels,
88
  num_heads=1,
89
  num_head_channels=-1,
 
 
90
  ):
91
  super().__init__()
92
  self.channels = channels
 
93
  if num_head_channels == -1:
94
  self.num_heads = num_heads
95
  else:
@@ -99,21 +105,20 @@ class AttentionBlock(nn.Module):
99
  self.num_heads = channels // num_head_channels
100
  self.norm = normalization(channels)
101
  self.qkv = nn.Conv1d(channels, channels * 3, 1)
 
102
  self.attention = QKVAttentionLegacy(self.num_heads)
103
 
104
  self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
105
-
106
- def forward(self, x, mask=None):
107
- if mask is not None:
108
- return self._forward(x, mask)
109
  else:
110
- return self._forward(x)
111
 
112
- def _forward(self, x, mask=None):
113
  b, c, *spatial = x.shape
114
  x = x.reshape(b, c, -1)
115
  qkv = self.qkv(self.norm(x))
116
- h = self.attention(qkv, mask)
117
  h = self.proj_out(h)
118
  return (x + h).reshape(b, c, *spatial)
119
 
 
6
  import torch.nn.functional as F
7
  import torchaudio
8
  from x_transformers import ContinuousTransformerWrapper
9
+ from x_transformers.x_transformers import RelativePositionBias
10
 
11
 
12
  def zero_module(module):
 
50
  super().__init__()
51
  self.n_heads = n_heads
52
 
53
+ def forward(self, qkv, mask=None, rel_pos=None):
54
  """
55
  Apply QKV attention.
56
 
 
65
  weight = torch.einsum(
66
  "bct,bcs->bts", q * scale, k * scale
67
  ) # More stable with f16 than dividing afterwards
68
+ if rel_pos is not None:
69
+ weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
70
  weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
71
  if mask is not None:
72
  # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
 
90
  channels,
91
  num_heads=1,
92
  num_head_channels=-1,
93
+ do_checkpoint=True,
94
+ relative_pos_embeddings=False,
95
  ):
96
  super().__init__()
97
  self.channels = channels
98
+ self.do_checkpoint = do_checkpoint
99
  if num_head_channels == -1:
100
  self.num_heads = num_heads
101
  else:
 
105
  self.num_heads = channels // num_head_channels
106
  self.norm = normalization(channels)
107
  self.qkv = nn.Conv1d(channels, channels * 3, 1)
108
+ # split heads before split qkv
109
  self.attention = QKVAttentionLegacy(self.num_heads)
110
 
111
  self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
112
+ if relative_pos_embeddings:
113
+ self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64)
 
 
114
  else:
115
+ self.relative_pos_embeddings = None
116
 
117
+ def forward(self, x, mask=None):
118
  b, c, *spatial = x.shape
119
  x = x.reshape(b, c, -1)
120
  qkv = self.qkv(self.norm(x))
121
+ h = self.attention(qkv, mask, self.relative_pos_embeddings)
122
  h = self.proj_out(h)
123
  return (x + h).reshape(b, c, *spatial)
124
 
models/diffusion_decoder.py CHANGED
@@ -1,22 +1,13 @@
1
- """
2
- This model is based on OpenAI's UNet from improved diffusion, with modifications to support a MEL conditioning signal
3
- and an audio conditioning input. It has also been simplified somewhat.
4
- Credit: https://github.com/openai/improved-diffusion
5
- """
6
- import functools
7
  import math
 
8
  from abc import abstractmethod
9
 
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
  from torch import autocast
14
- from torch.nn import Linear
15
- from torch.utils.checkpoint import checkpoint
16
- from x_transformers import ContinuousTransformerWrapper, Encoder
17
 
18
- from models.arch_util import normalization, zero_module, Downsample, Upsample, AudioMiniEncoder, AttentionBlock, \
19
- CheckpointedXTransformerEncoder
20
 
21
 
22
  def is_latent(t):
@@ -27,13 +18,6 @@ def is_sequence(t):
27
  return t.dtype == torch.long
28
 
29
 
30
- def ceil_multiple(base, multiple):
31
- res = base % multiple
32
- if res == 0:
33
- return base
34
- return base + (multiple - res)
35
-
36
-
37
  def timestep_embedding(timesteps, dim, max_period=10000):
38
  """
39
  Create sinusoidal timestep embeddings.
@@ -56,10 +40,6 @@ def timestep_embedding(timesteps, dim, max_period=10000):
56
 
57
 
58
  class TimestepBlock(nn.Module):
59
- """
60
- Any module where forward() takes timestep embeddings as a second argument.
61
- """
62
-
63
  @abstractmethod
64
  def forward(self, x, emb):
65
  """
@@ -68,11 +48,6 @@ class TimestepBlock(nn.Module):
68
 
69
 
70
  class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
71
- """
72
- A sequential module that passes timestep embeddings to the children that
73
- support it as an extra input.
74
- """
75
-
76
  def forward(self, x, emb):
77
  for layer in self:
78
  if isinstance(layer, TimestepBlock):
@@ -89,6 +64,7 @@ class ResBlock(TimestepBlock):
89
  emb_channels,
90
  dropout,
91
  out_channels=None,
 
92
  kernel_size=3,
93
  efficient_config=True,
94
  use_scale_shift_norm=False,
@@ -111,7 +87,7 @@ class ResBlock(TimestepBlock):
111
 
112
  self.emb_layers = nn.Sequential(
113
  nn.SiLU(),
114
- Linear(
115
  emb_channels,
116
  2 * self.out_channels if use_scale_shift_norm else self.out_channels,
117
  ),
@@ -120,9 +96,7 @@ class ResBlock(TimestepBlock):
120
  normalization(self.out_channels),
121
  nn.SiLU(),
122
  nn.Dropout(p=dropout),
123
- zero_module(
124
- nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
125
- ),
126
  )
127
 
128
  if self.out_channels == channels:
@@ -131,18 +105,6 @@ class ResBlock(TimestepBlock):
131
  self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
132
 
133
  def forward(self, x, emb):
134
- """
135
- Apply the block to a Tensor, conditioned on a timestep embedding.
136
-
137
- :param x: an [N x C x ...] Tensor of features.
138
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
139
- :return: an [N x C x ...] Tensor of outputs.
140
- """
141
- return checkpoint(
142
- self._forward, x, emb
143
- )
144
-
145
- def _forward(self, x, emb):
146
  h = self.in_layers(x)
147
  emb_out = self.emb_layers(emb).type(h.dtype)
148
  while len(emb_out.shape) < len(h.shape):
@@ -158,372 +120,205 @@ class ResBlock(TimestepBlock):
158
  return self.skip_connection(x) + h
159
 
160
 
161
- class DiffusionTts(nn.Module):
162
- """
163
- The full UNet model with attention and timestep embedding.
164
-
165
- Customized to be conditioned on an aligned prior derived from a autoregressive
166
- GPT-style model.
167
-
168
- :param in_channels: channels in the input Tensor.
169
- :param in_latent_channels: channels from the input latent.
170
- :param model_channels: base channel count for the model.
171
- :param out_channels: channels in the output Tensor.
172
- :param num_res_blocks: number of residual blocks per downsample.
173
- :param attention_resolutions: a collection of downsample rates at which
174
- attention will take place. May be a set, list, or tuple.
175
- For example, if this contains 4, then at 4x downsampling, attention
176
- will be used.
177
- :param dropout: the dropout probability.
178
- :param channel_mult: channel multiplier for each level of the UNet.
179
- :param conv_resample: if True, use learned convolutions for upsampling and
180
- downsampling.
181
- :param num_heads: the number of attention heads in each attention layer.
182
- :param num_heads_channels: if specified, ignore num_heads and instead use
183
- a fixed channel width per attention head.
184
- :param num_heads_upsample: works with num_heads to set a different number
185
- of heads for upsampling. Deprecated.
186
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
187
- :param resblock_updown: use residual blocks for up/downsampling.
188
- :param use_new_attention_order: use a different attention pattern for potentially
189
- increased efficiency.
190
- """
191
 
 
 
 
 
 
 
192
  def __init__(
193
  self,
194
- model_channels,
195
- in_channels=1,
196
- in_latent_channels=1024,
 
197
  in_tokens=8193,
198
- conditioning_dim_factor=8,
199
- conditioning_expansion=4,
200
- out_channels=2, # mean and variance
201
  dropout=0,
202
- # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
203
- channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
204
- num_res_blocks=(1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
205
- # spec_cond: 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0)
206
- # attn: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1
207
- token_conditioning_resolutions=(1,16,),
208
- attention_resolutions=(512,1024,2048),
209
- conv_resample=True,
210
  use_fp16=False,
211
- num_heads=1,
212
- num_head_channels=-1,
213
- num_heads_upsample=-1,
214
- kernel_size=3,
215
- scale_factor=2,
216
- time_embed_dim_multiplier=4,
217
- freeze_main_net=False,
218
- efficient_convs=True, # Uses kernels with width of 1 in several places rather than 3.
219
- use_scale_shift_norm=True,
220
  # Parameters for regularization.
 
221
  unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
222
- # Parameters for super-sampling.
223
- super_sampling=False,
224
- super_sampling_max_noising_factor=.1,
225
  ):
226
  super().__init__()
227
 
228
- if num_heads_upsample == -1:
229
- num_heads_upsample = num_heads
230
-
231
- if super_sampling:
232
- in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input.
233
  self.in_channels = in_channels
234
  self.model_channels = model_channels
235
  self.out_channels = out_channels
236
- self.attention_resolutions = attention_resolutions
237
  self.dropout = dropout
238
- self.channel_mult = channel_mult
239
- self.conv_resample = conv_resample
240
  self.num_heads = num_heads
241
- self.num_head_channels = num_head_channels
242
- self.num_heads_upsample = num_heads_upsample
243
- self.super_sampling_enabled = super_sampling
244
- self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
245
  self.unconditioned_percentage = unconditioned_percentage
246
  self.enable_fp16 = use_fp16
247
- self.alignment_size = 2 ** (len(channel_mult)+1)
248
- self.freeze_main_net = freeze_main_net
249
- padding = 1 if kernel_size == 3 else 2
250
- down_kernel = 1 if efficient_convs else 3
251
 
252
- time_embed_dim = model_channels * time_embed_dim_multiplier
253
  self.time_embed = nn.Sequential(
254
- Linear(model_channels, time_embed_dim),
255
  nn.SiLU(),
256
- Linear(time_embed_dim, time_embed_dim),
257
  )
258
 
259
- conditioning_dim = model_channels * conditioning_dim_factor
260
  # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
261
  # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
262
  # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
263
  # transformer network.
 
264
  self.code_converter = nn.Sequential(
265
- nn.Embedding(in_tokens, conditioning_dim),
266
- CheckpointedXTransformerEncoder(
267
- needs_permute=False,
268
- max_seq_len=-1,
269
- use_pos_emb=False,
270
- attn_layers=Encoder(
271
- dim=conditioning_dim,
272
- depth=3,
273
- heads=num_heads,
274
- ff_dropout=dropout,
275
- attn_dropout=dropout,
276
- use_rmsnorm=True,
277
- ff_glu=True,
278
- rotary_emb_dim=True,
279
- )
280
- ))
281
- self.latent_converter = nn.Conv1d(in_latent_channels, conditioning_dim, 1)
282
- self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_latent_channels,1))
283
- if in_channels > 60: # It's a spectrogram.
284
- self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,conditioning_dim,3,padding=1,stride=2),
285
- CheckpointedXTransformerEncoder(
286
- needs_permute=True,
287
- max_seq_len=-1,
288
- use_pos_emb=False,
289
- attn_layers=Encoder(
290
- dim=conditioning_dim,
291
- depth=4,
292
- heads=num_heads,
293
- ff_dropout=dropout,
294
- attn_dropout=dropout,
295
- use_rmsnorm=True,
296
- ff_glu=True,
297
- rotary_emb_dim=True,
298
- )
299
- ))
300
- else:
301
- self.contextual_embedder = AudioMiniEncoder(1, conditioning_dim, base_channels=32, depth=6, resnet_blocks=1,
302
- attn_blocks=3, num_attn_heads=8, dropout=dropout, downsample_factor=4, kernel_size=5)
303
- self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1)
304
- self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
305
- self.conditioning_timestep_integrator = TimestepEmbedSequential(
306
- ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
307
- AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
308
- ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
309
- AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
310
- ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
311
  )
312
- self.conditioning_expansion = conditioning_expansion
313
-
314
- self.input_blocks = nn.ModuleList(
315
- [
316
- TimestepEmbedSequential(
317
- nn.Conv1d(in_channels, model_channels, kernel_size, padding=padding)
318
- )
319
- ]
320
- )
321
- token_conditioning_blocks = []
322
- self._feature_size = model_channels
323
- input_block_chans = [model_channels]
324
- ch = model_channels
325
- ds = 1
326
-
327
- for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
328
- if ds in token_conditioning_resolutions:
329
- token_conditioning_block = nn.Conv1d(conditioning_dim, ch, 1)
330
- token_conditioning_block.weight.data *= .02
331
- self.input_blocks.append(token_conditioning_block)
332
- token_conditioning_blocks.append(token_conditioning_block)
333
-
334
- for _ in range(num_blocks):
335
- layers = [
336
- ResBlock(
337
- ch,
338
- time_embed_dim,
339
- dropout,
340
- out_channels=int(mult * model_channels),
341
- kernel_size=kernel_size,
342
- efficient_config=efficient_convs,
343
- use_scale_shift_norm=use_scale_shift_norm,
344
- )
345
- ]
346
- ch = int(mult * model_channels)
347
- if ds in attention_resolutions:
348
- layers.append(
349
- AttentionBlock(
350
- ch,
351
- num_heads=num_heads,
352
- num_head_channels=num_head_channels,
353
- )
354
- )
355
- self.input_blocks.append(TimestepEmbedSequential(*layers))
356
- self._feature_size += ch
357
- input_block_chans.append(ch)
358
- if level != len(channel_mult) - 1:
359
- out_ch = ch
360
- self.input_blocks.append(
361
- TimestepEmbedSequential(
362
- Downsample(
363
- ch, conv_resample, out_channels=out_ch, factor=scale_factor, ksize=down_kernel, pad=0 if down_kernel == 1 else 1
364
- )
365
- )
366
- )
367
- ch = out_ch
368
- input_block_chans.append(ch)
369
- ds *= 2
370
- self._feature_size += ch
371
-
372
- self.middle_block = TimestepEmbedSequential(
373
- ResBlock(
374
- ch,
375
- time_embed_dim,
376
- dropout,
377
- kernel_size=kernel_size,
378
- efficient_config=efficient_convs,
379
- use_scale_shift_norm=use_scale_shift_norm,
380
- ),
381
- AttentionBlock(
382
- ch,
383
- num_heads=num_heads,
384
- num_head_channels=num_head_channels,
385
- ),
386
- ResBlock(
387
- ch,
388
- time_embed_dim,
389
- dropout,
390
- kernel_size=kernel_size,
391
- efficient_config=efficient_convs,
392
- use_scale_shift_norm=use_scale_shift_norm,
393
- ),
394
  )
395
- self._feature_size += ch
396
-
397
- self.output_blocks = nn.ModuleList([])
398
- for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]:
399
- for i in range(num_blocks + 1):
400
- ich = input_block_chans.pop()
401
- layers = [
402
- ResBlock(
403
- ch + ich,
404
- time_embed_dim,
405
- dropout,
406
- out_channels=int(model_channels * mult),
407
- kernel_size=kernel_size,
408
- efficient_config=efficient_convs,
409
- use_scale_shift_norm=use_scale_shift_norm,
410
- )
411
- ]
412
- ch = int(model_channels * mult)
413
- if ds in attention_resolutions:
414
- layers.append(
415
- AttentionBlock(
416
- ch,
417
- num_heads=num_heads_upsample,
418
- num_head_channels=num_head_channels,
419
- )
420
- )
421
- if level and i == num_blocks:
422
- out_ch = ch
423
- layers.append(
424
- Upsample(ch, conv_resample, out_channels=out_ch, factor=scale_factor)
425
- )
426
- ds //= 2
427
- self.output_blocks.append(TimestepEmbedSequential(*layers))
428
- self._feature_size += ch
429
 
430
  self.out = nn.Sequential(
431
- normalization(ch),
432
  nn.SiLU(),
433
- zero_module(nn.Conv1d(model_channels, out_channels, kernel_size, padding=padding)),
434
  )
435
 
436
- def fix_alignment(self, x, aligned_conditioning):
437
- """
438
- The UNet requires that the input <x> is a certain multiple of 2, defined by the UNet depth. Enforce this by
439
- padding both <x> and <aligned_conditioning> before forward propagation and removing the padding before returning.
440
- """
441
- cm = ceil_multiple(x.shape[-1], self.alignment_size)
442
- if cm != 0:
443
- pc = (cm-x.shape[-1])/x.shape[-1]
444
- x = F.pad(x, (0,cm-x.shape[-1]))
445
- # Also fix aligned_latent, which is aligned to x.
446
- if is_latent(aligned_conditioning):
447
- aligned_conditioning = torch.cat([aligned_conditioning,
448
- self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
449
- else:
450
- aligned_conditioning = F.pad(aligned_conditioning, (0, int(pc*aligned_conditioning.shape[-1])))
451
- return x, aligned_conditioning
452
-
453
- def timestep_independent(self, aligned_conditioning, conditioning_input):
454
  # Shuffle aligned_latent to BxCxS format
455
  if is_latent(aligned_conditioning):
456
  aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
457
 
458
- with autocast(aligned_conditioning.device.type, enabled=self.enable_fp16):
459
- cond_emb = self.contextual_embedder(conditioning_input)
460
- if len(cond_emb.shape) == 3: # Just take the first element.
461
- cond_emb = cond_emb[:, :, 0]
462
- if is_latent(aligned_conditioning):
463
- code_emb = self.latent_converter(aligned_conditioning)
464
- else:
465
- code_emb = self.code_converter(aligned_conditioning)
466
- cond_emb = cond_emb.unsqueeze(-1).repeat(1, 1, code_emb.shape[-1])
467
- code_emb = self.conditioning_conv(torch.cat([cond_emb, code_emb], dim=1))
468
- return code_emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
- def forward(self, x, timesteps, precomputed_aligned_embeddings, conditioning_free=False):
471
- assert x.shape[-1] % self.alignment_size == 0
472
 
473
- with autocast(x.device.type, enabled=self.enable_fp16):
474
- if conditioning_free:
475
- code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
476
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  code_emb = precomputed_aligned_embeddings
478
-
479
- time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
480
- code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
481
- code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
482
-
483
- first = True
484
- time_emb = time_emb.float()
485
- h = x
486
- hs = []
487
- for k, module in enumerate(self.input_blocks):
488
- if isinstance(module, nn.Conv1d):
489
- h_tok = F.interpolate(module(code_emb), size=(h.shape[-1]), mode='nearest')
490
- h = h + h_tok
491
  else:
492
- with autocast(x.device.type, enabled=self.enable_fp16 and not first):
493
- # First block has autocast disabled to allow a high precision signal to be properly vectorized.
494
- h = module(h, time_emb)
495
- hs.append(h)
496
- first = False
497
- h = self.middle_block(h, time_emb)
498
- for module in self.output_blocks:
499
- h = torch.cat([h, hs.pop()], dim=1)
500
- h = module(h, time_emb)
501
-
502
- # Last block also has autocast disabled for high-precision outputs.
503
- h = h.float()
504
- out = self.out(h)
 
 
 
505
 
 
 
 
 
 
 
 
 
 
 
 
506
  return out
507
 
508
 
509
  if __name__ == '__main__':
510
- clip = torch.randn(2, 1, 32868)
511
- aligned_latent = torch.randn(2,388,1024)
512
- aligned_sequence = torch.randint(0,8192,(2,388))
513
- cond = torch.randn(2, 1, 44000)
514
  ts = torch.LongTensor([600, 600])
515
- model = DiffusionTts(128,
516
- channel_mult=[1,1.5,2, 3, 4, 6, 8],
517
- num_res_blocks=[2, 2, 2, 2, 2, 2, 1],
518
- token_conditioning_resolutions=[1,4,16,64],
519
- attention_resolutions=[],
520
- num_heads=8,
521
- kernel_size=3,
522
- scale_factor=2,
523
- time_embed_dim_multiplier=4,
524
- super_sampling=False,
525
- efficient_convs=False)
526
  # Test with latent aligned conditioning
527
- o = model(clip, ts, aligned_latent, cond)
528
  # Test with sequence aligned conditioning
529
  o = model(clip, ts, aligned_sequence, cond)
 
 
 
 
 
 
 
 
1
  import math
2
+ import random
3
  from abc import abstractmethod
4
 
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
  from torch import autocast
 
 
 
9
 
10
+ from models.arch_util import normalization, AttentionBlock
 
11
 
12
 
13
  def is_latent(t):
 
18
  return t.dtype == torch.long
19
 
20
 
 
 
 
 
 
 
 
21
  def timestep_embedding(timesteps, dim, max_period=10000):
22
  """
23
  Create sinusoidal timestep embeddings.
 
40
 
41
 
42
  class TimestepBlock(nn.Module):
 
 
 
 
43
  @abstractmethod
44
  def forward(self, x, emb):
45
  """
 
48
 
49
 
50
  class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
 
 
 
 
 
51
  def forward(self, x, emb):
52
  for layer in self:
53
  if isinstance(layer, TimestepBlock):
 
64
  emb_channels,
65
  dropout,
66
  out_channels=None,
67
+ dims=2,
68
  kernel_size=3,
69
  efficient_config=True,
70
  use_scale_shift_norm=False,
 
87
 
88
  self.emb_layers = nn.Sequential(
89
  nn.SiLU(),
90
+ nn.Linear(
91
  emb_channels,
92
  2 * self.out_channels if use_scale_shift_norm else self.out_channels,
93
  ),
 
96
  normalization(self.out_channels),
97
  nn.SiLU(),
98
  nn.Dropout(p=dropout),
99
+ nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding),
 
 
100
  )
101
 
102
  if self.out_channels == channels:
 
105
  self.skip_connection = nn.Conv1d(channels, self.out_channels, eff_kernel, padding=eff_padding)
106
 
107
  def forward(self, x, emb):
 
 
 
 
 
 
 
 
 
 
 
 
108
  h = self.in_layers(x)
109
  emb_out = self.emb_layers(emb).type(h.dtype)
110
  while len(emb_out.shape) < len(h.shape):
 
120
  return self.skip_connection(x) + h
121
 
122
 
123
+ class DiffusionLayer(TimestepBlock):
124
+ def __init__(self, model_channels, dropout, num_heads):
125
+ super().__init__()
126
+ self.resblk = ResBlock(model_channels, model_channels, dropout, model_channels, dims=1, use_scale_shift_norm=True)
127
+ self.attn = AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ def forward(self, x, time_emb):
130
+ y = self.resblk(x, time_emb)
131
+ return self.attn(y)
132
+
133
+
134
+ class DiffusionTts(nn.Module):
135
  def __init__(
136
  self,
137
+ model_channels=512,
138
+ num_layers=8,
139
+ in_channels=100,
140
+ in_latent_channels=512,
141
  in_tokens=8193,
142
+ out_channels=200, # mean and variance
 
 
143
  dropout=0,
 
 
 
 
 
 
 
 
144
  use_fp16=False,
145
+ num_heads=16,
 
 
 
 
 
 
 
 
146
  # Parameters for regularization.
147
+ layer_drop=.1,
148
  unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
 
 
 
149
  ):
150
  super().__init__()
151
 
 
 
 
 
 
152
  self.in_channels = in_channels
153
  self.model_channels = model_channels
154
  self.out_channels = out_channels
 
155
  self.dropout = dropout
 
 
156
  self.num_heads = num_heads
 
 
 
 
157
  self.unconditioned_percentage = unconditioned_percentage
158
  self.enable_fp16 = use_fp16
159
+ self.layer_drop = layer_drop
 
 
 
160
 
161
+ self.inp_block = nn.Conv1d(in_channels, model_channels, 3, 1, 1)
162
  self.time_embed = nn.Sequential(
163
+ nn.Linear(model_channels, model_channels),
164
  nn.SiLU(),
165
+ nn.Linear(model_channels, model_channels),
166
  )
167
 
 
168
  # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
169
  # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
170
  # complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
171
  # transformer network.
172
+ self.code_embedding = nn.Embedding(in_tokens, model_channels)
173
  self.code_converter = nn.Sequential(
174
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
175
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
176
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  )
178
+ self.code_norm = normalization(model_channels)
179
+ self.latent_converter = nn.Conv1d(in_latent_channels, model_channels, 1)
180
+ self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
181
+ nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
182
+ AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
183
+ AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
184
+ AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
185
+ AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
186
+ AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False))
187
+ self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
188
+ self.conditioning_timestep_integrator = TimestepEmbedSequential(
189
+ DiffusionLayer(model_channels, dropout, num_heads),
190
+ DiffusionLayer(model_channels, dropout, num_heads),
191
+ DiffusionLayer(model_channels, dropout, num_heads),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  )
193
+ self.integrating_conv = nn.Conv1d(model_channels*2, model_channels, kernel_size=1)
194
+ self.mel_head = nn.Conv1d(model_channels, in_channels, kernel_size=3, padding=1)
195
+
196
+ self.layers = nn.ModuleList([DiffusionLayer(model_channels, dropout, num_heads) for _ in range(num_layers)] +
197
+ [ResBlock(model_channels, model_channels, dropout, dims=1, use_scale_shift_norm=True) for _ in range(3)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  self.out = nn.Sequential(
200
+ normalization(model_channels),
201
  nn.SiLU(),
202
+ nn.Conv1d(model_channels, out_channels, 3, padding=1),
203
  )
204
 
205
+ def get_grad_norm_parameter_groups(self):
206
+ groups = {
207
+ 'minicoder': list(self.contextual_embedder.parameters()),
208
+ 'layers': list(self.layers.parameters()),
209
+ 'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_converter.parameters()) + list(self.latent_converter.parameters()),
210
+ 'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
211
+ 'time_embed': list(self.time_embed.parameters()),
212
+ }
213
+ return groups
214
+
215
+ def timestep_independent(self, aligned_conditioning, conditioning_input, return_code_pred):
 
 
 
 
 
 
 
216
  # Shuffle aligned_latent to BxCxS format
217
  if is_latent(aligned_conditioning):
218
  aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
219
 
220
+ # Note: this block does not need to repeated on inference, since it is not timestep-dependent or x-dependent.
221
+ speech_conditioning_input = conditioning_input.unsqueeze(1) if len(
222
+ conditioning_input.shape) == 3 else conditioning_input
223
+ conds = []
224
+ for j in range(speech_conditioning_input.shape[1]):
225
+ conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
226
+ conds = torch.cat(conds, dim=-1)
227
+ cond_emb = conds.mean(dim=-1)
228
+ cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
229
+ if is_latent(aligned_conditioning):
230
+ code_emb = self.latent_converter(aligned_conditioning)
231
+ else:
232
+ code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
233
+ code_emb = self.code_converter(code_emb)
234
+ code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
235
+
236
+ unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
237
+ # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
238
+ if self.training and self.unconditioned_percentage > 0:
239
+ unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
240
+ device=code_emb.device) < self.unconditioned_percentage
241
+ code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
242
+ code_emb)
243
+ expanded_code_emb = F.interpolate(code_emb, size=aligned_conditioning.shape[-1]*4, mode='nearest')
244
+
245
+ if not return_code_pred:
246
+ return expanded_code_emb
247
+ else:
248
+ mel_pred = self.mel_head(expanded_code_emb)
249
+ # Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
250
+ mel_pred = mel_pred * unconditioned_batches.logical_not()
251
+ return expanded_code_emb, mel_pred
252
 
 
 
253
 
254
+ def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
255
+ """
256
+ Apply the model to an input batch.
257
+
258
+ :param x: an [N x C x ...] Tensor of inputs.
259
+ :param timesteps: a 1-D batch of timesteps.
260
+ :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
261
+ :param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
262
+ :param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
263
+ :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
264
+ :return: an [N x C x ...] Tensor of outputs.
265
+ """
266
+ assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_input is not None)
267
+ assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
268
+
269
+ unused_params = []
270
+ if conditioning_free:
271
+ code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
272
+ unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
273
+ unused_params.extend(list(self.latent_converter.parameters()))
274
+ else:
275
+ if precomputed_aligned_embeddings is not None:
276
  code_emb = precomputed_aligned_embeddings
277
+ else:
278
+ code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, True)
279
+ if is_latent(aligned_conditioning):
280
+ unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
 
 
 
 
 
 
 
 
 
281
  else:
282
+ unused_params.extend(list(self.latent_converter.parameters()))
283
+ unused_params.append(self.unconditioned_embedding)
284
+
285
+ time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
286
+ code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
287
+ x = self.inp_block(x)
288
+ x = torch.cat([x, code_emb], dim=1)
289
+ x = self.integrating_conv(x)
290
+ for i, lyr in enumerate(self.layers):
291
+ # Do layer drop where applicable. Do not drop first and last layers.
292
+ if self.training and self.layer_drop > 0 and i != 0 and i != (len(self.layers)-1) and random.random() < self.layer_drop:
293
+ unused_params.extend(list(lyr.parameters()))
294
+ else:
295
+ # First and last blocks will have autocast disabled for improved precision.
296
+ with autocast(x.device.type, enabled=self.enable_fp16 and i != 0):
297
+ x = lyr(x, time_emb)
298
 
299
+ x = x.float()
300
+ out = self.out(x)
301
+
302
+ # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
303
+ extraneous_addition = 0
304
+ for p in unused_params:
305
+ extraneous_addition = extraneous_addition + p.mean()
306
+ out = out + extraneous_addition * 0
307
+
308
+ if return_code_pred:
309
+ return out, mel_pred
310
  return out
311
 
312
 
313
  if __name__ == '__main__':
314
+ clip = torch.randn(2, 100, 400)
315
+ aligned_latent = torch.randn(2,388,512)
316
+ aligned_sequence = torch.randint(0,8192,(2,100))
317
+ cond = torch.randn(2, 100, 400)
318
  ts = torch.LongTensor([600, 600])
319
+ model = DiffusionTts(512, layer_drop=.3, unconditioned_percentage=.5)
 
 
 
 
 
 
 
 
 
 
320
  # Test with latent aligned conditioning
321
+ #o = model(clip, ts, aligned_latent, cond)
322
  # Test with sequence aligned conditioning
323
  o = model(clip, ts, aligned_sequence, cond)
324
+