sculpt / model /thera.py
ds1david's picture
New logic
1eb87a5
raw
history blame
6.22 kB
import math
import jax
from flax.core import unfreeze, freeze
import jax.numpy as jnp
import flax.linen as nn
from jaxtyping import Array, ArrayLike, PyTree
from .edsr import EDSR
from .rdn import RDN
from .hyper import Hypernetwork
from .tail import build_tail
from .init import uniform_between, linear_up
from utils import make_grid, interpolate_grid, repeat_vmap
class Thermal(nn.Module):
w0_scale: float = 1.
@nn.compact
def __call__(self, x: ArrayLike, t, norm, k) -> Array:
phase = self.param('phase', nn.initializers.uniform(.5), x.shape[-1:])
return jnp.sin(self.w0_scale * x + phase) * jnp.exp(-(self.w0_scale * norm)**2 * k * t)
class TheraField(nn.Module):
dim_hidden: int
dim_out: int
w0: float = 1.
c: float = 6.
@nn.compact
def __call__(self, x: ArrayLike, t: ArrayLike, k: ArrayLike, components: ArrayLike) -> Array:
# coordinate projection according to shared components ("first layer")
x = x @ components
# thermal activations
norm = jnp.linalg.norm(components, axis=-2)
x = Thermal(self.w0)(x, t, norm, k)
# linear projection from hidden to output space ("second layer")
w_std = math.sqrt(self.c / self.dim_hidden) / self.w0
dense_init_fn = uniform_between(-w_std, w_std)
x = nn.Dense(self.dim_out, kernel_init=dense_init_fn, use_bias=False)(x)
return x
class Thera:
def __init__(
self,
hidden_dim: int,
out_dim: int,
backbone: nn.Module,
tail: nn.Module,
k_init: float = None,
components_init_scale: float = None
):
self.hidden_dim = hidden_dim
self.k_init = k_init
self.components_init_scale = components_init_scale
# single TheraField object whose `apply` method is used for all grid cells
self.field = TheraField(hidden_dim, out_dim)
# infer output size of the hypernetwork from a sample pass through the field;
# key doesnt matter as field params are only used for size inference
sample_params = self.field.init(jax.random.PRNGKey(0),
jnp.zeros((2,)), 0., 0., jnp.zeros((2, hidden_dim)))
sample_params_flat, tree_def = jax.tree_util.tree_flatten(sample_params)
param_shapes = [p.shape for p in sample_params_flat]
self.hypernet = Hypernetwork(backbone, tail, param_shapes, tree_def)
def init(self, key, sample_source) -> PyTree:
keys = jax.random.split(key, 2)
sample_coords = jnp.zeros(sample_source.shape[:-1] + (2,))
params = unfreeze(self.hypernet.init(keys[0], sample_source, sample_coords))
params['params']['k'] = jnp.array(self.k_init)
params['params']['components'] = \
linear_up(self.components_init_scale)(keys[1], (2, self.hidden_dim))
return freeze(params)
def apply_encoder(self, params: PyTree, source: ArrayLike, **kwargs) -> Array:
"""
Performs a forward pass through the hypernetwork to obtain an encoding.
"""
return self.hypernet.apply(
params, source, method=self.hypernet.get_encoding, **kwargs)
def apply_decoder(
self,
params: PyTree,
encoding: ArrayLike,
coords: ArrayLike,
t: ArrayLike,
return_jac: bool = False
) -> Array | tuple[Array, Array]:
"""
Performs a forward prediction through a grid of HxW Thera fields,
informed by `encoding`, at spatial and temporal coordinates
`coords` and `t`, respectively.
args:
params: Field parameters, shape (B, H, W, N)
encoding: Encoding tensor, shape (B, H, W, C)
coords: Spatial coordinates in [-0.5, 0.5], shape (B, H, W, 2)
t: Temporal coordinates, shape (B, 1)
"""
phi_params: PyTree = self.hypernet.apply(
params, encoding, coords, method=self.hypernet.get_params_at_coords)
# create local coordinate systems
source_grid = jnp.asarray(make_grid(encoding.shape[-3:-1]))
source_coords = jnp.tile(source_grid, (encoding.shape[0], 1, 1, 1))
interp_coords = interpolate_grid(coords, source_coords)
rel_coords = (coords - interp_coords)
rel_coords = rel_coords.at[..., 0].set(rel_coords[..., 0] * encoding.shape[-3])
rel_coords = rel_coords.at[..., 1].set(rel_coords[..., 1] * encoding.shape[-2])
# three maps over params, coords; one over t; dont map k and components
in_axes = [(0, 0, None, None, None), (0, 0, None, None, None), (0, 0, 0, None, None)]
apply_field = repeat_vmap(self.field.apply, in_axes)
out = apply_field(phi_params, rel_coords, t, params['params']['k'],
params['params']['components'])
if return_jac:
apply_jac = repeat_vmap(jax.jacrev(self.field.apply, argnums=1), in_axes)
jac = apply_jac(phi_params, rel_coords, jnp.zeros_like(t), params['params']['k'],
params['params']['components'])
return out, jac
return out
def apply(
self,
params: ArrayLike,
source: ArrayLike,
coords: ArrayLike,
t: ArrayLike,
return_jac: bool = False,
**kwargs
) -> Array:
"""
Performs a forward pass through the Thera model.
"""
encoding = self.apply_encoder(params, source, **kwargs)
out = self.apply_decoder(params, encoding, coords, t, return_jac=return_jac)
return out
def build_thera(
out_dim: int,
backbone: str,
size: str,
k_init: float = None,
components_init_scale: float = None
):
"""
Convenience function for building the three Thera sizes described in the paper.
"""
hidden_dim = 32 if size == 'air' else 512
if backbone == 'edsr-baseline':
backbone_module = EDSR(None, num_blocks=16, num_feats=64)
elif backbone == 'rdn':
backbone_module = RDN()
else:
raise NotImplementedError(backbone)
tail_module = build_tail(size)
return Thera(hidden_dim, out_dim, backbone_module, tail_module, k_init, components_init_scale)