marcop commited on
Commit
9838bd7
·
1 Parent(s): ade5cba

update to musika 44khz

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Marco Pasini
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🎵
4
  colorFrom: purple
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 2.5.2
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-4.0
 
4
  colorFrom: purple
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.3.1
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-4.0
__pycache__/layers.cpython-39.pyc DELETED
Binary file (5.8 kB)
 
__pycache__/models.cpython-39.pyc DELETED
Binary file (13.5 kB)
 
__pycache__/parse_test.cpython-39.pyc DELETED
Binary file (3.47 kB)
 
__pycache__/utils.cpython-39.pyc DELETED
Binary file (16.5 kB)
 
app.py CHANGED
@@ -8,8 +8,8 @@ args = parse_args()
8
 
9
  # initialize networks
10
  M = Models_functions(args)
11
- models_ls_techno, models_ls_classical = M.get_networks()
12
 
13
  # test musika
14
  U = Utils_functions(args)
15
- U.render_gradio(models_ls_techno, models_ls_classical, train=False)
 
8
 
9
  # initialize networks
10
  M = Models_functions(args)
11
+ models_ls_1, models_ls_2, models_ls_3 = M.get_networks()
12
 
13
  # test musika
14
  U = Utils_functions(args)
15
+ U.render_gradio(models_ls_1, models_ls_2, models_ls_3, train=False)
checkpoints/{classical → ae}/dec.h5 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:32e50f938685f56e1dc5e137e95c59472418194125e435cafb668584e65b0fcc
3
- size 29745512
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9aebbb210da611415c8e15b4bd7cc62c20d6a702169c709c3e6cc3912fb0aa84
3
+ size 50781776
checkpoints/{techno → ae}/dec2.h5 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8558f7e78c0bddd5f816c1b1058cda6d1fb3af9d1b6394b98d36147c2f1abb13
3
- size 26799136
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffc4c0d6821e23247889a2252da86f9ba22ad6425d004790457e0c2e57e18c65
3
+ size 26616400
checkpoints/{classical/dec2.h5 → ae/enc.h5} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:643e1415f5ebf34cb067b2ccaced657f1f3e3f068810c2d130dfc3abb20c3cc2
3
- size 26799136
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e56882c9e62aa4703a6d088efc374bfc534a4a8b4f3cdc0418fab1f0da1795a7
3
+ size 19196936
checkpoints/{techno/dec.h5 → ae/enc2.h5} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c3488041aadfda699de87ecaf10b9933264bae7d3a14c0df4ea1d02876017774
3
- size 29745512
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2f2f2b2b2879ad14690b2e64eb59f3b5dc3fbffaa88b37fdf5e8735b9dad305
3
+ size 15986152
checkpoints/{classical → metal}/gen_ema.h5 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1508fa1fad225f7c29ed64252b531e4462c97b0f0c80dd6d212014916f6261eb
3
- size 56431944
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e1f59d7fc67df5fa9eb0ab4665bcc9fe27d346fcfe60a082c8937c622105d7f
3
+ size 62200720
checkpoints/misc/gen_ema.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f0cb86f0bf22a568f1ef1ad5782d11543f9779069de996b81907328eeb24c8e
3
+ size 62200720
checkpoints/techno/gen_ema.h5 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:58fb65920fd5a29c7e45646f972867998101f3d7feb36a34a273166d67fd9828
3
- size 56431944
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:582149f06f50197022758a7ca981a13164364871321ab2cf0662fc6ed7d634b0
3
+ size 62200304
layers.py CHANGED
@@ -1,12 +1,7 @@
1
  import tensorflow as tf
2
- import tensorflow.keras.backend as K
3
- from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Dense
4
  from tensorflow.python.eager import context
5
- from tensorflow.python.framework import tensor_shape
6
- from tensorflow.python.keras import activations, constraints, initializers, regularizers
7
- from tensorflow.python.keras.layers.convolutional import SeparableConv
8
  from tensorflow.python.ops import (
9
- array_ops,
10
  gen_math_ops,
11
  math_ops,
12
  sparse_ops,
@@ -19,9 +14,10 @@ def l2normalize(v, eps=1e-12):
19
 
20
 
21
  class ConvSN2D(tf.keras.layers.Conv2D):
22
- def __init__(self, filters, kernel_size, power_iterations=1, **kwargs):
23
  super(ConvSN2D, self).__init__(filters, kernel_size, **kwargs)
24
  self.power_iterations = power_iterations
 
25
 
26
  def build(self, input_shape):
27
  super(ConvSN2D, self).build(input_shape)
@@ -70,7 +66,11 @@ class ConvSN2D(tf.keras.layers.Conv2D):
70
  return outputs
71
 
72
 
73
- class DenseSN(Dense):
 
 
 
 
74
  def build(self, input_shape):
75
  super(DenseSN, self).build(input_shape)
76
 
@@ -79,7 +79,7 @@ class DenseSN(Dense):
79
  shape=tuple([1, self.kernel.shape.as_list()[-1]]),
80
  initializer=tf.initializers.RandomNormal(0, 1),
81
  trainable=False,
82
- dtype=self.dtype,
83
  )
84
 
85
  def compute_spectral_norm(self, W, new_u, W_shape):
@@ -116,6 +116,10 @@ class DenseSN(Dense):
116
 
117
 
118
  class AddNoise(tf.keras.layers.Layer):
 
 
 
 
119
  def build(self, input_shape):
120
  self.b = self.add_weight(
121
  shape=[
@@ -131,26 +135,24 @@ class AddNoise(tf.keras.layers.Layer):
131
  [tf.shape(inputs)[0], inputs.shape[1], inputs.shape[2], 1],
132
  mean=0.0,
133
  stddev=1.0,
134
- dtype=self.dtype,
135
  )
136
  output = inputs + self.b * rand
137
  return output
138
 
139
 
140
  class PosEnc(tf.keras.layers.Layer):
141
- def __init__(self, **kwargs):
142
  super(PosEnc, self).__init__(**kwargs)
 
143
 
144
  def call(self, inputs):
145
- # inputs shape: [bs,mel_bins,shape,1]
146
  pos = tf.repeat(
147
  tf.reshape(tf.range(inputs.shape[-3], dtype=tf.int32), [1, -1, 1, 1]),
148
  inputs.shape[-2],
149
  -2,
150
  )
151
- pos = tf.cast(tf.repeat(pos, tf.shape(inputs)[0], 0), self.dtype) / tf.cast(
152
- inputs.shape[-3], self.dtype
153
- )
154
  return tf.concat([inputs, pos], -1) # [bs,1,hop,2]
155
 
156
 
@@ -159,6 +161,3 @@ def flatten_hw(x, data_format="channels_last"):
159
  x = tf.transpose(x, perm=[0, 3, 1, 2]) # Convert to `channels_first`
160
 
161
  old_shape = tf.shape(x)
162
- new_shape = [old_shape[0], old_shape[2] * old_shape[3], old_shape[1]]
163
-
164
- return tf.reshape(x, new_shape)
 
1
  import tensorflow as tf
2
+ import tensorflow.python.keras.backend as K
 
3
  from tensorflow.python.eager import context
 
 
 
4
  from tensorflow.python.ops import (
 
5
  gen_math_ops,
6
  math_ops,
7
  sparse_ops,
 
14
 
15
 
16
  class ConvSN2D(tf.keras.layers.Conv2D):
17
+ def __init__(self, filters, kernel_size, power_iterations=1, datatype=tf.float32, **kwargs):
18
  super(ConvSN2D, self).__init__(filters, kernel_size, **kwargs)
19
  self.power_iterations = power_iterations
20
+ self.datatype = datatype
21
 
22
  def build(self, input_shape):
23
  super(ConvSN2D, self).build(input_shape)
 
66
  return outputs
67
 
68
 
69
+ class DenseSN(tf.keras.layers.Dense):
70
+ def __init__(self, datatype=tf.float32, **kwargs):
71
+ super(DenseSN, self).__init__(**kwargs)
72
+ self.datatype = datatype
73
+
74
  def build(self, input_shape):
75
  super(DenseSN, self).build(input_shape)
76
 
 
79
  shape=tuple([1, self.kernel.shape.as_list()[-1]]),
80
  initializer=tf.initializers.RandomNormal(0, 1),
81
  trainable=False,
82
+ dtype=self.datatype,
83
  )
84
 
85
  def compute_spectral_norm(self, W, new_u, W_shape):
 
116
 
117
 
118
  class AddNoise(tf.keras.layers.Layer):
119
+ def __init__(self, datatype=tf.float32, **kwargs):
120
+ super(AddNoise, self).__init__(**kwargs)
121
+ self.datatype = datatype
122
+
123
  def build(self, input_shape):
124
  self.b = self.add_weight(
125
  shape=[
 
135
  [tf.shape(inputs)[0], inputs.shape[1], inputs.shape[2], 1],
136
  mean=0.0,
137
  stddev=1.0,
138
+ dtype=self.datatype,
139
  )
140
  output = inputs + self.b * rand
141
  return output
142
 
143
 
144
  class PosEnc(tf.keras.layers.Layer):
145
+ def __init__(self, datatype=tf.float32, **kwargs):
146
  super(PosEnc, self).__init__(**kwargs)
147
+ self.datatype = datatype
148
 
149
  def call(self, inputs):
 
150
  pos = tf.repeat(
151
  tf.reshape(tf.range(inputs.shape[-3], dtype=tf.int32), [1, -1, 1, 1]),
152
  inputs.shape[-2],
153
  -2,
154
  )
155
+ pos = tf.cast(tf.repeat(pos, tf.shape(inputs)[0], 0), self.dtype) / tf.cast(inputs.shape[-3], self.datatype)
 
 
156
  return tf.concat([inputs, pos], -1) # [bs,1,hop,2]
157
 
158
 
 
161
  x = tf.transpose(x, perm=[0, 3, 1, 2]) # Convert to `channels_first`
162
 
163
  old_shape = tf.shape(x)
 
 
 
models.py CHANGED
@@ -1,60 +1,31 @@
1
  import numpy as np
2
  import tensorflow as tf
3
- import tensorflow_addons as tfa
4
- from tensorflow.keras import mixed_precision
5
- from tensorflow.keras.layers import (
6
- Add,
7
- BatchNormalization,
8
- Concatenate,
9
- Conv2D,
10
- Conv2DTranspose,
11
- Cropping1D,
12
- Cropping2D,
13
- Dense,
14
- Dot,
15
- Flatten,
16
- GlobalAveragePooling2D,
17
- Input,
18
- Lambda,
19
- LeakyReLU,
20
- Multiply,
21
- ReLU,
22
- Reshape,
23
- SeparableConv2D,
24
- UpSampling2D,
25
- ZeroPadding2D,
26
- )
27
- from tensorflow.keras.models import Model, Sequential
28
- from tensorflow.keras.optimizers import Adam
29
  from tensorflow.python.keras.utils.layer_utils import count_params
30
 
31
- from layers import ConvSN2D, DenseSN, PosEnc, AddNoise
32
 
33
 
34
  class Models_functions:
35
  def __init__(self, args):
36
 
37
  self.args = args
 
38
  if self.args.mixed_precision:
39
- self.mixed_precision = mixed_precision
40
- self.policy = mixed_precision.Policy("mixed_float16")
41
- mixed_precision.set_global_policy(self.policy)
42
  self.init = tf.keras.initializers.he_uniform()
43
 
44
  def conv_util(
45
- self,
46
- inp,
47
- filters,
48
- kernel_size=(1, 3),
49
- strides=(1, 1),
50
- noise=False,
51
- upsample=False,
52
- padding="same",
53
- bn=True,
54
  ):
55
 
56
  x = inp
57
 
 
 
 
 
58
  if upsample:
59
  x = tf.keras.layers.Conv2DTranspose(
60
  filters,
@@ -63,6 +34,7 @@ class Models_functions:
63
  activation="linear",
64
  padding=padding,
65
  kernel_initializer=self.init,
 
66
  )(x)
67
  else:
68
  x = tf.keras.layers.Conv2D(
@@ -72,19 +44,26 @@ class Models_functions:
72
  activation="linear",
73
  padding=padding,
74
  kernel_initializer=self.init,
 
75
  )(x)
76
 
77
  if noise:
78
- x = AddNoise()(x)
79
 
80
- if bn:
81
  x = tf.keras.layers.BatchNormalization()(x)
82
 
83
  x = tf.keras.activations.swish(x)
84
 
85
  return x
86
 
87
- def adain(self, x, emb):
 
 
 
 
 
 
88
  emb = tf.keras.layers.Conv2D(
89
  x.shape[-1],
90
  kernel_size=(1, 1),
@@ -93,32 +72,11 @@ class Models_functions:
93
  padding="same",
94
  kernel_initializer=self.init,
95
  use_bias=True,
 
96
  )(emb)
97
- x = x / (tf.math.reduce_std(x, -2, keepdims=True) + 1e-7)
98
  return x * emb
99
 
100
- def se_layer(self, x, filters):
101
- x = tf.reduce_mean(x, -2, keepdims=True)
102
- x = tf.keras.layers.Conv2D(
103
- filters,
104
- kernel_size=(1, 1),
105
- strides=(1, 1),
106
- activation="linear",
107
- padding="valid",
108
- kernel_initializer=self.init,
109
- use_bias=True,
110
- )(x)
111
- x = tf.keras.activations.swish(x)
112
- return tf.keras.layers.Conv2D(
113
- filters,
114
- kernel_size=(1, 1),
115
- strides=(1, 1),
116
- activation="sigmoid",
117
- padding="valid",
118
- kernel_initializer=self.init,
119
- use_bias=True,
120
- )(x)
121
-
122
  def conv_util_gen(
123
  self,
124
  inp,
@@ -129,6 +87,7 @@ class Models_functions:
129
  upsample=False,
130
  emb=None,
131
  se1=None,
 
132
  ):
133
 
134
  x = inp
@@ -142,6 +101,7 @@ class Models_functions:
142
  padding="same",
143
  kernel_initializer=self.init,
144
  use_bias=True,
 
145
  )(x)
146
  else:
147
  x = tf.keras.layers.Conv2D(
@@ -152,24 +112,22 @@ class Models_functions:
152
  padding="same",
153
  kernel_initializer=self.init,
154
  use_bias=True,
 
155
  )(x)
156
 
157
  if noise:
158
- x = AddNoise()(x)
159
 
160
  if emb is not None:
161
- x = self.adain(x, emb)
162
  else:
163
- x = tf.keras.layers.BatchNormalization()(x)
164
-
165
- x1 = tf.keras.activations.swish(x)
166
 
167
- if se1 is not None:
168
- x1 = x1 * se1
169
 
170
- return x1
171
 
172
- def res_block_disc(self, inp, filters, kernel_size=(1, 3), kernel_size_2=None, strides=(1, 1)):
173
 
174
  if kernel_size_2 is None:
175
  kernel_size_2 = kernel_size
@@ -181,6 +139,7 @@ class Models_functions:
181
  activation="linear",
182
  padding="same",
183
  kernel_initializer=self.init,
 
184
  )(inp)
185
  x = tf.keras.layers.LeakyReLU(0.2)(x)
186
  x = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * x
@@ -191,6 +150,7 @@ class Models_functions:
191
  activation="linear",
192
  padding="same",
193
  kernel_initializer=self.init,
 
194
  )(x)
195
  x = tf.keras.layers.LeakyReLU(0.2)(x)
196
  x = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * x
@@ -207,28 +167,32 @@ class Models_functions:
207
  padding="same",
208
  kernel_initializer=self.init,
209
  use_bias=False,
 
210
  )(inp)
211
 
212
  return x + inp
213
 
214
  def build_encoder2(self):
215
 
216
- dim = 128
217
-
218
- inpf = Input((1, self.args.shape, dim))
219
 
220
- inpfls = tf.split(inpf, 16, -2)
221
  inpb = tf.concat(inpfls, 0)
222
 
223
- g0 = self.conv_util(inpb, 256, kernel_size=(1, 1), strides=(1, 1), padding="valid")
224
- g1 = self.conv_util(g0, 256 + 256, kernel_size=(1, 3), strides=(1, 1), padding="valid")
225
- g2 = self.conv_util(g1, 512 + 128, kernel_size=(1, 3), strides=(1, 1), padding="valid")
226
- g3 = self.conv_util(g2, 512 + 128, kernel_size=(1, 1), strides=(1, 1), padding="valid")
227
- g4 = self.conv_util(g3, 512 + 256, kernel_size=(1, 3), strides=(1, 1), padding="valid")
228
- g5 = self.conv_util(g4, 512 + 256, kernel_size=(1, 2), strides=(1, 1), padding="valid")
 
 
 
 
 
229
 
230
  g = tf.keras.layers.Conv2D(
231
- 64,
232
  kernel_size=(1, 1),
233
  strides=1,
234
  padding="valid",
@@ -237,50 +201,61 @@ class Models_functions:
237
  activation="tanh",
238
  )(g5)
239
 
240
- gls = tf.split(g, 16, 0)
241
  g = tf.concat(gls, -2)
242
  gls = tf.split(g, 2, -2)
243
  g = tf.concat(gls, 0)
244
 
245
  gf = tf.cast(g, tf.float32)
246
- return Model(inpf, gf, name="ENC2")
247
 
248
- def build_decoder2(self):
249
 
250
- dim = 128
251
- bottledim = 64
252
 
253
- inpf = Input((1, self.args.shape // 16, bottledim))
254
 
255
  g = inpf
256
 
 
 
 
257
  g = self.conv_util(
258
  g,
259
- 512 + 128 + 128,
260
  kernel_size=(1, 4),
 
 
 
 
 
 
 
 
 
261
  strides=(1, 1),
262
  upsample=False,
263
  noise=True,
 
 
 
 
 
 
 
264
  )
265
  g = self.conv_util(
266
  g,
267
- 512 + 128 + 128,
268
  kernel_size=(1, 4),
269
  strides=(1, 2),
270
  upsample=True,
271
  noise=True,
 
272
  )
273
- g = self.conv_util(g, 512 + 128, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True)
274
- g = self.conv_util(g, 512, kernel_size=(1, 4), strides=(1, 1), upsample=False, noise=True)
275
- g = self.conv_util(g, 256 + 128, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True)
276
 
277
  gf = tf.keras.layers.Conv2D(
278
- dim,
279
- kernel_size=(1, 1),
280
- strides=1,
281
- padding="same",
282
- activation="tanh",
283
- kernel_initializer=self.init,
284
  )(g)
285
 
286
  gfls = tf.split(gf, 2, 0)
@@ -288,131 +263,77 @@ class Models_functions:
288
 
289
  gf = tf.cast(gf, tf.float32)
290
 
291
- return Model(inpf, gf, name="DEC2")
292
 
293
  def build_encoder(self):
294
 
295
  dim = ((4 * self.args.hop) // 2) + 1
296
 
297
- inpf = Input((dim, self.args.shape, 1))
298
 
299
  ginp = tf.transpose(inpf, [0, 3, 2, 1])
300
 
301
- g0 = self.conv_util(
302
- ginp,
303
- self.args.hop * 2 + 32,
304
- kernel_size=(1, 1),
305
- strides=(1, 1),
306
- padding="valid",
307
- )
308
-
309
- g = self.conv_util(
310
- g0,
311
- self.args.hop * 2 + 64,
312
- kernel_size=(1, 1),
313
- strides=(1, 1),
314
- padding="valid",
315
- )
316
- g = self.conv_util(
317
- g,
318
- self.args.hop * 2 + 64 + 64,
319
- kernel_size=(1, 1),
320
- strides=(1, 1),
321
- padding="valid",
322
- )
323
- g = self.conv_util(
324
- g,
325
- self.args.hop * 2 + 128 + 64,
326
- kernel_size=(1, 1),
327
- strides=(1, 1),
328
- padding="valid",
329
- )
330
- g = self.conv_util(
331
- g,
332
- self.args.hop * 2 + 128 + 128,
333
- kernel_size=(1, 1),
334
- strides=(1, 1),
335
- padding="valid",
336
- )
337
 
338
  g = tf.keras.layers.Conv2D(
339
- 128,
340
- kernel_size=(1, 1),
341
- strides=1,
342
- padding="valid",
343
- kernel_initializer=self.init,
344
- )(g)
345
- gb = tf.keras.activations.tanh(g)
346
 
347
- gbls = tf.split(gb, 2, -2)
348
- gb = tf.concat(gbls, 0)
349
 
350
- gb = tf.cast(gb, tf.float32)
351
- return Model(inpf, gb, name="ENC")
352
 
353
  def build_decoder(self):
354
 
355
  dim = ((4 * self.args.hop) // 2) + 1
356
 
357
- inpf = Input((1, self.args.shape // 2, 128))
358
 
359
  g = inpf
360
 
361
- g0 = self.conv_util(g, self.args.hop * 3, kernel_size=(1, 1), strides=(1, 1), noise=True)
 
 
 
 
362
 
363
- g1 = self.conv_util(g0, self.args.hop * 2, kernel_size=(1, 3), strides=(1, 2), noise=True)
364
- g2 = self.conv_util(
365
- g1,
366
- self.args.hop + self.args.hop // 2,
367
- kernel_size=(1, 3),
368
- strides=(1, 2),
369
- noise=True,
370
  )
371
- g = self.conv_util(
372
- g2,
373
- self.args.hop + self.args.hop // 4,
374
- kernel_size=(1, 3),
375
- strides=(1, 2),
376
- noise=True,
377
  )
378
-
379
- g = self.conv_util(
380
- g,
381
- self.args.hop + self.args.hop // 2,
382
- kernel_size=(1, 4),
383
- strides=(1, 2),
384
- upsample=True,
385
- noise=True,
386
  )
 
 
 
 
 
 
387
  g = self.conv_util(
388
- g + g2,
389
- self.args.hop * 2,
390
- kernel_size=(1, 4),
391
- strides=(1, 2),
392
- upsample=True,
393
- noise=True,
394
  )
395
  g = self.conv_util(
396
- g + g1,
397
- self.args.hop * 3,
398
- kernel_size=(1, 4),
399
- strides=(1, 2),
400
- upsample=True,
401
- noise=True,
402
  )
403
-
404
- g = self.conv_util(g + g0, self.args.hop * 5, kernel_size=(1, 1), strides=(1, 1), noise=True)
405
-
406
- g = Conv2D(
407
- dim * 2,
408
- kernel_size=(1, 1),
409
- strides=(1, 1),
410
- kernel_initializer=self.init,
411
- padding="same",
412
- )(g)
413
- g = tf.clip_by_value(g, -1.0, 1.0)
414
-
415
- gf, pf = tf.split(g, 2, -1)
416
 
417
  gfls = tf.split(gf, self.args.shape // self.args.window, 0)
418
  gf = tf.concat(gfls, -2)
@@ -426,128 +347,60 @@ class Models_functions:
426
  s = tf.cast(tf.squeeze(s, -1), tf.float32)
427
  p = tf.cast(tf.squeeze(p, -1), tf.float32)
428
 
429
- return Model(inpf, [s, p], name="DEC")
430
 
431
  def build_critic(self):
432
 
433
- sinp = Input(shape=(1, self.args.latlen, self.args.latdepth * 2))
434
-
435
- dim = 64 * 2
436
-
437
- sf = tf.keras.layers.Conv2D(
438
- self.args.latdepth * 4,
439
- kernel_size=(1, 1),
440
- strides=(1, 1),
441
- activation="linear",
442
- padding="valid",
443
- kernel_initializer=self.init,
444
- use_bias=False,
445
- trainable=False,
446
- )(sinp)
447
 
448
  sf = tf.keras.layers.Conv2D(
449
- 256 + 128,
450
- kernel_size=(1, 3),
451
  strides=(1, 2),
452
  activation="linear",
453
  padding="same",
454
  kernel_initializer=self.init,
455
- )(sf)
 
456
  sf = tf.keras.layers.LeakyReLU(0.2)(sf)
457
- sf = self.res_block_disc(sf, 256 + 128 + 128, kernel_size=(1, 3), strides=(1, 2))
458
- sf = self.res_block_disc(sf, 512 + 128, kernel_size=(1, 3), strides=(1, 2))
459
- sf = self.res_block_disc(sf, 512 + 256, kernel_size=(1, 3), strides=(1, 2))
460
- sf = self.res_block_disc(sf, 512 + 128 + 256, kernel_size=(1, 3), strides=(1, 2))
461
- sfo = self.res_block_disc(sf, 512 + 512, kernel_size=(1, 3), strides=(1, 2), kernel_size_2=(1, 1))
462
- sf = sfo
463
 
464
- gf = tf.keras.layers.Dense(1, activation="linear", use_bias=True, kernel_initializer=self.init)(Flatten()(sf))
465
 
466
- gf = tf.cast(gf, tf.float32)
467
- sfo = tf.cast(sfo, tf.float32)
468
 
469
- return Model(sinp, [gf, sfo], name="C")
470
 
471
- def build_critic_rec(self):
472
 
473
- sinp = Input(shape=(1, self.args.latlen // 64, 512 + 512))
 
 
 
474
 
475
- dim = self.args.latdepth * 2
476
-
477
- sf = tf.keras.layers.Conv2DTranspose(
478
- 512,
479
- kernel_size=(1, 4),
480
- strides=(1, 2),
481
- activation="linear",
482
- padding="same",
483
- kernel_initializer=self.init,
484
- )(sinp)
485
- sf = tf.keras.layers.LeakyReLU(0.2)(sf)
486
-
487
- sf = tf.keras.layers.Conv2DTranspose(
488
- 256 + 128 + 64,
489
- kernel_size=(1, 4),
490
- strides=(1, 2),
491
- activation="linear",
492
- padding="same",
493
- kernel_initializer=self.init,
494
- )(sf)
495
- sf = tf.keras.layers.LeakyReLU(0.2)(sf)
496
- sf = tf.keras.layers.Conv2DTranspose(
497
- 256 + 128,
498
- kernel_size=(1, 4),
499
- strides=(1, 2),
500
- activation="linear",
501
- padding="same",
502
- kernel_initializer=self.init,
503
- )(sf)
504
- sf = tf.keras.layers.LeakyReLU(0.2)(sf)
505
- sf = tf.keras.layers.Conv2DTranspose(
506
- 256 + 64,
507
- kernel_size=(1, 4),
508
- strides=(1, 2),
509
- activation="linear",
510
- padding="same",
511
- kernel_initializer=self.init,
512
- )(sf)
513
- sf = tf.keras.layers.LeakyReLU(0.2)(sf)
514
- sf = tf.keras.layers.Conv2DTranspose(
515
- 256,
516
- kernel_size=(1, 4),
517
- strides=(1, 2),
518
- activation="linear",
519
- padding="same",
520
- kernel_initializer=self.init,
521
- )(sf)
522
- sf = tf.keras.layers.LeakyReLU(0.2)(sf)
523
- sf = tf.keras.layers.Conv2DTranspose(
524
- 128 + 64,
525
- kernel_size=(1, 4),
526
- strides=(1, 2),
527
  activation="linear",
528
  padding="same",
529
  kernel_initializer=self.init,
 
530
  )(sf)
531
  sf = tf.keras.layers.LeakyReLU(0.2)(sf)
532
 
533
- gf = tf.keras.layers.Conv2D(
534
- dim,
535
- kernel_size=(1, 1),
536
- strides=(1, 1),
537
- activation="tanh",
538
- padding="same",
539
- kernel_initializer=self.init,
540
- )(sf)
541
 
542
  gf = tf.cast(gf, tf.float32)
543
 
544
- return Model(sinp, gf, name="CR")
545
 
546
  def build_generator(self):
547
 
548
  dim = self.args.latdepth * 2
549
 
550
- inpf = Input((self.args.latlen, self.args.latdepth * 2))
551
 
552
  inpfls = tf.split(inpf, 2, -2)
553
  inpb = tf.concat(inpfls, 0)
@@ -558,112 +411,213 @@ class Models_functions:
558
  inp3 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp2)
559
  inp4 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp3)
560
  inp5 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp4)
561
- inp6 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp5)
 
562
 
563
- g = tf.keras.layers.Dense(
564
- 4 * (512 + 256 + 128),
565
- activation="linear",
566
- use_bias=True,
567
- kernel_initializer=self.init,
568
- )(Flatten()(inp6))
569
- g = tf.keras.layers.Reshape((1, 4, 512 + 256 + 128))(g)
570
- g = AddNoise()(g)
571
- g = self.adain(g, inp5)
572
- g = tf.keras.activations.swish(g)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
 
574
- g = self.conv_util_gen(
575
- g,
576
- 512 + 256,
577
- kernel_size=(1, 4),
578
- strides=(1, 2),
579
- upsample=True,
580
- noise=True,
581
- emb=inp4,
582
- )
583
- g1 = self.conv_util_gen(
584
- g,
585
- 512 + 256,
586
- kernel_size=(1, 1),
587
- strides=(1, 1),
588
- upsample=False,
589
- noise=True,
590
- emb=inp4,
591
- )
592
  g2 = self.conv_util_gen(
593
  g1,
594
- 512 + 128,
595
  kernel_size=(1, 4),
596
  strides=(1, 2),
597
  upsample=True,
598
  noise=True,
599
  emb=inp3,
 
600
  )
601
- g2b = self.conv_util_gen(
 
602
  g2,
603
- 512 + 128,
604
- kernel_size=(1, 3),
605
  strides=(1, 1),
606
  upsample=False,
607
  noise=True,
608
  emb=inp3,
 
609
  )
 
 
 
 
 
 
 
 
 
 
 
 
610
  g3 = self.conv_util_gen(
611
- g2b,
612
- 256 + 256,
613
  kernel_size=(1, 4),
614
  strides=(1, 2),
615
  upsample=True,
616
  noise=True,
617
  emb=inp2,
618
- se1=self.se_layer(g, 256 + 256),
619
  )
 
620
  g3 = self.conv_util_gen(
621
  g3,
622
- 256 + 256,
623
- kernel_size=(1, 3),
624
  strides=(1, 1),
625
  upsample=False,
626
  noise=True,
627
  emb=inp2,
628
- se1=self.se_layer(g1, 256 + 256),
629
  )
 
 
 
 
 
 
 
 
 
 
 
 
630
  g4 = self.conv_util_gen(
631
  g3,
632
- 256 + 128,
633
  kernel_size=(1, 4),
634
  strides=(1, 2),
635
  upsample=True,
636
  noise=True,
637
  emb=inp1,
638
- se1=self.se_layer(g2, 256 + 128),
639
  )
 
640
  g4 = self.conv_util_gen(
641
  g4,
642
- 256 + 128,
643
- kernel_size=(1, 3),
644
  strides=(1, 1),
645
  upsample=False,
646
  noise=True,
647
  emb=inp1,
648
- se1=self.se_layer(g2b, 256 + 128),
649
  )
 
 
 
 
 
 
 
 
 
 
 
 
650
  g5 = self.conv_util_gen(
651
  g4,
652
- 256,
653
  kernel_size=(1, 4),
654
  strides=(1, 2),
655
  upsample=True,
656
  noise=True,
657
  emb=tf.expand_dims(tf.cast(inpb, dtype=self.args.datatype), -3),
 
658
  )
659
 
660
  gf = tf.keras.layers.Conv2D(
661
- dim,
662
- kernel_size=(1, 1),
663
- strides=(1, 1),
664
- kernel_initializer=self.init,
665
- padding="same",
666
- activation="tanh",
667
  )(g5)
668
 
669
  gfls = tf.split(gf, 2, 0)
@@ -671,7 +625,7 @@ class Models_functions:
671
 
672
  gf = tf.cast(gf, tf.float32)
673
 
674
- return Model(inpf, gf, name="GEN")
675
 
676
  # Load past models from path to resume training or test
677
  def load(self, path, load_dec=False):
@@ -681,12 +635,13 @@ class Models_functions:
681
  dec = self.build_decoder()
682
  enc2 = self.build_encoder2()
683
  dec2 = self.build_decoder2()
684
- critic_rec = self.build_critic_rec()
685
  gen_ema = self.build_generator()
686
 
 
 
687
  if self.args.mixed_precision:
688
- opt_disc = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.9))
689
- opt_dec = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.9))
690
  else:
691
  opt_disc = tf.keras.optimizers.Adam(0.0001, 0.9)
692
  opt_dec = tf.keras.optimizers.Adam(0.0001, 0.9)
@@ -694,9 +649,11 @@ class Models_functions:
694
  if load_dec:
695
  dec.load_weights(self.args.dec_path + "/dec.h5")
696
  dec2.load_weights(self.args.dec_path + "/dec2.h5")
 
 
697
 
698
  else:
699
- grad_vars = critic.trainable_weights + critic_rec.trainable_weights
700
  zero_grads = [tf.zeros_like(w) for w in grad_vars]
701
  opt_disc.apply_gradients(zip(zero_grads, grad_vars))
702
 
@@ -707,16 +664,15 @@ class Models_functions:
707
  if not self.args.testing:
708
  opt_disc.set_weights(np.load(path + "/opt_disc.npy", allow_pickle=True))
709
  opt_dec.set_weights(np.load(path + "/opt_dec.npy", allow_pickle=True))
710
-
711
- if not self.args.testing:
712
  critic.load_weights(path + "/critic.h5")
713
  gen.load_weights(path + "/gen.h5")
714
- enc.load_weights(path + "/enc.h5")
715
- enc2.load_weights(path + "/enc2.h5")
716
- critic_rec.load_weights(path + "/critic_rec.h5")
717
  gen_ema.load_weights(path + "/gen_ema.h5")
718
- dec.load_weights(path + "/dec.h5")
719
- dec2.load_weights(path + "/dec2.h5")
 
 
720
 
721
  return (
722
  critic,
@@ -725,9 +681,9 @@ class Models_functions:
725
  dec,
726
  enc2,
727
  dec2,
728
- critic_rec,
729
  gen_ema,
730
  [opt_dec, opt_disc],
 
731
  )
732
 
733
  def build(self):
@@ -737,18 +693,19 @@ class Models_functions:
737
  dec = self.build_decoder()
738
  enc2 = self.build_encoder2()
739
  dec2 = self.build_decoder2()
740
- critic_rec = self.build_critic_rec()
741
  gen_ema = self.build_generator()
742
 
 
 
743
  gen_ema = tf.keras.models.clone_model(gen)
744
  gen_ema.set_weights(gen.get_weights())
745
 
746
  if self.args.mixed_precision:
747
- opt_disc = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.9))
748
- opt_dec = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.9))
749
  else:
750
- opt_disc = tf.keras.optimizers.Adam(0.0001, 0.9)
751
- opt_dec = tf.keras.optimizers.Adam(0.0001, 0.9)
752
 
753
  return (
754
  critic,
@@ -757,9 +714,9 @@ class Models_functions:
757
  dec,
758
  enc2,
759
  dec2,
760
- critic_rec,
761
  gen_ema,
762
  [opt_dec, opt_disc],
 
763
  )
764
 
765
  def get_networks(self):
@@ -767,65 +724,60 @@ class Models_functions:
767
  critic,
768
  gen,
769
  enc,
770
- dec_techno,
771
  enc2,
772
- dec2_techno,
773
- critic_rec,
774
- gen_ema_techno,
775
  [opt_dec, opt_disc],
776
- ) = self.load(self.args.load_path_techno, load_dec=False)
777
- print(f"Techno networks loaded from {self.args.load_path_techno}")
 
778
 
779
  (
780
  critic,
781
  gen,
782
  enc,
783
- dec_classical,
784
  enc2,
785
- dec2_classical,
786
- critic_rec,
787
- gen_ema_classical,
788
  [opt_dec, opt_disc],
789
- ) = self.load(self.args.load_path_classical, load_dec=False)
790
- print(f"Classical networks loaded from {self.args.load_path_classical}")
 
791
 
792
- return [critic, gen, enc, dec_techno, enc2, dec2_techno, critic_rec, gen_ema_techno, [opt_dec, opt_disc]], [
793
  critic,
794
  gen,
795
  enc,
796
- dec_classical,
797
  enc2,
798
- dec2_classical,
799
- critic_rec,
800
- gen_ema_classical,
801
  [opt_dec, opt_disc],
802
- ]
 
 
 
 
 
 
 
 
803
 
804
  def initialize_networks(self):
805
 
806
- [critic, gen, enc, dec_techno, enc2, dec2_techno, critic_rec, gen_ema_techno, [opt_dec, opt_disc]], [
807
- critic,
808
- gen,
809
- enc,
810
- dec_classical,
811
- enc2,
812
- dec2_classical,
813
- critic_rec,
814
- gen_ema_classical,
815
- [opt_dec, opt_disc],
816
- ] = self.get_networks()
817
 
818
- print(f"Generator params: {count_params(gen_ema_techno.trainable_variables)}")
819
- print(f"Decoder params: {count_params(dec_techno.trainable_variables+dec2_techno.trainable_variables)}")
820
 
821
- return [critic, gen, enc, dec_techno, enc2, dec2_techno, critic_rec, gen_ema_techno, [opt_dec, opt_disc]], [
822
- critic,
823
- gen,
824
- enc,
825
- dec_classical,
826
- enc2,
827
- dec2_classical,
828
- critic_rec,
829
- gen_ema_classical,
830
- [opt_dec, opt_disc],
831
- ]
 
1
  import numpy as np
2
  import tensorflow as tf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from tensorflow.python.keras.utils.layer_utils import count_params
4
 
5
+ from layers import AddNoise
6
 
7
 
8
  class Models_functions:
9
  def __init__(self, args):
10
 
11
  self.args = args
12
+
13
  if self.args.mixed_precision:
14
+ self.mixed_precision = tf.keras.mixed_precision
15
+ self.policy = tf.keras.mixed_precision.Policy("mixed_float16")
16
+ tf.keras.mixed_precision.set_global_policy(self.policy)
17
  self.init = tf.keras.initializers.he_uniform()
18
 
19
  def conv_util(
20
+ self, inp, filters, kernel_size=(1, 3), strides=(1, 1), noise=False, upsample=False, padding="same", bnorm=True
 
 
 
 
 
 
 
 
21
  ):
22
 
23
  x = inp
24
 
25
+ bias = True
26
+ if bnorm:
27
+ bias = False
28
+
29
  if upsample:
30
  x = tf.keras.layers.Conv2DTranspose(
31
  filters,
 
34
  activation="linear",
35
  padding=padding,
36
  kernel_initializer=self.init,
37
+ use_bias=bias,
38
  )(x)
39
  else:
40
  x = tf.keras.layers.Conv2D(
 
44
  activation="linear",
45
  padding=padding,
46
  kernel_initializer=self.init,
47
+ use_bias=bias,
48
  )(x)
49
 
50
  if noise:
51
+ x = AddNoise(self.args.datatype)(x)
52
 
53
+ if bnorm:
54
  x = tf.keras.layers.BatchNormalization()(x)
55
 
56
  x = tf.keras.activations.swish(x)
57
 
58
  return x
59
 
60
+ def pixel_shuffle(self, x, factor=2):
61
+ bs_dim, h_dim, w_dim, c_dim = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
62
+ x = tf.reshape(x, [bs_dim, h_dim, w_dim, c_dim // factor, factor])
63
+ x = tf.transpose(x, [0, 1, 2, 4, 3])
64
+ return tf.reshape(x, [bs_dim, h_dim, w_dim * factor, c_dim // factor])
65
+
66
+ def adain(self, x, emb, name):
67
  emb = tf.keras.layers.Conv2D(
68
  x.shape[-1],
69
  kernel_size=(1, 1),
 
72
  padding="same",
73
  kernel_initializer=self.init,
74
  use_bias=True,
75
+ name=name,
76
  )(emb)
77
+ x = x / (tf.math.reduce_std(x, -2, keepdims=True) + 1e-5)
78
  return x * emb
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def conv_util_gen(
81
  self,
82
  inp,
 
87
  upsample=False,
88
  emb=None,
89
  se1=None,
90
+ name="0",
91
  ):
92
 
93
  x = inp
 
101
  padding="same",
102
  kernel_initializer=self.init,
103
  use_bias=True,
104
+ name=name + "c",
105
  )(x)
106
  else:
107
  x = tf.keras.layers.Conv2D(
 
112
  padding="same",
113
  kernel_initializer=self.init,
114
  use_bias=True,
115
+ name=name + "c",
116
  )(x)
117
 
118
  if noise:
119
+ x = AddNoise(self.args.datatype, name=name + "r")(x)
120
 
121
  if emb is not None:
122
+ x = self.adain(x, emb, name=name + "ai")
123
  else:
124
+ x = tf.keras.layers.BatchNormalization(name=name + "bn")(x)
 
 
125
 
126
+ x = tf.keras.activations.swish(x)
 
127
 
128
+ return x
129
 
130
+ def res_block_disc(self, inp, filters, kernel_size=(1, 3), kernel_size_2=None, strides=(1, 1), name="0"):
131
 
132
  if kernel_size_2 is None:
133
  kernel_size_2 = kernel_size
 
139
  activation="linear",
140
  padding="same",
141
  kernel_initializer=self.init,
142
+ name=name + "c0",
143
  )(inp)
144
  x = tf.keras.layers.LeakyReLU(0.2)(x)
145
  x = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * x
 
150
  activation="linear",
151
  padding="same",
152
  kernel_initializer=self.init,
153
+ name=name + "c1",
154
  )(x)
155
  x = tf.keras.layers.LeakyReLU(0.2)(x)
156
  x = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * x
 
167
  padding="same",
168
  kernel_initializer=self.init,
169
  use_bias=False,
170
+ name=name + "c3",
171
  )(inp)
172
 
173
  return x + inp
174
 
175
  def build_encoder2(self):
176
 
177
+ inpf = tf.keras.layers.Input((1, self.args.shape, self.args.hop // 4))
 
 
178
 
179
+ inpfls = tf.split(inpf, 8, -2)
180
  inpb = tf.concat(inpfls, 0)
181
 
182
+ g0 = self.conv_util(inpb, self.args.hop, kernel_size=(1, 3), strides=(1, 1), padding="same", bnorm=False)
183
+ g1 = self.conv_util(
184
+ g0, self.args.hop + self.args.hop // 2, kernel_size=(1, 3), strides=(1, 2), padding="valid", bnorm=False
185
+ )
186
+ g2 = self.conv_util(
187
+ g1, self.args.hop + self.args.hop // 2, kernel_size=(1, 3), strides=(1, 1), padding="same", bnorm=False
188
+ )
189
+ g3 = self.conv_util(g2, self.args.hop * 2, kernel_size=(1, 3), strides=(1, 2), padding="valid", bnorm=False)
190
+ g4 = self.conv_util(g3, self.args.hop * 2, kernel_size=(1, 3), strides=(1, 1), padding="same", bnorm=False)
191
+ g5 = self.conv_util(g4, self.args.hop * 3, kernel_size=(1, 3), strides=(1, 1), padding="valid", bnorm=False)
192
+ g5 = self.conv_util(g5, self.args.hop * 3, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False)
193
 
194
  g = tf.keras.layers.Conv2D(
195
+ self.args.latdepth,
196
  kernel_size=(1, 1),
197
  strides=1,
198
  padding="valid",
 
201
  activation="tanh",
202
  )(g5)
203
 
204
+ gls = tf.split(g, 8, 0)
205
  g = tf.concat(gls, -2)
206
  gls = tf.split(g, 2, -2)
207
  g = tf.concat(gls, 0)
208
 
209
  gf = tf.cast(g, tf.float32)
 
210
 
211
+ return tf.keras.Model(inpf, gf, name="ENC2")
212
 
213
+ def build_decoder2(self):
 
214
 
215
+ inpf = tf.keras.layers.Input((1, self.args.shape // 32, self.args.latdepth))
216
 
217
  g = inpf
218
 
219
+ g = self.conv_util(
220
+ g, self.args.hop * 3, kernel_size=(1, 3), strides=(1, 1), upsample=False, noise=True, bnorm=False
221
+ )
222
  g = self.conv_util(
223
  g,
224
+ self.args.hop * 2 + self.args.hop // 2,
225
  kernel_size=(1, 4),
226
+ strides=(1, 2),
227
+ upsample=True,
228
+ noise=True,
229
+ bnorm=False,
230
+ )
231
+ g = self.conv_util(
232
+ g,
233
+ self.args.hop * 2 + self.args.hop // 2,
234
+ kernel_size=(1, 3),
235
  strides=(1, 1),
236
  upsample=False,
237
  noise=True,
238
+ bnorm=False,
239
+ )
240
+ g = self.conv_util(
241
+ g, self.args.hop * 2, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False
242
+ )
243
+ g = self.conv_util(
244
+ g, self.args.hop * 2, kernel_size=(1, 3), strides=(1, 1), upsample=False, noise=True, bnorm=False
245
  )
246
  g = self.conv_util(
247
  g,
248
+ self.args.hop + self.args.hop // 2,
249
  kernel_size=(1, 4),
250
  strides=(1, 2),
251
  upsample=True,
252
  noise=True,
253
+ bnorm=False,
254
  )
255
+ g = self.conv_util(g, self.args.hop, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False)
 
 
256
 
257
  gf = tf.keras.layers.Conv2D(
258
+ self.args.hop // 4, kernel_size=(1, 1), strides=1, padding="same", kernel_initializer=self.init, name="cout"
 
 
 
 
 
259
  )(g)
260
 
261
  gfls = tf.split(gf, 2, 0)
 
263
 
264
  gf = tf.cast(gf, tf.float32)
265
 
266
+ return tf.keras.Model(inpf, gf, name="DEC2")
267
 
268
  def build_encoder(self):
269
 
270
  dim = ((4 * self.args.hop) // 2) + 1
271
 
272
+ inpf = tf.keras.layers.Input((dim, self.args.shape, 1))
273
 
274
  ginp = tf.transpose(inpf, [0, 3, 2, 1])
275
 
276
+ g0 = self.conv_util(ginp, self.args.hop * 4, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False)
277
+ g1 = self.conv_util(g0, self.args.hop * 4, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False)
278
+ g2 = self.conv_util(g1, self.args.hop * 4, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False)
279
+ g4 = self.conv_util(g2, self.args.hop * 4, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False)
280
+ g5 = self.conv_util(g4, self.args.hop * 4, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
  g = tf.keras.layers.Conv2D(
283
+ self.args.hop // 4, kernel_size=(1, 1), strides=1, padding="valid", kernel_initializer=self.init
284
+ )(g5)
285
+
286
+ g = tf.keras.activations.tanh(g)
287
+
288
+ gls = tf.split(g, 2, -2)
289
+ g = tf.concat(gls, 0)
290
 
291
+ gf = tf.cast(g, tf.float32)
 
292
 
293
+ return tf.keras.Model(inpf, gf, name="ENC")
 
294
 
295
  def build_decoder(self):
296
 
297
  dim = ((4 * self.args.hop) // 2) + 1
298
 
299
+ inpf = tf.keras.layers.Input((1, self.args.shape // 2, self.args.hop // 4))
300
 
301
  g = inpf
302
 
303
+ g0 = self.conv_util(g, self.args.hop * 3, kernel_size=(1, 3), strides=(1, 1), noise=True, bnorm=False)
304
+ g1 = self.conv_util(g0, self.args.hop * 3, kernel_size=(1, 3), strides=(1, 2), noise=True, bnorm=False)
305
+ g2 = self.conv_util(g1, self.args.hop * 2, kernel_size=(1, 3), strides=(1, 2), noise=True, bnorm=False)
306
+ g3 = self.conv_util(g2, self.args.hop, kernel_size=(1, 3), strides=(1, 2), noise=True, bnorm=False)
307
+ g = self.conv_util(g3, self.args.hop, kernel_size=(1, 3), strides=(1, 2), noise=True, bnorm=False)
308
 
309
+ g33 = self.conv_util(
310
+ g, self.args.hop, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False
 
 
 
 
 
311
  )
312
+ g22 = self.conv_util(
313
+ g3, self.args.hop * 2, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False
 
 
 
 
314
  )
315
+ g11 = self.conv_util(
316
+ g22 + g2, self.args.hop * 3, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False
317
+ )
318
+ g00 = self.conv_util(
319
+ g11 + g1, self.args.hop * 3, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False
 
 
 
320
  )
321
+
322
+ g = tf.keras.layers.Conv2D(
323
+ dim, kernel_size=(1, 1), strides=(1, 1), kernel_initializer=self.init, padding="same"
324
+ )(g00 + g0)
325
+ gf = tf.clip_by_value(g, -1.0, 1.0)
326
+
327
  g = self.conv_util(
328
+ g22, self.args.hop * 3, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False
 
 
 
 
 
329
  )
330
  g = self.conv_util(
331
+ g + g11, self.args.hop * 3, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False
 
 
 
 
 
332
  )
333
+ g = tf.keras.layers.Conv2D(
334
+ dim, kernel_size=(1, 1), strides=(1, 1), kernel_initializer=self.init, padding="same"
335
+ )(g + g00)
336
+ pf = tf.clip_by_value(g, -1.0, 1.0)
 
 
 
 
 
 
 
 
 
337
 
338
  gfls = tf.split(gf, self.args.shape // self.args.window, 0)
339
  gf = tf.concat(gfls, -2)
 
347
  s = tf.cast(tf.squeeze(s, -1), tf.float32)
348
  p = tf.cast(tf.squeeze(p, -1), tf.float32)
349
 
350
+ return tf.keras.Model(inpf, [s, p], name="DEC")
351
 
352
  def build_critic(self):
353
 
354
+ sinp = tf.keras.layers.Input(shape=(1, self.args.latlen, self.args.latdepth * 2))
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
  sf = tf.keras.layers.Conv2D(
357
+ self.args.base_channels * 3,
358
+ kernel_size=(1, 4),
359
  strides=(1, 2),
360
  activation="linear",
361
  padding="same",
362
  kernel_initializer=self.init,
363
+ name="1c",
364
+ )(sinp)
365
  sf = tf.keras.layers.LeakyReLU(0.2)(sf)
 
 
 
 
 
 
366
 
367
+ sf = self.res_block_disc(sf, self.args.base_channels * 4, kernel_size=(1, 4), strides=(1, 2), name="2")
368
 
369
+ sf = self.res_block_disc(sf, self.args.base_channels * 5, kernel_size=(1, 4), strides=(1, 2), name="3")
 
370
 
371
+ sf = self.res_block_disc(sf, self.args.base_channels * 6, kernel_size=(1, 4), strides=(1, 2), name="4")
372
 
373
+ sf = self.res_block_disc(sf, self.args.base_channels * 7, kernel_size=(1, 4), strides=(1, 2), name="5")
374
 
375
+ if not self.args.small:
376
+ sf = self.res_block_disc(
377
+ sf, self.args.base_channels * 7, kernel_size=(1, 4), strides=(1, 2), kernel_size_2=(1, 1), name="6"
378
+ )
379
 
380
+ sf = tf.keras.layers.Conv2D(
381
+ self.args.base_channels * 7,
382
+ kernel_size=(1, 3),
383
+ strides=(1, 1),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  activation="linear",
385
  padding="same",
386
  kernel_initializer=self.init,
387
+ name="7c",
388
  )(sf)
389
  sf = tf.keras.layers.LeakyReLU(0.2)(sf)
390
 
391
+ gf = tf.keras.layers.Dense(1, activation="linear", use_bias=True, kernel_initializer=self.init, name="7d")(
392
+ tf.keras.layers.Flatten()(sf)
393
+ )
 
 
 
 
 
394
 
395
  gf = tf.cast(gf, tf.float32)
396
 
397
+ return tf.keras.Model(sinp, gf, name="C")
398
 
399
  def build_generator(self):
400
 
401
  dim = self.args.latdepth * 2
402
 
403
+ inpf = tf.keras.layers.Input((self.args.latlen, self.args.latdepth * 2))
404
 
405
  inpfls = tf.split(inpf, 2, -2)
406
  inpb = tf.concat(inpfls, 0)
 
411
  inp3 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp2)
412
  inp4 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp3)
413
  inp5 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp4)
414
+ if not self.args.small:
415
+ inp6 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp5)
416
 
417
+ if not self.args.small:
418
+ g = tf.keras.layers.Dense(
419
+ 4 * (self.args.base_channels * 7),
420
+ activation="linear",
421
+ use_bias=True,
422
+ kernel_initializer=self.init,
423
+ name="00d",
424
+ )(tf.keras.layers.Flatten()(inp6))
425
+ g = tf.keras.layers.Reshape((1, 4, self.args.base_channels * 7))(g)
426
+ g = AddNoise(self.args.datatype, name="00n")(g)
427
+ g = self.adain(g, inp5, name="00ai")
428
+ g = tf.keras.activations.swish(g)
429
+ else:
430
+ g = tf.keras.layers.Dense(
431
+ 4 * (self.args.base_channels * 7),
432
+ activation="linear",
433
+ use_bias=True,
434
+ kernel_initializer=self.init,
435
+ name="00d",
436
+ )(tf.keras.layers.Flatten()(inp5))
437
+ g = tf.keras.layers.Reshape((1, 4, self.args.base_channels * 7))(g)
438
+ g = AddNoise(self.args.datatype, name="00n")(g)
439
+ g = self.adain(g, inp4, name="00ai")
440
+ g = tf.keras.activations.swish(g)
441
+
442
+ if not self.args.small:
443
+ g1 = self.conv_util_gen(
444
+ g,
445
+ self.args.base_channels * 6,
446
+ kernel_size=(1, 4),
447
+ strides=(1, 2),
448
+ upsample=True,
449
+ noise=True,
450
+ emb=inp4,
451
+ name="0",
452
+ )
453
+ g1 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g1
454
+ g1 = self.conv_util_gen(
455
+ g1,
456
+ self.args.base_channels * 6,
457
+ kernel_size=(1, 4),
458
+ strides=(1, 1),
459
+ upsample=False,
460
+ noise=True,
461
+ emb=inp4,
462
+ name="1",
463
+ )
464
+ g1 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g1
465
+ g1 = g1 + tf.keras.layers.Conv2D(
466
+ g1.shape[-1],
467
+ kernel_size=(1, 1),
468
+ strides=1,
469
+ activation="linear",
470
+ padding="same",
471
+ kernel_initializer=self.init,
472
+ use_bias=True,
473
+ name="res1c",
474
+ )(self.pixel_shuffle(g))
475
+ else:
476
+ g1 = self.conv_util_gen(
477
+ g,
478
+ self.args.base_channels * 6,
479
+ kernel_size=(1, 1),
480
+ strides=(1, 1),
481
+ upsample=False,
482
+ noise=True,
483
+ emb=inp4,
484
+ name="0_small",
485
+ )
486
+ g1 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g1
487
+ g1 = self.conv_util_gen(
488
+ g1,
489
+ self.args.base_channels * 6,
490
+ kernel_size=(1, 1),
491
+ strides=(1, 1),
492
+ upsample=False,
493
+ noise=True,
494
+ emb=inp4,
495
+ name="1_small",
496
+ )
497
+ g1 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g1
498
+ g1 = g1 + tf.keras.layers.Conv2D(
499
+ g1.shape[-1],
500
+ kernel_size=(1, 1),
501
+ strides=1,
502
+ activation="linear",
503
+ padding="same",
504
+ kernel_initializer=self.init,
505
+ use_bias=True,
506
+ name="res1c_small",
507
+ )(g)
508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  g2 = self.conv_util_gen(
510
  g1,
511
+ self.args.base_channels * 5,
512
  kernel_size=(1, 4),
513
  strides=(1, 2),
514
  upsample=True,
515
  noise=True,
516
  emb=inp3,
517
+ name="2",
518
  )
519
+ g2 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g2
520
+ g2 = self.conv_util_gen(
521
  g2,
522
+ self.args.base_channels * 5,
523
+ kernel_size=(1, 4),
524
  strides=(1, 1),
525
  upsample=False,
526
  noise=True,
527
  emb=inp3,
528
+ name="3",
529
  )
530
+ g2 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g2
531
+ g2 = g2 + tf.keras.layers.Conv2D(
532
+ g2.shape[-1],
533
+ kernel_size=(1, 1),
534
+ strides=1,
535
+ activation="linear",
536
+ padding="same",
537
+ kernel_initializer=self.init,
538
+ use_bias=True,
539
+ name="res2c",
540
+ )(self.pixel_shuffle(g1))
541
+
542
  g3 = self.conv_util_gen(
543
+ g2,
544
+ self.args.base_channels * 4,
545
  kernel_size=(1, 4),
546
  strides=(1, 2),
547
  upsample=True,
548
  noise=True,
549
  emb=inp2,
550
+ name="4",
551
  )
552
+ g3 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g3
553
  g3 = self.conv_util_gen(
554
  g3,
555
+ self.args.base_channels * 4,
556
+ kernel_size=(1, 4),
557
  strides=(1, 1),
558
  upsample=False,
559
  noise=True,
560
  emb=inp2,
561
+ name="5",
562
  )
563
+ g3 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g3
564
+ g3 = g3 + tf.keras.layers.Conv2D(
565
+ g3.shape[-1],
566
+ kernel_size=(1, 1),
567
+ strides=1,
568
+ activation="linear",
569
+ padding="same",
570
+ kernel_initializer=self.init,
571
+ use_bias=True,
572
+ name="res3c",
573
+ )(self.pixel_shuffle(g2))
574
+
575
  g4 = self.conv_util_gen(
576
  g3,
577
+ self.args.base_channels * 3,
578
  kernel_size=(1, 4),
579
  strides=(1, 2),
580
  upsample=True,
581
  noise=True,
582
  emb=inp1,
583
+ name="6",
584
  )
585
+ g4 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g4
586
  g4 = self.conv_util_gen(
587
  g4,
588
+ self.args.base_channels * 3,
589
+ kernel_size=(1, 4),
590
  strides=(1, 1),
591
  upsample=False,
592
  noise=True,
593
  emb=inp1,
594
+ name="7",
595
  )
596
+ g4 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g4
597
+ g4 = g4 + tf.keras.layers.Conv2D(
598
+ g4.shape[-1],
599
+ kernel_size=(1, 1),
600
+ strides=1,
601
+ activation="linear",
602
+ padding="same",
603
+ kernel_initializer=self.init,
604
+ use_bias=True,
605
+ name="res4c",
606
+ )(self.pixel_shuffle(g3))
607
+
608
  g5 = self.conv_util_gen(
609
  g4,
610
+ self.args.base_channels * 2,
611
  kernel_size=(1, 4),
612
  strides=(1, 2),
613
  upsample=True,
614
  noise=True,
615
  emb=tf.expand_dims(tf.cast(inpb, dtype=self.args.datatype), -3),
616
+ name="8",
617
  )
618
 
619
  gf = tf.keras.layers.Conv2D(
620
+ dim, kernel_size=(1, 4), strides=(1, 1), kernel_initializer=self.init, padding="same", name="9c"
 
 
 
 
 
621
  )(g5)
622
 
623
  gfls = tf.split(gf, 2, 0)
 
625
 
626
  gf = tf.cast(gf, tf.float32)
627
 
628
+ return tf.keras.Model(inpf, gf, name="GEN")
629
 
630
  # Load past models from path to resume training or test
631
  def load(self, path, load_dec=False):
 
635
  dec = self.build_decoder()
636
  enc2 = self.build_encoder2()
637
  dec2 = self.build_decoder2()
 
638
  gen_ema = self.build_generator()
639
 
640
+ switch = tf.Variable(-1.0, dtype=tf.float32)
641
+
642
  if self.args.mixed_precision:
643
+ opt_disc = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.5))
644
+ opt_dec = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.5))
645
  else:
646
  opt_disc = tf.keras.optimizers.Adam(0.0001, 0.9)
647
  opt_dec = tf.keras.optimizers.Adam(0.0001, 0.9)
 
649
  if load_dec:
650
  dec.load_weights(self.args.dec_path + "/dec.h5")
651
  dec2.load_weights(self.args.dec_path + "/dec2.h5")
652
+ enc.load_weights(self.args.dec_path + "/enc.h5")
653
+ enc2.load_weights(self.args.dec_path + "/enc2.h5")
654
 
655
  else:
656
+ grad_vars = critic.trainable_weights
657
  zero_grads = [tf.zeros_like(w) for w in grad_vars]
658
  opt_disc.apply_gradients(zip(zero_grads, grad_vars))
659
 
 
664
  if not self.args.testing:
665
  opt_disc.set_weights(np.load(path + "/opt_disc.npy", allow_pickle=True))
666
  opt_dec.set_weights(np.load(path + "/opt_dec.npy", allow_pickle=True))
 
 
667
  critic.load_weights(path + "/critic.h5")
668
  gen.load_weights(path + "/gen.h5")
669
+ switch = tf.Variable(float(np.load(path + "/switch.npy", allow_pickle=True)), dtype=tf.float32)
670
+
 
671
  gen_ema.load_weights(path + "/gen_ema.h5")
672
+ dec.load_weights(self.args.dec_path + "/dec.h5")
673
+ dec2.load_weights(self.args.dec_path + "/dec2.h5")
674
+ enc.load_weights(self.args.dec_path + "/enc.h5")
675
+ enc2.load_weights(self.args.dec_path + "/enc2.h5")
676
 
677
  return (
678
  critic,
 
681
  dec,
682
  enc2,
683
  dec2,
 
684
  gen_ema,
685
  [opt_dec, opt_disc],
686
+ switch,
687
  )
688
 
689
  def build(self):
 
693
  dec = self.build_decoder()
694
  enc2 = self.build_encoder2()
695
  dec2 = self.build_decoder2()
 
696
  gen_ema = self.build_generator()
697
 
698
+ switch = tf.Variable(-1.0, dtype=tf.float32)
699
+
700
  gen_ema = tf.keras.models.clone_model(gen)
701
  gen_ema.set_weights(gen.get_weights())
702
 
703
  if self.args.mixed_precision:
704
+ opt_disc = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.5))
705
+ opt_dec = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.5))
706
  else:
707
+ opt_disc = tf.keras.optimizers.Adam(0.0001, 0.5)
708
+ opt_dec = tf.keras.optimizers.Adam(0.0001, 0.5)
709
 
710
  return (
711
  critic,
 
714
  dec,
715
  enc2,
716
  dec2,
 
717
  gen_ema,
718
  [opt_dec, opt_disc],
719
+ switch,
720
  )
721
 
722
  def get_networks(self):
 
724
  critic,
725
  gen,
726
  enc,
727
+ dec,
728
  enc2,
729
+ dec2,
730
+ gen_ema_1,
 
731
  [opt_dec, opt_disc],
732
+ switch,
733
+ ) = self.load(self.args.load_path_1, load_dec=False)
734
+ print(f"Networks loaded from {self.args.load_path_1}")
735
 
736
  (
737
  critic,
738
  gen,
739
  enc,
740
+ dec,
741
  enc2,
742
+ dec2,
743
+ gen_ema_2,
 
744
  [opt_dec, opt_disc],
745
+ switch,
746
+ ) = self.load(self.args.load_path_2, load_dec=False)
747
+ print(f"Networks loaded from {self.args.load_path_2}")
748
 
749
+ (
750
  critic,
751
  gen,
752
  enc,
753
+ dec,
754
  enc2,
755
+ dec2,
756
+ gen_ema_3,
 
757
  [opt_dec, opt_disc],
758
+ switch,
759
+ ) = self.load(self.args.load_path_3, load_dec=False)
760
+ print(f"Networks loaded from {self.args.load_path_3}")
761
+
762
+ return (
763
+ (critic, gen, enc, dec, enc2, dec2, gen_ema_1, [opt_dec, opt_disc], switch),
764
+ (critic, gen, enc, dec, enc2, dec2, gen_ema_2, [opt_dec, opt_disc], switch),
765
+ (critic, gen, enc, dec, enc2, dec2, gen_ema_3, [opt_dec, opt_disc], switch),
766
+ )
767
 
768
  def initialize_networks(self):
769
 
770
+ (
771
+ (critic, gen, enc, dec, enc2, dec2, gen_ema_1, [opt_dec, opt_disc], switch),
772
+ (critic, gen, enc, dec, enc2, dec2, gen_ema_2, [opt_dec, opt_disc], switch),
773
+ (critic, gen, enc, dec, enc2, dec2, gen_ema_3, [opt_dec, opt_disc], switch),
774
+ ) = self.get_networks()
 
 
 
 
 
 
775
 
776
+ print(f"Critic params: {count_params(critic.trainable_variables)}")
777
+ print(f"Generator params: {count_params(gen.trainable_variables)}")
778
 
779
+ return (
780
+ (critic, gen, enc, dec, enc2, dec2, gen_ema_1, [opt_dec, opt_disc], switch),
781
+ (critic, gen, enc, dec, enc2, dec2, gen_ema_2, [opt_dec, opt_disc], switch),
782
+ (critic, gen, enc, dec, enc2, dec2, gen_ema_3, [opt_dec, opt_disc], switch),
783
+ )
 
 
 
 
 
 
musika.py → musika_test.py RENAMED
@@ -1,7 +1,10 @@
 
 
 
 
1
  from parse_test import parse_args
2
  from models import Models_functions
3
  from utils import Utils_functions
4
- import os
5
 
6
  if __name__ == "__main__":
7
 
@@ -10,8 +13,8 @@ if __name__ == "__main__":
10
 
11
  # initialize networks
12
  M = Models_functions(args)
13
- models_ls_techno, models_ls_classical = M.get_networks()
14
 
15
  # test musika
16
  U = Utils_functions(args)
17
- U.render_gradio(models_ls_techno, models_ls_classical, train=False)
 
1
+ import os
2
+
3
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
4
+
5
  from parse_test import parse_args
6
  from models import Models_functions
7
  from utils import Utils_functions
 
8
 
9
  if __name__ == "__main__":
10
 
 
13
 
14
  # initialize networks
15
  M = Models_functions(args)
16
+ models_ls_1, models_ls_2, models_ls_3 = M.get_networks()
17
 
18
  # test musika
19
  U = Utils_functions(args)
20
+ U.render_gradio(models_ls_1, models_ls_2, models_ls_3, train=False)
parse_test.py CHANGED
@@ -17,6 +17,17 @@ class EasyDict(dict):
17
  del self[name]
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
20
  def params_args(args):
21
  parser = argparse.ArgumentParser()
22
 
@@ -35,14 +46,14 @@ def params_args(args):
35
  parser.add_argument(
36
  "--sr",
37
  type=int,
38
- default=22050,
39
  help="Sampling Rate",
40
  )
41
  parser.add_argument(
42
- "--latlen",
43
- type=int,
44
- default=256,
45
- help="Length of generated latent vectors",
46
  )
47
  parser.add_argument(
48
  "--latdepth",
@@ -50,6 +61,18 @@ def params_args(args):
50
  default=64,
51
  help="Depth of generated latent vectors",
52
  )
 
 
 
 
 
 
 
 
 
 
 
 
53
  parser.add_argument(
54
  "--shape",
55
  type=int,
@@ -64,55 +87,55 @@ def params_args(args):
64
  )
65
  parser.add_argument(
66
  "--mu_rescale",
67
- type=int,
68
  default=-25.0,
69
  help="Spectrogram mu used to normalize",
70
  )
71
  parser.add_argument(
72
  "--sigma_rescale",
73
- type=int,
74
  default=75.0,
75
  help="Spectrogram sigma used to normalize",
76
  )
77
  parser.add_argument(
78
- "--load_path_techno",
79
  type=str,
80
  default="checkpoints/techno/",
81
- help="Path of pretrained networks weights (techno)",
82
  )
83
  parser.add_argument(
84
- "--load_path_classical",
85
  type=str,
86
- default="checkpoints/classical/",
87
- help="Path of pretrained networks weights (classical)",
88
  )
89
  parser.add_argument(
90
- "--dec_path_techno",
91
  type=str,
92
- default="checkpoints/techno/",
93
- help="Path of pretrained decoders weights (techno)",
94
  )
95
  parser.add_argument(
96
- "--dec_path_classical",
97
  type=str,
98
- default="checkpoints/classical/",
99
- help="Path of pretrained decoders weights (classical)",
100
  )
101
  parser.add_argument(
102
  "--testing",
103
- type=bool,
104
  default=True,
105
  help="True if optimizers weight do not need to be loaded",
106
  )
107
  parser.add_argument(
108
  "--cpu",
109
- type=bool,
110
  default=False,
111
  help="True if you wish to use cpu",
112
  )
113
  parser.add_argument(
114
  "--mixed_precision",
115
- type=bool,
116
  default=True,
117
  help="True if your GPU supports mixed precision",
118
  )
@@ -122,20 +145,28 @@ def params_args(args):
122
  args.hop = tmp_args.hop
123
  args.mel_bins = tmp_args.mel_bins
124
  args.sr = tmp_args.sr
125
- args.latlen = tmp_args.latlen
126
  args.latdepth = tmp_args.latdepth
 
 
127
  args.shape = tmp_args.shape
128
  args.window = tmp_args.window
129
  args.mu_rescale = tmp_args.mu_rescale
130
  args.sigma_rescale = tmp_args.sigma_rescale
131
- args.load_path_techno = tmp_args.load_path_techno
132
- args.load_path_classical = tmp_args.load_path_classical
133
- args.dec_path_techno = tmp_args.dec_path_techno
134
- args.dec_path_classical = tmp_args.dec_path_classical
135
  args.testing = tmp_args.testing
136
  args.cpu = tmp_args.cpu
137
  args.mixed_precision = tmp_args.mixed_precision
138
 
 
 
 
 
 
 
139
  print()
140
 
141
  args.datatype = tf.float32
 
17
  del self[name]
18
 
19
 
20
+ def str2bool(v):
21
+ if isinstance(v, bool):
22
+ return v
23
+ if v.lower() in ("yes", "true", "t", "y", "1"):
24
+ return True
25
+ elif v.lower() in ("no", "false", "f", "n", "0"):
26
+ return False
27
+ else:
28
+ raise argparse.ArgumentTypeError("Boolean value expected.")
29
+
30
+
31
  def params_args(args):
32
  parser = argparse.ArgumentParser()
33
 
 
46
  parser.add_argument(
47
  "--sr",
48
  type=int,
49
+ default=44100,
50
  help="Sampling Rate",
51
  )
52
  parser.add_argument(
53
+ "--small",
54
+ type=str2bool,
55
+ default=False,
56
+ help="If True, use model with shorter available context, useful for small datasets",
57
  )
58
  parser.add_argument(
59
  "--latdepth",
 
61
  default=64,
62
  help="Depth of generated latent vectors",
63
  )
64
+ parser.add_argument(
65
+ "--coorddepth",
66
+ type=int,
67
+ default=64,
68
+ help="Dimension of latent coordinate and style random vectors",
69
+ )
70
+ parser.add_argument(
71
+ "--base_channels",
72
+ type=int,
73
+ default=128,
74
+ help="Base channels for generator and discriminator architectures",
75
+ )
76
  parser.add_argument(
77
  "--shape",
78
  type=int,
 
87
  )
88
  parser.add_argument(
89
  "--mu_rescale",
90
+ type=float,
91
  default=-25.0,
92
  help="Spectrogram mu used to normalize",
93
  )
94
  parser.add_argument(
95
  "--sigma_rescale",
96
+ type=float,
97
  default=75.0,
98
  help="Spectrogram sigma used to normalize",
99
  )
100
  parser.add_argument(
101
+ "--load_path_1",
102
  type=str,
103
  default="checkpoints/techno/",
104
+ help="Path of pretrained networks weights 1",
105
  )
106
  parser.add_argument(
107
+ "--load_path_2",
108
  type=str,
109
+ default="checkpoints/metal/",
110
+ help="Path of pretrained networks weights 2",
111
  )
112
  parser.add_argument(
113
+ "--load_path_3",
114
  type=str,
115
+ default="checkpoints/misc/",
116
+ help="Path of pretrained networks weights 3",
117
  )
118
  parser.add_argument(
119
+ "--dec_path",
120
  type=str,
121
+ default="checkpoints/ae/",
122
+ help="Path of pretrained decoders weights",
123
  )
124
  parser.add_argument(
125
  "--testing",
126
+ type=str2bool,
127
  default=True,
128
  help="True if optimizers weight do not need to be loaded",
129
  )
130
  parser.add_argument(
131
  "--cpu",
132
+ type=str2bool,
133
  default=False,
134
  help="True if you wish to use cpu",
135
  )
136
  parser.add_argument(
137
  "--mixed_precision",
138
+ type=str2bool,
139
  default=True,
140
  help="True if your GPU supports mixed precision",
141
  )
 
145
  args.hop = tmp_args.hop
146
  args.mel_bins = tmp_args.mel_bins
147
  args.sr = tmp_args.sr
148
+ args.small = tmp_args.small
149
  args.latdepth = tmp_args.latdepth
150
+ args.coorddepth = tmp_args.coorddepth
151
+ args.base_channels = tmp_args.base_channels
152
  args.shape = tmp_args.shape
153
  args.window = tmp_args.window
154
  args.mu_rescale = tmp_args.mu_rescale
155
  args.sigma_rescale = tmp_args.sigma_rescale
156
+ args.load_path_1 = tmp_args.load_path_1
157
+ args.load_path_2 = tmp_args.load_path_2
158
+ args.load_path_3 = tmp_args.load_path_3
159
+ args.dec_path = tmp_args.dec_path
160
  args.testing = tmp_args.testing
161
  args.cpu = tmp_args.cpu
162
  args.mixed_precision = tmp_args.mixed_precision
163
 
164
+ if args.small:
165
+ args.latlen = 128
166
+ else:
167
+ args.latlen = 256
168
+ args.coordlen = (args.latlen // 2) * 3
169
+
170
  print()
171
 
172
  args.datatype = tf.float32
requirements.txt CHANGED
@@ -1,20 +1,12 @@
1
  # This file may be used to create an environment using:
2
  # $ conda create --name <env> --file <this file>
3
  # platform: linux-64
4
- audioread==2.1.9
5
- gradio==2.5.2
6
- ipython==7.29.0
7
  librosa==0.8.1
8
  matplotlib==3.4.3
9
  numpy==1.20.3
10
- pillow==8.4.0
11
- protobuf==3.20.1rc1
12
- scikit-learn==1.0.1
13
  scipy==1.7.1
14
- seaborn==0.11.2
15
- soundfile==0.10.3.post1
16
- tensorboard==2.7.0
17
- tensorflow==2.7.0
18
- tensorflow-addons==0.15.0
19
- tensorflow-io==0.22.0
20
  tqdm==4.62.3
 
 
1
  # This file may be used to create an environment using:
2
  # $ conda create --name <env> --file <this file>
3
  # platform: linux-64
4
+ gradio==3.3.1
 
 
5
  librosa==0.8.1
6
  matplotlib==3.4.3
7
  numpy==1.20.3
 
 
 
8
  scipy==1.7.1
9
+ tensorboard==2.10.0
10
+ tensorflow==2.10.0
 
 
 
 
11
  tqdm==4.62.3
12
+ pydub==0.25.1
utils.py CHANGED
@@ -1,21 +1,13 @@
1
- import io
2
  import os
3
- import random
4
  import time
 
5
  from glob import glob
6
-
7
- import IPython
8
  import librosa
9
  import matplotlib.pyplot as plt
10
  import numpy as np
11
- import soundfile as sf
12
  import tensorflow as tf
13
- import tensorflow_io as tfio
14
-
15
- from tensorflow.keras import mixed_precision
16
- from tensorflow.keras.optimizers import Adam
17
  from tensorflow.python.framework import random_seed
18
- from tqdm import tqdm
19
  import gradio as gr
20
  from scipy.io.wavfile import write as write_wav
21
 
@@ -27,14 +19,18 @@ class Utils_functions:
27
 
28
  melmat = tf.signal.linear_to_mel_weight_matrix(
29
  num_mel_bins=args.mel_bins,
30
- num_spectrogram_bins=(4 * args.hop) // 2 + 1,
31
  sample_rate=args.sr,
32
  lower_edge_hertz=0.0,
33
  upper_edge_hertz=args.sr // 2,
34
  )
35
  mel_f = tf.convert_to_tensor(librosa.mel_frequencies(n_mels=args.mel_bins + 2, fmin=0.0, fmax=args.sr // 2))
36
  enorm = tf.cast(
37
- tf.expand_dims(tf.constant(2.0 / (mel_f[2 : args.mel_bins + 2] - mel_f[: args.mel_bins])), 0,), tf.float32,
 
 
 
 
38
  )
39
  melmat = tf.multiply(melmat, enorm)
40
  melmat = tf.divide(melmat, tf.reduce_sum(melmat, axis=0))
@@ -149,7 +145,17 @@ class Utils_functions:
149
  S = self.normalize(self.power2db(tf.abs(X) ** 2, top_db=topdb))
150
  return tf.tensordot(S, self.melmat, 1)
151
 
152
- def distribute(self, x, model, bs=64, dual_out=False):
 
 
 
 
 
 
 
 
 
 
153
  outls = []
154
  if isinstance(x, list):
155
  bdim = x[0].shape[0]
@@ -161,14 +167,13 @@ class Utils_functions:
161
  outls.append(model(x[i * bs : i * bs + bs], training=False))
162
 
163
  if dual_out:
164
- return (
165
- np.concatenate([outls[k][0] for k in range(len(outls))], 0),
166
- np.concatenate([outls[k][1] for k in range(len(outls))], 0),
167
  )
168
  else:
169
  return np.concatenate(outls, 0)
170
 
171
- def distribute_enc(self, x, model, bs=64):
172
  outls = []
173
  if isinstance(x, list):
174
  bdim = x[0].shape[0]
@@ -185,9 +190,9 @@ class Utils_functions:
185
  res = tf.concat(resls, -2)
186
  outls.append(res)
187
 
188
- return np.concatenate(outls, 0)
189
 
190
- def distribute_dec(self, x, model, bs=64):
191
  outls = []
192
  bdim = x.shape[0]
193
  for i in range(((bdim - 2) // bs) + 1):
@@ -196,12 +201,11 @@ class Utils_functions:
196
  inp = tf.concat(inpls, 0)
197
  res = model(inp, training=False)
198
  outls.append(res)
199
- return (
200
- np.concatenate([outls[k][0] for k in range(len(outls))], 0),
201
- np.concatenate([outls[k][1] for k in range(len(outls))], 0),
202
  )
203
 
204
- def distribute_dec2(self, x, model, bs=64):
205
  outls = []
206
  bdim = x.shape[0]
207
  for i in range(((bdim - 2) // bs) + 1):
@@ -210,32 +214,105 @@ class Utils_functions:
210
  inp1 = tf.concat(inpls, 0)
211
  outls.append(model(inp1, training=False))
212
 
213
- return np.concatenate(outls, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  def get_noise_interp(self):
216
  noiseg = tf.random.normal([1, 64], dtype=tf.float32)
217
 
218
- noisel = tf.concat([tf.random.normal([1, 64], dtype=tf.float32), noiseg], -1)
219
- noisec = tf.concat([tf.random.normal([1, 64], dtype=tf.float32), noiseg], -1)
220
- noiser = tf.concat([tf.random.normal([1, 64], dtype=tf.float32), noiseg], -1)
221
 
222
- rl = tf.linspace(noisel, noisec, self.args.latlen + 1, axis=-2)[:, :-1, :]
223
- rr = tf.linspace(noisec, noiser, self.args.latlen + 1, axis=-2)
224
 
225
  noisetot = tf.concat([rl, rr], -2)
226
- return tf.image.random_crop(noisetot, [1, self.args.latlen, 64 + 64])
 
227
 
228
  def generate_example_stereo(self, models_ls):
229
- (critic, gen, enc, dec, enc2, dec2, critic_rec, gen_ema, [opt_dec, opt_disc],) = models_ls
 
 
 
 
 
 
 
 
 
 
230
  abb = gen_ema(self.get_noise_interp(), training=False)
231
- abbls = tf.split(abb, abb.shape[-2] // 16, -2)
232
  abb = tf.concat(abbls, 0)
233
 
234
  chls = []
235
  for channel in range(2):
236
 
237
  ab = self.distribute_dec2(
238
- abb[:, :, :, channel * self.args.latdepth : channel * self.args.latdepth + self.args.latdepth,], dec2,
 
 
 
 
 
 
239
  )
240
  abls = tf.split(ab, ab.shape[-2] // self.args.shape, -2)
241
  ab = tf.concat(abls, 0)
@@ -273,14 +350,28 @@ class Utils_functions:
273
 
274
  fig, axs = plt.subplots(nrows=4, ncols=1, figsize=(20, 20))
275
  axs[0].imshow(
276
- np.flip(np.array(tf.transpose(self.wv2spec_hop((abwv[:, 0] + abwv[:, 1]) / 2.0, 80.0, 256), [1, 0],)), -2,),
 
 
 
 
 
 
 
 
277
  cmap=None,
278
  )
279
  axs[0].axis("off")
280
  axs[0].set_title("Generated1")
281
  axs[1].imshow(
282
  np.flip(
283
- np.array(tf.transpose(self.wv2spec_hop((abwv2[:, 0] + abwv2[:, 1]) / 2.0, 80.0, 256), [1, 0],)), -2,
 
 
 
 
 
 
284
  ),
285
  cmap=None,
286
  )
@@ -288,7 +379,13 @@ class Utils_functions:
288
  axs[1].set_title("Generated2")
289
  axs[2].imshow(
290
  np.flip(
291
- np.array(tf.transpose(self.wv2spec_hop((abwv3[:, 0] + abwv3[:, 1]) / 2.0, 80.0, 256), [1, 0],)), -2,
 
 
 
 
 
 
292
  ),
293
  cmap=None,
294
  )
@@ -296,7 +393,13 @@ class Utils_functions:
296
  axs[2].set_title("Generated3")
297
  axs[3].imshow(
298
  np.flip(
299
- np.array(tf.transpose(self.wv2spec_hop((abwv4[:, 0] + abwv4[:, 1]) / 2.0, 80.0, 256), [1, 0],)), -2,
 
 
 
 
 
 
300
  ),
301
  cmap=None,
302
  )
@@ -304,18 +407,24 @@ class Utils_functions:
304
  axs[3].set_title("Generated4")
305
  # plt.show()
306
  plt.savefig(f"{path}/output.png")
 
307
 
308
- # Save in training loop
309
  def save_end(
310
- self, epoch, gloss, closs, mloss, models_ls=None, n_save=3, save_path="checkpoints",
 
 
 
 
 
 
 
311
  ):
312
- (critic, gen, enc, dec, enc2, dec2, critic_rec, gen_ema, [opt_dec, opt_disc],) = models_ls
313
  if epoch % n_save == 0:
314
  print("Saving...")
315
- path = f"{save_path}/MUSIKA!_-{str(gloss)[:9]}-{str(closs)[:9]}-{str(mloss)[:9]}"
316
  os.mkdir(path)
317
  critic.save_weights(path + "/critic.h5")
318
- critic_rec.save_weights(path + "/critic_rec.h5")
319
  gen.save_weights(path + "/gen.h5")
320
  gen_ema.save_weights(path + "/gen_ema.h5")
321
  # enc.save_weights(path + "/enc.h5")
@@ -324,84 +433,51 @@ class Utils_functions:
324
  # dec2.save_weights(path + "/dec2.h5")
325
  np.save(path + "/opt_dec.npy", opt_dec.get_weights())
326
  np.save(path + "/opt_disc.npy", opt_disc.get_weights())
 
327
  self.save_test_image_full(path, models_ls=models_ls)
328
 
329
  def truncated_normal(self, shape, bound=2.0, dtype=tf.float32):
330
  seed1, seed2 = random_seed.get_seed(tf.random.uniform((), tf.int32.min, tf.int32.max, dtype=tf.int32))
331
  return tf.random.stateless_parameterized_truncated_normal(shape, [seed1, seed2], 0.0, 1.0, -bound, bound)
332
 
333
- def distribute_gen(self, x, model, bs=64):
334
  outls = []
335
  bdim = x.shape[0]
336
  if bdim == 1:
337
  bdim = 2
338
  for i in range(((bdim - 2) // bs) + 1):
339
  outls.append(model(x[i * bs : i * bs + bs], training=False))
340
- return np.concatenate(outls, 0)
341
 
342
- def get_noise_interp_multi(self, fac=1, var=2.0):
343
- noiseg = self.truncated_normal([1, 64], var, dtype=tf.float32)
344
-
345
- if var < 1.75:
346
- var = 1.75
347
 
348
- noisels = [
349
- tf.concat([self.truncated_normal([1, 64], var, dtype=tf.float32), noiseg], -1) for i in range(2 + (fac - 1))
350
- ]
351
- rls = [
352
- tf.linspace(noisels[k], noisels[k + 1], self.args.latlen + 1, axis=-2)[:, :-1, :]
353
- for k in range(len(noisels) - 1)
354
- ]
355
- return tf.concat(rls, 0)
356
 
357
- def stfunc(self, genre, z, var, models_ls_techno, models_ls_classical):
 
358
 
359
- (
360
- critic,
361
- gen,
362
- enc,
363
- dec_techno,
364
- enc2,
365
- dec2_techno,
366
- critic_rec,
367
- gen_ema_techno,
368
- [opt_dec, opt_disc],
369
- ) = models_ls_techno
370
- (
371
- critic,
372
- gen,
373
- enc,
374
- dec_classical,
375
- enc2,
376
- dec2_classical,
377
- critic_rec,
378
- gen_ema_classical,
379
- [opt_dec, opt_disc],
380
- ) = models_ls_classical
381
 
382
- var = 0.01 + (3.5 * (var / 100.0))
 
 
383
 
384
- if z == 0:
385
- fac = 1
386
- elif z == 1:
387
- fac = 5
388
- else:
389
- fac = 10
390
 
391
- if genre == 0:
392
- dec = dec_techno
393
- dec2 = dec2_techno
394
- gen_ema = gen_ema_techno
395
- else:
396
- dec = dec_classical
397
- dec2 = dec2_classical
398
- gen_ema = gen_ema_classical
399
 
400
- bef = time.time()
401
- ab = self.distribute_gen(self.get_noise_interp_multi(fac, var), gen_ema)
402
- abls = tf.split(ab, ab.shape[0], 0)
403
- ab = tf.concat(abls, -2)
404
- abls = tf.split(ab, ab.shape[-2] // 16, -2)
405
  abi = tf.concat(abls, 0)
406
 
407
  chls = []
@@ -410,25 +486,135 @@ class Utils_functions:
410
  ab = self.distribute_dec2(
411
  abi[:, :, :, channel * self.args.latdepth : channel * self.args.latdepth + self.args.latdepth],
412
  dec2,
413
- bs=128,
414
  )
415
- # abls = tf.split(ab, ab.shape[-2] // (self.args.shape // 2), -2)
416
  abls = tf.split(ab, ab.shape[-2] // self.args.shape, -2)
417
  ab = tf.concat(abls, 0)
418
 
419
- ab_m, ab_p = self.distribute_dec(ab, dec, bs=128)
420
  abwv = self.conc_tog_specphase(ab_m, ab_p)
421
  chls.append(abwv)
422
 
423
- print(
424
- f"Time for complete generation pipeline: {time.time()-bef} s {int(np.round((fac*23.)/(time.time()-bef)))}x faster than Real Time!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  )
426
 
427
- abwvc = np.clip(np.squeeze(np.stack(chls, -1)), -1.0, 1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  spec = np.flip(
429
  np.array(
430
  tf.transpose(
431
- self.wv2spec_hop((abwvc[: 23 * self.args.sr, 0] + abwvc[: 23 * self.args.sr, 1]) / 2.0, 80.0, 256),
 
 
432
  [1, 0],
433
  )
434
  ),
@@ -436,46 +622,53 @@ class Utils_functions:
436
  )
437
 
438
  return (
439
- spec,
440
  (self.args.sr, np.int16(abwvc * 32767.0)),
441
  )
442
 
443
- def render_gradio(self, models_ls_techno, models_ls_classical, train=True):
444
- article_text = "Original work by Marco Pasini ([Twitter](https://twitter.com/marco_ppasini)) and Jan Schlüter at Johannes Kepler Universität Linz."
445
 
446
- def gradio_func(x, y, z):
447
- return self.stfunc(x, y, z, models_ls_techno, models_ls_classical)
 
 
 
 
 
 
 
448
 
449
  iface = gr.Interface(
450
  fn=gradio_func,
451
  inputs=[
452
- gr.inputs.Radio(
453
- choices=["Techno/Experimental", "Classical"],
454
  type="index",
455
- default="Classical",
456
  label="Music Genre to Generate",
457
  ),
458
- gr.inputs.Radio(
459
- choices=["23s", "1m 58s", "3m 57s"], type="index", default="1m 58s", label="Generated Music Length",
 
 
 
460
  ),
461
- gr.inputs.Slider(
462
- minimum=0,
463
- maximum=100,
464
- step=1,
465
- default=25,
466
- label="Stability[left]/Variety[right] Tradeoff (Truncation Trick)",
467
  ),
468
  ],
469
  outputs=[
470
- gr.outputs.Image(label="Log-MelSpectrogram of Generated Audio (first 23 sec)"),
471
- gr.outputs.Audio(type="numpy", label="Generated Audio"),
472
  ],
473
- allow_screenshot=False,
474
  title="musika!",
475
- description="Blazingly Fast Stereo Waveform Music Generation of Arbitrary Length. Be patient and enjoy the weirdness!",
476
  article=article_text,
477
- layout="vertical",
478
- theme="huggingface",
479
  )
480
 
481
  print("--------------------------------")
 
 
1
  import os
 
2
  import time
3
+ import datetime
4
  from glob import glob
5
+ from tqdm import tqdm
 
6
  import librosa
7
  import matplotlib.pyplot as plt
8
  import numpy as np
 
9
  import tensorflow as tf
 
 
 
 
10
  from tensorflow.python.framework import random_seed
 
11
  import gradio as gr
12
  from scipy.io.wavfile import write as write_wav
13
 
 
19
 
20
  melmat = tf.signal.linear_to_mel_weight_matrix(
21
  num_mel_bins=args.mel_bins,
22
+ num_spectrogram_bins=(4 * args.hop * 2) // 2 + 1,
23
  sample_rate=args.sr,
24
  lower_edge_hertz=0.0,
25
  upper_edge_hertz=args.sr // 2,
26
  )
27
  mel_f = tf.convert_to_tensor(librosa.mel_frequencies(n_mels=args.mel_bins + 2, fmin=0.0, fmax=args.sr // 2))
28
  enorm = tf.cast(
29
+ tf.expand_dims(
30
+ tf.constant(2.0 / (mel_f[2 : args.mel_bins + 2] - mel_f[: args.mel_bins])),
31
+ 0,
32
+ ),
33
+ tf.float32,
34
  )
35
  melmat = tf.multiply(melmat, enorm)
36
  melmat = tf.divide(melmat, tf.reduce_sum(melmat, axis=0))
 
145
  S = self.normalize(self.power2db(tf.abs(X) ** 2, top_db=topdb))
146
  return tf.tensordot(S, self.melmat, 1)
147
 
148
+ def rand_channel_swap(self, x):
149
+ s_l, s_r = tf.split(x, 2, -1)
150
+ if tf.random.uniform((), dtype=tf.float32) > 0.5:
151
+ sl = s_l
152
+ sr = s_r
153
+ else:
154
+ sl = s_r
155
+ sr = s_l
156
+ return tf.concat([sl, sr], -1)
157
+
158
+ def distribute(self, x, model, bs=32, dual_out=False):
159
  outls = []
160
  if isinstance(x, list):
161
  bdim = x[0].shape[0]
 
167
  outls.append(model(x[i * bs : i * bs + bs], training=False))
168
 
169
  if dual_out:
170
+ return np.concatenate([outls[k][0] for k in range(len(outls))], 0), np.concatenate(
171
+ [outls[k][1] for k in range(len(outls))], 0
 
172
  )
173
  else:
174
  return np.concatenate(outls, 0)
175
 
176
+ def distribute_enc(self, x, model, bs=32):
177
  outls = []
178
  if isinstance(x, list):
179
  bdim = x[0].shape[0]
 
190
  res = tf.concat(resls, -2)
191
  outls.append(res)
192
 
193
+ return tf.concat(outls, 0)
194
 
195
+ def distribute_dec(self, x, model, bs=32):
196
  outls = []
197
  bdim = x.shape[0]
198
  for i in range(((bdim - 2) // bs) + 1):
 
201
  inp = tf.concat(inpls, 0)
202
  res = model(inp, training=False)
203
  outls.append(res)
204
+ return np.concatenate([outls[k][0] for k in range(len(outls))], 0), np.concatenate(
205
+ [outls[k][1] for k in range(len(outls))], 0
 
206
  )
207
 
208
+ def distribute_dec2(self, x, model, bs=32):
209
  outls = []
210
  bdim = x.shape[0]
211
  for i in range(((bdim - 2) // bs) + 1):
 
214
  inp1 = tf.concat(inpls, 0)
215
  outls.append(model(inp1, training=False))
216
 
217
+ return tf.concat(outls, 0)
218
+
219
+ def center_coordinate(
220
+ self, x
221
+ ): # allows to have sequences with even number length with anchor in the middle of the sequence
222
+ return tf.reduce_mean(tf.stack([x, tf.roll(x, -1, -2)], 0), 0)[:, :-1, :]
223
+
224
+ # hardcoded! need to figure out how to generalize it without breaking jit compiling
225
+ def crop_coordinate(
226
+ self, x
227
+ ): # randomly crops a conditioning sequence such that the anchor vector is at center of generator receptive field (maximum context is provided to the generator)
228
+ fac = tf.random.uniform((), 0, self.args.coordlen // (self.args.latlen // 2), dtype=tf.int32)
229
+ if fac == 0:
230
+ return tf.reshape(
231
+ x[
232
+ :,
233
+ (self.args.latlen // 4)
234
+ + 0 * (self.args.latlen // 2) : (self.args.latlen // 4)
235
+ + 0 * (self.args.latlen // 2)
236
+ + self.args.latlen,
237
+ :,
238
+ ],
239
+ [-1, self.args.latlen, x.shape[-1]],
240
+ )
241
+ elif fac == 1:
242
+ return tf.reshape(
243
+ x[
244
+ :,
245
+ (self.args.latlen // 4)
246
+ + 1 * (self.args.latlen // 2) : (self.args.latlen // 4)
247
+ + 1 * (self.args.latlen // 2)
248
+ + self.args.latlen,
249
+ :,
250
+ ],
251
+ [-1, self.args.latlen, x.shape[-1]],
252
+ )
253
+ else:
254
+ return tf.reshape(
255
+ x[
256
+ :,
257
+ (self.args.latlen // 4)
258
+ + 2 * (self.args.latlen // 2) : (self.args.latlen // 4)
259
+ + 2 * (self.args.latlen // 2)
260
+ + self.args.latlen,
261
+ :,
262
+ ],
263
+ [-1, self.args.latlen, x.shape[-1]],
264
+ )
265
+
266
+ def update_switch(self, switch, ca, cab, learning_rate_switch=0.0001, stable_point=0.9):
267
+ cert = tf.math.minimum(tf.math.maximum(tf.reduce_mean(ca) - tf.reduce_mean(cab), 0.0), 2.0) / 2.0
268
+
269
+ if cert > stable_point:
270
+ switch_new = switch - learning_rate_switch
271
+ else:
272
+ switch_new = switch + learning_rate_switch
273
+ return tf.math.maximum(tf.math.minimum(switch_new, 0.0), -1.0)
274
 
275
  def get_noise_interp(self):
276
  noiseg = tf.random.normal([1, 64], dtype=tf.float32)
277
 
278
+ noisel = tf.concat([tf.random.normal([1, self.args.coorddepth], dtype=tf.float32), noiseg], -1)
279
+ noisec = tf.concat([tf.random.normal([1, self.args.coorddepth], dtype=tf.float32), noiseg], -1)
280
+ noiser = tf.concat([tf.random.normal([1, self.args.coorddepth], dtype=tf.float32), noiseg], -1)
281
 
282
+ rl = tf.linspace(noisel, noisec, self.args.coordlen + 1, axis=-2)[:, :-1, :]
283
+ rr = tf.linspace(noisec, noiser, self.args.coordlen + 1, axis=-2)
284
 
285
  noisetot = tf.concat([rl, rr], -2)
286
+ noisetot = self.center_coordinate(noisetot)
287
+ return self.crop_coordinate(noisetot)
288
 
289
  def generate_example_stereo(self, models_ls):
290
+ (
291
+ critic,
292
+ gen,
293
+ enc,
294
+ dec,
295
+ enc2,
296
+ dec2,
297
+ gen_ema,
298
+ [opt_dec, opt_disc],
299
+ switch,
300
+ ) = models_ls
301
  abb = gen_ema(self.get_noise_interp(), training=False)
302
+ abbls = tf.split(abb, abb.shape[-2] // 8, -2)
303
  abb = tf.concat(abbls, 0)
304
 
305
  chls = []
306
  for channel in range(2):
307
 
308
  ab = self.distribute_dec2(
309
+ abb[
310
+ :,
311
+ :,
312
+ :,
313
+ channel * self.args.latdepth : channel * self.args.latdepth + self.args.latdepth,
314
+ ],
315
+ dec2,
316
  )
317
  abls = tf.split(ab, ab.shape[-2] // self.args.shape, -2)
318
  ab = tf.concat(abls, 0)
 
350
 
351
  fig, axs = plt.subplots(nrows=4, ncols=1, figsize=(20, 20))
352
  axs[0].imshow(
353
+ np.flip(
354
+ np.array(
355
+ tf.transpose(
356
+ self.wv2spec_hop((abwv[:, 0] + abwv[:, 1]) / 2.0, 80.0, self.args.hop * 2),
357
+ [1, 0],
358
+ )
359
+ ),
360
+ -2,
361
+ ),
362
  cmap=None,
363
  )
364
  axs[0].axis("off")
365
  axs[0].set_title("Generated1")
366
  axs[1].imshow(
367
  np.flip(
368
+ np.array(
369
+ tf.transpose(
370
+ self.wv2spec_hop((abwv2[:, 0] + abwv2[:, 1]) / 2.0, 80.0, self.args.hop * 2),
371
+ [1, 0],
372
+ )
373
+ ),
374
+ -2,
375
  ),
376
  cmap=None,
377
  )
 
379
  axs[1].set_title("Generated2")
380
  axs[2].imshow(
381
  np.flip(
382
+ np.array(
383
+ tf.transpose(
384
+ self.wv2spec_hop((abwv3[:, 0] + abwv3[:, 1]) / 2.0, 80.0, self.args.hop * 2),
385
+ [1, 0],
386
+ )
387
+ ),
388
+ -2,
389
  ),
390
  cmap=None,
391
  )
 
393
  axs[2].set_title("Generated3")
394
  axs[3].imshow(
395
  np.flip(
396
+ np.array(
397
+ tf.transpose(
398
+ self.wv2spec_hop((abwv4[:, 0] + abwv4[:, 1]) / 2.0, 80.0, self.args.hop * 2),
399
+ [1, 0],
400
+ )
401
+ ),
402
+ -2,
403
  ),
404
  cmap=None,
405
  )
 
407
  axs[3].set_title("Generated4")
408
  # plt.show()
409
  plt.savefig(f"{path}/output.png")
410
+ plt.close()
411
 
 
412
  def save_end(
413
+ self,
414
+ epoch,
415
+ gloss,
416
+ closs,
417
+ mloss,
418
+ models_ls=None,
419
+ n_save=3,
420
+ save_path="checkpoints",
421
  ):
422
+ (critic, gen, enc, dec, enc2, dec2, gen_ema, [opt_dec, opt_disc], switch) = models_ls
423
  if epoch % n_save == 0:
424
  print("Saving...")
425
+ path = f"{save_path}/MUSIKA_iterations-{((epoch+1)*self.args.totsamples)//(self.args.bs*1000)}k_losses-{str(gloss)[:9]}-{str(closs)[:9]}-{str(mloss)[:9]}"
426
  os.mkdir(path)
427
  critic.save_weights(path + "/critic.h5")
 
428
  gen.save_weights(path + "/gen.h5")
429
  gen_ema.save_weights(path + "/gen_ema.h5")
430
  # enc.save_weights(path + "/enc.h5")
 
433
  # dec2.save_weights(path + "/dec2.h5")
434
  np.save(path + "/opt_dec.npy", opt_dec.get_weights())
435
  np.save(path + "/opt_disc.npy", opt_disc.get_weights())
436
+ np.save(path + "/switch.npy", switch.numpy())
437
  self.save_test_image_full(path, models_ls=models_ls)
438
 
439
  def truncated_normal(self, shape, bound=2.0, dtype=tf.float32):
440
  seed1, seed2 = random_seed.get_seed(tf.random.uniform((), tf.int32.min, tf.int32.max, dtype=tf.int32))
441
  return tf.random.stateless_parameterized_truncated_normal(shape, [seed1, seed2], 0.0, 1.0, -bound, bound)
442
 
443
+ def distribute_gen(self, x, model, bs=32):
444
  outls = []
445
  bdim = x.shape[0]
446
  if bdim == 1:
447
  bdim = 2
448
  for i in range(((bdim - 2) // bs) + 1):
449
  outls.append(model(x[i * bs : i * bs + bs], training=False))
450
+ return tf.concat(outls, 0)
451
 
452
+ def generate_waveform(self, inp, gen_ema, dec, dec2, batch_size=64):
 
 
 
 
453
 
454
+ ab = self.distribute_gen(inp, gen_ema, bs=batch_size)
455
+ abls = tf.split(ab, ab.shape[0], 0)
456
+ ab = tf.concat(abls, -2)
457
+ abls = tf.split(ab, ab.shape[-2] // 8, -2)
458
+ abi = tf.concat(abls, 0)
 
 
 
459
 
460
+ chls = []
461
+ for channel in range(2):
462
 
463
+ ab = self.distribute_dec2(
464
+ abi[:, :, :, channel * self.args.latdepth : channel * self.args.latdepth + self.args.latdepth],
465
+ dec2,
466
+ bs=batch_size,
467
+ )
468
+ abls = tf.split(ab, ab.shape[-2] // self.args.shape, -2)
469
+ ab = tf.concat(abls, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
+ ab_m, ab_p = self.distribute_dec(ab, dec, bs=batch_size)
472
+ abwv = self.conc_tog_specphase(ab_m, ab_p)
473
+ chls.append(abwv)
474
 
475
+ return np.clip(np.squeeze(np.stack(chls, -1)), -1.0, 1.0)
 
 
 
 
 
476
 
477
+ def decode_waveform(self, lat, dec, dec2, batch_size=64):
 
 
 
 
 
 
 
478
 
479
+ lat = lat[:, :, : (lat.shape[-2] // 8) * 8, :]
480
+ abls = tf.split(lat, lat.shape[-2] // 8, -2)
 
 
 
481
  abi = tf.concat(abls, 0)
482
 
483
  chls = []
 
486
  ab = self.distribute_dec2(
487
  abi[:, :, :, channel * self.args.latdepth : channel * self.args.latdepth + self.args.latdepth],
488
  dec2,
489
+ bs=batch_size,
490
  )
 
491
  abls = tf.split(ab, ab.shape[-2] // self.args.shape, -2)
492
  ab = tf.concat(abls, 0)
493
 
494
+ ab_m, ab_p = self.distribute_dec(ab, dec, bs=batch_size)
495
  abwv = self.conc_tog_specphase(ab_m, ab_p)
496
  chls.append(abwv)
497
 
498
+ return np.clip(np.squeeze(np.stack(chls, -1)), -1.0, 1.0)
499
+
500
+ def get_noise_interp_multi(self, fac=1, var=2.0):
501
+ noiseg = self.truncated_normal([1, self.args.coorddepth], var, dtype=tf.float32)
502
+
503
+ coordratio = self.args.coordlen // self.args.latlen
504
+
505
+ noisels = [
506
+ tf.concat([self.truncated_normal([1, 64], var, dtype=tf.float32), noiseg], -1)
507
+ for i in range(3 + ((fac - 1) // coordratio))
508
+ ]
509
+ rls = tf.concat(
510
+ [
511
+ tf.linspace(noisels[k], noisels[k + 1], self.args.coordlen + 1, axis=-2)[:, :-1, :]
512
+ for k in range(len(noisels) - 1)
513
+ ],
514
+ -2,
515
  )
516
 
517
+ rls = self.center_coordinate(rls)
518
+ rls = rls[:, self.args.latlen // 4 :, :]
519
+ rls = rls[:, : (rls.shape[-2] // self.args.latlen) * self.args.latlen, :]
520
+
521
+ rls = tf.split(rls, rls.shape[-2] // self.args.latlen, -2)
522
+
523
+ return tf.concat(rls[:fac], 0)
524
+
525
+ def get_noise_interp_loop(self, fac=1, var=2.0):
526
+ noiseg = self.truncated_normal([1, self.args.coorddepth], var, dtype=tf.float32)
527
+
528
+ coordratio = self.args.coordlen // self.args.latlen
529
+
530
+ noisels_pre = [tf.concat([self.truncated_normal([1, 64], var, dtype=tf.float32), noiseg], -1) for i in range(2)]
531
+ noisels = []
532
+ for k in range(fac + 2):
533
+ noisels.append(noisels_pre[0])
534
+ noisels.append(noisels_pre[1])
535
+ rls = tf.concat(
536
+ [
537
+ tf.linspace(noisels[k], noisels[k + 1], self.args.latlen // 2 + 1, axis=-2)[:, :-1, :]
538
+ for k in range(len(noisels) - 1)
539
+ ],
540
+ -2,
541
+ )
542
+
543
+ rls = self.center_coordinate(rls)
544
+ rls = rls[:, self.args.latlen // 2 :, :]
545
+ rls = rls[:, : (rls.shape[-2] // self.args.latlen) * self.args.latlen, :]
546
+
547
+ rls = tf.split(rls, rls.shape[-2] // self.args.latlen, -2)
548
+
549
+ return tf.concat(rls[:fac], 0)
550
+
551
+ def generate(self, models_ls):
552
+ critic, gen, enc, dec, enc2, dec2, gen_ema, [opt_dec, opt_disc], switch = models_ls
553
+ os.makedirs(self.args.save_path, exist_ok=True)
554
+ fac = (self.args.seconds // 23) + 1
555
+ print(f"Generating {self.args.num_samples} samples...")
556
+ for i in tqdm(range(self.args.num_samples)):
557
+ wv = self.generate_waveform(
558
+ self.get_noise_interp_multi(fac, self.args.truncation), gen_ema, dec, dec2, batch_size=64
559
+ )
560
+ dt = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
561
+ write_wav(
562
+ f"{self.args.save_path}/{i}_{dt}.wav", self.args.sr, np.squeeze(wv)[: self.args.seconds * self.args.sr]
563
+ )
564
+
565
+ def decode_path(self, models_ls):
566
+ critic, gen, enc, dec, enc2, dec2, gen_ema, [opt_dec, opt_disc], switch = models_ls
567
+ os.makedirs(self.args.save_path, exist_ok=True)
568
+ pathls = glob(self.args.files_path + "/*.npy")
569
+ print(f"Decoding {len(pathls)} samples...")
570
+ for p in tqdm(pathls):
571
+ tp, ext = os.path.splitext(p)
572
+ bname = os.path.basename(tp)
573
+ lat = np.load(p, allow_pickle=True)
574
+ lat = tf.expand_dims(lat, 0)
575
+ lat = tf.expand_dims(lat, 0)
576
+ wv = self.decode_waveform(lat, dec, dec2, batch_size=64)
577
+ # dt = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
578
+ write_wav(f"{self.args.save_path}/{bname}.wav", self.args.sr, np.squeeze(wv))
579
+
580
+ def stfunc(self, genre, z, var, models_ls_1, models_ls_2, models_ls_3):
581
+
582
+ critic, gen, enc, dec, enc2, dec2, gen_ema_1, [opt_dec, opt_disc], switch = models_ls_1
583
+ critic, gen, enc, dec, enc2, dec2, gen_ema_2, [opt_dec, opt_disc], switch = models_ls_2
584
+ critic, gen, enc, dec, enc2, dec2, gen_ema_3, [opt_dec, opt_disc], switch = models_ls_3
585
+
586
+ if genre == 0:
587
+ gen_ema = gen_ema_1
588
+ elif genre == 1:
589
+ gen_ema = gen_ema_2
590
+ else:
591
+ gen_ema = gen_ema_3
592
+
593
+ var = float(var)
594
+
595
+ if z == 0:
596
+ fac = 1
597
+ elif z == 1:
598
+ fac = 5
599
+ else:
600
+ fac = 10
601
+
602
+ bef = time.time()
603
+
604
+ noiseinp = self.get_noise_interp_multi(fac, var)
605
+
606
+ abwvc = self.generate_waveform(noiseinp, gen_ema, dec, dec2, batch_size=64)
607
+
608
+ # print(
609
+ # f"Time for complete generation pipeline: {time.time()-bef} s {int(np.round((fac*23.)/(time.time()-bef)))}x faster than Real Time!"
610
+ # )
611
+
612
  spec = np.flip(
613
  np.array(
614
  tf.transpose(
615
+ self.wv2spec_hop(
616
+ (abwvc[: 23 * self.args.sr, 0] + abwvc[: 23 * self.args.sr, 1]) / 2.0, 80.0, self.args.hop * 2
617
+ ),
618
  [1, 0],
619
  )
620
  ),
 
622
  )
623
 
624
  return (
625
+ np.clip(spec, -1.0, 1.0),
626
  (self.args.sr, np.int16(abwvc * 32767.0)),
627
  )
628
 
629
+ def render_gradio(self, models_ls_1, models_ls_2, models_ls_3, train=True):
630
+ article_text = "Original work by Marco Pasini ([Twitter](https://twitter.com/marco_ppasini)) at the Institute of Computational Perception, JKU Linz. Supervised by Jan Schlüter."
631
 
632
+ def gradio_func(genre, x, y):
633
+ return self.stfunc(genre, x, y, models_ls_1, models_ls_2, models_ls_3)
634
+
635
+ if self.args.small:
636
+ durations = ["11s", "59s", "1m 58s"]
637
+ durations_default = "59s"
638
+ else:
639
+ durations = ["23s", "1m 58s", "3m 57s"]
640
+ durations_default = "1m 58s"
641
 
642
  iface = gr.Interface(
643
  fn=gradio_func,
644
  inputs=[
645
+ gr.Radio(
646
+ choices=["Techno/Experimental", "Death Metal (finetuned)", "Misc"],
647
  type="index",
648
+ value="Techno/Experimental",
649
  label="Music Genre to Generate",
650
  ),
651
+ gr.Radio(
652
+ choices=durations,
653
+ type="index",
654
+ value=durations_default,
655
+ label="Generated Music Length",
656
  ),
657
+ gr.Slider(
658
+ minimum=0.1,
659
+ maximum=3.9,
660
+ step=0.1,
661
+ value=1.8,
662
+ label="How much do you want the music style to be varied? (Stddev truncation for random vectors)",
663
  ),
664
  ],
665
  outputs=[
666
+ gr.Image(label="Log-MelSpectrogram of Generated Audio (first 23 s)"),
667
+ gr.Audio(type="numpy", label="Generated Audio"),
668
  ],
 
669
  title="musika!",
670
+ description="Blazingly Fast 44.1 kHz Stereo Waveform Music Generation of Arbitrary Length. Be patient and enjoy the weirdness!",
671
  article=article_text,
 
 
672
  )
673
 
674
  print("--------------------------------")