kaeru-shigure's picture
Upload 7 files
53d4b00 verified
import os, re, argparse, threading
import mlx.core as mx
import numpy as np
from PIL import Image, PngImagePlugin
from tqdm import tqdm
from ESRGAN import ESRGAN
def parse_args():
parser = argparse.ArgumentParser(description="Process tile size, padding, and file paths.")
parser.add_argument('--model', metavar='file_path', type=str, default='4x_NMKD-YandereNeoXL_200k.safetensors', help='Path to the model file')
parser.add_argument('--tile_size', metavar='256', type=int, default=256, help='Size of each tile (default: 256)')
parser.add_argument('--tile_pad', metavar='10', type=int, default=10, help='Padding around each tile (default: 10)')
parser.add_argument('files', metavar='in_file_path', type=str, nargs='+', help='List of file paths to process')
return parser.parse_args()
def load_model(model_path):
model = ESRGAN(mx.load(model_path))
return mx.compile(model), model.scale
def upscale_img(args, model, file_path, scale=4.0):
ts, tp, s = (args.tile_size, args.tile_pad, scale)
img_in = Image.open(file_path)
png_info = PngImagePlugin.PngInfo()
for k, v in (getattr(img_in, "text", None) or {}).items():
png_info.add_text(k, v)
img_save_argv = {
"icc_profile": img_in.info.get('icc_profile'),
"pnginfo": png_info,
}
img_in = mx.array(np.array(img_in.convert("RGB"), dtype=np.float32))[None] / 255.0
_, H, W, C = img_in.shape
mx.eval(img_in)
img_out = mx.zeros((1, H*s, W*s, C), dtype=mx.uint8)
mx.eval(img_out)
for hi, wj in tqdm([(hi, wj) for hi in range(0, H, ts) for wj in range(0, W, ts)]):
phs = min(hi, tp)
pws = min(wj, tp)
img_out[:, hi*4:(hi+ts)*s, wj*4:(wj+ts)*s, :] = (
model(img_in[:, max(0,hi-tp):hi+ts+tp, max(0, wj-tp):wj+ts+tp, :])[:, phs*s:(ts+phs)*s, pws*s:(ts+pws)*s, :] * 255.0
).astype(mx.uint8)
mx.eval(img_out)
img_out = np.array(img_out[0], copy=False)
img_out = Image.fromarray(img_out)
img_out.save(re.sub(r'(\.\w+)$', r'_4x.png', file_path), **img_save_argv)
def main():
print("\033[1;32mkaeru tiny mlx upscaler v0.1\033[0m")
mx.metal.set_cache_limit(0)
args = parse_args()
model, scale = load_model(args.model)
for file_path in args.files:
print(f"Upscaling {file_path}")
upscale_img(args, model, file_path, scale=scale)
if __name__ == "__main__":
th = threading.Thread(target=main)
th.start()
th.join()