File size: 5,181 Bytes
ecddba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4ee2ca
 
 
 
 
 
 
 
 
 
 
 
ecddba8
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import spaces
import os
from typing import cast
import gradio as gr
from PIL import Image
import torch
import torchvision
from diffusers import DDIMScheduler
from load_image import load_exr_image, load_ldr_image
from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline

os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"

current_directory = os.path.dirname(os.path.abspath(__file__))

_pipe = StableDiffusionAOVMatEstPipeline.from_pretrained(
    "zheng95z/rgb-to-x",
    torch_dtype=torch.float16,
    cache_dir=os.path.join(current_directory, "model_cache"),
).to("cuda")
pipe = cast(StableDiffusionAOVMatEstPipeline, _pipe)
pipe.scheduler = DDIMScheduler.from_config(
    pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
)
pipe.set_progress_bar_config(disable=True)
pipe.to("cuda")
pipe = cast(StableDiffusionAOVMatEstPipeline, pipe)


@spaces.GPU
def generate(
    photo,
    seed: int,
    inference_step: int,
    num_samples: int,
) -> list[Image.Image]:
    generator = torch.Generator(device="cuda").manual_seed(seed)

    if photo.name.endswith(".exr"):
        photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda")
    elif (
        photo.name.endswith(".png")
        or photo.name.endswith(".jpg")
        or photo.name.endswith(".jpeg")
    ):
        photo = load_ldr_image(photo.name, from_srgb=True).to("cuda")

    # Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop
    old_height = photo.shape[1]
    old_width = photo.shape[2]
    new_height = old_height
    new_width = old_width
    radio = old_height / old_width
    max_side = 1000
    if old_height > old_width:
        new_height = max_side
        new_width = int(new_height / radio)
    else:
        new_width = max_side
        new_height = int(new_width * radio)

    if new_width % 8 != 0 or new_height % 8 != 0:
        new_width = new_width // 8 * 8
        new_height = new_height // 8 * 8

    photo = torchvision.transforms.Resize((new_height, new_width))(photo)

    required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"]
    prompts = {
        "albedo": "Albedo (diffuse basecolor)",
        "normal": "Camera-space Normal",
        "roughness": "Roughness",
        "metallic": "Metallicness",
        "irradiance": "Irradiance (diffuse lighting)",
    }

    return_list = []
    for i in range(num_samples):
        for aov_name in required_aovs:
            prompt = prompts[aov_name]
            generated_image = pipe(
                prompt=prompt,
                photo=photo,
                num_inference_steps=inference_step,
                height=new_height,
                width=new_width,
                generator=generator,
                required_aovs=[aov_name],
            ).images[0][0]  # type: ignore

            generated_image = torchvision.transforms.Resize((old_height, old_width))(
                generated_image
            )

            generated_image = (generated_image, f"Generated {aov_name} {i}")
            return_list.append(generated_image)

    return return_list


with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)")
    with gr.Row():
        # Input side
        with gr.Column():
            gr.Markdown("### Given Image")
            photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"])

            gr.Markdown("### Parameters")
            run_button = gr.Button(value="Run")
            with gr.Accordion("Advanced options", open=False):
                seed = gr.Slider(
                    label="Seed",
                    minimum=-1,
                    maximum=2147483647,
                    step=1,
                    randomize=True,
                )
                inference_step = gr.Slider(
                    label="Inference Step",
                    minimum=1,
                    maximum=100,
                    step=1,
                    value=50,
                )
                num_samples = gr.Slider(
                    label="Samples",
                    minimum=1,
                    maximum=100,
                    step=1,
                    value=1,
                )

        # Output side
        with gr.Column():
            gr.Markdown("### Output Gallery")
            result_gallery = gr.Gallery(
                label="Output",
                show_label=False,
                elem_id="gallery",
                columns=2,
            )
            examples = gr.Examples(
                examples=[
                    [
                        "rgb2x/example/Castlereagh_corridor_photo.png",
                    ]
                ],
                inputs=[photo],
                outputs=[result_gallery],
                fn=generate,
                cache_mode="eager",
                cache_examples=True,
            )

    run_button.click(
        fn=generate,
        inputs=[photo, seed, inference_step, num_samples],
        outputs=result_gallery,
        queue=True,
    )


if __name__ == "__main__":
    demo.launch(debug=False, share=False, show_api=False)