|
|
|
|
|
from argparse import ArgumentParser, Namespace |
|
import pickle |
|
|
|
import jax |
|
from jax import jit |
|
import jax.numpy as jnp |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from model import build_thera |
|
from utils import make_grid, interpolate_grid |
|
|
|
MEAN = jnp.array([.4488, .4371, .4040]) |
|
VAR = jnp.array([.25, .25, .25]) |
|
PATCH_SIZE = 256 |
|
|
|
|
|
def process_single(source, apply_encoder, apply_decoder, params, target_shape): |
|
t = jnp.float32((target_shape[0] / source.shape[1])**-2)[None] |
|
coords_nearest = jnp.asarray(make_grid(target_shape)[None]) |
|
source_up = interpolate_grid(coords_nearest, source[None]) |
|
source = jax.nn.standardize(source, mean=MEAN, variance=VAR)[None] |
|
|
|
encoding = apply_encoder(params, source) |
|
coords = jnp.asarray(make_grid(source_up.shape[1:3])[None]) |
|
out = jnp.full_like(source_up, jnp.nan, dtype=jnp.float32) |
|
|
|
for h_min in range(0, coords.shape[1], PATCH_SIZE): |
|
h_max = min(h_min + PATCH_SIZE, coords.shape[1]) |
|
for w_min in range(0, coords.shape[2], PATCH_SIZE): |
|
|
|
w_max = min(w_min + PATCH_SIZE, coords.shape[2]) |
|
coords_patch = coords[:, h_min:h_max, w_min:w_max] |
|
out_patch = apply_decoder(params, encoding, coords_patch, t) |
|
out = out.at[:, h_min:h_max, w_min:w_max].set(out_patch) |
|
|
|
out = out * jnp.sqrt(VAR)[None, None, None] + MEAN[None, None, None] |
|
out += source_up |
|
return out |
|
|
|
|
|
def process(source, model, params, target_shape, do_ensemble=True): |
|
apply_encoder = jit(model.apply_encoder) |
|
apply_decoder = jit(model.apply_decoder) |
|
|
|
outs = [] |
|
for i_rot in range(4 if do_ensemble else 1): |
|
source_ = jnp.rot90(source, k=i_rot, axes=(-3, -2)) |
|
target_shape_ = tuple(reversed(target_shape)) if i_rot % 2 else target_shape |
|
out = process_single(source_, apply_encoder, apply_decoder, params, target_shape_) |
|
outs.append(jnp.rot90(out, k=i_rot, axes=(-2, -3))) |
|
|
|
out = jnp.stack(outs).mean(0).clip(0., 1.) |
|
return jnp.rint(out[0] * 255).astype(jnp.uint8) |
|
|
|
|
|
def main(args: Namespace): |
|
source = np.asarray(Image.open(args.in_file)) / 255. |
|
|
|
if args.scale is not None: |
|
if args.size is not None: |
|
raise ValueError('Cannot specify both size and scale') |
|
target_shape = ( |
|
round(source.shape[0] * args.scale), |
|
round(source.shape[1] * args.scale), |
|
) |
|
elif args.size is not None: |
|
target_shape = args.size |
|
else: |
|
raise ValueError('Must specify either size or scale') |
|
|
|
with open(args.checkpoint, 'rb') as fh: |
|
check = pickle.load(fh) |
|
params, backbone, size = check['model'], check['backbone'], check['size'] |
|
|
|
model = build_thera(3, backbone, size) |
|
|
|
out = process(source, model, params, target_shape, not args.no_ensemble) |
|
|
|
Image.fromarray(np.asarray(out)).save(args.out_file) |
|
|
|
|
|
def parse_args() -> Namespace: |
|
parser = ArgumentParser() |
|
parser.add_argument('in_file') |
|
parser.add_argument('out_file') |
|
parser.add_argument('--scale', type=float, help='Scale factor for super-resolution') |
|
parser.add_argument('--size', type=int, nargs=2, |
|
help='Target size (h, w), mutually exclusive with --scale') |
|
parser.add_argument('--checkpoint', help='Path to checkpoint file') |
|
parser.add_argument('--no-ensemble', action='store_true', help='Disable geo-ensemble') |
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
main(args) |
|
|