liray commited on
Commit
ded59bb
1 Parent(s): 91159ab
Files changed (1) hide show
  1. app.py +4 -9
app.py CHANGED
@@ -92,13 +92,11 @@ def model_init():
92
  load=False,
93
  )
94
  model.load_state_dict(torch.load(model_checkpoint, map_location="cpu")["model"])
95
- model = model.to("cuda")
96
- return model
97
 
98
- @spaces.GPU
99
  def sam_segment(predictor, input_image, drags, foreground_points=None):
100
  image = np.asarray(input_image)
101
- predictor = predictor.to("cuda")
102
  predictor.set_image(image)
103
 
104
  with torch.no_grad():
@@ -173,7 +171,7 @@ def preprocess_image(SAM_predictor, img, chk_group, drags):
173
  processed_img = image_pil.resize((256, 256), Image.LANCZOS)
174
  return processed_img, new_drags
175
 
176
- @spaces.GPU
177
  def single_image_sample(
178
  model,
179
  diffusion,
@@ -188,8 +186,6 @@ def single_image_sample(
188
  vae=None,
189
  ):
190
  z = torch.randn(2, 4, 32, 32).to("cuda")
191
- if vae is not None:
192
- vae = vae.to("cuda")
193
 
194
  # Prepare input for classifer-free guidance
195
  rel = torch.cat([rel, rel], dim=0).to("cuda")
@@ -233,9 +229,8 @@ def single_image_sample(
233
  images = samples
234
  return ((images + 1)[0].permute(1, 2, 0) * 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
235
 
236
- return samples
237
 
238
- @spaces.GPU
239
  def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion, img_cond, seed, cfg_scale, drags_list):
240
  if img_cond is None:
241
  gr.Warning("Please preprocess the image first.")
 
92
  load=False,
93
  )
94
  model.load_state_dict(torch.load(model_checkpoint, map_location="cpu")["model"])
95
+ return model.to("cuda")
 
96
 
97
+ @spaces.GPU(duration=10)
98
  def sam_segment(predictor, input_image, drags, foreground_points=None):
99
  image = np.asarray(input_image)
 
100
  predictor.set_image(image)
101
 
102
  with torch.no_grad():
 
171
  processed_img = image_pil.resize((256, 256), Image.LANCZOS)
172
  return processed_img, new_drags
173
 
174
+
175
  def single_image_sample(
176
  model,
177
  diffusion,
 
186
  vae=None,
187
  ):
188
  z = torch.randn(2, 4, 32, 32).to("cuda")
 
 
189
 
190
  # Prepare input for classifer-free guidance
191
  rel = torch.cat([rel, rel], dim=0).to("cuda")
 
229
  images = samples
230
  return ((images + 1)[0].permute(1, 2, 0) * 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
231
 
 
232
 
233
+ @spaces.GPU(duration=20)
234
  def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion, img_cond, seed, cfg_scale, drags_list):
235
  if img_cond is None:
236
  gr.Warning("Please preprocess the image first.")