jbilcke-hf HF staff commited on
Commit
a076391
·
1 Parent(s): 53f3635

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -19
app.py CHANGED
@@ -6,24 +6,29 @@ import gradio as gr
6
  import numpy as np
7
  import PIL.Image
8
  import torch
9
- from diffusers import LCMScheduler, AutoPipelineForText2Image
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
  MAX_IMAGE_SIZE = int(os.getenv('MAX_IMAGE_SIZE', '1024'))
13
  SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
14
 
15
- MODEL_ID = "segmind/SSD-1B"
16
- ADAPTER_ID = "latent-consistency/lcm-lora-ssd-1b"
17
-
18
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
19
  if torch.cuda.is_available():
20
- pipe = AutoPipelineForText2Image.from_pretrained(MODEL_ID, torch_dtype=torch.float16, variant="fp16")
21
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
22
- pipe.to("cuda")
 
 
 
 
 
 
 
 
 
23
 
24
- # load and fuse
25
- pipe.load_lora_weights(ADAPTER_ID)
26
- pipe.fuse_lora()
27
  else:
28
  pipe = None
29
 
@@ -39,8 +44,8 @@ def generate(prompt: str,
39
  seed: int = 0,
40
  width: int = 1024,
41
  height: int = 1024,
42
- guidance_scale: float = 0.0,
43
- num_inference_steps: int = 4,
44
  secret_token: str = '') -> PIL.Image.Image:
45
  if secret_token != SECRET_TOKEN:
46
  raise gr.Error(
@@ -64,7 +69,7 @@ with gr.Blocks() as demo:
64
  gr.HTML("""
65
  <div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
66
  <div style="text-align: center; color: black;">
67
- <p style="color: black;">This space is a REST API to programmatically generate images using LCM SDXL LoRA.</p>
68
  <p style="color: black;">It is not meant to be directly used through a user interface, but using code and an access key.</p>
69
  </div>
70
  </div>""")
@@ -112,16 +117,16 @@ with gr.Blocks() as demo:
112
  )
113
  guidance_scale = gr.Slider(
114
  label='Guidance scale',
115
- minimum=0,
116
- maximum=2,
117
  step=0.1,
118
- value=0.0)
119
  num_inference_steps = gr.Slider(
120
  label='Number of inference steps',
121
- minimum=1,
122
- maximum=8,
123
  step=1,
124
- value=4)
125
 
126
  use_negative_prompt.change(
127
  fn=lambda x: gr.update(visible=x),
 
6
  import numpy as np
7
  import PIL.Image
8
  import torch
9
+ from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
  MAX_IMAGE_SIZE = int(os.getenv('MAX_IMAGE_SIZE', '1024'))
13
  SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
14
 
 
 
 
15
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
16
  if torch.cuda.is_available():
17
+ unet = UNet2DConditionModel.from_pretrained(
18
+ "latent-consistency/lcm-ssd-1b",
19
+ torch_dtype=torch.float16,
20
+ variant="fp16"
21
+ )
22
+
23
+ pipe = DiffusionPipeline.from_pretrained(
24
+ "segmind/SSD-1B",
25
+ unet=unet,
26
+ torch_dtype=torch.float16,
27
+ variant="fp16"
28
+ )
29
 
30
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
31
+ pipe.to(device)
 
32
  else:
33
  pipe = None
34
 
 
44
  seed: int = 0,
45
  width: int = 1024,
46
  height: int = 1024,
47
+ guidance_scale: float = 1.0,
48
+ num_inference_steps: int = 6,
49
  secret_token: str = '') -> PIL.Image.Image:
50
  if secret_token != SECRET_TOKEN:
51
  raise gr.Error(
 
69
  gr.HTML("""
70
  <div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
71
  <div style="text-align: center; color: black;">
72
+ <p style="color: black;">This space is a REST API to programmatically generate images using LCM-SSD-1B.</p>
73
  <p style="color: black;">It is not meant to be directly used through a user interface, but using code and an access key.</p>
74
  </div>
75
  </div>""")
 
117
  )
118
  guidance_scale = gr.Slider(
119
  label='Guidance scale',
120
+ minimum=1,
121
+ maximum=20,
122
  step=0.1,
123
+ value=1.0)
124
  num_inference_steps = gr.Slider(
125
  label='Number of inference steps',
126
+ minimum=2,
127
+ maximum=40,
128
  step=1,
129
+ value=6)
130
 
131
  use_negative_prompt.change(
132
  fn=lambda x: gr.update(visible=x),