frappuccino's picture
Upload folder using huggingface_hub
0140c70 verified
raw
history blame contribute delete
No virus
370 Bytes
import torch
def detect_device():
"""
Detects the appropriate device to run on, and return the device and dtype.
"""
if torch.cuda.is_available():
return torch.device("cuda"), torch.float16
elif torch.backends.mps.is_available():
return torch.device("mps"), torch.float16
else:
return torch.device("cpu"), torch.float32