File size: 4,036 Bytes
d4ab5ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from torch import Tensor
from torch.nn import Module
from torch.utils.hooks import RemovableHandle
from transformers import ViTForImageClassification
from typing import Optional, Union


def _add_hooks(
    model: ViTForImageClassification, get_hook: callable
) -> list[RemovableHandle]:
    """Adds a list of hooks to the model according to the get_hook function provided.

    Args:
        model (ViTForImageClassification): the ViT instance to add hooks to
        get_hook (callable): a function that takes an index and returns a hook

    Returns:
        a list of RemovableHandle instances
    """
    return (
        [model.vit.embeddings.patch_embeddings.register_forward_hook(get_hook(0))]
        + [
            layer.register_forward_pre_hook(get_hook(i + 1))
            for i, layer in enumerate(model.vit.encoder.layer)
        ]
        + [
            model.vit.encoder.layer[-1].register_forward_hook(
                get_hook(len(model.vit.encoder.layer) + 1)
            )
        ]
    )


def vit_getter(
    model: ViTForImageClassification, x: Tensor
) -> tuple[Tensor, list[Tensor]]:
    """A function that returns the logits and hidden states of the model.

    Args:
        model (ViTForImageClassification): the ViT instance to use for the forward pass
        x (Tensor): the input to the model

    Returns:
        a tuple of the model's logits and hidden states
    """
    hidden_states_ = []

    def get_hook(i: int) -> callable:
        def hook(_: Module, inputs: tuple, outputs: Optional[tuple] = None):
            if i == 0:
                hidden_states_.append(outputs)
            elif 1 <= i <= len(model.vit.encoder.layer):
                hidden_states_.append(inputs[0])
            elif i == len(model.vit.encoder.layer) + 1:
                hidden_states_.append(outputs[0])

        return hook

    handles = _add_hooks(model, get_hook)
    try:
        logits = model(x).logits
    finally:
        for handle in handles:
            handle.remove()

    return logits, hidden_states_


def vit_setter(
    model: ViTForImageClassification, x: Tensor, hidden_states: list[Optional[Tensor]]
) -> tuple[Tensor, list[Tensor]]:
    """A function that sets some of the model's hidden states and returns its (new) logits
     and hidden states after another forward pass.

    Args:
        model (ViTForImageClassification): the ViT instance to use for the forward pass
        x (Tensor): the input to the model
        hidden_states (list[Optional[Tensor]]): a list, with each element corresponding to
         a hidden state to set or None to calculate anew for that index

    Returns:
        a tuple of the model's logits and (new) hidden states
    """
    hidden_states_ = []

    def get_hook(i: int) -> callable:
        def hook(
            _: Module, inputs: tuple, outputs: Optional[tuple] = None
        ) -> Optional[Union[tuple, Tensor]]:
            if i == 0:
                if hidden_states[i] is not None:
                    # print(hidden_states[i].shape)
                    hidden_states_.append(hidden_states[i][:, 1:])
                    return hidden_states_[-1]
                else:
                    hidden_states_.append(outputs)

            elif 1 <= i <= len(model.vit.encoder.layer):
                if hidden_states[i] is not None:
                    hidden_states_.append(hidden_states[i])
                    return (hidden_states[i],) + inputs[1:]
                else:
                    hidden_states_.append(inputs[0])

            elif i == len(model.vit.encoder.layer) + 1:
                if hidden_states[i] is not None:
                    hidden_states_.append(hidden_states[i])
                    return (hidden_states[i],) + outputs[1:]
                else:
                    hidden_states_.append(outputs[0])

        return hook

    handles = _add_hooks(model, get_hook)

    try:
        logits = model(x).logits
    finally:
        for handle in handles:
            handle.remove()

    return logits, hidden_states_