putting-nerf-on-a-diet / nerf /clip_utils.py
sseung0703's picture
update
e8c4ed3
raw
history blame
5.09 kB
import math
from typing import Optional
from absl import flags
from functools import partial
import jax
from jax import random
import jax.numpy as jnp
import numpy as np
from transformers import FlaxCLIPModel
from nerf import utils
FLAGS = flags.FLAGS
@partial(jax.jit, static_argnums=[0])
def semantic_loss(clip_model, src_image, target_embedding):
#c_image = utils.unshard(src_image[0])
f_image = utils.unshard(src_image[-1])
w = int(math.sqrt(src_image[-1].size//3))
#c_image = c_image.reshape([w, w, 3])
f_image = f_image.reshape([w, w, 3])
src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.expand_dims(f_image,0).transpose(0, 3, 1, 2)))
#src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image, f_image]).transpose(0, 3, 1, 2)))
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
sc_loss = 1 - jnp.sum(src_embedding * target_embedding)
return sc_loss, f_image
def semantic_step_multi(render_pfn, clip_model, rng, state, batch, lr):
random_rays = jax.tree_map(lambda x: utils.shard(x).astype(jnp.float16), batch["random_rays"])
target_embedding = batch["embedding"].astype(jnp.float16)
rng, key_0, key_1 = random.split(rng,3)
def loss_fn(variables):
src_image = render_pfn(variables, key_0, key_1, random_rays)
sc_loss, f_image = semantic_loss(clip_model, src_image, target_embedding)
return sc_loss * FLAGS.sc_loss_mult, f_image
(sc_loss, src_image), grad = jax.value_and_grad(loss_fn, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
return sc_loss, grad, src_image
@partial(jax.jit, static_argnums=[0, 1])
def semantic_step_single(model, clip_model, rng, state, batch, lr):
batch = jax.tree_map(lambda x: x.astype(jnp.float16), batch)
# the batch is without shard
random_rays = batch["random_rays"]
rng, key_0, key_1 = random.split(rng,3)
def semantic_loss(variables):
c_image, f_image = model.apply(variables, key_0, key_1, random_rays, False, rgb_only = True)
# reshape flat pixel to an image (assume 3 channels & square shape)
w = int(math.sqrt(f_image.shape[0]))
# c_image = c_image.reshape([w, w, 3])
f_image = f_image.reshape([w, w, 3])
src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.expand_dims(f_image,0).transpose(0, 3, 1, 2)))
# src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.stack([c_image, f_image]).transpose(0, 3, 1, 2)))
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True)
target_embedding = batch["embedding"]
sc_loss = 0.5 * jnp.sum((src_embedding - target_embedding)**2)
return sc_loss * FLAGS.sc_loss_mult, f_image
(sc_loss, src_image), grad = jax.value_and_grad(semantic_loss, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target)
return sc_loss, grad, src_image
def trans_t(t):
return jnp.array([
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, t],
[0, 0, 0, 1]], dtype=jnp.float32)
def rot_phi(phi):
return jnp.array([
[1, 0, 0, 0],
[0, jnp.cos(phi), jnp.sin(phi), 0],
[0,-jnp.sin(phi), jnp.cos(phi), 0],
[0, 0, 0, 1]], dtype=jnp.float32)
def rot_theta(th):
return jnp.array([
[jnp.cos(th), 0,-jnp.sin(th), 0],
[0, 1, 0, 0],
[jnp.sin(th), 0, jnp.cos(th), 0],
[0, 0, 0, 1]], dtype=jnp.float32)
def pose_spherical(radius, theta, phi):
c2w = trans_t(radius)
c2w = rot_phi(phi) @ c2w
c2w = rot_theta(theta) @ c2w
c2w = jnp.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
return c2w
def random_pose(rng, bds):
rng, *rng_inputs = jax.random.split(rng, 3)
radius = random.uniform(rng_inputs[1], minval=bds[0], maxval=bds[1])
theta = random.uniform(rng_inputs[1], minval=-jnp.pi, maxval=jnp.pi)
phi = random.uniform(rng_inputs[1], minval=0, maxval=jnp.pi/2)
return pose_spherical(radius, theta, phi)
def preprocess_for_CLIP(image):
"""
jax-based preprocessing for CLIP
image [B, 3, H, W]: batch image
return [B, 3, 224, 224]: pre-processed image for CLIP
"""
B, D, H, W = image.shape
mean = jnp.array([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1)
std = jnp.array([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1)
image = jax.image.resize(image, (B, D, 224, 224), 'bicubic') # assume that images have rectangle shape.
image = (image - mean.astype(image.dtype)) / std.astype(image.dtype)
return image
def init_CLIP(dtype: str, model_name: Optional[str]) -> FlaxCLIPModel:
if dtype == 'float16':
dtype = jnp.float16
elif dtype == 'float32':
dtype = jnp.float32
else:
raise ValueError
if model_name is None:
model_name = 'openai/clip-vit-base-patch32'
return FlaxCLIPModel.from_pretrained(model_name, dtype=dtype)