Spaces:
Running
on
L40S
Running
on
L40S
File size: 5,454 Bytes
c705408 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import random
from typing import List
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
# from videoswap.utils.registry import MODEL_REGISTRY
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, mid_dim=128):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_dim, mid_dim, bias=True),
nn.SiLU(inplace=False),
nn.Linear(mid_dim, out_dim, bias=True)
)
def forward(self, x):
return self.mlp(x)
def bilinear_interpolation(level_adapter_state, x, y, frame_idx, interpolated_value):
# level_adapter_state: (frames, channels, h, w)
# note the boundary
x1 = int(x)
y1 = int(y)
x2 = x1 + 1
y2 = y1 + 1
x_frac = x - x1
y_frac = y - y1
x1, x2 = max(min(x1, level_adapter_state.shape[3] - 1), 0), max(min(x2, level_adapter_state.shape[3] - 1), 0)
y1, y2 = max(min(y1, level_adapter_state.shape[2] - 1), 0), max(min(y2, level_adapter_state.shape[2] - 1), 0)
w11 = (1 - x_frac) * (1 - y_frac)
w21 = x_frac * (1 - y_frac)
w12 = (1 - x_frac) * y_frac
w22 = x_frac * y_frac
level_adapter_state[frame_idx, :, y1, x1] += interpolated_value * w11
level_adapter_state[frame_idx, :, y1, x2] += interpolated_value * w21
level_adapter_state[frame_idx, :, y2, x1] += interpolated_value * w12
level_adapter_state[frame_idx, :, y2, x2] += interpolated_value * w22
return level_adapter_state
# @MODEL_REGISTRY.register()
class SparsePointAdapter(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
embedding_channels=1280,
channels=[320, 640, 1280, 1280],
downsample_rate=[8, 16, 32, 64],
mid_dim=128,
):
super().__init__()
self.model_list = nn.ModuleList()
for ch in channels:
self.model_list.append(MLP(embedding_channels, ch, mid_dim))
self.downsample_rate = downsample_rate
self.channels = channels
self.radius = 2
def generate_loss_mask(self, point_index_list, point_tracker, num_frames, h, w, loss_type):
if loss_type == 'global':
# True
loss_mask = torch.ones((num_frames, 4, h // self.downsample_rate[0], w // self.downsample_rate[0]))
else:
# only compute loss for visible points, with a radius that is irrelevant of the downsampling scale
loss_mask = torch.zeros((num_frames, 4, h // self.downsample_rate[0], w // self.downsample_rate[0]))
for point_idx in point_index_list:
for frame_idx in range(num_frames):
px, py = point_tracker[frame_idx, point_idx]
if px < 0 or py < 0:
continue
else:
px, py = px / self.downsample_rate[0], py / self.downsample_rate[0]
x1 = int(px) - self.radius
y1 = int(py) - self.radius
x2 = int(px) + self.radius
y2 = int(py) + self.radius
x1, x2 = max(min(x1, loss_mask.shape[3] - 1), 0), max(min(x2, loss_mask.shape[3] - 1), 0)
y1, y2 = max(min(y1, loss_mask.shape[2] - 1), 0), max(min(y2, loss_mask.shape[2] - 1), 0)
loss_mask[:, :, y1:y2, x1:x2] = 1.0
return loss_mask
def forward(self, point_tracker, size, point_embedding, index_list=None, drop_rate=0.0, loss_type='global') -> List[torch.Tensor]:
# # (1, frames, num_points, 2) -> (frames, num_points, 2)
# point_tracker = point_tracker.squeeze(0)
# # (1, num_points, 1280) -> (num_points, 1280)
# point_embedding = point_embedding.squeeze(0)
w, h = size
num_frames, num_points = point_tracker.shape[:2]
if self.training:
point_index_list = [point_idx for point_idx in range(num_points) if random.random() > drop_rate]
loss_mask = self.generate_loss_mask(point_index_list, point_tracker, num_frames, h, w, loss_type)
else:
point_index_list = [point_idx for point_idx in range(num_points) if index_list is None or point_idx in index_list]
adapter_state = []
for level_idx, module in enumerate(self.model_list):
downsample_rate = self.downsample_rate[level_idx]
level_w, level_h = w // downsample_rate, h // downsample_rate
# e.g. (num_points, 1280) -> (num_points, 320)
point_feat = module(point_embedding)
level_adapter_state = torch.zeros((num_frames, self.channels[level_idx], level_h, level_w)).to(point_feat.device, dtype=point_feat.dtype)
for point_idx in point_index_list:
for frame_idx in range(num_frames):
px, py = point_tracker[frame_idx, point_idx]
if px < 0 or py < 0:
continue
else:
px, py = px / downsample_rate, py / downsample_rate
level_adapter_state = bilinear_interpolation(level_adapter_state, px, py, frame_idx, point_feat[point_idx])
adapter_state.append(level_adapter_state)
if self.training:
return adapter_state, loss_mask
else:
return adapter_state
|