maxoul commited on
Commit
d8009b5
·
verified ·
1 Parent(s): 60719d2

Upload COCOM

Browse files
Files changed (4) hide show
  1. adapters.pth +3 -0
  2. config.json +27 -9
  3. decoder_first_last_layers.pth +3 -0
  4. modelling_pisco.py +886 -120
adapters.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32e26734db991e2270703d5b113f3be8df8aa292e78bb762287711abb8fbbb5e
3
+ size 168063670
config.json CHANGED
@@ -1,18 +1,36 @@
1
  {
2
- "_name_or_path": "/scratch/1/user/mlouis/calmar/pisco_hub_models/pisco-llama",
3
- "architectures": [
4
- "PISCO"
5
- ],
6
  "auto_map": {
7
- "AutoConfig": "modelling_pisco.PISCOConfig",
8
- "AutoModel": "modelling_pisco.PISCO"
9
  },
 
 
 
 
 
 
10
  "compr_rate": 16,
 
 
11
  "decoder_model_name": "meta-llama/Meta-Llama-3.1-8B-Instruct",
12
  "device_map": "auto",
 
 
 
 
 
 
 
13
  "lora_r": 16,
14
- "model_type": "PISCO",
 
 
 
 
15
  "sep": true,
16
- "torch_dtype": "bfloat16",
17
- "transformers_version": "4.44.2"
18
  }
 
1
  {
2
+ "_attn_implementation_autoset": true,
3
+ "ae_mode": "token",
4
+ "attn_implementation": null,
 
5
  "auto_map": {
6
+ "AutoConfig": "modelling_pisco.COCOMConfig",
7
+ "AutoModel": "modelling_pisco.COCOM"
8
  },
9
+ "compr_base_model_name": "mistralai/Mistral-7B-Instruct-v0.2",
10
+ "compr_every_n_layer": null,
11
+ "compr_mlp_hidden_dim": 1024,
12
+ "compr_mode": "last_in_mask",
13
+ "compr_model_name": null,
14
+ "compr_n_layers": null,
15
  "compr_rate": 16,
16
+ "compr_rms_norm": false,
17
+ "compr_use_mlp": true,
18
  "decoder_model_name": "meta-llama/Meta-Llama-3.1-8B-Instruct",
19
  "device_map": "auto",
20
+ "different_mem_tokens": true,
21
+ "doc_max_length": 128,
22
+ "generation_top_k": 1,
23
+ "kbtc_training": false,
24
+ "load_adapters": true,
25
+ "lora": true,
26
+ "lora_compressor": false,
27
  "lora_r": 16,
28
+ "lora_r_compressor": 16,
29
+ "max_new_tokens": 128,
30
+ "model_type": "COCOM",
31
+ "optimize_mem_tokens": true,
32
+ "quantization": "no",
33
  "sep": true,
34
+ "training_form": "both_separately",
35
+ "transformers_version": "4.48.0"
36
  }
decoder_first_last_layers.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0135c46cccdf74403def3c03c233887918a7a9ba1c2c0ad6ee6db8ce21ef418
3
+ size 2101528196
modelling_pisco.py CHANGED
@@ -1,9 +1,383 @@
1
  import warnings
2
  import os
3
  import torch
4
- from peft import LoraConfig
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PretrainedConfig, AutoConfig, GenerationConfig
 
6
  from jinja2.exceptions import TemplateError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  def add_memory_tokens_to_inputs(input_ids: torch.Tensor, attention_mask: torch.Tensor, n_mem_tokens: int, tokenizer):
@@ -21,108 +395,288 @@ def add_memory_tokens_to_inputs(input_ids: torch.Tensor, attention_mask: torch.T
21
  return input_ids, attention_mask
22
 
23
 
24
- class PISCOConfig(PretrainedConfig):
25
 
26
- model_type = "PISCO"
27
  def __init__(self,
28
  decoder_model_name: str = "meta-llama/Llama-2-7b-chat-hf",
29
- compr_rate: int = 16,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  **kwargs):
31
  super().__init__(**kwargs)
32
 
33
  self.decoder_model_name = decoder_model_name # model name of decoder
 
 
 
 
 
34
  self.compr_rate = compr_rate # compression rate
35
- self.lora_r = 16
36
- self.sep = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
 
38
 
39
- class PISCO(PreTrainedModel):
40
- config_class = PISCOConfig
 
 
 
 
 
 
41
  def __init__(self, cfg):
42
  super().__init__(cfg)
43
  self.decoder_model_name = cfg.decoder_model_name
44
- self.sep = cfg.sep
45
- self.compr_rate = cfg.compr_rate
46
-
47
- self.create_tokenizer(cfg)
48
-
49
- # Base model config but we modify vocab size since we added tokens (mainly the mem tokens)
50
- decoder_config = AutoConfig.from_pretrained(cfg.decoder_model_name)
51
- decoder_config.vocab_size = len(self.tokenizer)
52
 
53
- # Initializing placeholder model:
54
- self.decoder = AutoModelForCausalLM.from_config(decoder_config,
55
- attn_implementation='flash_attention_2',
56
- torch_dtype=torch.bfloat16)
57
 
58
- peft_config = self.get_peft_config(cfg)
59
 
 
 
 
60
  self.adapter_keys = []
61
- self.decoder.add_adapter(peft_config, 'decoder_adapter')
62
- self.decoder.set_adapter('decoder_adapter')
63
- self.adapter_keys.append('decoder_adapter')
64
- self.decoder.add_adapter(peft_config, 'encoder_adapter')
65
- self.adapter_keys.append('encoder_adapter')
66
-
67
- self.generation_config = GenerationConfig(do_sample=False, top_p=None)
68
-
69
- def create_tokenizer(self, cfg):
70
- self.tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name, use_fast=True, padding_side='left')
71
-
72
- n_mem_tokens = 128 // cfg.compr_rate
73
- mem_tokens = ['<MEM' + str(i) + '>' for i in range(n_mem_tokens)]
74
- self.tokenizer.add_special_tokens({'additional_special_tokens': mem_tokens + ['<AE>', '<ENC>', '<SEP>']})
75
- self.tokenizer.mem_tokens = mem_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- self.tokenizer.mem_token_ids = [self.tokenizer.convert_tokens_to_ids(elt) for elt in self.tokenizer.mem_tokens]
78
- self.tokenizer.mem_token_ids_pt = torch.LongTensor(self.tokenizer.mem_token_ids) # required later on for operations on tensors
 
 
 
 
 
 
 
79
 
80
- self.tokenizer.ae_token = '<AE>' # token for autoencoding on decoder side
81
- self.tokenizer.ae_token_id = self.tokenizer.convert_tokens_to_ids('<AE>')
82
- self.tokenizer.enc_token = '<ENC>' # token for autoencoding on compressor side
83
- self.tokenizer.sep_token = '<SEP>' # sep token between document
84
- self.tokenizer.sep_token_id = self.tokenizer.convert_tokens_to_ids('<SEP>')
85
 
86
- # if pad token exists then use pad token, othrwise bos token
87
- if self.tokenizer.pad_token_id is None:
88
- self.tokenizer.pad_token_id = self.tokenizer.bos_token_id
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def set_all_adapters(self):
91
  if len(self.adapter_keys) > 0:
92
  self.decoder.set_adapter(self.adapter_keys)
 
 
 
 
93
 
94
- def get_peft_config(self, cfg: PISCOConfig) -> LoraConfig:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  """
96
  Builds the peft config
97
  """
98
- return LoraConfig(task_type="CAUSAL_LM", r=cfg.lora_r, lora_alpha=2* cfg.lora_r, target_modules='all-linear', lora_dropout=0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  def compress(self, enc_input_ids, enc_attention_mask):
101
- return self.compr_decoder(enc_input_ids, enc_attention_mask)
 
 
 
102
 
103
  def replace_emb(self, compressed_embs, dec_input_ids):
104
  """
105
- Create an input embedding vector combining the compressed_embs and the dec_input_ids
106
  """
107
  indices = range(0, compressed_embs.size(0) + 1, self.generation_top_k)
108
-
109
- input_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
110
- num_embs = compressed_embs.size(1)
111
- if self.sep:
112
- slot_len = num_embs + 1
113
- else:
114
- slot_len = num_embs
115
- # get first mem_token indices
116
- first_mem_token_indices = torch.argmax((dec_input_ids == self.tokenizer.mem_token_ids[0]).int(), dim=1)
117
- batch_size = input_embeds.size(0)
118
- # for each example in batch, replace them with compressed embeddings
119
- for i in range(batch_size):
120
- for j in range(indices[i], indices[i + 1]):
121
- start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
122
- assert input_embeds[i, start_idx:start_idx + num_embs, :].size() == compressed_embs[j].size(), \
123
- f"{input_embeds[i, start_idx:start_idx + num_embs, :].size()} VS {compressed_embs[j].size()}"
124
- input_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
125
-
126
  return input_embeds
127
 
128
  def compr_decoder(self, input_ids, attention_mask):
@@ -134,28 +688,68 @@ class PISCO(PreTrainedModel):
134
  # Switch adapter if we are training two different ones:
135
  if 'encoder_adapter' in self.adapter_keys:
136
  self.decoder.set_adapter('encoder_adapter')
137
-
138
  emb = self.decoder(input_ids=input_ids,
139
  attention_mask=attention_mask,
140
  output_hidden_states=True).hidden_states[-1]
141
- mask = torch.isin(input_ids, self.tokenizer.mem_token_ids_pt.to(input_ids.device))
142
  return emb[mask].reshape(emb.size(0), -1, emb.size(-1))
143
 
144
- def prepare_encoder_inputs_to_decoder(self, texts, max_length):
145
- inp_enc = [self.tokenizer.enc_token + self.tokenizer.bos_token + text + self.tokenizer.eos_token for text in texts]
146
- inp_enc = self.tokenizer(inp_enc, return_tensors='pt', padding="longest", max_length=max_length+3, truncation=True, add_special_tokens=False)
147
- num_mem_tokens = 128 // self.compr_rate # hardcode size
148
- assert num_mem_tokens == len(self.tokenizer.mem_tokens)
 
 
 
 
 
 
149
  inp_enc['input_ids'], inp_enc['attention_mask'] = add_memory_tokens_to_inputs(inp_enc['input_ids'],
150
  inp_enc['attention_mask'],
151
  num_mem_tokens,
152
- tokenizer=self.tokenizer)
153
 
154
  return inp_enc
155
 
156
- def prepare_encoder_inputs(self, texts, max_length):
157
- return self.prepare_encoder_inputs_to_decoder(texts, max_length)
 
 
 
 
 
 
 
 
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  def forward(self,
160
  enc_input_ids: torch.LongTensor = None,
161
  enc_attention_mask: torch.LongTensor = None,
@@ -185,6 +779,10 @@ class PISCO(PreTrainedModel):
185
  compressed_embs = self.compress(enc_input_ids, enc_attention_mask)
186
  inputs_embeds = self.replace_emb(compressed_embs, dec_input_ids)
187
 
 
 
 
 
188
  # decoding
189
  if 'decoder_adapter' in self.adapter_keys:
190
  self.decoder.set_adapter('decoder_adapter')
@@ -195,7 +793,179 @@ class PISCO(PreTrainedModel):
195
  self.set_all_adapters()
196
 
197
  return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}
198
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  def generate_from_text(self, questions: list[str], documents: list[list[str]], max_new_tokens: int = 128) -> list[str]:
200
  """
201
  Generates answers from documents (via compression then decoding)
@@ -216,7 +986,7 @@ class PISCO(PreTrainedModel):
216
 
217
  # Creating decoder inputs
218
  instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions]
219
- inp_dec = self.tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=2048)
220
  model_input['dec_input_ids'], model_input['dec_attention_mask'] = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device)
221
 
222
  # Generation
@@ -233,7 +1003,7 @@ class PISCO(PreTrainedModel):
233
 
234
  # Creating decoder inputs
235
  instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions]
236
- inp_dec = self.tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=2048)
237
  device = self.decoder.device
238
  dec_input_ids, dec_attention_mask = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device)
239
 
@@ -252,7 +1022,7 @@ class PISCO(PreTrainedModel):
252
  )
253
 
254
  # de-tokenizing
255
- return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
256
 
257
  def compress_documents(self, documents: list[str]) -> torch.Tensor:
258
  """
@@ -262,46 +1032,14 @@ class PISCO(PreTrainedModel):
262
  enc_input_ids = input_encoder['input_ids'].to(self.decoder.device)
263
  attention_mask = input_encoder['attention_mask'].to(self.decoder.device)
264
  return self.compress(enc_input_ids=enc_input_ids, enc_attention_mask=attention_mask)
265
-
266
- def generate(self, model_input, max_new_tokens=128):
267
- """
268
- Generation pipeline including compression + decoding from compressed
269
- """
270
-
271
- enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask']
272
-
273
- assert enc_input_ids.size() == enc_attention_mask.size()
274
-
275
- if len(enc_input_ids.size()) == 3: # likely from bergen: we just flatten all of this to perform encoding in one batch
276
- batch_size, top_k, seq_length = enc_input_ids.size()
277
- enc_input_ids = enc_input_ids.view(batch_size * top_k, seq_length)
278
- enc_attention_mask = enc_attention_mask.view(batch_size * top_k, seq_length)
279
-
280
- # Here, we should have top_k times more elements in enc_input_ids than in dec_input_ids
281
- assert enc_input_ids.size(0) == dec_input_ids.size(0) * self.generation_top_k, \
282
- f"{enc_input_ids.size(0)} VS {dec_input_ids.size(0)} with generation_top_k={self.generation_top_k}"
283
-
284
- compressed_embs = self.compress(enc_input_ids, enc_attention_mask)
285
- inputs_embeds = self.replace_emb(compressed_embs, dec_input_ids)
286
-
287
- if 'decoder_adapter' in self.adapter_keys:
288
- self.decoder.set_adapter('decoder_adapter')
289
-
290
- output_ids = self.decoder.generate(
291
- inputs_embeds=inputs_embeds,
292
- attention_mask=dec_attention_mask,
293
- generation_config=self.generation_config,
294
- max_new_tokens=max_new_tokens
295
- )
296
-
297
- return self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
298
-
299
  def blend_prompt_and_memory_tokens(self, query: str):
300
  """
301
  Takes care of blending the prompt with the memory tokens:
302
  Also returns, if a label is provided, the position of the first token index of the label (for loss comp later on)
 
303
  """
304
- mem_tokens_str = ''.join(self.tokenizer.mem_tokens) + self.tokenizer.sep_token
305
 
306
  # proper names for "eval" call, don't remove these lines
307
  docs = mem_tokens_str * self.generation_top_k
@@ -318,7 +1056,7 @@ class PISCO(PreTrainedModel):
318
 
319
  # Attempt to apply the system role and catch if it's not supported
320
  try:
321
- prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
322
 
323
  except TemplateError as e:
324
  # Catch the error related to system role and handle it (e.g. gemma)
@@ -326,9 +1064,37 @@ class PISCO(PreTrainedModel):
326
  # Remove system role and proceed with only the user role
327
  messages = [{"role": "user", "content": messages[0]['content'] + '\n' + messages[1]['content']}]
328
  # Apply template again without system role
329
- prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
330
  else:
331
  # Re-raise the exception if it's unrelated to system role
332
  raise e
333
 
334
  return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import warnings
2
  import os
3
  import torch
4
+ import gc
5
+
6
+ from torch import nn
7
  from jinja2.exceptions import TemplateError
8
+ from peft import LoraConfig
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, PretrainedConfig, AutoModel, AutoConfig
10
+ from huggingface_hub import hf_hub_download
11
+
12
+
13
+ def get_first_layers_model(base_model_name: str, n_layers: int, attn_implementation: str = 'flash_attention_2'):
14
+ """
15
+ Builds a model comprising only the n_layers first layer of the base_model_name
16
+ (it keeps the embedding and head layers)
17
+ """
18
+ full_model = AutoModelForCausalLM.from_pretrained(base_model_name)
19
+
20
+ # Create a new config for a model with fewer layers (e.g., 3 layers)
21
+ custom_config = AutoConfig.from_pretrained(base_model_name)
22
+ custom_config.num_hidden_layers = n_layers
23
+ first_layers_model = AutoModelForCausalLM.from_config(config=custom_config, attn_implementation=attn_implementation, torch_dtype=torch.bfloat16)
24
+
25
+ # Load the state dict of the full model
26
+ full_state_dict = full_model.state_dict()
27
+ custom_state_dict = first_layers_model.state_dict()
28
+ kept_state_dict = {k:v for k,v in full_state_dict.items() if k in custom_state_dict}
29
+
30
+ first_layers_model.load_state_dict(kept_state_dict, strict=True)
31
+
32
+ del full_model
33
+ torch.cuda.empty_cache()
34
+ gc.collect()
35
+
36
+ return first_layers_model
37
+
38
+
39
+ def get_every_n_layer_model(base_model_name: str, every_n_layer: int, attn_implementation: str = 'flash_attention_2'):
40
+ """
41
+ Builds a model comprising 1 every every_n_layer layer of the base_model_name
42
+ (it keeps the embedding and head layers)
43
+ """
44
+ full_model = AutoModelForCausalLM.from_pretrained(base_model_name)
45
+ n_kept_layers = full_model.config.num_hidden_layers // every_n_layer
46
+
47
+ print(f'New model with 1/{every_n_layer} from {base_model_name} will have {n_kept_layers} layers')
48
+
49
+ custom_config = AutoConfig.from_pretrained(base_model_name)
50
+ custom_config.num_hidden_layers = n_kept_layers
51
+ custom_model = AutoModelForCausalLM.from_config(config=custom_config,
52
+ attn_implementation=attn_implementation,
53
+ torch_dtype=torch.bfloat16)
54
+ full_state_dict = full_model.state_dict()
55
+ custom_state_dict = custom_model.state_dict()
56
+
57
+ # Filter out every Nth layer and rename to form a new state dict
58
+ kept_state_dict = {}
59
+ for key, value in full_state_dict.items():
60
+ if ".layers." in key:
61
+ # Extract layer index
62
+ layer_idx = int(key.split(".layers.")[1].split(".")[0])
63
+ # Check if it's an Nth layer
64
+ if layer_idx % every_n_layer == 0:
65
+ # Adjust layer index for the smaller model
66
+ new_layer_idx = layer_idx // every_n_layer
67
+ # print('replacing', f".layers.{layer_idx}.", f".layers.{new_layer_idx}.")
68
+ new_key = key.replace(f".layers.{layer_idx}.", f".layers.{new_layer_idx}.")
69
+ if new_key in custom_state_dict:
70
+ kept_state_dict[new_key] = value
71
+ else:
72
+ # Keep non-layer-specific parameters
73
+ if key in custom_state_dict:
74
+ kept_state_dict[key] = value
75
+
76
+ # Load the filtered state dict into the custom model
77
+ custom_model.load_state_dict(kept_state_dict, strict=True)
78
+
79
+ del full_model
80
+ torch.cuda.empty_cache()
81
+ gc.collect()
82
+
83
+ return custom_model
84
+
85
+
86
+ class MistralTrimmed(torch.nn.Module):
87
+ """
88
+ Trimmed version of base models for faster compression
89
+ NB: the name 'MistralTrimmed' suggests it just works with mistral but NO in fact most LLMs are supported !
90
+ """
91
+ def __init__(self,
92
+ n_layers: int = 15,
93
+ every_n_layer: int = None,
94
+ rms_norm: bool = False,
95
+ base_model_name: str = 'mistralai/Mistral-7B-Instruct-v0.2',
96
+ attn_implementation: str = 'flash_attention_2'):
97
+ """
98
+ you can either specify
99
+ - n_layers to some number: we take the n_layers first layers of the base model.
100
+ - every_n_layer to some number: in that case we take 1/N layer of the base model
101
+ The base_model_name is the name of the model from which this model is built.
102
+ """
103
+ assert (n_layers is None) ^ (every_n_layer is None), 'Cannot specify both n_layers and every_n_layer for MistralTrimmed'
104
+ super().__init__()
105
+
106
+ self.n_layers = n_layers
107
+ self.every_n_layer = every_n_layer
108
+ self.base_model_name = base_model_name
109
+
110
+ if n_layers is not None:
111
+ self.custom_model = get_first_layers_model(self.base_model_name,
112
+ n_layers,
113
+ attn_implementation=attn_implementation)
114
+
115
+ else:
116
+ self.custom_model = get_every_n_layer_model(self.base_model_name,
117
+ every_n_layer,
118
+ attn_implementation=attn_implementation)
119
+
120
+ self.custom_model = self.custom_model.bfloat16()
121
+ self.custom_model.cuda()
122
+
123
+ if rms_norm:
124
+ print('Compressor keeps its original rms norm')
125
+ else:
126
+ print('De-activating RMS norm in compressor')
127
+ # We deactivate the norm: we don't need it here since we want to manipulate stuff within embed space
128
+ # see https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/models/mistral/modeling_mistral.py#L699
129
+ self.custom_model.model.norm = nn.Identity()
130
+
131
+ # Piping useful methods:
132
+ self.add_adapter = self.custom_model.add_adapter
133
+ self.set_adapter = self.custom_model.set_adapter
134
+ self.load_adapter = self.custom_model.load_adapter
135
+ self.num_parameters = self.custom_model.num_parameters
136
+ self.resize_token_embeddings = self.custom_model.resize_token_embeddings
137
+ self.get_input_embeddings = self.custom_model.get_input_embeddings
138
+ self.get_adapter_state_dict = self.custom_model.get_adapter_state_dict
139
+
140
+ # self.custom_model.gradient_checkpointing_enable()
141
+
142
+ # del self.custom_model.lm_head # THIS FAILS since some models have tie_embeddings=True !
143
+ # gc.collect()
144
+ # torch.cuda.empty_cache()
145
+
146
+ def forward(self, input_ids, attention_mask=None):
147
+ return self.custom_model.model(input_ids, attention_mask, output_hidden_states=True) # we call the .model attribute of the causal LM to avoid the cost of the LM head ! nice huh ?
148
+
149
+ def __call__(self, input_ids, attention_mask=None, output_hidden_states=True):
150
+ return self.forward(input_ids, attention_mask)
151
+
152
+
153
+ class AbstractCompressor(nn.Module):
154
+ def __init__(self, compr_model_name: str, compr_rate: int, decoder_hidden_size: int):
155
+ super().__init__()
156
+ self.compr_model_name = compr_model_name
157
+ self.compr_rate = compr_rate
158
+ self.decoder_hidden_size = decoder_hidden_size
159
+
160
+ def forward(self, input_ids, attention_mask, generation_top_k):
161
+ """
162
+ input_ids of shape (batch_size, top_k, seq_length)
163
+ attention_mask of shape (batch_size, top_k, seq_length)
164
+ generation_top_k: the number of docs
165
+ """
166
+ raise NotImplementedError
167
+
168
+ def save_pretrained(self, save_directory):
169
+ raise NotImplementedError
170
+
171
+ def load_pretrained(self, load_directory):
172
+ raise NotImplementedError
173
+
174
+
175
+ class BertCompressor(AbstractCompressor):
176
+ def __init__(self,
177
+ compr_model_name: str,
178
+ compr_rate: int,
179
+ decoder_hidden_size: int,
180
+ mlp_hidden_dim: int = 8192,
181
+ use_mlp: bool = True,
182
+ doc_max_length : int = 128,
183
+ **kwargs):
184
+ # TODO use the device_map
185
+ super().__init__(compr_model_name=compr_model_name, compr_rate=compr_rate, decoder_hidden_size=decoder_hidden_size)
186
+ if compr_model_name == 'mistral_trimmed':
187
+ assert 'compr_n_layers' in kwargs
188
+ self.model = MistralTrimmed(n_layers=kwargs['compr_n_layers'],
189
+ every_n_layer=kwargs['compr_every_n_layer'],
190
+ rms_norm=kwargs['compr_rms_norm'],
191
+ base_model_name=kwargs['compr_base_model_name'],
192
+ attn_implementation=kwargs['attn_implementation'])
193
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model.base_model_name)
194
+ self.hidden_size = self.model.custom_model.config.hidden_size
195
+ else:
196
+ self.model = AutoModel.from_pretrained(compr_model_name, torch_dtype=torch.bfloat16, device_map='auto')
197
+ self.tokenizer = AutoTokenizer.from_pretrained(compr_model_name, use_fast=True)
198
+ self.tokenizer.padding_side = "left"
199
+ self.hidden_size = self.model.config.hidden_size
200
+
201
+ print('Base compressor nb parameters', self.model.num_parameters())
202
+
203
+ self.mlp_hidden_dim = mlp_hidden_dim
204
+ self.use_mlp = use_mlp
205
+ self.doc_max_length = doc_max_length
206
+
207
+ if self.use_mlp:
208
+ self.mlp = nn.Sequential(
209
+ nn.Linear(self.hidden_size, self.mlp_hidden_dim),
210
+ nn.ReLU(),
211
+ nn.Linear(self.mlp_hidden_dim, decoder_hidden_size)
212
+ ).bfloat16()
213
+ self.mlp.cuda()
214
+
215
+ self.n_emb = self.doc_max_length // self.compr_rate
216
+
217
+ mem_tokens = ['<MEM' + str(i) + '>' for i in range(self.n_emb)]
218
+ self.tokenizer.add_special_tokens({'additional_special_tokens': mem_tokens})
219
+ self.tokenizer.mem_tokens = mem_tokens
220
+ self.tokenizer.mem_token_ids = [self.tokenizer.convert_tokens_to_ids(elt) for elt in self.tokenizer.mem_tokens]
221
+ self.tokenizer.mem_token_ids_pt = torch.LongTensor(self.tokenizer.mem_token_ids)
222
+ self.model.resize_token_embeddings(len(self.tokenizer))
223
+
224
+ if self.tokenizer.pad_token_id is None:
225
+ self.tokenizer.pad_token_id = self.tokenizer.bos_token_id
226
+
227
+ if not use_mlp:
228
+ assert decoder_hidden_size == self.hidden_size, f'Mlp mandatory is hidden sizes not equal: {decoder_hidden_size} vs {self.hidden_size}'
229
+
230
+ self.lora = False
231
+ self.lora_name = 'compr_adapter'
232
+
233
+ def prepare_mem_tokens_optimization(self):
234
+ assert self.lora, 'should only be called with lora.'
235
+ self.model.get_input_embeddings().weight.requires_grad = True
236
+ # Applying a hook zero-ing the gradients except for the mem token:
237
+ def hook(grad):
238
+ mask = torch.zeros_like(grad)
239
+ mask[self.tokenizer.mem_token_ids] = 1.0
240
+ return grad * mask
241
+ self.model.get_input_embeddings().weight.register_hook(hook)
242
+
243
+ def set_lora(self, peft_config):
244
+ self.model.add_adapter(peft_config, self.lora_name)
245
+ self.model.set_adapter(self.lora_name)
246
+ self.lora = True
247
+ self.prepare_mem_tokens_optimization()
248
+
249
+ def forward(self, input_ids, attention_mask):
250
+ assert input_ids.size() == attention_mask.size()
251
+ assert len(input_ids.size()) == 2
252
+
253
+ batch_size_times_top_k = input_ids.size(0)
254
+
255
+ last_hidden_states = self.model(input_ids=input_ids,
256
+ attention_mask=attention_mask,
257
+ output_hidden_states=True).hidden_states[-1]
258
+
259
+ # Getting the hidden states at the mem token positions, as for regular cocom:
260
+ mask = torch.isin(input_ids, self.tokenizer.mem_token_ids_pt.to(input_ids.device))
261
+ selected_n_tokens = last_hidden_states[mask].reshape(last_hidden_states.size(0), -1, last_hidden_states.size(-1))
262
+
263
+ assert selected_n_tokens.size() == (batch_size_times_top_k, self.n_emb, self.hidden_size), f"{selected_n_tokens.size()} vs {(batch_size_times_top_k, self.n_emb, self.hidden_size)}"
264
+
265
+ if self.use_mlp:
266
+ selected_n_tokens = self.mlp(selected_n_tokens) # now of shape (batch_size, top_k, decoder_hidden_size)
267
+
268
+ assert selected_n_tokens.size() == (batch_size_times_top_k, self.n_emb, self.decoder_hidden_size), f"{selected_n_tokens.size()} vs {(batch_size_times_top_k, self.n_emb, self.decoder_hidden_size)}"
269
+
270
+ return selected_n_tokens
271
+
272
+ def get_lora_path_from_directory(self, directory):
273
+ return os.path.join(directory, 'compressor_adapters.pth')
274
+
275
+ def get_compressor_path_from_directory(self, directory):
276
+ return os.path.join(directory, 'compressor.pth')
277
+
278
+ def get_mlp_path_from_directory(self, directory):
279
+ return os.path.join(directory, 'mlp.pth')
280
+
281
+ def get_first_layer_path_from_directory(self, directory):
282
+ return os.path.join(directory, 'first_layer.pth')
283
+
284
+ def get_first_layer_state_dict(self) -> dict:
285
+ out = {}
286
+ for k, v in self.model.named_parameters():
287
+ if 'embed_tokens.weight' in k:
288
+ out[k] = v.cpu()
289
+
290
+ assert len(out) == 1, len(out) # We should get exactly one layer here
291
+ return out
292
+
293
+ def save_pretrained(self, save_directory):
294
+ """
295
+ Here we just save mlp state_dict and model state_dict
296
+ Config is handled in cocom model.
297
+ """
298
+ if not os.path.exists(save_directory):
299
+ os.makedirs(save_directory)
300
+
301
+ # Save MLP weights
302
+ if self.use_mlp:
303
+ mlp_path = self.get_mlp_path_from_directory(directory=save_directory)
304
+ torch.save(self.mlp.state_dict(), mlp_path)
305
+
306
+ # Saving the model
307
+ if not self.lora: # full training: save the full dict:
308
+ model_path = self.get_compressor_path_from_directory(directory=save_directory)
309
+ torch.save(self.model.state_dict(), model_path)
310
+ else: # lora training of the compressor
311
+ # We save the first layer:
312
+ first_layer_state_dict = self.get_first_layer_state_dict()
313
+ torch.save(first_layer_state_dict, self.get_first_layer_path_from_directory(directory=save_directory))
314
+
315
+ # We save the adapters:
316
+ adapter_state_dict = {k: v.cpu() for k, v in self.model.get_adapter_state_dict(self.lora_name).items()}
317
+ torch.save(adapter_state_dict, self.get_lora_path_from_directory(directory=save_directory))
318
+
319
+ def load_adapter(self, load_directory, peft_config):
320
+ assert peft_config is not None
321
+ map_location = torch.device("cpu") if not torch.cuda.is_available else None
322
+ adapter_state_dict = torch.load(self.get_lora_path_from_directory(directory=load_directory), map_location=map_location, weights_only=True)
323
+ print('loading compr adapter onto compressor model from', self.get_lora_path_from_directory(directory=load_directory))
324
+ self.model.load_adapter(peft_config=peft_config, adapter_name=self.lora_name, adapter_state_dict=adapter_state_dict)
325
+ self.lora = True
326
+ self.prepare_mem_tokens_optimization()
327
+
328
+ def load_first_layer(self, load_directory):
329
+ map_location = torch.device("cpu") if not torch.cuda.is_available else None
330
+ first_layer_state_dict = torch.load(self.get_first_layer_path_from_directory(load_directory), map_location=map_location, weights_only=True)
331
+ assert len(first_layer_state_dict.keys()) == 1
332
+ self.model.load_state_dict(first_layer_state_dict, strict=False)
333
+
334
+ def load_pretrained(self, load_directory, lora: bool = False, peft_config=None):
335
+ """
336
+ Loading the state dicts.
337
+ :lora: if True then the compressor was trained using lora: we just need to load the adapters
338
+ if False, the compressor was fully trained: we load it fully.
339
+ """
340
+ if self.use_mlp:
341
+ mlp_path = self.get_mlp_path_from_directory(directory=load_directory)
342
+ self.mlp.load_state_dict(torch.load(mlp_path, weights_only=True))
343
+
344
+ if lora:
345
+ self.load_first_layer(load_directory)
346
+ self.load_adapter(load_directory, peft_config)
347
+
348
+ else:
349
+ model_path = self.get_compressor_path_from_directory(directory=load_directory)
350
+ self.model.load_state_dict(torch.load(model_path, weights_only=True))
351
+
352
+ def prepare_inputs(self, texts, max_length, q_texts=None):
353
+ if q_texts is not None: # Query-dependent here:
354
+ assert len(texts) == len(q_texts), f"{len(texts)} == {len(q_texts)}"
355
+ if self.compr_model_name == 'mistral_trimmed':
356
+ # No special token, just formulating:
357
+ texts_to_encode = [ '\nQuery:\n' + query + 'Document:\n' + text for text, query in zip(texts, q_texts)]
358
+ inp_enc = self.tokenizer(texts_to_encode,
359
+ return_tensors='pt',
360
+ padding='max_length',
361
+ max_length=max_length + 8, # some margin for query/doc stuff + bos / eos
362
+ truncation=True,
363
+ add_special_tokens=True)
364
+ else:
365
+ inp_enc = self.tokenizer(q_texts, # we put the query in first position
366
+ texts,
367
+ return_tensors='pt',
368
+ padding='max_length',
369
+ max_length=max_length + 3,
370
+ truncation='only_second',
371
+ add_special_tokens=True)
372
+ else:
373
+ inp_enc = self.tokenizer(texts, return_tensors='pt', padding='max_length', max_length=max_length + 2, truncation=True, add_special_tokens=True)
374
+
375
+ inp_enc['input_ids'], inp_enc['attention_mask'] = add_memory_tokens_to_inputs(inp_enc['input_ids'],
376
+ inp_enc['attention_mask'],
377
+ self.n_emb,
378
+ tokenizer=self.tokenizer)
379
+
380
+ return inp_enc
381
 
382
 
383
  def add_memory_tokens_to_inputs(input_ids: torch.Tensor, attention_mask: torch.Tensor, n_mem_tokens: int, tokenizer):
 
395
  return input_ids, attention_mask
396
 
397
 
398
+ class COCOMConfig(PretrainedConfig):
399
 
400
+ model_type = "COCOM"
401
  def __init__(self,
402
  decoder_model_name: str = "meta-llama/Llama-2-7b-chat-hf",
403
+ doc_max_length: int = 128,
404
+ quantization: str = 'no',
405
+ sep: bool = False,
406
+ compr_model_name: str = "google-bert/bert-base-uncased",
407
+ compr_rate: int = 64,
408
+ compr_n_layers: int = None, # only for surgical mistral compressor
409
+ compr_every_n_layer: int = None,
410
+ compr_base_model_name: str = 'mistralai/Mistral-7B-Instruct-v0.2',
411
+ compr_rms_norm: bool = False, # only for surgical mistral compressor: if true, rms norm applied on h-s
412
+ compr_mlp_hidden_dim: int = 8096,
413
+ compr_use_mlp: bool = True,
414
+ lora: bool = False, # lora on decoder (and decoder as compr)
415
+ lora_compressor: bool = False, # lora only on the compressor if it exists
416
+ training_form: str = "both",
417
+ lora_r: int = 16,
418
+ lora_r_compressor: int = None,
419
+ load_adapters: bool = True,
420
+ kbtc_training: bool = False,
421
+ optimize_mem_tokens: bool = False,
422
+ different_mem_tokens: bool = False,
423
+ attn_implementation: str = 'flash_attention_2',
424
+ device_map = None,
425
  **kwargs):
426
  super().__init__(**kwargs)
427
 
428
  self.decoder_model_name = decoder_model_name # model name of decoder
429
+ self.doc_max_length = doc_max_length # the maximum length of document that can be used by this model (it is used to compute number of mem tokens !)
430
+ self.quantization = quantization # quantization, could be no, int4, int8
431
+ self.sep = sep # boolean type, whether to use sep token
432
+
433
+ self.compr_model_name = compr_model_name # model name of compressor
434
  self.compr_rate = compr_rate # compression rate
435
+ self.compr_use_mlp = compr_use_mlp
436
+ self.compr_mlp_hidden_dim = compr_mlp_hidden_dim
437
+ self.compr_n_layers = compr_n_layers
438
+ self.compr_every_n_layer = compr_every_n_layer
439
+ self.compr_base_model_name = compr_base_model_name
440
+ self.compr_rms_norm = compr_rms_norm
441
+
442
+ self.lora = lora # boolean type, whether to use lora trsining
443
+ self.lora_compressor = lora_compressor
444
+ self.training_form = training_form # training form, could be compressor: training only comprssor; both: training both
445
+ # Or both_separately: training both with separate adapters
446
+ self.lora_r = lora_r # lora_r for lora training, we use 16 throughout the experiment.
447
+ self.lora_r_compressor = lora_r_compressor or lora_r # defaulting to same lora as decoder.
448
+ self.load_adapters = load_adapters # used to load pretrained model: we first load without adapters, and then load them from file.
449
+ self.optimize_mem_tokens = optimize_mem_tokens
450
+ self.different_mem_tokens = different_mem_tokens
451
+
452
+ self.kbtc_training = kbtc_training
453
 
454
+ self.device_map = device_map
455
 
456
+ self.attn_implementation = attn_implementation
457
+
458
+ if training_form == 'compressor':
459
+ assert compr_model_name is not None and not self.lora
460
+
461
+
462
+ class COCOM(PreTrainedModel):
463
+ config_class = COCOMConfig
464
  def __init__(self, cfg):
465
  super().__init__(cfg)
466
  self.decoder_model_name = cfg.decoder_model_name
467
+ self.decoder = self.create_decoder(cfg)
 
 
 
 
 
 
 
468
 
469
+ self.doc_max_length = cfg.doc_max_length
 
 
 
470
 
471
+ print('Base decoder nb parameters', self.decoder.num_parameters())
472
 
473
+ self.compr_model_name = cfg.compr_model_name
474
+ self.training_form = cfg.training_form
475
+ self.lora = cfg.lora
476
  self.adapter_keys = []
477
+
478
+ self.compr = None
479
+ # when compr_model_name is not set, then means using a decoder-based compressor, otherwise a bert based compressor
480
+ if cfg.compr_model_name is not None:
481
+ # case bert based compressor
482
+ print('Instantiating compressor ', cfg.compr_model_name)
483
+ self.compr = BertCompressor(cfg.compr_model_name,
484
+ cfg.compr_rate,
485
+ doc_max_length=self.doc_max_length,
486
+ decoder_hidden_size=self.decoder.config.hidden_size,
487
+ mlp_hidden_dim=cfg.compr_mlp_hidden_dim,
488
+ compr_n_layers=cfg.compr_n_layers,
489
+ compr_every_n_layer=cfg.compr_every_n_layer,
490
+ compr_base_model_name=cfg.compr_base_model_name,
491
+ compr_rms_norm=cfg.compr_rms_norm,
492
+ use_mlp=cfg.compr_use_mlp,
493
+ attn_implementation=cfg.attn_implementation)
494
+
495
+ # set lora adaptors on decoder model
496
+ if cfg.lora:
497
+ peft_config = self.get_peft_config(lora_r=cfg.lora_r)
498
+
499
+ if cfg.load_adapters:
500
+ self.decoder.add_adapter(peft_config, 'decoder_adapter')
501
+ self.decoder.set_adapter('decoder_adapter') # active adapter by default
502
+ self.adapter_keys.append('decoder_adapter')
503
+
504
+ # Create separate adapters (if not BERT compressor and training_form == 'both_separately')
505
+ if self.training_form == 'both_separately' and self.compr is None:
506
+ if cfg.load_adapters:
507
+ self.decoder.add_adapter(peft_config, 'encoder_adapter')
508
+ self.adapter_keys.append('encoder_adapter')
509
+
510
+ # set lora adapters on compressor model:
511
+ if cfg.lora_compressor and self.compr is not None and cfg.load_adapters:
512
+ peft_config = self.get_peft_config(lora_r=cfg.lora_r_compressor)
513
+ self.compr.set_lora(peft_config)
514
+
515
+ self.decoder_tokenizer = COCOM.create_decoder_tokenizer(cfg)
516
+
517
+ # resize the tokenizer embedding
518
+ self.decoder.resize_token_embeddings(len(self.decoder_tokenizer))
519
+ self.decoder.generation_config.top_p = None
520
+ self.decoder.generation_config.temperature = None
521
+ self.decoder.generation_config.pad_token_id = self.decoder_tokenizer.pad_token_id
522
 
523
+ # self.decoder.gradient_checkpointing_enable()
524
+ # if self.compr is not None:
525
+ # self.compr.gradient_checkpointing_enable()
526
+
527
+ # other settings
528
+ self.generation_top_k = 1
529
+ self.sep = cfg.sep
530
+ self.compr_rate = cfg.compr_rate
531
+ self.local_rank = os.getenv('LOCAL_RANK', '0')
532
 
533
+ self.n_mem_tokens = self.doc_max_length // self.compr_rate # crucial!
 
 
 
 
534
 
 
 
 
535
 
536
+ if self.lora:
537
+ for adapter_key in self.adapter_keys:
538
+ self.decoder.set_adapter(adapter_key)
539
+ print(f'Adapter {adapter_key} trainable parameters: {self.num_parameters(only_trainable=True)}')
540
+
541
+ # We need to activate all adapters so that they are both trained...
542
+ self.set_all_adapters()
543
+ else:
544
+ print(f'Total trainable parameters: {self.num_parameters(only_trainable=True)}')
545
+
546
+ if self.compr is not None:
547
+ print(f'Compressor number of parameters: {self.compr.model.num_parameters(only_trainable=True)}')
548
+
549
+ self.prepare_mem_tokens_optimization()
550
+
551
+ def prepare_mem_tokens_optimization(self):
552
+ if self.config.optimize_mem_tokens:
553
+ if self.compr is None:
554
+ # Enforcing gradients for input embeddings (even if lora)
555
+ self.decoder.get_input_embeddings().weight.requires_grad = True
556
+ # Applying a hook zero-ing the gradients except for the mem token:
557
+ def hook(grad):
558
+ mask = torch.zeros_like(grad)
559
+ mask[self.decoder_tokenizer.mem_token_ids] = 1.0
560
+ return grad * mask
561
+ self.decoder.get_input_embeddings().weight.register_hook(hook)
562
+
563
  def set_all_adapters(self):
564
  if len(self.adapter_keys) > 0:
565
  self.decoder.set_adapter(self.adapter_keys)
566
+
567
+ @staticmethod
568
+ def create_decoder_tokenizer(cfg: COCOMConfig):
569
+ decoder_tokenizer = AutoTokenizer.from_pretrained(cfg.decoder_model_name, use_fast=True, padding_side='left')
570
 
571
+ # define special tokens
572
+ n_mem_tokens = cfg.doc_max_length // cfg.compr_rate
573
+ if cfg.different_mem_tokens:
574
+ # estimation fo the number of memory tokens needed:
575
+ mem_tokens = ['<MEM' + str(i) + '>' for i in range(n_mem_tokens)]
576
+ decoder_tokenizer.add_special_tokens({'additional_special_tokens': mem_tokens + ['<AE>', '<ENC>', '<SEP>']})
577
+ decoder_tokenizer.mem_tokens = mem_tokens
578
+ else:
579
+ decoder_tokenizer.add_special_tokens({'additional_special_tokens': ['<MEM>', '<AE>', '<ENC>', '<SEP>']})
580
+ decoder_tokenizer.mem_tokens = ['<MEM>'] * n_mem_tokens
581
+
582
+ decoder_tokenizer.mem_token_ids = [decoder_tokenizer.convert_tokens_to_ids(elt) for elt in decoder_tokenizer.mem_tokens]
583
+ decoder_tokenizer.mem_token_ids_pt = torch.LongTensor(decoder_tokenizer.mem_token_ids) # required later on for operations on tensors
584
+
585
+ decoder_tokenizer.ae_token = '<AE>' # token for autoencoding on decoder side
586
+ decoder_tokenizer.ae_token_id = decoder_tokenizer.convert_tokens_to_ids('<AE>')
587
+ decoder_tokenizer.enc_token = '<ENC>' # token for autoencoding on compressor side
588
+ decoder_tokenizer.sep_token = '<SEP>' # sep token between document
589
+ decoder_tokenizer.sep_token_id = decoder_tokenizer.convert_tokens_to_ids('<SEP>')
590
+
591
+ # If kbtc training, we add another one yet
592
+ if cfg.kbtc_training:
593
+ decoder_tokenizer.add_special_tokens({'additional_special_tokens': ['<KBTC>']})
594
+ decoder_tokenizer.kbtc_token = '<KBTC>'
595
+ decoder_tokenizer.kbtc_token_id = decoder_tokenizer.convert_tokens_to_ids('<KBTC>')
596
+
597
+ # if pad token exists then use pad token, othrwise bos token
598
+ if decoder_tokenizer.pad_token_id is None:
599
+ decoder_tokenizer.pad_token_id = decoder_tokenizer.bos_token_id
600
+
601
+ return decoder_tokenizer
602
+
603
+ def get_peft_config(self, lora_r: int) -> LoraConfig:
604
  """
605
  Builds the peft config
606
  """
607
+ return LoraConfig(task_type="CAUSAL_LM", r=lora_r, lora_alpha=2*lora_r, target_modules='all-linear', lora_dropout=0.1)
608
+
609
+ def create_decoder(self, cfg):
610
+ """
611
+ Loads the base decoder.
612
+ """
613
+ if torch.cuda.is_available():
614
+ if cfg.quantization == "no":
615
+ return AutoModelForCausalLM.from_pretrained(
616
+ cfg.decoder_model_name,
617
+ torch_dtype=torch.bfloat16,
618
+ attn_implementation=self.config.attn_implementation,
619
+ # low_cpu_mem_usage = True,
620
+ device_map=cfg.device_map
621
+ )
622
+ elif cfg.quantization == "int4":
623
+ quant_config = BitsAndBytesConfig(
624
+ load_in_4bit=True,
625
+ bnb_4bit_quant_type='nf4',
626
+ bnb_4bit_compute_dtype='bfloat16',
627
+ # low_cpu_mem_usage = True,
628
+ )
629
+ return AutoModelForCausalLM.from_pretrained(
630
+ cfg.decoder_model_name,
631
+ quantization_config=quant_config,
632
+ attn_implementation=self.config.attn_implementation,
633
+ torch_dtype=torch.bfloat16,
634
+ resume_download=True,
635
+ # low_cpu_mem_usage = True,
636
+ trust_remote_code=True,
637
+ device_map=cfg.device_map
638
+ )
639
+ elif cfg.quantization == "int8":
640
+ quant_config = BitsAndBytesConfig(
641
+ load_in_8bit=True,
642
+ llm_int8_enable_fp32_cpu_offload=True,
643
+ bnb_4bit_compute_dtype='bfloat16',
644
+ # low_cpu_mem_usage = True,
645
+ )
646
+ return AutoModelForCausalLM.from_pretrained(
647
+ cfg.decoder_model_name,
648
+ quantization_config=quant_config,
649
+ attn_implementation=self.config.attn_implementation,
650
+ torch_dtype=torch.bfloat16,
651
+ resume_download=True,
652
+ # low_cpu_mem_usage = True,
653
+ trust_remote_code=True,
654
+ device_map=cfg.device_map
655
+ )
656
+ else:
657
+ raise NotImplementedError()
658
+ else:
659
+ return AutoModelForCausalLM.from_pretrained(
660
+ cfg.decoder_model_name,
661
+ torch_dtype=torch.bfloat16,
662
+ resume_download=True,
663
+ # low_cpu_mem_usage = True,
664
+ trust_remote_code=True,
665
+ device_map=cfg.device_map
666
+ )
667
 
668
  def compress(self, enc_input_ids, enc_attention_mask):
669
+ if self.compr:
670
+ return self.compr(enc_input_ids, enc_attention_mask)
671
+ else:
672
+ return self.compr_decoder(enc_input_ids, enc_attention_mask)
673
 
674
  def replace_emb(self, compressed_embs, dec_input_ids):
675
  """
676
+ Compression logic (either with decoder or with dedicated compressor)
677
  """
678
  indices = range(0, compressed_embs.size(0) + 1, self.generation_top_k)
679
+ input_embeds = self.replace_embeddings(compressed_embs, dec_input_ids, indices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
680
  return input_embeds
681
 
682
  def compr_decoder(self, input_ids, attention_mask):
 
688
  # Switch adapter if we are training two different ones:
689
  if 'encoder_adapter' in self.adapter_keys:
690
  self.decoder.set_adapter('encoder_adapter')
691
+
692
  emb = self.decoder(input_ids=input_ids,
693
  attention_mask=attention_mask,
694
  output_hidden_states=True).hidden_states[-1]
695
+ mask = torch.isin(input_ids, self.decoder_tokenizer.mem_token_ids_pt.to(input_ids.device))
696
  return emb[mask].reshape(emb.size(0), -1, emb.size(-1))
697
 
698
+ def prepare_encoder_inputs_to_decoder(self, texts, max_length, q_texts=None):
699
+ if q_texts is not None:
700
+ texts_to_encode = [self.decoder_tokenizer.enc_token + self.decoder_tokenizer.bos_token + '\nQuery:\n' + query + 'Document:\n' + text + self.decoder_tokenizer.eos_token
701
+ for text, query in zip(texts, q_texts)]
702
+ inp_enc = self.decoder_tokenizer(texts_to_encode, return_tensors='pt', padding='max_length', max_length=max_length + 8, truncation=True, add_special_tokens=False)
703
+ else:
704
+ inp_enc = [self.decoder_tokenizer.enc_token + self.decoder_tokenizer.bos_token + text + self.decoder_tokenizer.eos_token for text in texts]
705
+ inp_enc = self.decoder_tokenizer(inp_enc, return_tensors='pt', padding="max_length", max_length=max_length+3, truncation=True, add_special_tokens=False)
706
+
707
+ num_mem_tokens = self.doc_max_length // self.compr_rate
708
+ assert num_mem_tokens == len(self.decoder_tokenizer.mem_tokens)
709
  inp_enc['input_ids'], inp_enc['attention_mask'] = add_memory_tokens_to_inputs(inp_enc['input_ids'],
710
  inp_enc['attention_mask'],
711
  num_mem_tokens,
712
+ tokenizer=self.decoder_tokenizer)
713
 
714
  return inp_enc
715
 
716
+ def prepare_encoder_inputs(self, texts: list[str], max_length: int, q_texts: list[str] = None):
717
+ """
718
+ Create the inputs to the encoder, for compression.
719
+ """
720
+ if q_texts is not None:
721
+ assert len(texts) == len(q_texts), f"{len(texts)} == {len(q_texts)}"
722
+
723
+ # Case where the encoder is the decoder with adapter:
724
+ if self.compr is None:
725
+ return self.prepare_encoder_inputs_to_decoder(texts, max_length, q_texts)
726
 
727
+ # Case where the encoder is a separate network:
728
+ else:
729
+ return self.compr.prepare_inputs(texts, max_length, q_texts)
730
+
731
+ def replace_embeddings(self, compressed_embs, dec_input_ids, indices):
732
+ """
733
+ Replace memory tokens in the decoder input to with the compressed embeddings
734
+ """
735
+ inputs_embeds = self.decoder.get_input_embeddings()(dec_input_ids)
736
+ num_embs = compressed_embs.size(1)
737
+ if self.sep:
738
+ slot_len = num_embs + 1
739
+ else:
740
+ slot_len = num_embs
741
+ # get first mem_token indices
742
+ first_mem_token_indices = torch.argmax((dec_input_ids == self.decoder_tokenizer.mem_token_ids[0]).int(), dim=1)
743
+ batch_size = inputs_embeds.size(0)
744
+ # for each example in batch, replace them with compressed embeddings
745
+ for i in range(batch_size):
746
+ for j in range(indices[i], indices[i + 1]):
747
+ start_idx = first_mem_token_indices[i].item() + (j-indices[i]) * slot_len
748
+ assert inputs_embeds[i, start_idx:start_idx + num_embs, :].size() == compressed_embs[j].size(), \
749
+ f"{inputs_embeds[i, start_idx:start_idx + num_embs, :].size()} VS {compressed_embs[j].size()}"
750
+ inputs_embeds[i, start_idx:start_idx + num_embs, :] = compressed_embs[j]
751
+ return inputs_embeds
752
+
753
  def forward(self,
754
  enc_input_ids: torch.LongTensor = None,
755
  enc_attention_mask: torch.LongTensor = None,
 
779
  compressed_embs = self.compress(enc_input_ids, enc_attention_mask)
780
  inputs_embeds = self.replace_emb(compressed_embs, dec_input_ids)
781
 
782
+ # if training_form is compressor, then detach the inputs_embeds, to make gradient not count in decoder
783
+ if (self.training_form == "compressor") and (self.compr is None):
784
+ inputs_embeds = inputs_embeds.detach()
785
+
786
  # decoding
787
  if 'decoder_adapter' in self.adapter_keys:
788
  self.decoder.set_adapter('decoder_adapter')
 
793
  self.set_all_adapters()
794
 
795
  return {"loss": decoder_outputs.loss, "logits": decoder_outputs.logits}
796
+
797
+ def generate(self, model_input, max_new_tokens=128, return_doc_embeddings: bool = False):
798
+
799
+ enc_input_ids, enc_attention_mask, dec_input_ids, dec_attention_mask = model_input['enc_input_ids'], model_input['enc_attention_mask'], model_input['dec_input_ids'], model_input['dec_attention_mask']
800
+
801
+ assert enc_input_ids.size() == enc_attention_mask.size()
802
+
803
+ if len(enc_input_ids.size()) == 3: # likely from bergen: we just flatten all of this to perform encoding in one batch
804
+ batch_size, top_k, seq_length = enc_input_ids.size()
805
+ enc_input_ids = enc_input_ids.view(batch_size * top_k, seq_length)
806
+ enc_attention_mask = enc_attention_mask.view(batch_size * top_k, seq_length)
807
+
808
+ # Here, we should have top_k times more elements in enc_input_ids than in dec_input_ids
809
+ assert enc_input_ids.size(0) == dec_input_ids.size(0) * self.generation_top_k, \
810
+ f"{enc_input_ids.size(0)} VS {dec_input_ids.size(0)} with generation_top_k={self.generation_top_k}"
811
+
812
+ compressed_embs = self.compress(enc_input_ids.to('cuda'), enc_attention_mask.to('cuda'))
813
+ inputs_embeds = self.replace_emb(compressed_embs, dec_input_ids.to('cuda'))
814
+
815
+ # Switch adapter if we are training two different ones:
816
+ if 'decoder_adapter' in self.adapter_keys:
817
+ self.decoder.set_adapter('decoder_adapter')
818
+
819
+ output_ids = self.decoder.generate(
820
+ inputs_embeds=inputs_embeds.to("cuda"),
821
+ attention_mask=dec_attention_mask.to("cuda"),
822
+ do_sample=False,
823
+ top_p=None,
824
+ max_new_tokens=max_new_tokens
825
+ )
826
+
827
+ decoded = self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
828
+
829
+ if return_doc_embeddings:
830
+ # Compressed_embds is of shape (batch_size*top_k, n_mem_tokens, hidden_dim)
831
+ # We reshape to batch_size, top_k, n_mem_tokens, hidden_dim
832
+ assert batch_size is not None
833
+ assert top_k is not None
834
+ compressed_embs = compressed_embs.view(batch_size, top_k, compressed_embs.size(1), compressed_embs.size(2))
835
+ return decoded, compressed_embs
836
+ else:
837
+ return decoded
838
+
839
+ def get_all_adapters_state_dict(self):
840
+ """
841
+ Return the state dicts of the adapters
842
+ Used for saving so we go to cpu automatically
843
+ """
844
+ return {key: {k:v.cpu() for k, v in self.decoder.get_adapter_state_dict(key).items()} for key in self.adapter_keys}
845
+
846
+ def load_adapter_from_state_dict(self, peft_config: LoraConfig, adapter_name: str, adapter_state_dict: dict) -> None:
847
+ """
848
+ Creates an adapter from the state dict (used to load from pretrained)
849
+ """
850
+ # assert adapter_name not in self.adapter_keys, f'Adapter {adapter_name} already exists'
851
+ print(f'loading adapter {adapter_name}')
852
+ self.decoder.load_adapter(peft_config=peft_config, adapter_name=adapter_name, adapter_state_dict=adapter_state_dict)
853
+ self.adapter_keys.append(adapter_name)
854
+
855
+ def get_decoder_first_and_last_layer_state_dict(self) -> dict:
856
+ """
857
+ Just getting the first and last layers: the only ones which change when adding tokens
858
+ Used to save the model so we automatically move to cpu.
859
+ """
860
+ out = {}
861
+ for k, v in self.decoder.named_parameters():
862
+ if 'lm_head.weight' in k or 'embed_tokens.weight' in k:
863
+ out[k] = v.cpu()
864
+
865
+ # assert len(out) == 2, len(out) # We should get both the embedding layer and the head layer.
866
+ return out
867
+
868
+ def save_pretrained(self, save_directory: str, **kwargs):
869
+ """
870
+ Save only the LoRA adapters and their configurations.
871
+ """
872
+ if self.lora:
873
+ if not os.path.exists(save_directory):
874
+ os.makedirs(save_directory)
875
+
876
+ # Save the LoRA adapter weights
877
+ torch.save(self.get_all_adapters_state_dict(), os.path.join(save_directory, "adapters.pth"))
878
+
879
+ # Save the first and last layers of decoder (because of diffs with tokens !)
880
+ torch.save(self.get_decoder_first_and_last_layer_state_dict(), os.path.join(save_directory, "decoder_first_last_layers.pth"))
881
+
882
+ # Save the bert compressor if it exists
883
+ if self.compr_model_name is not None:
884
+ self.compr.save_pretrained(os.path.join(save_directory, 'compressor'))
885
+
886
+ # Save the configuration
887
+ self.config.save_pretrained(save_directory)
888
+ else:
889
+ super().save_pretrained(save_directory, **kwargs)
890
+
891
+ @classmethod
892
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
893
+ """
894
+ Loading: to take care of checkpoints containing only lora and not base model.
895
+ """
896
+ # Load the configuration
897
+ config = COCOMConfig.from_pretrained(pretrained_model_name_or_path)
898
+
899
+ config.attn_implementation = kwargs.get('attn_implementation', config.attn_implementation)
900
+
901
+ map_location = torch.device("cpu") if not torch.cuda.is_available() else None
902
+
903
+ if config.lora:
904
+ # We need to delay the construction of the adapters (otherwise peft complains)
905
+ config.load_adapters = False
906
+
907
+ if 'device_map' in kwargs:
908
+ config.device_map = kwargs['device_map']
909
+
910
+ # Initialize the model
911
+ model = cls(config)
912
+
913
+ # Loading first and last layers (they might have changed due to extra tokens)
914
+ try:
915
+ # If loading from Hugging Face Hub
916
+ first_and_last_layers_path = hf_hub_download(
917
+ repo_id=pretrained_model_name_or_path,
918
+ filename="decoder_first_last_layers.pth"
919
+ )
920
+ except Exception as e:
921
+ # If loading from a local directory
922
+ first_and_last_layers_path = os.path.join(pretrained_model_name_or_path, "decoder_first_last_layers.pth")
923
+
924
+ if os.path.exists(first_and_last_layers_path):
925
+ first_and_last_decoder_state_dict = torch.load(first_and_last_layers_path, map_location=map_location, weights_only=True)
926
+ for key in first_and_last_decoder_state_dict:
927
+ assert key in model.decoder.state_dict()
928
+ model.decoder.load_state_dict(first_and_last_decoder_state_dict, strict=False)
929
+
930
+ else:
931
+ print('FIRST AND LAST LAYER NOT FOUND (ok for some old models):', first_and_last_layers_path)
932
+
933
+ peft_config = model.get_peft_config(lora_r=config.lora_r)
934
+
935
+ # Load the LoRA adapters (if the file exists)
936
+ try:
937
+ # If loading from Hugging Face Hub
938
+ adapters_path = hf_hub_download(
939
+ repo_id=pretrained_model_name_or_path,
940
+ filename="adapters.pth"
941
+ )
942
+ except Exception as e:
943
+ # If loading from a local directory
944
+ adapters_path = os.path.join(pretrained_model_name_or_path, "adapters.pth")
945
+
946
+ if os.path.exists(adapters_path):
947
+ adapters_state_dict = torch.load(adapters_path, map_location=map_location, weights_only=True)
948
+
949
+ for key, val in adapters_state_dict.items():
950
+ model.load_adapter_from_state_dict(peft_config=peft_config, adapter_name=key, adapter_state_dict=val)
951
+
952
+ else:
953
+ warnings.warn(f'I see lora on that PISCO model, but {adapters_path} does not exist, it may be normal \
954
+ for recent versions of transformers, be aware.')
955
+
956
+ # If there is a compressor, it's been built: we just need to load the state dict or the adapters:
957
+ if config.compr_model_name is not None:
958
+ model.compr.load_pretrained(os.path.join(pretrained_model_name_or_path, 'compressor'),
959
+ lora=config.lora_compressor,
960
+ peft_config=model.get_peft_config(lora_r=config.lora_r_compressor))
961
+
962
+ model.set_all_adapters()
963
+ model.config.load_adapters = True
964
+ return model
965
+
966
+ else:
967
+ return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
968
+
969
  def generate_from_text(self, questions: list[str], documents: list[list[str]], max_new_tokens: int = 128) -> list[str]:
970
  """
971
  Generates answers from documents (via compression then decoding)
 
986
 
987
  # Creating decoder inputs
988
  instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions]
989
+ inp_dec = self.decoder_tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=2048)
990
  model_input['dec_input_ids'], model_input['dec_attention_mask'] = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device)
991
 
992
  # Generation
 
1003
 
1004
  # Creating decoder inputs
1005
  instr = [self.blend_prompt_and_memory_tokens(query=q) for q in questions]
1006
+ inp_dec = self.decoder_tokenizer(instr, return_tensors='pt', padding="longest", add_special_tokens=False, truncation=True, max_length=2048)
1007
  device = self.decoder.device
1008
  dec_input_ids, dec_attention_mask = inp_dec['input_ids'].to(device), inp_dec['attention_mask'].to(device)
1009
 
 
1022
  )
1023
 
1024
  # de-tokenizing
1025
+ return self.decoder_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
1026
 
1027
  def compress_documents(self, documents: list[str]) -> torch.Tensor:
1028
  """
 
1032
  enc_input_ids = input_encoder['input_ids'].to(self.decoder.device)
1033
  attention_mask = input_encoder['attention_mask'].to(self.decoder.device)
1034
  return self.compress(enc_input_ids=enc_input_ids, enc_attention_mask=attention_mask)
1035
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1036
  def blend_prompt_and_memory_tokens(self, query: str):
1037
  """
1038
  Takes care of blending the prompt with the memory tokens:
1039
  Also returns, if a label is provided, the position of the first token index of the label (for loss comp later on)
1040
+ (Used for the HUB version)
1041
  """
1042
+ mem_tokens_str = ''.join(self.decoder_tokenizer.mem_tokens) + self.decoder_tokenizer.sep_token
1043
 
1044
  # proper names for "eval" call, don't remove these lines
1045
  docs = mem_tokens_str * self.generation_top_k
 
1056
 
1057
  # Attempt to apply the system role and catch if it's not supported
1058
  try:
1059
+ prompt = self.decoder_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
1060
 
1061
  except TemplateError as e:
1062
  # Catch the error related to system role and handle it (e.g. gemma)
 
1064
  # Remove system role and proceed with only the user role
1065
  messages = [{"role": "user", "content": messages[0]['content'] + '\n' + messages[1]['content']}]
1066
  # Apply template again without system role
1067
+ prompt = self.decoder_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
1068
  else:
1069
  # Re-raise the exception if it's unrelated to system role
1070
  raise e
1071
 
1072
  return prompt
1073
+
1074
+
1075
+ if __name__ == '__main__':
1076
+ cfg = COCOMConfig(decoder_model_name='mistralai/Mistral-7B-Instruct-v0.2',
1077
+ compr_model_name = "mistral_trimmed",
1078
+ compr_rate = 64,
1079
+ compr_n_layers = 5,
1080
+ compr_mlp_hidden_dim = 8096,
1081
+ compr_use_mlp = False,
1082
+ lora = True, # lora on decoder (and decoder as compr)
1083
+ lora_compressor = True, # lora only on the compressor if it exists
1084
+ training_form = "both",
1085
+ load_adapters = True,
1086
+ kbtc_training = False,
1087
+ optimize_mem_tokens = True,
1088
+ different_mem_tokens = True,
1089
+ attn_implementation = 'flash_attention_2')
1090
+
1091
+ cocom = COCOM(cfg)
1092
+
1093
+ cocom.save_pretrained('test_ckpt')
1094
+
1095
+ del cocom
1096
+ torch.cuda.empty_cache()
1097
+ import gc
1098
+ gc.collect()
1099
+
1100
+ cocom = COCOM.from_pretrained('test_ckpt')