Training with FA2 fails: AttributeError: 'FlashAttention2' object has no attribute 'attention_dropout'

#4
by g-ronimo - opened

Inference works fine with FA2 but when I try to train the model with the standard HF trainer, it fails with AttributeError: 'FlashAttention2' object has no attribute 'attention_dropout'

here's a minimal example to reproduce the error:

from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import torch

modelpath="models/stablelm-2-1_6b"

# model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    modelpath,    
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
    modelpath, 
    use_fast=False,
    trust_remote_code=True,
)    

# dataset
dataset = load_dataset("yelp_review_full")["test"]

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=100)

tokenized_dataset = dataset.map(
    tokenize_function, 
    remove_columns=dataset.column_names,
    batched=True)

# train
trainer = Trainer(
    model=model,
    args=TrainingArguments(output_dir="test_trainer"),
    train_dataset=tokenized_dataset,
)
trainer.train()

traceback:

AttributeError                            Traceback (most recent call last)
Cell In[1], line 38
     32 # train
     33 trainer = Trainer(
     34     model=model,
     35     args=TrainingArguments(output_dir="test_trainer"),
     36     train_dataset=tokenized_dataset,
     37 )
---> 38 trainer.train()

File ~/.local/lib/python3.10/site-packages/transformers/trainer.py:1539, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1537         hf_hub_utils.enable_progress_bars()
   1538 else:
-> 1539     return inner_training_loop(
   1540         args=args,
   1541         resume_from_checkpoint=resume_from_checkpoint,
   1542         trial=trial,
   1543         ignore_keys_for_eval=ignore_keys_for_eval,
   1544     )

File ~/.local/lib/python3.10/site-packages/transformers/trainer.py:1869, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   1866     self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
   1868 with self.accelerator.accumulate(model):
-> 1869     tr_loss_step = self.training_step(model, inputs)
   1871 if (
   1872     args.logging_nan_inf_filter
   1873     and not is_torch_tpu_available()
   1874     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   1875 ):
   1876     # if loss is nan or inf simply add the average of previous logged losses
   1877     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File ~/.local/lib/python3.10/site-packages/transformers/trainer.py:2768, in Trainer.training_step(self, model, inputs)
   2765     return loss_mb.reduce_mean().detach().to(self.args.device)
   2767 with self.compute_loss_context_manager():
-> 2768     loss = self.compute_loss(model, inputs)
   2770 if self.args.n_gpu > 1:
   2771     loss = loss.mean()  # mean() to average on multi-gpu parallel training

File ~/.local/lib/python3.10/site-packages/transformers/trainer.py:2791, in Trainer.compute_loss(self, model, inputs, return_outputs)
   2789 else:
   2790     labels = None
-> 2791 outputs = model(**inputs)
   2792 # Save past state if it exists
   2793 # TODO: this needs to be fixed and made cleaner later.
   2794 if self.args.past_index >= 0:

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.cache/huggingface/modules/transformers_modules/stablelm-2-1_6b/modeling_stablelm_epoch.py:818, in StableLMEpochForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
    813 return_dict = (
    814     return_dict if return_dict is not None else self.config.use_return_dict
    815 )
    817 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 818 outputs = self.model(
    819     input_ids,
    820     attention_mask=attention_mask,
    821     position_ids=position_ids,
    822     past_key_values=past_key_values,
    823     inputs_embeds=inputs_embeds,
    824     use_cache=use_cache,
    825     output_attentions=output_attentions,
    826     output_hidden_states=output_hidden_states,
    827     return_dict=return_dict,
    828 )
    830 hidden_states = outputs[0]
    831 logits = self.lm_head(hidden_states).float()

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.cache/huggingface/modules/transformers_modules/stablelm-2-1_6b/modeling_stablelm_epoch.py:722, in StableLMEpochModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
    715     layer_outputs = torch.utils.checkpoint.checkpoint(
    716         create_custom_forward(decoder_layer),
    717         hidden_states,
    718         attention_mask,
    719         position_ids,
    720     )
    721 else:
--> 722     layer_outputs = decoder_layer(
    723         hidden_states,
    724         attention_mask=attention_mask,
    725         position_ids=position_ids,
    726         past_key_value=past_key_value,
    727         output_attentions=output_attentions,
    728         use_cache=use_cache,
    729     )
    731 hidden_states = layer_outputs[0]
    733 if use_cache:

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.cache/huggingface/modules/transformers_modules/stablelm-2-1_6b/modeling_stablelm_epoch.py:513, in DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
    510 hidden_states = self.input_layernorm(hidden_states)
    512 # Self Attention
--> 513 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    514     hidden_states=hidden_states,
    515     attention_mask=attention_mask,
    516     position_ids=position_ids,
    517     past_key_value=past_key_value,
    518     output_attentions=output_attentions,
    519     use_cache=use_cache,
    520 )
    521 hidden_states = residual + hidden_states
    523 # Fully Connected

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/.cache/huggingface/modules/transformers_modules/stablelm-2-1_6b/modeling_stablelm_epoch.py:374, in FlashAttention2.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
    371 key_states = key_states.transpose(1, 2)
    372 value_states = value_states.transpose(1, 2)
--> 374 dropout_rate = self.attention_dropout if self.training else 0.0
    376 attn_output = self._flash_attention_forward(
    377     query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
    378 )
    379 attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1695, in Module.__getattr__(self, name)
   1693     if name in modules:
   1694         return modules[name]
-> 1695 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

AttributeError: 'FlashAttention2' object has no attribute 'attention_dropout'
transformers==4.37.0

@g-ronimo did you get the training config file or you have created one from the 3B?

which training config file?

@g-ronimo please read this I have opened a discussion that you liked

I see! From what I understand, this missing .yml file contains the hyperparameters used for pretraining. I don't need that for fine-tuning the (already pretrained) model

I don't know if it's related but i got the following error with FlashAttention while finetuning the stablelm-2-zephyr-1_6b:

RuntimeError: FlashAttention only support fp16 and bf16 data type

I don't know if it's related but i got the following error with FlashAttention while finetuning the stablelm-2-zephyr-1_6b:

RuntimeError: FlashAttention only support fp16 and bf16 data type

@interstellarninja FA2 apparently only works w/ bf16/fp16

https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2

model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16, 
    attn_implementation="flash_attention_2",
)
Stability AI org
β€’
edited Jan 23

Hey y'all! Sorry, attention dropout was never supported in modeling_stablelm_epoch and this reference slipped through. I've just added support.

@g-ronimo Note re:

here's a minimal example to reproduce the error: ..

You'll need to pass labels as well to avoid loss is None errors:

def tokenize_function(examples):
    inputs = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=100)
    inputs['labels'] = inputs['input_ids'].copy()  # Shift is done internally
    return inputs

Colab notebook here.

works now! thank you.
sorry for the confusing minimal example with missing labels. i'm not actually training anything on yelp reviews πŸ˜†

g-ronimo changed discussion status to closed

Sign up or log in to comment