Anton Bushuiev commited on
Commit
394e2eb
1 Parent(s): 9a58393

Move to device inside @spaces.GPU

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -70,6 +70,9 @@ random.seed(0)
70
 
71
  @spaces.GPU
72
  def predict_ddg(models, ppi, muts, return_attn):
 
 
 
73
  if return_attn:
74
  ddg_pred, attns = predict_ddg_(models, ppi, muts, return_attn=return_attn)
75
  return ddg_pred.detach().cpu(), attns.detach().cpu()
@@ -503,8 +506,8 @@ with app:
503
  download_from_zenodo('weights.zip')
504
 
505
  # Set device
506
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
507
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
508
 
509
  # Load models
510
  models = [
 
70
 
71
  @spaces.GPU
72
  def predict_ddg(models, ppi, muts, return_attn):
73
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
74
+ print(f"[INFO] Device on prediction: {device}")
75
+ models = [model.to(device) for model in models]
76
  if return_attn:
77
  ddg_pred, attns = predict_ddg_(models, ppi, muts, return_attn=return_attn)
78
  return ddg_pred.detach().cpu(), attns.detach().cpu()
 
506
  download_from_zenodo('weights.zip')
507
 
508
  # Set device
 
509
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
510
+ print(f"[INFO] Device on start: {device}")
511
 
512
  # Load models
513
  models = [