geekyrakshit's picture
add: utility for torch backend
70d9de4
raw
history blame
715 Bytes
import torch
import wandb
def get_wandb_artifact(
artifact_name: str, artifact_type: str, get_metadata: bool = False
) -> str:
if wandb.run:
artifact = wandb.use_artifact(artifact_name, type=artifact_type)
artifact_dir = artifact.download()
else:
api = wandb.Api()
artifact = api.artifact(artifact_name)
artifact_dir = artifact.download()
if get_metadata:
return artifact_dir, artifact.metadata
return artifact_dir
def get_torch_backend():
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
if torch.backends.mps.is_built():
return "mps"
return "cpu"
return "cpu"