jozee commited on
Commit
064897a
·
verified ·
1 Parent(s): 46a9673

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +349 -0
app.py CHANGED
@@ -12,6 +12,234 @@ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
12
  from PIL import Image, ImageDraw
13
  import numpy as np
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  with gr.Blocks() as demo:
16
  with gr.Column():
17
  with gr.Row():
@@ -27,7 +255,128 @@ with gr.Blocks() as demo:
27
  with gr.Column(scale=1):
28
  run_button = gr.Button('Generate')
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  with gr.Column():
31
  result = gr.Image(label="Generate Image", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
 
 
 
 
 
 
 
 
33
  demo.queue(max_size=12).launch(share=False)
 
12
  from PIL import Image, ImageDraw
13
  import numpy as np
14
 
15
+ config_file = hf_hub_download(
16
+ "xinsir/controlnet-union-sdxl-1.0",
17
+ filename="config_promax.json",
18
+ )
19
+
20
+ config = ControlNetModdel_Union.load_config(config_file)
21
+ controlnet_model = ControlNetModel_Union.from_config(config)
22
+ model_file = hf_hub_download(
23
+ "xinsir/controlnet-union-sdxl-1.0",
24
+ filename="diffusion_pytorch_model_promax.safetensors",
25
+ )
26
+ state_dict = load_state_dict(model_file)
27
+ model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
28
+ controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
29
+ )
30
+ model.to(device="cuda", dtype=torch.float16)
31
+
32
+ vae = AutoencoderKL.from_pretrained(
33
+ "madebbyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
34
+ ).to("cuda")
35
+
36
+ pipe = StableDiffusionXLFillPipeline.from_pretrained(
37
+ "SG161222/RealVisXL_V5.0_Lightning",
38
+ torch_dtype=torch.float16,
39
+ vae=vae,
40
+ controlnet=model,
41
+ variant="fp16",
42
+ ).to("cuda")
43
+
44
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
45
+
46
+ def can_expand(source_width, source_height, target_width, target_height, alignment):
47
+ """Checks if the image can be expanded based on the alignment."""
48
+ if alignment in ("Left", "Right") and source_width>=target_width:
49
+ return False
50
+ if alignment in ("Top","Bottom") and source_height>=target_height:
51
+ return False
52
+ return True
53
+
54
+ def prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
55
+ target_size = (width, height)
56
+
57
+ #Calculate the scaling factor to fit the image within the target size
58
+ scale_factor = min(target_size[0]/image.width, target_size[1]/image.height)
59
+ new_width = int(image.width * scale_factor)
60
+ new_height = int(image.height * scale_factor)
61
+
62
+ #Resize the source image to fit within target size
63
+ source = image.resize((new_width, new_height), Image.LANCZOS)
64
+
65
+ #Apply resize option using percentages
66
+ if resize_option == "Full":
67
+ resize_percentage = 100
68
+ elif resize_option == "50%":
69
+ resize_percentage = 50
70
+ elif resize_option == "33%":
71
+ resize_percentage = 33
72
+ elif resize_option == "25%":
73
+ resize_option = 25
74
+ else:
75
+ resize_percentage = custom_resize_percentage
76
+
77
+ #calculate new dimensions based on percentage
78
+ resize_factor = resize_percentage/100
79
+ new_width = int(source.width * resize_factor)
80
+ new_height = int(source.height * resize_factor)
81
+
82
+ #Ensure minimum size of 64 pixels
83
+ new_width = max(new_width, 64)
84
+ new_height = max(new_height, 64)
85
+
86
+ #Resize the image
87
+ source = source.resize((new_width, new_height), Image.LANCZOS)
88
+
89
+ #Calculate the overlap in pixels based on the percentage
90
+ overlap_x = int(new_width * (overlap_percentage/100))
91
+ overlap_y = int(new_height * (overlap_percentage/100))
92
+
93
+ #Ensure minimum overlap of 1 pixel
94
+ overlap_x = max(overlap_x, 1)
95
+ overlap_y = max(overlap_y, 1)
96
+
97
+ #Calculate margins based on alignment
98
+ if alignment == "Middle":
99
+ margin_x = (target_size[0]-new_width)//2
100
+ margin_y = (target_size[1]-new_height)//2
101
+ elif alignment == "Left":
102
+ margin_x = 0
103
+ margin_y = (target_size[1]-new_height)//2
104
+ elif alignment == "Right":
105
+ margin_x = target_size[0] - new_width
106
+ margin_y = (target_size[1]-new_height)//2
107
+ elif alignment == "Top":
108
+ margin_x = (target_size[0]-new_width)//2
109
+ margin_y = 0
110
+ elif alignment == "Bottom":
111
+ margin_x = (target_size[0]-new_width)//2
112
+ margin_y = target_size[1] - new_height
113
+
114
+ #adjust margins to eliminate gaps
115
+ margin_x = max(0, min(margin_x, target_size[0]-new_width))
116
+ margin_y = max(0, min(margin_y, target_size[1]-new_height))
117
+
118
+ #Create a new background image and paste the resized source image
119
+ background = Image.new('RGB', target_size, (255,255,255))
120
+ background.paste(source, (margin_x, margin_y))
121
+
122
+ #Create the mask
123
+ mask = Image.new('L', target_size, 255)
124
+ mask_draw = ImageDraw.Draw(mask)
125
+
126
+ #Calculate overlap areas
127
+ white_gaps_patch = 2
128
+
129
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x+white_gaps_patch
130
+ right_overlap = margin_x + new_width-overlap_x if overlap_right else margin_x+new_width-white_gaps_patch
131
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y + white_gaps_patch
132
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y+new_height-white_gaps_patch
133
+
134
+ if alignment == "Left":
135
+ left_overlap = margin_x + overlap_x if overlap_left else margin_x
136
+ elif alignment == "Right":
137
+ right_overlap = margin_x + new_width - overlap_x if overlap_right else margin_x + new_width
138
+ elif alignment == "Top":
139
+ top_overlap = margin_y + overlap_y if overlap_top else margin_y
140
+ elif alignment == "Bottom":
141
+ bottom_overlap = margin_y + new_height - overlap_y if overlap_bottom else margin_y + new_height
142
+
143
+ #Draw the mask
144
+ mask_draw.rectangle([
145
+ (left_overlap, top_overlap),
146
+ (right_overlap, bottom_overlap)
147
+ ], fill=0)
148
+
149
+ return background, mask
150
+
151
+ def preview_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
152
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
153
+
154
+ #Create a preview image showing the mask
155
+ preview = background.copy().convert('RGBA')
156
+
157
+ #Create a semi-transparent red overlay
158
+ red_overlay = Image.new('RGBA', background.size, (255, 0, 0, 64)) #Reduced alpha to 64(25% opacity)
159
+
160
+ #Convert black pixels in the mask to semi-transparent red
161
+ red_mask = Image.new('RGBA', background.size, (0,0,0,0))
162
+ red_mask.paste(red_overlay, (0,0), mask)
163
+
164
+ #Overlay the red mask on the background
165
+ preview = Image.alpha_composite(preview, red_mask)
166
+
167
+ return preview
168
+
169
+ @spaces.GPU(duration=24)
170
+ def infer(image, width, height, overlap_percentage, num_inference_steps, resize_option, custom_resize_percentage, prompt_input, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom):
171
+ background, mask = prepare_image_and_mask(image, width, height, overlap_percentage, resize_option, custom_resize_percentage, alignment, overlap_left, overlap_right, overlap_top, overlap_bottom)
172
+
173
+ if not can_expand(background.width, background.height, width, height, alignment):
174
+ alignment = "Middle"
175
+
176
+ cnet_image = background.copy()
177
+ cnet_image.paste(0, (0,0), mask)
178
+
179
+ final_prompt = f"{prompt_input}, high quality, 4k"
180
+
181
+ (
182
+ prompt_embeds,
183
+ negative_prompt_embeds,
184
+ pooled_prompt_embeds,
185
+ negative_pooled_prompt_embeds,
186
+ ) = pipe.encode_prompt(final_prompt, "cuda", True)
187
+
188
+ for image in pipe(
189
+ prompt_embeds = prompt_embeds,
190
+ negative_prompt_embeds = negative_prompt_embeds,
191
+ pooled_prompt_embeds = pooled_prompt_embeds,
192
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds,
193
+ image = cnet_image,
194
+ num_inference_steps=num_inference_steps
195
+ ):
196
+ yield cnet_image, image
197
+
198
+ image = image.convert('RGBA')
199
+ cnet_image.paste(image, (0,0), mask)
200
+
201
+ yield background, cnet_image
202
+
203
+ def clear_result():
204
+ """Clears the result ImageSlider."""
205
+ return gr.update(value=None)
206
+
207
+ def preload_presets(target_ratio, ui_width, ui_height):
208
+ """Updates the width and height sliders based on the selected aspect ratio."""
209
+ if target_ratio == "9:16":
210
+ changed_width = 720
211
+ changed_height = 1280
212
+ return changed_width, changed_height, gr.update()
213
+ elif target_ratio == "16:9":
214
+ changed_width = 1280
215
+ changed_height = 720
216
+ return changed_width, changed_height, gr.update()
217
+ elif target_ratio == "1:1":
218
+ changed_width = 1024
219
+ changed_height = 1024
220
+ return ui_width, ui_height, gr.update(open=True)
221
+
222
+ def select_the_right_preset(user_width, user_height):
223
+ if user_width == 720 and user_height == 1280:
224
+ return "9:16"
225
+ elif user_width == 1280 and user_height == 720:
226
+ return "16:9"
227
+ elif user_width == 1024 and user_height == 1024:
228
+ return "1:1"
229
+ else:
230
+ return "Custom"
231
+
232
+ def toggle_custom_resize_slider(resize_option):
233
+ return gr.update(visible=(resize_option=="Custom"))
234
+
235
+ def update_history(new_image, history):
236
+ """Updates the history gallery with the new image."""
237
+ if history is None:
238
+ history = []
239
+ history.insert(0, new_image)
240
+ return history
241
+
242
+
243
  with gr.Blocks() as demo:
244
  with gr.Column():
245
  with gr.Row():
 
255
  with gr.Column(scale=1):
256
  run_button = gr.Button('Generate')
257
 
258
+ with gr.Row():
259
+ target_ratio = gr.Ratio(
260
+ label="Expected Ratio",
261
+ choices=["9:16", "16:9", "1:1", "Custom"],
262
+ value="9:16",
263
+ scale=2
264
+ )
265
+
266
+ alignment_dropdown = gr.Dropdown(
267
+ choices=['Middle','Left','Right','Top','Bottom'],
268
+ value='Middle',
269
+ label='Alignment'
270
+ )
271
+ #高级配置,当选择custom的时候会自动打开
272
+ with gr.Accordion(label="Advanced settings", open=False) as settings_panel:
273
+ with gr.Column():
274
+ #自定义的宽高
275
+ with gr.Row():
276
+ width_slider = gr.Slider(
277
+ label="Target Width",
278
+ minimum=720,
279
+ maximum=1536,
280
+ step=8,
281
+ value=720, #Set a default value
282
+ )
283
+ height_slider = gr.Slider(
284
+ label="Target Height",
285
+ minimum=720,
286
+ maximum=1536,
287
+ step=8,
288
+ value=1280, #Set a default value
289
+ )
290
+ #生成步数
291
+ num_inference_steps = gr.Slider(label="Steps", minimum=4, maximum=12, step=1, value=8)
292
+ #组件组
293
+ with gr.Group():
294
+ overlap_percentage = gr.Slider(
295
+ label="Mask overlap (%)",
296
+ minimum=1,
297
+ maximum=50,
298
+ value=10,
299
+ step=1
300
+ )
301
+ with gr.Row():
302
+ overlap_top = gr.Checkbox(label="Overlap Top", value=True)
303
+ overlap_right = gr.Checkbox(label="Overlap Right", value=True)
304
+ with gr.Row():
305
+ overlap_left = gr.Checkbox(label="Overlap Left", value=True)
306
+ overlap_bottom = gr.Checkbox(label="Overlap Bottom", value=True)
307
+ with gr.Row():
308
+ resize_option = gr.Radio(
309
+ label = "Resize input image",
310
+ choices = ["Full", "50%", "33%", "25%", "Custom"],
311
+ value="Full"
312
+ )
313
+ custom_resize_percentage = gr.Slider(
314
+ label="Custom resize (%)",
315
+ minimum = 1,
316
+ maximum = 100,
317
+ step = 1,
318
+ value = 50,
319
+ visible = False
320
+ )
321
+
322
+ with gr.Column():
323
+ preview_button = gr.Button("Preview alignment and mask")
324
+
325
  with gr.Column():
326
  result = gr.Image(label="Generate Image", interactive=False)
327
+ history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
328
+ preview_image = gr.Image(label="Preview")
329
+
330
+ target_ratio.change(
331
+ fn=preload_presets, #选择ratio aspect 的单选框时,调用这个函数
332
+ inputs=[target_ratio, width_slider, height_slider],
333
+ outputs=[width_slider, height_slider, settings_panel],
334
+ queue=False
335
+ )
336
+
337
+ width_slider.change(
338
+ fn=select_the_right_preset,
339
+ inputs=[width_slider, height_slider],
340
+ outputs=[target_ratio],
341
+ queue=False
342
+ )
343
+
344
+ height_slider.change(
345
+ fn=select_the_right_preset,
346
+ inputs=[width_slider, height_slider],
347
+ outputs=[target_ratio],
348
+ queue=False
349
+ )
350
+
351
+ resize_option.change(
352
+ fn=toggle_custom_resize_slider,
353
+ inputs=[resize_option],
354
+ outputs=[custom_resize_percentage],
355
+ queue=False
356
+ )
357
+
358
+ run_button.click(#Clear the result
359
+ fn=clear_result,
360
+ inputs=None,
361
+ outputs=result,
362
+ ).then( #Generate the new image
363
+ fn=infer,
364
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, num_inference_steps,
365
+ resize_option, custom_resize_percentage, prompt_input, alignment_dropdown,
366
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
367
+ outputs=result,
368
+ ).then(#update the history gallery
369
+ fn=lambda x, history: update_history(x[1], history),
370
+ inputs=[result, history_gallery],
371
+ outputs=history_gallery,
372
+ )
373
 
374
+ preview_button.click(
375
+ fn=preview_image_and_mask,
376
+ inputs=[input_image, width_slider, height_slider, overlap_percentage, resize_option, custom_resize_percentage, alignment_dropdown,
377
+ overlap_left, overlap_right, overlap_top, overlap_bottom],
378
+ outputs=preview_image,
379
+ queue=False
380
+ )
381
+
382
  demo.queue(max_size=12).launch(share=False)