zino36 commited on
Commit
7c18755
·
verified ·
1 Parent(s): 773c11a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -31
app.py CHANGED
@@ -40,54 +40,29 @@
40
  # --------------------------------------------------------
41
  # gradio demo executable
42
  # --------------------------------------------------------
43
- import os
44
- import torch
45
- import tempfile
46
- from contextlib import nullcontext
47
-
48
- from mast3r.demo import get_args_parser, main_demo
49
-
50
- from mast3r.model import AsymmetricMASt3R
51
- from mast3r.utils.misc import hash_md5
52
-
53
- import matplotlib.pyplot as pl
54
- pl.ion()
55
-
56
- torch.backends.cuda.matmul.allow_tf32 = True # for GPU >= Ampere and PyTorch >= 1.12
57
-
58
- def get_default_weights_path(model_name):
59
- # Construct default weights path based on model_name
60
- return f"naver/{model_name}"
61
-
62
  if __name__ == '__main__':
63
  parser = get_args_parser()
64
  args = parser.parse_args()
65
 
66
- # Ensure at least one of weights or model_name is provided
67
  if args.weights is None and args.model_name is None:
68
- # Provide a default model_name if both weights and model_name are not provided
69
  args.model_name = 'MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
70
 
71
- # If weights are not provided but model_name is, construct weights_path
72
  if args.weights is None:
73
- args.weights = get_default_weights_path(args.model_name)
74
-
75
- if args.server_name is not None:
76
- server_name = args.server_name
77
- else:
78
- server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
79
 
80
- # Use the provided or default weights_path
 
81
  weights_path = args.weights
82
 
83
- # Load the model with the weights_path
84
  model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
85
  chkpt_tag = hash_md5(weights_path)
86
 
87
  def get_context(tmp_dir):
88
  return tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') if tmp_dir is None \
89
  else nullcontext(tmp_dir)
90
-
91
  with get_context(args.tmp_dir) as tmpdirname:
92
  cache_path = os.path.join(tmpdirname, chkpt_tag)
93
  os.makedirs(cache_path, exist_ok=True)
 
40
  # --------------------------------------------------------
41
  # gradio demo executable
42
  # --------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  if __name__ == '__main__':
44
  parser = get_args_parser()
45
  args = parser.parse_args()
46
 
47
+ # Set default values for required arguments
48
  if args.weights is None and args.model_name is None:
 
49
  args.model_name = 'MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
50
 
 
51
  if args.weights is None:
52
+ args.weights = f"naver/{args.model_name}"
 
 
 
 
 
53
 
54
+ # Rest of the code for setting up the server and loading the model
55
+ server_name = args.server_name or ('0.0.0.0' if args.local_network else '127.0.0.1')
56
  weights_path = args.weights
57
 
58
+ # Load the model
59
  model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
60
  chkpt_tag = hash_md5(weights_path)
61
 
62
  def get_context(tmp_dir):
63
  return tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') if tmp_dir is None \
64
  else nullcontext(tmp_dir)
65
+
66
  with get_context(args.tmp_dir) as tmpdirname:
67
  cache_path = os.path.join(tmpdirname, chkpt_tag)
68
  os.makedirs(cache_path, exist_ok=True)