vilarin commited on
Commit
2516da3
·
verified ·
1 Parent(s): c739019

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -46,7 +46,8 @@ class ModelWrapper:
46
  self.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
47
  self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
48
  self.num_step = num_step
49
-
 
50
  def create_generator(self, model_id, checkpoint_path):
51
  generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
52
  state_dict = torch.load(checkpoint_path, map_location="cuda")
@@ -108,7 +109,7 @@ class ModelWrapper:
108
  eval_images = ((eval_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
109
  return eval_images
110
 
111
- @spaces.GPU()
112
  @torch.no_grad()
113
  def inference(self, prompt, seed, height, width, num_images, fast_vae_decode):
114
  print("Running model inference...")
 
46
  self.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
47
  self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
48
  self.num_step = num_step
49
+
50
+ @spaces.GPU()
51
  def create_generator(self, model_id, checkpoint_path):
52
  generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
53
  state_dict = torch.load(checkpoint_path, map_location="cuda")
 
109
  eval_images = ((eval_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
110
  return eval_images
111
 
112
+
113
  @torch.no_grad()
114
  def inference(self, prompt, seed, height, width, num_images, fast_vae_decode):
115
  print("Running model inference...")