Spaces:
Runtime error
Runtime error
amankishore
commited on
Commit
·
10528ca
1
Parent(s):
7a11626
app with gradio
Browse files
app.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1 |
import numpy as np
|
|
|
|
|
2 |
import torch
|
|
|
3 |
|
4 |
from my.utils import tqdm
|
5 |
from my.utils.seed import seed_everything
|
@@ -12,33 +15,12 @@ from run_nerf import VoxConfig
|
|
12 |
from voxnerf.utils import every
|
13 |
from voxnerf.vis import stitch_vis, bad_vis as nerf_vis
|
14 |
|
15 |
-
from run_sjc import render_one_view
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
@torch.no_grad()
|
20 |
-
def evaluate(score_model, vox, poser):
|
21 |
-
H, W = poser.H, poser.W
|
22 |
-
vox.eval()
|
23 |
-
K, poses = poser.sample_test(100)
|
24 |
-
|
25 |
-
aabb = vox.aabb.T.cpu().numpy()
|
26 |
-
vox = vox.to(device_glb)
|
27 |
-
|
28 |
-
num_imgs = len(poses)
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
pose = poses[i]
|
33 |
-
y, depth = render_one_view(vox, aabb, H, W, K, pose)
|
34 |
-
if isinstance(score_model, StableDiffusion):
|
35 |
-
y = score_model.decode(y)
|
36 |
-
pane, img, depth = vis_routine(y, depth)
|
37 |
-
|
38 |
-
# metric.put_artifact(
|
39 |
-
# "view_seq", ".mp4",
|
40 |
-
# lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "view")[1])
|
41 |
-
# )
|
42 |
|
43 |
def vis_routine(y, depth):
|
44 |
pane = nerf_vis(y, depth, final_H=256)
|
@@ -46,110 +28,193 @@ def vis_routine(y, depth):
|
|
46 |
depth = depth.cpu().numpy()
|
47 |
return pane, im, depth
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
grad = (Ds - y) / chosen_σs
|
113 |
else:
|
114 |
-
|
115 |
-
|
116 |
-
grad = grad.mean(0, keepdim=True)
|
117 |
-
|
118 |
-
y.backward(-grad, retain_graph=True)
|
119 |
|
120 |
-
|
121 |
-
center_depth = depth[7:-7, 7:-7]
|
122 |
-
border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50)
|
123 |
-
center_depth_mean = center_depth.mean()
|
124 |
-
depth_diff = center_depth_mean - border_depth_mean
|
125 |
-
depth_loss = - torch.log(depth_diff + 1e-12)
|
126 |
-
depth_loss = depth_weight * depth_loss
|
127 |
-
depth_loss.backward(retain_graph=True)
|
128 |
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
opt.step()
|
136 |
|
|
|
137 |
|
138 |
-
|
|
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
pane, img, depth = vis_routine(y, depth)
|
145 |
|
146 |
-
|
|
|
|
|
147 |
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
-
|
152 |
-
ckpt = vox.state_dict()
|
153 |
-
# evaluate(model, vox, poser)
|
154 |
|
155 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
import torch
|
5 |
+
import imageio
|
6 |
|
7 |
from my.utils import tqdm
|
8 |
from my.utils.seed import seed_everything
|
|
|
15 |
from voxnerf.utils import every
|
16 |
from voxnerf.vis import stitch_vis, bad_vis as nerf_vis
|
17 |
|
18 |
+
from run_sjc import render_one_view, tsr_stats
|
19 |
|
20 |
+
import gradio as gr
|
21 |
+
import gc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
device_glb = torch.device("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
def vis_routine(y, depth):
|
26 |
pane = nerf_vis(y, depth, final_H=256)
|
|
|
28 |
depth = depth.cpu().numpy()
|
29 |
return pane, im, depth
|
30 |
|
31 |
+
with gr.Blocks(css=".gradio-container {max-width: 512px; margin: auto;}") as demo:
|
32 |
+
# title
|
33 |
+
gr.Markdown('[Score Jacobian Chaining](https://github.com/pals-ttic/sjc) Lifting Pretrained 2D Diffusion Models for 3D Generation')
|
34 |
+
|
35 |
+
# inputs
|
36 |
+
prompt = gr.Textbox(label="Prompt", max_lines=1, value="A high quality photo of a delicious burger")
|
37 |
+
iters = gr.Slider(label="Iters", minimum=1000, maximum=20000, value=10000, step=100)
|
38 |
+
seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
|
39 |
+
button = gr.Button('Generate')
|
40 |
+
|
41 |
+
# outputs
|
42 |
+
image = gr.Image(label="image", visible=True)
|
43 |
+
depth = gr.Image(label="depth", visible=True)
|
44 |
+
video = gr.Video(label="video", visible=False)
|
45 |
+
logs = gr.Textbox(label="logging")
|
46 |
+
|
47 |
+
def submit(prompt, iters, seed):
|
48 |
+
start_t = time.time()
|
49 |
+
seed_everything(seed)
|
50 |
+
# cfgs = {'gddpm': {'model': 'm_lsun_256', 'lsun_cat': 'bedroom', 'imgnet_cat': -1}, 'sd': {'variant': 'v1', 'v2_highres': False, 'prompt': 'A high quality photo of a delicious burger', 'scale': 100.0, 'precision': 'autocast'}, 'lr': 0.05, 'n_steps': 10000, 'emptiness_scale': 10, 'emptiness_weight': 10000, 'emptiness_step': 0.5, 'emptiness_multiplier': 20.0, 'depth_weight': 0, 'var_red': True}
|
51 |
+
pose = PoseConfig(rend_hw=64, FoV=60.0, R=1.5)
|
52 |
+
poser = pose.make()
|
53 |
+
sd_model = SD(variant='v1', v2_highres=False, prompt=prompt, scale=100.0, precision='autocast')
|
54 |
+
model = sd_model.make()
|
55 |
+
vox = VoxConfig(
|
56 |
+
model_type="V_SD", grid_size=100, density_shift=-1.0, c=4,
|
57 |
+
blend_bg_texture=True, bg_texture_hw=4,
|
58 |
+
bbox_len=1.0)
|
59 |
+
vox = vox.make()
|
60 |
+
|
61 |
+
lr = 0.05
|
62 |
+
n_steps = iters
|
63 |
+
emptiness_scale = 10
|
64 |
+
emptiness_weight = 10000
|
65 |
+
emptiness_step = 0.5
|
66 |
+
emptiness_multiplier = 20.0
|
67 |
+
depth_weight = 0
|
68 |
+
var_red = True
|
69 |
+
|
70 |
+
assert model.samps_centered()
|
71 |
+
_, target_H, target_W = model.data_shape()
|
72 |
+
bs = 1
|
73 |
+
aabb = vox.aabb.T.cpu().numpy()
|
74 |
+
vox = vox.to(device_glb)
|
75 |
+
opt = torch.optim.Adamax(vox.opt_params(), lr=lr)
|
76 |
+
|
77 |
+
H, W = poser.H, poser.W
|
78 |
+
Ks, poses, prompt_prefixes = poser.sample_train(n_steps)
|
79 |
+
|
80 |
+
ts = model.us[30:-10]
|
81 |
+
|
82 |
+
same_noise = torch.randn(1, 4, H, W, device=model.device).repeat(bs, 1, 1, 1)
|
83 |
+
|
84 |
+
with tqdm(total=n_steps) as pbar:
|
85 |
+
for i in range(n_steps):
|
86 |
+
|
87 |
+
p = f"{prompt_prefixes[i]} {model.prompt}"
|
88 |
+
score_conds = model.prompts_emb([p])
|
89 |
+
|
90 |
+
y, depth, ws = render_one_view(vox, aabb, H, W, Ks[i], poses[i], return_w=True)
|
91 |
+
|
92 |
+
if isinstance(model, StableDiffusion):
|
93 |
+
pass
|
|
|
94 |
else:
|
95 |
+
y = torch.nn.functional.interpolate(y, (target_H, target_W), mode='bilinear')
|
|
|
|
|
|
|
|
|
96 |
|
97 |
+
opt.zero_grad()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
+
with torch.no_grad():
|
100 |
+
chosen_σs = np.random.choice(ts, bs, replace=False)
|
101 |
+
chosen_σs = chosen_σs.reshape(-1, 1, 1, 1)
|
102 |
+
chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32)
|
103 |
+
# chosen_σs = us[i]
|
|
|
|
|
104 |
|
105 |
+
noise = torch.randn(bs, *y.shape[1:], device=model.device)
|
106 |
|
107 |
+
zs = y + chosen_σs * noise
|
108 |
+
Ds = model.denoise(zs, chosen_σs, **score_conds)
|
109 |
|
110 |
+
if var_red:
|
111 |
+
grad = (Ds - y) / chosen_σs
|
112 |
+
else:
|
113 |
+
grad = (Ds - zs) / chosen_σs
|
|
|
114 |
|
115 |
+
grad = grad.mean(0, keepdim=True)
|
116 |
+
|
117 |
+
y.backward(-grad, retain_graph=True)
|
118 |
|
119 |
+
if depth_weight > 0:
|
120 |
+
center_depth = depth[7:-7, 7:-7]
|
121 |
+
border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50)
|
122 |
+
center_depth_mean = center_depth.mean()
|
123 |
+
depth_diff = center_depth_mean - border_depth_mean
|
124 |
+
depth_loss = - torch.log(depth_diff + 1e-12)
|
125 |
+
depth_loss = depth_weight * depth_loss
|
126 |
+
depth_loss.backward(retain_graph=True)
|
127 |
+
|
128 |
+
emptiness_loss = torch.log(1 + emptiness_scale * ws).mean()
|
129 |
+
emptiness_loss = emptiness_weight * emptiness_loss
|
130 |
+
if emptiness_step * n_steps <= i:
|
131 |
+
emptiness_loss *= emptiness_multiplier
|
132 |
+
emptiness_loss.backward()
|
133 |
+
|
134 |
+
opt.step()
|
135 |
+
|
136 |
+
|
137 |
+
# metric.put_scalars()
|
138 |
+
|
139 |
+
if every(pbar, percent=1):
|
140 |
+
with torch.no_grad():
|
141 |
+
if isinstance(model, StableDiffusion):
|
142 |
+
y = model.decode(y)
|
143 |
+
pane, img, depth = vis_routine(y, depth)
|
144 |
+
|
145 |
+
# TODO: Output pane, img and depth to Gradio
|
146 |
+
|
147 |
+
pbar.update()
|
148 |
+
pbar.set_description(p)
|
149 |
+
|
150 |
+
yield {
|
151 |
+
image: gr.update(value=img, visible=True),
|
152 |
+
depth: gr.update(value=depth, visible=True),
|
153 |
+
video: gr.update(visible=False),
|
154 |
+
logs: str(tsr_stats(y)),
|
155 |
+
}
|
156 |
+
|
157 |
+
# TODO: Save Checkpoint
|
158 |
+
ckpt = vox.state_dict()
|
159 |
+
H, W = poser.H, poser.W
|
160 |
+
vox.eval()
|
161 |
+
K, poses = poser.sample_test(100)
|
162 |
+
|
163 |
+
aabb = vox.aabb.T.cpu().numpy()
|
164 |
+
vox = vox.to(device_glb)
|
165 |
+
|
166 |
+
num_imgs = len(poses)
|
167 |
+
|
168 |
+
for i in (pbar := tqdm(range(num_imgs))):
|
169 |
+
|
170 |
+
pose = poses[i]
|
171 |
+
y, depth = render_one_view(vox, aabb, H, W, K, pose)
|
172 |
+
if isinstance(model, StableDiffusion):
|
173 |
+
y = model.decode(y)
|
174 |
+
pane, img, depth = vis_routine(y, depth)
|
175 |
+
|
176 |
+
# Save img to output
|
177 |
+
img.save(f"output/{i}.png")
|
178 |
+
|
179 |
+
yield {
|
180 |
+
image: gr.update(value=img, visible=True),
|
181 |
+
depth: gr.update(value=depth, visible=True),
|
182 |
+
video: gr.update(visible=False),
|
183 |
+
logs: str(tsr_stats(y)),
|
184 |
+
}
|
185 |
+
|
186 |
+
output_video = "view_seq.mp4"
|
187 |
+
|
188 |
+
def export_movie(seqs, fname, fps=30):
|
189 |
+
fname = Path(fname)
|
190 |
+
if fname.suffix == "":
|
191 |
+
fname = fname.with_suffix(".mp4")
|
192 |
+
writer = imageio.get_writer(fname, fps=fps)
|
193 |
+
for img in seqs:
|
194 |
+
writer.append_data(img)
|
195 |
+
writer.close()
|
196 |
+
|
197 |
+
def stitch_vis(save_fn, img_fnames, fps=10):
|
198 |
+
figs = [imageio.imread(fn) for fn in img_fnames]
|
199 |
+
export_movie(figs, save_fn, fps)
|
200 |
+
|
201 |
+
stitch_vis(output_video, [f"output/{i}.png" for i in range(num_imgs)])
|
202 |
|
203 |
+
end_t = time.time()
|
|
|
|
|
204 |
|
205 |
+
yield {
|
206 |
+
image: gr.update(value=img, visible=False),
|
207 |
+
depth: gr.update(value=depth, visible=False),
|
208 |
+
video: gr.update(value=output_video, visible=True),
|
209 |
+
logs: f"Generation Finished in {(end_t - start_t)/ 60:.4f} minutes!",
|
210 |
+
}
|
211 |
+
|
212 |
+
button.click(
|
213 |
+
submit,
|
214 |
+
[prompt, iters, seed],
|
215 |
+
[image, depth, video, logs]
|
216 |
+
)
|
217 |
+
|
218 |
+
# concurrency_count: only allow ONE running progress, else GPU will OOM.
|
219 |
+
demo.queue(concurrency_count=1)
|
220 |
+
demo.launch()
|