import colorsys import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from metrics import * import torchvision.transforms as T import gradio as gr import matplotlib.pyplot as plt import tempfile import os import spaces import cv2 from huggingface_hub import snapshot_download from huggingface_hub import login login(token = os.getenv('HF_TOKEN')) model_dir = snapshot_download( repo_id="srijaydeshpande/spadesegresnet" ) color_map = { 'outside_roi' : (255, 255, 255), # white 'tumor' : (255, 0, 0), # red 'stroma' : (0, 0, 255), # blue 'inflammatory' : (0, 255, 0), # green 'necrosis' : (255, 255, 0), # yello 'others' : (8, 133, 161) # cyan } class_labels = ['outside_roi', 'tumor', 'stroma', 'inflammatory', 'necrosis', 'others'] colors = ['white', 'red', 'blue', 'green', 'yellow', 'cyan'] class SPADE(nn.Module): def __init__(self, norm_nc, label_nc, norm): super().__init__() if norm == 'instance': self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) elif norm == 'batch': self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) # The dimension of the intermediate embedding space. Yes, hardcoded. nhidden = 128 ks = 3 pw = ks // 2 self.mlp_shared = nn.Sequential( nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU() ) self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) def forward(self, x, segmap): # Part 1. generate parameter-free normalized activations normalized = self.param_free_norm(x) # Part 2. produce scaling and bias conditioned on semantic map segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') actv = self.mlp_shared(segmap) gamma = self.mlp_gamma(actv) beta = self.mlp_beta(actv) # apply scale and bias out = normalized * (1 + gamma) + beta return out class SPADEResnetBlock(nn.Module): def __init__(self, fin, fout): super().__init__() # Attributes self.learned_shortcut = (fin != fout) fmiddle = min(fin, fout) # create conv layers self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) if self.learned_shortcut: self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) # define normalization layers self.norm_0 = SPADE(fin, 3, norm='instance') self.norm_1 = SPADE(fmiddle, 3, norm='instance') if self.learned_shortcut: self.norm_s = SPADE(fin, 3, norm='instance') def forward(self, x, seg): x_s = self.shortcut(x, seg) dx = self.conv_0(self.actvn(self.norm_0(x, seg))) dx = self.conv_1(self.actvn(self.norm_1(dx, seg))) out = x_s + dx return out def shortcut(self, x, seg): if self.learned_shortcut: x_s = self.conv_s(self.norm_s(x, seg)) else: x_s = x return x_s def actvn(self, x): return F.leaky_relu(x, 2e-1) class ResnetBlock(nn.Module): def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False): super(ResnetBlock, self).__init__() self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout) def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): conv_block = [] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), norm_layer(dim), activation] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), norm_layer(dim)] return nn.Sequential(*conv_block) def forward(self, x): out = x + self.conv_block(x) return out class SPADEResNet(torch.nn.Module): def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=5, norm_layer=nn.BatchNorm2d, padding_type='reflect'): assert (n_blocks >= 0) super(SPADEResNet, self).__init__() activation = nn.ReLU(True) downsampler = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] ### downsample for i in range(n_downsampling): mult = 2 ** i downsampler += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), norm_layer(ngf * mult * 2), activation] self.downsampler = nn.Sequential(*downsampler) ### resnet blocks mult = 2 ** n_downsampling self.resnetblocks1 = SPADEResnetBlock(ngf * mult, ngf * mult) self.resnetblocks2 = SPADEResnetBlock(ngf * mult, ngf * mult) self.resnetblocks3 = SPADEResnetBlock(ngf * mult, ngf * mult) self.resnetblocks4 = SPADEResnetBlock(ngf * mult, ngf * mult) self.resnetblocks5 = SPADEResnetBlock(ngf * mult, ngf * mult) ### upsample upsampler = [] for i in range(n_downsampling): mult = 2 ** (n_downsampling - i) upsampler += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), norm_layer(int(ngf * mult / 2)), activation] upsampler += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] self.upsampler = nn.Sequential(*upsampler) def forward(self, input): downsampled = self.downsampler(input) resnet1 = self.resnetblocks1(downsampled, input) resnet2 = self.resnetblocks1(resnet1, input) resnet3 = self.resnetblocks1(resnet2, input) resnet4 = self.resnetblocks1(resnet3, input) resnet5 = self.resnetblocks1(resnet4, input) upsampled = self.upsampler(resnet5) return upsampled def generate_colors(n): brightness = 0.7 hsv = [(i / n, 1, brightness) for i in range(n)] colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),colors)) return colors def generate_colored_image(labels): # colors = generate_colors(6) w, h = labels.shape new_mk = np.empty([w, h, 3]) for i in range(0,w): for j in range(0,h): new_mk[i][j] = color_map[class_labels[labels[i][j]]] new_mk = new_mk.astype(np.uint8) return Image.fromarray(new_mk) def predict_wsi(image): patch_size = 768 stride = 700 # stride is kept relatively lower than the tile size so as to allow some overlap while constructing bigger regions generator_output_size = patch_size num_classes=5 pred_labels = torch.zeros(1, num_classes+1, image.shape[2], image.shape[3]).cuda() counter_tensor = torch.zeros(1, 1, image.shape[2], image.shape[3]).cuda() for i in range(0, image.shape[2] - patch_size + stride, stride): for j in range(0, image.shape[3] - patch_size + stride, stride): i_lowered = min(i, image.shape[2] - patch_size) j_lowered = min(j, image.shape[3] - patch_size) patch = image[:, :, i_lowered:i_lowered + patch_size, j_lowered:j_lowered + patch_size] pred_labels_patch = model(patch.float()) update_region_i = i_lowered # + (patch_size - generator_output_size) // 2 update_region_j = j_lowered # + (patch_size - generator_output_size) // 2 pred_labels[:, :, update_region_i:update_region_i + generator_output_size, update_region_j:update_region_j + generator_output_size] += pred_labels_patch counter_tensor[:, :, update_region_i:update_region_i + generator_output_size, update_region_j:update_region_j + generator_output_size] += 1 pred_labels /= counter_tensor return pred_labels @spaces.GPU(duration=120) def segment_image(image): img = image img = np.asarray(img) # resize if necessary h, w = img.shape[:2] min_side=768 if min(h, w) < min_side: scale = min_side / min(h, w) new_w, new_h = int(w * scale), int(h * scale) # Convert NumPy array to PIL Image image = Image.fromarray(img) # Resize the image using PIL resized_image = image.resize((new_w, new_h)) img = np.array(resized_image) if (np.max(img) > 100): img = img / 255.0 transform = T.Compose([T.ToTensor()]) image = transform(img) image = image[None, :] with torch.no_grad(): pred_labels = predict_wsi(image.float()) pred_labels = F.softmax(pred_labels, dim=1) pred_labels_probs = pred_labels.cpu().numpy() pred_labels = np.argmax(pred_labels_probs, axis=1) pred_labels = pred_labels[0] image = generate_colored_image(pred_labels) pixels_counts = [] total=0 print(np.unique(pred_labels)) for i in range(1,len(class_labels)): current_count=np.sum(pred_labels == i) pixels_counts.append(current_count) total+=current_count pixels_counts = [(value / total) * 100 for value in pixels_counts] print(pixels_counts) plt.figure(figsize=(10, 6)) bar_width = 0.15 plt.bar(class_labels[1:], pixels_counts, color=colors[1:], width=bar_width) plt.xticks(rotation=45, ha='right') plt.xlabel('Tissue types', fontsize=17) plt.ylabel('Class Percentage', fontsize=17) plt.title('Classes distribution', fontsize=18) plt.xticks(fontsize=16) plt.yticks(fontsize=16) plt.tight_layout() with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmpfile: plt.savefig(tmpfile.name) temp_filename = tmpfile.name stats = Image.open(temp_filename) # legend = Image.open('legend.png') superimposed_image = superimpose_images(img, image) return image, stats, superimposed_image def superimpose_images(image1, image2): if image1.dtype != np.uint8: image1 = (image1 * 255).astype(np.uint8) if image1.max() <= 1 else image1.astype(np.uint8) # Convert NumPy arrays to PIL images image1 = Image.fromarray(image1) # Resize image1 to match image2's size image1 = image1.resize(image2.size) image_np = np.array(image1) heatmap_np = np.array(image2) superimposed_np = cv2.addWeighted(heatmap_np, 0.2, image_np, 1, 0) superimposed_pil = Image.fromarray(superimposed_np) return superimposed_pil model_path = os.path.join(model_dir, 'spaderesnet.pt') model = SPADEResNet(input_nc=3, output_nc=6) model = nn.DataParallel(model) model = model.cuda() model.load_state_dict(torch.load(model_path), strict=True) examples = [ ["sample1.png"], ["sample2.png"] ] with gr.Row(): # First column: Input and first output with gr.Column(): input_image = gr.Image(label="Input Image") # Input image output1 = gr.Image(label="Segmentation Mask") # First output # Second column: Remaining three outputs with gr.Column(): output3 = gr.Image(label="Statistics") # Third output output4 = gr.Image(label="Superimposed Map") # Fourth output demo = gr.Interface( segment_image, inputs=input_image, examples=examples, outputs=[output1, output3, output4], title="Breast Cancer Semantic Segmentation" ) demo.launch()