# from https://github.com/isaaccorley/jax-enhance from functools import partial from typing import Any, Sequence, Callable import jax.numpy as jnp import flax.linen as nn from flax.core.frozen_dict import freeze import einops class PixelShuffle(nn.Module): scale_factor: int def setup(self): self.layer = partial( einops.rearrange, pattern="b h w (c h2 w2) -> b (h h2) (w w2) c", h2=self.scale_factor, w2=self.scale_factor ) def __call__(self, x: jnp.ndarray) -> jnp.ndarray: return self.layer(x) class ResidualBlock(nn.Module): channels: int kernel_size: Sequence[int] res_scale: float activation: Callable dtype: Any = jnp.float32 def setup(self): self.body = nn.Sequential([ nn.Conv(features=self.channels, kernel_size=self.kernel_size, dtype=self.dtype), self.activation, nn.Conv(features=self.channels, kernel_size=self.kernel_size, dtype=self.dtype), ]) def __call__(self, x: jnp.ndarray) -> jnp.ndarray: return x + self.body(x) class UpsampleBlock(nn.Module): num_upsamples: int channels: int kernel_size: Sequence[int] dtype: Any = jnp.float32 def setup(self): layers = [] for _ in range(self.num_upsamples): layers.extend([ nn.Conv(features=self.channels * 2 ** 2, kernel_size=self.kernel_size, dtype=self.dtype), PixelShuffle(scale_factor=2), ]) self.layers = layers def __call__(self, x: jnp.ndarray) -> jnp.ndarray: for layer in self.layers: x = layer(x) return x class EDSR(nn.Module): """Enhanced Deep Residual Networks for Single Image Super-Resolution https://arxiv.org/pdf/1707.02921v1.pdf""" scale_factor: int channels: int = 3 num_blocks: int = 32 num_feats: int = 256 dtype: Any = jnp.float32 def setup(self): # pre res blocks layer self.head = nn.Sequential([nn.Conv(features=self.num_feats, kernel_size=(3, 3), dtype=self.dtype)]) # res blocks res_blocks = [ ResidualBlock(channels=self.num_feats, kernel_size=(3, 3), res_scale=0.1, activation=nn.relu, dtype=self.dtype) for i in range(self.num_blocks) ] res_blocks.append(nn.Conv(features=self.num_feats, kernel_size=(3, 3), dtype=self.dtype)) self.body = nn.Sequential(res_blocks) def __call__(self, x: jnp.ndarray, _=None) -> jnp.ndarray: x = self.head(x) x = x + self.body(x) return x def convert_edsr_checkpoint(torch_dict, no_upsampling=True): def convert(in_dict): top_keys = set([k.split('.')[0] for k in in_dict.keys()]) leaves = set([k for k in in_dict.keys() if '.' not in k]) # convert leaves out_dict = {} for l in leaves: if l == 'weight': out_dict['kernel'] = jnp.asarray(in_dict[l]).transpose((2, 3, 1, 0)) elif l == 'bias': out_dict[l] = jnp.asarray(in_dict[l]) else: out_dict[l] = in_dict[l] for top_key in top_keys.difference(leaves): new_top_key = 'layers_' + top_key if top_key.isdigit() else top_key out_dict[new_top_key] = convert( {k[len(top_key) + 1:]: v for k, v in in_dict.items() if k.startswith(top_key)}) return out_dict converted = convert(torch_dict) # remove unwanted keys if no_upsampling: del converted['tail'] for k in ('add_mean', 'sub_mean'): del converted[k] return freeze(converted)