## THINGS TO DO??

### Existing NeRF
1. Rewrite the model into FLAX (otherwise cant work with Flax-CLIP?) (Haiku seems like functional, but Flax is OOP?)
2. Rewrite training loop into FLAX (does FLAX provide abstraction for writing training loop?? current training loop is pretty low-level)
3. Refactor the notebook code --> module code
    - render scene, visualizing, animation ... etc.
4. Dataloading for our concerned dataset
5. consolidate all controllabe params in a class (e.g. `Config`)


### NeRF --> DietNeRF
1. Change sampling to 8 samples only
2. Add CLIP into the training loop for a new loss function
3. Check if DietNeRF can get comparable result to NeRF

### Optional?
1. Understand what the hell each operations are doing? (e.g. `get_rays` ... etc)
2. Add W&B for visualization?
3. Super large scale NeRF(ssssss) training --> get huge samples of scene for POC
4. Optimize bottleneck operations by `jax.vmap`, `jax.pmap`

In [63]:
# enable showing live "loss plot" inside notebook
!pip install livelossplot



In [64]:
%%capture
!conda install -y -c conda-forge jax jaxlib flax optax datasets transformers
!conda install -y importlib-metadata
!pip install -U dm-haiku

In [65]:
# TPU setup
import os
if 'TPU_NAME' in os.environ:
    import requests
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1

    from jax.config import config
    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
    print('No TPU detected. Can be changed under "Runtime/Change runtime type".')

# Module check
import jax
import flax
import haiku as hk

for _m in (jax, flax, hk):
    print(f'{_m.__name__}: {_m.__version__}')
jax.local_devices()

No TPU detected. Can be changed under "Runtime/Change runtime type".
jax: 0.2.16
flax: 0.3.4
haiku: 0.0.4


[CpuDevice(id=0)]

In [66]:
from functools import partial

import jax
from jax import random, grad, jit, vmap, flatten_util, nn
from jax.experimental import optimizers  # change due to version difference
from jax.config import config
import jax.numpy as np

import haiku as hk
from haiku._src import utils

from livelossplot import PlotLosses
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm as tqdm
import cv2
import imageio
import glob
from IPython.display import clear_output
import pickle
from skimage.metrics import structural_similarity as ssim_fn

rng = jax.random.PRNGKey(42)

In [67]:
#ls ../input/pull-phototourism-images/sacre_coeur/dense/images
DATASET = 'sacre'
posedir = f'../input/phototourism/phototourism/sacre' # Directory condtains [bds.npy, c2w_mats.npy, kinv_mats.npy, res_mats.npy]
imgdir = f'../input/pull-phototourism-images/sacre_coeur/dense/images' # Directory of images

### 1. Helper Functions for Loading Data

In [68]:
posedata = {}
for f in os.listdir(posedir):
    if '.npy' not in f:
        continue
    z = np.load(os.path.join(posedir, f))
    posedata[f.split('.')[0]] = z
print('Pose data loaded - ', posedata.keys())

imgfiles = sorted(glob.glob(imgdir + '/*.jpg'))
print(f'{len(imgfiles)} images')

Pose data loaded -  dict_keys(['kinv_mats', 'res_mats', 'c2w_mats', 'bds'])
1179 images


In [69]:
@jit
def get_rays(c2w, kinv, i, j):
#     i, j = np.meshgrid(np.arange(W), np.arange(H), indexing='xy')
    pixco = np.stack([i, j, np.ones_like(i)], -1)
    dirs = pixco @ kinv.T
#     dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1)
    rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
    rays_o = np.broadcast_to(c2w[:3,-1], rays_d.shape)
    return np.stack([rays_o, rays_d], 0)


def normalize(x):
    return x / np.linalg.norm(x)


def viewmatrix(z, up, pos):
    vec2 = normalize(z)
    vec1_avg = up
    vec0 = normalize(np.cross(vec1_avg, vec2))
    vec1 = normalize(np.cross(vec2, vec0))
    m = np.stack([vec0, vec1, vec2, pos], 1)
    return m


def ptstocam(pts, c2w):
    tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0]
    return tt


def poses_avg(poses):
    center = poses[:, :3, 3].mean(0)
    vec2 = normalize(poses[:, :3, 2].sum(0))
    up = poses[:, :3, 1].sum(0)
    return viewmatrix(vec2, up, center)


def render_path_spiral(c2w, up, rads, focal, zrate, rots, N):
    """
    enumerate list of poses around a spiral
    used for test set visualization
    """
    render_poses = []
    rads = np.array(list(rads) + [1.])
    for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:
        c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) 
        z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.])))
        render_poses.append(viewmatrix(z, up, c))
    return render_poses

In [70]:
def get_example(img_idx, split='train', downsample=4):
    sc = .05
    
    # first 20 are test, next 5 are validation, the rest are training:
    # https://github.com/tancik/learnit/issues/3
    if 'train' in split:
        img_idx = img_idx + 25
    if 'val' in split:
        img_idx = img_idx + 20
        
    # uint8 --> float
    img = imageio.imread(imgfiles[img_idx])[...,:3]/255.
    
    # WHAT DO THESE MATRICES MEAN???
    # (4, 4)
    c2w = posedata['c2w_mats'][img_idx]
    # (3, 3)
    kinv = posedata['kinv_mats'][img_idx]
    c2w = np.concatenate([c2w[:3,:3], c2w[:3,3:4]*sc], -1)
    # (2, )
    bds = posedata['bds'][img_idx] * np.array([.9, 1.2]) * sc
    H, W = img.shape[:2]
    
    # (0, 4, 8, ..., H)
    # WHAT ARE THE PURPOSES OF THIS MATRIX???
    i, j = np.meshgrid(np.arange(0,W,downsample), np.arange(0,H,downsample), indexing='xy')
    
    test_images = img[j, i]
    test_rays = get_rays(c2w, kinv, i, j)
    return test_images, test_rays, bds

### 2. NeRF Renderer

In [71]:
def render_rays(
        rnd_input, model, params, 
        bvals, rays, near, far, 
        N_samples, rand=False, allret=False
    ):
    rays_o, rays_d = rays

    # Compute 3D query points
    z_vals = np.linspace(near, far, N_samples) 
    if rand:
        z_vals += random.uniform(rnd_input, shape=list(rays_o.shape[:-1]) + [N_samples]) * (far-near)/N_samples
    # r(t) = o + t*d
    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]
    
    # Run network
    pts_flat = np.reshape(pts, [-1,3])
    if bvals is not None:
        pts_flat = np.concatenate([np.sin(pts_flat @ bvals.T), 
                                np.cos(pts_flat @ bvals.T)], axis=-1)
        
    raw = model.apply(params, pts_flat)
    raw = np.reshape(raw, list(pts.shape[:-1]) + [4])
    
    # Compute opacities and colors
    rgb, sigma_a = raw[...,:3], raw[...,3]
    sigma_a = jax.nn.relu(sigma_a)
    rgb = jax.nn.sigmoid(rgb) 
    
    # Do volume rendering
    dists = np.concatenate([z_vals[..., 1:] - z_vals[..., :-1], np.broadcast_to([1e10], z_vals[...,:1].shape)], -1) 
    alpha = 1. - np.exp(-sigma_a * dists)
    trans = np.minimum(1., 1. - alpha + 1e-10)
    trans = np.concatenate([np.ones_like(trans[...,:1]), trans[...,:-1]], -1)  
    weights = alpha * np.cumprod(trans, -1)
    
    rgb_map = np.sum(weights[...,None] * rgb, -2) 
    acc_map = np.sum(weights, -1)
    
    if False:
        rgb_map = rgb_map + (1.-acc_map[..., None])
    
    if not allret:
        return rgb_map
    
    depth_map = np.sum(weights * z_vals, -1) 

    return rgb_map, depth_map, acc_map


def render_fn_inner(rnd_input, model, params, bvals, rays, near, far, rand, allret, N_samples):
    return render_rays(rnd_input, model, params, bvals, rays, near, far, 
                       N_samples=N_samples, rand=rand, allret=allret)

# optimize render_fn_inner by JIT (func in, func out)
render_fn_inner = jit(render_fn_inner, static_argnums=(1, 7, 8, 9))


def render_fn(rnd_input, model, params, bvals, rays, near, far, N_samples, rand):
    chunk = 5
    for i in range(0, rays.shape[1], chunk):
        out = render_fn_inner(rnd_input, model, params, bvals, rays[:,i:i+chunk], near, far, rand, True, N_samples)
        if i==0:
            rets = out
        else:
            rets = [np.concatenate([a, b], 0) for a, b in zip(rets, out)]
    return rets

### 3. NeRF Model Architecture

In [72]:
class Model(hk.Module):
    def __init__(self):
        super().__init__()
        self.width = 256
        self.depth = 6
        self.use_viewdirs = False
                
    def __call__(self, coords, view_dirs=None):
        sh = coords.shape
        if self.use_viewdirs:
            viewdirs = None
            viewdirs = np.repeat(viewdirs[...,None,:], coords.shape[-2], axis=-2)
            viewdirs /= np.linalg.norm(viewdirs, axis=-1, keepdims=True)
            viewdirs = np.reshape(viewdirs, (-1,3))
            viewdirs = hk.Linear(output_size=self.width//2)(viewdirs)
            viewdirs = jax.nn.relu(viewdirs)
        coords = np.reshape(coords, [-1,3])
        
        # positional encoding
        x = np.concatenate([np.concatenate([np.sin(coords*(2**i)), np.cos(coords*(2**i))], axis=-1) for i in np.linspace(0,8,20)], axis=-1)

        for _ in range(self.depth-1):
            x = hk.Linear(output_size=self.width)(x)
            x = jax.nn.relu(x)
            
        if self.use_viewdirs:
            density = hk.Linear(output_size=1)(x)
            x = np.concatenate([x,viewdirs], axis=-1)
            x = hk.Linear(output_size=self.width)(x)
            x = jax.nn.relu(x)
            rgb = hk.Linear(output_size=3)(x)
            out = np.concatenate([density, rgb], axis=-1)
        else:
            out = hk.Linear(output_size=4)(x)
        out = np.reshape(out, list(sh[:-1]) + [4])
        return out

### 4. Training Loop

In [73]:
batch_size = 64
N_samples = 128
inner_step_size = 1

model = hk.without_apply_rng(hk.transform(lambda x, y=None: Model()(x, y)))

mse_fn = jit(lambda x, y: np.mean((x - y)**2))
psnr_fn = jit(lambda x, y: -10 * np.log10(mse_fn(x, y)))

@jit
def single_step(rng, image, rays, params, bds):
    def sgd(param, update):
        return param - inner_step_size * update
    
    rng, rng_inputs = jax.random.split(rng)
    def loss_model(params):
        g = render_rays(rng_inputs, model, params, None, rays, bds[0], bds[1], N_samples, rand=True)
        return mse_fn(g, image)
    
    model_loss, grad = jax.value_and_grad(loss_model)(params)
    new_params = jax.tree_multimap(sgd, params, grad)
    return rng, new_params, model_loss

def update_network_weights(rng, images, rays, params, inner_steps, bds):
    for _ in range(inner_steps):
        rng, rng_input = random.split(rng)
        idx = random.randint(rng_input, shape=(batch_size,), minval=0, maxval=images.shape[0])
        image_sub = images[idx,:]
        rays_sub = rays[:,idx,:]
        
        rng, params, loss = single_step(rng, image_sub, rays_sub, params, bds)
    return rng, params, loss

In [74]:
plt_groups = {'Train PSNR':[], 'Test PSNR':[]}
plotlosses_model = PlotLosses(groups=plt_groups)

In [75]:
max_iters = 150000

inner_update_steps = 64
lr = 5e-4

exp_name = f'{DATASET}_ius_{inner_update_steps}_ilr_{inner_step_size}_olr_{lr}_bs_{batch_size}'
exp_dir = f'checkpoint/phototourism_checkpoints/{exp_name}/'

if not os.path.exists(exp_dir):
    os.makedirs(exp_dir)

params = model.init(rng, np.ones((1,3)))

opt = optimizers.adam(lr)
opt_state = opt.init_fun(params)

test_inner_steps = 64


def update_model(rng, params, opt_state, image, rays, bds):
    rng, new_params, model_loss = update_network_weights(rng, image, rays, params, inner_update_steps, bds)
    
    def calc_grad(params, new_params):
        return params - new_params
    model_grad = jax.tree_multimap(calc_grad, params, new_params)
    
    updates, opt_state = opt.update(model_grad, opt_state)
    params = optimizers.apply_updates(params, updates)
    return rng, params, opt_state, model_loss

@jit
def update_model_single(rng, params, opt_state, image, rays, bds):
    rng, new_params, model_loss = single_step(rng, image, rays, params, bds)
    
    def calc_grad(params, new_params):
        return params - new_params
    model_grad = jax.tree_multimap(calc_grad, params, new_params)
    
    updates, opt_state = opt.update(model_grad, opt_state)
    params = optimizers.apply_updates(params, updates)
    return rng, params, opt_state, model_loss



plt_groups['Train PSNR'].append(exp_name+f'_train')
plt_groups['Test PSNR'].append(exp_name+f'_test')
step = 0

train_psnrs = []
rng = jax.random.PRNGKey(0)

train_steps = []
train_psnrs_all = []
test_steps = []
test_psnrs_all = []
for step in tqdm(range(max_iters)):
    try:
        rng, rng_input = jax.random.split(rng)
        img_idx = random.randint(rng_input, shape=(), minval=0, maxval=len(imgfiles)-25)        
        images, rays, bds = get_example(img_idx, downsample=1)
    except:
        print('data loading error')
        raise
        continue
        

    images = np.reshape(images, (-1,3))
    rays = np.reshape(rays, (2,-1,3))

    if inner_update_steps == 1:
        rng, rng_input = random.split(rng)
        idx = random.randint(rng_input, shape=(batch_size,), minval=0, maxval=images.shape[0])
        rng, params, opt_state, loss = update_model_single(rng, params, opt_state, 
                                                           images[idx,:], rays[:,idx,:], bds)
    else:
        rng, params, opt_state, loss = update_model(rng, params, opt_state, 
                                                    images, rays, bds)
    train_psnrs.append(-10 * np.log10(loss))
    
    # track model loss
    if step % 250 == 0:
        plotlosses_model.update({exp_name+'_train':np.mean(np.array(train_psnrs))}, current_step=step)
        train_steps.append(step)
        train_psnrs_all.append(np.mean(np.array(train_psnrs)))
        train_psnrs = []
        
    # run validation
    if step % 500 == 0 and step != 0:
        test_psnr = []
        for ti in range(5):
            test_images, test_rays, bds = get_example(ti, split='val', downsample=2)

            test_images, test_holdout_images = np.split(test_images, [test_images.shape[1]//2], axis=1)
            test_rays, test_holdout_rays = np.split(test_rays, [test_rays.shape[2]//2], axis=2)

            test_images_flat = np.reshape(test_images, (-1,3))
            test_rays = np.reshape(test_rays, (2,-1,3))

            rng, test_params, test_inner_loss = update_network_weights(rng, test_images_flat, test_rays, params, test_inner_steps, bds)

            test_result = np.clip(render_fn(rng, model, test_params, None, test_holdout_rays, bds[0], bds[1], N_samples, rand=False)[0], 0, 1)
            test_psnr.append(psnr_fn(test_holdout_images, test_result))
        test_psnr = np.mean(np.array(test_psnr))

        test_steps.append(step)
        test_psnrs_all.append(test_psnr)
        
        plotlosses_model.update({exp_name+'_test':test_psnr}, current_step=step)
        plotlosses_model.send()

        plt.figure(figsize=(15,5))   
        plt.subplot(1,3, 1)
        plt.imshow(test_images)
        plt.subplot(1,3, 2)
        plt.imshow(test_holdout_images)
        plt.subplot(1,3, 3)
        plt.imshow(test_result)
        plt.show()
        
    # save model checkpoint + render sample view on test set for model check
    if step % 10000 == 0 and step != 0:
        test_images, test_rays, bds = get_example(0, split='test')
        test_images_flat = np.reshape(test_images, (-1,3))
        test_rays = np.reshape(test_rays, (2,-1,3))
        rng, test_params_1, test_inner_loss = update_network_weights(rng, test_images_flat, test_rays, params, test_inner_steps, bds)

        test_images, test_rays, bds = get_example(1, split='test')
        test_images_flat = np.reshape(test_images, (-1,3))
        test_rays = np.reshape(test_rays, (2,-1,3))
        rng, test_params_2, test_inner_loss = update_network_weights(rng, test_images_flat, test_rays, params, test_inner_steps, bds)
        
        poses = posedata['c2w_mats']
        c2w = poses_avg(poses)
        focal = .8
        render_poses = render_path_spiral(c2w, c2w[:3,1], [.1, .1, .05], focal, zrate=.5, rots=2, N=120)
        
        bds = np.array([5., 25.]) * .05
        H = 128
        W = H*3//2
        f = H * 1.
        kinv = np.array(
            [1./f, 0, -W*.5/f,
             0, -1./f, H*.5/f,
             0, 0, -1.]
        ).reshape([3,3])
        i, j = np.meshgrid(np.arange(0,W), np.arange(0,H), indexing='xy')
        renders = []
        for p, c2w in enumerate(tqdm(render_poses)):
            rays = get_rays(c2w, kinv, i, j)
            interp = p / len(render_poses)
            interp_params = jax.tree_multimap(lambda x, y: y*p/len(render_poses) + x*(1-p/len(render_poses)), test_params_1, test_params_2)
            result = render_fn(rng, model, interp_params, None, rays, bds[0], bds[1], N_samples, rand=False)[0]
            renders.append(result)
        
        renders = (np.clip(np.array(renders), 0, 1)*255).astype(np.uint8)
        imageio.mimwrite(f'{exp_dir}render_sprial_{step}.mp4', renders, fps=30, quality=8)
        
        plt.plot(train_steps, train_psnrs_all)
        plt.savefig(f'{exp_dir}train_curve_{step}.png')
        
        plt.plot(test_steps, test_psnrs_all)
        plt.savefig(f'{exp_dir}test_curve_{step}.png')
        
        with open(f'{exp_dir}checkpount_{step}.pkl', 'wb') as file:
            pickle.dump(params, file)

AttributeError: 'Optimizer' object has no attribute 'init_fun'

### 1. PLAYGROUND
- optimizers API: https://jax.readthedocs.io/en/latest/_modules/jax/experimental/optimizers.html#adam

In [76]:
from flax import linen as nn

In [79]:
nn.Dense?

[0;31mInit signature:[0m
[0mnn[0m[0;34m.[0m[0mDense[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mfeatures[0m[0;34m:[0m [0mint[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0muse_bias[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdtype[0m[0;34m:[0m [0mAny[0m [0;34m=[0m [0;34m<[0m[0;32mclass[0m [0;34m'jax._src.numpy.lax_numpy.float32'[0m[0;34m>[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mprecision[0m[0;34m:[0m [0mAny[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mkernel_init[0m[0;34m:[0m [0mCallable[0m[0;34m[[0m[0;34m[[0m[0mAny[0m[0;34m,[0m [0mIterable[0m[0;34m[[0m[0mint[0m[0;34m][0m[0;34m,[0m [0mAny[0m[0;34m][0m[0;34m,[0m [0mAny[0m[0;34m][0m [0;34m=[0m [0;34m<[0m[0mfunction[0m [0mvariance_scaling[0m[0;34m.[0m[0;34m<[0m[0mlocals[0m[0;34m>[0m[0;34m.[0m[0minit[0m [0mat[0m [0;36m0x7f66d9aac050[0m[0;34m>[0m[0;34m,[0m[0;3

In [95]:
class ModelFlax(nn.Module):
    width = 256
    depth = 6
    
    @nn.compact
    def __call__(self, coords):
        sh = coords.shape
        coords = np.reshape(coords, [-1,3])
        
        # positional encoding
        x = np.concatenate([np.concatenate([np.sin(coords*(2**i)), np.cos(coords*(2**i))], axis=-1) for i in np.linspace(0,8,20)], axis=-1)

        for idx in range(self.depth-1):
            #x = hk.Linear(output_size=self.width)(x)
            x = nn.Dense(self.depth, name=f'fc{idx}')(x)
            x = nn.relu(x)

        #out = hk.Linear(output_size=4)(x)
        out = nn.Dense(4, name='fc_last')(x)
        out = np.reshape(out, list(sh[:-1]) + [4])
        return out

In [96]:
model = ModelFlax()
key1, key2 = random.split(jax.random.PRNGKey(0))
dummy_x = random.normal(key1, (1, 3))
params = model.init(key2, dummy_x)

In [85]:
dummy_x = 
ModelFlax.init

<function flax.linen.module.Module.init(self, rngs: Union[Any, Dict[str, Any]], *args, method: Union[Callable[..., Any], NoneType] = None, mutable: Union[bool, str, Container[str], ForwardRef('DenyList')] = DenyList(deny='intermediates'), **kwargs) -> flax.core.frozen_dict.FrozenDict[str, typing.Mapping[str, typing.Any]]>

In [None]:
model = hk.without_apply_rng(hk.transform(lambda x, y=None: ModelHaiku()(x, y)))
params = model.init(rng, np.ones((1,3)))


In [77]:
class ModelHaiku(hk.Module):
    def __init__(self):
        super().__init__()
        self.width = 256
        self.depth = 6
                
    def __call__(self, coords):
        sh = coords.shape
        coords = np.reshape(coords, [-1,3])
        
        # positional encoding
        x = np.concatenate([np.concatenate([np.sin(coords*(2**i)), np.cos(coords*(2**i))], axis=-1) for i in np.linspace(0,8,20)], axis=-1)

        for _ in range(self.depth-1):
            x = hk.Linear(output_size=self.width)(x)
            x = jax.nn.relu(x)

        out = hk.Linear(output_size=4)(x)
        out = np.reshape(out, list(sh[:-1]) + [4])
        return out

In [78]:
## OLD JAX + HAIKU
def single_step(rng, image, rays, params, bds):
    def sgd(param, update):
        return param - inner_step_size * update
    
    rng, rng_inputs = jax.random.split(rng)
    def loss_model(params):
        g = render_rays(rng_inputs, model, params, None, rays, bds[0], bds[1], N_samples, rand=True)
        return mse_fn(g, image)
    
    model_loss, grad = jax.value_and_grad(loss_model)(params)
    new_params = jax.tree_multimap(sgd, params, grad)
    return rng, new_params, model_loss


model = hk.without_apply_rng(hk.transform(lambda x, y=None: ModelHaiku()(x, y)))
params = model.init(rng, np.ones((1,3)))
opt = optimizers.adam(lr)
opt_state = opt.init_fun(params)
updates, opt_state = opt.update(model_grad, opt_state)
params = optimizers.apply_updates(params, updates)

TypeError: __call__() takes 2 positional arguments but 3 were given

In [None]:
## NEW JAX + HAIKU
lr = 1e-3
num_steps = 3

model = hk.without_apply_rng(hk.transform(lambda x: ModelHaiku()(x)))
params = model.init(rng, np.ones((1,3)))
opt_init, opt_update, get_params = optimizers.adam(lr)
opt_state = opt_init(params)


def single_step_v2(step, rng, image, rays, bds, opt_state):
    def loss_model(params):
        g = render_rays(rng_inputs, model, params,
                        None, rays, bds[0], bds[1], 
                        N_samples, rand=True)
        return mse_fn(g, image)
    rng, rng_inputs = jax.random.split(rng)
    value, grads = jax.value_and_grad(loss_model)(get_params(opt_state))
    opt_state = opt_update(step, grads, opt_state)
    return value, opt_state

In [None]:
rng = jax.random.PRNGKey(0)

for istep in range(num_steps):
    rng, rng_input = jax.random.split(rng)
    img_idx = random.randint(rng_input, shape=(), minval=0, maxval=len(imgfiles)-25)
    images, rays, bds = get_example(img_idx, downsample=1)
    images = np.reshape(images, (-1,3))
    rays = np.reshape(rays, (2,-1,3))
    rng, rng_input = random.split(rng)
    idx = random.randint(rng_input, shape=(batch_size,), minval=0, maxval=images.shape[0])
    loss, opt_state = single_step_v2(istep, rng, images[idx,:], rays[:,idx,:], bds, opt_state)

In [100]:
## NEW JAX + FLAX!
lr = 1e-3
num_steps = 3

# model = hk.without_apply_rng(hk.transform(lambda x: ModelHaiku()(x)))
# params = model.init(rng, np.ones((1,3)))

model = ModelFlax()
key1, key2 = random.split(jax.random.PRNGKey(0))
dummy_x = random.normal(key1, (1, 3))
params = model.init(key2, dummy_x)

opt_init, opt_update, get_params = optimizers.adam(lr)
opt_state = opt_init(params)


def single_step_v2(step, rng, image, rays, bds, opt_state):
    def loss_model(params):
        g = render_rays(rng_inputs, model, params,
                        None, rays, bds[0], bds[1], 
                        N_samples, rand=True)
        return mse_fn(g, image)
    rng, rng_inputs = jax.random.split(rng)
    value, grads = jax.value_and_grad(loss_model)(get_params(opt_state))
    opt_state = opt_update(step, grads, opt_state)
    return value, opt_state

In [101]:
opt_state

OptimizerState(packed_state=([DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32), DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32)], [DeviceArray([[ 0.12091248,  0.03934956,  0.09956338, -0.03747796,
              -0.04041335,  0.02077009],
             [ 0.05760278, -0.0443754 ,  0.05104939, -0.14502399,
               0.0776381 ,  0.1672416 ],
             [ 0.11950634, -0.01816947,  0.01361674,  0.03426957,
              -0.04762232, -0.01502447],
             [-0.07393871,  0.04445767,  0.02080443,  0.00191198,
               0.06556749,  0.04923652],
             [ 0.0131189 ,  0.02381099, -0.01838387,  0.12599164,
              -0.06083291, -0.05262698],
             [-0.1097906 ,  0.09892927, -0.13584226,  0.03626684,
               0.02823985, -0.10594288],
             [ 0.0871003 , -0.09209569, -0.0871902 ,  0.0108161 ,
              -0.05293119, -0.06841837],
             [ 0.04326692, -0.13521153,  0.20045498, -0.1

In [102]:
rng = jax.random.PRNGKey(0)

for istep in range(num_steps):
    rng, rng_input = jax.random.split(rng)
    img_idx = random.randint(rng_input, shape=(), minval=0, maxval=len(imgfiles)-25)
    images, rays, bds = get_example(img_idx, downsample=1)
    images = np.reshape(images, (-1,3))
    rays = np.reshape(rays, (2,-1,3))
    rng, rng_input = random.split(rng)
    idx = random.randint(rng_input, shape=(batch_size,), minval=0, maxval=images.shape[0])
    loss, opt_state = single_step_v2(istep, rng, images[idx,:], rays[:,idx,:], bds, opt_state)

In [103]:
opt_state

OptimizerState(packed_state=([DeviceArray([-0.00045145, -0.00143954, -0.00263059, -0.00230609,
              0.00044354,  0.0007933 ], dtype=float32), DeviceArray([ 7.4164245e-07,  9.0824096e-06,  1.9418590e-06,
              5.1692371e-07,  3.2076082e-06, -8.0520622e-06],            dtype=float32), DeviceArray([3.7262568e-14, 6.9343763e-12, 2.2007938e-13, 4.5521316e-14,
             4.6548199e-12, 9.6004038e-12], dtype=float32)], [DeviceArray([[ 1.19515203e-01,  3.79956812e-02,  1.00867495e-01,
              -3.65395024e-02, -3.95988673e-02,  2.29738709e-02],
             [ 5.77125698e-02, -4.45387252e-02,  4.90781777e-02,
              -1.46169439e-01,  7.75758103e-02,  1.68276310e-01],
             [ 1.20900065e-01, -1.95997935e-02,  1.10316165e-02,
               3.20435129e-02, -4.71294001e-02, -1.43090524e-02],
             [-7.43934736e-02,  4.30180728e-02,  1.81723312e-02,
              -3.99712473e-04,  6.60072789e-02,  5.00307828e-02],
             [ 1.26672154e-02,  2.237154

In [None]:
# access param by layer name
get_params(opt_state)['params']['fc0']

In [109]:
model.get_variable(name='fc_last')

TypeError: get_variable() missing 1 required positional argument: 'col'