Realcat commited on
Commit
aebdae7
·
1 Parent(s): 7dc6568

add: ray dashboard port and serve port

Browse files
Files changed (7) hide show
  1. README.md +1 -1
  2. api/__init__.py +42 -0
  3. api/client.py +4 -11
  4. api/config/api.yaml +51 -0
  5. api/server.py +116 -145
  6. api/types.py +0 -16
  7. requirements.txt +3 -0
README.md CHANGED
@@ -107,7 +107,7 @@ docker run -it -p 7860:7860 vincentqin/image-matching-webui:latest python app.py
107
 
108
  ### Run demo
109
  ``` bash
110
- python3 ./app.py
111
  ```
112
  then open http://localhost:7860 in your browser.
113
 
 
107
 
108
  ### Run demo
109
  ``` bash
110
+ python3 app.py
111
  ```
112
  then open http://localhost:7860 in your browser.
113
 
api/__init__.py CHANGED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from typing import List
3
+ from pydantic import BaseModel
4
+ import base64
5
+ import io
6
+ import numpy as np
7
+ from fastapi.exceptions import HTTPException
8
+ from PIL import Image
9
+ from pathlib import Path
10
+
11
+ sys.path.append(str(Path(__file__).parents[1]))
12
+ from hloc import logger
13
+
14
+
15
+ class ImagesInput(BaseModel):
16
+ data: List[str] = []
17
+ max_keypoints: List[int] = []
18
+ timestamps: List[str] = []
19
+ grayscale: bool = False
20
+ image_hw: List[List[int]] = [[], []]
21
+ feature_type: int = 0
22
+ rotates: List[float] = []
23
+ scales: List[float] = []
24
+ reference_points: List[List[float]] = []
25
+ binarize: bool = False
26
+
27
+
28
+ def decode_base64_to_image(encoding):
29
+ if encoding.startswith("data:image/"):
30
+ encoding = encoding.split(";")[1].split(",")[1]
31
+ try:
32
+ image = Image.open(io.BytesIO(base64.b64decode(encoding)))
33
+ return image
34
+ except Exception as e:
35
+ logger.warning(f"API cannot decode image: {e}")
36
+ raise HTTPException(
37
+ status_code=500, detail="Invalid encoded image"
38
+ ) from e
39
+
40
+
41
+ def to_base64_nparray(encoding: str) -> np.ndarray:
42
+ return np.array(decode_base64_to_image(encoding)).astype("uint8")
api/client.py CHANGED
@@ -9,7 +9,7 @@ import cv2
9
  import numpy as np
10
  import requests
11
 
12
- ENDPOINT = "http://127.0.0.1:8001"
13
  if "REMOTE_URL_RAILWAY" in os.environ:
14
  ENDPOINT = os.environ["REMOTE_URL_RAILWAY"]
15
 
@@ -23,10 +23,8 @@ API_URL_EXTRACT = f"{ENDPOINT}/v1/extract"
23
  def read_image(path: str) -> str:
24
  """
25
  Read an image from a file, encode it as a JPEG and then as a base64 string.
26
-
27
  Args:
28
  path (str): The path to the image to read.
29
-
30
  Returns:
31
  str: The base64 encoded image.
32
  """
@@ -45,12 +43,10 @@ def read_image(path: str) -> str:
45
  def do_api_requests(url=API_URL_EXTRACT, **kwargs):
46
  """
47
  Helper function to send an API request to the image matching service.
48
-
49
  Args:
50
  url (str): The URL of the API endpoint to use. Defaults to the
51
  feature extraction endpoint.
52
  **kwargs: Additional keyword arguments to pass to the API.
53
-
54
  Returns:
55
  List[Dict[str, np.ndarray]]: A list of dictionaries containing the
56
  extracted features. The keys are "keypoints", "descriptors", and
@@ -99,11 +95,9 @@ def do_api_requests(url=API_URL_EXTRACT, **kwargs):
99
  def send_request_match(path0: str, path1: str) -> Dict[str, np.ndarray]:
100
  """
101
  Send a request to the API to generate a match between two images.
102
-
103
  Args:
104
  path0 (str): The path to the first image.
105
  path1 (str): The path to the second image.
106
-
107
  Returns:
108
  Dict[str, np.ndarray]: A dictionary containing the generated matches.
109
  The keys are "keypoints0", "keypoints1", "matches0", and "matches1",
@@ -134,10 +128,8 @@ def send_request_extract(
134
  ) -> List[Dict[str, np.ndarray]]:
135
  """
136
  Send a request to the API to extract features from an image.
137
-
138
  Args:
139
  input_images (str): The path to the image.
140
-
141
  Returns:
142
  List[Dict[str, np.ndarray]]: A list of dictionaries containing the
143
  extracted features. The keys are "keypoints", "descriptors", and
@@ -152,7 +144,8 @@ def send_request_extract(
152
  url=API_URL_EXTRACT,
153
  **inputs,
154
  )
155
- print("Keypoints detected: {}".format(len(response[0]["keypoints"])))
 
156
 
157
  # draw matching, debug only
158
  if viz:
@@ -214,7 +207,7 @@ if __name__ == "__main__":
214
  # )
215
 
216
  # request extract
217
- for i in range(10):
218
  t1 = time.time()
219
  preds = send_request_extract(args.image0)
220
  t2 = time.time()
 
9
  import numpy as np
10
  import requests
11
 
12
+ ENDPOINT = "http://127.0.0.1:8000"
13
  if "REMOTE_URL_RAILWAY" in os.environ:
14
  ENDPOINT = os.environ["REMOTE_URL_RAILWAY"]
15
 
 
23
  def read_image(path: str) -> str:
24
  """
25
  Read an image from a file, encode it as a JPEG and then as a base64 string.
 
26
  Args:
27
  path (str): The path to the image to read.
 
28
  Returns:
29
  str: The base64 encoded image.
30
  """
 
43
  def do_api_requests(url=API_URL_EXTRACT, **kwargs):
44
  """
45
  Helper function to send an API request to the image matching service.
 
46
  Args:
47
  url (str): The URL of the API endpoint to use. Defaults to the
48
  feature extraction endpoint.
49
  **kwargs: Additional keyword arguments to pass to the API.
 
50
  Returns:
51
  List[Dict[str, np.ndarray]]: A list of dictionaries containing the
52
  extracted features. The keys are "keypoints", "descriptors", and
 
95
  def send_request_match(path0: str, path1: str) -> Dict[str, np.ndarray]:
96
  """
97
  Send a request to the API to generate a match between two images.
 
98
  Args:
99
  path0 (str): The path to the first image.
100
  path1 (str): The path to the second image.
 
101
  Returns:
102
  Dict[str, np.ndarray]: A dictionary containing the generated matches.
103
  The keys are "keypoints0", "keypoints1", "matches0", and "matches1",
 
128
  ) -> List[Dict[str, np.ndarray]]:
129
  """
130
  Send a request to the API to extract features from an image.
 
131
  Args:
132
  input_images (str): The path to the image.
 
133
  Returns:
134
  List[Dict[str, np.ndarray]]: A list of dictionaries containing the
135
  extracted features. The keys are "keypoints", "descriptors", and
 
144
  url=API_URL_EXTRACT,
145
  **inputs,
146
  )
147
+ # breakpoint()
148
+ # print("Keypoints detected: {}".format(len(response[0]["keypoints"])))
149
 
150
  # draw matching, debug only
151
  if viz:
 
207
  # )
208
 
209
  # request extract
210
+ for i in range(1000):
211
  t1 = time.time()
212
  preds = send_request_extract(args.image0)
213
  t2 = time.time()
api/config/api.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was generated using the `serve build` command on Ray v2.38.0.
2
+
3
+ proxy_location: EveryNode
4
+ http_options:
5
+ host: 0.0.0.0
6
+ port: 8000
7
+
8
+ grpc_options:
9
+ port: 9000
10
+ grpc_servicer_functions: []
11
+
12
+ logging_config:
13
+ encoding: TEXT
14
+ log_level: INFO
15
+ logs_dir: null
16
+ enable_access_log: true
17
+
18
+ applications:
19
+ - name: app1
20
+ route_prefix: /
21
+ import_path: api.server:service
22
+ runtime_env: {}
23
+ deployments:
24
+ - name: ImageMatchingService
25
+ num_replicas: 4
26
+ ray_actor_options:
27
+ num_cpus: 2.0
28
+ num_gpus: 1.0
29
+
30
+ api:
31
+ feature:
32
+ output: feats-superpoint-n4096-rmax1600
33
+ model:
34
+ name: superpoint
35
+ nms_radius: 3
36
+ max_keypoints: 4096
37
+ keypoint_threshold: 0.005
38
+ preprocessing:
39
+ grayscale: True
40
+ force_resize: True
41
+ resize_max: 1600
42
+ width: 640
43
+ height: 480
44
+ dfactor: 8
45
+ matcher:
46
+ output: matches-NN-mutual
47
+ model:
48
+ name: nearest_neighbor
49
+ do_mutual_check: True
50
+ match_threshold: 0.2
51
+ dense: False
api/server.py CHANGED
@@ -1,24 +1,21 @@
1
  # server.py
2
- import base64
3
- import io
4
- import sys
5
  import warnings
6
  from pathlib import Path
7
  from typing import Any, Dict, Optional, Union
 
 
 
 
8
 
9
  import cv2
10
  import matplotlib.pyplot as plt
11
  import numpy as np
12
  import torch
13
- import uvicorn
14
  from fastapi import FastAPI, File, UploadFile
15
- from fastapi.exceptions import HTTPException
16
  from fastapi.responses import JSONResponse
17
  from PIL import Image
18
 
19
- sys.path.append(str(Path(__file__).parents[1]))
20
-
21
- from api.types import ImagesInput
22
  from hloc import DEVICE, extract_features, logger, match_dense, match_features
23
  from hloc.utils.viz import add_text, plot_keypoints
24
  from ui import get_version
@@ -26,23 +23,16 @@ from ui.utils import filter_matches, get_feature_model, get_model
26
  from ui.viz import display_matches, fig2im, plot_images
27
 
28
  warnings.simplefilter("ignore")
29
-
30
-
31
- def decode_base64_to_image(encoding):
32
- if encoding.startswith("data:image/"):
33
- encoding = encoding.split(";")[1].split(",")[1]
34
- try:
35
- image = Image.open(io.BytesIO(base64.b64decode(encoding)))
36
- return image
37
- except Exception as e:
38
- logger.warning(f"API cannot decode image: {e}")
39
- raise HTTPException(
40
- status_code=500, detail="Invalid encoded image"
41
- ) from e
42
-
43
-
44
- def to_base64_nparray(encoding: str) -> np.ndarray:
45
- return np.array(decode_base64_to_image(encoding)).astype("uint8")
46
 
47
 
48
  class ImageMatchingAPI(torch.nn.Module):
@@ -68,14 +58,12 @@ class ImageMatchingAPI(torch.nn.Module):
68
  ) -> None:
69
  """
70
  Initializes an instance of the ImageMatchingAPI class.
71
-
72
  Args:
73
  conf (dict): A dictionary containing the configuration parameters.
74
  device (str, optional): The device to use for computation. Defaults to "cpu".
75
  detect_threshold (float, optional): The threshold for detecting keypoints. Defaults to 0.015.
76
  max_keypoints (int, optional): The maximum number of keypoints to extract. Defaults to 1024.
77
  match_threshold (float, optional): The threshold for matching keypoints. Defaults to 0.2.
78
-
79
  Returns:
80
  None
81
  """
@@ -170,13 +158,22 @@ class ImageMatchingAPI(torch.nn.Module):
170
  pred = match_features.match_images(self.matcher, pred0, pred1)
171
  return pred
172
 
 
 
 
 
 
 
 
 
 
 
 
173
  @torch.inference_mode()
174
  def extract(self, img0: np.ndarray, **kwargs) -> Dict[str, np.ndarray]:
175
  """Extract features from a single image.
176
-
177
  Args:
178
  img0 (np.ndarray): image
179
-
180
  Returns:
181
  Dict[str, np.ndarray]: feature dict
182
  """
@@ -190,17 +187,13 @@ class ImageMatchingAPI(torch.nn.Module):
190
  pred = extract_features.extract(
191
  self.extractor, img0, self.extract_conf["preprocessing"]
192
  )
193
- pred = {
194
- k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v
195
- for k, v in pred.items()
196
- }
197
  # back to origin scale
198
  s0 = pred["original_size"] / pred["size"]
199
  pred["keypoints_orig"] = (
200
  match_features.scale_keypoints(pred["keypoints"] + 0.5, s0) - 0.5
201
  )
202
  # TODO: rotate back
203
-
204
  binarize = kwargs.get("binarize", False)
205
  if binarize:
206
  assert "descriptors" in pred
@@ -216,13 +209,11 @@ class ImageMatchingAPI(torch.nn.Module):
216
  ) -> Dict[str, np.ndarray]:
217
  """
218
  Forward pass of the image matching API.
219
-
220
  Args:
221
  img0: A 3D NumPy array of shape (H, W, C) representing the first image.
222
  Values are in the range [0, 1] and are in RGB mode.
223
  img1: A 3D NumPy array of shape (H, W, C) representing the second image.
224
  Values are in the range [0, 1] and are in RGB mode.
225
-
226
  Returns:
227
  A dictionary containing the following keys:
228
  - image0_orig: The original image 0.
@@ -252,11 +243,9 @@ class ImageMatchingAPI(torch.nn.Module):
252
  Filter matches using RANSAC. If keypoints are available, filter by keypoints.
253
  If lines are available, filter by lines. If both keypoints and lines are
254
  available, filter by keypoints.
255
-
256
  Args:
257
  pred (Dict[str, Any]): dict of matches, including original keypoints.
258
  See :func:`filter_matches` for the expected keys.
259
-
260
  Returns:
261
  Dict[str, Any]: filtered matches
262
  """
@@ -275,10 +264,8 @@ class ImageMatchingAPI(torch.nn.Module):
275
  ) -> None:
276
  """
277
  Visualize the matches.
278
-
279
  Args:
280
  log_path (Path, optional): The directory to save the images. Defaults to None.
281
-
282
  Returns:
283
  None
284
  """
@@ -349,96 +336,95 @@ class ImageMatchingAPI(torch.nn.Module):
349
  plt.close("all")
350
 
351
 
 
 
 
 
 
352
  class ImageMatchingService:
353
  def __init__(self, conf: dict, device: str):
354
  self.conf = conf
355
  self.api = ImageMatchingAPI(conf=conf, device=device)
356
- self.app = FastAPI()
357
- self.register_routes()
358
 
359
- def register_routes(self):
 
 
360
 
361
- @self.app.get("/version")
362
- async def version():
363
- return {"version": get_version()}
364
 
365
- @self.app.post("/v1/match")
366
- async def match(
367
- image0: UploadFile = File(...), image1: UploadFile = File(...)
368
- ):
369
- """
370
- Handle the image matching request and return the processed result.
 
 
 
 
 
 
 
 
 
 
 
371
 
372
- Args:
373
- image0 (UploadFile): The first image file for matching.
374
- image1 (UploadFile): The second image file for matching.
375
 
376
- Returns:
377
- JSONResponse: A JSON response containing the filtered match results
378
- or an error message in case of failure.
379
- """
380
- try:
381
- # Load the images from the uploaded files
382
- image0_array = self.load_image(image0)
383
- image1_array = self.load_image(image1)
384
 
385
- # Perform image matching using the API
386
- output = self.api(image0_array, image1_array)
387
 
388
- # Keys to skip in the output
389
- skip_keys = ["image0_orig", "image1_orig"]
 
 
 
390
 
391
- # Postprocess the output to filter unwanted data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  pred = self.postprocess(output, skip_keys)
393
-
394
- # Return the filtered prediction as a JSON response
395
- return JSONResponse(content=pred)
396
- except Exception as e:
397
- # Return an error message with status code 500 in case of exception
398
- return JSONResponse(content={"error": str(e)}, status_code=500)
399
-
400
- @self.app.post("/v1/extract")
401
- async def extract(input_info: ImagesInput):
402
- """
403
- Extract keypoints and descriptors from images.
404
-
405
- Args:
406
- input_info: An object containing the image data and options.
407
-
408
- Returns:
409
- A list of dictionaries containing the keypoints and descriptors.
410
- """
411
- try:
412
- preds = []
413
- for i, input_image in enumerate(input_info.data):
414
- # Load the image from the input data
415
- image_array = to_base64_nparray(input_image)
416
- # Extract keypoints and descriptors
417
- output = self.api.extract(
418
- image_array,
419
- max_keypoints=input_info.max_keypoints[i],
420
- binarize=input_info.binarize,
421
- )
422
- # Do not return the original image and image_orig
423
- # skip_keys = ["image", "image_orig"]
424
- skip_keys = []
425
-
426
- # Postprocess the output
427
- pred = self.postprocess(output, skip_keys)
428
- preds.append(pred)
429
- # Return the list of extracted features
430
- return JSONResponse(content=preds)
431
- except Exception as e:
432
- # Return an error message if an exception occurs
433
- return JSONResponse(content={"error": str(e)}, status_code=500)
434
 
435
  def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray:
436
  """
437
  Reads an image from a file path or an UploadFile object.
438
-
439
  Args:
440
  file_path: A file path or an UploadFile object.
441
-
442
  Returns:
443
  A numpy array representing the image.
444
  """
@@ -462,38 +448,23 @@ class ImageMatchingService:
462
  return pred
463
 
464
  def run(self, host: str = "0.0.0.0", port: int = 8001):
465
- uvicorn.run(self.app, host=host, port=port)
466
-
467
-
468
- if __name__ == "__main__":
469
- conf = {
470
- "feature": {
471
- "output": "feats-superpoint-n4096-rmax1600",
472
- "model": {
473
- "name": "superpoint",
474
- "nms_radius": 3,
475
- "max_keypoints": 4096,
476
- "keypoint_threshold": 0.005,
477
- },
478
- "preprocessing": {
479
- "grayscale": True,
480
- "force_resize": True,
481
- "resize_max": 1600,
482
- "width": 640,
483
- "height": 480,
484
- "dfactor": 8,
485
- },
486
- },
487
- "matcher": {
488
- "output": "matches-NN-mutual",
489
- "model": {
490
- "name": "nearest_neighbor",
491
- "do_mutual_check": True,
492
- "match_threshold": 0.2,
493
- },
494
- },
495
- "dense": False,
496
- }
497
 
498
- service = ImageMatchingService(conf=conf, device=DEVICE)
499
- service.run()
 
 
1
  # server.py
 
 
 
2
  import warnings
3
  from pathlib import Path
4
  from typing import Any, Dict, Optional, Union
5
+ import yaml
6
+
7
+ import ray
8
+ from ray import serve
9
 
10
  import cv2
11
  import matplotlib.pyplot as plt
12
  import numpy as np
13
  import torch
 
14
  from fastapi import FastAPI, File, UploadFile
 
15
  from fastapi.responses import JSONResponse
16
  from PIL import Image
17
 
18
+ from api import ImagesInput, to_base64_nparray
 
 
19
  from hloc import DEVICE, extract_features, logger, match_dense, match_features
20
  from hloc.utils.viz import add_text, plot_keypoints
21
  from ui import get_version
 
23
  from ui.viz import display_matches, fig2im, plot_images
24
 
25
  warnings.simplefilter("ignore")
26
+ app = FastAPI()
27
+ if ray.is_initialized():
28
+ ray.shutdown()
29
+ ray.init(
30
+ dashboard_port=8265,
31
+ ignore_reinit_error=True,
32
+ )
33
+ serve.start(
34
+ http_options={"host": "0.0.0.0", "port": 8000},
35
+ )
 
 
 
 
 
 
 
36
 
37
 
38
  class ImageMatchingAPI(torch.nn.Module):
 
58
  ) -> None:
59
  """
60
  Initializes an instance of the ImageMatchingAPI class.
 
61
  Args:
62
  conf (dict): A dictionary containing the configuration parameters.
63
  device (str, optional): The device to use for computation. Defaults to "cpu".
64
  detect_threshold (float, optional): The threshold for detecting keypoints. Defaults to 0.015.
65
  max_keypoints (int, optional): The maximum number of keypoints to extract. Defaults to 1024.
66
  match_threshold (float, optional): The threshold for matching keypoints. Defaults to 0.2.
 
67
  Returns:
68
  None
69
  """
 
158
  pred = match_features.match_images(self.matcher, pred0, pred1)
159
  return pred
160
 
161
+ def _convert_pred(self, pred):
162
+ ret = {
163
+ k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v
164
+ for k, v in pred.items()
165
+ }
166
+ ret = {
167
+ k: v[0].cpu().detach().numpy() if isinstance(v, list) else v
168
+ for k, v in ret.items()
169
+ }
170
+ return ret
171
+
172
  @torch.inference_mode()
173
  def extract(self, img0: np.ndarray, **kwargs) -> Dict[str, np.ndarray]:
174
  """Extract features from a single image.
 
175
  Args:
176
  img0 (np.ndarray): image
 
177
  Returns:
178
  Dict[str, np.ndarray]: feature dict
179
  """
 
187
  pred = extract_features.extract(
188
  self.extractor, img0, self.extract_conf["preprocessing"]
189
  )
190
+ pred = self._convert_pred(pred)
 
 
 
191
  # back to origin scale
192
  s0 = pred["original_size"] / pred["size"]
193
  pred["keypoints_orig"] = (
194
  match_features.scale_keypoints(pred["keypoints"] + 0.5, s0) - 0.5
195
  )
196
  # TODO: rotate back
 
197
  binarize = kwargs.get("binarize", False)
198
  if binarize:
199
  assert "descriptors" in pred
 
209
  ) -> Dict[str, np.ndarray]:
210
  """
211
  Forward pass of the image matching API.
 
212
  Args:
213
  img0: A 3D NumPy array of shape (H, W, C) representing the first image.
214
  Values are in the range [0, 1] and are in RGB mode.
215
  img1: A 3D NumPy array of shape (H, W, C) representing the second image.
216
  Values are in the range [0, 1] and are in RGB mode.
 
217
  Returns:
218
  A dictionary containing the following keys:
219
  - image0_orig: The original image 0.
 
243
  Filter matches using RANSAC. If keypoints are available, filter by keypoints.
244
  If lines are available, filter by lines. If both keypoints and lines are
245
  available, filter by keypoints.
 
246
  Args:
247
  pred (Dict[str, Any]): dict of matches, including original keypoints.
248
  See :func:`filter_matches` for the expected keys.
 
249
  Returns:
250
  Dict[str, Any]: filtered matches
251
  """
 
264
  ) -> None:
265
  """
266
  Visualize the matches.
 
267
  Args:
268
  log_path (Path, optional): The directory to save the images. Defaults to None.
 
269
  Returns:
270
  None
271
  """
 
336
  plt.close("all")
337
 
338
 
339
+ @serve.deployment(
340
+ num_replicas=4,
341
+ ray_actor_options={"num_cpus": 2, "num_gpus": 1}
342
+ )
343
+ @serve.ingress(app)
344
  class ImageMatchingService:
345
  def __init__(self, conf: dict, device: str):
346
  self.conf = conf
347
  self.api = ImageMatchingAPI(conf=conf, device=device)
 
 
348
 
349
+ @app.get("/")
350
+ def root(self):
351
+ return "Hello, world!"
352
 
353
+ @app.get("/version")
354
+ async def version(self):
355
+ return {"version": get_version()}
356
 
357
+ @app.post("/v1/match")
358
+ async def match(
359
+ self, image0: UploadFile = File(...), image1: UploadFile = File(...)
360
+ ):
361
+ """
362
+ Handle the image matching request and return the processed result.
363
+ Args:
364
+ image0 (UploadFile): The first image file for matching.
365
+ image1 (UploadFile): The second image file for matching.
366
+ Returns:
367
+ JSONResponse: A JSON response containing the filtered match results
368
+ or an error message in case of failure.
369
+ """
370
+ try:
371
+ # Load the images from the uploaded files
372
+ image0_array = self.load_image(image0)
373
+ image1_array = self.load_image(image1)
374
 
375
+ # Perform image matching using the API
376
+ output = self.api(image0_array, image1_array)
 
377
 
378
+ # Keys to skip in the output
379
+ skip_keys = ["image0_orig", "image1_orig"]
 
 
 
 
 
 
380
 
381
+ # Postprocess the output to filter unwanted data
382
+ pred = self.postprocess(output, skip_keys)
383
 
384
+ # Return the filtered prediction as a JSON response
385
+ return JSONResponse(content=pred)
386
+ except Exception as e:
387
+ # Return an error message with status code 500 in case of exception
388
+ return JSONResponse(content={"error": str(e)}, status_code=500)
389
 
390
+ @app.post("/v1/extract")
391
+ async def extract(self, input_info: ImagesInput):
392
+ """
393
+ Extract keypoints and descriptors from images.
394
+ Args:
395
+ input_info: An object containing the image data and options.
396
+ Returns:
397
+ A list of dictionaries containing the keypoints and descriptors.
398
+ """
399
+ try:
400
+ preds = []
401
+ for i, input_image in enumerate(input_info.data):
402
+ # Load the image from the input data
403
+ image_array = to_base64_nparray(input_image)
404
+ # Extract keypoints and descriptors
405
+ output = self.api.extract(
406
+ image_array,
407
+ max_keypoints=input_info.max_keypoints[i],
408
+ binarize=input_info.binarize,
409
+ )
410
+ # Do not return the original image and image_orig
411
+ # skip_keys = ["image", "image_orig"]
412
+ skip_keys = []
413
+
414
+ # Postprocess the output
415
  pred = self.postprocess(output, skip_keys)
416
+ preds.append(pred)
417
+ # Return the list of extracted features
418
+ return JSONResponse(content=preds)
419
+ except Exception as e:
420
+ # Return an error message if an exception occurs
421
+ return JSONResponse(content={"error": str(e)}, status_code=500)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
  def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray:
424
  """
425
  Reads an image from a file path or an UploadFile object.
 
426
  Args:
427
  file_path: A file path or an UploadFile object.
 
428
  Returns:
429
  A numpy array representing the image.
430
  """
 
448
  return pred
449
 
450
  def run(self, host: str = "0.0.0.0", port: int = 8001):
451
+ import uvicorn
452
+ uvicorn.run(app, host=host, port=port)
453
+
454
+
455
+ def read_config(config_path: Path) -> dict:
456
+ with open(config_path, "r") as f:
457
+ conf = yaml.safe_load(f)
458
+ return conf
459
+
460
+
461
+ # api server
462
+ conf = read_config(Path(__file__).parent / "config/api.yaml")
463
+ service = ImageMatchingService.bind(conf=conf["api"], device=DEVICE)
464
+
465
+ # handle = serve.run(service, route_prefix="/")
466
+ # serve run api.server_ray:service
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
 
468
+ # build to generate config file
469
+ # serve build api.server_ray:service -o api/config/ray.yaml
470
+ # serve run api/config/ray.yaml
api/types.py DELETED
@@ -1,16 +0,0 @@
1
- from typing import List
2
-
3
- from pydantic import BaseModel
4
-
5
-
6
- class ImagesInput(BaseModel):
7
- data: List[str] = []
8
- max_keypoints: List[int] = []
9
- timestamps: List[str] = []
10
- grayscale: bool = False
11
- image_hw: List[List[int]] = [[], []]
12
- feature_type: int = 0
13
- rotates: List[float] = []
14
- scales: List[float] = []
15
- reference_points: List[List[float]] = []
16
- binarize: bool = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -36,3 +36,6 @@ roma #dust3r
36
  tqdm
37
  yacs
38
  fastapi
 
 
 
 
36
  tqdm
37
  yacs
38
  fastapi
39
+ uvicorn
40
+ ray
41
+ ray[serve]