Franny Dean commited on
Commit
dde56f0
·
1 Parent(s): 96b77b8
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. EchoNet-Dynamic/.DS_Store +0 -0
  3. EchoNet-Dynamic/FileList.csv +0 -0
  4. EchoNet-Dynamic/Videos/.DS_Store +0 -0
  5. EchoNet-Dynamic/Videos/0X2A09D7E5E6F9F9DF.avi +0 -0
  6. EchoNet-Dynamic/Videos/0X2A46A6A93DF181FA.avi +0 -0
  7. EchoNet-Dynamic/Videos/0X2A6635B01B13AAA4.avi +0 -0
  8. EchoNet-Dynamic/Videos/0X2A667FD468D528A2.avi +0 -0
  9. PSSL_app.py +566 -0
  10. dynamic/.DS_Store +0 -0
  11. dynamic/.gitignore +7 -0
  12. dynamic/.travis.yml +68 -0
  13. dynamic/LICENSE.txt +3 -0
  14. dynamic/README.md +97 -0
  15. dynamic/echonet/.DS_Store +0 -0
  16. dynamic/echonet/__init__.py +26 -0
  17. dynamic/echonet/__main__.py +7 -0
  18. dynamic/echonet/__version__.py +3 -0
  19. dynamic/echonet/config.py +24 -0
  20. dynamic/echonet/datasets/__init__.py +8 -0
  21. dynamic/echonet/datasets/echo.py +282 -0
  22. dynamic/echonet/utils/__init__.py +179 -0
  23. dynamic/echonet/utils/segmentation.py +498 -0
  24. dynamic/echonet/utils/video.py +361 -0
  25. dynamic/example.cfg +1 -0
  26. dynamic/requirements.txt +28 -0
  27. dynamic/scripts/ConvertDICOMToAVI.ipynb +215 -0
  28. dynamic/scripts/InitializationNotebook.ipynb +288 -0
  29. dynamic/scripts/beat_by_beat_analysis.R +100 -0
  30. dynamic/scripts/plot_complexity.py +92 -0
  31. dynamic/scripts/plot_hyperparameter_sweep.py +149 -0
  32. dynamic/scripts/plot_loss.py +106 -0
  33. dynamic/scripts/plot_simulated_noise.py +160 -0
  34. dynamic/scripts/run_experiments.sh +49 -0
  35. dynamic/setup.py +44 -0
  36. echonet/__init__.py +26 -0
  37. echonet/__main__.py +7 -0
  38. echonet/__pycache__/__init__.cpython-311.pyc +0 -0
  39. echonet/__pycache__/__version__.cpython-311.pyc +0 -0
  40. echonet/__pycache__/config.cpython-311.pyc +0 -0
  41. echonet/__version__.py +3 -0
  42. echonet/config.py +24 -0
  43. echonet/datasets/__init__.py +8 -0
  44. echonet/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
  45. echonet/datasets/__pycache__/echo.cpython-311.pyc +0 -0
  46. echonet/datasets/echo.py +282 -0
  47. echonet/utils/__init__.py +179 -0
  48. echonet/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  49. echonet/utils/__pycache__/segmentation.cpython-311.pyc +0 -0
  50. echonet/utils/__pycache__/video.cpython-311.pyc +0 -0
.DS_Store ADDED
Binary file (8.2 kB). View file
 
EchoNet-Dynamic/.DS_Store ADDED
Binary file (6.15 kB). View file
 
EchoNet-Dynamic/FileList.csv ADDED
The diff for this file is too large to render. See raw diff
 
EchoNet-Dynamic/Videos/.DS_Store ADDED
Binary file (6.15 kB). View file
 
EchoNet-Dynamic/Videos/0X2A09D7E5E6F9F9DF.avi ADDED
Binary file (549 kB). View file
 
EchoNet-Dynamic/Videos/0X2A46A6A93DF181FA.avi ADDED
Binary file (462 kB). View file
 
EchoNet-Dynamic/Videos/0X2A6635B01B13AAA4.avi ADDED
Binary file (484 kB). View file
 
EchoNet-Dynamic/Videos/0X2A667FD468D528A2.avi ADDED
Binary file (821 kB). View file
 
PSSL_app.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import matplotlib.pyplot as plt
4
+ from scipy.integrate import odeint
5
+ import torch
6
+ from torch.utils import data
7
+ from torch.utils.data import DataLoader, Dataset
8
+ from torch import nn, optim
9
+ import os
10
+ from skimage.transform import rescale, resize
11
+ from torch import nn, optim
12
+ import torch.nn.functional as F
13
+ from torch.utils.data import Subset
14
+ from scipy.interpolate import interp1d
15
+
16
+ #for pvloop simulator:
17
+ import pandas as pd
18
+ from scipy.integrate import odeint
19
+ from scipy import interpolate
20
+ from scipy.interpolate import RegularGridInterpolator
21
+ from matplotlib import pyplot
22
+ import sys
23
+ import numpy as np
24
+ import collections
25
+ import pandas
26
+ import skimage.draw
27
+ import torchvision
28
+ import echonet
29
+
30
+ #odesolver:
31
+ from torch.storage import T
32
+ import argparse
33
+ import time
34
+
35
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
36
+
37
+ sequences_all = []
38
+ info_data_all = []
39
+ path = '/Users/FDean/Desktop/Physics_Informed_Transfer_Learning/EchoNet-Dynamic'
40
+ output_path = '/Users/FDean/Desktop/Physics_Informed_Transfer_Learning'
41
+
42
+ class Echo(torchvision.datasets.VisionDataset):
43
+ """EchoNet-Dynamic Dataset.
44
+ Args:
45
+ root (string): Root directory of dataset (defaults to `echonet.config.DATA_DIR`)
46
+ split (string): One of {``train'', ``val'', ``test'', ``all'', or ``external_test''}
47
+ target_type (string or list, optional): Type of target to use,
48
+ ``Filename'', ``EF'', ``EDV'', ``ESV'', ``LargeIndex'',
49
+ ``SmallIndex'', ``LargeFrame'', ``SmallFrame'', ``LargeTrace'',
50
+ or ``SmallTrace''
51
+ Can also be a list to output a tuple with all specified target types.
52
+ The targets represent:
53
+ ``Filename'' (string): filename of video
54
+ ``EF'' (float): ejection fraction
55
+ ``EDV'' (float): end-diastolic volume
56
+ ``ESV'' (float): end-systolic volume
57
+ ``LargeIndex'' (int): index of large (diastolic) frame in video
58
+ ``SmallIndex'' (int): index of small (systolic) frame in video
59
+ ``LargeFrame'' (np.array shape=(3, height, width)): normalized large (diastolic) frame
60
+ ``SmallFrame'' (np.array shape=(3, height, width)): normalized small (systolic) frame
61
+ ``LargeTrace'' (np.array shape=(height, width)): left ventricle large (diastolic) segmentation
62
+ value of 0 indicates pixel is outside left ventricle
63
+ 1 indicates pixel is inside left ventricle
64
+ ``SmallTrace'' (np.array shape=(height, width)): left ventricle small (systolic) segmentation
65
+ value of 0 indicates pixel is outside left ventricle
66
+ 1 indicates pixel is inside left ventricle
67
+ Defaults to ``EF''.
68
+ mean (int, float, or np.array shape=(3,), optional): means for all (if scalar) or each (if np.array) channel.
69
+ Used for normalizing the video. Defaults to 0 (video is not shifted).
70
+ std (int, float, or np.array shape=(3,), optional): standard deviation for all (if scalar) or each (if np.array) channel.
71
+ Used for normalizing the video. Defaults to 0 (video is not scaled).
72
+ length (int or None, optional): Number of frames to clip from video. If ``None'', longest possible clip is returned.
73
+ Defaults to 16.
74
+ period (int, optional): Sampling period for taking a clip from the video (i.e. every ``period''-th frame is taken)
75
+ Defaults to 2.
76
+ max_length (int or None, optional): Maximum number of frames to clip from video (main use is for shortening excessively
77
+ long videos when ``length'' is set to None). If ``None'', shortening is not applied to any video.
78
+ Defaults to 250.
79
+ clips (int, optional): Number of clips to sample. Main use is for test-time augmentation with random clips.
80
+ Defaults to 1.
81
+ pad (int or None, optional): Number of pixels to pad all frames on each side (used as augmentation).
82
+ and a window of the original size is taken. If ``None'', no padding occurs.
83
+ Defaults to ``None''.
84
+ noise (float or None, optional): Fraction of pixels to black out as simulated noise. If ``None'', no simulated noise is added.
85
+ Defaults to ``None''.
86
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
87
+ external_test_location (string): Path to videos to use for external testing.
88
+ """
89
+
90
+ def __init__(self, root=None,
91
+ split="train", target_type="EF",
92
+ mean=0., std=1.,
93
+ length=16, period=2,
94
+ max_length=250,
95
+ clips=1,
96
+ pad=None,
97
+ noise=None,
98
+ target_transform=None,
99
+ external_test_location=None):
100
+ if root is None:
101
+ root = path
102
+
103
+ super().__init__(root, target_transform=target_transform)
104
+
105
+ self.split = split.upper()
106
+ if not isinstance(target_type, list):
107
+ target_type = [target_type]
108
+ self.target_type = target_type
109
+ self.mean = mean
110
+ self.std = std
111
+ self.length = length
112
+ self.max_length = max_length
113
+ self.period = period
114
+ self.clips = clips
115
+ self.pad = pad
116
+ self.noise = noise
117
+ self.target_transform = target_transform
118
+ self.external_test_location = external_test_location
119
+
120
+ self.fnames, self.outcome = [], []
121
+
122
+ if self.split == "EXTERNAL_TEST":
123
+ self.fnames = sorted(os.listdir(self.external_test_location))
124
+ else:
125
+ # Load video-level labels
126
+ with open(os.path.join(self.root, "FileList.csv")) as f:
127
+ data = pandas.read_csv(f)
128
+ data["Split"].map(lambda x: x.upper())
129
+
130
+ if self.split != "ALL":
131
+ data = data[data["Split"] == self.split]
132
+
133
+ self.header = data.columns.tolist()
134
+ self.fnames = data["FileName"].tolist()
135
+ self.fnames = [fn + ".avi" for fn in self.fnames if os.path.splitext(fn)[1] == ""] # Assume avi if no suffix
136
+ self.outcome = data.values.tolist()
137
+
138
+ # Check that files are present
139
+ """
140
+ missing = set(self.fnames) - set(os.listdir(os.path.join(self.root, "Videos")))
141
+ if len(missing) != 0:
142
+ print("{} videos could not be found in {}:".format(len(missing), os.path.join(self.root, "Videos")))
143
+ for f in sorted(missing):
144
+ print("\t", f)
145
+ raise FileNotFoundError(os.path.join(self.root, "Videos", sorted(missing)[0]))
146
+ """
147
+
148
+ # Load traces
149
+ self.frames = collections.defaultdict(list)
150
+ self.trace = collections.defaultdict(_defaultdict_of_lists)
151
+
152
+ with open(os.path.join(self.root, "VolumeTracings.csv")) as f:
153
+ header = f.readline().strip().split(",")
154
+ assert header == ["FileName", "X1", "Y1", "X2", "Y2", "Frame"]
155
+
156
+ for line in f:
157
+ filename, x1, y1, x2, y2, frame = line.strip().split(',')
158
+ x1 = float(x1)
159
+ y1 = float(y1)
160
+ x2 = float(x2)
161
+ y2 = float(y2)
162
+ frame = int(frame)
163
+ if frame not in self.trace[filename]:
164
+ self.frames[filename].append(frame)
165
+ self.trace[filename][frame].append((x1, y1, x2, y2))
166
+ for filename in self.frames:
167
+ for frame in self.frames[filename]:
168
+ self.trace[filename][frame] = np.array(self.trace[filename][frame])
169
+
170
+ # A small number of videos are missing traces; remove these videos
171
+ keep = [len(self.frames[f]) >= 2 for f in self.fnames]
172
+ self.fnames = [f for (f, k) in zip(self.fnames, keep) if k]
173
+ self.outcome = [f for (f, k) in zip(self.outcome, keep) if k]
174
+
175
+ def __getitem__(self, index):
176
+ # Find filename of video
177
+ if self.split == "EXTERNAL_TEST":
178
+ video = os.path.join(self.external_test_location, self.fnames[index])
179
+ elif self.split == "CLINICAL_TEST":
180
+ video = os.path.join(self.root, "ProcessedStrainStudyA4c", self.fnames[index])
181
+ else:
182
+ video = os.path.join(self.root, "Videos", self.fnames[index])
183
+
184
+ # Load video into np.array
185
+ video = echonet.utils.loadvideo(video).astype(np.float32)
186
+
187
+ # Add simulated noise (black out random pixels)
188
+ # 0 represents black at this point (video has not been normalized yet)
189
+ if self.noise is not None:
190
+ n = video.shape[1] * video.shape[2] * video.shape[3]
191
+ ind = np.random.choice(n, round(self.noise * n), replace=False)
192
+ f = ind % video.shape[1]
193
+ ind //= video.shape[1]
194
+ i = ind % video.shape[2]
195
+ ind //= video.shape[2]
196
+ j = ind
197
+ video[:, f, i, j] = 0
198
+
199
+ # Apply normalization
200
+ if isinstance(self.mean, (float, int)):
201
+ video -= self.mean
202
+ else:
203
+ video -= self.mean.reshape(3, 1, 1, 1)
204
+
205
+ if isinstance(self.std, (float, int)):
206
+ video /= self.std
207
+ else:
208
+ video /= self.std.reshape(3, 1, 1, 1)
209
+
210
+ # Set number of frames
211
+ c, f, h, w = video.shape
212
+ if self.length is None:
213
+ # Take as many frames as possible
214
+ length = f // self.period
215
+ else:
216
+ # Take specified number of frames
217
+ length = self.length
218
+
219
+ if self.max_length is not None:
220
+ # Shorten videos to max_length
221
+ length = min(length, self.max_length)
222
+
223
+ if f < length * self.period:
224
+ # Pad video with frames filled with zeros if too short
225
+ # 0 represents the mean color (dark grey), since this is after normalization
226
+ video = np.concatenate((video, np.zeros((c, length * self.period - f, h, w), video.dtype)), axis=1)
227
+ c, f, h, w = video.shape # pylint: disable=E0633
228
+
229
+ if self.clips == "all":
230
+ # Take all possible clips of desired length
231
+ start = np.arange(f - (length - 1) * self.period)
232
+ else:
233
+ # Take random clips from video
234
+ start = np.random.choice(f - (length - 1) * self.period, self.clips)
235
+
236
+ # Gather targets
237
+ target = []
238
+ for t in self.target_type:
239
+ key = self.fnames[index]
240
+ if t == "Filename":
241
+ target.append(self.fnames[index])
242
+ elif t == "LargeIndex":
243
+ # Traces are sorted by cross-sectional area
244
+ # Largest (diastolic) frame is last
245
+ target.append(int(self.frames[key][-1]))
246
+ elif t == "SmallIndex":
247
+ # Largest (diastolic) frame is first
248
+ target.append(int(self.frames[key][0]))
249
+ elif t == "LargeFrame":
250
+ target.append(video[:, self.frames[key][-1], :, :])
251
+ elif t == "SmallFrame":
252
+ target.append(video[:, self.frames[key][0], :, :])
253
+ elif t in ["LargeTrace", "SmallTrace"]:
254
+ if t == "LargeTrace":
255
+ t = self.trace[key][self.frames[key][-1]]
256
+ else:
257
+ t = self.trace[key][self.frames[key][0]]
258
+ x1, y1, x2, y2 = t[:, 0], t[:, 1], t[:, 2], t[:, 3]
259
+ x = np.concatenate((x1[1:], np.flip(x2[1:])))
260
+ y = np.concatenate((y1[1:], np.flip(y2[1:])))
261
+
262
+ r, c = skimage.draw.polygon(np.rint(y).astype(np.int), np.rint(x).astype(np.int), (video.shape[2], video.shape[3]))
263
+ mask = np.zeros((video.shape[2], video.shape[3]), np.float32)
264
+ mask[r, c] = 1
265
+ target.append(mask)
266
+ else:
267
+ if self.split == "CLINICAL_TEST" or self.split == "EXTERNAL_TEST":
268
+ target.append(np.float32(0))
269
+ else:
270
+ target.append(np.float32(self.outcome[index][self.header.index(t)]))
271
+
272
+ if target != []:
273
+ target = tuple(target) if len(target) > 1 else target[0]
274
+ if self.target_transform is not None:
275
+ target = self.target_transform(target)
276
+
277
+ # Select clips from video
278
+ video = tuple(video[:, s + self.period * np.arange(length), :, :] for s in start)
279
+ if self.clips == 1:
280
+ video = video[0]
281
+ else:
282
+ video = np.stack(video)
283
+
284
+ if self.pad is not None:
285
+ # Add padding of zeros (mean color of videos)
286
+ # Crop of original size is taken out
287
+ # (Used as augmentation)
288
+ c, l, h, w = video.shape
289
+ temp = np.zeros((c, l, h + 2 * self.pad, w + 2 * self.pad), dtype=video.dtype)
290
+ temp[:, :, self.pad:-self.pad, self.pad:-self.pad] = video # pylint: disable=E1130
291
+ i, j = np.random.randint(0, 2 * self.pad, 2)
292
+ video = temp[:, :, i:(i + h), j:(j + w)]
293
+
294
+ return video, target
295
+
296
+ def __len__(self):
297
+ return len(self.fnames)
298
+
299
+ def extra_repr(self) -> str:
300
+ """Additional information to add at end of __repr__."""
301
+ lines = ["Target type: {target_type}", "Split: {split}"]
302
+ return '\n'.join(lines).format(**self.__dict__)
303
+
304
+
305
+ def _defaultdict_of_lists():
306
+ """Returns a defaultdict of lists.
307
+ This is used to avoid issues with Windows (if this function is anonymous,
308
+ the Echo dataset cannot be used in a dataloader).
309
+ """
310
+
311
+ return collections.defaultdict(list)
312
+ ##
313
+ print("Done loading training data!")
314
+ # define normalization layer to make sure output xi in an interval [ai, bi]:
315
+ # define normalization layer to make sure output xi in an interval [ai, bi]:
316
+
317
+
318
+ class IntervalNormalizationLayer(torch.nn.Module):
319
+ def __init__(self):
320
+ super().__init__()
321
+ # new_output = [Tc, start_p, Emax, Emin, Rm, Ra, Vd]
322
+ self.a = torch.tensor([0.4, 0., 0.5, 0.02, 0.005, 0.0001, 4.], dtype=torch.float32) #HR in 20-200->Tc in [0.3, 4]
323
+ self.b = torch.tensor([1.7, 280., 3.5, 0.1, 0.1, 0.25, 16.], dtype=torch.float32)
324
+ #taken out (initial conditions): a: 20, 5, 50; b: 400, 20, 100
325
+ def forward(self, inputs):
326
+ sigmoid_output = torch.sigmoid(inputs)
327
+ scaled_output = sigmoid_output * (self.b - self.a) + self.a
328
+ return scaled_output
329
+
330
+ class NEW3DCNN(nn.Module):
331
+ def __init__(self, num_parameters):
332
+ super(NEW3DCNN, self).__init__()
333
+
334
+ self.conv1 = nn.Conv3d(3, 8, kernel_size=3, padding=1)
335
+ self.batchnorm1 = nn.BatchNorm3d(8)
336
+ self.conv2 = nn.Conv3d(8, 16, kernel_size=3, padding=1)
337
+ self.batchnorm2 = nn.BatchNorm3d(16)
338
+ self.conv3 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
339
+ self.batchnorm3 = nn.BatchNorm3d(32)
340
+ self.conv4 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
341
+ self.batchnorm4 = nn.BatchNorm3d(64)
342
+ self.conv5 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
343
+ self.batchnorm5 = nn.BatchNorm3d(128)
344
+ self.pool = nn.AdaptiveAvgPool3d(1)
345
+ self.fc1 = nn.Linear(128, 512)
346
+ self.fc2 = nn.Linear(512, num_parameters)
347
+ self.norm1 = IntervalNormalizationLayer()
348
+
349
+ def forward(self, x):
350
+ x = F.relu(self.batchnorm1(self.conv1(x)))
351
+ x = F.max_pool3d(x, kernel_size=2, stride=2)
352
+ x = F.relu(self.batchnorm2(self.conv2(x)))
353
+ x = F.max_pool3d(x, kernel_size=2, stride=2)
354
+ x = F.relu(self.batchnorm3(self.conv3(x)))
355
+ x = F.max_pool3d(x, kernel_size=2, stride=2)
356
+ x = F.relu(self.batchnorm4(self.conv4(x)))
357
+ x = F.max_pool3d(x, kernel_size=2, stride=2)
358
+ x = F.relu(self.batchnorm5(self.conv5(x)))
359
+ x = self.pool(x)
360
+ x = x.view(x.size(0), -1)
361
+ x = F.relu(self.fc1(x))
362
+ x = self.fc2(x)
363
+ x = self.norm1(x)
364
+
365
+ return x
366
+
367
+
368
+ # Define a neural network with one hidden layer
369
+ class Interpolator(nn.Module):
370
+ def __init__(self):
371
+ super().__init__()
372
+ self.fc1 = nn.Linear(6, 250).double()
373
+ self.fc2 = nn.Linear(250, 2).double()
374
+
375
+ def forward(self, x):
376
+ x = torch.relu(self.fc1(x))
377
+ x = self.fc2(x)
378
+ return x
379
+
380
+ # Initialize the neural network
381
+ net = Interpolator()
382
+ net.load_state_dict(torch.load('/Users/FDean/Desktop/Physics_Informed_Transfer_Learning/final_model_weights/interp6_7param_weight.pt'))
383
+ print("Done loading interpolator!")
384
+
385
+ weights_path = '/Users/FDean/Desktop/Physics_Informed_Transfer_Learning/final_model_weights/202_full_echonet_7param_Vloss_epoch_200_lr_0.001_weight_best_model.pt'
386
+ model = NEW3DCNN(num_parameters = 7)
387
+ model.load_state_dict(torch.load(weights_path))
388
+ model.to(device)
389
+
390
+ ## PV loops
391
+
392
+ #returns Plv at time t using Elastance(t) and Vlv(t)-Vd=x1
393
+ def Plv(volume, Emax, Emin, t, Tc, Vd):
394
+ return Elastance(Emax,Emin,t, Tc)*(volume - Vd)
395
+
396
+ #returns Elastance(t)
397
+ def Elastance(Emax,Emin, t, Tc):
398
+ t = t-int(t/Tc)*Tc #can remove this if only want 1st ED (and the 1st ES before)
399
+ tn = t/(0.2+0.15*Tc)
400
+ return (Emax-Emin)*1.55*(tn/0.7)**1.9/((tn/0.7)**1.9+1)*1/((tn/1.17)**21.9+1) + Emin
401
+
402
+ def solve_ODE_for_volume(Rm, Ra, Emax, Emin, Vd, Tc, start_v, t):
403
+
404
+ # the ODE from Simaan et al 2008
405
+ def heart_ode(y, t, Rs, Rm, Ra, Rc, Ca, Cs, Cr, Ls, Emax, Emin, Tc):
406
+ x1, x2, x3, x4, x5 = y #here y is a vector of 5 values (not functions), at time t, used for getting (dy/dt)(t)
407
+ P_lv = Plv(x1+Vd,Emax,Emin,t,Tc,Vd)
408
+ dydt = [r(x2-P_lv)/Rm-r(P_lv-x4)/Ra, (x3-x2)/(Rs*Cr)-r(x2-P_lv)/(Cr*Rm), (x2-x3)/(Rs*Cs)+x5/Cs, -x5/Ca+r(P_lv-x4)/(Ca*Ra), (x4-x3-Rc*x5)/Ls]
409
+ return dydt
410
+
411
+ # RELU for diodes
412
+ def r(u):
413
+ return max(u, 0.)
414
+
415
+ # Define fixed parameters
416
+ Rs = 1.0
417
+ Rc = 0.0398
418
+ Ca = 0.08
419
+ Cs = 1.33
420
+ Cr = 4.400
421
+ Ls = 0.0005
422
+ startp = 75.
423
+
424
+ # Initial conditions
425
+ start_pla = float(start_v*Elastance(Emax, Emin, 0, Tc))
426
+ start_pao = startp
427
+ start_pa = start_pao
428
+ start_qt = 0 #aortic flow is Q_T and is 0 at ED, also see Fig5 in simaan2008dynamical
429
+ y0 = [start_v, start_pla, start_pa, start_pao, start_qt]
430
+
431
+ # Solve
432
+ sol = odeint(heart_ode, y0, t, args = (Rs, Rm, Ra, Rc, Ca, Cs, Cr, Ls, Emax, Emin, Tc)) #t: list of values
433
+
434
+ # volume is the first state variable plus theoretical zero pressure volume
435
+ volumes = np.array(sol[:, 0]) + Vd
436
+
437
+ return volumes
438
+
439
+ def pvloop_simulator(Rm, Ra, Emax, Emin, Vd, Tc, start_v):
440
+
441
+
442
+ # Define initial parameters
443
+ init_Emax = Emax # 3.0 # .5 to 3.5
444
+ init_Emin = Emin # 0.04 # .02 to .1
445
+ # init_Tc = Tc # .4 # .4 to 1.7
446
+ init_Vd = Vd # 10.0 # 0 to 25
447
+
448
+ # DUMMY VOLUME
449
+ # def volume(t, Tc):
450
+ # return 50*np.sin(2 * np.pi * t*(1/Tc))+100
451
+
452
+ # SOLVE the ODE model for the VOLUME CURVE
453
+ N = 100
454
+ t = np.linspace(0, Tc*N, int(60000*N)) #np.linspace(1, 100, 1000000)
455
+ volumes = solve_ODE_for_volume(Rm, Ra, Emax, Emin, Vd, Tc, start_v, t)
456
+
457
+ # FUNCTIONS for PRESSURE CURVE
458
+ vectorized_Elastance = np.vectorize(Elastance)
459
+ vectorized_Plv = np.vectorize(Plv)
460
+
461
+ def pressure(t, volume, Emax, Emin, Tc, Vd):
462
+ return vectorized_Plv(volume, Emax, Emin, t, Tc, Vd)
463
+
464
+ # calculate PRESSURE
465
+ pressures = pressure(t, volumes, init_Emax, init_Emin, Tc, init_Vd)
466
+
467
+ # Create the figure and the loop that we will manipulate
468
+ fig, ax = plt.subplots()
469
+ plt.ylim((0,280))
470
+ plt.xlim((0,280))
471
+ line = ax.plot(volumes[(N-2)*60000:(N)*60000], pressures[(N-2)*60000:(N)*60000], lw=1)
472
+ #print(line)
473
+ line = line[0]
474
+ #print(line)
475
+
476
+ fig.suptitle('Predicted PI-SSL LV Pressure Volume Loop', fontsize=16)
477
+ #plt.rcParams['fig.suptitle'] = -2.0
478
+ #ax.set_title(f'Mitral valve circuit resistance (Rm): {Rm} mmHg*s/ml \n Aortic valve circuit resistance (Ra): {Ra} mmHg*s/ml', fontsize=6)
479
+ ax.set_xlabel('LV Volume (ml)')
480
+ ax.set_ylabel('LV Pressure (mmHg)')
481
+
482
+ # adjust the main plot to make room for the sliders
483
+ fig.subplots_adjust(left=0.25, bottom=0.25)
484
+
485
+ return plt, Rm, Ra, Emax, Emin, Vd, Tc, start_v
486
+
487
+ def pvloop_simulator_plot_only(Rm, Ra, Emax, Emin, Vd, Tc, start_v):
488
+ plot,_,_,_,_,_,_,_ =pvloop_simulator(Rm, Ra, Emax, Emin, Vd, Tc, start_v)
489
+ return plot
490
+
491
+ ## Demo
492
+
493
+ def generate_example():
494
+ # get random input
495
+ data_path = '/Users/FDean/Desktop/Physics_Informed_Transfer_Learning/EchoNet-Dynamic'
496
+ image_data = Echo(root = data_path, split = 'all', target_type=['Filename','LargeIndex','SmallIndex'])
497
+ image_loaded_data = DataLoader(image_data, batch_size=1, shuffle=True)
498
+ val_data = next(iter(image_loaded_data))
499
+ #create_echo_clip(val_data,'test')
500
+ val_seq = val_data[0]
501
+ filename = val_data[1][0][0]
502
+ video = os.path.join(os.getcwd(), f"EchoNet-Dynamic/Videos/{filename}")
503
+ val_tensor = torch.tensor(val_seq, dtype=torch.float32)
504
+ results = model(val_tensor)
505
+
506
+ plot, Rm, Ra, Emax, Emin, Vd,Tc, start_v = pvloop_simulator(Rm=round(results[0][4].item(),2), Ra=round(results[0][5].item(),2), Emax=results[0][2].item(), Emin=round(results[0][3].item(),2), Vd=round(results[0][6].item(),2), Tc=round(results[0][0].item(),2), start_v=round(results[0][1].item(),2))
507
+
508
+ return video, plot, Rm, Ra, Emax, Emin, Vd, Tc, start_v
509
+
510
+ title = "Physics-informed self-supervised learning for predicting cardiac digital twins with echocardiography"
511
+
512
+ description = """
513
+ <p style='text-align: center'> Keying Kuang, Frances Dean, Jack B. Jedlicki, David Ouyang, Anthony Philippakis, David Sontag, Ahmed Alaa <br>
514
+ <a href='https://github.com/AlaaLab/CardioPINN' target='_blank'>Code</a></p>
515
+ We develop methodology for predicting digital twins from non-invasive cardiac ultrasound images in <a href='https://arxiv.org/abs/2403.00177'>Non-Invasive Medical Digital Twins using Physics-Informed Self-Supervised Learning</a>. \n\n
516
+ We demonstrate the ability of our model to predict left ventricular pressure-volume loops using image data here.
517
+ """
518
+
519
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>")
520
+ gr.Markdown(description)
521
+
522
+ with gr.Blocks() as demo:
523
+
524
+ # text
525
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>")
526
+ gr.Markdown(description)
527
+
528
+ with gr.Row():
529
+ with gr.Column(scale=1.5, min_width=100):
530
+
531
+ generate_button = gr.Button("Load sample echocardiogram and generate result")
532
+ with gr.Row():
533
+ video = gr.PlayableVideo(format="avi")
534
+ plot = gr.Plot()
535
+
536
+ with gr.Row():
537
+ Rm = gr.Number(label="Mitral valve circuit resistance (Rm) mmHg*s/ml:")
538
+ Ra = gr.Number(label="Aortic valve circuit resistance (Ra) mmHg*s/ml:")
539
+ Emax = gr.Number(label="Maximum elastance (Emax) mmHg/ml:")
540
+ Emin = gr.Number(label="Minimum elastance (Emin) mmHg/ml:")
541
+ Vd = gr.Number(label="Theoretical zero pressure volume (Vd) ml:")
542
+ Tc = gr.Number(label="Cycle duration (Tc) s:")
543
+ start_v = gr.Number(label="Initial volume (start_v) ml:")
544
+
545
+ simulation_button = gr.Button("Run simulation")
546
+
547
+
548
+
549
+ with gr.Row():
550
+ sl1 = gr.Slider(0.005, 0.1, value=Rm, label="Rm")
551
+ sl2 = gr.Slider(0.0001, 0.25, value=Ra, label="Ra")
552
+ sl3 = gr.Slider(0.5, 3.5, value=Emax, label="Emax")
553
+ sl4 = gr.Slider(0.02, 0.1, value= Emin, label="Emin")
554
+ sl5 = gr.Slider(4.0, 25.0, value=Vd, label="Vd")
555
+ sl6 = gr.Slider(0.4, 1.7, value=Tc, label="Tc")
556
+ sl7 = gr.Slider(0.0, 280.0, value=start_v, label="start_v")
557
+
558
+
559
+ generate_button.click(fn=generate_example, outputs = [video,plot,Rm,Ra,Emax,Emin,Vd,Tc,start_v])
560
+
561
+
562
+ simulation_button.click(fn=pvloop_simulator_plot_only, inputs = [sl1,sl2,sl3,sl4,sl5,sl6,sl7], outputs = [gr.Plot()])
563
+
564
+
565
+
566
+ demo.launch(share=True)
dynamic/.DS_Store ADDED
Binary file (6.15 kB). View file
 
dynamic/.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .ipynb_checkpoints/
2
+ __pycache__/
3
+ *.swp
4
+ echonet.cfg
5
+ .echonet.cfg
6
+ *.pyc
7
+ echonet.egg-info/
dynamic/.travis.yml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ language: minimal
2
+
3
+ os:
4
+ - linux
5
+
6
+ env:
7
+ # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.2 (torchvision 0.2 does not have VisionDataset)
8
+ # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.3 (torchvision 0.3 has a cuda issue)
9
+ - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.4
10
+ - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.5
11
+ # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.2
12
+ # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.3
13
+ - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.4
14
+ - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.5
15
+ # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.2
16
+ # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.3
17
+ - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.4
18
+ - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.5
19
+ # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.2
20
+ # - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.3
21
+ - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.4
22
+ - PYTHON_VERSION=3.6 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.5
23
+ # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.2
24
+ # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.3
25
+ - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.4
26
+ - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.1 TORCHVISION_VERSION=0.5
27
+ # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.2
28
+ # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.3
29
+ - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.4
30
+ - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.2 TORCHVISION_VERSION=0.5
31
+ # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.2
32
+ # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.3
33
+ - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.4
34
+ - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.3 TORCHVISION_VERSION=0.5
35
+ # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.2
36
+ # - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.3
37
+ - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.4
38
+ - PYTHON_VERSION=3.7 PYTORCH_VERSION=1.4 TORCHVISION_VERSION=0.5
39
+
40
+ install:
41
+ - if [[ "$TRAVIS_OS_NAME" == "linux" ]];
42
+ then
43
+ MINICONDA_OS=Linux;
44
+ sudo apt-get update;
45
+ else
46
+ MINICONDA_OS=MacOSX;
47
+ brew update;
48
+ fi
49
+ - wget https://repo.anaconda.com/miniconda/Miniconda3-latest-${MINICONDA_OS}-x86_64.sh -O miniconda.sh
50
+ - bash miniconda.sh -b -p $HOME/miniconda
51
+ - source "$HOME/miniconda/etc/profile.d/conda.sh"
52
+ - hash -r
53
+ - conda config --set always_yes yes --set changeps1 no
54
+ - conda update -q conda
55
+ # Useful for debugging any issues with conda
56
+ - conda info -a
57
+ - conda search pytorch || true
58
+
59
+ - conda create -q -n test-environment python=${PYTHON_VERSION} pytorch=${PYTORCH_VERSION}
60
+ - conda activate test-environment
61
+ - pip install -q torchvision==${TORCHVISION_VERSION} "pillow<7.0.0"
62
+ - pip install -q .
63
+ - pip install -q flake8 pylint
64
+
65
+ script:
66
+ - flake8 --ignore=E501
67
+ - pylint --disable=C0103,C0301,R0401,R0801,R0902,R0912,R0913,R0914,R0915 --extension-pkg-whitelist=cv2,torch --generated-members=torch.* echonet/ scripts/*.py setup.py
68
+ - python -c "import echonet"
dynamic/LICENSE.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Copyright Notice
2
+ The authors are the proprietor of certain copyrights of and to EchoNet-Dynamic software, source code and associated material. Code also contains source code created by certain third parties. Redistribution and use of the Code with or without modification is not permitted without explicit written permission by the authors.
3
+ Copyright 2019 The authors. All rights reserved.
dynamic/README.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ EchoNet-Dynamic:<br/>Interpretable AI for beat-to-beat cardiac function assessment
2
+ ------------------------------------------------------------------------------
3
+
4
+ EchoNet-Dynamic is a end-to-end beat-to-beat deep learning model for
5
+ 1) semantic segmentation of the left ventricle
6
+ 2) prediction of ejection fraction by entire video or subsampled clips, and
7
+ 3) assessment of cardiomyopathy with reduced ejection fraction.
8
+
9
+ For more details, see the accompanying paper,
10
+
11
+ > [**Video-based AI for beat-to-beat assessment of cardiac function**](https://www.nature.com/articles/s41586-020-2145-8)<br/>
12
+ David Ouyang, Bryan He, Amirata Ghorbani, Neal Yuan, Joseph Ebinger, Curt P. Langlotz, Paul A. Heidenreich, Robert A. Harrington, David H. Liang, Euan A. Ashley, and James Y. Zou. <b>Nature</b>, March 25, 2020. https://doi.org/10.1038/s41586-020-2145-8
13
+
14
+ Dataset
15
+ -------
16
+ We share a deidentified set of 10,030 echocardiogram images which were used for training EchoNet-Dynamic.
17
+ Preprocessing of these images, including deidentification and conversion from DICOM format to AVI format videos, were performed with OpenCV and pydicom. Additional information is at https://echonet.github.io/dynamic/. These deidentified images are shared with a non-commerical data use agreement.
18
+
19
+ Examples
20
+ --------
21
+
22
+ We show examples of our semantic segmentation for nine distinct patients below.
23
+ Three patients have normal cardiac function, three have low ejection fractions, and three have arrhythmia.
24
+ No human tracings for these patients were used by EchoNet-Dynamic.
25
+
26
+ | Normal | Low Ejection Fraction | Arrhythmia |
27
+ | ------ | --------------------- | ---------- |
28
+ | ![](docs/media/0X10A28877E97DF540.gif) | ![](docs/media/0X129133A90A61A59D.gif) | ![](docs/media/0X132C1E8DBB715D1D.gif) |
29
+ | ![](docs/media/0X1167650B8BEFF863.gif) | ![](docs/media/0X13CE2039E2D706A.gif ) | ![](docs/media/0X18BA5512BE5D6FFA.gif) |
30
+ | ![](docs/media/0X148FFCBF4D0C398F.gif) | ![](docs/media/0X16FC9AA0AD5D8136.gif) | ![](docs/media/0X1E12EEE43FD913E5.gif) |
31
+
32
+ Installation
33
+ ------------
34
+
35
+ First, clone this repository and enter the directory by running:
36
+
37
+ git clone https://github.com/echonet/dynamic.git
38
+ cd dynamic
39
+
40
+ EchoNet-Dynamic is implemented for Python 3, and depends on the following packages:
41
+ - NumPy
42
+ - PyTorch
43
+ - Torchvision
44
+ - OpenCV
45
+ - skimage
46
+ - sklearn
47
+ - tqdm
48
+
49
+ Echonet-Dynamic and its dependencies can be installed by navigating to the cloned directory and running
50
+
51
+ pip install --user .
52
+
53
+ Usage
54
+ -----
55
+ ### Preprocessing DICOM Videos
56
+
57
+ The input of EchoNet-Dynamic is an apical-4-chamber view echocardiogram video of any length. The easiest way to run our code is to use videos from our dataset, but we also provide a Jupyter Notebook, `ConvertDICOMToAVI.ipynb`, to convert DICOM files to AVI files used for input to EchoNet-Dynamic. The Notebook deidentifies the video by cropping out information outside of the ultrasound sector, resizes the input video, and saves the video in AVI format.
58
+
59
+ ### Setting Path to Data
60
+
61
+ By default, EchoNet-Dynamic assumes that a copy of the data is saved in a folder named `a4c-video-dir/` in this directory.
62
+ This path can be changed by creating a configuration file named `echonet.cfg` (an example configuration file is `example.cfg`).
63
+
64
+ ### Running Code
65
+
66
+ EchoNet-Dynamic has three main components: segmenting the left ventricle, predicting ejection fraction from subsampled clips, and assessing cardiomyopathy with beat-by-beat predictions.
67
+ Each of these components can be run with reasonable choices of hyperparameters with the scripts below.
68
+ We describe our full hyperparameter sweep in the next section.
69
+
70
+ #### Frame-by-frame Semantic Segmentation of the Left Ventricle
71
+
72
+ echonet segmentation --save_video
73
+
74
+ This creates a directory named `output/segmentation/deeplabv3_resnet50_random/`, which will contain
75
+ - log.csv: training and validation losses
76
+ - best.pt: checkpoint of weights for the model with the lowest validation loss
77
+ - size.csv: estimated size of left ventricle for each frame and indicator for beginning of beat
78
+ - videos: directory containing videos with segmentation overlay
79
+
80
+ #### Prediction of Ejection Fraction from Subsampled Clips
81
+
82
+ echonet video
83
+
84
+ This creates a directory named `output/video/r2plus1d_18_32_2_pretrained/`, which will contain
85
+ - log.csv: training and validation losses
86
+ - best.pt: checkpoint of weights for the model with the lowest validation loss
87
+ - test_predictions.csv: ejection fraction prediction for subsampled clips
88
+
89
+ #### Beat-by-beat Prediction of Ejection Fraction from Full Video and Assesment of Cardiomyopathy
90
+
91
+ The final beat-by-beat prediction and analysis is performed with `scripts/beat_analysis.R`.
92
+ This script combines the results from segmentation output in `size.csv` and the clip-level ejection fraction prediction in `test_predictions.csv`. The beginning of each systolic phase is detected by using the peak detection algorithm from scipy (`scipy.signal.find_peaks`) and a video clip centered around the beat is used for beat-by-beat prediction.
93
+
94
+ ### Hyperparameter Sweeps
95
+
96
+ The full set of hyperparameter sweeps from the paper can be run via `run_experiments.sh`.
97
+ In particular, we choose between pretrained and random initialization for the weights, the model (selected from `r2plus1d_18`, `r3d_18`, and `mc3_18`), the length of the video (1, 4, 8, 16, 32, 64, and 96 frames), and the sampling period (1, 2, 4, 6, and 8 frames).
dynamic/echonet/.DS_Store ADDED
Binary file (6.15 kB). View file
 
dynamic/echonet/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The echonet package contains code for loading echocardiogram videos, and
3
+ functions for training and testing segmentation and ejection fraction
4
+ prediction models.
5
+ """
6
+
7
+ import click
8
+
9
+ from echonet.__version__ import __version__
10
+ from echonet.config import CONFIG as config
11
+ import echonet.datasets as datasets
12
+ import echonet.utils as utils
13
+
14
+
15
+ @click.group()
16
+ def main():
17
+ """Entry point for command line interface."""
18
+
19
+
20
+ del click
21
+
22
+
23
+ main.add_command(utils.segmentation.run)
24
+ main.add_command(utils.video.run)
25
+
26
+ __all__ = ["__version__", "config", "datasets", "main", "utils"]
dynamic/echonet/__main__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Entry point for command line."""
2
+
3
+ import echonet
4
+
5
+
6
+ if __name__ == '__main__':
7
+ echonet.main()
dynamic/echonet/__version__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Version number for Echonet package."""
2
+
3
+ __version__ = "1.0.0"
dynamic/echonet/config.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sets paths based on configuration files."""
2
+
3
+ import configparser
4
+ import os
5
+ import types
6
+
7
+ _FILENAME = None
8
+ _PARAM = {}
9
+ for filename in ["echonet.cfg",
10
+ ".echonet.cfg",
11
+ os.path.expanduser("~/echonet.cfg"),
12
+ os.path.expanduser("~/.echonet.cfg"),
13
+ ]:
14
+ if os.path.isfile(filename):
15
+ _FILENAME = filename
16
+ config = configparser.ConfigParser()
17
+ with open(filename, "r") as f:
18
+ config.read_string("[config]\n" + f.read())
19
+ _PARAM = config["config"]
20
+ break
21
+
22
+ CONFIG = types.SimpleNamespace(
23
+ FILENAME=_FILENAME,
24
+ DATA_DIR=_PARAM.get("data_dir", "a4c-video-dir/"))
dynamic/echonet/datasets/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The echonet.datasets submodule defines a Pytorch dataset for loading
3
+ echocardiogram videos.
4
+ """
5
+
6
+ from .echo import Echo
7
+
8
+ __all__ = ["Echo"]
dynamic/echonet/datasets/echo.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """EchoNet-Dynamic Dataset."""
2
+
3
+ import os
4
+ import collections
5
+ import pandas
6
+
7
+ import numpy as np
8
+ import skimage.draw
9
+ import torchvision
10
+ import echonet
11
+
12
+
13
+ class Echo(torchvision.datasets.VisionDataset):
14
+ """EchoNet-Dynamic Dataset.
15
+
16
+ Args:
17
+ root (string): Root directory of dataset (defaults to `echonet.config.DATA_DIR`)
18
+ split (string): One of {``train'', ``val'', ``test'', ``all'', or ``external_test''}
19
+ target_type (string or list, optional): Type of target to use,
20
+ ``Filename'', ``EF'', ``EDV'', ``ESV'', ``LargeIndex'',
21
+ ``SmallIndex'', ``LargeFrame'', ``SmallFrame'', ``LargeTrace'',
22
+ or ``SmallTrace''
23
+ Can also be a list to output a tuple with all specified target types.
24
+ The targets represent:
25
+ ``Filename'' (string): filename of video
26
+ ``EF'' (float): ejection fraction
27
+ ``EDV'' (float): end-diastolic volume
28
+ ``ESV'' (float): end-systolic volume
29
+ ``LargeIndex'' (int): index of large (diastolic) frame in video
30
+ ``SmallIndex'' (int): index of small (systolic) frame in video
31
+ ``LargeFrame'' (np.array shape=(3, height, width)): normalized large (diastolic) frame
32
+ ``SmallFrame'' (np.array shape=(3, height, width)): normalized small (systolic) frame
33
+ ``LargeTrace'' (np.array shape=(height, width)): left ventricle large (diastolic) segmentation
34
+ value of 0 indicates pixel is outside left ventricle
35
+ 1 indicates pixel is inside left ventricle
36
+ ``SmallTrace'' (np.array shape=(height, width)): left ventricle small (systolic) segmentation
37
+ value of 0 indicates pixel is outside left ventricle
38
+ 1 indicates pixel is inside left ventricle
39
+ Defaults to ``EF''.
40
+ mean (int, float, or np.array shape=(3,), optional): means for all (if scalar) or each (if np.array) channel.
41
+ Used for normalizing the video. Defaults to 0 (video is not shifted).
42
+ std (int, float, or np.array shape=(3,), optional): standard deviation for all (if scalar) or each (if np.array) channel.
43
+ Used for normalizing the video. Defaults to 0 (video is not scaled).
44
+ length (int or None, optional): Number of frames to clip from video. If ``None'', longest possible clip is returned.
45
+ Defaults to 16.
46
+ period (int, optional): Sampling period for taking a clip from the video (i.e. every ``period''-th frame is taken)
47
+ Defaults to 2.
48
+ max_length (int or None, optional): Maximum number of frames to clip from video (main use is for shortening excessively
49
+ long videos when ``length'' is set to None). If ``None'', shortening is not applied to any video.
50
+ Defaults to 250.
51
+ clips (int, optional): Number of clips to sample. Main use is for test-time augmentation with random clips.
52
+ Defaults to 1.
53
+ pad (int or None, optional): Number of pixels to pad all frames on each side (used as augmentation).
54
+ and a window of the original size is taken. If ``None'', no padding occurs.
55
+ Defaults to ``None''.
56
+ noise (float or None, optional): Fraction of pixels to black out as simulated noise. If ``None'', no simulated noise is added.
57
+ Defaults to ``None''.
58
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
59
+ external_test_location (string): Path to videos to use for external testing.
60
+ """
61
+
62
+ def __init__(self, root=None,
63
+ split="train", target_type="EF",
64
+ mean=0., std=1.,
65
+ length=16, period=2,
66
+ max_length=250,
67
+ clips=1,
68
+ pad=None,
69
+ noise=None,
70
+ target_transform=None,
71
+ external_test_location=None):
72
+ if root is None:
73
+ root = echonet.config.DATA_DIR
74
+
75
+ super().__init__(root, target_transform=target_transform)
76
+
77
+ self.split = split.upper()
78
+ if not isinstance(target_type, list):
79
+ target_type = [target_type]
80
+ self.target_type = target_type
81
+ self.mean = mean
82
+ self.std = std
83
+ self.length = length
84
+ self.max_length = max_length
85
+ self.period = period
86
+ self.clips = clips
87
+ self.pad = pad
88
+ self.noise = noise
89
+ self.target_transform = target_transform
90
+ self.external_test_location = external_test_location
91
+
92
+ self.fnames, self.outcome = [], []
93
+
94
+ if self.split == "EXTERNAL_TEST":
95
+ self.fnames = sorted(os.listdir(self.external_test_location))
96
+ else:
97
+ # Load video-level labels
98
+ with open(os.path.join(self.root, "FileList.csv")) as f:
99
+ data = pandas.read_csv(f)
100
+ data["Split"].map(lambda x: x.upper())
101
+
102
+ if self.split != "ALL":
103
+ data = data[data["Split"] == self.split]
104
+
105
+ self.header = data.columns.tolist()
106
+ self.fnames = data["FileName"].tolist()
107
+ self.fnames = [fn + ".avi" for fn in self.fnames if os.path.splitext(fn)[1] == ""] # Assume avi if no suffix
108
+ self.outcome = data.values.tolist()
109
+
110
+ # Check that files are present
111
+ missing = set(self.fnames) - set(os.listdir(os.path.join(self.root, "Videos")))
112
+ if len(missing) != 0:
113
+ print("{} videos could not be found in {}:".format(len(missing), os.path.join(self.root, "Videos")))
114
+ for f in sorted(missing):
115
+ print("\t", f)
116
+ raise FileNotFoundError(os.path.join(self.root, "Videos", sorted(missing)[0]))
117
+
118
+ # Load traces
119
+ self.frames = collections.defaultdict(list)
120
+ self.trace = collections.defaultdict(_defaultdict_of_lists)
121
+
122
+ with open(os.path.join(self.root, "VolumeTracings.csv")) as f:
123
+ header = f.readline().strip().split(",")
124
+ assert header == ["FileName", "X1", "Y1", "X2", "Y2", "Frame"]
125
+
126
+ for line in f:
127
+ filename, x1, y1, x2, y2, frame = line.strip().split(',')
128
+ x1 = float(x1)
129
+ y1 = float(y1)
130
+ x2 = float(x2)
131
+ y2 = float(y2)
132
+ frame = int(frame)
133
+ if frame not in self.trace[filename]:
134
+ self.frames[filename].append(frame)
135
+ self.trace[filename][frame].append((x1, y1, x2, y2))
136
+ for filename in self.frames:
137
+ for frame in self.frames[filename]:
138
+ self.trace[filename][frame] = np.array(self.trace[filename][frame])
139
+
140
+ # A small number of videos are missing traces; remove these videos
141
+ keep = [len(self.frames[f]) >= 2 for f in self.fnames]
142
+ self.fnames = [f for (f, k) in zip(self.fnames, keep) if k]
143
+ self.outcome = [f for (f, k) in zip(self.outcome, keep) if k]
144
+
145
+ def __getitem__(self, index):
146
+ # Find filename of video
147
+ if self.split == "EXTERNAL_TEST":
148
+ video = os.path.join(self.external_test_location, self.fnames[index])
149
+ elif self.split == "CLINICAL_TEST":
150
+ video = os.path.join(self.root, "ProcessedStrainStudyA4c", self.fnames[index])
151
+ else:
152
+ video = os.path.join(self.root, "Videos", self.fnames[index])
153
+
154
+ # Load video into np.array
155
+ video = echonet.utils.loadvideo(video).astype(np.float32)
156
+
157
+ # Add simulated noise (black out random pixels)
158
+ # 0 represents black at this point (video has not been normalized yet)
159
+ if self.noise is not None:
160
+ n = video.shape[1] * video.shape[2] * video.shape[3]
161
+ ind = np.random.choice(n, round(self.noise * n), replace=False)
162
+ f = ind % video.shape[1]
163
+ ind //= video.shape[1]
164
+ i = ind % video.shape[2]
165
+ ind //= video.shape[2]
166
+ j = ind
167
+ video[:, f, i, j] = 0
168
+
169
+ # Apply normalization
170
+ if isinstance(self.mean, (float, int)):
171
+ video -= self.mean
172
+ else:
173
+ video -= self.mean.reshape(3, 1, 1, 1)
174
+
175
+ if isinstance(self.std, (float, int)):
176
+ video /= self.std
177
+ else:
178
+ video /= self.std.reshape(3, 1, 1, 1)
179
+
180
+ # Set number of frames
181
+ c, f, h, w = video.shape
182
+ if self.length is None:
183
+ # Take as many frames as possible
184
+ length = f // self.period
185
+ else:
186
+ # Take specified number of frames
187
+ length = self.length
188
+
189
+ if self.max_length is not None:
190
+ # Shorten videos to max_length
191
+ length = min(length, self.max_length)
192
+
193
+ if f < length * self.period:
194
+ # Pad video with frames filled with zeros if too short
195
+ # 0 represents the mean color (dark grey), since this is after normalization
196
+ video = np.concatenate((video, np.zeros((c, length * self.period - f, h, w), video.dtype)), axis=1)
197
+ c, f, h, w = video.shape # pylint: disable=E0633
198
+
199
+ if self.clips == "all":
200
+ # Take all possible clips of desired length
201
+ start = np.arange(f - (length - 1) * self.period)
202
+ else:
203
+ # Take random clips from video
204
+ start = np.random.choice(f - (length - 1) * self.period, self.clips)
205
+
206
+ # Gather targets
207
+ target = []
208
+ for t in self.target_type:
209
+ key = self.fnames[index]
210
+ if t == "Filename":
211
+ target.append(self.fnames[index])
212
+ elif t == "LargeIndex":
213
+ # Traces are sorted by cross-sectional area
214
+ # Largest (diastolic) frame is last
215
+ target.append(np.int(self.frames[key][-1]))
216
+ elif t == "SmallIndex":
217
+ # Largest (diastolic) frame is first
218
+ target.append(np.int(self.frames[key][0]))
219
+ elif t == "LargeFrame":
220
+ target.append(video[:, self.frames[key][-1], :, :])
221
+ elif t == "SmallFrame":
222
+ target.append(video[:, self.frames[key][0], :, :])
223
+ elif t in ["LargeTrace", "SmallTrace"]:
224
+ if t == "LargeTrace":
225
+ t = self.trace[key][self.frames[key][-1]]
226
+ else:
227
+ t = self.trace[key][self.frames[key][0]]
228
+ x1, y1, x2, y2 = t[:, 0], t[:, 1], t[:, 2], t[:, 3]
229
+ x = np.concatenate((x1[1:], np.flip(x2[1:])))
230
+ y = np.concatenate((y1[1:], np.flip(y2[1:])))
231
+
232
+ r, c = skimage.draw.polygon(np.rint(y).astype(np.int), np.rint(x).astype(np.int), (video.shape[2], video.shape[3]))
233
+ mask = np.zeros((video.shape[2], video.shape[3]), np.float32)
234
+ mask[r, c] = 1
235
+ target.append(mask)
236
+ else:
237
+ if self.split == "CLINICAL_TEST" or self.split == "EXTERNAL_TEST":
238
+ target.append(np.float32(0))
239
+ else:
240
+ target.append(np.float32(self.outcome[index][self.header.index(t)]))
241
+
242
+ if target != []:
243
+ target = tuple(target) if len(target) > 1 else target[0]
244
+ if self.target_transform is not None:
245
+ target = self.target_transform(target)
246
+
247
+ # Select clips from video
248
+ video = tuple(video[:, s + self.period * np.arange(length), :, :] for s in start)
249
+ if self.clips == 1:
250
+ video = video[0]
251
+ else:
252
+ video = np.stack(video)
253
+
254
+ if self.pad is not None:
255
+ # Add padding of zeros (mean color of videos)
256
+ # Crop of original size is taken out
257
+ # (Used as augmentation)
258
+ c, l, h, w = video.shape
259
+ temp = np.zeros((c, l, h + 2 * self.pad, w + 2 * self.pad), dtype=video.dtype)
260
+ temp[:, :, self.pad:-self.pad, self.pad:-self.pad] = video # pylint: disable=E1130
261
+ i, j = np.random.randint(0, 2 * self.pad, 2)
262
+ video = temp[:, :, i:(i + h), j:(j + w)]
263
+
264
+ return video, target
265
+
266
+ def __len__(self):
267
+ return len(self.fnames)
268
+
269
+ def extra_repr(self) -> str:
270
+ """Additional information to add at end of __repr__."""
271
+ lines = ["Target type: {target_type}", "Split: {split}"]
272
+ return '\n'.join(lines).format(**self.__dict__)
273
+
274
+
275
+ def _defaultdict_of_lists():
276
+ """Returns a defaultdict of lists.
277
+
278
+ This is used to avoid issues with Windows (if this function is anonymous,
279
+ the Echo dataset cannot be used in a dataloader).
280
+ """
281
+
282
+ return collections.defaultdict(list)
dynamic/echonet/utils/__init__.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for videos, plotting and computing performance metrics."""
2
+
3
+ import os
4
+ import typing
5
+
6
+ import cv2 # pytype: disable=attribute-error
7
+ import matplotlib
8
+ import numpy as np
9
+ import torch
10
+ import tqdm
11
+
12
+ from . import video
13
+ from . import segmentation
14
+
15
+
16
+ def loadvideo(filename: str) -> np.ndarray:
17
+ """Loads a video from a file.
18
+
19
+ Args:
20
+ filename (str): filename of video
21
+
22
+ Returns:
23
+ A np.ndarray with dimensions (channels=3, frames, height, width). The
24
+ values will be uint8's ranging from 0 to 255.
25
+
26
+ Raises:
27
+ FileNotFoundError: Could not find `filename`
28
+ ValueError: An error occurred while reading the video
29
+ """
30
+
31
+ if not os.path.exists(filename):
32
+ raise FileNotFoundError(filename)
33
+ capture = cv2.VideoCapture(filename)
34
+
35
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
36
+ frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
37
+ frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
38
+
39
+ v = np.zeros((frame_count, frame_height, frame_width, 3), np.uint8)
40
+
41
+ for count in range(frame_count):
42
+ ret, frame = capture.read()
43
+ if not ret:
44
+ raise ValueError("Failed to load frame #{} of {}.".format(count, filename))
45
+
46
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
47
+ v[count, :, :] = frame
48
+
49
+ v = v.transpose((3, 0, 1, 2))
50
+
51
+ return v
52
+
53
+
54
+ def savevideo(filename: str, array: np.ndarray, fps: typing.Union[float, int] = 1):
55
+ """Saves a video to a file.
56
+
57
+ Args:
58
+ filename (str): filename of video
59
+ array (np.ndarray): video of uint8's with shape (channels=3, frames, height, width)
60
+ fps (float or int): frames per second
61
+
62
+ Returns:
63
+ None
64
+ """
65
+
66
+ c, _, height, width = array.shape
67
+
68
+ if c != 3:
69
+ raise ValueError("savevideo expects array of shape (channels=3, frames, height, width), got shape ({})".format(", ".join(map(str, array.shape))))
70
+ fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
71
+ out = cv2.VideoWriter(filename, fourcc, fps, (width, height))
72
+
73
+ for frame in array.transpose((1, 2, 3, 0)):
74
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
75
+ out.write(frame)
76
+
77
+
78
+ def get_mean_and_std(dataset: torch.utils.data.Dataset,
79
+ samples: int = 128,
80
+ batch_size: int = 8,
81
+ num_workers: int = 4):
82
+ """Computes mean and std from samples from a Pytorch dataset.
83
+
84
+ Args:
85
+ dataset (torch.utils.data.Dataset): A Pytorch dataset.
86
+ ``dataset[i][0]'' is expected to be the i-th video in the dataset, which
87
+ should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width)
88
+ samples (int or None, optional): Number of samples to take from dataset. If ``None'', mean and
89
+ standard deviation are computed over all elements.
90
+ Defaults to 128.
91
+ batch_size (int, optional): how many samples per batch to load
92
+ Defaults to 8.
93
+ num_workers (int, optional): how many subprocesses to use for data
94
+ loading. If 0, the data will be loaded in the main process.
95
+ Defaults to 4.
96
+
97
+ Returns:
98
+ A tuple of the mean and standard deviation. Both are represented as np.array's of dimension (channels,).
99
+ """
100
+
101
+ if samples is not None and len(dataset) > samples:
102
+ indices = np.random.choice(len(dataset), samples, replace=False)
103
+ dataset = torch.utils.data.Subset(dataset, indices)
104
+ dataloader = torch.utils.data.DataLoader(
105
+ dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
106
+
107
+ n = 0 # number of elements taken (should be equal to samples by end of for loop)
108
+ s1 = 0. # sum of elements along channels (ends up as np.array of dimension (channels,))
109
+ s2 = 0. # sum of squares of elements along channels (ends up as np.array of dimension (channels,))
110
+ for (x, *_) in tqdm.tqdm(dataloader):
111
+ x = x.transpose(0, 1).contiguous().view(3, -1)
112
+ n += x.shape[1]
113
+ s1 += torch.sum(x, dim=1).numpy()
114
+ s2 += torch.sum(x ** 2, dim=1).numpy()
115
+ mean = s1 / n # type: np.ndarray
116
+ std = np.sqrt(s2 / n - mean ** 2) # type: np.ndarray
117
+
118
+ mean = mean.astype(np.float32)
119
+ std = std.astype(np.float32)
120
+
121
+ return mean, std
122
+
123
+
124
+ def bootstrap(a, b, func, samples=10000):
125
+ """Computes a bootstrapped confidence intervals for ``func(a, b)''.
126
+
127
+ Args:
128
+ a (array_like): first argument to `func`.
129
+ b (array_like): second argument to `func`.
130
+ func (callable): Function to compute confidence intervals for.
131
+ ``dataset[i][0]'' is expected to be the i-th video in the dataset, which
132
+ should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width)
133
+ samples (int, optional): Number of samples to compute.
134
+ Defaults to 10000.
135
+
136
+ Returns:
137
+ A tuple of (`func(a, b)`, estimated 5-th percentile, estimated 95-th percentile).
138
+ """
139
+ a = np.array(a)
140
+ b = np.array(b)
141
+
142
+ bootstraps = []
143
+ for _ in range(samples):
144
+ ind = np.random.choice(len(a), len(a))
145
+ bootstraps.append(func(a[ind], b[ind]))
146
+ bootstraps = sorted(bootstraps)
147
+
148
+ return func(a, b), bootstraps[round(0.05 * len(bootstraps))], bootstraps[round(0.95 * len(bootstraps))]
149
+
150
+
151
+ def latexify():
152
+ """Sets matplotlib params to appear more like LaTeX.
153
+
154
+ Based on https://nipunbatra.github.io/blog/2014/latexify.html
155
+ """
156
+ params = {'backend': 'pdf',
157
+ 'axes.titlesize': 8,
158
+ 'axes.labelsize': 8,
159
+ 'font.size': 8,
160
+ 'legend.fontsize': 8,
161
+ 'xtick.labelsize': 8,
162
+ 'ytick.labelsize': 8,
163
+ 'font.family': 'DejaVu Serif',
164
+ 'font.serif': 'Computer Modern',
165
+ }
166
+ matplotlib.rcParams.update(params)
167
+
168
+
169
+ def dice_similarity_coefficient(inter, union):
170
+ """Computes the dice similarity coefficient.
171
+
172
+ Args:
173
+ inter (iterable): iterable of the intersections
174
+ union (iterable): iterable of the unions
175
+ """
176
+ return 2 * sum(inter) / (sum(union) + sum(inter))
177
+
178
+
179
+ __all__ = ["video", "segmentation", "loadvideo", "savevideo", "get_mean_and_std", "bootstrap", "latexify", "dice_similarity_coefficient"]
dynamic/echonet/utils/segmentation.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Functions for training and running segmentation."""
2
+
3
+ import math
4
+ import os
5
+ import time
6
+
7
+ import click
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import scipy.signal
11
+ import skimage.draw
12
+ import torch
13
+ import torchvision
14
+ import tqdm
15
+
16
+ import echonet
17
+
18
+
19
+ @click.command("segmentation")
20
+ @click.option("--data_dir", type=click.Path(exists=True, file_okay=False), default=None)
21
+ @click.option("--output", type=click.Path(file_okay=False), default=None)
22
+ @click.option("--model_name", type=click.Choice(
23
+ sorted(name for name in torchvision.models.segmentation.__dict__
24
+ if name.islower() and not name.startswith("__") and callable(torchvision.models.segmentation.__dict__[name]))),
25
+ default="deeplabv3_resnet50")
26
+ @click.option("--pretrained/--random", default=False)
27
+ @click.option("--weights", type=click.Path(exists=True, dir_okay=False), default=None)
28
+ @click.option("--run_test/--skip_test", default=False)
29
+ @click.option("--save_video/--skip_video", default=False)
30
+ @click.option("--num_epochs", type=int, default=50)
31
+ @click.option("--lr", type=float, default=1e-5)
32
+ @click.option("--weight_decay", type=float, default=0)
33
+ @click.option("--lr_step_period", type=int, default=None)
34
+ @click.option("--num_train_patients", type=int, default=None)
35
+ @click.option("--num_workers", type=int, default=4)
36
+ @click.option("--batch_size", type=int, default=20)
37
+ @click.option("--device", type=str, default=None)
38
+ @click.option("--seed", type=int, default=0)
39
+ def run(
40
+ data_dir=None,
41
+ output=None,
42
+
43
+ model_name="deeplabv3_resnet50",
44
+ pretrained=False,
45
+ weights=None,
46
+
47
+ run_test=False,
48
+ save_video=False,
49
+ num_epochs=50,
50
+ lr=1e-5,
51
+ weight_decay=1e-5,
52
+ lr_step_period=None,
53
+ num_train_patients=None,
54
+ num_workers=4,
55
+ batch_size=20,
56
+ device=None,
57
+ seed=0,
58
+ ):
59
+ """Trains/tests segmentation model.
60
+
61
+ Args:
62
+ data_dir (str, optional): Directory containing dataset. Defaults to
63
+ `echonet.config.DATA_DIR`.
64
+ output (str, optional): Directory to place outputs. Defaults to
65
+ output/segmentation/<model_name>_<pretrained/random>/.
66
+ model_name (str, optional): Name of segmentation model. One of ``deeplabv3_resnet50'',
67
+ ``deeplabv3_resnet101'', ``fcn_resnet50'', or ``fcn_resnet101''
68
+ (options are torchvision.models.segmentation.<model_name>)
69
+ Defaults to ``deeplabv3_resnet50''.
70
+ pretrained (bool, optional): Whether to use pretrained weights for model
71
+ Defaults to False.
72
+ weights (str, optional): Path to checkpoint containing weights to
73
+ initialize model. Defaults to None.
74
+ run_test (bool, optional): Whether or not to run on test.
75
+ Defaults to False.
76
+ save_video (bool, optional): Whether to save videos with segmentations.
77
+ Defaults to False.
78
+ num_epochs (int, optional): Number of epochs during training
79
+ Defaults to 50.
80
+ lr (float, optional): Learning rate for SGD
81
+ Defaults to 1e-5.
82
+ weight_decay (float, optional): Weight decay for SGD
83
+ Defaults to 0.
84
+ lr_step_period (int or None, optional): Period of learning rate decay
85
+ (learning rate is decayed by a multiplicative factor of 0.1)
86
+ Defaults to math.inf (never decay learning rate).
87
+ num_train_patients (int or None, optional): Number of training patients
88
+ for ablations. Defaults to all patients.
89
+ num_workers (int, optional): Number of subprocesses to use for data
90
+ loading. If 0, the data will be loaded in the main process.
91
+ Defaults to 4.
92
+ device (str or None, optional): Name of device to run on. Options from
93
+ https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device
94
+ Defaults to ``cuda'' if available, and ``cpu'' otherwise.
95
+ batch_size (int, optional): Number of samples to load per batch
96
+ Defaults to 20.
97
+ seed (int, optional): Seed for random number generator. Defaults to 0.
98
+ """
99
+
100
+ # Seed RNGs
101
+ np.random.seed(seed)
102
+ torch.manual_seed(seed)
103
+
104
+ # Set default output directory
105
+ if output is None:
106
+ output = os.path.join("output", "segmentation", "{}_{}".format(model_name, "pretrained" if pretrained else "random"))
107
+ os.makedirs(output, exist_ok=True)
108
+
109
+ # Set device for computations
110
+ if device is None:
111
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
112
+
113
+ # Set up model
114
+ model = torchvision.models.segmentation.__dict__[model_name](pretrained=pretrained, aux_loss=False)
115
+
116
+ model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size) # change number of outputs to 1
117
+ if device.type == "cuda":
118
+ model = torch.nn.DataParallel(model)
119
+ model.to(device)
120
+
121
+ if weights is not None:
122
+ checkpoint = torch.load(weights)
123
+ model.load_state_dict(checkpoint['state_dict'])
124
+
125
+ # Set up optimizer
126
+ optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
127
+ if lr_step_period is None:
128
+ lr_step_period = math.inf
129
+ scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period)
130
+
131
+ # Compute mean and std
132
+ mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(root=data_dir, split="train"))
133
+ tasks = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"]
134
+ kwargs = {"target_type": tasks,
135
+ "mean": mean,
136
+ "std": std
137
+ }
138
+
139
+ # Set up datasets and dataloaders
140
+ dataset = {}
141
+ dataset["train"] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs)
142
+ if num_train_patients is not None and len(dataset["train"]) > num_train_patients:
143
+ # Subsample patients (used for ablation experiment)
144
+ indices = np.random.choice(len(dataset["train"]), num_train_patients, replace=False)
145
+ dataset["train"] = torch.utils.data.Subset(dataset["train"], indices)
146
+ dataset["val"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs)
147
+
148
+ # Run training and testing loops
149
+ with open(os.path.join(output, "log.csv"), "a") as f:
150
+ epoch_resume = 0
151
+ bestLoss = float("inf")
152
+ try:
153
+ # Attempt to load checkpoint
154
+ checkpoint = torch.load(os.path.join(output, "checkpoint.pt"))
155
+ model.load_state_dict(checkpoint['state_dict'])
156
+ optim.load_state_dict(checkpoint['opt_dict'])
157
+ scheduler.load_state_dict(checkpoint['scheduler_dict'])
158
+ epoch_resume = checkpoint["epoch"] + 1
159
+ bestLoss = checkpoint["best_loss"]
160
+ f.write("Resuming from epoch {}\n".format(epoch_resume))
161
+ except FileNotFoundError:
162
+ f.write("Starting run from scratch\n")
163
+
164
+ for epoch in range(epoch_resume, num_epochs):
165
+ print("Epoch #{}".format(epoch), flush=True)
166
+ for phase in ['train', 'val']:
167
+ start_time = time.time()
168
+ for i in range(torch.cuda.device_count()):
169
+ torch.cuda.reset_peak_memory_stats(i)
170
+
171
+ ds = dataset[phase]
172
+ dataloader = torch.utils.data.DataLoader(
173
+ ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"))
174
+
175
+ loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model, dataloader, phase == "train", optim, device)
176
+ overall_dice = 2 * (large_inter.sum() + small_inter.sum()) / (large_union.sum() + large_inter.sum() + small_union.sum() + small_inter.sum())
177
+ large_dice = 2 * large_inter.sum() / (large_union.sum() + large_inter.sum())
178
+ small_dice = 2 * small_inter.sum() / (small_union.sum() + small_inter.sum())
179
+ f.write("{},{},{},{},{},{},{},{},{},{},{}\n".format(epoch,
180
+ phase,
181
+ loss,
182
+ overall_dice,
183
+ large_dice,
184
+ small_dice,
185
+ time.time() - start_time,
186
+ large_inter.size,
187
+ sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())),
188
+ sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())),
189
+ batch_size))
190
+ f.flush()
191
+ scheduler.step()
192
+
193
+ # Save checkpoint
194
+ save = {
195
+ 'epoch': epoch,
196
+ 'state_dict': model.state_dict(),
197
+ 'best_loss': bestLoss,
198
+ 'loss': loss,
199
+ 'opt_dict': optim.state_dict(),
200
+ 'scheduler_dict': scheduler.state_dict(),
201
+ }
202
+ torch.save(save, os.path.join(output, "checkpoint.pt"))
203
+ if loss < bestLoss:
204
+ torch.save(save, os.path.join(output, "best.pt"))
205
+ bestLoss = loss
206
+
207
+ # Load best weights
208
+ if num_epochs != 0:
209
+ checkpoint = torch.load(os.path.join(output, "best.pt"))
210
+ model.load_state_dict(checkpoint['state_dict'])
211
+ f.write("Best validation loss {} from epoch {}\n".format(checkpoint["loss"], checkpoint["epoch"]))
212
+
213
+ if run_test:
214
+ # Run on validation and test
215
+ for split in ["val", "test"]:
216
+ dataset = echonet.datasets.Echo(root=data_dir, split=split, **kwargs)
217
+ dataloader = torch.utils.data.DataLoader(dataset,
218
+ batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"))
219
+ loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model, dataloader, False, None, device)
220
+
221
+ overall_dice = 2 * (large_inter + small_inter) / (large_union + large_inter + small_union + small_inter)
222
+ large_dice = 2 * large_inter / (large_union + large_inter)
223
+ small_dice = 2 * small_inter / (small_union + small_inter)
224
+ with open(os.path.join(output, "{}_dice.csv".format(split)), "w") as g:
225
+ g.write("Filename, Overall, Large, Small\n")
226
+ for (filename, overall, large, small) in zip(dataset.fnames, overall_dice, large_dice, small_dice):
227
+ g.write("{},{},{},{}\n".format(filename, overall, large, small))
228
+
229
+ f.write("{} dice (overall): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(np.concatenate((large_inter, small_inter)), np.concatenate((large_union, small_union)), echonet.utils.dice_similarity_coefficient)))
230
+ f.write("{} dice (large): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(large_inter, large_union, echonet.utils.dice_similarity_coefficient)))
231
+ f.write("{} dice (small): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(small_inter, small_union, echonet.utils.dice_similarity_coefficient)))
232
+ f.flush()
233
+
234
+ # Saving videos with segmentations
235
+ dataset = echonet.datasets.Echo(root=data_dir, split="test",
236
+ target_type=["Filename", "LargeIndex", "SmallIndex"], # Need filename for saving, and human-selected frames to annotate
237
+ mean=mean, std=std, # Normalization
238
+ length=None, max_length=None, period=1 # Take all frames
239
+ )
240
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, num_workers=num_workers, shuffle=False, pin_memory=False, collate_fn=_video_collate_fn)
241
+
242
+ # Save videos with segmentation
243
+ if save_video and not all(os.path.isfile(os.path.join(output, "videos", f)) for f in dataloader.dataset.fnames):
244
+ # Only run if missing videos
245
+
246
+ model.eval()
247
+
248
+ os.makedirs(os.path.join(output, "videos"), exist_ok=True)
249
+ os.makedirs(os.path.join(output, "size"), exist_ok=True)
250
+ echonet.utils.latexify()
251
+
252
+ with torch.no_grad():
253
+ with open(os.path.join(output, "size.csv"), "w") as g:
254
+ g.write("Filename,Frame,Size,HumanLarge,HumanSmall,ComputerSmall\n")
255
+ for (x, (filenames, large_index, small_index), length) in tqdm.tqdm(dataloader):
256
+ # Run segmentation model on blocks of frames one-by-one
257
+ # The whole concatenated video may be too long to run together
258
+ y = np.concatenate([model(x[i:(i + batch_size), :, :, :].to(device))["out"].detach().cpu().numpy() for i in range(0, x.shape[0], batch_size)])
259
+
260
+ start = 0
261
+ x = x.numpy()
262
+ for (i, (filename, offset)) in enumerate(zip(filenames, length)):
263
+ # Extract one video and segmentation predictions
264
+ video = x[start:(start + offset), ...]
265
+ logit = y[start:(start + offset), 0, :, :]
266
+
267
+ # Un-normalize video
268
+ video *= std.reshape(1, 3, 1, 1)
269
+ video += mean.reshape(1, 3, 1, 1)
270
+
271
+ # Get frames, channels, height, and width
272
+ f, c, h, w = video.shape # pylint: disable=W0612
273
+ assert c == 3
274
+
275
+ # Put two copies of the video side by side
276
+ video = np.concatenate((video, video), 3)
277
+
278
+ # If a pixel is in the segmentation, saturate blue channel
279
+ # Leave alone otherwise
280
+ video[:, 0, :, w:] = np.maximum(255. * (logit > 0), video[:, 0, :, w:]) # pylint: disable=E1111
281
+
282
+ # Add blank canvas under pair of videos
283
+ video = np.concatenate((video, np.zeros_like(video)), 2)
284
+
285
+ # Compute size of segmentation per frame
286
+ size = (logit > 0).sum((1, 2))
287
+
288
+ # Identify systole frames with peak detection
289
+ trim_min = sorted(size)[round(len(size) ** 0.05)]
290
+ trim_max = sorted(size)[round(len(size) ** 0.95)]
291
+ trim_range = trim_max - trim_min
292
+ systole = set(scipy.signal.find_peaks(-size, distance=20, prominence=(0.50 * trim_range))[0])
293
+
294
+ # Write sizes and frames to file
295
+ for (frame, s) in enumerate(size):
296
+ g.write("{},{},{},{},{},{}\n".format(filename, frame, s, 1 if frame == large_index[i] else 0, 1 if frame == small_index[i] else 0, 1 if frame in systole else 0))
297
+
298
+ # Plot sizes
299
+ fig = plt.figure(figsize=(size.shape[0] / 50 * 1.5, 3))
300
+ plt.scatter(np.arange(size.shape[0]) / 50, size, s=1)
301
+ ylim = plt.ylim()
302
+ for s in systole:
303
+ plt.plot(np.array([s, s]) / 50, ylim, linewidth=1)
304
+ plt.ylim(ylim)
305
+ plt.title(os.path.splitext(filename)[0])
306
+ plt.xlabel("Seconds")
307
+ plt.ylabel("Size (pixels)")
308
+ plt.tight_layout()
309
+ plt.savefig(os.path.join(output, "size", os.path.splitext(filename)[0] + ".pdf"))
310
+ plt.close(fig)
311
+
312
+ # Normalize size to [0, 1]
313
+ size -= size.min()
314
+ size = size / size.max()
315
+ size = 1 - size
316
+
317
+ # Iterate the frames in this video
318
+ for (f, s) in enumerate(size):
319
+
320
+ # On all frames, mark a pixel for the size of the frame
321
+ video[:, :, int(round(115 + 100 * s)), int(round(f / len(size) * 200 + 10))] = 255.
322
+
323
+ if f in systole:
324
+ # If frame is computer-selected systole, mark with a line
325
+ video[:, :, 115:224, int(round(f / len(size) * 200 + 10))] = 255.
326
+
327
+ def dash(start, stop, on=10, off=10):
328
+ buf = []
329
+ x = start
330
+ while x < stop:
331
+ buf.extend(range(x, x + on))
332
+ x += on
333
+ x += off
334
+ buf = np.array(buf)
335
+ buf = buf[buf < stop]
336
+ return buf
337
+ d = dash(115, 224)
338
+
339
+ if f == large_index[i]:
340
+ # If frame is human-selected diastole, mark with green dashed line on all frames
341
+ video[:, :, d, int(round(f / len(size) * 200 + 10))] = np.array([0, 225, 0]).reshape((1, 3, 1))
342
+ if f == small_index[i]:
343
+ # If frame is human-selected systole, mark with red dashed line on all frames
344
+ video[:, :, d, int(round(f / len(size) * 200 + 10))] = np.array([0, 0, 225]).reshape((1, 3, 1))
345
+
346
+ # Get pixels for a circle centered on the pixel
347
+ r, c = skimage.draw.disk((int(round(115 + 100 * s)), int(round(f / len(size) * 200 + 10))), 4.1)
348
+
349
+ # On the frame that's being shown, put a circle over the pixel
350
+ video[f, :, r, c] = 255.
351
+
352
+ # Rearrange dimensions and save
353
+ video = video.transpose(1, 0, 2, 3)
354
+ video = video.astype(np.uint8)
355
+ echonet.utils.savevideo(os.path.join(output, "videos", filename), video, 50)
356
+
357
+ # Move to next video
358
+ start += offset
359
+
360
+
361
+ def run_epoch(model, dataloader, train, optim, device):
362
+ """Run one epoch of training/evaluation for segmentation.
363
+
364
+ Args:
365
+ model (torch.nn.Module): Model to train/evaulate.
366
+ dataloder (torch.utils.data.DataLoader): Dataloader for dataset.
367
+ train (bool): Whether or not to train model.
368
+ optim (torch.optim.Optimizer): Optimizer
369
+ device (torch.device): Device to run on
370
+ """
371
+
372
+ total = 0.
373
+ n = 0
374
+
375
+ pos = 0
376
+ neg = 0
377
+ pos_pix = 0
378
+ neg_pix = 0
379
+
380
+ model.train(train)
381
+
382
+ large_inter = 0
383
+ large_union = 0
384
+ small_inter = 0
385
+ small_union = 0
386
+ large_inter_list = []
387
+ large_union_list = []
388
+ small_inter_list = []
389
+ small_union_list = []
390
+
391
+ with torch.set_grad_enabled(train):
392
+ with tqdm.tqdm(total=len(dataloader)) as pbar:
393
+ for (_, (large_frame, small_frame, large_trace, small_trace)) in dataloader:
394
+ # Count number of pixels in/out of human segmentation
395
+ pos += (large_trace == 1).sum().item()
396
+ pos += (small_trace == 1).sum().item()
397
+ neg += (large_trace == 0).sum().item()
398
+ neg += (small_trace == 0).sum().item()
399
+
400
+ # Count number of pixels in/out of computer segmentation
401
+ pos_pix += (large_trace == 1).sum(0).to("cpu").detach().numpy()
402
+ pos_pix += (small_trace == 1).sum(0).to("cpu").detach().numpy()
403
+ neg_pix += (large_trace == 0).sum(0).to("cpu").detach().numpy()
404
+ neg_pix += (small_trace == 0).sum(0).to("cpu").detach().numpy()
405
+
406
+ # Run prediction for diastolic frames and compute loss
407
+ large_frame = large_frame.to(device)
408
+ large_trace = large_trace.to(device)
409
+ y_large = model(large_frame)["out"]
410
+ loss_large = torch.nn.functional.binary_cross_entropy_with_logits(y_large[:, 0, :, :], large_trace, reduction="sum")
411
+ # Compute pixel intersection and union between human and computer segmentations
412
+ large_inter += np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
413
+ large_union += np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
414
+ large_inter_list.extend(np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
415
+ large_union_list.extend(np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
416
+
417
+ # Run prediction for systolic frames and compute loss
418
+ small_frame = small_frame.to(device)
419
+ small_trace = small_trace.to(device)
420
+ y_small = model(small_frame)["out"]
421
+ loss_small = torch.nn.functional.binary_cross_entropy_with_logits(y_small[:, 0, :, :], small_trace, reduction="sum")
422
+ # Compute pixel intersection and union between human and computer segmentations
423
+ small_inter += np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
424
+ small_union += np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum()
425
+ small_inter_list.extend(np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
426
+ small_union_list.extend(np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2)))
427
+
428
+ # Take gradient step if training
429
+ loss = (loss_large + loss_small) / 2
430
+ if train:
431
+ optim.zero_grad()
432
+ loss.backward()
433
+ optim.step()
434
+
435
+ # Accumulate losses and compute baselines
436
+ total += loss.item()
437
+ n += large_trace.size(0)
438
+ p = pos / (pos + neg)
439
+ p_pix = (pos_pix + 1) / (pos_pix + neg_pix + 2)
440
+
441
+ # Show info on process bar
442
+ pbar.set_postfix_str("{:.4f} ({:.4f}) / {:.4f} {:.4f}, {:.4f}, {:.4f}".format(total / n / 112 / 112, loss.item() / large_trace.size(0) / 112 / 112, -p * math.log(p) - (1 - p) * math.log(1 - p), (-p_pix * np.log(p_pix) - (1 - p_pix) * np.log(1 - p_pix)).mean(), 2 * large_inter / (large_union + large_inter), 2 * small_inter / (small_union + small_inter)))
443
+ pbar.update()
444
+
445
+ large_inter_list = np.array(large_inter_list)
446
+ large_union_list = np.array(large_union_list)
447
+ small_inter_list = np.array(small_inter_list)
448
+ small_union_list = np.array(small_union_list)
449
+
450
+ return (total / n / 112 / 112,
451
+ large_inter_list,
452
+ large_union_list,
453
+ small_inter_list,
454
+ small_union_list,
455
+ )
456
+
457
+
458
+ def _video_collate_fn(x):
459
+ """Collate function for Pytorch dataloader to merge multiple videos.
460
+
461
+ This function should be used in a dataloader for a dataset that returns
462
+ a video as the first element, along with some (non-zero) tuple of
463
+ targets. Then, the input x is a list of tuples:
464
+ - x[i][0] is the i-th video in the batch
465
+ - x[i][1] are the targets for the i-th video
466
+
467
+ This function returns a 3-tuple:
468
+ - The first element is the videos concatenated along the frames
469
+ dimension. This is done so that videos of different lengths can be
470
+ processed together (tensors cannot be "jagged", so we cannot have
471
+ a dimension for video, and another for frames).
472
+ - The second element is contains the targets with no modification.
473
+ - The third element is a list of the lengths of the videos in frames.
474
+ """
475
+ video, target = zip(*x) # Extract the videos and targets
476
+
477
+ # ``video'' is a tuple of length ``batch_size''
478
+ # Each element has shape (channels=3, frames, height, width)
479
+ # height and width are expected to be the same across videos, but
480
+ # frames can be different.
481
+
482
+ # ``target'' is also a tuple of length ``batch_size''
483
+ # Each element is a tuple of the targets for the item.
484
+
485
+ i = list(map(lambda t: t.shape[1], video)) # Extract lengths of videos in frames
486
+
487
+ # This contatenates the videos along the the frames dimension (basically
488
+ # playing the videos one after another). The frames dimension is then
489
+ # moved to be first.
490
+ # Resulting shape is (total frames, channels=3, height, width)
491
+ video = torch.as_tensor(np.swapaxes(np.concatenate(video, 1), 0, 1))
492
+
493
+ # Swap dimensions (approximately a transpose)
494
+ # Before: target[i][j] is the j-th target of element i
495
+ # After: target[i][j] is the i-th target of element j
496
+ target = zip(*target)
497
+
498
+ return video, target, i
dynamic/echonet/utils/video.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Functions for training and running EF prediction."""
2
+
3
+ import math
4
+ import os
5
+ import time
6
+
7
+ import click
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import sklearn.metrics
11
+ import torch
12
+ import torchvision
13
+ import tqdm
14
+
15
+ import echonet
16
+
17
+
18
+ @click.command("video")
19
+ @click.option("--data_dir", type=click.Path(exists=True, file_okay=False), default=None)
20
+ @click.option("--output", type=click.Path(file_okay=False), default=None)
21
+ @click.option("--task", type=str, default="EF")
22
+ @click.option("--model_name", type=click.Choice(
23
+ sorted(name for name in torchvision.models.video.__dict__
24
+ if name.islower() and not name.startswith("__") and callable(torchvision.models.video.__dict__[name]))),
25
+ default="r2plus1d_18")
26
+ @click.option("--pretrained/--random", default=True)
27
+ @click.option("--weights", type=click.Path(exists=True, dir_okay=False), default=None)
28
+ @click.option("--run_test/--skip_test", default=False)
29
+ @click.option("--num_epochs", type=int, default=45)
30
+ @click.option("--lr", type=float, default=1e-4)
31
+ @click.option("--weight_decay", type=float, default=1e-4)
32
+ @click.option("--lr_step_period", type=int, default=15)
33
+ @click.option("--frames", type=int, default=32)
34
+ @click.option("--period", type=int, default=2)
35
+ @click.option("--num_train_patients", type=int, default=None)
36
+ @click.option("--num_workers", type=int, default=4)
37
+ @click.option("--batch_size", type=int, default=20)
38
+ @click.option("--device", type=str, default=None)
39
+ @click.option("--seed", type=int, default=0)
40
+ def run(
41
+ data_dir=None,
42
+ output=None,
43
+ task="EF",
44
+
45
+ model_name="r2plus1d_18",
46
+ pretrained=True,
47
+ weights=None,
48
+
49
+ run_test=False,
50
+ num_epochs=45,
51
+ lr=1e-4,
52
+ weight_decay=1e-4,
53
+ lr_step_period=15,
54
+ frames=32,
55
+ period=2,
56
+ num_train_patients=None,
57
+ num_workers=4,
58
+ batch_size=20,
59
+ device=None,
60
+ seed=0,
61
+ ):
62
+ """Trains/tests EF prediction model.
63
+
64
+ \b
65
+ Args:
66
+ data_dir (str, optional): Directory containing dataset. Defaults to
67
+ `echonet.config.DATA_DIR`.
68
+ output (str, optional): Directory to place outputs. Defaults to
69
+ output/video/<model_name>_<pretrained/random>/.
70
+ task (str, optional): Name of task to predict. Options are the headers
71
+ of FileList.csv. Defaults to ``EF''.
72
+ model_name (str, optional): Name of model. One of ``mc3_18'',
73
+ ``r2plus1d_18'', or ``r3d_18''
74
+ (options are torchvision.models.video.<model_name>)
75
+ Defaults to ``r2plus1d_18''.
76
+ pretrained (bool, optional): Whether to use pretrained weights for model
77
+ Defaults to True.
78
+ weights (str, optional): Path to checkpoint containing weights to
79
+ initialize model. Defaults to None.
80
+ run_test (bool, optional): Whether or not to run on test.
81
+ Defaults to False.
82
+ num_epochs (int, optional): Number of epochs during training.
83
+ Defaults to 45.
84
+ lr (float, optional): Learning rate for SGD
85
+ Defaults to 1e-4.
86
+ weight_decay (float, optional): Weight decay for SGD
87
+ Defaults to 1e-4.
88
+ lr_step_period (int or None, optional): Period of learning rate decay
89
+ (learning rate is decayed by a multiplicative factor of 0.1)
90
+ Defaults to 15.
91
+ frames (int, optional): Number of frames to use in clip
92
+ Defaults to 32.
93
+ period (int, optional): Sampling period for frames
94
+ Defaults to 2.
95
+ n_train_patients (int or None, optional): Number of training patients
96
+ for ablations. Defaults to all patients.
97
+ num_workers (int, optional): Number of subprocesses to use for data
98
+ loading. If 0, the data will be loaded in the main process.
99
+ Defaults to 4.
100
+ device (str or None, optional): Name of device to run on. Options from
101
+ https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device
102
+ Defaults to ``cuda'' if available, and ``cpu'' otherwise.
103
+ batch_size (int, optional): Number of samples to load per batch
104
+ Defaults to 20.
105
+ seed (int, optional): Seed for random number generator. Defaults to 0.
106
+ """
107
+
108
+ # Seed RNGs
109
+ np.random.seed(seed)
110
+ torch.manual_seed(seed)
111
+
112
+ # Set default output directory
113
+ if output is None:
114
+ output = os.path.join("output", "video", "{}_{}_{}_{}".format(model_name, frames, period, "pretrained" if pretrained else "random"))
115
+ os.makedirs(output, exist_ok=True)
116
+
117
+ # Set device for computations
118
+ if device is None:
119
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
120
+
121
+ # Set up model
122
+ model = torchvision.models.video.__dict__[model_name](pretrained=pretrained)
123
+
124
+ model.fc = torch.nn.Linear(model.fc.in_features, 1)
125
+ model.fc.bias.data[0] = 55.6
126
+ if device.type == "cuda":
127
+ model = torch.nn.DataParallel(model)
128
+ model.to(device)
129
+
130
+ if weights is not None:
131
+ checkpoint = torch.load(weights)
132
+ model.load_state_dict(checkpoint['state_dict'])
133
+
134
+ # Set up optimizer
135
+ optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)
136
+ if lr_step_period is None:
137
+ lr_step_period = math.inf
138
+ scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period)
139
+
140
+ # Compute mean and std
141
+ mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(root=data_dir, split="train"))
142
+ kwargs = {"target_type": task,
143
+ "mean": mean,
144
+ "std": std,
145
+ "length": frames,
146
+ "period": period,
147
+ }
148
+
149
+ # Set up datasets and dataloaders
150
+ dataset = {}
151
+ dataset["train"] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs, pad=12)
152
+ if num_train_patients is not None and len(dataset["train"]) > num_train_patients:
153
+ # Subsample patients (used for ablation experiment)
154
+ indices = np.random.choice(len(dataset["train"]), num_train_patients, replace=False)
155
+ dataset["train"] = torch.utils.data.Subset(dataset["train"], indices)
156
+ dataset["val"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs)
157
+
158
+ # Run training and testing loops
159
+ with open(os.path.join(output, "log.csv"), "a") as f:
160
+ epoch_resume = 0
161
+ bestLoss = float("inf")
162
+ try:
163
+ # Attempt to load checkpoint
164
+ checkpoint = torch.load(os.path.join(output, "checkpoint.pt"))
165
+ model.load_state_dict(checkpoint['state_dict'])
166
+ optim.load_state_dict(checkpoint['opt_dict'])
167
+ scheduler.load_state_dict(checkpoint['scheduler_dict'])
168
+ epoch_resume = checkpoint["epoch"] + 1
169
+ bestLoss = checkpoint["best_loss"]
170
+ f.write("Resuming from epoch {}\n".format(epoch_resume))
171
+ except FileNotFoundError:
172
+ f.write("Starting run from scratch\n")
173
+
174
+ for epoch in range(epoch_resume, num_epochs):
175
+ print("Epoch #{}".format(epoch), flush=True)
176
+ for phase in ['train', 'val']:
177
+ start_time = time.time()
178
+ for i in range(torch.cuda.device_count()):
179
+ torch.cuda.reset_peak_memory_stats(i)
180
+
181
+ ds = dataset[phase]
182
+ dataloader = torch.utils.data.DataLoader(
183
+ ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train"))
184
+
185
+ loss, yhat, y = echonet.utils.video.run_epoch(model, dataloader, phase == "train", optim, device)
186
+ f.write("{},{},{},{},{},{},{},{},{}\n".format(epoch,
187
+ phase,
188
+ loss,
189
+ sklearn.metrics.r2_score(y, yhat),
190
+ time.time() - start_time,
191
+ y.size,
192
+ sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())),
193
+ sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())),
194
+ batch_size))
195
+ f.flush()
196
+ scheduler.step()
197
+
198
+ # Save checkpoint
199
+ save = {
200
+ 'epoch': epoch,
201
+ 'state_dict': model.state_dict(),
202
+ 'period': period,
203
+ 'frames': frames,
204
+ 'best_loss': bestLoss,
205
+ 'loss': loss,
206
+ 'r2': sklearn.metrics.r2_score(y, yhat),
207
+ 'opt_dict': optim.state_dict(),
208
+ 'scheduler_dict': scheduler.state_dict(),
209
+ }
210
+ torch.save(save, os.path.join(output, "checkpoint.pt"))
211
+ if loss < bestLoss:
212
+ torch.save(save, os.path.join(output, "best.pt"))
213
+ bestLoss = loss
214
+
215
+ # Load best weights
216
+ if num_epochs != 0:
217
+ checkpoint = torch.load(os.path.join(output, "best.pt"))
218
+ model.load_state_dict(checkpoint['state_dict'])
219
+ f.write("Best validation loss {} from epoch {}\n".format(checkpoint["loss"], checkpoint["epoch"]))
220
+ f.flush()
221
+
222
+ if run_test:
223
+ for split in ["val", "test"]:
224
+ # Performance without test-time augmentation
225
+ dataloader = torch.utils.data.DataLoader(
226
+ echonet.datasets.Echo(root=data_dir, split=split, **kwargs),
227
+ batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"))
228
+ loss, yhat, y = echonet.utils.video.run_epoch(model, dataloader, False, None, device)
229
+ f.write("{} (one clip) R2: {:.3f} ({:.3f} - {:.3f})\n".format(split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.r2_score)))
230
+ f.write("{} (one clip) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_absolute_error)))
231
+ f.write("{} (one clip) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, yhat, sklearn.metrics.mean_squared_error)))))
232
+ f.flush()
233
+
234
+ # Performance with test-time augmentation
235
+ ds = echonet.datasets.Echo(root=data_dir, split=split, **kwargs, clips="all")
236
+ dataloader = torch.utils.data.DataLoader(
237
+ ds, batch_size=1, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda"))
238
+ loss, yhat, y = echonet.utils.video.run_epoch(model, dataloader, False, None, device, save_all=True, block_size=batch_size)
239
+ f.write("{} (all clips) R2: {:.3f} ({:.3f} - {:.3f})\n".format(split, *echonet.utils.bootstrap(y, np.array(list(map(lambda x: x.mean(), yhat))), sklearn.metrics.r2_score)))
240
+ f.write("{} (all clips) MAE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *echonet.utils.bootstrap(y, np.array(list(map(lambda x: x.mean(), yhat))), sklearn.metrics.mean_absolute_error)))
241
+ f.write("{} (all clips) RMSE: {:.2f} ({:.2f} - {:.2f})\n".format(split, *tuple(map(math.sqrt, echonet.utils.bootstrap(y, np.array(list(map(lambda x: x.mean(), yhat))), sklearn.metrics.mean_squared_error)))))
242
+ f.flush()
243
+
244
+ # Write full performance to file
245
+ with open(os.path.join(output, "{}_predictions.csv".format(split)), "w") as g:
246
+ for (filename, pred) in zip(ds.fnames, yhat):
247
+ for (i, p) in enumerate(pred):
248
+ g.write("{},{},{:.4f}\n".format(filename, i, p))
249
+ echonet.utils.latexify()
250
+ yhat = np.array(list(map(lambda x: x.mean(), yhat)))
251
+
252
+ # Plot actual and predicted EF
253
+ fig = plt.figure(figsize=(3, 3))
254
+ lower = min(y.min(), yhat.min())
255
+ upper = max(y.max(), yhat.max())
256
+ plt.scatter(y, yhat, color="k", s=1, edgecolor=None, zorder=2)
257
+ plt.plot([0, 100], [0, 100], linewidth=1, zorder=3)
258
+ plt.axis([lower - 3, upper + 3, lower - 3, upper + 3])
259
+ plt.gca().set_aspect("equal", "box")
260
+ plt.xlabel("Actual EF (%)")
261
+ plt.ylabel("Predicted EF (%)")
262
+ plt.xticks([10, 20, 30, 40, 50, 60, 70, 80])
263
+ plt.yticks([10, 20, 30, 40, 50, 60, 70, 80])
264
+ plt.grid(color="gainsboro", linestyle="--", linewidth=1, zorder=1)
265
+ plt.tight_layout()
266
+ plt.savefig(os.path.join(output, "{}_scatter.pdf".format(split)))
267
+ plt.close(fig)
268
+
269
+ # Plot AUROC
270
+ fig = plt.figure(figsize=(3, 3))
271
+ plt.plot([0, 1], [0, 1], linewidth=1, color="k", linestyle="--")
272
+ for thresh in [35, 40, 45, 50]:
273
+ fpr, tpr, _ = sklearn.metrics.roc_curve(y > thresh, yhat)
274
+ print(thresh, sklearn.metrics.roc_auc_score(y > thresh, yhat))
275
+ plt.plot(fpr, tpr)
276
+
277
+ plt.axis([-0.01, 1.01, -0.01, 1.01])
278
+ plt.xlabel("False Positive Rate")
279
+ plt.ylabel("True Positive Rate")
280
+ plt.tight_layout()
281
+ plt.savefig(os.path.join(output, "{}_roc.pdf".format(split)))
282
+ plt.close(fig)
283
+
284
+
285
+ def run_epoch(model, dataloader, train, optim, device, save_all=False, block_size=None):
286
+ """Run one epoch of training/evaluation for segmentation.
287
+
288
+ Args:
289
+ model (torch.nn.Module): Model to train/evaulate.
290
+ dataloder (torch.utils.data.DataLoader): Dataloader for dataset.
291
+ train (bool): Whether or not to train model.
292
+ optim (torch.optim.Optimizer): Optimizer
293
+ device (torch.device): Device to run on
294
+ save_all (bool, optional): If True, return predictions for all
295
+ test-time augmentations separately. If False, return only
296
+ the mean prediction.
297
+ Defaults to False.
298
+ block_size (int or None, optional): Maximum number of augmentations
299
+ to run on at the same time. Use to limit the amount of memory
300
+ used. If None, always run on all augmentations simultaneously.
301
+ Default is None.
302
+ """
303
+
304
+ model.train(train)
305
+
306
+ total = 0 # total training loss
307
+ n = 0 # number of videos processed
308
+ s1 = 0 # sum of ground truth EF
309
+ s2 = 0 # Sum of ground truth EF squared
310
+
311
+ yhat = []
312
+ y = []
313
+
314
+ with torch.set_grad_enabled(train):
315
+ with tqdm.tqdm(total=len(dataloader)) as pbar:
316
+ for (X, outcome) in dataloader:
317
+
318
+ y.append(outcome.numpy())
319
+ X = X.to(device)
320
+ outcome = outcome.to(device)
321
+
322
+ average = (len(X.shape) == 6)
323
+ if average:
324
+ batch, n_clips, c, f, h, w = X.shape
325
+ X = X.view(-1, c, f, h, w)
326
+
327
+ s1 += outcome.sum()
328
+ s2 += (outcome ** 2).sum()
329
+
330
+ if block_size is None:
331
+ outputs = model(X)
332
+ else:
333
+ outputs = torch.cat([model(X[j:(j + block_size), ...]) for j in range(0, X.shape[0], block_size)])
334
+
335
+ if save_all:
336
+ yhat.append(outputs.view(-1).to("cpu").detach().numpy())
337
+
338
+ if average:
339
+ outputs = outputs.view(batch, n_clips, -1).mean(1)
340
+
341
+ if not save_all:
342
+ yhat.append(outputs.view(-1).to("cpu").detach().numpy())
343
+
344
+ loss = torch.nn.functional.mse_loss(outputs.view(-1), outcome)
345
+
346
+ if train:
347
+ optim.zero_grad()
348
+ loss.backward()
349
+ optim.step()
350
+
351
+ total += loss.item() * X.size(0)
352
+ n += X.size(0)
353
+
354
+ pbar.set_postfix_str("{:.2f} ({:.2f}) / {:.2f}".format(total / n, loss.item(), s2 / n - (s1 / n) ** 2))
355
+ pbar.update()
356
+
357
+ if not save_all:
358
+ yhat = np.concatenate(yhat)
359
+ y = np.concatenate(y)
360
+
361
+ return total / n, yhat, y
dynamic/example.cfg ADDED
@@ -0,0 +1 @@
 
 
1
+ DATA_DIR = a4c-video-dir/
dynamic/requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ certifi==2020.12.5
2
+ cycler==0.10.0
3
+ decorator==4.4.2
4
+ echonet==1.0.0
5
+ imageio==2.9.0
6
+ joblib==1.0.1
7
+ kiwisolver==1.3.1
8
+ matplotlib==3.3.4
9
+ networkx==2.5
10
+ numpy==1.20.1
11
+ opencv-python==4.5.1.48
12
+ pandas==1.2.3
13
+ Pillow==8.1.2
14
+ pyparsing==2.4.7
15
+ python-dateutil==2.8.1
16
+ pytz==2021.1
17
+ PyWavelets==1.1.1
18
+ scikit-image==0.18.1
19
+ scikit-learn==0.24.1
20
+ scipy==1.6.1
21
+ six==1.15.0
22
+ sklearn==0.0
23
+ threadpoolctl==2.1.0
24
+ tifffile==2021.3.17
25
+ torch==1.8.0
26
+ torchvision==0.9.0
27
+ tqdm==4.59.0
28
+ typing-extensions==3.7.4.3
dynamic/scripts/ConvertDICOMToAVI.ipynb ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 12,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# David Ouyang 10/2/2019\n",
10
+ "\n",
11
+ "# Notebook which iterates through a folder, including subfolders, \n",
12
+ "# and convert DICOM files to AVI files of a defined size (natively 112 x 112)\n",
13
+ "\n",
14
+ "import re\n",
15
+ "import os, os.path\n",
16
+ "from os.path import splitext\n",
17
+ "import pydicom as dicom\n",
18
+ "import numpy as np\n",
19
+ "from pydicom.uid import UID, generate_uid\n",
20
+ "import shutil\n",
21
+ "from multiprocessing import dummy as multiprocessing\n",
22
+ "import time\n",
23
+ "import subprocess\n",
24
+ "import datetime\n",
25
+ "from datetime import date\n",
26
+ "import sys\n",
27
+ "import cv2\n",
28
+ "#from scipy.misc import imread\n",
29
+ "import matplotlib.pyplot as plt\n",
30
+ "import sys\n",
31
+ "from shutil import copy\n",
32
+ "import math\n",
33
+ "\n",
34
+ "destinationFolder = \"Output Folder Name\"\n"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 10,
40
+ "metadata": {},
41
+ "outputs": [
42
+ {
43
+ "name": "stdout",
44
+ "output_type": "stream",
45
+ "text": [
46
+ "Requirement already satisfied: pillow in c:\\programdata\\anaconda3\\lib\\site-packages (6.2.0)\n",
47
+ "Requirement already satisfied: scipy in c:\\programdata\\anaconda3\\lib\\site-packages (1.3.1)\n"
48
+ ]
49
+ }
50
+ ],
51
+ "source": [
52
+ "# Dependencies you might need to run code\n",
53
+ "# Commonly missing\n",
54
+ "\n",
55
+ "#!pip install pydicom\n",
56
+ "#!pip install opencv-python\n"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 2,
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "def mask(output):\n",
66
+ " dimension = output.shape[0]\n",
67
+ " \n",
68
+ " # Mask pixels outside of scanning sector\n",
69
+ " m1, m2 = np.meshgrid(np.arange(dimension), np.arange(dimension))\n",
70
+ " \n",
71
+ "\n",
72
+ " mask = ((m1+m2)>int(dimension/2) + int(dimension/10)) \n",
73
+ " mask *= ((m1-m2)<int(dimension/2) + int(dimension/10))\n",
74
+ " mask = np.reshape(mask, (dimension, dimension)).astype(np.int8)\n",
75
+ " maskedImage = cv2.bitwise_and(output, output, mask = mask)\n",
76
+ " \n",
77
+ " #print(maskedImage.shape)\n",
78
+ " \n",
79
+ " return maskedImage\n",
80
+ "\n"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 3,
86
+ "metadata": {},
87
+ "outputs": [],
88
+ "source": [
89
+ "def makeVideo(fileToProcess, destinationFolder):\n",
90
+ " try:\n",
91
+ " fileName = fileToProcess.split('\\\\')[-1] #\\\\ if windows, / if on mac or sherlock\n",
92
+ " #hex(abs(hash(fileToProcess.split('/')[-1]))).upper()\n",
93
+ "\n",
94
+ " if not os.path.isdir(os.path.join(destinationFolder,fileName)):\n",
95
+ "\n",
96
+ " dataset = dicom.dcmread(fileToProcess, force=True)\n",
97
+ " testarray = dataset.pixel_array\n",
98
+ "\n",
99
+ " frame0 = testarray[0]\n",
100
+ " mean = np.mean(frame0, axis=1)\n",
101
+ " mean = np.mean(mean, axis=1)\n",
102
+ " yCrop = np.where(mean<1)[0][0]\n",
103
+ " testarray = testarray[:, yCrop:, :, :]\n",
104
+ "\n",
105
+ " bias = int(np.abs(testarray.shape[2] - testarray.shape[1])/2)\n",
106
+ " if bias>0:\n",
107
+ " if testarray.shape[1] < testarray.shape[2]:\n",
108
+ " testarray = testarray[:, :, bias:-bias, :]\n",
109
+ " else:\n",
110
+ " testarray = testarray[:, bias:-bias, :, :]\n",
111
+ "\n",
112
+ "\n",
113
+ " print(testarray.shape)\n",
114
+ " frames,height,width,channels = testarray.shape\n",
115
+ "\n",
116
+ " fps = 30\n",
117
+ "\n",
118
+ " try:\n",
119
+ " fps = dataset[(0x18, 0x40)].value\n",
120
+ " except:\n",
121
+ " print(\"couldn't find frame rate, default to 30\")\n",
122
+ "\n",
123
+ " fourcc = cv2.VideoWriter_fourcc('M','J','P','G')\n",
124
+ " video_filename = os.path.join(destinationFolder, fileName + '.avi')\n",
125
+ " out = cv2.VideoWriter(video_filename, fourcc, fps, cropSize)\n",
126
+ "\n",
127
+ "\n",
128
+ " for i in range(frames):\n",
129
+ "\n",
130
+ " outputA = testarray[i,:,:,0]\n",
131
+ " smallOutput = outputA[int(height/10):(height - int(height/10)), int(height/10):(height - int(height/10))]\n",
132
+ "\n",
133
+ " # Resize image\n",
134
+ " output = cv2.resize(smallOutput, cropSize, interpolation = cv2.INTER_CUBIC)\n",
135
+ "\n",
136
+ " finaloutput = mask(output)\n",
137
+ "\n",
138
+ "\n",
139
+ " finaloutput = cv2.merge([finaloutput,finaloutput,finaloutput])\n",
140
+ " out.write(finaloutput)\n",
141
+ "\n",
142
+ " out.release()\n",
143
+ "\n",
144
+ " else:\n",
145
+ " print(fileName,\"hasAlreadyBeenProcessed\")\n",
146
+ " except:\n",
147
+ " print(\"something filed, not sure what, have to debug\", fileName)\n",
148
+ " return 0"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": null,
154
+ "metadata": {},
155
+ "outputs": [],
156
+ "source": [
157
+ "AllA4cNames = \"Input Folder Name\"\n",
158
+ "\n",
159
+ "count = 0\n",
160
+ " \n",
161
+ "cropSize = (112,112)\n",
162
+ "subfolders = os.listdir(AllA4cNames)\n",
163
+ "\n",
164
+ "\n",
165
+ "for folder in subfolders:\n",
166
+ " print(folder)\n",
167
+ "\n",
168
+ " for content in os.listdir(os.path.join(AllA4cNames, folder)):\n",
169
+ " for subcontent in os.listdir(os.path.join(AllA4cNames, folder, content)):\n",
170
+ " count += 1\n",
171
+ " \n",
172
+ "\n",
173
+ " VideoPath = os.path.join(AllA4cNames, folder, content, subcontent)\n",
174
+ "\n",
175
+ " print(count, folder, content, subcontent)\n",
176
+ "\n",
177
+ " if not os.path.exists(os.path.join(destinationFolder,subcontent + \".avi\")):\n",
178
+ " makeVideo(VideoPath, destinationFolder)\n",
179
+ " else:\n",
180
+ " print(\"Already did this file\", VideoPath)\n",
181
+ "\n",
182
+ "\n",
183
+ "print(len(AllA4cFilenames))"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": null,
189
+ "metadata": {},
190
+ "outputs": [],
191
+ "source": []
192
+ }
193
+ ],
194
+ "metadata": {
195
+ "kernelspec": {
196
+ "display_name": "Python 3",
197
+ "language": "python",
198
+ "name": "python3"
199
+ },
200
+ "language_info": {
201
+ "codemirror_mode": {
202
+ "name": "ipython",
203
+ "version": 3
204
+ },
205
+ "file_extension": ".py",
206
+ "mimetype": "text/x-python",
207
+ "name": "python",
208
+ "nbconvert_exporter": "python",
209
+ "pygments_lexer": "ipython3",
210
+ "version": "3.7.4"
211
+ }
212
+ },
213
+ "nbformat": 4,
214
+ "nbformat_minor": 2
215
+ }
dynamic/scripts/InitializationNotebook.ipynb ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# David Ouyang 12/5/2019\n",
10
+ "\n",
11
+ "# Notebook which:\n",
12
+ "# 1. Downloads weights\n",
13
+ "# 2. Initializes model and imports weights\n",
14
+ "# 3. Performs test time evaluation of videos (already preprocessed with ConvertDICOMToAVI.ipynb)\n",
15
+ "\n",
16
+ "import re\n",
17
+ "import os, os.path\n",
18
+ "from os.path import splitext\n",
19
+ "import pydicom as dicom\n",
20
+ "import numpy as np\n",
21
+ "from pydicom.uid import UID, generate_uid\n",
22
+ "import shutil\n",
23
+ "from multiprocessing import dummy as multiprocessing\n",
24
+ "import time\n",
25
+ "import subprocess\n",
26
+ "import datetime\n",
27
+ "from datetime import date\n",
28
+ "import sys\n",
29
+ "import cv2\n",
30
+ "import matplotlib.pyplot as plt\n",
31
+ "import sys\n",
32
+ "from shutil import copy\n",
33
+ "import math\n",
34
+ "import torch\n",
35
+ "import torchvision\n",
36
+ "\n",
37
+ "sys.path.append(\"..\")\n",
38
+ "import echonet\n",
39
+ "\n",
40
+ "import wget \n",
41
+ "\n",
42
+ "#destinationFolder = \"/Users/davidouyang/Dropbox/Echo Research/CodeBase/Output\"\n",
43
+ "destinationFolder = \"C:\\\\Users\\\\Windows\\\\Dropbox\\\\Echo Research\\\\CodeBase\\\\Output\"\n",
44
+ "#videosFolder = \"/Users/davidouyang/Dropbox/Echo Research/CodeBase/a4c-video-dir\"\n",
45
+ "videosFolder = \"C:\\\\Users\\\\Windows\\\\Dropbox\\\\Echo Research\\\\CodeBase\\\\a4c-video-dir\"\n",
46
+ "#DestinationForWeights = \"/Users/davidouyang/Dropbox/Echo Research/CodeBase/EchoNetDynamic-Weights\"\n",
47
+ "DestinationForWeights = \"C:\\\\Users\\\\Windows\\\\Dropbox\\\\Echo Research\\\\CodeBase\\\\EchoNetDynamic-Weights\""
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": 5,
53
+ "metadata": {},
54
+ "outputs": [
55
+ {
56
+ "name": "stdout",
57
+ "output_type": "stream",
58
+ "text": [
59
+ "The weights are at C:\\Users\\Windows\\Dropbox\\Echo Research\\CodeBase\\EchoNetDynamic-Weights\n",
60
+ "Segmentation Weights already present\n",
61
+ "EF Weights already present\n"
62
+ ]
63
+ }
64
+ ],
65
+ "source": [
66
+ "# Download model weights\n",
67
+ "\n",
68
+ "if os.path.exists(DestinationForWeights):\n",
69
+ " print(\"The weights are at\", DestinationForWeights)\n",
70
+ "else:\n",
71
+ " print(\"Creating folder at \", DestinationForWeights, \" to store weights\")\n",
72
+ " os.mkdir(DestinationForWeights)\n",
73
+ " \n",
74
+ "segmentationWeightsURL = 'https://github.com/douyang/EchoNetDynamic/releases/download/v1.0.0/deeplabv3_resnet50_random.pt'\n",
75
+ "ejectionFractionWeightsURL = 'https://github.com/douyang/EchoNetDynamic/releases/download/v1.0.0/r2plus1d_18_32_2_pretrained.pt'\n",
76
+ "\n",
77
+ "\n",
78
+ "if not os.path.exists(os.path.join(DestinationForWeights, os.path.basename(segmentationWeightsURL))):\n",
79
+ " print(\"Downloading Segmentation Weights, \", segmentationWeightsURL,\" to \",os.path.join(DestinationForWeights,os.path.basename(segmentationWeightsURL)))\n",
80
+ " filename = wget.download(segmentationWeightsURL, out = DestinationForWeights)\n",
81
+ "else:\n",
82
+ " print(\"Segmentation Weights already present\")\n",
83
+ " \n",
84
+ "if not os.path.exists(os.path.join(DestinationForWeights, os.path.basename(ejectionFractionWeightsURL))):\n",
85
+ " print(\"Downloading EF Weights, \", ejectionFractionWeightsURL,\" to \",os.path.join(DestinationForWeights,os.path.basename(ejectionFractionWeightsURL)))\n",
86
+ " filename = wget.download(ejectionFractionWeightsURL, out = DestinationForWeights)\n",
87
+ "else:\n",
88
+ " print(\"EF Weights already present\")\n",
89
+ " \n"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 6,
95
+ "metadata": {},
96
+ "outputs": [
97
+ {
98
+ "name": "stdout",
99
+ "output_type": "stream",
100
+ "text": [
101
+ "loading weights from C:\\Users\\Windows\\Dropbox\\Echo Research\\CodeBase\\EchoNetDynamic-Weights\\r2plus1d_18_32_2_pretrained\n",
102
+ "cuda is available, original weights\n",
103
+ "external_test ['0X1A05DFFFCAFB253B.avi', '0X1A0A263B22CCD966.avi', '0X1A2A76BDB5B98BED.avi', '0X1A2C60147AF9FDAE.avi', '0X1A2E9496910EFF5B.avi', '0X1A3D565B371DC573.avi', '0X1A3E7BF1DFB132FB.avi', '0X1A5FAE3F9D37794E.avi', '0X1A6ACFE7B286DAFC.avi', '0X1A8D85542DBE8204.avi', '23_Apical_4_chamber_view.dcm.avi', '62_Apical_4_chamber_view.dcm.avi', '64_Apical_4_chamber_view.dcm.avi']\n"
104
+ ]
105
+ },
106
+ {
107
+ "name": "stderr",
108
+ "output_type": "stream",
109
+ "text": [
110
+ "100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:10<00:00, 1.00s/it]\n",
111
+ "100%|████████████████████████████████████████████████████████| 13/13 [00:29<00:00, 2.26s/it, 3122.29 (3440.26) / 0.00]\n"
112
+ ]
113
+ }
114
+ ],
115
+ "source": [
116
+ "# Initialize and Run EF model\n",
117
+ "\n",
118
+ "frames = 32\n",
119
+ "period = 1 #2\n",
120
+ "batch_size = 20\n",
121
+ "model = torchvision.models.video.r2plus1d_18(pretrained=False)\n",
122
+ "model.fc = torch.nn.Linear(model.fc.in_features, 1)\n",
123
+ "\n",
124
+ "\n",
125
+ "\n",
126
+ "print(\"loading weights from \", os.path.join(DestinationForWeights, \"r2plus1d_18_32_2_pretrained\"))\n",
127
+ "\n",
128
+ "if torch.cuda.is_available():\n",
129
+ " print(\"cuda is available, original weights\")\n",
130
+ " device = torch.device(\"cuda\")\n",
131
+ " model = torch.nn.DataParallel(model)\n",
132
+ " model.to(device)\n",
133
+ " checkpoint = torch.load(os.path.join(DestinationForWeights, os.path.basename(ejectionFractionWeightsURL)))\n",
134
+ " model.load_state_dict(checkpoint['state_dict'])\n",
135
+ "else:\n",
136
+ " print(\"cuda is not available, cpu weights\")\n",
137
+ " device = torch.device(\"cpu\")\n",
138
+ " checkpoint = torch.load(os.path.join(DestinationForWeights, os.path.basename(ejectionFractionWeightsURL)), map_location = \"cpu\")\n",
139
+ " state_dict_cpu = {k[7:]: v for (k, v) in checkpoint['state_dict'].items()}\n",
140
+ " model.load_state_dict(state_dict_cpu)\n",
141
+ "\n",
142
+ "\n",
143
+ "# try some random weights: final_r2+1d_model_regression_EF_sgd_skip1_32frames.pth.tar\n",
144
+ "# scp ouyangd@arthur2:~/Echo-Tracing-Analysis/final_r2+1d_model_regression_EF_sgd_skip1_32frames.pth.tar \"C:\\Users\\Windows\\Dropbox\\Echo Research\\CodeBase\\EchoNetDynamic-Weights\"\n",
145
+ "#Weights = \"final_r2+1d_model_regression_EF_sgd_skip1_32frames.pth.tar\"\n",
146
+ "\n",
147
+ "\n",
148
+ "output = os.path.join(destinationFolder, \"cedars_ef_output.csv\")\n",
149
+ "\n",
150
+ "ds = echonet.datasets.Echo(split = \"external_test\", external_test_location = videosFolder, crops=\"all\")\n",
151
+ "print(ds.split, ds.fnames)\n",
152
+ "\n",
153
+ "mean, std = echonet.utils.get_mean_and_std(ds)\n",
154
+ "\n",
155
+ "kwargs = {\"target_type\": \"EF\",\n",
156
+ " \"mean\": mean,\n",
157
+ " \"std\": std,\n",
158
+ " \"length\": frames,\n",
159
+ " \"period\": period,\n",
160
+ " }\n",
161
+ "\n",
162
+ "ds = echonet.datasets.Echo(split = \"external_test\", external_test_location = videosFolder, **kwargs, crops=\"all\")\n",
163
+ "\n",
164
+ "test_dataloader = torch.utils.data.DataLoader(ds, batch_size = 1, num_workers = 5, shuffle = True, pin_memory=(device.type == \"cuda\"))\n",
165
+ "loss, yhat, y = echonet.utils.video.run_epoch(model, test_dataloader, \"test\", None, device, save_all=True, blocks=25)\n",
166
+ "\n",
167
+ "with open(output, \"w\") as g:\n",
168
+ " for (filename, pred) in zip(ds.fnames, yhat):\n",
169
+ " for (i,p) in enumerate(pred):\n",
170
+ " g.write(\"{},{},{:.4f}\\n\".format(filename, i, p))\n"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": null,
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": [
179
+ "# Initialize and Run Segmentation model\n",
180
+ "\n",
181
+ "torch.cuda.empty_cache()\n",
182
+ "\n",
183
+ "\n",
184
+ "videosFolder = \"C:\\\\Users\\\\Windows\\\\Dropbox\\\\Echo Research\\\\CodeBase\\\\View Classification\\\\AppearsA4c\\\\Resized2\"\n",
185
+ "\n",
186
+ "def collate_fn(x):\n",
187
+ " x, f = zip(*x)\n",
188
+ " i = list(map(lambda t: t.shape[1], x))\n",
189
+ " x = torch.as_tensor(np.swapaxes(np.concatenate(x, 1), 0, 1))\n",
190
+ " return x, f, i\n",
191
+ "\n",
192
+ "dataloader = torch.utils.data.DataLoader(echonet.datasets.Echo(split=\"external_test\", external_test_location = videosFolder, target_type=[\"Filename\"], length=None, period=1, mean=mean, std=std),\n",
193
+ " batch_size=10, num_workers=0, shuffle=False, pin_memory=(device.type == \"cuda\"), collate_fn=collate_fn)\n",
194
+ "if not all([os.path.isfile(os.path.join(destinationFolder, \"labels\", os.path.splitext(f)[0] + \".npy\")) for f in dataloader.dataset.fnames]):\n",
195
+ " # Save segmentations for all frames\n",
196
+ " # Only run if missing files\n",
197
+ "\n",
198
+ " pathlib.Path(os.path.join(destinationFolder, \"labels\")).mkdir(parents=True, exist_ok=True)\n",
199
+ " block = 1024\n",
200
+ " model.eval()\n",
201
+ "\n",
202
+ " with torch.no_grad():\n",
203
+ " for (x, f, i) in tqdm.tqdm(dataloader):\n",
204
+ " x = x.to(device)\n",
205
+ " y = np.concatenate([model(x[i:(i + block), :, :, :])[\"out\"].detach().cpu().numpy() for i in range(0, x.shape[0], block)]).astype(np.float16)\n",
206
+ " start = 0\n",
207
+ " for (filename, offset) in zip(f, i):\n",
208
+ " np.save(os.path.join(destinationFolder, \"labels\", os.path.splitext(filename)[0]), y[start:(start + offset), 0, :, :])\n",
209
+ " start += offset\n",
210
+ " \n",
211
+ "dataloader = torch.utils.data.DataLoader(echonet.datasets.Echo(split=\"external_test\", external_test_location = videosFolder, target_type=[\"Filename\"], length=None, period=1, segmentation=os.path.join(destinationFolder, \"labels\")),\n",
212
+ " batch_size=1, num_workers=8, shuffle=False, pin_memory=False)\n",
213
+ "if not all(os.path.isfile(os.path.join(destinationFolder, \"videos\", f)) for f in dataloader.dataset.fnames):\n",
214
+ " pathlib.Path(os.path.join(destinationFolder, \"videos\")).mkdir(parents=True, exist_ok=True)\n",
215
+ " pathlib.Path(os.path.join(destinationFolder, \"size\")).mkdir(parents=True, exist_ok=True)\n",
216
+ " echonet.utils.latexify()\n",
217
+ " with open(os.path.join(destinationFolder, \"size.csv\"), \"w\") as g:\n",
218
+ " g.write(\"Filename,Frame,Size,ComputerSmall\\n\")\n",
219
+ " for (x, filename) in tqdm.tqdm(dataloader):\n",
220
+ " x = x.numpy()\n",
221
+ " for i in range(len(filename)):\n",
222
+ " img = x[i, :, :, :, :].copy()\n",
223
+ " logit = img[2, :, :, :].copy()\n",
224
+ " img[1, :, :, :] = img[0, :, :, :]\n",
225
+ " img[2, :, :, :] = img[0, :, :, :]\n",
226
+ " img = np.concatenate((img, img), 3)\n",
227
+ " img[0, :, :, 112:] = np.maximum(255. * (logit > 0), img[0, :, :, 112:])\n",
228
+ "\n",
229
+ " img = np.concatenate((img, np.zeros_like(img)), 2)\n",
230
+ " size = (logit > 0).sum(2).sum(1)\n",
231
+ " try:\n",
232
+ " trim_min = sorted(size)[round(len(size) ** 0.05)]\n",
233
+ " except:\n",
234
+ " import code; code.interact(local=dict(globals(), **locals()))\n",
235
+ " trim_max = sorted(size)[round(len(size) ** 0.95)]\n",
236
+ " trim_range = trim_max - trim_min\n",
237
+ " peaks = set(scipy.signal.find_peaks(-size, distance=20, prominence=(0.50 * trim_range))[0])\n",
238
+ " for (x, y) in enumerate(size):\n",
239
+ " g.write(\"{},{},{},{}\\n\".format(filename[0], x, y, 1 if x in peaks else 0))\n",
240
+ " fig = plt.figure(figsize=(size.shape[0] / 50 * 1.5, 3))\n",
241
+ " plt.scatter(np.arange(size.shape[0]) / 50, size, s=1)\n",
242
+ " ylim = plt.ylim()\n",
243
+ " for p in peaks:\n",
244
+ " plt.plot(np.array([p, p]) / 50, ylim, linewidth=1)\n",
245
+ " plt.ylim(ylim)\n",
246
+ " plt.title(os.path.splitext(filename[i])[0])\n",
247
+ " plt.xlabel(\"Seconds\")\n",
248
+ " plt.ylabel(\"Size (pixels)\")\n",
249
+ " plt.tight_layout()\n",
250
+ " plt.savefig(os.path.join(destinationFolder, \"size\", os.path.splitext(filename[i])[0] + \".pdf\"))\n",
251
+ " plt.close(fig)\n",
252
+ " size -= size.min()\n",
253
+ " size = size / size.max()\n",
254
+ " size = 1 - size\n",
255
+ " for (x, y) in enumerate(size):\n",
256
+ " img[:, :, int(round(115 + 100 * y)), int(round(x / len(size) * 200 + 10))] = 255.\n",
257
+ " interval = np.array([-3, -2, -1, 0, 1, 2, 3])\n",
258
+ " for a in interval:\n",
259
+ " for b in interval:\n",
260
+ " img[:, x, a + int(round(115 + 100 * y)), b + int(round(x / len(size) * 200 + 10))] = 255.\n",
261
+ " if x in peaks:\n",
262
+ " img[:, :, 200:225, b + int(round(x / len(size) * 200 + 10))] = 255.\n",
263
+ " echonet.utils.savevideo(os.path.join(destinationFolder, \"videos\", filename[i]), img.astype(np.uint8), 50) "
264
+ ]
265
+ }
266
+ ],
267
+ "metadata": {
268
+ "kernelspec": {
269
+ "display_name": "Python 3",
270
+ "language": "python",
271
+ "name": "python3"
272
+ },
273
+ "language_info": {
274
+ "codemirror_mode": {
275
+ "name": "ipython",
276
+ "version": 3
277
+ },
278
+ "file_extension": ".py",
279
+ "mimetype": "text/x-python",
280
+ "name": "python",
281
+ "nbconvert_exporter": "python",
282
+ "pygments_lexer": "ipython3",
283
+ "version": "3.7.4"
284
+ }
285
+ },
286
+ "nbformat": 4,
287
+ "nbformat_minor": 2
288
+ }
dynamic/scripts/beat_by_beat_analysis.R ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ library(ggplot2)
2
+ library(stringr)
3
+ library(plyr)
4
+ library(dplyr)
5
+ library(lubridate)
6
+ library(reshape2)
7
+ library(scales)
8
+ library(ggthemes)
9
+ library(Metrics)
10
+
11
+ data <- read.csv("r2plus1d_18_32_2_pretrained_test_predictions.csv", header = FALSE)
12
+ str(data)
13
+
14
+
15
+ dataNoAugmentation <- data[data$V2 == 0,]
16
+ str(dataNoAugmentation)
17
+
18
+
19
+ dataGlobalAugmentation <- data %>% group_by(V1) %>% summarize(meanPrediction = mean(V3), sdPred = sd(V3))
20
+ str(dataGlobalAugmentation)
21
+
22
+
23
+ sizeData <- read.csv("size.csv")
24
+ sizeData <- sizeData[sizeData$ComputerSmall == 1,]
25
+ str(sizeData)
26
+
27
+ sizeRelevantFrames <- sizeData[c(1,2)]
28
+ sizeRelevantFrames$Frame <- sizeRelevantFrames$Frame - 32
29
+ sizeRelevantFrames[sizeRelevantFrames$Frame < 0,]$Frame <- 0
30
+
31
+
32
+ beatByBeat <- merge(sizeRelevantFrames, data, by.x = c("Filename", "Frame"), by.y = c("V1", "V2"))
33
+ beatByBeat <- beatByBeat %>% group_by(Filename) %>% summarize(meanPrediction = mean(V3), sdPred = sd(V3))
34
+ str(beatByBeat)
35
+
36
+ ### For use, need to specify file directory
37
+ fileLocation <- "/Users/davidouyang/Local Medical Data/"
38
+ ActualNumbers <- read.csv(paste0(fileLocation, "FileList.csv", sep = ""))
39
+ ActualNumbers <- ActualNumbers[c(1,2)]
40
+ str(ActualNumbers)
41
+
42
+
43
+
44
+ dataNoAugmentation <- merge(dataNoAugmentation, ActualNumbers, by.x = "V1", by.y = "Filename", all.x = TRUE)
45
+ dataNoAugmentation$AbsErr <- abs(dataNoAugmentation$V3 - dataNoAugmentation$EF)
46
+ str(dataNoAugmentation)
47
+
48
+ summary(abs(dataNoAugmentation$V3 - dataNoAugmentation$EF))
49
+ # Mean of 4.216
50
+
51
+ rmse(dataNoAugmentation$V3,dataNoAugmentation$EF)
52
+ ## 5.56
53
+
54
+ modelNoAugmentation <- lm(dataNoAugmentation$EF ~ dataNoAugmentation$V3)
55
+ summary(modelNoAugmentation)$r.squared
56
+ # 0.79475
57
+
58
+
59
+ beatByBeat <- merge(beatByBeat, ActualNumbers, by.x = "Filename", by.y = "Filename", all.x = TRUE)
60
+ summary(abs(beatByBeat$meanPrediction - beatByBeat$EF))
61
+ # Mean of 4.051697
62
+
63
+ rmse(beatByBeat$meanPrediction, beatByBeat$EF)
64
+ # 5.325237
65
+
66
+ modelBeatByBeat <- lm(beatByBeat$EF ~ beatByBeat$meanPrediction)
67
+ summary(modelBeatByBeat)$r.squared
68
+ # 0.8093174
69
+
70
+
71
+ beatByBeatAnalysis <- merge(sizeRelevantFrames, data, by.x = c("Filename", "Frame"), by.y = c("V1", "V2"))
72
+ str(beatByBeatAnalysis)
73
+
74
+
75
+ MAEdata <- data.frame(counter = 1:500)
76
+ MAEdata$sample <- -9999
77
+ MAEdata$error <- -9999
78
+
79
+ str(MAEdata)
80
+
81
+ for (i in 1:500){
82
+
83
+
84
+ samplingBeat <- sample_n(beatByBeatAnalysis %>% group_by(Filename), 1 + floor((i-1)/100), replace = TRUE) %>% group_by(Filename) %>% dplyr::summarize(meanPred = mean(V3))
85
+ samplingBeat <- merge(samplingBeat, ActualNumbers, by.x = "Filename", by.y = "Filename", all.x = TRUE)
86
+ samplingBeat$error <- abs(samplingBeat$meanPred - samplingBeat$EF)
87
+
88
+ MAEdata$sample[i] <- 1 + floor((i-1)/100)
89
+ MAEdata$error[i] <- mean(samplingBeat$error )
90
+
91
+
92
+ }
93
+
94
+ str(MAEdata)
95
+
96
+ beatBoxPlot <- ggplot(data = MAEdata) + geom_boxplot(aes(x = sample, y = error, group = sample), outlier.shape = NA
97
+ ) + theme_classic() + theme(legend.position = "none", axis.text.y = element_text( size=7)) + xlab("Number of Sampled Beats") + ylab("Mean Absolute Error") + scale_fill_brewer(palette = "Set1", direction = -1)
98
+
99
+ beatBoxPlot
100
+
dynamic/scripts/plot_complexity.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """Code to generate plots for Extended Data Fig. 4."""
4
+
5
+ import os
6
+
7
+ import matplotlib
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+
11
+ import echonet
12
+
13
+
14
+ def main(root=os.path.join("timing", "video"),
15
+ fig_root=os.path.join("figure", "complexity"),
16
+ FRAMES=(1, 8, 16, 32, 64, 96),
17
+ pretrained=True):
18
+ """Generate plots for Extended Data Fig. 4."""
19
+
20
+ echonet.utils.latexify()
21
+
22
+ os.makedirs(fig_root, exist_ok=True)
23
+ fig = plt.figure(figsize=(6.50, 2.50))
24
+ gs = matplotlib.gridspec.GridSpec(1, 3, width_ratios=[2.5, 2.5, 1.50])
25
+ ax = (plt.subplot(gs[0]), plt.subplot(gs[1]), plt.subplot(gs[2]))
26
+
27
+ # Create legend
28
+ for (model, color) in zip(["EchoNet-Dynamic (EF)", "R3D", "MC3"], matplotlib.colors.TABLEAU_COLORS):
29
+ ax[2].plot([float("nan")], [float("nan")], "-", color=color, label=model)
30
+ ax[2].set_title("")
31
+ ax[2].axis("off")
32
+ ax[2].legend(loc="center")
33
+
34
+ for (model, color) in zip(["r2plus1d_18", "r3d_18", "mc3_18"], matplotlib.colors.TABLEAU_COLORS):
35
+ for split in ["val"]: # ["val", "train"]:
36
+ print(model, split)
37
+ data = [load(root, model, frames, 1, pretrained, split) for frames in FRAMES]
38
+ time = np.array(list(map(lambda x: x[0], data)))
39
+ n = np.array(list(map(lambda x: x[1], data)))
40
+ mem_allocated = np.array(list(map(lambda x: x[2], data)))
41
+ # mem_cached = np.array(list(map(lambda x: x[3], data)))
42
+ batch_size = np.array(list(map(lambda x: x[4], data)))
43
+
44
+ # Plot Time (panel a)
45
+ ax[0].plot(FRAMES, time / n, "-" if pretrained else "--", marker=".", color=color, linewidth=(1 if split == "train" else None))
46
+ print("Time:\n" + "\n".join(map(lambda x: "{:8d}: {:f}".format(*x), zip(FRAMES, time / n))))
47
+
48
+ # Plot Memory (panel b)
49
+ ax[1].plot(FRAMES, mem_allocated / batch_size / 1e9, "-" if pretrained else "--", marker=".", color=color, linewidth=(1 if split == "train" else None))
50
+ print("Memory:\n" + "\n".join(map(lambda x: "{:8d}: {:f}".format(*x), zip(FRAMES, mem_allocated / batch_size / 1e9))))
51
+ print()
52
+
53
+ # Labels for panel a
54
+ ax[0].set_xticks(FRAMES)
55
+ ax[0].text(-0.05, 1.10, "(a)", transform=ax[0].transAxes)
56
+ ax[0].set_xlabel("Clip length (frames)")
57
+ ax[0].set_ylabel("Time Per Clip (seconds)")
58
+
59
+ # Labels for panel b
60
+ ax[1].set_xticks(FRAMES)
61
+ ax[1].text(-0.05, 1.10, "(b)", transform=ax[1].transAxes)
62
+ ax[1].set_xlabel("Clip length (frames)")
63
+ ax[1].set_ylabel("Memory Per Clip (GB)")
64
+
65
+ # Save figure
66
+ plt.tight_layout()
67
+ plt.savefig(os.path.join(fig_root, "complexity.pdf"))
68
+ plt.savefig(os.path.join(fig_root, "complexity.eps"))
69
+ plt.close(fig)
70
+
71
+
72
+ def load(root, model, frames, period, pretrained, split):
73
+ """Loads runtime and memory usage for specified hyperparameter choice."""
74
+ with open(os.path.join(root, "{}_{}_{}_{}".format(model, frames, period, "pretrained" if pretrained else "random"), "log.csv"), "r") as f:
75
+ for line in f:
76
+ line = line.split(",")
77
+ if len(line) < 4:
78
+ # Skip lines that are not csv (these lines log information)
79
+ continue
80
+ if line[1] == split:
81
+ *_, time, n, mem_allocated, mem_cached, batch_size = line
82
+ time = float(time)
83
+ n = int(n)
84
+ mem_allocated = int(mem_allocated)
85
+ mem_cached = int(mem_cached)
86
+ batch_size = int(batch_size)
87
+ return time, n, mem_allocated, mem_cached, batch_size
88
+ raise ValueError("File missing information.")
89
+
90
+
91
+ if __name__ == "__main__":
92
+ main()
dynamic/scripts/plot_hyperparameter_sweep.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """Code to generate plots for Extended Data Fig. 1."""
4
+
5
+ import os
6
+
7
+ import matplotlib
8
+ import matplotlib.pyplot as plt
9
+
10
+ import echonet
11
+
12
+
13
+ def main(root=os.path.join("output", "video"),
14
+ fig_root=os.path.join("figure", "hyperparameter"),
15
+ FRAMES=(1, 8, 16, 32, 64, 96, None),
16
+ PERIOD=(1, 2, 4, 6, 8)
17
+ ):
18
+ """Generate plots for Extended Data Fig. 1."""
19
+
20
+ echonet.utils.latexify()
21
+ os.makedirs(fig_root, exist_ok=True)
22
+
23
+ # Parameters for plotting length sweep
24
+ MAX = FRAMES[-2]
25
+ START = 1 # Starting point for normal range
26
+ TERM0 = 104 # Ending point for normal range
27
+ BREAK = 112 # Location for break
28
+ TERM1 = 120 # Starting point for "all" section
29
+ ALL = 128 # Location of "all" point
30
+ END = 135 # Ending point for "all" section
31
+ RATIO = (BREAK - START) / (END - BREAK)
32
+
33
+ # Set up figure
34
+ fig = plt.figure(figsize=(3 + 2.5 + 1.5, 2.75))
35
+ outer = matplotlib.gridspec.GridSpec(1, 3, width_ratios=[3, 2.5, 1.50])
36
+ ax = plt.subplot(outer[2]) # Legend
37
+ ax2 = plt.subplot(outer[1]) # Period plot
38
+ gs = matplotlib.gridspec.GridSpecFromSubplotSpec(
39
+ 1, 2, subplot_spec=outer[0], width_ratios=[RATIO, 1], wspace=0.020) # Length plot
40
+
41
+ # Plot legend
42
+ for (model, color) in zip(["EchoNet-Dynamic (EF)", "R3D", "MC3"],
43
+ matplotlib.colors.TABLEAU_COLORS):
44
+ ax.plot([float("nan")], [float("nan")], "-", color=color, label=model)
45
+ ax.plot([float("nan")], [float("nan")], "-", color="k", label="Pretrained")
46
+ ax.plot([float("nan")], [float("nan")], "--", color="k", label="Random")
47
+ ax.set_title("")
48
+ ax.axis("off")
49
+ ax.legend(loc="center")
50
+
51
+ # Plot length sweep (panel a)
52
+ ax0 = plt.subplot(gs[0])
53
+ ax1 = plt.subplot(gs[1], sharey=ax0)
54
+ print("FRAMES")
55
+ for (model, color) in zip(["r2plus1d_18", "r3d_18", "mc3_18"],
56
+ matplotlib.colors.TABLEAU_COLORS):
57
+ for pretrained in [True, False]:
58
+ loss = [load(root, model, frames, 1, pretrained) for frames in FRAMES]
59
+ print(model, pretrained)
60
+ print(" ".join(list(map(lambda x: "{:.1f}".format(x) if x is not None else None, loss))))
61
+
62
+ l0 = loss[-2]
63
+ l1 = loss[-1]
64
+ ax0.plot(FRAMES[:-1] + (TERM0,),
65
+ loss[:-1] + [l0 + (l1 - l0) * (TERM0 - MAX) / (ALL - MAX)],
66
+ "-" if pretrained else "--", color=color)
67
+ ax1.plot([TERM1, ALL],
68
+ [l0 + (l1 - l0) * (TERM1 - MAX) / (ALL - MAX)] + [loss[-1]],
69
+ "-" if pretrained else "--", color=color)
70
+ ax0.scatter(list(map(lambda x: x if x is not None else ALL, FRAMES)), loss, color=color, s=4)
71
+ ax1.scatter(list(map(lambda x: x if x is not None else ALL, FRAMES)), loss, color=color, s=4)
72
+
73
+ ax0.set_xticks(list(map(lambda x: x if x is not None else ALL, FRAMES)))
74
+ ax1.set_xticks(list(map(lambda x: x if x is not None else ALL, FRAMES)))
75
+ ax0.set_xticklabels(list(map(lambda x: x if x is not None else "All", FRAMES)))
76
+ ax1.set_xticklabels(list(map(lambda x: x if x is not None else "All", FRAMES)))
77
+
78
+ # https://stackoverflow.com/questions/5656798/python-matplotlib-is-there-a-way-to-make-a-discontinuous-axis/43684155
79
+ # zoom-in / limit the view to different portions of the data
80
+ ax0.set_xlim(START, BREAK) # most of the data
81
+ ax1.set_xlim(BREAK, END)
82
+
83
+ # hide the spines between ax and ax2
84
+ ax0.spines['right'].set_visible(False)
85
+ ax1.spines['left'].set_visible(False)
86
+
87
+ ax1.get_yaxis().set_visible(False)
88
+
89
+ d = 0.015 # how big to make the diagonal lines in axes coordinates
90
+ # arguments to pass plot, just so we don't keep repeating them
91
+ kwargs = dict(transform=ax0.transAxes, color='k', clip_on=False, linewidth=1)
92
+ x0, x1, y0, y1 = ax0.axis()
93
+ scale = (y1 - y0) / (x1 - x0) / 2
94
+ ax0.plot((1 - scale * d, 1 + scale * d), (-d, +d), **kwargs) # top-left diagonal
95
+ ax0.plot((1 - scale * d, 1 + scale * d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
96
+
97
+ kwargs.update(transform=ax1.transAxes) # switch to the bottom 1xes
98
+ x0, x1, y0, y1 = ax1.axis()
99
+ scale = (y1 - y0) / (x1 - x0) / 2
100
+ ax1.plot((-scale * d, scale * d), (-d, +d), **kwargs) # top-right diagonal
101
+ ax1.plot((-scale * d, scale * d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
102
+
103
+ # ax0.xaxis.label.set_transform(matplotlib.transforms.blended_transform_factory(
104
+ # matplotlib.transforms.IdentityTransform(), fig.transFigure # specify x, y transform
105
+ # )) # changed from default blend (IdentityTransform(), a[0].transAxes)
106
+ ax0.xaxis.label.set_position((0.6, 0.0))
107
+ ax0.text(-0.05, 1.10, "(a)", transform=ax0.transAxes)
108
+ ax0.set_xlabel("Clip length (frames)")
109
+ ax0.set_ylabel("Validation Loss")
110
+
111
+ # Plot period sweep (panel b)
112
+ print("PERIOD")
113
+ for (model, color) in zip(["r2plus1d_18", "r3d_18", "mc3_18"], matplotlib.colors.TABLEAU_COLORS):
114
+ for pretrained in [True, False]:
115
+ loss = [load(root, model, 64 // period, period, pretrained) for period in PERIOD]
116
+ print(model, pretrained)
117
+ print(" ".join(list(map(lambda x: "{:.1f}".format(x) if x is not None else None, loss))))
118
+
119
+ ax2.plot(PERIOD, loss, "-" if pretrained else "--", marker=".", color=color)
120
+ ax2.set_xticks(PERIOD)
121
+ ax2.text(-0.05, 1.10, "(b)", transform=ax2.transAxes)
122
+ ax2.set_xlabel("Sampling Period (frames)")
123
+ ax2.set_ylabel("Validation Loss")
124
+
125
+ # Save figure
126
+ plt.tight_layout()
127
+ plt.savefig(os.path.join(fig_root, "hyperparameter.pdf"))
128
+ plt.savefig(os.path.join(fig_root, "hyperparameter.eps"))
129
+ plt.savefig(os.path.join(fig_root, "hyperparameter.png"))
130
+ plt.close(fig)
131
+
132
+
133
+ def load(root, model, frames, period, pretrained):
134
+ """Loads best validation loss for specified hyperparameter choice."""
135
+ pretrained = ("pretrained" if pretrained else "random")
136
+ f = os.path.join(
137
+ root,
138
+ "{}_{}_{}_{}".format(model, frames, period, pretrained),
139
+ "log.csv")
140
+ with open(f, "r") as f:
141
+ for line in f:
142
+ if "Best validation loss " in line:
143
+ return float(line.split()[3])
144
+
145
+ raise ValueError("File missing information.")
146
+
147
+
148
+ if __name__ == "__main__":
149
+ main()
dynamic/scripts/plot_loss.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """Code to generate plots for Extended Data Fig. 3."""
4
+
5
+ import argparse
6
+ import os
7
+ import matplotlib
8
+ import matplotlib.pyplot as plt
9
+
10
+ import echonet
11
+
12
+
13
+ def main():
14
+ """Generate plots for Extended Data Fig. 3."""
15
+
16
+ # Select paths and hyperparameter to plot
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("dir", nargs="?", default="output")
19
+ parser.add_argument("fig", nargs="?", default=os.path.join("figure", "loss"))
20
+ parser.add_argument("--frames", type=int, default=32)
21
+ parser.add_argument("--period", type=int, default=2)
22
+ args = parser.parse_args()
23
+
24
+ # Set up figure
25
+ echonet.utils.latexify()
26
+ os.makedirs(args.fig, exist_ok=True)
27
+ fig = plt.figure(figsize=(7, 5))
28
+ gs = matplotlib.gridspec.GridSpec(ncols=3, nrows=2, figure=fig, width_ratios=[2.75, 2.75, 1.50])
29
+
30
+ # Plot EF loss curve
31
+ ax0 = fig.add_subplot(gs[0, 0])
32
+ ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
33
+ for pretrained in [True]:
34
+ for (model, color) in zip(["r2plus1d_18", "r3d_18", "mc3_18"], matplotlib.colors.TABLEAU_COLORS):
35
+ loss = load(os.path.join(args.dir, "video", "{}_{}_{}_{}".format(model, args.frames, args.period, "pretrained" if pretrained else "random"), "log.csv"))
36
+ ax0.plot(range(1, 1 + len(loss["train"])), loss["train"], "-" if pretrained else "--", color=color)
37
+ ax1.plot(range(1, 1 + len(loss["val"])), loss["val"], "-" if pretrained else "--", color=color)
38
+
39
+ plt.axis([0, max(len(loss["train"]), len(loss["val"])), 0, max(max(loss["train"]), max(loss["val"]))])
40
+ ax0.text(-0.25, 1.00, "(a)", transform=ax0.transAxes)
41
+ ax1.text(-0.25, 1.00, "(b)", transform=ax1.transAxes)
42
+ ax0.set_xlabel("Epochs")
43
+ ax1.set_xlabel("Epochs")
44
+ ax0.set_xticks([0, 15, 30, 45])
45
+ ax1.set_xticks([0, 15, 30, 45])
46
+ ax0.set_ylabel("Training MSE Loss")
47
+ ax1.set_ylabel("Validation MSE Loss")
48
+
49
+ # Plot segmentation loss curve
50
+ ax0 = fig.add_subplot(gs[1, 0])
51
+ ax1 = fig.add_subplot(gs[1, 1], sharey=ax0)
52
+ pretrained = False
53
+ for (model, color) in zip(["deeplabv3_resnet50"], list(matplotlib.colors.TABLEAU_COLORS)[3:]):
54
+ loss = load(os.path.join(args.dir, "segmentation", "{}_{}".format(model, "pretrained" if pretrained else "random"), "log.csv"))
55
+ ax0.plot(range(1, 1 + len(loss["train"])), loss["train"], "--", color=color)
56
+ ax1.plot(range(1, 1 + len(loss["val"])), loss["val"], "--", color=color)
57
+
58
+ ax0.text(-0.25, 1.00, "(c)", transform=ax0.transAxes)
59
+ ax1.text(-0.25, 1.00, "(d)", transform=ax1.transAxes)
60
+ ax0.set_ylim([0, 0.13])
61
+ ax0.set_xlabel("Epochs")
62
+ ax1.set_xlabel("Epochs")
63
+ ax0.set_xticks([0, 25, 50])
64
+ ax1.set_xticks([0, 25, 50])
65
+ ax0.set_ylabel("Training Cross Entropy Loss")
66
+ ax1.set_ylabel("Validation Cross Entropy Loss")
67
+
68
+ # Legend
69
+ ax = fig.add_subplot(gs[:, 2])
70
+ for (model, color) in zip(["EchoNet-Dynamic (EF)", "R3D", "MC3", "EchoNet-Dynamic (Seg)"], matplotlib.colors.TABLEAU_COLORS):
71
+ ax.plot([float("nan")], [float("nan")], "-", color=color, label=model)
72
+ ax.set_title("")
73
+ ax.axis("off")
74
+ ax.legend(loc="center")
75
+
76
+ plt.tight_layout()
77
+ plt.savefig(os.path.join(args.fig, "loss.pdf"))
78
+ plt.savefig(os.path.join(args.fig, "loss.eps"))
79
+ plt.savefig(os.path.join(args.fig, "loss.png"))
80
+ plt.close(fig)
81
+
82
+
83
+ def load(filename):
84
+ """Loads losses from specified file."""
85
+
86
+ losses = {"train": [], "val": []}
87
+ with open(filename, "r") as f:
88
+ for line in f:
89
+ line = line.split(",")
90
+ if len(line) < 4:
91
+ continue
92
+ epoch, split, loss, *_ = line
93
+ epoch = int(epoch)
94
+ loss = float(loss)
95
+ assert(split in ["train", "val"])
96
+ if epoch == len(losses[split]):
97
+ losses[split].append(loss)
98
+ elif epoch == len(losses[split]) - 1:
99
+ losses[split][-1] = loss
100
+ else:
101
+ raise ValueError("File has uninterpretable formatting.")
102
+ return losses
103
+
104
+
105
+ if __name__ == "__main__":
106
+ main()
dynamic/scripts/plot_simulated_noise.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """Code to generate plots for Extended Data Fig. 6."""
4
+
5
+ import os
6
+ import pickle
7
+
8
+ import matplotlib
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import PIL
12
+ import sklearn
13
+ import torch
14
+ import torchvision
15
+
16
+ import echonet
17
+
18
+
19
+ def main(fig_root=os.path.join("figure", "noise"),
20
+ video_output=os.path.join("output", "video", "r2plus1d_18_32_2_pretrained"),
21
+ seg_output=os.path.join("output", "segmentation", "deeplabv3_resnet50_random"),
22
+ NOISE=(0, 0.1, 0.2, 0.3, 0.4, 0.5)):
23
+ """Generate plots for Extended Data Fig. 6."""
24
+
25
+ device = torch.device("cuda")
26
+
27
+ filename = os.path.join(fig_root, "data.pkl") # Cache of results
28
+ try:
29
+ # Attempt to load cache
30
+ with open(filename, "rb") as f:
31
+ Y, YHAT, INTER, UNION = pickle.load(f)
32
+ except FileNotFoundError:
33
+ # Generate results if no cache available
34
+ os.makedirs(fig_root, exist_ok=True)
35
+
36
+ # Load trained video model
37
+ model_v = torchvision.models.video.r2plus1d_18()
38
+ model_v.fc = torch.nn.Linear(model_v.fc.in_features, 1)
39
+ if device.type == "cuda":
40
+ model_v = torch.nn.DataParallel(model_v)
41
+ model_v.to(device)
42
+
43
+ checkpoint = torch.load(os.path.join(video_output, "checkpoint.pt"))
44
+ model_v.load_state_dict(checkpoint['state_dict'])
45
+
46
+ # Load trained segmentation model
47
+ model_s = torchvision.models.segmentation.deeplabv3_resnet50(aux_loss=False)
48
+ model_s.classifier[-1] = torch.nn.Conv2d(model_s.classifier[-1].in_channels, 1, kernel_size=model_s.classifier[-1].kernel_size)
49
+ if device.type == "cuda":
50
+ model_s = torch.nn.DataParallel(model_s)
51
+ model_s.to(device)
52
+
53
+ checkpoint = torch.load(os.path.join(seg_output, "checkpoint.pt"))
54
+ model_s.load_state_dict(checkpoint['state_dict'])
55
+
56
+ # Run simulation
57
+ dice = []
58
+ mse = []
59
+ r2 = []
60
+ Y = []
61
+ YHAT = []
62
+ INTER = []
63
+ UNION = []
64
+ for noise in NOISE:
65
+ Y.append([])
66
+ YHAT.append([])
67
+ INTER.append([])
68
+ UNION.append([])
69
+
70
+ dataset = echonet.datasets.Echo(split="test", noise=noise)
71
+ PIL.Image.fromarray(dataset[0][0][:, 0, :, :].astype(np.uint8).transpose(1, 2, 0)).save(os.path.join(fig_root, "noise_{}.tif".format(round(100 * noise))))
72
+
73
+ mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(split="train"))
74
+
75
+ tasks = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"]
76
+ kwargs = {
77
+ "target_type": tasks,
78
+ "mean": mean,
79
+ "std": std,
80
+ "noise": noise
81
+ }
82
+ dataset = echonet.datasets.Echo(split="test", **kwargs)
83
+
84
+ dataloader = torch.utils.data.DataLoader(dataset,
85
+ batch_size=16, num_workers=5, shuffle=True, pin_memory=(device.type == "cuda"))
86
+
87
+ loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model_s, dataloader, "test", None, device)
88
+ inter = np.concatenate((large_inter, small_inter)).sum()
89
+ union = np.concatenate((large_union, small_union)).sum()
90
+ dice.append(2 * inter / (union + inter))
91
+
92
+ INTER[-1].extend(large_inter.tolist() + small_inter.tolist())
93
+ UNION[-1].extend(large_union.tolist() + small_union.tolist())
94
+
95
+ kwargs = {"target_type": "EF",
96
+ "mean": mean,
97
+ "std": std,
98
+ "length": 32,
99
+ "period": 2,
100
+ "noise": noise
101
+ }
102
+
103
+ dataset = echonet.datasets.Echo(split="test", **kwargs)
104
+
105
+ dataloader = torch.utils.data.DataLoader(dataset,
106
+ batch_size=16, num_workers=5, shuffle=True, pin_memory=(device.type == "cuda"))
107
+ loss, yhat, y = echonet.utils.video.run_epoch(model_v, dataloader, "test", None, device)
108
+ mse.append(loss)
109
+ r2.append(sklearn.metrics.r2_score(y, yhat))
110
+ Y[-1].extend(y.tolist())
111
+ YHAT[-1].extend(yhat.tolist())
112
+
113
+ # Save results in cache
114
+ with open(filename, "wb") as f:
115
+ pickle.dump((Y, YHAT, INTER, UNION), f)
116
+
117
+ # Set up plot
118
+ echonet.utils.latexify()
119
+
120
+ NOISE = list(map(lambda x: round(100 * x), NOISE))
121
+ fig = plt.figure(figsize=(6.50, 4.75))
122
+ gs = matplotlib.gridspec.GridSpec(3, 1, height_ratios=[2.0, 2.0, 0.75])
123
+ ax = (plt.subplot(gs[0]), plt.subplot(gs[1]), plt.subplot(gs[2]))
124
+
125
+ # Plot EF prediction results (R^2)
126
+ r2 = [sklearn.metrics.r2_score(y, yhat) for (y, yhat) in zip(Y, YHAT)]
127
+ ax[0].plot(NOISE, r2, color="k", linewidth=1, marker=".")
128
+ ax[0].set_xticks([])
129
+ ax[0].set_ylabel("R$^2$")
130
+ l, h = min(r2), max(r2)
131
+ l, h = l - 0.1 * (h - l), h + 0.1 * (h - l)
132
+ ax[0].axis([min(NOISE) - 5, max(NOISE) + 5, 0, 1])
133
+
134
+ # Plot segmentation results (DSC)
135
+ dice = [echonet.utils.dice_similarity_coefficient(inter, union) for (inter, union) in zip(INTER, UNION)]
136
+ ax[1].plot(NOISE, dice, color="k", linewidth=1, marker=".")
137
+ ax[1].set_xlabel("Pixels Removed (%)")
138
+ ax[1].set_ylabel("DSC")
139
+ l, h = min(dice), max(dice)
140
+ l, h = l - 0.1 * (h - l), h + 0.1 * (h - l)
141
+ ax[1].axis([min(NOISE) - 5, max(NOISE) + 5, 0, 1])
142
+
143
+ # Add example images below
144
+ for noise in NOISE:
145
+ image = matplotlib.image.imread(os.path.join(fig_root, "noise_{}.tif".format(noise)))
146
+ imagebox = matplotlib.offsetbox.OffsetImage(image, zoom=0.4)
147
+ ab = matplotlib.offsetbox.AnnotationBbox(imagebox, (noise, 0.0), frameon=False)
148
+ ax[2].add_artist(ab)
149
+ ax[2].axis("off")
150
+ ax[2].axis([min(NOISE) - 5, max(NOISE) + 5, -1, 1])
151
+
152
+ fig.tight_layout()
153
+ plt.savefig(os.path.join(fig_root, "noise.pdf"), dpi=1200)
154
+ plt.savefig(os.path.join(fig_root, "noise.eps"), dpi=300)
155
+ plt.savefig(os.path.join(fig_root, "noise.png"), dpi=600)
156
+ plt.close(fig)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ main()
dynamic/scripts/run_experiments.sh ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ for pretrained in True False
4
+ do
5
+ for model in r2plus1d_18 r3d_18 mc3_18
6
+ do
7
+ for frames in 96 64 32 16 8 4 1
8
+ do
9
+ batch=$((256 / frames))
10
+ batch=$(( batch > 16 ? 16 : batch ))
11
+
12
+ cmd="import echonet; echonet.utils.video.run(modelname=\"${model}\", frames=${frames}, period=1, pretrained=${pretrained}, batch_size=${batch})"
13
+ python3 -c "${cmd}"
14
+ done
15
+ for period in 2 4 6 8
16
+ do
17
+ batch=$((256 / 64 * period))
18
+ batch=$(( batch > 16 ? 16 : batch ))
19
+
20
+ cmd="import echonet; echonet.utils.video.run(modelname=\"${model}\", frames=(64 // ${period}), period=${period}, pretrained=${pretrained}, batch_size=${batch})"
21
+ python3 -c "${cmd}"
22
+ done
23
+ done
24
+ done
25
+
26
+ period=2
27
+ pretrained=True
28
+ for model in r2plus1d_18 r3d_18 mc3_18
29
+ do
30
+ cmd="import echonet; echonet.utils.video.run(modelname=\"${model}\", frames=(64 // ${period}), period=${period}, pretrained=${pretrained}, run_test=True)"
31
+ python3 -c "${cmd}"
32
+ done
33
+
34
+ python3 -c "import echonet; echonet.utils.segmentation.run(modelname=\"deeplabv3_resnet50\", save_segmentation=True, pretrained=False)"
35
+
36
+ pretrained=True
37
+ model=r2plus1d_18
38
+ period=2
39
+ batch=$((256 / 64 * period))
40
+ batch=$(( batch > 16 ? 16 : batch ))
41
+ for patients in 16 32 64 128 256 512 1024 2048 4096 7460
42
+ do
43
+ cmd="import echonet; echonet.utils.video.run(modelname=\"${model}\", frames=(64 // ${period}), period=${period}, pretrained=${pretrained}, batch_size=${batch}, num_epochs=min(50 * (8192 // ${patients}), 200), output=\"output/training_size/video/${patients}\", n_train_patients=${patients})"
44
+ python3 -c "${cmd}"
45
+ cmd="import echonet; echonet.utils.segmentation.run(modelname=\"deeplabv3_resnet50\", pretrained=False, num_epochs=min(50 * (8192 // ${patients}), 200), output=\"output/training_size/segmentation/${patients}\", n_train_patients=${patients})"
46
+ python3 -c "${cmd}"
47
+
48
+ done
49
+
dynamic/setup.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Metadata for package to allow installation with pip."""
3
+
4
+ import os
5
+
6
+ import setuptools
7
+
8
+ with open("README.md", "r") as fh:
9
+ long_description = fh.read()
10
+
11
+ # Use same version from code
12
+ # See 3 from
13
+ # https://packaging.python.org/guides/single-sourcing-package-version/
14
+ version = {}
15
+ with open(os.path.join("echonet", "__version__.py")) as f:
16
+ exec(f.read(), version) # pylint: disable=W0122
17
+
18
+ setuptools.setup(
19
+ name="echonet",
20
+ description="Video-based AI for beat-to-beat cardiac function assessment.",
21
+ version=version["__version__"],
22
+ url="https://echonet.github.io/dynamic",
23
+ packages=setuptools.find_packages(),
24
+ install_requires=[
25
+ "click",
26
+ "numpy",
27
+ "pandas",
28
+ "torch",
29
+ "torchvision",
30
+ "opencv-python",
31
+ "scikit-image",
32
+ "tqdm",
33
+ "sklearn"
34
+ ],
35
+ classifiers=[
36
+ "Programming Language :: Python :: 3",
37
+ ],
38
+ entry_points={
39
+ "console_scripts": [
40
+ "echonet=echonet:main",
41
+ ],
42
+ }
43
+
44
+ )
echonet/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The echonet package contains code for loading echocardiogram videos, and
3
+ functions for training and testing segmentation and ejection fraction
4
+ prediction models.
5
+ """
6
+
7
+ import click
8
+
9
+ from echonet.__version__ import __version__
10
+ from echonet.config import CONFIG as config
11
+ import echonet.datasets as datasets
12
+ import echonet.utils as utils
13
+
14
+
15
+ @click.group()
16
+ def main():
17
+ """Entry point for command line interface."""
18
+
19
+
20
+ del click
21
+
22
+
23
+ main.add_command(utils.segmentation.run)
24
+ main.add_command(utils.video.run)
25
+
26
+ __all__ = ["__version__", "config", "datasets", "main", "utils"]
echonet/__main__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Entry point for command line."""
2
+
3
+ import echonet
4
+
5
+
6
+ if __name__ == '__main__':
7
+ echonet.main()
echonet/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.14 kB). View file
 
echonet/__pycache__/__version__.cpython-311.pyc ADDED
Binary file (263 Bytes). View file
 
echonet/__pycache__/config.cpython-311.pyc ADDED
Binary file (1.38 kB). View file
 
echonet/__version__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """Version number for Echonet package."""
2
+
3
+ __version__ = "1.0.0"
echonet/config.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sets paths based on configuration files."""
2
+
3
+ import configparser
4
+ import os
5
+ import types
6
+
7
+ _FILENAME = None
8
+ _PARAM = {}
9
+ for filename in ["echonet.cfg",
10
+ ".echonet.cfg",
11
+ os.path.expanduser("~/echonet.cfg"),
12
+ os.path.expanduser("~/.echonet.cfg"),
13
+ ]:
14
+ if os.path.isfile(filename):
15
+ _FILENAME = filename
16
+ config = configparser.ConfigParser()
17
+ with open(filename, "r") as f:
18
+ config.read_string("[config]\n" + f.read())
19
+ _PARAM = config["config"]
20
+ break
21
+
22
+ CONFIG = types.SimpleNamespace(
23
+ FILENAME=_FILENAME,
24
+ DATA_DIR=_PARAM.get("data_dir", "a4c-video-dir/"))
echonet/datasets/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The echonet.datasets submodule defines a Pytorch dataset for loading
3
+ echocardiogram videos.
4
+ """
5
+
6
+ from .echo import Echo
7
+
8
+ __all__ = ["Echo"]
echonet/datasets/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (385 Bytes). View file
 
echonet/datasets/__pycache__/echo.cpython-311.pyc ADDED
Binary file (18.8 kB). View file
 
echonet/datasets/echo.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """EchoNet-Dynamic Dataset."""
2
+
3
+ import os
4
+ import collections
5
+ import pandas
6
+
7
+ import numpy as np
8
+ import skimage.draw
9
+ import torchvision
10
+ import echonet
11
+
12
+
13
+ class Echo(torchvision.datasets.VisionDataset):
14
+ """EchoNet-Dynamic Dataset.
15
+
16
+ Args:
17
+ root (string): Root directory of dataset (defaults to `echonet.config.DATA_DIR`)
18
+ split (string): One of {``train'', ``val'', ``test'', ``all'', or ``external_test''}
19
+ target_type (string or list, optional): Type of target to use,
20
+ ``Filename'', ``EF'', ``EDV'', ``ESV'', ``LargeIndex'',
21
+ ``SmallIndex'', ``LargeFrame'', ``SmallFrame'', ``LargeTrace'',
22
+ or ``SmallTrace''
23
+ Can also be a list to output a tuple with all specified target types.
24
+ The targets represent:
25
+ ``Filename'' (string): filename of video
26
+ ``EF'' (float): ejection fraction
27
+ ``EDV'' (float): end-diastolic volume
28
+ ``ESV'' (float): end-systolic volume
29
+ ``LargeIndex'' (int): index of large (diastolic) frame in video
30
+ ``SmallIndex'' (int): index of small (systolic) frame in video
31
+ ``LargeFrame'' (np.array shape=(3, height, width)): normalized large (diastolic) frame
32
+ ``SmallFrame'' (np.array shape=(3, height, width)): normalized small (systolic) frame
33
+ ``LargeTrace'' (np.array shape=(height, width)): left ventricle large (diastolic) segmentation
34
+ value of 0 indicates pixel is outside left ventricle
35
+ 1 indicates pixel is inside left ventricle
36
+ ``SmallTrace'' (np.array shape=(height, width)): left ventricle small (systolic) segmentation
37
+ value of 0 indicates pixel is outside left ventricle
38
+ 1 indicates pixel is inside left ventricle
39
+ Defaults to ``EF''.
40
+ mean (int, float, or np.array shape=(3,), optional): means for all (if scalar) or each (if np.array) channel.
41
+ Used for normalizing the video. Defaults to 0 (video is not shifted).
42
+ std (int, float, or np.array shape=(3,), optional): standard deviation for all (if scalar) or each (if np.array) channel.
43
+ Used for normalizing the video. Defaults to 0 (video is not scaled).
44
+ length (int or None, optional): Number of frames to clip from video. If ``None'', longest possible clip is returned.
45
+ Defaults to 16.
46
+ period (int, optional): Sampling period for taking a clip from the video (i.e. every ``period''-th frame is taken)
47
+ Defaults to 2.
48
+ max_length (int or None, optional): Maximum number of frames to clip from video (main use is for shortening excessively
49
+ long videos when ``length'' is set to None). If ``None'', shortening is not applied to any video.
50
+ Defaults to 250.
51
+ clips (int, optional): Number of clips to sample. Main use is for test-time augmentation with random clips.
52
+ Defaults to 1.
53
+ pad (int or None, optional): Number of pixels to pad all frames on each side (used as augmentation).
54
+ and a window of the original size is taken. If ``None'', no padding occurs.
55
+ Defaults to ``None''.
56
+ noise (float or None, optional): Fraction of pixels to black out as simulated noise. If ``None'', no simulated noise is added.
57
+ Defaults to ``None''.
58
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
59
+ external_test_location (string): Path to videos to use for external testing.
60
+ """
61
+
62
+ def __init__(self, root=None,
63
+ split="train", target_type="EF",
64
+ mean=0., std=1.,
65
+ length=16, period=2,
66
+ max_length=250,
67
+ clips=1,
68
+ pad=None,
69
+ noise=None,
70
+ target_transform=None,
71
+ external_test_location=None):
72
+ if root is None:
73
+ root = echonet.config.DATA_DIR
74
+
75
+ super().__init__(root, target_transform=target_transform)
76
+
77
+ self.split = split.upper()
78
+ if not isinstance(target_type, list):
79
+ target_type = [target_type]
80
+ self.target_type = target_type
81
+ self.mean = mean
82
+ self.std = std
83
+ self.length = length
84
+ self.max_length = max_length
85
+ self.period = period
86
+ self.clips = clips
87
+ self.pad = pad
88
+ self.noise = noise
89
+ self.target_transform = target_transform
90
+ self.external_test_location = external_test_location
91
+
92
+ self.fnames, self.outcome = [], []
93
+
94
+ if self.split == "EXTERNAL_TEST":
95
+ self.fnames = sorted(os.listdir(self.external_test_location))
96
+ else:
97
+ # Load video-level labels
98
+ with open(os.path.join(self.root, "FileList.csv")) as f:
99
+ data = pandas.read_csv(f)
100
+ data["Split"].map(lambda x: x.upper())
101
+
102
+ if self.split != "ALL":
103
+ data = data[data["Split"] == self.split]
104
+
105
+ self.header = data.columns.tolist()
106
+ self.fnames = data["FileName"].tolist()
107
+ self.fnames = [fn + ".avi" for fn in self.fnames if os.path.splitext(fn)[1] == ""] # Assume avi if no suffix
108
+ self.outcome = data.values.tolist()
109
+
110
+ # Check that files are present
111
+ missing = set(self.fnames) - set(os.listdir(os.path.join(self.root, "Videos")))
112
+ if len(missing) != 0:
113
+ print("{} videos could not be found in {}:".format(len(missing), os.path.join(self.root, "Videos")))
114
+ for f in sorted(missing):
115
+ print("\t", f)
116
+ raise FileNotFoundError(os.path.join(self.root, "Videos", sorted(missing)[0]))
117
+
118
+ # Load traces
119
+ self.frames = collections.defaultdict(list)
120
+ self.trace = collections.defaultdict(_defaultdict_of_lists)
121
+
122
+ with open(os.path.join(self.root, "VolumeTracings.csv")) as f:
123
+ header = f.readline().strip().split(",")
124
+ assert header == ["FileName", "X1", "Y1", "X2", "Y2", "Frame"]
125
+
126
+ for line in f:
127
+ filename, x1, y1, x2, y2, frame = line.strip().split(',')
128
+ x1 = float(x1)
129
+ y1 = float(y1)
130
+ x2 = float(x2)
131
+ y2 = float(y2)
132
+ frame = int(frame)
133
+ if frame not in self.trace[filename]:
134
+ self.frames[filename].append(frame)
135
+ self.trace[filename][frame].append((x1, y1, x2, y2))
136
+ for filename in self.frames:
137
+ for frame in self.frames[filename]:
138
+ self.trace[filename][frame] = np.array(self.trace[filename][frame])
139
+
140
+ # A small number of videos are missing traces; remove these videos
141
+ keep = [len(self.frames[f]) >= 2 for f in self.fnames]
142
+ self.fnames = [f for (f, k) in zip(self.fnames, keep) if k]
143
+ self.outcome = [f for (f, k) in zip(self.outcome, keep) if k]
144
+
145
+ def __getitem__(self, index):
146
+ # Find filename of video
147
+ if self.split == "EXTERNAL_TEST":
148
+ video = os.path.join(self.external_test_location, self.fnames[index])
149
+ elif self.split == "CLINICAL_TEST":
150
+ video = os.path.join(self.root, "ProcessedStrainStudyA4c", self.fnames[index])
151
+ else:
152
+ video = os.path.join(self.root, "Videos", self.fnames[index])
153
+
154
+ # Load video into np.array
155
+ video = echonet.utils.loadvideo(video).astype(np.float32)
156
+
157
+ # Add simulated noise (black out random pixels)
158
+ # 0 represents black at this point (video has not been normalized yet)
159
+ if self.noise is not None:
160
+ n = video.shape[1] * video.shape[2] * video.shape[3]
161
+ ind = np.random.choice(n, round(self.noise * n), replace=False)
162
+ f = ind % video.shape[1]
163
+ ind //= video.shape[1]
164
+ i = ind % video.shape[2]
165
+ ind //= video.shape[2]
166
+ j = ind
167
+ video[:, f, i, j] = 0
168
+
169
+ # Apply normalization
170
+ if isinstance(self.mean, (float, int)):
171
+ video -= self.mean
172
+ else:
173
+ video -= self.mean.reshape(3, 1, 1, 1)
174
+
175
+ if isinstance(self.std, (float, int)):
176
+ video /= self.std
177
+ else:
178
+ video /= self.std.reshape(3, 1, 1, 1)
179
+
180
+ # Set number of frames
181
+ c, f, h, w = video.shape
182
+ if self.length is None:
183
+ # Take as many frames as possible
184
+ length = f // self.period
185
+ else:
186
+ # Take specified number of frames
187
+ length = self.length
188
+
189
+ if self.max_length is not None:
190
+ # Shorten videos to max_length
191
+ length = min(length, self.max_length)
192
+
193
+ if f < length * self.period:
194
+ # Pad video with frames filled with zeros if too short
195
+ # 0 represents the mean color (dark grey), since this is after normalization
196
+ video = np.concatenate((video, np.zeros((c, length * self.period - f, h, w), video.dtype)), axis=1)
197
+ c, f, h, w = video.shape # pylint: disable=E0633
198
+
199
+ if self.clips == "all":
200
+ # Take all possible clips of desired length
201
+ start = np.arange(f - (length - 1) * self.period)
202
+ else:
203
+ # Take random clips from video
204
+ start = np.random.choice(f - (length - 1) * self.period, self.clips)
205
+
206
+ # Gather targets
207
+ target = []
208
+ for t in self.target_type:
209
+ key = self.fnames[index]
210
+ if t == "Filename":
211
+ target.append(self.fnames[index])
212
+ elif t == "LargeIndex":
213
+ # Traces are sorted by cross-sectional area
214
+ # Largest (diastolic) frame is last
215
+ target.append(np.int(self.frames[key][-1]))
216
+ elif t == "SmallIndex":
217
+ # Largest (diastolic) frame is first
218
+ target.append(np.int(self.frames[key][0]))
219
+ elif t == "LargeFrame":
220
+ target.append(video[:, self.frames[key][-1], :, :])
221
+ elif t == "SmallFrame":
222
+ target.append(video[:, self.frames[key][0], :, :])
223
+ elif t in ["LargeTrace", "SmallTrace"]:
224
+ if t == "LargeTrace":
225
+ t = self.trace[key][self.frames[key][-1]]
226
+ else:
227
+ t = self.trace[key][self.frames[key][0]]
228
+ x1, y1, x2, y2 = t[:, 0], t[:, 1], t[:, 2], t[:, 3]
229
+ x = np.concatenate((x1[1:], np.flip(x2[1:])))
230
+ y = np.concatenate((y1[1:], np.flip(y2[1:])))
231
+
232
+ r, c = skimage.draw.polygon(np.rint(y).astype(np.int), np.rint(x).astype(np.int), (video.shape[2], video.shape[3]))
233
+ mask = np.zeros((video.shape[2], video.shape[3]), np.float32)
234
+ mask[r, c] = 1
235
+ target.append(mask)
236
+ else:
237
+ if self.split == "CLINICAL_TEST" or self.split == "EXTERNAL_TEST":
238
+ target.append(np.float32(0))
239
+ else:
240
+ target.append(np.float32(self.outcome[index][self.header.index(t)]))
241
+
242
+ if target != []:
243
+ target = tuple(target) if len(target) > 1 else target[0]
244
+ if self.target_transform is not None:
245
+ target = self.target_transform(target)
246
+
247
+ # Select clips from video
248
+ video = tuple(video[:, s + self.period * np.arange(length), :, :] for s in start)
249
+ if self.clips == 1:
250
+ video = video[0]
251
+ else:
252
+ video = np.stack(video)
253
+
254
+ if self.pad is not None:
255
+ # Add padding of zeros (mean color of videos)
256
+ # Crop of original size is taken out
257
+ # (Used as augmentation)
258
+ c, l, h, w = video.shape
259
+ temp = np.zeros((c, l, h + 2 * self.pad, w + 2 * self.pad), dtype=video.dtype)
260
+ temp[:, :, self.pad:-self.pad, self.pad:-self.pad] = video # pylint: disable=E1130
261
+ i, j = np.random.randint(0, 2 * self.pad, 2)
262
+ video = temp[:, :, i:(i + h), j:(j + w)]
263
+
264
+ return video, target
265
+
266
+ def __len__(self):
267
+ return len(self.fnames)
268
+
269
+ def extra_repr(self) -> str:
270
+ """Additional information to add at end of __repr__."""
271
+ lines = ["Target type: {target_type}", "Split: {split}"]
272
+ return '\n'.join(lines).format(**self.__dict__)
273
+
274
+
275
+ def _defaultdict_of_lists():
276
+ """Returns a defaultdict of lists.
277
+
278
+ This is used to avoid issues with Windows (if this function is anonymous,
279
+ the Echo dataset cannot be used in a dataloader).
280
+ """
281
+
282
+ return collections.defaultdict(list)
echonet/utils/__init__.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for videos, plotting and computing performance metrics."""
2
+
3
+ import os
4
+ import typing
5
+
6
+ import cv2 # pytype: disable=attribute-error
7
+ import matplotlib
8
+ import numpy as np
9
+ import torch
10
+ import tqdm
11
+
12
+ from . import video
13
+ from . import segmentation
14
+
15
+
16
+ def loadvideo(filename: str) -> np.ndarray:
17
+ """Loads a video from a file.
18
+
19
+ Args:
20
+ filename (str): filename of video
21
+
22
+ Returns:
23
+ A np.ndarray with dimensions (channels=3, frames, height, width). The
24
+ values will be uint8's ranging from 0 to 255.
25
+
26
+ Raises:
27
+ FileNotFoundError: Could not find `filename`
28
+ ValueError: An error occurred while reading the video
29
+ """
30
+
31
+ if not os.path.exists(filename):
32
+ raise FileNotFoundError(filename)
33
+ capture = cv2.VideoCapture(filename)
34
+
35
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
36
+ frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
37
+ frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
38
+
39
+ v = np.zeros((frame_count, frame_height, frame_width, 3), np.uint8)
40
+
41
+ for count in range(frame_count):
42
+ ret, frame = capture.read()
43
+ if not ret:
44
+ raise ValueError("Failed to load frame #{} of {}.".format(count, filename))
45
+
46
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
47
+ v[count, :, :] = frame
48
+
49
+ v = v.transpose((3, 0, 1, 2))
50
+
51
+ return v
52
+
53
+
54
+ def savevideo(filename: str, array: np.ndarray, fps: typing.Union[float, int] = 1):
55
+ """Saves a video to a file.
56
+
57
+ Args:
58
+ filename (str): filename of video
59
+ array (np.ndarray): video of uint8's with shape (channels=3, frames, height, width)
60
+ fps (float or int): frames per second
61
+
62
+ Returns:
63
+ None
64
+ """
65
+
66
+ c, _, height, width = array.shape
67
+
68
+ if c != 3:
69
+ raise ValueError("savevideo expects array of shape (channels=3, frames, height, width), got shape ({})".format(", ".join(map(str, array.shape))))
70
+ fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
71
+ out = cv2.VideoWriter(filename, fourcc, fps, (width, height))
72
+
73
+ for frame in array.transpose((1, 2, 3, 0)):
74
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
75
+ out.write(frame)
76
+
77
+
78
+ def get_mean_and_std(dataset: torch.utils.data.Dataset,
79
+ samples: int = 128,
80
+ batch_size: int = 8,
81
+ num_workers: int = 4):
82
+ """Computes mean and std from samples from a Pytorch dataset.
83
+
84
+ Args:
85
+ dataset (torch.utils.data.Dataset): A Pytorch dataset.
86
+ ``dataset[i][0]'' is expected to be the i-th video in the dataset, which
87
+ should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width)
88
+ samples (int or None, optional): Number of samples to take from dataset. If ``None'', mean and
89
+ standard deviation are computed over all elements.
90
+ Defaults to 128.
91
+ batch_size (int, optional): how many samples per batch to load
92
+ Defaults to 8.
93
+ num_workers (int, optional): how many subprocesses to use for data
94
+ loading. If 0, the data will be loaded in the main process.
95
+ Defaults to 4.
96
+
97
+ Returns:
98
+ A tuple of the mean and standard deviation. Both are represented as np.array's of dimension (channels,).
99
+ """
100
+
101
+ if samples is not None and len(dataset) > samples:
102
+ indices = np.random.choice(len(dataset), samples, replace=False)
103
+ dataset = torch.utils.data.Subset(dataset, indices)
104
+ dataloader = torch.utils.data.DataLoader(
105
+ dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
106
+
107
+ n = 0 # number of elements taken (should be equal to samples by end of for loop)
108
+ s1 = 0. # sum of elements along channels (ends up as np.array of dimension (channels,))
109
+ s2 = 0. # sum of squares of elements along channels (ends up as np.array of dimension (channels,))
110
+ for (x, *_) in tqdm.tqdm(dataloader):
111
+ x = x.transpose(0, 1).contiguous().view(3, -1)
112
+ n += x.shape[1]
113
+ s1 += torch.sum(x, dim=1).numpy()
114
+ s2 += torch.sum(x ** 2, dim=1).numpy()
115
+ mean = s1 / n # type: np.ndarray
116
+ std = np.sqrt(s2 / n - mean ** 2) # type: np.ndarray
117
+
118
+ mean = mean.astype(np.float32)
119
+ std = std.astype(np.float32)
120
+
121
+ return mean, std
122
+
123
+
124
+ def bootstrap(a, b, func, samples=10000):
125
+ """Computes a bootstrapped confidence intervals for ``func(a, b)''.
126
+
127
+ Args:
128
+ a (array_like): first argument to `func`.
129
+ b (array_like): second argument to `func`.
130
+ func (callable): Function to compute confidence intervals for.
131
+ ``dataset[i][0]'' is expected to be the i-th video in the dataset, which
132
+ should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width)
133
+ samples (int, optional): Number of samples to compute.
134
+ Defaults to 10000.
135
+
136
+ Returns:
137
+ A tuple of (`func(a, b)`, estimated 5-th percentile, estimated 95-th percentile).
138
+ """
139
+ a = np.array(a)
140
+ b = np.array(b)
141
+
142
+ bootstraps = []
143
+ for _ in range(samples):
144
+ ind = np.random.choice(len(a), len(a))
145
+ bootstraps.append(func(a[ind], b[ind]))
146
+ bootstraps = sorted(bootstraps)
147
+
148
+ return func(a, b), bootstraps[round(0.05 * len(bootstraps))], bootstraps[round(0.95 * len(bootstraps))]
149
+
150
+
151
+ def latexify():
152
+ """Sets matplotlib params to appear more like LaTeX.
153
+
154
+ Based on https://nipunbatra.github.io/blog/2014/latexify.html
155
+ """
156
+ params = {'backend': 'pdf',
157
+ 'axes.titlesize': 8,
158
+ 'axes.labelsize': 8,
159
+ 'font.size': 8,
160
+ 'legend.fontsize': 8,
161
+ 'xtick.labelsize': 8,
162
+ 'ytick.labelsize': 8,
163
+ 'font.family': 'DejaVu Serif',
164
+ 'font.serif': 'Computer Modern',
165
+ }
166
+ matplotlib.rcParams.update(params)
167
+
168
+
169
+ def dice_similarity_coefficient(inter, union):
170
+ """Computes the dice similarity coefficient.
171
+
172
+ Args:
173
+ inter (iterable): iterable of the intersections
174
+ union (iterable): iterable of the unions
175
+ """
176
+ return 2 * sum(inter) / (sum(union) + sum(inter))
177
+
178
+
179
+ __all__ = ["video", "segmentation", "loadvideo", "savevideo", "get_mean_and_std", "bootstrap", "latexify", "dice_similarity_coefficient"]
echonet/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (9.56 kB). View file
 
echonet/utils/__pycache__/segmentation.cpython-311.pyc ADDED
Binary file (39.2 kB). View file
 
echonet/utils/__pycache__/video.cpython-311.pyc ADDED
Binary file (27 kB). View file