Franny Dean
files
dde56f0
raw
history blame
6.24 kB
"""Utility functions for videos, plotting and computing performance metrics."""
import os
import typing
import cv2 # pytype: disable=attribute-error
import matplotlib
import numpy as np
import torch
import tqdm
from . import video
from . import segmentation
def loadvideo(filename: str) -> np.ndarray:
"""Loads a video from a file.
Args:
filename (str): filename of video
Returns:
A np.ndarray with dimensions (channels=3, frames, height, width). The
values will be uint8's ranging from 0 to 255.
Raises:
FileNotFoundError: Could not find `filename`
ValueError: An error occurred while reading the video
"""
if not os.path.exists(filename):
raise FileNotFoundError(filename)
capture = cv2.VideoCapture(filename)
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
v = np.zeros((frame_count, frame_height, frame_width, 3), np.uint8)
for count in range(frame_count):
ret, frame = capture.read()
if not ret:
raise ValueError("Failed to load frame #{} of {}.".format(count, filename))
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
v[count, :, :] = frame
v = v.transpose((3, 0, 1, 2))
return v
def savevideo(filename: str, array: np.ndarray, fps: typing.Union[float, int] = 1):
"""Saves a video to a file.
Args:
filename (str): filename of video
array (np.ndarray): video of uint8's with shape (channels=3, frames, height, width)
fps (float or int): frames per second
Returns:
None
"""
c, _, height, width = array.shape
if c != 3:
raise ValueError("savevideo expects array of shape (channels=3, frames, height, width), got shape ({})".format(", ".join(map(str, array.shape))))
fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
out = cv2.VideoWriter(filename, fourcc, fps, (width, height))
for frame in array.transpose((1, 2, 3, 0)):
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
out.write(frame)
def get_mean_and_std(dataset: torch.utils.data.Dataset,
samples: int = 128,
batch_size: int = 8,
num_workers: int = 4):
"""Computes mean and std from samples from a Pytorch dataset.
Args:
dataset (torch.utils.data.Dataset): A Pytorch dataset.
``dataset[i][0]'' is expected to be the i-th video in the dataset, which
should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width)
samples (int or None, optional): Number of samples to take from dataset. If ``None'', mean and
standard deviation are computed over all elements.
Defaults to 128.
batch_size (int, optional): how many samples per batch to load
Defaults to 8.
num_workers (int, optional): how many subprocesses to use for data
loading. If 0, the data will be loaded in the main process.
Defaults to 4.
Returns:
A tuple of the mean and standard deviation. Both are represented as np.array's of dimension (channels,).
"""
if samples is not None and len(dataset) > samples:
indices = np.random.choice(len(dataset), samples, replace=False)
dataset = torch.utils.data.Subset(dataset, indices)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
n = 0 # number of elements taken (should be equal to samples by end of for loop)
s1 = 0. # sum of elements along channels (ends up as np.array of dimension (channels,))
s2 = 0. # sum of squares of elements along channels (ends up as np.array of dimension (channels,))
for (x, *_) in tqdm.tqdm(dataloader):
x = x.transpose(0, 1).contiguous().view(3, -1)
n += x.shape[1]
s1 += torch.sum(x, dim=1).numpy()
s2 += torch.sum(x ** 2, dim=1).numpy()
mean = s1 / n # type: np.ndarray
std = np.sqrt(s2 / n - mean ** 2) # type: np.ndarray
mean = mean.astype(np.float32)
std = std.astype(np.float32)
return mean, std
def bootstrap(a, b, func, samples=10000):
"""Computes a bootstrapped confidence intervals for ``func(a, b)''.
Args:
a (array_like): first argument to `func`.
b (array_like): second argument to `func`.
func (callable): Function to compute confidence intervals for.
``dataset[i][0]'' is expected to be the i-th video in the dataset, which
should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width)
samples (int, optional): Number of samples to compute.
Defaults to 10000.
Returns:
A tuple of (`func(a, b)`, estimated 5-th percentile, estimated 95-th percentile).
"""
a = np.array(a)
b = np.array(b)
bootstraps = []
for _ in range(samples):
ind = np.random.choice(len(a), len(a))
bootstraps.append(func(a[ind], b[ind]))
bootstraps = sorted(bootstraps)
return func(a, b), bootstraps[round(0.05 * len(bootstraps))], bootstraps[round(0.95 * len(bootstraps))]
def latexify():
"""Sets matplotlib params to appear more like LaTeX.
Based on https://nipunbatra.github.io/blog/2014/latexify.html
"""
params = {'backend': 'pdf',
'axes.titlesize': 8,
'axes.labelsize': 8,
'font.size': 8,
'legend.fontsize': 8,
'xtick.labelsize': 8,
'ytick.labelsize': 8,
'font.family': 'DejaVu Serif',
'font.serif': 'Computer Modern',
}
matplotlib.rcParams.update(params)
def dice_similarity_coefficient(inter, union):
"""Computes the dice similarity coefficient.
Args:
inter (iterable): iterable of the intersections
union (iterable): iterable of the unions
"""
return 2 * sum(inter) / (sum(union) + sum(inter))
__all__ = ["video", "segmentation", "loadvideo", "savevideo", "get_mean_and_std", "bootstrap", "latexify", "dice_similarity_coefficient"]