File size: 4,075 Bytes
dde56f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python3

"""Code to generate plots for Extended Data Fig. 3."""

import argparse
import os
import matplotlib
import matplotlib.pyplot as plt

import echonet


def main():
    """Generate plots for Extended Data Fig. 3."""

    # Select paths and hyperparameter to plot
    parser = argparse.ArgumentParser()
    parser.add_argument("dir", nargs="?", default="output")
    parser.add_argument("fig", nargs="?", default=os.path.join("figure", "loss"))
    parser.add_argument("--frames", type=int, default=32)
    parser.add_argument("--period", type=int, default=2)
    args = parser.parse_args()

    # Set up figure
    echonet.utils.latexify()
    os.makedirs(args.fig, exist_ok=True)
    fig = plt.figure(figsize=(7, 5))
    gs = matplotlib.gridspec.GridSpec(ncols=3, nrows=2, figure=fig, width_ratios=[2.75, 2.75, 1.50])

    # Plot EF loss curve
    ax0 = fig.add_subplot(gs[0, 0])
    ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
    for pretrained in [True]:
        for (model, color) in zip(["r2plus1d_18", "r3d_18", "mc3_18"], matplotlib.colors.TABLEAU_COLORS):
            loss = load(os.path.join(args.dir, "video", "{}_{}_{}_{}".format(model, args.frames, args.period, "pretrained" if pretrained else "random"), "log.csv"))
            ax0.plot(range(1, 1 + len(loss["train"])), loss["train"], "-" if pretrained else "--", color=color)
            ax1.plot(range(1, 1 + len(loss["val"])), loss["val"], "-" if pretrained else "--", color=color)

    plt.axis([0, max(len(loss["train"]), len(loss["val"])), 0, max(max(loss["train"]), max(loss["val"]))])
    ax0.text(-0.25, 1.00, "(a)", transform=ax0.transAxes)
    ax1.text(-0.25, 1.00, "(b)", transform=ax1.transAxes)
    ax0.set_xlabel("Epochs")
    ax1.set_xlabel("Epochs")
    ax0.set_xticks([0, 15, 30, 45])
    ax1.set_xticks([0, 15, 30, 45])
    ax0.set_ylabel("Training MSE Loss")
    ax1.set_ylabel("Validation MSE Loss")

    # Plot segmentation loss curve
    ax0 = fig.add_subplot(gs[1, 0])
    ax1 = fig.add_subplot(gs[1, 1], sharey=ax0)
    pretrained = False
    for (model, color) in zip(["deeplabv3_resnet50"], list(matplotlib.colors.TABLEAU_COLORS)[3:]):
        loss = load(os.path.join(args.dir, "segmentation", "{}_{}".format(model, "pretrained" if pretrained else "random"), "log.csv"))
        ax0.plot(range(1, 1 + len(loss["train"])), loss["train"], "--", color=color)
        ax1.plot(range(1, 1 + len(loss["val"])), loss["val"], "--", color=color)

    ax0.text(-0.25, 1.00, "(c)", transform=ax0.transAxes)
    ax1.text(-0.25, 1.00, "(d)", transform=ax1.transAxes)
    ax0.set_ylim([0, 0.13])
    ax0.set_xlabel("Epochs")
    ax1.set_xlabel("Epochs")
    ax0.set_xticks([0, 25, 50])
    ax1.set_xticks([0, 25, 50])
    ax0.set_ylabel("Training Cross Entropy Loss")
    ax1.set_ylabel("Validation Cross Entropy Loss")

    # Legend
    ax = fig.add_subplot(gs[:, 2])
    for (model, color) in zip(["EchoNet-Dynamic (EF)", "R3D", "MC3", "EchoNet-Dynamic (Seg)"], matplotlib.colors.TABLEAU_COLORS):
        ax.plot([float("nan")], [float("nan")], "-", color=color, label=model)
    ax.set_title("")
    ax.axis("off")
    ax.legend(loc="center")

    plt.tight_layout()
    plt.savefig(os.path.join(args.fig, "loss.pdf"))
    plt.savefig(os.path.join(args.fig, "loss.eps"))
    plt.savefig(os.path.join(args.fig, "loss.png"))
    plt.close(fig)


def load(filename):
    """Loads losses from specified file."""

    losses = {"train": [], "val": []}
    with open(filename, "r") as f:
        for line in f:
            line = line.split(",")
            if len(line) < 4:
                continue
            epoch, split, loss, *_ = line
            epoch = int(epoch)
            loss = float(loss)
            assert(split in ["train", "val"])
            if epoch == len(losses[split]):
                losses[split].append(loss)
            elif epoch == len(losses[split]) - 1:
                losses[split][-1] = loss
            else:
                raise ValueError("File has uninterpretable formatting.")
    return losses


if __name__ == "__main__":
    main()