PhyscalX commited on
Commit
3d2142b
1 Parent(s): b28a01f
app.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Gradio application."""
17
+
18
+ import argparse
19
+ import multiprocessing as mp
20
+ import os
21
+ import time
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+ from tokenize_anything import test_engine
27
+ from tokenize_anything.utils.image import im_rescale
28
+ from tokenize_anything.utils.image import im_vstack
29
+
30
+
31
+ def parse_args():
32
+ """Parse arguments."""
33
+ parser = argparse.ArgumentParser(description="Launch gradio app.")
34
+ parser.add_argument("--model-type", type=str, default="tap_vit_l")
35
+ parser.add_argument("--checkpoint", type=str, default="models/tap_vit_l_03f8ec.pkl")
36
+ parser.add_argument("--concept", type=str, default="concepts/merged_2560.pkl")
37
+ parser.add_argument("--device", nargs="+", type=int, default=[0], help="Index of devices.")
38
+ return parser.parse_args()
39
+
40
+
41
+ class Predictor(object):
42
+ """Predictor."""
43
+
44
+ def __init__(self, model, kwargs):
45
+ self.model = model
46
+ self.kwargs = kwargs
47
+ self.batch_size = kwargs.get("batch_size", 256)
48
+ self.model.concept_projector.reset_weights(kwargs["concept_weights"])
49
+ self.model.text_decoder.reset_cache(max_batch_size=self.batch_size)
50
+
51
+ def preprocess_images(self, imgs):
52
+ """Preprocess the inference images."""
53
+ im_batch, im_shapes, im_scales = [], [], []
54
+ for img in imgs:
55
+ scaled_imgs, scales = im_rescale(img, scales=[1024])
56
+ im_batch += scaled_imgs
57
+ im_scales += scales
58
+ im_shapes += [x.shape[:2] for x in scaled_imgs]
59
+ im_batch = im_vstack(im_batch, self.model.pixel_mean_value, size=(1024, 1024))
60
+ im_shapes = np.array(im_shapes)
61
+ im_scales = np.array(im_scales).reshape((len(im_batch), -1))
62
+ im_info = np.hstack([im_shapes, im_scales]).astype("float32")
63
+ return im_batch, im_info
64
+
65
+ @torch.inference_mode()
66
+ def get_results(self, examples):
67
+ """Return the results."""
68
+ # Preprocess images and prompts.
69
+ imgs = [example["img"] for example in examples]
70
+ points = np.concatenate([example["points"] for example in examples])
71
+ im_batch, im_info = self.preprocess_images(imgs)
72
+ num_prompts = points.shape[0] if len(points.shape) > 2 else 1
73
+ batch_shape = im_batch.shape[0], num_prompts // im_batch.shape[0]
74
+ batch_points = points.reshape(batch_shape + (-1, 3))
75
+ batch_points[:, :, :, :2] *= im_info[:, None, None, 2:4]
76
+ batch_points = batch_points.reshape(points.shape)
77
+ # Predict tokens and masks.
78
+ inputs = self.model.get_inputs({"img": im_batch})
79
+ inputs.update(self.model.get_features(inputs))
80
+ outputs = self.model.get_outputs(dict(**inputs, **{"points": batch_points}))
81
+ # Select final mask.
82
+ iou_pred = outputs["iou_pred"].cpu().numpy()
83
+ point_score = batch_points[:, 0, 2].__eq__(2).__sub__(0.5)[:, None]
84
+ rank_scores = iou_pred + point_score * ([1000] + [0] * (iou_pred.shape[1] - 1))
85
+ mask_index = np.arange(rank_scores.shape[0]), rank_scores.argmax(1)
86
+ iou_scores = outputs["iou_pred"][mask_index].cpu().numpy().reshape(batch_shape)
87
+ # Upscale masks to the original image resolution.
88
+ mask_pred = outputs["mask_pred"][mask_index][:, None]
89
+ mask_pred = self.model.upscale_masks(mask_pred, im_batch.shape[1:-1])
90
+ mask_pred = mask_pred.view(batch_shape + mask_pred.shape[2:])
91
+ # Predict concepts.
92
+ concepts, scores = self.model.predict_concept(outputs["sem_embeds"][mask_index])
93
+ concepts, scores = [x.reshape(batch_shape) for x in (concepts, scores)]
94
+ # Generate captions.
95
+ sem_tokens = outputs["sem_tokens"][mask_index][:, None, :]
96
+ captions = self.model.generate_text(sem_tokens).reshape(batch_shape)
97
+ # Postprecess results.
98
+ results = []
99
+ for i in range(batch_shape[0]):
100
+ pred_h, pred_w = im_info[i, :2].astype("int")
101
+ masks = mask_pred[i : i + 1, :, :pred_h, :pred_w]
102
+ masks = self.model.upscale_masks(masks, imgs[i].shape[:2])[0]
103
+ results.append(
104
+ {
105
+ "scores": np.stack([iou_scores[i], scores[i]], axis=-1),
106
+ "masks": masks.gt(0).cpu().numpy().astype("uint8"),
107
+ "concepts": concepts[i],
108
+ "captions": captions[i],
109
+ }
110
+ )
111
+ return results
112
+
113
+
114
+ class ServingCommand(object):
115
+ """Command to run serving."""
116
+
117
+ def __init__(self, output_queue):
118
+ self.output_queue = output_queue
119
+ self.output_dict = mp.Manager().dict()
120
+ self.output_index = mp.Value("i", 0)
121
+
122
+ def postprocess_outputs(self, outputs):
123
+ """Main the detection objects."""
124
+ scores, masks = outputs["scores"], outputs["masks"]
125
+ concepts, captions = outputs["concepts"], outputs["captions"]
126
+ text_template = "{} ({:.2f}, {:.2f}): {}"
127
+ text_contents = concepts, scores[:, 0], scores[:, 1], captions
128
+ texts = np.array([text_template.format(*vals) for vals in zip(*text_contents)])
129
+ return masks, texts
130
+
131
+ def run(self):
132
+ """Main loop to make the serving outputs."""
133
+ while True:
134
+ img_id, outputs = self.output_queue.get()
135
+ self.output_dict[img_id] = self.postprocess_outputs(outputs)
136
+
137
+
138
+ def build_gradio_app(queues, command):
139
+ """Build the gradio application."""
140
+ import cv2
141
+ import gradio as gr
142
+ import gradio_image_prompter as gr_ext
143
+
144
+ title = "Tokenize Anything"
145
+ header = (
146
+ "<div align='center'>"
147
+ f"<h1>{title}</h1>"
148
+ "<h3>A promptable model capable of simultaneously segmenting, recognizing and captioning</h3>"
149
+ "</div>"
150
+ )
151
+ theme = "soft"
152
+ css = """#anno-img .mask {opacity: 0.5; transition: all 0.2s ease-in-out;}
153
+ #anno-img .mask.active {opacity: 0.7}"""
154
+
155
+ def get_examples():
156
+ assets_dir = os.path.join(os.path.dirname(__file__), "../assets")
157
+ app_images = list(filter(lambda x: x.startswith("app_image"), os.listdir(assets_dir)))
158
+ app_images.sort()
159
+ return [{"image": os.path.join(assets_dir, x)} for x in app_images]
160
+
161
+ def on_prompt_opt(index):
162
+ click_img = gr.Image(None, visible=index == 0)
163
+ draw_img = gr.ImageEditor(None, visible=index != 0)
164
+ anno_img = gr.AnnotatedImage(None)
165
+ return click_img, draw_img, anno_img
166
+
167
+ def on_reset_btn():
168
+ click_img, draw_img = gr.Image(None), gr.ImageEditor(None)
169
+ anno_img = gr.AnnotatedImage(None)
170
+ return click_img, draw_img, anno_img
171
+
172
+ def on_submit_btn(click_img, mask_img, prompt, multipoint):
173
+ if prompt == 0:
174
+ img = cv2.imread(click_img["image"])
175
+ points = np.array(click_img["points"]).reshape((-1, 2, 3))
176
+ if multipoint == 1:
177
+ points = points.reshape((-1, 3))
178
+ lt = points[np.where(points[:, 2] == 2)[0]][None, :, :]
179
+ rb = points[np.where(points[:, 2] == 3)[0]][None, :, :]
180
+ poly = points[np.where(points[:, 2] <= 1)[0]][None, :, :]
181
+ points = [lt, rb, poly] if len(lt) > 0 else [poly, np.array([[[0, 0, 4]]])]
182
+ points = np.concatenate(points, axis=1)
183
+ points = (np.array([[[0, 0, 4]]]) if len(points) == 0 else points).astype("float32")
184
+ elif prompt == 1:
185
+ img, points = mask_img["background"][:, :, (2, 1, 0)], []
186
+ for layer in mask_img["layers"]:
187
+ ys, xs = np.nonzero(layer[:, :, 0])
188
+ keep = np.linspace(0, ys.shape[0], 11, dtype="int64")[1:-1]
189
+ points.append(np.stack([xs[keep][None, :], ys[keep][None, :]], 2))
190
+ points = np.concatenate(points).astype("float32")
191
+ points = np.pad(points, [(0, 0), (0, 0), (0, 1)], constant_values=1)
192
+ pad_points = np.array([[[0, 0, 4]]], "float32").repeat(points.shape[0], 0)
193
+ points = np.concatenate([points, pad_points], axis=1)
194
+ inputs = {"img": img, "points": points}
195
+ with command.output_index.get_lock():
196
+ command.output_index.value += 1
197
+ img_id = command.output_index.value
198
+ queues[img_id % len(queues)].put((img_id, inputs))
199
+ while img_id not in command.output_dict:
200
+ time.sleep(0.005)
201
+ masks, texts = command.output_dict.pop(img_id)
202
+ annotations = [(x, y) for x, y in zip(masks, texts)]
203
+ return inputs["img"][:, :, ::-1], annotations
204
+
205
+ app = gr.Blocks(title=title, theme=theme, css=css).__enter__()
206
+ gr.Markdown(header)
207
+ container, column = gr.Row().__enter__(), gr.Column().__enter__()
208
+ click_img = gr_ext.ImagePrompter(type="filepath", show_label=False)
209
+ draw_img = gr.ImageEditor(type="numpy", show_label=False, visible=False)
210
+ interactions = "LeftClick (FG) | MiddleClick (BG) | PressMove (Box) | Draw (Sketch)"
211
+ gr.Markdown("<h3 style='text-align: center'>[🖱️ | 🖐️]: 🌟🌟 {} 🌟🌟 </h3>".format(interactions))
212
+ row = gr.Row().__enter__()
213
+ prompt_opt = gr.Radio(["Point+Box", "Sketch"], label="Prompt", type="index", value="Point+Box")
214
+ point_opt = gr.Radio(["Batch", "Ensemble"], label="Multipoint", type="index", value="Batch")
215
+ _, row = row.__exit__(), gr.Row().__enter__()
216
+ reset_btn, submit_btn = gr.Button("Reset"), gr.Button("Execute")
217
+ _, row = row.__exit__(), gr.Row().__enter__()
218
+ gr.Examples(get_examples(), inputs=[click_img], label="Examples (for Point+Box only)")
219
+ _, _, column = row.__exit__(), column.__exit__(), gr.Column().__enter__()
220
+ anno_img = gr.AnnotatedImage(elem_id="anno-img", show_label=False)
221
+ reset_btn.click(on_reset_btn, [], [click_img, draw_img, anno_img])
222
+ submit_btn.click(on_submit_btn, [click_img, draw_img, prompt_opt, point_opt], [anno_img])
223
+ prompt_opt.change(on_prompt_opt, [prompt_opt], [click_img, draw_img, anno_img])
224
+ column.__exit__(), container.__exit__(), app.__exit__()
225
+ return app
226
+
227
+
228
+ if __name__ == "__main__":
229
+ args = parse_args()
230
+ queues = [mp.Queue(1024) for _ in range(len(args.device) + 1)]
231
+ commands = [
232
+ test_engine.InferenceCommand(
233
+ queues[i],
234
+ queues[-1],
235
+ kwargs={
236
+ "model_type": args.model_type,
237
+ "weights": args.checkpoint,
238
+ "concept_weights": args.concept,
239
+ "device": args.device[i],
240
+ "predictor_type": Predictor,
241
+ "verbose": i == 0,
242
+ },
243
+ )
244
+ for i in range(len(args.device))
245
+ ]
246
+ commands += [ServingCommand(queues[-1])]
247
+ actors = [mp.Process(target=command.run, daemon=True) for command in commands]
248
+ for actor in actors:
249
+ actor.start()
250
+ app = build_gradio_app(queues[:-1], commands[-1])
251
+ app.queue()
252
+ app.launch()
concepts/merged_2560.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7a17403190a7a44669136d0ab278b1bb1e095bb68eff178c3e2617b2744bbb7
3
+ size 10514948
models/tap_vit_l_03f8ec.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d63a5aba993c34bf29c0466026136e18e25d2bd4ac9e51b8fc407b76c431707d
3
+ size 811637521
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ opencv-python
2
+ Pillow
3
+ gradio-image-prompter
4
+ torch>=2.0.0
5
+ flash-attn>=2.3.3
tokenize_anything/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Simultaneously Segment, Recognize, and Caption Anything with Promptable Tokenization."""
17
+
18
+ from tokenize_anything.build_model import model_registry
19
+ from tokenize_anything.version import __version__
tokenize_anything/build_model.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Build model."""
17
+
18
+ from functools import partial
19
+ import pickle
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ from tokenize_anything.modeling import ConceptProjector
25
+ from tokenize_anything.modeling import ImageDecoder
26
+ from tokenize_anything.modeling import ImageEncoderViT
27
+ from tokenize_anything.modeling import ImageTokenizer
28
+ from tokenize_anything.modeling import PromptEncoder
29
+ from tokenize_anything.modeling import TextDecoder
30
+ from tokenize_anything.modeling import TextTokenizer
31
+
32
+
33
+ def get_device(device_index):
34
+ """Create an available device object."""
35
+ if torch.cuda.is_available():
36
+ return torch.device("cuda", device_index)
37
+ return torch.device("cpu")
38
+
39
+
40
+ def load_weights(module, weights_file, strict=True):
41
+ """Load a weights file."""
42
+ if not weights_file:
43
+ return module._IncompatibleKeys([], [])
44
+ if weights_file.endswith(".pkl"):
45
+ with open(weights_file, "rb") as f:
46
+ state_dict = pickle.load(f)
47
+ for k, v in state_dict.items():
48
+ state_dict[k] = torch.from_numpy(v) if isinstance(v, np.ndarray) else v
49
+ else:
50
+ state_dict = torch.load(weights_file)
51
+ return module.load_state_dict(state_dict, strict=strict)
52
+
53
+
54
+ def vit_encoder(depth, embed_dim, num_heads, out_dim, image_size):
55
+ """Build an image encoder with ViT."""
56
+ return ImageEncoderViT(
57
+ depth=depth,
58
+ embed_dim=embed_dim,
59
+ num_heads=num_heads,
60
+ mlp_ratio=4,
61
+ patch_size=16,
62
+ window_size=16,
63
+ image_size=image_size,
64
+ out_dim=out_dim,
65
+ )
66
+
67
+
68
+ def image_tokenizer(image_encoder, checkpoint=None, device=0, dtype="float16", **kwargs):
69
+ """Build an image tokenizer."""
70
+ image_size = kwargs.get("image_size", 1024)
71
+ prompt_embed_dim = kwargs.get("prompt_embed_dim", 256)
72
+ sem_embed_dim = kwargs.get("sem_embed_dim", 1024)
73
+ text_embed_dim = kwargs.get("text_embed_dim", 512)
74
+ text_decoder_depth = kwargs.get("text_decoder_depth", 12)
75
+ text_seq_len = kwargs.get("text_seq_len", 40)
76
+ text_tokenizer = TextTokenizer()
77
+ model = ImageTokenizer(
78
+ image_encoder=image_encoder(out_dim=prompt_embed_dim, image_size=image_size),
79
+ prompt_encoder=PromptEncoder(embed_dim=prompt_embed_dim, image_size=image_size),
80
+ image_decoder=ImageDecoder(
81
+ depth=2,
82
+ embed_dim=prompt_embed_dim,
83
+ num_heads=prompt_embed_dim // 32,
84
+ num_mask_tokens=4,
85
+ sem_embed_dim=sem_embed_dim,
86
+ ),
87
+ text_tokenizer=text_tokenizer,
88
+ concept_projector=ConceptProjector(),
89
+ text_decoder=TextDecoder(
90
+ depth=text_decoder_depth,
91
+ embed_dim=text_embed_dim,
92
+ num_heads=text_embed_dim // 64,
93
+ mlp_ratio=4,
94
+ prompt_embed_dim=prompt_embed_dim,
95
+ max_seq_len=text_seq_len,
96
+ vocab_size=text_tokenizer.n_words,
97
+ ),
98
+ )
99
+ load_weights(model, checkpoint)
100
+ model = model.to(device=get_device(device))
101
+ model = model.eval() if not kwargs.get("training", False) else model
102
+ model = model.half() if dtype == "float16" else model
103
+ model = model.bfloat16() if dtype == "bfloat16" else model
104
+ model = model.float() if dtype == "float32" else model
105
+ return model
106
+
107
+
108
+ vit_b_encoder = partial(vit_encoder, depth=12, embed_dim=768, num_heads=12)
109
+ vit_l_encoder = partial(vit_encoder, depth=24, embed_dim=1024, num_heads=16)
110
+
111
+ model_registry = {
112
+ "tap_vit_b": partial(image_tokenizer, image_encoder=vit_b_encoder),
113
+ "tap_vit_l": partial(image_tokenizer, image_encoder=vit_l_encoder),
114
+ }
tokenize_anything/modeling/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Modeling components."""
17
+
18
+ from tokenize_anything.modeling.concept_projector import ConceptProjector
19
+ from tokenize_anything.modeling.image_decoder import ImageDecoder
20
+ from tokenize_anything.modeling.image_encoder import ImageEncoderViT
21
+ from tokenize_anything.modeling.image_tokenizer import ImageTokenizer
22
+ from tokenize_anything.modeling.prompt_encoder import PromptEncoder
23
+ from tokenize_anything.modeling.text_decoder import TextDecoder
24
+ from tokenize_anything.modeling.text_tokenizer import TextTokenizer
tokenize_anything/modeling/concept_projector.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Concet projector."""
17
+
18
+ import pickle
19
+
20
+ import numpy as np
21
+ import torch
22
+ from torch import nn
23
+
24
+
25
+ class ConceptProjector(nn.Module):
26
+ """Encode and decode concept using CLIP."""
27
+
28
+ def __init__(self, src_weights=None, tgt_weights=None):
29
+ super(ConceptProjector, self).__init__()
30
+ self.reset_weights(src_weights, tgt_weights)
31
+
32
+ def reset_weights(self, src_weights=None, tgt_weights=None):
33
+ """Reset the normalized projection weights."""
34
+ if src_weights is not None:
35
+ with open(src_weights, "rb") as f:
36
+ self.src_weights, self.concepts = pickle.load(f)
37
+ self.src_weights = torch.from_numpy(self.src_weights)
38
+ self.concepts = np.array(self.concepts)
39
+ if tgt_weights is not None:
40
+ with open(tgt_weights, "rb") as f:
41
+ self.tgt_weights, self.concepts = pickle.load(f)
42
+ self.tgt_weights = torch.from_numpy(self.tgt_weights)
43
+ self.concepts = np.array(self.concepts)
44
+
45
+ @staticmethod
46
+ def maybe_convert(embeds, proj):
47
+ """Convert inputs for safe projection."""
48
+ if embeds.dtype != torch.float32:
49
+ embeds = embeds.float()
50
+ if embeds.device != proj.device:
51
+ proj = proj.to(device=embeds.device)
52
+ return embeds, proj
53
+
54
+ def encode_src(self, src_embeds):
55
+ """Encode source visual embedding via concept projection."""
56
+ src_embeds, self.src_weights = self.maybe_convert(src_embeds, self.src_weights)
57
+ logits = nn.functional.normalize(src_embeds, dim=-1) @ self.src_weights
58
+ return nn.functional.log_softmax(logits, dim=-1)
59
+
60
+ def encode_tgt(self, tgt_embeds):
61
+ """Encode target visual embedding via concept projection."""
62
+ tgt_embeds, self.tgt_weights = self.maybe_convert(tgt_embeds, self.tgt_weights)
63
+ logits = nn.functional.normalize(tgt_embeds, dim=-1) @ self.tgt_weights
64
+ return nn.functional.log_softmax(logits, dim=-1)
65
+
66
+ def decode(self, src_embeds, k=1, return_index=False, return_prob=False):
67
+ """Return the top-k concepts of source visual embedding."""
68
+ src_embeds, self.src_weights = self.maybe_convert(src_embeds, self.src_weights)
69
+ logits = nn.functional.normalize(src_embeds, dim=-1) @ self.src_weights
70
+ probs = nn.functional.softmax(logits, dim=-1)
71
+ if return_prob:
72
+ return probs.cpu().numpy()
73
+ score, index = [x.cpu().numpy() for x in probs.topk(k, -1)]
74
+ return (index if return_index else self.concepts[index]), score
tokenize_anything/modeling/image_decoder.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Image decoder."""
17
+
18
+ try:
19
+ from flash_attn import flash_attn_func
20
+ except ImportError:
21
+ flash_attn_func = None
22
+
23
+ import torch
24
+ from torch import nn
25
+
26
+
27
+ class TransposedLayerNorm(nn.LayerNorm):
28
+ """LayerNorm with pre-transposed spatial axes."""
29
+
30
+ def forward(self, input):
31
+ return super().forward(input.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
32
+
33
+
34
+ class MLP(nn.Module):
35
+ """Two layers MLP."""
36
+
37
+ def __init__(self, dim, mlp_dim, activation_type="ReLU"):
38
+ super(MLP, self).__init__()
39
+ self.fc1 = nn.Linear(dim, mlp_dim)
40
+ self.fc2 = nn.Linear(mlp_dim, dim)
41
+ self.activation = getattr(nn, activation_type)()
42
+ self.activation.inplace = True
43
+
44
+ def forward(self, x):
45
+ return self.fc2(self.activation(self.fc1(x)))
46
+
47
+
48
+ class Attention(nn.Module):
49
+ """Multi-head attention."""
50
+
51
+ def __init__(self, dim=256, num_heads=8, attn_ratio=1):
52
+ super(Attention, self).__init__()
53
+ qkv_dim = int(dim * attn_ratio)
54
+ self.num_heads = num_heads
55
+ self.head_dim = qkv_dim // num_heads
56
+ self.q_proj = nn.Linear(dim, qkv_dim)
57
+ self.k_proj = nn.Linear(dim, qkv_dim)
58
+ self.v_proj = nn.Linear(dim, qkv_dim)
59
+ self.proj = nn.Linear(qkv_dim, dim)
60
+ self.scale = self.head_dim**-0.5
61
+
62
+ def forward(self, q, k, v):
63
+ q = self.q_proj(q).view((-1, q.size(1), self.num_heads, self.head_dim))
64
+ k = self.k_proj(k).view((-1, k.size(1), self.num_heads, self.head_dim))
65
+ v = self.v_proj(v).view((-1, v.size(1), self.num_heads, self.head_dim))
66
+ o = flash_attn_func(q, k, v, softmax_scale=self.scale)
67
+ return self.proj(o.flatten(2))
68
+
69
+
70
+ class Block(nn.Module):
71
+ """Transformer block."""
72
+
73
+ def __init__(
74
+ self,
75
+ dim=256,
76
+ num_heads=8,
77
+ attn_ratio=0.5,
78
+ mlp_dim=2048,
79
+ dropout=0.1,
80
+ activation_type="ReLU",
81
+ skip_first_query_pos=False,
82
+ ):
83
+ super(Block, self).__init__()
84
+ self.self_attn = Attention(dim, num_heads)
85
+ self.norm1 = nn.LayerNorm(dim)
86
+ self.cross_attn_token_to_image = Attention(dim, num_heads, attn_ratio)
87
+ self.norm2 = nn.LayerNorm(dim)
88
+ self.mlp = MLP(dim, mlp_dim, activation_type)
89
+ self.norm3 = nn.LayerNorm(dim)
90
+ self.cross_attn_image_to_token = Attention(dim, num_heads, attn_ratio)
91
+ self.norm4 = nn.LayerNorm(dim)
92
+ self.dropout = nn.Dropout(dropout, inplace=True)
93
+ self.skip_first_query_pos = skip_first_query_pos
94
+
95
+ def forward(self, query, key, query_pos, key_pos):
96
+ if self.skip_first_query_pos:
97
+ query = self.norm1(self.self_attn(query, query, query))
98
+ else:
99
+ q = query + query_pos
100
+ query = self.norm1(self.dropout(self.self_attn(q, q, query)).add_(query))
101
+ q, k = query + query_pos, key + key_pos
102
+ query = self.norm2(self.dropout(self.cross_attn_token_to_image(q, k, key)).add_(query))
103
+ query = self.norm3(self.dropout(self.mlp(query)).add_(query))
104
+ q = query + query_pos
105
+ key = self.norm4(self.cross_attn_image_to_token(k, q, query).add_(key))
106
+ return query, key
107
+
108
+
109
+ class Transformer(nn.Module):
110
+ """Two-way transformer decoder."""
111
+
112
+ def __init__(
113
+ self,
114
+ embed_dim=256,
115
+ num_heads=8,
116
+ attn_ratio=0.5,
117
+ mlp_dim=2048,
118
+ dropout=0.1,
119
+ activation_type="ReLU",
120
+ depth=2,
121
+ ):
122
+ super(Transformer, self).__init__()
123
+ self.blocks = nn.ModuleList(
124
+ Block(
125
+ embed_dim,
126
+ num_heads,
127
+ attn_ratio=attn_ratio,
128
+ mlp_dim=mlp_dim,
129
+ dropout=dropout,
130
+ activation_type=activation_type,
131
+ skip_first_query_pos=i == 0,
132
+ )
133
+ for i in range(depth)
134
+ )
135
+ self.final_attn_token_to_image = Attention(embed_dim, num_heads, attn_ratio)
136
+ self.norm = nn.LayerNorm(embed_dim)
137
+ self.dropout = nn.Dropout(dropout, inplace=True)
138
+
139
+ def forward(self, query, key, query_pos, key_pos):
140
+ for blk in self.blocks:
141
+ query, key = blk(query, key, query_pos, key_pos)
142
+ q, k = query + query_pos, key + key_pos
143
+ query = self.dropout(self.final_attn_token_to_image(q, k, key)).add_(query)
144
+ query = self.norm(query)
145
+ return query, key
146
+
147
+
148
+ class Predictor(nn.Module):
149
+ """MLP predictor."""
150
+
151
+ def __init__(self, in_dim, out_dim, mlp_dim=None, depth=3):
152
+ super(Predictor, self).__init__()
153
+ mlp_dims = [mlp_dim or in_dim] * (depth - 1)
154
+ in_dims, out_dims = [in_dim] + mlp_dims, mlp_dims + [out_dim]
155
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip(in_dims, out_dims))
156
+
157
+ def forward(self, x):
158
+ for fc in self.layers[:-1]:
159
+ x = nn.functional.relu(fc(x), inplace=True)
160
+ return self.layers[-1](x)
161
+
162
+
163
+ class ImageDecoder(nn.Module):
164
+ """Module to decode region tokens and masks."""
165
+
166
+ def __init__(self, depth, embed_dim, num_heads, num_mask_tokens=4, sem_embed_dim=1024):
167
+ super(ImageDecoder, self).__init__()
168
+ self.embed_dim = embed_dim
169
+ self.num_mask_tokens = num_mask_tokens
170
+ self.transformer = Transformer(embed_dim, num_heads=num_heads, depth=depth)
171
+ self.iou_token = nn.Embedding(1, embed_dim)
172
+ self.sem_tokens = nn.Embedding(self.num_mask_tokens, embed_dim)
173
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, embed_dim)
174
+ self.output_conv = nn.Sequential(
175
+ nn.ConvTranspose2d(embed_dim, embed_dim // 4, 2, 2),
176
+ TransposedLayerNorm(embed_dim // 4),
177
+ nn.GELU(),
178
+ nn.ConvTranspose2d(embed_dim // 4, embed_dim // 8, 2, 2),
179
+ nn.GELU(),
180
+ )
181
+ self.mask_pred = nn.ModuleList(
182
+ Predictor(embed_dim, embed_dim // 8) for _ in range(num_mask_tokens)
183
+ )
184
+ self.iou_pred = Predictor(embed_dim, self.num_mask_tokens)
185
+ self.sem_pred = Predictor(embed_dim, sem_embed_dim, 1024)
186
+
187
+ def get_outputs(self, inputs):
188
+ img_embeds = inputs["img_embeds"]
189
+ sparse_embeds = inputs["sparse_embeds"]
190
+ ims_per_batch = img_embeds.size(0)
191
+ prompts_per_batch = sparse_embeds.size(0)
192
+ img_embed_size = img_embeds.shape[2:-1]
193
+ # Prepare query.
194
+ tokens = [self.sem_tokens.weight, self.iou_token.weight, self.mask_tokens.weight]
195
+ query = torch.cat(tokens).unsqueeze_(0).expand(prompts_per_batch, -1, -1)
196
+ query = torch.cat((query, sparse_embeds), dim=1)
197
+ num_tokens = query.shape[1] - sparse_embeds.shape[1]
198
+ # Prepare key.
199
+ key = img_embeds.expand(-1, prompts_per_batch // ims_per_batch, -1, -1, -1)
200
+ key = key.flatten(0, 1).flatten(1, 2)
201
+ # Decode.
202
+ query, key = self.transformer(query, key, query, inputs["img_pos"])
203
+ # Upscale key.
204
+ key = key.transpose(1, 2).view((-1, self.embed_dim) + img_embed_size)
205
+ output_masks = self.output_conv(key).flatten(2)
206
+ # Unpack query.
207
+ tokens = query[:, :num_tokens].unbind(dim=1)
208
+ iou_tokens = tokens[num_tokens - self.num_mask_tokens - 1]
209
+ mask_tokens = tokens[num_tokens - self.num_mask_tokens :]
210
+ sem_tokens = tokens[: self.num_mask_tokens]
211
+ # Predict.
212
+ mask_pred = [f(x) for f, x in zip(self.mask_pred, mask_tokens)]
213
+ mask_pred = torch.stack(mask_pred, dim=1) @ output_masks
214
+ mask_pred_size = list(4 * embed_size for embed_size in img_embed_size)
215
+ mask_pred = mask_pred.view([-1, self.num_mask_tokens] + mask_pred_size)
216
+ outputs = {"iou_pred": self.iou_pred(iou_tokens), "mask_pred": mask_pred}
217
+ outputs["sem_tokens"] = torch.stack(sem_tokens, dim=1)
218
+ outputs["sem_embeds"] = self.sem_pred(outputs["sem_tokens"])
219
+ return outputs
220
+
221
+ def forward(self, inputs):
222
+ outputs = self.get_outputs(inputs)
223
+ outputs["iou_pred"] = outputs["iou_pred"].float()
224
+ return outputs
tokenize_anything/modeling/image_encoder.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ ##############################################################################
15
+ """Image encoder."""
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+
21
+ def space_to_depth(input, block_size):
22
+ """Rearrange blocks of spatial data into depth."""
23
+ if input.dim() == 3:
24
+ hXw, c = input.size()[1:]
25
+ h = w = int(hXw**0.5)
26
+ else:
27
+ h, w, c = input.size()[1:]
28
+ h1, w1 = h // block_size, w // block_size
29
+ c1 = (block_size**2) * c
30
+ input = input.reshape((-1, h1, block_size, w1, block_size, c))
31
+ return input.permute(0, 1, 3, 2, 4, 5).reshape((-1, h1, w1, c1))
32
+
33
+
34
+ def depth_to_space(input, block_size):
35
+ """Rearrange blocks of depth data into spatial."""
36
+ h1, w1, c1 = input.size()[1:]
37
+ h, w = h1 * block_size, w1 * block_size
38
+ c = c1 // (block_size**2)
39
+ input = input.reshape((-1, h1, w1, block_size, block_size, c))
40
+ return input.permute(0, 1, 3, 2, 4, 5).reshape((-1, h, w, c))
41
+
42
+
43
+ class MLP(nn.Module):
44
+ """Two layers MLP."""
45
+
46
+ def __init__(self, dim, mlp_ratio=4):
47
+ super(MLP, self).__init__()
48
+ self.fc1 = nn.Linear(dim, int(dim * mlp_ratio))
49
+ self.fc2 = nn.Linear(int(dim * mlp_ratio), dim)
50
+ self.activation = nn.GELU()
51
+
52
+ def forward(self, x):
53
+ return self.fc2(self.activation(self.fc1(x)))
54
+
55
+
56
+ class Attention(nn.Module):
57
+ """Multihead attention."""
58
+
59
+ def __init__(self, dim, num_heads, qkv_bias=True):
60
+ super(Attention, self).__init__()
61
+ self.num_heads = num_heads
62
+ self.head_dim = dim // num_heads
63
+ self.scale = self.head_dim**-0.5
64
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
65
+ self.proj = nn.Linear(dim, dim)
66
+ self.rel_pos_embed = nn.Identity()
67
+
68
+ def forward(self, x):
69
+ qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim)
70
+ qkv = self.qkv(x).reshape(qkv_shape).permute(2, 0, 3, 1, 4)
71
+ q, k, v = qkv.unbind(dim=0)
72
+ attn = q @ k.transpose(-2, -1).mul(self.scale)
73
+ attn = self.rel_pos_embed(attn)
74
+ o = nn.functional.softmax(attn, dim=-1) @ v
75
+ return self.proj(o.transpose(1, 2).flatten(2))
76
+
77
+
78
+ class Block(nn.Module):
79
+ """Transformer block."""
80
+
81
+ def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=True):
82
+ super(Block, self).__init__()
83
+ self.norm1 = nn.LayerNorm(dim)
84
+ self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias)
85
+ self.norm2 = nn.LayerNorm(dim)
86
+ self.mlp = MLP(dim, mlp_ratio=mlp_ratio)
87
+
88
+ def forward(self, x):
89
+ x = self.attn(self.norm1(x)).add_(x)
90
+ return self.mlp(self.norm2(x)).add_(x)
91
+
92
+
93
+ class Bottleneck(nn.Module):
94
+ """The bottleneck block."""
95
+
96
+ def __init__(self, dim, expansion=2, width=None):
97
+ super(Bottleneck, self).__init__()
98
+ width = width or dim // expansion
99
+ self.conv1 = nn.Conv2d(dim, width, 1, bias=False)
100
+ self.norm1 = nn.SyncBatchNorm(width)
101
+ self.conv2 = nn.Conv2d(width, width, 3, padding=1, bias=False)
102
+ self.norm2 = nn.SyncBatchNorm(width)
103
+ self.conv3 = nn.Conv2d(width, dim, 1, bias=False)
104
+ self.norm3 = nn.SyncBatchNorm(dim)
105
+ self.activation = nn.GELU()
106
+
107
+ def forward(self, x):
108
+ shortcut = x
109
+ x = self.activation(self.norm1(self.conv1(x)))
110
+ x = self.activation(self.norm2(self.conv2(x)))
111
+ return self.norm3(self.conv3(x)).add_(shortcut)
112
+
113
+
114
+ class PatchEmbed(nn.Module):
115
+ """Patch embedding layer."""
116
+
117
+ def __init__(self, dim=768, patch_size=16, bias=True):
118
+ super(PatchEmbed, self).__init__()
119
+ self.proj = nn.Conv2d(3, dim, patch_size, patch_size, bias=bias)
120
+
121
+ def forward(self, x):
122
+ return self.proj(x).flatten(2).transpose(1, 2)
123
+
124
+
125
+ class PosEmbed(nn.Module):
126
+ """Position embedding layer."""
127
+
128
+ def __init__(self, dim, num_patches):
129
+ super(PosEmbed, self).__init__()
130
+ self.dim = dim
131
+ self.num_patches = num_patches
132
+ self.weight = nn.Parameter(torch.zeros(num_patches, dim))
133
+ nn.init.normal_(self.weight, std=0.02)
134
+
135
+ def forward(self, x):
136
+ return x.add_(self.weight)
137
+
138
+
139
+ class RelPosEmbed(nn.Module):
140
+ """Relative position embedding layer."""
141
+
142
+ def __init__(self, num_heads, size):
143
+ super(RelPosEmbed, self).__init__()
144
+ self.register_buffer("index", self.get_index(size))
145
+ self.weight = nn.Parameter(torch.zeros(num_heads, (2 * size - 1) ** 2))
146
+
147
+ @staticmethod
148
+ def get_index(size):
149
+ """Return the relative index."""
150
+ grid = torch.arange(size)
151
+ grid = torch.stack(torch.meshgrid(grid, grid, indexing="ij")).reshape((2, -1))
152
+ coords = grid[:, :, None] - grid[:, None, :] + (size - 1)
153
+ coords[0] *= 2 * size - 1
154
+ return coords.sum(0)
155
+
156
+ def get_bias(self):
157
+ return self.weight[:, self.index]
158
+
159
+ def forward(self, x):
160
+ return x.add_(self.get_bias())
161
+
162
+
163
+ class SimpleFeaturePyramid(nn.Module):
164
+ """Module to create pyramid features."""
165
+
166
+ def __init__(self, embed_dim, out_dim, patch_size=16, min_lvl=4, max_lvl=4):
167
+ super(SimpleFeaturePyramid, self).__init__()
168
+ self.min_lvl, self.max_lvl = min_lvl, max_lvl
169
+ self.input_conv = nn.ModuleList()
170
+ self.lateral_conv = nn.ModuleList()
171
+ self.output_conv = nn.ModuleList()
172
+ patch_lvl = dict((2**i, i) for i in range(6))[patch_size]
173
+ for lvl in [min(i + 2, self.max_lvl) for i in range(4)]:
174
+ if lvl == patch_lvl or lvl < self.min_lvl:
175
+ self.input_conv += [nn.Identity()]
176
+ elif lvl < patch_lvl:
177
+ stride, layers = 2 ** (patch_lvl - lvl), []
178
+ while stride > 1:
179
+ layers += [nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2)]
180
+ layers += [nn.SyncBatchNorm(embed_dim), nn.GELU()] if stride > 2 else []
181
+ stride /= 2
182
+ self.input_conv.append(nn.Sequential(*layers))
183
+ elif lvl > patch_lvl:
184
+ stride = 2 ** (lvl - patch_lvl)
185
+ self.input_conv += [nn.MaxPool2d(stride, stride)]
186
+ for _ in range(min_lvl, max_lvl + 1):
187
+ self.lateral_conv.append(
188
+ nn.Sequential(
189
+ nn.Conv2d(embed_dim, out_dim, kernel_size=1, bias=False),
190
+ nn.SyncBatchNorm(out_dim),
191
+ )
192
+ )
193
+ self.output_conv.append(
194
+ nn.Sequential(
195
+ nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, bias=False),
196
+ nn.SyncBatchNorm(out_dim),
197
+ )
198
+ )
199
+
200
+ def forward(self, inputs):
201
+ inputs = inputs + [inputs[-1]] * (4 - len(inputs))
202
+ inputs = [conv(x) for conv, x in zip(self.input_conv, inputs)]
203
+ features = inputs[self.min_lvl - 1 : self.max_lvl]
204
+ laterals = [conv(x) for conv, x in zip(self.lateral_conv, features)]
205
+ return [conv(x) for conv, x in zip(self.output_conv, laterals)]
206
+
207
+
208
+ class ImageEncoderViT(nn.Module):
209
+ """ViT image encoder."""
210
+
211
+ def __init__(
212
+ self,
213
+ depth,
214
+ embed_dim,
215
+ num_heads,
216
+ mlp_ratio=4,
217
+ patch_size=16,
218
+ window_size=16,
219
+ image_size=1024,
220
+ out_dim=256,
221
+ ):
222
+ super(ImageEncoderViT, self).__init__()
223
+ self.embed_dim = embed_dim
224
+ self.image_size = image_size
225
+ self.window_size = window_size or image_size // patch_size
226
+ self.patch_embed = PatchEmbed(embed_dim, patch_size)
227
+ self.pos_embed = PosEmbed(embed_dim, (image_size // patch_size) ** 2)
228
+ self.blocks = nn.ModuleList(Block(embed_dim, num_heads, mlp_ratio) for _ in range(depth))
229
+ for blk in self.blocks:
230
+ blk.attn.rel_pos_embed = RelPosEmbed(num_heads, self.window_size)
231
+ self.norm = nn.LayerNorm(embed_dim)
232
+ self.cross_conv = nn.ModuleList(Bottleneck(embed_dim) for _ in range(4))
233
+ self.neck = SimpleFeaturePyramid(embed_dim, out_dim, patch_size)
234
+ self.cross_indices = list(range(depth // 4 - 1, depth, depth // 4))
235
+
236
+ def forward(self, x):
237
+ x = self.patch_embed(x)
238
+ x = self.pos_embed(x)
239
+ x = space_to_depth(x, self.window_size)
240
+ wmsa_shape = (-1,) + x.shape[1:]
241
+ msa_shape = (-1, self.window_size**2, self.embed_dim)
242
+ x = x.reshape(msa_shape)
243
+ for i, blk in enumerate(self.blocks):
244
+ x = blk(x)
245
+ if i in self.cross_indices or i == len(self.blocks) - 1:
246
+ x = self.norm(x) if i == len(self.blocks) - 1 else x
247
+ x = depth_to_space(x.reshape(wmsa_shape), self.window_size)
248
+ x = x.permute(0, 3, 1, 2)
249
+ if i in self.cross_indices:
250
+ x = self.cross_conv[self.cross_indices.index(i)](x)
251
+ if i in self.cross_indices and i < len(self.blocks) - 1:
252
+ x = x.permute(0, 2, 3, 1)
253
+ x = space_to_depth(x, self.window_size).reshape(msa_shape)
254
+ return self.neck([x])
tokenize_anything/modeling/image_tokenizer.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Image tokenizer."""
17
+
18
+ import numpy as np
19
+ import torch
20
+ from torch import nn
21
+
22
+
23
+ class ImageTokenizer(nn.Module):
24
+ """Tokenize image regions with visual prompts."""
25
+
26
+ def __init__(
27
+ self,
28
+ image_encoder,
29
+ prompt_encoder,
30
+ image_decoder,
31
+ concept_projector=None,
32
+ text_tokenizer=None,
33
+ text_decoder=None,
34
+ pixel_mean=(103.53, 116.28, 123.675),
35
+ pixel_std=(57.375, 57.12, 58.395),
36
+ ):
37
+ super(ImageTokenizer, self).__init__()
38
+ self.image_encoder = image_encoder
39
+ self.prompt_encoder = prompt_encoder
40
+ self.image_decoder = image_decoder
41
+ self.concept_projector = concept_projector
42
+ self.text_tokenizer = text_tokenizer
43
+ self.text_decoder = text_decoder
44
+ self.pixel_mean_value = pixel_mean # BGR order.
45
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean))
46
+ self.register_buffer("pixel_rsig", torch.Tensor(pixel_std).reciprocal_())
47
+
48
+ def get_inputs(self, inputs):
49
+ """Return the model inputs.
50
+
51
+ Parameters
52
+ ----------
53
+ inputs : dict
54
+ The initial inputs.
55
+
56
+ Returns
57
+ -------
58
+ dict
59
+ The model inputs.
60
+
61
+ """
62
+ if not isinstance(inputs["img"], torch.Tensor):
63
+ inputs["img"] = torch.from_numpy(inputs["img"])
64
+ if inputs["img"].device != self.pixel_mean.device:
65
+ inputs["img"] = inputs["img"].to(device=self.pixel_mean.device)
66
+ inputs["img"] = inputs["img"].to(dtype=self.pixel_mean.dtype)
67
+ inputs["img"] = inputs["img"].sub(self.pixel_mean).mul_(self.pixel_rsig)
68
+ inputs["img"] = inputs["img"].permute(0, 3, 1, 2)
69
+ return inputs
70
+
71
+ def get_features(self, inputs):
72
+ """Return the image features.
73
+
74
+ Parameters
75
+ ----------
76
+ inputs : dict
77
+ The inputs.
78
+
79
+ Returns
80
+ -------
81
+ dict
82
+ The image features.
83
+
84
+ """
85
+ features = self.image_encoder(inputs["img"])
86
+ img_embeds = features[0].permute(0, 2, 3, 1).unsqueeze_(1)
87
+ return {"features": features, "img_embeds": img_embeds}
88
+
89
+ def get_outputs(self, inputs):
90
+ """Return the model outputs.
91
+
92
+ Parameters
93
+ ----------
94
+ inputs : dict
95
+ The model inputs.
96
+
97
+ Returns
98
+ -------
99
+ dict
100
+ The model outputs.
101
+
102
+ """
103
+ inputs.update(self.prompt_encoder(inputs))
104
+ return self.image_decoder(inputs)
105
+
106
+ def forward(self, inputs):
107
+ """Define the computation performed at every call.
108
+
109
+ Parameters
110
+ ----------
111
+ inputs : dict
112
+ The initial inputs.
113
+
114
+ Returns
115
+ -------
116
+ dict
117
+ The model outputs.
118
+
119
+ """
120
+ inputs = self.get_inputs(inputs)
121
+ inputs.update(self.get_features(inputs))
122
+ return self.get_outputs(inputs)
123
+
124
+ def upscale_masks(self, masks, size):
125
+ """Upscale masks using bilinear interpolation.
126
+
127
+ Parameters
128
+ ----------
129
+ masks : torch.Tensor
130
+ The input masks.
131
+ size : Union[int, Tuple[int]]
132
+ The output size.
133
+
134
+ Returns
135
+ -------
136
+ torch.Tensor
137
+ The output masks.
138
+
139
+ """
140
+ return nn.functional.interpolate(masks, size, mode="bilinear", align_corners=False)
141
+
142
+ @torch.inference_mode()
143
+ def predict_concept(self, visual_embeds, k=1):
144
+ """Predict top-k concepts based on visual embeddings.
145
+
146
+ Parameters
147
+ ----------
148
+ visual_embeds: torch.Tensor
149
+ The embeddings to predict visual content.
150
+ k : int, optional, default=1
151
+ The k value.
152
+
153
+ Returns
154
+ -------
155
+ Tuple[numpy.ndarray, numpy.ndarray]
156
+ The concept scores and indices.
157
+
158
+ """
159
+ return self.concept_projector.decode(visual_embeds, k)
160
+
161
+ @torch.inference_mode()
162
+ def generate_text(self, visual_tokens, max_gen_len=None, temperature=0):
163
+ """Generate text sequences based on visual tokens.
164
+
165
+ Parameters
166
+ ----------
167
+ visual_tokens: torch.Tensor
168
+ The tokens to prompt visual context.
169
+ max_gen_len : int, optional
170
+ The maximum length of the generated text sequences.
171
+ temperature : float, optional
172
+ The temperature for controlling randomness in sampling.
173
+
174
+ Returns
175
+ -------
176
+ np.ndarray
177
+ An array of generated texts.
178
+
179
+ """
180
+ max_gen_len = max_gen_len or self.text_decoder.max_seq_len
181
+ prompts = self.text_decoder.get_prompts(visual_tokens)
182
+ out_shape = (prompts.size(0), self.text_decoder.max_text_len)
183
+ tokens = np.full(out_shape, self.text_tokenizer.pad_id, "int64")
184
+ tokens[:, 0], prev_pos = self.text_tokenizer.bos_id, 0
185
+ eos_reached = np.array([False] * tokens.shape[0])
186
+ for cur_pos in range(1, max_gen_len):
187
+ decode_seq_len = cur_pos - prev_pos
188
+ x = torch.from_numpy(tokens[:, prev_pos:cur_pos]).to(device=prompts.device)
189
+ logits = self.text_decoder.transformer(prompts, x, prev_pos)
190
+ next_logits = logits[: x.size(0), decode_seq_len - 1]
191
+ if temperature > 0:
192
+ p = nn.functional.softmax(next_logits / temperature, dim=-1)
193
+ next_token = torch.multinomial(p, 1).cpu().numpy().flatten()
194
+ else:
195
+ next_token = next_logits.argmax(-1).cpu().numpy()
196
+ tokens[:, cur_pos] = next_token
197
+ eos_reached |= next_token == self.text_tokenizer.eos_id
198
+ prev_pos, logits, next_logits = cur_pos, None, None
199
+ if eos_reached.all():
200
+ break
201
+ return np.array(self.text_tokenizer.detokenize(tokens))
tokenize_anything/modeling/prompt_encoder.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Prompt encoder."""
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+
22
+ class PromptEncoder(nn.Module):
23
+ """Module to encode geometric prompts."""
24
+
25
+ def __init__(self, embed_dim, image_size):
26
+ super(PromptEncoder, self).__init__()
27
+ self.img_size = [image_size] * 2
28
+ self.point_embed = nn.Embedding(5, embed_dim) # [bg, fg, lt, rb, pad]
29
+ self.corner_labels = torch.tensor([[2, 3]], dtype=torch.int64)
30
+ self.register_buffer("coord_matrix", torch.randn((2, embed_dim // 2)))
31
+ self.img_pos = None
32
+
33
+ def to_tensor(self, input):
34
+ """Convert input to tensor."""
35
+ if input is None:
36
+ return input
37
+ if not isinstance(input, torch.Tensor):
38
+ input = torch.from_numpy(input)
39
+ if input.device != self.coord_matrix.device:
40
+ input = input.to(device=self.coord_matrix.device)
41
+ return input
42
+
43
+ def to_points(self, points=None, boxes=None):
44
+ """Convert points or boxes to point prompts."""
45
+ if points is not None:
46
+ if isinstance(points, (tuple, list)):
47
+ coords, labels = points
48
+ else:
49
+ coords, labels = points[:, :, :2], points[:, :, 2]
50
+ coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1])
51
+ coords = self.to_tensor(coords.clip(0, 1).astype("float32"))
52
+ labels = self.to_tensor(labels.astype("int64"))
53
+ return coords, labels
54
+ if boxes is not None:
55
+ coords = boxes.reshape((-1, 2, 2))
56
+ coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1])
57
+ coords = self.to_tensor(coords.clip(0, 1).astype("float32"))
58
+ labels = self.to_tensor(self.corner_labels)
59
+ return coords, labels
60
+ return None
61
+
62
+ def encode_coords(self, coords):
63
+ """Return the embedding for given coords."""
64
+ pi4, pi2 = 4 * 3.1415926, 2 * 3.1415926
65
+ if self.coord_matrix.dtype != torch.float32:
66
+ self.coord_matrix = self.coord_matrix.float()
67
+ rad = coords.mul(pi4).sub_(pi2) @ self.coord_matrix
68
+ dtype = self.point_embed.weight.dtype
69
+ return torch.cat([rad.sin(), rad.cos()], dim=-1).to(dtype=dtype)
70
+
71
+ def encode_points(self, coords, labels):
72
+ """Return the embedding for given points."""
73
+ embed = self.encode_coords(coords)
74
+ embed.mul_(labels.ne(4).unsqueeze_(-1).float().to(dtype=embed.dtype))
75
+ return embed.add_(self.point_embed(labels))
76
+
77
+ def encode_grid(self, grid_size):
78
+ """Return the embedding for a grid of specified size."""
79
+ grid = torch.ones(*grid_size, dtype=torch.float32)
80
+ y = grid.cumsum(dim=0).sub_(0.5).div_(grid_size[0])
81
+ x = grid.cumsum(dim=1).sub_(0.5).div_(grid_size[1])
82
+ coords = self.to_tensor(torch.stack([x, y], dim=-1))
83
+ return self.encode_coords(coords)
84
+
85
+ def forward(self, inputs):
86
+ sparse_embeds = []
87
+ if inputs.get("boxes", None) is not None:
88
+ coords, labels = self.to_points(boxes=inputs["boxes"])
89
+ sparse_embeds.append(self.encode_points(coords, labels))
90
+ if inputs.get("points", None) is not None:
91
+ coords, labels = self.to_points(points=inputs["points"])
92
+ sparse_embeds.append(self.encode_points(coords, labels))
93
+ if len(sparse_embeds) > 1:
94
+ sparse_embeds = [torch.cat(sparse_embeds, dim=1)]
95
+ elif len(sparse_embeds) == 0:
96
+ raise ValueError("Excepted ``points`` or ``boxes`` prompts.")
97
+ img_embed_size = torch.Size(inputs["img_embeds"].shape[2:-1])
98
+ if self.img_pos is None or self.img_pos.shape[0] != img_embed_size.numel():
99
+ self.img_pos = self.encode_grid(img_embed_size).flatten(0, 1)
100
+ return {"sparse_embeds": sparse_embeds[0], "img_pos": self.img_pos}
tokenize_anything/modeling/text_decoder.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Text decoder."""
17
+
18
+ try:
19
+ from flash_attn import flash_attn_func
20
+ from flash_attn import flash_attn_with_kvcache
21
+ from flash_attn.layers.rotary import apply_rotary_emb
22
+ except ImportError:
23
+ flash_attn_func = None
24
+ flash_attn_with_kvcache = None
25
+ apply_rotary_emb = None
26
+
27
+ import torch
28
+ from torch import nn
29
+
30
+
31
+ class TransformerCache(nn.Module):
32
+ """Transformer cache module."""
33
+
34
+ def __init__(self, device=None, dtype=None):
35
+ super(TransformerCache, self).__init__()
36
+ self.device = device
37
+ self.dtype = dtype
38
+ self.start_pos = 0
39
+ self.cache_dict = {}
40
+
41
+ def init_seq(self, max_batch_size):
42
+ seq_lens = torch.zeros(max_batch_size, dtype=torch.int32, device=self.device)
43
+ self.cache_dict["seq_lens"] = seq_lens
44
+
45
+ def init_rotary(self, seq_len, dim, theta=10000.0):
46
+ grid = torch.arange(seq_len, dtype=torch.float32).unsqueeze_(-1)
47
+ freq = torch.pow(theta, torch.arange(0, dim, 2)[: dim // 2].float().div_(dim))
48
+ broadcast_freq = grid.mul(freq.reciprocal_().unsqueeze_(0))
49
+ cache_cos = broadcast_freq.cos().view((-1, dim // 2))
50
+ cache_sin = broadcast_freq.sin().view((-1, dim // 2))
51
+ self.cache_dict["cos"] = cache_cos.to(self.device, self.dtype)
52
+ self.cache_dict["sin"] = cache_sin.to(self.device, self.dtype)
53
+
54
+ def init_kv(self, mixer, kv_size):
55
+ cache_k = torch.zeros(*kv_size, dtype=self.dtype, device=self.device)
56
+ cache_v = torch.zeros(*kv_size, dtype=self.dtype, device=self.device)
57
+ self.cache_dict[f"{id(mixer)}_k"] = cache_k
58
+ self.cache_dict[f"{id(mixer)}_v"] = cache_v
59
+
60
+ def set_seq(self, start_pos=0, end_pos=None):
61
+ self.start_pos = start_pos
62
+ if "seq_lens" in self.cache_dict:
63
+ self.cache_dict["seq_lens"].fill_(start_pos)
64
+ if "cos" in self.cache_dict and end_pos is not None:
65
+ self.cache_dict["seq_cos"] = self.cache_dict["cos"][self.start_pos : end_pos]
66
+ self.cache_dict["seq_sin"] = self.cache_dict["sin"][self.start_pos : end_pos]
67
+
68
+ def forward_rotary(self, q, k, inplace=False):
69
+ cos = self.cache_dict.get("seq_cos", self.cache_dict.get("cos", None))
70
+ sin = self.cache_dict.get("seq_sin", self.cache_dict.get("sin", None))
71
+ if cos is None or sin is None:
72
+ return q, k
73
+ q = apply_rotary_emb(q, cos, sin, interleaved=True, inplace=inplace)
74
+ k = apply_rotary_emb(k, cos, sin, interleaved=True, inplace=inplace)
75
+ return q, k
76
+
77
+ def forward_flash(self, mixer, q, k, v):
78
+ cache_k = self.cache_dict.get(f"{id(mixer)}_k", None)
79
+ cache_v = self.cache_dict.get(f"{id(mixer)}_v", None)
80
+ flash_args = {"softmax_scale": mixer.scale, "causal": True}
81
+ if cache_k is None or cache_v is None:
82
+ return flash_attn_func(q, k, v, **flash_args)
83
+ flash_args["cache_seqlens"] = self.cache_dict["seq_lens"][: q.shape[0]]
84
+ return flash_attn_with_kvcache(q, cache_k, cache_v, k, v, **flash_args)
85
+
86
+
87
+ class Attention(nn.Module):
88
+ """Self-Attention layer."""
89
+
90
+ def __init__(self, dim, num_heads, bias=True):
91
+ super(Attention, self).__init__()
92
+ self.qkv = nn.Linear(dim, dim * 3, bias=bias)
93
+ self.proj = nn.Linear(dim, dim, bias=bias)
94
+ self.head_dim = dim // num_heads
95
+ self.num_heads = num_heads
96
+ self.scale = self.head_dim**-0.5
97
+ self.cache = nn.Module()
98
+
99
+ def forward(self, x):
100
+ qkv_shape = (-1, x.size(1), 3, self.num_heads, self.head_dim)
101
+ q, k, v = self.qkv(x).view(qkv_shape).unbind(dim=2)
102
+ q, k = self.cache.forward_rotary(q, k, inplace=True)
103
+ o = self.cache.forward_flash(self, q, k, v)
104
+ return self.proj(o.flatten(2))
105
+
106
+
107
+ class MLP(nn.Module):
108
+ """Two layers MLP."""
109
+
110
+ def __init__(self, dim, mlp_dim, bias=True):
111
+ super(MLP, self).__init__()
112
+ self.fc1 = nn.Linear(dim, mlp_dim, bias=bias)
113
+ self.fc2 = nn.Linear(mlp_dim, dim, bias=bias)
114
+ self.activation = nn.GELU()
115
+
116
+ def forward(self, x):
117
+ return self.fc2(self.activation(self.fc1(x)))
118
+
119
+
120
+ class Block(nn.Module):
121
+ """Transformer block."""
122
+
123
+ def __init__(self, dim, num_heads, mlp_dim, bias=True):
124
+ super(Block, self).__init__()
125
+ self.attn = Attention(dim, num_heads, bias=bias)
126
+ self.mlp = MLP(dim, mlp_dim, bias=bias)
127
+ self.norm1 = nn.LayerNorm(dim)
128
+ self.norm2 = nn.LayerNorm(dim)
129
+
130
+ def forward(self, x):
131
+ x = self.attn(self.norm1(x)).add_(x)
132
+ return self.mlp(self.norm2(x)).add_(x)
133
+
134
+
135
+ class Transformer(nn.Module):
136
+ """Causal transformer decoder."""
137
+
138
+ def __init__(self, depth, dim, num_heads, mlp_dim, vocab_size):
139
+ super(Transformer, self).__init__()
140
+ self.dim = dim
141
+ self.num_heads = num_heads
142
+ self.head_dim = dim // num_heads
143
+ self.vocab_size = vocab_size
144
+ self.tok_embeddings = nn.Embedding(vocab_size, dim)
145
+ self.blocks = nn.ModuleList(Block(dim, num_heads, mlp_dim) for _ in range(depth))
146
+ self.norm = nn.LayerNorm(dim)
147
+ self.text_proj = nn.Linear(dim, vocab_size, bias=False)
148
+
149
+ def forward(self, prompts, tokens, start_pos=0):
150
+ prompt_len = prompts.size(1)
151
+ start_pos = start_pos + (prompt_len if start_pos > 0 else 0)
152
+ end_pos = start_pos + tokens.size(1) + (0 if start_pos > 0 else prompt_len)
153
+ self.cache.set_seq(start_pos, end_pos)
154
+ x = self.tok_embeddings(tokens)
155
+ x = x if start_pos > 0 else torch.cat([prompts, x], dim=1)
156
+ for blk in self.blocks:
157
+ x = blk(x)
158
+ x = self.norm(x[:, 0 if start_pos > 0 else prompt_len :])
159
+ return self.text_proj(x).float()
160
+
161
+
162
+ class TextDecoder(nn.Module):
163
+ """Module to decode texts."""
164
+
165
+ def __init__(
166
+ self,
167
+ depth,
168
+ embed_dim,
169
+ num_heads,
170
+ mlp_ratio,
171
+ prompt_embed_dim,
172
+ max_seq_len,
173
+ vocab_size,
174
+ ):
175
+ super(TextDecoder, self).__init__()
176
+ self.max_seq_len = max_seq_len
177
+ self.max_text_len = self.max_seq_len - 1
178
+ self.encoder = nn.Linear(prompt_embed_dim, embed_dim, bias=False)
179
+ self.transformer = Transformer(
180
+ depth=depth,
181
+ dim=embed_dim,
182
+ mlp_dim=embed_dim * mlp_ratio,
183
+ num_heads=num_heads,
184
+ vocab_size=vocab_size,
185
+ )
186
+
187
+ def reset_cache(self, max_batch_size=1, max_seq_len=None):
188
+ device, dtype = self.encoder.weight.device, self.encoder.weight.dtype
189
+ max_seq_len = self.max_seq_len if max_seq_len is None else max_seq_len
190
+ num_heads, head_dim = self.transformer.num_heads, self.transformer.head_dim
191
+ self.transformer.cache = TransformerCache(device=device, dtype=dtype)
192
+ self.transformer.cache.init_seq(max_batch_size)
193
+ self.transformer.cache.init_rotary(max_seq_len, head_dim, theta=10000.0)
194
+ kv_cache_size = (max_batch_size, max_seq_len, num_heads, head_dim)
195
+ for blk in self.transformer.blocks:
196
+ blk.attn.__dict__["cache"] = self.transformer.cache
197
+ self.transformer.cache.init_kv(blk.attn, kv_cache_size) if not self.training else None
198
+
199
+ def get_prompts(self, prompt_tokens):
200
+ return self.encoder(prompt_tokens)
201
+
202
+ def get_outputs(self, inputs, start_pos=0):
203
+ return {"text_pred": self.transformer(inputs["prompts"], inputs["tokens"], start_pos)}
204
+
205
+ def forward(self, inputs, start_pos=0):
206
+ return self.get_outputs(inputs, start_pos)
tokenize_anything/modeling/text_tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenize_anything/modeling/text_tokenizer.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # Source from: https://github.com/facebookresearch/llama/blob/main/llama/tokenizer.py
3
+
4
+ import os
5
+ from logging import getLogger
6
+ from typing import List
7
+
8
+ from sentencepiece import SentencePieceProcessor
9
+
10
+
11
+ logger = getLogger()
12
+
13
+
14
+ class TextTokenizer:
15
+ """Tokenizing and encoding/decoding text using SentencePiece."""
16
+
17
+ def __init__(self, model_path=None):
18
+ """
19
+ Initializes the Tokenizer with a SentencePiece model.
20
+
21
+ Args:
22
+ model_path (str): The path to the SentencePiece model file.
23
+ """
24
+ if model_path is None:
25
+ model_path = os.path.join(
26
+ os.path.dirname(os.path.abspath(__file__)), "text_tokenizer.model"
27
+ )
28
+ # reload tokenizer
29
+ assert os.path.isfile(model_path), model_path
30
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
31
+ logger.info(f"Reloaded SentencePiece model from {model_path}")
32
+ # BOS / EOS token IDs
33
+ self.n_words: int = self.sp_model.vocab_size()
34
+ self.bos_id: int = self.sp_model.bos_id()
35
+ self.eos_id: int = self.sp_model.eos_id()
36
+ self.pad_id: int = self.sp_model.pad_id()
37
+ self.pad_id += self.n_words if self.pad_id < 0 else 0
38
+ logger.info(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}")
39
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
40
+
41
+ def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
42
+ """
43
+ Encodes a string into a list of token IDs.
44
+
45
+ Args:
46
+ s (str): The input string to be encoded.
47
+ bos (bool): Whether to prepend the beginning-of-sequence token.
48
+ eos (bool): Whether to append the end-of-sequence token.
49
+
50
+ Returns:
51
+ List[int]: A list of token IDs.
52
+ """
53
+ assert type(s) is str
54
+ t = self.sp_model.encode(s)
55
+ if bos:
56
+ t = [self.bos_id] + t
57
+ if eos:
58
+ t = t + [self.eos_id]
59
+ return t
60
+
61
+ def decode(self, t: List[int]) -> str:
62
+ """
63
+ Decodes a list of token IDs into a string.
64
+
65
+ Args:
66
+ t (List[int]): The list of token IDs to be decoded.
67
+
68
+ Returns:
69
+ str: The decoded string.
70
+ """
71
+ return self.sp_model.decode(t)
72
+
73
+ def tokenize(self, texts, context_length=None):
74
+ """Encode a list of string.
75
+
76
+ Parameters
77
+ ----------
78
+ texts : Union[str, List[str]]
79
+ The input text(s).
80
+ context_length : int, optional
81
+ The max token length.
82
+
83
+ Returns
84
+ -------
85
+ List[List[int]]
86
+ The encoded token indices.
87
+
88
+ """
89
+ if isinstance(texts, str):
90
+ texts = [texts]
91
+ tokens = [self.encode(text, bos=True, eos=True) for text in texts]
92
+ if context_length is None:
93
+ return tokens
94
+ truncated_tokens = []
95
+ for k, t in enumerate(tokens):
96
+ if len(t) > context_length:
97
+ t = t[:context_length]
98
+ t[-1] = self.eos_id
99
+ truncated_tokens.append(t)
100
+ return truncated_tokens
101
+
102
+ def detokenize(self, tokens):
103
+ """Decode a list of string.
104
+
105
+ Parameters
106
+ ----------
107
+ tokens : Union[List[List[int]], numpy.ndarray]
108
+ The input tokens.
109
+
110
+ Returns
111
+ -------
112
+ List[str]
113
+ The decoded text strings.
114
+
115
+ """
116
+ if hasattr(tokens, "tolist"):
117
+ tokens = tokens.tolist()
118
+ texts = []
119
+ for i in range(len(tokens)):
120
+ t = tokens[i][1:]
121
+ try:
122
+ eot_idx = t.index(self.eos_id)
123
+ t = t[:eot_idx]
124
+ except ValueError:
125
+ pass
126
+ texts.append(self.decode(t))
127
+ return texts
tokenize_anything/test_engine.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Engine for testing."""
17
+
18
+ import time
19
+
20
+ from tokenize_anything.build_model import model_registry
21
+
22
+
23
+ class InferenceCommand(object):
24
+ """Command to run batched inference."""
25
+
26
+ def __init__(self, input_queue, output_queue, kwargs):
27
+ self.input_queue = input_queue
28
+ self.output_queue = output_queue
29
+ self.kwargs = kwargs
30
+
31
+ def build_env(self):
32
+ """Build the environment."""
33
+ self.batch_size = self.kwargs.get("batch_size", 1)
34
+ self.batch_timeout = self.kwargs.get("batch_timeout", None)
35
+
36
+ def build_model(self):
37
+ """Build and return the model."""
38
+ builder = model_registry[self.kwargs["model_type"]]
39
+ return builder(device=self.kwargs["device"], checkpoint=self.kwargs["weights"])
40
+
41
+ def build_predictor(self, model):
42
+ """Build and return the predictor."""
43
+ return self.kwargs["predictor_type"](model, self.kwargs)
44
+
45
+ def send_results(self, predictor, indices, examples):
46
+ """Send the inference results."""
47
+ results = predictor.get_results(examples)
48
+ if hasattr(predictor, "timers"):
49
+ time_diffs = dict((k, v.average_time) for k, v in predictor.timers.items())
50
+ for i, outputs in enumerate(results):
51
+ self.output_queue.put((indices[i], time_diffs, outputs))
52
+ else:
53
+ for i, outputs in enumerate(results):
54
+ self.output_queue.put((indices[i], outputs))
55
+
56
+ def run(self):
57
+ """Main loop to make the inference outputs."""
58
+ self.build_env()
59
+ model = self.build_model()
60
+ predictor = self.build_predictor(model)
61
+ must_stop = False
62
+ while not must_stop:
63
+ indices, examples = [], []
64
+ deadline, timeout = None, None
65
+ for i in range(self.batch_size):
66
+ if self.batch_timeout and i == 1:
67
+ deadline = time.monotonic() + self.batch_timeout
68
+ if self.batch_timeout and i >= 1:
69
+ timeout = deadline - time.monotonic()
70
+ try:
71
+ index, example = self.input_queue.get(timeout=timeout)
72
+ if index < 0:
73
+ must_stop = True
74
+ break
75
+ indices.append(index)
76
+ examples.append(example)
77
+ except Exception:
78
+ pass
79
+ if len(examples) == 0:
80
+ continue
81
+ self.send_results(predictor, indices, examples)
tokenize_anything/utils/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
tokenize_anything/utils/image.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Image utilities."""
17
+
18
+ import numpy as np
19
+ import PIL.Image
20
+
21
+
22
+ def im_resize(img, size=None, scale=None, mode="linear"):
23
+ """Resize image by the scale or size."""
24
+ if size is None:
25
+ if not isinstance(scale, (tuple, list)):
26
+ scale = (scale, scale)
27
+ h, w = img.shape[:2]
28
+ size = int(h * scale[0] + 0.5), int(w * scale[1] + 0.5)
29
+ else:
30
+ if not isinstance(size, (tuple, list)):
31
+ size = (size, size)
32
+ resize_modes = {"linear": PIL.Image.BILINEAR}
33
+ img = PIL.Image.fromarray(img)
34
+ return np.array(img.resize(size[::-1], resize_modes[mode]))
35
+
36
+
37
+ def im_rescale(img, scales, max_size=0):
38
+ """Rescale image to match the detecting scales."""
39
+ im_shape = img.shape
40
+ img_list, img_scales = [], []
41
+ size_min = np.min(im_shape[:2])
42
+ size_max = np.max(im_shape[:2])
43
+ for target_size in scales:
44
+ im_scale = float(target_size) / float(size_min)
45
+ target_size_max = max_size if max_size > 0 else target_size
46
+ if np.round(im_scale * size_max) > target_size_max:
47
+ im_scale = float(target_size_max) / float(size_max)
48
+ img_list.append(im_resize(img, scale=im_scale))
49
+ img_scales.append((im_scale, im_scale))
50
+ return img_list, img_scales
51
+
52
+
53
+ def im_vstack(arrays, fill_value=None, dtype=None, size=None, align=None):
54
+ """Stack image arrays in sequence vertically."""
55
+ if fill_value is None:
56
+ return np.vstack(arrays)
57
+ # Compute the max stack shape.
58
+ max_shape = np.max(np.stack([arr.shape for arr in arrays]), 0)
59
+ if size is not None and min(size) > 0:
60
+ max_shape[: len(size)] = size
61
+ if align is not None and min(align) > 0:
62
+ align_size = np.ceil(max_shape[: len(align)] / align)
63
+ max_shape[: len(align)] = align_size.astype("int64") * align
64
+ # Fill output with the given value.
65
+ output_dtype = dtype or arrays[0].dtype
66
+ output_shape = [len(arrays)] + list(max_shape)
67
+ output = np.empty(output_shape, output_dtype)
68
+ output[:] = fill_value
69
+ # Copy arrays.
70
+ for i, arr in enumerate(arrays):
71
+ copy_slices = (slice(0, d) for d in arr.shape)
72
+ output[(i,) + tuple(copy_slices)] = arr
73
+ return output
tokenize_anything/utils/mask.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Mask utilities."""
17
+
18
+ import numpy as np
19
+ from pycocotools.mask import encode
20
+
21
+
22
+ def mask_to_box(mask):
23
+ """Convert binary masks to boxes."""
24
+ shape, (h, w) = mask.shape, mask.shape[-2:]
25
+ masks = mask.reshape((-1, h, w)).astype("bool")
26
+ in_height = np.max(masks, axis=-1)
27
+ in_width = np.max(masks, axis=-2)
28
+ in_height_coords = in_height * np.arange(h, dtype="int32")
29
+ in_width_coords = in_width * np.arange(w, dtype="int32")
30
+ bottom_edges = np.max(in_height_coords, axis=-1)
31
+ top_edges = np.min(in_height_coords + h * (~in_height), axis=-1)
32
+ right_edges = np.max(in_width_coords, axis=-1)
33
+ left_edges = np.min(in_width_coords + w * (~in_width), axis=-1)
34
+ is_empty = (right_edges < left_edges) | (bottom_edges < top_edges)
35
+ boxes = np.stack([left_edges, top_edges, right_edges, bottom_edges], axis=-1)
36
+ boxes = boxes.astype("float32") * ((~is_empty)[:, None])
37
+ return boxes.reshape(*shape[:-2], 4) if len(shape) > 2 else boxes[0]
38
+
39
+
40
+ def encode_masks(masks):
41
+ """Encode a set of masks to RLEs."""
42
+ rles = encode(np.asfortranarray(masks))
43
+ for rle in rles:
44
+ rle["counts"] = rle["counts"].decode()
45
+ return rles
tokenize_anything/utils/timer.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2023-present, BAAI. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ------------------------------------------------------------------------
16
+ """Timing functions."""
17
+
18
+ import contextlib
19
+ import time
20
+
21
+
22
+ class Timer(object):
23
+ """Simple timer."""
24
+
25
+ def __init__(self):
26
+ self.total_time = 0.0
27
+ self.calls = 0
28
+ self.start_time = 0.0
29
+ self.diff = 0.0
30
+ self.average_time = 0.0
31
+
32
+ def add_diff(self, diff, n=1, average=True):
33
+ self.total_time += diff
34
+ self.calls += n
35
+ self.average_time = self.total_time / self.calls
36
+ return self.average_time if average else self.diff
37
+
38
+ @contextlib.contextmanager
39
+ def tic_and_toc(self, n=1, average=True):
40
+ try:
41
+ yield self.tic()
42
+ finally:
43
+ self.toc(n, average)
44
+
45
+ def tic(self):
46
+ self.start_time = time.time()
47
+ return self
48
+
49
+ def toc(self, n=1, average=True):
50
+ self.diff = time.time() - self.start_time
51
+ return self.add_diff(self.diff, n, average)
tokenize_anything/version.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version = "0.1.0a0"
2
+ git_version = "None"
3
+ __version__ = version