drscotthawley commited on
Commit
b887586
1 Parent(s): b46aa4b

needed to add sample.py

Browse files
Files changed (1) hide show
  1. sample.py +645 -0
sample.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Code by Kat Crowson in k-diffusion repo, modified by Scott H Hawley (SHH)
4
+
5
+ """Samples from k-diffusion models."""
6
+
7
+ import argparse
8
+ from pathlib import Path
9
+
10
+ import accelerate
11
+ import safetensors.torch as safetorch
12
+ import torch
13
+ from tqdm import trange, tqdm
14
+ from PIL import Image
15
+ from torchvision import transforms
16
+
17
+ import k_diffusion as K
18
+
19
+ from control_toys.v_diffusion import DDPM, LogSchedule, CrashSchedule
20
+ #CHORD_BORDER = 8 # chord border size in pixels
21
+ from control_toys.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder
22
+
23
+
24
+ # ---- my mangled sampler that includes repaint
25
+ import torchsde
26
+
27
+ class BatchedBrownianTree:
28
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
29
+
30
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
31
+ t0, t1, self.sign = self.sort(t0, t1)
32
+ w0 = kwargs.get('w0', torch.zeros_like(x))
33
+ if seed is None:
34
+ seed = torch.randint(0, 2 ** 63 - 1, []).item()
35
+ self.batched = True
36
+ try:
37
+ assert len(seed) == x.shape[0]
38
+ w0 = w0[0]
39
+ except TypeError:
40
+ seed = [seed]
41
+ self.batched = False
42
+ self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
43
+
44
+ @staticmethod
45
+ def sort(a, b):
46
+ return (a, b, 1) if a < b else (b, a, -1)
47
+
48
+ def __call__(self, t0, t1):
49
+ t0, t1, sign = self.sort(t0, t1)
50
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
51
+ return w if self.batched else w[0]
52
+
53
+
54
+ class BrownianTreeNoiseSampler:
55
+ """A noise sampler backed by a torchsde.BrownianTree.
56
+
57
+ Args:
58
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
59
+ random samples.
60
+ sigma_min (float): The low end of the valid interval.
61
+ sigma_max (float): The high end of the valid interval.
62
+ seed (int or List[int]): The random seed. If a list of seeds is
63
+ supplied instead of a single integer, then the noise sampler will
64
+ use one BrownianTree per batch item, each with its own seed.
65
+ transform (callable): A function that maps sigma to the sampler's
66
+ internal timestep.
67
+ """
68
+
69
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
70
+ self.transform = transform
71
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
72
+ self.tree = BatchedBrownianTree(x, t0, t1, seed)
73
+
74
+ def __call__(self, sigma, sigma_next):
75
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
76
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
77
+
78
+ def append_dims(x, target_dims):
79
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
80
+ dims_to_append = target_dims - x.ndim
81
+ if dims_to_append < 0:
82
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
83
+ return x[(...,) + (None,) * dims_to_append]
84
+
85
+
86
+ def to_d(x, sigma, denoised):
87
+ """Converts a denoiser output to a Karras ODE derivative."""
88
+ return (x - denoised) / append_dims(sigma, x.ndim)
89
+
90
+
91
+ @torch.no_grad()
92
+ def my_sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., repaint=1):
93
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
94
+ extra_args = {} if extra_args is None else extra_args
95
+ s_in = x.new_ones([x.shape[0]])
96
+ for i in trange(len(sigmas) - 1, disable=disable):
97
+ for u in range(repaint):
98
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
99
+ eps = torch.randn_like(x) * s_noise
100
+ sigma_hat = sigmas[i] * (gamma + 1)
101
+ if gamma > 0:
102
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
103
+ denoised = model(x, sigma_hat * s_in, **extra_args)
104
+ d = to_d(x, sigma_hat, denoised)
105
+ if callback is not None:
106
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
107
+ dt = sigmas[i + 1] - sigma_hat
108
+ # Euler method
109
+ x = x + d * dt
110
+ if x.isnan().any():
111
+ assert False, f"x has NaNs, i = {i}, u = {u}, repaint = {repaint}"
112
+ if u < repaint - 1:
113
+ beta = (sigmas[i + 1] / sigmas[-1]) ** 2
114
+ x = torch.sqrt(1 - beta) * x + torch.sqrt(beta) * torch.randn_like(x)
115
+
116
+ return x
117
+
118
+ def get_scalings(sigma, sigma_data=0.5):
119
+ c_skip = sigma_data ** 2 / (sigma ** 2 + sigma_data ** 2)
120
+ c_out = sigma * sigma_data / (sigma ** 2 + sigma_data ** 2) ** 0.5
121
+ c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5
122
+ return c_skip, c_out, c_in
123
+
124
+
125
+ @torch.no_grad()
126
+ def my_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None,
127
+ disable=None, eta=1., s_noise=1., noise_sampler=None,
128
+ solver_type='midpoint',
129
+ repaint=4):
130
+ """DPM-Solver++(2M) SDE. but with repaint added"""
131
+
132
+ if solver_type not in {'heun', 'midpoint'}:
133
+ raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
134
+
135
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
136
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
137
+ extra_args = {} if extra_args is None else extra_args
138
+ s_in = x.new_ones([x.shape[0]])
139
+
140
+ old_denoised = None
141
+ h_last = None
142
+ old_x = None
143
+
144
+ for i in trange(len(sigmas) - 1, disable=disable): # time loop
145
+
146
+ for u in range(repaint):
147
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
148
+ if callback is not None:
149
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
150
+ #print("i, u, sigmas[i], sigmas[i + 1] = ", i, u, sigmas[i], sigmas[i + 1])
151
+ if sigmas[i + 1] == 0:
152
+ # Denoising step
153
+ x = denoised
154
+ else:
155
+ # DPM-Solver++(2M) SDE
156
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
157
+ h = s - t
158
+ eta_h = eta * h
159
+
160
+ x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
161
+
162
+ if old_denoised is not None:
163
+ r = h_last / h
164
+ if solver_type == 'heun':
165
+ x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
166
+ elif solver_type == 'midpoint':
167
+ x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
168
+
169
+ if eta:
170
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
171
+
172
+
173
+ if callback is not None:
174
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
175
+
176
+ if x.isnan().any():
177
+ assert False, f"x has NaNs, i = {i}, u = {u}, repaint = {repaint}"
178
+
179
+ if u < repaint - 1:
180
+ # RePaint: go "back" in integration via the "forward" process, by adding a little noise to x
181
+ # ...but scaled properly!
182
+ # But how to convert from original RePaint to k-diffusion? I'll try a few variants
183
+ repaint_choice = 'orig' # ['orig','var1','var2', etc...]
184
+
185
+ sigma_diff = (sigmas[i] - sigmas[i+1]).abs()
186
+ sigma_ratio = ( sigmas[i+1] / sigma_max ) # use i+1 or i?
187
+ if repaint_choice == 'orig': # attempt at original RePaint algorithm, which used betas
188
+ # if sigmas are the std devs, then betas are variances? but beta_max = 1, so how to get that? ratio?
189
+ beta = sigma_ratio**2
190
+ x = torch.sqrt(1-beta)*x + torch.sqrt(beta)*torch.randn_like(x) # this is from RePaint Paper
191
+ elif repaint_choice == 'var1': # or maybe this...? # worse than orig
192
+ x = x + sigma_diff*torch.randn_like(x)
193
+ elif repaint_choice == 'var2': # or this...? # yields NaNs
194
+ x = (1-sigma_diff)*x + sigma_diff*torch.randn_like(x)
195
+ elif repaint_choice == 'var3': # results similar to var1
196
+ x = (1.0-sigma_ratio)*x + sigmas[i+1]*torch.randn_like(x)
197
+ elif repaint_choice == 'var4': # NaNs # stealing code from elsewhere, no idea WTF I'm doing.
198
+ #Invert this: target = (input - c_skip * noised_input) / c_out, where target = model_output
199
+ x_tm1, x_t = x, old_x
200
+ # x_tm1 = ( x_0 - c_skip * noised_x0 ) / c_out
201
+ # So x_tm1*c_out = x_0 - c_skip * noised_x0
202
+ input, noise = x_tm1, torch.randn_like(x)
203
+ noised_input = input + noise * append_dims(sigma_diff, input.ndim)
204
+ c_skip, c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigmas[i])]
205
+ model_output = x_tm1
206
+ renoised_x = c_out * model_output + c_skip * noised_input
207
+ x = renoised_x
208
+ elif repaint_choice == 'var5':
209
+ x = torch.sqrt((1-(sigma_diff/sigma_max)**2))*x + sigma_diff*torch.randn_like(x)
210
+
211
+ # include this? guessing no.
212
+ #old_denoised = denoised
213
+ #h_last = h
214
+
215
+ old_denoised = denoised
216
+ h_last = h
217
+ old_x = x
218
+ return x
219
+
220
+
221
+
222
+
223
+ # -----from stable-audio-tools
224
+
225
+ # Define the noise schedule and sampling loop
226
+ def get_alphas_sigmas(t):
227
+ """Returns the scaling factors for the clean image (alpha) and for the
228
+ noise (sigma), given a timestep."""
229
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
230
+
231
+ def alpha_sigma_to_t(alpha, sigma):
232
+ """Returns a timestep, given the scaling factors for the clean image and for
233
+ the noise."""
234
+ return torch.atan2(sigma, alpha) / math.pi * 2
235
+
236
+ def t_to_alpha_sigma(t):
237
+ """Returns the scaling factors for the clean image and for the noise, given
238
+ a timestep."""
239
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
240
+
241
+ @torch.no_grad()
242
+ def sample(model, x, steps, eta, **extra_args):
243
+ """Draws samples from a model given starting noise. v-diffusion"""
244
+ ts = x.new_ones([x.shape[0]])
245
+
246
+ # Create the noise schedule
247
+ t = torch.linspace(1, 0, steps + 1)[:-1]
248
+
249
+ alphas, sigmas = get_alphas_sigmas(t)
250
+
251
+ # The sampling loop
252
+ for i in trange(steps):
253
+
254
+ # Get the model output (v, the predicted velocity)
255
+ with torch.cuda.amp.autocast():
256
+ v = model(x, ts * t[i], **extra_args).float()
257
+
258
+ # Predict the noise and the denoised image
259
+ pred = x * alphas[i] - v * sigmas[i]
260
+ eps = x * sigmas[i] + v * alphas[i]
261
+
262
+ # If we are not on the last timestep, compute the noisy image for the
263
+ # next timestep.
264
+ if i < steps - 1:
265
+ # If eta > 0, adjust the scaling factor for the predicted noise
266
+ # downward according to the amount of additional noise to add
267
+ ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
268
+ (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
269
+ adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
270
+
271
+ # Recombine the predicted noise and predicted denoised image in the
272
+ # correct proportions for the next step
273
+ x = pred * alphas[i + 1] + eps * adjusted_sigma
274
+
275
+ # Add the correct amount of fresh noise
276
+ if eta:
277
+ x += torch.randn_like(x) * ddim_sigma
278
+
279
+ # If we are on the last timestep, output the denoised image
280
+ return pred
281
+
282
+ # Soft mask inpainting is just shrinking hard (binary) mask inpainting
283
+ # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
284
+ def get_bmask(i, steps, mask):
285
+ strength = (i+1)/(steps)
286
+ # convert to binary mask
287
+ bmask = torch.where(mask<=strength,1,0)
288
+ return bmask
289
+
290
+ def make_cond_model_fn(model, cond_fn):
291
+ def cond_model_fn(x, sigma, **kwargs):
292
+ with torch.enable_grad():
293
+ x = x.detach().requires_grad_()
294
+ denoised = model(x, sigma, **kwargs)
295
+ cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
296
+ cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
297
+ return cond_denoised
298
+ return cond_model_fn
299
+
300
+ # Uses k-diffusion from https://github.com/crowsonkb/k-diffusion
301
+ # init_data is init_audio as latents (if this is latent diffusion)
302
+ # For sampling, set both init_data and mask to None
303
+ # For variations, set init_data
304
+ # For inpainting, set both init_data & mask
305
+ def sample_k(
306
+ model_fn,
307
+ noise,
308
+ init_data=None,
309
+ mask=None,
310
+ steps=100,
311
+ sampler_type="dpmpp-2m-sde",
312
+ sigma_min=0.5,
313
+ sigma_max=50,
314
+ rho=1.0, device="cuda",
315
+ callback=None,
316
+ cond_fn=None,
317
+ model_config=None,
318
+ repaint=1,
319
+ **extra_args
320
+ ):
321
+
322
+ #denoiser = K.external.VDenoiser(model_fn)
323
+ denoiser = K.Denoiser(model_fn, sigma_data=model_config['sigma_data'])
324
+
325
+ if cond_fn is not None:
326
+ denoiser = make_cond_model_fn(denoiser, cond_fn)
327
+
328
+ # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has
329
+ #sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
330
+ sigmas = K.sampling.get_sigmas_karras(steps, sigma_min, sigma_max, rho=7., device=device)
331
+ print("sigmas[0] = ", sigmas[0])
332
+ # Scale the initial noise by sigma
333
+ noise = noise * sigmas[0]
334
+
335
+ wrapped_callback = callback
336
+
337
+ if mask is None and init_data is not None:
338
+ # VARIATION (no inpainting)
339
+ # set the initial latent to the init_data, and noise it with initial sigma
340
+ x = init_data + noise
341
+ elif mask is not None and init_data is not None:
342
+ # INPAINTING
343
+ bmask = get_bmask(0, steps, mask)
344
+ # initial noising
345
+ input_noised = init_data + noise
346
+ # set the initial latent to a mix of init_data and noise, based on step 0's binary mask
347
+ x = input_noised * bmask + noise * (1-bmask)
348
+ # define the inpainting callback function (Note: side effects, it mutates x)
349
+ # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105
350
+ # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
351
+ # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)`
352
+ def inpainting_callback(args):
353
+ i = args["i"]
354
+ x = args["x"]
355
+ sigma = args["sigma"]
356
+ #denoised = args["denoised"]
357
+ # noise the init_data input with this step's appropriate amount of noise
358
+ input_noised = init_data + torch.randn_like(init_data) * sigma
359
+ # shrinking hard mask
360
+ bmask = get_bmask(i, steps, mask)
361
+ # mix input_noise with x, using binary mask
362
+ new_x = input_noised * bmask + x * (1-bmask)
363
+ # mutate x
364
+ x[:,:,:] = new_x[:,:,:]
365
+ # wrap together the inpainting callback and the user-submitted callback.
366
+ if callback is None:
367
+ wrapped_callback = inpainting_callback
368
+ else:
369
+ wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
370
+ else:
371
+ # SAMPLING
372
+ # set the initial latent to noise
373
+ x = noise
374
+
375
+
376
+ print("sample_k: x.min, x.max = ", x.min(), x.max())
377
+ print(f"sample_k: key, val.dtype = ",[ (key, val.dtype if val is not None else val) for key,val in extra_args.items()])
378
+ with torch.cuda.amp.autocast():
379
+ if sampler_type == "k-heun":
380
+ return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
381
+ elif sampler_type == "k-lms":
382
+ return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
383
+ elif sampler_type == "k-dpmpp-2s-ancestral":
384
+ return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
385
+ elif sampler_type == "k-dpm-2":
386
+ return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
387
+ elif sampler_type == "k-dpm-fast":
388
+ return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
389
+ elif sampler_type == "k-dpm-adaptive":
390
+ return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
391
+ elif sampler_type == "dpmpp-2m-sde":
392
+ return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
393
+ elif sampler_type == "my-dpmpp-2m-sde":
394
+ return my_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, repaint=repaint, extra_args=extra_args)
395
+ elif sampler_type == "dpmpp-3m-sde":
396
+ return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
397
+ elif sampler_type == "my-sample-euler":
398
+ return my_sample_euler(denoiser, x, sigmas, disable=False, callback=wrapped_callback, repaint=repaint, extra_args=extra_args)
399
+
400
+
401
+ ## ---- end stable-audio-tools
402
+ def infer_mask_from_init_img(img, mask_with='white'):
403
+ """given an image with mask areas marked, extract the mask itself
404
+ note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"""
405
+ print("Inferring mask from init_img")
406
+ assert mask_with in ['blue','white']
407
+ if not torch.is_tensor(img):
408
+ img = ToTensor()(img)
409
+ mask = torch.zeros(img.shape[-2:])
410
+ if mask_with == 'white':
411
+ mask[ (img[0,:,:]==1) & (img[1,:,:]==1) & (img[2,:,:]==1)] = 1
412
+ elif mask_with == 'blue':
413
+ mask[img[2,:,:]==1] = 1 # blue
414
+ return mask*1.0
415
+
416
+ def grow_mask(init_mask, grow_by=2):
417
+ "adds a border of grow_by pixels to the mask, by growing it grow_by times. If grow_by=0, does nothing"
418
+ new_mask = init_mask.clone()
419
+ for c in range(grow_by):
420
+ # wherever mask is bordered by a 1, set it to 1
421
+ new_mask[1:-1,1:-1] = (new_mask[1:-1,1:-1] + new_mask[0:-2,1:-1] + new_mask[2:,1:-1] + new_mask[1:-1,0:-2] + new_mask[1:-1,2:]) > 0
422
+ return new_mask
423
+
424
+
425
+ def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
426
+ "adds extra noise inside mask"
427
+ init_mask = grow_mask(init_mask, grow_by=grow_by) # make the mask bigger
428
+ if not torch.is_tensor(init_image):
429
+ init_image = ToTensor()(init_image)
430
+ init_image = init_image.clone()
431
+ # wherever mask is 1, set first set init_image to min value
432
+ init_image[:,init_mask == 1] = init_image.min()
433
+ init_image = init_image + seed_scale*torch.randn_like(init_image) * (init_mask) # add noise where mask is 1
434
+ # wherever the mask is 1, set the blue channel to -1.0, otherwise leave it alone
435
+ init_image[2,:,:] = init_image[2,:,:] * (1-init_mask) - 1.0*init_mask
436
+ return init_image
437
+
438
+
439
+ def get_init_image_and_mask(args, device):
440
+ convert_tensor = transforms.ToTensor()
441
+ init_image = Image.open(args.init_image).convert('RGB')
442
+ init_image = convert_tensor(init_image)
443
+ #normalize image from 0..1 to -1..1
444
+ init_image = (2.0 * init_image) - 1.0
445
+
446
+
447
+ init_mask = torch.ones(init_image.shape[-2:]) # ones are where stuff will change, zeros will stay the same
448
+
449
+ inpaint_task = 'infer' # infer mask from init_image
450
+ assert inpaint_task in ['accomp','chords','melody','nucleation','notes','continue','infer']
451
+
452
+ if inpaint_task in ['melody','accomp']:
453
+ init_mask[0:70,:] = 0 # zero out a melody strip of image near top
454
+ init_mask[128+0:128+70,:] = 0 # zero out a melody strip of image along bottom row
455
+ if inpaint_task == 'melody':
456
+ init_mask = 1 - init_mask
457
+ elif inpaint_task in ['notes','chords']:
458
+ # keep chords only
459
+ #init_mask = torch.ones_like(x)
460
+ init_mask[0:CHORD_BORDER,:] = 0 # top row of 256x256
461
+ init_mask[128-CHORD_BORDER:128+CHORD_BORDER,:] = 0 # middle rows of 256x256
462
+ init_mask[-CHORD_BORDER:,:] = 0 # bottom row of 256x256
463
+ if inpaint_task == 'chords':
464
+ init_mask = 1 - init_mask # inverse: genereate chords given notes
465
+ elif inpaint_task == 'continue':
466
+ init_mask[0:128,:] = 0 # remember it's a square, so just mask out the bottom half
467
+ elif inpaint_task == 'nucleation':
468
+ # set mask to wherever the blue channel is >= 0.9
469
+ init_mask = (init_image[2,:,:] > 0.0)*1.0
470
+ # zero out init mask in top and bottom borders
471
+ init_mask[0:CHORD_BORDER,:] = 0
472
+ init_mask[-CHORD_BORDER:,:] = 0
473
+ init_mask[128-CHORD_BORDER:128+CHORD_BORDER,:] = 0
474
+
475
+ # remove all blue in init_image between the borders
476
+ init_image[2,CHORD_BORDER:128-CHORD_BORDER,:] = -1.0
477
+ init_image[2,128+CHORD_BORDER:-CHORD_BORDER,:] = -1.0
478
+
479
+ # grow the sides of the mask by one pixel:
480
+ # wherever mask is zero but is bordered by a 1, set it to 1
481
+ init_mask[1:-1,1:-1] = (init_mask[1:-1,1:-1] + init_mask[0:-2,1:-1] + init_mask[2:,1:-1] + init_mask[1:-1,0:-2] + init_mask[1:-1,2:]) > 0
482
+ #init_mask[1:-1,1:-1] = (init_mask[1:-1,1:-1] + init_mask[0:-2,1:-1] + init_mask[2:,1:-1] + init_mask[1:-1,0:-2] + init_mask[1:-1,2:]) > 0
483
+ elif inpaint_task == 'infer':
484
+ init_mask = infer_mask_from_init_img(init_image, mask_with='white')
485
+
486
+ # Also black out init_image wherever init mask is 1
487
+ init_image[:,init_mask == 1] = init_image.min()
488
+
489
+ if args.seed_scale > 0: # driving nucleation
490
+ print("Seeding nucleation, seed_scale = ", args.seed_scale)
491
+ init_image = add_seeding(init_image, init_mask, grow_by=0, seed_scale=args.seed_scale)
492
+
493
+ # remove any blue in middle of init image
494
+ print("init_image.shape = ", init_image.shape)
495
+ init_image[2,CHORD_BORDER:128-CHORD_BORDER,:] = -1.0
496
+ init_image[2,128+CHORD_BORDER:-CHORD_BORDER,:] = -1.0
497
+
498
+ # Debugging: output some images so we can see what's going on
499
+ init_mask_t = init_mask.float()*255 # convert mask to 0..255 for writing as image
500
+ # Convert to NumPy array and rearrange dimensions
501
+ init_mask_img_numpy = init_mask_t.byte().cpu().numpy()#.transpose(1, 2, 0)
502
+ init_mask_debug_img = Image.fromarray(init_mask_img_numpy)
503
+ init_mask_debug_img.save("init_mask_debug.png")
504
+ init_image_debug_img = Image.fromarray((init_image*127.5+127.5).byte().cpu().numpy().transpose(1,2,0))
505
+ init_image_debug_img.save("init_image_debug.png")
506
+
507
+ # reshape image and mask to be 4D tensors
508
+ init_image = init_image.unsqueeze(0).repeat(args.batch_size, 1, 1, 1)
509
+ init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
510
+ return init_image.to(device), init_mask.to(device)
511
+
512
+
513
+ def main():
514
+ global init_image, init_mask
515
+ p = argparse.ArgumentParser(description=__doc__,
516
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
517
+ p.add_argument('--batch-size', type=int, default=64,
518
+ help='the batch size')
519
+ p.add_argument('--checkpoint', type=Path, required=True,
520
+ help='the checkpoint to use')
521
+ p.add_argument('--config', type=Path,
522
+ help='the model config')
523
+ p.add_argument('-n', type=int, default=64,
524
+ help='the number of images to sample')
525
+ p.add_argument('--prefix', type=str, default='out',
526
+ help='the output prefix')
527
+ p.add_argument('--repaint', type=int, default=1,
528
+ help='number of (re)paint steps')
529
+ p.add_argument('--steps', type=int, default=50,
530
+ help='the number of denoising steps')
531
+ p.add_argument('--seed-scale', type=float, default=0.0, help='strength of nucleation seeding')
532
+ p.add_argument('--init-image', type=Path, default=None, help='the initial image')
533
+ p.add_argument('--init-strength', type=float, default=1., help='strength of init image')
534
+ args = p.parse_args()
535
+ print("args =", args, flush=True)
536
+
537
+ config = K.config.load_config(args.config if args.config else args.checkpoint)
538
+ model_config = config['model']
539
+ # TODO: allow non-square input sizes
540
+ assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1]
541
+ size = model_config['input_size']
542
+
543
+ accelerator = accelerate.Accelerator()
544
+ device = accelerator.device
545
+ print('Using device:', device, flush=True)
546
+
547
+ inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device)
548
+ cse = None # ChordSeqEncoder().eval().requires_grad_(False).to(device) # add chord embedding-maker to main model
549
+ if cse is not None:
550
+ inner_model.cse = cse
551
+ try:
552
+ inner_model.load_state_dict(safetorch.load_file(args.checkpoint))
553
+ except:
554
+ #ckpt = torch.load(args.checkpoint).to(device)
555
+ ckpt = torch.load(args.checkpoint, map_location='cpu')
556
+ inner_model.load_state_dict(ckpt['model'])
557
+
558
+ accelerator.print('Parameters:', K.utils.n_params(inner_model))
559
+ model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data'])
560
+
561
+ sigma_min = model_config['sigma_min']
562
+ sigma_max = model_config['sigma_max']
563
+
564
+ # SHH modified
565
+ torch.set_float32_matmul_precision('high')
566
+ #class_cond = torch.tensor([0]).to(device)
567
+ #num_classes = 10
568
+ #class_cond = torch.remainder(torch.arange(0, args.n), num_classes).int().to(device)
569
+ #extra_args = {'class_cond':class_cond}
570
+ extra_args = {}
571
+ init_image, init_mask = None, None
572
+ if args.init_image is not None:
573
+ init_image, init_mask = get_init_image_and_mask(args, device)
574
+ init_image = init_image.to(device)
575
+ init_mask = init_mask.to(device)
576
+
577
+ @torch.no_grad()
578
+ @K.utils.eval_mode(model)
579
+ def run():
580
+ global init_image, init_mask
581
+ if accelerator.is_local_main_process:
582
+ tqdm.write('Sampling...')
583
+ sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)
584
+
585
+ #ddpm_sampler = DDPM(model)
586
+ #model_fn = model
587
+ #ddpm_sampler = K.external.VDenoiser(model_fn)
588
+
589
+ def sample_fn(n, debug=True):
590
+ x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
591
+ print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())
592
+
593
+ if args.init_image is not None:
594
+ init_data, mask = get_init_image_and_mask(args, device)
595
+ init_data = args.seed_scale*x*mask + (1-mask)*init_data # extra nucleation?
596
+ if cse is not None:
597
+ chord_cond = img_batch_to_seq_emb(init_data, inner_model.cse).to(device)
598
+ else:
599
+ chord_cond = None
600
+ #print("init_data.shape, init_data.min, init_data.max = ", init_data.shape, init_data.min(), init_data.max())
601
+ else:
602
+ init_data, mask, chord_cond = None, None, None
603
+
604
+ print("chord_cond = ", chord_cond)
605
+ extra_args['chord_cond'] = chord_cond
606
+ # these two work:
607
+ #x_0 = K.sampling.sample_lms(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args)
608
+ #x_0 = K.sampling.sample_dpmpp_2m_sde(model, x, sigmas, disable=not accelerator.is_local_main_process, extra_args=extra_args)
609
+
610
+ noise = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device)
611
+
612
+ sampler_type="my-dpmpp-2m-sde" # "k-lms"
613
+ #sampler_type="my-sample-euler"
614
+ #sampler_type="dpmpp-2m-sde"
615
+ #sampler_type = "dpmpp-3m-sde"
616
+ #sampler_type = "k-dpmpp-2s-ancestral"
617
+ print("dtypes:", [x.dtype if x is not None else None for x in [noise, init_data, mask, chord_cond]])
618
+ x_0 = sample_k(inner_model, noise, sampler_type=sampler_type,
619
+ init_data=init_data, mask=mask, steps=args.steps,
620
+ sigma_min=sigma_min, sigma_max=sigma_max, rho=7.,
621
+ device=device, model_config=model_config, repaint=args.repaint,
622
+ **extra_args)
623
+ #x_0 = sample_k(inner_model, noise, sampler_type="dpmpp-2m-sde", steps=100, sigma_min=0.5, sigma_max=50, rho=1., device=device, model_config=model_config, **extra_args)
624
+ print("x_0.min, x_0.max = ", x_0.min(), x_0.max())
625
+ if x_0.isnan().any():
626
+ assert False, "x_0 has NaNs"
627
+
628
+ # do gpu garbage collection before proceeding
629
+ torch.cuda.empty_cache()
630
+ return x_0
631
+
632
+ x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size)
633
+ if accelerator.is_main_process:
634
+ for i, out in enumerate(x_0):
635
+ filename = f'{args.prefix}_{i:05}.png'
636
+ K.utils.to_pil_image(out).save(filename)
637
+
638
+ try:
639
+ run()
640
+ except KeyboardInterrupt:
641
+ pass
642
+
643
+
644
+ if __name__ == '__main__':
645
+ main()