Spaces:
Running
on
Zero
Running
on
Zero
File size: 12,179 Bytes
84eee5b bddc9a2 84eee5b 1f0b43e 84eee5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
import spaces
import os
# this is a HF Spaces specific hack, as
# (i) building pytorch3d with GPU support is a bit tricky here
# (ii) installing the wheel via requirements.txt breaks ZeroGPU
os.system("pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt221/download.html")
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import skimage
from PIL import Image
import gradio as gr
from utils.render import PointsRendererWithMasks, render
from utils.ops import snap_high_gradients_to_nn, project_points, get_pointcloud, merge_pointclouds, outpaint_with_depth_estimation
from utils.gs import gs_options, read_cameras_from_optimization_bundle, Scene, run_gaussian_splatting, get_blank_gs_bundle
from pytorch3d.utils import opencv_from_cameras_projection
from utils.ops import focal2fov, fov2focal
from utils.models import infer_with_zoe_dc
from utils.scene import GaussianModel
from utils.demo import downsample_point_cloud
from typing import Iterable, Tuple, Dict, Optional
import itertools
from pytorch3d.structures import Pointclouds
from pytorch3d.renderer import (
look_at_view_transform,
PerspectiveCameras,
)
from pytorch3d.io import IO
def get_blank_gs_bundle(h, w):
return {
"camera_angle_x": focal2fov(torch.tensor([w], dtype=torch.float32), w),
"W": w,
"H": h,
"pcd_points": None,
"pcd_colors": None,
'frames': [],
}
@spaces.GPU(duration=30)
def extrapolate_point_cloud(prompt: str, image_size: Tuple[int, int], look_at_params: Iterable[Tuple[float, float, float, Tuple[float, float, float]]], point_cloud: Pointclouds = None, dry_run: bool = False, discard_mask: bool = False, initial_image: Optional[Image.Image] = None, depth_scaling: float = 1, **render_kwargs):
w, h = image_size
optimization_bundle_frames = []
for azim, elev, dist, at in look_at_params:
R, T = look_at_view_transform(device=device, azim=azim, elev=elev, dist=dist, at=at)
cameras = PerspectiveCameras(R=R, T=T, focal_length=torch.tensor([w], dtype=torch.float32), principal_point=(((h-1)/2, (w-1)/2),), image_size=(image_size,), device=device, in_ndc=False)
if point_cloud is not None:
images, masks, depths = render(cameras, point_cloud, **render_kwargs)
if not dry_run:
eroded_mask = skimage.morphology.binary_erosion((depths[0] > 0).cpu().numpy(), footprint=None)#skimage.morphology.disk(1))
eroded_depth = depths[0].clone()
eroded_depth[torch.from_numpy(eroded_mask).to(depths.device) <= 0] = 0
outpainted_img, aligned_depth = outpaint_with_depth_estimation(images[0], masks[0], eroded_depth, h, w, pipe, zoe_dc_model, prompt, cameras, dilation_size=2, depth_scaling=depth_scaling, generator=torch.Generator(device=pipe.device).manual_seed(0))
aligned_depth = torch.from_numpy(aligned_depth).to(device)
else:
# in a dry run, we do not actually outpaint the image
outpainted_img = Image.fromarray((255*images[0].cpu().numpy()).astype(np.uint8))
else:
assert initial_image is not None
assert not dry_run
# jumpstart the point cloud with a regular depth estimation
t_initial_image = torch.from_numpy(np.asarray(initial_image)/255.).permute(2,0,1).float()
depth = aligned_depth = infer_with_zoe_dc(zoe_dc_model, t_initial_image, torch.zeros(h, w))
outpainted_img = initial_image
images = [t_initial_image.to(device)]
masks = [torch.ones(h, w, dtype=torch.bool).to(device)]
if not dry_run:
# snap high gradients to nearest neighbor, which eliminates noodle artifacts
aligned_depth = snap_high_gradients_to_nn(aligned_depth, threshold=12).cpu()
xy_depth_world = project_points(cameras, aligned_depth)
c2w = cameras.get_world_to_view_transform().get_matrix()[0]
optimization_bundle_frames.append({
"image": outpainted_img,
"mask": masks[0].cpu().numpy(),
"transform_matrix": c2w.tolist(),
"azim": azim,
"elev": elev,
"dist": dist,
})
if discard_mask:
optimization_bundle_frames[-1].pop("mask")
if not dry_run:
optimization_bundle_frames[-1]["center_point"] = xy_depth_world[0].mean(dim=0).tolist()
optimization_bundle_frames[-1]["depth"] = aligned_depth.cpu().numpy()
optimization_bundle_frames[-1]["mean_depth"] = aligned_depth.mean().item()
else:
# in a dry run, we do not modify the point cloud
continue
rgb = (torch.from_numpy(np.asarray(outpainted_img).copy()).reshape(-1, 3).float() / 255).to(device)
if point_cloud is None:
point_cloud = get_pointcloud(xy_depth_world[0], device=device, features=rgb)
else:
# pytorch 3d's mask might be slightly too big (subpixels), so we erode it a little to avoid seams
# in theory, 1 pixel is sufficient but we use 2 to be safe
masks[0] = torch.from_numpy(skimage.morphology.binary_erosion(masks[0].cpu().numpy(), footprint=skimage.morphology.disk(2))).to(device)
partial_outpainted_point_cloud = get_pointcloud(xy_depth_world[0][~masks[0].view(-1)], device=device, features=rgb[~masks[0].view(-1)])
point_cloud = merge_pointclouds([point_cloud, partial_outpainted_point_cloud])
return optimization_bundle_frames, point_cloud
@spaces.GPU(duration=30)
def generate_point_cloud(initial_image: Image.Image, prompt: str):
image_size = initial_image.size
w, h = image_size
optimization_bundle = get_blank_gs_bundle(h, w)
step_size = 25
azim_steps = [0, step_size, -step_size]
look_at_params = [(azim, 0, 0.01, torch.zeros((1, 3))) for azim in azim_steps]
optimization_bundle["frames"], point_cloud = extrapolate_point_cloud(prompt, image_size, look_at_params, discard_mask=True, initial_image=initial_image, depth_scaling=0.5, fill_point_cloud_holes=True)
optimization_bundle["pcd_points"] = point_cloud.points_padded()[0].cpu().numpy()
optimization_bundle["pcd_colors"] = point_cloud.features_padded()[0].cpu().numpy()
return optimization_bundle, point_cloud
@spaces.GPU(duration=30)
def supplement_point_cloud(optimization_bundle: Dict, point_cloud: Pointclouds, prompt: str):
w, h = optimization_bundle["W"], optimization_bundle["H"]
supporting_frames = []
for i, frame in enumerate(optimization_bundle["frames"]):
# skip supporting views
if frame.get("supporting", False):
continue
center_point = torch.tensor(frame["center_point"]).to(device)
mean_depth = frame["mean_depth"]
azim, elev = frame["azim"], frame["elev"]
azim_jitters = torch.linspace(-5, 5, 3).tolist()
elev_jitters = torch.linspace(-5, 5, 3).tolist()
# build the product of azim and elev jitters
camera_jitters = [{"azim": azim + azim_jitter, "elev": elev + elev_jitter} for azim_jitter, elev_jitter in itertools.product(azim_jitters, elev_jitters)]
look_at_params = [(camera_jitter["azim"], camera_jitter["elev"], mean_depth, center_point.unsqueeze(0)) for camera_jitter in camera_jitters]
local_supporting_frames, point_cloud = extrapolate_point_cloud(prompt, (w, h), look_at_params, point_cloud, dry_run=True, depth_scaling=0.5, antialiasing=3)
for local_supporting_frame in local_supporting_frames:
local_supporting_frame["supporting"] = True
supporting_frames.extend(local_supporting_frames)
optimization_bundle["pcd_points"] = point_cloud.points_padded()[0].cpu().numpy()
optimization_bundle["pcd_colors"] = point_cloud.features_padded()[0].cpu().numpy()
return optimization_bundle, point_cloud
@spaces.GPU(duration=30)
def generate_scene(img: Image.Image, prompt: str):
assert isinstance(img, Image.Image)
# resize image maintaining the aspect ratio so the longest side is 720 pixels
max_size = 720
img.thumbnail((max_size, max_size))
# crop to ensure the image dimensions are divisible by 8
img = img.crop((0, 0, img.width - img.width % 8, img.height - img.height % 8))
gs_optimization_bundle, point_cloud = generate_point_cloud(img, prompt)
downsampled_point_cloud = downsample_point_cloud(gs_optimization_bundle, device=device)
gs_optimization_bundle["pcd_points"] = downsampled_point_cloud.points_padded()[0].cpu().numpy()
gs_optimization_bundle["pcd_colors"] = downsampled_point_cloud.features_padded()[0].cpu().numpy()
scene = Scene(gs_optimization_bundle, GaussianModel(gs_options.sh_degree), gs_options)
scene.gaussians._opacity = torch.ones_like(scene.gaussians._opacity)
#scene = run_gaussian_splatting(scene, gs_optimization_bundle)
# coordinate system transformation
scene.gaussians._xyz = scene.gaussians._xyz.detach()
scene.gaussians._xyz[:, 1] = -scene.gaussians._xyz[:, 1]
scene.gaussians._xyz[:, 2] = -scene.gaussians._xyz[:, 2]
save_path = "./output.ply"
scene.gaussians.save_ply(save_path)
return save_path
if __name__ == "__main__":
global device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from utils.models import get_zoe_dc_model, get_sd_pipeline
global zoe_dc_model
from huggingface_hub import hf_hub_download
zoe_dc_model = get_zoe_dc_model(ckpt_path=hf_hub_download(repo_id="paulengstler/invisible-stitch", filename="invisible-stitch.pt")).to(device)
global pipe
pipe = get_sd_pipeline().to(device)
demo = gr.Interface(
fn=generate_scene,
inputs=[
gr.Image(label="Input Image", sources=["upload", "clipboard"], type="pil"),
gr.Textbox(label="Scene Hallucination Prompt")
],
outputs=gr.Model3D(label="Generated Scene"),
allow_flagging="never",
title="Invisible Stitch: Generating Smooth 3D Scenes with Depth Inpainting",
description="Hallucinate geometrically coherent 3D scenes from a single input image in less than 30 seconds.<br /> [Project Page](https://research.paulengstler.com/invisible-stitch) | [GitHub](https://github.com/paulengstler/invisible-stitch) | [Paper](https://arxiv.org/abs/2404.19758) <br /><br />To keep this demo snappy, we have limited its functionality. Scenes are generated at a low resolution without densification, supporting views are not inpainted, and we do not optimize the resulting point cloud. Imperfections are to be expected, in particular around object borders. Please allow a couple of seconds for the generated scene to be downloaded (about 40 megabytes).",
article="Please consider running this demo locally to obtain high-quality results (see the GitHub repository).<br /><br />Here are some observations we made that might help you to get better results:<ul><li>Use generic prompts that match the surroundings of your input image.</li><li>Ensure that the borders of your input image are free from partially visible objects.</li><li>Keep your prompts simple and avoid adding specific details.</li></ul>",
examples=[
["examples/photo-1667788000333-4e36f948de9a.jpeg", "a street with traditional buildings in Kyoto, Japan"],
["examples/photo-1628624747186-a941c476b7ef.jpeg", "a suburban street in North Carolina on a bright, sunny day"],
["examples/photo-1469559845082-95b66baaf023.jpeg", "a view of Zion National Park"],
["examples/photo-1514984879728-be0aff75a6e8.jpeg", "a close-up view of a muddy path in a forest"],
["examples/photo-1618197345638-d2df92b39fe1.jpeg", "a close-up view of a white linen bed in a minimalistic room"],
["examples/photo-1546975490-e8b92a360b24.jpeg", "a warm living room with plants"],
["examples/photo-1499916078039-922301b0eb9b.jpeg", "a cozy bedroom on a bright day"],
])
demo.queue().launch(share=True)
|