Spaces:
Running
on
L4
Running
on
L4
#!/usr/bin/env python | |
from argparse import ArgumentParser, Namespace | |
import pickle | |
import jax | |
from jax import jit | |
import jax.numpy as jnp | |
import numpy as np | |
from PIL import Image | |
from model import build_thera | |
from utils import make_grid, interpolate_grid | |
MEAN = jnp.array([.4488, .4371, .4040]) | |
VAR = jnp.array([.25, .25, .25]) | |
PATCH_SIZE = 256 | |
def process_single(source, apply_encoder, apply_decoder, params, target_shape): | |
t = jnp.float32((target_shape[0] / source.shape[1])**-2)[None] | |
coords_nearest = jnp.asarray(make_grid(target_shape)[None]) | |
source_up = interpolate_grid(coords_nearest, source[None]) | |
source = jax.nn.standardize(source, mean=MEAN, variance=VAR)[None] | |
encoding = apply_encoder(params, source) | |
coords = jnp.asarray(make_grid(source_up.shape[1:3])[None]) # global sampling coords | |
out = jnp.full_like(source_up, jnp.nan, dtype=jnp.float32) | |
for h_min in range(0, coords.shape[1], PATCH_SIZE): | |
h_max = min(h_min + PATCH_SIZE, coords.shape[1]) | |
for w_min in range(0, coords.shape[2], PATCH_SIZE): | |
# apply decoder with one patch of coordinates | |
w_max = min(w_min + PATCH_SIZE, coords.shape[2]) | |
coords_patch = coords[:, h_min:h_max, w_min:w_max] | |
out_patch = apply_decoder(params, encoding, coords_patch, t) | |
out = out.at[:, h_min:h_max, w_min:w_max].set(out_patch) | |
out = out * jnp.sqrt(VAR)[None, None, None] + MEAN[None, None, None] | |
out += source_up | |
return out | |
def process(source, model, params, target_shape, do_ensemble=True): | |
apply_encoder = jit(model.apply_encoder) | |
apply_decoder = jit(model.apply_decoder) | |
outs = [] | |
for i_rot in range(4 if do_ensemble else 1): | |
source_ = jnp.rot90(source, k=i_rot, axes=(-3, -2)) | |
target_shape_ = tuple(reversed(target_shape)) if i_rot % 2 else target_shape | |
out = process_single(source_, apply_encoder, apply_decoder, params, target_shape_) | |
outs.append(jnp.rot90(out, k=i_rot, axes=(-2, -3))) | |
out = jnp.stack(outs).mean(0).clip(0., 1.) | |
return jnp.rint(out[0] * 255).astype(jnp.uint8) | |
def main(args: Namespace): | |
source = np.asarray(Image.open(args.in_file)) / 255. | |
if args.scale is not None: | |
if args.size is not None: | |
raise ValueError('Cannot specify both size and scale') | |
target_shape = ( | |
round(source.shape[0] * args.scale), | |
round(source.shape[1] * args.scale), | |
) | |
elif args.size is not None: | |
target_shape = args.size | |
else: | |
raise ValueError('Must specify either size or scale') | |
with open(args.checkpoint, 'rb') as fh: | |
check = pickle.load(fh) | |
params, backbone, size = check['model'], check['backbone'], check['size'] | |
model = build_thera(3, backbone, size) | |
out = process(source, model, params, target_shape, not args.no_ensemble) | |
Image.fromarray(np.asarray(out)).save(args.out_file) | |
def parse_args() -> Namespace: | |
parser = ArgumentParser() | |
parser.add_argument('in_file') | |
parser.add_argument('out_file') | |
parser.add_argument('--scale', type=float, help='Scale factor for super-resolution') | |
parser.add_argument('--size', type=int, nargs=2, | |
help='Target size (h, w), mutually exclusive with --scale') | |
parser.add_argument('--checkpoint', help='Path to checkpoint file') | |
parser.add_argument('--no-ensemble', action='store_true', help='Disable geo-ensemble') | |
return parser.parse_args() | |
if __name__ == '__main__': | |
args = parse_args() | |
main(args) | |