Spaces:
Running
Running
import gradio as gr # pyright: ignore[reportMissingTypeStubs] | |
import pillow_heif # pyright: ignore[reportMissingTypeStubs] | |
import spaces # pyright: ignore[reportMissingTypeStubs] | |
import torch | |
from PIL import Image | |
from refiners.fluxion.utils import manual_seed, no_grad | |
from utils import LightingPreference, load_ic_light, resize_modulo_8 | |
pillow_heif.register_heif_opener() # pyright: ignore[reportUnknownMemberType] | |
pillow_heif.register_avif_opener() # pyright: ignore[reportUnknownMemberType] | |
TITLE = """ | |
# IC-Light with Refiners | |
""" | |
# initialize the enhancer, on the cpu | |
DEVICE_CPU = torch.device("cpu") | |
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 | |
ic_light = load_ic_light(device=DEVICE_CPU, dtype=DTYPE) | |
# "move" the enhancer to the gpu, this is handled/intercepted by Zero GPU | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
ic_light.to(device=DEVICE, dtype=DTYPE) | |
ic_light.device = DEVICE | |
ic_light.dtype = DTYPE | |
ic_light.solver = ic_light.solver.to(device=DEVICE, dtype=DTYPE) | |
def process( | |
image: Image.Image, | |
light_pref: str, | |
prompt: str, | |
negative_prompt: str, | |
strength_first_pass: float, | |
strength_second_pass: float, | |
condition_scale: float, | |
num_inference_steps: int, | |
seed: int, | |
) -> Image.Image: | |
assert image.mode == "RGBA" | |
assert 0 <= strength_second_pass <= 1 | |
assert 0 <= strength_first_pass <= 1 | |
assert num_inference_steps > 0 | |
assert seed >= 0 | |
# set the seed | |
manual_seed(seed) | |
# resize image to ~768x768 | |
image = resize_modulo_8(image, 768) | |
# split RGB and alpha channel | |
mask = image.getchannel("A") | |
image = image.convert("RGB") | |
# compute embeddings | |
clip_text_embedding = ic_light.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) | |
ic_light.set_ic_light_condition(image=image, mask=mask) | |
# get the light_pref_image | |
light_pref_image = LightingPreference.from_str(value=light_pref).get_init_image( | |
width=image.width, | |
height=image.height, | |
interval=(0.2, 0.8), | |
) | |
# if no light preference is provided, do a full strength first pass | |
if light_pref_image is None: | |
x = torch.randn_like(ic_light._ic_light_condition) # pyright: ignore[reportPrivateUsage] | |
strength_first_pass = 1.0 | |
else: | |
x = ic_light.lda.image_to_latents(light_pref_image) | |
x = ic_light.solver.add_noise(x, noise=torch.randn_like(x), step=0) | |
# configure the first pass | |
num_steps = int(round(num_inference_steps / strength_first_pass)) | |
first_step = int(num_steps * (1 - strength_first_pass)) | |
ic_light.set_inference_steps(num_steps, first_step) | |
# first pass | |
for step in ic_light.steps: | |
x = ic_light( | |
x, | |
step=step, | |
clip_text_embedding=clip_text_embedding, | |
condition_scale=condition_scale, | |
) | |
# configure the second pass | |
num_steps = int(round(num_inference_steps / strength_second_pass)) | |
first_step = int(num_steps * (1 - strength_second_pass)) | |
ic_light.set_inference_steps(num_steps, first_step) | |
# initialize the latents | |
x = ic_light.solver.add_noise(x, noise=torch.randn_like(x), step=first_step) | |
# second pass | |
for step in ic_light.steps: | |
x = ic_light( | |
x, | |
step=step, | |
clip_text_embedding=clip_text_embedding, | |
condition_scale=condition_scale, | |
) | |
return ic_light.lda.latents_to_image(x) | |
with gr.Blocks() as demo: | |
gr.Markdown(TITLE) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
label="Input Image (RGBA)", | |
image_mode="RGBA", | |
type="pil", | |
) | |
run_button = gr.Button( | |
value="Relight Image", | |
) | |
with gr.Column(): | |
output_image = gr.Image( | |
label="Relighted Image (RGB)", | |
image_mode="RGB", | |
type="pil", | |
) | |
with gr.Accordion("Advanced Settings", open=True): | |
prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="bright green neon light, best quality, highres", | |
) | |
neg_prompt = gr.Textbox( | |
label="Negative Prompt", | |
placeholder="worst quality, low quality, normal quality", | |
) | |
light_pref = gr.Radio( | |
choices=["None", "Left", "Right", "Top", "Bottom"], | |
label="Light direction preference", | |
value="None", | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=100_000, | |
value=69_420, | |
step=1, | |
) | |
condition_scale = gr.Slider( | |
label="Condition scale", | |
minimum=0.5, | |
maximum=2, | |
value=1.25, | |
step=0.05, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
value=25, | |
step=1, | |
) | |
with gr.Row(): | |
strength_first_pass = gr.Slider( | |
label="Strength of the first pass", | |
minimum=0, | |
maximum=1, | |
value=0.9, | |
step=0.1, | |
) | |
strength_second_pass = gr.Slider( | |
label="Strength of the second pass", | |
minimum=0, | |
maximum=1, | |
value=0.5, | |
step=0.1, | |
) | |
run_button.click( | |
fn=process, | |
inputs=[ | |
input_image, | |
light_pref, | |
prompt, | |
neg_prompt, | |
strength_first_pass, | |
strength_second_pass, | |
condition_scale, | |
num_inference_steps, | |
seed, | |
], | |
outputs=output_image, | |
) | |
gr.Examples( # pyright: ignore[reportUnknownMemberType] | |
examples=[ | |
[ | |
"examples/plant.png", | |
"None", | |
"blue purple neon light, cyberpunk city background, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF", | |
"dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", | |
0.9, | |
0.5, | |
1.25, | |
25, | |
69_420, | |
], | |
[ | |
"examples/plant.png", | |
"Right", | |
"blue purple neon light, cyberpunk city background, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF", | |
"dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", | |
0.9, | |
0.5, | |
1.25, | |
25, | |
69_420, | |
], | |
[ | |
"examples/plant.png", | |
"Left", | |
"floor is blue ice cavern, stalactite, high-quality professional studo photography, realistic soft lighting, HEIC, CR2, NEF", | |
"dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", | |
0.9, | |
0.5, | |
1.25, | |
25, | |
69_420, | |
], | |
[ | |
"examples/chair.png", | |
"Right", | |
"god rays, fluffy clouds, peaceful surreal atmosphere, high-quality, HEIC, CR2, NEF", | |
"dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", | |
0.9, | |
0.5, | |
1.25, | |
25, | |
69, | |
], | |
[ | |
"examples/bunny.png", | |
"Left", | |
"grass field, high-quality, HEIC, CR2, NEF", | |
"dirty, messy, worst quality, low quality, watermark, signature, jpeg artifacts, deformed, monochrome, black and white", | |
0.9, | |
0.5, | |
1.25, | |
25, | |
420, | |
], | |
], | |
inputs=[ | |
input_image, | |
light_pref, | |
prompt, | |
neg_prompt, | |
strength_first_pass, | |
strength_second_pass, | |
condition_scale, | |
num_inference_steps, | |
seed, | |
], | |
outputs=output_image, | |
fn=process, | |
cache_examples=True, | |
cache_mode="lazy", | |
run_on_click=False, | |
) | |
demo.launch() | |