Andranik Sargsyan commited on
Commit
da1e12f
Β·
1 Parent(s): bfd34e9

enable fp16, move SR to cuda:1

Browse files
app.py CHANGED
@@ -64,8 +64,8 @@ inpainting_models = OrderedDict([
64
  ("Stable-Inpainting 2.0", models.sd2_inp.load_model()),
65
  ("Stable-Inpainting 1.5", models.sd15_inp.load_model())
66
  ])
67
- sr_model = models.sd2_sr.load_model()
68
- sam_predictor = models.sam.load_model()
69
 
70
  inp_model = None
71
  cached_inp_model_name = ''
 
64
  ("Stable-Inpainting 2.0", models.sd2_inp.load_model()),
65
  ("Stable-Inpainting 1.5", models.sd15_inp.load_model())
66
  ])
67
+ sr_model = models.sd2_sr.load_model(device='cuda:1')
68
+ sam_predictor = models.sam.load_model(device='cuda:0')
69
 
70
  inp_model = None
71
  cached_inp_model_name = ''
lib/methods/rasg.py CHANGED
@@ -38,9 +38,11 @@ def run(ddim, method, prompt, image, mask, seed, eta, prefix, negative_prompt, p
38
  unet_condition = ddim.get_inpainting_condition(image, mask)
39
  share.set_mask(mask)
40
 
 
 
41
  # Starting latent
42
  seed_everything(seed)
43
- zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda()
44
 
45
  # Setup unet for guidance
46
  ddim.unet.requires_grad_(True)
@@ -58,11 +60,12 @@ def run(ddim, method, prompt, image, mask, seed, eta, prefix, negative_prompt, p
58
 
59
  # Run the model
60
  _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
61
- eps_uncond, eps = ddim.unet(
62
- torch.cat([_zt, _zt]),
63
- timesteps = torch.tensor([timestep, timestep]).cuda(),
64
- context = context
65
- ).detach().chunk(2)
 
66
 
67
  # Unconditional guidance
68
  eps = (eps_uncond + guidance_scale * (eps - eps_uncond))
 
38
  unet_condition = ddim.get_inpainting_condition(image, mask)
39
  share.set_mask(mask)
40
 
41
+ dtype = unet_condition.dtype
42
+
43
  # Starting latent
44
  seed_everything(seed)
45
+ zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda().to(dtype)
46
 
47
  # Setup unet for guidance
48
  ddim.unet.requires_grad_(True)
 
60
 
61
  # Run the model
62
  _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
63
+ with torch.autocast('cuda'):
64
+ eps_uncond, eps = ddim.unet(
65
+ torch.cat([_zt, _zt]).to(dtype),
66
+ timesteps = torch.tensor([timestep, timestep]).cuda(),
67
+ context = context
68
+ ).detach().chunk(2)
69
 
70
  # Unconditional guidance
71
  eps = (eps_uncond + guidance_scale * (eps - eps_uncond))
lib/methods/sd.py CHANGED
@@ -43,11 +43,12 @@ def run(
43
 
44
  # Image condition
45
  unet_condition = ddim.get_inpainting_condition(image, mask)
 
46
  share.set_mask(mask)
47
 
48
  # Starting latent
49
  seed_everything(seed)
50
- zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda()
51
 
52
  # Turn off gradients
53
  ddim.unet.requires_grad_(False)
@@ -58,11 +59,12 @@ def run(
58
  if share.timestep <= 500: router.reset()
59
 
60
  _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
61
- eps_uncond, eps = ddim.unet(
62
- torch.cat([_zt, _zt]),
63
- timesteps = torch.tensor([timestep, timestep]).cuda(),
64
- context = context
65
- ).chunk(2)
 
66
 
67
  eps = (eps_uncond + guidance_scale * (eps - eps_uncond))
68
  z0 = (zt - share.schedule.sqrt_one_minus_alphas[timestep] * eps) / share.schedule.sqrt_alphas[timestep]
 
43
 
44
  # Image condition
45
  unet_condition = ddim.get_inpainting_condition(image, mask)
46
+ dtype = unet_condition.dtype
47
  share.set_mask(mask)
48
 
49
  # Starting latent
50
  seed_everything(seed)
51
+ zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda().to(dtype)
52
 
53
  # Turn off gradients
54
  ddim.unet.requires_grad_(False)
 
59
  if share.timestep <= 500: router.reset()
60
 
61
  _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
62
+ with torch.autocast('cuda'):
63
+ eps_uncond, eps = ddim.unet(
64
+ torch.cat([_zt, _zt]).to(dtype),
65
+ timesteps = torch.tensor([timestep, timestep]).cuda(),
66
+ context = context
67
+ ).chunk(2)
68
 
69
  eps = (eps_uncond + guidance_scale * (eps - eps_uncond))
70
  z0 = (zt - share.schedule.sqrt_one_minus_alphas[timestep] * eps) / share.schedule.sqrt_alphas[timestep]
lib/methods/sr.py CHANGED
@@ -59,8 +59,10 @@ def refine_mask(hr_image, hr_mask, lr_image, sam_predictor):
59
 
60
  def run(ddim, sam_predictor, lr_image, hr_image, hr_mask, prompt = 'high resolution professional photo', noise_level=20,
61
  blend_output = True, blend_trick = True, no_superres = False,
62
- dt = 20, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = False, dtype=torch.bfloat16):
63
  torch.manual_seed(seed)
 
 
64
 
65
  router.attention_forward = attentionpatch.default.forward_xformers
66
  router.basic_transformer_forward = transformerpatch.default.forward
@@ -74,7 +76,7 @@ dt = 20, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = Fa
74
  hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19)
75
  hr_mask_orig = hr_mask
76
  lr_image = lr_image.padx(64, padding_mode='reflect')
77
- lr_mask = hr_mask.resize((lr_image.torch().shape[2], lr_image.torch().shape[3]), resample = Image.BICUBIC).alpha().torch(vmin=0).cuda()
78
  lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19)
79
 
80
  if no_superres:
@@ -89,18 +91,18 @@ dt = 20, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = Fa
89
 
90
  # encode hr image
91
  with torch.no_grad():
92
- hr_z0 = ddim.vae.encode(hr_image.torch().cuda().to(dtype)).mean * ddim.config.scale_factor
93
 
94
  assert hr_z0.shape[2] == lr_image.torch().shape[2]
95
  assert hr_z0.shape[3] == lr_image.torch().shape[3]
96
 
97
- unet_condition = lr_image.cuda().torch().to(memory_format=torch.contiguous_format).to(dtype)
98
- zT = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3])).cuda().to(dtype)
99
 
100
  with torch.no_grad():
101
  context = ddim.encoder.encode([negative_prompt, prompt])
102
 
103
- noise_level = torch.Tensor(1 * [noise_level]).to('cuda').long()
104
  unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level)
105
 
106
  with torch.autocast('cuda'), torch.no_grad():
@@ -110,13 +112,13 @@ dt = 20, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = Fa
110
  _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
111
 
112
  eps_uncond, eps = ddim.unet(
113
- torch.cat([_zt, _zt]).to(dtype),
114
- timesteps = torch.tensor([t, t]).cuda(),
115
  context = context,
116
  y=torch.cat([noise_level]*2)
117
  ).chunk(2)
118
 
119
- ts = torch.full((zt.shape[0],), t, device='cuda', dtype=torch.long)
120
  model_output = (eps_uncond + guidance_scale * (eps - eps_uncond))
121
  eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
122
  z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
 
59
 
60
  def run(ddim, sam_predictor, lr_image, hr_image, hr_mask, prompt = 'high resolution professional photo', noise_level=20,
61
  blend_output = True, blend_trick = True, no_superres = False,
62
+ dt = 20, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = False):
63
  torch.manual_seed(seed)
64
+ dtype = ddim.vae.encoder.conv_in.weight.dtype
65
+ device = ddim.vae.encoder.conv_in.weight.device
66
 
67
  router.attention_forward = attentionpatch.default.forward_xformers
68
  router.basic_transformer_forward = transformerpatch.default.forward
 
76
  hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19)
77
  hr_mask_orig = hr_mask
78
  lr_image = lr_image.padx(64, padding_mode='reflect')
79
+ lr_mask = hr_mask.resize((lr_image.torch().shape[2], lr_image.torch().shape[3]), resample = Image.BICUBIC).alpha().torch(vmin=0).to(device)
80
  lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19)
81
 
82
  if no_superres:
 
91
 
92
  # encode hr image
93
  with torch.no_grad():
94
+ hr_z0 = ddim.vae.encode(hr_image.torch().cuda().to(dtype=dtype, device=device)).mean * ddim.config.scale_factor
95
 
96
  assert hr_z0.shape[2] == lr_image.torch().shape[2]
97
  assert hr_z0.shape[3] == lr_image.torch().shape[3]
98
 
99
+ unet_condition = lr_image.cuda().torch().to(memory_format=torch.contiguous_format).to(dtype=dtype, device=device)
100
+ zT = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3])).cuda().to(dtype=dtype, device=device)
101
 
102
  with torch.no_grad():
103
  context = ddim.encoder.encode([negative_prompt, prompt])
104
 
105
+ noise_level = torch.Tensor(1 * [noise_level]).to(device=device).long()
106
  unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level)
107
 
108
  with torch.autocast('cuda'), torch.no_grad():
 
112
  _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
113
 
114
  eps_uncond, eps = ddim.unet(
115
+ torch.cat([_zt, _zt]).to(dtype=dtype, device=device),
116
+ timesteps = torch.tensor([t, t]).to(device=device),
117
  context = context,
118
  y=torch.cat([noise_level]*2)
119
  ).chunk(2)
120
 
121
+ ts = torch.full((zt.shape[0],), t, device=device, dtype=torch.long)
122
  model_output = (eps_uncond + guidance_scale * (eps - eps_uncond))
123
  eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
124
  z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
lib/models/ds_inp.py CHANGED
@@ -15,7 +15,7 @@ DOWNLOAD_URL = 'https://civitai.com/api/download/models/131004'
15
  download_file(DOWNLOAD_URL, MODEL_PATH)
16
 
17
 
18
- def load_model():
19
  print ("Loading model: Dreamshaper Inpainting V8")
20
 
21
  download_file(DOWNLOAD_URL, MODEL_PATH)
@@ -36,10 +36,15 @@ def load_model():
36
  encoder.load_state_dict(encoder_state)
37
  vae.load_state_dict(vae_state)
38
 
 
 
 
 
 
39
  unet = unet.requires_grad_(False)
40
  encoder = encoder.requires_grad_(False)
41
  vae = vae.requires_grad_(False)
42
-
43
  ddim = DDIM(config, vae, encoder, unet)
44
  share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
45
 
 
15
  download_file(DOWNLOAD_URL, MODEL_PATH)
16
 
17
 
18
+ def load_model(dtype=torch.float16):
19
  print ("Loading model: Dreamshaper Inpainting V8")
20
 
21
  download_file(DOWNLOAD_URL, MODEL_PATH)
 
36
  encoder.load_state_dict(encoder_state)
37
  vae.load_state_dict(vae_state)
38
 
39
+ if dtype == torch.float16:
40
+ unet.convert_to_fp16()
41
+ vae.to(dtype)
42
+ encoder.to(dtype)
43
+
44
  unet = unet.requires_grad_(False)
45
  encoder = encoder.requires_grad_(False)
46
  vae = vae.requires_grad_(False)
47
+
48
  ddim = DDIM(config, vae, encoder, unet)
49
  share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
50
 
lib/models/sam.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from segment_anything import sam_model_registry, SamPredictor
2
  from .common import *
3
 
@@ -8,12 +9,10 @@ DOWNLOAD_URL = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939
8
  download_file(DOWNLOAD_URL, MODEL_PATH)
9
 
10
 
11
- def load_model():
12
  print ("Loading model: SAM")
13
  download_file(DOWNLOAD_URL, MODEL_PATH)
14
- model_type = "vit_h"
15
- device = "cuda"
16
- sam = sam_model_registry[model_type](checkpoint=MODEL_PATH)
17
  sam.to(device=device)
18
  sam_predictor = SamPredictor(sam)
19
  print ("SAM loaded")
 
1
+ import torch
2
  from segment_anything import sam_model_registry, SamPredictor
3
  from .common import *
4
 
 
9
  download_file(DOWNLOAD_URL, MODEL_PATH)
10
 
11
 
12
+ def load_model(device='cuda:0'):
13
  print ("Loading model: SAM")
14
  download_file(DOWNLOAD_URL, MODEL_PATH)
15
+ sam = sam_model_registry["vit_h"](checkpoint=MODEL_PATH)
 
 
16
  sam.to(device=device)
17
  sam_predictor = SamPredictor(sam)
18
  print ("SAM loaded")
lib/models/sd15_inp.py CHANGED
@@ -12,7 +12,7 @@ MODEL_PATH = f'{MODEL_FOLDER}/sd-1-5-inpainting/sd-v1-5-inpainting.ckpt'
12
  download_file(DOWNLOAD_URL, MODEL_PATH)
13
 
14
 
15
- def load_model():
16
  download_file(DOWNLOAD_URL, MODEL_PATH)
17
 
18
  state_dict = torch.load(MODEL_PATH)['state_dict']
@@ -34,6 +34,11 @@ def load_model():
34
  encoder.load_state_dict(encoder_state)
35
  vae.load_state_dict(vae_state)
36
 
 
 
 
 
 
37
  unet = unet.requires_grad_(False)
38
  encoder = encoder.requires_grad_(False)
39
  vae = vae.requires_grad_(False)
 
12
  download_file(DOWNLOAD_URL, MODEL_PATH)
13
 
14
 
15
+ def load_model(dtype=torch.float16):
16
  download_file(DOWNLOAD_URL, MODEL_PATH)
17
 
18
  state_dict = torch.load(MODEL_PATH)['state_dict']
 
34
  encoder.load_state_dict(encoder_state)
35
  vae.load_state_dict(vae_state)
36
 
37
+ if dtype == torch.float16:
38
+ unet.convert_to_fp16()
39
+ vae.to(dtype)
40
+ encoder.to(dtype)
41
+
42
  unet = unet.requires_grad_(False)
43
  encoder = encoder.requires_grad_(False)
44
  vae = vae.requires_grad_(False)
lib/models/sd2_inp.py CHANGED
@@ -13,7 +13,7 @@ DOWNLOAD_URL = 'https://huggingface.co/stabilityai/stable-diffusion-2-inpainting
13
  download_file(DOWNLOAD_URL, MODEL_PATH)
14
 
15
 
16
- def load_model():
17
  print ("Loading model: Stable-Inpainting 2.0")
18
 
19
  download_file(DOWNLOAD_URL, MODEL_PATH)
@@ -36,6 +36,13 @@ def load_model():
36
  encoder.load_state_dict(encoder_state)
37
  vae.load_state_dict(vae_state)
38
 
 
 
 
 
 
 
 
39
  unet = unet.requires_grad_(False)
40
  encoder = encoder.requires_grad_(False)
41
  vae = vae.requires_grad_(False)
 
13
  download_file(DOWNLOAD_URL, MODEL_PATH)
14
 
15
 
16
+ def load_model(dtype=torch.float16, device='cuda:0'):
17
  print ("Loading model: Stable-Inpainting 2.0")
18
 
19
  download_file(DOWNLOAD_URL, MODEL_PATH)
 
36
  encoder.load_state_dict(encoder_state)
37
  vae.load_state_dict(vae_state)
38
 
39
+ if dtype == torch.float16:
40
+ unet.convert_to_fp16()
41
+ unet.to(device=device)
42
+ vae.to(dtype=dtype, device=device)
43
+ encoder.to(dtype=dtype, device=device)
44
+ encoder.device = device
45
+
46
  unet = unet.requires_grad_(False)
47
  encoder = encoder.requires_grad_(False)
48
  vae = vae.requires_grad_(False)
lib/models/sd2_sr.py CHANGED
@@ -39,15 +39,15 @@ def extract_into_tensor(a, t, x_shape):
39
 
40
  def predict_eps_from_z_and_v(schedule, x_t, t, v):
41
  return (
42
- extract_into_tensor(schedule.sqrt_alphas.cuda(), t, x_t.shape) * v +
43
- extract_into_tensor(schedule.sqrt_one_minus_alphas.cuda(), t, x_t.shape) * x_t
44
  )
45
 
46
 
47
  def predict_start_from_z_and_v(schedule, x_t, t, v):
48
  return (
49
- extract_into_tensor(schedule.sqrt_alphas.cuda(), t, x_t.shape) * x_t -
50
- extract_into_tensor(schedule.sqrt_one_minus_alphas.cuda(), t, x_t.shape) * v
51
  )
52
 
53
 
@@ -153,7 +153,7 @@ def load_obj(path):
153
  return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
154
 
155
 
156
- def load_model(dtype=torch.bfloat16):
157
  print ("Loading model: SD2 superresolution...")
158
 
159
  download_file(DOWNLOAD_URL, MODEL_PATH)
@@ -180,9 +180,10 @@ def load_model(dtype=torch.bfloat16):
180
  encoder = encoder.requires_grad_(False)
181
  vae = vae.requires_grad_(False)
182
 
183
- unet.to(dtype)
184
- vae.to(dtype)
185
- encoder.to(dtype)
 
186
 
187
  ddim = DDIM(config, vae, encoder, unet)
188
 
@@ -199,6 +200,8 @@ def load_model(dtype=torch.bfloat16):
199
  for param in low_scale_model.parameters():
200
  param.requires_grad = False
201
 
 
 
202
  ddim.low_scale_model = low_scale_model
203
  print('SD2 superresolution loaded')
204
  return ddim
 
39
 
40
  def predict_eps_from_z_and_v(schedule, x_t, t, v):
41
  return (
42
+ extract_into_tensor(schedule.sqrt_alphas.to(x_t.device), t, x_t.shape) * v +
43
+ extract_into_tensor(schedule.sqrt_one_minus_alphas.to(x_t.device), t, x_t.shape) * x_t
44
  )
45
 
46
 
47
  def predict_start_from_z_and_v(schedule, x_t, t, v):
48
  return (
49
+ extract_into_tensor(schedule.sqrt_alphas.to(x_t.device), t, x_t.shape) * x_t -
50
+ extract_into_tensor(schedule.sqrt_one_minus_alphas.to(x_t.device), t, x_t.shape) * v
51
  )
52
 
53
 
 
153
  return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
154
 
155
 
156
+ def load_model(dtype=torch.bfloat16, device='cuda:0'):
157
  print ("Loading model: SD2 superresolution...")
158
 
159
  download_file(DOWNLOAD_URL, MODEL_PATH)
 
180
  encoder = encoder.requires_grad_(False)
181
  vae = vae.requires_grad_(False)
182
 
183
+ unet.to(dtype=dtype, device=device)
184
+ vae.to(dtype=dtype, device=device)
185
+ encoder.to(dtype=dtype, device=device)
186
+ encoder.device = device
187
 
188
  ddim = DDIM(config, vae, encoder, unet)
189
 
 
200
  for param in low_scale_model.parameters():
201
  param.requires_grad = False
202
 
203
+ low_scale_model = low_scale_model.to(dtype=dtype, device=device)
204
+
205
  ddim.low_scale_model = low_scale_model
206
  print('SD2 superresolution loaded')
207
  return ddim
lib/smplfusion/ddim.py CHANGED
@@ -43,11 +43,13 @@ class DDIM:
43
 
44
  def get_inpainting_condition(self, image, mask):
45
  latent_size = [x//8 for x in image.size]
 
46
  with torch.no_grad():
47
- condition_x0 = self.vae.encode(image.torch().cuda() * ~mask.torch(0).bool().cuda()).mean * self.config.scale_factor
48
-
49
- condition_mask = mask.resize(latent_size[::-1]).cuda().torch(0).bool().float()
50
 
 
51
  condition_x0 += 0.01 * condition_mask * torch.randn_like(condition_mask)
52
  return torch.cat([condition_mask, condition_x0], 1)
53
 
 
43
 
44
  def get_inpainting_condition(self, image, mask):
45
  latent_size = [x//8 for x in image.size]
46
+ dtype = self.vae.encoder.conv_in.weight.dtype
47
  with torch.no_grad():
48
+ masked_image = image.torch().cuda() * ~mask.torch(0).bool().cuda()
49
+ masked_image = masked_image.to(dtype)
50
+ condition_x0 = self.vae.encode(masked_image).mean * self.config.scale_factor
51
 
52
+ condition_mask = mask.resize(latent_size[::-1]).cuda().torch(0).bool().to(dtype)
53
  condition_x0 += 0.01 * condition_mask * torch.randn_like(condition_mask)
54
  return torch.cat([condition_mask, condition_x0], 1)
55
 
lib/smplfusion/models/unet.py CHANGED
@@ -14,7 +14,14 @@ from ..modules.attention.spatial_transformer import SpatialTransformer
14
 
15
 
16
  # dummy replace
17
- def convert_module_to_f16(x): pass
 
 
 
 
 
 
 
18
  def convert_module_to_f32(x): pass
19
 
20
 
 
14
 
15
 
16
  # dummy replace
17
+ def convert_module_to_f16(param):
18
+ """
19
+ Convert primitive modules to float16.
20
+ """
21
+ if isinstance(param, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
22
+ param.weight.data = param.weight.data.half()
23
+ if param.bias is not None:
24
+ param.bias.data = param.bias.data.half()
25
  def convert_module_to_f32(x): pass
26
 
27
 
lib/utils/iimage.py CHANGED
@@ -59,6 +59,10 @@ class IImage:
59
  data = self.data.transpose(0, 3, 1, 2) / 255.
60
  return vmin + torch.from_numpy(data).float().to(self.device) * (vmax - vmin)
61
 
 
 
 
 
62
  def cuda(self):
63
  self.device = 'cuda'
64
  return self
 
59
  data = self.data.transpose(0, 3, 1, 2) / 255.
60
  return vmin + torch.from_numpy(data).float().to(self.device) * (vmax - vmin)
61
 
62
+ def to(self, device):
63
+ self.device = device
64
+ return self
65
+
66
  def cuda(self):
67
  self.device = 'cuda'
68
  return self