Shape mismatch error for F.scaled_dot_product_attention in modelling_RW.py

#27
by xtliu - opened

For some reason, the class "Attention" module produces a shape mismatch error when it goes to this branch as below. It complains that self.num_heads is not equal to self.num_kv, which will happen because the model config default value for num_kv is 1.

    if alibi is None:
        query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
        key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
        value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim).expand(-1, self.num_heads, -1, -1)
         attn_output = F.scaled_dot_product_attention(
            query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
        )

I got the same issue, did you find a way to solve this?

I just use expand operation in the code piece above to make it the same shape

Sign up or log in to comment