multimodalart HF staff commited on
Commit
d465e66
1 Parent(s): 2e6feb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -8,20 +8,20 @@ from typing import List
8
  from diffusers.utils import numpy_to_pil
9
  from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
10
  from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
11
-
12
  #import user_history
13
 
14
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
15
 
16
  DESCRIPTION = "# Stable Cascade"
17
- #DESCRIPTION += "\n<p style=\"text-align: center\"><a href='https://huggingface.co/warp-ai/wuerstchen' target='_blank'>Würstchen</a> is a new fast and efficient high resolution text-to-image architecture and model</p>"
18
  if not torch.cuda.is_available():
19
  DESCRIPTION += "\n<p>Running on CPU 🥶</p>"
20
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
23
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
24
- USE_TORCH_COMPILE = True
25
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
26
  PREVIEW_IMAGES = False #not working for now
27
 
@@ -39,8 +39,8 @@ if torch.cuda.is_available():
39
  decoder_pipeline.to(device)
40
 
41
  if USE_TORCH_COMPILE:
42
- #prior_pipeline.prior = torch.compile(prior_pipeline.prior)
43
- decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="reduce-overhead", fullgraph=True)
44
 
45
  if PREVIEW_IMAGES:
46
  pass
@@ -66,7 +66,7 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
66
  seed = random.randint(0, MAX_SEED)
67
  return seed
68
 
69
-
70
  def generate(
71
  prompt: str,
72
  negative_prompt: str = "",
@@ -82,8 +82,9 @@ def generate(
82
  num_images_per_prompt: int = 2,
83
  #profile: gr.OAuthProfile | None = None,
84
  ) -> PIL.Image.Image:
 
 
85
  generator = torch.Generator().manual_seed(seed)
86
-
87
  prior_output = prior_pipeline(
88
  prompt=prompt,
89
  height=height,
@@ -193,7 +194,7 @@ with gr.Blocks() as demo:
193
  minimum=1,
194
  maximum=2,
195
  step=1,
196
- value=2,
197
  )
198
  with gr.Row():
199
  prior_guidance_scale = gr.Slider(
 
8
  from diffusers.utils import numpy_to_pil
9
  from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
10
  from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
11
+ import spaces
12
  #import user_history
13
 
14
  os.environ['TOKENIZERS_PARALLELISM'] = 'false'
15
 
16
  DESCRIPTION = "# Stable Cascade"
17
+ DESCRIPTION += "\n<p style=\"text-align: center\"><a href='https://huggingface.co/stabilityai/stable-cascade' target='_blank'>Stable Casaade</a> is a new fast and efficient high resolution text-to-image architecture and model built on the Würstchen architecture</p>"
18
  if not torch.cuda.is_available():
19
  DESCRIPTION += "\n<p>Running on CPU 🥶</p>"
20
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
23
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
24
+ USE_TORCH_COMPILE = False
25
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
26
  PREVIEW_IMAGES = False #not working for now
27
 
 
39
  decoder_pipeline.to(device)
40
 
41
  if USE_TORCH_COMPILE:
42
+ prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
43
+ decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="max-autotune", fullgraph=True)
44
 
45
  if PREVIEW_IMAGES:
46
  pass
 
66
  seed = random.randint(0, MAX_SEED)
67
  return seed
68
 
69
+ @spaces.GPU
70
  def generate(
71
  prompt: str,
72
  negative_prompt: str = "",
 
82
  num_images_per_prompt: int = 2,
83
  #profile: gr.OAuthProfile | None = None,
84
  ) -> PIL.Image.Image:
85
+ prior_pipeline.to("cuda")
86
+ decoder_pipeline.to("cuda")
87
  generator = torch.Generator().manual_seed(seed)
 
88
  prior_output = prior_pipeline(
89
  prompt=prompt,
90
  height=height,
 
194
  minimum=1,
195
  maximum=2,
196
  step=1,
197
+ value=1,
198
  )
199
  with gr.Row():
200
  prior_guidance_scale = gr.Slider(