go into cdistMatcher even on cpu devices
Browse files- mast3r/fast_nn.py +3 -1
mast3r/fast_nn.py
CHANGED
@@ -132,7 +132,9 @@ def fast_reciprocal_NNs(pts1, pts2, subsample_or_initxy1=8, ret_xy=True, pixel_t
|
|
132 |
old_xy1 = xy1.copy()
|
133 |
old_xy2 = xy2.copy()
|
134 |
|
135 |
-
if
|
|
|
|
|
136 |
pts1 = pts1.to(device)
|
137 |
pts2 = pts2.to(device)
|
138 |
tree1 = cdistMatcher(pts1, device=device)
|
|
|
132 |
old_xy1 = xy1.copy()
|
133 |
old_xy2 = xy2.copy()
|
134 |
|
135 |
+
if 'dist' in matcher_kw or 'block_size' in matcher_kw \
|
136 |
+
or (isinstance(device, str) and device.startswith('cuda')) \
|
137 |
+
or (isinstance(device, torch.device) and device.type.startswith('cuda')):
|
138 |
pts1 = pts1.to(device)
|
139 |
pts2 = pts2.to(device)
|
140 |
tree1 = cdistMatcher(pts1, device=device)
|