Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,352 Bytes
a88bb44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import glob
from copy import deepcopy
import gradio as gr
import numpy as np
import PIL
import spaces
import torch
import yaml
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download
from PIL import Image
from safetensors.torch import load_file
from torchvision.transforms import ToPILImage, ToTensor
from transformers import AutoModelForImageSegmentation
from utils import extract_object, get_model_from_config, resize_and_center_crop
ASPECT_RATIOS = {
str(512 / 2048): (512, 2048),
str(1024 / 1024): (1024, 1024),
str(2048 / 512): (2048, 512),
str(896 / 1152): (896, 1152),
str(1152 / 896): (1152, 896),
str(512 / 1920): (512, 1920),
str(640 / 1536): (640, 1536),
str(768 / 1280): (768, 1280),
str(1280 / 768): (1280, 768),
str(1536 / 640): (1536, 640),
str(1920 / 512): (1920, 512),
}
# download the config and model
MODEL_PATH = hf_hub_download("jasperai/LBM_relighting", "relight.safetensors")
CONFIG_PATH = hf_hub_download("jasperai/LBM_relighting", "relight.yaml")
with open(CONFIG_PATH, "r") as f:
config = yaml.safe_load(f)
model = get_model_from_config(**config)
sd = load_file(MODEL_PATH)
model.load_state_dict(sd, strict=True)
model.to("cuda").to(torch.bfloat16)
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
).cuda()
image_size = (1024, 1024)
@spaces.GPU
def evaluate(
fg_image: PIL.Image.Image,
bg_image: PIL.Image.Image,
num_sampling_steps: int = 1,
):
ori_h_bg, ori_w_bg = fg_image.size
ar_bg = ori_h_bg / ori_w_bg
closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg))
dimensions_bg = ASPECT_RATIOS[closest_ar_bg]
_, fg_mask = extract_object(birefnet, deepcopy(fg_image))
fg_image = resize_and_center_crop(fg_image, dimensions_bg[0], dimensions_bg[1])
fg_mask = resize_and_center_crop(fg_mask, dimensions_bg[0], dimensions_bg[1])
bg_image = resize_and_center_crop(bg_image, dimensions_bg[0], dimensions_bg[1])
img_pasted = Image.composite(fg_image, bg_image, fg_mask)
img_pasted_tensor = ToTensor()(img_pasted).unsqueeze(0) * 2 - 1
batch = {
"source_image": img_pasted_tensor.cuda().to(torch.bfloat16),
}
z_source = model.vae.encode(batch[model.source_key])
output_image = model.sample(
z=z_source,
num_steps=num_sampling_steps,
conditioner_inputs=batch,
max_samples=1,
).clamp(-1, 1)
output_image = (output_image[0].float().cpu() + 1) / 2
output_image = ToPILImage()(output_image)
# paste the output image on the background image
output_image = Image.composite(output_image, bg_image, fg_mask)
output_image.resize((ori_h_bg, ori_w_bg))
print(output_image.size, img_pasted.size)
return (np.array(img_pasted), np.array(output_image))
with gr.Blocks(title="LBM Object Relighting") as demo:
gr.Markdown(
f"""
# Object Relighting with Latent Bridge Matching
This is an interactive demo of [LBM: Latent Bridge Matching for Fast Image-to-Image Translation](https://arxiv.org/abs/2403.03025) *by Jasper Research*. We are internally exploring the possibility of releasing the model. If you enjoy the space, please also promote *open-source* by giving a ⭐ to the <a href='https://github.com/gojasper/LBM' target='_blank'>Github Repo</a>.
"""
)
gr.Markdown(
"💡 *Hint:* To better appreciate the low latency of our method, run the demo locally !"
)
with gr.Row():
with gr.Column():
with gr.Row():
fg_image = gr.Image(
type="pil",
label="Input Image",
image_mode="RGB",
height=360,
# width=360,
)
bg_image = gr.Image(
type="pil",
label="Target Background",
image_mode="RGB",
height=360,
# width=360,
)
with gr.Row():
submit_button = gr.Button("Relight", variant="primary")
with gr.Row():
num_inference_steps = gr.Slider(
minimum=1,
maximum=4,
value=1,
step=1,
label="Number of Inference Steps",
)
bg_gallery = gr.Gallery(
# height=450,
object_fit="contain",
label="Background List",
value=[path for path in glob.glob("examples/backgrounds/*.jpg")],
columns=5,
allow_preview=False,
)
with gr.Column():
output_slider = ImageSlider(label="Composite vs LBM", type="numpy")
output_slider.upload(
fn=evaluate,
inputs=[fg_image, bg_image, num_inference_steps],
outputs=[output_slider],
)
submit_button.click(
evaluate,
inputs=[fg_image, bg_image, num_inference_steps],
outputs=[output_slider],
show_progress=False,
show_api=False,
)
with gr.Row():
gr.Examples(
fn=evaluate,
examples=[
[
"examples/foregrounds/2.jpg",
"examples/backgrounds/14.jpg",
1,
],
[
"examples/foregrounds/10.jpg",
"examples/backgrounds/4.jpg",
1,
],
[
"examples/foregrounds/11.jpg",
"examples/backgrounds/24.jpg",
1,
],
[
"examples/foregrounds/19.jpg",
"examples/backgrounds/3.jpg",
1,
],
[
"examples/foregrounds/4.jpg",
"examples/backgrounds/6.jpg",
1,
],
[
"examples/foregrounds/14.jpg",
"examples/backgrounds/22.jpg",
1,
],
[
"examples/foregrounds/12.jpg",
"examples/backgrounds/1.jpg",
1,
],
],
inputs=[fg_image, bg_image, num_inference_steps],
outputs=[output_slider],
run_on_click=True,
)
gr.Markdown("**Disclaimer:**")
gr.Markdown(
"This demo is only for research purpose. Jasper cannot be held responsible for the generation of NSFW (Not Safe For Work) content through the use of this demo. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards. Jasper provides the tools, but the responsibility for their use lies with the individual user."
)
def bg_gallery_selected(gal, evt: gr.SelectData):
print(gal, evt.index)
return gal[evt.index][0]
bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=bg_image)
if __name__ == "__main__":
demo.queue().launch(share=True, show_api=False)
|