amildravid4292 commited on
Commit
4a03874
·
verified ·
1 Parent(s): 1d7cba1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -57,7 +57,6 @@ def load_models(device):
57
  unet = UNet2DConditionModel.from_pretrained(
58
  pretrained_model_name_or_path, subfolder="unet", revision=revision
59
  )
60
-
61
  unet.requires_grad_(False)
62
  unet.to(device, dtype=weight_dtype)
63
  vae.requires_grad_(False)
@@ -124,7 +123,7 @@ class main():
124
  self.vae.to(device, dtype=weight_dtype)
125
  self.text_encoder.to(device, dtype=weight_dtype)
126
  print("")
127
- print(self.text_encoder.device)
128
 
129
  self.network = None
130
 
@@ -171,7 +170,8 @@ class main():
171
  self.thick = thick
172
 
173
 
174
-
 
175
  def sample_model(self):
176
  self.unet, _, _, _, _ = load_models(self.device)
177
  self.network = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
@@ -181,6 +181,13 @@ class main():
181
  @spaces.GPU
182
  def inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
183
  device = self.device
 
 
 
 
 
 
 
184
  generator = torch.Generator(device=device).manual_seed(seed)
185
  latents = torch.randn(
186
  (1, self.unet.in_channels, 512 // 8, 512 // 8),
 
57
  unet = UNet2DConditionModel.from_pretrained(
58
  pretrained_model_name_or_path, subfolder="unet", revision=revision
59
  )
 
60
  unet.requires_grad_(False)
61
  unet.to(device, dtype=weight_dtype)
62
  vae.requires_grad_(False)
 
123
  self.vae.to(device, dtype=weight_dtype)
124
  self.text_encoder.to(device, dtype=weight_dtype)
125
  print("")
126
+
127
 
128
  self.network = None
129
 
 
170
  self.thick = thick
171
 
172
 
173
+ @torch.no_grad()
174
+ @spaces.GPU
175
  def sample_model(self):
176
  self.unet, _, _, _, _ = load_models(self.device)
177
  self.network = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
 
181
  @spaces.GPU
182
  def inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
183
  device = self.device
184
+ self.unet.to(device)
185
+ self.text_encoder.to(device)
186
+ self.vae.to(device)
187
+ self.tokenizer.to(device)
188
+
189
+
190
+
191
  generator = torch.Generator(device=device).manual_seed(seed)
192
  latents = torch.randn(
193
  (1, self.unet.in_channels, 512 // 8, 512 // 8),