LiruiZhao commited on
Commit
b7b1d93
1 Parent(s): 5798e9a
Files changed (1) hide show
  1. app.py +22 -9
app.py CHANGED
@@ -75,15 +75,20 @@ def append_dims(x, target_dims):
75
 
76
  class CompVisDenoiser(K.external.CompVisDenoiser):
77
  def __init__(self, model, quantize=False, device='cpu'):
78
- super().__init__( model, quantize, device)
79
 
80
  def get_eps(self, *args, **kwargs):
81
  return self.inner_model.apply_model(*args, **kwargs)
82
 
83
  def forward(self, input_0, input_1, sigma, **kwargs):
 
 
84
  c_out, c_in = [append_dims(x, input_0.ndim) for x in self.get_scalings(sigma)]
 
 
 
85
  # eps_0, eps_1 = self.get_eps(input_0 * c_in, input_1 * c_in, self.sigma_to_t(sigma), **kwargs)
86
- eps_0, eps_1 = self.get_eps(input_0 * c_in, self.sigma_to_t(sigma), **kwargs)
87
 
88
  return input_0 + eps_0 * c_out, eps_1
89
 
@@ -112,7 +117,6 @@ def decode_mask(mask, height = 256, width = 256):
112
  mask = mask.type(torch.uint8).cpu().numpy()
113
  return mask
114
 
115
- @torch.no_grad()
116
  def sample_euler_ancestral(model, x_0, x_1, sigmas, height, width, extra_args=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
117
  """Ancestral sampling with Euler method steps."""
118
  extra_args = {} if extra_args is None else extra_args
@@ -183,17 +187,24 @@ def generate(
183
 
184
  if instruction == "":
185
  return [input_image, seed]
186
-
 
 
187
  with torch.no_grad(), autocast("cuda"), model.ema_scope():
188
  cond = {}
189
- cond["c_crossattn"] = [model.get_learned_conditioning([instruction])]
190
  input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
191
  input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device)
192
- cond["c_concat"] = [model.encode_first_stage(input_image).mode()]
193
 
194
  uncond = {}
195
- uncond["c_crossattn"] = [null_token]
196
  uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
 
 
 
 
 
197
 
198
  sigmas = model_wrap.get_sigmas(steps)
199
 
@@ -204,8 +215,10 @@ def generate(
204
  "image_cfg_scale": image_cfg_scale,
205
  }
206
  torch.manual_seed(seed)
207
- z_0 = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
208
- z_1 = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
 
 
209
 
210
  z_0, z_1, image_list, mask_list = sample_euler_ancestral(model_wrap_cfg, z_0, z_1, sigmas, height, width, extra_args=extra_args)
211
 
 
75
 
76
  class CompVisDenoiser(K.external.CompVisDenoiser):
77
  def __init__(self, model, quantize=False, device='cpu'):
78
+ super().__init__(model, quantize, device)
79
 
80
  def get_eps(self, *args, **kwargs):
81
  return self.inner_model.apply_model(*args, **kwargs)
82
 
83
  def forward(self, input_0, input_1, sigma, **kwargs):
84
+ print("input_0.device:", input_0.device)
85
+ print("input_1.device:", input_1.device)
86
  c_out, c_in = [append_dims(x, input_0.ndim) for x in self.get_scalings(sigma)]
87
+ print("c_in.device:", c_in.device)
88
+ print("c_out.device:", c_out.device)
89
+ print("sigma.device:", sigma.device)
90
  # eps_0, eps_1 = self.get_eps(input_0 * c_in, input_1 * c_in, self.sigma_to_t(sigma), **kwargs)
91
+ eps_0, eps_1 = self.get_eps(input_0 * c_in, self.sigma_to_t(sigma.cpu()).cuda(), **kwargs)
92
 
93
  return input_0 + eps_0 * c_out, eps_1
94
 
 
117
  mask = mask.type(torch.uint8).cpu().numpy()
118
  return mask
119
 
 
120
  def sample_euler_ancestral(model, x_0, x_1, sigmas, height, width, extra_args=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
121
  """Ancestral sampling with Euler method steps."""
122
  extra_args = {} if extra_args is None else extra_args
 
187
 
188
  if instruction == "":
189
  return [input_image, seed]
190
+
191
+ model.cuda()
192
+ print("model.device:", model.device)
193
  with torch.no_grad(), autocast("cuda"), model.ema_scope():
194
  cond = {}
195
+ cond["c_crossattn"] = [model.get_learned_conditioning([instruction]).to(model.device)]
196
  input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
197
  input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device)
198
+ cond["c_concat"] = [model.encode_first_stage(input_image).mode().to(model.device)]
199
 
200
  uncond = {}
201
+ uncond["c_crossattn"] = [null_token.to(model.device)]
202
  uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
203
+
204
+ print("cond['c_crossattn'][0].device:", cond["c_crossattn"][0].device)
205
+ print("cond['c_concat'][0].device:", cond["c_concat"][0].device)
206
+ print("uncond['c_crossattn'][0].device:", uncond["c_crossattn"][0].device)
207
+ print("uncond['c_concat'][0].device:", uncond["c_concat"][0].device)
208
 
209
  sigmas = model_wrap.get_sigmas(steps)
210
 
 
215
  "image_cfg_scale": image_cfg_scale,
216
  }
217
  torch.manual_seed(seed)
218
+ z_0 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0]
219
+ z_1 = torch.randn_like(cond["c_concat"][0]).to(model.device) * sigmas[0]
220
+ print("z_0.device:", z_0.device)
221
+ print("z_1.device:", z_1.device)
222
 
223
  z_0, z_1, image_list, mask_list = sample_euler_ancestral(model_wrap_cfg, z_0, z_1, sigmas, height, width, extra_args=extra_args)
224