rishh76 commited on
Commit
f41d814
·
verified ·
1 Parent(s): aacbc56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -107
app.py CHANGED
@@ -1,154 +1,247 @@
1
- from typing import Tuple, Dict
2
  import requests
3
  import random
4
  import numpy as np
5
  import gradio as gr
 
6
  import torch
7
- from PIL import Image
8
- from diffusers import StableDiffusionInpaintPipeline
9
-
10
- INFO = """
11
- # FLUX-Based Inpainting 🎨
12
-
13
- This interface utilizes a FLUX model variant for precise inpainting. Special thanks to the [Black Forest Labs](https://huggingface.co/black-forest-labs) team
14
- and [Gothos](https://github.com/Gothos) for contributing to this advanced solution.
15
  """
16
 
17
- # Constants
18
- MAX_SEED_VALUE = np.iinfo(np.int32).max
19
- TARGET_DIM = 1024
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
- # Function to clear background
23
- def clear_background(image: Image.Image, threshold: int = 50) -> Image.Image:
24
  image = image.convert("RGBA")
25
- pixels = image.getdata()
26
- processed_data = [
27
- (0, 0, 0, 0) if sum(pixel[:3]) / 3 < threshold else pixel for pixel in pixels
28
- ]
29
- image.putdata(processed_data)
 
 
 
 
 
30
  return image
31
 
32
- # Sample data examples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  EXAMPLES = [
34
  [
35
  {
36
- "background": Image.open(requests.get("https://example.com/doge-1.png", stream=True).raw),
37
- "layers": [clear_background(Image.open(requests.get("https://example.com/mask-1.png", stream=True).raw))],
38
- "composite": Image.open(requests.get("https://example.com/composite-1.png", stream=True).raw),
39
  },
40
- "desert mirage",
41
  42,
42
  False,
43
- 0.75,
44
- 25
45
  ],
46
  [
47
  {
48
- "background": Image.open(requests.get("https://example.com/doge-2.png", stream=True).raw),
49
- "layers": [clear_background(Image.open(requests.get("https://example.com/mask-2.png", stream=True).raw))],
50
- "composite": Image.open(requests.get("https://example.com/composite-2.png", stream=True).raw),
51
  },
52
- "neon city",
53
- 100,
54
- True,
55
- 0.9,
56
- 35
57
  ]
58
  ]
59
 
60
- # Load model
61
- inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
62
  "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
63
 
64
- # Utility to adjust image size
65
- def get_scaled_dimensions(
66
- original_size: Tuple[int, int], max_dim: int = TARGET_DIM
 
67
  ) -> Tuple[int, int]:
68
- width, height = original_size
69
- scaling_factor = max_dim / max(width, height)
70
- return (int(width * scaling_factor) // 32 * 32, int(height * scaling_factor) // 32 * 32)
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  @spaces.GPU(duration=100)
73
- def generate_inpainting(
74
- input_data: Dict,
75
- prompt_text: str,
76
- chosen_seed: int,
77
- use_random_seed: bool,
78
- inpainting_strength: float,
79
- steps: int,
80
  progress=gr.Progress(track_tqdm=True)
81
  ):
82
- if not prompt_text:
83
- return gr.Info("Provide a prompt to proceed."), None
84
 
85
- background = input_data.get("background")
86
- mask_layer = input_data.get("layers")[0]
87
 
88
- if not background:
89
- return gr.Info("Background image is missing."), None
90
 
91
- if not mask_layer:
92
- return gr.Info("Mask layer is missing."), None
93
 
94
- new_width, new_height = get_scaled_dimensions(background.size)
95
- resized_background = background.resize((new_width, new_height), Image.LANCZOS)
96
- resized_mask = mask_layer.resize((new_width, new_height), Image.LANCZOS)
97
 
98
- if use_random_seed:
99
- chosen_seed = random.randint(0, MAX_SEED_VALUE)
100
-
101
- torch.manual_seed(chosen_seed)
102
- generated_image = inpainting_pipeline(
103
- prompt=prompt_text,
104
- image=resized_background,
105
  mask_image=resized_mask,
106
- strength=inpainting_strength,
107
- num_inference_steps=steps,
 
 
 
108
  ).images[0]
 
 
109
 
110
- return generated_image, resized_mask
111
-
112
- # Build the Gradio interface
113
- with gr.Blocks() as flux_app:
114
- gr.Markdown(INFO)
115
 
 
 
116
  with gr.Row():
117
  with gr.Column():
118
- image_editor = gr.ImageEditor(
119
- label="Edit Image",
120
- type="pil",
121
  sources=["upload", "webcam"],
122
- brush=gr.Brush(colors=["#FFF"], color_mode="fixed")
123
- )
124
-
125
- prompt_box = gr.Text(
126
- label="Inpainting Prompt", placeholder="Describe the change you'd like."
127
- )
128
- run_button = gr.Button(value="Run Inpainting")
129
-
130
- with gr.Accordion("Settings"):
131
- seed_slider = gr.Slider(0, MAX_SEED_VALUE, step=1, value=42, label="Seed")
132
- random_seed_toggle = gr.Checkbox(label="Randomize Seed", value=True)
133
- inpainting_strength_slider = gr.Slider(0.0, 1.0, step=0.01, value=0.85, label="Inpainting Strength")
134
- steps_slider = gr.Slider(1, 50, step=1, value=25, label="Inference Steps")
135
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  with gr.Column():
137
- output_image = gr.Image(label="Output Image")
138
- output_mask = gr.Image(label="Processed Mask")
139
-
140
- run_button.click(
141
- generate_inpainting,
142
- inputs=[image_editor, prompt_box, seed_slider, random_seed_toggle, inpainting_strength_slider, steps_slider],
143
- outputs=[output_image, output_mask]
144
- )
145
-
146
- gr.Examples(
147
- examples=EXAMPLES,
148
- fn=generate_inpainting,
149
- inputs=[image_editor, prompt_box, seed_slider, random_seed_toggle, inpainting_strength_slider, steps_slider],
150
- outputs=[output_image, output_mask],
151
- run_on_click=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  )
153
 
154
- flux_app.launch(debug=False, show_error=True)
 
1
+ from typing import Tuple
2
  import requests
3
  import random
4
  import numpy as np
5
  import gradio as gr
6
+ import spaces
7
  import torch
8
+ from PIL import Image, UnidentifiedImageError
9
+ from diffusers import FluxInpaintPipeline
10
+
11
+ MARKDOWN = """
12
+ # FLUX.1 Inpainting 🔥
13
+ Shoutout to [Black Forest Labs](https://huggingface.co/black-forest-labs) team for
14
+ creating this amazing model, and a big thanks to [Gothos](https://github.com/Gothos)
15
+ for taking it to the next level by enabling inpainting with the FLUX.
16
  """
17
 
18
+ MAX_SEED = np.iinfo(np.int32).max
19
+ IMAGE_SIZE = 1024
 
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
+
23
+ def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
24
  image = image.convert("RGBA")
25
+ data = image.getdata()
26
+ new_data = []
27
+ for item in data:
28
+ avg = sum(item[:3]) / 3
29
+ if avg < threshold:
30
+ new_data.append((0, 0, 0, 0))
31
+ else:
32
+ new_data.append(item)
33
+
34
+ image.putdata(new_data)
35
  return image
36
 
37
+
38
+ def load_image(url: str) -> Image.Image:
39
+ try:
40
+ response = requests.get(url, stream=True)
41
+ response.raise_for_status() # Raise an HTTPError for bad responses
42
+ image = Image.open(BytesIO(response.content))
43
+ return image
44
+ except requests.HTTPError as http_err:
45
+ print(f"HTTP error occurred: {http_err}")
46
+ return None
47
+ except UnidentifiedImageError:
48
+ print("Cannot identify image file")
49
+ return None
50
+ except Exception as err:
51
+ print(f"Other error occurred: {err}")
52
+ return None
53
+
54
+
55
  EXAMPLES = [
56
  [
57
  {
58
+ "background": load_image("https://media.roboflow.com/spaces/doge-2-image.png"),
59
+ "layers": [remove_background(load_image("https://media.roboflow.com/spaces/doge-2-mask-2.png"))],
60
+ "composite": load_image("https://media.roboflow.com/spaces/doge-2-composite-2.png"),
61
  },
62
+ "little lion",
63
  42,
64
  False,
65
+ 0.85,
66
+ 30
67
  ],
68
  [
69
  {
70
+ "background": load_image("https://media.roboflow.com/spaces/doge-2-image.png"),
71
+ "layers": [remove_background(load_image("https://media.roboflow.com/spaces/doge-2-mask-3.png"))],
72
+ "composite": load_image("https://media.roboflow.com/spaces/doge-2-composite-3.png"),
73
  },
74
+ "tribal tattoos",
75
+ 42,
76
+ False,
77
+ 0.85,
78
+ 30
79
  ]
80
  ]
81
 
82
+ pipe = FluxInpaintPipeline.from_pretrained(
 
83
  "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
84
 
85
+
86
+ def resize_image_dimensions(
87
+ original_resolution_wh: Tuple[int, int],
88
+ maximum_dimension: int = IMAGE_SIZE
89
  ) -> Tuple[int, int]:
90
+ width, height = original_resolution_wh
91
+
92
+ if width > height:
93
+ scaling_factor = maximum_dimension / width
94
+ else:
95
+ scaling_factor = maximum_dimension / height
96
+
97
+ new_width = int(width * scaling_factor)
98
+ new_height = int(height * scaling_factor)
99
+
100
+ new_width = new_width - (new_width % 32)
101
+ new_height = new_height - (new_height % 32)
102
+
103
+ return new_width, new_height
104
+
105
 
106
  @spaces.GPU(duration=100)
107
+ def process(
108
+ input_image_editor: dict,
109
+ input_text: str,
110
+ seed_slicer: int,
111
+ randomize_seed_checkbox: bool,
112
+ strength_slider: float,
113
+ num_inference_steps_slider: int,
114
  progress=gr.Progress(track_tqdm=True)
115
  ):
116
+ if not input_text:
117
+ return None, None, "Please enter a text prompt."
118
 
119
+ image = input_image_editor.get('background')
120
+ mask = input_image_editor.get('layers', [None])[0]
121
 
122
+ if not image:
123
+ return None, None, "Please upload an image."
124
 
125
+ if not mask:
126
+ return None, None, "Please draw a mask on the image."
127
 
128
+ width, height = resize_image_dimensions(original_resolution_wh=image.size)
129
+ resized_image = image.resize((width, height), Image.LANCZOS)
130
+ resized_mask = mask.resize((width, height), Image.LANCZOS)
131
 
132
+ if randomize_seed_checkbox:
133
+ seed_slicer = random.randint(0, MAX_SEED)
134
+ generator = torch.Generator().manual_seed(seed_slicer)
135
+ result = pipe(
136
+ prompt=input_text,
137
+ image=resized_image,
 
138
  mask_image=resized_mask,
139
+ width=width,
140
+ height=height,
141
+ strength=strength_slider,
142
+ generator=generator,
143
+ num_inference_steps=num_inference_steps_slider
144
  ).images[0]
145
+
146
+ return result, resized_mask, None
147
 
 
 
 
 
 
148
 
149
+ with gr.Blocks() as demo:
150
+ gr.Markdown(MARKDOWN)
151
  with gr.Row():
152
  with gr.Column():
153
+ input_image_editor_component = gr.ImageEditor(
154
+ label='Image',
155
+ type='pil',
156
  sources=["upload", "webcam"],
157
+ image_mode='RGB',
158
+ layers=False,
159
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
160
+
161
+ with gr.Row():
162
+ input_text_component = gr.Text(
163
+ label="Prompt",
164
+ show_label=False,
165
+ max_lines=1,
166
+ placeholder="Enter your prompt",
167
+ container=False,
168
+ )
169
+ submit_button_component = gr.Button(
170
+ value='Submit', variant='primary', scale=0)
171
+
172
+ with gr.Accordion("Advanced Settings", open=False):
173
+ seed_slicer_component = gr.Slider(
174
+ label="Seed",
175
+ minimum=0,
176
+ maximum=MAX_SEED,
177
+ step=1,
178
+ value=42,
179
+ )
180
+
181
+ randomize_seed_checkbox_component = gr.Checkbox(
182
+ label="Randomize seed", value=True)
183
+
184
+ with gr.Row():
185
+ strength_slider_component = gr.Slider(
186
+ label="Strength",
187
+ info="Indicates extent to transform the reference `image`. "
188
+ "Must be between 0 and 1. `image` is used as a starting "
189
+ "point and more noise is added the higher the `strength`.",
190
+ minimum=0,
191
+ maximum=1,
192
+ step=0.01,
193
+ value=0.85,
194
+ )
195
+
196
+ num_inference_steps_slider_component = gr.Slider(
197
+ label="Number of inference steps",
198
+ info="The number of denoising steps. More denoising steps "
199
+ "usually lead to a higher quality image at the",
200
+ minimum=1,
201
+ maximum=50,
202
+ step=1,
203
+ value=20,
204
+ )
205
  with gr.Column():
206
+ output_image_component = gr.Image(
207
+ type='pil', image_mode='RGB', label='Generated image', format="png")
208
+ with gr.Accordion("Debug", open=False):
209
+ output_mask_component = gr.Image(
210
+ type='pil', image_mode='RGB', label='Input mask', format="png")
211
+ with gr.Row():
212
+ gr.Examples(
213
+ fn=process,
214
+ examples=EXAMPLES,
215
+ inputs=[
216
+ input_image_editor_component,
217
+ input_text_component,
218
+ seed_slicer_component,
219
+ randomize_seed_checkbox_component,
220
+ strength_slider_component,
221
+ num_inference_steps_slider_component
222
+ ],
223
+ outputs=[
224
+ output_image_component,
225
+ output_mask_component
226
+ ],
227
+ run_on_click=True,
228
+ cache_examples=True
229
+ )
230
+
231
+ submit_button_component.click(
232
+ fn=process,
233
+ inputs=[
234
+ input_image_editor_component,
235
+ input_text_component,
236
+ seed_slicer_component,
237
+ randomize_seed_checkbox_component,
238
+ strength_slider_component,
239
+ num_inference_steps_slider_component
240
+ ],
241
+ outputs=[
242
+ output_image_component,
243
+ output_mask_component
244
+ ]
245
  )
246
 
247
+ demo.launch(debug=False, show_error=True)