Anton Bushuiev commited on
Commit
470975c
·
1 Parent(s): 5b962d1

wrap whole predict into spaces.GPU

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -45,7 +45,8 @@ from ppiref.utils.ppi import PPIPath
45
  from ppiref.utils.residue import Residue
46
  from ppiformer.tasks.node import DDGPPIformer
47
  from ppiformer.utils.api import download_weights
48
- from ppiformer.utils.api import predict_ddg as predict_ddg_
 
49
  from ppiformer.utils.torch import fill_diagonal
50
  from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
51
 
@@ -59,14 +60,14 @@ logging.basicConfig(
59
  random.seed(0)
60
 
61
 
62
- @spaces.GPU
63
- def predict_ddg(models, ppi, muts, return_attn):
64
- if return_attn:
65
- ddg_pred, attns = predict_ddg_(models, ppi, muts, return_attn=return_attn)
66
- return ddg_pred.detach().cpu(), attns.detach().cpu()
67
- else:
68
- ddg_pred = predict_ddg_(models, ppi, muts, return_attn=return_attn)
69
- return ddg_pred.detach().cpu()
70
 
71
 
72
  def process_inputs(inputs, temp_dir):
@@ -287,6 +288,7 @@ def plot_3dmol(pdb_path, ppi_path, mut, attn, attn_mut_id=0):
287
  allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
288
 
289
 
 
290
  def predict(models, temp_dir, *inputs):
291
  logging.info('Starting prediction')
292
 
 
45
  from ppiref.utils.residue import Residue
46
  from ppiformer.tasks.node import DDGPPIformer
47
  from ppiformer.utils.api import download_weights
48
+ # from ppiformer.utils.api import predict_ddg as predict_ddg_
49
+ from ppiformer.utils.api import predict_ddg
50
  from ppiformer.utils.torch import fill_diagonal
51
  from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
52
 
 
60
  random.seed(0)
61
 
62
 
63
+ # @spaces.GPU
64
+ # def predict_ddg(models, ppi, muts, return_attn):
65
+ # if return_attn:
66
+ # ddg_pred, attns = predict_ddg_(models, ppi, muts, return_attn=return_attn)
67
+ # return ddg_pred.detach().cpu(), attns.detach().cpu()
68
+ # else:
69
+ # ddg_pred = predict_ddg_(models, ppi, muts, return_attn=return_attn)
70
+ # return ddg_pred.detach().cpu()
71
 
72
 
73
  def process_inputs(inputs, temp_dir):
 
288
  allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
289
 
290
 
291
+ @spaces.GPU
292
  def predict(models, temp_dir, *inputs):
293
  logging.info('Starting prediction')
294