from typing import * BACKEND = 'flash_attn' DEBUG = False def __from_env(): import os global BACKEND global DEBUG env_attn_backend = os.environ.get('ATTN_BACKEND') env_sttn_debug = os.environ.get('ATTN_DEBUG') if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']: BACKEND = env_attn_backend if env_sttn_debug is not None: DEBUG = env_sttn_debug == '1' print(f"[ATTENTION] Using backend: {BACKEND}") __from_env() def set_backend(backend: Literal['xformers', 'flash_attn']): global BACKEND BACKEND = backend def set_debug(debug: bool): global DEBUG DEBUG = debug from .full_attn import * from .modules import *