Finetuning using LoRA

#4
by titoghose - opened

Hello,

I'm trying to finetune the model using LoRA with the sample code snippet in the model card. However, only the embeddings and in_proj layers are getting updated, despite supplying x_proj and out_proj as target_modules as well.

Library versions

  • transformers==4.39.0
  • torch==2.2.0+cu121
  • peft==0.13.2
  • causal-conv1d==1.2.0.post2
  • mamba-ssm==1.2.0.post1

The code used is as follows:

from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments

tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")

dataset = load_dataset("Abirate/english_quotes", split="train")

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)

lora_config =  LoraConfig(
        r=8,
        target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
        task_type="CAUSAL_LM",
        bias="none"
)
peft_model = get_peft_model(model, peft_config=lora_config)

sample_inp = tokenizer(dataset['quote'][0:1], return_tensors="pt")
sample_inp = {k: v.to("cuda") for k, v in sample_inp.items()}
sample_inp["labels"] = sample_inp["input_ids"].clone()
sample_inp["labels"][:, :-1] = -100

peft_model = peft_model.to("cuda")
peft_model.train()
op = peft_model(**sample_inp)
op.loss.backward()

for name, param in peft_model.named_parameters():
    if 'lora' in name.lower():
        if param.grad is not None:
            print(f'{str(param.grad.shape):<30}{name}')
        else:
            print(f'{str(param.grad):<30}{name}')
    if 'layers.2' in name:
        break

The output is as follows:

torch.Size([8, 50280])        base_model.model.backbone.embeddings.lora_embedding_A.default
torch.Size([768, 8])          base_model.model.backbone.embeddings.lora_embedding_B.default
torch.Size([8, 768])          base_model.model.backbone.layers.0.mixer.in_proj.lora_A.default.weight
torch.Size([3072, 8])         base_model.model.backbone.layers.0.mixer.in_proj.lora_B.default.weight
None                          base_model.model.backbone.layers.0.mixer.x_proj.lora_A.default.weight
None                          base_model.model.backbone.layers.0.mixer.x_proj.lora_B.default.weight
None                          base_model.model.backbone.layers.0.mixer.out_proj.lora_A.default.weight
None                          base_model.model.backbone.layers.0.mixer.out_proj.lora_B.default.weight
torch.Size([8, 768])          base_model.model.backbone.layers.1.mixer.in_proj.lora_A.default.weight
torch.Size([3072, 8])         base_model.model.backbone.layers.1.mixer.in_proj.lora_B.default.weight
None                          base_model.model.backbone.layers.1.mixer.x_proj.lora_A.default.weight
None                          base_model.model.backbone.layers.1.mixer.x_proj.lora_B.default.weight
None                          base_model.model.backbone.layers.1.mixer.out_proj.lora_A.default.weight
None                          base_model.model.backbone.layers.1.mixer.out_proj.lora_B.default.weight

I think this is because if the mamba_ssm and causal_conv1d packages are available, the forward pass of the MambaMixer module uses the mamba_inner_fn from the mamba_ssm package. This functions requires only the weight variable of the linear layers, and does not use the forward method of torch.nn.Linear. Hence, even though the select linear layers of the model are wrapped using LoRA from the peft library, the LoRA matrices never accumulate gradients because I assume they are not added to the computational graph since the mamba_inner_fn only accesses their weight variable. I think the reason the embeddings and in_proj get gradients is because they are computed using the standard forward pass of torch.nn.Embedding and torch.nn.Linear instead of using the mamba_inner_fn.

Does anyone have a solution to this problem that doesn't involve disabling the fast path of using the mamba_inner_fn?

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment