Flash attention 2 training unstable
Hello, thanks for the update but it made the model's training slower. With flash attention the speed is good but I have a model predicting nans after the first backpropagation, like the optimizer broke the model at the first step.
I did not have the same problem using this piece of code to import and train other models :
llm_args = {'pretrained_model_name_or_path': "MILVLG/imp-v1-3b",
'local_files_only': local_files_only, 'cache_dir': cache_dir,
'trust_remote_code': True, 'device_map': 'auto',
'torch_dtype' : "auto" , 'attn_implementation': "flash_attention_2"}
self.llm = AutoModelForCausalLM.from_pretrained(**llm_args)
As for flash-attention_2, in our early experiments, there was no significant acceleration with it, and we will confirm later. And we sometimes run training and testing in v100 or lower software environments and turn them off for the sake of versatility and consistency.
Besides, which training framework is used? Deepspeed or fsdp? And have you ever tried to use bf16 or tf32 and grad_clip?
Interesting result, for me it is twice as fast and I had similar gains with other models. It seems I am using Deepspeed (according to the tags in the docker image I chose in Azure ML).
I have "RuntimeError: FlashAttention only support fp16 and bf16 data type" if I don't use 'torch_dtype' : "auto", so I guess I am using bf16.
My dockerfile :
FROM mcr.microsoft.com/aifx/acpt/stable-ubuntu2004-cu121-py310-torch22x:biweekly.202404.2
COPY requirements.txt .
RUN pip install -r requirements.txt --no-cache-dir
RUN pip install gymnasium[classic-control]
RUN ninja --version
RUN echo $?
RUN pip install flash-attn --no-build-isolation
COPY --from=mcr.microsoft.com/azureml/o16n-base/python-assets:20230419.v1 /artifacts /var/
RUN /var/requirements/install_system_requirements.sh &&
cp /var/configuration/rsyslog.conf /etc/rsyslog.conf &&
cp /var/configuration/nginx.conf /etc/nginx/sites-available/app &&
ln -sf /etc/nginx/sites-available/app /etc/nginx/sites-enabled/app &&
rm -f /etc/nginx/sites-enabled/default
ENV SVDIR=/var/runit
ENV WORKER_TIMEOUT=400
EXPOSE 5001 8883 8888
RUN apt-get update
RUN apt-get install -y openssh-server openssh-client
ENV FLASH_ATTENTION_FORCE_BUILD=TRUE
My requirements.txt :
azureml-core==1.55.0.post2
azureml-dataset-runtime==1.55.0
azureml-defaults==1.55.0
azure-ml==0.0.1
azure-ml-component==0.9.18.post2
azureml-mlflow==1.55.0
azureml-contrib-services==1.55.0
azureml-automl-common-tools==1.55.0
torch-tb-profiler~=0.4.0
azureml-inference-server-http
inference-schema
MarkupSafe==2.1.2
regex
numpy
pybind11
gymnasium
h5py
urllib3>=1.26.18
cryptography>=41.0.4
aiohttp>=3.8.5
huggingface-hub
mnist==0.2.2
tables==3.8.0
scikit-learn==1.2.1
matplotlib==3.5
protobuf==3.20.2
packaging==23.1
seaborn==0.10.0
scipy
orjson==3.8.12
scikit-image
imgaug==0.4.0
pillow
peft==0.8.2
transformers==4.40.1
sentencepiece==0.1.97
statsmodels==0.13.5
msgspec==0.15.1
openpyxl
captum==0.6.0
tokenizers==0.19.1
accelerate==0.29.3
imgaug==0.4.0
einops==0.7.0
torch==2.2.2
torchvision
chardet==3.0.4
ninja==1.11.1.1
exllamav2==0.0.11
(With transformers 4.39.2 as you recommended I also have the problem)
So I have (only when I use flash attention 2) :
tensor(-0.0877, device='cuda:0', grad_fn=)
10.732534408569336
tensor(nan, device='cuda:0', grad_fn=)
nan
tensor(nan, device='cuda:0', grad_fn=)
nan
tensor(nan, device='cuda:0', grad_fn=)
nan
When in my training loop i have :
pred = pred[:, -len_answer:]
print(pred.mean())
loss = F.cross_entropy(pred.permute((0, 2, 1)), input_ids[:, -len_answer:],
reduction='sum') / 100
opt.zero_grad()
loss.backward()
opt.step()
print(loss.item())
Maybe this could help to understand the situation ? https://arxiv.org/abs/2405.02803
I know, i should have passed everything to torch.bfloat16 instead of torch.float16.