File size: 2,481 Bytes
53d4b00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os, re, argparse, threading
import mlx.core as mx
import numpy as np
from PIL import Image, PngImagePlugin
from tqdm import tqdm
from ESRGAN import ESRGAN

def parse_args():
    parser = argparse.ArgumentParser(description="Process tile size, padding, and file paths.")
    parser.add_argument('--model', metavar='file_path', type=str, default='4x_NMKD-YandereNeoXL_200k.safetensors', help='Path to the model file')
    parser.add_argument('--tile_size', metavar='256', type=int, default=256, help='Size of each tile (default: 256)')
    parser.add_argument('--tile_pad', metavar='10', type=int, default=10, help='Padding around each tile (default: 10)')
    parser.add_argument('files', metavar='in_file_path', type=str, nargs='+', help='List of file paths to process')
    return parser.parse_args()

def load_model(model_path):
    model = ESRGAN(mx.load(model_path))
    return mx.compile(model), model.scale

def upscale_img(args, model, file_path, scale=4.0):
    ts, tp, s = (args.tile_size, args.tile_pad, scale)

    img_in = Image.open(file_path)

    png_info = PngImagePlugin.PngInfo()
    for k, v in (getattr(img_in, "text", None) or {}).items():
        png_info.add_text(k, v)
    img_save_argv = {
        "icc_profile": img_in.info.get('icc_profile'),
        "pnginfo": png_info,
    }

    img_in = mx.array(np.array(img_in.convert("RGB"), dtype=np.float32))[None] / 255.0
    _, H, W, C = img_in.shape
    mx.eval(img_in)

    img_out = mx.zeros((1, H*s, W*s, C), dtype=mx.uint8)
    mx.eval(img_out)

    for hi, wj in tqdm([(hi, wj) for hi in range(0, H, ts) for wj in range(0, W, ts)]):
        phs = min(hi, tp)
        pws = min(wj, tp)
        img_out[:, hi*4:(hi+ts)*s, wj*4:(wj+ts)*s, :] = (
            model(img_in[:, max(0,hi-tp):hi+ts+tp, max(0, wj-tp):wj+ts+tp, :])[:, phs*s:(ts+phs)*s, pws*s:(ts+pws)*s, :] * 255.0
        ).astype(mx.uint8)
        mx.eval(img_out)
    
    img_out = np.array(img_out[0], copy=False)
    img_out = Image.fromarray(img_out)
    img_out.save(re.sub(r'(\.\w+)$', r'_4x.png', file_path), **img_save_argv)

def main():
    print("\033[1;32mkaeru tiny mlx upscaler v0.1\033[0m")
    mx.metal.set_cache_limit(0)
    args = parse_args()
    model, scale = load_model(args.model)
    for file_path in args.files:
        print(f"Upscaling {file_path}")
        upscale_img(args, model, file_path, scale=scale)

if __name__ == "__main__":
    th = threading.Thread(target=main)
    th.start()
    th.join()