Spaces:
Sleeping
Sleeping
Commit
•
1a89b82
1
Parent(s):
7359460
Update app.py
Browse files
app.py
CHANGED
@@ -29,8 +29,8 @@ PREVIEW_IMAGES = True
|
|
29 |
dtype = torch.bfloat16
|
30 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
31 |
if torch.cuda.is_available():
|
32 |
-
prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)
|
33 |
-
decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)
|
34 |
|
35 |
if ENABLE_CPU_OFFLOAD:
|
36 |
prior_pipeline.enable_model_cpu_offload()
|
@@ -45,8 +45,8 @@ if torch.cuda.is_available():
|
|
45 |
|
46 |
if PREVIEW_IMAGES:
|
47 |
previewer = Previewer()
|
48 |
-
|
49 |
-
previewer.
|
50 |
def callback_prior(i, t, latents):
|
51 |
output = previewer(latents)
|
52 |
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
|
@@ -82,9 +82,10 @@ def generate(
|
|
82 |
num_images_per_prompt: int = 2,
|
83 |
profile: gr.OAuthProfile | None = None,
|
84 |
) -> PIL.Image.Image:
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
88 |
generator = torch.Generator().manual_seed(seed)
|
89 |
prior_output = prior_pipeline(
|
90 |
prompt=prompt,
|
|
|
29 |
dtype = torch.bfloat16
|
30 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
31 |
if torch.cuda.is_available():
|
32 |
+
prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device)
|
33 |
+
decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
|
34 |
|
35 |
if ENABLE_CPU_OFFLOAD:
|
36 |
prior_pipeline.enable_model_cpu_offload()
|
|
|
45 |
|
46 |
if PREVIEW_IMAGES:
|
47 |
previewer = Previewer()
|
48 |
+
previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
|
49 |
+
previewer.load_state_dict(previewer_state_dict)
|
50 |
def callback_prior(i, t, latents):
|
51 |
output = previewer(latents)
|
52 |
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
|
|
|
82 |
num_images_per_prompt: int = 2,
|
83 |
profile: gr.OAuthProfile | None = None,
|
84 |
) -> PIL.Image.Image:
|
85 |
+
previewer.eval().requires_grad_(False).to(device).to(dtype)
|
86 |
+
prior_pipeline.to(device)
|
87 |
+
decoder_pipeline.to(device)
|
88 |
+
|
89 |
generator = torch.Generator().manual_seed(seed)
|
90 |
prior_output = prior_pipeline(
|
91 |
prompt=prompt,
|