|
import numpy as np |
|
import math |
|
|
|
|
|
def tiling_inference(session, lr, overlapping=8, patch_size=(56, 56)): |
|
""" |
|
Parameters: |
|
- session: an ONNX Runtime session object that contains the super-resolution model |
|
- lr: the low-resolution image |
|
- overlapping: the number of pixels to overlap between adjacent patches |
|
- patch_size: a tuple of (height, width) that specifies the size of each patch |
|
Returns: - a numpy array that represents the enhanced image |
|
""" |
|
_, h, w, _ = lr.shape |
|
sr = np.zeros((1, 2*h, 2*w, 3)) |
|
n_h = math.ceil(h / float(patch_size[0] - overlapping)) |
|
n_w = math.ceil(w / float(patch_size[1] - overlapping)) |
|
|
|
for ih in range(n_h): |
|
h_idx = ih * (patch_size[0] - overlapping) |
|
h_idx = h_idx if h_idx + patch_size[0] <= h else h - patch_size[0] |
|
for iw in range(n_w): |
|
w_idx = iw * (patch_size[1] - overlapping) |
|
w_idx = w_idx if w_idx + patch_size[1] <= w else w - patch_size[1] |
|
|
|
tiling_lr = lr[..., h_idx: h_idx+patch_size[0], w_idx: w_idx+patch_size[1], :] |
|
|
|
sr_tiling = session.run(None, {session.get_inputs()[0].name: tiling_lr})[0] |
|
|
|
left, right, top, bottom = 0, patch_size[1], 0, patch_size[0] |
|
left += overlapping//2 |
|
right -= overlapping//2 |
|
top += overlapping//2 |
|
bottom -= overlapping//2 |
|
|
|
if w_idx == 0: |
|
left -= overlapping//2 |
|
if h_idx == 0: |
|
top -= overlapping//2 |
|
if h_idx+patch_size[0]>=h: |
|
bottom += overlapping//2 |
|
if w_idx+patch_size[1]>=w: |
|
right += overlapping//2 |
|
|
|
|
|
sr[... , 2*(h_idx+top): 2*(h_idx+bottom), 2*(w_idx+left): 2*(w_idx+right), :] = sr_tiling[..., 2*top:2*bottom, 2*left:2*right, :] |
|
return sr |