|
import math |
|
|
|
import jax |
|
from flax.core import unfreeze, freeze |
|
import jax.numpy as jnp |
|
import flax.linen as nn |
|
from jaxtyping import Array, ArrayLike, PyTree |
|
|
|
from .edsr import EDSR |
|
from .rdn import RDN |
|
from .hyper import Hypernetwork |
|
from .tail import build_tail |
|
from .init import uniform_between, linear_up |
|
from utils import make_grid, interpolate_grid, repeat_vmap |
|
|
|
|
|
class Thermal(nn.Module): |
|
w0_scale: float = 1. |
|
|
|
@nn.compact |
|
def __call__(self, x: ArrayLike, t, norm, k) -> Array: |
|
phase = self.param('phase', nn.initializers.uniform(.5), x.shape[-1:]) |
|
return jnp.sin(self.w0_scale * x + phase) * jnp.exp(-(self.w0_scale * norm)**2 * k * t) |
|
|
|
|
|
class TheraField(nn.Module): |
|
dim_hidden: int |
|
dim_out: int |
|
w0: float = 1. |
|
c: float = 6. |
|
|
|
@nn.compact |
|
def __call__(self, x: ArrayLike, t: ArrayLike, k: ArrayLike, components: ArrayLike) -> Array: |
|
|
|
x = x @ components |
|
|
|
|
|
norm = jnp.linalg.norm(components, axis=-2) |
|
x = Thermal(self.w0)(x, t, norm, k) |
|
|
|
|
|
w_std = math.sqrt(self.c / self.dim_hidden) / self.w0 |
|
dense_init_fn = uniform_between(-w_std, w_std) |
|
x = nn.Dense(self.dim_out, kernel_init=dense_init_fn, use_bias=False)(x) |
|
|
|
return x |
|
|
|
|
|
class Thera: |
|
|
|
def __init__( |
|
self, |
|
hidden_dim: int, |
|
out_dim: int, |
|
backbone: nn.Module, |
|
tail: nn.Module, |
|
k_init: float = None, |
|
components_init_scale: float = None |
|
): |
|
self.hidden_dim = hidden_dim |
|
self.k_init = k_init |
|
self.components_init_scale = components_init_scale |
|
|
|
|
|
self.field = TheraField(hidden_dim, out_dim) |
|
|
|
|
|
|
|
sample_params = self.field.init(jax.random.PRNGKey(0), |
|
jnp.zeros((2,)), 0., 0., jnp.zeros((2, hidden_dim))) |
|
sample_params_flat, tree_def = jax.tree_util.tree_flatten(sample_params) |
|
param_shapes = [p.shape for p in sample_params_flat] |
|
|
|
self.hypernet = Hypernetwork(backbone, tail, param_shapes, tree_def) |
|
|
|
def init(self, key, sample_source) -> PyTree: |
|
keys = jax.random.split(key, 2) |
|
sample_coords = jnp.zeros(sample_source.shape[:-1] + (2,)) |
|
params = unfreeze(self.hypernet.init(keys[0], sample_source, sample_coords)) |
|
|
|
params['params']['k'] = jnp.array(self.k_init) |
|
params['params']['components'] = \ |
|
linear_up(self.components_init_scale)(keys[1], (2, self.hidden_dim)) |
|
|
|
return freeze(params) |
|
|
|
def apply_encoder(self, params: PyTree, source: ArrayLike, **kwargs) -> Array: |
|
""" |
|
Performs a forward pass through the hypernetwork to obtain an encoding. |
|
""" |
|
return self.hypernet.apply( |
|
params, source, method=self.hypernet.get_encoding, **kwargs) |
|
|
|
def apply_decoder( |
|
self, |
|
params: PyTree, |
|
encoding: ArrayLike, |
|
coords: ArrayLike, |
|
t: ArrayLike, |
|
return_jac: bool = False |
|
) -> Array | tuple[Array, Array]: |
|
""" |
|
Performs a forward prediction through a grid of HxW Thera fields, |
|
informed by `encoding`, at spatial and temporal coordinates |
|
`coords` and `t`, respectively. |
|
args: |
|
params: Field parameters, shape (B, H, W, N) |
|
encoding: Encoding tensor, shape (B, H, W, C) |
|
coords: Spatial coordinates in [-0.5, 0.5], shape (B, H, W, 2) |
|
t: Temporal coordinates, shape (B, 1) |
|
""" |
|
phi_params: PyTree = self.hypernet.apply( |
|
params, encoding, coords, method=self.hypernet.get_params_at_coords) |
|
|
|
|
|
source_grid = jnp.asarray(make_grid(encoding.shape[-3:-1])) |
|
source_coords = jnp.tile(source_grid, (encoding.shape[0], 1, 1, 1)) |
|
interp_coords = interpolate_grid(coords, source_coords) |
|
rel_coords = (coords - interp_coords) |
|
rel_coords = rel_coords.at[..., 0].set(rel_coords[..., 0] * encoding.shape[-3]) |
|
rel_coords = rel_coords.at[..., 1].set(rel_coords[..., 1] * encoding.shape[-2]) |
|
|
|
|
|
in_axes = [(0, 0, None, None, None), (0, 0, None, None, None), (0, 0, 0, None, None)] |
|
apply_field = repeat_vmap(self.field.apply, in_axes) |
|
out = apply_field(phi_params, rel_coords, t, params['params']['k'], |
|
params['params']['components']) |
|
|
|
if return_jac: |
|
apply_jac = repeat_vmap(jax.jacrev(self.field.apply, argnums=1), in_axes) |
|
jac = apply_jac(phi_params, rel_coords, jnp.zeros_like(t), params['params']['k'], |
|
params['params']['components']) |
|
return out, jac |
|
|
|
return out |
|
|
|
def apply( |
|
self, |
|
params: ArrayLike, |
|
source: ArrayLike, |
|
coords: ArrayLike, |
|
t: ArrayLike, |
|
return_jac: bool = False, |
|
**kwargs |
|
) -> Array: |
|
""" |
|
Performs a forward pass through the Thera model. |
|
""" |
|
encoding = self.apply_encoder(params, source, **kwargs) |
|
out = self.apply_decoder(params, encoding, coords, t, return_jac=return_jac) |
|
return out |
|
|
|
|
|
def build_thera( |
|
out_dim: int, |
|
backbone: str, |
|
size: str, |
|
k_init: float = None, |
|
components_init_scale: float = None |
|
): |
|
""" |
|
Convenience function for building the three Thera sizes described in the paper. |
|
""" |
|
hidden_dim = 32 if size == 'air' else 512 |
|
|
|
if backbone == 'edsr-baseline': |
|
backbone_module = EDSR(None, num_blocks=16, num_feats=64) |
|
elif backbone == 'rdn': |
|
backbone_module = RDN() |
|
else: |
|
raise NotImplementedError(backbone) |
|
|
|
tail_module = build_tail(size) |
|
|
|
return Thera(hidden_dim, out_dim, backbone_module, tail_module, k_init, components_init_scale) |
|
|