KingNish commited on
Commit
b50f552
β€’
1 Parent(s): 58d3864

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +262 -100
app.py CHANGED
@@ -1,86 +1,235 @@
1
- import spaces
2
- import argparse
3
  import os
4
- import time
5
- from os import path
6
- from safetensors.torch import load_file
7
- from huggingface_hub import hf_hub_download
8
-
9
- cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
10
- os.environ["TRANSFORMERS_CACHE"] = cache_path
11
- os.environ["HF_HUB_CACHE"] = cache_path
12
- os.environ["HF_HOME"] = cache_path
13
-
14
  import gradio as gr
 
 
 
 
15
  import torch
16
- from diffusers import StableDiffusionXLPipeline, LCMScheduler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # from scheduling_tcd import TCDScheduler
19
 
20
- torch.backends.cuda.matmul.allow_tf32 = True
21
 
22
- class timer:
23
- def __init__(self, method_name="timed process"):
24
- self.method = method_name
25
 
26
- def __enter__(self):
27
- self.start = time.time()
28
- print(f"{self.method} starts")
29
 
30
- def __exit__(self, exc_type, exc_val, exc_tb):
31
- end = time.time()
32
- print(f"{self.method} took {str(round(end - self.start, 2))}s")
 
33
 
34
- if not path.exists(cache_path):
35
- os.makedirs(cache_path, exist_ok=True)
36
 
37
- pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16)
38
- pipe.to(device="cuda", dtype=torch.bfloat16)
39
- unet_state = load_file(hf_hub_download("ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet.safetensors"), device="cuda")
40
- pipe.unet.load_state_dict(unet_state)
41
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing ="trailing")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  with gr.Blocks() as demo:
44
  gr.Markdown(DESCRIPTION)
45
- with gr.Row(equal_height=False):
46
  with gr.Group():
47
  with gr.Row():
48
- prompt = gr.Text(
49
- label="Prompt",
50
- show_label=False,
51
- max_lines=1,
52
- placeholder="Enter your prompt",
53
- container=False,
54
- )
55
- run_button = gr.Button("Run", scale=0)
56
- result = gr.Gallery(label="Result", columns=NUM_IMAGES_PER_PROMPT, show_label=False)
57
-
58
- with gr.Group():
59
- with gr.Row(visible=True):
60
- seed = gr.Slider(
 
 
 
 
 
 
 
 
 
 
 
61
  label="Seed",
62
  minimum=0,
63
- maximum=99999999,
64
- step=1,
65
- value=0,
66
- )
67
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
68
- with gr.Row(visible=True):
69
  width = gr.Slider(
70
  label="Width",
71
  minimum=256,
72
- maximum=8192,
73
  step=32,
74
- value=2048,
75
  )
76
  height = gr.Slider(
77
  label="Height",
78
  minimum=256,
79
- maximum=8192,
80
  step=32,
81
- value=2048,
82
  )
83
- gr.Examples(
 
 
 
 
 
 
 
 
 
84
  examples=examples,
85
  inputs=prompt,
86
  outputs=[result, seed],
@@ -88,56 +237,69 @@ with gr.Blocks() as demo:
88
  cache_examples=CACHE_EXAMPLES,
89
  )
90
 
91
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
92
- if randomize_seed:
93
- seed = random.randint(0, 99999999)
94
- return seed
95
-
96
- @spaces.GPU(duration=10)
97
- def process_image( height, width, prompt, seed, randomize_seed):
98
- global pipe
99
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
100
- return pipe(
101
- prompt=str,,
102
- num_inference_steps=1,
103
- guidance_scale=0.,
104
- height=int(height),
105
- width=int(width),
106
- timesteps=[800],
107
- randomize_seed: bool = False,
108
- use_resolution_binning: bool = True,
109
- progress=gr.Progress(track_tqdm=True),
110
- ).images
111
-
112
- seed = int(randomize_seed_fn(seed, randomize_seed))
113
- generator = torch.Generator().manual_seed(seed)
114
-
115
- reactive_controls = [ height, width, prompt, seed, randomize_seed]
116
 
117
 
118
- btn.click(process_image, inputs=reactive_controls, outputs=[output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  if __name__ == "__main__":
121
- demo.launch()
 
 
 
 
 
 
 
 
 
 
122
 
123
 
124
- DESCRIPTION = """ # Instant Image
125
- ### Super fast text to Image Generator.
126
- ### <span style='color: red;'>You may change the steps from 4 to 8, if you didn't get satisfied results.
127
- ### First Image processing takes time then images generate faster.
128
- """
129
- if not torch.cuda.is_available():
130
- DESCRIPTION += "\n<p>Running on CPU πŸ₯Ά This demo does not work on CPU.</p>"
131
 
132
 
133
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
134
-
135
- examples = [
136
- "A Monkey with a happy face in the Sahara desert.",
137
- "Eiffel Tower was Made up of ICE.",
138
- "Color photo of a corgi made of transparent glass, standing on the riverside in Yosemite National Park.",
139
- "A close-up photo of a woman. She wore a blue coat with a gray dress underneath and has blue eyes.",
140
- "A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.",
141
- "an astronaut sitting in a diner, eating fries, cinematic, analog film",
142
- ]
143
 
 
1
+ from __future__ import annotations
 
2
  import os
3
+ import random
4
+ import uuid
 
 
 
 
 
 
 
 
5
  import gradio as gr
6
+ import spaces
7
+ import numpy as np
8
+ import uuid
9
+ from diffusers import PixArtAlphaPipeline, LCMScheduler
10
  import torch
11
+ from typing import Tuple
12
+ from datetime import datetime
13
+
14
+
15
+ DESCRIPTION = """ # Instant Image
16
+ ### Super fast text to Image Generator.
17
+ ### <span style='color: red;'>You may change the steps from 4 to 8, if you didn't get satisfied results.
18
+ ### First Image processing takes time then images generate faster.
19
+ """
20
+ if not torch.cuda.is_available():
21
+ DESCRIPTION += "\n<p>Running on CPU πŸ₯Ά This demo does not work on CPU.</p>"
22
+
23
+ MAX_SEED = np.iinfo(np.int32).max
24
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
25
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "3000"))
26
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
27
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
28
+ PORT = int(os.getenv("DEMO_PORT", "15432"))
29
+
30
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
+
32
+
33
+ style_list = [
34
+ {
35
+ "name": "(No style)",
36
+ "prompt": "{prompt}",
37
+ "negative_prompt": "",
38
+ },
39
+ {
40
+ "name": "Cinematic",
41
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
42
+ "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
43
+ },
44
+ {
45
+ "name": "Realistic",
46
+ "prompt": "Photorealistic {prompt} . Ulta-realistic, professional, 4k, highly detailed",
47
+ "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly, disfigured",
48
+ },
49
+ {
50
+ "name": "Anime",
51
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
52
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
53
+ },
54
+ {
55
+ "name": "Digital Art",
56
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
57
+ "negative_prompt": "photo, photorealistic, realism, ugly",
58
+ },
59
+ {
60
+ "name": "Pixel art",
61
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
62
+ "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
63
+ },
64
+ {
65
+ "name": "Fantasy art",
66
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
67
+ "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
68
+ },
69
+ {
70
+ "name": "3D Model",
71
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
72
+ "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
73
+ },
74
+ ]
75
+
76
+
77
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
78
+ STYLE_NAMES = list(styles.keys())
79
+ DEFAULT_STYLE_NAME = "(No style)"
80
+ NUM_IMAGES_PER_PROMPT = 1
81
+
82
+ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
83
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
84
+ if not negative:
85
+ negative = ""
86
+ return p.replace("{prompt}", positive), n + negative
87
+
88
+ if torch.cuda.is_available():
89
+
90
+ pipe = PixArtAlphaPipeline.from_pretrained(
91
+ "PixArt-alpha/PixArt-LCM-XL-2-1024-MS",
92
+ torch_dtype=torch.float16,
93
+ use_safetensors=True,
94
+ )
95
+
96
+ if os.getenv('CONSISTENCY_DECODER', False):
97
+ print("Using DALL-E 3 Consistency Decoder")
98
+ pipe.vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
99
+
100
+ if ENABLE_CPU_OFFLOAD:
101
+ pipe.enable_model_cpu_offload()
102
+ else:
103
+ pipe.to(device)
104
+ print("Loaded on Device!")
105
+
106
+ # speed-up T5
107
+ pipe.text_encoder.to_bettertransformer()
108
+
109
+ if USE_TORCH_COMPILE:
110
+ pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
111
+ print("Model Compiled!")
112
+
113
 
 
114
 
 
115
 
 
 
 
116
 
 
 
 
117
 
118
+ def save_image(img):
119
+ unique_name = str(uuid.uuid4()) + ".png"
120
+ img.save(unique_name)
121
+ return unique_name
122
 
 
 
123
 
124
+
125
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
126
+ if randomize_seed:
127
+ seed = random.randint(0, MAX_SEED)
128
+ return seed
129
+
130
+ @spaces.GPU(duration=30)
131
+ def generate(
132
+ prompt: str,
133
+ negative_prompt: str = "",
134
+ style: str = DEFAULT_STYLE_NAME,
135
+ use_negative_prompt: bool = False,
136
+ seed: int = 0,
137
+ width: int = 1024,
138
+ height: int = 1024,
139
+ inference_steps: int = 4,
140
+ randomize_seed: bool = False,
141
+ use_resolution_binning: bool = True,
142
+ progress=gr.Progress(track_tqdm=True),
143
+ ):
144
+ seed = int(randomize_seed_fn(seed, randomize_seed))
145
+ generator = torch.Generator().manual_seed(seed)
146
+
147
+ if not use_negative_prompt:
148
+ negative_prompt = None # type: ignore
149
+ prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
150
+
151
+ images = pipe(
152
+ prompt=prompt,
153
+ negative_prompt=negative_prompt,
154
+ width=width,
155
+ height=height,
156
+ guidance_scale=0,
157
+ num_inference_steps=inference_steps,
158
+ generator=generator,
159
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
160
+ use_resolution_binning=use_resolution_binning,
161
+ output_type="pil",
162
+ ).images
163
+
164
+ image_paths = [save_image(img) for img in images]
165
+ print(image_paths)
166
+ return image_paths, seed
167
+
168
+
169
+ examples = [
170
+ "A Monkey with a happy face in the Sahara desert.",
171
+ "Eiffel Tower was Made up of ICE.",
172
+ "Color photo of a corgi made of transparent glass, standing on the riverside in Yosemite National Park.",
173
+ "A close-up photo of a woman. She wore a blue coat with a gray dress underneath and has blue eyes.",
174
+ "A litter of golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.",
175
+ "an astronaut sitting in a diner, eating fries, cinematic, analog film",
176
+ ]
177
 
178
  with gr.Blocks() as demo:
179
  gr.Markdown(DESCRIPTION)
180
+ with gr.Accordion("Advanced options", open=False):
181
  with gr.Group():
182
  with gr.Row():
183
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False, visible=True)
184
+ negative_prompt = gr.Text(
185
+ label="Negative prompt",
186
+ max_lines=1,
187
+ placeholder="Enter a negative prompt",
188
+ visible=True,
189
+ )
190
+
191
+ # num_imgs = gr.Slider(
192
+ # label="Num Images",
193
+ # minimum=1,
194
+ # maximum=8,
195
+ # step=1,
196
+ # value=1,
197
+ # )
198
+ style_selection = gr.Radio(
199
+ show_label=True,
200
+ container=True,
201
+ interactive=True,
202
+ choices=STYLE_NAMES,
203
+ value=DEFAULT_STYLE_NAME,
204
+ label="Image Style",
205
+ )
206
+ seed = gr.Slider(
207
  label="Seed",
208
  minimum=0,
 
 
 
 
 
 
209
  width = gr.Slider(
210
  label="Width",
211
  minimum=256,
212
+ maximum=MAX_IMAGE_SIZE,
213
  step=32,
214
+ value=1024,
215
  )
216
  height = gr.Slider(
217
  label="Height",
218
  minimum=256,
219
+ maximum=MAX_IMAGE_SIZE,
220
  step=32,
221
+ value=1024,
222
  )
223
+ with gr.Row():
224
+ inference_steps = gr.Slider(
225
+ label="Steps",
226
+ minimum=4,
227
+ maximum=20,
228
+ step=1,
229
+ value=4,
230
+ )
231
+
232
+ gr.Examples(
233
  examples=examples,
234
  inputs=prompt,
235
  outputs=[result, seed],
 
237
  cache_examples=CACHE_EXAMPLES,
238
  )
239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
 
242
+
243
+
244
+
245
+
246
+
247
+
248
+
249
+
250
+
251
+
252
+
253
+
254
+
255
+
256
+
257
+
258
+
259
+ use_negative_prompt.change(
260
+ fn=lambda x: gr.update(visible=x),
261
+ inputs=use_negative_prompt,
262
+ outputs=negative_prompt,
263
+ api_name=False,
264
+ )
265
+
266
+ gr.on(
267
+ triggers=[
268
+ prompt.submit,
269
+ negative_prompt.submit,
270
+ run_button.click,
271
+ ],
272
+ fn=generate,
273
+ inputs=[
274
+ prompt,
275
+ negative_prompt,
276
+ style_selection,
277
+ use_negative_prompt,
278
+ # num_imgs,
279
+ seed,
280
+ width,
281
+ height,
282
+ inference_steps,
283
+ randomize_seed,
284
+ ],
285
+ outputs=[result, seed],
286
+ api_name="run",
287
+ )
288
 
289
  if __name__ == "__main__":
290
+ demo.queue(max_size=20).launch()
291
+ # demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=11900, debug=True)
292
+
293
+
294
+
295
+
296
+
297
+
298
+
299
+
300
+
301
 
302
 
 
 
 
 
 
 
 
303
 
304
 
 
 
 
 
 
 
 
 
 
 
305