Gradient Checkpointing Is Broken
File "/root/.cache/huggingface/modules/transformers_modules/SinclairSchneider/dbrx-base-quantization-fixed/4bee1a61f23b8167cda9e9edddc438b1daf31b7c/modeling_dbrx.py", line 1093, in forward
block_outputs = self._gradient_checkpointing_func(
File "/usr/local/lib/python3.8/dist-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/checkpoint.py", line 472, in checkpoint
raise ValueError(
ValueError: Unexpected keyword arguments: attention_mask,position_ids,past_key_values,output_attentions,output_router_logits,use_cache,cache_position
Are you using pytorch 2.2.2 or an older version? Most of these problems occur due to an old package version.
I am using 2.2.2
accelerate 0.29.0.dev0
aiohttp 3.9.3
aiosignal 1.3.1
annotated-types 0.6.0
appdirs 1.4.4
async-timeout 4.0.3
attrs 23.2.0
bitsandbytes 0.43.0
certifi 2024.2.2
charset-normalizer 3.3.2
click 8.1.7
datasets 2.18.0
deepspeed 0.14.0+ce78a632
dill 0.3.8
docker-pycreds 0.4.0
docstring_parser 0.16
einops 0.7.0
exceptiongroup 1.2.0
filelock 3.13.3
flash-attn 2.5.6
frozenlist 1.4.1
fsspec 2024.2.0
gitdb 4.0.11
GitPython 3.1.43
hf_transfer 0.1.6
hjson 3.1.0
huggingface-hub 0.22.2
idna 3.6
iniconfig 2.0.0
Jinja2 3.1.3
markdown-it-py 3.0.0
MarkupSafe 2.1.5
mdurl 0.1.2
mpmath 1.3.0
multidict 6.0.5
multiprocess 0.70.16
networkx 3.1
ninja 1.11.1.1
numpy 1.24.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.19.3
nvidia-nvjitlink-cu12 12.4.99
nvidia-nvtx-cu12 12.1.105
packaging 24.0
pandas 2.0.3
peft 0.10.1.dev0
pillow 10.3.0
pip 24.0
pluggy 1.4.0
protobuf 3.20.1
psutil 5.9.8
py-cpuinfo 9.0.0
pyarrow 15.0.2
pyarrow-hotfix 0.6
pydantic 2.6.4
pydantic_core 2.16.3
Pygments 2.17.2
pynvml 11.5.0
pytest 8.1.1
python-dateutil 2.9.0.post0
pytz 2024.1
PyYAML 6.0.1
regex 2023.12.25
requests 2.31.0
rich 13.7.1
safetensors 0.4.2
scipy 1.10.1
sentencepiece 0.2.0
sentry-sdk 1.44.0
setproctitle 1.3.3
setuptools 69.2.0
shtab 1.7.1
six 1.16.0
smmap 5.0.1
sympy 1.12
text-generation 0.7.0
tiktoken 0.6.0
tokenizers 0.15.2
tomli 2.0.1
torch 2.2.2
torchaudio 2.2.2
torchvision 0.17.2
tqdm 4.66.2
transformers 4.40.0.dev0
triton 2.2.0
trl 0.8.1
typing_extensions 4.10.0
tyro 0.7.3
tzdata 2024.1
urllib3 2.2.1
wandb 0.16.5
wheel 0.43.0
xxhash 3.4.1
yarl 1.9.4
Okay sorry for leading into the wrong direction. Apparently there is a problem with gradient checkpointing.
https://huggingface.co/v2ray/dbrx-base-fixed from
@v2ray
Quote:
Error when using gradient checkpointing - Fixed by using positional arguments instead because _gradient_checkpointing_func doesn't support kwargs.
Ah gotcha. Mind updating the code so I don't have to reupload?
if self.gradient_checkpointing and self.training:
block_outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
output_router_logits,
use_cache,
cache_position,
)
Ah gotcha. Mind updating the code so I don't have to reupload?
if self.gradient_checkpointing and self.training: block_outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, output_router_logits, use_cache, cache_position, )
done :)
This leads to another issue. But I am able to fine-tune without gradient checkpointing, just with not as large sequences without using gradient checkpointing. I think I will likely just wait for the official implementation from huggingface.
File "trl_finetune.py", line 387, in
trainer.train()
File "/usr/local/lib/python3.8/dist-packages/trl/trainer/sft_trainer.py", line 360, in train
output = super().train(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 1842, in train
return inner_training_loop(
File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 2186, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 3121, in training_step
loss = self.compute_loss(model, inputs)
File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 3144, in compute_loss
outputs = model(**inputs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/operations.py", line 825, in forward
return model_forward(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/operations.py", line 813, in call
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/usr/local/lib/python3.8/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/peft/peft_model.py", line 1247, in forward
return self.base_model(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/peft/tuners/tuners_utils.py", line 178, in forward
return self.model.forward(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/SinclairSchneider/dbrx-base-quantization-fixed/27110c15205682a445c8e625cbf843c22ec3eecf/modeling_dbrx.py", line 1307, in forward
outputs = self.transformer(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/SinclairSchneider/dbrx-base-quantization-fixed/27110c15205682a445c8e625cbf843c22ec3eecf/modeling_dbrx.py", line 1093, in forward
block_outputs = self._gradient_checkpointing_func(
File "/usr/local/lib/python3.8/dist-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/checkpoint.py", line 482, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/usr/local/lib/python3.8/dist-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.8/dist-packages/torch/utils/checkpoint.py", line 261, in forward
outputs = run_function(*args)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/SinclairSchneider/dbrx-base-quantization-fixed/27110c15205682a445c8e625cbf843c22ec3eecf/modeling_dbrx.py", line 886, in forward
resid_states, hidden_states, self_attn_weights, present_key_value = self.norm_attn_norm(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/SinclairSchneider/dbrx-base-quantization-fixed/27110c15205682a445c8e625cbf843c22ec3eecf/modeling_dbrx.py", line 668, in forward
hidden_states, attn_weights, past_key_value = self.attn(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/accelerate/hooks.py", line 166, in new_forward
output = module._old_forward(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/SinclairSchneider/dbrx-base-quantization-fixed/27110c15205682a445c8e625cbf843c22ec3eecf/modeling_dbrx.py", line 444, in forward
cos, sin = self.rotary_emb(value_states, position_ids)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/SinclairSchneider/dbrx-base-quantization-fixed/27110c15205682a445c8e625cbf843c22ec3eecf/modeling_dbrx.py", line 63, in forward
position_ids.shape[0], -1, 1)
AttributeError: 'NoneType' object has no attribute 'shape'