PayPeer commited on
Commit
0e56cc5
β€’
1 Parent(s): 9484df6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -0
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoPipelineForImage2Image, AutoPipelineForText2Image
2
+ import torch
3
+ import os
4
+
5
+ try:
6
+ import intel_extension_for_pytorch as ipex
7
+ except:
8
+ pass
9
+
10
+ from PIL import Image
11
+ import numpy as np
12
+ import gradio as gr
13
+ import psutil
14
+ import time
15
+
16
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
17
+ TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
18
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
19
+ # check if MPS is available OSX only M1/M2/M3 chips
20
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
21
+ xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
22
+ device = torch.device(
23
+ "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
24
+ )
25
+ torch_device = device
26
+ torch_dtype = torch.float16
27
+
28
+ print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
29
+ print(f"TORCH_COMPILE: {TORCH_COMPILE}")
30
+ print(f"device: {device}")
31
+
32
+ if mps_available:
33
+ device = torch.device("mps")
34
+ torch_device = "cpu"
35
+ torch_dtype = torch.float32
36
+
37
+ if SAFETY_CHECKER == "True":
38
+ i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
39
+ "stabilityai/sdxl-turbo",
40
+ torch_dtype=torch_dtype,
41
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
42
+ )
43
+ t2i_pipe = AutoPipelineForText2Image.from_pretrained(
44
+ "stabilityai/sdxl-turbo",
45
+ torch_dtype=torch_dtype,
46
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
47
+ )
48
+ else:
49
+ i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
50
+ "stabilityai/sdxl-turbo",
51
+ safety_checker=None,
52
+ torch_dtype=torch_dtype,
53
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
54
+ )
55
+ t2i_pipe = AutoPipelineForText2Image.from_pretrained(
56
+ "stabilityai/sdxl-turbo",
57
+ safety_checker=None,
58
+ torch_dtype=torch_dtype,
59
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
60
+ )
61
+
62
+
63
+ t2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
64
+ t2i_pipe.set_progress_bar_config(disable=True)
65
+ i2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
66
+ i2i_pipe.set_progress_bar_config(disable=True)
67
+
68
+
69
+ def resize_crop(image, size=512):
70
+ image = image.convert("RGB")
71
+ w, h = image.size
72
+ image = image.resize((size, int(size * (h / w))), Image.BICUBIC)
73
+ return image
74
+
75
+
76
+ async def predict(init_image, prompt, strength, steps, seed=1231231):
77
+ if init_image is not None:
78
+ init_image = resize_crop(init_image)
79
+ generator = torch.manual_seed(seed)
80
+ last_time = time.time()
81
+ results = i2i_pipe(
82
+ prompt=prompt,
83
+ image=init_image,
84
+ generator=generator,
85
+ num_inference_steps=steps,
86
+ guidance_scale=0.0,
87
+ strength=strength,
88
+ width=512,
89
+ height=512,
90
+ output_type="pil",
91
+ )
92
+ else:
93
+ generator = torch.manual_seed(seed)
94
+ last_time = time.time()
95
+ results = t2i_pipe(
96
+ prompt=prompt,
97
+ generator=generator,
98
+ num_inference_steps=steps,
99
+ guidance_scale=0.0,
100
+ width=512,
101
+ height=512,
102
+ output_type="pil",
103
+ )
104
+ print(f"Pipe took {time.time() - last_time} seconds")
105
+ nsfw_content_detected = (
106
+ results.nsfw_content_detected[0]
107
+ if "nsfw_content_detected" in results
108
+ else False
109
+ )
110
+ if nsfw_content_detected:
111
+ gr.Warning("NSFW content detected.")
112
+ return Image.new("RGB", (512, 512))
113
+ return results.images[0]
114
+
115
+
116
+ css = """
117
+ #container{
118
+ margin: 0 auto;
119
+ max-width: 80rem;
120
+ }
121
+ #intro{
122
+ max-width: 100%;
123
+ text-align: center;
124
+ margin: 0 auto;
125
+ }
126
+ """
127
+ with gr.Blocks(css=css) as demo:
128
+ init_image_state = gr.State()
129
+ with gr.Column(elem_id="container"):
130
+ gr.Markdown(
131
+ """# SDXL Turbo Image to Image/Text to Image
132
+ ## Unofficial Demo
133
+ SDXL Turbo model can generate high quality images in a single pass read more on [stability.ai post](https://stability.ai/news/stability-ai-sdxl-turbo).
134
+ **Model**: https://huggingface.co/stabilityai/sdxl-turbo
135
+ """,
136
+ elem_id="intro",
137
+ )
138
+ with gr.Row():
139
+ prompt = gr.Textbox(
140
+ placeholder="Insert your prompt here:",
141
+ scale=5,
142
+ container=False,
143
+ )
144
+ generate_bt = gr.Button("Generate", scale=1)
145
+ with gr.Row():
146
+ with gr.Column():
147
+ image_input = gr.Image(
148
+ sources=["upload", "webcam", "clipboard"],
149
+ label="Webcam",
150
+ type="pil",
151
+ )
152
+ with gr.Column():
153
+ image = gr.Image(type="filepath")
154
+ with gr.Accordion("Advanced options", open=False):
155
+ strength = gr.Slider(
156
+ label="Strength",
157
+ value=0.7,
158
+ minimum=0.0,
159
+ maximum=1.0,
160
+ step=0.001,
161
+ )
162
+ steps = gr.Slider(
163
+ label="Steps", value=2, minimum=1, maximum=10, step=1
164
+ )
165
+ seed = gr.Slider(
166
+ randomize=True,
167
+ minimum=0,
168
+ maximum=12013012031030,
169
+ label="Seed",
170
+ step=1,
171
+ )
172
+
173
+ with gr.Accordion("Run with diffusers"):
174
+ gr.Markdown(
175
+ """## Running SDXL Turbo with `diffusers`
176
+ ```bash
177
+ pip install diffusers==0.23.1
178
+ ```
179
+ ```py
180
+ from diffusers import DiffusionPipeline
181
+
182
+ pipe = DiffusionPipeline.from_pretrained(
183
+ "stabilityai/sdxl-turbo"
184
+ ).to("cuda")
185
+ results = pipe(
186
+ prompt="A cinematic shot of a baby racoon wearing an intricate italian priest robe",
187
+ num_inference_steps=1,
188
+ guidance_scale=0.0,
189
+ )
190
+ imga = results.images[0]
191
+ imga.save("image.png")
192
+ ```
193
+ """
194
+ )
195
+
196
+ inputs = [image_input, prompt, strength, steps, seed]
197
+ generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
198
+ prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False)
199
+ steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
200
+ seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
201
+ strength.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
202
+ image_input.change(
203
+ fn=lambda x: x,
204
+ inputs=image_input,
205
+ outputs=init_image_state,
206
+ show_progress=False,
207
+ queue=False,
208
+ )
209
+
210
+ demo.queue()
211
+ demo.launch()