Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
"""Code to generate plots for Extended Data Fig. 6.""" | |
import os | |
import pickle | |
import matplotlib | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import PIL | |
import sklearn | |
import torch | |
import torchvision | |
import echonet | |
def main(fig_root=os.path.join("figure", "noise"), | |
video_output=os.path.join("output", "video", "r2plus1d_18_32_2_pretrained"), | |
seg_output=os.path.join("output", "segmentation", "deeplabv3_resnet50_random"), | |
NOISE=(0, 0.1, 0.2, 0.3, 0.4, 0.5)): | |
"""Generate plots for Extended Data Fig. 6.""" | |
device = torch.device("cuda") | |
filename = os.path.join(fig_root, "data.pkl") # Cache of results | |
try: | |
# Attempt to load cache | |
with open(filename, "rb") as f: | |
Y, YHAT, INTER, UNION = pickle.load(f) | |
except FileNotFoundError: | |
# Generate results if no cache available | |
os.makedirs(fig_root, exist_ok=True) | |
# Load trained video model | |
model_v = torchvision.models.video.r2plus1d_18() | |
model_v.fc = torch.nn.Linear(model_v.fc.in_features, 1) | |
if device.type == "cuda": | |
model_v = torch.nn.DataParallel(model_v) | |
model_v.to(device) | |
checkpoint = torch.load(os.path.join(video_output, "checkpoint.pt")) | |
model_v.load_state_dict(checkpoint['state_dict']) | |
# Load trained segmentation model | |
model_s = torchvision.models.segmentation.deeplabv3_resnet50(aux_loss=False) | |
model_s.classifier[-1] = torch.nn.Conv2d(model_s.classifier[-1].in_channels, 1, kernel_size=model_s.classifier[-1].kernel_size) | |
if device.type == "cuda": | |
model_s = torch.nn.DataParallel(model_s) | |
model_s.to(device) | |
checkpoint = torch.load(os.path.join(seg_output, "checkpoint.pt")) | |
model_s.load_state_dict(checkpoint['state_dict']) | |
# Run simulation | |
dice = [] | |
mse = [] | |
r2 = [] | |
Y = [] | |
YHAT = [] | |
INTER = [] | |
UNION = [] | |
for noise in NOISE: | |
Y.append([]) | |
YHAT.append([]) | |
INTER.append([]) | |
UNION.append([]) | |
dataset = echonet.datasets.Echo(split="test", noise=noise) | |
PIL.Image.fromarray(dataset[0][0][:, 0, :, :].astype(np.uint8).transpose(1, 2, 0)).save(os.path.join(fig_root, "noise_{}.tif".format(round(100 * noise)))) | |
mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(split="train")) | |
tasks = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"] | |
kwargs = { | |
"target_type": tasks, | |
"mean": mean, | |
"std": std, | |
"noise": noise | |
} | |
dataset = echonet.datasets.Echo(split="test", **kwargs) | |
dataloader = torch.utils.data.DataLoader(dataset, | |
batch_size=16, num_workers=5, shuffle=True, pin_memory=(device.type == "cuda")) | |
loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model_s, dataloader, "test", None, device) | |
inter = np.concatenate((large_inter, small_inter)).sum() | |
union = np.concatenate((large_union, small_union)).sum() | |
dice.append(2 * inter / (union + inter)) | |
INTER[-1].extend(large_inter.tolist() + small_inter.tolist()) | |
UNION[-1].extend(large_union.tolist() + small_union.tolist()) | |
kwargs = {"target_type": "EF", | |
"mean": mean, | |
"std": std, | |
"length": 32, | |
"period": 2, | |
"noise": noise | |
} | |
dataset = echonet.datasets.Echo(split="test", **kwargs) | |
dataloader = torch.utils.data.DataLoader(dataset, | |
batch_size=16, num_workers=5, shuffle=True, pin_memory=(device.type == "cuda")) | |
loss, yhat, y = echonet.utils.video.run_epoch(model_v, dataloader, "test", None, device) | |
mse.append(loss) | |
r2.append(sklearn.metrics.r2_score(y, yhat)) | |
Y[-1].extend(y.tolist()) | |
YHAT[-1].extend(yhat.tolist()) | |
# Save results in cache | |
with open(filename, "wb") as f: | |
pickle.dump((Y, YHAT, INTER, UNION), f) | |
# Set up plot | |
echonet.utils.latexify() | |
NOISE = list(map(lambda x: round(100 * x), NOISE)) | |
fig = plt.figure(figsize=(6.50, 4.75)) | |
gs = matplotlib.gridspec.GridSpec(3, 1, height_ratios=[2.0, 2.0, 0.75]) | |
ax = (plt.subplot(gs[0]), plt.subplot(gs[1]), plt.subplot(gs[2])) | |
# Plot EF prediction results (R^2) | |
r2 = [sklearn.metrics.r2_score(y, yhat) for (y, yhat) in zip(Y, YHAT)] | |
ax[0].plot(NOISE, r2, color="k", linewidth=1, marker=".") | |
ax[0].set_xticks([]) | |
ax[0].set_ylabel("R$^2$") | |
l, h = min(r2), max(r2) | |
l, h = l - 0.1 * (h - l), h + 0.1 * (h - l) | |
ax[0].axis([min(NOISE) - 5, max(NOISE) + 5, 0, 1]) | |
# Plot segmentation results (DSC) | |
dice = [echonet.utils.dice_similarity_coefficient(inter, union) for (inter, union) in zip(INTER, UNION)] | |
ax[1].plot(NOISE, dice, color="k", linewidth=1, marker=".") | |
ax[1].set_xlabel("Pixels Removed (%)") | |
ax[1].set_ylabel("DSC") | |
l, h = min(dice), max(dice) | |
l, h = l - 0.1 * (h - l), h + 0.1 * (h - l) | |
ax[1].axis([min(NOISE) - 5, max(NOISE) + 5, 0, 1]) | |
# Add example images below | |
for noise in NOISE: | |
image = matplotlib.image.imread(os.path.join(fig_root, "noise_{}.tif".format(noise))) | |
imagebox = matplotlib.offsetbox.OffsetImage(image, zoom=0.4) | |
ab = matplotlib.offsetbox.AnnotationBbox(imagebox, (noise, 0.0), frameon=False) | |
ax[2].add_artist(ab) | |
ax[2].axis("off") | |
ax[2].axis([min(NOISE) - 5, max(NOISE) + 5, -1, 1]) | |
fig.tight_layout() | |
plt.savefig(os.path.join(fig_root, "noise.pdf"), dpi=1200) | |
plt.savefig(os.path.join(fig_root, "noise.eps"), dpi=300) | |
plt.savefig(os.path.join(fig_root, "noise.png"), dpi=600) | |
plt.close(fig) | |
if __name__ == "__main__": | |
main() | |