Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -46,7 +46,8 @@ class ModelWrapper:
|
|
46 |
self.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
47 |
self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
|
48 |
self.num_step = num_step
|
49 |
-
|
|
|
50 |
def create_generator(self, model_id, checkpoint_path):
|
51 |
generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
|
52 |
state_dict = torch.load(checkpoint_path, map_location="cuda")
|
@@ -108,7 +109,7 @@ class ModelWrapper:
|
|
108 |
eval_images = ((eval_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
|
109 |
return eval_images
|
110 |
|
111 |
-
|
112 |
@torch.no_grad()
|
113 |
def inference(self, prompt, seed, height, width, num_images, fast_vae_decode):
|
114 |
print("Running model inference...")
|
|
|
46 |
self.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
47 |
self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
|
48 |
self.num_step = num_step
|
49 |
+
|
50 |
+
@spaces.GPU()
|
51 |
def create_generator(self, model_id, checkpoint_path):
|
52 |
generator = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(self.DTYPE)
|
53 |
state_dict = torch.load(checkpoint_path, map_location="cuda")
|
|
|
109 |
eval_images = ((eval_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
|
110 |
return eval_images
|
111 |
|
112 |
+
|
113 |
@torch.no_grad()
|
114 |
def inference(self, prompt, seed, height, width, num_images, fast_vae_decode):
|
115 |
print("Running model inference...")
|