thera / model /convnext.py
Alexander Becker
Add code
b139995
import flax.linen as nn
from jaxtyping import Array, ArrayLike
class ConvNeXtBlock(nn.Module):
"""ConvNext block. See Fig.4 in "A ConvNet for the 2020s" by Liu et al.
https://openaccess.thecvf.com/content/CVPR2022/papers/Liu_A_ConvNet_for_the_2020s_CVPR_2022_paper.pdf
"""
n_dims: int = 64
kernel_size: int = 3 # 7 in the paper's version
group_features: bool = False
def setup(self) -> None:
self.residual = nn.Sequential([
nn.Conv(self.n_dims, kernel_size=(self.kernel_size, self.kernel_size), use_bias=False,
feature_group_count=self.n_dims if self.group_features else 1),
nn.LayerNorm(),
nn.Conv(4 * self.n_dims, kernel_size=(1, 1)),
nn.gelu,
nn.Conv(self.n_dims, kernel_size=(1, 1)),
])
def __call__(self, x: ArrayLike) -> Array:
return x + self.residual(x)
class Projection(nn.Module):
n_dims: int
@nn.compact
def __call__(self, x: ArrayLike) -> Array:
x = nn.LayerNorm()(x)
x = nn.Conv(self.n_dims, (1, 1))(x)
return x
class ConvNeXt(nn.Module):
block_defs: list[tuple]
def setup(self) -> None:
layers = []
current_size = self.block_defs[0][0]
for block_def in self.block_defs:
if block_def[0] != current_size:
layers.append(Projection(block_def[0]))
layers.append(ConvNeXtBlock(*block_def))
current_size = block_def[0]
self.layers = layers
def __call__(self, x: ArrayLike, _: bool) -> Array:
for layer in self.layers:
x = layer(x)
return x