import numpy as np import torch import rasterio import xarray as xr import rioxarray as rxr import cv2 from transformers import SegformerForSemanticSegmentation from tqdm import tqdm from scipy.ndimage import grey_dilation import matplotlib as mpl import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import make_axes_locatable from .viz_utils import alpha_composite from loguru import logger def resize(img, shape=None, scaling_factor=1., order='CHW'): """Resize an image by a given scaling factor""" assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" assert shape is None or scaling_factor == 1., "Got both shape and scaling_factor. Please provide only one of them" # resize image if order == 'CHW': img = np.moveaxis(img, 0, -1) # CHW -> HWC if shape is not None: img = cv2.resize(img, shape[::-1], interpolation=cv2.INTER_LINEAR) else: img = cv2.resize(img, None, fx=scaling_factor, fy=scaling_factor, interpolation=cv2.INTER_LINEAR) # NB: cv2.resize returns a HW image if the input image is HW1: restore the C dimension if len(img.shape) == 2: img = img[..., None] if order == 'CHW': img = np.moveaxis(img, -1, 0) # HWC -> CHW return img def minimum_needed_padding(img_size, patch_size: int, stride: int): """ Compute the minimum padding needed to make an image divisible by a patch size with a given stride. Args: image_shape (tuple): the shape (H,W) of the image tensor patch_size (int): the size of the patches to extract stride (int): the stride to use when extracting patches Returns: tuple: the padding needed to make the image tensor divisible by the patch size with the given stride """ img_size = np.array(img_size) pad = np.where( img_size <= patch_size, (patch_size - img_size) % patch_size, # the % patch_size is to handle the case img_size = (0,0) (stride - (img_size - patch_size)) % stride ) pad_t, pad_l = pad // 2 pad_b, pad_r = pad[0] - pad_t, pad[1] - pad_l return pad_t, pad_b, pad_l, pad_r def pad(img, pad, order='CHW'): """Pad an image by the given pad values, in the format (pad_t, pad_b, pad_l, pad_r)""" assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" pad_t, pad_b, pad_l, pad_r = pad # pad image if order == 'HWC': padded_img = np.pad(img, ((pad_t,pad_b), (pad_l,pad_r), (0,0)), mode='constant', constant_values=0) # can also try mode='reflect' else: padded_img = np.pad(img, ((0,0), (pad_t,pad_b), (pad_l,pad_r)), mode='constant', constant_values=0) # can also try mode='reflect' if isinstance(img, torch.Tensor): padded_img = torch.tensor(padded_img) return padded_img def extract_patches(img, patch_size=512, stride=256, order='CHW', only_return_idx=True): """Extract patches from an image, in the format (h_start, h_end, w_start, w_end)""" assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" if order == 'HWC': H, W = img.shape[:2] else: H, W = img.shape[1:] # compute the number of patches n_patches = ((H - patch_size) // stride + 1) * ((W - patch_size) // stride + 1) # extract patches patches = [] patches_idx = [] for i in range(0, H-patch_size+1, stride): for j in range(0, W-patch_size+1, stride): patches_idx.append((i, i+patch_size, j, j+patch_size)) if not only_return_idx: if order == 'HWC': patch = img[i:i+patch_size, j:j+patch_size, :] else: patch = img[:, i:i+patch_size, j:j+patch_size] patches.append(patch) if only_return_idx: return patches_idx return patches, patches_idx def segment_batch(batch, model): # perform prediction with torch.no_grad(): out = model(batch) # (n_patches, 1, H, W) logits if isinstance(model, SegformerForSemanticSegmentation): out = upsample(out.logits, size=batch.shape[-2:]) # apply sigmoid out = torch.sigmoid(out) # logits -> confidence scores return out def upsample(x, size): """Upsample a 3D/4D/5D tensor""" return torch.nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=False) def merge_patches(patches, patches_idx, rotate=False, canvas_shape=None, order='CHW'): # TODO """Merge patches into a single image""" assert order in ['HWC', 'CHW'], f"Got unknown order '{order}', expected one of ['HWC','CHW']" if rotate: axes_to_rotate = (0,1) if order == 'HWC' else (1,2) patches = [np.rot90(p, -i, axes=axes_to_rotate) for i,p in enumerate(patches)] else: assert len(patches) == len(patches_idx), f"Got {len(patches)} patches and {len(patches_idx)} indexes" # if canvas_shape is None, infer it from patches_idx if canvas_shape is None: patches_idx_zipped = list(zip(*patches_idx)) canvas_H = max(patches_idx_zipped[1]) canvas_W = max(patches_idx_zipped[3]) else: canvas_H, canvas_W = canvas_shape # initialize canvas dtype = patches[0].dtype if order == 'HWC': canvas_C = patches[0].shape[-1] canvas = np.zeros((canvas_H, canvas_W, canvas_C), dtype=dtype) # HWC n_overlapping_patches = np.zeros((canvas_H, canvas_W, 1)) else: canvas_C = patches[0].shape[0] canvas = np.zeros((canvas_C, canvas_H, canvas_W, ), dtype=dtype) # CHW n_overlapping_patches = np.zeros((1, canvas_H, canvas_W)) # merge patches for p, (t,b,l,r) in zip(patches, patches_idx): if order == 'HWC': canvas[t:b, l:r, :] += p n_overlapping_patches[t:b, l:r, 0] += 1 else: canvas[:, t:b, l:r] += p n_overlapping_patches[0, t:b, l:r] += 1 # compute average canvas = np.divide(canvas, n_overlapping_patches, where=(n_overlapping_patches != 0)) return canvas def segment(img, model, patch_size=512, stride=256, scaling_factor=1., rotate=False, device=None, batch_size=16, verbose=False): """Segment an RGB image by using a segmentation model. Returns a probability map (and performance metrics, if requested)""" # some checks assert isinstance(img, np.ndarray), f"Input must be a numpy array. Got {type(img)}" assert img.shape[0] in [3,4], f"Input image must be formatted as CHW, with C = 3,4. Got a shape of {img.shape}" assert img.dtype == np.uint8, f"Input image must be a numpy array with dtype np.uint8. Got {img.dtype}" # prepare model for evaluation model = model.to(device) model.eval() # prepare alpha channel original_shape = img.shape if img.shape[0] == 3: # create dummy alpha channel alpha = np.full(original_shape[1:], 255, dtype=np.uint8) else: # extract alpha channel img, alpha = img[:3], img[3] # resize image img = resize(img, scaling_factor=scaling_factor) # pad image pad_t, pad_b, pad_l, pad_r = minimum_needed_padding(img.shape[1:], patch_size, stride) padded_img = pad(img, pad=(pad_t, pad_b, pad_l, pad_r)) padded_shape = padded_img.shape # extract patches indexes patches_idx = extract_patches(padded_img, patch_size=patch_size, stride=stride) ### segment masks = [] masks_idx = [] batch = [] for i, p_idx in enumerate(tqdm(patches_idx, disable=not verbose, desc="Predicting...", total=len(patches_idx))): t, b, l, r = p_idx # extract patch patch = padded_img[:, t:b, l:r] # consider patch only if it is valid (i.e. not all black or all white) if np.any(patch != 0) and np.any(patch != 255): # convert patch to torch.tensor with float32 values in [0,1] (as required by torch) patch = torch.tensor(patch).float() / 255. # normalize patch with ImageNet mean and std patch = (patch - torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)) / torch.tensor([0.229, 0.224, 0.225]).view(3,1,1) # add patch to batch batch.append(patch) masks_idx.append(p_idx) # (optional) for each patch extracted, consider also its rotated versions if rotate: for rot in range(1,4): patch = torch.rot90(patch, rot, dims=[1,2]) batch.append(patch) masks_idx.append(p_idx) # if the batch is full, perform prediction if len(batch) >= batch_size or i == len(patches_idx)-1: # move batch to GPU batch = torch.stack(batch).to(device) # perform prediction out = segment_batch(batch, model) # append predictions to masks masks.append(out.cpu().numpy()) # reset batch batch = [] # concatenate predictions masks = np.concatenate(masks) # (n_patches, 1, H, W) # merge patches mask = merge_patches(masks, masks_idx, rotate=rotate, canvas_shape=padded_shape[1:]) # (1, H, W) # undo padding mask = mask[:, pad_t:padded_shape[1]-pad_b, pad_l:padded_shape[2]-pad_r] # resize mask to original shape mask = resize(mask, shape=original_shape[1:]) # apply alpha channel, i.e. set to -1 the pixels where alpha is 0 mask = np.where(alpha == 0, -1, mask) return mask.squeeze() def sliding_window_avg_pooling(img, window, granularity, alpha=None, min_nonblank_pixels=0., order="HWC", normalize=False, return_min_max=False, verbose=False): assert isinstance(img, np.ndarray), f'Input image must be a numpy array. Got {type(img)}' if order == "HWC": assert img.shape[2] == 1, f'Input image must be formatted as HWC, with C = 1. Got a shape of {img.shape}' elif order == "CHW": assert img.shape[0] == 1, f'Input image must be formatted as CHW, with C = 1. Got a shape of {img.shape}' # check if alpha channel was given, and cast it to np.float32 with values in [0,1] if alpha is not None: assert img.shape == alpha.shape, f'The shape of input image {img.shape} and alpha channel {alpha.shape} do not match' if alpha.dtype == np.uint8: alpha = (alpha / 255).astype(np.float32) elif alpha.dtype == bool: alpha = alpha.astype(np.float32) else: alpha = np.ones_like(img, dtype=np.float32) # compute threshold thresh = min_nonblank_pixels * window**2 # extract patches idxs patches_idx = extract_patches(img, patch_size=window, stride=granularity, order=order, only_return_idx=True) # initialize canvas canvas = np.zeros_like(img, dtype=np.float32) n_overlapping_patches = np.zeros_like(img, dtype=np.float32) # cycle through patches idxs for t,b,l,r in tqdm(patches_idx, disable=not verbose): p_a = alpha[t:b,l:r] n_valid_pixels = p_a.sum() # keep only if it has more than min_nonblank_pixels if n_valid_pixels <= thresh: continue # compute average patch value (i.e. density inside the patch) p = img[t:b,l:r] p_density = (p * p_a).sum() / n_valid_pixels # add to canvas canvas[t:b,l:r] += p_density n_overlapping_patches[t:b,l:r] += 1 # compute average density density_map = np.divide(canvas, n_overlapping_patches, where=(n_overlapping_patches != 0)) # apply alpha density_map = density_map * alpha if normalize: # [0,1]-normalize density_map_min = density_map.min() density_map_max = density_map.max() density_map = (density_map - density_map_min) / (density_map_max - density_map_min) if return_min_max: return density_map, density_map_min, density_map_max return density_map def compute_vndvi( raster: np.ndarray, mask: np.ndarray, dilate_rows=True, window_size=360, granularity=45, ): assert isinstance(raster, np.ndarray) assert isinstance(mask, np.ndarray) assert len(raster.shape) == 3 # CHW assert len(mask.shape) == 2 # HW assert raster.shape[0] in [3,4] # RGB or RGBA # CHW -> HWC raster = raster.transpose(1,2,0) # Extract channels _raster = raster.astype(np.float32) / 255 # convert to float32 in [0,1] R, G, B = _raster[:,:,0], _raster[:,:,1], _raster[:,:,2] # To avoid division by 0 due to negative power, we replace 0 with 1 in R and B channels R = np.where(R == 0, 1, R) B = np.where(B == 0, 1, B) # Mask has values: 0=interrows, 255=rows, 1=nodata # Get mask for the rows and interrows mask_rows = (mask == 255) mask_interrows = (mask == 0) mask_valid = mask_rows | mask_interrows # Compute vndvi vndvi = 0.5268 * (R**(-0.1294) * G**(0.3389) * B**(-0.3118)) # Clip values to [0,1] vndvi = np.clip(vndvi, 0, 1) # Compute 10th and 90th percentile on whole vineyard vndvi heatmap vndvi_perc10, vndvi_perc90 = np.percentile(vndvi[mask_valid], [10,90]) # Clip values between 10th and 90th percentile vndvi_clipped = np.clip(vndvi, vndvi_perc10, vndvi_perc90) # Perform sliding window average pooling to smooth the heatmap # NB: the window takes into account only the rows vndvi_rows_clipped_pooled = sliding_window_avg_pooling( np.where(mask_rows, vndvi_clipped, 0)[..., None], window = int(window_size / 4), granularity = granularity, alpha = mask_rows[..., None], min_nonblank_pixels = 0.0, verbose=True, ) # Same, but for interrows vndvi_interrows_clipped_pooled = sliding_window_avg_pooling( np.where(mask_interrows, vndvi_clipped, 0)[..., None], window = int(window_size / 4), granularity = granularity, alpha = mask_interrows[..., None], min_nonblank_pixels = 0.0, verbose=True, ) # Apply dilation to rows mask dil_factor = int(window_size / 60) mask_rows_dilated = grey_dilation(mask_rows, size=(dil_factor, dil_factor)) vndvi_rows_clipped_pooled_dilated = grey_dilation(vndvi_rows_clipped_pooled, size=(dil_factor, dil_factor, 1)) # For visualization purposes, normalize with vndvi_perc10 and # vndvi_perc90 (because we want vndvi_perc10 to be the first color of # the colormap and vndvi_perc90 to be the last) vndvi_rows_clipped_pooled_normalized = (vndvi_rows_clipped_pooled - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10) vndvi_rows_clipped_pooled_dilated_normalized = (vndvi_rows_clipped_pooled_dilated - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10) vndvi_interrows_clipped_pooled_normalized = (vndvi_interrows_clipped_pooled - vndvi_perc10) / (vndvi_perc90 - vndvi_perc10) # for visualization vndvi_rows_img = alpha_composite( raster, vndvi_rows_clipped_pooled_dilated_normalized if dilate_rows else vndvi_rows_clipped_pooled_normalized, opacity = 1.0, colormap = 'RdYlGn', alpha_image = np.zeros_like(raster[:,:,[0]]), alpha_mask = mask_rows_dilated[...,None] if dilate_rows else mask_rows[...,None], ) # HW4 RGBA vndvi_interrows_img = alpha_composite( raster, vndvi_interrows_clipped_pooled_normalized, opacity = 1.0, colormap = 'RdYlGn', alpha_image = np.zeros_like(raster[:,:,[0]]), alpha_mask = mask_interrows[...,None], ) # HW4 RGBA # add colorbar # fig_rows, ax = plt.subplots(1, 1, figsize=(10, 10)) # divider = make_axes_locatable(ax) # cax = divider.append_axes('right', size='5%', pad=0.15) # ax.imshow(vndvi_rows_img) # fig_rows.colorbar( # mappable = mpl.cm.ScalarMappable( # norm = mpl.colors.Normalize( # vmin = vndvi_perc10, # vmax = vndvi_perc90), # cmap = 'RdYlGn'), # cax = cax, # orientation = 'vertical', # label = 'vNDVI', # shrink = 1) # fig_interrows, ax = plt.subplots(1, 1, figsize=(10, 10)) # divider = make_axes_locatable(ax) # cax = divider.append_axes('right', size='5%', pad=0.15) # ax.imshow(vndvi_interrows_img) # fig_interrows.colorbar( # mappable = mpl.cm.ScalarMappable( # norm = mpl.colors.Normalize( # vmin = vndvi_perc10, # vmax = vndvi_perc90), # cmap = 'RdYlGn'), # cax = cax, # orientation = 'vertical', # label = 'vNDVI', # shrink = 1) # return fig_rows, fig_interrows return vndvi_rows_img, vndvi_interrows_img def compute_vdi( raster: np.ndarray, mask: np.ndarray, window_size=360, granularity=40, ): # CHW -> HWC raster = raster.transpose(1,2,0) # Mask has values: 0=interrows, 255=rows, 1=nodata # Get mask for the rows and interrows mask_rows = (mask == 255) mask_interrows = (mask == 0) mask_valid = mask_rows | mask_interrows # compute vdi vdi, vdi_min, vdi_max = sliding_window_avg_pooling( mask_rows[...,None], window=window_size, granularity=granularity, alpha=mask_valid[...,None], min_nonblank_pixels=0.9, normalize=True, return_min_max=True, verbose=True, ) # for visualization vdi_img = alpha_composite( raster, vdi, opacity = 1, colormap = 'jet_r', alpha_image = mask_valid[...,None], alpha_mask = mask_valid[...,None], ) # add colorbar # fig, ax = plt.subplots(1, 1, figsize=(10, 10)) # divider = make_axes_locatable(ax) # cax = divider.append_axes('right', size='5%', pad=0.15) # ax.imshow(vdi_img) # fig.colorbar( # mappable = mpl.cm.ScalarMappable( # norm = mpl.colors.Normalize( # vmin = vdi_min, # vmax = vdi_max), # cmap = 'jet_r'), # cax = cax, # orientation = 'vertical', # label = 'VDI', # shrink = 1) # return fig return vdi_img def compute_mask( raster: np.ndarray, model: torch.nn.Module, patch_size=512, stride=256, scaling_factor=None, rotate=False, batch_size=16 ): assert isinstance(raster, np.ndarray), f'Input raster must be a numpy array. Got {type(raster)}' assert len(raster.shape) == 3, f'Input raster must have 3 dimensions (bands, rows, cols). Got shape {raster.shape}' assert raster.shape[0] in [3,4], f'Input raster must have 3 bands (RGB) or 4 bands (RGBA). Got {raster.shape[0]} bands' assert isinstance(model, torch.nn.Module), 'Model must be a torch.nn.Module' device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Infer GSD #gsd = abs(raster.rio.transform()[0]) # ground sampling distance (NB: valid only if image is a GeoTIFF) # Growseg works best on orthoimages with gsd in [1, 1.7] cm/px. You may want to # specify a scaling factor different from 1 if your image has a different gsd. # E.g.: SCALING_FACTOR = gsd / 0.015 # logger.info(f'Image GSD: {gsd*100:.2f} cm/px') # scaling_factor = scaling_factor or (gsd / 0.015) scaling_factor = scaling_factor or 1 logger.info(f'Applying scaling factor: {scaling_factor:.2f}') # segment logger.info('Segmenting image...') score_map = segment( raster, model, patch_size=patch_size, stride=stride, scaling_factor=scaling_factor, rotate=rotate, device=device, batch_size=batch_size, verbose=True ) # mask is a HxW float32 array in [0, 1] # apply threshold on confidence scores alpha = (score_map == -1) mask = (score_map > 0.5) # convert to uint8 mask = (mask * 255).astype(np.uint8) # set nodata pixels to 1 mask[alpha] = 1 return mask