File size: 5,798 Bytes
85588e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from collections import defaultdict
from contextlib import contextmanager
from logging import getLogger
import math
import sys
from typing import List, Union, Iterable

import numpy as np
import torch
from torch import nn

from timm.models import VisionTransformer
from einops import rearrange

DEFAULT_NUM_WINDOWED = 5


class VitDetArgs:
    def __init__(self,
                 window_size: int,
                 num_summary_tokens: int,
                 num_windowed: int = DEFAULT_NUM_WINDOWED,
    ):
        self.window_size = window_size
        self.num_summary_tokens = num_summary_tokens
        self.num_windowed = num_windowed


def apply_vitdet_arch(model: VisionTransformer, args: VitDetArgs):
    if isinstance(model, VisionTransformer):
        patch_embed = getattr(model, 'patch_generator', model.patch_embed)

        return ViTDetHook(patch_embed, model.blocks, args)
    else:
        print(f'Warning: Unable to apply VitDet aug!', file=sys.stderr)


class ViTDetHook:
    def __init__(self,
                 embedder: nn.Module,
                 blocks: nn.Sequential,
                 args: VitDetArgs,
    ):
        self.blocks = blocks
        self.num_summary_tokens = args.num_summary_tokens
        self.window_size = args.window_size

        self._input_resolution = None
        self._num_windows = None
        self._cls_patch = None
        self._order_cache = dict()

        embedder.register_forward_pre_hook(self._enter_model)

        # This will decide if we window-fy the patches
        # and enable vit-det for this iteration, and if so,
        # rearrange the patches for efficient mode switching
        blocks.register_forward_pre_hook(self._enter_blocks)

        is_global = True
        period = args.num_windowed + 1
        for i, layer in enumerate(blocks[:-1]):
            ctr = i % period
            if ctr == 0:
                layer.register_forward_pre_hook(self._to_windows)
                is_global = False
            elif ctr == args.num_windowed:
                layer.register_forward_pre_hook(self._to_global)
                is_global = True

        # Always ensure the final layer is a global layer
        if not is_global:
            blocks[-1].register_forward_pre_hook(self._to_global)

        blocks.register_forward_hook(self._exit_model)

    def _enter_model(self, _, input: List[torch.Tensor]):
        self._input_resolution = input[0].shape[-2:]

    def _enter_blocks(self, _, input: List[torch.Tensor]):
        # print(f'{get_rank()} - ViTDet Window Size: {self._window_size}', file=sys.stderr)

        patches = input[0]
        patches = self._rearrange_patches(patches)

        return (patches,) + input[1:]

    def _to_windows(self, _, input: List[torch.Tensor]):
        patches = input[0]

        if self.num_summary_tokens:
            self._cls_patch = patches[:, :self.num_summary_tokens]
            patches = patches[:, self.num_summary_tokens:]

        patches = rearrange(
            patches, 'b (p t) c -> (b p) t c',
            p=self._num_windows, t=self.window_size ** 2,
        )

        return (patches,) + input[1:]

    def _to_global(self, _, input: List[torch.Tensor]):
        patches = input[0]

        patches = rearrange(
            patches, '(b p) t c -> b (p t) c',
            p=self._num_windows, t=self.window_size ** 2,
            b=patches.shape[0] // self._num_windows,
        )

        if self.num_summary_tokens:
            patches = torch.cat([
                self._cls_patch,
                patches,
            ], dim=1)

        return (patches,) + input[1:]

    def _exit_model(self, _, inputs: List[torch.Tensor], patches: torch.Tensor):
        # Return patches to their original order
        patch_order = self._order_cache[self._input_resolution][0]
        patch_order = patch_order.reshape(1, -1, 1).expand_as(patches)

        ret_patches = torch.empty_like(patches)
        ret_patches = torch.scatter(
            ret_patches,
            dim=1,
            index=patch_order,
            src=patches,
        )

        return ret_patches

    def _rearrange_patches(self, patches: torch.Tensor):
        # We rearrange the patches so that we can efficiently
        # switch between windowed and global mode by just
        # reshaping the tensor

        patch_order, self._num_windows = self._order_cache.get(self._input_resolution, (None, None))
        if patch_order is None:
            num_feat_patches = patches.shape[1] - self.num_summary_tokens
            num_pixels = self._input_resolution[0] * self._input_resolution[1]

            patch_size = int(round(math.sqrt(num_pixels / num_feat_patches)))
            rows = self._input_resolution[-2] // patch_size
            cols = self._input_resolution[-1] // patch_size

            w_rows = rows // self.window_size
            w_cols = cols // self.window_size

            patch_order = torch.arange(0, num_feat_patches, device=patches.device)

            patch_order = rearrange(
                patch_order, '(wy py wx px) -> (wy wx py px)',
                wy=w_rows, wx=w_cols,
                py=self.window_size, px=self.window_size,
            )

            if self.num_summary_tokens:
                patch_order = torch.cat([
                    torch.arange(self.num_summary_tokens, dtype=patch_order.dtype, device=patch_order.device),
                    patch_order + self.num_summary_tokens,
                ])

            self._num_windows = w_rows * w_cols
            self._order_cache[self._input_resolution] = (
                patch_order,
                self._num_windows,
            )

        patch_order = patch_order.reshape(1, -1, 1).expand_as(patches)
        patches = torch.gather(patches, dim=1, index=patch_order)
        return patches