codestella
commited on
Commit
·
97ec1af
1
Parent(s):
3c9f729
code change
Browse files- .gitattributes +0 -17
- LICENSE +0 -0
- __init__.py +0 -0
- assets/lego-nerf.gif +0 -0
- configs/blender.yaml +0 -0
- configs/demo.yaml +0 -0
- configs/diet_nerf_tpu_vm_4shot.yaml +2 -1
- configs/diet_nerf_tpu_vm_few_shot.yaml +2 -1
- configs/diet_nerf_tpu_vm_test.yaml +3 -2
- configs/eval_diet_nerf_tpu_vm_few_shot.yaml +0 -0
- configs/nerf_tpu_vm_4shot.yaml +0 -0
- configs/nerf_tpu_vm_few_shot.yaml +0 -0
- configs/orig_nerf_tpu_vm_full.yaml +0 -0
- configs/orig_nerf_tpu_vm_test.yaml +0 -0
- eval.py +18 -9
- eval.sh +0 -0
- example_data/imgs/r_0.png +0 -0
- example_data/transforms_test.json +0 -0
- example_data/transforms_train.json +0 -0
- fork-of-first-touch-of-nerf-in-jax.ipynb +0 -0
- nerf/__init__.py +0 -0
- nerf/__pycache__/__init__.cpython-37.pyc +0 -0
- nerf/__pycache__/clip_utils.cpython-37.pyc +0 -0
- nerf/__pycache__/datasets.cpython-37.pyc +0 -0
- nerf/__pycache__/model_utils.cpython-37.pyc +0 -0
- nerf/__pycache__/models.cpython-37.pyc +0 -0
- nerf/__pycache__/utils.cpython-37.pyc +0 -0
- nerf/clip_utils.py +17 -23
- nerf/datasets.py +15 -9
- nerf/model_utils.py +0 -0
- nerf/models.py +2 -3
- nerf/utils.py +4 -2
- requirements.txt +0 -0
- run.sh +0 -0
- train.py +9 -21
- train.sh +0 -0
.gitattributes
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
CHANGED
File without changes
|
__init__.py
CHANGED
File without changes
|
assets/lego-nerf.gif
DELETED
Binary file (519 kB)
|
|
configs/blender.yaml
CHANGED
File without changes
|
configs/demo.yaml
CHANGED
File without changes
|
configs/diet_nerf_tpu_vm_4shot.yaml
CHANGED
@@ -8,8 +8,9 @@ white_bkgd: true
|
|
8 |
batch_size: 1024
|
9 |
randomized: true
|
10 |
max_steps: 200000
|
|
|
11 |
print_every: 100
|
12 |
-
render_every:
|
13 |
save_every: 5000
|
14 |
use_semantic_loss: true
|
15 |
clip_model_name: openai/clip-vit-base-patch32
|
|
|
8 |
batch_size: 1024
|
9 |
randomized: true
|
10 |
max_steps: 200000
|
11 |
+
stop_sc_loss: 160000
|
12 |
print_every: 100
|
13 |
+
render_every: 1000
|
14 |
save_every: 5000
|
15 |
use_semantic_loss: true
|
16 |
clip_model_name: openai/clip-vit-base-patch32
|
configs/diet_nerf_tpu_vm_few_shot.yaml
CHANGED
@@ -8,8 +8,9 @@ white_bkgd: true
|
|
8 |
batch_size: 1024
|
9 |
randomized: true
|
10 |
max_steps: 200000
|
|
|
11 |
print_every: 100
|
12 |
-
render_every:
|
13 |
save_every: 5000
|
14 |
use_semantic_loss: true
|
15 |
clip_model_name: openai/clip-vit-base-patch32
|
|
|
8 |
batch_size: 1024
|
9 |
randomized: true
|
10 |
max_steps: 200000
|
11 |
+
stop_sc_loss: 160000
|
12 |
print_every: 100
|
13 |
+
render_every: 1000
|
14 |
save_every: 5000
|
15 |
use_semantic_loss: true
|
16 |
clip_model_name: openai/clip-vit-base-patch32
|
configs/diet_nerf_tpu_vm_test.yaml
CHANGED
@@ -2,12 +2,13 @@ dataset: blender
|
|
2 |
batching: single_image
|
3 |
factor: 0
|
4 |
num_coarse_samples: 64
|
5 |
-
num_fine_samples:
|
6 |
use_viewdirs: true
|
7 |
white_bkgd: true
|
8 |
-
batch_size:
|
9 |
randomized: true
|
10 |
max_steps: 200000
|
|
|
11 |
print_every: 100
|
12 |
render_every: 1000
|
13 |
save_every: 5000
|
|
|
2 |
batching: single_image
|
3 |
factor: 0
|
4 |
num_coarse_samples: 64
|
5 |
+
num_fine_samples: 128
|
6 |
use_viewdirs: true
|
7 |
white_bkgd: true
|
8 |
+
batch_size: 1024
|
9 |
randomized: true
|
10 |
max_steps: 200000
|
11 |
+
stop_sc_loss: 160000
|
12 |
print_every: 100
|
13 |
render_every: 1000
|
14 |
save_every: 5000
|
configs/eval_diet_nerf_tpu_vm_few_shot.yaml
CHANGED
File without changes
|
configs/nerf_tpu_vm_4shot.yaml
CHANGED
File without changes
|
configs/nerf_tpu_vm_few_shot.yaml
CHANGED
File without changes
|
configs/orig_nerf_tpu_vm_full.yaml
CHANGED
File without changes
|
configs/orig_nerf_tpu_vm_test.yaml
CHANGED
File without changes
|
eval.py
CHANGED
@@ -112,30 +112,39 @@ def main(unused_argv):
|
|
112 |
summary_writer = tensorboard.SummaryWriter(
|
113 |
path.join(FLAGS.train_dir, "eval"))
|
114 |
|
115 |
-
def generate_spinning_gif(radius, phi,
|
116 |
_rng = random.PRNGKey(0)
|
117 |
partial_render_fn = functools.partial(render_pfn, state.optimizer.target)
|
118 |
gif_images = []
|
|
|
119 |
for theta in tqdm(np.linspace(-math.pi, math.pi, frame_n)):
|
120 |
camtoworld = np.array(clip_utils.pose_spherical(radius, theta, phi))
|
121 |
rays = dataset.camtoworld_matrix_to_rays(camtoworld, downsample=4)
|
122 |
_rng, key0, key1 = random.split(_rng, 3)
|
123 |
-
color,
|
124 |
_rng, False, chunk=4096)
|
125 |
image = predict_to_image(color)
|
|
|
126 |
gif_images.append(image)
|
|
|
|
|
|
|
|
|
127 |
gif_images[0].save(gif_fn, save_all=True,
|
128 |
append_images=gif_images,
|
129 |
duration=100, loop=0)
|
130 |
-
|
|
|
|
|
|
|
|
|
131 |
|
132 |
if FLAGS.generate_gif_only:
|
133 |
print('generate GIF file only')
|
134 |
_radius = 4.
|
135 |
_phi = (30 * math.pi) / 180
|
136 |
-
|
137 |
-
|
138 |
-
print(f'GIF file for spinning views written: {_gif_fn}')
|
139 |
return
|
140 |
else:
|
141 |
print('generate GIF file AND evaluate model performance')
|
@@ -149,6 +158,7 @@ def main(unused_argv):
|
|
149 |
utils.makedirs(out_dir)
|
150 |
psnr_values = []
|
151 |
ssim_values = []
|
|
|
152 |
#lpips_values = []
|
153 |
if not FLAGS.eval_once:
|
154 |
showcase_index = np.random.randint(0, dataset.size)
|
@@ -225,9 +235,8 @@ def main(unused_argv):
|
|
225 |
if not is_gif_written:
|
226 |
_radius = 4.
|
227 |
_phi = (30 * math.pi) / 180
|
228 |
-
|
229 |
-
|
230 |
-
print(f'GIF file for spinning views written: {_gif_fn}')
|
231 |
is_gif_written = True
|
232 |
|
233 |
if FLAGS.eval_once:
|
|
|
112 |
summary_writer = tensorboard.SummaryWriter(
|
113 |
path.join(FLAGS.train_dir, "eval"))
|
114 |
|
115 |
+
def generate_spinning_gif(radius, phi, output_dir, frame_n):
|
116 |
_rng = random.PRNGKey(0)
|
117 |
partial_render_fn = functools.partial(render_pfn, state.optimizer.target)
|
118 |
gif_images = []
|
119 |
+
gif_images2 = []
|
120 |
for theta in tqdm(np.linspace(-math.pi, math.pi, frame_n)):
|
121 |
camtoworld = np.array(clip_utils.pose_spherical(radius, theta, phi))
|
122 |
rays = dataset.camtoworld_matrix_to_rays(camtoworld, downsample=4)
|
123 |
_rng, key0, key1 = random.split(_rng, 3)
|
124 |
+
color, disp, _ = utils.render_image(partial_render_fn, rays,
|
125 |
_rng, False, chunk=4096)
|
126 |
image = predict_to_image(color)
|
127 |
+
image2 = predict_to_image(disp[Ellipsis, 0])
|
128 |
gif_images.append(image)
|
129 |
+
gif_images2.append(image2)
|
130 |
+
|
131 |
+
gif_fn = os.path.join(output_dir, 'rgb_spinning.gif')
|
132 |
+
gif_fn2 = os.path.join(output_dir, 'disp_spinning.gif')
|
133 |
gif_images[0].save(gif_fn, save_all=True,
|
134 |
append_images=gif_images,
|
135 |
duration=100, loop=0)
|
136 |
+
gif_images2[0].save(gif_fn2, save_all=True,
|
137 |
+
append_images=gif_images2,
|
138 |
+
duration=100, loop=0)
|
139 |
+
|
140 |
+
#return gif_images, gif_images2
|
141 |
|
142 |
if FLAGS.generate_gif_only:
|
143 |
print('generate GIF file only')
|
144 |
_radius = 4.
|
145 |
_phi = (30 * math.pi) / 180
|
146 |
+
generate_spinning_gif(_radius, _phi, out_dir, frame_n=30)
|
147 |
+
print('GIF file for spinning views written)')
|
|
|
148 |
return
|
149 |
else:
|
150 |
print('generate GIF file AND evaluate model performance')
|
|
|
158 |
utils.makedirs(out_dir)
|
159 |
psnr_values = []
|
160 |
ssim_values = []
|
161 |
+
|
162 |
#lpips_values = []
|
163 |
if not FLAGS.eval_once:
|
164 |
showcase_index = np.random.randint(0, dataset.size)
|
|
|
235 |
if not is_gif_written:
|
236 |
_radius = 4.
|
237 |
_phi = (30 * math.pi) / 180
|
238 |
+
generate_spinning_gif(_radius, _phi, out_dir, frame_n=30)
|
239 |
+
print(f'GIF file for spinning views written')
|
|
|
240 |
is_gif_written = True
|
241 |
|
242 |
if FLAGS.eval_once:
|
eval.sh
CHANGED
File without changes
|
example_data/imgs/r_0.png
CHANGED
example_data/transforms_test.json
CHANGED
File without changes
|
example_data/transforms_train.json
CHANGED
File without changes
|
fork-of-first-touch-of-nerf-in-jax.ipynb
CHANGED
File without changes
|
nerf/__init__.py
CHANGED
File without changes
|
nerf/__pycache__/__init__.cpython-37.pyc
DELETED
Binary file (137 Bytes)
|
|
nerf/__pycache__/clip_utils.cpython-37.pyc
DELETED
Binary file (5.16 kB)
|
|
nerf/__pycache__/datasets.cpython-37.pyc
DELETED
Binary file (18.3 kB)
|
|
nerf/__pycache__/model_utils.cpython-37.pyc
DELETED
Binary file (10 kB)
|
|
nerf/__pycache__/models.cpython-37.pyc
DELETED
Binary file (5.08 kB)
|
|
nerf/__pycache__/utils.cpython-37.pyc
DELETED
Binary file (15.8 kB)
|
|
nerf/clip_utils.py
CHANGED
@@ -15,50 +15,44 @@ FLAGS = flags.FLAGS
|
|
15 |
|
16 |
@partial(jax.jit, static_argnums=[0])
|
17 |
def semantic_loss(clip_model, src_image, target_embedding):
|
18 |
-
|
19 |
-
f_image = utils.unshard(src_image[
|
20 |
-
|
21 |
-
|
22 |
-
#c_image = c_image.reshape([w, w, 3])
|
23 |
f_image = f_image.reshape([w, w, 3])
|
24 |
-
|
25 |
-
src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.
|
26 |
-
#src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image, f_image]).transpose(0, 3, 1, 2)))
|
27 |
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
|
28 |
-
sc_loss =
|
29 |
return sc_loss, f_image
|
30 |
|
31 |
def semantic_step_multi(render_pfn, clip_model, rng, state, batch, lr):
|
32 |
-
random_rays =
|
33 |
-
target_embedding = batch["embedding"]
|
34 |
rng, key_0, key_1 = random.split(rng,3)
|
35 |
-
|
36 |
def loss_fn(variables):
|
37 |
-
|
38 |
-
sc_loss, f_image = semantic_loss(clip_model,
|
39 |
return sc_loss * FLAGS.sc_loss_mult, f_image
|
40 |
(sc_loss, src_image), grad = jax.value_and_grad(loss_fn, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
|
41 |
return sc_loss, grad, src_image
|
42 |
|
43 |
@partial(jax.jit, static_argnums=[0, 1])
|
44 |
def semantic_step_single(model, clip_model, rng, state, batch, lr):
|
45 |
-
|
46 |
-
|
47 |
-
random_rays = batch["random_rays"]
|
48 |
rng, key_0, key_1 = random.split(rng,3)
|
49 |
|
50 |
def semantic_loss(variables):
|
51 |
c_image, f_image = model.apply(variables, key_0, key_1, random_rays, False, rgb_only = True)
|
52 |
-
# reshape flat pixel to an image (assume 3 channels & square shape)
|
53 |
w = int(math.sqrt(f_image.shape[0]))
|
54 |
-
|
55 |
f_image = f_image.reshape([w, w, 3])
|
56 |
|
57 |
-
src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.
|
58 |
-
# src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image, f_image]).transpose(0, 3, 1, 2)))
|
59 |
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
|
60 |
-
|
61 |
-
sc_loss = 0.5 * jnp.sum((src_embedding - target_embedding)**2)
|
62 |
return sc_loss * FLAGS.sc_loss_mult, f_image
|
63 |
(sc_loss, src_image), grad = jax.value_and_grad(semantic_loss, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
|
64 |
return sc_loss, grad, src_image
|
|
|
15 |
|
16 |
@partial(jax.jit, static_argnums=[0])
|
17 |
def semantic_loss(clip_model, src_image, target_embedding):
|
18 |
+
c_image = utils.unshard(src_image[0])
|
19 |
+
f_image = utils.unshard(src_image[1])
|
20 |
+
w = int(math.sqrt(f_image.shape[0]))
|
21 |
+
c_image = c_image.reshape([w, w, 3])
|
|
|
22 |
f_image = f_image.reshape([w, w, 3])
|
23 |
+
|
24 |
+
src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image,f_image],0).transpose(0, 3, 1, 2)))
|
|
|
25 |
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
|
26 |
+
sc_loss = 2 - jnp.sum(src_embedding * target_embedding)
|
27 |
return sc_loss, f_image
|
28 |
|
29 |
def semantic_step_multi(render_pfn, clip_model, rng, state, batch, lr):
|
30 |
+
random_rays = batch["random_rays"]
|
31 |
+
target_embedding = batch["embedding"]
|
32 |
rng, key_0, key_1 = random.split(rng,3)
|
33 |
+
|
34 |
def loss_fn(variables):
|
35 |
+
images = render_pfn(variables, key_0, key_1, random_rays)
|
36 |
+
sc_loss, f_image = semantic_loss(clip_model, images, target_embedding)
|
37 |
return sc_loss * FLAGS.sc_loss_mult, f_image
|
38 |
(sc_loss, src_image), grad = jax.value_and_grad(loss_fn, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
|
39 |
return sc_loss, grad, src_image
|
40 |
|
41 |
@partial(jax.jit, static_argnums=[0, 1])
|
42 |
def semantic_step_single(model, clip_model, rng, state, batch, lr):
|
43 |
+
random_rays = jax.tree_map(lambda x: x.reshape(-1,3), batch["random_rays"])
|
44 |
+
target_embedding = batch["embedding"]
|
|
|
45 |
rng, key_0, key_1 = random.split(rng,3)
|
46 |
|
47 |
def semantic_loss(variables):
|
48 |
c_image, f_image = model.apply(variables, key_0, key_1, random_rays, False, rgb_only = True)
|
|
|
49 |
w = int(math.sqrt(f_image.shape[0]))
|
50 |
+
c_image = c_image.reshape([w, w, 3])
|
51 |
f_image = f_image.reshape([w, w, 3])
|
52 |
|
53 |
+
src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image,f_image],0).transpose(0, 3, 1, 2)))
|
|
|
54 |
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
|
55 |
+
sc_loss = 2 - jnp.sum(src_embedding * target_embedding)
|
|
|
56 |
return sc_loss * FLAGS.sc_loss_mult, f_image
|
57 |
(sc_loss, src_image), grad = jax.value_and_grad(semantic_loss, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
|
58 |
return sc_loss, grad, src_image
|
nerf/datasets.py
CHANGED
@@ -236,6 +236,7 @@ class Blender(Dataset):
|
|
236 |
camera_angle_x = float(meta["camera_angle_x"])
|
237 |
self.focal = .5 * self.w / np.tan(.5 * camera_angle_x)
|
238 |
self.n_examples = self.images.shape[0]
|
|
|
239 |
|
240 |
if flags.use_semantic_loss and clip_model is not None:
|
241 |
embs = []
|
@@ -258,8 +259,8 @@ class Blender(Dataset):
|
|
258 |
|
259 |
frames = np.arange(len(meta["frames"]))
|
260 |
if few_shot > 0 and split == 'train':
|
261 |
-
np.random.seed(0)
|
262 |
-
np.random.shuffle(frames)
|
263 |
frames = frames[:few_shot]
|
264 |
|
265 |
# if split == 'train':
|
@@ -308,16 +309,21 @@ class Blender(Dataset):
|
|
308 |
src_seed = int(time.time())
|
309 |
src_rng = jax.random.PRNGKey(src_seed)
|
310 |
src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far)))
|
311 |
-
|
312 |
-
cx = np.random.randint(
|
313 |
-
cy = np.random.randint(
|
314 |
-
d =
|
315 |
-
|
|
|
|
|
|
|
316 |
w = random_rays[0].shape[0] - random_rays[0].shape[0]%jax.local_device_count()
|
317 |
random_rays = jax.tree_map(lambda x: x[:w,:w].reshape(-1,3), random_rays)
|
318 |
-
batch_dict["random_rays"] = random_rays
|
|
|
|
|
319 |
return batch_dict
|
320 |
-
|
321 |
class LLFF(Dataset):
|
322 |
"""LLFF Dataset."""
|
323 |
|
|
|
236 |
camera_angle_x = float(meta["camera_angle_x"])
|
237 |
self.focal = .5 * self.w / np.tan(.5 * camera_angle_x)
|
238 |
self.n_examples = self.images.shape[0]
|
239 |
+
self.dtype = flags.clip_output_dtype
|
240 |
|
241 |
if flags.use_semantic_loss and clip_model is not None:
|
242 |
embs = []
|
|
|
259 |
|
260 |
frames = np.arange(len(meta["frames"]))
|
261 |
if few_shot > 0 and split == 'train':
|
262 |
+
# np.random.seed(0)
|
263 |
+
# np.random.shuffle(frames)
|
264 |
frames = frames[:few_shot]
|
265 |
|
266 |
# if split == 'train':
|
|
|
309 |
src_seed = int(time.time())
|
310 |
src_rng = jax.random.PRNGKey(src_seed)
|
311 |
src_camtoworld = np.array(clip_utils.random_pose(src_rng, (self.near, self.far)))
|
312 |
+
|
313 |
+
cx = np.random.randint(320, 480)
|
314 |
+
cy = np.random.randint(320, 480)
|
315 |
+
d = 140
|
316 |
+
|
317 |
+
random_rays = self.camtoworld_matrix_to_rays(src_camtoworld, downsample = 1)
|
318 |
+
random_rays = jax.tree_map(lambda x: x[cy-d:cy+d:4,cx-d:cx+d:4], random_rays)
|
319 |
+
|
320 |
w = random_rays[0].shape[0] - random_rays[0].shape[0]%jax.local_device_count()
|
321 |
random_rays = jax.tree_map(lambda x: x[:w,:w].reshape(-1,3), random_rays)
|
322 |
+
batch_dict["random_rays"] = utils.shard(random_rays)
|
323 |
+
if self.dtype == 'float16':
|
324 |
+
batch_dict = jax.tree_map(lambda x: x.astype(np.float16), batch_dict)
|
325 |
return batch_dict
|
326 |
+
|
327 |
class LLFF(Dataset):
|
328 |
"""LLFF Dataset."""
|
329 |
|
nerf/model_utils.py
CHANGED
File without changes
|
nerf/models.py
CHANGED
@@ -136,7 +136,7 @@ class NerfModel(nn.Module):
|
|
136 |
(comp_rgb, disp, acc),
|
137 |
]
|
138 |
|
139 |
-
if self.num_fine_samples > 0
|
140 |
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
141 |
key, rng_1 = random.split(rng_1)
|
142 |
|
@@ -191,8 +191,7 @@ class NerfModel(nn.Module):
|
|
191 |
)
|
192 |
ret.append((comp_rgb, disp, acc))
|
193 |
if rgb_only:
|
194 |
-
|
195 |
-
return [None, ret[0][0]]
|
196 |
return ret
|
197 |
|
198 |
def construct_nerf(key, example_batch, args):
|
|
|
136 |
(comp_rgb, disp, acc),
|
137 |
]
|
138 |
|
139 |
+
if self.num_fine_samples > 0:
|
140 |
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
141 |
key, rng_1 = random.split(rng_1)
|
142 |
|
|
|
191 |
)
|
192 |
ret.append((comp_rgb, disp, acc))
|
193 |
if rgb_only:
|
194 |
+
return [ret[0][0], ret[1][0]]
|
|
|
195 |
return ret
|
196 |
|
197 |
def construct_nerf(key, example_batch, args):
|
nerf/utils.py
CHANGED
@@ -66,11 +66,11 @@ def define_flags():
|
|
66 |
flags.DEFINE_bool("use_semantic_loss", True,
|
67 |
"whether use semantic loss or not")
|
68 |
flags.DEFINE_string("clip_model_name", "openai/clip-vit-base-patch32", "model type for CLIP")
|
69 |
-
flags.DEFINE_string("clip_output_dtype", "
|
70 |
"float32/ float16 (float16 for memory saving)")
|
71 |
flags.DEFINE_integer("sc_loss_every", 16,
|
72 |
"no. of steps to take before performing semantic loss evaluation")
|
73 |
-
flags.DEFINE_float("sc_loss_mult", 1e-
|
74 |
"weighting for semantic loss from CLIP")
|
75 |
|
76 |
# Dataset Flags
|
@@ -166,6 +166,8 @@ def define_flags():
|
|
166 |
|
167 |
flags.DEFINE_integer("max_steps", 1000000,
|
168 |
"the number of optimization steps.")
|
|
|
|
|
169 |
flags.DEFINE_integer("save_every", 10000,
|
170 |
"the number of steps to save a checkpoint.")
|
171 |
flags.DEFINE_integer("print_every", 100,
|
|
|
66 |
flags.DEFINE_bool("use_semantic_loss", True,
|
67 |
"whether use semantic loss or not")
|
68 |
flags.DEFINE_string("clip_model_name", "openai/clip-vit-base-patch32", "model type for CLIP")
|
69 |
+
flags.DEFINE_string("clip_output_dtype", "float16",
|
70 |
"float32/ float16 (float16 for memory saving)")
|
71 |
flags.DEFINE_integer("sc_loss_every", 16,
|
72 |
"no. of steps to take before performing semantic loss evaluation")
|
73 |
+
flags.DEFINE_float("sc_loss_mult", 1e-2,
|
74 |
"weighting for semantic loss from CLIP")
|
75 |
|
76 |
# Dataset Flags
|
|
|
166 |
|
167 |
flags.DEFINE_integer("max_steps", 1000000,
|
168 |
"the number of optimization steps.")
|
169 |
+
flags.DEFINE_integer("stop_sc_loss", 1000000,
|
170 |
+
"the number of sc_loss optimization steps")
|
171 |
flags.DEFINE_integer("save_every", 10000,
|
172 |
"the number of steps to save a checkpoint.")
|
173 |
flags.DEFINE_integer("print_every", 100,
|
requirements.txt
CHANGED
File without changes
|
run.sh
CHANGED
File without changes
|
train.py
CHANGED
@@ -50,7 +50,6 @@ print(f"detected device: {jax.local_devices()}")
|
|
50 |
|
51 |
|
52 |
def train_step(model, clip_model, rng, state, batch, lr, step, K,):
|
53 |
-
# TODO make clip_grad input enable
|
54 |
"""One optimization step.
|
55 |
|
56 |
Args:
|
@@ -102,7 +101,6 @@ def train_step(model, clip_model, rng, state, batch, lr, step, K,):
|
|
102 |
|
103 |
(_, stats), grad = (
|
104 |
jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target))
|
105 |
-
#grad = jax.lax.pmean(grad, axis_name="batch")
|
106 |
stats = jax.lax.pmean(stats, axis_name="batch")
|
107 |
|
108 |
# Clip the gradient by value.
|
@@ -238,26 +236,16 @@ def main(unused_argv):
|
|
238 |
|
239 |
grad, stats, keys = train_pstep(keys, state, batch, lr, step, FLAGS.sc_loss_every)
|
240 |
|
241 |
-
if step%FLAGS.sc_loss_every == 0 and FLAGS.use_semantic_loss:
|
242 |
sc_batch = dataset.get_clip_data()
|
243 |
if jax.local_device_count() > 1:
|
244 |
sc_loss, sc_grad, sc_image = clip_utils.semantic_step_multi(render_pfn_, clip_model, keys[0], state, sc_batch, lr)
|
245 |
else:
|
246 |
sc_loss, sc_grad, sc_image = clip_utils.semantic_step_single(model, clip_model, keys[0], state, sc_batch, lr)
|
247 |
|
248 |
-
if jax.host_id() == 0 and step%FLAGS.print_every:
|
249 |
-
for mlp_k, mlp in grad['params'].items():
|
250 |
-
for layer_k, layer_g in mlp.items():
|
251 |
-
summary_writer.scalar("%s/%s/kernel_grad"%(mlp_k, layer_k), jnp.linalg.norm(jnp.mean(layer_g['kernel'],0)), step)
|
252 |
-
for mlp_k, mlp in sc_grad['params'].items():
|
253 |
-
for layer_k, layer_g in mlp.items():
|
254 |
-
summary_writer.scalar("%s/%s/kernel_sc_grad"%(mlp_k, layer_k), jnp.linalg.norm(layer_g['kernel']), step)
|
255 |
-
|
256 |
leaves, treedef = jax.tree_flatten(grad)
|
257 |
sc_leaves, _ = jax.tree_flatten(sc_grad)
|
258 |
grad = treedef.unflatten(g+jnp.expand_dims(sc_g,0) for g, sc_g in zip(leaves, sc_leaves))
|
259 |
-
|
260 |
-
|
261 |
|
262 |
state = update_pstep(state, grad, lr)
|
263 |
|
@@ -276,24 +264,26 @@ def main(unused_argv):
|
|
276 |
summary_writer.scalar("psnr/train", stats.psnr[0], step)
|
277 |
summary_writer.scalar("train_coarse/loss", stats.loss_c[0], step)
|
278 |
summary_writer.scalar("train_coarse/psnr", stats.psnr_c[0], step)
|
279 |
-
|
280 |
avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
|
281 |
avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
|
282 |
stats_trace = []
|
283 |
summary_writer.scalar("train_avg/loss", avg_loss, step)
|
284 |
summary_writer.scalar("train_avg/psnr", avg_psnr, step)
|
285 |
-
|
286 |
steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
|
287 |
reset_timer = True
|
288 |
rays_per_sec = FLAGS.batch_size * steps_per_sec
|
289 |
-
summary_writer.scalar("
|
290 |
-
summary_writer.scalar("
|
|
|
|
|
291 |
precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
|
292 |
print(("{:" + "{:d}".format(precision) + "d}").format(step) +
|
293 |
f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " +
|
294 |
f"avg_loss={avg_loss:0.4f}, " +
|
295 |
f"weight_l2={stats.weight_l2[0]:0.2e}, " +
|
296 |
-
|
297 |
f"lr={lr:0.2e}, {rays_per_sec:0.0f} rays/sec")
|
298 |
if step % FLAGS.save_every == 0:
|
299 |
state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
|
@@ -324,12 +314,10 @@ def main(unused_argv):
|
|
324 |
eval_time = time.time() - t_eval_start
|
325 |
num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1]))
|
326 |
rays_per_sec = num_rays / eval_time
|
327 |
-
summary_writer.scalar("test_rays_per_sec", rays_per_sec, step)
|
328 |
print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec")
|
329 |
summary_writer.scalar("psnr/test", psnr, step)
|
330 |
-
summary_writer.scalar("test_psnr", psnr, step)
|
331 |
summary_writer.scalar("ssim/ssim", ssim, step)
|
332 |
-
summary_writer.scalar("test_ssim", ssim, step)
|
333 |
if sc_image is not None:
|
334 |
summary_writer .image("random_ray_image", sc_image, step)
|
335 |
summary_writer.image("test_pred_color", pred_color, step)
|
|
|
50 |
|
51 |
|
52 |
def train_step(model, clip_model, rng, state, batch, lr, step, K,):
|
|
|
53 |
"""One optimization step.
|
54 |
|
55 |
Args:
|
|
|
101 |
|
102 |
(_, stats), grad = (
|
103 |
jax.value_and_grad(loss_fn, has_aux=True)(state.optimizer.target))
|
|
|
104 |
stats = jax.lax.pmean(stats, axis_name="batch")
|
105 |
|
106 |
# Clip the gradient by value.
|
|
|
236 |
|
237 |
grad, stats, keys = train_pstep(keys, state, batch, lr, step, FLAGS.sc_loss_every)
|
238 |
|
239 |
+
if step%FLAGS.sc_loss_every == 0 and FLAGS.use_semantic_loss and step < FLAGS.stop_sc_loss:
|
240 |
sc_batch = dataset.get_clip_data()
|
241 |
if jax.local_device_count() > 1:
|
242 |
sc_loss, sc_grad, sc_image = clip_utils.semantic_step_multi(render_pfn_, clip_model, keys[0], state, sc_batch, lr)
|
243 |
else:
|
244 |
sc_loss, sc_grad, sc_image = clip_utils.semantic_step_single(model, clip_model, keys[0], state, sc_batch, lr)
|
245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
leaves, treedef = jax.tree_flatten(grad)
|
247 |
sc_leaves, _ = jax.tree_flatten(sc_grad)
|
248 |
grad = treedef.unflatten(g+jnp.expand_dims(sc_g,0) for g, sc_g in zip(leaves, sc_leaves))
|
|
|
|
|
249 |
|
250 |
state = update_pstep(state, grad, lr)
|
251 |
|
|
|
264 |
summary_writer.scalar("psnr/train", stats.psnr[0], step)
|
265 |
summary_writer.scalar("train_coarse/loss", stats.loss_c[0], step)
|
266 |
summary_writer.scalar("train_coarse/psnr", stats.psnr_c[0], step)
|
267 |
+
|
268 |
avg_loss = np.mean(np.concatenate([s.loss for s in stats_trace]))
|
269 |
avg_psnr = np.mean(np.concatenate([s.psnr for s in stats_trace]))
|
270 |
stats_trace = []
|
271 |
summary_writer.scalar("train_avg/loss", avg_loss, step)
|
272 |
summary_writer.scalar("train_avg/psnr", avg_psnr, step)
|
273 |
+
|
274 |
steps_per_sec = FLAGS.print_every / (time.time() - t_loop_start)
|
275 |
reset_timer = True
|
276 |
rays_per_sec = FLAGS.batch_size * steps_per_sec
|
277 |
+
summary_writer.scalar("stats/weight_l2", stats.weight_l2[0], step)
|
278 |
+
summary_writer.scalar("stats/learning_rate", lr, step)
|
279 |
+
summary_writer.scalar("iter_speed/train_steps_per_sec", steps_per_sec, step)
|
280 |
+
summary_writer.scalar("iter_speed/train_rays_per_sec", rays_per_sec, step)
|
281 |
precision = int(np.ceil(np.log10(FLAGS.max_steps))) + 1
|
282 |
print(("{:" + "{:d}".format(precision) + "d}").format(step) +
|
283 |
f"/{FLAGS.max_steps:d}: " + f"i_loss={stats.loss[0]:0.4f}, " +
|
284 |
f"avg_loss={avg_loss:0.4f}, " +
|
285 |
f"weight_l2={stats.weight_l2[0]:0.2e}, " +
|
286 |
+
f"sc_loss={sc_loss:0.4f}, " +
|
287 |
f"lr={lr:0.2e}, {rays_per_sec:0.0f} rays/sec")
|
288 |
if step % FLAGS.save_every == 0:
|
289 |
state_to_save = jax.device_get(jax.tree_map(lambda x: x[0], state))
|
|
|
314 |
eval_time = time.time() - t_eval_start
|
315 |
num_rays = jnp.prod(jnp.array(test_case["rays"].directions.shape[:-1]))
|
316 |
rays_per_sec = num_rays / eval_time
|
317 |
+
summary_writer.scalar("iter_speed/test_rays_per_sec", rays_per_sec, step)
|
318 |
print(f"Eval {step}: {eval_time:0.3f}s., {rays_per_sec:0.0f} rays/sec")
|
319 |
summary_writer.scalar("psnr/test", psnr, step)
|
|
|
320 |
summary_writer.scalar("ssim/ssim", ssim, step)
|
|
|
321 |
if sc_image is not None:
|
322 |
summary_writer .image("random_ray_image", sc_image, step)
|
323 |
summary_writer.image("test_pred_color", pred_color, step)
|
train.sh
CHANGED
File without changes
|