|
import functools, math, re |
|
from collections import OrderedDict |
|
import mlx.core as mx |
|
import mlx.nn as nn |
|
import numpy as np |
|
import blocks as B |
|
from mlx.utils import tree_flatten |
|
|
|
def conv_state_pair_to_mlx(kv): |
|
k, v = kv |
|
if v.ndim == 4: |
|
v = v.transpose(0, 2, 3, 1) |
|
v = v.reshape(-1).reshape(v.shape) |
|
return re.sub(r'(\.\d+\.)', r'.layers\1', k), v |
|
|
|
|
|
|
|
class ESRGAN(nn.Module): |
|
def __init__( |
|
self, |
|
state_dict, |
|
norm=None, |
|
act: str = "leakyrelu", |
|
upsampler: str = "upconv", |
|
mode: str = "CNA", |
|
) -> None: |
|
""" |
|
ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks. |
|
By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao, |
|
and Chen Change Loy. |
|
This is old-arch Residual in Residual Dense Block Network and is not |
|
the newest revision that's available at github.com/xinntao/ESRGAN. |
|
This is on purpose, the newest Network has severely limited the |
|
potential use of the Network with no benefits. |
|
This network supports model files from both new and old-arch. |
|
Args: |
|
norm: Normalization layer |
|
act: Activation layer |
|
upsampler: Upsample layer. upconv, pixel_shuffle |
|
mode: Convolution mode |
|
""" |
|
super().__init__() |
|
|
|
|
|
self._raw_state = state_dict |
|
self.norm = norm |
|
self.act = act |
|
self.upsampler = upsampler |
|
self.mode = mode |
|
|
|
self.state_map = { |
|
|
|
|
|
"model.0.weight": ("conv_first.weight",), |
|
"model.0.bias": ("conv_first.bias",), |
|
"model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"), |
|
"model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"), |
|
"model.3.weight": ("upconv1.weight", "conv_up1.weight"), |
|
"model.3.bias": ("upconv1.bias", "conv_up1.bias"), |
|
"model.6.weight": ("upconv2.weight", "conv_up2.weight"), |
|
"model.6.bias": ("upconv2.bias", "conv_up2.bias"), |
|
"model.8.weight": ("HRconv.weight", "conv_hr.weight"), |
|
"model.8.bias": ("HRconv.bias", "conv_hr.bias"), |
|
"model.10.weight": ("conv_last.weight",), |
|
"model.10.bias": ("conv_last.bias",), |
|
r"model.1.sub.\1.RDB\2.conv\3.0.\4": ( |
|
r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)", |
|
r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)", |
|
), |
|
} |
|
if "params_ema" in self._raw_state: |
|
self._raw_state = self._raw_state["params_ema"] |
|
self.num_blocks = self.get_num_blocks() |
|
|
|
self.plus = any("conv1x1" in k for k in self._raw_state.keys()) |
|
|
|
self._raw_state = self.new_to_old_arch(self._raw_state) |
|
|
|
self.key_arr = sorted(list(self._raw_state.keys()), key=lambda x: [1 if v == "bias" else 0 if v == "weight" else int(v) if re.match(r'^\d+$', v) else v for v in re.findall(r'[^.]+', x)]) |
|
|
|
|
|
self.in_nc = self._raw_state[self.key_arr[0]].shape[1] |
|
self.out_nc = self._raw_state[self.key_arr[-1]].shape[0] |
|
|
|
self.scale = self.get_scale() |
|
|
|
self.num_filters = self._raw_state[self.key_arr[0]].shape[0] |
|
|
|
c2x2 = False |
|
if self._raw_state["model.0.weight"].shape[-3] == 2: |
|
c2x2 = True |
|
self.scale = math.ceil(self.scale ** (1.0 / 3)) |
|
|
|
|
|
if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in ( |
|
self.in_nc / 4, |
|
self.in_nc / 16, |
|
): |
|
self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc)) |
|
else: |
|
self.shuffle_factor = None |
|
|
|
upsample_block = { |
|
"upconv": B.upconv_block, |
|
"pixel_shuffle": B.pixelshuffle_block, |
|
}.get(self.upsampler) |
|
if upsample_block is None: |
|
raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found") |
|
|
|
if self.scale == 3: |
|
upsample_blocks = upsample_block( |
|
in_nc=self.num_filters, |
|
out_nc=self.num_filters, |
|
upscale_factor=3, |
|
act_type=self.act, |
|
c2x2=c2x2, |
|
) |
|
else: |
|
upsample_blocks = [ |
|
upsample_block( |
|
in_nc=self.num_filters, |
|
out_nc=self.num_filters, |
|
act_type=self.act, |
|
c2x2=c2x2, |
|
) |
|
for _ in range(int(math.log(self.scale, 2))) |
|
] |
|
|
|
self.model = B.sequential( |
|
|
|
B.conv_block( |
|
in_nc=self.in_nc, |
|
out_nc=self.num_filters, |
|
kernel_size=3, |
|
norm_type=None, |
|
act_type=None, |
|
c2x2=c2x2, |
|
), |
|
B.ShortcutBlock( |
|
B.sequential( |
|
|
|
*[ |
|
B.RRDB( |
|
nf=self.num_filters, |
|
kernel_size=3, |
|
gc=32, |
|
stride=1, |
|
bias=True, |
|
pad_type="zero", |
|
norm_type=self.norm, |
|
act_type=self.act, |
|
mode="CNA", |
|
plus=self.plus, |
|
c2x2=c2x2, |
|
) |
|
for _ in range(self.num_blocks) |
|
], |
|
|
|
B.conv_block( |
|
in_nc=self.num_filters, |
|
out_nc=self.num_filters, |
|
kernel_size=3, |
|
norm_type=self.norm, |
|
act_type=None, |
|
mode=self.mode, |
|
c2x2=c2x2, |
|
), |
|
) |
|
), |
|
*upsample_blocks, |
|
|
|
B.conv_block( |
|
in_nc=self.num_filters, |
|
out_nc=self.num_filters, |
|
kernel_size=3, |
|
norm_type=None, |
|
act_type=self.act, |
|
c2x2=c2x2, |
|
), |
|
|
|
B.conv_block( |
|
in_nc=self.num_filters, |
|
out_nc=self.out_nc, |
|
kernel_size=3, |
|
norm_type=None, |
|
act_type=None, |
|
c2x2=c2x2, |
|
), |
|
) |
|
|
|
self.load_weights(list(conv_state_pair_to_mlx(p) for p in self._raw_state.items()), strict=True) |
|
|
|
|
|
def new_to_old_arch(self, state): |
|
"""Convert a new-arch model state dictionary to an old-arch dictionary.""" |
|
if "params_ema" in state: |
|
state = state["params_ema"] |
|
|
|
if "conv_first.weight" not in state: |
|
|
|
return state |
|
|
|
|
|
for kind in ("weight", "bias"): |
|
self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[ |
|
f"model.1.sub./NB/.{kind}" |
|
] |
|
del self.state_map[f"model.1.sub./NB/.{kind}"] |
|
|
|
old_state = OrderedDict() |
|
for old_key, new_keys in self.state_map.items(): |
|
for new_key in new_keys: |
|
if r"\1" in old_key: |
|
for k, v in state.items(): |
|
sub = re.sub(new_key, old_key, k) |
|
if sub != k: |
|
old_state[sub] = v |
|
else: |
|
if new_key in state: |
|
old_state[old_key] = state[new_key] |
|
|
|
|
|
def compare(item1, item2): |
|
parts1 = item1.split(".") |
|
parts2 = item2.split(".") |
|
int1 = int(parts1[1]) |
|
int2 = int(parts2[1]) |
|
return int1 - int2 |
|
|
|
sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare)) |
|
|
|
|
|
out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys) |
|
|
|
return out_dict |
|
|
|
def get_scale(self, min_part: int = 6) -> int: |
|
n = 0 |
|
for part in list(self._raw_state): |
|
parts = part.split(".")[1:] |
|
if len(parts) == 2: |
|
part_num = int(parts[0]) |
|
if part_num > min_part and parts[1] == "weight": |
|
n += 1 |
|
return 2**n |
|
|
|
def get_num_blocks(self) -> int: |
|
nbs = [] |
|
state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + ( |
|
r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)", |
|
) |
|
for state_key in state_keys: |
|
for k in self._raw_state: |
|
m = re.search(state_key, k) |
|
if m: |
|
nbs.append(int(m.group(1))) |
|
if nbs: |
|
break |
|
return max(*nbs) + 1 |
|
|
|
def __call__(self, x): |
|
if self.shuffle_factor: |
|
x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor) |
|
return self.model(x) |