kaeru-shigure commited on
Commit
53d4b00
·
verified ·
1 Parent(s): f0c1d90

Upload 7 files

Browse files
Files changed (7) hide show
  1. .gitignore +145 -0
  2. 4x_NMKD-YandereNeoXL_200k.safetensors +3 -0
  3. ESRGAN.py +264 -0
  4. README.md +4 -3
  5. blocks.py +534 -0
  6. requirements.txt +4 -0
  7. upscale.py +71 -0
.gitignore ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
133
+
134
+ # pytype static type analyzer
135
+ .pytype/
136
+
137
+ # Cython debug symbols
138
+ cython_debug/
139
+
140
+
141
+ # Custom
142
+ *.pth
143
+ input/**/*.*
144
+ output/**/*.*
145
+ .vscode/
4x_NMKD-YandereNeoXL_200k.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f61458c2637415947ee7baf05cdf529d54bb8cd3e36ba47393a3f48a2d1f3d59
3
+ size 66864028
ESRGAN.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools, math, re
2
+ from collections import OrderedDict
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+ import blocks as B
7
+ from mlx.utils import tree_flatten
8
+
9
+ def conv_state_pair_to_mlx(kv):
10
+ k, v = kv
11
+ if v.ndim == 4:
12
+ v = v.transpose(0, 2, 3, 1)
13
+ v = v.reshape(-1).reshape(v.shape)
14
+ return re.sub(r'(\.\d+\.)', r'.layers\1', k), v
15
+
16
+ # Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py
17
+ # Which enhanced stuff that was already here
18
+ class ESRGAN(nn.Module):
19
+ def __init__(
20
+ self,
21
+ state_dict,
22
+ norm=None,
23
+ act: str = "leakyrelu",
24
+ upsampler: str = "upconv",
25
+ mode: str = "CNA",
26
+ ) -> None:
27
+ """
28
+ ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks.
29
+ By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao,
30
+ and Chen Change Loy.
31
+ This is old-arch Residual in Residual Dense Block Network and is not
32
+ the newest revision that's available at github.com/xinntao/ESRGAN.
33
+ This is on purpose, the newest Network has severely limited the
34
+ potential use of the Network with no benefits.
35
+ This network supports model files from both new and old-arch.
36
+ Args:
37
+ norm: Normalization layer
38
+ act: Activation layer
39
+ upsampler: Upsample layer. upconv, pixel_shuffle
40
+ mode: Convolution mode
41
+ """
42
+ super().__init__()
43
+
44
+
45
+ self._raw_state = state_dict
46
+ self.norm = norm
47
+ self.act = act
48
+ self.upsampler = upsampler
49
+ self.mode = mode
50
+
51
+ self.state_map = {
52
+ # currently supports old, new, and newer RRDBNet arch models
53
+ # ESRGAN, BSRGAN/RealSR, Real-ESRGAN
54
+ "model.0.weight": ("conv_first.weight",),
55
+ "model.0.bias": ("conv_first.bias",),
56
+ "model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"),
57
+ "model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"),
58
+ "model.3.weight": ("upconv1.weight", "conv_up1.weight"),
59
+ "model.3.bias": ("upconv1.bias", "conv_up1.bias"),
60
+ "model.6.weight": ("upconv2.weight", "conv_up2.weight"),
61
+ "model.6.bias": ("upconv2.bias", "conv_up2.bias"),
62
+ "model.8.weight": ("HRconv.weight", "conv_hr.weight"),
63
+ "model.8.bias": ("HRconv.bias", "conv_hr.bias"),
64
+ "model.10.weight": ("conv_last.weight",),
65
+ "model.10.bias": ("conv_last.bias",),
66
+ r"model.1.sub.\1.RDB\2.conv\3.0.\4": (
67
+ r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)",
68
+ r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)",
69
+ ),
70
+ }
71
+ if "params_ema" in self._raw_state:
72
+ self._raw_state = self._raw_state["params_ema"]
73
+ self.num_blocks = self.get_num_blocks()
74
+
75
+ self.plus = any("conv1x1" in k for k in self._raw_state.keys())
76
+
77
+ self._raw_state = self.new_to_old_arch(self._raw_state)
78
+
79
+ self.key_arr = sorted(list(self._raw_state.keys()), key=lambda x: [1 if v == "bias" else 0 if v == "weight" else int(v) if re.match(r'^\d+$', v) else v for v in re.findall(r'[^.]+', x)])
80
+ # print(self.key_arr)
81
+
82
+ self.in_nc = self._raw_state[self.key_arr[0]].shape[1]
83
+ self.out_nc = self._raw_state[self.key_arr[-1]].shape[0]
84
+
85
+ self.scale = self.get_scale()
86
+
87
+ self.num_filters = self._raw_state[self.key_arr[0]].shape[0]
88
+
89
+ c2x2 = False
90
+ if self._raw_state["model.0.weight"].shape[-3] == 2:
91
+ c2x2 = True
92
+ self.scale = math.ceil(self.scale ** (1.0 / 3))
93
+
94
+ # Detect if pixelunshuffle was used (Real-ESRGAN)
95
+ if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in (
96
+ self.in_nc / 4,
97
+ self.in_nc / 16,
98
+ ):
99
+ self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc))
100
+ else:
101
+ self.shuffle_factor = None
102
+
103
+ upsample_block = {
104
+ "upconv": B.upconv_block,
105
+ "pixel_shuffle": B.pixelshuffle_block,
106
+ }.get(self.upsampler)
107
+ if upsample_block is None:
108
+ raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found")
109
+
110
+ if self.scale == 3:
111
+ upsample_blocks = upsample_block(
112
+ in_nc=self.num_filters,
113
+ out_nc=self.num_filters,
114
+ upscale_factor=3,
115
+ act_type=self.act,
116
+ c2x2=c2x2,
117
+ )
118
+ else:
119
+ upsample_blocks = [
120
+ upsample_block(
121
+ in_nc=self.num_filters,
122
+ out_nc=self.num_filters,
123
+ act_type=self.act,
124
+ c2x2=c2x2,
125
+ )
126
+ for _ in range(int(math.log(self.scale, 2)))
127
+ ]
128
+
129
+ self.model = B.sequential(
130
+ # fea conv
131
+ B.conv_block(
132
+ in_nc=self.in_nc,
133
+ out_nc=self.num_filters,
134
+ kernel_size=3,
135
+ norm_type=None,
136
+ act_type=None,
137
+ c2x2=c2x2,
138
+ ),
139
+ B.ShortcutBlock(
140
+ B.sequential(
141
+ # rrdb blocks
142
+ *[
143
+ B.RRDB(
144
+ nf=self.num_filters,
145
+ kernel_size=3,
146
+ gc=32,
147
+ stride=1,
148
+ bias=True,
149
+ pad_type="zero",
150
+ norm_type=self.norm,
151
+ act_type=self.act,
152
+ mode="CNA",
153
+ plus=self.plus,
154
+ c2x2=c2x2,
155
+ )
156
+ for _ in range(self.num_blocks)
157
+ ],
158
+ # lr conv
159
+ B.conv_block(
160
+ in_nc=self.num_filters,
161
+ out_nc=self.num_filters,
162
+ kernel_size=3,
163
+ norm_type=self.norm,
164
+ act_type=None,
165
+ mode=self.mode,
166
+ c2x2=c2x2,
167
+ ),
168
+ )
169
+ ),
170
+ *upsample_blocks,
171
+ # hr_conv0
172
+ B.conv_block(
173
+ in_nc=self.num_filters,
174
+ out_nc=self.num_filters,
175
+ kernel_size=3,
176
+ norm_type=None,
177
+ act_type=self.act,
178
+ c2x2=c2x2,
179
+ ),
180
+ # hr_conv1
181
+ B.conv_block(
182
+ in_nc=self.num_filters,
183
+ out_nc=self.out_nc,
184
+ kernel_size=3,
185
+ norm_type=None,
186
+ act_type=None,
187
+ c2x2=c2x2,
188
+ ),
189
+ )
190
+
191
+ self.load_weights(list(conv_state_pair_to_mlx(p) for p in self._raw_state.items()), strict=True)
192
+
193
+
194
+ def new_to_old_arch(self, state):
195
+ """Convert a new-arch model state dictionary to an old-arch dictionary."""
196
+ if "params_ema" in state:
197
+ state = state["params_ema"]
198
+
199
+ if "conv_first.weight" not in state:
200
+ # model is already old arch, this is a loose check, but should be sufficient
201
+ return state
202
+
203
+ # add nb to state keys
204
+ for kind in ("weight", "bias"):
205
+ self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[
206
+ f"model.1.sub./NB/.{kind}"
207
+ ]
208
+ del self.state_map[f"model.1.sub./NB/.{kind}"]
209
+
210
+ old_state = OrderedDict()
211
+ for old_key, new_keys in self.state_map.items():
212
+ for new_key in new_keys:
213
+ if r"\1" in old_key:
214
+ for k, v in state.items():
215
+ sub = re.sub(new_key, old_key, k)
216
+ if sub != k:
217
+ old_state[sub] = v
218
+ else:
219
+ if new_key in state:
220
+ old_state[old_key] = state[new_key]
221
+
222
+ # Sort by first numeric value of each layer
223
+ def compare(item1, item2):
224
+ parts1 = item1.split(".")
225
+ parts2 = item2.split(".")
226
+ int1 = int(parts1[1])
227
+ int2 = int(parts2[1])
228
+ return int1 - int2
229
+
230
+ sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare))
231
+
232
+ # Rebuild the output dict in the right order
233
+ out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys)
234
+
235
+ return out_dict
236
+
237
+ def get_scale(self, min_part: int = 6) -> int:
238
+ n = 0
239
+ for part in list(self._raw_state):
240
+ parts = part.split(".")[1:]
241
+ if len(parts) == 2:
242
+ part_num = int(parts[0])
243
+ if part_num > min_part and parts[1] == "weight":
244
+ n += 1
245
+ return 2**n
246
+
247
+ def get_num_blocks(self) -> int:
248
+ nbs = []
249
+ state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
250
+ r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",
251
+ )
252
+ for state_key in state_keys:
253
+ for k in self._raw_state:
254
+ m = re.search(state_key, k)
255
+ if m:
256
+ nbs.append(int(m.group(1)))
257
+ if nbs:
258
+ break
259
+ return max(*nbs) + 1
260
+
261
+ def __call__(self, x):
262
+ if self.shuffle_factor:
263
+ x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor)
264
+ return self.model(x)
README.md CHANGED
@@ -1,3 +1,4 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
1
+ # kaeru tiny mlx upscaler
2
+
3
+ A simple upscale script for `4x_NMKD-YandereNeoXL_200k`
4
+ Based on [joeyballentine/ESRGAN](https://github.com/JoeyBallentine/ESRGAN)
blocks.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from collections import OrderedDict
5
+
6
+ import mlx.core as mx
7
+ import mlx.nn as nn
8
+
9
+ ####################
10
+ # Basic blocks
11
+ ####################
12
+
13
+
14
+ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
15
+ # helper selecting activation
16
+ # neg_slope: for leakyrelu and init of prelu
17
+ # n_prelu: for p_relu num_parameters
18
+ act_type = act_type.lower()
19
+ if act_type == "relu":
20
+ layer = nn.ReLU()#inplace)
21
+ elif act_type == "leakyrelu":
22
+ layer = nn.LeakyReLU(neg_slope)#, inplace)
23
+ elif act_type == "prelu":
24
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
25
+ else:
26
+ raise NotImplementedError(
27
+ "activation layer [{:s}] is not found".format(act_type)
28
+ )
29
+ return layer
30
+
31
+
32
+ def norm(norm_type, nc):
33
+ # helper selecting normalization layer
34
+ norm_type = norm_type.lower()
35
+ if norm_type == "batch":
36
+ layer = nn.BatchNorm2d(nc, affine=True)
37
+ elif norm_type == "instance":
38
+ layer = nn.InstanceNorm2d(nc, affine=False)
39
+ else:
40
+ raise NotImplementedError(
41
+ "normalization layer [{:s}] is not found".format(norm_type)
42
+ )
43
+ return layer
44
+
45
+
46
+ def pad(pad_type, padding):
47
+ # helper selecting padding layer
48
+ # if padding is 'zero', do by conv layers
49
+ pad_type = pad_type.lower()
50
+ if padding == 0:
51
+ return None
52
+ if pad_type == "reflect":
53
+ layer = nn.ReflectionPad2d(padding)
54
+ elif pad_type == "replicate":
55
+ layer = nn.ReplicationPad2d(padding)
56
+ else:
57
+ raise NotImplementedError(
58
+ "padding layer [{:s}] is not implemented".format(pad_type)
59
+ )
60
+ return layer
61
+
62
+
63
+ def get_valid_padding(kernel_size, dilation):
64
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
65
+ padding = (kernel_size - 1) // 2
66
+ return padding
67
+
68
+
69
+ class ConcatBlock(nn.Module):
70
+ # Concat the output of a submodule to its input
71
+ def __init__(self, submodule):
72
+ super(ConcatBlock, self).__init__()
73
+ self.sub = submodule
74
+
75
+ def __call__(self, x):
76
+ output = torch.cat((x, self.sub(x)), dim=1)
77
+ return output
78
+
79
+ def __repr__(self):
80
+ tmpstr = "Identity .. \n|"
81
+ modstr = self.sub.__repr__().replace("\n", "\n|")
82
+ tmpstr = tmpstr + modstr
83
+ return tmpstr
84
+
85
+
86
+ class ShortcutBlock(nn.Module):
87
+ # Elementwise sum the output of a submodule to its input
88
+ def __init__(self, submodule):
89
+ super(ShortcutBlock, self).__init__()
90
+ self.sub = submodule
91
+
92
+ def __call__(self, x):
93
+ output = x + self.sub(x)
94
+ return output
95
+
96
+ def __repr__(self):
97
+ tmpstr = "Identity + \n|"
98
+ modstr = self.sub.__repr__().replace("\n", "\n|")
99
+ tmpstr = tmpstr + modstr
100
+ return tmpstr
101
+
102
+
103
+ class ShortcutBlockSPSR(nn.Module):
104
+ # Elementwise sum the output of a submodule to its input
105
+ def __init__(self, submodule):
106
+ super(ShortcutBlockSPSR, self).__init__()
107
+ self.sub = submodule
108
+
109
+ def __call__(self, x):
110
+ return x, self.sub
111
+
112
+ def __repr__(self):
113
+ tmpstr = "Identity + \n|"
114
+ modstr = self.sub.__repr__().replace("\n", "\n|")
115
+ tmpstr = tmpstr + modstr
116
+ return tmpstr
117
+
118
+
119
+ def sequential(*args):
120
+ # Flatten Sequential. It unwraps nn.Sequential.
121
+ if len(args) == 1:
122
+ if isinstance(args[0], OrderedDict):
123
+ raise NotImplementedError("sequential does not support OrderedDict input.")
124
+ return args[0] # No sequential is needed.
125
+ modules = []
126
+ for module in args:
127
+ if isinstance(module, nn.Sequential):
128
+ for submodule in module.children()["layers"]:
129
+ modules.append(submodule)
130
+ elif isinstance(module, nn.Module):
131
+ modules.append(module)
132
+ return nn.Sequential(*modules)
133
+
134
+
135
+ def conv_block(
136
+ in_nc,
137
+ out_nc,
138
+ kernel_size,
139
+ stride=1,
140
+ dilation=1,
141
+ groups=1,
142
+ bias=True,
143
+ pad_type="zero",
144
+ norm_type=None,
145
+ act_type="relu",
146
+ mode="CNA",
147
+ c2x2=False,
148
+ ):
149
+ """
150
+ Conv layer with padding, normalization, activation
151
+ mode: CNA --> Conv -> Norm -> Act
152
+ NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
153
+ """
154
+
155
+ if c2x2:
156
+ return conv_block_2c2(in_nc, out_nc, act_type=act_type)
157
+
158
+ assert mode in ["CNA", "NAC", "CNAC"], "Wrong conv mode [{:s}]".format(mode)
159
+ padding = get_valid_padding(kernel_size, dilation)
160
+ p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
161
+ padding = padding if pad_type == "zero" else 0
162
+
163
+ c = nn.Conv2d(
164
+ in_nc,
165
+ out_nc,
166
+ kernel_size=kernel_size,
167
+ stride=stride,
168
+ padding=padding,
169
+ dilation=dilation,
170
+ bias=bias,
171
+ **({"groups": groups} if groups != 1 else {}),
172
+ )
173
+ a = act(act_type) if act_type else None
174
+ if "CNA" in mode:
175
+ n = norm(norm_type, out_nc) if norm_type else None
176
+ return sequential(p, c, n, a)
177
+ elif mode == "NAC":
178
+ if norm_type is None and act_type is not None:
179
+ a = act(act_type, inplace=False)
180
+ # Important!
181
+ # input----ReLU(inplace)----Conv--+----output
182
+ # |________________________|
183
+ # inplace ReLU will modify the input, therefore wrong output
184
+ n = norm(norm_type, in_nc) if norm_type else None
185
+ return sequential(n, a, p, c)
186
+
187
+
188
+ # 2x2x2 Conv Block
189
+ def conv_block_2c2(
190
+ in_nc,
191
+ out_nc,
192
+ act_type="relu",
193
+ ):
194
+ return sequential(
195
+ nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
196
+ nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
197
+ act(act_type) if act_type else None,
198
+ )
199
+
200
+
201
+ ####################
202
+ # Useful blocks
203
+ ####################
204
+
205
+
206
+ class ResNetBlock(nn.Module):
207
+ """
208
+ ResNet Block, 3-3 style
209
+ with extra residual scaling used in EDSR
210
+ (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ in_nc,
216
+ mid_nc,
217
+ out_nc,
218
+ kernel_size=3,
219
+ stride=1,
220
+ dilation=1,
221
+ groups=1,
222
+ bias=True,
223
+ pad_type="zero",
224
+ norm_type=None,
225
+ act_type="relu",
226
+ mode="CNA",
227
+ res_scale=1,
228
+ ):
229
+ super(ResNetBlock, self).__init__()
230
+ conv0 = conv_block(
231
+ in_nc,
232
+ mid_nc,
233
+ kernel_size,
234
+ stride,
235
+ dilation,
236
+ groups,
237
+ bias,
238
+ pad_type,
239
+ norm_type,
240
+ act_type,
241
+ mode,
242
+ )
243
+ if mode == "CNA":
244
+ act_type = None
245
+ if mode == "CNAC": # Residual path: |-CNAC-|
246
+ act_type = None
247
+ norm_type = None
248
+ conv1 = conv_block(
249
+ mid_nc,
250
+ out_nc,
251
+ kernel_size,
252
+ stride,
253
+ dilation,
254
+ groups,
255
+ bias,
256
+ pad_type,
257
+ norm_type,
258
+ act_type,
259
+ mode,
260
+ )
261
+ # if in_nc != out_nc:
262
+ # self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
263
+ # None, None)
264
+ # print('Need a projecter in ResNetBlock.')
265
+ # else:
266
+ # self.project = lambda x:x
267
+ self.res = sequential(conv0, conv1)
268
+ self.res_scale = res_scale
269
+
270
+ def __call__(self, x):
271
+ res = self.res(x).mul(self.res_scale)
272
+ return x + res
273
+
274
+
275
+ class RRDB(nn.Module):
276
+ """
277
+ Residual in Residual Dense Block
278
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
279
+ """
280
+
281
+ def __init__(
282
+ self,
283
+ nf,
284
+ kernel_size=3,
285
+ gc=32,
286
+ stride=1,
287
+ bias=1,
288
+ pad_type="zero",
289
+ norm_type=None,
290
+ act_type="leakyrelu",
291
+ mode="CNA",
292
+ convtype="Conv2D",
293
+ spectral_norm=False,
294
+ plus=False,
295
+ c2x2=False,
296
+ ):
297
+ super(RRDB, self).__init__()
298
+ self.RDB1 = ResidualDenseBlock_5C(
299
+ nf,
300
+ kernel_size,
301
+ gc,
302
+ stride,
303
+ bias,
304
+ pad_type,
305
+ norm_type,
306
+ act_type,
307
+ mode,
308
+ plus=plus,
309
+ c2x2=c2x2,
310
+ )
311
+ self.RDB2 = ResidualDenseBlock_5C(
312
+ nf,
313
+ kernel_size,
314
+ gc,
315
+ stride,
316
+ bias,
317
+ pad_type,
318
+ norm_type,
319
+ act_type,
320
+ mode,
321
+ plus=plus,
322
+ c2x2=c2x2,
323
+ )
324
+ self.RDB3 = ResidualDenseBlock_5C(
325
+ nf,
326
+ kernel_size,
327
+ gc,
328
+ stride,
329
+ bias,
330
+ pad_type,
331
+ norm_type,
332
+ act_type,
333
+ mode,
334
+ plus=plus,
335
+ c2x2=c2x2,
336
+ )
337
+
338
+ def __call__(self, x):
339
+ out = self.RDB1(x)
340
+ out = self.RDB2(out)
341
+ out = self.RDB3(out)
342
+ return out * 0.2 + x
343
+
344
+
345
+ class ResidualDenseBlock_5C(nn.Module):
346
+ """
347
+ Residual Dense Block
348
+ style: 5 convs
349
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
350
+ Modified options that can be used:
351
+ - "Partial Convolution based Padding" arXiv:1811.11718
352
+ - "Spectral normalization" arXiv:1802.05957
353
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
354
+ {Rakotonirina} and A. {Rasoanaivo}
355
+
356
+ Args:
357
+ nf (int): Channel number of intermediate features (num_feat).
358
+ gc (int): Channels for each growth (num_grow_ch: growth channel,
359
+ i.e. intermediate channels).
360
+ convtype (str): the type of convolution to use. Default: 'Conv2D'
361
+ gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new
362
+ trainable parameters)
363
+ plus (bool): enable the additional residual paths from ESRGAN+
364
+ (adds trainable parameters)
365
+ """
366
+
367
+ def __init__(
368
+ self,
369
+ nf=64,
370
+ kernel_size=3,
371
+ gc=32,
372
+ stride=1,
373
+ bias=1,
374
+ pad_type="zero",
375
+ norm_type=None,
376
+ act_type="leakyrelu",
377
+ mode="CNA",
378
+ plus=False,
379
+ c2x2=False,
380
+ ):
381
+ super(ResidualDenseBlock_5C, self).__init__()
382
+
383
+ ## +
384
+ self.conv1x1 = conv1x1(nf, gc) if plus else None
385
+ ## +
386
+
387
+ self.conv1 = conv_block(
388
+ nf,
389
+ gc,
390
+ kernel_size,
391
+ stride,
392
+ bias=bias,
393
+ pad_type=pad_type,
394
+ norm_type=norm_type,
395
+ act_type=act_type,
396
+ mode=mode,
397
+ c2x2=c2x2,
398
+ )
399
+ self.conv2 = conv_block(
400
+ nf + gc,
401
+ gc,
402
+ kernel_size,
403
+ stride,
404
+ bias=bias,
405
+ pad_type=pad_type,
406
+ norm_type=norm_type,
407
+ act_type=act_type,
408
+ mode=mode,
409
+ c2x2=c2x2,
410
+ )
411
+ self.conv3 = conv_block(
412
+ nf + 2 * gc,
413
+ gc,
414
+ kernel_size,
415
+ stride,
416
+ bias=bias,
417
+ pad_type=pad_type,
418
+ norm_type=norm_type,
419
+ act_type=act_type,
420
+ mode=mode,
421
+ c2x2=c2x2,
422
+ )
423
+ self.conv4 = conv_block(
424
+ nf + 3 * gc,
425
+ gc,
426
+ kernel_size,
427
+ stride,
428
+ bias=bias,
429
+ pad_type=pad_type,
430
+ norm_type=norm_type,
431
+ act_type=act_type,
432
+ mode=mode,
433
+ c2x2=c2x2,
434
+ )
435
+ if mode == "CNA":
436
+ last_act = None
437
+ else:
438
+ last_act = act_type
439
+ self.conv5 = conv_block(
440
+ nf + 4 * gc,
441
+ nf,
442
+ 3,
443
+ stride,
444
+ bias=bias,
445
+ pad_type=pad_type,
446
+ norm_type=norm_type,
447
+ act_type=last_act,
448
+ mode=mode,
449
+ c2x2=c2x2,
450
+ )
451
+
452
+ def __call__(self, x):
453
+ x1 = self.conv1(x)
454
+ x2 = self.conv2(mx.concatenate((x, x1), axis=3))
455
+ if self.conv1x1:
456
+ x2 = x2 + self.conv1x1(x) # +
457
+ x3 = self.conv3(mx.concatenate((x, x1, x2), axis=3))
458
+ x4 = self.conv4(mx.concatenate((x, x1, x2, x3), axis=3))
459
+ if self.conv1x1:
460
+ x4 = x4 + x2 # +
461
+ x5 = self.conv5(mx.concatenate((x, x1, x2, x3, x4), axis=3))
462
+ return x5 * 0.2 + x
463
+
464
+
465
+ def conv1x1(in_planes, out_planes, stride=1):
466
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
467
+
468
+
469
+ ####################
470
+ # Upsampler
471
+ ####################
472
+
473
+
474
+ def pixelshuffle_block(
475
+ in_nc,
476
+ out_nc,
477
+ upscale_factor=2,
478
+ kernel_size=3,
479
+ stride=1,
480
+ bias=True,
481
+ pad_type="zero",
482
+ norm_type=None,
483
+ act_type="relu",
484
+ ):
485
+ """
486
+ Pixel shuffle layer
487
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
488
+ Neural Network, CVPR17)
489
+ """
490
+ conv = conv_block(
491
+ in_nc,
492
+ out_nc * (upscale_factor**2),
493
+ kernel_size,
494
+ stride,
495
+ bias=bias,
496
+ pad_type=pad_type,
497
+ norm_type=None,
498
+ act_type=None,
499
+ )
500
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
501
+
502
+ n = norm(norm_type, out_nc) if norm_type else None
503
+ a = act(act_type) if act_type else None
504
+ return sequential(conv, pixel_shuffle, n, a)
505
+
506
+
507
+ def upconv_block(
508
+ in_nc,
509
+ out_nc,
510
+ upscale_factor=2,
511
+ kernel_size=3,
512
+ stride=1,
513
+ bias=True,
514
+ pad_type="zero",
515
+ norm_type=None,
516
+ act_type="relu",
517
+ mode="nearest",
518
+ c2x2=False,
519
+ ):
520
+ # Up conv
521
+ # described in https://distill.pub/2016/deconv-checkerboard/
522
+ upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
523
+ conv = conv_block(
524
+ in_nc,
525
+ out_nc,
526
+ kernel_size,
527
+ stride,
528
+ bias=bias,
529
+ pad_type=pad_type,
530
+ norm_type=norm_type,
531
+ act_type=act_type,
532
+ c2x2=c2x2,
533
+ )
534
+ return sequential(upsample, conv)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ mlx==0.20.0
2
+ numpy==2.1.3
3
+ pillow==11.0.0
4
+ tqdm==4.67.0
upscale.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, re, argparse, threading
2
+ import mlx.core as mx
3
+ import numpy as np
4
+ from PIL import Image, PngImagePlugin
5
+ from tqdm import tqdm
6
+ from ESRGAN import ESRGAN
7
+
8
+ def parse_args():
9
+ parser = argparse.ArgumentParser(description="Process tile size, padding, and file paths.")
10
+ parser.add_argument('--model', metavar='file_path', type=str, default='4x_NMKD-YandereNeoXL_200k.safetensors', help='Path to the model file')
11
+ parser.add_argument('--tile_size', metavar='256', type=int, default=256, help='Size of each tile (default: 256)')
12
+ parser.add_argument('--tile_pad', metavar='10', type=int, default=10, help='Padding around each tile (default: 10)')
13
+ parser.add_argument('files', metavar='in_file_path', type=str, nargs='+', help='List of file paths to process')
14
+ return parser.parse_args()
15
+
16
+ def load_model(model_path):
17
+ model = ESRGAN(mx.load(model_path))
18
+ return mx.compile(model), model.scale
19
+
20
+ def upscale_img(args, model, file_path, scale=4.0):
21
+ ts, tp, s = (args.tile_size, args.tile_pad, scale)
22
+
23
+ img_in = Image.open(file_path)
24
+
25
+ png_info = PngImagePlugin.PngInfo()
26
+ for k, v in (getattr(img_in, "text", None) or {}).items():
27
+ png_info.add_text(k, v)
28
+ img_save_argv = {
29
+ "icc_profile": img_in.info.get('icc_profile'),
30
+ "pnginfo": png_info,
31
+ }
32
+
33
+ img_in = mx.array(np.array(img_in.convert("RGB"), dtype=np.float32))[None] / 255.0
34
+ _, H, W, C = img_in.shape
35
+ mx.eval(img_in)
36
+
37
+ img_out = mx.zeros((1, H*s, W*s, C), dtype=mx.uint8)
38
+ mx.eval(img_out)
39
+
40
+ for hi, wj in tqdm([(hi, wj) for hi in range(0, H, ts) for wj in range(0, W, ts)]):
41
+ phs = min(hi, tp)
42
+ pws = min(wj, tp)
43
+ img_out[:, hi*4:(hi+ts)*s, wj*4:(wj+ts)*s, :] = (
44
+ model(img_in[:, max(0,hi-tp):hi+ts+tp, max(0, wj-tp):wj+ts+tp, :])[:, phs*s:(ts+phs)*s, pws*s:(ts+pws)*s, :] * 255.0
45
+ ).astype(mx.uint8)
46
+ mx.eval(img_out)
47
+
48
+ img_out = np.array(img_out[0], copy=False)
49
+ img_out = Image.fromarray(img_out)
50
+ img_out.save(re.sub(r'(\.\w+)$', r'_4x.png', file_path), **img_save_argv)
51
+
52
+ def main():
53
+ print("\033[1;32mkaeru tiny mlx upscaler v0.1\033[0m")
54
+ mx.metal.set_cache_limit(0)
55
+ args = parse_args()
56
+ model, scale = load_model(args.model)
57
+ for file_path in args.files:
58
+ print(f"Upscaling {file_path}")
59
+ upscale_img(args, model, file_path, scale=scale)
60
+
61
+ if __name__ == "__main__":
62
+ th = threading.Thread(target=main)
63
+ th.start()
64
+ th.join()
65
+
66
+
67
+
68
+
69
+
70
+
71
+