Spaces:
Build error
Build error
import math | |
from typing import Optional | |
from absl import flags | |
from functools import partial | |
import jax | |
from jax import random | |
import jax.numpy as jnp | |
import numpy as np | |
from transformers import FlaxCLIPModel | |
from jaxnerf.nerf import utils | |
FLAGS = flags.FLAGS | |
def semantic_loss(clip_model, src_image, target_embedding): | |
src_image = utils.unshard(src_image) | |
w = int(math.sqrt(src_image.size//3)) | |
src_image = src_image.reshape([w, w, 3]) | |
src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.expand_dims(src_image,0).transpose(0, 3, 1, 2))) | |
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True) | |
src_embedding = jnp.array(src_embedding) | |
sc_loss = 0.5 * jnp.sum((src_embedding - target_embedding) ** 2) / src_embedding.shape[0] | |
return sc_loss, src_image | |
def semantic_step_multi(render_pfn, clip_model, rng, state, batch, lr): | |
random_rays = jax.tree_map(lambda x: utils.shard(x).astype(jnp.float16), batch["random_rays"]) | |
target_embedding = batch["embedding"].astype(jnp.float16) | |
rng, key_0, key_1 = random.split(rng,3) | |
def loss_fn(variables): | |
src_image = render_pfn(variables, key_0, key_1, random_rays) | |
sc_loss, src_image = semantic_loss(clip_model, src_image, target_embedding) | |
return sc_loss * FLAGS.sc_loss_mult, src_image | |
(sc_loss, src_image), grad = jax.value_and_grad(loss_fn, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0].astype(jnp.float16), state)).optimizer.target) | |
return sc_loss, grad, src_image | |
def semantic_step_single(model, clip_model, rng, state, batch, lr): | |
# the batch is without shard | |
random_rays = batch["random_rays"] | |
rng, key_0, key_1 = random.split(rng,3) | |
def semantic_loss(variables): | |
src_image = model.apply(variables, key_0, key_1, random_rays, False, rgb_only = True) | |
# reshape flat pixel to an image (assume 3 channels & square shape) | |
w = int(math.sqrt(src_image.shape[0])) | |
src_image = src_image.reshape([w, w, 3]) | |
src_embedding = clip_model.get_image_features(pixel_values=preprocess_for_CLIP(jnp.expand_dims(src_image,0).transpose(0, 3, 1, 2))) | |
src_embedding /= jnp.linalg.norm(src_embedding, axis=-1, keepdims=True) | |
src_embedding = jnp.array(src_embedding) | |
target_embedding = batch["embedding"] | |
sc_loss = 0.5 * jnp.sum((src_embedding - target_embedding) ** 2) / src_embedding.shape[0] | |
return sc_loss * FLAGS.sc_loss_mult, src_image | |
(sc_loss, src_image), grad = jax.value_and_grad(semantic_loss, has_aux = True)(jax.device_get(jax.tree_map(lambda x:x[0], state)).optimizer.target) | |
return sc_loss, grad, src_image | |
def trans_t(t): | |
return jnp.array([ | |
[1, 0, 0, 0], | |
[0, 1, 0, 0], | |
[0, 0, 1, t], | |
[0, 0, 0, 1]], dtype=jnp.float32) | |
def rot_phi(phi): | |
return jnp.array([ | |
[1, 0, 0, 0], | |
[0, jnp.cos(phi), jnp.sin(phi), 0], | |
[0,-jnp.sin(phi), jnp.cos(phi), 0], | |
[0, 0, 0, 1]], dtype=jnp.float32) | |
def rot_theta(th): | |
return jnp.array([ | |
[jnp.cos(th), 0,-jnp.sin(th), 0], | |
[0, 1, 0, 0], | |
[jnp.sin(th), 0, jnp.cos(th), 0], | |
[0, 0, 0, 1]], dtype=jnp.float32) | |
def pose_spherical(radius, theta, phi): | |
c2w = trans_t(radius) | |
c2w = rot_phi(phi) @ c2w | |
c2w = rot_theta(theta) @ c2w | |
c2w = jnp.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w | |
return c2w | |
def random_pose(rng, bds): | |
rng, *rng_inputs = jax.random.split(rng, 3) | |
radius = random.uniform(rng_inputs[1], minval=bds[0], maxval=bds[1]) | |
theta = random.uniform(rng_inputs[1], minval=-jnp.pi, maxval=jnp.pi) | |
phi = random.uniform(rng_inputs[1], minval=0, maxval=jnp.pi/2) | |
return pose_spherical(radius, theta, phi) | |
def preprocess_for_CLIP(image): | |
""" | |
jax-based preprocessing for CLIP | |
image [B, 3, H, W]: batch image | |
return [B, 3, 224, 224]: pre-processed image for CLIP | |
""" | |
B, D, H, W = image.shape | |
mean = jnp.array([0.48145466, 0.4578275, 0.40821073]).reshape(1, 3, 1, 1) | |
std = jnp.array([0.26862954, 0.26130258, 0.27577711]).reshape(1, 3, 1, 1) | |
image = jax.image.resize(image, (B, D, 224, 224), 'bicubic') # assume that images have rectangle shape. | |
image = (image - mean.astype(image.dtype)) / std.astype(image.dtype) | |
return image | |
def init_CLIP(dtype: str, model_name: Optional[str]) -> FlaxCLIPModel: | |
if dtype == 'float16': | |
dtype = jnp.float16 | |
elif dtype == 'float32': | |
dtype = jnp.float32 | |
else: | |
raise ValueError | |
if model_name is None: | |
model_name = 'openai/clip-vit-base-patch32' | |
return FlaxCLIPModel.from_pretrained(model_name, dtype=dtype) | |