geekyrakshit commited on
Commit
af715dd
·
1 Parent(s): 335e8a6

update: get_torch_backend

Browse files
Files changed (1) hide show
  1. medrag_multi_modal/utils.py +2 -1
medrag_multi_modal/utils.py CHANGED
@@ -22,7 +22,8 @@ def get_wandb_artifact(
22
 
23
  def get_torch_backend():
24
  if torch.cuda.is_available():
25
- return "cuda"
 
26
  if torch.backends.mps.is_available():
27
  if torch.backends.mps.is_built():
28
  return "mps"
 
22
 
23
  def get_torch_backend():
24
  if torch.cuda.is_available():
25
+ if torch.backends.cuda.is_built():
26
+ return "cuda"
27
  if torch.backends.mps.is_available():
28
  if torch.backends.mps.is_built():
29
  return "mps"