Spaces:
Runtime error
Runtime error
File size: 4,036 Bytes
d4ab5ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from torch import Tensor
from torch.nn import Module
from torch.utils.hooks import RemovableHandle
from transformers import ViTForImageClassification
from typing import Optional, Union
def _add_hooks(
model: ViTForImageClassification, get_hook: callable
) -> list[RemovableHandle]:
"""Adds a list of hooks to the model according to the get_hook function provided.
Args:
model (ViTForImageClassification): the ViT instance to add hooks to
get_hook (callable): a function that takes an index and returns a hook
Returns:
a list of RemovableHandle instances
"""
return (
[model.vit.embeddings.patch_embeddings.register_forward_hook(get_hook(0))]
+ [
layer.register_forward_pre_hook(get_hook(i + 1))
for i, layer in enumerate(model.vit.encoder.layer)
]
+ [
model.vit.encoder.layer[-1].register_forward_hook(
get_hook(len(model.vit.encoder.layer) + 1)
)
]
)
def vit_getter(
model: ViTForImageClassification, x: Tensor
) -> tuple[Tensor, list[Tensor]]:
"""A function that returns the logits and hidden states of the model.
Args:
model (ViTForImageClassification): the ViT instance to use for the forward pass
x (Tensor): the input to the model
Returns:
a tuple of the model's logits and hidden states
"""
hidden_states_ = []
def get_hook(i: int) -> callable:
def hook(_: Module, inputs: tuple, outputs: Optional[tuple] = None):
if i == 0:
hidden_states_.append(outputs)
elif 1 <= i <= len(model.vit.encoder.layer):
hidden_states_.append(inputs[0])
elif i == len(model.vit.encoder.layer) + 1:
hidden_states_.append(outputs[0])
return hook
handles = _add_hooks(model, get_hook)
try:
logits = model(x).logits
finally:
for handle in handles:
handle.remove()
return logits, hidden_states_
def vit_setter(
model: ViTForImageClassification, x: Tensor, hidden_states: list[Optional[Tensor]]
) -> tuple[Tensor, list[Tensor]]:
"""A function that sets some of the model's hidden states and returns its (new) logits
and hidden states after another forward pass.
Args:
model (ViTForImageClassification): the ViT instance to use for the forward pass
x (Tensor): the input to the model
hidden_states (list[Optional[Tensor]]): a list, with each element corresponding to
a hidden state to set or None to calculate anew for that index
Returns:
a tuple of the model's logits and (new) hidden states
"""
hidden_states_ = []
def get_hook(i: int) -> callable:
def hook(
_: Module, inputs: tuple, outputs: Optional[tuple] = None
) -> Optional[Union[tuple, Tensor]]:
if i == 0:
if hidden_states[i] is not None:
# print(hidden_states[i].shape)
hidden_states_.append(hidden_states[i][:, 1:])
return hidden_states_[-1]
else:
hidden_states_.append(outputs)
elif 1 <= i <= len(model.vit.encoder.layer):
if hidden_states[i] is not None:
hidden_states_.append(hidden_states[i])
return (hidden_states[i],) + inputs[1:]
else:
hidden_states_.append(inputs[0])
elif i == len(model.vit.encoder.layer) + 1:
if hidden_states[i] is not None:
hidden_states_.append(hidden_states[i])
return (hidden_states[i],) + outputs[1:]
else:
hidden_states_.append(outputs[0])
return hook
handles = _add_hooks(model, get_hook)
try:
logits = model(x).logits
finally:
for handle in handles:
handle.remove()
return logits, hidden_states_
|