Spaces:
Runtime error
Runtime error
add out-painting feature
Browse files- README.md +5 -1
- audiodiffusion/__init__.py +32 -8
- audiodiffusion/mel.py +13 -7
- notebooks/test_model.ipynb +71 -38
README.md
CHANGED
@@ -15,8 +15,12 @@ license: gpl-3.0
|
|
15 |
|
16 |
---
|
17 |
|
18 |
-
**
|
19 |
|
|
|
|
|
|
|
|
|
20 |
You can now generate an audio based on a previous one. You can use this to generate variations of the same audio or even to "remix" a track (via a sort of "style transfer"). You can find examples of how to do this in the [`test_model.ipynb`](https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/test_model.ipynb) notebook.
|
21 |
|
22 |
---
|
|
|
15 |
|
16 |
---
|
17 |
|
18 |
+
**UPDATES**:
|
19 |
|
20 |
+
4/10/2022
|
21 |
+
It is now possible to mask parts of the input audio during generation which means you can stitch several samples together (think "out-painting").
|
22 |
+
|
23 |
+
27/9/2022
|
24 |
You can now generate an audio based on a previous one. You can use this to generate variations of the same audio or even to "remix" a track (via a sort of "style transfer"). You can find examples of how to do this in the [`test_model.ipynb`](https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/test_model.ipynb) notebook.
|
25 |
|
26 |
---
|
audiodiffusion/__init__.py
CHANGED
@@ -9,7 +9,7 @@ from diffusers import DDPMPipeline, DDPMScheduler
|
|
9 |
|
10 |
from .mel import Mel
|
11 |
|
12 |
-
VERSION = "1.1.
|
13 |
|
14 |
|
15 |
class AudioDiffusion:
|
@@ -61,7 +61,9 @@ class AudioDiffusion:
|
|
61 |
slice: int = 0,
|
62 |
start_step: int = 0,
|
63 |
steps: int = None,
|
64 |
-
generator: torch.Generator = None
|
|
|
|
|
65 |
) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
|
66 |
"""Generate random mel spectrogram from audio input and convert to audio.
|
67 |
|
@@ -72,6 +74,8 @@ class AudioDiffusion:
|
|
72 |
start_step (int): step to start from
|
73 |
steps (int): number of de-noising steps to perform (defaults to num_train_timesteps)
|
74 |
generator (torch.Generator): random number generator or None
|
|
|
|
|
75 |
|
76 |
Returns:
|
77 |
PIL Image: mel spectrogram
|
@@ -84,31 +88,51 @@ class AudioDiffusion:
|
|
84 |
steps = self.ddpm.scheduler.num_train_timesteps
|
85 |
scheduler = DDPMScheduler(num_train_timesteps=steps)
|
86 |
scheduler.set_timesteps(steps)
|
87 |
-
|
|
|
88 |
(1, self.ddpm.unet.in_channels, self.ddpm.unet.sample_size,
|
89 |
self.ddpm.unet.sample_size),
|
90 |
generator=generator,
|
91 |
)
|
|
|
92 |
if audio_file is not None or raw_audio is not None:
|
93 |
self.mel.load_audio(audio_file, raw_audio)
|
94 |
input_image = self.mel.audio_slice_to_image(slice)
|
95 |
input_image = np.frombuffer(input_image.tobytes(),
|
96 |
dtype="uint8").reshape(
|
97 |
-
(input_image.
|
98 |
-
input_image.
|
99 |
input_image = ((input_image / 255) * 2 - 1)
|
|
|
100 |
if start_step > 0:
|
101 |
-
images[0
|
102 |
torch.tensor(input_image[np.newaxis, np.newaxis, :]),
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
images = images.to(self.ddpm.device)
|
106 |
-
for t in
|
|
|
107 |
model_output = self.ddpm.unet(images, t)['sample']
|
108 |
images = scheduler.step(model_output,
|
109 |
t,
|
110 |
images,
|
111 |
generator=generator)['prev_sample']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
images = (images / 2 + 0.5).clamp(0, 1)
|
113 |
images = images.cpu().permute(0, 2, 3, 1).numpy()
|
114 |
|
|
|
9 |
|
10 |
from .mel import Mel
|
11 |
|
12 |
+
VERSION = "1.1.4"
|
13 |
|
14 |
|
15 |
class AudioDiffusion:
|
|
|
61 |
slice: int = 0,
|
62 |
start_step: int = 0,
|
63 |
steps: int = None,
|
64 |
+
generator: torch.Generator = None,
|
65 |
+
mask_start_secs: float = 0,
|
66 |
+
mask_end_secs: float = 0
|
67 |
) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
|
68 |
"""Generate random mel spectrogram from audio input and convert to audio.
|
69 |
|
|
|
74 |
start_step (int): step to start from
|
75 |
steps (int): number of de-noising steps to perform (defaults to num_train_timesteps)
|
76 |
generator (torch.Generator): random number generator or None
|
77 |
+
mask_start_secs (float): number of seconds of audio to mask (not generate) at start
|
78 |
+
mask_end_secs (float): number of seconds of audio to mask (not generate) at end
|
79 |
|
80 |
Returns:
|
81 |
PIL Image: mel spectrogram
|
|
|
88 |
steps = self.ddpm.scheduler.num_train_timesteps
|
89 |
scheduler = DDPMScheduler(num_train_timesteps=steps)
|
90 |
scheduler.set_timesteps(steps)
|
91 |
+
mask = None
|
92 |
+
images = noise = torch.randn(
|
93 |
(1, self.ddpm.unet.in_channels, self.ddpm.unet.sample_size,
|
94 |
self.ddpm.unet.sample_size),
|
95 |
generator=generator,
|
96 |
)
|
97 |
+
|
98 |
if audio_file is not None or raw_audio is not None:
|
99 |
self.mel.load_audio(audio_file, raw_audio)
|
100 |
input_image = self.mel.audio_slice_to_image(slice)
|
101 |
input_image = np.frombuffer(input_image.tobytes(),
|
102 |
dtype="uint8").reshape(
|
103 |
+
(input_image.height,
|
104 |
+
input_image.width))
|
105 |
input_image = ((input_image / 255) * 2 - 1)
|
106 |
+
|
107 |
if start_step > 0:
|
108 |
+
images[0, 0] = scheduler.add_noise(
|
109 |
torch.tensor(input_image[np.newaxis, np.newaxis, :]),
|
110 |
+
noise, steps - start_step)
|
111 |
+
|
112 |
+
mask_start = int(mask_start_secs * self.mel.get_sample_rate() /
|
113 |
+
self.mel.hop_length)
|
114 |
+
mask_end = int(mask_end_secs * self.mel.get_sample_rate() /
|
115 |
+
self.mel.hop_length)
|
116 |
+
mask = scheduler.add_noise(
|
117 |
+
torch.tensor(input_image[np.newaxis, np.newaxis, :]), noise,
|
118 |
+
scheduler.timesteps[start_step:])
|
119 |
|
120 |
images = images.to(self.ddpm.device)
|
121 |
+
for step, t in enumerate(
|
122 |
+
self.progress_bar(scheduler.timesteps[start_step:])):
|
123 |
model_output = self.ddpm.unet(images, t)['sample']
|
124 |
images = scheduler.step(model_output,
|
125 |
t,
|
126 |
images,
|
127 |
generator=generator)['prev_sample']
|
128 |
+
|
129 |
+
if mask is not None:
|
130 |
+
if mask_start > 0:
|
131 |
+
images[0, 0, :, :mask_start] = mask[step,
|
132 |
+
0, :, :mask_start]
|
133 |
+
if mask_end > 0:
|
134 |
+
images[0, 0, :, -mask_end:] = mask[step, 0, :, -mask_end:]
|
135 |
+
|
136 |
images = (images / 2 + 0.5).clamp(0, 1)
|
137 |
images = images.cpu().permute(0, 2, 3, 1).numpy()
|
138 |
|
audiodiffusion/mel.py
CHANGED
@@ -37,7 +37,7 @@ class Mel:
|
|
37 |
self.slice_size = self.x_res * self.hop_length - 1
|
38 |
self.fmax = self.sr / 2
|
39 |
self.top_db = top_db
|
40 |
-
self.
|
41 |
|
42 |
def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None):
|
43 |
"""Load audio.
|
@@ -47,11 +47,16 @@ class Mel:
|
|
47 |
raw_audio (np.ndarray): audio as numpy array
|
48 |
"""
|
49 |
if audio_file is not None:
|
50 |
-
self.
|
51 |
-
audio_file,
|
52 |
-
mono=True)
|
53 |
else:
|
54 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def get_number_of_slices(self) -> int:
|
57 |
"""Get number of slices in audio.
|
@@ -59,7 +64,7 @@ class Mel:
|
|
59 |
Returns:
|
60 |
int: number of spectograms audio can be sliced into
|
61 |
"""
|
62 |
-
return len(self.
|
63 |
|
64 |
def get_audio_slice(self, slice: int = 0) -> np.ndarray:
|
65 |
"""Get slice of audio.
|
@@ -70,7 +75,8 @@ class Mel:
|
|
70 |
Returns:
|
71 |
np.ndarray: audio as numpy array
|
72 |
"""
|
73 |
-
return self.
|
|
|
74 |
|
75 |
def get_sample_rate(self) -> int:
|
76 |
"""Get sample rate:
|
|
|
37 |
self.slice_size = self.x_res * self.hop_length - 1
|
38 |
self.fmax = self.sr / 2
|
39 |
self.top_db = top_db
|
40 |
+
self.audio = None
|
41 |
|
42 |
def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None):
|
43 |
"""Load audio.
|
|
|
47 |
raw_audio (np.ndarray): audio as numpy array
|
48 |
"""
|
49 |
if audio_file is not None:
|
50 |
+
self.audio, _ = librosa.load(audio_file, mono=True)
|
|
|
|
|
51 |
else:
|
52 |
+
self.audio = raw_audio
|
53 |
+
|
54 |
+
# Pad with silence if necessary.
|
55 |
+
if len(self.audio) < self.x_res * self.hop_length:
|
56 |
+
self.audio = np.concatenate([
|
57 |
+
self.audio,
|
58 |
+
np.zeros((self.x_res * self.hop_length - len(self.audio), ))
|
59 |
+
])
|
60 |
|
61 |
def get_number_of_slices(self) -> int:
|
62 |
"""Get number of slices in audio.
|
|
|
64 |
Returns:
|
65 |
int: number of spectograms audio can be sliced into
|
66 |
"""
|
67 |
+
return len(self.audio) // self.slice_size
|
68 |
|
69 |
def get_audio_slice(self, slice: int = 0) -> np.ndarray:
|
70 |
"""Get slice of audio.
|
|
|
75 |
Returns:
|
76 |
np.ndarray: audio as numpy array
|
77 |
"""
|
78 |
+
return self.audio[self.slice_size * slice:self.slice_size *
|
79 |
+
(slice + 1)]
|
80 |
|
81 |
def get_sample_rate(self) -> int:
|
82 |
"""Get sample rate:
|
notebooks/test_model.ipynb
CHANGED
@@ -10,7 +10,7 @@
|
|
10 |
},
|
11 |
{
|
12 |
"cell_type": "code",
|
13 |
-
"execution_count":
|
14 |
"id": "6c7800a6",
|
15 |
"metadata": {},
|
16 |
"outputs": [],
|
@@ -27,7 +27,7 @@
|
|
27 |
},
|
28 |
{
|
29 |
"cell_type": "code",
|
30 |
-
"execution_count":
|
31 |
"id": "b447e2c4",
|
32 |
"metadata": {},
|
33 |
"outputs": [],
|
@@ -39,7 +39,7 @@
|
|
39 |
},
|
40 |
{
|
41 |
"cell_type": "code",
|
42 |
-
"execution_count":
|
43 |
"id": "c2fc0e7a",
|
44 |
"metadata": {},
|
45 |
"outputs": [],
|
@@ -63,7 +63,7 @@
|
|
63 |
},
|
64 |
{
|
65 |
"cell_type": "code",
|
66 |
-
"execution_count":
|
67 |
"id": "97f24046",
|
68 |
"metadata": {},
|
69 |
"outputs": [],
|
@@ -79,7 +79,7 @@
|
|
79 |
},
|
80 |
{
|
81 |
"cell_type": "code",
|
82 |
-
"execution_count":
|
83 |
"id": "a3d45c36",
|
84 |
"metadata": {},
|
85 |
"outputs": [],
|
@@ -169,6 +169,39 @@
|
|
169 |
"display(Audio(track, rate=sample_rate))"
|
170 |
]
|
171 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
{
|
173 |
"cell_type": "markdown",
|
174 |
"id": "b6434d3f",
|
@@ -182,12 +215,12 @@
|
|
182 |
"id": "0da030b2",
|
183 |
"metadata": {},
|
184 |
"source": [
|
185 |
-
"Alternatively, you can start from another audio altogether, resulting in a kind of style transfer."
|
186 |
]
|
187 |
},
|
188 |
{
|
189 |
"cell_type": "code",
|
190 |
-
"execution_count":
|
191 |
"id": "fc620a80",
|
192 |
"metadata": {},
|
193 |
"outputs": [],
|
@@ -207,41 +240,31 @@
|
|
207 |
"metadata": {
|
208 |
"scrolled": true
|
209 |
},
|
210 |
-
"outputs": [
|
211 |
-
{
|
212 |
-
"data": {
|
213 |
-
"application/vnd.jupyter.widget-view+json": {
|
214 |
-
"model_id": "6e741e6bd196458fa38f86197bd16378",
|
215 |
-
"version_major": 2,
|
216 |
-
"version_minor": 0
|
217 |
-
},
|
218 |
-
"text/plain": [
|
219 |
-
" 0%| | 0/500 [00:00<?, ?it/s]"
|
220 |
-
]
|
221 |
-
},
|
222 |
-
"metadata": {},
|
223 |
-
"output_type": "display_data"
|
224 |
-
}
|
225 |
-
],
|
226 |
"source": [
|
227 |
-
"
|
228 |
-
"
|
229 |
-
"
|
|
|
|
|
|
|
230 |
"generator = torch.Generator()\n",
|
231 |
"seed = generator.seed()\n",
|
232 |
-
"
|
|
|
233 |
" generator.manual_seed(seed)\n",
|
234 |
-
" audio =
|
235 |
-
"
|
236 |
-
"
|
237 |
-
"
|
238 |
-
"
|
239 |
-
"
|
240 |
-
"
|
241 |
-
"
|
|
|
242 |
" display(Audio(audio, rate=sample_rate))\n",
|
243 |
" display(Audio(audio2, rate=sample_rate))\n",
|
244 |
-
" track = np.concatenate([track, audio2])"
|
245 |
]
|
246 |
},
|
247 |
{
|
@@ -307,7 +330,17 @@
|
|
307 |
{
|
308 |
"cell_type": "code",
|
309 |
"execution_count": null,
|
310 |
-
"id": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
"metadata": {},
|
312 |
"outputs": [],
|
313 |
"source": []
|
@@ -334,7 +367,7 @@
|
|
334 |
"name": "python",
|
335 |
"nbconvert_exporter": "python",
|
336 |
"pygments_lexer": "ipython3",
|
337 |
-
"version": "3.10.
|
338 |
},
|
339 |
"toc": {
|
340 |
"base_numbering": 1,
|
|
|
10 |
},
|
11 |
{
|
12 |
"cell_type": "code",
|
13 |
+
"execution_count": null,
|
14 |
"id": "6c7800a6",
|
15 |
"metadata": {},
|
16 |
"outputs": [],
|
|
|
27 |
},
|
28 |
{
|
29 |
"cell_type": "code",
|
30 |
+
"execution_count": null,
|
31 |
"id": "b447e2c4",
|
32 |
"metadata": {},
|
33 |
"outputs": [],
|
|
|
39 |
},
|
40 |
{
|
41 |
"cell_type": "code",
|
42 |
+
"execution_count": null,
|
43 |
"id": "c2fc0e7a",
|
44 |
"metadata": {},
|
45 |
"outputs": [],
|
|
|
63 |
},
|
64 |
{
|
65 |
"cell_type": "code",
|
66 |
+
"execution_count": null,
|
67 |
"id": "97f24046",
|
68 |
"metadata": {},
|
69 |
"outputs": [],
|
|
|
79 |
},
|
80 |
{
|
81 |
"cell_type": "code",
|
82 |
+
"execution_count": null,
|
83 |
"id": "a3d45c36",
|
84 |
"metadata": {},
|
85 |
"outputs": [],
|
|
|
169 |
"display(Audio(track, rate=sample_rate))"
|
170 |
]
|
171 |
},
|
172 |
+
{
|
173 |
+
"cell_type": "markdown",
|
174 |
+
"id": "11f91ad3",
|
175 |
+
"metadata": {},
|
176 |
+
"source": [
|
177 |
+
"### Generate continuations (\"out-painting\")"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "code",
|
182 |
+
"execution_count": null,
|
183 |
+
"id": "756d7af5",
|
184 |
+
"metadata": {},
|
185 |
+
"outputs": [],
|
186 |
+
"source": [
|
187 |
+
"overlap_secs = 2 #@param {type:\"integer\"}\n",
|
188 |
+
"start_step = 0 #@param {type:\"slider\", min:0, max:1000, step:10}\n",
|
189 |
+
"overlap_samples = overlap_secs * sample_rate\n",
|
190 |
+
"track = audio\n",
|
191 |
+
"for variation in range(12):\n",
|
192 |
+
" image2, (\n",
|
193 |
+
" sample_rate, audio2\n",
|
194 |
+
" ) = audio_diffusion.generate_spectrogram_and_audio_from_audio(\n",
|
195 |
+
" raw_audio=audio[-overlap_samples:],\n",
|
196 |
+
" start_step=start_step,\n",
|
197 |
+
" mask_start_secs=overlap_secs)\n",
|
198 |
+
" display(image2)\n",
|
199 |
+
" display(Audio(audio2, rate=sample_rate))\n",
|
200 |
+
" track = np.concatenate([track, audio2[overlap_samples:]])\n",
|
201 |
+
" audio = audio2\n",
|
202 |
+
"display(Audio(track, rate=sample_rate))"
|
203 |
+
]
|
204 |
+
},
|
205 |
{
|
206 |
"cell_type": "markdown",
|
207 |
"id": "b6434d3f",
|
|
|
215 |
"id": "0da030b2",
|
216 |
"metadata": {},
|
217 |
"source": [
|
218 |
+
"Alternatively, you can start from another audio altogether, resulting in a kind of style transfer. Maintaining the same seed during generation fixes the style, while masking helps stitch consecutive segments together more smoothly."
|
219 |
]
|
220 |
},
|
221 |
{
|
222 |
"cell_type": "code",
|
223 |
+
"execution_count": null,
|
224 |
"id": "fc620a80",
|
225 |
"metadata": {},
|
226 |
"outputs": [],
|
|
|
240 |
"metadata": {
|
241 |
"scrolled": true
|
242 |
},
|
243 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
"source": [
|
245 |
+
"start_step = 500 #@param {type:\"slider\", min:0, max:1000, step:10}\n",
|
246 |
+
"overlap_secs = 1 #@param {type:\"integer\"}\n",
|
247 |
+
"overlap_samples = overlap_secs * sample_rate\n",
|
248 |
+
"mel.load_audio(audio_file)\n",
|
249 |
+
"slice_size = audio_diffusion.mel.x_res * audio_diffusion.mel.hop_length\n",
|
250 |
+
"stride = slice_size - overlap_samples\n",
|
251 |
"generator = torch.Generator()\n",
|
252 |
"seed = generator.seed()\n",
|
253 |
+
"track = np.array([])\n",
|
254 |
+
"for sample in range(len(mel.audio) // stride):\n",
|
255 |
" generator.manual_seed(seed)\n",
|
256 |
+
" audio = mel.audio[sample * stride:sample * stride + slice_size]\n",
|
257 |
+
" if len(track) > 0:\n",
|
258 |
+
" audio[:overlap_samples] = audio2[-overlap_samples:]\n",
|
259 |
+
" _, (sample_rate,\n",
|
260 |
+
" audio2) = audio_diffusion.generate_spectrogram_and_audio_from_audio(\n",
|
261 |
+
" raw_audio=audio,\n",
|
262 |
+
" start_step=start_step,\n",
|
263 |
+
" generator=generator,\n",
|
264 |
+
" mask_start_secs=1 if len(track) > 0 else 0)\n",
|
265 |
" display(Audio(audio, rate=sample_rate))\n",
|
266 |
" display(Audio(audio2, rate=sample_rate))\n",
|
267 |
+
" track = np.concatenate([track, audio2[overlap_samples:]])"
|
268 |
]
|
269 |
},
|
270 |
{
|
|
|
330 |
{
|
331 |
"cell_type": "code",
|
332 |
"execution_count": null,
|
333 |
+
"id": "df112a72",
|
334 |
+
"metadata": {},
|
335 |
+
"outputs": [],
|
336 |
+
"source": [
|
337 |
+
"len(audio) / mel.hop_length"
|
338 |
+
]
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"cell_type": "code",
|
342 |
+
"execution_count": null,
|
343 |
+
"id": "ad467206",
|
344 |
"metadata": {},
|
345 |
"outputs": [],
|
346 |
"source": []
|
|
|
367 |
"name": "python",
|
368 |
"nbconvert_exporter": "python",
|
369 |
"pygments_lexer": "ipython3",
|
370 |
+
"version": "3.10.6"
|
371 |
},
|
372 |
"toc": {
|
373 |
"base_numbering": 1,
|