Gradient Checkpointing Is Broken

#2
by mallorbc - opened

File "/root/.cache/huggingface/modules/transformers_modules/SinclairSchneider/dbrx-base-quantization-fixed/4bee1a61f23b8167cda9e9edddc438b1daf31b7c/modeling_dbrx.py", line 1093, in forward
block_outputs = self._gradient_checkpointing_func(
File "/usr/local/lib/python3.8/dist-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/checkpoint.py", line 472, in checkpoint
raise ValueError(
ValueError: Unexpected keyword arguments: attention_mask,position_ids,past_key_values,output_attentions,output_router_logits,use_cache,cache_position

Are you using pytorch 2.2.2 or an older version? Most of these problems occur due to an old package version.

I am using 2.2.2

accelerate 0.29.0.dev0
aiohttp 3.9.3
aiosignal 1.3.1
annotated-types 0.6.0
appdirs 1.4.4
async-timeout 4.0.3
attrs 23.2.0
bitsandbytes 0.43.0
certifi 2024.2.2
charset-normalizer 3.3.2
click 8.1.7
datasets 2.18.0
deepspeed 0.14.0+ce78a632
dill 0.3.8
docker-pycreds 0.4.0
docstring_parser 0.16
einops 0.7.0
exceptiongroup 1.2.0
filelock 3.13.3
flash-attn 2.5.6
frozenlist 1.4.1
fsspec 2024.2.0
gitdb 4.0.11
GitPython 3.1.43
hf_transfer 0.1.6
hjson 3.1.0
huggingface-hub 0.22.2
idna 3.6
iniconfig 2.0.0
Jinja2 3.1.3
markdown-it-py 3.0.0
MarkupSafe 2.1.5
mdurl 0.1.2
mpmath 1.3.0
multidict 6.0.5
multiprocess 0.70.16
networkx 3.1
ninja 1.11.1.1
numpy 1.24.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.19.3
nvidia-nvjitlink-cu12 12.4.99
nvidia-nvtx-cu12 12.1.105
packaging 24.0
pandas 2.0.3
peft 0.10.1.dev0
pillow 10.3.0
pip 24.0
pluggy 1.4.0
protobuf 3.20.1
psutil 5.9.8
py-cpuinfo 9.0.0
pyarrow 15.0.2
pyarrow-hotfix 0.6
pydantic 2.6.4
pydantic_core 2.16.3
Pygments 2.17.2
pynvml 11.5.0
pytest 8.1.1
python-dateutil 2.9.0.post0
pytz 2024.1
PyYAML 6.0.1
regex 2023.12.25
requests 2.31.0
rich 13.7.1
safetensors 0.4.2
scipy 1.10.1
sentencepiece 0.2.0
sentry-sdk 1.44.0
setproctitle 1.3.3
setuptools 69.2.0
shtab 1.7.1
six 1.16.0
smmap 5.0.1
sympy 1.12
text-generation 0.7.0
tiktoken 0.6.0
tokenizers 0.15.2
tomli 2.0.1
torch 2.2.2
torchaudio 2.2.2
torchvision 0.17.2
tqdm 4.66.2
transformers 4.40.0.dev0
triton 2.2.0
trl 0.8.1
typing_extensions 4.10.0
tyro 0.7.3
tzdata 2024.1
urllib3 2.2.1
wandb 0.16.5
wheel 0.43.0
xxhash 3.4.1
yarl 1.9.4

Okay sorry for leading into the wrong direction. Apparently there is a problem with gradient checkpointing.
https://huggingface.co/v2ray/dbrx-base-fixed from @v2ray
Quote:
Error when using gradient checkpointing - Fixed by using positional arguments instead because _gradient_checkpointing_func doesn't support kwargs.

Ah gotcha. Mind updating the code so I don't have to reupload?

        if self.gradient_checkpointing and self.training:
            block_outputs = self._gradient_checkpointing_func(
                block.__call__,
                hidden_states,
                causal_mask,
                position_ids,
                past_key_values,
                output_attentions,
                output_router_logits,
                use_cache,
                cache_position,
            )

Ah gotcha. Mind updating the code so I don't have to reupload?

        if self.gradient_checkpointing and self.training:
            block_outputs = self._gradient_checkpointing_func(
                block.__call__,
                hidden_states,
                causal_mask,
                position_ids,
                past_key_values,
                output_attentions,
                output_router_logits,
                use_cache,
                cache_position,
            )

done :)

This leads to another issue. But I am able to fine-tune without gradient checkpointing, just with not as large sequences without using gradient checkpointing. I think I will likely just wait for the official implementation from huggingface.

File "trl_finetune.py", line 387, in
trainer.train()
File "/usr/local/lib/python3.8/dist-packages/trl/trainer/sft_trainer.py", line 360, in train
output = super().train(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 1842, in train
return inner_training_loop(
File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 2186, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 3121, in training_step
loss = self.compute_loss(model, inputs)
File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 3144, in compute_loss
outputs = model(**inputs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/operations.py", line 825, in forward
return model_forward(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/operations.py", line 813, in call
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/usr/local/lib/python3.8/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/peft/peft_model.py", line 1247, in forward
return self.base_model(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/peft/tuners/tuners_utils.py", line 178, in forward
return self.model.forward(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/SinclairSchneider/dbrx-base-quantization-fixed/27110c15205682a445c8e625cbf843c22ec3eecf/modeling_dbrx.py", line 1307, in forward
outputs = self.transformer(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/SinclairSchneider/dbrx-base-quantization-fixed/27110c15205682a445c8e625cbf843c22ec3eecf/modeling_dbrx.py", line 1093, in forward
block_outputs = self._gradient_checkpointing_func(
File "/usr/local/lib/python3.8/dist-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/checkpoint.py", line 482, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/usr/local/lib/python3.8/dist-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.8/dist-packages/torch/utils/checkpoint.py", line 261, in forward
outputs = run_function(*args)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/SinclairSchneider/dbrx-base-quantization-fixed/27110c15205682a445c8e625cbf843c22ec3eecf/modeling_dbrx.py", line 886, in forward
resid_states, hidden_states, self_attn_weights, present_key_value = self.norm_attn_norm(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/SinclairSchneider/dbrx-base-quantization-fixed/27110c15205682a445c8e625cbf843c22ec3eecf/modeling_dbrx.py", line 668, in forward
hidden_states, attn_weights, past_key_value = self.attn(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/SinclairSchneider/dbrx-base-quantization-fixed/27110c15205682a445c8e625cbf843c22ec3eecf/modeling_dbrx.py", line 444, in forward
cos, sin = self.rotary_emb(value_states, position_ids)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/SinclairSchneider/dbrx-base-quantization-fixed/27110c15205682a445c8e625cbf843c22ec3eecf/modeling_dbrx.py", line 63, in forward
position_ids.shape[0], -1, 1)
AttributeError: 'NoneType' object has no attribute 'shape'

@mallorbc This is because the order of the positional args are not in the correct order of the called function(Line 848 and 849 should be swapped). If you use my implementation here, there shouldn't be any issue.

Sign up or log in to comment