#!/usr/bin/env python3 """Code to generate plots for Extended Data Fig. 6.""" import os import pickle import matplotlib import matplotlib.pyplot as plt import numpy as np import PIL import sklearn import torch import torchvision import echonet def main(fig_root=os.path.join("figure", "noise"), video_output=os.path.join("output", "video", "r2plus1d_18_32_2_pretrained"), seg_output=os.path.join("output", "segmentation", "deeplabv3_resnet50_random"), NOISE=(0, 0.1, 0.2, 0.3, 0.4, 0.5)): """Generate plots for Extended Data Fig. 6.""" device = torch.device("cuda") filename = os.path.join(fig_root, "data.pkl") # Cache of results try: # Attempt to load cache with open(filename, "rb") as f: Y, YHAT, INTER, UNION = pickle.load(f) except FileNotFoundError: # Generate results if no cache available os.makedirs(fig_root, exist_ok=True) # Load trained video model model_v = torchvision.models.video.r2plus1d_18() model_v.fc = torch.nn.Linear(model_v.fc.in_features, 1) if device.type == "cuda": model_v = torch.nn.DataParallel(model_v) model_v.to(device) checkpoint = torch.load(os.path.join(video_output, "checkpoint.pt")) model_v.load_state_dict(checkpoint['state_dict']) # Load trained segmentation model model_s = torchvision.models.segmentation.deeplabv3_resnet50(aux_loss=False) model_s.classifier[-1] = torch.nn.Conv2d(model_s.classifier[-1].in_channels, 1, kernel_size=model_s.classifier[-1].kernel_size) if device.type == "cuda": model_s = torch.nn.DataParallel(model_s) model_s.to(device) checkpoint = torch.load(os.path.join(seg_output, "checkpoint.pt")) model_s.load_state_dict(checkpoint['state_dict']) # Run simulation dice = [] mse = [] r2 = [] Y = [] YHAT = [] INTER = [] UNION = [] for noise in NOISE: Y.append([]) YHAT.append([]) INTER.append([]) UNION.append([]) dataset = echonet.datasets.Echo(split="test", noise=noise) PIL.Image.fromarray(dataset[0][0][:, 0, :, :].astype(np.uint8).transpose(1, 2, 0)).save(os.path.join(fig_root, "noise_{}.tif".format(round(100 * noise)))) mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(split="train")) tasks = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"] kwargs = { "target_type": tasks, "mean": mean, "std": std, "noise": noise } dataset = echonet.datasets.Echo(split="test", **kwargs) dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, num_workers=5, shuffle=True, pin_memory=(device.type == "cuda")) loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model_s, dataloader, "test", None, device) inter = np.concatenate((large_inter, small_inter)).sum() union = np.concatenate((large_union, small_union)).sum() dice.append(2 * inter / (union + inter)) INTER[-1].extend(large_inter.tolist() + small_inter.tolist()) UNION[-1].extend(large_union.tolist() + small_union.tolist()) kwargs = {"target_type": "EF", "mean": mean, "std": std, "length": 32, "period": 2, "noise": noise } dataset = echonet.datasets.Echo(split="test", **kwargs) dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, num_workers=5, shuffle=True, pin_memory=(device.type == "cuda")) loss, yhat, y = echonet.utils.video.run_epoch(model_v, dataloader, "test", None, device) mse.append(loss) r2.append(sklearn.metrics.r2_score(y, yhat)) Y[-1].extend(y.tolist()) YHAT[-1].extend(yhat.tolist()) # Save results in cache with open(filename, "wb") as f: pickle.dump((Y, YHAT, INTER, UNION), f) # Set up plot echonet.utils.latexify() NOISE = list(map(lambda x: round(100 * x), NOISE)) fig = plt.figure(figsize=(6.50, 4.75)) gs = matplotlib.gridspec.GridSpec(3, 1, height_ratios=[2.0, 2.0, 0.75]) ax = (plt.subplot(gs[0]), plt.subplot(gs[1]), plt.subplot(gs[2])) # Plot EF prediction results (R^2) r2 = [sklearn.metrics.r2_score(y, yhat) for (y, yhat) in zip(Y, YHAT)] ax[0].plot(NOISE, r2, color="k", linewidth=1, marker=".") ax[0].set_xticks([]) ax[0].set_ylabel("R$^2$") l, h = min(r2), max(r2) l, h = l - 0.1 * (h - l), h + 0.1 * (h - l) ax[0].axis([min(NOISE) - 5, max(NOISE) + 5, 0, 1]) # Plot segmentation results (DSC) dice = [echonet.utils.dice_similarity_coefficient(inter, union) for (inter, union) in zip(INTER, UNION)] ax[1].plot(NOISE, dice, color="k", linewidth=1, marker=".") ax[1].set_xlabel("Pixels Removed (%)") ax[1].set_ylabel("DSC") l, h = min(dice), max(dice) l, h = l - 0.1 * (h - l), h + 0.1 * (h - l) ax[1].axis([min(NOISE) - 5, max(NOISE) + 5, 0, 1]) # Add example images below for noise in NOISE: image = matplotlib.image.imread(os.path.join(fig_root, "noise_{}.tif".format(noise))) imagebox = matplotlib.offsetbox.OffsetImage(image, zoom=0.4) ab = matplotlib.offsetbox.AnnotationBbox(imagebox, (noise, 0.0), frameon=False) ax[2].add_artist(ab) ax[2].axis("off") ax[2].axis([min(NOISE) - 5, max(NOISE) + 5, -1, 1]) fig.tight_layout() plt.savefig(os.path.join(fig_root, "noise.pdf"), dpi=1200) plt.savefig(os.path.join(fig_root, "noise.eps"), dpi=300) plt.savefig(os.path.join(fig_root, "noise.png"), dpi=600) plt.close(fig) if __name__ == "__main__": main()