from torch import Tensor from torch.nn import Module from torch.utils.hooks import RemovableHandle from transformers import ViTForImageClassification from typing import Optional, Union def _add_hooks( model: ViTForImageClassification, get_hook: callable ) -> list[RemovableHandle]: """Adds a list of hooks to the model according to the get_hook function provided. Args: model (ViTForImageClassification): the ViT instance to add hooks to get_hook (callable): a function that takes an index and returns a hook Returns: a list of RemovableHandle instances """ return ( [model.vit.embeddings.patch_embeddings.register_forward_hook(get_hook(0))] + [ layer.register_forward_pre_hook(get_hook(i + 1)) for i, layer in enumerate(model.vit.encoder.layer) ] + [ model.vit.encoder.layer[-1].register_forward_hook( get_hook(len(model.vit.encoder.layer) + 1) ) ] ) def vit_getter( model: ViTForImageClassification, x: Tensor ) -> tuple[Tensor, list[Tensor]]: """A function that returns the logits and hidden states of the model. Args: model (ViTForImageClassification): the ViT instance to use for the forward pass x (Tensor): the input to the model Returns: a tuple of the model's logits and hidden states """ hidden_states_ = [] def get_hook(i: int) -> callable: def hook(_: Module, inputs: tuple, outputs: Optional[tuple] = None): if i == 0: hidden_states_.append(outputs) elif 1 <= i <= len(model.vit.encoder.layer): hidden_states_.append(inputs[0]) elif i == len(model.vit.encoder.layer) + 1: hidden_states_.append(outputs[0]) return hook handles = _add_hooks(model, get_hook) try: logits = model(x).logits finally: for handle in handles: handle.remove() return logits, hidden_states_ def vit_setter( model: ViTForImageClassification, x: Tensor, hidden_states: list[Optional[Tensor]] ) -> tuple[Tensor, list[Tensor]]: """A function that sets some of the model's hidden states and returns its (new) logits and hidden states after another forward pass. Args: model (ViTForImageClassification): the ViT instance to use for the forward pass x (Tensor): the input to the model hidden_states (list[Optional[Tensor]]): a list, with each element corresponding to a hidden state to set or None to calculate anew for that index Returns: a tuple of the model's logits and (new) hidden states """ hidden_states_ = [] def get_hook(i: int) -> callable: def hook( _: Module, inputs: tuple, outputs: Optional[tuple] = None ) -> Optional[Union[tuple, Tensor]]: if i == 0: if hidden_states[i] is not None: # print(hidden_states[i].shape) hidden_states_.append(hidden_states[i][:, 1:]) return hidden_states_[-1] else: hidden_states_.append(outputs) elif 1 <= i <= len(model.vit.encoder.layer): if hidden_states[i] is not None: hidden_states_.append(hidden_states[i]) return (hidden_states[i],) + inputs[1:] else: hidden_states_.append(inputs[0]) elif i == len(model.vit.encoder.layer) + 1: if hidden_states[i] is not None: hidden_states_.append(hidden_states[i]) return (hidden_states[i],) + outputs[1:] else: hidden_states_.append(outputs[0]) return hook handles = _add_hooks(model, get_hook) try: logits = model(x).logits finally: for handle in handles: handle.remove() return logits, hidden_states_