Shilpaj commited on
Commit
ffd9258
·
verified ·
1 Parent(s): c033de0

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +4 -7
  2. utils.py +1 -2
app.py CHANGED
@@ -10,7 +10,6 @@ import torch
10
  import gradio as gr
11
  import spaces
12
  from tqdm.auto import tqdm
13
- import numpy as np
14
  from PIL import Image
15
  from utils import (
16
  load_models, clear_gpu_memory, set_timesteps, latents_to_pil,
@@ -18,7 +17,6 @@ from utils import (
18
  )
19
  from diffusers import StableDiffusionPipeline
20
 
21
-
22
  # Set device
23
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
24
  if device == "mps":
@@ -31,12 +29,11 @@ def load_model():
31
  "runwayml/stable-diffusion-v1-5",
32
  torch_dtype=torch.float16,
33
  safety_checker=None
34
- )
35
 
36
  @spaces.GPU
37
  def get_pipeline():
38
- pipe = load_model()
39
- return pipe.to("cuda")
40
 
41
  # Load concept library
42
  concept_embeds, concept_tokens = load_concept_library(get_pipeline())
@@ -50,6 +47,7 @@ art_concepts = {
50
  "comic_book": "comic book style, ink outlines, cel shading"
51
  }
52
 
 
53
  def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
54
  vignette_loss_scale, concept_style=None, concept_strength=0.5,
55
  height=512, width=512):
@@ -267,7 +265,6 @@ def create_demo():
267
  guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=7.5)
268
  vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=0.0)
269
 
270
- # Combine SD concept library tokens and art concept descriptions
271
  all_styles = ["none"] + concept_tokens + list(art_concepts.keys())
272
  concept_style = gr.Dropdown(choices=all_styles, label="Style Concept", value="none")
273
  concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5)
@@ -294,7 +291,7 @@ def create_demo():
294
 
295
  # Set up event handlers
296
  generate_btn.click(
297
- generate_image,
298
  inputs=[prompt, seed, num_inference_steps, guidance_scale,
299
  vignette_loss_scale, concept_style, concept_strength],
300
  outputs=output_image
 
10
  import gradio as gr
11
  import spaces
12
  from tqdm.auto import tqdm
 
13
  from PIL import Image
14
  from utils import (
15
  load_models, clear_gpu_memory, set_timesteps, latents_to_pil,
 
17
  )
18
  from diffusers import StableDiffusionPipeline
19
 
 
20
  # Set device
21
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
22
  if device == "mps":
 
29
  "runwayml/stable-diffusion-v1-5",
30
  torch_dtype=torch.float16,
31
  safety_checker=None
32
+ ).to(device)
33
 
34
  @spaces.GPU
35
  def get_pipeline():
36
+ return load_model()
 
37
 
38
  # Load concept library
39
  concept_embeds, concept_tokens = load_concept_library(get_pipeline())
 
47
  "comic_book": "comic book style, ink outlines, cel shading"
48
  }
49
 
50
+ @spaces.GPU
51
  def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
52
  vignette_loss_scale, concept_style=None, concept_strength=0.5,
53
  height=512, width=512):
 
265
  guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=7.5)
266
  vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=0.0)
267
 
 
268
  all_styles = ["none"] + concept_tokens + list(art_concepts.keys())
269
  concept_style = gr.Dropdown(choices=all_styles, label="Style Concept", value="none")
270
  concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5)
 
291
 
292
  # Set up event handlers
293
  generate_btn.click(
294
+ generate_latents,
295
  inputs=[prompt, seed, num_inference_steps, guidance_scale,
296
  vignette_loss_scale, concept_style, concept_strength],
297
  outputs=output_image
utils.py CHANGED
@@ -67,12 +67,11 @@ def clear_gpu_memory():
67
  """Clear GPU memory cache"""
68
  torch.cuda.empty_cache()
69
  gc.collect()
70
- torch.cuda.empty_cache()
71
 
72
  def set_timesteps(scheduler, num_inference_steps):
73
  """Set timesteps for the scheduler with MPS compatibility fix"""
74
  scheduler.set_timesteps(num_inference_steps)
75
- scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility
76
 
77
  def pil_to_latent(input_im, vae, device):
78
  """
 
67
  """Clear GPU memory cache"""
68
  torch.cuda.empty_cache()
69
  gc.collect()
 
70
 
71
  def set_timesteps(scheduler, num_inference_steps):
72
  """Set timesteps for the scheduler with MPS compatibility fix"""
73
  scheduler.set_timesteps(num_inference_steps)
74
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32)
75
 
76
  def pil_to_latent(input_im, vae, device):
77
  """