|
import flax.linen as nn |
|
from jaxtyping import Array, ArrayLike |
|
|
|
|
|
class ConvNeXtBlock(nn.Module): |
|
"""ConvNext block. See Fig.4 in "A ConvNet for the 2020s" by Liu et al. |
|
|
|
https://openaccess.thecvf.com/content/CVPR2022/papers/Liu_A_ConvNet_for_the_2020s_CVPR_2022_paper.pdf |
|
""" |
|
n_dims: int = 64 |
|
kernel_size: int = 3 |
|
group_features: bool = False |
|
|
|
def setup(self) -> None: |
|
self.residual = nn.Sequential([ |
|
nn.Conv(self.n_dims, kernel_size=(self.kernel_size, self.kernel_size), use_bias=False, |
|
feature_group_count=self.n_dims if self.group_features else 1), |
|
nn.LayerNorm(), |
|
nn.Conv(4 * self.n_dims, kernel_size=(1, 1)), |
|
nn.gelu, |
|
nn.Conv(self.n_dims, kernel_size=(1, 1)), |
|
]) |
|
|
|
def __call__(self, x: ArrayLike) -> Array: |
|
return x + self.residual(x) |
|
|
|
|
|
class Projection(nn.Module): |
|
n_dims: int |
|
|
|
@nn.compact |
|
def __call__(self, x: ArrayLike) -> Array: |
|
x = nn.LayerNorm()(x) |
|
x = nn.Conv(self.n_dims, (1, 1))(x) |
|
return x |
|
|
|
|
|
class ConvNeXt(nn.Module): |
|
block_defs: list[tuple] |
|
|
|
def setup(self) -> None: |
|
layers = [] |
|
current_size = self.block_defs[0][0] |
|
for block_def in self.block_defs: |
|
if block_def[0] != current_size: |
|
layers.append(Projection(block_def[0])) |
|
layers.append(ConvNeXtBlock(*block_def)) |
|
current_size = block_def[0] |
|
self.layers = layers |
|
|
|
def __call__(self, x: ArrayLike, _: bool) -> Array: |
|
for layer in self.layers: |
|
x = layer(x) |
|
return x |
|
|
|
|