Update app.py
Browse files
app.py
CHANGED
@@ -40,6 +40,49 @@
|
|
40 |
# --------------------------------------------------------
|
41 |
# gradio demo executable
|
42 |
# --------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
if __name__ == '__main__':
|
44 |
parser = get_args_parser()
|
45 |
args = parser.parse_args()
|
|
|
40 |
# --------------------------------------------------------
|
41 |
# gradio demo executable
|
42 |
# --------------------------------------------------------
|
43 |
+
#!/usr/bin/env python3
|
44 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
45 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
46 |
+
#
|
47 |
+
# --------------------------------------------------------
|
48 |
+
# gradio demo executable
|
49 |
+
# --------------------------------------------------------
|
50 |
+
import os
|
51 |
+
import torch
|
52 |
+
import tempfile
|
53 |
+
from contextlib import nullcontext
|
54 |
+
|
55 |
+
from mast3r.demo import get_args_parser, main_demo
|
56 |
+
|
57 |
+
from mast3r.model import AsymmetricMASt3R
|
58 |
+
from mast3r.utils.misc import hash_md5
|
59 |
+
|
60 |
+
import matplotlib.pyplot as pl
|
61 |
+
pl.ion()
|
62 |
+
|
63 |
+
torch.backends.cuda.matmul.allow_tf32 = True # for GPU >= Ampere and PyTorch >= 1.12
|
64 |
+
|
65 |
+
import argparse
|
66 |
+
|
67 |
+
def get_args_parser():
|
68 |
+
parser = argparse.ArgumentParser(description="MASt3R Demo")
|
69 |
+
parser.add_argument("--weights", type=str, default=None, help="Path to the weights file.")
|
70 |
+
parser.add_argument("--model_name", type=str, default=None, choices=[
|
71 |
+
'MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'], help="Name of the model to use.")
|
72 |
+
parser.add_argument("--device", type=str, default='cuda', help="Device to run the model on.")
|
73 |
+
parser.add_argument("--server_name", type=str, default=None, help="Server name to use.")
|
74 |
+
parser.add_argument("--local_network", action='store_true', help="Run on local network.")
|
75 |
+
parser.add_argument("--image_size", type=int, choices=[512, 224], default=512, help="Size of the images.")
|
76 |
+
parser.add_argument("--server_port", type=int, default=None, help="Port for the server.")
|
77 |
+
parser.add_argument("--tmp_dir", type=str, default=None, help="Temporary directory.")
|
78 |
+
parser.add_argument("--silent", action='store_true', help="Run silently.")
|
79 |
+
parser.add_argument("--share", action='store_true', help="Share the application.")
|
80 |
+
parser.add_argument("--gradio_delete_cache", action='store_true', help="Delete Gradio cache.")
|
81 |
+
return parser
|
82 |
+
|
83 |
+
def get_default_weights_path(model_name):
|
84 |
+
# Construct default weights path based on model_name
|
85 |
+
return f"naver/{model_name}"
|
86 |
if __name__ == '__main__':
|
87 |
parser = get_args_parser()
|
88 |
args = parser.parse_args()
|