|
import logging |
|
import math |
|
import numpy as np |
|
|
|
import torch |
|
|
|
from tiler import Tiler, Merger |
|
|
|
from pytorch_caney.processing import normalize |
|
from pytorch_caney.processing import global_standardization |
|
from pytorch_caney.processing import local_standardization |
|
from pytorch_caney.processing import standardize_image |
|
|
|
__author__ = "Jordan A Caraballo-Vega, Science Data Processing Branch" |
|
__email__ = "jordan.a.caraballo-vega@nasa.gov" |
|
__status__ = "Production" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sliding_window_tiler_multiclass( |
|
xraster, |
|
model, |
|
n_classes: int, |
|
img_size: int, |
|
pad_style: str = 'reflect', |
|
overlap: float = 0.50, |
|
constant_value: int = 600, |
|
batch_size: int = 1024, |
|
threshold: float = 0.50, |
|
standardization: str = None, |
|
mean=None, |
|
std=None, |
|
normalize: float = 1.0, |
|
rescale: str = None, |
|
window: str = 'triang', |
|
probability_map: bool = False |
|
): |
|
""" |
|
Sliding window using tiler. |
|
""" |
|
|
|
tile_channels = xraster.shape[-1] |
|
print(f'Standardizing: {standardization}') |
|
|
|
|
|
tiler_image = Tiler( |
|
data_shape=xraster.shape, |
|
tile_shape=(img_size, img_size, tile_channels), |
|
channel_dimension=-1, |
|
overlap=overlap, |
|
mode=pad_style, |
|
constant_value=constant_value |
|
) |
|
|
|
|
|
tiler_mask = Tiler( |
|
data_shape=(xraster.shape[0], xraster.shape[1], n_classes), |
|
tile_shape=(img_size, img_size, n_classes), |
|
channel_dimension=-1, |
|
overlap=overlap, |
|
mode=pad_style, |
|
constant_value=constant_value |
|
) |
|
|
|
merger = Merger(tiler=tiler_mask, window=window) |
|
|
|
|
|
|
|
for batch_id, batch_i in tiler_image(xraster, batch_size=batch_size): |
|
|
|
|
|
batch = batch_i.copy() |
|
|
|
if standardization is not None: |
|
for item in range(batch.shape[0]): |
|
batch[item, :, :, :] = standardize_image( |
|
batch[item, :, :, :], standardization, mean, std) |
|
|
|
input_batch = batch.astype('float32') |
|
input_batch_tensor = torch.from_numpy(input_batch) |
|
input_batch_tensor = input_batch_tensor.transpose(-1, 1) |
|
|
|
with torch.no_grad(): |
|
y_batch = model(input_batch_tensor) |
|
y_batch = y_batch.transpose(1, -1) |
|
merger.add_batch(batch_id, batch_size, y_batch) |
|
|
|
prediction = merger.merge(unpad=True) |
|
|
|
if not probability_map: |
|
if prediction.shape[-1] > 1: |
|
prediction = np.argmax(prediction, axis=-1) |
|
else: |
|
prediction = np.squeeze( |
|
np.where(prediction > threshold, 1, 0).astype(np.int16) |
|
) |
|
else: |
|
prediction = np.squeeze(prediction) |
|
return prediction |
|
|
|
|
|
|
|
|
|
def segment(image, model='model.h5', tile_size=256, channels=6, |
|
norm_data=[], bsize=8): |
|
""" |
|
Applies a semantic segmentation model to an image. Ideal for non-scene |
|
imagery. Leaves artifacts in boundaries if no post-processing is done. |
|
:param image: image to classify (numpy array) |
|
:param model: loaded model object |
|
:param tile_size: tile size of patches |
|
:param channels: number of channels |
|
:param norm_data: numpy array with mean and std data |
|
:param bsize: number of patches to predict at the same time |
|
return numpy array with classified mask |
|
""" |
|
|
|
seg = np.zeros((image.shape[0], image.shape[1])) |
|
for i in range(0, image.shape[0], int(tile_size)): |
|
for j in range(0, image.shape[1], int(tile_size)): |
|
|
|
if i + tile_size > image.shape[0]: |
|
i = image.shape[0] - tile_size |
|
if j + tile_size > image.shape[1]: |
|
j = image.shape[1] - tile_size |
|
|
|
|
|
tile = normalize( |
|
image[i: i + tile_size, j: j + tile_size, :].astype(float), |
|
norm_data |
|
) |
|
out = model.predict( |
|
tile.reshape( |
|
(1, tile.shape[0], tile.shape[1], tile.shape[2]) |
|
).astype(float), |
|
batch_size=4 |
|
) |
|
out = out.argmax(axis=3) |
|
out = out.reshape(tile_size, tile_size) |
|
seg[i: i + tile_size, j: j + tile_size] = out |
|
return seg |
|
|
|
|
|
def segment_binary(image, model='model.h5', norm_data=[], |
|
tile_size=256, channels=6, bsize=8 |
|
): |
|
""" |
|
Applies binary semantic segmentation model to an image. Ideal for non-scene |
|
imagery. Leaves artifacts in boundaries if no post-processing is done. |
|
:param image: image to classify (numpy array) |
|
:param model: loaded model object |
|
:param tile_size: tile size of patches |
|
:param channels: number of channels |
|
:param norm_data: numpy array with mean and std data |
|
return numpy array with classified mask |
|
""" |
|
|
|
seg = np.zeros((image.shape[0], image.shape[1])) |
|
for i in range(0, image.shape[0], int(tile_size)): |
|
for j in range(0, image.shape[1], int(tile_size)): |
|
|
|
if i + tile_size > image.shape[0]: |
|
i = image.shape[0] - tile_size |
|
if j + tile_size > image.shape[1]: |
|
j = image.shape[1] - tile_size |
|
|
|
|
|
tile = normalize( |
|
image[i:i + tile_size, j:j + tile_size, :].astype(float), |
|
norm_data |
|
) |
|
out = model.predict( |
|
tile.reshape( |
|
(1, tile.shape[0], tile.shape[1], tile.shape[2]) |
|
).astype(float), |
|
batch_size=bsize |
|
) |
|
out[out >= 0.5] = 1 |
|
out[out < 0.5] = 0 |
|
out = out.reshape(tile_size, tile_size) |
|
seg[i:i + tile_size, j:j + tile_size] = out |
|
return seg |
|
|
|
|
|
def pad_image(img, target_size): |
|
""" |
|
Pad an image up to the target size. |
|
""" |
|
rows_missing = target_size - img.shape[0] |
|
cols_missing = target_size - img.shape[1] |
|
padded_img = np.pad( |
|
img, ((0, rows_missing), (0, cols_missing), (0, 0)), 'constant' |
|
) |
|
return padded_img |
|
|
|
|
|
def predict_sliding(image, model='', stand_method='local', |
|
stand_strategy='per-batch', stand_data=[], |
|
tile_size=256, nclasses=6, overlap=0.25, spline=[] |
|
): |
|
""" |
|
Predict on tiles of exactly the network input shape. |
|
This way nothing gets squeezed. |
|
""" |
|
model.eval() |
|
stride = math.ceil(tile_size * (1 - overlap)) |
|
tile_rows = max( |
|
int(math.ceil((image.shape[0] - tile_size) / stride) + 1), 1 |
|
) |
|
tile_cols = max( |
|
int(math.ceil((image.shape[1] - tile_size) / stride) + 1), 1 |
|
) |
|
logging.info("Need %i x %i prediction tiles @ stride %i px" % |
|
(tile_cols, tile_rows, stride) |
|
) |
|
|
|
full_probs = np.zeros((image.shape[0], image.shape[1], nclasses)) |
|
count_predictions = np.zeros((image.shape[0], image.shape[1], nclasses)) |
|
tile_counter = 0 |
|
for row in range(tile_rows): |
|
for col in range(tile_cols): |
|
x1 = int(col * stride) |
|
y1 = int(row * stride) |
|
x2 = min(x1 + tile_size, image.shape[1]) |
|
y2 = min(y1 + tile_size, image.shape[0]) |
|
x1 = max(int(x2 - tile_size), 0) |
|
y1 = max(int(y2 - tile_size), 0) |
|
|
|
img = image[y1:y2, x1:x2] |
|
padded_img = pad_image(img, tile_size) |
|
tile_counter += 1 |
|
|
|
padded_img = np.expand_dims(padded_img, 0) |
|
|
|
if stand_method == 'local': |
|
imgn = local_standardization( |
|
padded_img, ndata=stand_data, strategy=stand_strategy |
|
) |
|
elif stand_method == 'global': |
|
imgn = global_standardization( |
|
padded_img, strategy=stand_strategy |
|
) |
|
else: |
|
imgn = padded_img |
|
|
|
imgn = imgn.astype('float32') |
|
imgn_tensor = torch.from_numpy(imgn) |
|
imgn_tensor = imgn_tensor.transpose(-1, 1) |
|
with torch.no_grad(): |
|
padded_prediction = model(imgn_tensor) |
|
|
|
|
|
padded_prediction = np.squeeze(padded_prediction) |
|
padded_prediction = padded_prediction.transpose(0, -1).numpy() |
|
prediction = padded_prediction[0:img.shape[0], 0:img.shape[1], :] |
|
count_predictions[y1:y2, x1:x2] += 1 |
|
full_probs[y1:y2, x1:x2] += prediction |
|
|
|
full_probs /= count_predictions |
|
return full_probs |
|
|
|
|
|
def predict_sliding_binary(image, model='model.h5', tile_size=256, |
|
nclasses=6, overlap=1/3, norm_data=[] |
|
): |
|
""" |
|
Predict on tiles of exactly the network input shape. |
|
This way nothing gets squeezed. |
|
""" |
|
stride = math.ceil(tile_size * (1 - overlap)) |
|
tile_rows = max( |
|
int(math.ceil((image.shape[0] - tile_size) / stride) + 1), 1 |
|
) |
|
tile_cols = max( |
|
int(math.ceil((image.shape[1] - tile_size) / stride) + 1), 1 |
|
) |
|
logging.info("Need %i x %i prediction tiles @ stride %i px" % |
|
(tile_cols, tile_rows, stride) |
|
) |
|
full_probs = np.zeros((image.shape[0], image.shape[1], nclasses)) |
|
count_predictions = np.zeros((image.shape[0], image.shape[1], nclasses)) |
|
tile_counter = 0 |
|
for row in range(tile_rows): |
|
for col in range(tile_cols): |
|
x1 = int(col * stride) |
|
y1 = int(row * stride) |
|
x2 = min(x1 + tile_size, image.shape[1]) |
|
y2 = min(y1 + tile_size, image.shape[0]) |
|
x1 = max(int(x2 - tile_size), 0) |
|
y1 = max(int(y2 - tile_size), 0) |
|
|
|
img = image[y1:y2, x1:x2] |
|
padded_img = pad_image(img, tile_size) |
|
tile_counter += 1 |
|
|
|
imgn = normalize(padded_img, norm_data) |
|
imgn = imgn.astype('float32') |
|
padded_prediction = model.predict(np.expand_dims(imgn, 0))[0] |
|
prediction = padded_prediction[0:img.shape[0], 0:img.shape[1], :] |
|
count_predictions[y1:y2, x1:x2] += 1 |
|
full_probs[y1:y2, x1:x2] += prediction |
|
|
|
full_probs /= count_predictions |
|
full_probs[full_probs >= 0.8] = 1 |
|
full_probs[full_probs < 0.8] = 0 |
|
return full_probs.reshape((image.shape[0], image.shape[1])) |
|
|
|
|
|
def predict_windowing(x, model, stand_method='local', |
|
stand_strategy='per-batch', stand_data=[], |
|
patch_sz=160, n_classes=5, b_size=128, spline=[] |
|
): |
|
img_height = x.shape[0] |
|
img_width = x.shape[1] |
|
n_channels = x.shape[2] |
|
|
|
npatches_vertical = math.ceil(img_height / patch_sz) |
|
npatches_horizontal = math.ceil(img_width / patch_sz) |
|
extended_height = patch_sz * npatches_vertical |
|
extended_width = patch_sz * npatches_horizontal |
|
ext_x = np.zeros( |
|
shape=(extended_height, extended_width, n_channels), dtype=np.float32 |
|
) |
|
|
|
ext_x[:img_height, :img_width, :] = x |
|
for i in range(img_height, extended_height): |
|
ext_x[i, :, :] = ext_x[2 * img_height - i - 1, :, :] |
|
for j in range(img_width, extended_width): |
|
ext_x[:, j, :] = ext_x[:, 2 * img_width - j - 1, :] |
|
|
|
|
|
patches_list = [] |
|
for i in range(0, npatches_vertical): |
|
for j in range(0, npatches_horizontal): |
|
x0, x1 = i * patch_sz, (i + 1) * patch_sz |
|
y0, y1 = j * patch_sz, (j + 1) * patch_sz |
|
patches_list.append(ext_x[x0:x1, y0:y1, :]) |
|
patches_array = np.asarray(patches_list) |
|
|
|
|
|
|
|
if stand_method == 'local': |
|
patches_array = local_standardization( |
|
patches_array, ndata=stand_data, strategy=stand_strategy |
|
) |
|
elif stand_method == 'global': |
|
patches_array = global_standardization( |
|
patches_array, strategy=stand_strategy |
|
) |
|
|
|
|
|
patches_predict = model.predict(patches_array, batch_size=b_size) |
|
prediction = np.zeros( |
|
shape=(extended_height, extended_width, n_classes), dtype=np.float32 |
|
) |
|
logging.info("prediction shape: ", prediction.shape) |
|
for k in range(patches_predict.shape[0]): |
|
i = k // npatches_horizontal |
|
j = k % npatches_horizontal |
|
x0, x1 = i * patch_sz, (i + 1) * patch_sz |
|
y0, y1 = j * patch_sz, (j + 1) * patch_sz |
|
prediction[x0:x1, y0:y1, :] = patches_predict[k, :, :, :] * spline |
|
return prediction[:img_height, :img_width, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|