Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -57,7 +57,6 @@ def load_models(device):
|
|
57 |
unet = UNet2DConditionModel.from_pretrained(
|
58 |
pretrained_model_name_or_path, subfolder="unet", revision=revision
|
59 |
)
|
60 |
-
|
61 |
unet.requires_grad_(False)
|
62 |
unet.to(device, dtype=weight_dtype)
|
63 |
vae.requires_grad_(False)
|
@@ -124,7 +123,7 @@ class main():
|
|
124 |
self.vae.to(device, dtype=weight_dtype)
|
125 |
self.text_encoder.to(device, dtype=weight_dtype)
|
126 |
print("")
|
127 |
-
|
128 |
|
129 |
self.network = None
|
130 |
|
@@ -171,7 +170,8 @@ class main():
|
|
171 |
self.thick = thick
|
172 |
|
173 |
|
174 |
-
|
|
|
175 |
def sample_model(self):
|
176 |
self.unet, _, _, _, _ = load_models(self.device)
|
177 |
self.network = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
|
@@ -181,6 +181,13 @@ class main():
|
|
181 |
@spaces.GPU
|
182 |
def inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
183 |
device = self.device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
generator = torch.Generator(device=device).manual_seed(seed)
|
185 |
latents = torch.randn(
|
186 |
(1, self.unet.in_channels, 512 // 8, 512 // 8),
|
|
|
57 |
unet = UNet2DConditionModel.from_pretrained(
|
58 |
pretrained_model_name_or_path, subfolder="unet", revision=revision
|
59 |
)
|
|
|
60 |
unet.requires_grad_(False)
|
61 |
unet.to(device, dtype=weight_dtype)
|
62 |
vae.requires_grad_(False)
|
|
|
123 |
self.vae.to(device, dtype=weight_dtype)
|
124 |
self.text_encoder.to(device, dtype=weight_dtype)
|
125 |
print("")
|
126 |
+
|
127 |
|
128 |
self.network = None
|
129 |
|
|
|
170 |
self.thick = thick
|
171 |
|
172 |
|
173 |
+
@torch.no_grad()
|
174 |
+
@spaces.GPU
|
175 |
def sample_model(self):
|
176 |
self.unet, _, _, _, _ = load_models(self.device)
|
177 |
self.network = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
|
|
|
181 |
@spaces.GPU
|
182 |
def inference(self, prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
183 |
device = self.device
|
184 |
+
self.unet.to(device)
|
185 |
+
self.text_encoder.to(device)
|
186 |
+
self.vae.to(device)
|
187 |
+
self.tokenizer.to(device)
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
generator = torch.Generator(device=device).manual_seed(seed)
|
192 |
latents = torch.randn(
|
193 |
(1, self.unet.in_channels, 512 // 8, 512 // 8),
|