Sapir's picture
Added tpu flash attention.
4f52f00
raw
history blame
125 Bytes
from enum import Enum
def execute_graph() -> None:
if _acceleration_type == AccelerationType.TPU:
xm.mark_step()