File size: 5,454 Bytes
c705408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import random
from typing import List

import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin

# from videoswap.utils.registry import MODEL_REGISTRY


class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, mid_dim=128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, mid_dim, bias=True),
            nn.SiLU(inplace=False),
            nn.Linear(mid_dim, out_dim, bias=True)
        )

    def forward(self, x):
        return self.mlp(x)


def bilinear_interpolation(level_adapter_state, x, y, frame_idx, interpolated_value):
    # level_adapter_state: (frames, channels, h, w)
    # note the boundary
    x1 = int(x)
    y1 = int(y)
    x2 = x1 + 1
    y2 = y1 + 1
    x_frac = x - x1
    y_frac = y - y1

    x1, x2 = max(min(x1, level_adapter_state.shape[3] - 1), 0), max(min(x2, level_adapter_state.shape[3] - 1), 0)
    y1, y2 = max(min(y1, level_adapter_state.shape[2] - 1), 0), max(min(y2, level_adapter_state.shape[2] - 1), 0)

    w11 = (1 - x_frac) * (1 - y_frac)
    w21 = x_frac * (1 - y_frac)
    w12 = (1 - x_frac) * y_frac
    w22 = x_frac * y_frac

    level_adapter_state[frame_idx, :, y1, x1] += interpolated_value * w11
    level_adapter_state[frame_idx, :, y1, x2] += interpolated_value * w21
    level_adapter_state[frame_idx, :, y2, x1] += interpolated_value * w12
    level_adapter_state[frame_idx, :, y2, x2] += interpolated_value * w22

    return level_adapter_state


# @MODEL_REGISTRY.register()
class SparsePointAdapter(ModelMixin, ConfigMixin):

    @register_to_config
    def __init__(
        self,
        embedding_channels=1280,
        channels=[320, 640, 1280, 1280],
        downsample_rate=[8, 16, 32, 64],
        mid_dim=128,
    ):
        super().__init__()

        self.model_list = nn.ModuleList()

        for ch in channels:
            self.model_list.append(MLP(embedding_channels, ch, mid_dim))

        self.downsample_rate = downsample_rate
        self.channels = channels
        self.radius = 2

    def generate_loss_mask(self, point_index_list, point_tracker, num_frames, h, w, loss_type):
        if loss_type == 'global':
            # True
            loss_mask = torch.ones((num_frames, 4, h // self.downsample_rate[0], w // self.downsample_rate[0]))
        else:
            # only compute loss for visible points, with a radius that is irrelevant of the downsampling scale
            loss_mask = torch.zeros((num_frames, 4, h // self.downsample_rate[0], w // self.downsample_rate[0]))
            for point_idx in point_index_list:
                for frame_idx in range(num_frames):
                    px, py = point_tracker[frame_idx, point_idx]

                    if px < 0 or py < 0:
                        continue
                    else:
                        px, py = px / self.downsample_rate[0], py / self.downsample_rate[0]

                        x1 = int(px) - self.radius
                        y1 = int(py) - self.radius
                        x2 = int(px) + self.radius
                        y2 = int(py) + self.radius

                        x1, x2 = max(min(x1, loss_mask.shape[3] - 1), 0), max(min(x2, loss_mask.shape[3] - 1), 0)
                        y1, y2 = max(min(y1, loss_mask.shape[2] - 1), 0), max(min(y2, loss_mask.shape[2] - 1), 0)

                        loss_mask[:, :, y1:y2, x1:x2] = 1.0
        return loss_mask

    def forward(self, point_tracker, size, point_embedding, index_list=None, drop_rate=0.0, loss_type='global') -> List[torch.Tensor]:

        # # (1, frames, num_points, 2) -> (frames, num_points, 2)
        # point_tracker = point_tracker.squeeze(0)
        # # (1, num_points, 1280) -> (num_points, 1280)
        # point_embedding = point_embedding.squeeze(0)

        w, h = size
        num_frames, num_points = point_tracker.shape[:2]

        if self.training:
            point_index_list = [point_idx for point_idx in range(num_points) if random.random() > drop_rate]
            loss_mask = self.generate_loss_mask(point_index_list, point_tracker, num_frames, h, w, loss_type)
        else:
            point_index_list = [point_idx for point_idx in range(num_points) if index_list is None or point_idx in index_list]

        adapter_state = []
        for level_idx, module in enumerate(self.model_list):

            downsample_rate = self.downsample_rate[level_idx]
            level_w, level_h = w // downsample_rate, h // downsample_rate

            # e.g. (num_points, 1280) -> (num_points, 320) 
            point_feat = module(point_embedding)

            level_adapter_state = torch.zeros((num_frames, self.channels[level_idx], level_h, level_w)).to(point_feat.device, dtype=point_feat.dtype)

            for point_idx in point_index_list:

                for frame_idx in range(num_frames):
                    px, py = point_tracker[frame_idx, point_idx]

                    if px < 0 or py < 0:
                        continue
                    else:
                        px, py = px / downsample_rate, py / downsample_rate
                        level_adapter_state = bilinear_interpolation(level_adapter_state, px, py, frame_idx, point_feat[point_idx])
            adapter_state.append(level_adapter_state)

        if self.training:
            return adapter_state, loss_mask
        else:
            return adapter_state