israelweiss commited on
Commit
2d8c11a
·
1 Parent(s): 0d20806

no prompt button, api call to int

Browse files
Files changed (1) hide show
  1. app.py +57 -114
app.py CHANGED
@@ -1,26 +1,45 @@
1
  import gradio as gr
2
- import torch
3
  import numpy as np
4
- import diffusers
5
  import os
6
- import spaces
7
- from PIL import Image
8
- hf_token = os.environ.get("HF_TOKEN")
9
- from diffusers import StableDiffusionXLInpaintPipeline, DDIMScheduler, UNet2DConditionModel
10
- from diffusers import (
11
- AutoencoderKL,
12
- LCMScheduler,
13
- )
14
- from pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
15
- from controlnet import ControlNetModel, ControlNetConditioningEmbedding
16
- import torch
17
- import numpy as np
18
  from PIL import Image
19
  import requests
20
- import PIL
21
  from io import BytesIO
22
- from torchvision import transforms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  ratios_map = {
26
  0.5:{"width":704,"height":1408},
@@ -43,13 +62,6 @@ ratios_map = {
43
  }
44
  ratios = np.array(list(ratios_map.keys()))
45
 
46
- image_transforms = transforms.Compose(
47
- [
48
- transforms.ToTensor(),
49
- ]
50
- )
51
-
52
- default_negative_prompt = "Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
53
 
54
 
55
  def get_masked_image(image, image_mask, width, height):
@@ -77,25 +89,6 @@ def get_size(init_image):
77
 
78
  return w,h
79
 
80
- device = "cuda" if torch.cuda.is_available() else "cpu"
81
-
82
- # Load, init model
83
- controlnet = ControlNetModel().from_pretrained("briaai/DEV-ControlNetInpaintingFast", torch_dtype=torch.float16)
84
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
85
- pipe = StableDiffusionXLControlNetPipeline.from_pretrained("briaai/BRIA-2.3", controlnet=controlnet.to(dtype=torch.float16), torch_dtype=torch.float16, vae=vae) #force_zeros_for_empty_prompt=False, # vae=vae)
86
-
87
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
88
- pipe.load_lora_weights("briaai/BRIA-2.3-FAST-LORA")
89
- pipe.fuse_lora()
90
-
91
- pipe = pipe.to(device)
92
- # pipe.enable_xformers_memory_efficient_attention()
93
-
94
- # generator = torch.Generator(device='cuda').manual_seed(123456)
95
-
96
- vae = pipe.vae
97
-
98
- pipe.enable_model_cpu_offload()
99
 
100
  def read_content(file_path: str) -> str:
101
  """read the content of target file
@@ -105,62 +98,19 @@ def read_content(file_path: str) -> str:
105
 
106
  return content
107
 
108
- @spaces.GPU(enable_queue=True)
109
- def predict(dict, prompt="", negative_prompt = default_negative_prompt, guidance_scale=1.2, steps=12, seed=123456):
110
- if negative_prompt == "":
111
- negative_prompt = None
112
 
113
  init_image = Image.fromarray(dict['background'][:, :, :3], 'RGB') #dict['background'].convert("RGB")#.resize((1024, 1024))
114
  mask = Image.fromarray(dict['layers'][0][:,:,3], 'L') #dict['layers'].convert("RGB")#.resize((1024, 1024))
115
-
116
-
117
- width, height = get_size(init_image)
118
-
119
- init_image = init_image.resize((width, height))
120
- mask = mask.resize((width, height))
121
-
122
-
123
- masked_image, image_mask, masked_image_to_present = get_masked_image(init_image, mask, width, height)
124
- masked_image_tensor = image_transforms(masked_image)
125
- masked_image_tensor = (masked_image_tensor - 0.5) / 0.5
126
 
127
- masked_image_tensor = masked_image_tensor.unsqueeze(0).to(device="cuda")
 
128
 
129
- control_latents = vae.encode(
130
- masked_image_tensor[:, :3, :, :].to(vae.dtype)
131
- ).latent_dist.sample()
132
 
133
- control_latents = control_latents * vae.config.scaling_factor
134
-
135
- image_mask = np.array(image_mask)[:,:]
136
- mask_tensor = torch.tensor(image_mask, dtype=torch.float32)[None, ...]
137
- # binarize the mask
138
- mask_tensor = torch.where(mask_tensor > 128.0, 255.0, 0)
139
-
140
- mask_tensor = mask_tensor / 255.0
141
-
142
- mask_tensor = mask_tensor.to(device="cuda")
143
- mask_resized = torch.nn.functional.interpolate(mask_tensor[None, ...], size=(control_latents.shape[2], control_latents.shape[3]), mode='nearest')
144
- # mask_resized = mask_resized.to(torch.float16)
145
- masked_image = torch.cat([control_latents, mask_resized], dim=1)
146
-
147
- generator = torch.Generator(device='cuda').manual_seed(int(seed))
148
-
149
- output = pipe(prompt = prompt,
150
- width=width,
151
- height=height,
152
- negative_prompt=negative_prompt,
153
- image = masked_image, # control image V
154
- init_image = init_image,
155
- mask_image = mask_tensor,
156
- guidance_scale = guidance_scale,
157
- num_inference_steps=int(steps),
158
- # strength=strength,
159
- generator=generator,
160
- controlnet_conditioning_sale=1.0)
161
-
162
- torch.cuda.empty_cache
163
- return output.images[0] #, gr.update(visible=True)
164
 
165
 
166
  css = '''
@@ -212,29 +162,22 @@ with image_blocks as demo:
212
  </p>
213
  ''')
214
  with gr.Row():
215
- with gr.Column():
216
- image = gr.ImageEditor(sources=["upload"], layers=False, transforms=[], brush=gr.Brush(colors=["#000000"], color_mode="fixed")) #gr.Image(sources=['upload'], tool='sketch', elem_id="image_upload", type="pil", label="Upload", height=400)
217
- with gr.Row(elem_id="prompt-container", equal_height=True):
218
- with gr.Row():
219
- prompt = gr.Textbox(placeholder="Your prompt (what you want in place of what is erased)", show_label=False, elem_id="prompt")
220
- btn = gr.Button("Inpaint!", elem_id="run_button")
221
-
222
- with gr.Accordion(label="Advanced Settings", open=False):
223
- with gr.Row(equal_height=True):
224
- guidance_scale = gr.Number(value=1.2, minimum=0.8, maximum=2.5, step=0.1, label="guidance_scale")
225
- steps = gr.Number(value=12, minimum=6, maximum=20, step=1, label="steps")
226
- # strength = gr.Number(value=1, minimum=0.01, maximum=1.0, step=0.01, label="strength")
227
- seed = gr.Number(value=123456, minimum=0, maximum=999999, step=1, label="seed")
228
- negative_prompt = gr.Textbox(label="negative_prompt", value=default_negative_prompt, placeholder=default_negative_prompt, info="what you don't want to see in the image")
229
-
230
-
231
- with gr.Column():
232
- image_out = gr.Image(label="Output", elem_id="output-img", height=400)
233
-
234
 
235
-
236
- btn.click(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, seed], outputs=[image_out], api_name='run')
237
- prompt.submit(fn=predict, inputs=[image, prompt, negative_prompt, guidance_scale, steps, seed], outputs=[image_out])
 
 
 
 
 
 
 
 
238
 
239
  gr.HTML(
240
  """
 
1
  import gradio as gr
 
2
  import numpy as np
3
+
4
  import os
 
 
 
 
 
 
 
 
 
 
 
 
5
  from PIL import Image
6
  import requests
 
7
  from io import BytesIO
8
+ import io
9
+ import base64
10
+
11
+ hf_token = os.environ.get("HF_TOKEN")
12
+ auth_headers = {"api_token": hf_token}
13
+
14
+ def convert_mask_image_to_base64_string(mask_image):
15
+ buffer = io.BytesIO()
16
+ mask_image.save(buffer, format="PNG") # You can choose the format (e.g., "JPEG", "PNG")
17
+ # Encode the buffer in base64
18
+ image_base64_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
19
+ return f",{image_base64_string}" # for some reason the funciton which downloads image from base64 expects prefix of "," which is redundant in the url
20
+
21
+ def download_image(url):
22
+ response = requests.get(url)
23
+ return Image.open(BytesIO(response.content)).convert("RGB")
24
 
25
+ def eraser_api_call(image_base64_file, mask_base64_file, seed, mask_type, original_quality, guidance_scale):
26
+
27
+ # url = "http://engine.prod.bria-api.com/v1/eraser" # TODO: use this link!
28
+ url = "http://engine.int.bria-api.com/v1/eraser" # TODO: use this link!
29
+
30
+ payload = {
31
+ "file": image_base64_file,
32
+ "mask_file": mask_base64_file,
33
+ "seed": seed,
34
+ "mask_type": mask_type,
35
+ "original_quality": original_quality,
36
+ "text_guidance_scale": guidance_scale
37
+ }
38
+ response = requests.post(url, json=payload, headers=auth_headers)
39
+ response = response.json()
40
+ res_image = download_image(response["result_url"])
41
+
42
+ return res_image
43
 
44
  ratios_map = {
45
  0.5:{"width":704,"height":1408},
 
62
  }
63
  ratios = np.array(list(ratios_map.keys()))
64
 
 
 
 
 
 
 
 
65
 
66
 
67
  def get_masked_image(image, image_mask, width, height):
 
89
 
90
  return w,h
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def read_content(file_path: str) -> str:
94
  """read the content of target file
 
98
 
99
  return content
100
 
101
+ def predict(dict, guidance_scale=1.2, seed=123456):
 
 
 
102
 
103
  init_image = Image.fromarray(dict['background'][:, :, :3], 'RGB') #dict['background'].convert("RGB")#.resize((1024, 1024))
104
  mask = Image.fromarray(dict['layers'][0][:,:,3], 'L') #dict['layers'].convert("RGB")#.resize((1024, 1024))
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ image_base64_file = convert_mask_image_to_base64_string(init_image)
107
+ mask_base64_file = convert_mask_image_to_base64_string(mask)
108
 
109
+ mask_type = "brush"
110
+ original_quality = True
111
+ gen_img = eraser_api_call(image_base64_file, mask_base64_file, seed, mask_type, original_quality, guidance_scale)
112
 
113
+ return gen_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
  css = '''
 
162
  </p>
163
  ''')
164
  with gr.Row():
165
+ with gr.Column():
166
+ image = gr.ImageEditor(sources=["upload"], layers=False, transforms=[], brush=gr.Brush(colors=["#000000"], color_mode="fixed"))
167
+ with gr.Row(elem_id="prompt-container", equal_height=True):
168
+ btn = gr.Button("Inpaint!", elem_id="run_button")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ with gr.Accordion(label="Advanced Settings", open=False):
171
+ with gr.Row(equal_height=True):
172
+ guidance_scale = gr.Number(value=1.2, minimum=0.0, maximum=2.5, step=0.1, label="guidance_scale")
173
+ seed = gr.Number(value=123456, minimum=0, maximum=999999, step=1, label="seed")
174
+
175
+ with gr.Column():
176
+ image_out = gr.Image(label="Output", elem_id="output-img", height=400)
177
+
178
+ # Button click will trigger the inpainting function (no prompt required)
179
+ btn.click(fn=predict, inputs=[image, guidance_scale, seed], outputs=[image_out], api_name='run')
180
+
181
 
182
  gr.HTML(
183
  """