Spaces:
Paused
Paused
Commit
•
a2bfbe9
1
Parent(s):
7c9588b
Update app.py
Browse files
app.py
CHANGED
@@ -46,7 +46,6 @@ if torch.cuda.is_available():
|
|
46 |
if PREVIEW_IMAGES:
|
47 |
previewer = Previewer()
|
48 |
previewer.load_state_dict(torch.load("previewer/text2img_wurstchen_b_v1_previewer_100k.pt")["state_dict"])
|
49 |
-
previewer.eval().requires_grad_(False).to(device).to(dtype)
|
50 |
|
51 |
def callback_prior(i, t, latents):
|
52 |
output = previewer(latents)
|
@@ -84,6 +83,7 @@ def generate(
|
|
84 |
) -> PIL.Image.Image:
|
85 |
prior_pipeline.to("cuda")
|
86 |
decoder_pipeline.to("cuda")
|
|
|
87 |
generator = torch.Generator().manual_seed(seed)
|
88 |
prior_output = prior_pipeline(
|
89 |
prompt=prompt,
|
|
|
46 |
if PREVIEW_IMAGES:
|
47 |
previewer = Previewer()
|
48 |
previewer.load_state_dict(torch.load("previewer/text2img_wurstchen_b_v1_previewer_100k.pt")["state_dict"])
|
|
|
49 |
|
50 |
def callback_prior(i, t, latents):
|
51 |
output = previewer(latents)
|
|
|
83 |
) -> PIL.Image.Image:
|
84 |
prior_pipeline.to("cuda")
|
85 |
decoder_pipeline.to("cuda")
|
86 |
+
previewer.eval().requires_grad_(False).to(device).to(dtype)
|
87 |
generator = torch.Generator().manual_seed(seed)
|
88 |
prior_output = prior_pipeline(
|
89 |
prompt=prompt,
|