wondervictor commited on
Commit
113349b
·
verified ·
1 Parent(s): 4005999

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +7 -7
model.py CHANGED
@@ -87,7 +87,7 @@ class Model:
87
  def load_t5(self):
88
  precision = torch.bfloat16
89
  t5_model = T5Embedder(
90
- device="cuda",
91
  local_cache=True,
92
  cache_dir='checkpoints/flan-t5-xl',
93
  dir_or_name='flan-t5-xl',
@@ -114,8 +114,8 @@ class Model:
114
  image = resize_image_to_16_multiple(image, 'canny')
115
  W, H = image.size
116
  print(W, H)
117
- self.t5_model.model.to(self.device)
118
- self.gpt_model_canny.to(self.device)
119
 
120
  condition_img = self.get_control_canny(np.array(image), low_threshold,
121
  high_threshold)
@@ -191,10 +191,10 @@ class Model:
191
  image = resize_image_to_16_multiple(image, 'depth')
192
  W, H = image.size
193
  print(W, H)
194
- self.t5_model.model.to(self.device)
195
- self.gpt_model_depth.to(self.device)
196
- self.get_control_depth.model.to(self.device)
197
- self.vq_model.to(self.device)
198
  image_tensor = torch.from_numpy(np.array(image)).to(self.device)
199
  condition_img = torch.from_numpy(
200
  self.get_control_depth(image_tensor)).unsqueeze(0)
 
87
  def load_t5(self):
88
  precision = torch.bfloat16
89
  t5_model = T5Embedder(
90
+ device=self.device,
91
  local_cache=True,
92
  cache_dir='checkpoints/flan-t5-xl',
93
  dir_or_name='flan-t5-xl',
 
114
  image = resize_image_to_16_multiple(image, 'canny')
115
  W, H = image.size
116
  print(W, H)
117
+ # self.t5_model.model.to(self.device)
118
+ # self.gpt_model_canny.to(self.device)
119
 
120
  condition_img = self.get_control_canny(np.array(image), low_threshold,
121
  high_threshold)
 
191
  image = resize_image_to_16_multiple(image, 'depth')
192
  W, H = image.size
193
  print(W, H)
194
+ # self.t5_model.model.to(self.device)
195
+ # self.gpt_model_depth.to(self.device)
196
+ # self.get_control_depth.model.to(self.device)
197
+ # self.vq_model.to(self.device)
198
  image_tensor = torch.from_numpy(np.array(image)).to(self.device)
199
  condition_img = torch.from_numpy(
200
  self.get_control_depth(image_tensor)).unsqueeze(0)