Update modelling_RW.py
I've observed that loading this model to float16 or bfloat16 lead to bug
modelling_RW.py", line 289, in forward
attn_output = F.scaled_dot_product_attention(
RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: float and value.dtype: c10::Half instead.
I've encountered this error when running main script from lm-eval-harness :
python main.py \
--model hf-causal \
--model_args pretrained=tiiuae/falcon-7b,dtype=float16 \
--tasks winogrande,piqa,hellaswag,arc_easy,arc_challenge \
--batch_size 1
This is because the output of cos_sin
method of RotaryEmbedding
class is float32
.
In this commit I propose a simple fix for the model to work successfully with half precision.
I can confirm this fix works for inference - however, it seems you have introduced a duplicate forward()
method instead of updating the existing one?
Sorry, actually I wanted is to slightly update forward to output the same type given in query and key. Fixed.
it worked yesterday for me , clone new one getting allocation exceed issues , while running it on 80GB a100
return forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/falcon-40b/modelling_RW.py", line 93, in forward
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 79.32 GiB total capacity; 77.15 GiB already allocated; 832.00 KiB free; 78.13 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
my bad forgot to install
pip install xformers :P
still same error : (
@akashcollectiv this model occupies almost all A100 capacity if loaded on a single GPU. I think there is not enough memory for all but very short sequences.
Sorry, actually I wanted is to slightly update forward to output the same type given in query and key. Fixed.
Thanks! I can confirm this works as intended and it also enable loading the model in 8-bit which is great for inference :)
This change will unfortunately change things for bfloat16, due to numerical precision. Which is the only dtype we have properly validated the performance with, we should expect some degradation of model quality in fp16.
In particular the cos/sin needs to be applied in full precision, as was done in the original code. I believe only adding the q.dtype to the rotary forward should be sufficient, I'll take care of it today.
Should be fixed now, though as previously mentioned, inference with dtypes other than bfloat16 may incur model degradation.