dvir-bria's picture
Update app.py
49fade0 verified
raw
history blame
6.51 kB
import gradio as gr
import numpy as np
import os
from PIL import Image
import requests
from io import BytesIO
import io
import base64
hf_token = os.environ.get("HF_TOKEN_API_DEMO") # we get it from a secret env variable, such that it's private
auth_headers = {"api_token": hf_token}
def convert_mask_image_to_base64_string(mask_image):
buffer = io.BytesIO()
mask_image.save(buffer, format="PNG") # You can choose the format (e.g., "JPEG", "PNG")
# Encode the buffer in base64
image_base64_string = base64.b64encode(buffer.getvalue()).decode('utf-8')
return f",{image_base64_string}" # for some reason the funciton which downloads image from base64 expects prefix of "," which is redundant in the url
def download_image(url):
print(url)
response = requests.get(url)
bytes = BytesIO(response.content)
return Image.open(bytes).convert("RGB")
def gen_fill_api_call(image_base64_file, mask_base64_file, prompt):
url = "http://engine.int.bria-api.com/v1/gen_fill"
payload = {
"file": image_base64_file,
"mask_file": mask_base64_file,
"prompt": prompt,
}
response = requests.post(url, json=payload, headers=auth_headers)
response = response.json()
print(response)
res_image = download_image(response["urls"][0])
return res_image
def predict(dict, prompt):
init_image = Image.fromarray(dict['background'][:, :, :3], 'RGB') #dict['background'].convert("RGB")#.resize((1024, 1024))
mask = Image.fromarray(dict['layers'][0][:,:,3], 'L') #dict['layers'].convert("RGB")#.resize((1024, 1024))
image_base64_file = convert_mask_image_to_base64_string(init_image)
mask_base64_file = convert_mask_image_to_base64_string(mask)
gen_img = gen_fill_api_call(image_base64_file, mask_base64_file, prompt)
return gen_img
css = '''
.gradio-container{max-width: 1100px !important}
#image_upload{min-height:400px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
#mask_radio .gr-form{background:transparent; border: none}
#word_mask{margin-top: .75em !important}
#word_mask textarea:disabled{opacity: 0.3}
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
.dark .footer {border-color: #303030}
.dark .footer>p {background: #0b0f19}
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
#image_upload .touch-none{display: flex}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 13rem; margin-left: auto;}
div#share-btn-container > div {flex-direction: row;background: black;align-items: center}
#share-btn-container:hover {background-color: #060606}
#share-btn {all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important;right:0;}
#share-btn * {all: unset}
#share-btn-container div:nth-child(-n+2){width: auto !important;min-height: 0px !important;}
#share-btn-container .wrap {display: none !important}
#share-btn-container.hidden {display: none!important}
#prompt input{width: calc(100% - 160px);border-top-right-radius: 0px;border-bottom-right-radius: 0px;}
#run_button {
width: 100%;
height: 50px; /* Set a fixed height for the button */
display: flex;
align-items: center;
justify-content: center;
}
#output-img img, #image_upload img {
object-fit: contain; /* Ensure aspect ratio is preserved */
width: 100%;
height: auto; /* Let height adjust automatically */
}
#prompt-container{margin-top:-18px;}
#prompt-container .form{border-top-left-radius: 0;border-top-right-radius: 0}
#image_upload{border-bottom-left-radius: 0px;border-bottom-right-radius: 0px}
'''
image_blocks = gr.Blocks(css=css, elem_id="total-container")
with image_blocks as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("## BRIA Generative Fill API")
gr.HTML('''
<p style="margin-bottom: 10px; font-size: 94%">
This demo showcases the BRIA Generative Fill capability, which allows users to remove specific elements or objects from images.<br>
The pipeline comprises multiple components, including <a href="https://huggingface.co/briaai/BRIA-2.3" target="_blank">briaai/BRIA-2.3</a>,
<a href="https://huggingface.co/briaai/BRIA-2.3-ControlNet-Generative-Fill" target="_blank">briaai/BRIA-2.3-ControlNet-Generative-Fill</a>,
and <a href="https://huggingface.co/briaai/BRIA-2.3-FAST-LORA" target="_blank">briaai/BRIA-2.3-FAST-LORA</a>, all trained on licensed data.<br>
This ensures full legal liability coverage for copyright and privacy infringement.<br>
Notes:<br>
- High-resolution images may take longer to process.<br>
- For multiple masks, results are better if all masks are included in inference.<br>
</p>
''')
with gr.Row():
with gr.Column():
image = gr.ImageEditor(sources=["upload"], layers=False, transforms=[],
brush=gr.Brush(colors=["#000000"], color_mode="fixed"),
)
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
with gr.Row(elem_id="prompt-container", equal_height=True):
with gr.Column():
btn = gr.Button("Fill!", elem_id="run_button")
with gr.Column():
image_out = gr.Image(label="Output", elem_id="output-img")
# Button click will trigger the inpainting function (now with prompt included)
btn.click(fn=predict, inputs=[image, prompt], outputs=[image_out], api_name='run')
gr.HTML(
"""
<div class="footer">
<p>Model by <a href="https://huggingface.co/diffusers" style="text-decoration: underline;" target="_blank">Diffusers</a> - Gradio Demo by 🤗 Hugging Face
</p>
</div>
"""
)
image_blocks.queue(max_size=25, api_open=False).launch(show_api=False)