Spaces:
Running
on
Zero
Running
on
Zero
Stable Audio Open + progbars + mp3 + batched forward + cleanup
Browse files- .gitattributes +1 -0
- Examples/{Beethoven.wav β Beethoven.mp3} +2 -2
- Examples/{Cat_dog.wav β Beethoven_arcade.mp3} +2 -2
- Examples/{Beethoven_arcade.wav β Beethoven_piano.mp3} +2 -2
- Examples/{Beethoven_piano.wav β Beethoven_rock.mp3} +2 -2
- Examples/{Cat.wav β Cat.mp3} +2 -2
- Examples/Cat_dog.mp3 +3 -0
- Examples/ModalJazz.mp3 +3 -0
- Examples/ModalJazz.wav +0 -3
- Examples/ModalJazz_banjo.mp3 +3 -0
- Examples/ModalJazz_banjo.wav +0 -3
- Examples/Shadows.mp3 +3 -0
- Examples/Shadows_arcade.mp3 +3 -0
- README.md +4 -1
- app.py +235 -158
- inversion_utils.py +139 -381
- models.py +469 -253
- requirements.txt +3 -2
- utils.py +50 -16
.gitattributes
CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
*.wav filter=lfs diff=lfs merge=lfs -text
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
*.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
Examples/{Beethoven.wav β Beethoven.mp3}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3dcc79fe071d118df3caaeeb85d7944f93a5df40bbdb72a26b67bd57da2af7c5
|
3 |
+
size 1097142
|
Examples/{Cat_dog.wav β Beethoven_arcade.mp3}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:542bd61d9cc1723ccfd9bfc06b0818e77fc763013827ff1f9289e2ac6a912904
|
3 |
+
size 563040
|
Examples/{Beethoven_arcade.wav β Beethoven_piano.mp3}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:000d82c39d8c41b10188d328e29cb1baa948232bacd693f22e297cc54f4bb707
|
3 |
+
size 563040
|
Examples/{Beethoven_piano.wav β Beethoven_rock.mp3}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c51d75c9094a50c7892449a013b32ffde266a5abd6dad9f00bf3aeec0ee935ee
|
3 |
+
size 1097142
|
Examples/{Cat.wav β Cat.mp3}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cff7010e5fb12a57508c7a0941663f1a12bfc8b3b3d01d0973359cd42ae5eb1e
|
3 |
+
size 402542
|
Examples/Cat_dog.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:72ff727243606215c934552e946f7d97b5e2e39c4d6263f7f36659e3f39f3008
|
3 |
+
size 207403
|
Examples/ModalJazz.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:34cf145b84b6b4669050ca42932fb74ac0f28aabbe6c665f12a877c9809fa9c6
|
3 |
+
size 4153468
|
Examples/ModalJazz.wav
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:846a77046d21ebc3996841404eede9d56797c82b3414025e1ccafe586eaf2959
|
3 |
-
size 9153322
|
|
|
|
|
|
|
|
Examples/ModalJazz_banjo.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:11680068427556981aa6304e6c11bd05debc820ca581c248954c1ffe3cd94569
|
3 |
+
size 2128320
|
Examples/ModalJazz_banjo.wav
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:122e0078c0bf2fc96425071706fe0e8674c93cc1d2787fd02c0e2c0f12de5cc5
|
3 |
-
size 6802106
|
|
|
|
|
|
|
|
Examples/Shadows.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2e0cab2ebda4507641d6a1b5d9b2d888a7526581b7de48540ebf86ce00579908
|
3 |
+
size 1342693
|
Examples/Shadows_arcade.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:68c84805ea17d0697cd79bc85394754d70fb02f740db4bee4c6ccbb5269a5d84
|
3 |
+
size 1342693
|
README.md
CHANGED
@@ -9,7 +9,10 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
license: cc-by-sa-4.0
|
11 |
short_description: Edit audios with text prompts
|
|
|
|
|
|
|
12 |
---
|
13 |
|
14 |
The 30-second limit was introduced to ensure that queue wait times remain reasonable, especially when there are a lot of users.
|
15 |
-
For that reason pull-requests that change this limit will not be merged. Please clone or duplicate the space to work locally without limits.
|
|
|
9 |
pinned: false
|
10 |
license: cc-by-sa-4.0
|
11 |
short_description: Edit audios with text prompts
|
12 |
+
hf_oauth: true
|
13 |
+
hf_oauth_scopes:
|
14 |
+
- read-repos
|
15 |
---
|
16 |
|
17 |
The 30-second limit was introduced to ensure that queue wait times remain reasonable, especially when there are a lot of users.
|
18 |
+
For that reason pull-requests that change this limit will not be merged. Please clone or duplicate the space to work locally without limits.
|
app.py
CHANGED
@@ -6,27 +6,26 @@ if os.getenv('SPACES_ZERO_GPU') == "true":
|
|
6 |
import gradio as gr
|
7 |
import random
|
8 |
import torch
|
|
|
9 |
from torch import inference_mode
|
10 |
-
|
11 |
-
from typing import Optional
|
12 |
import numpy as np
|
13 |
from models import load_model
|
14 |
import utils
|
15 |
import spaces
|
|
|
16 |
from inversion_utils import inversion_forward_process, inversion_reverse_process
|
17 |
|
18 |
|
19 |
-
# current_loaded_model = "cvssp/audioldm2-music"
|
20 |
-
# # current_loaded_model = "cvssp/audioldm2-music"
|
21 |
-
|
22 |
-
# ldm_stable = load_model(current_loaded_model, device, 200) # deafult model
|
23 |
LDM2 = "cvssp/audioldm2"
|
24 |
MUSIC = "cvssp/audioldm2-music"
|
25 |
LDM2_LARGE = "cvssp/audioldm2-large"
|
|
|
26 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
27 |
ldm2 = load_model(model_id=LDM2, device=device)
|
28 |
ldm2_large = load_model(model_id=LDM2_LARGE, device=device)
|
29 |
ldm2_music = load_model(model_id=MUSIC, device=device)
|
|
|
30 |
|
31 |
|
32 |
def randomize_seed_fn(seed, randomize_seed):
|
@@ -36,89 +35,136 @@ def randomize_seed_fn(seed, randomize_seed):
|
|
36 |
return seed
|
37 |
|
38 |
|
39 |
-
def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src
|
40 |
# ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device)
|
41 |
|
42 |
with inference_mode():
|
43 |
w0 = ldm_stable.vae_encode(x0)
|
44 |
|
45 |
# find Zs and wts - forward process
|
46 |
-
_, zs, wts = inversion_forward_process(ldm_stable, w0, etas=1,
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
53 |
|
54 |
|
55 |
-
def sample(ldm_stable, zs, wts,
|
56 |
# reverse process (via Zs and wT)
|
57 |
tstart = torch.tensor(tstart, dtype=torch.int)
|
58 |
-
|
59 |
-
w0, _ = inversion_reverse_process(ldm_stable, xT=wts, skips=steps - skip,
|
60 |
etas=1., prompts=[prompt_tar],
|
61 |
neg_prompts=[""], cfg_scales=[cfg_scale_tar],
|
62 |
-
|
63 |
-
|
|
|
|
|
64 |
|
65 |
# vae decode image
|
66 |
with inference_mode():
|
67 |
x0_dec = ldm_stable.vae_decode(w0)
|
68 |
-
if x0_dec.dim() < 4:
|
69 |
-
x0_dec = x0_dec[None, :, :, :]
|
70 |
|
71 |
-
|
72 |
-
|
|
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
if model_id == LDM2:
|
81 |
-
factor =
|
82 |
elif model_id == LDM2_LARGE:
|
83 |
-
factor =
|
|
|
|
|
84 |
else: # MUSIC
|
85 |
factor = 1
|
86 |
|
87 |
-
|
88 |
if do_inversion or randomize_seed:
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
|
|
|
|
91 |
if input_audio is None:
|
92 |
raise gr.Error('Input audio missing!')
|
93 |
-
duration = min(utils.get_duration(input_audio), 30)
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
@spaces.GPU(duration=get_duration)
|
101 |
-
def edit(
|
102 |
-
# cache_dir,
|
103 |
-
input_audio,
|
104 |
-
model_id: str,
|
105 |
-
do_inversion: bool,
|
106 |
-
# wtszs_file: str,
|
107 |
-
wts: Optional[torch.Tensor], zs: Optional[torch.Tensor],
|
108 |
-
saved_inv_model: str,
|
109 |
-
source_prompt="",
|
110 |
-
target_prompt="",
|
111 |
-
steps=200,
|
112 |
-
cfg_scale_src=3.5,
|
113 |
-
cfg_scale_tar=12,
|
114 |
-
t_start=45,
|
115 |
-
randomize_seed=True):
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
print(model_id)
|
118 |
if model_id == LDM2:
|
119 |
ldm_stable = ldm2
|
120 |
elif model_id == LDM2_LARGE:
|
121 |
ldm_stable = ldm2_large
|
|
|
|
|
122 |
else: # MUSIC
|
123 |
ldm_stable = ldm2_music
|
124 |
|
@@ -130,102 +176,126 @@ def edit(
|
|
130 |
|
131 |
if input_audio is None:
|
132 |
raise gr.Error('Input audio missing!')
|
133 |
-
x0 = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device
|
134 |
-
|
135 |
-
# if not (do_inversion or randomize_seed):
|
136 |
-
# if not os.path.exists(wtszs_file):
|
137 |
-
# do_inversion = True
|
138 |
-
# Too much time has passed
|
139 |
if wts is None or zs is None:
|
140 |
do_inversion = True
|
141 |
|
142 |
if do_inversion or randomize_seed: # always re-run inversion
|
143 |
-
zs_tensor, wts_tensor = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
# wtszs_file = f.name
|
149 |
-
# wtszs_file = gr.State(value=f.name)
|
150 |
-
# wts = gr.State(value=wts_tensor)
|
151 |
wts = wts_tensor
|
152 |
zs = zs_tensor
|
153 |
-
|
154 |
-
# demo.move_resource_to_block_cache(f.name)
|
155 |
saved_inv_model = model_id
|
156 |
do_inversion = False
|
157 |
else:
|
158 |
-
# wtszs = torch.load(wtszs_file, map_location=device)
|
159 |
-
# # wtszs = torch.load(wtszs_file.f, map_location=device)
|
160 |
-
# wts_tensor = wtszs['wts']
|
161 |
-
# zs_tensor = wtszs['zs']
|
162 |
wts_tensor = wts.to(device)
|
163 |
zs_tensor = zs.to(device)
|
|
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
output = sample(ldm_stable, zs_tensor, wts_tensor, steps, prompt_tar=target_prompt,
|
169 |
-
tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar)
|
170 |
|
171 |
-
return output, wts.cpu(), zs.cpu(), saved_inv_model, do_inversion
|
172 |
# return output, wtszs_file, saved_inv_model, do_inversion
|
173 |
|
174 |
|
175 |
def get_example():
|
176 |
case = [
|
177 |
-
['Examples/Beethoven.
|
178 |
'',
|
179 |
'A recording of an arcade game soundtrack.',
|
180 |
45,
|
181 |
'cvssp/audioldm2-music',
|
182 |
'27s',
|
183 |
-
'Examples/Beethoven_arcade.
|
184 |
],
|
185 |
-
['Examples/Beethoven.
|
186 |
'A high quality recording of wind instruments and strings playing.',
|
187 |
'A high quality recording of a piano playing.',
|
188 |
45,
|
189 |
'cvssp/audioldm2-music',
|
190 |
'27s',
|
191 |
-
'Examples/Beethoven_piano.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
],
|
193 |
-
['Examples/ModalJazz.
|
194 |
'Trumpets playing alongside a piano, bass and drums in an upbeat old-timey cool jazz song.',
|
195 |
'A banjo playing alongside a piano, bass and drums in an upbeat old-timey cool country song.',
|
196 |
45,
|
197 |
'cvssp/audioldm2-music',
|
198 |
'106s',
|
199 |
-
'Examples/ModalJazz_banjo.
|
200 |
-
['Examples/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
'',
|
202 |
'A dog barking.',
|
203 |
75,
|
204 |
'cvssp/audioldm2-large',
|
205 |
'10s',
|
206 |
-
'Examples/Cat_dog.
|
207 |
]
|
208 |
return case
|
209 |
|
210 |
|
211 |
intro = """
|
212 |
-
<h1 style="font-weight:
|
213 |
-
<h2 style="font-weight:
|
214 |
-
|
|
|
215 |
<a href="https://arxiv.org/abs/2402.10009">[Paper]</a> |
|
216 |
<a href="https://hilamanor.github.io/AudioEditing/">[Project page]</a> |
|
217 |
<a href="https://github.com/HilaManor/AudioEditingCode">[Code]</a>
|
218 |
</h3>
|
219 |
|
220 |
-
|
221 |
-
<p style="font-size: 0.9rem; margin: 0rem; line-height: 1.2em; margin-top:1em">
|
222 |
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
223 |
<a href="https://huggingface.co/spaces/hilamanor/audioEditing?duplicate=true">
|
224 |
-
<img style="margin-top: 0em; margin-bottom: 0em; display:inline" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
</p>
|
226 |
-
|
227 |
"""
|
228 |
|
|
|
229 |
help = """
|
230 |
<div style="font-size:medium">
|
231 |
<b>Instructions:</b><br>
|
@@ -233,22 +303,27 @@ help = """
|
|
233 |
<li>You must provide an input audio and a target prompt to edit the audio. </li>
|
234 |
<li>T<sub>start</sub> is used to control the tradeoff between fidelity to the original signal and text-adhearance.
|
235 |
Lower value -> favor fidelity. Higher value -> apply a stronger edit.</li>
|
236 |
-
<li>Make sure that you use
|
237 |
-
For example, use
|
238 |
</li>
|
239 |
<li>You can additionally provide a source prompt to guide even further the editing process.</li>
|
240 |
<li>Longer input will take more time.</li>
|
241 |
<li><strong>Unlimited length</strong>: This space automatically trims input audio to a maximum length of 30 seconds.
|
242 |
-
For unlimited length, duplicated the space, and
|
243 |
-
|
244 |
-
|
245 |
-
<code style="display:inline; background-color: lightgrey;
|
|
|
246 |
</ul>
|
247 |
</div>
|
248 |
|
249 |
"""
|
250 |
|
251 |
-
|
|
|
|
|
|
|
|
|
252 |
def reset_do_inversion(do_inversion_user, do_inversion):
|
253 |
# do_inversion = gr.State(value=True)
|
254 |
do_inversion = True
|
@@ -267,23 +342,22 @@ with gr.Blocks(css='style.css') as demo: #, delete_cache=(3600, 3600)) as demo:
|
|
267 |
return do_inversion_user, do_inversion
|
268 |
|
269 |
gr.HTML(intro)
|
|
|
270 |
wts = gr.State()
|
271 |
zs = gr.State()
|
272 |
-
|
273 |
-
# cache_dir = gr.State(demo.GRADIO_CACHE)
|
274 |
saved_inv_model = gr.State()
|
275 |
-
# current_loaded_model = gr.State(value="cvssp/audioldm2-music")
|
276 |
-
# ldm_stable = load_model("cvssp/audioldm2-music", device, 200)
|
277 |
-
# ldm_stable = gr.State(value=ldm_stable)
|
278 |
do_inversion = gr.State(value=True) # To save some runtime when editing the same thing over and over
|
279 |
do_inversion_user = gr.State(value=False)
|
280 |
|
281 |
with gr.Group():
|
282 |
-
gr.Markdown("π‘ **note**: input longer than **30 sec** is automatically trimmed
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
|
|
|
|
287 |
|
288 |
with gr.Row():
|
289 |
tar_prompt = gr.Textbox(label="Prompt", info="Describe your desired edited output",
|
@@ -293,17 +367,16 @@ with gr.Blocks(css='style.css') as demo: #, delete_cache=(3600, 3600)) as demo:
|
|
293 |
with gr.Row():
|
294 |
t_start = gr.Slider(minimum=15, maximum=85, value=45, step=1, label="T-start (%)", interactive=True, scale=3,
|
295 |
info="Lower T-start -> closer to original audio. Higher T-start -> stronger edit.")
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
info="Choose a checkpoint suitable for your
|
302 |
value="cvssp/audioldm2-music", interactive=True, type="value", scale=2)
|
303 |
-
|
304 |
with gr.Row():
|
305 |
-
|
306 |
-
|
307 |
|
308 |
with gr.Accordion("More Options", open=False):
|
309 |
with gr.Row():
|
@@ -311,58 +384,62 @@ with gr.Blocks(css='style.css') as demo: #, delete_cache=(3600, 3600)) as demo:
|
|
311 |
info="Optional: Describe the original audio input",
|
312 |
placeholder="A recording of a happy upbeat classical music piece",)
|
313 |
|
314 |
-
with gr.Row():
|
315 |
cfg_scale_src = gr.Number(value=3, minimum=0.5, maximum=25, precision=None,
|
316 |
label="Source Guidance Scale", interactive=True, scale=1)
|
317 |
cfg_scale_tar = gr.Number(value=12, minimum=0.5, maximum=25, precision=None,
|
318 |
label="Target Guidance Scale", interactive=True, scale=1)
|
319 |
-
steps = gr.Number(value=50, step=1, minimum=
|
320 |
info="Higher values (e.g. 200) yield higher-quality generation.",
|
321 |
-
label="Num Diffusion Steps", interactive=True, scale=
|
322 |
-
with gr.Row():
|
323 |
seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
|
324 |
randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
|
|
|
325 |
length = gr.Number(label="Length", interactive=False, visible=False)
|
326 |
|
327 |
with gr.Accordion("Helpπ‘", open=False):
|
328 |
gr.HTML(help)
|
329 |
|
330 |
submit.click(
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
|
|
359 |
|
360 |
# If sources changed we have to rerun inversion
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
|
|
|
|
366 |
|
367 |
gr.Examples(
|
368 |
label="Examples",
|
|
|
6 |
import gradio as gr
|
7 |
import random
|
8 |
import torch
|
9 |
+
import os
|
10 |
from torch import inference_mode
|
11 |
+
from typing import Optional, List
|
|
|
12 |
import numpy as np
|
13 |
from models import load_model
|
14 |
import utils
|
15 |
import spaces
|
16 |
+
import huggingface_hub
|
17 |
from inversion_utils import inversion_forward_process, inversion_reverse_process
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
20 |
LDM2 = "cvssp/audioldm2"
|
21 |
MUSIC = "cvssp/audioldm2-music"
|
22 |
LDM2_LARGE = "cvssp/audioldm2-large"
|
23 |
+
STABLEAUD = "stabilityai/stable-audio-open-1.0"
|
24 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
ldm2 = load_model(model_id=LDM2, device=device)
|
26 |
ldm2_large = load_model(model_id=LDM2_LARGE, device=device)
|
27 |
ldm2_music = load_model(model_id=MUSIC, device=device)
|
28 |
+
ldm_stableaud = load_model(model_id=STABLEAUD, device=device, token=os.getenv('PRIV_TOKEN'))
|
29 |
|
30 |
|
31 |
def randomize_seed_fn(seed, randomize_seed):
|
|
|
35 |
return seed
|
36 |
|
37 |
|
38 |
+
def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src, duration, save_compute):
|
39 |
# ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device)
|
40 |
|
41 |
with inference_mode():
|
42 |
w0 = ldm_stable.vae_encode(x0)
|
43 |
|
44 |
# find Zs and wts - forward process
|
45 |
+
_, zs, wts, extra_info = inversion_forward_process(ldm_stable, w0, etas=1,
|
46 |
+
prompts=[prompt_src],
|
47 |
+
cfg_scales=[cfg_scale_src],
|
48 |
+
num_inference_steps=num_diffusion_steps,
|
49 |
+
numerical_fix=True,
|
50 |
+
duration=duration,
|
51 |
+
save_compute=save_compute)
|
52 |
+
return zs, wts, extra_info
|
53 |
|
54 |
|
55 |
+
def sample(ldm_stable, zs, wts, extra_info, prompt_tar, tstart, cfg_scale_tar, duration, save_compute):
|
56 |
# reverse process (via Zs and wT)
|
57 |
tstart = torch.tensor(tstart, dtype=torch.int)
|
58 |
+
w0, _ = inversion_reverse_process(ldm_stable, xT=wts, tstart=tstart,
|
|
|
59 |
etas=1., prompts=[prompt_tar],
|
60 |
neg_prompts=[""], cfg_scales=[cfg_scale_tar],
|
61 |
+
zs=zs[:int(tstart)],
|
62 |
+
duration=duration,
|
63 |
+
extra_info=extra_info,
|
64 |
+
save_compute=save_compute)
|
65 |
|
66 |
# vae decode image
|
67 |
with inference_mode():
|
68 |
x0_dec = ldm_stable.vae_decode(w0)
|
|
|
|
|
69 |
|
70 |
+
if 'stable-audio' not in ldm_stable.model_id:
|
71 |
+
if x0_dec.dim() < 4:
|
72 |
+
x0_dec = x0_dec[None, :, :, :]
|
73 |
|
74 |
+
with torch.no_grad():
|
75 |
+
audio = ldm_stable.decode_to_mel(x0_dec)
|
76 |
+
else:
|
77 |
+
audio = x0_dec.squeeze(0).T
|
78 |
+
|
79 |
+
return (ldm_stable.get_sr(), audio.squeeze().cpu().numpy())
|
80 |
+
|
81 |
+
|
82 |
+
def get_duration(input_audio,
|
83 |
+
model_id: str,
|
84 |
+
do_inversion: bool,
|
85 |
+
wts: Optional[torch.Tensor], zs: Optional[torch.Tensor], extra_info: Optional[List],
|
86 |
+
saved_inv_model: str,
|
87 |
+
source_prompt: str = "",
|
88 |
+
target_prompt: str = "",
|
89 |
+
steps: int = 200,
|
90 |
+
cfg_scale_src: float = 3.5,
|
91 |
+
cfg_scale_tar: float = 12,
|
92 |
+
t_start: int = 45,
|
93 |
+
randomize_seed: bool = True,
|
94 |
+
save_compute: bool = True,
|
95 |
+
oauth_token: Optional[gr.OAuthToken] = None):
|
96 |
if model_id == LDM2:
|
97 |
+
factor = 1
|
98 |
elif model_id == LDM2_LARGE:
|
99 |
+
factor = 2.5
|
100 |
+
elif model_id == STABLEAUD:
|
101 |
+
factor = 3.2
|
102 |
else: # MUSIC
|
103 |
factor = 1
|
104 |
|
105 |
+
forwards = 0
|
106 |
if do_inversion or randomize_seed:
|
107 |
+
forwards = steps if source_prompt == "" else steps * 2 # x2 when there is a prompt text
|
108 |
+
forwards += int(t_start / 100 * steps) * 2
|
109 |
+
|
110 |
+
duration = min(utils.get_duration(input_audio), utils.MAX_DURATION)
|
111 |
+
time_for_maxlength = factor * forwards * 0.15 # 0.25 is the time per forward pass
|
112 |
+
print('expected time:', time_for_maxlength / utils.MAX_DURATION * duration)
|
113 |
+
|
114 |
+
spare_time = 5
|
115 |
+
return max(10, time_for_maxlength / utils.MAX_DURATION * duration + spare_time)
|
116 |
+
|
117 |
|
118 |
+
def verify_model_params(model_id: str, input_audio, src_prompt: str, tar_prompt: str, cfg_scale_src: float,
|
119 |
+
oauth_token: gr.OAuthToken | None):
|
120 |
if input_audio is None:
|
121 |
raise gr.Error('Input audio missing!')
|
|
|
122 |
|
123 |
+
if tar_prompt == "":
|
124 |
+
raise gr.Error("Please provide a target prompt to edit the audio.")
|
125 |
+
|
126 |
+
if src_prompt != "":
|
127 |
+
if model_id == STABLEAUD and cfg_scale_src != 1:
|
128 |
+
gr.Info("Consider using Source Guidance Scale=1 for Stable Audio Open 1.0.")
|
129 |
+
elif model_id != STABLEAUD and cfg_scale_src != 3:
|
130 |
+
gr.Info(f"Consider using Source Guidance Scale=3 for {model_id}.")
|
131 |
+
|
132 |
+
if model_id == STABLEAUD:
|
133 |
+
if oauth_token is None:
|
134 |
+
raise gr.Error("You must be logged in to use Stable Audio Open 1.0. Please log in and try again.")
|
135 |
+
try:
|
136 |
+
huggingface_hub.get_hf_file_metadata(huggingface_hub.hf_hub_url(STABLEAUD, 'transformer/config.json'),
|
137 |
+
token=oauth_token.token)
|
138 |
+
print('Has Access')
|
139 |
+
# except huggingface_hub.utils._errors.GatedRepoError:
|
140 |
+
except huggingface_hub.errors.GatedRepoError:
|
141 |
+
raise gr.Error("You need to accept the license agreement to use Stable Audio Open 1.0. "
|
142 |
+
"Visit the <a href='https://huggingface.co/stabilityai/stable-audio-open-1.0'>"
|
143 |
+
"model page</a> to get access.")
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
+
@spaces.GPU(duration=get_duration)
|
147 |
+
def edit(input_audio,
|
148 |
+
model_id: str,
|
149 |
+
do_inversion: bool,
|
150 |
+
wts: Optional[torch.Tensor], zs: Optional[torch.Tensor], extra_info: Optional[List],
|
151 |
+
saved_inv_model: str,
|
152 |
+
source_prompt: str = "",
|
153 |
+
target_prompt: str = "",
|
154 |
+
steps: int = 200,
|
155 |
+
cfg_scale_src: float = 3.5,
|
156 |
+
cfg_scale_tar: float = 12,
|
157 |
+
t_start: int = 45,
|
158 |
+
randomize_seed: bool = True,
|
159 |
+
save_compute: bool = True,
|
160 |
+
oauth_token: Optional[gr.OAuthToken] = None):
|
161 |
print(model_id)
|
162 |
if model_id == LDM2:
|
163 |
ldm_stable = ldm2
|
164 |
elif model_id == LDM2_LARGE:
|
165 |
ldm_stable = ldm2_large
|
166 |
+
elif model_id == STABLEAUD:
|
167 |
+
ldm_stable = ldm_stableaud
|
168 |
else: # MUSIC
|
169 |
ldm_stable = ldm2_music
|
170 |
|
|
|
176 |
|
177 |
if input_audio is None:
|
178 |
raise gr.Error('Input audio missing!')
|
179 |
+
x0, _, duration = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device,
|
180 |
+
stft=('stable-audio' not in ldm_stable.model_id), model_sr=ldm_stable.get_sr())
|
|
|
|
|
|
|
|
|
181 |
if wts is None or zs is None:
|
182 |
do_inversion = True
|
183 |
|
184 |
if do_inversion or randomize_seed: # always re-run inversion
|
185 |
+
zs_tensor, wts_tensor, extra_info_list = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt,
|
186 |
+
num_diffusion_steps=steps,
|
187 |
+
cfg_scale_src=cfg_scale_src,
|
188 |
+
duration=duration,
|
189 |
+
save_compute=save_compute)
|
|
|
|
|
|
|
190 |
wts = wts_tensor
|
191 |
zs = zs_tensor
|
192 |
+
extra_info = extra_info_list
|
|
|
193 |
saved_inv_model = model_id
|
194 |
do_inversion = False
|
195 |
else:
|
|
|
|
|
|
|
|
|
196 |
wts_tensor = wts.to(device)
|
197 |
zs_tensor = zs.to(device)
|
198 |
+
extra_info_list = [e.to(device) for e in extra_info if e is not None]
|
199 |
|
200 |
+
output = sample(ldm_stable, zs_tensor, wts_tensor, extra_info_list, prompt_tar=target_prompt,
|
201 |
+
tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar, duration=duration,
|
202 |
+
save_compute=save_compute)
|
|
|
|
|
203 |
|
204 |
+
return output, wts.cpu(), zs.cpu(), [e.cpu() for e in extra_info if e is not None], saved_inv_model, do_inversion
|
205 |
# return output, wtszs_file, saved_inv_model, do_inversion
|
206 |
|
207 |
|
208 |
def get_example():
|
209 |
case = [
|
210 |
+
['Examples/Beethoven.mp3',
|
211 |
'',
|
212 |
'A recording of an arcade game soundtrack.',
|
213 |
45,
|
214 |
'cvssp/audioldm2-music',
|
215 |
'27s',
|
216 |
+
'Examples/Beethoven_arcade.mp3',
|
217 |
],
|
218 |
+
['Examples/Beethoven.mp3',
|
219 |
'A high quality recording of wind instruments and strings playing.',
|
220 |
'A high quality recording of a piano playing.',
|
221 |
45,
|
222 |
'cvssp/audioldm2-music',
|
223 |
'27s',
|
224 |
+
'Examples/Beethoven_piano.mp3',
|
225 |
+
],
|
226 |
+
['Examples/Beethoven.mp3',
|
227 |
+
'',
|
228 |
+
'Heavy Rock.',
|
229 |
+
40,
|
230 |
+
'stabilityai/stable-audio-open-1.0',
|
231 |
+
'27s',
|
232 |
+
'Examples/Beethoven_rock.mp3',
|
233 |
],
|
234 |
+
['Examples/ModalJazz.mp3',
|
235 |
'Trumpets playing alongside a piano, bass and drums in an upbeat old-timey cool jazz song.',
|
236 |
'A banjo playing alongside a piano, bass and drums in an upbeat old-timey cool country song.',
|
237 |
45,
|
238 |
'cvssp/audioldm2-music',
|
239 |
'106s',
|
240 |
+
'Examples/ModalJazz_banjo.mp3',],
|
241 |
+
['Examples/Shadows.mp3',
|
242 |
+
'',
|
243 |
+
'8-bit arcade game soundtrack.',
|
244 |
+
40,
|
245 |
+
'stabilityai/stable-audio-open-1.0',
|
246 |
+
'34s',
|
247 |
+
'Examples/Shadows_arcade.mp3',],
|
248 |
+
['Examples/Cat.mp3',
|
249 |
'',
|
250 |
'A dog barking.',
|
251 |
75,
|
252 |
'cvssp/audioldm2-large',
|
253 |
'10s',
|
254 |
+
'Examples/Cat_dog.mp3',]
|
255 |
]
|
256 |
return case
|
257 |
|
258 |
|
259 |
intro = """
|
260 |
+
<h1 style="font-weight: 1000; text-align: center; margin: 0px;"> ZETA Editing π§ </h1>
|
261 |
+
<h2 style="font-weight: 1000; text-align: center; margin: 0px;">
|
262 |
+
Zero-Shot Text-Based Audio Editing Using DDPM Inversion ποΈ </h2>
|
263 |
+
<h3 style="margin-top: 0px; margin-bottom: 10px; text-align: center;">
|
264 |
<a href="https://arxiv.org/abs/2402.10009">[Paper]</a> |
|
265 |
<a href="https://hilamanor.github.io/AudioEditing/">[Project page]</a> |
|
266 |
<a href="https://github.com/HilaManor/AudioEditingCode">[Code]</a>
|
267 |
</h3>
|
268 |
|
269 |
+
<p style="font-size: 1rem; line-height: 1.2em;">
|
|
|
270 |
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
271 |
<a href="https://huggingface.co/spaces/hilamanor/audioEditing?duplicate=true">
|
272 |
+
<img style="margin-top: 0em; margin-bottom: 0em; display:inline" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" >
|
273 |
+
</a>
|
274 |
+
</p>
|
275 |
+
<p style="margin: 0px;">
|
276 |
+
<b>NEW - 15.10.24:</b> You can now edit using <b>Stable Audio Open 1.0</b>.
|
277 |
+
You must be <b>logged in</b> after accepting the
|
278 |
+
<b><a href="https://huggingface.co/stabilityai/stable-audio-open-1.0">license agreement</a></b> to use it.</br>
|
279 |
+
</p>
|
280 |
+
<ul style="padding-left:40px; line-height:normal;">
|
281 |
+
<li style="margin: 0px;">Prompts behave differently - e.g.,
|
282 |
+
try "8-bit arcade" directly instead of "a recording of...". Check out the new examples below!</li>
|
283 |
+
<li style="margin: 0px;">Try to play around <code>T-start=40%</code>.</li>
|
284 |
+
<li style="margin: 0px;">Under "More Options": Use <code>Source Guidance Scale=1</code>,
|
285 |
+
and you can try fewer timesteps (even 20!).</li>
|
286 |
+
<li style="margin: 0px;">Stable Audio Open is a general-audio model.
|
287 |
+
For better music editing, duplicate the space and change to a
|
288 |
+
<a href="https://huggingface.co/models?other=base_model:finetune:stabilityai/stable-audio-open-1.0">
|
289 |
+
fine-tuned model for music</a>.</li>
|
290 |
+
</ul>
|
291 |
+
<p>
|
292 |
+
<b>NEW - 15.10.24:</b> Parallel editing is enabled by default.
|
293 |
+
To disable, uncheck <code>Efficient editing</code> under "More Options".
|
294 |
+
Saves a bit of time.
|
295 |
</p>
|
|
|
296 |
"""
|
297 |
|
298 |
+
|
299 |
help = """
|
300 |
<div style="font-size:medium">
|
301 |
<b>Instructions:</b><br>
|
|
|
303 |
<li>You must provide an input audio and a target prompt to edit the audio. </li>
|
304 |
<li>T<sub>start</sub> is used to control the tradeoff between fidelity to the original signal and text-adhearance.
|
305 |
Lower value -> favor fidelity. Higher value -> apply a stronger edit.</li>
|
306 |
+
<li>Make sure that you use a model version that is suitable for your input audio.
|
307 |
+
For example, use AudioLDM2-music for music while AudioLDM2-large for general audio.
|
308 |
</li>
|
309 |
<li>You can additionally provide a source prompt to guide even further the editing process.</li>
|
310 |
<li>Longer input will take more time.</li>
|
311 |
<li><strong>Unlimited length</strong>: This space automatically trims input audio to a maximum length of 30 seconds.
|
312 |
+
For unlimited length, duplicated the space, and change the
|
313 |
+
<code style="display:inline; background-color: lightgrey;">MAX_DURATION</code> parameter
|
314 |
+
inside <code style="display:inline; background-color: lightgrey;">utils.py</code>
|
315 |
+
to <code style="display:inline; background-color: lightgrey;">None</code>.
|
316 |
+
</li>
|
317 |
</ul>
|
318 |
</div>
|
319 |
|
320 |
"""
|
321 |
|
322 |
+
css = '.gradio-container {max-width: 1000px !important; padding-top: 1.5rem !important;}' \
|
323 |
+
'.audio-upload .wrap {min-height: 0px;}'
|
324 |
+
|
325 |
+
# with gr.Blocks(css='style.css') as demo:
|
326 |
+
with gr.Blocks(css=css) as demo:
|
327 |
def reset_do_inversion(do_inversion_user, do_inversion):
|
328 |
# do_inversion = gr.State(value=True)
|
329 |
do_inversion = True
|
|
|
342 |
return do_inversion_user, do_inversion
|
343 |
|
344 |
gr.HTML(intro)
|
345 |
+
|
346 |
wts = gr.State()
|
347 |
zs = gr.State()
|
348 |
+
extra_info = gr.State()
|
|
|
349 |
saved_inv_model = gr.State()
|
|
|
|
|
|
|
350 |
do_inversion = gr.State(value=True) # To save some runtime when editing the same thing over and over
|
351 |
do_inversion_user = gr.State(value=False)
|
352 |
|
353 |
with gr.Group():
|
354 |
+
gr.Markdown("π‘ **note**: input longer than **30 sec** is automatically trimmed "
|
355 |
+
"(for unlimited input, see the Help section below)")
|
356 |
+
with gr.Row(equal_height=True):
|
357 |
+
input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath",
|
358 |
+
editable=True, label="Input Audio", interactive=True, scale=1, format='wav',
|
359 |
+
elem_classes=['audio-upload'])
|
360 |
+
output_audio = gr.Audio(label="Edited Audio", interactive=False, scale=1, format='wav')
|
361 |
|
362 |
with gr.Row():
|
363 |
tar_prompt = gr.Textbox(label="Prompt", info="Describe your desired edited output",
|
|
|
367 |
with gr.Row():
|
368 |
t_start = gr.Slider(minimum=15, maximum=85, value=45, step=1, label="T-start (%)", interactive=True, scale=3,
|
369 |
info="Lower T-start -> closer to original audio. Higher T-start -> stronger edit.")
|
370 |
+
model_id = gr.Dropdown(label="Model Version",
|
371 |
+
choices=[LDM2,
|
372 |
+
LDM2_LARGE,
|
373 |
+
MUSIC,
|
374 |
+
STABLEAUD],
|
375 |
+
info="Choose a checkpoint suitable for your audio and edit",
|
376 |
value="cvssp/audioldm2-music", interactive=True, type="value", scale=2)
|
|
|
377 |
with gr.Row():
|
378 |
+
submit = gr.Button("Edit", variant="primary", scale=3)
|
379 |
+
gr.LoginButton(value="Login to HF (For Stable Audio)", scale=1)
|
380 |
|
381 |
with gr.Accordion("More Options", open=False):
|
382 |
with gr.Row():
|
|
|
384 |
info="Optional: Describe the original audio input",
|
385 |
placeholder="A recording of a happy upbeat classical music piece",)
|
386 |
|
387 |
+
with gr.Row(equal_height=True):
|
388 |
cfg_scale_src = gr.Number(value=3, minimum=0.5, maximum=25, precision=None,
|
389 |
label="Source Guidance Scale", interactive=True, scale=1)
|
390 |
cfg_scale_tar = gr.Number(value=12, minimum=0.5, maximum=25, precision=None,
|
391 |
label="Target Guidance Scale", interactive=True, scale=1)
|
392 |
+
steps = gr.Number(value=50, step=1, minimum=10, maximum=300,
|
393 |
info="Higher values (e.g. 200) yield higher-quality generation.",
|
394 |
+
label="Num Diffusion Steps", interactive=True, scale=2)
|
395 |
+
with gr.Row(equal_height=True):
|
396 |
seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
|
397 |
randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
|
398 |
+
save_compute = gr.Checkbox(label='Efficient editing', value=True)
|
399 |
length = gr.Number(label="Length", interactive=False, visible=False)
|
400 |
|
401 |
with gr.Accordion("Helpπ‘", open=False):
|
402 |
gr.HTML(help)
|
403 |
|
404 |
submit.click(
|
405 |
+
fn=verify_model_params,
|
406 |
+
inputs=[model_id, input_audio, src_prompt, tar_prompt, cfg_scale_src],
|
407 |
+
outputs=[]
|
408 |
+
).success(
|
409 |
+
fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=[seed], queue=False
|
410 |
+
).then(
|
411 |
+
fn=clear_do_inversion_user, inputs=[do_inversion_user], outputs=[do_inversion_user]
|
412 |
+
).then(
|
413 |
+
fn=edit,
|
414 |
+
inputs=[input_audio,
|
415 |
+
model_id,
|
416 |
+
do_inversion,
|
417 |
+
wts, zs, extra_info,
|
418 |
+
saved_inv_model,
|
419 |
+
src_prompt,
|
420 |
+
tar_prompt,
|
421 |
+
steps,
|
422 |
+
cfg_scale_src,
|
423 |
+
cfg_scale_tar,
|
424 |
+
t_start,
|
425 |
+
randomize_seed,
|
426 |
+
save_compute,
|
427 |
+
],
|
428 |
+
outputs=[output_audio, wts, zs, extra_info, saved_inv_model, do_inversion]
|
429 |
+
).success(
|
430 |
+
fn=post_match_do_inversion,
|
431 |
+
inputs=[do_inversion_user, do_inversion],
|
432 |
+
outputs=[do_inversion_user, do_inversion]
|
433 |
+
)
|
434 |
|
435 |
# If sources changed we have to rerun inversion
|
436 |
+
gr.on(
|
437 |
+
triggers=[input_audio.change, src_prompt.change, model_id.change, cfg_scale_src.change,
|
438 |
+
steps.change, save_compute.change],
|
439 |
+
fn=reset_do_inversion,
|
440 |
+
inputs=[do_inversion_user, do_inversion],
|
441 |
+
outputs=[do_inversion_user, do_inversion]
|
442 |
+
)
|
443 |
|
444 |
gr.Examples(
|
445 |
label="Examples",
|
inversion_utils.py
CHANGED
@@ -1,341 +1,135 @@
|
|
1 |
import torch
|
2 |
from tqdm import tqdm
|
3 |
-
|
4 |
-
from typing import List, Optional, Dict, Union
|
5 |
from models import PipelineWrapper
|
6 |
-
|
7 |
-
|
8 |
-
def mu_tilde(model, xt, x0, timestep):
|
9 |
-
"mu_tilde(x_t, x_0) DDPM paper eq. 7"
|
10 |
-
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
11 |
-
alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 \
|
12 |
-
else model.scheduler.final_alpha_cumprod
|
13 |
-
alpha_t = model.scheduler.alphas[timestep]
|
14 |
-
beta_t = 1 - alpha_t
|
15 |
-
alpha_bar = model.scheduler.alphas_cumprod[timestep]
|
16 |
-
return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1-alpha_bar)) * x0 + \
|
17 |
-
((alpha_t**0.5 * (1-alpha_prod_t_prev)) / (1 - alpha_bar)) * xt
|
18 |
-
|
19 |
-
|
20 |
-
def sample_xts_from_x0(model, x0, num_inference_steps=50, x_prev_mode=False):
|
21 |
-
"""
|
22 |
-
Samples from P(x_1:T|x_0)
|
23 |
-
"""
|
24 |
-
# torch.manual_seed(43256465436)
|
25 |
-
alpha_bar = model.model.scheduler.alphas_cumprod
|
26 |
-
sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5
|
27 |
-
alphas = model.model.scheduler.alphas
|
28 |
-
# betas = 1 - alphas
|
29 |
-
variance_noise_shape = (
|
30 |
-
num_inference_steps + 1,
|
31 |
-
model.model.unet.config.in_channels,
|
32 |
-
# model.unet.sample_size,
|
33 |
-
# model.unet.sample_size)
|
34 |
-
x0.shape[-2],
|
35 |
-
x0.shape[-1])
|
36 |
-
|
37 |
-
timesteps = model.model.scheduler.timesteps.to(model.device)
|
38 |
-
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
39 |
-
xts = torch.zeros(variance_noise_shape).to(x0.device)
|
40 |
-
xts[0] = x0
|
41 |
-
x_prev = x0
|
42 |
-
for t in reversed(timesteps):
|
43 |
-
# idx = t_to_idx[int(t)]
|
44 |
-
idx = num_inference_steps-t_to_idx[int(t)]
|
45 |
-
if x_prev_mode:
|
46 |
-
xts[idx] = x_prev * (alphas[t] ** 0.5) + torch.randn_like(x0) * ((1-alphas[t]) ** 0.5)
|
47 |
-
x_prev = xts[idx].clone()
|
48 |
-
else:
|
49 |
-
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
|
50 |
-
# xts = torch.cat([xts, x0 ],dim = 0)
|
51 |
-
|
52 |
-
return xts
|
53 |
-
|
54 |
-
|
55 |
-
def forward_step(model, model_output, timestep, sample):
|
56 |
-
next_timestep = min(model.scheduler.config.num_train_timesteps - 2,
|
57 |
-
timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps)
|
58 |
-
|
59 |
-
# 2. compute alphas, betas
|
60 |
-
alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
|
61 |
-
# alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 \
|
62 |
-
# else self.scheduler.final_alpha_cumprod
|
63 |
-
|
64 |
-
beta_prod_t = 1 - alpha_prod_t
|
65 |
-
|
66 |
-
# 3. compute predicted original sample from predicted noise also called
|
67 |
-
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
68 |
-
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
69 |
-
|
70 |
-
# 5. TODO: simple noising implementatiom
|
71 |
-
next_sample = model.scheduler.add_noise(pred_original_sample, model_output, torch.LongTensor([next_timestep]))
|
72 |
-
return next_sample
|
73 |
|
74 |
|
75 |
def inversion_forward_process(model: PipelineWrapper,
|
76 |
x0: torch.Tensor,
|
77 |
etas: Optional[float] = None,
|
78 |
-
prog_bar: bool = False,
|
79 |
prompts: List[str] = [""],
|
80 |
cfg_scales: List[float] = [3.5],
|
81 |
num_inference_steps: int = 50,
|
82 |
-
eps: Optional[float] = None,
|
83 |
-
cutoff_points: Optional[List[float]] = None,
|
84 |
numerical_fix: bool = False,
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
raise NotImplementedError("How do you split cfg_scales for hspace? TODO")
|
90 |
-
|
91 |
if len(prompts) > 1 or prompts[0] != "":
|
92 |
text_embeddings_hidden_states, text_embeddings_class_labels, \
|
93 |
text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
|
94 |
-
# text_embeddings = encode_text(model, prompt)
|
95 |
-
|
96 |
-
# # classifier free guidance
|
97 |
-
batch_size = len(prompts)
|
98 |
-
cfg_scales_tensor = torch.ones((batch_size, *x0.shape[1:]), device=model.device, dtype=x0.dtype)
|
99 |
-
|
100 |
-
# if len(prompts) > 1:
|
101 |
-
# if cutoff_points is None:
|
102 |
-
# cutoff_points = [i * 1 / batch_size for i in range(1, batch_size)]
|
103 |
-
# if len(cfg_scales) == 1:
|
104 |
-
# cfg_scales *= batch_size
|
105 |
-
# elif len(cfg_scales) < batch_size:
|
106 |
-
# raise ValueError("Not enough target CFG scales")
|
107 |
-
|
108 |
-
# cutoff_points = [int(x * cfg_scales_tensor.shape[2]) for x in cutoff_points]
|
109 |
-
# cutoff_points = [0, *cutoff_points, cfg_scales_tensor.shape[2]]
|
110 |
|
111 |
-
#
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
# else:
|
119 |
-
cfg_scales_tensor *= cfg_scales[0]
|
120 |
|
121 |
-
uncond_embedding_hidden_states, uncond_embedding_class_lables, uncond_boolean_prompt_mask = model.encode_text([""])
|
122 |
-
# uncond_embedding = encode_text(model, "")
|
123 |
timesteps = model.model.scheduler.timesteps.to(model.device)
|
124 |
-
variance_noise_shape = (
|
125 |
-
num_inference_steps,
|
126 |
-
model.model.unet.config.in_channels,
|
127 |
-
# model.unet.sample_size,
|
128 |
-
# model.unet.sample_size)
|
129 |
-
x0.shape[-2],
|
130 |
-
x0.shape[-1])
|
131 |
|
132 |
-
if
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
hspaces = []
|
143 |
-
skipconns = []
|
144 |
-
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
145 |
xt = x0
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
148 |
|
149 |
-
for t in op:
|
150 |
-
# idx = t_to_idx[int(t)]
|
151 |
-
idx = num_inference_steps - t_to_idx[int(t)] - 1
|
152 |
# 1. predict noise residual
|
153 |
-
|
154 |
-
|
155 |
|
156 |
with torch.no_grad():
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
if len(prompts) > 1 or prompts[0] != "":
|
171 |
# # classifier free guidance
|
172 |
-
noise_pred = out
|
173 |
-
(cfg_scales_tensor * (cond_out.sample - out.sample.expand(batch_size, -1, -1, -1))
|
174 |
-
).sum(axis=0).unsqueeze(0)
|
175 |
-
if extract_h_space or extract_skipconns:
|
176 |
-
noise_h_space = out_hspace + cfg_scales[0] * (cond_out_hspace - out_hspace)
|
177 |
-
if extract_skipconns:
|
178 |
-
noise_skipconns = {k: [out_skipconns[k][j] + cfg_scales[0] *
|
179 |
-
(cond_out_skipconns[k][j] - out_skipconns[k][j])
|
180 |
-
for j in range(len(out_skipconns[k]))]
|
181 |
-
for k in out_skipconns}
|
182 |
-
else:
|
183 |
-
noise_pred = out.sample
|
184 |
-
if extract_h_space or extract_skipconns:
|
185 |
-
noise_h_space = out_hspace
|
186 |
-
if extract_skipconns:
|
187 |
-
noise_skipconns = out_skipconns
|
188 |
-
if extract_h_space or extract_skipconns:
|
189 |
-
hspaces.append(noise_h_space)
|
190 |
-
if extract_skipconns:
|
191 |
-
skipconns.append(noise_skipconns)
|
192 |
-
|
193 |
-
if eta_is_zero:
|
194 |
-
# 2. compute more noisy image and set x_t -> x_t+1
|
195 |
-
xt = forward_step(model.model, noise_pred, t, xt)
|
196 |
else:
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
alpha_prod_t_prev = model.get_alpha_prod_t_prev(prev_timestep)
|
210 |
-
variance = model.get_variance(t, prev_timestep)
|
211 |
-
|
212 |
-
if model.model.scheduler.config.prediction_type == 'epsilon':
|
213 |
-
radom_noise_pred = noise_pred
|
214 |
-
elif model.model.scheduler.config.prediction_type == 'v_prediction':
|
215 |
-
radom_noise_pred = (alpha_bar[t] ** 0.5) * noise_pred + ((1 - alpha_bar[t]) ** 0.5) * xt
|
216 |
-
|
217 |
-
pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * radom_noise_pred
|
218 |
-
|
219 |
-
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
220 |
-
|
221 |
-
z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
|
222 |
-
|
223 |
-
zs[idx] = z
|
224 |
-
|
225 |
-
# correction to avoid error accumulation
|
226 |
-
if numerical_fix:
|
227 |
-
xtm1 = mu_xt + (etas[idx] * variance ** 0.5)*z
|
228 |
-
xts[idx] = xtm1
|
229 |
|
230 |
if zs is not None:
|
231 |
# zs[-1] = torch.zeros_like(zs[-1])
|
232 |
zs[0] = torch.zeros_like(zs[0])
|
233 |
# zs_cycle[0] = torch.zeros_like(zs[0])
|
234 |
|
235 |
-
|
236 |
-
|
237 |
-
return xt, zs, xts, hspaces
|
238 |
-
|
239 |
-
if extract_skipconns:
|
240 |
-
hspaces = torch.concat(hspaces, axis=0)
|
241 |
-
return xt, zs, xts, hspaces, skipconns
|
242 |
-
|
243 |
-
return xt, zs, xts
|
244 |
-
|
245 |
-
|
246 |
-
def reverse_step(model, model_output, timestep, sample, eta=0, variance_noise=None):
|
247 |
-
# 1. get previous step value (=t-1)
|
248 |
-
prev_timestep = timestep - model.model.scheduler.config.num_train_timesteps // \
|
249 |
-
model.model.scheduler.num_inference_steps
|
250 |
-
# 2. compute alphas, betas
|
251 |
-
alpha_prod_t = model.model.scheduler.alphas_cumprod[timestep]
|
252 |
-
alpha_prod_t_prev = model.get_alpha_prod_t_prev(prev_timestep)
|
253 |
-
beta_prod_t = 1 - alpha_prod_t
|
254 |
-
# 3. compute predicted original sample from predicted noise also called
|
255 |
-
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
256 |
-
if model.model.scheduler.config.prediction_type == 'epsilon':
|
257 |
-
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
258 |
-
elif model.model.scheduler.config.prediction_type == 'v_prediction':
|
259 |
-
pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
|
260 |
-
|
261 |
-
# 5. compute variance: "sigma_t(Ξ·)" -> see formula (16)
|
262 |
-
# Ο_t = sqrt((1 β Ξ±_tβ1)/(1 β Ξ±_t)) * sqrt(1 β Ξ±_t/Ξ±_tβ1)
|
263 |
-
# variance = self.scheduler._get_variance(timestep, prev_timestep)
|
264 |
-
variance = model.get_variance(timestep, prev_timestep)
|
265 |
-
# std_dev_t = eta * variance ** (0.5)
|
266 |
-
# Take care of asymetric reverse process (asyrp)
|
267 |
-
if model.model.scheduler.config.prediction_type == 'epsilon':
|
268 |
-
model_output_direction = model_output
|
269 |
-
elif model.model.scheduler.config.prediction_type == 'v_prediction':
|
270 |
-
model_output_direction = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
271 |
-
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
272 |
-
# pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
|
273 |
-
pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
|
274 |
-
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
275 |
-
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
276 |
-
# 8. Add noice if eta > 0
|
277 |
-
if eta > 0:
|
278 |
-
if variance_noise is None:
|
279 |
-
variance_noise = torch.randn(model_output.shape, device=model.device)
|
280 |
-
sigma_z = eta * variance ** (0.5) * variance_noise
|
281 |
-
prev_sample = prev_sample + sigma_z
|
282 |
-
|
283 |
-
return prev_sample
|
284 |
|
285 |
|
286 |
def inversion_reverse_process(model: PipelineWrapper,
|
287 |
xT: torch.Tensor,
|
288 |
-
|
289 |
-
fix_alpha: float = 0.1,
|
290 |
etas: float = 0,
|
291 |
prompts: List[str] = [""],
|
292 |
neg_prompts: List[str] = [""],
|
293 |
cfg_scales: Optional[List[float]] = None,
|
294 |
-
prog_bar: bool = False,
|
295 |
zs: Optional[List[torch.Tensor]] = None,
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
zero_out_resconns: Optional[Union[int, List]] = None,
|
302 |
-
asyrp: bool = False,
|
303 |
-
extract_h_space: bool = False,
|
304 |
-
extract_skipconns: bool = False):
|
305 |
-
|
306 |
-
batch_size = len(prompts)
|
307 |
|
308 |
text_embeddings_hidden_states, text_embeddings_class_labels, \
|
309 |
text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
|
310 |
-
|
311 |
-
uncond_boolean_prompt_mask = model.encode_text(neg_prompts
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
cfg_scales_tensor = torch.ones((batch_size, *xT.shape[1:]), device=model.device, dtype=xT.dtype)
|
317 |
-
|
318 |
-
# if batch_size > 1:
|
319 |
-
# if cutoff_points is None:
|
320 |
-
# cutoff_points = [i * 1 / batch_size for i in range(1, batch_size)]
|
321 |
-
# if len(cfg_scales) == 1:
|
322 |
-
# cfg_scales *= batch_size
|
323 |
-
# elif len(cfg_scales) < batch_size:
|
324 |
-
# raise ValueError("Not enough target CFG scales")
|
325 |
-
|
326 |
-
# cutoff_points = [int(x * cfg_scales_tensor.shape[2]) for x in cutoff_points]
|
327 |
-
# cutoff_points = [0, *cutoff_points, cfg_scales_tensor.shape[2]]
|
328 |
|
329 |
-
|
330 |
-
# cfg_scales_tensor[i, :, end:] = 0
|
331 |
-
# cfg_scales_tensor[i, :, :start] = 0
|
332 |
-
# masks[i, :, end:] = 0
|
333 |
-
# masks[i, :, :start] = 0
|
334 |
-
# cfg_scales_tensor[i] *= cfg_scales[i]
|
335 |
-
# cfg_scales_tensor = T.functional.gaussian_blur(cfg_scales_tensor, kernel_size=15, sigma=1)
|
336 |
-
# masks = T.functional.gaussian_blur(masks, kernel_size=15, sigma=1)
|
337 |
-
# else:
|
338 |
-
cfg_scales_tensor *= cfg_scales[0]
|
339 |
|
340 |
if etas is None:
|
341 |
etas = 0
|
@@ -344,107 +138,71 @@ def inversion_reverse_process(model: PipelineWrapper,
|
|
344 |
assert len(etas) == model.model.scheduler.num_inference_steps
|
345 |
timesteps = model.model.scheduler.timesteps.to(model.device)
|
346 |
|
347 |
-
|
348 |
-
|
349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
|
351 |
-
t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
|
352 |
-
hspaces = []
|
353 |
-
skipconns = []
|
354 |
-
|
355 |
-
for it, t in enumerate(op):
|
356 |
-
# idx = t_to_idx[int(t)]
|
357 |
-
idx = model.model.scheduler.num_inference_steps - t_to_idx[int(t)] - \
|
358 |
-
(model.model.scheduler.num_inference_steps - zs.shape[0] + 1)
|
359 |
# # Unconditional embedding
|
360 |
with torch.no_grad():
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
timestep=t,
|
385 |
encoder_hidden_states=text_embeddings_hidden_states,
|
386 |
class_labels=text_embeddings_class_labels,
|
387 |
encoder_attention_mask=text_embeddings_boolean_prompt_mask,
|
388 |
-
|
389 |
-
(cfg_scales[0] / (cfg_scales[0] + 1)) *
|
390 |
-
(hspace_add[-zs.shape[0]:][it] if hspace_add.shape[0] > 1
|
391 |
-
else hspace_add)),
|
392 |
-
replace_h_space=(None if hspace_replace is None else
|
393 |
-
(hspace_replace[-zs.shape[0]:][it].unsqueeze(0) if hspace_replace.shape[0] > 1
|
394 |
-
else hspace_replace)),
|
395 |
-
zero_out_resconns=zero_out_resconns,
|
396 |
-
replace_skip_conns=(None if skipconns_replace is None else
|
397 |
-
(skipconns_replace[-zs.shape[0]:][it] if len(skipconns_replace) > 1
|
398 |
-
else skipconns_replace))
|
399 |
-
) # encoder_hidden_states = text_embeddings)
|
400 |
|
401 |
z = zs[idx] if zs is not None else None
|
402 |
-
# print(f'idx: {idx}')
|
403 |
-
# print(f't: {t}')
|
404 |
z = z.unsqueeze(0)
|
405 |
-
#
|
406 |
-
|
407 |
-
# # classifier free guidance
|
408 |
-
# noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
|
409 |
-
noise_pred = uncond_out.sample + \
|
410 |
-
(cfg_scales_tensor * (cond_out.sample - uncond_out.sample.expand(batch_size, -1, -1, -1))
|
411 |
-
).sum(axis=0).unsqueeze(0)
|
412 |
-
if extract_h_space or extract_skipconns:
|
413 |
-
noise_h_space = out_hspace + cfg_scales[0] * (cond_out_hspace - out_hspace)
|
414 |
-
if extract_skipconns:
|
415 |
-
noise_skipconns = {k: [out_skipconns[k][j] + cfg_scales[0] *
|
416 |
-
(cond_out_skipconns[k][j] - out_skipconns[k][j])
|
417 |
-
for j in range(len(out_skipconns[k]))]
|
418 |
-
for k in out_skipconns}
|
419 |
-
else:
|
420 |
-
noise_pred = uncond_out.sample
|
421 |
-
if extract_h_space or extract_skipconns:
|
422 |
-
noise_h_space = out_hspace
|
423 |
-
if extract_skipconns:
|
424 |
-
noise_skipconns = out_skipconns
|
425 |
-
|
426 |
-
if extract_h_space or extract_skipconns:
|
427 |
-
hspaces.append(noise_h_space)
|
428 |
-
if extract_skipconns:
|
429 |
-
skipconns.append(noise_skipconns)
|
430 |
|
431 |
# 2. compute less noisy image and set x_t -> x_t-1
|
432 |
-
xt =
|
433 |
-
|
434 |
-
# xt = controller.step_callback(xt)
|
435 |
-
|
436 |
-
# "fix" xt
|
437 |
-
apply_fix = ((skips.max() - skips) > it)
|
438 |
-
if apply_fix.any():
|
439 |
-
apply_fix = (apply_fix * fix_alpha).unsqueeze(1).unsqueeze(2).unsqueeze(3).to(xT.device)
|
440 |
-
xt = (masks * (xt.expand(batch_size, -1, -1, -1) * (1 - apply_fix) +
|
441 |
-
apply_fix * xT[skips.max() - it - 1].expand(batch_size, -1, -1, -1))
|
442 |
-
).sum(axis=0).unsqueeze(0)
|
443 |
-
|
444 |
-
if extract_h_space:
|
445 |
-
return xt, zs, torch.concat(hspaces, axis=0)
|
446 |
-
|
447 |
-
if extract_skipconns:
|
448 |
-
return xt, zs, torch.concat(hspaces, axis=0), skipconns
|
449 |
|
|
|
450 |
return xt, zs
|
|
|
1 |
import torch
|
2 |
from tqdm import tqdm
|
3 |
+
from typing import List, Optional, Tuple
|
|
|
4 |
from models import PipelineWrapper
|
5 |
+
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
def inversion_forward_process(model: PipelineWrapper,
|
9 |
x0: torch.Tensor,
|
10 |
etas: Optional[float] = None,
|
|
|
11 |
prompts: List[str] = [""],
|
12 |
cfg_scales: List[float] = [3.5],
|
13 |
num_inference_steps: int = 50,
|
|
|
|
|
14 |
numerical_fix: bool = False,
|
15 |
+
duration: Optional[float] = None,
|
16 |
+
first_order: bool = False,
|
17 |
+
save_compute: bool = True,
|
18 |
+
progress=gr.Progress()) -> Tuple:
|
|
|
|
|
19 |
if len(prompts) > 1 or prompts[0] != "":
|
20 |
text_embeddings_hidden_states, text_embeddings_class_labels, \
|
21 |
text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# In the forward negative prompts are not supported currently (TODO)
|
24 |
+
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text(
|
25 |
+
[""], negative=True, save_compute=save_compute, cond_length=text_embeddings_class_labels.shape[1]
|
26 |
+
if text_embeddings_class_labels is not None else None)
|
27 |
+
else:
|
28 |
+
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, uncond_boolean_prompt_mask = model.encode_text(
|
29 |
+
[""], negative=True, save_compute=False)
|
|
|
|
|
30 |
|
|
|
|
|
31 |
timesteps = model.model.scheduler.timesteps.to(model.device)
|
32 |
+
variance_noise_shape = model.get_noise_shape(x0, num_inference_steps)
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
if type(etas) in [int, float]:
|
35 |
+
etas = [etas]*model.model.scheduler.num_inference_steps
|
36 |
+
xts = model.sample_xts_from_x0(x0, num_inference_steps=num_inference_steps)
|
37 |
+
zs = torch.zeros(size=variance_noise_shape, device=model.device)
|
38 |
+
extra_info = [None] * len(zs)
|
39 |
+
|
40 |
+
if timesteps[0].dtype == torch.int64:
|
41 |
+
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
42 |
+
elif timesteps[0].dtype == torch.float32:
|
43 |
+
t_to_idx = {float(v): k for k, v in enumerate(timesteps)}
|
|
|
|
|
|
|
44 |
xt = x0
|
45 |
+
op = tqdm(timesteps, desc="Inverting")
|
46 |
+
model.setup_extra_inputs(xt, init_timestep=timesteps[0], audio_end_in_s=duration,
|
47 |
+
save_compute=save_compute and prompts[0] != "")
|
48 |
+
app_op = progress.tqdm(timesteps, desc="Inverting")
|
49 |
+
for t, _ in zip(op, app_op):
|
50 |
+
idx = num_inference_steps - t_to_idx[int(t) if timesteps[0].dtype == torch.int64 else float(t)] - 1
|
51 |
|
|
|
|
|
|
|
52 |
# 1. predict noise residual
|
53 |
+
xt = xts[idx+1][None]
|
54 |
+
xt_inp = model.model.scheduler.scale_model_input(xt, t)
|
55 |
|
56 |
with torch.no_grad():
|
57 |
+
if save_compute and prompts[0] != "":
|
58 |
+
comb_out, _, _ = model.unet_forward(
|
59 |
+
xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1),
|
60 |
+
timestep=t,
|
61 |
+
encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states
|
62 |
+
], dim=0)
|
63 |
+
if uncond_embeddings_hidden_states is not None else None,
|
64 |
+
class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0)
|
65 |
+
if uncond_embeddings_class_lables is not None else None,
|
66 |
+
encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask
|
67 |
+
], dim=0)
|
68 |
+
if uncond_boolean_prompt_mask is not None else None,
|
69 |
+
)
|
70 |
+
out, cond_out = comb_out.sample.chunk(2, dim=0)
|
71 |
+
else:
|
72 |
+
out = model.unet_forward(xt_inp, timestep=t,
|
73 |
+
encoder_hidden_states=uncond_embeddings_hidden_states,
|
74 |
+
class_labels=uncond_embeddings_class_lables,
|
75 |
+
encoder_attention_mask=uncond_boolean_prompt_mask)[0].sample
|
76 |
+
if len(prompts) > 1 or prompts[0] != "":
|
77 |
+
cond_out = model.unet_forward(
|
78 |
+
xt_inp,
|
79 |
+
timestep=t,
|
80 |
+
encoder_hidden_states=text_embeddings_hidden_states,
|
81 |
+
class_labels=text_embeddings_class_labels,
|
82 |
+
encoder_attention_mask=text_embeddings_boolean_prompt_mask)[0].sample
|
83 |
|
84 |
if len(prompts) > 1 or prompts[0] != "":
|
85 |
# # classifier free guidance
|
86 |
+
noise_pred = out + (cfg_scales[0] * (cond_out - out)).sum(axis=0).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
else:
|
88 |
+
noise_pred = out
|
89 |
+
|
90 |
+
# xtm1 = xts[idx+1][None]
|
91 |
+
xtm1 = xts[idx][None]
|
92 |
+
z, xtm1, extra = model.get_zs_from_xts(xt, xtm1, noise_pred, t,
|
93 |
+
eta=etas[idx], numerical_fix=numerical_fix,
|
94 |
+
first_order=first_order)
|
95 |
+
zs[idx] = z
|
96 |
+
# print(f"Fix Xt-1 distance - NORM:{torch.norm(xts[idx] - xtm1):.4g}, MSE:{((xts[idx] - xtm1)**2).mean():.4g}")
|
97 |
+
xts[idx] = xtm1
|
98 |
+
extra_info[idx] = extra
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
if zs is not None:
|
101 |
# zs[-1] = torch.zeros_like(zs[-1])
|
102 |
zs[0] = torch.zeros_like(zs[0])
|
103 |
# zs_cycle[0] = torch.zeros_like(zs[0])
|
104 |
|
105 |
+
del app_op.iterables[0]
|
106 |
+
return xt, zs, xts, extra_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
|
109 |
def inversion_reverse_process(model: PipelineWrapper,
|
110 |
xT: torch.Tensor,
|
111 |
+
tstart: torch.Tensor,
|
|
|
112 |
etas: float = 0,
|
113 |
prompts: List[str] = [""],
|
114 |
neg_prompts: List[str] = [""],
|
115 |
cfg_scales: Optional[List[float]] = None,
|
|
|
116 |
zs: Optional[List[torch.Tensor]] = None,
|
117 |
+
duration: Optional[float] = None,
|
118 |
+
first_order: bool = False,
|
119 |
+
extra_info: Optional[List] = None,
|
120 |
+
save_compute: bool = True,
|
121 |
+
progress=gr.Progress()) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
text_embeddings_hidden_states, text_embeddings_class_labels, \
|
124 |
text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
|
125 |
+
uncond_embeddings_hidden_states, uncond_embeddings_class_lables, \
|
126 |
+
uncond_boolean_prompt_mask = model.encode_text(neg_prompts,
|
127 |
+
negative=True,
|
128 |
+
save_compute=save_compute,
|
129 |
+
cond_length=text_embeddings_class_labels.shape[1]
|
130 |
+
if text_embeddings_class_labels is not None else None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
+
xt = xT[tstart.max()].unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
if etas is None:
|
135 |
etas = 0
|
|
|
138 |
assert len(etas) == model.model.scheduler.num_inference_steps
|
139 |
timesteps = model.model.scheduler.timesteps.to(model.device)
|
140 |
|
141 |
+
op = tqdm(timesteps[-zs.shape[0]:], desc="Editing")
|
142 |
+
if timesteps[0].dtype == torch.int64:
|
143 |
+
t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
|
144 |
+
elif timesteps[0].dtype == torch.float32:
|
145 |
+
t_to_idx = {float(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
|
146 |
+
model.setup_extra_inputs(xt, extra_info=extra_info, init_timestep=timesteps[-zs.shape[0]],
|
147 |
+
audio_end_in_s=duration, save_compute=save_compute)
|
148 |
+
app_op = progress.tqdm(timesteps[-zs.shape[0]:], desc="Editing")
|
149 |
+
for it, (t, _) in enumerate(zip(op, app_op)):
|
150 |
+
idx = model.model.scheduler.num_inference_steps - t_to_idx[
|
151 |
+
int(t) if timesteps[0].dtype == torch.int64 else float(t)] - \
|
152 |
+
(model.model.scheduler.num_inference_steps - zs.shape[0] + 1)
|
153 |
+
|
154 |
+
xt_inp = model.model.scheduler.scale_model_input(xt, t)
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
# # Unconditional embedding
|
157 |
with torch.no_grad():
|
158 |
+
# print(f'xt_inp.shape: {xt_inp.shape}')
|
159 |
+
# print(f't.shape: {t.shape}')
|
160 |
+
# print(f'uncond_embeddings_hidden_states.shape: {uncond_embeddings_hidden_states.shape}')
|
161 |
+
# print(f'uncond_embeddings_class_lables.shape: {uncond_embeddings_class_lables.shape}')
|
162 |
+
# print(f'uncond_boolean_prompt_mask.shape: {uncond_boolean_prompt_mask.shape}')
|
163 |
+
# print(f'text_embeddings_hidden_states.shape: {text_embeddings_hidden_states.shape}')
|
164 |
+
# print(f'text_embeddings_class_labels.shape: {text_embeddings_class_labels.shape}')
|
165 |
+
# print(f'text_embeddings_boolean_prompt_mask.shape: {text_embeddings_boolean_prompt_mask.shape}')
|
166 |
+
|
167 |
+
if save_compute:
|
168 |
+
comb_out, _, _ = model.unet_forward(
|
169 |
+
xt_inp.expand(2, -1, -1, -1) if hasattr(model.model, 'unet') else xt_inp.expand(2, -1, -1),
|
170 |
+
timestep=t,
|
171 |
+
encoder_hidden_states=torch.cat([uncond_embeddings_hidden_states, text_embeddings_hidden_states
|
172 |
+
], dim=0)
|
173 |
+
if uncond_embeddings_hidden_states is not None else None,
|
174 |
+
class_labels=torch.cat([uncond_embeddings_class_lables, text_embeddings_class_labels], dim=0)
|
175 |
+
if uncond_embeddings_class_lables is not None else None,
|
176 |
+
encoder_attention_mask=torch.cat([uncond_boolean_prompt_mask, text_embeddings_boolean_prompt_mask
|
177 |
+
], dim=0)
|
178 |
+
if uncond_boolean_prompt_mask is not None else None,
|
179 |
+
)
|
180 |
+
uncond_out, cond_out = comb_out.sample.chunk(2, dim=0)
|
181 |
+
else:
|
182 |
+
uncond_out = model.unet_forward(
|
183 |
+
xt_inp, timestep=t,
|
184 |
+
encoder_hidden_states=uncond_embeddings_hidden_states,
|
185 |
+
class_labels=uncond_embeddings_class_lables,
|
186 |
+
encoder_attention_mask=uncond_boolean_prompt_mask,
|
187 |
+
)[0].sample
|
188 |
+
|
189 |
+
# Conditional embedding
|
190 |
+
cond_out = model.unet_forward(
|
191 |
+
xt_inp,
|
192 |
timestep=t,
|
193 |
encoder_hidden_states=text_embeddings_hidden_states,
|
194 |
class_labels=text_embeddings_class_labels,
|
195 |
encoder_attention_mask=text_embeddings_boolean_prompt_mask,
|
196 |
+
)[0].sample
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
z = zs[idx] if zs is not None else None
|
|
|
|
|
199 |
z = z.unsqueeze(0)
|
200 |
+
# classifier free guidance
|
201 |
+
noise_pred = uncond_out + (cfg_scales[0] * (cond_out - uncond_out)).sum(axis=0).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
|
203 |
# 2. compute less noisy image and set x_t -> x_t-1
|
204 |
+
xt = model.reverse_step_with_custom_noise(noise_pred, t, xt, variance_noise=z,
|
205 |
+
eta=etas[idx], first_order=first_order)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
+
del app_op.iterables[0]
|
208 |
return xt, zs
|
models.py
CHANGED
@@ -1,46 +1,160 @@
|
|
1 |
import torch
|
2 |
-
from diffusers import DDIMScheduler
|
3 |
-
from diffusers import
|
4 |
-
from
|
|
|
5 |
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
|
|
6 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
7 |
|
8 |
|
9 |
class PipelineWrapper(torch.nn.Module):
|
10 |
-
def __init__(self, model_id
|
|
|
|
|
|
|
11 |
super().__init__(*args, **kwargs)
|
12 |
self.model_id = model_id
|
13 |
self.device = device
|
14 |
self.double_precision = double_precision
|
|
|
15 |
|
16 |
-
def get_sigma(self, timestep) -> float:
|
17 |
sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.model.scheduler.alphas_cumprod - 1)
|
18 |
return sqrt_recipm1_alphas_cumprod[timestep]
|
19 |
|
20 |
-
def load_scheduler(self):
|
21 |
pass
|
22 |
|
23 |
-
def get_fn_STFT(self):
|
24 |
pass
|
25 |
|
26 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
pass
|
28 |
|
29 |
-
def
|
30 |
pass
|
31 |
|
32 |
-
def
|
33 |
pass
|
34 |
|
35 |
-
def encode_text(self, prompts: List[str]
|
|
|
36 |
pass
|
37 |
|
38 |
-
def get_variance(self, timestep, prev_timestep):
|
39 |
pass
|
40 |
|
41 |
-
def get_alpha_prod_t_prev(self, prev_timestep):
|
42 |
pass
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
def unet_forward(self,
|
45 |
sample: torch.FloatTensor,
|
46 |
timestep: Union[torch.Tensor, float, int],
|
@@ -57,244 +171,27 @@ class PipelineWrapper(torch.nn.Module):
|
|
57 |
replace_skip_conns: Optional[Dict[int, torch.Tensor]] = None,
|
58 |
return_dict: bool = True,
|
59 |
zero_out_resconns: Optional[Union[int, List]] = None) -> Tuple:
|
60 |
-
|
61 |
-
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
62 |
-
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
63 |
-
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
64 |
-
# on the fly if necessary.
|
65 |
-
default_overall_up_factor = 2**self.model.unet.num_upsamplers
|
66 |
-
|
67 |
-
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
68 |
-
forward_upsample_size = False
|
69 |
-
upsample_size = None
|
70 |
-
|
71 |
-
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
72 |
-
# logger.info("Forward upsample size to force interpolation output size.")
|
73 |
-
forward_upsample_size = True
|
74 |
-
|
75 |
-
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
76 |
-
# expects mask of shape:
|
77 |
-
# [batch, key_tokens]
|
78 |
-
# adds singleton query_tokens dimension:
|
79 |
-
# [batch, 1, key_tokens]
|
80 |
-
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
81 |
-
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
82 |
-
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
83 |
-
if attention_mask is not None:
|
84 |
-
# assume that mask is expressed as:
|
85 |
-
# (1 = keep, 0 = discard)
|
86 |
-
# convert mask into a bias that can be added to attention scores:
|
87 |
-
# (keep = +0, discard = -10000.0)
|
88 |
-
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
89 |
-
attention_mask = attention_mask.unsqueeze(1)
|
90 |
-
|
91 |
-
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
92 |
-
if encoder_attention_mask is not None:
|
93 |
-
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
94 |
-
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
95 |
-
|
96 |
-
# 0. center input if necessary
|
97 |
-
if self.model.unet.config.center_input_sample:
|
98 |
-
sample = 2 * sample - 1.0
|
99 |
-
|
100 |
-
# 1. time
|
101 |
-
timesteps = timestep
|
102 |
-
if not torch.is_tensor(timesteps):
|
103 |
-
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
104 |
-
# This would be a good case for the `match` statement (Python 3.10+)
|
105 |
-
is_mps = sample.device.type == "mps"
|
106 |
-
if isinstance(timestep, float):
|
107 |
-
dtype = torch.float32 if is_mps else torch.float64
|
108 |
-
else:
|
109 |
-
dtype = torch.int32 if is_mps else torch.int64
|
110 |
-
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
111 |
-
elif len(timesteps.shape) == 0:
|
112 |
-
timesteps = timesteps[None].to(sample.device)
|
113 |
-
|
114 |
-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
115 |
-
timesteps = timesteps.expand(sample.shape[0])
|
116 |
-
|
117 |
-
t_emb = self.model.unet.time_proj(timesteps)
|
118 |
-
|
119 |
-
# `Timesteps` does not contain any weights and will always return f32 tensors
|
120 |
-
# but time_embedding might actually be running in fp16. so we need to cast here.
|
121 |
-
# there might be better ways to encapsulate this.
|
122 |
-
t_emb = t_emb.to(dtype=sample.dtype)
|
123 |
-
|
124 |
-
emb = self.model.unet.time_embedding(t_emb, timestep_cond)
|
125 |
-
|
126 |
-
if self.model.unet.class_embedding is not None:
|
127 |
-
if class_labels is None:
|
128 |
-
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
129 |
-
|
130 |
-
if self.model.unet.config.class_embed_type == "timestep":
|
131 |
-
class_labels = self.model.unet.time_proj(class_labels)
|
132 |
-
|
133 |
-
# `Timesteps` does not contain any weights and will always return f32 tensors
|
134 |
-
# there might be better ways to encapsulate this.
|
135 |
-
class_labels = class_labels.to(dtype=sample.dtype)
|
136 |
-
|
137 |
-
class_emb = self.model.unet.class_embedding(class_labels).to(dtype=sample.dtype)
|
138 |
-
|
139 |
-
if self.model.unet.config.class_embeddings_concat:
|
140 |
-
emb = torch.cat([emb, class_emb], dim=-1)
|
141 |
-
else:
|
142 |
-
emb = emb + class_emb
|
143 |
-
|
144 |
-
if self.model.unet.config.addition_embed_type == "text":
|
145 |
-
aug_emb = self.model.unet.add_embedding(encoder_hidden_states)
|
146 |
-
emb = emb + aug_emb
|
147 |
-
elif self.model.unet.config.addition_embed_type == "text_image":
|
148 |
-
# Kadinsky 2.1 - style
|
149 |
-
if "image_embeds" not in added_cond_kwargs:
|
150 |
-
raise ValueError(
|
151 |
-
f"{self.model.unet.__class__} has the config param `addition_embed_type` set to 'text_image' "
|
152 |
-
f"which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
153 |
-
)
|
154 |
-
|
155 |
-
image_embs = added_cond_kwargs.get("image_embeds")
|
156 |
-
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
157 |
-
|
158 |
-
aug_emb = self.model.unet.add_embedding(text_embs, image_embs)
|
159 |
-
emb = emb + aug_emb
|
160 |
-
|
161 |
-
if self.model.unet.time_embed_act is not None:
|
162 |
-
emb = self.model.unet.time_embed_act(emb)
|
163 |
-
|
164 |
-
if self.model.unet.encoder_hid_proj is not None and self.model.unet.config.encoder_hid_dim_type == "text_proj":
|
165 |
-
encoder_hidden_states = self.model.unet.encoder_hid_proj(encoder_hidden_states)
|
166 |
-
elif self.model.unet.encoder_hid_proj is not None and \
|
167 |
-
self.model.unet.config.encoder_hid_dim_type == "text_image_proj":
|
168 |
-
# Kadinsky 2.1 - style
|
169 |
-
if "image_embeds" not in added_cond_kwargs:
|
170 |
-
raise ValueError(
|
171 |
-
f"{self.model.unet.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' "
|
172 |
-
f"which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
173 |
-
)
|
174 |
-
|
175 |
-
image_embeds = added_cond_kwargs.get("image_embeds")
|
176 |
-
encoder_hidden_states = self.model.unet.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
177 |
-
|
178 |
-
# 2. pre-process
|
179 |
-
sample = self.model.unet.conv_in(sample)
|
180 |
-
|
181 |
-
# 3. down
|
182 |
-
down_block_res_samples = (sample,)
|
183 |
-
for downsample_block in self.model.unet.down_blocks:
|
184 |
-
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
185 |
-
sample, res_samples = downsample_block(
|
186 |
-
hidden_states=sample,
|
187 |
-
temb=emb,
|
188 |
-
encoder_hidden_states=encoder_hidden_states,
|
189 |
-
attention_mask=attention_mask,
|
190 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
191 |
-
encoder_attention_mask=encoder_attention_mask,
|
192 |
-
)
|
193 |
-
else:
|
194 |
-
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
195 |
-
|
196 |
-
down_block_res_samples += res_samples
|
197 |
-
|
198 |
-
if down_block_additional_residuals is not None:
|
199 |
-
new_down_block_res_samples = ()
|
200 |
-
|
201 |
-
for down_block_res_sample, down_block_additional_residual in zip(
|
202 |
-
down_block_res_samples, down_block_additional_residuals
|
203 |
-
):
|
204 |
-
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
205 |
-
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
206 |
-
|
207 |
-
down_block_res_samples = new_down_block_res_samples
|
208 |
-
|
209 |
-
# 4. mid
|
210 |
-
if self.model.unet.mid_block is not None:
|
211 |
-
sample = self.model.unet.mid_block(
|
212 |
-
sample,
|
213 |
-
emb,
|
214 |
-
encoder_hidden_states=encoder_hidden_states,
|
215 |
-
attention_mask=attention_mask,
|
216 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
217 |
-
encoder_attention_mask=encoder_attention_mask,
|
218 |
-
)
|
219 |
-
|
220 |
-
# print(sample.shape)
|
221 |
-
|
222 |
-
if replace_h_space is None:
|
223 |
-
h_space = sample.clone()
|
224 |
-
else:
|
225 |
-
h_space = replace_h_space
|
226 |
-
sample = replace_h_space.clone()
|
227 |
-
|
228 |
-
if mid_block_additional_residual is not None:
|
229 |
-
sample = sample + mid_block_additional_residual
|
230 |
-
|
231 |
-
extracted_res_conns = {}
|
232 |
-
# 5. up
|
233 |
-
for i, upsample_block in enumerate(self.model.unet.up_blocks):
|
234 |
-
is_final_block = i == len(self.model.unet.up_blocks) - 1
|
235 |
-
|
236 |
-
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
237 |
-
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
238 |
-
if replace_skip_conns is not None and replace_skip_conns.get(i):
|
239 |
-
res_samples = replace_skip_conns.get(i)
|
240 |
-
|
241 |
-
if zero_out_resconns is not None:
|
242 |
-
if (type(zero_out_resconns) is int and i >= (zero_out_resconns - 1)) or \
|
243 |
-
type(zero_out_resconns) is list and i in zero_out_resconns:
|
244 |
-
res_samples = [torch.zeros_like(x) for x in res_samples]
|
245 |
-
# down_block_res_samples = [torch.zeros_like(x) for x in down_block_res_samples]
|
246 |
-
|
247 |
-
extracted_res_conns[i] = res_samples
|
248 |
-
|
249 |
-
# if we have not reached the final block and need to forward the
|
250 |
-
# upsample size, we do it here
|
251 |
-
if not is_final_block and forward_upsample_size:
|
252 |
-
upsample_size = down_block_res_samples[-1].shape[2:]
|
253 |
-
|
254 |
-
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
255 |
-
sample = upsample_block(
|
256 |
-
hidden_states=sample,
|
257 |
-
temb=emb,
|
258 |
-
res_hidden_states_tuple=res_samples,
|
259 |
-
encoder_hidden_states=encoder_hidden_states,
|
260 |
-
cross_attention_kwargs=cross_attention_kwargs,
|
261 |
-
upsample_size=upsample_size,
|
262 |
-
attention_mask=attention_mask,
|
263 |
-
encoder_attention_mask=encoder_attention_mask,
|
264 |
-
)
|
265 |
-
else:
|
266 |
-
sample = upsample_block(
|
267 |
-
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
268 |
-
)
|
269 |
-
|
270 |
-
# 6. post-process
|
271 |
-
if self.model.unet.conv_norm_out:
|
272 |
-
sample = self.model.unet.conv_norm_out(sample)
|
273 |
-
sample = self.model.unet.conv_act(sample)
|
274 |
-
sample = self.model.unet.conv_out(sample)
|
275 |
-
|
276 |
-
if not return_dict:
|
277 |
-
return (sample,)
|
278 |
-
|
279 |
-
return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns
|
280 |
|
281 |
|
282 |
class AudioLDM2Wrapper(PipelineWrapper):
|
283 |
def __init__(self, *args, **kwargs) -> None:
|
284 |
super().__init__(*args, **kwargs)
|
285 |
if self.double_precision:
|
286 |
-
self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, torch_dtype=torch.float64
|
|
|
287 |
else:
|
288 |
try:
|
289 |
-
self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=True
|
|
|
290 |
except FileNotFoundError:
|
291 |
-
self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=False
|
|
|
292 |
|
293 |
-
def load_scheduler(self):
|
294 |
-
# self.model.scheduler = DDIMScheduler.from_config(self.model_id, subfolder="scheduler")
|
295 |
self.model.scheduler = DDIMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
|
296 |
|
297 |
-
def get_fn_STFT(self):
|
298 |
from audioldm.audio import TacotronSTFT
|
299 |
return TacotronSTFT(
|
300 |
filter_length=1024,
|
@@ -306,17 +203,17 @@ class AudioLDM2Wrapper(PipelineWrapper):
|
|
306 |
mel_fmax=8000,
|
307 |
)
|
308 |
|
309 |
-
def vae_encode(self, x):
|
310 |
# self.model.vae.disable_tiling()
|
311 |
if x.shape[2] % 4:
|
312 |
x = torch.nn.functional.pad(x, (0, 0, 4 - (x.shape[2] % 4), 0))
|
313 |
return (self.model.vae.encode(x).latent_dist.mode() * self.model.vae.config.scaling_factor).float()
|
314 |
# return (self.encode_no_tiling(x).latent_dist.mode() * self.model.vae.config.scaling_factor).float()
|
315 |
|
316 |
-
def vae_decode(self, x):
|
317 |
return self.model.vae.decode(1 / self.model.vae.config.scaling_factor * x).sample
|
318 |
|
319 |
-
def decode_to_mel(self, x):
|
320 |
if self.double_precision:
|
321 |
tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().double()).detach()
|
322 |
tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().float()).detach()
|
@@ -324,7 +221,9 @@ class AudioLDM2Wrapper(PipelineWrapper):
|
|
324 |
tmp = tmp.unsqueeze(0)
|
325 |
return tmp
|
326 |
|
327 |
-
def encode_text(self, prompts: List[str]
|
|
|
|
|
328 |
tokenizers = [self.model.tokenizer, self.model.tokenizer_2]
|
329 |
text_encoders = [self.model.text_encoder, self.model.text_encoder_2]
|
330 |
prompt_embeds_list = []
|
@@ -333,8 +232,11 @@ class AudioLDM2Wrapper(PipelineWrapper):
|
|
333 |
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
334 |
text_inputs = tokenizer(
|
335 |
prompts,
|
336 |
-
padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
|
337 |
-
|
|
|
|
|
|
|
338 |
truncation=True,
|
339 |
return_tensors="pt",
|
340 |
)
|
@@ -404,7 +306,7 @@ class AudioLDM2Wrapper(PipelineWrapper):
|
|
404 |
|
405 |
return generated_prompt_embeds, prompt_embeds, attention_mask
|
406 |
|
407 |
-
def get_variance(self, timestep, prev_timestep):
|
408 |
alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
|
409 |
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
|
410 |
beta_prod_t = 1 - alpha_prod_t
|
@@ -412,7 +314,7 @@ class AudioLDM2Wrapper(PipelineWrapper):
|
|
412 |
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
413 |
return variance
|
414 |
|
415 |
-
def get_alpha_prod_t_prev(self, prev_timestep):
|
416 |
return self.model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 \
|
417 |
else self.model.scheduler.final_alpha_cumprod
|
418 |
|
@@ -485,8 +387,6 @@ class AudioLDM2Wrapper(PipelineWrapper):
|
|
485 |
# 1. time
|
486 |
timesteps = timestep
|
487 |
if not torch.is_tensor(timesteps):
|
488 |
-
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
489 |
-
# This would be a good case for the `match` statement (Python 3.10+)
|
490 |
is_mps = sample.device.type == "mps"
|
491 |
if isinstance(timestep, float):
|
492 |
dtype = torch.float32 if is_mps else torch.float64
|
@@ -628,12 +528,328 @@ class AudioLDM2Wrapper(PipelineWrapper):
|
|
628 |
|
629 |
return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns
|
630 |
|
631 |
-
|
632 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
633 |
|
634 |
|
635 |
-
def load_model(model_id, device,
|
636 |
-
|
|
|
|
|
|
|
|
|
637 |
ldm_stable.load_scheduler()
|
638 |
torch.cuda.empty_cache()
|
639 |
return ldm_stable
|
|
|
1 |
import torch
|
2 |
+
from diffusers import DDIMScheduler, CosineDPMSolverMultistepScheduler
|
3 |
+
from diffusers.schedulers.scheduling_dpmsolver_sde import BrownianTreeNoiseSampler
|
4 |
+
from diffusers import AudioLDM2Pipeline, StableAudioPipeline
|
5 |
+
from transformers import RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer
|
6 |
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
7 |
+
from diffusers.models.embeddings import get_1d_rotary_pos_embed
|
8 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
9 |
+
import gradio as gr
|
10 |
|
11 |
|
12 |
class PipelineWrapper(torch.nn.Module):
|
13 |
+
def __init__(self, model_id: str,
|
14 |
+
device: torch.device,
|
15 |
+
double_precision: bool = False,
|
16 |
+
token: Optional[str] = None, *args, **kwargs) -> None:
|
17 |
super().__init__(*args, **kwargs)
|
18 |
self.model_id = model_id
|
19 |
self.device = device
|
20 |
self.double_precision = double_precision
|
21 |
+
self.token = token
|
22 |
|
23 |
+
def get_sigma(self, timestep: int) -> float:
|
24 |
sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.model.scheduler.alphas_cumprod - 1)
|
25 |
return sqrt_recipm1_alphas_cumprod[timestep]
|
26 |
|
27 |
+
def load_scheduler(self) -> None:
|
28 |
pass
|
29 |
|
30 |
+
def get_fn_STFT(self) -> torch.nn.Module:
|
31 |
pass
|
32 |
|
33 |
+
def get_sr(self) -> int:
|
34 |
+
return 16000
|
35 |
+
|
36 |
+
def vae_encode(self, x: torch.Tensor) -> torch.Tensor:
|
37 |
+
pass
|
38 |
+
|
39 |
+
def vae_decode(self, x: torch.Tensor) -> torch.Tensor:
|
40 |
pass
|
41 |
|
42 |
+
def decode_to_mel(self, x: torch.Tensor) -> torch.Tensor:
|
43 |
pass
|
44 |
|
45 |
+
def setup_extra_inputs(self, *args, **kwargs) -> None:
|
46 |
pass
|
47 |
|
48 |
+
def encode_text(self, prompts: List[str], **kwargs
|
49 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
50 |
pass
|
51 |
|
52 |
+
def get_variance(self, timestep: torch.Tensor, prev_timestep: torch.Tensor) -> torch.Tensor:
|
53 |
pass
|
54 |
|
55 |
+
def get_alpha_prod_t_prev(self, prev_timestep: torch.Tensor) -> torch.Tensor:
|
56 |
pass
|
57 |
|
58 |
+
def get_noise_shape(self, x0: torch.Tensor, num_steps: int) -> Tuple[int, ...]:
|
59 |
+
variance_noise_shape = (num_steps,
|
60 |
+
self.model.unet.config.in_channels,
|
61 |
+
x0.shape[-2],
|
62 |
+
x0.shape[-1])
|
63 |
+
return variance_noise_shape
|
64 |
+
|
65 |
+
def sample_xts_from_x0(self, x0: torch.Tensor, num_inference_steps: int = 50) -> torch.Tensor:
|
66 |
+
"""
|
67 |
+
Samples from P(x_1:T|x_0)
|
68 |
+
"""
|
69 |
+
alpha_bar = self.model.scheduler.alphas_cumprod
|
70 |
+
sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5
|
71 |
+
|
72 |
+
variance_noise_shape = self.get_noise_shape(x0, num_inference_steps + 1)
|
73 |
+
timesteps = self.model.scheduler.timesteps.to(self.device)
|
74 |
+
t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
|
75 |
+
xts = torch.zeros(variance_noise_shape).to(x0.device)
|
76 |
+
xts[0] = x0
|
77 |
+
for t in reversed(timesteps):
|
78 |
+
idx = num_inference_steps - t_to_idx[int(t)]
|
79 |
+
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
|
80 |
+
|
81 |
+
return xts
|
82 |
+
|
83 |
+
def get_zs_from_xts(self, xt: torch.Tensor, xtm1: torch.Tensor, noise_pred: torch.Tensor,
|
84 |
+
t: torch.Tensor, eta: float = 0, numerical_fix: bool = True, **kwargs
|
85 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
86 |
+
# pred of x0
|
87 |
+
alpha_bar = self.model.scheduler.alphas_cumprod
|
88 |
+
if self.model.scheduler.config.prediction_type == 'epsilon':
|
89 |
+
pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
|
90 |
+
elif self.model.scheduler.config.prediction_type == 'v_prediction':
|
91 |
+
pred_original_sample = (alpha_bar[t] ** 0.5) * xt - ((1 - alpha_bar[t]) ** 0.5) * noise_pred
|
92 |
+
|
93 |
+
# direction to xt
|
94 |
+
prev_timestep = t - self.model.scheduler.config.num_train_timesteps // \
|
95 |
+
self.model.scheduler.num_inference_steps
|
96 |
+
|
97 |
+
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
|
98 |
+
variance = self.get_variance(t, prev_timestep)
|
99 |
+
|
100 |
+
if self.model.scheduler.config.prediction_type == 'epsilon':
|
101 |
+
radom_noise_pred = noise_pred
|
102 |
+
elif self.model.scheduler.config.prediction_type == 'v_prediction':
|
103 |
+
radom_noise_pred = (alpha_bar[t] ** 0.5) * noise_pred + ((1 - alpha_bar[t]) ** 0.5) * xt
|
104 |
+
|
105 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * radom_noise_pred
|
106 |
+
|
107 |
+
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
108 |
+
|
109 |
+
z = (xtm1 - mu_xt) / (eta * variance ** 0.5)
|
110 |
+
|
111 |
+
# correction to avoid error accumulation
|
112 |
+
if numerical_fix:
|
113 |
+
xtm1 = mu_xt + (eta * variance ** 0.5)*z
|
114 |
+
|
115 |
+
return z, xtm1, None
|
116 |
+
|
117 |
+
def reverse_step_with_custom_noise(self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor,
|
118 |
+
variance_noise: Optional[torch.Tensor] = None, eta: float = 0, **kwargs
|
119 |
+
) -> torch.Tensor:
|
120 |
+
# 1. get previous step value (=t-1)
|
121 |
+
prev_timestep = timestep - self.model.scheduler.config.num_train_timesteps // \
|
122 |
+
self.model.scheduler.num_inference_steps
|
123 |
+
# 2. compute alphas, betas
|
124 |
+
alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
|
125 |
+
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
|
126 |
+
beta_prod_t = 1 - alpha_prod_t
|
127 |
+
# 3. compute predicted original sample from predicted noise also called
|
128 |
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
129 |
+
if self.model.scheduler.config.prediction_type == 'epsilon':
|
130 |
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
131 |
+
elif self.model.scheduler.config.prediction_type == 'v_prediction':
|
132 |
+
pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
|
133 |
+
|
134 |
+
# 5. compute variance: "sigma_t(Ξ·)" -> see formula (16)
|
135 |
+
# Ο_t = sqrt((1 β Ξ±_tβ1)/(1 β Ξ±_t)) * sqrt(1 β Ξ±_t/Ξ±_tβ1)
|
136 |
+
# variance = self.scheduler._get_variance(timestep, prev_timestep)
|
137 |
+
variance = self.get_variance(timestep, prev_timestep)
|
138 |
+
# std_dev_t = eta * variance ** (0.5)
|
139 |
+
# Take care of asymetric reverse process (asyrp)
|
140 |
+
if self.model.scheduler.config.prediction_type == 'epsilon':
|
141 |
+
model_output_direction = model_output
|
142 |
+
elif self.model.scheduler.config.prediction_type == 'v_prediction':
|
143 |
+
model_output_direction = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
144 |
+
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
145 |
+
# pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
|
146 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
|
147 |
+
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
148 |
+
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
149 |
+
# 8. Add noice if eta > 0
|
150 |
+
if eta > 0:
|
151 |
+
if variance_noise is None:
|
152 |
+
variance_noise = torch.randn(model_output.shape, device=self.device)
|
153 |
+
sigma_z = eta * variance ** (0.5) * variance_noise
|
154 |
+
prev_sample = prev_sample + sigma_z
|
155 |
+
|
156 |
+
return prev_sample
|
157 |
+
|
158 |
def unet_forward(self,
|
159 |
sample: torch.FloatTensor,
|
160 |
timestep: Union[torch.Tensor, float, int],
|
|
|
171 |
replace_skip_conns: Optional[Dict[int, torch.Tensor]] = None,
|
172 |
return_dict: bool = True,
|
173 |
zero_out_resconns: Optional[Union[int, List]] = None) -> Tuple:
|
174 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
|
177 |
class AudioLDM2Wrapper(PipelineWrapper):
|
178 |
def __init__(self, *args, **kwargs) -> None:
|
179 |
super().__init__(*args, **kwargs)
|
180 |
if self.double_precision:
|
181 |
+
self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, torch_dtype=torch.float64, token=self.token
|
182 |
+
).to(self.device)
|
183 |
else:
|
184 |
try:
|
185 |
+
self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=True, token=self.token
|
186 |
+
).to(self.device)
|
187 |
except FileNotFoundError:
|
188 |
+
self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=False, token=self.token
|
189 |
+
).to(self.device)
|
190 |
|
191 |
+
def load_scheduler(self) -> None:
|
|
|
192 |
self.model.scheduler = DDIMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
|
193 |
|
194 |
+
def get_fn_STFT(self) -> torch.nn.Module:
|
195 |
from audioldm.audio import TacotronSTFT
|
196 |
return TacotronSTFT(
|
197 |
filter_length=1024,
|
|
|
203 |
mel_fmax=8000,
|
204 |
)
|
205 |
|
206 |
+
def vae_encode(self, x: torch.Tensor) -> torch.Tensor:
|
207 |
# self.model.vae.disable_tiling()
|
208 |
if x.shape[2] % 4:
|
209 |
x = torch.nn.functional.pad(x, (0, 0, 4 - (x.shape[2] % 4), 0))
|
210 |
return (self.model.vae.encode(x).latent_dist.mode() * self.model.vae.config.scaling_factor).float()
|
211 |
# return (self.encode_no_tiling(x).latent_dist.mode() * self.model.vae.config.scaling_factor).float()
|
212 |
|
213 |
+
def vae_decode(self, x: torch.Tensor) -> torch.Tensor:
|
214 |
return self.model.vae.decode(1 / self.model.vae.config.scaling_factor * x).sample
|
215 |
|
216 |
+
def decode_to_mel(self, x: torch.Tensor) -> torch.Tensor:
|
217 |
if self.double_precision:
|
218 |
tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().double()).detach()
|
219 |
tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().float()).detach()
|
|
|
221 |
tmp = tmp.unsqueeze(0)
|
222 |
return tmp
|
223 |
|
224 |
+
def encode_text(self, prompts: List[str], negative: bool = False,
|
225 |
+
save_compute: bool = False, cond_length: int = 0, **kwargs
|
226 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
227 |
tokenizers = [self.model.tokenizer, self.model.tokenizer_2]
|
228 |
text_encoders = [self.model.text_encoder, self.model.text_encoder_2]
|
229 |
prompt_embeds_list = []
|
|
|
232 |
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
233 |
text_inputs = tokenizer(
|
234 |
prompts,
|
235 |
+
padding="max_length" if (save_compute and negative) or isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
|
236 |
+
else True,
|
237 |
+
max_length=tokenizer.model_max_length
|
238 |
+
if (not save_compute) or ((not negative) or isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer)))
|
239 |
+
else cond_length,
|
240 |
truncation=True,
|
241 |
return_tensors="pt",
|
242 |
)
|
|
|
306 |
|
307 |
return generated_prompt_embeds, prompt_embeds, attention_mask
|
308 |
|
309 |
+
def get_variance(self, timestep: torch.Tensor, prev_timestep: torch.Tensor) -> torch.Tensor:
|
310 |
alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
|
311 |
alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
|
312 |
beta_prod_t = 1 - alpha_prod_t
|
|
|
314 |
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
315 |
return variance
|
316 |
|
317 |
+
def get_alpha_prod_t_prev(self, prev_timestep: torch.Tensor) -> torch.Tensor:
|
318 |
return self.model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 \
|
319 |
else self.model.scheduler.final_alpha_cumprod
|
320 |
|
|
|
387 |
# 1. time
|
388 |
timesteps = timestep
|
389 |
if not torch.is_tensor(timesteps):
|
|
|
|
|
390 |
is_mps = sample.device.type == "mps"
|
391 |
if isinstance(timestep, float):
|
392 |
dtype = torch.float32 if is_mps else torch.float64
|
|
|
528 |
|
529 |
return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns
|
530 |
|
531 |
+
|
532 |
+
class StableAudWrapper(PipelineWrapper):
|
533 |
+
def __init__(self, *args, **kwargs) -> None:
|
534 |
+
super().__init__(*args, **kwargs)
|
535 |
+
try:
|
536 |
+
self.model = StableAudioPipeline.from_pretrained(self.model_id, token=self.token, local_files_only=True
|
537 |
+
).to(self.device)
|
538 |
+
except FileNotFoundError:
|
539 |
+
self.model = StableAudioPipeline.from_pretrained(self.model_id, token=self.token, local_files_only=False
|
540 |
+
).to(self.device)
|
541 |
+
self.model.transformer.eval()
|
542 |
+
self.model.vae.eval()
|
543 |
+
|
544 |
+
if self.double_precision:
|
545 |
+
self.model = self.model.to(torch.float64)
|
546 |
+
|
547 |
+
def load_scheduler(self) -> None:
|
548 |
+
self.model.scheduler = CosineDPMSolverMultistepScheduler.from_pretrained(
|
549 |
+
self.model_id, subfolder="scheduler", token=self.token)
|
550 |
+
|
551 |
+
def encode_text(self, prompts: List[str], negative: bool = False, **kwargs) -> Tuple[torch.Tensor, None, torch.Tensor]:
|
552 |
+
text_inputs = self.model.tokenizer(
|
553 |
+
prompts,
|
554 |
+
padding="max_length",
|
555 |
+
max_length=self.model.tokenizer.model_max_length,
|
556 |
+
truncation=True,
|
557 |
+
return_tensors="pt",
|
558 |
+
)
|
559 |
+
|
560 |
+
text_input_ids = text_inputs.input_ids.to(self.device)
|
561 |
+
attention_mask = text_inputs.attention_mask.to(self.device)
|
562 |
+
|
563 |
+
self.model.text_encoder.eval()
|
564 |
+
with torch.no_grad():
|
565 |
+
prompt_embeds = self.model.text_encoder(text_input_ids, attention_mask=attention_mask)[0]
|
566 |
+
|
567 |
+
if negative and attention_mask is not None: # set the masked tokens to the null embed
|
568 |
+
prompt_embeds = torch.where(attention_mask.to(torch.bool).unsqueeze(2), prompt_embeds, 0.0)
|
569 |
+
|
570 |
+
prompt_embeds = self.model.projection_model(text_hidden_states=prompt_embeds).text_hidden_states
|
571 |
+
|
572 |
+
if attention_mask is None:
|
573 |
+
raise gr.Error("Shouldn't reach here. Please raise an issue if you do.")
|
574 |
+
"""prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
575 |
+
if attention_mask is not None and negative_attention_mask is None:
|
576 |
+
negative_attention_mask = torch.ones_like(attention_mask)
|
577 |
+
elif attention_mask is None and negative_attention_mask is not None:
|
578 |
+
attention_mask = torch.ones_like(negative_attention_mask)"""
|
579 |
+
|
580 |
+
if prompts == [""]: # empty
|
581 |
+
return torch.zeros_like(prompt_embeds, device=prompt_embeds.device), None, None
|
582 |
+
|
583 |
+
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
|
584 |
+
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype)
|
585 |
+
return prompt_embeds, None, attention_mask
|
586 |
+
|
587 |
+
def get_fn_STFT(self) -> torch.nn.Module:
|
588 |
+
from audioldm.audio import TacotronSTFT
|
589 |
+
return TacotronSTFT(
|
590 |
+
filter_length=1024,
|
591 |
+
hop_length=160,
|
592 |
+
win_length=1024,
|
593 |
+
n_mel_channels=64,
|
594 |
+
sampling_rate=44100,
|
595 |
+
mel_fmin=0,
|
596 |
+
mel_fmax=22050,
|
597 |
+
)
|
598 |
+
|
599 |
+
def vae_encode(self, x: torch.Tensor) -> torch.Tensor:
|
600 |
+
x = x.unsqueeze(0)
|
601 |
+
|
602 |
+
audio_vae_length = int(self.model.transformer.config.sample_size * self.model.vae.hop_length)
|
603 |
+
audio_shape = (1, self.model.vae.config.audio_channels, audio_vae_length)
|
604 |
+
|
605 |
+
# check num_channels
|
606 |
+
if x.shape[1] == 1 and self.model.vae.config.audio_channels == 2:
|
607 |
+
x = x.repeat(1, 2, 1)
|
608 |
+
|
609 |
+
audio_length = x.shape[-1]
|
610 |
+
audio = x.new_zeros(audio_shape)
|
611 |
+
audio[:, :, : min(audio_length, audio_vae_length)] = x[:, :, :audio_vae_length]
|
612 |
+
|
613 |
+
encoded_audio = self.model.vae.encode(audio.to(self.device)).latent_dist
|
614 |
+
encoded_audio = encoded_audio.sample()
|
615 |
+
return encoded_audio
|
616 |
+
|
617 |
+
def vae_decode(self, x: torch.Tensor) -> torch.Tensor:
|
618 |
+
torch.cuda.empty_cache()
|
619 |
+
# return self.model.vae.decode(1 / self.model.vae.config.scaling_factor * x).sample
|
620 |
+
aud = self.model.vae.decode(x).sample
|
621 |
+
return aud[:, :, self.waveform_start:self.waveform_end]
|
622 |
+
|
623 |
+
def setup_extra_inputs(self, x: torch.Tensor, init_timestep: torch.Tensor,
|
624 |
+
extra_info: Optional[Any] = None,
|
625 |
+
audio_start_in_s: float = 0, audio_end_in_s: Optional[float] = None,
|
626 |
+
save_compute: bool = False) -> None:
|
627 |
+
max_audio_length_in_s = self.model.transformer.config.sample_size * self.model.vae.hop_length / \
|
628 |
+
self.model.vae.config.sampling_rate
|
629 |
+
if audio_end_in_s is None:
|
630 |
+
audio_end_in_s = max_audio_length_in_s
|
631 |
+
|
632 |
+
if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
|
633 |
+
raise ValueError(
|
634 |
+
f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer "
|
635 |
+
f"than the model maximum possible length ({max_audio_length_in_s}). "
|
636 |
+
f"Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
|
637 |
+
)
|
638 |
+
|
639 |
+
self.waveform_start = int(audio_start_in_s * self.model.vae.config.sampling_rate)
|
640 |
+
self.waveform_end = int(audio_end_in_s * self.model.vae.config.sampling_rate)
|
641 |
+
|
642 |
+
self.seconds_start_hidden_states, self.seconds_end_hidden_states = self.model.encode_duration(
|
643 |
+
audio_start_in_s, audio_end_in_s, self.device, False, 1)
|
644 |
+
|
645 |
+
if save_compute:
|
646 |
+
self.seconds_start_hidden_states = torch.cat([self.seconds_start_hidden_states, self.seconds_start_hidden_states], dim=0)
|
647 |
+
self.seconds_end_hidden_states = torch.cat([self.seconds_end_hidden_states, self.seconds_end_hidden_states], dim=0)
|
648 |
+
|
649 |
+
self.audio_duration_embeds = torch.cat([self.seconds_start_hidden_states,
|
650 |
+
self.seconds_end_hidden_states], dim=2)
|
651 |
+
|
652 |
+
# 7. Prepare rotary positional embedding
|
653 |
+
self.rotary_embedding = get_1d_rotary_pos_embed(
|
654 |
+
self.model.rotary_embed_dim,
|
655 |
+
x.shape[2] + self.audio_duration_embeds.shape[1],
|
656 |
+
use_real=True,
|
657 |
+
repeat_interleave_real=False,
|
658 |
+
)
|
659 |
+
|
660 |
+
self.model.scheduler._init_step_index(init_timestep)
|
661 |
+
|
662 |
+
# fix lower_order_nums for the reverse step - Option 1: only start from first order
|
663 |
+
# self.model.scheduler.lower_order_nums = 0
|
664 |
+
# self.model.scheduler.model_outputs = [None] * self.model.scheduler.config.solver_order
|
665 |
+
# fix lower_order_nums for the reverse step - Option 2: start from the correct order with history
|
666 |
+
t_to_idx = {float(v): k for k, v in enumerate(self.model.scheduler.timesteps)}
|
667 |
+
idx = len(self.model.scheduler.timesteps) - t_to_idx[float(init_timestep)] - 1
|
668 |
+
self.model.scheduler.model_outputs = [None, extra_info[idx] if extra_info is not None else None]
|
669 |
+
self.model.scheduler.lower_order_nums = min(self.model.scheduler.step_index,
|
670 |
+
self.model.scheduler.config.solver_order)
|
671 |
+
|
672 |
+
# if rand check:
|
673 |
+
# x *= self.model.scheduler.init_noise_sigma
|
674 |
+
# return x
|
675 |
+
|
676 |
+
def sample_xts_from_x0(self, x0: torch.Tensor, num_inference_steps: int = 50) -> torch.Tensor:
|
677 |
+
"""
|
678 |
+
Samples from P(x_1:T|x_0)
|
679 |
+
"""
|
680 |
+
|
681 |
+
sigmas = self.model.scheduler.sigmas
|
682 |
+
shapes = self.get_noise_shape(x0, num_inference_steps + 1)
|
683 |
+
xts = torch.zeros(shapes).to(x0.device)
|
684 |
+
xts[0] = x0
|
685 |
+
|
686 |
+
timesteps = self.model.scheduler.timesteps.to(self.device)
|
687 |
+
t_to_idx = {float(v): k for k, v in enumerate(timesteps)}
|
688 |
+
for t in reversed(timesteps):
|
689 |
+
# idx = t_to_idx[int(t)]
|
690 |
+
idx = num_inference_steps - t_to_idx[float(t)]
|
691 |
+
n = torch.randn_like(x0)
|
692 |
+
xts[idx] = x0 + n * sigmas[t_to_idx[float(t)]]
|
693 |
+
return xts
|
694 |
+
|
695 |
+
def get_zs_from_xts(self, xt: torch.Tensor, xtm1: torch.Tensor, data_pred: torch.Tensor,
|
696 |
+
t: torch.Tensor, numerical_fix: bool = True, first_order: bool = False, **kwargs
|
697 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
698 |
+
# pred of x0
|
699 |
+
sigmas = self.model.scheduler.sigmas
|
700 |
+
timesteps = self.model.scheduler.timesteps
|
701 |
+
solver_order = self.model.scheduler.config.solver_order
|
702 |
+
|
703 |
+
if self.model.scheduler.step_index is None:
|
704 |
+
self.model.scheduler._init_step_index(t)
|
705 |
+
curr_step_index = self.model.scheduler.step_index
|
706 |
+
|
707 |
+
# Improve numerical stability for small number of steps
|
708 |
+
lower_order_final = (curr_step_index == len(timesteps) - 1) and (
|
709 |
+
self.model.scheduler.config.euler_at_final
|
710 |
+
or (self.model.scheduler.config.lower_order_final and len(timesteps) < 15)
|
711 |
+
or self.model.scheduler.config.final_sigmas_type == "zero")
|
712 |
+
lower_order_second = ((curr_step_index == len(timesteps) - 2) and
|
713 |
+
self.model.scheduler.config.lower_order_final and len(timesteps) < 15)
|
714 |
+
|
715 |
+
data_pred = self.model.scheduler.convert_model_output(data_pred, sample=xt)
|
716 |
+
for i in range(solver_order - 1):
|
717 |
+
self.model.scheduler.model_outputs[i] = self.model.scheduler.model_outputs[i + 1]
|
718 |
+
self.model.scheduler.model_outputs[-1] = data_pred
|
719 |
+
|
720 |
+
# instead of brownian noise, here we calculate the noise ourselves
|
721 |
+
if (curr_step_index == len(timesteps) - 1) and self.model.scheduler.config.final_sigmas_type == "zero":
|
722 |
+
z = torch.zeros_like(xt)
|
723 |
+
elif first_order or solver_order == 1 or self.model.scheduler.lower_order_nums < 1 or lower_order_final:
|
724 |
+
sigma_t, sigma_s = sigmas[curr_step_index + 1], sigmas[curr_step_index]
|
725 |
+
h = torch.log(sigma_s) - torch.log(sigma_t)
|
726 |
+
z = (xtm1 - (sigma_t / sigma_s * torch.exp(-h)) * xt - (1 - torch.exp(-2.0 * h)) * data_pred) \
|
727 |
+
/ (sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)))
|
728 |
+
elif solver_order == 2 or self.model.scheduler.lower_order_nums < 2 or lower_order_second:
|
729 |
+
sigma_t = sigmas[curr_step_index + 1]
|
730 |
+
sigma_s0 = sigmas[curr_step_index]
|
731 |
+
sigma_s1 = sigmas[curr_step_index - 1]
|
732 |
+
m0, m1 = self.model.scheduler.model_outputs[-1], self.model.scheduler.model_outputs[-2]
|
733 |
+
h, h_0 = torch.log(sigma_s0) - torch.log(sigma_t), torch.log(sigma_s1) - torch.log(sigma_s0)
|
734 |
+
r0 = h_0 / h
|
735 |
+
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
736 |
+
|
737 |
+
# sde-dpmsolver++
|
738 |
+
z = (xtm1 - (sigma_t / sigma_s0 * torch.exp(-h)) * xt
|
739 |
+
- (1 - torch.exp(-2.0 * h)) * D0
|
740 |
+
- 0.5 * (1 - torch.exp(-2.0 * h)) * D1) \
|
741 |
+
/ (sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)))
|
742 |
+
|
743 |
+
# correction to avoid error accumulation
|
744 |
+
if numerical_fix:
|
745 |
+
if first_order or solver_order == 1 or self.model.scheduler.lower_order_nums < 1 or lower_order_final:
|
746 |
+
xtm1 = self.model.scheduler.dpm_solver_first_order_update(data_pred, sample=xt, noise=z)
|
747 |
+
elif solver_order == 2 or self.model.scheduler.lower_order_nums < 2 or lower_order_second:
|
748 |
+
xtm1 = self.model.scheduler.multistep_dpm_solver_second_order_update(
|
749 |
+
self.model.scheduler.model_outputs, sample=xt, noise=z)
|
750 |
+
# If not perfect recon - maybe TODO fix self.model.scheduler.model_outputs as well?
|
751 |
+
|
752 |
+
if self.model.scheduler.lower_order_nums < solver_order:
|
753 |
+
self.model.scheduler.lower_order_nums += 1
|
754 |
+
# upon completion increase step index by one
|
755 |
+
self.model.scheduler._step_index += 1
|
756 |
+
|
757 |
+
return z, xtm1, self.model.scheduler.model_outputs[-2]
|
758 |
+
|
759 |
+
def get_sr(self) -> int:
|
760 |
+
return self.model.vae.config.sampling_rate
|
761 |
+
|
762 |
+
def get_noise_shape(self, x0: torch.Tensor, num_steps: int) -> Tuple[int, int, int]:
|
763 |
+
variance_noise_shape = (num_steps,
|
764 |
+
self.model.transformer.config.in_channels,
|
765 |
+
int(self.model.transformer.config.sample_size))
|
766 |
+
return variance_noise_shape
|
767 |
+
|
768 |
+
def reverse_step_with_custom_noise(self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor,
|
769 |
+
variance_noise: Optional[torch.Tensor] = None,
|
770 |
+
first_order: bool = False, **kwargs
|
771 |
+
) -> torch.Tensor:
|
772 |
+
if self.model.scheduler.step_index is None:
|
773 |
+
self.model.scheduler._init_step_index(timestep)
|
774 |
+
|
775 |
+
# Improve numerical stability for small number of steps
|
776 |
+
lower_order_final = (self.model.scheduler.step_index == len(self.model.scheduler.timesteps) - 1) and (
|
777 |
+
self.model.scheduler.config.euler_at_final
|
778 |
+
or (self.model.scheduler.config.lower_order_final and len(self.model.scheduler.timesteps) < 15)
|
779 |
+
or self.model.scheduler.config.final_sigmas_type == "zero"
|
780 |
+
)
|
781 |
+
lower_order_second = (
|
782 |
+
(self.model.scheduler.step_index == len(self.model.scheduler.timesteps) - 2) and
|
783 |
+
self.model.scheduler.config.lower_order_final and len(self.model.scheduler.timesteps) < 15
|
784 |
+
)
|
785 |
+
|
786 |
+
model_output = self.model.scheduler.convert_model_output(model_output, sample=sample)
|
787 |
+
for i in range(self.model.scheduler.config.solver_order - 1):
|
788 |
+
self.model.scheduler.model_outputs[i] = self.model.scheduler.model_outputs[i + 1]
|
789 |
+
self.model.scheduler.model_outputs[-1] = model_output
|
790 |
+
|
791 |
+
if variance_noise is None:
|
792 |
+
if self.model.scheduler.noise_sampler is None:
|
793 |
+
self.model.scheduler.noise_sampler = BrownianTreeNoiseSampler(
|
794 |
+
model_output, sigma_min=self.model.scheduler.config.sigma_min,
|
795 |
+
sigma_max=self.model.scheduler.config.sigma_max, seed=None)
|
796 |
+
variance_noise = self.model.scheduler.noise_sampler(
|
797 |
+
self.model.scheduler.sigmas[self.model.scheduler.step_index],
|
798 |
+
self.model.scheduler.sigmas[self.model.scheduler.step_index + 1]).to(model_output.device)
|
799 |
+
|
800 |
+
if first_order or self.model.scheduler.config.solver_order == 1 or \
|
801 |
+
self.model.scheduler.lower_order_nums < 1 or lower_order_final:
|
802 |
+
prev_sample = self.model.scheduler.dpm_solver_first_order_update(
|
803 |
+
model_output, sample=sample, noise=variance_noise)
|
804 |
+
elif self.model.scheduler.config.solver_order == 2 or \
|
805 |
+
self.model.scheduler.lower_order_nums < 2 or lower_order_second:
|
806 |
+
prev_sample = self.model.scheduler.multistep_dpm_solver_second_order_update(
|
807 |
+
self.model.scheduler.model_outputs, sample=sample, noise=variance_noise)
|
808 |
+
|
809 |
+
if self.model.scheduler.lower_order_nums < self.model.scheduler.config.solver_order:
|
810 |
+
self.model.scheduler.lower_order_nums += 1
|
811 |
+
|
812 |
+
# upon completion increase step index by one
|
813 |
+
self.model.scheduler._step_index += 1
|
814 |
+
|
815 |
+
return prev_sample
|
816 |
+
|
817 |
+
def unet_forward(self,
|
818 |
+
sample: torch.FloatTensor,
|
819 |
+
timestep: Union[torch.Tensor, float, int],
|
820 |
+
encoder_hidden_states: torch.Tensor,
|
821 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
822 |
+
return_dict: bool = True,
|
823 |
+
**kwargs) -> Tuple:
|
824 |
+
|
825 |
+
# Create text_audio_duration_embeds and audio_duration_embeds
|
826 |
+
embeds = torch.cat([encoder_hidden_states, self.seconds_start_hidden_states, self.seconds_end_hidden_states],
|
827 |
+
dim=1)
|
828 |
+
if encoder_attention_mask is None:
|
829 |
+
# handle the batched case
|
830 |
+
if embeds.shape[0] > 1:
|
831 |
+
embeds[0] = torch.zeros_like(embeds[0], device=embeds.device)
|
832 |
+
else:
|
833 |
+
embeds = torch.zeros_like(embeds, device=embeds.device)
|
834 |
+
|
835 |
+
noise_pred = self.model.transformer(sample,
|
836 |
+
timestep.unsqueeze(0),
|
837 |
+
encoder_hidden_states=embeds,
|
838 |
+
global_hidden_states=self.audio_duration_embeds,
|
839 |
+
rotary_embedding=self.rotary_embedding)
|
840 |
+
|
841 |
+
if not return_dict:
|
842 |
+
return (noise_pred.sample,)
|
843 |
+
|
844 |
+
return noise_pred, None, None
|
845 |
|
846 |
|
847 |
+
def load_model(model_id: str, device: torch.device,
|
848 |
+
double_precision: bool = False, token: Optional[str] = None) -> PipelineWrapper:
|
849 |
+
if 'audioldm2' in model_id:
|
850 |
+
ldm_stable = AudioLDM2Wrapper(model_id=model_id, device=device, double_precision=double_precision, token=token)
|
851 |
+
elif 'stable-audio' in model_id:
|
852 |
+
ldm_stable = StableAudWrapper(model_id=model_id, device=device, double_precision=double_precision, token=token)
|
853 |
ldm_stable.load_scheduler()
|
854 |
torch.cuda.empty_cache()
|
855 |
return ldm_stable
|
requirements.txt
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
-
torch
|
2 |
-
numpy<2
|
3 |
torchaudio
|
4 |
diffusers
|
5 |
accelerate
|
|
|
6 |
transformers
|
7 |
tqdm
|
8 |
soundfile
|
|
|
1 |
+
torch>2.2.0
|
2 |
+
numpy<2.0.0
|
3 |
torchaudio
|
4 |
diffusers
|
5 |
accelerate
|
6 |
+
torchsde
|
7 |
transformers
|
8 |
tqdm
|
9 |
soundfile
|
utils.py
CHANGED
@@ -2,8 +2,11 @@ import numpy as np
|
|
2 |
import torch
|
3 |
from typing import Optional, List, Tuple, NamedTuple, Union
|
4 |
from models import PipelineWrapper
|
|
|
5 |
from audioldm.utils import get_duration
|
6 |
|
|
|
|
|
7 |
|
8 |
class PromptEmbeddings(NamedTuple):
|
9 |
embedding_hidden_states: torch.Tensor
|
@@ -11,26 +14,57 @@ class PromptEmbeddings(NamedTuple):
|
|
11 |
boolean_prompt_mask: torch.Tensor
|
12 |
|
13 |
|
14 |
-
def load_audio(audio_path: Union[str, np.array], fn_STFT, left: int = 0, right: int = 0,
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
19 |
|
20 |
-
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
|
36 |
def get_height_of_spectrogram(length: int, ldm_stable: PipelineWrapper) -> int:
|
|
|
2 |
import torch
|
3 |
from typing import Optional, List, Tuple, NamedTuple, Union
|
4 |
from models import PipelineWrapper
|
5 |
+
import torchaudio
|
6 |
from audioldm.utils import get_duration
|
7 |
|
8 |
+
MAX_DURATION = 30
|
9 |
+
|
10 |
|
11 |
class PromptEmbeddings(NamedTuple):
|
12 |
embedding_hidden_states: torch.Tensor
|
|
|
14 |
boolean_prompt_mask: torch.Tensor
|
15 |
|
16 |
|
17 |
+
def load_audio(audio_path: Union[str, np.array], fn_STFT, left: int = 0, right: int = 0,
|
18 |
+
device: Optional[torch.device] = None,
|
19 |
+
return_wav: bool = False, stft: bool = False, model_sr: Optional[int] = None) -> torch.Tensor:
|
20 |
+
if stft: # AudioLDM/tango loading to spectrogram
|
21 |
+
if type(audio_path) is str:
|
22 |
+
import audioldm
|
23 |
+
import audioldm.audio
|
24 |
|
25 |
+
duration = get_duration(audio_path)
|
26 |
+
if MAX_DURATION is not None:
|
27 |
+
duration = min(duration, MAX_DURATION)
|
28 |
|
29 |
+
mel, _, wav = audioldm.audio.wav_to_fbank(audio_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT)
|
30 |
+
mel = mel.unsqueeze(0)
|
31 |
+
else:
|
32 |
+
mel = audio_path
|
33 |
|
34 |
+
c, h, w = mel.shape
|
35 |
+
left = min(left, w-1)
|
36 |
+
right = min(right, w - left - 1)
|
37 |
+
mel = mel[:, :, left:w-right]
|
38 |
+
mel = mel.unsqueeze(0).to(device)
|
39 |
|
40 |
+
if return_wav:
|
41 |
+
return mel, 16000, duration, wav
|
42 |
+
|
43 |
+
return mel, model_sr, duration
|
44 |
+
else:
|
45 |
+
waveform, sr = torchaudio.load(audio_path)
|
46 |
+
if sr != model_sr:
|
47 |
+
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=model_sr)
|
48 |
+
# waveform = waveform.numpy()[0, ...]
|
49 |
+
|
50 |
+
def normalize_wav(waveform):
|
51 |
+
waveform = waveform - torch.mean(waveform)
|
52 |
+
waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
|
53 |
+
return waveform * 0.5
|
54 |
+
|
55 |
+
waveform = normalize_wav(waveform)
|
56 |
+
# waveform = waveform[None, ...]
|
57 |
+
# waveform = pad_wav(waveform, segment_length)
|
58 |
+
|
59 |
+
# waveform = waveform[0, ...]
|
60 |
+
waveform = torch.FloatTensor(waveform)
|
61 |
+
if MAX_DURATION is not None:
|
62 |
+
duration = min(waveform.shape[-1] / model_sr, MAX_DURATION)
|
63 |
+
waveform = waveform[:, :int(duration * model_sr)]
|
64 |
+
|
65 |
+
# cut waveform
|
66 |
+
duration = waveform.shape[-1] / model_sr
|
67 |
+
return waveform, model_sr, duration
|
68 |
|
69 |
|
70 |
def get_height_of_spectrogram(length: int, ldm_stable: PipelineWrapper) -> int:
|