Anton Bushuiev commited on
Commit
21677dc
1 Parent(s): f44e701

move model ladoing back

Browse files
Files changed (1) hide show
  1. app.py +14 -15
app.py CHANGED
@@ -393,21 +393,6 @@ def update_plot(dropdown, dropdown_choices_to_plot_args):
393
  return plot_3dmol(*dropdown_choices_to_plot_args[dropdown])
394
 
395
 
396
- # Set device
397
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
398
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
399
-
400
- # Load models
401
- models = [
402
- DDGPPIformer.load_from_checkpoint(
403
- PPIFORMER_WEIGHTS_DIR / f'ddg_regression/{i}.ckpt',
404
- map_location=torch.device('cpu')
405
- ).eval()
406
- for i in range(3)
407
- ]
408
- models = [model.to(device) for model in models]
409
-
410
-
411
  app = gr.Blocks(theme=gr.themes.Default(primary_hue="green", secondary_hue="pink"))
412
  with app:
413
 
@@ -502,6 +487,20 @@ with app:
502
  # Download weights from Zenodo
503
  download_weights()
504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
  # Create temporary directory for storing downloaded PDBs and extracted PPIs
506
  temp_dir_obj = tempfile.TemporaryDirectory()
507
  temp_dir = Path(temp_dir_obj.name)
 
393
  return plot_3dmol(*dropdown_choices_to_plot_args[dropdown])
394
 
395
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  app = gr.Blocks(theme=gr.themes.Default(primary_hue="green", secondary_hue="pink"))
397
  with app:
398
 
 
487
  # Download weights from Zenodo
488
  download_weights()
489
 
490
+ # Set device
491
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
492
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
493
+
494
+ # Load models
495
+ models = [
496
+ DDGPPIformer.load_from_checkpoint(
497
+ PPIFORMER_WEIGHTS_DIR / f'ddg_regression/{i}.ckpt',
498
+ map_location=torch.device('cpu')
499
+ ).eval()
500
+ for i in range(3)
501
+ ]
502
+ models = [model.to(device) for model in models]
503
+
504
  # Create temporary directory for storing downloaded PDBs and extracted PPIs
505
  temp_dir_obj = tempfile.TemporaryDirectory()
506
  temp_dir = Path(temp_dir_obj.name)