Add SDPA attention
#2
by
Katsumata420
- opened
SDPA attention の追加
下記のようにすることで Attention 部分の処理が torch.matmul から torch の sdpa に変更されます(指定しない場合は eager)
model = AutoModel.from_pretrained("retrieva-jp/bert-1.3b", trust_remote_code=True, attn_implementation="sdpa")
SDPA Attention の検証結果
- SDPA Attention を利用した場合と、これまでの Attention(eager)を利用した場合で出力が大きく変わらないことを検証済み
- SDPA Attention を有効にすることで、秒間あたりのトークン処理数などが改善されることを確認済み
内部でも SDPA の有無で出力が変更しないことを確認できたためマージします
Katsumata420
changed pull request status to
merged