fdschmidt93's picture
feat: reupload weights
f75dc6b
raw
history blame
28.9 kB
from typing import Any, Dict, List, Optional, Tuple, cast, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.auto import AutoModel, AutoModelForSequenceClassification
from transformers.modeling_outputs import (
BaseModelOutputWithPooling,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder
from transformers.cache_utils import Cache
from .configuration_nllbllm2vec import NLLBLLM2VecConfig
from .modeling_llama_encoder import LlamaEncoderModel
class NLLBLLM2Vec(PreTrainedModel):
config_class = NLLBLLM2VecConfig
model_type = "nllb-llm2vec"
"""
NLLBLLM2Vec model combining NLLB and LLama encoders.
Args:
config (Optional[NLLBLLM2VecConfig]): Configuration object.
nllb_encoder (Optional[M2M100Encoder]): Pre-initialized NLLB encoder.
llm2vec (Optional[LlamaEncoderModel]): Pre-initialized LLama encoder.
*inputs: Additional positional arguments.
**kwargs: Additional keyword arguments.
"""
def __init__(
self,
config: Optional[NLLBLLM2VecConfig] = None,
nllb_encoder: Optional[M2M100Encoder] = None,
llm2vec: Optional[LlamaEncoderModel] = None,
*inputs,
**kwargs,
):
# Ensure that either config is not None or both encoders are provided
if config is None and (nllb_encoder is None or llm2vec is None):
raise ValueError(
"Either `config` must be provided, or both `nllb_encoder` and `llm2vec` must be specified."
)
if config is not None:
super().__init__(config, *inputs, **kwargs)
self.nllb_encoder = nllb_encoder or M2M100Encoder(config.nllb_config)
self.llm2vec = llm2vec or LlamaEncoderModel(config.llm2vec_config)
self.config = config
else:
# Both encoders are provided
self.nllb_encoder = cast(M2M100Encoder, nllb_encoder)
self.llm2vec = cast(LlamaEncoderModel, llm2vec)
self.config = NLLBLLM2VecConfig(
nllb_config=self.nllb_encoder.config, # type: ignore
llm2vec_config=self.llm2vec.config, # type: ignore
)
super().__init__(self.config, *inputs, **kwargs)
self.up_proj = nn.Linear(
self.nllb_encoder.config.d_model,
self.llm2vec.config.hidden_size,
bias=False,
)
# Additional initialization logic can go here
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
indices: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> BaseModelOutputWithPooling:
"""
Forward pass of the model.
Args:
input_ids (torch.Tensor): Input token IDs.
attention_mask (torch.Tensor): Attention mask.
indices (Optional[Tuple[torch.Tensor, torch.Tensor]]): Precomputed input indices and offsets.
Returns:
BaseModelOutputWithPooling: Model outputs with last hidden state and pooled output.
"""
# Compute input indices and offsets if not provided
if indices is None:
seq_indices, seq_offsets = self._get_input_offsets(attention_mask)
else:
seq_indices, seq_offsets = indices
with torch.inference_mode():
nllb_outputs = self.nllb_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
)
nllb_last_hidden_state = nllb_outputs.last_hidden_state
nllb_last_hidden_state = self.up_proj(nllb_last_hidden_state)
nllb_last_hidden_state = nllb_last_hidden_state.detach().clone()
outputs = self.llm2vec(
inputs_embeds=nllb_last_hidden_state,
attention_mask=attention_mask,
)
pooler_output = self._mean_embedding(
hidden_states=outputs.last_hidden_state,
input_indices=seq_indices,
offsets=seq_offsets,
)
return BaseModelOutputWithPooling(
last_hidden_state=outputs.last_hidden_state,
pooler_output=pooler_output,
)
@property
def tokenizer(self):
"""
Get the tokenizer associated with the model.
Returns:
PreTrainedTokenizer: The tokenizer instance.
"""
if not hasattr(self, "_tokenizer"):
from transformers import AutoTokenizer
self._tokenizer = AutoTokenizer.from_pretrained(
"facebook/nllb-200-distilled-600M", padding_side="right"
)
return self._tokenizer
def encode(
self,
inputs: List[str],
src_lang: str = "eng_Latn",
tokenize_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
"""
Encode input texts into embeddings.
Args:
inputs (List[str]): List of input texts.
src_lang (str): Source language code.
tokenize_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer.
Defaults to:
>> tokenize_kwargs = {
>> "padding": True,
>> "truncation": True,
>> "max_length": 512,
>> "return_tensors": "pt",
>> }
Returns:
torch.Tensor: Mean-pooled sequence embeddings of the inputs.
"""
if tokenize_kwargs is None:
tokenize_kwargs = {
"padding": True,
"truncation": True,
"max_length": 512,
"return_tensors": "pt",
}
tokenizer = self.tokenizer
tokenizer.src_lang = src_lang
device = next(self.parameters()).device
batch = tokenizer(inputs, **tokenize_kwargs).to(device)
device_type = device.type # e.g., 'cuda' or 'cpu'
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
return self(**batch).pooler_output
@staticmethod
def _get_input_offsets(
attention_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute indices and offsets for mean pooling using EmbeddingBag.
Args:
attention_mask (torch.Tensor): Attention mask of shape (batch_size, seq_len).
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- input_indices: Indices of non-padded tokens in the flattened input.
- offsets: Offsets indicating the start index of each sequence in the flattened input.
"""
# Find the indices of non-padded tokens in flattened hidden_states
input_indices = attention_mask.view(-1).nonzero(as_tuple=False).squeeze()
# Compute the offsets: for each sequence, where it starts in the flattened input
non_padded_lengths = attention_mask.sum(
dim=1
) # Count non-padded tokens per sequence
offsets = torch.cat(
[
torch.tensor([0], device=attention_mask.device),
non_padded_lengths.cumsum(dim=0)[:-1],
]
)
return input_indices, offsets
@staticmethod
def _mean_embedding(
hidden_states: torch.Tensor,
input_indices: torch.Tensor,
offsets: torch.Tensor,
) -> torch.Tensor:
"""
Compute the mean of non-padded embeddings using `embedding_bag`,
properly handling padding with offsets.
Args:
hidden_states (torch.Tensor): Hidden states of shape (batch_size, seq_len, embed_dim).
input_indices (torch.Tensor): Indices of non-padded tokens in flattened form.
offsets (torch.Tensor): Offsets specifying the start of each sequence.
Returns:
torch.Tensor: Pooled mean embeddings of shape (batch_size, embed_dim).
"""
# Flatten hidden_states to 2D: shape (batch_size * seq_len, embedding_dim)
batch_size, seq_len, embed_dim = hidden_states.shape
token_embeds = hidden_states.view(-1, embed_dim)
# Use embedding_bag with mode 'mean' and appropriate indices
return F.embedding_bag(
input=input_indices, # Indices of non-padded tokens in flattened form
weight=token_embeds, # The flattened hidden states as embedding matrix
offsets=offsets, # Offsets specifying start of each sequence
mode="mean", # Aggregation mode
)
class NLLBLLM2VecForSequenceClassification(PreTrainedModel):
config_class = NLLBLLM2VecConfig
model_type = "nllb-llm2vec"
base_model_prefix = "model"
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = NLLBLLM2Vec(config)
self.score = nn.Linear(
config.llm2vec_config.hidden_size, self.num_labels, bias=False
)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.nllb.embed_tokens
def set_input_embeddings(self, value):
self.model.nllb.embed_tokens = value
# We need to modify the adapter config and state dict at runtime
# such that adapter weights are correctly loaded from an AutoModel-suitable
# adapter_config.json and adapter_config.safetensors
def load_adapter(
self,
peft_model_id: Optional[str] = None,
adapter_name: Optional[str] = None,
revision: Optional[str] = None,
token: Optional[str] = None,
device_map: Optional[str] = "auto",
max_memory: Optional[str] = None,
offload_folder: Optional[str] = None,
offload_index: Optional[int] = None,
peft_config: Optional[Dict[str, Any]] = None,
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
adapter_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
from peft import PeftConfig, load_peft_weights # type: ignore
from transformers.utils import find_adapter_config_file
if adapter_kwargs is None:
adapter_kwargs = {}
if "device" not in adapter_kwargs:
device = (
self.device
if not hasattr(self, "hf_device_map")
else list(self.hf_device_map.values())[0]
)
else:
device = adapter_kwargs["device"]
# To avoid PEFT errors later on with safetensors.
if isinstance(device, torch.device):
device = str(device)
# Override token with adapter_kwargs' token
if "token" in adapter_kwargs:
token = adapter_kwargs["token"]
if peft_model_id is None and (
adapter_state_dict is None and peft_config is None
):
raise ValueError(
"You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter."
)
if peft_config is None:
assert isinstance(peft_model_id, str)
adapter_config_file = find_adapter_config_file(
peft_model_id,
token=token,
**adapter_kwargs,
)
if adapter_config_file is None:
raise ValueError(
f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the "
"adapter model."
)
peft_config = cast(
Dict[str, Any],
PeftConfig.from_pretrained(
peft_model_id,
token=token,
**adapter_kwargs,
),
)
peft_config.target_modules = [ # type: ignore
"model." + module
for module in peft_config.target_modules # type: ignore
]
if peft_model_id is not None:
adapter_state_dict = load_peft_weights(
peft_model_id, token=token, device=device, **adapter_kwargs
)
assert isinstance(adapter_state_dict, dict)
# correctly set the name
processed_adapter_state_dict = {}
prefix = "base_model."
for key, value in adapter_state_dict.items():
if key.startswith(prefix):
new_key = key[len(prefix) :]
else:
new_key = key
processed_adapter_state_dict[new_key] = value
return super().load_adapter(
peft_model_id=None,
adapter_name=adapter_name,
revision=revision,
token=token,
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_index=offload_index,
peft_config=peft_config,
adapter_state_dict=processed_adapter_state_dict,
adapter_kwargs=adapter_kwargs,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs.pooler_output
pooled_logits = self.score(hidden_states)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
if self.num_labels == 1:
loss = F.mse_loss(pooled_logits.squeeze(), labels.squeeze())
else:
loss = F.mse_loss(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss = F.cross_entropy(
pooled_logits.view(-1, self.num_labels), labels.view(-1)
)
elif self.config.problem_type == "multi_label_classification":
loss = F.binary_cross_entropy_with_logits(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
hidden_states=hidden_states,
logits=pooled_logits,
)
AutoModel.register(NLLBLLM2VecConfig, NLLBLLM2Vec)
AutoModelForSequenceClassification.register(
NLLBLLM2VecConfig, NLLBLLM2VecForSequenceClassification
)
def repl():
from transformers import AutoModel
cfg = NLLBLLM2VecConfig()
model = NLLBLLM2Vec(cfg)
nllb = AutoModel.from_pretrained(
"facebook/nllb-200-distilled-600M", torch_dtype=torch.bfloat16
).encoder
# llm2vec = AutoModel.from_pretrained(
# "fdschmidt93/LLM2Vec-Meta-Llama-3.1-8B-Instruct-mntp-unsup-simcse",
# trust_remote_code=True,
# torch_dtype=torch.bfloat16,
# )
llama = LlamaEncoderModel.from_pretrained("../trident-nllb-llm2vec/data/model/llm2vec_llama3-1_unsupervised/", torch_dtype=torch.bfloat16)
model.nllb_encoder.load_state_dict(nllb.state_dict())
model.llm2vec.load_state_dict(llama.state_dict())
ckpt = torch.load("./step=20000-weights.ckpt", map_location="cpu")
model.up_proj.load_state_dict({"weight": ckpt["model.up_proj.weight"]})
model.save_pretrained("../weights_new")
from peft.mapping import get_peft_model
from peft.tuners.lora.config import LoraConfig
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.0,
bias="none",
task_type="FEATURE_EXTRACTION",
target_modules=[
"llm2vec.layers.0.self_attn.q_proj",
"llm2vec.layers.0.self_attn.k_proj",
"llm2vec.layers.0.self_attn.v_proj",
"llm2vec.layers.0.self_attn.o_proj",
"llm2vec.layers.0.mlp.gate_proj",
"llm2vec.layers.0.mlp.up_proj",
"llm2vec.layers.0.mlp.down_proj",
"llm2vec.layers.1.self_attn.q_proj",
"llm2vec.layers.1.self_attn.k_proj",
"llm2vec.layers.1.self_attn.v_proj",
"llm2vec.layers.1.self_attn.o_proj",
"llm2vec.layers.1.mlp.gate_proj",
"llm2vec.layers.1.mlp.up_proj",
"llm2vec.layers.1.mlp.down_proj",
"llm2vec.layers.2.self_attn.q_proj",
"llm2vec.layers.2.self_attn.k_proj",
"llm2vec.layers.2.self_attn.v_proj",
"llm2vec.layers.2.self_attn.o_proj",
"llm2vec.layers.2.mlp.gate_proj",
"llm2vec.layers.2.mlp.up_proj",
"llm2vec.layers.2.mlp.down_proj",
"llm2vec.layers.3.self_attn.q_proj",
"llm2vec.layers.3.self_attn.k_proj",
"llm2vec.layers.3.self_attn.v_proj",
"llm2vec.layers.3.self_attn.o_proj",
"llm2vec.layers.3.mlp.gate_proj",
"llm2vec.layers.3.mlp.up_proj",
"llm2vec.layers.3.mlp.down_proj",
"llm2vec.layers.4.self_attn.q_proj",
"llm2vec.layers.4.self_attn.k_proj",
"llm2vec.layers.4.self_attn.v_proj",
"llm2vec.layers.4.self_attn.o_proj",
"llm2vec.layers.4.mlp.gate_proj",
"llm2vec.layers.4.mlp.up_proj",
"llm2vec.layers.4.mlp.down_proj",
"llm2vec.layers.5.self_attn.q_proj",
"llm2vec.layers.5.self_attn.k_proj",
"llm2vec.layers.5.self_attn.v_proj",
"llm2vec.layers.5.self_attn.o_proj",
"llm2vec.layers.5.mlp.gate_proj",
"llm2vec.layers.5.mlp.up_proj",
"llm2vec.layers.5.mlp.down_proj",
"llm2vec.layers.6.self_attn.q_proj",
"llm2vec.layers.6.self_attn.k_proj",
"llm2vec.layers.6.self_attn.v_proj",
"llm2vec.layers.6.self_attn.o_proj",
"llm2vec.layers.6.mlp.gate_proj",
"llm2vec.layers.6.mlp.up_proj",
"llm2vec.layers.6.mlp.down_proj",
"llm2vec.layers.7.self_attn.q_proj",
"llm2vec.layers.7.self_attn.k_proj",
"llm2vec.layers.7.self_attn.v_proj",
"llm2vec.layers.7.self_attn.o_proj",
"llm2vec.layers.7.mlp.gate_proj",
"llm2vec.layers.7.mlp.up_proj",
"llm2vec.layers.7.mlp.down_proj",
"llm2vec.layers.8.self_attn.q_proj",
"llm2vec.layers.8.self_attn.k_proj",
"llm2vec.layers.8.self_attn.v_proj",
"llm2vec.layers.8.self_attn.o_proj",
"llm2vec.layers.8.mlp.gate_proj",
"llm2vec.layers.8.mlp.up_proj",
"llm2vec.layers.8.mlp.down_proj",
"llm2vec.layers.9.self_attn.q_proj",
"llm2vec.layers.9.self_attn.k_proj",
"llm2vec.layers.9.self_attn.v_proj",
"llm2vec.layers.9.self_attn.o_proj",
"llm2vec.layers.9.mlp.gate_proj",
"llm2vec.layers.9.mlp.up_proj",
"llm2vec.layers.9.mlp.down_proj",
"llm2vec.layers.10.self_attn.q_proj",
"llm2vec.layers.10.self_attn.k_proj",
"llm2vec.layers.10.self_attn.v_proj",
"llm2vec.layers.10.self_attn.o_proj",
"llm2vec.layers.10.mlp.gate_proj",
"llm2vec.layers.10.mlp.up_proj",
"llm2vec.layers.10.mlp.down_proj",
"llm2vec.layers.11.self_attn.q_proj",
"llm2vec.layers.11.self_attn.k_proj",
"llm2vec.layers.11.self_attn.v_proj",
"llm2vec.layers.11.self_attn.o_proj",
"llm2vec.layers.11.mlp.gate_proj",
"llm2vec.layers.11.mlp.up_proj",
"llm2vec.layers.11.mlp.down_proj",
"llm2vec.layers.12.self_attn.q_proj",
"llm2vec.layers.12.self_attn.k_proj",
"llm2vec.layers.12.self_attn.v_proj",
"llm2vec.layers.12.self_attn.o_proj",
"llm2vec.layers.12.mlp.gate_proj",
"llm2vec.layers.12.mlp.up_proj",
"llm2vec.layers.12.mlp.down_proj",
"llm2vec.layers.13.self_attn.q_proj",
"llm2vec.layers.13.self_attn.k_proj",
"llm2vec.layers.13.self_attn.v_proj",
"llm2vec.layers.13.self_attn.o_proj",
"llm2vec.layers.13.mlp.gate_proj",
"llm2vec.layers.13.mlp.up_proj",
"llm2vec.layers.13.mlp.down_proj",
"llm2vec.layers.14.self_attn.q_proj",
"llm2vec.layers.14.self_attn.k_proj",
"llm2vec.layers.14.self_attn.v_proj",
"llm2vec.layers.14.self_attn.o_proj",
"llm2vec.layers.14.mlp.gate_proj",
"llm2vec.layers.14.mlp.up_proj",
"llm2vec.layers.14.mlp.down_proj",
"llm2vec.layers.15.self_attn.q_proj",
"llm2vec.layers.15.self_attn.k_proj",
"llm2vec.layers.15.self_attn.v_proj",
"llm2vec.layers.15.self_attn.o_proj",
"llm2vec.layers.15.mlp.gate_proj",
"llm2vec.layers.15.mlp.up_proj",
"llm2vec.layers.15.mlp.down_proj",
"llm2vec.layers.16.self_attn.q_proj",
"llm2vec.layers.16.self_attn.k_proj",
"llm2vec.layers.16.self_attn.v_proj",
"llm2vec.layers.16.self_attn.o_proj",
"llm2vec.layers.16.mlp.gate_proj",
"llm2vec.layers.16.mlp.up_proj",
"llm2vec.layers.16.mlp.down_proj",
"llm2vec.layers.17.self_attn.q_proj",
"llm2vec.layers.17.self_attn.k_proj",
"llm2vec.layers.17.self_attn.v_proj",
"llm2vec.layers.17.self_attn.o_proj",
"llm2vec.layers.17.mlp.gate_proj",
"llm2vec.layers.17.mlp.up_proj",
"llm2vec.layers.17.mlp.down_proj",
"llm2vec.layers.18.self_attn.q_proj",
"llm2vec.layers.18.self_attn.k_proj",
"llm2vec.layers.18.self_attn.v_proj",
"llm2vec.layers.18.self_attn.o_proj",
"llm2vec.layers.18.mlp.gate_proj",
"llm2vec.layers.18.mlp.up_proj",
"llm2vec.layers.18.mlp.down_proj",
"llm2vec.layers.19.self_attn.q_proj",
"llm2vec.layers.19.self_attn.k_proj",
"llm2vec.layers.19.self_attn.v_proj",
"llm2vec.layers.19.self_attn.o_proj",
"llm2vec.layers.19.mlp.gate_proj",
"llm2vec.layers.19.mlp.up_proj",
"llm2vec.layers.19.mlp.down_proj",
"llm2vec.layers.20.self_attn.q_proj",
"llm2vec.layers.20.self_attn.k_proj",
"llm2vec.layers.20.self_attn.v_proj",
"llm2vec.layers.20.self_attn.o_proj",
"llm2vec.layers.20.mlp.gate_proj",
"llm2vec.layers.20.mlp.up_proj",
"llm2vec.layers.20.mlp.down_proj",
"llm2vec.layers.21.self_attn.q_proj",
"llm2vec.layers.21.self_attn.k_proj",
"llm2vec.layers.21.self_attn.v_proj",
"llm2vec.layers.21.self_attn.o_proj",
"llm2vec.layers.21.mlp.gate_proj",
"llm2vec.layers.21.mlp.up_proj",
"llm2vec.layers.21.mlp.down_proj",
"llm2vec.layers.22.self_attn.q_proj",
"llm2vec.layers.22.self_attn.k_proj",
"llm2vec.layers.22.self_attn.v_proj",
"llm2vec.layers.22.self_attn.o_proj",
"llm2vec.layers.22.mlp.gate_proj",
"llm2vec.layers.22.mlp.up_proj",
"llm2vec.layers.22.mlp.down_proj",
"llm2vec.layers.23.self_attn.q_proj",
"llm2vec.layers.23.self_attn.k_proj",
"llm2vec.layers.23.self_attn.v_proj",
"llm2vec.layers.23.self_attn.o_proj",
"llm2vec.layers.23.mlp.gate_proj",
"llm2vec.layers.23.mlp.up_proj",
"llm2vec.layers.23.mlp.down_proj",
"llm2vec.layers.24.self_attn.q_proj",
"llm2vec.layers.24.self_attn.k_proj",
"llm2vec.layers.24.self_attn.v_proj",
"llm2vec.layers.24.self_attn.o_proj",
"llm2vec.layers.24.mlp.gate_proj",
"llm2vec.layers.24.mlp.up_proj",
"llm2vec.layers.24.mlp.down_proj",
"llm2vec.layers.25.self_attn.q_proj",
"llm2vec.layers.25.self_attn.k_proj",
"llm2vec.layers.25.self_attn.v_proj",
"llm2vec.layers.25.self_attn.o_proj",
"llm2vec.layers.25.mlp.gate_proj",
"llm2vec.layers.25.mlp.up_proj",
"llm2vec.layers.25.mlp.down_proj",
"llm2vec.layers.26.self_attn.q_proj",
"llm2vec.layers.26.self_attn.k_proj",
"llm2vec.layers.26.self_attn.v_proj",
"llm2vec.layers.26.self_attn.o_proj",
"llm2vec.layers.26.mlp.gate_proj",
"llm2vec.layers.26.mlp.up_proj",
"llm2vec.layers.26.mlp.down_proj",
"llm2vec.layers.27.self_attn.q_proj",
"llm2vec.layers.27.self_attn.k_proj",
"llm2vec.layers.27.self_attn.v_proj",
"llm2vec.layers.27.self_attn.o_proj",
"llm2vec.layers.27.mlp.gate_proj",
"llm2vec.layers.27.mlp.up_proj",
"llm2vec.layers.27.mlp.down_proj",
"llm2vec.layers.28.self_attn.q_proj",
"llm2vec.layers.28.self_attn.k_proj",
"llm2vec.layers.28.self_attn.v_proj",
"llm2vec.layers.28.self_attn.o_proj",
"llm2vec.layers.28.mlp.gate_proj",
"llm2vec.layers.28.mlp.up_proj",
"llm2vec.layers.28.mlp.down_proj",
"llm2vec.layers.29.self_attn.q_proj",
"llm2vec.layers.29.self_attn.k_proj",
"llm2vec.layers.29.self_attn.v_proj",
"llm2vec.layers.29.self_attn.o_proj",
"llm2vec.layers.29.mlp.gate_proj",
"llm2vec.layers.29.mlp.up_proj",
"llm2vec.layers.29.mlp.down_proj",
"llm2vec.layers.30.self_attn.q_proj",
"llm2vec.layers.30.self_attn.k_proj",
"llm2vec.layers.30.self_attn.v_proj",
"llm2vec.layers.30.self_attn.o_proj",
"llm2vec.layers.30.mlp.gate_proj",
"llm2vec.layers.30.mlp.up_proj",
"llm2vec.layers.30.mlp.down_proj",
"llm2vec.layers.31.self_attn.q_proj",
"llm2vec.layers.31.self_attn.k_proj",
"llm2vec.layers.31.self_attn.v_proj",
"llm2vec.layers.31.self_attn.o_proj",
"llm2vec.layers.31.mlp.gate_proj",
"llm2vec.layers.31.mlp.up_proj",
"llm2vec.layers.31.mlp.down_proj",
],
)
peft_model = get_peft_model(model, lora_config)
peft_model.save_pretrained("../nllb-llm2vec-saved")
import json
with open("./model.safetensors.index.json", "r") as f:
print(json.load(f))
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(
".", trust_remote_code=True, device_map="cuda"
)