Spaces:
Sleeping
Sleeping
File size: 6,053 Bytes
31f2f28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
"""Various utilities used in the film_net frame interpolator model."""
from typing import List, Optional
import cv2
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
def pad_batch(batch, align):
height, width = batch.shape[1:3]
height_to_pad = (align - height % align) if height % align != 0 else 0
width_to_pad = (align - width % align) if width % align != 0 else 0
crop_region = [height_to_pad >> 1, width_to_pad >> 1, height + (height_to_pad >> 1), width + (width_to_pad >> 1)]
batch = np.pad(batch, ((0, 0), (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)),
(width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)), (0, 0)), mode='constant')
return batch, crop_region
def load_image(path, align=64):
image = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB).astype(np.float32) / np.float32(255)
image_batch, crop_region = pad_batch(np.expand_dims(image, axis=0), align)
return image_batch, crop_region
def build_image_pyramid(image: torch.Tensor, pyramid_levels: int = 3) -> List[torch.Tensor]:
"""Builds an image pyramid from a given image.
The original image is included in the pyramid and the rest are generated by
successively halving the resolution.
Args:
image: the input image.
options: film_net options object
Returns:
A list of images starting from the finest with options.pyramid_levels items
"""
pyramid = []
for i in range(pyramid_levels):
pyramid.append(image)
if i < pyramid_levels - 1:
image = F.avg_pool2d(image, 2, 2)
return pyramid
def warp(image: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
"""Backward warps the image using the given flow.
Specifically, the output pixel in batch b, at position x, y will be computed
as follows:
(flowed_y, flowed_x) = (y+flow[b, y, x, 1], x+flow[b, y, x, 0])
output[b, y, x] = bilinear_lookup(image, b, flowed_y, flowed_x)
Note that the flow vectors are expected as [x, y], e.g. x in position 0 and
y in position 1.
Args:
image: An image with shape BxHxWxC.
flow: A flow with shape BxHxWx2, with the two channels denoting the relative
offset in order: (dx, dy).
Returns:
A warped image.
"""
flow = -flow.flip(1)
dtype = flow.dtype
device = flow.device
# warped = tfa_image.dense_image_warp(image, flow)
# Same as above but with pytorch
ls1 = 1 - 1 / flow.shape[3]
ls2 = 1 - 1 / flow.shape[2]
normalized_flow2 = flow.permute(0, 2, 3, 1) / torch.tensor(
[flow.shape[2] * .5, flow.shape[3] * .5], dtype=dtype, device=device)[None, None, None]
normalized_flow2 = torch.stack([
torch.linspace(-ls1, ls1, flow.shape[3], dtype=dtype, device=device)[None, None, :] - normalized_flow2[..., 1],
torch.linspace(-ls2, ls2, flow.shape[2], dtype=dtype, device=device)[None, :, None] - normalized_flow2[..., 0],
], dim=3)
warped = F.grid_sample(image, normalized_flow2,
mode='bilinear', padding_mode='border', align_corners=False)
return warped.reshape(image.shape)
def multiply_pyramid(pyramid: List[torch.Tensor],
scalar: torch.Tensor) -> List[torch.Tensor]:
"""Multiplies all image batches in the pyramid by a batch of scalars.
Args:
pyramid: Pyramid of image batches.
scalar: Batch of scalars.
Returns:
An image pyramid with all images multiplied by the scalar.
"""
# To multiply each image with its corresponding scalar, we first transpose
# the batch of images from BxHxWxC-format to CxHxWxB. This can then be
# multiplied with a batch of scalars, then we transpose back to the standard
# BxHxWxC form.
return [image * scalar[..., None, None] for image in pyramid]
def flow_pyramid_synthesis(
residual_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
"""Converts a residual flow pyramid into a flow pyramid."""
flow = residual_pyramid[-1]
flow_pyramid: List[torch.Tensor] = [flow]
for residual_flow in residual_pyramid[:-1][::-1]:
level_size = residual_flow.shape[2:4]
flow = F.interpolate(2 * flow, size=level_size, mode='bilinear')
flow = residual_flow + flow
flow_pyramid.insert(0, flow)
return flow_pyramid
def pyramid_warp(feature_pyramid: List[torch.Tensor],
flow_pyramid: List[torch.Tensor]) -> List[torch.Tensor]:
"""Warps the feature pyramid using the flow pyramid.
Args:
feature_pyramid: feature pyramid starting from the finest level.
flow_pyramid: flow fields, starting from the finest level.
Returns:
Reverse warped feature pyramid.
"""
warped_feature_pyramid = []
for features, flow in zip(feature_pyramid, flow_pyramid):
warped_feature_pyramid.append(warp(features, flow))
return warped_feature_pyramid
def concatenate_pyramids(pyramid1: List[torch.Tensor],
pyramid2: List[torch.Tensor]) -> List[torch.Tensor]:
"""Concatenates each pyramid level together in the channel dimension."""
result = []
for features1, features2 in zip(pyramid1, pyramid2):
result.append(torch.cat([features1, features2], dim=1))
return result
class Conv2d(nn.Sequential):
def __init__(self, in_channels, out_channels, size, activation: Optional[str] = 'relu'):
assert activation in (None, 'relu')
super().__init__(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=size,
padding='same' if size % 2 else 0)
)
self.size = size
self.activation = nn.LeakyReLU(.2) if activation == 'relu' else None
def forward(self, x):
if not self.size % 2:
x = F.pad(x, (0, 1, 0, 1))
y = self[0](x)
if self.activation is not None:
y = self.activation(y)
return y
|