ChongMou commited on
Commit
f0ae51e
·
1 Parent(s): 3996268

remove save GPU memory

Browse files
Files changed (1) hide show
  1. demo/model.py +33 -33
demo/model.py CHANGED
@@ -175,11 +175,11 @@ class Model_all:
175
  im = im.float()
176
  im_edge = tensor2img(im)
177
 
178
- # save gpu memory
179
- self.base_model.model = self.base_model.model.cpu()
180
- self.model_sketch = self.model_sketch.cuda()
181
- self.base_model.first_stage_model = self.base_model.first_stage_model.cpu()
182
- self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
183
 
184
  # extract condition features
185
  c = self.base_model.get_learned_conditioning([prompt+', '+pos_prompt])
@@ -187,10 +187,10 @@ class Model_all:
187
  features_adapter = self.model_sketch(im.to(self.device))
188
  shape = [4, 64, 64]
189
 
190
- # save gpu memory
191
- self.model_sketch = self.model_sketch.cpu()
192
- self.base_model.cond_stage_model = self.base_model.cond_stage_model.cpu()
193
- self.base_model.model = self.base_model.model.cuda()
194
 
195
  # sampling
196
  samples_ddim, _ = self.sampler.sample(S=50,
@@ -205,8 +205,8 @@ class Model_all:
205
  features_adapter1=features_adapter,
206
  mode = 'sketch',
207
  con_strength = con_strength)
208
- # save gpu memory
209
- self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
210
 
211
  x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
212
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -246,11 +246,11 @@ class Model_all:
246
  im = im>0.5
247
  im = im.float()
248
 
249
- # save gpu memory
250
- self.base_model.model = self.base_model.model.cpu()
251
- self.model_sketch = self.model_sketch.cuda()
252
- self.base_model.first_stage_model = self.base_model.first_stage_model.cpu()
253
- self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
254
 
255
  # extract condition features
256
  c = self.base_model.get_learned_conditioning([prompt+', '+pos_prompt])
@@ -258,10 +258,10 @@ class Model_all:
258
  features_adapter = self.model_sketch(im.to(self.device))
259
  shape = [4, 64, 64]
260
 
261
- # save gpu memory
262
- self.model_sketch = self.model_sketch.cpu()
263
- self.base_model.cond_stage_model = self.base_model.cond_stage_model.cpu()
264
- self.base_model.model = self.base_model.model.cuda()
265
 
266
  # sampling
267
  samples_ddim, _ = self.sampler.sample(S=50,
@@ -277,8 +277,8 @@ class Model_all:
277
  mode = 'sketch',
278
  con_strength = con_strength)
279
 
280
- # save gpu memory
281
- self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
282
 
283
  x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
284
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -345,11 +345,11 @@ class Model_all:
345
  thickness=2)
346
  im_pose = cv2.resize(im_pose,(512,512))
347
 
348
- # save gpu memory
349
- self.base_model.model = self.base_model.model.cpu()
350
- self.model_pose = self.model_pose.cuda()
351
- self.base_model.first_stage_model = self.base_model.first_stage_model.cpu()
352
- self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
353
 
354
  # extract condition features
355
  c = self.base_model.get_learned_conditioning([prompt+', '+pos_prompt])
@@ -358,10 +358,10 @@ class Model_all:
358
  pose = pose.unsqueeze(0)
359
  features_adapter = self.model_pose(pose.to(self.device))
360
 
361
- # save gpu memory
362
- self.model_pose = self.model_pose.cpu()
363
- self.base_model.cond_stage_model = self.base_model.cond_stage_model.cpu()
364
- self.base_model.model = self.base_model.model.cuda()
365
 
366
  shape = [4, 64, 64]
367
 
@@ -379,8 +379,8 @@ class Model_all:
379
  mode = 'sketch',
380
  con_strength = con_strength)
381
 
382
- # save gpu memory
383
- self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
384
 
385
  x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
386
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
 
175
  im = im.float()
176
  im_edge = tensor2img(im)
177
 
178
+ # # save gpu memory
179
+ # self.base_model.model = self.base_model.model.cpu()
180
+ # self.model_sketch = self.model_sketch.cuda()
181
+ # self.base_model.first_stage_model = self.base_model.first_stage_model.cpu()
182
+ # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
183
 
184
  # extract condition features
185
  c = self.base_model.get_learned_conditioning([prompt+', '+pos_prompt])
 
187
  features_adapter = self.model_sketch(im.to(self.device))
188
  shape = [4, 64, 64]
189
 
190
+ # # save gpu memory
191
+ # self.model_sketch = self.model_sketch.cpu()
192
+ # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cpu()
193
+ # self.base_model.model = self.base_model.model.cuda()
194
 
195
  # sampling
196
  samples_ddim, _ = self.sampler.sample(S=50,
 
205
  features_adapter1=features_adapter,
206
  mode = 'sketch',
207
  con_strength = con_strength)
208
+ # # save gpu memory
209
+ # self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
210
 
211
  x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
212
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
 
246
  im = im>0.5
247
  im = im.float()
248
 
249
+ # # save gpu memory
250
+ # self.base_model.model = self.base_model.model.cpu()
251
+ # self.model_sketch = self.model_sketch.cuda()
252
+ # self.base_model.first_stage_model = self.base_model.first_stage_model.cpu()
253
+ # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
254
 
255
  # extract condition features
256
  c = self.base_model.get_learned_conditioning([prompt+', '+pos_prompt])
 
258
  features_adapter = self.model_sketch(im.to(self.device))
259
  shape = [4, 64, 64]
260
 
261
+ # # save gpu memory
262
+ # self.model_sketch = self.model_sketch.cpu()
263
+ # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cpu()
264
+ # self.base_model.model = self.base_model.model.cuda()
265
 
266
  # sampling
267
  samples_ddim, _ = self.sampler.sample(S=50,
 
277
  mode = 'sketch',
278
  con_strength = con_strength)
279
 
280
+ # # save gpu memory
281
+ # self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
282
 
283
  x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
284
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
 
345
  thickness=2)
346
  im_pose = cv2.resize(im_pose,(512,512))
347
 
348
+ # # save gpu memory
349
+ # self.base_model.model = self.base_model.model.cpu()
350
+ # self.model_pose = self.model_pose.cuda()
351
+ # self.base_model.first_stage_model = self.base_model.first_stage_model.cpu()
352
+ # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
353
 
354
  # extract condition features
355
  c = self.base_model.get_learned_conditioning([prompt+', '+pos_prompt])
 
358
  pose = pose.unsqueeze(0)
359
  features_adapter = self.model_pose(pose.to(self.device))
360
 
361
+ # # save gpu memory
362
+ # self.model_pose = self.model_pose.cpu()
363
+ # self.base_model.cond_stage_model = self.base_model.cond_stage_model.cpu()
364
+ # self.base_model.model = self.base_model.model.cuda()
365
 
366
  shape = [4, 64, 64]
367
 
 
379
  mode = 'sketch',
380
  con_strength = con_strength)
381
 
382
+ # # save gpu memory
383
+ # self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
384
 
385
  x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
386
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)