File size: 6,217 Bytes
1eb87a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import math

import jax
from flax.core import unfreeze, freeze
import jax.numpy as jnp
import flax.linen as nn
from jaxtyping import Array, ArrayLike, PyTree

from .edsr import EDSR
from .rdn import RDN
from .hyper import Hypernetwork
from .tail import build_tail
from .init import uniform_between, linear_up
from utils import make_grid, interpolate_grid, repeat_vmap


class Thermal(nn.Module):
    w0_scale: float = 1.

    @nn.compact
    def __call__(self, x: ArrayLike, t, norm, k) -> Array:
        phase = self.param('phase', nn.initializers.uniform(.5), x.shape[-1:])
        return jnp.sin(self.w0_scale * x + phase) * jnp.exp(-(self.w0_scale * norm)**2 * k * t)


class TheraField(nn.Module):
    dim_hidden: int
    dim_out: int
    w0: float = 1.
    c: float = 6.

    @nn.compact
    def __call__(self, x: ArrayLike, t: ArrayLike, k: ArrayLike, components: ArrayLike) -> Array:
        # coordinate projection according to shared components ("first layer")
        x = x @ components

        # thermal activations
        norm = jnp.linalg.norm(components, axis=-2)
        x = Thermal(self.w0)(x, t, norm, k)

        # linear projection from hidden to output space ("second layer")
        w_std = math.sqrt(self.c / self.dim_hidden) / self.w0
        dense_init_fn = uniform_between(-w_std, w_std)
        x = nn.Dense(self.dim_out, kernel_init=dense_init_fn, use_bias=False)(x)

        return x


class Thera:

    def __init__(
            self,
            hidden_dim: int,
            out_dim: int,
            backbone: nn.Module,
            tail: nn.Module,
            k_init: float = None,
            components_init_scale: float = None
    ):
        self.hidden_dim = hidden_dim
        self.k_init = k_init
        self.components_init_scale = components_init_scale

        # single TheraField object whose `apply` method is used for all grid cells
        self.field = TheraField(hidden_dim, out_dim)

        # infer output size of the hypernetwork from a sample pass through the field;
        # key doesnt matter as field params are only used for size inference
        sample_params = self.field.init(jax.random.PRNGKey(0),
            jnp.zeros((2,)), 0., 0., jnp.zeros((2, hidden_dim)))
        sample_params_flat, tree_def = jax.tree_util.tree_flatten(sample_params)
        param_shapes = [p.shape for p in sample_params_flat]

        self.hypernet = Hypernetwork(backbone, tail, param_shapes, tree_def)

    def init(self, key, sample_source) -> PyTree:
        keys = jax.random.split(key, 2)
        sample_coords = jnp.zeros(sample_source.shape[:-1] + (2,))
        params = unfreeze(self.hypernet.init(keys[0], sample_source, sample_coords))

        params['params']['k'] = jnp.array(self.k_init)
        params['params']['components'] = \
            linear_up(self.components_init_scale)(keys[1], (2, self.hidden_dim))

        return freeze(params)

    def apply_encoder(self, params: PyTree, source: ArrayLike, **kwargs) -> Array:
        """
        Performs a forward pass through the hypernetwork to obtain an encoding.
        """
        return self.hypernet.apply(
            params, source, method=self.hypernet.get_encoding, **kwargs)

    def apply_decoder(
        self,
        params: PyTree,
        encoding: ArrayLike,
        coords: ArrayLike,
        t: ArrayLike,
        return_jac: bool = False
    ) -> Array | tuple[Array, Array]:
        """
        Performs a forward prediction through a grid of HxW Thera fields,
        informed by `encoding`, at spatial and temporal coordinates
        `coords` and `t`, respectively.
        args:
            params: Field parameters, shape (B, H, W, N)
            encoding: Encoding tensor, shape (B, H, W, C)
            coords: Spatial coordinates in [-0.5, 0.5], shape (B, H, W, 2)
            t: Temporal coordinates, shape (B, 1)
        """
        phi_params: PyTree = self.hypernet.apply(
            params, encoding, coords, method=self.hypernet.get_params_at_coords)

        # create local coordinate systems
        source_grid = jnp.asarray(make_grid(encoding.shape[-3:-1]))
        source_coords = jnp.tile(source_grid, (encoding.shape[0], 1, 1, 1))
        interp_coords = interpolate_grid(coords, source_coords)
        rel_coords = (coords - interp_coords)
        rel_coords = rel_coords.at[..., 0].set(rel_coords[..., 0] * encoding.shape[-3])
        rel_coords = rel_coords.at[..., 1].set(rel_coords[..., 1] * encoding.shape[-2])

        # three maps over params, coords; one over t; dont map k and components
        in_axes = [(0, 0, None, None, None), (0, 0, None, None, None), (0, 0, 0, None, None)]
        apply_field = repeat_vmap(self.field.apply, in_axes)
        out = apply_field(phi_params, rel_coords, t, params['params']['k'],
            params['params']['components'])

        if return_jac:
            apply_jac = repeat_vmap(jax.jacrev(self.field.apply, argnums=1), in_axes)
            jac = apply_jac(phi_params, rel_coords, jnp.zeros_like(t), params['params']['k'],
                params['params']['components'])
            return out, jac

        return out

    def apply(
        self,
        params: ArrayLike,
        source: ArrayLike,
        coords: ArrayLike,
        t: ArrayLike,
        return_jac: bool = False,
        **kwargs
    ) -> Array:
        """
        Performs a forward pass through the Thera model.
        """
        encoding = self.apply_encoder(params, source, **kwargs)
        out = self.apply_decoder(params, encoding, coords, t, return_jac=return_jac)
        return out


def build_thera(
    out_dim: int,
    backbone: str,
    size: str,
    k_init: float = None,
    components_init_scale: float = None
):
    """
    Convenience function for building the three Thera sizes described in the paper.
    """
    hidden_dim = 32 if size == 'air' else 512

    if backbone == 'edsr-baseline':
        backbone_module = EDSR(None, num_blocks=16, num_feats=64)
    elif backbone == 'rdn':
        backbone_module = RDN()
    else:
        raise NotImplementedError(backbone)

    tail_module = build_tail(size)

    return Thera(hidden_dim, out_dim, backbone_module, tail_module, k_init, components_init_scale)