Stanislaw Szymanowicz commited on
Commit
e10da38
·
1 Parent(s): 4aa5114

Add model and app file

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +178 -4
  3. model_file/objaverse/.hydra/config.yaml +66 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ */__pycache__
app.py CHANGED
@@ -1,7 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
1
+ import torch
2
+ import torchvision
3
+ import numpy as np
4
+
5
+ import os
6
+ from omegaconf import OmegaConf
7
+ from PIL import Image
8
+
9
+ from utils.app_utils import (
10
+ remove_background,
11
+ resize_foreground,
12
+ set_white_background,
13
+ resize_to_128,
14
+ to_tensor,
15
+ get_source_camera_v2w_rmo_and_quats,
16
+ get_target_cameras,
17
+ export_to_obj)
18
+
19
+ import imageio
20
+
21
+ from scene.gaussian_predictor import GaussianSplatPredictor
22
+ from gaussian_renderer import render_predicted
23
+
24
  import gradio as gr
25
 
26
+ import rembg
27
+
28
+ def main():
29
+
30
+ # ============= model loading ==========
31
+ def load_model(device):
32
+ experiment_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
33
+ "model_file", "objaverse")
34
+ # load cfg
35
+ training_cfg = OmegaConf.load(os.path.join(experiment_path, ".hydra", "config.yaml"))
36
+ # load model
37
+ model = GaussianSplatPredictor(training_cfg)
38
+ ckpt_loaded = torch.load(os.path.join(experiment_path, "model_latest.pth"), map_location=device)
39
+ model.load_state_dict(ckpt_loaded["model_state_dict"])
40
+ return model, training_cfg
41
+
42
+ if torch.cuda.is_available():
43
+ device = "cuda:0"
44
+ else:
45
+ device = "cpu"
46
+ torch.cuda.set_device(device)
47
+
48
+ model, model_cfg = load_model(device)
49
+ model.to(device)
50
+
51
+ # ============= image preprocessing =============
52
+ rembg_session = rembg.new_session()
53
+
54
+ def check_input_image(input_image):
55
+ if input_image is None:
56
+ raise gr.Error("No image uploaded!")
57
+
58
+ def preprocess(input_image, preprocess_background=True, foreground_ratio=0.65):
59
+ # 0.7 seems to be a reasonable foreground ratio
60
+ if preprocess_background:
61
+ image = input_image.convert("RGB")
62
+ image = remove_background(image, rembg_session)
63
+ image = resize_foreground(image, foreground_ratio)
64
+ image = set_white_background(image)
65
+ else:
66
+ image = input_image
67
+ if image.mode == "RGBA":
68
+ image = set_white_background(image)
69
+ image = resize_to_128(image)
70
+ return image
71
+
72
+ ply_out_path="/users/stan/splatter-image/gradio_out/mesh.ply"
73
+ os.makedirs(os.path.dirname(ply_out_path), exist_ok=True)
74
+
75
+ def reconstruct_and_export(image):
76
+ """
77
+ Passes image through model, outputs reconstruction in form of a dict of tensors.
78
+ """
79
+ image = to_tensor(image).to(device)
80
+ view_to_world_source, rot_transform_quats = get_source_camera_v2w_rmo_and_quats()
81
+ view_to_world_source = view_to_world_source.to(device)
82
+ rot_transform_quats = rot_transform_quats.to(device)
83
+
84
+ reconstruction_unactivated = model(
85
+ image.unsqueeze(0).unsqueeze(0),
86
+ view_to_world_source,
87
+ rot_transform_quats,
88
+ None,
89
+ activate_output=False)
90
+
91
+ reconstruction = {k: v[0].contiguous() for k, v in reconstruction_unactivated.items()}
92
+ reconstruction["scaling"] = model.scaling_activation(reconstruction["scaling"])
93
+ reconstruction["opacity"] = model.opacity_activation(reconstruction["opacity"])
94
+
95
+ # render images in a loop
96
+ world_view_transforms, full_proj_transforms, camera_centers = get_target_cameras()
97
+ background = torch.tensor([1, 1, 1] , dtype=torch.float32, device=device)
98
+ loop_renders = []
99
+ t_to_512 = torchvision.transforms.Resize(512, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
100
+ for r_idx in range( world_view_transforms.shape[0]):
101
+ image = render_predicted(reconstruction,
102
+ world_view_transforms[r_idx].to(device),
103
+ full_proj_transforms[r_idx].to(device),
104
+ camera_centers[r_idx].to(device),
105
+ background,
106
+ model_cfg,
107
+ focals_pixels=None)["render"]
108
+ image = t_to_512(image)
109
+ loop_renders.append(torch.clamp(image * 255, 0.0, 255.0).detach().permute(1, 2, 0).cpu().numpy().astype(np.uint8))
110
+ loop_out_path = os.path.join(os.path.dirname(ply_out_path), "loop.mp4")
111
+ imageio.mimsave(loop_out_path, loop_renders, fps=25)
112
+ # export reconstruction to ply
113
+ export_to_obj(reconstruction_unactivated, ply_out_path)
114
+
115
+ return loop_out_path, ply_out_path
116
+
117
+ with gr.Blocks() as demo:
118
+ gr.Markdown(
119
+ """
120
+
121
+ # Splatter Image Demo
122
+ [Splatter Image](https://github.com/szymanowiczs/splatter-image) (CVPR 2024) is a fast, super cheap to train method for object 3D reconstruction from a single image.
123
+ The model used in the demo was trained on **Objaverse-LVIS on 2 A6000 GPUs for 3.5 days**.
124
+ On NVIDIA V100 GPU, reconstruction can be done at 38FPS and rendering at 588FPS.
125
+ Upload an image of an object to see how the Splatter Image does.
126
+
127
+ **Comments:**
128
+ 1. The first example you upload should take about 4.5 seconds (with preprocessing, saving and overhead), the following take about 1.5s.
129
+ 2. The model does not work well on photos of humans.
130
+ 3. The 3D viewer shows a .ply mesh extracted from a mix of 3D Gaussians. Artefacts might show - see video for more faithful results.
131
+ 4. Best results are achieved on the datasets described in the [repository](https://github.com/szymanowiczs/splatter-image) using that code. This demo is experimental.
132
+ 5. Our model might not be better than some state-of-the-art methods, but it is of comparable quality and is **much** cheaper to train and run.
133
+ """
134
+ )
135
+ with gr.Row(variant="panel"):
136
+ with gr.Column():
137
+ with gr.Row():
138
+ input_image = gr.Image(
139
+ label="Input Image",
140
+ image_mode="RGBA",
141
+ sources="upload",
142
+ type="pil",
143
+ elem_id="content_image",
144
+ )
145
+ processed_image = gr.Image(label="Processed Image", interactive=False)
146
+ with gr.Row():
147
+ with gr.Group():
148
+ preprocess_background = gr.Checkbox(
149
+ label="Remove Background", value=True
150
+ )
151
+ with gr.Row():
152
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
153
+ with gr.Column():
154
+ with gr.Row():
155
+ with gr.Tab("Reconstruction"):
156
+ with gr.Column():
157
+ output_video = gr.Video(value=None, width=512, label="Rendered Video", autoplay=True)
158
+ output_model = gr.Model3D(
159
+ height=512,
160
+ label="Output Model",
161
+ interactive=False
162
+ )
163
+
164
+ submit.click(fn=check_input_image, inputs=[input_image]).success(
165
+ fn=preprocess,
166
+ inputs=[input_image, preprocess_background],
167
+ outputs=[processed_image],
168
+ ).success(
169
+ fn=reconstruct_and_export,
170
+ inputs=[processed_image],
171
+ outputs=[output_video, output_model],
172
+ )
173
+
174
+ demo.queue(max_size=1)
175
+ demo.launch()
176
+
177
+
178
+ if __name__ == "__main__":
179
+ main()
180
 
181
+ # gradio app interface
 
model_file/objaverse/.hydra/config.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb:
2
+ project: gs_pred
3
+ cam_embd:
4
+ embedding: null
5
+ encode_embedding: null
6
+ dimension: 0
7
+ method: null
8
+ general:
9
+ device: 0
10
+ random_seed: 0
11
+ num_devices: 2
12
+ mixed_precision: true
13
+ data:
14
+ training_resolution: 128
15
+ fov: 49.134342641202636
16
+ subset: -1
17
+ input_images: 1
18
+ znear: 0.8
19
+ zfar: 3.2
20
+ category: objaverse
21
+ white_background: true
22
+ origin_distances: false
23
+ opt:
24
+ iterations: 50001
25
+ base_lr: 6.34584421e-05
26
+ batch_size: 16
27
+ betas:
28
+ - 0.9
29
+ - 0.999
30
+ loss: l2
31
+ imgs_per_obj: 4
32
+ ema:
33
+ use: true
34
+ update_every: 10
35
+ update_after_step: 100
36
+ beta: 0.9999
37
+ lambda_lpips: 0.33814373
38
+ start_lpips_after: 0
39
+ step_lr_at: -1
40
+ model:
41
+ max_sh_degree: 1
42
+ inverted_x: false
43
+ inverted_y: true
44
+ name: SingleUNet
45
+ opacity_scale: 1.0
46
+ opacity_bias: -2.0
47
+ scale_scale: 0.01
48
+ scale_bias: 0.02
49
+ xyz_scale: 0.1
50
+ xyz_bias: 0.0
51
+ depth_scale: 1.0
52
+ depth_bias: 0.0
53
+ network_without_offset: false
54
+ network_with_offset: true
55
+ attention_resolutions:
56
+ - 16
57
+ cross_view_attention: true
58
+ isotropic: false
59
+ base_dim: 128
60
+ num_blocks: 4
61
+ logging:
62
+ ckpt_iterations: 1000
63
+ val_log: 10000
64
+ loss_log: 10
65
+ loop_log: 10000
66
+ render_log: 10000