import logging import numpy as np import torch from torch import nn from enum import Enum, auto from transformers import BartModel, BartForConditionalGeneration, \ T5Model, T5ForConditionalGeneration, \ LEDModel, LEDForConditionalGeneration, \ AutoModelForCausalLM, AutoModelForSeq2SeqLM, \ MODEL_WITH_LM_HEAD_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING from typing import TypeVar, Generic from .index_building import Datastore, DatastoreBatch logger = logging.getLogger('Unlimiformer') logger.setLevel(20) ModelType = TypeVar('ModelType') class Unlimiformer(Generic[ModelType]): def __init__(self, model: ModelType, layer_begin=-1, layer_end=None, unlimiformer_head_num=None, exclude_attention=False, model_encoder_max_len=None, chunk_overlap=0, verbose=False, save_heatmap=False, tokenizer=None, unlimiformer_training=False, use_datastore=False, flat_index=False, test_datastore=False, reconstruct_embeddings=False, gpu_datastore=False, gpu_index=False, index_devices=(0,), datastore_device=0, ): super().__init__() self.model = model model.unlimiformer = self self.layer_begin = layer_begin self.layer_end = layer_end self.specific_head = unlimiformer_head_num self.exclude_attention = exclude_attention self.actual_model_window_size = None self.model_encoder_max_len = model_encoder_max_len self.chunk_overlap = chunk_overlap self.verbose = verbose self.save_heatmap = save_heatmap self.tokenizer = tokenizer self.unlimiformer_training = unlimiformer_training self.use_datastore = use_datastore self.flat_index = flat_index self.reconstruct_embeddings = reconstruct_embeddings self.gpu_datastore = gpu_datastore self.gpu_index = gpu_index # if torch.cuda.is_available() and gpu_index: # self.index_devices = [torch.device(f'cuda:{i}') for i in index_devices] # else: self.index_devices = [torch.device('cpu')] self.datastore_device = torch.device('cpu') self.test_datastore = test_datastore # flag for debugging self.device = torch.device('cpu') self.activation_capturer = None self.is_encoder_decoder = model.config.is_encoder_decoder self.hook_handles = [] self.is_input_encoding_pass = False self.is_first_test_decoding_step = False self.prev_tokens = None self.last_beam_idx = None self.heatmap = None self.cur_decoder_layer_index = None self.datastore = None self.break_into(model) def break_into(self, model): self.actual_model_window_size = self.window_size() if self.model_encoder_max_len is None: self.model_encoder_max_len = self.actual_model_window_size self.window_margin = int(self.model_encoder_max_len * self.chunk_overlap / 2) self.num_heads = model.config.num_attention_heads if self.specific_head is None: self.head_nums = Ellipsis # torch.arange(0, self.num_heads, device=self.device) else: self.head_nums = self.specific_head self.hooks_injected = False self.training_hooks_injected = False self.original_forward_func = model.forward # Activate Unlimiformer when calling model.eval(), deactivate for model.train() self.original_model_eval_func = model.eval model.eval = self.pre_eval_hook self.original_model_train_func = model.train model.train = self.pre_train_hook def pre_eval_hook(self): self.remove_training_hooks(self.model) self.inject_hooks(self.model) self.original_model_eval_func() def pre_train_hook(self, mode=True): # mode=True means model.train() is called # mode=False means model.eval() is called torch.cuda.empty_cache() if mode is True: self.break_out(self.model) if self.unlimiformer_training: self.inject_training_hooks(self.model) self.original_model_train_func(mode) def inject_hooks(self, model): if self.hooks_injected: return # Inject our activation_capturer to capture the activations at every forward pass attention_layers_to_capture = self.activation_to_capture(self.layer_begin, self.layer_end) self.activation_capturer = [] for layer in attention_layers_to_capture: if type(layer) is list: layer_capturers = [] for k_or_v in layer: capturer = ActivationCapturer(k_or_v, capture_input=False) layer_capturers.append(capturer) self.register_hook(k_or_v, capturer) self.activation_capturer.append(layer_capturers) else: capturer = ActivationCapturer(layer, capture_input=False) self.register_hook(layer, capturer) self.activation_capturer.append(capturer) # Inject our main function after the main attention function attention_layers_to_run = self.attention_op_to_run(self.layer_begin, self.layer_end) for layer in attention_layers_to_run: self.register_hook(layer, self.attention_forward_hook) decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end) self.original_decoder_layer_cross_attn_forward_funcs = [] for i, decoder_layer in enumerate(decoder_layers_to_run): decoder_layer_cross_attention = self.cross_attention(decoder_layer) self.original_decoder_layer_cross_attn_forward_funcs.append(decoder_layer_cross_attention.forward) decoder_layer_cross_attention.forward = self.create_cross_attn_pre_forward_hook(decoder_layer_cross_attention.forward, i) # Inject our hook function in the beginning of generation. # When the "model.generate()" will be called, it will first call our "reset_generation()" function, # and only then call "model.generate()" self.original_generate_func = model.generate model.generate = self.pre_generate_hook model.forward = self.pre_forward_hook self.original_reorder_cache_func = model._reorder_cache model._reorder_cache = self.reorder_cache_hook self.hooks_injected = True def inject_training_hooks(self, model): if self.training_hooks_injected: return # self.original_forward_func = model.forward model.forward = self.pre_forward_hook decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end) self.original_decoder_layer_self_attn_forward_funcs = [] for decoder_layer in decoder_layers_to_run: attention = self.self_attention(decoder_layer) self.original_decoder_layer_self_attn_forward_funcs.append(attention.forward) attention.forward = self.create_self_attn_pre_forward_hook(attention.forward) self.original_decoder_layer_cross_attn_forward_funcs = [] for i, decoder_layer in enumerate(decoder_layers_to_run): decoder_layer_cross_attention = self.cross_attention(decoder_layer) self.original_decoder_layer_cross_attn_forward_funcs.append(decoder_layer_cross_attention.forward) decoder_layer_cross_attention.forward = self.create_cross_attn_pre_forward_hook(decoder_layer_cross_attention.forward, i) self.original_decoder_layer_forward_funcs = [] for decoder_layer in decoder_layers_to_run: self.original_decoder_layer_forward_funcs.append(decoder_layer.forward) decoder_layer.forward = self.create_decoder_layer_func(decoder_layer.forward, decoder_layer) self.inject_hooks_for_unaffected_layers(model, decoder_layers_to_run) attention_layers_to_run = self.attention_op_to_run(self.layer_begin, self.layer_end) for layer in attention_layers_to_run: self.register_hook(layer, self.train_attention_forward_hook) self.training_hooks_injected = True def inject_hooks_for_unaffected_layers(self, model, decoder_layers_to_run): self.original_non_injected_decoder_layer_forward_funcs = [] non_injected_decoder_layers = [l for l in self.attention_layer_to_run(0, None) if l not in decoder_layers_to_run] for decoder_layer in non_injected_decoder_layers: self.original_non_injected_decoder_layer_forward_funcs.append(decoder_layer.forward) decoder_layer.forward = self.create_noninjected_decoder_layer_func(decoder_layer.forward, decoder_layer) def create_self_attn_pre_forward_hook(self, original_self_attn_forward_func): def self_attention_pre_forward_hook(*args, **kwargs): kwargs['past_key_value'] = None return original_self_attn_forward_func(*args, **kwargs) return self_attention_pre_forward_hook def create_decoder_layer_func(self, decoder_layer_original_forward_func, decoder_layer): def checkpointed_decoder_layer( hidden_states: torch.Tensor, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, layer_head_mask=None, cross_attn_layer_head_mask=None, past_key_value=None, output_attentions=False, position_bias=None, encoder_decoder_position_bias=None, use_cache=True): def forward_with_all_keys(hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask, past_key_value, output_attentions, use_cache, long_inputs, long_inputs_mask, position_bias, encoder_decoder_position_bias): key, value = self.create_key_value(long_inputs, decoder_layer) decoder_layer_args = self.create_decoder_layer_args( hidden_states=hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, output_attentions=output_attentions, position_bias=position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias, use_cache=use_cache, key=key,value=value) return decoder_layer_original_forward_func(**decoder_layer_args) return torch.utils.checkpoint.checkpoint( forward_with_all_keys, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask, None, output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask, position_bias, encoder_decoder_position_bias) return checkpointed_decoder_layer def create_noninjected_decoder_layer_func(self, decoder_layer_original_forward_func, decoder_layer): def checkpointed_decoder_layer( hidden_states: torch.Tensor, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, layer_head_mask=None, cross_attn_layer_head_mask=None, past_key_value=None, output_attentions=False, position_bias=None, encoder_decoder_position_bias=None, use_cache=True): def forward_with_all_keys(hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask, past_key_value, output_attentions, use_cache, long_inputs, long_inputs_mask, position_bias, encoder_decoder_position_bias): decoder_layer_args = self.create_decoder_layer_args( hidden_states=hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, output_attentions=output_attentions, position_bias=position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias, use_cache=use_cache, key=None, value=None) return decoder_layer_original_forward_func(**decoder_layer_args) return torch.utils.checkpoint.checkpoint( forward_with_all_keys, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask, None, output_attentions, use_cache, self.long_inputs_encoded, self.long_inputs_mask, position_bias, encoder_decoder_position_bias) return checkpointed_decoder_layer def register_hook(self, layer, func, pre=False): handle = layer.register_forward_pre_hook(func) if pre else layer.register_forward_hook(func) self.hook_handles.append(handle) def break_out(self, model): self.prompt_keys = [] self.prompt_values = [] self.prompt_attention_mask = [] self.generated_input_ids = [] torch.cuda.empty_cache() if not self.hooks_injected: return for h in self.hook_handles: h.remove() model.generate = self.original_generate_func model.forward = self.original_forward_func model._reorder_cache = self.original_reorder_cache_func decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end) for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_cross_attn_forward_funcs): self.cross_attention(decoder_layer).forward = original_func self.hooks_injected = False def remove_training_hooks(self, model): self.long_inputs_encoded, self.long_inputs_mask = None, None if not self.training_hooks_injected: return for h in self.hook_handles: h.remove() model.forward = self.original_forward_func decoder_layers_to_run = self.attention_layer_to_run(self.layer_begin, self.layer_end) for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_self_attn_forward_funcs): self.self_attention(decoder_layer).forward = original_func for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_cross_attn_forward_funcs): self.cross_attention(decoder_layer).forward = original_func for decoder_layer, original_func in zip(decoder_layers_to_run, self.original_decoder_layer_forward_funcs): decoder_layer.forward = original_func non_injected_decoder_layers = [l for l in self.attention_layer_to_run(0, None) if l not in decoder_layers_to_run] for decoder_layer, original_func in zip(non_injected_decoder_layers, self.original_non_injected_decoder_layer_forward_funcs): decoder_layer.forward = original_func self.training_hooks_injected = False def reset_memory(self, input_ids, attention_mask): if self.use_datastore: if self.is_encoder_decoder: self.datastore = [DatastoreBatch(dim=self.model.config.hidden_size, batch_size=input_ids.shape[0], flat_index=self.flat_index, gpu_index=self.gpu_index, index_device=self.index_devices[0])] self.hidden_states = [[]] else: self.datastore = [DatastoreBatch(dim=self.model.config.hidden_size, batch_size=input_ids.shape[0], flat_index=self.flat_index, gpu_index=self.gpu_index, index_device=self.index_devices[i % len(self.index_devices)]) for i in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]] self.hidden_states = [[] for _ in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]] torch.cuda.empty_cache() self.prompt_input_ids = input_ids self.input_ids_size = input_ids.shape[-1] self.prompt_keys, self.prompt_values = None, None self.prev_tokens = [None for _ in range(len(self.original_decoder_layer_cross_attn_forward_funcs))] self.last_beam_idx = None self.cur_layer_key_value_placeholder = None self.is_input_encoding_pass = True if self.is_encoder_decoder: dummy_labels = torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device) else: dummy_labels = None if self.save_heatmap: if self.heatmap is not None: print(f'Generated: {self.tokenizer.decode(self.generated_input_ids[0])}') self.plot_heatmap(self.heatmap[0].detach().cpu().numpy()) self.heatmap = torch.tensor([], dtype=torch.float, device=input_ids.device) self.generated_input_ids = torch.tensor([], dtype=torch.long, device=input_ids.device) self.prompt_keys = [[] for _ in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]] self.prompt_values = [[] for _ in range(self.model.config.num_hidden_layers)[self.layer_begin:self.layer_end]] self.prompt_attention_mask = [] window_indices = self.window_indices(input_ids.shape[-1]) for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices: logger.info(f'Encoding {context_start_ind} to {context_end_ind} out of {input_ids.shape[-1]}') chunk = input_ids[:, context_start_ind:context_end_ind].to(self.device) chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind].to(self.device) with torch.inference_mode(): _ = self.model(chunk, attention_mask=chunk_attention_mask, labels=dummy_labels) # , return_dict=True, output_hidden_states=True) if self.use_datastore: # TODO: verify with BART as well # hidden_states_to_index = [hidden_states.encoder_last_hidden_state] # list of length 1 of (batch, chunked_source_len, dim) hidden_states_to_index = [ layer_capturer.captured for layer_capturer in self.activation_capturer ] # hidden_states_to_index = list(hidden_states.hidden_states)[:-1][self.layer_begin:self.layer_end] to_add = [state[:, update_start_ind:update_end_ind].detach() for state in hidden_states_to_index] to_apply_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # to_apply_mask = to_apply_mask.log().to(to_add[0].dtype) to_apply_mask = to_apply_mask.to(to_add[0].dtype) if not self.reconstruct_embeddings: to_add_embeddings = to_add if not self.gpu_datastore: to_add_embeddings = [states.cpu() for states in to_add_embeddings] to_apply_mask = to_apply_mask.cpu() for i, layer_states in enumerate(to_add_embeddings): layer_states = layer_states * to_apply_mask.unsqueeze(-1) self.hidden_states[i].append(layer_states.to(self.datastore_device)) # list of len layers, inside it there is a list of len batch, each item is (masked_time, dim) # for i, to_add_layer in enumerate(to_add): # keys = [key[mask.bool()] for key, mask in zip(to_add_layer, to_apply_mask)] # self.datastore[i].add_keys(keys) if (not self.use_datastore) or self.test_datastore: layers_kv = [ self.process_key_value(layer_capturer) # (batch, head, time, dim) for layer_capturer in self.activation_capturer ] # list of pairs of (batch, head, time, dim) # list of (batch, head, chunked_time, dim) key = [layer[0][:, :, update_start_ind:update_end_ind] for layer in layers_kv] value = [layer[1][:, :, update_start_ind:update_end_ind] for layer in layers_kv] chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time) # key = torch.stack(key, dim=0) # (num_layers, batch, head, time, dim) # value = torch.stack(value, dim=0) # (num_layers, batch, head, time, dim) for i, (layer_key, layer_value) in enumerate(zip(key, value)): self.prompt_keys[i].append(layer_key) # (num_layers, batch, head, chunked_source_len, dim) self.prompt_values[i].append(layer_value) # (num_layers, batch, head, chunked_source_len, dim) self.prompt_attention_mask.append(chunk_attention_mask) # (batch, chunked_source_len) if self.use_datastore: # keys are all in datastore already! if not self.reconstruct_embeddings: # self.hidden_states = [torch.cat(layer_hidden_states, axis=1) for layer_hidden_states in self.hidden_states] concat_hidden_states = [] for i in range(len(self.hidden_states)): concat_hidden_states.append(torch.cat(self.hidden_states[i], axis=1)) self.hidden_states[i] = None self.hidden_states = concat_hidden_states for datastore, layer_hidden_states in zip(self.datastore, self.hidden_states): datastore.train_index(layer_hidden_states) if (not self.use_datastore) or self.test_datastore: for i, (layer_keys, layer_values) in enumerate(zip(self.prompt_keys, self.prompt_values)): self.prompt_keys[i] = torch.cat(layer_keys, dim=-2) self.prompt_values[i] = torch.cat(layer_values, dim=-2) # self.prompt_keys = torch.cat(self.prompt_keys, dim=-2) # (num_layers, batch, head, source_len, dim) # self.prompt_values = torch.cat(self.prompt_values, dim=-2) # (num_layers, batch, head, source_len, dim) self.prompt_attention_mask = torch.cat(self.prompt_attention_mask, dim=-1) # (batch, source_len) self.is_input_encoding_pass = False if self.verbose: print(f'Input: ' f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| ' f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}') print() def chunked_encode_input(self, input_ids, attention_mask): long_inputs_encoded = [] long_inputs_mask = [] window_indices = self.window_indices(input_ids.shape[-1]) self.is_input_encoding_pass = True for context_start_ind, context_end_ind, update_start_ind, update_end_ind in window_indices: chunk = input_ids[:, context_start_ind:context_end_ind] chunk_attention_mask = attention_mask[:, context_start_ind:context_end_ind] output = self.model.base_model.encoder(chunk, attention_mask=chunk_attention_mask, return_dict=True, output_hidden_states=True) encoder_last_hidden_state = output.last_hidden_state # (batch, time, dim) # list of (batch, head, chunked_time, dim) encoder_last_hidden_state = encoder_last_hidden_state[:, update_start_ind:update_end_ind] # (batch, chunked_time, dim) chunk_attention_mask = chunk_attention_mask[:, update_start_ind:update_end_ind] # (batch, chunked_time) long_inputs_encoded.append(encoder_last_hidden_state) # (batch, chunked_source_len, dim) long_inputs_mask.append(chunk_attention_mask) # (batch, chunked_source_len) long_inputs_encoded = torch.cat(long_inputs_encoded, dim=1) # (batch, source_len, dim) long_inputs_mask = torch.cat(long_inputs_mask, dim=1) # (batch, source_len) self.is_input_encoding_pass = False if self.verbose: print(f'Input: ' f'{self.tokenizer.decode(input_ids[0][:self.actual_model_window_size], skip_special_tokens=True)} ||| ' f'{self.tokenizer.decode(input_ids[0][self.actual_model_window_size:], skip_special_tokens=True)}') print() return long_inputs_encoded, long_inputs_mask def window_indices(self, total_seq_len): # Copied from SLED (Ivgy et al., 2022) # https://github.com/Mivg/SLED/blob/main/sled/modeling_sled.py#L467 if total_seq_len <= self.model_encoder_max_len: return [(0, total_seq_len, 0, total_seq_len)] else: results = [] # if self.chunk_overlap == 0: # stride = self.model_encoder_max_len stride = self.model_encoder_max_len - 2 * self.window_margin context_start = update_start_ind = 0 context_end = self.model_encoder_max_len if self.is_encoder_decoder: update_end_ind = context_end - self.window_margin else: update_end_ind = context_end # first window always should update from the beginning results.append((context_start, context_end, update_start_ind, update_end_ind)) while context_end < total_seq_len: context_end = min(total_seq_len, context_end + stride) context_start = ( context_start + stride if context_end < total_seq_len else total_seq_len - self.model_encoder_max_len ) update_start_ind = max(update_start_ind + stride, update_end_ind) # last window always should update until the end update_end_ind = ( min(total_seq_len, update_end_ind + stride) if context_end < total_seq_len else total_seq_len ) cs, ce, us, ue = context_start, context_end, update_start_ind - context_start, \ update_end_ind - context_start results.append((cs, ce, us, ue)) return results def pre_generate_hook(self, input_ids, **kwargs): if 'attention_mask' not in kwargs: kwargs['attention_mask'] = torch.ones_like(input_ids) self.reset_memory(input_ids, kwargs['attention_mask']) new_kwargs = kwargs if 'attention_mask' in kwargs: new_kwargs = {k: v for k, v in kwargs.items() if k != 'attention_mask'} new_kwargs['attention_mask'] = kwargs['attention_mask'][:, :self.actual_model_window_size].to(self.device) new_kwargs['use_cache'] = True if self.is_encoder_decoder: input_ids_prefix = input_ids[:, :self.actual_model_window_size] else: input_ids_prefix = input_ids[:, -self.actual_model_window_size:] input_ids_prefix = input_ids_prefix.to(self.device) return self.original_generate_func(input_ids_prefix, **new_kwargs) def pre_forward_hook(self, input_ids=None, attention_mask=None, labels=None, **kwargs): self.set_gradient_checkpointing(False) if not self.is_input_encoding_pass: if self.model.training: # self.reset_memory(input_ids, attention_mask) self.long_inputs_encoded, self.long_inputs_mask = self.chunked_encode_input(input_ids=input_ids, attention_mask=attention_mask) input_ids = input_ids[:, :self.actual_model_window_size] attention_mask = attention_mask[:, :self.actual_model_window_size] if attention_mask is not None else None # input_ids = input_ids[:, :self.model_encoder_max_len] # labels = labels[:, :self.model_encoder_max_len] if labels is not None else None else: if kwargs.get('past_key_values') is None: self.is_first_test_decoding_step = True if input_ids is not None: # self.input_ids_size += input_ids.shape[-1] self.input_ids_size += 1 if kwargs.get('decoder_input_ids') is not None: self.generated_input_ids = torch.cat([self.generated_input_ids, kwargs['decoder_input_ids']], axis=-1) result = self.original_forward_func(input_ids=input_ids, labels=labels, attention_mask=attention_mask, **kwargs) self.is_first_test_decoding_step = False return result def create_cross_attn_pre_forward_hook(self, original_cross_attn_forward_func, cur_layer_num): def attention_pre_forward_hook(hidden_states, attention_mask=None, *args, **kwargs): self.cur_decoder_layer_index = cur_layer_num if kwargs.get('past_key_value') is not None: # it's a tuple, and we convert it to a list to be able to perform assignment # and modify its items from our attention_forward_hook self.cur_layer_key_value_placeholder = \ kwargs['past_key_value'] = list(kwargs['past_key_value']) # (batch, head, time, attn_dim) batch_size, tgt_len, dim = hidden_states.shape if self.model.training: # from: (batch, tgt_len, dim) to: (batch * tgt_len, 1, dim) hidden_states = hidden_states.reshape(-1, 1, hidden_states.shape[-1]) # from: (batch, 1, tgt_len, dim) to: (batch * tgt_len, 1, 1, dim) attention_mask = attention_mask.reshape(-1, 1, 1, attention_mask.shape[-1]) attn_output, attn_weights_reshaped, past_key_value = original_cross_attn_forward_func(hidden_states=hidden_states, attention_mask=attention_mask, *args, **kwargs) attn_output = attn_output.reshape(batch_size, tgt_len, dim) result = (attn_output, attn_weights_reshaped, past_key_value) else: result = original_cross_attn_forward_func(hidden_states=hidden_states, attention_mask=attention_mask, *args, **kwargs) # Uri: this part adds the generated tokens to the prompt. # However it was commented out because currently we always keep the generated tokens in the attention window # if not self.is_encoder_decoder and not self.is_input_encoding_pass and \ # past_key_value[0].shape[2] > self.prompt_keys[self.cur_decoder_layer_index].shape[2]: # self.prompt_keys[self.cur_decoder_layer_index] = torch.cat([self.prompt_keys[self.cur_decoder_layer_index], past_key_value[0][:,:,-1:]], dim=-2) # self.prompt_values[self.cur_decoder_layer_index] = torch.cat([self.prompt_values[self.cur_decoder_layer_index], past_key_value[1][:,:,-1:]], dim=-2) # if self.cur_decoder_layer_index == self.model.config.num_hidden_layers - 1: # self.prompt_attention_mask = torch.cat([ # self.prompt_attention_mask, # torch.ones([self.prompt_attention_mask.shape[0], 1], dtype=self.prompt_attention_mask.dtype).to(self.device)], dim=-1) return result return attention_pre_forward_hook def attention_forward_hook(self, module, input, output): # output: (batch, time, 3 * heads * attention_dim) if self.is_input_encoding_pass or self.is_first_test_decoding_step: return with torch.no_grad(): prompt_size = self.prompt_input_ids.shape[1] generated_size = self.input_ids_size - prompt_size window_size = self.cur_layer_key_value_placeholder[0].shape[-2] # topk = min(self.actual_model_window_size, attn_weights.shape[-1]) topk = min(prompt_size, window_size) if not self.is_encoder_decoder: topk = min(topk, window_size - generated_size + 1) if self.gpu_index: topk = min(topk, 2048) query = self.process_query(output)[:,-1] # (batch * beam, head, dim) query = query[:, self.head_nums] # (batch * beam, head, dim) if self.use_datastore: # query: (batch, beam, head, dim) # need to multiply by key vector # query.view(query.shape[0], query.shape[1] * query.shape[2]) # k_proj in attention? datastore_index = 0 if self.is_encoder_decoder else self.cur_decoder_layer_index attention_layer_list = self.get_kv_projections(self.layer_begin, self.layer_end) k_proj_layer = [layers[0] for layers in attention_layer_list][self.cur_decoder_layer_index] v_proj_layer = [layers[1] for layers in attention_layer_list][self.cur_decoder_layer_index] # modify query by k_projs k_proj = k_proj_layer.weight datastore_query = self.preprocess_query(query, k_proj) # (batch * beam, num_heads, embed_dim) batch_size = self.datastore[datastore_index].batch_size datastore_query = datastore_query.view((batch_size, -1, datastore_query.shape[2])) # (batch, beam * num_heads, embed_dim) # then search if self.reconstruct_embeddings: # embeddings: (batch, beam * head, actual_model_window_size, dim) _, top_search_key_indices, embeddings = self.datastore[datastore_index].search_and_reconstruct(datastore_query, k=topk) else: _, top_search_key_indices = self.datastore[datastore_index].search(datastore_query, k=topk) # self.embeddings: (batch, src_len, dim) # indices: (batch, beam * head, actual_model_window_size) # embeddings: (batch, beam * head, actual_model_window_size, dim) embeddings = torch.take_along_dim(input=self.hidden_states[datastore_index].unsqueeze(1), indices=top_search_key_indices.unsqueeze(-1).to(self.hidden_states[datastore_index].device), dim=-2) embeddings = embeddings.to(self.device) # (batch, beam, head, actual_model_window_size) # top_search_key_scores = top_search_key_scores.reshape(batch_size, -1, *top_search_key_scores.shape[1:]) top_search_key_indices = top_search_key_indices.reshape(batch_size, -1, *top_search_key_indices.shape[1:]) # embeddings: (batch, beam, head, actual_model_window_size, dim) embeddings = embeddings.reshape(batch_size, -1, self.num_heads, *embeddings.shape[2:]) # raw_values are actually token indices; need to look them up if (not self.use_datastore) or self.test_datastore: this_layer_prompt_keys = self.prompt_keys[self.cur_decoder_layer_index] this_layer_prompt_values = self.prompt_values[self.cur_decoder_layer_index] # query: (batch * beam, head, dim) batch_size = self.prompt_input_ids.shape[0] beam_size = query.shape[0] // batch_size # query: (batch, beam, head, dim) query = query.reshape(batch_size, beam_size, *query.shape[1:]) # this_layer_prompt_keys: (batch, head, source_len, dim) # this_layer_prompt_keys.unsqueeze(1): (batch, 1, head, source_len, dim) # query.unsqueeze(-1): (batch, beam, head, dim, 1) # attn_weights: (batch, beam, head, source_len) attn_weights = torch.matmul(this_layer_prompt_keys.unsqueeze(1)[:, :, self.head_nums], query.unsqueeze(-1)).squeeze(-1) # attn_weights = torch.matmul(query.unsqueeze(-2), this_layer_prompt_keys.unsqueeze(1)[:, :, self.head_nums]).squeeze(-2) prompt_attention_mask_to_add = (1 - self.prompt_attention_mask) * -1e9 # (batch, source_len) prompt_attention_mask_to_add = prompt_attention_mask_to_add.unsqueeze(1).unsqueeze(1) attn_weights += prompt_attention_mask_to_add # (batch, beam, head, source_len) if self.exclude_attention and attn_weights.shape[-1] > self.actual_model_window_size: attn_weights[..., :self.actual_model_window_size] -= 1e9 # target_keys, target_values, topk = self.get_target_slices(output) top_key_scores, top_key_indices = torch.topk(attn_weights, k=topk, dim=-1, sorted=True) # (batch, beam, head, trunc_source) if self.save_heatmap: # heatrow: (beam, heads, source_len) heatrow = torch.zeros([top_key_indices.shape[1], top_key_indices.shape[2], this_layer_prompt_keys.shape[-2]], dtype=torch.float) heatrow = heatrow.scatter(index=top_key_indices[0], src=torch.ones_like(top_key_scores[0]), dim=-1) # heatrow = torch.nn.functional.softmax(heatrow, dim=-1) # self.heatmap: (beam, heads, targets, source_len) self.heatmap = torch.cat([self.heatmap, heatrow.unsqueeze(-2)], axis=-2) if self.test_datastore: assert top_key_indices.shape == top_search_key_indices.shape assert torch.mean((top_key_indices == top_search_key_indices).float()) > 0.99 if self.verbose: if self.is_encoder_decoder: for i, beam in enumerate(self.generated_input_ids): print(f'({i}) Generated: {self.tokenizer.decode(beam)}') # else: # print(f'Generated: {self.tokenizer.decode(self.input_ids)}') print() if self.use_datastore: # k_proj_layer.weight, v_proj_layer.weight: (embed_dim, embed_dim) # embeddings: (batch, beam, head, encoder_len, embed_dim) retrieved_keys, retrieved_values = self.post_process_retrieved(embeddings, k_proj_layer, v_proj_layer, top_search_key_indices) else: # this_layer_prompt_keys: (batch, head, source_len, dim) # top_key_indices: (batch, beam, head, trunc_source) retrieved_keys = torch.take_along_dim(this_layer_prompt_keys.unsqueeze(1), indices=top_key_indices.unsqueeze(-1), dim=-2) # (batch, head, trunc_source, attn_dim) retrieved_values = torch.take_along_dim(this_layer_prompt_values.unsqueeze(1), indices=top_key_indices.unsqueeze(-1), dim=-2) # (batch, head, trunc_source, attn_dim) if self.test_datastore: correct_keys = torch.take_along_dim(this_layer_prompt_keys.unsqueeze(1), indices=top_key_indices.unsqueeze(-1), dim=-2) # (batch, head, trunc_source, attn_dim) correct_values = torch.take_along_dim(this_layer_prompt_values.unsqueeze(1), indices=top_key_indices.unsqueeze(-1), dim=-2) # (batch, head, trunc_source, attn_dim) assert correct_keys.shape == retrieved_keys.shape assert correct_values.shape == retrieved_values.shape assert torch.mean(torch.isclose(correct_keys, retrieved_keys, rtol=1e-3, atol=1e-3).float()) > 0.99 assert torch.mean(torch.isclose(correct_values, retrieved_values, rtol=1e-3, atol=1e-3).float()) > 0.99 # retrieved_keys, retrieved_values: (batch * beam, head, encoder_len, attn_dim) retrieved_keys = retrieved_keys.flatten(0, 1)[:,:,:topk] retrieved_values = retrieved_values.flatten(0, 1)[:,:,:topk] self.cur_layer_key_value_placeholder[0] = torch.cat([retrieved_keys, self.cur_layer_key_value_placeholder[0][:,:,topk:]], dim=-2) self.cur_layer_key_value_placeholder[1] = torch.cat([retrieved_values, self.cur_layer_key_value_placeholder[1][:,:,topk:]], dim=-2) return def train_attention_forward_hook(self, module, input, output): # output: (batch, time, 3 * heads * attention_dim) if self.is_input_encoding_pass or self.is_first_test_decoding_step: return this_layer_prompt_keys = self.cur_layer_key_value_placeholder[0] this_layer_prompt_values = self.cur_layer_key_value_placeholder[1] with torch.no_grad(): query = self.process_query(output) # (batch * beam, tgt_len, head, dim) # query = query[:, :, self.head_nums] # (batch * beam, head, dim) # query: (batch * beam, tgt_len, head, dim) batch_size = this_layer_prompt_keys.shape[0] tgt_len = query.shape[0] // batch_size # query: (batch, tgt, head, dim) query = query.reshape(batch_size, tgt_len, *query.shape[2:]) # this_layer_prompt_keys: (batch, head, source_len, dim) # this_layer_prompt_keys.unsqueeze(1): (batch, 1, head, source_len, dim) # attn_weights: (batch, tgt_len, head, 1, source_len) # attn_weights = torch.matmul(query.unsqueeze(-2), this_layer_prompt_keys.unsqueeze(1).permute(0,1,2,4,3)) attn_weights = torch.matmul(this_layer_prompt_keys.unsqueeze(1), query.unsqueeze(-1)) \ .reshape(batch_size, tgt_len, query.shape[-2], 1, this_layer_prompt_keys.shape[-2]) # attn_weights = torch.matmul(query.unsqueeze(-2), this_layer_prompt_keys.unsqueeze(1)[:, :, self.head_nums]).squeeze(-2) prompt_attention_mask_to_add = (1 - self.long_inputs_mask) * -1e9 # (batch, source_len) prompt_attention_mask_to_add = prompt_attention_mask_to_add.unsqueeze(1).unsqueeze(1).unsqueeze(1) attn_weights += prompt_attention_mask_to_add # (batch, beam, head, source_len) # target_keys, target_values, topk = self.get_target_slices(output) topk = min(self.actual_model_window_size, attn_weights.shape[-1]) top_key_scores, top_key_indices = torch.topk(attn_weights, k=min(topk, attn_weights.shape[-1]), dim=-1, sorted=True) # (batch, beam, head, tgt, trunc_source) # this_layer_prompt_keys: (batch, head, source_len, dim) # top_key_indices: (batch, tgt_len, head, 1, trunc_source) new_keys = torch.take_along_dim(this_layer_prompt_keys.unsqueeze(2).unsqueeze(1), indices=top_key_indices.unsqueeze(-1), dim=-2) # (batch, tgt_len, head, 1, trunc_source, attn_dim) new_values = torch.take_along_dim(this_layer_prompt_values.unsqueeze(2).unsqueeze(1), indices=top_key_indices.unsqueeze(-1), dim=-2) # (batch, tgt_len, head, 1, trunc_source, attn_dim) # (batch * beam, head, tgt_len, trunc_source, attn_dim) self.cur_layer_key_value_placeholder[0] = new_keys.flatten(0, 1).squeeze(2) self.cur_layer_key_value_placeholder[1] = new_values.flatten(0, 1).squeeze(2) return def preprocess_query(self, query, k_proj_weight): k_proj = k_proj_weight.view(1, self.num_heads, query.shape[-1], k_proj_weight.shape[0]) # (1, num_heads, attn_dim, embed_dim) datastore_query = query.unsqueeze(-2) # (batch * beam, num_heads, 1, attn_dim) datastore_query = torch.matmul(datastore_query, k_proj) # (batch * beam, num_heads, 1, embed_dim) datastore_query = datastore_query.squeeze(-2) # (batch * beam, num_heads, embed_dim) return datastore_query def post_process_retrieved(self, embeddings, k_proj_layer, v_proj_layer, top_search_key_indices): embed_dim = embeddings.shape[-1] k_weight = k_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, 1, heads, embed_dim, attn_dim) k_bias = 0 if k_proj_layer.bias is not None: k_bias = k_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0) v_weight = v_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, heads, embed_dim, attn_dim) v_bias = 0 if v_proj_layer.bias is not None: v_bias = v_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0) # new_keys, new_values: (batch, beam, head, encoder_len, attn_dim) retrieved_keys = torch.matmul(embeddings, k_weight) + k_bias # (beam, head, encoder_len, embed_dim) retrieved_values = torch.matmul(embeddings, v_weight) + v_bias # (beam, head, encoder_len, embed_dim) return retrieved_keys, retrieved_values def set_gradient_checkpointing(self, value): self.model.base_model.decoder.gradient_checkpointing = value def reorder_cache_hook(self, past, beam_idx): self.last_beam_idx = beam_idx self.generated_input_ids = self.generated_input_ids[beam_idx] for i, layer_prev_tokens in enumerate(self.prev_tokens): if layer_prev_tokens is not None: self.prev_tokens[i] = layer_prev_tokens.flatten(0, 1)[beam_idx].reshape(layer_prev_tokens.shape) if self.save_heatmap and self.heatmap.numel() > 0: self.heatmap = self.heatmap[beam_idx] return self.original_reorder_cache_func(past, beam_idx) @classmethod def convert_model(cls, model, *args, **kwargs): # if type(model.config) in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING: # elif type(model.config) in MODEL_WITH_LM_HEAD_MAPPING: # else: # raise ValueError(f'Unsupported model type: {type(model.config)}') # if model.config.is_encoder_decoder: # model_clone = AutoModelForSeq2SeqLM.from_config(model.config) # else: # model_clone = AutoModelForCausalLM.from_config(model.config) # model_clone.load_state_dict(model.state_dict()).to(args.device) type_to_class = { BartModel: UnlimiformerBART, BartForConditionalGeneration: UnlimiformerBART, T5Model: UnlimiformerT5, T5ForConditionalGeneration: UnlimiformerT5, LEDModel: UnlimiformerLED, LEDForConditionalGeneration: UnlimiformerLED, # LlamaModel: UnlimiformerLLaMa, # LlamaForCausalLM: UnlimiformerLLaMa, } type_to_class[type(model)](model, *args, **kwargs) return model def plot_heatmap(self, data, xticklabels='auto', yticklabels='auto'): # data: (heads, targets, source_len) import seaborn as sb import matplotlib.pyplot as plt # print('gat = np.array([') # for row in data[0]: # rowstr = ', '.join([f'{x:.2f}' for x in row]) # print(f' [{rowstr}],') # print(']') # sb.set(font_scale=1.5, rc={'text.usetex': True}) for i in range(data.shape[0]): fig, axes = plt.subplots(1, 1, figsize=(40, 100)) cur_ax = axes axes.set_title(f'Head #{i}, length: {data.shape[2]}, target length: {data.shape[1]}') cur_ax = axes # annot = [[x for x in row] for row in data] ax = sb.heatmap(data[i], annot=False, fmt='.2f', xticklabels=512, yticklabels=yticklabels, ax=cur_ax) ax.xaxis.tick_top() plt.savefig(f'knns_head{i}.pdf') # plt.savefig('gat_s10_contrast.pdf') plt.show() class UnlimiformerBART(Unlimiformer[BartModel]): def __init__(self, model: BartModel, *args, **kwargs): super().__init__(model, *args, **kwargs) def create_key_value(self, encoder_hidden_states, decoder_layer): # (batch, time, hidden_dim) attention = decoder_layer.encoder_attn # key, value: (batch, heads, time, attn_dim) key = attention.k_proj(encoder_hidden_states) key = key.view(key.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() value = attention.v_proj(encoder_hidden_states) value = value.view(value.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() # key, value: (batch, heads, time, attn_dim) return key, value def process_key_value(self, capturers): key_capturer, value_capturer = capturers key, value = key_capturer.captured, value_capturer.captured # (batch, time, heads, attn_dim) attention = self.model.base_model.decoder.layers[-1].encoder_attn # query, key, value: (batch, heads, time, attn_dim) # query = query.view(query.shape[0], query.shape[1], attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() key = key.view(key.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() value = value.view(value.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() return key, value def process_query(self, output): # (batch, time, heads, attn_dim) attention = self.model.base_model.decoder.layers[-1].encoder_attn # query: (batch, heads, time, attn_dim) # query = output.view(output.shape[0], output.shape[1], attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() query = output.view(output.shape[0], output.shape[1], attention.num_heads, attention.head_dim).contiguous() return query def get_kv_projections(self, layer_begin, layer_end): return [ [layer.encoder_attn.k_proj, layer.encoder_attn.v_proj] for layer in self.model.base_model.decoder.layers[layer_begin:layer_end] ] def activation_to_capture(self, layer_begin, layer_end): if self.use_datastore: return [self.model.base_model.encoder.layers[-1]] else: return self.get_kv_projections(layer_begin, layer_end) def attention_op_to_run(self, layer_begin, layer_end): return [ layer.encoder_attn.q_proj for layer in self.model.base_model.decoder.layers[layer_begin:layer_end] ] def attention_layer_to_run(self, layer_begin, layer_end): return self.model.base_model.decoder.layers[layer_begin:layer_end] def self_attention(self, decoder_layer): return decoder_layer.self_attn def cross_attention(self, decoder_layer): return decoder_layer.encoder_attn def window_size(self): return self.model.config.max_position_embeddings def create_decoder_layer_args(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask, past_key_value, output_attentions, position_bias, encoder_decoder_position_bias, use_cache, key, value): args = {'hidden_states': hidden_states, 'attention_mask': attention_mask, 'encoder_hidden_states': encoder_hidden_states, 'encoder_attention_mask': encoder_attention_mask, 'layer_head_mask': layer_head_mask, 'cross_attn_layer_head_mask': cross_attn_layer_head_mask, 'past_key_value': (None, None, key, value), 'output_attentions': output_attentions, 'use_cache': use_cache,} if key is None and value is None: args['past_key_value'] = None return args class UnlimiformerT5(Unlimiformer[T5Model]): def __init__(self, model: T5Model, *args, **kwargs): super().__init__(model, *args, **kwargs) def create_key_value(self, encoder_hidden_states, decoder_layer): # (batch, time, hidden_dim) attention = decoder_layer.layer[1].EncDecAttention # key, value: (batch, heads, time, attn_dim) key = attention.k(encoder_hidden_states) key = key.view(key.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous() value = attention.v(encoder_hidden_states) value = value.view(value.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous() return key, value def process_key_value(self, capturers): key_capturer, value_capturer = capturers key, value = key_capturer.captured, value_capturer.captured # (batch, time, heads, attn_dim) attention = self.model.base_model.decoder.block[-1].layer[1].EncDecAttention # query, key, value: (batch, heads, time, attn_dim) # query = query.view(query.shape[0], query.shape[1], attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() key = key.view(key.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous() value = value.view(value.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).transpose(1, 2).contiguous() return key, value def process_query(self, output): # (batch, time, heads, attn_dim) attention = self.model.base_model.decoder.block[-1].layer[1].EncDecAttention # query: (batch, heads, time, attn_dim) query = output.view(output.shape[0], -1, attention.n_heads, attention.key_value_proj_dim).contiguous() return query def get_kv_projections(self, layer_begin, layer_end): return [ [layer.layer[1].EncDecAttention.k, layer.layer[1].EncDecAttention.v] for layer in self.model.base_model.decoder.block[layer_begin:layer_end] ] def activation_to_capture(self, layer_begin, layer_end): if self.use_datastore: return [self.model.base_model.encoder.layers[-1]] else: return self.get_kv_projections(layer_begin, layer_end) def attention_op_to_run(self, layer_begin, layer_end): return [ layer.layer[1].EncDecAttention.q for layer in self.model.base_model.decoder.block[layer_begin:layer_end] ] def attention_layer_to_run(self, layer_begin, layer_end): return self.model.base_model.decoder.block[layer_begin:layer_end] def self_attention(self, decoder_layer): return decoder_layer.layer[0] def cross_attention(self, decoder_layer): return decoder_layer.layer[1] def window_size(self): try: size = self.model.config.n_positions except AttributeError: size = 1024 return size def create_decoder_layer_args(self, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, layer_head_mask, cross_attn_layer_head_mask, past_key_value, output_attentions, position_bias, encoder_decoder_position_bias, use_cache, key, value): args = {'hidden_states': hidden_states, 'attention_mask': attention_mask, 'position_bias': position_bias, 'encoder_hidden_states': encoder_hidden_states, 'encoder_attention_mask': encoder_attention_mask, 'encoder_decoder_position_bias': encoder_decoder_position_bias, 'layer_head_mask': layer_head_mask, 'cross_attn_layer_head_mask': cross_attn_layer_head_mask, 'past_key_value': (None, None, key, value), 'use_cache': use_cache, 'output_attentions': output_attentions} if key is None and value is None: args['past_key_value'] = None return args class UnlimiformerLED(UnlimiformerBART): def __init__(self, model: LEDModel, *args, **kwargs): super().__init__(model, *args, **kwargs) def window_size(self): return self.model.config.max_encoder_position_embeddings # class UnlimiformerLLaMa(Unlimiformer[LlamaModel]): # def __init__(self, model: LlamaModel, *args, **kwargs): # super().__init__(model, *args, **kwargs) # def get_kv_projections(self, layer_begin, layer_end): # return [ # [layer.self_attn.k_proj, layer.self_attn.v_proj] # for layer in self.model.base_model.layers[layer_begin:layer_end] # ] # def activation_to_capture(self, layer_begin, layer_end): # if self.use_datastore: # return [ # layer.input_layernorm # for layer in self.model.base_model.layers[layer_begin:layer_end] # ] # else: # return self.get_kv_projections(layer_begin, layer_end) # def attention_op_to_run(self, layer_begin, layer_end): # return [ # layer.self_attn.q_proj # for layer in self.model.base_model.layers[layer_begin:layer_end] # ] # def attention_layer_to_run(self, layer_begin, layer_end): # return self.model.base_model.layers[layer_begin:layer_end] # def self_attention(self, decoder_layer): # return decoder_layer.self_attn # def cross_attention(self, decoder_layer): # return decoder_layer.self_attn # def window_size(self): # return self.model.config.max_position_embeddings # def set_gradient_checkpointing(self, value): # self.model.base_model.gradient_checkpointing = value # def process_key_value(self, capturers): # key_capturer, value_capturer = capturers # # (batch, time, heads * attn_dim) # key, value = key_capturer.captured, value_capturer.captured # attention = self.model.base_model.layers[-1].self_attn # # (batch, heads, time, attn_dim) # key = key.view(key.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() # value = value.view(value.shape[0], -1, attention.num_heads, attention.head_dim).transpose(1, 2).contiguous() # return key, value # def process_query(self, output): # # output: (batch, time, heads * attn_dim) # attention = self.model.base_model.layers[-1].self_attn # # query: (batch, time, heads, attn_dim) # query = output.view(output.shape[0], output.shape[1], attention.num_heads, attention.head_dim).contiguous() # return query # def rotate_half(self, x): # """Rotates half the hidden dims of the input.""" # x1 = x[..., : x.shape[-1] // 2] # x2 = x[..., x.shape[-1] // 2 :] # return torch.cat((-x2, x1), dim=-1) # def preprocess_query(self, query, k_proj_weight): # # query: (batch * time, head, dim) # attention = self.model.base_model.layers[-1].self_attn # num_generated = min(self.input_ids_size - self.prompt_input_ids.shape[1], self.actual_model_window_size) # cos, sin = attention.rotary_emb(query, seq_len=num_generated) # cos = cos[:,:,-1] # [1, 1, dim] # sin = sin[:,:,-1] # [1, 1, dim] # # cos = cos[-1].unsqueeze(0).unsqueeze(0) # [bs, 1, seq_len, dim] # # sin = sin[-1].unsqueeze(0) # [bs, 1, seq_len, dim] # query = (query * cos) + (self.rotate_half(query) * sin) # k_proj = k_proj_weight.view(1, self.num_heads, query.shape[-1], k_proj_weight.shape[0]) # (1, num_heads, attn_dim, embed_dim) # k_proj_l = k_proj[..., :k_proj.shape[-2] // 2, :] # k_proj_r = k_proj[..., k_proj.shape[-2] // 2:, :] # k_proj_rotated = torch.cat([-k_proj_l, k_proj_r], dim=-2) # datastore_query = query.unsqueeze(-2) # (batch * beam, num_heads, 1, attn_dim) # datastore_query = torch.matmul(datastore_query, k_proj + k_proj_rotated) # (batch * beam, num_heads, 1, embed_dim) # datastore_query = datastore_query.squeeze(-2) # (batch * beam, num_heads, embed_dim) # return datastore_query # def post_process_retrieved(self, embeddings, k_proj_layer, v_proj_layer, top_search_key_indices): # embed_dim = embeddings.shape[-1] # k_weight = k_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, 1, heads, embed_dim, attn_dim) # k_bias = 0 # if k_proj_layer.bias is not None: # k_bias = k_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0) # v_weight = v_proj_layer.weight.view(1, 1, self.num_heads, embed_dim // self.num_heads, embed_dim).transpose(-2,-1) # (1, heads, embed_dim, attn_dim) # v_bias = 0 # if v_proj_layer.bias is not None: # v_bias = v_proj_layer.bias.view(1, self.num_heads, embed_dim // self.num_heads).unsqueeze(-2).unsqueeze(0) # # new_keys, new_values: (batch, beam, head, encoder_len, attn_dim) # retrieved_keys = torch.matmul(embeddings, k_weight) + k_bias # (beam, head, encoder_len, embed_dim) # retrieved_values = torch.matmul(embeddings, v_weight) + v_bias # (beam, head, encoder_len, embed_dim) # attention = self.model.base_model.layers[-1].self_attn # cos, sin = attention.rotary_emb(retrieved_values, seq_len=self.hidden_states[0].shape[1]) # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] # if self.prompt_input_ids.shape[1] > self.actual_model_window_size: # # scale the top key indices to the actual model window size, such that the model will not see # # positional embeddings that did not appear at training time # scaled_key_indices = ((top_search_key_indices / self.prompt_input_ids.shape[1]) * self.actual_model_window_size).int() # else: # scaled_key_indices = top_search_key_indices # # top_search_key_indices = top_search_key_indices.to(cos.device) # scaled_key_indices = scaled_key_indices.to(cos.device) # cos = cos[scaled_key_indices] # [bs, 1, seq_len, dim] # sin = sin[scaled_key_indices] # [bs, 1, seq_len, dim] # retrieved_keys = (retrieved_keys * cos) + (self.rotate_half(retrieved_keys) * sin) # return retrieved_keys, retrieved_values class ActivationCapturer(nn.Module): def __init__(self, layer, capture_input=False): super().__init__() self.layer = layer self.capture_input = capture_input self.captured = None def unwrap_tuple(self, t): if isinstance(t, tuple) and len(t) == 1: t = t[0] return t def forward(self, module, layer_input, layer_output): if self.capture_input: self.captured = self.unwrap_tuple(layer_input) else: self.captured = self.unwrap_tuple(layer_output)