sculpt / super_resolve.py
ds1david's picture
New logic
a02c6d7
raw
history blame
3.57 kB
#!/usr/bin/env python
from argparse import ArgumentParser, Namespace
import pickle
import jax
from jax import jit
import jax.numpy as jnp
import numpy as np
from PIL import Image
from model import build_thera
from utils import make_grid, interpolate_grid
MEAN = jnp.array([.4488, .4371, .4040])
VAR = jnp.array([.25, .25, .25])
PATCH_SIZE = 256
def process_single(source, apply_encoder, apply_decoder, params, target_shape):
t = jnp.float32((target_shape[0] / source.shape[1])**-2)[None]
coords_nearest = jnp.asarray(make_grid(target_shape)[None])
source_up = interpolate_grid(coords_nearest, source[None])
source = jax.nn.standardize(source, mean=MEAN, variance=VAR)[None]
encoding = apply_encoder(params, source)
coords = jnp.asarray(make_grid(source_up.shape[1:3])[None]) # global sampling coords
out = jnp.full_like(source_up, jnp.nan, dtype=jnp.float32)
for h_min in range(0, coords.shape[1], PATCH_SIZE):
h_max = min(h_min + PATCH_SIZE, coords.shape[1])
for w_min in range(0, coords.shape[2], PATCH_SIZE):
# apply decoder with one patch of coordinates
w_max = min(w_min + PATCH_SIZE, coords.shape[2])
coords_patch = coords[:, h_min:h_max, w_min:w_max]
out_patch = apply_decoder(params, encoding, coords_patch, t)
out = out.at[:, h_min:h_max, w_min:w_max].set(out_patch)
out = out * jnp.sqrt(VAR)[None, None, None] + MEAN[None, None, None]
out += source_up
return out
def process(source, model, params, target_shape, do_ensemble=True):
apply_encoder = jit(model.apply_encoder)
apply_decoder = jit(model.apply_decoder)
outs = []
for i_rot in range(4 if do_ensemble else 1):
source_ = jnp.rot90(source, k=i_rot, axes=(-3, -2))
target_shape_ = tuple(reversed(target_shape)) if i_rot % 2 else target_shape
out = process_single(source_, apply_encoder, apply_decoder, params, target_shape_)
outs.append(jnp.rot90(out, k=i_rot, axes=(-2, -3)))
out = jnp.stack(outs).mean(0).clip(0., 1.)
return jnp.rint(out[0] * 255).astype(jnp.uint8)
def main(args: Namespace):
source = np.asarray(Image.open(args.in_file)) / 255.
if args.scale is not None:
if args.size is not None:
raise ValueError('Cannot specify both size and scale')
target_shape = (
round(source.shape[0] * args.scale),
round(source.shape[1] * args.scale),
)
elif args.size is not None:
target_shape = args.size
else:
raise ValueError('Must specify either size or scale')
with open(args.checkpoint, 'rb') as fh:
check = pickle.load(fh)
params, backbone, size = check['model'], check['backbone'], check['size']
model = build_thera(3, backbone, size)
out = process(source, model, params, target_shape, not args.no_ensemble)
Image.fromarray(np.asarray(out)).save(args.out_file)
def parse_args() -> Namespace:
parser = ArgumentParser()
parser.add_argument('in_file')
parser.add_argument('out_file')
parser.add_argument('--scale', type=float, help='Scale factor for super-resolution')
parser.add_argument('--size', type=int, nargs=2,
help='Target size (h, w), mutually exclusive with --scale')
parser.add_argument('--checkpoint', help='Path to checkpoint file')
parser.add_argument('--no-ensemble', action='store_true', help='Disable geo-ensemble')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
main(args)