Spaces:
Running
on
L4
Running
on
L4
File size: 1,704 Bytes
b139995 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
# Residual Dense Network for Image Super-Resolution
# https://arxiv.org/abs/1802.08797
# modified from: https://github.com/thstkdgus35/EDSR-PyTorch
import jax.numpy as jnp
import flax.linen as nn
class RDB_Conv(nn.Module):
growRate: int
kSize: int = 3
@nn.compact
def __call__(self, x):
out = nn.Sequential([
nn.Conv(self.growRate, (self.kSize, self.kSize), padding=(self.kSize-1)//2),
nn.activation.relu
])(x)
return jnp.concatenate((x, out), -1)
class RDB(nn.Module):
growRate0: int
growRate: int
nConvLayers: int
@nn.compact
def __call__(self, x):
res = x
for c in range(self.nConvLayers):
x = RDB_Conv(self.growRate)(x)
x = nn.Conv(self.growRate0, (1, 1))(x)
return x + res
class RDN(nn.Module):
G0: int = 64
RDNkSize: int = 3
RDNconfig: str = 'B'
scale: int = 2
n_colors: int = 3
@nn.compact
def __call__(self, x, _=None):
D, C, G = {
'A': (20, 6, 32),
'B': (16, 8, 64),
}[self.RDNconfig]
# Shallow feature extraction
f_1 = nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))(x)
x = nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))(f_1)
# Redidual dense blocks and dense feature fusion
RDBs_out = []
for i in range(D):
x = RDB(self.G0, G, C)(x)
RDBs_out.append(x)
x = jnp.concatenate(RDBs_out, -1)
# Global Feature Fusion
x = nn.Sequential([
nn.Conv(self.G0, (1, 1)),
nn.Conv(self.G0, (self.RDNkSize, self.RDNkSize))
])(x)
x = x + f_1
return x
|