Spaces:
Sleeping
Sleeping
""" | |
Utility functions for interpretability. | |
""" | |
import torch | |
def get_torch_device() -> str: | |
""" | |
Returns the device and dtype to be used for torch tensors. | |
""" | |
if torch.cuda.is_available(): | |
device = "cuda:0" | |
elif torch.backends.mps.is_available(): # for Apple Silicon | |
device = "mps" | |
else: | |
device = "cpu" | |
return device | |