charanhu commited on
Commit
29383c6
1 Parent(s): 85f71cf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import random
5
+ import uuid
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ from PIL import Image
10
+ import spaces
11
+ import torch
12
+ from diffusers import DiffusionPipeline
13
+
14
+ DESCRIPTION = """# Playground v2.5"""
15
+ if not torch.cuda.is_available():
16
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
17
+
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
20
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
21
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
22
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
23
+
24
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
+
26
+ NUM_IMAGES_PER_PROMPT = 1
27
+
28
+ if torch.cuda.is_available():
29
+ pipe = DiffusionPipeline.from_pretrained(
30
+ "playgroundai/playground-v2.5-1024px-aesthetic",
31
+ torch_dtype=torch.float16,
32
+ use_safetensors=True,
33
+ add_watermarker=False,
34
+ variant="fp16"
35
+ )
36
+ if ENABLE_CPU_OFFLOAD:
37
+ pipe.enable_model_cpu_offload()
38
+ else:
39
+ pipe.to(device)
40
+ print("Loaded on Device!")
41
+
42
+ if USE_TORCH_COMPILE:
43
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
44
+ print("Model Compiled!")
45
+
46
+
47
+ def save_image(img):
48
+ unique_name = str(uuid.uuid4()) + ".png"
49
+ img.save(unique_name)
50
+ return unique_name
51
+
52
+
53
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
54
+ if randomize_seed:
55
+ seed = random.randint(0, MAX_SEED)
56
+ return seed
57
+
58
+
59
+ @spaces.GPU(enable_queue=True)
60
+ def generate(
61
+ prompt: str,
62
+ negative_prompt: str = "",
63
+ use_negative_prompt: bool = False,
64
+ seed: int = 0,
65
+ width: int = 1024,
66
+ height: int = 1024,
67
+ guidance_scale: float = 3,
68
+ randomize_seed: bool = False,
69
+ use_resolution_binning: bool = True,
70
+ progress=gr.Progress(track_tqdm=True),
71
+ ):
72
+ pipe.to(device)
73
+ seed = int(randomize_seed_fn(seed, randomize_seed))
74
+ generator = torch.Generator().manual_seed(seed)
75
+
76
+ if not use_negative_prompt:
77
+ negative_prompt = None # type: ignore
78
+
79
+ images = pipe(
80
+ prompt=prompt,
81
+ negative_prompt=negative_prompt,
82
+ width=width,
83
+ height=height,
84
+ guidance_scale=guidance_scale,
85
+ num_inference_steps=25,
86
+ generator=generator,
87
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
88
+ use_resolution_binning=use_resolution_binning,
89
+ output_type="pil",
90
+ ).images
91
+
92
+ image_paths = [save_image(img) for img in images]
93
+ print(image_paths)
94
+ return image_paths, seed
95
+
96
+
97
+ examples = [
98
+ "neon holography crystal cat",
99
+ "a cat eating a piece of cheese",
100
+ "an astronaut riding a horse in space",
101
+ "a cartoon of a boy playing with a tiger",
102
+ "a cute robot artist painting on an easel, concept art",
103
+ "a close up of a woman wearing a transparent, prismatic, elaborate nemeses headdress, over the should pose, brown skin-tone"
104
+ ]
105
+
106
+ css = '''
107
+ .gradio-container{max-width: 560px !important}
108
+ h1{text-align:center}
109
+ '''
110
+ with gr.Blocks(css=css) as demo:
111
+ gr.Markdown(DESCRIPTION)
112
+ gr.DuplicateButton(
113
+ value="Duplicate Space for private use",
114
+ elem_id="duplicate-button",
115
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
116
+ )
117
+ with gr.Group():
118
+ with gr.Row():
119
+ prompt = gr.Text(
120
+ label="Prompt",
121
+ show_label=False,
122
+ max_lines=1,
123
+ placeholder="Enter your prompt",
124
+ container=False,
125
+ )
126
+ run_button = gr.Button("Run", scale=0)
127
+ result = gr.Gallery(label="Result", columns=NUM_IMAGES_PER_PROMPT, show_label=False)
128
+ with gr.Accordion("Advanced options", open=False):
129
+ with gr.Row():
130
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
131
+ negative_prompt = gr.Text(
132
+ label="Negative prompt",
133
+ max_lines=1,
134
+ placeholder="Enter a negative prompt",
135
+ visible=True,
136
+ )
137
+ seed = gr.Slider(
138
+ label="Seed",
139
+ minimum=0,
140
+ maximum=MAX_SEED,
141
+ step=1,
142
+ value=0,
143
+ )
144
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
145
+ with gr.Row(visible=True):
146
+ width = gr.Slider(
147
+ label="Width",
148
+ minimum=256,
149
+ maximum=MAX_IMAGE_SIZE,
150
+ step=32,
151
+ value=1024,
152
+ )
153
+ height = gr.Slider(
154
+ label="Height",
155
+ minimum=256,
156
+ maximum=MAX_IMAGE_SIZE,
157
+ step=32,
158
+ value=1024,
159
+ )
160
+ with gr.Row():
161
+ guidance_scale = gr.Slider(
162
+ label="Guidance Scale",
163
+ minimum=0.1,
164
+ maximum=20,
165
+ step=0.1,
166
+ value=3.0,
167
+ )
168
+
169
+ gr.Examples(
170
+ examples=examples,
171
+ inputs=prompt,
172
+ outputs=[result, seed],
173
+ fn=generate,
174
+ cache_examples=CACHE_EXAMPLES,
175
+ )
176
+
177
+ use_negative_prompt.change(
178
+ fn=lambda x: gr.update(visible=x),
179
+ inputs=use_negative_prompt,
180
+ outputs=negative_prompt,
181
+ api_name=False,
182
+ )
183
+
184
+ gr.on(
185
+ triggers=[
186
+ prompt.submit,
187
+ negative_prompt.submit,
188
+ run_button.click,
189
+ ],
190
+ fn=generate,
191
+ inputs=[
192
+ prompt,
193
+ negative_prompt,
194
+ use_negative_prompt,
195
+ seed,
196
+ width,
197
+ height,
198
+ guidance_scale,
199
+ randomize_seed,
200
+ ],
201
+ outputs=[result, seed],
202
+ api_name="run",
203
+ )
204
+
205
+ if __name__ == "__main__":
206
+ demo.queue(max_size=20).launch()