File size: 6,344 Bytes
dde56f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python3

"""Code to generate plots for Extended Data Fig. 1."""

import os

import matplotlib
import matplotlib.pyplot as plt

import echonet


def main(root=os.path.join("output", "video"),
         fig_root=os.path.join("figure", "hyperparameter"),
         FRAMES=(1, 8, 16, 32, 64, 96, None),
         PERIOD=(1, 2, 4, 6, 8)
         ):
    """Generate plots for Extended Data Fig. 1."""

    echonet.utils.latexify()
    os.makedirs(fig_root, exist_ok=True)

    # Parameters for plotting length sweep
    MAX = FRAMES[-2]
    START = 1    # Starting point for normal range
    TERM0 = 104  # Ending point for normal range
    BREAK = 112  # Location for break
    TERM1 = 120  # Starting point for "all" section
    ALL = 128    # Location of "all" point
    END = 135    # Ending point for "all" section
    RATIO = (BREAK - START) / (END - BREAK)

    # Set up figure
    fig = plt.figure(figsize=(3 + 2.5 + 1.5, 2.75))
    outer = matplotlib.gridspec.GridSpec(1, 3, width_ratios=[3, 2.5, 1.50])
    ax = plt.subplot(outer[2])   # Legend
    ax2 = plt.subplot(outer[1])  # Period plot
    gs = matplotlib.gridspec.GridSpecFromSubplotSpec(
        1, 2, subplot_spec=outer[0], width_ratios=[RATIO, 1], wspace=0.020)  # Length plot

    # Plot legend
    for (model, color) in zip(["EchoNet-Dynamic (EF)", "R3D", "MC3"],
                              matplotlib.colors.TABLEAU_COLORS):
        ax.plot([float("nan")], [float("nan")], "-", color=color, label=model)
    ax.plot([float("nan")], [float("nan")], "-", color="k", label="Pretrained")
    ax.plot([float("nan")], [float("nan")], "--", color="k", label="Random")
    ax.set_title("")
    ax.axis("off")
    ax.legend(loc="center")

    # Plot length sweep (panel a)
    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1], sharey=ax0)
    print("FRAMES")
    for (model, color) in zip(["r2plus1d_18", "r3d_18", "mc3_18"],
                              matplotlib.colors.TABLEAU_COLORS):
        for pretrained in [True, False]:
            loss = [load(root, model, frames, 1, pretrained) for frames in FRAMES]
            print(model, pretrained)
            print("    ".join(list(map(lambda x: "{:.1f}".format(x) if x is not None else None, loss))))

            l0 = loss[-2]
            l1 = loss[-1]
            ax0.plot(FRAMES[:-1] + (TERM0,),
                     loss[:-1] + [l0 + (l1 - l0) * (TERM0 - MAX) / (ALL - MAX)],
                     "-" if pretrained else "--", color=color)
            ax1.plot([TERM1, ALL],
                     [l0 + (l1 - l0) * (TERM1 - MAX) / (ALL - MAX)] + [loss[-1]],
                     "-" if pretrained else "--", color=color)
            ax0.scatter(list(map(lambda x: x if x is not None else ALL, FRAMES)), loss, color=color, s=4)
            ax1.scatter(list(map(lambda x: x if x is not None else ALL, FRAMES)), loss, color=color, s=4)

    ax0.set_xticks(list(map(lambda x: x if x is not None else ALL, FRAMES)))
    ax1.set_xticks(list(map(lambda x: x if x is not None else ALL, FRAMES)))
    ax0.set_xticklabels(list(map(lambda x: x if x is not None else "All", FRAMES)))
    ax1.set_xticklabels(list(map(lambda x: x if x is not None else "All", FRAMES)))

    # https://stackoverflow.com/questions/5656798/python-matplotlib-is-there-a-way-to-make-a-discontinuous-axis/43684155
    # zoom-in / limit the view to different portions of the data
    ax0.set_xlim(START, BREAK)  # most of the data
    ax1.set_xlim(BREAK, END)

    # hide the spines between ax and ax2
    ax0.spines['right'].set_visible(False)
    ax1.spines['left'].set_visible(False)

    ax1.get_yaxis().set_visible(False)

    d = 0.015  # how big to make the diagonal lines in axes coordinates
    # arguments to pass plot, just so we don't keep repeating them
    kwargs = dict(transform=ax0.transAxes, color='k', clip_on=False, linewidth=1)
    x0, x1, y0, y1 = ax0.axis()
    scale = (y1 - y0) / (x1 - x0) / 2
    ax0.plot((1 - scale * d, 1 + scale * d), (-d, +d), **kwargs)  # top-left diagonal
    ax0.plot((1 - scale * d, 1 + scale * d), (1 - d, 1 + d), **kwargs)  # bottom-left diagonal

    kwargs.update(transform=ax1.transAxes)  # switch to the bottom 1xes
    x0, x1, y0, y1 = ax1.axis()
    scale = (y1 - y0) / (x1 - x0) / 2
    ax1.plot((-scale * d, scale * d), (-d, +d), **kwargs)  # top-right diagonal
    ax1.plot((-scale * d, scale * d), (1 - d, 1 + d), **kwargs)  # bottom-right diagonal

    # ax0.xaxis.label.set_transform(matplotlib.transforms.blended_transform_factory(
    #        matplotlib.transforms.IdentityTransform(), fig.transFigure # specify x, y transform
    #        )) # changed from default blend (IdentityTransform(), a[0].transAxes)
    ax0.xaxis.label.set_position((0.6, 0.0))
    ax0.text(-0.05, 1.10, "(a)", transform=ax0.transAxes)
    ax0.set_xlabel("Clip length (frames)")
    ax0.set_ylabel("Validation Loss")

    # Plot period sweep (panel b)
    print("PERIOD")
    for (model, color) in zip(["r2plus1d_18", "r3d_18", "mc3_18"], matplotlib.colors.TABLEAU_COLORS):
        for pretrained in [True, False]:
            loss = [load(root, model, 64 // period, period, pretrained) for period in PERIOD]
            print(model, pretrained)
            print("    ".join(list(map(lambda x: "{:.1f}".format(x) if x is not None else None, loss))))

            ax2.plot(PERIOD, loss, "-" if pretrained else "--", marker=".", color=color)
    ax2.set_xticks(PERIOD)
    ax2.text(-0.05, 1.10, "(b)", transform=ax2.transAxes)
    ax2.set_xlabel("Sampling Period (frames)")
    ax2.set_ylabel("Validation Loss")

    # Save figure
    plt.tight_layout()
    plt.savefig(os.path.join(fig_root, "hyperparameter.pdf"))
    plt.savefig(os.path.join(fig_root, "hyperparameter.eps"))
    plt.savefig(os.path.join(fig_root, "hyperparameter.png"))
    plt.close(fig)


def load(root, model, frames, period, pretrained):
    """Loads best validation loss for specified hyperparameter choice."""
    pretrained = ("pretrained" if pretrained else "random")
    f = os.path.join(
        root,
        "{}_{}_{}_{}".format(model, frames, period, pretrained),
        "log.csv")
    with open(f, "r") as f:
        for line in f:
            if "Best validation loss " in line:
                return float(line.split()[3])

    raise ValueError("File missing information.")


if __name__ == "__main__":
    main()