import os import gradio as gr import numpy as np import segmentation_models_pytorch as smp import torch import torch.nn.functional as F from torchvision import transforms from torchvision.utils import draw_segmentation_masks config = { "downsize_res": 512, "batch_size": 6, "epochs": 30, "lr": 3e-4, "model_architecture": "Unet", "model_config": { "encoder_name": "resnet34", "encoder_weights": "imagenet", "in_channels": 3, "classes": 7, }, } colors = [ (0, 255, 255), (255, 255, 0), (255, 0, 255), (0, 255, 0), (0, 0, 255), (255, 255, 255), (0, 0, 0), ] cp_path = "CP_epoch20.pth" device = "cuda" if torch.cuda.is_available() else "cpu" # load model model_architecture = getattr(smp, config["model_architecture"]) model = model_architecture(**config["model_config"]) model.load_state_dict(torch.load(cp_path, map_location=torch.device(device))) model.to(device) model.eval() # transforms downsize_t = transforms.Resize((config["downsize_res"], config["downsize_res"]), antialias=True) transform = transforms.Compose( [ transforms.ToTensor(), ] ) def label_to_onehot(mask: torch.Tensor, num_classes: int) -> torch.Tensor: """Transforms a tensor from label encoding to one hot encoding in boolean dtype""" dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2) return torch.permute( F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool), dims_p, ) def get_overlay(image: torch.Tensor, preds: torch.Tensor, alpha: float) -> torch.Tensor: """Generates the segmentation ovelay for an satellite image""" masks = label_to_onehot(preds.squeeze(), 7) overlay = draw_segmentation_masks(image, masks=masks, alpha=alpha, colors=colors) return overlay def hwc_to_chw(image_tensor: torch.Tensor) -> torch.Tensor: return torch.permute(image_tensor, (2, 0, 1)) def chw_to_hwc(image_tensor: torch.Tensor) -> torch.Tensor: return torch.permute(image_tensor, (1, 2, 0)) def segment(satellite_image: np.ndarray) -> tuple[np.ndarray, np.ndarray]: image_tensor = torch.from_numpy(satellite_image) image_tensor = hwc_to_chw(image_tensor) pil_image = transforms.functional.to_pil_image(image_tensor) # preprocess image X = transform(pil_image).unsqueeze(0) X = X.to(device) X_down = downsize_t(X) # forward pass logits = model(X_down) preds = torch.argmax(logits, 1).detach() # resize to evaluate with the original image preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True) # get rbg formatted images segmentation_overlay = chw_to_hwc(get_overlay(image_tensor, preds, 0.2)).numpy() raw_segmentation = chw_to_hwc( get_overlay(torch.zeros_like(image_tensor), preds, 1) ).numpy() return raw_segmentation, segmentation_overlay inputs = gr.inputs.Image(label="Input Image") outputs = [gr.Image(label="Raw Segmentation"), gr.Image(label="Segmentation Overlay")] images_dir = "sample_sat_images/" examples = [f"{images_dir}/{image_id}" for image_id in os.listdir(images_dir)] title = "Satellite Images Landcover Classification" description = ( "Upload a satellite image from your computer or select one from" " the examples to automatically. The model will segment the landcover" " types from a preselected set of possible types." ) article = open("article.md", "r").read() iface = gr.Interface( segment, inputs, outputs, examples=examples, title=title, description=description, cache_examples=True, article=article, ) iface.launch()