Spaces:
Running
on
L40S
Running
on
L40S
import random | |
from typing import List | |
import torch | |
import torch.nn as nn | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
from diffusers.models.modeling_utils import ModelMixin | |
# from videoswap.utils.registry import MODEL_REGISTRY | |
class MLP(nn.Module): | |
def __init__(self, in_dim, out_dim, mid_dim=128): | |
super().__init__() | |
self.mlp = nn.Sequential( | |
nn.Linear(in_dim, mid_dim, bias=True), | |
nn.SiLU(inplace=False), | |
nn.Linear(mid_dim, out_dim, bias=True) | |
) | |
def forward(self, x): | |
return self.mlp(x) | |
def bilinear_interpolation(level_adapter_state, x, y, frame_idx, interpolated_value): | |
# level_adapter_state: (frames, channels, h, w) | |
# note the boundary | |
x1 = int(x) | |
y1 = int(y) | |
x2 = x1 + 1 | |
y2 = y1 + 1 | |
x_frac = x - x1 | |
y_frac = y - y1 | |
x1, x2 = max(min(x1, level_adapter_state.shape[3] - 1), 0), max(min(x2, level_adapter_state.shape[3] - 1), 0) | |
y1, y2 = max(min(y1, level_adapter_state.shape[2] - 1), 0), max(min(y2, level_adapter_state.shape[2] - 1), 0) | |
w11 = (1 - x_frac) * (1 - y_frac) | |
w21 = x_frac * (1 - y_frac) | |
w12 = (1 - x_frac) * y_frac | |
w22 = x_frac * y_frac | |
level_adapter_state[frame_idx, :, y1, x1] += interpolated_value * w11 | |
level_adapter_state[frame_idx, :, y1, x2] += interpolated_value * w21 | |
level_adapter_state[frame_idx, :, y2, x1] += interpolated_value * w12 | |
level_adapter_state[frame_idx, :, y2, x2] += interpolated_value * w22 | |
return level_adapter_state | |
# @MODEL_REGISTRY.register() | |
class SparsePointAdapter(ModelMixin, ConfigMixin): | |
def __init__( | |
self, | |
embedding_channels=1280, | |
channels=[320, 640, 1280, 1280], | |
downsample_rate=[8, 16, 32, 64], | |
mid_dim=128, | |
): | |
super().__init__() | |
self.model_list = nn.ModuleList() | |
for ch in channels: | |
self.model_list.append(MLP(embedding_channels, ch, mid_dim)) | |
self.downsample_rate = downsample_rate | |
self.channels = channels | |
self.radius = 2 | |
def generate_loss_mask(self, point_index_list, point_tracker, num_frames, h, w, loss_type): | |
if loss_type == 'global': | |
# True | |
loss_mask = torch.ones((num_frames, 4, h // self.downsample_rate[0], w // self.downsample_rate[0])) | |
else: | |
# only compute loss for visible points, with a radius that is irrelevant of the downsampling scale | |
loss_mask = torch.zeros((num_frames, 4, h // self.downsample_rate[0], w // self.downsample_rate[0])) | |
for point_idx in point_index_list: | |
for frame_idx in range(num_frames): | |
px, py = point_tracker[frame_idx, point_idx] | |
if px < 0 or py < 0: | |
continue | |
else: | |
px, py = px / self.downsample_rate[0], py / self.downsample_rate[0] | |
x1 = int(px) - self.radius | |
y1 = int(py) - self.radius | |
x2 = int(px) + self.radius | |
y2 = int(py) + self.radius | |
x1, x2 = max(min(x1, loss_mask.shape[3] - 1), 0), max(min(x2, loss_mask.shape[3] - 1), 0) | |
y1, y2 = max(min(y1, loss_mask.shape[2] - 1), 0), max(min(y2, loss_mask.shape[2] - 1), 0) | |
loss_mask[:, :, y1:y2, x1:x2] = 1.0 | |
return loss_mask | |
def forward(self, point_tracker, size, point_embedding, index_list=None, drop_rate=0.0, loss_type='global') -> List[torch.Tensor]: | |
# # (1, frames, num_points, 2) -> (frames, num_points, 2) | |
# point_tracker = point_tracker.squeeze(0) | |
# # (1, num_points, 1280) -> (num_points, 1280) | |
# point_embedding = point_embedding.squeeze(0) | |
w, h = size | |
num_frames, num_points = point_tracker.shape[:2] | |
if self.training: | |
point_index_list = [point_idx for point_idx in range(num_points) if random.random() > drop_rate] | |
loss_mask = self.generate_loss_mask(point_index_list, point_tracker, num_frames, h, w, loss_type) | |
else: | |
point_index_list = [point_idx for point_idx in range(num_points) if index_list is None or point_idx in index_list] | |
adapter_state = [] | |
for level_idx, module in enumerate(self.model_list): | |
downsample_rate = self.downsample_rate[level_idx] | |
level_w, level_h = w // downsample_rate, h // downsample_rate | |
# e.g. (num_points, 1280) -> (num_points, 320) | |
point_feat = module(point_embedding) | |
level_adapter_state = torch.zeros((num_frames, self.channels[level_idx], level_h, level_w)).to(point_feat.device, dtype=point_feat.dtype) | |
for point_idx in point_index_list: | |
for frame_idx in range(num_frames): | |
px, py = point_tracker[frame_idx, point_idx] | |
if px < 0 or py < 0: | |
continue | |
else: | |
px, py = px / downsample_rate, py / downsample_rate | |
level_adapter_state = bilinear_interpolation(level_adapter_state, px, py, frame_idx, point_feat[point_idx]) | |
adapter_state.append(level_adapter_state) | |
if self.training: | |
return adapter_state, loss_mask | |
else: | |
return adapter_state | |