|
import math |
|
import pathlib |
|
import warnings |
|
from types import FunctionType |
|
from typing import Any, BinaryIO, List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
from PIL import Image, ImageColor, ImageDraw, ImageFont |
|
|
|
__all__ = [ |
|
"make_grid", |
|
"save_image", |
|
"draw_bounding_boxes", |
|
"draw_segmentation_masks", |
|
"draw_keypoints", |
|
"flow_to_image", |
|
] |
|
|
|
|
|
@torch.no_grad() |
|
def make_grid( |
|
tensor: Union[torch.Tensor, List[torch.Tensor]], |
|
nrow: int = 8, |
|
padding: int = 2, |
|
normalize: bool = False, |
|
value_range: Optional[Tuple[int, int]] = None, |
|
scale_each: bool = False, |
|
pad_value: float = 0.0, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
""" |
|
Make a grid of images. |
|
|
|
Args: |
|
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) |
|
or a list of images all of the same size. |
|
nrow (int, optional): Number of images displayed in each row of the grid. |
|
The final grid size is ``(B / nrow, nrow)``. Default: ``8``. |
|
padding (int, optional): amount of padding. Default: ``2``. |
|
normalize (bool, optional): If True, shift the image to the range (0, 1), |
|
by the min and max values specified by ``value_range``. Default: ``False``. |
|
value_range (tuple, optional): tuple (min, max) where min and max are numbers, |
|
then these numbers are used to normalize the image. By default, min and max |
|
are computed from the tensor. |
|
range (tuple. optional): |
|
.. warning:: |
|
This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range`` |
|
instead. |
|
scale_each (bool, optional): If ``True``, scale each image in the batch of |
|
images separately rather than the (min, max) over all images. Default: ``False``. |
|
pad_value (float, optional): Value for the padded pixels. Default: ``0``. |
|
|
|
Returns: |
|
grid (Tensor): the tensor containing grid of images. |
|
""" |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(make_grid) |
|
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): |
|
raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") |
|
|
|
if "range" in kwargs.keys(): |
|
warnings.warn( |
|
"The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. " |
|
"Please use 'value_range' instead." |
|
) |
|
value_range = kwargs["range"] |
|
|
|
|
|
if isinstance(tensor, list): |
|
tensor = torch.stack(tensor, dim=0) |
|
|
|
if tensor.dim() == 2: |
|
tensor = tensor.unsqueeze(0) |
|
if tensor.dim() == 3: |
|
if tensor.size(0) == 1: |
|
tensor = torch.cat((tensor, tensor, tensor), 0) |
|
tensor = tensor.unsqueeze(0) |
|
|
|
if tensor.dim() == 4 and tensor.size(1) == 1: |
|
tensor = torch.cat((tensor, tensor, tensor), 1) |
|
|
|
if normalize is True: |
|
tensor = tensor.clone() |
|
if value_range is not None: |
|
assert isinstance( |
|
value_range, tuple |
|
), "value_range has to be a tuple (min, max) if specified. min and max are numbers" |
|
|
|
def norm_ip(img, low, high): |
|
img.clamp_(min=low, max=high) |
|
img.sub_(low).div_(max(high - low, 1e-5)) |
|
|
|
def norm_range(t, value_range): |
|
if value_range is not None: |
|
norm_ip(t, value_range[0], value_range[1]) |
|
else: |
|
norm_ip(t, float(t.min()), float(t.max())) |
|
|
|
if scale_each is True: |
|
for t in tensor: |
|
norm_range(t, value_range) |
|
else: |
|
norm_range(tensor, value_range) |
|
|
|
assert isinstance(tensor, torch.Tensor) |
|
if tensor.size(0) == 1: |
|
return tensor.squeeze(0) |
|
|
|
|
|
nmaps = tensor.size(0) |
|
xmaps = min(nrow, nmaps) |
|
ymaps = int(math.ceil(float(nmaps) / xmaps)) |
|
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) |
|
num_channels = tensor.size(1) |
|
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) |
|
k = 0 |
|
for y in range(ymaps): |
|
for x in range(xmaps): |
|
if k >= nmaps: |
|
break |
|
|
|
|
|
grid.narrow(1, y * height + padding, height - padding).narrow( |
|
2, x * width + padding, width - padding |
|
).copy_(tensor[k]) |
|
k = k + 1 |
|
return grid |
|
|
|
|
|
@torch.no_grad() |
|
def save_image( |
|
tensor: Union[torch.Tensor, List[torch.Tensor]], |
|
fp: Union[str, pathlib.Path, BinaryIO], |
|
format: Optional[str] = None, |
|
**kwargs, |
|
) -> None: |
|
""" |
|
Save a given Tensor into an image file. |
|
|
|
Args: |
|
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, |
|
saves the tensor as a grid of images by calling ``make_grid``. |
|
fp (string or file object): A filename or a file object |
|
format(Optional): If omitted, the format to use is determined from the filename extension. |
|
If a file object was used instead of a filename, this parameter should always be used. |
|
**kwargs: Other arguments are documented in ``make_grid``. |
|
""" |
|
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(save_image) |
|
grid = make_grid(tensor, **kwargs) |
|
|
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
im = Image.fromarray(ndarr) |
|
im.save(fp, format=format) |
|
|
|
|
|
@torch.no_grad() |
|
def draw_bounding_boxes( |
|
image: torch.Tensor, |
|
boxes: torch.Tensor, |
|
labels: Optional[List[str]] = None, |
|
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, |
|
fill: Optional[bool] = False, |
|
width: int = 1, |
|
font: Optional[str] = None, |
|
font_size: int = 10, |
|
) -> torch.Tensor: |
|
|
|
""" |
|
Draws bounding boxes on given image. |
|
The values of the input image should be uint8 between 0 and 255. |
|
If fill is True, Resulting Tensor should be saved as PNG image. |
|
|
|
Args: |
|
image (Tensor): Tensor of shape (C x H x W) and dtype uint8. |
|
boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that |
|
the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and |
|
`0 <= ymin < ymax < H`. |
|
labels (List[str]): List containing the labels of bounding boxes. |
|
colors (color or list of colors, optional): List containing the colors |
|
of the boxes or single color for all boxes. The color can be represented as |
|
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. |
|
By default, random colors are generated for boxes. |
|
fill (bool): If `True` fills the bounding box with specified color. |
|
width (int): Width of bounding box. |
|
font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may |
|
also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, |
|
`/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. |
|
font_size (int): The requested font size in points. |
|
|
|
Returns: |
|
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. |
|
""" |
|
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(draw_bounding_boxes) |
|
if not isinstance(image, torch.Tensor): |
|
raise TypeError(f"Tensor expected, got {type(image)}") |
|
elif image.dtype != torch.uint8: |
|
raise ValueError(f"Tensor uint8 expected, got {image.dtype}") |
|
elif image.dim() != 3: |
|
raise ValueError("Pass individual images, not batches") |
|
elif image.size(0) not in {1, 3}: |
|
raise ValueError("Only grayscale and RGB images are supported") |
|
|
|
num_boxes = boxes.shape[0] |
|
|
|
if labels is None: |
|
labels: Union[List[str], List[None]] = [None] * num_boxes |
|
elif len(labels) != num_boxes: |
|
raise ValueError( |
|
f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box." |
|
) |
|
|
|
if colors is None: |
|
colors = _generate_color_palette(num_boxes) |
|
elif isinstance(colors, list): |
|
if len(colors) < num_boxes: |
|
raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ") |
|
else: |
|
colors = [colors] * num_boxes |
|
|
|
colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors] |
|
|
|
|
|
if image.size(0) == 1: |
|
image = torch.tile(image, (3, 1, 1)) |
|
|
|
ndarr = image.permute(1, 2, 0).cpu().numpy() |
|
img_to_draw = Image.fromarray(ndarr) |
|
img_boxes = boxes.to(torch.int64).tolist() |
|
|
|
if fill: |
|
draw = ImageDraw.Draw(img_to_draw, "RGBA") |
|
else: |
|
draw = ImageDraw.Draw(img_to_draw) |
|
|
|
txt_font = ImageFont.load_default() if font is None else ImageFont.truetype(font=font, size=font_size) |
|
|
|
for bbox, color, label in zip(img_boxes, colors, labels): |
|
if fill: |
|
fill_color = color + (100,) |
|
draw.rectangle(bbox, width=width, outline=color, fill=fill_color) |
|
else: |
|
draw.rectangle(bbox, width=width, outline=color) |
|
|
|
if label is not None: |
|
margin = width + 1 |
|
draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) |
|
|
|
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) |
|
|
|
|
|
@torch.no_grad() |
|
def draw_segmentation_masks( |
|
image: torch.Tensor, |
|
masks: torch.Tensor, |
|
alpha: float = 0.8, |
|
colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, |
|
) -> torch.Tensor: |
|
|
|
""" |
|
Draws segmentation masks on given RGB image. |
|
The values of the input image should be uint8 between 0 and 255. |
|
|
|
Args: |
|
image (Tensor): Tensor of shape (3, H, W) and dtype uint8. |
|
masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. |
|
alpha (float): Float number between 0 and 1 denoting the transparency of the masks. |
|
0 means full transparency, 1 means no transparency. |
|
colors (color or list of colors, optional): List containing the colors |
|
of the masks or single color for all masks. The color can be represented as |
|
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. |
|
By default, random colors are generated for each mask. |
|
|
|
Returns: |
|
img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. |
|
""" |
|
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(draw_segmentation_masks) |
|
if not isinstance(image, torch.Tensor): |
|
raise TypeError(f"The image must be a tensor, got {type(image)}") |
|
elif image.dtype != torch.uint8: |
|
raise ValueError(f"The image dtype must be uint8, got {image.dtype}") |
|
elif image.dim() != 3: |
|
raise ValueError("Pass individual images, not batches") |
|
elif image.size()[0] != 3: |
|
raise ValueError("Pass an RGB image. Other Image formats are not supported") |
|
if masks.ndim == 2: |
|
masks = masks[None, :, :] |
|
if masks.ndim != 3: |
|
raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") |
|
if masks.dtype != torch.bool: |
|
raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") |
|
if masks.shape[-2:] != image.shape[-2:]: |
|
raise ValueError("The image and the masks must have the same height and width") |
|
|
|
num_masks = masks.size()[0] |
|
if colors is not None and num_masks > len(colors): |
|
raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") |
|
|
|
if colors is None: |
|
colors = _generate_color_palette(num_masks) |
|
|
|
if not isinstance(colors, list): |
|
colors = [colors] |
|
if not isinstance(colors[0], (tuple, str)): |
|
raise ValueError("colors must be a tuple or a string, or a list thereof") |
|
if isinstance(colors[0], tuple) and len(colors[0]) != 3: |
|
raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") |
|
|
|
out_dtype = torch.uint8 |
|
|
|
colors_ = [] |
|
for color in colors: |
|
if isinstance(color, str): |
|
color = ImageColor.getrgb(color) |
|
colors_.append(torch.tensor(color, dtype=out_dtype)) |
|
|
|
img_to_draw = image.detach().clone() |
|
|
|
for mask, color in zip(masks, colors_): |
|
img_to_draw[:, mask] = color[:, None] |
|
|
|
out = image * (1 - alpha) + img_to_draw * alpha |
|
return out.to(out_dtype) |
|
|
|
|
|
@torch.no_grad() |
|
def draw_keypoints( |
|
image: torch.Tensor, |
|
keypoints: torch.Tensor, |
|
connectivity: Optional[List[Tuple[int, int]]] = None, |
|
colors: Optional[Union[str, Tuple[int, int, int]]] = None, |
|
radius: int = 2, |
|
width: int = 3, |
|
) -> torch.Tensor: |
|
|
|
""" |
|
Draws Keypoints on given RGB image. |
|
The values of the input image should be uint8 between 0 and 255. |
|
|
|
Args: |
|
image (Tensor): Tensor of shape (3, H, W) and dtype uint8. |
|
keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, |
|
in the format [x, y]. |
|
connectivity (List[Tuple[int, int]]]): A List of tuple where, |
|
each tuple contains pair of keypoints to be connected. |
|
colors (str, Tuple): The color can be represented as |
|
PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. |
|
radius (int): Integer denoting radius of keypoint. |
|
width (int): Integer denoting width of line connecting keypoints. |
|
|
|
Returns: |
|
img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. |
|
""" |
|
|
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(draw_keypoints) |
|
if not isinstance(image, torch.Tensor): |
|
raise TypeError(f"The image must be a tensor, got {type(image)}") |
|
elif image.dtype != torch.uint8: |
|
raise ValueError(f"The image dtype must be uint8, got {image.dtype}") |
|
elif image.dim() != 3: |
|
raise ValueError("Pass individual images, not batches") |
|
elif image.size()[0] != 3: |
|
raise ValueError("Pass an RGB image. Other Image formats are not supported") |
|
|
|
if keypoints.ndim != 3: |
|
raise ValueError("keypoints must be of shape (num_instances, K, 2)") |
|
|
|
ndarr = image.permute(1, 2, 0).cpu().numpy() |
|
img_to_draw = Image.fromarray(ndarr) |
|
draw = ImageDraw.Draw(img_to_draw) |
|
img_kpts = keypoints.to(torch.int64).tolist() |
|
|
|
for kpt_id, kpt_inst in enumerate(img_kpts): |
|
for inst_id, kpt in enumerate(kpt_inst): |
|
x1 = kpt[0] - radius |
|
x2 = kpt[0] + radius |
|
y1 = kpt[1] - radius |
|
y2 = kpt[1] + radius |
|
draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) |
|
|
|
if connectivity: |
|
for connection in connectivity: |
|
start_pt_x = kpt_inst[connection[0]][0] |
|
start_pt_y = kpt_inst[connection[0]][1] |
|
|
|
end_pt_x = kpt_inst[connection[1]][0] |
|
end_pt_y = kpt_inst[connection[1]][1] |
|
|
|
draw.line( |
|
((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), |
|
width=width, |
|
) |
|
|
|
return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def flow_to_image(flow: torch.Tensor) -> torch.Tensor: |
|
|
|
""" |
|
Converts a flow to an RGB image. |
|
|
|
Args: |
|
flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. |
|
|
|
Returns: |
|
img (Tensor): Image Tensor of dtype uint8 where each color corresponds |
|
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. |
|
""" |
|
|
|
if flow.dtype != torch.float: |
|
raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") |
|
|
|
orig_shape = flow.shape |
|
if flow.ndim == 3: |
|
flow = flow[None] |
|
|
|
if flow.ndim != 4 or flow.shape[1] != 2: |
|
raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") |
|
|
|
max_norm = torch.sum(flow ** 2, dim=1).sqrt().max() |
|
epsilon = torch.finfo((flow).dtype).eps |
|
normalized_flow = flow / (max_norm + epsilon) |
|
img = _normalized_flow_to_image(normalized_flow) |
|
|
|
if len(orig_shape) == 3: |
|
img = img[0] |
|
return img |
|
|
|
|
|
@torch.no_grad() |
|
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: |
|
|
|
""" |
|
Converts a batch of normalized flow to an RGB image. |
|
|
|
Args: |
|
normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) |
|
Returns: |
|
img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. |
|
""" |
|
|
|
N, _, H, W = normalized_flow.shape |
|
device = normalized_flow.device |
|
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) |
|
colorwheel = _make_colorwheel().to(device) |
|
num_cols = colorwheel.shape[0] |
|
norm = torch.sum(normalized_flow ** 2, dim=1).sqrt() |
|
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi |
|
fk = (a + 1) / 2 * (num_cols - 1) |
|
k0 = torch.floor(fk).to(torch.long) |
|
k1 = k0 + 1 |
|
k1[k1 == num_cols] = 0 |
|
f = fk - k0 |
|
|
|
for c in range(colorwheel.shape[1]): |
|
tmp = colorwheel[:, c] |
|
col0 = tmp[k0] / 255.0 |
|
col1 = tmp[k1] / 255.0 |
|
col = (1 - f) * col0 + f * col1 |
|
col = 1 - norm * (1 - col) |
|
flow_image[:, c, :, :] = torch.floor(255 * col) |
|
return flow_image |
|
|
|
|
|
def _make_colorwheel() -> torch.Tensor: |
|
""" |
|
Generates a color wheel for optical flow visualization as presented in: |
|
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) |
|
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. |
|
|
|
Returns: |
|
colorwheel (Tensor[55, 3]): Colorwheel Tensor. |
|
""" |
|
|
|
RY = 15 |
|
YG = 6 |
|
GC = 4 |
|
CB = 11 |
|
BM = 13 |
|
MR = 6 |
|
|
|
ncols = RY + YG + GC + CB + BM + MR |
|
colorwheel = torch.zeros((ncols, 3)) |
|
col = 0 |
|
|
|
|
|
colorwheel[0:RY, 0] = 255 |
|
colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY) |
|
col = col + RY |
|
|
|
colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG) |
|
colorwheel[col : col + YG, 1] = 255 |
|
col = col + YG |
|
|
|
colorwheel[col : col + GC, 1] = 255 |
|
colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC) |
|
col = col + GC |
|
|
|
colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB) |
|
colorwheel[col : col + CB, 2] = 255 |
|
col = col + CB |
|
|
|
colorwheel[col : col + BM, 2] = 255 |
|
colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM) |
|
col = col + BM |
|
|
|
colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR) |
|
colorwheel[col : col + MR, 0] = 255 |
|
return colorwheel |
|
|
|
|
|
def _generate_color_palette(num_objects: int): |
|
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) |
|
return [tuple((i * palette) % 255) for i in range(num_objects)] |
|
|
|
|
|
def _log_api_usage_once(obj: Any) -> None: |
|
|
|
""" |
|
Logs API usage(module and name) within an organization. |
|
In a large ecosystem, it's often useful to track the PyTorch and |
|
TorchVision APIs usage. This API provides the similar functionality to the |
|
logging module in the Python stdlib. It can be used for debugging purpose |
|
to log which methods are used and by default it is inactive, unless the user |
|
manually subscribes a logger via the `SetAPIUsageLogger method <https://github.com/pytorch/pytorch/blob/eb3b9fe719b21fae13c7a7cf3253f970290a573e/c10/util/Logging.cpp#L114>`_. |
|
Please note it is triggered only once for the same API call within a process. |
|
It does not collect any data from open-source users since it is no-op by default. |
|
For more information, please refer to |
|
* PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging; |
|
* Logging policy: https://github.com/pytorch/vision/issues/5052; |
|
|
|
Args: |
|
obj (class instance or method): an object to extract info from. |
|
""" |
|
if not obj.__module__.startswith("torchvision"): |
|
return |
|
name = obj.__class__.__name__ |
|
if isinstance(obj, FunctionType): |
|
name = obj.__name__ |
|
torch._C._log_api_usage_once(f"{obj.__module__}.{name}") |
|
|