|
import os, re, argparse, threading |
|
import mlx.core as mx |
|
import numpy as np |
|
from PIL import Image, PngImagePlugin |
|
from tqdm import tqdm |
|
from ESRGAN import ESRGAN |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Process tile size, padding, and file paths.") |
|
parser.add_argument('--model', metavar='file_path', type=str, default='4x_NMKD-YandereNeoXL_200k.safetensors', help='Path to the model file') |
|
parser.add_argument('--tile_size', metavar='256', type=int, default=256, help='Size of each tile (default: 256)') |
|
parser.add_argument('--tile_pad', metavar='10', type=int, default=10, help='Padding around each tile (default: 10)') |
|
parser.add_argument('files', metavar='in_file_path', type=str, nargs='+', help='List of file paths to process') |
|
return parser.parse_args() |
|
|
|
def load_model(model_path): |
|
model = ESRGAN(mx.load(model_path)) |
|
return mx.compile(model), model.scale |
|
|
|
def upscale_img(args, model, file_path, scale=4.0): |
|
ts, tp, s = (args.tile_size, args.tile_pad, scale) |
|
|
|
img_in = Image.open(file_path) |
|
|
|
png_info = PngImagePlugin.PngInfo() |
|
for k, v in (getattr(img_in, "text", None) or {}).items(): |
|
png_info.add_text(k, v) |
|
img_save_argv = { |
|
"icc_profile": img_in.info.get('icc_profile'), |
|
"pnginfo": png_info, |
|
} |
|
|
|
img_in = mx.array(np.array(img_in.convert("RGB"), dtype=np.float32))[None] / 255.0 |
|
_, H, W, C = img_in.shape |
|
mx.eval(img_in) |
|
|
|
img_out = mx.zeros((1, H*s, W*s, C), dtype=mx.uint8) |
|
mx.eval(img_out) |
|
|
|
for hi, wj in tqdm([(hi, wj) for hi in range(0, H, ts) for wj in range(0, W, ts)]): |
|
phs = min(hi, tp) |
|
pws = min(wj, tp) |
|
img_out[:, hi*4:(hi+ts)*s, wj*4:(wj+ts)*s, :] = ( |
|
model(img_in[:, max(0,hi-tp):hi+ts+tp, max(0, wj-tp):wj+ts+tp, :])[:, phs*s:(ts+phs)*s, pws*s:(ts+pws)*s, :] * 255.0 |
|
).astype(mx.uint8) |
|
mx.eval(img_out) |
|
|
|
img_out = np.array(img_out[0], copy=False) |
|
img_out = Image.fromarray(img_out) |
|
img_out.save(re.sub(r'(\.\w+)$', r'_4x.png', file_path), **img_save_argv) |
|
|
|
def main(): |
|
print("\033[1;32mkaeru tiny mlx upscaler v0.1\033[0m") |
|
mx.metal.set_cache_limit(0) |
|
args = parse_args() |
|
model, scale = load_model(args.model) |
|
for file_path in args.files: |
|
print(f"Upscaling {file_path}") |
|
upscale_img(args, model, file_path, scale=scale) |
|
|
|
if __name__ == "__main__": |
|
th = threading.Thread(target=main) |
|
th.start() |
|
th.join() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|