# -------------------------------------------------------- # BiomedSeg # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Yu Gu (yugu1@microsoft.com), Theo Zhao (theodorezhao@microsoft.com) # -------------------------------------------------------- import os import sys this_file_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.join(this_file_dir, "../ct_seg")) import json import warnings import PIL from PIL import Image from typing import Any, Callable, Dict, List, Optional, Tuple import monai import cv2 import math import gradio as gr import torch import argparse import imageio import numpy as np import scipy from torchvision import transforms from models import dinov2_vitl_transunet from class_dict import class_dict, dataset_class from transforms import _MEAN, _STD from monai import transforms as monai_transforms from scipy.ndimage import label id2label = {v: k for k, v in class_dict.items()} np.random.seed(0) id2color = {k: list(np.random.choice(range(256), size=3)) for k,v in id2label.items()} def clean_mask(X): """ Cleans the mask for labels 1 and 2 by keeping only the largest connected component for each label. Parameters: X (numpy.ndarray): Volumetric mask of shape [N, 1, W, H] with values 0 (background), 1, or 2. Returns: numpy.ndarray: Cleaned volumetric mask with the same shape as X. """ # Extract the volume data (assuming N is the depth dimension) if X.ndim == 4: volume = X[:, 0, :, :] # Shape: [N, W, H] else: volume = X for label_value in [1, 2, 10]: # Create a binary mask for the current label mask = (volume == label_value) if not np.any(mask): continue # Skip if the label is not present # Define connectivity for 3D connected components structure = np.ones((3, 3, 3), dtype=int) # Label connected components labeled_mask, num_features = label(mask, structure=structure) if num_features == 0: continue # No connected components found # Compute sizes of all connected components component_sizes = np.bincount(labeled_mask.ravel()) component_sizes[0] = 0 # Ignore the background # Find the label of the largest connected component largest_component_label = component_sizes.argmax() # Create a mask for the largest connected component largest_component_mask = (labeled_mask == largest_component_label) # Remove all other components of the current label volume[mask] = 0 # Set all pixels of the current label to background volume[largest_component_mask] = label_value # Restore the largest component # Update the original mask if X.ndim == 4: X[:, 0, :, :] = volume else: X = volume return X def parse_option(): parser = argparse.ArgumentParser('SEEM Demo', add_help=False) parser.add_argument('--model_path', default="ckpt/model_19.pth", metavar="FILE", help='path to model file') # parser.add_argument('--model_path', default="ckpt/uw_seg_heart.pth", metavar="FILE", help='path to model file') cfg = parser.parse_args() return cfg ''' build args ''' cfg = parse_option() pretrained_pth = cfg.model_path def load_tif_images(file_path): vol = imageio.imread(file_path) if np.max(vol) <= 1: vol = vol * 255 return vol def overlay_image_with_mask(image, segmentation_map, path='test.png', ax=None): color_seg = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3 for label, color in id2color.items(): color_seg[segmentation_map == label, :] = color # Show image + mask img = np.array(image) * 0.5 + color_seg * 0.5 img = img.astype(np.uint8) return img def resize_volume(vol, size, max_frames, nearest_neighbor=False): W, H, F = vol.shape zoom_rate = size / W vol_reshape = scipy.ndimage.zoom( vol, (zoom_rate, zoom_rate, zoom_rate), order=3 if not nearest_neighbor else 0 ) resizeW, resizeH, resizeF = vol_reshape.shape if resizeF > max_frames: vol_reshape = vol_reshape[:, :, :max_frames] resizeF = max_frames else: resized_max_fr = int(math.ceil(max_frames * zoom_rate)) vol_reshape = np.concatenate([vol_reshape, np.zeros((resizeW, resizeH, resized_max_fr - resizeF))], axis=-1) return vol_reshape, resizeF, zoom_rate val_transform = monai_transforms.Compose([monai_transforms.Resized(keys=['image'], spatial_size=(256, 256), mode=['bilinear'])]) def process_volume(vol: np.ndarray, keep_frames: Callable=lambda x: x > 0.025): initial_resize = monai.transforms.ResizeWithPadOrCrop((512, 512)) transform = monai.transforms.CropForeground(keys=["pixel_values"], source_key="pixel_values", return_coords=True) crop_vol, start_coords, end_coords = transform(vol) keep_frames = np.where(keep_frames(np.mean(np.mean(crop_vol, axis=-1), axis=-1)))[0] crop_vol = crop_vol[keep_frames] W, H, F = crop_vol.shape proc_vol = cv2.equalizeHist(crop_vol.reshape(W, -1).astype(np.uint8)).reshape(W, H, F) proc_vol = initial_resize(proc_vol).detach().cpu().numpy().transpose((1, 2, 0)) proc_vol, max_fr = resize_volume(proc_vol, 256, max_frames=512)[:2] images = [] for i in range(proc_vol.shape[2]): image = torch.from_numpy(proc_vol[:, :, i]).unsqueeze(0) image_transformed = val_transform({"image": image})["image"] images.append(image_transformed) images = torch.stack(images) if images.max() > 1: images = images / 255.0 # make the images three channels images = images.repeat(1, 3, 1, 1) for c in range(len(_MEAN)): images[:, c, :, :] = (images[:, c, :, :] - _MEAN[c]) / _STD[c] return images, max_fr def untransform(img): for c in range(len(_MEAN)): img[c] = img[c] * _STD[c] + _MEAN[c] if img.max() <= 1: img = img * 255 return img.long() def process_ct(ct_path: str): vol = load_tif_images(ct_path) images, frame_indices = process_volume(vol, keep_frames=lambda x: x > 0.025) return images, frame_indices # Ensure the example file is in the same directory or provide a relative path examples = [["demo/CTseg_57_raw.tif"], ["demo/CTrec-don_1101.tif"]] ''' build model ''' class_names = dataset_class["uwseg"] class_ids = [class_dict[class_name] for class_name in class_names] model = dinov2_vitl_transunet(pretrained="", num_classes=len(class_dict), img_size=256) state_dict = torch.load(pretrained_pth) model.load_state_dict(state_dict) model = model.cuda() @torch.no_grad() def inference(image_input): if isinstance(image_input, str): # image_input is a file path file_path = image_input else: # image_input is a gr.File object file_path = image_input.name images, frame_indices = process_ct(file_path) with torch.no_grad(): with torch.cuda.amp.autocast(dtype=torch.float16): logits = model(images.cuda()) for j in range(len(class_dict)): if j + 1 not in class_ids: logits[:, j] = -1000 pred = torch.argmax(logits, dim=1) + 1 pred_mask = (torch.max(logits, dim=1)[0] > 0) pred = pred_mask * pred pred[frame_indices:] = 0 pred = torch.from_numpy(clean_mask(pred.cpu().numpy())) volume_size = torch.sum(pred==2).item() # 1 pixel = 1 mm^2, change to cm^3 volume_size = volume_size / 1000 # Compute the size of the segmented mask for each slice sizes = pred.view(pred.shape[0], -1).sum(dim=1).cpu().numpy() segmentation_results = [] raw_images = [] for i in range(len(images)): images[i] = untransform(images[i]) raw_image = Image.fromarray(images[i].cpu().permute(1, 2, 0).numpy().astype(np.uint8)) raw_images.append(raw_image) image_with_mask = overlay_image_with_mask(images[i].cpu().permute(1, 2, 0).numpy(), pred[i].squeeze(0).cpu().numpy()) image_with_mask = Image.fromarray(image_with_mask) segmentation_results.append(image_with_mask) initial_slice_index = 0 output_seg = segmentation_results[initial_slice_index] output_raw = raw_images[initial_slice_index] num_slices = len(segmentation_results) initial_size = sizes[initial_slice_index] return output_seg, output_raw, segmentation_results, raw_images, gr.update(maximum=num_slices - 1), sizes, f"Heart volume size: {volume_size} cm^3" def update_slice(slice_index, segmentation_results_state, raw_images_state, sizes_text): segmentation_results = segmentation_results_state raw_images = raw_images_state if segmentation_results is None or raw_images is None: return None, None, "" output_seg = segmentation_results[slice_index] output_raw = raw_images[slice_index] return output_seg, output_raw, size_text def load_example(example): image_file_path = example return inference(image_file_path) title = "CT Segmentation" description = """