ikuinen99 commited on
Commit
e4bd7f9
1 Parent(s): 435d80f
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. InputSans-Regular.ttf +0 -0
  2. app.py +299 -4
  3. bubogpt/__init__.py +31 -0
  4. bubogpt/common/__init__.py +0 -0
  5. bubogpt/common/config.py +473 -0
  6. bubogpt/common/dist_utils.py +137 -0
  7. bubogpt/common/gradcam.py +24 -0
  8. bubogpt/common/logger.py +195 -0
  9. bubogpt/common/optims.py +119 -0
  10. bubogpt/common/registry.py +333 -0
  11. bubogpt/common/utils.py +424 -0
  12. bubogpt/configs/datasets/aud_img_neg/default.yaml +10 -0
  13. bubogpt/configs/datasets/audioset/defaults.yaml +5 -0
  14. bubogpt/configs/datasets/bbc/defaults.yaml +5 -0
  15. bubogpt/configs/datasets/cc12m/defaults.yaml +5 -0
  16. bubogpt/configs/datasets/cc_sbu/align.yaml +5 -0
  17. bubogpt/configs/datasets/cc_sbu/defaults.yaml +5 -0
  18. bubogpt/configs/datasets/clotho/align.yaml +5 -0
  19. bubogpt/configs/datasets/freesound/defaults.yaml +5 -0
  20. bubogpt/configs/datasets/laion/defaults.yaml +5 -0
  21. bubogpt/configs/datasets/soundbible/defaults.yaml +5 -0
  22. bubogpt/configs/datasets/vggss/align.yaml +6 -0
  23. bubogpt/configs/default.yaml +5 -0
  24. bubogpt/configs/models/mmgpt4.yaml +30 -0
  25. bubogpt/datasets/__init__.py +0 -0
  26. bubogpt/datasets/builders/__init__.py +90 -0
  27. bubogpt/datasets/builders/audio_base_dataset_builder.py +142 -0
  28. bubogpt/datasets/builders/audio_image_text_builder.py +105 -0
  29. bubogpt/datasets/builders/audio_text_pair_builder.py +88 -0
  30. bubogpt/datasets/builders/image_base_dataset_builder.py +238 -0
  31. bubogpt/datasets/builders/image_text_pair_builder.py +189 -0
  32. bubogpt/datasets/builders/multimodal_base_dataset_builder.py +74 -0
  33. bubogpt/datasets/data_utils.py +215 -0
  34. bubogpt/datasets/datasets/__init__.py +0 -0
  35. bubogpt/datasets/datasets/audio_caption/__init__.py +1 -0
  36. bubogpt/datasets/datasets/audio_caption/audio_caption_datasets.py +70 -0
  37. bubogpt/datasets/datasets/audio_image/__init__.py +0 -0
  38. bubogpt/datasets/datasets/audio_image/audio_image_datasets.py +92 -0
  39. bubogpt/datasets/datasets/base_dataset.py +79 -0
  40. bubogpt/datasets/datasets/dataloader_utils.py +162 -0
  41. bubogpt/datasets/datasets/image_caption/__init__.py +0 -0
  42. bubogpt/datasets/datasets/image_caption/cc_sbu_dataset.py +68 -0
  43. bubogpt/datasets/datasets/image_caption/image_caption_datasets.py +73 -0
  44. bubogpt/datasets/datasets/image_caption/laion_dataset.py +31 -0
  45. bubogpt/datasets/datasets/image_caption/llava_dataset.py +72 -0
  46. bubogpt/datasets/datasets/mixins/__init__.py +0 -0
  47. bubogpt/datasets/datasets/mixins/mixins.py +30 -0
  48. bubogpt/models/Qformer.py +1216 -0
  49. bubogpt/models/__init__.py +200 -0
  50. bubogpt/models/base_model.py +247 -0
InputSans-Regular.ttf ADDED
Binary file (128 kB). View file
 
app.py CHANGED
@@ -1,7 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ # import sys
5
+ # import os
6
+ #
7
+ # BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
8
+ # sys.path.append(BASE_DIR)
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.backends.cudnn as cudnn
13
  import gradio as gr
14
 
15
+ from constants.constant import LIGHTER_COLOR_MAP_HEX
16
+ # NOTE: Must import LlamaTokenizer before `bubogpt.common.config`
17
+ # otherwise, it will cause seg fault when `llama_tokenizer.decode` is called
18
+
19
+ from grounding_model import GroundingModule
20
+ from match import MatchModule
21
+ from bubogpt.common.config import Config
22
+ from bubogpt.common.dist_utils import get_rank
23
+ from bubogpt.common.registry import registry
24
+ from eval_scripts.conversation import Chat, CONV_X, DummyChat
25
+ # NOTE&TODO: put this before bubogpt import will cause circular import
26
+ # possibly because `imagebind` imports `bubogpt` and `bubogpt` also imports `imagebind`
27
+ from imagebind.models.image_bind import ModalityType
28
+ # from ner import NERModule
29
+ from tagging_model import TaggingModule
30
+
31
+
32
+
33
+ def parse_args():
34
+ parser = argparse.ArgumentParser(description="Qualitative")
35
+ parser.add_argument("--cfg-path", help="path to configuration file.", deafult='./eval_configs/mmgpt4_eval.yaml')
36
+ parser.add_argument("--dummy", action="store_true", help="Debug Mode")
37
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
38
+ parser.add_argument(
39
+ "--options",
40
+ nargs="+",
41
+ help="override some settings in the used config, the key-value pair "
42
+ "in xxx=yyy format will be merged into config file (deprecate), "
43
+ "change to --cfg-options instead.",
44
+ )
45
+ parser.add_argument("--ground-all", action="store_true")
46
+ args = parser.parse_args()
47
+ return args
48
+
49
+
50
+ def setup_seeds(config):
51
+ seed = config.run_cfg.seed + get_rank()
52
+
53
+ random.seed(seed)
54
+ np.random.seed(seed)
55
+ torch.manual_seed(seed)
56
+
57
+ cudnn.benchmark = False
58
+ cudnn.deterministic = True
59
+
60
+
61
+ # ========================================
62
+ # Model Initialization
63
+ # ========================================
64
+
65
+ print('Initializing Chat')
66
+ args = parse_args()
67
+
68
+ assert args.dummy or (args.cfg_path is not None), "Invalid Config! Set --dummy or configurate the cfg_path!"
69
+
70
+ if not args.dummy:
71
+ cfg = Config(args)
72
+
73
+ # Create processors
74
+ vis_processor_cfg = cfg.datasets_cfg.default.vis_processor.eval
75
+ aud_processor_cfg = cfg.datasets_cfg.default.audio_processor.eval
76
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
77
+ aud_processor = registry.get_processor_class(aud_processor_cfg.name).from_config(aud_processor_cfg)
78
+ processors = {ModalityType.VISION: vis_processor, ModalityType.AUDIO: aud_processor}
79
+
80
+ # Create model
81
+ model_config = cfg.model_cfg
82
+ model_config.device_8bit = args.gpu_id
83
+ model_cls = registry.get_model_class(model_config.arch)
84
+ model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
85
+ chat = Chat(model, processors, device='cuda:{}'.format(args.gpu_id))
86
+ else:
87
+ model = None
88
+ chat = DummyChat()
89
+
90
+ match = MatchModule(model='gpt-4')
91
+ tagging_module = TaggingModule(device='cuda:{}'.format(args.gpu_id))
92
+ grounding_dino = GroundingModule(device='cuda:{}'.format(args.gpu_id))
93
+ print('Initialization Finished')
94
+
95
+
96
+ # ========================================
97
+ # Gradio Setting
98
+ # ========================================
99
+
100
+ def gradio_reset(chat_state, emb_list):
101
+ if chat_state is not None:
102
+ chat_state.messages = []
103
+ if emb_list is not None:
104
+ emb_list = []
105
+ return None, gr.update(value=None, interactive=True), gr.update(value=None, interactive=False), \
106
+ gr.update(value=None, interactive=True), \
107
+ gr.update(placeholder='Please upload your image/audio first', interactive=False), \
108
+ gr.update(value=None), \
109
+ gr.update(value="Upload & Start Chat", interactive=True), \
110
+ chat_state, emb_list, gr.update(value={})
111
+
112
+
113
+ def upload_x(gr_img, gr_aud, chat_state):
114
+ if gr_img is None and gr_aud is None:
115
+ return None, None, None, gr.update(interactive=True), chat_state, None, {}
116
+ chat_state = CONV_X.copy()
117
+ emb_list = []
118
+ if gr_img is not None:
119
+ chat.upload_img(gr_img, chat_state, emb_list)
120
+ state = {
121
+ 'tags': tagging_module(gr_img)
122
+ }
123
+ # print(state)
124
+ else:
125
+ state = {}
126
+ if gr_aud is not None:
127
+ chat.upload_aud(gr_aud, chat_state, emb_list)
128
+ return gr.update(interactive=False), gr.update(interactive=False), \
129
+ gr.update(interactive=True, placeholder='Type and press Enter'), \
130
+ gr.update(value="Start Chatting", interactive=False), \
131
+ chat_state, emb_list, state
132
+
133
+
134
+ def gradio_ask(user_message, chatbot, chat_state, text_output, last_answer):
135
+ if len(user_message) == 0:
136
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state, \
137
+ gr.update(value=None, color_map=None, show_legend=False), gr.update(value=None)
138
+ if last_answer is not None:
139
+ chatbot[-1][1] = last_answer
140
+ chat.ask(user_message, chat_state)
141
+ if text_output is not None:
142
+ os.makedirs('results', exist_ok=True)
143
+ # print("****** Text output is:", text_output)
144
+ chatbot[-1][1] = ''.join(map(lambda x: x[0], text_output))
145
+ chatbot = chatbot + [[user_message, None]]
146
+ return '', chatbot, chat_state, gr.update(value=None, color_map=None, show_legend=False), gr.update(value=None)
147
+
148
+
149
+ def gradio_answer(image, chatbot, chat_state, emb_list, num_beams, temperature, entity_state):
150
+ llm_message = chat.answer(conversation=chat_state,
151
+ emb_list=emb_list,
152
+ num_beams=num_beams,
153
+ temperature=temperature,
154
+ max_new_tokens=300,
155
+ max_length=2000)[0]
156
+ if image is not None:
157
+ # new_entity_state = entity_state.value()
158
+ # new_entity_state.update({"answer": llm_message})
159
+ entity_state["answer"] = llm_message
160
+ rich_text, match_state, color_map = match(llm_message, entity_state)
161
+ print("Original Color Map: ", color_map)
162
+ color_map = {key: LIGHTER_COLOR_MAP_HEX[color_map[key]] for key in color_map}
163
+ print("Modified Color Map: ", color_map)
164
+ chatbot[-1][1] = "The answer can be found in the textbox below and I'm trying my best to highlight the " \
165
+ "corresponding region on the image."
166
+ # new_entity_state.update({"match_state": match_state})
167
+ entity_state['match_state'] = match_state # item_id -> local_id
168
+ new_grounded_image = grounding_dino.draw(image, entity_state)
169
+ show_legend = bool(match_state)
170
+ print('gradio_answer ==> current state: ', entity_state)
171
+
172
+ # if args.ground_all:
173
+ # ground_img, local_results = grounding_dino.prompt2mask(image,
174
+ # '.'.join(map(lambda x: x, state['entity'])),
175
+ # state=state)
176
+ # else:
177
+ # ground_img = None
178
+ return chatbot, chat_state, emb_list, \
179
+ gr.update(value=rich_text, color_map=color_map, show_legend=show_legend), \
180
+ gr.update(value=entity_state), \
181
+ gr.update(value=llm_message), gr.update(value=new_grounded_image)
182
+ else:
183
+ chatbot[-1][1] = llm_message
184
+ return chatbot, chat_state, emb_list, \
185
+ gr.update(value=None), \
186
+ entity_state, \
187
+ gr.update(value=None), gr.update(value=None)
188
+
189
+ def grounding_fn(image, chatbot, entity_state):
190
+ # print("Grounding fn: ", entity_state)
191
+ if image and entity_state:
192
+ ground_img, local_results = grounding_dino.prompt2mask2(
193
+ image, ','.join(map(lambda x: x, entity_state['tags'])), state=entity_state
194
+ )
195
+ entity_state['grounding'] = {
196
+ 'full': ground_img,
197
+ 'local': local_results
198
+ }
199
+ print('grounding_fn ==> current state: ', entity_state)
200
+ return chatbot, gr.update(value=ground_img, interactive=False), entity_state
201
+ return chatbot, gr.update(value=None, interactive=False), entity_state
202
+
203
+
204
+ def select_fn(image, ground_img, entity_state, evt: gr.SelectData):
205
+ if image is None:
206
+ return gr.update(value=None, interactive=False)
207
+ item, label = evt.value[0], evt.value[1]
208
+
209
+ if label is None:
210
+ return ground_img
211
+ print('select_fn ==> current state: ', entity_state)
212
+ if 'grounding' not in entity_state:
213
+ ground_img, local_results = grounding_dino.prompt2mask2(image,
214
+ ','.join(map(lambda x: x[0], entity_state['tags'])),
215
+ state=entity_state)
216
+ entity_state['grounding'] = {
217
+ 'full': ground_img,
218
+ 'local': local_results
219
+ }
220
+ # local_img = entity_state['grounding']['local'][entity]['image']
221
+ # print("DEBUG INFO: ", entity_state)
222
+ local_img = grounding_dino.draw(image, entity_state, item.lower())
223
+ return gr.update(value=local_img, interactive=False)
224
+
225
+
226
+ title = """<h1 align="center">Demo of BuboGPT</h1>"""
227
+ description = """<h3>This is the demo of BuboGPT. Upload and start chatting!</h3>"""
228
+ # article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
229
+ # """
230
+
231
+ # TODO show examples below
232
+
233
+ with gr.Blocks() as demo:
234
+ gr.Markdown(title)
235
+ gr.Markdown(description)
236
+ # gr.Markdown(article)
237
+
238
+ with gr.Row():
239
+ with gr.Column(scale=0.5):
240
+ image = gr.Image(type="pil")
241
+ grounded_image = gr.Image(type="pil", interactive=False)
242
+ audio = gr.Audio()
243
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
244
+ clear = gr.Button("Restart")
245
+
246
+ num_beams = gr.Slider(
247
+ minimum=1,
248
+ maximum=10,
249
+ value=1,
250
+ step=1,
251
+ interactive=True,
252
+ label="beam search numbers",
253
+ )
254
+
255
+ temperature = gr.Slider(
256
+ minimum=0.1,
257
+ maximum=2.0,
258
+ value=1.0,
259
+ step=0.1,
260
+ interactive=True,
261
+ label="Temperature",
262
+ )
263
+
264
+ with gr.Column():
265
+ chat_state = gr.State()
266
+ last_answer = gr.State()
267
+ entity_state = gr.State(value={})
268
+ emb_list = gr.State()
269
+ chatbot = gr.Chatbot(label='BindGPT-4')
270
+ text_output = gr.HighlightedText(value=None, label="Response", show_legend=False)
271
+ text_input = gr.Textbox(label='User', placeholder='Please upload your image/audio first', interactive=False)
272
+
273
+ upload_button.click(
274
+ upload_x, [image, audio, chat_state],
275
+ [image, audio, text_input, upload_button, chat_state, emb_list, entity_state]).then(
276
+ grounding_fn,
277
+ [image, chatbot, entity_state],
278
+ [chatbot, grounded_image, entity_state]
279
+ )
280
+
281
+ text_input.submit(gradio_ask,
282
+ [text_input, chatbot, chat_state, text_output, last_answer],
283
+ [text_input, chatbot, chat_state, text_output, last_answer]
284
+ ).then(
285
+ gradio_answer,
286
+ [image, chatbot, chat_state, emb_list, num_beams, temperature, entity_state],
287
+ [chatbot, chat_state, emb_list, text_output, entity_state, last_answer, grounded_image]
288
+ )
289
+
290
+ clear.click(gradio_reset,
291
+ [chat_state, emb_list],
292
+ [chatbot, image, grounded_image, audio, text_input, text_output,
293
+ upload_button, chat_state, emb_list, entity_state],
294
+ queue=False)
295
+
296
+ text_output.select(
297
+ select_fn,
298
+ [image, grounded_image, entity_state],
299
+ [grounded_image]
300
+ )
301
 
302
+ demo.launch(enable_queue=True)
 
bubogpt/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ import sys
10
+
11
+ from omegaconf import OmegaConf
12
+
13
+ from bubogpt.common.registry import registry
14
+
15
+ from bubogpt.datasets.builders import *
16
+ from bubogpt.models import *
17
+ from bubogpt.processors import *
18
+ from bubogpt.tasks import *
19
+
20
+
21
+ root_dir = os.path.dirname(os.path.abspath(__file__))
22
+ default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
23
+
24
+ registry.register_path("library_root", root_dir)
25
+ repo_root = os.path.join(root_dir, "..")
26
+ registry.register_path("repo_root", repo_root)
27
+ cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
28
+ registry.register_path("cache_root", cache_root)
29
+
30
+ registry.register("MAX_INT", sys.maxsize)
31
+ registry.register("SPLIT_NAMES", ["train", "val", "test"])
bubogpt/common/__init__.py ADDED
File without changes
bubogpt/common/config.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import json
10
+ from typing import Dict
11
+
12
+ from omegaconf import OmegaConf
13
+ from bubogpt.common.registry import registry
14
+
15
+ # logging.info = print
16
+
17
+
18
+ class Config:
19
+ def __init__(self, args):
20
+ self.config = {}
21
+
22
+ self.args = args
23
+
24
+ # Register the config and configuration for setup
25
+ registry.register("configuration", self)
26
+
27
+ user_config = self._build_opt_list(self.args.options)
28
+
29
+ config = OmegaConf.load(self.args.cfg_path)
30
+
31
+ runner_config = self.build_runner_config(config)
32
+ model_config = self.build_model_config(config, **user_config)
33
+ if not config.run.evaluate:
34
+ dataset_config = self.build_dataset_config(config)
35
+ else:
36
+ dataset_config = OmegaConf.create({"datasets": config.datasets})
37
+
38
+ # Validate the user-provided runner configuration
39
+ # model and dataset configuration are supposed to be validated by the respective classes
40
+ # [TODO] validate the model/dataset configuration
41
+ # self._validate_runner_config(runner_config)
42
+
43
+ # Override the default configuration with user options.
44
+ self.config = OmegaConf.merge(
45
+ runner_config, model_config, dataset_config, user_config
46
+ )
47
+
48
+ def _validate_runner_config(self, runner_config):
49
+ """
50
+ This method validates the configuration, such that
51
+ 1) all the user specified options are valid;
52
+ 2) no type mismatches between the user specified options and the config.
53
+ """
54
+ runner_config_validator = create_runner_config_validator()
55
+ runner_config_validator.validate(runner_config)
56
+
57
+ def _build_opt_list(self, opts):
58
+ opts_dot_list = self._convert_to_dot_list(opts)
59
+ return OmegaConf.from_dotlist(opts_dot_list)
60
+
61
+ @staticmethod
62
+ def build_model_config(config, **kwargs):
63
+ model = config.get("model", None)
64
+ assert model is not None, "Missing model configuration file."
65
+
66
+ model_cls = registry.get_model_class(model.arch)
67
+ assert model_cls is not None, f"Model '{model.arch}' has not been registered."
68
+
69
+ model_type = kwargs.get("model.model_type", None)
70
+ if not model_type:
71
+ model_type = model.get("model_type", None)
72
+ # else use the model type selected by user.
73
+
74
+ assert model_type is not None, "Missing model_type."
75
+
76
+ model_config_path = model_cls.default_config_path(model_type=model_type)
77
+
78
+ model_config = OmegaConf.create()
79
+ # hierarchy override, customized config > default config
80
+ model_config = OmegaConf.merge(
81
+ model_config,
82
+ OmegaConf.load(model_config_path),
83
+ {"model": config["model"]},
84
+ )
85
+
86
+ return model_config
87
+
88
+ @staticmethod
89
+ def build_runner_config(config):
90
+ return {"run": config.run}
91
+
92
+ @staticmethod
93
+ def build_dataset_config(config):
94
+ datasets = config.get("datasets", None)
95
+ if datasets is None:
96
+ raise KeyError(
97
+ "Expecting 'datasets' as the root key for dataset configuration."
98
+ )
99
+
100
+ dataset_config = OmegaConf.create()
101
+
102
+ for dataset_name in datasets:
103
+ builder_cls = registry.get_builder_class(dataset_name)
104
+
105
+ dataset_config_type = datasets[dataset_name].get("type", "default")
106
+ dataset_config_path = builder_cls.default_config_path(
107
+ type=dataset_config_type
108
+ )
109
+
110
+ # hierarchy override, customized config > default config
111
+ dataset_config = OmegaConf.merge(
112
+ dataset_config,
113
+ OmegaConf.load(dataset_config_path) if dataset_config_path is not None else {},
114
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
115
+ )
116
+
117
+ return dataset_config
118
+
119
+ def _convert_to_dot_list(self, opts):
120
+ if opts is None:
121
+ opts = []
122
+
123
+ if len(opts) == 0:
124
+ return opts
125
+
126
+ has_equal = opts[0].find("=") != -1
127
+
128
+ if has_equal:
129
+ return opts
130
+
131
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
132
+
133
+ def get_config(self):
134
+ return self.config
135
+
136
+ @property
137
+ def run_cfg(self):
138
+ return self.config.run
139
+
140
+ @property
141
+ def datasets_cfg(self):
142
+ return self.config.datasets
143
+
144
+ @property
145
+ def model_cfg(self):
146
+ return self.config.model
147
+
148
+ def pretty_print(self):
149
+ logging.info("\n===== Running Parameters =====")
150
+ logging.info(self._convert_node_to_json(self.config.run))
151
+
152
+ logging.info("\n====== Dataset Attributes ======")
153
+ datasets = self.config.datasets
154
+
155
+ for dataset in datasets:
156
+ if dataset in self.config.datasets:
157
+ logging.info(f"\n======== {dataset} =======")
158
+ dataset_config = self.config.datasets[dataset]
159
+ logging.info(self._convert_node_to_json(dataset_config))
160
+ else:
161
+ logging.warning(f"No dataset named '{dataset}' in config. Skipping")
162
+
163
+ logging.info(f"\n====== Model Attributes ======")
164
+ logging.info(self._convert_node_to_json(self.config.model))
165
+
166
+ def _convert_node_to_json(self, node):
167
+ container = OmegaConf.to_container(node, resolve=True)
168
+ return json.dumps(container, indent=4, sort_keys=True)
169
+
170
+ def to_dict(self):
171
+ return OmegaConf.to_container(self.config)
172
+
173
+
174
+ def node_to_dict(node):
175
+ return OmegaConf.to_container(node)
176
+
177
+
178
+ class ConfigValidator:
179
+ """
180
+ This is a preliminary implementation to centralize and validate the configuration.
181
+ May be altered in the future.
182
+
183
+ A helper class to validate configurations from yaml file.
184
+
185
+ This serves the following purposes:
186
+ 1. Ensure all the options in the yaml are defined, raise error if not.
187
+ 2. when type mismatches are found, the validator will raise an error.
188
+ 3. a central place to store and display helpful messages for supported configurations.
189
+
190
+ """
191
+
192
+ class _Argument:
193
+ def __init__(self, name, choices=None, type=None, help=None):
194
+ self.name = name
195
+ self.val = None
196
+ self.choices = choices
197
+ self.type = type
198
+ self.help = help
199
+
200
+ def __str__(self):
201
+ s = f"{self.name}={self.val}"
202
+ if self.type is not None:
203
+ s += f", ({self.type})"
204
+ if self.choices is not None:
205
+ s += f", choices: {self.choices}"
206
+ if self.help is not None:
207
+ s += f", ({self.help})"
208
+ return s
209
+
210
+ def __init__(self, description):
211
+ self.description = description
212
+
213
+ self.arguments = dict()
214
+
215
+ self.parsed_args = None
216
+
217
+ def __getitem__(self, key):
218
+ assert self.parsed_args is not None, "No arguments parsed yet."
219
+
220
+ return self.parsed_args[key]
221
+
222
+ def __str__(self) -> str:
223
+ return self.format_help()
224
+
225
+ def add_argument(self, *args, **kwargs):
226
+ """
227
+ Assume the first argument is the name of the argument.
228
+ """
229
+ self.arguments[args[0]] = self._Argument(*args, **kwargs)
230
+
231
+ def validate(self, config=None):
232
+ """
233
+ Convert yaml config (dict-like) to list, required by argparse.
234
+ """
235
+ for k, v in config.items():
236
+ assert (
237
+ k in self.arguments
238
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
239
+
240
+ if self.arguments[k].type is not None:
241
+ try:
242
+ self.arguments[k].val = self.arguments[k].type(v)
243
+ except ValueError:
244
+ raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
245
+
246
+ if self.arguments[k].choices is not None:
247
+ assert (
248
+ v in self.arguments[k].choices
249
+ ), f"""{k} must be one of {self.arguments[k].choices}."""
250
+
251
+ return config
252
+
253
+ def format_arguments(self):
254
+ return str([f"{k}" for k in sorted(self.arguments.keys())])
255
+
256
+ def format_help(self):
257
+ # description + key-value pair string for each argument
258
+ help_msg = str(self.description)
259
+ return help_msg + ", available arguments: " + self.format_arguments()
260
+
261
+ def print_help(self):
262
+ # display help message
263
+ print(self.format_help())
264
+
265
+
266
+ def create_runner_config_validator():
267
+ validator = ConfigValidator(description="Runner configurations")
268
+
269
+ validator.add_argument(
270
+ "runner",
271
+ type=str,
272
+ choices=["runner_base", "runner_iter"],
273
+ help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
274
+ runner runs based on iters. Default: runner_base""",
275
+ )
276
+ # add argumetns for training dataset ratios
277
+ validator.add_argument(
278
+ "train_dataset_ratios",
279
+ type=Dict[str, float],
280
+ help="""Ratios of training dataset. This is used in iteration-based runner.
281
+ Do not support for epoch-based runner because how to define an epoch becomes tricky.
282
+ Default: None""",
283
+ )
284
+ validator.add_argument(
285
+ "max_iters",
286
+ type=float,
287
+ help="Maximum number of iterations to run.",
288
+ )
289
+ validator.add_argument(
290
+ "max_epoch",
291
+ type=int,
292
+ help="Maximum number of epochs to run.",
293
+ )
294
+ # add arguments for iters_per_inner_epoch
295
+ validator.add_argument(
296
+ "iters_per_inner_epoch",
297
+ type=float,
298
+ help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
299
+ )
300
+ lr_scheds_choices = registry.list_lr_schedulers()
301
+ validator.add_argument(
302
+ "lr_sched",
303
+ type=str,
304
+ choices=lr_scheds_choices,
305
+ help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
306
+ )
307
+ task_choices = registry.list_tasks()
308
+ validator.add_argument(
309
+ "task",
310
+ type=str,
311
+ choices=task_choices,
312
+ help="Task to use, from {}".format(task_choices),
313
+ )
314
+ # add arguments for init_lr
315
+ validator.add_argument(
316
+ "init_lr",
317
+ type=float,
318
+ help="Initial learning rate. This will be the learning rate after warmup and before decay.",
319
+ )
320
+ # add arguments for min_lr
321
+ validator.add_argument(
322
+ "min_lr",
323
+ type=float,
324
+ help="Minimum learning rate (after decay).",
325
+ )
326
+ # add arguments for warmup_lr
327
+ validator.add_argument(
328
+ "warmup_lr",
329
+ type=float,
330
+ help="Starting learning rate for warmup.",
331
+ )
332
+ # add arguments for learning rate decay rate
333
+ validator.add_argument(
334
+ "lr_decay_rate",
335
+ type=float,
336
+ help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
337
+ )
338
+ # add arguments for weight decay
339
+ validator.add_argument(
340
+ "weight_decay",
341
+ type=float,
342
+ help="Weight decay rate.",
343
+ )
344
+ # add arguments for training batch size
345
+ validator.add_argument(
346
+ "batch_size_train",
347
+ type=int,
348
+ help="Training batch size.",
349
+ )
350
+ # add arguments for evaluation batch size
351
+ validator.add_argument(
352
+ "batch_size_eval",
353
+ type=int,
354
+ help="Evaluation batch size, including validation and testing.",
355
+ )
356
+ # add arguments for number of workers for data loading
357
+ validator.add_argument(
358
+ "num_workers",
359
+ help="Number of workers for data loading.",
360
+ )
361
+ # add arguments for warm up steps
362
+ validator.add_argument(
363
+ "warmup_steps",
364
+ type=int,
365
+ help="Number of warmup steps. Required if a warmup schedule is used.",
366
+ )
367
+ # add arguments for random seed
368
+ validator.add_argument(
369
+ "seed",
370
+ type=int,
371
+ help="Random seed.",
372
+ )
373
+ # add arguments for output directory
374
+ validator.add_argument(
375
+ "output_dir",
376
+ type=str,
377
+ help="Output directory to save checkpoints and logs.",
378
+ )
379
+ # add arguments for whether only use evaluation
380
+ validator.add_argument(
381
+ "evaluate",
382
+ help="Whether to only evaluate the model. If true, training will not be performed.",
383
+ )
384
+ # add arguments for splits used for training, e.g. ["train", "val"]
385
+ validator.add_argument(
386
+ "train_splits",
387
+ type=list,
388
+ help="Splits to use for training.",
389
+ )
390
+ # add arguments for splits used for validation, e.g. ["val"]
391
+ validator.add_argument(
392
+ "valid_splits",
393
+ type=list,
394
+ help="Splits to use for validation. If not provided, will skip the validation.",
395
+ )
396
+ # add arguments for splits used for testing, e.g. ["test"]
397
+ validator.add_argument(
398
+ "test_splits",
399
+ type=list,
400
+ help="Splits to use for testing. If not provided, will skip the testing.",
401
+ )
402
+ # add arguments for accumulating gradient for iterations
403
+ validator.add_argument(
404
+ "accum_grad_iters",
405
+ type=int,
406
+ help="Number of iterations to accumulate gradient for.",
407
+ )
408
+
409
+ # ====== distributed training ======
410
+ validator.add_argument(
411
+ "device",
412
+ type=str,
413
+ choices=["cpu", "cuda"],
414
+ help="Device to use. Support 'cuda' or 'cpu' as for now.",
415
+ )
416
+ validator.add_argument(
417
+ "world_size",
418
+ type=int,
419
+ help="Number of processes participating in the job.",
420
+ )
421
+ validator.add_argument("dist_url", type=str)
422
+ validator.add_argument("distributed", type=bool)
423
+ # add arguments to opt using distributed sampler during evaluation or not
424
+ validator.add_argument(
425
+ "use_dist_eval_sampler",
426
+ type=bool,
427
+ help="Whether to use distributed sampler during evaluation or not.",
428
+ )
429
+
430
+ # ====== task specific ======
431
+ # generation task specific arguments
432
+ # add arguments for maximal length of text output
433
+ validator.add_argument(
434
+ "max_len",
435
+ type=int,
436
+ help="Maximal length of text output.",
437
+ )
438
+ # add arguments for minimal length of text output
439
+ validator.add_argument(
440
+ "min_len",
441
+ type=int,
442
+ help="Minimal length of text output.",
443
+ )
444
+ # add arguments number of beams
445
+ validator.add_argument(
446
+ "num_beams",
447
+ type=int,
448
+ help="Number of beams used for beam search.",
449
+ )
450
+
451
+ # vqa task specific arguments
452
+ # add arguments for number of answer candidates
453
+ validator.add_argument(
454
+ "num_ans_candidates",
455
+ type=int,
456
+ help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
457
+ )
458
+ # add arguments for inference method
459
+ validator.add_argument(
460
+ "inference_method",
461
+ type=str,
462
+ choices=["genearte", "rank"],
463
+ help="""Inference method to use for question answering. If rank, requires a answer list.""",
464
+ )
465
+
466
+ # ====== model specific ======
467
+ validator.add_argument(
468
+ "k_test",
469
+ type=int,
470
+ help="Number of top k most similar samples from ITC/VTC selection to be tested.",
471
+ )
472
+
473
+ return validator
bubogpt/common/dist_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import functools
10
+ import os
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import timm.models.hub as timm_hub
15
+
16
+
17
+ def setup_for_distributed(is_master):
18
+ """
19
+ This function disables printing when not in master process
20
+ """
21
+ import builtins as __builtin__
22
+
23
+ builtin_print = __builtin__.print
24
+
25
+ def print(*args, **kwargs):
26
+ force = kwargs.pop("force", False)
27
+ if is_master or force:
28
+ builtin_print(*args, **kwargs)
29
+
30
+ __builtin__.print = print
31
+
32
+
33
+ def is_dist_avail_and_initialized():
34
+ if not dist.is_available():
35
+ return False
36
+ if not dist.is_initialized():
37
+ return False
38
+ return True
39
+
40
+
41
+ def get_world_size():
42
+ if not is_dist_avail_and_initialized():
43
+ return 1
44
+ return dist.get_world_size()
45
+
46
+
47
+ def get_rank():
48
+ if not is_dist_avail_and_initialized():
49
+ return 0
50
+ return dist.get_rank()
51
+
52
+
53
+ def is_main_process():
54
+ return get_rank() == 0
55
+
56
+
57
+ def init_distributed_mode(args):
58
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59
+ args.rank = int(os.environ["RANK"])
60
+ args.world_size = int(os.environ["WORLD_SIZE"])
61
+ args.gpu = int(os.environ["LOCAL_RANK"])
62
+ elif "SLURM_PROCID" in os.environ:
63
+ args.rank = int(os.environ["SLURM_PROCID"])
64
+ args.gpu = args.rank % torch.cuda.device_count()
65
+ else:
66
+ print("Not using distributed mode")
67
+ args.distributed = False
68
+ return
69
+
70
+ args.distributed = True
71
+
72
+ torch.cuda.set_device(args.gpu)
73
+ args.dist_backend = "nccl"
74
+ print(
75
+ "| distributed init (rank {}, world {}): {}".format(
76
+ args.rank, args.world_size, args.dist_url
77
+ ),
78
+ flush=True,
79
+ )
80
+ torch.distributed.init_process_group(
81
+ backend=args.dist_backend,
82
+ init_method=args.dist_url,
83
+ world_size=args.world_size,
84
+ rank=args.rank,
85
+ timeout=datetime.timedelta(
86
+ days=365
87
+ ), # allow auto-downloading and de-compressing
88
+ )
89
+ torch.distributed.barrier()
90
+ setup_for_distributed(args.rank == 0)
91
+
92
+
93
+ def get_dist_info():
94
+ if torch.__version__ < "1.0":
95
+ initialized = dist._initialized
96
+ else:
97
+ initialized = dist.is_initialized()
98
+ if initialized:
99
+ rank = dist.get_rank()
100
+ world_size = dist.get_world_size()
101
+ else: # non-distributed training
102
+ rank = 0
103
+ world_size = 1
104
+ return rank, world_size
105
+
106
+
107
+ def main_process(func):
108
+ @functools.wraps(func)
109
+ def wrapper(*args, **kwargs):
110
+ rank, _ = get_dist_info()
111
+ if rank == 0:
112
+ return func(*args, **kwargs)
113
+
114
+ return wrapper
115
+
116
+
117
+ def download_cached_file(url, check_hash=True, progress=False):
118
+ """
119
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
120
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
121
+ """
122
+
123
+ def get_cached_file_path():
124
+ # a hack to sync the file path across processes
125
+ parts = torch.hub.urlparse(url)
126
+ filename = os.path.basename(parts.path)
127
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
128
+
129
+ return cached_file
130
+
131
+ if is_main_process():
132
+ timm_hub.download_cached_file(url, check_hash, progress)
133
+
134
+ if is_dist_avail_and_initialized():
135
+ dist.barrier()
136
+
137
+ return get_cached_file_path()
bubogpt/common/gradcam.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib import pyplot as plt
3
+ from scipy.ndimage import filters
4
+ from skimage import transform as skimage_transform
5
+
6
+
7
+ def getAttMap(img, attMap, blur=True, overlap=True):
8
+ attMap -= attMap.min()
9
+ if attMap.max() > 0:
10
+ attMap /= attMap.max()
11
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12
+ if blur:
13
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14
+ attMap -= attMap.min()
15
+ attMap /= attMap.max()
16
+ cmap = plt.get_cmap("jet")
17
+ attMapV = cmap(attMap)
18
+ attMapV = np.delete(attMapV, 3, 2)
19
+ if overlap:
20
+ attMap = (
21
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23
+ )
24
+ return attMap
bubogpt/common/logger.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import logging
10
+ import time
11
+ from collections import defaultdict, deque
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+
16
+ from bubogpt.common import dist_utils
17
+
18
+
19
+ class SmoothedValue(object):
20
+ """Track a series of values and provide access to smoothed values over a
21
+ window or the global series average.
22
+ """
23
+
24
+ def __init__(self, window_size=20, fmt=None):
25
+ if fmt is None:
26
+ fmt = "{median:.4f} ({global_avg:.4f})"
27
+ self.deque = deque(maxlen=window_size)
28
+ self.total = 0.0
29
+ self.count = 0
30
+ self.fmt = fmt
31
+
32
+ def update(self, value, n=1):
33
+ self.deque.append(value)
34
+ self.count += n
35
+ self.total += value * n
36
+
37
+ def synchronize_between_processes(self):
38
+ """
39
+ Warning: does not synchronize the deque!
40
+ """
41
+ if not dist_utils.is_dist_avail_and_initialized():
42
+ return
43
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44
+ dist.barrier()
45
+ dist.all_reduce(t)
46
+ t = t.tolist()
47
+ self.count = int(t[0])
48
+ self.total = t[1]
49
+
50
+ @property
51
+ def median(self):
52
+ d = torch.tensor(list(self.deque))
53
+ return d.median().item()
54
+
55
+ @property
56
+ def avg(self):
57
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
58
+ return d.mean().item()
59
+
60
+ @property
61
+ def global_avg(self):
62
+ return self.total / self.count
63
+
64
+ @property
65
+ def max(self):
66
+ return max(self.deque)
67
+
68
+ @property
69
+ def value(self):
70
+ return self.deque[-1]
71
+
72
+ def __str__(self):
73
+ return self.fmt.format(
74
+ median=self.median,
75
+ avg=self.avg,
76
+ global_avg=self.global_avg,
77
+ max=self.max,
78
+ value=self.value,
79
+ )
80
+
81
+
82
+ class MetricLogger(object):
83
+ def __init__(self, delimiter="\t"):
84
+ self.meters = defaultdict(SmoothedValue)
85
+ self.delimiter = delimiter
86
+
87
+ def update(self, **kwargs):
88
+ for k, v in kwargs.items():
89
+ if isinstance(v, torch.Tensor):
90
+ v = v.item()
91
+ assert isinstance(v, (float, int))
92
+ self.meters[k].update(v)
93
+
94
+ def __getattr__(self, attr):
95
+ if attr in self.meters:
96
+ return self.meters[attr]
97
+ if attr in self.__dict__:
98
+ return self.__dict__[attr]
99
+ raise AttributeError(
100
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101
+ )
102
+
103
+ def __str__(self):
104
+ loss_str = []
105
+ for name, meter in self.meters.items():
106
+ loss_str.append("{}: {}".format(name, str(meter)))
107
+ return self.delimiter.join(loss_str)
108
+
109
+ def global_avg(self):
110
+ loss_str = []
111
+ for name, meter in self.meters.items():
112
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
113
+ return self.delimiter.join(loss_str)
114
+
115
+ def synchronize_between_processes(self):
116
+ for meter in self.meters.values():
117
+ meter.synchronize_between_processes()
118
+
119
+ def add_meter(self, name, meter):
120
+ self.meters[name] = meter
121
+
122
+ def log_every(self, iterable, print_freq, header=None):
123
+ i = 0
124
+ if not header:
125
+ header = ""
126
+ start_time = time.time()
127
+ end = time.time()
128
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
129
+ data_time = SmoothedValue(fmt="{avg:.4f}")
130
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
131
+ log_msg = [
132
+ header,
133
+ "[{0" + space_fmt + "}/{1}]",
134
+ "eta: {eta}",
135
+ "{meters}",
136
+ "time: {time}",
137
+ "data: {data}",
138
+ ]
139
+ if torch.cuda.is_available():
140
+ log_msg.append("max mem: {memory:.0f}")
141
+ log_msg = self.delimiter.join(log_msg)
142
+ MB = 1024.0 * 1024.0
143
+ for obj in iterable:
144
+ data_time.update(time.time() - end)
145
+ yield obj
146
+ iter_time.update(time.time() - end)
147
+ if i % print_freq == 0 or i == len(iterable) - 1:
148
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
149
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
150
+ if torch.cuda.is_available():
151
+ print(
152
+ log_msg.format(
153
+ i,
154
+ len(iterable),
155
+ eta=eta_string,
156
+ meters=str(self),
157
+ time=str(iter_time),
158
+ data=str(data_time),
159
+ memory=torch.cuda.max_memory_allocated() / MB,
160
+ )
161
+ )
162
+ else:
163
+ print(
164
+ log_msg.format(
165
+ i,
166
+ len(iterable),
167
+ eta=eta_string,
168
+ meters=str(self),
169
+ time=str(iter_time),
170
+ data=str(data_time),
171
+ )
172
+ )
173
+ i += 1
174
+ end = time.time()
175
+ total_time = time.time() - start_time
176
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177
+ print(
178
+ "{} Total time: {} ({:.4f} s / it)".format(
179
+ header, total_time_str, total_time / len(iterable)
180
+ )
181
+ )
182
+
183
+
184
+ class AttrDict(dict):
185
+ def __init__(self, *args, **kwargs):
186
+ super(AttrDict, self).__init__(*args, **kwargs)
187
+ self.__dict__ = self
188
+
189
+
190
+ def setup_logger():
191
+ logging.basicConfig(
192
+ level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
193
+ format="%(asctime)s [%(levelname)s] %(message)s",
194
+ handlers=[logging.StreamHandler()],
195
+ )
bubogpt/common/optims.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import math
9
+
10
+ from bubogpt.common.registry import registry
11
+
12
+
13
+ @registry.register_lr_scheduler("linear_warmup_step_lr")
14
+ class LinearWarmupStepLRScheduler:
15
+ def __init__(
16
+ self,
17
+ optimizer,
18
+ max_epoch,
19
+ min_lr,
20
+ init_lr,
21
+ decay_rate=1,
22
+ warmup_start_lr=-1,
23
+ warmup_steps=0,
24
+ **kwargs
25
+ ):
26
+ self.optimizer = optimizer
27
+
28
+ self.max_epoch = max_epoch
29
+ self.min_lr = min_lr
30
+
31
+ self.decay_rate = decay_rate
32
+
33
+ self.init_lr = init_lr
34
+ self.warmup_steps = warmup_steps
35
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36
+
37
+ def step(self, cur_epoch, cur_step):
38
+ if cur_epoch == 0:
39
+ warmup_lr_schedule(
40
+ step=cur_step,
41
+ optimizer=self.optimizer,
42
+ max_step=self.warmup_steps,
43
+ init_lr=self.warmup_start_lr,
44
+ max_lr=self.init_lr,
45
+ )
46
+ else:
47
+ step_lr_schedule(
48
+ epoch=cur_epoch,
49
+ optimizer=self.optimizer,
50
+ init_lr=self.init_lr,
51
+ min_lr=self.min_lr,
52
+ decay_rate=self.decay_rate,
53
+ )
54
+
55
+
56
+ @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57
+ class LinearWarmupCosineLRScheduler:
58
+ def __init__(
59
+ self,
60
+ optimizer,
61
+ max_epoch,
62
+ iters_per_epoch,
63
+ min_lr,
64
+ init_lr,
65
+ warmup_steps=0,
66
+ warmup_start_lr=-1,
67
+ **kwargs
68
+ ):
69
+ self.optimizer = optimizer
70
+
71
+ self.max_epoch = max_epoch
72
+ self.iters_per_epoch = iters_per_epoch
73
+ self.min_lr = min_lr
74
+
75
+ self.init_lr = init_lr
76
+ self.warmup_steps = warmup_steps
77
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
78
+
79
+ def step(self, cur_epoch, cur_step):
80
+ total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
81
+ if total_cur_step < self.warmup_steps:
82
+ warmup_lr_schedule(
83
+ step=cur_step,
84
+ optimizer=self.optimizer,
85
+ max_step=self.warmup_steps,
86
+ init_lr=self.warmup_start_lr,
87
+ max_lr=self.init_lr,
88
+ )
89
+ else:
90
+ cosine_lr_schedule(
91
+ epoch=total_cur_step,
92
+ optimizer=self.optimizer,
93
+ max_epoch=self.max_epoch * self.iters_per_epoch,
94
+ init_lr=self.init_lr,
95
+ min_lr=self.min_lr,
96
+ )
97
+
98
+
99
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
100
+ """Decay the learning rate"""
101
+ lr = (init_lr - min_lr) * 0.5 * (
102
+ 1.0 + math.cos(math.pi * epoch / max_epoch)
103
+ ) + min_lr
104
+ for param_group in optimizer.param_groups:
105
+ param_group["lr"] = lr
106
+
107
+
108
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
109
+ """Warmup the learning rate"""
110
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
111
+ for param_group in optimizer.param_groups:
112
+ param_group["lr"] = lr
113
+
114
+
115
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
116
+ """Decay the learning rate"""
117
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
118
+ for param_group in optimizer.param_groups:
119
+ param_group["lr"] = lr
bubogpt/common/registry.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+
9
+
10
+ class Registry:
11
+ mapping = {
12
+ "builder_name_mapping": {},
13
+ "task_name_mapping": {},
14
+ "processor_name_mapping": {},
15
+ "model_name_mapping": {},
16
+ "lr_scheduler_name_mapping": {},
17
+ "runner_name_mapping": {},
18
+ "state": {},
19
+ "paths": {},
20
+ }
21
+
22
+ @classmethod
23
+ def register_builder(cls, name):
24
+ r"""Register a dataset builder to registry with key 'name'
25
+
26
+ Args:
27
+ name: Key with which the builder will be registered.
28
+
29
+ Usage:
30
+
31
+ from bubogpt.common.registry import registry
32
+ from bubogpt.datasets.base_dataset_builder import BaseDatasetBuilder
33
+ """
34
+
35
+ def wrap(builder_cls):
36
+ # TODO: merge them or split builders by modality
37
+ from bubogpt.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
38
+ from bubogpt.datasets.builders.audio_base_dataset_builder import AudioBaseDatasetBuilder
39
+ from bubogpt.datasets.builders.multimodal_base_dataset_builder import MultimodalBaseDatasetBuilder
40
+
41
+ assert issubclass(
42
+ builder_cls, (ImageBaseDatasetBuilder, AudioBaseDatasetBuilder, MultimodalBaseDatasetBuilder)
43
+ ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
44
+ builder_cls
45
+ )
46
+ if name in cls.mapping["builder_name_mapping"]:
47
+ raise KeyError(
48
+ "Name '{}' already registered for {}.".format(
49
+ name, cls.mapping["builder_name_mapping"][name]
50
+ )
51
+ )
52
+ cls.mapping["builder_name_mapping"][name] = builder_cls
53
+ return builder_cls
54
+
55
+ return wrap
56
+
57
+ @classmethod
58
+ def register_task(cls, name):
59
+ r"""Register a task to registry with key 'name'
60
+
61
+ Args:
62
+ name: Key with which the task will be registered.
63
+
64
+ Usage:
65
+
66
+ from bubogpt.common.registry import registry
67
+ """
68
+
69
+ def wrap(task_cls):
70
+ from bubogpt.tasks.base_task import BaseTask
71
+
72
+ assert issubclass(
73
+ task_cls, BaseTask
74
+ ), "All tasks must inherit BaseTask class"
75
+ if name in cls.mapping["task_name_mapping"]:
76
+ raise KeyError(
77
+ "Name '{}' already registered for {}.".format(
78
+ name, cls.mapping["task_name_mapping"][name]
79
+ )
80
+ )
81
+ cls.mapping["task_name_mapping"][name] = task_cls
82
+ return task_cls
83
+
84
+ return wrap
85
+
86
+ @classmethod
87
+ def register_model(cls, name):
88
+ r"""Register a task to registry with key 'name'
89
+
90
+ Args:
91
+ name: Key with which the task will be registered.
92
+
93
+ Usage:
94
+
95
+ from bubogpt.common.registry import registry
96
+ """
97
+
98
+ def wrap(model_cls):
99
+ from bubogpt.models import BaseModel
100
+
101
+ assert issubclass(
102
+ model_cls, BaseModel
103
+ ), "All models must inherit BaseModel class"
104
+ if name in cls.mapping["model_name_mapping"]:
105
+ raise KeyError(
106
+ "Name '{}' already registered for {}.".format(
107
+ name, cls.mapping["model_name_mapping"][name]
108
+ )
109
+ )
110
+ cls.mapping["model_name_mapping"][name] = model_cls
111
+ return model_cls
112
+
113
+ return wrap
114
+
115
+ @classmethod
116
+ def register_processor(cls, name):
117
+ r"""Register a processor to registry with key 'name'
118
+
119
+ Args:
120
+ name: Key with which the task will be registered.
121
+
122
+ Usage:
123
+
124
+ from bubogpt.common.registry import registry
125
+ """
126
+
127
+ def wrap(processor_cls):
128
+ from bubogpt.processors import BaseProcessor
129
+
130
+ assert issubclass(
131
+ processor_cls, BaseProcessor
132
+ ), "All processors must inherit BaseProcessor class"
133
+ if name in cls.mapping["processor_name_mapping"]:
134
+ raise KeyError(
135
+ "Name '{}' already registered for {}.".format(
136
+ name, cls.mapping["processor_name_mapping"][name]
137
+ )
138
+ )
139
+ cls.mapping["processor_name_mapping"][name] = processor_cls
140
+ return processor_cls
141
+
142
+ return wrap
143
+
144
+ @classmethod
145
+ def register_lr_scheduler(cls, name):
146
+ r"""Register a model to registry with key 'name'
147
+
148
+ Args:
149
+ name: Key with which the task will be registered.
150
+
151
+ Usage:
152
+
153
+ from bubogpt.common.registry import registry
154
+ """
155
+
156
+ def wrap(lr_sched_cls):
157
+ if name in cls.mapping["lr_scheduler_name_mapping"]:
158
+ raise KeyError(
159
+ "Name '{}' already registered for {}.".format(
160
+ name, cls.mapping["lr_scheduler_name_mapping"][name]
161
+ )
162
+ )
163
+ cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
164
+ return lr_sched_cls
165
+
166
+ return wrap
167
+
168
+ @classmethod
169
+ def register_runner(cls, name):
170
+ r"""Register a model to registry with key 'name'
171
+
172
+ Args:
173
+ name: Key with which the task will be registered.
174
+
175
+ Usage:
176
+
177
+ from bubogpt.common.registry import registry
178
+ """
179
+
180
+ def wrap(runner_cls):
181
+ if name in cls.mapping["runner_name_mapping"]:
182
+ raise KeyError(
183
+ "Name '{}' already registered for {}.".format(
184
+ name, cls.mapping["runner_name_mapping"][name]
185
+ )
186
+ )
187
+ cls.mapping["runner_name_mapping"][name] = runner_cls
188
+ return runner_cls
189
+
190
+ return wrap
191
+
192
+ @classmethod
193
+ def register_path(cls, name, path):
194
+ r"""Register a path to registry with key 'name'
195
+
196
+ Args:
197
+ name: Key with which the path will be registered.
198
+
199
+ Usage:
200
+
201
+ from bubogpt.common.registry import registry
202
+ """
203
+ assert isinstance(path, str), "All path must be str."
204
+ if name in cls.mapping["paths"]:
205
+ raise KeyError("Name '{}' already registered.".format(name))
206
+ cls.mapping["paths"][name] = path
207
+
208
+ @classmethod
209
+ def register(cls, name, obj):
210
+ r"""Register an item to registry with key 'name'
211
+
212
+ Args:
213
+ name: Key with which the item will be registered.
214
+
215
+ Usage::
216
+
217
+ from bubogpt.common.registry import registry
218
+
219
+ registry.register("config", {})
220
+ """
221
+ path = name.split(".")
222
+ current = cls.mapping["state"]
223
+
224
+ for part in path[:-1]:
225
+ if part not in current:
226
+ current[part] = {}
227
+ current = current[part]
228
+
229
+ current[path[-1]] = obj
230
+
231
+ # @classmethod
232
+ # def get_trainer_class(cls, name):
233
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
234
+
235
+ @classmethod
236
+ def get_builder_class(cls, name):
237
+ return cls.mapping["builder_name_mapping"].get(name, None)
238
+
239
+ @classmethod
240
+ def get_model_class(cls, name):
241
+ return cls.mapping["model_name_mapping"].get(name, None)
242
+
243
+ @classmethod
244
+ def get_task_class(cls, name):
245
+ return cls.mapping["task_name_mapping"].get(name, None)
246
+
247
+ @classmethod
248
+ def get_processor_class(cls, name):
249
+ return cls.mapping["processor_name_mapping"].get(name, None)
250
+
251
+ @classmethod
252
+ def get_lr_scheduler_class(cls, name):
253
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
254
+
255
+ @classmethod
256
+ def get_runner_class(cls, name):
257
+ return cls.mapping["runner_name_mapping"].get(name, None)
258
+
259
+ @classmethod
260
+ def list_runners(cls):
261
+ return sorted(cls.mapping["runner_name_mapping"].keys())
262
+
263
+ @classmethod
264
+ def list_models(cls):
265
+ return sorted(cls.mapping["model_name_mapping"].keys())
266
+
267
+ @classmethod
268
+ def list_tasks(cls):
269
+ return sorted(cls.mapping["task_name_mapping"].keys())
270
+
271
+ @classmethod
272
+ def list_processors(cls):
273
+ return sorted(cls.mapping["processor_name_mapping"].keys())
274
+
275
+ @classmethod
276
+ def list_lr_schedulers(cls):
277
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
278
+
279
+ @classmethod
280
+ def list_datasets(cls):
281
+ return sorted(cls.mapping["builder_name_mapping"].keys())
282
+
283
+ @classmethod
284
+ def get_path(cls, name):
285
+ return cls.mapping["paths"].get(name, None)
286
+
287
+ @classmethod
288
+ def get(cls, name, default=None, no_warning=False):
289
+ r"""Get an item from registry with key 'name'
290
+
291
+ Args:
292
+ name (string): Key whose value needs to be retrieved.
293
+ default: If passed and key is not in registry, default value will
294
+ be returned with a warning. Default: None
295
+ no_warning (bool): If passed as True, warning when key doesn't exist
296
+ will not be generated. Useful for MMF's
297
+ internal operations. Default: False
298
+ """
299
+ original_name = name
300
+ name = name.split(".")
301
+ value = cls.mapping["state"]
302
+ for subname in name:
303
+ value = value.get(subname, default)
304
+ if value is default:
305
+ break
306
+
307
+ if (
308
+ "writer" in cls.mapping["state"]
309
+ and value == default
310
+ and no_warning is False
311
+ ):
312
+ cls.mapping["state"]["writer"].warning(
313
+ "Key {} is not present in registry, returning default value "
314
+ "of {}".format(original_name, default)
315
+ )
316
+ return value
317
+
318
+ @classmethod
319
+ def unregister(cls, name):
320
+ r"""Remove an item from registry with key 'name'
321
+
322
+ Args:
323
+ name: Key which needs to be removed.
324
+ Usage::
325
+
326
+ from mmf.common.registry import registry
327
+
328
+ config = registry.unregister("config")
329
+ """
330
+ return cls.mapping["state"].pop(name, None)
331
+
332
+
333
+ registry = Registry()
bubogpt/common/utils.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import io
9
+ import json
10
+ import logging
11
+ import os
12
+ import pickle
13
+ import re
14
+ import shutil
15
+ import urllib
16
+ import urllib.error
17
+ import urllib.request
18
+ from typing import Optional
19
+ from urllib.parse import urlparse
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ import yaml
24
+ from iopath.common.download import download
25
+ from iopath.common.file_io import file_lock, g_pathmgr
26
+ from bubogpt.common.registry import registry
27
+ from torch.utils.model_zoo import tqdm
28
+ from torchvision.datasets.utils import (
29
+ check_integrity,
30
+ download_file_from_google_drive,
31
+ extract_archive,
32
+ )
33
+
34
+
35
+ def now():
36
+ from datetime import datetime
37
+
38
+ return datetime.now().strftime("%Y%m%d%H%M")[:-1]
39
+
40
+
41
+ def is_url(url_or_filename):
42
+ parsed = urlparse(url_or_filename)
43
+ return parsed.scheme in ("http", "https")
44
+
45
+
46
+ def get_cache_path(rel_path):
47
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
48
+
49
+
50
+ def get_abs_path(rel_path):
51
+ return os.path.join(registry.get_path("library_root"), rel_path)
52
+
53
+
54
+ def load_json(filename):
55
+ with open(filename, "r") as f:
56
+ return json.load(f)
57
+
58
+
59
+ # The following are adapted from torchvision and vissl
60
+ # torchvision: https://github.com/pytorch/vision
61
+ # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
62
+
63
+
64
+ def makedir(dir_path):
65
+ """
66
+ Create the directory if it does not exist.
67
+ """
68
+ is_success = False
69
+ try:
70
+ if not g_pathmgr.exists(dir_path):
71
+ g_pathmgr.mkdirs(dir_path)
72
+ is_success = True
73
+ except BaseException:
74
+ print(f"Error creating directory: {dir_path}")
75
+ return is_success
76
+
77
+
78
+ def get_redirected_url(url: str):
79
+ """
80
+ Given a URL, returns the URL it redirects to or the
81
+ original URL in case of no indirection
82
+ """
83
+ import requests
84
+
85
+ with requests.Session() as session:
86
+ with session.get(url, stream=True, allow_redirects=True) as response:
87
+ if response.history:
88
+ return response.url
89
+ else:
90
+ return url
91
+
92
+
93
+ def to_google_drive_download_url(view_url: str) -> str:
94
+ """
95
+ Utility function to transform a view URL of google drive
96
+ to a download URL for google drive
97
+ Example input:
98
+ https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
99
+ Example output:
100
+ https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
101
+ """
102
+ splits = view_url.split("/")
103
+ assert splits[-1] == "view"
104
+ file_id = splits[-2]
105
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
106
+
107
+
108
+ def download_google_drive_url(url: str, output_path: str, output_file_name: str):
109
+ """
110
+ Download a file from google drive
111
+ Downloading an URL from google drive requires confirmation when
112
+ the file of the size is too big (google drive notifies that
113
+ anti-viral checks cannot be performed on such files)
114
+ """
115
+ import requests
116
+
117
+ with requests.Session() as session:
118
+
119
+ # First get the confirmation token and append it to the URL
120
+ with session.get(url, stream=True, allow_redirects=True) as response:
121
+ for k, v in response.cookies.items():
122
+ if k.startswith("download_warning"):
123
+ url = url + "&confirm=" + v
124
+
125
+ # Then download the content of the file
126
+ with session.get(url, stream=True, verify=True) as response:
127
+ makedir(output_path)
128
+ path = os.path.join(output_path, output_file_name)
129
+ total_size = int(response.headers.get("Content-length", 0))
130
+ with open(path, "wb") as file:
131
+ from tqdm import tqdm
132
+
133
+ with tqdm(total=total_size) as progress_bar:
134
+ for block in response.iter_content(
135
+ chunk_size=io.DEFAULT_BUFFER_SIZE
136
+ ):
137
+ file.write(block)
138
+ progress_bar.update(len(block))
139
+
140
+
141
+ def _get_google_drive_file_id(url: str) -> Optional[str]:
142
+ parts = urlparse(url)
143
+
144
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
145
+ return None
146
+
147
+ match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
148
+ if match is None:
149
+ return None
150
+
151
+ return match.group("id")
152
+
153
+
154
+ def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
155
+ with open(filename, "wb") as fh:
156
+ with urllib.request.urlopen(
157
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
158
+ ) as response:
159
+ with tqdm(total=response.length) as pbar:
160
+ for chunk in iter(lambda: response.read(chunk_size), ""):
161
+ if not chunk:
162
+ break
163
+ pbar.update(chunk_size)
164
+ fh.write(chunk)
165
+
166
+
167
+ def download_url(
168
+ url: str,
169
+ root: str,
170
+ filename: Optional[str] = None,
171
+ md5: Optional[str] = None,
172
+ ) -> None:
173
+ """Download a file from a url and place it in root.
174
+ Args:
175
+ url (str): URL to download file from
176
+ root (str): Directory to place downloaded file in
177
+ filename (str, optional): Name to save the file under.
178
+ If None, use the basename of the URL.
179
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
180
+ """
181
+ root = os.path.expanduser(root)
182
+ if not filename:
183
+ filename = os.path.basename(url)
184
+ fpath = os.path.join(root, filename)
185
+
186
+ makedir(root)
187
+
188
+ # check if file is already present locally
189
+ if check_integrity(fpath, md5):
190
+ print("Using downloaded and verified file: " + fpath)
191
+ return
192
+
193
+ # expand redirect chain if needed
194
+ url = get_redirected_url(url)
195
+
196
+ # check if file is located on Google Drive
197
+ file_id = _get_google_drive_file_id(url)
198
+ if file_id is not None:
199
+ return download_file_from_google_drive(file_id, root, filename, md5)
200
+
201
+ # download the file
202
+ try:
203
+ print("Downloading " + url + " to " + fpath)
204
+ _urlretrieve(url, fpath)
205
+ except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
206
+ if url[:5] == "https":
207
+ url = url.replace("https:", "http:")
208
+ print(
209
+ "Failed download. Trying https -> http instead."
210
+ " Downloading " + url + " to " + fpath
211
+ )
212
+ _urlretrieve(url, fpath)
213
+ else:
214
+ raise e
215
+
216
+ # check integrity of downloaded file
217
+ if not check_integrity(fpath, md5):
218
+ raise RuntimeError("File not found or corrupted.")
219
+
220
+
221
+ def download_and_extract_archive(
222
+ url: str,
223
+ download_root: str,
224
+ extract_root: Optional[str] = None,
225
+ filename: Optional[str] = None,
226
+ md5: Optional[str] = None,
227
+ remove_finished: bool = False,
228
+ ) -> None:
229
+ download_root = os.path.expanduser(download_root)
230
+ if extract_root is None:
231
+ extract_root = download_root
232
+ if not filename:
233
+ filename = os.path.basename(url)
234
+
235
+ download_url(url, download_root, filename, md5)
236
+
237
+ archive = os.path.join(download_root, filename)
238
+ print("Extracting {} to {}".format(archive, extract_root))
239
+ extract_archive(archive, extract_root, remove_finished)
240
+
241
+
242
+ def cache_url(url: str, cache_dir: str) -> str:
243
+ """
244
+ This implementation downloads the remote resource and caches it locally.
245
+ The resource will only be downloaded if not previously requested.
246
+ """
247
+ parsed_url = urlparse(url)
248
+ dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
249
+ makedir(dirname)
250
+ filename = url.split("/")[-1]
251
+ cached = os.path.join(dirname, filename)
252
+ with file_lock(cached):
253
+ if not os.path.isfile(cached):
254
+ logging.info(f"Downloading {url} to {cached} ...")
255
+ cached = download(url, dirname, filename=filename)
256
+ logging.info(f"URL {url} cached in {cached}")
257
+ return cached
258
+
259
+
260
+ # TODO (prigoyal): convert this into RAII-style API
261
+ def create_file_symlink(file1, file2):
262
+ """
263
+ Simply create the symlinks for a given file1 to file2.
264
+ Useful during model checkpointing to symlinks to the
265
+ latest successful checkpoint.
266
+ """
267
+ try:
268
+ if g_pathmgr.exists(file2):
269
+ g_pathmgr.rm(file2)
270
+ g_pathmgr.symlink(file1, file2)
271
+ except Exception as e:
272
+ logging.info(f"Could NOT create symlink. Error: {e}")
273
+
274
+
275
+ def save_file(data, filename, append_to_json=True, verbose=True):
276
+ """
277
+ Common i/o utility to handle saving data to various file formats.
278
+ Supported:
279
+ .pkl, .pickle, .npy, .json
280
+ Specifically for .json, users have the option to either append (default)
281
+ or rewrite by passing in Boolean value to append_to_json.
282
+ """
283
+ if verbose:
284
+ logging.info(f"Saving data to file: {filename}")
285
+ file_ext = os.path.splitext(filename)[1]
286
+ if file_ext in [".pkl", ".pickle"]:
287
+ with g_pathmgr.open(filename, "wb") as fopen:
288
+ pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
289
+ elif file_ext == ".npy":
290
+ with g_pathmgr.open(filename, "wb") as fopen:
291
+ np.save(fopen, data)
292
+ elif file_ext == ".json":
293
+ if append_to_json:
294
+ with g_pathmgr.open(filename, "a") as fopen:
295
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
296
+ fopen.flush()
297
+ else:
298
+ with g_pathmgr.open(filename, "w") as fopen:
299
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
300
+ fopen.flush()
301
+ elif file_ext == ".yaml":
302
+ with g_pathmgr.open(filename, "w") as fopen:
303
+ dump = yaml.dump(data)
304
+ fopen.write(dump)
305
+ fopen.flush()
306
+ else:
307
+ raise Exception(f"Saving {file_ext} is not supported yet")
308
+
309
+ if verbose:
310
+ logging.info(f"Saved data to file: {filename}")
311
+
312
+
313
+ def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
314
+ """
315
+ Common i/o utility to handle loading data from various file formats.
316
+ Supported:
317
+ .pkl, .pickle, .npy, .json
318
+ For the npy files, we support reading the files in mmap_mode.
319
+ If the mmap_mode of reading is not successful, we load data without the
320
+ mmap_mode.
321
+ """
322
+ if verbose:
323
+ logging.info(f"Loading data from file: {filename}")
324
+
325
+ file_ext = os.path.splitext(filename)[1]
326
+ if file_ext == ".txt":
327
+ with g_pathmgr.open(filename, "r") as fopen:
328
+ data = fopen.readlines()
329
+ elif file_ext in [".pkl", ".pickle"]:
330
+ with g_pathmgr.open(filename, "rb") as fopen:
331
+ data = pickle.load(fopen, encoding="latin1")
332
+ elif file_ext == ".npy":
333
+ if mmap_mode:
334
+ try:
335
+ with g_pathmgr.open(filename, "rb") as fopen:
336
+ data = np.load(
337
+ fopen,
338
+ allow_pickle=allow_pickle,
339
+ encoding="latin1",
340
+ mmap_mode=mmap_mode,
341
+ )
342
+ except ValueError as e:
343
+ logging.info(
344
+ f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
345
+ )
346
+ data = np.load(
347
+ filename,
348
+ allow_pickle=allow_pickle,
349
+ encoding="latin1",
350
+ mmap_mode=mmap_mode,
351
+ )
352
+ logging.info("Successfully loaded without g_pathmgr")
353
+ except Exception:
354
+ logging.info("Could not mmap without g_pathmgr. Trying without mmap")
355
+ with g_pathmgr.open(filename, "rb") as fopen:
356
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
357
+ else:
358
+ with g_pathmgr.open(filename, "rb") as fopen:
359
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
360
+ elif file_ext == ".json":
361
+ with g_pathmgr.open(filename, "r") as fopen:
362
+ data = json.load(fopen)
363
+ elif file_ext == ".yaml":
364
+ with g_pathmgr.open(filename, "r") as fopen:
365
+ data = yaml.load(fopen, Loader=yaml.FullLoader)
366
+ elif file_ext == ".csv":
367
+ with g_pathmgr.open(filename, "r") as fopen:
368
+ data = pd.read_csv(fopen)
369
+ else:
370
+ raise Exception(f"Reading from {file_ext} is not supported yet")
371
+ return data
372
+
373
+
374
+ def abspath(resource_path: str):
375
+ """
376
+ Make a path absolute, but take into account prefixes like
377
+ "http://" or "manifold://"
378
+ """
379
+ regex = re.compile(r"^\w+://")
380
+ if regex.match(resource_path) is None:
381
+ return os.path.abspath(resource_path)
382
+ else:
383
+ return resource_path
384
+
385
+
386
+ def makedir(dir_path):
387
+ """
388
+ Create the directory if it does not exist.
389
+ """
390
+ is_success = False
391
+ try:
392
+ if not g_pathmgr.exists(dir_path):
393
+ g_pathmgr.mkdirs(dir_path)
394
+ is_success = True
395
+ except BaseException:
396
+ logging.info(f"Error creating directory: {dir_path}")
397
+ return is_success
398
+
399
+
400
+ def is_url(input_url):
401
+ """
402
+ Check if an input string is a url. look for http(s):// and ignoring the case
403
+ """
404
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
405
+ return is_url
406
+
407
+
408
+ def cleanup_dir(dir):
409
+ """
410
+ Utility for deleting a directory. Useful for cleaning the storage space
411
+ that contains various training artifacts like checkpoints, data etc.
412
+ """
413
+ if os.path.exists(dir):
414
+ logging.info(f"Deleting directory: {dir}")
415
+ shutil.rmtree(dir)
416
+ logging.info(f"Deleted contents of directory: {dir}")
417
+
418
+
419
+ def get_file_size(filename):
420
+ """
421
+ Given a file, get the size of file in MB
422
+ """
423
+ size_in_mb = os.path.getsize(filename) / float(1024**2)
424
+ return size_in_mb
bubogpt/configs/datasets/aud_img_neg/default.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+ aud_img_neg:
3
+ data_type: audio_image
4
+ build_info:
5
+ image:
6
+ storage: /path/to/cc_sbu_align
7
+ ann_files: ['filter_cap.json']
8
+ audio:
9
+ storage: /path/to/clotho
10
+ ann_files: ['audio_cap.json']
bubogpt/configs/datasets/audioset/defaults.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets:
2
+ audioset:
3
+ data_type: audio
4
+ build_info:
5
+ storage: /path/to/AudioSet_SL/AudioSet_SL{00..54}.tar
bubogpt/configs/datasets/bbc/defaults.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets:
2
+ bbc:
3
+ data_type: audio
4
+ build_info:
5
+ storage: /path/to/BBC_Sound_Effects/BBC_Sound_Effects{000000..000062}.tar
bubogpt/configs/datasets/cc12m/defaults.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets:
2
+ cc12m:
3
+ data_type: images
4
+ build_info:
5
+ storage: /path/to/cc12m_web/{000000..002221}.tar
bubogpt/configs/datasets/cc_sbu/align.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets:
2
+ cc_sbu_align:
3
+ data_type: images
4
+ build_info:
5
+ storage: /path/to/cc_sbu_align
bubogpt/configs/datasets/cc_sbu/defaults.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets:
2
+ cc_sbu:
3
+ data_type: images
4
+ build_info:
5
+ storage: /path/to/cc_sbu_dataset/{00000..01255}.tar
bubogpt/configs/datasets/clotho/align.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets:
2
+ clotho_align:
3
+ data_type: audio
4
+ build_info:
5
+ storage: /path/to/clotho
bubogpt/configs/datasets/freesound/defaults.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets:
2
+ freesound:
3
+ data_type: audio
4
+ build_info:
5
+ storage: /path/to/wavcaps/web_datasets/FreeSound/FreeSound{000000..000524}.tar
bubogpt/configs/datasets/laion/defaults.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets:
2
+ laion:
3
+ data_type: images
4
+ build_info:
5
+ storage: /path/to/laion_dataset/{00000..10488}.tar
bubogpt/configs/datasets/soundbible/defaults.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets:
2
+ soundbible:
3
+ data_type: audio
4
+ build_info:
5
+ storage: /path/to/SoundBible0.tar
bubogpt/configs/datasets/vggss/align.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ datasets:
2
+ vggss_align:
3
+ data_type: audio_image
4
+ build_info:
5
+ storage: /path/to/vggss
6
+ ann_files: ["vggss_mult_prefix.json"]
bubogpt/configs/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ env:
2
+ # For default users
3
+ # cache_root: "cache"
4
+ # For internal use with persistent storage
5
+ cache_root: "/export/home/.cache/bubogpt"
bubogpt/configs/models/mmgpt4.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: mm_gpt4
3
+
4
+ # Imagebind
5
+ freeze_imagebind: True
6
+
7
+ # Q-Former
8
+ freeze_qformer: True
9
+ q_former_model: "/path/to/blip2_pretrained_flant5xxl.pth"
10
+ num_query_token: 32
11
+
12
+ # Vicuna
13
+ llama_model: "/path/to/vicuna-7b-v0/"
14
+
15
+ # generation configs
16
+ prompt: ""
17
+
18
+ preprocess:
19
+ vis_processor:
20
+ train:
21
+ name: "imagebind_vision_train"
22
+ image_size: 224
23
+ eval:
24
+ name: "imagebind_vision_eval"
25
+ image_size: 224
26
+ text_processor:
27
+ train:
28
+ name: "imagebind_caption"
29
+ eval:
30
+ name: "imagebind_caption"
bubogpt/datasets/__init__.py ADDED
File without changes
bubogpt/datasets/builders/__init__.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from bubogpt.datasets.builders.image_base_dataset_builder import load_dataset_config
9
+ from bubogpt.datasets.builders.image_text_pair_builder import (
10
+ CCSBUBuilderImage,
11
+ LaionBuilderImage,
12
+ CCSBUAlignBuilderImage,
13
+ LlavaInstruct150Builder,
14
+ )
15
+ from bubogpt.datasets.builders.audio_text_pair_builder import (
16
+ BBCBuilder,
17
+ AudioSetBuilder,
18
+ SoundBibleBuilder,
19
+ FreeSoundBuilder
20
+ )
21
+ from bubogpt.datasets.builders.audio_image_text_builder import (
22
+ VGGSSBuilderAudioImage
23
+ )
24
+ from bubogpt.common.registry import registry
25
+
26
+ __all__ = [
27
+ "CCSBUBuilderImage",
28
+ "LaionBuilderImage",
29
+ "CCSBUAlignBuilderImage",
30
+ "LlavaInstruct150Builder",
31
+ # Audio builders
32
+ "BBCBuilder",
33
+ "AudioSetBuilder",
34
+ "SoundBibleBuilder",
35
+ "FreeSoundBuilder",
36
+ # Audio Image builders
37
+ "VGGSSBuilderAudioImage"
38
+ ]
39
+
40
+
41
+ def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
42
+ """
43
+ Example
44
+
45
+ >>> dataset = load_dataset("coco_caption", cfg=None)
46
+ >>> splits = dataset.keys()
47
+ >>> print([len(dataset[split]) for split in splits])
48
+
49
+ """
50
+ if cfg_path is None:
51
+ cfg = None
52
+ else:
53
+ cfg = load_dataset_config(cfg_path)
54
+
55
+ try:
56
+ builder = registry.get_builder_class(name)(cfg)
57
+ except TypeError:
58
+ print(
59
+ f"Dataset {name} not found. Available datasets:\n"
60
+ + ", ".join([str(k) for k in dataset_zoo.get_names()])
61
+ )
62
+ exit(1)
63
+
64
+ if vis_path is not None:
65
+ if data_type is None:
66
+ # use default data type in the config
67
+ data_type = builder.config.data_type
68
+
69
+ assert (
70
+ data_type in builder.config.build_info
71
+ ), f"Invalid data_type {data_type} for {name}."
72
+
73
+ builder.config.build_info.get(data_type).storage = vis_path
74
+
75
+ dataset = builder.build_datasets()
76
+ return dataset
77
+
78
+
79
+ class DatasetZoo:
80
+ def __init__(self) -> None:
81
+ self.dataset_zoo = {
82
+ k: list(v.DATASET_CONFIG_DICT.keys())
83
+ for k, v in sorted(registry.mapping["builder_name_mapping"].items())
84
+ }
85
+
86
+ def get_names(self):
87
+ return list(self.dataset_zoo.keys())
88
+
89
+
90
+ dataset_zoo = DatasetZoo()
bubogpt/datasets/builders/audio_base_dataset_builder.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import shutil
4
+ import warnings
5
+
6
+ from omegaconf import OmegaConf
7
+ import torch.distributed as dist
8
+ from torchvision.datasets.utils import download_url
9
+
10
+ import bubogpt.common.utils as utils
11
+ from bubogpt.common.dist_utils import is_dist_avail_and_initialized, is_main_process
12
+ from bubogpt.common.registry import registry
13
+ from bubogpt.datasets.builders import load_dataset_config
14
+ from bubogpt.processors.base_processor import BaseProcessor
15
+
16
+
17
+ class AudioBaseDatasetBuilder:
18
+ train_dataset_cls, eval_dataset_cls = None, None
19
+
20
+ def __init__(self, cfg=None):
21
+ super().__init__()
22
+
23
+ if cfg is None:
24
+ # help to create datasets from default config.
25
+ self.config = load_dataset_config(self.default_config_path())
26
+ elif isinstance(cfg, str):
27
+ self.config = load_dataset_config(cfg)
28
+ else:
29
+ # when called from task.build_dataset()
30
+ self.config = cfg
31
+
32
+ self.data_type = self.config.data_type
33
+
34
+ self.audio_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
35
+ self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
36
+
37
+ def build_datasets(self):
38
+ # download, split, etc...
39
+ # only called on 1 GPU/TPU in distributed
40
+
41
+ if is_main_process():
42
+ self._download_data()
43
+
44
+ if is_dist_avail_and_initialized():
45
+ dist.barrier()
46
+
47
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
48
+ logging.info("Building datasets...")
49
+ datasets = self.build() # dataset['train'/'val'/'test']
50
+
51
+ return datasets
52
+
53
+ def build_processors(self):
54
+ aud_proc_cfg = self.config.get("audio_processor")
55
+ txt_proc_cfg = self.config.get("text_processor")
56
+
57
+ if aud_proc_cfg is not None:
58
+ aud_train_cfg = aud_proc_cfg.get("train")
59
+ aud_eval_cfg = aud_proc_cfg.get("eval")
60
+
61
+ self.audio_processors["train"] = self._build_proc_from_cfg(aud_train_cfg)
62
+ self.audio_processors["eval"] = self._build_proc_from_cfg(aud_eval_cfg)
63
+
64
+ if txt_proc_cfg is not None:
65
+ txt_train_cfg = txt_proc_cfg.get("train")
66
+ txt_eval_cfg = txt_proc_cfg.get("eval")
67
+
68
+ self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
69
+ self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
70
+
71
+ @staticmethod
72
+ def _build_proc_from_cfg(cfg):
73
+ return (
74
+ registry.get_processor_class(cfg.name).from_config(cfg)
75
+ if cfg is not None
76
+ else None
77
+ )
78
+
79
+ @classmethod
80
+ def default_config_path(cls, type="default"):
81
+ return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
82
+
83
+ def _download_data(self):
84
+ self._download_ann()
85
+ self._download_aud()
86
+
87
+ def _download_ann(self):
88
+ """
89
+ Download annotation files if necessary.
90
+ All the audio-language datasets should have annotations of unified format.
91
+
92
+ storage_path can be:
93
+ (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
94
+ (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
95
+
96
+ Local annotation paths should be relative.
97
+ """
98
+ anns = self.config.build_info.annotations
99
+
100
+ splits = anns.keys()
101
+
102
+ cache_root = registry.get_path("cache_root")
103
+
104
+ for split in splits:
105
+ info = anns[split]
106
+
107
+ urls, storage_paths = info.get("url", None), info.storage
108
+
109
+ if isinstance(urls, str):
110
+ urls = [urls]
111
+ if isinstance(storage_paths, str):
112
+ storage_paths = [storage_paths]
113
+
114
+ assert len(urls) == len(storage_paths)
115
+
116
+ for url_or_filename, storage_path in zip(urls, storage_paths):
117
+ # if storage_path is relative, make it full by prefixing with cache_root.
118
+ if not os.path.isabs(storage_path):
119
+ storage_path = os.path.join(cache_root, storage_path)
120
+
121
+ dirname = os.path.dirname(storage_path)
122
+ if not os.path.exists(dirname):
123
+ os.makedirs(dirname)
124
+
125
+ if os.path.isfile(url_or_filename):
126
+ src, dst = url_or_filename, storage_path
127
+ if not os.path.exists(dst):
128
+ shutil.copyfile(src=src, dst=dst)
129
+ else:
130
+ logging.info("Using existing file {}.".format(dst))
131
+ else:
132
+ if os.path.isdir(storage_path):
133
+ # if only dirname is provided, suffix with basename of URL.
134
+ raise ValueError(
135
+ "Expecting storage_path to be a file path, got directory {}".format(
136
+ storage_path
137
+ )
138
+ )
139
+ else:
140
+ filename = os.path.basename(storage_path)
141
+
142
+ download_url(url=url_or_filename, root=dirname, filename=filename)
bubogpt/datasets/builders/audio_image_text_builder.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import warnings
4
+
5
+ from bubogpt.common.registry import registry
6
+ from bubogpt.datasets.builders.multimodal_base_dataset_builder import MultimodalBaseDatasetBuilder
7
+ from bubogpt.datasets.datasets.audio_image.audio_image_datasets import AudioLocalizationDataset, AudioImageNegDataset
8
+
9
+
10
+ @registry.register_builder("vggss_align")
11
+ class VGGSSBuilderAudioImage(MultimodalBaseDatasetBuilder):
12
+ train_dataset_cls = AudioLocalizationDataset
13
+
14
+ DATASET_CONFIG_DICT = {
15
+ "default": "configs/datasets/vggss/align.yaml",
16
+ "3k": "configs/datasets/vggss/align3k.yaml",
17
+ "5k": "configs/datasets/vggss/align5k.yaml",
18
+ "31k": "configs/datasets/vggss/align31k.yaml",
19
+ }
20
+
21
+ def build_datasets(self):
22
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
23
+ logging.info("Building datasets...")
24
+ self.build_processors()
25
+
26
+ build_info = self.config.build_info
27
+ storage_path = build_info.storage
28
+
29
+ datasets = dict()
30
+
31
+ if not os.path.exists(storage_path):
32
+ warnings.warn("storage path {} does not exist.".format(storage_path))
33
+ print("Building datasets with: ", self.get_ann_files())
34
+
35
+ # create datasets
36
+ dataset_cls = self.train_dataset_cls
37
+ datasets['train'] = dataset_cls(
38
+ processors={**{
39
+ modal: self.processors[modal]["train"] for modal in self.data_type
40
+ }, **{
41
+ "text": self.processors["text"]["train"]
42
+ }},
43
+ roots={
44
+ modal: os.path.join(storage_path, f"{modal}s") for modal in self.data_type
45
+ },
46
+ # ann_paths=[os.path.join(storage_path, 'vggsound_balanced.json')],
47
+ ann_paths=self.get_ann_files(),
48
+ )
49
+
50
+ return datasets
51
+
52
+ def get_ann_files(self):
53
+ ann_files = self.config.build_info.get("ann_files", ["vggsound_balanced.json"])
54
+ return [os.path.join(self.config.build_info.storage, fname) for fname in ann_files]
55
+
56
+
57
+ @registry.register_builder("aud_img_neg")
58
+ class NegBuilderAudioImage(MultimodalBaseDatasetBuilder):
59
+ train_dataset_cls = AudioImageNegDataset
60
+
61
+ DATASET_CONFIG_DICT = {
62
+ "default": "configs/datasets/aud_img_neg/default.yaml",
63
+ }
64
+
65
+ def build_datasets(self):
66
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
67
+ logging.info("Building datasets...")
68
+ self.build_processors()
69
+
70
+ build_info = self.config.build_info
71
+ # storage_path = build_info.storage
72
+ storage_path = {
73
+ "image": build_info.image.storage,
74
+ "audio": build_info.audio.storage,
75
+ }
76
+ ann_files = {
77
+ "image": build_info.image.ann_files,
78
+ "audio": build_info.audio.ann_files,
79
+ }
80
+ ann_paths = {
81
+ modal: [os.path.join(storage_path[modal], fname) for fname in ann_files[modal]] for modal in self.data_type
82
+ }
83
+
84
+ datasets = dict()
85
+
86
+ for path in storage_path.values():
87
+ if not os.path.exists(path):
88
+ warnings.warn("storage path {} does not exist.".format(path))
89
+ print("Building datasets with: ", ann_paths)
90
+
91
+ # create datasets
92
+ dataset_cls = self.train_dataset_cls
93
+ datasets['train'] = dataset_cls(
94
+ processors={**{
95
+ modal: self.processors[modal]["train"] for modal in self.data_type
96
+ }, **{
97
+ "text": self.processors["text"]["train"]
98
+ }},
99
+ roots={
100
+ modal: os.path.join(storage_path[modal], f"{modal}") for modal in self.data_type
101
+ },
102
+ ann_paths=ann_paths,
103
+ )
104
+
105
+ return datasets
bubogpt/datasets/builders/audio_text_pair_builder.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import warnings
4
+
5
+ from bubogpt.common.registry import registry
6
+ from bubogpt.datasets.builders.audio_base_dataset_builder import AudioBaseDatasetBuilder
7
+ from bubogpt.datasets.datasets.audio_caption import GenericAudioDataset, AudioCaptionDataset
8
+
9
+
10
+ class GenericAudioBuilder(AudioBaseDatasetBuilder):
11
+ train_dataset_cls = GenericAudioDataset
12
+
13
+ def _download_ann(self):
14
+ pass
15
+
16
+ def _download_aud(self):
17
+ pass
18
+
19
+ def build(self):
20
+ self.build_processors()
21
+
22
+ build_info = self.config.build_info
23
+
24
+ datasets = dict()
25
+ split = "train"
26
+
27
+ # create datasets
28
+ dataset_cls = self.train_dataset_cls
29
+ datasets[split] = dataset_cls(
30
+ audio_processor=self.audio_processors[split],
31
+ text_processor=self.text_processors[split],
32
+ location=build_info.storage,
33
+ ).inner_dataset
34
+
35
+ return datasets
36
+
37
+
38
+ @registry.register_builder("bbc")
39
+ class BBCBuilder(GenericAudioBuilder):
40
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/bbc/defaults.yaml"}
41
+
42
+
43
+ @registry.register_builder("audioset")
44
+ class AudioSetBuilder(GenericAudioBuilder):
45
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/audioset/defaults.yaml"}
46
+
47
+
48
+ @registry.register_builder("soundbible")
49
+ class SoundBibleBuilder(GenericAudioBuilder):
50
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/soundbible/defaults.yaml"}
51
+
52
+
53
+ @registry.register_builder("freesound")
54
+ class FreeSoundBuilder(GenericAudioBuilder):
55
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/freesound/defaults.yaml"}
56
+
57
+
58
+ @registry.register_builder("clotho_align")
59
+ class ClothoAlignBuilderAudio(GenericAudioBuilder):
60
+ train_dataset_cls = AudioCaptionDataset
61
+
62
+ DATASET_CONFIG_DICT = {
63
+ "default": "configs/datasets/clotho/align.yaml",
64
+ }
65
+
66
+ def build_datasets(self):
67
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
68
+ logging.info("Building datasets...")
69
+ self.build_processors()
70
+
71
+ build_info = self.config.build_info
72
+ storage_path = build_info.storage
73
+
74
+ datasets = dict()
75
+
76
+ if not os.path.exists(storage_path):
77
+ warnings.warn("storage path {} does not exist.".format(storage_path))
78
+
79
+ # create datasets
80
+ dataset_cls = self.train_dataset_cls
81
+ datasets['train'] = dataset_cls(
82
+ audio_processor=self.audio_processors["train"],
83
+ text_processor=self.text_processors["train"],
84
+ audio_root=os.path.join(storage_path, 'all'),
85
+ ann_paths=[os.path.join(storage_path, 'audio_cap.json')],
86
+ )
87
+
88
+ return datasets
bubogpt/datasets/builders/image_base_dataset_builder.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is from
3
+ Copyright (c) 2022, salesforce.com, inc.
4
+ All rights reserved.
5
+ SPDX-License-Identifier: BSD-3-Clause
6
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7
+ """
8
+
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import warnings
13
+
14
+ from omegaconf import OmegaConf
15
+ import torch.distributed as dist
16
+ from torchvision.datasets.utils import download_url
17
+
18
+ import bubogpt.common.utils as utils
19
+ from bubogpt.common.dist_utils import is_dist_avail_and_initialized, is_main_process
20
+ from bubogpt.common.registry import registry
21
+ from bubogpt.processors.base_processor import BaseProcessor
22
+
23
+
24
+ class ImageBaseDatasetBuilder:
25
+ train_dataset_cls, eval_dataset_cls = None, None
26
+
27
+ def __init__(self, cfg=None):
28
+ super().__init__()
29
+
30
+ if cfg is None:
31
+ # help to create datasets from default config.
32
+ self.config = load_dataset_config(self.default_config_path())
33
+ elif isinstance(cfg, str):
34
+ self.config = load_dataset_config(cfg)
35
+ else:
36
+ # when called from task.build_dataset()
37
+ self.config = cfg
38
+
39
+ self.data_type = self.config.data_type
40
+
41
+ self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
42
+ self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
43
+
44
+ def build_datasets(self):
45
+ # download, split, etc...
46
+ # only called on 1 GPU/TPU in distributed
47
+
48
+ if is_main_process():
49
+ self._download_data()
50
+
51
+ if is_dist_avail_and_initialized():
52
+ dist.barrier()
53
+
54
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
55
+ logging.info("Building datasets...")
56
+ datasets = self.build() # dataset['train'/'val'/'test']
57
+
58
+ return datasets
59
+
60
+ def build_processors(self):
61
+ vis_proc_cfg = self.config.get("vis_processor")
62
+ txt_proc_cfg = self.config.get("text_processor")
63
+
64
+ if vis_proc_cfg is not None:
65
+ vis_train_cfg = vis_proc_cfg.get("train")
66
+ vis_eval_cfg = vis_proc_cfg.get("eval")
67
+
68
+ self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
69
+ self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
70
+
71
+ if txt_proc_cfg is not None:
72
+ txt_train_cfg = txt_proc_cfg.get("train")
73
+ txt_eval_cfg = txt_proc_cfg.get("eval")
74
+
75
+ self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
76
+ self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
77
+
78
+ @staticmethod
79
+ def _build_proc_from_cfg(cfg):
80
+ return (
81
+ registry.get_processor_class(cfg.name).from_config(cfg)
82
+ if cfg is not None
83
+ else None
84
+ )
85
+
86
+ @classmethod
87
+ def default_config_path(cls, type="default"):
88
+ if cls.DATASET_CONFIG_DICT[type] is None:
89
+ return None
90
+ else:
91
+ return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
92
+
93
+ def _download_data(self):
94
+ self._download_ann()
95
+ self._download_vis()
96
+
97
+ def _download_ann(self):
98
+ """
99
+ Download annotation files if necessary.
100
+ All the vision-language datasets should have annotations of unified format.
101
+
102
+ storage_path can be:
103
+ (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
104
+ (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
105
+
106
+ Local annotation paths should be relative.
107
+ """
108
+ anns = self.config.build_info.annotations
109
+
110
+ splits = anns.keys()
111
+
112
+ cache_root = registry.get_path("cache_root")
113
+
114
+ for split in splits:
115
+ info = anns[split]
116
+
117
+ urls, storage_paths = info.get("url", None), info.storage
118
+
119
+ if isinstance(urls, str):
120
+ urls = [urls]
121
+ if isinstance(storage_paths, str):
122
+ storage_paths = [storage_paths]
123
+
124
+ assert len(urls) == len(storage_paths)
125
+
126
+ for url_or_filename, storage_path in zip(urls, storage_paths):
127
+ # if storage_path is relative, make it full by prefixing with cache_root.
128
+ if not os.path.isabs(storage_path):
129
+ storage_path = os.path.join(cache_root, storage_path)
130
+
131
+ dirname = os.path.dirname(storage_path)
132
+ if not os.path.exists(dirname):
133
+ os.makedirs(dirname)
134
+
135
+ if os.path.isfile(url_or_filename):
136
+ src, dst = url_or_filename, storage_path
137
+ if not os.path.exists(dst):
138
+ shutil.copyfile(src=src, dst=dst)
139
+ else:
140
+ logging.info("Using existing file {}.".format(dst))
141
+ else:
142
+ if os.path.isdir(storage_path):
143
+ # if only dirname is provided, suffix with basename of URL.
144
+ raise ValueError(
145
+ "Expecting storage_path to be a file path, got directory {}".format(
146
+ storage_path
147
+ )
148
+ )
149
+ else:
150
+ filename = os.path.basename(storage_path)
151
+
152
+ download_url(url=url_or_filename, root=dirname, filename=filename)
153
+
154
+ def _download_vis(self):
155
+
156
+ storage_path = self.config.build_info.get(self.data_type).storage
157
+ storage_path = utils.get_cache_path(storage_path)
158
+
159
+ if not os.path.exists(storage_path):
160
+ warnings.warn(
161
+ f"""
162
+ The specified path {storage_path} for visual inputs does not exist.
163
+ Please provide a correct path to the visual inputs or
164
+ refer to datasets/download_scripts/README.md for downloading instructions.
165
+ """
166
+ )
167
+
168
+ def build(self):
169
+ """
170
+ Create by split datasets inheriting torch.utils.data.Datasets.
171
+
172
+ # build() can be dataset-specific. Overwrite to customize.
173
+ """
174
+ self.build_processors()
175
+
176
+ build_info = self.config.build_info
177
+
178
+ ann_info = build_info.annotations
179
+ vis_info = build_info.get(self.data_type)
180
+
181
+ datasets = dict()
182
+ for split in ann_info.keys():
183
+ if split not in ["train", "val", "test"]:
184
+ continue
185
+
186
+ is_train = split == "train"
187
+
188
+ # processors
189
+ vis_processor = (
190
+ self.vis_processors["train"]
191
+ if is_train
192
+ else self.vis_processors["eval"]
193
+ )
194
+ text_processor = (
195
+ self.text_processors["train"]
196
+ if is_train
197
+ else self.text_processors["eval"]
198
+ )
199
+
200
+ # annotation path
201
+ ann_paths = ann_info.get(split).storage
202
+ if isinstance(ann_paths, str):
203
+ ann_paths = [ann_paths]
204
+
205
+ abs_ann_paths = []
206
+ for ann_path in ann_paths:
207
+ if not os.path.isabs(ann_path):
208
+ ann_path = utils.get_cache_path(ann_path)
209
+ abs_ann_paths.append(ann_path)
210
+ ann_paths = abs_ann_paths
211
+
212
+ # visual data storage path
213
+ vis_path = os.path.join(vis_info.storage, split)
214
+
215
+ if not os.path.isabs(vis_path):
216
+ # vis_path = os.path.join(utils.get_cache_path(), vis_path)
217
+ vis_path = utils.get_cache_path(vis_path)
218
+
219
+ if not os.path.exists(vis_path):
220
+ warnings.warn("storage path {} does not exist.".format(vis_path))
221
+
222
+ # create datasets
223
+ dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
224
+ datasets[split] = dataset_cls(
225
+ vis_processor=vis_processor,
226
+ text_processor=text_processor,
227
+ ann_paths=ann_paths,
228
+ vis_root=vis_path,
229
+ )
230
+
231
+ return datasets
232
+
233
+
234
+ def load_dataset_config(cfg_path):
235
+ cfg = OmegaConf.load(cfg_path).datasets
236
+ cfg = cfg[list(cfg.keys())[0]]
237
+
238
+ return cfg
bubogpt/datasets/builders/image_text_pair_builder.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import warnings
4
+
5
+ from bubogpt.common.registry import registry
6
+ from bubogpt.datasets.builders.image_base_dataset_builder import ImageBaseDatasetBuilder
7
+ from bubogpt.datasets.datasets.image_caption.laion_dataset import LaionDataset
8
+ from bubogpt.datasets.datasets.image_caption.cc_sbu_dataset import CCSBUDataset, \
9
+ CCSBUAlignDatasetImageImageCaptionDataset, CCDataset
10
+ from bubogpt.datasets.datasets.image_caption.llava_dataset import LlavaInstruct150Dataset
11
+
12
+ @registry.register_builder("cc_sbu")
13
+ class CCSBUBuilderImage(ImageBaseDatasetBuilder):
14
+ train_dataset_cls = CCSBUDataset
15
+
16
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
17
+
18
+ def _download_ann(self):
19
+ pass
20
+
21
+ def _download_vis(self):
22
+ pass
23
+
24
+ def build(self):
25
+ self.build_processors()
26
+
27
+ build_info = self.config.build_info
28
+
29
+ datasets = dict()
30
+ split = "train"
31
+
32
+ # create datasets
33
+ # [NOTE] return inner_datasets (wds.DataPipeline)
34
+ dataset_cls = self.train_dataset_cls
35
+ datasets[split] = dataset_cls(
36
+ vision_processor=self.vis_processors[split],
37
+ text_processor=self.text_processors[split],
38
+ location=build_info.storage,
39
+ ).inner_dataset
40
+
41
+ return datasets
42
+
43
+
44
+ @registry.register_builder("laion")
45
+ class LaionBuilderImage(ImageBaseDatasetBuilder):
46
+ train_dataset_cls = LaionDataset
47
+
48
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
49
+
50
+ def _download_ann(self):
51
+ pass
52
+
53
+ def _download_vis(self):
54
+ pass
55
+
56
+ def build(self):
57
+ self.build_processors()
58
+
59
+ build_info = self.config.build_info
60
+
61
+ datasets = dict()
62
+ split = "train"
63
+
64
+ # create datasets
65
+ # [NOTE] return inner_datasets (wds.DataPipeline)
66
+ dataset_cls = self.train_dataset_cls
67
+ datasets[split] = dataset_cls(
68
+ vision_processor=self.vis_processors[split],
69
+ text_processor=self.text_processors[split],
70
+ location=build_info.storage,
71
+ ).inner_dataset
72
+
73
+ return datasets
74
+
75
+
76
+ @registry.register_builder("cc_sbu_align")
77
+ class CCSBUAlignBuilderImage(ImageBaseDatasetBuilder):
78
+ train_dataset_cls = CCSBUAlignDatasetImageImageCaptionDataset
79
+
80
+ DATASET_CONFIG_DICT = {
81
+ "default": "configs/datasets/cc_sbu/align.yaml",
82
+ }
83
+
84
+ def build_datasets(self):
85
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
86
+ logging.info("Building datasets...")
87
+ self.build_processors()
88
+
89
+ build_info = self.config.build_info
90
+ storage_path = build_info.storage
91
+
92
+ datasets = dict()
93
+
94
+ if not os.path.exists(storage_path):
95
+ warnings.warn("storage path {} does not exist.".format(storage_path))
96
+
97
+ # create datasets
98
+ dataset_cls = self.train_dataset_cls
99
+ datasets['train'] = dataset_cls(
100
+ vision_processor=self.vis_processors["train"],
101
+ text_processor=self.text_processors["train"],
102
+ ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
103
+ vis_root=os.path.join(storage_path, 'image'),
104
+ )
105
+
106
+ return datasets
107
+
108
+
109
+ @registry.register_builder("cc12m")
110
+ class CC12MBuilder(ImageBaseDatasetBuilder):
111
+ train_dataset_cls = CCDataset
112
+
113
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/cc12m/defaults.yaml"}
114
+
115
+ def _download_ann(self):
116
+ pass
117
+
118
+ def _download_vis(self):
119
+ pass
120
+
121
+ def build(self):
122
+ self.build_processors()
123
+
124
+ build_info = self.config.build_info
125
+
126
+ datasets = dict()
127
+ split = "train"
128
+
129
+ # create datasets
130
+ # [NOTE] return inner_datasets (wds.DataPipeline)
131
+ dataset_cls = self.train_dataset_cls
132
+ datasets[split] = dataset_cls(
133
+ vis_processor=self.vis_processors[split],
134
+ text_processor=self.text_processors[split],
135
+ location=build_info.storage,
136
+ ).inner_dataset
137
+
138
+ return datasets
139
+
140
+
141
+ @registry.register_builder("llava_instruct150")
142
+ class LlavaInstruct150Builder(ImageBaseDatasetBuilder):
143
+ train_dataset_cls = LlavaInstruct150Dataset
144
+
145
+ DATASET_CONFIG_DICT = {"default": None}
146
+
147
+ def _download_ann(self):
148
+ pass
149
+
150
+ def _download_vis(self):
151
+ pass
152
+
153
+
154
+ def build(self):
155
+ self.build_processors()
156
+
157
+ datasets = dict()
158
+ split = "train"
159
+ dataset_cls = self.train_dataset_cls
160
+ datasets[split] = dataset_cls(
161
+ vis_processor=self.vis_processors[split],
162
+ text_processor=self.text_processors[split],
163
+ vis_root="/path/to/dataset/COCO_2014",
164
+ ann_paths=[os.path.join("/path/to/dataset/llava/annotations", subset + '.json')
165
+ for subset in ["complex_reasoning_77k", "conversation_58k", "detail_23k"]],
166
+ )
167
+ return datasets
168
+
169
+
170
+ # from bubogpt.datasets.builders.image_text_pair_builder import LlavaInstruct150Builder
171
+
172
+ if __name__ == "__main__":
173
+ from omegaconf import OmegaConf
174
+ from itertools import islice
175
+
176
+ data_cfg = OmegaConf.create({
177
+ "vis_processor": {"train": {"name": "imagebind_vision_train", "image_size": 224}},
178
+ "text_processor": {"train": {"name": "imagebind_caption"}},
179
+ "data_type": "image",
180
+ })
181
+
182
+ builder = LlavaInstruct150Builder(data_cfg)
183
+
184
+ datasets = builder.build_datasets()
185
+
186
+ datasets["train"].check_existence()
187
+
188
+ for sample in islice(datasets["train"], 10):
189
+ print(sample["vision"].shape, sample["prompt"], sample["text_input"])
bubogpt/datasets/builders/multimodal_base_dataset_builder.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch.distributed as dist
4
+
5
+ import bubogpt.common.utils as utils
6
+ from bubogpt.common.dist_utils import is_dist_avail_and_initialized, is_main_process
7
+ from bubogpt.common.registry import registry
8
+ from bubogpt.datasets.builders import load_dataset_config
9
+ from bubogpt.processors.base_processor import BaseProcessor
10
+
11
+
12
+ class MultimodalBaseDatasetBuilder():
13
+ train_dataset_cls, eval_dataset_cls = None, None
14
+
15
+ def __init__(self, cfg=None):
16
+ super().__init__()
17
+
18
+ if cfg is None:
19
+ # help to create datasets from default config.
20
+ self.config = load_dataset_config(self.default_config_path())
21
+ elif isinstance(cfg, str):
22
+ self.config = load_dataset_config(cfg)
23
+ else:
24
+ # when called from task.build_dataset()
25
+ self.config = cfg
26
+
27
+ self.data_type = self.config.data_type.split("_")
28
+ # It will be a list like ["audio", "image"], etc.
29
+
30
+ # Add "text" manually here.
31
+
32
+ self.processors = {modal: {"train": BaseProcessor(), "eval": BaseProcessor()}
33
+ for modal in [*self.data_type, "text"]}
34
+
35
+ def build_datasets(self):
36
+ # download, split, etc...
37
+ # only called on 1 GPU/TPU in distributed
38
+
39
+ if is_main_process():
40
+ self._download_data()
41
+
42
+ if is_dist_avail_and_initialized():
43
+ dist.barrier()
44
+
45
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
46
+ logging.info("Building datasets...")
47
+ datasets = self.build() # dataset['train'/'val'/'test']
48
+
49
+ return datasets
50
+
51
+ def build_processors(self):
52
+ for modal in [*self.data_type, "text"]:
53
+ proc_cfg = self.config.get("{}_processor".format(modal))
54
+ if proc_cfg is not None:
55
+ train_cfg = proc_cfg.get("train")
56
+ eval_cfg = proc_cfg.get("eval")
57
+ self.processors[modal]["train"] = self._build_proc_from_cfg(train_cfg)
58
+ self.processors[modal]["eval"] = self._build_proc_from_cfg(eval_cfg)
59
+
60
+
61
+ @staticmethod
62
+ def _build_proc_from_cfg(cfg):
63
+ return (
64
+ registry.get_processor_class(cfg.name).from_config(cfg)
65
+ if cfg is not None
66
+ else None
67
+ )
68
+
69
+ @classmethod
70
+ def default_config_path(cls, type="default"):
71
+ return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
72
+
73
+ def _download_data(self):
74
+ pass
bubogpt/datasets/data_utils.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import random
10
+ from typing import List, Iterable
11
+
12
+ import decord
13
+ import webdataset as wds
14
+ import torch
15
+ from torch.utils.data import IterableDataset, Dataset, ConcatDataset
16
+
17
+ from bubogpt.common.registry import registry
18
+
19
+ decord.bridge.set_bridge("torch")
20
+ MAX_INT = registry.get("MAX_INT")
21
+
22
+
23
+ class WrappedConcatDataset(ConcatDataset):
24
+ def __init__(self, datasets: Iterable[Dataset]) -> None:
25
+ super().__init__(datasets)
26
+
27
+ def collater(self, samples):
28
+ # TODO For now only supports datasets with same underlying collater implementations
29
+
30
+ all_keys = set()
31
+ for s in samples:
32
+ all_keys.update(s)
33
+
34
+ shared_keys = all_keys
35
+ for s in samples:
36
+ shared_keys = shared_keys & set(s.keys())
37
+
38
+ samples_shared_keys = []
39
+ for s in samples:
40
+ samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
41
+
42
+ return self.datasets[0].collater(samples_shared_keys)
43
+
44
+
45
+ class WrappedChainDataset(wds.DataPipeline):
46
+ r"""Dataset for chaining multiple :class:`DataPipeline` s.
47
+
48
+ This class is useful to assemble different existing dataset streams. The
49
+ chaining operation is done on-the-fly, so concatenating large-scale
50
+ datasets with this class will be efficient.
51
+
52
+ Args:
53
+ datasets (iterable of IterableDataset): datasets to be chained together
54
+ """
55
+
56
+ def __init__(self, datasets: List[wds.DataPipeline]) -> None:
57
+ super().__init__()
58
+ self.datasets = datasets
59
+ self.prob = []
60
+ self.names = []
61
+ for dataset in self.datasets:
62
+ if hasattr(dataset, 'name'):
63
+ self.names.append(dataset.name)
64
+ else:
65
+ self.names.append('Unknown')
66
+ if hasattr(dataset, 'sample_ratio'):
67
+ self.prob.append(dataset.sample_ratio)
68
+ else:
69
+ self.prob.append(1)
70
+ logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
71
+
72
+ def __iter__(self):
73
+ datastreams = [iter(dataset) for dataset in self.datasets]
74
+ while True:
75
+ select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
76
+ yield next(select_datastream)
77
+
78
+
79
+ def apply_to_sample(f, sample):
80
+ if len(sample) == 0:
81
+ return {}
82
+
83
+ def _apply(x):
84
+ if torch.is_tensor(x):
85
+ return f(x)
86
+ elif isinstance(x, dict):
87
+ return {key: _apply(value) for key, value in x.items()}
88
+ elif isinstance(x, list):
89
+ return [_apply(x) for x in x]
90
+ else:
91
+ return x
92
+
93
+ return _apply(sample)
94
+
95
+
96
+ def move_to_cuda(sample):
97
+ def _move_to_cuda(tensor):
98
+ return tensor.cuda()
99
+
100
+ return apply_to_sample(_move_to_cuda, sample)
101
+
102
+
103
+ def move_to_cpu(sample):
104
+ def _move_to_cpu(tensor):
105
+ return tensor.cpu()
106
+
107
+ return apply_to_sample(_move_to_cpu, sample)
108
+
109
+
110
+ def prepare_sample(samples, cuda_enabled=True):
111
+ if cuda_enabled:
112
+ samples = move_to_cuda(samples)
113
+
114
+ # TODO fp16 support
115
+
116
+ return samples
117
+
118
+
119
+ def reorg_datasets_by_split(datasets):
120
+ """
121
+ Organizes datasets by split.
122
+
123
+ Args:
124
+ datasets: dict of torch.utils.data.Dataset objects by name.
125
+
126
+ Returns:
127
+ Dict of datasets by split {split_name: List[Datasets]}.
128
+ """
129
+ # if len(datasets) == 1:
130
+ # return datasets[list(datasets.keys())[0]]
131
+ # else:
132
+ reorg_datasets = dict()
133
+
134
+ # reorganize by split
135
+ for _, dataset in datasets.items():
136
+ for split_name, dataset_split in dataset.items():
137
+ if split_name not in reorg_datasets:
138
+ reorg_datasets[split_name] = [dataset_split]
139
+ else:
140
+ reorg_datasets[split_name].append(dataset_split)
141
+
142
+ return reorg_datasets
143
+
144
+
145
+ def concat_datasets(datasets):
146
+ """
147
+ Concatenates multiple datasets into a single dataset.
148
+
149
+ It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
150
+ generic IterableDataset because it requires creating separate samplers.
151
+
152
+ Now only supports conctenating training datasets and assuming validation and testing
153
+ have only a single dataset. This is because metrics should not be computed on the concatenated
154
+ datasets.
155
+
156
+ Args:
157
+ datasets: dict of torch.utils.data.Dataset objects by split.
158
+
159
+ Returns:
160
+ Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
161
+ "val" and "test" remain the same.
162
+
163
+ If the input training datasets contain both map-style and DataPipeline datasets, returns
164
+ a tuple, where the first element is a concatenated map-style dataset and the second
165
+ element is a chained DataPipeline dataset.
166
+
167
+ """
168
+ # concatenate datasets in the same split
169
+ for split_name in datasets:
170
+ if split_name != "train":
171
+ assert (
172
+ len(datasets[split_name]) == 1
173
+ ), "Do not support multiple {} datasets.".format(split_name)
174
+ datasets[split_name] = datasets[split_name][0]
175
+ else:
176
+ iterable_datasets, map_datasets = [], []
177
+ for dataset in datasets[split_name]:
178
+ if isinstance(dataset, wds.DataPipeline):
179
+ logging.info(
180
+ "Dataset {} is IterableDataset, can't be concatenated.".format(
181
+ dataset
182
+ )
183
+ )
184
+ iterable_datasets.append(dataset)
185
+ elif isinstance(dataset, IterableDataset):
186
+ raise NotImplementedError(
187
+ "Do not support concatenation of generic IterableDataset."
188
+ )
189
+ else:
190
+ map_datasets.append(dataset)
191
+
192
+ # if len(iterable_datasets) > 0:
193
+ # concatenate map-style datasets and iterable-style datasets separately
194
+ if len(iterable_datasets) > 1:
195
+ chained_datasets = (
196
+ WrappedChainDataset(iterable_datasets)
197
+ )
198
+ elif len(iterable_datasets) == 1:
199
+ chained_datasets = iterable_datasets[0]
200
+ else:
201
+ chained_datasets = None
202
+
203
+ concat_datasets = (
204
+ WrappedConcatDataset(map_datasets) if len(map_datasets) > 0 else None
205
+ )
206
+
207
+ train_datasets = concat_datasets, chained_datasets
208
+ train_datasets = tuple([x for x in train_datasets if x is not None])
209
+ train_datasets = (
210
+ train_datasets[0] if len(train_datasets) == 1 else train_datasets
211
+ )
212
+
213
+ datasets[split_name] = train_datasets
214
+
215
+ return datasets
bubogpt/datasets/datasets/__init__.py ADDED
File without changes
bubogpt/datasets/datasets/audio_caption/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from bubogpt.datasets.datasets.audio_caption.audio_caption_datasets import GenericAudioDataset, AudioCaptionDataset
bubogpt/datasets/datasets/audio_caption/audio_caption_datasets.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import torchaudio
4
+ import random
5
+ import tempfile
6
+
7
+ from torch.utils.data import Dataset, default_collate
8
+ import webdataset as wds
9
+ from bubogpt.datasets.datasets.base_dataset import BaseDualDataset
10
+
11
+
12
+ class GenericAudioDataset(BaseDualDataset):
13
+ def __init__(self, audio_processor, text_processor, location):
14
+ super().__init__(x_processor=audio_processor, text_processor=text_processor)
15
+
16
+ self.inner_dataset = wds.DataPipeline(
17
+ wds.ResampledShards(location),
18
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
19
+ wds.shuffle(1000, handler=wds.warn_and_continue),
20
+ wds.decode(wds.torch_audio, handler=wds.warn_and_continue),
21
+ wds.to_tuple("flac", "json", handler=wds.warn_and_continue),
22
+ wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
23
+ wds.map(self.to_dict, handler=wds.warn_and_continue),
24
+ )
25
+
26
+ def to_dict(self, sample):
27
+ return {
28
+ "audio": sample[0],
29
+ # [clips_per_video, channel, mel_bins, time_steps]
30
+ "text_input": self.text_processor(sample[1]["caption"]),
31
+ }
32
+
33
+
34
+ class AudioCaptionDataset(BaseDualDataset):
35
+ def __init__(self, audio_processor, text_processor, audio_root, ann_paths):
36
+ """
37
+ vis_root (string): Root directory of images (e.g. coco/images/)
38
+ ann_root (string): directory to store the annotation file
39
+ """
40
+ super().__init__(audio_processor, text_processor, audio_root, ann_paths)
41
+
42
+ self.audio_ids = {}
43
+ n = 0
44
+ for ann in self.annotation:
45
+ audio_id = ann["audio_id"]
46
+ if audio_id not in self.audio_ids.keys():
47
+ self.audio_ids[audio_id] = n
48
+ n += 1
49
+
50
+ with open("prompts/alignment_audio.txt") as f:
51
+ self.prompts = f.read().splitlines()
52
+ print(f"==> {self.__class__.__name__} using prompts: ", "\n " + "\n ".join(self.prompts))
53
+
54
+ def __getitem__(self, index):
55
+
56
+ # TODO this assumes image input, not general enough
57
+ ann = self.annotation[index]
58
+
59
+ audio_file = ann["audio_id"] + ".wav"
60
+ audio_path = os.path.join(self.x_root, audio_file)
61
+ audio = torchaudio.load(audio_path)
62
+ audio = self.x_processor(audio)
63
+ caption = self.text_processor(ann["caption"])
64
+
65
+ return {
66
+ "audio": audio,
67
+ "text_input": caption,
68
+ # "audio_id": self.audio_ids[ann["audio_id"]],
69
+ "prompt": random.choice(self.prompts),
70
+ }
bubogpt/datasets/datasets/audio_image/__init__.py ADDED
File without changes
bubogpt/datasets/datasets/audio_image/audio_image_datasets.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import json
4
+ import torchaudio
5
+ from torch.utils.data import Dataset
6
+ from PIL import Image
7
+ from bubogpt.datasets.datasets.base_dataset import BaseMultiSourceDataset
8
+ import webdataset as wds
9
+
10
+
11
+ class AudioLocalizationDataset(BaseMultiSourceDataset):
12
+ def __init__(self, processors, roots, ann_paths):
13
+ super().__init__(processors, roots, ann_paths)
14
+
15
+ with open("prompts/alignment_audio_image_region.txt") as f:
16
+ self.prompts = f.read().splitlines()
17
+ print(f"==> {self.__class__.__name__} using prompts: ", "\n " + "\n ".join(self.prompts))
18
+
19
+ def __getitem__(self, index):
20
+ ann = self.annotation[index]
21
+
22
+ audio_file = ann["audio_id"] + ".wav"
23
+ image_file = ann["image_id"] + ".jpg"
24
+ audio_path = os.path.join(self.roots["audio"], audio_file)
25
+ image_path = os.path.join(self.roots["image"], image_file)
26
+
27
+ audio = torchaudio.load(audio_path)
28
+ image = Image.open(image_path).convert("RGB")
29
+ audio = self.processors["audio"](audio)
30
+ image = self.processors["image"](image)
31
+ caption = self.processors["text"](ann["caption"])
32
+
33
+ return {
34
+ "audio": audio,
35
+ "vision": image,
36
+ "text_input": caption,
37
+ "prompt": random.choice(self.prompts),
38
+ }
39
+
40
+
41
+ class AudioImageNegDataset(Dataset):
42
+ def __init__(self, processors, roots, ann_paths) -> None:
43
+ super().__init__()
44
+
45
+ self.processors = processors
46
+ self.roots = roots
47
+ self.ann_paths = ann_paths
48
+
49
+ self.img_annotation = []
50
+ for ann_path in ann_paths['image']:
51
+ self.img_annotation.extend(json.load(open(ann_path, "r"))['annotations'])
52
+
53
+ self.aud_annotation = []
54
+ for ann_path in ann_paths['audio']:
55
+ self.aud_annotation.extend(json.load(open(ann_path, "r"))['annotations'])
56
+
57
+ with open("prompts/alignment_audio_image_neg.txt") as f:
58
+ self.prompts = f.read().splitlines()
59
+ print(f"==> {self.__class__.__name__} using prompts: ", "\n " + "\n ".join(self.prompts))
60
+
61
+ def __len__(self):
62
+ return len(self.img_annotation)
63
+
64
+ def __getitem__(self, index):
65
+
66
+ img_ann = self.img_annotation[index]
67
+
68
+ img_file = '{}.jpg'.format(img_ann["image_id"])
69
+ image_path = os.path.join(self.roots['image'], img_file)
70
+ image = Image.open(image_path).convert("RGB")
71
+ image = self.processors['image'](image)
72
+
73
+ aud_index = random.randint(0, len(self.aud_annotation)-1)
74
+ aud_ann = self.aud_annotation[aud_index]
75
+
76
+ audio_file = aud_ann["audio_id"] + ".wav"
77
+ audio_path = os.path.join(self.roots['audio'], audio_file)
78
+ audio = torchaudio.load(audio_path)
79
+ audio = self.processors['audio'](audio)
80
+ prompt = random.choice(self.prompts)
81
+ if "related" in prompt:
82
+ prefix = "They seem unrelated. "
83
+ else:
84
+ prefix = "They seem unrelated. " if random.random() < 0.5 else ""
85
+ caption = self.processors['text'](prefix + img_ann["caption"] + aud_ann["caption"])
86
+
87
+ return {
88
+ 'audio': audio,
89
+ 'vision': image,
90
+ 'text_input': caption,
91
+ 'prompt': prompt,
92
+ }
bubogpt/datasets/datasets/base_dataset.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import json
9
+ from typing import Iterable
10
+
11
+ from torch.utils.data import Dataset
12
+ from torch.utils.data.dataloader import default_collate
13
+
14
+
15
+ class BaseDualDataset(Dataset):
16
+ def __init__(
17
+ self, x_processor=None, text_processor=None, x_root=None, ann_paths=[]
18
+ ):
19
+ """
20
+ x_root (string): Root directory of data in modality X (e.g. coco/images/, etc.)
21
+ ann_root (string): directory to store the annotation file
22
+ """
23
+ self.x_root = x_root
24
+
25
+ self.annotation = []
26
+ for ann_path in ann_paths:
27
+ self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
28
+
29
+ self.x_processor = x_processor
30
+ self.text_processor = text_processor
31
+
32
+ self._add_instance_ids()
33
+
34
+ def __len__(self):
35
+ return len(self.annotation)
36
+
37
+ def collater(self, samples):
38
+ return default_collate(samples)
39
+
40
+ def set_processors(self, x_processor, text_processor):
41
+ self.x_processor = x_processor
42
+ self.text_processor = text_processor
43
+
44
+ def _add_instance_ids(self, key="instance_id"):
45
+ for idx, ann in enumerate(self.annotation):
46
+ ann[key] = str(idx)
47
+
48
+
49
+ class BaseMultiSourceDataset(Dataset):
50
+ def __init__(
51
+ self, processors=None, roots=None, ann_paths=[]
52
+ ):
53
+ """
54
+ processors (Dict[str, Processor]): The processors of different modalities.
55
+ roots (Dict[str, str]): The roots of different modalities, Deprecated
56
+ ann_root (string): directory to store the annotation file
57
+ """
58
+ self.roots = roots
59
+
60
+ self.annotation = []
61
+ for ann_path in ann_paths:
62
+ self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
63
+
64
+ self.processors = processors
65
+
66
+ self._add_instance_ids()
67
+
68
+ def __len__(self):
69
+ return len(self.annotation)
70
+
71
+ def collater(self, samples):
72
+ return default_collate(samples)
73
+
74
+ def set_processors(self, processors):
75
+ self.processors = processors
76
+
77
+ def _add_instance_ids(self, key="instance_id"):
78
+ for idx, ann in enumerate(self.annotation):
79
+ ann[key] = str(idx)
bubogpt/datasets/datasets/dataloader_utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import time
9
+ import random
10
+ import torch
11
+ from bubogpt.datasets.data_utils import move_to_cuda
12
+ from torch.utils.data import DataLoader
13
+
14
+
15
+ class MultiIterLoader:
16
+ """
17
+ A simple wrapper for iterating over multiple iterators.
18
+
19
+ Args:
20
+ loaders (List[Loader]): List of Iterator loaders.
21
+ ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
22
+ """
23
+
24
+ def __init__(self, loaders, ratios=None):
25
+ # assert all loaders has __next__ method
26
+ for loader in loaders:
27
+ assert hasattr(
28
+ loader, "__next__"
29
+ ), "Loader {} has no __next__ method.".format(loader)
30
+
31
+ if ratios is None:
32
+ ratios = [1.0] * len(loaders)
33
+ else:
34
+ assert len(ratios) == len(loaders)
35
+ ratios = [float(ratio) / sum(ratios) for ratio in ratios]
36
+
37
+ self.loaders = loaders
38
+ self.ratios = ratios
39
+
40
+ def __next__(self):
41
+ # random sample from each loader by ratio
42
+ loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
43
+ return next(self.loaders[loader_idx])
44
+
45
+
46
+ class PrefetchLoader(object):
47
+ """
48
+ Modified from https://github.com/ChenRocks/UNITER.
49
+
50
+ overlap compute and cuda data transfer
51
+ (copied and then modified from nvidia apex)
52
+ """
53
+
54
+ def __init__(self, loader):
55
+ self.loader = loader
56
+ self.stream = torch.cuda.Stream()
57
+
58
+ def __iter__(self):
59
+ loader_it = iter(self.loader)
60
+ self.preload(loader_it)
61
+ batch = self.next(loader_it)
62
+ while batch is not None:
63
+ is_tuple = isinstance(batch, tuple)
64
+ if is_tuple:
65
+ task, batch = batch
66
+
67
+ if is_tuple:
68
+ yield task, batch
69
+ else:
70
+ yield batch
71
+ batch = self.next(loader_it)
72
+
73
+ def __len__(self):
74
+ return len(self.loader)
75
+
76
+ def preload(self, it):
77
+ try:
78
+ self.batch = next(it)
79
+ except StopIteration:
80
+ self.batch = None
81
+ return
82
+ # if record_stream() doesn't work, another option is to make sure
83
+ # device inputs are created on the main stream.
84
+ # self.next_input_gpu = torch.empty_like(self.next_input,
85
+ # device='cuda')
86
+ # self.next_target_gpu = torch.empty_like(self.next_target,
87
+ # device='cuda')
88
+ # Need to make sure the memory allocated for next_* is not still in use
89
+ # by the main stream at the time we start copying to next_*:
90
+ # self.stream.wait_stream(torch.cuda.current_stream())
91
+ with torch.cuda.stream(self.stream):
92
+ self.batch = move_to_cuda(self.batch)
93
+ # more code for the alternative if record_stream() doesn't work:
94
+ # copy_ will record the use of the pinned source tensor in this
95
+ # side stream.
96
+ # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
97
+ # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
98
+ # self.next_input = self.next_input_gpu
99
+ # self.next_target = self.next_target_gpu
100
+
101
+ def next(self, it):
102
+ torch.cuda.current_stream().wait_stream(self.stream)
103
+ batch = self.batch
104
+ if batch is not None:
105
+ record_cuda_stream(batch)
106
+ self.preload(it)
107
+ return batch
108
+
109
+ def __getattr__(self, name):
110
+ method = self.loader.__getattribute__(name)
111
+ return method
112
+
113
+
114
+ def record_cuda_stream(batch):
115
+ if isinstance(batch, torch.Tensor):
116
+ batch.record_stream(torch.cuda.current_stream())
117
+ elif isinstance(batch, list) or isinstance(batch, tuple):
118
+ for t in batch:
119
+ record_cuda_stream(t)
120
+ elif isinstance(batch, dict):
121
+ for t in batch.values():
122
+ record_cuda_stream(t)
123
+ else:
124
+ pass
125
+
126
+
127
+ class IterLoader:
128
+ """
129
+ A wrapper to convert DataLoader as an infinite iterator.
130
+
131
+ Modified from:
132
+ https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
133
+ """
134
+
135
+ def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
136
+ self._dataloader = dataloader
137
+ self.iter_loader = iter(self._dataloader)
138
+ self._use_distributed = use_distributed
139
+ self._epoch = 0
140
+
141
+ @property
142
+ def epoch(self) -> int:
143
+ return self._epoch
144
+
145
+ def __next__(self):
146
+ try:
147
+ data = next(self.iter_loader)
148
+ except StopIteration:
149
+ self._epoch += 1
150
+ if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
151
+ self._dataloader.sampler.set_epoch(self._epoch)
152
+ time.sleep(2) # Prevent possible deadlock during epoch transition
153
+ self.iter_loader = iter(self._dataloader)
154
+ data = next(self.iter_loader)
155
+
156
+ return data
157
+
158
+ def __iter__(self):
159
+ return self
160
+
161
+ def __len__(self):
162
+ return len(self._dataloader)
bubogpt/datasets/datasets/image_caption/__init__.py ADDED
File without changes
bubogpt/datasets/datasets/image_caption/cc_sbu_dataset.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import webdataset as wds
4
+ from bubogpt.datasets.datasets.base_dataset import BaseDualDataset
5
+ from bubogpt.datasets.datasets.image_caption.image_caption_datasets import ImageCaptionDataset
6
+
7
+
8
+ class CCSBUDataset(BaseDualDataset):
9
+ def __init__(self, vision_processor, text_processor, location):
10
+ super().__init__(x_processor=vision_processor, text_processor=text_processor)
11
+
12
+ self.inner_dataset = wds.DataPipeline(
13
+ wds.ResampledShards(location),
14
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
15
+ wds.shuffle(1000, handler=wds.warn_and_continue),
16
+ wds.decode("pilrgb", handler=wds.warn_and_continue),
17
+ wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
18
+ wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
19
+ wds.map(self.to_dict, handler=wds.warn_and_continue),
20
+ )
21
+
22
+ def to_dict(self, sample):
23
+ return {
24
+ "vision": sample[0],
25
+ "text_input": self.text_processor(sample[1]["caption"]),
26
+ }
27
+
28
+
29
+ class CCSBUAlignDatasetImageImageCaptionDataset(ImageCaptionDataset):
30
+
31
+ def __getitem__(self, index):
32
+
33
+ # TODO this assumes image input, not general enough
34
+ ann = self.annotation[index]
35
+
36
+ img_file = '{}.jpg'.format(ann["image_id"])
37
+ image_path = os.path.join(self.x_root, img_file)
38
+ image = Image.open(image_path).convert("RGB")
39
+
40
+ image = self.x_processor(image)
41
+ caption = ann["caption"]
42
+
43
+ return {
44
+ "vision": image,
45
+ "text_input": caption,
46
+ "image_id": self.img_ids[ann["image_id"]],
47
+ }
48
+
49
+
50
+ class CCDataset(BaseDualDataset):
51
+ def __init__(self, vis_processor, text_processor, location):
52
+ super().__init__(x_processor=vis_processor, text_processor=text_processor)
53
+
54
+ self.inner_dataset = wds.DataPipeline(
55
+ wds.ResampledShards(location),
56
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
57
+ wds.shuffle(1000, handler=wds.warn_and_continue),
58
+ wds.decode("pilrgb", handler=wds.warn_and_continue),
59
+ wds.to_tuple("jpg", "txt", handler=wds.warn_and_continue),
60
+ wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
61
+ wds.map(self.to_dict, handler=wds.warn_and_continue),
62
+ )
63
+
64
+ def to_dict(self, sample):
65
+ return {
66
+ "vision": sample[0],
67
+ "text_input": sample[1],
68
+ }
bubogpt/datasets/datasets/image_caption/image_caption_datasets.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+
10
+ from bubogpt.datasets.datasets.base_dataset import BaseDualDataset
11
+ from PIL import Image
12
+
13
+ from bubogpt.datasets.datasets.mixins.mixins import __ImageDisplMixin
14
+
15
+
16
+ class ImageCaptionDataset(BaseDualDataset, __ImageDisplMixin):
17
+ def __init__(self, vision_processor, text_processor, vis_root, ann_paths):
18
+ """
19
+ vis_root (string): Root directory of images (e.g. coco/images/)
20
+ ann_root (string): directory to store the annotation file
21
+ """
22
+ super().__init__(vision_processor, text_processor, vis_root, ann_paths)
23
+
24
+ self.img_ids = {}
25
+ n = 0
26
+ for ann in self.annotation:
27
+ img_id = ann["image_id"]
28
+ if img_id not in self.img_ids.keys():
29
+ self.img_ids[img_id] = n
30
+ n += 1
31
+
32
+ def __getitem__(self, index):
33
+
34
+ # TODO this assumes image input, not general enough
35
+ ann = self.annotation[index]
36
+
37
+ img_file = '{:0>12}.jpg'.format(ann["image_id"])
38
+ image_path = os.path.join(self.x_root, img_file)
39
+ image = Image.open(image_path).convert("RGB")
40
+
41
+ image = self.x_processor(image)
42
+ caption = self.text_processor(ann["caption"])
43
+
44
+ return {
45
+ "vision": image,
46
+ "text_input": caption,
47
+ "image_id": self.img_ids[ann["image_id"]],
48
+ }
49
+
50
+
51
+ class CaptionEvalDataset(BaseDualDataset, __ImageDisplMixin):
52
+ def __init__(self, vision_processor, text_processor, x_root, ann_paths):
53
+ """
54
+ vis_root (string): Root directory of images (e.g. coco/images/)
55
+ ann_root (string): directory to store the annotation file
56
+ split (string): val or test
57
+ """
58
+ super().__init__(vision_processor, text_processor, x_root, ann_paths)
59
+
60
+ def __getitem__(self, index):
61
+
62
+ ann = self.annotation[index]
63
+
64
+ image_path = os.path.join(self.x_root, ann["image"])
65
+ image = Image.open(image_path).convert("RGB")
66
+
67
+ image = self.x_processor(image)
68
+
69
+ return {
70
+ "vision": image,
71
+ "image_id": ann["image_id"],
72
+ "instance_id": ann["instance_id"],
73
+ }
bubogpt/datasets/datasets/image_caption/laion_dataset.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import webdataset as wds
9
+ from bubogpt.datasets.datasets.base_dataset import BaseDualDataset
10
+
11
+
12
+ class LaionDataset(BaseDualDataset):
13
+ def __init__(self, vision_processor, text_processor, location):
14
+ super().__init__(x_processor=vision_processor, text_processor=text_processor)
15
+
16
+ self.inner_dataset = wds.DataPipeline(
17
+ wds.ResampledShards(location),
18
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
19
+ wds.shuffle(1000, handler=wds.warn_and_continue),
20
+ wds.decode("pilrgb", handler=wds.warn_and_continue),
21
+ wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
22
+ wds.map_tuple(self.x_processor, handler=wds.warn_and_continue),
23
+ wds.map(self.to_dict, handler=wds.warn_and_continue),
24
+ )
25
+
26
+ def to_dict(self, sample):
27
+ return {
28
+ "vision": sample[0],
29
+ "text_input": self.text_processor(sample[1]["caption"]),
30
+ }
31
+
bubogpt/datasets/datasets/image_caption/llava_dataset.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from PIL import Image
5
+ import webdataset as wds
6
+ from bubogpt.datasets.datasets.base_dataset import BaseDualDataset
7
+ from bubogpt.datasets.datasets.image_caption.image_caption_datasets import ImageCaptionDataset
8
+
9
+
10
+ class LlavaInstruct150Dataset(BaseDualDataset):
11
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
12
+ super().__init__(x_processor=vis_processor, text_processor=text_processor)
13
+ self.vis_root = vis_root
14
+ self.ann_paths = ann_paths
15
+
16
+ self.data_list = data_list = []
17
+ # for split in ["complex_reasoning_77k", "conversation_58k", "detail_23k"]:
18
+ # with open(os.path.join(vis_root, f'annotations/{split}.json'), 'r') as f:
19
+ # data_list.extend(json.load(f))
20
+ for ann_path in ann_paths:
21
+ with open(ann_path) as f:
22
+ data_list.extend(json.load(f))
23
+
24
+ self.annotation = []
25
+ for item in data_list:
26
+ image_id = item['id']
27
+ conversations = item['conversations']
28
+ for conv_id in range(len(conversations) //2 ):
29
+ question = conversations[2*conv_id]['value']
30
+ answer = conversations[2 * conv_id+1]['value']
31
+ self.annotation.append({'image_id':image_id, 'question':question, 'answer':answer})
32
+
33
+ # llava prompts
34
+ self.prompts = [
35
+ "<Vision><ModalityHere></Vision> <question>",
36
+ "<Vision><ModalityHere></Vision> Quesion: <question>",
37
+ "<Vision><ModalityHere></Vision> <question> A detail answer to the question is",
38
+ "<Vision><ModalityHere></Vision> Quesion: <question> detail answer:",
39
+ "<Vision><ModalityHere></Vision> Based on the image, respond to this question with a detail answer: <question> Answer:",
40
+ "<Vision><ModalityHere></Vision> Use the provided image to answer the question: <question>",
41
+ "<Vision><ModalityHere></Vision> What is the answer to the following question? <question>",
42
+ ]
43
+ print(f"==> {self.__class__.__name__} using prompts: ", "\n " + "\n ".join(self.prompts))
44
+ # self.prompt_template = '###Human: {} ###Assistant: '
45
+
46
+ def __getitem__(self, index):
47
+ ann = self.annotation[index]
48
+
49
+ image_path = os.path.join(self.vis_root, "train2014/COCO_train2014_{:0>12}.jpg".format(ann["image_id"]))
50
+ image = Image.open(image_path).convert("RGB")
51
+ image = self.x_processor(image)
52
+
53
+ question = ann['question']
54
+ question = question.replace('<image>\n', '').replace('\n<image>', '')
55
+ # prompt = self.prompt_template.format(random.choice(self.prompts))
56
+ prompt = random.choice(self.prompts)
57
+ prompt = prompt.replace('<question>', question)
58
+
59
+ return {
60
+ "vision": image,
61
+ "prompt": prompt,
62
+ "text_input": ann["answer"],
63
+ }
64
+
65
+ def check_existence(self):
66
+ from tqdm import tqdm
67
+ for i in tqdm(range(len(self.data_list))):
68
+ image_id = self.data_list[i]["id"]
69
+ image_path = os.path.join(self.vis_root, "train2014/COCO_train2014_{:0>12}.jpg".format(image_id))
70
+ if not os.path.exists(image_path):
71
+ print(f'Image does not exist: {image_path}')
72
+ print("Checking sucessful!")
bubogpt/datasets/datasets/mixins/__init__.py ADDED
File without changes
bubogpt/datasets/datasets/mixins/mixins.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+
4
+ class __ImageDisplMixin:
5
+ def displ_item(self, index):
6
+ sample, ann = self.__getitem__(index), self.annotation[index]
7
+
8
+ return OrderedDict(
9
+ {
10
+ "file": ann["image"],
11
+ "caption": ann["caption"],
12
+ "vision": sample["vision"],
13
+ }
14
+ )
15
+
16
+
17
+ class __AudioDisplMixin:
18
+ def displ_item(self, index):
19
+ sample, ann = self.__getitem__(index), self.annotation[index]
20
+
21
+ # TODO: Finish the Audio Display Mixin
22
+ '''
23
+ return OrderedDict(
24
+ {
25
+ }
26
+ )
27
+ '''
28
+
29
+ raise NotImplementedError
30
+
bubogpt/models/Qformer.py ADDED
@@ -0,0 +1,1216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2023, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Dict, Any
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ class BertEmbeddings(nn.Module):
52
+ """Construct the embeddings from word and position embeddings."""
53
+
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.word_embeddings = nn.Embedding(
57
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
58
+ )
59
+ self.position_embeddings = nn.Embedding(
60
+ config.max_position_embeddings, config.hidden_size
61
+ )
62
+
63
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
64
+ # any TensorFlow checkpoint file
65
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
66
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
67
+
68
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
69
+ self.register_buffer(
70
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
71
+ )
72
+ self.position_embedding_type = getattr(
73
+ config, "position_embedding_type", "absolute"
74
+ )
75
+
76
+ self.config = config
77
+
78
+ def forward(
79
+ self,
80
+ input_ids=None,
81
+ position_ids=None,
82
+ query_embeds=None,
83
+ past_key_values_length=0,
84
+ ):
85
+ if input_ids is not None:
86
+ seq_length = input_ids.size()[1]
87
+ else:
88
+ seq_length = 0
89
+
90
+ if position_ids is None:
91
+ position_ids = self.position_ids[
92
+ :, past_key_values_length : seq_length + past_key_values_length
93
+ ].clone()
94
+
95
+ if input_ids is not None:
96
+ embeddings = self.word_embeddings(input_ids)
97
+ if self.position_embedding_type == "absolute":
98
+ position_embeddings = self.position_embeddings(position_ids)
99
+ embeddings = embeddings + position_embeddings
100
+
101
+ if query_embeds is not None:
102
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
103
+ else:
104
+ embeddings = query_embeds
105
+
106
+ embeddings = self.LayerNorm(embeddings)
107
+ embeddings = self.dropout(embeddings)
108
+ return embeddings
109
+
110
+
111
+ class BertSelfAttention(nn.Module):
112
+ def __init__(self, config, is_cross_attention):
113
+ super().__init__()
114
+ self.config = config
115
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
116
+ config, "embedding_size"
117
+ ):
118
+ raise ValueError(
119
+ "The hidden size (%d) is not a multiple of the number of attention "
120
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
121
+ )
122
+
123
+ self.num_attention_heads = config.num_attention_heads
124
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
125
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
126
+
127
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
128
+ if is_cross_attention:
129
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
130
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
131
+ else:
132
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
133
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
134
+
135
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
136
+ self.position_embedding_type = getattr(
137
+ config, "position_embedding_type", "absolute"
138
+ )
139
+ if (
140
+ self.position_embedding_type == "relative_key"
141
+ or self.position_embedding_type == "relative_key_query"
142
+ ):
143
+ self.max_position_embeddings = config.max_position_embeddings
144
+ self.distance_embedding = nn.Embedding(
145
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
146
+ )
147
+ self.save_attention = False
148
+
149
+ def save_attn_gradients(self, attn_gradients):
150
+ self.attn_gradients = attn_gradients
151
+
152
+ def get_attn_gradients(self):
153
+ return self.attn_gradients
154
+
155
+ def save_attention_map(self, attention_map):
156
+ self.attention_map = attention_map
157
+
158
+ def get_attention_map(self):
159
+ return self.attention_map
160
+
161
+ def transpose_for_scores(self, x):
162
+ new_x_shape = x.size()[:-1] + (
163
+ self.num_attention_heads,
164
+ self.attention_head_size,
165
+ )
166
+ x = x.view(*new_x_shape)
167
+ return x.permute(0, 2, 1, 3)
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states,
172
+ attention_mask=None,
173
+ head_mask=None,
174
+ encoder_hidden_states=None,
175
+ encoder_attention_mask=None,
176
+ past_key_value=None,
177
+ output_attentions=False,
178
+ ):
179
+
180
+ # If this is instantiated as a cross-attention module, the keys
181
+ # and values come from an encoder; the attention mask needs to be
182
+ # such that the encoder's padding tokens are not attended to.
183
+ is_cross_attention = encoder_hidden_states is not None
184
+
185
+ if is_cross_attention:
186
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
187
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
188
+ attention_mask = encoder_attention_mask
189
+ elif past_key_value is not None:
190
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
191
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
192
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
193
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
194
+ else:
195
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
196
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
197
+
198
+ mixed_query_layer = self.query(hidden_states)
199
+
200
+ query_layer = self.transpose_for_scores(mixed_query_layer)
201
+
202
+ past_key_value = (key_layer, value_layer)
203
+
204
+ # Take the dot product between "query" and "key" to get the raw attention scores.
205
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
206
+
207
+ if (
208
+ self.position_embedding_type == "relative_key"
209
+ or self.position_embedding_type == "relative_key_query"
210
+ ):
211
+ seq_length = hidden_states.size()[1]
212
+ position_ids_l = torch.arange(
213
+ seq_length, dtype=torch.long, device=hidden_states.device
214
+ ).view(-1, 1)
215
+ position_ids_r = torch.arange(
216
+ seq_length, dtype=torch.long, device=hidden_states.device
217
+ ).view(1, -1)
218
+ distance = position_ids_l - position_ids_r
219
+ positional_embedding = self.distance_embedding(
220
+ distance + self.max_position_embeddings - 1
221
+ )
222
+ positional_embedding = positional_embedding.to(
223
+ dtype=query_layer.dtype
224
+ ) # fp16 compatibility
225
+
226
+ if self.position_embedding_type == "relative_key":
227
+ relative_position_scores = torch.einsum(
228
+ "bhld,lrd->bhlr", query_layer, positional_embedding
229
+ )
230
+ attention_scores = attention_scores + relative_position_scores
231
+ elif self.position_embedding_type == "relative_key_query":
232
+ relative_position_scores_query = torch.einsum(
233
+ "bhld,lrd->bhlr", query_layer, positional_embedding
234
+ )
235
+ relative_position_scores_key = torch.einsum(
236
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
237
+ )
238
+ attention_scores = (
239
+ attention_scores
240
+ + relative_position_scores_query
241
+ + relative_position_scores_key
242
+ )
243
+
244
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
245
+ if attention_mask is not None:
246
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
247
+ attention_scores = attention_scores + attention_mask
248
+
249
+ # Normalize the attention scores to probabilities.
250
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
251
+
252
+ if is_cross_attention and self.save_attention:
253
+ self.save_attention_map(attention_probs)
254
+ attention_probs.register_hook(self.save_attn_gradients)
255
+
256
+ # This is actually dropping out entire tokens to attend to, which might
257
+ # seem a bit unusual, but is taken from the original Transformer paper.
258
+ attention_probs_dropped = self.dropout(attention_probs)
259
+
260
+ # Mask heads if we want to
261
+ if head_mask is not None:
262
+ attention_probs_dropped = attention_probs_dropped * head_mask
263
+
264
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
265
+
266
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
267
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
268
+ context_layer = context_layer.view(*new_context_layer_shape)
269
+
270
+ outputs = (
271
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
272
+ )
273
+
274
+ outputs = outputs + (past_key_value,)
275
+ return outputs
276
+
277
+
278
+ class BertSelfOutput(nn.Module):
279
+ def __init__(self, config):
280
+ super().__init__()
281
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
282
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
283
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
284
+
285
+ def forward(self, hidden_states, input_tensor):
286
+ hidden_states = self.dense(hidden_states)
287
+ hidden_states = self.dropout(hidden_states)
288
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
289
+ return hidden_states
290
+
291
+
292
+ class BertAttention(nn.Module):
293
+ def __init__(self, config, is_cross_attention=False):
294
+ super().__init__()
295
+ self.self = BertSelfAttention(config, is_cross_attention)
296
+ self.output = BertSelfOutput(config)
297
+ self.pruned_heads = set()
298
+
299
+ def prune_heads(self, heads):
300
+ if len(heads) == 0:
301
+ return
302
+ heads, index = find_pruneable_heads_and_indices(
303
+ heads,
304
+ self.self.num_attention_heads,
305
+ self.self.attention_head_size,
306
+ self.pruned_heads,
307
+ )
308
+
309
+ # Prune linear layers
310
+ self.self.query = prune_linear_layer(self.self.query, index)
311
+ self.self.key = prune_linear_layer(self.self.key, index)
312
+ self.self.value = prune_linear_layer(self.self.value, index)
313
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
314
+
315
+ # Update hyper params and store pruned heads
316
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
317
+ self.self.all_head_size = (
318
+ self.self.attention_head_size * self.self.num_attention_heads
319
+ )
320
+ self.pruned_heads = self.pruned_heads.union(heads)
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states,
325
+ attention_mask=None,
326
+ head_mask=None,
327
+ encoder_hidden_states=None,
328
+ encoder_attention_mask=None,
329
+ past_key_value=None,
330
+ output_attentions=False,
331
+ ):
332
+ self_outputs = self.self(
333
+ hidden_states,
334
+ attention_mask,
335
+ head_mask,
336
+ encoder_hidden_states,
337
+ encoder_attention_mask,
338
+ past_key_value,
339
+ output_attentions,
340
+ )
341
+ attention_output = self.output(self_outputs[0], hidden_states)
342
+
343
+ outputs = (attention_output,) + self_outputs[
344
+ 1:
345
+ ] # add attentions if we output them
346
+ return outputs
347
+
348
+
349
+ class BertIntermediate(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
353
+ if isinstance(config.hidden_act, str):
354
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
355
+ else:
356
+ self.intermediate_act_fn = config.hidden_act
357
+
358
+ def forward(self, hidden_states):
359
+ hidden_states = self.dense(hidden_states)
360
+ hidden_states = self.intermediate_act_fn(hidden_states)
361
+ return hidden_states
362
+
363
+
364
+ class BertOutput(nn.Module):
365
+ def __init__(self, config):
366
+ super().__init__()
367
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
368
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
369
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
370
+
371
+ def forward(self, hidden_states, input_tensor):
372
+ hidden_states = self.dense(hidden_states)
373
+ hidden_states = self.dropout(hidden_states)
374
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
375
+ return hidden_states
376
+
377
+
378
+ class BertLayer(nn.Module):
379
+ def __init__(self, config, layer_num):
380
+ super().__init__()
381
+ self.config = config
382
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
383
+ self.seq_len_dim = 1
384
+ self.attention = BertAttention(config)
385
+ self.layer_num = layer_num
386
+ if (
387
+ self.config.add_cross_attention
388
+ and layer_num % self.config.cross_attention_freq == 0
389
+ ):
390
+ self.crossattention = BertAttention(
391
+ config, is_cross_attention=self.config.add_cross_attention
392
+ )
393
+ self.has_cross_attention = True
394
+ else:
395
+ self.has_cross_attention = False
396
+ self.intermediate = BertIntermediate(config)
397
+ self.output = BertOutput(config)
398
+
399
+ self.intermediate_query = BertIntermediate(config)
400
+ self.output_query = BertOutput(config)
401
+
402
+ def forward(
403
+ self,
404
+ hidden_states,
405
+ attention_mask=None,
406
+ head_mask=None,
407
+ encoder_hidden_states=None,
408
+ encoder_attention_mask=None,
409
+ past_key_value=None,
410
+ output_attentions=False,
411
+ query_length=0,
412
+ ):
413
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
414
+ self_attn_past_key_value = (
415
+ past_key_value[:2] if past_key_value is not None else None
416
+ )
417
+ self_attention_outputs = self.attention(
418
+ hidden_states,
419
+ attention_mask,
420
+ head_mask,
421
+ output_attentions=output_attentions,
422
+ past_key_value=self_attn_past_key_value,
423
+ )
424
+ attention_output = self_attention_outputs[0]
425
+ outputs = self_attention_outputs[1:-1]
426
+
427
+ present_key_value = self_attention_outputs[-1]
428
+
429
+ if query_length > 0:
430
+ query_attention_output = attention_output[:, :query_length, :]
431
+
432
+ if self.has_cross_attention:
433
+ assert (
434
+ encoder_hidden_states is not None
435
+ ), "encoder_hidden_states must be given for cross-attention layers"
436
+ cross_attention_outputs = self.crossattention(
437
+ query_attention_output,
438
+ attention_mask,
439
+ head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ output_attentions=output_attentions,
443
+ )
444
+ query_attention_output = cross_attention_outputs[0]
445
+ outputs = (
446
+ outputs + cross_attention_outputs[1:-1]
447
+ ) # add cross attentions if we output attention weights
448
+
449
+ layer_output = apply_chunking_to_forward(
450
+ self.feed_forward_chunk_query,
451
+ self.chunk_size_feed_forward,
452
+ self.seq_len_dim,
453
+ query_attention_output,
454
+ )
455
+ if attention_output.shape[1] > query_length:
456
+ layer_output_text = apply_chunking_to_forward(
457
+ self.feed_forward_chunk,
458
+ self.chunk_size_feed_forward,
459
+ self.seq_len_dim,
460
+ attention_output[:, query_length:, :],
461
+ )
462
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
463
+ else:
464
+ layer_output = apply_chunking_to_forward(
465
+ self.feed_forward_chunk,
466
+ self.chunk_size_feed_forward,
467
+ self.seq_len_dim,
468
+ attention_output,
469
+ )
470
+ outputs = (layer_output,) + outputs
471
+
472
+ outputs = outputs + (present_key_value,)
473
+
474
+ return outputs
475
+
476
+ def feed_forward_chunk(self, attention_output):
477
+ intermediate_output = self.intermediate(attention_output)
478
+ layer_output = self.output(intermediate_output, attention_output)
479
+ return layer_output
480
+
481
+ def feed_forward_chunk_query(self, attention_output):
482
+ intermediate_output = self.intermediate_query(attention_output)
483
+ layer_output = self.output_query(intermediate_output, attention_output)
484
+ return layer_output
485
+
486
+
487
+ class BertEncoder(nn.Module):
488
+ def __init__(self, config):
489
+ super().__init__()
490
+ self.config = config
491
+ self.layer = nn.ModuleList(
492
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
493
+ )
494
+
495
+ def forward(
496
+ self,
497
+ hidden_states,
498
+ attention_mask=None,
499
+ head_mask=None,
500
+ encoder_hidden_states=None,
501
+ encoder_attention_mask=None,
502
+ past_key_values=None,
503
+ use_cache=None,
504
+ output_attentions=False,
505
+ output_hidden_states=False,
506
+ return_dict=True,
507
+ query_length=0,
508
+ ):
509
+ all_hidden_states = () if output_hidden_states else None
510
+ all_self_attentions = () if output_attentions else None
511
+ all_cross_attentions = (
512
+ () if output_attentions and self.config.add_cross_attention else None
513
+ )
514
+
515
+ next_decoder_cache = () if use_cache else None
516
+
517
+ for i in range(self.config.num_hidden_layers):
518
+ layer_module = self.layer[i]
519
+ if output_hidden_states:
520
+ all_hidden_states = all_hidden_states + (hidden_states,)
521
+
522
+ layer_head_mask = head_mask[i] if head_mask is not None else None
523
+ past_key_value = past_key_values[i] if past_key_values is not None else None
524
+
525
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
526
+
527
+ if use_cache:
528
+ logger.warn(
529
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
530
+ )
531
+ use_cache = False
532
+
533
+ def create_custom_forward(module):
534
+ def custom_forward(*inputs):
535
+ return module(
536
+ *inputs, past_key_value, output_attentions, query_length
537
+ )
538
+
539
+ return custom_forward
540
+
541
+ layer_outputs = torch.utils.checkpoint.checkpoint(
542
+ create_custom_forward(layer_module),
543
+ hidden_states,
544
+ attention_mask,
545
+ layer_head_mask,
546
+ encoder_hidden_states,
547
+ encoder_attention_mask,
548
+ )
549
+ else:
550
+ layer_outputs = layer_module(
551
+ hidden_states,
552
+ attention_mask,
553
+ layer_head_mask,
554
+ encoder_hidden_states,
555
+ encoder_attention_mask,
556
+ past_key_value,
557
+ output_attentions,
558
+ query_length,
559
+ )
560
+
561
+ hidden_states = layer_outputs[0]
562
+ if use_cache:
563
+ next_decoder_cache += (layer_outputs[-1],)
564
+ if output_attentions:
565
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
566
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
567
+
568
+ if output_hidden_states:
569
+ all_hidden_states = all_hidden_states + (hidden_states,)
570
+
571
+ if not return_dict:
572
+ return tuple(
573
+ v
574
+ for v in [
575
+ hidden_states,
576
+ next_decoder_cache,
577
+ all_hidden_states,
578
+ all_self_attentions,
579
+ all_cross_attentions,
580
+ ]
581
+ if v is not None
582
+ )
583
+ return BaseModelOutputWithPastAndCrossAttentions(
584
+ last_hidden_state=hidden_states,
585
+ past_key_values=next_decoder_cache,
586
+ hidden_states=all_hidden_states,
587
+ attentions=all_self_attentions,
588
+ cross_attentions=all_cross_attentions,
589
+ )
590
+
591
+
592
+ class BertPooler(nn.Module):
593
+ def __init__(self, config):
594
+ super().__init__()
595
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
596
+ self.activation = nn.Tanh()
597
+
598
+ def forward(self, hidden_states):
599
+ # We "pool" the model by simply taking the hidden state corresponding
600
+ # to the first token.
601
+ first_token_tensor = hidden_states[:, 0]
602
+ pooled_output = self.dense(first_token_tensor)
603
+ pooled_output = self.activation(pooled_output)
604
+ return pooled_output
605
+
606
+
607
+ class BertPredictionHeadTransform(nn.Module):
608
+ def __init__(self, config):
609
+ super().__init__()
610
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
611
+ if isinstance(config.hidden_act, str):
612
+ self.transform_act_fn = ACT2FN[config.hidden_act]
613
+ else:
614
+ self.transform_act_fn = config.hidden_act
615
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
616
+
617
+ def forward(self, hidden_states):
618
+ hidden_states = self.dense(hidden_states)
619
+ hidden_states = self.transform_act_fn(hidden_states)
620
+ hidden_states = self.LayerNorm(hidden_states)
621
+ return hidden_states
622
+
623
+
624
+ class BertLMPredictionHead(nn.Module):
625
+ def __init__(self, config):
626
+ super().__init__()
627
+ self.transform = BertPredictionHeadTransform(config)
628
+
629
+ # The output weights are the same as the input embeddings, but there is
630
+ # an output-only bias for each token.
631
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
632
+
633
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
634
+
635
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
636
+ self.decoder.bias = self.bias
637
+
638
+ def forward(self, hidden_states):
639
+ hidden_states = self.transform(hidden_states)
640
+ hidden_states = self.decoder(hidden_states)
641
+ return hidden_states
642
+
643
+
644
+ class BertOnlyMLMHead(nn.Module):
645
+ def __init__(self, config):
646
+ super().__init__()
647
+ self.predictions = BertLMPredictionHead(config)
648
+
649
+ def forward(self, sequence_output):
650
+ prediction_scores = self.predictions(sequence_output)
651
+ return prediction_scores
652
+
653
+
654
+ class BertPreTrainedModel(PreTrainedModel):
655
+ """
656
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
657
+ models.
658
+ """
659
+
660
+ config_class = BertConfig
661
+ base_model_prefix = "bert"
662
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
663
+
664
+ def _init_weights(self, module):
665
+ """Initialize the weights"""
666
+ if isinstance(module, (nn.Linear, nn.Embedding)):
667
+ # Slightly different from the TF version which uses truncated_normal for initialization
668
+ # cf https://github.com/pytorch/pytorch/pull/5617
669
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
670
+ elif isinstance(module, nn.LayerNorm):
671
+ module.bias.data.zero_()
672
+ module.weight.data.fill_(1.0)
673
+ if isinstance(module, nn.Linear) and module.bias is not None:
674
+ module.bias.data.zero_()
675
+
676
+
677
+ class BertModel(BertPreTrainedModel):
678
+ """
679
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
680
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
681
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
682
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
683
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
684
+ input to the forward pass.
685
+ """
686
+
687
+ def __init__(self, config, add_pooling_layer=False):
688
+ super().__init__(config)
689
+ self.config = config
690
+
691
+ self.embeddings = BertEmbeddings(config)
692
+
693
+ self.encoder = BertEncoder(config)
694
+
695
+ self.pooler = BertPooler(config) if add_pooling_layer else None
696
+
697
+ self.init_weights()
698
+
699
+ def get_input_embeddings(self):
700
+ return self.embeddings.word_embeddings
701
+
702
+ def set_input_embeddings(self, value):
703
+ self.embeddings.word_embeddings = value
704
+
705
+ def _prune_heads(self, heads_to_prune):
706
+ """
707
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
708
+ class PreTrainedModel
709
+ """
710
+ for layer, heads in heads_to_prune.items():
711
+ self.encoder.layer[layer].attention.prune_heads(heads)
712
+
713
+ def get_extended_attention_mask(
714
+ self,
715
+ attention_mask: Tensor,
716
+ input_shape: Tuple[int],
717
+ device: device,
718
+ is_decoder: bool,
719
+ has_query: bool = False,
720
+ ) -> Tensor:
721
+ """
722
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
723
+
724
+ Arguments:
725
+ attention_mask (:obj:`torch.Tensor`):
726
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
727
+ input_shape (:obj:`Tuple[int]`):
728
+ The shape of the input to the model.
729
+ device: (:obj:`torch.device`):
730
+ The device of the input to the model.
731
+
732
+ Returns:
733
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
734
+ """
735
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
736
+ # ourselves in which case we just need to make it broadcastable to all heads.
737
+ if attention_mask.dim() == 3:
738
+ extended_attention_mask = attention_mask[:, None, :, :]
739
+ elif attention_mask.dim() == 2:
740
+ # Provided a padding mask of dimensions [batch_size, seq_length]
741
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
742
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
743
+ if is_decoder:
744
+ batch_size, seq_length = input_shape
745
+
746
+ seq_ids = torch.arange(seq_length, device=device)
747
+ causal_mask = (
748
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
749
+ <= seq_ids[None, :, None]
750
+ )
751
+
752
+ # add a prefix ones mask to the causal mask
753
+ # causal and attention masks must have same type with pytorch version < 1.3
754
+ causal_mask = causal_mask.to(attention_mask.dtype)
755
+
756
+ if causal_mask.shape[1] < attention_mask.shape[1]:
757
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
758
+ if has_query: # UniLM style attention mask
759
+ causal_mask = torch.cat(
760
+ [
761
+ torch.zeros(
762
+ (batch_size, prefix_seq_len, seq_length),
763
+ device=device,
764
+ dtype=causal_mask.dtype,
765
+ ),
766
+ causal_mask,
767
+ ],
768
+ axis=1,
769
+ )
770
+ causal_mask = torch.cat(
771
+ [
772
+ torch.ones(
773
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
774
+ device=device,
775
+ dtype=causal_mask.dtype,
776
+ ),
777
+ causal_mask,
778
+ ],
779
+ axis=-1,
780
+ )
781
+ extended_attention_mask = (
782
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
783
+ )
784
+ else:
785
+ extended_attention_mask = attention_mask[:, None, None, :]
786
+ else:
787
+ raise ValueError(
788
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
789
+ input_shape, attention_mask.shape
790
+ )
791
+ )
792
+
793
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
794
+ # masked positions, this operation will create a tensor which is 0.0 for
795
+ # positions we want to attend and -10000.0 for masked positions.
796
+ # Since we are adding it to the raw scores before the softmax, this is
797
+ # effectively the same as removing these entirely.
798
+ extended_attention_mask = extended_attention_mask.to(
799
+ dtype=self.dtype
800
+ ) # fp16 compatibility
801
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
802
+ return extended_attention_mask
803
+
804
+ def forward(
805
+ self,
806
+ input_ids=None,
807
+ attention_mask=None,
808
+ position_ids=None,
809
+ head_mask=None,
810
+ query_embeds=None,
811
+ encoder_hidden_states=None,
812
+ encoder_attention_mask=None,
813
+ past_key_values=None,
814
+ use_cache=None,
815
+ output_attentions=None,
816
+ output_hidden_states=None,
817
+ return_dict=None,
818
+ is_decoder=False,
819
+ ):
820
+ r"""
821
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
822
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
823
+ the model is configured as a decoder.
824
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
825
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
826
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
827
+ - 1 for tokens that are **not masked**,
828
+ - 0 for tokens that are **masked**.
829
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
830
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
831
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
832
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
833
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
834
+ use_cache (:obj:`bool`, `optional`):
835
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
836
+ decoding (see :obj:`past_key_values`).
837
+ """
838
+ output_attentions = (
839
+ output_attentions
840
+ if output_attentions is not None
841
+ else self.config.output_attentions
842
+ )
843
+ output_hidden_states = (
844
+ output_hidden_states
845
+ if output_hidden_states is not None
846
+ else self.config.output_hidden_states
847
+ )
848
+ return_dict = (
849
+ return_dict if return_dict is not None else self.config.use_return_dict
850
+ )
851
+
852
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
853
+
854
+ if input_ids is None:
855
+ assert (
856
+ query_embeds is not None
857
+ ), "You have to specify query_embeds when input_ids is None"
858
+
859
+ # past_key_values_length
860
+ past_key_values_length = (
861
+ past_key_values[0][0].shape[2] - self.config.query_length
862
+ if past_key_values is not None
863
+ else 0
864
+ )
865
+
866
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
867
+
868
+ embedding_output = self.embeddings(
869
+ input_ids=input_ids,
870
+ position_ids=position_ids,
871
+ query_embeds=query_embeds,
872
+ past_key_values_length=past_key_values_length,
873
+ )
874
+
875
+ input_shape = embedding_output.size()[:-1]
876
+ batch_size, seq_length = input_shape
877
+ device = embedding_output.device
878
+
879
+ if attention_mask is None:
880
+ attention_mask = torch.ones(
881
+ ((batch_size, seq_length + past_key_values_length)), device=device
882
+ )
883
+
884
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
885
+ # ourselves in which case we just need to make it broadcastable to all heads.
886
+ if is_decoder:
887
+ extended_attention_mask = self.get_extended_attention_mask(
888
+ attention_mask,
889
+ input_ids.shape,
890
+ device,
891
+ is_decoder,
892
+ has_query=(query_embeds is not None),
893
+ )
894
+ else:
895
+ extended_attention_mask = self.get_extended_attention_mask(
896
+ attention_mask, input_shape, device, is_decoder
897
+ )
898
+
899
+ # If a 2D or 3D attention mask is provided for the cross-attention
900
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
901
+ if encoder_hidden_states is not None:
902
+ if type(encoder_hidden_states) == list:
903
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
904
+ 0
905
+ ].size()
906
+ else:
907
+ (
908
+ encoder_batch_size,
909
+ encoder_sequence_length,
910
+ _,
911
+ ) = encoder_hidden_states.size()
912
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
913
+
914
+ if type(encoder_attention_mask) == list:
915
+ encoder_extended_attention_mask = [
916
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
917
+ ]
918
+ elif encoder_attention_mask is None:
919
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
920
+ encoder_extended_attention_mask = self.invert_attention_mask(
921
+ encoder_attention_mask
922
+ )
923
+ else:
924
+ encoder_extended_attention_mask = self.invert_attention_mask(
925
+ encoder_attention_mask
926
+ )
927
+ else:
928
+ encoder_extended_attention_mask = None
929
+
930
+ # Prepare head mask if needed
931
+ # 1.0 in head_mask indicate we keep the head
932
+ # attention_probs has shape bsz x n_heads x N x N
933
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
934
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
935
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
936
+
937
+ encoder_outputs = self.encoder(
938
+ embedding_output,
939
+ attention_mask=extended_attention_mask,
940
+ head_mask=head_mask,
941
+ encoder_hidden_states=encoder_hidden_states,
942
+ encoder_attention_mask=encoder_extended_attention_mask,
943
+ past_key_values=past_key_values,
944
+ use_cache=use_cache,
945
+ output_attentions=output_attentions,
946
+ output_hidden_states=output_hidden_states,
947
+ return_dict=return_dict,
948
+ query_length=query_length,
949
+ )
950
+ sequence_output = encoder_outputs[0]
951
+ pooled_output = (
952
+ self.pooler(sequence_output) if self.pooler is not None else None
953
+ )
954
+
955
+ if not return_dict:
956
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
957
+
958
+ return BaseModelOutputWithPoolingAndCrossAttentions(
959
+ last_hidden_state=sequence_output,
960
+ pooler_output=pooled_output,
961
+ past_key_values=encoder_outputs.past_key_values,
962
+ hidden_states=encoder_outputs.hidden_states,
963
+ attentions=encoder_outputs.attentions,
964
+ cross_attentions=encoder_outputs.cross_attentions,
965
+ )
966
+
967
+
968
+ class BertLMHeadModel(BertPreTrainedModel):
969
+
970
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
971
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
972
+
973
+ def __init__(self, config):
974
+ super().__init__(config)
975
+
976
+ self.bert = BertModel(config, add_pooling_layer=False)
977
+ self.cls = BertOnlyMLMHead(config)
978
+
979
+ self.init_weights()
980
+
981
+ def get_output_embeddings(self):
982
+ return self.cls.predictions.decoder
983
+
984
+ def set_output_embeddings(self, new_embeddings):
985
+ self.cls.predictions.decoder = new_embeddings
986
+
987
+ def forward(
988
+ self,
989
+ input_ids=None,
990
+ attention_mask=None,
991
+ position_ids=None,
992
+ head_mask=None,
993
+ query_embeds=None,
994
+ encoder_hidden_states=None,
995
+ encoder_attention_mask=None,
996
+ labels=None,
997
+ past_key_values=None,
998
+ use_cache=True,
999
+ output_attentions=None,
1000
+ output_hidden_states=None,
1001
+ return_dict=None,
1002
+ return_logits=False,
1003
+ is_decoder=True,
1004
+ reduction="mean",
1005
+ ):
1006
+ r"""
1007
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1008
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1009
+ the model is configured as a decoder.
1010
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1011
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1012
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1013
+ - 1 for tokens that are **not masked**,
1014
+ - 0 for tokens that are **masked**.
1015
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1016
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1017
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1018
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1019
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1020
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1021
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1022
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1023
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1024
+ use_cache (:obj:`bool`, `optional`):
1025
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1026
+ decoding (see :obj:`past_key_values`).
1027
+ Returns:
1028
+ Example::
1029
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1030
+ >>> import torch
1031
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1032
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1033
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1034
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1035
+ >>> outputs = model(**inputs)
1036
+ >>> prediction_logits = outputs.logits
1037
+ """
1038
+ return_dict = (
1039
+ return_dict if return_dict is not None else self.config.use_return_dict
1040
+ )
1041
+ if labels is not None:
1042
+ use_cache = False
1043
+ if past_key_values is not None:
1044
+ query_embeds = None
1045
+
1046
+ outputs = self.bert(
1047
+ input_ids,
1048
+ attention_mask=attention_mask,
1049
+ position_ids=position_ids,
1050
+ head_mask=head_mask,
1051
+ query_embeds=query_embeds,
1052
+ encoder_hidden_states=encoder_hidden_states,
1053
+ encoder_attention_mask=encoder_attention_mask,
1054
+ past_key_values=past_key_values,
1055
+ use_cache=use_cache,
1056
+ output_attentions=output_attentions,
1057
+ output_hidden_states=output_hidden_states,
1058
+ return_dict=return_dict,
1059
+ is_decoder=is_decoder,
1060
+ )
1061
+
1062
+ sequence_output = outputs[0]
1063
+ if query_embeds is not None:
1064
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1065
+
1066
+ prediction_scores = self.cls(sequence_output)
1067
+
1068
+ if return_logits:
1069
+ return prediction_scores[:, :-1, :].contiguous()
1070
+
1071
+ lm_loss = None
1072
+ if labels is not None:
1073
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1074
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1075
+ labels = labels[:, 1:].contiguous()
1076
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1077
+ lm_loss = loss_fct(
1078
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1079
+ labels.view(-1),
1080
+ )
1081
+ if reduction == "none":
1082
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1083
+
1084
+ if not return_dict:
1085
+ output = (prediction_scores,) + outputs[2:]
1086
+ return ((lm_loss,) + output) if lm_loss is not None else output
1087
+
1088
+ return CausalLMOutputWithCrossAttentions(
1089
+ loss=lm_loss,
1090
+ logits=prediction_scores,
1091
+ past_key_values=outputs.past_key_values,
1092
+ hidden_states=outputs.hidden_states,
1093
+ attentions=outputs.attentions,
1094
+ cross_attentions=outputs.cross_attentions,
1095
+ )
1096
+
1097
+ def prepare_inputs_for_generation(
1098
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1099
+ ):
1100
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1101
+ if attention_mask is None:
1102
+ attention_mask = input_ids.new_ones(input_ids.shape)
1103
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1104
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1105
+
1106
+ # cut decoder_input_ids if past is used
1107
+ if past is not None:
1108
+ input_ids = input_ids[:, -1:]
1109
+
1110
+ return {
1111
+ "input_ids": input_ids,
1112
+ "query_embeds": query_embeds,
1113
+ "attention_mask": attention_mask,
1114
+ "past_key_values": past,
1115
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1116
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1117
+ "is_decoder": True,
1118
+ }
1119
+
1120
+ def _reorder_cache(self, past, beam_idx):
1121
+ reordered_past = ()
1122
+ for layer_past in past:
1123
+ reordered_past += (
1124
+ tuple(
1125
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1126
+ ),
1127
+ )
1128
+ return reordered_past
1129
+
1130
+
1131
+ class BertForMaskedLM(BertPreTrainedModel):
1132
+
1133
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1134
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1135
+
1136
+ def __init__(self, config):
1137
+ super().__init__(config)
1138
+
1139
+ self.bert = BertModel(config, add_pooling_layer=False)
1140
+ self.cls = BertOnlyMLMHead(config)
1141
+
1142
+ self.init_weights()
1143
+
1144
+ def get_output_embeddings(self):
1145
+ return self.cls.predictions.decoder
1146
+
1147
+ def set_output_embeddings(self, new_embeddings):
1148
+ self.cls.predictions.decoder = new_embeddings
1149
+
1150
+ def forward(
1151
+ self,
1152
+ input_ids=None,
1153
+ attention_mask=None,
1154
+ position_ids=None,
1155
+ head_mask=None,
1156
+ query_embeds=None,
1157
+ encoder_hidden_states=None,
1158
+ encoder_attention_mask=None,
1159
+ labels=None,
1160
+ output_attentions=None,
1161
+ output_hidden_states=None,
1162
+ return_dict=None,
1163
+ return_logits=False,
1164
+ is_decoder=False,
1165
+ ):
1166
+ r"""
1167
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1168
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1169
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1170
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1171
+ """
1172
+
1173
+ return_dict = (
1174
+ return_dict if return_dict is not None else self.config.use_return_dict
1175
+ )
1176
+
1177
+ outputs = self.bert(
1178
+ input_ids,
1179
+ attention_mask=attention_mask,
1180
+ position_ids=position_ids,
1181
+ head_mask=head_mask,
1182
+ query_embeds=query_embeds,
1183
+ encoder_hidden_states=encoder_hidden_states,
1184
+ encoder_attention_mask=encoder_attention_mask,
1185
+ output_attentions=output_attentions,
1186
+ output_hidden_states=output_hidden_states,
1187
+ return_dict=return_dict,
1188
+ is_decoder=is_decoder,
1189
+ )
1190
+
1191
+ if query_embeds is not None:
1192
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1193
+ prediction_scores = self.cls(sequence_output)
1194
+
1195
+ if return_logits:
1196
+ return prediction_scores
1197
+
1198
+ masked_lm_loss = None
1199
+ if labels is not None:
1200
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1201
+ masked_lm_loss = loss_fct(
1202
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1203
+ )
1204
+
1205
+ if not return_dict:
1206
+ output = (prediction_scores,) + outputs[2:]
1207
+ return (
1208
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1209
+ )
1210
+
1211
+ return MaskedLMOutput(
1212
+ loss=masked_lm_loss,
1213
+ logits=prediction_scores,
1214
+ hidden_states=outputs.hidden_states,
1215
+ attentions=outputs.attentions,
1216
+ )
bubogpt/models/__init__.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import torch
10
+ from omegaconf import OmegaConf
11
+
12
+ from bubogpt.common.registry import registry
13
+ from bubogpt.models.base_model import BaseModel
14
+ from bubogpt.models.blip2 import Blip2Base
15
+ from bubogpt.processors.base_processor import BaseProcessor
16
+ from bubogpt.models.mm_gpt4 import MMGPT4
17
+
18
+
19
+ __all__ = [
20
+ "load_model",
21
+ "BaseModel",
22
+ "Blip2Base",
23
+ "MMGPT4"
24
+ ]
25
+
26
+
27
+ def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
28
+ """
29
+ Load supported models.
30
+
31
+ To list all available models and types in registry:
32
+ >>> from bubogpt.models import model_zoo
33
+ >>> print(model_zoo)
34
+
35
+ Args:
36
+ name (str): name of the model.
37
+ model_type (str): type of the model.
38
+ is_eval (bool): whether the model is in eval mode. Default: False.
39
+ device (str): device to use. Default: "cpu".
40
+ checkpoint (str): path or to checkpoint. Default: None.
41
+ Note that expecting the checkpoint to have the same keys in state_dict as the model.
42
+
43
+ Returns:
44
+ model (torch.nn.Module): model.
45
+ """
46
+
47
+ model = registry.get_model_class(name).from_pretrained(model_type=model_type)
48
+
49
+ if checkpoint is not None:
50
+ model.load_checkpoint(checkpoint)
51
+
52
+ if is_eval:
53
+ model.eval()
54
+
55
+ if device == "cpu":
56
+ model = model.float()
57
+
58
+ return model.to(device)
59
+
60
+
61
+ def load_preprocess(config):
62
+ """
63
+ Load preprocessor configs and construct preprocessors.
64
+
65
+ If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
66
+
67
+ Args:
68
+ config (dict): preprocessor configs.
69
+
70
+ Returns:
71
+ vis_processors (dict): preprocessors for visual inputs.
72
+ txt_processors (dict): preprocessors for text inputs.
73
+
74
+ Key is "train" or "eval" for processors used in training and evaluation respectively.
75
+ """
76
+
77
+ def _build_proc_from_cfg(cfg):
78
+ return (
79
+ registry.get_processor_class(cfg.name).from_config(cfg)
80
+ if cfg is not None
81
+ else BaseProcessor()
82
+ )
83
+
84
+ vis_processors = dict()
85
+ txt_processors = dict()
86
+
87
+ vis_proc_cfg = config.get("vis_processor")
88
+ txt_proc_cfg = config.get("text_processor")
89
+
90
+ if vis_proc_cfg is not None:
91
+ vis_train_cfg = vis_proc_cfg.get("train")
92
+ vis_eval_cfg = vis_proc_cfg.get("eval")
93
+ else:
94
+ vis_train_cfg = None
95
+ vis_eval_cfg = None
96
+
97
+ vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
98
+ vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
99
+
100
+ if txt_proc_cfg is not None:
101
+ txt_train_cfg = txt_proc_cfg.get("train")
102
+ txt_eval_cfg = txt_proc_cfg.get("eval")
103
+ else:
104
+ txt_train_cfg = None
105
+ txt_eval_cfg = None
106
+
107
+ txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
108
+ txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
109
+
110
+ return vis_processors, txt_processors
111
+
112
+
113
+ def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
114
+ """
115
+ Load model and its related preprocessors.
116
+
117
+ List all available models and types in registry:
118
+ >>> from bubogpt.models import model_zoo
119
+ >>> print(model_zoo)
120
+
121
+ Args:
122
+ name (str): name of the model.
123
+ model_type (str): type of the model.
124
+ is_eval (bool): whether the model is in eval mode. Default: False.
125
+ device (str): device to use. Default: "cpu".
126
+
127
+ Returns:
128
+ model (torch.nn.Module): model.
129
+ vis_processors (dict): preprocessors for visual inputs.
130
+ txt_processors (dict): preprocessors for text inputs.
131
+ """
132
+ model_cls = registry.get_model_class(name)
133
+
134
+ # load model
135
+ model = model_cls.from_pretrained(model_type=model_type)
136
+
137
+ if is_eval:
138
+ model.eval()
139
+
140
+ # load preprocess
141
+ cfg = OmegaConf.load(model_cls.default_config_path(model_type))
142
+ if cfg is not None:
143
+ preprocess_cfg = cfg.preprocess
144
+
145
+ vis_processors, txt_processors = load_preprocess(preprocess_cfg)
146
+ else:
147
+ vis_processors, txt_processors = None, None
148
+ logging.info(
149
+ f"""No default preprocess for model {name} ({model_type}).
150
+ This can happen if the model is not finetuned on downstream datasets,
151
+ or it is not intended for direct use without finetuning.
152
+ """
153
+ )
154
+
155
+ if device == "cpu" or device == torch.device("cpu"):
156
+ model = model.float()
157
+
158
+ return model.to(device), vis_processors, txt_processors
159
+
160
+
161
+ class ModelZoo:
162
+ """
163
+ A utility class to create string representation of available model architectures and types.
164
+
165
+ >>> from bubogpt.models import model_zoo
166
+ >>> # list all available models
167
+ >>> print(model_zoo)
168
+ >>> # show total number of models
169
+ >>> print(len(model_zoo))
170
+ """
171
+
172
+ def __init__(self) -> None:
173
+ self.model_zoo = {
174
+ k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
175
+ for k, v in registry.mapping["model_name_mapping"].items()
176
+ }
177
+
178
+ def __str__(self) -> str:
179
+ return (
180
+ "=" * 50
181
+ + "\n"
182
+ + f"{'Architectures':<30} {'Types'}\n"
183
+ + "=" * 50
184
+ + "\n"
185
+ + "\n".join(
186
+ [
187
+ f"{name:<30} {', '.join(types)}"
188
+ for name, types in self.model_zoo.items()
189
+ ]
190
+ )
191
+ )
192
+
193
+ def __iter__(self):
194
+ return iter(self.model_zoo.items())
195
+
196
+ def __len__(self):
197
+ return sum([len(v) for v in self.model_zoo.values()])
198
+
199
+
200
+ model_zoo = ModelZoo()
bubogpt/models/base_model.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import os
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from bubogpt.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
15
+ from bubogpt.common.utils import get_abs_path, is_url
16
+ from omegaconf import OmegaConf
17
+
18
+
19
+ class BaseModel(nn.Module):
20
+ """Base class for models."""
21
+
22
+ def __init__(self):
23
+ super().__init__()
24
+
25
+ @property
26
+ def device(self):
27
+ return list(self.parameters())[0].device
28
+
29
+ def load_checkpoint(self, url_or_filename):
30
+ """
31
+ Load from a finetuned checkpoint.
32
+
33
+ This should expect no mismatch in the model keys and the checkpoint keys.
34
+ """
35
+
36
+ if is_url(url_or_filename):
37
+ cached_file = download_cached_file(
38
+ url_or_filename, check_hash=False, progress=True
39
+ )
40
+ checkpoint = torch.load(cached_file, map_location="cpu")
41
+ elif os.path.isfile(url_or_filename):
42
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
43
+ else:
44
+ raise RuntimeError("checkpoint url or path is invalid")
45
+
46
+ if "model" in checkpoint.keys():
47
+ state_dict = checkpoint["model"]
48
+ else:
49
+ state_dict = checkpoint
50
+
51
+ msg = self.load_state_dict(state_dict, strict=False)
52
+
53
+ logging.info("Missing keys {}".format(msg.missing_keys))
54
+ logging.info("load checkpoint from %s" % url_or_filename)
55
+
56
+ return msg
57
+
58
+ @classmethod
59
+ def from_pretrained(cls, model_type):
60
+ """
61
+ Build a pretrained model from default configuration file, specified by model_type.
62
+
63
+ Args:
64
+ - model_type (str): model type, specifying architecture and checkpoints.
65
+
66
+ Returns:
67
+ - model (nn.Module): pretrained or finetuned model, depending on the configuration.
68
+ """
69
+ model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
70
+ model = cls.from_config(model_cfg)
71
+
72
+ return model
73
+
74
+ @classmethod
75
+ def default_config_path(cls, model_type):
76
+ assert (
77
+ model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
78
+ ), "Unknown model type {}".format(model_type)
79
+ return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
80
+
81
+ def load_checkpoint_from_config(self, cfg, **kwargs):
82
+ """
83
+ Load checkpoint as specified in the config file.
84
+
85
+ If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
86
+ When loading the pretrained model, each task-specific architecture may define their
87
+ own load_from_pretrained() method.
88
+ """
89
+ load_finetuned = cfg.get("load_finetuned", True)
90
+ if load_finetuned:
91
+ finetune_path = cfg.get("finetuned", None)
92
+ assert (
93
+ finetune_path is not None
94
+ ), "Found load_finetuned is True, but finetune_path is None."
95
+ self.load_checkpoint(url_or_filename=finetune_path)
96
+ else:
97
+ # load pre-trained weights
98
+ pretrain_path = cfg.get("pretrained", None)
99
+ assert "Found load_finetuned is False, but pretrain_path is None."
100
+ self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
101
+
102
+ def before_evaluation(self, **kwargs):
103
+ pass
104
+
105
+ def show_n_params(self, return_str=True):
106
+ tot = 0
107
+ for p in self.parameters():
108
+ w = 1
109
+ for x in p.shape:
110
+ w *= x
111
+ tot += w
112
+ if return_str:
113
+ if tot >= 1e6:
114
+ return "{:.1f}M".format(tot / 1e6)
115
+ else:
116
+ return "{:.1f}K".format(tot / 1e3)
117
+ else:
118
+ return tot
119
+
120
+
121
+ class BaseEncoder(nn.Module):
122
+ """
123
+ Base class for primitive encoders, such as ViT, TimeSformer, etc.
124
+ """
125
+
126
+ def __init__(self):
127
+ super().__init__()
128
+
129
+ def forward_features(self, samples, **kwargs):
130
+ raise NotImplementedError
131
+
132
+ @property
133
+ def device(self):
134
+ return list(self.parameters())[0].device
135
+
136
+
137
+ class SharedQueueMixin:
138
+ @torch.no_grad()
139
+ def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
140
+ # gather keys before updating queue
141
+ image_feats = concat_all_gather(image_feat)
142
+ text_feats = concat_all_gather(text_feat)
143
+
144
+ batch_size = image_feats.shape[0]
145
+
146
+ ptr = int(self.queue_ptr)
147
+ assert self.queue_size % batch_size == 0 # for simplicity
148
+
149
+ # replace the keys at ptr (dequeue and enqueue)
150
+ self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
151
+ self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
152
+
153
+ if idxs is not None:
154
+ idxs = concat_all_gather(idxs)
155
+ self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
156
+
157
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
158
+ self.queue_ptr[0] = ptr
159
+
160
+
161
+ class MomentumDistilationMixin:
162
+ @torch.no_grad()
163
+ def copy_params(self):
164
+ for model_pair in self.model_pairs:
165
+ for param, param_m in zip(
166
+ model_pair[0].parameters(), model_pair[1].parameters()
167
+ ):
168
+ param_m.data.copy_(param.data) # initialize
169
+ param_m.requires_grad = False # not update by gradient
170
+
171
+ @torch.no_grad()
172
+ def _momentum_update(self):
173
+ for model_pair in self.model_pairs:
174
+ for param, param_m in zip(
175
+ model_pair[0].parameters(), model_pair[1].parameters()
176
+ ):
177
+ param_m.data = param_m.data * self.momentum + param.data * (
178
+ 1.0 - self.momentum
179
+ )
180
+
181
+
182
+ class GatherLayer(torch.autograd.Function):
183
+ """
184
+ Gather tensors from all workers with support for backward propagation:
185
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
186
+ """
187
+
188
+ @staticmethod
189
+ def forward(ctx, x):
190
+ output = [
191
+ torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
192
+ ]
193
+ torch.distributed.all_gather(output, x)
194
+ return tuple(output)
195
+
196
+ @staticmethod
197
+ def backward(ctx, *grads):
198
+ all_gradients = torch.stack(grads)
199
+ torch.distributed.all_reduce(all_gradients)
200
+ return all_gradients[torch.distributed.get_rank()]
201
+
202
+
203
+ def all_gather_with_grad(tensors):
204
+ """
205
+ Performs all_gather operation on the provided tensors.
206
+ Graph remains connected for backward grad computation.
207
+ """
208
+ # Queue the gathered tensors
209
+ world_size = torch.distributed.get_world_size()
210
+ # There is no need for reduction in the single-proc case
211
+ if world_size == 1:
212
+ return tensors
213
+
214
+ # tensor_all = GatherLayer.apply(tensors)
215
+ tensor_all = GatherLayer.apply(tensors)
216
+
217
+ return torch.cat(tensor_all, dim=0)
218
+
219
+
220
+ @torch.no_grad()
221
+ def concat_all_gather(tensor):
222
+ """
223
+ Performs all_gather operation on the provided tensors.
224
+ *** Warning ***: torch.distributed.all_gather has no gradient.
225
+ """
226
+ # if use distributed training
227
+ if not is_dist_avail_and_initialized():
228
+ return tensor
229
+
230
+ tensors_gather = [
231
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
232
+ ]
233
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
234
+
235
+ output = torch.cat(tensors_gather, dim=0)
236
+ return output
237
+
238
+
239
+ def tile(x, dim, n_tile):
240
+ init_dim = x.size(dim)
241
+ repeat_idx = [1] * x.dim()
242
+ repeat_idx[dim] = n_tile
243
+ x = x.repeat(*(repeat_idx))
244
+ order_index = torch.LongTensor(
245
+ np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
246
+ )
247
+ return torch.index_select(x, dim, order_index.to(x.device))