File size: 12,179 Bytes
84eee5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bddc9a2
84eee5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f0b43e
84eee5b
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import spaces
import os

# this is a HF Spaces specific hack, as
#  (i)  building pytorch3d with GPU support is a bit tricky here
#  (ii) installing the wheel via requirements.txt breaks ZeroGPU
os.system("pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt221/download.html")

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

import skimage
from PIL import Image

import gradio as gr

from utils.render import PointsRendererWithMasks, render
from utils.ops import snap_high_gradients_to_nn, project_points, get_pointcloud, merge_pointclouds, outpaint_with_depth_estimation
from utils.gs import gs_options, read_cameras_from_optimization_bundle, Scene, run_gaussian_splatting, get_blank_gs_bundle

from pytorch3d.utils import opencv_from_cameras_projection
from utils.ops import focal2fov, fov2focal
from utils.models import infer_with_zoe_dc
from utils.scene import GaussianModel
from utils.demo import downsample_point_cloud
from typing import Iterable, Tuple, Dict, Optional
import itertools

from pytorch3d.structures import Pointclouds
from pytorch3d.renderer import (
    look_at_view_transform,
    PerspectiveCameras,
)

from pytorch3d.io import IO

def get_blank_gs_bundle(h, w):
    return {
        "camera_angle_x": focal2fov(torch.tensor([w], dtype=torch.float32), w),
        "W": w,
        "H": h,
        "pcd_points": None,
        "pcd_colors": None,
        'frames': [],
    }

@spaces.GPU(duration=30)
def extrapolate_point_cloud(prompt: str, image_size: Tuple[int, int], look_at_params: Iterable[Tuple[float, float, float, Tuple[float, float, float]]], point_cloud: Pointclouds = None, dry_run: bool = False, discard_mask: bool = False, initial_image: Optional[Image.Image] = None, depth_scaling: float = 1, **render_kwargs):
    w, h = image_size
    optimization_bundle_frames = []

    for azim, elev, dist, at in look_at_params:
        R, T = look_at_view_transform(device=device, azim=azim, elev=elev, dist=dist, at=at)
        cameras = PerspectiveCameras(R=R, T=T, focal_length=torch.tensor([w], dtype=torch.float32), principal_point=(((h-1)/2, (w-1)/2),), image_size=(image_size,), device=device, in_ndc=False)

        if point_cloud is not None:
            images, masks, depths = render(cameras, point_cloud, **render_kwargs)

            if not dry_run:
                eroded_mask = skimage.morphology.binary_erosion((depths[0] > 0).cpu().numpy(), footprint=None)#skimage.morphology.disk(1))
                eroded_depth = depths[0].clone()
                eroded_depth[torch.from_numpy(eroded_mask).to(depths.device) <= 0] = 0

                outpainted_img, aligned_depth = outpaint_with_depth_estimation(images[0], masks[0], eroded_depth, h, w, pipe, zoe_dc_model, prompt, cameras, dilation_size=2, depth_scaling=depth_scaling, generator=torch.Generator(device=pipe.device).manual_seed(0))

                aligned_depth = torch.from_numpy(aligned_depth).to(device)

            else:
                # in a dry run, we do not actually outpaint the image
                outpainted_img = Image.fromarray((255*images[0].cpu().numpy()).astype(np.uint8))

        else:
            assert initial_image is not None
            assert not dry_run

            # jumpstart the point cloud with a regular depth estimation
            t_initial_image = torch.from_numpy(np.asarray(initial_image)/255.).permute(2,0,1).float()
            depth = aligned_depth = infer_with_zoe_dc(zoe_dc_model, t_initial_image, torch.zeros(h, w))
            outpainted_img = initial_image
            images = [t_initial_image.to(device)]
            masks = [torch.ones(h, w, dtype=torch.bool).to(device)]

        if not dry_run:
            # snap high gradients to nearest neighbor, which eliminates noodle artifacts
            aligned_depth = snap_high_gradients_to_nn(aligned_depth, threshold=12).cpu()
            xy_depth_world = project_points(cameras, aligned_depth)

        c2w = cameras.get_world_to_view_transform().get_matrix()[0]

        optimization_bundle_frames.append({
            "image": outpainted_img,
            "mask": masks[0].cpu().numpy(),
            "transform_matrix": c2w.tolist(),
            "azim": azim,
            "elev": elev,
            "dist": dist,
        })

        if discard_mask:
            optimization_bundle_frames[-1].pop("mask")

        if not dry_run:
            optimization_bundle_frames[-1]["center_point"] = xy_depth_world[0].mean(dim=0).tolist()
            optimization_bundle_frames[-1]["depth"] = aligned_depth.cpu().numpy()
            optimization_bundle_frames[-1]["mean_depth"] = aligned_depth.mean().item()

        else:
            # in a dry run, we do not modify the point cloud
            continue

        rgb = (torch.from_numpy(np.asarray(outpainted_img).copy()).reshape(-1, 3).float() / 255).to(device)

        if point_cloud is None:
            point_cloud = get_pointcloud(xy_depth_world[0], device=device, features=rgb)

        else:
            # pytorch 3d's mask might be slightly too big (subpixels), so we erode it a little to avoid seams
            # in theory, 1 pixel is sufficient but we use 2 to be safe
            masks[0] = torch.from_numpy(skimage.morphology.binary_erosion(masks[0].cpu().numpy(), footprint=skimage.morphology.disk(2))).to(device)

            partial_outpainted_point_cloud = get_pointcloud(xy_depth_world[0][~masks[0].view(-1)], device=device, features=rgb[~masks[0].view(-1)])

            point_cloud = merge_pointclouds([point_cloud, partial_outpainted_point_cloud])

    return optimization_bundle_frames, point_cloud

@spaces.GPU(duration=30)
def generate_point_cloud(initial_image: Image.Image, prompt: str):
    image_size = initial_image.size
    w, h = image_size

    optimization_bundle = get_blank_gs_bundle(h, w)

    step_size = 25

    azim_steps = [0, step_size, -step_size]
    look_at_params = [(azim, 0, 0.01, torch.zeros((1, 3))) for azim in azim_steps]

    optimization_bundle["frames"], point_cloud = extrapolate_point_cloud(prompt, image_size, look_at_params, discard_mask=True, initial_image=initial_image, depth_scaling=0.5, fill_point_cloud_holes=True)

    optimization_bundle["pcd_points"] = point_cloud.points_padded()[0].cpu().numpy()
    optimization_bundle["pcd_colors"] = point_cloud.features_padded()[0].cpu().numpy()

    return optimization_bundle, point_cloud

@spaces.GPU(duration=30)
def supplement_point_cloud(optimization_bundle: Dict, point_cloud: Pointclouds, prompt: str):
    w, h = optimization_bundle["W"], optimization_bundle["H"]

    supporting_frames = []

    for i, frame in enumerate(optimization_bundle["frames"]):
        # skip supporting views
        if frame.get("supporting", False):
            continue

        center_point = torch.tensor(frame["center_point"]).to(device)
        mean_depth = frame["mean_depth"]
        azim, elev = frame["azim"], frame["elev"]

        azim_jitters = torch.linspace(-5, 5, 3).tolist()
        elev_jitters = torch.linspace(-5, 5, 3).tolist()

        # build the product of azim and elev jitters
        camera_jitters = [{"azim": azim + azim_jitter, "elev": elev + elev_jitter} for azim_jitter, elev_jitter in itertools.product(azim_jitters, elev_jitters)]

        look_at_params = [(camera_jitter["azim"], camera_jitter["elev"], mean_depth, center_point.unsqueeze(0)) for camera_jitter in camera_jitters]

        local_supporting_frames, point_cloud = extrapolate_point_cloud(prompt, (w, h), look_at_params, point_cloud, dry_run=True, depth_scaling=0.5, antialiasing=3)

        for local_supporting_frame in local_supporting_frames:
            local_supporting_frame["supporting"] = True

        supporting_frames.extend(local_supporting_frames)

    optimization_bundle["pcd_points"] = point_cloud.points_padded()[0].cpu().numpy()
    optimization_bundle["pcd_colors"] = point_cloud.features_padded()[0].cpu().numpy()

    return optimization_bundle, point_cloud

@spaces.GPU(duration=30)
def generate_scene(img: Image.Image, prompt: str):
    assert isinstance(img, Image.Image)

    # resize image maintaining the aspect ratio so the longest side is 720 pixels
    max_size = 720
    img.thumbnail((max_size, max_size))

    # crop to ensure the image dimensions are divisible by 8
    img = img.crop((0, 0, img.width - img.width % 8, img.height - img.height % 8))

    gs_optimization_bundle, point_cloud = generate_point_cloud(img, prompt)

    downsampled_point_cloud = downsample_point_cloud(gs_optimization_bundle, device=device)

    gs_optimization_bundle["pcd_points"] = downsampled_point_cloud.points_padded()[0].cpu().numpy()
    gs_optimization_bundle["pcd_colors"] = downsampled_point_cloud.features_padded()[0].cpu().numpy()

    scene = Scene(gs_optimization_bundle, GaussianModel(gs_options.sh_degree), gs_options)

    scene.gaussians._opacity = torch.ones_like(scene.gaussians._opacity)
    #scene = run_gaussian_splatting(scene, gs_optimization_bundle)

    # coordinate system transformation
    scene.gaussians._xyz = scene.gaussians._xyz.detach()
    scene.gaussians._xyz[:, 1] = -scene.gaussians._xyz[:, 1]
    scene.gaussians._xyz[:, 2] = -scene.gaussians._xyz[:, 2]

    save_path = "./output.ply"

    scene.gaussians.save_ply(save_path)

    return save_path

if __name__ == "__main__":
    global device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  

    from utils.models import get_zoe_dc_model, get_sd_pipeline

    global zoe_dc_model
    from huggingface_hub import hf_hub_download
    zoe_dc_model = get_zoe_dc_model(ckpt_path=hf_hub_download(repo_id="paulengstler/invisible-stitch", filename="invisible-stitch.pt")).to(device)

    global pipe
    pipe = get_sd_pipeline().to(device)

    demo = gr.Interface(
        fn=generate_scene,
        inputs=[
            gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil"),
            gr.Textbox(label="Scene Hallucination Prompt")
        ],
        outputs=gr.Model3D(label="Generated Scene"),
        allow_flagging="never",
        title="Invisible Stitch: Generating Smooth 3D Scenes with Depth Inpainting",
        description="Hallucinate geometrically coherent 3D scenes from a single input image in less than 30 seconds.<br /> [Project Page](https://research.paulengstler.com/invisible-stitch) | [GitHub](https://github.com/paulengstler/invisible-stitch) | [Paper](https://arxiv.org/abs/2404.19758) <br /><br />To keep this demo snappy, we have limited its functionality. Scenes are generated at a low resolution without densification, supporting views are not inpainted, and we do not optimize the resulting point cloud. Imperfections are to be expected, in particular around object borders. Please allow a couple of seconds for the generated scene to be downloaded (about 40 megabytes).",
        article="Please consider running this demo locally to obtain high-quality results (see the GitHub repository).<br /><br />Here are some observations we made that might help you to get better results:<ul><li>Use generic prompts that match the surroundings of your input image.</li><li>Ensure that the borders of your input image are free from partially visible objects.</li><li>Keep your prompts simple and avoid adding specific details.</li></ul>",
        examples=[
            ["examples/photo-1667788000333-4e36f948de9a.jpeg", "a street with traditional buildings in Kyoto, Japan"],
            ["examples/photo-1628624747186-a941c476b7ef.jpeg", "a suburban street in North Carolina on a bright, sunny day"],
            ["examples/photo-1469559845082-95b66baaf023.jpeg", "a view of Zion National Park"],
            ["examples/photo-1514984879728-be0aff75a6e8.jpeg", "a close-up view of a muddy path in a forest"],
            ["examples/photo-1618197345638-d2df92b39fe1.jpeg", "a close-up view of a white linen bed in a minimalistic room"],
            ["examples/photo-1546975490-e8b92a360b24.jpeg", "a warm living room with plants"],
            ["examples/photo-1499916078039-922301b0eb9b.jpeg", "a cozy bedroom on a bright day"],
        ])
    demo.queue().launch(share=True)