Training with FA2 fails: AttributeError: 'FlashAttention2' object has no attribute 'attention_dropout'
Inference works fine with FA2 but when I try to train the model with the standard HF trainer, it fails with AttributeError: 'FlashAttention2' object has no attribute 'attention_dropout'
here's a minimal example to reproduce the error:
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import torch
modelpath="models/stablelm-2-1_6b"
# model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
modelpath,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
modelpath,
use_fast=False,
trust_remote_code=True,
)
# dataset
dataset = load_dataset("yelp_review_full")["test"]
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=100)
tokenized_dataset = dataset.map(
tokenize_function,
remove_columns=dataset.column_names,
batched=True)
# train
trainer = Trainer(
model=model,
args=TrainingArguments(output_dir="test_trainer"),
train_dataset=tokenized_dataset,
)
trainer.train()
traceback:
AttributeError Traceback (most recent call last)
Cell In[1], line 38
32 # train
33 trainer = Trainer(
34 model=model,
35 args=TrainingArguments(output_dir="test_trainer"),
36 train_dataset=tokenized_dataset,
37 )
---> 38 trainer.train()
File ~/.local/lib/python3.10/site-packages/transformers/trainer.py:1539, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1537 hf_hub_utils.enable_progress_bars()
1538 else:
-> 1539 return inner_training_loop(
1540 args=args,
1541 resume_from_checkpoint=resume_from_checkpoint,
1542 trial=trial,
1543 ignore_keys_for_eval=ignore_keys_for_eval,
1544 )
File ~/.local/lib/python3.10/site-packages/transformers/trainer.py:1869, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
1866 self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
1868 with self.accelerator.accumulate(model):
-> 1869 tr_loss_step = self.training_step(model, inputs)
1871 if (
1872 args.logging_nan_inf_filter
1873 and not is_torch_tpu_available()
1874 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
1875 ):
1876 # if loss is nan or inf simply add the average of previous logged losses
1877 tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
File ~/.local/lib/python3.10/site-packages/transformers/trainer.py:2768, in Trainer.training_step(self, model, inputs)
2765 return loss_mb.reduce_mean().detach().to(self.args.device)
2767 with self.compute_loss_context_manager():
-> 2768 loss = self.compute_loss(model, inputs)
2770 if self.args.n_gpu > 1:
2771 loss = loss.mean() # mean() to average on multi-gpu parallel training
File ~/.local/lib/python3.10/site-packages/transformers/trainer.py:2791, in Trainer.compute_loss(self, model, inputs, return_outputs)
2789 else:
2790 labels = None
-> 2791 outputs = model(**inputs)
2792 # Save past state if it exists
2793 # TODO: this needs to be fixed and made cleaner later.
2794 if self.args.past_index >= 0:
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/.cache/huggingface/modules/transformers_modules/stablelm-2-1_6b/modeling_stablelm_epoch.py:818, in StableLMEpochForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
813 return_dict = (
814 return_dict if return_dict is not None else self.config.use_return_dict
815 )
817 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 818 outputs = self.model(
819 input_ids,
820 attention_mask=attention_mask,
821 position_ids=position_ids,
822 past_key_values=past_key_values,
823 inputs_embeds=inputs_embeds,
824 use_cache=use_cache,
825 output_attentions=output_attentions,
826 output_hidden_states=output_hidden_states,
827 return_dict=return_dict,
828 )
830 hidden_states = outputs[0]
831 logits = self.lm_head(hidden_states).float()
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/.cache/huggingface/modules/transformers_modules/stablelm-2-1_6b/modeling_stablelm_epoch.py:722, in StableLMEpochModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
715 layer_outputs = torch.utils.checkpoint.checkpoint(
716 create_custom_forward(decoder_layer),
717 hidden_states,
718 attention_mask,
719 position_ids,
720 )
721 else:
--> 722 layer_outputs = decoder_layer(
723 hidden_states,
724 attention_mask=attention_mask,
725 position_ids=position_ids,
726 past_key_value=past_key_value,
727 output_attentions=output_attentions,
728 use_cache=use_cache,
729 )
731 hidden_states = layer_outputs[0]
733 if use_cache:
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/.cache/huggingface/modules/transformers_modules/stablelm-2-1_6b/modeling_stablelm_epoch.py:513, in DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
510 hidden_states = self.input_layernorm(hidden_states)
512 # Self Attention
--> 513 hidden_states, self_attn_weights, present_key_value = self.self_attn(
514 hidden_states=hidden_states,
515 attention_mask=attention_mask,
516 position_ids=position_ids,
517 past_key_value=past_key_value,
518 output_attentions=output_attentions,
519 use_cache=use_cache,
520 )
521 hidden_states = residual + hidden_states
523 # Fully Connected
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/.cache/huggingface/modules/transformers_modules/stablelm-2-1_6b/modeling_stablelm_epoch.py:374, in FlashAttention2.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)
371 key_states = key_states.transpose(1, 2)
372 value_states = value_states.transpose(1, 2)
--> 374 dropout_rate = self.attention_dropout if self.training else 0.0
376 attn_output = self._flash_attention_forward(
377 query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
378 )
379 attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1695, in Module.__getattr__(self, name)
1693 if name in modules:
1694 return modules[name]
-> 1695 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'FlashAttention2' object has no attribute 'attention_dropout'
transformers==4.37.0
@g-ronimo did you get the training config file or you have created one from the 3B?
which training config file?
@g-ronimo please read this I have opened a discussion that you liked
I see! From what I understand, this missing .yml file contains the hyperparameters used for pretraining. I don't need that for fine-tuning the (already pretrained) model
I don't know if it's related but i got the following error with FlashAttention while finetuning the stablelm-2-zephyr-1_6b:
RuntimeError: FlashAttention only support fp16 and bf16 data type
I don't know if it's related but i got the following error with FlashAttention while finetuning the stablelm-2-zephyr-1_6b:
RuntimeError: FlashAttention only support fp16 and bf16 data type
@interstellarninja FA2 apparently only works w/ bf16/fp16
https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
Hey y'all! Sorry, attention dropout was never supported in modeling_stablelm_epoch
and this reference slipped through. I've just added support.
@g-ronimo Note re:
here's a minimal example to reproduce the error: ..
You'll need to pass labels
as well to avoid loss is None
errors:
def tokenize_function(examples):
inputs = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=100)
inputs['labels'] = inputs['input_ids'].copy() # Shift is done internally
return inputs
Colab notebook here.
works now! thank you.
sorry for the confusing minimal example with missing labels. i'm not actually training anything on yelp reviews π