Spaces:
Running
Running
File size: 6,137 Bytes
dde56f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
#!/usr/bin/env python3
"""Code to generate plots for Extended Data Fig. 6."""
import os
import pickle
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import PIL
import sklearn
import torch
import torchvision
import echonet
def main(fig_root=os.path.join("figure", "noise"),
video_output=os.path.join("output", "video", "r2plus1d_18_32_2_pretrained"),
seg_output=os.path.join("output", "segmentation", "deeplabv3_resnet50_random"),
NOISE=(0, 0.1, 0.2, 0.3, 0.4, 0.5)):
"""Generate plots for Extended Data Fig. 6."""
device = torch.device("cuda")
filename = os.path.join(fig_root, "data.pkl") # Cache of results
try:
# Attempt to load cache
with open(filename, "rb") as f:
Y, YHAT, INTER, UNION = pickle.load(f)
except FileNotFoundError:
# Generate results if no cache available
os.makedirs(fig_root, exist_ok=True)
# Load trained video model
model_v = torchvision.models.video.r2plus1d_18()
model_v.fc = torch.nn.Linear(model_v.fc.in_features, 1)
if device.type == "cuda":
model_v = torch.nn.DataParallel(model_v)
model_v.to(device)
checkpoint = torch.load(os.path.join(video_output, "checkpoint.pt"))
model_v.load_state_dict(checkpoint['state_dict'])
# Load trained segmentation model
model_s = torchvision.models.segmentation.deeplabv3_resnet50(aux_loss=False)
model_s.classifier[-1] = torch.nn.Conv2d(model_s.classifier[-1].in_channels, 1, kernel_size=model_s.classifier[-1].kernel_size)
if device.type == "cuda":
model_s = torch.nn.DataParallel(model_s)
model_s.to(device)
checkpoint = torch.load(os.path.join(seg_output, "checkpoint.pt"))
model_s.load_state_dict(checkpoint['state_dict'])
# Run simulation
dice = []
mse = []
r2 = []
Y = []
YHAT = []
INTER = []
UNION = []
for noise in NOISE:
Y.append([])
YHAT.append([])
INTER.append([])
UNION.append([])
dataset = echonet.datasets.Echo(split="test", noise=noise)
PIL.Image.fromarray(dataset[0][0][:, 0, :, :].astype(np.uint8).transpose(1, 2, 0)).save(os.path.join(fig_root, "noise_{}.tif".format(round(100 * noise))))
mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(split="train"))
tasks = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"]
kwargs = {
"target_type": tasks,
"mean": mean,
"std": std,
"noise": noise
}
dataset = echonet.datasets.Echo(split="test", **kwargs)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=16, num_workers=5, shuffle=True, pin_memory=(device.type == "cuda"))
loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model_s, dataloader, "test", None, device)
inter = np.concatenate((large_inter, small_inter)).sum()
union = np.concatenate((large_union, small_union)).sum()
dice.append(2 * inter / (union + inter))
INTER[-1].extend(large_inter.tolist() + small_inter.tolist())
UNION[-1].extend(large_union.tolist() + small_union.tolist())
kwargs = {"target_type": "EF",
"mean": mean,
"std": std,
"length": 32,
"period": 2,
"noise": noise
}
dataset = echonet.datasets.Echo(split="test", **kwargs)
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=16, num_workers=5, shuffle=True, pin_memory=(device.type == "cuda"))
loss, yhat, y = echonet.utils.video.run_epoch(model_v, dataloader, "test", None, device)
mse.append(loss)
r2.append(sklearn.metrics.r2_score(y, yhat))
Y[-1].extend(y.tolist())
YHAT[-1].extend(yhat.tolist())
# Save results in cache
with open(filename, "wb") as f:
pickle.dump((Y, YHAT, INTER, UNION), f)
# Set up plot
echonet.utils.latexify()
NOISE = list(map(lambda x: round(100 * x), NOISE))
fig = plt.figure(figsize=(6.50, 4.75))
gs = matplotlib.gridspec.GridSpec(3, 1, height_ratios=[2.0, 2.0, 0.75])
ax = (plt.subplot(gs[0]), plt.subplot(gs[1]), plt.subplot(gs[2]))
# Plot EF prediction results (R^2)
r2 = [sklearn.metrics.r2_score(y, yhat) for (y, yhat) in zip(Y, YHAT)]
ax[0].plot(NOISE, r2, color="k", linewidth=1, marker=".")
ax[0].set_xticks([])
ax[0].set_ylabel("R$^2$")
l, h = min(r2), max(r2)
l, h = l - 0.1 * (h - l), h + 0.1 * (h - l)
ax[0].axis([min(NOISE) - 5, max(NOISE) + 5, 0, 1])
# Plot segmentation results (DSC)
dice = [echonet.utils.dice_similarity_coefficient(inter, union) for (inter, union) in zip(INTER, UNION)]
ax[1].plot(NOISE, dice, color="k", linewidth=1, marker=".")
ax[1].set_xlabel("Pixels Removed (%)")
ax[1].set_ylabel("DSC")
l, h = min(dice), max(dice)
l, h = l - 0.1 * (h - l), h + 0.1 * (h - l)
ax[1].axis([min(NOISE) - 5, max(NOISE) + 5, 0, 1])
# Add example images below
for noise in NOISE:
image = matplotlib.image.imread(os.path.join(fig_root, "noise_{}.tif".format(noise)))
imagebox = matplotlib.offsetbox.OffsetImage(image, zoom=0.4)
ab = matplotlib.offsetbox.AnnotationBbox(imagebox, (noise, 0.0), frameon=False)
ax[2].add_artist(ab)
ax[2].axis("off")
ax[2].axis([min(NOISE) - 5, max(NOISE) + 5, -1, 1])
fig.tight_layout()
plt.savefig(os.path.join(fig_root, "noise.pdf"), dpi=1200)
plt.savefig(os.path.join(fig_root, "noise.eps"), dpi=300)
plt.savefig(os.path.join(fig_root, "noise.png"), dpi=600)
plt.close(fig)
if __name__ == "__main__":
main()
|