File size: 2,481 Bytes
53d4b00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os, re, argparse, threading
import mlx.core as mx
import numpy as np
from PIL import Image, PngImagePlugin
from tqdm import tqdm
from ESRGAN import ESRGAN
def parse_args():
parser = argparse.ArgumentParser(description="Process tile size, padding, and file paths.")
parser.add_argument('--model', metavar='file_path', type=str, default='4x_NMKD-YandereNeoXL_200k.safetensors', help='Path to the model file')
parser.add_argument('--tile_size', metavar='256', type=int, default=256, help='Size of each tile (default: 256)')
parser.add_argument('--tile_pad', metavar='10', type=int, default=10, help='Padding around each tile (default: 10)')
parser.add_argument('files', metavar='in_file_path', type=str, nargs='+', help='List of file paths to process')
return parser.parse_args()
def load_model(model_path):
model = ESRGAN(mx.load(model_path))
return mx.compile(model), model.scale
def upscale_img(args, model, file_path, scale=4.0):
ts, tp, s = (args.tile_size, args.tile_pad, scale)
img_in = Image.open(file_path)
png_info = PngImagePlugin.PngInfo()
for k, v in (getattr(img_in, "text", None) or {}).items():
png_info.add_text(k, v)
img_save_argv = {
"icc_profile": img_in.info.get('icc_profile'),
"pnginfo": png_info,
}
img_in = mx.array(np.array(img_in.convert("RGB"), dtype=np.float32))[None] / 255.0
_, H, W, C = img_in.shape
mx.eval(img_in)
img_out = mx.zeros((1, H*s, W*s, C), dtype=mx.uint8)
mx.eval(img_out)
for hi, wj in tqdm([(hi, wj) for hi in range(0, H, ts) for wj in range(0, W, ts)]):
phs = min(hi, tp)
pws = min(wj, tp)
img_out[:, hi*4:(hi+ts)*s, wj*4:(wj+ts)*s, :] = (
model(img_in[:, max(0,hi-tp):hi+ts+tp, max(0, wj-tp):wj+ts+tp, :])[:, phs*s:(ts+phs)*s, pws*s:(ts+pws)*s, :] * 255.0
).astype(mx.uint8)
mx.eval(img_out)
img_out = np.array(img_out[0], copy=False)
img_out = Image.fromarray(img_out)
img_out.save(re.sub(r'(\.\w+)$', r'_4x.png', file_path), **img_save_argv)
def main():
print("\033[1;32mkaeru tiny mlx upscaler v0.1\033[0m")
mx.metal.set_cache_limit(0)
args = parse_args()
model, scale = load_model(args.model)
for file_path in args.files:
print(f"Upscaling {file_path}")
upscale_img(args, model, file_path, scale=scale)
if __name__ == "__main__":
th = threading.Thread(target=main)
th.start()
th.join()
|