File size: 5,181 Bytes
ecddba8 e4ee2ca ecddba8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import spaces
import os
from typing import cast
import gradio as gr
from PIL import Image
import torch
import torchvision
from diffusers import DDIMScheduler
from load_image import load_exr_image, load_ldr_image
from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
current_directory = os.path.dirname(os.path.abspath(__file__))
_pipe = StableDiffusionAOVMatEstPipeline.from_pretrained(
"zheng95z/rgb-to-x",
torch_dtype=torch.float16,
cache_dir=os.path.join(current_directory, "model_cache"),
).to("cuda")
pipe = cast(StableDiffusionAOVMatEstPipeline, _pipe)
pipe.scheduler = DDIMScheduler.from_config(
pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
)
pipe.set_progress_bar_config(disable=True)
pipe.to("cuda")
pipe = cast(StableDiffusionAOVMatEstPipeline, pipe)
@spaces.GPU
def generate(
photo,
seed: int,
inference_step: int,
num_samples: int,
) -> list[Image.Image]:
generator = torch.Generator(device="cuda").manual_seed(seed)
if photo.name.endswith(".exr"):
photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda")
elif (
photo.name.endswith(".png")
or photo.name.endswith(".jpg")
or photo.name.endswith(".jpeg")
):
photo = load_ldr_image(photo.name, from_srgb=True).to("cuda")
# Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
old_height = photo.shape[1]
old_width = photo.shape[2]
new_height = old_height
new_width = old_width
radio = old_height / old_width
max_side = 1000
if old_height > old_width:
new_height = max_side
new_width = int(new_height / radio)
else:
new_width = max_side
new_height = int(new_width * radio)
if new_width % 8 != 0 or new_height % 8 != 0:
new_width = new_width // 8 * 8
new_height = new_height // 8 * 8
photo = torchvision.transforms.Resize((new_height, new_width))(photo)
required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
prompts = {
"albedo": "Albedo (diffuse basecolor)",
"normal": "Camera-space Normal",
"roughness": "Roughness",
"metallic": "Metallicness",
"irradiance": "Irradiance (diffuse lighting)",
}
return_list = []
for i in range(num_samples):
for aov_name in required_aovs:
prompt = prompts[aov_name]
generated_image = pipe(
prompt=prompt,
photo=photo,
num_inference_steps=inference_step,
height=new_height,
width=new_width,
generator=generator,
required_aovs=[aov_name],
).images[0][0] # type: ignore
generated_image = torchvision.transforms.Resize((old_height, old_width))(
generated_image
)
generated_image = (generated_image, f"Generated {aov_name} {i}")
return_list.append(generated_image)
return return_list
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
with gr.Row():
# Input side
with gr.Column():
gr.Markdown("### Given Image")
photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])
gr.Markdown("### Parameters")
run_button = gr.Button(value="Run")
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(
label="Seed",
minimum=-1,
maximum=2147483647,
step=1,
randomize=True,
)
inference_step = gr.Slider(
label="Inference Step",
minimum=1,
maximum=100,
step=1,
value=50,
)
num_samples = gr.Slider(
label="Samples",
minimum=1,
maximum=100,
step=1,
value=1,
)
# Output side
with gr.Column():
gr.Markdown("### Output Gallery")
result_gallery = gr.Gallery(
label="Output",
show_label=False,
elem_id="gallery",
columns=2,
)
examples = gr.Examples(
examples=[
[
"rgb2x/example/Castlereagh_corridor_photo.png",
]
],
inputs=[photo],
outputs=[result_gallery],
fn=generate,
cache_mode="eager",
cache_examples=True,
)
run_button.click(
fn=generate,
inputs=[photo, seed, inference_step, num_samples],
outputs=result_gallery,
queue=True,
)
if __name__ == "__main__":
demo.launch(debug=False, share=False, show_api=False)
|