Spaces:
Running
Running
"""Functions for training and running segmentation.""" | |
import math | |
import os | |
import time | |
import click | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import scipy.signal | |
import skimage.draw | |
import torch | |
import torchvision | |
import tqdm | |
import echonet | |
def run( | |
data_dir=None, | |
output=None, | |
model_name="deeplabv3_resnet50", | |
pretrained=False, | |
weights=None, | |
run_test=False, | |
save_video=False, | |
num_epochs=50, | |
lr=1e-5, | |
weight_decay=1e-5, | |
lr_step_period=None, | |
num_train_patients=None, | |
num_workers=4, | |
batch_size=20, | |
device=None, | |
seed=0, | |
): | |
"""Trains/tests segmentation model. | |
Args: | |
data_dir (str, optional): Directory containing dataset. Defaults to | |
`echonet.config.DATA_DIR`. | |
output (str, optional): Directory to place outputs. Defaults to | |
output/segmentation/<model_name>_<pretrained/random>/. | |
model_name (str, optional): Name of segmentation model. One of ``deeplabv3_resnet50'', | |
``deeplabv3_resnet101'', ``fcn_resnet50'', or ``fcn_resnet101'' | |
(options are torchvision.models.segmentation.<model_name>) | |
Defaults to ``deeplabv3_resnet50''. | |
pretrained (bool, optional): Whether to use pretrained weights for model | |
Defaults to False. | |
weights (str, optional): Path to checkpoint containing weights to | |
initialize model. Defaults to None. | |
run_test (bool, optional): Whether or not to run on test. | |
Defaults to False. | |
save_video (bool, optional): Whether to save videos with segmentations. | |
Defaults to False. | |
num_epochs (int, optional): Number of epochs during training | |
Defaults to 50. | |
lr (float, optional): Learning rate for SGD | |
Defaults to 1e-5. | |
weight_decay (float, optional): Weight decay for SGD | |
Defaults to 0. | |
lr_step_period (int or None, optional): Period of learning rate decay | |
(learning rate is decayed by a multiplicative factor of 0.1) | |
Defaults to math.inf (never decay learning rate). | |
num_train_patients (int or None, optional): Number of training patients | |
for ablations. Defaults to all patients. | |
num_workers (int, optional): Number of subprocesses to use for data | |
loading. If 0, the data will be loaded in the main process. | |
Defaults to 4. | |
device (str or None, optional): Name of device to run on. Options from | |
https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device | |
Defaults to ``cuda'' if available, and ``cpu'' otherwise. | |
batch_size (int, optional): Number of samples to load per batch | |
Defaults to 20. | |
seed (int, optional): Seed for random number generator. Defaults to 0. | |
""" | |
# Seed RNGs | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
# Set default output directory | |
if output is None: | |
output = os.path.join("output", "segmentation", "{}_{}".format(model_name, "pretrained" if pretrained else "random")) | |
os.makedirs(output, exist_ok=True) | |
# Set device for computations | |
if device is None: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Set up model | |
model = torchvision.models.segmentation.__dict__[model_name](pretrained=pretrained, aux_loss=False) | |
model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size) # change number of outputs to 1 | |
if device.type == "cuda": | |
model = torch.nn.DataParallel(model) | |
model.to(device) | |
if weights is not None: | |
checkpoint = torch.load(weights) | |
model.load_state_dict(checkpoint['state_dict']) | |
# Set up optimizer | |
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay) | |
if lr_step_period is None: | |
lr_step_period = math.inf | |
scheduler = torch.optim.lr_scheduler.StepLR(optim, lr_step_period) | |
# Compute mean and std | |
mean, std = echonet.utils.get_mean_and_std(echonet.datasets.Echo(root=data_dir, split="train")) | |
tasks = ["LargeFrame", "SmallFrame", "LargeTrace", "SmallTrace"] | |
kwargs = {"target_type": tasks, | |
"mean": mean, | |
"std": std | |
} | |
# Set up datasets and dataloaders | |
dataset = {} | |
dataset["train"] = echonet.datasets.Echo(root=data_dir, split="train", **kwargs) | |
if num_train_patients is not None and len(dataset["train"]) > num_train_patients: | |
# Subsample patients (used for ablation experiment) | |
indices = np.random.choice(len(dataset["train"]), num_train_patients, replace=False) | |
dataset["train"] = torch.utils.data.Subset(dataset["train"], indices) | |
dataset["val"] = echonet.datasets.Echo(root=data_dir, split="val", **kwargs) | |
# Run training and testing loops | |
with open(os.path.join(output, "log.csv"), "a") as f: | |
epoch_resume = 0 | |
bestLoss = float("inf") | |
try: | |
# Attempt to load checkpoint | |
checkpoint = torch.load(os.path.join(output, "checkpoint.pt")) | |
model.load_state_dict(checkpoint['state_dict']) | |
optim.load_state_dict(checkpoint['opt_dict']) | |
scheduler.load_state_dict(checkpoint['scheduler_dict']) | |
epoch_resume = checkpoint["epoch"] + 1 | |
bestLoss = checkpoint["best_loss"] | |
f.write("Resuming from epoch {}\n".format(epoch_resume)) | |
except FileNotFoundError: | |
f.write("Starting run from scratch\n") | |
for epoch in range(epoch_resume, num_epochs): | |
print("Epoch #{}".format(epoch), flush=True) | |
for phase in ['train', 'val']: | |
start_time = time.time() | |
for i in range(torch.cuda.device_count()): | |
torch.cuda.reset_peak_memory_stats(i) | |
ds = dataset[phase] | |
dataloader = torch.utils.data.DataLoader( | |
ds, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=(device.type == "cuda"), drop_last=(phase == "train")) | |
loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model, dataloader, phase == "train", optim, device) | |
overall_dice = 2 * (large_inter.sum() + small_inter.sum()) / (large_union.sum() + large_inter.sum() + small_union.sum() + small_inter.sum()) | |
large_dice = 2 * large_inter.sum() / (large_union.sum() + large_inter.sum()) | |
small_dice = 2 * small_inter.sum() / (small_union.sum() + small_inter.sum()) | |
f.write("{},{},{},{},{},{},{},{},{},{},{}\n".format(epoch, | |
phase, | |
loss, | |
overall_dice, | |
large_dice, | |
small_dice, | |
time.time() - start_time, | |
large_inter.size, | |
sum(torch.cuda.max_memory_allocated() for i in range(torch.cuda.device_count())), | |
sum(torch.cuda.max_memory_reserved() for i in range(torch.cuda.device_count())), | |
batch_size)) | |
f.flush() | |
scheduler.step() | |
# Save checkpoint | |
save = { | |
'epoch': epoch, | |
'state_dict': model.state_dict(), | |
'best_loss': bestLoss, | |
'loss': loss, | |
'opt_dict': optim.state_dict(), | |
'scheduler_dict': scheduler.state_dict(), | |
} | |
torch.save(save, os.path.join(output, "checkpoint.pt")) | |
if loss < bestLoss: | |
torch.save(save, os.path.join(output, "best.pt")) | |
bestLoss = loss | |
# Load best weights | |
if num_epochs != 0: | |
checkpoint = torch.load(os.path.join(output, "best.pt")) | |
model.load_state_dict(checkpoint['state_dict']) | |
f.write("Best validation loss {} from epoch {}\n".format(checkpoint["loss"], checkpoint["epoch"])) | |
if run_test: | |
# Run on validation and test | |
for split in ["val", "test"]: | |
dataset = echonet.datasets.Echo(root=data_dir, split=split, **kwargs) | |
dataloader = torch.utils.data.DataLoader(dataset, | |
batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=(device.type == "cuda")) | |
loss, large_inter, large_union, small_inter, small_union = echonet.utils.segmentation.run_epoch(model, dataloader, False, None, device) | |
overall_dice = 2 * (large_inter + small_inter) / (large_union + large_inter + small_union + small_inter) | |
large_dice = 2 * large_inter / (large_union + large_inter) | |
small_dice = 2 * small_inter / (small_union + small_inter) | |
with open(os.path.join(output, "{}_dice.csv".format(split)), "w") as g: | |
g.write("Filename, Overall, Large, Small\n") | |
for (filename, overall, large, small) in zip(dataset.fnames, overall_dice, large_dice, small_dice): | |
g.write("{},{},{},{}\n".format(filename, overall, large, small)) | |
f.write("{} dice (overall): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(np.concatenate((large_inter, small_inter)), np.concatenate((large_union, small_union)), echonet.utils.dice_similarity_coefficient))) | |
f.write("{} dice (large): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(large_inter, large_union, echonet.utils.dice_similarity_coefficient))) | |
f.write("{} dice (small): {:.4f} ({:.4f} - {:.4f})\n".format(split, *echonet.utils.bootstrap(small_inter, small_union, echonet.utils.dice_similarity_coefficient))) | |
f.flush() | |
# Saving videos with segmentations | |
dataset = echonet.datasets.Echo(root=data_dir, split="test", | |
target_type=["Filename", "LargeIndex", "SmallIndex"], # Need filename for saving, and human-selected frames to annotate | |
mean=mean, std=std, # Normalization | |
length=None, max_length=None, period=1 # Take all frames | |
) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, num_workers=num_workers, shuffle=False, pin_memory=False, collate_fn=_video_collate_fn) | |
# Save videos with segmentation | |
if save_video and not all(os.path.isfile(os.path.join(output, "videos", f)) for f in dataloader.dataset.fnames): | |
# Only run if missing videos | |
model.eval() | |
os.makedirs(os.path.join(output, "videos"), exist_ok=True) | |
os.makedirs(os.path.join(output, "size"), exist_ok=True) | |
echonet.utils.latexify() | |
with torch.no_grad(): | |
with open(os.path.join(output, "size.csv"), "w") as g: | |
g.write("Filename,Frame,Size,HumanLarge,HumanSmall,ComputerSmall\n") | |
for (x, (filenames, large_index, small_index), length) in tqdm.tqdm(dataloader): | |
# Run segmentation model on blocks of frames one-by-one | |
# The whole concatenated video may be too long to run together | |
y = np.concatenate([model(x[i:(i + batch_size), :, :, :].to(device))["out"].detach().cpu().numpy() for i in range(0, x.shape[0], batch_size)]) | |
start = 0 | |
x = x.numpy() | |
for (i, (filename, offset)) in enumerate(zip(filenames, length)): | |
# Extract one video and segmentation predictions | |
video = x[start:(start + offset), ...] | |
logit = y[start:(start + offset), 0, :, :] | |
# Un-normalize video | |
video *= std.reshape(1, 3, 1, 1) | |
video += mean.reshape(1, 3, 1, 1) | |
# Get frames, channels, height, and width | |
f, c, h, w = video.shape # pylint: disable=W0612 | |
assert c == 3 | |
# Put two copies of the video side by side | |
video = np.concatenate((video, video), 3) | |
# If a pixel is in the segmentation, saturate blue channel | |
# Leave alone otherwise | |
video[:, 0, :, w:] = np.maximum(255. * (logit > 0), video[:, 0, :, w:]) # pylint: disable=E1111 | |
# Add blank canvas under pair of videos | |
video = np.concatenate((video, np.zeros_like(video)), 2) | |
# Compute size of segmentation per frame | |
size = (logit > 0).sum((1, 2)) | |
# Identify systole frames with peak detection | |
trim_min = sorted(size)[round(len(size) ** 0.05)] | |
trim_max = sorted(size)[round(len(size) ** 0.95)] | |
trim_range = trim_max - trim_min | |
systole = set(scipy.signal.find_peaks(-size, distance=20, prominence=(0.50 * trim_range))[0]) | |
# Write sizes and frames to file | |
for (frame, s) in enumerate(size): | |
g.write("{},{},{},{},{},{}\n".format(filename, frame, s, 1 if frame == large_index[i] else 0, 1 if frame == small_index[i] else 0, 1 if frame in systole else 0)) | |
# Plot sizes | |
fig = plt.figure(figsize=(size.shape[0] / 50 * 1.5, 3)) | |
plt.scatter(np.arange(size.shape[0]) / 50, size, s=1) | |
ylim = plt.ylim() | |
for s in systole: | |
plt.plot(np.array([s, s]) / 50, ylim, linewidth=1) | |
plt.ylim(ylim) | |
plt.title(os.path.splitext(filename)[0]) | |
plt.xlabel("Seconds") | |
plt.ylabel("Size (pixels)") | |
plt.tight_layout() | |
plt.savefig(os.path.join(output, "size", os.path.splitext(filename)[0] + ".pdf")) | |
plt.close(fig) | |
# Normalize size to [0, 1] | |
size -= size.min() | |
size = size / size.max() | |
size = 1 - size | |
# Iterate the frames in this video | |
for (f, s) in enumerate(size): | |
# On all frames, mark a pixel for the size of the frame | |
video[:, :, int(round(115 + 100 * s)), int(round(f / len(size) * 200 + 10))] = 255. | |
if f in systole: | |
# If frame is computer-selected systole, mark with a line | |
video[:, :, 115:224, int(round(f / len(size) * 200 + 10))] = 255. | |
def dash(start, stop, on=10, off=10): | |
buf = [] | |
x = start | |
while x < stop: | |
buf.extend(range(x, x + on)) | |
x += on | |
x += off | |
buf = np.array(buf) | |
buf = buf[buf < stop] | |
return buf | |
d = dash(115, 224) | |
if f == large_index[i]: | |
# If frame is human-selected diastole, mark with green dashed line on all frames | |
video[:, :, d, int(round(f / len(size) * 200 + 10))] = np.array([0, 225, 0]).reshape((1, 3, 1)) | |
if f == small_index[i]: | |
# If frame is human-selected systole, mark with red dashed line on all frames | |
video[:, :, d, int(round(f / len(size) * 200 + 10))] = np.array([0, 0, 225]).reshape((1, 3, 1)) | |
# Get pixels for a circle centered on the pixel | |
r, c = skimage.draw.disk((int(round(115 + 100 * s)), int(round(f / len(size) * 200 + 10))), 4.1) | |
# On the frame that's being shown, put a circle over the pixel | |
video[f, :, r, c] = 255. | |
# Rearrange dimensions and save | |
video = video.transpose(1, 0, 2, 3) | |
video = video.astype(np.uint8) | |
echonet.utils.savevideo(os.path.join(output, "videos", filename), video, 50) | |
# Move to next video | |
start += offset | |
def run_epoch(model, dataloader, train, optim, device): | |
"""Run one epoch of training/evaluation for segmentation. | |
Args: | |
model (torch.nn.Module): Model to train/evaulate. | |
dataloder (torch.utils.data.DataLoader): Dataloader for dataset. | |
train (bool): Whether or not to train model. | |
optim (torch.optim.Optimizer): Optimizer | |
device (torch.device): Device to run on | |
""" | |
total = 0. | |
n = 0 | |
pos = 0 | |
neg = 0 | |
pos_pix = 0 | |
neg_pix = 0 | |
model.train(train) | |
large_inter = 0 | |
large_union = 0 | |
small_inter = 0 | |
small_union = 0 | |
large_inter_list = [] | |
large_union_list = [] | |
small_inter_list = [] | |
small_union_list = [] | |
with torch.set_grad_enabled(train): | |
with tqdm.tqdm(total=len(dataloader)) as pbar: | |
for (_, (large_frame, small_frame, large_trace, small_trace)) in dataloader: | |
# Count number of pixels in/out of human segmentation | |
pos += (large_trace == 1).sum().item() | |
pos += (small_trace == 1).sum().item() | |
neg += (large_trace == 0).sum().item() | |
neg += (small_trace == 0).sum().item() | |
# Count number of pixels in/out of computer segmentation | |
pos_pix += (large_trace == 1).sum(0).to("cpu").detach().numpy() | |
pos_pix += (small_trace == 1).sum(0).to("cpu").detach().numpy() | |
neg_pix += (large_trace == 0).sum(0).to("cpu").detach().numpy() | |
neg_pix += (small_trace == 0).sum(0).to("cpu").detach().numpy() | |
# Run prediction for diastolic frames and compute loss | |
large_frame = large_frame.to(device) | |
large_trace = large_trace.to(device) | |
y_large = model(large_frame)["out"] | |
loss_large = torch.nn.functional.binary_cross_entropy_with_logits(y_large[:, 0, :, :], large_trace, reduction="sum") | |
# Compute pixel intersection and union between human and computer segmentations | |
large_inter += np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum() | |
large_union += np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum() | |
large_inter_list.extend(np.logical_and(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) | |
large_union_list.extend(np.logical_or(y_large[:, 0, :, :].detach().cpu().numpy() > 0., large_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) | |
# Run prediction for systolic frames and compute loss | |
small_frame = small_frame.to(device) | |
small_trace = small_trace.to(device) | |
y_small = model(small_frame)["out"] | |
loss_small = torch.nn.functional.binary_cross_entropy_with_logits(y_small[:, 0, :, :], small_trace, reduction="sum") | |
# Compute pixel intersection and union between human and computer segmentations | |
small_inter += np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum() | |
small_union += np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum() | |
small_inter_list.extend(np.logical_and(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) | |
small_union_list.extend(np.logical_or(y_small[:, 0, :, :].detach().cpu().numpy() > 0., small_trace[:, :, :].detach().cpu().numpy() > 0.).sum((1, 2))) | |
# Take gradient step if training | |
loss = (loss_large + loss_small) / 2 | |
if train: | |
optim.zero_grad() | |
loss.backward() | |
optim.step() | |
# Accumulate losses and compute baselines | |
total += loss.item() | |
n += large_trace.size(0) | |
p = pos / (pos + neg) | |
p_pix = (pos_pix + 1) / (pos_pix + neg_pix + 2) | |
# Show info on process bar | |
pbar.set_postfix_str("{:.4f} ({:.4f}) / {:.4f} {:.4f}, {:.4f}, {:.4f}".format(total / n / 112 / 112, loss.item() / large_trace.size(0) / 112 / 112, -p * math.log(p) - (1 - p) * math.log(1 - p), (-p_pix * np.log(p_pix) - (1 - p_pix) * np.log(1 - p_pix)).mean(), 2 * large_inter / (large_union + large_inter), 2 * small_inter / (small_union + small_inter))) | |
pbar.update() | |
large_inter_list = np.array(large_inter_list) | |
large_union_list = np.array(large_union_list) | |
small_inter_list = np.array(small_inter_list) | |
small_union_list = np.array(small_union_list) | |
return (total / n / 112 / 112, | |
large_inter_list, | |
large_union_list, | |
small_inter_list, | |
small_union_list, | |
) | |
def _video_collate_fn(x): | |
"""Collate function for Pytorch dataloader to merge multiple videos. | |
This function should be used in a dataloader for a dataset that returns | |
a video as the first element, along with some (non-zero) tuple of | |
targets. Then, the input x is a list of tuples: | |
- x[i][0] is the i-th video in the batch | |
- x[i][1] are the targets for the i-th video | |
This function returns a 3-tuple: | |
- The first element is the videos concatenated along the frames | |
dimension. This is done so that videos of different lengths can be | |
processed together (tensors cannot be "jagged", so we cannot have | |
a dimension for video, and another for frames). | |
- The second element is contains the targets with no modification. | |
- The third element is a list of the lengths of the videos in frames. | |
""" | |
video, target = zip(*x) # Extract the videos and targets | |
# ``video'' is a tuple of length ``batch_size'' | |
# Each element has shape (channels=3, frames, height, width) | |
# height and width are expected to be the same across videos, but | |
# frames can be different. | |
# ``target'' is also a tuple of length ``batch_size'' | |
# Each element is a tuple of the targets for the item. | |
i = list(map(lambda t: t.shape[1], video)) # Extract lengths of videos in frames | |
# This contatenates the videos along the the frames dimension (basically | |
# playing the videos one after another). The frames dimension is then | |
# moved to be first. | |
# Resulting shape is (total frames, channels=3, height, width) | |
video = torch.as_tensor(np.swapaxes(np.concatenate(video, 1), 0, 1)) | |
# Swap dimensions (approximately a transpose) | |
# Before: target[i][j] is the j-th target of element i | |
# After: target[i][j] is the i-th target of element j | |
target = zip(*target) | |
return video, target, i | |