Fabian-David Schmidt commited on
Commit
838c37a
·
1 Parent(s): ba88283

update config and modelling files

Browse files
Files changed (3) hide show
  1. config.json +1 -3
  2. configuration_nllbllm2vec.py +15 -2
  3. modeling_nllbllm2vec.py +243 -407
config.json CHANGED
@@ -1,5 +1,4 @@
1
  {
2
- "_name_or_path": "fdschmidt93/NLLBLLM2Vec",
3
  "architectures": [
4
  "NLLBLLM2Vec"
5
  ],
@@ -37,6 +36,5 @@
37
  "vocab_size": 256206
38
  },
39
  "torch_dtype": "bfloat16",
40
- "transformers_version": "4.44.2"
41
  }
42
-
 
1
  {
 
2
  "architectures": [
3
  "NLLBLLM2Vec"
4
  ],
 
36
  "vocab_size": 256206
37
  },
38
  "torch_dtype": "bfloat16",
39
+ "transformers_version": "4.45.2"
40
  }
 
configuration_nllbllm2vec.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from transformers import AutoConfig
2
  from transformers.configuration_utils import PretrainedConfig
3
  from transformers.models.llama.configuration_llama import LlamaConfig
@@ -36,6 +37,7 @@ DEFAULT_M2M100_CONFIG = {
36
  "vocab_size": 256206,
37
  "tokenizer_class": "NllbTokenizer",
38
  "max_length": 200,
 
39
  }
40
 
41
  DEFAULT_LLAMA_CONFIG = {
@@ -61,6 +63,7 @@ DEFAULT_LLAMA_CONFIG = {
61
  "transformers_version": "4.40.0.dev0",
62
  "use_cache": False,
63
  "vocab_size": 128256,
 
64
  }
65
 
66
 
@@ -70,13 +73,23 @@ class NLLBLLM2VecConfig(PretrainedConfig):
70
 
71
  def __init__(
72
  self,
73
- nllb_config: dict = DEFAULT_M2M100_CONFIG,
74
- llm2vec_config: dict = DEFAULT_LLAMA_CONFIG,
 
 
75
  **kwargs,
76
  ):
77
  super().__init__(**kwargs)
 
78
  self.nllb_config = M2M100Config(**nllb_config)
 
79
  self.llm2vec_config = LlamaConfig(**llm2vec_config)
 
 
 
 
 
 
80
 
81
 
82
  AutoConfig.register(NLLBLLM2VEC_TYPE, NLLBLLM2VecConfig)
 
1
+ from typing import Optional, Dict
2
  from transformers import AutoConfig
3
  from transformers.configuration_utils import PretrainedConfig
4
  from transformers.models.llama.configuration_llama import LlamaConfig
 
37
  "vocab_size": 256206,
38
  "tokenizer_class": "NllbTokenizer",
39
  "max_length": 200,
40
+ "_attn_implementation": "flash_attention_2",
41
  }
42
 
43
  DEFAULT_LLAMA_CONFIG = {
 
63
  "transformers_version": "4.40.0.dev0",
64
  "use_cache": False,
65
  "vocab_size": 128256,
66
+ "_attn_implementation": "flash_attention_2",
67
  }
68
 
69
 
 
73
 
74
  def __init__(
75
  self,
76
+ nllb_config: Dict = DEFAULT_M2M100_CONFIG,
77
+ llm2vec_config: Dict = DEFAULT_LLAMA_CONFIG,
78
+ _attn_implementation="sdpa",
79
+ initializer_range: Optional[float] = None,
80
  **kwargs,
81
  ):
82
  super().__init__(**kwargs)
83
+ self._attn_implementation = _attn_implementation
84
  self.nllb_config = M2M100Config(**nllb_config)
85
+ self.nllb_config._attn_implementation = _attn_implementation
86
  self.llm2vec_config = LlamaConfig(**llm2vec_config)
87
+ self.llm2vec_config._attn_implementation = _attn_implementation
88
+ if initializer_range is None:
89
+ self.initializer_range = self.llm2vec_config.initializer_range
90
+ else:
91
+ self.initializer_range = initializer_range
92
+ self.llm2vec_config.initializer_range
93
 
94
 
95
  AutoConfig.register(NLLBLLM2VEC_TYPE, NLLBLLM2VecConfig)
modeling_nllbllm2vec.py CHANGED
@@ -1,24 +1,69 @@
1
- from typing import Any, Dict, List, Optional, Tuple, cast, Union
 
 
 
2
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
- from transformers.models.auto import AutoModel, AutoModelForSequenceClassification
 
 
 
 
7
  from transformers.modeling_outputs import (
8
  BaseModelOutputWithPooling,
 
9
  SequenceClassifierOutputWithPast,
 
10
  )
11
  from transformers.modeling_utils import PreTrainedModel
 
12
  from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder
13
- from transformers.cache_utils import Cache
14
 
15
  from .configuration_nllbllm2vec import NLLBLLM2VecConfig
16
  from .modeling_llama_encoder import LlamaEncoderModel
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  class NLLBLLM2Vec(PreTrainedModel):
20
  config_class = NLLBLLM2VecConfig
21
  model_type = "nllb-llm2vec"
 
 
22
  """
23
  NLLBLLM2Vec model combining NLLB and LLama encoders.
24
 
@@ -46,9 +91,13 @@ class NLLBLLM2Vec(PreTrainedModel):
46
 
47
  if config is not None:
48
  super().__init__(config, *inputs, **kwargs)
 
 
 
49
  self.nllb_encoder = nllb_encoder or M2M100Encoder(config.nllb_config)
50
  self.llm2vec = llm2vec or LlamaEncoderModel(config.llm2vec_config)
51
  self.config = config
 
52
  else:
53
  # Both encoders are provided
54
  self.nllb_encoder = cast(M2M100Encoder, nllb_encoder)
@@ -64,7 +113,15 @@ class NLLBLLM2Vec(PreTrainedModel):
64
  self.llm2vec.config.hidden_size,
65
  bias=False,
66
  )
67
- # Additional initialization logic can go here
 
 
 
 
 
 
 
 
68
 
69
  def forward(
70
  self,
@@ -91,14 +148,12 @@ class NLLBLLM2Vec(PreTrainedModel):
91
  else:
92
  seq_indices, seq_offsets = indices
93
 
94
- with torch.inference_mode():
95
- nllb_outputs = self.nllb_encoder(
96
- input_ids=input_ids,
97
- attention_mask=attention_mask,
98
- )
99
- nllb_last_hidden_state = nllb_outputs.last_hidden_state
100
- nllb_last_hidden_state = self.up_proj(nllb_last_hidden_state)
101
- nllb_last_hidden_state = nllb_last_hidden_state.detach().clone()
102
  outputs = self.llm2vec(
103
  inputs_embeds=nllb_last_hidden_state,
104
  attention_mask=attention_mask,
@@ -133,14 +188,22 @@ class NLLBLLM2Vec(PreTrainedModel):
133
  self,
134
  inputs: List[str],
135
  src_lang: str = "eng_Latn",
 
136
  tokenize_kwargs: Optional[Dict[str, Any]] = None,
 
137
  ) -> torch.Tensor:
138
  """
139
  Encode input texts into embeddings.
140
 
141
  Args:
142
  inputs (List[str]): List of input texts.
143
- src_lang (str): Source language code.
 
 
 
 
 
 
144
  tokenize_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer.
145
  Defaults to:
146
  >> tokenize_kwargs = {
@@ -149,26 +212,54 @@ class NLLBLLM2Vec(PreTrainedModel):
149
  >> "max_length": 512,
150
  >> "return_tensors": "pt",
151
  >> }
152
-
 
 
 
 
 
153
  Returns:
154
  torch.Tensor: Mean-pooled sequence embeddings of the inputs.
155
  """
156
- if tokenize_kwargs is None:
157
- tokenize_kwargs = {
158
- "padding": True,
159
- "truncation": True,
160
- "max_length": 512,
161
- "return_tensors": "pt",
162
- }
163
 
164
  tokenizer = self.tokenizer
165
  tokenizer.src_lang = src_lang
166
  device = next(self.parameters()).device
167
- batch = tokenizer(inputs, **tokenize_kwargs).to(device)
168
- device_type = device.type # e.g., 'cuda' or 'cpu'
169
 
170
- with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
171
- return self(**batch).pooler_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  @staticmethod
174
  def _get_input_offsets(
@@ -192,12 +283,8 @@ class NLLBLLM2Vec(PreTrainedModel):
192
  non_padded_lengths = attention_mask.sum(
193
  dim=1
194
  ) # Count non-padded tokens per sequence
195
- offsets = torch.cat(
196
- [
197
- torch.tensor([0], device=attention_mask.device),
198
- non_padded_lengths.cumsum(dim=0)[:-1],
199
- ]
200
- )
201
  return input_indices, offsets
202
 
203
  @staticmethod
@@ -235,10 +322,13 @@ class NLLBLLM2VecForSequenceClassification(PreTrainedModel):
235
  config_class = NLLBLLM2VecConfig
236
  model_type = "nllb-llm2vec"
237
  base_model_prefix = "model"
 
 
238
 
239
  def __init__(self, config):
240
  super().__init__(config)
241
  self.num_labels = config.num_labels
 
242
  self.model = NLLBLLM2Vec(config)
243
  self.score = nn.Linear(
244
  config.llm2vec_config.hidden_size, self.num_labels, bias=False
@@ -247,114 +337,29 @@ class NLLBLLM2VecForSequenceClassification(PreTrainedModel):
247
  # Initialize weights and apply final processing
248
  self.post_init()
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  def get_input_embeddings(self):
251
  return self.model.nllb.embed_tokens
252
 
253
  def set_input_embeddings(self, value):
254
  self.model.nllb.embed_tokens = value
255
-
256
- # We need to modify the adapter config and state dict at runtime
257
- # such that adapter weights are correctly loaded from an AutoModel-suitable
258
- # adapter_config.json and adapter_config.safetensors
259
- def load_adapter(
260
- self,
261
- peft_model_id: Optional[str] = None,
262
- adapter_name: Optional[str] = None,
263
- revision: Optional[str] = None,
264
- token: Optional[str] = None,
265
- device_map: Optional[str] = "auto",
266
- max_memory: Optional[str] = None,
267
- offload_folder: Optional[str] = None,
268
- offload_index: Optional[int] = None,
269
- peft_config: Optional[Dict[str, Any]] = None,
270
- adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
271
- adapter_kwargs: Optional[Dict[str, Any]] = None,
272
- ) -> None:
273
- from peft import PeftConfig, load_peft_weights # type: ignore
274
- from transformers.utils import find_adapter_config_file
275
-
276
- if adapter_kwargs is None:
277
- adapter_kwargs = {}
278
-
279
- if "device" not in adapter_kwargs:
280
- device = (
281
- self.device
282
- if not hasattr(self, "hf_device_map")
283
- else list(self.hf_device_map.values())[0]
284
- )
285
- else:
286
- device = adapter_kwargs["device"]
287
- # To avoid PEFT errors later on with safetensors.
288
- if isinstance(device, torch.device):
289
- device = str(device)
290
-
291
- # Override token with adapter_kwargs' token
292
- if "token" in adapter_kwargs:
293
- token = adapter_kwargs["token"]
294
-
295
- if peft_model_id is None and (
296
- adapter_state_dict is None and peft_config is None
297
- ):
298
- raise ValueError(
299
- "You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
300
- )
301
-
302
- if peft_config is None:
303
- assert isinstance(peft_model_id, str)
304
- adapter_config_file = find_adapter_config_file(
305
- peft_model_id,
306
- token=token,
307
- **adapter_kwargs,
308
- )
309
-
310
- if adapter_config_file is None:
311
- raise ValueError(
312
- f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the "
313
- "adapter model."
314
- )
315
-
316
- peft_config = cast(
317
- Dict[str, Any],
318
- PeftConfig.from_pretrained(
319
- peft_model_id,
320
- token=token,
321
- **adapter_kwargs,
322
- ),
323
- )
324
- peft_config.target_modules = [ # type: ignore
325
- "model." + module
326
- for module in peft_config.target_modules # type: ignore
327
- ]
328
-
329
- if peft_model_id is not None:
330
- adapter_state_dict = load_peft_weights(
331
- peft_model_id, token=token, device=device, **adapter_kwargs
332
- )
333
-
334
- assert isinstance(adapter_state_dict, dict)
335
-
336
- # correctly set the name
337
- processed_adapter_state_dict = {}
338
- prefix = "base_model."
339
- for key, value in adapter_state_dict.items():
340
- if key.startswith(prefix):
341
- new_key = key[len(prefix) :]
342
- else:
343
- new_key = key
344
- processed_adapter_state_dict[new_key] = value
345
- return super().load_adapter(
346
- peft_model_id=None,
347
- adapter_name=adapter_name,
348
- revision=revision,
349
- token=token,
350
- device_map=device_map,
351
- max_memory=max_memory,
352
- offload_folder=offload_folder,
353
- offload_index=offload_index,
354
- peft_config=peft_config,
355
- adapter_state_dict=processed_adapter_state_dict,
356
- adapter_kwargs=adapter_kwargs,
357
- )
358
 
359
  def forward(
360
  self,
@@ -420,10 +425,110 @@ class NLLBLLM2VecForSequenceClassification(PreTrainedModel):
420
  output = (pooled_logits,) + transformer_outputs[1:]
421
  return ((loss,) + output) if loss is not None else output
422
 
423
- return SequenceClassifierOutputWithPast(
424
  loss=loss,
425
  hidden_states=hidden_states,
426
  logits=pooled_logits,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  )
428
 
429
 
@@ -431,275 +536,6 @@ AutoModel.register(NLLBLLM2VecConfig, NLLBLLM2Vec)
431
  AutoModelForSequenceClassification.register(
432
  NLLBLLM2VecConfig, NLLBLLM2VecForSequenceClassification
433
  )
434
-
435
-
436
- def repl():
437
- from transformers import AutoModel
438
-
439
- cfg = NLLBLLM2VecConfig()
440
- model = NLLBLLM2Vec(cfg)
441
-
442
- nllb = AutoModel.from_pretrained(
443
- "facebook/nllb-200-distilled-600M", torch_dtype=torch.bfloat16
444
- ).encoder
445
- # llm2vec = AutoModel.from_pretrained(
446
- # "fdschmidt93/LLM2Vec-Meta-Llama-3.1-8B-Instruct-mntp-unsup-simcse",
447
- # trust_remote_code=True,
448
- # torch_dtype=torch.bfloat16,
449
- # )
450
- llama = LlamaEncoderModel.from_pretrained("../trident-nllb-llm2vec/data/model/llm2vec_llama3-1_unsupervised/", torch_dtype=torch.bfloat16)
451
- model.nllb_encoder.load_state_dict(nllb.state_dict())
452
- model.llm2vec.load_state_dict(llama.state_dict())
453
- ckpt = torch.load("./step=20000-weights.ckpt", map_location="cpu")
454
- model.up_proj.load_state_dict({"weight": ckpt["model.up_proj.weight"]})
455
-
456
- model.save_pretrained("../weights_new")
457
-
458
- from peft.mapping import get_peft_model
459
- from peft.tuners.lora.config import LoraConfig
460
-
461
- lora_config = LoraConfig(
462
- r=16,
463
- lora_alpha=32,
464
- lora_dropout=0.0,
465
- bias="none",
466
- task_type="FEATURE_EXTRACTION",
467
- target_modules=[
468
- "llm2vec.layers.0.self_attn.q_proj",
469
- "llm2vec.layers.0.self_attn.k_proj",
470
- "llm2vec.layers.0.self_attn.v_proj",
471
- "llm2vec.layers.0.self_attn.o_proj",
472
- "llm2vec.layers.0.mlp.gate_proj",
473
- "llm2vec.layers.0.mlp.up_proj",
474
- "llm2vec.layers.0.mlp.down_proj",
475
- "llm2vec.layers.1.self_attn.q_proj",
476
- "llm2vec.layers.1.self_attn.k_proj",
477
- "llm2vec.layers.1.self_attn.v_proj",
478
- "llm2vec.layers.1.self_attn.o_proj",
479
- "llm2vec.layers.1.mlp.gate_proj",
480
- "llm2vec.layers.1.mlp.up_proj",
481
- "llm2vec.layers.1.mlp.down_proj",
482
- "llm2vec.layers.2.self_attn.q_proj",
483
- "llm2vec.layers.2.self_attn.k_proj",
484
- "llm2vec.layers.2.self_attn.v_proj",
485
- "llm2vec.layers.2.self_attn.o_proj",
486
- "llm2vec.layers.2.mlp.gate_proj",
487
- "llm2vec.layers.2.mlp.up_proj",
488
- "llm2vec.layers.2.mlp.down_proj",
489
- "llm2vec.layers.3.self_attn.q_proj",
490
- "llm2vec.layers.3.self_attn.k_proj",
491
- "llm2vec.layers.3.self_attn.v_proj",
492
- "llm2vec.layers.3.self_attn.o_proj",
493
- "llm2vec.layers.3.mlp.gate_proj",
494
- "llm2vec.layers.3.mlp.up_proj",
495
- "llm2vec.layers.3.mlp.down_proj",
496
- "llm2vec.layers.4.self_attn.q_proj",
497
- "llm2vec.layers.4.self_attn.k_proj",
498
- "llm2vec.layers.4.self_attn.v_proj",
499
- "llm2vec.layers.4.self_attn.o_proj",
500
- "llm2vec.layers.4.mlp.gate_proj",
501
- "llm2vec.layers.4.mlp.up_proj",
502
- "llm2vec.layers.4.mlp.down_proj",
503
- "llm2vec.layers.5.self_attn.q_proj",
504
- "llm2vec.layers.5.self_attn.k_proj",
505
- "llm2vec.layers.5.self_attn.v_proj",
506
- "llm2vec.layers.5.self_attn.o_proj",
507
- "llm2vec.layers.5.mlp.gate_proj",
508
- "llm2vec.layers.5.mlp.up_proj",
509
- "llm2vec.layers.5.mlp.down_proj",
510
- "llm2vec.layers.6.self_attn.q_proj",
511
- "llm2vec.layers.6.self_attn.k_proj",
512
- "llm2vec.layers.6.self_attn.v_proj",
513
- "llm2vec.layers.6.self_attn.o_proj",
514
- "llm2vec.layers.6.mlp.gate_proj",
515
- "llm2vec.layers.6.mlp.up_proj",
516
- "llm2vec.layers.6.mlp.down_proj",
517
- "llm2vec.layers.7.self_attn.q_proj",
518
- "llm2vec.layers.7.self_attn.k_proj",
519
- "llm2vec.layers.7.self_attn.v_proj",
520
- "llm2vec.layers.7.self_attn.o_proj",
521
- "llm2vec.layers.7.mlp.gate_proj",
522
- "llm2vec.layers.7.mlp.up_proj",
523
- "llm2vec.layers.7.mlp.down_proj",
524
- "llm2vec.layers.8.self_attn.q_proj",
525
- "llm2vec.layers.8.self_attn.k_proj",
526
- "llm2vec.layers.8.self_attn.v_proj",
527
- "llm2vec.layers.8.self_attn.o_proj",
528
- "llm2vec.layers.8.mlp.gate_proj",
529
- "llm2vec.layers.8.mlp.up_proj",
530
- "llm2vec.layers.8.mlp.down_proj",
531
- "llm2vec.layers.9.self_attn.q_proj",
532
- "llm2vec.layers.9.self_attn.k_proj",
533
- "llm2vec.layers.9.self_attn.v_proj",
534
- "llm2vec.layers.9.self_attn.o_proj",
535
- "llm2vec.layers.9.mlp.gate_proj",
536
- "llm2vec.layers.9.mlp.up_proj",
537
- "llm2vec.layers.9.mlp.down_proj",
538
- "llm2vec.layers.10.self_attn.q_proj",
539
- "llm2vec.layers.10.self_attn.k_proj",
540
- "llm2vec.layers.10.self_attn.v_proj",
541
- "llm2vec.layers.10.self_attn.o_proj",
542
- "llm2vec.layers.10.mlp.gate_proj",
543
- "llm2vec.layers.10.mlp.up_proj",
544
- "llm2vec.layers.10.mlp.down_proj",
545
- "llm2vec.layers.11.self_attn.q_proj",
546
- "llm2vec.layers.11.self_attn.k_proj",
547
- "llm2vec.layers.11.self_attn.v_proj",
548
- "llm2vec.layers.11.self_attn.o_proj",
549
- "llm2vec.layers.11.mlp.gate_proj",
550
- "llm2vec.layers.11.mlp.up_proj",
551
- "llm2vec.layers.11.mlp.down_proj",
552
- "llm2vec.layers.12.self_attn.q_proj",
553
- "llm2vec.layers.12.self_attn.k_proj",
554
- "llm2vec.layers.12.self_attn.v_proj",
555
- "llm2vec.layers.12.self_attn.o_proj",
556
- "llm2vec.layers.12.mlp.gate_proj",
557
- "llm2vec.layers.12.mlp.up_proj",
558
- "llm2vec.layers.12.mlp.down_proj",
559
- "llm2vec.layers.13.self_attn.q_proj",
560
- "llm2vec.layers.13.self_attn.k_proj",
561
- "llm2vec.layers.13.self_attn.v_proj",
562
- "llm2vec.layers.13.self_attn.o_proj",
563
- "llm2vec.layers.13.mlp.gate_proj",
564
- "llm2vec.layers.13.mlp.up_proj",
565
- "llm2vec.layers.13.mlp.down_proj",
566
- "llm2vec.layers.14.self_attn.q_proj",
567
- "llm2vec.layers.14.self_attn.k_proj",
568
- "llm2vec.layers.14.self_attn.v_proj",
569
- "llm2vec.layers.14.self_attn.o_proj",
570
- "llm2vec.layers.14.mlp.gate_proj",
571
- "llm2vec.layers.14.mlp.up_proj",
572
- "llm2vec.layers.14.mlp.down_proj",
573
- "llm2vec.layers.15.self_attn.q_proj",
574
- "llm2vec.layers.15.self_attn.k_proj",
575
- "llm2vec.layers.15.self_attn.v_proj",
576
- "llm2vec.layers.15.self_attn.o_proj",
577
- "llm2vec.layers.15.mlp.gate_proj",
578
- "llm2vec.layers.15.mlp.up_proj",
579
- "llm2vec.layers.15.mlp.down_proj",
580
- "llm2vec.layers.16.self_attn.q_proj",
581
- "llm2vec.layers.16.self_attn.k_proj",
582
- "llm2vec.layers.16.self_attn.v_proj",
583
- "llm2vec.layers.16.self_attn.o_proj",
584
- "llm2vec.layers.16.mlp.gate_proj",
585
- "llm2vec.layers.16.mlp.up_proj",
586
- "llm2vec.layers.16.mlp.down_proj",
587
- "llm2vec.layers.17.self_attn.q_proj",
588
- "llm2vec.layers.17.self_attn.k_proj",
589
- "llm2vec.layers.17.self_attn.v_proj",
590
- "llm2vec.layers.17.self_attn.o_proj",
591
- "llm2vec.layers.17.mlp.gate_proj",
592
- "llm2vec.layers.17.mlp.up_proj",
593
- "llm2vec.layers.17.mlp.down_proj",
594
- "llm2vec.layers.18.self_attn.q_proj",
595
- "llm2vec.layers.18.self_attn.k_proj",
596
- "llm2vec.layers.18.self_attn.v_proj",
597
- "llm2vec.layers.18.self_attn.o_proj",
598
- "llm2vec.layers.18.mlp.gate_proj",
599
- "llm2vec.layers.18.mlp.up_proj",
600
- "llm2vec.layers.18.mlp.down_proj",
601
- "llm2vec.layers.19.self_attn.q_proj",
602
- "llm2vec.layers.19.self_attn.k_proj",
603
- "llm2vec.layers.19.self_attn.v_proj",
604
- "llm2vec.layers.19.self_attn.o_proj",
605
- "llm2vec.layers.19.mlp.gate_proj",
606
- "llm2vec.layers.19.mlp.up_proj",
607
- "llm2vec.layers.19.mlp.down_proj",
608
- "llm2vec.layers.20.self_attn.q_proj",
609
- "llm2vec.layers.20.self_attn.k_proj",
610
- "llm2vec.layers.20.self_attn.v_proj",
611
- "llm2vec.layers.20.self_attn.o_proj",
612
- "llm2vec.layers.20.mlp.gate_proj",
613
- "llm2vec.layers.20.mlp.up_proj",
614
- "llm2vec.layers.20.mlp.down_proj",
615
- "llm2vec.layers.21.self_attn.q_proj",
616
- "llm2vec.layers.21.self_attn.k_proj",
617
- "llm2vec.layers.21.self_attn.v_proj",
618
- "llm2vec.layers.21.self_attn.o_proj",
619
- "llm2vec.layers.21.mlp.gate_proj",
620
- "llm2vec.layers.21.mlp.up_proj",
621
- "llm2vec.layers.21.mlp.down_proj",
622
- "llm2vec.layers.22.self_attn.q_proj",
623
- "llm2vec.layers.22.self_attn.k_proj",
624
- "llm2vec.layers.22.self_attn.v_proj",
625
- "llm2vec.layers.22.self_attn.o_proj",
626
- "llm2vec.layers.22.mlp.gate_proj",
627
- "llm2vec.layers.22.mlp.up_proj",
628
- "llm2vec.layers.22.mlp.down_proj",
629
- "llm2vec.layers.23.self_attn.q_proj",
630
- "llm2vec.layers.23.self_attn.k_proj",
631
- "llm2vec.layers.23.self_attn.v_proj",
632
- "llm2vec.layers.23.self_attn.o_proj",
633
- "llm2vec.layers.23.mlp.gate_proj",
634
- "llm2vec.layers.23.mlp.up_proj",
635
- "llm2vec.layers.23.mlp.down_proj",
636
- "llm2vec.layers.24.self_attn.q_proj",
637
- "llm2vec.layers.24.self_attn.k_proj",
638
- "llm2vec.layers.24.self_attn.v_proj",
639
- "llm2vec.layers.24.self_attn.o_proj",
640
- "llm2vec.layers.24.mlp.gate_proj",
641
- "llm2vec.layers.24.mlp.up_proj",
642
- "llm2vec.layers.24.mlp.down_proj",
643
- "llm2vec.layers.25.self_attn.q_proj",
644
- "llm2vec.layers.25.self_attn.k_proj",
645
- "llm2vec.layers.25.self_attn.v_proj",
646
- "llm2vec.layers.25.self_attn.o_proj",
647
- "llm2vec.layers.25.mlp.gate_proj",
648
- "llm2vec.layers.25.mlp.up_proj",
649
- "llm2vec.layers.25.mlp.down_proj",
650
- "llm2vec.layers.26.self_attn.q_proj",
651
- "llm2vec.layers.26.self_attn.k_proj",
652
- "llm2vec.layers.26.self_attn.v_proj",
653
- "llm2vec.layers.26.self_attn.o_proj",
654
- "llm2vec.layers.26.mlp.gate_proj",
655
- "llm2vec.layers.26.mlp.up_proj",
656
- "llm2vec.layers.26.mlp.down_proj",
657
- "llm2vec.layers.27.self_attn.q_proj",
658
- "llm2vec.layers.27.self_attn.k_proj",
659
- "llm2vec.layers.27.self_attn.v_proj",
660
- "llm2vec.layers.27.self_attn.o_proj",
661
- "llm2vec.layers.27.mlp.gate_proj",
662
- "llm2vec.layers.27.mlp.up_proj",
663
- "llm2vec.layers.27.mlp.down_proj",
664
- "llm2vec.layers.28.self_attn.q_proj",
665
- "llm2vec.layers.28.self_attn.k_proj",
666
- "llm2vec.layers.28.self_attn.v_proj",
667
- "llm2vec.layers.28.self_attn.o_proj",
668
- "llm2vec.layers.28.mlp.gate_proj",
669
- "llm2vec.layers.28.mlp.up_proj",
670
- "llm2vec.layers.28.mlp.down_proj",
671
- "llm2vec.layers.29.self_attn.q_proj",
672
- "llm2vec.layers.29.self_attn.k_proj",
673
- "llm2vec.layers.29.self_attn.v_proj",
674
- "llm2vec.layers.29.self_attn.o_proj",
675
- "llm2vec.layers.29.mlp.gate_proj",
676
- "llm2vec.layers.29.mlp.up_proj",
677
- "llm2vec.layers.29.mlp.down_proj",
678
- "llm2vec.layers.30.self_attn.q_proj",
679
- "llm2vec.layers.30.self_attn.k_proj",
680
- "llm2vec.layers.30.self_attn.v_proj",
681
- "llm2vec.layers.30.self_attn.o_proj",
682
- "llm2vec.layers.30.mlp.gate_proj",
683
- "llm2vec.layers.30.mlp.up_proj",
684
- "llm2vec.layers.30.mlp.down_proj",
685
- "llm2vec.layers.31.self_attn.q_proj",
686
- "llm2vec.layers.31.self_attn.k_proj",
687
- "llm2vec.layers.31.self_attn.v_proj",
688
- "llm2vec.layers.31.self_attn.o_proj",
689
- "llm2vec.layers.31.mlp.gate_proj",
690
- "llm2vec.layers.31.mlp.up_proj",
691
- "llm2vec.layers.31.mlp.down_proj",
692
- ],
693
- )
694
- peft_model = get_peft_model(model, lora_config)
695
- peft_model.save_pretrained("../nllb-llm2vec-saved")
696
- import json
697
-
698
- with open("./model.safetensors.index.json", "r") as f:
699
- print(json.load(f))
700
-
701
- from transformers import AutoModelForSequenceClassification
702
-
703
- model = AutoModelForSequenceClassification.from_pretrained(
704
- ".", trust_remote_code=True, device_map="cuda"
705
- )
 
1
+ import math
2
+ import warnings
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
5
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
+ import transformers
10
+ from packaging import version
11
+ from torch.utils.data.dataloader import DataLoader
12
+ from tqdm import tqdm
13
+ from transformers.cache_utils import Cache
14
  from transformers.modeling_outputs import (
15
  BaseModelOutputWithPooling,
16
+ ModelOutput,
17
  SequenceClassifierOutputWithPast,
18
+ TokenClassifierOutput,
19
  )
20
  from transformers.modeling_utils import PreTrainedModel
21
+ from transformers.models.auto import AutoModel, AutoModelForSequenceClassification
22
  from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder
23
+ from transformers.tokenization_utils import BatchEncoding
24
 
25
  from .configuration_nllbllm2vec import NLLBLLM2VecConfig
26
  from .modeling_llama_encoder import LlamaEncoderModel
27
 
28
+ DEFAULT_TOKENIZE_KWARGS = {
29
+ "padding": True,
30
+ "truncation": True,
31
+ "max_length": 512,
32
+ "return_tensors": "pt",
33
+ }
34
+
35
+ DEFAULT_DATALOADER_KWARGS = {
36
+ "shuffle": False,
37
+ "batch_size": 32,
38
+ "pin_memory": True,
39
+ }
40
+
41
+
42
+ def default_collate_fn_closure(tokenizer, tokenize_kwargs) -> Callable:
43
+ def collate_fn(batch: list[str]) -> BatchEncoding:
44
+ return tokenizer(batch, **tokenize_kwargs)
45
+ return collate_fn
46
+
47
+
48
+ def defaulter(kwd_dict: Optional[Dict], default_dict: Dict) -> Dict:
49
+ return default_dict if kwd_dict is None else {**default_dict, **kwd_dict}
50
+
51
+
52
+ @dataclass
53
+ class SequenceClassifierOutputWithPastAndPooler(ModelOutput):
54
+ loss: Optional[torch.FloatTensor] = None
55
+ logits: torch.FloatTensor = None
56
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
57
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
58
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
59
+ pooler_output: torch.FloatTensor = None
60
+
61
 
62
  class NLLBLLM2Vec(PreTrainedModel):
63
  config_class = NLLBLLM2VecConfig
64
  model_type = "nllb-llm2vec"
65
+ _supports_flash_attn_2 = True
66
+ _supports_sdpa = True
67
  """
68
  NLLBLLM2Vec model combining NLLB and LLama encoders.
69
 
 
91
 
92
  if config is not None:
93
  super().__init__(config, *inputs, **kwargs)
94
+ # from_pretrained overwrites this after config instantiation, so we make sure it's correctly set
95
+ config.nllb_config._attn_implementation = config._attn_implementation
96
+ config.llm2vec_config._attn_implementation = config._attn_implementation
97
  self.nllb_encoder = nllb_encoder or M2M100Encoder(config.nllb_config)
98
  self.llm2vec = llm2vec or LlamaEncoderModel(config.llm2vec_config)
99
  self.config = config
100
+
101
  else:
102
  # Both encoders are provided
103
  self.nllb_encoder = cast(M2M100Encoder, nllb_encoder)
 
113
  self.llm2vec.config.hidden_size,
114
  bias=False,
115
  )
116
+
117
+ # TODO: update this once commit is included
118
+ min_version = "4.46.0"
119
+ if self.config.nllb_config._attn_implementation == "flash_attention_2":
120
+ if version.parse(transformers.__version__) < version.parse(min_version):
121
+ warnings.warn(
122
+ f"Installed transformers version ({transformers.__version__}) never sets NLLB-encoder dropout to `False` with FlashAttention2. See https://github.com/huggingface/transformers/pull/33844 for more info. Consider upgrading to latest to {min_version} or master.",
123
+ UserWarning,
124
+ )
125
 
126
  def forward(
127
  self,
 
148
  else:
149
  seq_indices, seq_offsets = indices
150
 
151
+ nllb_outputs = self.nllb_encoder(
152
+ input_ids=input_ids,
153
+ attention_mask=attention_mask,
154
+ )
155
+ nllb_last_hidden_state = nllb_outputs.last_hidden_state
156
+ nllb_last_hidden_state = self.up_proj(nllb_last_hidden_state)
 
 
157
  outputs = self.llm2vec(
158
  inputs_embeds=nllb_last_hidden_state,
159
  attention_mask=attention_mask,
 
188
  self,
189
  inputs: List[str],
190
  src_lang: str = "eng_Latn",
191
+ dataloader_kwargs: Optional[Dict[str, Any]] = None,
192
  tokenize_kwargs: Optional[Dict[str, Any]] = None,
193
+ collate_fn_closure: Optional[Callable] = None,
194
  ) -> torch.Tensor:
195
  """
196
  Encode input texts into embeddings.
197
 
198
  Args:
199
  inputs (List[str]): List of input texts.
200
+ src_lang (str): Source language code for the tokenizer (default: `"eng_Latn"`).
201
+ dataloader_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the dataloader excl. `collate_fn`.
202
+ Defaults to:
203
+ >> dataloader_kwargs = {
204
+ >> "shuffle": False,
205
+ >> "pin_memory": True,
206
+ >> }
207
  tokenize_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer.
208
  Defaults to:
209
  >> tokenize_kwargs = {
 
212
  >> "max_length": 512,
213
  >> "return_tensors": "pt",
214
  >> }
215
+ collate_fn_closure (Optional[Callable]): Closure that should return a `collate_fn`.
216
+ Defaults to:
217
+ >> def default_collate_fn_closure(tokenizer, tokenize_kwargs) -> Callable:
218
+ >> def collate_fn(batch: list[str]) -> BatchEncoding:
219
+ >> return tokenizer(batch, **tokenize_kwargs)
220
+ >> return collate_fn
221
  Returns:
222
  torch.Tensor: Mean-pooled sequence embeddings of the inputs.
223
  """
224
+ # merge user kwargs with defaults, giving priority to user kwargs
225
+ tokenize_kwargs = defaulter(tokenize_kwargs, DEFAULT_TOKENIZE_KWARGS)
226
+ dataloader_kwargs = defaulter(dataloader_kwargs, DEFAULT_DATALOADER_KWARGS)
 
 
 
 
227
 
228
  tokenizer = self.tokenizer
229
  tokenizer.src_lang = src_lang
230
  device = next(self.parameters()).device
 
 
231
 
232
+ if collate_fn_closure is None:
233
+ collate_fn = default_collate_fn_closure(tokenizer, tokenize_kwargs)
234
+ else:
235
+ collate_fn = collate_fn_closure(tokenizer, tokenize_kwargs)
236
+ assert (
237
+ "collate_fn" not in dataloader_kwargs
238
+ ), "`collate_fn` should be created via `collate_fn_closure`"
239
+ self.eval()
240
+ if len(inputs) > dataloader_kwargs.get("batch_size", 1):
241
+ dataloader = DataLoader(inputs, collate_fn=collate_fn, **dataloader_kwargs) # type: ignore
242
+ all_embeddings = []
243
+ # Iterate through the dataloader with a progress bar and autocast
244
+ with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
245
+ for batch in tqdm(dataloader, desc="Encoding"):
246
+ # Move batch to device
247
+ batch = {k: v.to(device) for k, v in batch.items()}
248
+ # Forward pass through the model (assumes model returns embeddings)
249
+ with torch.inference_mode():
250
+ pooled_embeddings = cast(
251
+ SequenceClassifierOutputWithPastAndPooler, self(**batch)
252
+ ).pooler_output # Assuming model returns sequence embeddings
253
+ all_embeddings.append(pooled_embeddings)
254
+ # Concatenate all pooled embeddings along the batch dimension
255
+ all_embeddings = torch.cat(all_embeddings, dim=0)
256
+ else:
257
+ batch = {k: v.to(device) for k, v in collate_fn(inputs)}
258
+ with torch.inference_mode():
259
+ all_embeddings = cast(
260
+ SequenceClassifierOutputWithPastAndPooler, self(**batch)
261
+ ).pooler_output # Assuming model returns sequence embeddings
262
+ return all_embeddings
263
 
264
  @staticmethod
265
  def _get_input_offsets(
 
283
  non_padded_lengths = attention_mask.sum(
284
  dim=1
285
  ) # Count non-padded tokens per sequence
286
+ offsets = non_padded_lengths.cumsum(dim=0).roll(shifts=1)
287
+ offsets[0] = 0
 
 
 
 
288
  return input_indices, offsets
289
 
290
  @staticmethod
 
322
  config_class = NLLBLLM2VecConfig
323
  model_type = "nllb-llm2vec"
324
  base_model_prefix = "model"
325
+ _supports_flash_attn_2 = True
326
+ _supports_sdpa = True
327
 
328
  def __init__(self, config):
329
  super().__init__(config)
330
  self.num_labels = config.num_labels
331
+
332
  self.model = NLLBLLM2Vec(config)
333
  self.score = nn.Linear(
334
  config.llm2vec_config.hidden_size, self.num_labels, bias=False
 
337
  # Initialize weights and apply final processing
338
  self.post_init()
339
 
340
+ def _init_weights(self, module):
341
+ if module is self.score:
342
+ # INFO:
343
+ # - critical that clf head is in float32 (NusaX perf. drops funky otherwise)
344
+ # - Initialization needs to be redone, otherwise borked
345
+ # - Use kaiming uniform, b/c Llama init (cf. `nn.Linear` below) performs worse
346
+ self.score = self.score.to(torch.float32)
347
+ torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
348
+ elif isinstance(module, nn.Linear):
349
+ if isinstance(module, nn.Linear):
350
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
351
+ if module.bias is not None:
352
+ module.bias.data.zero_()
353
+ elif isinstance(module, nn.Embedding):
354
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
355
+ if module.padding_idx is not None:
356
+ module.weight.data[module.padding_idx].zero_()
357
+
358
  def get_input_embeddings(self):
359
  return self.model.nllb.embed_tokens
360
 
361
  def set_input_embeddings(self, value):
362
  self.model.nllb.embed_tokens = value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
  def forward(
365
  self,
 
425
  output = (pooled_logits,) + transformer_outputs[1:]
426
  return ((loss,) + output) if loss is not None else output
427
 
428
+ return SequenceClassifierOutputWithPastAndPooler(
429
  loss=loss,
430
  hidden_states=hidden_states,
431
  logits=pooled_logits,
432
+ pooler_output=transformer_outputs.pooler_output,
433
+ )
434
+
435
+
436
+ class NLLBLLM2VecForTokenClassification(PreTrainedModel):
437
+ config_class = NLLBLLM2VecConfig
438
+ model_type = "nllb-llm2vec"
439
+ base_model_prefix = "model"
440
+ _supports_flash_attn_2 = True
441
+ _supports_sdpa = True
442
+
443
+ def __init__(self, config: NLLBLLM2VecConfig):
444
+ super().__init__(config)
445
+ self.num_labels = config.num_labels
446
+
447
+ self.model = NLLBLLM2Vec(config)
448
+ self.classifier = nn.Linear(
449
+ config.llm2vec_config.hidden_size, self.num_labels, bias=False
450
+ )
451
+
452
+ # Initialize weights and apply final processing
453
+ self.post_init()
454
+
455
+ def _init_weights(self, module):
456
+ if module is self.classifier:
457
+ # INFO:
458
+ # - critical that clf head is in float32 (NusaX perf. drops funky otherwise)
459
+ # - Initialization needs to be redone, otherwise borked
460
+ # - Use kaiming uniform, b/c Llama init (cf. `nn.Linear` below) performs worse
461
+ self.classifier = self.classifier.to(torch.float32)
462
+ torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
463
+ elif isinstance(module, nn.Linear):
464
+ if isinstance(module, nn.Linear):
465
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
466
+ if module.bias is not None:
467
+ module.bias.data.zero_()
468
+ elif isinstance(module, nn.Embedding):
469
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
470
+ if module.padding_idx is not None:
471
+ module.weight.data[module.padding_idx].zero_()
472
+
473
+ def get_input_embeddings(self):
474
+ return self.model.nllb.embed_tokens
475
+
476
+ def set_input_embeddings(self, value):
477
+ self.model.nllb.embed_tokens = value
478
+
479
+ # adapted from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification
480
+ # - removed classifier dropout
481
+ # - use F.cross_entropy
482
+ def forward(
483
+ self,
484
+ input_ids: Optional[torch.LongTensor] = None,
485
+ attention_mask: Optional[torch.FloatTensor] = None,
486
+ token_type_ids: Optional[torch.LongTensor] = None,
487
+ position_ids: Optional[torch.LongTensor] = None,
488
+ head_mask: Optional[torch.FloatTensor] = None,
489
+ inputs_embeds: Optional[torch.FloatTensor] = None,
490
+ labels: Optional[torch.LongTensor] = None,
491
+ output_attentions: Optional[bool] = None,
492
+ output_hidden_states: Optional[bool] = None,
493
+ return_dict: Optional[bool] = None,
494
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
495
+ r"""
496
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
497
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
498
+ """
499
+ return_dict = (
500
+ return_dict if return_dict is not None else self.config.use_return_dict
501
+ )
502
+
503
+ outputs = self.model(
504
+ input_ids,
505
+ attention_mask=attention_mask,
506
+ token_type_ids=token_type_ids,
507
+ position_ids=position_ids,
508
+ head_mask=head_mask,
509
+ inputs_embeds=inputs_embeds,
510
+ output_attentions=output_attentions,
511
+ output_hidden_states=output_hidden_states,
512
+ return_dict=return_dict,
513
+ )
514
+ sequence_output = outputs[0]
515
+ logits = self.classifier(sequence_output)
516
+
517
+ loss = None
518
+ if labels is not None:
519
+ # move labels to correct device to enable model parallelism
520
+ labels = labels.to(logits.device)
521
+ loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
522
+
523
+ if not return_dict:
524
+ output = (logits,) + outputs[2:]
525
+ return ((loss,) + output) if loss is not None else output
526
+
527
+ return TokenClassifierOutput(
528
+ loss=loss,
529
+ logits=logits,
530
+ hidden_states=outputs.hidden_states,
531
+ attentions=outputs.attentions,
532
  )
533
 
534
 
 
536
  AutoModelForSequenceClassification.register(
537
  NLLBLLM2VecConfig, NLLBLLM2VecForSequenceClassification
538
  )
539
+ AutoModelForSequenceClassification.register(
540
+ NLLBLLM2VecConfig, NLLBLLM2VecForTokenClassification
541
+ )