github-actions[bot] commited on
Commit
123489f
0 Parent(s):

Sync to HuggingFace Spaces

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.github/workflows/main.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ on:
2
+ push:
3
+ branches:
4
+ - main
5
+ jobs:
6
+ huggingface-sync:
7
+ runs-on: ubuntu-latest
8
+ steps:
9
+ - name: Checkout Repository
10
+ uses: actions/checkout@v3
11
+
12
+ - name: Hugging Face Sync
13
+ uses: JacobLinCool/huggingface-sync@v1
14
+ with:
15
+ user: Y-T-G
16
+ space: Blur-Anything
17
+ emoji: 💻
18
+ token: ${{ secrets.HF_TOKEN }}
19
+ github: ${{ secrets.GITHUB_TOKEN }}
20
+ colorFrom: yellow
21
+ colorTo: pino
22
+ sdk: gradio
23
+ app_file: app.py
24
+ pinned: false
25
+ license: mit
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ checkpoints/*
2
+ output/*
3
+ notebook.ipynb
4
+ *.pyc
CHANGELOG.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Changelog
2
+
3
+ ## v0.2.0 - 2023-08-11
4
+
5
+ ### MobileSAM
6
+ - Added quantized ONNX MobileSAM model. Pass `--sam_model_type vit_t` to use it.
7
+
8
+ ## v0.1.0 - 2023-05-06
9
+
10
+ ### Blur-Anything Initial Release
11
+ - Added blur implementation
12
+ - Using pims instead of storing frames in memory for better memory usage
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Mohammed Yasin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Blur Anything
3
+ emoji: 💻
4
+ colorFrom: yellow
5
+ colorTo: pino
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+
11
+ # Blur Anything For Videos
12
+
13
+ Blur Anything is an adaptation of the excellent [Track Anything](https://github.com/gaomingqi/Track-Anything) project which is in turn based on Meta's Segment Anything and XMem. It allows you to blur anything in a video, including faces, license plates, etc.
14
+
15
+ <div>
16
+ <a src="https://img.shields.io/badge/%F0%9F%A4%97-Open_in_Spaces-informational.svg?style=flat-square" href="https://huggingface.co/spaces/Y-T-G/Blur-Anything">
17
+ <img src="https://img.shields.io/badge/%F0%9F%A4%97-Open_in_Spaces-informational.svg?style=flat-square">
18
+ </a>
19
+ </div>
20
+
21
+ ## Get Started
22
+ ```shell
23
+ # Clone the repository:
24
+ git clone https://github.com/Y-T-G/Blur-Anything.git
25
+ cd Blur-Anything
26
+
27
+ # Install dependencies:
28
+ pip install -r requirements.txt
29
+
30
+ # Run the Blur-Anything gradio demo.
31
+ python app.py --device cuda:0
32
+ # python app.py --device cuda:0 --sam_model_type vit_b # for lower memory usage
33
+ ```
34
+
35
+ ## To Do
36
+ - [x] Add a gradio demo
37
+ - [ ] Add support to use YouTube video URL
38
+ - [ ] Add option to completely black out the object
39
+
40
+ ## Acknowledgements
41
+
42
+ The project is an adaptation of [Track Anything](https://github.com/gaomingqi/Track-Anything) which is based on [Segment Anything](https://github.com/facebookresearch/segment-anything) and [XMem](https://github.com/hkchengrex/XMem).
app.py ADDED
@@ -0,0 +1,880 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import requests
4
+ import sys
5
+ import json
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ import torchvision
11
+ import pims
12
+
13
+ from export_onnx_model import run_export
14
+ from onnxruntime.quantization import QuantType
15
+ from onnxruntime.quantization.quantize import quantize_dynamic
16
+
17
+ sys.path.append(sys.path[0] + "/tracker")
18
+ sys.path.append(sys.path[0] + "/tracker/model")
19
+
20
+ from track_anything import TrackingAnything
21
+ from track_anything import parse_augment
22
+
23
+ from utils.painter import mask_painter
24
+ from utils.blur import blur_frames_and_write
25
+
26
+
27
+ # download checkpoints
28
+ def download_checkpoint(url, folder, filename):
29
+ os.makedirs(folder, exist_ok=True)
30
+ filepath = os.path.join(folder, filename)
31
+
32
+ if not os.path.exists(filepath):
33
+ print("Downloading checkpoints...")
34
+ response = requests.get(url, stream=True)
35
+ with open(filepath, "wb") as f:
36
+ for chunk in response.iter_content(chunk_size=8192):
37
+ if chunk:
38
+ f.write(chunk)
39
+
40
+ print("Download successful.")
41
+
42
+ return filepath
43
+
44
+
45
+ # convert points input to prompt state
46
+ def get_prompt(click_state, click_input):
47
+ inputs = json.loads(click_input)
48
+ points = click_state[0]
49
+ labels = click_state[1]
50
+ for input in inputs:
51
+ points.append(input[:2])
52
+ labels.append(input[2])
53
+ click_state[0] = points
54
+ click_state[1] = labels
55
+ prompt = {
56
+ "prompt_type": ["click"],
57
+ "input_point": click_state[0],
58
+ "input_label": click_state[1],
59
+ "multimask_output": "False",
60
+ }
61
+ return prompt
62
+
63
+
64
+ # extract frames from upload video
65
+ def get_frames_from_video(video_input, video_state):
66
+ """
67
+ Args:
68
+ video_path:str
69
+ timestamp:float64
70
+ Return
71
+ [[0:nearest_frame], [nearest_frame:], nearest_frame]
72
+ """
73
+ video_path = video_input
74
+ frames = []
75
+ user_name = time.time()
76
+ operation_log = [
77
+ ("", ""),
78
+ (
79
+ "Video uploaded. Click the image for adding targets to track and blur.",
80
+ "Normal",
81
+ ),
82
+ ]
83
+ try:
84
+ frames = pims.Video(video_path)
85
+ fps = frames.frame_rate
86
+ image_size = (frames.shape[1], frames.shape[2])
87
+
88
+ except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
89
+ print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
90
+
91
+ # initialize video_state
92
+ video_state = {
93
+ "user_name": user_name,
94
+ "video_name": os.path.split(video_path)[-1],
95
+ "origin_images": frames,
96
+ "painted_images": [0] * len(frames),
97
+ "masks": [0] * len(frames),
98
+ "logits": [None] * len(frames),
99
+ "select_frame_number": 0,
100
+ "fps": fps,
101
+ }
102
+ video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(
103
+ video_state["video_name"], video_state["fps"], len(frames), image_size
104
+ )
105
+ model.samcontroler.sam_controler.reset_image()
106
+ model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
107
+ return (
108
+ video_state,
109
+ video_info,
110
+ video_state["origin_images"][0],
111
+ gr.update(visible=True, maximum=len(frames), value=1),
112
+ gr.update(visible=True, maximum=len(frames), value=len(frames)),
113
+ gr.update(visible=True),
114
+ gr.update(visible=True),
115
+ gr.update(visible=True),
116
+ gr.update(visible=True),
117
+ gr.update(visible=True),
118
+ gr.update(visible=True),
119
+ gr.update(visible=True),
120
+ gr.update(visible=True),
121
+ gr.update(visible=True),
122
+ gr.update(visible=True),
123
+ gr.update(visible=True, value=operation_log),
124
+ )
125
+
126
+
127
+ def run_example(example):
128
+ return video_input
129
+
130
+
131
+ # get the select frame from gradio slider
132
+ def select_template(image_selection_slider, video_state, interactive_state):
133
+ # images = video_state[1]
134
+ image_selection_slider -= 1
135
+ video_state["select_frame_number"] = image_selection_slider
136
+
137
+ # once select a new template frame, set the image in sam
138
+
139
+ model.samcontroler.sam_controler.reset_image()
140
+ model.samcontroler.sam_controler.set_image(
141
+ video_state["origin_images"][image_selection_slider]
142
+ )
143
+
144
+ # update the masks when select a new template frame
145
+ operation_log = [
146
+ ("", ""),
147
+ (
148
+ "Select frame {}. Try click image and add mask for tracking.".format(
149
+ image_selection_slider
150
+ ),
151
+ "Normal",
152
+ ),
153
+ ]
154
+
155
+ return (
156
+ video_state["painted_images"][image_selection_slider],
157
+ video_state,
158
+ interactive_state,
159
+ operation_log,
160
+ )
161
+
162
+
163
+ # set the tracking end frame
164
+ def set_end_number(track_pause_number_slider, video_state, interactive_state):
165
+ interactive_state["track_end_number"] = track_pause_number_slider
166
+ operation_log = [
167
+ ("", ""),
168
+ (
169
+ "Set the tracking finish at frame {}".format(track_pause_number_slider),
170
+ "Normal",
171
+ ),
172
+ ]
173
+
174
+ return (
175
+ interactive_state,
176
+ operation_log,
177
+ )
178
+
179
+
180
+ def get_resize_ratio(resize_ratio_slider, interactive_state):
181
+ interactive_state["resize_ratio"] = resize_ratio_slider
182
+
183
+ return interactive_state
184
+
185
+
186
+ def get_blur_strength(blur_strength_slider, interactive_state):
187
+ interactive_state["blur_strength"] = blur_strength_slider
188
+
189
+ return interactive_state
190
+
191
+
192
+ # use sam to get the mask
193
+ def sam_refine(
194
+ video_state, point_prompt, click_state, interactive_state, evt: gr.SelectData
195
+ ):
196
+ """
197
+ Args:
198
+ template_frame: PIL.Image
199
+ point_prompt: flag for positive or negative button click
200
+ click_state: [[points], [labels]]
201
+ """
202
+ if point_prompt == "Positive":
203
+ coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
204
+ interactive_state["positive_click_times"] += 1
205
+ else:
206
+ coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1])
207
+ interactive_state["negative_click_times"] += 1
208
+
209
+ # prompt for sam model
210
+ model.samcontroler.sam_controler.reset_image()
211
+ model.samcontroler.sam_controler.set_image(
212
+ video_state["origin_images"][video_state["select_frame_number"]]
213
+ )
214
+ prompt = get_prompt(click_state=click_state, click_input=coordinate)
215
+
216
+ mask, logit, painted_image = model.first_frame_click(
217
+ image=video_state["origin_images"][video_state["select_frame_number"]],
218
+ points=np.array(prompt["input_point"]),
219
+ labels=np.array(prompt["input_label"]),
220
+ multimask=prompt["multimask_output"],
221
+ )
222
+
223
+ video_state["masks"][video_state["select_frame_number"]] = mask
224
+ video_state["logits"][video_state["select_frame_number"]] = logit
225
+ video_state["painted_images"][video_state["select_frame_number"]] = painted_image
226
+
227
+ operation_log = [
228
+ ("", ""),
229
+ (
230
+ "Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment",
231
+ "Normal",
232
+ ),
233
+ ]
234
+ return painted_image, video_state, interactive_state, operation_log
235
+
236
+
237
+ def add_multi_mask(video_state, interactive_state, mask_dropdown):
238
+ try:
239
+ mask = video_state["masks"][video_state["select_frame_number"]]
240
+ interactive_state["multi_mask"]["masks"].append(mask)
241
+ interactive_state["multi_mask"]["mask_names"].append(
242
+ "mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))
243
+ )
244
+ mask_dropdown.append(
245
+ "mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))
246
+ )
247
+ select_frame, run_status = show_mask(
248
+ video_state, interactive_state, mask_dropdown
249
+ )
250
+
251
+ operation_log = [
252
+ ("", ""),
253
+ (
254
+ "Added a mask, use the mask select for target tracking or blurring.",
255
+ "Normal",
256
+ ),
257
+ ]
258
+ except Exception:
259
+ operation_log = [
260
+ ("Please click the left image to generate mask.", "Error"),
261
+ ("", ""),
262
+ ]
263
+ return (
264
+ interactive_state,
265
+ gr.update(
266
+ choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown
267
+ ),
268
+ select_frame,
269
+ [[], []],
270
+ operation_log,
271
+ )
272
+
273
+
274
+ def clear_click(video_state, click_state):
275
+ click_state = [[], []]
276
+ template_frame = video_state["origin_images"][video_state["select_frame_number"]]
277
+ operation_log = [
278
+ ("", ""),
279
+ ("Clear points history and refresh the image.", "Normal"),
280
+ ]
281
+ return template_frame, click_state, operation_log
282
+
283
+
284
+ def remove_multi_mask(interactive_state, mask_dropdown):
285
+ interactive_state["multi_mask"]["mask_names"] = []
286
+ interactive_state["multi_mask"]["masks"] = []
287
+
288
+ operation_log = [("", ""), ("Remove all mask, please add new masks", "Normal")]
289
+ return interactive_state, gr.update(choices=[], value=[]), operation_log
290
+
291
+
292
+ def show_mask(video_state, interactive_state, mask_dropdown):
293
+ mask_dropdown.sort()
294
+ select_frame = video_state["origin_images"][video_state["select_frame_number"]]
295
+
296
+ for i in range(len(mask_dropdown)):
297
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
298
+ mask = interactive_state["multi_mask"]["masks"][mask_number]
299
+ select_frame = mask_painter(
300
+ select_frame, mask.astype("uint8"), mask_color=mask_number + 2
301
+ )
302
+
303
+ operation_log = [
304
+ ("", ""),
305
+ ("Select {} for tracking or blurring".format(mask_dropdown), "Normal"),
306
+ ]
307
+ return select_frame, operation_log
308
+
309
+
310
+ # tracking vos
311
+ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
312
+ operation_log = [
313
+ ("", ""),
314
+ (
315
+ "Track the selected masks, and then you can select the masks for blurring.",
316
+ "Normal",
317
+ ),
318
+ ]
319
+ model.xmem.clear_memory()
320
+ if interactive_state["track_end_number"]:
321
+ following_frames = video_state["origin_images"][
322
+ video_state["select_frame_number"]: interactive_state["track_end_number"]
323
+ ]
324
+ else:
325
+ following_frames = video_state["origin_images"][
326
+ video_state["select_frame_number"]:
327
+ ]
328
+
329
+ if interactive_state["multi_mask"]["masks"]:
330
+ if len(mask_dropdown) == 0:
331
+ mask_dropdown = ["mask_001"]
332
+ mask_dropdown.sort()
333
+ template_mask = interactive_state["multi_mask"]["masks"][
334
+ int(mask_dropdown[0].split("_")[1]) - 1
335
+ ] * (int(mask_dropdown[0].split("_")[1]))
336
+ for i in range(1, len(mask_dropdown)):
337
+ mask_number = int(mask_dropdown[i].split("_")[1]) - 1
338
+ template_mask = np.clip(
339
+ template_mask
340
+ + interactive_state["multi_mask"]["masks"][mask_number]
341
+ * (mask_number + 1),
342
+ 0,
343
+ mask_number + 1,
344
+ )
345
+ video_state["masks"][video_state["select_frame_number"]] = template_mask
346
+ else:
347
+ template_mask = video_state["masks"][video_state["select_frame_number"]]
348
+
349
+ # operation error
350
+ if len(np.unique(template_mask)) == 1:
351
+ template_mask[0][0] = 1
352
+ operation_log = [
353
+ (
354
+ "Error! Please add at least one mask to track by clicking the left image.",
355
+ "Error",
356
+ ),
357
+ ("", ""),
358
+ ]
359
+ # return video_output, video_state, interactive_state, operation_error
360
+ output_path = "./output/track/{}".format(video_state["video_name"])
361
+ fps = video_state["fps"]
362
+ masks, logits, painted_images = model.generator(
363
+ images=following_frames, template_mask=template_mask, write=True, fps=fps, output_path=output_path
364
+ )
365
+ # clear GPU memory
366
+ model.xmem.clear_memory()
367
+
368
+ if interactive_state["track_end_number"]:
369
+ video_state["masks"][
370
+ video_state["select_frame_number"]: interactive_state["track_end_number"]
371
+ ] = masks
372
+ video_state["logits"][
373
+ video_state["select_frame_number"]: interactive_state["track_end_number"]
374
+ ] = logits
375
+ video_state["painted_images"][
376
+ video_state["select_frame_number"]: interactive_state["track_end_number"]
377
+ ] = painted_images
378
+ else:
379
+ video_state["masks"][video_state["select_frame_number"]:] = masks
380
+ video_state["logits"][video_state["select_frame_number"]:] = logits
381
+ video_state["painted_images"][
382
+ video_state["select_frame_number"]:
383
+ ] = painted_images
384
+
385
+ interactive_state["inference_times"] += 1
386
+
387
+ print(
388
+ "For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(
389
+ interactive_state["inference_times"],
390
+ interactive_state["positive_click_times"]
391
+ + interactive_state["negative_click_times"],
392
+ interactive_state["positive_click_times"],
393
+ interactive_state["negative_click_times"],
394
+ )
395
+ )
396
+
397
+ return output_path, video_state, interactive_state, operation_log
398
+
399
+
400
+ def blur_video(video_state, interactive_state, mask_dropdown):
401
+ operation_log = [("", ""), ("Removed the selected masks.", "Normal")]
402
+
403
+ frames = np.asarray(video_state["origin_images"])[
404
+ video_state["select_frame_number"]:interactive_state["track_end_number"]
405
+ ]
406
+ fps = video_state["fps"]
407
+ output_path = "./output/blur/{}".format(video_state["video_name"])
408
+ blur_masks = np.asarray(video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]])
409
+ if len(mask_dropdown) == 0:
410
+ mask_dropdown = ["mask_001"]
411
+ mask_dropdown.sort()
412
+ # convert mask_dropdown to mask numbers
413
+ blur_mask_numbers = [
414
+ int(mask_dropdown[i].split("_")[1]) for i in range(len(mask_dropdown))
415
+ ]
416
+ # interate through all masks and remove the masks that are not in mask_dropdown
417
+ unique_masks = np.unique(blur_masks)
418
+ num_masks = len(unique_masks) - 1
419
+ for i in range(1, num_masks + 1):
420
+ if i in blur_mask_numbers:
421
+ continue
422
+ blur_masks[blur_masks == i] = 0
423
+
424
+ # blur video
425
+ try:
426
+ blur_frames_and_write(
427
+ frames,
428
+ blur_masks,
429
+ ratio=interactive_state["resize_ratio"],
430
+ strength=interactive_state["blur_strength"],
431
+ fps=fps,
432
+ output_path=output_path
433
+ )
434
+ except Exception as e:
435
+ print("Exception ", e)
436
+ operation_log = [
437
+ (
438
+ "Error! You are trying to blur without masks input. Please track the selected mask first, and then press blur. To speed up, please use the resize ratio to scale down the image size.",
439
+ "Error",
440
+ ),
441
+ ("", ""),
442
+ ]
443
+
444
+ return output_path, video_state, interactive_state, operation_log
445
+
446
+
447
+ # generate video after vos inference
448
+ def generate_video_from_frames(frames, output_path, fps=30):
449
+ """
450
+ Generates a video from a list of frames.
451
+
452
+ Args:
453
+ frames (list of numpy arrays): The frames to include in the video.
454
+ output_path (str): The path to save the generated video.
455
+ fps (int, optional): The frame rate of the output video. Defaults to 30.
456
+ """
457
+
458
+ frames = torch.from_numpy(np.asarray(frames))
459
+ if not os.path.exists(os.path.dirname(output_path)):
460
+ os.makedirs(os.path.dirname(output_path))
461
+ torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
462
+ return output_path
463
+
464
+
465
+ # convert to onnx quantized model
466
+ def convert_to_onnx(args, checkpoint, quantized=True):
467
+ """
468
+ Convert the model to onnx format.
469
+
470
+ Args:
471
+ model (nn.Module): The model to convert.
472
+ output_path (str): The path to save the onnx model.
473
+ input_shape (tuple): The input shape of the model.
474
+ quantized (bool, optional): Whether to quantize the model. Defaults to True.
475
+ """
476
+ onnx_output_path = f"{checkpoint.split('.')[-2]}.onnx"
477
+ quant_output_path = f"{checkpoint.split('.')[-2]}_quant.onnx"
478
+
479
+ print("Converting to ONNX quantized model...")
480
+
481
+ if not (os.path.exists(onnx_output_path)):
482
+ run_export(
483
+ model_type=args.sam_model_type,
484
+ checkpoint=checkpoint,
485
+ opset=16,
486
+ output=onnx_output_path,
487
+ return_single_mask=True
488
+ )
489
+
490
+ if quantized and not (os.path.exists(quant_output_path)):
491
+ quantize_dynamic(
492
+ model_input=onnx_output_path,
493
+ model_output=quant_output_path,
494
+ optimize_model=True,
495
+ per_channel=False,
496
+ reduce_range=False,
497
+ weight_type=QuantType.QUInt8,
498
+ )
499
+
500
+ return quant_output_path if quantized else onnx_output_path
501
+
502
+
503
+ # args, defined in track_anything.py
504
+ args = parse_augment()
505
+
506
+ # check and download checkpoints if needed
507
+ SAM_checkpoint_dict = {
508
+ "vit_h": "sam_vit_h_4b8939.pth",
509
+ "vit_l": "sam_vit_l_0b3195.pth",
510
+ "vit_b": "sam_vit_b_01ec64.pth",
511
+ "vit_t": "mobile_sam.pt",
512
+ }
513
+ SAM_checkpoint_url_dict = {
514
+ "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
515
+ "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
516
+ "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
517
+ "vit_t": "https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt",
518
+ }
519
+ sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type]
520
+ sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type]
521
+ xmem_checkpoint = "XMem-s012.pth"
522
+ xmem_checkpoint_url = (
523
+ "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
524
+ )
525
+
526
+ # initialize SAM, XMem
527
+ folder = "checkpoints"
528
+ sam_pt_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
529
+ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
530
+
531
+ if args.sam_model_type == "vit_t":
532
+ sam_onnx_checkpoint = convert_to_onnx(args, sam_pt_checkpoint, quantized=True)
533
+ else:
534
+ sam_onnx_checkpoint = ""
535
+
536
+ model = TrackingAnything(sam_pt_checkpoint, sam_onnx_checkpoint, xmem_checkpoint, args)
537
+
538
+ title = """<p><h1 align="center">Blur-Anything</h1></p>
539
+ """
540
+ description = """<p>Gradio demo for Blur Anything, a flexible and interactive
541
+ tool for video object tracking, segmentation, and blurring. To
542
+ use it, simply upload your video, or click one of the examples to
543
+ load them. Code: <a
544
+ href="https://github.com/Y-T-G/Blur-Anything">https://github.com/Y-T-G/Blur-Anything</a>
545
+ <a
546
+ href="https://huggingface.co/spaces/Y-T-G/Blur-Anything?duplicate=true"><img
547
+ style="display: inline; margin-top: 0em; margin-bottom: 0em"
548
+ src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
549
+
550
+
551
+ with gr.Blocks() as iface:
552
+ """
553
+ state for
554
+ """
555
+ click_state = gr.State([[], []])
556
+ interactive_state = gr.State(
557
+ {
558
+ "inference_times": 0,
559
+ "negative_click_times": 0,
560
+ "positive_click_times": 0,
561
+ "mask_save": args.mask_save,
562
+ "multi_mask": {"mask_names": [], "masks": []},
563
+ "track_end_number": None,
564
+ "resize_ratio": 1,
565
+ "blur_strength": 3,
566
+ }
567
+ )
568
+
569
+ video_state = gr.State(
570
+ {
571
+ "user_name": "",
572
+ "video_name": "",
573
+ "origin_images": None,
574
+ "painted_images": None,
575
+ "masks": None,
576
+ "blur_masks": None,
577
+ "logits": None,
578
+ "select_frame_number": 0,
579
+ "fps": 30,
580
+ }
581
+ )
582
+ gr.Markdown(title)
583
+ gr.Markdown(description)
584
+ with gr.Row():
585
+ # for user video input
586
+ with gr.Column():
587
+ with gr.Row():
588
+ video_input = gr.Video()
589
+ with gr.Column():
590
+ video_info = gr.Textbox(label="Video Info")
591
+ resize_info = gr.Textbox(
592
+ value="You can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.",
593
+ label="Tips for running this demo.",
594
+ )
595
+ resize_ratio_slider = gr.Slider(
596
+ minimum=0.02,
597
+ maximum=1,
598
+ step=0.02,
599
+ value=1,
600
+ label="Resize ratio",
601
+ visible=True,
602
+ )
603
+
604
+ with gr.Row():
605
+ # put the template frame under the radio button
606
+ with gr.Column():
607
+ # extract frames
608
+ with gr.Column():
609
+ extract_frames_button = gr.Button(
610
+ value="Get video info", interactive=True, variant="primary"
611
+ )
612
+
613
+ # click points settins, negative or positive, mode continuous or single
614
+ with gr.Row():
615
+ with gr.Row():
616
+ point_prompt = gr.Radio(
617
+ choices=["Positive", "Negative"],
618
+ value="Positive",
619
+ label="Point Prompt",
620
+ interactive=True,
621
+ visible=False,
622
+ )
623
+ remove_mask_button = gr.Button(
624
+ value="Remove mask", interactive=True, visible=False
625
+ )
626
+ clear_button_click = gr.Button(
627
+ value="Clear Clicks", interactive=True, visible=False
628
+ )
629
+ Add_mask_button = gr.Button(
630
+ value="Add mask", interactive=True, visible=False
631
+ )
632
+ template_frame = gr.Image(
633
+ type="pil",
634
+ interactive=True,
635
+ elem_id="template_frame",
636
+ visible=False,
637
+ )
638
+ image_selection_slider = gr.Slider(
639
+ minimum=1,
640
+ maximum=100,
641
+ step=1,
642
+ value=1,
643
+ label="Image Selection",
644
+ visible=False,
645
+ )
646
+ track_pause_number_slider = gr.Slider(
647
+ minimum=1,
648
+ maximum=100,
649
+ step=1,
650
+ value=1,
651
+ label="Track end frames",
652
+ visible=False,
653
+ )
654
+
655
+ with gr.Column():
656
+ run_status = gr.HighlightedText(
657
+ value=[
658
+ ("Text", "Error"),
659
+ ("to be", "Label 2"),
660
+ ("highlighted", "Label 3"),
661
+ ],
662
+ visible=False,
663
+ )
664
+ mask_dropdown = gr.Dropdown(
665
+ multiselect=True,
666
+ value=[],
667
+ label="Mask selection",
668
+ info=".",
669
+ visible=False,
670
+ )
671
+ video_output = gr.Video(visible=False)
672
+ with gr.Row():
673
+ tracking_video_predict_button = gr.Button(
674
+ value="Tracking", visible=False
675
+ )
676
+ blur_video_predict_button = gr.Button(
677
+ value="Blur", visible=False
678
+ )
679
+ with gr.Row():
680
+ blur_strength_slider = gr.Slider(
681
+ minimum=3,
682
+ maximum=15,
683
+ step=2,
684
+ value=3,
685
+ label="Blur Strength",
686
+ visible=False,
687
+ )
688
+
689
+ # first step: get the video information
690
+ extract_frames_button.click(
691
+ fn=get_frames_from_video,
692
+ inputs=[video_input, video_state],
693
+ outputs=[
694
+ video_state,
695
+ video_info,
696
+ template_frame,
697
+ image_selection_slider,
698
+ track_pause_number_slider,
699
+ point_prompt,
700
+ clear_button_click,
701
+ Add_mask_button,
702
+ template_frame,
703
+ tracking_video_predict_button,
704
+ video_output,
705
+ mask_dropdown,
706
+ remove_mask_button,
707
+ blur_video_predict_button,
708
+ blur_strength_slider,
709
+ run_status,
710
+ ],
711
+ )
712
+
713
+ # second step: select images from slider
714
+ image_selection_slider.release(
715
+ fn=select_template,
716
+ inputs=[image_selection_slider, video_state, interactive_state],
717
+ outputs=[template_frame, video_state, interactive_state, run_status],
718
+ api_name="select_image",
719
+ )
720
+ track_pause_number_slider.release(
721
+ fn=set_end_number,
722
+ inputs=[track_pause_number_slider, video_state, interactive_state],
723
+ outputs=[interactive_state, run_status],
724
+ api_name="end_image",
725
+ )
726
+ resize_ratio_slider.release(
727
+ fn=get_resize_ratio,
728
+ inputs=[resize_ratio_slider, interactive_state],
729
+ outputs=[interactive_state],
730
+ api_name="resize_ratio",
731
+ )
732
+
733
+ blur_strength_slider.release(
734
+ fn=get_blur_strength,
735
+ inputs=[blur_strength_slider, interactive_state],
736
+ outputs=[interactive_state],
737
+ api_name="blur_strength",
738
+ )
739
+
740
+ # click select image to get mask using sam
741
+ template_frame.select(
742
+ fn=sam_refine,
743
+ inputs=[video_state, point_prompt, click_state, interactive_state],
744
+ outputs=[template_frame, video_state, interactive_state, run_status],
745
+ )
746
+
747
+ # add different mask
748
+ Add_mask_button.click(
749
+ fn=add_multi_mask,
750
+ inputs=[video_state, interactive_state, mask_dropdown],
751
+ outputs=[
752
+ interactive_state,
753
+ mask_dropdown,
754
+ template_frame,
755
+ click_state,
756
+ run_status,
757
+ ],
758
+ )
759
+
760
+ remove_mask_button.click(
761
+ fn=remove_multi_mask,
762
+ inputs=[interactive_state, mask_dropdown],
763
+ outputs=[interactive_state, mask_dropdown, run_status],
764
+ )
765
+
766
+ # tracking video from select image and mask
767
+ tracking_video_predict_button.click(
768
+ fn=vos_tracking_video,
769
+ inputs=[video_state, interactive_state, mask_dropdown],
770
+ outputs=[video_output, video_state, interactive_state, run_status],
771
+ )
772
+
773
+ # tracking video from select image and mask
774
+ blur_video_predict_button.click(
775
+ fn=blur_video,
776
+ inputs=[video_state, interactive_state, mask_dropdown],
777
+ outputs=[video_output, video_state, interactive_state, run_status],
778
+ )
779
+
780
+ # click to get mask
781
+ mask_dropdown.change(
782
+ fn=show_mask,
783
+ inputs=[video_state, interactive_state, mask_dropdown],
784
+ outputs=[template_frame, run_status],
785
+ )
786
+
787
+ # clear input
788
+ video_input.clear(
789
+ lambda: (
790
+ {
791
+ "user_name": "",
792
+ "video_name": "",
793
+ "origin_images": None,
794
+ "painted_images": None,
795
+ "masks": None,
796
+ "blur_masks": None,
797
+ "logits": None,
798
+ "select_frame_number": 0,
799
+ "fps": 30,
800
+ },
801
+ {
802
+ "inference_times": 0,
803
+ "negative_click_times": 0,
804
+ "positive_click_times": 0,
805
+ "mask_save": args.mask_save,
806
+ "multi_mask": {"mask_names": [], "masks": []},
807
+ "track_end_number": 0,
808
+ "resize_ratio": 1,
809
+ "blur_strength": 3,
810
+ },
811
+ [[], []],
812
+ None,
813
+ None,
814
+ gr.update(visible=False),
815
+ gr.update(visible=False),
816
+ gr.update(visible=False),
817
+ gr.update(visible=False),
818
+ gr.update(visible=False),
819
+ gr.update(visible=False),
820
+ gr.update(visible=False),
821
+ gr.update(visible=False),
822
+ gr.update(visible=False),
823
+ gr.update(visible=False, value=[]),
824
+ gr.update(visible=False),
825
+ gr.update(visible=False),
826
+ gr.update(visible=False),
827
+ ),
828
+ [],
829
+ [
830
+ video_state,
831
+ interactive_state,
832
+ click_state,
833
+ video_output,
834
+ template_frame,
835
+ tracking_video_predict_button,
836
+ image_selection_slider,
837
+ track_pause_number_slider,
838
+ point_prompt,
839
+ clear_button_click,
840
+ Add_mask_button,
841
+ template_frame,
842
+ tracking_video_predict_button,
843
+ video_output,
844
+ mask_dropdown,
845
+ remove_mask_button,
846
+ blur_video_predict_button,
847
+ blur_strength_slider,
848
+ run_status,
849
+ ],
850
+ queue=False,
851
+ show_progress=False,
852
+ )
853
+
854
+ # points clear
855
+ clear_button_click.click(
856
+ fn=clear_click,
857
+ inputs=[
858
+ video_state,
859
+ click_state,
860
+ ],
861
+ outputs=[template_frame, click_state, run_status],
862
+ )
863
+ # set example
864
+ gr.Markdown("## Examples")
865
+ gr.Examples(
866
+ examples=[
867
+ os.path.join(os.path.dirname(__file__), "./data/", test_sample)
868
+ for test_sample in [
869
+ "sample-1.mp4",
870
+ "sample-2.mp4",
871
+ ]
872
+ ],
873
+ fn=run_example,
874
+ inputs=[video_input],
875
+ outputs=[video_input],
876
+ )
877
+ iface.queue(concurrency_count=1)
878
+ iface.launch(
879
+ debug=True, enable_queue=True
880
+ )
data/sample-1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc49f2d9f5f00775248b8a66228f3e42304bbc391013d23ac66d21ba1f0e5fd2
3
+ size 664422
data/sample-2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45ba5eb410e9d25744946afe61abff9e2ab0916d2f206637636ae30d0decd5e9
3
+ size 1369798
export_onnx_model.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ from mobile_sam import sam_model_registry
10
+ from mobile_sam.utils.onnx import SamOnnxModel
11
+
12
+ import argparse
13
+ import warnings
14
+
15
+ try:
16
+ import onnxruntime # type: ignore
17
+
18
+ onnxruntime_exists = True
19
+ except ImportError:
20
+ onnxruntime_exists = False
21
+
22
+ parser = argparse.ArgumentParser(
23
+ description="Export the SAM prompt encoder and mask decoder to an ONNX model."
24
+ )
25
+
26
+ parser.add_argument(
27
+ "--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
28
+ )
29
+
30
+ parser.add_argument(
31
+ "--output", type=str, required=True, help="The filename to save the ONNX model to."
32
+ )
33
+
34
+ parser.add_argument(
35
+ "--model-type",
36
+ type=str,
37
+ required=True,
38
+ help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
39
+ )
40
+
41
+ parser.add_argument(
42
+ "--return-single-mask",
43
+ action="store_true",
44
+ help=(
45
+ "If true, the exported ONNX model will only return the best mask, "
46
+ "instead of returning multiple masks. For high resolution images "
47
+ "this can improve runtime when upscaling masks is expensive."
48
+ ),
49
+ )
50
+
51
+ parser.add_argument(
52
+ "--opset",
53
+ type=int,
54
+ default=16,
55
+ help="The ONNX opset version to use. Must be >=11",
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--quantize-out",
60
+ type=str,
61
+ default=None,
62
+ help=(
63
+ "If set, will quantize the model and save it with this name. "
64
+ "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
65
+ ),
66
+ )
67
+
68
+ parser.add_argument(
69
+ "--gelu-approximate",
70
+ action="store_true",
71
+ help=(
72
+ "Replace GELU operations with approximations using tanh. Useful "
73
+ "for some runtimes that have slow or unimplemented erf ops, used in GELU."
74
+ ),
75
+ )
76
+
77
+ parser.add_argument(
78
+ "--use-stability-score",
79
+ action="store_true",
80
+ help=(
81
+ "Replaces the model's predicted mask quality score with the stability "
82
+ "score calculated on the low resolution masks using an offset of 1.0. "
83
+ ),
84
+ )
85
+
86
+ parser.add_argument(
87
+ "--return-extra-metrics",
88
+ action="store_true",
89
+ help=(
90
+ "The model will return five results: (masks, scores, stability_scores, "
91
+ "areas, low_res_logits) instead of the usual three. This can be "
92
+ "significantly slower for high resolution outputs."
93
+ ),
94
+ )
95
+
96
+
97
+ def run_export(
98
+ model_type: str,
99
+ checkpoint: str,
100
+ output: str,
101
+ opset: int,
102
+ return_single_mask: bool,
103
+ gelu_approximate: bool = False,
104
+ use_stability_score: bool = False,
105
+ return_extra_metrics=False,
106
+ ):
107
+ print("Loading model...")
108
+ sam = sam_model_registry[model_type](checkpoint=checkpoint)
109
+
110
+ onnx_model = SamOnnxModel(
111
+ model=sam,
112
+ return_single_mask=return_single_mask,
113
+ use_stability_score=use_stability_score,
114
+ return_extra_metrics=return_extra_metrics,
115
+ )
116
+
117
+ if gelu_approximate:
118
+ for n, m in onnx_model.named_modules():
119
+ if isinstance(m, torch.nn.GELU):
120
+ m.approximate = "tanh"
121
+
122
+ dynamic_axes = {
123
+ "point_coords": {1: "num_points"},
124
+ "point_labels": {1: "num_points"},
125
+ }
126
+
127
+ embed_dim = sam.prompt_encoder.embed_dim
128
+ embed_size = sam.prompt_encoder.image_embedding_size
129
+ mask_input_size = [4 * x for x in embed_size]
130
+ dummy_inputs = {
131
+ "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
132
+ "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
133
+ "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
134
+ "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
135
+ "has_mask_input": torch.tensor([1], dtype=torch.float),
136
+ "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
137
+ }
138
+
139
+ _ = onnx_model(**dummy_inputs)
140
+
141
+ output_names = ["masks", "iou_predictions", "low_res_masks"]
142
+
143
+ with warnings.catch_warnings():
144
+ warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
145
+ warnings.filterwarnings("ignore", category=UserWarning)
146
+ with open(output, "wb") as f:
147
+ print(f"Exporting onnx model to {output}...")
148
+ torch.onnx.export(
149
+ onnx_model,
150
+ tuple(dummy_inputs.values()),
151
+ f,
152
+ export_params=True,
153
+ verbose=False,
154
+ opset_version=opset,
155
+ do_constant_folding=True,
156
+ input_names=list(dummy_inputs.keys()),
157
+ output_names=output_names,
158
+ dynamic_axes=dynamic_axes,
159
+ )
160
+
161
+ if onnxruntime_exists:
162
+ ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
163
+ # set cpu provider default
164
+ providers = ["CPUExecutionProvider"]
165
+ ort_session = onnxruntime.InferenceSession(output, providers=providers)
166
+ _ = ort_session.run(None, ort_inputs)
167
+ print("Model has successfully been run with ONNXRuntime.")
168
+
169
+
170
+ def to_numpy(tensor):
171
+ return tensor.cpu().numpy()
172
+
173
+
174
+ if __name__ == "__main__":
175
+ args = parser.parse_args()
176
+ run_export(
177
+ model_type=args.model_type,
178
+ checkpoint=args.checkpoint,
179
+ output=args.output,
180
+ opset=args.opset,
181
+ return_single_mask=args.return_single_mask,
182
+ gelu_approximate=args.gelu_approximate,
183
+ use_stability_score=args.use_stability_score,
184
+ return_extra_metrics=args.return_extra_metrics,
185
+ )
186
+
187
+ if args.quantize_out is not None:
188
+ assert onnxruntime_exists, "onnxruntime is required to quantize the model."
189
+ from onnxruntime.quantization import QuantType # type: ignore
190
+ from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore
191
+
192
+ print(f"Quantizing model and writing to {args.quantize_out}...")
193
+ quantize_dynamic(
194
+ model_input=args.output,
195
+ model_output=args.quantize_out,
196
+ optimize_model=True,
197
+ per_channel=False,
198
+ reduce_range=False,
199
+ weight_type=QuantType.QUInt8,
200
+ )
201
+ print("Done!")
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-opencv
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "Blur-Anything"
3
+ version = "0.1.0"
4
+ description = "Track and blur any object or person in a video."
5
+ authors = ["Y-T-G <yaseensinbox@gmail.com>"]
6
+ license = "MIT"
7
+ readme = "README.md"
8
+ packages = [{include = "blur_anything"}]
9
+
10
+ [tool.poetry.dependencies]
11
+ python = "^3.9"
12
+ gradio = "^3.28.1"
13
+ numpy = "^1.24.3"
14
+ av = "^10.0.0"
15
+ torch = "^2.0.0"
16
+ opencv-python = "^4.7.0.72"
17
+ psutil = "^5.9.5"
18
+ tqdm = "^4.65.0"
19
+ matplotlib = "^3.7.1"
20
+ segment-anything = {git = "https://github.com/facebookresearch/segment-anything.git"}
21
+ torchvision = "^0.15.1"
22
+ pims = "^0.6.1"
23
+ mobile-sam = {git = "https://github.com/ChaoningZhang/MobileSAM.git"}
24
+ onnxruntime = "^1.15.1"
25
+ timm = "^0.9.5"
26
+ onnx = "^1.14.0"
27
+
28
+
29
+ [build-system]
30
+ requires = ["poetry-core"]
31
+ build-backend = "poetry.core.masonry.api"
requirements.txt ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==4.2.2
5
+ anyio==3.6.2
6
+ async-timeout==4.0.2
7
+ attrs==23.1.0
8
+ av==10.0.0
9
+ certifi==2022.12.7
10
+ charset-normalizer==3.1.0
11
+ click==8.1.3
12
+ cmake==3.26.3
13
+ colorama==0.4.6
14
+ coloredlogs==15.0.1
15
+ contourpy==1.0.7
16
+ cycler==0.11.0
17
+ entrypoints==0.4
18
+ fastapi==0.95.1
19
+ ffmpy==0.3.0
20
+ filelock==3.12.0
21
+ flatbuffers==23.5.26
22
+ fonttools==4.39.3
23
+ frozenlist==1.3.3
24
+ fsspec==2023.4.0
25
+ gradio-client==0.1.4
26
+ gradio==3.28.3
27
+ h11==0.14.0
28
+ httpcore==0.17.0
29
+ httpx==0.24.0
30
+ huggingface-hub==0.14.1
31
+ humanfriendly==10.0
32
+ idna==3.4
33
+ imageio==2.28.1
34
+ importlib-resources==5.12.0
35
+ jinja2==3.1.2
36
+ jsonschema==4.17.3
37
+ kiwisolver==1.4.4
38
+ linkify-it-py==2.0.2
39
+ lit==16.0.2
40
+ markdown-it-py==2.2.0
41
+ markdown-it-py[linkify]==2.2.0
42
+ markupsafe==2.1.2
43
+ matplotlib==3.7.1
44
+ mdit-py-plugins==0.3.3
45
+ mdurl==0.1.2
46
+ mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git
47
+ mpmath==1.3.0
48
+ multidict==6.0.4
49
+ networkx==3.1
50
+ numpy==1.24.3
51
+ nvidia-cublas-cu11==11.10.3.66
52
+ nvidia-cuda-cupti-cu11==11.7.101
53
+ nvidia-cuda-nvrtc-cu11==11.7.99
54
+ nvidia-cuda-runtime-cu11==11.7.99
55
+ nvidia-cudnn-cu11==8.5.0.96
56
+ nvidia-cufft-cu11==10.9.0.58
57
+ nvidia-curand-cu11==10.2.10.91
58
+ nvidia-cusolver-cu11==11.4.0.1
59
+ nvidia-cusparse-cu11==11.7.4.91
60
+ nvidia-nccl-cu11==2.14.3
61
+ nvidia-nvtx-cu11==11.7.91
62
+ onnx==1.14.0
63
+ onnxruntime==1.15.1
64
+ opencv-python==4.7.0.72
65
+ orjson==3.8.11
66
+ packaging==23.1
67
+ pandas==2.0.1
68
+ pillow==9.5.0
69
+ pims==0.6.1
70
+ protobuf==4.24.0
71
+ psutil==5.9.5
72
+ pydantic==1.10.7
73
+ pydub==0.25.1
74
+ pygments==2.15.1
75
+ pyparsing==3.0.9
76
+ pyreadline3==3.4.1
77
+ pyrsistent==0.19.3
78
+ python-dateutil==2.8.2
79
+ python-multipart==0.0.6
80
+ pytz==2023.3
81
+ pyyaml==6.0
82
+ requests==2.30.0
83
+ safetensors==0.3.2
84
+ segment-anything @ git+https://github.com/facebookresearch/segment-anything.git
85
+ semantic-version==2.10.0
86
+ setuptools==67.7.2
87
+ six==1.16.0
88
+ slicerator==1.1.0
89
+ sniffio==1.3.0
90
+ starlette==0.26.1
91
+ sympy==1.11.1
92
+ timm==0.9.5
93
+ toolz==0.12.0
94
+ torch==2.0.0
95
+ torchvision==0.15.1
96
+ tqdm==4.65.0
97
+ triton==2.0.0
98
+ typing-extensions==4.5.0
99
+ tzdata==2023.3
100
+ uc-micro-py==1.0.2
101
+ urllib3==2.0.2
102
+ uvicorn==0.22.0
103
+ websockets==11.0.2
104
+ wheel==0.40.0
105
+ yarl==1.9.2
106
+ zipp==3.15.0
track_anything.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+
4
+ from utils.interact_tools import SamControler
5
+ from tracker.base_tracker import BaseTracker
6
+ import numpy as np
7
+ import argparse
8
+ import cv2
9
+
10
+ from typing import Optional
11
+
12
+
13
+ class TrackingAnything:
14
+ def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, xmem_checkpoint, args):
15
+ self.args = args
16
+ self.sam_pt_checkpoint = sam_pt_checkpoint
17
+ self.sam_onnx_checkpoint = sam_onnx_checkpoint
18
+ self.xmem_checkpoint = xmem_checkpoint
19
+ self.samcontroler = SamControler(
20
+ self.sam_pt_checkpoint, self.sam_onnx_checkpoint, args.sam_model_type, args.device
21
+ )
22
+ self.xmem = BaseTracker(self.xmem_checkpoint, device=args.device)
23
+
24
+ def first_frame_click(
25
+ self, image: np.ndarray, points: np.ndarray, labels: np.ndarray, multimask=True
26
+ ):
27
+ mask, logit, painted_image = self.samcontroler.first_frame_click(
28
+ image, points, labels, multimask
29
+ )
30
+ return mask, logit, painted_image
31
+
32
+ def generator(
33
+ self,
34
+ images: list,
35
+ template_mask: np.ndarray,
36
+ write: Optional[bool] = False,
37
+ fps: Optional[int] = "30",
38
+ output_path: Optional[str] = "tracking.mp4",
39
+ ):
40
+ masks = []
41
+ logits = []
42
+ painted_images = []
43
+
44
+ if write:
45
+ size = images[0].shape[:2][::-1]
46
+ if not os.path.exists(os.path.dirname(output_path)):
47
+ os.makedirs(os.path.dirname(output_path))
48
+ writer = cv2.VideoWriter(
49
+ output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, size
50
+ )
51
+
52
+ for i in tqdm(range(len(images)), desc="Tracking image"):
53
+ if i == 0:
54
+ mask, logit, painted_image = self.xmem.track(images[i], template_mask)
55
+ else:
56
+ mask, logit, painted_image = self.xmem.track(images[i])
57
+
58
+ masks.append(mask)
59
+ logits.append(logit)
60
+
61
+ if write:
62
+ writer.write(painted_image[:,:,::-1])
63
+ else:
64
+ painted_images.append(painted_image)
65
+
66
+ if write:
67
+ writer.release()
68
+
69
+ return masks, logits, painted_images
70
+
71
+
72
+ def parse_augment():
73
+ parser = argparse.ArgumentParser()
74
+ parser.add_argument("--device", type=str, default="cpu")
75
+ parser.add_argument("--sam_model_type", type=str, default="vit_t")
76
+ parser.add_argument(
77
+ "--port",
78
+ type=int,
79
+ default=6080,
80
+ help="only useful when running gradio applications",
81
+ )
82
+ parser.add_argument("--debug", action="store_true")
83
+ parser.add_argument("--mask_save", default=False)
84
+ args = parser.parse_args()
85
+
86
+ if args.debug:
87
+ print(args)
88
+ return args
tracker/base_tracker.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import for debugging
2
+ import os
3
+ import glob
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ # import for base_tracker
8
+ import torch
9
+ import yaml
10
+ import torch.nn.functional as F
11
+ from tracker.model.network import XMem
12
+ from inference.inference_core import InferenceCore
13
+ from tracker.util.mask_mapper import MaskMapper
14
+ from torchvision import transforms
15
+ from tracker.util.range_transform import im_normalization
16
+
17
+ from utils.painter import mask_painter
18
+
19
+ dir_path = os.path.dirname(os.path.realpath(__file__))
20
+
21
+
22
+ class BaseTracker:
23
+ def __init__(
24
+ self, xmem_checkpoint, device, sam_model=None, model_type=None
25
+ ) -> None:
26
+ """
27
+ device: model device
28
+ xmem_checkpoint: checkpoint of XMem model
29
+ """
30
+ # load configurations
31
+ with open(f"{dir_path}/config/config.yaml", "r") as stream:
32
+ config = yaml.safe_load(stream)
33
+ # initialise XMem
34
+ network = XMem(config, xmem_checkpoint, map_location=device).eval()
35
+ # initialise IncerenceCore
36
+ self.tracker = InferenceCore(network, config)
37
+ # data transformation
38
+ self.im_transform = transforms.Compose(
39
+ [
40
+ transforms.ToTensor(),
41
+ im_normalization,
42
+ ]
43
+ )
44
+ self.device = device
45
+
46
+ # changable properties
47
+ self.mapper = MaskMapper()
48
+ self.initialised = False
49
+
50
+ # # SAM-based refinement
51
+ # self.sam_model = sam_model
52
+ # self.resizer = Resize([256, 256])
53
+
54
+ @torch.no_grad()
55
+ def resize_mask(self, mask):
56
+ # mask transform is applied AFTER mapper, so we need to post-process it in eval.py
57
+ h, w = mask.shape[-2:]
58
+ min_hw = min(h, w)
59
+ return F.interpolate(
60
+ mask,
61
+ (int(h / min_hw * self.size), int(w / min_hw * self.size)),
62
+ mode="nearest",
63
+ )
64
+
65
+ @torch.no_grad()
66
+ def track(self, frame, first_frame_annotation=None):
67
+ """
68
+ Input:
69
+ frames: numpy arrays (H, W, 3)
70
+ logit: numpy array (H, W), logit
71
+
72
+ Output:
73
+ mask: numpy arrays (H, W)
74
+ logit: numpy arrays, probability map (H, W)
75
+ painted_image: numpy array (H, W, 3)
76
+ """
77
+
78
+ if first_frame_annotation is not None: # first frame mask
79
+ # initialisation
80
+ mask, labels = self.mapper.convert_mask(first_frame_annotation)
81
+ mask = torch.Tensor(mask).to(self.device)
82
+ self.tracker.set_all_labels(list(self.mapper.remappings.values()))
83
+ else:
84
+ mask = None
85
+ labels = None
86
+ # prepare inputs
87
+ frame_tensor = self.im_transform(frame).to(self.device)
88
+ # track one frame
89
+ probs, _ = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W
90
+ # # refine
91
+ # if first_frame_annotation is None:
92
+ # out_mask = self.sam_refinement(frame, logits[1], ti)
93
+
94
+ # convert to mask
95
+ out_mask = torch.argmax(probs, dim=0)
96
+ out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
97
+
98
+ final_mask = np.zeros_like(out_mask)
99
+
100
+ # map back
101
+ for k, v in self.mapper.remappings.items():
102
+ final_mask[out_mask == v] = k
103
+
104
+ num_objs = final_mask.max()
105
+ painted_image = frame
106
+ for obj in range(1, num_objs + 1):
107
+ if np.max(final_mask == obj) == 0:
108
+ continue
109
+ painted_image = mask_painter(
110
+ painted_image, (final_mask == obj).astype("uint8"), mask_color=obj + 1
111
+ )
112
+
113
+ # print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB')
114
+
115
+ return final_mask, final_mask, painted_image
116
+
117
+ @torch.no_grad()
118
+ def sam_refinement(self, frame, logits, ti):
119
+ """
120
+ refine segmentation results with mask prompt
121
+ """
122
+ # convert to 1, 256, 256
123
+ self.sam_model.set_image(frame)
124
+ mode = "mask"
125
+ logits = logits.unsqueeze(0)
126
+ logits = self.resizer(logits).cpu().numpy()
127
+ prompts = {"mask_input": logits} # 1 256 256
128
+ masks, scores, logits = self.sam_model.predict(
129
+ prompts, mode, multimask=True
130
+ ) # masks (n, h, w), scores (n,), logits (n, 256, 256)
131
+ painted_image = mask_painter(
132
+ frame, masks[np.argmax(scores)].astype("uint8"), mask_alpha=0.8
133
+ )
134
+ painted_image = Image.fromarray(painted_image)
135
+ painted_image.save(f"/ssd1/gaomingqi/refine/{ti:05d}.png")
136
+ self.sam_model.reset_image()
137
+
138
+ @torch.no_grad()
139
+ def clear_memory(self):
140
+ self.tracker.clear_memory()
141
+ self.mapper.clear_labels()
142
+ torch.cuda.empty_cache()
tracker/config/config.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config info for XMem
2
+ benchmark: False
3
+ disable_long_term: False
4
+ max_mid_term_frames: 10
5
+ min_mid_term_frames: 5
6
+ max_long_term_elements: 1000
7
+ num_prototypes: 128
8
+ top_k: 30
9
+ mem_every: 5
10
+ deep_update_every: -1
11
+ save_scores: False
12
+ flip: False
13
+ size: 480
14
+ enable_long_term: True
15
+ enable_long_term_count_usage: True
tracker/inference/__init__.py ADDED
File without changes
tracker/inference/inference_core.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inference.memory_manager import MemoryManager
2
+ from model.network import XMem
3
+ from model.aggregate import aggregate
4
+
5
+ from tracker.util.tensor_util import pad_divide_by, unpad
6
+
7
+
8
+ class InferenceCore:
9
+ def __init__(self, network: XMem, config):
10
+ self.config = config
11
+ self.network = network
12
+ self.mem_every = config["mem_every"]
13
+ self.deep_update_every = config["deep_update_every"]
14
+ self.enable_long_term = config["enable_long_term"]
15
+
16
+ # if deep_update_every < 0, synchronize deep update with memory frame
17
+ self.deep_update_sync = self.deep_update_every < 0
18
+
19
+ self.clear_memory()
20
+ self.all_labels = None
21
+
22
+ def clear_memory(self):
23
+ self.curr_ti = -1
24
+ self.last_mem_ti = 0
25
+ if not self.deep_update_sync:
26
+ self.last_deep_update_ti = -self.deep_update_every
27
+ self.memory = MemoryManager(config=self.config)
28
+
29
+ def update_config(self, config):
30
+ self.mem_every = config["mem_every"]
31
+ self.deep_update_every = config["deep_update_every"]
32
+ self.enable_long_term = config["enable_long_term"]
33
+
34
+ # if deep_update_every < 0, synchronize deep update with memory frame
35
+ self.deep_update_sync = self.deep_update_every < 0
36
+ self.memory.update_config(config)
37
+
38
+ def set_all_labels(self, all_labels):
39
+ # self.all_labels = [l.item() for l in all_labels]
40
+ self.all_labels = all_labels
41
+
42
+ def step(self, image, mask=None, valid_labels=None, end=False):
43
+ # image: 3*H*W
44
+ # mask: num_objects*H*W or None
45
+ self.curr_ti += 1
46
+ image, self.pad = pad_divide_by(image, 16)
47
+ image = image.unsqueeze(0) # add the batch dimension
48
+
49
+ is_mem_frame = (
50
+ (self.curr_ti - self.last_mem_ti >= self.mem_every) or (mask is not None)
51
+ ) and (not end)
52
+ need_segment = (self.curr_ti > 0) and (
53
+ (valid_labels is None) or (len(self.all_labels) != len(valid_labels))
54
+ )
55
+ is_deep_update = (
56
+ (self.deep_update_sync and is_mem_frame)
57
+ or ( # synchronized
58
+ not self.deep_update_sync
59
+ and self.curr_ti - self.last_deep_update_ti >= self.deep_update_every
60
+ ) # no-sync
61
+ ) and (not end)
62
+ is_normal_update = (not self.deep_update_sync or not is_deep_update) and (
63
+ not end
64
+ )
65
+
66
+ key, shrinkage, selection, f16, f8, f4 = self.network.encode_key(
67
+ image, need_ek=(self.enable_long_term or need_segment), need_sk=is_mem_frame
68
+ )
69
+ multi_scale_features = (f16, f8, f4)
70
+
71
+ # segment the current frame is needed
72
+ if need_segment:
73
+ memory_readout = self.memory.match_memory(key, selection).unsqueeze(0)
74
+
75
+ hidden, pred_logits_with_bg, pred_prob_with_bg = self.network.segment(
76
+ multi_scale_features,
77
+ memory_readout,
78
+ self.memory.get_hidden(),
79
+ h_out=is_normal_update,
80
+ strip_bg=False,
81
+ )
82
+ # remove batch dim
83
+ pred_prob_with_bg = pred_prob_with_bg[0]
84
+ pred_prob_no_bg = pred_prob_with_bg[1:]
85
+
86
+ pred_logits_with_bg = pred_logits_with_bg[0]
87
+ pred_logits_no_bg = pred_logits_with_bg[1:]
88
+
89
+ if is_normal_update:
90
+ self.memory.set_hidden(hidden)
91
+ else:
92
+ pred_prob_no_bg = (
93
+ pred_prob_with_bg
94
+ ) = pred_logits_with_bg = pred_logits_no_bg = None
95
+
96
+ # use the input mask if any
97
+ if mask is not None:
98
+ mask, _ = pad_divide_by(mask, 16)
99
+
100
+ if pred_prob_no_bg is not None:
101
+ # if we have a predicted mask, we work on it
102
+ # make pred_prob_no_bg consistent with the input mask
103
+ mask_regions = mask.sum(0) > 0.5
104
+ pred_prob_no_bg[:, mask_regions] = 0
105
+ # shift by 1 because mask/pred_prob_no_bg do not contain background
106
+ mask = mask.type_as(pred_prob_no_bg)
107
+ if valid_labels is not None:
108
+ shift_by_one_non_labels = [
109
+ i
110
+ for i in range(pred_prob_no_bg.shape[0])
111
+ if (i + 1) not in valid_labels
112
+ ]
113
+ # non-labelled objects are copied from the predicted mask
114
+ mask[shift_by_one_non_labels] = pred_prob_no_bg[
115
+ shift_by_one_non_labels
116
+ ]
117
+ pred_prob_with_bg = aggregate(mask, dim=0)
118
+
119
+ # also create new hidden states
120
+ self.memory.create_hidden_state(len(self.all_labels), key)
121
+
122
+ # save as memory if needed
123
+ if is_mem_frame:
124
+ value, hidden = self.network.encode_value(
125
+ image,
126
+ f16,
127
+ self.memory.get_hidden(),
128
+ pred_prob_with_bg[1:].unsqueeze(0),
129
+ is_deep_update=is_deep_update,
130
+ )
131
+ self.memory.add_memory(
132
+ key,
133
+ shrinkage,
134
+ value,
135
+ self.all_labels,
136
+ selection=selection if self.enable_long_term else None,
137
+ )
138
+ self.last_mem_ti = self.curr_ti
139
+
140
+ if is_deep_update:
141
+ self.memory.set_hidden(hidden)
142
+ self.last_deep_update_ti = self.curr_ti
143
+
144
+ if pred_logits_with_bg is None:
145
+ return unpad(pred_prob_with_bg, self.pad), None
146
+ else:
147
+ return unpad(pred_prob_with_bg, self.pad), unpad(
148
+ pred_logits_with_bg, self.pad
149
+ )
tracker/inference/kv_memory_store.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List
3
+
4
+
5
+ class KeyValueMemoryStore:
6
+ """
7
+ Works for key/value pairs type storage
8
+ e.g., working and long-term memory
9
+ """
10
+
11
+ """
12
+ An object group is created when new objects enter the video
13
+ Objects in the same group share the same temporal extent
14
+ i.e., objects initialized in the same frame are in the same group
15
+ For DAVIS/interactive, there is only one object group
16
+ For YouTubeVOS, there can be multiple object groups
17
+ """
18
+
19
+ def __init__(self, count_usage: bool):
20
+ self.count_usage = count_usage
21
+
22
+ # keys are stored in a single tensor and are shared between groups/objects
23
+ # values are stored as a list indexed by object groups
24
+ self.k = None
25
+ self.v = []
26
+ self.obj_groups = []
27
+ # for debugging only
28
+ self.all_objects = []
29
+
30
+ # shrinkage and selection are also single tensors
31
+ self.s = self.e = None
32
+
33
+ # usage
34
+ if self.count_usage:
35
+ self.use_count = self.life_count = None
36
+
37
+ def add(self, key, value, shrinkage, selection, objects: List[int]):
38
+ new_count = torch.zeros(
39
+ (key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32
40
+ )
41
+ new_life = (
42
+ torch.zeros(
43
+ (key.shape[0], 1, key.shape[2]), device=key.device, dtype=torch.float32
44
+ )
45
+ + 1e-7
46
+ )
47
+
48
+ # add the key
49
+ if self.k is None:
50
+ self.k = key
51
+ self.s = shrinkage
52
+ self.e = selection
53
+ if self.count_usage:
54
+ self.use_count = new_count
55
+ self.life_count = new_life
56
+ else:
57
+ self.k = torch.cat([self.k, key], -1)
58
+ if shrinkage is not None:
59
+ self.s = torch.cat([self.s, shrinkage], -1)
60
+ if selection is not None:
61
+ self.e = torch.cat([self.e, selection], -1)
62
+ if self.count_usage:
63
+ self.use_count = torch.cat([self.use_count, new_count], -1)
64
+ self.life_count = torch.cat([self.life_count, new_life], -1)
65
+
66
+ # add the value
67
+ if objects is not None:
68
+ # When objects is given, v is a tensor; used in working memory
69
+ assert isinstance(value, torch.Tensor)
70
+ # First consume objects that are already in the memory bank
71
+ # cannot use set here because we need to preserve order
72
+ # shift by one as background is not part of value
73
+ remaining_objects = [obj - 1 for obj in objects]
74
+ for gi, group in enumerate(self.obj_groups):
75
+ for obj in group:
76
+ # should properly raise an error if there are overlaps in obj_groups
77
+ remaining_objects.remove(obj)
78
+ self.v[gi] = torch.cat([self.v[gi], value[group]], -1)
79
+
80
+ # If there are remaining objects, add them as a new group
81
+ if len(remaining_objects) > 0:
82
+ new_group = list(remaining_objects)
83
+ self.v.append(value[new_group])
84
+ self.obj_groups.append(new_group)
85
+ self.all_objects.extend(new_group)
86
+
87
+ assert (
88
+ sorted(self.all_objects) == self.all_objects
89
+ ), "Objects MUST be inserted in sorted order "
90
+ else:
91
+ # When objects is not given, v is a list that already has the object groups sorted
92
+ # used in long-term memory
93
+ assert isinstance(value, list)
94
+ for gi, gv in enumerate(value):
95
+ if gv is None:
96
+ continue
97
+ if gi < self.num_groups:
98
+ self.v[gi] = torch.cat([self.v[gi], gv], -1)
99
+ else:
100
+ self.v.append(gv)
101
+
102
+ def update_usage(self, usage):
103
+ # increase all life count by 1
104
+ # increase use of indexed elements
105
+ if not self.count_usage:
106
+ return
107
+
108
+ self.use_count += usage.view_as(self.use_count)
109
+ self.life_count += 1
110
+
111
+ def sieve_by_range(self, start: int, end: int, min_size: int):
112
+ # keep only the elements *outside* of this range (with some boundary conditions)
113
+ # i.e., concat (a[:start], a[end:])
114
+ # min_size is only used for values, we do not sieve values under this size
115
+ # (because they are not consolidated)
116
+
117
+ if end == 0:
118
+ # negative 0 would not work as the end index!
119
+ self.k = self.k[:, :, :start]
120
+ if self.count_usage:
121
+ self.use_count = self.use_count[:, :, :start]
122
+ self.life_count = self.life_count[:, :, :start]
123
+ if self.s is not None:
124
+ self.s = self.s[:, :, :start]
125
+ if self.e is not None:
126
+ self.e = self.e[:, :, :start]
127
+
128
+ for gi in range(self.num_groups):
129
+ if self.v[gi].shape[-1] >= min_size:
130
+ self.v[gi] = self.v[gi][:, :, :start]
131
+ else:
132
+ self.k = torch.cat([self.k[:, :, :start], self.k[:, :, end:]], -1)
133
+ if self.count_usage:
134
+ self.use_count = torch.cat(
135
+ [self.use_count[:, :, :start], self.use_count[:, :, end:]], -1
136
+ )
137
+ self.life_count = torch.cat(
138
+ [self.life_count[:, :, :start], self.life_count[:, :, end:]], -1
139
+ )
140
+ if self.s is not None:
141
+ self.s = torch.cat([self.s[:, :, :start], self.s[:, :, end:]], -1)
142
+ if self.e is not None:
143
+ self.e = torch.cat([self.e[:, :, :start], self.e[:, :, end:]], -1)
144
+
145
+ for gi in range(self.num_groups):
146
+ if self.v[gi].shape[-1] >= min_size:
147
+ self.v[gi] = torch.cat(
148
+ [self.v[gi][:, :, :start], self.v[gi][:, :, end:]], -1
149
+ )
150
+
151
+ def remove_obsolete_features(self, max_size: int):
152
+ # normalize with life duration
153
+ usage = self.get_usage().flatten()
154
+
155
+ values, _ = torch.topk(
156
+ usage, k=(self.size - max_size), largest=False, sorted=True
157
+ )
158
+ survived = usage > values[-1]
159
+
160
+ self.k = self.k[:, :, survived]
161
+ self.s = self.s[:, :, survived] if self.s is not None else None
162
+ # Long-term memory does not store ek so this should not be needed
163
+ self.e = self.e[:, :, survived] if self.e is not None else None
164
+ if self.num_groups > 1:
165
+ raise NotImplementedError(
166
+ """The current data structure does not support feature removal with
167
+ multiple object groups (e.g., some objects start to appear later in the video)
168
+ The indices for "survived" is based on keys but not all values are present for every key
169
+ Basically we need to remap the indices for keys to values
170
+ """
171
+ )
172
+ for gi in range(self.num_groups):
173
+ self.v[gi] = self.v[gi][:, :, survived]
174
+
175
+ self.use_count = self.use_count[:, :, survived]
176
+ self.life_count = self.life_count[:, :, survived]
177
+
178
+ def get_usage(self):
179
+ # return normalized usage
180
+ if not self.count_usage:
181
+ raise RuntimeError("I did not count usage!")
182
+ else:
183
+ usage = self.use_count / self.life_count
184
+ return usage
185
+
186
+ def get_all_sliced(self, start: int, end: int):
187
+ # return k, sk, ek, usage in order, sliced by start and end
188
+
189
+ if end == 0:
190
+ # negative 0 would not work as the end index!
191
+ k = self.k[:, :, start:]
192
+ sk = self.s[:, :, start:] if self.s is not None else None
193
+ ek = self.e[:, :, start:] if self.e is not None else None
194
+ usage = self.get_usage()[:, :, start:]
195
+ else:
196
+ k = self.k[:, :, start:end]
197
+ sk = self.s[:, :, start:end] if self.s is not None else None
198
+ ek = self.e[:, :, start:end] if self.e is not None else None
199
+ usage = self.get_usage()[:, :, start:end]
200
+
201
+ return k, sk, ek, usage
202
+
203
+ def get_v_size(self, ni: int):
204
+ return self.v[ni].shape[2]
205
+
206
+ def engaged(self):
207
+ return self.k is not None
208
+
209
+ @property
210
+ def size(self):
211
+ if self.k is None:
212
+ return 0
213
+ else:
214
+ return self.k.shape[-1]
215
+
216
+ @property
217
+ def num_groups(self):
218
+ return len(self.v)
219
+
220
+ @property
221
+ def key(self):
222
+ return self.k
223
+
224
+ @property
225
+ def value(self):
226
+ return self.v
227
+
228
+ @property
229
+ def shrinkage(self):
230
+ return self.s
231
+
232
+ @property
233
+ def selection(self):
234
+ return self.e
tracker/inference/memory_manager.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+
4
+ from inference.kv_memory_store import KeyValueMemoryStore
5
+ from model.memory_util import *
6
+
7
+
8
+ class MemoryManager:
9
+ """
10
+ Manages all three memory stores and the transition between working/long-term memory
11
+ """
12
+
13
+ def __init__(self, config):
14
+ self.hidden_dim = config["hidden_dim"]
15
+ self.top_k = config["top_k"]
16
+
17
+ self.enable_long_term = config["enable_long_term"]
18
+ self.enable_long_term_usage = config["enable_long_term_count_usage"]
19
+ if self.enable_long_term:
20
+ self.max_mt_frames = config["max_mid_term_frames"]
21
+ self.min_mt_frames = config["min_mid_term_frames"]
22
+ self.num_prototypes = config["num_prototypes"]
23
+ self.max_long_elements = config["max_long_term_elements"]
24
+
25
+ # dimensions will be inferred from input later
26
+ self.CK = self.CV = None
27
+ self.H = self.W = None
28
+
29
+ # The hidden state will be stored in a single tensor for all objects
30
+ # B x num_objects x CH x H x W
31
+ self.hidden = None
32
+
33
+ self.work_mem = KeyValueMemoryStore(count_usage=self.enable_long_term)
34
+ if self.enable_long_term:
35
+ self.long_mem = KeyValueMemoryStore(count_usage=self.enable_long_term_usage)
36
+
37
+ self.reset_config = True
38
+
39
+ def update_config(self, config):
40
+ self.reset_config = True
41
+ self.hidden_dim = config["hidden_dim"]
42
+ self.top_k = config["top_k"]
43
+
44
+ assert self.enable_long_term == config["enable_long_term"], "cannot update this"
45
+ assert (
46
+ self.enable_long_term_usage == config["enable_long_term_count_usage"]
47
+ ), "cannot update this"
48
+
49
+ self.enable_long_term_usage = config["enable_long_term_count_usage"]
50
+ if self.enable_long_term:
51
+ self.max_mt_frames = config["max_mid_term_frames"]
52
+ self.min_mt_frames = config["min_mid_term_frames"]
53
+ self.num_prototypes = config["num_prototypes"]
54
+ self.max_long_elements = config["max_long_term_elements"]
55
+
56
+ def _readout(self, affinity, v):
57
+ # this function is for a single object group
58
+ return v @ affinity
59
+
60
+ def match_memory(self, query_key, selection):
61
+ # query_key: B x C^k x H x W
62
+ # selection: B x C^k x H x W
63
+ num_groups = self.work_mem.num_groups
64
+ h, w = query_key.shape[-2:]
65
+
66
+ query_key = query_key.flatten(start_dim=2)
67
+ selection = selection.flatten(start_dim=2) if selection is not None else None
68
+
69
+ """
70
+ Memory readout using keys
71
+ """
72
+
73
+ if self.enable_long_term and self.long_mem.engaged():
74
+ # Use long-term memory
75
+ long_mem_size = self.long_mem.size
76
+ memory_key = torch.cat([self.long_mem.key, self.work_mem.key], -1)
77
+ shrinkage = torch.cat(
78
+ [self.long_mem.shrinkage, self.work_mem.shrinkage], -1
79
+ )
80
+
81
+ similarity = get_similarity(memory_key, shrinkage, query_key, selection)
82
+ work_mem_similarity = similarity[:, long_mem_size:]
83
+ long_mem_similarity = similarity[:, :long_mem_size]
84
+
85
+ # get the usage with the first group
86
+ # the first group always have all the keys valid
87
+ affinity, usage = do_softmax(
88
+ torch.cat(
89
+ [
90
+ long_mem_similarity[:, -self.long_mem.get_v_size(0) :],
91
+ work_mem_similarity,
92
+ ],
93
+ 1,
94
+ ),
95
+ top_k=self.top_k,
96
+ inplace=True,
97
+ return_usage=True,
98
+ )
99
+ affinity = [affinity]
100
+
101
+ # compute affinity group by group as later groups only have a subset of keys
102
+ for gi in range(1, num_groups):
103
+ if gi < self.long_mem.num_groups:
104
+ # merge working and lt similarities before softmax
105
+ affinity_one_group = do_softmax(
106
+ torch.cat(
107
+ [
108
+ long_mem_similarity[:, -self.long_mem.get_v_size(gi) :],
109
+ work_mem_similarity[:, -self.work_mem.get_v_size(gi) :],
110
+ ],
111
+ 1,
112
+ ),
113
+ top_k=self.top_k,
114
+ inplace=True,
115
+ )
116
+ else:
117
+ # no long-term memory for this group
118
+ affinity_one_group = do_softmax(
119
+ work_mem_similarity[:, -self.work_mem.get_v_size(gi) :],
120
+ top_k=self.top_k,
121
+ inplace=(gi == num_groups - 1),
122
+ )
123
+ affinity.append(affinity_one_group)
124
+
125
+ all_memory_value = []
126
+ for gi, gv in enumerate(self.work_mem.value):
127
+ # merge the working and lt values before readout
128
+ if gi < self.long_mem.num_groups:
129
+ all_memory_value.append(
130
+ torch.cat(
131
+ [self.long_mem.value[gi], self.work_mem.value[gi]], -1
132
+ )
133
+ )
134
+ else:
135
+ all_memory_value.append(gv)
136
+
137
+ """
138
+ Record memory usage for working and long-term memory
139
+ """
140
+ # ignore the index return for long-term memory
141
+ work_usage = usage[:, long_mem_size:]
142
+ self.work_mem.update_usage(work_usage.flatten())
143
+
144
+ if self.enable_long_term_usage:
145
+ # ignore the index return for working memory
146
+ long_usage = usage[:, :long_mem_size]
147
+ self.long_mem.update_usage(long_usage.flatten())
148
+ else:
149
+ # No long-term memory
150
+ similarity = get_similarity(
151
+ self.work_mem.key, self.work_mem.shrinkage, query_key, selection
152
+ )
153
+
154
+ if self.enable_long_term:
155
+ affinity, usage = do_softmax(
156
+ similarity,
157
+ inplace=(num_groups == 1),
158
+ top_k=self.top_k,
159
+ return_usage=True,
160
+ )
161
+
162
+ # Record memory usage for working memory
163
+ self.work_mem.update_usage(usage.flatten())
164
+ else:
165
+ affinity = do_softmax(
166
+ similarity,
167
+ inplace=(num_groups == 1),
168
+ top_k=self.top_k,
169
+ return_usage=False,
170
+ )
171
+
172
+ affinity = [affinity]
173
+
174
+ # compute affinity group by group as later groups only have a subset of keys
175
+ for gi in range(1, num_groups):
176
+ affinity_one_group = do_softmax(
177
+ similarity[:, -self.work_mem.get_v_size(gi) :],
178
+ top_k=self.top_k,
179
+ inplace=(gi == num_groups - 1),
180
+ )
181
+ affinity.append(affinity_one_group)
182
+
183
+ all_memory_value = self.work_mem.value
184
+
185
+ # Shared affinity within each group
186
+ all_readout_mem = torch.cat(
187
+ [self._readout(affinity[gi], gv) for gi, gv in enumerate(all_memory_value)],
188
+ 0,
189
+ )
190
+
191
+ return all_readout_mem.view(all_readout_mem.shape[0], self.CV, h, w)
192
+
193
+ def add_memory(self, key, shrinkage, value, objects, selection=None):
194
+ # key: 1*C*H*W
195
+ # value: 1*num_objects*C*H*W
196
+ # objects contain a list of object indices
197
+ if self.H is None or self.reset_config:
198
+ self.reset_config = False
199
+ self.H, self.W = key.shape[-2:]
200
+ self.HW = self.H * self.W
201
+ if self.enable_long_term:
202
+ # convert from num. frames to num. nodes
203
+ self.min_work_elements = self.min_mt_frames * self.HW
204
+ self.max_work_elements = self.max_mt_frames * self.HW
205
+
206
+ # key: 1*C*N
207
+ # value: num_objects*C*N
208
+ key = key.flatten(start_dim=2)
209
+ shrinkage = shrinkage.flatten(start_dim=2)
210
+ value = value[0].flatten(start_dim=2)
211
+
212
+ self.CK = key.shape[1]
213
+ self.CV = value.shape[1]
214
+
215
+ if selection is not None:
216
+ if not self.enable_long_term:
217
+ warnings.warn(
218
+ "the selection factor is only needed in long-term mode", UserWarning
219
+ )
220
+ selection = selection.flatten(start_dim=2)
221
+
222
+ self.work_mem.add(key, value, shrinkage, selection, objects)
223
+
224
+ # long-term memory cleanup
225
+ if self.enable_long_term:
226
+ # Do memory compressed if needed
227
+ if self.work_mem.size >= self.max_work_elements:
228
+ # print('remove memory')
229
+ # Remove obsolete features if needed
230
+ if self.long_mem.size >= (self.max_long_elements - self.num_prototypes):
231
+ self.long_mem.remove_obsolete_features(
232
+ self.max_long_elements - self.num_prototypes
233
+ )
234
+
235
+ self.compress_features()
236
+
237
+ def create_hidden_state(self, n, sample_key):
238
+ # n is the TOTAL number of objects
239
+ h, w = sample_key.shape[-2:]
240
+ if self.hidden is None:
241
+ self.hidden = torch.zeros(
242
+ (1, n, self.hidden_dim, h, w), device=sample_key.device
243
+ )
244
+ elif self.hidden.shape[1] != n:
245
+ self.hidden = torch.cat(
246
+ [
247
+ self.hidden,
248
+ torch.zeros(
249
+ (1, n - self.hidden.shape[1], self.hidden_dim, h, w),
250
+ device=sample_key.device,
251
+ ),
252
+ ],
253
+ 1,
254
+ )
255
+
256
+ assert self.hidden.shape[1] == n
257
+
258
+ def set_hidden(self, hidden):
259
+ self.hidden = hidden
260
+
261
+ def get_hidden(self):
262
+ return self.hidden
263
+
264
+ def compress_features(self):
265
+ HW = self.HW
266
+ candidate_value = []
267
+ total_work_mem_size = self.work_mem.size
268
+ for gv in self.work_mem.value:
269
+ # Some object groups might be added later in the video
270
+ # So not all keys have values associated with all objects
271
+ # We need to keep track of the key->value validity
272
+ mem_size_in_this_group = gv.shape[-1]
273
+ if mem_size_in_this_group == total_work_mem_size:
274
+ # full LT
275
+ candidate_value.append(gv[:, :, HW : -self.min_work_elements + HW])
276
+ else:
277
+ # mem_size is smaller than total_work_mem_size, but at least HW
278
+ assert HW <= mem_size_in_this_group < total_work_mem_size
279
+ if mem_size_in_this_group > self.min_work_elements + HW:
280
+ # part of this object group still goes into LT
281
+ candidate_value.append(gv[:, :, HW : -self.min_work_elements + HW])
282
+ else:
283
+ # this object group cannot go to the LT at all
284
+ candidate_value.append(None)
285
+
286
+ # perform memory consolidation
287
+ prototype_key, prototype_value, prototype_shrinkage = self.consolidation(
288
+ *self.work_mem.get_all_sliced(HW, -self.min_work_elements + HW),
289
+ candidate_value
290
+ )
291
+
292
+ # remove consolidated working memory
293
+ self.work_mem.sieve_by_range(
294
+ HW, -self.min_work_elements + HW, min_size=self.min_work_elements + HW
295
+ )
296
+
297
+ # add to long-term memory
298
+ self.long_mem.add(
299
+ prototype_key,
300
+ prototype_value,
301
+ prototype_shrinkage,
302
+ selection=None,
303
+ objects=None,
304
+ )
305
+ # print(f'long memory size: {self.long_mem.size}')
306
+ # print(f'work memory size: {self.work_mem.size}')
307
+
308
+ def consolidation(
309
+ self,
310
+ candidate_key,
311
+ candidate_shrinkage,
312
+ candidate_selection,
313
+ usage,
314
+ candidate_value,
315
+ ):
316
+ # keys: 1*C*N
317
+ # values: num_objects*C*N
318
+ N = candidate_key.shape[-1]
319
+
320
+ # find the indices with max usage
321
+ _, max_usage_indices = torch.topk(
322
+ usage, k=self.num_prototypes, dim=-1, sorted=True
323
+ )
324
+ prototype_indices = max_usage_indices.flatten()
325
+
326
+ # Prototypes are invalid for out-of-bound groups
327
+ validity = [
328
+ prototype_indices >= (N - gv.shape[2]) if gv is not None else None
329
+ for gv in candidate_value
330
+ ]
331
+
332
+ prototype_key = candidate_key[:, :, prototype_indices]
333
+ prototype_selection = (
334
+ candidate_selection[:, :, prototype_indices]
335
+ if candidate_selection is not None
336
+ else None
337
+ )
338
+
339
+ """
340
+ Potentiation step
341
+ """
342
+ similarity = get_similarity(
343
+ candidate_key, candidate_shrinkage, prototype_key, prototype_selection
344
+ )
345
+
346
+ # convert similarity to affinity
347
+ # need to do it group by group since the softmax normalization would be different
348
+ affinity = [
349
+ do_softmax(similarity[:, -gv.shape[2] :, validity[gi]])
350
+ if gv is not None
351
+ else None
352
+ for gi, gv in enumerate(candidate_value)
353
+ ]
354
+
355
+ # some values can be have all False validity. Weed them out.
356
+ affinity = [
357
+ aff if aff is None or aff.shape[-1] > 0 else None for aff in affinity
358
+ ]
359
+
360
+ # readout the values
361
+ prototype_value = [
362
+ self._readout(affinity[gi], gv) if affinity[gi] is not None else None
363
+ for gi, gv in enumerate(candidate_value)
364
+ ]
365
+
366
+ # readout the shrinkage term
367
+ prototype_shrinkage = (
368
+ self._readout(affinity[0], candidate_shrinkage)
369
+ if candidate_shrinkage is not None
370
+ else None
371
+ )
372
+
373
+ return prototype_key, prototype_value, prototype_shrinkage
tracker/model/__init__.py ADDED
File without changes
tracker/model/aggregate.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ # Soft aggregation from STM
6
+ def aggregate(prob, dim, return_logits=False):
7
+ new_prob = torch.cat(
8
+ [torch.prod(1 - prob, dim=dim, keepdim=True), prob], dim
9
+ ).clamp(1e-7, 1 - 1e-7)
10
+ logits = torch.log((new_prob / (1 - new_prob)))
11
+ prob = F.softmax(logits, dim=dim)
12
+
13
+ if return_logits:
14
+ return logits, prob
15
+ else:
16
+ return prob
tracker/model/cbam.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class BasicConv(nn.Module):
9
+ def __init__(
10
+ self,
11
+ in_planes,
12
+ out_planes,
13
+ kernel_size,
14
+ stride=1,
15
+ padding=0,
16
+ dilation=1,
17
+ groups=1,
18
+ bias=True,
19
+ ):
20
+ super(BasicConv, self).__init__()
21
+ self.out_channels = out_planes
22
+ self.conv = nn.Conv2d(
23
+ in_planes,
24
+ out_planes,
25
+ kernel_size=kernel_size,
26
+ stride=stride,
27
+ padding=padding,
28
+ dilation=dilation,
29
+ groups=groups,
30
+ bias=bias,
31
+ )
32
+
33
+ def forward(self, x):
34
+ x = self.conv(x)
35
+ return x
36
+
37
+
38
+ class Flatten(nn.Module):
39
+ def forward(self, x):
40
+ return x.view(x.size(0), -1)
41
+
42
+
43
+ class ChannelGate(nn.Module):
44
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=["avg", "max"]):
45
+ super(ChannelGate, self).__init__()
46
+ self.gate_channels = gate_channels
47
+ self.mlp = nn.Sequential(
48
+ Flatten(),
49
+ nn.Linear(gate_channels, gate_channels // reduction_ratio),
50
+ nn.ReLU(),
51
+ nn.Linear(gate_channels // reduction_ratio, gate_channels),
52
+ )
53
+ self.pool_types = pool_types
54
+
55
+ def forward(self, x):
56
+ channel_att_sum = None
57
+ for pool_type in self.pool_types:
58
+ if pool_type == "avg":
59
+ avg_pool = F.avg_pool2d(
60
+ x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))
61
+ )
62
+ channel_att_raw = self.mlp(avg_pool)
63
+ elif pool_type == "max":
64
+ max_pool = F.max_pool2d(
65
+ x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))
66
+ )
67
+ channel_att_raw = self.mlp(max_pool)
68
+
69
+ if channel_att_sum is None:
70
+ channel_att_sum = channel_att_raw
71
+ else:
72
+ channel_att_sum = channel_att_sum + channel_att_raw
73
+
74
+ scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
75
+ return x * scale
76
+
77
+
78
+ class ChannelPool(nn.Module):
79
+ def forward(self, x):
80
+ return torch.cat(
81
+ (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
82
+ )
83
+
84
+
85
+ class SpatialGate(nn.Module):
86
+ def __init__(self):
87
+ super(SpatialGate, self).__init__()
88
+ kernel_size = 7
89
+ self.compress = ChannelPool()
90
+ self.spatial = BasicConv(
91
+ 2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2
92
+ )
93
+
94
+ def forward(self, x):
95
+ x_compress = self.compress(x)
96
+ x_out = self.spatial(x_compress)
97
+ scale = torch.sigmoid(x_out) # broadcasting
98
+ return x * scale
99
+
100
+
101
+ class CBAM(nn.Module):
102
+ def __init__(
103
+ self,
104
+ gate_channels,
105
+ reduction_ratio=16,
106
+ pool_types=["avg", "max"],
107
+ no_spatial=False,
108
+ ):
109
+ super(CBAM, self).__init__()
110
+ self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
111
+ self.no_spatial = no_spatial
112
+ if not no_spatial:
113
+ self.SpatialGate = SpatialGate()
114
+
115
+ def forward(self, x):
116
+ x_out = self.ChannelGate(x)
117
+ if not self.no_spatial:
118
+ x_out = self.SpatialGate(x_out)
119
+ return x_out
tracker/model/group_modules.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Group-specific modules
3
+ They handle features that also depends on the mask.
4
+ Features are typically of shape
5
+ batch_size * num_objects * num_channels * H * W
6
+
7
+ All of them are permutation equivariant w.r.t. to the num_objects dimension
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ def interpolate_groups(g, ratio, mode, align_corners):
16
+ batch_size, num_objects = g.shape[:2]
17
+ g = F.interpolate(
18
+ g.flatten(start_dim=0, end_dim=1),
19
+ scale_factor=ratio,
20
+ mode=mode,
21
+ align_corners=align_corners,
22
+ )
23
+ g = g.view(batch_size, num_objects, *g.shape[1:])
24
+ return g
25
+
26
+
27
+ def upsample_groups(g, ratio=2, mode="bilinear", align_corners=False):
28
+ return interpolate_groups(g, ratio, mode, align_corners)
29
+
30
+
31
+ def downsample_groups(g, ratio=1 / 2, mode="area", align_corners=None):
32
+ return interpolate_groups(g, ratio, mode, align_corners)
33
+
34
+
35
+ class GConv2D(nn.Conv2d):
36
+ def forward(self, g):
37
+ batch_size, num_objects = g.shape[:2]
38
+ g = super().forward(g.flatten(start_dim=0, end_dim=1))
39
+ return g.view(batch_size, num_objects, *g.shape[1:])
40
+
41
+
42
+ class GroupResBlock(nn.Module):
43
+ def __init__(self, in_dim, out_dim):
44
+ super().__init__()
45
+
46
+ if in_dim == out_dim:
47
+ self.downsample = None
48
+ else:
49
+ self.downsample = GConv2D(in_dim, out_dim, kernel_size=3, padding=1)
50
+
51
+ self.conv1 = GConv2D(in_dim, out_dim, kernel_size=3, padding=1)
52
+ self.conv2 = GConv2D(out_dim, out_dim, kernel_size=3, padding=1)
53
+
54
+ def forward(self, g):
55
+ out_g = self.conv1(F.relu(g))
56
+ out_g = self.conv2(F.relu(out_g))
57
+
58
+ if self.downsample is not None:
59
+ g = self.downsample(g)
60
+
61
+ return out_g + g
62
+
63
+
64
+ class MainToGroupDistributor(nn.Module):
65
+ def __init__(self, x_transform=None, method="cat", reverse_order=False):
66
+ super().__init__()
67
+
68
+ self.x_transform = x_transform
69
+ self.method = method
70
+ self.reverse_order = reverse_order
71
+
72
+ def forward(self, x, g):
73
+ num_objects = g.shape[1]
74
+
75
+ if self.x_transform is not None:
76
+ x = self.x_transform(x)
77
+
78
+ if self.method == "cat":
79
+ if self.reverse_order:
80
+ g = torch.cat(
81
+ [g, x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)], 2
82
+ )
83
+ else:
84
+ g = torch.cat(
85
+ [x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1), g], 2
86
+ )
87
+ elif self.method == "add":
88
+ g = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) + g
89
+ else:
90
+ raise NotImplementedError
91
+
92
+ return g
tracker/model/losses.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from collections import defaultdict
6
+
7
+
8
+ def dice_loss(input_mask, cls_gt):
9
+ num_objects = input_mask.shape[1]
10
+ losses = []
11
+ for i in range(num_objects):
12
+ mask = input_mask[:, i].flatten(start_dim=1)
13
+ # background not in mask, so we add one to cls_gt
14
+ gt = (cls_gt == (i + 1)).float().flatten(start_dim=1)
15
+ numerator = 2 * (mask * gt).sum(-1)
16
+ denominator = mask.sum(-1) + gt.sum(-1)
17
+ loss = 1 - (numerator + 1) / (denominator + 1)
18
+ losses.append(loss)
19
+ return torch.cat(losses).mean()
20
+
21
+
22
+ # https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch
23
+ class BootstrappedCE(nn.Module):
24
+ def __init__(self, start_warm, end_warm, top_p=0.15):
25
+ super().__init__()
26
+
27
+ self.start_warm = start_warm
28
+ self.end_warm = end_warm
29
+ self.top_p = top_p
30
+
31
+ def forward(self, input, target, it):
32
+ if it < self.start_warm:
33
+ return F.cross_entropy(input, target), 1.0
34
+
35
+ raw_loss = F.cross_entropy(input, target, reduction="none").view(-1)
36
+ num_pixels = raw_loss.numel()
37
+
38
+ if it > self.end_warm:
39
+ this_p = self.top_p
40
+ else:
41
+ this_p = self.top_p + (1 - self.top_p) * (
42
+ (self.end_warm - it) / (self.end_warm - self.start_warm)
43
+ )
44
+ loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
45
+ return loss.mean(), this_p
46
+
47
+
48
+ class LossComputer:
49
+ def __init__(self, config):
50
+ super().__init__()
51
+ self.config = config
52
+ self.bce = BootstrappedCE(config["start_warm"], config["end_warm"])
53
+
54
+ def compute(self, data, num_objects, it):
55
+ losses = defaultdict(int)
56
+
57
+ b, t = data["rgb"].shape[:2]
58
+
59
+ losses["total_loss"] = 0
60
+ for ti in range(1, t):
61
+ for bi in range(b):
62
+ loss, p = self.bce(
63
+ data[f"logits_{ti}"][bi : bi + 1, : num_objects[bi] + 1],
64
+ data["cls_gt"][bi : bi + 1, ti, 0],
65
+ it,
66
+ )
67
+ losses["p"] += p / b / (t - 1)
68
+ losses[f"ce_loss_{ti}"] += loss / b
69
+
70
+ losses["total_loss"] += losses["ce_loss_%d" % ti]
71
+ losses[f"dice_loss_{ti}"] = dice_loss(
72
+ data[f"masks_{ti}"], data["cls_gt"][:, ti, 0]
73
+ )
74
+ losses["total_loss"] += losses[f"dice_loss_{ti}"]
75
+
76
+ return losses
tracker/model/memory_util.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from typing import Optional
5
+
6
+
7
+ def get_similarity(mk, ms, qk, qe):
8
+ # used for training/inference and memory reading/memory potentiation
9
+ # mk: B x CK x [N] - Memory keys
10
+ # ms: B x 1 x [N] - Memory shrinkage
11
+ # qk: B x CK x [HW/P] - Query keys
12
+ # qe: B x CK x [HW/P] - Query selection
13
+ # Dimensions in [] are flattened
14
+ CK = mk.shape[1]
15
+ mk = mk.flatten(start_dim=2)
16
+ ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
17
+ qk = qk.flatten(start_dim=2)
18
+ qe = qe.flatten(start_dim=2) if qe is not None else None
19
+
20
+ if qe is not None:
21
+ # See appendix for derivation
22
+ # or you can just trust me ヽ(ー_ー )ノ
23
+ mk = mk.transpose(1, 2)
24
+ a_sq = mk.pow(2) @ qe
25
+ two_ab = 2 * (mk @ (qk * qe))
26
+ b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
27
+ similarity = -a_sq + two_ab - b_sq
28
+ else:
29
+ # similar to STCN if we don't have the selection term
30
+ a_sq = mk.pow(2).sum(1).unsqueeze(2)
31
+ two_ab = 2 * (mk.transpose(1, 2) @ qk)
32
+ similarity = -a_sq + two_ab
33
+
34
+ if ms is not None:
35
+ similarity = similarity * ms / math.sqrt(CK) # B*N*HW
36
+ else:
37
+ similarity = similarity / math.sqrt(CK) # B*N*HW
38
+
39
+ return similarity
40
+
41
+
42
+ def do_softmax(
43
+ similarity, top_k: Optional[int] = None, inplace=False, return_usage=False
44
+ ):
45
+ # normalize similarity with top-k softmax
46
+ # similarity: B x N x [HW/P]
47
+ # use inplace with care
48
+ if top_k is not None:
49
+ values, indices = torch.topk(similarity, k=top_k, dim=1)
50
+
51
+ x_exp = values.exp_()
52
+ x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
53
+ if inplace:
54
+ similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
55
+ affinity = similarity
56
+ else:
57
+ affinity = torch.zeros_like(similarity).scatter_(
58
+ 1, indices, x_exp
59
+ ) # B*N*HW
60
+ else:
61
+ maxes = torch.max(similarity, dim=1, keepdim=True)[0]
62
+ x_exp = torch.exp(similarity - maxes)
63
+ x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
64
+ affinity = x_exp / x_exp_sum
65
+ indices = None
66
+
67
+ if return_usage:
68
+ return affinity, affinity.sum(dim=2)
69
+
70
+ return affinity
71
+
72
+
73
+ def get_affinity(mk, ms, qk, qe):
74
+ # shorthand used in training with no top-k
75
+ similarity = get_similarity(mk, ms, qk, qe)
76
+ affinity = do_softmax(similarity)
77
+ return affinity
78
+
79
+
80
+ def readout(affinity, mv):
81
+ B, CV, T, H, W = mv.shape
82
+
83
+ mo = mv.view(B, CV, T * H * W)
84
+ mem = torch.bmm(mo, affinity)
85
+ mem = mem.view(B, CV, H, W)
86
+
87
+ return mem
tracker/model/modules.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modules.py - This file stores the rather boring network blocks.
3
+
4
+ x - usually means features that only depends on the image
5
+ g - usually means features that also depends on the mask.
6
+ They might have an extra "group" or "num_objects" dimension, hence
7
+ batch_size * num_objects * num_channels * H * W
8
+
9
+ The trailing number of a variable usually denote the stride
10
+
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from model.group_modules import *
18
+ from model import resnet
19
+ from model.cbam import CBAM
20
+
21
+
22
+ class FeatureFusionBlock(nn.Module):
23
+ def __init__(self, x_in_dim, g_in_dim, g_mid_dim, g_out_dim):
24
+ super().__init__()
25
+
26
+ self.distributor = MainToGroupDistributor()
27
+ self.block1 = GroupResBlock(x_in_dim + g_in_dim, g_mid_dim)
28
+ self.attention = CBAM(g_mid_dim)
29
+ self.block2 = GroupResBlock(g_mid_dim, g_out_dim)
30
+
31
+ def forward(self, x, g):
32
+ batch_size, num_objects = g.shape[:2]
33
+
34
+ g = self.distributor(x, g)
35
+ g = self.block1(g)
36
+ r = self.attention(g.flatten(start_dim=0, end_dim=1))
37
+ r = r.view(batch_size, num_objects, *r.shape[1:])
38
+
39
+ g = self.block2(g + r)
40
+
41
+ return g
42
+
43
+
44
+ class HiddenUpdater(nn.Module):
45
+ # Used in the decoder, multi-scale feature + GRU
46
+ def __init__(self, g_dims, mid_dim, hidden_dim):
47
+ super().__init__()
48
+ self.hidden_dim = hidden_dim
49
+
50
+ self.g16_conv = GConv2D(g_dims[0], mid_dim, kernel_size=1)
51
+ self.g8_conv = GConv2D(g_dims[1], mid_dim, kernel_size=1)
52
+ self.g4_conv = GConv2D(g_dims[2], mid_dim, kernel_size=1)
53
+
54
+ self.transform = GConv2D(
55
+ mid_dim + hidden_dim, hidden_dim * 3, kernel_size=3, padding=1
56
+ )
57
+
58
+ nn.init.xavier_normal_(self.transform.weight)
59
+
60
+ def forward(self, g, h):
61
+ g = (
62
+ self.g16_conv(g[0])
63
+ + self.g8_conv(downsample_groups(g[1], ratio=1 / 2))
64
+ + self.g4_conv(downsample_groups(g[2], ratio=1 / 4))
65
+ )
66
+
67
+ g = torch.cat([g, h], 2)
68
+
69
+ # defined slightly differently than standard GRU,
70
+ # namely the new value is generated before the forget gate.
71
+ # might provide better gradient but frankly it was initially just an
72
+ # implementation error that I never bothered fixing
73
+ values = self.transform(g)
74
+ forget_gate = torch.sigmoid(values[:, :, : self.hidden_dim])
75
+ update_gate = torch.sigmoid(values[:, :, self.hidden_dim : self.hidden_dim * 2])
76
+ new_value = torch.tanh(values[:, :, self.hidden_dim * 2 :])
77
+ new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
78
+
79
+ return new_h
80
+
81
+
82
+ class HiddenReinforcer(nn.Module):
83
+ # Used in the value encoder, a single GRU
84
+ def __init__(self, g_dim, hidden_dim):
85
+ super().__init__()
86
+ self.hidden_dim = hidden_dim
87
+ self.transform = GConv2D(
88
+ g_dim + hidden_dim, hidden_dim * 3, kernel_size=3, padding=1
89
+ )
90
+
91
+ nn.init.xavier_normal_(self.transform.weight)
92
+
93
+ def forward(self, g, h):
94
+ g = torch.cat([g, h], 2)
95
+
96
+ # defined slightly differently than standard GRU,
97
+ # namely the new value is generated before the forget gate.
98
+ # might provide better gradient but frankly it was initially just an
99
+ # implementation error that I never bothered fixing
100
+ values = self.transform(g)
101
+ forget_gate = torch.sigmoid(values[:, :, : self.hidden_dim])
102
+ update_gate = torch.sigmoid(values[:, :, self.hidden_dim : self.hidden_dim * 2])
103
+ new_value = torch.tanh(values[:, :, self.hidden_dim * 2 :])
104
+ new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
105
+
106
+ return new_h
107
+
108
+
109
+ class ValueEncoder(nn.Module):
110
+ def __init__(self, value_dim, hidden_dim, single_object=False):
111
+ super().__init__()
112
+
113
+ self.single_object = single_object
114
+ network = resnet.resnet18(pretrained=True, extra_dim=1 if single_object else 2)
115
+ self.conv1 = network.conv1
116
+ self.bn1 = network.bn1
117
+ self.relu = network.relu # 1/2, 64
118
+ self.maxpool = network.maxpool
119
+
120
+ self.layer1 = network.layer1 # 1/4, 64
121
+ self.layer2 = network.layer2 # 1/8, 128
122
+ self.layer3 = network.layer3 # 1/16, 256
123
+
124
+ self.distributor = MainToGroupDistributor()
125
+ self.fuser = FeatureFusionBlock(1024, 256, value_dim, value_dim)
126
+ if hidden_dim > 0:
127
+ self.hidden_reinforce = HiddenReinforcer(value_dim, hidden_dim)
128
+ else:
129
+ self.hidden_reinforce = None
130
+
131
+ def forward(self, image, image_feat_f16, h, masks, others, is_deep_update=True):
132
+ # image_feat_f16 is the feature from the key encoder
133
+ if not self.single_object:
134
+ g = torch.stack([masks, others], 2)
135
+ else:
136
+ g = masks.unsqueeze(2)
137
+ g = self.distributor(image, g)
138
+
139
+ batch_size, num_objects = g.shape[:2]
140
+ g = g.flatten(start_dim=0, end_dim=1)
141
+
142
+ g = self.conv1(g)
143
+ g = self.bn1(g) # 1/2, 64
144
+ g = self.maxpool(g) # 1/4, 64
145
+ g = self.relu(g)
146
+
147
+ g = self.layer1(g) # 1/4
148
+ g = self.layer2(g) # 1/8
149
+ g = self.layer3(g) # 1/16
150
+
151
+ g = g.view(batch_size, num_objects, *g.shape[1:])
152
+ g = self.fuser(image_feat_f16, g)
153
+
154
+ if is_deep_update and self.hidden_reinforce is not None:
155
+ h = self.hidden_reinforce(g, h)
156
+
157
+ return g, h
158
+
159
+
160
+ class KeyEncoder(nn.Module):
161
+ def __init__(self):
162
+ super().__init__()
163
+ network = resnet.resnet50(pretrained=True)
164
+ self.conv1 = network.conv1
165
+ self.bn1 = network.bn1
166
+ self.relu = network.relu # 1/2, 64
167
+ self.maxpool = network.maxpool
168
+
169
+ self.res2 = network.layer1 # 1/4, 256
170
+ self.layer2 = network.layer2 # 1/8, 512
171
+ self.layer3 = network.layer3 # 1/16, 1024
172
+
173
+ def forward(self, f):
174
+ x = self.conv1(f)
175
+ x = self.bn1(x)
176
+ x = self.relu(x) # 1/2, 64
177
+ x = self.maxpool(x) # 1/4, 64
178
+ f4 = self.res2(x) # 1/4, 256
179
+ f8 = self.layer2(f4) # 1/8, 512
180
+ f16 = self.layer3(f8) # 1/16, 1024
181
+
182
+ return f16, f8, f4
183
+
184
+
185
+ class UpsampleBlock(nn.Module):
186
+ def __init__(self, skip_dim, g_up_dim, g_out_dim, scale_factor=2):
187
+ super().__init__()
188
+ self.skip_conv = nn.Conv2d(skip_dim, g_up_dim, kernel_size=3, padding=1)
189
+ self.distributor = MainToGroupDistributor(method="add")
190
+ self.out_conv = GroupResBlock(g_up_dim, g_out_dim)
191
+ self.scale_factor = scale_factor
192
+
193
+ def forward(self, skip_f, up_g):
194
+ skip_f = self.skip_conv(skip_f)
195
+ g = upsample_groups(up_g, ratio=self.scale_factor)
196
+ g = self.distributor(skip_f, g)
197
+ g = self.out_conv(g)
198
+ return g
199
+
200
+
201
+ class KeyProjection(nn.Module):
202
+ def __init__(self, in_dim, keydim):
203
+ super().__init__()
204
+
205
+ self.key_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
206
+ # shrinkage
207
+ self.d_proj = nn.Conv2d(in_dim, 1, kernel_size=3, padding=1)
208
+ # selection
209
+ self.e_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
210
+
211
+ nn.init.orthogonal_(self.key_proj.weight.data)
212
+ nn.init.zeros_(self.key_proj.bias.data)
213
+
214
+ def forward(self, x, need_s, need_e):
215
+ shrinkage = self.d_proj(x) ** 2 + 1 if (need_s) else None
216
+ selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
217
+
218
+ return self.key_proj(x), shrinkage, selection
219
+
220
+
221
+ class Decoder(nn.Module):
222
+ def __init__(self, val_dim, hidden_dim):
223
+ super().__init__()
224
+
225
+ self.fuser = FeatureFusionBlock(1024, val_dim + hidden_dim, 512, 512)
226
+ if hidden_dim > 0:
227
+ self.hidden_update = HiddenUpdater([512, 256, 256 + 1], 256, hidden_dim)
228
+ else:
229
+ self.hidden_update = None
230
+
231
+ self.up_16_8 = UpsampleBlock(512, 512, 256) # 1/16 -> 1/8
232
+ self.up_8_4 = UpsampleBlock(256, 256, 256) # 1/8 -> 1/4
233
+
234
+ self.pred = nn.Conv2d(256, 1, kernel_size=3, padding=1, stride=1)
235
+
236
+ def forward(self, f16, f8, f4, hidden_state, memory_readout, h_out=True):
237
+ batch_size, num_objects = memory_readout.shape[:2]
238
+
239
+ if self.hidden_update is not None:
240
+ g16 = self.fuser(f16, torch.cat([memory_readout, hidden_state], 2))
241
+ else:
242
+ g16 = self.fuser(f16, memory_readout)
243
+
244
+ g8 = self.up_16_8(f8, g16)
245
+ g4 = self.up_8_4(f4, g8)
246
+ logits = self.pred(F.relu(g4.flatten(start_dim=0, end_dim=1)))
247
+
248
+ if h_out and self.hidden_update is not None:
249
+ g4 = torch.cat(
250
+ [g4, logits.view(batch_size, num_objects, 1, *logits.shape[-2:])], 2
251
+ )
252
+ hidden_state = self.hidden_update([g16, g8, g4], hidden_state)
253
+ else:
254
+ hidden_state = None
255
+
256
+ logits = F.interpolate(
257
+ logits, scale_factor=4, mode="bilinear", align_corners=False
258
+ )
259
+ logits = logits.view(batch_size, num_objects, *logits.shape[-2:])
260
+
261
+ return hidden_state, logits
tracker/model/network.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file defines XMem, the highest level nn.Module interface
3
+ During training, it is used by trainer.py
4
+ During evaluation, it is used by inference_core.py
5
+
6
+ It further depends on modules.py which gives more detailed implementations of sub-modules
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from model.aggregate import aggregate
13
+ from model.modules import *
14
+ from model.memory_util import *
15
+
16
+
17
+ class XMem(nn.Module):
18
+ def __init__(self, config, model_path=None, map_location=None):
19
+ """
20
+ model_path/map_location are used in evaluation only
21
+ map_location is for converting models saved in cuda to cpu
22
+ """
23
+ super().__init__()
24
+ model_weights = self.init_hyperparameters(config, model_path, map_location)
25
+
26
+ self.single_object = config.get("single_object", False)
27
+ print(f"Single object mode: {self.single_object}")
28
+
29
+ self.key_encoder = KeyEncoder()
30
+ self.value_encoder = ValueEncoder(
31
+ self.value_dim, self.hidden_dim, self.single_object
32
+ )
33
+
34
+ # Projection from f16 feature space to key/value space
35
+ self.key_proj = KeyProjection(1024, self.key_dim)
36
+
37
+ self.decoder = Decoder(self.value_dim, self.hidden_dim)
38
+
39
+ if model_weights is not None:
40
+ self.load_weights(model_weights, init_as_zero_if_needed=True)
41
+
42
+ def encode_key(self, frame, need_sk=True, need_ek=True):
43
+ # Determine input shape
44
+ if len(frame.shape) == 5:
45
+ # shape is b*t*c*h*w
46
+ need_reshape = True
47
+ b, t = frame.shape[:2]
48
+ # flatten so that we can feed them into a 2D CNN
49
+ frame = frame.flatten(start_dim=0, end_dim=1)
50
+ elif len(frame.shape) == 4:
51
+ # shape is b*c*h*w
52
+ need_reshape = False
53
+ else:
54
+ raise NotImplementedError
55
+
56
+ f16, f8, f4 = self.key_encoder(frame)
57
+ key, shrinkage, selection = self.key_proj(f16, need_sk, need_ek)
58
+
59
+ if need_reshape:
60
+ # B*C*T*H*W
61
+ key = key.view(b, t, *key.shape[-3:]).transpose(1, 2).contiguous()
62
+ if shrinkage is not None:
63
+ shrinkage = (
64
+ shrinkage.view(b, t, *shrinkage.shape[-3:])
65
+ .transpose(1, 2)
66
+ .contiguous()
67
+ )
68
+ if selection is not None:
69
+ selection = (
70
+ selection.view(b, t, *selection.shape[-3:])
71
+ .transpose(1, 2)
72
+ .contiguous()
73
+ )
74
+
75
+ # B*T*C*H*W
76
+ f16 = f16.view(b, t, *f16.shape[-3:])
77
+ f8 = f8.view(b, t, *f8.shape[-3:])
78
+ f4 = f4.view(b, t, *f4.shape[-3:])
79
+
80
+ return key, shrinkage, selection, f16, f8, f4
81
+
82
+ def encode_value(self, frame, image_feat_f16, h16, masks, is_deep_update=True):
83
+ num_objects = masks.shape[1]
84
+ if num_objects != 1:
85
+ others = torch.cat(
86
+ [
87
+ torch.sum(
88
+ masks[:, [j for j in range(num_objects) if i != j]],
89
+ dim=1,
90
+ keepdim=True,
91
+ )
92
+ for i in range(num_objects)
93
+ ],
94
+ 1,
95
+ )
96
+ else:
97
+ others = torch.zeros_like(masks)
98
+
99
+ g16, h16 = self.value_encoder(
100
+ frame, image_feat_f16, h16, masks, others, is_deep_update
101
+ )
102
+
103
+ return g16, h16
104
+
105
+ # Used in training only.
106
+ # This step is replaced by MemoryManager in test time
107
+ def read_memory(
108
+ self, query_key, query_selection, memory_key, memory_shrinkage, memory_value
109
+ ):
110
+ """
111
+ query_key : B * CK * H * W
112
+ query_selection : B * CK * H * W
113
+ memory_key : B * CK * T * H * W
114
+ memory_shrinkage: B * 1 * T * H * W
115
+ memory_value : B * num_objects * CV * T * H * W
116
+ """
117
+ batch_size, num_objects = memory_value.shape[:2]
118
+ memory_value = memory_value.flatten(start_dim=1, end_dim=2)
119
+
120
+ affinity = get_affinity(
121
+ memory_key, memory_shrinkage, query_key, query_selection
122
+ )
123
+ memory = readout(affinity, memory_value)
124
+ memory = memory.view(
125
+ batch_size, num_objects, self.value_dim, *memory.shape[-2:]
126
+ )
127
+
128
+ return memory
129
+
130
+ def segment(
131
+ self,
132
+ multi_scale_features,
133
+ memory_readout,
134
+ hidden_state,
135
+ selector=None,
136
+ h_out=True,
137
+ strip_bg=True,
138
+ ):
139
+
140
+ hidden_state, logits = self.decoder(
141
+ *multi_scale_features, hidden_state, memory_readout, h_out=h_out
142
+ )
143
+ prob = torch.sigmoid(logits)
144
+ if selector is not None:
145
+ prob = prob * selector
146
+
147
+ logits, prob = aggregate(prob, dim=1, return_logits=True)
148
+ if strip_bg:
149
+ # Strip away the background
150
+ prob = prob[:, 1:]
151
+
152
+ return hidden_state, logits, prob
153
+
154
+ def forward(self, mode, *args, **kwargs):
155
+ if mode == "encode_key":
156
+ return self.encode_key(*args, **kwargs)
157
+ elif mode == "encode_value":
158
+ return self.encode_value(*args, **kwargs)
159
+ elif mode == "read_memory":
160
+ return self.read_memory(*args, **kwargs)
161
+ elif mode == "segment":
162
+ return self.segment(*args, **kwargs)
163
+ else:
164
+ raise NotImplementedError
165
+
166
+ def init_hyperparameters(self, config, model_path=None, map_location=None):
167
+ """
168
+ Init three hyperparameters: key_dim, value_dim, and hidden_dim
169
+ If model_path is provided, we load these from the model weights
170
+ The actual parameters are then updated to the config in-place
171
+
172
+ Otherwise we load it either from the config or default
173
+ """
174
+ if model_path is not None:
175
+ # load the model and key/value/hidden dimensions with some hacks
176
+ # config is updated with the loaded parameters
177
+ model_weights = torch.load(model_path, map_location=map_location)
178
+ self.key_dim = model_weights["key_proj.key_proj.weight"].shape[0]
179
+ self.value_dim = model_weights[
180
+ "value_encoder.fuser.block2.conv2.weight"
181
+ ].shape[0]
182
+ self.disable_hidden = (
183
+ "decoder.hidden_update.transform.weight" not in model_weights
184
+ )
185
+ if self.disable_hidden:
186
+ self.hidden_dim = 0
187
+ else:
188
+ self.hidden_dim = (
189
+ model_weights["decoder.hidden_update.transform.weight"].shape[0]
190
+ // 3
191
+ )
192
+ print(
193
+ f"Hyperparameters read from the model weights: "
194
+ f"C^k={self.key_dim}, C^v={self.value_dim}, C^h={self.hidden_dim}"
195
+ )
196
+ else:
197
+ model_weights = None
198
+ # load dimensions from config or default
199
+ if "key_dim" not in config:
200
+ self.key_dim = 64
201
+ print(f"key_dim not found in config. Set to default {self.key_dim}")
202
+ else:
203
+ self.key_dim = config["key_dim"]
204
+
205
+ if "value_dim" not in config:
206
+ self.value_dim = 512
207
+ print(f"value_dim not found in config. Set to default {self.value_dim}")
208
+ else:
209
+ self.value_dim = config["value_dim"]
210
+
211
+ if "hidden_dim" not in config:
212
+ self.hidden_dim = 64
213
+ print(
214
+ f"hidden_dim not found in config. Set to default {self.hidden_dim}"
215
+ )
216
+ else:
217
+ self.hidden_dim = config["hidden_dim"]
218
+
219
+ self.disable_hidden = self.hidden_dim <= 0
220
+
221
+ config["key_dim"] = self.key_dim
222
+ config["value_dim"] = self.value_dim
223
+ config["hidden_dim"] = self.hidden_dim
224
+
225
+ return model_weights
226
+
227
+ def load_weights(self, src_dict, init_as_zero_if_needed=False):
228
+ # Maps SO weight (without other_mask) to MO weight (with other_mask)
229
+ for k in list(src_dict.keys()):
230
+ if k == "value_encoder.conv1.weight":
231
+ if src_dict[k].shape[1] == 4:
232
+ print("Converting weights from single object to multiple objects.")
233
+ pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device)
234
+ if not init_as_zero_if_needed:
235
+ print("Randomly initialized padding.")
236
+ nn.init.orthogonal_(pads)
237
+ else:
238
+ print("Zero-initialized padding.")
239
+ src_dict[k] = torch.cat([src_dict[k], pads], 1)
240
+
241
+ self.load_state_dict(src_dict)
tracker/model/resnet.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ resnet.py - A modified ResNet structure
3
+ We append extra channels to the first conv by some network surgery
4
+ """
5
+
6
+ from collections import OrderedDict
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.utils import model_zoo
12
+
13
+
14
+ def load_weights_add_extra_dim(target, source_state, extra_dim=1):
15
+ new_dict = OrderedDict()
16
+
17
+ for k1, v1 in target.state_dict().items():
18
+ if not "num_batches_tracked" in k1:
19
+ if k1 in source_state:
20
+ tar_v = source_state[k1]
21
+
22
+ if v1.shape != tar_v.shape:
23
+ # Init the new segmentation channel with zeros
24
+ # print(v1.shape, tar_v.shape)
25
+ c, _, w, h = v1.shape
26
+ pads = torch.zeros((c, extra_dim, w, h), device=tar_v.device)
27
+ nn.init.orthogonal_(pads)
28
+ tar_v = torch.cat([tar_v, pads], 1)
29
+
30
+ new_dict[k1] = tar_v
31
+
32
+ target.load_state_dict(new_dict)
33
+
34
+
35
+ model_urls = {
36
+ "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
37
+ "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
38
+ }
39
+
40
+
41
+ def conv3x3(in_planes, out_planes, stride=1, dilation=1):
42
+ return nn.Conv2d(
43
+ in_planes,
44
+ out_planes,
45
+ kernel_size=3,
46
+ stride=stride,
47
+ padding=dilation,
48
+ dilation=dilation,
49
+ bias=False,
50
+ )
51
+
52
+
53
+ class BasicBlock(nn.Module):
54
+ expansion = 1
55
+
56
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
57
+ super(BasicBlock, self).__init__()
58
+ self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
59
+ self.bn1 = nn.BatchNorm2d(planes)
60
+ self.relu = nn.ReLU(inplace=True)
61
+ self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
62
+ self.bn2 = nn.BatchNorm2d(planes)
63
+ self.downsample = downsample
64
+ self.stride = stride
65
+
66
+ def forward(self, x):
67
+ residual = x
68
+
69
+ out = self.conv1(x)
70
+ out = self.bn1(out)
71
+ out = self.relu(out)
72
+
73
+ out = self.conv2(out)
74
+ out = self.bn2(out)
75
+
76
+ if self.downsample is not None:
77
+ residual = self.downsample(x)
78
+
79
+ out += residual
80
+ out = self.relu(out)
81
+
82
+ return out
83
+
84
+
85
+ class Bottleneck(nn.Module):
86
+ expansion = 4
87
+
88
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
89
+ super(Bottleneck, self).__init__()
90
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
91
+ self.bn1 = nn.BatchNorm2d(planes)
92
+ self.conv2 = nn.Conv2d(
93
+ planes,
94
+ planes,
95
+ kernel_size=3,
96
+ stride=stride,
97
+ dilation=dilation,
98
+ padding=dilation,
99
+ bias=False,
100
+ )
101
+ self.bn2 = nn.BatchNorm2d(planes)
102
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
103
+ self.bn3 = nn.BatchNorm2d(planes * 4)
104
+ self.relu = nn.ReLU(inplace=True)
105
+ self.downsample = downsample
106
+ self.stride = stride
107
+
108
+ def forward(self, x):
109
+ residual = x
110
+
111
+ out = self.conv1(x)
112
+ out = self.bn1(out)
113
+ out = self.relu(out)
114
+
115
+ out = self.conv2(out)
116
+ out = self.bn2(out)
117
+ out = self.relu(out)
118
+
119
+ out = self.conv3(out)
120
+ out = self.bn3(out)
121
+
122
+ if self.downsample is not None:
123
+ residual = self.downsample(x)
124
+
125
+ out += residual
126
+ out = self.relu(out)
127
+
128
+ return out
129
+
130
+
131
+ class ResNet(nn.Module):
132
+ def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
133
+ self.inplanes = 64
134
+ super(ResNet, self).__init__()
135
+ self.conv1 = nn.Conv2d(
136
+ 3 + extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False
137
+ )
138
+ self.bn1 = nn.BatchNorm2d(64)
139
+ self.relu = nn.ReLU(inplace=True)
140
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
141
+ self.layer1 = self._make_layer(block, 64, layers[0])
142
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
143
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
144
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
145
+
146
+ for m in self.modules():
147
+ if isinstance(m, nn.Conv2d):
148
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
149
+ m.weight.data.normal_(0, math.sqrt(2.0 / n))
150
+ elif isinstance(m, nn.BatchNorm2d):
151
+ m.weight.data.fill_(1)
152
+ m.bias.data.zero_()
153
+
154
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
155
+ downsample = None
156
+ if stride != 1 or self.inplanes != planes * block.expansion:
157
+ downsample = nn.Sequential(
158
+ nn.Conv2d(
159
+ self.inplanes,
160
+ planes * block.expansion,
161
+ kernel_size=1,
162
+ stride=stride,
163
+ bias=False,
164
+ ),
165
+ nn.BatchNorm2d(planes * block.expansion),
166
+ )
167
+
168
+ layers = [block(self.inplanes, planes, stride, downsample)]
169
+ self.inplanes = planes * block.expansion
170
+ for i in range(1, blocks):
171
+ layers.append(block(self.inplanes, planes, dilation=dilation))
172
+
173
+ return nn.Sequential(*layers)
174
+
175
+
176
+ def resnet18(pretrained=True, extra_dim=0):
177
+ model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim)
178
+ if pretrained:
179
+ load_weights_add_extra_dim(
180
+ model, model_zoo.load_url(model_urls["resnet18"]), extra_dim
181
+ )
182
+ return model
183
+
184
+
185
+ def resnet50(pretrained=True, extra_dim=0):
186
+ model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim)
187
+ if pretrained:
188
+ load_weights_add_extra_dim(
189
+ model, model_zoo.load_url(model_urls["resnet50"]), extra_dim
190
+ )
191
+ return model
tracker/model/trainer.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ trainer.py - warpper and utility functions for network training
3
+ Compute loss, back-prop, update parameters, logging, etc.
4
+ """
5
+ import datetime
6
+ import os
7
+ import time
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+
13
+ from model.network import XMem
14
+ from model.losses import LossComputer
15
+ from util.log_integrator import Integrator
16
+ from util.image_saver import pool_pairs
17
+
18
+
19
+ class XMemTrainer:
20
+ def __init__(self, config, logger=None, save_path=None, local_rank=0, world_size=1):
21
+ self.config = config
22
+ self.num_frames = config["num_frames"]
23
+ self.num_ref_frames = config["num_ref_frames"]
24
+ self.deep_update_prob = config["deep_update_prob"]
25
+ self.local_rank = local_rank
26
+
27
+ self.XMem = nn.parallel.DistributedDataParallel(
28
+ XMem(config).cuda(),
29
+ device_ids=[local_rank],
30
+ output_device=local_rank,
31
+ broadcast_buffers=False,
32
+ )
33
+
34
+ # Set up logger when local_rank=0
35
+ self.logger = logger
36
+ self.save_path = save_path
37
+ if logger is not None:
38
+ self.last_time = time.time()
39
+ self.logger.log_string(
40
+ "model_size",
41
+ str(sum([param.nelement() for param in self.XMem.parameters()])),
42
+ )
43
+ self.train_integrator = Integrator(
44
+ self.logger, distributed=True, local_rank=local_rank, world_size=world_size
45
+ )
46
+ self.loss_computer = LossComputer(config)
47
+
48
+ self.train()
49
+ self.optimizer = optim.AdamW(
50
+ filter(lambda p: p.requires_grad, self.XMem.parameters()),
51
+ lr=config["lr"],
52
+ weight_decay=config["weight_decay"],
53
+ )
54
+ self.scheduler = optim.lr_scheduler.MultiStepLR(
55
+ self.optimizer, config["steps"], config["gamma"]
56
+ )
57
+ if config["amp"]:
58
+ self.scaler = torch.cuda.amp.GradScaler()
59
+
60
+ # Logging info
61
+ self.log_text_interval = config["log_text_interval"]
62
+ self.log_image_interval = config["log_image_interval"]
63
+ self.save_network_interval = config["save_network_interval"]
64
+ self.save_checkpoint_interval = config["save_checkpoint_interval"]
65
+ if config["debug"]:
66
+ self.log_text_interval = self.log_image_interval = 1
67
+
68
+ def do_pass(self, data, max_it, it=0):
69
+ # No need to store the gradient outside training
70
+ torch.set_grad_enabled(self._is_train)
71
+
72
+ for k, v in data.items():
73
+ if type(v) != list and type(v) != dict and type(v) != int:
74
+ data[k] = v.cuda(non_blocking=True)
75
+
76
+ out = {}
77
+ frames = data["rgb"]
78
+ first_frame_gt = data["first_frame_gt"].float()
79
+ b = frames.shape[0]
80
+ num_filled_objects = [o.item() for o in data["info"]["num_objects"]]
81
+ num_objects = first_frame_gt.shape[2]
82
+ selector = data["selector"].unsqueeze(2).unsqueeze(2)
83
+
84
+ global_avg = 0
85
+
86
+ with torch.cuda.amp.autocast(enabled=self.config["amp"]):
87
+ # image features never change, compute once
88
+ key, shrinkage, selection, f16, f8, f4 = self.XMem("encode_key", frames)
89
+
90
+ filler_one = torch.zeros(1, dtype=torch.int64)
91
+ hidden = torch.zeros(
92
+ (b, num_objects, self.config["hidden_dim"], *key.shape[-2:])
93
+ )
94
+ v16, hidden = self.XMem(
95
+ "encode_value", frames[:, 0], f16[:, 0], hidden, first_frame_gt[:, 0]
96
+ )
97
+ values = v16.unsqueeze(3) # add the time dimension
98
+
99
+ for ti in range(1, self.num_frames):
100
+ if ti <= self.num_ref_frames:
101
+ ref_values = values
102
+ ref_keys = key[:, :, :ti]
103
+ ref_shrinkage = (
104
+ shrinkage[:, :, :ti] if shrinkage is not None else None
105
+ )
106
+ else:
107
+ # pick num_ref_frames random frames
108
+ # this is not very efficient but I think we would
109
+ # need broadcasting in gather which we don't have
110
+ indices = [
111
+ torch.cat(
112
+ [
113
+ filler_one,
114
+ torch.randperm(ti - 1)[: self.num_ref_frames - 1] + 1,
115
+ ]
116
+ )
117
+ for _ in range(b)
118
+ ]
119
+ ref_values = torch.stack(
120
+ [values[bi, :, :, indices[bi]] for bi in range(b)], 0
121
+ )
122
+ ref_keys = torch.stack(
123
+ [key[bi, :, indices[bi]] for bi in range(b)], 0
124
+ )
125
+ ref_shrinkage = (
126
+ torch.stack(
127
+ [shrinkage[bi, :, indices[bi]] for bi in range(b)], 0
128
+ )
129
+ if shrinkage is not None
130
+ else None
131
+ )
132
+
133
+ # Segment frame ti
134
+ memory_readout = self.XMem(
135
+ "read_memory",
136
+ key[:, :, ti],
137
+ selection[:, :, ti] if selection is not None else None,
138
+ ref_keys,
139
+ ref_shrinkage,
140
+ ref_values,
141
+ )
142
+ hidden, logits, masks = self.XMem(
143
+ "segment",
144
+ (f16[:, ti], f8[:, ti], f4[:, ti]),
145
+ memory_readout,
146
+ hidden,
147
+ selector,
148
+ h_out=(ti < (self.num_frames - 1)),
149
+ )
150
+
151
+ # No need to encode the last frame
152
+ if ti < (self.num_frames - 1):
153
+ is_deep_update = np.random.rand() < self.deep_update_prob
154
+ v16, hidden = self.XMem(
155
+ "encode_value",
156
+ frames[:, ti],
157
+ f16[:, ti],
158
+ hidden,
159
+ masks,
160
+ is_deep_update=is_deep_update,
161
+ )
162
+ values = torch.cat([values, v16.unsqueeze(3)], 3)
163
+
164
+ out[f"masks_{ti}"] = masks
165
+ out[f"logits_{ti}"] = logits
166
+
167
+ if self._do_log or self._is_train:
168
+ losses = self.loss_computer.compute(
169
+ {**data, **out}, num_filled_objects, it
170
+ )
171
+
172
+ # Logging
173
+ if self._do_log:
174
+ self.integrator.add_dict(losses)
175
+ if self._is_train:
176
+ if it % self.log_image_interval == 0 and it != 0:
177
+ if self.logger is not None:
178
+ images = {**data, **out}
179
+ size = (384, 384)
180
+ self.logger.log_cv2(
181
+ "train/pairs",
182
+ pool_pairs(images, size, num_filled_objects),
183
+ it,
184
+ )
185
+
186
+ if self._is_train:
187
+
188
+ if (it) % self.log_text_interval == 0 and it != 0:
189
+ time_spent = time.time() - self.last_time
190
+
191
+ if self.logger is not None:
192
+ self.logger.log_scalar(
193
+ "train/lr", self.scheduler.get_last_lr()[0], it
194
+ )
195
+ self.logger.log_metrics(
196
+ "train", "time", (time_spent) / self.log_text_interval, it
197
+ )
198
+
199
+ global_avg = 0.5 * (global_avg) + 0.5 * (time_spent)
200
+ eta_seconds = global_avg * (max_it - it) / 100
201
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
202
+ print(f"ETA: {eta_string}")
203
+
204
+ self.last_time = time.time()
205
+ self.train_integrator.finalize("train", it)
206
+ self.train_integrator.reset_except_hooks()
207
+
208
+ if it % self.save_network_interval == 0 and it != 0:
209
+ if self.logger is not None:
210
+ self.save_network(it)
211
+
212
+ if it % self.save_checkpoint_interval == 0 and it != 0:
213
+ if self.logger is not None:
214
+ self.save_checkpoint(it)
215
+
216
+ # Backward pass
217
+ self.optimizer.zero_grad(set_to_none=True)
218
+ if self.config["amp"]:
219
+ self.scaler.scale(losses["total_loss"]).backward()
220
+ self.scaler.step(self.optimizer)
221
+ self.scaler.update()
222
+ else:
223
+ losses["total_loss"].backward()
224
+ self.optimizer.step()
225
+
226
+ self.scheduler.step()
227
+
228
+ def save_network(self, it):
229
+ if self.save_path is None:
230
+ print("Saving has been disabled.")
231
+ return
232
+
233
+ os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
234
+ model_path = f"{self.save_path}_{it}.pth"
235
+ torch.save(self.XMem.module.state_dict(), model_path)
236
+ print(f"Network saved to {model_path}.")
237
+
238
+ def save_checkpoint(self, it):
239
+ if self.save_path is None:
240
+ print("Saving has been disabled.")
241
+ return
242
+
243
+ os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
244
+ checkpoint_path = f"{self.save_path}_checkpoint_{it}.pth"
245
+ checkpoint = {
246
+ "it": it,
247
+ "network": self.XMem.module.state_dict(),
248
+ "optimizer": self.optimizer.state_dict(),
249
+ "scheduler": self.scheduler.state_dict(),
250
+ }
251
+ torch.save(checkpoint, checkpoint_path)
252
+ print(f"Checkpoint saved to {checkpoint_path}.")
253
+
254
+ def load_checkpoint(self, path):
255
+ # This method loads everything and should be used to resume training
256
+ map_location = "cuda:%d" % self.local_rank
257
+ checkpoint = torch.load(path, map_location={"cuda:0": map_location})
258
+
259
+ it = checkpoint["it"]
260
+ network = checkpoint["network"]
261
+ optimizer = checkpoint["optimizer"]
262
+ scheduler = checkpoint["scheduler"]
263
+
264
+ map_location = "cuda:%d" % self.local_rank
265
+ self.XMem.module.load_state_dict(network)
266
+ self.optimizer.load_state_dict(optimizer)
267
+ self.scheduler.load_state_dict(scheduler)
268
+
269
+ print("Network weights, optimizer states, and scheduler states loaded.")
270
+
271
+ return it
272
+
273
+ def load_network_in_memory(self, src_dict):
274
+ self.XMem.module.load_weights(src_dict)
275
+ print("Network weight loaded from memory.")
276
+
277
+ def load_network(self, path):
278
+ # This method loads only the network weight and should be used to load a pretrained model
279
+ map_location = "cuda:%d" % self.local_rank
280
+ src_dict = torch.load(path, map_location={"cuda:0": map_location})
281
+
282
+ self.load_network_in_memory(src_dict)
283
+ print(f"Network weight loaded from {path}")
284
+
285
+ def train(self):
286
+ self._is_train = True
287
+ self._do_log = True
288
+ self.integrator = self.train_integrator
289
+ self.XMem.eval()
290
+ return self
291
+
292
+ def val(self):
293
+ self._is_train = False
294
+ self._do_log = True
295
+ self.XMem.eval()
296
+ return self
297
+
298
+ def test(self):
299
+ self._is_train = False
300
+ self._do_log = False
301
+ self.XMem.eval()
302
+ return self
tracker/util/__init__.py ADDED
File without changes
tracker/util/mask_mapper.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def all_to_onehot(masks, labels):
6
+ if len(masks.shape) == 3:
7
+ Ms = np.zeros(
8
+ (len(labels), masks.shape[0], masks.shape[1], masks.shape[2]),
9
+ dtype=np.uint8,
10
+ )
11
+ else:
12
+ Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
13
+
14
+ for ni, l in enumerate(labels):
15
+ Ms[ni] = (masks == l).astype(np.uint8)
16
+
17
+ return Ms
18
+
19
+
20
+ class MaskMapper:
21
+ """
22
+ This class is used to convert a indexed-mask to a one-hot representation.
23
+ It also takes care of remapping non-continuous indices
24
+ It has two modes:
25
+ 1. Default. Only masks with new indices are supposed to go into the remapper.
26
+ This is also the case for YouTubeVOS.
27
+ i.e., regions with index 0 are not "background", but "don't care".
28
+
29
+ 2. Exhaustive. Regions with index 0 are considered "background".
30
+ Every single pixel is considered to be "labeled".
31
+ """
32
+
33
+ def __init__(self):
34
+ self.labels = []
35
+ self.remappings = {}
36
+
37
+ # if coherent, no mapping is required
38
+ self.coherent = True
39
+
40
+ def clear_labels(self):
41
+ self.labels = []
42
+ self.remappings = {}
43
+ # if coherent, no mapping is required
44
+ self.coherent = True
45
+
46
+ def convert_mask(self, mask, exhaustive=False):
47
+ # mask is in index representation, H*W numpy array
48
+ labels = np.unique(mask).astype(np.uint8)
49
+ labels = labels[labels != 0].tolist()
50
+
51
+ new_labels = list(set(labels) - set(self.labels))
52
+ if not exhaustive:
53
+ assert len(new_labels) == len(
54
+ labels
55
+ ), "Old labels found in non-exhaustive mode"
56
+
57
+ # add new remappings
58
+ for i, l in enumerate(new_labels):
59
+ self.remappings[l] = i + len(self.labels) + 1
60
+ if self.coherent and i + len(self.labels) + 1 != l:
61
+ self.coherent = False
62
+
63
+ if exhaustive:
64
+ new_mapped_labels = range(1, len(self.labels) + len(new_labels) + 1)
65
+ else:
66
+ if self.coherent:
67
+ new_mapped_labels = new_labels
68
+ else:
69
+ new_mapped_labels = range(
70
+ len(self.labels) + 1, len(self.labels) + len(new_labels) + 1
71
+ )
72
+
73
+ self.labels.extend(new_labels)
74
+ mask = torch.from_numpy(all_to_onehot(mask, self.labels)).float()
75
+
76
+ # mask num_objects*H*W
77
+ return mask, new_mapped_labels
78
+
79
+ def remap_index_mask(self, mask):
80
+ # mask is in index representation, H*W numpy array
81
+ if self.coherent:
82
+ return mask
83
+
84
+ new_mask = np.zeros_like(mask)
85
+ for l, i in self.remappings.items():
86
+ new_mask[mask == i] = l
87
+ return new_mask
tracker/util/range_transform.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms as transforms
2
+
3
+ im_mean = (124, 116, 104)
4
+
5
+ im_normalization = transforms.Normalize(
6
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
7
+ )
8
+
9
+ inv_im_trans = transforms.Normalize(
10
+ mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
11
+ std=[1 / 0.229, 1 / 0.224, 1 / 0.225],
12
+ )
tracker/util/tensor_util.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+
4
+ def compute_tensor_iu(seg, gt):
5
+ intersection = (seg & gt).float().sum()
6
+ union = (seg | gt).float().sum()
7
+
8
+ return intersection, union
9
+
10
+
11
+ def compute_tensor_iou(seg, gt):
12
+ intersection, union = compute_tensor_iu(seg, gt)
13
+ iou = (intersection + 1e-6) / (union + 1e-6)
14
+
15
+ return iou
16
+
17
+
18
+ # STM
19
+ def pad_divide_by(in_img, d):
20
+ h, w = in_img.shape[-2:]
21
+
22
+ if h % d > 0:
23
+ new_h = h + d - h % d
24
+ else:
25
+ new_h = h
26
+ if w % d > 0:
27
+ new_w = w + d - w % d
28
+ else:
29
+ new_w = w
30
+ lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
31
+ lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
32
+ pad_array = (int(lw), int(uw), int(lh), int(uh))
33
+ out = F.pad(in_img, pad_array)
34
+ return out, pad_array
35
+
36
+
37
+ def unpad(img, pad):
38
+ if len(img.shape) == 4:
39
+ if pad[2] + pad[3] > 0:
40
+ img = img[:, :, pad[2] : -pad[3], :]
41
+ if pad[0] + pad[1] > 0:
42
+ img = img[:, :, :, pad[0] : -pad[1]]
43
+ elif len(img.shape) == 3:
44
+ if pad[2] + pad[3] > 0:
45
+ img = img[:, pad[2] : -pad[3], :]
46
+ if pad[0] + pad[1] > 0:
47
+ img = img[:, :, pad[0] : -pad[1]]
48
+ else:
49
+ raise NotImplementedError
50
+ return img
utils/base_segmenter.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class BaseSegmenter:
6
+ def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device="cuda:0"):
7
+ """
8
+ device: model device
9
+ SAM_checkpoint: path of SAM checkpoint
10
+ model_type: vit_b, vit_l, vit_h, vit_t
11
+ """
12
+ print(f"Initializing BaseSegmenter to {device}")
13
+ assert model_type in [
14
+ "vit_b",
15
+ "vit_l",
16
+ "vit_h",
17
+ "vit_t",
18
+ ], "model_type must be vit_b, vit_l, vit_h or vit_t"
19
+
20
+ self.device = device
21
+ self.torch_dtype = torch.float16 if "cuda" in device else torch.float32
22
+
23
+ if (model_type == "vit_t"):
24
+ from mobile_sam import sam_model_registry, SamPredictor
25
+ from onnxruntime import InferenceSession
26
+ self.ort_session = InferenceSession(sam_onnx_checkpoint)
27
+ self.predict = self.predict_onnx
28
+ else:
29
+ from segment_anything import sam_model_registry, SamPredictor
30
+ self.predict = self.predict_pt
31
+
32
+ self.model = sam_model_registry[model_type](checkpoint=sam_pt_checkpoint)
33
+ self.model.to(device=self.device)
34
+ self.predictor = SamPredictor(self.model)
35
+ self.embedded = False
36
+
37
+ @torch.no_grad()
38
+ def set_image(self, image: np.ndarray):
39
+ # PIL.open(image_path) 3channel: RGB
40
+ # image embedding: avoid encode the same image multiple times
41
+ self.orignal_image = image
42
+ if self.embedded:
43
+ print("repeat embedding, please reset_image.")
44
+ return
45
+ self.predictor.set_image(image)
46
+ self.image_embedding = self.predictor.get_image_embedding().cpu().numpy()
47
+ self.embedded = True
48
+ return
49
+
50
+ @torch.no_grad()
51
+ def reset_image(self):
52
+ # reset image embeding
53
+ self.predictor.reset_image()
54
+ self.embedded = False
55
+
56
+ def predict_pt(self, prompts, mode, multimask=True):
57
+ """
58
+ image: numpy array, h, w, 3
59
+ prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
60
+ prompts['point_coords']: numpy array [N,2]
61
+ prompts['point_labels']: numpy array [1,N]
62
+ prompts['mask_input']: numpy array [1,256,256]
63
+ mode: 'point' (points only), 'mask' (mask only), 'both' (consider both)
64
+ mask_outputs: True (return 3 masks), False (return 1 mask only)
65
+ whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :]
66
+ """
67
+ assert (
68
+ self.embedded
69
+ ), "prediction is called before set_image (feature embedding)."
70
+ assert mode in ["point", "mask", "both"], "mode must be point, mask, or both"
71
+
72
+ if mode == "point":
73
+ masks, scores, logits = self.predictor.predict(
74
+ point_coords=prompts["point_coords"],
75
+ point_labels=prompts["point_labels"],
76
+ multimask_output=multimask,
77
+ )
78
+ elif mode == "mask":
79
+ masks, scores, logits = self.predictor.predict(
80
+ mask_input=prompts["mask_input"], multimask_output=multimask
81
+ )
82
+ elif mode == "both": # both
83
+ masks, scores, logits = self.predictor.predict(
84
+ point_coords=prompts["point_coords"],
85
+ point_labels=prompts["point_labels"],
86
+ mask_input=prompts["mask_input"],
87
+ multimask_output=multimask,
88
+ )
89
+ else:
90
+ raise ("Not implement now!")
91
+ # masks (n, h, w), scores (n,), logits (n, 256, 256)
92
+ return masks, scores, logits
93
+
94
+ def predict_onnx(self, prompts, mode, multimask=True):
95
+ """
96
+ image: numpy array, h, w, 3
97
+ prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
98
+ prompts['point_coords']: numpy array [N,2]
99
+ prompts['point_labels']: numpy array [1,N]
100
+ prompts['mask_input']: numpy array [1,256,256]
101
+ mode: 'point' (points only), 'mask' (mask only), 'both' (consider both)
102
+ mask_outputs: True (return 3 masks), False (return 1 mask only)
103
+ whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :]
104
+ """
105
+ assert (
106
+ self.embedded
107
+ ), "prediction is called before set_image (feature embedding)."
108
+ assert mode in ["point", "mask", "both"], "mode must be point, mask, or both"
109
+
110
+ if mode == "point":
111
+ ort_inputs = {
112
+ "image_embeddings": self.image_embedding,
113
+ "point_coords": prompts["point_coords"],
114
+ "point_labels": prompts["point_labels"],
115
+ "mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32),
116
+ "has_mask_input": np.zeros(1, dtype=np.float32),
117
+ "orig_im_size": prompts["orig_im_size"],
118
+ }
119
+ masks, scores, logits = self.ort_session.run(None, ort_inputs)
120
+ masks = masks > self.predictor.model.mask_threshold
121
+
122
+ elif mode == "mask":
123
+ ort_inputs = {
124
+ "image_embeddings": self.image_embedding,
125
+ "point_coords": np.zeros((len(prompts["point_labels"]), 2), dtype=np.float32),
126
+ "point_labels": prompts["point_labels"],
127
+ "mask_input": prompts["mask_input"],
128
+ "has_mask_input": np.ones(1, dtype=np.float32),
129
+ "orig_im_size": prompts["orig_im_size"],
130
+ }
131
+ masks, scores, logits = self.ort_session.run(None, ort_inputs)
132
+ masks = masks > self.predictor.model.mask_threshold
133
+
134
+ elif mode == "both": # both
135
+ ort_inputs = {
136
+ "image_embeddings": self.image_embedding,
137
+ "point_coords": prompts["point_coords"],
138
+ "point_labels": prompts["point_labels"],
139
+ "mask_input": prompts["mask_input"],
140
+ "has_mask_input": np.ones(1, dtype=np.float32),
141
+ "orig_im_size": prompts["orig_im_size"],
142
+ }
143
+ masks, scores, logits = self.ort_session.run(None, ort_inputs)
144
+ masks = masks > self.predictor.model.mask_threshold
145
+
146
+ else:
147
+ raise ("Not implement now!")
148
+ # masks (n, h, w), scores (n,), logits (n, 256, 256)
149
+ return masks[0], scores[0], logits[0]
utils/blur.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+
5
+
6
+ # resize frames
7
+ def resize_frames(frames, size=None):
8
+ """
9
+ size: (w, h)
10
+ """
11
+ if size is not None:
12
+ frames = [cv2.resize(f, size) for f in frames]
13
+ frames = np.stack(frames, 0)
14
+
15
+ return frames
16
+
17
+
18
+ # resize frames
19
+ def resize_masks(masks, size=None):
20
+ """
21
+ size: (w, h)
22
+ """
23
+ if size is not None:
24
+ masks = [np.expand_dims(cv2.resize(m, size), 2) for m in masks]
25
+ masks = np.stack(masks, 0)
26
+
27
+ return masks
28
+
29
+
30
+ # apply gaussian blur to mask with defined strength
31
+ def apply_blur(frame, strength):
32
+ blurred = cv2.GaussianBlur(frame, (strength, strength), 0)
33
+ return blurred
34
+
35
+
36
+ # blur frames
37
+ def blur_frames_and_write(
38
+ frames, masks, ratio, strength, dilate_radius=15, fps=30, output_path="blurred.mp4"
39
+ ):
40
+ assert frames.shape[:3] == masks.shape, "different size between frames and masks"
41
+ assert ratio > 0 and ratio <= 1, "ratio must in (0, 1]"
42
+
43
+ # --------------------
44
+ # pre-processing
45
+ # --------------------
46
+ masks = masks.copy()
47
+ masks = np.clip(masks, 0, 1)
48
+ kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius))
49
+ masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
50
+ T, H, W = masks.shape
51
+ masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
52
+ # size: (w, h)
53
+ if ratio == 1:
54
+ size = (W, H)
55
+ binary_masks = masks
56
+ else:
57
+ size = [int(W * ratio), int(H * ratio)]
58
+ size = [
59
+ si + 1 if si % 2 > 0 else si for si in size
60
+ ] # only consider even values
61
+ # shortest side should be larger than 50
62
+ if min(size) < 50:
63
+ ratio = 50.0 / min(H, W)
64
+ size = [int(W * ratio), int(H * ratio)]
65
+ binary_masks = resize_masks(masks, tuple(size))
66
+ frames = resize_frames(frames, tuple(size)) # T, H, W, 3
67
+
68
+ if not os.path.exists(os.path.dirname(output_path)):
69
+ os.makedirs(os.path.dirname(output_path))
70
+ writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, size)
71
+
72
+ for frame, mask in zip(frames, binary_masks):
73
+ blurred_frame = apply_blur(frame, strength)
74
+ masked = cv2.bitwise_or(blurred_frame, blurred_frame, mask=mask)
75
+ processed = np.where(masked == (0, 0, 0), frame, masked)
76
+
77
+ writer.write(processed[:, :, ::-1])
78
+
79
+ writer.release()
80
+
81
+ return output_path
utils/interact_tools.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ from .base_segmenter import BaseSegmenter
4
+ from .painter import mask_painter, point_painter
5
+
6
+
7
+ mask_color = 3
8
+ mask_alpha = 0.7
9
+ contour_color = 1
10
+ contour_width = 5
11
+ point_color_ne = 8
12
+ point_color_ps = 50
13
+ point_alpha = 0.9
14
+ point_radius = 15
15
+ contour_color = 2
16
+ contour_width = 5
17
+
18
+
19
+ class SamControler:
20
+ def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device):
21
+ """
22
+ initialize sam controler
23
+ """
24
+
25
+ self.sam_controler = BaseSegmenter(sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device)
26
+ self.onnx = model_type == "vit_t"
27
+
28
+ def first_frame_click(
29
+ self,
30
+ image: np.ndarray,
31
+ points: np.ndarray,
32
+ labels: np.ndarray,
33
+ multimask=True,
34
+ mask_color=3,
35
+ ):
36
+ """
37
+ it is used in first frame in video
38
+ return: mask, logit, painted image(mask+point)
39
+ """
40
+ # self.sam_controler.set_image(image)
41
+ neg_flag = labels[-1]
42
+
43
+ if self.onnx:
44
+ onnx_coord = np.concatenate([points, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
45
+ onnx_label = np.concatenate([labels, np.array([-1])], axis=0)[None, :].astype(np.float32)
46
+ onnx_coord = self.sam_controler.predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
47
+ prompts = {
48
+ "point_coords": onnx_coord,
49
+ "point_labels": onnx_label,
50
+ "orig_im_size": np.array(image.shape[:2], dtype=np.float32),
51
+ }
52
+
53
+ else:
54
+ prompts = {
55
+ "point_coords": points,
56
+ "point_labels": labels,
57
+ }
58
+
59
+ if neg_flag == 1:
60
+ # find positive
61
+ masks, scores, logits = self.sam_controler.predict(
62
+ prompts, "point", multimask
63
+ )
64
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
65
+
66
+ prompts["mask_input"] = np.expand_dims(logit[None, :, :], 0)
67
+ masks, scores, logits = self.sam_controler.predict(
68
+ prompts, "both", multimask
69
+ )
70
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
71
+
72
+ else:
73
+ # find neg
74
+ masks, scores, logits = self.sam_controler.predict(
75
+ prompts, "point", multimask
76
+ )
77
+ mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
78
+
79
+ assert len(points) == len(labels)
80
+
81
+ painted_image = mask_painter(
82
+ image,
83
+ mask.astype("uint8"),
84
+ mask_color,
85
+ mask_alpha,
86
+ contour_color,
87
+ contour_width,
88
+ )
89
+ painted_image = point_painter(
90
+ painted_image,
91
+ np.squeeze(points[np.argwhere(labels > 0)], axis=1),
92
+ point_color_ne,
93
+ point_alpha,
94
+ point_radius,
95
+ contour_color,
96
+ contour_width,
97
+ )
98
+ painted_image = point_painter(
99
+ painted_image,
100
+ np.squeeze(points[np.argwhere(labels < 1)], axis=1),
101
+ point_color_ps,
102
+ point_alpha,
103
+ point_radius,
104
+ contour_color,
105
+ contour_width,
106
+ )
107
+ painted_image = Image.fromarray(painted_image)
108
+
109
+ return mask, logit, painted_image
utils/painter.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+
6
+ def colormap(rgb=True):
7
+ color_list = np.array(
8
+ [
9
+ 0.000,
10
+ 0.000,
11
+ 0.000,
12
+ 1.000,
13
+ 1.000,
14
+ 1.000,
15
+ 1.000,
16
+ 0.498,
17
+ 0.313,
18
+ 0.392,
19
+ 0.581,
20
+ 0.929,
21
+ 0.000,
22
+ 0.447,
23
+ 0.741,
24
+ 0.850,
25
+ 0.325,
26
+ 0.098,
27
+ 0.929,
28
+ 0.694,
29
+ 0.125,
30
+ 0.494,
31
+ 0.184,
32
+ 0.556,
33
+ 0.466,
34
+ 0.674,
35
+ 0.188,
36
+ 0.301,
37
+ 0.745,
38
+ 0.933,
39
+ 0.635,
40
+ 0.078,
41
+ 0.184,
42
+ 0.300,
43
+ 0.300,
44
+ 0.300,
45
+ 0.600,
46
+ 0.600,
47
+ 0.600,
48
+ 1.000,
49
+ 0.000,
50
+ 0.000,
51
+ 1.000,
52
+ 0.500,
53
+ 0.000,
54
+ 0.749,
55
+ 0.749,
56
+ 0.000,
57
+ 0.000,
58
+ 1.000,
59
+ 0.000,
60
+ 0.000,
61
+ 0.000,
62
+ 1.000,
63
+ 0.667,
64
+ 0.000,
65
+ 1.000,
66
+ 0.333,
67
+ 0.333,
68
+ 0.000,
69
+ 0.333,
70
+ 0.667,
71
+ 0.000,
72
+ 0.333,
73
+ 1.000,
74
+ 0.000,
75
+ 0.667,
76
+ 0.333,
77
+ 0.000,
78
+ 0.667,
79
+ 0.667,
80
+ 0.000,
81
+ 0.667,
82
+ 1.000,
83
+ 0.000,
84
+ 1.000,
85
+ 0.333,
86
+ 0.000,
87
+ 1.000,
88
+ 0.667,
89
+ 0.000,
90
+ 1.000,
91
+ 1.000,
92
+ 0.000,
93
+ 0.000,
94
+ 0.333,
95
+ 0.500,
96
+ 0.000,
97
+ 0.667,
98
+ 0.500,
99
+ 0.000,
100
+ 1.000,
101
+ 0.500,
102
+ 0.333,
103
+ 0.000,
104
+ 0.500,
105
+ 0.333,
106
+ 0.333,
107
+ 0.500,
108
+ 0.333,
109
+ 0.667,
110
+ 0.500,
111
+ 0.333,
112
+ 1.000,
113
+ 0.500,
114
+ 0.667,
115
+ 0.000,
116
+ 0.500,
117
+ 0.667,
118
+ 0.333,
119
+ 0.500,
120
+ 0.667,
121
+ 0.667,
122
+ 0.500,
123
+ 0.667,
124
+ 1.000,
125
+ 0.500,
126
+ 1.000,
127
+ 0.000,
128
+ 0.500,
129
+ 1.000,
130
+ 0.333,
131
+ 0.500,
132
+ 1.000,
133
+ 0.667,
134
+ 0.500,
135
+ 1.000,
136
+ 1.000,
137
+ 0.500,
138
+ 0.000,
139
+ 0.333,
140
+ 1.000,
141
+ 0.000,
142
+ 0.667,
143
+ 1.000,
144
+ 0.000,
145
+ 1.000,
146
+ 1.000,
147
+ 0.333,
148
+ 0.000,
149
+ 1.000,
150
+ 0.333,
151
+ 0.333,
152
+ 1.000,
153
+ 0.333,
154
+ 0.667,
155
+ 1.000,
156
+ 0.333,
157
+ 1.000,
158
+ 1.000,
159
+ 0.667,
160
+ 0.000,
161
+ 1.000,
162
+ 0.667,
163
+ 0.333,
164
+ 1.000,
165
+ 0.667,
166
+ 0.667,
167
+ 1.000,
168
+ 0.667,
169
+ 1.000,
170
+ 1.000,
171
+ 1.000,
172
+ 0.000,
173
+ 1.000,
174
+ 1.000,
175
+ 0.333,
176
+ 1.000,
177
+ 1.000,
178
+ 0.667,
179
+ 1.000,
180
+ 0.167,
181
+ 0.000,
182
+ 0.000,
183
+ 0.333,
184
+ 0.000,
185
+ 0.000,
186
+ 0.500,
187
+ 0.000,
188
+ 0.000,
189
+ 0.667,
190
+ 0.000,
191
+ 0.000,
192
+ 0.833,
193
+ 0.000,
194
+ 0.000,
195
+ 1.000,
196
+ 0.000,
197
+ 0.000,
198
+ 0.000,
199
+ 0.167,
200
+ 0.000,
201
+ 0.000,
202
+ 0.333,
203
+ 0.000,
204
+ 0.000,
205
+ 0.500,
206
+ 0.000,
207
+ 0.000,
208
+ 0.667,
209
+ 0.000,
210
+ 0.000,
211
+ 0.833,
212
+ 0.000,
213
+ 0.000,
214
+ 1.000,
215
+ 0.000,
216
+ 0.000,
217
+ 0.000,
218
+ 0.167,
219
+ 0.000,
220
+ 0.000,
221
+ 0.333,
222
+ 0.000,
223
+ 0.000,
224
+ 0.500,
225
+ 0.000,
226
+ 0.000,
227
+ 0.667,
228
+ 0.000,
229
+ 0.000,
230
+ 0.833,
231
+ 0.000,
232
+ 0.000,
233
+ 1.000,
234
+ 0.143,
235
+ 0.143,
236
+ 0.143,
237
+ 0.286,
238
+ 0.286,
239
+ 0.286,
240
+ 0.429,
241
+ 0.429,
242
+ 0.429,
243
+ 0.571,
244
+ 0.571,
245
+ 0.571,
246
+ 0.714,
247
+ 0.714,
248
+ 0.714,
249
+ 0.857,
250
+ 0.857,
251
+ 0.857,
252
+ ]
253
+ ).astype(np.float32)
254
+ color_list = color_list.reshape((-1, 3)) * 255
255
+ if not rgb:
256
+ color_list = color_list[:, ::-1]
257
+ return color_list
258
+
259
+
260
+ color_list = colormap()
261
+ color_list = color_list.astype("uint8").tolist()
262
+
263
+
264
+ def vis_add_mask(image, mask, color, alpha):
265
+ color = np.array(color_list[color])
266
+ mask = mask > 0.5
267
+ image[mask] = image[mask] * (1 - alpha) + color * alpha
268
+ return image.astype("uint8")
269
+
270
+
271
+ def point_painter(
272
+ input_image,
273
+ input_points,
274
+ point_color=5,
275
+ point_alpha=0.9,
276
+ point_radius=15,
277
+ contour_color=2,
278
+ contour_width=5,
279
+ ):
280
+ h, w = input_image.shape[:2]
281
+ point_mask = np.zeros((h, w)).astype("uint8")
282
+ for point in input_points:
283
+ point_mask[point[1], point[0]] = 1
284
+
285
+ kernel = cv2.getStructuringElement(2, (point_radius, point_radius))
286
+ point_mask = cv2.dilate(point_mask, kernel)
287
+
288
+ contour_radius = (contour_width - 1) // 2
289
+ dist_transform_fore = cv2.distanceTransform(point_mask, cv2.DIST_L2, 3)
290
+ dist_transform_back = cv2.distanceTransform(1 - point_mask, cv2.DIST_L2, 3)
291
+ dist_map = dist_transform_fore - dist_transform_back
292
+ # ...:::!!!:::...
293
+ contour_radius += 2
294
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
295
+ contour_mask = contour_mask / np.max(contour_mask)
296
+ contour_mask[contour_mask > 0.5] = 1.0
297
+
298
+ # paint mask
299
+ painted_image = vis_add_mask(
300
+ input_image.copy(), point_mask, point_color, point_alpha
301
+ )
302
+ # paint contour
303
+ painted_image = vis_add_mask(
304
+ painted_image.copy(), 1 - contour_mask, contour_color, 1
305
+ )
306
+ return painted_image
307
+
308
+
309
+ def mask_painter(
310
+ input_image,
311
+ input_mask,
312
+ mask_color=5,
313
+ mask_alpha=0.7,
314
+ contour_color=1,
315
+ contour_width=3,
316
+ ):
317
+ assert (
318
+ input_image.shape[:2] == input_mask.shape
319
+ ), "different shape between image and mask"
320
+ # 0: background, 1: foreground
321
+ mask = np.clip(input_mask, 0, 1)
322
+ contour_radius = (contour_width - 1) // 2
323
+
324
+ dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
325
+ dist_transform_back = cv2.distanceTransform(1 - mask, cv2.DIST_L2, 3)
326
+ dist_map = dist_transform_fore - dist_transform_back
327
+ # ...:::!!!:::...
328
+ contour_radius += 2
329
+ contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
330
+ contour_mask = contour_mask / np.max(contour_mask)
331
+ contour_mask[contour_mask > 0.5] = 1.0
332
+
333
+ # paint mask
334
+ painted_image = vis_add_mask(
335
+ input_image.copy(), mask.copy(), mask_color, mask_alpha
336
+ )
337
+ # paint contour
338
+ painted_image = vis_add_mask(
339
+ painted_image.copy(), 1 - contour_mask, contour_color, 1
340
+ )
341
+
342
+ return painted_image
343
+
344
+
345
+ def background_remover(input_image, input_mask):
346
+ """
347
+ input_image: H, W, 3, np.array
348
+ input_mask: H, W, np.array
349
+
350
+ image_wo_background: PIL.Image
351
+ """
352
+ assert (
353
+ input_image.shape[:2] == input_mask.shape
354
+ ), "different shape between image and mask"
355
+ # 0: background, 1: foreground
356
+ mask = np.expand_dims(np.clip(input_mask, 0, 1), axis=2) * 255
357
+ image_wo_background = np.concatenate([input_image, mask], axis=2) # H, W, 4
358
+ image_wo_background = Image.fromarray(image_wo_background).convert("RGBA")
359
+
360
+ return image_wo_background