import os, re, argparse, threading import mlx.core as mx import numpy as np from PIL import Image, PngImagePlugin from tqdm import tqdm from ESRGAN import ESRGAN def parse_args(): parser = argparse.ArgumentParser(description="Process tile size, padding, and file paths.") parser.add_argument('--model', metavar='file_path', type=str, default='4x_NMKD-YandereNeoXL_200k.safetensors', help='Path to the model file') parser.add_argument('--tile_size', metavar='256', type=int, default=256, help='Size of each tile (default: 256)') parser.add_argument('--tile_pad', metavar='10', type=int, default=10, help='Padding around each tile (default: 10)') parser.add_argument('files', metavar='in_file_path', type=str, nargs='+', help='List of file paths to process') return parser.parse_args() def load_model(model_path): model = ESRGAN(mx.load(model_path)) return mx.compile(model), model.scale def upscale_img(args, model, file_path, scale=4.0): ts, tp, s = (args.tile_size, args.tile_pad, scale) img_in = Image.open(file_path) png_info = PngImagePlugin.PngInfo() for k, v in (getattr(img_in, "text", None) or {}).items(): png_info.add_text(k, v) img_save_argv = { "icc_profile": img_in.info.get('icc_profile'), "pnginfo": png_info, } img_in = mx.array(np.array(img_in.convert("RGB"), dtype=np.float32))[None] / 255.0 _, H, W, C = img_in.shape mx.eval(img_in) img_out = mx.zeros((1, H*s, W*s, C), dtype=mx.uint8) mx.eval(img_out) for hi, wj in tqdm([(hi, wj) for hi in range(0, H, ts) for wj in range(0, W, ts)]): phs = min(hi, tp) pws = min(wj, tp) img_out[:, hi*4:(hi+ts)*s, wj*4:(wj+ts)*s, :] = ( model(img_in[:, max(0,hi-tp):hi+ts+tp, max(0, wj-tp):wj+ts+tp, :])[:, phs*s:(ts+phs)*s, pws*s:(ts+pws)*s, :] * 255.0 ).astype(mx.uint8) mx.eval(img_out) img_out = np.array(img_out[0], copy=False) img_out = Image.fromarray(img_out) img_out.save(re.sub(r'(\.\w+)$', r'_4x.png', file_path), **img_save_argv) def main(): print("\033[1;32mkaeru tiny mlx upscaler v0.1\033[0m") mx.metal.set_cache_limit(0) args = parse_args() model, scale = load_model(args.model) for file_path in args.files: print(f"Upscaling {file_path}") upscale_img(args, model, file_path, scale=scale) if __name__ == "__main__": th = threading.Thread(target=main) th.start() th.join()