Spaces:
Running
Running
Upload 50 files
Browse files- README.md +1 -1
- app.py +131 -0
- gradio_image_prompter-0.1.0-py3-none-any.whl +0 -0
- models/__init__.py +15 -0
- models/automatic_mask_generator.py +372 -0
- models/build_sam.py +107 -0
- models/grasp_mods.py +318 -0
- models/modeling/__init__.py +11 -0
- models/modeling/common.py +43 -0
- models/modeling/image_encoder.py +395 -0
- models/modeling/mask_decoder.py +176 -0
- models/modeling/prompt_encoder.py +214 -0
- models/modeling/sam.py +174 -0
- models/modeling/transformer.py +240 -0
- models/predictor.py +269 -0
- models/utils/__init__.py +5 -0
- models/utils/amg.py +346 -0
- models/utils/onnx.py +144 -0
- models/utils/transforms.py +102 -0
- requirements.txt +5 -0
- src/.gitignore +9 -0
- src/LICENSE +201 -0
- src/README.md +48 -0
- src/backend/gradio_image_prompter/__init__.py +3 -0
- src/backend/gradio_image_prompter/image_prompter.py +133 -0
- src/backend/gradio_image_prompter/image_prompter.pyi +134 -0
- src/backend/gradio_image_prompter/templates/component/__vite-browser-external-2447137e.js +4 -0
- src/backend/gradio_image_prompter/templates/component/index.js +0 -0
- src/backend/gradio_image_prompter/templates/component/style.css +1 -0
- src/backend/gradio_image_prompter/templates/component/wrapper-6f348d45-f837cf34.js +2455 -0
- src/backend/gradio_image_prompter/templates/example/index.js +263 -0
- src/backend/gradio_image_prompter/templates/example/style.css +1 -0
- src/demo/__init__.py +0 -0
- src/demo/app.py +9 -0
- src/frontend/Example.svelte +44 -0
- src/frontend/Index.svelte +167 -0
- src/frontend/package-lock.json +718 -0
- src/frontend/package.json +28 -0
- src/frontend/shared/BoxDrawer.svelte +237 -0
- src/frontend/shared/ClearImage.svelte +48 -0
- src/frontend/shared/Image.svelte +15 -0
- src/frontend/shared/ImagePreview.svelte +88 -0
- src/frontend/shared/ImageUploader.svelte +192 -0
- src/frontend/shared/utils.ts +24 -0
- src/pyproject.toml +43 -0
- structures/__init__.py +0 -0
- structures/bounding_box.py +323 -0
- structures/grasp_box.py +127 -0
- structures/image_list.py +67 -0
- structures/segmentation_mask.py +298 -0
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: GraspAnything
|
3 |
-
emoji:
|
4 |
colorFrom: gray
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: GraspAnything
|
3 |
+
emoji: 🤖✊
|
4 |
colorFrom: gray
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
app.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import sys
|
6 |
+
sys.path.append("./")
|
7 |
+
from models import sam_model_registry
|
8 |
+
from models.grasp_mods import modify_forward
|
9 |
+
from models.utils.transforms import ResizeLongestSide
|
10 |
+
|
11 |
+
from gradio_image_prompter import ImagePrompter
|
12 |
+
from structures.grasp_box import GraspCoder
|
13 |
+
img_resize = ResizeLongestSide(1024)
|
14 |
+
import cv2
|
15 |
+
|
16 |
+
import gradio as gr
|
17 |
+
|
18 |
+
from models.grasp_mods import add_inference_method
|
19 |
+
|
20 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
+
model_type = "vit_b"
|
22 |
+
|
23 |
+
mean = np.array([103.53, 116.28, 123.675])[:, np.newaxis, np.newaxis]
|
24 |
+
std = np.array([57.375, 57.12, 58.395])[:, np.newaxis, np.newaxis]
|
25 |
+
|
26 |
+
sam = sam_model_registry[model_type]()
|
27 |
+
sam.to(device=device)
|
28 |
+
|
29 |
+
sam.forward = modify_forward(sam)
|
30 |
+
sam.infer = add_inference_method(sam)
|
31 |
+
|
32 |
+
pretrained_model_path = "E:/epoch_9_step_535390.pth"
|
33 |
+
|
34 |
+
if pretrained_model_path != "":
|
35 |
+
sd = torch.load(pretrained_model_path)
|
36 |
+
# strip prefix "module." from keys
|
37 |
+
new_sd = {}
|
38 |
+
for k, v in sd.items():
|
39 |
+
if k.startswith("module."):
|
40 |
+
k = k[7:]
|
41 |
+
new_sd[k] = v
|
42 |
+
sam.load_state_dict(new_sd)
|
43 |
+
|
44 |
+
sam.eval()
|
45 |
+
|
46 |
+
def predict(input, topk):
|
47 |
+
np_image = input["image"]
|
48 |
+
points = input["points"]
|
49 |
+
orig_size = np_image.shape[:2]
|
50 |
+
# normalize image
|
51 |
+
np_image = np_image.transpose(2, 0, 1)
|
52 |
+
|
53 |
+
image = (np_image - mean) / std
|
54 |
+
image = torch.tensor(image).float().to(device)
|
55 |
+
image = image.unsqueeze(0)
|
56 |
+
t_image = img_resize.apply_image_torch(image)
|
57 |
+
t_orig_size = t_image.shape[-2:]
|
58 |
+
# pad to 1024x1024
|
59 |
+
t_image = torch.nn.functional.pad(t_image, (0, 1024 - t_image.shape[-1], 0, 1024 - t_image.shape[-2]))
|
60 |
+
|
61 |
+
# get box prompt
|
62 |
+
valid_boxes = []
|
63 |
+
for point in points:
|
64 |
+
x1, y1, type1, x2, y2, type2 = point
|
65 |
+
if type1 == 2 and type2 == 3:
|
66 |
+
valid_boxes.append([x1, y1, x2, y2])
|
67 |
+
if len(valid_boxes) == 0:
|
68 |
+
return np_image
|
69 |
+
t_boxes = np.array(valid_boxes)
|
70 |
+
t_boxes = img_resize.apply_boxes(t_boxes, orig_size)
|
71 |
+
box_torch = torch.as_tensor(t_boxes, dtype=torch.float, device=device)
|
72 |
+
batched_inputs = [{"image": t_image[0], "boxes": box_torch}]
|
73 |
+
with torch.no_grad():
|
74 |
+
outputs = sam.infer(batched_inputs, multimask_output=False)
|
75 |
+
# visualize and post on tensorboard
|
76 |
+
# recover image
|
77 |
+
recovered_img = batched_inputs[0]['image'].cpu().numpy()
|
78 |
+
recovered_img = recovered_img * std + mean
|
79 |
+
recovered_img = recovered_img.transpose(1, 2, 0).astype(np.uint8).clip(0, 255)
|
80 |
+
|
81 |
+
for i in range(len(outputs.pred_masks)):
|
82 |
+
# get predicted mask
|
83 |
+
pred_mask = outputs.pred_masks[i].detach().sigmoid().cpu().numpy() > 0.5
|
84 |
+
pred_mask = pred_mask.transpose(1, 2, 0).repeat(3, axis=2)
|
85 |
+
|
86 |
+
# get predicted grasp
|
87 |
+
pred_logits = outputs.logits[i].detach().cpu().numpy()
|
88 |
+
top_ind = pred_logits[:, 0].argsort()[-topk:][::-1]
|
89 |
+
pred_grasp = outputs.pred_boxes[i].detach().cpu().numpy()[top_ind]
|
90 |
+
coded_grasp = GraspCoder(1024, 1024, None, grasp_annos_reformat=pred_grasp)
|
91 |
+
_ = coded_grasp.decode()
|
92 |
+
decoded_grasp = copy.deepcopy(coded_grasp.grasp_annos)
|
93 |
+
|
94 |
+
# draw mask
|
95 |
+
mask_color = np.array([0, 255, 0])[None, None, :]
|
96 |
+
recovered_img[pred_mask] = recovered_img[pred_mask] * 0.5 + (pred_mask * mask_color)[pred_mask] * 0.5
|
97 |
+
|
98 |
+
# draw grasp
|
99 |
+
recovered_img = np.ascontiguousarray(recovered_img)
|
100 |
+
for grasp in decoded_grasp:
|
101 |
+
grasp = grasp.astype(int)
|
102 |
+
cv2.line(recovered_img, tuple(grasp[0:2]), tuple(grasp[2:4]), (255, 0, 0), 1)
|
103 |
+
cv2.line(recovered_img, tuple(grasp[4:6]), tuple(grasp[6:8]), (255, 0, 0), 1)
|
104 |
+
cv2.line(recovered_img, tuple(grasp[2:4]), tuple(grasp[4:6]), (0, 0, 255), 2)
|
105 |
+
cv2.line(recovered_img, tuple(grasp[6:8]), tuple(grasp[0:2]), (0, 0, 255), 2)
|
106 |
+
|
107 |
+
recovered_img = recovered_img[:t_orig_size[0], :t_orig_size[1]]
|
108 |
+
# resize to original size
|
109 |
+
recovered_img = cv2.resize(recovered_img, (orig_size[0], orig_size[1]))
|
110 |
+
return recovered_img
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
app = gr.Blocks(title="GraspAnything")
|
114 |
+
with app:
|
115 |
+
gr.Markdown("""
|
116 |
+
# GraspAnything <br>
|
117 |
+
Upload an image and draw a box around the object you want to grasp. Set top k to be the number of grasps you want to predict for each object.
|
118 |
+
""")
|
119 |
+
with gr.Column():
|
120 |
+
prompter = ImagePrompter(show_label=False)
|
121 |
+
top_k = gr.Slider(minimum=1, maximum=20, step=1, value=3, label="Top K Grasps")
|
122 |
+
with gr.Column():
|
123 |
+
image_output = gr.Image()
|
124 |
+
btn = gr.Button("Generate!")
|
125 |
+
btn.click(predict,
|
126 |
+
inputs=[prompter, top_k],
|
127 |
+
outputs=[image_output])
|
128 |
+
app.launch()
|
129 |
+
|
130 |
+
|
131 |
+
|
gradio_image_prompter-0.1.0-py3-none-any.whl
ADDED
Binary file (96.2 kB). View file
|
|
models/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .build_sam import (
|
8 |
+
build_sam,
|
9 |
+
build_sam_vit_h,
|
10 |
+
build_sam_vit_l,
|
11 |
+
build_sam_vit_b,
|
12 |
+
sam_model_registry,
|
13 |
+
)
|
14 |
+
from .predictor import SamPredictor
|
15 |
+
from .automatic_mask_generator import SamAutomaticMaskGenerator
|
models/automatic_mask_generator.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
10 |
+
|
11 |
+
from typing import Any, Dict, List, Optional, Tuple
|
12 |
+
|
13 |
+
from .modeling import Sam
|
14 |
+
from .predictor import SamPredictor
|
15 |
+
from .utils.amg import (
|
16 |
+
MaskData,
|
17 |
+
area_from_rle,
|
18 |
+
batch_iterator,
|
19 |
+
batched_mask_to_box,
|
20 |
+
box_xyxy_to_xywh,
|
21 |
+
build_all_layer_point_grids,
|
22 |
+
calculate_stability_score,
|
23 |
+
coco_encode_rle,
|
24 |
+
generate_crop_boxes,
|
25 |
+
is_box_near_crop_edge,
|
26 |
+
mask_to_rle_pytorch,
|
27 |
+
remove_small_regions,
|
28 |
+
rle_to_mask,
|
29 |
+
uncrop_boxes_xyxy,
|
30 |
+
uncrop_masks,
|
31 |
+
uncrop_points,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
class SamAutomaticMaskGenerator:
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
model: Sam,
|
39 |
+
points_per_side: Optional[int] = 32,
|
40 |
+
points_per_batch: int = 64,
|
41 |
+
pred_iou_thresh: float = 0.88,
|
42 |
+
stability_score_thresh: float = 0.95,
|
43 |
+
stability_score_offset: float = 1.0,
|
44 |
+
box_nms_thresh: float = 0.7,
|
45 |
+
crop_n_layers: int = 0,
|
46 |
+
crop_nms_thresh: float = 0.7,
|
47 |
+
crop_overlap_ratio: float = 512 / 1500,
|
48 |
+
crop_n_points_downscale_factor: int = 1,
|
49 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
50 |
+
min_mask_region_area: int = 0,
|
51 |
+
output_mode: str = "binary_mask",
|
52 |
+
) -> None:
|
53 |
+
"""
|
54 |
+
Using a SAM model, generates masks for the entire image.
|
55 |
+
Generates a grid of point prompts over the image, then filters
|
56 |
+
low quality and duplicate masks. The default settings are chosen
|
57 |
+
for SAM with a ViT-H backbone.
|
58 |
+
|
59 |
+
Arguments:
|
60 |
+
model (Sam): The SAM model to use for mask prediction.
|
61 |
+
points_per_side (int or None): The number of points to be sampled
|
62 |
+
along one side of the image. The total number of points is
|
63 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
64 |
+
point sampling.
|
65 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
66 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
67 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
68 |
+
model's predicted mask quality.
|
69 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
70 |
+
the stability of the mask under changes to the cutoff used to binarize
|
71 |
+
the model's mask predictions.
|
72 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
73 |
+
calculated the stability score.
|
74 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
75 |
+
suppression to filter duplicate masks.
|
76 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
77 |
+
crops of the image. Sets the number of layers to run, where each
|
78 |
+
layer has 2**i_layer number of image crops.
|
79 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
80 |
+
suppression to filter duplicate masks between different crops.
|
81 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
82 |
+
In the first crop layer, crops will overlap by this fraction of
|
83 |
+
the image length. Later layers with more crops scale down this overlap.
|
84 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
85 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
86 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
87 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
88 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
89 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
90 |
+
to remove disconnected regions and holes in masks with area smaller
|
91 |
+
than min_mask_region_area. Requires opencv.
|
92 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
93 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
94 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
95 |
+
memory.
|
96 |
+
"""
|
97 |
+
|
98 |
+
assert (points_per_side is None) != (
|
99 |
+
point_grids is None
|
100 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
101 |
+
if points_per_side is not None:
|
102 |
+
self.point_grids = build_all_layer_point_grids(
|
103 |
+
points_per_side,
|
104 |
+
crop_n_layers,
|
105 |
+
crop_n_points_downscale_factor,
|
106 |
+
)
|
107 |
+
elif point_grids is not None:
|
108 |
+
self.point_grids = point_grids
|
109 |
+
else:
|
110 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
111 |
+
|
112 |
+
assert output_mode in [
|
113 |
+
"binary_mask",
|
114 |
+
"uncompressed_rle",
|
115 |
+
"coco_rle",
|
116 |
+
], f"Unknown output_mode {output_mode}."
|
117 |
+
if output_mode == "coco_rle":
|
118 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
119 |
+
|
120 |
+
if min_mask_region_area > 0:
|
121 |
+
import cv2 # type: ignore # noqa: F401
|
122 |
+
|
123 |
+
self.predictor = SamPredictor(model)
|
124 |
+
self.points_per_batch = points_per_batch
|
125 |
+
self.pred_iou_thresh = pred_iou_thresh
|
126 |
+
self.stability_score_thresh = stability_score_thresh
|
127 |
+
self.stability_score_offset = stability_score_offset
|
128 |
+
self.box_nms_thresh = box_nms_thresh
|
129 |
+
self.crop_n_layers = crop_n_layers
|
130 |
+
self.crop_nms_thresh = crop_nms_thresh
|
131 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
132 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
133 |
+
self.min_mask_region_area = min_mask_region_area
|
134 |
+
self.output_mode = output_mode
|
135 |
+
|
136 |
+
@torch.no_grad()
|
137 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
138 |
+
"""
|
139 |
+
Generates masks for the given image.
|
140 |
+
|
141 |
+
Arguments:
|
142 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
146 |
+
a dict containing the following keys:
|
147 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
148 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
149 |
+
is a dictionary containing the RLE.
|
150 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
151 |
+
area (int): The area in pixels of the mask.
|
152 |
+
predicted_iou (float): The model's own prediction of the mask's
|
153 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
154 |
+
point_coords (list(list(float))): The point coordinates input
|
155 |
+
to the model to generate this mask.
|
156 |
+
stability_score (float): A measure of the mask's quality. This
|
157 |
+
is filtered on using the stability_score_thresh parameter.
|
158 |
+
crop_box (list(float)): The crop of the image used to generate
|
159 |
+
the mask, given in XYWH format.
|
160 |
+
"""
|
161 |
+
|
162 |
+
# Generate masks
|
163 |
+
mask_data = self._generate_masks(image)
|
164 |
+
|
165 |
+
# Filter small disconnected regions and holes in masks
|
166 |
+
if self.min_mask_region_area > 0:
|
167 |
+
mask_data = self.postprocess_small_regions(
|
168 |
+
mask_data,
|
169 |
+
self.min_mask_region_area,
|
170 |
+
max(self.box_nms_thresh, self.crop_nms_thresh),
|
171 |
+
)
|
172 |
+
|
173 |
+
# Encode masks
|
174 |
+
if self.output_mode == "coco_rle":
|
175 |
+
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
176 |
+
elif self.output_mode == "binary_mask":
|
177 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
178 |
+
else:
|
179 |
+
mask_data["segmentations"] = mask_data["rles"]
|
180 |
+
|
181 |
+
# Write mask records
|
182 |
+
curr_anns = []
|
183 |
+
for idx in range(len(mask_data["segmentations"])):
|
184 |
+
ann = {
|
185 |
+
"segmentation": mask_data["segmentations"][idx],
|
186 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
187 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
188 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
189 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
190 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
191 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
192 |
+
}
|
193 |
+
curr_anns.append(ann)
|
194 |
+
|
195 |
+
return curr_anns
|
196 |
+
|
197 |
+
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
198 |
+
orig_size = image.shape[:2]
|
199 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
200 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
201 |
+
)
|
202 |
+
|
203 |
+
# Iterate over image crops
|
204 |
+
data = MaskData()
|
205 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
206 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
207 |
+
data.cat(crop_data)
|
208 |
+
|
209 |
+
# Remove duplicate masks between crops
|
210 |
+
if len(crop_boxes) > 1:
|
211 |
+
# Prefer masks from smaller crops
|
212 |
+
scores = 1 / box_area(data["crop_boxes"])
|
213 |
+
scores = scores.to(data["boxes"].device)
|
214 |
+
keep_by_nms = batched_nms(
|
215 |
+
data["boxes"].float(),
|
216 |
+
scores,
|
217 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
218 |
+
iou_threshold=self.crop_nms_thresh,
|
219 |
+
)
|
220 |
+
data.filter(keep_by_nms)
|
221 |
+
|
222 |
+
data.to_numpy()
|
223 |
+
return data
|
224 |
+
|
225 |
+
def _process_crop(
|
226 |
+
self,
|
227 |
+
image: np.ndarray,
|
228 |
+
crop_box: List[int],
|
229 |
+
crop_layer_idx: int,
|
230 |
+
orig_size: Tuple[int, ...],
|
231 |
+
) -> MaskData:
|
232 |
+
# Crop the image and calculate embeddings
|
233 |
+
x0, y0, x1, y1 = crop_box
|
234 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
235 |
+
cropped_im_size = cropped_im.shape[:2]
|
236 |
+
self.predictor.set_image(cropped_im)
|
237 |
+
|
238 |
+
# Get points for this crop
|
239 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
240 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
241 |
+
|
242 |
+
# Generate masks for this crop in batches
|
243 |
+
data = MaskData()
|
244 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
245 |
+
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
|
246 |
+
data.cat(batch_data)
|
247 |
+
del batch_data
|
248 |
+
self.predictor.reset_image()
|
249 |
+
|
250 |
+
# Remove duplicates within this crop.
|
251 |
+
keep_by_nms = batched_nms(
|
252 |
+
data["boxes"].float(),
|
253 |
+
data["iou_preds"],
|
254 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
255 |
+
iou_threshold=self.box_nms_thresh,
|
256 |
+
)
|
257 |
+
data.filter(keep_by_nms)
|
258 |
+
|
259 |
+
# Return to the original image frame
|
260 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
261 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
262 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
263 |
+
|
264 |
+
return data
|
265 |
+
|
266 |
+
def _process_batch(
|
267 |
+
self,
|
268 |
+
points: np.ndarray,
|
269 |
+
im_size: Tuple[int, ...],
|
270 |
+
crop_box: List[int],
|
271 |
+
orig_size: Tuple[int, ...],
|
272 |
+
) -> MaskData:
|
273 |
+
orig_h, orig_w = orig_size
|
274 |
+
|
275 |
+
# Run model on this batch
|
276 |
+
transformed_points = self.predictor.transform.apply_coords(points, im_size)
|
277 |
+
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
|
278 |
+
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
|
279 |
+
masks, iou_preds, _ = self.predictor.predict_torch(
|
280 |
+
in_points[:, None, :],
|
281 |
+
in_labels[:, None],
|
282 |
+
multimask_output=True,
|
283 |
+
return_logits=True,
|
284 |
+
)
|
285 |
+
|
286 |
+
# Serialize predictions and store in MaskData
|
287 |
+
data = MaskData(
|
288 |
+
masks=masks.flatten(0, 1),
|
289 |
+
iou_preds=iou_preds.flatten(0, 1),
|
290 |
+
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
291 |
+
)
|
292 |
+
del masks
|
293 |
+
|
294 |
+
# Filter by predicted IoU
|
295 |
+
if self.pred_iou_thresh > 0.0:
|
296 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
297 |
+
data.filter(keep_mask)
|
298 |
+
|
299 |
+
# Calculate stability score
|
300 |
+
data["stability_score"] = calculate_stability_score(
|
301 |
+
data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
|
302 |
+
)
|
303 |
+
if self.stability_score_thresh > 0.0:
|
304 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
305 |
+
data.filter(keep_mask)
|
306 |
+
|
307 |
+
# Threshold masks and calculate boxes
|
308 |
+
data["masks"] = data["masks"] > self.predictor.model.mask_threshold
|
309 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
310 |
+
|
311 |
+
# Filter boxes that touch crop boundaries
|
312 |
+
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
313 |
+
if not torch.all(keep_mask):
|
314 |
+
data.filter(keep_mask)
|
315 |
+
|
316 |
+
# Compress to RLE
|
317 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
318 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
319 |
+
del data["masks"]
|
320 |
+
|
321 |
+
return data
|
322 |
+
|
323 |
+
@staticmethod
|
324 |
+
def postprocess_small_regions(
|
325 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
326 |
+
) -> MaskData:
|
327 |
+
"""
|
328 |
+
Removes small disconnected regions and holes in masks, then reruns
|
329 |
+
box NMS to remove any new duplicates.
|
330 |
+
|
331 |
+
Edits mask_data in place.
|
332 |
+
|
333 |
+
Requires open-cv as a dependency.
|
334 |
+
"""
|
335 |
+
if len(mask_data["rles"]) == 0:
|
336 |
+
return mask_data
|
337 |
+
|
338 |
+
# Filter small disconnected regions and holes
|
339 |
+
new_masks = []
|
340 |
+
scores = []
|
341 |
+
for rle in mask_data["rles"]:
|
342 |
+
mask = rle_to_mask(rle)
|
343 |
+
|
344 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
345 |
+
unchanged = not changed
|
346 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
347 |
+
unchanged = unchanged and not changed
|
348 |
+
|
349 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
350 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
351 |
+
# so NMS will prefer ones that didn't need postprocessing
|
352 |
+
scores.append(float(unchanged))
|
353 |
+
|
354 |
+
# Recalculate boxes and remove any new duplicates
|
355 |
+
masks = torch.cat(new_masks, dim=0)
|
356 |
+
boxes = batched_mask_to_box(masks)
|
357 |
+
keep_by_nms = batched_nms(
|
358 |
+
boxes.float(),
|
359 |
+
torch.as_tensor(scores),
|
360 |
+
torch.zeros_like(boxes[:, 0]), # categories
|
361 |
+
iou_threshold=nms_thresh,
|
362 |
+
)
|
363 |
+
|
364 |
+
# Only recalculate RLEs for masks that have changed
|
365 |
+
for i_mask in keep_by_nms:
|
366 |
+
if scores[i_mask] == 0.0:
|
367 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
368 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
369 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
370 |
+
mask_data.filter(keep_by_nms)
|
371 |
+
|
372 |
+
return mask_data
|
models/build_sam.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
|
12 |
+
|
13 |
+
|
14 |
+
def build_sam_vit_h(checkpoint=None):
|
15 |
+
return _build_sam(
|
16 |
+
encoder_embed_dim=1280,
|
17 |
+
encoder_depth=32,
|
18 |
+
encoder_num_heads=16,
|
19 |
+
encoder_global_attn_indexes=[7, 15, 23, 31],
|
20 |
+
checkpoint=checkpoint,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
build_sam = build_sam_vit_h
|
25 |
+
|
26 |
+
|
27 |
+
def build_sam_vit_l(checkpoint=None):
|
28 |
+
return _build_sam(
|
29 |
+
encoder_embed_dim=1024,
|
30 |
+
encoder_depth=24,
|
31 |
+
encoder_num_heads=16,
|
32 |
+
encoder_global_attn_indexes=[5, 11, 17, 23],
|
33 |
+
checkpoint=checkpoint,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def build_sam_vit_b(checkpoint=None):
|
38 |
+
return _build_sam(
|
39 |
+
encoder_embed_dim=768,
|
40 |
+
encoder_depth=12,
|
41 |
+
encoder_num_heads=12,
|
42 |
+
encoder_global_attn_indexes=[2, 5, 8, 11],
|
43 |
+
checkpoint=checkpoint,
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
sam_model_registry = {
|
48 |
+
"default": build_sam_vit_h,
|
49 |
+
"vit_h": build_sam_vit_h,
|
50 |
+
"vit_l": build_sam_vit_l,
|
51 |
+
"vit_b": build_sam_vit_b,
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
def _build_sam(
|
56 |
+
encoder_embed_dim,
|
57 |
+
encoder_depth,
|
58 |
+
encoder_num_heads,
|
59 |
+
encoder_global_attn_indexes,
|
60 |
+
checkpoint=None,
|
61 |
+
):
|
62 |
+
prompt_embed_dim = 256
|
63 |
+
image_size = 1024
|
64 |
+
vit_patch_size = 16
|
65 |
+
image_embedding_size = image_size // vit_patch_size
|
66 |
+
sam = Sam(
|
67 |
+
image_encoder=ImageEncoderViT(
|
68 |
+
depth=encoder_depth,
|
69 |
+
embed_dim=encoder_embed_dim,
|
70 |
+
img_size=image_size,
|
71 |
+
mlp_ratio=4,
|
72 |
+
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
73 |
+
num_heads=encoder_num_heads,
|
74 |
+
patch_size=vit_patch_size,
|
75 |
+
qkv_bias=True,
|
76 |
+
use_rel_pos=True,
|
77 |
+
global_attn_indexes=encoder_global_attn_indexes,
|
78 |
+
window_size=14,
|
79 |
+
out_chans=prompt_embed_dim,
|
80 |
+
),
|
81 |
+
prompt_encoder=PromptEncoder(
|
82 |
+
embed_dim=prompt_embed_dim,
|
83 |
+
image_embedding_size=(image_embedding_size, image_embedding_size),
|
84 |
+
input_image_size=(image_size, image_size),
|
85 |
+
mask_in_chans=16,
|
86 |
+
),
|
87 |
+
mask_decoder=MaskDecoder(
|
88 |
+
num_multimask_outputs=3,
|
89 |
+
transformer=TwoWayTransformer(
|
90 |
+
depth=2,
|
91 |
+
embedding_dim=prompt_embed_dim,
|
92 |
+
mlp_dim=2048,
|
93 |
+
num_heads=8,
|
94 |
+
),
|
95 |
+
transformer_dim=prompt_embed_dim,
|
96 |
+
iou_head_depth=3,
|
97 |
+
iou_head_hidden_dim=256,
|
98 |
+
),
|
99 |
+
pixel_mean=[123.675, 116.28, 103.53],
|
100 |
+
pixel_std=[58.395, 57.12, 57.375],
|
101 |
+
)
|
102 |
+
sam.eval()
|
103 |
+
if checkpoint is not None:
|
104 |
+
with open(checkpoint, "rb") as f:
|
105 |
+
state_dict = torch.load(f)
|
106 |
+
sam.load_state_dict(state_dict)
|
107 |
+
return sam
|
models/grasp_mods.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Add additional grasp decoder for Segment Anything model.
|
3 |
+
The structure should follow the grasp decoder structure in GraspDETR.
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from transformers.models.detr.configuration_detr import DetrConfig
|
8 |
+
from transformers.models.detr.modeling_detr import DetrHungarianMatcher, DetrLoss, DetrSegmentationOutput, DetrDecoder, sigmoid_focal_loss, dice_loss
|
9 |
+
from typing import Any, Dict, List, Tuple
|
10 |
+
from transformers.models.detr.modeling_detr import generalized_box_iou
|
11 |
+
from transformers.image_transforms import center_to_corners_format
|
12 |
+
from scipy.optimize import linear_sum_assignment
|
13 |
+
|
14 |
+
def modify_matcher_forward(self):
|
15 |
+
@torch.no_grad()
|
16 |
+
def matcher_forward(outputs, targets):
|
17 |
+
|
18 |
+
batch_size, num_queries = outputs["logits"].shape[:2]
|
19 |
+
|
20 |
+
# We flatten to compute the cost matrices in a batch
|
21 |
+
out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
|
22 |
+
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
23 |
+
|
24 |
+
# Also concat the target labels and boxes
|
25 |
+
target_ids = torch.cat([v["class_labels"] for v in targets])
|
26 |
+
target_bbox = torch.cat([v["boxes"] for v in targets])
|
27 |
+
|
28 |
+
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
|
29 |
+
# but approximate it in 1 - proba[target class].
|
30 |
+
# The 1 is a constant that doesn't change the matching, it can be ommitted.
|
31 |
+
class_cost = -out_prob[:, target_ids]
|
32 |
+
|
33 |
+
# Compute the L1 cost between boxes
|
34 |
+
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
35 |
+
|
36 |
+
# Compute the giou cost between boxes
|
37 |
+
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox[:, :4]), center_to_corners_format(target_bbox[:, :4]))
|
38 |
+
|
39 |
+
# Final cost matrix
|
40 |
+
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
41 |
+
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
42 |
+
|
43 |
+
sizes = [len(v["boxes"]) for v in targets]
|
44 |
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
45 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
46 |
+
return matcher_forward
|
47 |
+
|
48 |
+
def modify_grasp_loss_forward(self):
|
49 |
+
def modified_loss_labels(outputs, targets, indices, num_boxes):
|
50 |
+
"""
|
51 |
+
Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
|
52 |
+
[nb_target_boxes]
|
53 |
+
"""
|
54 |
+
num_classes = 1 # model v9 always use class agnostic grasp
|
55 |
+
if "logits" not in outputs:
|
56 |
+
raise KeyError("No logits were found in the outputs")
|
57 |
+
source_logits = outputs["logits"]
|
58 |
+
|
59 |
+
idx = self._get_source_permutation_idx(indices)
|
60 |
+
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
|
61 |
+
target_classes = torch.full(
|
62 |
+
source_logits.shape[:2], num_classes, dtype=torch.int64, device=source_logits.device
|
63 |
+
)
|
64 |
+
target_classes[idx] = target_classes_o
|
65 |
+
|
66 |
+
loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes)
|
67 |
+
losses = {"loss_ce": loss_ce}
|
68 |
+
|
69 |
+
return losses
|
70 |
+
|
71 |
+
def modified_loss_boxes(outputs, targets, indices, num_boxes):
|
72 |
+
|
73 |
+
if "pred_boxes" not in outputs:
|
74 |
+
raise KeyError("No predicted boxes found in outputs")
|
75 |
+
idx = self._get_source_permutation_idx(indices)
|
76 |
+
source_boxes = outputs["pred_boxes"][idx]
|
77 |
+
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
78 |
+
|
79 |
+
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
|
80 |
+
|
81 |
+
losses = {}
|
82 |
+
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
83 |
+
|
84 |
+
loss_giou = 1 - torch.diag(
|
85 |
+
generalized_box_iou(center_to_corners_format(source_boxes[:, :4]), center_to_corners_format(target_boxes[:, :4]))
|
86 |
+
)
|
87 |
+
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
88 |
+
return losses
|
89 |
+
return modified_loss_labels, modified_loss_boxes
|
90 |
+
|
91 |
+
def modify_forward(self):
|
92 |
+
"""
|
93 |
+
Modify the following methods to make SAM perform grasp detection after segmentation:
|
94 |
+
1. Add a parallel decoder for grasping detection: 1(+1) classes, 5 values to regress (bbox & rotation)
|
95 |
+
Returns:
|
96 |
+
Modified model
|
97 |
+
"""
|
98 |
+
# 1. We instantiate a new module in self.base_model, as another decoder
|
99 |
+
self.grasp_decoder_config = DetrConfig()
|
100 |
+
self.grasp_decoder = DetrDecoder(self.grasp_decoder_config).to(self.device)
|
101 |
+
self.grasp_query_position_embeddings = nn.Embedding(20, 256).to(self.device)
|
102 |
+
# 2. Base model forward method is not directly used, no modification needs to be done
|
103 |
+
# self.detr.model.forward = modify_base_model_forward(self.detr.model)
|
104 |
+
# 3. Add additional classification head & bbox regression head for grasp_decoder output
|
105 |
+
self.grasp_predictor = torch.nn.Sequential(
|
106 |
+
torch.nn.Linear(256, 256),
|
107 |
+
torch.nn.Linear(256, 256),
|
108 |
+
torch.nn.Linear(256, 5)
|
109 |
+
).to(self.device)
|
110 |
+
self.grasp_label_classifier = torch.nn.Linear(256, 2).to(self.device)
|
111 |
+
# 4. Add positional embedding
|
112 |
+
# name it as grasp_img_pos_embed to avoid name conflict
|
113 |
+
class ImagePosEmbed(nn.Module):
|
114 |
+
def __init__(self, img_size=64, hidden_dim=256):
|
115 |
+
super().__init__()
|
116 |
+
self.pos_embed = nn.Parameter(
|
117 |
+
torch.randn(1, img_size, img_size, hidden_dim)
|
118 |
+
)
|
119 |
+
def forward(self, x):
|
120 |
+
return x + self.pos_embed
|
121 |
+
|
122 |
+
self.grasp_img_pos_embed = ImagePosEmbed().to(self.device)
|
123 |
+
|
124 |
+
def modified_forward(
|
125 |
+
batched_input: List[Dict[str, Any]],
|
126 |
+
multimask_output: bool,
|
127 |
+
):
|
128 |
+
input_images = torch.stack([x["image"] for x in batched_input], dim=0)
|
129 |
+
image_embeddings = self.image_encoder(input_images)
|
130 |
+
|
131 |
+
outputs = []
|
132 |
+
srcs = []
|
133 |
+
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
134 |
+
if "point_coords" in image_record:
|
135 |
+
points = (image_record["point_coords"], image_record["point_labels"])
|
136 |
+
else:
|
137 |
+
points = None
|
138 |
+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
139 |
+
points=points,
|
140 |
+
boxes=image_record.get("boxes", None),
|
141 |
+
masks=image_record.get("mask_inputs", None),
|
142 |
+
)
|
143 |
+
low_res_masks, iou_predictions, src = self.mask_decoder(
|
144 |
+
image_embeddings=curr_embedding.unsqueeze(0),
|
145 |
+
image_pe=self.prompt_encoder.get_dense_pe(),
|
146 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
147 |
+
dense_prompt_embeddings=dense_embeddings,
|
148 |
+
multimask_output=multimask_output,
|
149 |
+
)
|
150 |
+
outputs.append(
|
151 |
+
{
|
152 |
+
"iou_predictions": iou_predictions,
|
153 |
+
"low_res_logits": low_res_masks,
|
154 |
+
}
|
155 |
+
)
|
156 |
+
srcs.append(src[0])
|
157 |
+
srcs = torch.stack(srcs, dim=0)
|
158 |
+
# forward grasp decoder here
|
159 |
+
# 1. Get encoder hidden states
|
160 |
+
grasp_encoder_hidden_states = self.grasp_img_pos_embed(srcs.permute(0, 2, 3, 1))
|
161 |
+
# 2. Get query embeddings
|
162 |
+
grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
|
163 |
+
# repeat to batchsize
|
164 |
+
grasp_query_pe = grasp_query_pe.repeat(len(batched_input), 1, 1)
|
165 |
+
grasp_decoder_outputs = self.grasp_decoder(
|
166 |
+
inputs_embeds=torch.zeros_like(grasp_query_pe),
|
167 |
+
attention_mask=None,
|
168 |
+
position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
|
169 |
+
query_position_embeddings=grasp_query_pe,
|
170 |
+
encoder_hidden_states=grasp_encoder_hidden_states,
|
171 |
+
encoder_attention_mask=None,
|
172 |
+
output_attentions=False,
|
173 |
+
output_hidden_states=False,
|
174 |
+
return_dict=True,
|
175 |
+
)
|
176 |
+
grasp_sequence_output = grasp_decoder_outputs[0]
|
177 |
+
grasp_logits = self.grasp_label_classifier(grasp_sequence_output)
|
178 |
+
pred_grasps = self.grasp_predictor(grasp_sequence_output).sigmoid()
|
179 |
+
|
180 |
+
# 3. Calculate loss
|
181 |
+
loss, loss_dict = 0, {}
|
182 |
+
if "grasp_labels" in batched_input[0]:
|
183 |
+
config = self.grasp_decoder_config
|
184 |
+
grasp_labels = [{
|
185 |
+
"class_labels": torch.zeros([len(x["grasp_labels"])], dtype=torch.long).to(self.device),
|
186 |
+
"boxes": x["grasp_labels"],
|
187 |
+
} for x in batched_input]
|
188 |
+
# First: create the matcher
|
189 |
+
matcher = DetrHungarianMatcher(
|
190 |
+
class_cost=config.class_cost, bbox_cost=config.bbox_cost, giou_cost=config.giou_cost
|
191 |
+
)
|
192 |
+
matcher.forward = modify_matcher_forward(matcher)
|
193 |
+
# Second: create the criterion
|
194 |
+
losses = ["labels", "boxes"]
|
195 |
+
criterion = DetrLoss(
|
196 |
+
matcher=matcher,
|
197 |
+
num_classes=config.num_labels,
|
198 |
+
eos_coef=config.eos_coefficient,
|
199 |
+
losses=losses,
|
200 |
+
)
|
201 |
+
criterion.loss_labels, criterion.loss_boxes = modify_grasp_loss_forward(criterion)
|
202 |
+
criterion.to(self.device)
|
203 |
+
# Third: compute the losses, based on outputs and labels
|
204 |
+
outputs_loss = {}
|
205 |
+
outputs_loss["logits"] = grasp_logits
|
206 |
+
outputs_loss["pred_boxes"] = pred_grasps
|
207 |
+
|
208 |
+
grasp_loss_dict = criterion(outputs_loss, grasp_labels)
|
209 |
+
# Fourth: compute total loss, as a weighted sum of the various losses
|
210 |
+
weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
|
211 |
+
weight_dict["loss_giou"] = config.giou_loss_coefficient
|
212 |
+
if config.auxiliary_loss:
|
213 |
+
aux_weight_dict = {}
|
214 |
+
for i in range(config.decoder_layers - 1):
|
215 |
+
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
216 |
+
weight_dict.update(aux_weight_dict)
|
217 |
+
grasp_loss = sum(grasp_loss_dict[k] * weight_dict[k] for k in grasp_loss_dict.keys() if k in weight_dict)
|
218 |
+
|
219 |
+
# merge grasp branch loss into variable loss & loss_dict
|
220 |
+
loss += grasp_loss
|
221 |
+
loss_dict.update(grasp_loss_dict)
|
222 |
+
pred_masks = self.postprocess_masks(
|
223 |
+
torch.cat([x['low_res_logits'] for x in outputs], dim=0),
|
224 |
+
input_size=image_record["image"].shape[-2:],
|
225 |
+
original_size=(1024, 1024),
|
226 |
+
)
|
227 |
+
if 'masks' in batched_input[0]:
|
228 |
+
# 4. Calculate segmentation loss
|
229 |
+
sf_loss = sigmoid_focal_loss(pred_masks.flatten(1),
|
230 |
+
torch.stack([x['masks'] for x in batched_input], dim=0).unsqueeze(1).type(torch.float32).flatten(1), len(batched_input))
|
231 |
+
d_loss = dice_loss(pred_masks.flatten(1),
|
232 |
+
torch.stack([x['masks'] for x in batched_input], dim=0).unsqueeze(1).type(torch.float32).flatten(1), len(batched_input))
|
233 |
+
loss += sf_loss + d_loss
|
234 |
+
loss_dict["sf_loss"] = sf_loss
|
235 |
+
loss_dict["d_loss"] = d_loss
|
236 |
+
return DetrSegmentationOutput(
|
237 |
+
loss=loss,
|
238 |
+
loss_dict=loss_dict,
|
239 |
+
logits=grasp_logits,
|
240 |
+
pred_boxes=pred_grasps,
|
241 |
+
pred_masks=pred_masks,
|
242 |
+
)
|
243 |
+
|
244 |
+
return modified_forward
|
245 |
+
|
246 |
+
def add_inference_method(self):
|
247 |
+
def infer(
|
248 |
+
batched_input: List[Dict[str, Any]],
|
249 |
+
multimask_output: bool,
|
250 |
+
):
|
251 |
+
input_images = torch.stack([x["image"] for x in batched_input], dim=0)
|
252 |
+
image_embeddings = self.image_encoder(input_images)
|
253 |
+
|
254 |
+
outputs = []
|
255 |
+
srcs = []
|
256 |
+
curr_embedding = image_embeddings[0]
|
257 |
+
image_record = batched_input[0]
|
258 |
+
|
259 |
+
if "point_coords" in image_record:
|
260 |
+
points = (image_record["point_coords"], image_record["point_labels"])
|
261 |
+
else:
|
262 |
+
points = None
|
263 |
+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
264 |
+
points=points,
|
265 |
+
boxes=image_record.get("boxes", None),
|
266 |
+
masks=image_record.get("mask_inputs", None),
|
267 |
+
)
|
268 |
+
low_res_masks, iou_predictions, src = self.mask_decoder(
|
269 |
+
image_embeddings=curr_embedding.unsqueeze(0),
|
270 |
+
image_pe=self.prompt_encoder.get_dense_pe(),
|
271 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
272 |
+
dense_prompt_embeddings=dense_embeddings,
|
273 |
+
multimask_output=multimask_output,
|
274 |
+
)
|
275 |
+
outputs.append(
|
276 |
+
{
|
277 |
+
"iou_predictions": iou_predictions,
|
278 |
+
"low_res_logits": low_res_masks,
|
279 |
+
}
|
280 |
+
)
|
281 |
+
srcs.append(src[0])
|
282 |
+
|
283 |
+
n_queries = iou_predictions.size(0)
|
284 |
+
|
285 |
+
# forward grasp decoder here
|
286 |
+
# 1. Get encoder hidden states
|
287 |
+
grasp_encoder_hidden_states = self.grasp_img_pos_embed(src.permute(0, 2, 3, 1))
|
288 |
+
# 2. Get query embeddings
|
289 |
+
grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
|
290 |
+
# repeat to batchsize
|
291 |
+
grasp_query_pe = grasp_query_pe.repeat(n_queries, 1, 1)
|
292 |
+
grasp_decoder_outputs = self.grasp_decoder(
|
293 |
+
inputs_embeds=torch.zeros_like(grasp_query_pe),
|
294 |
+
attention_mask=None,
|
295 |
+
position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
|
296 |
+
query_position_embeddings=grasp_query_pe,
|
297 |
+
encoder_hidden_states=grasp_encoder_hidden_states,
|
298 |
+
encoder_attention_mask=None,
|
299 |
+
output_attentions=False,
|
300 |
+
output_hidden_states=False,
|
301 |
+
return_dict=True,
|
302 |
+
)
|
303 |
+
grasp_sequence_output = grasp_decoder_outputs[0]
|
304 |
+
grasp_logits = self.grasp_label_classifier(grasp_sequence_output)
|
305 |
+
pred_grasps = self.grasp_predictor(grasp_sequence_output).sigmoid()
|
306 |
+
pred_masks = self.postprocess_masks(
|
307 |
+
torch.cat([x['low_res_logits'] for x in outputs], dim=0),
|
308 |
+
input_size=image_record["image"].shape[-2:],
|
309 |
+
original_size=(1024, 1024),
|
310 |
+
)
|
311 |
+
return DetrSegmentationOutput(
|
312 |
+
loss=0,
|
313 |
+
loss_dict={},
|
314 |
+
logits=grasp_logits,
|
315 |
+
pred_boxes=pred_grasps,
|
316 |
+
pred_masks=pred_masks,
|
317 |
+
)
|
318 |
+
return infer
|
models/modeling/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .sam import Sam
|
8 |
+
from .image_encoder import ImageEncoderViT
|
9 |
+
from .mask_decoder import MaskDecoder
|
10 |
+
from .prompt_encoder import PromptEncoder
|
11 |
+
from .transformer import TwoWayTransformer
|
models/modeling/common.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from typing import Type
|
11 |
+
|
12 |
+
|
13 |
+
class MLPBlock(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
embedding_dim: int,
|
17 |
+
mlp_dim: int,
|
18 |
+
act: Type[nn.Module] = nn.GELU,
|
19 |
+
) -> None:
|
20 |
+
super().__init__()
|
21 |
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
22 |
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
23 |
+
self.act = act()
|
24 |
+
|
25 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
26 |
+
return self.lin2(self.act(self.lin1(x)))
|
27 |
+
|
28 |
+
|
29 |
+
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
30 |
+
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
31 |
+
class LayerNorm2d(nn.Module):
|
32 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
33 |
+
super().__init__()
|
34 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
35 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
36 |
+
self.eps = eps
|
37 |
+
|
38 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
39 |
+
u = x.mean(1, keepdim=True)
|
40 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
41 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
42 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
43 |
+
return x
|
models/modeling/image_encoder.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
from typing import Optional, Tuple, Type
|
12 |
+
|
13 |
+
from .common import LayerNorm2d, MLPBlock
|
14 |
+
|
15 |
+
|
16 |
+
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
17 |
+
class ImageEncoderViT(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
img_size: int = 1024,
|
21 |
+
patch_size: int = 16,
|
22 |
+
in_chans: int = 3,
|
23 |
+
embed_dim: int = 768,
|
24 |
+
depth: int = 12,
|
25 |
+
num_heads: int = 12,
|
26 |
+
mlp_ratio: float = 4.0,
|
27 |
+
out_chans: int = 256,
|
28 |
+
qkv_bias: bool = True,
|
29 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
30 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
31 |
+
use_abs_pos: bool = True,
|
32 |
+
use_rel_pos: bool = False,
|
33 |
+
rel_pos_zero_init: bool = True,
|
34 |
+
window_size: int = 0,
|
35 |
+
global_attn_indexes: Tuple[int, ...] = (),
|
36 |
+
) -> None:
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
img_size (int): Input image size.
|
40 |
+
patch_size (int): Patch size.
|
41 |
+
in_chans (int): Number of input image channels.
|
42 |
+
embed_dim (int): Patch embedding dimension.
|
43 |
+
depth (int): Depth of ViT.
|
44 |
+
num_heads (int): Number of attention heads in each ViT block.
|
45 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
46 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
47 |
+
norm_layer (nn.Module): Normalization layer.
|
48 |
+
act_layer (nn.Module): Activation layer.
|
49 |
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
50 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
51 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
52 |
+
window_size (int): Window size for window attention blocks.
|
53 |
+
global_attn_indexes (list): Indexes for blocks using global attention.
|
54 |
+
"""
|
55 |
+
super().__init__()
|
56 |
+
self.img_size = img_size
|
57 |
+
|
58 |
+
self.patch_embed = PatchEmbed(
|
59 |
+
kernel_size=(patch_size, patch_size),
|
60 |
+
stride=(patch_size, patch_size),
|
61 |
+
in_chans=in_chans,
|
62 |
+
embed_dim=embed_dim,
|
63 |
+
)
|
64 |
+
|
65 |
+
self.pos_embed: Optional[nn.Parameter] = None
|
66 |
+
if use_abs_pos:
|
67 |
+
# Initialize absolute positional embedding with pretrain image size.
|
68 |
+
self.pos_embed = nn.Parameter(
|
69 |
+
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
|
70 |
+
)
|
71 |
+
|
72 |
+
self.blocks = nn.ModuleList()
|
73 |
+
for i in range(depth):
|
74 |
+
block = Block(
|
75 |
+
dim=embed_dim,
|
76 |
+
num_heads=num_heads,
|
77 |
+
mlp_ratio=mlp_ratio,
|
78 |
+
qkv_bias=qkv_bias,
|
79 |
+
norm_layer=norm_layer,
|
80 |
+
act_layer=act_layer,
|
81 |
+
use_rel_pos=use_rel_pos,
|
82 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
83 |
+
window_size=window_size if i not in global_attn_indexes else 0,
|
84 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
85 |
+
)
|
86 |
+
self.blocks.append(block)
|
87 |
+
|
88 |
+
self.neck = nn.Sequential(
|
89 |
+
nn.Conv2d(
|
90 |
+
embed_dim,
|
91 |
+
out_chans,
|
92 |
+
kernel_size=1,
|
93 |
+
bias=False,
|
94 |
+
),
|
95 |
+
LayerNorm2d(out_chans),
|
96 |
+
nn.Conv2d(
|
97 |
+
out_chans,
|
98 |
+
out_chans,
|
99 |
+
kernel_size=3,
|
100 |
+
padding=1,
|
101 |
+
bias=False,
|
102 |
+
),
|
103 |
+
LayerNorm2d(out_chans),
|
104 |
+
)
|
105 |
+
|
106 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
107 |
+
x = self.patch_embed(x)
|
108 |
+
if self.pos_embed is not None:
|
109 |
+
x = x + self.pos_embed
|
110 |
+
|
111 |
+
for blk in self.blocks:
|
112 |
+
x = blk(x)
|
113 |
+
|
114 |
+
x = self.neck(x.permute(0, 3, 1, 2))
|
115 |
+
|
116 |
+
return x
|
117 |
+
|
118 |
+
|
119 |
+
class Block(nn.Module):
|
120 |
+
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
121 |
+
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
dim: int,
|
125 |
+
num_heads: int,
|
126 |
+
mlp_ratio: float = 4.0,
|
127 |
+
qkv_bias: bool = True,
|
128 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
129 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
130 |
+
use_rel_pos: bool = False,
|
131 |
+
rel_pos_zero_init: bool = True,
|
132 |
+
window_size: int = 0,
|
133 |
+
input_size: Optional[Tuple[int, int]] = None,
|
134 |
+
) -> None:
|
135 |
+
"""
|
136 |
+
Args:
|
137 |
+
dim (int): Number of input channels.
|
138 |
+
num_heads (int): Number of attention heads in each ViT block.
|
139 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
140 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
141 |
+
norm_layer (nn.Module): Normalization layer.
|
142 |
+
act_layer (nn.Module): Activation layer.
|
143 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
144 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
145 |
+
window_size (int): Window size for window attention blocks. If it equals 0, then
|
146 |
+
use global attention.
|
147 |
+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
148 |
+
positional parameter size.
|
149 |
+
"""
|
150 |
+
super().__init__()
|
151 |
+
self.norm1 = norm_layer(dim)
|
152 |
+
self.attn = Attention(
|
153 |
+
dim,
|
154 |
+
num_heads=num_heads,
|
155 |
+
qkv_bias=qkv_bias,
|
156 |
+
use_rel_pos=use_rel_pos,
|
157 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
158 |
+
input_size=input_size if window_size == 0 else (window_size, window_size),
|
159 |
+
)
|
160 |
+
|
161 |
+
self.norm2 = norm_layer(dim)
|
162 |
+
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
|
163 |
+
|
164 |
+
self.window_size = window_size
|
165 |
+
|
166 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
167 |
+
shortcut = x
|
168 |
+
x = self.norm1(x)
|
169 |
+
# Window partition
|
170 |
+
if self.window_size > 0:
|
171 |
+
H, W = x.shape[1], x.shape[2]
|
172 |
+
x, pad_hw = window_partition(x, self.window_size)
|
173 |
+
|
174 |
+
x = self.attn(x)
|
175 |
+
# Reverse window partition
|
176 |
+
if self.window_size > 0:
|
177 |
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
178 |
+
|
179 |
+
x = shortcut + x
|
180 |
+
x = x + self.mlp(self.norm2(x))
|
181 |
+
|
182 |
+
return x
|
183 |
+
|
184 |
+
|
185 |
+
class Attention(nn.Module):
|
186 |
+
"""Multi-head Attention block with relative position embeddings."""
|
187 |
+
|
188 |
+
def __init__(
|
189 |
+
self,
|
190 |
+
dim: int,
|
191 |
+
num_heads: int = 8,
|
192 |
+
qkv_bias: bool = True,
|
193 |
+
use_rel_pos: bool = False,
|
194 |
+
rel_pos_zero_init: bool = True,
|
195 |
+
input_size: Optional[Tuple[int, int]] = None,
|
196 |
+
) -> None:
|
197 |
+
"""
|
198 |
+
Args:
|
199 |
+
dim (int): Number of input channels.
|
200 |
+
num_heads (int): Number of attention heads.
|
201 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
202 |
+
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
203 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
204 |
+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
205 |
+
positional parameter size.
|
206 |
+
"""
|
207 |
+
super().__init__()
|
208 |
+
self.num_heads = num_heads
|
209 |
+
head_dim = dim // num_heads
|
210 |
+
self.scale = head_dim**-0.5
|
211 |
+
|
212 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
213 |
+
self.proj = nn.Linear(dim, dim)
|
214 |
+
|
215 |
+
self.use_rel_pos = use_rel_pos
|
216 |
+
if self.use_rel_pos:
|
217 |
+
assert (
|
218 |
+
input_size is not None
|
219 |
+
), "Input size must be provided if using relative positional encoding."
|
220 |
+
# initialize relative positional embeddings
|
221 |
+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
222 |
+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
223 |
+
|
224 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
225 |
+
B, H, W, _ = x.shape
|
226 |
+
# qkv with shape (3, B, nHead, H * W, C)
|
227 |
+
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
228 |
+
# q, k, v with shape (B * nHead, H * W, C)
|
229 |
+
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
230 |
+
|
231 |
+
attn = (q * self.scale) @ k.transpose(-2, -1)
|
232 |
+
|
233 |
+
if self.use_rel_pos:
|
234 |
+
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
235 |
+
|
236 |
+
attn = attn.softmax(dim=-1)
|
237 |
+
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
238 |
+
x = self.proj(x)
|
239 |
+
|
240 |
+
return x
|
241 |
+
|
242 |
+
|
243 |
+
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
244 |
+
"""
|
245 |
+
Partition into non-overlapping windows with padding if needed.
|
246 |
+
Args:
|
247 |
+
x (tensor): input tokens with [B, H, W, C].
|
248 |
+
window_size (int): window size.
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
252 |
+
(Hp, Wp): padded height and width before partition
|
253 |
+
"""
|
254 |
+
B, H, W, C = x.shape
|
255 |
+
|
256 |
+
pad_h = (window_size - H % window_size) % window_size
|
257 |
+
pad_w = (window_size - W % window_size) % window_size
|
258 |
+
if pad_h > 0 or pad_w > 0:
|
259 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
260 |
+
Hp, Wp = H + pad_h, W + pad_w
|
261 |
+
|
262 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
263 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
264 |
+
return windows, (Hp, Wp)
|
265 |
+
|
266 |
+
|
267 |
+
def window_unpartition(
|
268 |
+
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
|
269 |
+
) -> torch.Tensor:
|
270 |
+
"""
|
271 |
+
Window unpartition into original sequences and removing padding.
|
272 |
+
Args:
|
273 |
+
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
274 |
+
window_size (int): window size.
|
275 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
276 |
+
hw (Tuple): original height and width (H, W) before padding.
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
x: unpartitioned sequences with [B, H, W, C].
|
280 |
+
"""
|
281 |
+
Hp, Wp = pad_hw
|
282 |
+
H, W = hw
|
283 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
284 |
+
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
285 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
286 |
+
|
287 |
+
if Hp > H or Wp > W:
|
288 |
+
x = x[:, :H, :W, :].contiguous()
|
289 |
+
return x
|
290 |
+
|
291 |
+
|
292 |
+
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
293 |
+
"""
|
294 |
+
Get relative positional embeddings according to the relative positions of
|
295 |
+
query and key sizes.
|
296 |
+
Args:
|
297 |
+
q_size (int): size of query q.
|
298 |
+
k_size (int): size of key k.
|
299 |
+
rel_pos (Tensor): relative position embeddings (L, C).
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
Extracted positional embeddings according to relative positions.
|
303 |
+
"""
|
304 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
305 |
+
# Interpolate rel pos if needed.
|
306 |
+
if rel_pos.shape[0] != max_rel_dist:
|
307 |
+
# Interpolate rel pos.
|
308 |
+
rel_pos_resized = F.interpolate(
|
309 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
310 |
+
size=max_rel_dist,
|
311 |
+
mode="linear",
|
312 |
+
)
|
313 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
314 |
+
else:
|
315 |
+
rel_pos_resized = rel_pos
|
316 |
+
|
317 |
+
# Scale the coords with short length if shapes for q and k are different.
|
318 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
319 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
320 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
321 |
+
|
322 |
+
return rel_pos_resized[relative_coords.long()]
|
323 |
+
|
324 |
+
|
325 |
+
def add_decomposed_rel_pos(
|
326 |
+
attn: torch.Tensor,
|
327 |
+
q: torch.Tensor,
|
328 |
+
rel_pos_h: torch.Tensor,
|
329 |
+
rel_pos_w: torch.Tensor,
|
330 |
+
q_size: Tuple[int, int],
|
331 |
+
k_size: Tuple[int, int],
|
332 |
+
) -> torch.Tensor:
|
333 |
+
"""
|
334 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
335 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
336 |
+
Args:
|
337 |
+
attn (Tensor): attention map.
|
338 |
+
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
339 |
+
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
340 |
+
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
341 |
+
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
342 |
+
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
attn (Tensor): attention map with added relative positional embeddings.
|
346 |
+
"""
|
347 |
+
q_h, q_w = q_size
|
348 |
+
k_h, k_w = k_size
|
349 |
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
350 |
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
351 |
+
|
352 |
+
B, _, dim = q.shape
|
353 |
+
r_q = q.reshape(B, q_h, q_w, dim)
|
354 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
355 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
356 |
+
|
357 |
+
attn = (
|
358 |
+
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
359 |
+
).view(B, q_h * q_w, k_h * k_w)
|
360 |
+
|
361 |
+
return attn
|
362 |
+
|
363 |
+
|
364 |
+
class PatchEmbed(nn.Module):
|
365 |
+
"""
|
366 |
+
Image to Patch Embedding.
|
367 |
+
"""
|
368 |
+
|
369 |
+
def __init__(
|
370 |
+
self,
|
371 |
+
kernel_size: Tuple[int, int] = (16, 16),
|
372 |
+
stride: Tuple[int, int] = (16, 16),
|
373 |
+
padding: Tuple[int, int] = (0, 0),
|
374 |
+
in_chans: int = 3,
|
375 |
+
embed_dim: int = 768,
|
376 |
+
) -> None:
|
377 |
+
"""
|
378 |
+
Args:
|
379 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
380 |
+
stride (Tuple): stride of the projection layer.
|
381 |
+
padding (Tuple): padding size of the projection layer.
|
382 |
+
in_chans (int): Number of input image channels.
|
383 |
+
embed_dim (int): Patch embedding dimension.
|
384 |
+
"""
|
385 |
+
super().__init__()
|
386 |
+
|
387 |
+
self.proj = nn.Conv2d(
|
388 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
389 |
+
)
|
390 |
+
|
391 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
392 |
+
x = self.proj(x)
|
393 |
+
# B C H W -> B H W C
|
394 |
+
x = x.permute(0, 2, 3, 1)
|
395 |
+
return x
|
models/modeling/mask_decoder.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
from typing import List, Tuple, Type
|
12 |
+
|
13 |
+
from .common import LayerNorm2d
|
14 |
+
|
15 |
+
|
16 |
+
class MaskDecoder(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
*,
|
20 |
+
transformer_dim: int,
|
21 |
+
transformer: nn.Module,
|
22 |
+
num_multimask_outputs: int = 3,
|
23 |
+
activation: Type[nn.Module] = nn.GELU,
|
24 |
+
iou_head_depth: int = 3,
|
25 |
+
iou_head_hidden_dim: int = 256,
|
26 |
+
) -> None:
|
27 |
+
"""
|
28 |
+
Predicts masks given an image and prompt embeddings, using a
|
29 |
+
transformer architecture.
|
30 |
+
|
31 |
+
Arguments:
|
32 |
+
transformer_dim (int): the channel dimension of the transformer
|
33 |
+
transformer (nn.Module): the transformer used to predict masks
|
34 |
+
num_multimask_outputs (int): the number of masks to predict
|
35 |
+
when disambiguating masks
|
36 |
+
activation (nn.Module): the type of activation to use when
|
37 |
+
upscaling masks
|
38 |
+
iou_head_depth (int): the depth of the MLP used to predict
|
39 |
+
mask quality
|
40 |
+
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
41 |
+
used to predict mask quality
|
42 |
+
"""
|
43 |
+
super().__init__()
|
44 |
+
self.transformer_dim = transformer_dim
|
45 |
+
self.transformer = transformer
|
46 |
+
|
47 |
+
self.num_multimask_outputs = num_multimask_outputs
|
48 |
+
|
49 |
+
self.iou_token = nn.Embedding(1, transformer_dim)
|
50 |
+
self.num_mask_tokens = num_multimask_outputs + 1
|
51 |
+
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
52 |
+
|
53 |
+
self.output_upscaling = nn.Sequential(
|
54 |
+
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
55 |
+
LayerNorm2d(transformer_dim // 4),
|
56 |
+
activation(),
|
57 |
+
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
58 |
+
activation(),
|
59 |
+
)
|
60 |
+
self.output_hypernetworks_mlps = nn.ModuleList(
|
61 |
+
[
|
62 |
+
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
63 |
+
for i in range(self.num_mask_tokens)
|
64 |
+
]
|
65 |
+
)
|
66 |
+
|
67 |
+
self.iou_prediction_head = MLP(
|
68 |
+
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
|
69 |
+
)
|
70 |
+
|
71 |
+
def forward(
|
72 |
+
self,
|
73 |
+
image_embeddings: torch.Tensor,
|
74 |
+
image_pe: torch.Tensor,
|
75 |
+
sparse_prompt_embeddings: torch.Tensor,
|
76 |
+
dense_prompt_embeddings: torch.Tensor,
|
77 |
+
multimask_output: bool,
|
78 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
79 |
+
"""
|
80 |
+
Predict masks given image and prompt embeddings.
|
81 |
+
|
82 |
+
Arguments:
|
83 |
+
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
84 |
+
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
85 |
+
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
86 |
+
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
87 |
+
multimask_output (bool): Whether to return multiple masks or a single
|
88 |
+
mask.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
torch.Tensor: batched predicted masks
|
92 |
+
torch.Tensor: batched predictions of mask quality
|
93 |
+
"""
|
94 |
+
masks, iou_pred, src = self.predict_masks(
|
95 |
+
image_embeddings=image_embeddings,
|
96 |
+
image_pe=image_pe,
|
97 |
+
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
98 |
+
dense_prompt_embeddings=dense_prompt_embeddings,
|
99 |
+
)
|
100 |
+
|
101 |
+
# Select the correct mask or masks for output
|
102 |
+
if multimask_output:
|
103 |
+
mask_slice = slice(1, None)
|
104 |
+
else:
|
105 |
+
mask_slice = slice(0, 1)
|
106 |
+
masks = masks[:, mask_slice, :, :]
|
107 |
+
iou_pred = iou_pred[:, mask_slice]
|
108 |
+
|
109 |
+
# Prepare output
|
110 |
+
return masks, iou_pred, src
|
111 |
+
|
112 |
+
def predict_masks(
|
113 |
+
self,
|
114 |
+
image_embeddings: torch.Tensor,
|
115 |
+
image_pe: torch.Tensor,
|
116 |
+
sparse_prompt_embeddings: torch.Tensor,
|
117 |
+
dense_prompt_embeddings: torch.Tensor,
|
118 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
119 |
+
"""Predicts masks. See 'forward' for more details."""
|
120 |
+
# Concatenate output tokens
|
121 |
+
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
122 |
+
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
|
123 |
+
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
124 |
+
|
125 |
+
# Expand per-image data in batch direction to be per-mask
|
126 |
+
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
127 |
+
src = src + dense_prompt_embeddings
|
128 |
+
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
129 |
+
b, c, h, w = src.shape
|
130 |
+
|
131 |
+
# Run the transformer
|
132 |
+
hs, src = self.transformer(src, pos_src, tokens)
|
133 |
+
iou_token_out = hs[:, 0, :]
|
134 |
+
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
135 |
+
|
136 |
+
# Upscale mask embeddings and predict masks using the mask tokens
|
137 |
+
src = src.transpose(1, 2).view(b, c, h, w)
|
138 |
+
upscaled_embedding = self.output_upscaling(src)
|
139 |
+
hyper_in_list: List[torch.Tensor] = []
|
140 |
+
for i in range(self.num_mask_tokens):
|
141 |
+
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
|
142 |
+
hyper_in = torch.stack(hyper_in_list, dim=1)
|
143 |
+
b, c, h, w = upscaled_embedding.shape
|
144 |
+
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
145 |
+
|
146 |
+
# Generate mask quality predictions
|
147 |
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
148 |
+
|
149 |
+
return masks, iou_pred, src
|
150 |
+
|
151 |
+
|
152 |
+
# Lightly adapted from
|
153 |
+
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
|
154 |
+
class MLP(nn.Module):
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
input_dim: int,
|
158 |
+
hidden_dim: int,
|
159 |
+
output_dim: int,
|
160 |
+
num_layers: int,
|
161 |
+
sigmoid_output: bool = False,
|
162 |
+
) -> None:
|
163 |
+
super().__init__()
|
164 |
+
self.num_layers = num_layers
|
165 |
+
h = [hidden_dim] * (num_layers - 1)
|
166 |
+
self.layers = nn.ModuleList(
|
167 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
168 |
+
)
|
169 |
+
self.sigmoid_output = sigmoid_output
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
for i, layer in enumerate(self.layers):
|
173 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
174 |
+
if self.sigmoid_output:
|
175 |
+
x = F.sigmoid(x)
|
176 |
+
return x
|
models/modeling/prompt_encoder.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from typing import Any, Optional, Tuple, Type
|
12 |
+
|
13 |
+
from .common import LayerNorm2d
|
14 |
+
|
15 |
+
|
16 |
+
class PromptEncoder(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
embed_dim: int,
|
20 |
+
image_embedding_size: Tuple[int, int],
|
21 |
+
input_image_size: Tuple[int, int],
|
22 |
+
mask_in_chans: int,
|
23 |
+
activation: Type[nn.Module] = nn.GELU,
|
24 |
+
) -> None:
|
25 |
+
"""
|
26 |
+
Encodes prompts for input to SAM's mask decoder.
|
27 |
+
|
28 |
+
Arguments:
|
29 |
+
embed_dim (int): The prompts' embedding dimension
|
30 |
+
image_embedding_size (tuple(int, int)): The spatial size of the
|
31 |
+
image embedding, as (H, W).
|
32 |
+
input_image_size (int): The padded size of the image as input
|
33 |
+
to the image encoder, as (H, W).
|
34 |
+
mask_in_chans (int): The number of hidden channels used for
|
35 |
+
encoding input masks.
|
36 |
+
activation (nn.Module): The activation to use when encoding
|
37 |
+
input masks.
|
38 |
+
"""
|
39 |
+
super().__init__()
|
40 |
+
self.embed_dim = embed_dim
|
41 |
+
self.input_image_size = input_image_size
|
42 |
+
self.image_embedding_size = image_embedding_size
|
43 |
+
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
44 |
+
|
45 |
+
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
46 |
+
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
|
47 |
+
self.point_embeddings = nn.ModuleList(point_embeddings)
|
48 |
+
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
49 |
+
|
50 |
+
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
|
51 |
+
self.mask_downscaling = nn.Sequential(
|
52 |
+
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
53 |
+
LayerNorm2d(mask_in_chans // 4),
|
54 |
+
activation(),
|
55 |
+
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
56 |
+
LayerNorm2d(mask_in_chans),
|
57 |
+
activation(),
|
58 |
+
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
59 |
+
)
|
60 |
+
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
61 |
+
|
62 |
+
def get_dense_pe(self) -> torch.Tensor:
|
63 |
+
"""
|
64 |
+
Returns the positional encoding used to encode point prompts,
|
65 |
+
applied to a dense set of points the shape of the image encoding.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
torch.Tensor: Positional encoding with shape
|
69 |
+
1x(embed_dim)x(embedding_h)x(embedding_w)
|
70 |
+
"""
|
71 |
+
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
72 |
+
|
73 |
+
def _embed_points(
|
74 |
+
self,
|
75 |
+
points: torch.Tensor,
|
76 |
+
labels: torch.Tensor,
|
77 |
+
pad: bool,
|
78 |
+
) -> torch.Tensor:
|
79 |
+
"""Embeds point prompts."""
|
80 |
+
points = points + 0.5 # Shift to center of pixel
|
81 |
+
if pad:
|
82 |
+
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
83 |
+
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
84 |
+
points = torch.cat([points, padding_point], dim=1)
|
85 |
+
labels = torch.cat([labels, padding_label], dim=1)
|
86 |
+
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
|
87 |
+
point_embedding[labels == -1] = 0.0
|
88 |
+
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
89 |
+
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
90 |
+
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
91 |
+
return point_embedding
|
92 |
+
|
93 |
+
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
94 |
+
"""Embeds box prompts."""
|
95 |
+
boxes = boxes + 0.5 # Shift to center of pixel
|
96 |
+
coords = boxes.reshape(-1, 2, 2)
|
97 |
+
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
|
98 |
+
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
99 |
+
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
100 |
+
return corner_embedding
|
101 |
+
|
102 |
+
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
103 |
+
"""Embeds mask inputs."""
|
104 |
+
mask_embedding = self.mask_downscaling(masks)
|
105 |
+
return mask_embedding
|
106 |
+
|
107 |
+
def _get_batch_size(
|
108 |
+
self,
|
109 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
110 |
+
boxes: Optional[torch.Tensor],
|
111 |
+
masks: Optional[torch.Tensor],
|
112 |
+
) -> int:
|
113 |
+
"""
|
114 |
+
Gets the batch size of the output given the batch size of the input prompts.
|
115 |
+
"""
|
116 |
+
if points is not None:
|
117 |
+
return points[0].shape[0]
|
118 |
+
elif boxes is not None:
|
119 |
+
return boxes.shape[0]
|
120 |
+
elif masks is not None:
|
121 |
+
return masks.shape[0]
|
122 |
+
else:
|
123 |
+
return 1
|
124 |
+
|
125 |
+
def _get_device(self) -> torch.device:
|
126 |
+
return self.point_embeddings[0].weight.device
|
127 |
+
|
128 |
+
def forward(
|
129 |
+
self,
|
130 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
131 |
+
boxes: Optional[torch.Tensor],
|
132 |
+
masks: Optional[torch.Tensor],
|
133 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
134 |
+
"""
|
135 |
+
Embeds different types of prompts, returning both sparse and dense
|
136 |
+
embeddings.
|
137 |
+
|
138 |
+
Arguments:
|
139 |
+
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
|
140 |
+
and labels to embed.
|
141 |
+
boxes (torch.Tensor or none): boxes to embed
|
142 |
+
masks (torch.Tensor or none): masks to embed
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
146 |
+
BxNx(embed_dim), where N is determined by the number of input points
|
147 |
+
and boxes.
|
148 |
+
torch.Tensor: dense embeddings for the masks, in the shape
|
149 |
+
Bx(embed_dim)x(embed_H)x(embed_W)
|
150 |
+
"""
|
151 |
+
bs = self._get_batch_size(points, boxes, masks)
|
152 |
+
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
|
153 |
+
if points is not None:
|
154 |
+
coords, labels = points
|
155 |
+
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
156 |
+
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
157 |
+
if boxes is not None:
|
158 |
+
box_embeddings = self._embed_boxes(boxes)
|
159 |
+
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
|
160 |
+
|
161 |
+
if masks is not None:
|
162 |
+
dense_embeddings = self._embed_masks(masks)
|
163 |
+
else:
|
164 |
+
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
165 |
+
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
166 |
+
)
|
167 |
+
|
168 |
+
return sparse_embeddings, dense_embeddings
|
169 |
+
|
170 |
+
|
171 |
+
class PositionEmbeddingRandom(nn.Module):
|
172 |
+
"""
|
173 |
+
Positional encoding using random spatial frequencies.
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
177 |
+
super().__init__()
|
178 |
+
if scale is None or scale <= 0.0:
|
179 |
+
scale = 1.0
|
180 |
+
self.register_buffer(
|
181 |
+
"positional_encoding_gaussian_matrix",
|
182 |
+
scale * torch.randn((2, num_pos_feats)),
|
183 |
+
)
|
184 |
+
|
185 |
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
186 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
187 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
188 |
+
coords = 2 * coords - 1
|
189 |
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
190 |
+
coords = 2 * np.pi * coords
|
191 |
+
# outputs d_1 x ... x d_n x C shape
|
192 |
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
193 |
+
|
194 |
+
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
195 |
+
"""Generate positional encoding for a grid of the specified size."""
|
196 |
+
h, w = size
|
197 |
+
device: Any = self.positional_encoding_gaussian_matrix.device
|
198 |
+
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
199 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
200 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
201 |
+
y_embed = y_embed / h
|
202 |
+
x_embed = x_embed / w
|
203 |
+
|
204 |
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
205 |
+
return pe.permute(2, 0, 1) # C x H x W
|
206 |
+
|
207 |
+
def forward_with_coords(
|
208 |
+
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
209 |
+
) -> torch.Tensor:
|
210 |
+
"""Positionally encode points that are not normalized to [0,1]."""
|
211 |
+
coords = coords_input.clone()
|
212 |
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
213 |
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
214 |
+
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
models/modeling/sam.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
from typing import Any, Dict, List, Tuple
|
12 |
+
|
13 |
+
from .image_encoder import ImageEncoderViT
|
14 |
+
from .mask_decoder import MaskDecoder
|
15 |
+
from .prompt_encoder import PromptEncoder
|
16 |
+
|
17 |
+
|
18 |
+
class Sam(nn.Module):
|
19 |
+
mask_threshold: float = 0.0
|
20 |
+
image_format: str = "RGB"
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
image_encoder: ImageEncoderViT,
|
25 |
+
prompt_encoder: PromptEncoder,
|
26 |
+
mask_decoder: MaskDecoder,
|
27 |
+
pixel_mean: List[float] = [123.675, 116.28, 103.53],
|
28 |
+
pixel_std: List[float] = [58.395, 57.12, 57.375],
|
29 |
+
) -> None:
|
30 |
+
"""
|
31 |
+
SAM predicts object masks from an image and input prompts.
|
32 |
+
|
33 |
+
Arguments:
|
34 |
+
image_encoder (ImageEncoderViT): The backbone used to encode the
|
35 |
+
image into image embeddings that allow for efficient mask prediction.
|
36 |
+
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
37 |
+
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
|
38 |
+
and encoded prompts.
|
39 |
+
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
40 |
+
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
41 |
+
"""
|
42 |
+
super().__init__()
|
43 |
+
self.image_encoder = image_encoder
|
44 |
+
self.prompt_encoder = prompt_encoder
|
45 |
+
self.mask_decoder = mask_decoder
|
46 |
+
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
47 |
+
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
48 |
+
|
49 |
+
@property
|
50 |
+
def device(self) -> Any:
|
51 |
+
return self.pixel_mean.device
|
52 |
+
|
53 |
+
@torch.no_grad()
|
54 |
+
def forward(
|
55 |
+
self,
|
56 |
+
batched_input: List[Dict[str, Any]],
|
57 |
+
multimask_output: bool,
|
58 |
+
) -> List[Dict[str, torch.Tensor]]:
|
59 |
+
"""
|
60 |
+
Predicts masks end-to-end from provided images and prompts.
|
61 |
+
If prompts are not known in advance, using SamPredictor is
|
62 |
+
recommended over calling the model directly.
|
63 |
+
|
64 |
+
Arguments:
|
65 |
+
batched_input (list(dict)): A list over input images, each a
|
66 |
+
dictionary with the following keys. A prompt key can be
|
67 |
+
excluded if it is not present.
|
68 |
+
'image': The image as a torch tensor in 3xHxW format,
|
69 |
+
already transformed for input to the model.
|
70 |
+
'original_size': (tuple(int, int)) The original size of
|
71 |
+
the image before transformation, as (H, W).
|
72 |
+
'point_coords': (torch.Tensor) Batched point prompts for
|
73 |
+
this image, with shape BxNx2. Already transformed to the
|
74 |
+
input frame of the model.
|
75 |
+
'point_labels': (torch.Tensor) Batched labels for point prompts,
|
76 |
+
with shape BxN.
|
77 |
+
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
|
78 |
+
Already transformed to the input frame of the model.
|
79 |
+
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
|
80 |
+
in the form Bx1xHxW.
|
81 |
+
multimask_output (bool): Whether the model should predict multiple
|
82 |
+
disambiguating masks, or return a single mask.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
(list(dict)): A list over input images, where each element is
|
86 |
+
as dictionary with the following keys.
|
87 |
+
'masks': (torch.Tensor) Batched binary mask predictions,
|
88 |
+
with shape BxCxHxW, where B is the number of input prompts,
|
89 |
+
C is determined by multimask_output, and (H, W) is the
|
90 |
+
original size of the image.
|
91 |
+
'iou_predictions': (torch.Tensor) The model's predictions
|
92 |
+
of mask quality, in shape BxC.
|
93 |
+
'low_res_logits': (torch.Tensor) Low resolution logits with
|
94 |
+
shape BxCxHxW, where H=W=256. Can be passed as mask input
|
95 |
+
to subsequent iterations of prediction.
|
96 |
+
"""
|
97 |
+
input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
|
98 |
+
image_embeddings = self.image_encoder(input_images)
|
99 |
+
|
100 |
+
outputs = []
|
101 |
+
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
102 |
+
if "point_coords" in image_record:
|
103 |
+
points = (image_record["point_coords"], image_record["point_labels"])
|
104 |
+
else:
|
105 |
+
points = None
|
106 |
+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
107 |
+
points=points,
|
108 |
+
boxes=image_record.get("boxes", None),
|
109 |
+
masks=image_record.get("mask_inputs", None),
|
110 |
+
)
|
111 |
+
low_res_masks, iou_predictions = self.mask_decoder(
|
112 |
+
image_embeddings=curr_embedding.unsqueeze(0),
|
113 |
+
image_pe=self.prompt_encoder.get_dense_pe(),
|
114 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
115 |
+
dense_prompt_embeddings=dense_embeddings,
|
116 |
+
multimask_output=multimask_output,
|
117 |
+
)
|
118 |
+
masks = self.postprocess_masks(
|
119 |
+
low_res_masks,
|
120 |
+
input_size=image_record["image"].shape[-2:],
|
121 |
+
original_size=image_record["original_size"],
|
122 |
+
)
|
123 |
+
masks = masks > self.mask_threshold
|
124 |
+
outputs.append(
|
125 |
+
{
|
126 |
+
"masks": masks,
|
127 |
+
"iou_predictions": iou_predictions,
|
128 |
+
"low_res_logits": low_res_masks,
|
129 |
+
}
|
130 |
+
)
|
131 |
+
return outputs
|
132 |
+
|
133 |
+
def postprocess_masks(
|
134 |
+
self,
|
135 |
+
masks: torch.Tensor,
|
136 |
+
input_size: Tuple[int, ...],
|
137 |
+
original_size: Tuple[int, ...],
|
138 |
+
) -> torch.Tensor:
|
139 |
+
"""
|
140 |
+
Remove padding and upscale masks to the original image size.
|
141 |
+
|
142 |
+
Arguments:
|
143 |
+
masks (torch.Tensor): Batched masks from the mask_decoder,
|
144 |
+
in BxCxHxW format.
|
145 |
+
input_size (tuple(int, int)): The size of the image input to the
|
146 |
+
model, in (H, W) format. Used to remove padding.
|
147 |
+
original_size (tuple(int, int)): The original size of the image
|
148 |
+
before resizing for input to the model, in (H, W) format.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
|
152 |
+
is given by original_size.
|
153 |
+
"""
|
154 |
+
masks = F.interpolate(
|
155 |
+
masks,
|
156 |
+
(self.image_encoder.img_size, self.image_encoder.img_size),
|
157 |
+
mode="bilinear",
|
158 |
+
align_corners=False,
|
159 |
+
)
|
160 |
+
masks = masks[..., : input_size[0], : input_size[1]]
|
161 |
+
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
|
162 |
+
return masks
|
163 |
+
|
164 |
+
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
165 |
+
"""Normalize pixel values and pad to a square input."""
|
166 |
+
# Normalize colors
|
167 |
+
x = (x - self.pixel_mean) / self.pixel_std
|
168 |
+
|
169 |
+
# Pad
|
170 |
+
h, w = x.shape[-2:]
|
171 |
+
padh = self.image_encoder.img_size - h
|
172 |
+
padw = self.image_encoder.img_size - w
|
173 |
+
x = F.pad(x, (0, padw, 0, padh))
|
174 |
+
return x
|
models/modeling/transformer.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import Tensor, nn
|
9 |
+
|
10 |
+
import math
|
11 |
+
from typing import Tuple, Type
|
12 |
+
|
13 |
+
from .common import MLPBlock
|
14 |
+
|
15 |
+
|
16 |
+
class TwoWayTransformer(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
depth: int,
|
20 |
+
embedding_dim: int,
|
21 |
+
num_heads: int,
|
22 |
+
mlp_dim: int,
|
23 |
+
activation: Type[nn.Module] = nn.ReLU,
|
24 |
+
attention_downsample_rate: int = 2,
|
25 |
+
) -> None:
|
26 |
+
"""
|
27 |
+
A transformer decoder that attends to an input image using
|
28 |
+
queries whose positional embedding is supplied.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
depth (int): number of layers in the transformer
|
32 |
+
embedding_dim (int): the channel dimension for the input embeddings
|
33 |
+
num_heads (int): the number of heads for multihead attention. Must
|
34 |
+
divide embedding_dim
|
35 |
+
mlp_dim (int): the channel dimension internal to the MLP block
|
36 |
+
activation (nn.Module): the activation to use in the MLP block
|
37 |
+
"""
|
38 |
+
super().__init__()
|
39 |
+
self.depth = depth
|
40 |
+
self.embedding_dim = embedding_dim
|
41 |
+
self.num_heads = num_heads
|
42 |
+
self.mlp_dim = mlp_dim
|
43 |
+
self.layers = nn.ModuleList()
|
44 |
+
|
45 |
+
for i in range(depth):
|
46 |
+
self.layers.append(
|
47 |
+
TwoWayAttentionBlock(
|
48 |
+
embedding_dim=embedding_dim,
|
49 |
+
num_heads=num_heads,
|
50 |
+
mlp_dim=mlp_dim,
|
51 |
+
activation=activation,
|
52 |
+
attention_downsample_rate=attention_downsample_rate,
|
53 |
+
skip_first_layer_pe=(i == 0),
|
54 |
+
)
|
55 |
+
)
|
56 |
+
|
57 |
+
self.final_attn_token_to_image = Attention(
|
58 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
59 |
+
)
|
60 |
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
61 |
+
|
62 |
+
def forward(
|
63 |
+
self,
|
64 |
+
image_embedding: Tensor,
|
65 |
+
image_pe: Tensor,
|
66 |
+
point_embedding: Tensor,
|
67 |
+
) -> Tuple[Tensor, Tensor]:
|
68 |
+
"""
|
69 |
+
Args:
|
70 |
+
image_embedding (torch.Tensor): image to attend to. Should be shape
|
71 |
+
B x embedding_dim x h x w for any h and w.
|
72 |
+
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
73 |
+
have the same shape as image_embedding.
|
74 |
+
point_embedding (torch.Tensor): the embedding to add to the query points.
|
75 |
+
Must have shape B x N_points x embedding_dim for any N_points.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
torch.Tensor: the processed point_embedding
|
79 |
+
torch.Tensor: the processed image_embedding
|
80 |
+
"""
|
81 |
+
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
82 |
+
bs, c, h, w = image_embedding.shape
|
83 |
+
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
84 |
+
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
85 |
+
|
86 |
+
# Prepare queries
|
87 |
+
queries = point_embedding
|
88 |
+
keys = image_embedding
|
89 |
+
|
90 |
+
# Apply transformer blocks and final layernorm
|
91 |
+
for layer in self.layers:
|
92 |
+
queries, keys = layer(
|
93 |
+
queries=queries,
|
94 |
+
keys=keys,
|
95 |
+
query_pe=point_embedding,
|
96 |
+
key_pe=image_pe,
|
97 |
+
)
|
98 |
+
|
99 |
+
# Apply the final attention layer from the points to the image
|
100 |
+
q = queries + point_embedding
|
101 |
+
k = keys + image_pe
|
102 |
+
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
103 |
+
queries = queries + attn_out
|
104 |
+
queries = self.norm_final_attn(queries)
|
105 |
+
|
106 |
+
return queries, keys
|
107 |
+
|
108 |
+
|
109 |
+
class TwoWayAttentionBlock(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
embedding_dim: int,
|
113 |
+
num_heads: int,
|
114 |
+
mlp_dim: int = 2048,
|
115 |
+
activation: Type[nn.Module] = nn.ReLU,
|
116 |
+
attention_downsample_rate: int = 2,
|
117 |
+
skip_first_layer_pe: bool = False,
|
118 |
+
) -> None:
|
119 |
+
"""
|
120 |
+
A transformer block with four layers: (1) self-attention of sparse
|
121 |
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
122 |
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
123 |
+
inputs.
|
124 |
+
|
125 |
+
Arguments:
|
126 |
+
embedding_dim (int): the channel dimension of the embeddings
|
127 |
+
num_heads (int): the number of heads in the attention layers
|
128 |
+
mlp_dim (int): the hidden dimension of the mlp block
|
129 |
+
activation (nn.Module): the activation of the mlp block
|
130 |
+
skip_first_layer_pe (bool): skip the PE on the first layer
|
131 |
+
"""
|
132 |
+
super().__init__()
|
133 |
+
self.self_attn = Attention(embedding_dim, num_heads)
|
134 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
135 |
+
|
136 |
+
self.cross_attn_token_to_image = Attention(
|
137 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
138 |
+
)
|
139 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
140 |
+
|
141 |
+
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
142 |
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
143 |
+
|
144 |
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
145 |
+
self.cross_attn_image_to_token = Attention(
|
146 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
147 |
+
)
|
148 |
+
|
149 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
150 |
+
|
151 |
+
def forward(
|
152 |
+
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
153 |
+
) -> Tuple[Tensor, Tensor]:
|
154 |
+
# Self attention block
|
155 |
+
if self.skip_first_layer_pe:
|
156 |
+
queries = self.self_attn(q=queries, k=queries, v=queries)
|
157 |
+
else:
|
158 |
+
q = queries + query_pe
|
159 |
+
attn_out = self.self_attn(q=q, k=q, v=queries)
|
160 |
+
queries = queries + attn_out
|
161 |
+
queries = self.norm1(queries)
|
162 |
+
|
163 |
+
# Cross attention block, tokens attending to image embedding
|
164 |
+
q = queries + query_pe
|
165 |
+
k = keys + key_pe
|
166 |
+
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
167 |
+
queries = queries + attn_out
|
168 |
+
queries = self.norm2(queries)
|
169 |
+
|
170 |
+
# MLP block
|
171 |
+
mlp_out = self.mlp(queries)
|
172 |
+
queries = queries + mlp_out
|
173 |
+
queries = self.norm3(queries)
|
174 |
+
|
175 |
+
# Cross attention block, image embedding attending to tokens
|
176 |
+
q = queries + query_pe
|
177 |
+
k = keys + key_pe
|
178 |
+
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
179 |
+
keys = keys + attn_out
|
180 |
+
keys = self.norm4(keys)
|
181 |
+
|
182 |
+
return queries, keys
|
183 |
+
|
184 |
+
|
185 |
+
class Attention(nn.Module):
|
186 |
+
"""
|
187 |
+
An attention layer that allows for downscaling the size of the embedding
|
188 |
+
after projection to queries, keys, and values.
|
189 |
+
"""
|
190 |
+
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
embedding_dim: int,
|
194 |
+
num_heads: int,
|
195 |
+
downsample_rate: int = 1,
|
196 |
+
) -> None:
|
197 |
+
super().__init__()
|
198 |
+
self.embedding_dim = embedding_dim
|
199 |
+
self.internal_dim = embedding_dim // downsample_rate
|
200 |
+
self.num_heads = num_heads
|
201 |
+
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
202 |
+
|
203 |
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
204 |
+
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
205 |
+
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
206 |
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
207 |
+
|
208 |
+
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
209 |
+
b, n, c = x.shape
|
210 |
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
211 |
+
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
212 |
+
|
213 |
+
def _recombine_heads(self, x: Tensor) -> Tensor:
|
214 |
+
b, n_heads, n_tokens, c_per_head = x.shape
|
215 |
+
x = x.transpose(1, 2)
|
216 |
+
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
217 |
+
|
218 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
219 |
+
# Input projections
|
220 |
+
q = self.q_proj(q)
|
221 |
+
k = self.k_proj(k)
|
222 |
+
v = self.v_proj(v)
|
223 |
+
|
224 |
+
# Separate into heads
|
225 |
+
q = self._separate_heads(q, self.num_heads)
|
226 |
+
k = self._separate_heads(k, self.num_heads)
|
227 |
+
v = self._separate_heads(v, self.num_heads)
|
228 |
+
|
229 |
+
# Attention
|
230 |
+
_, _, _, c_per_head = q.shape
|
231 |
+
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
232 |
+
attn = attn / math.sqrt(c_per_head)
|
233 |
+
attn = torch.softmax(attn, dim=-1)
|
234 |
+
|
235 |
+
# Get output
|
236 |
+
out = attn @ v
|
237 |
+
out = self._recombine_heads(out)
|
238 |
+
out = self.out_proj(out)
|
239 |
+
|
240 |
+
return out
|
models/predictor.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from models.modeling import Sam
|
11 |
+
|
12 |
+
from typing import Optional, Tuple
|
13 |
+
|
14 |
+
from .utils.transforms import ResizeLongestSide
|
15 |
+
|
16 |
+
|
17 |
+
class SamPredictor:
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
sam_model: Sam,
|
21 |
+
) -> None:
|
22 |
+
"""
|
23 |
+
Uses SAM to calculate the image embedding for an image, and then
|
24 |
+
allow repeated, efficient mask prediction given prompts.
|
25 |
+
|
26 |
+
Arguments:
|
27 |
+
sam_model (Sam): The model to use for mask prediction.
|
28 |
+
"""
|
29 |
+
super().__init__()
|
30 |
+
self.model = sam_model
|
31 |
+
self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
|
32 |
+
self.reset_image()
|
33 |
+
|
34 |
+
def set_image(
|
35 |
+
self,
|
36 |
+
image: np.ndarray,
|
37 |
+
image_format: str = "RGB",
|
38 |
+
) -> None:
|
39 |
+
"""
|
40 |
+
Calculates the image embeddings for the provided image, allowing
|
41 |
+
masks to be predicted with the 'predict' method.
|
42 |
+
|
43 |
+
Arguments:
|
44 |
+
image (np.ndarray): The image for calculating masks. Expects an
|
45 |
+
image in HWC uint8 format, with pixel values in [0, 255].
|
46 |
+
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
47 |
+
"""
|
48 |
+
assert image_format in [
|
49 |
+
"RGB",
|
50 |
+
"BGR",
|
51 |
+
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
|
52 |
+
if image_format != self.model.image_format:
|
53 |
+
image = image[..., ::-1]
|
54 |
+
|
55 |
+
# Transform the image to the form expected by the model
|
56 |
+
input_image = self.transform.apply_image(image)
|
57 |
+
input_image_torch = torch.as_tensor(input_image, device=self.device)
|
58 |
+
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
|
59 |
+
|
60 |
+
self.set_torch_image(input_image_torch, image.shape[:2])
|
61 |
+
|
62 |
+
@torch.no_grad()
|
63 |
+
def set_torch_image(
|
64 |
+
self,
|
65 |
+
transformed_image: torch.Tensor,
|
66 |
+
original_image_size: Tuple[int, ...],
|
67 |
+
) -> None:
|
68 |
+
"""
|
69 |
+
Calculates the image embeddings for the provided image, allowing
|
70 |
+
masks to be predicted with the 'predict' method. Expects the input
|
71 |
+
image to be already transformed to the format expected by the model.
|
72 |
+
|
73 |
+
Arguments:
|
74 |
+
transformed_image (torch.Tensor): The input image, with shape
|
75 |
+
1x3xHxW, which has been transformed with ResizeLongestSide.
|
76 |
+
original_image_size (tuple(int, int)): The size of the image
|
77 |
+
before transformation, in (H, W) format.
|
78 |
+
"""
|
79 |
+
assert (
|
80 |
+
len(transformed_image.shape) == 4
|
81 |
+
and transformed_image.shape[1] == 3
|
82 |
+
and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
|
83 |
+
), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
|
84 |
+
self.reset_image()
|
85 |
+
|
86 |
+
self.original_size = original_image_size
|
87 |
+
self.input_size = tuple(transformed_image.shape[-2:])
|
88 |
+
input_image = self.model.preprocess(transformed_image)
|
89 |
+
self.features = self.model.image_encoder(input_image)
|
90 |
+
self.is_image_set = True
|
91 |
+
|
92 |
+
def predict(
|
93 |
+
self,
|
94 |
+
point_coords: Optional[np.ndarray] = None,
|
95 |
+
point_labels: Optional[np.ndarray] = None,
|
96 |
+
box: Optional[np.ndarray] = None,
|
97 |
+
mask_input: Optional[np.ndarray] = None,
|
98 |
+
multimask_output: bool = True,
|
99 |
+
return_logits: bool = False,
|
100 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
101 |
+
"""
|
102 |
+
Predict masks for the given input prompts, using the currently set image.
|
103 |
+
|
104 |
+
Arguments:
|
105 |
+
point_coords (np.ndarray or None): A Nx2 array of point prompts to the
|
106 |
+
model. Each point is in (X,Y) in pixels.
|
107 |
+
point_labels (np.ndarray or None): A length N array of labels for the
|
108 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
109 |
+
background point.
|
110 |
+
box (np.ndarray or None): A length 4 array given a box prompt to the
|
111 |
+
model, in XYXY format.
|
112 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
113 |
+
coming from a previous prediction iteration. Has form 1xHxW, where
|
114 |
+
for SAM, H=W=256.
|
115 |
+
multimask_output (bool): If true, the model will return three masks.
|
116 |
+
For ambiguous input prompts (such as a single click), this will often
|
117 |
+
produce better masks than a single prediction. If only a single
|
118 |
+
mask is needed, the model's predicted quality score can be used
|
119 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
120 |
+
input prompts, multimask_output=False can give better results.
|
121 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
122 |
+
instead of a binary mask.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
(np.ndarray): The output masks in CxHxW format, where C is the
|
126 |
+
number of masks, and (H, W) is the original image size.
|
127 |
+
(np.ndarray): An array of length C containing the model's
|
128 |
+
predictions for the quality of each mask.
|
129 |
+
(np.ndarray): An array of shape CxHxW, where C is the number
|
130 |
+
of masks and H=W=256. These low resolution logits can be passed to
|
131 |
+
a subsequent iteration as mask input.
|
132 |
+
"""
|
133 |
+
if not self.is_image_set:
|
134 |
+
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
135 |
+
|
136 |
+
# Transform input prompts
|
137 |
+
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
|
138 |
+
if point_coords is not None:
|
139 |
+
assert (
|
140 |
+
point_labels is not None
|
141 |
+
), "point_labels must be supplied if point_coords is supplied."
|
142 |
+
point_coords = self.transform.apply_coords(point_coords, self.original_size)
|
143 |
+
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
|
144 |
+
labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
|
145 |
+
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
|
146 |
+
if box is not None:
|
147 |
+
box = self.transform.apply_boxes(box, self.original_size)
|
148 |
+
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
|
149 |
+
box_torch = box_torch[None, :]
|
150 |
+
if mask_input is not None:
|
151 |
+
mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
|
152 |
+
mask_input_torch = mask_input_torch[None, :, :, :]
|
153 |
+
|
154 |
+
masks, iou_predictions, low_res_masks = self.predict_torch(
|
155 |
+
coords_torch,
|
156 |
+
labels_torch,
|
157 |
+
box_torch,
|
158 |
+
mask_input_torch,
|
159 |
+
multimask_output,
|
160 |
+
return_logits=return_logits,
|
161 |
+
)
|
162 |
+
|
163 |
+
masks_np = masks[0].detach().cpu().numpy()
|
164 |
+
iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
|
165 |
+
low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
|
166 |
+
return masks_np, iou_predictions_np, low_res_masks_np
|
167 |
+
|
168 |
+
@torch.no_grad()
|
169 |
+
def predict_torch(
|
170 |
+
self,
|
171 |
+
point_coords: Optional[torch.Tensor],
|
172 |
+
point_labels: Optional[torch.Tensor],
|
173 |
+
boxes: Optional[torch.Tensor] = None,
|
174 |
+
mask_input: Optional[torch.Tensor] = None,
|
175 |
+
multimask_output: bool = True,
|
176 |
+
return_logits: bool = False,
|
177 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
178 |
+
"""
|
179 |
+
Predict masks for the given input prompts, using the currently set image.
|
180 |
+
Input prompts are batched torch tensors and are expected to already be
|
181 |
+
transformed to the input frame using ResizeLongestSide.
|
182 |
+
|
183 |
+
Arguments:
|
184 |
+
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
|
185 |
+
model. Each point is in (X,Y) in pixels.
|
186 |
+
point_labels (torch.Tensor or None): A BxN array of labels for the
|
187 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
188 |
+
background point.
|
189 |
+
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
|
190 |
+
model, in XYXY format.
|
191 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
192 |
+
coming from a previous prediction iteration. Has form Bx1xHxW, where
|
193 |
+
for SAM, H=W=256. Masks returned by a previous iteration of the
|
194 |
+
predict method do not need further transformation.
|
195 |
+
multimask_output (bool): If true, the model will return three masks.
|
196 |
+
For ambiguous input prompts (such as a single click), this will often
|
197 |
+
produce better masks than a single prediction. If only a single
|
198 |
+
mask is needed, the model's predicted quality score can be used
|
199 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
200 |
+
input prompts, multimask_output=False can give better results.
|
201 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
202 |
+
instead of a binary mask.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
(torch.Tensor): The output masks in BxCxHxW format, where C is the
|
206 |
+
number of masks, and (H, W) is the original image size.
|
207 |
+
(torch.Tensor): An array of shape BxC containing the model's
|
208 |
+
predictions for the quality of each mask.
|
209 |
+
(torch.Tensor): An array of shape BxCxHxW, where C is the number
|
210 |
+
of masks and H=W=256. These low res logits can be passed to
|
211 |
+
a subsequent iteration as mask input.
|
212 |
+
"""
|
213 |
+
if not self.is_image_set:
|
214 |
+
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
215 |
+
|
216 |
+
if point_coords is not None:
|
217 |
+
points = (point_coords, point_labels)
|
218 |
+
else:
|
219 |
+
points = None
|
220 |
+
|
221 |
+
# Embed prompts
|
222 |
+
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
|
223 |
+
points=points,
|
224 |
+
boxes=boxes,
|
225 |
+
masks=mask_input,
|
226 |
+
)
|
227 |
+
|
228 |
+
# Predict masks
|
229 |
+
low_res_masks, iou_predictions = self.model.mask_decoder(
|
230 |
+
image_embeddings=self.features,
|
231 |
+
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
232 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
233 |
+
dense_prompt_embeddings=dense_embeddings,
|
234 |
+
multimask_output=multimask_output,
|
235 |
+
)
|
236 |
+
|
237 |
+
# Upscale the masks to the original image resolution
|
238 |
+
masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
|
239 |
+
|
240 |
+
if not return_logits:
|
241 |
+
masks = masks > self.model.mask_threshold
|
242 |
+
|
243 |
+
return masks, iou_predictions, low_res_masks
|
244 |
+
|
245 |
+
def get_image_embedding(self) -> torch.Tensor:
|
246 |
+
"""
|
247 |
+
Returns the image embeddings for the currently set image, with
|
248 |
+
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
|
249 |
+
the embedding spatial dimension of SAM (typically C=256, H=W=64).
|
250 |
+
"""
|
251 |
+
if not self.is_image_set:
|
252 |
+
raise RuntimeError(
|
253 |
+
"An image must be set with .set_image(...) to generate an embedding."
|
254 |
+
)
|
255 |
+
assert self.features is not None, "Features must exist if an image has been set."
|
256 |
+
return self.features
|
257 |
+
|
258 |
+
@property
|
259 |
+
def device(self) -> torch.device:
|
260 |
+
return self.model.device
|
261 |
+
|
262 |
+
def reset_image(self) -> None:
|
263 |
+
"""Resets the currently set image."""
|
264 |
+
self.is_image_set = False
|
265 |
+
self.features = None
|
266 |
+
self.orig_h = None
|
267 |
+
self.orig_w = None
|
268 |
+
self.input_h = None
|
269 |
+
self.input_w = None
|
models/utils/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
models/utils/amg.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import math
|
11 |
+
from copy import deepcopy
|
12 |
+
from itertools import product
|
13 |
+
from typing import Any, Dict, Generator, ItemsView, List, Tuple
|
14 |
+
|
15 |
+
|
16 |
+
class MaskData:
|
17 |
+
"""
|
18 |
+
A structure for storing masks and their related data in batched format.
|
19 |
+
Implements basic filtering and concatenation.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, **kwargs) -> None:
|
23 |
+
for v in kwargs.values():
|
24 |
+
assert isinstance(
|
25 |
+
v, (list, np.ndarray, torch.Tensor)
|
26 |
+
), "MaskData only supports list, numpy arrays, and torch tensors."
|
27 |
+
self._stats = dict(**kwargs)
|
28 |
+
|
29 |
+
def __setitem__(self, key: str, item: Any) -> None:
|
30 |
+
assert isinstance(
|
31 |
+
item, (list, np.ndarray, torch.Tensor)
|
32 |
+
), "MaskData only supports list, numpy arrays, and torch tensors."
|
33 |
+
self._stats[key] = item
|
34 |
+
|
35 |
+
def __delitem__(self, key: str) -> None:
|
36 |
+
del self._stats[key]
|
37 |
+
|
38 |
+
def __getitem__(self, key: str) -> Any:
|
39 |
+
return self._stats[key]
|
40 |
+
|
41 |
+
def items(self) -> ItemsView[str, Any]:
|
42 |
+
return self._stats.items()
|
43 |
+
|
44 |
+
def filter(self, keep: torch.Tensor) -> None:
|
45 |
+
for k, v in self._stats.items():
|
46 |
+
if v is None:
|
47 |
+
self._stats[k] = None
|
48 |
+
elif isinstance(v, torch.Tensor):
|
49 |
+
self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
|
50 |
+
elif isinstance(v, np.ndarray):
|
51 |
+
self._stats[k] = v[keep.detach().cpu().numpy()]
|
52 |
+
elif isinstance(v, list) and keep.dtype == torch.bool:
|
53 |
+
self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
|
54 |
+
elif isinstance(v, list):
|
55 |
+
self._stats[k] = [v[i] for i in keep]
|
56 |
+
else:
|
57 |
+
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
58 |
+
|
59 |
+
def cat(self, new_stats: "MaskData") -> None:
|
60 |
+
for k, v in new_stats.items():
|
61 |
+
if k not in self._stats or self._stats[k] is None:
|
62 |
+
self._stats[k] = deepcopy(v)
|
63 |
+
elif isinstance(v, torch.Tensor):
|
64 |
+
self._stats[k] = torch.cat([self._stats[k], v], dim=0)
|
65 |
+
elif isinstance(v, np.ndarray):
|
66 |
+
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
|
67 |
+
elif isinstance(v, list):
|
68 |
+
self._stats[k] = self._stats[k] + deepcopy(v)
|
69 |
+
else:
|
70 |
+
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
71 |
+
|
72 |
+
def to_numpy(self) -> None:
|
73 |
+
for k, v in self._stats.items():
|
74 |
+
if isinstance(v, torch.Tensor):
|
75 |
+
self._stats[k] = v.detach().cpu().numpy()
|
76 |
+
|
77 |
+
|
78 |
+
def is_box_near_crop_edge(
|
79 |
+
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
|
80 |
+
) -> torch.Tensor:
|
81 |
+
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
82 |
+
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
83 |
+
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
84 |
+
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
|
85 |
+
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
86 |
+
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
87 |
+
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
88 |
+
return torch.any(near_crop_edge, dim=1)
|
89 |
+
|
90 |
+
|
91 |
+
def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
|
92 |
+
box_xywh = deepcopy(box_xyxy)
|
93 |
+
box_xywh[2] = box_xywh[2] - box_xywh[0]
|
94 |
+
box_xywh[3] = box_xywh[3] - box_xywh[1]
|
95 |
+
return box_xywh
|
96 |
+
|
97 |
+
|
98 |
+
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
99 |
+
assert len(args) > 0 and all(
|
100 |
+
len(a) == len(args[0]) for a in args
|
101 |
+
), "Batched iteration must have inputs of all the same size."
|
102 |
+
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
103 |
+
for b in range(n_batches):
|
104 |
+
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
|
105 |
+
|
106 |
+
|
107 |
+
def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
|
108 |
+
"""
|
109 |
+
Encodes masks to an uncompressed RLE, in the format expected by
|
110 |
+
pycoco tools.
|
111 |
+
"""
|
112 |
+
# Put in fortran order and flatten h,w
|
113 |
+
b, h, w = tensor.shape
|
114 |
+
tensor = tensor.permute(0, 2, 1).flatten(1)
|
115 |
+
|
116 |
+
# Compute change indices
|
117 |
+
diff = tensor[:, 1:] ^ tensor[:, :-1]
|
118 |
+
change_indices = diff.nonzero()
|
119 |
+
|
120 |
+
# Encode run length
|
121 |
+
out = []
|
122 |
+
for i in range(b):
|
123 |
+
cur_idxs = change_indices[change_indices[:, 0] == i, 1]
|
124 |
+
cur_idxs = torch.cat(
|
125 |
+
[
|
126 |
+
torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
|
127 |
+
cur_idxs + 1,
|
128 |
+
torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
|
129 |
+
]
|
130 |
+
)
|
131 |
+
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
132 |
+
counts = [] if tensor[i, 0] == 0 else [0]
|
133 |
+
counts.extend(btw_idxs.detach().cpu().tolist())
|
134 |
+
out.append({"size": [h, w], "counts": counts})
|
135 |
+
return out
|
136 |
+
|
137 |
+
|
138 |
+
def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
139 |
+
"""Compute a binary mask from an uncompressed RLE."""
|
140 |
+
h, w = rle["size"]
|
141 |
+
mask = np.empty(h * w, dtype=bool)
|
142 |
+
idx = 0
|
143 |
+
parity = False
|
144 |
+
for count in rle["counts"]:
|
145 |
+
mask[idx : idx + count] = parity
|
146 |
+
idx += count
|
147 |
+
parity ^= True
|
148 |
+
mask = mask.reshape(w, h)
|
149 |
+
return mask.transpose() # Put in C order
|
150 |
+
|
151 |
+
|
152 |
+
def area_from_rle(rle: Dict[str, Any]) -> int:
|
153 |
+
return sum(rle["counts"][1::2])
|
154 |
+
|
155 |
+
|
156 |
+
def calculate_stability_score(
|
157 |
+
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
|
158 |
+
) -> torch.Tensor:
|
159 |
+
"""
|
160 |
+
Computes the stability score for a batch of masks. The stability
|
161 |
+
score is the IoU between the binary masks obtained by thresholding
|
162 |
+
the predicted mask logits at high and low values.
|
163 |
+
"""
|
164 |
+
# One mask is always contained inside the other.
|
165 |
+
# Save memory by preventing unnecessary cast to torch.int64
|
166 |
+
intersections = (
|
167 |
+
(masks > (mask_threshold + threshold_offset))
|
168 |
+
.sum(-1, dtype=torch.int16)
|
169 |
+
.sum(-1, dtype=torch.int32)
|
170 |
+
)
|
171 |
+
unions = (
|
172 |
+
(masks > (mask_threshold - threshold_offset))
|
173 |
+
.sum(-1, dtype=torch.int16)
|
174 |
+
.sum(-1, dtype=torch.int32)
|
175 |
+
)
|
176 |
+
return intersections / unions
|
177 |
+
|
178 |
+
|
179 |
+
def build_point_grid(n_per_side: int) -> np.ndarray:
|
180 |
+
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
181 |
+
offset = 1 / (2 * n_per_side)
|
182 |
+
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
183 |
+
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
184 |
+
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
185 |
+
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
186 |
+
return points
|
187 |
+
|
188 |
+
|
189 |
+
def build_all_layer_point_grids(
|
190 |
+
n_per_side: int, n_layers: int, scale_per_layer: int
|
191 |
+
) -> List[np.ndarray]:
|
192 |
+
"""Generates point grids for all crop layers."""
|
193 |
+
points_by_layer = []
|
194 |
+
for i in range(n_layers + 1):
|
195 |
+
n_points = int(n_per_side / (scale_per_layer**i))
|
196 |
+
points_by_layer.append(build_point_grid(n_points))
|
197 |
+
return points_by_layer
|
198 |
+
|
199 |
+
|
200 |
+
def generate_crop_boxes(
|
201 |
+
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
202 |
+
) -> Tuple[List[List[int]], List[int]]:
|
203 |
+
"""
|
204 |
+
Generates a list of crop boxes of different sizes. Each layer
|
205 |
+
has (2**i)**2 boxes for the ith layer.
|
206 |
+
"""
|
207 |
+
crop_boxes, layer_idxs = [], []
|
208 |
+
im_h, im_w = im_size
|
209 |
+
short_side = min(im_h, im_w)
|
210 |
+
|
211 |
+
# Original image
|
212 |
+
crop_boxes.append([0, 0, im_w, im_h])
|
213 |
+
layer_idxs.append(0)
|
214 |
+
|
215 |
+
def crop_len(orig_len, n_crops, overlap):
|
216 |
+
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
|
217 |
+
|
218 |
+
for i_layer in range(n_layers):
|
219 |
+
n_crops_per_side = 2 ** (i_layer + 1)
|
220 |
+
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
221 |
+
|
222 |
+
crop_w = crop_len(im_w, n_crops_per_side, overlap)
|
223 |
+
crop_h = crop_len(im_h, n_crops_per_side, overlap)
|
224 |
+
|
225 |
+
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
|
226 |
+
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
|
227 |
+
|
228 |
+
# Crops in XYWH format
|
229 |
+
for x0, y0 in product(crop_box_x0, crop_box_y0):
|
230 |
+
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
|
231 |
+
crop_boxes.append(box)
|
232 |
+
layer_idxs.append(i_layer + 1)
|
233 |
+
|
234 |
+
return crop_boxes, layer_idxs
|
235 |
+
|
236 |
+
|
237 |
+
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
238 |
+
x0, y0, _, _ = crop_box
|
239 |
+
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
240 |
+
# Check if boxes has a channel dimension
|
241 |
+
if len(boxes.shape) == 3:
|
242 |
+
offset = offset.unsqueeze(1)
|
243 |
+
return boxes + offset
|
244 |
+
|
245 |
+
|
246 |
+
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
247 |
+
x0, y0, _, _ = crop_box
|
248 |
+
offset = torch.tensor([[x0, y0]], device=points.device)
|
249 |
+
# Check if points has a channel dimension
|
250 |
+
if len(points.shape) == 3:
|
251 |
+
offset = offset.unsqueeze(1)
|
252 |
+
return points + offset
|
253 |
+
|
254 |
+
|
255 |
+
def uncrop_masks(
|
256 |
+
masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
|
257 |
+
) -> torch.Tensor:
|
258 |
+
x0, y0, x1, y1 = crop_box
|
259 |
+
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
260 |
+
return masks
|
261 |
+
# Coordinate transform masks
|
262 |
+
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
|
263 |
+
pad = (x0, pad_x - x0, y0, pad_y - y0)
|
264 |
+
return torch.nn.functional.pad(masks, pad, value=0)
|
265 |
+
|
266 |
+
|
267 |
+
def remove_small_regions(
|
268 |
+
mask: np.ndarray, area_thresh: float, mode: str
|
269 |
+
) -> Tuple[np.ndarray, bool]:
|
270 |
+
"""
|
271 |
+
Removes small disconnected regions and holes in a mask. Returns the
|
272 |
+
mask and an indicator of if the mask has been modified.
|
273 |
+
"""
|
274 |
+
import cv2 # type: ignore
|
275 |
+
|
276 |
+
assert mode in ["holes", "islands"]
|
277 |
+
correct_holes = mode == "holes"
|
278 |
+
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
279 |
+
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
280 |
+
sizes = stats[:, -1][1:] # Row 0 is background label
|
281 |
+
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
282 |
+
if len(small_regions) == 0:
|
283 |
+
return mask, False
|
284 |
+
fill_labels = [0] + small_regions
|
285 |
+
if not correct_holes:
|
286 |
+
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
287 |
+
# If every region is below threshold, keep largest
|
288 |
+
if len(fill_labels) == 0:
|
289 |
+
fill_labels = [int(np.argmax(sizes)) + 1]
|
290 |
+
mask = np.isin(regions, fill_labels)
|
291 |
+
return mask, True
|
292 |
+
|
293 |
+
|
294 |
+
def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
|
295 |
+
from pycocotools import mask as mask_utils # type: ignore
|
296 |
+
|
297 |
+
h, w = uncompressed_rle["size"]
|
298 |
+
rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
|
299 |
+
rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
|
300 |
+
return rle
|
301 |
+
|
302 |
+
|
303 |
+
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
304 |
+
"""
|
305 |
+
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
|
306 |
+
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
307 |
+
"""
|
308 |
+
# torch.max below raises an error on empty inputs, just skip in this case
|
309 |
+
if torch.numel(masks) == 0:
|
310 |
+
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
311 |
+
|
312 |
+
# Normalize shape to CxHxW
|
313 |
+
shape = masks.shape
|
314 |
+
h, w = shape[-2:]
|
315 |
+
if len(shape) > 2:
|
316 |
+
masks = masks.flatten(0, -3)
|
317 |
+
else:
|
318 |
+
masks = masks.unsqueeze(0)
|
319 |
+
|
320 |
+
# Get top and bottom edges
|
321 |
+
in_height, _ = torch.max(masks, dim=-1)
|
322 |
+
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
|
323 |
+
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
324 |
+
in_height_coords = in_height_coords + h * (~in_height)
|
325 |
+
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
326 |
+
|
327 |
+
# Get left and right edges
|
328 |
+
in_width, _ = torch.max(masks, dim=-2)
|
329 |
+
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
|
330 |
+
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
331 |
+
in_width_coords = in_width_coords + w * (~in_width)
|
332 |
+
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
333 |
+
|
334 |
+
# If the mask is empty the right edge will be to the left of the left edge.
|
335 |
+
# Replace these boxes with [0, 0, 0, 0]
|
336 |
+
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
337 |
+
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
338 |
+
out = out * (~empty_filter).unsqueeze(-1)
|
339 |
+
|
340 |
+
# Return to original shape
|
341 |
+
if len(shape) > 2:
|
342 |
+
out = out.reshape(*shape[:-2], 4)
|
343 |
+
else:
|
344 |
+
out = out[0]
|
345 |
+
|
346 |
+
return out
|
models/utils/onnx.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
from typing import Tuple
|
12 |
+
|
13 |
+
from ..modeling import Sam
|
14 |
+
from .amg import calculate_stability_score
|
15 |
+
|
16 |
+
|
17 |
+
class SamOnnxModel(nn.Module):
|
18 |
+
"""
|
19 |
+
This model should not be called directly, but is used in ONNX export.
|
20 |
+
It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
|
21 |
+
with some functions modified to enable model tracing. Also supports extra
|
22 |
+
options controlling what information. See the ONNX export script for details.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
model: Sam,
|
28 |
+
return_single_mask: bool,
|
29 |
+
use_stability_score: bool = False,
|
30 |
+
return_extra_metrics: bool = False,
|
31 |
+
) -> None:
|
32 |
+
super().__init__()
|
33 |
+
self.mask_decoder = model.mask_decoder
|
34 |
+
self.model = model
|
35 |
+
self.img_size = model.image_encoder.img_size
|
36 |
+
self.return_single_mask = return_single_mask
|
37 |
+
self.use_stability_score = use_stability_score
|
38 |
+
self.stability_score_offset = 1.0
|
39 |
+
self.return_extra_metrics = return_extra_metrics
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def resize_longest_image_size(
|
43 |
+
input_image_size: torch.Tensor, longest_side: int
|
44 |
+
) -> torch.Tensor:
|
45 |
+
input_image_size = input_image_size.to(torch.float32)
|
46 |
+
scale = longest_side / torch.max(input_image_size)
|
47 |
+
transformed_size = scale * input_image_size
|
48 |
+
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
|
49 |
+
return transformed_size
|
50 |
+
|
51 |
+
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
|
52 |
+
point_coords = point_coords + 0.5
|
53 |
+
point_coords = point_coords / self.img_size
|
54 |
+
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
|
55 |
+
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
|
56 |
+
|
57 |
+
point_embedding = point_embedding * (point_labels != -1)
|
58 |
+
point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
|
59 |
+
point_labels == -1
|
60 |
+
)
|
61 |
+
|
62 |
+
for i in range(self.model.prompt_encoder.num_point_embeddings):
|
63 |
+
point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
|
64 |
+
i
|
65 |
+
].weight * (point_labels == i)
|
66 |
+
|
67 |
+
return point_embedding
|
68 |
+
|
69 |
+
def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
|
70 |
+
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
|
71 |
+
mask_embedding = mask_embedding + (
|
72 |
+
1 - has_mask_input
|
73 |
+
) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
|
74 |
+
return mask_embedding
|
75 |
+
|
76 |
+
def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
|
77 |
+
masks = F.interpolate(
|
78 |
+
masks,
|
79 |
+
size=(self.img_size, self.img_size),
|
80 |
+
mode="bilinear",
|
81 |
+
align_corners=False,
|
82 |
+
)
|
83 |
+
|
84 |
+
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
|
85 |
+
masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
|
86 |
+
|
87 |
+
orig_im_size = orig_im_size.to(torch.int64)
|
88 |
+
h, w = orig_im_size[0], orig_im_size[1]
|
89 |
+
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
|
90 |
+
return masks
|
91 |
+
|
92 |
+
def select_masks(
|
93 |
+
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
|
94 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
95 |
+
# Determine if we should return the multiclick mask or not from the number of points.
|
96 |
+
# The reweighting is used to avoid control flow.
|
97 |
+
score_reweight = torch.tensor(
|
98 |
+
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
|
99 |
+
).to(iou_preds.device)
|
100 |
+
score = iou_preds + (num_points - 2.5) * score_reweight
|
101 |
+
best_idx = torch.argmax(score, dim=1)
|
102 |
+
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
|
103 |
+
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
|
104 |
+
|
105 |
+
return masks, iou_preds
|
106 |
+
|
107 |
+
@torch.no_grad()
|
108 |
+
def forward(
|
109 |
+
self,
|
110 |
+
image_embeddings: torch.Tensor,
|
111 |
+
point_coords: torch.Tensor,
|
112 |
+
point_labels: torch.Tensor,
|
113 |
+
mask_input: torch.Tensor,
|
114 |
+
has_mask_input: torch.Tensor,
|
115 |
+
orig_im_size: torch.Tensor,
|
116 |
+
):
|
117 |
+
sparse_embedding = self._embed_points(point_coords, point_labels)
|
118 |
+
dense_embedding = self._embed_masks(mask_input, has_mask_input)
|
119 |
+
|
120 |
+
masks, scores = self.model.mask_decoder.predict_masks(
|
121 |
+
image_embeddings=image_embeddings,
|
122 |
+
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
123 |
+
sparse_prompt_embeddings=sparse_embedding,
|
124 |
+
dense_prompt_embeddings=dense_embedding,
|
125 |
+
)
|
126 |
+
|
127 |
+
if self.use_stability_score:
|
128 |
+
scores = calculate_stability_score(
|
129 |
+
masks, self.model.mask_threshold, self.stability_score_offset
|
130 |
+
)
|
131 |
+
|
132 |
+
if self.return_single_mask:
|
133 |
+
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
|
134 |
+
|
135 |
+
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
|
136 |
+
|
137 |
+
if self.return_extra_metrics:
|
138 |
+
stability_scores = calculate_stability_score(
|
139 |
+
upscaled_masks, self.model.mask_threshold, self.stability_score_offset
|
140 |
+
)
|
141 |
+
areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
|
142 |
+
return upscaled_masks, scores, stability_scores, areas, masks
|
143 |
+
|
144 |
+
return upscaled_masks, scores, masks
|
models/utils/transforms.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
|
11 |
+
|
12 |
+
from copy import deepcopy
|
13 |
+
from typing import Tuple
|
14 |
+
|
15 |
+
|
16 |
+
class ResizeLongestSide:
|
17 |
+
"""
|
18 |
+
Resizes images to the longest side 'target_length', as well as provides
|
19 |
+
methods for resizing coordinates and boxes. Provides methods for
|
20 |
+
transforming both numpy array and batched torch tensors.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, target_length: int) -> None:
|
24 |
+
self.target_length = target_length
|
25 |
+
|
26 |
+
def apply_image(self, image: np.ndarray) -> np.ndarray:
|
27 |
+
"""
|
28 |
+
Expects a numpy array with shape HxWxC in uint8 format.
|
29 |
+
"""
|
30 |
+
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
|
31 |
+
return np.array(resize(to_pil_image(image), target_size))
|
32 |
+
|
33 |
+
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
34 |
+
"""
|
35 |
+
Expects a numpy array of length 2 in the final dimension. Requires the
|
36 |
+
original image size in (H, W) format.
|
37 |
+
"""
|
38 |
+
old_h, old_w = original_size
|
39 |
+
new_h, new_w = self.get_preprocess_shape(
|
40 |
+
original_size[0], original_size[1], self.target_length
|
41 |
+
)
|
42 |
+
coords = deepcopy(coords).astype(float)
|
43 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
44 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
45 |
+
return coords
|
46 |
+
|
47 |
+
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
48 |
+
"""
|
49 |
+
Expects a numpy array shape Bx4. Requires the original image size
|
50 |
+
in (H, W) format.
|
51 |
+
"""
|
52 |
+
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
|
53 |
+
return boxes.reshape(-1, 4)
|
54 |
+
|
55 |
+
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
|
56 |
+
"""
|
57 |
+
Expects batched images with shape BxCxHxW and float format. This
|
58 |
+
transformation may not exactly match apply_image. apply_image is
|
59 |
+
the transformation expected by the model.
|
60 |
+
"""
|
61 |
+
# Expects an image in BCHW format. May not exactly match apply_image.
|
62 |
+
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
|
63 |
+
return F.interpolate(
|
64 |
+
image, target_size, mode="bilinear", align_corners=False, antialias=True
|
65 |
+
)
|
66 |
+
|
67 |
+
def apply_coords_torch(
|
68 |
+
self, coords: torch.Tensor, original_size: Tuple[int, ...]
|
69 |
+
) -> torch.Tensor:
|
70 |
+
"""
|
71 |
+
Expects a torch tensor with length 2 in the last dimension. Requires the
|
72 |
+
original image size in (H, W) format.
|
73 |
+
"""
|
74 |
+
old_h, old_w = original_size
|
75 |
+
new_h, new_w = self.get_preprocess_shape(
|
76 |
+
original_size[0], original_size[1], self.target_length
|
77 |
+
)
|
78 |
+
coords = deepcopy(coords).to(torch.float)
|
79 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
80 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
81 |
+
return coords
|
82 |
+
|
83 |
+
def apply_boxes_torch(
|
84 |
+
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
|
85 |
+
) -> torch.Tensor:
|
86 |
+
"""
|
87 |
+
Expects a torch tensor with shape Bx4. Requires the original image
|
88 |
+
size in (H, W) format.
|
89 |
+
"""
|
90 |
+
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
|
91 |
+
return boxes.reshape(-1, 4)
|
92 |
+
|
93 |
+
@staticmethod
|
94 |
+
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
|
95 |
+
"""
|
96 |
+
Compute the output size given input size and target long side length.
|
97 |
+
"""
|
98 |
+
scale = long_side_length * 1.0 / max(oldh, oldw)
|
99 |
+
newh, neww = oldh * scale, oldw * scale
|
100 |
+
neww = int(neww + 0.5)
|
101 |
+
newh = int(newh + 0.5)
|
102 |
+
return (newh, neww)
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
pycocotools
|
4 |
+
transformers
|
5 |
+
gradio_image_prompter-0.1.0-py3-none-any.whl
|
src/.gitignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.eggs/
|
2 |
+
dist/
|
3 |
+
*.pyc
|
4 |
+
__pycache__/
|
5 |
+
*.py[cod]
|
6 |
+
*$py.class
|
7 |
+
__tmp/*
|
8 |
+
*.pyi
|
9 |
+
node_modules
|
src/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
src/README.md
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Image Prompter for Gradio
|
2 |
+
A gradio component to upload images and process point/box prompts.
|
3 |
+
|
4 |
+
This custom component is developed for [Tokenize Anything](https://github.com/baaivision/tokenize-anything) gradio demo.
|
5 |
+
|
6 |
+
## Installation
|
7 |
+
|
8 |
+
### Preliminaries
|
9 |
+
|
10 |
+
``gradio`` >= 4.0.0
|
11 |
+
|
12 |
+
### Installing Package
|
13 |
+
|
14 |
+
```bash
|
15 |
+
pip install gradio-image-prompter
|
16 |
+
```
|
17 |
+
|
18 |
+
## Quick Start
|
19 |
+
|
20 |
+
### Development
|
21 |
+
|
22 |
+
```bash
|
23 |
+
cd gradio-image-prompter
|
24 |
+
gradio cc install
|
25 |
+
gradio cc dev
|
26 |
+
```
|
27 |
+
|
28 |
+
### Example
|
29 |
+
|
30 |
+
```python
|
31 |
+
import gradio as gr
|
32 |
+
from gradio_image_prompter import ImagePrompter
|
33 |
+
|
34 |
+
demo = gr.Interface(
|
35 |
+
lambda prompts: (prompts["image"], prompts["points"]),
|
36 |
+
ImagePrompter(show_label=False),
|
37 |
+
[gr.Image(show_label=False), gr.Dataframe(label="Points")],
|
38 |
+
)
|
39 |
+
demo.launch()
|
40 |
+
|
41 |
+
```
|
42 |
+
|
43 |
+
## License
|
44 |
+
[Apache License 2.0](LICENSE)
|
45 |
+
|
46 |
+
## Acknowledgement
|
47 |
+
|
48 |
+
We thank the repositories: [SAM](https://github.com/facebookresearch/segment-anything), [GradioBox](https://github.com/ShoufaChen/gradio-box) and [Gradio](https://github.com/gradio-app/gradio).
|
src/backend/gradio_image_prompter/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .image_prompter import ImagePrompter
|
2 |
+
|
3 |
+
__all__ = ["ImagePrompter"]
|
src/backend/gradio_image_prompter/image_prompter.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, PhyscalX. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Gradio ``ImagePrompter`` component."""
|
17 |
+
|
18 |
+
from __future__ import annotations
|
19 |
+
|
20 |
+
from typing import Optional, List, TypedDict, Union, Literal
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import gradio
|
24 |
+
from gradio.data_classes import FileData, GradioModel
|
25 |
+
from gradio_client.documentation import document, set_documentation_group
|
26 |
+
from PIL import Image as _Image # using _ to minimize namespace pollution
|
27 |
+
|
28 |
+
set_documentation_group("component")
|
29 |
+
|
30 |
+
|
31 |
+
class PromptData(GradioModel):
|
32 |
+
image: FileData
|
33 |
+
points: List[List[float]]
|
34 |
+
|
35 |
+
|
36 |
+
class PromptValue(TypedDict):
|
37 |
+
image: Optional[Union[np.ndarray, _Image.Image, str]]
|
38 |
+
points: Optional[List[List[float]]]
|
39 |
+
|
40 |
+
|
41 |
+
@document()
|
42 |
+
class ImagePrompter(gradio.Image):
|
43 |
+
"""Create an image prompter to upload images and process point/box prompts."""
|
44 |
+
|
45 |
+
data_model = PromptData
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
value: str | _Image.Image | np.ndarray | None = None,
|
50 |
+
*,
|
51 |
+
height: int | None = None,
|
52 |
+
width: int | None = None,
|
53 |
+
image_mode: Literal[
|
54 |
+
"1", "L", "P", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"
|
55 |
+
] = "RGB",
|
56 |
+
sources: list[Literal["upload", "clipboard"]] | None = None,
|
57 |
+
type: Literal["numpy", "pil", "filepath"] = "numpy",
|
58 |
+
label: str | None = None,
|
59 |
+
every: float | None = None,
|
60 |
+
show_label: bool | None = None,
|
61 |
+
show_download_button: bool = True,
|
62 |
+
container: bool = True,
|
63 |
+
scale: int | None = None,
|
64 |
+
min_width: int = 160,
|
65 |
+
interactive: bool | None = None,
|
66 |
+
visible: bool = True,
|
67 |
+
elem_id: str | None = None,
|
68 |
+
elem_classes: list[str] | str | None = None,
|
69 |
+
render: bool = True,
|
70 |
+
show_share_button: bool | None = None,
|
71 |
+
):
|
72 |
+
"""
|
73 |
+
Parameters:
|
74 |
+
value: A PIL Image, numpy array, path or URL for the default value. If callable, it will be called set the initial value.
|
75 |
+
height: Height of the displayed image in pixels.
|
76 |
+
width: Width of the displayed image in pixels.
|
77 |
+
image_mode: "RGB" if color, or "L" if black and white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html.
|
78 |
+
sources: List of sources for the image.
|
79 |
+
type: The format the image is converted before being passed into the prediction function.
|
80 |
+
label: The label for this component.
|
81 |
+
every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open.
|
82 |
+
show_label: if True, will display label.
|
83 |
+
show_download_button: If True, will display button to download image.
|
84 |
+
container: If True, will place the component in a container - providing some extra padding around the border.
|
85 |
+
scale: relative width compared to adjacent Components in a Row. Should be an integer.
|
86 |
+
min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value.
|
87 |
+
interactive: if True, will allow users to upload and edit an image; if False, can only be used to display images.
|
88 |
+
visible: If False, component will be hidden.
|
89 |
+
streaming: If True when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'webcam'.
|
90 |
+
elem_id: An optional string that is assigned as the id of this component in the HTML DOM.
|
91 |
+
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM.
|
92 |
+
render: If False, component will not render be rendered in the Blocks context.
|
93 |
+
mirror_webcam: If True webcam will be mirrored. Default is True.
|
94 |
+
show_share_button: If True, show a share icon that allows user to share outputs to Hugging Face Spaces Discussions.
|
95 |
+
"""
|
96 |
+
super(ImagePrompter, self).__init__(
|
97 |
+
value=value,
|
98 |
+
height=height,
|
99 |
+
width=width,
|
100 |
+
image_mode=image_mode,
|
101 |
+
sources=["upload", "clipboard"] if sources is None else sources,
|
102 |
+
type=type,
|
103 |
+
label=label,
|
104 |
+
every=every,
|
105 |
+
show_label=show_label,
|
106 |
+
show_download_button=show_download_button,
|
107 |
+
container=container,
|
108 |
+
scale=scale,
|
109 |
+
min_width=min_width,
|
110 |
+
interactive=interactive,
|
111 |
+
visible=visible,
|
112 |
+
elem_id=elem_id,
|
113 |
+
elem_classes=elem_classes,
|
114 |
+
render=render,
|
115 |
+
show_share_button=show_share_button,
|
116 |
+
)
|
117 |
+
|
118 |
+
def preprocess(self, x: PromptData) -> PromptValue | None:
|
119 |
+
if x is None:
|
120 |
+
return x
|
121 |
+
im = super().preprocess(x.image)
|
122 |
+
return {"image": im, "points": x.points}
|
123 |
+
|
124 |
+
def postprocess(self, y: PromptValue) -> PromptData | None:
|
125 |
+
if y is None:
|
126 |
+
return None
|
127 |
+
image, points = y.get("image", None), y.get("points", [])
|
128 |
+
return PromptData(image=super().postprocess(image), points=points)
|
129 |
+
|
130 |
+
def as_example(self, y: PromptValue) -> str | None:
|
131 |
+
if y is None:
|
132 |
+
return None
|
133 |
+
return self.move_resource_to_block_cache(y.get("image", None))
|
src/backend/gradio_image_prompter/image_prompter.pyi
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, PhyscalX. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Gradio ``ImagePrompter`` component."""
|
17 |
+
|
18 |
+
from __future__ import annotations
|
19 |
+
|
20 |
+
from typing import Optional, List, TypedDict, Union, Literal
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import gradio
|
24 |
+
from gradio.data_classes import FileData, GradioModel
|
25 |
+
from gradio_client.documentation import document, set_documentation_group
|
26 |
+
from PIL import Image as _Image # using _ to minimize namespace pollution
|
27 |
+
|
28 |
+
set_documentation_group("component")
|
29 |
+
|
30 |
+
|
31 |
+
class PromptData(GradioModel):
|
32 |
+
image: FileData
|
33 |
+
points: List[List[float]]
|
34 |
+
|
35 |
+
|
36 |
+
class PromptValue(TypedDict):
|
37 |
+
image: Optional[Union[np.ndarray, _Image.Image, str]]
|
38 |
+
points: Optional[list[list[float]]]
|
39 |
+
|
40 |
+
from gradio.events import Dependency
|
41 |
+
|
42 |
+
@document()
|
43 |
+
class ImagePrompter(gradio.Image):
|
44 |
+
"""Create an image prompter to upload images and process point/box prompts."""
|
45 |
+
|
46 |
+
data_model = PromptData
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
value: str | _Image.Image | np.ndarray | None = None,
|
51 |
+
*,
|
52 |
+
height: int | None = None,
|
53 |
+
width: int | None = None,
|
54 |
+
image_mode: Literal[
|
55 |
+
"1", "L", "P", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"
|
56 |
+
] = "RGB",
|
57 |
+
sources: list[Literal["upload", "clipboard"]] | None = None,
|
58 |
+
type: Literal["numpy", "pil", "filepath"] = "numpy",
|
59 |
+
label: str | None = None,
|
60 |
+
every: float | None = None,
|
61 |
+
show_label: bool | None = None,
|
62 |
+
show_download_button: bool = True,
|
63 |
+
container: bool = True,
|
64 |
+
scale: int | None = None,
|
65 |
+
min_width: int = 160,
|
66 |
+
interactive: bool | None = None,
|
67 |
+
visible: bool = True,
|
68 |
+
elem_id: str | None = None,
|
69 |
+
elem_classes: list[str] | str | None = None,
|
70 |
+
render: bool = True,
|
71 |
+
show_share_button: bool | None = None,
|
72 |
+
):
|
73 |
+
"""
|
74 |
+
Parameters:
|
75 |
+
value: A PIL Image, numpy array, path or URL for the default value. If callable, it will be called set the initial value.
|
76 |
+
height: Height of the displayed image in pixels.
|
77 |
+
width: Width of the displayed image in pixels.
|
78 |
+
image_mode: "RGB" if color, or "L" if black and white. See https://pillow.readthedocs.io/en/stable/handbook/concepts.html.
|
79 |
+
sources: List of sources for the image.
|
80 |
+
type: The format the image is converted before being passed into the prediction function.
|
81 |
+
label: The label for this component.
|
82 |
+
every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open.
|
83 |
+
show_label: if True, will display label.
|
84 |
+
show_download_button: If True, will display button to download image.
|
85 |
+
container: If True, will place the component in a container - providing some extra padding around the border.
|
86 |
+
scale: relative width compared to adjacent Components in a Row. Should be an integer.
|
87 |
+
min_width: minimum pixel width, will wrap if not sufficient screen space to satisfy this value.
|
88 |
+
interactive: if True, will allow users to upload and edit an image; if False, can only be used to display images.
|
89 |
+
visible: If False, component will be hidden.
|
90 |
+
streaming: If True when used in a `live` interface, will automatically stream webcam feed. Only valid is source is 'webcam'.
|
91 |
+
elem_id: An optional string that is assigned as the id of this component in the HTML DOM.
|
92 |
+
elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM.
|
93 |
+
render: If False, component will not render be rendered in the Blocks context.
|
94 |
+
mirror_webcam: If True webcam will be mirrored. Default is True.
|
95 |
+
show_share_button: If True, show a share icon that allows user to share outputs to Hugging Face Spaces Discussions.
|
96 |
+
"""
|
97 |
+
super(ImagePrompter, self).__init__(
|
98 |
+
value=value,
|
99 |
+
height=height,
|
100 |
+
width=width,
|
101 |
+
image_mode=image_mode,
|
102 |
+
sources=["upload", "clipboard"] if sources is None else sources,
|
103 |
+
type=type,
|
104 |
+
label=label,
|
105 |
+
every=every,
|
106 |
+
show_label=show_label,
|
107 |
+
show_download_button=show_download_button,
|
108 |
+
container=container,
|
109 |
+
scale=scale,
|
110 |
+
min_width=min_width,
|
111 |
+
interactive=interactive,
|
112 |
+
visible=visible,
|
113 |
+
elem_id=elem_id,
|
114 |
+
elem_classes=elem_classes,
|
115 |
+
render=render,
|
116 |
+
show_share_button=show_share_button,
|
117 |
+
)
|
118 |
+
|
119 |
+
def preprocess(self, x: PromptData) -> PromptValue | None:
|
120 |
+
if x is None:
|
121 |
+
return x
|
122 |
+
im = super().preprocess(x.image)
|
123 |
+
return {"image": im, "points": x.points}
|
124 |
+
|
125 |
+
def postprocess(self, y: PromptValue) -> PromptData | None:
|
126 |
+
if y is None:
|
127 |
+
return None
|
128 |
+
image, points = y.get("image", None), y.get("points", [])
|
129 |
+
return PromptData(image=super().postprocess(image), points=points)
|
130 |
+
|
131 |
+
def as_example(self, y: PromptValue) -> str | None:
|
132 |
+
if y is None:
|
133 |
+
return None
|
134 |
+
return self.move_resource_to_block_cache(y.get("image", None))
|
src/backend/gradio_image_prompter/templates/component/__vite-browser-external-2447137e.js
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
const e = {};
|
2 |
+
export {
|
3 |
+
e as default
|
4 |
+
};
|
src/backend/gradio_image_prompter/templates/component/index.js
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/backend/gradio_image_prompter/templates/component/style.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.block.svelte-1t38q2d{position:relative;margin:0;box-shadow:var(--block-shadow);border-width:var(--block-border-width);border-color:var(--block-border-color);border-radius:var(--block-radius);background:var(--block-background-fill);width:100%;line-height:var(--line-sm)}.block.border_focus.svelte-1t38q2d{border-color:var(--color-accent)}.padded.svelte-1t38q2d{padding:var(--block-padding)}.hidden.svelte-1t38q2d{display:none}.hide-container.svelte-1t38q2d{margin:0;box-shadow:none;--block-border-width:0;background:transparent;padding:0;overflow:visible}div.svelte-1hnfib2{margin-bottom:var(--spacing-lg);color:var(--block-info-text-color);font-weight:var(--block-info-text-weight);font-size:var(--block-info-text-size);line-height:var(--line-sm)}span.has-info.svelte-22c38v{margin-bottom:var(--spacing-xs)}span.svelte-22c38v:not(.has-info){margin-bottom:var(--spacing-lg)}span.svelte-22c38v{display:inline-block;position:relative;z-index:var(--layer-4);border:solid var(--block-title-border-width) var(--block-title-border-color);border-radius:var(--block-title-radius);background:var(--block-title-background-fill);padding:var(--block-title-padding);color:var(--block-title-text-color);font-weight:var(--block-title-text-weight);font-size:var(--block-title-text-size);line-height:var(--line-sm)}.hide.svelte-22c38v{margin:0;height:0}label.svelte-9gxdi0{display:inline-flex;align-items:center;z-index:var(--layer-2);box-shadow:var(--block-label-shadow);border:var(--block-label-border-width) solid var(--border-color-primary);border-top:none;border-left:none;border-radius:var(--block-label-radius);background:var(--block-label-background-fill);padding:var(--block-label-padding);pointer-events:none;color:var(--block-label-text-color);font-weight:var(--block-label-text-weight);font-size:var(--block-label-text-size);line-height:var(--line-sm)}.gr-group label.svelte-9gxdi0{border-top-left-radius:0}label.float.svelte-9gxdi0{position:absolute;top:var(--block-label-margin);left:var(--block-label-margin)}label.svelte-9gxdi0:not(.float){position:static;margin-top:var(--block-label-margin);margin-left:var(--block-label-margin)}.hide.svelte-9gxdi0{height:0}span.svelte-9gxdi0{opacity:.8;margin-right:var(--size-2);width:calc(var(--block-label-text-size) - 1px);height:calc(var(--block-label-text-size) - 1px)}.hide-label.svelte-9gxdi0{box-shadow:none;border-width:0;background:transparent;overflow:visible}button.svelte-lpi64a{display:flex;justify-content:center;align-items:center;gap:1px;z-index:var(--layer-2);border-radius:var(--radius-sm);color:var(--block-label-text-color);border:1px solid transparent}button[disabled].svelte-lpi64a{opacity:.5;box-shadow:none}button[disabled].svelte-lpi64a:hover{cursor:not-allowed}.padded.svelte-lpi64a{padding:2px;background:var(--bg-color);box-shadow:var(--shadow-drop);border:1px solid var(--button-secondary-border-color)}button.svelte-lpi64a:hover,button.highlight.svelte-lpi64a{cursor:pointer;color:var(--color-accent)}.padded.svelte-lpi64a:hover{border:2px solid var(--button-secondary-border-color-hover);padding:1px;color:var(--block-label-text-color)}span.svelte-lpi64a{padding:0 1px;font-size:10px}div.svelte-lpi64a{padding:2px;display:flex;align-items:flex-end}.small.svelte-lpi64a{width:14px;height:14px}.large.svelte-lpi64a{width:22px;height:22px}.pending.svelte-lpi64a{animation:svelte-lpi64a-flash .5s infinite}@keyframes svelte-lpi64a-flash{0%{opacity:.5}50%{opacity:1}to{opacity:.5}}.transparent.svelte-lpi64a{background:transparent;border:none;box-shadow:none}.empty.svelte-3w3rth{display:flex;justify-content:center;align-items:center;margin-top:calc(0px - var(--size-6));height:var(--size-full)}.icon.svelte-3w3rth{opacity:.5;height:var(--size-5);color:var(--body-text-color)}.small.svelte-3w3rth{min-height:calc(var(--size-32) - 20px)}.large.svelte-3w3rth{min-height:calc(var(--size-64) - 20px)}.unpadded_box.svelte-3w3rth{margin-top:0}.small_parent.svelte-3w3rth{min-height:100%!important}.dropdown-arrow.svelte-145leq6{fill:currentColor}.wrap.svelte-kzcjhc{display:flex;flex-direction:column;justify-content:center;align-items:center;min-height:var(--size-60);color:var(--block-label-text-color);line-height:var(--line-md);height:100%;padding-top:var(--size-3)}.or.svelte-kzcjhc{color:var(--body-text-color-subdued);display:flex}.icon-wrap.svelte-kzcjhc{width:30px;margin-bottom:var(--spacing-lg)}@media (--screen-md){.wrap.svelte-kzcjhc{font-size:var(--text-lg)}}.hovered.svelte-kzcjhc{color:var(--color-accent)}div.svelte-ipfyu7{border-top:1px solid transparent;display:flex;max-height:100%;justify-content:center;gap:var(--spacing-sm);height:auto;align-items:flex-end;padding-bottom:var(--spacing-xl);color:var(--block-label-text-color);flex-shrink:0;width:95%}.show_border.svelte-ipfyu7{border-top:1px solid var(--block-border-color);margin-top:var(--spacing-xxl);box-shadow:var(--shadow-drop)}.source-selection.svelte-lde7lt{display:flex;align-items:center;justify-content:center;border-top:1px solid var(--border-color-primary);width:95%;bottom:0;left:0;right:0;margin-left:auto;margin-right:auto;align-self:flex-end}.icon.svelte-lde7lt{width:22px;height:22px;margin:var(--spacing-lg) var(--spacing-xs);padding:var(--spacing-xs);color:var(--neutral-400);border-radius:var(--radius-md)}.selected.svelte-lde7lt{color:var(--color-accent)}.icon.svelte-lde7lt:hover,.icon.svelte-lde7lt:focus{color:var(--color-accent)}img.svelte-1e0ed51,button.svelte-1e0ed51{width:var(--size-full);height:var(--size-full);object-fit:contain;display:block;border-radius:var(--radius-lg)}.selectable.svelte-1e0ed51{cursor:crosshair}.icon-buttons.svelte-1e0ed51{display:flex;position:absolute;top:6px;right:6px;gap:var(--size-1)}.wrap.svelte-12ckl9l.svelte-12ckl9l{overflow-y:auto;transition:opacity .5s ease-in-out;background:var(--block-background-fill);position:relative;display:flex;flex-direction:column;align-items:center;justify-content:center;min-height:var(--size-40)}.wrap.svelte-12ckl9l.svelte-12ckl9l:after{content:"";position:absolute;top:0;left:0;width:var(--upload-progress-width);height:100%;transition:all .5s ease-in-out;z-index:1}.uploading.svelte-12ckl9l.svelte-12ckl9l{font-size:var(--text-lg);font-family:var(--font);z-index:2}.file-name.svelte-12ckl9l.svelte-12ckl9l{margin:var(--spacing-md);font-size:var(--text-lg);color:var(--body-text-color-subdued)}.file.svelte-12ckl9l.svelte-12ckl9l{font-size:var(--text-md);z-index:2;display:flex;align-items:center}.file.svelte-12ckl9l progress.svelte-12ckl9l{display:inline;height:var(--size-1);width:100%;transition:all .5s ease-in-out;color:var(--color-accent);border:none}.file.svelte-12ckl9l progress[value].svelte-12ckl9l::-webkit-progress-value{background-color:var(--color-accent);border-radius:20px}.file.svelte-12ckl9l progress[value].svelte-12ckl9l::-webkit-progress-bar{background-color:var(--border-color-accent);border-radius:20px}.progress-bar.svelte-12ckl9l.svelte-12ckl9l{width:14px;height:14px;border-radius:50%;background:radial-gradient(closest-side,var(--block-background-fill) 64%,transparent 53% 100%),conic-gradient(var(--color-accent) var(--upload-progress-width),var(--border-color-accent) 0);transition:all .5s ease-in-out}button.svelte-1aq8tno{cursor:pointer;width:var(--size-full)}.hidden.svelte-1aq8tno{display:none;height:0!important;position:absolute;width:0;flex-grow:0}.center.svelte-1aq8tno{display:flex;justify-content:center}.flex.svelte-1aq8tno{display:flex;justify-content:center;align-items:center}input.svelte-1aq8tno{display:none}div.svelte-1wj0ocy{display:flex;top:var(--size-2);right:var(--size-2);justify-content:flex-end;gap:var(--spacing-sm);z-index:var(--layer-1)}.not-absolute.svelte-1wj0ocy{margin:var(--size-1)}div.svelte-1o7cyxy{display:flex;position:absolute;top:var(--size-2);right:var(--size-2);justify-content:flex-end;gap:var(--spacing-sm);z-index:var(--layer-5)}canvas.svelte-1mnpmgt{display:block;position:absolute;top:0;right:0;bottom:0;left:0;margin:auto}.wrap.svelte-1mnpmgt{position:relative;width:var(--size-full);height:var(--size-full);touch-action:none}img.svelte-1qm7xww{width:var(--size-full);height:var(--size-full)}.upload-container.svelte-1qm7xww{height:100%;flex-shrink:1;max-height:100%}.image-container.svelte-1qm7xww{display:flex;height:100%;flex-direction:column;justify-content:center;align-items:center;max-height:100%}svg.svelte-43sxxs.svelte-43sxxs{width:var(--size-20);height:var(--size-20)}svg.svelte-43sxxs path.svelte-43sxxs{fill:var(--loader-color)}div.svelte-43sxxs.svelte-43sxxs{z-index:var(--layer-2)}.margin.svelte-43sxxs.svelte-43sxxs{margin:var(--size-4)}.wrap.svelte-1txqlrd.svelte-1txqlrd{display:flex;flex-direction:column;justify-content:center;align-items:center;z-index:var(--layer-top);transition:opacity .1s ease-in-out;border-radius:var(--block-radius);background:var(--block-background-fill);padding:0 var(--size-6);max-height:var(--size-screen-h);overflow:hidden;pointer-events:none}.wrap.center.svelte-1txqlrd.svelte-1txqlrd{top:0;right:0;left:0}.wrap.default.svelte-1txqlrd.svelte-1txqlrd{top:0;right:0;bottom:0;left:0}.hide.svelte-1txqlrd.svelte-1txqlrd{opacity:0;pointer-events:none}.generating.svelte-1txqlrd.svelte-1txqlrd{animation:svelte-1txqlrd-pulse 2s cubic-bezier(.4,0,.6,1) infinite;border:2px solid var(--color-accent);background:transparent}.translucent.svelte-1txqlrd.svelte-1txqlrd{background:none}@keyframes svelte-1txqlrd-pulse{0%,to{opacity:1}50%{opacity:.5}}.loading.svelte-1txqlrd.svelte-1txqlrd{z-index:var(--layer-2);color:var(--body-text-color)}.eta-bar.svelte-1txqlrd.svelte-1txqlrd{position:absolute;top:0;right:0;bottom:0;left:0;transform-origin:left;opacity:.8;z-index:var(--layer-1);transition:10ms;background:var(--background-fill-secondary)}.progress-bar-wrap.svelte-1txqlrd.svelte-1txqlrd{border:1px solid var(--border-color-primary);background:var(--background-fill-primary);width:55.5%;height:var(--size-4)}.progress-bar.svelte-1txqlrd.svelte-1txqlrd{transform-origin:left;background-color:var(--loader-color);width:var(--size-full);height:var(--size-full)}.progress-level.svelte-1txqlrd.svelte-1txqlrd{display:flex;flex-direction:column;align-items:center;gap:1;z-index:var(--layer-2);width:var(--size-full)}.progress-level-inner.svelte-1txqlrd.svelte-1txqlrd{margin:var(--size-2) auto;color:var(--body-text-color);font-size:var(--text-sm);font-family:var(--font-mono)}.meta-text.svelte-1txqlrd.svelte-1txqlrd{position:absolute;top:0;right:0;z-index:var(--layer-2);padding:var(--size-1) var(--size-2);font-size:var(--text-sm);font-family:var(--font-mono)}.meta-text-center.svelte-1txqlrd.svelte-1txqlrd{display:flex;position:absolute;top:0;right:0;justify-content:center;align-items:center;transform:translateY(var(--size-6));z-index:var(--layer-2);padding:var(--size-1) var(--size-2);font-size:var(--text-sm);font-family:var(--font-mono);text-align:center}.error.svelte-1txqlrd.svelte-1txqlrd{box-shadow:var(--shadow-drop);border:solid 1px var(--error-border-color);border-radius:var(--radius-full);background:var(--error-background-fill);padding-right:var(--size-4);padding-left:var(--size-4);color:var(--error-text-color);font-weight:var(--weight-semibold);font-size:var(--text-lg);line-height:var(--line-lg);font-family:var(--font)}.minimal.svelte-1txqlrd .progress-text.svelte-1txqlrd{background:var(--block-background-fill)}.border.svelte-1txqlrd.svelte-1txqlrd{border:1px solid var(--border-color-primary)}.toast-body.svelte-solcu7{display:flex;position:relative;right:0;left:0;align-items:center;margin:var(--size-6) var(--size-4);margin:auto;border-radius:var(--container-radius);overflow:hidden;pointer-events:auto}.toast-body.error.svelte-solcu7{border:1px solid var(--color-red-700);background:var(--color-red-50)}.dark .toast-body.error.svelte-solcu7{border:1px solid var(--color-red-500);background-color:var(--color-grey-950)}.toast-body.warning.svelte-solcu7{border:1px solid var(--color-yellow-700);background:var(--color-yellow-50)}.dark .toast-body.warning.svelte-solcu7{border:1px solid var(--color-yellow-500);background-color:var(--color-grey-950)}.toast-body.info.svelte-solcu7{border:1px solid var(--color-grey-700);background:var(--color-grey-50)}.dark .toast-body.info.svelte-solcu7{border:1px solid var(--color-grey-500);background-color:var(--color-grey-950)}.toast-title.svelte-solcu7{display:flex;align-items:center;font-weight:var(--weight-bold);font-size:var(--text-lg);line-height:var(--line-sm);text-transform:capitalize}.toast-title.error.svelte-solcu7{color:var(--color-red-700)}.dark .toast-title.error.svelte-solcu7{color:var(--color-red-50)}.toast-title.warning.svelte-solcu7{color:var(--color-yellow-700)}.dark .toast-title.warning.svelte-solcu7{color:var(--color-yellow-50)}.toast-title.info.svelte-solcu7{color:var(--color-grey-700)}.dark .toast-title.info.svelte-solcu7{color:var(--color-grey-50)}.toast-close.svelte-solcu7{margin:0 var(--size-3);border-radius:var(--size-3);padding:0px var(--size-1-5);font-size:var(--size-5);line-height:var(--size-5)}.toast-close.error.svelte-solcu7{color:var(--color-red-700)}.dark .toast-close.error.svelte-solcu7{color:var(--color-red-500)}.toast-close.warning.svelte-solcu7{color:var(--color-yellow-700)}.dark .toast-close.warning.svelte-solcu7{color:var(--color-yellow-500)}.toast-close.info.svelte-solcu7{color:var(--color-grey-700)}.dark .toast-close.info.svelte-solcu7{color:var(--color-grey-500)}.toast-text.svelte-solcu7{font-size:var(--text-lg)}.toast-text.error.svelte-solcu7{color:var(--color-red-700)}.dark .toast-text.error.svelte-solcu7{color:var(--color-red-50)}.toast-text.warning.svelte-solcu7{color:var(--color-yellow-700)}.dark .toast-text.warning.svelte-solcu7{color:var(--color-yellow-50)}.toast-text.info.svelte-solcu7{color:var(--color-grey-700)}.dark .toast-text.info.svelte-solcu7{color:var(--color-grey-50)}.toast-details.svelte-solcu7{margin:var(--size-3) var(--size-3) var(--size-3) 0;width:100%}.toast-icon.svelte-solcu7{display:flex;position:absolute;position:relative;flex-shrink:0;justify-content:center;align-items:center;margin:var(--size-2);border-radius:var(--radius-full);padding:var(--size-1);padding-left:calc(var(--size-1) - 1px);width:35px;height:35px}.toast-icon.error.svelte-solcu7{color:var(--color-red-700)}.dark .toast-icon.error.svelte-solcu7{color:var(--color-red-500)}.toast-icon.warning.svelte-solcu7{color:var(--color-yellow-700)}.dark .toast-icon.warning.svelte-solcu7{color:var(--color-yellow-500)}.toast-icon.info.svelte-solcu7{color:var(--color-grey-700)}.dark .toast-icon.info.svelte-solcu7{color:var(--color-grey-500)}@keyframes svelte-solcu7-countdown{0%{transform:scaleX(1)}to{transform:scaleX(0)}}.timer.svelte-solcu7{position:absolute;bottom:0;left:0;transform-origin:0 0;animation:svelte-solcu7-countdown 10s linear forwards;width:100%;height:var(--size-1)}.timer.error.svelte-solcu7{background:var(--color-red-700)}.dark .timer.error.svelte-solcu7{background:var(--color-red-500)}.timer.warning.svelte-solcu7{background:var(--color-yellow-700)}.dark .timer.warning.svelte-solcu7{background:var(--color-yellow-500)}.timer.info.svelte-solcu7{background:var(--color-grey-700)}.dark .timer.info.svelte-solcu7{background:var(--color-grey-500)}.toast-wrap.svelte-gatr8h{display:flex;position:fixed;top:var(--size-4);right:var(--size-4);flex-direction:column;align-items:end;gap:var(--size-2);z-index:var(--layer-top);width:calc(100% - var(--size-8))}@media (--screen-sm){.toast-wrap.svelte-gatr8h{width:calc(var(--size-96) + var(--size-10))}}.container.svelte-h11ksk img{width:100%;height:100%}.container.selected.svelte-h11ksk{border-color:var(--border-color-accent)}.container.table.svelte-h11ksk{margin:0 auto;border:2px solid var(--border-color-primary);border-radius:var(--radius-lg);overflow:hidden;width:var(--size-20);height:var(--size-20);object-fit:cover}.container.gallery.svelte-h11ksk{height:var(--size-20);max-height:var(--size-20);object-fit:cover}
|
src/backend/gradio_image_prompter/templates/component/wrapper-6f348d45-f837cf34.js
ADDED
@@ -0,0 +1,2455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import S from "./__vite-browser-external-2447137e.js";
|
2 |
+
function z(s) {
|
3 |
+
return s && s.__esModule && Object.prototype.hasOwnProperty.call(s, "default") ? s.default : s;
|
4 |
+
}
|
5 |
+
function gt(s) {
|
6 |
+
if (s.__esModule)
|
7 |
+
return s;
|
8 |
+
var e = s.default;
|
9 |
+
if (typeof e == "function") {
|
10 |
+
var t = function r() {
|
11 |
+
if (this instanceof r) {
|
12 |
+
var i = [null];
|
13 |
+
i.push.apply(i, arguments);
|
14 |
+
var n = Function.bind.apply(e, i);
|
15 |
+
return new n();
|
16 |
+
}
|
17 |
+
return e.apply(this, arguments);
|
18 |
+
};
|
19 |
+
t.prototype = e.prototype;
|
20 |
+
} else
|
21 |
+
t = {};
|
22 |
+
return Object.defineProperty(t, "__esModule", { value: !0 }), Object.keys(s).forEach(function(r) {
|
23 |
+
var i = Object.getOwnPropertyDescriptor(s, r);
|
24 |
+
Object.defineProperty(t, r, i.get ? i : {
|
25 |
+
enumerable: !0,
|
26 |
+
get: function() {
|
27 |
+
return s[r];
|
28 |
+
}
|
29 |
+
});
|
30 |
+
}), t;
|
31 |
+
}
|
32 |
+
const { Duplex: yt } = S;
|
33 |
+
function Oe(s) {
|
34 |
+
s.emit("close");
|
35 |
+
}
|
36 |
+
function vt() {
|
37 |
+
!this.destroyed && this._writableState.finished && this.destroy();
|
38 |
+
}
|
39 |
+
function Qe(s) {
|
40 |
+
this.removeListener("error", Qe), this.destroy(), this.listenerCount("error") === 0 && this.emit("error", s);
|
41 |
+
}
|
42 |
+
function St(s, e) {
|
43 |
+
let t = !0;
|
44 |
+
const r = new yt({
|
45 |
+
...e,
|
46 |
+
autoDestroy: !1,
|
47 |
+
emitClose: !1,
|
48 |
+
objectMode: !1,
|
49 |
+
writableObjectMode: !1
|
50 |
+
});
|
51 |
+
return s.on("message", function(n, o) {
|
52 |
+
const l = !o && r._readableState.objectMode ? n.toString() : n;
|
53 |
+
r.push(l) || s.pause();
|
54 |
+
}), s.once("error", function(n) {
|
55 |
+
r.destroyed || (t = !1, r.destroy(n));
|
56 |
+
}), s.once("close", function() {
|
57 |
+
r.destroyed || r.push(null);
|
58 |
+
}), r._destroy = function(i, n) {
|
59 |
+
if (s.readyState === s.CLOSED) {
|
60 |
+
n(i), process.nextTick(Oe, r);
|
61 |
+
return;
|
62 |
+
}
|
63 |
+
let o = !1;
|
64 |
+
s.once("error", function(f) {
|
65 |
+
o = !0, n(f);
|
66 |
+
}), s.once("close", function() {
|
67 |
+
o || n(i), process.nextTick(Oe, r);
|
68 |
+
}), t && s.terminate();
|
69 |
+
}, r._final = function(i) {
|
70 |
+
if (s.readyState === s.CONNECTING) {
|
71 |
+
s.once("open", function() {
|
72 |
+
r._final(i);
|
73 |
+
});
|
74 |
+
return;
|
75 |
+
}
|
76 |
+
s._socket !== null && (s._socket._writableState.finished ? (i(), r._readableState.endEmitted && r.destroy()) : (s._socket.once("finish", function() {
|
77 |
+
i();
|
78 |
+
}), s.close()));
|
79 |
+
}, r._read = function() {
|
80 |
+
s.isPaused && s.resume();
|
81 |
+
}, r._write = function(i, n, o) {
|
82 |
+
if (s.readyState === s.CONNECTING) {
|
83 |
+
s.once("open", function() {
|
84 |
+
r._write(i, n, o);
|
85 |
+
});
|
86 |
+
return;
|
87 |
+
}
|
88 |
+
s.send(i, o);
|
89 |
+
}, r.on("end", vt), r.on("error", Qe), r;
|
90 |
+
}
|
91 |
+
var Et = St;
|
92 |
+
const Vs = /* @__PURE__ */ z(Et);
|
93 |
+
var te = { exports: {} }, U = {
|
94 |
+
BINARY_TYPES: ["nodebuffer", "arraybuffer", "fragments"],
|
95 |
+
EMPTY_BUFFER: Buffer.alloc(0),
|
96 |
+
GUID: "258EAFA5-E914-47DA-95CA-C5AB0DC85B11",
|
97 |
+
kForOnEventAttribute: Symbol("kIsForOnEventAttribute"),
|
98 |
+
kListener: Symbol("kListener"),
|
99 |
+
kStatusCode: Symbol("status-code"),
|
100 |
+
kWebSocket: Symbol("websocket"),
|
101 |
+
NOOP: () => {
|
102 |
+
}
|
103 |
+
}, bt, xt;
|
104 |
+
const { EMPTY_BUFFER: kt } = U, Se = Buffer[Symbol.species];
|
105 |
+
function wt(s, e) {
|
106 |
+
if (s.length === 0)
|
107 |
+
return kt;
|
108 |
+
if (s.length === 1)
|
109 |
+
return s[0];
|
110 |
+
const t = Buffer.allocUnsafe(e);
|
111 |
+
let r = 0;
|
112 |
+
for (let i = 0; i < s.length; i++) {
|
113 |
+
const n = s[i];
|
114 |
+
t.set(n, r), r += n.length;
|
115 |
+
}
|
116 |
+
return r < e ? new Se(t.buffer, t.byteOffset, r) : t;
|
117 |
+
}
|
118 |
+
function Je(s, e, t, r, i) {
|
119 |
+
for (let n = 0; n < i; n++)
|
120 |
+
t[r + n] = s[n] ^ e[n & 3];
|
121 |
+
}
|
122 |
+
function et(s, e) {
|
123 |
+
for (let t = 0; t < s.length; t++)
|
124 |
+
s[t] ^= e[t & 3];
|
125 |
+
}
|
126 |
+
function Ot(s) {
|
127 |
+
return s.length === s.buffer.byteLength ? s.buffer : s.buffer.slice(s.byteOffset, s.byteOffset + s.length);
|
128 |
+
}
|
129 |
+
function Ee(s) {
|
130 |
+
if (Ee.readOnly = !0, Buffer.isBuffer(s))
|
131 |
+
return s;
|
132 |
+
let e;
|
133 |
+
return s instanceof ArrayBuffer ? e = new Se(s) : ArrayBuffer.isView(s) ? e = new Se(s.buffer, s.byteOffset, s.byteLength) : (e = Buffer.from(s), Ee.readOnly = !1), e;
|
134 |
+
}
|
135 |
+
te.exports = {
|
136 |
+
concat: wt,
|
137 |
+
mask: Je,
|
138 |
+
toArrayBuffer: Ot,
|
139 |
+
toBuffer: Ee,
|
140 |
+
unmask: et
|
141 |
+
};
|
142 |
+
if (!process.env.WS_NO_BUFFER_UTIL)
|
143 |
+
try {
|
144 |
+
const s = require("bufferutil");
|
145 |
+
xt = te.exports.mask = function(e, t, r, i, n) {
|
146 |
+
n < 48 ? Je(e, t, r, i, n) : s.mask(e, t, r, i, n);
|
147 |
+
}, bt = te.exports.unmask = function(e, t) {
|
148 |
+
e.length < 32 ? et(e, t) : s.unmask(e, t);
|
149 |
+
};
|
150 |
+
} catch {
|
151 |
+
}
|
152 |
+
var ne = te.exports;
|
153 |
+
const Ce = Symbol("kDone"), ue = Symbol("kRun");
|
154 |
+
let Ct = class {
|
155 |
+
/**
|
156 |
+
* Creates a new `Limiter`.
|
157 |
+
*
|
158 |
+
* @param {Number} [concurrency=Infinity] The maximum number of jobs allowed
|
159 |
+
* to run concurrently
|
160 |
+
*/
|
161 |
+
constructor(e) {
|
162 |
+
this[Ce] = () => {
|
163 |
+
this.pending--, this[ue]();
|
164 |
+
}, this.concurrency = e || 1 / 0, this.jobs = [], this.pending = 0;
|
165 |
+
}
|
166 |
+
/**
|
167 |
+
* Adds a job to the queue.
|
168 |
+
*
|
169 |
+
* @param {Function} job The job to run
|
170 |
+
* @public
|
171 |
+
*/
|
172 |
+
add(e) {
|
173 |
+
this.jobs.push(e), this[ue]();
|
174 |
+
}
|
175 |
+
/**
|
176 |
+
* Removes a job from the queue and runs it if possible.
|
177 |
+
*
|
178 |
+
* @private
|
179 |
+
*/
|
180 |
+
[ue]() {
|
181 |
+
if (this.pending !== this.concurrency && this.jobs.length) {
|
182 |
+
const e = this.jobs.shift();
|
183 |
+
this.pending++, e(this[Ce]);
|
184 |
+
}
|
185 |
+
}
|
186 |
+
};
|
187 |
+
var Tt = Ct;
|
188 |
+
const W = S, Te = ne, Lt = Tt, { kStatusCode: tt } = U, Nt = Buffer[Symbol.species], Pt = Buffer.from([0, 0, 255, 255]), se = Symbol("permessage-deflate"), w = Symbol("total-length"), V = Symbol("callback"), C = Symbol("buffers"), J = Symbol("error");
|
189 |
+
let K, Rt = class {
|
190 |
+
/**
|
191 |
+
* Creates a PerMessageDeflate instance.
|
192 |
+
*
|
193 |
+
* @param {Object} [options] Configuration options
|
194 |
+
* @param {(Boolean|Number)} [options.clientMaxWindowBits] Advertise support
|
195 |
+
* for, or request, a custom client window size
|
196 |
+
* @param {Boolean} [options.clientNoContextTakeover=false] Advertise/
|
197 |
+
* acknowledge disabling of client context takeover
|
198 |
+
* @param {Number} [options.concurrencyLimit=10] The number of concurrent
|
199 |
+
* calls to zlib
|
200 |
+
* @param {(Boolean|Number)} [options.serverMaxWindowBits] Request/confirm the
|
201 |
+
* use of a custom server window size
|
202 |
+
* @param {Boolean} [options.serverNoContextTakeover=false] Request/accept
|
203 |
+
* disabling of server context takeover
|
204 |
+
* @param {Number} [options.threshold=1024] Size (in bytes) below which
|
205 |
+
* messages should not be compressed if context takeover is disabled
|
206 |
+
* @param {Object} [options.zlibDeflateOptions] Options to pass to zlib on
|
207 |
+
* deflate
|
208 |
+
* @param {Object} [options.zlibInflateOptions] Options to pass to zlib on
|
209 |
+
* inflate
|
210 |
+
* @param {Boolean} [isServer=false] Create the instance in either server or
|
211 |
+
* client mode
|
212 |
+
* @param {Number} [maxPayload=0] The maximum allowed message length
|
213 |
+
*/
|
214 |
+
constructor(e, t, r) {
|
215 |
+
if (this._maxPayload = r | 0, this._options = e || {}, this._threshold = this._options.threshold !== void 0 ? this._options.threshold : 1024, this._isServer = !!t, this._deflate = null, this._inflate = null, this.params = null, !K) {
|
216 |
+
const i = this._options.concurrencyLimit !== void 0 ? this._options.concurrencyLimit : 10;
|
217 |
+
K = new Lt(i);
|
218 |
+
}
|
219 |
+
}
|
220 |
+
/**
|
221 |
+
* @type {String}
|
222 |
+
*/
|
223 |
+
static get extensionName() {
|
224 |
+
return "permessage-deflate";
|
225 |
+
}
|
226 |
+
/**
|
227 |
+
* Create an extension negotiation offer.
|
228 |
+
*
|
229 |
+
* @return {Object} Extension parameters
|
230 |
+
* @public
|
231 |
+
*/
|
232 |
+
offer() {
|
233 |
+
const e = {};
|
234 |
+
return this._options.serverNoContextTakeover && (e.server_no_context_takeover = !0), this._options.clientNoContextTakeover && (e.client_no_context_takeover = !0), this._options.serverMaxWindowBits && (e.server_max_window_bits = this._options.serverMaxWindowBits), this._options.clientMaxWindowBits ? e.client_max_window_bits = this._options.clientMaxWindowBits : this._options.clientMaxWindowBits == null && (e.client_max_window_bits = !0), e;
|
235 |
+
}
|
236 |
+
/**
|
237 |
+
* Accept an extension negotiation offer/response.
|
238 |
+
*
|
239 |
+
* @param {Array} configurations The extension negotiation offers/reponse
|
240 |
+
* @return {Object} Accepted configuration
|
241 |
+
* @public
|
242 |
+
*/
|
243 |
+
accept(e) {
|
244 |
+
return e = this.normalizeParams(e), this.params = this._isServer ? this.acceptAsServer(e) : this.acceptAsClient(e), this.params;
|
245 |
+
}
|
246 |
+
/**
|
247 |
+
* Releases all resources used by the extension.
|
248 |
+
*
|
249 |
+
* @public
|
250 |
+
*/
|
251 |
+
cleanup() {
|
252 |
+
if (this._inflate && (this._inflate.close(), this._inflate = null), this._deflate) {
|
253 |
+
const e = this._deflate[V];
|
254 |
+
this._deflate.close(), this._deflate = null, e && e(
|
255 |
+
new Error(
|
256 |
+
"The deflate stream was closed while data was being processed"
|
257 |
+
)
|
258 |
+
);
|
259 |
+
}
|
260 |
+
}
|
261 |
+
/**
|
262 |
+
* Accept an extension negotiation offer.
|
263 |
+
*
|
264 |
+
* @param {Array} offers The extension negotiation offers
|
265 |
+
* @return {Object} Accepted configuration
|
266 |
+
* @private
|
267 |
+
*/
|
268 |
+
acceptAsServer(e) {
|
269 |
+
const t = this._options, r = e.find((i) => !(t.serverNoContextTakeover === !1 && i.server_no_context_takeover || i.server_max_window_bits && (t.serverMaxWindowBits === !1 || typeof t.serverMaxWindowBits == "number" && t.serverMaxWindowBits > i.server_max_window_bits) || typeof t.clientMaxWindowBits == "number" && !i.client_max_window_bits));
|
270 |
+
if (!r)
|
271 |
+
throw new Error("None of the extension offers can be accepted");
|
272 |
+
return t.serverNoContextTakeover && (r.server_no_context_takeover = !0), t.clientNoContextTakeover && (r.client_no_context_takeover = !0), typeof t.serverMaxWindowBits == "number" && (r.server_max_window_bits = t.serverMaxWindowBits), typeof t.clientMaxWindowBits == "number" ? r.client_max_window_bits = t.clientMaxWindowBits : (r.client_max_window_bits === !0 || t.clientMaxWindowBits === !1) && delete r.client_max_window_bits, r;
|
273 |
+
}
|
274 |
+
/**
|
275 |
+
* Accept the extension negotiation response.
|
276 |
+
*
|
277 |
+
* @param {Array} response The extension negotiation response
|
278 |
+
* @return {Object} Accepted configuration
|
279 |
+
* @private
|
280 |
+
*/
|
281 |
+
acceptAsClient(e) {
|
282 |
+
const t = e[0];
|
283 |
+
if (this._options.clientNoContextTakeover === !1 && t.client_no_context_takeover)
|
284 |
+
throw new Error('Unexpected parameter "client_no_context_takeover"');
|
285 |
+
if (!t.client_max_window_bits)
|
286 |
+
typeof this._options.clientMaxWindowBits == "number" && (t.client_max_window_bits = this._options.clientMaxWindowBits);
|
287 |
+
else if (this._options.clientMaxWindowBits === !1 || typeof this._options.clientMaxWindowBits == "number" && t.client_max_window_bits > this._options.clientMaxWindowBits)
|
288 |
+
throw new Error(
|
289 |
+
'Unexpected or invalid parameter "client_max_window_bits"'
|
290 |
+
);
|
291 |
+
return t;
|
292 |
+
}
|
293 |
+
/**
|
294 |
+
* Normalize parameters.
|
295 |
+
*
|
296 |
+
* @param {Array} configurations The extension negotiation offers/reponse
|
297 |
+
* @return {Array} The offers/response with normalized parameters
|
298 |
+
* @private
|
299 |
+
*/
|
300 |
+
normalizeParams(e) {
|
301 |
+
return e.forEach((t) => {
|
302 |
+
Object.keys(t).forEach((r) => {
|
303 |
+
let i = t[r];
|
304 |
+
if (i.length > 1)
|
305 |
+
throw new Error(`Parameter "${r}" must have only a single value`);
|
306 |
+
if (i = i[0], r === "client_max_window_bits") {
|
307 |
+
if (i !== !0) {
|
308 |
+
const n = +i;
|
309 |
+
if (!Number.isInteger(n) || n < 8 || n > 15)
|
310 |
+
throw new TypeError(
|
311 |
+
`Invalid value for parameter "${r}": ${i}`
|
312 |
+
);
|
313 |
+
i = n;
|
314 |
+
} else if (!this._isServer)
|
315 |
+
throw new TypeError(
|
316 |
+
`Invalid value for parameter "${r}": ${i}`
|
317 |
+
);
|
318 |
+
} else if (r === "server_max_window_bits") {
|
319 |
+
const n = +i;
|
320 |
+
if (!Number.isInteger(n) || n < 8 || n > 15)
|
321 |
+
throw new TypeError(
|
322 |
+
`Invalid value for parameter "${r}": ${i}`
|
323 |
+
);
|
324 |
+
i = n;
|
325 |
+
} else if (r === "client_no_context_takeover" || r === "server_no_context_takeover") {
|
326 |
+
if (i !== !0)
|
327 |
+
throw new TypeError(
|
328 |
+
`Invalid value for parameter "${r}": ${i}`
|
329 |
+
);
|
330 |
+
} else
|
331 |
+
throw new Error(`Unknown parameter "${r}"`);
|
332 |
+
t[r] = i;
|
333 |
+
});
|
334 |
+
}), e;
|
335 |
+
}
|
336 |
+
/**
|
337 |
+
* Decompress data. Concurrency limited.
|
338 |
+
*
|
339 |
+
* @param {Buffer} data Compressed data
|
340 |
+
* @param {Boolean} fin Specifies whether or not this is the last fragment
|
341 |
+
* @param {Function} callback Callback
|
342 |
+
* @public
|
343 |
+
*/
|
344 |
+
decompress(e, t, r) {
|
345 |
+
K.add((i) => {
|
346 |
+
this._decompress(e, t, (n, o) => {
|
347 |
+
i(), r(n, o);
|
348 |
+
});
|
349 |
+
});
|
350 |
+
}
|
351 |
+
/**
|
352 |
+
* Compress data. Concurrency limited.
|
353 |
+
*
|
354 |
+
* @param {(Buffer|String)} data Data to compress
|
355 |
+
* @param {Boolean} fin Specifies whether or not this is the last fragment
|
356 |
+
* @param {Function} callback Callback
|
357 |
+
* @public
|
358 |
+
*/
|
359 |
+
compress(e, t, r) {
|
360 |
+
K.add((i) => {
|
361 |
+
this._compress(e, t, (n, o) => {
|
362 |
+
i(), r(n, o);
|
363 |
+
});
|
364 |
+
});
|
365 |
+
}
|
366 |
+
/**
|
367 |
+
* Decompress data.
|
368 |
+
*
|
369 |
+
* @param {Buffer} data Compressed data
|
370 |
+
* @param {Boolean} fin Specifies whether or not this is the last fragment
|
371 |
+
* @param {Function} callback Callback
|
372 |
+
* @private
|
373 |
+
*/
|
374 |
+
_decompress(e, t, r) {
|
375 |
+
const i = this._isServer ? "client" : "server";
|
376 |
+
if (!this._inflate) {
|
377 |
+
const n = `${i}_max_window_bits`, o = typeof this.params[n] != "number" ? W.Z_DEFAULT_WINDOWBITS : this.params[n];
|
378 |
+
this._inflate = W.createInflateRaw({
|
379 |
+
...this._options.zlibInflateOptions,
|
380 |
+
windowBits: o
|
381 |
+
}), this._inflate[se] = this, this._inflate[w] = 0, this._inflate[C] = [], this._inflate.on("error", Bt), this._inflate.on("data", st);
|
382 |
+
}
|
383 |
+
this._inflate[V] = r, this._inflate.write(e), t && this._inflate.write(Pt), this._inflate.flush(() => {
|
384 |
+
const n = this._inflate[J];
|
385 |
+
if (n) {
|
386 |
+
this._inflate.close(), this._inflate = null, r(n);
|
387 |
+
return;
|
388 |
+
}
|
389 |
+
const o = Te.concat(
|
390 |
+
this._inflate[C],
|
391 |
+
this._inflate[w]
|
392 |
+
);
|
393 |
+
this._inflate._readableState.endEmitted ? (this._inflate.close(), this._inflate = null) : (this._inflate[w] = 0, this._inflate[C] = [], t && this.params[`${i}_no_context_takeover`] && this._inflate.reset()), r(null, o);
|
394 |
+
});
|
395 |
+
}
|
396 |
+
/**
|
397 |
+
* Compress data.
|
398 |
+
*
|
399 |
+
* @param {(Buffer|String)} data Data to compress
|
400 |
+
* @param {Boolean} fin Specifies whether or not this is the last fragment
|
401 |
+
* @param {Function} callback Callback
|
402 |
+
* @private
|
403 |
+
*/
|
404 |
+
_compress(e, t, r) {
|
405 |
+
const i = this._isServer ? "server" : "client";
|
406 |
+
if (!this._deflate) {
|
407 |
+
const n = `${i}_max_window_bits`, o = typeof this.params[n] != "number" ? W.Z_DEFAULT_WINDOWBITS : this.params[n];
|
408 |
+
this._deflate = W.createDeflateRaw({
|
409 |
+
...this._options.zlibDeflateOptions,
|
410 |
+
windowBits: o
|
411 |
+
}), this._deflate[w] = 0, this._deflate[C] = [], this._deflate.on("data", Ut);
|
412 |
+
}
|
413 |
+
this._deflate[V] = r, this._deflate.write(e), this._deflate.flush(W.Z_SYNC_FLUSH, () => {
|
414 |
+
if (!this._deflate)
|
415 |
+
return;
|
416 |
+
let n = Te.concat(
|
417 |
+
this._deflate[C],
|
418 |
+
this._deflate[w]
|
419 |
+
);
|
420 |
+
t && (n = new Nt(n.buffer, n.byteOffset, n.length - 4)), this._deflate[V] = null, this._deflate[w] = 0, this._deflate[C] = [], t && this.params[`${i}_no_context_takeover`] && this._deflate.reset(), r(null, n);
|
421 |
+
});
|
422 |
+
}
|
423 |
+
};
|
424 |
+
var oe = Rt;
|
425 |
+
function Ut(s) {
|
426 |
+
this[C].push(s), this[w] += s.length;
|
427 |
+
}
|
428 |
+
function st(s) {
|
429 |
+
if (this[w] += s.length, this[se]._maxPayload < 1 || this[w] <= this[se]._maxPayload) {
|
430 |
+
this[C].push(s);
|
431 |
+
return;
|
432 |
+
}
|
433 |
+
this[J] = new RangeError("Max payload size exceeded"), this[J].code = "WS_ERR_UNSUPPORTED_MESSAGE_LENGTH", this[J][tt] = 1009, this.removeListener("data", st), this.reset();
|
434 |
+
}
|
435 |
+
function Bt(s) {
|
436 |
+
this[se]._inflate = null, s[tt] = 1007, this[V](s);
|
437 |
+
}
|
438 |
+
var re = { exports: {} };
|
439 |
+
const $t = {}, Mt = /* @__PURE__ */ Object.freeze(/* @__PURE__ */ Object.defineProperty({
|
440 |
+
__proto__: null,
|
441 |
+
default: $t
|
442 |
+
}, Symbol.toStringTag, { value: "Module" })), It = /* @__PURE__ */ gt(Mt);
|
443 |
+
var Le;
|
444 |
+
const { isUtf8: Ne } = S, Dt = [
|
445 |
+
0,
|
446 |
+
0,
|
447 |
+
0,
|
448 |
+
0,
|
449 |
+
0,
|
450 |
+
0,
|
451 |
+
0,
|
452 |
+
0,
|
453 |
+
0,
|
454 |
+
0,
|
455 |
+
0,
|
456 |
+
0,
|
457 |
+
0,
|
458 |
+
0,
|
459 |
+
0,
|
460 |
+
0,
|
461 |
+
// 0 - 15
|
462 |
+
0,
|
463 |
+
0,
|
464 |
+
0,
|
465 |
+
0,
|
466 |
+
0,
|
467 |
+
0,
|
468 |
+
0,
|
469 |
+
0,
|
470 |
+
0,
|
471 |
+
0,
|
472 |
+
0,
|
473 |
+
0,
|
474 |
+
0,
|
475 |
+
0,
|
476 |
+
0,
|
477 |
+
0,
|
478 |
+
// 16 - 31
|
479 |
+
0,
|
480 |
+
1,
|
481 |
+
0,
|
482 |
+
1,
|
483 |
+
1,
|
484 |
+
1,
|
485 |
+
1,
|
486 |
+
1,
|
487 |
+
0,
|
488 |
+
0,
|
489 |
+
1,
|
490 |
+
1,
|
491 |
+
0,
|
492 |
+
1,
|
493 |
+
1,
|
494 |
+
0,
|
495 |
+
// 32 - 47
|
496 |
+
1,
|
497 |
+
1,
|
498 |
+
1,
|
499 |
+
1,
|
500 |
+
1,
|
501 |
+
1,
|
502 |
+
1,
|
503 |
+
1,
|
504 |
+
1,
|
505 |
+
1,
|
506 |
+
0,
|
507 |
+
0,
|
508 |
+
0,
|
509 |
+
0,
|
510 |
+
0,
|
511 |
+
0,
|
512 |
+
// 48 - 63
|
513 |
+
0,
|
514 |
+
1,
|
515 |
+
1,
|
516 |
+
1,
|
517 |
+
1,
|
518 |
+
1,
|
519 |
+
1,
|
520 |
+
1,
|
521 |
+
1,
|
522 |
+
1,
|
523 |
+
1,
|
524 |
+
1,
|
525 |
+
1,
|
526 |
+
1,
|
527 |
+
1,
|
528 |
+
1,
|
529 |
+
// 64 - 79
|
530 |
+
1,
|
531 |
+
1,
|
532 |
+
1,
|
533 |
+
1,
|
534 |
+
1,
|
535 |
+
1,
|
536 |
+
1,
|
537 |
+
1,
|
538 |
+
1,
|
539 |
+
1,
|
540 |
+
1,
|
541 |
+
0,
|
542 |
+
0,
|
543 |
+
0,
|
544 |
+
1,
|
545 |
+
1,
|
546 |
+
// 80 - 95
|
547 |
+
1,
|
548 |
+
1,
|
549 |
+
1,
|
550 |
+
1,
|
551 |
+
1,
|
552 |
+
1,
|
553 |
+
1,
|
554 |
+
1,
|
555 |
+
1,
|
556 |
+
1,
|
557 |
+
1,
|
558 |
+
1,
|
559 |
+
1,
|
560 |
+
1,
|
561 |
+
1,
|
562 |
+
1,
|
563 |
+
// 96 - 111
|
564 |
+
1,
|
565 |
+
1,
|
566 |
+
1,
|
567 |
+
1,
|
568 |
+
1,
|
569 |
+
1,
|
570 |
+
1,
|
571 |
+
1,
|
572 |
+
1,
|
573 |
+
1,
|
574 |
+
1,
|
575 |
+
0,
|
576 |
+
1,
|
577 |
+
0,
|
578 |
+
1,
|
579 |
+
0
|
580 |
+
// 112 - 127
|
581 |
+
];
|
582 |
+
function Wt(s) {
|
583 |
+
return s >= 1e3 && s <= 1014 && s !== 1004 && s !== 1005 && s !== 1006 || s >= 3e3 && s <= 4999;
|
584 |
+
}
|
585 |
+
function be(s) {
|
586 |
+
const e = s.length;
|
587 |
+
let t = 0;
|
588 |
+
for (; t < e; )
|
589 |
+
if (!(s[t] & 128))
|
590 |
+
t++;
|
591 |
+
else if ((s[t] & 224) === 192) {
|
592 |
+
if (t + 1 === e || (s[t + 1] & 192) !== 128 || (s[t] & 254) === 192)
|
593 |
+
return !1;
|
594 |
+
t += 2;
|
595 |
+
} else if ((s[t] & 240) === 224) {
|
596 |
+
if (t + 2 >= e || (s[t + 1] & 192) !== 128 || (s[t + 2] & 192) !== 128 || s[t] === 224 && (s[t + 1] & 224) === 128 || // Overlong
|
597 |
+
s[t] === 237 && (s[t + 1] & 224) === 160)
|
598 |
+
return !1;
|
599 |
+
t += 3;
|
600 |
+
} else if ((s[t] & 248) === 240) {
|
601 |
+
if (t + 3 >= e || (s[t + 1] & 192) !== 128 || (s[t + 2] & 192) !== 128 || (s[t + 3] & 192) !== 128 || s[t] === 240 && (s[t + 1] & 240) === 128 || // Overlong
|
602 |
+
s[t] === 244 && s[t + 1] > 143 || s[t] > 244)
|
603 |
+
return !1;
|
604 |
+
t += 4;
|
605 |
+
} else
|
606 |
+
return !1;
|
607 |
+
return !0;
|
608 |
+
}
|
609 |
+
re.exports = {
|
610 |
+
isValidStatusCode: Wt,
|
611 |
+
isValidUTF8: be,
|
612 |
+
tokenChars: Dt
|
613 |
+
};
|
614 |
+
if (Ne)
|
615 |
+
Le = re.exports.isValidUTF8 = function(s) {
|
616 |
+
return s.length < 24 ? be(s) : Ne(s);
|
617 |
+
};
|
618 |
+
else if (!process.env.WS_NO_UTF_8_VALIDATE)
|
619 |
+
try {
|
620 |
+
const s = It;
|
621 |
+
Le = re.exports.isValidUTF8 = function(e) {
|
622 |
+
return e.length < 32 ? be(e) : s(e);
|
623 |
+
};
|
624 |
+
} catch {
|
625 |
+
}
|
626 |
+
var ae = re.exports;
|
627 |
+
const { Writable: At } = S, Pe = oe, {
|
628 |
+
BINARY_TYPES: Ft,
|
629 |
+
EMPTY_BUFFER: Re,
|
630 |
+
kStatusCode: jt,
|
631 |
+
kWebSocket: Gt
|
632 |
+
} = U, { concat: de, toArrayBuffer: Vt, unmask: Ht } = ne, { isValidStatusCode: zt, isValidUTF8: Ue } = ae, X = Buffer[Symbol.species], A = 0, Be = 1, $e = 2, Me = 3, _e = 4, Yt = 5;
|
633 |
+
let qt = class extends At {
|
634 |
+
/**
|
635 |
+
* Creates a Receiver instance.
|
636 |
+
*
|
637 |
+
* @param {Object} [options] Options object
|
638 |
+
* @param {String} [options.binaryType=nodebuffer] The type for binary data
|
639 |
+
* @param {Object} [options.extensions] An object containing the negotiated
|
640 |
+
* extensions
|
641 |
+
* @param {Boolean} [options.isServer=false] Specifies whether to operate in
|
642 |
+
* client or server mode
|
643 |
+
* @param {Number} [options.maxPayload=0] The maximum allowed message length
|
644 |
+
* @param {Boolean} [options.skipUTF8Validation=false] Specifies whether or
|
645 |
+
* not to skip UTF-8 validation for text and close messages
|
646 |
+
*/
|
647 |
+
constructor(e = {}) {
|
648 |
+
super(), this._binaryType = e.binaryType || Ft[0], this._extensions = e.extensions || {}, this._isServer = !!e.isServer, this._maxPayload = e.maxPayload | 0, this._skipUTF8Validation = !!e.skipUTF8Validation, this[Gt] = void 0, this._bufferedBytes = 0, this._buffers = [], this._compressed = !1, this._payloadLength = 0, this._mask = void 0, this._fragmented = 0, this._masked = !1, this._fin = !1, this._opcode = 0, this._totalPayloadLength = 0, this._messageLength = 0, this._fragments = [], this._state = A, this._loop = !1;
|
649 |
+
}
|
650 |
+
/**
|
651 |
+
* Implements `Writable.prototype._write()`.
|
652 |
+
*
|
653 |
+
* @param {Buffer} chunk The chunk of data to write
|
654 |
+
* @param {String} encoding The character encoding of `chunk`
|
655 |
+
* @param {Function} cb Callback
|
656 |
+
* @private
|
657 |
+
*/
|
658 |
+
_write(e, t, r) {
|
659 |
+
if (this._opcode === 8 && this._state == A)
|
660 |
+
return r();
|
661 |
+
this._bufferedBytes += e.length, this._buffers.push(e), this.startLoop(r);
|
662 |
+
}
|
663 |
+
/**
|
664 |
+
* Consumes `n` bytes from the buffered data.
|
665 |
+
*
|
666 |
+
* @param {Number} n The number of bytes to consume
|
667 |
+
* @return {Buffer} The consumed bytes
|
668 |
+
* @private
|
669 |
+
*/
|
670 |
+
consume(e) {
|
671 |
+
if (this._bufferedBytes -= e, e === this._buffers[0].length)
|
672 |
+
return this._buffers.shift();
|
673 |
+
if (e < this._buffers[0].length) {
|
674 |
+
const r = this._buffers[0];
|
675 |
+
return this._buffers[0] = new X(
|
676 |
+
r.buffer,
|
677 |
+
r.byteOffset + e,
|
678 |
+
r.length - e
|
679 |
+
), new X(r.buffer, r.byteOffset, e);
|
680 |
+
}
|
681 |
+
const t = Buffer.allocUnsafe(e);
|
682 |
+
do {
|
683 |
+
const r = this._buffers[0], i = t.length - e;
|
684 |
+
e >= r.length ? t.set(this._buffers.shift(), i) : (t.set(new Uint8Array(r.buffer, r.byteOffset, e), i), this._buffers[0] = new X(
|
685 |
+
r.buffer,
|
686 |
+
r.byteOffset + e,
|
687 |
+
r.length - e
|
688 |
+
)), e -= r.length;
|
689 |
+
} while (e > 0);
|
690 |
+
return t;
|
691 |
+
}
|
692 |
+
/**
|
693 |
+
* Starts the parsing loop.
|
694 |
+
*
|
695 |
+
* @param {Function} cb Callback
|
696 |
+
* @private
|
697 |
+
*/
|
698 |
+
startLoop(e) {
|
699 |
+
let t;
|
700 |
+
this._loop = !0;
|
701 |
+
do
|
702 |
+
switch (this._state) {
|
703 |
+
case A:
|
704 |
+
t = this.getInfo();
|
705 |
+
break;
|
706 |
+
case Be:
|
707 |
+
t = this.getPayloadLength16();
|
708 |
+
break;
|
709 |
+
case $e:
|
710 |
+
t = this.getPayloadLength64();
|
711 |
+
break;
|
712 |
+
case Me:
|
713 |
+
this.getMask();
|
714 |
+
break;
|
715 |
+
case _e:
|
716 |
+
t = this.getData(e);
|
717 |
+
break;
|
718 |
+
default:
|
719 |
+
this._loop = !1;
|
720 |
+
return;
|
721 |
+
}
|
722 |
+
while (this._loop);
|
723 |
+
e(t);
|
724 |
+
}
|
725 |
+
/**
|
726 |
+
* Reads the first two bytes of a frame.
|
727 |
+
*
|
728 |
+
* @return {(RangeError|undefined)} A possible error
|
729 |
+
* @private
|
730 |
+
*/
|
731 |
+
getInfo() {
|
732 |
+
if (this._bufferedBytes < 2) {
|
733 |
+
this._loop = !1;
|
734 |
+
return;
|
735 |
+
}
|
736 |
+
const e = this.consume(2);
|
737 |
+
if (e[0] & 48)
|
738 |
+
return this._loop = !1, g(
|
739 |
+
RangeError,
|
740 |
+
"RSV2 and RSV3 must be clear",
|
741 |
+
!0,
|
742 |
+
1002,
|
743 |
+
"WS_ERR_UNEXPECTED_RSV_2_3"
|
744 |
+
);
|
745 |
+
const t = (e[0] & 64) === 64;
|
746 |
+
if (t && !this._extensions[Pe.extensionName])
|
747 |
+
return this._loop = !1, g(
|
748 |
+
RangeError,
|
749 |
+
"RSV1 must be clear",
|
750 |
+
!0,
|
751 |
+
1002,
|
752 |
+
"WS_ERR_UNEXPECTED_RSV_1"
|
753 |
+
);
|
754 |
+
if (this._fin = (e[0] & 128) === 128, this._opcode = e[0] & 15, this._payloadLength = e[1] & 127, this._opcode === 0) {
|
755 |
+
if (t)
|
756 |
+
return this._loop = !1, g(
|
757 |
+
RangeError,
|
758 |
+
"RSV1 must be clear",
|
759 |
+
!0,
|
760 |
+
1002,
|
761 |
+
"WS_ERR_UNEXPECTED_RSV_1"
|
762 |
+
);
|
763 |
+
if (!this._fragmented)
|
764 |
+
return this._loop = !1, g(
|
765 |
+
RangeError,
|
766 |
+
"invalid opcode 0",
|
767 |
+
!0,
|
768 |
+
1002,
|
769 |
+
"WS_ERR_INVALID_OPCODE"
|
770 |
+
);
|
771 |
+
this._opcode = this._fragmented;
|
772 |
+
} else if (this._opcode === 1 || this._opcode === 2) {
|
773 |
+
if (this._fragmented)
|
774 |
+
return this._loop = !1, g(
|
775 |
+
RangeError,
|
776 |
+
`invalid opcode ${this._opcode}`,
|
777 |
+
!0,
|
778 |
+
1002,
|
779 |
+
"WS_ERR_INVALID_OPCODE"
|
780 |
+
);
|
781 |
+
this._compressed = t;
|
782 |
+
} else if (this._opcode > 7 && this._opcode < 11) {
|
783 |
+
if (!this._fin)
|
784 |
+
return this._loop = !1, g(
|
785 |
+
RangeError,
|
786 |
+
"FIN must be set",
|
787 |
+
!0,
|
788 |
+
1002,
|
789 |
+
"WS_ERR_EXPECTED_FIN"
|
790 |
+
);
|
791 |
+
if (t)
|
792 |
+
return this._loop = !1, g(
|
793 |
+
RangeError,
|
794 |
+
"RSV1 must be clear",
|
795 |
+
!0,
|
796 |
+
1002,
|
797 |
+
"WS_ERR_UNEXPECTED_RSV_1"
|
798 |
+
);
|
799 |
+
if (this._payloadLength > 125 || this._opcode === 8 && this._payloadLength === 1)
|
800 |
+
return this._loop = !1, g(
|
801 |
+
RangeError,
|
802 |
+
`invalid payload length ${this._payloadLength}`,
|
803 |
+
!0,
|
804 |
+
1002,
|
805 |
+
"WS_ERR_INVALID_CONTROL_PAYLOAD_LENGTH"
|
806 |
+
);
|
807 |
+
} else
|
808 |
+
return this._loop = !1, g(
|
809 |
+
RangeError,
|
810 |
+
`invalid opcode ${this._opcode}`,
|
811 |
+
!0,
|
812 |
+
1002,
|
813 |
+
"WS_ERR_INVALID_OPCODE"
|
814 |
+
);
|
815 |
+
if (!this._fin && !this._fragmented && (this._fragmented = this._opcode), this._masked = (e[1] & 128) === 128, this._isServer) {
|
816 |
+
if (!this._masked)
|
817 |
+
return this._loop = !1, g(
|
818 |
+
RangeError,
|
819 |
+
"MASK must be set",
|
820 |
+
!0,
|
821 |
+
1002,
|
822 |
+
"WS_ERR_EXPECTED_MASK"
|
823 |
+
);
|
824 |
+
} else if (this._masked)
|
825 |
+
return this._loop = !1, g(
|
826 |
+
RangeError,
|
827 |
+
"MASK must be clear",
|
828 |
+
!0,
|
829 |
+
1002,
|
830 |
+
"WS_ERR_UNEXPECTED_MASK"
|
831 |
+
);
|
832 |
+
if (this._payloadLength === 126)
|
833 |
+
this._state = Be;
|
834 |
+
else if (this._payloadLength === 127)
|
835 |
+
this._state = $e;
|
836 |
+
else
|
837 |
+
return this.haveLength();
|
838 |
+
}
|
839 |
+
/**
|
840 |
+
* Gets extended payload length (7+16).
|
841 |
+
*
|
842 |
+
* @return {(RangeError|undefined)} A possible error
|
843 |
+
* @private
|
844 |
+
*/
|
845 |
+
getPayloadLength16() {
|
846 |
+
if (this._bufferedBytes < 2) {
|
847 |
+
this._loop = !1;
|
848 |
+
return;
|
849 |
+
}
|
850 |
+
return this._payloadLength = this.consume(2).readUInt16BE(0), this.haveLength();
|
851 |
+
}
|
852 |
+
/**
|
853 |
+
* Gets extended payload length (7+64).
|
854 |
+
*
|
855 |
+
* @return {(RangeError|undefined)} A possible error
|
856 |
+
* @private
|
857 |
+
*/
|
858 |
+
getPayloadLength64() {
|
859 |
+
if (this._bufferedBytes < 8) {
|
860 |
+
this._loop = !1;
|
861 |
+
return;
|
862 |
+
}
|
863 |
+
const e = this.consume(8), t = e.readUInt32BE(0);
|
864 |
+
return t > Math.pow(2, 53 - 32) - 1 ? (this._loop = !1, g(
|
865 |
+
RangeError,
|
866 |
+
"Unsupported WebSocket frame: payload length > 2^53 - 1",
|
867 |
+
!1,
|
868 |
+
1009,
|
869 |
+
"WS_ERR_UNSUPPORTED_DATA_PAYLOAD_LENGTH"
|
870 |
+
)) : (this._payloadLength = t * Math.pow(2, 32) + e.readUInt32BE(4), this.haveLength());
|
871 |
+
}
|
872 |
+
/**
|
873 |
+
* Payload length has been read.
|
874 |
+
*
|
875 |
+
* @return {(RangeError|undefined)} A possible error
|
876 |
+
* @private
|
877 |
+
*/
|
878 |
+
haveLength() {
|
879 |
+
if (this._payloadLength && this._opcode < 8 && (this._totalPayloadLength += this._payloadLength, this._totalPayloadLength > this._maxPayload && this._maxPayload > 0))
|
880 |
+
return this._loop = !1, g(
|
881 |
+
RangeError,
|
882 |
+
"Max payload size exceeded",
|
883 |
+
!1,
|
884 |
+
1009,
|
885 |
+
"WS_ERR_UNSUPPORTED_MESSAGE_LENGTH"
|
886 |
+
);
|
887 |
+
this._masked ? this._state = Me : this._state = _e;
|
888 |
+
}
|
889 |
+
/**
|
890 |
+
* Reads mask bytes.
|
891 |
+
*
|
892 |
+
* @private
|
893 |
+
*/
|
894 |
+
getMask() {
|
895 |
+
if (this._bufferedBytes < 4) {
|
896 |
+
this._loop = !1;
|
897 |
+
return;
|
898 |
+
}
|
899 |
+
this._mask = this.consume(4), this._state = _e;
|
900 |
+
}
|
901 |
+
/**
|
902 |
+
* Reads data bytes.
|
903 |
+
*
|
904 |
+
* @param {Function} cb Callback
|
905 |
+
* @return {(Error|RangeError|undefined)} A possible error
|
906 |
+
* @private
|
907 |
+
*/
|
908 |
+
getData(e) {
|
909 |
+
let t = Re;
|
910 |
+
if (this._payloadLength) {
|
911 |
+
if (this._bufferedBytes < this._payloadLength) {
|
912 |
+
this._loop = !1;
|
913 |
+
return;
|
914 |
+
}
|
915 |
+
t = this.consume(this._payloadLength), this._masked && this._mask[0] | this._mask[1] | this._mask[2] | this._mask[3] && Ht(t, this._mask);
|
916 |
+
}
|
917 |
+
if (this._opcode > 7)
|
918 |
+
return this.controlMessage(t);
|
919 |
+
if (this._compressed) {
|
920 |
+
this._state = Yt, this.decompress(t, e);
|
921 |
+
return;
|
922 |
+
}
|
923 |
+
return t.length && (this._messageLength = this._totalPayloadLength, this._fragments.push(t)), this.dataMessage();
|
924 |
+
}
|
925 |
+
/**
|
926 |
+
* Decompresses data.
|
927 |
+
*
|
928 |
+
* @param {Buffer} data Compressed data
|
929 |
+
* @param {Function} cb Callback
|
930 |
+
* @private
|
931 |
+
*/
|
932 |
+
decompress(e, t) {
|
933 |
+
this._extensions[Pe.extensionName].decompress(e, this._fin, (i, n) => {
|
934 |
+
if (i)
|
935 |
+
return t(i);
|
936 |
+
if (n.length) {
|
937 |
+
if (this._messageLength += n.length, this._messageLength > this._maxPayload && this._maxPayload > 0)
|
938 |
+
return t(
|
939 |
+
g(
|
940 |
+
RangeError,
|
941 |
+
"Max payload size exceeded",
|
942 |
+
!1,
|
943 |
+
1009,
|
944 |
+
"WS_ERR_UNSUPPORTED_MESSAGE_LENGTH"
|
945 |
+
)
|
946 |
+
);
|
947 |
+
this._fragments.push(n);
|
948 |
+
}
|
949 |
+
const o = this.dataMessage();
|
950 |
+
if (o)
|
951 |
+
return t(o);
|
952 |
+
this.startLoop(t);
|
953 |
+
});
|
954 |
+
}
|
955 |
+
/**
|
956 |
+
* Handles a data message.
|
957 |
+
*
|
958 |
+
* @return {(Error|undefined)} A possible error
|
959 |
+
* @private
|
960 |
+
*/
|
961 |
+
dataMessage() {
|
962 |
+
if (this._fin) {
|
963 |
+
const e = this._messageLength, t = this._fragments;
|
964 |
+
if (this._totalPayloadLength = 0, this._messageLength = 0, this._fragmented = 0, this._fragments = [], this._opcode === 2) {
|
965 |
+
let r;
|
966 |
+
this._binaryType === "nodebuffer" ? r = de(t, e) : this._binaryType === "arraybuffer" ? r = Vt(de(t, e)) : r = t, this.emit("message", r, !0);
|
967 |
+
} else {
|
968 |
+
const r = de(t, e);
|
969 |
+
if (!this._skipUTF8Validation && !Ue(r))
|
970 |
+
return this._loop = !1, g(
|
971 |
+
Error,
|
972 |
+
"invalid UTF-8 sequence",
|
973 |
+
!0,
|
974 |
+
1007,
|
975 |
+
"WS_ERR_INVALID_UTF8"
|
976 |
+
);
|
977 |
+
this.emit("message", r, !1);
|
978 |
+
}
|
979 |
+
}
|
980 |
+
this._state = A;
|
981 |
+
}
|
982 |
+
/**
|
983 |
+
* Handles a control message.
|
984 |
+
*
|
985 |
+
* @param {Buffer} data Data to handle
|
986 |
+
* @return {(Error|RangeError|undefined)} A possible error
|
987 |
+
* @private
|
988 |
+
*/
|
989 |
+
controlMessage(e) {
|
990 |
+
if (this._opcode === 8)
|
991 |
+
if (this._loop = !1, e.length === 0)
|
992 |
+
this.emit("conclude", 1005, Re), this.end();
|
993 |
+
else {
|
994 |
+
const t = e.readUInt16BE(0);
|
995 |
+
if (!zt(t))
|
996 |
+
return g(
|
997 |
+
RangeError,
|
998 |
+
`invalid status code ${t}`,
|
999 |
+
!0,
|
1000 |
+
1002,
|
1001 |
+
"WS_ERR_INVALID_CLOSE_CODE"
|
1002 |
+
);
|
1003 |
+
const r = new X(
|
1004 |
+
e.buffer,
|
1005 |
+
e.byteOffset + 2,
|
1006 |
+
e.length - 2
|
1007 |
+
);
|
1008 |
+
if (!this._skipUTF8Validation && !Ue(r))
|
1009 |
+
return g(
|
1010 |
+
Error,
|
1011 |
+
"invalid UTF-8 sequence",
|
1012 |
+
!0,
|
1013 |
+
1007,
|
1014 |
+
"WS_ERR_INVALID_UTF8"
|
1015 |
+
);
|
1016 |
+
this.emit("conclude", t, r), this.end();
|
1017 |
+
}
|
1018 |
+
else
|
1019 |
+
this._opcode === 9 ? this.emit("ping", e) : this.emit("pong", e);
|
1020 |
+
this._state = A;
|
1021 |
+
}
|
1022 |
+
};
|
1023 |
+
var rt = qt;
|
1024 |
+
function g(s, e, t, r, i) {
|
1025 |
+
const n = new s(
|
1026 |
+
t ? `Invalid WebSocket frame: ${e}` : e
|
1027 |
+
);
|
1028 |
+
return Error.captureStackTrace(n, g), n.code = i, n[jt] = r, n;
|
1029 |
+
}
|
1030 |
+
const qs = /* @__PURE__ */ z(rt), { randomFillSync: Kt } = S, Ie = oe, { EMPTY_BUFFER: Xt } = U, { isValidStatusCode: Zt } = ae, { mask: De, toBuffer: M } = ne, x = Symbol("kByteLength"), Qt = Buffer.alloc(4);
|
1031 |
+
let Jt = class P {
|
1032 |
+
/**
|
1033 |
+
* Creates a Sender instance.
|
1034 |
+
*
|
1035 |
+
* @param {(net.Socket|tls.Socket)} socket The connection socket
|
1036 |
+
* @param {Object} [extensions] An object containing the negotiated extensions
|
1037 |
+
* @param {Function} [generateMask] The function used to generate the masking
|
1038 |
+
* key
|
1039 |
+
*/
|
1040 |
+
constructor(e, t, r) {
|
1041 |
+
this._extensions = t || {}, r && (this._generateMask = r, this._maskBuffer = Buffer.alloc(4)), this._socket = e, this._firstFragment = !0, this._compress = !1, this._bufferedBytes = 0, this._deflating = !1, this._queue = [];
|
1042 |
+
}
|
1043 |
+
/**
|
1044 |
+
* Frames a piece of data according to the HyBi WebSocket protocol.
|
1045 |
+
*
|
1046 |
+
* @param {(Buffer|String)} data The data to frame
|
1047 |
+
* @param {Object} options Options object
|
1048 |
+
* @param {Boolean} [options.fin=false] Specifies whether or not to set the
|
1049 |
+
* FIN bit
|
1050 |
+
* @param {Function} [options.generateMask] The function used to generate the
|
1051 |
+
* masking key
|
1052 |
+
* @param {Boolean} [options.mask=false] Specifies whether or not to mask
|
1053 |
+
* `data`
|
1054 |
+
* @param {Buffer} [options.maskBuffer] The buffer used to store the masking
|
1055 |
+
* key
|
1056 |
+
* @param {Number} options.opcode The opcode
|
1057 |
+
* @param {Boolean} [options.readOnly=false] Specifies whether `data` can be
|
1058 |
+
* modified
|
1059 |
+
* @param {Boolean} [options.rsv1=false] Specifies whether or not to set the
|
1060 |
+
* RSV1 bit
|
1061 |
+
* @return {(Buffer|String)[]} The framed data
|
1062 |
+
* @public
|
1063 |
+
*/
|
1064 |
+
static frame(e, t) {
|
1065 |
+
let r, i = !1, n = 2, o = !1;
|
1066 |
+
t.mask && (r = t.maskBuffer || Qt, t.generateMask ? t.generateMask(r) : Kt(r, 0, 4), o = (r[0] | r[1] | r[2] | r[3]) === 0, n = 6);
|
1067 |
+
let l;
|
1068 |
+
typeof e == "string" ? (!t.mask || o) && t[x] !== void 0 ? l = t[x] : (e = Buffer.from(e), l = e.length) : (l = e.length, i = t.mask && t.readOnly && !o);
|
1069 |
+
let f = l;
|
1070 |
+
l >= 65536 ? (n += 8, f = 127) : l > 125 && (n += 2, f = 126);
|
1071 |
+
const a = Buffer.allocUnsafe(i ? l + n : n);
|
1072 |
+
return a[0] = t.fin ? t.opcode | 128 : t.opcode, t.rsv1 && (a[0] |= 64), a[1] = f, f === 126 ? a.writeUInt16BE(l, 2) : f === 127 && (a[2] = a[3] = 0, a.writeUIntBE(l, 4, 6)), t.mask ? (a[1] |= 128, a[n - 4] = r[0], a[n - 3] = r[1], a[n - 2] = r[2], a[n - 1] = r[3], o ? [a, e] : i ? (De(e, r, a, n, l), [a]) : (De(e, r, e, 0, l), [a, e])) : [a, e];
|
1073 |
+
}
|
1074 |
+
/**
|
1075 |
+
* Sends a close message to the other peer.
|
1076 |
+
*
|
1077 |
+
* @param {Number} [code] The status code component of the body
|
1078 |
+
* @param {(String|Buffer)} [data] The message component of the body
|
1079 |
+
* @param {Boolean} [mask=false] Specifies whether or not to mask the message
|
1080 |
+
* @param {Function} [cb] Callback
|
1081 |
+
* @public
|
1082 |
+
*/
|
1083 |
+
close(e, t, r, i) {
|
1084 |
+
let n;
|
1085 |
+
if (e === void 0)
|
1086 |
+
n = Xt;
|
1087 |
+
else {
|
1088 |
+
if (typeof e != "number" || !Zt(e))
|
1089 |
+
throw new TypeError("First argument must be a valid error code number");
|
1090 |
+
if (t === void 0 || !t.length)
|
1091 |
+
n = Buffer.allocUnsafe(2), n.writeUInt16BE(e, 0);
|
1092 |
+
else {
|
1093 |
+
const l = Buffer.byteLength(t);
|
1094 |
+
if (l > 123)
|
1095 |
+
throw new RangeError("The message must not be greater than 123 bytes");
|
1096 |
+
n = Buffer.allocUnsafe(2 + l), n.writeUInt16BE(e, 0), typeof t == "string" ? n.write(t, 2) : n.set(t, 2);
|
1097 |
+
}
|
1098 |
+
}
|
1099 |
+
const o = {
|
1100 |
+
[x]: n.length,
|
1101 |
+
fin: !0,
|
1102 |
+
generateMask: this._generateMask,
|
1103 |
+
mask: r,
|
1104 |
+
maskBuffer: this._maskBuffer,
|
1105 |
+
opcode: 8,
|
1106 |
+
readOnly: !1,
|
1107 |
+
rsv1: !1
|
1108 |
+
};
|
1109 |
+
this._deflating ? this.enqueue([this.dispatch, n, !1, o, i]) : this.sendFrame(P.frame(n, o), i);
|
1110 |
+
}
|
1111 |
+
/**
|
1112 |
+
* Sends a ping message to the other peer.
|
1113 |
+
*
|
1114 |
+
* @param {*} data The message to send
|
1115 |
+
* @param {Boolean} [mask=false] Specifies whether or not to mask `data`
|
1116 |
+
* @param {Function} [cb] Callback
|
1117 |
+
* @public
|
1118 |
+
*/
|
1119 |
+
ping(e, t, r) {
|
1120 |
+
let i, n;
|
1121 |
+
if (typeof e == "string" ? (i = Buffer.byteLength(e), n = !1) : (e = M(e), i = e.length, n = M.readOnly), i > 125)
|
1122 |
+
throw new RangeError("The data size must not be greater than 125 bytes");
|
1123 |
+
const o = {
|
1124 |
+
[x]: i,
|
1125 |
+
fin: !0,
|
1126 |
+
generateMask: this._generateMask,
|
1127 |
+
mask: t,
|
1128 |
+
maskBuffer: this._maskBuffer,
|
1129 |
+
opcode: 9,
|
1130 |
+
readOnly: n,
|
1131 |
+
rsv1: !1
|
1132 |
+
};
|
1133 |
+
this._deflating ? this.enqueue([this.dispatch, e, !1, o, r]) : this.sendFrame(P.frame(e, o), r);
|
1134 |
+
}
|
1135 |
+
/**
|
1136 |
+
* Sends a pong message to the other peer.
|
1137 |
+
*
|
1138 |
+
* @param {*} data The message to send
|
1139 |
+
* @param {Boolean} [mask=false] Specifies whether or not to mask `data`
|
1140 |
+
* @param {Function} [cb] Callback
|
1141 |
+
* @public
|
1142 |
+
*/
|
1143 |
+
pong(e, t, r) {
|
1144 |
+
let i, n;
|
1145 |
+
if (typeof e == "string" ? (i = Buffer.byteLength(e), n = !1) : (e = M(e), i = e.length, n = M.readOnly), i > 125)
|
1146 |
+
throw new RangeError("The data size must not be greater than 125 bytes");
|
1147 |
+
const o = {
|
1148 |
+
[x]: i,
|
1149 |
+
fin: !0,
|
1150 |
+
generateMask: this._generateMask,
|
1151 |
+
mask: t,
|
1152 |
+
maskBuffer: this._maskBuffer,
|
1153 |
+
opcode: 10,
|
1154 |
+
readOnly: n,
|
1155 |
+
rsv1: !1
|
1156 |
+
};
|
1157 |
+
this._deflating ? this.enqueue([this.dispatch, e, !1, o, r]) : this.sendFrame(P.frame(e, o), r);
|
1158 |
+
}
|
1159 |
+
/**
|
1160 |
+
* Sends a data message to the other peer.
|
1161 |
+
*
|
1162 |
+
* @param {*} data The message to send
|
1163 |
+
* @param {Object} options Options object
|
1164 |
+
* @param {Boolean} [options.binary=false] Specifies whether `data` is binary
|
1165 |
+
* or text
|
1166 |
+
* @param {Boolean} [options.compress=false] Specifies whether or not to
|
1167 |
+
* compress `data`
|
1168 |
+
* @param {Boolean} [options.fin=false] Specifies whether the fragment is the
|
1169 |
+
* last one
|
1170 |
+
* @param {Boolean} [options.mask=false] Specifies whether or not to mask
|
1171 |
+
* `data`
|
1172 |
+
* @param {Function} [cb] Callback
|
1173 |
+
* @public
|
1174 |
+
*/
|
1175 |
+
send(e, t, r) {
|
1176 |
+
const i = this._extensions[Ie.extensionName];
|
1177 |
+
let n = t.binary ? 2 : 1, o = t.compress, l, f;
|
1178 |
+
if (typeof e == "string" ? (l = Buffer.byteLength(e), f = !1) : (e = M(e), l = e.length, f = M.readOnly), this._firstFragment ? (this._firstFragment = !1, o && i && i.params[i._isServer ? "server_no_context_takeover" : "client_no_context_takeover"] && (o = l >= i._threshold), this._compress = o) : (o = !1, n = 0), t.fin && (this._firstFragment = !0), i) {
|
1179 |
+
const a = {
|
1180 |
+
[x]: l,
|
1181 |
+
fin: t.fin,
|
1182 |
+
generateMask: this._generateMask,
|
1183 |
+
mask: t.mask,
|
1184 |
+
maskBuffer: this._maskBuffer,
|
1185 |
+
opcode: n,
|
1186 |
+
readOnly: f,
|
1187 |
+
rsv1: o
|
1188 |
+
};
|
1189 |
+
this._deflating ? this.enqueue([this.dispatch, e, this._compress, a, r]) : this.dispatch(e, this._compress, a, r);
|
1190 |
+
} else
|
1191 |
+
this.sendFrame(
|
1192 |
+
P.frame(e, {
|
1193 |
+
[x]: l,
|
1194 |
+
fin: t.fin,
|
1195 |
+
generateMask: this._generateMask,
|
1196 |
+
mask: t.mask,
|
1197 |
+
maskBuffer: this._maskBuffer,
|
1198 |
+
opcode: n,
|
1199 |
+
readOnly: f,
|
1200 |
+
rsv1: !1
|
1201 |
+
}),
|
1202 |
+
r
|
1203 |
+
);
|
1204 |
+
}
|
1205 |
+
/**
|
1206 |
+
* Dispatches a message.
|
1207 |
+
*
|
1208 |
+
* @param {(Buffer|String)} data The message to send
|
1209 |
+
* @param {Boolean} [compress=false] Specifies whether or not to compress
|
1210 |
+
* `data`
|
1211 |
+
* @param {Object} options Options object
|
1212 |
+
* @param {Boolean} [options.fin=false] Specifies whether or not to set the
|
1213 |
+
* FIN bit
|
1214 |
+
* @param {Function} [options.generateMask] The function used to generate the
|
1215 |
+
* masking key
|
1216 |
+
* @param {Boolean} [options.mask=false] Specifies whether or not to mask
|
1217 |
+
* `data`
|
1218 |
+
* @param {Buffer} [options.maskBuffer] The buffer used to store the masking
|
1219 |
+
* key
|
1220 |
+
* @param {Number} options.opcode The opcode
|
1221 |
+
* @param {Boolean} [options.readOnly=false] Specifies whether `data` can be
|
1222 |
+
* modified
|
1223 |
+
* @param {Boolean} [options.rsv1=false] Specifies whether or not to set the
|
1224 |
+
* RSV1 bit
|
1225 |
+
* @param {Function} [cb] Callback
|
1226 |
+
* @private
|
1227 |
+
*/
|
1228 |
+
dispatch(e, t, r, i) {
|
1229 |
+
if (!t) {
|
1230 |
+
this.sendFrame(P.frame(e, r), i);
|
1231 |
+
return;
|
1232 |
+
}
|
1233 |
+
const n = this._extensions[Ie.extensionName];
|
1234 |
+
this._bufferedBytes += r[x], this._deflating = !0, n.compress(e, r.fin, (o, l) => {
|
1235 |
+
if (this._socket.destroyed) {
|
1236 |
+
const f = new Error(
|
1237 |
+
"The socket was closed while data was being compressed"
|
1238 |
+
);
|
1239 |
+
typeof i == "function" && i(f);
|
1240 |
+
for (let a = 0; a < this._queue.length; a++) {
|
1241 |
+
const c = this._queue[a], h = c[c.length - 1];
|
1242 |
+
typeof h == "function" && h(f);
|
1243 |
+
}
|
1244 |
+
return;
|
1245 |
+
}
|
1246 |
+
this._bufferedBytes -= r[x], this._deflating = !1, r.readOnly = !1, this.sendFrame(P.frame(l, r), i), this.dequeue();
|
1247 |
+
});
|
1248 |
+
}
|
1249 |
+
/**
|
1250 |
+
* Executes queued send operations.
|
1251 |
+
*
|
1252 |
+
* @private
|
1253 |
+
*/
|
1254 |
+
dequeue() {
|
1255 |
+
for (; !this._deflating && this._queue.length; ) {
|
1256 |
+
const e = this._queue.shift();
|
1257 |
+
this._bufferedBytes -= e[3][x], Reflect.apply(e[0], this, e.slice(1));
|
1258 |
+
}
|
1259 |
+
}
|
1260 |
+
/**
|
1261 |
+
* Enqueues a send operation.
|
1262 |
+
*
|
1263 |
+
* @param {Array} params Send operation parameters.
|
1264 |
+
* @private
|
1265 |
+
*/
|
1266 |
+
enqueue(e) {
|
1267 |
+
this._bufferedBytes += e[3][x], this._queue.push(e);
|
1268 |
+
}
|
1269 |
+
/**
|
1270 |
+
* Sends a frame.
|
1271 |
+
*
|
1272 |
+
* @param {Buffer[]} list The frame to send
|
1273 |
+
* @param {Function} [cb] Callback
|
1274 |
+
* @private
|
1275 |
+
*/
|
1276 |
+
sendFrame(e, t) {
|
1277 |
+
e.length === 2 ? (this._socket.cork(), this._socket.write(e[0]), this._socket.write(e[1], t), this._socket.uncork()) : this._socket.write(e[0], t);
|
1278 |
+
}
|
1279 |
+
};
|
1280 |
+
var it = Jt;
|
1281 |
+
const Ks = /* @__PURE__ */ z(it), { kForOnEventAttribute: F, kListener: pe } = U, We = Symbol("kCode"), Ae = Symbol("kData"), Fe = Symbol("kError"), je = Symbol("kMessage"), Ge = Symbol("kReason"), I = Symbol("kTarget"), Ve = Symbol("kType"), He = Symbol("kWasClean");
|
1282 |
+
class B {
|
1283 |
+
/**
|
1284 |
+
* Create a new `Event`.
|
1285 |
+
*
|
1286 |
+
* @param {String} type The name of the event
|
1287 |
+
* @throws {TypeError} If the `type` argument is not specified
|
1288 |
+
*/
|
1289 |
+
constructor(e) {
|
1290 |
+
this[I] = null, this[Ve] = e;
|
1291 |
+
}
|
1292 |
+
/**
|
1293 |
+
* @type {*}
|
1294 |
+
*/
|
1295 |
+
get target() {
|
1296 |
+
return this[I];
|
1297 |
+
}
|
1298 |
+
/**
|
1299 |
+
* @type {String}
|
1300 |
+
*/
|
1301 |
+
get type() {
|
1302 |
+
return this[Ve];
|
1303 |
+
}
|
1304 |
+
}
|
1305 |
+
Object.defineProperty(B.prototype, "target", { enumerable: !0 });
|
1306 |
+
Object.defineProperty(B.prototype, "type", { enumerable: !0 });
|
1307 |
+
class Y extends B {
|
1308 |
+
/**
|
1309 |
+
* Create a new `CloseEvent`.
|
1310 |
+
*
|
1311 |
+
* @param {String} type The name of the event
|
1312 |
+
* @param {Object} [options] A dictionary object that allows for setting
|
1313 |
+
* attributes via object members of the same name
|
1314 |
+
* @param {Number} [options.code=0] The status code explaining why the
|
1315 |
+
* connection was closed
|
1316 |
+
* @param {String} [options.reason=''] A human-readable string explaining why
|
1317 |
+
* the connection was closed
|
1318 |
+
* @param {Boolean} [options.wasClean=false] Indicates whether or not the
|
1319 |
+
* connection was cleanly closed
|
1320 |
+
*/
|
1321 |
+
constructor(e, t = {}) {
|
1322 |
+
super(e), this[We] = t.code === void 0 ? 0 : t.code, this[Ge] = t.reason === void 0 ? "" : t.reason, this[He] = t.wasClean === void 0 ? !1 : t.wasClean;
|
1323 |
+
}
|
1324 |
+
/**
|
1325 |
+
* @type {Number}
|
1326 |
+
*/
|
1327 |
+
get code() {
|
1328 |
+
return this[We];
|
1329 |
+
}
|
1330 |
+
/**
|
1331 |
+
* @type {String}
|
1332 |
+
*/
|
1333 |
+
get reason() {
|
1334 |
+
return this[Ge];
|
1335 |
+
}
|
1336 |
+
/**
|
1337 |
+
* @type {Boolean}
|
1338 |
+
*/
|
1339 |
+
get wasClean() {
|
1340 |
+
return this[He];
|
1341 |
+
}
|
1342 |
+
}
|
1343 |
+
Object.defineProperty(Y.prototype, "code", { enumerable: !0 });
|
1344 |
+
Object.defineProperty(Y.prototype, "reason", { enumerable: !0 });
|
1345 |
+
Object.defineProperty(Y.prototype, "wasClean", { enumerable: !0 });
|
1346 |
+
class le extends B {
|
1347 |
+
/**
|
1348 |
+
* Create a new `ErrorEvent`.
|
1349 |
+
*
|
1350 |
+
* @param {String} type The name of the event
|
1351 |
+
* @param {Object} [options] A dictionary object that allows for setting
|
1352 |
+
* attributes via object members of the same name
|
1353 |
+
* @param {*} [options.error=null] The error that generated this event
|
1354 |
+
* @param {String} [options.message=''] The error message
|
1355 |
+
*/
|
1356 |
+
constructor(e, t = {}) {
|
1357 |
+
super(e), this[Fe] = t.error === void 0 ? null : t.error, this[je] = t.message === void 0 ? "" : t.message;
|
1358 |
+
}
|
1359 |
+
/**
|
1360 |
+
* @type {*}
|
1361 |
+
*/
|
1362 |
+
get error() {
|
1363 |
+
return this[Fe];
|
1364 |
+
}
|
1365 |
+
/**
|
1366 |
+
* @type {String}
|
1367 |
+
*/
|
1368 |
+
get message() {
|
1369 |
+
return this[je];
|
1370 |
+
}
|
1371 |
+
}
|
1372 |
+
Object.defineProperty(le.prototype, "error", { enumerable: !0 });
|
1373 |
+
Object.defineProperty(le.prototype, "message", { enumerable: !0 });
|
1374 |
+
class xe extends B {
|
1375 |
+
/**
|
1376 |
+
* Create a new `MessageEvent`.
|
1377 |
+
*
|
1378 |
+
* @param {String} type The name of the event
|
1379 |
+
* @param {Object} [options] A dictionary object that allows for setting
|
1380 |
+
* attributes via object members of the same name
|
1381 |
+
* @param {*} [options.data=null] The message content
|
1382 |
+
*/
|
1383 |
+
constructor(e, t = {}) {
|
1384 |
+
super(e), this[Ae] = t.data === void 0 ? null : t.data;
|
1385 |
+
}
|
1386 |
+
/**
|
1387 |
+
* @type {*}
|
1388 |
+
*/
|
1389 |
+
get data() {
|
1390 |
+
return this[Ae];
|
1391 |
+
}
|
1392 |
+
}
|
1393 |
+
Object.defineProperty(xe.prototype, "data", { enumerable: !0 });
|
1394 |
+
const es = {
|
1395 |
+
/**
|
1396 |
+
* Register an event listener.
|
1397 |
+
*
|
1398 |
+
* @param {String} type A string representing the event type to listen for
|
1399 |
+
* @param {(Function|Object)} handler The listener to add
|
1400 |
+
* @param {Object} [options] An options object specifies characteristics about
|
1401 |
+
* the event listener
|
1402 |
+
* @param {Boolean} [options.once=false] A `Boolean` indicating that the
|
1403 |
+
* listener should be invoked at most once after being added. If `true`,
|
1404 |
+
* the listener would be automatically removed when invoked.
|
1405 |
+
* @public
|
1406 |
+
*/
|
1407 |
+
addEventListener(s, e, t = {}) {
|
1408 |
+
for (const i of this.listeners(s))
|
1409 |
+
if (!t[F] && i[pe] === e && !i[F])
|
1410 |
+
return;
|
1411 |
+
let r;
|
1412 |
+
if (s === "message")
|
1413 |
+
r = function(n, o) {
|
1414 |
+
const l = new xe("message", {
|
1415 |
+
data: o ? n : n.toString()
|
1416 |
+
});
|
1417 |
+
l[I] = this, Z(e, this, l);
|
1418 |
+
};
|
1419 |
+
else if (s === "close")
|
1420 |
+
r = function(n, o) {
|
1421 |
+
const l = new Y("close", {
|
1422 |
+
code: n,
|
1423 |
+
reason: o.toString(),
|
1424 |
+
wasClean: this._closeFrameReceived && this._closeFrameSent
|
1425 |
+
});
|
1426 |
+
l[I] = this, Z(e, this, l);
|
1427 |
+
};
|
1428 |
+
else if (s === "error")
|
1429 |
+
r = function(n) {
|
1430 |
+
const o = new le("error", {
|
1431 |
+
error: n,
|
1432 |
+
message: n.message
|
1433 |
+
});
|
1434 |
+
o[I] = this, Z(e, this, o);
|
1435 |
+
};
|
1436 |
+
else if (s === "open")
|
1437 |
+
r = function() {
|
1438 |
+
const n = new B("open");
|
1439 |
+
n[I] = this, Z(e, this, n);
|
1440 |
+
};
|
1441 |
+
else
|
1442 |
+
return;
|
1443 |
+
r[F] = !!t[F], r[pe] = e, t.once ? this.once(s, r) : this.on(s, r);
|
1444 |
+
},
|
1445 |
+
/**
|
1446 |
+
* Remove an event listener.
|
1447 |
+
*
|
1448 |
+
* @param {String} type A string representing the event type to remove
|
1449 |
+
* @param {(Function|Object)} handler The listener to remove
|
1450 |
+
* @public
|
1451 |
+
*/
|
1452 |
+
removeEventListener(s, e) {
|
1453 |
+
for (const t of this.listeners(s))
|
1454 |
+
if (t[pe] === e && !t[F]) {
|
1455 |
+
this.removeListener(s, t);
|
1456 |
+
break;
|
1457 |
+
}
|
1458 |
+
}
|
1459 |
+
};
|
1460 |
+
var ts = {
|
1461 |
+
CloseEvent: Y,
|
1462 |
+
ErrorEvent: le,
|
1463 |
+
Event: B,
|
1464 |
+
EventTarget: es,
|
1465 |
+
MessageEvent: xe
|
1466 |
+
};
|
1467 |
+
function Z(s, e, t) {
|
1468 |
+
typeof s == "object" && s.handleEvent ? s.handleEvent.call(s, t) : s.call(e, t);
|
1469 |
+
}
|
1470 |
+
const { tokenChars: j } = ae;
|
1471 |
+
function k(s, e, t) {
|
1472 |
+
s[e] === void 0 ? s[e] = [t] : s[e].push(t);
|
1473 |
+
}
|
1474 |
+
function ss(s) {
|
1475 |
+
const e = /* @__PURE__ */ Object.create(null);
|
1476 |
+
let t = /* @__PURE__ */ Object.create(null), r = !1, i = !1, n = !1, o, l, f = -1, a = -1, c = -1, h = 0;
|
1477 |
+
for (; h < s.length; h++)
|
1478 |
+
if (a = s.charCodeAt(h), o === void 0)
|
1479 |
+
if (c === -1 && j[a] === 1)
|
1480 |
+
f === -1 && (f = h);
|
1481 |
+
else if (h !== 0 && (a === 32 || a === 9))
|
1482 |
+
c === -1 && f !== -1 && (c = h);
|
1483 |
+
else if (a === 59 || a === 44) {
|
1484 |
+
if (f === -1)
|
1485 |
+
throw new SyntaxError(`Unexpected character at index ${h}`);
|
1486 |
+
c === -1 && (c = h);
|
1487 |
+
const v = s.slice(f, c);
|
1488 |
+
a === 44 ? (k(e, v, t), t = /* @__PURE__ */ Object.create(null)) : o = v, f = c = -1;
|
1489 |
+
} else
|
1490 |
+
throw new SyntaxError(`Unexpected character at index ${h}`);
|
1491 |
+
else if (l === void 0)
|
1492 |
+
if (c === -1 && j[a] === 1)
|
1493 |
+
f === -1 && (f = h);
|
1494 |
+
else if (a === 32 || a === 9)
|
1495 |
+
c === -1 && f !== -1 && (c = h);
|
1496 |
+
else if (a === 59 || a === 44) {
|
1497 |
+
if (f === -1)
|
1498 |
+
throw new SyntaxError(`Unexpected character at index ${h}`);
|
1499 |
+
c === -1 && (c = h), k(t, s.slice(f, c), !0), a === 44 && (k(e, o, t), t = /* @__PURE__ */ Object.create(null), o = void 0), f = c = -1;
|
1500 |
+
} else if (a === 61 && f !== -1 && c === -1)
|
1501 |
+
l = s.slice(f, h), f = c = -1;
|
1502 |
+
else
|
1503 |
+
throw new SyntaxError(`Unexpected character at index ${h}`);
|
1504 |
+
else if (i) {
|
1505 |
+
if (j[a] !== 1)
|
1506 |
+
throw new SyntaxError(`Unexpected character at index ${h}`);
|
1507 |
+
f === -1 ? f = h : r || (r = !0), i = !1;
|
1508 |
+
} else if (n)
|
1509 |
+
if (j[a] === 1)
|
1510 |
+
f === -1 && (f = h);
|
1511 |
+
else if (a === 34 && f !== -1)
|
1512 |
+
n = !1, c = h;
|
1513 |
+
else if (a === 92)
|
1514 |
+
i = !0;
|
1515 |
+
else
|
1516 |
+
throw new SyntaxError(`Unexpected character at index ${h}`);
|
1517 |
+
else if (a === 34 && s.charCodeAt(h - 1) === 61)
|
1518 |
+
n = !0;
|
1519 |
+
else if (c === -1 && j[a] === 1)
|
1520 |
+
f === -1 && (f = h);
|
1521 |
+
else if (f !== -1 && (a === 32 || a === 9))
|
1522 |
+
c === -1 && (c = h);
|
1523 |
+
else if (a === 59 || a === 44) {
|
1524 |
+
if (f === -1)
|
1525 |
+
throw new SyntaxError(`Unexpected character at index ${h}`);
|
1526 |
+
c === -1 && (c = h);
|
1527 |
+
let v = s.slice(f, c);
|
1528 |
+
r && (v = v.replace(/\\/g, ""), r = !1), k(t, l, v), a === 44 && (k(e, o, t), t = /* @__PURE__ */ Object.create(null), o = void 0), l = void 0, f = c = -1;
|
1529 |
+
} else
|
1530 |
+
throw new SyntaxError(`Unexpected character at index ${h}`);
|
1531 |
+
if (f === -1 || n || a === 32 || a === 9)
|
1532 |
+
throw new SyntaxError("Unexpected end of input");
|
1533 |
+
c === -1 && (c = h);
|
1534 |
+
const p = s.slice(f, c);
|
1535 |
+
return o === void 0 ? k(e, p, t) : (l === void 0 ? k(t, p, !0) : r ? k(t, l, p.replace(/\\/g, "")) : k(t, l, p), k(e, o, t)), e;
|
1536 |
+
}
|
1537 |
+
function rs(s) {
|
1538 |
+
return Object.keys(s).map((e) => {
|
1539 |
+
let t = s[e];
|
1540 |
+
return Array.isArray(t) || (t = [t]), t.map((r) => [e].concat(
|
1541 |
+
Object.keys(r).map((i) => {
|
1542 |
+
let n = r[i];
|
1543 |
+
return Array.isArray(n) || (n = [n]), n.map((o) => o === !0 ? i : `${i}=${o}`).join("; ");
|
1544 |
+
})
|
1545 |
+
).join("; ")).join(", ");
|
1546 |
+
}).join(", ");
|
1547 |
+
}
|
1548 |
+
var nt = { format: rs, parse: ss };
|
1549 |
+
const is = S, ns = S, os = S, ot = S, as = S, { randomBytes: ls, createHash: fs } = S, { URL: me } = S, T = oe, hs = rt, cs = it, {
|
1550 |
+
BINARY_TYPES: ze,
|
1551 |
+
EMPTY_BUFFER: Q,
|
1552 |
+
GUID: us,
|
1553 |
+
kForOnEventAttribute: ge,
|
1554 |
+
kListener: ds,
|
1555 |
+
kStatusCode: _s,
|
1556 |
+
kWebSocket: y,
|
1557 |
+
NOOP: at
|
1558 |
+
} = U, {
|
1559 |
+
EventTarget: { addEventListener: ps, removeEventListener: ms }
|
1560 |
+
} = ts, { format: gs, parse: ys } = nt, { toBuffer: vs } = ne, Ss = 30 * 1e3, lt = Symbol("kAborted"), ye = [8, 13], O = ["CONNECTING", "OPEN", "CLOSING", "CLOSED"], Es = /^[!#$%&'*+\-.0-9A-Z^_`|a-z~]+$/;
|
1561 |
+
let m = class d extends is {
|
1562 |
+
/**
|
1563 |
+
* Create a new `WebSocket`.
|
1564 |
+
*
|
1565 |
+
* @param {(String|URL)} address The URL to which to connect
|
1566 |
+
* @param {(String|String[])} [protocols] The subprotocols
|
1567 |
+
* @param {Object} [options] Connection options
|
1568 |
+
*/
|
1569 |
+
constructor(e, t, r) {
|
1570 |
+
super(), this._binaryType = ze[0], this._closeCode = 1006, this._closeFrameReceived = !1, this._closeFrameSent = !1, this._closeMessage = Q, this._closeTimer = null, this._extensions = {}, this._paused = !1, this._protocol = "", this._readyState = d.CONNECTING, this._receiver = null, this._sender = null, this._socket = null, e !== null ? (this._bufferedAmount = 0, this._isServer = !1, this._redirects = 0, t === void 0 ? t = [] : Array.isArray(t) || (typeof t == "object" && t !== null ? (r = t, t = []) : t = [t]), ht(this, e, t, r)) : this._isServer = !0;
|
1571 |
+
}
|
1572 |
+
/**
|
1573 |
+
* This deviates from the WHATWG interface since ws doesn't support the
|
1574 |
+
* required default "blob" type (instead we define a custom "nodebuffer"
|
1575 |
+
* type).
|
1576 |
+
*
|
1577 |
+
* @type {String}
|
1578 |
+
*/
|
1579 |
+
get binaryType() {
|
1580 |
+
return this._binaryType;
|
1581 |
+
}
|
1582 |
+
set binaryType(e) {
|
1583 |
+
ze.includes(e) && (this._binaryType = e, this._receiver && (this._receiver._binaryType = e));
|
1584 |
+
}
|
1585 |
+
/**
|
1586 |
+
* @type {Number}
|
1587 |
+
*/
|
1588 |
+
get bufferedAmount() {
|
1589 |
+
return this._socket ? this._socket._writableState.length + this._sender._bufferedBytes : this._bufferedAmount;
|
1590 |
+
}
|
1591 |
+
/**
|
1592 |
+
* @type {String}
|
1593 |
+
*/
|
1594 |
+
get extensions() {
|
1595 |
+
return Object.keys(this._extensions).join();
|
1596 |
+
}
|
1597 |
+
/**
|
1598 |
+
* @type {Boolean}
|
1599 |
+
*/
|
1600 |
+
get isPaused() {
|
1601 |
+
return this._paused;
|
1602 |
+
}
|
1603 |
+
/**
|
1604 |
+
* @type {Function}
|
1605 |
+
*/
|
1606 |
+
/* istanbul ignore next */
|
1607 |
+
get onclose() {
|
1608 |
+
return null;
|
1609 |
+
}
|
1610 |
+
/**
|
1611 |
+
* @type {Function}
|
1612 |
+
*/
|
1613 |
+
/* istanbul ignore next */
|
1614 |
+
get onerror() {
|
1615 |
+
return null;
|
1616 |
+
}
|
1617 |
+
/**
|
1618 |
+
* @type {Function}
|
1619 |
+
*/
|
1620 |
+
/* istanbul ignore next */
|
1621 |
+
get onopen() {
|
1622 |
+
return null;
|
1623 |
+
}
|
1624 |
+
/**
|
1625 |
+
* @type {Function}
|
1626 |
+
*/
|
1627 |
+
/* istanbul ignore next */
|
1628 |
+
get onmessage() {
|
1629 |
+
return null;
|
1630 |
+
}
|
1631 |
+
/**
|
1632 |
+
* @type {String}
|
1633 |
+
*/
|
1634 |
+
get protocol() {
|
1635 |
+
return this._protocol;
|
1636 |
+
}
|
1637 |
+
/**
|
1638 |
+
* @type {Number}
|
1639 |
+
*/
|
1640 |
+
get readyState() {
|
1641 |
+
return this._readyState;
|
1642 |
+
}
|
1643 |
+
/**
|
1644 |
+
* @type {String}
|
1645 |
+
*/
|
1646 |
+
get url() {
|
1647 |
+
return this._url;
|
1648 |
+
}
|
1649 |
+
/**
|
1650 |
+
* Set up the socket and the internal resources.
|
1651 |
+
*
|
1652 |
+
* @param {(net.Socket|tls.Socket)} socket The network socket between the
|
1653 |
+
* server and client
|
1654 |
+
* @param {Buffer} head The first packet of the upgraded stream
|
1655 |
+
* @param {Object} options Options object
|
1656 |
+
* @param {Function} [options.generateMask] The function used to generate the
|
1657 |
+
* masking key
|
1658 |
+
* @param {Number} [options.maxPayload=0] The maximum allowed message size
|
1659 |
+
* @param {Boolean} [options.skipUTF8Validation=false] Specifies whether or
|
1660 |
+
* not to skip UTF-8 validation for text and close messages
|
1661 |
+
* @private
|
1662 |
+
*/
|
1663 |
+
setSocket(e, t, r) {
|
1664 |
+
const i = new hs({
|
1665 |
+
binaryType: this.binaryType,
|
1666 |
+
extensions: this._extensions,
|
1667 |
+
isServer: this._isServer,
|
1668 |
+
maxPayload: r.maxPayload,
|
1669 |
+
skipUTF8Validation: r.skipUTF8Validation
|
1670 |
+
});
|
1671 |
+
this._sender = new cs(e, this._extensions, r.generateMask), this._receiver = i, this._socket = e, i[y] = this, e[y] = this, i.on("conclude", ks), i.on("drain", ws), i.on("error", Os), i.on("message", Cs), i.on("ping", Ts), i.on("pong", Ls), e.setTimeout(0), e.setNoDelay(), t.length > 0 && e.unshift(t), e.on("close", ut), e.on("data", fe), e.on("end", dt), e.on("error", _t), this._readyState = d.OPEN, this.emit("open");
|
1672 |
+
}
|
1673 |
+
/**
|
1674 |
+
* Emit the `'close'` event.
|
1675 |
+
*
|
1676 |
+
* @private
|
1677 |
+
*/
|
1678 |
+
emitClose() {
|
1679 |
+
if (!this._socket) {
|
1680 |
+
this._readyState = d.CLOSED, this.emit("close", this._closeCode, this._closeMessage);
|
1681 |
+
return;
|
1682 |
+
}
|
1683 |
+
this._extensions[T.extensionName] && this._extensions[T.extensionName].cleanup(), this._receiver.removeAllListeners(), this._readyState = d.CLOSED, this.emit("close", this._closeCode, this._closeMessage);
|
1684 |
+
}
|
1685 |
+
/**
|
1686 |
+
* Start a closing handshake.
|
1687 |
+
*
|
1688 |
+
* +----------+ +-----------+ +----------+
|
1689 |
+
* - - -|ws.close()|-->|close frame|-->|ws.close()|- - -
|
1690 |
+
* | +----------+ +-----------+ +----------+ |
|
1691 |
+
* +----------+ +-----------+ |
|
1692 |
+
* CLOSING |ws.close()|<--|close frame|<--+-----+ CLOSING
|
1693 |
+
* +----------+ +-----------+ |
|
1694 |
+
* | | | +---+ |
|
1695 |
+
* +------------------------+-->|fin| - - - -
|
1696 |
+
* | +---+ | +---+
|
1697 |
+
* - - - - -|fin|<---------------------+
|
1698 |
+
* +---+
|
1699 |
+
*
|
1700 |
+
* @param {Number} [code] Status code explaining why the connection is closing
|
1701 |
+
* @param {(String|Buffer)} [data] The reason why the connection is
|
1702 |
+
* closing
|
1703 |
+
* @public
|
1704 |
+
*/
|
1705 |
+
close(e, t) {
|
1706 |
+
if (this.readyState !== d.CLOSED) {
|
1707 |
+
if (this.readyState === d.CONNECTING) {
|
1708 |
+
const r = "WebSocket was closed before the connection was established";
|
1709 |
+
b(this, this._req, r);
|
1710 |
+
return;
|
1711 |
+
}
|
1712 |
+
if (this.readyState === d.CLOSING) {
|
1713 |
+
this._closeFrameSent && (this._closeFrameReceived || this._receiver._writableState.errorEmitted) && this._socket.end();
|
1714 |
+
return;
|
1715 |
+
}
|
1716 |
+
this._readyState = d.CLOSING, this._sender.close(e, t, !this._isServer, (r) => {
|
1717 |
+
r || (this._closeFrameSent = !0, (this._closeFrameReceived || this._receiver._writableState.errorEmitted) && this._socket.end());
|
1718 |
+
}), this._closeTimer = setTimeout(
|
1719 |
+
this._socket.destroy.bind(this._socket),
|
1720 |
+
Ss
|
1721 |
+
);
|
1722 |
+
}
|
1723 |
+
}
|
1724 |
+
/**
|
1725 |
+
* Pause the socket.
|
1726 |
+
*
|
1727 |
+
* @public
|
1728 |
+
*/
|
1729 |
+
pause() {
|
1730 |
+
this.readyState === d.CONNECTING || this.readyState === d.CLOSED || (this._paused = !0, this._socket.pause());
|
1731 |
+
}
|
1732 |
+
/**
|
1733 |
+
* Send a ping.
|
1734 |
+
*
|
1735 |
+
* @param {*} [data] The data to send
|
1736 |
+
* @param {Boolean} [mask] Indicates whether or not to mask `data`
|
1737 |
+
* @param {Function} [cb] Callback which is executed when the ping is sent
|
1738 |
+
* @public
|
1739 |
+
*/
|
1740 |
+
ping(e, t, r) {
|
1741 |
+
if (this.readyState === d.CONNECTING)
|
1742 |
+
throw new Error("WebSocket is not open: readyState 0 (CONNECTING)");
|
1743 |
+
if (typeof e == "function" ? (r = e, e = t = void 0) : typeof t == "function" && (r = t, t = void 0), typeof e == "number" && (e = e.toString()), this.readyState !== d.OPEN) {
|
1744 |
+
ve(this, e, r);
|
1745 |
+
return;
|
1746 |
+
}
|
1747 |
+
t === void 0 && (t = !this._isServer), this._sender.ping(e || Q, t, r);
|
1748 |
+
}
|
1749 |
+
/**
|
1750 |
+
* Send a pong.
|
1751 |
+
*
|
1752 |
+
* @param {*} [data] The data to send
|
1753 |
+
* @param {Boolean} [mask] Indicates whether or not to mask `data`
|
1754 |
+
* @param {Function} [cb] Callback which is executed when the pong is sent
|
1755 |
+
* @public
|
1756 |
+
*/
|
1757 |
+
pong(e, t, r) {
|
1758 |
+
if (this.readyState === d.CONNECTING)
|
1759 |
+
throw new Error("WebSocket is not open: readyState 0 (CONNECTING)");
|
1760 |
+
if (typeof e == "function" ? (r = e, e = t = void 0) : typeof t == "function" && (r = t, t = void 0), typeof e == "number" && (e = e.toString()), this.readyState !== d.OPEN) {
|
1761 |
+
ve(this, e, r);
|
1762 |
+
return;
|
1763 |
+
}
|
1764 |
+
t === void 0 && (t = !this._isServer), this._sender.pong(e || Q, t, r);
|
1765 |
+
}
|
1766 |
+
/**
|
1767 |
+
* Resume the socket.
|
1768 |
+
*
|
1769 |
+
* @public
|
1770 |
+
*/
|
1771 |
+
resume() {
|
1772 |
+
this.readyState === d.CONNECTING || this.readyState === d.CLOSED || (this._paused = !1, this._receiver._writableState.needDrain || this._socket.resume());
|
1773 |
+
}
|
1774 |
+
/**
|
1775 |
+
* Send a data message.
|
1776 |
+
*
|
1777 |
+
* @param {*} data The message to send
|
1778 |
+
* @param {Object} [options] Options object
|
1779 |
+
* @param {Boolean} [options.binary] Specifies whether `data` is binary or
|
1780 |
+
* text
|
1781 |
+
* @param {Boolean} [options.compress] Specifies whether or not to compress
|
1782 |
+
* `data`
|
1783 |
+
* @param {Boolean} [options.fin=true] Specifies whether the fragment is the
|
1784 |
+
* last one
|
1785 |
+
* @param {Boolean} [options.mask] Specifies whether or not to mask `data`
|
1786 |
+
* @param {Function} [cb] Callback which is executed when data is written out
|
1787 |
+
* @public
|
1788 |
+
*/
|
1789 |
+
send(e, t, r) {
|
1790 |
+
if (this.readyState === d.CONNECTING)
|
1791 |
+
throw new Error("WebSocket is not open: readyState 0 (CONNECTING)");
|
1792 |
+
if (typeof t == "function" && (r = t, t = {}), typeof e == "number" && (e = e.toString()), this.readyState !== d.OPEN) {
|
1793 |
+
ve(this, e, r);
|
1794 |
+
return;
|
1795 |
+
}
|
1796 |
+
const i = {
|
1797 |
+
binary: typeof e != "string",
|
1798 |
+
mask: !this._isServer,
|
1799 |
+
compress: !0,
|
1800 |
+
fin: !0,
|
1801 |
+
...t
|
1802 |
+
};
|
1803 |
+
this._extensions[T.extensionName] || (i.compress = !1), this._sender.send(e || Q, i, r);
|
1804 |
+
}
|
1805 |
+
/**
|
1806 |
+
* Forcibly close the connection.
|
1807 |
+
*
|
1808 |
+
* @public
|
1809 |
+
*/
|
1810 |
+
terminate() {
|
1811 |
+
if (this.readyState !== d.CLOSED) {
|
1812 |
+
if (this.readyState === d.CONNECTING) {
|
1813 |
+
const e = "WebSocket was closed before the connection was established";
|
1814 |
+
b(this, this._req, e);
|
1815 |
+
return;
|
1816 |
+
}
|
1817 |
+
this._socket && (this._readyState = d.CLOSING, this._socket.destroy());
|
1818 |
+
}
|
1819 |
+
}
|
1820 |
+
};
|
1821 |
+
Object.defineProperty(m, "CONNECTING", {
|
1822 |
+
enumerable: !0,
|
1823 |
+
value: O.indexOf("CONNECTING")
|
1824 |
+
});
|
1825 |
+
Object.defineProperty(m.prototype, "CONNECTING", {
|
1826 |
+
enumerable: !0,
|
1827 |
+
value: O.indexOf("CONNECTING")
|
1828 |
+
});
|
1829 |
+
Object.defineProperty(m, "OPEN", {
|
1830 |
+
enumerable: !0,
|
1831 |
+
value: O.indexOf("OPEN")
|
1832 |
+
});
|
1833 |
+
Object.defineProperty(m.prototype, "OPEN", {
|
1834 |
+
enumerable: !0,
|
1835 |
+
value: O.indexOf("OPEN")
|
1836 |
+
});
|
1837 |
+
Object.defineProperty(m, "CLOSING", {
|
1838 |
+
enumerable: !0,
|
1839 |
+
value: O.indexOf("CLOSING")
|
1840 |
+
});
|
1841 |
+
Object.defineProperty(m.prototype, "CLOSING", {
|
1842 |
+
enumerable: !0,
|
1843 |
+
value: O.indexOf("CLOSING")
|
1844 |
+
});
|
1845 |
+
Object.defineProperty(m, "CLOSED", {
|
1846 |
+
enumerable: !0,
|
1847 |
+
value: O.indexOf("CLOSED")
|
1848 |
+
});
|
1849 |
+
Object.defineProperty(m.prototype, "CLOSED", {
|
1850 |
+
enumerable: !0,
|
1851 |
+
value: O.indexOf("CLOSED")
|
1852 |
+
});
|
1853 |
+
[
|
1854 |
+
"binaryType",
|
1855 |
+
"bufferedAmount",
|
1856 |
+
"extensions",
|
1857 |
+
"isPaused",
|
1858 |
+
"protocol",
|
1859 |
+
"readyState",
|
1860 |
+
"url"
|
1861 |
+
].forEach((s) => {
|
1862 |
+
Object.defineProperty(m.prototype, s, { enumerable: !0 });
|
1863 |
+
});
|
1864 |
+
["open", "error", "close", "message"].forEach((s) => {
|
1865 |
+
Object.defineProperty(m.prototype, `on${s}`, {
|
1866 |
+
enumerable: !0,
|
1867 |
+
get() {
|
1868 |
+
for (const e of this.listeners(s))
|
1869 |
+
if (e[ge])
|
1870 |
+
return e[ds];
|
1871 |
+
return null;
|
1872 |
+
},
|
1873 |
+
set(e) {
|
1874 |
+
for (const t of this.listeners(s))
|
1875 |
+
if (t[ge]) {
|
1876 |
+
this.removeListener(s, t);
|
1877 |
+
break;
|
1878 |
+
}
|
1879 |
+
typeof e == "function" && this.addEventListener(s, e, {
|
1880 |
+
[ge]: !0
|
1881 |
+
});
|
1882 |
+
}
|
1883 |
+
});
|
1884 |
+
});
|
1885 |
+
m.prototype.addEventListener = ps;
|
1886 |
+
m.prototype.removeEventListener = ms;
|
1887 |
+
var ft = m;
|
1888 |
+
function ht(s, e, t, r) {
|
1889 |
+
const i = {
|
1890 |
+
protocolVersion: ye[1],
|
1891 |
+
maxPayload: 104857600,
|
1892 |
+
skipUTF8Validation: !1,
|
1893 |
+
perMessageDeflate: !0,
|
1894 |
+
followRedirects: !1,
|
1895 |
+
maxRedirects: 10,
|
1896 |
+
...r,
|
1897 |
+
createConnection: void 0,
|
1898 |
+
socketPath: void 0,
|
1899 |
+
hostname: void 0,
|
1900 |
+
protocol: void 0,
|
1901 |
+
timeout: void 0,
|
1902 |
+
method: "GET",
|
1903 |
+
host: void 0,
|
1904 |
+
path: void 0,
|
1905 |
+
port: void 0
|
1906 |
+
};
|
1907 |
+
if (!ye.includes(i.protocolVersion))
|
1908 |
+
throw new RangeError(
|
1909 |
+
`Unsupported protocol version: ${i.protocolVersion} (supported versions: ${ye.join(", ")})`
|
1910 |
+
);
|
1911 |
+
let n;
|
1912 |
+
if (e instanceof me)
|
1913 |
+
n = e, s._url = e.href;
|
1914 |
+
else {
|
1915 |
+
try {
|
1916 |
+
n = new me(e);
|
1917 |
+
} catch {
|
1918 |
+
throw new SyntaxError(`Invalid URL: ${e}`);
|
1919 |
+
}
|
1920 |
+
s._url = e;
|
1921 |
+
}
|
1922 |
+
const o = n.protocol === "wss:", l = n.protocol === "ws+unix:";
|
1923 |
+
let f;
|
1924 |
+
if (n.protocol !== "ws:" && !o && !l ? f = `The URL's protocol must be one of "ws:", "wss:", or "ws+unix:"` : l && !n.pathname ? f = "The URL's pathname is empty" : n.hash && (f = "The URL contains a fragment identifier"), f) {
|
1925 |
+
const u = new SyntaxError(f);
|
1926 |
+
if (s._redirects === 0)
|
1927 |
+
throw u;
|
1928 |
+
ee(s, u);
|
1929 |
+
return;
|
1930 |
+
}
|
1931 |
+
const a = o ? 443 : 80, c = ls(16).toString("base64"), h = o ? ns.request : os.request, p = /* @__PURE__ */ new Set();
|
1932 |
+
let v;
|
1933 |
+
if (i.createConnection = o ? xs : bs, i.defaultPort = i.defaultPort || a, i.port = n.port || a, i.host = n.hostname.startsWith("[") ? n.hostname.slice(1, -1) : n.hostname, i.headers = {
|
1934 |
+
...i.headers,
|
1935 |
+
"Sec-WebSocket-Version": i.protocolVersion,
|
1936 |
+
"Sec-WebSocket-Key": c,
|
1937 |
+
Connection: "Upgrade",
|
1938 |
+
Upgrade: "websocket"
|
1939 |
+
}, i.path = n.pathname + n.search, i.timeout = i.handshakeTimeout, i.perMessageDeflate && (v = new T(
|
1940 |
+
i.perMessageDeflate !== !0 ? i.perMessageDeflate : {},
|
1941 |
+
!1,
|
1942 |
+
i.maxPayload
|
1943 |
+
), i.headers["Sec-WebSocket-Extensions"] = gs({
|
1944 |
+
[T.extensionName]: v.offer()
|
1945 |
+
})), t.length) {
|
1946 |
+
for (const u of t) {
|
1947 |
+
if (typeof u != "string" || !Es.test(u) || p.has(u))
|
1948 |
+
throw new SyntaxError(
|
1949 |
+
"An invalid or duplicated subprotocol was specified"
|
1950 |
+
);
|
1951 |
+
p.add(u);
|
1952 |
+
}
|
1953 |
+
i.headers["Sec-WebSocket-Protocol"] = t.join(",");
|
1954 |
+
}
|
1955 |
+
if (i.origin && (i.protocolVersion < 13 ? i.headers["Sec-WebSocket-Origin"] = i.origin : i.headers.Origin = i.origin), (n.username || n.password) && (i.auth = `${n.username}:${n.password}`), l) {
|
1956 |
+
const u = i.path.split(":");
|
1957 |
+
i.socketPath = u[0], i.path = u[1];
|
1958 |
+
}
|
1959 |
+
let _;
|
1960 |
+
if (i.followRedirects) {
|
1961 |
+
if (s._redirects === 0) {
|
1962 |
+
s._originalIpc = l, s._originalSecure = o, s._originalHostOrSocketPath = l ? i.socketPath : n.host;
|
1963 |
+
const u = r && r.headers;
|
1964 |
+
if (r = { ...r, headers: {} }, u)
|
1965 |
+
for (const [E, $] of Object.entries(u))
|
1966 |
+
r.headers[E.toLowerCase()] = $;
|
1967 |
+
} else if (s.listenerCount("redirect") === 0) {
|
1968 |
+
const u = l ? s._originalIpc ? i.socketPath === s._originalHostOrSocketPath : !1 : s._originalIpc ? !1 : n.host === s._originalHostOrSocketPath;
|
1969 |
+
(!u || s._originalSecure && !o) && (delete i.headers.authorization, delete i.headers.cookie, u || delete i.headers.host, i.auth = void 0);
|
1970 |
+
}
|
1971 |
+
i.auth && !r.headers.authorization && (r.headers.authorization = "Basic " + Buffer.from(i.auth).toString("base64")), _ = s._req = h(i), s._redirects && s.emit("redirect", s.url, _);
|
1972 |
+
} else
|
1973 |
+
_ = s._req = h(i);
|
1974 |
+
i.timeout && _.on("timeout", () => {
|
1975 |
+
b(s, _, "Opening handshake has timed out");
|
1976 |
+
}), _.on("error", (u) => {
|
1977 |
+
_ === null || _[lt] || (_ = s._req = null, ee(s, u));
|
1978 |
+
}), _.on("response", (u) => {
|
1979 |
+
const E = u.headers.location, $ = u.statusCode;
|
1980 |
+
if (E && i.followRedirects && $ >= 300 && $ < 400) {
|
1981 |
+
if (++s._redirects > i.maxRedirects) {
|
1982 |
+
b(s, _, "Maximum redirects exceeded");
|
1983 |
+
return;
|
1984 |
+
}
|
1985 |
+
_.abort();
|
1986 |
+
let q;
|
1987 |
+
try {
|
1988 |
+
q = new me(E, e);
|
1989 |
+
} catch {
|
1990 |
+
const L = new SyntaxError(`Invalid URL: ${E}`);
|
1991 |
+
ee(s, L);
|
1992 |
+
return;
|
1993 |
+
}
|
1994 |
+
ht(s, q, t, r);
|
1995 |
+
} else
|
1996 |
+
s.emit("unexpected-response", _, u) || b(
|
1997 |
+
s,
|
1998 |
+
_,
|
1999 |
+
`Unexpected server response: ${u.statusCode}`
|
2000 |
+
);
|
2001 |
+
}), _.on("upgrade", (u, E, $) => {
|
2002 |
+
if (s.emit("upgrade", u), s.readyState !== m.CONNECTING)
|
2003 |
+
return;
|
2004 |
+
if (_ = s._req = null, u.headers.upgrade.toLowerCase() !== "websocket") {
|
2005 |
+
b(s, E, "Invalid Upgrade header");
|
2006 |
+
return;
|
2007 |
+
}
|
2008 |
+
const q = fs("sha1").update(c + us).digest("base64");
|
2009 |
+
if (u.headers["sec-websocket-accept"] !== q) {
|
2010 |
+
b(s, E, "Invalid Sec-WebSocket-Accept header");
|
2011 |
+
return;
|
2012 |
+
}
|
2013 |
+
const D = u.headers["sec-websocket-protocol"];
|
2014 |
+
let L;
|
2015 |
+
if (D !== void 0 ? p.size ? p.has(D) || (L = "Server sent an invalid subprotocol") : L = "Server sent a subprotocol but none was requested" : p.size && (L = "Server sent no subprotocol"), L) {
|
2016 |
+
b(s, E, L);
|
2017 |
+
return;
|
2018 |
+
}
|
2019 |
+
D && (s._protocol = D);
|
2020 |
+
const ke = u.headers["sec-websocket-extensions"];
|
2021 |
+
if (ke !== void 0) {
|
2022 |
+
if (!v) {
|
2023 |
+
b(s, E, "Server sent a Sec-WebSocket-Extensions header but no extension was requested");
|
2024 |
+
return;
|
2025 |
+
}
|
2026 |
+
let he;
|
2027 |
+
try {
|
2028 |
+
he = ys(ke);
|
2029 |
+
} catch {
|
2030 |
+
b(s, E, "Invalid Sec-WebSocket-Extensions header");
|
2031 |
+
return;
|
2032 |
+
}
|
2033 |
+
const we = Object.keys(he);
|
2034 |
+
if (we.length !== 1 || we[0] !== T.extensionName) {
|
2035 |
+
b(s, E, "Server indicated an extension that was not requested");
|
2036 |
+
return;
|
2037 |
+
}
|
2038 |
+
try {
|
2039 |
+
v.accept(he[T.extensionName]);
|
2040 |
+
} catch {
|
2041 |
+
b(s, E, "Invalid Sec-WebSocket-Extensions header");
|
2042 |
+
return;
|
2043 |
+
}
|
2044 |
+
s._extensions[T.extensionName] = v;
|
2045 |
+
}
|
2046 |
+
s.setSocket(E, $, {
|
2047 |
+
generateMask: i.generateMask,
|
2048 |
+
maxPayload: i.maxPayload,
|
2049 |
+
skipUTF8Validation: i.skipUTF8Validation
|
2050 |
+
});
|
2051 |
+
}), i.finishRequest ? i.finishRequest(_, s) : _.end();
|
2052 |
+
}
|
2053 |
+
function ee(s, e) {
|
2054 |
+
s._readyState = m.CLOSING, s.emit("error", e), s.emitClose();
|
2055 |
+
}
|
2056 |
+
function bs(s) {
|
2057 |
+
return s.path = s.socketPath, ot.connect(s);
|
2058 |
+
}
|
2059 |
+
function xs(s) {
|
2060 |
+
return s.path = void 0, !s.servername && s.servername !== "" && (s.servername = ot.isIP(s.host) ? "" : s.host), as.connect(s);
|
2061 |
+
}
|
2062 |
+
function b(s, e, t) {
|
2063 |
+
s._readyState = m.CLOSING;
|
2064 |
+
const r = new Error(t);
|
2065 |
+
Error.captureStackTrace(r, b), e.setHeader ? (e[lt] = !0, e.abort(), e.socket && !e.socket.destroyed && e.socket.destroy(), process.nextTick(ee, s, r)) : (e.destroy(r), e.once("error", s.emit.bind(s, "error")), e.once("close", s.emitClose.bind(s)));
|
2066 |
+
}
|
2067 |
+
function ve(s, e, t) {
|
2068 |
+
if (e) {
|
2069 |
+
const r = vs(e).length;
|
2070 |
+
s._socket ? s._sender._bufferedBytes += r : s._bufferedAmount += r;
|
2071 |
+
}
|
2072 |
+
if (t) {
|
2073 |
+
const r = new Error(
|
2074 |
+
`WebSocket is not open: readyState ${s.readyState} (${O[s.readyState]})`
|
2075 |
+
);
|
2076 |
+
process.nextTick(t, r);
|
2077 |
+
}
|
2078 |
+
}
|
2079 |
+
function ks(s, e) {
|
2080 |
+
const t = this[y];
|
2081 |
+
t._closeFrameReceived = !0, t._closeMessage = e, t._closeCode = s, t._socket[y] !== void 0 && (t._socket.removeListener("data", fe), process.nextTick(ct, t._socket), s === 1005 ? t.close() : t.close(s, e));
|
2082 |
+
}
|
2083 |
+
function ws() {
|
2084 |
+
const s = this[y];
|
2085 |
+
s.isPaused || s._socket.resume();
|
2086 |
+
}
|
2087 |
+
function Os(s) {
|
2088 |
+
const e = this[y];
|
2089 |
+
e._socket[y] !== void 0 && (e._socket.removeListener("data", fe), process.nextTick(ct, e._socket), e.close(s[_s])), e.emit("error", s);
|
2090 |
+
}
|
2091 |
+
function Ye() {
|
2092 |
+
this[y].emitClose();
|
2093 |
+
}
|
2094 |
+
function Cs(s, e) {
|
2095 |
+
this[y].emit("message", s, e);
|
2096 |
+
}
|
2097 |
+
function Ts(s) {
|
2098 |
+
const e = this[y];
|
2099 |
+
e.pong(s, !e._isServer, at), e.emit("ping", s);
|
2100 |
+
}
|
2101 |
+
function Ls(s) {
|
2102 |
+
this[y].emit("pong", s);
|
2103 |
+
}
|
2104 |
+
function ct(s) {
|
2105 |
+
s.resume();
|
2106 |
+
}
|
2107 |
+
function ut() {
|
2108 |
+
const s = this[y];
|
2109 |
+
this.removeListener("close", ut), this.removeListener("data", fe), this.removeListener("end", dt), s._readyState = m.CLOSING;
|
2110 |
+
let e;
|
2111 |
+
!this._readableState.endEmitted && !s._closeFrameReceived && !s._receiver._writableState.errorEmitted && (e = s._socket.read()) !== null && s._receiver.write(e), s._receiver.end(), this[y] = void 0, clearTimeout(s._closeTimer), s._receiver._writableState.finished || s._receiver._writableState.errorEmitted ? s.emitClose() : (s._receiver.on("error", Ye), s._receiver.on("finish", Ye));
|
2112 |
+
}
|
2113 |
+
function fe(s) {
|
2114 |
+
this[y]._receiver.write(s) || this.pause();
|
2115 |
+
}
|
2116 |
+
function dt() {
|
2117 |
+
const s = this[y];
|
2118 |
+
s._readyState = m.CLOSING, s._receiver.end(), this.end();
|
2119 |
+
}
|
2120 |
+
function _t() {
|
2121 |
+
const s = this[y];
|
2122 |
+
this.removeListener("error", _t), this.on("error", at), s && (s._readyState = m.CLOSING, this.destroy());
|
2123 |
+
}
|
2124 |
+
const Xs = /* @__PURE__ */ z(ft), { tokenChars: Ns } = ae;
|
2125 |
+
function Ps(s) {
|
2126 |
+
const e = /* @__PURE__ */ new Set();
|
2127 |
+
let t = -1, r = -1, i = 0;
|
2128 |
+
for (i; i < s.length; i++) {
|
2129 |
+
const o = s.charCodeAt(i);
|
2130 |
+
if (r === -1 && Ns[o] === 1)
|
2131 |
+
t === -1 && (t = i);
|
2132 |
+
else if (i !== 0 && (o === 32 || o === 9))
|
2133 |
+
r === -1 && t !== -1 && (r = i);
|
2134 |
+
else if (o === 44) {
|
2135 |
+
if (t === -1)
|
2136 |
+
throw new SyntaxError(`Unexpected character at index ${i}`);
|
2137 |
+
r === -1 && (r = i);
|
2138 |
+
const l = s.slice(t, r);
|
2139 |
+
if (e.has(l))
|
2140 |
+
throw new SyntaxError(`The "${l}" subprotocol is duplicated`);
|
2141 |
+
e.add(l), t = r = -1;
|
2142 |
+
} else
|
2143 |
+
throw new SyntaxError(`Unexpected character at index ${i}`);
|
2144 |
+
}
|
2145 |
+
if (t === -1 || r !== -1)
|
2146 |
+
throw new SyntaxError("Unexpected end of input");
|
2147 |
+
const n = s.slice(t, i);
|
2148 |
+
if (e.has(n))
|
2149 |
+
throw new SyntaxError(`The "${n}" subprotocol is duplicated`);
|
2150 |
+
return e.add(n), e;
|
2151 |
+
}
|
2152 |
+
var Rs = { parse: Ps };
|
2153 |
+
const Us = S, ie = S, { createHash: Bs } = S, qe = nt, N = oe, $s = Rs, Ms = ft, { GUID: Is, kWebSocket: Ds } = U, Ws = /^[+/0-9A-Za-z]{22}==$/, Ke = 0, Xe = 1, pt = 2;
|
2154 |
+
class As extends Us {
|
2155 |
+
/**
|
2156 |
+
* Create a `WebSocketServer` instance.
|
2157 |
+
*
|
2158 |
+
* @param {Object} options Configuration options
|
2159 |
+
* @param {Number} [options.backlog=511] The maximum length of the queue of
|
2160 |
+
* pending connections
|
2161 |
+
* @param {Boolean} [options.clientTracking=true] Specifies whether or not to
|
2162 |
+
* track clients
|
2163 |
+
* @param {Function} [options.handleProtocols] A hook to handle protocols
|
2164 |
+
* @param {String} [options.host] The hostname where to bind the server
|
2165 |
+
* @param {Number} [options.maxPayload=104857600] The maximum allowed message
|
2166 |
+
* size
|
2167 |
+
* @param {Boolean} [options.noServer=false] Enable no server mode
|
2168 |
+
* @param {String} [options.path] Accept only connections matching this path
|
2169 |
+
* @param {(Boolean|Object)} [options.perMessageDeflate=false] Enable/disable
|
2170 |
+
* permessage-deflate
|
2171 |
+
* @param {Number} [options.port] The port where to bind the server
|
2172 |
+
* @param {(http.Server|https.Server)} [options.server] A pre-created HTTP/S
|
2173 |
+
* server to use
|
2174 |
+
* @param {Boolean} [options.skipUTF8Validation=false] Specifies whether or
|
2175 |
+
* not to skip UTF-8 validation for text and close messages
|
2176 |
+
* @param {Function} [options.verifyClient] A hook to reject connections
|
2177 |
+
* @param {Function} [options.WebSocket=WebSocket] Specifies the `WebSocket`
|
2178 |
+
* class to use. It must be the `WebSocket` class or class that extends it
|
2179 |
+
* @param {Function} [callback] A listener for the `listening` event
|
2180 |
+
*/
|
2181 |
+
constructor(e, t) {
|
2182 |
+
if (super(), e = {
|
2183 |
+
maxPayload: 100 * 1024 * 1024,
|
2184 |
+
skipUTF8Validation: !1,
|
2185 |
+
perMessageDeflate: !1,
|
2186 |
+
handleProtocols: null,
|
2187 |
+
clientTracking: !0,
|
2188 |
+
verifyClient: null,
|
2189 |
+
noServer: !1,
|
2190 |
+
backlog: null,
|
2191 |
+
// use default (511 as implemented in net.js)
|
2192 |
+
server: null,
|
2193 |
+
host: null,
|
2194 |
+
path: null,
|
2195 |
+
port: null,
|
2196 |
+
WebSocket: Ms,
|
2197 |
+
...e
|
2198 |
+
}, e.port == null && !e.server && !e.noServer || e.port != null && (e.server || e.noServer) || e.server && e.noServer)
|
2199 |
+
throw new TypeError(
|
2200 |
+
'One and only one of the "port", "server", or "noServer" options must be specified'
|
2201 |
+
);
|
2202 |
+
if (e.port != null ? (this._server = ie.createServer((r, i) => {
|
2203 |
+
const n = ie.STATUS_CODES[426];
|
2204 |
+
i.writeHead(426, {
|
2205 |
+
"Content-Length": n.length,
|
2206 |
+
"Content-Type": "text/plain"
|
2207 |
+
}), i.end(n);
|
2208 |
+
}), this._server.listen(
|
2209 |
+
e.port,
|
2210 |
+
e.host,
|
2211 |
+
e.backlog,
|
2212 |
+
t
|
2213 |
+
)) : e.server && (this._server = e.server), this._server) {
|
2214 |
+
const r = this.emit.bind(this, "connection");
|
2215 |
+
this._removeListeners = js(this._server, {
|
2216 |
+
listening: this.emit.bind(this, "listening"),
|
2217 |
+
error: this.emit.bind(this, "error"),
|
2218 |
+
upgrade: (i, n, o) => {
|
2219 |
+
this.handleUpgrade(i, n, o, r);
|
2220 |
+
}
|
2221 |
+
});
|
2222 |
+
}
|
2223 |
+
e.perMessageDeflate === !0 && (e.perMessageDeflate = {}), e.clientTracking && (this.clients = /* @__PURE__ */ new Set(), this._shouldEmitClose = !1), this.options = e, this._state = Ke;
|
2224 |
+
}
|
2225 |
+
/**
|
2226 |
+
* Returns the bound address, the address family name, and port of the server
|
2227 |
+
* as reported by the operating system if listening on an IP socket.
|
2228 |
+
* If the server is listening on a pipe or UNIX domain socket, the name is
|
2229 |
+
* returned as a string.
|
2230 |
+
*
|
2231 |
+
* @return {(Object|String|null)} The address of the server
|
2232 |
+
* @public
|
2233 |
+
*/
|
2234 |
+
address() {
|
2235 |
+
if (this.options.noServer)
|
2236 |
+
throw new Error('The server is operating in "noServer" mode');
|
2237 |
+
return this._server ? this._server.address() : null;
|
2238 |
+
}
|
2239 |
+
/**
|
2240 |
+
* Stop the server from accepting new connections and emit the `'close'` event
|
2241 |
+
* when all existing connections are closed.
|
2242 |
+
*
|
2243 |
+
* @param {Function} [cb] A one-time listener for the `'close'` event
|
2244 |
+
* @public
|
2245 |
+
*/
|
2246 |
+
close(e) {
|
2247 |
+
if (this._state === pt) {
|
2248 |
+
e && this.once("close", () => {
|
2249 |
+
e(new Error("The server is not running"));
|
2250 |
+
}), process.nextTick(G, this);
|
2251 |
+
return;
|
2252 |
+
}
|
2253 |
+
if (e && this.once("close", e), this._state !== Xe)
|
2254 |
+
if (this._state = Xe, this.options.noServer || this.options.server)
|
2255 |
+
this._server && (this._removeListeners(), this._removeListeners = this._server = null), this.clients ? this.clients.size ? this._shouldEmitClose = !0 : process.nextTick(G, this) : process.nextTick(G, this);
|
2256 |
+
else {
|
2257 |
+
const t = this._server;
|
2258 |
+
this._removeListeners(), this._removeListeners = this._server = null, t.close(() => {
|
2259 |
+
G(this);
|
2260 |
+
});
|
2261 |
+
}
|
2262 |
+
}
|
2263 |
+
/**
|
2264 |
+
* See if a given request should be handled by this server instance.
|
2265 |
+
*
|
2266 |
+
* @param {http.IncomingMessage} req Request object to inspect
|
2267 |
+
* @return {Boolean} `true` if the request is valid, else `false`
|
2268 |
+
* @public
|
2269 |
+
*/
|
2270 |
+
shouldHandle(e) {
|
2271 |
+
if (this.options.path) {
|
2272 |
+
const t = e.url.indexOf("?");
|
2273 |
+
if ((t !== -1 ? e.url.slice(0, t) : e.url) !== this.options.path)
|
2274 |
+
return !1;
|
2275 |
+
}
|
2276 |
+
return !0;
|
2277 |
+
}
|
2278 |
+
/**
|
2279 |
+
* Handle a HTTP Upgrade request.
|
2280 |
+
*
|
2281 |
+
* @param {http.IncomingMessage} req The request object
|
2282 |
+
* @param {(net.Socket|tls.Socket)} socket The network socket between the
|
2283 |
+
* server and client
|
2284 |
+
* @param {Buffer} head The first packet of the upgraded stream
|
2285 |
+
* @param {Function} cb Callback
|
2286 |
+
* @public
|
2287 |
+
*/
|
2288 |
+
handleUpgrade(e, t, r, i) {
|
2289 |
+
t.on("error", Ze);
|
2290 |
+
const n = e.headers["sec-websocket-key"], o = +e.headers["sec-websocket-version"];
|
2291 |
+
if (e.method !== "GET") {
|
2292 |
+
R(this, e, t, 405, "Invalid HTTP method");
|
2293 |
+
return;
|
2294 |
+
}
|
2295 |
+
if (e.headers.upgrade.toLowerCase() !== "websocket") {
|
2296 |
+
R(this, e, t, 400, "Invalid Upgrade header");
|
2297 |
+
return;
|
2298 |
+
}
|
2299 |
+
if (!n || !Ws.test(n)) {
|
2300 |
+
R(this, e, t, 400, "Missing or invalid Sec-WebSocket-Key header");
|
2301 |
+
return;
|
2302 |
+
}
|
2303 |
+
if (o !== 8 && o !== 13) {
|
2304 |
+
R(this, e, t, 400, "Missing or invalid Sec-WebSocket-Version header");
|
2305 |
+
return;
|
2306 |
+
}
|
2307 |
+
if (!this.shouldHandle(e)) {
|
2308 |
+
H(t, 400);
|
2309 |
+
return;
|
2310 |
+
}
|
2311 |
+
const l = e.headers["sec-websocket-protocol"];
|
2312 |
+
let f = /* @__PURE__ */ new Set();
|
2313 |
+
if (l !== void 0)
|
2314 |
+
try {
|
2315 |
+
f = $s.parse(l);
|
2316 |
+
} catch {
|
2317 |
+
R(this, e, t, 400, "Invalid Sec-WebSocket-Protocol header");
|
2318 |
+
return;
|
2319 |
+
}
|
2320 |
+
const a = e.headers["sec-websocket-extensions"], c = {};
|
2321 |
+
if (this.options.perMessageDeflate && a !== void 0) {
|
2322 |
+
const h = new N(
|
2323 |
+
this.options.perMessageDeflate,
|
2324 |
+
!0,
|
2325 |
+
this.options.maxPayload
|
2326 |
+
);
|
2327 |
+
try {
|
2328 |
+
const p = qe.parse(a);
|
2329 |
+
p[N.extensionName] && (h.accept(p[N.extensionName]), c[N.extensionName] = h);
|
2330 |
+
} catch {
|
2331 |
+
R(this, e, t, 400, "Invalid or unacceptable Sec-WebSocket-Extensions header");
|
2332 |
+
return;
|
2333 |
+
}
|
2334 |
+
}
|
2335 |
+
if (this.options.verifyClient) {
|
2336 |
+
const h = {
|
2337 |
+
origin: e.headers[`${o === 8 ? "sec-websocket-origin" : "origin"}`],
|
2338 |
+
secure: !!(e.socket.authorized || e.socket.encrypted),
|
2339 |
+
req: e
|
2340 |
+
};
|
2341 |
+
if (this.options.verifyClient.length === 2) {
|
2342 |
+
this.options.verifyClient(h, (p, v, _, u) => {
|
2343 |
+
if (!p)
|
2344 |
+
return H(t, v || 401, _, u);
|
2345 |
+
this.completeUpgrade(
|
2346 |
+
c,
|
2347 |
+
n,
|
2348 |
+
f,
|
2349 |
+
e,
|
2350 |
+
t,
|
2351 |
+
r,
|
2352 |
+
i
|
2353 |
+
);
|
2354 |
+
});
|
2355 |
+
return;
|
2356 |
+
}
|
2357 |
+
if (!this.options.verifyClient(h))
|
2358 |
+
return H(t, 401);
|
2359 |
+
}
|
2360 |
+
this.completeUpgrade(c, n, f, e, t, r, i);
|
2361 |
+
}
|
2362 |
+
/**
|
2363 |
+
* Upgrade the connection to WebSocket.
|
2364 |
+
*
|
2365 |
+
* @param {Object} extensions The accepted extensions
|
2366 |
+
* @param {String} key The value of the `Sec-WebSocket-Key` header
|
2367 |
+
* @param {Set} protocols The subprotocols
|
2368 |
+
* @param {http.IncomingMessage} req The request object
|
2369 |
+
* @param {(net.Socket|tls.Socket)} socket The network socket between the
|
2370 |
+
* server and client
|
2371 |
+
* @param {Buffer} head The first packet of the upgraded stream
|
2372 |
+
* @param {Function} cb Callback
|
2373 |
+
* @throws {Error} If called more than once with the same socket
|
2374 |
+
* @private
|
2375 |
+
*/
|
2376 |
+
completeUpgrade(e, t, r, i, n, o, l) {
|
2377 |
+
if (!n.readable || !n.writable)
|
2378 |
+
return n.destroy();
|
2379 |
+
if (n[Ds])
|
2380 |
+
throw new Error(
|
2381 |
+
"server.handleUpgrade() was called more than once with the same socket, possibly due to a misconfiguration"
|
2382 |
+
);
|
2383 |
+
if (this._state > Ke)
|
2384 |
+
return H(n, 503);
|
2385 |
+
const a = [
|
2386 |
+
"HTTP/1.1 101 Switching Protocols",
|
2387 |
+
"Upgrade: websocket",
|
2388 |
+
"Connection: Upgrade",
|
2389 |
+
`Sec-WebSocket-Accept: ${Bs("sha1").update(t + Is).digest("base64")}`
|
2390 |
+
], c = new this.options.WebSocket(null);
|
2391 |
+
if (r.size) {
|
2392 |
+
const h = this.options.handleProtocols ? this.options.handleProtocols(r, i) : r.values().next().value;
|
2393 |
+
h && (a.push(`Sec-WebSocket-Protocol: ${h}`), c._protocol = h);
|
2394 |
+
}
|
2395 |
+
if (e[N.extensionName]) {
|
2396 |
+
const h = e[N.extensionName].params, p = qe.format({
|
2397 |
+
[N.extensionName]: [h]
|
2398 |
+
});
|
2399 |
+
a.push(`Sec-WebSocket-Extensions: ${p}`), c._extensions = e;
|
2400 |
+
}
|
2401 |
+
this.emit("headers", a, i), n.write(a.concat(`\r
|
2402 |
+
`).join(`\r
|
2403 |
+
`)), n.removeListener("error", Ze), c.setSocket(n, o, {
|
2404 |
+
maxPayload: this.options.maxPayload,
|
2405 |
+
skipUTF8Validation: this.options.skipUTF8Validation
|
2406 |
+
}), this.clients && (this.clients.add(c), c.on("close", () => {
|
2407 |
+
this.clients.delete(c), this._shouldEmitClose && !this.clients.size && process.nextTick(G, this);
|
2408 |
+
})), l(c, i);
|
2409 |
+
}
|
2410 |
+
}
|
2411 |
+
var Fs = As;
|
2412 |
+
function js(s, e) {
|
2413 |
+
for (const t of Object.keys(e))
|
2414 |
+
s.on(t, e[t]);
|
2415 |
+
return function() {
|
2416 |
+
for (const r of Object.keys(e))
|
2417 |
+
s.removeListener(r, e[r]);
|
2418 |
+
};
|
2419 |
+
}
|
2420 |
+
function G(s) {
|
2421 |
+
s._state = pt, s.emit("close");
|
2422 |
+
}
|
2423 |
+
function Ze() {
|
2424 |
+
this.destroy();
|
2425 |
+
}
|
2426 |
+
function H(s, e, t, r) {
|
2427 |
+
t = t || ie.STATUS_CODES[e], r = {
|
2428 |
+
Connection: "close",
|
2429 |
+
"Content-Type": "text/html",
|
2430 |
+
"Content-Length": Buffer.byteLength(t),
|
2431 |
+
...r
|
2432 |
+
}, s.once("finish", s.destroy), s.end(
|
2433 |
+
`HTTP/1.1 ${e} ${ie.STATUS_CODES[e]}\r
|
2434 |
+
` + Object.keys(r).map((i) => `${i}: ${r[i]}`).join(`\r
|
2435 |
+
`) + `\r
|
2436 |
+
\r
|
2437 |
+
` + t
|
2438 |
+
);
|
2439 |
+
}
|
2440 |
+
function R(s, e, t, r, i) {
|
2441 |
+
if (s.listenerCount("wsClientError")) {
|
2442 |
+
const n = new Error(i);
|
2443 |
+
Error.captureStackTrace(n, R), s.emit("wsClientError", n, t, e);
|
2444 |
+
} else
|
2445 |
+
H(t, r, i);
|
2446 |
+
}
|
2447 |
+
const Zs = /* @__PURE__ */ z(Fs);
|
2448 |
+
export {
|
2449 |
+
qs as Receiver,
|
2450 |
+
Ks as Sender,
|
2451 |
+
Xs as WebSocket,
|
2452 |
+
Zs as WebSocketServer,
|
2453 |
+
Vs as createWebSocketStream,
|
2454 |
+
Xs as default
|
2455 |
+
};
|
src/backend/gradio_image_prompter/templates/example/index.js
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
const { setContext: ee, getContext: p } = window.__gradio__svelte__internal, v = "WORKER_PROXY_CONTEXT_KEY";
|
2 |
+
function y() {
|
3 |
+
return p(v);
|
4 |
+
}
|
5 |
+
function k(l) {
|
6 |
+
return l.host === window.location.host || l.host === "localhost:7860" || l.host === "127.0.0.1:7860" || // Ref: https://github.com/gradio-app/gradio/blob/v3.32.0/js/app/src/Index.svelte#L194
|
7 |
+
l.host === "lite.local";
|
8 |
+
}
|
9 |
+
async function f(l) {
|
10 |
+
if (l == null)
|
11 |
+
return l;
|
12 |
+
const e = new URL(l);
|
13 |
+
if (!k(e) || e.protocol !== "http:" && e.protocol !== "https:")
|
14 |
+
return l;
|
15 |
+
const r = y();
|
16 |
+
if (r == null)
|
17 |
+
return l;
|
18 |
+
const n = e.pathname;
|
19 |
+
return r.httpRequest({
|
20 |
+
method: "GET",
|
21 |
+
path: n,
|
22 |
+
headers: {},
|
23 |
+
query_string: ""
|
24 |
+
}).then((t) => {
|
25 |
+
if (t.status !== 200)
|
26 |
+
throw new Error(`Failed to get file ${n} from the Wasm worker.`);
|
27 |
+
const o = new Blob([t.body], {
|
28 |
+
type: t.headers["Content-Type"]
|
29 |
+
});
|
30 |
+
return URL.createObjectURL(o);
|
31 |
+
});
|
32 |
+
}
|
33 |
+
const {
|
34 |
+
SvelteComponent: w,
|
35 |
+
append: C,
|
36 |
+
assign: _,
|
37 |
+
compute_rest_props: d,
|
38 |
+
detach: u,
|
39 |
+
element: b,
|
40 |
+
empty: E,
|
41 |
+
exclude_internal_props: R,
|
42 |
+
get_spread_update: O,
|
43 |
+
handle_promise: h,
|
44 |
+
init: q,
|
45 |
+
insert: m,
|
46 |
+
noop: c,
|
47 |
+
safe_not_equal: T,
|
48 |
+
set_attributes: g,
|
49 |
+
set_data: P,
|
50 |
+
set_style: U,
|
51 |
+
src_url_equal: W,
|
52 |
+
text: K,
|
53 |
+
update_await_block_branch: X
|
54 |
+
} = window.__gradio__svelte__internal;
|
55 |
+
function Y(l) {
|
56 |
+
let e, r = (
|
57 |
+
/*error*/
|
58 |
+
l[3].message + ""
|
59 |
+
), n;
|
60 |
+
return {
|
61 |
+
c() {
|
62 |
+
e = b("p"), n = K(r), U(e, "color", "red");
|
63 |
+
},
|
64 |
+
m(t, o) {
|
65 |
+
m(t, e, o), C(e, n);
|
66 |
+
},
|
67 |
+
p(t, o) {
|
68 |
+
o & /*src*/
|
69 |
+
1 && r !== (r = /*error*/
|
70 |
+
t[3].message + "") && P(n, r);
|
71 |
+
},
|
72 |
+
d(t) {
|
73 |
+
t && u(e);
|
74 |
+
}
|
75 |
+
};
|
76 |
+
}
|
77 |
+
function L(l) {
|
78 |
+
let e, r, n = [
|
79 |
+
{
|
80 |
+
src: r = /*resolved_src*/
|
81 |
+
l[2]
|
82 |
+
},
|
83 |
+
/*$$restProps*/
|
84 |
+
l[1]
|
85 |
+
], t = {};
|
86 |
+
for (let o = 0; o < n.length; o += 1)
|
87 |
+
t = _(t, n[o]);
|
88 |
+
return {
|
89 |
+
c() {
|
90 |
+
e = b("img"), g(e, t);
|
91 |
+
},
|
92 |
+
m(o, s) {
|
93 |
+
m(o, e, s);
|
94 |
+
},
|
95 |
+
p(o, s) {
|
96 |
+
g(e, t = O(n, [
|
97 |
+
s & /*src*/
|
98 |
+
1 && !W(e.src, r = /*resolved_src*/
|
99 |
+
o[2]) && { src: r },
|
100 |
+
s & /*$$restProps*/
|
101 |
+
2 && /*$$restProps*/
|
102 |
+
o[1]
|
103 |
+
]));
|
104 |
+
},
|
105 |
+
d(o) {
|
106 |
+
o && u(e);
|
107 |
+
}
|
108 |
+
};
|
109 |
+
}
|
110 |
+
function N(l) {
|
111 |
+
return { c, m: c, p: c, d: c };
|
112 |
+
}
|
113 |
+
function S(l) {
|
114 |
+
let e, r, n = {
|
115 |
+
ctx: l,
|
116 |
+
current: null,
|
117 |
+
token: null,
|
118 |
+
hasCatch: !0,
|
119 |
+
pending: N,
|
120 |
+
then: L,
|
121 |
+
catch: Y,
|
122 |
+
value: 2,
|
123 |
+
error: 3
|
124 |
+
};
|
125 |
+
return h(r = f(
|
126 |
+
/*src*/
|
127 |
+
l[0]
|
128 |
+
), n), {
|
129 |
+
c() {
|
130 |
+
e = E(), n.block.c();
|
131 |
+
},
|
132 |
+
m(t, o) {
|
133 |
+
m(t, e, o), n.block.m(t, n.anchor = o), n.mount = () => e.parentNode, n.anchor = e;
|
134 |
+
},
|
135 |
+
p(t, [o]) {
|
136 |
+
l = t, n.ctx = l, o & /*src*/
|
137 |
+
1 && r !== (r = f(
|
138 |
+
/*src*/
|
139 |
+
l[0]
|
140 |
+
)) && h(r, n) || X(n, l, o);
|
141 |
+
},
|
142 |
+
i: c,
|
143 |
+
o: c,
|
144 |
+
d(t) {
|
145 |
+
t && u(e), n.block.d(t), n.token = null, n = null;
|
146 |
+
}
|
147 |
+
};
|
148 |
+
}
|
149 |
+
function j(l, e, r) {
|
150 |
+
const n = ["src"];
|
151 |
+
let t = d(e, n), { src: o = void 0 } = e;
|
152 |
+
return l.$$set = (s) => {
|
153 |
+
e = _(_({}, e), R(s)), r(1, t = d(e, n)), "src" in s && r(0, o = s.src);
|
154 |
+
}, [o, t];
|
155 |
+
}
|
156 |
+
class B extends w {
|
157 |
+
constructor(e) {
|
158 |
+
super(), q(this, e, j, S, T, { src: 0 });
|
159 |
+
}
|
160 |
+
}
|
161 |
+
const {
|
162 |
+
SvelteComponent: F,
|
163 |
+
attr: G,
|
164 |
+
create_component: I,
|
165 |
+
destroy_component: z,
|
166 |
+
detach: A,
|
167 |
+
element: D,
|
168 |
+
init: H,
|
169 |
+
insert: J,
|
170 |
+
mount_component: M,
|
171 |
+
safe_not_equal: Q,
|
172 |
+
toggle_class: i,
|
173 |
+
transition_in: V,
|
174 |
+
transition_out: Z
|
175 |
+
} = window.__gradio__svelte__internal;
|
176 |
+
function x(l) {
|
177 |
+
let e, r, n;
|
178 |
+
return r = new B({
|
179 |
+
props: {
|
180 |
+
src: (
|
181 |
+
/*samples_dir*/
|
182 |
+
l[1] + /*value*/
|
183 |
+
l[0]
|
184 |
+
),
|
185 |
+
alt: ""
|
186 |
+
}
|
187 |
+
}), {
|
188 |
+
c() {
|
189 |
+
e = D("div"), I(r.$$.fragment), G(e, "class", "container svelte-h11ksk"), i(
|
190 |
+
e,
|
191 |
+
"table",
|
192 |
+
/*type*/
|
193 |
+
l[2] === "table"
|
194 |
+
), i(
|
195 |
+
e,
|
196 |
+
"gallery",
|
197 |
+
/*type*/
|
198 |
+
l[2] === "gallery"
|
199 |
+
), i(
|
200 |
+
e,
|
201 |
+
"selected",
|
202 |
+
/*selected*/
|
203 |
+
l[3]
|
204 |
+
);
|
205 |
+
},
|
206 |
+
m(t, o) {
|
207 |
+
J(t, e, o), M(r, e, null), n = !0;
|
208 |
+
},
|
209 |
+
p(t, [o]) {
|
210 |
+
const s = {};
|
211 |
+
o & /*samples_dir, value*/
|
212 |
+
3 && (s.src = /*samples_dir*/
|
213 |
+
t[1] + /*value*/
|
214 |
+
t[0]), r.$set(s), (!n || o & /*type*/
|
215 |
+
4) && i(
|
216 |
+
e,
|
217 |
+
"table",
|
218 |
+
/*type*/
|
219 |
+
t[2] === "table"
|
220 |
+
), (!n || o & /*type*/
|
221 |
+
4) && i(
|
222 |
+
e,
|
223 |
+
"gallery",
|
224 |
+
/*type*/
|
225 |
+
t[2] === "gallery"
|
226 |
+
), (!n || o & /*selected*/
|
227 |
+
8) && i(
|
228 |
+
e,
|
229 |
+
"selected",
|
230 |
+
/*selected*/
|
231 |
+
t[3]
|
232 |
+
);
|
233 |
+
},
|
234 |
+
i(t) {
|
235 |
+
n || (V(r.$$.fragment, t), n = !0);
|
236 |
+
},
|
237 |
+
o(t) {
|
238 |
+
Z(r.$$.fragment, t), n = !1;
|
239 |
+
},
|
240 |
+
d(t) {
|
241 |
+
t && A(e), z(r);
|
242 |
+
}
|
243 |
+
};
|
244 |
+
}
|
245 |
+
function $(l, e, r) {
|
246 |
+
let { value: n } = e, { samples_dir: t } = e, { type: o } = e, { selected: s = !1 } = e;
|
247 |
+
return l.$$set = (a) => {
|
248 |
+
"value" in a && r(0, n = a.value), "samples_dir" in a && r(1, t = a.samples_dir), "type" in a && r(2, o = a.type), "selected" in a && r(3, s = a.selected);
|
249 |
+
}, [n, t, o, s];
|
250 |
+
}
|
251 |
+
class te extends F {
|
252 |
+
constructor(e) {
|
253 |
+
super(), H(this, e, $, x, Q, {
|
254 |
+
value: 0,
|
255 |
+
samples_dir: 1,
|
256 |
+
type: 2,
|
257 |
+
selected: 3
|
258 |
+
});
|
259 |
+
}
|
260 |
+
}
|
261 |
+
export {
|
262 |
+
te as default
|
263 |
+
};
|
src/backend/gradio_image_prompter/templates/example/style.css
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.container.svelte-h11ksk img{width:100%;height:100%}.container.selected.svelte-h11ksk{border-color:var(--border-color-accent)}.container.table.svelte-h11ksk{margin:0 auto;border:2px solid var(--border-color-primary);border-radius:var(--radius-lg);overflow:hidden;width:var(--size-20);height:var(--size-20);object-fit:cover}.container.gallery.svelte-h11ksk{height:var(--size-20);max-height:var(--size-20);object-fit:cover}
|
src/demo/__init__.py
ADDED
File without changes
|
src/demo/app.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from gradio_image_prompter import ImagePrompter
|
3 |
+
|
4 |
+
demo = gr.Interface(
|
5 |
+
lambda prompts: (prompts["image"], prompts["points"]),
|
6 |
+
ImagePrompter(show_label=False),
|
7 |
+
[gr.Image(show_label=False), gr.Dataframe(label="Points")],
|
8 |
+
)
|
9 |
+
demo.launch()
|
src/frontend/Example.svelte
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<script lang="ts">
|
2 |
+
import Image from "./shared/Image.svelte";
|
3 |
+
|
4 |
+
export let value: string;
|
5 |
+
export let samples_dir: string;
|
6 |
+
export let type: "gallery" | "table";
|
7 |
+
export let selected = false;
|
8 |
+
</script>
|
9 |
+
|
10 |
+
<div
|
11 |
+
class="container"
|
12 |
+
class:table={type === "table"}
|
13 |
+
class:gallery={type === "gallery"}
|
14 |
+
class:selected
|
15 |
+
>
|
16 |
+
<Image src={samples_dir + value} alt="" />
|
17 |
+
</div>
|
18 |
+
|
19 |
+
<style>
|
20 |
+
.container :global(img) {
|
21 |
+
width: 100%;
|
22 |
+
height: 100%;
|
23 |
+
}
|
24 |
+
|
25 |
+
.container.selected {
|
26 |
+
border-color: var(--border-color-accent);
|
27 |
+
}
|
28 |
+
|
29 |
+
.container.table {
|
30 |
+
margin: 0 auto;
|
31 |
+
border: 2px solid var(--border-color-primary);
|
32 |
+
border-radius: var(--radius-lg);
|
33 |
+
overflow: hidden;
|
34 |
+
width: var(--size-20);
|
35 |
+
height: var(--size-20);
|
36 |
+
object-fit: cover;
|
37 |
+
}
|
38 |
+
|
39 |
+
.container.gallery {
|
40 |
+
height: var(--size-20);
|
41 |
+
max-height: var(--size-20);
|
42 |
+
object-fit: cover;
|
43 |
+
}
|
44 |
+
</style>
|
src/frontend/Index.svelte
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<svelte:options accessors={true} />
|
2 |
+
|
3 |
+
<script context="module" lang="ts">
|
4 |
+
export { default as BaseImageUploader } from "./shared/ImageUploader.svelte";
|
5 |
+
export { default as BaseStaticImage } from "./shared/ImagePreview.svelte";
|
6 |
+
export { default as BaseExample } from "./Example.svelte";
|
7 |
+
export { default as BaseImage } from "./shared/Image.svelte";
|
8 |
+
export { default as BoxDrawer } from "./shared/BoxDrawer.svelte";
|
9 |
+
</script>
|
10 |
+
|
11 |
+
<script lang="ts">
|
12 |
+
import type { Gradio, SelectData } from "@gradio/utils";
|
13 |
+
import StaticImage from "./shared/ImagePreview.svelte";
|
14 |
+
import ImageUploader from "./shared/ImageUploader.svelte";
|
15 |
+
|
16 |
+
import { Block, Empty, UploadText } from "@gradio/atoms";
|
17 |
+
import { Image } from "@gradio/icons";
|
18 |
+
import { StatusTracker } from "@gradio/statustracker";
|
19 |
+
import type { FileData } from "@gradio/client";
|
20 |
+
import type { LoadingStatus } from "@gradio/statustracker";
|
21 |
+
import { normalise_file } from "@gradio/client";
|
22 |
+
|
23 |
+
export let elem_id = "";
|
24 |
+
export let elem_classes: string[] = [];
|
25 |
+
export let visible = true;
|
26 |
+
|
27 |
+
export let value: { image: FileData; points: number[][6] } | null = null;
|
28 |
+
$: _image = value && normalise_file(value.image, root, proxy_url);
|
29 |
+
$: _points = value && value.points;
|
30 |
+
|
31 |
+
export let label: string;
|
32 |
+
export let show_label: boolean;
|
33 |
+
export let show_download_button: boolean;
|
34 |
+
export let root: string;
|
35 |
+
export let proxy_url: null | string;
|
36 |
+
|
37 |
+
export let height: number | undefined;
|
38 |
+
export let width: number | undefined;
|
39 |
+
|
40 |
+
export let _selectable = false;
|
41 |
+
export let container = true;
|
42 |
+
export let scale: number | null = null;
|
43 |
+
export let min_width: number | undefined = undefined;
|
44 |
+
export let loading_status: LoadingStatus;
|
45 |
+
export let show_share_button = false;
|
46 |
+
export let sources: "upload"[] = ["upload"];
|
47 |
+
export let interactive: boolean;
|
48 |
+
export let streaming: boolean;
|
49 |
+
|
50 |
+
export let gradio: Gradio<{
|
51 |
+
change: never;
|
52 |
+
error: string;
|
53 |
+
edit: never;
|
54 |
+
stream: never;
|
55 |
+
drag: never;
|
56 |
+
upload: never;
|
57 |
+
clear: never;
|
58 |
+
select: SelectData;
|
59 |
+
share: ShareData;
|
60 |
+
}>;
|
61 |
+
|
62 |
+
$: url = _image?.url;
|
63 |
+
$: url && gradio.dispatch("change");
|
64 |
+
|
65 |
+
let dragging: boolean;
|
66 |
+
let active_tool: null | "webcam" = null;
|
67 |
+
</script>
|
68 |
+
|
69 |
+
{#if !interactive}
|
70 |
+
<Block
|
71 |
+
{visible}
|
72 |
+
variant={"solid"}
|
73 |
+
border_mode={dragging ? "focus" : "base"}
|
74 |
+
padding={false}
|
75 |
+
{elem_id}
|
76 |
+
{elem_classes}
|
77 |
+
height={height || undefined}
|
78 |
+
{width}
|
79 |
+
allow_overflow={false}
|
80 |
+
{container}
|
81 |
+
{scale}
|
82 |
+
{min_width}
|
83 |
+
>
|
84 |
+
<StatusTracker
|
85 |
+
autoscroll={gradio.autoscroll}
|
86 |
+
i18n={gradio.i18n}
|
87 |
+
{...loading_status}
|
88 |
+
/>
|
89 |
+
<StaticImage
|
90 |
+
on:select={({ detail }) => gradio.dispatch("select", detail)}
|
91 |
+
on:share={({ detail }) => gradio.dispatch("share", detail)}
|
92 |
+
on:error={({ detail }) => gradio.dispatch("error", detail)}
|
93 |
+
value={_image}
|
94 |
+
{label}
|
95 |
+
{show_label}
|
96 |
+
{show_download_button}
|
97 |
+
selectable={_selectable}
|
98 |
+
{show_share_button}
|
99 |
+
i18n={gradio.i18n}
|
100 |
+
/>
|
101 |
+
</Block>
|
102 |
+
{:else}
|
103 |
+
<Block
|
104 |
+
{visible}
|
105 |
+
variant={_image === null ? "dashed" : "solid"}
|
106 |
+
border_mode={dragging ? "focus" : "base"}
|
107 |
+
padding={false}
|
108 |
+
{elem_id}
|
109 |
+
{elem_classes}
|
110 |
+
height={height || undefined}
|
111 |
+
{width}
|
112 |
+
allow_overflow={false}
|
113 |
+
{container}
|
114 |
+
{scale}
|
115 |
+
{min_width}
|
116 |
+
>
|
117 |
+
<StatusTracker
|
118 |
+
autoscroll={gradio.autoscroll}
|
119 |
+
i18n={gradio.i18n}
|
120 |
+
{...loading_status}
|
121 |
+
/>
|
122 |
+
|
123 |
+
<ImageUploader
|
124 |
+
bind:active_tool
|
125 |
+
bind:value={_image}
|
126 |
+
bind:points={_points}
|
127 |
+
{root}
|
128 |
+
{sources}
|
129 |
+
on:points_change={({ detail }) => (value.points = detail)}
|
130 |
+
on:edit={() => gradio.dispatch("edit")}
|
131 |
+
on:clear={() => {
|
132 |
+
value = null;
|
133 |
+
gradio.dispatch("clear");
|
134 |
+
gradio.dispatch("change");
|
135 |
+
}}
|
136 |
+
on:stream={() => gradio.dispatch("stream")}
|
137 |
+
on:drag={({ detail }) => (dragging = detail)}
|
138 |
+
on:upload={({ detail }) => {
|
139 |
+
if (value == null) {
|
140 |
+
value = { image: detail, points: null };
|
141 |
+
} else {
|
142 |
+
value.image = detail;
|
143 |
+
}
|
144 |
+
gradio.dispatch("upload");
|
145 |
+
}}
|
146 |
+
on:select={({ detail }) => gradio.dispatch("select", detail)}
|
147 |
+
on:share={({ detail }) => gradio.dispatch("share", detail)}
|
148 |
+
on:error={({ detail }) => {
|
149 |
+
loading_status = loading_status;
|
150 |
+
loading_status.status = "error";
|
151 |
+
gradio.dispatch("error", detail);
|
152 |
+
}}
|
153 |
+
on:click={() => gradio.dispatch("error", "bad thing happened")}
|
154 |
+
on:error
|
155 |
+
{label}
|
156 |
+
{show_label}
|
157 |
+
{streaming}
|
158 |
+
i18n={gradio.i18n}
|
159 |
+
>
|
160 |
+
{#if sources.includes("upload")}
|
161 |
+
<UploadText i18n={gradio.i18n} type="image" mode="short" />
|
162 |
+
{:else}
|
163 |
+
<Empty unpadded_box={true} size="large"><Image /></Empty>
|
164 |
+
{/if}
|
165 |
+
</ImageUploader>
|
166 |
+
</Block>
|
167 |
+
{/if}
|
src/frontend/package-lock.json
ADDED
@@ -0,0 +1,718 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "gradio_image_prompter",
|
3 |
+
"version": "0.4.2",
|
4 |
+
"lockfileVersion": 3,
|
5 |
+
"requires": true,
|
6 |
+
"packages": {
|
7 |
+
"": {
|
8 |
+
"name": "gradio_image_prompter",
|
9 |
+
"version": "0.4.2",
|
10 |
+
"license": "ISC",
|
11 |
+
"dependencies": {
|
12 |
+
"@gradio/atoms": "0.3.1",
|
13 |
+
"@gradio/client": "0.8.2",
|
14 |
+
"@gradio/icons": "0.3.1",
|
15 |
+
"@gradio/statustracker": "0.4.1",
|
16 |
+
"@gradio/upload": "0.5.2",
|
17 |
+
"@gradio/utils": "0.2.0",
|
18 |
+
"@gradio/wasm": "0.3.0",
|
19 |
+
"cropperjs": "^1.5.12",
|
20 |
+
"lazy-brush": "^1.0.1",
|
21 |
+
"resize-observer-polyfill": "^1.5.1"
|
22 |
+
}
|
23 |
+
},
|
24 |
+
"node_modules/@ampproject/remapping": {
|
25 |
+
"version": "2.2.1",
|
26 |
+
"resolved": "https://registry.npmjs.org/@ampproject/remapping/-/remapping-2.2.1.tgz",
|
27 |
+
"integrity": "sha512-lFMjJTrFL3j7L9yBxwYfCq2k6qqwHyzuUl/XBnif78PWTJYyL/dfowQHWE3sp6U6ZzqWiiIZnpTMO96zhkjwtg==",
|
28 |
+
"peer": true,
|
29 |
+
"dependencies": {
|
30 |
+
"@jridgewell/gen-mapping": "^0.3.0",
|
31 |
+
"@jridgewell/trace-mapping": "^0.3.9"
|
32 |
+
},
|
33 |
+
"engines": {
|
34 |
+
"node": ">=6.0.0"
|
35 |
+
}
|
36 |
+
},
|
37 |
+
"node_modules/@esbuild/darwin-arm64": {
|
38 |
+
"version": "0.19.8",
|
39 |
+
"resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.19.8.tgz",
|
40 |
+
"integrity": "sha512-RQw9DemMbIq35Bprbboyf8SmOr4UXsRVxJ97LgB55VKKeJOOdvsIPy0nFyF2l8U+h4PtBx/1kRf0BelOYCiQcw==",
|
41 |
+
"cpu": [
|
42 |
+
"arm64"
|
43 |
+
],
|
44 |
+
"optional": true,
|
45 |
+
"os": [
|
46 |
+
"darwin"
|
47 |
+
],
|
48 |
+
"engines": {
|
49 |
+
"node": ">=12"
|
50 |
+
}
|
51 |
+
},
|
52 |
+
"node_modules/@formatjs/ecma402-abstract": {
|
53 |
+
"version": "1.11.4",
|
54 |
+
"resolved": "https://registry.npmjs.org/@formatjs/ecma402-abstract/-/ecma402-abstract-1.11.4.tgz",
|
55 |
+
"integrity": "sha512-EBikYFp2JCdIfGEb5G9dyCkTGDmC57KSHhRQOC3aYxoPWVZvfWCDjZwkGYHN7Lis/fmuWl906bnNTJifDQ3sXw==",
|
56 |
+
"dependencies": {
|
57 |
+
"@formatjs/intl-localematcher": "0.2.25",
|
58 |
+
"tslib": "^2.1.0"
|
59 |
+
}
|
60 |
+
},
|
61 |
+
"node_modules/@formatjs/fast-memoize": {
|
62 |
+
"version": "1.2.1",
|
63 |
+
"resolved": "https://registry.npmjs.org/@formatjs/fast-memoize/-/fast-memoize-1.2.1.tgz",
|
64 |
+
"integrity": "sha512-Rg0e76nomkz3vF9IPlKeV+Qynok0r7YZjL6syLz4/urSg0IbjPZCB/iYUMNsYA643gh4mgrX3T7KEIFIxJBQeg==",
|
65 |
+
"dependencies": {
|
66 |
+
"tslib": "^2.1.0"
|
67 |
+
}
|
68 |
+
},
|
69 |
+
"node_modules/@formatjs/icu-messageformat-parser": {
|
70 |
+
"version": "2.1.0",
|
71 |
+
"resolved": "https://registry.npmjs.org/@formatjs/icu-messageformat-parser/-/icu-messageformat-parser-2.1.0.tgz",
|
72 |
+
"integrity": "sha512-Qxv/lmCN6hKpBSss2uQ8IROVnta2r9jd3ymUEIjm2UyIkUCHVcbUVRGL/KS/wv7876edvsPe+hjHVJ4z8YuVaw==",
|
73 |
+
"dependencies": {
|
74 |
+
"@formatjs/ecma402-abstract": "1.11.4",
|
75 |
+
"@formatjs/icu-skeleton-parser": "1.3.6",
|
76 |
+
"tslib": "^2.1.0"
|
77 |
+
}
|
78 |
+
},
|
79 |
+
"node_modules/@formatjs/icu-skeleton-parser": {
|
80 |
+
"version": "1.3.6",
|
81 |
+
"resolved": "https://registry.npmjs.org/@formatjs/icu-skeleton-parser/-/icu-skeleton-parser-1.3.6.tgz",
|
82 |
+
"integrity": "sha512-I96mOxvml/YLrwU2Txnd4klA7V8fRhb6JG/4hm3VMNmeJo1F03IpV2L3wWt7EweqNLES59SZ4d6hVOPCSf80Bg==",
|
83 |
+
"dependencies": {
|
84 |
+
"@formatjs/ecma402-abstract": "1.11.4",
|
85 |
+
"tslib": "^2.1.0"
|
86 |
+
}
|
87 |
+
},
|
88 |
+
"node_modules/@formatjs/intl-localematcher": {
|
89 |
+
"version": "0.2.25",
|
90 |
+
"resolved": "https://registry.npmjs.org/@formatjs/intl-localematcher/-/intl-localematcher-0.2.25.tgz",
|
91 |
+
"integrity": "sha512-YmLcX70BxoSopLFdLr1Ds99NdlTI2oWoLbaUW2M406lxOIPzE1KQhRz2fPUkq34xVZQaihCoU29h0KK7An3bhA==",
|
92 |
+
"dependencies": {
|
93 |
+
"tslib": "^2.1.0"
|
94 |
+
}
|
95 |
+
},
|
96 |
+
"node_modules/@gradio/atoms": {
|
97 |
+
"version": "0.3.1",
|
98 |
+
"resolved": "https://registry.npmjs.org/@gradio/atoms/-/atoms-0.3.1.tgz",
|
99 |
+
"integrity": "sha512-P2u1Qud/EmwfGMD9HZdSkw4L3RznGUE3owBx4lRY7JP/1J3sDqy/wN8pZFex+kPKripX29+IiH6+4TRqSs2zFw==",
|
100 |
+
"dependencies": {
|
101 |
+
"@gradio/icons": "^0.3.1",
|
102 |
+
"@gradio/utils": "^0.2.0"
|
103 |
+
}
|
104 |
+
},
|
105 |
+
"node_modules/@gradio/client": {
|
106 |
+
"version": "0.8.2",
|
107 |
+
"resolved": "https://registry.npmjs.org/@gradio/client/-/client-0.8.2.tgz",
|
108 |
+
"integrity": "sha512-ZWrkJBsVg7ioIHhGV1pqIo4MBL0GPn0SHLeA04cqrsxkWiZHZz9CB5wFtm1kaFtd68ERAgEzR8OYVzzlBd2pyQ==",
|
109 |
+
"dependencies": {
|
110 |
+
"bufferutil": "^4.0.7",
|
111 |
+
"semiver": "^1.1.0",
|
112 |
+
"ws": "^8.13.0"
|
113 |
+
},
|
114 |
+
"engines": {
|
115 |
+
"node": ">=18.0.0"
|
116 |
+
}
|
117 |
+
},
|
118 |
+
"node_modules/@gradio/column": {
|
119 |
+
"version": "0.1.0",
|
120 |
+
"resolved": "https://registry.npmjs.org/@gradio/column/-/column-0.1.0.tgz",
|
121 |
+
"integrity": "sha512-P24nqqVnMXBaDA1f/zSN5HZRho4PxP8Dq+7VltPHlmxIEiZYik2AJ4J0LeuIha34FDO0guu/16evdrpvGIUAfw=="
|
122 |
+
},
|
123 |
+
"node_modules/@gradio/icons": {
|
124 |
+
"version": "0.3.1",
|
125 |
+
"resolved": "https://registry.npmjs.org/@gradio/icons/-/icons-0.3.1.tgz",
|
126 |
+
"integrity": "sha512-ZwgXODKa7irD+spE0RCae8fyixgwKOtds6wHL300n9pIRYzL9QkvS1cQJbz0C6NupFCYRSGTQrV5hoLo7yQCew=="
|
127 |
+
},
|
128 |
+
"node_modules/@gradio/statustracker": {
|
129 |
+
"version": "0.4.1",
|
130 |
+
"resolved": "https://registry.npmjs.org/@gradio/statustracker/-/statustracker-0.4.1.tgz",
|
131 |
+
"integrity": "sha512-6YV5UDzau/nNid5D25YLZyPGm/tFd9b0a+x0OCHY+aE3cez7PD4v6hWGuQXPNwa/69viRm8YyoQ2Vex7/3updA==",
|
132 |
+
"dependencies": {
|
133 |
+
"@gradio/atoms": "^0.3.1",
|
134 |
+
"@gradio/column": "^0.1.0",
|
135 |
+
"@gradio/icons": "^0.3.1",
|
136 |
+
"@gradio/utils": "^0.2.0"
|
137 |
+
}
|
138 |
+
},
|
139 |
+
"node_modules/@gradio/theme": {
|
140 |
+
"version": "0.2.0",
|
141 |
+
"resolved": "https://registry.npmjs.org/@gradio/theme/-/theme-0.2.0.tgz",
|
142 |
+
"integrity": "sha512-33c68Nk7oRXLn08OxPfjcPm7S4tXGOUV1I1bVgzdM2YV5o1QBOS1GEnXPZPu/CEYPePLMB6bsDwffrLEyLGWVQ=="
|
143 |
+
},
|
144 |
+
"node_modules/@gradio/upload": {
|
145 |
+
"version": "0.5.2",
|
146 |
+
"resolved": "https://registry.npmjs.org/@gradio/upload/-/upload-0.5.2.tgz",
|
147 |
+
"integrity": "sha512-IXQZ/+0TG/FSOSjJKE28lUG+vGGboD+YQswyvSK6lOpRHvixiqK+eJo0g3jHvmWO9wZLBrEx3XRv8LSgnVHHzw==",
|
148 |
+
"dependencies": {
|
149 |
+
"@gradio/atoms": "^0.3.1",
|
150 |
+
"@gradio/client": "^0.8.2",
|
151 |
+
"@gradio/icons": "^0.3.1",
|
152 |
+
"@gradio/upload": "^0.5.2",
|
153 |
+
"@gradio/utils": "^0.2.0"
|
154 |
+
}
|
155 |
+
},
|
156 |
+
"node_modules/@gradio/utils": {
|
157 |
+
"version": "0.2.0",
|
158 |
+
"resolved": "https://registry.npmjs.org/@gradio/utils/-/utils-0.2.0.tgz",
|
159 |
+
"integrity": "sha512-YkwzXufi6IxQrlMW+1sFo8Yn6F9NLL69ZoBsbo7QEhms0v5L7pmOTw+dfd7M3dwbRP2lgjrb52i1kAIN3n6aqQ==",
|
160 |
+
"dependencies": {
|
161 |
+
"@gradio/theme": "^0.2.0",
|
162 |
+
"svelte-i18n": "^3.6.0"
|
163 |
+
}
|
164 |
+
},
|
165 |
+
"node_modules/@gradio/wasm": {
|
166 |
+
"version": "0.3.0",
|
167 |
+
"resolved": "https://registry.npmjs.org/@gradio/wasm/-/wasm-0.3.0.tgz",
|
168 |
+
"integrity": "sha512-avgMFBrHUUDzQraBMW9mNgiQMMkObsPzDap0PZV6FgzfDpW8K+R4BBcl+gClq82jRi3ulDjtISTXriUrNNfkrg==",
|
169 |
+
"dependencies": {
|
170 |
+
"@types/path-browserify": "^1.0.0",
|
171 |
+
"path-browserify": "^1.0.1"
|
172 |
+
}
|
173 |
+
},
|
174 |
+
"node_modules/@jridgewell/gen-mapping": {
|
175 |
+
"version": "0.3.3",
|
176 |
+
"resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.3.tgz",
|
177 |
+
"integrity": "sha512-HLhSWOLRi875zjjMG/r+Nv0oCW8umGb0BgEhyX3dDX3egwZtB8PqLnjz3yedt8R5StBrzcg4aBpnh8UA9D1BoQ==",
|
178 |
+
"peer": true,
|
179 |
+
"dependencies": {
|
180 |
+
"@jridgewell/set-array": "^1.0.1",
|
181 |
+
"@jridgewell/sourcemap-codec": "^1.4.10",
|
182 |
+
"@jridgewell/trace-mapping": "^0.3.9"
|
183 |
+
},
|
184 |
+
"engines": {
|
185 |
+
"node": ">=6.0.0"
|
186 |
+
}
|
187 |
+
},
|
188 |
+
"node_modules/@jridgewell/resolve-uri": {
|
189 |
+
"version": "3.1.1",
|
190 |
+
"resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.1.tgz",
|
191 |
+
"integrity": "sha512-dSYZh7HhCDtCKm4QakX0xFpsRDqjjtZf/kjI/v3T3Nwt5r8/qz/M19F9ySyOqU94SXBmeG9ttTul+YnR4LOxFA==",
|
192 |
+
"peer": true,
|
193 |
+
"engines": {
|
194 |
+
"node": ">=6.0.0"
|
195 |
+
}
|
196 |
+
},
|
197 |
+
"node_modules/@jridgewell/set-array": {
|
198 |
+
"version": "1.1.2",
|
199 |
+
"resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.1.2.tgz",
|
200 |
+
"integrity": "sha512-xnkseuNADM0gt2bs+BvhO0p78Mk762YnZdsuzFV018NoG1Sj1SCQvpSqa7XUaTam5vAGasABV9qXASMKnFMwMw==",
|
201 |
+
"peer": true,
|
202 |
+
"engines": {
|
203 |
+
"node": ">=6.0.0"
|
204 |
+
}
|
205 |
+
},
|
206 |
+
"node_modules/@jridgewell/sourcemap-codec": {
|
207 |
+
"version": "1.4.15",
|
208 |
+
"resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.15.tgz",
|
209 |
+
"integrity": "sha512-eF2rxCRulEKXHTRiDrDy6erMYWqNw4LPdQ8UQA4huuxaQsVeRPFl2oM8oDGxMFhJUWZf9McpLtJasDDZb/Bpeg==",
|
210 |
+
"peer": true
|
211 |
+
},
|
212 |
+
"node_modules/@jridgewell/trace-mapping": {
|
213 |
+
"version": "0.3.20",
|
214 |
+
"resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.20.tgz",
|
215 |
+
"integrity": "sha512-R8LcPeWZol2zR8mmH3JeKQ6QRCFb7XgUhV9ZlGhHLGyg4wpPiPZNQOOWhFZhxKw8u//yTbNGI42Bx/3paXEQ+Q==",
|
216 |
+
"peer": true,
|
217 |
+
"dependencies": {
|
218 |
+
"@jridgewell/resolve-uri": "^3.1.0",
|
219 |
+
"@jridgewell/sourcemap-codec": "^1.4.14"
|
220 |
+
}
|
221 |
+
},
|
222 |
+
"node_modules/@types/estree": {
|
223 |
+
"version": "1.0.5",
|
224 |
+
"resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.5.tgz",
|
225 |
+
"integrity": "sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==",
|
226 |
+
"peer": true
|
227 |
+
},
|
228 |
+
"node_modules/@types/path-browserify": {
|
229 |
+
"version": "1.0.2",
|
230 |
+
"resolved": "https://registry.npmjs.org/@types/path-browserify/-/path-browserify-1.0.2.tgz",
|
231 |
+
"integrity": "sha512-ZkC5IUqqIFPXx3ASTTybTzmQdwHwe2C0u3eL75ldQ6T9E9IWFJodn6hIfbZGab73DfyiHN4Xw15gNxUq2FbvBA=="
|
232 |
+
},
|
233 |
+
"node_modules/acorn": {
|
234 |
+
"version": "8.11.2",
|
235 |
+
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.11.2.tgz",
|
236 |
+
"integrity": "sha512-nc0Axzp/0FILLEVsm4fNwLCwMttvhEI263QtVPQcbpfZZ3ts0hLsZGOpE6czNlid7CJ9MlyH8reXkpsf3YUY4w==",
|
237 |
+
"peer": true,
|
238 |
+
"bin": {
|
239 |
+
"acorn": "bin/acorn"
|
240 |
+
},
|
241 |
+
"engines": {
|
242 |
+
"node": ">=0.4.0"
|
243 |
+
}
|
244 |
+
},
|
245 |
+
"node_modules/aria-query": {
|
246 |
+
"version": "5.3.0",
|
247 |
+
"resolved": "https://registry.npmjs.org/aria-query/-/aria-query-5.3.0.tgz",
|
248 |
+
"integrity": "sha512-b0P0sZPKtyu8HkeRAfCq0IfURZK+SuwMjY1UXGBU27wpAiTwQAIlq56IbIO+ytk/JjS1fMR14ee5WBBfKi5J6A==",
|
249 |
+
"peer": true,
|
250 |
+
"dependencies": {
|
251 |
+
"dequal": "^2.0.3"
|
252 |
+
}
|
253 |
+
},
|
254 |
+
"node_modules/axobject-query": {
|
255 |
+
"version": "3.2.1",
|
256 |
+
"resolved": "https://registry.npmjs.org/axobject-query/-/axobject-query-3.2.1.tgz",
|
257 |
+
"integrity": "sha512-jsyHu61e6N4Vbz/v18DHwWYKK0bSWLqn47eeDSKPB7m8tqMHF9YJ+mhIk2lVteyZrY8tnSj/jHOv4YiTCuCJgg==",
|
258 |
+
"peer": true,
|
259 |
+
"dependencies": {
|
260 |
+
"dequal": "^2.0.3"
|
261 |
+
}
|
262 |
+
},
|
263 |
+
"node_modules/bufferutil": {
|
264 |
+
"version": "4.0.8",
|
265 |
+
"resolved": "https://registry.npmjs.org/bufferutil/-/bufferutil-4.0.8.tgz",
|
266 |
+
"integrity": "sha512-4T53u4PdgsXqKaIctwF8ifXlRTTmEPJ8iEPWFdGZvcf7sbwYo6FKFEX9eNNAnzFZ7EzJAQ3CJeOtCRA4rDp7Pw==",
|
267 |
+
"hasInstallScript": true,
|
268 |
+
"dependencies": {
|
269 |
+
"node-gyp-build": "^4.3.0"
|
270 |
+
},
|
271 |
+
"engines": {
|
272 |
+
"node": ">=6.14.2"
|
273 |
+
}
|
274 |
+
},
|
275 |
+
"node_modules/cli-color": {
|
276 |
+
"version": "2.0.3",
|
277 |
+
"resolved": "https://registry.npmjs.org/cli-color/-/cli-color-2.0.3.tgz",
|
278 |
+
"integrity": "sha512-OkoZnxyC4ERN3zLzZaY9Emb7f/MhBOIpePv0Ycok0fJYT+Ouo00UBEIwsVsr0yoow++n5YWlSUgST9GKhNHiRQ==",
|
279 |
+
"dependencies": {
|
280 |
+
"d": "^1.0.1",
|
281 |
+
"es5-ext": "^0.10.61",
|
282 |
+
"es6-iterator": "^2.0.3",
|
283 |
+
"memoizee": "^0.4.15",
|
284 |
+
"timers-ext": "^0.1.7"
|
285 |
+
},
|
286 |
+
"engines": {
|
287 |
+
"node": ">=0.10"
|
288 |
+
}
|
289 |
+
},
|
290 |
+
"node_modules/code-red": {
|
291 |
+
"version": "1.0.4",
|
292 |
+
"resolved": "https://registry.npmjs.org/code-red/-/code-red-1.0.4.tgz",
|
293 |
+
"integrity": "sha512-7qJWqItLA8/VPVlKJlFXU+NBlo/qyfs39aJcuMT/2ere32ZqvF5OSxgdM5xOfJJ7O429gg2HM47y8v9P+9wrNw==",
|
294 |
+
"peer": true,
|
295 |
+
"dependencies": {
|
296 |
+
"@jridgewell/sourcemap-codec": "^1.4.15",
|
297 |
+
"@types/estree": "^1.0.1",
|
298 |
+
"acorn": "^8.10.0",
|
299 |
+
"estree-walker": "^3.0.3",
|
300 |
+
"periscopic": "^3.1.0"
|
301 |
+
}
|
302 |
+
},
|
303 |
+
"node_modules/cropperjs": {
|
304 |
+
"version": "1.6.1",
|
305 |
+
"resolved": "https://registry.npmjs.org/cropperjs/-/cropperjs-1.6.1.tgz",
|
306 |
+
"integrity": "sha512-F4wsi+XkDHCOMrHMYjrTEE4QBOrsHHN5/2VsVAaRq8P7E5z7xQpT75S+f/9WikmBEailas3+yo+6zPIomW+NOA=="
|
307 |
+
},
|
308 |
+
"node_modules/css-tree": {
|
309 |
+
"version": "2.3.1",
|
310 |
+
"resolved": "https://registry.npmjs.org/css-tree/-/css-tree-2.3.1.tgz",
|
311 |
+
"integrity": "sha512-6Fv1DV/TYw//QF5IzQdqsNDjx/wc8TrMBZsqjL9eW01tWb7R7k/mq+/VXfJCl7SoD5emsJop9cOByJZfs8hYIw==",
|
312 |
+
"peer": true,
|
313 |
+
"dependencies": {
|
314 |
+
"mdn-data": "2.0.30",
|
315 |
+
"source-map-js": "^1.0.1"
|
316 |
+
},
|
317 |
+
"engines": {
|
318 |
+
"node": "^10 || ^12.20.0 || ^14.13.0 || >=15.0.0"
|
319 |
+
}
|
320 |
+
},
|
321 |
+
"node_modules/d": {
|
322 |
+
"version": "1.0.1",
|
323 |
+
"resolved": "https://registry.npmjs.org/d/-/d-1.0.1.tgz",
|
324 |
+
"integrity": "sha512-m62ShEObQ39CfralilEQRjH6oAMtNCV1xJyEx5LpRYUVN+EviphDgUc/F3hnYbADmkiNs67Y+3ylmlG7Lnu+FA==",
|
325 |
+
"dependencies": {
|
326 |
+
"es5-ext": "^0.10.50",
|
327 |
+
"type": "^1.0.1"
|
328 |
+
}
|
329 |
+
},
|
330 |
+
"node_modules/deepmerge": {
|
331 |
+
"version": "4.3.1",
|
332 |
+
"resolved": "https://registry.npmjs.org/deepmerge/-/deepmerge-4.3.1.tgz",
|
333 |
+
"integrity": "sha512-3sUqbMEc77XqpdNO7FRyRog+eW3ph+GYCbj+rK+uYyRMuwsVy0rMiVtPn+QJlKFvWP/1PYpapqYn0Me2knFn+A==",
|
334 |
+
"engines": {
|
335 |
+
"node": ">=0.10.0"
|
336 |
+
}
|
337 |
+
},
|
338 |
+
"node_modules/dequal": {
|
339 |
+
"version": "2.0.3",
|
340 |
+
"resolved": "https://registry.npmjs.org/dequal/-/dequal-2.0.3.tgz",
|
341 |
+
"integrity": "sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==",
|
342 |
+
"peer": true,
|
343 |
+
"engines": {
|
344 |
+
"node": ">=6"
|
345 |
+
}
|
346 |
+
},
|
347 |
+
"node_modules/es5-ext": {
|
348 |
+
"version": "0.10.62",
|
349 |
+
"resolved": "https://registry.npmjs.org/es5-ext/-/es5-ext-0.10.62.tgz",
|
350 |
+
"integrity": "sha512-BHLqn0klhEpnOKSrzn/Xsz2UIW8j+cGmo9JLzr8BiUapV8hPL9+FliFqjwr9ngW7jWdnxv6eO+/LqyhJVqgrjA==",
|
351 |
+
"hasInstallScript": true,
|
352 |
+
"dependencies": {
|
353 |
+
"es6-iterator": "^2.0.3",
|
354 |
+
"es6-symbol": "^3.1.3",
|
355 |
+
"next-tick": "^1.1.0"
|
356 |
+
},
|
357 |
+
"engines": {
|
358 |
+
"node": ">=0.10"
|
359 |
+
}
|
360 |
+
},
|
361 |
+
"node_modules/es6-iterator": {
|
362 |
+
"version": "2.0.3",
|
363 |
+
"resolved": "https://registry.npmjs.org/es6-iterator/-/es6-iterator-2.0.3.tgz",
|
364 |
+
"integrity": "sha512-zw4SRzoUkd+cl+ZoE15A9o1oQd920Bb0iOJMQkQhl3jNc03YqVjAhG7scf9C5KWRU/R13Orf588uCC6525o02g==",
|
365 |
+
"dependencies": {
|
366 |
+
"d": "1",
|
367 |
+
"es5-ext": "^0.10.35",
|
368 |
+
"es6-symbol": "^3.1.1"
|
369 |
+
}
|
370 |
+
},
|
371 |
+
"node_modules/es6-symbol": {
|
372 |
+
"version": "3.1.3",
|
373 |
+
"resolved": "https://registry.npmjs.org/es6-symbol/-/es6-symbol-3.1.3.tgz",
|
374 |
+
"integrity": "sha512-NJ6Yn3FuDinBaBRWl/q5X/s4koRHBrgKAu+yGI6JCBeiu3qrcbJhwT2GeR/EXVfylRk8dpQVJoLEFhK+Mu31NA==",
|
375 |
+
"dependencies": {
|
376 |
+
"d": "^1.0.1",
|
377 |
+
"ext": "^1.1.2"
|
378 |
+
}
|
379 |
+
},
|
380 |
+
"node_modules/es6-weak-map": {
|
381 |
+
"version": "2.0.3",
|
382 |
+
"resolved": "https://registry.npmjs.org/es6-weak-map/-/es6-weak-map-2.0.3.tgz",
|
383 |
+
"integrity": "sha512-p5um32HOTO1kP+w7PRnB+5lQ43Z6muuMuIMffvDN8ZB4GcnjLBV6zGStpbASIMk4DCAvEaamhe2zhyCb/QXXsA==",
|
384 |
+
"dependencies": {
|
385 |
+
"d": "1",
|
386 |
+
"es5-ext": "^0.10.46",
|
387 |
+
"es6-iterator": "^2.0.3",
|
388 |
+
"es6-symbol": "^3.1.1"
|
389 |
+
}
|
390 |
+
},
|
391 |
+
"node_modules/esbuild": {
|
392 |
+
"version": "0.19.8",
|
393 |
+
"resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.19.8.tgz",
|
394 |
+
"integrity": "sha512-l7iffQpT2OrZfH2rXIp7/FkmaeZM0vxbxN9KfiCwGYuZqzMg/JdvX26R31Zxn/Pxvsrg3Y9N6XTcnknqDyyv4w==",
|
395 |
+
"hasInstallScript": true,
|
396 |
+
"bin": {
|
397 |
+
"esbuild": "bin/esbuild"
|
398 |
+
},
|
399 |
+
"engines": {
|
400 |
+
"node": ">=12"
|
401 |
+
},
|
402 |
+
"optionalDependencies": {
|
403 |
+
"@esbuild/android-arm": "0.19.8",
|
404 |
+
"@esbuild/android-arm64": "0.19.8",
|
405 |
+
"@esbuild/android-x64": "0.19.8",
|
406 |
+
"@esbuild/darwin-arm64": "0.19.8",
|
407 |
+
"@esbuild/darwin-x64": "0.19.8",
|
408 |
+
"@esbuild/freebsd-arm64": "0.19.8",
|
409 |
+
"@esbuild/freebsd-x64": "0.19.8",
|
410 |
+
"@esbuild/linux-arm": "0.19.8",
|
411 |
+
"@esbuild/linux-arm64": "0.19.8",
|
412 |
+
"@esbuild/linux-ia32": "0.19.8",
|
413 |
+
"@esbuild/linux-loong64": "0.19.8",
|
414 |
+
"@esbuild/linux-mips64el": "0.19.8",
|
415 |
+
"@esbuild/linux-ppc64": "0.19.8",
|
416 |
+
"@esbuild/linux-riscv64": "0.19.8",
|
417 |
+
"@esbuild/linux-s390x": "0.19.8",
|
418 |
+
"@esbuild/linux-x64": "0.19.8",
|
419 |
+
"@esbuild/netbsd-x64": "0.19.8",
|
420 |
+
"@esbuild/openbsd-x64": "0.19.8",
|
421 |
+
"@esbuild/sunos-x64": "0.19.8",
|
422 |
+
"@esbuild/win32-arm64": "0.19.8",
|
423 |
+
"@esbuild/win32-ia32": "0.19.8",
|
424 |
+
"@esbuild/win32-x64": "0.19.8"
|
425 |
+
}
|
426 |
+
},
|
427 |
+
"node_modules/estree-walker": {
|
428 |
+
"version": "3.0.3",
|
429 |
+
"resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-3.0.3.tgz",
|
430 |
+
"integrity": "sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==",
|
431 |
+
"peer": true,
|
432 |
+
"dependencies": {
|
433 |
+
"@types/estree": "^1.0.0"
|
434 |
+
}
|
435 |
+
},
|
436 |
+
"node_modules/event-emitter": {
|
437 |
+
"version": "0.3.5",
|
438 |
+
"resolved": "https://registry.npmjs.org/event-emitter/-/event-emitter-0.3.5.tgz",
|
439 |
+
"integrity": "sha512-D9rRn9y7kLPnJ+hMq7S/nhvoKwwvVJahBi2BPmx3bvbsEdK3W9ii8cBSGjP+72/LnM4n6fo3+dkCX5FeTQruXA==",
|
440 |
+
"dependencies": {
|
441 |
+
"d": "1",
|
442 |
+
"es5-ext": "~0.10.14"
|
443 |
+
}
|
444 |
+
},
|
445 |
+
"node_modules/ext": {
|
446 |
+
"version": "1.7.0",
|
447 |
+
"resolved": "https://registry.npmjs.org/ext/-/ext-1.7.0.tgz",
|
448 |
+
"integrity": "sha512-6hxeJYaL110a9b5TEJSj0gojyHQAmA2ch5Os+ySCiA1QGdS697XWY1pzsrSjqA9LDEEgdB/KypIlR59RcLuHYw==",
|
449 |
+
"dependencies": {
|
450 |
+
"type": "^2.7.2"
|
451 |
+
}
|
452 |
+
},
|
453 |
+
"node_modules/ext/node_modules/type": {
|
454 |
+
"version": "2.7.2",
|
455 |
+
"resolved": "https://registry.npmjs.org/type/-/type-2.7.2.tgz",
|
456 |
+
"integrity": "sha512-dzlvlNlt6AXU7EBSfpAscydQ7gXB+pPGsPnfJnZpiNJBDj7IaJzQlBZYGdEi4R9HmPdBv2XmWJ6YUtoTa7lmCw=="
|
457 |
+
},
|
458 |
+
"node_modules/globalyzer": {
|
459 |
+
"version": "0.1.0",
|
460 |
+
"resolved": "https://registry.npmjs.org/globalyzer/-/globalyzer-0.1.0.tgz",
|
461 |
+
"integrity": "sha512-40oNTM9UfG6aBmuKxk/giHn5nQ8RVz/SS4Ir6zgzOv9/qC3kKZ9v4etGTcJbEl/NyVQH7FGU7d+X1egr57Md2Q=="
|
462 |
+
},
|
463 |
+
"node_modules/globrex": {
|
464 |
+
"version": "0.1.2",
|
465 |
+
"resolved": "https://registry.npmjs.org/globrex/-/globrex-0.1.2.tgz",
|
466 |
+
"integrity": "sha512-uHJgbwAMwNFf5mLst7IWLNg14x1CkeqglJb/K3doi4dw6q2IvAAmM/Y81kevy83wP+Sst+nutFTYOGg3d1lsxg=="
|
467 |
+
},
|
468 |
+
"node_modules/intl-messageformat": {
|
469 |
+
"version": "9.13.0",
|
470 |
+
"resolved": "https://registry.npmjs.org/intl-messageformat/-/intl-messageformat-9.13.0.tgz",
|
471 |
+
"integrity": "sha512-7sGC7QnSQGa5LZP7bXLDhVDtQOeKGeBFGHF2Y8LVBwYZoQZCgWeKoPGTa5GMG8g/TzDgeXuYJQis7Ggiw2xTOw==",
|
472 |
+
"dependencies": {
|
473 |
+
"@formatjs/ecma402-abstract": "1.11.4",
|
474 |
+
"@formatjs/fast-memoize": "1.2.1",
|
475 |
+
"@formatjs/icu-messageformat-parser": "2.1.0",
|
476 |
+
"tslib": "^2.1.0"
|
477 |
+
}
|
478 |
+
},
|
479 |
+
"node_modules/is-promise": {
|
480 |
+
"version": "2.2.2",
|
481 |
+
"resolved": "https://registry.npmjs.org/is-promise/-/is-promise-2.2.2.tgz",
|
482 |
+
"integrity": "sha512-+lP4/6lKUBfQjZ2pdxThZvLUAafmZb8OAxFb8XXtiQmS35INgr85hdOGoEs124ez1FCnZJt6jau/T+alh58QFQ=="
|
483 |
+
},
|
484 |
+
"node_modules/is-reference": {
|
485 |
+
"version": "3.0.2",
|
486 |
+
"resolved": "https://registry.npmjs.org/is-reference/-/is-reference-3.0.2.tgz",
|
487 |
+
"integrity": "sha512-v3rht/LgVcsdZa3O2Nqs+NMowLOxeOm7Ay9+/ARQ2F+qEoANRcqrjAZKGN0v8ymUetZGgkp26LTnGT7H0Qo9Pg==",
|
488 |
+
"peer": true,
|
489 |
+
"dependencies": {
|
490 |
+
"@types/estree": "*"
|
491 |
+
}
|
492 |
+
},
|
493 |
+
"node_modules/lazy-brush": {
|
494 |
+
"version": "1.0.1",
|
495 |
+
"resolved": "https://registry.npmjs.org/lazy-brush/-/lazy-brush-1.0.1.tgz",
|
496 |
+
"integrity": "sha512-xT/iSClTVi7vLoF8dCWTBhCuOWqsLXCMPa6ucVmVAk6hyNCM5JeS1NLhXqIrJktUg+caEYKlqSOUU4u3cpXzKg=="
|
497 |
+
},
|
498 |
+
"node_modules/locate-character": {
|
499 |
+
"version": "3.0.0",
|
500 |
+
"resolved": "https://registry.npmjs.org/locate-character/-/locate-character-3.0.0.tgz",
|
501 |
+
"integrity": "sha512-SW13ws7BjaeJ6p7Q6CO2nchbYEc3X3J6WrmTTDto7yMPqVSZTUyY5Tjbid+Ab8gLnATtygYtiDIJGQRRn2ZOiA==",
|
502 |
+
"peer": true
|
503 |
+
},
|
504 |
+
"node_modules/lru-queue": {
|
505 |
+
"version": "0.1.0",
|
506 |
+
"resolved": "https://registry.npmjs.org/lru-queue/-/lru-queue-0.1.0.tgz",
|
507 |
+
"integrity": "sha512-BpdYkt9EvGl8OfWHDQPISVpcl5xZthb+XPsbELj5AQXxIC8IriDZIQYjBJPEm5rS420sjZ0TLEzRcq5KdBhYrQ==",
|
508 |
+
"dependencies": {
|
509 |
+
"es5-ext": "~0.10.2"
|
510 |
+
}
|
511 |
+
},
|
512 |
+
"node_modules/magic-string": {
|
513 |
+
"version": "0.30.5",
|
514 |
+
"resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.30.5.tgz",
|
515 |
+
"integrity": "sha512-7xlpfBaQaP/T6Vh8MO/EqXSW5En6INHEvEXQiuff7Gku0PWjU3uf6w/j9o7O+SpB5fOAkrI5HeoNgwjEO0pFsA==",
|
516 |
+
"peer": true,
|
517 |
+
"dependencies": {
|
518 |
+
"@jridgewell/sourcemap-codec": "^1.4.15"
|
519 |
+
},
|
520 |
+
"engines": {
|
521 |
+
"node": ">=12"
|
522 |
+
}
|
523 |
+
},
|
524 |
+
"node_modules/mdn-data": {
|
525 |
+
"version": "2.0.30",
|
526 |
+
"resolved": "https://registry.npmjs.org/mdn-data/-/mdn-data-2.0.30.tgz",
|
527 |
+
"integrity": "sha512-GaqWWShW4kv/G9IEucWScBx9G1/vsFZZJUO+tD26M8J8z3Kw5RDQjaoZe03YAClgeS/SWPOcb4nkFBTEi5DUEA==",
|
528 |
+
"peer": true
|
529 |
+
},
|
530 |
+
"node_modules/memoizee": {
|
531 |
+
"version": "0.4.15",
|
532 |
+
"resolved": "https://registry.npmjs.org/memoizee/-/memoizee-0.4.15.tgz",
|
533 |
+
"integrity": "sha512-UBWmJpLZd5STPm7PMUlOw/TSy972M+z8gcyQ5veOnSDRREz/0bmpyTfKt3/51DhEBqCZQn1udM/5flcSPYhkdQ==",
|
534 |
+
"dependencies": {
|
535 |
+
"d": "^1.0.1",
|
536 |
+
"es5-ext": "^0.10.53",
|
537 |
+
"es6-weak-map": "^2.0.3",
|
538 |
+
"event-emitter": "^0.3.5",
|
539 |
+
"is-promise": "^2.2.2",
|
540 |
+
"lru-queue": "^0.1.0",
|
541 |
+
"next-tick": "^1.1.0",
|
542 |
+
"timers-ext": "^0.1.7"
|
543 |
+
}
|
544 |
+
},
|
545 |
+
"node_modules/mri": {
|
546 |
+
"version": "1.2.0",
|
547 |
+
"resolved": "https://registry.npmjs.org/mri/-/mri-1.2.0.tgz",
|
548 |
+
"integrity": "sha512-tzzskb3bG8LvYGFF/mDTpq3jpI6Q9wc3LEmBaghu+DdCssd1FakN7Bc0hVNmEyGq1bq3RgfkCb3cmQLpNPOroA==",
|
549 |
+
"engines": {
|
550 |
+
"node": ">=4"
|
551 |
+
}
|
552 |
+
},
|
553 |
+
"node_modules/next-tick": {
|
554 |
+
"version": "1.1.0",
|
555 |
+
"resolved": "https://registry.npmjs.org/next-tick/-/next-tick-1.1.0.tgz",
|
556 |
+
"integrity": "sha512-CXdUiJembsNjuToQvxayPZF9Vqht7hewsvy2sOWafLvi2awflj9mOC6bHIg50orX8IJvWKY9wYQ/zB2kogPslQ=="
|
557 |
+
},
|
558 |
+
"node_modules/node-gyp-build": {
|
559 |
+
"version": "4.7.1",
|
560 |
+
"resolved": "https://registry.npmjs.org/node-gyp-build/-/node-gyp-build-4.7.1.tgz",
|
561 |
+
"integrity": "sha512-wTSrZ+8lsRRa3I3H8Xr65dLWSgCvY2l4AOnaeKdPA9TB/WYMPaTcrzf3rXvFoVvjKNVnu0CcWSx54qq9GKRUYg==",
|
562 |
+
"bin": {
|
563 |
+
"node-gyp-build": "bin.js",
|
564 |
+
"node-gyp-build-optional": "optional.js",
|
565 |
+
"node-gyp-build-test": "build-test.js"
|
566 |
+
}
|
567 |
+
},
|
568 |
+
"node_modules/path-browserify": {
|
569 |
+
"version": "1.0.1",
|
570 |
+
"resolved": "https://registry.npmjs.org/path-browserify/-/path-browserify-1.0.1.tgz",
|
571 |
+
"integrity": "sha512-b7uo2UCUOYZcnF/3ID0lulOJi/bafxa1xPe7ZPsammBSpjSWQkjNxlt635YGS2MiR9GjvuXCtz2emr3jbsz98g=="
|
572 |
+
},
|
573 |
+
"node_modules/periscopic": {
|
574 |
+
"version": "3.1.0",
|
575 |
+
"resolved": "https://registry.npmjs.org/periscopic/-/periscopic-3.1.0.tgz",
|
576 |
+
"integrity": "sha512-vKiQ8RRtkl9P+r/+oefh25C3fhybptkHKCZSPlcXiJux2tJF55GnEj3BVn4A5gKfq9NWWXXrxkHBwVPUfH0opw==",
|
577 |
+
"peer": true,
|
578 |
+
"dependencies": {
|
579 |
+
"@types/estree": "^1.0.0",
|
580 |
+
"estree-walker": "^3.0.0",
|
581 |
+
"is-reference": "^3.0.0"
|
582 |
+
}
|
583 |
+
},
|
584 |
+
"node_modules/resize-observer-polyfill": {
|
585 |
+
"version": "1.5.1",
|
586 |
+
"resolved": "https://registry.npmjs.org/resize-observer-polyfill/-/resize-observer-polyfill-1.5.1.tgz",
|
587 |
+
"integrity": "sha512-LwZrotdHOo12nQuZlHEmtuXdqGoOD0OhaxopaNFxWzInpEgaLWoVuAMbTzixuosCx2nEG58ngzW3vxdWoxIgdg=="
|
588 |
+
},
|
589 |
+
"node_modules/sade": {
|
590 |
+
"version": "1.8.1",
|
591 |
+
"resolved": "https://registry.npmjs.org/sade/-/sade-1.8.1.tgz",
|
592 |
+
"integrity": "sha512-xal3CZX1Xlo/k4ApwCFrHVACi9fBqJ7V+mwhBsuf/1IOKbBy098Fex+Wa/5QMubw09pSZ/u8EY8PWgevJsXp1A==",
|
593 |
+
"dependencies": {
|
594 |
+
"mri": "^1.1.0"
|
595 |
+
},
|
596 |
+
"engines": {
|
597 |
+
"node": ">=6"
|
598 |
+
}
|
599 |
+
},
|
600 |
+
"node_modules/semiver": {
|
601 |
+
"version": "1.1.0",
|
602 |
+
"resolved": "https://registry.npmjs.org/semiver/-/semiver-1.1.0.tgz",
|
603 |
+
"integrity": "sha512-QNI2ChmuioGC1/xjyYwyZYADILWyW6AmS1UH6gDj/SFUUUS4MBAWs/7mxnkRPc/F4iHezDP+O8t0dO8WHiEOdg==",
|
604 |
+
"engines": {
|
605 |
+
"node": ">=6"
|
606 |
+
}
|
607 |
+
},
|
608 |
+
"node_modules/source-map-js": {
|
609 |
+
"version": "1.0.2",
|
610 |
+
"resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.0.2.tgz",
|
611 |
+
"integrity": "sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==",
|
612 |
+
"peer": true,
|
613 |
+
"engines": {
|
614 |
+
"node": ">=0.10.0"
|
615 |
+
}
|
616 |
+
},
|
617 |
+
"node_modules/svelte": {
|
618 |
+
"version": "4.2.8",
|
619 |
+
"resolved": "https://registry.npmjs.org/svelte/-/svelte-4.2.8.tgz",
|
620 |
+
"integrity": "sha512-hU6dh1MPl8gh6klQZwK/n73GiAHiR95IkFsesLPbMeEZi36ydaXL/ZAb4g9sayT0MXzpxyZjR28yderJHxcmYA==",
|
621 |
+
"peer": true,
|
622 |
+
"dependencies": {
|
623 |
+
"@ampproject/remapping": "^2.2.1",
|
624 |
+
"@jridgewell/sourcemap-codec": "^1.4.15",
|
625 |
+
"@jridgewell/trace-mapping": "^0.3.18",
|
626 |
+
"acorn": "^8.9.0",
|
627 |
+
"aria-query": "^5.3.0",
|
628 |
+
"axobject-query": "^3.2.1",
|
629 |
+
"code-red": "^1.0.3",
|
630 |
+
"css-tree": "^2.3.1",
|
631 |
+
"estree-walker": "^3.0.3",
|
632 |
+
"is-reference": "^3.0.1",
|
633 |
+
"locate-character": "^3.0.0",
|
634 |
+
"magic-string": "^0.30.4",
|
635 |
+
"periscopic": "^3.1.0"
|
636 |
+
},
|
637 |
+
"engines": {
|
638 |
+
"node": ">=16"
|
639 |
+
}
|
640 |
+
},
|
641 |
+
"node_modules/svelte-i18n": {
|
642 |
+
"version": "3.7.4",
|
643 |
+
"resolved": "https://registry.npmjs.org/svelte-i18n/-/svelte-i18n-3.7.4.tgz",
|
644 |
+
"integrity": "sha512-yGRCNo+eBT4cPuU7IVsYTYjxB7I2V8qgUZPlHnNctJj5IgbJgV78flsRzpjZ/8iUYZrS49oCt7uxlU3AZv/N5Q==",
|
645 |
+
"dependencies": {
|
646 |
+
"cli-color": "^2.0.3",
|
647 |
+
"deepmerge": "^4.2.2",
|
648 |
+
"esbuild": "^0.19.2",
|
649 |
+
"estree-walker": "^2",
|
650 |
+
"intl-messageformat": "^9.13.0",
|
651 |
+
"sade": "^1.8.1",
|
652 |
+
"tiny-glob": "^0.2.9"
|
653 |
+
},
|
654 |
+
"bin": {
|
655 |
+
"svelte-i18n": "dist/cli.js"
|
656 |
+
},
|
657 |
+
"engines": {
|
658 |
+
"node": ">= 16"
|
659 |
+
},
|
660 |
+
"peerDependencies": {
|
661 |
+
"svelte": "^3 || ^4"
|
662 |
+
}
|
663 |
+
},
|
664 |
+
"node_modules/svelte-i18n/node_modules/estree-walker": {
|
665 |
+
"version": "2.0.2",
|
666 |
+
"resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-2.0.2.tgz",
|
667 |
+
"integrity": "sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w=="
|
668 |
+
},
|
669 |
+
"node_modules/timers-ext": {
|
670 |
+
"version": "0.1.7",
|
671 |
+
"resolved": "https://registry.npmjs.org/timers-ext/-/timers-ext-0.1.7.tgz",
|
672 |
+
"integrity": "sha512-b85NUNzTSdodShTIbky6ZF02e8STtVVfD+fu4aXXShEELpozH+bCpJLYMPZbsABN2wDH7fJpqIoXxJpzbf0NqQ==",
|
673 |
+
"dependencies": {
|
674 |
+
"es5-ext": "~0.10.46",
|
675 |
+
"next-tick": "1"
|
676 |
+
}
|
677 |
+
},
|
678 |
+
"node_modules/tiny-glob": {
|
679 |
+
"version": "0.2.9",
|
680 |
+
"resolved": "https://registry.npmjs.org/tiny-glob/-/tiny-glob-0.2.9.tgz",
|
681 |
+
"integrity": "sha512-g/55ssRPUjShh+xkfx9UPDXqhckHEsHr4Vd9zX55oSdGZc/MD0m3sferOkwWtp98bv+kcVfEHtRJgBVJzelrzg==",
|
682 |
+
"dependencies": {
|
683 |
+
"globalyzer": "0.1.0",
|
684 |
+
"globrex": "^0.1.2"
|
685 |
+
}
|
686 |
+
},
|
687 |
+
"node_modules/tslib": {
|
688 |
+
"version": "2.6.2",
|
689 |
+
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz",
|
690 |
+
"integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q=="
|
691 |
+
},
|
692 |
+
"node_modules/type": {
|
693 |
+
"version": "1.2.0",
|
694 |
+
"resolved": "https://registry.npmjs.org/type/-/type-1.2.0.tgz",
|
695 |
+
"integrity": "sha512-+5nt5AAniqsCnu2cEQQdpzCAh33kVx8n0VoFidKpB1dVVLAN/F+bgVOqOJqOnEnrhp222clB5p3vUlD+1QAnfg=="
|
696 |
+
},
|
697 |
+
"node_modules/ws": {
|
698 |
+
"version": "8.14.2",
|
699 |
+
"resolved": "https://registry.npmjs.org/ws/-/ws-8.14.2.tgz",
|
700 |
+
"integrity": "sha512-wEBG1ftX4jcglPxgFCMJmZ2PLtSbJ2Peg6TmpJFTbe9GZYOQCDPdMYu/Tm0/bGZkw8paZnJY45J4K2PZrLYq8g==",
|
701 |
+
"engines": {
|
702 |
+
"node": ">=10.0.0"
|
703 |
+
},
|
704 |
+
"peerDependencies": {
|
705 |
+
"bufferutil": "^4.0.1",
|
706 |
+
"utf-8-validate": ">=5.0.2"
|
707 |
+
},
|
708 |
+
"peerDependenciesMeta": {
|
709 |
+
"bufferutil": {
|
710 |
+
"optional": true
|
711 |
+
},
|
712 |
+
"utf-8-validate": {
|
713 |
+
"optional": true
|
714 |
+
}
|
715 |
+
}
|
716 |
+
}
|
717 |
+
}
|
718 |
+
}
|
src/frontend/package.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "gradio_image_prompter",
|
3 |
+
"version": "0.4.2",
|
4 |
+
"description": "Gradio UI packages",
|
5 |
+
"type": "module",
|
6 |
+
"author": "",
|
7 |
+
"license": "ISC",
|
8 |
+
"private": false,
|
9 |
+
"dependencies": {
|
10 |
+
"@gradio/atoms": "0.3.1",
|
11 |
+
"@gradio/client": "0.8.2",
|
12 |
+
"@gradio/icons": "0.3.1",
|
13 |
+
"@gradio/statustracker": "0.4.1",
|
14 |
+
"@gradio/upload": "0.5.2",
|
15 |
+
"@gradio/utils": "0.2.0",
|
16 |
+
"@gradio/wasm": "0.3.0",
|
17 |
+
"cropperjs": "^1.5.12",
|
18 |
+
"lazy-brush": "^1.0.1",
|
19 |
+
"resize-observer-polyfill": "^1.5.1"
|
20 |
+
},
|
21 |
+
"main_changeset": true,
|
22 |
+
"main": "./Index.svelte",
|
23 |
+
"exports": {
|
24 |
+
".": "./Index.svelte",
|
25 |
+
"./example": "./Example.svelte",
|
26 |
+
"./package.json": "./package.json"
|
27 |
+
}
|
28 |
+
}
|
src/frontend/shared/BoxDrawer.svelte
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<svelte:options accessors={true} />
|
2 |
+
|
3 |
+
<script lang="ts">
|
4 |
+
import { createEventDispatcher, onDestroy, onMount, tick } from "svelte";
|
5 |
+
|
6 |
+
const dispatch = createEventDispatcher();
|
7 |
+
|
8 |
+
export let width = 0;
|
9 |
+
export let height = 0;
|
10 |
+
export let natural_width = 0;
|
11 |
+
export let natural_height = 0;
|
12 |
+
|
13 |
+
let boxes: Array<Array<number>> = [];
|
14 |
+
let points: Array<Array<number>> = [];
|
15 |
+
|
16 |
+
let canvas_container: HTMLElement;
|
17 |
+
let canvas: HTMLCanvasElement;
|
18 |
+
let ctx: CanvasRenderingContext2D | null;
|
19 |
+
|
20 |
+
let mouse_pressing: boolean = false;
|
21 |
+
let mouse_button: number;
|
22 |
+
let prev_x: number, prev_y: number;
|
23 |
+
let cur_x: number, cur_y: number;
|
24 |
+
|
25 |
+
let old_width = 0;
|
26 |
+
let old_height = 0;
|
27 |
+
let canvasObserver: ResizeObserver;
|
28 |
+
|
29 |
+
async function set_canvas_size(dimensions: {
|
30 |
+
width: number;
|
31 |
+
height: number;
|
32 |
+
}) {
|
33 |
+
await tick();
|
34 |
+
canvas.width = dimensions.width;
|
35 |
+
canvas.height = dimensions.height;
|
36 |
+
canvas.style.width = `${dimensions.width}px`;
|
37 |
+
canvas.style.height = `${dimensions.height}px`;
|
38 |
+
canvas.style.marginTop = `-${dimensions.height}px`;
|
39 |
+
}
|
40 |
+
|
41 |
+
export async function resize_canvas() {
|
42 |
+
if (width === old_width && height === old_height) return;
|
43 |
+
await set_canvas_size({ width: width, height: height });
|
44 |
+
draw_canvas();
|
45 |
+
setTimeout(() => {
|
46 |
+
old_height = height;
|
47 |
+
old_width = width;
|
48 |
+
}, 100);
|
49 |
+
clear();
|
50 |
+
}
|
51 |
+
|
52 |
+
export function clear() {
|
53 |
+
boxes = [];
|
54 |
+
points = [];
|
55 |
+
draw_canvas();
|
56 |
+
dispatch("change", points);
|
57 |
+
return true;
|
58 |
+
}
|
59 |
+
|
60 |
+
export function undo() {
|
61 |
+
boxes.pop();
|
62 |
+
points.pop();
|
63 |
+
draw_canvas();
|
64 |
+
dispatch("change", points);
|
65 |
+
return true;
|
66 |
+
}
|
67 |
+
|
68 |
+
onMount(async () => {
|
69 |
+
ctx = canvas.getContext("2d");
|
70 |
+
if (ctx) {
|
71 |
+
(ctx.lineJoin = "round"), (ctx.lineCap = "round");
|
72 |
+
ctx.strokeStyle = "#000";
|
73 |
+
}
|
74 |
+
canvasObserver = new ResizeObserver(() => {
|
75 |
+
resize_canvas();
|
76 |
+
});
|
77 |
+
canvasObserver.observe(canvas_container);
|
78 |
+
draw_loop();
|
79 |
+
clear();
|
80 |
+
});
|
81 |
+
|
82 |
+
onDestroy(() => {
|
83 |
+
canvasObserver.unobserve(canvas_container);
|
84 |
+
});
|
85 |
+
|
86 |
+
function get_mouse_pos(e: MouseEvent | TouchEvent | FocusEvent) {
|
87 |
+
const rect = canvas.getBoundingClientRect();
|
88 |
+
let screenX, screenY: number;
|
89 |
+
if (e instanceof MouseEvent) {
|
90 |
+
screenX = e.clientX;
|
91 |
+
screenY = e.clientY;
|
92 |
+
} else if (e instanceof TouchEvent) {
|
93 |
+
screenX = e.changedTouches[0].clientX;
|
94 |
+
screenY = e.changedTouches[0].clientY;
|
95 |
+
} else {
|
96 |
+
return { x: prev_x, y: prev_y };
|
97 |
+
}
|
98 |
+
return { x: screenX - rect.left, y: screenY - rect.top };
|
99 |
+
}
|
100 |
+
|
101 |
+
function handle_draw_start(e: MouseEvent | TouchEvent) {
|
102 |
+
e.preventDefault();
|
103 |
+
(mouse_pressing = true), (mouse_button = 0);
|
104 |
+
if (e instanceof MouseEvent) mouse_button = e.button;
|
105 |
+
const { x, y } = get_mouse_pos(e);
|
106 |
+
(prev_x = x), (prev_y = y);
|
107 |
+
}
|
108 |
+
|
109 |
+
function handle_draw_move(e: MouseEvent | TouchEvent) {
|
110 |
+
e.preventDefault();
|
111 |
+
const { x, y } = get_mouse_pos(e);
|
112 |
+
(cur_x = x), (cur_y = y);
|
113 |
+
}
|
114 |
+
|
115 |
+
function handle_draw_end(e: MouseEvent | TouchEvent | FocusEvent) {
|
116 |
+
e.preventDefault();
|
117 |
+
if (mouse_pressing) {
|
118 |
+
const { x, y } = get_mouse_pos(e);
|
119 |
+
let x1 = Math.min(prev_x, x);
|
120 |
+
let y1 = Math.min(prev_y, y);
|
121 |
+
let x2 = Math.max(prev_x, x);
|
122 |
+
let y2 = Math.max(prev_y, y);
|
123 |
+
boxes.push([x1, y1, x2, y2]);
|
124 |
+
let scale_x = natural_width / width;
|
125 |
+
let scale_y = natural_height / height;
|
126 |
+
let is_point = x1 == x2 && y1 == y2;
|
127 |
+
points.push([
|
128 |
+
Math.round(x1 * scale_x),
|
129 |
+
Math.round(y1 * scale_y),
|
130 |
+
is_point ? (mouse_button == 0 ? 1 : 0) : 2, // label1
|
131 |
+
is_point ? 0 : Math.round(x2 * scale_x),
|
132 |
+
is_point ? 0 : Math.round(y2 * scale_y),
|
133 |
+
is_point ? 4 : 3, // label2
|
134 |
+
]);
|
135 |
+
dispatch("change", points);
|
136 |
+
}
|
137 |
+
mouse_pressing = false;
|
138 |
+
}
|
139 |
+
|
140 |
+
function draw_loop() {
|
141 |
+
draw_canvas();
|
142 |
+
window.requestAnimationFrame(() => {
|
143 |
+
draw_loop();
|
144 |
+
});
|
145 |
+
}
|
146 |
+
|
147 |
+
function draw_canvas() {
|
148 |
+
if (!ctx) return;
|
149 |
+
ctx.clearRect(0, 0, width, height);
|
150 |
+
if (mouse_pressing && cur_x != prev_x && prev_y != cur_y) {
|
151 |
+
let boxes_temp = boxes.slice();
|
152 |
+
boxes_temp.push([prev_x, prev_y, cur_x, cur_y]);
|
153 |
+
draw_boxes(boxes_temp);
|
154 |
+
draw_points(boxes);
|
155 |
+
} else {
|
156 |
+
draw_boxes(boxes);
|
157 |
+
draw_points(boxes);
|
158 |
+
}
|
159 |
+
}
|
160 |
+
|
161 |
+
function draw_boxes(boxes: Array<Array<number>>) {
|
162 |
+
if (!ctx) return;
|
163 |
+
ctx.fillStyle = "rgba(0, 0, 0, 0.1)";
|
164 |
+
ctx.beginPath();
|
165 |
+
boxes.forEach((box: Array<number>) => {
|
166 |
+
if (box[0] != box[2] && box[1] != box[3]) {
|
167 |
+
ctx.rect(box[0], box[1], box[2] - box[0], box[3] - box[1]);
|
168 |
+
}
|
169 |
+
});
|
170 |
+
ctx.fill();
|
171 |
+
ctx.stroke();
|
172 |
+
}
|
173 |
+
|
174 |
+
function draw_points(boxes: Array<Array<number>>) {
|
175 |
+
if (!ctx) return;
|
176 |
+
// Draw foreground points.
|
177 |
+
ctx.beginPath();
|
178 |
+
ctx.fillStyle = "rgba(0, 255, 255, 1.0)"; // Cyan.
|
179 |
+
boxes.forEach((box: Array<number>, index: number) => {
|
180 |
+
if (points[index][2] == 1) {
|
181 |
+
let radius = Math.sqrt(width * height) * 0.01;
|
182 |
+
ctx.moveTo(box[0] + radius, box[1]);
|
183 |
+
ctx.arc(box[0], box[1], radius, 0, 2 * Math.PI, false);
|
184 |
+
}
|
185 |
+
});
|
186 |
+
ctx.fill();
|
187 |
+
ctx.stroke();
|
188 |
+
// Draw background points.
|
189 |
+
ctx.beginPath();
|
190 |
+
ctx.fillStyle = "rgba(255, 192, 203, 1.0)"; // Pink.
|
191 |
+
boxes.forEach((box: Array<number>, index: number) => {
|
192 |
+
if (points[index][2] == 0) {
|
193 |
+
let radius = Math.sqrt(width * height) * 0.01;
|
194 |
+
ctx.moveTo(box[0] + radius, box[1]);
|
195 |
+
ctx.arc(box[0], box[1], radius, 0, 2 * Math.PI, false);
|
196 |
+
}
|
197 |
+
});
|
198 |
+
ctx.fill();
|
199 |
+
ctx.stroke();
|
200 |
+
}
|
201 |
+
</script>
|
202 |
+
|
203 |
+
<div class="wrap" bind:this={canvas_container}>
|
204 |
+
<canvas
|
205 |
+
bind:this={canvas}
|
206 |
+
on:mousedown={handle_draw_start}
|
207 |
+
on:mousemove={handle_draw_move}
|
208 |
+
on:mouseout={handle_draw_move}
|
209 |
+
on:mouseup={handle_draw_end}
|
210 |
+
on:touchstart={handle_draw_start}
|
211 |
+
on:touchmove={handle_draw_move}
|
212 |
+
on:touchend={handle_draw_end}
|
213 |
+
on:touchcancel={handle_draw_end}
|
214 |
+
on:blur={handle_draw_end}
|
215 |
+
on:click|stopPropagation
|
216 |
+
style=" z-index: 15"
|
217 |
+
/>
|
218 |
+
</div>
|
219 |
+
|
220 |
+
<style>
|
221 |
+
canvas {
|
222 |
+
display: block;
|
223 |
+
position: absolute;
|
224 |
+
top: 0;
|
225 |
+
right: 0;
|
226 |
+
bottom: 0;
|
227 |
+
left: 0;
|
228 |
+
margin: auto;
|
229 |
+
}
|
230 |
+
|
231 |
+
.wrap {
|
232 |
+
position: relative;
|
233 |
+
width: var(--size-full);
|
234 |
+
height: var(--size-full);
|
235 |
+
touch-action: none;
|
236 |
+
}
|
237 |
+
</style>
|
src/frontend/shared/ClearImage.svelte
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<script lang="ts">
|
2 |
+
import { createEventDispatcher } from "svelte";
|
3 |
+
import { IconButton } from "@gradio/atoms";
|
4 |
+
import { Undo, Erase, Clear } from "@gradio/icons";
|
5 |
+
|
6 |
+
const dispatch = createEventDispatcher();
|
7 |
+
</script>
|
8 |
+
|
9 |
+
<div>
|
10 |
+
<IconButton
|
11 |
+
Icon={Undo}
|
12 |
+
label="Remove Last Box"
|
13 |
+
on:click={(event) => {
|
14 |
+
dispatch("remove_box");
|
15 |
+
event.stopPropagation();
|
16 |
+
}}
|
17 |
+
/>
|
18 |
+
|
19 |
+
<IconButton
|
20 |
+
Icon={Erase}
|
21 |
+
label="Remove All boxes"
|
22 |
+
on:click={(event) => {
|
23 |
+
dispatch("remove_boxes");
|
24 |
+
event.stopPropagation();
|
25 |
+
}}
|
26 |
+
/>
|
27 |
+
|
28 |
+
<IconButton
|
29 |
+
Icon={Clear}
|
30 |
+
label="Remove Image"
|
31 |
+
on:click={(event) => {
|
32 |
+
dispatch("remove_image");
|
33 |
+
event.stopPropagation();
|
34 |
+
}}
|
35 |
+
/>
|
36 |
+
</div>
|
37 |
+
|
38 |
+
<style>
|
39 |
+
div {
|
40 |
+
display: flex;
|
41 |
+
position: absolute;
|
42 |
+
top: var(--size-2);
|
43 |
+
right: var(--size-2);
|
44 |
+
justify-content: flex-end;
|
45 |
+
gap: var(--spacing-sm);
|
46 |
+
z-index: var(--layer-5);
|
47 |
+
}
|
48 |
+
</style>
|
src/frontend/shared/Image.svelte
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<script lang="ts">
|
2 |
+
import type { HTMLImgAttributes } from "svelte/elements";
|
3 |
+
type $$Props = HTMLImgAttributes;
|
4 |
+
|
5 |
+
import { resolve_wasm_src } from "@gradio/wasm/svelte";
|
6 |
+
|
7 |
+
export let src: HTMLImgAttributes["src"] = undefined;
|
8 |
+
</script>
|
9 |
+
|
10 |
+
{#await resolve_wasm_src(src) then resolved_src}
|
11 |
+
<!-- svelte-ignore a11y-missing-attribute -->
|
12 |
+
<img src={resolved_src} {...$$restProps} />
|
13 |
+
{:catch error}
|
14 |
+
<p style="color: red;">{error.message}</p>
|
15 |
+
{/await}
|
src/frontend/shared/ImagePreview.svelte
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<script lang="ts">
|
2 |
+
import { createEventDispatcher } from "svelte";
|
3 |
+
import type { SelectData } from "@gradio/utils";
|
4 |
+
import { uploadToHuggingFace } from "@gradio/utils";
|
5 |
+
import { BlockLabel, Empty, IconButton, ShareButton } from "@gradio/atoms";
|
6 |
+
import { Download } from "@gradio/icons";
|
7 |
+
import { get_coordinates_of_clicked_image } from "./utils";
|
8 |
+
|
9 |
+
import { Image } from "@gradio/icons";
|
10 |
+
import { type FileData } from "@gradio/client";
|
11 |
+
import type { I18nFormatter } from "@gradio/utils";
|
12 |
+
|
13 |
+
export let value: null | FileData;
|
14 |
+
export let label: string | undefined = undefined;
|
15 |
+
export let show_label: boolean;
|
16 |
+
export let show_download_button = true;
|
17 |
+
export let selectable = false;
|
18 |
+
export let show_share_button = false;
|
19 |
+
export let i18n: I18nFormatter;
|
20 |
+
|
21 |
+
const dispatch = createEventDispatcher<{
|
22 |
+
change: string;
|
23 |
+
select: SelectData;
|
24 |
+
}>();
|
25 |
+
|
26 |
+
const handle_click = (evt: MouseEvent): void => {
|
27 |
+
let coordinates = get_coordinates_of_clicked_image(evt);
|
28 |
+
if (coordinates) {
|
29 |
+
dispatch("select", { index: coordinates, value: null });
|
30 |
+
}
|
31 |
+
};
|
32 |
+
</script>
|
33 |
+
|
34 |
+
<BlockLabel {show_label} Icon={Image} label={label || i18n("image.image")} />
|
35 |
+
{#if value === null || !value.url}
|
36 |
+
<Empty unpadded_box={true} size="large"><Image /></Empty>
|
37 |
+
{:else}
|
38 |
+
<div class="icon-buttons">
|
39 |
+
{#if show_download_button}
|
40 |
+
<a
|
41 |
+
href={value.url}
|
42 |
+
target={window.__is_colab__ ? "_blank" : null}
|
43 |
+
download={value.orig_name || "image"}
|
44 |
+
>
|
45 |
+
<IconButton Icon={Download} label={i18n("common.download")} />
|
46 |
+
</a>
|
47 |
+
{/if}
|
48 |
+
{#if show_share_button}
|
49 |
+
<ShareButton
|
50 |
+
{i18n}
|
51 |
+
on:share
|
52 |
+
on:error
|
53 |
+
formatter={async (value) => {
|
54 |
+
if (!value) return "";
|
55 |
+
let url = await uploadToHuggingFace(value, "base64");
|
56 |
+
return `<img src="${url}" />`;
|
57 |
+
}}
|
58 |
+
{value}
|
59 |
+
/>
|
60 |
+
{/if}
|
61 |
+
</div>
|
62 |
+
<button on:click={handle_click}>
|
63 |
+
<img src={value.url} alt="" class:selectable loading="lazy" />
|
64 |
+
</button>
|
65 |
+
{/if}
|
66 |
+
|
67 |
+
<style>
|
68 |
+
img,
|
69 |
+
button {
|
70 |
+
width: var(--size-full);
|
71 |
+
height: var(--size-full);
|
72 |
+
object-fit: contain;
|
73 |
+
display: block;
|
74 |
+
border-radius: var(--radius-lg);
|
75 |
+
}
|
76 |
+
|
77 |
+
.selectable {
|
78 |
+
cursor: crosshair;
|
79 |
+
}
|
80 |
+
|
81 |
+
.icon-buttons {
|
82 |
+
display: flex;
|
83 |
+
position: absolute;
|
84 |
+
top: 6px;
|
85 |
+
right: 6px;
|
86 |
+
gap: var(--size-1);
|
87 |
+
}
|
88 |
+
</style>
|
src/frontend/shared/ImageUploader.svelte
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<script lang="ts">
|
2 |
+
import { createEventDispatcher } from "svelte";
|
3 |
+
import { BlockLabel } from "@gradio/atoms";
|
4 |
+
import { Image } from "@gradio/icons";
|
5 |
+
import type { I18nFormatter } from "@gradio/utils";
|
6 |
+
import { get_coordinates_of_clicked_image } from "./utils";
|
7 |
+
import { ImagePaste, Upload as UploadIcon } from "@gradio/icons";
|
8 |
+
import { Toolbar, IconButton } from "@gradio/atoms";
|
9 |
+
|
10 |
+
import { Upload } from "@gradio/upload";
|
11 |
+
import { type FileData, normalise_file } from "@gradio/client";
|
12 |
+
import ClearImage from "./ClearImage.svelte";
|
13 |
+
import BoxDrawer from "./BoxDrawer.svelte";
|
14 |
+
|
15 |
+
const dispatch = createEventDispatcher();
|
16 |
+
let box_drawer: BoxDrawer;
|
17 |
+
|
18 |
+
export let value: null | FileData;
|
19 |
+
export let points: null | number[][6];
|
20 |
+
export let label: string | undefined = undefined;
|
21 |
+
export let show_label: boolean;
|
22 |
+
|
23 |
+
function handle_image_load(event: Event) {
|
24 |
+
const element = event.currentTarget as HTMLImageElement;
|
25 |
+
box_drawer.width = element.width;
|
26 |
+
box_drawer.height = element.height;
|
27 |
+
box_drawer.natural_width = element.naturalWidth;
|
28 |
+
box_drawer.natural_height = element.naturalHeight;
|
29 |
+
box_drawer.resize_canvas();
|
30 |
+
}
|
31 |
+
|
32 |
+
function handle_points_change({ detail }: { detail: number[][6] }) {
|
33 |
+
points = detail;
|
34 |
+
dispatch("points_change", detail);
|
35 |
+
}
|
36 |
+
|
37 |
+
export let sources: ("clipboard" | "upload")[] = ["upload", "clipboard"];
|
38 |
+
export let streaming = false;
|
39 |
+
export let root: string;
|
40 |
+
export let i18n: I18nFormatter;
|
41 |
+
|
42 |
+
let upload: Upload;
|
43 |
+
let uploading = false;
|
44 |
+
export let active_tool: "webcam" | null = null;
|
45 |
+
|
46 |
+
function handle_upload({ detail }: CustomEvent<FileData>): void {
|
47 |
+
value = normalise_file(detail, root, null);
|
48 |
+
dispatch("upload", detail);
|
49 |
+
}
|
50 |
+
|
51 |
+
$: if (uploading) value = null;
|
52 |
+
$: value && !value.url && (value = normalise_file(value, root, null));
|
53 |
+
|
54 |
+
let dragging = false;
|
55 |
+
$: dispatch("drag", dragging);
|
56 |
+
|
57 |
+
function handle_click(evt: MouseEvent): void {
|
58 |
+
let coordinates = get_coordinates_of_clicked_image(evt);
|
59 |
+
if (coordinates) {
|
60 |
+
dispatch("select", { index: coordinates, value: null });
|
61 |
+
}
|
62 |
+
}
|
63 |
+
|
64 |
+
const sources_meta = {
|
65 |
+
upload: {
|
66 |
+
icon: UploadIcon,
|
67 |
+
label: i18n("Upload"),
|
68 |
+
order: 0,
|
69 |
+
},
|
70 |
+
clipboard: {
|
71 |
+
icon: ImagePaste,
|
72 |
+
label: i18n("Paste"),
|
73 |
+
order: 2,
|
74 |
+
},
|
75 |
+
};
|
76 |
+
|
77 |
+
$: sources_list = sources.sort(
|
78 |
+
(a, b) => sources_meta[a].order - sources_meta[b].order,
|
79 |
+
);
|
80 |
+
|
81 |
+
async function handle_toolbar(
|
82 |
+
source: (typeof sources)[number],
|
83 |
+
): Promise<void> {
|
84 |
+
switch (source) {
|
85 |
+
case "clipboard":
|
86 |
+
navigator.clipboard.read().then(async (items) => {
|
87 |
+
for (let i = 0; i < items.length; i++) {
|
88 |
+
const type = items[i].types.find((t) => t.startsWith("image/"));
|
89 |
+
if (type) {
|
90 |
+
value = null;
|
91 |
+
items[i].getType(type).then(async (blob) => {
|
92 |
+
const f = await upload.load_files([
|
93 |
+
new File([blob], `clipboard.${type.replace("image/", "")}`),
|
94 |
+
]);
|
95 |
+
f;
|
96 |
+
value = f?.[0] || null;
|
97 |
+
});
|
98 |
+
break;
|
99 |
+
}
|
100 |
+
}
|
101 |
+
});
|
102 |
+
break;
|
103 |
+
case "upload":
|
104 |
+
upload.open_file_upload();
|
105 |
+
break;
|
106 |
+
default:
|
107 |
+
break;
|
108 |
+
}
|
109 |
+
}
|
110 |
+
</script>
|
111 |
+
|
112 |
+
<BlockLabel {show_label} Icon={Image} label={label || "Image"} />
|
113 |
+
|
114 |
+
<div data-testid="image" class="image-container">
|
115 |
+
{#if value?.url}
|
116 |
+
<ClearImage
|
117 |
+
on:remove_box={() => {
|
118 |
+
box_drawer.undo();
|
119 |
+
}}
|
120 |
+
on:remove_boxes={() => {
|
121 |
+
box_drawer.clear();
|
122 |
+
}}
|
123 |
+
on:remove_image={() => {
|
124 |
+
value = null;
|
125 |
+
dispatch("clear");
|
126 |
+
}}
|
127 |
+
/>
|
128 |
+
{/if}
|
129 |
+
<div class="upload-container">
|
130 |
+
<Upload
|
131 |
+
hidden={value !== null || active_tool === "webcam"}
|
132 |
+
bind:this={upload}
|
133 |
+
bind:uploading
|
134 |
+
bind:dragging
|
135 |
+
filetype="image/*"
|
136 |
+
on:load={handle_upload}
|
137 |
+
on:error
|
138 |
+
{root}
|
139 |
+
disable_click={!sources.includes("upload")}
|
140 |
+
>
|
141 |
+
{#if value === null && !active_tool}
|
142 |
+
<slot />
|
143 |
+
{/if}
|
144 |
+
</Upload>
|
145 |
+
{#if value !== null && !streaming}
|
146 |
+
<!-- svelte-ignore a11y-click-events-have-key-events-->
|
147 |
+
<!-- svelte-ignore a11y-no-noninteractive-element-interactions-->
|
148 |
+
<img
|
149 |
+
src={value.url}
|
150 |
+
alt={value.alt_text}
|
151 |
+
on:click={handle_click}
|
152 |
+
on:load={handle_image_load}
|
153 |
+
/>
|
154 |
+
<BoxDrawer bind:this={box_drawer} on:change={handle_points_change} />
|
155 |
+
{/if}
|
156 |
+
</div>
|
157 |
+
{#if sources.length > 1 || sources.includes("clipboard")}
|
158 |
+
<Toolbar show_border={!value?.url}>
|
159 |
+
{#each sources_list as source}
|
160 |
+
<IconButton
|
161 |
+
on:click={() => handle_toolbar(source)}
|
162 |
+
Icon={sources_meta[source].icon}
|
163 |
+
size="large"
|
164 |
+
label="{source}-image-toolbar-btn"
|
165 |
+
padded={false}
|
166 |
+
/>
|
167 |
+
{/each}
|
168 |
+
</Toolbar>
|
169 |
+
{/if}
|
170 |
+
</div>
|
171 |
+
|
172 |
+
<style>
|
173 |
+
img {
|
174 |
+
width: var(--size-full);
|
175 |
+
height: var(--size-full);
|
176 |
+
}
|
177 |
+
|
178 |
+
.upload-container {
|
179 |
+
height: 100%;
|
180 |
+
flex-shrink: 1;
|
181 |
+
max-height: 100%;
|
182 |
+
}
|
183 |
+
|
184 |
+
.image-container {
|
185 |
+
display: flex;
|
186 |
+
height: 100%;
|
187 |
+
flex-direction: column;
|
188 |
+
justify-content: center;
|
189 |
+
align-items: center;
|
190 |
+
max-height: 100%;
|
191 |
+
}
|
192 |
+
</style>
|
src/frontend/shared/utils.ts
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export const get_coordinates_of_clicked_image = (
|
2 |
+
evt: MouseEvent
|
3 |
+
): [number, number] | null => {
|
4 |
+
let image = evt.currentTarget as HTMLImageElement;
|
5 |
+
|
6 |
+
const imageRect = image.getBoundingClientRect();
|
7 |
+
const xScale = image.naturalWidth / imageRect.width;
|
8 |
+
const yScale = image.naturalHeight / imageRect.height;
|
9 |
+
if (xScale > yScale) {
|
10 |
+
const displayed_height = image.naturalHeight / xScale;
|
11 |
+
const y_offset = (imageRect.height - displayed_height) / 2;
|
12 |
+
var x = Math.round((evt.clientX - imageRect.left) * xScale);
|
13 |
+
var y = Math.round((evt.clientY - imageRect.top - y_offset) * xScale);
|
14 |
+
} else {
|
15 |
+
const displayed_width = image.naturalWidth / yScale;
|
16 |
+
const x_offset = (imageRect.width - displayed_width) / 2;
|
17 |
+
var x = Math.round((evt.clientX - imageRect.left - x_offset) * yScale);
|
18 |
+
var y = Math.round((evt.clientY - imageRect.top) * yScale);
|
19 |
+
}
|
20 |
+
if (x < 0 || x >= image.naturalWidth || y < 0 || y >= image.naturalHeight) {
|
21 |
+
return null;
|
22 |
+
}
|
23 |
+
return [x, y];
|
24 |
+
};
|
src/pyproject.toml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = [
|
3 |
+
"hatchling",
|
4 |
+
"hatch-requirements-txt",
|
5 |
+
"hatch-fancy-pypi-readme>=22.5.0",
|
6 |
+
]
|
7 |
+
build-backend = "hatchling.build"
|
8 |
+
|
9 |
+
[project]
|
10 |
+
name = "gradio_image_prompter"
|
11 |
+
version = "0.1.0"
|
12 |
+
description = "A gradio component to upload images and process point/box prompts."
|
13 |
+
readme = "README.md"
|
14 |
+
license = "apache-2.0"
|
15 |
+
requires-python = ">=3.8"
|
16 |
+
url = "https://github.com/PhyscalX/gradio-image-prompter"
|
17 |
+
authors = [{ name = "PhyscalX", email = "neopenx@gmail.com" }]
|
18 |
+
keywords = ["gradio-custom-component", "gradio-template-Image"]
|
19 |
+
# Add dependencies here
|
20 |
+
dependencies = ["gradio>=4.0,<5.0"]
|
21 |
+
classifiers = [
|
22 |
+
'Development Status :: 3 - Alpha',
|
23 |
+
'License :: OSI Approved :: Apache Software License',
|
24 |
+
'Operating System :: OS Independent',
|
25 |
+
'Programming Language :: Python :: 3',
|
26 |
+
'Programming Language :: Python :: 3 :: Only',
|
27 |
+
'Programming Language :: Python :: 3.8',
|
28 |
+
'Programming Language :: Python :: 3.9',
|
29 |
+
'Programming Language :: Python :: 3.10',
|
30 |
+
'Programming Language :: Python :: 3.11',
|
31 |
+
'Topic :: Scientific/Engineering',
|
32 |
+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
33 |
+
'Topic :: Scientific/Engineering :: Visualization',
|
34 |
+
]
|
35 |
+
|
36 |
+
[project.optional-dependencies]
|
37 |
+
dev = ["build", "twine"]
|
38 |
+
|
39 |
+
[tool.hatch.build]
|
40 |
+
artifacts = ["/backend/gradio_image_prompter/templates", "*.pyi", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates", "backend/gradio_image_prompter/templates"]
|
41 |
+
|
42 |
+
[tool.hatch.build.targets.wheel]
|
43 |
+
packages = ["/backend/gradio_image_prompter"]
|
structures/__init__.py
ADDED
File without changes
|
structures/bounding_box.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# transpose
|
4 |
+
FLIP_LEFT_RIGHT = 0
|
5 |
+
FLIP_TOP_BOTTOM = 1
|
6 |
+
|
7 |
+
|
8 |
+
class BoxList(object):
|
9 |
+
"""
|
10 |
+
This class represents a set of bounding boxes.
|
11 |
+
The bounding boxes are represented as a Nx4 Tensor.
|
12 |
+
In order to uniquely determine the bounding boxes with respect
|
13 |
+
to an image, we also store the corresponding image dimensions.
|
14 |
+
They can contain extra information that is specific to each bounding box, such as
|
15 |
+
labels.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, bbox, image_size, mode="xyxy"):
|
19 |
+
device = bbox.device if isinstance(bbox, torch.Tensor) else torch.device("cpu")
|
20 |
+
# only do as_tensor if isn't a "no-op", because it hurts JIT tracing
|
21 |
+
if (not isinstance(bbox, torch.Tensor)
|
22 |
+
or bbox.dtype != torch.float32 or bbox.device != device):
|
23 |
+
bbox = torch.as_tensor(bbox, dtype=torch.float32, device=device)
|
24 |
+
if bbox.ndimension() == 1 and bbox.size(-1) ==4:
|
25 |
+
bbox = bbox.unsqueeze(0)
|
26 |
+
if bbox.ndimension() != 2:
|
27 |
+
raise ValueError(
|
28 |
+
"bbox should have 2 dimensions, got {}".format(bbox.ndimension())
|
29 |
+
)
|
30 |
+
if bbox.size(-1) != 4:
|
31 |
+
raise ValueError(
|
32 |
+
"last dimenion of bbox should have a "
|
33 |
+
"size of 4, got {}".format(bbox.size(-1))
|
34 |
+
)
|
35 |
+
if mode not in ("xyxy", "xywh"):
|
36 |
+
raise ValueError("mode should be 'xyxy' or 'xywh'")
|
37 |
+
|
38 |
+
self.bbox = bbox
|
39 |
+
self.size = image_size # (image_width, image_height)
|
40 |
+
self.mode = mode
|
41 |
+
self.extra_fields = {}
|
42 |
+
|
43 |
+
# note: _jit_wrap/_jit_unwrap only work if the keys and the sizes don't change in between
|
44 |
+
def _jit_unwrap(self):
|
45 |
+
return (self.bbox,) + tuple(f for f in (self.get_field(field)
|
46 |
+
for field in sorted(self.fields()))
|
47 |
+
if isinstance(f, torch.Tensor))
|
48 |
+
|
49 |
+
def _jit_wrap(self, input_stream):
|
50 |
+
self.bbox = input_stream[0]
|
51 |
+
num_consumed = 1
|
52 |
+
for f in sorted(self.fields()):
|
53 |
+
if isinstance(self.extra_fields[f], torch.Tensor):
|
54 |
+
self.extra_fields[f] = input_stream[num_consumed]
|
55 |
+
num_consumed += 1
|
56 |
+
return self, input_stream[num_consumed:]
|
57 |
+
|
58 |
+
def add_field(self, field, field_data):
|
59 |
+
self.extra_fields[field] = field_data
|
60 |
+
|
61 |
+
def get_field(self, field):
|
62 |
+
return self.extra_fields[field]
|
63 |
+
|
64 |
+
def has_field(self, field):
|
65 |
+
return field in self.extra_fields
|
66 |
+
|
67 |
+
def fields(self):
|
68 |
+
return list(self.extra_fields.keys())
|
69 |
+
|
70 |
+
def _copy_extra_fields(self, bbox):
|
71 |
+
for k, v in bbox.extra_fields.items():
|
72 |
+
self.extra_fields[k] = v
|
73 |
+
|
74 |
+
def convert(self, mode):
|
75 |
+
if mode not in ("xyxy", "xywh"):
|
76 |
+
raise ValueError("mode should be 'xyxy' or 'xywh'")
|
77 |
+
if mode == self.mode:
|
78 |
+
return self
|
79 |
+
# we only have two modes, so don't need to check
|
80 |
+
# self.mode
|
81 |
+
xmin, ymin, xmax, ymax = self._split_into_xyxy()
|
82 |
+
if mode == "xyxy":
|
83 |
+
bbox = torch.cat((xmin, ymin, xmax, ymax), dim=-1)
|
84 |
+
bbox = BoxList(bbox, self.size, mode=mode)
|
85 |
+
else:
|
86 |
+
TO_REMOVE = 1
|
87 |
+
# NOTE: explicitly specify dim to avoid tracing error in GPU
|
88 |
+
bbox = torch.cat(
|
89 |
+
(xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=1
|
90 |
+
)
|
91 |
+
bbox = BoxList(bbox, self.size, mode=mode)
|
92 |
+
bbox._copy_extra_fields(self)
|
93 |
+
return bbox
|
94 |
+
|
95 |
+
def _split_into_xyxy(self):
|
96 |
+
if self.mode == "xyxy":
|
97 |
+
xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1)
|
98 |
+
return xmin, ymin, xmax, ymax
|
99 |
+
elif self.mode == "xywh":
|
100 |
+
TO_REMOVE = 1
|
101 |
+
xmin, ymin, w, h = self.bbox.split(1, dim=-1)
|
102 |
+
return (
|
103 |
+
xmin,
|
104 |
+
ymin,
|
105 |
+
xmin + (w - TO_REMOVE).clamp(min=0),
|
106 |
+
ymin + (h - TO_REMOVE).clamp(min=0),
|
107 |
+
)
|
108 |
+
else:
|
109 |
+
raise RuntimeError("Should not be here")
|
110 |
+
|
111 |
+
def resize(self, size, *args, **kwargs):
|
112 |
+
"""
|
113 |
+
Returns a resized copy of this bounding box
|
114 |
+
|
115 |
+
:param size: The requested size in pixels, as a 2-tuple:
|
116 |
+
(width, height).
|
117 |
+
"""
|
118 |
+
|
119 |
+
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
|
120 |
+
if ratios[0] == ratios[1]:
|
121 |
+
ratio = ratios[0]
|
122 |
+
scaled_box = self.bbox * ratio
|
123 |
+
bbox = BoxList(scaled_box, size, mode=self.mode)
|
124 |
+
# bbox._copy_extra_fields(self)
|
125 |
+
for k, v in self.extra_fields.items():
|
126 |
+
if not isinstance(v, torch.Tensor):
|
127 |
+
v = v.resize(size, *args, **kwargs)
|
128 |
+
bbox.add_field(k, v)
|
129 |
+
return bbox
|
130 |
+
|
131 |
+
ratio_width, ratio_height = ratios
|
132 |
+
xmin, ymin, xmax, ymax = self._split_into_xyxy()
|
133 |
+
scaled_xmin = xmin * ratio_width
|
134 |
+
scaled_xmax = xmax * ratio_width
|
135 |
+
scaled_ymin = ymin * ratio_height
|
136 |
+
scaled_ymax = ymax * ratio_height
|
137 |
+
scaled_box = torch.cat(
|
138 |
+
(scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1
|
139 |
+
)
|
140 |
+
bbox = BoxList(scaled_box, size, mode="xyxy")
|
141 |
+
# bbox._copy_extra_fields(self)
|
142 |
+
for k, v in self.extra_fields.items():
|
143 |
+
if not isinstance(v, torch.Tensor):
|
144 |
+
v = v.resize(size, *args, **kwargs)
|
145 |
+
bbox.add_field(k, v)
|
146 |
+
|
147 |
+
return bbox.convert(self.mode)
|
148 |
+
|
149 |
+
def transpose(self, method):
|
150 |
+
"""
|
151 |
+
Transpose bounding box (flip or rotate in 90 degree steps)
|
152 |
+
:param method: One of :py:attr:`PIL.Image.FLIP_LEFT_RIGHT`,
|
153 |
+
:py:attr:`PIL.Image.FLIP_TOP_BOTTOM`, :py:attr:`PIL.Image.ROTATE_90`,
|
154 |
+
:py:attr:`PIL.Image.ROTATE_180`, :py:attr:`PIL.Image.ROTATE_270`,
|
155 |
+
:py:attr:`PIL.Image.TRANSPOSE` or :py:attr:`PIL.Image.TRANSVERSE`.
|
156 |
+
"""
|
157 |
+
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
|
158 |
+
raise NotImplementedError(
|
159 |
+
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
|
160 |
+
)
|
161 |
+
|
162 |
+
image_width, image_height = self.size
|
163 |
+
xmin, ymin, xmax, ymax = self._split_into_xyxy()
|
164 |
+
if method == FLIP_LEFT_RIGHT:
|
165 |
+
TO_REMOVE = 1
|
166 |
+
transposed_xmin = image_width - xmax - TO_REMOVE
|
167 |
+
transposed_xmax = image_width - xmin - TO_REMOVE
|
168 |
+
transposed_ymin = ymin
|
169 |
+
transposed_ymax = ymax
|
170 |
+
elif method == FLIP_TOP_BOTTOM:
|
171 |
+
transposed_xmin = xmin
|
172 |
+
transposed_xmax = xmax
|
173 |
+
transposed_ymin = image_height - ymax
|
174 |
+
transposed_ymax = image_height - ymin
|
175 |
+
|
176 |
+
transposed_boxes = torch.cat(
|
177 |
+
(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1
|
178 |
+
)
|
179 |
+
bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
|
180 |
+
# bbox._copy_extra_fields(self)
|
181 |
+
for k, v in self.extra_fields.items():
|
182 |
+
if not isinstance(v, torch.Tensor):
|
183 |
+
v = v.transpose(method)
|
184 |
+
bbox.add_field(k, v)
|
185 |
+
return bbox.convert(self.mode)
|
186 |
+
|
187 |
+
def crop(self, box):
|
188 |
+
"""
|
189 |
+
Cropss a rectangular region from this bounding box. The box is a
|
190 |
+
4-tuple defining the left, upper, right, and lower pixel
|
191 |
+
coordinate.
|
192 |
+
"""
|
193 |
+
xmin, ymin, xmax, ymax = self._split_into_xyxy()
|
194 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
195 |
+
cropped_xmin = (xmin - box[0]).clamp(min=0, max=w)
|
196 |
+
cropped_ymin = (ymin - box[1]).clamp(min=0, max=h)
|
197 |
+
cropped_xmax = (xmax - box[0]).clamp(min=0, max=w)
|
198 |
+
cropped_ymax = (ymax - box[1]).clamp(min=0, max=h)
|
199 |
+
|
200 |
+
# TODO should I filter empty boxes here?
|
201 |
+
cropped_box = torch.cat(
|
202 |
+
(cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1
|
203 |
+
)
|
204 |
+
bbox = BoxList(cropped_box, (w, h), mode="xyxy")
|
205 |
+
# bbox._copy_extra_fields(self)
|
206 |
+
for k, v in self.extra_fields.items():
|
207 |
+
if not isinstance(v, torch.Tensor):
|
208 |
+
v = v.crop(box)
|
209 |
+
bbox.add_field(k, v)
|
210 |
+
return bbox.convert(self.mode)
|
211 |
+
|
212 |
+
# Tensor-like methods
|
213 |
+
|
214 |
+
def to(self, device):
|
215 |
+
bbox = BoxList(self.bbox.to(device), self.size, self.mode)
|
216 |
+
for k, v in self.extra_fields.items():
|
217 |
+
if hasattr(v, "to"):
|
218 |
+
v = v.to(device)
|
219 |
+
bbox.add_field(k, v)
|
220 |
+
return bbox
|
221 |
+
|
222 |
+
def __getitem__(self, item):
|
223 |
+
bbox = BoxList(self.bbox[item], self.size, self.mode)
|
224 |
+
for k, v in self.extra_fields.items():
|
225 |
+
bbox.add_field(k, v[item])
|
226 |
+
return bbox
|
227 |
+
|
228 |
+
def __len__(self):
|
229 |
+
return self.bbox.shape[0]
|
230 |
+
|
231 |
+
def clip_to_image(self, remove_empty=True):
|
232 |
+
TO_REMOVE = 1
|
233 |
+
x1s = self.bbox[:, 0].clamp(min=0, max=self.size[0] - TO_REMOVE)
|
234 |
+
y1s = self.bbox[:, 1].clamp(min=0, max=self.size[1] - TO_REMOVE)
|
235 |
+
x2s = self.bbox[:, 2].clamp(min=0, max=self.size[0] - TO_REMOVE)
|
236 |
+
y2s = self.bbox[:, 3].clamp(min=0, max=self.size[1] - TO_REMOVE)
|
237 |
+
self.bbox = torch.stack((x1s, y1s, x2s, y2s), dim=-1)
|
238 |
+
if remove_empty:
|
239 |
+
box = self.bbox
|
240 |
+
keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
|
241 |
+
return self[keep]
|
242 |
+
return self
|
243 |
+
|
244 |
+
def area(self):
|
245 |
+
if self.mode == 'xyxy':
|
246 |
+
TO_REMOVE = 1
|
247 |
+
box = self.bbox
|
248 |
+
area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE)
|
249 |
+
elif self.mode == 'xywh':
|
250 |
+
box = self.bbox
|
251 |
+
area = box[:, 2] * box[:, 3]
|
252 |
+
else:
|
253 |
+
raise RuntimeError("Should not be here")
|
254 |
+
|
255 |
+
return area
|
256 |
+
|
257 |
+
def copy_with_fields(self, fields):
|
258 |
+
bbox = BoxList(self.bbox, self.size, self.mode)
|
259 |
+
if not isinstance(fields, (list, tuple)):
|
260 |
+
fields = [fields]
|
261 |
+
for field in fields:
|
262 |
+
bbox.add_field(field, self.get_field(field))
|
263 |
+
return bbox
|
264 |
+
|
265 |
+
def __repr__(self):
|
266 |
+
s = self.__class__.__name__ + "("
|
267 |
+
s += "num_boxes={}, ".format(len(self))
|
268 |
+
s += "image_width={}, ".format(self.size[0])
|
269 |
+
s += "image_height={}, ".format(self.size[1])
|
270 |
+
s += "mode={})".format(self.mode)
|
271 |
+
return s
|
272 |
+
|
273 |
+
@staticmethod
|
274 |
+
def concate_box_list(list_of_boxes):
|
275 |
+
boxes = torch.cat([i.bbox for i in list_of_boxes], dim=0)
|
276 |
+
extra_fields_keys = list(list_of_boxes[0].extra_fields.keys())
|
277 |
+
extra_fields = {}
|
278 |
+
for key in extra_fields_keys:
|
279 |
+
extra_fields[key] = torch.cat([i.extra_fields[key] for i in list_of_boxes], dim=0)
|
280 |
+
|
281 |
+
final = list_of_boxes[0].copy_with_fields(extra_fields_keys)
|
282 |
+
|
283 |
+
final.bbox = boxes
|
284 |
+
final.extra_fields = extra_fields
|
285 |
+
return final
|
286 |
+
|
287 |
+
|
288 |
+
@torch.jit.unused
|
289 |
+
def _onnx_clip_boxes_to_image(boxes, size):
|
290 |
+
# type: (Tensor, Tuple[int, int])
|
291 |
+
"""
|
292 |
+
Clip boxes so that they lie inside an image of size `size`.
|
293 |
+
Clip's min max are traced as constants. Use torch.min/max to WAR this issue
|
294 |
+
Arguments:
|
295 |
+
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format
|
296 |
+
size (Tuple[height, width]): size of the image
|
297 |
+
Returns:
|
298 |
+
clipped_boxes (Tensor[N, 4])
|
299 |
+
"""
|
300 |
+
TO_REMOVE = 1
|
301 |
+
device = boxes.device
|
302 |
+
dim = boxes.dim()
|
303 |
+
boxes_x = boxes[..., 0::2]
|
304 |
+
boxes_y = boxes[..., 1::2]
|
305 |
+
|
306 |
+
boxes_x = torch.max(boxes_x, torch.tensor(0., dtype=torch.float).to(device))
|
307 |
+
boxes_x = torch.min(boxes_x, torch.tensor(size[1] - TO_REMOVE, dtype=torch.float).to(device))
|
308 |
+
boxes_y = torch.max(boxes_y, torch.tensor(0., dtype=torch.float).to(device))
|
309 |
+
boxes_y = torch.min(boxes_y, torch.tensor(size[0] - TO_REMOVE, dtype=torch.float).to(device))
|
310 |
+
|
311 |
+
clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
|
312 |
+
return clipped_boxes.reshape(boxes.shape)
|
313 |
+
|
314 |
+
|
315 |
+
if __name__ == "__main__":
|
316 |
+
bbox = BoxList([[0, 0, 10, 10], [0, 0, 5, 5]], (10, 10))
|
317 |
+
s_bbox = bbox.resize((5, 5))
|
318 |
+
print(s_bbox)
|
319 |
+
print(s_bbox.bbox)
|
320 |
+
|
321 |
+
t_bbox = bbox.transpose(0)
|
322 |
+
print(t_bbox)
|
323 |
+
print(t_bbox.bbox)
|
structures/grasp_box.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
class GraspCoder:
|
3 |
+
"""
|
4 |
+
This class is to encode grasp annotations similar to BoxCoder class
|
5 |
+
It is supposed to support the following functions:
|
6 |
+
1. Encode grasp annotations:
|
7 |
+
(x1, y1, x2, y2, x3, y3, x4, y4) -> (x_center, y_center, width, height, sine(theta))
|
8 |
+
2. Decode grasp annotations:
|
9 |
+
(x_center, y_center, width, height, sine(theta)) -> (x1, y1, x2, y2, x3, y3, x4, y4)
|
10 |
+
3. Resize box grasp annotations when resizing image
|
11 |
+
4. Transform box according to various image augmentations
|
12 |
+
One GraspCoder class should encode annotations of one image only
|
13 |
+
"""
|
14 |
+
def __init__(self, height, width, grasp_annos, grasp_annos_reformat=None):
|
15 |
+
"""
|
16 |
+
|
17 |
+
Args:
|
18 |
+
height: height of image
|
19 |
+
width: width of image
|
20 |
+
grasp_annos: list of numpy.arrays, each of length 8, in format of (x1, y1, x2, y2, x3, y3, x4, y4)
|
21 |
+
"""
|
22 |
+
self.height = height
|
23 |
+
self.width = width
|
24 |
+
self.grasp_annos = grasp_annos
|
25 |
+
self.grasp_annos_reformat = grasp_annos_reformat
|
26 |
+
def __len__(self):
|
27 |
+
return len(self.grasp_annos)
|
28 |
+
def encode(self, normalize=True):
|
29 |
+
"""
|
30 |
+
(x1, y1, x2, y2, x3, y3, x4, y4) -> (x_center, y_center, width, height, sine(theta))
|
31 |
+
Args:
|
32 |
+
normalize -> bool: return values normalized to 0~1 or not
|
33 |
+
Returns:
|
34 |
+
grasp_annos_reformat: List of numpy.array
|
35 |
+
"""
|
36 |
+
grasp_annos_reformat = []
|
37 |
+
for grasp in self.grasp_annos:
|
38 |
+
x1, y1, x2, y2, x3, y3, x4, y4 = tuple(grasp)
|
39 |
+
if (x1 + x2) < (x3 + x4):
|
40 |
+
x1, y1, x2, y2, x3, y3, x4, y4 = x3, y3, x4, y4, x1, y1, x2, y2
|
41 |
+
x_center = (x1 + x3)/2
|
42 |
+
y_center = (y1 + y3)/2
|
43 |
+
width = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
|
44 |
+
height = np.sqrt((x2 - x3)**2 + (y2 - y3)**2)
|
45 |
+
sine = ((y1 + y2)/2 - y_center) / (height / 2)
|
46 |
+
if normalize:
|
47 |
+
x_center /= self.width
|
48 |
+
y_center /= self.height
|
49 |
+
width /= self.width
|
50 |
+
height /= self.height
|
51 |
+
sine = (sine + 1) / 2
|
52 |
+
grasp_annos_reformat.append(np.array([x_center, y_center, width, height, sine]))
|
53 |
+
self.grasp_annos_reformat = grasp_annos_reformat
|
54 |
+
return grasp_annos_reformat
|
55 |
+
def decode(self):
|
56 |
+
"""
|
57 |
+
Decode normalized grasp_annos_reformat, will overwrite self.grasp_annos, and return the overwritten value
|
58 |
+
(x1, y1, x2, y2, x3, y3, x4, y4) -> (x_center, y_center, width, height, sine(theta))
|
59 |
+
Returns:
|
60 |
+
grasp_annos: List of numpy.array
|
61 |
+
"""
|
62 |
+
grasp_annos = []
|
63 |
+
for grasp in self.grasp_annos_reformat:
|
64 |
+
x_center, y_center, width, height, sine = tuple(grasp)
|
65 |
+
x_center *= self.width
|
66 |
+
y_center *= self.height
|
67 |
+
width *= self.width
|
68 |
+
height *= self.height
|
69 |
+
sine = sine * 2 - 1
|
70 |
+
cosine = np.sqrt(1 - sine ** 2)
|
71 |
+
angle = np.arcsin(sine)
|
72 |
+
x1 = x_center + cosine * height / 2 + sine * width / 2
|
73 |
+
x2 = x_center + cosine * height / 2 - sine * width / 2
|
74 |
+
y1 = y_center + sine * height / 2 - cosine * width / 2
|
75 |
+
y2 = y_center + sine * height / 2 + cosine * width / 2
|
76 |
+
x3 = x_center * 2 - x1
|
77 |
+
x4 = x_center * 2 - x2
|
78 |
+
y3 = y_center * 2 - y1
|
79 |
+
y4 = y_center * 2 - y2
|
80 |
+
grasp_annos.append(np.array([x1, y1, x2, y2, x3, y3, x4, y4]))
|
81 |
+
self.grasp_annos = grasp_annos
|
82 |
+
return grasp_annos
|
83 |
+
|
84 |
+
def resize(self, new_size):
|
85 |
+
"""
|
86 |
+
Resize the grasp annotations according to resized image
|
87 |
+
Args:
|
88 |
+
new_size -> Tuple: (new_width, new_height)
|
89 |
+
new_height: The resized image height
|
90 |
+
new_width: The resized image width
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
self
|
94 |
+
"""
|
95 |
+
new_width, new_height = new_size
|
96 |
+
grasp_annos = self.grasp_annos
|
97 |
+
old_height, old_width = self.height, self.width
|
98 |
+
resized_grasp_annos = []
|
99 |
+
for grasp in grasp_annos:
|
100 |
+
grasp[0::2] = grasp[0::2] / old_width * new_width
|
101 |
+
grasp[1::2] = grasp[1::2] / old_height * new_height
|
102 |
+
resized_grasp_annos.append(grasp)
|
103 |
+
self.grasp_annos = resized_grasp_annos
|
104 |
+
self.height, self.width = new_height, new_width
|
105 |
+
|
106 |
+
return self
|
107 |
+
def transpose(self, axis):
|
108 |
+
"""
|
109 |
+
For Horizontal/Vertical flip
|
110 |
+
Args:
|
111 |
+
axis: 0 represents X axis, 1 represnets Y axis
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
self
|
115 |
+
"""
|
116 |
+
grasp_annos = self.grasp_annos
|
117 |
+
flipped_grasp_annos = []
|
118 |
+
if axis == 0:
|
119 |
+
for grasp in grasp_annos:
|
120 |
+
grasp[0::2] = self.width - grasp[0::2]
|
121 |
+
flipped_grasp_annos.append(grasp)
|
122 |
+
elif axis == 1:
|
123 |
+
for grasp in grasp_annos:
|
124 |
+
grasp[1::2] = self.height - grasp[1::2]
|
125 |
+
flipped_grasp_annos.append(grasp)
|
126 |
+
self.grasp_annos = flipped_grasp_annos
|
127 |
+
return self
|
structures/image_list.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class ImageList(object):
|
5 |
+
"""
|
6 |
+
Structure that holds a list of images (of possibly
|
7 |
+
varying sizes) as a single tensor.
|
8 |
+
This works by padding the images to the same size,
|
9 |
+
and storing in a field the original sizes of each image
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, tensors, image_sizes):
|
13 |
+
"""
|
14 |
+
Arguments:
|
15 |
+
tensors (tensor)
|
16 |
+
image_sizes (list[tuple[int, int]])
|
17 |
+
"""
|
18 |
+
self.tensors = tensors
|
19 |
+
self.image_sizes = image_sizes
|
20 |
+
|
21 |
+
def to(self, *args, **kwargs):
|
22 |
+
cast_tensor = self.tensors.to(*args, **kwargs)
|
23 |
+
return ImageList(cast_tensor, self.image_sizes)
|
24 |
+
|
25 |
+
|
26 |
+
def to_image_list(tensors, size_divisible=0):
|
27 |
+
"""
|
28 |
+
tensors can be an ImageList, a torch.Tensor or
|
29 |
+
an iterable of Tensors. It can't be a numpy array.
|
30 |
+
When tensors is an iterable of Tensors, it pads
|
31 |
+
the Tensors with zeros so that they have the same
|
32 |
+
shape
|
33 |
+
"""
|
34 |
+
if isinstance(tensors, torch.Tensor) and size_divisible > 0:
|
35 |
+
tensors = [tensors]
|
36 |
+
|
37 |
+
if isinstance(tensors, ImageList):
|
38 |
+
return tensors
|
39 |
+
elif isinstance(tensors, torch.Tensor):
|
40 |
+
# single tensor shape can be inferred
|
41 |
+
assert tensors.dim() == 4
|
42 |
+
image_sizes = [tensor.shape[-2:] for tensor in tensors]
|
43 |
+
return ImageList(tensors, image_sizes)
|
44 |
+
elif isinstance(tensors, (tuple, list)):
|
45 |
+
max_size = tuple(max(s) for s in zip(*[img.shape for img in tensors]))
|
46 |
+
|
47 |
+
# TODO Ideally, just remove this and let me model handle arbitrary
|
48 |
+
# input sizs
|
49 |
+
if size_divisible > 0:
|
50 |
+
import math
|
51 |
+
|
52 |
+
stride = size_divisible
|
53 |
+
max_size = list(max_size)
|
54 |
+
max_size[1] = int(math.ceil(max_size[1] / stride) * stride)
|
55 |
+
max_size[2] = int(math.ceil(max_size[2] / stride) * stride)
|
56 |
+
max_size = tuple(max_size)
|
57 |
+
|
58 |
+
batch_shape = (len(tensors),) + max_size
|
59 |
+
batched_imgs = tensors[0].new(*batch_shape).zero_()
|
60 |
+
for img, pad_img in zip(tensors, batched_imgs):
|
61 |
+
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
62 |
+
|
63 |
+
image_sizes = [im.shape[-2:] for im in tensors]
|
64 |
+
|
65 |
+
return ImageList(batched_imgs, image_sizes)
|
66 |
+
else:
|
67 |
+
raise TypeError("Unsupported type for to_image_list: {}".format(type(tensors)))
|
structures/segmentation_mask.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import pycocotools.mask as mask_utils
|
5 |
+
|
6 |
+
# transpose
|
7 |
+
FLIP_LEFT_RIGHT = 0
|
8 |
+
FLIP_TOP_BOTTOM = 1
|
9 |
+
|
10 |
+
|
11 |
+
class MaskList(object):
|
12 |
+
"""
|
13 |
+
This class is unfinished and not meant for use yet
|
14 |
+
It is supposed to contain the binary masks for all instances in a list of 2D tensors (H, W)
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, masks, size, mode):
|
18 |
+
assert(isinstance(masks, list))
|
19 |
+
assert(mode in ['mask', 'rle'])
|
20 |
+
self.masks = masks
|
21 |
+
self.size = size # (image_width, image_height)
|
22 |
+
self.mode = mode
|
23 |
+
|
24 |
+
def transpose(self, method):
|
25 |
+
assert (self.mode == "mask"), "RLE masks cannot be transposed. Please convert them to binary first."
|
26 |
+
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
|
27 |
+
raise NotImplementedError(
|
28 |
+
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
|
29 |
+
)
|
30 |
+
|
31 |
+
# width, height = self.size
|
32 |
+
masks = np.array(self.masks)
|
33 |
+
if masks.ndim == 2:
|
34 |
+
masks = np.expand_dims(masks, axis=0)
|
35 |
+
if method == FLIP_LEFT_RIGHT:
|
36 |
+
masks = np.flip(masks, axis=2)
|
37 |
+
elif method == FLIP_TOP_BOTTOM:
|
38 |
+
masks = np.flip(masks, axis=1)
|
39 |
+
flipped_masks = np.split(masks, masks.shape[0])
|
40 |
+
flipped_masks = [mask.squeeze(0) for mask in flipped_masks]
|
41 |
+
return MaskList(flipped_masks, self.size, self.mode)
|
42 |
+
|
43 |
+
def resize(self, size, *args, **kwargs):
|
44 |
+
"""
|
45 |
+
Resize the binary mask.
|
46 |
+
:param size: tuple, (image_width, image_height)
|
47 |
+
:param args:
|
48 |
+
:param kwargs:
|
49 |
+
:return:
|
50 |
+
"""
|
51 |
+
assert(self.mode == "mask"), "RLE masks cannot be resized. Please convert them to binary first."
|
52 |
+
cat_mask = np.array(self.masks)
|
53 |
+
|
54 |
+
cat_mask = cat_mask.transpose(1, 2, 0)
|
55 |
+
cat_mask *= 255
|
56 |
+
cat_mask = cat_mask.astype(np.uint8)
|
57 |
+
resized_mask = cv2.resize(cat_mask, size)
|
58 |
+
if resized_mask.ndim == 2:
|
59 |
+
resized_mask = np.expand_dims(resized_mask, axis=2)
|
60 |
+
try:
|
61 |
+
resized_mask = resized_mask.transpose(2, 0, 1)
|
62 |
+
except ValueError:
|
63 |
+
print("?")
|
64 |
+
resized_mask = resized_mask.astype(int)
|
65 |
+
resized_mask = resized_mask // 255
|
66 |
+
# # visualize to check mask correctness
|
67 |
+
# from matplotlib import pyplot as plt
|
68 |
+
# plt.figure()
|
69 |
+
# plt.imshow(resized_mask[0]*255, cmap='gray')
|
70 |
+
# plt.show()
|
71 |
+
mask_list = np.split(resized_mask, resized_mask.shape[0])
|
72 |
+
mask_list = [mask.squeeze(0) for mask in mask_list]
|
73 |
+
return MaskList(mask_list, size, "mask")
|
74 |
+
|
75 |
+
def pad(self, size):
|
76 |
+
"""
|
77 |
+
pad the binary masks according to the new size. New size must be larger than original size in all dimensions
|
78 |
+
:param size: New image size, (image_width, image_height)
|
79 |
+
:return:
|
80 |
+
"""
|
81 |
+
assert(size[0] >= self.size[0] and size[1] >= self.size[1]), "New size must be larger than original size in all dimensions"
|
82 |
+
cat_mask = np.array(self.masks)
|
83 |
+
if cat_mask.ndim == 2:
|
84 |
+
cat_mask = np.expand_dims(cat_mask, axis=0)
|
85 |
+
padded_mask = np.zeros([len(self.masks), size[1], size[0]])
|
86 |
+
padded_mask[:, :cat_mask.shape[1], :cat_mask.shape[2]] = cat_mask
|
87 |
+
# # visualize to check mask correctness
|
88 |
+
# from matplotlib import pyplot as plt
|
89 |
+
# plt.figure()
|
90 |
+
# plt.imshow(padded_mask[1]*255, cmap='gray')
|
91 |
+
# plt.show()
|
92 |
+
mask_list = np.split(padded_mask, padded_mask.shape[0])
|
93 |
+
mask_list = [mask.squeeze(0) for mask in mask_list]
|
94 |
+
return MaskList(mask_list, size, "mask")
|
95 |
+
|
96 |
+
def convert(self, mode):
|
97 |
+
"""
|
98 |
+
Convert mask from between mode "mask" and mode "rle"
|
99 |
+
:param mode:
|
100 |
+
:return:
|
101 |
+
"""
|
102 |
+
if mode == self.mode:
|
103 |
+
return self
|
104 |
+
elif mode == "rle" and self.mode == "mask":
|
105 |
+
# use pycocotools to encode binary masks to rle
|
106 |
+
rle_mask_list = mask_utils.encode(np.asfortranarray(np.array(self.masks).transpose(1, 2, 0).astype(np.uint8)))
|
107 |
+
return MaskList(rle_mask_list, self.size, "rle")
|
108 |
+
elif mode == "mask" and self.mode == "rle":
|
109 |
+
# use pycocotools to decode rle to binary masks
|
110 |
+
bimasks = mask_utils.decode(self.masks)
|
111 |
+
mask_list = np.split(bimasks.transpose(2, 0, 1), bimasks.shape[2])
|
112 |
+
mask_list = [mask.squeeze(0) for mask in mask_list]
|
113 |
+
return MaskList(mask_list, self.size, "mask")
|
114 |
+
|
115 |
+
def bbox(self, bbox_mode="xyxy"):
|
116 |
+
"""
|
117 |
+
Generate a bounding box according to the binary mask
|
118 |
+
:param bbox_mode:
|
119 |
+
:return:
|
120 |
+
"""
|
121 |
+
pass
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
return len(self.masks)
|
125 |
+
|
126 |
+
def __repr__(self):
|
127 |
+
s = self.__class__.__name__ + "("
|
128 |
+
s += "num_masks={}, ".format(len(self))
|
129 |
+
s += "image_width={}, ".format(self.size[0])
|
130 |
+
s += "image_height={}, ".format(self.size[1])
|
131 |
+
s += "mode={})".format(self.mode)
|
132 |
+
return s
|
133 |
+
|
134 |
+
|
135 |
+
class Polygons(object):
|
136 |
+
"""
|
137 |
+
This class holds a set of polygons that represents a single instance
|
138 |
+
of an object mask. The object can be represented as a set of
|
139 |
+
polygons
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(self, polygons, size, mode):
|
143 |
+
# assert isinstance(polygons, list), '{}'.format(polygons)
|
144 |
+
if isinstance(polygons, list):
|
145 |
+
polygons = [torch.as_tensor(p, dtype=torch.float32) for p in polygons]
|
146 |
+
elif isinstance(polygons, Polygons):
|
147 |
+
polygons = polygons.polygons
|
148 |
+
|
149 |
+
self.polygons = polygons
|
150 |
+
self.size = size
|
151 |
+
self.mode = mode
|
152 |
+
|
153 |
+
def transpose(self, method):
|
154 |
+
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
|
155 |
+
raise NotImplementedError(
|
156 |
+
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
|
157 |
+
)
|
158 |
+
|
159 |
+
flipped_polygons = []
|
160 |
+
width, height = self.size
|
161 |
+
if method == FLIP_LEFT_RIGHT:
|
162 |
+
dim = width
|
163 |
+
idx = 0
|
164 |
+
elif method == FLIP_TOP_BOTTOM:
|
165 |
+
dim = height
|
166 |
+
idx = 1
|
167 |
+
|
168 |
+
for poly in self.polygons:
|
169 |
+
p = poly.clone()
|
170 |
+
TO_REMOVE = 1
|
171 |
+
p[idx::2] = dim - poly[idx::2] - TO_REMOVE
|
172 |
+
flipped_polygons.append(p)
|
173 |
+
|
174 |
+
return Polygons(flipped_polygons, size=self.size, mode=self.mode)
|
175 |
+
|
176 |
+
def crop(self, box):
|
177 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
178 |
+
|
179 |
+
# TODO chck if necessary
|
180 |
+
w = max(w, 1)
|
181 |
+
h = max(h, 1)
|
182 |
+
|
183 |
+
cropped_polygons = []
|
184 |
+
for poly in self.polygons:
|
185 |
+
p = poly.clone()
|
186 |
+
p[0::2] = p[0::2] - box[0] # .clamp(min=0, max=w)
|
187 |
+
p[1::2] = p[1::2] - box[1] # .clamp(min=0, max=h)
|
188 |
+
cropped_polygons.append(p)
|
189 |
+
|
190 |
+
return Polygons(cropped_polygons, size=(w, h), mode=self.mode)
|
191 |
+
|
192 |
+
def resize(self, size, *args, **kwargs):
|
193 |
+
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
|
194 |
+
if ratios[0] == ratios[1]:
|
195 |
+
ratio = ratios[0]
|
196 |
+
scaled_polys = [p * ratio for p in self.polygons]
|
197 |
+
return Polygons(scaled_polys, size, mode=self.mode)
|
198 |
+
|
199 |
+
ratio_w, ratio_h = ratios
|
200 |
+
scaled_polygons = []
|
201 |
+
for poly in self.polygons:
|
202 |
+
p = poly.clone()
|
203 |
+
p[0::2] *= ratio_w
|
204 |
+
p[1::2] *= ratio_h
|
205 |
+
scaled_polygons.append(p)
|
206 |
+
|
207 |
+
return Polygons(scaled_polygons, size=size, mode=self.mode)
|
208 |
+
|
209 |
+
def convert(self, mode):
|
210 |
+
width, height = self.size
|
211 |
+
if mode == "mask":
|
212 |
+
rles = mask_utils.frPyObjects(
|
213 |
+
[p.detach().numpy() for p in self.polygons], height, width
|
214 |
+
)
|
215 |
+
rle = mask_utils.merge(rles)
|
216 |
+
mask = mask_utils.decode(rle)
|
217 |
+
mask = torch.from_numpy(mask)
|
218 |
+
# TODO add squeeze?
|
219 |
+
return mask
|
220 |
+
|
221 |
+
def __repr__(self):
|
222 |
+
s = self.__class__.__name__ + "("
|
223 |
+
s += "num_polygons={}, ".format(len(self.polygons))
|
224 |
+
s += "image_width={}, ".format(self.size[0])
|
225 |
+
s += "image_height={}, ".format(self.size[1])
|
226 |
+
s += "mode={})".format(self.mode)
|
227 |
+
return s
|
228 |
+
|
229 |
+
|
230 |
+
class SegmentationMask(object):
|
231 |
+
"""
|
232 |
+
This class stores the segmentations for all objects in the image
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(self, polygons, size, mode=None):
|
236 |
+
"""
|
237 |
+
Arguments:
|
238 |
+
polygons: a list of list of lists of numbers. The first
|
239 |
+
level of the list correspond to individual instances,
|
240 |
+
the second level to all the polygons that compose the
|
241 |
+
object, and the third level to the polygon coordinates.
|
242 |
+
"""
|
243 |
+
assert isinstance(polygons, list)
|
244 |
+
|
245 |
+
self.polygons = [Polygons(p, size, mode) for p in polygons]
|
246 |
+
self.size = size
|
247 |
+
self.mode = mode
|
248 |
+
|
249 |
+
def transpose(self, method):
|
250 |
+
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
|
251 |
+
raise NotImplementedError(
|
252 |
+
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
|
253 |
+
)
|
254 |
+
|
255 |
+
flipped = []
|
256 |
+
for polygon in self.polygons:
|
257 |
+
flipped.append(polygon.transpose(method))
|
258 |
+
return SegmentationMask(flipped, size=self.size, mode=self.mode)
|
259 |
+
|
260 |
+
def crop(self, box):
|
261 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
262 |
+
cropped = []
|
263 |
+
for polygon in self.polygons:
|
264 |
+
cropped.append(polygon.crop(box))
|
265 |
+
return SegmentationMask(cropped, size=(w, h), mode=self.mode)
|
266 |
+
|
267 |
+
def resize(self, size, *args, **kwargs):
|
268 |
+
scaled = []
|
269 |
+
for polygon in self.polygons:
|
270 |
+
scaled.append(polygon.resize(size, *args, **kwargs))
|
271 |
+
return SegmentationMask(scaled, size=size, mode=self.mode)
|
272 |
+
|
273 |
+
def to(self, *args, **kwargs):
|
274 |
+
return self
|
275 |
+
|
276 |
+
def __getitem__(self, item):
|
277 |
+
if isinstance(item, (int, slice)):
|
278 |
+
selected_polygons = [self.polygons[item]]
|
279 |
+
else:
|
280 |
+
# advanced indexing on a single dimension
|
281 |
+
selected_polygons = []
|
282 |
+
if isinstance(item, torch.Tensor) and item.dtype == torch.bool:
|
283 |
+
item = item.nonzero()
|
284 |
+
item = item.squeeze(1) if item.numel() > 0 else item
|
285 |
+
item = item.tolist()
|
286 |
+
for i in item:
|
287 |
+
selected_polygons.append(self.polygons[i])
|
288 |
+
return SegmentationMask(selected_polygons, size=self.size, mode=self.mode)
|
289 |
+
|
290 |
+
def __iter__(self):
|
291 |
+
return iter(self.polygons)
|
292 |
+
|
293 |
+
def __repr__(self):
|
294 |
+
s = self.__class__.__name__ + "("
|
295 |
+
s += "num_instances={}, ".format(len(self.polygons))
|
296 |
+
s += "image_width={}, ".format(self.size[0])
|
297 |
+
s += "image_height={})".format(self.size[1])
|
298 |
+
return s
|