training: fix type mismatch when training

#6
No description provided.
Jack477 changed pull request status to open
Jack477 changed pull request title from fix type mismatch when training to training: fix type mismatch when training

error stack when using fp16 training :

  File "/root/.cache/huggingface/modules/transformers_modules/modeling_deepseek.py", line 1252, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1570, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/modeling_deepseek.py", line 821, in forward
    q = self.q_proj(hidden_states)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1570, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/zero/linear.py", line 109, in zero3_linear_wrap
    return LinearFunctionForZeroStage3.apply(input, weight)
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.8/dist-packages/torch/cuda/amp/autocast_mode.py", line 98, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/zero/linear.py", line 57, in forward
    output = input.matmul(weight.t())
RuntimeError: expected scalar type Float but found Half
luofuli changed pull request status to merged

Sign up or log in to comment