Realcat commited on
Commit
4c930ba
·
1 Parent(s): 5f00ed5

add: caches

Browse files
common/app_class.py CHANGED
@@ -34,7 +34,6 @@ class ImageMatchingApp:
34
  )
35
  self.cfg = load_config(self.config_path)
36
  self.matcher_zoo = get_matcher_zoo(self.cfg["matcher_zoo"])
37
- # self.ransac_zoo = get_ransac_zoo(self.cfg["ransac_zoo"])
38
  self.app = None
39
  self.init_interface()
40
  # print all the keys
 
34
  )
35
  self.cfg = load_config(self.config_path)
36
  self.matcher_zoo = get_matcher_zoo(self.cfg["matcher_zoo"])
 
37
  self.app = None
38
  self.init_interface()
39
  # print all the keys
common/utils.py CHANGED
@@ -39,7 +39,7 @@ DEFAULT_MATCHING_THRESHOLD = 0.2
39
  DEFAULT_SETTING_GEOMETRY = "Homography"
40
  GRADIO_VERSION = gr.__version__.split(".")[0]
41
  MATCHER_ZOO = None
42
-
43
 
44
  def load_config(config_name: str) -> Dict[str, Any]:
45
  """
@@ -467,15 +467,23 @@ def run_matching(
467
  f"Success! Please be patient and allow for about 2-3 minutes."
468
  f" Due to CPU inference, {key} is quiet slow."
469
  )
470
-
471
  model = matcher_zoo[key]
472
  match_conf = model["matcher"]
473
  # update match config
474
  match_conf["model"]["match_threshold"] = match_threshold
475
  match_conf["model"]["max_keypoints"] = extract_max_keypoints
476
  t0 = time.time()
477
- matcher = get_model(match_conf)
 
 
 
 
 
 
 
 
478
  gr.Info(f"Loading model using: {time.time()-t0:.3f}s")
 
479
  t1 = time.time()
480
 
481
  if model["dense"]:
@@ -489,7 +497,15 @@ def run_matching(
489
  # update extract config
490
  extract_conf["model"]["max_keypoints"] = extract_max_keypoints
491
  extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
492
- extractor = get_feature_model(extract_conf)
 
 
 
 
 
 
 
 
493
  pred0 = extract_features.extract(
494
  extractor, image0, extract_conf["preprocessing"]
495
  )
@@ -499,6 +515,7 @@ def run_matching(
499
  pred = match_features.match_images(matcher, pred0, pred1)
500
  del extractor
501
  gr.Info(f"Matching images done using: {time.time()-t1:.3f}s")
 
502
  t1 = time.time()
503
  # plot images with keypoints
504
  titles = [
@@ -532,6 +549,8 @@ def run_matching(
532
  ransac_max_iter=ransac_max_iter,
533
  )
534
  gr.Info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
 
 
535
 
536
  # plot images with ransac matches
537
  titles = [
@@ -541,6 +560,8 @@ def run_matching(
541
  output_matches_ransac, num_matches_ransac = display_matches(
542
  pred, titles=titles
543
  )
 
 
544
 
545
  t1 = time.time()
546
  # plot wrapped images
@@ -552,6 +573,7 @@ def run_matching(
552
  choice_estimate_geom,
553
  )
554
  gr.Info(f"Compute geometry done using: {time.time()-t1:.3f}s")
 
555
  plt.close("all")
556
  del pred
557
  logger.info(f"TOTAL time: {time.time()-t0:.3f}s")
 
39
  DEFAULT_SETTING_GEOMETRY = "Homography"
40
  GRADIO_VERSION = gr.__version__.split(".")[0]
41
  MATCHER_ZOO = None
42
+ models_already_loaded = {}
43
 
44
  def load_config(config_name: str) -> Dict[str, Any]:
45
  """
 
467
  f"Success! Please be patient and allow for about 2-3 minutes."
468
  f" Due to CPU inference, {key} is quiet slow."
469
  )
 
470
  model = matcher_zoo[key]
471
  match_conf = model["matcher"]
472
  # update match config
473
  match_conf["model"]["match_threshold"] = match_threshold
474
  match_conf["model"]["max_keypoints"] = extract_max_keypoints
475
  t0 = time.time()
476
+ cache_key = match_conf["model"]["name"]
477
+ if cache_key in models_already_loaded:
478
+ matcher = models_already_loaded[cache_key]
479
+ matcher.conf['max_keypoints'] = extract_max_keypoints
480
+ matcher.conf['match_threshold'] = match_threshold
481
+ logger.info(f"Loaded cached model {cache_key}")
482
+ else:
483
+ matcher = get_model(match_conf)
484
+ models_already_loaded[cache_key] = matcher
485
  gr.Info(f"Loading model using: {time.time()-t0:.3f}s")
486
+ logger.info(f"Loading model using: {time.time()-t0:.3f}s")
487
  t1 = time.time()
488
 
489
  if model["dense"]:
 
497
  # update extract config
498
  extract_conf["model"]["max_keypoints"] = extract_max_keypoints
499
  extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
500
+ cache_key = extract_conf["model"]["name"]
501
+ if cache_key in models_already_loaded:
502
+ extractor = models_already_loaded[cache_key]
503
+ extractor.conf['max_keypoints'] = extract_max_keypoints
504
+ extractor.conf['keypoint_threshold'] = keypoint_threshold
505
+ logger.info(f"Loaded cached model {cache_key}")
506
+ else:
507
+ extractor = get_feature_model(extract_conf)
508
+ models_already_loaded[cache_key] = extractor
509
  pred0 = extract_features.extract(
510
  extractor, image0, extract_conf["preprocessing"]
511
  )
 
515
  pred = match_features.match_images(matcher, pred0, pred1)
516
  del extractor
517
  gr.Info(f"Matching images done using: {time.time()-t1:.3f}s")
518
+ logger.info(f"Matching images done using: {time.time()-t1:.3f}s")
519
  t1 = time.time()
520
  # plot images with keypoints
521
  titles = [
 
549
  ransac_max_iter=ransac_max_iter,
550
  )
551
  gr.Info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
552
+ logger.info(f"RANSAC matches done using: {time.time()-t1:.3f}s")
553
+ t1 = time.time()
554
 
555
  # plot images with ransac matches
556
  titles = [
 
560
  output_matches_ransac, num_matches_ransac = display_matches(
561
  pred, titles=titles
562
  )
563
+ gr.Info(f"Display matches done using: {time.time()-t1:.3f}s")
564
+ logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
565
 
566
  t1 = time.time()
567
  # plot wrapped images
 
573
  choice_estimate_geom,
574
  )
575
  gr.Info(f"Compute geometry done using: {time.time()-t1:.3f}s")
576
+ logger.info(f"Compute geometry done using: {time.time()-t1:.3f}s")
577
  plt.close("all")
578
  del pred
579
  logger.info(f"TOTAL time: {time.time()-t0:.3f}s")
hloc/matchers/aspanformer.py CHANGED
@@ -21,6 +21,7 @@ class ASpanFormer(BaseModel):
21
  "weights": "outdoor",
22
  "match_threshold": 0.2,
23
  "sinkhorn_iterations": 20,
 
24
  "config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py",
25
  "model_name": "weights_aspanformer.tar",
26
  }
@@ -68,7 +69,6 @@ class ASpanFormer(BaseModel):
68
 
69
  do_system(f"cd {str(aspanformer_path)} & tar -xvf {str(tar_path)}")
70
 
71
- logger.info(f"Loading Aspanformer model...")
72
 
73
  config = get_cfg_defaults()
74
  config.merge_from_file(conf["config_path"])
@@ -86,6 +86,7 @@ class ASpanFormer(BaseModel):
86
  "state_dict"
87
  ]
88
  self.net.load_state_dict(state_dict, strict=False)
 
89
 
90
  def _forward(self, data):
91
  data_ = {
 
21
  "weights": "outdoor",
22
  "match_threshold": 0.2,
23
  "sinkhorn_iterations": 20,
24
+ "max_keypoints": 2048,
25
  "config_path": aspanformer_path / "configs/aspan/outdoor/aspan_test.py",
26
  "model_name": "weights_aspanformer.tar",
27
  }
 
69
 
70
  do_system(f"cd {str(aspanformer_path)} & tar -xvf {str(tar_path)}")
71
 
 
72
 
73
  config = get_cfg_defaults()
74
  config.merge_from_file(conf["config_path"])
 
86
  "state_dict"
87
  ]
88
  self.net.load_state_dict(state_dict, strict=False)
89
+ logger.info(f"Loaded Aspanformer model")
90
 
91
  def _forward(self, data):
92
  data_ = {
hloc/matchers/loftr.py CHANGED
@@ -10,7 +10,7 @@ class LoFTR(BaseModel):
10
  default_conf = {
11
  "weights": "outdoor",
12
  "match_threshold": 0.2,
13
- "max_num_matches": None,
14
  }
15
  required_inputs = ["image0", "image1"]
16
 
@@ -36,7 +36,7 @@ class LoFTR(BaseModel):
36
 
37
  scores = pred["confidence"]
38
 
39
- top_k = self.conf["max_num_matches"]
40
  if top_k is not None and len(scores) > top_k:
41
  keep = torch.argsort(scores, descending=True)[:top_k]
42
  pred["keypoints0"], pred["keypoints1"] = (
 
10
  default_conf = {
11
  "weights": "outdoor",
12
  "match_threshold": 0.2,
13
+ "max_keypoints": None,
14
  }
15
  required_inputs = ["image0", "image1"]
16
 
 
36
 
37
  scores = pred["confidence"]
38
 
39
+ top_k = self.conf["max_keypoints"]
40
  if top_k is not None and len(scores) > top_k:
41
  keep = torch.argsort(scores, descending=True)[:top_k]
42
  pred["keypoints0"], pred["keypoints1"] = (
hloc/matchers/roma.py CHANGED
@@ -52,7 +52,7 @@ class Roma(BaseModel):
52
  logger.info(f"Downloading the dinov2 model with `{cmd}`.")
53
  subprocess.run(cmd, check=True)
54
 
55
- logger.info(f"Loading Roma model...")
56
  # load the model
57
  weights = torch.load(model_path, map_location="cpu")
58
  dinov2_weights = torch.load(dinov2_weights, map_location="cpu")
 
52
  logger.info(f"Downloading the dinov2 model with `{cmd}`.")
53
  subprocess.run(cmd, check=True)
54
 
55
+ logger.info(f"Loading Roma model")
56
  # load the model
57
  weights = torch.load(model_path, map_location="cpu")
58
  dinov2_weights = torch.load(dinov2_weights, map_location="cpu")