Caleb Spradlin
initial commit
ab687e7
import logging
import math
import numpy as np
import torch
from tiler import Tiler, Merger
from pytorch_caney.processing import normalize
from pytorch_caney.processing import global_standardization
from pytorch_caney.processing import local_standardization
from pytorch_caney.processing import standardize_image
__author__ = "Jordan A Caraballo-Vega, Science Data Processing Branch"
__email__ = "jordan.a.caraballo-vega@nasa.gov"
__status__ = "Production"
# ---------------------------------------------------------------------------
# module inference
#
# Data segmentation and prediction functions.
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Module Methods
# ---------------------------------------------------------------------------
def sliding_window_tiler_multiclass(
xraster,
model,
n_classes: int,
img_size: int,
pad_style: str = 'reflect',
overlap: float = 0.50,
constant_value: int = 600,
batch_size: int = 1024,
threshold: float = 0.50,
standardization: str = None,
mean=None,
std=None,
normalize: float = 1.0,
rescale: str = None,
window: str = 'triang', # 'overlap-tile'
probability_map: bool = False
):
"""
Sliding window using tiler.
"""
tile_channels = xraster.shape[-1] # model.layers[0].input_shape[0][-1]
print(f'Standardizing: {standardization}')
# n_classes = out of the output layer, output_shape
tiler_image = Tiler(
data_shape=xraster.shape,
tile_shape=(img_size, img_size, tile_channels),
channel_dimension=-1,
overlap=overlap,
mode=pad_style,
constant_value=constant_value
)
# Define the tiler and merger based on the output size of the prediction
tiler_mask = Tiler(
data_shape=(xraster.shape[0], xraster.shape[1], n_classes),
tile_shape=(img_size, img_size, n_classes),
channel_dimension=-1,
overlap=overlap,
mode=pad_style,
constant_value=constant_value
)
merger = Merger(tiler=tiler_mask, window=window)
# xraster = normalize_image(xraster, normalize)
# Iterate over the data in batches
for batch_id, batch_i in tiler_image(xraster, batch_size=batch_size):
# Standardize
batch = batch_i.copy()
if standardization is not None:
for item in range(batch.shape[0]):
batch[item, :, :, :] = standardize_image(
batch[item, :, :, :], standardization, mean, std)
input_batch = batch.astype('float32')
input_batch_tensor = torch.from_numpy(input_batch)
input_batch_tensor = input_batch_tensor.transpose(-1, 1)
# input_batch_tensor = input_batch_tensor.cuda(non_blocking=True)
with torch.no_grad():
y_batch = model(input_batch_tensor)
y_batch = y_batch.transpose(1, -1) # .cpu().numpy()
merger.add_batch(batch_id, batch_size, y_batch)
prediction = merger.merge(unpad=True)
if not probability_map:
if prediction.shape[-1] > 1:
prediction = np.argmax(prediction, axis=-1)
else:
prediction = np.squeeze(
np.where(prediction > threshold, 1, 0).astype(np.int16)
)
else:
prediction = np.squeeze(prediction)
return prediction
# --------------------------- Segmentation Functions ----------------------- #
def segment(image, model='model.h5', tile_size=256, channels=6,
norm_data=[], bsize=8):
"""
Applies a semantic segmentation model to an image. Ideal for non-scene
imagery. Leaves artifacts in boundaries if no post-processing is done.
:param image: image to classify (numpy array)
:param model: loaded model object
:param tile_size: tile size of patches
:param channels: number of channels
:param norm_data: numpy array with mean and std data
:param bsize: number of patches to predict at the same time
return numpy array with classified mask
"""
# Create blank array to store predicted label
seg = np.zeros((image.shape[0], image.shape[1]))
for i in range(0, image.shape[0], int(tile_size)):
for j in range(0, image.shape[1], int(tile_size)):
# If edge of tile beyond image boundary, shift it to boundary
if i + tile_size > image.shape[0]:
i = image.shape[0] - tile_size
if j + tile_size > image.shape[1]:
j = image.shape[1] - tile_size
# Extract and normalise tile
tile = normalize(
image[i: i + tile_size, j: j + tile_size, :].astype(float),
norm_data
)
out = model.predict(
tile.reshape(
(1, tile.shape[0], tile.shape[1], tile.shape[2])
).astype(float),
batch_size=4
)
out = out.argmax(axis=3) # get max prediction for pixel in classes
out = out.reshape(tile_size, tile_size) # reshape to tile size
seg[i: i + tile_size, j: j + tile_size] = out
return seg
def segment_binary(image, model='model.h5', norm_data=[],
tile_size=256, channels=6, bsize=8
):
"""
Applies binary semantic segmentation model to an image. Ideal for non-scene
imagery. Leaves artifacts in boundaries if no post-processing is done.
:param image: image to classify (numpy array)
:param model: loaded model object
:param tile_size: tile size of patches
:param channels: number of channels
:param norm_data: numpy array with mean and std data
return numpy array with classified mask
"""
# Create blank array to store predicted label
seg = np.zeros((image.shape[0], image.shape[1]))
for i in range(0, image.shape[0], int(tile_size)):
for j in range(0, image.shape[1], int(tile_size)):
# If edge of tile beyond image boundary, shift it to boundary
if i + tile_size > image.shape[0]:
i = image.shape[0] - tile_size
if j + tile_size > image.shape[1]:
j = image.shape[1] - tile_size
# Extract and normalise tile
tile = normalize(
image[i:i + tile_size, j:j + tile_size, :].astype(float),
norm_data
)
out = model.predict(
tile.reshape(
(1, tile.shape[0], tile.shape[1], tile.shape[2])
).astype(float),
batch_size=bsize
)
out[out >= 0.5] = 1
out[out < 0.5] = 0
out = out.reshape(tile_size, tile_size) # reshape to tile size
seg[i:i + tile_size, j:j + tile_size] = out
return seg
def pad_image(img, target_size):
"""
Pad an image up to the target size.
"""
rows_missing = target_size - img.shape[0]
cols_missing = target_size - img.shape[1]
padded_img = np.pad(
img, ((0, rows_missing), (0, cols_missing), (0, 0)), 'constant'
)
return padded_img
def predict_sliding(image, model='', stand_method='local',
stand_strategy='per-batch', stand_data=[],
tile_size=256, nclasses=6, overlap=0.25, spline=[]
):
"""
Predict on tiles of exactly the network input shape.
This way nothing gets squeezed.
"""
model.eval()
stride = math.ceil(tile_size * (1 - overlap))
tile_rows = max(
int(math.ceil((image.shape[0] - tile_size) / stride) + 1), 1
) # strided convolution formula
tile_cols = max(
int(math.ceil((image.shape[1] - tile_size) / stride) + 1), 1
)
logging.info("Need %i x %i prediction tiles @ stride %i px" %
(tile_cols, tile_rows, stride)
)
full_probs = np.zeros((image.shape[0], image.shape[1], nclasses))
count_predictions = np.zeros((image.shape[0], image.shape[1], nclasses))
tile_counter = 0
for row in range(tile_rows):
for col in range(tile_cols):
x1 = int(col * stride)
y1 = int(row * stride)
x2 = min(x1 + tile_size, image.shape[1])
y2 = min(y1 + tile_size, image.shape[0])
x1 = max(int(x2 - tile_size), 0)
y1 = max(int(y2 - tile_size), 0)
img = image[y1:y2, x1:x2]
padded_img = pad_image(img, tile_size)
tile_counter += 1
padded_img = np.expand_dims(padded_img, 0)
if stand_method == 'local':
imgn = local_standardization(
padded_img, ndata=stand_data, strategy=stand_strategy
)
elif stand_method == 'global':
imgn = global_standardization(
padded_img, strategy=stand_strategy
)
else:
imgn = padded_img
imgn = imgn.astype('float32')
imgn_tensor = torch.from_numpy(imgn)
imgn_tensor = imgn_tensor.transpose(-1, 1)
with torch.no_grad():
padded_prediction = model(imgn_tensor)
# if padded_prediction.shape[1] > 1:
# padded_prediction = np.argmax(padded_prediction, axis=1)
padded_prediction = np.squeeze(padded_prediction)
padded_prediction = padded_prediction.transpose(0, -1).numpy()
prediction = padded_prediction[0:img.shape[0], 0:img.shape[1], :]
count_predictions[y1:y2, x1:x2] += 1
full_probs[y1:y2, x1:x2] += prediction # * spline
# average the predictions in the overlapping regions
full_probs /= count_predictions
return full_probs
def predict_sliding_binary(image, model='model.h5', tile_size=256,
nclasses=6, overlap=1/3, norm_data=[]
):
"""
Predict on tiles of exactly the network input shape.
This way nothing gets squeezed.
"""
stride = math.ceil(tile_size * (1 - overlap))
tile_rows = max(
int(math.ceil((image.shape[0] - tile_size) / stride) + 1), 1
) # strided convolution formula
tile_cols = max(
int(math.ceil((image.shape[1] - tile_size) / stride) + 1), 1
)
logging.info("Need %i x %i prediction tiles @ stride %i px" %
(tile_cols, tile_rows, stride)
)
full_probs = np.zeros((image.shape[0], image.shape[1], nclasses))
count_predictions = np.zeros((image.shape[0], image.shape[1], nclasses))
tile_counter = 0
for row in range(tile_rows):
for col in range(tile_cols):
x1 = int(col * stride)
y1 = int(row * stride)
x2 = min(x1 + tile_size, image.shape[1])
y2 = min(y1 + tile_size, image.shape[0])
x1 = max(int(x2 - tile_size), 0)
y1 = max(int(y2 - tile_size), 0)
img = image[y1:y2, x1:x2]
padded_img = pad_image(img, tile_size)
tile_counter += 1
imgn = normalize(padded_img, norm_data)
imgn = imgn.astype('float32')
padded_prediction = model.predict(np.expand_dims(imgn, 0))[0]
prediction = padded_prediction[0:img.shape[0], 0:img.shape[1], :]
count_predictions[y1:y2, x1:x2] += 1
full_probs[y1:y2, x1:x2] += prediction
# average the predictions in the overlapping regions
full_probs /= count_predictions
full_probs[full_probs >= 0.8] = 1
full_probs[full_probs < 0.8] = 0
return full_probs.reshape((image.shape[0], image.shape[1]))
def predict_windowing(x, model, stand_method='local',
stand_strategy='per-batch', stand_data=[],
patch_sz=160, n_classes=5, b_size=128, spline=[]
):
img_height = x.shape[0]
img_width = x.shape[1]
n_channels = x.shape[2]
# make extended img so that it contains integer number of patches
npatches_vertical = math.ceil(img_height / patch_sz)
npatches_horizontal = math.ceil(img_width / patch_sz)
extended_height = patch_sz * npatches_vertical
extended_width = patch_sz * npatches_horizontal
ext_x = np.zeros(
shape=(extended_height, extended_width, n_channels), dtype=np.float32
)
# fill extended image with mirrors:
ext_x[:img_height, :img_width, :] = x
for i in range(img_height, extended_height):
ext_x[i, :, :] = ext_x[2 * img_height - i - 1, :, :]
for j in range(img_width, extended_width):
ext_x[:, j, :] = ext_x[:, 2 * img_width - j - 1, :]
# now we assemble all patches in one array
patches_list = []
for i in range(0, npatches_vertical):
for j in range(0, npatches_horizontal):
x0, x1 = i * patch_sz, (i + 1) * patch_sz
y0, y1 = j * patch_sz, (j + 1) * patch_sz
patches_list.append(ext_x[x0:x1, y0:y1, :])
patches_array = np.asarray(patches_list)
# normalization(patches_array, ndata)
if stand_method == 'local': # apply local zero center standardization
patches_array = local_standardization(
patches_array, ndata=stand_data, strategy=stand_strategy
)
elif stand_method == 'global': # apply global zero center standardization
patches_array = global_standardization(
patches_array, strategy=stand_strategy
)
# predictions:
patches_predict = model.predict(patches_array, batch_size=b_size)
prediction = np.zeros(
shape=(extended_height, extended_width, n_classes), dtype=np.float32
)
logging.info("prediction shape: ", prediction.shape)
for k in range(patches_predict.shape[0]):
i = k // npatches_horizontal
j = k % npatches_horizontal
x0, x1 = i * patch_sz, (i + 1) * patch_sz
y0, y1 = j * patch_sz, (j + 1) * patch_sz
prediction[x0:x1, y0:y1, :] = patches_predict[k, :, :, :] * spline
return prediction[:img_height, :img_width, :]
# -------------------------------------------------------------------------------
# module model Unit Tests
# -------------------------------------------------------------------------------
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
# Add unit tests here