Spaces:
Running
Running
Realcat
commited on
Commit
β’
9cde3b4
1
Parent(s):
d64a873
update: roma
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- build_docker.sh +1 -0
- hloc/matchers/roma.py +3 -1
- third_party/{Roma β RoMa}/.gitignore +0 -0
- third_party/RoMa/LICENSE +21 -0
- third_party/{Roma β RoMa}/README.md +44 -15
- third_party/{Roma β RoMa}/assets/sacre_coeur_A.jpg +0 -0
- third_party/{Roma β RoMa}/assets/sacre_coeur_B.jpg +0 -0
- third_party/RoMa/assets/toronto_A.jpg +3 -0
- third_party/RoMa/assets/toronto_B.jpg +3 -0
- third_party/{Roma β RoMa}/data/.gitignore +0 -0
- third_party/RoMa/demo/demo_3D_effect.py +46 -0
- third_party/{Roma β RoMa}/demo/demo_fundamental.py +5 -10
- third_party/{Roma β RoMa}/demo/demo_match.py +11 -14
- third_party/RoMa/demo/demo_match_opencv_sift.py +43 -0
- third_party/RoMa/demo/gif/.gitignore +2 -0
- third_party/{Roma β RoMa}/pretrained/dinov2_vitl14_pretrain.pth +0 -0
- third_party/{Roma β RoMa}/pretrained/roma_outdoor.pth +0 -0
- third_party/{Roma β RoMa}/requirements.txt +1 -1
- third_party/{Roma β RoMa}/roma/__init__.py +2 -2
- third_party/{Roma β RoMa}/roma/benchmarks/__init__.py +0 -0
- third_party/{Roma β RoMa}/roma/benchmarks/hpatches_sequences_homog_benchmark.py +7 -5
- third_party/{Roma β RoMa}/roma/benchmarks/megadepth_dense_benchmark.py +11 -27
- third_party/{Roma β RoMa}/roma/benchmarks/megadepth_pose_estimation_benchmark.py +19 -49
- third_party/{Roma β RoMa}/roma/benchmarks/scannet_benchmark.py +30 -27
- third_party/{Roma β RoMa}/roma/checkpointing/__init__.py +0 -0
- third_party/{Roma β RoMa}/roma/checkpointing/checkpoint.py +4 -5
- third_party/{Roma β RoMa}/roma/datasets/__init__.py +1 -1
- third_party/{Roma β RoMa}/roma/datasets/megadepth.py +42 -81
- third_party/{Roma β RoMa}/roma/datasets/scannet.py +72 -103
- third_party/RoMa/roma/losses/__init__.py +1 -0
- third_party/{Roma β RoMa}/roma/losses/robust_loss.py +54 -119
- third_party/RoMa/roma/models/__init__.py +1 -0
- third_party/{Roma β RoMa}/roma/models/encoders.py +7 -15
- third_party/{Roma β RoMa}/roma/models/matcher.py +100 -21
- third_party/RoMa/roma/models/model_zoo/__init__.py +53 -0
- third_party/{Roma β RoMa}/roma/models/model_zoo/roma_models.py +69 -84
- third_party/{Roma β RoMa}/roma/models/transformer/__init__.py +14 -46
- third_party/{Roma β RoMa}/roma/models/transformer/dinov2.py +23 -71
- third_party/{Roma β RoMa}/roma/models/transformer/layers/__init__.py +0 -0
- third_party/{Roma β RoMa}/roma/models/transformer/layers/attention.py +1 -5
- third_party/{Roma β RoMa}/roma/models/transformer/layers/block.py +13 -45
- third_party/{Roma β RoMa}/roma/models/transformer/layers/dino_head.py +2 -11
- third_party/{Roma β RoMa}/roma/models/transformer/layers/drop_path.py +1 -3
- third_party/{Roma β RoMa}/roma/models/transformer/layers/layer_scale.py +0 -0
- third_party/{Roma β RoMa}/roma/models/transformer/layers/mlp.py +0 -0
- third_party/{Roma β RoMa}/roma/models/transformer/layers/patch_embed.py +4 -16
- third_party/{Roma β RoMa}/roma/models/transformer/layers/swiglu_ffn.py +0 -0
- third_party/{Roma β RoMa}/roma/train/__init__.py +0 -0
- third_party/{Roma β RoMa}/roma/train/train.py +15 -39
- third_party/{Roma β RoMa}/roma/utils/__init__.py +0 -0
build_docker.sh
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
docker build -t image-matching-webui:latest . --no-cache
|
2 |
docker tag image-matching-webui:latest vincentqin/image-matching-webui:latest
|
3 |
docker push vincentqin/image-matching-webui:latest
|
|
|
|
1 |
docker build -t image-matching-webui:latest . --no-cache
|
2 |
docker tag image-matching-webui:latest vincentqin/image-matching-webui:latest
|
3 |
docker push vincentqin/image-matching-webui:latest
|
4 |
+
|
hloc/matchers/roma.py
CHANGED
@@ -6,7 +6,7 @@ from PIL import Image
|
|
6 |
from ..utils.base_model import BaseModel
|
7 |
from .. import logger
|
8 |
|
9 |
-
roma_path = Path(__file__).parent / "../../third_party/
|
10 |
sys.path.append(str(roma_path))
|
11 |
|
12 |
from roma.models.model_zoo.roma_models import roma_model
|
@@ -63,6 +63,8 @@ class Roma(BaseModel):
|
|
63 |
weights=weights,
|
64 |
dinov2_weights=dinov2_weights,
|
65 |
device=device,
|
|
|
|
|
66 |
)
|
67 |
logger.info(f"Load Roma model done.")
|
68 |
|
|
|
6 |
from ..utils.base_model import BaseModel
|
7 |
from .. import logger
|
8 |
|
9 |
+
roma_path = Path(__file__).parent / "../../third_party/RoMa"
|
10 |
sys.path.append(str(roma_path))
|
11 |
|
12 |
from roma.models.model_zoo.roma_models import roma_model
|
|
|
63 |
weights=weights,
|
64 |
dinov2_weights=dinov2_weights,
|
65 |
device=device,
|
66 |
+
#temp fix issue: https://github.com/Parskatt/RoMa/issues/26
|
67 |
+
amp_dtype=torch.float32,
|
68 |
)
|
69 |
logger.info(f"Load Roma model done.")
|
70 |
|
third_party/{Roma β RoMa}/.gitignore
RENAMED
File without changes
|
third_party/RoMa/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Johan Edstedt
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
third_party/{Roma β RoMa}/README.md
RENAMED
@@ -1,14 +1,29 @@
|
|
1 |
-
#
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
<br/>
|
4 |
-
|
5 |
-
|
6 |
-
>
|
7 |
-
>
|
8 |
-
|
9 |
-
**NOTE!!! Very early code, there might be bugs**
|
10 |
-
|
11 |
-
The codebase is in the [roma folder](roma).
|
12 |
|
13 |
## Setup/Install
|
14 |
In your python environment (tested on Linux python 3.10), run:
|
@@ -32,6 +47,19 @@ F, mask = cv2.findFundamentalMat(
|
|
32 |
kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
|
33 |
)
|
34 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
## Reproducing Results
|
36 |
The experiments in the paper are provided in the [experiments folder](experiments).
|
37 |
|
@@ -46,7 +74,8 @@ torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outd
|
|
46 |
python experiments/roma_outdoor.py --only_test --benchmark mega-1500
|
47 |
```
|
48 |
## License
|
49 |
-
|
|
|
50 |
|
51 |
## Acknowledgement
|
52 |
Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
|
@@ -54,10 +83,10 @@ Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
|
|
54 |
## BibTeX
|
55 |
If you find our models useful, please consider citing our paper!
|
56 |
```
|
57 |
-
@article{
|
58 |
-
title={{RoMa
|
59 |
author={Edstedt, Johan and Sun, Qiyu and BΓΆkman, Georg and WadenbΓ€ck, MΓ₯rten and Felsberg, Michael},
|
60 |
-
journal={
|
61 |
-
year={
|
62 |
}
|
63 |
```
|
|
|
1 |
+
#
|
2 |
+
<p align="center">
|
3 |
+
<h1 align="center"> <ins>RoMa</ins> ποΈ:<br> Robust Dense Feature Matching <br> βCVPR 2024β</h1>
|
4 |
+
<p align="center">
|
5 |
+
<a href="https://scholar.google.com/citations?user=Ul-vMR0AAAAJ">Johan Edstedt</a>
|
6 |
+
Β·
|
7 |
+
<a href="https://scholar.google.com/citations?user=HS2WuHkAAAAJ">Qiyu Sun</a>
|
8 |
+
Β·
|
9 |
+
<a href="https://scholar.google.com/citations?user=FUE3Wd0AAAAJ">Georg BΓΆkman</a>
|
10 |
+
Β·
|
11 |
+
<a href="https://scholar.google.com/citations?user=6WRQpCQAAAAJ">MΓ₯rten WadenbΓ€ck</a>
|
12 |
+
Β·
|
13 |
+
<a href="https://scholar.google.com/citations?user=lkWfR08AAAAJ">Michael Felsberg</a>
|
14 |
+
</p>
|
15 |
+
<h2 align="center"><p>
|
16 |
+
<a href="https://arxiv.org/abs/2305.15404" align="center">Paper</a> |
|
17 |
+
<a href="https://parskatt.github.io/RoMa" align="center">Project Page</a>
|
18 |
+
</p></h2>
|
19 |
+
<div align="center"></div>
|
20 |
+
</p>
|
21 |
<br/>
|
22 |
+
<p align="center">
|
23 |
+
<img src="https://github.com/Parskatt/RoMa/assets/22053118/15d8fea7-aa6d-479f-8a93-350d950d006b" alt="example" width=80%>
|
24 |
+
<br>
|
25 |
+
<em>RoMa is the robust dense feature matcher capable of estimating pixel-dense warps and reliable certainties for almost any image pair.</em>
|
26 |
+
</p>
|
|
|
|
|
|
|
27 |
|
28 |
## Setup/Install
|
29 |
In your python environment (tested on Linux python 3.10), run:
|
|
|
47 |
kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
|
48 |
)
|
49 |
```
|
50 |
+
|
51 |
+
**New**: You can also match arbitrary keypoints with RoMa. A demo for this will be added soon.
|
52 |
+
## Settings
|
53 |
+
|
54 |
+
### Resolution
|
55 |
+
By default RoMa uses an initial resolution of (560,560) which is then upsampled to (864,864).
|
56 |
+
You can change this at construction (see roma_outdoor kwargs).
|
57 |
+
You can also change this later, by changing the roma_model.w_resized, roma_model.h_resized, and roma_model.upsample_res.
|
58 |
+
|
59 |
+
### Sampling
|
60 |
+
roma_model.sample_thresh controls the thresholding used when sampling matches for estimation. In certain cases a lower or higher threshold may improve results.
|
61 |
+
|
62 |
+
|
63 |
## Reproducing Results
|
64 |
The experiments in the paper are provided in the [experiments folder](experiments).
|
65 |
|
|
|
74 |
python experiments/roma_outdoor.py --only_test --benchmark mega-1500
|
75 |
```
|
76 |
## License
|
77 |
+
All our code except DINOv2 is MIT license.
|
78 |
+
DINOv2 has an Apache 2 license [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE).
|
79 |
|
80 |
## Acknowledgement
|
81 |
Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
|
|
|
83 |
## BibTeX
|
84 |
If you find our models useful, please consider citing our paper!
|
85 |
```
|
86 |
+
@article{edstedt2024roma,
|
87 |
+
title={{RoMa: Robust Dense Feature Matching}},
|
88 |
author={Edstedt, Johan and Sun, Qiyu and BΓΆkman, Georg and WadenbΓ€ck, MΓ₯rten and Felsberg, Michael},
|
89 |
+
journal={IEEE Conference on Computer Vision and Pattern Recognition},
|
90 |
+
year={2024}
|
91 |
}
|
92 |
```
|
third_party/{Roma β RoMa}/assets/sacre_coeur_A.jpg
RENAMED
File without changes
|
third_party/{Roma β RoMa}/assets/sacre_coeur_B.jpg
RENAMED
File without changes
|
third_party/RoMa/assets/toronto_A.jpg
ADDED
Git LFS Details
|
third_party/RoMa/assets/toronto_B.jpg
ADDED
Git LFS Details
|
third_party/{Roma β RoMa}/data/.gitignore
RENAMED
File without changes
|
third_party/RoMa/demo/demo_3D_effect.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from roma.utils.utils import tensor_to_pil
|
6 |
+
|
7 |
+
from roma import roma_outdoor
|
8 |
+
|
9 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
+
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
from argparse import ArgumentParser
|
14 |
+
parser = ArgumentParser()
|
15 |
+
parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
|
16 |
+
parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
|
17 |
+
parser.add_argument("--save_path", default="demo/gif/roma_warp_toronto", type=str)
|
18 |
+
|
19 |
+
args, _ = parser.parse_known_args()
|
20 |
+
im1_path = args.im_A_path
|
21 |
+
im2_path = args.im_B_path
|
22 |
+
save_path = args.save_path
|
23 |
+
|
24 |
+
# Create model
|
25 |
+
roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
|
26 |
+
roma_model.symmetric = False
|
27 |
+
|
28 |
+
H, W = roma_model.get_output_resolution()
|
29 |
+
|
30 |
+
im1 = Image.open(im1_path).resize((W, H))
|
31 |
+
im2 = Image.open(im2_path).resize((W, H))
|
32 |
+
|
33 |
+
# Match
|
34 |
+
warp, certainty = roma_model.match(im1_path, im2_path, device=device)
|
35 |
+
# Sampling not needed, but can be done with model.sample(warp, certainty)
|
36 |
+
x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
|
37 |
+
x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
|
38 |
+
|
39 |
+
coords_A, coords_B = warp[...,:2], warp[...,2:]
|
40 |
+
for i, x in enumerate(np.linspace(0,2*np.pi,200)):
|
41 |
+
t = (1 + np.cos(x))/2
|
42 |
+
interp_warp = (1-t)*coords_A + t*coords_B
|
43 |
+
im2_transfer_rgb = F.grid_sample(
|
44 |
+
x2[None], interp_warp[None], mode="bilinear", align_corners=False
|
45 |
+
)[0]
|
46 |
+
tensor_to_pil(im2_transfer_rgb, unnormalize=False).save(f"{save_path}_{i:03d}.jpg")
|
third_party/{Roma β RoMa}/demo/demo_fundamental.py
RENAMED
@@ -3,12 +3,11 @@ import torch
|
|
3 |
import cv2
|
4 |
from roma import roma_outdoor
|
5 |
|
6 |
-
device = torch.device(
|
7 |
|
8 |
|
9 |
if __name__ == "__main__":
|
10 |
from argparse import ArgumentParser
|
11 |
-
|
12 |
parser = ArgumentParser()
|
13 |
parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
|
14 |
parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
|
@@ -20,6 +19,7 @@ if __name__ == "__main__":
|
|
20 |
# Create model
|
21 |
roma_model = roma_outdoor(device=device)
|
22 |
|
|
|
23 |
W_A, H_A = Image.open(im1_path).size
|
24 |
W_B, H_B = Image.open(im2_path).size
|
25 |
|
@@ -27,12 +27,7 @@ if __name__ == "__main__":
|
|
27 |
warp, certainty = roma_model.match(im1_path, im2_path, device=device)
|
28 |
# Sample matches for estimation
|
29 |
matches, certainty = roma_model.sample(warp, certainty)
|
30 |
-
kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
|
31 |
F, mask = cv2.findFundamentalMat(
|
32 |
-
kpts1.cpu().numpy(),
|
33 |
-
|
34 |
-
ransacReprojThreshold=0.2,
|
35 |
-
method=cv2.USAC_MAGSAC,
|
36 |
-
confidence=0.999999,
|
37 |
-
maxIters=10000,
|
38 |
-
)
|
|
|
3 |
import cv2
|
4 |
from roma import roma_outdoor
|
5 |
|
6 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
7 |
|
8 |
|
9 |
if __name__ == "__main__":
|
10 |
from argparse import ArgumentParser
|
|
|
11 |
parser = ArgumentParser()
|
12 |
parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
|
13 |
parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
|
|
|
19 |
# Create model
|
20 |
roma_model = roma_outdoor(device=device)
|
21 |
|
22 |
+
|
23 |
W_A, H_A = Image.open(im1_path).size
|
24 |
W_B, H_B = Image.open(im2_path).size
|
25 |
|
|
|
27 |
warp, certainty = roma_model.match(im1_path, im2_path, device=device)
|
28 |
# Sample matches for estimation
|
29 |
matches, certainty = roma_model.sample(warp, certainty)
|
30 |
+
kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
|
31 |
F, mask = cv2.findFundamentalMat(
|
32 |
+
kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
|
33 |
+
)
|
|
|
|
|
|
|
|
|
|
third_party/{Roma β RoMa}/demo/demo_match.py
RENAMED
@@ -4,20 +4,17 @@ import torch.nn.functional as F
|
|
4 |
import numpy as np
|
5 |
from roma.utils.utils import tensor_to_pil
|
6 |
|
7 |
-
from roma import
|
8 |
|
9 |
-
device = torch.device(
|
10 |
|
11 |
|
12 |
if __name__ == "__main__":
|
13 |
from argparse import ArgumentParser
|
14 |
-
|
15 |
parser = ArgumentParser()
|
16 |
-
parser.add_argument("--im_A_path", default="assets/
|
17 |
-
parser.add_argument("--im_B_path", default="assets/
|
18 |
-
parser.add_argument(
|
19 |
-
"--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str
|
20 |
-
)
|
21 |
|
22 |
args, _ = parser.parse_known_args()
|
23 |
im1_path = args.im_A_path
|
@@ -25,7 +22,7 @@ if __name__ == "__main__":
|
|
25 |
save_path = args.save_path
|
26 |
|
27 |
# Create model
|
28 |
-
roma_model =
|
29 |
|
30 |
H, W = roma_model.get_output_resolution()
|
31 |
|
@@ -39,12 +36,12 @@ if __name__ == "__main__":
|
|
39 |
x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
|
40 |
|
41 |
im2_transfer_rgb = F.grid_sample(
|
42 |
-
|
43 |
)[0]
|
44 |
im1_transfer_rgb = F.grid_sample(
|
45 |
-
|
46 |
)[0]
|
47 |
-
warp_im = torch.cat((im2_transfer_rgb,
|
48 |
-
white_im = torch.ones((H,
|
49 |
vis_im = certainty * warp_im + (1 - certainty) * white_im
|
50 |
-
tensor_to_pil(vis_im, unnormalize=False).save(save_path)
|
|
|
4 |
import numpy as np
|
5 |
from roma.utils.utils import tensor_to_pil
|
6 |
|
7 |
+
from roma import roma_outdoor
|
8 |
|
9 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
10 |
|
11 |
|
12 |
if __name__ == "__main__":
|
13 |
from argparse import ArgumentParser
|
|
|
14 |
parser = ArgumentParser()
|
15 |
+
parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
|
16 |
+
parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
|
17 |
+
parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
|
|
|
|
|
18 |
|
19 |
args, _ = parser.parse_known_args()
|
20 |
im1_path = args.im_A_path
|
|
|
22 |
save_path = args.save_path
|
23 |
|
24 |
# Create model
|
25 |
+
roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
|
26 |
|
27 |
H, W = roma_model.get_output_resolution()
|
28 |
|
|
|
36 |
x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
|
37 |
|
38 |
im2_transfer_rgb = F.grid_sample(
|
39 |
+
x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
|
40 |
)[0]
|
41 |
im1_transfer_rgb = F.grid_sample(
|
42 |
+
x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
|
43 |
)[0]
|
44 |
+
warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
|
45 |
+
white_im = torch.ones((H,2*W),device=device)
|
46 |
vis_im = certainty * warp_im + (1 - certainty) * white_im
|
47 |
+
tensor_to_pil(vis_im, unnormalize=False).save(save_path)
|
third_party/RoMa/demo/demo_match_opencv_sift.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import cv2 as cv
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
from argparse import ArgumentParser
|
12 |
+
parser = ArgumentParser()
|
13 |
+
parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
|
14 |
+
parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
|
15 |
+
parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
|
16 |
+
|
17 |
+
args, _ = parser.parse_known_args()
|
18 |
+
im1_path = args.im_A_path
|
19 |
+
im2_path = args.im_B_path
|
20 |
+
save_path = args.save_path
|
21 |
+
|
22 |
+
img1 = cv.imread(im1_path,cv.IMREAD_GRAYSCALE) # queryImage
|
23 |
+
img2 = cv.imread(im2_path,cv.IMREAD_GRAYSCALE) # trainImage
|
24 |
+
# Initiate SIFT detector
|
25 |
+
sift = cv.SIFT_create()
|
26 |
+
# find the keypoints and descriptors with SIFT
|
27 |
+
kp1, des1 = sift.detectAndCompute(img1,None)
|
28 |
+
kp2, des2 = sift.detectAndCompute(img2,None)
|
29 |
+
# BFMatcher with default params
|
30 |
+
bf = cv.BFMatcher()
|
31 |
+
matches = bf.knnMatch(des1,des2,k=2)
|
32 |
+
# Apply ratio test
|
33 |
+
good = []
|
34 |
+
for m,n in matches:
|
35 |
+
if m.distance < 0.75*n.distance:
|
36 |
+
good.append([m])
|
37 |
+
# cv.drawMatchesKnn expects list of lists as matches.
|
38 |
+
draw_params = dict(matchColor = (255,0,0), # draw matches in red color
|
39 |
+
singlePointColor = None,
|
40 |
+
flags = 2)
|
41 |
+
|
42 |
+
img3 = cv.drawMatchesKnn(img1,kp1,img2,kp2,good,None,**draw_params)
|
43 |
+
Image.fromarray(img3).save("demo/sift_matches.png")
|
third_party/RoMa/demo/gif/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
third_party/{Roma β RoMa}/pretrained/dinov2_vitl14_pretrain.pth
RENAMED
File without changes
|
third_party/{Roma β RoMa}/pretrained/roma_outdoor.pth
RENAMED
File without changes
|
third_party/{Roma β RoMa}/requirements.txt
RENAMED
@@ -10,4 +10,4 @@ matplotlib
|
|
10 |
h5py
|
11 |
wandb
|
12 |
timm
|
13 |
-
xformers # Optional, used for memefficient attention
|
|
|
10 |
h5py
|
11 |
wandb
|
12 |
timm
|
13 |
+
#xformers # Optional, used for memefficient attention
|
third_party/{Roma β RoMa}/roma/__init__.py
RENAMED
@@ -2,7 +2,7 @@ import os
|
|
2 |
from .models import roma_outdoor, roma_indoor
|
3 |
|
4 |
DEBUG_MODE = False
|
5 |
-
RANK = int(os.environ.get(
|
6 |
GLOBAL_STEP = 0
|
7 |
STEP_SIZE = 1
|
8 |
-
LOCAL_RANK = -1
|
|
|
2 |
from .models import roma_outdoor, roma_indoor
|
3 |
|
4 |
DEBUG_MODE = False
|
5 |
+
RANK = int(os.environ.get('RANK', default = 0))
|
6 |
GLOBAL_STEP = 0
|
7 |
STEP_SIZE = 1
|
8 |
+
LOCAL_RANK = -1
|
third_party/{Roma β RoMa}/roma/benchmarks/__init__.py
RENAMED
File without changes
|
third_party/{Roma β RoMa}/roma/benchmarks/hpatches_sequences_homog_benchmark.py
RENAMED
@@ -53,7 +53,7 @@ class HpatchesHomogBenchmark:
|
|
53 |
)
|
54 |
return im_A_coords, im_A_to_im_B
|
55 |
|
56 |
-
def benchmark(self, model, model_name=None):
|
57 |
n_matches = []
|
58 |
homog_dists = []
|
59 |
for seq_idx, seq_name in tqdm(
|
@@ -69,7 +69,9 @@ class HpatchesHomogBenchmark:
|
|
69 |
H = np.loadtxt(
|
70 |
os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
|
71 |
)
|
72 |
-
dense_matches, dense_certainty = model.match(
|
|
|
|
|
73 |
good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
|
74 |
pos_a, pos_b = self.convert_coordinates(
|
75 |
good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
|
@@ -78,9 +80,9 @@ class HpatchesHomogBenchmark:
|
|
78 |
H_pred, inliers = cv2.findHomography(
|
79 |
pos_a,
|
80 |
pos_b,
|
81 |
-
method=cv2.RANSAC,
|
82 |
-
confidence=0.99999,
|
83 |
-
ransacReprojThreshold=3 * min(w2, h2) / 480,
|
84 |
)
|
85 |
except:
|
86 |
H_pred = None
|
|
|
53 |
)
|
54 |
return im_A_coords, im_A_to_im_B
|
55 |
|
56 |
+
def benchmark(self, model, model_name = None):
|
57 |
n_matches = []
|
58 |
homog_dists = []
|
59 |
for seq_idx, seq_name in tqdm(
|
|
|
69 |
H = np.loadtxt(
|
70 |
os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
|
71 |
)
|
72 |
+
dense_matches, dense_certainty = model.match(
|
73 |
+
im_A_path, im_B_path
|
74 |
+
)
|
75 |
good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
|
76 |
pos_a, pos_b = self.convert_coordinates(
|
77 |
good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
|
|
|
80 |
H_pred, inliers = cv2.findHomography(
|
81 |
pos_a,
|
82 |
pos_b,
|
83 |
+
method = cv2.RANSAC,
|
84 |
+
confidence = 0.99999,
|
85 |
+
ransacReprojThreshold = 3 * min(w2, h2) / 480,
|
86 |
)
|
87 |
except:
|
88 |
H_pred = None
|
third_party/{Roma β RoMa}/roma/benchmarks/megadepth_dense_benchmark.py
RENAMED
@@ -6,11 +6,8 @@ from roma.utils import warp_kpts
|
|
6 |
from torch.utils.data import ConcatDataset
|
7 |
import roma
|
8 |
|
9 |
-
|
10 |
class MegadepthDenseBenchmark:
|
11 |
-
def __init__(
|
12 |
-
self, data_root="data/megadepth", h=384, w=512, num_samples=2000
|
13 |
-
) -> None:
|
14 |
mega = MegadepthBuilder(data_root=data_root)
|
15 |
self.dataset = ConcatDataset(
|
16 |
mega.build_scenes(split="test_loftr", ht=h, wt=w)
|
@@ -52,15 +49,13 @@ class MegadepthDenseBenchmark:
|
|
52 |
pck_3_tot = 0.0
|
53 |
pck_5_tot = 0.0
|
54 |
sampler = torch.utils.data.WeightedRandomSampler(
|
55 |
-
torch.ones(len(self.dataset)),
|
56 |
-
replacement=False,
|
57 |
-
num_samples=self.num_samples,
|
58 |
)
|
59 |
B = batch_size
|
60 |
dataloader = torch.utils.data.DataLoader(
|
61 |
self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
|
62 |
)
|
63 |
-
for idx, data in tqdm.tqdm(enumerate(dataloader), disable=roma.RANK > 0):
|
64 |
im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
|
65 |
data["im_A"],
|
66 |
data["im_B"],
|
@@ -77,36 +72,25 @@ class MegadepthDenseBenchmark:
|
|
77 |
if roma.DEBUG_MODE:
|
78 |
from roma.utils.utils import tensor_to_pil
|
79 |
import torch.nn.functional as F
|
80 |
-
|
81 |
path = "vis"
|
82 |
H, W = model.get_output_resolution()
|
83 |
-
white_im = torch.ones((B,
|
84 |
im_B_transfer_rgb = F.grid_sample(
|
85 |
-
im_B.cuda(),
|
86 |
-
matches[:, :, :W, 2:],
|
87 |
-
mode="bilinear",
|
88 |
-
align_corners=False,
|
89 |
)
|
90 |
warp_im = im_B_transfer_rgb
|
91 |
-
c_b = certainty[
|
92 |
-
:, None
|
93 |
-
] # (certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
|
94 |
vis_im = c_b * warp_im + (1 - c_b) * white_im
|
95 |
for b in range(B):
|
96 |
import os
|
97 |
-
|
98 |
-
os.makedirs(
|
99 |
-
f"{path}/{model.name}/{idx}_{b}_{H}_{W}", exist_ok=True
|
100 |
-
)
|
101 |
tensor_to_pil(vis_im[b], unnormalize=True).save(
|
102 |
-
f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg"
|
103 |
-
)
|
104 |
tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
|
105 |
-
f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg"
|
106 |
-
)
|
107 |
tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
|
108 |
-
f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg"
|
109 |
-
|
110 |
|
111 |
gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
|
112 |
gd_tot + gd.mean(),
|
|
|
6 |
from torch.utils.data import ConcatDataset
|
7 |
import roma
|
8 |
|
|
|
9 |
class MegadepthDenseBenchmark:
|
10 |
+
def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None:
|
|
|
|
|
11 |
mega = MegadepthBuilder(data_root=data_root)
|
12 |
self.dataset = ConcatDataset(
|
13 |
mega.build_scenes(split="test_loftr", ht=h, wt=w)
|
|
|
49 |
pck_3_tot = 0.0
|
50 |
pck_5_tot = 0.0
|
51 |
sampler = torch.utils.data.WeightedRandomSampler(
|
52 |
+
torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
|
|
|
|
|
53 |
)
|
54 |
B = batch_size
|
55 |
dataloader = torch.utils.data.DataLoader(
|
56 |
self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
|
57 |
)
|
58 |
+
for idx, data in tqdm.tqdm(enumerate(dataloader), disable = roma.RANK > 0):
|
59 |
im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
|
60 |
data["im_A"],
|
61 |
data["im_B"],
|
|
|
72 |
if roma.DEBUG_MODE:
|
73 |
from roma.utils.utils import tensor_to_pil
|
74 |
import torch.nn.functional as F
|
|
|
75 |
path = "vis"
|
76 |
H, W = model.get_output_resolution()
|
77 |
+
white_im = torch.ones((B,1,H,W),device="cuda")
|
78 |
im_B_transfer_rgb = F.grid_sample(
|
79 |
+
im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False
|
|
|
|
|
|
|
80 |
)
|
81 |
warp_im = im_B_transfer_rgb
|
82 |
+
c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
|
|
|
|
|
83 |
vis_im = c_b * warp_im + (1 - c_b) * white_im
|
84 |
for b in range(B):
|
85 |
import os
|
86 |
+
os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True)
|
|
|
|
|
|
|
87 |
tensor_to_pil(vis_im[b], unnormalize=True).save(
|
88 |
+
f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg")
|
|
|
89 |
tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
|
90 |
+
f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg")
|
|
|
91 |
tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
|
92 |
+
f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg")
|
93 |
+
|
94 |
|
95 |
gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
|
96 |
gd_tot + gd.mean(),
|
third_party/{Roma β RoMa}/roma/benchmarks/megadepth_pose_estimation_benchmark.py
RENAMED
@@ -7,9 +7,8 @@ import torch.nn.functional as F
|
|
7 |
import roma
|
8 |
import kornia.geometry.epipolar as kepi
|
9 |
|
10 |
-
|
11 |
class MegaDepthPoseEstimationBenchmark:
|
12 |
-
def __init__(self, data_root="data/megadepth", scene_names=None) -> None:
|
13 |
if scene_names is None:
|
14 |
self.scene_names = [
|
15 |
"0015_0.1_0.3.npz",
|
@@ -26,22 +25,13 @@ class MegaDepthPoseEstimationBenchmark:
|
|
26 |
]
|
27 |
self.data_root = data_root
|
28 |
|
29 |
-
def benchmark(
|
30 |
-
self,
|
31 |
-
model,
|
32 |
-
model_name=None,
|
33 |
-
resolution=None,
|
34 |
-
scale_intrinsics=True,
|
35 |
-
calibrated=True,
|
36 |
-
):
|
37 |
-
H, W = model.get_output_resolution()
|
38 |
with torch.no_grad():
|
39 |
data_root = self.data_root
|
40 |
tot_e_t, tot_e_R, tot_e_pose = [], [], []
|
41 |
thresholds = [5, 10, 20]
|
42 |
for scene_ind in range(len(self.scenes)):
|
43 |
import os
|
44 |
-
|
45 |
scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
|
46 |
scene = self.scenes[scene_ind]
|
47 |
pairs = scene["pair_infos"]
|
@@ -58,22 +48,21 @@ class MegaDepthPoseEstimationBenchmark:
|
|
58 |
T2 = poses[idx2].copy()
|
59 |
R2, t2 = T2[:3, :3], T2[:3, 3]
|
60 |
R, t = compute_relative_pose(R1, t1, R2, t2)
|
61 |
-
T1_to_2 = np.concatenate((R,
|
62 |
im_A_path = f"{data_root}/{im_paths[idx1]}"
|
63 |
im_B_path = f"{data_root}/{im_paths[idx2]}"
|
64 |
dense_matches, dense_certainty = model.match(
|
65 |
im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
|
66 |
)
|
67 |
-
sparse_matches,
|
68 |
-
dense_matches, dense_certainty,
|
69 |
)
|
70 |
-
|
71 |
im_A = Image.open(im_A_path)
|
72 |
w1, h1 = im_A.size
|
73 |
im_B = Image.open(im_B_path)
|
74 |
w2, h2 = im_B.size
|
75 |
-
|
76 |
-
if scale_intrinsics:
|
77 |
scale1 = 1200 / max(w1, h1)
|
78 |
scale2 = 1200 / max(w2, h2)
|
79 |
w1, h1 = scale1 * w1, scale1 * h1
|
@@ -82,42 +71,23 @@ class MegaDepthPoseEstimationBenchmark:
|
|
82 |
K1[:2] = K1[:2] * scale1
|
83 |
K2[:2] = K2[:2] * scale2
|
84 |
|
85 |
-
kpts1 = sparse_matches
|
86 |
-
kpts1 =
|
87 |
-
(
|
88 |
-
w1 * (kpts1[:, 0] + 1) / 2,
|
89 |
-
h1 * (kpts1[:, 1] + 1) / 2,
|
90 |
-
),
|
91 |
-
axis=-1,
|
92 |
-
)
|
93 |
-
kpts2 = sparse_matches[:, 2:]
|
94 |
-
kpts2 = np.stack(
|
95 |
-
(
|
96 |
-
w2 * (kpts2[:, 0] + 1) / 2,
|
97 |
-
h2 * (kpts2[:, 1] + 1) / 2,
|
98 |
-
),
|
99 |
-
axis=-1,
|
100 |
-
)
|
101 |
-
|
102 |
for _ in range(5):
|
103 |
shuffling = np.random.permutation(np.arange(len(kpts1)))
|
104 |
kpts1 = kpts1[shuffling]
|
105 |
kpts2 = kpts2[shuffling]
|
106 |
try:
|
107 |
-
threshold = 0.5
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
K2,
|
118 |
-
norm_threshold,
|
119 |
-
conf=0.99999,
|
120 |
-
)
|
121 |
T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
|
122 |
e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
|
123 |
e_pose = max(e_t, e_R)
|
|
|
7 |
import roma
|
8 |
import kornia.geometry.epipolar as kepi
|
9 |
|
|
|
10 |
class MegaDepthPoseEstimationBenchmark:
|
11 |
+
def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
|
12 |
if scene_names is None:
|
13 |
self.scene_names = [
|
14 |
"0015_0.1_0.3.npz",
|
|
|
25 |
]
|
26 |
self.data_root = data_root
|
27 |
|
28 |
+
def benchmark(self, model, model_name = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
with torch.no_grad():
|
30 |
data_root = self.data_root
|
31 |
tot_e_t, tot_e_R, tot_e_pose = [], [], []
|
32 |
thresholds = [5, 10, 20]
|
33 |
for scene_ind in range(len(self.scenes)):
|
34 |
import os
|
|
|
35 |
scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
|
36 |
scene = self.scenes[scene_ind]
|
37 |
pairs = scene["pair_infos"]
|
|
|
48 |
T2 = poses[idx2].copy()
|
49 |
R2, t2 = T2[:3, :3], T2[:3, 3]
|
50 |
R, t = compute_relative_pose(R1, t1, R2, t2)
|
51 |
+
T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
|
52 |
im_A_path = f"{data_root}/{im_paths[idx1]}"
|
53 |
im_B_path = f"{data_root}/{im_paths[idx2]}"
|
54 |
dense_matches, dense_certainty = model.match(
|
55 |
im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
|
56 |
)
|
57 |
+
sparse_matches,_ = model.sample(
|
58 |
+
dense_matches, dense_certainty, 5_000
|
59 |
)
|
60 |
+
|
61 |
im_A = Image.open(im_A_path)
|
62 |
w1, h1 = im_A.size
|
63 |
im_B = Image.open(im_B_path)
|
64 |
w2, h2 = im_B.size
|
65 |
+
if True: # Note: we keep this true as it was used in DKM/RoMa papers. There is very little difference compared to setting to False.
|
|
|
66 |
scale1 = 1200 / max(w1, h1)
|
67 |
scale2 = 1200 / max(w2, h2)
|
68 |
w1, h1 = scale1 * w1, scale1 * h1
|
|
|
71 |
K1[:2] = K1[:2] * scale1
|
72 |
K2[:2] = K2[:2] * scale2
|
73 |
|
74 |
+
kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
|
75 |
+
kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
for _ in range(5):
|
77 |
shuffling = np.random.permutation(np.arange(len(kpts1)))
|
78 |
kpts1 = kpts1[shuffling]
|
79 |
kpts2 = kpts2[shuffling]
|
80 |
try:
|
81 |
+
threshold = 0.5
|
82 |
+
norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
|
83 |
+
R_est, t_est, mask = estimate_pose(
|
84 |
+
kpts1,
|
85 |
+
kpts2,
|
86 |
+
K1,
|
87 |
+
K2,
|
88 |
+
norm_threshold,
|
89 |
+
conf=0.99999,
|
90 |
+
)
|
|
|
|
|
|
|
|
|
91 |
T1_to_2_est = np.concatenate((R_est, t_est), axis=-1) #
|
92 |
e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
|
93 |
e_pose = max(e_t, e_R)
|
third_party/{Roma β RoMa}/roma/benchmarks/scannet_benchmark.py
RENAMED
@@ -10,7 +10,7 @@ class ScanNetBenchmark:
|
|
10 |
def __init__(self, data_root="data/scannet") -> None:
|
11 |
self.data_root = data_root
|
12 |
|
13 |
-
def benchmark(self, model, model_name=None):
|
14 |
model.train(False)
|
15 |
with torch.no_grad():
|
16 |
data_root = self.data_root
|
@@ -24,20 +24,20 @@ class ScanNetBenchmark:
|
|
24 |
scene = pairs[pairind]
|
25 |
scene_name = f"scene0{scene[0]}_00"
|
26 |
im_A_path = osp.join(
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
im_A = Image.open(im_A_path)
|
34 |
im_B_path = osp.join(
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
im_B = Image.open(im_B_path)
|
42 |
T_gt = rel_pose[pairind].reshape(3, 4)
|
43 |
R, t = T_gt[:3, :3], T_gt[:3, 3]
|
@@ -76,20 +76,24 @@ class ScanNetBenchmark:
|
|
76 |
|
77 |
offset = 0.5
|
78 |
kpts1 = sparse_matches[:, :2]
|
79 |
-
kpts1 =
|
80 |
-
(
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
85 |
)
|
86 |
kpts2 = sparse_matches[:, 2:]
|
87 |
-
kpts2 =
|
88 |
-
(
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
93 |
)
|
94 |
for _ in range(5):
|
95 |
shuffling = np.random.permutation(np.arange(len(kpts1)))
|
@@ -97,8 +101,7 @@ class ScanNetBenchmark:
|
|
97 |
kpts2 = kpts2[shuffling]
|
98 |
try:
|
99 |
norm_threshold = 0.5 / (
|
100 |
-
|
101 |
-
)
|
102 |
R_est, t_est, mask = estimate_pose(
|
103 |
kpts1,
|
104 |
kpts2,
|
|
|
10 |
def __init__(self, data_root="data/scannet") -> None:
|
11 |
self.data_root = data_root
|
12 |
|
13 |
+
def benchmark(self, model, model_name = None):
|
14 |
model.train(False)
|
15 |
with torch.no_grad():
|
16 |
data_root = self.data_root
|
|
|
24 |
scene = pairs[pairind]
|
25 |
scene_name = f"scene0{scene[0]}_00"
|
26 |
im_A_path = osp.join(
|
27 |
+
self.data_root,
|
28 |
+
"scans_test",
|
29 |
+
scene_name,
|
30 |
+
"color",
|
31 |
+
f"{scene[2]}.jpg",
|
32 |
+
)
|
33 |
im_A = Image.open(im_A_path)
|
34 |
im_B_path = osp.join(
|
35 |
+
self.data_root,
|
36 |
+
"scans_test",
|
37 |
+
scene_name,
|
38 |
+
"color",
|
39 |
+
f"{scene[3]}.jpg",
|
40 |
+
)
|
41 |
im_B = Image.open(im_B_path)
|
42 |
T_gt = rel_pose[pairind].reshape(3, 4)
|
43 |
R, t = T_gt[:3, :3], T_gt[:3, 3]
|
|
|
76 |
|
77 |
offset = 0.5
|
78 |
kpts1 = sparse_matches[:, :2]
|
79 |
+
kpts1 = (
|
80 |
+
np.stack(
|
81 |
+
(
|
82 |
+
w1 * (kpts1[:, 0] + 1) / 2 - offset,
|
83 |
+
h1 * (kpts1[:, 1] + 1) / 2 - offset,
|
84 |
+
),
|
85 |
+
axis=-1,
|
86 |
+
)
|
87 |
)
|
88 |
kpts2 = sparse_matches[:, 2:]
|
89 |
+
kpts2 = (
|
90 |
+
np.stack(
|
91 |
+
(
|
92 |
+
w2 * (kpts2[:, 0] + 1) / 2 - offset,
|
93 |
+
h2 * (kpts2[:, 1] + 1) / 2 - offset,
|
94 |
+
),
|
95 |
+
axis=-1,
|
96 |
+
)
|
97 |
)
|
98 |
for _ in range(5):
|
99 |
shuffling = np.random.permutation(np.arange(len(kpts1)))
|
|
|
101 |
kpts2 = kpts2[shuffling]
|
102 |
try:
|
103 |
norm_threshold = 0.5 / (
|
104 |
+
np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
|
|
|
105 |
R_est, t_est, mask = estimate_pose(
|
106 |
kpts1,
|
107 |
kpts2,
|
third_party/{Roma β RoMa}/roma/checkpointing/__init__.py
RENAMED
File without changes
|
third_party/{Roma β RoMa}/roma/checkpointing/checkpoint.py
RENAMED
@@ -7,7 +7,6 @@ import gc
|
|
7 |
|
8 |
import roma
|
9 |
|
10 |
-
|
11 |
class CheckPoint:
|
12 |
def __init__(self, dir=None, name="tmp"):
|
13 |
self.name = name
|
@@ -20,7 +19,7 @@ class CheckPoint:
|
|
20 |
optimizer,
|
21 |
lr_scheduler,
|
22 |
n,
|
23 |
-
|
24 |
if roma.RANK == 0:
|
25 |
assert model is not None
|
26 |
if isinstance(model, (DataParallel, DistributedDataParallel)):
|
@@ -33,14 +32,14 @@ class CheckPoint:
|
|
33 |
}
|
34 |
torch.save(states, self.dir + self.name + f"_latest.pth")
|
35 |
logger.info(f"Saved states {list(states.keys())}, at step {n}")
|
36 |
-
|
37 |
def load(
|
38 |
self,
|
39 |
model,
|
40 |
optimizer,
|
41 |
lr_scheduler,
|
42 |
n,
|
43 |
-
|
44 |
if os.path.exists(self.dir + self.name + f"_latest.pth") and roma.RANK == 0:
|
45 |
states = torch.load(self.dir + self.name + f"_latest.pth")
|
46 |
if "model" in states:
|
@@ -58,4 +57,4 @@ class CheckPoint:
|
|
58 |
del states
|
59 |
gc.collect()
|
60 |
torch.cuda.empty_cache()
|
61 |
-
return model, optimizer, lr_scheduler, n
|
|
|
7 |
|
8 |
import roma
|
9 |
|
|
|
10 |
class CheckPoint:
|
11 |
def __init__(self, dir=None, name="tmp"):
|
12 |
self.name = name
|
|
|
19 |
optimizer,
|
20 |
lr_scheduler,
|
21 |
n,
|
22 |
+
):
|
23 |
if roma.RANK == 0:
|
24 |
assert model is not None
|
25 |
if isinstance(model, (DataParallel, DistributedDataParallel)):
|
|
|
32 |
}
|
33 |
torch.save(states, self.dir + self.name + f"_latest.pth")
|
34 |
logger.info(f"Saved states {list(states.keys())}, at step {n}")
|
35 |
+
|
36 |
def load(
|
37 |
self,
|
38 |
model,
|
39 |
optimizer,
|
40 |
lr_scheduler,
|
41 |
n,
|
42 |
+
):
|
43 |
if os.path.exists(self.dir + self.name + f"_latest.pth") and roma.RANK == 0:
|
44 |
states = torch.load(self.dir + self.name + f"_latest.pth")
|
45 |
if "model" in states:
|
|
|
57 |
del states
|
58 |
gc.collect()
|
59 |
torch.cuda.empty_cache()
|
60 |
+
return model, optimizer, lr_scheduler, n
|
third_party/{Roma β RoMa}/roma/datasets/__init__.py
RENAMED
@@ -1,2 +1,2 @@
|
|
1 |
from .megadepth import MegadepthBuilder
|
2 |
-
from .scannet import ScanNetBuilder
|
|
|
1 |
from .megadepth import MegadepthBuilder
|
2 |
+
from .scannet import ScanNetBuilder
|
third_party/{Roma β RoMa}/roma/datasets/megadepth.py
RENAMED
@@ -10,7 +10,6 @@ import roma
|
|
10 |
from roma.utils import *
|
11 |
import math
|
12 |
|
13 |
-
|
14 |
class MegadepthScene:
|
15 |
def __init__(
|
16 |
self,
|
@@ -23,20 +22,18 @@ class MegadepthScene:
|
|
23 |
shake_t=0,
|
24 |
rot_prob=0.0,
|
25 |
normalize=True,
|
26 |
-
max_num_pairs=100_000,
|
27 |
-
scene_name=None,
|
28 |
-
use_horizontal_flip_aug=False,
|
29 |
-
use_single_horizontal_flip_aug=False,
|
30 |
-
colorjiggle_params=None,
|
31 |
-
random_eraser=None,
|
32 |
-
use_randaug=False,
|
33 |
-
randaug_params=None,
|
34 |
-
randomize_size=False,
|
35 |
) -> None:
|
36 |
self.data_root = data_root
|
37 |
-
self.scene_name = (
|
38 |
-
os.path.splitext(scene_name)[0] + f"_{min_overlap}_{max_overlap}"
|
39 |
-
)
|
40 |
self.image_paths = scene_info["image_paths"]
|
41 |
self.depth_paths = scene_info["depth_paths"]
|
42 |
self.intrinsics = scene_info["intrinsics"]
|
@@ -54,18 +51,18 @@ class MegadepthScene:
|
|
54 |
self.overlaps = self.overlaps[pairinds]
|
55 |
if randomize_size:
|
56 |
area = ht * wt
|
57 |
-
s = int(16 * (math.sqrt(area)
|
58 |
-
sizes = ((ht,
|
59 |
choice = roma.RANK % 3
|
60 |
-
ht, wt = sizes[choice]
|
61 |
# counts, bins = np.histogram(self.overlaps,20)
|
62 |
# print(counts)
|
63 |
self.im_transform_ops = get_tuple_transform_ops(
|
64 |
-
resize=(ht, wt),
|
65 |
-
normalize=normalize,
|
66 |
-
colorjiggle_params=colorjiggle_params,
|
67 |
)
|
68 |
-
self.depth_transform_ops = get_depth_tuple_transform_ops(
|
|
|
|
|
69 |
self.wt, self.ht = wt, ht
|
70 |
self.shake_t = shake_t
|
71 |
self.random_eraser = random_eraser
|
@@ -78,19 +75,17 @@ class MegadepthScene:
|
|
78 |
def load_im(self, im_path):
|
79 |
im = Image.open(im_path)
|
80 |
return im
|
81 |
-
|
82 |
-
def horizontal_flip(self, im_A, im_B, depth_A, depth_B,
|
83 |
im_A = im_A.flip(-1)
|
84 |
im_B = im_B.flip(-1)
|
85 |
-
depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
|
86 |
-
flip_mat = torch.tensor([[-1, 0, self.wt],
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
K_B = flip_mat @ K_B
|
91 |
-
|
92 |
return im_A, im_B, depth_A, depth_B, K_A, K_B
|
93 |
-
|
94 |
def load_depth(self, depth_ref, crop=None):
|
95 |
depth = np.array(h5py.File(depth_ref, "r")["depth"])
|
96 |
return torch.from_numpy(depth)
|
@@ -145,31 +140,29 @@ class MegadepthScene:
|
|
145 |
depth_A, depth_B = self.depth_transform_ops(
|
146 |
(depth_A[None, None], depth_B[None, None])
|
147 |
)
|
148 |
-
|
149 |
-
[im_A, im_B, depth_A, depth_B], t = self.rand_shake(
|
150 |
-
im_A, im_B, depth_A, depth_B
|
151 |
-
)
|
152 |
K1[:2, 2] += t
|
153 |
K2[:2, 2] += t
|
154 |
-
|
155 |
im_A, im_B = im_A[None], im_B[None]
|
156 |
if self.random_eraser is not None:
|
157 |
im_A, depth_A = self.random_eraser(im_A, depth_A)
|
158 |
im_B, depth_B = self.random_eraser(im_B, depth_B)
|
159 |
-
|
160 |
if self.use_horizontal_flip_aug:
|
161 |
if np.random.rand() > 0.5:
|
162 |
-
im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(
|
163 |
-
im_A, im_B, depth_A, depth_B, K1, K2
|
164 |
-
)
|
165 |
if self.use_single_horizontal_flip_aug:
|
166 |
if np.random.rand() > 0.5:
|
167 |
im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2)
|
168 |
-
|
169 |
if roma.DEBUG_MODE:
|
170 |
-
tensor_to_pil(im_A[0], unnormalize=True).save(
|
171 |
-
|
172 |
-
|
|
|
|
|
173 |
data_dict = {
|
174 |
"im_A": im_A[0],
|
175 |
"im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
|
@@ -182,53 +175,25 @@ class MegadepthScene:
|
|
182 |
"T_1to2": T_1to2,
|
183 |
"im_A_path": im_A_ref,
|
184 |
"im_B_path": im_B_ref,
|
|
|
185 |
}
|
186 |
return data_dict
|
187 |
|
188 |
|
189 |
class MegadepthBuilder:
|
190 |
-
def __init__(
|
191 |
-
self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore=True
|
192 |
-
) -> None:
|
193 |
self.data_root = data_root
|
194 |
self.scene_info_root = os.path.join(data_root, "prep_scene_info")
|
195 |
self.all_scenes = os.listdir(self.scene_info_root)
|
196 |
self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
|
197 |
# LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
|
198 |
-
self.loftr_ignore_scenes = set(
|
199 |
-
|
200 |
-
"0121.npy",
|
201 |
-
"0133.npy",
|
202 |
-
"0168.npy",
|
203 |
-
"0178.npy",
|
204 |
-
"0229.npy",
|
205 |
-
"0349.npy",
|
206 |
-
"0412.npy",
|
207 |
-
"0430.npy",
|
208 |
-
"0443.npy",
|
209 |
-
"1001.npy",
|
210 |
-
"5014.npy",
|
211 |
-
"5015.npy",
|
212 |
-
"5016.npy",
|
213 |
-
]
|
214 |
-
)
|
215 |
-
self.imc21_scenes = set(
|
216 |
-
[
|
217 |
-
"0008.npy",
|
218 |
-
"0019.npy",
|
219 |
-
"0021.npy",
|
220 |
-
"0024.npy",
|
221 |
-
"0025.npy",
|
222 |
-
"0032.npy",
|
223 |
-
"0063.npy",
|
224 |
-
"1589.npy",
|
225 |
-
]
|
226 |
-
)
|
227 |
self.test_scenes_loftr = ["0015.npy", "0022.npy"]
|
228 |
self.loftr_ignore = loftr_ignore
|
229 |
self.imc21_ignore = imc21_ignore
|
230 |
|
231 |
-
def build_scenes(self, split="train", min_overlap=0.0, scene_names=None, **kwargs):
|
232 |
if split == "train":
|
233 |
scene_names = set(self.all_scenes) - set(self.test_scenes)
|
234 |
elif split == "train_loftr":
|
@@ -252,11 +217,7 @@ class MegadepthBuilder:
|
|
252 |
).item()
|
253 |
scenes.append(
|
254 |
MegadepthScene(
|
255 |
-
self.data_root,
|
256 |
-
scene_info,
|
257 |
-
min_overlap=min_overlap,
|
258 |
-
scene_name=scene_name,
|
259 |
-
**kwargs,
|
260 |
)
|
261 |
)
|
262 |
return scenes
|
|
|
10 |
from roma.utils import *
|
11 |
import math
|
12 |
|
|
|
13 |
class MegadepthScene:
|
14 |
def __init__(
|
15 |
self,
|
|
|
22 |
shake_t=0,
|
23 |
rot_prob=0.0,
|
24 |
normalize=True,
|
25 |
+
max_num_pairs = 100_000,
|
26 |
+
scene_name = None,
|
27 |
+
use_horizontal_flip_aug = False,
|
28 |
+
use_single_horizontal_flip_aug = False,
|
29 |
+
colorjiggle_params = None,
|
30 |
+
random_eraser = None,
|
31 |
+
use_randaug = False,
|
32 |
+
randaug_params = None,
|
33 |
+
randomize_size = False,
|
34 |
) -> None:
|
35 |
self.data_root = data_root
|
36 |
+
self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}"
|
|
|
|
|
37 |
self.image_paths = scene_info["image_paths"]
|
38 |
self.depth_paths = scene_info["depth_paths"]
|
39 |
self.intrinsics = scene_info["intrinsics"]
|
|
|
51 |
self.overlaps = self.overlaps[pairinds]
|
52 |
if randomize_size:
|
53 |
area = ht * wt
|
54 |
+
s = int(16 * (math.sqrt(area)//16))
|
55 |
+
sizes = ((ht,wt), (s,s), (wt,ht))
|
56 |
choice = roma.RANK % 3
|
57 |
+
ht, wt = sizes[choice]
|
58 |
# counts, bins = np.histogram(self.overlaps,20)
|
59 |
# print(counts)
|
60 |
self.im_transform_ops = get_tuple_transform_ops(
|
61 |
+
resize=(ht, wt), normalize=normalize, colorjiggle_params = colorjiggle_params,
|
|
|
|
|
62 |
)
|
63 |
+
self.depth_transform_ops = get_depth_tuple_transform_ops(
|
64 |
+
resize=(ht, wt)
|
65 |
+
)
|
66 |
self.wt, self.ht = wt, ht
|
67 |
self.shake_t = shake_t
|
68 |
self.random_eraser = random_eraser
|
|
|
75 |
def load_im(self, im_path):
|
76 |
im = Image.open(im_path)
|
77 |
return im
|
78 |
+
|
79 |
+
def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
|
80 |
im_A = im_A.flip(-1)
|
81 |
im_B = im_B.flip(-1)
|
82 |
+
depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
|
83 |
+
flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
|
84 |
+
K_A = flip_mat@K_A
|
85 |
+
K_B = flip_mat@K_B
|
86 |
+
|
|
|
|
|
87 |
return im_A, im_B, depth_A, depth_B, K_A, K_B
|
88 |
+
|
89 |
def load_depth(self, depth_ref, crop=None):
|
90 |
depth = np.array(h5py.File(depth_ref, "r")["depth"])
|
91 |
return torch.from_numpy(depth)
|
|
|
140 |
depth_A, depth_B = self.depth_transform_ops(
|
141 |
(depth_A[None, None], depth_B[None, None])
|
142 |
)
|
143 |
+
|
144 |
+
[im_A, im_B, depth_A, depth_B], t = self.rand_shake(im_A, im_B, depth_A, depth_B)
|
|
|
|
|
145 |
K1[:2, 2] += t
|
146 |
K2[:2, 2] += t
|
147 |
+
|
148 |
im_A, im_B = im_A[None], im_B[None]
|
149 |
if self.random_eraser is not None:
|
150 |
im_A, depth_A = self.random_eraser(im_A, depth_A)
|
151 |
im_B, depth_B = self.random_eraser(im_B, depth_B)
|
152 |
+
|
153 |
if self.use_horizontal_flip_aug:
|
154 |
if np.random.rand() > 0.5:
|
155 |
+
im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
|
|
|
|
|
156 |
if self.use_single_horizontal_flip_aug:
|
157 |
if np.random.rand() > 0.5:
|
158 |
im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2)
|
159 |
+
|
160 |
if roma.DEBUG_MODE:
|
161 |
+
tensor_to_pil(im_A[0], unnormalize=True).save(
|
162 |
+
f"vis/im_A.jpg")
|
163 |
+
tensor_to_pil(im_B[0], unnormalize=True).save(
|
164 |
+
f"vis/im_B.jpg")
|
165 |
+
|
166 |
data_dict = {
|
167 |
"im_A": im_A[0],
|
168 |
"im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
|
|
|
175 |
"T_1to2": T_1to2,
|
176 |
"im_A_path": im_A_ref,
|
177 |
"im_B_path": im_B_ref,
|
178 |
+
|
179 |
}
|
180 |
return data_dict
|
181 |
|
182 |
|
183 |
class MegadepthBuilder:
|
184 |
+
def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None:
|
|
|
|
|
185 |
self.data_root = data_root
|
186 |
self.scene_info_root = os.path.join(data_root, "prep_scene_info")
|
187 |
self.all_scenes = os.listdir(self.scene_info_root)
|
188 |
self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
|
189 |
# LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
|
190 |
+
self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy'])
|
191 |
+
self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
self.test_scenes_loftr = ["0015.npy", "0022.npy"]
|
193 |
self.loftr_ignore = loftr_ignore
|
194 |
self.imc21_ignore = imc21_ignore
|
195 |
|
196 |
+
def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs):
|
197 |
if split == "train":
|
198 |
scene_names = set(self.all_scenes) - set(self.test_scenes)
|
199 |
elif split == "train_loftr":
|
|
|
217 |
).item()
|
218 |
scenes.append(
|
219 |
MegadepthScene(
|
220 |
+
self.data_root, scene_info, min_overlap=min_overlap,scene_name = scene_name, **kwargs
|
|
|
|
|
|
|
|
|
221 |
)
|
222 |
)
|
223 |
return scenes
|
third_party/{Roma β RoMa}/roma/datasets/scannet.py
RENAMED
@@ -5,7 +5,10 @@ import cv2
|
|
5 |
import h5py
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
-
from torch.utils.data import
|
|
|
|
|
|
|
9 |
|
10 |
import torchvision.transforms.functional as tvf
|
11 |
import kornia.augmentation as K
|
@@ -16,36 +19,22 @@ from roma.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
|
|
16 |
from roma.utils.transforms import GeometricSequential
|
17 |
from tqdm import tqdm
|
18 |
|
19 |
-
|
20 |
class ScanNetScene:
|
21 |
-
def __init__(
|
22 |
-
|
23 |
-
data_root,
|
24 |
-
scene_info
|
25 |
-
|
26 |
-
wt=512,
|
27 |
-
min_overlap=0.0,
|
28 |
-
shake_t=0,
|
29 |
-
rot_prob=0.0,
|
30 |
-
use_horizontal_flip_aug=False,
|
31 |
-
) -> None:
|
32 |
-
self.scene_root = osp.join(data_root, "scans", "scans_train")
|
33 |
-
self.data_names = scene_info["name"]
|
34 |
-
self.overlaps = scene_info["score"]
|
35 |
# Only sample 10s
|
36 |
-
valid = (self.data_names[
|
37 |
self.overlaps = self.overlaps[valid]
|
38 |
self.data_names = self.data_names[valid]
|
39 |
if len(self.data_names) > 10000:
|
40 |
-
pairinds = np.random.choice(
|
41 |
-
np.arange(0, len(self.data_names)), 10000, replace=False
|
42 |
-
)
|
43 |
self.data_names = self.data_names[pairinds]
|
44 |
self.overlaps = self.overlaps[pairinds]
|
45 |
self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
|
46 |
-
self.depth_transform_ops = get_depth_tuple_transform_ops(
|
47 |
-
resize=(ht, wt), normalize=False
|
48 |
-
)
|
49 |
self.wt, self.ht = wt, ht
|
50 |
self.shake_t = shake_t
|
51 |
self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
|
@@ -54,7 +43,7 @@ class ScanNetScene:
|
|
54 |
def load_im(self, im_B, crop=None):
|
55 |
im = Image.open(im_B)
|
56 |
return im
|
57 |
-
|
58 |
def load_depth(self, depth_ref, crop=None):
|
59 |
depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
|
60 |
depth = depth / 1000
|
@@ -63,73 +52,64 @@ class ScanNetScene:
|
|
63 |
|
64 |
def __len__(self):
|
65 |
return len(self.data_names)
|
66 |
-
|
67 |
def scale_intrinsic(self, K, wi, hi):
|
68 |
-
sx, sy = self.wt / wi, self.ht /
|
69 |
-
sK = torch.tensor([[sx, 0, 0],
|
70 |
-
|
|
|
|
|
71 |
|
72 |
-
def horizontal_flip(self, im_A, im_B, depth_A, depth_B,
|
73 |
im_A = im_A.flip(-1)
|
74 |
im_B = im_B.flip(-1)
|
75 |
-
depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
|
76 |
-
flip_mat = torch.tensor([[-1, 0, self.wt],
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
K_B = flip_mat @ K_B
|
81 |
-
|
82 |
return im_A, im_B, depth_A, depth_B, K_A, K_B
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
Returns:
|
88 |
pose_w2c (np.ndarray): (4, 4)
|
89 |
"""
|
90 |
-
cam2world = np.loadtxt(path, delimiter=
|
91 |
world2cam = np.linalg.inv(cam2world)
|
92 |
return world2cam
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
intrinsic
|
97 |
-
|
|
|
|
|
98 |
|
99 |
def __getitem__(self, pair_idx):
|
100 |
# read intrinsics of original size
|
101 |
data_name = self.data_names[pair_idx]
|
102 |
scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
|
103 |
-
scene_name = f
|
104 |
-
|
105 |
# read the intrinsic of depthmap
|
106 |
-
K1 = K2 =
|
107 |
-
|
108 |
-
|
109 |
# read and compute relative poses
|
110 |
-
T1 =
|
111 |
-
|
112 |
-
|
113 |
-
T2 =
|
114 |
-
|
115 |
-
|
116 |
-
T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
|
117 |
-
:4, :4
|
118 |
-
] # (4, 4)
|
119 |
|
120 |
# Load positive pair data
|
121 |
-
im_A_ref = os.path.join(
|
122 |
-
|
123 |
-
)
|
124 |
-
|
125 |
-
self.scene_root, scene_name, "color", f"{stem_name_2}.jpg"
|
126 |
-
)
|
127 |
-
depth_A_ref = os.path.join(
|
128 |
-
self.scene_root, scene_name, "depth", f"{stem_name_1}.png"
|
129 |
-
)
|
130 |
-
depth_B_ref = os.path.join(
|
131 |
-
self.scene_root, scene_name, "depth", f"{stem_name_2}.png"
|
132 |
-
)
|
133 |
|
134 |
im_A = self.load_im(im_A_ref)
|
135 |
im_B = self.load_im(im_B_ref)
|
@@ -141,51 +121,40 @@ class ScanNetScene:
|
|
141 |
K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
|
142 |
# Process images
|
143 |
im_A, im_B = self.im_transform_ops((im_A, im_B))
|
144 |
-
depth_A, depth_B = self.depth_transform_ops(
|
145 |
-
(depth_A[None, None], depth_B[None, None])
|
146 |
-
)
|
147 |
if self.use_horizontal_flip_aug:
|
148 |
if np.random.rand() > 0.5:
|
149 |
-
im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
"K2": K2,
|
160 |
-
"T_1to2": T_1to2,
|
161 |
-
}
|
162 |
return data_dict
|
163 |
|
164 |
|
165 |
class ScanNetBuilder:
|
166 |
-
def __init__(self, data_root=
|
167 |
self.data_root = data_root
|
168 |
-
self.scene_info_root = os.path.join(data_root,
|
169 |
self.all_scenes = os.listdir(self.scene_info_root)
|
170 |
-
|
171 |
-
def build_scenes(self, split=
|
172 |
# Note: split doesn't matter here as we always use same scannet_train scenes
|
173 |
scene_names = self.all_scenes
|
174 |
scenes = []
|
175 |
-
for scene_name in tqdm(scene_names, disable=roma.RANK > 0):
|
176 |
-
scene_info = np.load(
|
177 |
-
|
178 |
-
)
|
179 |
-
scenes.append(
|
180 |
-
ScanNetScene(
|
181 |
-
self.data_root, scene_info, min_overlap=min_overlap, **kwargs
|
182 |
-
)
|
183 |
-
)
|
184 |
return scenes
|
185 |
-
|
186 |
-
def weight_scenes(self, concat_dataset, alpha
|
187 |
ns = []
|
188 |
for d in concat_dataset.datasets:
|
189 |
ns.append(len(d))
|
190 |
-
ws = torch.cat([torch.ones(n)
|
191 |
return ws
|
|
|
5 |
import h5py
|
6 |
import numpy as np
|
7 |
import torch
|
8 |
+
from torch.utils.data import (
|
9 |
+
Dataset,
|
10 |
+
DataLoader,
|
11 |
+
ConcatDataset)
|
12 |
|
13 |
import torchvision.transforms.functional as tvf
|
14 |
import kornia.augmentation as K
|
|
|
19 |
from roma.utils.transforms import GeometricSequential
|
20 |
from tqdm import tqdm
|
21 |
|
|
|
22 |
class ScanNetScene:
|
23 |
+
def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.,use_horizontal_flip_aug = False,
|
24 |
+
) -> None:
|
25 |
+
self.scene_root = osp.join(data_root,"scans","scans_train")
|
26 |
+
self.data_names = scene_info['name']
|
27 |
+
self.overlaps = scene_info['score']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
# Only sample 10s
|
29 |
+
valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
|
30 |
self.overlaps = self.overlaps[valid]
|
31 |
self.data_names = self.data_names[valid]
|
32 |
if len(self.data_names) > 10000:
|
33 |
+
pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
|
|
|
|
|
34 |
self.data_names = self.data_names[pairinds]
|
35 |
self.overlaps = self.overlaps[pairinds]
|
36 |
self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
|
37 |
+
self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
|
|
|
|
|
38 |
self.wt, self.ht = wt, ht
|
39 |
self.shake_t = shake_t
|
40 |
self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
|
|
|
43 |
def load_im(self, im_B, crop=None):
|
44 |
im = Image.open(im_B)
|
45 |
return im
|
46 |
+
|
47 |
def load_depth(self, depth_ref, crop=None):
|
48 |
depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
|
49 |
depth = depth / 1000
|
|
|
52 |
|
53 |
def __len__(self):
|
54 |
return len(self.data_names)
|
55 |
+
|
56 |
def scale_intrinsic(self, K, wi, hi):
|
57 |
+
sx, sy = self.wt / wi, self.ht / hi
|
58 |
+
sK = torch.tensor([[sx, 0, 0],
|
59 |
+
[0, sy, 0],
|
60 |
+
[0, 0, 1]])
|
61 |
+
return sK@K
|
62 |
|
63 |
+
def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
|
64 |
im_A = im_A.flip(-1)
|
65 |
im_B = im_B.flip(-1)
|
66 |
+
depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
|
67 |
+
flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
|
68 |
+
K_A = flip_mat@K_A
|
69 |
+
K_B = flip_mat@K_B
|
70 |
+
|
|
|
|
|
71 |
return im_A, im_B, depth_A, depth_B, K_A, K_B
|
72 |
+
def read_scannet_pose(self,path):
|
73 |
+
""" Read ScanNet's Camera2World pose and transform it to World2Camera.
|
74 |
+
|
|
|
75 |
Returns:
|
76 |
pose_w2c (np.ndarray): (4, 4)
|
77 |
"""
|
78 |
+
cam2world = np.loadtxt(path, delimiter=' ')
|
79 |
world2cam = np.linalg.inv(cam2world)
|
80 |
return world2cam
|
81 |
|
82 |
+
|
83 |
+
def read_scannet_intrinsic(self,path):
|
84 |
+
""" Read ScanNet's intrinsic matrix and return the 3x3 matrix.
|
85 |
+
"""
|
86 |
+
intrinsic = np.loadtxt(path, delimiter=' ')
|
87 |
+
return torch.tensor(intrinsic[:-1, :-1], dtype = torch.float)
|
88 |
|
89 |
def __getitem__(self, pair_idx):
|
90 |
# read intrinsics of original size
|
91 |
data_name = self.data_names[pair_idx]
|
92 |
scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
|
93 |
+
scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
|
94 |
+
|
95 |
# read the intrinsic of depthmap
|
96 |
+
K1 = K2 = self.read_scannet_intrinsic(osp.join(self.scene_root,
|
97 |
+
scene_name,
|
98 |
+
'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
|
99 |
# read and compute relative poses
|
100 |
+
T1 = self.read_scannet_pose(osp.join(self.scene_root,
|
101 |
+
scene_name,
|
102 |
+
'pose', f'{stem_name_1}.txt'))
|
103 |
+
T2 = self.read_scannet_pose(osp.join(self.scene_root,
|
104 |
+
scene_name,
|
105 |
+
'pose', f'{stem_name_2}.txt'))
|
106 |
+
T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4] # (4, 4)
|
|
|
|
|
107 |
|
108 |
# Load positive pair data
|
109 |
+
im_A_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
|
110 |
+
im_B_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
|
111 |
+
depth_A_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
|
112 |
+
depth_B_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
im_A = self.load_im(im_A_ref)
|
115 |
im_B = self.load_im(im_B_ref)
|
|
|
121 |
K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
|
122 |
# Process images
|
123 |
im_A, im_B = self.im_transform_ops((im_A, im_B))
|
124 |
+
depth_A, depth_B = self.depth_transform_ops((depth_A[None,None], depth_B[None,None]))
|
|
|
|
|
125 |
if self.use_horizontal_flip_aug:
|
126 |
if np.random.rand() > 0.5:
|
127 |
+
im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
|
128 |
+
|
129 |
+
data_dict = {'im_A': im_A,
|
130 |
+
'im_B': im_B,
|
131 |
+
'im_A_depth': depth_A[0,0],
|
132 |
+
'im_B_depth': depth_B[0,0],
|
133 |
+
'K1': K1,
|
134 |
+
'K2': K2,
|
135 |
+
'T_1to2':T_1to2,
|
136 |
+
}
|
|
|
|
|
|
|
137 |
return data_dict
|
138 |
|
139 |
|
140 |
class ScanNetBuilder:
|
141 |
+
def __init__(self, data_root = 'data/scannet') -> None:
|
142 |
self.data_root = data_root
|
143 |
+
self.scene_info_root = os.path.join(data_root,'scannet_indices')
|
144 |
self.all_scenes = os.listdir(self.scene_info_root)
|
145 |
+
|
146 |
+
def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
|
147 |
# Note: split doesn't matter here as we always use same scannet_train scenes
|
148 |
scene_names = self.all_scenes
|
149 |
scenes = []
|
150 |
+
for scene_name in tqdm(scene_names, disable = roma.RANK > 0):
|
151 |
+
scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
|
152 |
+
scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
return scenes
|
154 |
+
|
155 |
+
def weight_scenes(self, concat_dataset, alpha=.5):
|
156 |
ns = []
|
157 |
for d in concat_dataset.datasets:
|
158 |
ns.append(len(d))
|
159 |
+
ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
|
160 |
return ws
|
third_party/RoMa/roma/losses/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .robust_loss import RobustLosses
|
third_party/{Roma β RoMa}/roma/losses/robust_loss.py
RENAMED
@@ -7,7 +7,6 @@ import wandb
|
|
7 |
import roma
|
8 |
import math
|
9 |
|
10 |
-
|
11 |
class RobustLosses(nn.Module):
|
12 |
def __init__(
|
13 |
self,
|
@@ -18,12 +17,12 @@ class RobustLosses(nn.Module):
|
|
18 |
local_loss=True,
|
19 |
local_dist=4.0,
|
20 |
local_largest_scale=8,
|
21 |
-
smooth_mask=False,
|
22 |
-
depth_interpolation_mode="bilinear",
|
23 |
-
mask_depth_loss=False,
|
24 |
-
relative_depth_error_threshold=0.05,
|
25 |
-
alpha=1
|
26 |
-
c=1e-3,
|
27 |
):
|
28 |
super().__init__()
|
29 |
self.robust = robust # measured in pixels
|
@@ -46,103 +45,68 @@ class RobustLosses(nn.Module):
|
|
46 |
B, C, H, W = scale_gm_cls.shape
|
47 |
device = x2.device
|
48 |
cls_res = round(math.sqrt(C))
|
49 |
-
G = torch.meshgrid(
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
)
|
54 |
-
for _ in range(2)
|
55 |
-
]
|
56 |
-
)
|
57 |
-
G = torch.stack((G[1], G[0]), dim=-1).reshape(C, 2)
|
58 |
-
GT = (
|
59 |
-
(G[None, :, None, None, :] - x2[:, None])
|
60 |
-
.norm(dim=-1)
|
61 |
-
.min(dim=1)
|
62 |
-
.indices
|
63 |
-
)
|
64 |
-
cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction="none")[prob > 0.99]
|
65 |
if not torch.any(cls_loss):
|
66 |
-
cls_loss = certainty_loss * 0.0 # Prevent issues where prob is 0 everywhere
|
67 |
|
68 |
-
certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,
|
69 |
losses = {
|
70 |
f"gm_certainty_loss_{scale}": certainty_loss.mean(),
|
71 |
f"gm_cls_loss_{scale}": cls_loss.mean(),
|
72 |
}
|
73 |
-
wandb.log(losses, step=roma.GLOBAL_STEP)
|
74 |
return losses
|
75 |
|
76 |
-
def delta_cls_loss(
|
77 |
-
self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale
|
78 |
-
):
|
79 |
with torch.no_grad():
|
80 |
B, C, H, W = delta_cls.shape
|
81 |
device = x2.device
|
82 |
cls_res = round(math.sqrt(C))
|
83 |
-
G = torch.meshgrid(
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
)
|
88 |
-
for _ in range(2)
|
89 |
-
]
|
90 |
-
)
|
91 |
-
G = torch.stack((G[1], G[0]), dim=-1).reshape(C, 2) * offset_scale
|
92 |
-
GT = (
|
93 |
-
(G[None, :, None, None, :] + flow_pre_delta[:, None] - x2[:, None])
|
94 |
-
.norm(dim=-1)
|
95 |
-
.min(dim=1)
|
96 |
-
.indices
|
97 |
-
)
|
98 |
-
cls_loss = F.cross_entropy(delta_cls, GT, reduction="none")[prob > 0.99]
|
99 |
if not torch.any(cls_loss):
|
100 |
-
cls_loss = certainty_loss * 0.0 # Prevent issues where prob is 0 everywhere
|
101 |
-
certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,
|
102 |
losses = {
|
103 |
f"delta_certainty_loss_{scale}": certainty_loss.mean(),
|
104 |
f"delta_cls_loss_{scale}": cls_loss.mean(),
|
105 |
}
|
106 |
-
wandb.log(losses, step=roma.GLOBAL_STEP)
|
107 |
return losses
|
108 |
|
109 |
-
def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode="delta"):
|
110 |
-
epe = (flow.permute(0,
|
111 |
if scale == 1:
|
112 |
-
pck_05 = (epe[prob > 0.99] < 0.5 * (2
|
113 |
-
wandb.log({"train_pck_05": pck_05}, step=roma.GLOBAL_STEP)
|
114 |
|
115 |
ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
|
116 |
a = self.alpha
|
117 |
cs = self.c * scale
|
118 |
x = epe[prob > 0.99]
|
119 |
-
reg_loss = cs**a * ((x
|
120 |
if not torch.any(reg_loss):
|
121 |
-
reg_loss = ce_loss * 0.0 # Prevent issues where prob is 0 everywhere
|
122 |
losses = {
|
123 |
f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
|
124 |
f"{mode}_regression_loss_{scale}": reg_loss.mean(),
|
125 |
}
|
126 |
-
wandb.log(losses, step=roma.GLOBAL_STEP)
|
127 |
return losses
|
128 |
|
129 |
def forward(self, corresps, batch):
|
130 |
scales = list(corresps.keys())
|
131 |
tot_loss = 0.0
|
132 |
# scale_weights due to differences in scale for regression gradients and classification gradients
|
133 |
-
scale_weights = {1:
|
134 |
for scale in scales:
|
135 |
scale_corresps = corresps[scale]
|
136 |
-
(
|
137 |
-
scale_certainty,
|
138 |
-
flow_pre_delta,
|
139 |
-
delta_cls,
|
140 |
-
offset_scale,
|
141 |
-
scale_gm_cls,
|
142 |
-
scale_gm_certainty,
|
143 |
-
flow,
|
144 |
-
scale_gm_flow,
|
145 |
-
) = (
|
146 |
scale_corresps["certainty"],
|
147 |
scale_corresps["flow_pre_delta"],
|
148 |
scale_corresps.get("delta_cls"),
|
@@ -151,72 +115,43 @@ class RobustLosses(nn.Module):
|
|
151 |
scale_corresps.get("gm_certainty"),
|
152 |
scale_corresps["flow"],
|
153 |
scale_corresps.get("gm_flow"),
|
|
|
154 |
)
|
155 |
flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
|
156 |
b, h, w, d = flow_pre_delta.shape
|
157 |
-
gt_warp, gt_prob = get_gt_warp(
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
x2 = gt_warp.float()
|
167 |
prob = gt_prob
|
168 |
-
|
169 |
if self.local_largest_scale >= scale:
|
170 |
prob = prob * (
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
< (2 / 512) * (self.local_dist[scale] * scale)
|
175 |
-
)
|
176 |
-
|
177 |
if scale_gm_cls is not None:
|
178 |
-
gm_cls_losses = self.gm_cls_loss(
|
179 |
-
|
180 |
-
)
|
181 |
-
gm_loss = (
|
182 |
-
self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"]
|
183 |
-
+ gm_cls_losses[f"gm_cls_loss_{scale}"]
|
184 |
-
)
|
185 |
tot_loss = tot_loss + scale_weights[scale] * gm_loss
|
186 |
elif scale_gm_flow is not None:
|
187 |
-
gm_flow_losses = self.regression_loss(
|
188 |
-
|
189 |
-
)
|
190 |
-
gm_loss = (
|
191 |
-
self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"]
|
192 |
-
+ gm_flow_losses[f"gm_regression_loss_{scale}"]
|
193 |
-
)
|
194 |
tot_loss = tot_loss + scale_weights[scale] * gm_loss
|
195 |
-
|
196 |
if delta_cls is not None:
|
197 |
-
delta_cls_losses = self.delta_cls_loss(
|
198 |
-
|
199 |
-
prob,
|
200 |
-
flow_pre_delta,
|
201 |
-
delta_cls,
|
202 |
-
scale_certainty,
|
203 |
-
scale,
|
204 |
-
offset_scale,
|
205 |
-
)
|
206 |
-
delta_cls_loss = (
|
207 |
-
self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"]
|
208 |
-
+ delta_cls_losses[f"delta_cls_loss_{scale}"]
|
209 |
-
)
|
210 |
tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss
|
211 |
else:
|
212 |
-
delta_regression_losses = self.regression_loss(
|
213 |
-
|
214 |
-
)
|
215 |
-
reg_loss = (
|
216 |
-
self.ce_weight
|
217 |
-
* delta_regression_losses[f"delta_certainty_loss_{scale}"]
|
218 |
-
+ delta_regression_losses[f"delta_regression_loss_{scale}"]
|
219 |
-
)
|
220 |
tot_loss = tot_loss + scale_weights[scale] * reg_loss
|
221 |
-
prev_epe = (flow.permute(0,
|
222 |
return tot_loss
|
|
|
7 |
import roma
|
8 |
import math
|
9 |
|
|
|
10 |
class RobustLosses(nn.Module):
|
11 |
def __init__(
|
12 |
self,
|
|
|
17 |
local_loss=True,
|
18 |
local_dist=4.0,
|
19 |
local_largest_scale=8,
|
20 |
+
smooth_mask = False,
|
21 |
+
depth_interpolation_mode = "bilinear",
|
22 |
+
mask_depth_loss = False,
|
23 |
+
relative_depth_error_threshold = 0.05,
|
24 |
+
alpha = 1.,
|
25 |
+
c = 1e-3,
|
26 |
):
|
27 |
super().__init__()
|
28 |
self.robust = robust # measured in pixels
|
|
|
45 |
B, C, H, W = scale_gm_cls.shape
|
46 |
device = x2.device
|
47 |
cls_res = round(math.sqrt(C))
|
48 |
+
G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
|
49 |
+
G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2)
|
50 |
+
GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices
|
51 |
+
cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction = 'none')[prob > 0.99]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
if not torch.any(cls_loss):
|
53 |
+
cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere
|
54 |
|
55 |
+
certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)
|
56 |
losses = {
|
57 |
f"gm_certainty_loss_{scale}": certainty_loss.mean(),
|
58 |
f"gm_cls_loss_{scale}": cls_loss.mean(),
|
59 |
}
|
60 |
+
wandb.log(losses, step = roma.GLOBAL_STEP)
|
61 |
return losses
|
62 |
|
63 |
+
def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale):
|
|
|
|
|
64 |
with torch.no_grad():
|
65 |
B, C, H, W = delta_cls.shape
|
66 |
device = x2.device
|
67 |
cls_res = round(math.sqrt(C))
|
68 |
+
G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
|
69 |
+
G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale
|
70 |
+
GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices
|
71 |
+
cls_loss = F.cross_entropy(delta_cls, GT, reduction = 'none')[prob > 0.99]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
if not torch.any(cls_loss):
|
73 |
+
cls_loss = (certainty_loss * 0.0) # Prevent issues where prob is 0 everywhere
|
74 |
+
certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob)
|
75 |
losses = {
|
76 |
f"delta_certainty_loss_{scale}": certainty_loss.mean(),
|
77 |
f"delta_cls_loss_{scale}": cls_loss.mean(),
|
78 |
}
|
79 |
+
wandb.log(losses, step = roma.GLOBAL_STEP)
|
80 |
return losses
|
81 |
|
82 |
+
def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
|
83 |
+
epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
|
84 |
if scale == 1:
|
85 |
+
pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
|
86 |
+
wandb.log({"train_pck_05": pck_05}, step = roma.GLOBAL_STEP)
|
87 |
|
88 |
ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
|
89 |
a = self.alpha
|
90 |
cs = self.c * scale
|
91 |
x = epe[prob > 0.99]
|
92 |
+
reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
|
93 |
if not torch.any(reg_loss):
|
94 |
+
reg_loss = (ce_loss * 0.0) # Prevent issues where prob is 0 everywhere
|
95 |
losses = {
|
96 |
f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
|
97 |
f"{mode}_regression_loss_{scale}": reg_loss.mean(),
|
98 |
}
|
99 |
+
wandb.log(losses, step = roma.GLOBAL_STEP)
|
100 |
return losses
|
101 |
|
102 |
def forward(self, corresps, batch):
|
103 |
scales = list(corresps.keys())
|
104 |
tot_loss = 0.0
|
105 |
# scale_weights due to differences in scale for regression gradients and classification gradients
|
106 |
+
scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1}
|
107 |
for scale in scales:
|
108 |
scale_corresps = corresps[scale]
|
109 |
+
scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
scale_corresps["certainty"],
|
111 |
scale_corresps["flow_pre_delta"],
|
112 |
scale_corresps.get("delta_cls"),
|
|
|
115 |
scale_corresps.get("gm_certainty"),
|
116 |
scale_corresps["flow"],
|
117 |
scale_corresps.get("gm_flow"),
|
118 |
+
|
119 |
)
|
120 |
flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
|
121 |
b, h, w, d = flow_pre_delta.shape
|
122 |
+
gt_warp, gt_prob = get_gt_warp(
|
123 |
+
batch["im_A_depth"],
|
124 |
+
batch["im_B_depth"],
|
125 |
+
batch["T_1to2"],
|
126 |
+
batch["K1"],
|
127 |
+
batch["K2"],
|
128 |
+
H=h,
|
129 |
+
W=w,
|
130 |
+
)
|
131 |
x2 = gt_warp.float()
|
132 |
prob = gt_prob
|
133 |
+
|
134 |
if self.local_largest_scale >= scale:
|
135 |
prob = prob * (
|
136 |
+
F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0]
|
137 |
+
< (2 / 512) * (self.local_dist[scale] * scale))
|
138 |
+
|
|
|
|
|
|
|
139 |
if scale_gm_cls is not None:
|
140 |
+
gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale)
|
141 |
+
gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"]
|
|
|
|
|
|
|
|
|
|
|
142 |
tot_loss = tot_loss + scale_weights[scale] * gm_loss
|
143 |
elif scale_gm_flow is not None:
|
144 |
+
gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
|
145 |
+
gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
|
|
|
|
|
|
|
|
|
|
|
146 |
tot_loss = tot_loss + scale_weights[scale] * gm_loss
|
147 |
+
|
148 |
if delta_cls is not None:
|
149 |
+
delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale)
|
150 |
+
delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss
|
152 |
else:
|
153 |
+
delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
|
154 |
+
reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
tot_loss = tot_loss + scale_weights[scale] * reg_loss
|
156 |
+
prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach()
|
157 |
return tot_loss
|
third_party/RoMa/roma/models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model_zoo import roma_outdoor, roma_indoor
|
third_party/{Roma β RoMa}/roma/models/encoders.py
RENAMED
@@ -8,7 +8,8 @@ import gc
|
|
8 |
|
9 |
|
10 |
class ResNet50(nn.Module):
|
11 |
-
def __init__(self, pretrained=False, high_res = False, weights = None,
|
|
|
12 |
super().__init__()
|
13 |
if dilation is None:
|
14 |
dilation = [False,False,False]
|
@@ -24,10 +25,7 @@ class ResNet50(nn.Module):
|
|
24 |
self.freeze_bn = freeze_bn
|
25 |
self.early_exit = early_exit
|
26 |
self.amp = amp
|
27 |
-
|
28 |
-
self.amp_dtype = torch.float32
|
29 |
-
else:
|
30 |
-
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
31 |
|
32 |
def forward(self, x, **kwargs):
|
33 |
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
@@ -59,14 +57,11 @@ class ResNet50(nn.Module):
|
|
59 |
pass
|
60 |
|
61 |
class VGG19(nn.Module):
|
62 |
-
def __init__(self, pretrained=False, amp = False) -> None:
|
63 |
super().__init__()
|
64 |
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
65 |
self.amp = amp
|
66 |
-
|
67 |
-
self.amp_dtype = torch.float32
|
68 |
-
else:
|
69 |
-
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
70 |
|
71 |
def forward(self, x, **kwargs):
|
72 |
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
@@ -80,7 +75,7 @@ class VGG19(nn.Module):
|
|
80 |
return feats
|
81 |
|
82 |
class CNNandDinov2(nn.Module):
|
83 |
-
def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None):
|
84 |
super().__init__()
|
85 |
if dinov2_weights is None:
|
86 |
dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
|
@@ -100,10 +95,7 @@ class CNNandDinov2(nn.Module):
|
|
100 |
else:
|
101 |
self.cnn = VGG19(**cnn_kwargs)
|
102 |
self.amp = amp
|
103 |
-
|
104 |
-
self.amp_dtype = torch.float32
|
105 |
-
else:
|
106 |
-
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
107 |
if self.amp:
|
108 |
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
|
109 |
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
|
|
|
8 |
|
9 |
|
10 |
class ResNet50(nn.Module):
|
11 |
+
def __init__(self, pretrained=False, high_res = False, weights = None,
|
12 |
+
dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False, amp_dtype = torch.float16) -> None:
|
13 |
super().__init__()
|
14 |
if dilation is None:
|
15 |
dilation = [False,False,False]
|
|
|
25 |
self.freeze_bn = freeze_bn
|
26 |
self.early_exit = early_exit
|
27 |
self.amp = amp
|
28 |
+
self.amp_dtype = amp_dtype
|
|
|
|
|
|
|
29 |
|
30 |
def forward(self, x, **kwargs):
|
31 |
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
|
|
57 |
pass
|
58 |
|
59 |
class VGG19(nn.Module):
|
60 |
+
def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
|
61 |
super().__init__()
|
62 |
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
63 |
self.amp = amp
|
64 |
+
self.amp_dtype = amp_dtype
|
|
|
|
|
|
|
65 |
|
66 |
def forward(self, x, **kwargs):
|
67 |
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
|
|
75 |
return feats
|
76 |
|
77 |
class CNNandDinov2(nn.Module):
|
78 |
+
def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None, amp_dtype = torch.float16):
|
79 |
super().__init__()
|
80 |
if dinov2_weights is None:
|
81 |
dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
|
|
|
95 |
else:
|
96 |
self.cnn = VGG19(**cnn_kwargs)
|
97 |
self.amp = amp
|
98 |
+
self.amp_dtype = amp_dtype
|
|
|
|
|
|
|
99 |
if self.amp:
|
100 |
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
|
101 |
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
|
third_party/{Roma β RoMa}/roma/models/matcher.py
RENAMED
@@ -7,6 +7,7 @@ import torch.nn.functional as F
|
|
7 |
from einops import rearrange
|
8 |
import warnings
|
9 |
from warnings import warn
|
|
|
10 |
|
11 |
import roma
|
12 |
from roma.utils import get_tuple_transform_ops
|
@@ -37,6 +38,7 @@ class ConvRefiner(nn.Module):
|
|
37 |
sample_mode = "bilinear",
|
38 |
norm_type = nn.BatchNorm2d,
|
39 |
bn_momentum = 0.1,
|
|
|
40 |
):
|
41 |
super().__init__()
|
42 |
self.bn_momentum = bn_momentum
|
@@ -71,12 +73,8 @@ class ConvRefiner(nn.Module):
|
|
71 |
self.disable_local_corr_grad = disable_local_corr_grad
|
72 |
self.is_classifier = is_classifier
|
73 |
self.sample_mode = sample_mode
|
74 |
-
self.
|
75 |
-
|
76 |
-
self.amp_dtype = torch.float32
|
77 |
-
else:
|
78 |
-
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
79 |
-
|
80 |
def create_block(
|
81 |
self,
|
82 |
in_dim,
|
@@ -113,8 +111,8 @@ class ConvRefiner(nn.Module):
|
|
113 |
if self.has_displacement_emb:
|
114 |
im_A_coords = torch.meshgrid(
|
115 |
(
|
116 |
-
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=
|
117 |
-
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=
|
118 |
)
|
119 |
)
|
120 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
@@ -278,7 +276,7 @@ class Decoder(nn.Module):
|
|
278 |
def __init__(
|
279 |
self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
|
280 |
num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
|
281 |
-
flow_upsample_mode = "bilinear"
|
282 |
):
|
283 |
super().__init__()
|
284 |
self.embedding_decoder = embedding_decoder
|
@@ -300,11 +298,8 @@ class Decoder(nn.Module):
|
|
300 |
self.displacement_dropout_p = displacement_dropout_p
|
301 |
self.gm_warp_dropout_p = gm_warp_dropout_p
|
302 |
self.flow_upsample_mode = flow_upsample_mode
|
303 |
-
|
304 |
-
|
305 |
-
else:
|
306 |
-
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
307 |
-
|
308 |
def get_placeholder_flow(self, b, h, w, device):
|
309 |
coarse_coords = torch.meshgrid(
|
310 |
(
|
@@ -367,7 +362,7 @@ class Decoder(nn.Module):
|
|
367 |
corresps[ins] = {}
|
368 |
f1_s, f2_s = f1[ins], f2[ins]
|
369 |
if new_scale in self.proj:
|
370 |
-
with torch.autocast("cuda", self.amp_dtype):
|
371 |
f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
|
372 |
|
373 |
if ins in coarse_scales:
|
@@ -429,11 +424,12 @@ class RegressionMatcher(nn.Module):
|
|
429 |
decoder,
|
430 |
h=448,
|
431 |
w=448,
|
432 |
-
sample_mode = "
|
433 |
upsample_preds = False,
|
434 |
symmetric = False,
|
435 |
name = None,
|
436 |
attenuate_cert = None,
|
|
|
437 |
):
|
438 |
super().__init__()
|
439 |
self.attenuate_cert = attenuate_cert
|
@@ -448,6 +444,7 @@ class RegressionMatcher(nn.Module):
|
|
448 |
self.upsample_res = (14*16*6, 14*16*6)
|
449 |
self.symmetric = symmetric
|
450 |
self.sample_thresh = 0.05
|
|
|
451 |
|
452 |
def get_output_resolution(self):
|
453 |
if not self.upsample_preds:
|
@@ -527,12 +524,62 @@ class RegressionMatcher(nn.Module):
|
|
527 |
scale_factor=scale_factor)
|
528 |
return corresps
|
529 |
|
530 |
-
def to_pixel_coordinates(self,
|
531 |
-
|
|
|
|
|
|
|
532 |
kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
|
533 |
kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
|
534 |
return kpts_A, kpts_B
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
535 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
536 |
def match(
|
537 |
self,
|
538 |
im_A_path,
|
@@ -543,9 +590,8 @@ class RegressionMatcher(nn.Module):
|
|
543 |
):
|
544 |
if device is None:
|
545 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
546 |
-
from PIL import Image
|
547 |
if isinstance(im_A_path, (str, os.PathLike)):
|
548 |
-
im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
|
549 |
else:
|
550 |
# Assume its not a path
|
551 |
im_A, im_B = im_A_path, im_B_path
|
@@ -597,7 +643,14 @@ class RegressionMatcher(nn.Module):
|
|
597 |
test_transform = get_tuple_transform_ops(
|
598 |
resize=(hs, ws), normalize=True
|
599 |
)
|
600 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
601 |
im_A, im_B = test_transform((im_A, im_B))
|
602 |
im_A, im_B = im_A[None].to(device), im_B[None].to(device)
|
603 |
scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
|
@@ -653,4 +706,30 @@ class RegressionMatcher(nn.Module):
|
|
653 |
warp[0],
|
654 |
certainty[0, 0],
|
655 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
656 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from einops import rearrange
|
8 |
import warnings
|
9 |
from warnings import warn
|
10 |
+
from PIL import Image
|
11 |
|
12 |
import roma
|
13 |
from roma.utils import get_tuple_transform_ops
|
|
|
38 |
sample_mode = "bilinear",
|
39 |
norm_type = nn.BatchNorm2d,
|
40 |
bn_momentum = 0.1,
|
41 |
+
amp_dtype = torch.float16,
|
42 |
):
|
43 |
super().__init__()
|
44 |
self.bn_momentum = bn_momentum
|
|
|
73 |
self.disable_local_corr_grad = disable_local_corr_grad
|
74 |
self.is_classifier = is_classifier
|
75 |
self.sample_mode = sample_mode
|
76 |
+
self.amp_dtype = amp_dtype
|
77 |
+
|
|
|
|
|
|
|
|
|
78 |
def create_block(
|
79 |
self,
|
80 |
in_dim,
|
|
|
111 |
if self.has_displacement_emb:
|
112 |
im_A_coords = torch.meshgrid(
|
113 |
(
|
114 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device),
|
115 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device),
|
116 |
)
|
117 |
)
|
118 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
276 |
def __init__(
|
277 |
self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
|
278 |
num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
|
279 |
+
flow_upsample_mode = "bilinear", amp_dtype = torch.float16,
|
280 |
):
|
281 |
super().__init__()
|
282 |
self.embedding_decoder = embedding_decoder
|
|
|
298 |
self.displacement_dropout_p = displacement_dropout_p
|
299 |
self.gm_warp_dropout_p = gm_warp_dropout_p
|
300 |
self.flow_upsample_mode = flow_upsample_mode
|
301 |
+
self.amp_dtype = amp_dtype
|
302 |
+
|
|
|
|
|
|
|
303 |
def get_placeholder_flow(self, b, h, w, device):
|
304 |
coarse_coords = torch.meshgrid(
|
305 |
(
|
|
|
362 |
corresps[ins] = {}
|
363 |
f1_s, f2_s = f1[ins], f2[ins]
|
364 |
if new_scale in self.proj:
|
365 |
+
with torch.autocast("cuda", dtype = self.amp_dtype):
|
366 |
f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
|
367 |
|
368 |
if ins in coarse_scales:
|
|
|
424 |
decoder,
|
425 |
h=448,
|
426 |
w=448,
|
427 |
+
sample_mode = "threshold_balanced",
|
428 |
upsample_preds = False,
|
429 |
symmetric = False,
|
430 |
name = None,
|
431 |
attenuate_cert = None,
|
432 |
+
recrop_upsample = False,
|
433 |
):
|
434 |
super().__init__()
|
435 |
self.attenuate_cert = attenuate_cert
|
|
|
444 |
self.upsample_res = (14*16*6, 14*16*6)
|
445 |
self.symmetric = symmetric
|
446 |
self.sample_thresh = 0.05
|
447 |
+
self.recrop_upsample = recrop_upsample
|
448 |
|
449 |
def get_output_resolution(self):
|
450 |
if not self.upsample_preds:
|
|
|
524 |
scale_factor=scale_factor)
|
525 |
return corresps
|
526 |
|
527 |
+
def to_pixel_coordinates(self, coords, H_A, W_A, H_B, W_B):
|
528 |
+
if isinstance(coords, (list, tuple)):
|
529 |
+
kpts_A, kpts_B = coords[0], coords[1]
|
530 |
+
else:
|
531 |
+
kpts_A, kpts_B = coords[...,:2], coords[...,2:]
|
532 |
kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
|
533 |
kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
|
534 |
return kpts_A, kpts_B
|
535 |
+
|
536 |
+
def to_normalized_coordinates(self, coords, H_A, W_A, H_B, W_B):
|
537 |
+
if isinstance(coords, (list, tuple)):
|
538 |
+
kpts_A, kpts_B = coords[0], coords[1]
|
539 |
+
else:
|
540 |
+
kpts_A, kpts_B = coords[...,:2], coords[...,2:]
|
541 |
+
kpts_A = torch.stack((2/W_A * kpts_A[...,0] - 1, 2/H_A * kpts_A[...,1] - 1),axis=-1)
|
542 |
+
kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
|
543 |
+
return kpts_A, kpts_B
|
544 |
|
545 |
+
def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True, return_inds = False):
|
546 |
+
x_A_to_B = F.grid_sample(warp[...,-2:].permute(2,0,1)[None], x_A[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT
|
547 |
+
cert_A_to_B = F.grid_sample(certainty[None,None,...], x_A[None,None], align_corners = False, mode = "bilinear")[0,0,0]
|
548 |
+
D = torch.cdist(x_A_to_B, x_B)
|
549 |
+
inds_A, inds_B = torch.nonzero((D == D.min(dim=-1, keepdim = True).values) * (D == D.min(dim=-2, keepdim = True).values) * (cert_A_to_B[:,None] > self.sample_thresh), as_tuple = True)
|
550 |
+
|
551 |
+
if return_tuple:
|
552 |
+
if return_inds:
|
553 |
+
return inds_A, inds_B
|
554 |
+
else:
|
555 |
+
return x_A[inds_A], x_B[inds_B]
|
556 |
+
else:
|
557 |
+
if return_inds:
|
558 |
+
return torch.cat((inds_A, inds_B),dim=-1)
|
559 |
+
else:
|
560 |
+
return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
|
561 |
+
|
562 |
+
def get_roi(self, certainty, W, H, thr = 0.025):
|
563 |
+
raise NotImplementedError("WIP, disable for now")
|
564 |
+
hs,ws = certainty.shape
|
565 |
+
certainty = certainty/certainty.sum(dim=(-1,-2))
|
566 |
+
cum_certainty_w = certainty.cumsum(dim=-1).sum(dim=-2)
|
567 |
+
cum_certainty_h = certainty.cumsum(dim=-2).sum(dim=-1)
|
568 |
+
print(cum_certainty_w)
|
569 |
+
print(torch.min(torch.nonzero(cum_certainty_w > thr)))
|
570 |
+
print(torch.min(torch.nonzero(cum_certainty_w < thr)))
|
571 |
+
left = int(W/ws * torch.min(torch.nonzero(cum_certainty_w > thr)))
|
572 |
+
right = int(W/ws * torch.max(torch.nonzero(cum_certainty_w < 1 - thr)))
|
573 |
+
top = int(H/hs * torch.min(torch.nonzero(cum_certainty_h > thr)))
|
574 |
+
bottom = int(H/hs * torch.max(torch.nonzero(cum_certainty_h < 1 - thr)))
|
575 |
+
print(left, right, top, bottom)
|
576 |
+
return left, top, right, bottom
|
577 |
+
|
578 |
+
def recrop(self, certainty, image_path):
|
579 |
+
roi = self.get_roi(certainty, *Image.open(image_path).size)
|
580 |
+
return Image.open(image_path).convert("RGB").crop(roi)
|
581 |
+
|
582 |
+
@torch.inference_mode()
|
583 |
def match(
|
584 |
self,
|
585 |
im_A_path,
|
|
|
590 |
):
|
591 |
if device is None:
|
592 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
593 |
if isinstance(im_A_path, (str, os.PathLike)):
|
594 |
+
im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
|
595 |
else:
|
596 |
# Assume its not a path
|
597 |
im_A, im_B = im_A_path, im_B_path
|
|
|
643 |
test_transform = get_tuple_transform_ops(
|
644 |
resize=(hs, ws), normalize=True
|
645 |
)
|
646 |
+
if self.recrop_upsample:
|
647 |
+
certainty = corresps[finest_scale]["certainty"]
|
648 |
+
print(certainty.shape)
|
649 |
+
im_A = self.recrop(certainty[0,0], im_A_path)
|
650 |
+
im_B = self.recrop(certainty[1,0], im_B_path)
|
651 |
+
#TODO: need to adjust corresps when doing this
|
652 |
+
else:
|
653 |
+
im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
|
654 |
im_A, im_B = test_transform((im_A, im_B))
|
655 |
im_A, im_B = im_A[None].to(device), im_B[None].to(device)
|
656 |
scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
|
|
|
706 |
warp[0],
|
707 |
certainty[0, 0],
|
708 |
)
|
709 |
+
|
710 |
+
def visualize_warp(self, warp, certainty, im_A = None, im_B = None, im_A_path = None, im_B_path = None, device = "cuda", symmetric = True, save_path = None):
|
711 |
+
assert symmetric == True, "Currently assuming bidirectional warp, might update this if someone complains ;)"
|
712 |
+
H,W2,_ = warp.shape
|
713 |
+
W = W2//2 if symmetric else W2
|
714 |
+
if im_A is None:
|
715 |
+
from PIL import Image
|
716 |
+
im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
|
717 |
+
im_A = im_A.resize((W,H))
|
718 |
+
im_B = im_B.resize((W,H))
|
719 |
+
|
720 |
+
x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
|
721 |
+
x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
|
722 |
|
723 |
+
im_A_transfer_rgb = F.grid_sample(
|
724 |
+
x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
|
725 |
+
)[0]
|
726 |
+
im_B_transfer_rgb = F.grid_sample(
|
727 |
+
x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
|
728 |
+
)[0]
|
729 |
+
warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
|
730 |
+
white_im = torch.ones((H,2*W),device=device)
|
731 |
+
vis_im = certainty * warp_im + (1 - certainty) * white_im
|
732 |
+
if save_path is not None:
|
733 |
+
from roma.utils import tensor_to_pil
|
734 |
+
tensor_to_pil(vis_im, unnormalize=False).save(save_path)
|
735 |
+
return vis_im
|
third_party/RoMa/roma/models/model_zoo/__init__.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
import torch
|
3 |
+
from .roma_models import roma_model
|
4 |
+
|
5 |
+
weight_urls = {
|
6 |
+
"roma": {
|
7 |
+
"outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth",
|
8 |
+
"indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth",
|
9 |
+
},
|
10 |
+
"dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D
|
11 |
+
}
|
12 |
+
|
13 |
+
def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
|
14 |
+
if isinstance(coarse_res, int):
|
15 |
+
coarse_res = (coarse_res, coarse_res)
|
16 |
+
if isinstance(upsample_res, int):
|
17 |
+
upsample_res = (upsample_res, upsample_res)
|
18 |
+
|
19 |
+
assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
|
20 |
+
assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
|
21 |
+
|
22 |
+
if weights is None:
|
23 |
+
weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["outdoor"],
|
24 |
+
map_location=device)
|
25 |
+
if dinov2_weights is None:
|
26 |
+
dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
|
27 |
+
map_location=device)
|
28 |
+
model = roma_model(resolution=coarse_res, upsample_preds=True,
|
29 |
+
weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
|
30 |
+
model.upsample_res = upsample_res
|
31 |
+
print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
|
32 |
+
return model
|
33 |
+
|
34 |
+
def roma_indoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
|
35 |
+
if isinstance(coarse_res, int):
|
36 |
+
coarse_res = (coarse_res, coarse_res)
|
37 |
+
if isinstance(upsample_res, int):
|
38 |
+
upsample_res = (upsample_res, upsample_res)
|
39 |
+
|
40 |
+
assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
|
41 |
+
assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
|
42 |
+
|
43 |
+
if weights is None:
|
44 |
+
weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["indoor"],
|
45 |
+
map_location=device)
|
46 |
+
if dinov2_weights is None:
|
47 |
+
dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
|
48 |
+
map_location=device)
|
49 |
+
model = roma_model(resolution=coarse_res, upsample_preds=True,
|
50 |
+
weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
|
51 |
+
model.upsample_res = upsample_res
|
52 |
+
print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
|
53 |
+
return model
|
third_party/{Roma β RoMa}/roma/models/model_zoo/roma_models.py
RENAMED
@@ -1,98 +1,91 @@
|
|
1 |
import warnings
|
2 |
import torch.nn as nn
|
|
|
3 |
from roma.models.matcher import *
|
4 |
from roma.models.transformer import Block, TransformerDecoder, MemEffAttention
|
5 |
from roma.models.encoders import *
|
6 |
|
7 |
-
|
8 |
-
def roma_model(
|
9 |
-
resolution, upsample_preds, device=None, weights=None, dinov2_weights=None, **kwargs
|
10 |
-
):
|
11 |
# roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters
|
12 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
13 |
-
torch.backends.cudnn.allow_tf32 = True
|
14 |
-
warnings.filterwarnings(
|
15 |
-
"ignore", category=UserWarning, message="TypedStorage is deprecated"
|
16 |
-
)
|
17 |
gp_dim = 512
|
18 |
feat_dim = 512
|
19 |
decoder_dim = gp_dim + feat_dim
|
20 |
cls_to_coord_res = 64
|
21 |
coordinate_decoder = TransformerDecoder(
|
22 |
-
nn.Sequential(
|
23 |
-
|
24 |
-
),
|
25 |
-
decoder_dim,
|
26 |
cls_to_coord_res**2 + 1,
|
27 |
is_classifier=True,
|
28 |
-
amp=True,
|
29 |
-
pos_enc=False,
|
30 |
-
)
|
31 |
dw = True
|
32 |
hidden_blocks = 8
|
33 |
kernel_size = 5
|
34 |
displacement_emb = "linear"
|
35 |
disable_local_corr_grad = True
|
36 |
-
|
37 |
conv_refiner = nn.ModuleDict(
|
38 |
{
|
39 |
"16": ConvRefiner(
|
40 |
-
2 * 512
|
41 |
-
2 * 512
|
42 |
2 + 1,
|
43 |
kernel_size=kernel_size,
|
44 |
dw=dw,
|
45 |
hidden_blocks=hidden_blocks,
|
46 |
displacement_emb=displacement_emb,
|
47 |
displacement_emb_dim=128,
|
48 |
-
local_corr_radius=7,
|
49 |
-
corr_in_other=True,
|
50 |
-
amp=True,
|
51 |
-
disable_local_corr_grad=disable_local_corr_grad,
|
52 |
-
bn_momentum=0.01,
|
53 |
),
|
54 |
"8": ConvRefiner(
|
55 |
-
2 * 512
|
56 |
-
2 * 512
|
57 |
2 + 1,
|
58 |
kernel_size=kernel_size,
|
59 |
dw=dw,
|
60 |
hidden_blocks=hidden_blocks,
|
61 |
displacement_emb=displacement_emb,
|
62 |
displacement_emb_dim=64,
|
63 |
-
local_corr_radius=3,
|
64 |
-
corr_in_other=True,
|
65 |
-
amp=True,
|
66 |
-
disable_local_corr_grad=disable_local_corr_grad,
|
67 |
-
bn_momentum=0.01,
|
68 |
),
|
69 |
"4": ConvRefiner(
|
70 |
-
2 * 256
|
71 |
-
2 * 256
|
72 |
2 + 1,
|
73 |
kernel_size=kernel_size,
|
74 |
dw=dw,
|
75 |
hidden_blocks=hidden_blocks,
|
76 |
displacement_emb=displacement_emb,
|
77 |
displacement_emb_dim=32,
|
78 |
-
local_corr_radius=2,
|
79 |
-
corr_in_other=True,
|
80 |
-
amp=True,
|
81 |
-
disable_local_corr_grad=disable_local_corr_grad,
|
82 |
-
bn_momentum=0.01,
|
83 |
),
|
84 |
"2": ConvRefiner(
|
85 |
-
2 * 64
|
86 |
-
128
|
87 |
2 + 1,
|
88 |
kernel_size=kernel_size,
|
89 |
dw=dw,
|
90 |
hidden_blocks=hidden_blocks,
|
91 |
displacement_emb=displacement_emb,
|
92 |
displacement_emb_dim=16,
|
93 |
-
amp=True,
|
94 |
-
disable_local_corr_grad=disable_local_corr_grad,
|
95 |
-
bn_momentum=0.01,
|
96 |
),
|
97 |
"1": ConvRefiner(
|
98 |
2 * 9 + 6,
|
@@ -100,12 +93,12 @@ def roma_model(
|
|
100 |
2 + 1,
|
101 |
kernel_size=kernel_size,
|
102 |
dw=dw,
|
103 |
-
hidden_blocks=hidden_blocks,
|
104 |
-
displacement_emb=displacement_emb,
|
105 |
-
displacement_emb_dim=6,
|
106 |
-
amp=True,
|
107 |
-
disable_local_corr_grad=disable_local_corr_grad,
|
108 |
-
bn_momentum=0.01,
|
109 |
),
|
110 |
}
|
111 |
)
|
@@ -130,46 +123,38 @@ def roma_model(
|
|
130 |
proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
|
131 |
proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
|
132 |
proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
|
133 |
-
proj = nn.ModuleDict(
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
}
|
141 |
-
)
|
142 |
displacement_dropout_p = 0.0
|
143 |
gm_warp_dropout_p = 0.0
|
144 |
-
decoder = Decoder(
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
)
|
154 |
-
|
155 |
encoder = CNNandDinov2(
|
156 |
-
cnn_kwargs=dict(
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
160 |
)
|
161 |
-
h,
|
162 |
symmetric = True
|
163 |
attenuate_cert = True
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
h=h,
|
168 |
-
w=w,
|
169 |
-
upsample_preds=upsample_preds,
|
170 |
-
symmetric=symmetric,
|
171 |
-
attenuate_cert=attenuate_cert,
|
172 |
-
**kwargs
|
173 |
-
).to(device)
|
174 |
matcher.load_state_dict(weights)
|
175 |
return matcher
|
|
|
1 |
import warnings
|
2 |
import torch.nn as nn
|
3 |
+
import torch
|
4 |
from roma.models.matcher import *
|
5 |
from roma.models.transformer import Block, TransformerDecoder, MemEffAttention
|
6 |
from roma.models.encoders import *
|
7 |
|
8 |
+
def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, amp_dtype: torch.dtype=torch.float16, **kwargs):
|
|
|
|
|
|
|
9 |
# roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters
|
10 |
+
#torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul TODO: these probably ruin stuff, should be careful
|
11 |
+
#torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
12 |
+
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
|
|
|
|
13 |
gp_dim = 512
|
14 |
feat_dim = 512
|
15 |
decoder_dim = gp_dim + feat_dim
|
16 |
cls_to_coord_res = 64
|
17 |
coordinate_decoder = TransformerDecoder(
|
18 |
+
nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
|
19 |
+
decoder_dim,
|
|
|
|
|
20 |
cls_to_coord_res**2 + 1,
|
21 |
is_classifier=True,
|
22 |
+
amp = True,
|
23 |
+
pos_enc = False,)
|
|
|
24 |
dw = True
|
25 |
hidden_blocks = 8
|
26 |
kernel_size = 5
|
27 |
displacement_emb = "linear"
|
28 |
disable_local_corr_grad = True
|
29 |
+
|
30 |
conv_refiner = nn.ModuleDict(
|
31 |
{
|
32 |
"16": ConvRefiner(
|
33 |
+
2 * 512+128+(2*7+1)**2,
|
34 |
+
2 * 512+128+(2*7+1)**2,
|
35 |
2 + 1,
|
36 |
kernel_size=kernel_size,
|
37 |
dw=dw,
|
38 |
hidden_blocks=hidden_blocks,
|
39 |
displacement_emb=displacement_emb,
|
40 |
displacement_emb_dim=128,
|
41 |
+
local_corr_radius = 7,
|
42 |
+
corr_in_other = True,
|
43 |
+
amp = True,
|
44 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
45 |
+
bn_momentum = 0.01,
|
46 |
),
|
47 |
"8": ConvRefiner(
|
48 |
+
2 * 512+64+(2*3+1)**2,
|
49 |
+
2 * 512+64+(2*3+1)**2,
|
50 |
2 + 1,
|
51 |
kernel_size=kernel_size,
|
52 |
dw=dw,
|
53 |
hidden_blocks=hidden_blocks,
|
54 |
displacement_emb=displacement_emb,
|
55 |
displacement_emb_dim=64,
|
56 |
+
local_corr_radius = 3,
|
57 |
+
corr_in_other = True,
|
58 |
+
amp = True,
|
59 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
60 |
+
bn_momentum = 0.01,
|
61 |
),
|
62 |
"4": ConvRefiner(
|
63 |
+
2 * 256+32+(2*2+1)**2,
|
64 |
+
2 * 256+32+(2*2+1)**2,
|
65 |
2 + 1,
|
66 |
kernel_size=kernel_size,
|
67 |
dw=dw,
|
68 |
hidden_blocks=hidden_blocks,
|
69 |
displacement_emb=displacement_emb,
|
70 |
displacement_emb_dim=32,
|
71 |
+
local_corr_radius = 2,
|
72 |
+
corr_in_other = True,
|
73 |
+
amp = True,
|
74 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
75 |
+
bn_momentum = 0.01,
|
76 |
),
|
77 |
"2": ConvRefiner(
|
78 |
+
2 * 64+16,
|
79 |
+
128+16,
|
80 |
2 + 1,
|
81 |
kernel_size=kernel_size,
|
82 |
dw=dw,
|
83 |
hidden_blocks=hidden_blocks,
|
84 |
displacement_emb=displacement_emb,
|
85 |
displacement_emb_dim=16,
|
86 |
+
amp = True,
|
87 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
88 |
+
bn_momentum = 0.01,
|
89 |
),
|
90 |
"1": ConvRefiner(
|
91 |
2 * 9 + 6,
|
|
|
93 |
2 + 1,
|
94 |
kernel_size=kernel_size,
|
95 |
dw=dw,
|
96 |
+
hidden_blocks = hidden_blocks,
|
97 |
+
displacement_emb = displacement_emb,
|
98 |
+
displacement_emb_dim = 6,
|
99 |
+
amp = True,
|
100 |
+
disable_local_corr_grad = disable_local_corr_grad,
|
101 |
+
bn_momentum = 0.01,
|
102 |
),
|
103 |
}
|
104 |
)
|
|
|
123 |
proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
|
124 |
proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
|
125 |
proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
|
126 |
+
proj = nn.ModuleDict({
|
127 |
+
"16": proj16,
|
128 |
+
"8": proj8,
|
129 |
+
"4": proj4,
|
130 |
+
"2": proj2,
|
131 |
+
"1": proj1,
|
132 |
+
})
|
|
|
|
|
133 |
displacement_dropout_p = 0.0
|
134 |
gm_warp_dropout_p = 0.0
|
135 |
+
decoder = Decoder(coordinate_decoder,
|
136 |
+
gps,
|
137 |
+
proj,
|
138 |
+
conv_refiner,
|
139 |
+
detach=True,
|
140 |
+
scales=["16", "8", "4", "2", "1"],
|
141 |
+
displacement_dropout_p = displacement_dropout_p,
|
142 |
+
gm_warp_dropout_p = gm_warp_dropout_p)
|
143 |
+
|
|
|
|
|
144 |
encoder = CNNandDinov2(
|
145 |
+
cnn_kwargs = dict(
|
146 |
+
pretrained=False,
|
147 |
+
amp = True),
|
148 |
+
amp = True,
|
149 |
+
use_vgg = True,
|
150 |
+
dinov2_weights = dinov2_weights,
|
151 |
+
amp_dtype=amp_dtype,
|
152 |
)
|
153 |
+
h,w = resolution
|
154 |
symmetric = True
|
155 |
attenuate_cert = True
|
156 |
+
sample_mode = "threshold_balanced"
|
157 |
+
matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds,
|
158 |
+
symmetric = symmetric, attenuate_cert = attenuate_cert, sample_mode = sample_mode, **kwargs).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
matcher.load_state_dict(weights)
|
160 |
return matcher
|
third_party/{Roma β RoMa}/roma/models/transformer/__init__.py
RENAMED
@@ -7,23 +7,9 @@ from .layers.block import Block
|
|
7 |
from .layers.attention import MemEffAttention
|
8 |
from .dinov2 import vit_large
|
9 |
|
10 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
-
|
12 |
-
|
13 |
class TransformerDecoder(nn.Module):
|
14 |
-
def __init__(
|
15 |
-
|
16 |
-
blocks,
|
17 |
-
hidden_dim,
|
18 |
-
out_dim,
|
19 |
-
is_classifier=False,
|
20 |
-
*args,
|
21 |
-
amp=False,
|
22 |
-
pos_enc=True,
|
23 |
-
learned_embeddings=False,
|
24 |
-
embedding_dim=None,
|
25 |
-
**kwargs
|
26 |
-
) -> None:
|
27 |
super().__init__(*args, **kwargs)
|
28 |
self.blocks = blocks
|
29 |
self.to_out = nn.Linear(hidden_dim, out_dim)
|
@@ -32,48 +18,30 @@ class TransformerDecoder(nn.Module):
|
|
32 |
self._scales = [16]
|
33 |
self.is_classifier = is_classifier
|
34 |
self.amp = amp
|
35 |
-
|
36 |
-
if torch.cuda.is_bf16_supported():
|
37 |
-
self.amp_dtype = torch.bfloat16
|
38 |
-
else:
|
39 |
-
self.amp_dtype = torch.float16
|
40 |
-
else:
|
41 |
-
self.amp_dtype = torch.float32
|
42 |
-
|
43 |
self.pos_enc = pos_enc
|
44 |
self.learned_embeddings = learned_embeddings
|
45 |
if self.learned_embeddings:
|
46 |
-
self.learned_pos_embeddings = nn.Parameter(
|
47 |
-
nn.init.kaiming_normal_(
|
48 |
-
torch.empty((1, hidden_dim, embedding_dim, embedding_dim))
|
49 |
-
)
|
50 |
-
)
|
51 |
|
52 |
def scales(self):
|
53 |
return self._scales.copy()
|
54 |
|
55 |
def forward(self, gp_posterior, features, old_stuff, new_scale):
|
56 |
-
with torch.autocast(
|
57 |
-
B,
|
58 |
-
x = torch.cat((gp_posterior, features), dim=1)
|
59 |
-
B,
|
60 |
-
grid = get_grid(B, H, W, x.device).reshape(B,
|
61 |
if self.learned_embeddings:
|
62 |
-
pos_enc = (
|
63 |
-
F.interpolate(
|
64 |
-
self.learned_pos_embeddings,
|
65 |
-
size=(H, W),
|
66 |
-
mode="bilinear",
|
67 |
-
align_corners=False,
|
68 |
-
)
|
69 |
-
.permute(0, 2, 3, 1)
|
70 |
-
.reshape(1, H * W, C)
|
71 |
-
)
|
72 |
else:
|
73 |
pos_enc = 0
|
74 |
-
tokens = x.reshape(B,
|
75 |
z = self.blocks(tokens)
|
76 |
out = self.to_out(z)
|
77 |
-
out = out.permute(0,
|
78 |
warp, certainty = out[:, :-1], out[:, -1:]
|
79 |
return warp, certainty, None
|
|
|
|
|
|
7 |
from .layers.attention import MemEffAttention
|
8 |
from .dinov2 import vit_large
|
9 |
|
|
|
|
|
|
|
10 |
class TransformerDecoder(nn.Module):
|
11 |
+
def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args,
|
12 |
+
amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, amp_dtype = torch.float16, **kwargs) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
super().__init__(*args, **kwargs)
|
14 |
self.blocks = blocks
|
15 |
self.to_out = nn.Linear(hidden_dim, out_dim)
|
|
|
18 |
self._scales = [16]
|
19 |
self.is_classifier = is_classifier
|
20 |
self.amp = amp
|
21 |
+
self.amp_dtype = amp_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
self.pos_enc = pos_enc
|
23 |
self.learned_embeddings = learned_embeddings
|
24 |
if self.learned_embeddings:
|
25 |
+
self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim))))
|
|
|
|
|
|
|
|
|
26 |
|
27 |
def scales(self):
|
28 |
return self._scales.copy()
|
29 |
|
30 |
def forward(self, gp_posterior, features, old_stuff, new_scale):
|
31 |
+
with torch.autocast("cuda", dtype=self.amp_dtype, enabled=self.amp):
|
32 |
+
B,C,H,W = gp_posterior.shape
|
33 |
+
x = torch.cat((gp_posterior, features), dim = 1)
|
34 |
+
B,C,H,W = x.shape
|
35 |
+
grid = get_grid(B, H, W, x.device).reshape(B,H*W,2)
|
36 |
if self.learned_embeddings:
|
37 |
+
pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
else:
|
39 |
pos_enc = 0
|
40 |
+
tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc
|
41 |
z = self.blocks(tokens)
|
42 |
out = self.to_out(z)
|
43 |
+
out = out.permute(0,2,1).reshape(B, self.out_dim, H, W)
|
44 |
warp, certainty = out[:, :-1], out[:, -1:]
|
45 |
return warp, certainty, None
|
46 |
+
|
47 |
+
|
third_party/{Roma β RoMa}/roma/models/transformer/dinov2.py
RENAMED
@@ -18,29 +18,16 @@ import torch.nn as nn
|
|
18 |
import torch.utils.checkpoint
|
19 |
from torch.nn.init import trunc_normal_
|
20 |
|
21 |
-
from .layers import
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
NestedTensorBlock as Block,
|
27 |
-
)
|
28 |
-
|
29 |
-
|
30 |
-
def named_apply(
|
31 |
-
fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
|
32 |
-
) -> nn.Module:
|
33 |
if not depth_first and include_root:
|
34 |
fn(module=module, name=name)
|
35 |
for child_name, child_module in module.named_children():
|
36 |
child_name = ".".join((name, child_name)) if name else child_name
|
37 |
-
named_apply(
|
38 |
-
fn=fn,
|
39 |
-
module=child_module,
|
40 |
-
name=child_name,
|
41 |
-
depth_first=depth_first,
|
42 |
-
include_root=True,
|
43 |
-
)
|
44 |
if depth_first and include_root:
|
45 |
fn(module=module, name=name)
|
46 |
return module
|
@@ -100,33 +87,22 @@ class DinoVisionTransformer(nn.Module):
|
|
100 |
super().__init__()
|
101 |
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
102 |
|
103 |
-
self.num_features =
|
104 |
-
self.embed_dim
|
105 |
-
) = embed_dim # num_features for consistency with other models
|
106 |
self.num_tokens = 1
|
107 |
self.n_blocks = depth
|
108 |
self.num_heads = num_heads
|
109 |
self.patch_size = patch_size
|
110 |
|
111 |
-
self.patch_embed = embed_layer(
|
112 |
-
img_size=img_size,
|
113 |
-
patch_size=patch_size,
|
114 |
-
in_chans=in_chans,
|
115 |
-
embed_dim=embed_dim,
|
116 |
-
)
|
117 |
num_patches = self.patch_embed.num_patches
|
118 |
|
119 |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
120 |
-
self.pos_embed = nn.Parameter(
|
121 |
-
torch.zeros(1, num_patches + self.num_tokens, embed_dim)
|
122 |
-
)
|
123 |
|
124 |
if drop_path_uniform is True:
|
125 |
dpr = [drop_path_rate] * depth
|
126 |
else:
|
127 |
-
dpr = [
|
128 |
-
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
129 |
-
] # stochastic depth decay rule
|
130 |
|
131 |
if ffn_layer == "mlp":
|
132 |
ffn_layer = Mlp
|
@@ -163,9 +139,7 @@ class DinoVisionTransformer(nn.Module):
|
|
163 |
chunksize = depth // block_chunks
|
164 |
for i in range(0, depth, chunksize):
|
165 |
# this is to keep the block index consistent if we chunk the block list
|
166 |
-
chunked_blocks.append(
|
167 |
-
[nn.Identity()] * i + blocks_list[i : i + chunksize]
|
168 |
-
)
|
169 |
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
170 |
else:
|
171 |
self.chunked_blocks = False
|
@@ -179,7 +153,7 @@ class DinoVisionTransformer(nn.Module):
|
|
179 |
self.init_weights()
|
180 |
for param in self.parameters():
|
181 |
param.requires_grad = False
|
182 |
-
|
183 |
@property
|
184 |
def device(self):
|
185 |
return self.cls_token.device
|
@@ -206,29 +180,20 @@ class DinoVisionTransformer(nn.Module):
|
|
206 |
w0, h0 = w0 + 0.1, h0 + 0.1
|
207 |
|
208 |
patch_pos_embed = nn.functional.interpolate(
|
209 |
-
patch_pos_embed.reshape(
|
210 |
-
1, int(math.sqrt(N)), int(math.sqrt(N)), dim
|
211 |
-
).permute(0, 3, 1, 2),
|
212 |
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
213 |
mode="bicubic",
|
214 |
)
|
215 |
|
216 |
-
assert (
|
217 |
-
int(w0) == patch_pos_embed.shape[-2]
|
218 |
-
and int(h0) == patch_pos_embed.shape[-1]
|
219 |
-
)
|
220 |
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
221 |
-
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
|
222 |
-
previous_dtype
|
223 |
-
)
|
224 |
|
225 |
def prepare_tokens_with_masks(self, x, masks=None):
|
226 |
B, nc, w, h = x.shape
|
227 |
x = self.patch_embed(x)
|
228 |
if masks is not None:
|
229 |
-
x = torch.where(
|
230 |
-
masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
|
231 |
-
)
|
232 |
|
233 |
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
234 |
x = x + self.interpolate_pos_encoding(x, w, h)
|
@@ -236,10 +201,7 @@ class DinoVisionTransformer(nn.Module):
|
|
236 |
return x
|
237 |
|
238 |
def forward_features_list(self, x_list, masks_list):
|
239 |
-
x = [
|
240 |
-
self.prepare_tokens_with_masks(x, masks)
|
241 |
-
for x, masks in zip(x_list, masks_list)
|
242 |
-
]
|
243 |
for blk in self.blocks:
|
244 |
x = blk(x)
|
245 |
|
@@ -278,34 +240,26 @@ class DinoVisionTransformer(nn.Module):
|
|
278 |
x = self.prepare_tokens_with_masks(x)
|
279 |
# If n is an int, take the n last blocks. If it's a list, take them
|
280 |
output, total_block_len = [], len(self.blocks)
|
281 |
-
blocks_to_take = (
|
282 |
-
range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
283 |
-
)
|
284 |
for i, blk in enumerate(self.blocks):
|
285 |
x = blk(x)
|
286 |
if i in blocks_to_take:
|
287 |
output.append(x)
|
288 |
-
assert len(output) == len(
|
289 |
-
blocks_to_take
|
290 |
-
), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
291 |
return output
|
292 |
|
293 |
def _get_intermediate_layers_chunked(self, x, n=1):
|
294 |
x = self.prepare_tokens_with_masks(x)
|
295 |
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
296 |
# If n is an int, take the n last blocks. If it's a list, take them
|
297 |
-
blocks_to_take = (
|
298 |
-
range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
299 |
-
)
|
300 |
for block_chunk in self.blocks:
|
301 |
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
302 |
x = blk(x)
|
303 |
if i in blocks_to_take:
|
304 |
output.append(x)
|
305 |
i += 1
|
306 |
-
assert len(output) == len(
|
307 |
-
blocks_to_take
|
308 |
-
), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
309 |
return output
|
310 |
|
311 |
def get_intermediate_layers(
|
@@ -327,9 +281,7 @@ class DinoVisionTransformer(nn.Module):
|
|
327 |
if reshape:
|
328 |
B, _, w, h = x.shape
|
329 |
outputs = [
|
330 |
-
out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
|
331 |
-
.permute(0, 3, 1, 2)
|
332 |
-
.contiguous()
|
333 |
for out in outputs
|
334 |
]
|
335 |
if return_class_token:
|
@@ -404,4 +356,4 @@ def vit_giant2(patch_size=16, **kwargs):
|
|
404 |
block_fn=partial(Block, attn_class=MemEffAttention),
|
405 |
**kwargs,
|
406 |
)
|
407 |
-
return model
|
|
|
18 |
import torch.utils.checkpoint
|
19 |
from torch.nn.init import trunc_normal_
|
20 |
|
21 |
+
from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
if not depth_first and include_root:
|
27 |
fn(module=module, name=name)
|
28 |
for child_name, child_module in module.named_children():
|
29 |
child_name = ".".join((name, child_name)) if name else child_name
|
30 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
if depth_first and include_root:
|
32 |
fn(module=module, name=name)
|
33 |
return module
|
|
|
87 |
super().__init__()
|
88 |
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
89 |
|
90 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
|
|
|
|
91 |
self.num_tokens = 1
|
92 |
self.n_blocks = depth
|
93 |
self.num_heads = num_heads
|
94 |
self.patch_size = patch_size
|
95 |
|
96 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
|
|
|
|
|
|
|
|
|
97 |
num_patches = self.patch_embed.num_patches
|
98 |
|
99 |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
100 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
|
|
|
|
101 |
|
102 |
if drop_path_uniform is True:
|
103 |
dpr = [drop_path_rate] * depth
|
104 |
else:
|
105 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
|
|
|
106 |
|
107 |
if ffn_layer == "mlp":
|
108 |
ffn_layer = Mlp
|
|
|
139 |
chunksize = depth // block_chunks
|
140 |
for i in range(0, depth, chunksize):
|
141 |
# this is to keep the block index consistent if we chunk the block list
|
142 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
|
|
|
|
143 |
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
144 |
else:
|
145 |
self.chunked_blocks = False
|
|
|
153 |
self.init_weights()
|
154 |
for param in self.parameters():
|
155 |
param.requires_grad = False
|
156 |
+
|
157 |
@property
|
158 |
def device(self):
|
159 |
return self.cls_token.device
|
|
|
180 |
w0, h0 = w0 + 0.1, h0 + 0.1
|
181 |
|
182 |
patch_pos_embed = nn.functional.interpolate(
|
183 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
|
|
|
|
184 |
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
185 |
mode="bicubic",
|
186 |
)
|
187 |
|
188 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
|
|
|
|
|
|
189 |
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
190 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
|
|
|
|
191 |
|
192 |
def prepare_tokens_with_masks(self, x, masks=None):
|
193 |
B, nc, w, h = x.shape
|
194 |
x = self.patch_embed(x)
|
195 |
if masks is not None:
|
196 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
|
|
|
|
197 |
|
198 |
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
199 |
x = x + self.interpolate_pos_encoding(x, w, h)
|
|
|
201 |
return x
|
202 |
|
203 |
def forward_features_list(self, x_list, masks_list):
|
204 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
|
|
|
|
|
|
205 |
for blk in self.blocks:
|
206 |
x = blk(x)
|
207 |
|
|
|
240 |
x = self.prepare_tokens_with_masks(x)
|
241 |
# If n is an int, take the n last blocks. If it's a list, take them
|
242 |
output, total_block_len = [], len(self.blocks)
|
243 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
|
|
|
|
244 |
for i, blk in enumerate(self.blocks):
|
245 |
x = blk(x)
|
246 |
if i in blocks_to_take:
|
247 |
output.append(x)
|
248 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
|
|
|
|
249 |
return output
|
250 |
|
251 |
def _get_intermediate_layers_chunked(self, x, n=1):
|
252 |
x = self.prepare_tokens_with_masks(x)
|
253 |
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
254 |
# If n is an int, take the n last blocks. If it's a list, take them
|
255 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
|
|
|
|
256 |
for block_chunk in self.blocks:
|
257 |
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
258 |
x = blk(x)
|
259 |
if i in blocks_to_take:
|
260 |
output.append(x)
|
261 |
i += 1
|
262 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
|
|
|
|
263 |
return output
|
264 |
|
265 |
def get_intermediate_layers(
|
|
|
281 |
if reshape:
|
282 |
B, _, w, h = x.shape
|
283 |
outputs = [
|
284 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
285 |
for out in outputs
|
286 |
]
|
287 |
if return_class_token:
|
|
|
356 |
block_fn=partial(Block, attn_class=MemEffAttention),
|
357 |
**kwargs,
|
358 |
)
|
359 |
+
return model
|
third_party/{Roma β RoMa}/roma/models/transformer/layers/__init__.py
RENAMED
File without changes
|
third_party/{Roma β RoMa}/roma/models/transformer/layers/attention.py
RENAMED
@@ -48,11 +48,7 @@ class Attention(nn.Module):
|
|
48 |
|
49 |
def forward(self, x: Tensor) -> Tensor:
|
50 |
B, N, C = x.shape
|
51 |
-
qkv = (
|
52 |
-
self.qkv(x)
|
53 |
-
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
54 |
-
.permute(2, 0, 3, 1, 4)
|
55 |
-
)
|
56 |
|
57 |
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
58 |
attn = q @ k.transpose(-2, -1)
|
|
|
48 |
|
49 |
def forward(self, x: Tensor) -> Tensor:
|
50 |
B, N, C = x.shape
|
51 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
|
|
|
|
|
52 |
|
53 |
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
54 |
attn = q @ k.transpose(-2, -1)
|
third_party/{Roma β RoMa}/roma/models/transformer/layers/block.py
RENAMED
@@ -62,9 +62,7 @@ class Block(nn.Module):
|
|
62 |
attn_drop=attn_drop,
|
63 |
proj_drop=drop,
|
64 |
)
|
65 |
-
self.ls1 = (
|
66 |
-
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
67 |
-
)
|
68 |
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
69 |
|
70 |
self.norm2 = norm_layer(dim)
|
@@ -76,9 +74,7 @@ class Block(nn.Module):
|
|
76 |
drop=drop,
|
77 |
bias=ffn_bias,
|
78 |
)
|
79 |
-
self.ls2 = (
|
80 |
-
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
81 |
-
)
|
82 |
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
83 |
|
84 |
self.sample_drop_ratio = drop_path
|
@@ -131,9 +127,7 @@ def drop_add_residual_stochastic_depth(
|
|
131 |
residual_scale_factor = b / sample_subset_size
|
132 |
|
133 |
# 3) add the residual
|
134 |
-
x_plus_residual = torch.index_add(
|
135 |
-
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
136 |
-
)
|
137 |
return x_plus_residual.view_as(x)
|
138 |
|
139 |
|
@@ -149,16 +143,10 @@ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None
|
|
149 |
if scaling_vector is None:
|
150 |
x_flat = x.flatten(1)
|
151 |
residual = residual.flatten(1)
|
152 |
-
x_plus_residual = torch.index_add(
|
153 |
-
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
154 |
-
)
|
155 |
else:
|
156 |
x_plus_residual = scaled_index_add(
|
157 |
-
x,
|
158 |
-
brange,
|
159 |
-
residual.to(dtype=x.dtype),
|
160 |
-
scaling=scaling_vector,
|
161 |
-
alpha=residual_scale_factor,
|
162 |
)
|
163 |
return x_plus_residual
|
164 |
|
@@ -170,11 +158,7 @@ def get_attn_bias_and_cat(x_list, branges=None):
|
|
170 |
"""
|
171 |
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
172 |
"""
|
173 |
-
batch_sizes =
|
174 |
-
[b.shape[0] for b in branges]
|
175 |
-
if branges is not None
|
176 |
-
else [x.shape[0] for x in x_list]
|
177 |
-
)
|
178 |
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
179 |
if all_shapes not in attn_bias_cache.keys():
|
180 |
seqlens = []
|
@@ -186,9 +170,7 @@ def get_attn_bias_and_cat(x_list, branges=None):
|
|
186 |
attn_bias_cache[all_shapes] = attn_bias
|
187 |
|
188 |
if branges is not None:
|
189 |
-
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
|
190 |
-
1, -1, x_list[0].shape[-1]
|
191 |
-
)
|
192 |
else:
|
193 |
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
194 |
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
@@ -203,9 +185,7 @@ def drop_add_residual_stochastic_depth_list(
|
|
203 |
scaling_vector=None,
|
204 |
) -> Tensor:
|
205 |
# 1) generate random set of indices for dropping samples in the batch
|
206 |
-
branges_scales = [
|
207 |
-
get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
|
208 |
-
]
|
209 |
branges = [s[0] for s in branges_scales]
|
210 |
residual_scale_factors = [s[1] for s in branges_scales]
|
211 |
|
@@ -216,14 +196,8 @@ def drop_add_residual_stochastic_depth_list(
|
|
216 |
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
217 |
|
218 |
outputs = []
|
219 |
-
for x, brange, residual, residual_scale_factor in zip(
|
220 |
-
|
221 |
-
):
|
222 |
-
outputs.append(
|
223 |
-
add_residual(
|
224 |
-
x, brange, residual, residual_scale_factor, scaling_vector
|
225 |
-
).view_as(x)
|
226 |
-
)
|
227 |
return outputs
|
228 |
|
229 |
|
@@ -246,17 +220,13 @@ class NestedTensorBlock(Block):
|
|
246 |
x_list,
|
247 |
residual_func=attn_residual_func,
|
248 |
sample_drop_ratio=self.sample_drop_ratio,
|
249 |
-
scaling_vector=self.ls1.gamma
|
250 |
-
if isinstance(self.ls1, LayerScale)
|
251 |
-
else None,
|
252 |
)
|
253 |
x_list = drop_add_residual_stochastic_depth_list(
|
254 |
x_list,
|
255 |
residual_func=ffn_residual_func,
|
256 |
sample_drop_ratio=self.sample_drop_ratio,
|
257 |
-
scaling_vector=self.ls2.gamma
|
258 |
-
if isinstance(self.ls1, LayerScale)
|
259 |
-
else None,
|
260 |
)
|
261 |
return x_list
|
262 |
else:
|
@@ -276,9 +246,7 @@ class NestedTensorBlock(Block):
|
|
276 |
if isinstance(x_or_x_list, Tensor):
|
277 |
return super().forward(x_or_x_list)
|
278 |
elif isinstance(x_or_x_list, list):
|
279 |
-
assert
|
280 |
-
XFORMERS_AVAILABLE
|
281 |
-
), "Please install xFormers for nested tensors usage"
|
282 |
return self.forward_nested(x_or_x_list)
|
283 |
else:
|
284 |
raise AssertionError
|
|
|
62 |
attn_drop=attn_drop,
|
63 |
proj_drop=drop,
|
64 |
)
|
65 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
|
|
|
66 |
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
67 |
|
68 |
self.norm2 = norm_layer(dim)
|
|
|
74 |
drop=drop,
|
75 |
bias=ffn_bias,
|
76 |
)
|
77 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
|
|
|
78 |
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
79 |
|
80 |
self.sample_drop_ratio = drop_path
|
|
|
127 |
residual_scale_factor = b / sample_subset_size
|
128 |
|
129 |
# 3) add the residual
|
130 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
|
|
|
|
131 |
return x_plus_residual.view_as(x)
|
132 |
|
133 |
|
|
|
143 |
if scaling_vector is None:
|
144 |
x_flat = x.flatten(1)
|
145 |
residual = residual.flatten(1)
|
146 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
|
|
|
|
147 |
else:
|
148 |
x_plus_residual = scaled_index_add(
|
149 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
|
|
|
|
|
|
|
|
150 |
)
|
151 |
return x_plus_residual
|
152 |
|
|
|
158 |
"""
|
159 |
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
160 |
"""
|
161 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
|
|
|
|
|
|
|
|
162 |
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
163 |
if all_shapes not in attn_bias_cache.keys():
|
164 |
seqlens = []
|
|
|
170 |
attn_bias_cache[all_shapes] = attn_bias
|
171 |
|
172 |
if branges is not None:
|
173 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
|
|
|
|
174 |
else:
|
175 |
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
176 |
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
|
|
185 |
scaling_vector=None,
|
186 |
) -> Tensor:
|
187 |
# 1) generate random set of indices for dropping samples in the batch
|
188 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
|
|
|
|
189 |
branges = [s[0] for s in branges_scales]
|
190 |
residual_scale_factors = [s[1] for s in branges_scales]
|
191 |
|
|
|
196 |
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
197 |
|
198 |
outputs = []
|
199 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
200 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
return outputs
|
202 |
|
203 |
|
|
|
220 |
x_list,
|
221 |
residual_func=attn_residual_func,
|
222 |
sample_drop_ratio=self.sample_drop_ratio,
|
223 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
|
|
|
|
224 |
)
|
225 |
x_list = drop_add_residual_stochastic_depth_list(
|
226 |
x_list,
|
227 |
residual_func=ffn_residual_func,
|
228 |
sample_drop_ratio=self.sample_drop_ratio,
|
229 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
|
|
|
|
230 |
)
|
231 |
return x_list
|
232 |
else:
|
|
|
246 |
if isinstance(x_or_x_list, Tensor):
|
247 |
return super().forward(x_or_x_list)
|
248 |
elif isinstance(x_or_x_list, list):
|
249 |
+
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
|
|
|
|
250 |
return self.forward_nested(x_or_x_list)
|
251 |
else:
|
252 |
raise AssertionError
|
third_party/{Roma β RoMa}/roma/models/transformer/layers/dino_head.py
RENAMED
@@ -23,14 +23,7 @@ class DINOHead(nn.Module):
|
|
23 |
):
|
24 |
super().__init__()
|
25 |
nlayers = max(nlayers, 1)
|
26 |
-
self.mlp = _build_mlp(
|
27 |
-
nlayers,
|
28 |
-
in_dim,
|
29 |
-
bottleneck_dim,
|
30 |
-
hidden_dim=hidden_dim,
|
31 |
-
use_bn=use_bn,
|
32 |
-
bias=mlp_bias,
|
33 |
-
)
|
34 |
self.apply(self._init_weights)
|
35 |
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
36 |
self.last_layer.weight_g.data.fill_(1)
|
@@ -49,9 +42,7 @@ class DINOHead(nn.Module):
|
|
49 |
return x
|
50 |
|
51 |
|
52 |
-
def _build_mlp(
|
53 |
-
nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
|
54 |
-
):
|
55 |
if nlayers == 1:
|
56 |
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
57 |
else:
|
|
|
23 |
):
|
24 |
super().__init__()
|
25 |
nlayers = max(nlayers, 1)
|
26 |
+
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
self.apply(self._init_weights)
|
28 |
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
29 |
self.last_layer.weight_g.data.fill_(1)
|
|
|
42 |
return x
|
43 |
|
44 |
|
45 |
+
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
|
|
|
|
46 |
if nlayers == 1:
|
47 |
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
48 |
else:
|
third_party/{Roma β RoMa}/roma/models/transformer/layers/drop_path.py
RENAMED
@@ -16,9 +16,7 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
|
16 |
if drop_prob == 0.0 or not training:
|
17 |
return x
|
18 |
keep_prob = 1 - drop_prob
|
19 |
-
shape = (x.shape[0],) + (1,) * (
|
20 |
-
x.ndim - 1
|
21 |
-
) # work with diff dim tensors, not just 2D ConvNets
|
22 |
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
23 |
if keep_prob > 0.0:
|
24 |
random_tensor.div_(keep_prob)
|
|
|
16 |
if drop_prob == 0.0 or not training:
|
17 |
return x
|
18 |
keep_prob = 1 - drop_prob
|
19 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
|
|
|
|
20 |
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
21 |
if keep_prob > 0.0:
|
22 |
random_tensor.div_(keep_prob)
|
third_party/{Roma β RoMa}/roma/models/transformer/layers/layer_scale.py
RENAMED
File without changes
|
third_party/{Roma β RoMa}/roma/models/transformer/layers/mlp.py
RENAMED
File without changes
|
third_party/{Roma β RoMa}/roma/models/transformer/layers/patch_embed.py
RENAMED
@@ -63,21 +63,15 @@ class PatchEmbed(nn.Module):
|
|
63 |
|
64 |
self.flatten_embedding = flatten_embedding
|
65 |
|
66 |
-
self.proj = nn.Conv2d(
|
67 |
-
in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
|
68 |
-
)
|
69 |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
70 |
|
71 |
def forward(self, x: Tensor) -> Tensor:
|
72 |
_, _, H, W = x.shape
|
73 |
patch_H, patch_W = self.patch_size
|
74 |
|
75 |
-
assert
|
76 |
-
|
77 |
-
), f"Input image height {H} is not a multiple of patch height {patch_H}"
|
78 |
-
assert (
|
79 |
-
W % patch_W == 0
|
80 |
-
), f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
81 |
|
82 |
x = self.proj(x) # B C H W
|
83 |
H, W = x.size(2), x.size(3)
|
@@ -89,13 +83,7 @@ class PatchEmbed(nn.Module):
|
|
89 |
|
90 |
def flops(self) -> float:
|
91 |
Ho, Wo = self.patches_resolution
|
92 |
-
flops = (
|
93 |
-
Ho
|
94 |
-
* Wo
|
95 |
-
* self.embed_dim
|
96 |
-
* self.in_chans
|
97 |
-
* (self.patch_size[0] * self.patch_size[1])
|
98 |
-
)
|
99 |
if self.norm is not None:
|
100 |
flops += Ho * Wo * self.embed_dim
|
101 |
return flops
|
|
|
63 |
|
64 |
self.flatten_embedding = flatten_embedding
|
65 |
|
66 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
|
|
|
|
67 |
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
68 |
|
69 |
def forward(self, x: Tensor) -> Tensor:
|
70 |
_, _, H, W = x.shape
|
71 |
patch_H, patch_W = self.patch_size
|
72 |
|
73 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
74 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
|
|
|
|
|
|
|
|
75 |
|
76 |
x = self.proj(x) # B C H W
|
77 |
H, W = x.size(2), x.size(3)
|
|
|
83 |
|
84 |
def flops(self) -> float:
|
85 |
Ho, Wo = self.patches_resolution
|
86 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
if self.norm is not None:
|
88 |
flops += Ho * Wo * self.embed_dim
|
89 |
return flops
|
third_party/{Roma β RoMa}/roma/models/transformer/layers/swiglu_ffn.py
RENAMED
File without changes
|
third_party/{Roma β RoMa}/roma/train/__init__.py
RENAMED
File without changes
|
third_party/{Roma β RoMa}/roma/train/train.py
RENAMED
@@ -4,62 +4,41 @@ import roma
|
|
4 |
import torch
|
5 |
import wandb
|
6 |
|
7 |
-
|
8 |
-
def log_param_statistics(named_parameters, norm_type=2):
|
9 |
named_parameters = list(named_parameters)
|
10 |
grads = [p.grad for n, p in named_parameters if p.grad is not None]
|
11 |
-
weight_norms = [
|
12 |
-
|
13 |
-
]
|
14 |
-
names = [n for n, p in named_parameters if p.grad is not None]
|
15 |
param_norm = torch.stack(weight_norms).norm(p=norm_type)
|
16 |
device = grads[0].device
|
17 |
-
grad_norms = torch.stack(
|
18 |
-
[torch.norm(g.detach(), norm_type).to(device) for g in grads]
|
19 |
-
)
|
20 |
nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms)
|
21 |
nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf]
|
22 |
total_grad_norm = torch.norm(grad_norms, norm_type)
|
23 |
if torch.any(nans_or_infs):
|
24 |
print(f"These params have nan or inf grads: {nan_inf_names}")
|
25 |
-
wandb.log({"grad_norm": total_grad_norm.item()}, step=roma.GLOBAL_STEP)
|
26 |
-
wandb.log({"param_norm": param_norm.item()}, step=roma.GLOBAL_STEP)
|
27 |
-
|
28 |
|
29 |
-
def train_step(
|
30 |
-
train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm=1.0, **kwargs
|
31 |
-
):
|
32 |
optimizer.zero_grad()
|
33 |
out = model(train_batch)
|
34 |
l = objective(out, train_batch)
|
35 |
grad_scaler.scale(l).backward()
|
36 |
grad_scaler.unscale_(optimizer)
|
37 |
log_param_statistics(model.named_parameters())
|
38 |
-
torch.nn.utils.clip_grad_norm_(
|
39 |
-
model.parameters(), grad_clip_norm
|
40 |
-
) # what should max norm be?
|
41 |
grad_scaler.step(optimizer)
|
42 |
grad_scaler.update()
|
43 |
-
wandb.log({"grad_scale": grad_scaler._scale.item()}, step=roma.GLOBAL_STEP)
|
44 |
-
if grad_scaler._scale < 1
|
45 |
-
grad_scaler._scale = torch.tensor(1.
|
46 |
-
roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE
|
47 |
return {"train_out": out, "train_loss": l.item()}
|
48 |
|
49 |
|
50 |
def train_k_steps(
|
51 |
-
n_0,
|
52 |
-
k,
|
53 |
-
dataloader,
|
54 |
-
model,
|
55 |
-
objective,
|
56 |
-
optimizer,
|
57 |
-
lr_scheduler,
|
58 |
-
grad_scaler,
|
59 |
-
progress_bar=True,
|
60 |
-
grad_clip_norm=1.0,
|
61 |
-
warmup=None,
|
62 |
-
ema_model=None,
|
63 |
):
|
64 |
for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or roma.RANK > 0):
|
65 |
batch = next(dataloader)
|
@@ -73,7 +52,7 @@ def train_k_steps(
|
|
73 |
lr_scheduler=lr_scheduler,
|
74 |
grad_scaler=grad_scaler,
|
75 |
n=n,
|
76 |
-
grad_clip_norm=grad_clip_norm,
|
77 |
)
|
78 |
if ema_model is not None:
|
79 |
ema_model.update()
|
@@ -82,10 +61,7 @@ def train_k_steps(
|
|
82 |
lr_scheduler.step()
|
83 |
else:
|
84 |
lr_scheduler.step()
|
85 |
-
[
|
86 |
-
wandb.log({f"lr_group_{grp}": lr})
|
87 |
-
for grp, lr in enumerate(lr_scheduler.get_last_lr())
|
88 |
-
]
|
89 |
|
90 |
|
91 |
def train_epoch(
|
|
|
4 |
import torch
|
5 |
import wandb
|
6 |
|
7 |
+
def log_param_statistics(named_parameters, norm_type = 2):
|
|
|
8 |
named_parameters = list(named_parameters)
|
9 |
grads = [p.grad for n, p in named_parameters if p.grad is not None]
|
10 |
+
weight_norms = [p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None]
|
11 |
+
names = [n for n,p in named_parameters if p.grad is not None]
|
|
|
|
|
12 |
param_norm = torch.stack(weight_norms).norm(p=norm_type)
|
13 |
device = grads[0].device
|
14 |
+
grad_norms = torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads])
|
|
|
|
|
15 |
nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms)
|
16 |
nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf]
|
17 |
total_grad_norm = torch.norm(grad_norms, norm_type)
|
18 |
if torch.any(nans_or_infs):
|
19 |
print(f"These params have nan or inf grads: {nan_inf_names}")
|
20 |
+
wandb.log({"grad_norm": total_grad_norm.item()}, step = roma.GLOBAL_STEP)
|
21 |
+
wandb.log({"param_norm": param_norm.item()}, step = roma.GLOBAL_STEP)
|
|
|
22 |
|
23 |
+
def train_step(train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm = 1.,**kwargs):
|
|
|
|
|
24 |
optimizer.zero_grad()
|
25 |
out = model(train_batch)
|
26 |
l = objective(out, train_batch)
|
27 |
grad_scaler.scale(l).backward()
|
28 |
grad_scaler.unscale_(optimizer)
|
29 |
log_param_statistics(model.named_parameters())
|
30 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm) # what should max norm be?
|
|
|
|
|
31 |
grad_scaler.step(optimizer)
|
32 |
grad_scaler.update()
|
33 |
+
wandb.log({"grad_scale": grad_scaler._scale.item()}, step = roma.GLOBAL_STEP)
|
34 |
+
if grad_scaler._scale < 1.:
|
35 |
+
grad_scaler._scale = torch.tensor(1.).to(grad_scaler._scale)
|
36 |
+
roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE # increment global step
|
37 |
return {"train_out": out, "train_loss": l.item()}
|
38 |
|
39 |
|
40 |
def train_k_steps(
|
41 |
+
n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler, progress_bar=True, grad_clip_norm = 1., warmup = None, ema_model = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
):
|
43 |
for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or roma.RANK > 0):
|
44 |
batch = next(dataloader)
|
|
|
52 |
lr_scheduler=lr_scheduler,
|
53 |
grad_scaler=grad_scaler,
|
54 |
n=n,
|
55 |
+
grad_clip_norm = grad_clip_norm,
|
56 |
)
|
57 |
if ema_model is not None:
|
58 |
ema_model.update()
|
|
|
61 |
lr_scheduler.step()
|
62 |
else:
|
63 |
lr_scheduler.step()
|
64 |
+
[wandb.log({f"lr_group_{grp}": lr}) for grp, lr in enumerate(lr_scheduler.get_last_lr())]
|
|
|
|
|
|
|
65 |
|
66 |
|
67 |
def train_epoch(
|
third_party/{Roma β RoMa}/roma/utils/__init__.py
RENAMED
File without changes
|