Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files
app.py
CHANGED
@@ -10,7 +10,6 @@ import torch
|
|
10 |
import gradio as gr
|
11 |
import spaces
|
12 |
from tqdm.auto import tqdm
|
13 |
-
import numpy as np
|
14 |
from PIL import Image
|
15 |
from utils import (
|
16 |
load_models, clear_gpu_memory, set_timesteps, latents_to_pil,
|
@@ -18,7 +17,6 @@ from utils import (
|
|
18 |
)
|
19 |
from diffusers import StableDiffusionPipeline
|
20 |
|
21 |
-
|
22 |
# Set device
|
23 |
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
24 |
if device == "mps":
|
@@ -31,12 +29,11 @@ def load_model():
|
|
31 |
"runwayml/stable-diffusion-v1-5",
|
32 |
torch_dtype=torch.float16,
|
33 |
safety_checker=None
|
34 |
-
)
|
35 |
|
36 |
@spaces.GPU
|
37 |
def get_pipeline():
|
38 |
-
|
39 |
-
return pipe.to("cuda")
|
40 |
|
41 |
# Load concept library
|
42 |
concept_embeds, concept_tokens = load_concept_library(get_pipeline())
|
@@ -50,6 +47,7 @@ art_concepts = {
|
|
50 |
"comic_book": "comic book style, ink outlines, cel shading"
|
51 |
}
|
52 |
|
|
|
53 |
def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
|
54 |
vignette_loss_scale, concept_style=None, concept_strength=0.5,
|
55 |
height=512, width=512):
|
@@ -267,7 +265,6 @@ def create_demo():
|
|
267 |
guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=7.5)
|
268 |
vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=0.0)
|
269 |
|
270 |
-
# Combine SD concept library tokens and art concept descriptions
|
271 |
all_styles = ["none"] + concept_tokens + list(art_concepts.keys())
|
272 |
concept_style = gr.Dropdown(choices=all_styles, label="Style Concept", value="none")
|
273 |
concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5)
|
@@ -294,7 +291,7 @@ def create_demo():
|
|
294 |
|
295 |
# Set up event handlers
|
296 |
generate_btn.click(
|
297 |
-
|
298 |
inputs=[prompt, seed, num_inference_steps, guidance_scale,
|
299 |
vignette_loss_scale, concept_style, concept_strength],
|
300 |
outputs=output_image
|
|
|
10 |
import gradio as gr
|
11 |
import spaces
|
12 |
from tqdm.auto import tqdm
|
|
|
13 |
from PIL import Image
|
14 |
from utils import (
|
15 |
load_models, clear_gpu_memory, set_timesteps, latents_to_pil,
|
|
|
17 |
)
|
18 |
from diffusers import StableDiffusionPipeline
|
19 |
|
|
|
20 |
# Set device
|
21 |
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
22 |
if device == "mps":
|
|
|
29 |
"runwayml/stable-diffusion-v1-5",
|
30 |
torch_dtype=torch.float16,
|
31 |
safety_checker=None
|
32 |
+
).to(device)
|
33 |
|
34 |
@spaces.GPU
|
35 |
def get_pipeline():
|
36 |
+
return load_model()
|
|
|
37 |
|
38 |
# Load concept library
|
39 |
concept_embeds, concept_tokens = load_concept_library(get_pipeline())
|
|
|
47 |
"comic_book": "comic book style, ink outlines, cel shading"
|
48 |
}
|
49 |
|
50 |
+
@spaces.GPU
|
51 |
def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
|
52 |
vignette_loss_scale, concept_style=None, concept_strength=0.5,
|
53 |
height=512, width=512):
|
|
|
265 |
guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=7.5)
|
266 |
vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=0.0)
|
267 |
|
|
|
268 |
all_styles = ["none"] + concept_tokens + list(art_concepts.keys())
|
269 |
concept_style = gr.Dropdown(choices=all_styles, label="Style Concept", value="none")
|
270 |
concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5)
|
|
|
291 |
|
292 |
# Set up event handlers
|
293 |
generate_btn.click(
|
294 |
+
generate_latents,
|
295 |
inputs=[prompt, seed, num_inference_steps, guidance_scale,
|
296 |
vignette_loss_scale, concept_style, concept_strength],
|
297 |
outputs=output_image
|
utils.py
CHANGED
@@ -67,12 +67,11 @@ def clear_gpu_memory():
|
|
67 |
"""Clear GPU memory cache"""
|
68 |
torch.cuda.empty_cache()
|
69 |
gc.collect()
|
70 |
-
torch.cuda.empty_cache()
|
71 |
|
72 |
def set_timesteps(scheduler, num_inference_steps):
|
73 |
"""Set timesteps for the scheduler with MPS compatibility fix"""
|
74 |
scheduler.set_timesteps(num_inference_steps)
|
75 |
-
scheduler.timesteps = scheduler.timesteps.to(torch.float32)
|
76 |
|
77 |
def pil_to_latent(input_im, vae, device):
|
78 |
"""
|
|
|
67 |
"""Clear GPU memory cache"""
|
68 |
torch.cuda.empty_cache()
|
69 |
gc.collect()
|
|
|
70 |
|
71 |
def set_timesteps(scheduler, num_inference_steps):
|
72 |
"""Set timesteps for the scheduler with MPS compatibility fix"""
|
73 |
scheduler.set_timesteps(num_inference_steps)
|
74 |
+
scheduler.timesteps = scheduler.timesteps.to(torch.float32)
|
75 |
|
76 |
def pil_to_latent(input_im, vae, device):
|
77 |
"""
|