thera / model /swin_ir.py
Alexander Becker
Add code
b139995
import math
from typing import Callable, Optional, Iterable
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
from jaxtyping import Array
def trunc_normal(mean=0., std=1., a=-2., b=2., dtype=jnp.float32) -> Callable:
"""Truncated normal initialization function"""
def init(key, shape, dtype=dtype) -> Array:
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/weight_init.py
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
out = jax.random.uniform(key, shape, dtype=dtype, minval=2 * l - 1, maxval=2 * u - 1)
out = jax.scipy.special.erfinv(out) * std * math.sqrt(2.) + mean
return jnp.clip(out, a, b)
return init
def Dense(features, use_bias=True, kernel_init=trunc_normal(std=.02), bias_init=nn.initializers.zeros):
return nn.Dense(features, use_bias=use_bias, kernel_init=kernel_init, bias_init=bias_init)
def LayerNorm():
"""torch LayerNorm uses larger epsilon by default"""
return nn.LayerNorm(epsilon=1e-05)
class Mlp(nn.Module):
in_features: int
hidden_features: int = None
out_features: int = None
act_layer: Callable = nn.gelu
drop: float = 0.0
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.hidden_features or self.in_features)(x)
x = self.act_layer(x)
x = nn.Dropout(self.drop, deterministic=not training)(x)
x = nn.Dense(self.out_features or self.in_features)(x)
x = nn.Dropout(self.drop, deterministic=not training)(x)
return x
def window_partition(x, window_size: int):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.reshape((B, H // window_size, window_size, W // window_size, window_size, C))
windows = x.transpose((0, 1, 3, 2, 4, 5)).reshape((-1, window_size, window_size, C))
return windows
def window_reverse(windows, window_size: int, H: int, W: int):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.reshape((B, H // window_size, W // window_size, window_size, window_size, -1))
x = x.transpose((0, 1, 3, 2, 4, 5)).reshape((B, H, W, -1))
return x
class DropPath(nn.Module):
"""
Implementation referred from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
"""
dropout_prob: float = 0.1
deterministic: Optional[bool] = None
@nn.compact
def __call__(self, input, training):
if not training:
return input
keep_prob = 1 - self.dropout_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1)
rng = self.make_rng("dropout")
random_tensor = keep_prob + jax.random.uniform(rng, shape)
random_tensor = jnp.floor(random_tensor)
return jnp.divide(input, keep_prob) * random_tensor
class WindowAttention(nn.Module):
dim: int
window_size: Iterable[int]
num_heads: int
qkv_bias: bool = True
qk_scale: Optional[float] = None
att_drop: float = 0.0
proj_drop: float = 0.0
def make_rel_pos_index(self):
h_indices = np.arange(0, self.window_size[0])
w_indices = np.arange(0, self.window_size[1])
indices = np.stack(np.meshgrid(w_indices, h_indices, indexing="ij"))
flatten_indices = np.reshape(indices, (2, -1))
relative_indices = flatten_indices[:, :, None] - flatten_indices[:, None, :]
relative_indices = np.transpose(relative_indices, (1, 2, 0))
relative_indices[:, :, 0] += self.window_size[0] - 1
relative_indices[:, :, 1] += self.window_size[1] - 1
relative_indices[:, :, 0] *= 2 * self.window_size[1] - 1
relative_pos_index = np.sum(relative_indices, -1)
return relative_pos_index
@nn.compact
def __call__(self, inputs, mask, training):
rpbt = self.param(
"relative_position_bias_table",
trunc_normal(std=.02),
(
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1),
self.num_heads,
),
)
#relative_pos_index = self.variable(
# "variables", "relative_position_index", self.get_rel_pos_index
#)
batch, n, channels = inputs.shape
qkv = nn.Dense(self.dim * 3, use_bias=self.qkv_bias, name="qkv")(inputs)
qkv = qkv.reshape(batch, n, 3, self.num_heads, channels // self.num_heads)
qkv = jnp.transpose(qkv, (2, 0, 3, 1, 4))
q, k, v = qkv[0], qkv[1], qkv[2]
scale = self.qk_scale or (self.dim // self.num_heads) ** -0.5
q = q * scale
att = q @ jnp.swapaxes(k, -2, -1)
rel_pos_bias = jnp.reshape(
rpbt[np.reshape(self.make_rel_pos_index(), (-1))],
(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
),
)
rel_pos_bias = jnp.transpose(rel_pos_bias, (2, 0, 1))
att += jnp.expand_dims(rel_pos_bias, 0)
if mask is not None:
att = jnp.reshape(
att, (batch // mask.shape[0], mask.shape[0], self.num_heads, n, n)
)
att = att + jnp.expand_dims(jnp.expand_dims(mask, 1), 0)
att = jnp.reshape(att, (-1, self.num_heads, n, n))
att = jax.nn.softmax(att)
else:
att = jax.nn.softmax(att)
att = nn.Dropout(self.att_drop)(att, deterministic=not training)
x = jnp.reshape(jnp.swapaxes(att @ v, 1, 2), (batch, n, channels))
x = nn.Dense(self.dim, name="proj")(x)
x = nn.Dropout(self.proj_drop)(x, deterministic=not training)
return x
class SwinTransformerBlock(nn.Module):
dim: int
input_resolution: tuple[int]
num_heads: int
window_size: int = 7
shift_size: int = 0
mlp_ratio: float = 4.
qkv_bias: bool = True
qk_scale: Optional[float] = None
drop: float = 0.
attn_drop: float = 0.
drop_path: float = 0.
act_layer: Callable = nn.activation.gelu
norm_layer: Callable = LayerNorm
@staticmethod
def make_att_mask(shift_size, window_size, height, width):
if shift_size > 0:
mask = jnp.zeros([1, height, width, 1])
h_slices = (
slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None),
)
w_slices = (
slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None),
)
count = 0
for h in h_slices:
for w in w_slices:
mask = mask.at[:, h, w, :].set(count)
count += 1
mask_windows = window_partition(mask, window_size)
mask_windows = jnp.reshape(mask_windows, (-1, window_size * window_size))
att_mask = jnp.expand_dims(mask_windows, 1) - jnp.expand_dims(mask_windows, 2)
att_mask = jnp.where(att_mask != 0.0, float(-100.0), att_mask)
att_mask = jnp.where(att_mask == 0.0, float(0.0), att_mask)
else:
att_mask = None
return att_mask
@nn.compact
def __call__(self, x, x_size, training):
H, W = x_size
B, L, C = x.shape
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
shortcut = x
x = self.norm_layer()(x)
x = x.reshape((B, H, W, C))
# cyclic shift
if self.shift_size > 0:
shifted_x = jnp.roll(x, (-self.shift_size, -self.shift_size), axis=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.reshape((-1, self.window_size * self.window_size, C)) # nW*B, window_size*window_size, C
#attn_mask = self.variable(
# "variables",
# "attn_mask",
# self.get_att_mask,
# self.shift_size,
# self.window_size,
# self.input_resolution[0],
# self.input_resolution[1]
#)
attn_mask = self.make_att_mask(self.shift_size, self.window_size, *self.input_resolution)
attn = WindowAttention(self.dim, (self.window_size, self.window_size), self.num_heads,
self.qkv_bias, self.qk_scale, self.attn_drop, self.drop)
if self.input_resolution == x_size:
attn_windows = attn(x_windows, attn_mask, training) # nW*B, window_size*window_size, C
else:
# test time
assert not training
test_mask = self.make_att_mask(self.shift_size, self.window_size, *x_size)
attn_windows = attn(x_windows, test_mask, training=False)
# merge windows
attn_windows = attn_windows.reshape((-1, self.window_size, self.window_size, C))
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = jnp.roll(shifted_x, (self.shift_size, self.shift_size), axis=(1, 2))
else:
x = shifted_x
x = x.reshape((B, H * W, C))
# FFN
x = shortcut + DropPath(self.drop_path)(x, training)
norm = self.norm_layer()(x)
mlp = Mlp(in_features=self.dim, hidden_features=int(self.dim * self.mlp_ratio),
act_layer=self.act_layer, drop=self.drop)(norm, training)
x = x + DropPath(self.drop_path)(mlp, training)
return x
class PatchMerging(nn.Module):
inp_res: Iterable[int]
dim: int
norm_layer: Callable = LayerNorm
@nn.compact
def __call__(self, inputs):
batch, n, channels = inputs.shape
height, width = self.inp_res[0], self.inp_res[1]
x = jnp.reshape(inputs, (batch, height, width, channels))
x0 = x[:, 0::2, 0::2, :]
x1 = x[:, 1::2, 0::2, :]
x2 = x[:, 0::2, 1::2, :]
x3 = x[:, 1::2, 1::2, :]
x = jnp.concatenate([x0, x1, x2, x3], axis=-1)
x = jnp.reshape(x, (batch, -1, 4 * channels))
x = self.norm_layer()(x)
x = nn.Dense(2 * self.dim, use_bias=False)(x)
return x
class BasicLayer(nn.Module):
dim: int
input_resolution: int
depth: int
num_heads: int
window_size: int
mlp_ratio: float = 4.
qkv_bias: bool = True
qk_scale: Optional[float] = None
drop: float = 0.
attn_drop: float = 0.
drop_path: float = 0.
norm_layer: Callable = LayerNorm
downsample: Optional[Callable] = None
@nn.compact
def __call__(self, x, x_size, training):
for i in range(self.depth):
x = SwinTransformerBlock(
self.dim,
self.input_resolution,
self.num_heads,
self.window_size,
0 if (i % 2 == 0) else self.window_size // 2,
self.mlp_ratio,
self.qkv_bias,
self.qk_scale,
self.drop,
self.attn_drop,
self.drop_path[i] if isinstance(self.drop_path, (list, tuple)) else self.drop_path,
norm_layer=self.norm_layer
)(x, x_size, training)
if self.downsample is not None:
x = self.downsample(self.input_resolution, dim=self.dim, norm_layer=self.norm_layer)(x)
return x
class RSTB(nn.Module):
dim: int
input_resolution: int
depth: int
num_heads: int
window_size: int
mlp_ratio: float = 4.
qkv_bias: bool = True
qk_scale: Optional[float] = None
drop: float = 0.
attn_drop: float = 0.
drop_path: float = 0.
norm_layer: Callable = LayerNorm
downsample: Optional[Callable] = None
img_size: int = 224,
patch_size: int = 4,
resi_connection: str = '1conv'
@nn.compact
def __call__(self, x, x_size, training):
res = x
x = BasicLayer(dim=self.dim,
input_resolution=self.input_resolution,
depth=self.depth,
num_heads=self.num_heads,
window_size=self.window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
drop=self.drop, attn_drop=self.attn_drop,
drop_path=self.drop_path,
norm_layer=self.norm_layer,
downsample=self.downsample)(x, x_size, training)
x = PatchUnEmbed(embed_dim=self.dim)(x, x_size)
# resi_connection == '1conv':
x = nn.Conv(self.dim, (3, 3))(x)
x = PatchEmbed()(x)
return x + res
class PatchEmbed(nn.Module):
norm_layer: Optional[Callable] = None
@nn.compact
def __call__(self, x):
x = x.reshape((x.shape[0], -1, x.shape[-1])) # B Ph Pw C -> B Ph*Pw C
if self.norm_layer is not None:
x = self.norm_layer()(x)
return x
class PatchUnEmbed(nn.Module):
embed_dim: int = 96
@nn.compact
def __call__(self, x, x_size):
B, HW, C = x.shape
x = x.reshape((B, x_size[0], x_size[1], self.embed_dim))
return x
class SwinIR(nn.Module):
r""" SwinIR JAX implementation
Args:
img_size (int | tuple(int)): Input image size. Default 64
patch_size (int | tuple(int)): Patch size. Default: 1
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_range: Image range. 1. or 25I think5.
"""
img_size: int = 48
patch_size: int = 1
in_chans: int = 3
embed_dim: int = 180
depths: tuple = (6, 6, 6, 6, 6, 6)
num_heads: tuple = (6, 6, 6, 6, 6, 6)
window_size: int = 8
mlp_ratio: float = 2.
qkv_bias: bool = True
qk_scale: Optional[float] = None
drop_rate: float = 0.
attn_drop_rate: float = 0.
drop_path_rate: float = 0.1
norm_layer: Callable = LayerNorm
ape: bool = False
patch_norm: bool = True
upscale: int = 2
img_range: float = 1.
num_feat: int = 64
def pad(self, x):
_, h, w, _ = x.shape
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
x = jnp.pad(x, ((0, 0), (0, mod_pad_h), (0, mod_pad_w), (0, 0)), 'reflect')
return x
@nn.compact
def __call__(self, x, training):
_, h_before, w_before, _ = x.shape
x = self.pad(x)
_, h, w, _ = x.shape
patches_resolution = [self.img_size // self.patch_size] * 2
num_patches = patches_resolution[0] * patches_resolution[1]
# conv_first
x = nn.Conv(self.embed_dim, (3, 3))(x)
res = x
# feature extraction
x_size = (h, w)
x = PatchEmbed(self.norm_layer if self.patch_norm else None)(x)
if self.ape:
absolute_pos_embed = \
self.param('ape', trunc_normal(std=.02), (1, num_patches, self.embed_dim))
x = x + absolute_pos_embed
x = nn.Dropout(self.drop_rate, deterministic=not training)(x)
dpr = [x.item() for x in np.linspace(0, self.drop_path_rate, sum(self.depths))]
for i_layer in range(len(self.depths)):
x = RSTB(
dim=self.embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]),
depth=self.depths[i_layer],
num_heads=self.num_heads[i_layer],
window_size=self.window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
drop=self.drop_rate, attn_drop=self.attn_drop_rate,
drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
norm_layer=self.norm_layer,
downsample=None,
img_size=self.img_size,
patch_size=self.patch_size)(x, x_size, training)
x = self.norm_layer()(x) # B L C
x = PatchUnEmbed(self.embed_dim)(x, x_size)
# conv_after_body
x = nn.Conv(self.embed_dim, (3, 3))(x)
x = x + res
# conv_before_upsample
x = nn.activation.leaky_relu(nn.Conv(self.num_feat, (3, 3))(x))
# revert padding
x = x[:, :-(h - h_before) or None, :-(w - w_before) or None]
return x