Papers
arxiv:2412.05496

Flex Attention: A Programming Model for Generating Optimized Attention Kernels

Published on Dec 7, 2024
Authors:
,
,
,
,

Abstract

Over the past 7 years, attention has become one of the most important primitives in deep learning. The primary approach to optimize attention is FlashAttention, which fuses the operation together, drastically improving both the runtime and the memory consumption. However, the importance of FlashAttention combined with its monolithic nature poses a problem for researchers aiming to try new attention variants -- a "software lottery". This problem is exacerbated by the difficulty of writing efficient fused <PRE_TAG>attention</POST_TAG> kernels, resisting traditional compiler-based approaches. We introduce FlexAttention, a novel compiler-driven programming model that allows implementing the majority of attention variants in a few lines of idiomatic PyTorch code. We demonstrate that many existing attention variants (e.g. Alibi, Document Masking, PagedAttention, etc.) can be implemented via FlexAttention, and that we achieve competitive performance compared to these handwritten kernels. Finally, we demonstrate how FlexAttention allows for easy composition of attention variants, solving the combinatorial explosion of attention variants.

Community

Sign up or log in to comment

Models citing this paper 0

No model linking this paper

Cite arxiv.org/abs/2412.05496 in a model README.md to link it from this page.

Datasets citing this paper 0

No dataset linking this paper

Cite arxiv.org/abs/2412.05496 in a dataset README.md to link it from this page.

Spaces citing this paper 0

No Space linking this paper

Cite arxiv.org/abs/2412.05496 in a Space README.md to link it from this page.

Collections including this paper 0

No Collection including this paper

Add this paper to a collection to link it from this page.