|
|
|
from __future__ import absolute_import, division, print_function |
|
|
|
import os |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
|
|
import torchvision.models as models |
|
from feature_extractor import cl |
|
from models.GraphTransformer import Classifier |
|
from models.weight_init import weight_init |
|
from feature_extractor.build_graph_utils import ToTensor, Compose, bag_dataset, adj_matrix |
|
import torchvision.transforms.functional as VF |
|
from src.vis_graphcam import show_cam_on_image,cam_to_mask |
|
from easydict import EasyDict as edict |
|
from models.GraphTransformer import Classifier |
|
from slide_tiling import save_tiles |
|
import pickle |
|
from collections import OrderedDict |
|
import glob |
|
import openslide |
|
import numpy as np |
|
import skimage.transform |
|
import cv2 |
|
|
|
|
|
class Predictor: |
|
|
|
def __init__(self): |
|
self.classdict = pickle.load(open(os.environ['CLASS_METADATA'], 'rb' )) |
|
self.label_map_inv = dict() |
|
for label_name, label_id in self.classdict.items(): |
|
self.label_map_inv[label_id] = label_name |
|
|
|
iclf_weights = os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH'] |
|
graph_transformer_weights = os.environ['GT_WEIGHT_PATH'] |
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.__init_iclf(iclf_weights, backbone='resnet18') |
|
self.__init_graph_transformer(graph_transformer_weights) |
|
|
|
def predict(self, slide_path): |
|
|
|
|
|
save_tiles(slide_path) |
|
|
|
filename = os.path.basename(slide_path) |
|
FILEID = filename.rsplit('.', maxsplit=1)[0] |
|
patches_glob_path = os.path.join(os.environ['PATCHES_DIR'], f'{FILEID}_files', '*', '*.jpeg') |
|
patches_paths = glob.glob(patches_glob_path) |
|
|
|
sample = self.iclf_predict(patches_paths) |
|
|
|
|
|
torch.set_grad_enabled(True) |
|
node_feat, adjs, masks = Predictor.preparefeatureLabel(sample['image'], sample['adj_s'], self.device) |
|
pred,labels,loss,graphcam_tensors = self.model.forward(node_feat=node_feat, labels=None, adj=adjs, mask=masks, graphcam_flag=True, to_file=False) |
|
|
|
patches_coords = sample['c_idx'][0] |
|
viz_dict = self.get_graphcams(graphcam_tensors, patches_coords, slide_path, FILEID) |
|
return self.label_map_inv[pred.item()], viz_dict |
|
|
|
def iclf_predict(self, patches_paths): |
|
feats_list = [] |
|
|
|
batch_size = 128 |
|
num_workers = 0 |
|
args = edict({'batch_size':batch_size, 'num_workers':num_workers} ) |
|
dataloader, bag_size = bag_dataset(args, patches_paths) |
|
|
|
with torch.no_grad(): |
|
for iteration, batch in enumerate(dataloader): |
|
patches = batch['input'].float().to(self.device) |
|
feats, classes = self.i_classifier(patches) |
|
|
|
feats_list.extend(feats) |
|
output = torch.stack(feats_list, dim=0).to(self.device) |
|
|
|
adj_s = adj_matrix(patches_paths, output) |
|
|
|
|
|
patch_infos = [] |
|
for path in patches_paths: |
|
x, y = path.split('/')[-1].split('.')[0].split('_') |
|
patch_infos.append((x,y)) |
|
|
|
preds = {'image': [output], |
|
'adj_s': [adj_s], |
|
'c_idx': [patch_infos]} |
|
return preds |
|
|
|
|
|
|
|
def get_graphcams(self, graphcam_tensors, patches_coords, slide_path, FILEID): |
|
label_map = self.classdict |
|
label_name_from_id = self.label_map_inv |
|
|
|
n_class = len(label_map) |
|
|
|
p = graphcam_tensors['prob'].cpu().detach().numpy()[0] |
|
ori = openslide.OpenSlide(slide_path) |
|
width, height = ori.dimensions |
|
|
|
REDUCTION_FACTOR = 20 |
|
w, h = int(width/512), int(height/512) |
|
w_r, h_r = int(width/20), int(height/20) |
|
resized_img = ori.get_thumbnail((width,height)) |
|
resized_img = resized_img.resize((w_r,h_r)) |
|
ratio_w, ratio_h = width/resized_img.width, height/resized_img.height |
|
|
|
w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR) |
|
|
|
patches = [] |
|
xmax, ymax = 0, 0 |
|
for patch_coords in patches_coords: |
|
x, y = patch_coords |
|
if xmax < int(x): xmax = int(x) |
|
if ymax < int(y): ymax = int(y) |
|
patches.append('{}_{}.jpeg'.format(x,y)) |
|
|
|
|
|
|
|
output_img = np.asarray(resized_img)[:,:,::-1].copy() |
|
|
|
|
|
|
|
assign_matrix = graphcam_tensors['s_matrix_ori'] |
|
m = nn.Softmax(dim=1) |
|
assign_matrix = m(assign_matrix) |
|
|
|
|
|
p = np.clip(p, 0.4, 1) |
|
|
|
|
|
|
|
output_img_copy =np.copy(output_img) |
|
gray = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) |
|
image_transformer_attribution = (output_img_copy - output_img_copy.min()) / (output_img_copy.max() - output_img_copy.min()) |
|
cam_matrices = [] |
|
masks = [] |
|
visualizations = [] |
|
viz_dict = dict() |
|
|
|
SAMPLE_VIZ_DIR = os.path.join(os.environ['GRAPHCAM_DIR'], |
|
FILEID) |
|
os.makedirs(SAMPLE_VIZ_DIR, exist_ok=True) |
|
|
|
for class_i in range(n_class): |
|
|
|
|
|
cam_matrix = graphcam_tensors[f'cam_{class_i}'] |
|
cam_matrix = torch.mm(assign_matrix, cam_matrix.transpose(1,0)) |
|
cam_matrix = cam_matrix.cpu() |
|
|
|
|
|
cam_matrix = (cam_matrix - cam_matrix.min()) / (cam_matrix.max() - cam_matrix.min()) |
|
cam_matrix = cam_matrix.detach().numpy() |
|
cam_matrix = p[class_i] * cam_matrix |
|
cam_matrix = np.clip(cam_matrix, 0, 1) |
|
|
|
|
|
mask = cam_to_mask(gray, patches, cam_matrix, w, h, w_s, h_s) |
|
|
|
vis = show_cam_on_image(image_transformer_attribution, mask) |
|
vis = np.uint8(255 * vis) |
|
|
|
cam_matrices.append(cam_matrix) |
|
masks.append(mask) |
|
visualizations.append(vis) |
|
viz_dict['{}'.format(label_name_from_id[class_i]) ] = vis |
|
cv2.imwrite(os.path.join( |
|
SAMPLE_VIZ_DIR, |
|
'{}_all_types_cam_{}.png'.format(FILEID, label_name_from_id[class_i] ) |
|
), vis) |
|
h, w, _ = output_img.shape |
|
if h > w: |
|
vis_merge = cv2.hconcat([output_img] + visualizations) |
|
else: |
|
vis_merge = cv2.vconcat([output_img] + visualizations) |
|
|
|
|
|
cv2.imwrite(os.path.join( |
|
SAMPLE_VIZ_DIR, |
|
'{}_all_types_cam_all.png'.format(FILEID)), |
|
vis_merge) |
|
viz_dict['ALL'] = vis_merge |
|
cv2.imwrite(os.path.join( |
|
SAMPLE_VIZ_DIR, |
|
'{}_all_types_ori.png'.format(FILEID ) |
|
), |
|
output_img) |
|
viz_dict['ORI'] = output_img |
|
return viz_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
def preparefeatureLabel(batch_graph, batch_adjs, device='cpu'): |
|
batch_size = len(batch_graph) |
|
max_node_num = 0 |
|
|
|
for i in range(batch_size): |
|
max_node_num = max(max_node_num, batch_graph[i].shape[0]) |
|
|
|
masks = torch.zeros(batch_size, max_node_num) |
|
adjs = torch.zeros(batch_size, max_node_num, max_node_num) |
|
batch_node_feat = torch.zeros(batch_size, max_node_num, 512) |
|
|
|
for i in range(batch_size): |
|
cur_node_num = batch_graph[i].shape[0] |
|
|
|
tmp_node_fea = batch_graph[i] |
|
batch_node_feat[i, 0:cur_node_num] = tmp_node_fea |
|
|
|
|
|
adjs[i, 0:cur_node_num, 0:cur_node_num] = batch_adjs[i] |
|
|
|
|
|
masks[i,0:cur_node_num] = 1 |
|
|
|
node_feat = batch_node_feat.to() |
|
adjs = adjs.to(device) |
|
masks = masks.to(device) |
|
|
|
return node_feat, adjs, masks |
|
|
|
def __init_graph_transformer(self, graph_transformer_weights): |
|
n_class = len(self.classdict) |
|
model = Classifier(n_class) |
|
model = nn.DataParallel(model) |
|
model.load_state_dict(torch.load(graph_transformer_weights, |
|
map_location=torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) )) |
|
if torch.cuda.is_available(): |
|
model = model.cuda() |
|
self.model = model |
|
|
|
|
|
def __init_iclf(self, iclf_weights, backbone='resnet18'): |
|
if backbone == 'resnet18': |
|
resnet = models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d) |
|
num_feats = 512 |
|
if backbone == 'resnet34': |
|
resnet = models.resnet34(pretrained=False, norm_layer=nn.InstanceNorm2d) |
|
num_feats = 512 |
|
if backbone == 'resnet50': |
|
resnet = models.resnet50(pretrained=False, norm_layer=nn.InstanceNorm2d) |
|
num_feats = 2048 |
|
if backbone == 'resnet101': |
|
resnet = models.resnet101(pretrained=False, norm_layer=nn.InstanceNorm2d) |
|
num_feats = 2048 |
|
for param in resnet.parameters(): |
|
param.requires_grad = False |
|
resnet.fc = nn.Identity() |
|
i_classifier = cl.IClassifier(resnet, num_feats, output_class=2).to(self.device) |
|
|
|
|
|
|
|
state_dict_weights = torch.load(iclf_weights, map_location=torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )) |
|
state_dict_init = i_classifier.state_dict() |
|
new_state_dict = OrderedDict() |
|
for (k, v), (k_0, v_0) in zip(state_dict_weights.items(), state_dict_init.items()): |
|
if 'features' not in k: |
|
continue |
|
name = k_0 |
|
new_state_dict[name] = v |
|
i_classifier.load_state_dict(new_state_dict, strict=False) |
|
|
|
self.i_classifier = i_classifier |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import subprocess |
|
import argparse |
|
import os |
|
import shutil |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description='PyTorch Classification') |
|
parser.add_argument('--slide_path', type=str, help='path to the WSI slide') |
|
args = parser.parse_args() |
|
predictor = Predictor() |
|
|
|
predicted_class, viz_dict = predictor.predict(args.slide_path) |
|
print('Class prediction is: ', predicted_class) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|