from enum import Enum def execute_graph() -> None: if _acceleration_type == AccelerationType.TPU: xm.mark_step()