wondervictor commited on
Commit
8b9ace1
·
verified ·
1 Parent(s): 09a0762

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +14 -8
model.py CHANGED
@@ -103,10 +103,7 @@ class Model:
103
  control_strength: float,
104
  preprocessor_name: str,
105
  ) -> list[PIL.Image.Image]:
106
- self.t5_model.model.to('cuda').to(torch.bfloat16)
107
- self.load_gpt_weight('edge')
108
- self.gpt_model.to('cuda').to(torch.bfloat16)
109
- self.vq_model.to('cuda')
110
  if isinstance(image, np.ndarray):
111
  image = Image.fromarray(image)
112
  origin_W, origin_H = image.size
@@ -125,9 +122,15 @@ class Model:
125
  elif preprocessor_name == 'No preprocess':
126
  condition_img = image
127
  print('get edge')
 
 
128
  condition_img = condition_img.resize((512,512))
129
  W, H = condition_img.size
130
 
 
 
 
 
131
  condition_img = torch.from_numpy(np.array(condition_img)).unsqueeze(0).permute(0,3,1,2).repeat(1,1,1,1)
132
  condition_img = condition_img.to(self.device)
133
  condition_img = 2*(condition_img/255 - 0.5)
@@ -198,10 +201,7 @@ class Model:
198
  control_strength: float,
199
  preprocessor_name: str
200
  ) -> list[PIL.Image.Image]:
201
- self.t5_model.model.to(self.device).to(torch.bfloat16)
202
- self.load_gpt_weight('depth')
203
- self.gpt_model.to('cuda').to(torch.bfloat16)
204
- self.vq_model.to(self.device)
205
  if isinstance(image, np.ndarray):
206
  image = Image.fromarray(image)
207
  origin_W, origin_H = image.size
@@ -216,9 +216,15 @@ class Model:
216
  elif preprocessor_name == 'No preprocess':
217
  condition_img = image
218
  print('get depth')
 
 
219
  condition_img = condition_img.resize((512,512))
220
  W, H = condition_img.size
221
 
 
 
 
 
222
  condition_img = torch.from_numpy(np.array(condition_img)).unsqueeze(0).permute(0,3,1,2).repeat(1,1,1,1)
223
  condition_img = condition_img.to(self.device)
224
  condition_img = 2*(condition_img/255 - 0.5)
 
103
  control_strength: float,
104
  preprocessor_name: str,
105
  ) -> list[PIL.Image.Image]:
106
+
 
 
 
107
  if isinstance(image, np.ndarray):
108
  image = Image.fromarray(image)
109
  origin_W, origin_H = image.size
 
122
  elif preprocessor_name == 'No preprocess':
123
  condition_img = image
124
  print('get edge')
125
+ del self.preprocessor.model
126
+ torch.cuda.empty_cache()
127
  condition_img = condition_img.resize((512,512))
128
  W, H = condition_img.size
129
 
130
+ self.t5_model.model.to('cuda').to(torch.bfloat16)
131
+ self.load_gpt_weight('edge')
132
+ self.gpt_model.to('cuda').to(torch.bfloat16)
133
+ self.vq_model.to('cuda')
134
  condition_img = torch.from_numpy(np.array(condition_img)).unsqueeze(0).permute(0,3,1,2).repeat(1,1,1,1)
135
  condition_img = condition_img.to(self.device)
136
  condition_img = 2*(condition_img/255 - 0.5)
 
201
  control_strength: float,
202
  preprocessor_name: str
203
  ) -> list[PIL.Image.Image]:
204
+
 
 
 
205
  if isinstance(image, np.ndarray):
206
  image = Image.fromarray(image)
207
  origin_W, origin_H = image.size
 
216
  elif preprocessor_name == 'No preprocess':
217
  condition_img = image
218
  print('get depth')
219
+ del self.preprocessor.model
220
+ torch.cuda.empty_cache()
221
  condition_img = condition_img.resize((512,512))
222
  W, H = condition_img.size
223
 
224
+ self.t5_model.model.to(self.device).to(torch.bfloat16)
225
+ self.load_gpt_weight('depth')
226
+ self.gpt_model.to('cuda').to(torch.bfloat16)
227
+ self.vq_model.to(self.device)
228
  condition_img = torch.from_numpy(np.array(condition_img)).unsqueeze(0).permute(0,3,1,2).repeat(1,1,1,1)
229
  condition_img = condition_img.to(self.device)
230
  condition_img = 2*(condition_img/255 - 0.5)