Upload 7 files
Browse files- .gitignore +145 -0
- 4x_NMKD-YandereNeoXL_200k.safetensors +3 -0
- ESRGAN.py +264 -0
- README.md +4 -3
- blocks.py +534 -0
- requirements.txt +4 -0
- upscale.py +71 -0
.gitignore
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
98 |
+
__pypackages__/
|
99 |
+
|
100 |
+
# Celery stuff
|
101 |
+
celerybeat-schedule
|
102 |
+
celerybeat.pid
|
103 |
+
|
104 |
+
# SageMath parsed files
|
105 |
+
*.sage.py
|
106 |
+
|
107 |
+
# Environments
|
108 |
+
.env
|
109 |
+
.venv
|
110 |
+
env/
|
111 |
+
venv/
|
112 |
+
ENV/
|
113 |
+
env.bak/
|
114 |
+
venv.bak/
|
115 |
+
|
116 |
+
# Spyder project settings
|
117 |
+
.spyderproject
|
118 |
+
.spyproject
|
119 |
+
|
120 |
+
# Rope project settings
|
121 |
+
.ropeproject
|
122 |
+
|
123 |
+
# mkdocs documentation
|
124 |
+
/site
|
125 |
+
|
126 |
+
# mypy
|
127 |
+
.mypy_cache/
|
128 |
+
.dmypy.json
|
129 |
+
dmypy.json
|
130 |
+
|
131 |
+
# Pyre type checker
|
132 |
+
.pyre/
|
133 |
+
|
134 |
+
# pytype static type analyzer
|
135 |
+
.pytype/
|
136 |
+
|
137 |
+
# Cython debug symbols
|
138 |
+
cython_debug/
|
139 |
+
|
140 |
+
|
141 |
+
# Custom
|
142 |
+
*.pth
|
143 |
+
input/**/*.*
|
144 |
+
output/**/*.*
|
145 |
+
.vscode/
|
4x_NMKD-YandereNeoXL_200k.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f61458c2637415947ee7baf05cdf529d54bb8cd3e36ba47393a3f48a2d1f3d59
|
3 |
+
size 66864028
|
ESRGAN.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools, math, re
|
2 |
+
from collections import OrderedDict
|
3 |
+
import mlx.core as mx
|
4 |
+
import mlx.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import blocks as B
|
7 |
+
from mlx.utils import tree_flatten
|
8 |
+
|
9 |
+
def conv_state_pair_to_mlx(kv):
|
10 |
+
k, v = kv
|
11 |
+
if v.ndim == 4:
|
12 |
+
v = v.transpose(0, 2, 3, 1)
|
13 |
+
v = v.reshape(-1).reshape(v.shape)
|
14 |
+
return re.sub(r'(\.\d+\.)', r'.layers\1', k), v
|
15 |
+
|
16 |
+
# Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py
|
17 |
+
# Which enhanced stuff that was already here
|
18 |
+
class ESRGAN(nn.Module):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
state_dict,
|
22 |
+
norm=None,
|
23 |
+
act: str = "leakyrelu",
|
24 |
+
upsampler: str = "upconv",
|
25 |
+
mode: str = "CNA",
|
26 |
+
) -> None:
|
27 |
+
"""
|
28 |
+
ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks.
|
29 |
+
By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao,
|
30 |
+
and Chen Change Loy.
|
31 |
+
This is old-arch Residual in Residual Dense Block Network and is not
|
32 |
+
the newest revision that's available at github.com/xinntao/ESRGAN.
|
33 |
+
This is on purpose, the newest Network has severely limited the
|
34 |
+
potential use of the Network with no benefits.
|
35 |
+
This network supports model files from both new and old-arch.
|
36 |
+
Args:
|
37 |
+
norm: Normalization layer
|
38 |
+
act: Activation layer
|
39 |
+
upsampler: Upsample layer. upconv, pixel_shuffle
|
40 |
+
mode: Convolution mode
|
41 |
+
"""
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
|
45 |
+
self._raw_state = state_dict
|
46 |
+
self.norm = norm
|
47 |
+
self.act = act
|
48 |
+
self.upsampler = upsampler
|
49 |
+
self.mode = mode
|
50 |
+
|
51 |
+
self.state_map = {
|
52 |
+
# currently supports old, new, and newer RRDBNet arch models
|
53 |
+
# ESRGAN, BSRGAN/RealSR, Real-ESRGAN
|
54 |
+
"model.0.weight": ("conv_first.weight",),
|
55 |
+
"model.0.bias": ("conv_first.bias",),
|
56 |
+
"model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"),
|
57 |
+
"model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"),
|
58 |
+
"model.3.weight": ("upconv1.weight", "conv_up1.weight"),
|
59 |
+
"model.3.bias": ("upconv1.bias", "conv_up1.bias"),
|
60 |
+
"model.6.weight": ("upconv2.weight", "conv_up2.weight"),
|
61 |
+
"model.6.bias": ("upconv2.bias", "conv_up2.bias"),
|
62 |
+
"model.8.weight": ("HRconv.weight", "conv_hr.weight"),
|
63 |
+
"model.8.bias": ("HRconv.bias", "conv_hr.bias"),
|
64 |
+
"model.10.weight": ("conv_last.weight",),
|
65 |
+
"model.10.bias": ("conv_last.bias",),
|
66 |
+
r"model.1.sub.\1.RDB\2.conv\3.0.\4": (
|
67 |
+
r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)",
|
68 |
+
r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)",
|
69 |
+
),
|
70 |
+
}
|
71 |
+
if "params_ema" in self._raw_state:
|
72 |
+
self._raw_state = self._raw_state["params_ema"]
|
73 |
+
self.num_blocks = self.get_num_blocks()
|
74 |
+
|
75 |
+
self.plus = any("conv1x1" in k for k in self._raw_state.keys())
|
76 |
+
|
77 |
+
self._raw_state = self.new_to_old_arch(self._raw_state)
|
78 |
+
|
79 |
+
self.key_arr = sorted(list(self._raw_state.keys()), key=lambda x: [1 if v == "bias" else 0 if v == "weight" else int(v) if re.match(r'^\d+$', v) else v for v in re.findall(r'[^.]+', x)])
|
80 |
+
# print(self.key_arr)
|
81 |
+
|
82 |
+
self.in_nc = self._raw_state[self.key_arr[0]].shape[1]
|
83 |
+
self.out_nc = self._raw_state[self.key_arr[-1]].shape[0]
|
84 |
+
|
85 |
+
self.scale = self.get_scale()
|
86 |
+
|
87 |
+
self.num_filters = self._raw_state[self.key_arr[0]].shape[0]
|
88 |
+
|
89 |
+
c2x2 = False
|
90 |
+
if self._raw_state["model.0.weight"].shape[-3] == 2:
|
91 |
+
c2x2 = True
|
92 |
+
self.scale = math.ceil(self.scale ** (1.0 / 3))
|
93 |
+
|
94 |
+
# Detect if pixelunshuffle was used (Real-ESRGAN)
|
95 |
+
if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in (
|
96 |
+
self.in_nc / 4,
|
97 |
+
self.in_nc / 16,
|
98 |
+
):
|
99 |
+
self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc))
|
100 |
+
else:
|
101 |
+
self.shuffle_factor = None
|
102 |
+
|
103 |
+
upsample_block = {
|
104 |
+
"upconv": B.upconv_block,
|
105 |
+
"pixel_shuffle": B.pixelshuffle_block,
|
106 |
+
}.get(self.upsampler)
|
107 |
+
if upsample_block is None:
|
108 |
+
raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found")
|
109 |
+
|
110 |
+
if self.scale == 3:
|
111 |
+
upsample_blocks = upsample_block(
|
112 |
+
in_nc=self.num_filters,
|
113 |
+
out_nc=self.num_filters,
|
114 |
+
upscale_factor=3,
|
115 |
+
act_type=self.act,
|
116 |
+
c2x2=c2x2,
|
117 |
+
)
|
118 |
+
else:
|
119 |
+
upsample_blocks = [
|
120 |
+
upsample_block(
|
121 |
+
in_nc=self.num_filters,
|
122 |
+
out_nc=self.num_filters,
|
123 |
+
act_type=self.act,
|
124 |
+
c2x2=c2x2,
|
125 |
+
)
|
126 |
+
for _ in range(int(math.log(self.scale, 2)))
|
127 |
+
]
|
128 |
+
|
129 |
+
self.model = B.sequential(
|
130 |
+
# fea conv
|
131 |
+
B.conv_block(
|
132 |
+
in_nc=self.in_nc,
|
133 |
+
out_nc=self.num_filters,
|
134 |
+
kernel_size=3,
|
135 |
+
norm_type=None,
|
136 |
+
act_type=None,
|
137 |
+
c2x2=c2x2,
|
138 |
+
),
|
139 |
+
B.ShortcutBlock(
|
140 |
+
B.sequential(
|
141 |
+
# rrdb blocks
|
142 |
+
*[
|
143 |
+
B.RRDB(
|
144 |
+
nf=self.num_filters,
|
145 |
+
kernel_size=3,
|
146 |
+
gc=32,
|
147 |
+
stride=1,
|
148 |
+
bias=True,
|
149 |
+
pad_type="zero",
|
150 |
+
norm_type=self.norm,
|
151 |
+
act_type=self.act,
|
152 |
+
mode="CNA",
|
153 |
+
plus=self.plus,
|
154 |
+
c2x2=c2x2,
|
155 |
+
)
|
156 |
+
for _ in range(self.num_blocks)
|
157 |
+
],
|
158 |
+
# lr conv
|
159 |
+
B.conv_block(
|
160 |
+
in_nc=self.num_filters,
|
161 |
+
out_nc=self.num_filters,
|
162 |
+
kernel_size=3,
|
163 |
+
norm_type=self.norm,
|
164 |
+
act_type=None,
|
165 |
+
mode=self.mode,
|
166 |
+
c2x2=c2x2,
|
167 |
+
),
|
168 |
+
)
|
169 |
+
),
|
170 |
+
*upsample_blocks,
|
171 |
+
# hr_conv0
|
172 |
+
B.conv_block(
|
173 |
+
in_nc=self.num_filters,
|
174 |
+
out_nc=self.num_filters,
|
175 |
+
kernel_size=3,
|
176 |
+
norm_type=None,
|
177 |
+
act_type=self.act,
|
178 |
+
c2x2=c2x2,
|
179 |
+
),
|
180 |
+
# hr_conv1
|
181 |
+
B.conv_block(
|
182 |
+
in_nc=self.num_filters,
|
183 |
+
out_nc=self.out_nc,
|
184 |
+
kernel_size=3,
|
185 |
+
norm_type=None,
|
186 |
+
act_type=None,
|
187 |
+
c2x2=c2x2,
|
188 |
+
),
|
189 |
+
)
|
190 |
+
|
191 |
+
self.load_weights(list(conv_state_pair_to_mlx(p) for p in self._raw_state.items()), strict=True)
|
192 |
+
|
193 |
+
|
194 |
+
def new_to_old_arch(self, state):
|
195 |
+
"""Convert a new-arch model state dictionary to an old-arch dictionary."""
|
196 |
+
if "params_ema" in state:
|
197 |
+
state = state["params_ema"]
|
198 |
+
|
199 |
+
if "conv_first.weight" not in state:
|
200 |
+
# model is already old arch, this is a loose check, but should be sufficient
|
201 |
+
return state
|
202 |
+
|
203 |
+
# add nb to state keys
|
204 |
+
for kind in ("weight", "bias"):
|
205 |
+
self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[
|
206 |
+
f"model.1.sub./NB/.{kind}"
|
207 |
+
]
|
208 |
+
del self.state_map[f"model.1.sub./NB/.{kind}"]
|
209 |
+
|
210 |
+
old_state = OrderedDict()
|
211 |
+
for old_key, new_keys in self.state_map.items():
|
212 |
+
for new_key in new_keys:
|
213 |
+
if r"\1" in old_key:
|
214 |
+
for k, v in state.items():
|
215 |
+
sub = re.sub(new_key, old_key, k)
|
216 |
+
if sub != k:
|
217 |
+
old_state[sub] = v
|
218 |
+
else:
|
219 |
+
if new_key in state:
|
220 |
+
old_state[old_key] = state[new_key]
|
221 |
+
|
222 |
+
# Sort by first numeric value of each layer
|
223 |
+
def compare(item1, item2):
|
224 |
+
parts1 = item1.split(".")
|
225 |
+
parts2 = item2.split(".")
|
226 |
+
int1 = int(parts1[1])
|
227 |
+
int2 = int(parts2[1])
|
228 |
+
return int1 - int2
|
229 |
+
|
230 |
+
sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare))
|
231 |
+
|
232 |
+
# Rebuild the output dict in the right order
|
233 |
+
out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys)
|
234 |
+
|
235 |
+
return out_dict
|
236 |
+
|
237 |
+
def get_scale(self, min_part: int = 6) -> int:
|
238 |
+
n = 0
|
239 |
+
for part in list(self._raw_state):
|
240 |
+
parts = part.split(".")[1:]
|
241 |
+
if len(parts) == 2:
|
242 |
+
part_num = int(parts[0])
|
243 |
+
if part_num > min_part and parts[1] == "weight":
|
244 |
+
n += 1
|
245 |
+
return 2**n
|
246 |
+
|
247 |
+
def get_num_blocks(self) -> int:
|
248 |
+
nbs = []
|
249 |
+
state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
|
250 |
+
r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",
|
251 |
+
)
|
252 |
+
for state_key in state_keys:
|
253 |
+
for k in self._raw_state:
|
254 |
+
m = re.search(state_key, k)
|
255 |
+
if m:
|
256 |
+
nbs.append(int(m.group(1)))
|
257 |
+
if nbs:
|
258 |
+
break
|
259 |
+
return max(*nbs) + 1
|
260 |
+
|
261 |
+
def __call__(self, x):
|
262 |
+
if self.shuffle_factor:
|
263 |
+
x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor)
|
264 |
+
return self.model(x)
|
README.md
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
1 |
+
# kaeru tiny mlx upscaler
|
2 |
+
|
3 |
+
A simple upscale script for `4x_NMKD-YandereNeoXL_200k`
|
4 |
+
Based on [joeyballentine/ESRGAN](https://github.com/JoeyBallentine/ESRGAN)
|
blocks.py
ADDED
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
import mlx.core as mx
|
7 |
+
import mlx.nn as nn
|
8 |
+
|
9 |
+
####################
|
10 |
+
# Basic blocks
|
11 |
+
####################
|
12 |
+
|
13 |
+
|
14 |
+
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
|
15 |
+
# helper selecting activation
|
16 |
+
# neg_slope: for leakyrelu and init of prelu
|
17 |
+
# n_prelu: for p_relu num_parameters
|
18 |
+
act_type = act_type.lower()
|
19 |
+
if act_type == "relu":
|
20 |
+
layer = nn.ReLU()#inplace)
|
21 |
+
elif act_type == "leakyrelu":
|
22 |
+
layer = nn.LeakyReLU(neg_slope)#, inplace)
|
23 |
+
elif act_type == "prelu":
|
24 |
+
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
25 |
+
else:
|
26 |
+
raise NotImplementedError(
|
27 |
+
"activation layer [{:s}] is not found".format(act_type)
|
28 |
+
)
|
29 |
+
return layer
|
30 |
+
|
31 |
+
|
32 |
+
def norm(norm_type, nc):
|
33 |
+
# helper selecting normalization layer
|
34 |
+
norm_type = norm_type.lower()
|
35 |
+
if norm_type == "batch":
|
36 |
+
layer = nn.BatchNorm2d(nc, affine=True)
|
37 |
+
elif norm_type == "instance":
|
38 |
+
layer = nn.InstanceNorm2d(nc, affine=False)
|
39 |
+
else:
|
40 |
+
raise NotImplementedError(
|
41 |
+
"normalization layer [{:s}] is not found".format(norm_type)
|
42 |
+
)
|
43 |
+
return layer
|
44 |
+
|
45 |
+
|
46 |
+
def pad(pad_type, padding):
|
47 |
+
# helper selecting padding layer
|
48 |
+
# if padding is 'zero', do by conv layers
|
49 |
+
pad_type = pad_type.lower()
|
50 |
+
if padding == 0:
|
51 |
+
return None
|
52 |
+
if pad_type == "reflect":
|
53 |
+
layer = nn.ReflectionPad2d(padding)
|
54 |
+
elif pad_type == "replicate":
|
55 |
+
layer = nn.ReplicationPad2d(padding)
|
56 |
+
else:
|
57 |
+
raise NotImplementedError(
|
58 |
+
"padding layer [{:s}] is not implemented".format(pad_type)
|
59 |
+
)
|
60 |
+
return layer
|
61 |
+
|
62 |
+
|
63 |
+
def get_valid_padding(kernel_size, dilation):
|
64 |
+
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
65 |
+
padding = (kernel_size - 1) // 2
|
66 |
+
return padding
|
67 |
+
|
68 |
+
|
69 |
+
class ConcatBlock(nn.Module):
|
70 |
+
# Concat the output of a submodule to its input
|
71 |
+
def __init__(self, submodule):
|
72 |
+
super(ConcatBlock, self).__init__()
|
73 |
+
self.sub = submodule
|
74 |
+
|
75 |
+
def __call__(self, x):
|
76 |
+
output = torch.cat((x, self.sub(x)), dim=1)
|
77 |
+
return output
|
78 |
+
|
79 |
+
def __repr__(self):
|
80 |
+
tmpstr = "Identity .. \n|"
|
81 |
+
modstr = self.sub.__repr__().replace("\n", "\n|")
|
82 |
+
tmpstr = tmpstr + modstr
|
83 |
+
return tmpstr
|
84 |
+
|
85 |
+
|
86 |
+
class ShortcutBlock(nn.Module):
|
87 |
+
# Elementwise sum the output of a submodule to its input
|
88 |
+
def __init__(self, submodule):
|
89 |
+
super(ShortcutBlock, self).__init__()
|
90 |
+
self.sub = submodule
|
91 |
+
|
92 |
+
def __call__(self, x):
|
93 |
+
output = x + self.sub(x)
|
94 |
+
return output
|
95 |
+
|
96 |
+
def __repr__(self):
|
97 |
+
tmpstr = "Identity + \n|"
|
98 |
+
modstr = self.sub.__repr__().replace("\n", "\n|")
|
99 |
+
tmpstr = tmpstr + modstr
|
100 |
+
return tmpstr
|
101 |
+
|
102 |
+
|
103 |
+
class ShortcutBlockSPSR(nn.Module):
|
104 |
+
# Elementwise sum the output of a submodule to its input
|
105 |
+
def __init__(self, submodule):
|
106 |
+
super(ShortcutBlockSPSR, self).__init__()
|
107 |
+
self.sub = submodule
|
108 |
+
|
109 |
+
def __call__(self, x):
|
110 |
+
return x, self.sub
|
111 |
+
|
112 |
+
def __repr__(self):
|
113 |
+
tmpstr = "Identity + \n|"
|
114 |
+
modstr = self.sub.__repr__().replace("\n", "\n|")
|
115 |
+
tmpstr = tmpstr + modstr
|
116 |
+
return tmpstr
|
117 |
+
|
118 |
+
|
119 |
+
def sequential(*args):
|
120 |
+
# Flatten Sequential. It unwraps nn.Sequential.
|
121 |
+
if len(args) == 1:
|
122 |
+
if isinstance(args[0], OrderedDict):
|
123 |
+
raise NotImplementedError("sequential does not support OrderedDict input.")
|
124 |
+
return args[0] # No sequential is needed.
|
125 |
+
modules = []
|
126 |
+
for module in args:
|
127 |
+
if isinstance(module, nn.Sequential):
|
128 |
+
for submodule in module.children()["layers"]:
|
129 |
+
modules.append(submodule)
|
130 |
+
elif isinstance(module, nn.Module):
|
131 |
+
modules.append(module)
|
132 |
+
return nn.Sequential(*modules)
|
133 |
+
|
134 |
+
|
135 |
+
def conv_block(
|
136 |
+
in_nc,
|
137 |
+
out_nc,
|
138 |
+
kernel_size,
|
139 |
+
stride=1,
|
140 |
+
dilation=1,
|
141 |
+
groups=1,
|
142 |
+
bias=True,
|
143 |
+
pad_type="zero",
|
144 |
+
norm_type=None,
|
145 |
+
act_type="relu",
|
146 |
+
mode="CNA",
|
147 |
+
c2x2=False,
|
148 |
+
):
|
149 |
+
"""
|
150 |
+
Conv layer with padding, normalization, activation
|
151 |
+
mode: CNA --> Conv -> Norm -> Act
|
152 |
+
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
153 |
+
"""
|
154 |
+
|
155 |
+
if c2x2:
|
156 |
+
return conv_block_2c2(in_nc, out_nc, act_type=act_type)
|
157 |
+
|
158 |
+
assert mode in ["CNA", "NAC", "CNAC"], "Wrong conv mode [{:s}]".format(mode)
|
159 |
+
padding = get_valid_padding(kernel_size, dilation)
|
160 |
+
p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
|
161 |
+
padding = padding if pad_type == "zero" else 0
|
162 |
+
|
163 |
+
c = nn.Conv2d(
|
164 |
+
in_nc,
|
165 |
+
out_nc,
|
166 |
+
kernel_size=kernel_size,
|
167 |
+
stride=stride,
|
168 |
+
padding=padding,
|
169 |
+
dilation=dilation,
|
170 |
+
bias=bias,
|
171 |
+
**({"groups": groups} if groups != 1 else {}),
|
172 |
+
)
|
173 |
+
a = act(act_type) if act_type else None
|
174 |
+
if "CNA" in mode:
|
175 |
+
n = norm(norm_type, out_nc) if norm_type else None
|
176 |
+
return sequential(p, c, n, a)
|
177 |
+
elif mode == "NAC":
|
178 |
+
if norm_type is None and act_type is not None:
|
179 |
+
a = act(act_type, inplace=False)
|
180 |
+
# Important!
|
181 |
+
# input----ReLU(inplace)----Conv--+----output
|
182 |
+
# |________________________|
|
183 |
+
# inplace ReLU will modify the input, therefore wrong output
|
184 |
+
n = norm(norm_type, in_nc) if norm_type else None
|
185 |
+
return sequential(n, a, p, c)
|
186 |
+
|
187 |
+
|
188 |
+
# 2x2x2 Conv Block
|
189 |
+
def conv_block_2c2(
|
190 |
+
in_nc,
|
191 |
+
out_nc,
|
192 |
+
act_type="relu",
|
193 |
+
):
|
194 |
+
return sequential(
|
195 |
+
nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
|
196 |
+
nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
|
197 |
+
act(act_type) if act_type else None,
|
198 |
+
)
|
199 |
+
|
200 |
+
|
201 |
+
####################
|
202 |
+
# Useful blocks
|
203 |
+
####################
|
204 |
+
|
205 |
+
|
206 |
+
class ResNetBlock(nn.Module):
|
207 |
+
"""
|
208 |
+
ResNet Block, 3-3 style
|
209 |
+
with extra residual scaling used in EDSR
|
210 |
+
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
|
211 |
+
"""
|
212 |
+
|
213 |
+
def __init__(
|
214 |
+
self,
|
215 |
+
in_nc,
|
216 |
+
mid_nc,
|
217 |
+
out_nc,
|
218 |
+
kernel_size=3,
|
219 |
+
stride=1,
|
220 |
+
dilation=1,
|
221 |
+
groups=1,
|
222 |
+
bias=True,
|
223 |
+
pad_type="zero",
|
224 |
+
norm_type=None,
|
225 |
+
act_type="relu",
|
226 |
+
mode="CNA",
|
227 |
+
res_scale=1,
|
228 |
+
):
|
229 |
+
super(ResNetBlock, self).__init__()
|
230 |
+
conv0 = conv_block(
|
231 |
+
in_nc,
|
232 |
+
mid_nc,
|
233 |
+
kernel_size,
|
234 |
+
stride,
|
235 |
+
dilation,
|
236 |
+
groups,
|
237 |
+
bias,
|
238 |
+
pad_type,
|
239 |
+
norm_type,
|
240 |
+
act_type,
|
241 |
+
mode,
|
242 |
+
)
|
243 |
+
if mode == "CNA":
|
244 |
+
act_type = None
|
245 |
+
if mode == "CNAC": # Residual path: |-CNAC-|
|
246 |
+
act_type = None
|
247 |
+
norm_type = None
|
248 |
+
conv1 = conv_block(
|
249 |
+
mid_nc,
|
250 |
+
out_nc,
|
251 |
+
kernel_size,
|
252 |
+
stride,
|
253 |
+
dilation,
|
254 |
+
groups,
|
255 |
+
bias,
|
256 |
+
pad_type,
|
257 |
+
norm_type,
|
258 |
+
act_type,
|
259 |
+
mode,
|
260 |
+
)
|
261 |
+
# if in_nc != out_nc:
|
262 |
+
# self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
|
263 |
+
# None, None)
|
264 |
+
# print('Need a projecter in ResNetBlock.')
|
265 |
+
# else:
|
266 |
+
# self.project = lambda x:x
|
267 |
+
self.res = sequential(conv0, conv1)
|
268 |
+
self.res_scale = res_scale
|
269 |
+
|
270 |
+
def __call__(self, x):
|
271 |
+
res = self.res(x).mul(self.res_scale)
|
272 |
+
return x + res
|
273 |
+
|
274 |
+
|
275 |
+
class RRDB(nn.Module):
|
276 |
+
"""
|
277 |
+
Residual in Residual Dense Block
|
278 |
+
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
279 |
+
"""
|
280 |
+
|
281 |
+
def __init__(
|
282 |
+
self,
|
283 |
+
nf,
|
284 |
+
kernel_size=3,
|
285 |
+
gc=32,
|
286 |
+
stride=1,
|
287 |
+
bias=1,
|
288 |
+
pad_type="zero",
|
289 |
+
norm_type=None,
|
290 |
+
act_type="leakyrelu",
|
291 |
+
mode="CNA",
|
292 |
+
convtype="Conv2D",
|
293 |
+
spectral_norm=False,
|
294 |
+
plus=False,
|
295 |
+
c2x2=False,
|
296 |
+
):
|
297 |
+
super(RRDB, self).__init__()
|
298 |
+
self.RDB1 = ResidualDenseBlock_5C(
|
299 |
+
nf,
|
300 |
+
kernel_size,
|
301 |
+
gc,
|
302 |
+
stride,
|
303 |
+
bias,
|
304 |
+
pad_type,
|
305 |
+
norm_type,
|
306 |
+
act_type,
|
307 |
+
mode,
|
308 |
+
plus=plus,
|
309 |
+
c2x2=c2x2,
|
310 |
+
)
|
311 |
+
self.RDB2 = ResidualDenseBlock_5C(
|
312 |
+
nf,
|
313 |
+
kernel_size,
|
314 |
+
gc,
|
315 |
+
stride,
|
316 |
+
bias,
|
317 |
+
pad_type,
|
318 |
+
norm_type,
|
319 |
+
act_type,
|
320 |
+
mode,
|
321 |
+
plus=plus,
|
322 |
+
c2x2=c2x2,
|
323 |
+
)
|
324 |
+
self.RDB3 = ResidualDenseBlock_5C(
|
325 |
+
nf,
|
326 |
+
kernel_size,
|
327 |
+
gc,
|
328 |
+
stride,
|
329 |
+
bias,
|
330 |
+
pad_type,
|
331 |
+
norm_type,
|
332 |
+
act_type,
|
333 |
+
mode,
|
334 |
+
plus=plus,
|
335 |
+
c2x2=c2x2,
|
336 |
+
)
|
337 |
+
|
338 |
+
def __call__(self, x):
|
339 |
+
out = self.RDB1(x)
|
340 |
+
out = self.RDB2(out)
|
341 |
+
out = self.RDB3(out)
|
342 |
+
return out * 0.2 + x
|
343 |
+
|
344 |
+
|
345 |
+
class ResidualDenseBlock_5C(nn.Module):
|
346 |
+
"""
|
347 |
+
Residual Dense Block
|
348 |
+
style: 5 convs
|
349 |
+
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
350 |
+
Modified options that can be used:
|
351 |
+
- "Partial Convolution based Padding" arXiv:1811.11718
|
352 |
+
- "Spectral normalization" arXiv:1802.05957
|
353 |
+
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
354 |
+
{Rakotonirina} and A. {Rasoanaivo}
|
355 |
+
|
356 |
+
Args:
|
357 |
+
nf (int): Channel number of intermediate features (num_feat).
|
358 |
+
gc (int): Channels for each growth (num_grow_ch: growth channel,
|
359 |
+
i.e. intermediate channels).
|
360 |
+
convtype (str): the type of convolution to use. Default: 'Conv2D'
|
361 |
+
gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new
|
362 |
+
trainable parameters)
|
363 |
+
plus (bool): enable the additional residual paths from ESRGAN+
|
364 |
+
(adds trainable parameters)
|
365 |
+
"""
|
366 |
+
|
367 |
+
def __init__(
|
368 |
+
self,
|
369 |
+
nf=64,
|
370 |
+
kernel_size=3,
|
371 |
+
gc=32,
|
372 |
+
stride=1,
|
373 |
+
bias=1,
|
374 |
+
pad_type="zero",
|
375 |
+
norm_type=None,
|
376 |
+
act_type="leakyrelu",
|
377 |
+
mode="CNA",
|
378 |
+
plus=False,
|
379 |
+
c2x2=False,
|
380 |
+
):
|
381 |
+
super(ResidualDenseBlock_5C, self).__init__()
|
382 |
+
|
383 |
+
## +
|
384 |
+
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
385 |
+
## +
|
386 |
+
|
387 |
+
self.conv1 = conv_block(
|
388 |
+
nf,
|
389 |
+
gc,
|
390 |
+
kernel_size,
|
391 |
+
stride,
|
392 |
+
bias=bias,
|
393 |
+
pad_type=pad_type,
|
394 |
+
norm_type=norm_type,
|
395 |
+
act_type=act_type,
|
396 |
+
mode=mode,
|
397 |
+
c2x2=c2x2,
|
398 |
+
)
|
399 |
+
self.conv2 = conv_block(
|
400 |
+
nf + gc,
|
401 |
+
gc,
|
402 |
+
kernel_size,
|
403 |
+
stride,
|
404 |
+
bias=bias,
|
405 |
+
pad_type=pad_type,
|
406 |
+
norm_type=norm_type,
|
407 |
+
act_type=act_type,
|
408 |
+
mode=mode,
|
409 |
+
c2x2=c2x2,
|
410 |
+
)
|
411 |
+
self.conv3 = conv_block(
|
412 |
+
nf + 2 * gc,
|
413 |
+
gc,
|
414 |
+
kernel_size,
|
415 |
+
stride,
|
416 |
+
bias=bias,
|
417 |
+
pad_type=pad_type,
|
418 |
+
norm_type=norm_type,
|
419 |
+
act_type=act_type,
|
420 |
+
mode=mode,
|
421 |
+
c2x2=c2x2,
|
422 |
+
)
|
423 |
+
self.conv4 = conv_block(
|
424 |
+
nf + 3 * gc,
|
425 |
+
gc,
|
426 |
+
kernel_size,
|
427 |
+
stride,
|
428 |
+
bias=bias,
|
429 |
+
pad_type=pad_type,
|
430 |
+
norm_type=norm_type,
|
431 |
+
act_type=act_type,
|
432 |
+
mode=mode,
|
433 |
+
c2x2=c2x2,
|
434 |
+
)
|
435 |
+
if mode == "CNA":
|
436 |
+
last_act = None
|
437 |
+
else:
|
438 |
+
last_act = act_type
|
439 |
+
self.conv5 = conv_block(
|
440 |
+
nf + 4 * gc,
|
441 |
+
nf,
|
442 |
+
3,
|
443 |
+
stride,
|
444 |
+
bias=bias,
|
445 |
+
pad_type=pad_type,
|
446 |
+
norm_type=norm_type,
|
447 |
+
act_type=last_act,
|
448 |
+
mode=mode,
|
449 |
+
c2x2=c2x2,
|
450 |
+
)
|
451 |
+
|
452 |
+
def __call__(self, x):
|
453 |
+
x1 = self.conv1(x)
|
454 |
+
x2 = self.conv2(mx.concatenate((x, x1), axis=3))
|
455 |
+
if self.conv1x1:
|
456 |
+
x2 = x2 + self.conv1x1(x) # +
|
457 |
+
x3 = self.conv3(mx.concatenate((x, x1, x2), axis=3))
|
458 |
+
x4 = self.conv4(mx.concatenate((x, x1, x2, x3), axis=3))
|
459 |
+
if self.conv1x1:
|
460 |
+
x4 = x4 + x2 # +
|
461 |
+
x5 = self.conv5(mx.concatenate((x, x1, x2, x3, x4), axis=3))
|
462 |
+
return x5 * 0.2 + x
|
463 |
+
|
464 |
+
|
465 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
466 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
467 |
+
|
468 |
+
|
469 |
+
####################
|
470 |
+
# Upsampler
|
471 |
+
####################
|
472 |
+
|
473 |
+
|
474 |
+
def pixelshuffle_block(
|
475 |
+
in_nc,
|
476 |
+
out_nc,
|
477 |
+
upscale_factor=2,
|
478 |
+
kernel_size=3,
|
479 |
+
stride=1,
|
480 |
+
bias=True,
|
481 |
+
pad_type="zero",
|
482 |
+
norm_type=None,
|
483 |
+
act_type="relu",
|
484 |
+
):
|
485 |
+
"""
|
486 |
+
Pixel shuffle layer
|
487 |
+
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
488 |
+
Neural Network, CVPR17)
|
489 |
+
"""
|
490 |
+
conv = conv_block(
|
491 |
+
in_nc,
|
492 |
+
out_nc * (upscale_factor**2),
|
493 |
+
kernel_size,
|
494 |
+
stride,
|
495 |
+
bias=bias,
|
496 |
+
pad_type=pad_type,
|
497 |
+
norm_type=None,
|
498 |
+
act_type=None,
|
499 |
+
)
|
500 |
+
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
501 |
+
|
502 |
+
n = norm(norm_type, out_nc) if norm_type else None
|
503 |
+
a = act(act_type) if act_type else None
|
504 |
+
return sequential(conv, pixel_shuffle, n, a)
|
505 |
+
|
506 |
+
|
507 |
+
def upconv_block(
|
508 |
+
in_nc,
|
509 |
+
out_nc,
|
510 |
+
upscale_factor=2,
|
511 |
+
kernel_size=3,
|
512 |
+
stride=1,
|
513 |
+
bias=True,
|
514 |
+
pad_type="zero",
|
515 |
+
norm_type=None,
|
516 |
+
act_type="relu",
|
517 |
+
mode="nearest",
|
518 |
+
c2x2=False,
|
519 |
+
):
|
520 |
+
# Up conv
|
521 |
+
# described in https://distill.pub/2016/deconv-checkerboard/
|
522 |
+
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
|
523 |
+
conv = conv_block(
|
524 |
+
in_nc,
|
525 |
+
out_nc,
|
526 |
+
kernel_size,
|
527 |
+
stride,
|
528 |
+
bias=bias,
|
529 |
+
pad_type=pad_type,
|
530 |
+
norm_type=norm_type,
|
531 |
+
act_type=act_type,
|
532 |
+
c2x2=c2x2,
|
533 |
+
)
|
534 |
+
return sequential(upsample, conv)
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mlx==0.20.0
|
2 |
+
numpy==2.1.3
|
3 |
+
pillow==11.0.0
|
4 |
+
tqdm==4.67.0
|
upscale.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, re, argparse, threading
|
2 |
+
import mlx.core as mx
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image, PngImagePlugin
|
5 |
+
from tqdm import tqdm
|
6 |
+
from ESRGAN import ESRGAN
|
7 |
+
|
8 |
+
def parse_args():
|
9 |
+
parser = argparse.ArgumentParser(description="Process tile size, padding, and file paths.")
|
10 |
+
parser.add_argument('--model', metavar='file_path', type=str, default='4x_NMKD-YandereNeoXL_200k.safetensors', help='Path to the model file')
|
11 |
+
parser.add_argument('--tile_size', metavar='256', type=int, default=256, help='Size of each tile (default: 256)')
|
12 |
+
parser.add_argument('--tile_pad', metavar='10', type=int, default=10, help='Padding around each tile (default: 10)')
|
13 |
+
parser.add_argument('files', metavar='in_file_path', type=str, nargs='+', help='List of file paths to process')
|
14 |
+
return parser.parse_args()
|
15 |
+
|
16 |
+
def load_model(model_path):
|
17 |
+
model = ESRGAN(mx.load(model_path))
|
18 |
+
return mx.compile(model), model.scale
|
19 |
+
|
20 |
+
def upscale_img(args, model, file_path, scale=4.0):
|
21 |
+
ts, tp, s = (args.tile_size, args.tile_pad, scale)
|
22 |
+
|
23 |
+
img_in = Image.open(file_path)
|
24 |
+
|
25 |
+
png_info = PngImagePlugin.PngInfo()
|
26 |
+
for k, v in (getattr(img_in, "text", None) or {}).items():
|
27 |
+
png_info.add_text(k, v)
|
28 |
+
img_save_argv = {
|
29 |
+
"icc_profile": img_in.info.get('icc_profile'),
|
30 |
+
"pnginfo": png_info,
|
31 |
+
}
|
32 |
+
|
33 |
+
img_in = mx.array(np.array(img_in.convert("RGB"), dtype=np.float32))[None] / 255.0
|
34 |
+
_, H, W, C = img_in.shape
|
35 |
+
mx.eval(img_in)
|
36 |
+
|
37 |
+
img_out = mx.zeros((1, H*s, W*s, C), dtype=mx.uint8)
|
38 |
+
mx.eval(img_out)
|
39 |
+
|
40 |
+
for hi, wj in tqdm([(hi, wj) for hi in range(0, H, ts) for wj in range(0, W, ts)]):
|
41 |
+
phs = min(hi, tp)
|
42 |
+
pws = min(wj, tp)
|
43 |
+
img_out[:, hi*4:(hi+ts)*s, wj*4:(wj+ts)*s, :] = (
|
44 |
+
model(img_in[:, max(0,hi-tp):hi+ts+tp, max(0, wj-tp):wj+ts+tp, :])[:, phs*s:(ts+phs)*s, pws*s:(ts+pws)*s, :] * 255.0
|
45 |
+
).astype(mx.uint8)
|
46 |
+
mx.eval(img_out)
|
47 |
+
|
48 |
+
img_out = np.array(img_out[0], copy=False)
|
49 |
+
img_out = Image.fromarray(img_out)
|
50 |
+
img_out.save(re.sub(r'(\.\w+)$', r'_4x.png', file_path), **img_save_argv)
|
51 |
+
|
52 |
+
def main():
|
53 |
+
print("\033[1;32mkaeru tiny mlx upscaler v0.1\033[0m")
|
54 |
+
mx.metal.set_cache_limit(0)
|
55 |
+
args = parse_args()
|
56 |
+
model, scale = load_model(args.model)
|
57 |
+
for file_path in args.files:
|
58 |
+
print(f"Upscaling {file_path}")
|
59 |
+
upscale_img(args, model, file_path, scale=scale)
|
60 |
+
|
61 |
+
if __name__ == "__main__":
|
62 |
+
th = threading.Thread(target=main)
|
63 |
+
th.start()
|
64 |
+
th.join()
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
|