wondervictor commited on
Commit
a93afca
·
1 Parent(s): 2ca88bd
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. LICENSE +201 -0
  3. README.md +1 -1
  4. app.py +80 -0
  5. inference.py +189 -0
  6. model/EfficientSAM/.DS_Store +0 -0
  7. model/EfficientSAM/efficient_sam/__init__.py +7 -0
  8. model/EfficientSAM/efficient_sam/build_efficient_sam.py +22 -0
  9. model/EfficientSAM/efficient_sam/efficient_sam.py +306 -0
  10. model/EfficientSAM/efficient_sam/efficient_sam_decoder.py +318 -0
  11. model/EfficientSAM/efficient_sam/efficient_sam_encoder.py +257 -0
  12. model/EfficientSAM/efficient_sam/mlp.py +29 -0
  13. model/EfficientSAM/efficient_sam/two_way_transformer.py +266 -0
  14. model/configuration_evf.py +113 -0
  15. model/evf_effisam.py +313 -0
  16. model/evf_sam.py +303 -0
  17. model/segment_anything/__init__.py +10 -0
  18. model/segment_anything/automatic_mask_generator.py +372 -0
  19. model/segment_anything/build_sam.py +108 -0
  20. model/segment_anything/modeling/__init__.py +11 -0
  21. model/segment_anything/modeling/common.py +43 -0
  22. model/segment_anything/modeling/image_encoder.py +426 -0
  23. model/segment_anything/modeling/mask_decoder.py +191 -0
  24. model/segment_anything/modeling/prompt_encoder.py +238 -0
  25. model/segment_anything/modeling/sam.py +184 -0
  26. model/segment_anything/modeling/transformer.py +242 -0
  27. model/segment_anything/predictor.py +284 -0
  28. model/segment_anything/utils/__init__.py +5 -0
  29. model/segment_anything/utils/amg.py +346 -0
  30. model/segment_anything/utils/onnx.py +157 -0
  31. model/segment_anything/utils/transforms.py +113 -0
  32. model/unilm/beit3/README.md +191 -0
  33. model/unilm/beit3/datasets.py +847 -0
  34. model/unilm/beit3/engine_for_finetuning.py +598 -0
  35. model/unilm/beit3/get_started/get_started_for_captioning.md +176 -0
  36. model/unilm/beit3/get_started/get_started_for_image_classification.md +138 -0
  37. model/unilm/beit3/get_started/get_started_for_nlvr2.md +136 -0
  38. model/unilm/beit3/get_started/get_started_for_retrieval.md +161 -0
  39. model/unilm/beit3/get_started/get_started_for_vqav2.md +144 -0
  40. model/unilm/beit3/glossary.py +190 -0
  41. model/unilm/beit3/modeling_finetune.py +386 -0
  42. model/unilm/beit3/modeling_utils.py +76 -0
  43. model/unilm/beit3/optim_factory.py +128 -0
  44. model/unilm/beit3/randaug.py +340 -0
  45. model/unilm/beit3/requirements.txt +22 -0
  46. model/unilm/beit3/run_beit3_finetuning.py +448 -0
  47. model/unilm/beit3/utils.py +913 -0
  48. requirements.txt +32 -0
  49. utils/ade20k_classes.json +30 -0
  50. utils/aug.py +117 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Evf Sam
3
  emoji: 👀
4
  colorFrom: yellow
5
  colorTo: gray
 
1
  ---
2
+ title: EVF-SAM
3
  emoji: 👀
4
  colorFrom: yellow
5
  colorTo: gray
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference import sam_preprocess, beit3_preprocess
3
+ from model.evf_sam import EvfSamModel
4
+ from transformers import AutoTokenizer
5
+ import torch
6
+ import numpy as np
7
+ import sys
8
+
9
+ version = "YxZhang/evf-sam"
10
+ model_type = "ori"
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained(
13
+ version,
14
+ padding_side="right",
15
+ use_fast=False,
16
+ )
17
+
18
+ kwargs = {
19
+ "torch_dtype": torch.half,
20
+ }
21
+ model = EvfSamModel.from_pretrained(version, low_cpu_mem_usage=True,
22
+ **kwargs).cuda().eval()
23
+
24
+
25
+ @torch.no_grad()
26
+ def pred(image_np, prompt):
27
+ original_size_list = [image_np.shape[:2]]
28
+
29
+ image_beit = beit3_preprocess(image_np, 224).to(dtype=model.dtype,
30
+ device=model.device)
31
+
32
+ image_sam, resize_shape = sam_preprocess(image_np, model_type=model_type)
33
+ image_sam = image_sam.to(dtype=model.dtype, device=model.device)
34
+
35
+ input_ids = tokenizer(
36
+ prompt, return_tensors="pt")["input_ids"].to(device=model.device)
37
+
38
+ # infer
39
+ pred_mask = model.inference(
40
+ image_sam.unsqueeze(0),
41
+ image_beit.unsqueeze(0),
42
+ input_ids,
43
+ resize_list=[resize_shape],
44
+ original_size_list=original_size_list,
45
+ )
46
+ pred_mask = pred_mask.detach().cpu().numpy()[0]
47
+ pred_mask = pred_mask > 0
48
+
49
+ visualization = image_np.copy()
50
+ visualization[pred_mask] = (image_np * 0.5 +
51
+ pred_mask[:, :, None].astype(np.uint8) *
52
+ np.array([50, 120, 220]) * 0.5)[pred_mask]
53
+
54
+ return visualization / 255.0, pred_mask.astype(np.float16)
55
+
56
+
57
+ demo = gr.Interface(
58
+ fn=pred,
59
+ inputs=[
60
+ gr.components.Image(type="numpy", label="Image", image_mode="RGB"),
61
+ gr.components.Textbox(
62
+ label="Prompt",
63
+ info=
64
+ "Use a phrase or sentence to describe the object you want to segment. Currently we only support English"
65
+ )
66
+ ],
67
+ outputs=[
68
+ gr.components.Image(type="numpy", label="visulization"),
69
+ gr.components.Image(type="numpy", label="mask")
70
+ ],
71
+ examples=[["assets/zebra.jpg", "zebra top left"],
72
+ ["assets/bus.jpg", "bus going to south common"],
73
+ [
74
+ "assets/carrots.jpg",
75
+ "3carrots in center with ice and greenn leaves"
76
+ ]],
77
+ title="EVF-SAM referring expression segmentation",
78
+ allow_flagging="never")
79
+ # demo.launch()
80
+ demo.launch(share=False, server_name="0.0.0.0", server_port=10001)
inference.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torchvision import transforms
10
+ from torchvision.transforms.functional import InterpolationMode
11
+ from transformers import AutoTokenizer, BitsAndBytesConfig
12
+ from model.segment_anything.utils.transforms import ResizeLongestSide
13
+
14
+
15
+
16
+ def parse_args(args):
17
+ parser = argparse.ArgumentParser(description="EVF infer")
18
+ parser.add_argument("--version", required=True)
19
+ parser.add_argument("--vis_save_path", default="./infer", type=str)
20
+ parser.add_argument(
21
+ "--precision",
22
+ default="fp16",
23
+ type=str,
24
+ choices=["fp32", "bf16", "fp16"],
25
+ help="precision for inference",
26
+ )
27
+ parser.add_argument("--image_size", default=224, type=int, help="image size")
28
+ parser.add_argument("--model_max_length", default=512, type=int)
29
+
30
+ parser.add_argument("--local-rank", default=0, type=int, help="node rank")
31
+ parser.add_argument("--load_in_8bit", action="store_true", default=False)
32
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
33
+ parser.add_argument("--model_type", default="ori", choices=["ori", "effi"])
34
+ parser.add_argument("--image_path", type=str, default="assets/zebra.jpg")
35
+ parser.add_argument("--prompt", type=str, default="zebra top left")
36
+
37
+ return parser.parse_args(args)
38
+
39
+
40
+ def sam_preprocess(
41
+ x: np.ndarray,
42
+ pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
43
+ pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
44
+ img_size=1024,
45
+ model_type="ori") -> torch.Tensor:
46
+ '''
47
+ preprocess of Segment Anything Model, including scaling, normalization and padding.
48
+ preprocess differs between SAM and Effi-SAM, where Effi-SAM use no padding.
49
+ input: ndarray
50
+ output: torch.Tensor
51
+ '''
52
+ assert img_size==1024, \
53
+ "both SAM and Effi-SAM receive images of size 1024^2, don't change this setting unless you're sure that your employed model works well with another size."
54
+ x = ResizeLongestSide(img_size).apply_image(x)
55
+ resize_shape = x.shape[:2]
56
+ x = torch.from_numpy(x).permute(2,0,1).contiguous()
57
+
58
+ # Normalize colors
59
+ x = (x - pixel_mean) / pixel_std
60
+ if model_type=="effi":
61
+ x = F.interpolate(x.unsqueeze(0), (img_size, img_size), mode="bilinear").squeeze(0)
62
+ else:
63
+ # Pad
64
+ h, w = x.shape[-2:]
65
+ padh = img_size - h
66
+ padw = img_size - w
67
+ x = F.pad(x, (0, padw, 0, padh))
68
+ return x, resize_shape
69
+
70
+ def beit3_preprocess(x: np.ndarray, img_size=224) -> torch.Tensor:
71
+ '''
72
+ preprocess for BEIT-3 model.
73
+ input: ndarray
74
+ output: torch.Tensor
75
+ '''
76
+ beit_preprocess = transforms.Compose([
77
+ transforms.ToTensor(),
78
+ transforms.Resize((img_size, img_size), interpolation=InterpolationMode.BICUBIC),
79
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
80
+ ])
81
+ return beit_preprocess(x)
82
+
83
+ def init_models(args):
84
+ tokenizer = AutoTokenizer.from_pretrained(
85
+ args.version,
86
+ padding_side="right",
87
+ use_fast=False,
88
+ )
89
+
90
+ torch_dtype = torch.float32
91
+ if args.precision == "bf16":
92
+ torch_dtype = torch.bfloat16
93
+ elif args.precision == "fp16":
94
+ torch_dtype = torch.half
95
+
96
+ kwargs = {"torch_dtype": torch_dtype}
97
+ if args.load_in_4bit:
98
+ kwargs.update(
99
+ {
100
+ "torch_dtype": torch.half,
101
+ "quantization_config": BitsAndBytesConfig(
102
+ llm_int8_skip_modules=["visual_model"],
103
+ load_in_4bit=True,
104
+ bnb_4bit_compute_dtype=torch.float16,
105
+ bnb_4bit_use_double_quant=True,
106
+ bnb_4bit_quant_type="nf4",
107
+ ),
108
+ }
109
+ )
110
+ elif args.load_in_8bit:
111
+ kwargs.update(
112
+ {
113
+ "torch_dtype": torch.half,
114
+ "quantization_config": BitsAndBytesConfig(
115
+ llm_int8_skip_modules=["visual_model"],
116
+ load_in_8bit=True,
117
+ ),
118
+ }
119
+ )
120
+
121
+ if args.model_type=="ori":
122
+ from model.evf_sam import EvfSamModel
123
+ model = EvfSamModel.from_pretrained(
124
+ args.version, low_cpu_mem_usage=True, **kwargs
125
+ )
126
+ elif args.model_type=="effi":
127
+ from model.evf_effisam import EvfEffiSamModel
128
+ model = EvfEffiSamModel.from_pretrained(
129
+ args.version, low_cpu_mem_usage=True, **kwargs
130
+ )
131
+
132
+ if (not args.load_in_4bit) and (not args.load_in_8bit):
133
+ model = model.cuda()
134
+ model.eval()
135
+
136
+ return tokenizer, model
137
+
138
+ def main(args):
139
+ args = parse_args(args)
140
+
141
+ # clarify IO
142
+ image_path = args.image_path
143
+ if not os.path.exists(image_path):
144
+ print("File not found in {}".format(image_path))
145
+ exit()
146
+ prompt = args.prompt
147
+
148
+ os.makedirs(args.vis_save_path, exist_ok=True)
149
+ save_path = "{}/{}_vis.png".format(
150
+ args.vis_save_path, os.path.basename(image_path).split(".")[0]
151
+ )
152
+
153
+ # initialize model and tokenizer
154
+ tokenizer, model = init_models(args)
155
+ # preprocess
156
+ image_np = cv2.imread(image_path)
157
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
158
+ original_size_list = [image_np.shape[:2]]
159
+
160
+ image_beit = beit3_preprocess(image_np, args.image_size).to(dtype=model.dtype, device=model.device)
161
+
162
+ image_sam, resize_shape = sam_preprocess(image_np, model_type=args.model_type)
163
+ image_sam = image_sam.to(dtype=model.dtype, device=model.device)
164
+
165
+ input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device=model.device)
166
+
167
+ # infer
168
+ pred_mask = model.inference(
169
+ image_sam.unsqueeze(0),
170
+ image_beit.unsqueeze(0),
171
+ input_ids,
172
+ resize_list=[resize_shape],
173
+ original_size_list=original_size_list,
174
+ )
175
+ pred_mask = pred_mask.detach().cpu().numpy()[0]
176
+ pred_mask = pred_mask > 0
177
+
178
+ # save visualization
179
+ save_img = image_np.copy()
180
+ save_img[pred_mask] = (
181
+ image_np * 0.5
182
+ + pred_mask[:, :, None].astype(np.uint8) * np.array([50, 120, 220]) * 0.5
183
+ )[pred_mask]
184
+ save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
185
+
186
+ cv2.imwrite(save_path, save_img)
187
+
188
+ if __name__ == "__main__":
189
+ main(sys.argv[1:])
model/EfficientSAM/.DS_Store ADDED
Binary file (10.2 kB). View file
 
model/EfficientSAM/efficient_sam/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ from .build_efficient_sam import (
5
+ build_efficient_sam_vitt,
6
+ build_efficient_sam_vits,
7
+ )
model/EfficientSAM/efficient_sam/build_efficient_sam.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .efficient_sam import build_efficient_sam
8
+
9
+ def build_efficient_sam_vitt(checkpoint=None):
10
+ return build_efficient_sam(
11
+ encoder_patch_embed_dim=192,
12
+ encoder_num_heads=3,
13
+ checkpoint=checkpoint,
14
+ ).eval()
15
+
16
+
17
+ def build_efficient_sam_vits(checkpoint=None):
18
+ return build_efficient_sam(
19
+ encoder_patch_embed_dim=384,
20
+ encoder_num_heads=6,
21
+ checkpoint=checkpoint,
22
+ ).eval()
model/EfficientSAM/efficient_sam/efficient_sam.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Any, List, Tuple, Type
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+ from torch import nn, Tensor
14
+
15
+ from .efficient_sam_decoder import MaskDecoder, PromptEncoder
16
+ from .efficient_sam_encoder import ImageEncoderViT
17
+ from .two_way_transformer import TwoWayAttentionBlock, TwoWayTransformer
18
+
19
+ class EfficientSam(nn.Module):
20
+ mask_threshold: float = 0.0
21
+ image_format: str = "RGB"
22
+
23
+ def __init__(
24
+ self,
25
+ image_encoder: ImageEncoderViT,
26
+ prompt_encoder: PromptEncoder,
27
+ decoder_max_num_input_points: int,
28
+ mask_decoder: MaskDecoder,
29
+ pixel_mean: List[float] = [0.485, 0.456, 0.406],
30
+ pixel_std: List[float] = [0.229, 0.224, 0.225],
31
+ ) -> None:
32
+ """
33
+ SAM predicts object masks from an image and input prompts.
34
+
35
+ Arguments:
36
+ image_encoder (ImageEncoderViT): The backbone used to encode the
37
+ image into image embeddings that allow for efficient mask prediction.
38
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
39
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
40
+ and encoded prompts.
41
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
42
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
43
+ """
44
+ super().__init__()
45
+ self.image_encoder = image_encoder
46
+ self.prompt_encoder = prompt_encoder
47
+ self.decoder_max_num_input_points = decoder_max_num_input_points
48
+ self.mask_decoder = mask_decoder
49
+ self.register_buffer(
50
+ "pixel_mean", torch.Tensor(pixel_mean).view(1, 3, 1, 1), False
51
+ )
52
+ self.register_buffer(
53
+ "pixel_std", torch.Tensor(pixel_std).view(1, 3, 1, 1), False
54
+ )
55
+
56
+ @torch.jit.export
57
+ def predict_masks(
58
+ self,
59
+ image_embeddings: torch.Tensor,
60
+ batched_points: torch.Tensor,
61
+ batched_point_labels: torch.Tensor,
62
+ multimask_output: bool,
63
+ input_h: int,
64
+ input_w: int,
65
+ output_h: int = -1,
66
+ output_w: int = -1,
67
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
68
+ """
69
+ Predicts masks given image embeddings and prompts. This only runs the decoder.
70
+
71
+ Arguments:
72
+ image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W]
73
+ batched_points: A tensor of shape [B, max_num_queries, num_pts, 2]
74
+ batched_point_labels: A tensor of shape [B, max_num_queries, num_pts]
75
+ Returns:
76
+ A tuple of two tensors:
77
+ low_res_mask: A tensor of shape [B, max_num_queries, 256, 256] of predicted masks
78
+ iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores
79
+ """
80
+
81
+ batch_size, max_num_queries, num_pts, _ = batched_points.shape
82
+ num_pts = batched_points.shape[2]
83
+ rescaled_batched_points = self.get_rescaled_pts(batched_points, input_h, input_w)
84
+
85
+ if num_pts > self.decoder_max_num_input_points:
86
+ rescaled_batched_points = rescaled_batched_points[
87
+ :, :, : self.decoder_max_num_input_points, :
88
+ ]
89
+ batched_point_labels = batched_point_labels[
90
+ :, :, : self.decoder_max_num_input_points
91
+ ]
92
+ elif num_pts < self.decoder_max_num_input_points:
93
+ rescaled_batched_points = F.pad(
94
+ rescaled_batched_points,
95
+ (0, 0, 0, self.decoder_max_num_input_points - num_pts),
96
+ value=-1.0,
97
+ )
98
+ batched_point_labels = F.pad(
99
+ batched_point_labels,
100
+ (0, self.decoder_max_num_input_points - num_pts),
101
+ value=-1.0,
102
+ )
103
+
104
+ sparse_embeddings = self.prompt_encoder(
105
+ rescaled_batched_points.reshape(
106
+ batch_size * max_num_queries, self.decoder_max_num_input_points, 2
107
+ ),
108
+ batched_point_labels.reshape(
109
+ batch_size * max_num_queries, self.decoder_max_num_input_points
110
+ ),
111
+ )
112
+
113
+ sparse_embeddings = sparse_embeddings.view(
114
+ batch_size,
115
+ max_num_queries,
116
+ sparse_embeddings.shape[1],
117
+ sparse_embeddings.shape[2],
118
+ )
119
+ low_res_masks, iou_predictions = self.mask_decoder(
120
+ image_embeddings,
121
+ self.prompt_encoder.get_dense_pe(),
122
+ sparse_prompt_embeddings=sparse_embeddings,
123
+ multimask_output=multimask_output,
124
+ )
125
+ _, num_predictions, low_res_size, _ = low_res_masks.shape
126
+
127
+ if output_w > 0 and output_h > 0:
128
+ output_masks = F.interpolate(
129
+ low_res_masks, (output_h, output_w), mode="bicubic"
130
+ )
131
+ output_masks = torch.reshape(
132
+ output_masks,
133
+ (batch_size, max_num_queries, num_predictions, output_h, output_w),
134
+ )
135
+ else:
136
+ output_masks = torch.reshape(
137
+ low_res_masks,
138
+ (
139
+ batch_size,
140
+ max_num_queries,
141
+ num_predictions,
142
+ low_res_size,
143
+ low_res_size,
144
+ ),
145
+ )
146
+ iou_predictions = torch.reshape(
147
+ iou_predictions, (batch_size, max_num_queries, num_predictions)
148
+ )
149
+ return output_masks, iou_predictions
150
+
151
+ def get_rescaled_pts(self, batched_points: torch.Tensor, input_h: int, input_w: int):
152
+ return torch.stack(
153
+ [
154
+ torch.where(
155
+ batched_points[..., 0] >= 0,
156
+ batched_points[..., 0] * self.image_encoder.img_size / input_w,
157
+ -1.0,
158
+ ),
159
+ torch.where(
160
+ batched_points[..., 1] >= 0,
161
+ batched_points[..., 1] * self.image_encoder.img_size / input_h,
162
+ -1.0,
163
+ ),
164
+ ],
165
+ dim=-1,
166
+ )
167
+
168
+ @torch.jit.export
169
+ def get_image_embeddings(self, batched_images) -> torch.Tensor:
170
+ """
171
+ Predicts masks end-to-end from provided images and prompts.
172
+ If prompts are not known in advance, using SamPredictor is
173
+ recommended over calling the model directly.
174
+
175
+ Arguments:
176
+ batched_images: A tensor of shape [B, 3, H, W]
177
+ Returns:
178
+ List of image embeddings each of of shape [B, C(i), H(i), W(i)].
179
+ The last embedding corresponds to the final layer.
180
+ """
181
+ batched_images = self.preprocess(batched_images)
182
+ return self.image_encoder(batched_images)
183
+
184
+ def forward(
185
+ self,
186
+ batched_images: torch.Tensor,
187
+ batched_points: torch.Tensor,
188
+ batched_point_labels: torch.Tensor,
189
+ scale_to_original_image_size: bool = True,
190
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
191
+ """
192
+ Predicts masks end-to-end from provided images and prompts.
193
+ If prompts are not known in advance, using SamPredictor is
194
+ recommended over calling the model directly.
195
+
196
+ Arguments:
197
+ batched_images: A tensor of shape [B, 3, H, W]
198
+ batched_points: A tensor of shape [B, num_queries, max_num_pts, 2]
199
+ batched_point_labels: A tensor of shape [B, num_queries, max_num_pts]
200
+
201
+ Returns:
202
+ A list tuples of two tensors where the ith element is by considering the first i+1 points.
203
+ low_res_mask: A tensor of shape [B, 256, 256] of predicted masks
204
+ iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores
205
+ """
206
+ batch_size, _, input_h, input_w = batched_images.shape
207
+ image_embeddings = self.get_image_embeddings(batched_images)
208
+ return self.predict_masks(
209
+ image_embeddings,
210
+ batched_points,
211
+ batched_point_labels,
212
+ multimask_output=True,
213
+ input_h=input_h,
214
+ input_w=input_w,
215
+ output_h=input_h if scale_to_original_image_size else -1,
216
+ output_w=input_w if scale_to_original_image_size else -1,
217
+ )
218
+
219
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
220
+ """Normalize pixel values and pad to a square input."""
221
+ if (
222
+ x.shape[2] != self.image_encoder.img_size
223
+ or x.shape[3] != self.image_encoder.img_size
224
+ ):
225
+ x = F.interpolate(
226
+ x,
227
+ (self.image_encoder.img_size, self.image_encoder.img_size),
228
+ mode="bilinear",
229
+ )
230
+ return (x - self.pixel_mean) / self.pixel_std
231
+
232
+
233
+ def build_efficient_sam(encoder_patch_embed_dim, encoder_num_heads, checkpoint=None):
234
+ img_size = 1024
235
+ encoder_patch_size = 16
236
+ encoder_depth = 12
237
+ encoder_mlp_ratio = 4.0
238
+ encoder_neck_dims = [256, 256]
239
+ decoder_max_num_input_points = 6
240
+ decoder_transformer_depth = 2
241
+ decoder_transformer_mlp_dim = 2048
242
+ decoder_num_heads = 8
243
+ decoder_upscaling_layer_dims = [64, 32]
244
+ num_multimask_outputs = 3
245
+ iou_head_depth = 3
246
+ iou_head_hidden_dim = 256
247
+ activation = "gelu"
248
+ normalization_type = "layer_norm"
249
+ normalize_before_activation = False
250
+
251
+ assert activation == "relu" or activation == "gelu"
252
+ if activation == "relu":
253
+ activation_fn = nn.ReLU
254
+ else:
255
+ activation_fn = nn.GELU
256
+
257
+ image_encoder = ImageEncoderViT(
258
+ img_size=img_size,
259
+ patch_size=encoder_patch_size,
260
+ in_chans=3,
261
+ patch_embed_dim=encoder_patch_embed_dim,
262
+ normalization_type=normalization_type,
263
+ depth=encoder_depth,
264
+ num_heads=encoder_num_heads,
265
+ mlp_ratio=encoder_mlp_ratio,
266
+ neck_dims=encoder_neck_dims,
267
+ act_layer=activation_fn,
268
+ )
269
+
270
+ image_embedding_size = image_encoder.image_embedding_size
271
+ encoder_transformer_output_dim = image_encoder.transformer_output_dim
272
+
273
+ sam = EfficientSam(
274
+ image_encoder=image_encoder,
275
+ prompt_encoder=PromptEncoder(
276
+ embed_dim=encoder_transformer_output_dim,
277
+ image_embedding_size=(image_embedding_size, image_embedding_size),
278
+ input_image_size=(img_size, img_size),
279
+ ),
280
+ decoder_max_num_input_points=decoder_max_num_input_points,
281
+ mask_decoder=MaskDecoder(
282
+ transformer_dim=encoder_transformer_output_dim,
283
+ transformer=TwoWayTransformer(
284
+ depth=decoder_transformer_depth,
285
+ embedding_dim=encoder_transformer_output_dim,
286
+ num_heads=decoder_num_heads,
287
+ mlp_dim=decoder_transformer_mlp_dim,
288
+ activation=activation_fn,
289
+ normalize_before_activation=normalize_before_activation,
290
+ ),
291
+ num_multimask_outputs=num_multimask_outputs,
292
+ activation=activation_fn,
293
+ normalization_type=normalization_type,
294
+ normalize_before_activation=normalize_before_activation,
295
+ iou_head_depth=iou_head_depth - 1,
296
+ iou_head_hidden_dim=iou_head_hidden_dim,
297
+ upscaling_layer_dims=decoder_upscaling_layer_dims,
298
+ ),
299
+ pixel_mean=[0.485, 0.456, 0.406],
300
+ pixel_std=[0.229, 0.224, 0.225],
301
+ )
302
+ if checkpoint is not None:
303
+ with open(checkpoint, "rb") as f:
304
+ state_dict = torch.load(f, map_location="cpu")
305
+ sam.load_state_dict(state_dict["model"])
306
+ return sam
model/EfficientSAM/efficient_sam/efficient_sam_decoder.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Tuple, Type
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from .mlp import MLPBlock
15
+
16
+
17
+ class PromptEncoder(nn.Module):
18
+ def __init__(
19
+ self,
20
+ embed_dim: int,
21
+ image_embedding_size: Tuple[int, int],
22
+ input_image_size: Tuple[int, int],
23
+ ) -> None:
24
+ """
25
+ Encodes prompts for input to SAM's mask decoder.
26
+
27
+ Arguments:
28
+ embed_dim (int): The prompts' embedding dimension
29
+ image_embedding_size (tuple(int, int)): The spatial size of the
30
+ image embedding, as (H, W).
31
+ input_image_size (int): The padded size of the image as input
32
+ to the image encoder, as (H, W).
33
+ """
34
+ super().__init__()
35
+ self.embed_dim = embed_dim
36
+ self.input_image_size = input_image_size
37
+ self.image_embedding_size = image_embedding_size
38
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
39
+ self.invalid_points = nn.Embedding(1, embed_dim)
40
+ self.point_embeddings = nn.Embedding(1, embed_dim)
41
+ self.bbox_top_left_embeddings = nn.Embedding(1, embed_dim)
42
+ self.bbox_bottom_right_embeddings = nn.Embedding(1, embed_dim)
43
+
44
+ def get_dense_pe(self) -> torch.Tensor:
45
+ """
46
+ Returns the positional encoding used to encode point prompts,
47
+ applied to a dense set of points the shape of the image encoding.
48
+
49
+ Returns:
50
+ torch.Tensor: Positional encoding with shape
51
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
52
+ """
53
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
54
+
55
+ def _embed_points(
56
+ self,
57
+ points: torch.Tensor,
58
+ labels: torch.Tensor,
59
+ ) -> torch.Tensor:
60
+ """Embeds point prompts."""
61
+
62
+ points = points + 0.5 # Shift to center of pixel
63
+ point_embedding = self.pe_layer.forward_with_coords(
64
+ points, self.input_image_size
65
+ )
66
+ invalid_label_ids = torch.eq(labels, -1)[:,:,None]
67
+ point_label_ids = torch.eq(labels, 1)[:,:,None]
68
+ topleft_label_ids = torch.eq(labels, 2)[:,:,None]
69
+ bottomright_label_ids = torch.eq(labels, 3)[:,:,None]
70
+ point_embedding = point_embedding + self.invalid_points.weight[:,None,:] * invalid_label_ids
71
+ point_embedding = point_embedding + self.point_embeddings.weight[:,None,:] * point_label_ids
72
+ point_embedding = point_embedding + self.bbox_top_left_embeddings.weight[:,None,:] * topleft_label_ids
73
+ point_embedding = point_embedding + self.bbox_bottom_right_embeddings.weight[:,None,:] * bottomright_label_ids
74
+ return point_embedding
75
+
76
+ def forward(
77
+ self,
78
+ coords,
79
+ labels,
80
+ ) -> torch.Tensor:
81
+ """
82
+ Embeds different types of prompts, returning both sparse and dense
83
+ embeddings.
84
+
85
+ Arguments:
86
+ points: A tensor of shape [B, 2]
87
+ labels: An integer tensor of shape [B] where each element is 1,2 or 3.
88
+
89
+ Returns:
90
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
91
+ BxNx(embed_dim), where N is determined by the number of input points
92
+ and boxes.
93
+ """
94
+ return self._embed_points(coords, labels)
95
+
96
+
97
+ class PositionEmbeddingRandom(nn.Module):
98
+ """
99
+ Positional encoding using random spatial frequencies.
100
+ """
101
+
102
+ def __init__(self, num_pos_feats: int) -> None:
103
+ super().__init__()
104
+ self.register_buffer(
105
+ "positional_encoding_gaussian_matrix", torch.randn((2, num_pos_feats))
106
+ )
107
+
108
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
109
+ """Positionally encode points that are normalized to [0,1]."""
110
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
111
+ coords = 2 * coords - 1
112
+ coords = coords @ self.positional_encoding_gaussian_matrix
113
+ coords = 2 * np.pi * coords
114
+ # outputs d_1 x ... x d_n x C shape
115
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
116
+
117
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
118
+ """Generate positional encoding for a grid of the specified size."""
119
+ h, w = size
120
+ device = self.positional_encoding_gaussian_matrix.device
121
+ grid = torch.ones([h, w], device=device, dtype=self.positional_encoding_gaussian_matrix.dtype)
122
+ y_embed = grid.cumsum(dim=0) - 0.5
123
+ x_embed = grid.cumsum(dim=1) - 0.5
124
+ y_embed = y_embed / h
125
+ x_embed = x_embed / w
126
+
127
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
128
+ return pe.permute(2, 0, 1) # C x H x W
129
+
130
+ def forward_with_coords(
131
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
132
+ ) -> torch.Tensor:
133
+ """Positionally encode points that are not normalized to [0,1]."""
134
+ coords = coords_input.clone()
135
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
136
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
137
+ # remove to(float) here, don't know why original implementation add this
138
+ return self._pe_encoding(coords) # B x N x C
139
+
140
+
141
+ class MaskDecoder(nn.Module):
142
+ def __init__(
143
+ self,
144
+ *,
145
+ transformer_dim: int,
146
+ transformer: nn.Module,
147
+ num_multimask_outputs: int,
148
+ activation: Type[nn.Module],
149
+ normalization_type: str,
150
+ normalize_before_activation: bool,
151
+ iou_head_depth: int,
152
+ iou_head_hidden_dim: int,
153
+ upscaling_layer_dims: List[int],
154
+ ) -> None:
155
+ """
156
+ Predicts masks given an image and prompt embeddings, using a
157
+ transformer architecture.
158
+
159
+ Arguments:
160
+ transformer_dim (int): the channel dimension of the transformer
161
+ transformer (nn.Module): the transformer used to predict masks
162
+ num_multimask_outputs (int): the number of masks to predict
163
+ when disambiguating masks
164
+ activation (nn.Module): the type of activation to use when
165
+ upscaling masks
166
+ iou_head_depth (int): the depth of the MLP used to predict
167
+ mask quality
168
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
169
+ used to predict mask quality
170
+ """
171
+ super().__init__()
172
+ self.transformer_dim = transformer_dim
173
+ self.transformer = transformer
174
+
175
+ self.num_multimask_outputs = num_multimask_outputs
176
+
177
+ self.iou_token = nn.Embedding(1, transformer_dim)
178
+ if num_multimask_outputs > 1:
179
+ self.num_mask_tokens = num_multimask_outputs + 1
180
+ else:
181
+ self.num_mask_tokens = 1
182
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
183
+ output_dim_after_upscaling = transformer_dim
184
+
185
+ self.final_output_upscaling_layers = nn.ModuleList([])
186
+ for idx, layer_dims in enumerate(upscaling_layer_dims):
187
+ self.final_output_upscaling_layers.append(
188
+ nn.Sequential(
189
+ nn.ConvTranspose2d(
190
+ output_dim_after_upscaling,
191
+ layer_dims,
192
+ kernel_size=2,
193
+ stride=2,
194
+ ),
195
+ nn.GroupNorm(1, layer_dims)
196
+ if idx < len(upscaling_layer_dims) - 1
197
+ else nn.Identity(),
198
+ activation(),
199
+ )
200
+ )
201
+ output_dim_after_upscaling = layer_dims
202
+
203
+ self.output_hypernetworks_mlps = nn.ModuleList(
204
+ [
205
+ MLPBlock(
206
+ input_dim=transformer_dim,
207
+ hidden_dim=transformer_dim,
208
+ output_dim=output_dim_after_upscaling,
209
+ num_layers=2,
210
+ act=activation,
211
+ )
212
+ for i in range(self.num_mask_tokens)
213
+ ]
214
+ )
215
+
216
+ self.iou_prediction_head = MLPBlock(
217
+ input_dim=transformer_dim,
218
+ hidden_dim=iou_head_hidden_dim,
219
+ output_dim=self.num_mask_tokens,
220
+ num_layers=iou_head_depth,
221
+ act=activation,
222
+ )
223
+
224
+ def forward(
225
+ self,
226
+ image_embeddings: torch.Tensor,
227
+ image_pe: torch.Tensor,
228
+ sparse_prompt_embeddings: torch.Tensor,
229
+ multimask_output: bool,
230
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
231
+ """
232
+ Predict masks given image and prompt embeddings.
233
+
234
+ Arguments:
235
+ image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W]
236
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings (the batch dimension is broadcastable).
237
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
238
+ multimask_output (bool): Whether to return multiple masks or a single
239
+ mask.
240
+
241
+ Returns:
242
+ torch.Tensor: batched predicted masks
243
+ torch.Tensor: batched predictions of mask quality
244
+ """
245
+
246
+ (
247
+ batch_size,
248
+ max_num_queries,
249
+ sparse_embed_dim_1,
250
+ sparse_embed_dim_2,
251
+ ) = sparse_prompt_embeddings.shape
252
+
253
+ (
254
+ _,
255
+ image_embed_dim_c,
256
+ image_embed_dim_h,
257
+ image_embed_dim_w,
258
+ ) = image_embeddings.shape
259
+
260
+ # Tile the image embedding for all queries.
261
+ image_embeddings_tiled = torch.tile(
262
+ image_embeddings[:, None, :, :, :], [1, max_num_queries, 1, 1, 1]
263
+ ).view(
264
+ batch_size * max_num_queries,
265
+ image_embed_dim_c,
266
+ image_embed_dim_h,
267
+ image_embed_dim_w,
268
+ )
269
+ sparse_prompt_embeddings = sparse_prompt_embeddings.reshape(
270
+ batch_size * max_num_queries, sparse_embed_dim_1, sparse_embed_dim_2
271
+ )
272
+ masks, iou_pred = self.predict_masks(
273
+ image_embeddings=image_embeddings_tiled,
274
+ image_pe=image_pe,
275
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
276
+ )
277
+
278
+ if multimask_output and self.num_multimask_outputs > 1:
279
+ return masks[:, 1:, :], iou_pred[:, 1:]
280
+ else:
281
+ return masks[:, :1, :], iou_pred[:, :1]
282
+
283
+ def predict_masks(
284
+ self,
285
+ image_embeddings: torch.Tensor,
286
+ image_pe: torch.Tensor,
287
+ sparse_prompt_embeddings: torch.Tensor,
288
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
289
+ """Predicts masks. See 'forward' for more details."""
290
+ # Concatenate output tokens
291
+ output_tokens = torch.cat(
292
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
293
+ )
294
+ output_tokens = output_tokens.unsqueeze(0).expand(
295
+ sparse_prompt_embeddings.size(0), -1, -1
296
+ )
297
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
298
+ # Expand per-image data in batch direction to be per-mask
299
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
300
+ b, c, h, w = image_embeddings.shape
301
+ hs, src = self.transformer(image_embeddings, pos_src, tokens)
302
+ iou_token_out = hs[:, 0, :]
303
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
304
+
305
+ # Upscale mask embeddings and predict masks using the mask tokens
306
+ upscaled_embedding = src.transpose(1, 2).view(b, c, h, w)
307
+
308
+ for upscaling_layer in self.final_output_upscaling_layers:
309
+ upscaled_embedding = upscaling_layer(upscaled_embedding)
310
+ hyper_in_list: List[torch.Tensor] = []
311
+ for i, output_hypernetworks_mlp in enumerate(self.output_hypernetworks_mlps):
312
+ hyper_in_list.append(output_hypernetworks_mlp(mask_tokens_out[:, i, :]))
313
+ hyper_in = torch.stack(hyper_in_list, dim=1)
314
+ b, c, h, w = upscaled_embedding.shape
315
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
316
+ # Generate mask quality predictions
317
+ iou_pred = self.iou_prediction_head(iou_token_out)
318
+ return masks, iou_pred
model/EfficientSAM/efficient_sam/efficient_sam_encoder.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import List, Optional, Tuple, Type
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class LayerNorm2d(nn.Module):
16
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
17
+ super().__init__()
18
+ self.weight = nn.Parameter(torch.ones(num_channels))
19
+ self.bias = nn.Parameter(torch.zeros(num_channels))
20
+ self.eps = eps
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ u = x.mean(1, keepdim=True)
24
+ s = (x - u).pow(2).mean(1, keepdim=True)
25
+ x = (x - u) / torch.sqrt(s + self.eps)
26
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
27
+ return x
28
+
29
+
30
+ class PatchEmbed(nn.Module):
31
+ """2D Image to Patch Embedding"""
32
+
33
+ def __init__(
34
+ self,
35
+ img_size,
36
+ patch_size,
37
+ in_chans,
38
+ embed_dim,
39
+ ):
40
+ super().__init__()
41
+ self.proj = nn.Conv2d(
42
+ in_chans,
43
+ embed_dim,
44
+ kernel_size=(patch_size, patch_size),
45
+ stride=(patch_size, patch_size),
46
+ bias=True,
47
+ )
48
+
49
+ def forward(self, x):
50
+ B, C, H, W = x.shape
51
+ x = self.proj(x)
52
+ return x
53
+
54
+
55
+ class Attention(nn.Module):
56
+ def __init__(
57
+ self,
58
+ dim,
59
+ num_heads,
60
+ qkv_bias,
61
+ qk_scale=None,
62
+ ):
63
+ super().__init__()
64
+ self.num_heads = num_heads
65
+ head_dim = dim // num_heads
66
+ self.scale = qk_scale or head_dim**-0.5
67
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
68
+ self.proj = nn.Linear(dim, dim)
69
+
70
+ def forward(self, x):
71
+ B, N, C = x.shape
72
+ qkv = (
73
+ self.qkv(x)
74
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
75
+ .permute(2, 0, 3, 1, 4)
76
+ )
77
+ q, k, v = (
78
+ qkv[0],
79
+ qkv[1],
80
+ qkv[2],
81
+ )
82
+ attn = (q @ k.transpose(-2, -1)) * self.scale
83
+ attn = attn.softmax(dim=-1)
84
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
85
+ x = self.proj(x)
86
+ return x
87
+
88
+
89
+ class Mlp(nn.Module):
90
+ def __init__(
91
+ self,
92
+ in_features,
93
+ hidden_features=None,
94
+ out_features=None,
95
+ act_layer=nn.GELU,
96
+ ):
97
+ super().__init__()
98
+ out_features = out_features or in_features
99
+ hidden_features = hidden_features or in_features
100
+ self.fc1 = nn.Linear(in_features, hidden_features)
101
+ self.act = act_layer()
102
+ self.fc2 = nn.Linear(hidden_features, out_features)
103
+
104
+ def forward(self, x):
105
+ x = self.fc1(x)
106
+ x = self.act(x)
107
+ x = self.fc2(x)
108
+ return x
109
+
110
+
111
+ class Block(nn.Module):
112
+ def __init__(
113
+ self,
114
+ dim,
115
+ num_heads,
116
+ mlp_ratio=4.0,
117
+ qkv_bias=False,
118
+ qk_scale=None,
119
+ act_layer=nn.GELU,
120
+ ):
121
+ super().__init__()
122
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
123
+ self.attn = Attention(
124
+ dim,
125
+ num_heads=num_heads,
126
+ qkv_bias=qkv_bias,
127
+ qk_scale=qk_scale,
128
+ )
129
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
130
+ mlp_hidden_dim = int(dim * mlp_ratio)
131
+ self.mlp = Mlp(
132
+ in_features=dim,
133
+ hidden_features=mlp_hidden_dim,
134
+ act_layer=act_layer,
135
+ )
136
+
137
+ def forward(self, x):
138
+ x = x + self.attn(self.norm1(x))
139
+ x = x + self.mlp(self.norm2(x))
140
+ return x
141
+
142
+
143
+ @torch.jit.export
144
+ def get_abs_pos(
145
+ abs_pos: torch.Tensor, has_cls_token: bool, hw: List[int]
146
+ ) -> torch.Tensor:
147
+ """
148
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
149
+ dimension for the original embeddings.
150
+ Args:
151
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
152
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
153
+ hw (Tuple): size of input image tokens.
154
+
155
+ Returns:
156
+ Absolute positional embeddings after processing with shape (1, H, W, C)
157
+ """
158
+ h = hw[0]
159
+ w = hw[1]
160
+ if has_cls_token:
161
+ abs_pos = abs_pos[:, 1:]
162
+ xy_num = abs_pos.shape[1]
163
+ size = int(math.sqrt(xy_num))
164
+ assert size * size == xy_num
165
+
166
+ if size != h or size != w:
167
+ new_abs_pos = F.interpolate(
168
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
169
+ size=(h, w),
170
+ mode="bicubic",
171
+ align_corners=False,
172
+ )
173
+ return new_abs_pos.permute(0, 2, 3, 1)
174
+ else:
175
+ return abs_pos.reshape(1, h, w, -1)
176
+
177
+
178
+ # Image encoder for efficient SAM.
179
+ class ImageEncoderViT(nn.Module):
180
+ def __init__(
181
+ self,
182
+ img_size: int,
183
+ patch_size: int,
184
+ in_chans: int,
185
+ patch_embed_dim: int,
186
+ normalization_type: str,
187
+ depth: int,
188
+ num_heads: int,
189
+ mlp_ratio: float,
190
+ neck_dims: List[int],
191
+ act_layer: Type[nn.Module],
192
+ ) -> None:
193
+ """
194
+ Args:
195
+ img_size (int): Input image size.
196
+ patch_size (int): Patch size.
197
+ in_chans (int): Number of input image channels.
198
+ patch_embed_dim (int): Patch embedding dimension.
199
+ depth (int): Depth of ViT.
200
+ num_heads (int): Number of attention heads in each ViT block.
201
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
202
+ act_layer (nn.Module): Activation layer.
203
+ """
204
+ super().__init__()
205
+
206
+ self.img_size = img_size
207
+ self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1))
208
+ self.transformer_output_dim = ([patch_embed_dim] + neck_dims)[-1]
209
+ self.pretrain_use_cls_token = True
210
+ pretrain_img_size = 224
211
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, patch_embed_dim)
212
+ # Initialize absolute positional embedding with pretrain image size.
213
+ num_patches = (pretrain_img_size // patch_size) * (
214
+ pretrain_img_size // patch_size
215
+ )
216
+ num_positions = num_patches + 1
217
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, patch_embed_dim))
218
+ self.blocks = nn.ModuleList()
219
+ for i in range(depth):
220
+ vit_block = Block(patch_embed_dim, num_heads, mlp_ratio, True)
221
+ self.blocks.append(vit_block)
222
+ self.neck = nn.Sequential(
223
+ nn.Conv2d(
224
+ patch_embed_dim,
225
+ neck_dims[0],
226
+ kernel_size=1,
227
+ bias=False,
228
+ ),
229
+ LayerNorm2d(neck_dims[0]),
230
+ nn.Conv2d(
231
+ neck_dims[0],
232
+ neck_dims[0],
233
+ kernel_size=3,
234
+ padding=1,
235
+ bias=False,
236
+ ),
237
+ LayerNorm2d(neck_dims[0]),
238
+ )
239
+
240
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
241
+ assert (
242
+ x.shape[2] == self.img_size and x.shape[3] == self.img_size
243
+ ), "input image size must match self.img_size"
244
+ x = self.patch_embed(x)
245
+ # B C H W -> B H W C
246
+ x = x.permute(0, 2, 3, 1)
247
+ x = x + get_abs_pos(
248
+ self.pos_embed, self.pretrain_use_cls_token, [x.shape[1], x.shape[2]]
249
+ )
250
+ num_patches = x.shape[1]
251
+ assert x.shape[2] == num_patches
252
+ x = x.reshape(x.shape[0], num_patches * num_patches, x.shape[3])
253
+ for blk in self.blocks:
254
+ x = blk(x)
255
+ x = x.reshape(x.shape[0], num_patches, num_patches, x.shape[2])
256
+ x = self.neck(x.permute(0, 3, 1, 2))
257
+ return x
model/EfficientSAM/efficient_sam/mlp.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type
2
+
3
+ from torch import nn
4
+
5
+
6
+ # Lightly adapted from
7
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
8
+ class MLPBlock(nn.Module):
9
+ def __init__(
10
+ self,
11
+ input_dim: int,
12
+ hidden_dim: int,
13
+ output_dim: int,
14
+ num_layers: int,
15
+ act: Type[nn.Module],
16
+ ) -> None:
17
+ super().__init__()
18
+ self.num_layers = num_layers
19
+ h = [hidden_dim] * (num_layers - 1)
20
+ self.layers = nn.ModuleList(
21
+ nn.Sequential(nn.Linear(n, k), act())
22
+ for n, k in zip([input_dim] + h, [hidden_dim] * num_layers)
23
+ )
24
+ self.fc = nn.Linear(hidden_dim, output_dim)
25
+
26
+ def forward(self, x):
27
+ for layer in self.layers:
28
+ x = layer(x)
29
+ return self.fc(x)
model/EfficientSAM/efficient_sam/two_way_transformer.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple, Type
3
+ import torch
4
+ from torch import nn, Tensor
5
+ from .mlp import MLPBlock
6
+
7
+
8
+
9
+
10
+ class TwoWayTransformer(nn.Module):
11
+ def __init__(
12
+ self,
13
+ depth: int,
14
+ embedding_dim: int,
15
+ num_heads: int,
16
+ mlp_dim: int,
17
+ activation: Type[nn.Module],
18
+ normalize_before_activation: bool,
19
+ attention_downsample_rate: int = 2,
20
+ ) -> None:
21
+ """
22
+ A transformer decoder that attends to an input image using
23
+ queries whose positional embedding is supplied.
24
+
25
+ Args:
26
+ depth (int): number of layers in the transformer
27
+ embedding_dim (int): the channel dimension for the input embeddings
28
+ num_heads (int): the number of heads for multihead attention. Must
29
+ divide embedding_dim
30
+ mlp_dim (int): the channel dimension internal to the MLP block
31
+ activation (nn.Module): the activation to use in the MLP block
32
+ """
33
+ super().__init__()
34
+ self.depth = depth
35
+ self.embedding_dim = embedding_dim
36
+ self.num_heads = num_heads
37
+ self.mlp_dim = mlp_dim
38
+ self.layers = nn.ModuleList()
39
+
40
+ for i in range(depth):
41
+ curr_layer = TwoWayAttentionBlock(
42
+ embedding_dim=embedding_dim,
43
+ num_heads=num_heads,
44
+ mlp_dim=mlp_dim,
45
+ activation=activation,
46
+ normalize_before_activation=normalize_before_activation,
47
+ attention_downsample_rate=attention_downsample_rate,
48
+ skip_first_layer_pe=(i == 0),
49
+ )
50
+ self.layers.append(curr_layer)
51
+
52
+ self.final_attn_token_to_image = AttentionForTwoWayAttentionBlock(
53
+ embedding_dim,
54
+ num_heads,
55
+ downsample_rate=attention_downsample_rate,
56
+ )
57
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
58
+
59
+ def forward(
60
+ self,
61
+ image_embedding: Tensor,
62
+ image_pe: Tensor,
63
+ point_embedding: Tensor,
64
+ ) -> Tuple[Tensor, Tensor]:
65
+ """
66
+ Args:
67
+ image_embedding (torch.Tensor): image to attend to. Should be shape
68
+ B x embedding_dim x h x w for any h and w.
69
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
70
+ have the same shape as image_embedding.
71
+ point_embedding (torch.Tensor): the embedding to add to the query points.
72
+ Must have shape B x N_points x embedding_dim for any N_points.
73
+
74
+ Returns:
75
+ torch.Tensor: the processed point_embedding
76
+ torch.Tensor: the processed image_embedding
77
+ """
78
+
79
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
80
+ bs, c, h, w = image_embedding.shape
81
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
82
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
83
+
84
+ # Prepare queries
85
+ queries = point_embedding
86
+ keys = image_embedding
87
+
88
+ # Apply transformer blocks and final layernorm
89
+ for idx, layer in enumerate(self.layers):
90
+ queries, keys = layer(
91
+ queries=queries,
92
+ keys=keys,
93
+ query_pe=point_embedding,
94
+ key_pe=image_pe,
95
+ )
96
+
97
+ # Apply the final attention layer from the points to the image
98
+ q = queries + point_embedding
99
+ k = keys + image_pe
100
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
101
+ queries = queries + attn_out
102
+ queries = self.norm_final_attn(queries)
103
+ return queries, keys
104
+
105
+
106
+ class TwoWayAttentionBlock(nn.Module):
107
+ def __init__(
108
+ self,
109
+ embedding_dim: int,
110
+ num_heads: int,
111
+ mlp_dim: int,
112
+ activation: Type[nn.Module],
113
+ normalize_before_activation: bool,
114
+ attention_downsample_rate: int = 2,
115
+ skip_first_layer_pe: bool = False,
116
+ ) -> None:
117
+ """
118
+ A transformer block with four layers: (1) self-attention of sparse
119
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
120
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
121
+ inputs.
122
+
123
+ Arguments:
124
+ embedding_dim (int): the channel dimension of the embeddings
125
+ num_heads (int): the number of heads in the attention layers
126
+ mlp_dim (int): the hidden dimension of the mlp block
127
+ activation (nn.Module): the activation of the mlp block
128
+ skip_first_layer_pe (bool): skip the PE on the first layer
129
+ """
130
+ super().__init__()
131
+ self.self_attn = AttentionForTwoWayAttentionBlock(embedding_dim, num_heads)
132
+ self.norm1 = nn.LayerNorm(embedding_dim)
133
+
134
+ self.cross_attn_token_to_image = AttentionForTwoWayAttentionBlock(
135
+ embedding_dim,
136
+ num_heads,
137
+ downsample_rate=attention_downsample_rate,
138
+ )
139
+ self.norm2 = nn.LayerNorm(embedding_dim)
140
+
141
+ self.mlp = MLPBlock(
142
+ embedding_dim,
143
+ mlp_dim,
144
+ embedding_dim,
145
+ 1,
146
+ activation,
147
+ )
148
+
149
+ self.norm3 = nn.LayerNorm(embedding_dim)
150
+
151
+ self.norm4 = nn.LayerNorm(embedding_dim)
152
+ self.cross_attn_image_to_token = AttentionForTwoWayAttentionBlock(
153
+ embedding_dim,
154
+ num_heads,
155
+ downsample_rate=attention_downsample_rate,
156
+ )
157
+
158
+ self.skip_first_layer_pe = skip_first_layer_pe
159
+
160
+ def forward(
161
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
162
+ ) -> Tuple[Tensor, Tensor]:
163
+ # Self attention block
164
+ if not self.skip_first_layer_pe:
165
+ queries = queries + query_pe
166
+ attn_out = self.self_attn(q=queries, k=queries, v=queries)
167
+ queries = queries + attn_out
168
+ queries = self.norm1(queries)
169
+
170
+ # Cross attention block, tokens attending to image embedding
171
+ q = queries + query_pe
172
+ k = keys + key_pe
173
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
174
+ queries = queries + attn_out
175
+ queries = self.norm2(queries)
176
+
177
+ # MLP block
178
+ mlp_out = self.mlp(queries)
179
+ queries = queries + mlp_out
180
+ queries = self.norm3(queries)
181
+
182
+ # Cross attention block, image embedding attending to tokens
183
+ q = queries + query_pe
184
+ k = keys + key_pe
185
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
186
+ keys = keys + attn_out
187
+ keys = self.norm4(keys)
188
+
189
+ return queries, keys
190
+
191
+
192
+ class AttentionForTwoWayAttentionBlock(nn.Module):
193
+ """
194
+ An attention layer that allows for downscaling the size of the embedding
195
+ after projection to queries, keys, and values.
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ embedding_dim: int,
201
+ num_heads: int,
202
+ downsample_rate: int = 1,
203
+ ) -> None:
204
+ super().__init__()
205
+ self.embedding_dim = embedding_dim
206
+ self.internal_dim = embedding_dim // downsample_rate
207
+ self.num_heads = num_heads
208
+ assert (
209
+ self.internal_dim % num_heads == 0
210
+ ), "num_heads must divide embedding_dim."
211
+ self.c_per_head = self.internal_dim / num_heads
212
+ self.inv_sqrt_c_per_head = 1.0 / math.sqrt(self.c_per_head)
213
+
214
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
215
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
216
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
217
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
218
+ self._reset_parameters()
219
+
220
+ def _reset_parameters(self) -> None:
221
+ # The fan_out is incorrect, but matches pytorch's initialization
222
+ # for which qkv is a single 3*embedding_dim x embedding_dim matrix
223
+ fan_in = self.embedding_dim
224
+ fan_out = 3 * self.internal_dim
225
+ # Xavier uniform with our custom fan_out
226
+ bnd = math.sqrt(6 / (fan_in + fan_out))
227
+ nn.init.uniform_(self.q_proj.weight, -bnd, bnd)
228
+ nn.init.uniform_(self.k_proj.weight, -bnd, bnd)
229
+ nn.init.uniform_(self.v_proj.weight, -bnd, bnd)
230
+ # out_proj.weight is left with default initialization, like pytorch attention
231
+ nn.init.zeros_(self.q_proj.bias)
232
+ nn.init.zeros_(self.k_proj.bias)
233
+ nn.init.zeros_(self.v_proj.bias)
234
+ nn.init.zeros_(self.out_proj.bias)
235
+
236
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
237
+ b, n, c = x.shape
238
+ x = x.reshape(b, n, num_heads, c // num_heads)
239
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
240
+
241
+ def _recombine_heads(self, x: Tensor) -> Tensor:
242
+ b, n_heads, n_tokens, c_per_head = x.shape
243
+ x = x.transpose(1, 2)
244
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
245
+
246
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
247
+ # Input projections
248
+ q = self.q_proj(q)
249
+ k = self.k_proj(k)
250
+ v = self.v_proj(v)
251
+
252
+ # Separate into heads
253
+ q = self._separate_heads(q, self.num_heads)
254
+ k = self._separate_heads(k, self.num_heads)
255
+ v = self._separate_heads(v, self.num_heads)
256
+
257
+ # Attention
258
+ _, _, _, c_per_head = q.shape
259
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
260
+ attn = attn * self.inv_sqrt_c_per_head
261
+ attn = torch.softmax(attn, dim=-1)
262
+ # Get output
263
+ out = attn @ v
264
+ out = self._recombine_heads(out)
265
+ out = self.out_proj(out)
266
+ return out
model/configuration_evf.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ Evf model configuration"""
21
+
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.utils import logging
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ EVF_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
28
+
29
+
30
+ class EvfConfig(PretrainedConfig):
31
+ r"""
32
+ This is the configuration class to store the configuration of a [`EvfSam`].
33
+
34
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35
+ documentation from [`PretrainedConfig`] for more information.
36
+
37
+ Args:
38
+ hidden_size (`int`, *optional*, defaults to 4096):
39
+ Dimension of the hidden representations.
40
+ pretraining_tp (`int`, *optional*, defaults to `1`):
41
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
42
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
43
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
44
+ issue](https://github.com/pytorch/pytorch/issues/76232).
45
+ rope_scaling (`Dict`, *optional*):
46
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
47
+ strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
48
+ is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
49
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
50
+ these scaling strategies behave:
51
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
52
+ experimental feature, subject to breaking API changes in future versions.
53
+
54
+ Example:
55
+
56
+ ```python
57
+
58
+ >>> configuration = EvfConfig()
59
+ >>> model = EvfSam(configuration)
60
+
61
+ >>> # Accessing the model configuration
62
+ >>> configuration = model.config
63
+ ```"""
64
+ model_type = "evf"
65
+ keys_to_ignore_at_inference = ["past_key_values"]
66
+
67
+ def __init__(
68
+ self,
69
+ hidden_size=768,
70
+ pad_token_id=1,
71
+ bos_token_id=0,
72
+ eos_token_id=2,
73
+ pretraining_tp=1,
74
+ tie_word_embeddings=False,
75
+ rope_scaling=None,
76
+ out_dim=256,
77
+ **kwargs,
78
+ ):
79
+ self.hidden_size = hidden_size
80
+ self.out_dim = out_dim
81
+
82
+ # self.pretraining_tp = pretraining_tp
83
+ # self.rope_scaling = rope_scaling
84
+ # self._rope_scaling_validation()
85
+
86
+ super().__init__(
87
+ pad_token_id=pad_token_id,
88
+ bos_token_id=bos_token_id,
89
+ eos_token_id=eos_token_id,
90
+ tie_word_embeddings=tie_word_embeddings,
91
+ **kwargs,
92
+ )
93
+
94
+ def _rope_scaling_validation(self):
95
+ """
96
+ Validate the `rope_scaling` configuration.
97
+ """
98
+ if self.rope_scaling is None:
99
+ return
100
+
101
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
102
+ raise ValueError(
103
+ "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
104
+ f"got {self.rope_scaling}"
105
+ )
106
+ rope_scaling_type = self.rope_scaling.get("type", None)
107
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
108
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
109
+ raise ValueError(
110
+ f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
111
+ )
112
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
113
+ raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
model/evf_effisam.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM
7
+ from .EfficientSAM.efficient_sam.build_efficient_sam import build_efficient_sam_vits, build_efficient_sam_vitt
8
+ from .unilm.beit3.modeling_utils import BEiT3Wrapper, _get_base_config, _get_large_config
9
+ from .configuration_evf import EvfConfig
10
+
11
+
12
+ def dice_loss(
13
+ inputs: torch.Tensor,
14
+ targets: torch.Tensor,
15
+ num_masks: float,
16
+ scale=1000, # 100000.0,
17
+ eps=1e-6,
18
+ ):
19
+ """
20
+ Compute the DICE loss, similar to generalized IOU for masks
21
+ Args:
22
+ inputs: A float tensor of arbitrary shape.
23
+ The predictions for each example.
24
+ targets: A float tensor with the same shape as inputs. Stores the binary
25
+ classification label for each element in inputs
26
+ (0 for the negative class and 1 for the positive class).
27
+ """
28
+ inputs = inputs.sigmoid()
29
+ inputs = inputs.flatten(1, 2)
30
+ targets = targets.flatten(1, 2)
31
+ numerator = 2 * (inputs / scale * targets).sum(-1)
32
+ denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1)
33
+ loss = 1 - (numerator + eps) / (denominator + eps)
34
+ loss = loss.sum() / (num_masks + 1e-8)
35
+ return loss
36
+
37
+
38
+ def sigmoid_ce_loss(
39
+ inputs: torch.Tensor,
40
+ targets: torch.Tensor,
41
+ num_masks: float,
42
+ ):
43
+ """
44
+ Args:
45
+ inputs: A float tensor of arbitrary shape.
46
+ The predictions for each example.
47
+ targets: A float tensor with the same shape as inputs. Stores the binary
48
+ classification label for each element in inputs
49
+ (0 for the negative class and 1 for the positive class).
50
+ Returns:
51
+ Loss tensor
52
+ """
53
+ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
54
+ loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8)
55
+ return loss
56
+
57
+
58
+
59
+ class EvfEffiSamModel(PreTrainedModel):
60
+ config_class = EvfConfig
61
+ def __init__(
62
+ self,
63
+ config,
64
+ **kwargs
65
+ ):
66
+ super(EvfEffiSamModel, self).__init__(config)
67
+
68
+ self.config = config
69
+ self.vision_pretrained = kwargs.get("vision_pretrained", None)
70
+ self.encoder_pretrained = kwargs.get("encoder_pretrained", None)
71
+ self.dice_loss_weight = kwargs.get("dice_loss_weight", None)
72
+ self.bce_loss_weight = kwargs.get("bce_loss_weight", None)
73
+ self.train_mask_decoder = kwargs.get("train_mask_decoder", False)
74
+ self.initialize_evf_modules(config)
75
+
76
+
77
+ def initialize_evf_modules(self, config):
78
+ # EffiSAM
79
+ if config.sam_scale=="tiny":
80
+ self.visual_model = build_efficient_sam_vitt(self.vision_pretrained)
81
+ elif config.sam_scale=="small":
82
+ # vits scale, or without pretrained weight (self.vision_pretrained=None)
83
+ self.visual_model = build_efficient_sam_vits(self.vision_pretrained)
84
+ else:
85
+ raise NotImplementedError
86
+
87
+ for param in self.visual_model.parameters():
88
+ param.requires_grad = False
89
+ if self.train_mask_decoder:
90
+ self.visual_model.mask_decoder.train()
91
+ for param in self.visual_model.mask_decoder.parameters():
92
+ param.requires_grad = True
93
+
94
+ # beit-3
95
+ if self.config.mm_extractor_scale == "base":
96
+ beit_config = _get_base_config()
97
+ elif self.config.mm_extractor_scale == "large":
98
+ beit_config = _get_large_config()
99
+ else:
100
+ raise AttributeError(f"model config should contain key 'mm_extractor_scale', with value 'base' or 'large'.")
101
+
102
+ self.mm_extractor = BEiT3Wrapper(beit_config)
103
+ if self.encoder_pretrained is not None:
104
+ beit_state_dict = torch.load(self.encoder_pretrained)["model"]
105
+ self.mm_extractor.load_state_dict(
106
+ beit_state_dict,
107
+ strict=False
108
+ )
109
+
110
+ for param in self.mm_extractor.parameters():
111
+ param.requires_grad = True
112
+
113
+ # Projection layer
114
+ in_dim = config.hidden_size
115
+ assert in_dim==beit_config.encoder_embed_dim, \
116
+ f"projection layer dim {in_dim} mismatch with mm_extractor dim {beit_config.encoder_embed_dim}"
117
+ out_dim = config.out_dim
118
+ text_fc = [
119
+ nn.Linear(in_dim, in_dim),
120
+ nn.ReLU(),
121
+ nn.Linear(in_dim, out_dim)
122
+ ]
123
+ self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
124
+ self.text_hidden_fcs.train()
125
+ for param in self.text_hidden_fcs.parameters():
126
+ param.requires_grad = True
127
+
128
+ def get_visual_embs(self, pixel_values: torch.Tensor):
129
+ with torch.no_grad():
130
+ image_embeddings_list = []
131
+ for i in range(pixel_values.shape[0]):
132
+ torch.cuda.empty_cache()
133
+ image_embeddings = self.visual_model.image_encoder(
134
+ pixel_values[i].unsqueeze(0)
135
+ )
136
+ image_embeddings_list.append(image_embeddings)
137
+ torch.cuda.empty_cache()
138
+ image_embeddings = torch.cat(image_embeddings_list, 0)
139
+ return image_embeddings
140
+
141
+ def forward(
142
+ self,
143
+ images: torch.Tensor,
144
+ images_evf: torch.Tensor,
145
+ input_ids: torch.Tensor,
146
+ attention_masks: torch.Tensor,
147
+ offset: torch.Tensor,
148
+ masks_list: List[torch.Tensor],
149
+ label_list: List[torch.Tensor],
150
+ resize_list: List[tuple],
151
+ inference: bool = False,
152
+ **kwargs,
153
+ ):
154
+ image_embeddings = self.get_visual_embs(images)
155
+ batch_size = image_embeddings.shape[0]
156
+ assert batch_size == len(offset) - 1
157
+
158
+ images_evf_list = []
159
+ for i in range(len(offset) - 1):
160
+ start_i, end_i = offset[i], offset[i + 1]
161
+ images_evf_i = (
162
+ images_evf[i]
163
+ .unsqueeze(0)
164
+ .expand(end_i - start_i, -1, -1, -1)
165
+ .contiguous()
166
+ )
167
+ images_evf_list.append(images_evf_i)
168
+ images_evf = torch.cat(images_evf_list, dim=0)
169
+
170
+ multimask_output = False
171
+ output = self.mm_extractor.beit3(
172
+ visual_tokens=images_evf,
173
+ textual_tokens=input_ids,
174
+ text_padding_position=~attention_masks
175
+ )
176
+
177
+ feat = output["encoder_out"][:, :1, ...]
178
+
179
+ feat = self.text_hidden_fcs[0](feat)
180
+ feat = torch.split(feat, [offset[i+1] - offset[i] for i in range(len(offset)-1)])
181
+
182
+ pred_masks = []
183
+ for i in range(len(feat)):
184
+ sparse_embeddings = feat[i].unsqueeze(0)
185
+ sparse_embeddings = sparse_embeddings.to(feat[i].dtype)
186
+ low_res_masks, iou_predictions = self.visual_model.mask_decoder(
187
+ image_embeddings=image_embeddings[i].unsqueeze(0),
188
+ image_pe=self.visual_model.prompt_encoder.get_dense_pe(),
189
+ sparse_prompt_embeddings=sparse_embeddings,
190
+ multimask_output=multimask_output,
191
+ )
192
+
193
+ if multimask_output:
194
+ sorted_ids = torch.argsort(iou_predictions, dim=-1, descending=True)
195
+ low_res_masks = torch.take_along_dim(low_res_masks, sorted_ids[..., None, None], dim=1)
196
+
197
+ pred_mask = self.postprocess_masks(
198
+ low_res_masks[:, :1],
199
+ input_size=resize_list[i],
200
+ original_size=label_list[i].shape,
201
+ )
202
+ pred_masks.append(pred_mask[:, 0])
203
+
204
+ gt_masks = masks_list
205
+
206
+ if inference:
207
+ return {
208
+ "pred_masks": pred_masks,
209
+ "gt_masks": gt_masks,
210
+ }
211
+
212
+ mask_bce_loss = 0
213
+ mask_dice_loss = 0
214
+ num_masks = 0
215
+ for batch_idx in range(len(pred_masks)):
216
+ gt_mask = gt_masks[batch_idx]
217
+ pred_mask = pred_masks[batch_idx]
218
+
219
+ assert (
220
+ gt_mask.shape[0] == pred_mask.shape[0]
221
+ ), "gt_mask.shape: {}, pred_mask.shape: {}".format(
222
+ gt_mask.shape, pred_mask.shape
223
+ )
224
+ mask_bce_loss += (
225
+ sigmoid_ce_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
226
+ * gt_mask.shape[0]
227
+ )
228
+ mask_dice_loss += (
229
+ dice_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
230
+ * gt_mask.shape[0]
231
+ )
232
+ num_masks += gt_mask.shape[0]
233
+
234
+ mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8)
235
+ mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
236
+ mask_loss = mask_bce_loss + mask_dice_loss
237
+
238
+ loss = mask_loss
239
+
240
+ return {
241
+ "loss": loss,
242
+ "mask_bce_loss": mask_bce_loss,
243
+ "mask_dice_loss": mask_dice_loss,
244
+ "mask_loss": mask_loss,
245
+ }
246
+
247
+ def postprocess_masks(
248
+ self,
249
+ masks: torch.Tensor,
250
+ input_size: Tuple[int, ...],
251
+ original_size: Tuple[int, ...],
252
+ ) -> torch.Tensor:
253
+ """
254
+ pre-process of Effi-SAM is different from SAM, where there is no padding,
255
+ so cropping is not needed in post-process.
256
+ """
257
+
258
+ dtype = masks.dtype
259
+
260
+ # masks = F.interpolate(
261
+ # masks.float(),
262
+ # (1024, 1024),
263
+ # mode="bilinear",
264
+ # align_corners=False,
265
+ # )
266
+ # masks = masks.to(dtype)
267
+ # masks = masks[..., : input_size[0], : input_size[1]]
268
+
269
+ masks = F.interpolate(
270
+ masks, original_size, mode="bilinear", align_corners=False
271
+ )
272
+ masks = masks.to(dtype)
273
+ return masks
274
+
275
+ def inference(
276
+ self,
277
+ images,
278
+ images_evf,
279
+ input_ids,
280
+ resize_list,
281
+ original_size_list,
282
+ multimask_output=False,
283
+ ):
284
+ with torch.no_grad():
285
+ image_embeddings = self.visual_model.image_encoder(images)
286
+
287
+ output = self.mm_extractor.beit3(visual_tokens=images_evf, textual_tokens=input_ids, text_padding_position=torch.zeros_like(input_ids))
288
+
289
+ feat = output["encoder_out"][:, :1, ...]
290
+ feat = self.text_hidden_fcs[0](feat)
291
+ sparse_embeddings = feat.unsqueeze(0)
292
+ sparse_embeddings = sparse_embeddings.to(feat.dtype)
293
+ low_res_masks, iou_predictions = self.visual_model.mask_decoder(
294
+ image_embeddings=image_embeddings,
295
+ image_pe=self.visual_model.prompt_encoder.get_dense_pe(),
296
+ sparse_prompt_embeddings=sparse_embeddings,
297
+ multimask_output=multimask_output,
298
+ )
299
+ if multimask_output:
300
+ sorted_ids = torch.argsort(iou_predictions, dim=-1, descending=True)
301
+ low_res_masks = torch.take_along_dim(low_res_masks, sorted_ids[..., None, None], dim=1)
302
+
303
+ pred_mask = self.postprocess_masks(
304
+ low_res_masks[:, :1],
305
+ input_size=resize_list[0],
306
+ original_size=original_size_list[0],
307
+ )
308
+
309
+ return pred_mask[:, 0]
310
+
311
+
312
+ AutoConfig.register("evf", EvfConfig)
313
+ AutoModelForCausalLM.register(EvfConfig, EvfEffiSamModel)
model/evf_sam.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM
7
+ from .segment_anything import build_sam_vit_h
8
+ from .unilm.beit3.modeling_utils import BEiT3Wrapper, _get_base_config, _get_large_config
9
+ from .configuration_evf import EvfConfig
10
+
11
+ def dice_loss(
12
+ inputs: torch.Tensor,
13
+ targets: torch.Tensor,
14
+ num_masks: float,
15
+ scale=1000, # 100000.0,
16
+ eps=1e-6,
17
+ ):
18
+ """
19
+ Compute the DICE loss, similar to generalized IOU for masks
20
+ Args:
21
+ inputs: A float tensor of arbitrary shape.
22
+ The predictions for each example.
23
+ targets: A float tensor with the same shape as inputs. Stores the binary
24
+ classification label for each element in inputs
25
+ (0 for the negative class and 1 for the positive class).
26
+ """
27
+ inputs = inputs.sigmoid()
28
+ inputs = inputs.flatten(1, 2)
29
+ targets = targets.flatten(1, 2)
30
+ numerator = 2 * (inputs / scale * targets).sum(-1)
31
+ denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1)
32
+ loss = 1 - (numerator + eps) / (denominator + eps)
33
+ loss = loss.sum() / (num_masks + 1e-8)
34
+ return loss
35
+
36
+
37
+ def sigmoid_ce_loss(
38
+ inputs: torch.Tensor,
39
+ targets: torch.Tensor,
40
+ num_masks: float,
41
+ ):
42
+ """
43
+ Args:
44
+ inputs: A float tensor of arbitrary shape.
45
+ The predictions for each example.
46
+ targets: A float tensor with the same shape as inputs. Stores the binary
47
+ classification label for each element in inputs
48
+ (0 for the negative class and 1 for the positive class).
49
+ Returns:
50
+ Loss tensor
51
+ """
52
+ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
53
+ loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8)
54
+ return loss
55
+
56
+
57
+
58
+ class EvfSamModel(PreTrainedModel):
59
+ config_class = EvfConfig
60
+ def __init__(
61
+ self,
62
+ config,
63
+ **kwargs
64
+ ):
65
+ super(EvfSamModel, self).__init__(config)
66
+
67
+ self.config = config
68
+ self.vision_pretrained = kwargs.get("vision_pretrained", None)
69
+ self.encoder_pretrained = kwargs.get("encoder_pretrained", None)
70
+ self.dice_loss_weight = kwargs.get("dice_loss_weight", None)
71
+ self.bce_loss_weight = kwargs.get("bce_loss_weight", None)
72
+ self.train_mask_decoder = kwargs.get("train_mask_decoder", False)
73
+ self.train_prompt_encoder = kwargs.get("train_prompt_encoder", False)
74
+ self.initialize_evf_modules(config)
75
+
76
+
77
+ def initialize_evf_modules(self, config):
78
+ # SAM
79
+ if config.sam_scale=="huge":
80
+ self.visual_model = build_sam_vit_h(self.vision_pretrained)
81
+ else:
82
+ raise NotImplementedError
83
+
84
+ for param in self.visual_model.parameters():
85
+ param.requires_grad = False
86
+ if self.train_mask_decoder:
87
+ self.visual_model.mask_decoder.train()
88
+ for param in self.visual_model.mask_decoder.parameters():
89
+ param.requires_grad = True
90
+ if self.train_prompt_encoder:
91
+ self.visual_model.prompt_encoder.no_mask_embed.requires_grad_(True)
92
+
93
+ # beit-3
94
+ if self.config.mm_extractor_scale == "base":
95
+ beit_config = _get_base_config()
96
+ elif self.config.mm_extractor_scale == "large":
97
+ beit_config = _get_large_config()
98
+ else:
99
+ raise AttributeError(f"model config should contain key 'mm_extractor_scale', with value 'base' or 'large'.")
100
+
101
+ self.mm_extractor = BEiT3Wrapper(beit_config)
102
+ if self.encoder_pretrained is not None:
103
+ beit_state_dict = torch.load(self.encoder_pretrained)["model"]
104
+ self.mm_extractor.load_state_dict(
105
+ beit_state_dict,
106
+ strict=False
107
+ )
108
+
109
+ for param in self.mm_extractor.parameters():
110
+ param.requires_grad = True
111
+
112
+ # Projection layer
113
+ in_dim = config.hidden_size
114
+ assert in_dim==beit_config.encoder_embed_dim, \
115
+ f"projection layer dim {in_dim} mismatch with mm_extractor dim {beit_config.encoder_embed_dim}"
116
+ out_dim = config.out_dim
117
+ text_fc = [
118
+ nn.Linear(in_dim, in_dim),
119
+ nn.ReLU(),
120
+ nn.Linear(in_dim, out_dim)
121
+ ]
122
+ self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)])
123
+ self.text_hidden_fcs.train()
124
+ for param in self.text_hidden_fcs.parameters():
125
+ param.requires_grad = True
126
+
127
+ def get_visual_embs(self, pixel_values: torch.FloatTensor):
128
+ with torch.no_grad():
129
+ image_embeddings_list = []
130
+ for i in range(pixel_values.shape[0]):
131
+ torch.cuda.empty_cache()
132
+ image_embeddings = self.visual_model.image_encoder(
133
+ pixel_values[i].unsqueeze(0)
134
+ )
135
+ image_embeddings_list.append(image_embeddings)
136
+ torch.cuda.empty_cache()
137
+ image_embeddings = torch.cat(image_embeddings_list, 0)
138
+ return image_embeddings
139
+
140
+ def forward(
141
+ self,
142
+ images: torch.FloatTensor,
143
+ images_evf: torch.FloatTensor,
144
+ input_ids: torch.LongTensor,
145
+ attention_masks: torch.LongTensor,
146
+ offset: torch.LongTensor,
147
+ masks_list: List[torch.FloatTensor],
148
+ label_list: List[torch.Tensor],
149
+ resize_list: List[tuple],
150
+ inference: bool = False,
151
+ **kwargs,
152
+ ):
153
+ image_embeddings = self.get_visual_embs(images)
154
+ batch_size = image_embeddings.shape[0]
155
+ assert batch_size == len(offset) - 1
156
+
157
+ images_evf_list = []
158
+ for i in range(len(offset) - 1):
159
+ start_i, end_i = offset[i], offset[i + 1]
160
+ images_evf_i = (
161
+ images_evf[i]
162
+ .unsqueeze(0)
163
+ .expand(end_i - start_i, -1, -1, -1)
164
+ .contiguous()
165
+ )
166
+ images_evf_list.append(images_evf_i)
167
+ images_evf = torch.cat(images_evf_list, dim=0)
168
+
169
+ multimask_output = False
170
+ output = self.mm_extractor.beit3(
171
+ visual_tokens=images_evf,
172
+ textual_tokens=input_ids,
173
+ text_padding_position=~attention_masks
174
+ )
175
+
176
+ feat = output["encoder_out"][:, :1, ...]
177
+
178
+ feat = self.text_hidden_fcs[0](feat)
179
+ feat = torch.split(feat, [offset[i+1] - offset[i] for i in range(len(offset)-1)])
180
+
181
+ pred_masks = []
182
+ for i in range(len(feat)):
183
+ (
184
+ sparse_embeddings,
185
+ dense_embeddings,
186
+ ) = self.visual_model.prompt_encoder(
187
+ points=None,
188
+ boxes=None,
189
+ masks=None,
190
+ text_embeds=feat[i],
191
+ )
192
+ sparse_embeddings = sparse_embeddings.to(feat[i].dtype)
193
+ low_res_masks, iou_predictions = self.visual_model.mask_decoder(
194
+ image_embeddings=image_embeddings[i].unsqueeze(0),
195
+ image_pe=self.visual_model.prompt_encoder.get_dense_pe(),
196
+ sparse_prompt_embeddings=sparse_embeddings,
197
+ dense_prompt_embeddings=dense_embeddings,
198
+ multimask_output=multimask_output,
199
+ )
200
+
201
+ if multimask_output:
202
+ sorted_ids = torch.argsort(iou_predictions, dim=-1, descending=True)
203
+ low_res_masks = torch.take_along_dim(low_res_masks, sorted_ids[..., None, None], dim=1)[:, :1]
204
+
205
+ pred_mask = self.visual_model.postprocess_masks(
206
+ low_res_masks,
207
+ input_size=resize_list[i],
208
+ original_size=label_list[i].shape,
209
+ )
210
+ pred_masks.append(pred_mask[:, 0])
211
+
212
+ gt_masks = masks_list
213
+
214
+ if inference:
215
+ return {
216
+ "pred_masks": pred_masks,
217
+ "gt_masks": gt_masks,
218
+ }
219
+
220
+ mask_bce_loss = 0
221
+ mask_dice_loss = 0
222
+ num_masks = 0
223
+ for batch_idx in range(len(pred_masks)):
224
+ gt_mask = gt_masks[batch_idx]
225
+ pred_mask = pred_masks[batch_idx]
226
+
227
+ assert (
228
+ gt_mask.shape[0] == pred_mask.shape[0]
229
+ ), "gt_mask.shape: {}, pred_mask.shape: {}".format(
230
+ gt_mask.shape, pred_mask.shape
231
+ )
232
+ mask_bce_loss += (
233
+ sigmoid_ce_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
234
+ * gt_mask.shape[0]
235
+ )
236
+ mask_dice_loss += (
237
+ dice_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0])
238
+ * gt_mask.shape[0]
239
+ )
240
+ num_masks += gt_mask.shape[0]
241
+
242
+ mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8)
243
+ mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8)
244
+ mask_loss = mask_bce_loss + mask_dice_loss
245
+
246
+ loss = mask_loss
247
+
248
+ return {
249
+ "loss": loss,
250
+ "mask_bce_loss": mask_bce_loss,
251
+ "mask_dice_loss": mask_dice_loss,
252
+ "mask_loss": mask_loss,
253
+ }
254
+
255
+ def inference(
256
+ self,
257
+ images,
258
+ images_evf,
259
+ input_ids,
260
+ resize_list,
261
+ original_size_list,
262
+ multimask_output=False,
263
+ ):
264
+ with torch.no_grad():
265
+ image_embeddings = self.visual_model.image_encoder(images)
266
+ multimask_output = multimask_output
267
+
268
+ output = self.mm_extractor.beit3(visual_tokens=images_evf, textual_tokens=input_ids, text_padding_position=torch.zeros_like(input_ids))
269
+
270
+ feat = output["encoder_out"][:, :1, ...]
271
+ feat = self.text_hidden_fcs[0](feat)
272
+ (
273
+ sparse_embeddings,
274
+ dense_embeddings,
275
+ ) = self.visual_model.prompt_encoder(
276
+ points=None,
277
+ boxes=None,
278
+ masks=None,
279
+ text_embeds=feat,
280
+ )
281
+ sparse_embeddings = sparse_embeddings.to(feat.dtype)
282
+ low_res_masks, iou_predictions = self.visual_model.mask_decoder(
283
+ image_embeddings=image_embeddings,
284
+ image_pe=self.visual_model.prompt_encoder.get_dense_pe(),
285
+ sparse_prompt_embeddings=sparse_embeddings,
286
+ dense_prompt_embeddings=dense_embeddings,
287
+ multimask_output=multimask_output,
288
+ )
289
+ if multimask_output:
290
+ sorted_ids = torch.argsort(iou_predictions, dim=-1, descending=True)
291
+ low_res_masks = torch.take_along_dim(low_res_masks, sorted_ids[..., None, None], dim=1)[:, :1]
292
+
293
+ pred_mask = self.visual_model.postprocess_masks(
294
+ low_res_masks,
295
+ input_size=resize_list[0],
296
+ original_size=original_size_list[0],
297
+ )
298
+
299
+ return pred_mask[:, 0]
300
+
301
+
302
+ AutoConfig.register("evf", EvfConfig)
303
+ AutoModelForCausalLM.register(EvfConfig, EvfSamModel)
model/segment_anything/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .automatic_mask_generator import SamAutomaticMaskGenerator
8
+ from .build_sam import (build_sam, build_sam_vit_b, build_sam_vit_h,
9
+ build_sam_vit_l, sam_model_registry)
10
+ from .predictor import SamPredictor
model/segment_anything/automatic_mask_generator.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torchvision.ops.boxes import batched_nms, box_area # type: ignore
12
+
13
+ from .modeling import Sam
14
+ from .predictor import SamPredictor
15
+ from .utils.amg import (MaskData, area_from_rle, batch_iterator,
16
+ batched_mask_to_box, box_xyxy_to_xywh,
17
+ build_all_layer_point_grids, calculate_stability_score,
18
+ coco_encode_rle, generate_crop_boxes,
19
+ is_box_near_crop_edge, mask_to_rle_pytorch,
20
+ remove_small_regions, rle_to_mask, uncrop_boxes_xyxy,
21
+ uncrop_masks, uncrop_points)
22
+
23
+
24
+ class SamAutomaticMaskGenerator:
25
+ def __init__(
26
+ self,
27
+ model: Sam,
28
+ points_per_side: Optional[int] = 32,
29
+ points_per_batch: int = 64,
30
+ pred_iou_thresh: float = 0.88,
31
+ stability_score_thresh: float = 0.95,
32
+ stability_score_offset: float = 1.0,
33
+ box_nms_thresh: float = 0.7,
34
+ crop_n_layers: int = 0,
35
+ crop_nms_thresh: float = 0.7,
36
+ crop_overlap_ratio: float = 512 / 1500,
37
+ crop_n_points_downscale_factor: int = 1,
38
+ point_grids: Optional[List[np.ndarray]] = None,
39
+ min_mask_region_area: int = 0,
40
+ output_mode: str = "binary_mask",
41
+ ) -> None:
42
+ """
43
+ Using a SAM model, generates masks for the entire image.
44
+ Generates a grid of point prompts over the image, then filters
45
+ low quality and duplicate masks. The default settings are chosen
46
+ for SAM with a ViT-H backbone.
47
+
48
+ Arguments:
49
+ model (Sam): The SAM model to use for mask prediction.
50
+ points_per_side (int or None): The number of points to be sampled
51
+ along one side of the image. The total number of points is
52
+ points_per_side**2. If None, 'point_grids' must provide explicit
53
+ point sampling.
54
+ points_per_batch (int): Sets the number of points run simultaneously
55
+ by the model. Higher numbers may be faster but use more GPU memory.
56
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
57
+ model's predicted mask quality.
58
+ stability_score_thresh (float): A filtering threshold in [0,1], using
59
+ the stability of the mask under changes to the cutoff used to binarize
60
+ the model's mask predictions.
61
+ stability_score_offset (float): The amount to shift the cutoff when
62
+ calculated the stability score.
63
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
64
+ suppression to filter duplicate masks.
65
+ crop_n_layers (int): If >0, mask prediction will be run again on
66
+ crops of the image. Sets the number of layers to run, where each
67
+ layer has 2**i_layer number of image crops.
68
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
69
+ suppression to filter duplicate masks between different crops.
70
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
71
+ In the first crop layer, crops will overlap by this fraction of
72
+ the image length. Later layers with more crops scale down this overlap.
73
+ crop_n_points_downscale_factor (int): The number of points-per-side
74
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
75
+ point_grids (list(np.ndarray) or None): A list over explicit grids
76
+ of points used for sampling, normalized to [0,1]. The nth grid in the
77
+ list is used in the nth crop layer. Exclusive with points_per_side.
78
+ min_mask_region_area (int): If >0, postprocessing will be applied
79
+ to remove disconnected regions and holes in masks with area smaller
80
+ than min_mask_region_area. Requires opencv.
81
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
82
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
83
+ For large resolutions, 'binary_mask' may consume large amounts of
84
+ memory.
85
+ """
86
+
87
+ assert (points_per_side is None) != (
88
+ point_grids is None
89
+ ), "Exactly one of points_per_side or point_grid must be provided."
90
+ if points_per_side is not None:
91
+ self.point_grids = build_all_layer_point_grids(
92
+ points_per_side,
93
+ crop_n_layers,
94
+ crop_n_points_downscale_factor,
95
+ )
96
+ elif point_grids is not None:
97
+ self.point_grids = point_grids
98
+ else:
99
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
100
+
101
+ assert output_mode in [
102
+ "binary_mask",
103
+ "uncompressed_rle",
104
+ "coco_rle",
105
+ ], f"Unknown output_mode {output_mode}."
106
+ if output_mode == "coco_rle":
107
+ from pycocotools import \
108
+ mask as mask_utils # type: ignore # noqa: F401
109
+
110
+ if min_mask_region_area > 0:
111
+ import cv2 # type: ignore # noqa: F401
112
+
113
+ self.predictor = SamPredictor(model)
114
+ self.points_per_batch = points_per_batch
115
+ self.pred_iou_thresh = pred_iou_thresh
116
+ self.stability_score_thresh = stability_score_thresh
117
+ self.stability_score_offset = stability_score_offset
118
+ self.box_nms_thresh = box_nms_thresh
119
+ self.crop_n_layers = crop_n_layers
120
+ self.crop_nms_thresh = crop_nms_thresh
121
+ self.crop_overlap_ratio = crop_overlap_ratio
122
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
123
+ self.min_mask_region_area = min_mask_region_area
124
+ self.output_mode = output_mode
125
+
126
+ @torch.no_grad()
127
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
128
+ """
129
+ Generates masks for the given image.
130
+
131
+ Arguments:
132
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
133
+
134
+ Returns:
135
+ list(dict(str, any)): A list over records for masks. Each record is
136
+ a dict containing the following keys:
137
+ segmentation (dict(str, any) or np.ndarray): The mask. If
138
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
139
+ is a dictionary containing the RLE.
140
+ bbox (list(float)): The box around the mask, in XYWH format.
141
+ area (int): The area in pixels of the mask.
142
+ predicted_iou (float): The model's own prediction of the mask's
143
+ quality. This is filtered by the pred_iou_thresh parameter.
144
+ point_coords (list(list(float))): The point coordinates input
145
+ to the model to generate this mask.
146
+ stability_score (float): A measure of the mask's quality. This
147
+ is filtered on using the stability_score_thresh parameter.
148
+ crop_box (list(float)): The crop of the image used to generate
149
+ the mask, given in XYWH format.
150
+ """
151
+
152
+ # Generate masks
153
+ mask_data = self._generate_masks(image)
154
+
155
+ # Filter small disconnected regions and holes in masks
156
+ if self.min_mask_region_area > 0:
157
+ mask_data = self.postprocess_small_regions(
158
+ mask_data,
159
+ self.min_mask_region_area,
160
+ max(self.box_nms_thresh, self.crop_nms_thresh),
161
+ )
162
+
163
+ # Encode masks
164
+ if self.output_mode == "coco_rle":
165
+ mask_data["segmentations"] = [
166
+ coco_encode_rle(rle) for rle in mask_data["rles"]
167
+ ]
168
+ elif self.output_mode == "binary_mask":
169
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
170
+ else:
171
+ mask_data["segmentations"] = mask_data["rles"]
172
+
173
+ # Write mask records
174
+ curr_anns = []
175
+ for idx in range(len(mask_data["segmentations"])):
176
+ ann = {
177
+ "segmentation": mask_data["segmentations"][idx],
178
+ "area": area_from_rle(mask_data["rles"][idx]),
179
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
180
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
181
+ "point_coords": [mask_data["points"][idx].tolist()],
182
+ "stability_score": mask_data["stability_score"][idx].item(),
183
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
184
+ }
185
+ curr_anns.append(ann)
186
+
187
+ return curr_anns
188
+
189
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
190
+ orig_size = image.shape[:2]
191
+ crop_boxes, layer_idxs = generate_crop_boxes(
192
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
193
+ )
194
+
195
+ # Iterate over image crops
196
+ data = MaskData()
197
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
198
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
199
+ data.cat(crop_data)
200
+
201
+ # Remove duplicate masks between crops
202
+ if len(crop_boxes) > 1:
203
+ # Prefer masks from smaller crops
204
+ scores = 1 / box_area(data["crop_boxes"])
205
+ scores = scores.to(data["boxes"].device)
206
+ keep_by_nms = batched_nms(
207
+ data["boxes"].float(),
208
+ scores,
209
+ torch.zeros_like(data["boxes"][:, 0]), # categories
210
+ iou_threshold=self.crop_nms_thresh,
211
+ )
212
+ data.filter(keep_by_nms)
213
+
214
+ data.to_numpy()
215
+ return data
216
+
217
+ def _process_crop(
218
+ self,
219
+ image: np.ndarray,
220
+ crop_box: List[int],
221
+ crop_layer_idx: int,
222
+ orig_size: Tuple[int, ...],
223
+ ) -> MaskData:
224
+ # Crop the image and calculate embeddings
225
+ x0, y0, x1, y1 = crop_box
226
+ cropped_im = image[y0:y1, x0:x1, :]
227
+ cropped_im_size = cropped_im.shape[:2]
228
+ self.predictor.set_image(cropped_im)
229
+
230
+ # Get points for this crop
231
+ points_scale = np.array(cropped_im_size)[None, ::-1]
232
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
233
+
234
+ # Generate masks for this crop in batches
235
+ data = MaskData()
236
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
237
+ batch_data = self._process_batch(
238
+ points, cropped_im_size, crop_box, orig_size
239
+ )
240
+ data.cat(batch_data)
241
+ del batch_data
242
+ self.predictor.reset_image()
243
+
244
+ # Remove duplicates within this crop.
245
+ keep_by_nms = batched_nms(
246
+ data["boxes"].float(),
247
+ data["iou_preds"],
248
+ torch.zeros_like(data["boxes"][:, 0]), # categories
249
+ iou_threshold=self.box_nms_thresh,
250
+ )
251
+ data.filter(keep_by_nms)
252
+
253
+ # Return to the original image frame
254
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
255
+ data["points"] = uncrop_points(data["points"], crop_box)
256
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
257
+
258
+ return data
259
+
260
+ def _process_batch(
261
+ self,
262
+ points: np.ndarray,
263
+ im_size: Tuple[int, ...],
264
+ crop_box: List[int],
265
+ orig_size: Tuple[int, ...],
266
+ ) -> MaskData:
267
+ orig_h, orig_w = orig_size
268
+
269
+ # Run model on this batch
270
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
271
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
272
+ in_labels = torch.ones(
273
+ in_points.shape[0], dtype=torch.int, device=in_points.device
274
+ )
275
+ masks, iou_preds, _ = self.predictor.predict_torch(
276
+ in_points[:, None, :],
277
+ in_labels[:, None],
278
+ multimask_output=True,
279
+ return_logits=True,
280
+ )
281
+
282
+ # Serialize predictions and store in MaskData
283
+ data = MaskData(
284
+ masks=masks.flatten(0, 1),
285
+ iou_preds=iou_preds.flatten(0, 1),
286
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
287
+ )
288
+ del masks
289
+
290
+ # Filter by predicted IoU
291
+ if self.pred_iou_thresh > 0.0:
292
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
293
+ data.filter(keep_mask)
294
+
295
+ # Calculate stability score
296
+ data["stability_score"] = calculate_stability_score(
297
+ data["masks"],
298
+ self.predictor.model.mask_threshold,
299
+ self.stability_score_offset,
300
+ )
301
+ if self.stability_score_thresh > 0.0:
302
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
303
+ data.filter(keep_mask)
304
+
305
+ # Threshold masks and calculate boxes
306
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
307
+ data["boxes"] = batched_mask_to_box(data["masks"])
308
+
309
+ # Filter boxes that touch crop boundaries
310
+ keep_mask = ~is_box_near_crop_edge(
311
+ data["boxes"], crop_box, [0, 0, orig_w, orig_h]
312
+ )
313
+ if not torch.all(keep_mask):
314
+ data.filter(keep_mask)
315
+
316
+ # Compress to RLE
317
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
318
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
319
+ del data["masks"]
320
+
321
+ return data
322
+
323
+ @staticmethod
324
+ def postprocess_small_regions(
325
+ mask_data: MaskData, min_area: int, nms_thresh: float
326
+ ) -> MaskData:
327
+ """
328
+ Removes small disconnected regions and holes in masks, then reruns
329
+ box NMS to remove any new duplicates.
330
+
331
+ Edits mask_data in place.
332
+
333
+ Requires open-cv as a dependency.
334
+ """
335
+ if len(mask_data["rles"]) == 0:
336
+ return mask_data
337
+
338
+ # Filter small disconnected regions and holes
339
+ new_masks = []
340
+ scores = []
341
+ for rle in mask_data["rles"]:
342
+ mask = rle_to_mask(rle)
343
+
344
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
345
+ unchanged = not changed
346
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
347
+ unchanged = unchanged and not changed
348
+
349
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
350
+ # Give score=0 to changed masks and score=1 to unchanged masks
351
+ # so NMS will prefer ones that didn't need postprocessing
352
+ scores.append(float(unchanged))
353
+
354
+ # Recalculate boxes and remove any new duplicates
355
+ masks = torch.cat(new_masks, dim=0)
356
+ boxes = batched_mask_to_box(masks)
357
+ keep_by_nms = batched_nms(
358
+ boxes.float(),
359
+ torch.as_tensor(scores),
360
+ torch.zeros_like(boxes[:, 0]), # categories
361
+ iou_threshold=nms_thresh,
362
+ )
363
+
364
+ # Only recalculate RLEs for masks that have changed
365
+ for i_mask in keep_by_nms:
366
+ if scores[i_mask] == 0.0:
367
+ mask_torch = masks[i_mask].unsqueeze(0)
368
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
369
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
370
+ mask_data.filter(keep_by_nms)
371
+
372
+ return mask_data
model/segment_anything/build_sam.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from functools import partial
8
+
9
+ import torch
10
+
11
+ from .modeling import (ImageEncoderViT, MaskDecoder, PromptEncoder, Sam,
12
+ TwoWayTransformer)
13
+
14
+
15
+ def build_sam_vit_h(checkpoint=None):
16
+ return _build_sam(
17
+ encoder_embed_dim=1280,
18
+ encoder_depth=32,
19
+ encoder_num_heads=16,
20
+ encoder_global_attn_indexes=[7, 15, 23, 31],
21
+ checkpoint=checkpoint,
22
+ )
23
+
24
+
25
+ build_sam = build_sam_vit_h
26
+
27
+
28
+ def build_sam_vit_l(checkpoint=None):
29
+ return _build_sam(
30
+ encoder_embed_dim=1024,
31
+ encoder_depth=24,
32
+ encoder_num_heads=16,
33
+ encoder_global_attn_indexes=[5, 11, 17, 23],
34
+ checkpoint=checkpoint,
35
+ )
36
+
37
+
38
+ def build_sam_vit_b(checkpoint=None):
39
+ return _build_sam(
40
+ encoder_embed_dim=768,
41
+ encoder_depth=12,
42
+ encoder_num_heads=12,
43
+ encoder_global_attn_indexes=[2, 5, 8, 11],
44
+ checkpoint=checkpoint,
45
+ )
46
+
47
+
48
+ sam_model_registry = {
49
+ "default": build_sam_vit_h,
50
+ "vit_h": build_sam_vit_h,
51
+ "vit_l": build_sam_vit_l,
52
+ "vit_b": build_sam_vit_b,
53
+ }
54
+
55
+
56
+ def _build_sam(
57
+ encoder_embed_dim,
58
+ encoder_depth,
59
+ encoder_num_heads,
60
+ encoder_global_attn_indexes,
61
+ checkpoint=None,
62
+ ):
63
+ prompt_embed_dim = 256
64
+ image_size = 1024
65
+ vit_patch_size = 16
66
+ image_embedding_size = image_size // vit_patch_size
67
+ sam = Sam(
68
+ image_encoder=ImageEncoderViT(
69
+ depth=encoder_depth,
70
+ embed_dim=encoder_embed_dim,
71
+ img_size=image_size,
72
+ mlp_ratio=4,
73
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
74
+ num_heads=encoder_num_heads,
75
+ patch_size=vit_patch_size,
76
+ qkv_bias=True,
77
+ use_rel_pos=True,
78
+ global_attn_indexes=encoder_global_attn_indexes,
79
+ window_size=14,
80
+ out_chans=prompt_embed_dim,
81
+ ),
82
+ prompt_encoder=PromptEncoder(
83
+ embed_dim=prompt_embed_dim,
84
+ image_embedding_size=(image_embedding_size, image_embedding_size),
85
+ input_image_size=(image_size, image_size),
86
+ mask_in_chans=16,
87
+ ),
88
+ mask_decoder=MaskDecoder(
89
+ num_multimask_outputs=3,
90
+ transformer=TwoWayTransformer(
91
+ depth=2,
92
+ embedding_dim=prompt_embed_dim,
93
+ mlp_dim=2048,
94
+ num_heads=8,
95
+ ),
96
+ transformer_dim=prompt_embed_dim,
97
+ iou_head_depth=3,
98
+ iou_head_hidden_dim=256,
99
+ ),
100
+ pixel_mean=[123.675, 116.28, 103.53],
101
+ pixel_std=[58.395, 57.12, 57.375],
102
+ )
103
+ sam.eval()
104
+ if checkpoint is not None:
105
+ with open(checkpoint, "rb") as f:
106
+ state_dict = torch.load(f)
107
+ sam.load_state_dict(state_dict, strict=False)
108
+ return sam
model/segment_anything/modeling/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .image_encoder import ImageEncoderViT
8
+ from .mask_decoder import MaskDecoder
9
+ from .prompt_encoder import PromptEncoder
10
+ from .sam import Sam
11
+ from .transformer import TwoWayTransformer
model/segment_anything/modeling/common.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Type
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+ class MLPBlock(nn.Module):
14
+ def __init__(
15
+ self,
16
+ embedding_dim: int,
17
+ mlp_dim: int,
18
+ act: Type[nn.Module] = nn.GELU,
19
+ ) -> None:
20
+ super().__init__()
21
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23
+ self.act = act()
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ return self.lin2(self.act(self.lin1(x)))
27
+
28
+
29
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31
+ class LayerNorm2d(nn.Module):
32
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33
+ super().__init__()
34
+ self.weight = nn.Parameter(torch.ones(num_channels))
35
+ self.bias = nn.Parameter(torch.zeros(num_channels))
36
+ self.eps = eps
37
+
38
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
39
+ u = x.mean(1, keepdim=True)
40
+ s = (x - u).pow(2).mean(1, keepdim=True)
41
+ x = (x - u) / torch.sqrt(s + self.eps)
42
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
43
+ return x
model/segment_anything/modeling/image_encoder.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Tuple, Type
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from .common import LayerNorm2d, MLPBlock
14
+
15
+
16
+ # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
17
+ class ImageEncoderViT(nn.Module):
18
+ def __init__(
19
+ self,
20
+ img_size: int = 1024,
21
+ patch_size: int = 16,
22
+ in_chans: int = 3,
23
+ embed_dim: int = 768,
24
+ depth: int = 12,
25
+ num_heads: int = 12,
26
+ mlp_ratio: float = 4.0,
27
+ out_chans: int = 256,
28
+ qkv_bias: bool = True,
29
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
30
+ act_layer: Type[nn.Module] = nn.GELU,
31
+ use_abs_pos: bool = True,
32
+ use_rel_pos: bool = False,
33
+ rel_pos_zero_init: bool = True,
34
+ window_size: int = 0,
35
+ global_attn_indexes: Tuple[int, ...] = (),
36
+ ) -> None:
37
+ """
38
+ Args:
39
+ img_size (int): Input image size.
40
+ patch_size (int): Patch size.
41
+ in_chans (int): Number of input image channels.
42
+ embed_dim (int): Patch embedding dimension.
43
+ depth (int): Depth of ViT.
44
+ num_heads (int): Number of attention heads in each ViT block.
45
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
46
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
47
+ norm_layer (nn.Module): Normalization layer.
48
+ act_layer (nn.Module): Activation layer.
49
+ use_abs_pos (bool): If True, use absolute positional embeddings.
50
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
51
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
52
+ window_size (int): Window size for window attention blocks.
53
+ global_attn_indexes (list): Indexes for blocks using global attention.
54
+ """
55
+ super().__init__()
56
+ self.img_size = img_size
57
+ self.embed_dim = embed_dim
58
+ self.out_chans = out_chans
59
+
60
+ self.patch_embed = PatchEmbed(
61
+ kernel_size=(patch_size, patch_size),
62
+ stride=(patch_size, patch_size),
63
+ in_chans=in_chans,
64
+ embed_dim=embed_dim,
65
+ )
66
+
67
+ self.pos_embed: Optional[nn.Parameter] = None
68
+ if use_abs_pos:
69
+ # Initialize absolute positional embedding with pretrain image size.
70
+ self.pos_embed = nn.Parameter(
71
+ torch.zeros(
72
+ 1, img_size // patch_size, img_size // patch_size, embed_dim
73
+ )
74
+ )
75
+
76
+ self.blocks = nn.ModuleList()
77
+ for i in range(depth):
78
+ block = Block(
79
+ dim=embed_dim,
80
+ num_heads=num_heads,
81
+ mlp_ratio=mlp_ratio,
82
+ qkv_bias=qkv_bias,
83
+ norm_layer=norm_layer,
84
+ act_layer=act_layer,
85
+ use_rel_pos=use_rel_pos,
86
+ rel_pos_zero_init=rel_pos_zero_init,
87
+ window_size=window_size if i not in global_attn_indexes else 0,
88
+ input_size=(img_size // patch_size, img_size // patch_size),
89
+ )
90
+ self.blocks.append(block)
91
+
92
+ self.neck = nn.Sequential(
93
+ nn.Conv2d(
94
+ embed_dim,
95
+ out_chans,
96
+ kernel_size=1,
97
+ bias=False,
98
+ ),
99
+ LayerNorm2d(out_chans),
100
+ nn.Conv2d(
101
+ out_chans,
102
+ out_chans,
103
+ kernel_size=3,
104
+ padding=1,
105
+ bias=False,
106
+ ),
107
+ LayerNorm2d(out_chans),
108
+ )
109
+
110
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
111
+ x = self.patch_embed(x)
112
+ if self.pos_embed is not None:
113
+ x = x + self.pos_embed
114
+
115
+ for blk in self.blocks:
116
+ x = blk(x)
117
+
118
+ dtype = x.dtype
119
+ if dtype == torch.float16: # prevent overflow
120
+ with torch.autocast(device_type="cuda", dtype=torch.float32):
121
+ x = self.neck(x.permute(0, 3, 1, 2))
122
+ x = x.to(dtype)
123
+ else:
124
+ x = self.neck(x.permute(0, 3, 1, 2))
125
+ return x
126
+
127
+
128
+ class Block(nn.Module):
129
+ """Transformer blocks with support of window attention and residual propagation blocks"""
130
+
131
+ def __init__(
132
+ self,
133
+ dim: int,
134
+ num_heads: int,
135
+ mlp_ratio: float = 4.0,
136
+ qkv_bias: bool = True,
137
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
138
+ act_layer: Type[nn.Module] = nn.GELU,
139
+ use_rel_pos: bool = False,
140
+ rel_pos_zero_init: bool = True,
141
+ window_size: int = 0,
142
+ input_size: Optional[Tuple[int, int]] = None,
143
+ ) -> None:
144
+ """
145
+ Args:
146
+ dim (int): Number of input channels.
147
+ num_heads (int): Number of attention heads in each ViT block.
148
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
149
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
150
+ norm_layer (nn.Module): Normalization layer.
151
+ act_layer (nn.Module): Activation layer.
152
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
153
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
154
+ window_size (int): Window size for window attention blocks. If it equals 0, then
155
+ use global attention.
156
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
157
+ positional parameter size.
158
+ """
159
+ super().__init__()
160
+ self.norm1 = norm_layer(dim)
161
+ self.attn = Attention(
162
+ dim,
163
+ num_heads=num_heads,
164
+ qkv_bias=qkv_bias,
165
+ use_rel_pos=use_rel_pos,
166
+ rel_pos_zero_init=rel_pos_zero_init,
167
+ input_size=input_size if window_size == 0 else (window_size, window_size),
168
+ )
169
+
170
+ self.norm2 = norm_layer(dim)
171
+ self.mlp = MLPBlock(
172
+ embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
173
+ )
174
+
175
+ self.window_size = window_size
176
+
177
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
178
+ shortcut = x
179
+ x = self.norm1(x)
180
+ # Window partition
181
+ if self.window_size > 0:
182
+ H, W = x.shape[1], x.shape[2]
183
+ x, pad_hw = window_partition(x, self.window_size)
184
+
185
+ x = self.attn(x)
186
+ # Reverse window partition
187
+ if self.window_size > 0:
188
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
189
+
190
+ x = shortcut + x
191
+ x = x + self.mlp(self.norm2(x))
192
+
193
+ return x
194
+
195
+
196
+ class Attention(nn.Module):
197
+ """Multi-head Attention block with relative position embeddings."""
198
+
199
+ def __init__(
200
+ self,
201
+ dim: int,
202
+ num_heads: int = 8,
203
+ qkv_bias: bool = True,
204
+ use_rel_pos: bool = False,
205
+ rel_pos_zero_init: bool = True,
206
+ input_size: Optional[Tuple[int, int]] = None,
207
+ ) -> None:
208
+ """
209
+ Args:
210
+ dim (int): Number of input channels.
211
+ num_heads (int): Number of attention heads.
212
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
213
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
214
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
215
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
216
+ positional parameter size.
217
+ """
218
+ super().__init__()
219
+ self.num_heads = num_heads
220
+ head_dim = dim // num_heads
221
+ self.scale = head_dim**-0.5
222
+
223
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
224
+ self.proj = nn.Linear(dim, dim)
225
+
226
+ self.use_rel_pos = use_rel_pos
227
+ if self.use_rel_pos:
228
+ assert (
229
+ input_size is not None
230
+ ), "Input size must be provided if using relative positional encoding."
231
+ # initialize relative positional embeddings
232
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
233
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
234
+
235
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
236
+ B, H, W, _ = x.shape
237
+ # qkv with shape (3, B, nHead, H * W, C)
238
+ qkv = (
239
+ self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
240
+ )
241
+ # q, k, v with shape (B * nHead, H * W, C)
242
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
243
+
244
+ attn = (q * self.scale) @ k.transpose(-2, -1)
245
+
246
+ if self.use_rel_pos:
247
+ attn = add_decomposed_rel_pos(
248
+ attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
249
+ )
250
+
251
+ attn = attn.softmax(dim=-1)
252
+ x = (
253
+ (attn @ v)
254
+ .view(B, self.num_heads, H, W, -1)
255
+ .permute(0, 2, 3, 1, 4)
256
+ .reshape(B, H, W, -1)
257
+ )
258
+ x = self.proj(x)
259
+
260
+ return x
261
+
262
+
263
+ def window_partition(
264
+ x: torch.Tensor, window_size: int
265
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
266
+ """
267
+ Partition into non-overlapping windows with padding if needed.
268
+ Args:
269
+ x (tensor): input tokens with [B, H, W, C].
270
+ window_size (int): window size.
271
+
272
+ Returns:
273
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
274
+ (Hp, Wp): padded height and width before partition
275
+ """
276
+ B, H, W, C = x.shape
277
+
278
+ pad_h = (window_size - H % window_size) % window_size
279
+ pad_w = (window_size - W % window_size) % window_size
280
+ if pad_h > 0 or pad_w > 0:
281
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
282
+ Hp, Wp = H + pad_h, W + pad_w
283
+
284
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
285
+ windows = (
286
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
287
+ )
288
+ return windows, (Hp, Wp)
289
+
290
+
291
+ def window_unpartition(
292
+ windows: torch.Tensor,
293
+ window_size: int,
294
+ pad_hw: Tuple[int, int],
295
+ hw: Tuple[int, int],
296
+ ) -> torch.Tensor:
297
+ """
298
+ Window unpartition into original sequences and removing padding.
299
+ Args:
300
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
301
+ window_size (int): window size.
302
+ pad_hw (Tuple): padded height and width (Hp, Wp).
303
+ hw (Tuple): original height and width (H, W) before padding.
304
+
305
+ Returns:
306
+ x: unpartitioned sequences with [B, H, W, C].
307
+ """
308
+ Hp, Wp = pad_hw
309
+ H, W = hw
310
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
311
+ x = windows.view(
312
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
313
+ )
314
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
315
+
316
+ if Hp > H or Wp > W:
317
+ x = x[:, :H, :W, :].contiguous()
318
+ return x
319
+
320
+
321
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
322
+ """
323
+ Get relative positional embeddings according to the relative positions of
324
+ query and key sizes.
325
+ Args:
326
+ q_size (int): size of query q.
327
+ k_size (int): size of key k.
328
+ rel_pos (Tensor): relative position embeddings (L, C).
329
+
330
+ Returns:
331
+ Extracted positional embeddings according to relative positions.
332
+ """
333
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
334
+ # Interpolate rel pos if needed.
335
+ if rel_pos.shape[0] != max_rel_dist:
336
+ # Interpolate rel pos.
337
+ rel_pos_resized = F.interpolate(
338
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
339
+ size=max_rel_dist,
340
+ mode="linear",
341
+ )
342
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
343
+ else:
344
+ rel_pos_resized = rel_pos
345
+
346
+ # Scale the coords with short length if shapes for q and k are different.
347
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
348
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
349
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
350
+
351
+ return rel_pos_resized[relative_coords.long()]
352
+
353
+
354
+ def add_decomposed_rel_pos(
355
+ attn: torch.Tensor,
356
+ q: torch.Tensor,
357
+ rel_pos_h: torch.Tensor,
358
+ rel_pos_w: torch.Tensor,
359
+ q_size: Tuple[int, int],
360
+ k_size: Tuple[int, int],
361
+ ) -> torch.Tensor:
362
+ """
363
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
364
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
365
+ Args:
366
+ attn (Tensor): attention map.
367
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
368
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
369
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
370
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
371
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
372
+
373
+ Returns:
374
+ attn (Tensor): attention map with added relative positional embeddings.
375
+ """
376
+ q_h, q_w = q_size
377
+ k_h, k_w = k_size
378
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
379
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
380
+
381
+ B, _, dim = q.shape
382
+ r_q = q.reshape(B, q_h, q_w, dim)
383
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
384
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
385
+
386
+ attn = (
387
+ attn.view(B, q_h, q_w, k_h, k_w)
388
+ + rel_h[:, :, :, :, None]
389
+ + rel_w[:, :, :, None, :]
390
+ ).view(B, q_h * q_w, k_h * k_w)
391
+
392
+ return attn
393
+
394
+
395
+ class PatchEmbed(nn.Module):
396
+ """
397
+ Image to Patch Embedding.
398
+ """
399
+
400
+ def __init__(
401
+ self,
402
+ kernel_size: Tuple[int, int] = (16, 16),
403
+ stride: Tuple[int, int] = (16, 16),
404
+ padding: Tuple[int, int] = (0, 0),
405
+ in_chans: int = 3,
406
+ embed_dim: int = 768,
407
+ ) -> None:
408
+ """
409
+ Args:
410
+ kernel_size (Tuple): kernel size of the projection layer.
411
+ stride (Tuple): stride of the projection layer.
412
+ padding (Tuple): padding size of the projection layer.
413
+ in_chans (int): Number of input image channels.
414
+ embed_dim (int): Patch embedding dimension.
415
+ """
416
+ super().__init__()
417
+
418
+ self.proj = nn.Conv2d(
419
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
420
+ )
421
+
422
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
423
+ x = self.proj(x)
424
+ # B C H W -> B H W C
425
+ x = x.permute(0, 2, 3, 1)
426
+ return x
model/segment_anything/modeling/mask_decoder.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List, Tuple, Type
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+
13
+ from .common import LayerNorm2d
14
+
15
+
16
+ class MaskDecoder(nn.Module):
17
+ def __init__(
18
+ self,
19
+ *,
20
+ transformer_dim: int,
21
+ transformer: nn.Module,
22
+ num_multimask_outputs: int = 3,
23
+ activation: Type[nn.Module] = nn.GELU,
24
+ iou_head_depth: int = 3,
25
+ iou_head_hidden_dim: int = 256,
26
+ ) -> None:
27
+ """
28
+ Predicts masks given an image and prompt embeddings, using a
29
+ transformer architecture.
30
+
31
+ Arguments:
32
+ transformer_dim (int): the channel dimension of the transformer
33
+ transformer (nn.Module): the transformer used to predict masks
34
+ num_multimask_outputs (int): the number of masks to predict
35
+ when disambiguating masks
36
+ activation (nn.Module): the type of activation to use when
37
+ upscaling masks
38
+ iou_head_depth (int): the depth of the MLP used to predict
39
+ mask quality
40
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
41
+ used to predict mask quality
42
+ """
43
+ super().__init__()
44
+ self.transformer_dim = transformer_dim
45
+ self.transformer = transformer
46
+
47
+ self.num_multimask_outputs = num_multimask_outputs
48
+
49
+ self.iou_token = nn.Embedding(1, transformer_dim)
50
+ self.num_mask_tokens = num_multimask_outputs + 1
51
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52
+
53
+ self.output_upscaling = nn.Sequential(
54
+ nn.ConvTranspose2d(
55
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
56
+ ),
57
+ LayerNorm2d(transformer_dim // 4),
58
+ activation(),
59
+ nn.ConvTranspose2d(
60
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
61
+ ),
62
+ activation(),
63
+ )
64
+ self.output_hypernetworks_mlps = nn.ModuleList(
65
+ [
66
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
67
+ for i in range(self.num_mask_tokens)
68
+ ]
69
+ )
70
+
71
+ self.iou_prediction_head = MLP(
72
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
73
+ )
74
+
75
+ def forward(
76
+ self,
77
+ image_embeddings: torch.Tensor,
78
+ image_pe: torch.Tensor,
79
+ sparse_prompt_embeddings: torch.Tensor,
80
+ dense_prompt_embeddings: torch.Tensor,
81
+ multimask_output: bool,
82
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
83
+ """
84
+ Predict masks given image and prompt embeddings.
85
+
86
+ Arguments:
87
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
88
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
89
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
90
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
91
+ multimask_output (bool): Whether to return multiple masks or a single
92
+ mask.
93
+
94
+ Returns:
95
+ torch.Tensor: batched predicted masks
96
+ torch.Tensor: batched predictions of mask quality
97
+ """
98
+ masks, iou_pred = self.predict_masks(
99
+ image_embeddings=image_embeddings,
100
+ image_pe=image_pe,
101
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
102
+ dense_prompt_embeddings=dense_prompt_embeddings,
103
+ )
104
+
105
+ # Select the correct mask or masks for output
106
+ if multimask_output:
107
+ mask_slice = slice(1, None)
108
+ else:
109
+ mask_slice = slice(0, 1)
110
+ masks = masks[:, mask_slice, :, :]
111
+ iou_pred = iou_pred[:, mask_slice]
112
+
113
+ # Prepare output
114
+ return masks, iou_pred
115
+
116
+ def predict_masks(
117
+ self,
118
+ image_embeddings: torch.Tensor,
119
+ image_pe: torch.Tensor,
120
+ sparse_prompt_embeddings: torch.Tensor,
121
+ dense_prompt_embeddings: torch.Tensor,
122
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
123
+ """Predicts masks. See 'forward' for more details."""
124
+ # Concatenate output tokens
125
+ output_tokens = torch.cat(
126
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
127
+ )
128
+ output_tokens = output_tokens.unsqueeze(0).expand(
129
+ sparse_prompt_embeddings.size(0), -1, -1
130
+ )
131
+
132
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
133
+
134
+ # image_embeddings: [1, C, H, W], tokens: [B, N, C]
135
+ # dense_prompt_embeddings: [B, C, H, W]
136
+ # Expand per-image data in batch direction to be per-mask
137
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
138
+ src = src + dense_prompt_embeddings
139
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
140
+ b, c, h, w = src.shape
141
+
142
+ # Run the transformer
143
+ hs, src = self.transformer(src, pos_src, tokens)
144
+ iou_token_out = hs[:, 0, :]
145
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
146
+
147
+ # Upscale mask embeddings and predict masks using the mask tokens
148
+ src = src.transpose(1, 2).view(b, c, h, w)
149
+ upscaled_embedding = self.output_upscaling(src)
150
+ hyper_in_list: List[torch.Tensor] = []
151
+ for i in range(self.num_mask_tokens):
152
+ hyper_in_list.append(
153
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
154
+ )
155
+ hyper_in = torch.stack(hyper_in_list, dim=1)
156
+ b, c, h, w = upscaled_embedding.shape
157
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(
158
+ b, self.num_mask_tokens, h, w
159
+ )
160
+
161
+ # Generate mask quality predictions
162
+ iou_pred = self.iou_prediction_head(iou_token_out)
163
+
164
+ return masks, iou_pred
165
+
166
+
167
+ # Lightly adapted from
168
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
169
+ class MLP(nn.Module):
170
+ def __init__(
171
+ self,
172
+ input_dim: int,
173
+ hidden_dim: int,
174
+ output_dim: int,
175
+ num_layers: int,
176
+ sigmoid_output: bool = False,
177
+ ) -> None:
178
+ super().__init__()
179
+ self.num_layers = num_layers
180
+ h = [hidden_dim] * (num_layers - 1)
181
+ self.layers = nn.ModuleList(
182
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
183
+ )
184
+ self.sigmoid_output = sigmoid_output
185
+
186
+ def forward(self, x):
187
+ for i, layer in enumerate(self.layers):
188
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
189
+ if self.sigmoid_output:
190
+ x = F.sigmoid(x)
191
+ return x
model/segment_anything/modeling/prompt_encoder.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Any, Optional, Tuple, Type
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch import nn
12
+
13
+ from .common import LayerNorm2d
14
+
15
+
16
+ class PromptEncoder(nn.Module):
17
+ def __init__(
18
+ self,
19
+ embed_dim: int,
20
+ image_embedding_size: Tuple[int, int],
21
+ input_image_size: Tuple[int, int],
22
+ mask_in_chans: int,
23
+ activation: Type[nn.Module] = nn.GELU,
24
+ ) -> None:
25
+ """
26
+ Encodes prompts for input to SAM's mask decoder.
27
+
28
+ Arguments:
29
+ embed_dim (int): The prompts' embedding dimension
30
+ image_embedding_size (tuple(int, int)): The spatial size of the
31
+ image embedding, as (H, W).
32
+ input_image_size (int): The padded size of the image as input
33
+ to the image encoder, as (H, W).
34
+ mask_in_chans (int): The number of hidden channels used for
35
+ encoding input masks.
36
+ activation (nn.Module): The activation to use when encoding
37
+ input masks.
38
+ """
39
+ super().__init__()
40
+ self.embed_dim = embed_dim
41
+ self.input_image_size = input_image_size
42
+ self.image_embedding_size = image_embedding_size
43
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44
+
45
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46
+ point_embeddings = [
47
+ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
48
+ ]
49
+ self.point_embeddings = nn.ModuleList(point_embeddings)
50
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
51
+
52
+ self.mask_input_size = (
53
+ 4 * image_embedding_size[0],
54
+ 4 * image_embedding_size[1],
55
+ )
56
+ self.mask_downscaling = nn.Sequential(
57
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
58
+ LayerNorm2d(mask_in_chans // 4),
59
+ activation(),
60
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
61
+ LayerNorm2d(mask_in_chans),
62
+ activation(),
63
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
64
+ )
65
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
66
+
67
+ def get_dense_pe(self) -> torch.Tensor:
68
+ """
69
+ Returns the positional encoding used to encode point prompts,
70
+ applied to a dense set of points the shape of the image encoding.
71
+
72
+ Returns:
73
+ torch.Tensor: Positional encoding with shape
74
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
75
+ """
76
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
77
+
78
+ def _embed_points(
79
+ self,
80
+ points: torch.Tensor,
81
+ labels: torch.Tensor,
82
+ pad: bool,
83
+ ) -> torch.Tensor:
84
+ """Embeds point prompts."""
85
+ points = points + 0.5 # Shift to center of pixel
86
+ if pad:
87
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
88
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
89
+ points = torch.cat([points, padding_point], dim=1)
90
+ labels = torch.cat([labels, padding_label], dim=1)
91
+ point_embedding = self.pe_layer.forward_with_coords(
92
+ points, self.input_image_size
93
+ )
94
+ point_embedding[labels == -1] = 0.0
95
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
96
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
97
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
98
+ return point_embedding
99
+
100
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
101
+ """Embeds box prompts."""
102
+ boxes = boxes + 0.5 # Shift to center of pixel
103
+ coords = boxes.reshape(-1, 2, 2)
104
+ corner_embedding = self.pe_layer.forward_with_coords(
105
+ coords, self.input_image_size
106
+ )
107
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
108
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
109
+ return corner_embedding
110
+
111
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
112
+ """Embeds mask inputs."""
113
+ mask_embedding = self.mask_downscaling(masks)
114
+ return mask_embedding
115
+
116
+ def _get_batch_size(
117
+ self,
118
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
119
+ boxes: Optional[torch.Tensor],
120
+ masks: Optional[torch.Tensor],
121
+ text_embeds: Optional[torch.Tensor],
122
+ ) -> int:
123
+ """
124
+ Gets the batch size of the output given the batch size of the input prompts.
125
+ """
126
+ if points is not None:
127
+ return points[0].shape[0]
128
+ elif boxes is not None:
129
+ return boxes.shape[0]
130
+ elif masks is not None:
131
+ return masks.shape[0]
132
+ elif text_embeds is not None:
133
+ return text_embeds.shape[0]
134
+ else:
135
+ return 1
136
+
137
+ def _get_device(self) -> torch.device:
138
+ return self.point_embeddings[0].weight.device
139
+
140
+ def forward(
141
+ self,
142
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
143
+ boxes: Optional[torch.Tensor],
144
+ masks: Optional[torch.Tensor],
145
+ text_embeds: Optional[torch.Tensor],
146
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
147
+ """
148
+ Embeds different types of prompts, returning both sparse and dense
149
+ embeddings.
150
+
151
+ Arguments:
152
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
153
+ and labels to embed.
154
+ boxes (torch.Tensor or none): boxes to embed
155
+ masks (torch.Tensor or none): masks to embed
156
+
157
+ Returns:
158
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
159
+ BxNx(embed_dim), where N is determined by the number of input points
160
+ and boxes.
161
+ torch.Tensor: dense embeddings for the masks, in the shape
162
+ Bx(embed_dim)x(embed_H)x(embed_W)
163
+ """
164
+ bs = self._get_batch_size(points, boxes, masks, text_embeds)
165
+ sparse_embeddings = torch.empty(
166
+ (bs, 0, self.embed_dim), device=self._get_device()
167
+ )
168
+ if points is not None:
169
+ coords, labels = points
170
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
171
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
172
+ if boxes is not None:
173
+ box_embeddings = self._embed_boxes(boxes)
174
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
175
+
176
+ if text_embeds is not None:
177
+ sparse_embeddings = torch.cat([sparse_embeddings, text_embeds], dim=1)
178
+
179
+ if masks is not None:
180
+ dense_embeddings = self._embed_masks(masks)
181
+ else:
182
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
183
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
184
+ )
185
+
186
+ return sparse_embeddings, dense_embeddings
187
+
188
+
189
+ class PositionEmbeddingRandom(nn.Module):
190
+ """
191
+ Positional encoding using random spatial frequencies.
192
+ """
193
+
194
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
195
+ super().__init__()
196
+ if scale is None or scale <= 0.0:
197
+ scale = 1.0
198
+ self.register_buffer(
199
+ "positional_encoding_gaussian_matrix",
200
+ scale * torch.randn((2, num_pos_feats)),
201
+ )
202
+
203
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
204
+ """Positionally encode points that are normalized to [0,1]."""
205
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
206
+ coords = 2 * coords - 1
207
+
208
+ if coords.dtype != self.positional_encoding_gaussian_matrix.dtype:
209
+ coords = coords.to(self.positional_encoding_gaussian_matrix.dtype)
210
+
211
+ coords = coords @ self.positional_encoding_gaussian_matrix
212
+ coords = 2 * np.pi * coords
213
+ # outputs d_1 x ... x d_n x C shape
214
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
215
+
216
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
217
+ """Generate positional encoding for a grid of the specified size."""
218
+ h, w = size
219
+ device: Any = self.positional_encoding_gaussian_matrix.device
220
+ grid = torch.ones(
221
+ (h, w), device=device, dtype=self.positional_encoding_gaussian_matrix.dtype
222
+ )
223
+ y_embed = grid.cumsum(dim=0) - 0.5
224
+ x_embed = grid.cumsum(dim=1) - 0.5
225
+ y_embed = y_embed / h
226
+ x_embed = x_embed / w
227
+
228
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
229
+ return pe.permute(2, 0, 1) # C x H x W
230
+
231
+ def forward_with_coords(
232
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
233
+ ) -> torch.Tensor:
234
+ """Positionally encode points that are not normalized to [0,1]."""
235
+ coords = coords_input.clone()
236
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
237
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
238
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
model/segment_anything/modeling/sam.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Any, Dict, List, Tuple
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+
13
+ from .image_encoder import ImageEncoderViT
14
+ from .mask_decoder import MaskDecoder
15
+ from .prompt_encoder import PromptEncoder
16
+
17
+
18
+ class Sam(nn.Module):
19
+ mask_threshold: float = 0.0
20
+ image_format: str = "RGB"
21
+
22
+ def __init__(
23
+ self,
24
+ image_encoder: ImageEncoderViT,
25
+ prompt_encoder: PromptEncoder,
26
+ mask_decoder: MaskDecoder,
27
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
28
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
29
+ ) -> None:
30
+ """
31
+ SAM predicts object masks from an image and input prompts.
32
+
33
+ Arguments:
34
+ image_encoder (ImageEncoderViT): The backbone used to encode the
35
+ image into image embeddings that allow for efficient mask prediction.
36
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
37
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
38
+ and encoded prompts.
39
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
40
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
41
+ """
42
+ super().__init__()
43
+ self.image_encoder = image_encoder
44
+ self.prompt_encoder = prompt_encoder
45
+ self.mask_decoder = mask_decoder
46
+ self.register_buffer(
47
+ "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
48
+ )
49
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
50
+
51
+ @property
52
+ def device(self) -> Any:
53
+ return self.pixel_mean.device
54
+
55
+ @torch.no_grad()
56
+ def forward(
57
+ self,
58
+ batched_input: List[Dict[str, Any]],
59
+ multimask_output: bool,
60
+ ) -> List[Dict[str, torch.Tensor]]:
61
+ """
62
+ Predicts masks end-to-end from provided images and prompts.
63
+ If prompts are not known in advance, using SamPredictor is
64
+ recommended over calling the model directly.
65
+
66
+ Arguments:
67
+ batched_input (list(dict)): A list over input images, each a
68
+ dictionary with the following keys. A prompt key can be
69
+ excluded if it is not present.
70
+ 'image': The image as a torch tensor in 3xHxW format,
71
+ already transformed for input to the model.
72
+ 'original_size': (tuple(int, int)) The original size of
73
+ the image before transformation, as (H, W).
74
+ 'point_coords': (torch.Tensor) Batched point prompts for
75
+ this image, with shape BxNx2. Already transformed to the
76
+ input frame of the model.
77
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
78
+ with shape BxN.
79
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
80
+ Already transformed to the input frame of the model.
81
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
82
+ in the form Bx1xHxW.
83
+ multimask_output (bool): Whether the model should predict multiple
84
+ disambiguating masks, or return a single mask.
85
+
86
+ Returns:
87
+ (list(dict)): A list over input images, where each element is
88
+ as dictionary with the following keys.
89
+ 'masks': (torch.Tensor) Batched binary mask predictions,
90
+ with shape BxCxHxW, where B is the number of input prompts,
91
+ C is determined by multimask_output, and (H, W) is the
92
+ original size of the image.
93
+ 'iou_predictions': (torch.Tensor) The model's predictions
94
+ of mask quality, in shape BxC.
95
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
96
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
97
+ to subsequent iterations of prediction.
98
+ """
99
+ input_images = torch.stack(
100
+ [self.preprocess(x["image"]) for x in batched_input], dim=0
101
+ )
102
+ image_embeddings = self.image_encoder(input_images)
103
+
104
+ outputs = []
105
+ for image_record, curr_embedding in zip(batched_input, image_embeddings):
106
+ if "point_coords" in image_record:
107
+ points = (image_record["point_coords"], image_record["point_labels"])
108
+ else:
109
+ points = None
110
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
111
+ points=points,
112
+ boxes=image_record.get("boxes", None),
113
+ masks=image_record.get("mask_inputs", None),
114
+ )
115
+ low_res_masks, iou_predictions = self.mask_decoder(
116
+ image_embeddings=curr_embedding.unsqueeze(0),
117
+ image_pe=self.prompt_encoder.get_dense_pe(),
118
+ sparse_prompt_embeddings=sparse_embeddings,
119
+ dense_prompt_embeddings=dense_embeddings,
120
+ multimask_output=multimask_output,
121
+ )
122
+ masks = self.postprocess_masks(
123
+ low_res_masks,
124
+ input_size=image_record["image"].shape[-2:],
125
+ original_size=image_record["original_size"],
126
+ )
127
+ masks = masks > self.mask_threshold
128
+ outputs.append(
129
+ {
130
+ "masks": masks,
131
+ "iou_predictions": iou_predictions,
132
+ "low_res_logits": low_res_masks,
133
+ }
134
+ )
135
+ return outputs
136
+
137
+ def postprocess_masks(
138
+ self,
139
+ masks: torch.Tensor,
140
+ input_size: Tuple[int, ...],
141
+ original_size: Tuple[int, ...],
142
+ ) -> torch.Tensor:
143
+ """
144
+ Remove padding and upscale masks to the original image size.
145
+
146
+ Arguments:
147
+ masks (torch.Tensor): Batched masks from the mask_decoder,
148
+ in BxCxHxW format.
149
+ input_size (tuple(int, int)): The size of the image input to the
150
+ model, in (H, W) format. Used to remove padding.
151
+ original_size (tuple(int, int)): The original size of the image
152
+ before resizing for input to the model, in (H, W) format.
153
+
154
+ Returns:
155
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
156
+ is given by original_size.
157
+ """
158
+
159
+ dtype = masks.dtype
160
+
161
+ masks = F.interpolate(
162
+ masks.float(),
163
+ (self.image_encoder.img_size, self.image_encoder.img_size),
164
+ mode="bilinear",
165
+ align_corners=False,
166
+ )
167
+ # masks = masks.to(dtype)
168
+ masks = masks[..., : input_size[0], : input_size[1]]
169
+ masks = F.interpolate(
170
+ masks, original_size, mode="bilinear", align_corners=False
171
+ )
172
+ return masks
173
+
174
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
175
+ """Normalize pixel values and pad to a square input."""
176
+ # Normalize colors
177
+ x = (x - self.pixel_mean) / self.pixel_std
178
+
179
+ # Pad
180
+ h, w = x.shape[-2:]
181
+ padh = self.image_encoder.img_size - h
182
+ padw = self.image_encoder.img_size - w
183
+ x = F.pad(x, (0, padw, 0, padh))
184
+ return x
model/segment_anything/modeling/transformer.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Tuple, Type
9
+
10
+ import torch
11
+ from torch import Tensor, nn
12
+
13
+ from .common import MLPBlock
14
+
15
+
16
+ class TwoWayTransformer(nn.Module):
17
+ def __init__(
18
+ self,
19
+ depth: int,
20
+ embedding_dim: int,
21
+ num_heads: int,
22
+ mlp_dim: int,
23
+ activation: Type[nn.Module] = nn.ReLU,
24
+ attention_downsample_rate: int = 2,
25
+ ) -> None:
26
+ """
27
+ A transformer decoder that attends to an input image using
28
+ queries whose positional embedding is supplied.
29
+
30
+ Args:
31
+ depth (int): number of layers in the transformer
32
+ embedding_dim (int): the channel dimension for the input embeddings
33
+ num_heads (int): the number of heads for multihead attention. Must
34
+ divide embedding_dim
35
+ mlp_dim (int): the channel dimension internal to the MLP block
36
+ activation (nn.Module): the activation to use in the MLP block
37
+ """
38
+ super().__init__()
39
+ self.depth = depth
40
+ self.embedding_dim = embedding_dim
41
+ self.num_heads = num_heads
42
+ self.mlp_dim = mlp_dim
43
+ self.layers = nn.ModuleList()
44
+
45
+ for i in range(depth):
46
+ self.layers.append(
47
+ TwoWayAttentionBlock(
48
+ embedding_dim=embedding_dim,
49
+ num_heads=num_heads,
50
+ mlp_dim=mlp_dim,
51
+ activation=activation,
52
+ attention_downsample_rate=attention_downsample_rate,
53
+ skip_first_layer_pe=(i == 0),
54
+ )
55
+ )
56
+
57
+ self.final_attn_token_to_image = Attention(
58
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
59
+ )
60
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
61
+
62
+ def forward(
63
+ self,
64
+ image_embedding: Tensor,
65
+ image_pe: Tensor,
66
+ point_embedding: Tensor,
67
+ ) -> Tuple[Tensor, Tensor]:
68
+ """
69
+ Args:
70
+ image_embedding (torch.Tensor): image to attend to. Should be shape
71
+ B x embedding_dim x h x w for any h and w.
72
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
73
+ have the same shape as image_embedding.
74
+ point_embedding (torch.Tensor): the embedding to add to the query points.
75
+ Must have shape B x N_points x embedding_dim for any N_points.
76
+
77
+ Returns:
78
+ torch.Tensor: the processed point_embedding
79
+ torch.Tensor: the processed image_embedding
80
+ """
81
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
82
+ bs, c, h, w = image_embedding.shape
83
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
84
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
85
+
86
+ # Prepare queries
87
+ queries = point_embedding
88
+ keys = image_embedding
89
+
90
+ # Apply transformer blocks and final layernorm
91
+ for layer in self.layers:
92
+ queries, keys = layer(
93
+ queries=queries,
94
+ keys=keys,
95
+ query_pe=point_embedding,
96
+ key_pe=image_pe,
97
+ )
98
+
99
+ # Apply the final attention layer from the points to the image
100
+ q = queries + point_embedding
101
+ k = keys + image_pe
102
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
103
+ queries = queries + attn_out
104
+ queries = self.norm_final_attn(queries)
105
+
106
+ return queries, keys
107
+
108
+
109
+ class TwoWayAttentionBlock(nn.Module):
110
+ def __init__(
111
+ self,
112
+ embedding_dim: int,
113
+ num_heads: int,
114
+ mlp_dim: int = 2048,
115
+ activation: Type[nn.Module] = nn.ReLU,
116
+ attention_downsample_rate: int = 2,
117
+ skip_first_layer_pe: bool = False,
118
+ ) -> None:
119
+ """
120
+ A transformer block with four layers: (1) self-attention of sparse
121
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
122
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
123
+ inputs.
124
+
125
+ Arguments:
126
+ embedding_dim (int): the channel dimension of the embeddings
127
+ num_heads (int): the number of heads in the attention layers
128
+ mlp_dim (int): the hidden dimension of the mlp block
129
+ activation (nn.Module): the activation of the mlp block
130
+ skip_first_layer_pe (bool): skip the PE on the first layer
131
+ """
132
+ super().__init__()
133
+ self.self_attn = Attention(embedding_dim, num_heads)
134
+ self.norm1 = nn.LayerNorm(embedding_dim)
135
+
136
+ self.cross_attn_token_to_image = Attention(
137
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
138
+ )
139
+ self.norm2 = nn.LayerNorm(embedding_dim)
140
+
141
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
142
+ self.norm3 = nn.LayerNorm(embedding_dim)
143
+
144
+ self.norm4 = nn.LayerNorm(embedding_dim)
145
+ self.cross_attn_image_to_token = Attention(
146
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
147
+ )
148
+
149
+ self.skip_first_layer_pe = skip_first_layer_pe
150
+
151
+ def forward(
152
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
153
+ ) -> Tuple[Tensor, Tensor]:
154
+ # Self attention block
155
+ if self.skip_first_layer_pe:
156
+ queries = self.self_attn(q=queries, k=queries, v=queries)
157
+ else:
158
+ q = queries + query_pe
159
+ attn_out = self.self_attn(q=q, k=q, v=queries)
160
+ queries = queries + attn_out
161
+ queries = self.norm1(queries)
162
+
163
+ # Cross attention block, tokens attending to image embedding
164
+ q = queries + query_pe
165
+ k = keys + key_pe
166
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
167
+ queries = queries + attn_out
168
+ queries = self.norm2(queries)
169
+
170
+ # MLP block
171
+ mlp_out = self.mlp(queries)
172
+ queries = queries + mlp_out
173
+ queries = self.norm3(queries)
174
+
175
+ # Cross attention block, image embedding attending to tokens
176
+ q = queries + query_pe
177
+ k = keys + key_pe
178
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
179
+ keys = keys + attn_out
180
+ keys = self.norm4(keys)
181
+
182
+ return queries, keys
183
+
184
+
185
+ class Attention(nn.Module):
186
+ """
187
+ An attention layer that allows for downscaling the size of the embedding
188
+ after projection to queries, keys, and values.
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ embedding_dim: int,
194
+ num_heads: int,
195
+ downsample_rate: int = 1,
196
+ ) -> None:
197
+ super().__init__()
198
+ self.embedding_dim = embedding_dim
199
+ self.internal_dim = embedding_dim // downsample_rate
200
+ self.num_heads = num_heads
201
+ assert (
202
+ self.internal_dim % num_heads == 0
203
+ ), "num_heads must divide embedding_dim."
204
+
205
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
206
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
207
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
208
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
209
+
210
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
211
+ b, n, c = x.shape
212
+ x = x.reshape(b, n, num_heads, c // num_heads)
213
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
214
+
215
+ def _recombine_heads(self, x: Tensor) -> Tensor:
216
+ b, n_heads, n_tokens, c_per_head = x.shape
217
+ x = x.transpose(1, 2)
218
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
219
+
220
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
221
+ # Input projections
222
+ q = self.q_proj(q)
223
+ k = self.k_proj(k)
224
+ v = self.v_proj(v)
225
+
226
+ # Separate into heads
227
+ q = self._separate_heads(q, self.num_heads)
228
+ k = self._separate_heads(k, self.num_heads)
229
+ v = self._separate_heads(v, self.num_heads)
230
+
231
+ # Attention
232
+ _, _, _, c_per_head = q.shape
233
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
234
+ attn = attn / math.sqrt(c_per_head)
235
+ attn = torch.softmax(attn, dim=-1)
236
+
237
+ # Get output
238
+ out = attn @ v
239
+ out = self._recombine_heads(out)
240
+ out = self.out_proj(out)
241
+
242
+ return out
model/segment_anything/predictor.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Tuple
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from .modeling import Sam
13
+ from .utils.transforms import ResizeLongestSide
14
+
15
+
16
+ class SamPredictor:
17
+ def __init__(
18
+ self,
19
+ sam_model: Sam,
20
+ ) -> None:
21
+ """
22
+ Uses SAM to calculate the image embedding for an image, and then
23
+ allow repeated, efficient mask prediction given prompts.
24
+
25
+ Arguments:
26
+ sam_model (Sam): The model to use for mask prediction.
27
+ """
28
+ super().__init__()
29
+ self.model = sam_model
30
+ self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
31
+ self.reset_image()
32
+
33
+ def set_image(
34
+ self,
35
+ image: np.ndarray,
36
+ image_format: str = "RGB",
37
+ ) -> None:
38
+ """
39
+ Calculates the image embeddings for the provided image, allowing
40
+ masks to be predicted with the 'predict' method.
41
+
42
+ Arguments:
43
+ image (np.ndarray): The image for calculating masks. Expects an
44
+ image in HWC uint8 format, with pixel values in [0, 255].
45
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
46
+ """
47
+ assert image_format in [
48
+ "RGB",
49
+ "BGR",
50
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
51
+ if image_format != self.model.image_format:
52
+ image = image[..., ::-1]
53
+
54
+ # Transform the image to the form expected by the model
55
+ input_image = self.transform.apply_image(image)
56
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
57
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[
58
+ None, :, :, :
59
+ ]
60
+
61
+ self.set_torch_image(input_image_torch, image.shape[:2])
62
+
63
+ @torch.no_grad()
64
+ def set_torch_image(
65
+ self,
66
+ transformed_image: torch.Tensor,
67
+ original_image_size: Tuple[int, ...],
68
+ ) -> None:
69
+ """
70
+ Calculates the image embeddings for the provided image, allowing
71
+ masks to be predicted with the 'predict' method. Expects the input
72
+ image to be already transformed to the format expected by the model.
73
+
74
+ Arguments:
75
+ transformed_image (torch.Tensor): The input image, with shape
76
+ 1x3xHxW, which has been transformed with ResizeLongestSide.
77
+ original_image_size (tuple(int, int)): The size of the image
78
+ before transformation, in (H, W) format.
79
+ """
80
+ assert (
81
+ len(transformed_image.shape) == 4
82
+ and transformed_image.shape[1] == 3
83
+ and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
84
+ ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
85
+ self.reset_image()
86
+
87
+ self.original_size = original_image_size
88
+ self.input_size = tuple(transformed_image.shape[-2:])
89
+ input_image = self.model.preprocess(transformed_image)
90
+ self.features = self.model.image_encoder(input_image)
91
+ self.is_image_set = True
92
+
93
+ def predict(
94
+ self,
95
+ point_coords: Optional[np.ndarray] = None,
96
+ point_labels: Optional[np.ndarray] = None,
97
+ box: Optional[np.ndarray] = None,
98
+ mask_input: Optional[np.ndarray] = None,
99
+ multimask_output: bool = True,
100
+ return_logits: bool = False,
101
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
102
+ """
103
+ Predict masks for the given input prompts, using the currently set image.
104
+
105
+ Arguments:
106
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
107
+ model. Each point is in (X,Y) in pixels.
108
+ point_labels (np.ndarray or None): A length N array of labels for the
109
+ point prompts. 1 indicates a foreground point and 0 indicates a
110
+ background point.
111
+ box (np.ndarray or None): A length 4 array given a box prompt to the
112
+ model, in XYXY format.
113
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
114
+ coming from a previous prediction iteration. Has form 1xHxW, where
115
+ for SAM, H=W=256.
116
+ multimask_output (bool): If true, the model will return three masks.
117
+ For ambiguous input prompts (such as a single click), this will often
118
+ produce better masks than a single prediction. If only a single
119
+ mask is needed, the model's predicted quality score can be used
120
+ to select the best mask. For non-ambiguous prompts, such as multiple
121
+ input prompts, multimask_output=False can give better results.
122
+ return_logits (bool): If true, returns un-thresholded masks logits
123
+ instead of a binary mask.
124
+
125
+ Returns:
126
+ (np.ndarray): The output masks in CxHxW format, where C is the
127
+ number of masks, and (H, W) is the original image size.
128
+ (np.ndarray): An array of length C containing the model's
129
+ predictions for the quality of each mask.
130
+ (np.ndarray): An array of shape CxHxW, where C is the number
131
+ of masks and H=W=256. These low resolution logits can be passed to
132
+ a subsequent iteration as mask input.
133
+ """
134
+ if not self.is_image_set:
135
+ raise RuntimeError(
136
+ "An image must be set with .set_image(...) before mask prediction."
137
+ )
138
+
139
+ # Transform input prompts
140
+ coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
141
+ if point_coords is not None:
142
+ assert (
143
+ point_labels is not None
144
+ ), "point_labels must be supplied if point_coords is supplied."
145
+ point_coords = self.transform.apply_coords(point_coords, self.original_size)
146
+ coords_torch = torch.as_tensor(
147
+ point_coords, dtype=torch.float, device=self.device
148
+ )
149
+ labels_torch = torch.as_tensor(
150
+ point_labels, dtype=torch.int, device=self.device
151
+ )
152
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
153
+ if box is not None:
154
+ box = self.transform.apply_boxes(box, self.original_size)
155
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
156
+ box_torch = box_torch[None, :]
157
+ if mask_input is not None:
158
+ mask_input_torch = torch.as_tensor(
159
+ mask_input, dtype=torch.float, device=self.device
160
+ )
161
+ mask_input_torch = mask_input_torch[None, :, :, :]
162
+
163
+ masks, iou_predictions, low_res_masks = self.predict_torch(
164
+ coords_torch,
165
+ labels_torch,
166
+ box_torch,
167
+ mask_input_torch,
168
+ multimask_output,
169
+ return_logits=return_logits,
170
+ )
171
+
172
+ masks_np = masks[0].detach().cpu().numpy()
173
+ iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
174
+ low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
175
+ return masks_np, iou_predictions_np, low_res_masks_np
176
+
177
+ @torch.no_grad()
178
+ def predict_torch(
179
+ self,
180
+ point_coords: Optional[torch.Tensor],
181
+ point_labels: Optional[torch.Tensor],
182
+ boxes: Optional[torch.Tensor] = None,
183
+ mask_input: Optional[torch.Tensor] = None,
184
+ multimask_output: bool = True,
185
+ return_logits: bool = False,
186
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
187
+ """
188
+ Predict masks for the given input prompts, using the currently set image.
189
+ Input prompts are batched torch tensors and are expected to already be
190
+ transformed to the input frame using ResizeLongestSide.
191
+
192
+ Arguments:
193
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
194
+ model. Each point is in (X,Y) in pixels.
195
+ point_labels (torch.Tensor or None): A BxN array of labels for the
196
+ point prompts. 1 indicates a foreground point and 0 indicates a
197
+ background point.
198
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
199
+ model, in XYXY format.
200
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
201
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
202
+ for SAM, H=W=256. Masks returned by a previous iteration of the
203
+ predict method do not need further transformation.
204
+ multimask_output (bool): If true, the model will return three masks.
205
+ For ambiguous input prompts (such as a single click), this will often
206
+ produce better masks than a single prediction. If only a single
207
+ mask is needed, the model's predicted quality score can be used
208
+ to select the best mask. For non-ambiguous prompts, such as multiple
209
+ input prompts, multimask_output=False can give better results.
210
+ return_logits (bool): If true, returns un-thresholded masks logits
211
+ instead of a binary mask.
212
+
213
+ Returns:
214
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
215
+ number of masks, and (H, W) is the original image size.
216
+ (torch.Tensor): An array of shape BxC containing the model's
217
+ predictions for the quality of each mask.
218
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
219
+ of masks and H=W=256. These low res logits can be passed to
220
+ a subsequent iteration as mask input.
221
+ """
222
+ if not self.is_image_set:
223
+ raise RuntimeError(
224
+ "An image must be set with .set_image(...) before mask prediction."
225
+ )
226
+
227
+ if point_coords is not None:
228
+ points = (point_coords, point_labels)
229
+ else:
230
+ points = None
231
+
232
+ # Embed prompts
233
+ sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
234
+ points=points,
235
+ boxes=boxes,
236
+ masks=mask_input,
237
+ )
238
+
239
+ # Predict masks
240
+ low_res_masks, iou_predictions = self.model.mask_decoder(
241
+ image_embeddings=self.features,
242
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
243
+ sparse_prompt_embeddings=sparse_embeddings,
244
+ dense_prompt_embeddings=dense_embeddings,
245
+ multimask_output=multimask_output,
246
+ )
247
+
248
+ # Upscale the masks to the original image resolution
249
+ masks = self.model.postprocess_masks(
250
+ low_res_masks, self.input_size, self.original_size
251
+ )
252
+
253
+ if not return_logits:
254
+ masks = masks > self.model.mask_threshold
255
+
256
+ return masks, iou_predictions, low_res_masks
257
+
258
+ def get_image_embedding(self) -> torch.Tensor:
259
+ """
260
+ Returns the image embeddings for the currently set image, with
261
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
262
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
263
+ """
264
+ if not self.is_image_set:
265
+ raise RuntimeError(
266
+ "An image must be set with .set_image(...) to generate an embedding."
267
+ )
268
+ assert (
269
+ self.features is not None
270
+ ), "Features must exist if an image has been set."
271
+ return self.features
272
+
273
+ @property
274
+ def device(self) -> torch.device:
275
+ return self.model.device
276
+
277
+ def reset_image(self) -> None:
278
+ """Resets the currently set image."""
279
+ self.is_image_set = False
280
+ self.features = None
281
+ self.orig_h = None
282
+ self.orig_w = None
283
+ self.input_h = None
284
+ self.input_w = None
model/segment_anything/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
model/segment_anything/utils/amg.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from copy import deepcopy
9
+ from itertools import product
10
+ from typing import Any, Dict, Generator, ItemsView, List, Tuple
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+
16
+ class MaskData:
17
+ """
18
+ A structure for storing masks and their related data in batched format.
19
+ Implements basic filtering and concatenation.
20
+ """
21
+
22
+ def __init__(self, **kwargs) -> None:
23
+ for v in kwargs.values():
24
+ assert isinstance(
25
+ v, (list, np.ndarray, torch.Tensor)
26
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
27
+ self._stats = dict(**kwargs)
28
+
29
+ def __setitem__(self, key: str, item: Any) -> None:
30
+ assert isinstance(
31
+ item, (list, np.ndarray, torch.Tensor)
32
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
33
+ self._stats[key] = item
34
+
35
+ def __delitem__(self, key: str) -> None:
36
+ del self._stats[key]
37
+
38
+ def __getitem__(self, key: str) -> Any:
39
+ return self._stats[key]
40
+
41
+ def items(self) -> ItemsView[str, Any]:
42
+ return self._stats.items()
43
+
44
+ def filter(self, keep: torch.Tensor) -> None:
45
+ for k, v in self._stats.items():
46
+ if v is None:
47
+ self._stats[k] = None
48
+ elif isinstance(v, torch.Tensor):
49
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
50
+ elif isinstance(v, np.ndarray):
51
+ self._stats[k] = v[keep.detach().cpu().numpy()]
52
+ elif isinstance(v, list) and keep.dtype == torch.bool:
53
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
54
+ elif isinstance(v, list):
55
+ self._stats[k] = [v[i] for i in keep]
56
+ else:
57
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
58
+
59
+ def cat(self, new_stats: "MaskData") -> None:
60
+ for k, v in new_stats.items():
61
+ if k not in self._stats or self._stats[k] is None:
62
+ self._stats[k] = deepcopy(v)
63
+ elif isinstance(v, torch.Tensor):
64
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
65
+ elif isinstance(v, np.ndarray):
66
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
67
+ elif isinstance(v, list):
68
+ self._stats[k] = self._stats[k] + deepcopy(v)
69
+ else:
70
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
71
+
72
+ def to_numpy(self) -> None:
73
+ for k, v in self._stats.items():
74
+ if isinstance(v, torch.Tensor):
75
+ self._stats[k] = v.detach().cpu().numpy()
76
+
77
+
78
+ def is_box_near_crop_edge(
79
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
80
+ ) -> torch.Tensor:
81
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
82
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
83
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
84
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
85
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
86
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
87
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
88
+ return torch.any(near_crop_edge, dim=1)
89
+
90
+
91
+ def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
92
+ box_xywh = deepcopy(box_xyxy)
93
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
94
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
95
+ return box_xywh
96
+
97
+
98
+ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
99
+ assert len(args) > 0 and all(
100
+ len(a) == len(args[0]) for a in args
101
+ ), "Batched iteration must have inputs of all the same size."
102
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
103
+ for b in range(n_batches):
104
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
105
+
106
+
107
+ def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
108
+ """
109
+ Encodes masks to an uncompressed RLE, in the format expected by
110
+ pycoco tools.
111
+ """
112
+ # Put in fortran order and flatten h,w
113
+ b, h, w = tensor.shape
114
+ tensor = tensor.permute(0, 2, 1).flatten(1)
115
+
116
+ # Compute change indices
117
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
118
+ change_indices = diff.nonzero()
119
+
120
+ # Encode run length
121
+ out = []
122
+ for i in range(b):
123
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
124
+ cur_idxs = torch.cat(
125
+ [
126
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
127
+ cur_idxs + 1,
128
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
129
+ ]
130
+ )
131
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
132
+ counts = [] if tensor[i, 0] == 0 else [0]
133
+ counts.extend(btw_idxs.detach().cpu().tolist())
134
+ out.append({"size": [h, w], "counts": counts})
135
+ return out
136
+
137
+
138
+ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
139
+ """Compute a binary mask from an uncompressed RLE."""
140
+ h, w = rle["size"]
141
+ mask = np.empty(h * w, dtype=bool)
142
+ idx = 0
143
+ parity = False
144
+ for count in rle["counts"]:
145
+ mask[idx : idx + count] = parity
146
+ idx += count
147
+ parity ^= True
148
+ mask = mask.reshape(w, h)
149
+ return mask.transpose() # Put in C order
150
+
151
+
152
+ def area_from_rle(rle: Dict[str, Any]) -> int:
153
+ return sum(rle["counts"][1::2])
154
+
155
+
156
+ def calculate_stability_score(
157
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
158
+ ) -> torch.Tensor:
159
+ """
160
+ Computes the stability score for a batch of masks. The stability
161
+ score is the IoU between the binary masks obtained by thresholding
162
+ the predicted mask logits at high and low values.
163
+ """
164
+ # One mask is always contained inside the other.
165
+ # Save memory by preventing unnecessary cast to torch.int64
166
+ intersections = (
167
+ (masks > (mask_threshold + threshold_offset))
168
+ .sum(-1, dtype=torch.int16)
169
+ .sum(-1, dtype=torch.int32)
170
+ )
171
+ unions = (
172
+ (masks > (mask_threshold - threshold_offset))
173
+ .sum(-1, dtype=torch.int16)
174
+ .sum(-1, dtype=torch.int32)
175
+ )
176
+ return intersections / unions
177
+
178
+
179
+ def build_point_grid(n_per_side: int) -> np.ndarray:
180
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
181
+ offset = 1 / (2 * n_per_side)
182
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
183
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
184
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
185
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
186
+ return points
187
+
188
+
189
+ def build_all_layer_point_grids(
190
+ n_per_side: int, n_layers: int, scale_per_layer: int
191
+ ) -> List[np.ndarray]:
192
+ """Generates point grids for all crop layers."""
193
+ points_by_layer = []
194
+ for i in range(n_layers + 1):
195
+ n_points = int(n_per_side / (scale_per_layer**i))
196
+ points_by_layer.append(build_point_grid(n_points))
197
+ return points_by_layer
198
+
199
+
200
+ def generate_crop_boxes(
201
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
202
+ ) -> Tuple[List[List[int]], List[int]]:
203
+ """
204
+ Generates a list of crop boxes of different sizes. Each layer
205
+ has (2**i)**2 boxes for the ith layer.
206
+ """
207
+ crop_boxes, layer_idxs = [], []
208
+ im_h, im_w = im_size
209
+ short_side = min(im_h, im_w)
210
+
211
+ # Original image
212
+ crop_boxes.append([0, 0, im_w, im_h])
213
+ layer_idxs.append(0)
214
+
215
+ def crop_len(orig_len, n_crops, overlap):
216
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
217
+
218
+ for i_layer in range(n_layers):
219
+ n_crops_per_side = 2 ** (i_layer + 1)
220
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
221
+
222
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
223
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
224
+
225
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
226
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
227
+
228
+ # Crops in XYWH format
229
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
230
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
231
+ crop_boxes.append(box)
232
+ layer_idxs.append(i_layer + 1)
233
+
234
+ return crop_boxes, layer_idxs
235
+
236
+
237
+ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
238
+ x0, y0, _, _ = crop_box
239
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
240
+ # Check if boxes has a channel dimension
241
+ if len(boxes.shape) == 3:
242
+ offset = offset.unsqueeze(1)
243
+ return boxes + offset
244
+
245
+
246
+ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
247
+ x0, y0, _, _ = crop_box
248
+ offset = torch.tensor([[x0, y0]], device=points.device)
249
+ # Check if points has a channel dimension
250
+ if len(points.shape) == 3:
251
+ offset = offset.unsqueeze(1)
252
+ return points + offset
253
+
254
+
255
+ def uncrop_masks(
256
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
257
+ ) -> torch.Tensor:
258
+ x0, y0, x1, y1 = crop_box
259
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
260
+ return masks
261
+ # Coordinate transform masks
262
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
263
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
264
+ return torch.nn.functional.pad(masks, pad, value=0)
265
+
266
+
267
+ def remove_small_regions(
268
+ mask: np.ndarray, area_thresh: float, mode: str
269
+ ) -> Tuple[np.ndarray, bool]:
270
+ """
271
+ Removes small disconnected regions and holes in a mask. Returns the
272
+ mask and an indicator of if the mask has been modified.
273
+ """
274
+ import cv2 # type: ignore
275
+
276
+ assert mode in ["holes", "islands"]
277
+ correct_holes = mode == "holes"
278
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
279
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
280
+ sizes = stats[:, -1][1:] # Row 0 is background label
281
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
282
+ if len(small_regions) == 0:
283
+ return mask, False
284
+ fill_labels = [0] + small_regions
285
+ if not correct_holes:
286
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
287
+ # If every region is below threshold, keep largest
288
+ if len(fill_labels) == 0:
289
+ fill_labels = [int(np.argmax(sizes)) + 1]
290
+ mask = np.isin(regions, fill_labels)
291
+ return mask, True
292
+
293
+
294
+ def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
295
+ from pycocotools import mask as mask_utils # type: ignore
296
+
297
+ h, w = uncompressed_rle["size"]
298
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
299
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
300
+ return rle
301
+
302
+
303
+ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
304
+ """
305
+ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
306
+ an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
307
+ """
308
+ # torch.max below raises an error on empty inputs, just skip in this case
309
+ if torch.numel(masks) == 0:
310
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
311
+
312
+ # Normalize shape to CxHxW
313
+ shape = masks.shape
314
+ h, w = shape[-2:]
315
+ if len(shape) > 2:
316
+ masks = masks.flatten(0, -3)
317
+ else:
318
+ masks = masks.unsqueeze(0)
319
+
320
+ # Get top and bottom edges
321
+ in_height, _ = torch.max(masks, dim=-1)
322
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
323
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
324
+ in_height_coords = in_height_coords + h * (~in_height)
325
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
326
+
327
+ # Get left and right edges
328
+ in_width, _ = torch.max(masks, dim=-2)
329
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
330
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
331
+ in_width_coords = in_width_coords + w * (~in_width)
332
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
333
+
334
+ # If the mask is empty the right edge will be to the left of the left edge.
335
+ # Replace these boxes with [0, 0, 0, 0]
336
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
337
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
338
+ out = out * (~empty_filter).unsqueeze(-1)
339
+
340
+ # Return to original shape
341
+ if len(shape) > 2:
342
+ out = out.reshape(*shape[:-2], 4)
343
+ else:
344
+ out = out[0]
345
+
346
+ return out
model/segment_anything/utils/onnx.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.nn import functional as F
12
+
13
+ from ..modeling import Sam
14
+ from .amg import calculate_stability_score
15
+
16
+
17
+ class SamOnnxModel(nn.Module):
18
+ """
19
+ This model should not be called directly, but is used in ONNX export.
20
+ It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21
+ with some functions modified to enable model tracing. Also supports extra
22
+ options controlling what information. See the ONNX export script for details.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ model: Sam,
28
+ return_single_mask: bool,
29
+ use_stability_score: bool = False,
30
+ return_extra_metrics: bool = False,
31
+ ) -> None:
32
+ super().__init__()
33
+ self.mask_decoder = model.mask_decoder
34
+ self.model = model
35
+ self.img_size = model.image_encoder.img_size
36
+ self.return_single_mask = return_single_mask
37
+ self.use_stability_score = use_stability_score
38
+ self.stability_score_offset = 1.0
39
+ self.return_extra_metrics = return_extra_metrics
40
+
41
+ @staticmethod
42
+ def resize_longest_image_size(
43
+ input_image_size: torch.Tensor, longest_side: int
44
+ ) -> torch.Tensor:
45
+ input_image_size = input_image_size.to(torch.float32)
46
+ scale = longest_side / torch.max(input_image_size)
47
+ transformed_size = scale * input_image_size
48
+ transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
49
+ return transformed_size
50
+
51
+ def _embed_points(
52
+ self, point_coords: torch.Tensor, point_labels: torch.Tensor
53
+ ) -> torch.Tensor:
54
+ point_coords = point_coords + 0.5
55
+ point_coords = point_coords / self.img_size
56
+ point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
57
+ point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
58
+
59
+ point_embedding = point_embedding * (point_labels != -1)
60
+ point_embedding = (
61
+ point_embedding
62
+ + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
63
+ )
64
+
65
+ for i in range(self.model.prompt_encoder.num_point_embeddings):
66
+ point_embedding = (
67
+ point_embedding
68
+ + self.model.prompt_encoder.point_embeddings[i].weight
69
+ * (point_labels == i)
70
+ )
71
+
72
+ return point_embedding
73
+
74
+ def _embed_masks(
75
+ self, input_mask: torch.Tensor, has_mask_input: torch.Tensor
76
+ ) -> torch.Tensor:
77
+ mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(
78
+ input_mask
79
+ )
80
+ mask_embedding = mask_embedding + (
81
+ 1 - has_mask_input
82
+ ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
83
+ return mask_embedding
84
+
85
+ def mask_postprocessing(
86
+ self, masks: torch.Tensor, orig_im_size: torch.Tensor
87
+ ) -> torch.Tensor:
88
+ masks = F.interpolate(
89
+ masks,
90
+ size=(self.img_size, self.img_size),
91
+ mode="bilinear",
92
+ align_corners=False,
93
+ )
94
+
95
+ prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(
96
+ torch.int64
97
+ )
98
+ masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
99
+
100
+ orig_im_size = orig_im_size.to(torch.int64)
101
+ h, w = orig_im_size[0], orig_im_size[1]
102
+ masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
103
+ return masks
104
+
105
+ def select_masks(
106
+ self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
107
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
108
+ # Determine if we should return the multiclick mask or not from the number of points.
109
+ # The reweighting is used to avoid control flow.
110
+ score_reweight = torch.tensor(
111
+ [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
112
+ ).to(iou_preds.device)
113
+ score = iou_preds + (num_points - 2.5) * score_reweight
114
+ best_idx = torch.argmax(score, dim=1)
115
+ masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
116
+ iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
117
+
118
+ return masks, iou_preds
119
+
120
+ @torch.no_grad()
121
+ def forward(
122
+ self,
123
+ image_embeddings: torch.Tensor,
124
+ point_coords: torch.Tensor,
125
+ point_labels: torch.Tensor,
126
+ mask_input: torch.Tensor,
127
+ has_mask_input: torch.Tensor,
128
+ orig_im_size: torch.Tensor,
129
+ ):
130
+ sparse_embedding = self._embed_points(point_coords, point_labels)
131
+ dense_embedding = self._embed_masks(mask_input, has_mask_input)
132
+
133
+ masks, scores = self.model.mask_decoder.predict_masks(
134
+ image_embeddings=image_embeddings,
135
+ image_pe=self.model.prompt_encoder.get_dense_pe(),
136
+ sparse_prompt_embeddings=sparse_embedding,
137
+ dense_prompt_embeddings=dense_embedding,
138
+ )
139
+
140
+ if self.use_stability_score:
141
+ scores = calculate_stability_score(
142
+ masks, self.model.mask_threshold, self.stability_score_offset
143
+ )
144
+
145
+ if self.return_single_mask:
146
+ masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
147
+
148
+ upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
149
+
150
+ if self.return_extra_metrics:
151
+ stability_scores = calculate_stability_score(
152
+ upscaled_masks, self.model.mask_threshold, self.stability_score_offset
153
+ )
154
+ areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
155
+ return upscaled_masks, scores, stability_scores, areas, masks
156
+
157
+ return upscaled_masks, scores, masks
model/segment_anything/utils/transforms.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from copy import deepcopy
8
+ from typing import Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch.nn import functional as F
13
+ from torchvision.transforms.functional import resize # type: ignore
14
+ from torchvision.transforms.functional import to_pil_image
15
+
16
+
17
+ class ResizeLongestSide:
18
+ """
19
+ Resizes images to the longest side 'target_length', as well as provides
20
+ methods for resizing coordinates and boxes. Provides methods for
21
+ transforming both numpy array and batched torch tensors.
22
+ """
23
+
24
+ def __init__(self, target_length: int) -> None:
25
+ self.target_length = target_length
26
+
27
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
28
+ """
29
+ Expects a numpy array with shape HxWxC in uint8 format.
30
+ """
31
+ target_size = self.get_preprocess_shape(
32
+ image.shape[0], image.shape[1], self.target_length
33
+ )
34
+ return np.array(resize(to_pil_image(image), target_size))
35
+
36
+ def apply_coords(
37
+ self, coords: np.ndarray, original_size: Tuple[int, ...]
38
+ ) -> np.ndarray:
39
+ """
40
+ Expects a numpy array of length 2 in the final dimension. Requires the
41
+ original image size in (H, W) format.
42
+ """
43
+ old_h, old_w = original_size
44
+ new_h, new_w = self.get_preprocess_shape(
45
+ original_size[0], original_size[1], self.target_length
46
+ )
47
+ coords = deepcopy(coords).astype(float)
48
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
49
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
50
+ return coords
51
+
52
+ def apply_boxes(
53
+ self, boxes: np.ndarray, original_size: Tuple[int, ...]
54
+ ) -> np.ndarray:
55
+ """
56
+ Expects a numpy array shape Bx4. Requires the original image size
57
+ in (H, W) format.
58
+ """
59
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
60
+ return boxes.reshape(-1, 4)
61
+
62
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
63
+ """
64
+ Expects batched images with shape BxCxHxW and float format. This
65
+ transformation may not exactly match apply_image. apply_image is
66
+ the transformation expected by the model.
67
+ """
68
+ # Expects an image in BCHW format. May not exactly match apply_image.
69
+ target_size = self.get_preprocess_shape(
70
+ image.shape[0], image.shape[1], self.target_length
71
+ )
72
+ return F.interpolate(
73
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
74
+ )
75
+
76
+ def apply_coords_torch(
77
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
78
+ ) -> torch.Tensor:
79
+ """
80
+ Expects a torch tensor with length 2 in the last dimension. Requires the
81
+ original image size in (H, W) format.
82
+ """
83
+ old_h, old_w = original_size
84
+ new_h, new_w = self.get_preprocess_shape(
85
+ original_size[0], original_size[1], self.target_length
86
+ )
87
+ coords = deepcopy(coords).to(torch.float)
88
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
89
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
90
+ return coords
91
+
92
+ def apply_boxes_torch(
93
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
94
+ ) -> torch.Tensor:
95
+ """
96
+ Expects a torch tensor with shape Bx4. Requires the original image
97
+ size in (H, W) format.
98
+ """
99
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
100
+ return boxes.reshape(-1, 4)
101
+
102
+ @staticmethod
103
+ def get_preprocess_shape(
104
+ oldh: int, oldw: int, long_side_length: int
105
+ ) -> Tuple[int, int]:
106
+ """
107
+ Compute the output size given input size and target long side length.
108
+ """
109
+ scale = long_side_length * 1.0 / max(oldh, oldw)
110
+ newh, neww = oldh * scale, oldw * scale
111
+ neww = int(neww + 0.5)
112
+ newh = int(newh + 0.5)
113
+ return (newh, neww)
model/unilm/beit3/README.md ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [(BEiT-3) Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks](https://arxiv.org/abs/2208.10442)
2
+
3
+ Official PyTorch implementation and pretrained models of BEiT-3.
4
+
5
+ The code and pretrained models of **BEiT** can be found at [here](https://github.com/microsoft/unilm/tree/master/beit).
6
+
7
+ The code and pretrained models of **BEiT v2** can be found at [here](https://github.com/microsoft/unilm/tree/master/beit2).
8
+
9
+ - March, 2023: release [the code and pretrained models of **BEiT-3**](https://github.com/microsoft/unilm/tree/master/beit3)
10
+ - March, 2023: [**BEiT-3**](https://arxiv.org/abs/2208.10442) was accepted by **CVPR 2023**.
11
+ - Sept 2022: release [the code and pretrained models of **BEiT v2**](https://github.com/microsoft/unilm/tree/master/beit2)
12
+ - Aug 2022: release preprint [Image as a Foreign Language: BEiT Pretraining for All Vision and Vision-Language Tasks](https://arxiv.org/abs/2208.10442)
13
+ - Aug 2022: release preprint [BEiT v2: Masked Image Modeling with Vector-Quantized Visual Tokenizers](https://arxiv.org/abs/2208.06366)
14
+ - June 2022: release preprint [VL-BEiT: Generative Vision-Language Pretraining](https://arxiv.org/abs/2206.01127)
15
+ - March, 2022: add [linear probe examples](https://github.com/microsoft/unilm/blob/master/beit/get_started_for_image_classification.md#example-linear-probe-on-imagenet)
16
+ - January, 2022: [**BEiT**](https://openreview.net/forum?id=p-BhZSz59o4) was accepted by **ICLR 2022 as Oral presentation** (54 out of 3391).
17
+ - August 2021: [**BEiT**](https://huggingface.co/transformers/master/model_doc/beit.html) is on [HuggingFace](https://github.com/huggingface/transformers)
18
+ - July 2021: BEiT-large achieves **[state-of-the-art results on ADE20K](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k) (a big jump to 57.0 mIoU) for semantic segmentation**.
19
+ - July 2021: BEiT-large achieves **state-of-the-art ImageNet top-1 accuracy (88.6%) under the setting without extra data other than ImageNet-22k**.
20
+ - July 2021: release [the code and pretrained models of **BEiT**](https://github.com/microsoft/unilm/tree/master/beit)
21
+ - June 2021: release preprint [BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254)
22
+
23
+ ## Pretrained models
24
+
25
+ We provide BEiT-3 weights pretrained on monomodal and multimodal data. Our large-size model outperforms previous large-size models across various vision-language and vision downstream tasks. The models were pretrained with 224x224 resolution.
26
+
27
+ ### Tips
28
+ - For vision-language tasks that require deep fusion, we recommend using `BEiT3-base` and `BEiT3-large`.
29
+ - For image-text retrieval or vision tasks, using `BEiT3-base-itc` and `BEiT3-large-itc` usually achieve better performance.
30
+
31
+ ### Download Checkpoints
32
+
33
+ 1. Models pretrained on ImageNet-21k images, 160 GB text documents, and web-scale image-text pairs (collected from [LAION-400M](https://laion.ai/blog/laion-400-open-dataset/), [English LAION-2B](https://laion.ai/blog/laion-5b/), [COYO-700M](https://github.com/kakaobrain/coyo-dataset), and CC15M).
34
+ - [`BEiT3-base`](https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224.pth): #layer=12; hidden=768; FFN factor=4x; #head=12; patch=16x16; #parameters: 276M
35
+ - [`BEiT3-large`](https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224.pth): #layer=24; hidden=1024; FFN factor=4x; #head=16; patch=16x16; #parameters: 746M
36
+
37
+ 2. Perform image-text contrastive intermediate tuning on `BEiT3-base` and `BEiT3-large`.
38
+ - [`BEiT3-base-itc`](https://github.com/addf400/files/releases/download/beit3/beit3_base_itc_patch16_224.pth): #layer=12; hidden=768; FFN factor=4x; #head=12; patch=16x16; #parameters: 222M
39
+ - [`BEiT3-large-itc`](https://github.com/addf400/files/releases/download/beit3/beit3_large_itc_patch16_224.pth): #layer=24; hidden=1024; FFN factor=4x; #head=16; patch=16x16; #parameters: 674M
40
+
41
+ 3. Add indomain image-text pairs (COCO and VG) to continue training `BEiT3-base` and `BEiT3-large` using masked data modeling. The indomain models achieve better performance on VQAv2 and NLVR2 tasks.
42
+ - [`BEiT3-base-indomain`](https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224.pth): #layer=12; hidden=768; FFN factor=4x; #head=12; patch=16x16; #parameters: 276M
43
+ - [`BEiT3-large-indomain`](https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224.pth): #layer=24; hidden=1024; FFN factor=4x; #head=16; patch=16x16; #parameters: 746M
44
+
45
+ ### Text Tokenizer
46
+
47
+ [beit3.spm](https://github.com/addf400/files/releases/download/beit3/beit3.spm) is the sentencepiece model used for tokenizing texts.
48
+ ```
49
+ from transformers import XLMRobertaTokenizer
50
+ tokenizer = XLMRobertaTokenizer("/your_beit3_model_path/beit3.spm")
51
+ ```
52
+
53
+ ### Architecture
54
+
55
+ We use [Magneto](https://arxiv.org/abs/2210.06423) with decoupled Multiway Transformer as the backbone architecture. Magneto can have better training stability and obtain better performance across modalities (such as vision, and language). The implementation is based on the [torchscale](https://github.com/microsoft/torchscale/blob/main/torchscale/model/BEiT3.py) package.
56
+
57
+
58
+ ## Setup
59
+
60
+ ```
61
+ alias=`whoami | cut -d'.' -f2`; docker run -it --rm --runtime=nvidia --ipc=host --privileged -v /home/${alias}:/home/${alias} pytorch/pytorch:1.8.1-cuda11.1-cudnn8-devel bash
62
+ ```
63
+
64
+ Clone the repo and install required packages:
65
+ ```
66
+ git clone https://github.com/microsoft/unilm.git
67
+ cd unilm/beit3
68
+ pip install -r requirements.txt
69
+ ```
70
+
71
+
72
+ ## Fine-tuning on ImageNet-1k (Image Classification)
73
+
74
+ The detailed instructions can be found at [`get_started_for_image_classification.md`](get_started/get_started_for_image_classification.md). We only use vision-related parameters for image classification fine-tuning.
75
+
76
+ | initialized checkpoint | resolution | acc@1 | acc@5 | #params | weight |
77
+ |:----------------------------------------|:----------:|:-----:|:-----:|:-------:|-------------------|
78
+ | [beit3_base_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224.pth) | 224x224 | 85.4 | 97.6 | 87M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224_in1k.pth) |
79
+ | [beit3_base_indomain_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224.pth) | 224x224 | 85.4 | 97.6 | 87M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224_in1k.pth) |
80
+ | [beit3_large_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224.pth) | 224x224 | 87.6 | 98.3 | 305M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224_in1k.pth) |
81
+ | [beit3_large_indomain_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224.pth) | 224x224 | 87.5 | 98.3 | 305M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224_in1k.pth) |
82
+
83
+
84
+ ## Fine-tuning on VQAv2 (Visual Question Answering)
85
+
86
+ The detailed instructions can be found at [`get_started_for_vqav2.md`](get_started/get_started_for_vqav2.md).
87
+
88
+ | initialized checkpoint | resolution | augmented data | test-dev | test-std | #params | weight |
89
+ |:----------------------------------------|:----------:|:-----:|:-----:|:-----:|:-------:|-------------------|
90
+ | [beit3_base_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224.pth) | 480x480 | - | 77.65 | - | 228M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_480_vqa.pth) |
91
+ | [beit3_base_indomain_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224.pth) | 480x480 | - | 78.46 | - | 228M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_480_vqa.pth) |
92
+ | [beit3_large_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224.pth) | 480x480 | - | 81.85 | - | 683M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_480_vqa.pth) |
93
+ | [beit3_large_indomain_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224.pth) | 480x480 | - | 82.53 | - | 683M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_480_vqa.pth) |
94
+ | [beit3_large_indomain_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224.pth) | 768x768 | VGQA | 82.97 | 83.03 | 684M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_768_vgqaaug_vqa.pth) |
95
+
96
+
97
+ ## Fine-tuning on NLVR2 (Visual Reasoning)
98
+
99
+ The detailed instructions can be found at [`get_started_for_nlvr2.md`](get_started/get_started_for_nlvr2.md).
100
+
101
+ | initialized checkpoint | resolution | dev | test-P | #params | weight |
102
+ |:----------------------------------------|:----------:|:-----:|:-----:|:-------:|-------------------|
103
+ | [beit3_base_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224.pth) | 224x224 | 83.6 | 84.4 | 226M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224_nlvr2.pth) |
104
+ | [beit3_base_indomain_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224.pth) | 224x224 | 84.6 | 85.3 | 226M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224_nlvr2.pth) |
105
+ | [beit3_large_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224.pth) | 224x224 | 88.5 | 89.4 | 681M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224_nlvr2.pth) |
106
+ | [beit3_large_indomain_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224.pth) | 224x224 | 89.2 | 90.0 | 681M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_large_indomain_patch16_224_nlvr2.pth) |
107
+
108
+
109
+ ## Fine-tuning on COCO Captioning and NoCaps (Image Captioning)
110
+
111
+ The detailed instructions can be found at [`get_started_for_image_captioning.md`](get_started/get_started_for_captioning.md).
112
+
113
+ ### COCO Captioning
114
+
115
+ | initialized checkpoint | resolution | test CIDEr | #params | weight |
116
+ |:----------------------------------------|:----------:|:-----:|:-------:|-------------------|
117
+ | [beit3_base_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224.pth) | 480x480 | 133.6 | 271M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_480_coco_captioning.pth) |
118
+ | [beit3_base_indomain_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224.pth) | 480x480 | 135.0 | 271M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_480_coco_captioning.pth) |
119
+ | [beit3_large_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224.pth) | 480x480 | 143.2 | 739M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_480_coco_captioning.pth) |
120
+
121
+ ### NoCaps
122
+
123
+ | initialized checkpoint | resolution | val CIDEr | #params | weight |
124
+ |:----------------------------------------|:----------:|:-----:|:-------:|-------------------|
125
+ | [beit3_base_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_224.pth) | 480x480 | 104.4 | 271M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_480_nocaps.pth) |
126
+ | [beit3_base_indomain_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_224.pth) | 480x480 | 105.6 | 271M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_base_indomain_patch16_480_nocaps.pth) |
127
+ | [beit3_large_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_224.pth) | 480x480 | 120.2 | 739M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_480_nocaps.pth) |
128
+
129
+
130
+ ## Fine-tuning on COCO and Flickr30k Retrieval (Image-Text Retrieval)
131
+
132
+ The detailed instructions can be found at [`get_started_for_retrieval.md`](get_started/get_started_for_retrieval.md).
133
+
134
+ ### COCO Retrieval
135
+
136
+ | initialized checkpoint | resolution | IR@1 | TR@1 | #params | weight |
137
+ |:----------------------------------------|:----------:|:-----:|:-----:|:-------:|-------------------|
138
+ | [beit3_base_itc_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_base_itc_patch16_224.pth) | 384x384 | 61.4 | 79.1 | 222M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_384_coco_retrieval.pth) |
139
+ | [beit3_large_itc_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_large_itc_patch16_224.pth) | 384x384 | 63.4 | 82.1 | 675M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_384_coco_retrieval.pth) |
140
+
141
+ ### Flickr30k Retrieval
142
+
143
+ | initialized checkpoint | resolution | IR@1 | TR@1 | #params | weight |
144
+ |:----------------------------------------|:----------:|:-----:|:-----:|:-------:|-------------------|
145
+ | [beit3_base_itc_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_base_itc_patch16_224.pth) | 384x384 | 86.2 | 96.3 | 222M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_base_patch16_384_f30k_retrieval.pth) |
146
+ | [beit3_large_itc_patch16_224](https://github.com/addf400/files/releases/download/beit3/beit3_large_itc_patch16_224.pth) | 384x384 | 88.1 | 97.2 | 675M | [link](https://github.com/addf400/files/releases/download/beit3/beit3_large_patch16_384_f30k_retrieval.pth) |
147
+
148
+
149
+ ## Citation
150
+
151
+ If you find this repository useful, please consider citing our work:
152
+ ```
153
+ @inproceedings{beit3,
154
+ title={Image as a foreign language: {BEiT} pretraining for vision and vision-language tasks},
155
+ author={Wenhui Wang and Hangbo Bao and Li Dong and Johan Bjorck and Zhiliang Peng and Qiang Liu and Kriti Aggarwal and Owais Khan Mohammed and Saksham Singhal and Subhojit Som and Furu Wei},
156
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
157
+ year={2023}
158
+ }
159
+
160
+ @article{beitv2,
161
+ title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers},
162
+ author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei},
163
+ year={2022},
164
+ eprint={2208.06366},
165
+ archivePrefix={arXiv},
166
+ primaryClass={cs.CV}
167
+ }
168
+
169
+ @inproceedings{beit,
170
+ title={{BEiT}: {BERT} Pre-Training of Image Transformers},
171
+ author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei},
172
+ booktitle={International Conference on Learning Representations},
173
+ year={2022},
174
+ url={https://openreview.net/forum?id=p-BhZSz59o4}
175
+ }
176
+ ```
177
+
178
+
179
+ ## Acknowledgement
180
+
181
+ This repository is built using the [BEiT](https://github.com/microsoft/unilm/tree/master/beit), the [BEiTv2](https://github.com/microsoft/unilm/tree/master/beit2), the [CLIP](https://github.com/openai/CLIP), the [open_clip](https://github.com/mlfoundations/open_clip), the [Oscar](https://github.com/microsoft/Oscar), the [DeiT](https://github.com/facebookresearch/deit), the [Dino](https://github.com/facebookresearch/dino) repository and the [timm](https://github.com/rwightman/pytorch-image-models) library.
182
+
183
+
184
+ ## License
185
+ This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
186
+
187
+ [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
188
+
189
+ ### Contact Information
190
+
191
+ For help or issues using BEiT-3 models, please submit a GitHub issue.
model/unilm/beit3/datasets.py ADDED
@@ -0,0 +1,847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beit3
4
+ # Copyright (c) 2023 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # --------------------------------------------------------'
7
+
8
+ import os
9
+ import json
10
+ import random
11
+ import torch
12
+ import glob
13
+ from collections import defaultdict, Counter
14
+ from torchvision import transforms
15
+ from torchvision.datasets.folder import default_loader
16
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
17
+ from timm.data.transforms import RandomResizedCropAndInterpolation
18
+ from timm.data import create_transform
19
+
20
+ import utils
21
+ from glossary import normalize_word
22
+ from randaug import RandomAugment
23
+
24
+
25
+ class BaseDataset(torch.utils.data.Dataset):
26
+ def __init__(
27
+ self, data_path, split, transform,
28
+ tokenizer, num_max_bpe_tokens, task=None,
29
+ ):
30
+ index_files = self.get_index_files(split, task=task)
31
+ self.tokenizer = tokenizer
32
+ self.num_max_bpe_tokens = num_max_bpe_tokens
33
+ self.data_path = data_path
34
+ items = []
35
+ self.index_files = index_files
36
+
37
+ offset = 0
38
+ for _index_file in index_files:
39
+ index_file = os.path.join(data_path, _index_file)
40
+ with open(index_file, mode="r", encoding="utf-8") as reader:
41
+ for line in reader:
42
+ data = json.loads(line)
43
+ items.append(data)
44
+ print("Load %d image-text pairs from %s. " % (len(items) - offset, index_file))
45
+ offset = len(items)
46
+ self.items = items
47
+ self.bos_token_id = tokenizer.bos_token_id
48
+ self.eos_token_id = tokenizer.eos_token_id
49
+ self.pad_token_id = tokenizer.pad_token_id
50
+ self.loader = default_loader
51
+ self.transform = transform
52
+ self.split = split
53
+
54
+ @staticmethod
55
+ def get_index_files(split):
56
+ raise NotImplementedError()
57
+
58
+ def _get_image(self, image_path: str):
59
+ image_path = os.path.join(self.data_path, image_path)
60
+ image = self.loader(image_path)
61
+ return self.transform(image)
62
+
63
+ def _get_text_segment(self, text_segment, max_len=None):
64
+ if isinstance(text_segment, str):
65
+ tokens = self.tokenizer.tokenize(text_segment)
66
+ else:
67
+ tokens = text_segment[:]
68
+ if len(tokens) == 0:
69
+ raise RuntimeError("The text segment should contains at least one tokens!")
70
+ if max_len is None:
71
+ max_len = self.num_max_bpe_tokens
72
+
73
+ if len(tokens) > max_len - 2:
74
+ tokens = tokens[:max_len - 2]
75
+
76
+ tokens = [self.bos_token_id] + tokens[:] + [self.eos_token_id]
77
+ num_tokens = len(tokens)
78
+ padding_mask = [0] * num_tokens + [1] * (max_len - num_tokens)
79
+ return tokens + [self.pad_token_id] * (max_len - num_tokens), padding_mask, num_tokens
80
+
81
+ def _get_image_text_example(self, index: int, data: dict):
82
+ item = self.items[index]
83
+ img_path = item["image_path"]
84
+ img = self._get_image(img_path)
85
+ data["image"] = img
86
+
87
+ text_segment = item["text_segment"]
88
+ language_tokens, padding_mask, _ = self._get_text_segment(text_segment)
89
+ data["language_tokens"] = language_tokens
90
+ data["padding_mask"] = padding_mask
91
+
92
+ def __getitem__(self, index: int):
93
+ data = dict()
94
+ self._get_image_text_example(index, data)
95
+ return data
96
+
97
+ def __len__(self) -> int:
98
+ return len(self.items)
99
+
100
+ def __repr__(self) -> str:
101
+ head = "Dataset " + self.__class__.__name__
102
+ body = '{' + "\n Number of items: %s," % self.__len__()
103
+ body += "\n data root = %s," % self.data_path
104
+ body += "\n split = %s," % self.split
105
+ body += "\n dataset index files = %s" % str(self.index_files)
106
+ body += "\n num max bpe tokens = %s" % self.num_max_bpe_tokens
107
+ body += "\n transforms = ["
108
+ for t in self.transform.transforms:
109
+ body += "\n %s" % str(t)
110
+ body += "\n ]"
111
+ body += "\n}"
112
+
113
+ return head + body
114
+
115
+
116
+ def _write_data_into_jsonl(items, jsonl_file):
117
+ with open(jsonl_file, mode="w", encoding="utf-8") as writer:
118
+ for data in items:
119
+ writer.write(json.dumps(data, indent=None))
120
+ writer.write('\n')
121
+ print("Write %s with %d items !" % (jsonl_file, len(items)))
122
+
123
+
124
+ def _make_retrieval_coco_karpathy_dataset_index(
125
+ data_path,
126
+ tokenizer,
127
+ split=("train", "restval"),
128
+ split_name="train",
129
+ ):
130
+ coco_karpathy_split_json_file = os.path.join(data_path, "dataset_coco.json")
131
+ items = []
132
+ image_counter = set()
133
+ print("read %s" % coco_karpathy_split_json_file)
134
+ with open(coco_karpathy_split_json_file, mode="r", encoding="utf-8") as reader:
135
+ data = json.loads(reader.read())
136
+ for item in data["images"]:
137
+ if item["split"] in split:
138
+ image_path = os.path.join(item["filepath"], item["filename"])
139
+ for sent in item["sentences"]:
140
+ tokens = tokenizer.tokenize(sent["raw"])
141
+ token_ids = tokenizer.convert_tokens_to_ids(tokens)
142
+ items.append({
143
+ "image_path": image_path,
144
+ "text_segment": token_ids,
145
+ "image_id": len(image_counter),
146
+ })
147
+ if image_path not in image_counter:
148
+ image_counter.add(image_path)
149
+ print("Find %d images and %d image-text pairs for karpathy dataset %s split !" % \
150
+ (len(image_counter), len(items), split_name))
151
+ index_file = os.path.join(data_path, "coco_retrieval.%s.jsonl" % split_name)
152
+ _write_data_into_jsonl(items, index_file)
153
+ pass
154
+
155
+
156
+ def _make_captioning_coco_karpathy_dataset_index(
157
+ data_path,
158
+ tokenizer,
159
+ split=("train", "restval"),
160
+ split_name="train",
161
+ ):
162
+ coco_karpathy_split_json_file = os.path.join(data_path, "dataset_coco.json")
163
+ items = []
164
+ image_counter = set()
165
+ print("read %s" % coco_karpathy_split_json_file)
166
+ with open(coco_karpathy_split_json_file, mode="r", encoding="utf-8") as reader:
167
+ data = json.loads(reader.read())
168
+ for item in data["images"]:
169
+ if item["split"] in split:
170
+ image_path = os.path.join(item["filepath"], item["filename"])
171
+ if item["split"] in ["train", "restval"]:
172
+ for sent in item["sentences"]:
173
+ tokens = tokenizer.tokenize(sent["raw"])
174
+ token_ids = tokenizer.convert_tokens_to_ids(tokens)
175
+ items.append({
176
+ "image_path": image_path,
177
+ "text_segment": token_ids,
178
+ "image_id": item["cocoid"],
179
+ })
180
+ else:
181
+ items.append({
182
+ "image_path": image_path,
183
+ "text_segment": None,
184
+ "image_id": item["cocoid"],
185
+ })
186
+ if image_path not in image_counter:
187
+ image_counter.add(image_path)
188
+ print("Find %d images and %d image-text pairs for karpathy dataset %s split !" % \
189
+ (len(image_counter), len(items), split_name))
190
+ index_file = os.path.join(data_path, "coco_captioning.%s.jsonl" % split_name)
191
+ _write_data_into_jsonl(items, index_file)
192
+ pass
193
+
194
+
195
+ def _make_nocaps_dataset_index(
196
+ data_path,
197
+ split="val",
198
+ ):
199
+ if split == "val":
200
+ json_file = "nocaps_val_4500_captions.json"
201
+ elif split == "test":
202
+ json_file = "nocaps_test_image_info.json"
203
+ nocaps_split_json_file = os.path.join(data_path, json_file)
204
+ items = []
205
+ image_counter = set()
206
+ print("read %s" % nocaps_split_json_file)
207
+ with open(nocaps_split_json_file, mode="r", encoding="utf-8") as reader:
208
+ data = json.loads(reader.read())
209
+ for item in data["images"]:
210
+ image_path = os.path.join(split, item["file_name"])
211
+ items.append({
212
+ "image_path": image_path,
213
+ "text_segment": None,
214
+ "image_id": item["id"],
215
+ })
216
+
217
+ if image_path not in image_counter:
218
+ image_counter.add(image_path)
219
+
220
+ print("Find %d images and %d image-text pairs for nocaps dataset %s split !" % \
221
+ (len(image_counter), len(items), split))
222
+ index_file = os.path.join(data_path, "nocaps.%s.jsonl" % split)
223
+ _write_data_into_jsonl(items, index_file)
224
+
225
+
226
+ class NLVR2Dataset(BaseDataset):
227
+ @staticmethod
228
+ def get_index_files(split, task=None):
229
+ if split == "train":
230
+ return ("nlvr2.train.index.jsonl", )
231
+ elif split == "val":
232
+ return ("nlvr2.dev.index.jsonl", )
233
+ elif split == "test":
234
+ return ("nlvr2.test-P.index.jsonl", )
235
+ else:
236
+ raise RuntimeError("split %s is not found!" % split)
237
+
238
+ def __getitem__(self, index: int):
239
+ data = super().__getitem__(index)
240
+ item = self.items[index]
241
+ img_path = item["image2_path"]
242
+ img = self._get_image(img_path)
243
+ data["image2"] = img
244
+ data["label"] = self.items[index]["label"]
245
+ return data
246
+
247
+ @staticmethod
248
+ def __preprocess_json(preifx, json_file, tokenizer, index_file):
249
+ items = []
250
+ with open(json_file, mode="r", encoding="utf-8") as reader:
251
+ for line in reader:
252
+ data = json.loads(line)
253
+ path = os.path.join(preifx, str(data["directory"])) if "directory" in data else preifx
254
+ path = os.path.join(path, "-".join(data["identifier"].split("-")[:-1]))
255
+ tokens = tokenizer.tokenize(data["sentence"])
256
+ token_ids = tokenizer.convert_tokens_to_ids(tokens)
257
+ items.append({
258
+ "image_path": path + "-img0.png",
259
+ "image2_path": path + "-img1.png",
260
+ "text_segment": token_ids,
261
+ "label": 1 if data["label"] == "True" else 0,
262
+ "identifier": data["identifier"],
263
+ })
264
+ _write_data_into_jsonl(items, index_file)
265
+
266
+ @classmethod
267
+ def make_dataset_index(cls, data_path, tokenizer, nlvr_repo_path):
268
+ cls.__preprocess_json(
269
+ preifx="images/train", json_file=os.path.join(nlvr_repo_path, "nlvr2/data/train.json"),
270
+ tokenizer=tokenizer, index_file=os.path.join(data_path, cls.get_index_files("train")[0]),
271
+ )
272
+ cls.__preprocess_json(
273
+ preifx="dev", json_file=os.path.join(nlvr_repo_path, "nlvr2/data/dev.json"),
274
+ tokenizer=tokenizer, index_file=os.path.join(data_path, cls.get_index_files("val")[0]),
275
+ )
276
+ cls.__preprocess_json(
277
+ preifx="test1", json_file=os.path.join(nlvr_repo_path, "nlvr2/data/test1.json"),
278
+ tokenizer=tokenizer, index_file=os.path.join(data_path, cls.get_index_files("test")[0]),
279
+ )
280
+
281
+
282
+ class ImageNetDataset(BaseDataset):
283
+ @staticmethod
284
+ def get_index_files(split, task=None):
285
+ if split == "train":
286
+ return ("imagenet.train.index.jsonl", )
287
+ elif split == "val":
288
+ return ("imagenet.val.index.jsonl", )
289
+ elif split == "test":
290
+ return ("imagenet.val.index.jsonl", )
291
+ else:
292
+ raise RuntimeError("split %s is not found!" % split)
293
+
294
+ def __getitem__(self, index: int):
295
+ data = dict()
296
+ item = self.items[index]
297
+ img_path = item["image_path"]
298
+ img = self._get_image(img_path)
299
+ data["image"] = img
300
+ data["label"] = item["label"]
301
+ return data
302
+
303
+ @staticmethod
304
+ def _find_classes(dir):
305
+ """
306
+ Finds the class folders in a dataset.
307
+ Args:
308
+ dir (string): Root directory path.
309
+ Returns:
310
+ tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
311
+ Ensures:
312
+ No class is a subdirectory of another.
313
+ """
314
+ classes = [d.name for d in os.scandir(dir) if d.is_dir()]
315
+ classes.sort()
316
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
317
+ return classes, class_to_idx
318
+
319
+ @staticmethod
320
+ def _make_imagenet_index(data_path, index_path, data_path_prefix, class_to_idx, split):
321
+ items = []
322
+ index_file = os.path.join(index_path, f"imagenet.{split}.index.jsonl")
323
+ for target_class in sorted(class_to_idx.keys()):
324
+ class_index = class_to_idx[target_class]
325
+ target_dir = os.path.join(data_path, target_class)
326
+ if not os.path.isdir(target_dir):
327
+ continue
328
+ for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
329
+ for fname in sorted(fnames):
330
+ path = os.path.join(root, fname)
331
+ path = path.replace(data_path_prefix, "")
332
+ items.append({
333
+ "image_path": path,
334
+ "label": class_index,
335
+ })
336
+
337
+ _write_data_into_jsonl(items, index_file)
338
+
339
+ @classmethod
340
+ def make_dataset_index(cls, train_data_path, val_data_path, index_path):
341
+ data_path_prefix = train_data_path[:[x[0]==x[1] for x in zip(train_data_path, val_data_path)].index(0)]
342
+ classes, class_to_idx = cls._find_classes(train_data_path)
343
+ cls._make_imagenet_index(
344
+ data_path=train_data_path, index_path=index_path, data_path_prefix=data_path_prefix,
345
+ class_to_idx=class_to_idx, split="train",
346
+ )
347
+ cls._make_imagenet_index(
348
+ data_path=val_data_path, index_path=index_path, data_path_prefix=data_path_prefix,
349
+ class_to_idx=class_to_idx, split="val",
350
+ )
351
+
352
+
353
+ class VQAv2Dataset(BaseDataset):
354
+ def __init__(self, data_path, **kwargs):
355
+ super().__init__(data_path=data_path, **kwargs)
356
+ ans2label_file = os.path.join(data_path, "answer2label.txt")
357
+ ans2label = {}
358
+ label2ans = []
359
+ with open(ans2label_file, mode="r", encoding="utf-8") as reader:
360
+ for i, line in enumerate(reader):
361
+ data = json.loads(line)
362
+ ans = data["answer"]
363
+ label = data["label"]
364
+ label = int(label)
365
+ assert label == i
366
+ ans2label[ans] = i
367
+ label2ans.append(ans)
368
+
369
+ self.ans2label = ans2label
370
+ self.label2ans = label2ans
371
+
372
+ @staticmethod
373
+ def get_index_files(split, task=None):
374
+ if split == "train":
375
+ return ("vqa.train.jsonl", "vqa.trainable_val.jsonl")
376
+ elif split == "val":
377
+ return ("vqa.rest_val.jsonl", )
378
+ elif split == "test":
379
+ return ("vqa.test.jsonl", )
380
+ elif split == "test-dev":
381
+ return ("vqa.test-dev.jsonl", )
382
+ else:
383
+ raise RuntimeError("split %s is not found!" % split)
384
+
385
+ def __getitem__(self, index: int):
386
+ data = super().__getitem__(index)
387
+ if "labels" in self.items[index] and len(self.items[index]["labels"]) > 0:
388
+ labels = [0.] * len(self.label2ans)
389
+ for l, s in zip(self.items[index]["labels"], self.items[index]["scores"]):
390
+ labels[l] = s
391
+ data["labels"] = torch.FloatTensor(labels)
392
+ else:
393
+ data["qid"] = self.items[index]["qid"]
394
+ return data
395
+
396
+ @staticmethod
397
+ def get_score(occurences):
398
+ if occurences == 0:
399
+ return 0.0
400
+ elif occurences == 1:
401
+ return 0.3
402
+ elif occurences == 2:
403
+ return 0.6
404
+ elif occurences == 3:
405
+ return 0.9
406
+ else:
407
+ return 1.0
408
+
409
+ @classmethod
410
+ def make_dataset_index(cls, data_path, tokenizer, annotation_data_path):
411
+ with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_train2014_questions.json"), "r") as fp:
412
+ questions_train2014 = json.load(fp)["questions"]
413
+ with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_val2014_questions.json"), "r") as fp:
414
+ questions_val2014 = json.load(fp)["questions"]
415
+ with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_test2015_questions.json"), "r") as fp:
416
+ questions_test2015 = json.load(fp)["questions"]
417
+ with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_test-dev2015_questions.json"), "r") as fp:
418
+ questions_test_dev2015 = json.load(fp)["questions"]
419
+
420
+ with open(os.path.join(annotation_data_path, "v2_mscoco_train2014_annotations.json"), "r") as fp:
421
+ annotations_train2014 = json.load(fp)["annotations"]
422
+ with open(os.path.join(annotation_data_path, "v2_mscoco_val2014_annotations.json"), "r") as fp:
423
+ annotations_val2014 = json.load(fp)["annotations"]
424
+
425
+ annotations = dict()
426
+
427
+ for split, questions in zip(
428
+ ["train", "val", "test", "test-dev"],
429
+ [questions_train2014, questions_val2014, questions_test2015, questions_test_dev2015],
430
+ ):
431
+ _annot = defaultdict(dict)
432
+ for q in questions:
433
+ question_text = q["question"]
434
+ tokens = tokenizer.tokenize(question_text)
435
+ token_ids = tokenizer.convert_tokens_to_ids(tokens)
436
+
437
+ assert q["question_id"] not in _annot[q["image_id"]]
438
+ _annot[q["image_id"]][q["question_id"]] = {
439
+ "question": question_text,
440
+ "token_ids": token_ids,
441
+ }
442
+
443
+ annotations[split] = _annot
444
+
445
+ all_major_answers = list()
446
+
447
+ for split, annots in zip(
448
+ ["train", "val"], [annotations_train2014, annotations_val2014],
449
+ ):
450
+ # _annot = annotations[split]
451
+ for q in annots:
452
+ all_major_answers.append(q["multiple_choice_answer"])
453
+
454
+ all_major_answers = [normalize_word(word) for word in all_major_answers]
455
+ counter = {k: v for k, v in Counter(all_major_answers).items() if v >= 9}
456
+ ans2label = {k: i for i, k in enumerate(counter.keys())}
457
+ label2ans = list(counter.keys())
458
+
459
+ for split, annots in zip(
460
+ ["train", "val"], [annotations_train2014, annotations_val2014],
461
+ ):
462
+ _annot = annotations[split]
463
+ for q in annots:
464
+ answers = q["answers"]
465
+ answer_count = {}
466
+ for answer in answers:
467
+ answer_ = answer["answer"]
468
+ answer_count[answer_] = answer_count.get(answer_, 0) + 1
469
+
470
+ labels = []
471
+ scores = []
472
+ for answer in answer_count:
473
+ if answer not in ans2label:
474
+ continue
475
+ labels.append(ans2label[answer])
476
+ score = cls.get_score(answer_count[answer])
477
+ scores.append(score)
478
+
479
+ assert "labels" not in _annot[q["image_id"]][q["question_id"]]
480
+ assert "question" in _annot[q["image_id"]][q["question_id"]]
481
+ _annot[q["image_id"]][q["question_id"]]["labels"] = labels
482
+ _annot[q["image_id"]][q["question_id"]]["scores"] = scores
483
+
484
+ for split in ["train", "val"]:
485
+ filtered_annot = dict()
486
+ for ik, iv in annotations[split].items():
487
+ new_q = dict()
488
+ for qk, qv in iv.items():
489
+ if len(qv["labels"]) != 0:
490
+ new_q[qk] = qv
491
+ if len(new_q) != 0:
492
+ filtered_annot[ik] = new_q
493
+ annotations[split] = filtered_annot
494
+
495
+ split2items = {}
496
+ for split in ["train", "val", "test", "test-dev"]:
497
+ annot = annotations[split]
498
+ split_name = {
499
+ "train": "train2014",
500
+ "val": "val2014",
501
+ "test": "test2015",
502
+ "test-dev": "test2015",
503
+ }[split]
504
+ paths = list(glob.glob(f"{data_path}/{split_name}/*.jpg"))
505
+ random.shuffle(paths)
506
+ annot_paths = [path for path in paths \
507
+ if int(path.split("/")[-1].split("_")[-1][:-4]) in annot]
508
+
509
+ if len(paths) == len(annot_paths):
510
+ print("all images have caption annotations")
511
+ else:
512
+ print("not all images have caption annotations")
513
+ print(len(paths), len(annot_paths), len(annot))
514
+
515
+ items = []
516
+ for path in annot_paths:
517
+ iid = int(path.split("/")[-1].split("_")[-1][:-4])
518
+ _annot = annotations[split][iid]
519
+ for qid in _annot:
520
+ q = _annot[qid]
521
+ if split in ["train", "val"]:
522
+ labels = q["labels"]
523
+ scores = q["scores"]
524
+ else:
525
+ labels, scores = [], []
526
+
527
+ items.append({
528
+ "image_path": os.path.join(split_name, path.split('/')[-1]),
529
+ "text_segment": q["token_ids"],
530
+ "labels": labels,
531
+ "scores": scores,
532
+ "qid": qid,
533
+ })
534
+ split2items[split] = items
535
+
536
+ _write_data_into_jsonl(items=items, jsonl_file=os.path.join(data_path, "vqa.%s.jsonl" % split))
537
+
538
+ # Following ViLT, we use 1000 images of the original val set as the final val set
539
+ val_image2items = defaultdict(list)
540
+ for item in split2items["val"]:
541
+ val_image2items[item["image_path"]].append(item)
542
+
543
+ print("Contains %d image and %d pairs for val set!" % (len(val_image2items), len(split2items["val"])))
544
+
545
+ val_images = list(val_image2items.keys())
546
+ random.shuffle(val_images)
547
+ trainable_val = []
548
+ rest_val = []
549
+ for i, image_id in enumerate(val_images):
550
+ if i < 1000:
551
+ rest_val += val_image2items[image_id]
552
+ else:
553
+ trainable_val += val_image2items[image_id]
554
+
555
+ _write_data_into_jsonl(items=trainable_val, jsonl_file=os.path.join(data_path, "vqa.trainable_val.jsonl"))
556
+ _write_data_into_jsonl(items=rest_val, jsonl_file=os.path.join(data_path, "vqa.rest_val.jsonl"))
557
+
558
+ with open(os.path.join(data_path, "answer2label.txt"), mode="w", encoding="utf-8") as writer:
559
+ for ans in ans2label:
560
+ to_json = {
561
+ "answer": ans,
562
+ "label": ans2label[ans]
563
+ }
564
+ writer.write("%s\n" % json.dumps(to_json))
565
+
566
+
567
+ class RetrievalDataset(BaseDataset):
568
+ @staticmethod
569
+ def get_index_files(split, task=None):
570
+ if split == "train":
571
+ return (f"{task}.train.jsonl", )
572
+ elif split == "val":
573
+ return (f"{task}.val.jsonl", )
574
+ elif split == "test":
575
+ return (f"{task}.test.jsonl", )
576
+ else:
577
+ raise RuntimeError("split %s is not found!" % split)
578
+
579
+ def __getitem__(self, index: int):
580
+ data = super().__getitem__(index)
581
+ data["image_id"] = self.items[index]["image_id"]
582
+ return data
583
+
584
+ @staticmethod
585
+ def make_flickr30k_dataset_index(data_path, tokenizer, karpathy_path):
586
+
587
+ with open(os.path.join(karpathy_path, "dataset_flickr30k.json"), "r") as reader:
588
+ captions = json.loads(reader.read())
589
+
590
+ captions = captions["images"]
591
+ split2items = defaultdict(list)
592
+ split2images = defaultdict(set)
593
+
594
+ for each_item in captions:
595
+ image_path = os.path.join("flickr30k-images", each_item["filename"])
596
+ split = each_item["split"]
597
+
598
+ for text_segment in each_item["sentences"]:
599
+ tokens = tokenizer.tokenize(text_segment["raw"])
600
+ token_ids = tokenizer.convert_tokens_to_ids(tokens)
601
+
602
+ split2items[split].append({
603
+ "image_path": image_path,
604
+ "text_segment": token_ids,
605
+ "image_id": len(split2images[split]),
606
+ })
607
+
608
+ assert each_item["filename"] not in split2images[split]
609
+ split2images[split].add(each_item["filename"])
610
+
611
+ for split in split2items:
612
+ print("%d images and %d image-text pairs!" % (len(split2images[split]), len(split2items[split])))
613
+ _write_data_into_jsonl(split2items[split], os.path.join(data_path, "flickr30k.%s.jsonl" % split))
614
+
615
+ @staticmethod
616
+ def make_coco_dataset_index(data_path, tokenizer):
617
+ _make_retrieval_coco_karpathy_dataset_index(data_path, tokenizer, split=("train", "restval"), split_name="train")
618
+ _make_retrieval_coco_karpathy_dataset_index(data_path, tokenizer, split=("val", ), split_name="val")
619
+ _make_retrieval_coco_karpathy_dataset_index(data_path, tokenizer, split=("test", ), split_name="test")
620
+
621
+
622
+ class CaptioningDataset(BaseDataset):
623
+
624
+ def __init__(self, data_path, split, transform,
625
+ tokenizer, num_max_bpe_tokens, task, mask_prob):
626
+ super().__init__(
627
+ data_path=data_path, split=split,
628
+ transform=transform, tokenizer=tokenizer,
629
+ num_max_bpe_tokens=num_max_bpe_tokens, task=task,
630
+ )
631
+ self.mask_token_id = tokenizer.mask_token_id
632
+ self.language_vocab_size = tokenizer.vocab_size
633
+ self.mask_prob = mask_prob
634
+
635
+ @staticmethod
636
+ def get_index_files(split, task=None):
637
+ if split == "train":
638
+ return ("coco_captioning.train.jsonl", )
639
+ elif split == "val":
640
+ return (f"{task}.val.jsonl", )
641
+ elif split == "test":
642
+ return (f"{task}.test.jsonl", )
643
+ else:
644
+ raise RuntimeError("split %s is not found!" % split)
645
+
646
+ def _get_mask_token(self, token):
647
+ p = random.random()
648
+ if p < 0.8:
649
+ return self.mask_token_id
650
+ elif p < 0.9:
651
+ return token
652
+ else:
653
+ return random.randint(3, self.language_vocab_size - 1)
654
+
655
+ def _masking_on_text_tokens(self, tokens, num_tokens, mask_prob):
656
+ bool_masked_pos = [0] * len(tokens)
657
+ to_mask = min(int(num_tokens * mask_prob + 0.5), num_tokens - 1)
658
+ to_mask = max(to_mask, 1)
659
+ num_masked_tokens = 0
660
+ while num_masked_tokens < to_mask:
661
+ i = random.randint(1, num_tokens - 1)
662
+ if bool_masked_pos[i] == 0:
663
+ bool_masked_pos[i] = 1
664
+ tokens[i] = self._get_mask_token(tokens[i])
665
+ num_masked_tokens += 1
666
+
667
+ return tokens, bool_masked_pos
668
+
669
+ def __getitem__(self, index: int):
670
+ data = dict()
671
+ item = self.items[index]
672
+ img_path = item["image_path"]
673
+ img = self._get_image(img_path)
674
+ data["image"] = img
675
+ data["image_id"] = item["image_id"]
676
+
677
+ text_segment = item["text_segment"]
678
+ if text_segment is not None:
679
+ language_tokens, padding_mask, num_tokens = self._get_text_segment(text_segment)
680
+ masked_tokens = language_tokens[:]
681
+ masked_tokens, language_masked_pos = \
682
+ self._masking_on_text_tokens(masked_tokens, num_tokens, self.mask_prob)
683
+ data["language_tokens"] = language_tokens
684
+ data["masked_tokens"] = masked_tokens
685
+ data["language_masked_pos"] = language_masked_pos
686
+ data["padding_mask"] = padding_mask
687
+ return data
688
+
689
+ @staticmethod
690
+ def make_coco_captioning_dataset_index(data_path, tokenizer):
691
+ _make_captioning_coco_karpathy_dataset_index(data_path, tokenizer, split=("train", "restval"), split_name="train")
692
+ _make_captioning_coco_karpathy_dataset_index(data_path, tokenizer, split=("val", ), split_name="val")
693
+ _make_captioning_coco_karpathy_dataset_index(data_path, tokenizer, split=("test", ), split_name="test")
694
+
695
+ @staticmethod
696
+ def make_nocaps_captioning_dataset_index(data_path):
697
+ _make_nocaps_dataset_index(data_path, split="val")
698
+ _make_nocaps_dataset_index(data_path, split="test")
699
+
700
+
701
+ task2dataset = {
702
+ "nlvr2": NLVR2Dataset,
703
+ "vqav2": VQAv2Dataset,
704
+ "flickr30k": RetrievalDataset,
705
+ "coco_retrieval": RetrievalDataset,
706
+ "coco_captioning": CaptioningDataset,
707
+ "nocaps": CaptioningDataset,
708
+ "imagenet": ImageNetDataset,
709
+ }
710
+
711
+
712
+ def create_dataloader(dataset, is_train, batch_size, num_workers, pin_mem, dist_eval=False):
713
+ if is_train or dist_eval:
714
+ num_tasks = utils.get_world_size()
715
+ global_rank = utils.get_rank()
716
+
717
+ if not is_train and dist_eval and len(dataset) % num_tasks != 0:
718
+ print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
719
+ 'This will slightly alter validation results as extra duplicate entries are added to achieve '
720
+ 'equal num of samples per-process.')
721
+
722
+ sampler = torch.utils.data.DistributedSampler(
723
+ dataset, num_replicas=num_tasks, rank=global_rank, shuffle=is_train
724
+ )
725
+ else:
726
+ sampler = torch.utils.data.SequentialSampler(dataset)
727
+
728
+ return torch.utils.data.DataLoader(
729
+ dataset, sampler=sampler,
730
+ batch_size=batch_size,
731
+ num_workers=num_workers,
732
+ pin_memory=pin_mem,
733
+ drop_last=is_train,
734
+ collate_fn=utils.merge_batch_tensors_by_dict_key,
735
+ )
736
+
737
+
738
+ def build_transform(is_train, args):
739
+ if args.task in ["imagenet"]:
740
+ return build_imagenet_transform(is_train, args)
741
+
742
+ if is_train:
743
+ t = [
744
+ RandomResizedCropAndInterpolation(args.input_size, scale=(0.5, 1.0), interpolation=args.train_interpolation),
745
+ transforms.RandomHorizontalFlip(),
746
+ ]
747
+ if args.randaug:
748
+ t.append(
749
+ RandomAugment(
750
+ 2, 7, isPIL=True,
751
+ augs=[
752
+ 'Identity','AutoContrast','Equalize','Brightness','Sharpness',
753
+ 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate',
754
+ ]))
755
+ t += [
756
+ transforms.ToTensor(),
757
+ transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
758
+ ]
759
+ t = transforms.Compose(t)
760
+ else:
761
+ t = transforms.Compose([
762
+ transforms.Resize((args.input_size, args.input_size), interpolation=3),
763
+ transforms.ToTensor(),
764
+ transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)
765
+ ])
766
+
767
+ return t
768
+
769
+
770
+ def build_imagenet_transform(is_train, args):
771
+ resize_im = args.input_size > 32
772
+ if is_train:
773
+ # this should always dispatch to transforms_imagenet_train
774
+ transform = create_transform(
775
+ input_size=args.input_size,
776
+ is_training=True,
777
+ color_jitter=args.color_jitter,
778
+ auto_augment=args.aa,
779
+ interpolation=args.train_interpolation,
780
+ re_prob=args.reprob,
781
+ re_mode=args.remode,
782
+ re_count=args.recount,
783
+ mean=IMAGENET_DEFAULT_MEAN,
784
+ std=IMAGENET_DEFAULT_STD,
785
+ )
786
+ if not resize_im:
787
+ # replace RandomResizedCropAndInterpolation with
788
+ # RandomCrop
789
+ transform.transforms[0] = transforms.RandomCrop(
790
+ args.input_size, padding=4)
791
+ return transform
792
+
793
+ t = []
794
+ if resize_im:
795
+ if args.crop_pct is None:
796
+ args.crop_pct = 1.0
797
+ size = int(args.input_size / args.crop_pct)
798
+ t.append(
799
+ transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
800
+ )
801
+ t.append(transforms.CenterCrop(args.input_size))
802
+
803
+ t.append(transforms.ToTensor())
804
+ t.append(transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD))
805
+ return transforms.Compose(t)
806
+
807
+
808
+ def get_sentencepiece_model_for_beit3(args):
809
+ from transformers import XLMRobertaTokenizer
810
+ return XLMRobertaTokenizer(args.sentencepiece_model)
811
+
812
+
813
+ def create_dataset_by_split(args, split, is_train=True):
814
+ transform = build_transform(is_train=is_train, args=args)
815
+ dataset_class = task2dataset[args.task]
816
+ tokenizer = get_sentencepiece_model_for_beit3(args)
817
+
818
+ opt_kwargs = {}
819
+ if args.task in ["coco_captioning", "nocaps"]:
820
+ opt_kwargs["mask_prob"] = args.captioning_mask_prob
821
+
822
+ dataset = dataset_class(
823
+ data_path=args.data_path, split=split,
824
+ transform=transform, tokenizer=tokenizer,
825
+ num_max_bpe_tokens=args.num_max_bpe_tokens,
826
+ task=args.task, **opt_kwargs,
827
+ )
828
+ if is_train:
829
+ batch_size = args.batch_size
830
+ elif hasattr(args, "eval_batch_size") and args.eval_batch_size is not None:
831
+ batch_size = args.eval_batch_size
832
+ else:
833
+ batch_size = int(args.batch_size * 1.5)
834
+
835
+ return create_dataloader(
836
+ dataset, is_train=is_train, batch_size=batch_size,
837
+ num_workers=args.num_workers, pin_mem=args.pin_mem, dist_eval=args.dist_eval,
838
+ )
839
+
840
+
841
+ def create_downstream_dataset(args, is_eval=False):
842
+ if is_eval:
843
+ return create_dataset_by_split(args, split="test", is_train=False)
844
+ else:
845
+ return \
846
+ create_dataset_by_split(args, split="train", is_train=True), \
847
+ create_dataset_by_split(args, split="val", is_train=True)
model/unilm/beit3/engine_for_finetuning.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beit3
4
+ # Copyright (c) 2023 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # --------------------------------------------------------'
7
+
8
+ import math
9
+ import sys
10
+ import json
11
+ from typing import Iterable, Optional
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from timm.utils import ModelEma
18
+ from timm.utils import accuracy, ModelEma
19
+ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
20
+ from datasets import get_sentencepiece_model_for_beit3
21
+
22
+ import utils
23
+
24
+
25
+ class TaskHandler(object):
26
+ def __init__(self) -> None:
27
+ self.metric_logger = None
28
+ self.split = None
29
+
30
+ def train_batch(self, model, **kwargs):
31
+ raise NotImplementedError()
32
+
33
+ def eval_batch(self, model, **kwargs):
34
+ raise NotImplementedError()
35
+
36
+ def before_eval(self, metric_logger, data_loader, **kwargs):
37
+ self.metric_logger = metric_logger
38
+ self.split = data_loader.dataset.split
39
+
40
+ def after_eval(self, **kwargs):
41
+ raise NotImplementedError()
42
+
43
+
44
+ class NLVR2Handler(TaskHandler):
45
+ def __init__(self) -> None:
46
+ super().__init__()
47
+ self.criterion = torch.nn.CrossEntropyLoss()
48
+
49
+ def train_batch(self, model, image, image2, language_tokens, padding_mask, label):
50
+ logits = model(
51
+ image_a=image, image_b=image2,
52
+ text_description=language_tokens,
53
+ padding_mask=padding_mask)
54
+ acc = (logits.max(-1)[-1] == label).float().mean()
55
+ return {
56
+ "loss": self.criterion(input=logits, target=label),
57
+ "acc": acc,
58
+ }
59
+
60
+ def eval_batch(self, model, image, image2, language_tokens, padding_mask, label):
61
+ logits = model(
62
+ image_a=image, image_b=image2,
63
+ text_description=language_tokens,
64
+ padding_mask=padding_mask)
65
+ batch_size = language_tokens.shape[0]
66
+ acc = (logits.max(-1)[-1] == label).float().sum(0) * 100.0 / batch_size
67
+ self.metric_logger.meters['acc'].update(acc.item(), n=batch_size)
68
+
69
+ def after_eval(self, **kwargs):
70
+ print('* Acc {acc.global_avg:.3f}'.format(acc=self.metric_logger.acc))
71
+ return {k: meter.global_avg for k, meter in self.metric_logger.meters.items()}, "acc"
72
+
73
+
74
+ class ImageNetHandler(TaskHandler):
75
+ def __init__(self, args) -> None:
76
+ super().__init__()
77
+ mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
78
+ if mixup_active:
79
+ # smoothing is handled with mixup label transform
80
+ self.criterion = SoftTargetCrossEntropy()
81
+ elif args.label_smoothing > 0.:
82
+ self.criterion = LabelSmoothingCrossEntropy(smoothing=args.label_smoothing)
83
+ else:
84
+ self.criterion = torch.nn.CrossEntropyLoss()
85
+
86
+ def train_batch(self, model, image, label):
87
+ logits = model(image=image)
88
+ return {
89
+ "loss": self.criterion(logits, label),
90
+ }
91
+
92
+ def eval_batch(self, model, image, label):
93
+ logits = model(image=image)
94
+ batch_size = image.shape[0]
95
+ acc1, acc5 = accuracy(logits, label, topk=(1, 5))
96
+ self.metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
97
+ self.metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
98
+
99
+ def after_eval(self, **kwargs):
100
+ print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}'
101
+ .format(top1=self.metric_logger.acc1, top5=self.metric_logger.acc5))
102
+ return {k: meter.global_avg for k, meter in self.metric_logger.meters.items()}, "acc1"
103
+
104
+
105
+ class RetrievalHandler(TaskHandler):
106
+ def __init__(self) -> None:
107
+ super().__init__()
108
+ self.image_feats = []
109
+ self.text_feats = []
110
+ self.image_ids = []
111
+ self.metric_logger = None
112
+
113
+ def train_batch(self, model, image, language_tokens, padding_mask, image_id):
114
+ loss, vision_cls, language_cls = model(
115
+ image=image, text_description=language_tokens, padding_mask=padding_mask)
116
+ return {
117
+ "loss": loss,
118
+ }
119
+
120
+ def before_eval(self, metric_logger, **kwargs):
121
+ self.image_feats.clear()
122
+ self.text_feats.clear()
123
+ self.image_ids.clear()
124
+ self.metric_logger = metric_logger
125
+
126
+ def eval_batch(self, model, image, language_tokens, padding_mask, image_id):
127
+ vision_cls, _ = model(image=image, only_infer=True)
128
+ _, language_cls = model(
129
+ text_description=language_tokens, padding_mask=padding_mask, only_infer=True)
130
+
131
+ self.image_feats.append(vision_cls.clone())
132
+ self.text_feats.append(language_cls.clone())
133
+ self.image_ids.append(image_id.clone())
134
+
135
+ def after_eval(self, **kwargs):
136
+ image_feats = {}
137
+ for feats, ids in zip(self.image_feats, self.image_ids):
138
+ for i, _idx in enumerate(ids):
139
+ idx = _idx.item()
140
+ if idx not in image_feats:
141
+ image_feats[idx] = feats[i]
142
+
143
+ tiids = torch.cat(self.image_ids, dim=0)
144
+ iids = []
145
+ sorted_tensors = []
146
+ for key in sorted(image_feats.keys()):
147
+ sorted_tensors.append(image_feats[key].view(1, -1))
148
+ iids.append(key)
149
+
150
+ image_cls_feats = torch.cat(sorted_tensors, dim=0)
151
+ text_cls_feats = torch.cat(self.text_feats, dim=0)
152
+
153
+ scores = image_cls_feats @ text_cls_feats.t()
154
+ iids = torch.LongTensor(iids).to(scores.device)
155
+
156
+ print("scores: {}".format(scores.size()))
157
+ print("iids: {}".format(iids.size()))
158
+ print("tiids: {}".format(tiids.size()))
159
+
160
+ topk10 = scores.topk(10, dim=1)
161
+ topk5 = scores.topk(5, dim=1)
162
+ topk1 = scores.topk(1, dim=1)
163
+
164
+ topk10_iids = tiids[topk10.indices]
165
+ topk5_iids = tiids[topk5.indices]
166
+ topk1_iids = tiids[topk1.indices]
167
+
168
+ tr_r10 = (iids.unsqueeze(1) == topk10_iids).float().max(dim=1)[0].mean()
169
+ tr_r5 = (iids.unsqueeze(1) == topk5_iids).float().max(dim=1)[0].mean()
170
+ tr_r1 = (iids.unsqueeze(1) == topk1_iids).float().max(dim=1)[0].mean()
171
+
172
+ topk10 = scores.topk(10, dim=0)
173
+ topk5 = scores.topk(5, dim=0)
174
+ topk1 = scores.topk(1, dim=0)
175
+ topk10_iids = iids[topk10.indices]
176
+ topk5_iids = iids[topk5.indices]
177
+ topk1_iids = iids[topk1.indices]
178
+
179
+ ir_r10 = (tiids.unsqueeze(0) == topk10_iids).float().max(dim=0)[0].mean()
180
+ ir_r5 = (tiids.unsqueeze(0) == topk5_iids).float().max(dim=0)[0].mean()
181
+ ir_r1 = (tiids.unsqueeze(0) == topk1_iids).float().max(dim=0)[0].mean()
182
+
183
+ eval_result = {
184
+ "tr_r10": tr_r10.item() * 100.0,
185
+ "tr_r5": tr_r5.item() * 100.0,
186
+ "tr_r1": tr_r1.item() * 100.0,
187
+ "ir_r10": ir_r10.item() * 100.0,
188
+ "ir_r5": ir_r5.item() * 100.0,
189
+ "ir_r1": ir_r1.item() * 100.0,
190
+ "average_score": 100.0 * (tr_r1 + tr_r5 + tr_r10 + ir_r1 + ir_r5 + ir_r10).item() / 6.0,
191
+ }
192
+
193
+ print('* Eval result = %s' % json.dumps(eval_result))
194
+ return eval_result, "average_score"
195
+
196
+
197
+ class VQAHandler(TaskHandler):
198
+ def __init__(self) -> None:
199
+ super().__init__()
200
+ self.predictions = []
201
+ self.criterion = nn.BCEWithLogitsLoss(reduction='mean')
202
+ self.label2ans = None
203
+
204
+ def train_batch(self, model, image, language_tokens, padding_mask, labels):
205
+ logits = model(
206
+ image=image, question=language_tokens,
207
+ padding_mask=padding_mask)
208
+ return {
209
+ "loss": self.criterion(input=logits.float(), target=labels.float()) * labels.shape[1],
210
+ }
211
+
212
+ def before_eval(self, metric_logger, data_loader, **kwargs):
213
+ self.predictions.clear()
214
+ self.metric_logger = metric_logger
215
+ self.label2ans = data_loader.dataset.label2ans
216
+
217
+ def eval_batch(self, model, image, language_tokens, padding_mask, labels=None, qid=None):
218
+ logits = model(
219
+ image=image, question=language_tokens,
220
+ padding_mask=padding_mask)
221
+ batch_size = language_tokens.shape[0]
222
+ if labels is not None:
223
+ scores = utils.VQAScore()(logits, labels) * 100.0
224
+ self.metric_logger.meters['score'].update(scores.item(), n=batch_size)
225
+ else:
226
+ _, preds = logits.max(-1)
227
+ for image_id, pred in zip(qid, preds):
228
+ self.predictions.append({
229
+ "question_id": image_id.item(),
230
+ "answer": self.label2ans[pred.item()],
231
+ })
232
+
233
+ def after_eval(self, **kwargs):
234
+ if len(self.predictions) == 0:
235
+ print('* Score {score.global_avg:.3f}'.format(score=self.metric_logger.score))
236
+ return {k: meter.global_avg for k, meter in self.metric_logger.meters.items()}, "score"
237
+ else:
238
+ return self.predictions, "prediction"
239
+
240
+
241
+ class CaptioningHandler(TaskHandler):
242
+ def __init__(self, args) -> None:
243
+ super().__init__()
244
+ self.predictions = []
245
+ self.criterion = utils.BertCaptioningLoss(args.label_smoothing, args.drop_worst_ratio, args.drop_worst_after)
246
+ self.tokenizer = get_sentencepiece_model_for_beit3(args)
247
+ self.num_beams = args.num_beams
248
+ self.max_len = args.num_max_bpe_tokens
249
+ self.length_penalty = args.length_penalty
250
+ self.vocab_size = args.vocab_size
251
+
252
+ def train_batch(self, model, image, language_tokens, masked_tokens, language_masked_pos, padding_mask, image_id, global_step):
253
+ logits, _ = model(
254
+ image=image, text_ids=masked_tokens, padding_mask=padding_mask, language_masked_pos=language_masked_pos, image_id=image_id)
255
+ masked_labels = language_tokens[language_masked_pos.bool()]
256
+ score = torch.max(logits, -1)[1].data == masked_labels
257
+ acc = torch.sum(score.float()) / torch.sum(language_masked_pos)
258
+ return {
259
+ "loss": self.criterion(logits, masked_labels, global_step),
260
+ "acc": acc
261
+ }
262
+
263
+ def before_eval(self, metric_logger, data_loader, **kwargs):
264
+ self.predictions.clear()
265
+ self.metric_logger = metric_logger
266
+
267
+ def eval_batch(self, model, image, image_id=None):
268
+ cur_len = 2
269
+ num_keep_best = 1
270
+ TOPN_PER_BEAM = 3
271
+
272
+ batch_size = image.size(0)
273
+ mask_id = self.tokenizer.mask_token_id
274
+ cls_id = self.tokenizer.cls_token_id
275
+ pad_id = self.tokenizer.pad_token_id
276
+ sep_id = self.tokenizer.sep_token_id
277
+ eos_token_ids = [sep_id]
278
+
279
+ cls_ids = torch.full(
280
+ (batch_size, 1), cls_id, dtype=torch.long, device=image.device
281
+ )
282
+ mask_ids = torch.full(
283
+ (batch_size, 1), mask_id, dtype=torch.long, device=image.device
284
+ )
285
+ cur_input_ids = torch.cat([cls_ids, mask_ids], dim=1)
286
+ tmp_ids = torch.full(
287
+ (batch_size, self.max_len-1), mask_id, dtype=torch.long, device=image.device
288
+ )
289
+ decoding_results = torch.cat([cls_ids, tmp_ids], dim=1)
290
+
291
+ # Expand input to num beams
292
+ cur_input_ids = cur_input_ids.unsqueeze(1).expand(batch_size, self.num_beams, cur_len)
293
+ cur_input_ids = cur_input_ids.contiguous().view(batch_size * self.num_beams, cur_len) # (batch_size * num_beams, cur_len)
294
+ decoding_results = decoding_results.unsqueeze(1).expand(batch_size, self.num_beams, self.max_len)
295
+ decoding_results = decoding_results.contiguous().view(batch_size * self.num_beams, self.max_len) # (batch_size * num_beams, cur_len)
296
+ image = image.unsqueeze(1).expand(batch_size, self.num_beams, image.size(-3), image.size(-2), image.size(-1))
297
+ image = image.contiguous().view(batch_size * self.num_beams, image.size(-3), image.size(-2), image.size(-1))
298
+
299
+ generated_hyps = [
300
+ utils.BeamHypotheses(
301
+ num_keep_best, self.max_len, length_penalty=self.length_penalty, early_stopping=False
302
+ ) for _ in range(batch_size)
303
+ ]
304
+ # scores for each sentence in the beam
305
+ beam_scores = torch.zeros((batch_size, self.num_beams), dtype=torch.float, device=cur_input_ids.device)
306
+ beam_scores[:, 1:] = -1e9
307
+ beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
308
+
309
+ # done sentences
310
+ done = [False for _ in range(batch_size)]
311
+ incremental_state = {}
312
+
313
+ while cur_len <= self.max_len:
314
+ next_token_idx = 1
315
+ padding_masks = torch.full(
316
+ cur_input_ids.shape, 0, dtype=torch.long, device=image.device
317
+ )
318
+ input_image = image
319
+ if cur_len != 2:
320
+ input_image = None
321
+
322
+ outputs, incremental_state_next = model(
323
+ image=input_image, text_ids=cur_input_ids, language_masked_pos=None,
324
+ padding_mask=padding_masks, text_len=cur_len, incremental_state=incremental_state)
325
+ incremental_state = incremental_state_next
326
+
327
+ # assert outputs.shape[1] == token_len
328
+ scores = outputs[:, next_token_idx, :] # (batch_size * num_beams, vocab_size)
329
+ scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
330
+ assert scores.size() == (batch_size * self.num_beams, self.vocab_size)
331
+ # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
332
+ _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
333
+ # re-organize to group the beam together (we are keeping top hypothesis accross beams)
334
+ _scores = _scores.view(batch_size, self.num_beams * self.vocab_size) # (batch_size, num_beams * vocab_size)
335
+ next_scores, next_words = torch.topk(_scores, TOPN_PER_BEAM * self.num_beams, dim=1, largest=True, sorted=True)
336
+ assert next_scores.size() == next_words.size() == (batch_size, TOPN_PER_BEAM * self.num_beams)
337
+
338
+ # next batch beam content
339
+ # list of (batch_size * num_beams) tuple(next hypothesis score, next word, current position in the batch)
340
+ next_batch_beam = []
341
+ # for each sentence
342
+ for batch_ex in range(batch_size):
343
+ # if we are done with this sentence
344
+ done[batch_ex] = done[batch_ex] or generated_hyps[batch_ex].is_done(next_scores[batch_ex].max().item())
345
+ if done[batch_ex]:
346
+ next_batch_beam.extend([(0, pad_id, 0)] * self.num_beams) # pad the batch
347
+ continue
348
+
349
+ # next sentence beam content
350
+ next_sent_beam = []
351
+ for idx, score in zip(next_words[batch_ex], next_scores[batch_ex]):
352
+ # get beam and word IDs
353
+ beam_id = idx // self.vocab_size
354
+ word_id = idx % self.vocab_size
355
+ # end of sentence, or next word
356
+ # if word_id.item() in eos_token_ids or cur_len + 1 == max_len:
357
+ if (word_id.item() in eos_token_ids and cur_len + 1 <= self.max_len) or (cur_len + 1 == self.max_len):
358
+ generated_hyps[batch_ex].add(
359
+ decoding_results[batch_ex * self.num_beams + beam_id, :cur_len].clone(), score.item()
360
+ )
361
+ else:
362
+ next_sent_beam.append((score, word_id, batch_ex * self.num_beams + beam_id))
363
+ # the beam for next step is full
364
+ if len(next_sent_beam) == self.num_beams:
365
+ break
366
+
367
+ # update next beam content
368
+ if cur_len + 1 == self.max_len:
369
+ assert len(next_sent_beam) == 0
370
+ else:
371
+ assert len(next_sent_beam) == self.num_beams
372
+
373
+ if len(next_sent_beam) == 0:
374
+ next_sent_beam = [(0, pad_id, 0)] * self.num_beams # pad the batch
375
+ next_batch_beam.extend(next_sent_beam)
376
+ assert len(next_batch_beam) == self.num_beams * (batch_ex + 1)
377
+
378
+ # sanity check / prepare next batch
379
+ assert len(next_batch_beam) == batch_size * self.num_beams
380
+ beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
381
+ beam_words = cur_input_ids.new([x[1] for x in next_batch_beam])
382
+ beam_idx = cur_input_ids.new([x[2] for x in next_batch_beam])
383
+
384
+ # re-order batch
385
+ cur_input_ids = cur_input_ids[beam_idx, :]
386
+ decoding_results = decoding_results[beam_idx, :]
387
+ for module in incremental_state:
388
+ for key in incremental_state[module]:
389
+ result = incremental_state[module][key].index_select(0, beam_idx)
390
+ incremental_state[module][key] = result[:,:,:-1,:]
391
+
392
+ next_ids = torch.full(
393
+ (batch_size * self.num_beams, 1), mask_id, dtype=torch.long, device=image.device
394
+ )
395
+ cur_input_ids = torch.cat([beam_words.unsqueeze(1), next_ids], dim=1)
396
+ decoding_results[:, cur_len-1] = beam_words
397
+ # update current length
398
+ cur_len = cur_len + 1
399
+ # stop when we are done with each sentence
400
+ if all(done):
401
+ break
402
+
403
+ # select the best hypotheses
404
+ tgt_len = torch.ones(batch_size, num_keep_best, dtype=torch.long)
405
+ logprobs = torch.zeros(batch_size, num_keep_best,
406
+ dtype=torch.float).fill_(-1e5).to(cur_input_ids.device)
407
+ all_best = []
408
+
409
+ for i, hypotheses in enumerate(generated_hyps):
410
+ best = []
411
+ hyp_scores = torch.tensor([x[0] for x in hypotheses.hyp])
412
+ _, best_indices = torch.topk(hyp_scores,
413
+ min(num_keep_best, len(hyp_scores)), largest=True)
414
+ for best_idx, hyp_idx in enumerate(best_indices):
415
+ conf, best_hyp = hypotheses.hyp[hyp_idx]
416
+ best.append(best_hyp)
417
+ logprobs[i, best_idx] = conf
418
+ tgt_len[i, best_idx] = len(best_hyp) + 1 # +1 for the <EOS> symbol
419
+ all_best.append(best)
420
+
421
+ # generate target batch, pad to the same length
422
+ decoded = cur_input_ids.new(batch_size, num_keep_best, self.max_len).fill_(pad_id)
423
+ for batch_idx, best in enumerate(all_best):
424
+ for best_idx, hypo in enumerate(best):
425
+ decoded[batch_idx, best_idx, : tgt_len[batch_idx, best_idx] - 1] = hypo
426
+ decoded[batch_idx, best_idx, tgt_len[batch_idx, best_idx] - 1] = eos_token_ids[0]
427
+
428
+ captions = self.tokenizer.batch_decode(decoded.squeeze(1), skip_special_tokens=True)
429
+ for qid, pred in zip(image_id, captions):
430
+ self.predictions.append({
431
+ "image_id": qid.item(),
432
+ "caption": pred,
433
+ })
434
+
435
+ def after_eval(self, **kwargs):
436
+ return self.predictions, "prediction"
437
+
438
+
439
+ def get_handler(args):
440
+ if args.task == "nlvr2":
441
+ return NLVR2Handler()
442
+ elif args.task == "vqav2":
443
+ return VQAHandler()
444
+ elif args.task in ("flickr30k", "coco_retrieval"):
445
+ return RetrievalHandler()
446
+ elif args.task in ("coco_captioning", "nocaps"):
447
+ return CaptioningHandler(args)
448
+ elif args.task in ("imagenet"):
449
+ return ImageNetHandler(args)
450
+ else:
451
+ raise NotImplementedError("Sorry, %s is not support." % args.task)
452
+
453
+
454
+ def train_one_epoch(
455
+ model: torch.nn.Module, data_loader: Iterable,
456
+ optimizer: torch.optim.Optimizer, device: torch.device,
457
+ handler: TaskHandler, epoch: int, start_steps: int,
458
+ lr_schedule_values: list, loss_scaler, max_norm: float = 0,
459
+ update_freq: int = 1, model_ema: Optional[ModelEma] = None,
460
+ log_writer: Optional[utils.TensorboardLogger] = None,
461
+ task = None, mixup_fn=None,
462
+ ):
463
+ model.train(True)
464
+ metric_logger = utils.MetricLogger(delimiter=" ")
465
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
466
+ metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
467
+ header = 'Epoch: [{}]'.format(epoch)
468
+ print_freq = 10
469
+
470
+ if loss_scaler is None:
471
+ model.zero_grad()
472
+ model.micro_steps = 0
473
+ else:
474
+ optimizer.zero_grad()
475
+
476
+ for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
477
+ step = data_iter_step // update_freq
478
+ global_step = start_steps + step # global training iteration
479
+ # Update LR & WD for the first acc
480
+ if lr_schedule_values is not None and data_iter_step % update_freq == 0:
481
+ for i, param_group in enumerate(optimizer.param_groups):
482
+ if lr_schedule_values is not None:
483
+ param_group["lr"] = lr_schedule_values[global_step] * param_group["lr_scale"]
484
+ # put input data into cuda
485
+ for tensor_key in data.keys():
486
+ data[tensor_key] = data[tensor_key].to(device, non_blocking=True)
487
+ # print("input %s = %s" % (tensor_key, data[tensor_key]))
488
+ if loss_scaler is None and tensor_key.startswith("image"):
489
+ data[tensor_key] = data[tensor_key].half()
490
+
491
+ # mixup for imagenet finetuning
492
+ if mixup_fn is not None:
493
+ data["image"], data["label"] = mixup_fn(data["image"], data["label"])
494
+
495
+ if task in ["coco_captioning", "nocaps"]:
496
+ data["global_step"] = global_step
497
+
498
+ if loss_scaler is None:
499
+ results = handler.train_batch(model, **data)
500
+ else:
501
+ with torch.cuda.amp.autocast():
502
+ results = handler.train_batch(model, **data)
503
+
504
+ loss = results.pop("loss")
505
+ loss_value = loss.item()
506
+
507
+ if not math.isfinite(loss_value):
508
+ print("Loss is {}, stopping training".format(loss_value))
509
+ sys.exit(1)
510
+
511
+ if loss_scaler is None:
512
+ loss /= update_freq
513
+ model.backward(loss)
514
+ model.step()
515
+
516
+ if (data_iter_step + 1) % update_freq == 0:
517
+ # model.zero_grad()
518
+ # Deepspeed will call step() & model.zero_grad() automatic
519
+ if model_ema is not None:
520
+ model_ema.update(model)
521
+ grad_norm = None
522
+ loss_scale_value = utils.get_loss_scale_for_deepspeed(model)
523
+ else:
524
+ # this attribute is added by timm on one optimizer (adahessian)
525
+ is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
526
+ loss /= update_freq
527
+ grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
528
+ parameters=model.parameters(), create_graph=is_second_order,
529
+ update_grad=(data_iter_step + 1) % update_freq == 0)
530
+ if (data_iter_step + 1) % update_freq == 0:
531
+ optimizer.zero_grad()
532
+ if model_ema is not None:
533
+ model_ema.update(model)
534
+ loss_scale_value = loss_scaler.state_dict()["scale"]
535
+
536
+ torch.cuda.synchronize()
537
+
538
+ metric_logger.update(loss=loss_value)
539
+ metric_logger.update(loss_scale=loss_scale_value)
540
+ min_lr = 10.
541
+ max_lr = 0.
542
+ for group in optimizer.param_groups:
543
+ min_lr = min(min_lr, group["lr"])
544
+ max_lr = max(max_lr, group["lr"])
545
+
546
+ metric_logger.update(lr=max_lr)
547
+ metric_logger.update(min_lr=min_lr)
548
+ weight_decay_value = None
549
+ for group in optimizer.param_groups:
550
+ if group["weight_decay"] > 0:
551
+ weight_decay_value = group["weight_decay"]
552
+ metric_logger.update(weight_decay=weight_decay_value)
553
+ metric_logger.update(grad_norm=grad_norm)
554
+
555
+ if log_writer is not None:
556
+ kwargs = {
557
+ "loss": loss_value,
558
+ }
559
+ for key in results:
560
+ kwargs[key] = results[key]
561
+ log_writer.update(head="train", **kwargs)
562
+
563
+ kwargs = {
564
+ "loss_scale": loss_scale_value,
565
+ "lr": max_lr,
566
+ "min_lr": min_lr,
567
+ "weight_decay": weight_decay_value,
568
+ "grad_norm": grad_norm,
569
+ }
570
+ log_writer.update(head="opt", **kwargs)
571
+ log_writer.set_step()
572
+
573
+ # gather the stats from all processes
574
+ metric_logger.synchronize_between_processes()
575
+ print("Averaged stats:", metric_logger)
576
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
577
+
578
+
579
+ @torch.no_grad()
580
+ def evaluate(data_loader, model, device, handler):
581
+ metric_logger = utils.MetricLogger(delimiter=" ")
582
+ header = 'Test:'
583
+
584
+ # switch to evaluation mode
585
+ model.eval()
586
+ handler.before_eval(metric_logger=metric_logger, data_loader=data_loader)
587
+
588
+ for data in metric_logger.log_every(data_loader, 10, header):
589
+ for tensor_key in data.keys():
590
+ data[tensor_key] = data[tensor_key].to(device, non_blocking=True)
591
+
592
+ with torch.cuda.amp.autocast():
593
+ handler.eval_batch(model=model, **data)
594
+
595
+ # gather the stats from all processes
596
+ metric_logger.synchronize_between_processes()
597
+
598
+ return handler.after_eval()
model/unilm/beit3/get_started/get_started_for_captioning.md ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning BEiT-3 on Image Captioning
2
+
3
+ ## COCO Captioning Setup
4
+
5
+ 1. [Setup environment](../README.md#setup).
6
+ 2. Download [2014 train images](http://images.cocodataset.org/zips/train2014.zip), [2014 val images](http://images.cocodataset.org/zips/val2014.zip) and [karpathy split](https://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip), then organize the dataset as following structure:
7
+
8
+ ```
9
+ /path/to/your_data/
10
+ train2014/
11
+ COCO_train2014_000000000009.jpg
12
+ ...
13
+ val2014/
14
+ COCO_val2014_000000000042.jpg
15
+ ...
16
+ dataset_coco.json
17
+ ```
18
+
19
+ We then generate the index json files using the following command. [beit3.spm](https://github.com/addf400/files/releases/download/beit3/beit3.spm) is the sentencepiece model used for tokenizing texts.
20
+ ```
21
+ from datasets import CaptioningDataset
22
+ from transformers import XLMRobertaTokenizer
23
+
24
+ tokenizer = XLMRobertaTokenizer("/your_beit3_model_path/beit3.spm")
25
+
26
+ CaptioningDataset.make_coco_captioning_dataset_index(
27
+ data_path="/path/to/your_data",
28
+ tokenizer=tokenizer,
29
+ )
30
+ ```
31
+
32
+
33
+ ## NoCaps Setup
34
+
35
+ 1. [Setup environment](README.md#setup).
36
+ 2. Download [NoCaps val set](https://nocaps.s3.amazonaws.com/nocaps_val_4500_captions.json), [NoCaps test set](https://s3.amazonaws.com/nocaps/nocaps_test_image_info.json) and download imags using the urls in val and test json files, then organize the dataset as following structure:
37
+
38
+ ```
39
+ /path/to/your_data/
40
+ val/
41
+ 09c863d76bcf6b00.jpg
42
+ ...
43
+ test/
44
+ 19dc6913830a0a21.jpg
45
+ ...
46
+ nocaps_val_4500_captions.json
47
+ nocaps_test_image_info.json
48
+ ```
49
+
50
+ We then generate the index json files using the following command. [beit3.spm](https://github.com/addf400/files/releases/download/beit3/beit3.spm) is the sentencepiece model used for tokenizing texts.
51
+ ```
52
+ from datasets import CaptioningDataset
53
+ from transformers import XLMRobertaTokenizer
54
+
55
+ tokenizer = XLMRobertaTokenizer("/your_beit3_model_path/beit3.spm")
56
+
57
+ CaptioningDataset.make_nocaps_captioning_dataset_index(
58
+ data_path="/path/to/your_data",
59
+ )
60
+ ```
61
+ We use COCO captioning training set as the training data of NoCaps.
62
+
63
+
64
+ ## Example: Fine-tuning BEiT-3 on Captioning
65
+
66
+ The BEiT-3 **base** model can be fine-tuned on captioning tasks using 8 V100-32GB:
67
+
68
+ ```bash
69
+ python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
70
+ --model beit3_base_patch16_480 \
71
+ --input_size 480 \
72
+ --task coco_captioning \
73
+ --batch_size 32 \
74
+ --layer_decay 1.0 \
75
+ --lr 4e-5 \
76
+ --randaug \
77
+ --epochs 10 \
78
+ --warmup_epochs 1 \
79
+ --drop_path 0.1 \
80
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
81
+ --finetune /your_beit3_model_path/beit3_base_patch16_224.pth \
82
+ --data_path /path/to/your_data \
83
+ --output_dir /path/to/save/your_model \
84
+ --log_dir /path/to/save/your_model/log \
85
+ --weight_decay 0.05 \
86
+ --seed 42 \
87
+ --save_ckpt_freq 5 \
88
+ --num_max_bpe_tokens 32 \
89
+ --captioning_mask_prob 0.7 \
90
+ --drop_worst_after 12000 \
91
+ --dist_eval \
92
+ --checkpoint_activations \
93
+ --enable_deepspeed
94
+ ```
95
+ - `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `8*32 = 256`.
96
+ - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models).
97
+ - `--task`: **coco_captioning** for COCO captioning and **nocaps** for NoCaps dataset.
98
+ - `lr`: 4e-5 for COCO captioning and 1e-5 for NoCaps.
99
+ - `--enable_deepspeed`: optional. If you use apex, please enable deepspeed.
100
+ - `--checkpoint_activations`: using gradient checkpointing for saving GPU memory.
101
+
102
+
103
+ The BEiT-3 **large** model can be fine-tuned on captioning tasks using 8 V100-32GB:
104
+
105
+ ```bash
106
+ python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
107
+ --model beit3_large_patch16_480 \
108
+ --input_size 480 \
109
+ --task coco_captioning \
110
+ --batch_size 32 \
111
+ --layer_decay 1.0 \
112
+ --lr 8e-6 \
113
+ --randaug \
114
+ --epochs 10 \
115
+ --warmup_epochs 1 \
116
+ --drop_path 0.1 \
117
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
118
+ --finetune /your_beit3_model_path/beit3_large_patch16_224.pth \
119
+ --data_path /path/to/your_data \
120
+ --output_dir /path/to/save/your_model \
121
+ --log_dir /path/to/save/your_model/log \
122
+ --weight_decay 0.05 \
123
+ --seed 42 \
124
+ --save_ckpt_freq 5 \
125
+ --num_max_bpe_tokens 32 \
126
+ --captioning_mask_prob 0.7 \
127
+ --drop_worst_after 12000 \
128
+ --dist_eval \
129
+ --checkpoint_activations \
130
+ --enable_deepspeed
131
+ ```
132
+ - `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `8*32 = 256`.
133
+ - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models).
134
+ - `--task`: **coco_captioning** for COCO captioning and **nocaps** for NoCaps dataset.
135
+ - `lr`: 8e-6 for COCO captioning and NoCaps.
136
+ - `--enable_deepspeed`: optional. If you use apex, please enable deepspeed.
137
+ - `--checkpoint_activations`: using gradient checkpointing for saving GPU memory.
138
+
139
+
140
+ ## Example: Evaluate BEiT-3 Fine-tuned model on Captioning
141
+
142
+ - Get the prediction file of the fine-tuned BEiT3-base model on captioning with 8 V100-32GB:
143
+ ```bash
144
+ python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
145
+ --model beit3_base_patch16_480 \
146
+ --input_size 480 \
147
+ --task coco_captioning \
148
+ --batch_size 16 \
149
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
150
+ --finetune /your_beit3_model_path/beit3_base_patch16_480_coco_captioning.pth \
151
+ --data_path /path/to/your_data \
152
+ --output_dir /path/to/save/your_prediction \
153
+ --eval \
154
+ --dist_eval
155
+ ```
156
+ - `--task`: **coco_captioning** for COCO captioning and **nocaps** for NoCaps dataset.
157
+ - `--finetune`: **beit3_base_patch16_480_coco_captioning.pth** for COCO captioning and **beit3_base_patch16_480_nocaps.pth** for NoCaps dataset.
158
+
159
+ - Get the prediction file of the fine-tuned BEiT3-large model on captioning with 8 V100-32GB:
160
+ ```bash
161
+ python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
162
+ --model beit3_large_patch16_480 \
163
+ --input_size 480 \
164
+ --task coco_captioning \
165
+ --batch_size 16 \
166
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
167
+ --finetune /your_beit3_model_path/beit3_large_patch16_480_coco_captioning.pth \
168
+ --data_path /path/to/your_data \
169
+ --output_dir /path/to/save/your_prediction \
170
+ --eval \
171
+ --dist_eval
172
+ ```
173
+ - `--task`: **coco_captioning** for COCO captioning and **nocaps** for NoCaps dataset.
174
+ - `--finetune`: **beit3_large_patch16_480_coco_captioning.pth** for COCO captioning and **beit3_large_patch16_480_nocaps.pth** for NoCaps dataset.
175
+
176
+ Please then submit the prediction file in the `output_dir` to the [evaluation server](https://eval.ai/web/challenges/challenge-page/355/overview) to obtain the NoCaps val and test results.
model/unilm/beit3/get_started/get_started_for_image_classification.md ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning BEiT-3 on ImageNet-1k (Image Classification)
2
+
3
+
4
+ ## Setup
5
+
6
+ 1. [Setup environment](../README.md#setup).
7
+ 2. Download and extract ImageNet-1k from http://image-net.org/.
8
+
9
+ The directory structure is the standard layout of torchvision's [`datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder). The training and validation data are expected to be in the `train/` folder and `val/` folder, respectively:
10
+
11
+ ```
12
+ /path/to/imagenet/
13
+ train/
14
+ class1/
15
+ img1.jpeg
16
+ class2/
17
+ img2.jpeg
18
+ val/
19
+ class1/
20
+ img3.jpeg
21
+ class/2
22
+ img4.jpeg
23
+ ```
24
+
25
+ We then generate the index json files using the following command. [beit3.spm](https://github.com/addf400/files/releases/download/beit3/beit3.spm) is the sentencepiece model used for tokenizing texts.
26
+ ```
27
+ from datasets import ImageNetDataset
28
+
29
+ ImageNetDataset.make_dataset_index(
30
+ train_data_path = "/path/to/your_data/train",
31
+ val_data_path = "/path/to/your_data/val",
32
+ index_path = "/path/to/your_data"
33
+ )
34
+ ```
35
+
36
+
37
+ ## Example: Fine-tuning BEiT-3 on ImageNet-1k (Image Classification)
38
+
39
+ The BEiT-3 **base** model can be finetuned on ImageNet-1k using 8 V100-32GB:
40
+
41
+ ```bash
42
+ python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
43
+ --model beit3_base_patch16_224 \
44
+ --task imagenet \
45
+ --batch_size 128 \
46
+ --layer_decay 0.65 \
47
+ --lr 7e-4 \
48
+ --update_freq 1 \
49
+ --epochs 50 \
50
+ --warmup_epochs 5 \
51
+ --drop_path 0.15 \
52
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
53
+ --finetune /your_beit3_model_path/beit3_base_patch16_224.pth \
54
+ --data_path /path/to/your_data \
55
+ --output_dir /path/to/save/your_model \
56
+ --log_dir /path/to/save/your_model/log \
57
+ --weight_decay 0.05 \
58
+ --seed 42 \
59
+ --save_ckpt_freq 5 \
60
+ --dist_eval \
61
+ --mixup 0.8 \
62
+ --cutmix 1.0 \
63
+ --enable_deepspeed
64
+ ```
65
+ - `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `8*128*1 = 1024`.
66
+ - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models)
67
+ - `--enable_deepspeed`: optional. If you use apex, please enable deepspeed.
68
+
69
+
70
+ The BEiT-3 **large** model can be finetuned on ImageNet-1k using a DGX box (8 V100-32GB):
71
+
72
+ ```bash
73
+ python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
74
+ --model beit3_large_patch16_224 \
75
+ --task imagenet \
76
+ --batch_size 128 \
77
+ --layer_decay 0.8 \
78
+ --lr 2e-4 \
79
+ --update_freq 1 \
80
+ --epochs 50 \
81
+ --warmup_epochs 5 \
82
+ --drop_path 0.25 \
83
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
84
+ --finetune /your_beit3_model_path/beit3_large_patch16_224.pth \
85
+ --data_path /path/to/your_data \
86
+ --output_dir /path/to/save/your_model \
87
+ --log_dir /path/to/save/your_model/log \
88
+ --weight_decay 0.05 \
89
+ --seed 42 \
90
+ --save_ckpt_freq 5 \
91
+ --dist_eval \
92
+ --mixup 0.8 \
93
+ --cutmix 1.0 \
94
+ --enable_deepspeed \
95
+ --checkpoint_activations
96
+ ```
97
+ - `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `8*128 = 1024`.
98
+ - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models)
99
+ - `--enable_deepspeed`: optional. If you use apex, please enable deepspeed.
100
+ - `--checkpoint_activations`: using gradient checkpointing for saving GPU memory
101
+
102
+ ## Example: Evaluate BEiT-3 Finetuned model on ImageNet-1k (Image Classification)
103
+
104
+ - Evaluate our fine-tuned BEiT3-base model on ImageNet val with a single GPU:
105
+ ```bash
106
+ python -m torch.distributed.launch --nproc_per_node=1 run_beit3_finetuning.py \
107
+ --model beit3_base_patch16_224 \
108
+ --task imagenet \
109
+ --batch_size 128 \
110
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
111
+ --finetune /your_beit3_model_path/beit3_base_patch16_224_in1k.pth \
112
+ --data_path /path/to/your_data \
113
+ --eval \
114
+ --dist_eval
115
+ ```
116
+
117
+ Expected results:
118
+ ```
119
+ * Acc@1 85.400 Acc@5 97.630
120
+ ```
121
+
122
+ - Evaluate our fine-tuned BEiT3-large model on ImageNet val with a single GPU:
123
+ ```bash
124
+ python -m torch.distributed.launch --nproc_per_node=1 run_beit3_finetuning.py \
125
+ --model beit3_large_patch16_224 \
126
+ --task imagenet \
127
+ --batch_size 128 \
128
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
129
+ --finetune /your_beit3_model_path/beit3_large_patch16_224_in1k.pth \
130
+ --data_path /path/to/your_data \
131
+ --eval \
132
+ --dist_eval
133
+ ```
134
+
135
+ Expected results:
136
+ ```
137
+ * Acc@1 87.580 Acc@5 98.326
138
+ ```
model/unilm/beit3/get_started/get_started_for_nlvr2.md ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning BEiT-3 on NLVR2 (Visual Reasoning)
2
+
3
+
4
+ ## Setup
5
+
6
+ 1. [Setup environment](../README.md#setup).
7
+ 2. Clone the [repository](https://github.com/lil-lab/nlvr) and sign the [request form](https://goo.gl/forms/yS29stWnFWzrDBFH3) to download the images, then organize the dataset as following structure:
8
+
9
+ ```
10
+ /path/to/your_data/
11
+ images/train/
12
+ 0/train-11670-0-img0.png
13
+ ...
14
+ dev/
15
+ dev-269-0-img0.png
16
+ ...
17
+ test1/
18
+ test1-261-0-img0.png
19
+ ...
20
+ nlvr/ (nlvr repo)
21
+ nlvr/
22
+ nlvr2/
23
+ ```
24
+
25
+ We then generate the index json files using the following command. [beit3.spm](https://github.com/addf400/files/releases/download/beit3/beit3.spm) is the sentencepiece model used for tokenizing texts.
26
+ ```
27
+ from datasets import NLVR2Dataset
28
+ from transformers import XLMRobertaTokenizer
29
+
30
+ tokenizer = XLMRobertaTokenizer("/your_beit3_model_path/beit3.spm")
31
+
32
+ NLVR2Dataset.make_dataset_index(
33
+ data_path="/path/to/your_data",
34
+ tokenizer=tokenizer,
35
+ nlvr_repo_path="/path/to/your_data/nlvr"
36
+ )
37
+ ```
38
+
39
+
40
+ ## Example: Fine-tuning BEiT-3 on NLVR2 (Visual Reasoning)
41
+
42
+ The BEiT-3 **base** model can be finetuned on NLVR2 using 8 V100-32GB:
43
+
44
+ ```bash
45
+ python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
46
+ --model beit3_base_patch16_224 \
47
+ --task nlvr2 \
48
+ --batch_size 32 \
49
+ --layer_decay 0.65 \
50
+ --lr 7e-4 \
51
+ --epochs 20 \
52
+ --warmup_epochs 5 \
53
+ --drop_path 0.2 \
54
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
55
+ --finetune /your_beit3_model_path/beit3_base_patch16_224.pth \
56
+ --data_path /path/to/your_data \
57
+ --output_dir /path/to/save/your_model \
58
+ --log_dir /path/to/save/your_model/log \
59
+ --weight_decay 0.2 \
60
+ --seed 42 \
61
+ --save_ckpt_freq 5 \
62
+ --enable_deepspeed
63
+ ```
64
+ - `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `8*32 = 256`.
65
+ - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models).
66
+ - `--enable_deepspeed`: optional. If you use apex, please enable deepspeed.
67
+ - `--lr`: 7e-4 for `BEiT3-base`, 5e-4 for `BEiT3-base-indomain`.
68
+
69
+
70
+ The BEiT-3 **large** model can be finetuned on NLVR2 using 8 V100-32GB:
71
+
72
+ ```bash
73
+ python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
74
+ --model beit3_large_patch16_224 \
75
+ --task nlvr2 \
76
+ --batch_size 32 \
77
+ --layer_decay 0.85 \
78
+ --lr 3e-4 \
79
+ --epochs 20 \
80
+ --warmup_epochs 5 \
81
+ --drop_path 0.2 \
82
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
83
+ --finetune /your_beit3_model_path/beit3_large_patch16_224.pth \
84
+ --data_path /path/to/your_data \
85
+ --output_dir /path/to/save/your_model \
86
+ --log_dir /path/to/save/your_model/log \
87
+ --weight_decay 0.2 \
88
+ --seed 42 \
89
+ --save_ckpt_freq 5 \
90
+ --enable_deepspeed \
91
+ --checkpoint_activations
92
+ ```
93
+ - `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `8*32 = 256`.
94
+ - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models).
95
+ - `--enable_deepspeed`: optional. If you use apex, please enable deepspeed.
96
+ - `--lr`: 3e-4 for `BEiT3-large`, 1e-4 for `BEiT3-large-indomain`.
97
+ - `--checkpoint_activations`: using gradient checkpointing for saving GPU memory.
98
+
99
+
100
+ ## Example: Evaluate BEiT-3 Finetuned model on NLVR2 (Visual Reasoning)
101
+
102
+ - Get the result of our fine-tuned BEiT3-base model on NLVR2 test with 8 V100-32GB:
103
+ ```bash
104
+ python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
105
+ --model beit3_base_patch16_224 \
106
+ --task nlvr2 \
107
+ --batch_size 32 \
108
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
109
+ --finetune /your_beit3_model_path/beit3_base_patch16_224_nlvr2.pth \
110
+ --data_path /path/to/your_data \
111
+ --eval \
112
+ --dist_eval
113
+ ```
114
+
115
+ Expected results:
116
+ ```
117
+ * Acc 84.386
118
+ ```
119
+
120
+ - Get the result of our fine-tuned BEiT3-large model on NLVR2 test with 8 V100-32GB:
121
+ ```bash
122
+ python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
123
+ --model beit3_large_patch16_224 \
124
+ --task nlvr2 \
125
+ --batch_size 32 \
126
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
127
+ --finetune /your_beit3_model_path/beit3_large_patch16_224_nlvr2.pth \
128
+ --data_path /path/to/your_data \
129
+ --eval \
130
+ --dist_eval
131
+ ```
132
+
133
+ Expected results:
134
+ ```
135
+ * Acc 89.437
136
+ ```
model/unilm/beit3/get_started/get_started_for_retrieval.md ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning BEiT-3 on Image-text Retrieval
2
+
3
+ ## COCO Retrieval Setup
4
+
5
+ 1. [Setup environment](../README.md#setup).
6
+ 2. Download [2014 train images](http://images.cocodataset.org/zips/train2014.zip), [2014 val images](http://images.cocodataset.org/zips/val2014.zip) and [karpathy split](https://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip), then organize the dataset as following structure:
7
+
8
+ ```
9
+ /path/to/your_data/
10
+ train2014/
11
+ COCO_train2014_000000000009.jpg
12
+ ...
13
+ val2014/
14
+ COCO_val2014_000000000042.jpg
15
+ ...
16
+ dataset_coco.json
17
+ ```
18
+
19
+ We then generate the index json files using the following command. [beit3.spm](https://github.com/addf400/files/releases/download/beit3/beit3.spm) is the sentencepiece model used for tokenizing texts.
20
+ ```
21
+ from datasets import RetrievalDataset
22
+ from transformers import XLMRobertaTokenizer
23
+
24
+ tokenizer = XLMRobertaTokenizer("/your_beit3_model_path/beit3.spm")
25
+
26
+ RetrievalDataset.make_coco_dataset_index(
27
+ data_path="/path/to/your_data",
28
+ tokenizer=tokenizer,
29
+ )
30
+ ```
31
+
32
+
33
+ ## Flickr30k Retrieval Setup
34
+
35
+ 1. [Setup environment](README.md#setup).
36
+ 2. Sign [flickr images request form](https://forms.illinois.edu/sec/229675) and download [karpathy split](https://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip), then organize the dataset as following structure:
37
+
38
+ ```
39
+ /path/to/your_data/
40
+ flickr30k-images/
41
+ 2923475135.jpg
42
+ ...
43
+ dataset_flickr30k.json
44
+ ```
45
+
46
+ We then generate the index json files using the following command. [beit3.spm](https://github.com/addf400/files/releases/download/beit3/beit3.spm) is the sentencepiece model used for tokenizing texts.
47
+ ```
48
+ from datasets import RetrievalDataset
49
+ from transformers import XLMRobertaTokenizer
50
+
51
+ tokenizer = XLMRobertaTokenizer("/your_beit3_model_path/beit3.spm")
52
+
53
+ RetrievalDataset.make_flickr30k_dataset_index(
54
+ data_path="/path/to/your_data",
55
+ tokenizer=tokenizer,
56
+ karpathy_path="/path/to/your_data",
57
+ )
58
+ ```
59
+
60
+
61
+ ## Example: Fine-tuning BEiT-3 on Retrieval
62
+
63
+ The BEiT-3 **base** model can be finetuned on retrieval tasks using 16 V100-32GB:
64
+
65
+ ```bash
66
+ python -m torch.distributed.launch --nproc_per_node=16 run_beit3_finetuning.py \
67
+ --model beit3_base_patch16_384 \
68
+ --input_size 384 \
69
+ --task coco_retrieval \
70
+ --batch_size 192 \
71
+ --layer_decay 0.65 \
72
+ --lr 2e-4 \
73
+ --epochs 15 \
74
+ --warmup_epochs 3 \
75
+ --drop_path 0.2 \
76
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
77
+ --finetune /your_beit3_model_path/beit3_base_itc_patch16_224.pth \
78
+ --data_path /path/to/your_data \
79
+ --output_dir /path/to/save/your_model \
80
+ --log_dir /path/to/save/your_model/log \
81
+ --weight_decay 0.05 \
82
+ --seed 42 \
83
+ --save_ckpt_freq 5 \
84
+ --enable_deepspeed \
85
+ --checkpoint_activations
86
+ ```
87
+ - `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `192*16 = 3072`.
88
+ - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models)
89
+ - `--task`: **coco_retrieval** for COCO retrieval, **flickr30k** for Flickr30k retrieval
90
+ - `--lr`: 2e-4 for COCO retrieval, 1e-4 for Flickr30k retrieval
91
+ - `--epochs`: 15 for COCO retrieval, 20 for Flickr30k retrieval
92
+ - `--warmup_epochs`: 3 for COCO retrieval, 5 for Flickr30k retrieval
93
+ - `--checkpoint_activations`: using gradient checkpointing for saving GPU memory
94
+
95
+
96
+ The BEiT-3 **large** model can be finetuned on retrieval tasks using 2x16 V100-32GB:
97
+
98
+ ```bash
99
+ python -m torch.distributed.launch --nproc_per_node=16 --nnodes=2 --node_rank=$NODE_RANK \
100
+ --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT run_beit3_finetuning.py \
101
+ --model beit3_large_patch16_384 \
102
+ --input_size 384 \
103
+ --task coco_retrieval \
104
+ --batch_size 96 \
105
+ --layer_decay 0.85 \
106
+ --lr 5e-5 \
107
+ --epochs 15 \
108
+ --warmup_epochs 3 \
109
+ --drop_path 0.2 \
110
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
111
+ --finetune /your_beit3_model_path/beit3_large_itc_patch16_224.pth \
112
+ --data_path /path/to/your_data \
113
+ --output_dir /path/to/save/your_model \
114
+ --log_dir /path/to/save/your_model/log \
115
+ --weight_decay 0.05 \
116
+ --seed 42 \
117
+ --save_ckpt_freq 5 \
118
+ --enable_deepspeed \
119
+ --checkpoint_activations
120
+ ```
121
+ - `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `96*32 = 3072`.
122
+ - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models)
123
+ - `--task`: **coco_retrieval** for COCO retrieval, **flickr30k** for Flickr30k retrieval
124
+ - `--epochs`: 15 for COCO retrieval, 20 for Flickr30k retrieval
125
+ - `--warmup_epochs`: 3 for COCO retrieval, 5 for Flickr30k retrieval
126
+ - `--checkpoint_activations`: using gradient checkpointing for saving GPU memory
127
+
128
+
129
+ ## Example: Evaluate BEiT-3 Fine-tuned model on COCO Retrieval and Flickr30k Retrieval
130
+
131
+ - Get the results of our fine-tuned BEiT3-base model on retrieval tasks using a single GPU:
132
+ ```bash
133
+ python -m torch.distributed.launch --nproc_per_node=1 run_beit3_finetuning.py \
134
+ --model beit3_base_patch16_384 \
135
+ --input_size 384 \
136
+ --task coco_retrieval \
137
+ --batch_size 16 \
138
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
139
+ --finetune /your_beit3_model_path/beit3_base_patch16_384_coco_retrieval.pth \
140
+ --data_path /path/to/your_data \
141
+ --eval \
142
+ --dist_eval
143
+ ```
144
+ - `--task`: **coco_retrieval** for COCO retrieval, **flickr30k** for Flickr30k retrieval
145
+ - `--finetune`: **beit3_base_patch16_384_coco_retrieval.pth** for COCO retrieval, **beit3_base_patch16_384_f30k_retrieval.pth** for Flickr30k retrieval
146
+
147
+ - Get the results of our fine-tuned BEiT3-large model on retrieval tasks using a single GPU:
148
+ ```bash
149
+ python -m torch.distributed.launch --nproc_per_node=1 run_beit3_finetuning.py \
150
+ --model beit3_large_patch16_384 \
151
+ --input_size 384 \
152
+ --task coco_retrieval \
153
+ --batch_size 16 \
154
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
155
+ --finetune /your_beit3_model_path/beit3_large_patch16_384_coco_retrieval.pth \
156
+ --data_path /path/to/your_data \
157
+ --eval \
158
+ --dist_eval
159
+ ```
160
+ - `--task`: **coco_retrieval** for COCO retrieval, **flickr30k** for Flickr30k retrieval
161
+ - `--finetune`: **beit3_large_patch16_384_coco_retrieval.pth** for COCO retrieval, **beit3_large_patch16_384_f30k_retrieval.pth** for Flickr30k retrieval
model/unilm/beit3/get_started/get_started_for_vqav2.md ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning BEiT-3 on VQAv2 (Visual Question Answering)
2
+
3
+
4
+ ## Setup
5
+
6
+ 1. [Setup environment](../README.md#setup).
7
+ 2. Download COCO [2014 train images](http://images.cocodataset.org/zips/train2014.zip), [2014 val images](http://images.cocodataset.org/zips/val2014.zip), [2015 test images](http://images.cocodataset.org/zips/test2015.zip), annotations ([train](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip), [val](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip)), and questions ([train](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip), [val](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip), [test](https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip)), then organize the dataset as following structure:
8
+
9
+ ```
10
+ /path/to/your_data/
11
+ train2014/
12
+ COCO_train2014_000000000009.jpg
13
+ ...
14
+ val2014/
15
+ COCO_val2014_000000000042.jpg
16
+ ...
17
+ test2015/
18
+ COCO_test2015_000000000001.jpg
19
+ ...
20
+ vqa/
21
+ v2_OpenEnded_mscoco_train2014_questions.json
22
+ v2_OpenEnded_mscoco_val2014_questions.json
23
+ v2_OpenEnded_mscoco_test2015_questions.json
24
+ v2_OpenEnded_mscoco_test-dev2015_questions.json
25
+ v2_mscoco_train2014_annotations.json
26
+ v2_mscoco_val2014_annotations.json
27
+ ```
28
+
29
+ We then generate the index json files using the following command. [beit3.spm](https://github.com/addf400/files/releases/download/beit3/beit3.spm) is the sentencepiece model used for tokenizing texts.
30
+ ```
31
+ from datasets import VQAv2Dataset
32
+ from transformers import XLMRobertaTokenizer
33
+
34
+ tokenizer = XLMRobertaTokenizer("/your_beit3_model_path/beit3.spm")
35
+
36
+ VQAv2Dataset.make_dataset_index(
37
+ data_path="/path/to/your_data",
38
+ tokenizer=tokenizer,
39
+ annotation_data_path="/path/to/your_data/vqa",
40
+ )
41
+ ```
42
+
43
+
44
+ ## Example: Fine-tuning BEiT-3 on VQAv2 (Visual Question Answering)
45
+
46
+ The BEiT-3 **base** model can be finetuned on VQAv2 using 8 V100-32GB:
47
+
48
+ ```bash
49
+ python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
50
+ --model beit3_base_patch16_480 \
51
+ --input_size 480 \
52
+ --task vqav2 \
53
+ --batch_size 16 \
54
+ --layer_decay 1.0 \
55
+ --lr 3e-5 \
56
+ --update_freq 1 \
57
+ --randaug \
58
+ --epochs 10 \
59
+ --warmup_epochs 1 \
60
+ --drop_path 0.1 \
61
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
62
+ --finetune /your_beit3_model_path/beit3_base_patch16_224.pth \
63
+ --data_path /path/to/your_data \
64
+ --output_dir /path/to/save/your_model \
65
+ --log_dir /path/to/save/your_model/log \
66
+ --weight_decay 0.01 \
67
+ --seed 42 \
68
+ --save_ckpt_freq 5 \
69
+ --task_head_lr_weight 20 \
70
+ --opt_betas 0.9 0.98 \
71
+ --enable_deepspeed
72
+ ```
73
+ - `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `8*16 = 128`.
74
+ - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models)
75
+ - `--enable_deepspeed`: optional. If you use apex, please enable deepspeed.
76
+
77
+
78
+ The BEiT-3 **large** model can be finetuned on VQAv2 using 8 V100-32GB:
79
+
80
+ ```bash
81
+ python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
82
+ --model beit3_large_patch16_480 \
83
+ --input_size 480 \
84
+ --task vqav2 \
85
+ --batch_size 16 \
86
+ --layer_decay 1.0 \
87
+ --lr 2e-5 \
88
+ --update_freq 1 \
89
+ --randaug \
90
+ --epochs 10 \
91
+ --warmup_epochs 1 \
92
+ --drop_path 0.15 \
93
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
94
+ --finetune /your_beit3_model_path/beit3_large_patch16_224.pth \
95
+ --data_path /path/to/your_data \
96
+ --output_dir /path/to/save/your_model \
97
+ --log_dir /path/to/save/your_model/log \
98
+ --weight_decay 0.01 \
99
+ --seed 42 \
100
+ --save_ckpt_freq 5 \
101
+ --task_head_lr_weight 20 \
102
+ --opt_betas 0.9 0.98 \
103
+ --enable_deepspeed \
104
+ --checkpoint_activations
105
+ ```
106
+ - `--batch_size`: batch size per GPU. Effective batch size = `number of GPUs` * `--batch_size` * `--update_freq`. So in the above example, the effective batch size is `8*16 = 128`.
107
+ - `--finetune`: weight path of your pretrained models; please download the pretrained model weights in [README.md](../README.md#pretrained-models)
108
+ - `--enable_deepspeed`: optional. If you use apex, please enable deepspeed.
109
+ - `--checkpoint_activations`: using gradient checkpointing for saving GPU memory
110
+
111
+
112
+ ## Example: Evaluate BEiT-3 Finetuned model on VQAv2 (Visual Question Answering)
113
+
114
+ - Get the prediction file of the fine-tuned BEiT3-base model on VQAv2 test with 8 V100-32GB:
115
+ ```bash
116
+ python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
117
+ --model beit3_base_patch16_480 \
118
+ --input_size 480 \
119
+ --task vqav2 \
120
+ --batch_size 16 \
121
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
122
+ --finetune /your_beit3_model_path/beit3_base_patch16_480_vqa.pth \
123
+ --data_path /path/to/your_data \
124
+ --output_dir /path/to/save/your_prediction \
125
+ --eval \
126
+ --dist_eval
127
+ ```
128
+
129
+ - Get the prediction file of the fine-tuned BEiT3-large model on VQAv2 test with 8 V100-32GB:
130
+ ```bash
131
+ python -m torch.distributed.launch --nproc_per_node=8 run_beit3_finetuning.py \
132
+ --model beit3_large_patch16_480 \
133
+ --input_size 480 \
134
+ --task vqav2 \
135
+ --batch_size 16 \
136
+ --sentencepiece_model /your_beit3_model_path/beit3.spm \
137
+ --finetune /your_beit3_model_path/beit3_large_patch16_480_vqa.pth \
138
+ --data_path /path/to/your_data \
139
+ --output_dir /path/to/save/your_prediction \
140
+ --eval \
141
+ --dist_eval
142
+ ```
143
+
144
+ Please then submit the prediction file in the `output_dir` to the [evaluation server](https://eval.ai/web/challenges/challenge-page/830/overview) to obtain the VQAv2 test-dev and test-std results.
model/unilm/beit3/glossary.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ contractions = {
4
+ "aint": "ain't",
5
+ "arent": "aren't",
6
+ "cant": "can't",
7
+ "couldve": "could've",
8
+ "couldnt": "couldn't",
9
+ "couldn'tve": "couldn't've",
10
+ "couldnt've": "couldn't've",
11
+ "didnt": "didn't",
12
+ "doesnt": "doesn't",
13
+ "dont": "don't",
14
+ "hadnt": "hadn't",
15
+ "hadnt've": "hadn't've",
16
+ "hadn'tve": "hadn't've",
17
+ "hasnt": "hasn't",
18
+ "havent": "haven't",
19
+ "hed": "he'd",
20
+ "hed've": "he'd've",
21
+ "he'dve": "he'd've",
22
+ "hes": "he's",
23
+ "howd": "how'd",
24
+ "howll": "how'll",
25
+ "hows": "how's",
26
+ "Id've": "I'd've",
27
+ "I'dve": "I'd've",
28
+ "Im": "I'm",
29
+ "Ive": "I've",
30
+ "isnt": "isn't",
31
+ "itd": "it'd",
32
+ "itd've": "it'd've",
33
+ "it'dve": "it'd've",
34
+ "itll": "it'll",
35
+ "let's": "let's",
36
+ "maam": "ma'am",
37
+ "mightnt": "mightn't",
38
+ "mightnt've": "mightn't've",
39
+ "mightn'tve": "mightn't've",
40
+ "mightve": "might've",
41
+ "mustnt": "mustn't",
42
+ "mustve": "must've",
43
+ "neednt": "needn't",
44
+ "notve": "not've",
45
+ "oclock": "o'clock",
46
+ "oughtnt": "oughtn't",
47
+ "ow's'at": "'ow's'at",
48
+ "'ows'at": "'ow's'at",
49
+ "'ow'sat": "'ow's'at",
50
+ "shant": "shan't",
51
+ "shed've": "she'd've",
52
+ "she'dve": "she'd've",
53
+ "she's": "she's",
54
+ "shouldve": "should've",
55
+ "shouldnt": "shouldn't",
56
+ "shouldnt've": "shouldn't've",
57
+ "shouldn'tve": "shouldn't've",
58
+ "somebody'd": "somebodyd",
59
+ "somebodyd've": "somebody'd've",
60
+ "somebody'dve": "somebody'd've",
61
+ "somebodyll": "somebody'll",
62
+ "somebodys": "somebody's",
63
+ "someoned": "someone'd",
64
+ "someoned've": "someone'd've",
65
+ "someone'dve": "someone'd've",
66
+ "someonell": "someone'll",
67
+ "someones": "someone's",
68
+ "somethingd": "something'd",
69
+ "somethingd've": "something'd've",
70
+ "something'dve": "something'd've",
71
+ "somethingll": "something'll",
72
+ "thats": "that's",
73
+ "thered": "there'd",
74
+ "thered've": "there'd've",
75
+ "there'dve": "there'd've",
76
+ "therere": "there're",
77
+ "theres": "there's",
78
+ "theyd": "they'd",
79
+ "theyd've": "they'd've",
80
+ "they'dve": "they'd've",
81
+ "theyll": "they'll",
82
+ "theyre": "they're",
83
+ "theyve": "they've",
84
+ "twas": "'twas",
85
+ "wasnt": "wasn't",
86
+ "wed've": "we'd've",
87
+ "we'dve": "we'd've",
88
+ "weve": "we've",
89
+ "werent": "weren't",
90
+ "whatll": "what'll",
91
+ "whatre": "what're",
92
+ "whats": "what's",
93
+ "whatve": "what've",
94
+ "whens": "when's",
95
+ "whered": "where'd",
96
+ "wheres": "where's",
97
+ "whereve": "where've",
98
+ "whod": "who'd",
99
+ "whod've": "who'd've",
100
+ "who'dve": "who'd've",
101
+ "wholl": "who'll",
102
+ "whos": "who's",
103
+ "whove": "who've",
104
+ "whyll": "why'll",
105
+ "whyre": "why're",
106
+ "whys": "why's",
107
+ "wont": "won't",
108
+ "wouldve": "would've",
109
+ "wouldnt": "wouldn't",
110
+ "wouldnt've": "wouldn't've",
111
+ "wouldn'tve": "wouldn't've",
112
+ "yall": "y'all",
113
+ "yall'll": "y'all'll",
114
+ "y'allll": "y'all'll",
115
+ "yall'd've": "y'all'd've",
116
+ "y'alld've": "y'all'd've",
117
+ "y'all'dve": "y'all'd've",
118
+ "youd": "you'd",
119
+ "youd've": "you'd've",
120
+ "you'dve": "you'd've",
121
+ "youll": "you'll",
122
+ "youre": "you're",
123
+ "youve": "you've",
124
+ }
125
+
126
+ manual_map = {
127
+ "none": "0",
128
+ "zero": "0",
129
+ "one": "1",
130
+ "two": "2",
131
+ "three": "3",
132
+ "four": "4",
133
+ "five": "5",
134
+ "six": "6",
135
+ "seven": "7",
136
+ "eight": "8",
137
+ "nine": "9",
138
+ "ten": "10",
139
+ }
140
+ articles = ["a", "an", "the"]
141
+ period_strip = re.compile("(?!<=\d)(\.)(?!\d)")
142
+ comma_strip = re.compile("(\d)(\,)(\d)")
143
+ punct = [
144
+ ";",
145
+ r"/",
146
+ "[",
147
+ "]",
148
+ '"',
149
+ "{",
150
+ "}",
151
+ "(",
152
+ ")",
153
+ "=",
154
+ "+",
155
+ "\\",
156
+ "_",
157
+ "-",
158
+ ">",
159
+ "<",
160
+ "@",
161
+ "`",
162
+ ",",
163
+ "?",
164
+ "!",
165
+ ]
166
+
167
+
168
+ def normalize_word(token):
169
+ _token = token
170
+ for p in punct:
171
+ if (p + " " in token or " " + p in token) or (
172
+ re.search(comma_strip, token) != None
173
+ ):
174
+ _token = _token.replace(p, "")
175
+ else:
176
+ _token = _token.replace(p, " ")
177
+ token = period_strip.sub("", _token, re.UNICODE)
178
+
179
+ _token = []
180
+ temp = token.lower().split()
181
+ for word in temp:
182
+ word = manual_map.setdefault(word, word)
183
+ if word not in articles:
184
+ _token.append(word)
185
+ for i, word in enumerate(_token):
186
+ if word in contractions:
187
+ _token[i] = contractions[word]
188
+ token = " ".join(_token)
189
+ token = token.replace(",", "")
190
+ return token
model/unilm/beit3/modeling_finetune.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beit3
4
+ # Copyright (c) 2023 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # --------------------------------------------------------'
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from timm.models.registry import register_model
12
+ import numpy as np
13
+
14
+ import utils
15
+ from modeling_utils import BEiT3Wrapper, _get_base_config, _get_large_config
16
+
17
+
18
+ class TwoLayerMLP(nn.Module):
19
+ def __init__(
20
+ self,
21
+ in_features,
22
+ hidden_features,
23
+ out_features,
24
+ norm_layer,
25
+ norm_input=True,
26
+ ):
27
+ super().__init__()
28
+ self.norm1 = norm_layer(in_features) if norm_input else nn.Identity()
29
+ self.dense1 = nn.Linear(in_features, hidden_features)
30
+ self.norm2 = norm_layer(hidden_features)
31
+ self.act = nn.GELU()
32
+ self.dense2 = nn.Linear(hidden_features, out_features)
33
+
34
+ def forward(self, x):
35
+ x = self.norm1(x)
36
+ x = self.dense1(x)
37
+ x = self.norm2(x)
38
+ x = self.act(x)
39
+ return self.dense2(x)
40
+
41
+
42
+ class Pooler(nn.Module):
43
+ def __init__(self, input_features, output_features, norm_layer):
44
+ super().__init__()
45
+ self.norm = norm_layer(input_features)
46
+ self.dense = nn.Linear(input_features, output_features)
47
+ self.activation = nn.Tanh()
48
+
49
+ def forward(self, x):
50
+ cls_rep = x[:, 0, :]
51
+ cls_rep = self.norm(cls_rep)
52
+ pooled_output = self.dense(cls_rep)
53
+ pooled_output = self.activation(pooled_output)
54
+ return pooled_output
55
+
56
+
57
+ class BEiT3ForVisualReasoning(BEiT3Wrapper):
58
+ def __init__(
59
+ self,
60
+ args,
61
+ num_classes,
62
+ norm_layer=nn.LayerNorm,
63
+ **kwargs
64
+ ):
65
+ super(BEiT3ForVisualReasoning, self).__init__(args=args)
66
+ embed_dim = args.encoder_embed_dim
67
+ self.head = TwoLayerMLP(
68
+ in_features=embed_dim * 4,
69
+ hidden_features=embed_dim * 2,
70
+ out_features=num_classes,
71
+ norm_layer=norm_layer,
72
+ )
73
+ init_scale = 0.001
74
+ self.head.apply(self._init_weights)
75
+ if isinstance(self.head.dense1, nn.Linear):
76
+ self.head.dense1.weight.data.mul_(init_scale)
77
+ self.head.dense1.bias.data.mul_(init_scale)
78
+
79
+ if isinstance(self.head.dense2, nn.Linear):
80
+ self.head.dense2.weight.data.mul_(init_scale)
81
+ self.head.dense2.bias.data.mul_(init_scale)
82
+
83
+ def forward(self, image_a, image_b, text_description, padding_mask, **kwargs):
84
+ bsz, _ = text_description.size()
85
+
86
+ vision_input = torch.cat((image_a, image_b), dim=0)
87
+ language_input = torch.cat((text_description, text_description), dim=0)
88
+ padding_mask = torch.cat((padding_mask, padding_mask), dim=0)
89
+
90
+ outputs = self.beit3(
91
+ textual_tokens=language_input,
92
+ visual_tokens=vision_input,
93
+ text_padding_position=padding_mask,
94
+ )
95
+ x = outputs["encoder_out"]
96
+ multiway_split_position = outputs["multiway_split_position"]
97
+
98
+ vision_cls = x[:, 0, :]
99
+ language_cls = x[:, multiway_split_position, :]
100
+ cls_rep = torch.cat((vision_cls, language_cls), dim=-1)
101
+ a, b = torch.split(cls_rep, split_size_or_sections=[bsz, bsz], dim=0)
102
+ cls_rep = torch.cat((a, b), dim=-1)
103
+ return self.head(cls_rep)
104
+
105
+
106
+ class BEiT3ForImageClassification(BEiT3Wrapper):
107
+ def __init__(
108
+ self,
109
+ args,
110
+ num_classes,
111
+ norm_layer=nn.LayerNorm,
112
+ **kwargs
113
+ ):
114
+ super(BEiT3ForImageClassification, self).__init__(args=args)
115
+ embed_dim = args.encoder_embed_dim
116
+ self.fc_norm = norm_layer(embed_dim)
117
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
118
+
119
+ self.fc_norm.apply(self._init_weights)
120
+ self.head.apply(self._init_weights)
121
+ init_scale = 0.001
122
+ if isinstance(self.head, nn.Linear):
123
+ self.head.weight.data.mul_(init_scale)
124
+ self.head.bias.data.mul_(init_scale)
125
+
126
+ def forward(self, image, **kwargs):
127
+ x = self.beit3(textual_tokens=None, visual_tokens=image)["encoder_out"]
128
+ t = x[:, 1:, :]
129
+ cls_x = self.fc_norm(t.mean(1))
130
+ return self.head(cls_x)
131
+
132
+
133
+ class BEiT3ForCaptioning(BEiT3Wrapper):
134
+ def __init__(
135
+ self,
136
+ args,
137
+ **kwargs
138
+ ):
139
+ super(BEiT3ForCaptioning, self).__init__(args=args)
140
+ embed_dim = args.encoder_embed_dim
141
+ self.mlm_head = nn.Linear(embed_dim, args.vocab_size)
142
+ self.mlm_head.apply(self._init_weights)
143
+
144
+ def forward(self, image, text_ids, padding_mask, language_masked_pos, text_len=None, incremental_state=None, **kwargs):
145
+ text_len = text_len if text_len is not None else text_ids.size(1)
146
+ image_len = self.beit3.vision_embed.num_position_embeddings()
147
+ max_len = text_len + image_len
148
+ uni_mask = torch.zeros((max_len, max_len), dtype=torch.long, device=text_ids.device)
149
+ i_start, i_end = 0, image_len
150
+ t_start, t_end = image_len, max_len
151
+ # triangle mask for caption to caption
152
+ uni_mask[t_start:t_end, t_start:t_end] = torch.tril(torch.ones(text_len, text_len, dtype=torch.long, device=text_ids.device))
153
+ # full attention for caption to image
154
+ uni_mask[t_start:t_end, i_start:i_end] = 1
155
+ # full attention for image to image
156
+ uni_mask[i_start:i_end, i_start:i_end] = 1
157
+ uni_mask = 1-uni_mask
158
+
159
+ if incremental_state is not None:
160
+ for idx in range(self.get_num_layers()):
161
+ if idx not in incremental_state:
162
+ incremental_state[idx] = {}
163
+
164
+ # for incremental decoding
165
+ positions = None
166
+ if image is None:
167
+ uni_mask = uni_mask[-2:]
168
+ padding_mask = None
169
+ # start position (2 (fairseq starts at 2) + cur_position) is equal to text_len
170
+ positions = torch.arange(text_len, text_ids.size(1) + text_len, device=text_ids.device).long().unsqueeze(0)
171
+
172
+ outputs = self.beit3(
173
+ textual_tokens=text_ids,
174
+ visual_tokens=image,
175
+ text_padding_position=padding_mask,
176
+ attn_mask=uni_mask,
177
+ incremental_state=incremental_state,
178
+ positions=positions,
179
+ )
180
+ if image is not None:
181
+ text_feats = outputs["encoder_out"][:, image_len:]
182
+ else:
183
+ text_feats = outputs["encoder_out"]
184
+
185
+ if language_masked_pos is not None:
186
+ text_feats = text_feats[language_masked_pos.bool()]
187
+
188
+ return self.mlm_head(text_feats), incremental_state
189
+
190
+
191
+ class BEiT3ForVisualQuestionAnswering(BEiT3Wrapper):
192
+ def __init__(
193
+ self,
194
+ args,
195
+ num_classes,
196
+ norm_layer=nn.LayerNorm,
197
+ **kwargs
198
+ ):
199
+ super(BEiT3ForVisualQuestionAnswering, self).__init__(args=args)
200
+ embed_dim = args.encoder_embed_dim
201
+ self.pooler = Pooler(
202
+ input_features=embed_dim,
203
+ output_features=embed_dim,
204
+ norm_layer=norm_layer,
205
+ )
206
+ self.pooler.apply(self._init_weights)
207
+ self.head = nn.Sequential(
208
+ nn.Linear(embed_dim, embed_dim * 2),
209
+ norm_layer(embed_dim * 2),
210
+ nn.GELU(),
211
+ nn.Linear(embed_dim * 2, num_classes),
212
+ )
213
+ self.head.apply(self._init_weights)
214
+
215
+ def forward(self, image, question, padding_mask, **kwargs):
216
+ outputs = self.beit3(
217
+ textual_tokens=question,
218
+ visual_tokens=image,
219
+ text_padding_position=padding_mask,
220
+ )
221
+ x = outputs["encoder_out"]
222
+ cls_rep = self.pooler(x)
223
+ return self.head(cls_rep)
224
+
225
+
226
+ class BEiT3ForRetrieval(BEiT3Wrapper):
227
+ def __init__(
228
+ self,
229
+ args,
230
+ **kwargs
231
+ ):
232
+ super(BEiT3ForRetrieval, self).__init__(args=args)
233
+ embed_dim = args.encoder_embed_dim
234
+ self.language_head = nn.Linear(embed_dim, embed_dim, bias=False)
235
+ self.vision_head = nn.Linear(embed_dim, embed_dim, bias=False)
236
+ self.language_head.apply(self._init_weights)
237
+ self.vision_head.apply(self._init_weights)
238
+ self.criterion = utils.ClipLoss(
239
+ rank=utils.get_rank(),
240
+ world_size=utils.get_world_size(),
241
+ )
242
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
243
+
244
+ def forward(self, image=None, text_description=None, padding_mask=None, only_infer=False, **kwargs):
245
+ if image is not None:
246
+ outputs = self.beit3(
247
+ textual_tokens=None,
248
+ visual_tokens=image,
249
+ text_padding_position=None,
250
+ )
251
+ x = outputs["encoder_out"]
252
+ vision_cls = self.vision_head(x[:, 0, :])
253
+ vision_cls = F.normalize(vision_cls, dim=-1)
254
+ else:
255
+ vision_cls = None
256
+
257
+ if text_description is not None:
258
+ outputs = self.beit3(
259
+ textual_tokens=text_description,
260
+ visual_tokens=None,
261
+ text_padding_position=padding_mask,
262
+ )
263
+ x = outputs["encoder_out"]
264
+ language_cls = self.language_head(x[:, 0, :])
265
+ language_cls = F.normalize(language_cls, dim=-1)
266
+ else:
267
+ language_cls = None
268
+
269
+ if only_infer:
270
+ return vision_cls, language_cls
271
+ else:
272
+ loss, logits_per_image, logits_per_text = self.criterion(
273
+ vision_cls, language_cls, self.logit_scale.exp())
274
+ return loss, vision_cls, language_cls
275
+
276
+
277
+ @register_model
278
+ def beit3_base_patch16_224_imageclassification(pretrained=False, **kwargs):
279
+ args = _get_base_config(**kwargs)
280
+ args.normalize_output = False
281
+ model = BEiT3ForImageClassification(args, num_classes=1000, **kwargs)
282
+ return model
283
+
284
+
285
+ @register_model
286
+ def beit3_large_patch16_224_imageclassification(pretrained=False, **kwargs):
287
+ args = _get_large_config(**kwargs)
288
+ args.normalize_output = False
289
+ model = BEiT3ForImageClassification(args, num_classes=1000, **kwargs)
290
+ return model
291
+
292
+
293
+ @register_model
294
+ def beit3_base_patch16_224_nlvr2(pretrained=False, **kwargs):
295
+ args = _get_base_config(**kwargs)
296
+ model = BEiT3ForVisualReasoning(args, num_classes=2, **kwargs)
297
+ return model
298
+
299
+
300
+ @register_model
301
+ def beit3_large_patch16_224_nlvr2(pretrained=False, **kwargs):
302
+ args = _get_large_config(**kwargs)
303
+ model = BEiT3ForVisualReasoning(args, num_classes=2, **kwargs)
304
+ return model
305
+
306
+
307
+ @register_model
308
+ def beit3_base_patch16_384_vqav2(pretrained=False, **kwargs):
309
+ args = _get_base_config(img_size=384, **kwargs)
310
+ args.normalize_output = False
311
+ model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs)
312
+ return model
313
+
314
+
315
+ @register_model
316
+ def beit3_base_patch16_480_vqav2(pretrained=False, **kwargs):
317
+ args = _get_base_config(img_size=480, **kwargs)
318
+ args.normalize_output = False
319
+ model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs)
320
+ return model
321
+
322
+
323
+ @register_model
324
+ def beit3_large_patch16_384_vqav2(pretrained=False, **kwargs):
325
+ args = _get_large_config(img_size=384, **kwargs)
326
+ args.normalize_output = False
327
+ model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs)
328
+ return model
329
+
330
+
331
+ @register_model
332
+ def beit3_large_patch16_480_vqav2(pretrained=False, **kwargs):
333
+ args = _get_large_config(img_size=480, **kwargs)
334
+ args.normalize_output = False
335
+ model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs)
336
+ return model
337
+
338
+
339
+ @register_model
340
+ def beit3_large_patch16_768_vqav2(pretrained=False, **kwargs):
341
+ args = _get_large_config(img_size=768, **kwargs)
342
+ args.normalize_output = False
343
+ model = BEiT3ForVisualQuestionAnswering(args, num_classes=3129, **kwargs)
344
+ return model
345
+
346
+
347
+ @register_model
348
+ def beit3_base_patch16_224_captioning(pretrained=False, **kwargs):
349
+ args = _get_base_config(**kwargs)
350
+ model = BEiT3ForCaptioning(args, **kwargs)
351
+ return model
352
+
353
+
354
+ @register_model
355
+ def beit3_base_patch16_480_captioning(pretrained=False, **kwargs):
356
+ args = _get_base_config(img_size=480, **kwargs)
357
+ model = BEiT3ForCaptioning(args, **kwargs)
358
+ return model
359
+
360
+
361
+ @register_model
362
+ def beit3_large_patch16_480_captioning(pretrained=False, **kwargs):
363
+ args = _get_large_config(img_size=480, **kwargs)
364
+ model = BEiT3ForCaptioning(args, **kwargs)
365
+ return model
366
+
367
+
368
+ @register_model
369
+ def beit3_base_patch16_224_retrieval(pretrained=False, **kwargs):
370
+ args = _get_base_config(**kwargs)
371
+ model = BEiT3ForRetrieval(args, **kwargs)
372
+ return model
373
+
374
+
375
+ @register_model
376
+ def beit3_base_patch16_384_retrieval(pretrained=False, **kwargs):
377
+ args = _get_base_config(img_size=384, **kwargs)
378
+ model = BEiT3ForRetrieval(args, **kwargs)
379
+ return model
380
+
381
+
382
+ @register_model
383
+ def beit3_large_patch16_384_retrieval(pretrained=False, **kwargs):
384
+ args = _get_large_config(img_size=384, **kwargs)
385
+ model = BEiT3ForRetrieval(args, **kwargs)
386
+ return model
model/unilm/beit3/modeling_utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beit3
4
+ # Copyright (c) 2023 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # --------------------------------------------------------'
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ from timm.models.layers import trunc_normal_ as __call_trunc_normal_
12
+
13
+ from torchscale.model.BEiT3 import BEiT3
14
+ from torchscale.architecture.config import EncoderConfig
15
+
16
+
17
+ def trunc_normal_(tensor, mean=0., std=1.):
18
+ __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
19
+
20
+
21
+ def _get_base_config(
22
+ img_size=224, patch_size=16, drop_path_rate=0,
23
+ checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs
24
+ ):
25
+ return EncoderConfig(
26
+ img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True,
27
+ layernorm_embedding=False, normalize_output=True, no_output_layer=True,
28
+ drop_path_rate=drop_path_rate, encoder_embed_dim=768, encoder_attention_heads=12,
29
+ encoder_ffn_embed_dim=int(768 * mlp_ratio), encoder_layers=12,
30
+ checkpoint_activations=checkpoint_activations,
31
+ )
32
+
33
+
34
+ def _get_large_config(
35
+ img_size=224, patch_size=16, drop_path_rate=0,
36
+ checkpoint_activations=None, mlp_ratio=4, vocab_size=64010, **kwargs
37
+ ):
38
+ return EncoderConfig(
39
+ img_size=img_size, patch_size=patch_size, vocab_size=vocab_size, multiway=True,
40
+ layernorm_embedding=False, normalize_output=True, no_output_layer=True,
41
+ drop_path_rate=drop_path_rate, encoder_embed_dim=1024, encoder_attention_heads=16,
42
+ encoder_ffn_embed_dim=int(1024 * mlp_ratio), encoder_layers=24,
43
+ checkpoint_activations=checkpoint_activations,
44
+ )
45
+
46
+
47
+ class BEiT3Wrapper(nn.Module):
48
+ def __init__(self, args, **kwargs):
49
+ super().__init__()
50
+ self.args = args
51
+ self.beit3 = BEiT3(args)
52
+ self.apply(self._init_weights)
53
+
54
+ def fix_init_weight(self):
55
+ def rescale(param, layer_id):
56
+ param.div_(math.sqrt(2.0 * layer_id))
57
+
58
+ for layer_id, layer in enumerate(self.blocks):
59
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
60
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
61
+
62
+ def get_num_layers(self):
63
+ return self.beit3.encoder.num_layers
64
+
65
+ @torch.jit.ignore
66
+ def no_weight_decay(self):
67
+ return {'pos_embed', 'cls_token', 'beit3.encoder.embed_positions.A.weight', 'beit3.vision_embed.cls_token', 'logit_scale'}
68
+
69
+ def _init_weights(self, m):
70
+ if isinstance(m, nn.Linear):
71
+ trunc_normal_(m.weight, std=.02)
72
+ if isinstance(m, nn.Linear) and m.bias is not None:
73
+ nn.init.constant_(m.bias, 0)
74
+ elif isinstance(m, nn.LayerNorm):
75
+ nn.init.constant_(m.bias, 0)
76
+ nn.init.constant_(m.weight, 1.0)
model/unilm/beit3/optim_factory.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beit3
4
+ # Copyright (c) 2023 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # --------------------------------------------------------'
7
+
8
+ from torch import optim as optim
9
+ from timm.optim.lookahead import Lookahead
10
+
11
+ import json
12
+
13
+
14
+ def get_num_layer_for_vit(var_name, num_max_layer):
15
+ if "embed" in var_name:
16
+ return 0
17
+ elif var_name in (
18
+ "cls_token", "mask_token", "pos_embed", "language_pos_embed",
19
+ "word_embeddings.weight", "vision_cls_token", "vision_pos_embed"
20
+ ):
21
+ return 0
22
+ elif var_name.startswith("patch_embed"):
23
+ return 0
24
+ elif var_name.startswith("rel_pos_bias"):
25
+ return num_max_layer - 1
26
+ elif "layers." in var_name:
27
+ layer_id = int(var_name.split('layers.')[1].split('.')[0])
28
+ return layer_id + 1
29
+ else:
30
+ return num_max_layer - 1
31
+
32
+
33
+ def get_is_head_flag_for_vit(var_name, num_max_layer):
34
+ if var_name.startswith("head"):
35
+ return 1
36
+ # elif var_name.startswith("pooler"):
37
+ # return 1
38
+ else:
39
+ return 0
40
+
41
+
42
+ class LayerDecayValueAssigner(object):
43
+ def __init__(self, values, scale_handler=None):
44
+ self.scale_handler = scale_handler or get_num_layer_for_vit
45
+ self.values = values
46
+
47
+ def get_scale(self, layer_id):
48
+ return self.values[layer_id]
49
+
50
+ def get_layer_id(self, var_name):
51
+ return self.scale_handler(var_name, len(self.values))
52
+
53
+
54
+ # The implementation code is modified from Timm (https://github.com/huggingface/pytorch-image-models/tree/main/timm
55
+ def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
56
+ parameter_group_names = {}
57
+ parameter_group_vars = {}
58
+
59
+ for name, param in model.named_parameters():
60
+ if not param.requires_grad:
61
+ continue # frozen weights
62
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
63
+ group_name = "no_decay"
64
+ this_weight_decay = 0.
65
+ else:
66
+ group_name = "decay"
67
+ this_weight_decay = weight_decay
68
+ if get_num_layer is not None:
69
+ layer_id = get_num_layer(name)
70
+ group_name = "layer_%d_%s" % (layer_id, group_name)
71
+ else:
72
+ layer_id = None
73
+
74
+ if group_name not in parameter_group_names:
75
+ if get_layer_scale is not None:
76
+ scale = get_layer_scale(layer_id)
77
+ else:
78
+ scale = 1.
79
+
80
+ parameter_group_names[group_name] = {
81
+ "weight_decay": this_weight_decay,
82
+ "params": [],
83
+ "lr_scale": scale
84
+ }
85
+ parameter_group_vars[group_name] = {
86
+ "weight_decay": this_weight_decay,
87
+ "params": [],
88
+ "lr_scale": scale
89
+ }
90
+
91
+ parameter_group_vars[group_name]["params"].append(param)
92
+ parameter_group_names[group_name]["params"].append(name)
93
+ print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
94
+ return list(parameter_group_vars.values())
95
+
96
+
97
+ def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
98
+ opt_lower = args.opt.lower()
99
+ weight_decay = args.weight_decay
100
+ if weight_decay and filter_bias_and_bn:
101
+ skip = {}
102
+ if skip_list is not None:
103
+ skip = skip_list
104
+ elif hasattr(model, 'no_weight_decay'):
105
+ skip = model.no_weight_decay()
106
+ parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
107
+ weight_decay = 0.
108
+ else:
109
+ parameters = model.parameters()
110
+
111
+ opt_args = dict(lr=args.lr, weight_decay=weight_decay)
112
+ if hasattr(args, 'opt_eps') and args.opt_eps is not None:
113
+ opt_args['eps'] = args.opt_eps
114
+ if hasattr(args, 'opt_betas') and args.opt_betas is not None:
115
+ opt_args['betas'] = args.opt_betas
116
+
117
+ opt_split = opt_lower.split('_')
118
+ opt_lower = opt_split[-1]
119
+ if opt_lower == 'adamw':
120
+ optimizer = optim.AdamW(parameters, **opt_args)
121
+ else:
122
+ raise ValueError("Invalid optimizer")
123
+
124
+ if len(opt_split) > 1:
125
+ if opt_split[0] == 'lookahead':
126
+ optimizer = Lookahead(optimizer)
127
+
128
+ return optimizer
model/unilm/beit3/randaug.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ ## aug functions
6
+ def identity_func(img):
7
+ return img
8
+
9
+
10
+ def autocontrast_func(img, cutoff=0):
11
+ '''
12
+ same output as PIL.ImageOps.autocontrast
13
+ '''
14
+ n_bins = 256
15
+
16
+ def tune_channel(ch):
17
+ n = ch.size
18
+ cut = cutoff * n // 100
19
+ if cut == 0:
20
+ high, low = ch.max(), ch.min()
21
+ else:
22
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
23
+ low = np.argwhere(np.cumsum(hist) > cut)
24
+ low = 0 if low.shape[0] == 0 else low[0]
25
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
26
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
27
+ if high <= low:
28
+ table = np.arange(n_bins)
29
+ else:
30
+ scale = (n_bins - 1) / (high - low)
31
+ offset = -low * scale
32
+ table = np.arange(n_bins) * scale + offset
33
+ table[table < 0] = 0
34
+ table[table > n_bins - 1] = n_bins - 1
35
+ table = table.clip(0, 255).astype(np.uint8)
36
+ return table[ch]
37
+
38
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
39
+ out = cv2.merge(channels)
40
+ return out
41
+
42
+
43
+ def equalize_func(img):
44
+ '''
45
+ same output as PIL.ImageOps.equalize
46
+ PIL's implementation is different from cv2.equalize
47
+ '''
48
+ n_bins = 256
49
+
50
+ def tune_channel(ch):
51
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
52
+ non_zero_hist = hist[hist != 0].reshape(-1)
53
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
54
+ if step == 0: return ch
55
+ n = np.empty_like(hist)
56
+ n[0] = step // 2
57
+ n[1:] = hist[:-1]
58
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
59
+ return table[ch]
60
+
61
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
62
+ out = cv2.merge(channels)
63
+ return out
64
+
65
+
66
+ def rotate_func(img, degree, fill=(0, 0, 0)):
67
+ '''
68
+ like PIL, rotate by degree, not radians
69
+ '''
70
+ H, W = img.shape[0], img.shape[1]
71
+ center = W / 2, H / 2
72
+ M = cv2.getRotationMatrix2D(center, degree, 1)
73
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
74
+ return out
75
+
76
+
77
+ def solarize_func(img, thresh=128):
78
+ '''
79
+ same output as PIL.ImageOps.posterize
80
+ '''
81
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
82
+ table = table.clip(0, 255).astype(np.uint8)
83
+ out = table[img]
84
+ return out
85
+
86
+
87
+ def color_func(img, factor):
88
+ '''
89
+ same output as PIL.ImageEnhance.Color
90
+ '''
91
+ ## implementation according to PIL definition, quite slow
92
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
93
+ # out = blend(degenerate, img, factor)
94
+ # M = (
95
+ # np.eye(3) * factor
96
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
97
+ # )[np.newaxis, np.newaxis, :]
98
+ M = (
99
+ np.float32([
100
+ [0.886, -0.114, -0.114],
101
+ [-0.587, 0.413, -0.587],
102
+ [-0.299, -0.299, 0.701]]) * factor
103
+ + np.float32([[0.114], [0.587], [0.299]])
104
+ )
105
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
106
+ return out
107
+
108
+
109
+ def contrast_func(img, factor):
110
+ """
111
+ same output as PIL.ImageEnhance.Contrast
112
+ """
113
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
114
+ table = np.array([(
115
+ el - mean) * factor + mean
116
+ for el in range(256)
117
+ ]).clip(0, 255).astype(np.uint8)
118
+ out = table[img]
119
+ return out
120
+
121
+
122
+ def brightness_func(img, factor):
123
+ '''
124
+ same output as PIL.ImageEnhance.Contrast
125
+ '''
126
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
127
+ out = table[img]
128
+ return out
129
+
130
+
131
+ def sharpness_func(img, factor):
132
+ '''
133
+ The differences the this result and PIL are all on the 4 boundaries, the center
134
+ areas are same
135
+ '''
136
+ kernel = np.ones((3, 3), dtype=np.float32)
137
+ kernel[1][1] = 5
138
+ kernel /= 13
139
+ degenerate = cv2.filter2D(img, -1, kernel)
140
+ if factor == 0.0:
141
+ out = degenerate
142
+ elif factor == 1.0:
143
+ out = img
144
+ else:
145
+ out = img.astype(np.float32)
146
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
147
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
148
+ out = out.astype(np.uint8)
149
+ return out
150
+
151
+
152
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
153
+ H, W = img.shape[0], img.shape[1]
154
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
155
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
156
+ return out
157
+
158
+
159
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
160
+ '''
161
+ same output as PIL.Image.transform
162
+ '''
163
+ H, W = img.shape[0], img.shape[1]
164
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
165
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
166
+ return out
167
+
168
+
169
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
170
+ '''
171
+ same output as PIL.Image.transform
172
+ '''
173
+ H, W = img.shape[0], img.shape[1]
174
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
175
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
176
+ return out
177
+
178
+
179
+ def posterize_func(img, bits):
180
+ '''
181
+ same output as PIL.ImageOps.posterize
182
+ '''
183
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
184
+ return out
185
+
186
+
187
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
188
+ H, W = img.shape[0], img.shape[1]
189
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
190
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8)
191
+ return out
192
+
193
+
194
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
195
+ replace = np.array(replace, dtype=np.uint8)
196
+ H, W = img.shape[0], img.shape[1]
197
+ rh, rw = np.random.random(2)
198
+ pad_size = pad_size // 2
199
+ ch, cw = int(rh * H), int(rw * W)
200
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
201
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
202
+ out = img.copy()
203
+ out[x1:x2, y1:y2, :] = replace
204
+ return out
205
+
206
+
207
+ ### level to args
208
+ def enhance_level_to_args(MAX_LEVEL):
209
+ def level_to_args(level):
210
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
211
+ return level_to_args
212
+
213
+
214
+ def shear_level_to_args(MAX_LEVEL, replace_value):
215
+ def level_to_args(level):
216
+ level = (level / MAX_LEVEL) * 0.3
217
+ if np.random.random() > 0.5: level = -level
218
+ return (level, replace_value)
219
+
220
+ return level_to_args
221
+
222
+
223
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
224
+ def level_to_args(level):
225
+ level = (level / MAX_LEVEL) * float(translate_const)
226
+ if np.random.random() > 0.5: level = -level
227
+ return (level, replace_value)
228
+
229
+ return level_to_args
230
+
231
+
232
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
233
+ def level_to_args(level):
234
+ level = int((level / MAX_LEVEL) * cutout_const)
235
+ return (level, replace_value)
236
+
237
+ return level_to_args
238
+
239
+
240
+ def solarize_level_to_args(MAX_LEVEL):
241
+ def level_to_args(level):
242
+ level = int((level / MAX_LEVEL) * 256)
243
+ return (level, )
244
+ return level_to_args
245
+
246
+
247
+ def none_level_to_args(level):
248
+ return ()
249
+
250
+
251
+ def posterize_level_to_args(MAX_LEVEL):
252
+ def level_to_args(level):
253
+ level = int((level / MAX_LEVEL) * 4)
254
+ return (level, )
255
+ return level_to_args
256
+
257
+
258
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
259
+ def level_to_args(level):
260
+ level = (level / MAX_LEVEL) * 30
261
+ if np.random.random() < 0.5:
262
+ level = -level
263
+ return (level, replace_value)
264
+
265
+ return level_to_args
266
+
267
+
268
+ func_dict = {
269
+ 'Identity': identity_func,
270
+ 'AutoContrast': autocontrast_func,
271
+ 'Equalize': equalize_func,
272
+ 'Rotate': rotate_func,
273
+ 'Solarize': solarize_func,
274
+ 'Color': color_func,
275
+ 'Contrast': contrast_func,
276
+ 'Brightness': brightness_func,
277
+ 'Sharpness': sharpness_func,
278
+ 'ShearX': shear_x_func,
279
+ 'TranslateX': translate_x_func,
280
+ 'TranslateY': translate_y_func,
281
+ 'Posterize': posterize_func,
282
+ 'ShearY': shear_y_func,
283
+ }
284
+
285
+ translate_const = 10
286
+ MAX_LEVEL = 10
287
+ replace_value = (128, 128, 128)
288
+ arg_dict = {
289
+ 'Identity': none_level_to_args,
290
+ 'AutoContrast': none_level_to_args,
291
+ 'Equalize': none_level_to_args,
292
+ 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
293
+ 'Solarize': solarize_level_to_args(MAX_LEVEL),
294
+ 'Color': enhance_level_to_args(MAX_LEVEL),
295
+ 'Contrast': enhance_level_to_args(MAX_LEVEL),
296
+ 'Brightness': enhance_level_to_args(MAX_LEVEL),
297
+ 'Sharpness': enhance_level_to_args(MAX_LEVEL),
298
+ 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
299
+ 'TranslateX': translate_level_to_args(
300
+ translate_const, MAX_LEVEL, replace_value
301
+ ),
302
+ 'TranslateY': translate_level_to_args(
303
+ translate_const, MAX_LEVEL, replace_value
304
+ ),
305
+ 'Posterize': posterize_level_to_args(MAX_LEVEL),
306
+ 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
307
+ }
308
+
309
+
310
+ class RandomAugment(object):
311
+
312
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
313
+ self.N = N
314
+ self.M = M
315
+ self.isPIL = isPIL
316
+ if augs:
317
+ self.augs = augs
318
+ else:
319
+ self.augs = list(arg_dict.keys())
320
+
321
+ def get_random_ops(self):
322
+ sampled_ops = np.random.choice(self.augs, self.N)
323
+ return [(op, 0.5, self.M) for op in sampled_ops]
324
+
325
+ def __call__(self, img):
326
+ if self.isPIL:
327
+ img = np.array(img)
328
+ ops = self.get_random_ops()
329
+ for name, prob, level in ops:
330
+ if np.random.random() > prob:
331
+ continue
332
+ args = arg_dict[name](level)
333
+ img = func_dict[name](img, *args)
334
+ return img
335
+
336
+
337
+ if __name__ == '__main__':
338
+ a = RandomAugment()
339
+ img = np.random.randn(32, 32, 3)
340
+ a(img)
model/unilm/beit3/requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm==0.4.12
4
+ Pillow
5
+ blobfile
6
+ mypy
7
+ numpy
8
+ pytest
9
+ requests
10
+ einops
11
+ tensorboardX
12
+ scipy
13
+ ftfy
14
+ opencv-python
15
+ sentencepiece
16
+ pyarrow
17
+ torchmetrics==0.7.3
18
+ transformers
19
+ deepspeed==0.4.0
20
+ pycocotools
21
+ pycocoevalcap
22
+ torchscale==0.2.0
model/unilm/beit3/run_beit3_finetuning.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beit3
4
+ # Copyright (c) 2023 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # --------------------------------------------------------'
7
+
8
+ import argparse
9
+ import datetime
10
+ import numpy as np
11
+ import time
12
+ import torch
13
+ import torch.backends.cudnn as cudnn
14
+ import json
15
+ import os
16
+
17
+ from pathlib import Path
18
+
19
+ from timm.data.mixup import Mixup
20
+ from timm.models import create_model
21
+ from timm.utils import ModelEma
22
+ from optim_factory import create_optimizer, get_parameter_groups, \
23
+ LayerDecayValueAssigner, get_is_head_flag_for_vit
24
+
25
+ from engine_for_finetuning import train_one_epoch, get_handler, evaluate
26
+ from datasets import create_downstream_dataset
27
+ from utils import NativeScalerWithGradNormCount as NativeScaler
28
+ import utils
29
+ import modeling_finetune
30
+
31
+
32
+ def get_args():
33
+ parser = argparse.ArgumentParser('BEiT fine-tuning and evaluation script for image classification', add_help=False)
34
+
35
+ # Model parameters
36
+ parser.add_argument('--model', default='beit_base_patch16_224', type=str, metavar='MODEL',
37
+ help='Name of model to train')
38
+ parser.add_argument('--task', type=str, required=True,
39
+ choices=['nlvr2', 'vqav2', 'flickr30k', 'coco_retrieval', 'coco_captioning', 'nocaps', 'imagenet'],
40
+ help='Name of task to fine-tuning')
41
+
42
+ parser.add_argument('--input_size', default=224, type=int,
43
+ help='images input size')
44
+ parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
45
+ help='Drop path rate (default: 0.1)')
46
+
47
+ parser.add_argument('--checkpoint_activations', action='store_true', default=None,
48
+ help='Enable checkpointing to save your memory.')
49
+ parser.add_argument('--sentencepiece_model', type=str, required=True,
50
+ help='Sentencepiece model path for the pretrained model.')
51
+ parser.add_argument('--vocab_size', type=int, default=64010)
52
+ parser.add_argument('--num_max_bpe_tokens', type=int, default=64)
53
+
54
+ parser.add_argument('--model_ema', action='store_true', default=False)
55
+ parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
56
+ parser.add_argument('--model_ema_force_cpu', action='store_true', default=False, help='')
57
+
58
+ # Optimizer parameters
59
+ parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
60
+ help='Optimizer (default: "adamw"')
61
+ parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
62
+ help='Optimizer Epsilon (default: 1e-8)')
63
+ parser.add_argument('--opt_betas', default=[0.9, 0.999], type=float, nargs='+', metavar='BETA',
64
+ help='Optimizer Betas (default: 0.9, 0.999, use opt default)')
65
+ parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
66
+ help='Clip gradient norm (default: None, no clipping)')
67
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
68
+ help='SGD momentum (default: 0.9)')
69
+ parser.add_argument('--weight_decay', type=float, default=0.05,
70
+ help='weight decay (default: 0.05)')
71
+
72
+ parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
73
+ help='learning rate (default: 5e-4)')
74
+ parser.add_argument('--layer_decay', type=float, default=0.9)
75
+ parser.add_argument('--task_head_lr_weight', type=float, default=0)
76
+
77
+ parser.add_argument('--warmup_lr', type=float, default=1e-6, metavar='LR',
78
+ help='warmup learning rate (default: 1e-6)')
79
+ parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
80
+ help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
81
+ parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N',
82
+ help='epochs to warmup LR, if scheduler supports')
83
+ parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
84
+ help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
85
+
86
+ parser.add_argument('--batch_size', default=64, type=int)
87
+ parser.add_argument('--eval_batch_size', default=None, type=int)
88
+ parser.add_argument('--epochs', default=20, type=int)
89
+ parser.add_argument('--update_freq', default=1, type=int)
90
+ parser.add_argument('--save_ckpt_freq', default=5, type=int)
91
+
92
+ # Augmentation parameters
93
+ parser.add_argument('--randaug', action='store_true', default=False)
94
+ parser.add_argument('--train_interpolation', type=str, default='bicubic',
95
+ help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
96
+
97
+ # Finetuning params
98
+ parser.add_argument('--finetune', default='',
99
+ help='finetune from checkpoint')
100
+ parser.add_argument('--model_key', default='model|module', type=str)
101
+ parser.add_argument('--model_prefix', default='', type=str)
102
+
103
+ # Dataset parameters
104
+ parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
105
+ help='dataset path')
106
+
107
+ parser.add_argument('--output_dir', default='',
108
+ help='path where to save, empty for no saving')
109
+ parser.add_argument('--log_dir', default=None,
110
+ help='path where to tensorboard log')
111
+ parser.add_argument('--device', default='cuda',
112
+ help='device to use for training / testing')
113
+ parser.add_argument('--seed', default=0, type=int)
114
+ parser.add_argument('--resume', default='',
115
+ help='resume from checkpoint')
116
+ parser.add_argument('--auto_resume', action='store_true')
117
+ parser.add_argument('--no_auto_resume', action='store_false', dest='auto_resume')
118
+ parser.set_defaults(auto_resume=True)
119
+
120
+ parser.add_argument('--save_ckpt', action='store_true')
121
+ parser.add_argument('--no_save_ckpt', action='store_false', dest='save_ckpt')
122
+ parser.set_defaults(save_ckpt=True)
123
+
124
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
125
+ help='start epoch')
126
+ parser.add_argument('--eval', action='store_true',
127
+ help='Perform evaluation only')
128
+ parser.add_argument('--dist_eval', action='store_true', default=False,
129
+ help='Enabling distributed evaluation')
130
+ parser.add_argument('--num_workers', default=10, type=int)
131
+ parser.add_argument('--pin_mem', action='store_true',
132
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
133
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
134
+ parser.set_defaults(pin_mem=True)
135
+
136
+ # distributed training parameters
137
+ parser.add_argument('--world_size', default=1, type=int,
138
+ help='number of distributed processes')
139
+ parser.add_argument('--local_rank', default=-1, type=int)
140
+ parser.add_argument('--dist_on_itp', action='store_true')
141
+ parser.add_argument('--dist_url', default='env://',
142
+ help='url used to set up distributed training')
143
+
144
+ # parameter for dump predictions (VQA, COCO captioning, NoCaps)
145
+ parser.add_argument('--task_cache_path', default=None, type=str)
146
+
147
+ # parameter for imagenet finetuning
148
+ parser.add_argument('--nb_classes', default=1000, type=int,
149
+ help='number of the classification types')
150
+ parser.add_argument('--mixup', type=float, default=0,
151
+ help='mixup alpha, mixup enabled if > 0.')
152
+ parser.add_argument('--cutmix', type=float, default=0,
153
+ help='cutmix alpha, cutmix enabled if > 0.')
154
+ parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
155
+ help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
156
+ parser.add_argument('--mixup_prob', type=float, default=1.0,
157
+ help='Probability of performing mixup or cutmix when either/both is enabled')
158
+ parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
159
+ help='Probability of switching to cutmix when both mixup and cutmix enabled')
160
+ parser.add_argument('--mixup_mode', type=str, default='batch',
161
+ help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
162
+
163
+ # augmentation parameters for imagenet finetuning
164
+ parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
165
+ help='Color jitter factor (default: 0.4)')
166
+ parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
167
+ help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)')
168
+ parser.add_argument('--smoothing', type=float, default=0.1,
169
+ help='Label smoothing (default: 0.1)')
170
+
171
+ # evaluation parameters for imagenet
172
+ parser.add_argument('--crop_pct', type=float, default=None)
173
+
174
+ # random Erase params for imagenet finetuning
175
+ parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
176
+ help='Random erase prob (default: 0.25)')
177
+ parser.add_argument('--remode', type=str, default='pixel',
178
+ help='Random erase mode (default: "pixel")')
179
+ parser.add_argument('--recount', type=int, default=1,
180
+ help='Random erase count (default: 1)')
181
+ parser.add_argument('--resplit', action='store_true', default=False,
182
+ help='Do not random erase first (clean) augmentation split')
183
+
184
+ # parameter for captioning finetuning
185
+ parser.add_argument('--captioning_mask_prob', type=float, default=0.6)
186
+ parser.add_argument('--drop_worst_ratio', type=float, default=0.2)
187
+ parser.add_argument('--drop_worst_after', type=int, default=12000)
188
+ parser.add_argument('--num_beams', type=int, default=3)
189
+ parser.add_argument('--length_penalty', type=float, default=0.6)
190
+
191
+ # label smoothing for imagenet and captioning
192
+ parser.add_argument('--label_smoothing', type=float, default=0.1)
193
+
194
+ # deepspeed parameters
195
+ parser.add_argument('--enable_deepspeed', action='store_true', default=False)
196
+ parser.add_argument('--initial_scale_power', type=int, default=16)
197
+ parser.add_argument('--zero_stage', default=0, type=int,
198
+ help='ZeRO optimizer stage (default: 0)')
199
+
200
+ known_args, _ = parser.parse_known_args()
201
+
202
+ if known_args.enable_deepspeed:
203
+ try:
204
+ import deepspeed
205
+ from deepspeed import DeepSpeedConfig
206
+ parser = deepspeed.add_config_arguments(parser)
207
+ ds_init = deepspeed.initialize
208
+ except:
209
+ print("Please 'pip install deepspeed==0.4.0'")
210
+ exit(0)
211
+ else:
212
+ ds_init = None
213
+
214
+ return parser.parse_args(), ds_init
215
+
216
+
217
+ def main(args, ds_init):
218
+ utils.init_distributed_mode(args)
219
+
220
+ if ds_init is not None:
221
+ utils.create_ds_config(args)
222
+
223
+ if args.task_cache_path is None:
224
+ args.task_cache_path = args.output_dir
225
+
226
+ print(args)
227
+
228
+ device = torch.device(args.device)
229
+
230
+ # fix the seed for reproducibility
231
+ seed = args.seed + utils.get_rank()
232
+ torch.manual_seed(seed)
233
+ np.random.seed(seed)
234
+ # random.seed(seed)
235
+
236
+ cudnn.benchmark = True
237
+
238
+ if utils.get_rank() == 0 and args.log_dir is not None:
239
+ os.makedirs(args.log_dir, exist_ok=True)
240
+ log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
241
+ else:
242
+ log_writer = None
243
+
244
+ data_loader_train, data_loader_val = create_downstream_dataset(args)
245
+
246
+ if not args.model.endswith(args.task):
247
+ if args.task in ("flickr30k", "coco_retrieval"):
248
+ model_config = "%s_retrieval" % args.model
249
+ elif args.task in ("coco_captioning", "nocaps"):
250
+ model_config = "%s_captioning" % args.model
251
+ elif args.task in ("imagenet"):
252
+ model_config = "%s_imageclassification" % args.model
253
+ else:
254
+ model_config = "%s_%s" % (args.model, args.task)
255
+ else:
256
+ model_config = args.model
257
+ print("model_config = %s" % model_config)
258
+ model = create_model(
259
+ model_config,
260
+ pretrained=False,
261
+ drop_path_rate=args.drop_path,
262
+ vocab_size=args.vocab_size,
263
+ checkpoint_activations=args.checkpoint_activations,
264
+ )
265
+
266
+ if args.finetune:
267
+ utils.load_model_and_may_interpolate(args.finetune, model, args.model_key, args.model_prefix)
268
+
269
+ model.to(device)
270
+
271
+ model_ema = None
272
+ if args.model_ema:
273
+ # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
274
+ model_ema = ModelEma(
275
+ model,
276
+ decay=args.model_ema_decay,
277
+ device='cpu' if args.model_ema_force_cpu else '',
278
+ resume='')
279
+ print("Using EMA with decay = %.8f" % args.model_ema_decay)
280
+
281
+ model_without_ddp = model
282
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
283
+
284
+ print("Model = %s" % str(model_without_ddp))
285
+ print('number of params:', n_parameters)
286
+
287
+ total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
288
+ num_training_steps_per_epoch = len(data_loader_train.dataset) // total_batch_size
289
+ print("LR = %.8f" % args.lr)
290
+ print("Batch size = %d" % total_batch_size)
291
+ print("Update frequent = %d" % args.update_freq)
292
+ print("Number of training examples = %d" % len(data_loader_train.dataset))
293
+ print("Number of training training per epoch = %d" % num_training_steps_per_epoch)
294
+
295
+ num_layers = model_without_ddp.get_num_layers()
296
+ if args.layer_decay < 1.0:
297
+ lrs = list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))
298
+ assigner = LayerDecayValueAssigner(lrs)
299
+ elif args.task_head_lr_weight > 1:
300
+ assigner = LayerDecayValueAssigner([1.0, args.task_head_lr_weight], scale_handler=get_is_head_flag_for_vit)
301
+ else:
302
+ assigner = None
303
+
304
+ if assigner is not None:
305
+ print("Assigned values = %s" % str(assigner.values))
306
+
307
+ skip_weight_decay_list = model.no_weight_decay()
308
+
309
+ if args.distributed:
310
+ torch.distributed.barrier()
311
+ if args.enable_deepspeed:
312
+ loss_scaler = None
313
+ optimizer_params = get_parameter_groups(
314
+ model, args.weight_decay, skip_weight_decay_list,
315
+ assigner.get_layer_id if assigner is not None else None,
316
+ assigner.get_scale if assigner is not None else None)
317
+ model, optimizer, _, _ = ds_init(
318
+ args=args, model=model, model_parameters=optimizer_params,
319
+ dist_init_required=not args.distributed,
320
+ )
321
+
322
+ print("model.gradient_accumulation_steps() = %d" % model.gradient_accumulation_steps())
323
+ assert model.gradient_accumulation_steps() == args.update_freq
324
+ else:
325
+ if args.distributed:
326
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
327
+ model_without_ddp = model.module
328
+
329
+ optimizer = create_optimizer(
330
+ args, model_without_ddp, skip_list=skip_weight_decay_list,
331
+ get_num_layer=assigner.get_layer_id if assigner is not None else None,
332
+ get_layer_scale=assigner.get_scale if assigner is not None else None)
333
+ loss_scaler = NativeScaler()
334
+
335
+ lr_schedule_values = utils.cosine_scheduler(
336
+ args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
337
+ warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
338
+ )
339
+
340
+ utils.auto_load_model(
341
+ args=args, model=model, model_without_ddp=model_without_ddp,
342
+ optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
343
+
344
+ task_handler = get_handler(args)
345
+
346
+ # mixup for imagenet
347
+ mixup_fn = None
348
+ if args.task in ["imagenet", "in1k"]:
349
+ mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
350
+ if mixup_active:
351
+ print("Mixup is activated!")
352
+ mixup_fn = Mixup(
353
+ mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
354
+ prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
355
+ label_smoothing=args.label_smoothing, num_classes=args.nb_classes)
356
+
357
+ if args.eval:
358
+ data_loader_test = create_downstream_dataset(args, is_eval=True)
359
+ if args.task in ["nlvr2", "flickr30k", "coco_retrieval", "imagenet"]:
360
+ ext_test_stats, task_key = evaluate(data_loader_test, model, device, task_handler)
361
+ print(f"Accuracy of the network on the {len(data_loader_test.dataset)} test images: {ext_test_stats[task_key]:.3f}%")
362
+ exit(0)
363
+ elif args.task == "vqav2":
364
+ result, _ = evaluate(data_loader_test, model, device, task_handler)
365
+ utils.dump_predictions(args, result, "vqav2_test")
366
+ exit(0)
367
+ elif args.task in ["coco_captioning", "nocaps"]:
368
+ predictions, _ = evaluate(data_loader_test, model, device, task_handler)
369
+ prediction_file = utils.dump_predictions(args, predictions, "{}_test".format(args.task))
370
+ if utils.is_main_process() and args.task == "coco_captioning":
371
+ captioning_result = utils.coco_caption_eval(args.output_dir, prediction_file, "{}_test".format(args.task))
372
+ result_file = os.path.join(args.output_dir, f"{args.task}_result.json")
373
+ print(json.dumps(captioning_result))
374
+ utils.write_result_to_jsonl(captioning_result, result_file)
375
+ exit(0)
376
+
377
+ print(f"Start training for {args.epochs} epochs")
378
+ start_time = time.time()
379
+
380
+ max_accuracy = 0.0
381
+ for epoch in range(args.start_epoch, args.epochs):
382
+ if args.distributed:
383
+ data_loader_train.sampler.set_epoch(epoch)
384
+ if log_writer is not None:
385
+ log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
386
+ train_stats = train_one_epoch(
387
+ model, data_loader_train, optimizer, device, task_handler, epoch,
388
+ epoch * num_training_steps_per_epoch, lr_schedule_values, loss_scaler,
389
+ args.clip_grad, args.update_freq, model_ema, log_writer, args.task, mixup_fn,
390
+ )
391
+ if args.output_dir and args.save_ckpt:
392
+ if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
393
+ utils.save_model(
394
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
395
+ loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema)
396
+ if data_loader_val is not None:
397
+ if args.task not in ["coco_captioning", "nocaps"]:
398
+ test_stats, task_key = evaluate(data_loader_val, model, device, task_handler)
399
+ else:
400
+ predictions, _ = evaluate(data_loader_val, model, device, task_handler)
401
+ prediction_file = utils.dump_predictions(args, predictions, f"{args.task}_val_e{epoch}")
402
+ result_file = os.path.join(args.output_dir, f"{args.task}_result_val_e{epoch}.json")
403
+ task_key = "CIDEr"
404
+ if utils.is_main_process():
405
+ test_stats = utils.coco_caption_eval(args.output_dir, prediction_file, "{}_val".format(args.task))
406
+ utils.write_result_to_jsonl(test_stats, result_file)
407
+ torch.distributed.barrier()
408
+ if not utils.is_main_process():
409
+ test_stats = utils.read_result_from_jsonl(result_file)
410
+
411
+ print(f"Performance of the network on the {len(data_loader_val.dataset)} val images: {test_stats[task_key]:.1f}%")
412
+ if max_accuracy < test_stats[task_key]:
413
+ max_accuracy = test_stats[task_key]
414
+ if args.output_dir and args.save_ckpt:
415
+ utils.save_model(
416
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
417
+ loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)
418
+
419
+ print(f'Max performance: {max_accuracy:.2f}%')
420
+ if log_writer is not None:
421
+ log_writer.update(acc=test_stats[task_key], head="perf", step=epoch)
422
+
423
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
424
+ **{f'val_{k}': v for k, v in test_stats.items()},
425
+ 'epoch': epoch,
426
+ 'n_parameters': n_parameters}
427
+ else:
428
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
429
+ # **{f'test_{k}': v for k, v in test_stats.items()},
430
+ 'epoch': epoch,
431
+ 'n_parameters': n_parameters}
432
+
433
+ if args.output_dir and utils.is_main_process():
434
+ if log_writer is not None:
435
+ log_writer.flush()
436
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
437
+ f.write(json.dumps(log_stats) + "\n")
438
+
439
+ total_time = time.time() - start_time
440
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
441
+ print('Training time {}'.format(total_time_str))
442
+
443
+
444
+ if __name__ == '__main__':
445
+ opts, ds_init = get_args()
446
+ if opts.output_dir:
447
+ Path(opts.output_dir).mkdir(parents=True, exist_ok=True)
448
+ main(opts, ds_init)
model/unilm/beit3/utils.py ADDED
@@ -0,0 +1,913 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Image as a Foreign Language: BEiT Pretraining for Vision and Vision-Language Tasks (https://arxiv.org/abs/2208.10442)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beit3
4
+ # Copyright (c) 2023 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # --------------------------------------------------------'
7
+
8
+ import datetime
9
+ import io
10
+ import os
11
+ import math
12
+ import time
13
+ import json
14
+ import argparse
15
+ import numpy as np
16
+ from pathlib import Path
17
+ from collections import defaultdict, deque
18
+ from timm.utils import get_state_dict
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torch._six import inf
25
+ from torchmetrics import Metric
26
+ from tensorboardX import SummaryWriter
27
+
28
+
29
+ def bool_flag(s):
30
+ """
31
+ Parse boolean arguments from the command line.
32
+ """
33
+ FALSY_STRINGS = {"off", "false", "0"}
34
+ TRUTHY_STRINGS = {"on", "true", "1"}
35
+ if s.lower() in FALSY_STRINGS:
36
+ return False
37
+ elif s.lower() in TRUTHY_STRINGS:
38
+ return True
39
+ else:
40
+ raise argparse.ArgumentTypeError("invalid value for a boolean flag")
41
+
42
+
43
+ class SmoothedValue(object):
44
+ """Track a series of values and provide access to smoothed values over a
45
+ window or the global series average.
46
+ """
47
+
48
+ def __init__(self, window_size=20, fmt=None):
49
+ if fmt is None:
50
+ fmt = "{median:.4f} ({global_avg:.4f})"
51
+ self.deque = deque(maxlen=window_size)
52
+ self.total = 0.0
53
+ self.count = 0
54
+ self.fmt = fmt
55
+
56
+ def update(self, value, n=1):
57
+ self.deque.append(value)
58
+ self.count += n
59
+ self.total += value * n
60
+
61
+ def synchronize_between_processes(self):
62
+ """
63
+ Warning: does not synchronize the deque!
64
+ """
65
+ if not is_dist_avail_and_initialized():
66
+ return
67
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
68
+ dist.barrier()
69
+ dist.all_reduce(t)
70
+ t = t.tolist()
71
+ self.count = int(t[0])
72
+ self.total = t[1]
73
+
74
+ @property
75
+ def median(self):
76
+ d = torch.tensor(list(self.deque))
77
+ return d.median().item()
78
+
79
+ @property
80
+ def avg(self):
81
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
82
+ return d.mean().item()
83
+
84
+ @property
85
+ def global_avg(self):
86
+ return self.total / self.count
87
+
88
+ @property
89
+ def max(self):
90
+ return max(self.deque)
91
+
92
+ @property
93
+ def value(self):
94
+ return self.deque[-1]
95
+
96
+ def __str__(self):
97
+ return self.fmt.format(
98
+ median=self.median,
99
+ avg=self.avg,
100
+ global_avg=self.global_avg,
101
+ max=self.max,
102
+ value=self.value)
103
+
104
+
105
+ class MetricLogger(object):
106
+ def __init__(self, delimiter="\t"):
107
+ self.meters = defaultdict(SmoothedValue)
108
+ self.delimiter = delimiter
109
+
110
+ def update(self, **kwargs):
111
+ for k, v in kwargs.items():
112
+ if v is None:
113
+ continue
114
+ if isinstance(v, torch.Tensor):
115
+ v = v.item()
116
+ assert isinstance(v, (float, int))
117
+ self.meters[k].update(v)
118
+
119
+ def __getattr__(self, attr):
120
+ if attr in self.meters:
121
+ return self.meters[attr]
122
+ if attr in self.__dict__:
123
+ return self.__dict__[attr]
124
+ raise AttributeError("'{}' object has no attribute '{}'".format(
125
+ type(self).__name__, attr))
126
+
127
+ def __str__(self):
128
+ loss_str = []
129
+ for name, meter in self.meters.items():
130
+ loss_str.append(
131
+ "{}: {}".format(name, str(meter))
132
+ )
133
+ return self.delimiter.join(loss_str)
134
+
135
+ def synchronize_between_processes(self):
136
+ for meter in self.meters.values():
137
+ meter.synchronize_between_processes()
138
+
139
+ def add_meter(self, name, meter):
140
+ self.meters[name] = meter
141
+
142
+ def log_every(self, iterable, print_freq, header=None):
143
+ i = 0
144
+ if not header:
145
+ header = ''
146
+ start_time = time.time()
147
+ end = time.time()
148
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
149
+ data_time = SmoothedValue(fmt='{avg:.4f}')
150
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
151
+ log_msg = [
152
+ header,
153
+ '[{0' + space_fmt + '}/{1}]',
154
+ 'eta: {eta}',
155
+ '{meters}',
156
+ 'time: {time}',
157
+ 'data: {data}'
158
+ ]
159
+ if torch.cuda.is_available():
160
+ log_msg.append('max mem: {memory:.0f}')
161
+ log_msg = self.delimiter.join(log_msg)
162
+ MB = 1024.0 * 1024.0
163
+ for obj in iterable:
164
+ data_time.update(time.time() - end)
165
+ yield obj
166
+ iter_time.update(time.time() - end)
167
+ if i % print_freq == 0 or i == len(iterable) - 1:
168
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
169
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
170
+ if torch.cuda.is_available():
171
+ print(log_msg.format(
172
+ i, len(iterable), eta=eta_string,
173
+ meters=str(self),
174
+ time=str(iter_time), data=str(data_time),
175
+ memory=torch.cuda.max_memory_allocated() / MB))
176
+ else:
177
+ print(log_msg.format(
178
+ i, len(iterable), eta=eta_string,
179
+ meters=str(self),
180
+ time=str(iter_time), data=str(data_time)))
181
+ i += 1
182
+ end = time.time()
183
+ total_time = time.time() - start_time
184
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
185
+ print('{} Total time: {} ({:.4f} s / it)'.format(
186
+ header, total_time_str, total_time / len(iterable)))
187
+
188
+
189
+ class TensorboardLogger(object):
190
+ def __init__(self, log_dir):
191
+ self.writer = SummaryWriter(logdir=log_dir)
192
+ self.step = 0
193
+
194
+ def set_step(self, step=None):
195
+ if step is not None:
196
+ self.step = step
197
+ else:
198
+ self.step += 1
199
+
200
+ def update(self, head='scalar', step=None, **kwargs):
201
+ for k, v in kwargs.items():
202
+ if v is None:
203
+ continue
204
+ if isinstance(v, torch.Tensor):
205
+ v = v.item()
206
+ assert isinstance(v, (float, int))
207
+ self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
208
+
209
+ def flush(self):
210
+ self.writer.flush()
211
+
212
+
213
+ def _load_checkpoint_for_ema(model_ema, checkpoint):
214
+ """
215
+ Workaround for ModelEma._load_checkpoint to accept an already-loaded object
216
+ """
217
+ mem_file = io.BytesIO()
218
+ torch.save(checkpoint, mem_file)
219
+ mem_file.seek(0)
220
+ model_ema._load_checkpoint(mem_file)
221
+
222
+
223
+ def setup_for_distributed(is_master):
224
+ """
225
+ This function disables printing when not in master process
226
+ """
227
+ import builtins as __builtin__
228
+ builtin_print = __builtin__.print
229
+
230
+ def print(*args, **kwargs):
231
+ force = kwargs.pop('force', False)
232
+ if is_master or force:
233
+ builtin_print(*args, **kwargs)
234
+
235
+ __builtin__.print = print
236
+
237
+
238
+ def is_dist_avail_and_initialized():
239
+ if not dist.is_available():
240
+ return False
241
+ if not dist.is_initialized():
242
+ return False
243
+ return True
244
+
245
+
246
+ def get_world_size():
247
+ if not is_dist_avail_and_initialized():
248
+ return 1
249
+ return dist.get_world_size()
250
+
251
+
252
+ def get_rank():
253
+ if not is_dist_avail_and_initialized():
254
+ return 0
255
+ return dist.get_rank()
256
+
257
+
258
+ def is_main_process():
259
+ return get_rank() == 0
260
+
261
+
262
+ def save_on_master(*args, **kwargs):
263
+ if is_main_process():
264
+ torch.save(*args, **kwargs)
265
+
266
+
267
+ def _get_rank_env():
268
+ if "RANK" in os.environ:
269
+ return int(os.environ["RANK"])
270
+ else:
271
+ return int(os.environ['OMPI_COMM_WORLD_RANK'])
272
+
273
+
274
+ def _get_local_rank_env():
275
+ if "LOCAL_RANK" in os.environ:
276
+ return int(os.environ["LOCAL_RANK"])
277
+ else:
278
+ return int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
279
+
280
+
281
+ def _get_world_size_env():
282
+ if "WORLD_SIZE" in os.environ:
283
+ return int(os.environ["WORLD_SIZE"])
284
+ else:
285
+ return int(os.environ['OMPI_COMM_WORLD_SIZE'])
286
+
287
+
288
+ # The implementation code is modified from DeiT (https://github.com/facebookresearch/deit.git)
289
+ def init_distributed_mode(args):
290
+ if args.dist_on_itp:
291
+ args.rank = _get_rank_env()
292
+ args.world_size = _get_world_size_env() # int(os.environ['OMPI_COMM_WORLD_SIZE'])
293
+ args.gpu = _get_local_rank_env()
294
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
295
+ os.environ['LOCAL_RANK'] = str(args.gpu)
296
+ os.environ['RANK'] = str(args.rank)
297
+ os.environ['WORLD_SIZE'] = str(args.world_size)
298
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
299
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
300
+ args.rank = int(os.environ["RANK"])
301
+ args.world_size = int(os.environ['WORLD_SIZE'])
302
+ args.gpu = int(os.environ['LOCAL_RANK'])
303
+ elif 'SLURM_PROCID' in os.environ:
304
+ args.rank = int(os.environ['SLURM_PROCID'])
305
+ args.gpu = args.rank % torch.cuda.device_count()
306
+ else:
307
+ print('Not using distributed mode')
308
+ args.distributed = False
309
+ return
310
+
311
+ args.distributed = True
312
+
313
+ torch.cuda.set_device(args.gpu)
314
+ args.dist_backend = 'nccl'
315
+ print('| distributed init (rank {}): {}, gpu {}'.format(
316
+ args.rank, args.dist_url, args.gpu), flush=True)
317
+ torch.distributed.init_process_group(
318
+ backend=args.dist_backend, init_method=args.dist_url,
319
+ world_size=args.world_size, rank=args.rank,
320
+ timeout=datetime.timedelta(0, 7200)
321
+ )
322
+ torch.distributed.barrier()
323
+ setup_for_distributed(args.rank == 0)
324
+
325
+
326
+ def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
327
+ missing_keys = []
328
+ unexpected_keys = []
329
+ error_msgs = []
330
+ # copy state_dict so _load_from_state_dict can modify it
331
+ metadata = getattr(state_dict, '_metadata', None)
332
+ state_dict = state_dict.copy()
333
+ if metadata is not None:
334
+ state_dict._metadata = metadata
335
+
336
+ def load(module, prefix=''):
337
+ local_metadata = {} if metadata is None else metadata.get(
338
+ prefix[:-1], {})
339
+ module._load_from_state_dict(
340
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
341
+ for name, child in module._modules.items():
342
+ if child is not None:
343
+ load(child, prefix + name + '.')
344
+
345
+ load(model, prefix=prefix)
346
+
347
+ warn_missing_keys = []
348
+ ignore_missing_keys = []
349
+ for key in missing_keys:
350
+ keep_flag = True
351
+ for ignore_key in ignore_missing.split('|'):
352
+ if ignore_key in key:
353
+ keep_flag = False
354
+ break
355
+ if keep_flag:
356
+ warn_missing_keys.append(key)
357
+ else:
358
+ ignore_missing_keys.append(key)
359
+
360
+ missing_keys = warn_missing_keys
361
+
362
+ if len(missing_keys) > 0:
363
+ print("Weights of {} not initialized from pretrained model: {}".format(
364
+ model.__class__.__name__, missing_keys))
365
+ if len(unexpected_keys) > 0:
366
+ print("Weights from pretrained model not used in {}: {}".format(
367
+ model.__class__.__name__, unexpected_keys))
368
+ if len(ignore_missing_keys) > 0:
369
+ print("Ignored weights of {} not initialized from pretrained model: {}".format(
370
+ model.__class__.__name__, ignore_missing_keys))
371
+ if len(error_msgs) > 0:
372
+ print('\n'.join(error_msgs))
373
+
374
+
375
+ class NativeScalerWithGradNormCount:
376
+ state_dict_key = "amp_scaler"
377
+
378
+ def __init__(self):
379
+ self._scaler = torch.cuda.amp.GradScaler()
380
+
381
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
382
+ self._scaler.scale(loss).backward(create_graph=create_graph)
383
+ if update_grad:
384
+ if clip_grad is not None:
385
+ assert parameters is not None
386
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
387
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
388
+ else:
389
+ self._scaler.unscale_(optimizer)
390
+ norm = get_grad_norm_(parameters)
391
+ self._scaler.step(optimizer)
392
+ self._scaler.update()
393
+ else:
394
+ norm = None
395
+ return norm
396
+
397
+ def state_dict(self):
398
+ return self._scaler.state_dict()
399
+
400
+ def load_state_dict(self, state_dict):
401
+ self._scaler.load_state_dict(state_dict)
402
+
403
+
404
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
405
+ if isinstance(parameters, torch.Tensor):
406
+ parameters = [parameters]
407
+ parameters = [p for p in parameters if p.grad is not None]
408
+ norm_type = float(norm_type)
409
+ if len(parameters) == 0:
410
+ return torch.tensor(0.)
411
+ device = parameters[0].grad.device
412
+ if norm_type == inf:
413
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
414
+ else:
415
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
416
+ return total_norm
417
+
418
+
419
+ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
420
+ start_warmup_value=0, warmup_steps=-1, sched_type="cos"):
421
+ warmup_schedule = np.array([])
422
+ warmup_iters = warmup_epochs * niter_per_ep
423
+ if warmup_steps > 0:
424
+ warmup_iters = warmup_steps
425
+ print("Set warmup steps = %d" % warmup_iters)
426
+ if warmup_epochs > 0:
427
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
428
+
429
+ if sched_type == "cos":
430
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
431
+ schedule = np.array([
432
+ final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
433
+ elif sched_type == "linear":
434
+ schedule = np.linspace(base_value, final_value, epochs * niter_per_ep - warmup_iters)
435
+ else:
436
+ raise NotImplementedError()
437
+
438
+ schedule = np.concatenate((warmup_schedule, schedule))
439
+
440
+ assert len(schedule) == epochs * niter_per_ep
441
+ return schedule
442
+
443
+
444
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
445
+ output_dir = Path(args.output_dir)
446
+ if loss_scaler is not None:
447
+ checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch)]
448
+ for checkpoint_path in checkpoint_paths:
449
+ to_save = {
450
+ 'model': model_without_ddp.state_dict(),
451
+ 'optimizer': optimizer.state_dict(),
452
+ 'epoch': epoch,
453
+ 'scaler': loss_scaler.state_dict(),
454
+ 'args': args,
455
+ }
456
+
457
+ if model_ema is not None:
458
+ to_save['model_ema'] = get_state_dict(model_ema)
459
+
460
+ save_on_master(to_save, checkpoint_path)
461
+ else:
462
+ client_state = {'epoch': epoch, "args": args}
463
+ if model_ema is not None:
464
+ client_state['model_ema'] = get_state_dict(model_ema)
465
+ model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch, client_state=client_state)
466
+
467
+
468
+ def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
469
+ output_dir = Path(args.output_dir)
470
+ if loss_scaler is not None:
471
+ # torch.amp
472
+ if args.auto_resume and len(args.resume) == 0:
473
+ import glob
474
+ all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
475
+ latest_ckpt = -1
476
+ for ckpt in all_checkpoints:
477
+ t = ckpt.split('-')[-1].split('.')[0]
478
+ if t.isdigit():
479
+ latest_ckpt = max(int(t), latest_ckpt)
480
+ if latest_ckpt >= 0:
481
+ args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
482
+ print("Auto resume checkpoint: %s" % args.resume)
483
+
484
+ if args.resume:
485
+ if args.resume.startswith('https'):
486
+ checkpoint = torch.hub.load_state_dict_from_url(
487
+ args.resume, map_location='cpu', check_hash=True)
488
+ else:
489
+ checkpoint = torch.load(args.resume, map_location='cpu')
490
+ model_without_ddp.load_state_dict(checkpoint['model'])
491
+ print("Resume checkpoint %s" % args.resume)
492
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint:
493
+ optimizer.load_state_dict(checkpoint['optimizer'])
494
+ args.start_epoch = checkpoint['epoch'] + 1
495
+ if hasattr(args, 'model_ema') and args.model_ema:
496
+ _load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
497
+ if 'scaler' in checkpoint:
498
+ loss_scaler.load_state_dict(checkpoint['scaler'])
499
+ print("With optim & sched!")
500
+ else:
501
+ # deepspeed, only support '--auto_resume'.
502
+ if args.auto_resume:
503
+ import glob
504
+ all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*'))
505
+ latest_ckpt = -1
506
+ for ckpt in all_checkpoints:
507
+ t = ckpt.split('-')[-1].split('.')[0]
508
+ if t.isdigit():
509
+ latest_ckpt = max(int(t), latest_ckpt)
510
+ if latest_ckpt >= 0:
511
+ args.resume = os.path.join(output_dir, 'checkpoint-%d' % latest_ckpt)
512
+ print("Auto resume checkpoint: %d" % latest_ckpt)
513
+ _, client_states = model.load_checkpoint(args.output_dir, tag='checkpoint-%d' % latest_ckpt)
514
+ args.start_epoch = client_states['epoch'] + 1
515
+ if model_ema is not None:
516
+ if args.model_ema:
517
+ _load_checkpoint_for_ema(model_ema, client_states['model_ema'])
518
+
519
+
520
+ # The implementation code is modified from DeiT (https://github.com/facebookresearch/deit.git)
521
+ def load_model_and_may_interpolate(ckpt_path, model, model_key, model_prefix):
522
+ if ckpt_path.startswith('https'):
523
+ checkpoint = torch.hub.load_state_dict_from_url(
524
+ ckpt_path, map_location='cpu', check_hash=True)
525
+ else:
526
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
527
+
528
+ print("Load ckpt from %s" % ckpt_path)
529
+ checkpoint_model = None
530
+ for model_key in model_key.split('|'):
531
+ if model_key in checkpoint:
532
+ checkpoint_model = checkpoint[model_key]
533
+ print("Load state_dict by model_key = %s" % model_key)
534
+ break
535
+
536
+ if checkpoint_model is None:
537
+ checkpoint_model = checkpoint
538
+
539
+ state_dict = model.state_dict()
540
+ for k in ['head.weight', 'head.bias']:
541
+ if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
542
+ print(f"Removing key {k} from pretrained checkpoint")
543
+ del checkpoint_model[k]
544
+
545
+ # interpolate position embedding
546
+ for pos_embed_key in ("vision_pos_embed", "pos_embed", "beit3.encoder.embed_positions.A.weight"):
547
+ if pos_embed_key in checkpoint_model:
548
+ pos_embed_checkpoint = checkpoint_model[pos_embed_key]
549
+ embedding_size = pos_embed_checkpoint.shape[-1]
550
+ if pos_embed_key == "beit3.encoder.embed_positions.A.weight":
551
+ # being consistent with Fairseq, which starts from 2 for position embedding
552
+ torchscale_model = True
553
+ num_patches = model.beit3.vision_embed.num_patches
554
+ num_extra_tokens = model.beit3.vision_embed.num_position_embeddings() + 2 - num_patches
555
+ else:
556
+ torchscale_model = False
557
+ num_patches = model.patch_embed.num_patches
558
+ num_extra_tokens = getattr(model, pos_embed_key).shape[-2] - num_patches
559
+ # height (== width) for the checkpoint position embedding
560
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
561
+ # height (== width) for the new position embedding
562
+ new_size = int(num_patches ** 0.5)
563
+ # class_token and dist_token are kept unchanged
564
+ if orig_size != new_size:
565
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
566
+ if torchscale_model:
567
+ extra_tokens = pos_embed_checkpoint[:num_extra_tokens].unsqueeze(0)
568
+ # only the position tokens are interpolated
569
+ pos_tokens = pos_embed_checkpoint[num_extra_tokens:]
570
+ else:
571
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
572
+ # only the position tokens are interpolated
573
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
574
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
575
+ pos_tokens = torch.nn.functional.interpolate(
576
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
577
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
578
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
579
+ if torchscale_model:
580
+ new_pos_embed = new_pos_embed.squeeze(0)
581
+ checkpoint_model[pos_embed_key] = new_pos_embed
582
+
583
+ load_state_dict(model, checkpoint_model, prefix=model_prefix)
584
+
585
+
586
+ def create_ds_config(args):
587
+ args.deepspeed_config = os.path.join(args.output_dir, "deepspeed_config.json")
588
+ with open(args.deepspeed_config, mode="w") as writer:
589
+ ds_config = {
590
+ "train_batch_size": args.batch_size * args.update_freq * get_world_size(),
591
+ "train_micro_batch_size_per_gpu": args.batch_size,
592
+ "steps_per_print": 1000,
593
+ "optimizer": {
594
+ "type": "Adam",
595
+ "adam_w_mode": True,
596
+ "params": {
597
+ "lr": args.lr,
598
+ "weight_decay": args.weight_decay,
599
+ "bias_correction": True,
600
+ "betas": [
601
+ args.opt_betas[0],
602
+ args.opt_betas[1]
603
+ ],
604
+ "eps": args.opt_eps
605
+ }
606
+ },
607
+ "fp16": {
608
+ "enabled": True,
609
+ "loss_scale": 0,
610
+ "initial_scale_power": getattr(args, "initial_scale_power", 12),
611
+ "loss_scale_window": 1000,
612
+ "hysteresis": 2,
613
+ "min_loss_scale": 1
614
+ },
615
+ "amp": {
616
+ "enabled": False,
617
+ "opt_level": "O2"
618
+ }
619
+ }
620
+
621
+ if args.clip_grad is not None:
622
+ ds_config.update({'gradient_clipping': args.clip_grad})
623
+
624
+ if args.zero_stage == 1:
625
+ ds_config.update({"zero_optimization": {"stage": args.zero_stage, "reduce_bucket_size": 5e8}})
626
+ elif args.zero_stage > 1:
627
+ raise NotImplementedError()
628
+
629
+ writer.write(json.dumps(ds_config, indent=2))
630
+
631
+
632
+ def merge_batch_tensors_by_dict_key(batch):
633
+ batch_tensors = {}
634
+ for tensor_key in batch[0]:
635
+ if isinstance(batch[0][tensor_key], torch.Tensor):
636
+ batch_tensors[tensor_key] = torch.stack([d[tensor_key] for d in batch])
637
+ else:
638
+ batch_tensors[tensor_key] = torch.tensor([d[tensor_key] for d in batch], dtype=torch.long)
639
+ return batch_tensors
640
+
641
+
642
+ def get_loss_scale_for_deepspeed(model):
643
+ optimizer = model.optimizer
644
+ loss_scale = None
645
+ if hasattr(optimizer, 'loss_scale'):
646
+ loss_scale = optimizer.loss_scale
647
+ elif hasattr(optimizer, 'cur_scale'):
648
+ loss_scale = optimizer.cur_scale
649
+ return loss_scale
650
+
651
+
652
+ class GatherLayer(torch.autograd.Function):
653
+ """
654
+ Gather tensors from all workers with support for backward propagation:
655
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
656
+ """
657
+ @staticmethod
658
+ def forward(ctx, x):
659
+ output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
660
+ dist.all_gather(output, x)
661
+ return tuple(output)
662
+ @staticmethod
663
+ def backward(ctx, *grads):
664
+ all_gradients = torch.stack(grads)
665
+ dist.all_reduce(all_gradients)
666
+ return all_gradients[dist.get_rank()]
667
+
668
+
669
+ def gather_features(
670
+ image_features,
671
+ text_features,
672
+ ):
673
+ gathered_image_features = GatherLayer.apply(image_features)
674
+ gathered_text_features = GatherLayer.apply(text_features)
675
+ all_image_features = torch.cat(gathered_image_features)
676
+ all_text_features = torch.cat(gathered_text_features)
677
+
678
+ return all_image_features, all_text_features
679
+
680
+
681
+ # The implementation code is modified from open_clip (https://github.com/mlfoundations/open_clip.git)
682
+ class ClipLoss(nn.Module):
683
+
684
+ def __init__(
685
+ self,
686
+ cache_labels=False,
687
+ rank=0,
688
+ world_size=1,
689
+ ):
690
+ super().__init__()
691
+ self.cache_labels = cache_labels
692
+ self.rank = rank
693
+ self.world_size = world_size
694
+
695
+ # cache state
696
+ self.prev_num_logits = 0
697
+ self.labels = {}
698
+
699
+ def forward(self, image_features, text_features, logit_scale):
700
+ device = image_features.device
701
+ if self.world_size > 1:
702
+ all_image_features, all_text_features = gather_features(
703
+ image_features, text_features
704
+ )
705
+
706
+ logits_per_image = logit_scale * image_features @ all_text_features.T
707
+ logits_per_text = logit_scale * text_features @ all_image_features.T
708
+ else:
709
+ logits_per_image = logit_scale * image_features @ text_features.T
710
+ logits_per_text = logit_scale * text_features @ image_features.T
711
+
712
+ # calculated ground-truth and cache if enabled
713
+ num_logits = logits_per_image.shape[0]
714
+ if self.prev_num_logits != num_logits or device not in self.labels:
715
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
716
+ if self.world_size > 1:
717
+ labels = labels + num_logits * self.rank
718
+ if self.cache_labels:
719
+ self.labels[device] = labels
720
+ self.prev_num_logits = num_logits
721
+ else:
722
+ labels = self.labels[device]
723
+
724
+ total_loss = (
725
+ F.cross_entropy(logits_per_image, labels) +
726
+ F.cross_entropy(logits_per_text, labels)
727
+ ) / 2
728
+ return total_loss, logits_per_image, logits_per_text
729
+
730
+
731
+ def write_result_to_jsonl(test_stats, result_file):
732
+ with open(result_file, mode="w", encoding="utf-8") as writer:
733
+ writer.write(json.dumps(test_stats, indent=None))
734
+
735
+
736
+ def read_result_from_jsonl(result_file):
737
+ with open(result_file, mode="r", encoding="utf-8") as reader:
738
+ return json.load(reader)
739
+
740
+
741
+ # The implementation code is from ViLT (https://github.com/dandelin/ViLT.git)
742
+ class VQAScore(Metric):
743
+ def __init__(self, dist_sync_on_step=False):
744
+ super().__init__(dist_sync_on_step=dist_sync_on_step)
745
+ self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum")
746
+ self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
747
+
748
+ def update(self, logits, target):
749
+ logits, target = (
750
+ logits.detach().float().to(self.score.device),
751
+ target.detach().float().to(self.score.device),
752
+ )
753
+ logits = torch.max(logits, 1)[1]
754
+ one_hots = torch.zeros(*target.size()).to(target)
755
+ one_hots.scatter_(1, logits.view(-1, 1), 1)
756
+ scores = one_hots * target
757
+
758
+ self.score += scores.sum()
759
+ self.total += len(logits)
760
+
761
+ def compute(self):
762
+ return self.score / self.total
763
+
764
+
765
+ class BertCaptioningLoss(nn.Module):
766
+ def __init__(self, label_smoothing, drop_worst_ratio, drop_worst_after):
767
+ super().__init__()
768
+ self.label_smoothing = label_smoothing
769
+ self.drop_worst_ratio = drop_worst_ratio
770
+ self.drop_worst_after = drop_worst_after
771
+ self.log_soft = nn.LogSoftmax(dim=1)
772
+ self.kl = nn.KLDivLoss(reduction='none')
773
+ self.iter = 0
774
+
775
+ def forward(self, logits, target, iter):
776
+ eps = self.label_smoothing
777
+ n_class = logits.size(1)
778
+ one_hot = torch.zeros_like(logits).scatter(1, target.view(-1, 1), 1)
779
+ one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
780
+ log_prb = self.log_soft(logits)
781
+ loss = self.kl(log_prb, one_hot).sum(1)
782
+
783
+ if self.drop_worst_ratio > 0 and iter > self.drop_worst_after:
784
+ loss, _ = torch.topk(loss,
785
+ k=int(loss.shape[0] * (1-self.drop_worst_ratio)),
786
+ largest=False)
787
+ loss = loss.mean()
788
+
789
+ return loss
790
+
791
+
792
+ class BeamHypotheses(object):
793
+ def __init__(self, n_hyp, max_length, length_penalty, early_stopping):
794
+ """
795
+ Initialize n-best list of hypotheses.
796
+ """
797
+ self.max_length = max_length - 1 # ignoring bos_token
798
+ self.length_penalty = length_penalty
799
+ self.early_stopping = early_stopping
800
+ self.n_hyp = n_hyp
801
+ self.hyp = []
802
+ self.worst_score = 1e9
803
+
804
+ def __len__(self):
805
+ """
806
+ Number of hypotheses in the list.
807
+ """
808
+ return len(self.hyp)
809
+
810
+ def add(self, hyp, sum_logprobs):
811
+ """
812
+ Add a new hypothesis to the list.
813
+ """
814
+ score = sum_logprobs / len(hyp) ** self.length_penalty
815
+ if len(self) < self.n_hyp or score > self.worst_score:
816
+ self.hyp.append((score, hyp))
817
+ if len(self) > self.n_hyp:
818
+ sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
819
+ del self.hyp[sorted_scores[0][1]]
820
+ self.worst_score = sorted_scores[1][0]
821
+ else:
822
+ self.worst_score = min(score, self.worst_score)
823
+
824
+ def is_done(self, best_sum_logprobs):
825
+ """
826
+ If there are enough hypotheses and that none of the hypotheses being generated
827
+ can become better than the worst one in the heap, then we are done with this sentence.
828
+ """
829
+ if len(self) < self.n_hyp:
830
+ return False
831
+ elif self.early_stopping:
832
+ return True
833
+ else:
834
+ return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
835
+
836
+
837
+ def dump_predictions(args, result, file_suffix):
838
+ global_rank = get_rank()
839
+ jsons = None
840
+ if global_rank >= 0:
841
+ output_file = os.path.join(args.task_cache_path, f"submit_{global_rank}_{file_suffix}.json")
842
+ with open(output_file, "w") as fp:
843
+ json.dump(result, fp, indent=2)
844
+ torch.distributed.barrier()
845
+
846
+ if global_rank == 0:
847
+ world_size = get_world_size()
848
+ jsons = []
849
+ for i in range(world_size):
850
+ each_file = os.path.join(args.task_cache_path, f"submit_{i}_{file_suffix}.json")
851
+ with open(each_file, "r") as fp:
852
+ jsons += json.load(fp)
853
+
854
+ new_jsons = []
855
+ res_dict = dict()
856
+ if args.task in ["coco_captioning", "nocaps"]:
857
+ qid_key = "image_id"
858
+ else:
859
+ # for VQAv2
860
+ qid_key = "question_id"
861
+ for item in jsons:
862
+ if item[qid_key] in res_dict:
863
+ continue
864
+ new_jsons.append(item)
865
+ res_dict[item[qid_key]] = item
866
+ jsons = new_jsons
867
+
868
+ torch.distributed.barrier()
869
+ os.remove(output_file)
870
+ else:
871
+ jsons = result
872
+
873
+ result_file = os.path.join(args.output_dir, f"submit_{file_suffix}.json")
874
+ if jsons is not None:
875
+ with open(result_file, "w") as fp:
876
+ json.dump(jsons, fp, indent=2)
877
+ print("Infer %d examples into %s" % (len(jsons), result_file))
878
+ return result_file
879
+
880
+
881
+ # The evaluation code is from BLIP (https://github.com/salesforce/BLIP)
882
+ # For nocaps, please submit the prediction file to the evaluate server (https://eval.ai/web/challenges/challenge-page/355/overview) to obtain the final results
883
+ def coco_caption_eval(gt_dir, results_file, split):
884
+ from pycocotools.coco import COCO
885
+ from pycocoevalcap.eval import COCOEvalCap
886
+ from torchvision.datasets.utils import download_url
887
+
888
+ urls = {'coco_captioning_val': 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
889
+ 'coco_captioning_test': 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json',
890
+ 'nocaps_val': 'https://github.com/addf400/files/releases/download/beit3/nocaps_val_gt.json'}
891
+ filenames = {'coco_captioning_val':'coco_karpathy_val_gt.json',
892
+ 'coco_captioning_test':'coco_karpathy_test_gt.json',
893
+ 'nocaps_val':'nocaps_val_gt.json'}
894
+
895
+ download_url(urls[split], gt_dir)
896
+ annotation_file = os.path.join(gt_dir, filenames[split])
897
+
898
+ # create coco object and coco_result object
899
+ coco = COCO(annotation_file)
900
+ coco_result = coco.loadRes(results_file)
901
+
902
+ # create coco_eval object by taking coco and coco_result
903
+ coco_eval = COCOEvalCap(coco, coco_result)
904
+
905
+ # evaluate results
906
+ # SPICE will take a few minutes the first time, but speeds up due to caching
907
+ coco_eval.evaluate()
908
+
909
+ res_dict = dict()
910
+ for metric, score in coco_eval.eval.items():
911
+ res_dict[metric] = score
912
+
913
+ return res_dict
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ packaging
2
+ sentencepiece
3
+ einops==0.4.1
4
+ fastapi==0.100.1
5
+ markdown2==2.4.10
6
+ numpy==1.24.2
7
+ openai==0.27.8
8
+ opencv_python==4.8.0.74
9
+ Pillow==9.4.0
10
+ pycocotools==2.0.6
11
+ ray==2.6.1
12
+ Requests==2.31.0
13
+ shortuuid==1.0.11
14
+ tqdm==4.64.1
15
+ transformers==4.31.0
16
+ uvicorn==0.23.2
17
+ scipy==1.11.2
18
+ bitsandbytes==0.41.1
19
+ timm==0.4.12
20
+ blobfile
21
+ mypy
22
+ pytest
23
+ requests
24
+ tensorboardX
25
+ ftfy
26
+ opencv-python
27
+ pyarrow
28
+ torchmetrics==0.7.3
29
+ deepspeed
30
+ pycocoevalcap
31
+ torchscale==0.2.0
32
+ gradio
utils/ade20k_classes.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ "wall", "building", "sky", "floor", "tree", "ceiling", "road",
3
+ "bed", "windowpane", "grass", "cabinet", "sidewalk",
4
+ "person", "earth", "door", "table", "mountain", "plant",
5
+ "curtain", "chair", "car", "water", "painting", "sofa",
6
+ "shelf", "house", "sea", "mirror", "rug", "field", "armchair",
7
+ "seat", "fence", "desk", "rock", "wardrobe", "lamp",
8
+ "bathtub", "railing", "cushion", "base", "box", "column",
9
+ "signboard", "chest of drawers", "counter", "sand", "sink",
10
+ "skyscraper", "fireplace", "refrigerator", "grandstand",
11
+ "path", "stairs", "runway", "case", "pool table", "pillow",
12
+ "screen door", "stairway", "river", "bridge", "bookcase",
13
+ "blind", "coffee table", "toilet", "flower", "book", "hill",
14
+ "bench", "countertop", "stove", "palm", "kitchen island",
15
+ "computer", "swivel chair", "boat", "bar", "arcade machine",
16
+ "hovel", "bus", "towel", "light", "truck", "tower",
17
+ "chandelier", "awning", "streetlight", "booth",
18
+ "television receiver", "airplane", "dirt track", "apparel",
19
+ "pole", "land", "bannister", "escalator", "ottoman", "bottle",
20
+ "buffet", "poster", "stage", "van", "ship", "fountain",
21
+ "conveyer belt", "canopy", "washer", "plaything",
22
+ "swimming pool", "stool", "barrel", "basket", "waterfall",
23
+ "tent", "bag", "minibike", "cradle", "oven", "ball", "food",
24
+ "step", "tank", "trade name", "microwave", "pot", "animal",
25
+ "bicycle", "lake", "dishwasher", "screen", "blanket",
26
+ "sculpture", "hood", "sconce", "vase", "traffic light",
27
+ "tray", "ashcan", "fan", "pier", "crt screen", "plate",
28
+ "monitor", "bulletin board", "shower", "radiator", "glass",
29
+ "clock", "flag"
30
+ ]
utils/aug.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from copy import deepcopy
8
+ from typing import Tuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch.nn import functional as F
13
+ from torchvision.transforms.functional import resize # type: ignore
14
+ from torchvision.transforms.functional import to_pil_image
15
+ import random
16
+
17
+
18
+ class RandomScale:
19
+ """
20
+ Resizes images to the longest side 'target_length', as well as provides
21
+ methods for resizing coordinates and boxes. Provides methods for
22
+ transforming both numpy array and batched torch tensors.
23
+ """
24
+
25
+ def __init__(self, max_length: int, min_length: int) -> None:
26
+ self.max_length = max_length
27
+ self.min_length = min_length
28
+
29
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
30
+ """
31
+ Expects a numpy array with shape HxWxC in uint8 format.
32
+ """
33
+ target_size = self.get_preprocess_shape(
34
+ image.shape[0], image.shape[1], self.max_length, self.min_length
35
+ )
36
+ return np.array(resize(to_pil_image(image), target_size))
37
+
38
+ def apply_coords(
39
+ self, coords: np.ndarray, original_size: Tuple[int, ...]
40
+ ) -> np.ndarray:
41
+ """
42
+ Expects a numpy array of length 2 in the final dimension. Requires the
43
+ original image size in (H, W) format.
44
+ """
45
+ old_h, old_w = original_size
46
+ new_h, new_w = self.get_preprocess_shape(
47
+ original_size[0], original_size[1], self.max_length, self.min_length
48
+ )
49
+ coords = deepcopy(coords).astype(float)
50
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
51
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
52
+ return coords
53
+
54
+ def apply_boxes(
55
+ self, boxes: np.ndarray, original_size: Tuple[int, ...]
56
+ ) -> np.ndarray:
57
+ """
58
+ Expects a numpy array shape Bx4. Requires the original image size
59
+ in (H, W) format.
60
+ """
61
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
62
+ return boxes.reshape(-1, 4)
63
+
64
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
65
+ """
66
+ Expects batched images with shape BxCxHxW and float format. This
67
+ transformation may not exactly match apply_image. apply_image is
68
+ the transformation expected by the model.
69
+ """
70
+ # Expects an image in BCHW format. May not exactly match apply_image.
71
+ target_size = self.get_preprocess_shape(
72
+ image.shape[0], image.shape[1], self.max_length, self.min_length
73
+ )
74
+ return F.interpolate(
75
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
76
+ )
77
+
78
+ def apply_coords_torch(
79
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
80
+ ) -> torch.Tensor:
81
+ """
82
+ Expects a torch tensor with length 2 in the last dimension. Requires the
83
+ original image size in (H, W) format.
84
+ """
85
+ old_h, old_w = original_size
86
+ new_h, new_w = self.get_preprocess_shape(
87
+ original_size[0], original_size[1], self.max_length, self.min_length
88
+ )
89
+ coords = deepcopy(coords).to(torch.float)
90
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
91
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
92
+ return coords
93
+
94
+ def apply_boxes_torch(
95
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
96
+ ) -> torch.Tensor:
97
+ """
98
+ Expects a torch tensor with shape Bx4. Requires the original image
99
+ size in (H, W) format.
100
+ """
101
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
102
+ return boxes.reshape(-1, 4)
103
+
104
+ @staticmethod
105
+ def get_preprocess_shape(
106
+ oldh: int, oldw: int, max_length: int, min_length: int
107
+ ) -> Tuple[int, int]:
108
+ """
109
+ Compute the output size given input size and target long side length.
110
+ """
111
+ max_scale = max_length * 1.0 / max(oldh, oldw)
112
+ min_scale = min_length * 1.0 / max(oldh, oldw)
113
+ scale = min_scale + random.random() * (max_scale-min_scale)
114
+ newh, neww = oldh * scale, oldw * scale
115
+ neww = int(neww + 0.5)
116
+ newh = int(newh + 0.5)
117
+ return (newh, neww)