Spaces:
Runtime error
Runtime error
"""Utility functions for videos, plotting and computing performance metrics.""" | |
import os | |
import typing | |
import cv2 # pytype: disable=attribute-error | |
import matplotlib | |
import numpy as np | |
import torch | |
import tqdm | |
from . import video | |
from . import segmentation | |
def loadvideo(filename: str) -> np.ndarray: | |
"""Loads a video from a file. | |
Args: | |
filename (str): filename of video | |
Returns: | |
A np.ndarray with dimensions (channels=3, frames, height, width). The | |
values will be uint8's ranging from 0 to 255. | |
Raises: | |
FileNotFoundError: Could not find `filename` | |
ValueError: An error occurred while reading the video | |
""" | |
if not os.path.exists(filename): | |
raise FileNotFoundError(filename) | |
capture = cv2.VideoCapture(filename) | |
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
v = np.zeros((frame_count, frame_height, frame_width, 3), np.uint8) | |
for count in range(frame_count): | |
ret, frame = capture.read() | |
if not ret: | |
raise ValueError("Failed to load frame #{} of {}.".format(count, filename)) | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
v[count, :, :] = frame | |
v = v.transpose((3, 0, 1, 2)) | |
return v | |
def savevideo(filename: str, array: np.ndarray, fps: typing.Union[float, int] = 1): | |
"""Saves a video to a file. | |
Args: | |
filename (str): filename of video | |
array (np.ndarray): video of uint8's with shape (channels=3, frames, height, width) | |
fps (float or int): frames per second | |
Returns: | |
None | |
""" | |
c, _, height, width = array.shape | |
if c != 3: | |
raise ValueError("savevideo expects array of shape (channels=3, frames, height, width), got shape ({})".format(", ".join(map(str, array.shape)))) | |
fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') | |
out = cv2.VideoWriter(filename, fourcc, fps, (width, height)) | |
for frame in array.transpose((1, 2, 3, 0)): | |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
out.write(frame) | |
def get_mean_and_std(dataset: torch.utils.data.Dataset, | |
samples: int = 128, | |
batch_size: int = 8, | |
num_workers: int = 4): | |
"""Computes mean and std from samples from a Pytorch dataset. | |
Args: | |
dataset (torch.utils.data.Dataset): A Pytorch dataset. | |
``dataset[i][0]'' is expected to be the i-th video in the dataset, which | |
should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width) | |
samples (int or None, optional): Number of samples to take from dataset. If ``None'', mean and | |
standard deviation are computed over all elements. | |
Defaults to 128. | |
batch_size (int, optional): how many samples per batch to load | |
Defaults to 8. | |
num_workers (int, optional): how many subprocesses to use for data | |
loading. If 0, the data will be loaded in the main process. | |
Defaults to 4. | |
Returns: | |
A tuple of the mean and standard deviation. Both are represented as np.array's of dimension (channels,). | |
""" | |
if samples is not None and len(dataset) > samples: | |
indices = np.random.choice(len(dataset), samples, replace=False) | |
dataset = torch.utils.data.Subset(dataset, indices) | |
dataloader = torch.utils.data.DataLoader( | |
dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) | |
n = 0 # number of elements taken (should be equal to samples by end of for loop) | |
s1 = 0. # sum of elements along channels (ends up as np.array of dimension (channels,)) | |
s2 = 0. # sum of squares of elements along channels (ends up as np.array of dimension (channels,)) | |
for (x, *_) in tqdm.tqdm(dataloader): | |
x = x.transpose(0, 1).contiguous().view(3, -1) | |
n += x.shape[1] | |
s1 += torch.sum(x, dim=1).numpy() | |
s2 += torch.sum(x ** 2, dim=1).numpy() | |
mean = s1 / n # type: np.ndarray | |
std = np.sqrt(s2 / n - mean ** 2) # type: np.ndarray | |
mean = mean.astype(np.float32) | |
std = std.astype(np.float32) | |
return mean, std | |
def bootstrap(a, b, func, samples=10000): | |
"""Computes a bootstrapped confidence intervals for ``func(a, b)''. | |
Args: | |
a (array_like): first argument to `func`. | |
b (array_like): second argument to `func`. | |
func (callable): Function to compute confidence intervals for. | |
``dataset[i][0]'' is expected to be the i-th video in the dataset, which | |
should be a ``torch.Tensor'' of dimensions (channels=3, frames, height, width) | |
samples (int, optional): Number of samples to compute. | |
Defaults to 10000. | |
Returns: | |
A tuple of (`func(a, b)`, estimated 5-th percentile, estimated 95-th percentile). | |
""" | |
a = np.array(a) | |
b = np.array(b) | |
bootstraps = [] | |
for _ in range(samples): | |
ind = np.random.choice(len(a), len(a)) | |
bootstraps.append(func(a[ind], b[ind])) | |
bootstraps = sorted(bootstraps) | |
return func(a, b), bootstraps[round(0.05 * len(bootstraps))], bootstraps[round(0.95 * len(bootstraps))] | |
def latexify(): | |
"""Sets matplotlib params to appear more like LaTeX. | |
Based on https://nipunbatra.github.io/blog/2014/latexify.html | |
""" | |
params = {'backend': 'pdf', | |
'axes.titlesize': 8, | |
'axes.labelsize': 8, | |
'font.size': 8, | |
'legend.fontsize': 8, | |
'xtick.labelsize': 8, | |
'ytick.labelsize': 8, | |
'font.family': 'DejaVu Serif', | |
'font.serif': 'Computer Modern', | |
} | |
matplotlib.rcParams.update(params) | |
def dice_similarity_coefficient(inter, union): | |
"""Computes the dice similarity coefficient. | |
Args: | |
inter (iterable): iterable of the intersections | |
union (iterable): iterable of the unions | |
""" | |
return 2 * sum(inter) / (sum(union) + sum(inter)) | |
__all__ = ["video", "segmentation", "loadvideo", "savevideo", "get_mean_and_std", "bootstrap", "latexify", "dice_similarity_coefficient"] | |