Spaces:
Running
Running
Realcat
commited on
Commit
·
aebdae7
1
Parent(s):
7dc6568
add: ray dashboard port and serve port
Browse files- README.md +1 -1
- api/__init__.py +42 -0
- api/client.py +4 -11
- api/config/api.yaml +51 -0
- api/server.py +116 -145
- api/types.py +0 -16
- requirements.txt +3 -0
README.md
CHANGED
@@ -107,7 +107,7 @@ docker run -it -p 7860:7860 vincentqin/image-matching-webui:latest python app.py
|
|
107 |
|
108 |
### Run demo
|
109 |
``` bash
|
110 |
-
python3
|
111 |
```
|
112 |
then open http://localhost:7860 in your browser.
|
113 |
|
|
|
107 |
|
108 |
### Run demo
|
109 |
``` bash
|
110 |
+
python3 app.py
|
111 |
```
|
112 |
then open http://localhost:7860 in your browser.
|
113 |
|
api/__init__.py
CHANGED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from typing import List
|
3 |
+
from pydantic import BaseModel
|
4 |
+
import base64
|
5 |
+
import io
|
6 |
+
import numpy as np
|
7 |
+
from fastapi.exceptions import HTTPException
|
8 |
+
from PIL import Image
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
sys.path.append(str(Path(__file__).parents[1]))
|
12 |
+
from hloc import logger
|
13 |
+
|
14 |
+
|
15 |
+
class ImagesInput(BaseModel):
|
16 |
+
data: List[str] = []
|
17 |
+
max_keypoints: List[int] = []
|
18 |
+
timestamps: List[str] = []
|
19 |
+
grayscale: bool = False
|
20 |
+
image_hw: List[List[int]] = [[], []]
|
21 |
+
feature_type: int = 0
|
22 |
+
rotates: List[float] = []
|
23 |
+
scales: List[float] = []
|
24 |
+
reference_points: List[List[float]] = []
|
25 |
+
binarize: bool = False
|
26 |
+
|
27 |
+
|
28 |
+
def decode_base64_to_image(encoding):
|
29 |
+
if encoding.startswith("data:image/"):
|
30 |
+
encoding = encoding.split(";")[1].split(",")[1]
|
31 |
+
try:
|
32 |
+
image = Image.open(io.BytesIO(base64.b64decode(encoding)))
|
33 |
+
return image
|
34 |
+
except Exception as e:
|
35 |
+
logger.warning(f"API cannot decode image: {e}")
|
36 |
+
raise HTTPException(
|
37 |
+
status_code=500, detail="Invalid encoded image"
|
38 |
+
) from e
|
39 |
+
|
40 |
+
|
41 |
+
def to_base64_nparray(encoding: str) -> np.ndarray:
|
42 |
+
return np.array(decode_base64_to_image(encoding)).astype("uint8")
|
api/client.py
CHANGED
@@ -9,7 +9,7 @@ import cv2
|
|
9 |
import numpy as np
|
10 |
import requests
|
11 |
|
12 |
-
ENDPOINT = "http://127.0.0.1:
|
13 |
if "REMOTE_URL_RAILWAY" in os.environ:
|
14 |
ENDPOINT = os.environ["REMOTE_URL_RAILWAY"]
|
15 |
|
@@ -23,10 +23,8 @@ API_URL_EXTRACT = f"{ENDPOINT}/v1/extract"
|
|
23 |
def read_image(path: str) -> str:
|
24 |
"""
|
25 |
Read an image from a file, encode it as a JPEG and then as a base64 string.
|
26 |
-
|
27 |
Args:
|
28 |
path (str): The path to the image to read.
|
29 |
-
|
30 |
Returns:
|
31 |
str: The base64 encoded image.
|
32 |
"""
|
@@ -45,12 +43,10 @@ def read_image(path: str) -> str:
|
|
45 |
def do_api_requests(url=API_URL_EXTRACT, **kwargs):
|
46 |
"""
|
47 |
Helper function to send an API request to the image matching service.
|
48 |
-
|
49 |
Args:
|
50 |
url (str): The URL of the API endpoint to use. Defaults to the
|
51 |
feature extraction endpoint.
|
52 |
**kwargs: Additional keyword arguments to pass to the API.
|
53 |
-
|
54 |
Returns:
|
55 |
List[Dict[str, np.ndarray]]: A list of dictionaries containing the
|
56 |
extracted features. The keys are "keypoints", "descriptors", and
|
@@ -99,11 +95,9 @@ def do_api_requests(url=API_URL_EXTRACT, **kwargs):
|
|
99 |
def send_request_match(path0: str, path1: str) -> Dict[str, np.ndarray]:
|
100 |
"""
|
101 |
Send a request to the API to generate a match between two images.
|
102 |
-
|
103 |
Args:
|
104 |
path0 (str): The path to the first image.
|
105 |
path1 (str): The path to the second image.
|
106 |
-
|
107 |
Returns:
|
108 |
Dict[str, np.ndarray]: A dictionary containing the generated matches.
|
109 |
The keys are "keypoints0", "keypoints1", "matches0", and "matches1",
|
@@ -134,10 +128,8 @@ def send_request_extract(
|
|
134 |
) -> List[Dict[str, np.ndarray]]:
|
135 |
"""
|
136 |
Send a request to the API to extract features from an image.
|
137 |
-
|
138 |
Args:
|
139 |
input_images (str): The path to the image.
|
140 |
-
|
141 |
Returns:
|
142 |
List[Dict[str, np.ndarray]]: A list of dictionaries containing the
|
143 |
extracted features. The keys are "keypoints", "descriptors", and
|
@@ -152,7 +144,8 @@ def send_request_extract(
|
|
152 |
url=API_URL_EXTRACT,
|
153 |
**inputs,
|
154 |
)
|
155 |
-
|
|
|
156 |
|
157 |
# draw matching, debug only
|
158 |
if viz:
|
@@ -214,7 +207,7 @@ if __name__ == "__main__":
|
|
214 |
# )
|
215 |
|
216 |
# request extract
|
217 |
-
for i in range(
|
218 |
t1 = time.time()
|
219 |
preds = send_request_extract(args.image0)
|
220 |
t2 = time.time()
|
|
|
9 |
import numpy as np
|
10 |
import requests
|
11 |
|
12 |
+
ENDPOINT = "http://127.0.0.1:8000"
|
13 |
if "REMOTE_URL_RAILWAY" in os.environ:
|
14 |
ENDPOINT = os.environ["REMOTE_URL_RAILWAY"]
|
15 |
|
|
|
23 |
def read_image(path: str) -> str:
|
24 |
"""
|
25 |
Read an image from a file, encode it as a JPEG and then as a base64 string.
|
|
|
26 |
Args:
|
27 |
path (str): The path to the image to read.
|
|
|
28 |
Returns:
|
29 |
str: The base64 encoded image.
|
30 |
"""
|
|
|
43 |
def do_api_requests(url=API_URL_EXTRACT, **kwargs):
|
44 |
"""
|
45 |
Helper function to send an API request to the image matching service.
|
|
|
46 |
Args:
|
47 |
url (str): The URL of the API endpoint to use. Defaults to the
|
48 |
feature extraction endpoint.
|
49 |
**kwargs: Additional keyword arguments to pass to the API.
|
|
|
50 |
Returns:
|
51 |
List[Dict[str, np.ndarray]]: A list of dictionaries containing the
|
52 |
extracted features. The keys are "keypoints", "descriptors", and
|
|
|
95 |
def send_request_match(path0: str, path1: str) -> Dict[str, np.ndarray]:
|
96 |
"""
|
97 |
Send a request to the API to generate a match between two images.
|
|
|
98 |
Args:
|
99 |
path0 (str): The path to the first image.
|
100 |
path1 (str): The path to the second image.
|
|
|
101 |
Returns:
|
102 |
Dict[str, np.ndarray]: A dictionary containing the generated matches.
|
103 |
The keys are "keypoints0", "keypoints1", "matches0", and "matches1",
|
|
|
128 |
) -> List[Dict[str, np.ndarray]]:
|
129 |
"""
|
130 |
Send a request to the API to extract features from an image.
|
|
|
131 |
Args:
|
132 |
input_images (str): The path to the image.
|
|
|
133 |
Returns:
|
134 |
List[Dict[str, np.ndarray]]: A list of dictionaries containing the
|
135 |
extracted features. The keys are "keypoints", "descriptors", and
|
|
|
144 |
url=API_URL_EXTRACT,
|
145 |
**inputs,
|
146 |
)
|
147 |
+
# breakpoint()
|
148 |
+
# print("Keypoints detected: {}".format(len(response[0]["keypoints"])))
|
149 |
|
150 |
# draw matching, debug only
|
151 |
if viz:
|
|
|
207 |
# )
|
208 |
|
209 |
# request extract
|
210 |
+
for i in range(1000):
|
211 |
t1 = time.time()
|
212 |
preds = send_request_extract(args.image0)
|
213 |
t2 = time.time()
|
api/config/api.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file was generated using the `serve build` command on Ray v2.38.0.
|
2 |
+
|
3 |
+
proxy_location: EveryNode
|
4 |
+
http_options:
|
5 |
+
host: 0.0.0.0
|
6 |
+
port: 8000
|
7 |
+
|
8 |
+
grpc_options:
|
9 |
+
port: 9000
|
10 |
+
grpc_servicer_functions: []
|
11 |
+
|
12 |
+
logging_config:
|
13 |
+
encoding: TEXT
|
14 |
+
log_level: INFO
|
15 |
+
logs_dir: null
|
16 |
+
enable_access_log: true
|
17 |
+
|
18 |
+
applications:
|
19 |
+
- name: app1
|
20 |
+
route_prefix: /
|
21 |
+
import_path: api.server:service
|
22 |
+
runtime_env: {}
|
23 |
+
deployments:
|
24 |
+
- name: ImageMatchingService
|
25 |
+
num_replicas: 4
|
26 |
+
ray_actor_options:
|
27 |
+
num_cpus: 2.0
|
28 |
+
num_gpus: 1.0
|
29 |
+
|
30 |
+
api:
|
31 |
+
feature:
|
32 |
+
output: feats-superpoint-n4096-rmax1600
|
33 |
+
model:
|
34 |
+
name: superpoint
|
35 |
+
nms_radius: 3
|
36 |
+
max_keypoints: 4096
|
37 |
+
keypoint_threshold: 0.005
|
38 |
+
preprocessing:
|
39 |
+
grayscale: True
|
40 |
+
force_resize: True
|
41 |
+
resize_max: 1600
|
42 |
+
width: 640
|
43 |
+
height: 480
|
44 |
+
dfactor: 8
|
45 |
+
matcher:
|
46 |
+
output: matches-NN-mutual
|
47 |
+
model:
|
48 |
+
name: nearest_neighbor
|
49 |
+
do_mutual_check: True
|
50 |
+
match_threshold: 0.2
|
51 |
+
dense: False
|
api/server.py
CHANGED
@@ -1,24 +1,21 @@
|
|
1 |
# server.py
|
2 |
-
import base64
|
3 |
-
import io
|
4 |
-
import sys
|
5 |
import warnings
|
6 |
from pathlib import Path
|
7 |
from typing import Any, Dict, Optional, Union
|
|
|
|
|
|
|
|
|
8 |
|
9 |
import cv2
|
10 |
import matplotlib.pyplot as plt
|
11 |
import numpy as np
|
12 |
import torch
|
13 |
-
import uvicorn
|
14 |
from fastapi import FastAPI, File, UploadFile
|
15 |
-
from fastapi.exceptions import HTTPException
|
16 |
from fastapi.responses import JSONResponse
|
17 |
from PIL import Image
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
from api.types import ImagesInput
|
22 |
from hloc import DEVICE, extract_features, logger, match_dense, match_features
|
23 |
from hloc.utils.viz import add_text, plot_keypoints
|
24 |
from ui import get_version
|
@@ -26,23 +23,16 @@ from ui.utils import filter_matches, get_feature_model, get_model
|
|
26 |
from ui.viz import display_matches, fig2im, plot_images
|
27 |
|
28 |
warnings.simplefilter("ignore")
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
raise HTTPException(
|
40 |
-
status_code=500, detail="Invalid encoded image"
|
41 |
-
) from e
|
42 |
-
|
43 |
-
|
44 |
-
def to_base64_nparray(encoding: str) -> np.ndarray:
|
45 |
-
return np.array(decode_base64_to_image(encoding)).astype("uint8")
|
46 |
|
47 |
|
48 |
class ImageMatchingAPI(torch.nn.Module):
|
@@ -68,14 +58,12 @@ class ImageMatchingAPI(torch.nn.Module):
|
|
68 |
) -> None:
|
69 |
"""
|
70 |
Initializes an instance of the ImageMatchingAPI class.
|
71 |
-
|
72 |
Args:
|
73 |
conf (dict): A dictionary containing the configuration parameters.
|
74 |
device (str, optional): The device to use for computation. Defaults to "cpu".
|
75 |
detect_threshold (float, optional): The threshold for detecting keypoints. Defaults to 0.015.
|
76 |
max_keypoints (int, optional): The maximum number of keypoints to extract. Defaults to 1024.
|
77 |
match_threshold (float, optional): The threshold for matching keypoints. Defaults to 0.2.
|
78 |
-
|
79 |
Returns:
|
80 |
None
|
81 |
"""
|
@@ -170,13 +158,22 @@ class ImageMatchingAPI(torch.nn.Module):
|
|
170 |
pred = match_features.match_images(self.matcher, pred0, pred1)
|
171 |
return pred
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
@torch.inference_mode()
|
174 |
def extract(self, img0: np.ndarray, **kwargs) -> Dict[str, np.ndarray]:
|
175 |
"""Extract features from a single image.
|
176 |
-
|
177 |
Args:
|
178 |
img0 (np.ndarray): image
|
179 |
-
|
180 |
Returns:
|
181 |
Dict[str, np.ndarray]: feature dict
|
182 |
"""
|
@@ -190,17 +187,13 @@ class ImageMatchingAPI(torch.nn.Module):
|
|
190 |
pred = extract_features.extract(
|
191 |
self.extractor, img0, self.extract_conf["preprocessing"]
|
192 |
)
|
193 |
-
pred =
|
194 |
-
k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v
|
195 |
-
for k, v in pred.items()
|
196 |
-
}
|
197 |
# back to origin scale
|
198 |
s0 = pred["original_size"] / pred["size"]
|
199 |
pred["keypoints_orig"] = (
|
200 |
match_features.scale_keypoints(pred["keypoints"] + 0.5, s0) - 0.5
|
201 |
)
|
202 |
# TODO: rotate back
|
203 |
-
|
204 |
binarize = kwargs.get("binarize", False)
|
205 |
if binarize:
|
206 |
assert "descriptors" in pred
|
@@ -216,13 +209,11 @@ class ImageMatchingAPI(torch.nn.Module):
|
|
216 |
) -> Dict[str, np.ndarray]:
|
217 |
"""
|
218 |
Forward pass of the image matching API.
|
219 |
-
|
220 |
Args:
|
221 |
img0: A 3D NumPy array of shape (H, W, C) representing the first image.
|
222 |
Values are in the range [0, 1] and are in RGB mode.
|
223 |
img1: A 3D NumPy array of shape (H, W, C) representing the second image.
|
224 |
Values are in the range [0, 1] and are in RGB mode.
|
225 |
-
|
226 |
Returns:
|
227 |
A dictionary containing the following keys:
|
228 |
- image0_orig: The original image 0.
|
@@ -252,11 +243,9 @@ class ImageMatchingAPI(torch.nn.Module):
|
|
252 |
Filter matches using RANSAC. If keypoints are available, filter by keypoints.
|
253 |
If lines are available, filter by lines. If both keypoints and lines are
|
254 |
available, filter by keypoints.
|
255 |
-
|
256 |
Args:
|
257 |
pred (Dict[str, Any]): dict of matches, including original keypoints.
|
258 |
See :func:`filter_matches` for the expected keys.
|
259 |
-
|
260 |
Returns:
|
261 |
Dict[str, Any]: filtered matches
|
262 |
"""
|
@@ -275,10 +264,8 @@ class ImageMatchingAPI(torch.nn.Module):
|
|
275 |
) -> None:
|
276 |
"""
|
277 |
Visualize the matches.
|
278 |
-
|
279 |
Args:
|
280 |
log_path (Path, optional): The directory to save the images. Defaults to None.
|
281 |
-
|
282 |
Returns:
|
283 |
None
|
284 |
"""
|
@@ -349,96 +336,95 @@ class ImageMatchingAPI(torch.nn.Module):
|
|
349 |
plt.close("all")
|
350 |
|
351 |
|
|
|
|
|
|
|
|
|
|
|
352 |
class ImageMatchingService:
|
353 |
def __init__(self, conf: dict, device: str):
|
354 |
self.conf = conf
|
355 |
self.api = ImageMatchingAPI(conf=conf, device=device)
|
356 |
-
self.app = FastAPI()
|
357 |
-
self.register_routes()
|
358 |
|
359 |
-
|
|
|
|
|
360 |
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
|
372 |
-
|
373 |
-
|
374 |
-
image1 (UploadFile): The second image file for matching.
|
375 |
|
376 |
-
|
377 |
-
|
378 |
-
or an error message in case of failure.
|
379 |
-
"""
|
380 |
-
try:
|
381 |
-
# Load the images from the uploaded files
|
382 |
-
image0_array = self.load_image(image0)
|
383 |
-
image1_array = self.load_image(image1)
|
384 |
|
385 |
-
|
386 |
-
|
387 |
|
388 |
-
|
389 |
-
|
|
|
|
|
|
|
390 |
|
391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
pred = self.postprocess(output, skip_keys)
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
@self.app.post("/v1/extract")
|
401 |
-
async def extract(input_info: ImagesInput):
|
402 |
-
"""
|
403 |
-
Extract keypoints and descriptors from images.
|
404 |
-
|
405 |
-
Args:
|
406 |
-
input_info: An object containing the image data and options.
|
407 |
-
|
408 |
-
Returns:
|
409 |
-
A list of dictionaries containing the keypoints and descriptors.
|
410 |
-
"""
|
411 |
-
try:
|
412 |
-
preds = []
|
413 |
-
for i, input_image in enumerate(input_info.data):
|
414 |
-
# Load the image from the input data
|
415 |
-
image_array = to_base64_nparray(input_image)
|
416 |
-
# Extract keypoints and descriptors
|
417 |
-
output = self.api.extract(
|
418 |
-
image_array,
|
419 |
-
max_keypoints=input_info.max_keypoints[i],
|
420 |
-
binarize=input_info.binarize,
|
421 |
-
)
|
422 |
-
# Do not return the original image and image_orig
|
423 |
-
# skip_keys = ["image", "image_orig"]
|
424 |
-
skip_keys = []
|
425 |
-
|
426 |
-
# Postprocess the output
|
427 |
-
pred = self.postprocess(output, skip_keys)
|
428 |
-
preds.append(pred)
|
429 |
-
# Return the list of extracted features
|
430 |
-
return JSONResponse(content=preds)
|
431 |
-
except Exception as e:
|
432 |
-
# Return an error message if an exception occurs
|
433 |
-
return JSONResponse(content={"error": str(e)}, status_code=500)
|
434 |
|
435 |
def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray:
|
436 |
"""
|
437 |
Reads an image from a file path or an UploadFile object.
|
438 |
-
|
439 |
Args:
|
440 |
file_path: A file path or an UploadFile object.
|
441 |
-
|
442 |
Returns:
|
443 |
A numpy array representing the image.
|
444 |
"""
|
@@ -462,38 +448,23 @@ class ImageMatchingService:
|
|
462 |
return pred
|
463 |
|
464 |
def run(self, host: str = "0.0.0.0", port: int = 8001):
|
465 |
-
uvicorn
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
"resize_max": 1600,
|
482 |
-
"width": 640,
|
483 |
-
"height": 480,
|
484 |
-
"dfactor": 8,
|
485 |
-
},
|
486 |
-
},
|
487 |
-
"matcher": {
|
488 |
-
"output": "matches-NN-mutual",
|
489 |
-
"model": {
|
490 |
-
"name": "nearest_neighbor",
|
491 |
-
"do_mutual_check": True,
|
492 |
-
"match_threshold": 0.2,
|
493 |
-
},
|
494 |
-
},
|
495 |
-
"dense": False,
|
496 |
-
}
|
497 |
|
498 |
-
|
499 |
-
|
|
|
|
1 |
# server.py
|
|
|
|
|
|
|
2 |
import warnings
|
3 |
from pathlib import Path
|
4 |
from typing import Any, Dict, Optional, Union
|
5 |
+
import yaml
|
6 |
+
|
7 |
+
import ray
|
8 |
+
from ray import serve
|
9 |
|
10 |
import cv2
|
11 |
import matplotlib.pyplot as plt
|
12 |
import numpy as np
|
13 |
import torch
|
|
|
14 |
from fastapi import FastAPI, File, UploadFile
|
|
|
15 |
from fastapi.responses import JSONResponse
|
16 |
from PIL import Image
|
17 |
|
18 |
+
from api import ImagesInput, to_base64_nparray
|
|
|
|
|
19 |
from hloc import DEVICE, extract_features, logger, match_dense, match_features
|
20 |
from hloc.utils.viz import add_text, plot_keypoints
|
21 |
from ui import get_version
|
|
|
23 |
from ui.viz import display_matches, fig2im, plot_images
|
24 |
|
25 |
warnings.simplefilter("ignore")
|
26 |
+
app = FastAPI()
|
27 |
+
if ray.is_initialized():
|
28 |
+
ray.shutdown()
|
29 |
+
ray.init(
|
30 |
+
dashboard_port=8265,
|
31 |
+
ignore_reinit_error=True,
|
32 |
+
)
|
33 |
+
serve.start(
|
34 |
+
http_options={"host": "0.0.0.0", "port": 8000},
|
35 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
|
38 |
class ImageMatchingAPI(torch.nn.Module):
|
|
|
58 |
) -> None:
|
59 |
"""
|
60 |
Initializes an instance of the ImageMatchingAPI class.
|
|
|
61 |
Args:
|
62 |
conf (dict): A dictionary containing the configuration parameters.
|
63 |
device (str, optional): The device to use for computation. Defaults to "cpu".
|
64 |
detect_threshold (float, optional): The threshold for detecting keypoints. Defaults to 0.015.
|
65 |
max_keypoints (int, optional): The maximum number of keypoints to extract. Defaults to 1024.
|
66 |
match_threshold (float, optional): The threshold for matching keypoints. Defaults to 0.2.
|
|
|
67 |
Returns:
|
68 |
None
|
69 |
"""
|
|
|
158 |
pred = match_features.match_images(self.matcher, pred0, pred1)
|
159 |
return pred
|
160 |
|
161 |
+
def _convert_pred(self, pred):
|
162 |
+
ret = {
|
163 |
+
k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v
|
164 |
+
for k, v in pred.items()
|
165 |
+
}
|
166 |
+
ret = {
|
167 |
+
k: v[0].cpu().detach().numpy() if isinstance(v, list) else v
|
168 |
+
for k, v in ret.items()
|
169 |
+
}
|
170 |
+
return ret
|
171 |
+
|
172 |
@torch.inference_mode()
|
173 |
def extract(self, img0: np.ndarray, **kwargs) -> Dict[str, np.ndarray]:
|
174 |
"""Extract features from a single image.
|
|
|
175 |
Args:
|
176 |
img0 (np.ndarray): image
|
|
|
177 |
Returns:
|
178 |
Dict[str, np.ndarray]: feature dict
|
179 |
"""
|
|
|
187 |
pred = extract_features.extract(
|
188 |
self.extractor, img0, self.extract_conf["preprocessing"]
|
189 |
)
|
190 |
+
pred = self._convert_pred(pred)
|
|
|
|
|
|
|
191 |
# back to origin scale
|
192 |
s0 = pred["original_size"] / pred["size"]
|
193 |
pred["keypoints_orig"] = (
|
194 |
match_features.scale_keypoints(pred["keypoints"] + 0.5, s0) - 0.5
|
195 |
)
|
196 |
# TODO: rotate back
|
|
|
197 |
binarize = kwargs.get("binarize", False)
|
198 |
if binarize:
|
199 |
assert "descriptors" in pred
|
|
|
209 |
) -> Dict[str, np.ndarray]:
|
210 |
"""
|
211 |
Forward pass of the image matching API.
|
|
|
212 |
Args:
|
213 |
img0: A 3D NumPy array of shape (H, W, C) representing the first image.
|
214 |
Values are in the range [0, 1] and are in RGB mode.
|
215 |
img1: A 3D NumPy array of shape (H, W, C) representing the second image.
|
216 |
Values are in the range [0, 1] and are in RGB mode.
|
|
|
217 |
Returns:
|
218 |
A dictionary containing the following keys:
|
219 |
- image0_orig: The original image 0.
|
|
|
243 |
Filter matches using RANSAC. If keypoints are available, filter by keypoints.
|
244 |
If lines are available, filter by lines. If both keypoints and lines are
|
245 |
available, filter by keypoints.
|
|
|
246 |
Args:
|
247 |
pred (Dict[str, Any]): dict of matches, including original keypoints.
|
248 |
See :func:`filter_matches` for the expected keys.
|
|
|
249 |
Returns:
|
250 |
Dict[str, Any]: filtered matches
|
251 |
"""
|
|
|
264 |
) -> None:
|
265 |
"""
|
266 |
Visualize the matches.
|
|
|
267 |
Args:
|
268 |
log_path (Path, optional): The directory to save the images. Defaults to None.
|
|
|
269 |
Returns:
|
270 |
None
|
271 |
"""
|
|
|
336 |
plt.close("all")
|
337 |
|
338 |
|
339 |
+
@serve.deployment(
|
340 |
+
num_replicas=4,
|
341 |
+
ray_actor_options={"num_cpus": 2, "num_gpus": 1}
|
342 |
+
)
|
343 |
+
@serve.ingress(app)
|
344 |
class ImageMatchingService:
|
345 |
def __init__(self, conf: dict, device: str):
|
346 |
self.conf = conf
|
347 |
self.api = ImageMatchingAPI(conf=conf, device=device)
|
|
|
|
|
348 |
|
349 |
+
@app.get("/")
|
350 |
+
def root(self):
|
351 |
+
return "Hello, world!"
|
352 |
|
353 |
+
@app.get("/version")
|
354 |
+
async def version(self):
|
355 |
+
return {"version": get_version()}
|
356 |
|
357 |
+
@app.post("/v1/match")
|
358 |
+
async def match(
|
359 |
+
self, image0: UploadFile = File(...), image1: UploadFile = File(...)
|
360 |
+
):
|
361 |
+
"""
|
362 |
+
Handle the image matching request and return the processed result.
|
363 |
+
Args:
|
364 |
+
image0 (UploadFile): The first image file for matching.
|
365 |
+
image1 (UploadFile): The second image file for matching.
|
366 |
+
Returns:
|
367 |
+
JSONResponse: A JSON response containing the filtered match results
|
368 |
+
or an error message in case of failure.
|
369 |
+
"""
|
370 |
+
try:
|
371 |
+
# Load the images from the uploaded files
|
372 |
+
image0_array = self.load_image(image0)
|
373 |
+
image1_array = self.load_image(image1)
|
374 |
|
375 |
+
# Perform image matching using the API
|
376 |
+
output = self.api(image0_array, image1_array)
|
|
|
377 |
|
378 |
+
# Keys to skip in the output
|
379 |
+
skip_keys = ["image0_orig", "image1_orig"]
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
|
381 |
+
# Postprocess the output to filter unwanted data
|
382 |
+
pred = self.postprocess(output, skip_keys)
|
383 |
|
384 |
+
# Return the filtered prediction as a JSON response
|
385 |
+
return JSONResponse(content=pred)
|
386 |
+
except Exception as e:
|
387 |
+
# Return an error message with status code 500 in case of exception
|
388 |
+
return JSONResponse(content={"error": str(e)}, status_code=500)
|
389 |
|
390 |
+
@app.post("/v1/extract")
|
391 |
+
async def extract(self, input_info: ImagesInput):
|
392 |
+
"""
|
393 |
+
Extract keypoints and descriptors from images.
|
394 |
+
Args:
|
395 |
+
input_info: An object containing the image data and options.
|
396 |
+
Returns:
|
397 |
+
A list of dictionaries containing the keypoints and descriptors.
|
398 |
+
"""
|
399 |
+
try:
|
400 |
+
preds = []
|
401 |
+
for i, input_image in enumerate(input_info.data):
|
402 |
+
# Load the image from the input data
|
403 |
+
image_array = to_base64_nparray(input_image)
|
404 |
+
# Extract keypoints and descriptors
|
405 |
+
output = self.api.extract(
|
406 |
+
image_array,
|
407 |
+
max_keypoints=input_info.max_keypoints[i],
|
408 |
+
binarize=input_info.binarize,
|
409 |
+
)
|
410 |
+
# Do not return the original image and image_orig
|
411 |
+
# skip_keys = ["image", "image_orig"]
|
412 |
+
skip_keys = []
|
413 |
+
|
414 |
+
# Postprocess the output
|
415 |
pred = self.postprocess(output, skip_keys)
|
416 |
+
preds.append(pred)
|
417 |
+
# Return the list of extracted features
|
418 |
+
return JSONResponse(content=preds)
|
419 |
+
except Exception as e:
|
420 |
+
# Return an error message if an exception occurs
|
421 |
+
return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
|
423 |
def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray:
|
424 |
"""
|
425 |
Reads an image from a file path or an UploadFile object.
|
|
|
426 |
Args:
|
427 |
file_path: A file path or an UploadFile object.
|
|
|
428 |
Returns:
|
429 |
A numpy array representing the image.
|
430 |
"""
|
|
|
448 |
return pred
|
449 |
|
450 |
def run(self, host: str = "0.0.0.0", port: int = 8001):
|
451 |
+
import uvicorn
|
452 |
+
uvicorn.run(app, host=host, port=port)
|
453 |
+
|
454 |
+
|
455 |
+
def read_config(config_path: Path) -> dict:
|
456 |
+
with open(config_path, "r") as f:
|
457 |
+
conf = yaml.safe_load(f)
|
458 |
+
return conf
|
459 |
+
|
460 |
+
|
461 |
+
# api server
|
462 |
+
conf = read_config(Path(__file__).parent / "config/api.yaml")
|
463 |
+
service = ImageMatchingService.bind(conf=conf["api"], device=DEVICE)
|
464 |
+
|
465 |
+
# handle = serve.run(service, route_prefix="/")
|
466 |
+
# serve run api.server_ray:service
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
467 |
|
468 |
+
# build to generate config file
|
469 |
+
# serve build api.server_ray:service -o api/config/ray.yaml
|
470 |
+
# serve run api/config/ray.yaml
|
api/types.py
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
from typing import List
|
2 |
-
|
3 |
-
from pydantic import BaseModel
|
4 |
-
|
5 |
-
|
6 |
-
class ImagesInput(BaseModel):
|
7 |
-
data: List[str] = []
|
8 |
-
max_keypoints: List[int] = []
|
9 |
-
timestamps: List[str] = []
|
10 |
-
grayscale: bool = False
|
11 |
-
image_hw: List[List[int]] = [[], []]
|
12 |
-
feature_type: int = 0
|
13 |
-
rotates: List[float] = []
|
14 |
-
scales: List[float] = []
|
15 |
-
reference_points: List[List[float]] = []
|
16 |
-
binarize: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -36,3 +36,6 @@ roma #dust3r
|
|
36 |
tqdm
|
37 |
yacs
|
38 |
fastapi
|
|
|
|
|
|
|
|
36 |
tqdm
|
37 |
yacs
|
38 |
fastapi
|
39 |
+
uvicorn
|
40 |
+
ray
|
41 |
+
ray[serve]
|