drscotthawley commited on
Commit
6873531
1 Parent(s): 500319a

adding zerogpu decorators

Browse files
Files changed (1) hide show
  1. sample.py +17 -3
sample.py CHANGED
@@ -1,9 +1,13 @@
1
  #!/usr/bin/env python3
2
 
3
  # Code by Kat Crowson in k-diffusion repo, modified by Scott H Hawley (SHH)
 
4
 
5
  """Samples from k-diffusion models."""
6
 
 
 
 
7
  import argparse
8
  from pathlib import Path
9
 
@@ -24,6 +28,7 @@ from pom.chords import CHORD_BORDER, img_batch_to_seq_emb, ChordSeqEncoder
24
  # ---- my mangled sampler that includes repaint
25
  import torchsde
26
 
 
27
  class BatchedBrownianTree:
28
  """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
29
 
@@ -51,6 +56,7 @@ class BatchedBrownianTree:
51
  return w if self.batched else w[0]
52
 
53
 
 
54
  class BrownianTreeNoiseSampler:
55
  """A noise sampler backed by a torchsde.BrownianTree.
56
 
@@ -88,6 +94,7 @@ def to_d(x, sigma, denoised):
88
  return (x - denoised) / append_dims(sigma, x.ndim)
89
 
90
 
 
91
  @torch.no_grad()
92
  def my_sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., repaint=1):
93
  """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@@ -122,6 +129,7 @@ def get_scalings(sigma, sigma_data=0.5):
122
  return c_skip, c_out, c_in
123
 
124
 
 
125
  @torch.no_grad()
126
  def my_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None,
127
  disable=None, eta=1., s_noise=1., noise_sampler=None,
@@ -281,12 +289,14 @@ def sample(model, x, steps, eta, **extra_args):
281
 
282
  # Soft mask inpainting is just shrinking hard (binary) mask inpainting
283
  # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
 
284
  def get_bmask(i, steps, mask):
285
  strength = (i+1)/(steps)
286
  # convert to binary mask
287
  bmask = torch.where(mask<=strength,1,0)
288
  return bmask
289
 
 
290
  def make_cond_model_fn(model, cond_fn):
291
  def cond_model_fn(x, sigma, **kwargs):
292
  with torch.enable_grad():
@@ -302,6 +312,7 @@ def make_cond_model_fn(model, cond_fn):
302
  # For sampling, set both init_data and mask to None
303
  # For variations, set init_data
304
  # For inpainting, set both init_data & mask
 
305
  def sample_k(
306
  model_fn,
307
  noise,
@@ -399,6 +410,7 @@ def sample_k(
399
 
400
 
401
  ## ---- end stable-audio-tools
 
402
  def infer_mask_from_init_img(img, mask_with='white'):
403
  """given an image with mask areas marked, extract the mask itself
404
  note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"""
@@ -413,6 +425,7 @@ def infer_mask_from_init_img(img, mask_with='white'):
413
  mask[img[2,:,:]==1] = 1 # blue
414
  return mask*1.0
415
 
 
416
  def grow_mask(init_mask, grow_by=2):
417
  "adds a border of grow_by pixels to the mask, by growing it grow_by times. If grow_by=0, does nothing"
418
  new_mask = init_mask.clone()
@@ -421,7 +434,7 @@ def grow_mask(init_mask, grow_by=2):
421
  new_mask[1:-1,1:-1] = (new_mask[1:-1,1:-1] + new_mask[0:-2,1:-1] + new_mask[2:,1:-1] + new_mask[1:-1,0:-2] + new_mask[1:-1,2:]) > 0
422
  return new_mask
423
 
424
-
425
  def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
426
  "adds extra noise inside mask"
427
  init_mask = grow_mask(init_mask, grow_by=grow_by) # make the mask bigger
@@ -435,7 +448,7 @@ def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
435
  init_image[2,:,:] = init_image[2,:,:] * (1-init_mask) - 1.0*init_mask
436
  return init_image
437
 
438
-
439
  def get_init_image_and_mask(args, device):
440
  convert_tensor = transforms.ToTensor()
441
  init_image = Image.open(args.init_image).convert('RGB')
@@ -509,7 +522,7 @@ def get_init_image_and_mask(args, device):
509
  init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
510
  return init_image.to(device), init_mask.to(device)
511
 
512
-
513
  def main():
514
  global init_image, init_mask
515
  p = argparse.ArgumentParser(description=__doc__,
@@ -586,6 +599,7 @@ def main():
586
  #model_fn = model
587
  #ddpm_sampler = K.external.VDenoiser(model_fn)
588
 
 
589
  def sample_fn(n, debug=True):
590
  x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
591
  print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())
 
1
  #!/usr/bin/env python3
2
 
3
  # Code by Kat Crowson in k-diffusion repo, modified by Scott H Hawley (SHH)
4
+ # Modified by Scott H. Hawley for masking, ZeroGPU ets.
5
 
6
  """Samples from k-diffusion models."""
7
 
8
+ import gradio
9
+ import spaces
10
+ import natten
11
  import argparse
12
  from pathlib import Path
13
 
 
28
  # ---- my mangled sampler that includes repaint
29
  import torchsde
30
 
31
+ @spaces.GPU
32
  class BatchedBrownianTree:
33
  """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
34
 
 
56
  return w if self.batched else w[0]
57
 
58
 
59
+ @spaces.GPU
60
  class BrownianTreeNoiseSampler:
61
  """A noise sampler backed by a torchsde.BrownianTree.
62
 
 
94
  return (x - denoised) / append_dims(sigma, x.ndim)
95
 
96
 
97
+ @spaces.GPU
98
  @torch.no_grad()
99
  def my_sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., repaint=1):
100
  """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
 
129
  return c_skip, c_out, c_in
130
 
131
 
132
+ @spaces.GPU
133
  @torch.no_grad()
134
  def my_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None,
135
  disable=None, eta=1., s_noise=1., noise_sampler=None,
 
289
 
290
  # Soft mask inpainting is just shrinking hard (binary) mask inpainting
291
  # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
292
+ @spaces.GPU
293
  def get_bmask(i, steps, mask):
294
  strength = (i+1)/(steps)
295
  # convert to binary mask
296
  bmask = torch.where(mask<=strength,1,0)
297
  return bmask
298
 
299
+ @spaces.GPU
300
  def make_cond_model_fn(model, cond_fn):
301
  def cond_model_fn(x, sigma, **kwargs):
302
  with torch.enable_grad():
 
312
  # For sampling, set both init_data and mask to None
313
  # For variations, set init_data
314
  # For inpainting, set both init_data & mask
315
+ @spaces.GPU
316
  def sample_k(
317
  model_fn,
318
  noise,
 
410
 
411
 
412
  ## ---- end stable-audio-tools
413
+ @spaces.GPU
414
  def infer_mask_from_init_img(img, mask_with='white'):
415
  """given an image with mask areas marked, extract the mask itself
416
  note, this works whether image is normalized on 0..1 or -1..1, but not 0..255"""
 
425
  mask[img[2,:,:]==1] = 1 # blue
426
  return mask*1.0
427
 
428
+ @spaces.GPU
429
  def grow_mask(init_mask, grow_by=2):
430
  "adds a border of grow_by pixels to the mask, by growing it grow_by times. If grow_by=0, does nothing"
431
  new_mask = init_mask.clone()
 
434
  new_mask[1:-1,1:-1] = (new_mask[1:-1,1:-1] + new_mask[0:-2,1:-1] + new_mask[2:,1:-1] + new_mask[1:-1,0:-2] + new_mask[1:-1,2:]) > 0
435
  return new_mask
436
 
437
+ @spaces.GPU
438
  def add_seeding(init_image, init_mask, grow_by=0, seed_scale=1.0):
439
  "adds extra noise inside mask"
440
  init_mask = grow_mask(init_mask, grow_by=grow_by) # make the mask bigger
 
448
  init_image[2,:,:] = init_image[2,:,:] * (1-init_mask) - 1.0*init_mask
449
  return init_image
450
 
451
+ @spaces.GPU
452
  def get_init_image_and_mask(args, device):
453
  convert_tensor = transforms.ToTensor()
454
  init_image = Image.open(args.init_image).convert('RGB')
 
522
  init_mask = init_mask.unsqueeze(0).unsqueeze(1).repeat(args.batch_size,3,1,1).float()
523
  return init_image.to(device), init_mask.to(device)
524
 
525
+ @spaces.GPU
526
  def main():
527
  global init_image, init_mask
528
  p = argparse.ArgumentParser(description=__doc__,
 
599
  #model_fn = model
600
  #ddpm_sampler = K.external.VDenoiser(model_fn)
601
 
602
+ @spaces.GPU
603
  def sample_fn(n, debug=True):
604
  x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigma_max
605
  print("n, sigma_max, x.min, x.max = ", n, sigma_max, x.min(), x.max())