Akjava commited on
Commit
0c03a70
·
1 Parent(s): ca68f6d

support donut

Browse files
Files changed (1) hide show
  1. app.py +36 -12
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  from diffusers import FluxInpaintPipeline
4
  import gradio as gr
5
  import re
6
- from PIL import Image
7
 
8
  import os
9
  import numpy as np
@@ -28,6 +28,16 @@ def adjust_to_multiple_of_32(width: int, height: int):
28
  height = height - (height % 32)
29
  return width, height
30
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  dtype = torch.bfloat16
@@ -47,7 +57,7 @@ def sanitize_prompt(prompt):
47
  return sanitized_prompt
48
 
49
  @spaces.GPU(duration=120)
50
- def process_images(image, image2=None,prompt="a girl",inpaint_model="black-forest-labs/FLUX.1-schnell",strength=0.75,seed=0,progress=gr.Progress(track_tqdm=True)):
51
  # I'm not sure when this happen
52
  progress(0, desc="start-process-images")
53
  #print("start-process-images")
@@ -78,9 +88,9 @@ def process_images(image, image2=None,prompt="a girl",inpaint_model="black-fores
78
  generator = torch.Generator("cuda").manual_seed(seed)
79
  generators.append(generator)
80
 
81
- width,height = convert_to_fit_size(image.size)
82
  #print(f"fit {width}x{height}")
83
- width,height = adjust_to_multiple_of_32(width,height)
84
  #print(f"multiple {width}x{height}")
85
  image = image.resize((width, height), Image.LANCZOS)
86
  mask_image = mask_image.resize((width, height), Image.NEAREST)
@@ -89,11 +99,24 @@ def process_images(image, image2=None,prompt="a girl",inpaint_model="black-fores
89
  output = pipe(prompt=prompt, image=image, mask_image=mask_image,generator=generator,strength=strength,width=width,height=height,
90
  guidance_scale=0,num_inference_steps=num_inference_steps,max_sequence_length=256)
91
 
92
- return output.images[0],mask_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
-
95
- output,mask_image = process_inpaint(image["background"],mask,prompt,inpaint_model,strength,seed)
96
-
97
  return output,mask_image
98
 
99
 
@@ -138,9 +161,10 @@ with gr.Blocks(css=css, elem_id="demo-container") as demo:
138
  with gr.Column():
139
  image = gr.ImageEditor(height=800,sources=['upload','clipboard'],transforms=[],image_mode='RGB', layers=False, elem_id="image_upload", type="pil", label="Upload",brush=gr.Brush(colors=["#fff"], color_mode="fixed"))
140
  with gr.Row(elem_id="prompt-container", equal_height=False):
141
- with gr.Row():
142
- prompt = gr.Textbox(label="Prompt",value="a person",placeholder="Your prompt (what you want in place of what is erased)", elem_id="prompt")
143
-
 
144
  btn = gr.Button("Inpaint", elem_id="run_button",variant="primary")
145
 
146
  image_mask = gr.Image(sources=['upload','clipboard'], elem_id="mask_upload", type="pil", label="Mask_Upload",height=400, value=None)
@@ -159,7 +183,7 @@ with gr.Blocks(css=css, elem_id="demo-container") as demo:
159
 
160
 
161
 
162
- btn.click(fn=process_images, inputs=[image, image_mask,prompt,inpaint_model,strength,seed], outputs =[image_out,mask_out], api_name='infer')
163
  gr.Examples(
164
  examples=[
165
  ["examples/00538245.jpg", "examples/normal_mouth_mask.jpg","a beautiful girl,big-smile",0.75,"examples/normal_mouth_mask_result.jpg"],
 
3
  from diffusers import FluxInpaintPipeline
4
  import gradio as gr
5
  import re
6
+ from PIL import Image,ImageFilter
7
 
8
  import os
9
  import numpy as np
 
28
  height = height - (height % 32)
29
  return width, height
30
 
31
+ def mask_to_donut(mask,size):
32
+ if size%2 ==0:
33
+ size+=1
34
+ dilation_mask = mask.filter(ImageFilter.MaxFilter(size))
35
+
36
+ white_img = Image.new('RGB', mask.size, (255,255,255))
37
+ black_img = Image.new('RGB', mask.size, (0,0,0))
38
+ white_img.paste(black_img,(0,0),dilation_mask.convert("L"))
39
+ white_img.paste(mask,(0,0),mask.convert("L"))
40
+ return white_img
41
 
42
 
43
  dtype = torch.bfloat16
 
57
  return sanitized_prompt
58
 
59
  @spaces.GPU(duration=120)
60
+ def process_images(image, image2=None,prompt="a girl",inpaint_model="black-forest-labs/FLUX.1-schnell",strength=0.75,seed=0,donut_mask=True,donut_size=32,progress=gr.Progress(track_tqdm=True)):
61
  # I'm not sure when this happen
62
  progress(0, desc="start-process-images")
63
  #print("start-process-images")
 
88
  generator = torch.Generator("cuda").manual_seed(seed)
89
  generators.append(generator)
90
 
91
+ fit_width,fit_height = convert_to_fit_size(image.size)
92
  #print(f"fit {width}x{height}")
93
+ width,height = adjust_to_multiple_of_32(fit_width,fit_height)
94
  #print(f"multiple {width}x{height}")
95
  image = image.resize((width, height), Image.LANCZOS)
96
  mask_image = mask_image.resize((width, height), Image.NEAREST)
 
99
  output = pipe(prompt=prompt, image=image, mask_image=mask_image,generator=generator,strength=strength,width=width,height=height,
100
  guidance_scale=0,num_inference_steps=num_inference_steps,max_sequence_length=256)
101
 
102
+ return output.images[0],mask_image,image,fit_width,fit_height
103
+
104
+ if donut_mask:
105
+ original_mask = mask
106
+ mask = mask_to_donut(mask,donut_size)
107
+
108
+ #output,mask_image,image_resized,fit_width,fit_height=image["background"],mask,image["background"],512,512
109
+ output,mask_image,image_resized,fit_width,fit_height = process_inpaint(image["background"],mask,prompt,inpaint_model,strength,seed)
110
+
111
+ if donut_mask:
112
+ mask = original_mask.resize(mask_image.size)
113
+ image_resized.paste(output,(0,0),mask.convert("L"))
114
+ output = image_resized.resize((fit_width,fit_height),Image.LANCZOS)
115
+ mask_image = mask.resize(output.size)
116
+ else:
117
+ output = output.resize((fit_width,fit_height),Image.LANCZOS)
118
+ mask_image = mask_image.resize(output.size)
119
 
 
 
 
120
  return output,mask_image
121
 
122
 
 
161
  with gr.Column():
162
  image = gr.ImageEditor(height=800,sources=['upload','clipboard'],transforms=[],image_mode='RGB', layers=False, elem_id="image_upload", type="pil", label="Upload",brush=gr.Brush(colors=["#fff"], color_mode="fixed"))
163
  with gr.Row(elem_id="prompt-container", equal_height=False):
164
+ prompt = gr.Textbox(label="Prompt",value="a person",placeholder="Your prompt (what you want in place of what is erased)", elem_id="prompt")
165
+ with gr.Row(equal_height=True):
166
+ donut_mask = gr.Checkbox(label="Donut Mask",value=False,info="Usually improve result,but slow.Do second example things")
167
+ donut_size = gr.Slider(label="Donut Size",minimum=1,maximum=64,step=1,value=32,info="Larger value make extreamly slow")
168
  btn = gr.Button("Inpaint", elem_id="run_button",variant="primary")
169
 
170
  image_mask = gr.Image(sources=['upload','clipboard'], elem_id="mask_upload", type="pil", label="Mask_Upload",height=400, value=None)
 
183
 
184
 
185
 
186
+ btn.click(fn=process_images, inputs=[image, image_mask,prompt,inpaint_model,strength,seed,donut_mask,donut_size], outputs =[image_out,mask_out], api_name='infer')
187
  gr.Examples(
188
  examples=[
189
  ["examples/00538245.jpg", "examples/normal_mouth_mask.jpg","a beautiful girl,big-smile",0.75,"examples/normal_mouth_mask_result.jpg"],